From b4a85c8f13e0ac54cac93dfde144aa000f1540be Mon Sep 17 00:00:00 2001 From: Ming Deng Date: Sat, 15 Aug 2020 09:39:58 +0800 Subject: [PATCH] Remove files --- .travis.yml | 2 +- admin.go | 461 --- admin_test.go | 249 -- adminui.go | 356 --- app.go | 516 ---- beego.go | 132 - build_info.go | 34 - cache/README.md | 59 - cache/cache.go | 103 - cache/cache_test.go | 191 -- cache/conv.go | 100 - cache/conv_test.go | 143 - cache/file.go | 258 -- cache/memcache/memcache.go | 188 -- cache/memcache/memcache_test.go | 108 - cache/memory.go | 256 -- cache/redis/redis.go | 272 -- cache/redis/redis_test.go | 144 - cache/ssdb/ssdb.go | 231 -- cache/ssdb/ssdb_test.go | 104 - config.go | 557 ---- config/config.go | 259 -- config/config_test.go | 55 - config/env/env.go | 92 - config/env/env_test.go | 75 - config/fake.go | 151 - config/ini.go | 524 ---- config/ini_test.go | 190 -- config/json.go | 290 -- config/json_test.go | 222 -- config/xml/xml.go | 248 -- config/xml/xml_test.go | 125 - config/yaml/yaml.go | 337 --- config/yaml/yaml_test.go | 115 - config_test.go | 146 - context/acceptencoder.go | 232 -- context/acceptencoder_test.go | 59 - context/context.go | 263 -- context/context_test.go | 54 - context/input.go | 709 ----- context/input_test.go | 217 -- context/output.go | 413 --- context/param/conv.go | 78 - context/param/methodparams.go | 69 - context/param/options.go | 37 - context/param/parsers.go | 149 - context/param/parsers_test.go | 84 - context/renderer.go | 12 - context/response.go | 27 - controller.go | 774 ----- controller_test.go | 181 -- doc.go | 19 - error.go | 492 ---- error_test.go | 88 - filter.go | 47 - filter_test.go | 68 - flash.go | 119 - flash_test.go | 54 - fs.go | 77 - go.mod | 1 + go.sum | 26 + grace/grace.go | 175 -- grace/server.go | 362 --- hooks.go | 104 - httplib/README.md | 97 - httplib/httplib.go | 697 ----- httplib/httplib_test.go | 286 -- logs/README.md | 72 - logs/accesslog.go | 83 - logs/alils/alils.go | 186 -- logs/alils/config.go | 13 - logs/alils/log.pb.go | 1038 ------- logs/alils/log_config.go | 42 - logs/alils/log_project.go | 819 ------ logs/alils/log_store.go | 271 -- logs/alils/machine_group.go | 91 - logs/alils/request.go | 62 - logs/alils/signature.go | 111 - logs/conn.go | 119 - logs/conn_test.go | 80 - logs/console.go | 99 - logs/console_test.go | 64 - logs/es/es.go | 102 - logs/file.go | 409 --- logs/file_test.go | 420 --- logs/jianliao.go | 72 - logs/log.go | 669 ----- logs/logger.go | 176 -- logs/logger_test.go | 57 - logs/multifile.go | 119 - logs/multifile_test.go | 78 - logs/slack.go | 60 - logs/smtp.go | 149 - logs/smtp_test.go | 24 - metric/prometheus.go | 101 - metric/prometheus_test.go | 42 - migration/ddl.go | 395 --- migration/doc.go | 32 - migration/migration.go | 330 --- mime.go | 556 ---- namespace.go | 433 --- namespace_test.go | 168 -- orm/README.md | 159 -- orm/cmd.go | 283 -- orm/cmd_utils.go | 320 --- orm/db.go | 1908 ------------- orm/db_alias.go | 487 ---- orm/db_mysql.go | 189 -- orm/db_oracle.go | 142 - orm/db_postgres.go | 195 -- orm/db_sqlite.go | 167 -- orm/db_tables.go | 482 ---- orm/db_tidb.go | 63 - orm/db_utils.go | 177 -- orm/models.go | 99 - orm/models_boot.go | 347 --- orm/models_fields.go | 783 ------ orm/models_info_f.go | 473 ---- orm/models_info_m.go | 148 - orm/models_test.go | 497 ---- orm/models_utils.go | 227 -- orm/orm.go | 602 ---- orm/orm_conds.go | 153 - orm/orm_log.go | 222 -- orm/orm_object.go | 87 - orm/orm_querym2m.go | 140 - orm/orm_queryset.go | 300 -- orm/orm_raw.go | 867 ------ orm/orm_test.go | 2500 ----------------- orm/qb.go | 62 - orm/qb_mysql.go | 185 -- orm/qb_tidb.go | 182 -- orm/types.go | 474 ---- orm/utils.go | 319 --- orm/utils_test.go | 70 - parser.go | 591 ---- plugins/apiauth/apiauth.go | 165 -- plugins/apiauth/apiauth_test.go | 20 - plugins/auth/basic.go | 107 - plugins/authz/authz.go | 86 - plugins/authz/authz_model.conf | 14 - plugins/authz/authz_policy.csv | 7 - plugins/authz/authz_test.go | 107 - plugins/cors/cors.go | 228 -- plugins/cors/cors_test.go | 253 -- policy.go | 100 - router.go | 1085 ------- router_test.go | 732 ----- session/README.md | 114 - session/couchbase/sess_couchbase.go | 247 -- session/ledis/ledis_session.go | 173 -- session/memcache/sess_memcache.go | 230 -- session/mysql/sess_mysql.go | 228 -- session/postgres/sess_postgresql.go | 243 -- session/redis/sess_redis.go | 261 -- session/redis_cluster/redis_cluster.go | 221 -- session/redis_sentinel/sess_redis_sentinel.go | 234 -- .../sess_redis_sentinel_test.go | 90 - session/sess_cookie.go | 180 -- session/sess_cookie_test.go | 105 - session/sess_file.go | 315 --- session/sess_file_test.go | 386 --- session/sess_mem.go | 196 -- session/sess_mem_test.go | 58 - session/sess_test.go | 131 - session/sess_utils.go | 207 -- session/session.go | 377 --- session/ssdb/sess_ssdb.go | 199 -- staticfile.go | 234 -- staticfile_test.go | 99 - swagger/swagger.go | 174 -- template.go | 417 --- template_test.go | 316 --- templatefunc.go | 798 ------ templatefunc_test.go | 380 --- testing/assertions.go | 15 - testing/client.go | 65 - toolbox/healthcheck.go | 48 - toolbox/profile.go | 184 -- toolbox/profile_test.go | 28 - toolbox/statistics.go | 149 - toolbox/statistics_test.go | 40 - toolbox/task.go | 640 ----- toolbox/task_test.go | 85 - tree.go | 590 ---- tree_test.go | 306 -- unregroute_test.go | 226 -- utils/caller.go | 25 - utils/caller_test.go | 28 - utils/captcha/LICENSE | 19 - utils/captcha/README.md | 45 - utils/captcha/captcha.go | 270 -- utils/captcha/image.go | 501 ---- utils/captcha/image_test.go | 52 - utils/captcha/siprng.go | 277 -- utils/captcha/siprng_test.go | 33 - utils/debug.go | 478 ---- utils/debug_test.go | 46 - utils/file.go | 101 - utils/file_test.go | 76 - utils/mail.go | 424 --- utils/mail_test.go | 41 - utils/pagination/controller.go | 26 - utils/pagination/doc.go | 58 - utils/pagination/paginator.go | 189 -- utils/pagination/utils.go | 34 - utils/rand.go | 44 - utils/rand_test.go | 33 - utils/safemap.go | 91 - utils/safemap_test.go | 89 - utils/slice.go | 170 -- utils/slice_test.go | 29 - utils/testdata/grepe.test | 7 - utils/utils.go | 89 - utils/utils_test.go | 36 - validation/README.md | 147 - validation/util.go | 298 -- validation/util_test.go | 128 - validation/validation.go | 456 --- validation/validation_test.go | 609 ---- validation/validators.go | 738 ----- 221 files changed, 28 insertions(+), 52357 deletions(-) delete mode 100644 admin.go delete mode 100644 admin_test.go delete mode 100644 adminui.go delete mode 100644 app.go delete mode 100644 beego.go delete mode 100644 build_info.go delete mode 100644 cache/README.md delete mode 100644 cache/cache.go delete mode 100644 cache/cache_test.go delete mode 100644 cache/conv.go delete mode 100644 cache/conv_test.go delete mode 100644 cache/file.go delete mode 100644 cache/memcache/memcache.go delete mode 100644 cache/memcache/memcache_test.go delete mode 100644 cache/memory.go delete mode 100644 cache/redis/redis.go delete mode 100644 cache/redis/redis_test.go delete mode 100644 cache/ssdb/ssdb.go delete mode 100644 cache/ssdb/ssdb_test.go delete mode 100644 config.go delete mode 100644 config/config.go delete mode 100644 config/config_test.go delete mode 100644 config/env/env.go delete mode 100644 config/env/env_test.go delete mode 100644 config/fake.go delete mode 100644 config/ini.go delete mode 100644 config/ini_test.go delete mode 100644 config/json.go delete mode 100644 config/json_test.go delete mode 100644 config/xml/xml.go delete mode 100644 config/xml/xml_test.go delete mode 100644 config/yaml/yaml.go delete mode 100644 config/yaml/yaml_test.go delete mode 100644 config_test.go delete mode 100644 context/acceptencoder.go delete mode 100644 context/acceptencoder_test.go delete mode 100644 context/context.go delete mode 100644 context/context_test.go delete mode 100644 context/input.go delete mode 100644 context/input_test.go delete mode 100644 context/output.go delete mode 100644 context/param/conv.go delete mode 100644 context/param/methodparams.go delete mode 100644 context/param/options.go delete mode 100644 context/param/parsers.go delete mode 100644 context/param/parsers_test.go delete mode 100644 context/renderer.go delete mode 100644 context/response.go delete mode 100644 controller.go delete mode 100644 controller_test.go delete mode 100644 doc.go delete mode 100644 error.go delete mode 100644 error_test.go delete mode 100644 filter.go delete mode 100644 filter_test.go delete mode 100644 flash.go delete mode 100644 flash_test.go delete mode 100644 fs.go delete mode 100644 grace/grace.go delete mode 100644 grace/server.go delete mode 100644 hooks.go delete mode 100644 httplib/README.md delete mode 100644 httplib/httplib.go delete mode 100644 httplib/httplib_test.go delete mode 100644 logs/README.md delete mode 100644 logs/accesslog.go delete mode 100644 logs/alils/alils.go delete mode 100755 logs/alils/config.go delete mode 100755 logs/alils/log.pb.go delete mode 100755 logs/alils/log_config.go delete mode 100755 logs/alils/log_project.go delete mode 100755 logs/alils/log_store.go delete mode 100755 logs/alils/machine_group.go delete mode 100755 logs/alils/request.go delete mode 100755 logs/alils/signature.go delete mode 100644 logs/conn.go delete mode 100644 logs/conn_test.go delete mode 100644 logs/console.go delete mode 100644 logs/console_test.go delete mode 100644 logs/es/es.go delete mode 100644 logs/file.go delete mode 100644 logs/file_test.go delete mode 100644 logs/jianliao.go delete mode 100644 logs/log.go delete mode 100644 logs/logger.go delete mode 100644 logs/logger_test.go delete mode 100644 logs/multifile.go delete mode 100644 logs/multifile_test.go delete mode 100644 logs/slack.go delete mode 100644 logs/smtp.go delete mode 100644 logs/smtp_test.go delete mode 100644 metric/prometheus.go delete mode 100644 metric/prometheus_test.go delete mode 100644 migration/ddl.go delete mode 100644 migration/doc.go delete mode 100644 migration/migration.go delete mode 100644 mime.go delete mode 100644 namespace.go delete mode 100644 namespace_test.go delete mode 100644 orm/README.md delete mode 100644 orm/cmd.go delete mode 100644 orm/cmd_utils.go delete mode 100644 orm/db.go delete mode 100644 orm/db_alias.go delete mode 100644 orm/db_mysql.go delete mode 100644 orm/db_oracle.go delete mode 100644 orm/db_postgres.go delete mode 100644 orm/db_sqlite.go delete mode 100644 orm/db_tables.go delete mode 100644 orm/db_tidb.go delete mode 100644 orm/db_utils.go delete mode 100644 orm/models.go delete mode 100644 orm/models_boot.go delete mode 100644 orm/models_fields.go delete mode 100644 orm/models_info_f.go delete mode 100644 orm/models_info_m.go delete mode 100644 orm/models_test.go delete mode 100644 orm/models_utils.go delete mode 100644 orm/orm.go delete mode 100644 orm/orm_conds.go delete mode 100644 orm/orm_log.go delete mode 100644 orm/orm_object.go delete mode 100644 orm/orm_querym2m.go delete mode 100644 orm/orm_queryset.go delete mode 100644 orm/orm_raw.go delete mode 100644 orm/orm_test.go delete mode 100644 orm/qb.go delete mode 100644 orm/qb_mysql.go delete mode 100644 orm/qb_tidb.go delete mode 100644 orm/types.go delete mode 100644 orm/utils.go delete mode 100644 orm/utils_test.go delete mode 100644 parser.go delete mode 100644 plugins/apiauth/apiauth.go delete mode 100644 plugins/apiauth/apiauth_test.go delete mode 100644 plugins/auth/basic.go delete mode 100644 plugins/authz/authz.go delete mode 100644 plugins/authz/authz_model.conf delete mode 100644 plugins/authz/authz_policy.csv delete mode 100644 plugins/authz/authz_test.go delete mode 100644 plugins/cors/cors.go delete mode 100644 plugins/cors/cors_test.go delete mode 100644 policy.go delete mode 100644 router.go delete mode 100644 router_test.go delete mode 100644 session/README.md delete mode 100644 session/couchbase/sess_couchbase.go delete mode 100644 session/ledis/ledis_session.go delete mode 100644 session/memcache/sess_memcache.go delete mode 100644 session/mysql/sess_mysql.go delete mode 100644 session/postgres/sess_postgresql.go delete mode 100644 session/redis/sess_redis.go delete mode 100644 session/redis_cluster/redis_cluster.go delete mode 100644 session/redis_sentinel/sess_redis_sentinel.go delete mode 100644 session/redis_sentinel/sess_redis_sentinel_test.go delete mode 100644 session/sess_cookie.go delete mode 100644 session/sess_cookie_test.go delete mode 100644 session/sess_file.go delete mode 100644 session/sess_file_test.go delete mode 100644 session/sess_mem.go delete mode 100644 session/sess_mem_test.go delete mode 100644 session/sess_test.go delete mode 100644 session/sess_utils.go delete mode 100644 session/session.go delete mode 100644 session/ssdb/sess_ssdb.go delete mode 100644 staticfile.go delete mode 100644 staticfile_test.go delete mode 100644 swagger/swagger.go delete mode 100644 template.go delete mode 100644 template_test.go delete mode 100644 templatefunc.go delete mode 100644 templatefunc_test.go delete mode 100644 testing/assertions.go delete mode 100644 testing/client.go delete mode 100644 toolbox/healthcheck.go delete mode 100644 toolbox/profile.go delete mode 100644 toolbox/profile_test.go delete mode 100644 toolbox/statistics.go delete mode 100644 toolbox/statistics_test.go delete mode 100644 toolbox/task.go delete mode 100644 toolbox/task_test.go delete mode 100644 tree.go delete mode 100644 tree_test.go delete mode 100644 unregroute_test.go delete mode 100644 utils/caller.go delete mode 100644 utils/caller_test.go delete mode 100644 utils/captcha/LICENSE delete mode 100644 utils/captcha/README.md delete mode 100644 utils/captcha/captcha.go delete mode 100644 utils/captcha/image.go delete mode 100644 utils/captcha/image_test.go delete mode 100644 utils/captcha/siprng.go delete mode 100644 utils/captcha/siprng_test.go delete mode 100644 utils/debug.go delete mode 100644 utils/debug_test.go delete mode 100644 utils/file.go delete mode 100644 utils/file_test.go delete mode 100644 utils/mail.go delete mode 100644 utils/mail_test.go delete mode 100644 utils/pagination/controller.go delete mode 100644 utils/pagination/doc.go delete mode 100644 utils/pagination/paginator.go delete mode 100644 utils/pagination/utils.go delete mode 100644 utils/rand.go delete mode 100644 utils/rand_test.go delete mode 100644 utils/safemap.go delete mode 100644 utils/safemap_test.go delete mode 100644 utils/slice.go delete mode 100644 utils/slice_test.go delete mode 100644 utils/testdata/grepe.test delete mode 100644 utils/utils.go delete mode 100644 utils/utils_test.go delete mode 100644 validation/README.md delete mode 100644 validation/util.go delete mode 100644 validation/util_test.go delete mode 100644 validation/validation.go delete mode 100644 validation/validation_test.go delete mode 100644 validation/validators.go diff --git a/.travis.yml b/.travis.yml index 26c3732e..63b31c52 100644 --- a/.travis.yml +++ b/.travis.yml @@ -64,7 +64,7 @@ after_script: - rm -rf ./res/var/* script: - go test ./... - - staticcheck -show-ignored -checks "-ST1017,-U1000,-ST1005,-S1034,-S1012,-SA4006,-SA6005,-SA1019,-SA1024" + - staticcheck -show-ignored -checks "-ST1017,-U1000,-ST1005,-S1034,-S1012,-SA4006,-SA6005,-SA1019,-SA1024" ./pkg - unconvert $(go list ./... | grep -v /vendor/) - ineffassign . - find . ! \( -path './vendor' -prune \) -type f -name '*.go' -print0 | xargs -0 gofmt -l -s diff --git a/admin.go b/admin.go deleted file mode 100644 index d9c96dfd..00000000 --- a/admin.go +++ /dev/null @@ -1,461 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package beego - -import ( - "bytes" - "encoding/json" - "fmt" - "net/http" - "os" - "reflect" - "strconv" - "text/template" - "time" - - "github.com/prometheus/client_golang/prometheus/promhttp" - - "github.com/astaxie/beego/grace" - "github.com/astaxie/beego/logs" - "github.com/astaxie/beego/toolbox" - "github.com/astaxie/beego/utils" -) - -// BeeAdminApp is the default adminApp used by admin module. -var beeAdminApp *adminApp - -// FilterMonitorFunc is default monitor filter when admin module is enable. -// if this func returns, admin module records qps for this request by condition of this function logic. -// usage: -// func MyFilterMonitor(method, requestPath string, t time.Duration, pattern string, statusCode int) bool { -// if method == "POST" { -// return false -// } -// if t.Nanoseconds() < 100 { -// return false -// } -// if strings.HasPrefix(requestPath, "/astaxie") { -// return false -// } -// return true -// } -// beego.FilterMonitorFunc = MyFilterMonitor. -// Deprecated: using pkg/, we will delete this in v2.1.0 -var FilterMonitorFunc func(string, string, time.Duration, string, int) bool - -func init() { - beeAdminApp = &adminApp{ - routers: make(map[string]http.HandlerFunc), - } - // keep in mind that all data should be html escaped to avoid XSS attack - beeAdminApp.Route("/", adminIndex) - beeAdminApp.Route("/qps", qpsIndex) - beeAdminApp.Route("/prof", profIndex) - beeAdminApp.Route("/healthcheck", healthcheck) - beeAdminApp.Route("/task", taskStatus) - beeAdminApp.Route("/listconf", listConf) - beeAdminApp.Route("/metrics", promhttp.Handler().ServeHTTP) - FilterMonitorFunc = func(string, string, time.Duration, string, int) bool { return true } -} - -// AdminIndex is the default http.Handler for admin module. -// it matches url pattern "/". -func adminIndex(rw http.ResponseWriter, _ *http.Request) { - writeTemplate(rw, map[interface{}]interface{}{}, indexTpl, defaultScriptsTpl) -} - -// QpsIndex is the http.Handler for writing qps statistics map result info in http.ResponseWriter. -// it's registered with url pattern "/qps" in admin module. -func qpsIndex(rw http.ResponseWriter, _ *http.Request) { - data := make(map[interface{}]interface{}) - data["Content"] = toolbox.StatisticsMap.GetMap() - - // do html escape before display path, avoid xss - if content, ok := (data["Content"]).(M); ok { - if resultLists, ok := (content["Data"]).([][]string); ok { - for i := range resultLists { - if len(resultLists[i]) > 0 { - resultLists[i][0] = template.HTMLEscapeString(resultLists[i][0]) - } - } - } - } - - writeTemplate(rw, data, qpsTpl, defaultScriptsTpl) -} - -// ListConf is the http.Handler of displaying all beego configuration values as key/value pair. -// it's registered with url pattern "/listconf" in admin module. -func listConf(rw http.ResponseWriter, r *http.Request) { - r.ParseForm() - command := r.Form.Get("command") - if command == "" { - rw.Write([]byte("command not support")) - return - } - - data := make(map[interface{}]interface{}) - switch command { - case "conf": - m := make(M) - list("BConfig", BConfig, m) - m["AppConfigPath"] = template.HTMLEscapeString(appConfigPath) - m["AppConfigProvider"] = template.HTMLEscapeString(appConfigProvider) - tmpl := template.Must(template.New("dashboard").Parse(dashboardTpl)) - tmpl = template.Must(tmpl.Parse(configTpl)) - tmpl = template.Must(tmpl.Parse(defaultScriptsTpl)) - - data["Content"] = m - - tmpl.Execute(rw, data) - - case "router": - content := PrintTree() - content["Fields"] = []string{ - "Router Pattern", - "Methods", - "Controller", - } - data["Content"] = content - data["Title"] = "Routers" - writeTemplate(rw, data, routerAndFilterTpl, defaultScriptsTpl) - case "filter": - var ( - content = M{ - "Fields": []string{ - "Router Pattern", - "Filter Function", - }, - } - filterTypes = []string{} - filterTypeData = make(M) - ) - - if BeeApp.Handlers.enableFilter { - var filterType string - for k, fr := range map[int]string{ - BeforeStatic: "Before Static", - BeforeRouter: "Before Router", - BeforeExec: "Before Exec", - AfterExec: "After Exec", - FinishRouter: "Finish Router"} { - if bf := BeeApp.Handlers.filters[k]; len(bf) > 0 { - filterType = fr - filterTypes = append(filterTypes, filterType) - resultList := new([][]string) - for _, f := range bf { - var result = []string{ - // void xss - template.HTMLEscapeString(f.pattern), - template.HTMLEscapeString(utils.GetFuncName(f.filterFunc)), - } - *resultList = append(*resultList, result) - } - filterTypeData[filterType] = resultList - } - } - } - - content["Data"] = filterTypeData - content["Methods"] = filterTypes - - data["Content"] = content - data["Title"] = "Filters" - writeTemplate(rw, data, routerAndFilterTpl, defaultScriptsTpl) - default: - rw.Write([]byte("command not support")) - } -} - -func list(root string, p interface{}, m M) { - pt := reflect.TypeOf(p) - pv := reflect.ValueOf(p) - if pt.Kind() == reflect.Ptr { - pt = pt.Elem() - pv = pv.Elem() - } - for i := 0; i < pv.NumField(); i++ { - var key string - if root == "" { - key = pt.Field(i).Name - } else { - key = root + "." + pt.Field(i).Name - } - if pv.Field(i).Kind() == reflect.Struct { - list(key, pv.Field(i).Interface(), m) - } else { - m[key] = pv.Field(i).Interface() - } - } -} - -// PrintTree prints all registered routers. -// Deprecated: using pkg/, we will delete this in v2.1.0 -func PrintTree() M { - var ( - content = M{} - methods = []string{} - methodsData = make(M) - ) - for method, t := range BeeApp.Handlers.routers { - - resultList := new([][]string) - - printTree(resultList, t) - - methods = append(methods, template.HTMLEscapeString(method)) - methodsData[template.HTMLEscapeString(method)] = resultList - } - - content["Data"] = methodsData - content["Methods"] = methods - return content -} - -func printTree(resultList *[][]string, t *Tree) { - for _, tr := range t.fixrouters { - printTree(resultList, tr) - } - if t.wildcard != nil { - printTree(resultList, t.wildcard) - } - for _, l := range t.leaves { - if v, ok := l.runObject.(*ControllerInfo); ok { - if v.routerType == routerTypeBeego { - var result = []string{ - template.HTMLEscapeString(v.pattern), - template.HTMLEscapeString(fmt.Sprintf("%s", v.methods)), - template.HTMLEscapeString(v.controllerType.String()), - } - *resultList = append(*resultList, result) - } else if v.routerType == routerTypeRESTFul { - var result = []string{ - template.HTMLEscapeString(v.pattern), - template.HTMLEscapeString(fmt.Sprintf("%s", v.methods)), - "", - } - *resultList = append(*resultList, result) - } else if v.routerType == routerTypeHandler { - var result = []string{ - template.HTMLEscapeString(v.pattern), - "", - "", - } - *resultList = append(*resultList, result) - } - } - } -} - -// ProfIndex is a http.Handler for showing profile command. -// it's in url pattern "/prof" in admin module. -func profIndex(rw http.ResponseWriter, r *http.Request) { - r.ParseForm() - command := r.Form.Get("command") - if command == "" { - return - } - - var ( - format = r.Form.Get("format") - data = make(map[interface{}]interface{}) - result bytes.Buffer - ) - toolbox.ProcessInput(command, &result) - data["Content"] = template.HTMLEscapeString(result.String()) - - if format == "json" && command == "gc summary" { - dataJSON, err := json.Marshal(data) - if err != nil { - http.Error(rw, err.Error(), http.StatusInternalServerError) - return - } - writeJSON(rw, dataJSON) - return - } - - data["Title"] = template.HTMLEscapeString(command) - defaultTpl := defaultScriptsTpl - if command == "gc summary" { - defaultTpl = gcAjaxTpl - } - writeTemplate(rw, data, profillingTpl, defaultTpl) -} - -// Healthcheck is a http.Handler calling health checking and showing the result. -// it's in "/healthcheck" pattern in admin module. -func healthcheck(rw http.ResponseWriter, r *http.Request) { - var ( - result []string - data = make(map[interface{}]interface{}) - resultList = new([][]string) - content = M{ - "Fields": []string{"Name", "Message", "Status"}, - } - ) - - for name, h := range toolbox.AdminCheckList { - if err := h.Check(); err != nil { - result = []string{ - "error", - template.HTMLEscapeString(name), - template.HTMLEscapeString(err.Error()), - } - } else { - result = []string{ - "success", - template.HTMLEscapeString(name), - "OK", - } - } - *resultList = append(*resultList, result) - } - - queryParams := r.URL.Query() - jsonFlag := queryParams.Get("json") - shouldReturnJSON, _ := strconv.ParseBool(jsonFlag) - - if shouldReturnJSON { - response := buildHealthCheckResponseList(resultList) - jsonResponse, err := json.Marshal(response) - - if err != nil { - http.Error(rw, err.Error(), http.StatusInternalServerError) - } else { - writeJSON(rw, jsonResponse) - } - return - } - - content["Data"] = resultList - data["Content"] = content - data["Title"] = "Health Check" - - writeTemplate(rw, data, healthCheckTpl, defaultScriptsTpl) -} - -func buildHealthCheckResponseList(healthCheckResults *[][]string) []map[string]interface{} { - response := make([]map[string]interface{}, len(*healthCheckResults)) - - for i, healthCheckResult := range *healthCheckResults { - currentResultMap := make(map[string]interface{}) - - currentResultMap["name"] = healthCheckResult[0] - currentResultMap["message"] = healthCheckResult[1] - currentResultMap["status"] = healthCheckResult[2] - - response[i] = currentResultMap - } - - return response - -} - -func writeJSON(rw http.ResponseWriter, jsonData []byte) { - rw.Header().Set("Content-Type", "application/json") - rw.Write(jsonData) -} - -// TaskStatus is a http.Handler with running task status (task name, status and the last execution). -// it's in "/task" pattern in admin module. -func taskStatus(rw http.ResponseWriter, req *http.Request) { - data := make(map[interface{}]interface{}) - - // Run Task - req.ParseForm() - taskname := req.Form.Get("taskname") - if taskname != "" { - if t, ok := toolbox.AdminTaskList[taskname]; ok { - if err := t.Run(); err != nil { - data["Message"] = []string{"error", template.HTMLEscapeString(fmt.Sprintf("%s", err))} - } - data["Message"] = []string{"success", template.HTMLEscapeString(fmt.Sprintf("%s run success,Now the Status is
%s", taskname, t.GetStatus()))} - } else { - data["Message"] = []string{"warning", template.HTMLEscapeString(fmt.Sprintf("there's no task which named: %s", taskname))} - } - } - - // List Tasks - content := make(M) - resultList := new([][]string) - var fields = []string{ - "Task Name", - "Task Spec", - "Task Status", - "Last Time", - "", - } - for tname, tk := range toolbox.AdminTaskList { - result := []string{ - template.HTMLEscapeString(tname), - template.HTMLEscapeString(tk.GetSpec()), - template.HTMLEscapeString(tk.GetStatus()), - template.HTMLEscapeString(tk.GetPrev().String()), - } - *resultList = append(*resultList, result) - } - - content["Fields"] = fields - content["Data"] = resultList - data["Content"] = content - data["Title"] = "Tasks" - writeTemplate(rw, data, tasksTpl, defaultScriptsTpl) -} - -func writeTemplate(rw http.ResponseWriter, data map[interface{}]interface{}, tpls ...string) { - tmpl := template.Must(template.New("dashboard").Parse(dashboardTpl)) - for _, tpl := range tpls { - tmpl = template.Must(tmpl.Parse(tpl)) - } - tmpl.Execute(rw, data) -} - -// adminApp is an http.HandlerFunc map used as beeAdminApp. -type adminApp struct { - routers map[string]http.HandlerFunc -} - -// Route adds http.HandlerFunc to adminApp with url pattern. -func (admin *adminApp) Route(pattern string, f http.HandlerFunc) { - admin.routers[pattern] = f -} - -// Run adminApp http server. -// Its addr is defined in configuration file as adminhttpaddr and adminhttpport. -// Deprecated: using pkg/, we will delete this in v2.1.0 -func (admin *adminApp) Run() { - if len(toolbox.AdminTaskList) > 0 { - toolbox.StartTask() - } - addr := BConfig.Listen.AdminAddr - - if BConfig.Listen.AdminPort != 0 { - addr = fmt.Sprintf("%s:%d", BConfig.Listen.AdminAddr, BConfig.Listen.AdminPort) - } - for p, f := range admin.routers { - http.Handle(p, f) - } - logs.Info("Admin server Running on %s", addr) - - var err error - if BConfig.Listen.Graceful { - err = grace.ListenAndServe(addr, nil) - } else { - err = http.ListenAndServe(addr, nil) - } - if err != nil { - logs.Critical("Admin ListenAndServe: ", err, fmt.Sprintf("%d", os.Getpid())) - } -} diff --git a/admin_test.go b/admin_test.go deleted file mode 100644 index 205c76c2..00000000 --- a/admin_test.go +++ /dev/null @@ -1,249 +0,0 @@ -package beego - -import ( - "encoding/json" - "errors" - "fmt" - "net/http" - "net/http/httptest" - "strings" - "testing" - - "github.com/stretchr/testify/assert" - - "github.com/astaxie/beego/toolbox" -) - -type SampleDatabaseCheck struct { -} - -type SampleCacheCheck struct { -} - -func (dc *SampleDatabaseCheck) Check() error { - return nil -} - -func (cc *SampleCacheCheck) Check() error { - return errors.New("no cache detected") -} - -func TestList_01(t *testing.T) { - m := make(M) - list("BConfig", BConfig, m) - t.Log(m) - om := oldMap() - for k, v := range om { - if fmt.Sprint(m[k]) != fmt.Sprint(v) { - t.Log(k, "old-key", v, "new-key", m[k]) - t.FailNow() - } - } -} - -func oldMap() M { - m := make(M) - m["BConfig.AppName"] = BConfig.AppName - m["BConfig.RunMode"] = BConfig.RunMode - m["BConfig.RouterCaseSensitive"] = BConfig.RouterCaseSensitive - m["BConfig.ServerName"] = BConfig.ServerName - m["BConfig.RecoverPanic"] = BConfig.RecoverPanic - m["BConfig.CopyRequestBody"] = BConfig.CopyRequestBody - m["BConfig.EnableGzip"] = BConfig.EnableGzip - m["BConfig.MaxMemory"] = BConfig.MaxMemory - m["BConfig.EnableErrorsShow"] = BConfig.EnableErrorsShow - m["BConfig.Listen.Graceful"] = BConfig.Listen.Graceful - m["BConfig.Listen.ServerTimeOut"] = BConfig.Listen.ServerTimeOut - m["BConfig.Listen.ListenTCP4"] = BConfig.Listen.ListenTCP4 - m["BConfig.Listen.EnableHTTP"] = BConfig.Listen.EnableHTTP - m["BConfig.Listen.HTTPAddr"] = BConfig.Listen.HTTPAddr - m["BConfig.Listen.HTTPPort"] = BConfig.Listen.HTTPPort - m["BConfig.Listen.EnableHTTPS"] = BConfig.Listen.EnableHTTPS - m["BConfig.Listen.HTTPSAddr"] = BConfig.Listen.HTTPSAddr - m["BConfig.Listen.HTTPSPort"] = BConfig.Listen.HTTPSPort - m["BConfig.Listen.HTTPSCertFile"] = BConfig.Listen.HTTPSCertFile - m["BConfig.Listen.HTTPSKeyFile"] = BConfig.Listen.HTTPSKeyFile - m["BConfig.Listen.EnableAdmin"] = BConfig.Listen.EnableAdmin - m["BConfig.Listen.AdminAddr"] = BConfig.Listen.AdminAddr - m["BConfig.Listen.AdminPort"] = BConfig.Listen.AdminPort - m["BConfig.Listen.EnableFcgi"] = BConfig.Listen.EnableFcgi - m["BConfig.Listen.EnableStdIo"] = BConfig.Listen.EnableStdIo - m["BConfig.WebConfig.AutoRender"] = BConfig.WebConfig.AutoRender - m["BConfig.WebConfig.EnableDocs"] = BConfig.WebConfig.EnableDocs - m["BConfig.WebConfig.FlashName"] = BConfig.WebConfig.FlashName - m["BConfig.WebConfig.FlashSeparator"] = BConfig.WebConfig.FlashSeparator - m["BConfig.WebConfig.DirectoryIndex"] = BConfig.WebConfig.DirectoryIndex - m["BConfig.WebConfig.StaticDir"] = BConfig.WebConfig.StaticDir - m["BConfig.WebConfig.StaticExtensionsToGzip"] = BConfig.WebConfig.StaticExtensionsToGzip - m["BConfig.WebConfig.StaticCacheFileSize"] = BConfig.WebConfig.StaticCacheFileSize - m["BConfig.WebConfig.StaticCacheFileNum"] = BConfig.WebConfig.StaticCacheFileNum - m["BConfig.WebConfig.TemplateLeft"] = BConfig.WebConfig.TemplateLeft - m["BConfig.WebConfig.TemplateRight"] = BConfig.WebConfig.TemplateRight - m["BConfig.WebConfig.ViewsPath"] = BConfig.WebConfig.ViewsPath - m["BConfig.WebConfig.EnableXSRF"] = BConfig.WebConfig.EnableXSRF - m["BConfig.WebConfig.XSRFExpire"] = BConfig.WebConfig.XSRFExpire - m["BConfig.WebConfig.Session.SessionOn"] = BConfig.WebConfig.Session.SessionOn - m["BConfig.WebConfig.Session.SessionProvider"] = BConfig.WebConfig.Session.SessionProvider - m["BConfig.WebConfig.Session.SessionName"] = BConfig.WebConfig.Session.SessionName - m["BConfig.WebConfig.Session.SessionGCMaxLifetime"] = BConfig.WebConfig.Session.SessionGCMaxLifetime - m["BConfig.WebConfig.Session.SessionProviderConfig"] = BConfig.WebConfig.Session.SessionProviderConfig - m["BConfig.WebConfig.Session.SessionCookieLifeTime"] = BConfig.WebConfig.Session.SessionCookieLifeTime - m["BConfig.WebConfig.Session.SessionAutoSetCookie"] = BConfig.WebConfig.Session.SessionAutoSetCookie - m["BConfig.WebConfig.Session.SessionDomain"] = BConfig.WebConfig.Session.SessionDomain - m["BConfig.WebConfig.Session.SessionDisableHTTPOnly"] = BConfig.WebConfig.Session.SessionDisableHTTPOnly - m["BConfig.Log.AccessLogs"] = BConfig.Log.AccessLogs - m["BConfig.Log.EnableStaticLogs"] = BConfig.Log.EnableStaticLogs - m["BConfig.Log.AccessLogsFormat"] = BConfig.Log.AccessLogsFormat - m["BConfig.Log.FileLineNum"] = BConfig.Log.FileLineNum - m["BConfig.Log.Outputs"] = BConfig.Log.Outputs - return m -} - -func TestWriteJSON(t *testing.T) { - t.Log("Testing the adding of JSON to the response") - - w := httptest.NewRecorder() - originalBody := []int{1, 2, 3} - - res, _ := json.Marshal(originalBody) - - writeJSON(w, res) - - decodedBody := []int{} - err := json.NewDecoder(w.Body).Decode(&decodedBody) - - if err != nil { - t.Fatal("Could not decode response body into slice.") - } - - for i := range decodedBody { - if decodedBody[i] != originalBody[i] { - t.Fatalf("Expected %d but got %d in decoded body slice", originalBody[i], decodedBody[i]) - } - } -} - -func TestHealthCheckHandlerDefault(t *testing.T) { - endpointPath := "/healthcheck" - - toolbox.AddHealthCheck("database", &SampleDatabaseCheck{}) - toolbox.AddHealthCheck("cache", &SampleCacheCheck{}) - - req, err := http.NewRequest("GET", endpointPath, nil) - if err != nil { - t.Fatal(err) - } - - w := httptest.NewRecorder() - - handler := http.HandlerFunc(healthcheck) - - handler.ServeHTTP(w, req) - - if status := w.Code; status != http.StatusOK { - t.Errorf("handler returned wrong status code: got %v want %v", - status, http.StatusOK) - } - if !strings.Contains(w.Body.String(), "database") { - t.Errorf("Expected 'database' in generated template.") - } - -} - -func TestBuildHealthCheckResponseList(t *testing.T) { - healthCheckResults := [][]string{ - []string{ - "error", - "Database", - "Error occured whie starting the db", - }, - []string{ - "success", - "Cache", - "Cache started successfully", - }, - } - - responseList := buildHealthCheckResponseList(&healthCheckResults) - - if len(responseList) != len(healthCheckResults) { - t.Errorf("invalid response map length: got %d want %d", - len(responseList), len(healthCheckResults)) - } - - responseFields := []string{"name", "message", "status"} - - for _, response := range responseList { - for _, field := range responseFields { - _, ok := response[field] - if !ok { - t.Errorf("expected %s to be in the response %v", field, response) - } - } - - } - -} - -func TestHealthCheckHandlerReturnsJSON(t *testing.T) { - - toolbox.AddHealthCheck("database", &SampleDatabaseCheck{}) - toolbox.AddHealthCheck("cache", &SampleCacheCheck{}) - - req, err := http.NewRequest("GET", "/healthcheck?json=true", nil) - if err != nil { - t.Fatal(err) - } - - w := httptest.NewRecorder() - - handler := http.HandlerFunc(healthcheck) - - handler.ServeHTTP(w, req) - if status := w.Code; status != http.StatusOK { - t.Errorf("handler returned wrong status code: got %v want %v", - status, http.StatusOK) - } - - decodedResponseBody := []map[string]interface{}{} - expectedResponseBody := []map[string]interface{}{} - - expectedJSONString := []byte(` - [ - { - "message":"database", - "name":"success", - "status":"OK" - }, - { - "message":"cache", - "name":"error", - "status":"no cache detected" - } - ] - `) - - json.Unmarshal(expectedJSONString, &expectedResponseBody) - - json.Unmarshal(w.Body.Bytes(), &decodedResponseBody) - - if len(expectedResponseBody) != len(decodedResponseBody) { - t.Errorf("invalid response map length: got %d want %d", - len(decodedResponseBody), len(expectedResponseBody)) - } - assert.Equal(t, len(expectedResponseBody), len(decodedResponseBody)) - assert.Equal(t, 2, len(decodedResponseBody)) - - var database, cache map[string]interface{} - if decodedResponseBody[0]["message"] == "database" { - database = decodedResponseBody[0] - cache = decodedResponseBody[1] - } else { - database = decodedResponseBody[1] - cache = decodedResponseBody[0] - } - - assert.Equal(t, expectedResponseBody[0], database) - assert.Equal(t, expectedResponseBody[1], cache) - -} diff --git a/adminui.go b/adminui.go deleted file mode 100644 index cdcdef33..00000000 --- a/adminui.go +++ /dev/null @@ -1,356 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package beego - -var indexTpl = ` -{{define "content"}} -

Beego Admin Dashboard

-

-For detail usage please check our document: -

-

-Toolbox -

-

-Live Monitor -

-{{.Content}} -{{end}}` - -var profillingTpl = ` -{{define "content"}} -

{{.Title}}

-
-
{{.Content}}
-
-{{end}}` - -var defaultScriptsTpl = `` - -var gcAjaxTpl = ` -{{define "scripts"}} - -{{end}} -` - -var qpsTpl = `{{define "content"}} -

Requests statistics

- - - - {{range .Content.Fields}} - - {{end}} - - - - - {{range $i, $elem := .Content.Data}} - - - - - - - - - - - {{end}} - - -
- {{.}} -
{{index $elem 0}}{{index $elem 1}}{{index $elem 2}}{{index $elem 4}}{{index $elem 6}}{{index $elem 8}}{{index $elem 10}}
-{{end}}` - -var configTpl = ` -{{define "content"}} -

Configurations

-
-{{range $index, $elem := .Content}}
-{{$index}}={{$elem}}
-{{end}}
-
-{{end}} -` - -var routerAndFilterTpl = `{{define "content"}} - - -

{{.Title}}

- -{{range .Content.Methods}} - -
-
{{.}}
-
- - - - {{range $.Content.Fields}} - - {{end}} - - - - - {{$slice := index $.Content.Data .}} - {{range $i, $elem := $slice}} - - - {{range $elem}} - - {{end}} - - - {{end}} - - -
- {{.}} -
- {{.}} -
-
-
-{{end}} - - -{{end}}` - -var tasksTpl = `{{define "content"}} - -

{{.Title}}

- -{{if .Message }} -{{ $messageType := index .Message 0}} -

-{{index .Message 1}} -

-{{end}} - - - - - -{{range .Content.Fields}} - -{{end}} - - - - -{{range $i, $slice := .Content.Data}} - - {{range $slice}} - - {{end}} - - -{{end}} - -
-{{.}} -
- {{.}} - - Run -
- -{{end}}` - -var healthCheckTpl = ` -{{define "content"}} - -

{{.Title}}

- - - -{{range .Content.Fields}} - -{{end}} - - - -{{range $i, $slice := .Content.Data}} - {{ $header := index $slice 0}} - {{ if eq "success" $header}} - - {{else if eq "error" $header}} - - {{else}} - - {{end}} - {{range $j, $elem := $slice}} - {{if ne $j 0}} - - {{end}} - {{end}} - - -{{end}} - - -
- {{.}} -
- {{$elem}} - - {{$header}} -
-{{end}}` - -// The base dashboardTpl -var dashboardTpl = ` - - - - - - - - - - -Welcome to Beego Admin Dashboard - - - - - - - - - - - - - -
-{{template "content" .}} -
- - - - - - - -{{template "scripts" .}} - - -` diff --git a/app.go b/app.go deleted file mode 100644 index d86188c0..00000000 --- a/app.go +++ /dev/null @@ -1,516 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package beego - -import ( - "crypto/tls" - "crypto/x509" - "fmt" - "io/ioutil" - "net" - "net/http" - "net/http/fcgi" - "os" - "path" - "strings" - "time" - - "github.com/astaxie/beego/grace" - "github.com/astaxie/beego/logs" - "github.com/astaxie/beego/utils" - "golang.org/x/crypto/acme/autocert" -) - -var ( - // BeeApp is an application instance - // Deprecated: using pkg/, we will delete this in v2.1.0 - BeeApp *App -) - -func init() { - // create beego application - BeeApp = NewApp() -} - -// App defines beego application with a new PatternServeMux. -type App struct { - Handlers *ControllerRegister - Server *http.Server -} - -// NewApp returns a new beego application. -// Deprecated: using pkg/, we will delete this in v2.1.0 -func NewApp() *App { - cr := NewControllerRegister() - app := &App{Handlers: cr, Server: &http.Server{}} - return app -} - -// MiddleWare function for http.Handler -// Deprecated: using pkg/, we will delete this in v2.1.0 -type MiddleWare func(http.Handler) http.Handler - -// Run beego application. -// Deprecated: using pkg/, we will delete this in v2.1.0 -func (app *App) Run(mws ...MiddleWare) { - addr := BConfig.Listen.HTTPAddr - - if BConfig.Listen.HTTPPort != 0 { - addr = fmt.Sprintf("%s:%d", BConfig.Listen.HTTPAddr, BConfig.Listen.HTTPPort) - } - - var ( - err error - l net.Listener - endRunning = make(chan bool, 1) - ) - - // run cgi server - if BConfig.Listen.EnableFcgi { - if BConfig.Listen.EnableStdIo { - if err = fcgi.Serve(nil, app.Handlers); err == nil { // standard I/O - logs.Info("Use FCGI via standard I/O") - } else { - logs.Critical("Cannot use FCGI via standard I/O", err) - } - return - } - if BConfig.Listen.HTTPPort == 0 { - // remove the Socket file before start - if utils.FileExists(addr) { - os.Remove(addr) - } - l, err = net.Listen("unix", addr) - } else { - l, err = net.Listen("tcp", addr) - } - if err != nil { - logs.Critical("Listen: ", err) - } - if err = fcgi.Serve(l, app.Handlers); err != nil { - logs.Critical("fcgi.Serve: ", err) - } - return - } - - app.Server.Handler = app.Handlers - for i := len(mws) - 1; i >= 0; i-- { - if mws[i] == nil { - continue - } - app.Server.Handler = mws[i](app.Server.Handler) - } - app.Server.ReadTimeout = time.Duration(BConfig.Listen.ServerTimeOut) * time.Second - app.Server.WriteTimeout = time.Duration(BConfig.Listen.ServerTimeOut) * time.Second - app.Server.ErrorLog = logs.GetLogger("HTTP") - - // run graceful mode - if BConfig.Listen.Graceful { - httpsAddr := BConfig.Listen.HTTPSAddr - app.Server.Addr = httpsAddr - if BConfig.Listen.EnableHTTPS || BConfig.Listen.EnableMutualHTTPS { - go func() { - time.Sleep(1000 * time.Microsecond) - if BConfig.Listen.HTTPSPort != 0 { - httpsAddr = fmt.Sprintf("%s:%d", BConfig.Listen.HTTPSAddr, BConfig.Listen.HTTPSPort) - app.Server.Addr = httpsAddr - } - server := grace.NewServer(httpsAddr, app.Server.Handler) - server.Server.ReadTimeout = app.Server.ReadTimeout - server.Server.WriteTimeout = app.Server.WriteTimeout - if BConfig.Listen.EnableMutualHTTPS { - if err := server.ListenAndServeMutualTLS(BConfig.Listen.HTTPSCertFile, BConfig.Listen.HTTPSKeyFile, BConfig.Listen.TrustCaFile); err != nil { - logs.Critical("ListenAndServeTLS: ", err, fmt.Sprintf("%d", os.Getpid())) - time.Sleep(100 * time.Microsecond) - } - } else { - if BConfig.Listen.AutoTLS { - m := autocert.Manager{ - Prompt: autocert.AcceptTOS, - HostPolicy: autocert.HostWhitelist(BConfig.Listen.Domains...), - Cache: autocert.DirCache(BConfig.Listen.TLSCacheDir), - } - app.Server.TLSConfig = &tls.Config{GetCertificate: m.GetCertificate} - BConfig.Listen.HTTPSCertFile, BConfig.Listen.HTTPSKeyFile = "", "" - } - if err := server.ListenAndServeTLS(BConfig.Listen.HTTPSCertFile, BConfig.Listen.HTTPSKeyFile); err != nil { - logs.Critical("ListenAndServeTLS: ", err, fmt.Sprintf("%d", os.Getpid())) - time.Sleep(100 * time.Microsecond) - } - } - endRunning <- true - }() - } - if BConfig.Listen.EnableHTTP { - go func() { - server := grace.NewServer(addr, app.Server.Handler) - server.Server.ReadTimeout = app.Server.ReadTimeout - server.Server.WriteTimeout = app.Server.WriteTimeout - if BConfig.Listen.ListenTCP4 { - server.Network = "tcp4" - } - if err := server.ListenAndServe(); err != nil { - logs.Critical("ListenAndServe: ", err, fmt.Sprintf("%d", os.Getpid())) - time.Sleep(100 * time.Microsecond) - } - endRunning <- true - }() - } - <-endRunning - return - } - - // run normal mode - if BConfig.Listen.EnableHTTPS || BConfig.Listen.EnableMutualHTTPS { - go func() { - time.Sleep(1000 * time.Microsecond) - if BConfig.Listen.HTTPSPort != 0 { - app.Server.Addr = fmt.Sprintf("%s:%d", BConfig.Listen.HTTPSAddr, BConfig.Listen.HTTPSPort) - } else if BConfig.Listen.EnableHTTP { - logs.Info("Start https server error, conflict with http. Please reset https port") - return - } - logs.Info("https server Running on https://%s", app.Server.Addr) - if BConfig.Listen.AutoTLS { - m := autocert.Manager{ - Prompt: autocert.AcceptTOS, - HostPolicy: autocert.HostWhitelist(BConfig.Listen.Domains...), - Cache: autocert.DirCache(BConfig.Listen.TLSCacheDir), - } - app.Server.TLSConfig = &tls.Config{GetCertificate: m.GetCertificate} - BConfig.Listen.HTTPSCertFile, BConfig.Listen.HTTPSKeyFile = "", "" - } else if BConfig.Listen.EnableMutualHTTPS { - pool := x509.NewCertPool() - data, err := ioutil.ReadFile(BConfig.Listen.TrustCaFile) - if err != nil { - logs.Info("MutualHTTPS should provide TrustCaFile") - return - } - pool.AppendCertsFromPEM(data) - app.Server.TLSConfig = &tls.Config{ - ClientCAs: pool, - ClientAuth: BConfig.Listen.ClientAuth, - } - } - if err := app.Server.ListenAndServeTLS(BConfig.Listen.HTTPSCertFile, BConfig.Listen.HTTPSKeyFile); err != nil { - logs.Critical("ListenAndServeTLS: ", err) - time.Sleep(100 * time.Microsecond) - endRunning <- true - } - }() - - } - if BConfig.Listen.EnableHTTP { - go func() { - app.Server.Addr = addr - logs.Info("http server Running on http://%s", app.Server.Addr) - if BConfig.Listen.ListenTCP4 { - ln, err := net.Listen("tcp4", app.Server.Addr) - if err != nil { - logs.Critical("ListenAndServe: ", err) - time.Sleep(100 * time.Microsecond) - endRunning <- true - return - } - if err = app.Server.Serve(ln); err != nil { - logs.Critical("ListenAndServe: ", err) - time.Sleep(100 * time.Microsecond) - endRunning <- true - return - } - } else { - if err := app.Server.ListenAndServe(); err != nil { - logs.Critical("ListenAndServe: ", err) - time.Sleep(100 * time.Microsecond) - endRunning <- true - } - } - }() - } - <-endRunning -} - -// Router adds a patterned controller handler to BeeApp. -// it's an alias method of App.Router. -// usage: -// simple router -// beego.Router("/admin", &admin.UserController{}) -// beego.Router("/admin/index", &admin.ArticleController{}) -// -// regex router -// -// beego.Router("/api/:id([0-9]+)", &controllers.RController{}) -// -// custom rules -// beego.Router("/api/list",&RestController{},"*:ListFood") -// beego.Router("/api/create",&RestController{},"post:CreateFood") -// beego.Router("/api/update",&RestController{},"put:UpdateFood") -// beego.Router("/api/delete",&RestController{},"delete:DeleteFood") -// Deprecated: using pkg/, we will delete this in v2.1.0 -func Router(rootpath string, c ControllerInterface, mappingMethods ...string) *App { - BeeApp.Handlers.Add(rootpath, c, mappingMethods...) - return BeeApp -} - -// UnregisterFixedRoute unregisters the route with the specified fixedRoute. It is particularly useful -// in web applications that inherit most routes from a base webapp via the underscore -// import, and aim to overwrite only certain paths. -// The method parameter can be empty or "*" for all HTTP methods, or a particular -// method type (e.g. "GET" or "POST") for selective removal. -// -// Usage (replace "GET" with "*" for all methods): -// beego.UnregisterFixedRoute("/yourpreviouspath", "GET") -// beego.Router("/yourpreviouspath", yourControllerAddress, "get:GetNewPage") -// Deprecated: using pkg/, we will delete this in v2.1.0 -func UnregisterFixedRoute(fixedRoute string, method string) *App { - subPaths := splitPath(fixedRoute) - if method == "" || method == "*" { - for m := range HTTPMETHOD { - if _, ok := BeeApp.Handlers.routers[m]; !ok { - continue - } - if BeeApp.Handlers.routers[m].prefix == strings.Trim(fixedRoute, "/ ") { - findAndRemoveSingleTree(BeeApp.Handlers.routers[m]) - continue - } - findAndRemoveTree(subPaths, BeeApp.Handlers.routers[m], m) - } - return BeeApp - } - // Single HTTP method - um := strings.ToUpper(method) - if _, ok := BeeApp.Handlers.routers[um]; ok { - if BeeApp.Handlers.routers[um].prefix == strings.Trim(fixedRoute, "/ ") { - findAndRemoveSingleTree(BeeApp.Handlers.routers[um]) - return BeeApp - } - findAndRemoveTree(subPaths, BeeApp.Handlers.routers[um], um) - } - return BeeApp -} - -func findAndRemoveTree(paths []string, entryPointTree *Tree, method string) { - for i := range entryPointTree.fixrouters { - if entryPointTree.fixrouters[i].prefix == paths[0] { - if len(paths) == 1 { - if len(entryPointTree.fixrouters[i].fixrouters) > 0 { - // If the route had children subtrees, remove just the functional leaf, - // to allow children to function as before - if len(entryPointTree.fixrouters[i].leaves) > 0 { - entryPointTree.fixrouters[i].leaves[0] = nil - entryPointTree.fixrouters[i].leaves = entryPointTree.fixrouters[i].leaves[1:] - } - } else { - // Remove the *Tree from the fixrouters slice - entryPointTree.fixrouters[i] = nil - - if i == len(entryPointTree.fixrouters)-1 { - entryPointTree.fixrouters = entryPointTree.fixrouters[:i] - } else { - entryPointTree.fixrouters = append(entryPointTree.fixrouters[:i], entryPointTree.fixrouters[i+1:len(entryPointTree.fixrouters)]...) - } - } - return - } - findAndRemoveTree(paths[1:], entryPointTree.fixrouters[i], method) - } - } -} - -func findAndRemoveSingleTree(entryPointTree *Tree) { - if entryPointTree == nil { - return - } - if len(entryPointTree.fixrouters) > 0 { - // If the route had children subtrees, remove just the functional leaf, - // to allow children to function as before - if len(entryPointTree.leaves) > 0 { - entryPointTree.leaves[0] = nil - entryPointTree.leaves = entryPointTree.leaves[1:] - } - } -} - -// Include will generate router file in the router/xxx.go from the controller's comments -// usage: -// beego.Include(&BankAccount{}, &OrderController{},&RefundController{},&ReceiptController{}) -// type BankAccount struct{ -// beego.Controller -// } -// -// register the function -// func (b *BankAccount)Mapping(){ -// b.Mapping("ShowAccount" , b.ShowAccount) -// b.Mapping("ModifyAccount", b.ModifyAccount) -//} -// -// //@router /account/:id [get] -// func (b *BankAccount) ShowAccount(){ -// //logic -// } -// -// -// //@router /account/:id [post] -// func (b *BankAccount) ModifyAccount(){ -// //logic -// } -// -// the comments @router url methodlist -// url support all the function Router's pattern -// methodlist [get post head put delete options *] -// Deprecated: using pkg/, we will delete this in v2.1.0 -func Include(cList ...ControllerInterface) *App { - BeeApp.Handlers.Include(cList...) - return BeeApp -} - -// RESTRouter adds a restful controller handler to BeeApp. -// its' controller implements beego.ControllerInterface and -// defines a param "pattern/:objectId" to visit each resource. -// Deprecated: using pkg/, we will delete this in v2.1.0 -func RESTRouter(rootpath string, c ControllerInterface) *App { - Router(rootpath, c) - Router(path.Join(rootpath, ":objectId"), c) - return BeeApp -} - -// AutoRouter adds defined controller handler to BeeApp. -// it's same to App.AutoRouter. -// if beego.AddAuto(&MainContorlller{}) and MainController has methods List and Page, -// visit the url /main/list to exec List function or /main/page to exec Page function. -// Deprecated: using pkg/, we will delete this in v2.1.0 -func AutoRouter(c ControllerInterface) *App { - BeeApp.Handlers.AddAuto(c) - return BeeApp -} - -// AutoPrefix adds controller handler to BeeApp with prefix. -// it's same to App.AutoRouterWithPrefix. -// if beego.AutoPrefix("/admin",&MainContorlller{}) and MainController has methods List and Page, -// visit the url /admin/main/list to exec List function or /admin/main/page to exec Page function. -// Deprecated: using pkg/, we will delete this in v2.1.0 -func AutoPrefix(prefix string, c ControllerInterface) *App { - BeeApp.Handlers.AddAutoPrefix(prefix, c) - return BeeApp -} - -// Get used to register router for Get method -// usage: -// beego.Get("/", func(ctx *context.Context){ -// ctx.Output.Body("hello world") -// }) -// Deprecated: using pkg/, we will delete this in v2.1.0 -func Get(rootpath string, f FilterFunc) *App { - BeeApp.Handlers.Get(rootpath, f) - return BeeApp -} - -// Post used to register router for Post method -// usage: -// beego.Post("/api", func(ctx *context.Context){ -// ctx.Output.Body("hello world") -// }) -// Deprecated: using pkg/, we will delete this in v2.1.0 -func Post(rootpath string, f FilterFunc) *App { - BeeApp.Handlers.Post(rootpath, f) - return BeeApp -} - -// Delete used to register router for Delete method -// usage: -// beego.Delete("/api", func(ctx *context.Context){ -// ctx.Output.Body("hello world") -// }) -// Deprecated: using pkg/, we will delete this in v2.1.0 -func Delete(rootpath string, f FilterFunc) *App { - BeeApp.Handlers.Delete(rootpath, f) - return BeeApp -} - -// Put used to register router for Put method -// usage: -// beego.Put("/api", func(ctx *context.Context){ -// ctx.Output.Body("hello world") -// }) -// Deprecated: using pkg/, we will delete this in v2.1.0 -func Put(rootpath string, f FilterFunc) *App { - BeeApp.Handlers.Put(rootpath, f) - return BeeApp -} - -// Head used to register router for Head method -// usage: -// beego.Head("/api", func(ctx *context.Context){ -// ctx.Output.Body("hello world") -// }) -// Deprecated: using pkg/, we will delete this in v2.1.0 -func Head(rootpath string, f FilterFunc) *App { - BeeApp.Handlers.Head(rootpath, f) - return BeeApp -} - -// Options used to register router for Options method -// usage: -// beego.Options("/api", func(ctx *context.Context){ -// ctx.Output.Body("hello world") -// }) -// Deprecated: using pkg/, we will delete this in v2.1.0 -func Options(rootpath string, f FilterFunc) *App { - BeeApp.Handlers.Options(rootpath, f) - return BeeApp -} - -// Patch used to register router for Patch method -// usage: -// beego.Patch("/api", func(ctx *context.Context){ -// ctx.Output.Body("hello world") -// }) -// Deprecated: using pkg/, we will delete this in v2.1.0 -func Patch(rootpath string, f FilterFunc) *App { - BeeApp.Handlers.Patch(rootpath, f) - return BeeApp -} - -// Any used to register router for all methods -// usage: -// beego.Any("/api", func(ctx *context.Context){ -// ctx.Output.Body("hello world") -// }) -// Deprecated: using pkg/, we will delete this in v2.1.0 -func Any(rootpath string, f FilterFunc) *App { - BeeApp.Handlers.Any(rootpath, f) - return BeeApp -} - -// Handler used to register a Handler router -// usage: -// beego.Handler("/api", http.HandlerFunc(func (w http.ResponseWriter, r *http.Request) { -// fmt.Fprintf(w, "Hello, %q", html.EscapeString(r.URL.Path)) -// })) -// Deprecated: using pkg/, we will delete this in v2.1.0 -func Handler(rootpath string, h http.Handler, options ...interface{}) *App { - BeeApp.Handlers.Handler(rootpath, h, options...) - return BeeApp -} - -// InsertFilter adds a FilterFunc with pattern condition and action constant. -// The pos means action constant including -// beego.BeforeStatic, beego.BeforeRouter, beego.BeforeExec, beego.AfterExec and beego.FinishRouter. -// The bool params is for setting the returnOnOutput value (false allows multiple filters to execute) -// Deprecated: using pkg/, we will delete this in v2.1.0 -func InsertFilter(pattern string, pos int, filter FilterFunc, params ...bool) *App { - BeeApp.Handlers.InsertFilter(pattern, pos, filter, params...) - return BeeApp -} diff --git a/beego.go b/beego.go deleted file mode 100644 index ef93134d..00000000 --- a/beego.go +++ /dev/null @@ -1,132 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package beego - -import ( - "os" - "path/filepath" - "strconv" - "strings" -) - -const ( - // VERSION represent beego web framework version. - // Deprecated: using pkg/, we will delete this in v2.1.0 - VERSION = "1.12.2" - - // DEV is for develop - // Deprecated: using pkg/, we will delete this in v2.1.0 - DEV = "dev" - // PROD is for production - // Deprecated: using pkg/, we will delete this in v2.1.0 - PROD = "prod" -) - -// M is Map shortcut -// Deprecated: using pkg/, we will delete this in v2.1.0 -type M map[string]interface{} - -// Hook function to run -type hookfunc func() error - -var ( - hooks = make([]hookfunc, 0) //hook function slice to store the hookfunc -) - -// AddAPPStartHook is used to register the hookfunc -// The hookfuncs will run in beego.Run() -// such as initiating session , starting middleware , building template, starting admin control and so on. -// Deprecated: using pkg/, we will delete this in v2.1.0 -func AddAPPStartHook(hf ...hookfunc) { - hooks = append(hooks, hf...) -} - -// Run beego application. -// beego.Run() default run on HttpPort -// beego.Run("localhost") -// beego.Run(":8089") -// beego.Run("127.0.0.1:8089") -// Deprecated: using pkg/, we will delete this in v2.1.0 -func Run(params ...string) { - - initBeforeHTTPRun() - - if len(params) > 0 && params[0] != "" { - strs := strings.Split(params[0], ":") - if len(strs) > 0 && strs[0] != "" { - BConfig.Listen.HTTPAddr = strs[0] - } - if len(strs) > 1 && strs[1] != "" { - BConfig.Listen.HTTPPort, _ = strconv.Atoi(strs[1]) - } - - BConfig.Listen.Domains = params - } - - BeeApp.Run() -} - -// RunWithMiddleWares Run beego application with middlewares. -// Deprecated: using pkg/, we will delete this in v2.1.0 -func RunWithMiddleWares(addr string, mws ...MiddleWare) { - initBeforeHTTPRun() - - strs := strings.Split(addr, ":") - if len(strs) > 0 && strs[0] != "" { - BConfig.Listen.HTTPAddr = strs[0] - BConfig.Listen.Domains = []string{strs[0]} - } - if len(strs) > 1 && strs[1] != "" { - BConfig.Listen.HTTPPort, _ = strconv.Atoi(strs[1]) - } - - BeeApp.Run(mws...) -} - -func initBeforeHTTPRun() { - //init hooks - AddAPPStartHook( - registerMime, - registerDefaultErrorHandler, - registerSession, - registerTemplate, - registerAdmin, - registerGzip, - ) - - for _, hk := range hooks { - if err := hk(); err != nil { - panic(err) - } - } -} - -// TestBeegoInit is for test package init -// Deprecated: using pkg/, we will delete this in v2.1.0 -func TestBeegoInit(ap string) { - path := filepath.Join(ap, "conf", "app.conf") - os.Chdir(ap) - InitBeegoBeforeTest(path) -} - -// InitBeegoBeforeTest is for test package init -// Deprecated: using pkg/, we will delete this in v2.1.0 -func InitBeegoBeforeTest(appConfigPath string) { - if err := LoadAppConfig(appConfigProvider, appConfigPath); err != nil { - panic(err) - } - BConfig.RunMode = "test" - initBeforeHTTPRun() -} diff --git a/build_info.go b/build_info.go deleted file mode 100644 index 59e78127..00000000 --- a/build_info.go +++ /dev/null @@ -1,34 +0,0 @@ -// Copyright 2020 astaxie -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package beego - -var ( - // Deprecated: using pkg/, we will delete this in v2.1.0 - BuildVersion string - // Deprecated: using pkg/, we will delete this in v2.1.0 - BuildGitRevision string - // Deprecated: using pkg/, we will delete this in v2.1.0 - BuildStatus string - // Deprecated: using pkg/, we will delete this in v2.1.0 - BuildTag string - // Deprecated: using pkg/, we will delete this in v2.1.0 - BuildTime string - - // Deprecated: using pkg/, we will delete this in v2.1.0 - GoVersion string - - // Deprecated: using pkg/, we will delete this in v2.1.0 - GitBranch string -) diff --git a/cache/README.md b/cache/README.md deleted file mode 100644 index b467760a..00000000 --- a/cache/README.md +++ /dev/null @@ -1,59 +0,0 @@ -## cache -cache is a Go cache manager. It can use many cache adapters. The repo is inspired by `database/sql` . - - -## How to install? - - go get github.com/astaxie/beego/cache - - -## What adapters are supported? - -As of now this cache support memory, Memcache and Redis. - - -## How to use it? - -First you must import it - - import ( - "github.com/astaxie/beego/cache" - ) - -Then init a Cache (example with memory adapter) - - bm, err := cache.NewCache("memory", `{"interval":60}`) - -Use it like this: - - bm.Put("astaxie", 1, 10 * time.Second) - bm.Get("astaxie") - bm.IsExist("astaxie") - bm.Delete("astaxie") - - -## Memory adapter - -Configure memory adapter like this: - - {"interval":60} - -interval means the gc time. The cache will check at each time interval, whether item has expired. - - -## Memcache adapter - -Memcache adapter use the [gomemcache](http://github.com/bradfitz/gomemcache) client. - -Configure like this: - - {"conn":"127.0.0.1:11211"} - - -## Redis adapter - -Redis adapter use the [redigo](http://github.com/gomodule/redigo) client. - -Configure like this: - - {"conn":":6039"} diff --git a/cache/cache.go b/cache/cache.go deleted file mode 100644 index 82585c4e..00000000 --- a/cache/cache.go +++ /dev/null @@ -1,103 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Package cache provide a Cache interface and some implement engine -// Usage: -// -// import( -// "github.com/astaxie/beego/cache" -// ) -// -// bm, err := cache.NewCache("memory", `{"interval":60}`) -// -// Use it like this: -// -// bm.Put("astaxie", 1, 10 * time.Second) -// bm.Get("astaxie") -// bm.IsExist("astaxie") -// bm.Delete("astaxie") -// -// more docs http://beego.me/docs/module/cache.md -package cache - -import ( - "fmt" - "time" -) - -// Cache interface contains all behaviors for cache adapter. -// usage: -// cache.Register("file",cache.NewFileCache) // this operation is run in init method of file.go. -// c,err := cache.NewCache("file","{....}") -// c.Put("key",value, 3600 * time.Second) -// v := c.Get("key") -// -// c.Incr("counter") // now is 1 -// c.Incr("counter") // now is 2 -// count := c.Get("counter").(int) -type Cache interface { - // get cached value by key. - Get(key string) interface{} - // GetMulti is a batch version of Get. - GetMulti(keys []string) []interface{} - // set cached value with key and expire time. - Put(key string, val interface{}, timeout time.Duration) error - // delete cached value by key. - Delete(key string) error - // increase cached int value by key, as a counter. - Incr(key string) error - // decrease cached int value by key, as a counter. - Decr(key string) error - // check if cached value exists or not. - IsExist(key string) bool - // clear all cache. - ClearAll() error - // start gc routine based on config string settings. - StartAndGC(config string) error -} - -// Instance is a function create a new Cache Instance -type Instance func() Cache - -var adapters = make(map[string]Instance) - -// Register makes a cache adapter available by the adapter name. -// If Register is called twice with the same name or if driver is nil, -// it panics. -func Register(name string, adapter Instance) { - if adapter == nil { - panic("cache: Register adapter is nil") - } - if _, ok := adapters[name]; ok { - panic("cache: Register called twice for adapter " + name) - } - adapters[name] = adapter -} - -// NewCache Create a new cache driver by adapter name and config string. -// config need to be correct JSON as string: {"interval":360}. -// it will start gc automatically. -func NewCache(adapterName, config string) (adapter Cache, err error) { - instanceFunc, ok := adapters[adapterName] - if !ok { - err = fmt.Errorf("cache: unknown adapter name %q (forgot to import?)", adapterName) - return - } - adapter = instanceFunc() - err = adapter.StartAndGC(config) - if err != nil { - adapter = nil - } - return -} diff --git a/cache/cache_test.go b/cache/cache_test.go deleted file mode 100644 index 470c0a43..00000000 --- a/cache/cache_test.go +++ /dev/null @@ -1,191 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package cache - -import ( - "os" - "sync" - "testing" - "time" -) - -func TestCacheIncr(t *testing.T) { - bm, err := NewCache("memory", `{"interval":20}`) - if err != nil { - t.Error("init err") - } - //timeoutDuration := 10 * time.Second - - bm.Put("edwardhey", 0, time.Second*20) - wg := sync.WaitGroup{} - wg.Add(10) - for i := 0; i < 10; i++ { - go func() { - defer wg.Done() - bm.Incr("edwardhey") - }() - } - wg.Wait() - if bm.Get("edwardhey").(int) != 10 { - t.Error("Incr err") - } -} - -func TestCache(t *testing.T) { - bm, err := NewCache("memory", `{"interval":20}`) - if err != nil { - t.Error("init err") - } - timeoutDuration := 10 * time.Second - if err = bm.Put("astaxie", 1, timeoutDuration); err != nil { - t.Error("set Error", err) - } - if !bm.IsExist("astaxie") { - t.Error("check err") - } - - if v := bm.Get("astaxie"); v.(int) != 1 { - t.Error("get err") - } - - time.Sleep(30 * time.Second) - - if bm.IsExist("astaxie") { - t.Error("check err") - } - - if err = bm.Put("astaxie", 1, timeoutDuration); err != nil { - t.Error("set Error", err) - } - - if err = bm.Incr("astaxie"); err != nil { - t.Error("Incr Error", err) - } - - if v := bm.Get("astaxie"); v.(int) != 2 { - t.Error("get err") - } - - if err = bm.Decr("astaxie"); err != nil { - t.Error("Decr Error", err) - } - - if v := bm.Get("astaxie"); v.(int) != 1 { - t.Error("get err") - } - bm.Delete("astaxie") - if bm.IsExist("astaxie") { - t.Error("delete err") - } - - //test GetMulti - if err = bm.Put("astaxie", "author", timeoutDuration); err != nil { - t.Error("set Error", err) - } - if !bm.IsExist("astaxie") { - t.Error("check err") - } - if v := bm.Get("astaxie"); v.(string) != "author" { - t.Error("get err") - } - - if err = bm.Put("astaxie1", "author1", timeoutDuration); err != nil { - t.Error("set Error", err) - } - if !bm.IsExist("astaxie1") { - t.Error("check err") - } - - vv := bm.GetMulti([]string{"astaxie", "astaxie1"}) - if len(vv) != 2 { - t.Error("GetMulti ERROR") - } - if vv[0].(string) != "author" { - t.Error("GetMulti ERROR") - } - if vv[1].(string) != "author1" { - t.Error("GetMulti ERROR") - } -} - -func TestFileCache(t *testing.T) { - bm, err := NewCache("file", `{"CachePath":"cache","FileSuffix":".bin","DirectoryLevel":"2","EmbedExpiry":"0"}`) - if err != nil { - t.Error("init err") - } - timeoutDuration := 10 * time.Second - if err = bm.Put("astaxie", 1, timeoutDuration); err != nil { - t.Error("set Error", err) - } - if !bm.IsExist("astaxie") { - t.Error("check err") - } - - if v := bm.Get("astaxie"); v.(int) != 1 { - t.Error("get err") - } - - if err = bm.Incr("astaxie"); err != nil { - t.Error("Incr Error", err) - } - - if v := bm.Get("astaxie"); v.(int) != 2 { - t.Error("get err") - } - - if err = bm.Decr("astaxie"); err != nil { - t.Error("Decr Error", err) - } - - if v := bm.Get("astaxie"); v.(int) != 1 { - t.Error("get err") - } - bm.Delete("astaxie") - if bm.IsExist("astaxie") { - t.Error("delete err") - } - - //test string - if err = bm.Put("astaxie", "author", timeoutDuration); err != nil { - t.Error("set Error", err) - } - if !bm.IsExist("astaxie") { - t.Error("check err") - } - if v := bm.Get("astaxie"); v.(string) != "author" { - t.Error("get err") - } - - //test GetMulti - if err = bm.Put("astaxie1", "author1", timeoutDuration); err != nil { - t.Error("set Error", err) - } - if !bm.IsExist("astaxie1") { - t.Error("check err") - } - - vv := bm.GetMulti([]string{"astaxie", "astaxie1"}) - if len(vv) != 2 { - t.Error("GetMulti ERROR") - } - if vv[0].(string) != "author" { - t.Error("GetMulti ERROR") - } - if vv[1].(string) != "author1" { - t.Error("GetMulti ERROR") - } - - os.RemoveAll("cache") -} diff --git a/cache/conv.go b/cache/conv.go deleted file mode 100644 index 87800586..00000000 --- a/cache/conv.go +++ /dev/null @@ -1,100 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package cache - -import ( - "fmt" - "strconv" -) - -// GetString convert interface to string. -func GetString(v interface{}) string { - switch result := v.(type) { - case string: - return result - case []byte: - return string(result) - default: - if v != nil { - return fmt.Sprint(result) - } - } - return "" -} - -// GetInt convert interface to int. -func GetInt(v interface{}) int { - switch result := v.(type) { - case int: - return result - case int32: - return int(result) - case int64: - return int(result) - default: - if d := GetString(v); d != "" { - value, _ := strconv.Atoi(d) - return value - } - } - return 0 -} - -// GetInt64 convert interface to int64. -func GetInt64(v interface{}) int64 { - switch result := v.(type) { - case int: - return int64(result) - case int32: - return int64(result) - case int64: - return result - default: - - if d := GetString(v); d != "" { - value, _ := strconv.ParseInt(d, 10, 64) - return value - } - } - return 0 -} - -// GetFloat64 convert interface to float64. -func GetFloat64(v interface{}) float64 { - switch result := v.(type) { - case float64: - return result - default: - if d := GetString(v); d != "" { - value, _ := strconv.ParseFloat(d, 64) - return value - } - } - return 0 -} - -// GetBool convert interface to bool. -func GetBool(v interface{}) bool { - switch result := v.(type) { - case bool: - return result - default: - if d := GetString(v); d != "" { - value, _ := strconv.ParseBool(d) - return value - } - } - return false -} diff --git a/cache/conv_test.go b/cache/conv_test.go deleted file mode 100644 index b90e224a..00000000 --- a/cache/conv_test.go +++ /dev/null @@ -1,143 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package cache - -import ( - "testing" -) - -func TestGetString(t *testing.T) { - var t1 = "test1" - if "test1" != GetString(t1) { - t.Error("get string from string error") - } - var t2 = []byte("test2") - if "test2" != GetString(t2) { - t.Error("get string from byte array error") - } - var t3 = 1 - if "1" != GetString(t3) { - t.Error("get string from int error") - } - var t4 int64 = 1 - if "1" != GetString(t4) { - t.Error("get string from int64 error") - } - var t5 = 1.1 - if "1.1" != GetString(t5) { - t.Error("get string from float64 error") - } - - if "" != GetString(nil) { - t.Error("get string from nil error") - } -} - -func TestGetInt(t *testing.T) { - var t1 = 1 - if 1 != GetInt(t1) { - t.Error("get int from int error") - } - var t2 int32 = 32 - if 32 != GetInt(t2) { - t.Error("get int from int32 error") - } - var t3 int64 = 64 - if 64 != GetInt(t3) { - t.Error("get int from int64 error") - } - var t4 = "128" - if 128 != GetInt(t4) { - t.Error("get int from num string error") - } - if 0 != GetInt(nil) { - t.Error("get int from nil error") - } -} - -func TestGetInt64(t *testing.T) { - var i int64 = 1 - var t1 = 1 - if i != GetInt64(t1) { - t.Error("get int64 from int error") - } - var t2 int32 = 1 - if i != GetInt64(t2) { - t.Error("get int64 from int32 error") - } - var t3 int64 = 1 - if i != GetInt64(t3) { - t.Error("get int64 from int64 error") - } - var t4 = "1" - if i != GetInt64(t4) { - t.Error("get int64 from num string error") - } - if 0 != GetInt64(nil) { - t.Error("get int64 from nil") - } -} - -func TestGetFloat64(t *testing.T) { - var f = 1.11 - var t1 float32 = 1.11 - if f != GetFloat64(t1) { - t.Error("get float64 from float32 error") - } - var t2 = 1.11 - if f != GetFloat64(t2) { - t.Error("get float64 from float64 error") - } - var t3 = "1.11" - if f != GetFloat64(t3) { - t.Error("get float64 from string error") - } - - var f2 float64 = 1 - var t4 = 1 - if f2 != GetFloat64(t4) { - t.Error("get float64 from int error") - } - - if 0 != GetFloat64(nil) { - t.Error("get float64 from nil error") - } -} - -func TestGetBool(t *testing.T) { - var t1 = true - if !GetBool(t1) { - t.Error("get bool from bool error") - } - var t2 = "true" - if !GetBool(t2) { - t.Error("get bool from string error") - } - if GetBool(nil) { - t.Error("get bool from nil error") - } -} - -func byteArrayEquals(a []byte, b []byte) bool { - if len(a) != len(b) { - return false - } - for i, v := range a { - if v != b[i] { - return false - } - } - return true -} diff --git a/cache/file.go b/cache/file.go deleted file mode 100644 index 6f12d3ee..00000000 --- a/cache/file.go +++ /dev/null @@ -1,258 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package cache - -import ( - "bytes" - "crypto/md5" - "encoding/gob" - "encoding/hex" - "encoding/json" - "fmt" - "io" - "io/ioutil" - "os" - "path/filepath" - "reflect" - "strconv" - "time" -) - -// FileCacheItem is basic unit of file cache adapter. -// it contains data and expire time. -type FileCacheItem struct { - Data interface{} - Lastaccess time.Time - Expired time.Time -} - -// FileCache Config -var ( - FileCachePath = "cache" // cache directory - FileCacheFileSuffix = ".bin" // cache file suffix - FileCacheDirectoryLevel = 2 // cache file deep level if auto generated cache files. - FileCacheEmbedExpiry time.Duration // cache expire time, default is no expire forever. -) - -// FileCache is cache adapter for file storage. -type FileCache struct { - CachePath string - FileSuffix string - DirectoryLevel int - EmbedExpiry int -} - -// NewFileCache Create new file cache with no config. -// the level and expiry need set in method StartAndGC as config string. -func NewFileCache() Cache { - // return &FileCache{CachePath:FileCachePath, FileSuffix:FileCacheFileSuffix} - return &FileCache{} -} - -// StartAndGC will start and begin gc for file cache. -// the config need to be like {CachePath:"/cache","FileSuffix":".bin","DirectoryLevel":"2","EmbedExpiry":"0"} -func (fc *FileCache) StartAndGC(config string) error { - - cfg := make(map[string]string) - err := json.Unmarshal([]byte(config), &cfg) - if err != nil { - return err - } - if _, ok := cfg["CachePath"]; !ok { - cfg["CachePath"] = FileCachePath - } - if _, ok := cfg["FileSuffix"]; !ok { - cfg["FileSuffix"] = FileCacheFileSuffix - } - if _, ok := cfg["DirectoryLevel"]; !ok { - cfg["DirectoryLevel"] = strconv.Itoa(FileCacheDirectoryLevel) - } - if _, ok := cfg["EmbedExpiry"]; !ok { - cfg["EmbedExpiry"] = strconv.FormatInt(int64(FileCacheEmbedExpiry.Seconds()), 10) - } - fc.CachePath = cfg["CachePath"] - fc.FileSuffix = cfg["FileSuffix"] - fc.DirectoryLevel, _ = strconv.Atoi(cfg["DirectoryLevel"]) - fc.EmbedExpiry, _ = strconv.Atoi(cfg["EmbedExpiry"]) - - fc.Init() - return nil -} - -// Init will make new dir for file cache if not exist. -func (fc *FileCache) Init() { - if ok, _ := exists(fc.CachePath); !ok { // todo : error handle - _ = os.MkdirAll(fc.CachePath, os.ModePerm) // todo : error handle - } -} - -// get cached file name. it's md5 encoded. -func (fc *FileCache) getCacheFileName(key string) string { - m := md5.New() - io.WriteString(m, key) - keyMd5 := hex.EncodeToString(m.Sum(nil)) - cachePath := fc.CachePath - switch fc.DirectoryLevel { - case 2: - cachePath = filepath.Join(cachePath, keyMd5[0:2], keyMd5[2:4]) - case 1: - cachePath = filepath.Join(cachePath, keyMd5[0:2]) - } - - if ok, _ := exists(cachePath); !ok { // todo : error handle - _ = os.MkdirAll(cachePath, os.ModePerm) // todo : error handle - } - - return filepath.Join(cachePath, fmt.Sprintf("%s%s", keyMd5, fc.FileSuffix)) -} - -// Get value from file cache. -// if non-exist or expired, return empty string. -func (fc *FileCache) Get(key string) interface{} { - fileData, err := FileGetContents(fc.getCacheFileName(key)) - if err != nil { - return "" - } - var to FileCacheItem - GobDecode(fileData, &to) - if to.Expired.Before(time.Now()) { - return "" - } - return to.Data -} - -// GetMulti gets values from file cache. -// if non-exist or expired, return empty string. -func (fc *FileCache) GetMulti(keys []string) []interface{} { - var rc []interface{} - for _, key := range keys { - rc = append(rc, fc.Get(key)) - } - return rc -} - -// Put value into file cache. -// timeout means how long to keep this file, unit of ms. -// if timeout equals fc.EmbedExpiry(default is 0), cache this item forever. -func (fc *FileCache) Put(key string, val interface{}, timeout time.Duration) error { - gob.Register(val) - - item := FileCacheItem{Data: val} - if timeout == time.Duration(fc.EmbedExpiry) { - item.Expired = time.Now().Add((86400 * 365 * 10) * time.Second) // ten years - } else { - item.Expired = time.Now().Add(timeout) - } - item.Lastaccess = time.Now() - data, err := GobEncode(item) - if err != nil { - return err - } - return FilePutContents(fc.getCacheFileName(key), data) -} - -// Delete file cache value. -func (fc *FileCache) Delete(key string) error { - filename := fc.getCacheFileName(key) - if ok, _ := exists(filename); ok { - return os.Remove(filename) - } - return nil -} - -// Incr will increase cached int value. -// fc value is saving forever unless Delete. -func (fc *FileCache) Incr(key string) error { - data := fc.Get(key) - var incr int - if reflect.TypeOf(data).Name() != "int" { - incr = 0 - } else { - incr = data.(int) + 1 - } - fc.Put(key, incr, time.Duration(fc.EmbedExpiry)) - return nil -} - -// Decr will decrease cached int value. -func (fc *FileCache) Decr(key string) error { - data := fc.Get(key) - var decr int - if reflect.TypeOf(data).Name() != "int" || data.(int)-1 <= 0 { - decr = 0 - } else { - decr = data.(int) - 1 - } - fc.Put(key, decr, time.Duration(fc.EmbedExpiry)) - return nil -} - -// IsExist check value is exist. -func (fc *FileCache) IsExist(key string) bool { - ret, _ := exists(fc.getCacheFileName(key)) - return ret -} - -// ClearAll will clean cached files. -// not implemented. -func (fc *FileCache) ClearAll() error { - return nil -} - -// check file exist. -func exists(path string) (bool, error) { - _, err := os.Stat(path) - if err == nil { - return true, nil - } - if os.IsNotExist(err) { - return false, nil - } - return false, err -} - -// FileGetContents Get bytes to file. -// if non-exist, create this file. -func FileGetContents(filename string) (data []byte, e error) { - return ioutil.ReadFile(filename) -} - -// FilePutContents Put bytes to file. -// if non-exist, create this file. -func FilePutContents(filename string, content []byte) error { - return ioutil.WriteFile(filename, content, os.ModePerm) -} - -// GobEncode Gob encodes file cache item. -func GobEncode(data interface{}) ([]byte, error) { - buf := bytes.NewBuffer(nil) - enc := gob.NewEncoder(buf) - err := enc.Encode(data) - if err != nil { - return nil, err - } - return buf.Bytes(), err -} - -// GobDecode Gob decodes file cache item. -func GobDecode(data []byte, to *FileCacheItem) error { - buf := bytes.NewBuffer(data) - dec := gob.NewDecoder(buf) - return dec.Decode(&to) -} - -func init() { - Register("file", NewFileCache) -} diff --git a/cache/memcache/memcache.go b/cache/memcache/memcache.go deleted file mode 100644 index 19116bfa..00000000 --- a/cache/memcache/memcache.go +++ /dev/null @@ -1,188 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Package memcache for cache provider -// -// depend on github.com/bradfitz/gomemcache/memcache -// -// go install github.com/bradfitz/gomemcache/memcache -// -// Usage: -// import( -// _ "github.com/astaxie/beego/cache/memcache" -// "github.com/astaxie/beego/cache" -// ) -// -// bm, err := cache.NewCache("memcache", `{"conn":"127.0.0.1:11211"}`) -// -// more docs http://beego.me/docs/module/cache.md -package memcache - -import ( - "encoding/json" - "errors" - "strings" - "time" - - "github.com/astaxie/beego/cache" - "github.com/bradfitz/gomemcache/memcache" -) - -// Cache Memcache adapter. -type Cache struct { - conn *memcache.Client - conninfo []string -} - -// NewMemCache create new memcache adapter. -func NewMemCache() cache.Cache { - return &Cache{} -} - -// Get get value from memcache. -func (rc *Cache) Get(key string) interface{} { - if rc.conn == nil { - if err := rc.connectInit(); err != nil { - return err - } - } - if item, err := rc.conn.Get(key); err == nil { - return item.Value - } - return nil -} - -// GetMulti get value from memcache. -func (rc *Cache) GetMulti(keys []string) []interface{} { - size := len(keys) - var rv []interface{} - if rc.conn == nil { - if err := rc.connectInit(); err != nil { - for i := 0; i < size; i++ { - rv = append(rv, err) - } - return rv - } - } - mv, err := rc.conn.GetMulti(keys) - if err == nil { - for _, v := range mv { - rv = append(rv, v.Value) - } - return rv - } - for i := 0; i < size; i++ { - rv = append(rv, err) - } - return rv -} - -// Put put value to memcache. -func (rc *Cache) Put(key string, val interface{}, timeout time.Duration) error { - if rc.conn == nil { - if err := rc.connectInit(); err != nil { - return err - } - } - item := memcache.Item{Key: key, Expiration: int32(timeout / time.Second)} - if v, ok := val.([]byte); ok { - item.Value = v - } else if str, ok := val.(string); ok { - item.Value = []byte(str) - } else { - return errors.New("val only support string and []byte") - } - return rc.conn.Set(&item) -} - -// Delete delete value in memcache. -func (rc *Cache) Delete(key string) error { - if rc.conn == nil { - if err := rc.connectInit(); err != nil { - return err - } - } - return rc.conn.Delete(key) -} - -// Incr increase counter. -func (rc *Cache) Incr(key string) error { - if rc.conn == nil { - if err := rc.connectInit(); err != nil { - return err - } - } - _, err := rc.conn.Increment(key, 1) - return err -} - -// Decr decrease counter. -func (rc *Cache) Decr(key string) error { - if rc.conn == nil { - if err := rc.connectInit(); err != nil { - return err - } - } - _, err := rc.conn.Decrement(key, 1) - return err -} - -// IsExist check value exists in memcache. -func (rc *Cache) IsExist(key string) bool { - if rc.conn == nil { - if err := rc.connectInit(); err != nil { - return false - } - } - _, err := rc.conn.Get(key) - return err == nil -} - -// ClearAll clear all cached in memcache. -func (rc *Cache) ClearAll() error { - if rc.conn == nil { - if err := rc.connectInit(); err != nil { - return err - } - } - return rc.conn.FlushAll() -} - -// StartAndGC start memcache adapter. -// config string is like {"conn":"connection info"}. -// if connecting error, return. -func (rc *Cache) StartAndGC(config string) error { - var cf map[string]string - json.Unmarshal([]byte(config), &cf) - if _, ok := cf["conn"]; !ok { - return errors.New("config has no conn key") - } - rc.conninfo = strings.Split(cf["conn"], ";") - if rc.conn == nil { - if err := rc.connectInit(); err != nil { - return err - } - } - return nil -} - -// connect to memcache and keep the connection. -func (rc *Cache) connectInit() error { - rc.conn = memcache.New(rc.conninfo...) - return nil -} - -func init() { - cache.Register("memcache", NewMemCache) -} diff --git a/cache/memcache/memcache_test.go b/cache/memcache/memcache_test.go deleted file mode 100644 index d9129b69..00000000 --- a/cache/memcache/memcache_test.go +++ /dev/null @@ -1,108 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package memcache - -import ( - _ "github.com/bradfitz/gomemcache/memcache" - - "strconv" - "testing" - "time" - - "github.com/astaxie/beego/cache" -) - -func TestMemcacheCache(t *testing.T) { - bm, err := cache.NewCache("memcache", `{"conn": "127.0.0.1:11211"}`) - if err != nil { - t.Error("init err") - } - timeoutDuration := 10 * time.Second - if err = bm.Put("astaxie", "1", timeoutDuration); err != nil { - t.Error("set Error", err) - } - if !bm.IsExist("astaxie") { - t.Error("check err") - } - - time.Sleep(11 * time.Second) - - if bm.IsExist("astaxie") { - t.Error("check err") - } - if err = bm.Put("astaxie", "1", timeoutDuration); err != nil { - t.Error("set Error", err) - } - - if v, err := strconv.Atoi(string(bm.Get("astaxie").([]byte))); err != nil || v != 1 { - t.Error("get err") - } - - if err = bm.Incr("astaxie"); err != nil { - t.Error("Incr Error", err) - } - - if v, err := strconv.Atoi(string(bm.Get("astaxie").([]byte))); err != nil || v != 2 { - t.Error("get err") - } - - if err = bm.Decr("astaxie"); err != nil { - t.Error("Decr Error", err) - } - - if v, err := strconv.Atoi(string(bm.Get("astaxie").([]byte))); err != nil || v != 1 { - t.Error("get err") - } - bm.Delete("astaxie") - if bm.IsExist("astaxie") { - t.Error("delete err") - } - - //test string - if err = bm.Put("astaxie", "author", timeoutDuration); err != nil { - t.Error("set Error", err) - } - if !bm.IsExist("astaxie") { - t.Error("check err") - } - - if v := bm.Get("astaxie").([]byte); string(v) != "author" { - t.Error("get err") - } - - //test GetMulti - if err = bm.Put("astaxie1", "author1", timeoutDuration); err != nil { - t.Error("set Error", err) - } - if !bm.IsExist("astaxie1") { - t.Error("check err") - } - - vv := bm.GetMulti([]string{"astaxie", "astaxie1"}) - if len(vv) != 2 { - t.Error("GetMulti ERROR") - } - if string(vv[0].([]byte)) != "author" && string(vv[0].([]byte)) != "author1" { - t.Error("GetMulti ERROR") - } - if string(vv[1].([]byte)) != "author1" && string(vv[1].([]byte)) != "author" { - t.Error("GetMulti ERROR") - } - - // test clear all - if err = bm.ClearAll(); err != nil { - t.Error("clear all err") - } -} diff --git a/cache/memory.go b/cache/memory.go deleted file mode 100644 index d8314e3c..00000000 --- a/cache/memory.go +++ /dev/null @@ -1,256 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package cache - -import ( - "encoding/json" - "errors" - "sync" - "time" -) - -var ( - // DefaultEvery means the clock time of recycling the expired cache items in memory. - DefaultEvery = 60 // 1 minute -) - -// MemoryItem store memory cache item. -type MemoryItem struct { - val interface{} - createdTime time.Time - lifespan time.Duration -} - -func (mi *MemoryItem) isExpire() bool { - // 0 means forever - if mi.lifespan == 0 { - return false - } - return time.Now().Sub(mi.createdTime) > mi.lifespan -} - -// MemoryCache is Memory cache adapter. -// it contains a RW locker for safe map storage. -type MemoryCache struct { - sync.RWMutex - dur time.Duration - items map[string]*MemoryItem - Every int // run an expiration check Every clock time -} - -// NewMemoryCache returns a new MemoryCache. -func NewMemoryCache() Cache { - cache := MemoryCache{items: make(map[string]*MemoryItem)} - return &cache -} - -// Get cache from memory. -// if non-existed or expired, return nil. -func (bc *MemoryCache) Get(name string) interface{} { - bc.RLock() - defer bc.RUnlock() - if itm, ok := bc.items[name]; ok { - if itm.isExpire() { - return nil - } - return itm.val - } - return nil -} - -// GetMulti gets caches from memory. -// if non-existed or expired, return nil. -func (bc *MemoryCache) GetMulti(names []string) []interface{} { - var rc []interface{} - for _, name := range names { - rc = append(rc, bc.Get(name)) - } - return rc -} - -// Put cache to memory. -// if lifespan is 0, it will be forever till restart. -func (bc *MemoryCache) Put(name string, value interface{}, lifespan time.Duration) error { - bc.Lock() - defer bc.Unlock() - bc.items[name] = &MemoryItem{ - val: value, - createdTime: time.Now(), - lifespan: lifespan, - } - return nil -} - -// Delete cache in memory. -func (bc *MemoryCache) Delete(name string) error { - bc.Lock() - defer bc.Unlock() - if _, ok := bc.items[name]; !ok { - return errors.New("key not exist") - } - delete(bc.items, name) - if _, ok := bc.items[name]; ok { - return errors.New("delete key error") - } - return nil -} - -// Incr increase cache counter in memory. -// it supports int,int32,int64,uint,uint32,uint64. -func (bc *MemoryCache) Incr(key string) error { - bc.Lock() - defer bc.Unlock() - itm, ok := bc.items[key] - if !ok { - return errors.New("key not exist") - } - switch val := itm.val.(type) { - case int: - itm.val = val + 1 - case int32: - itm.val = val + 1 - case int64: - itm.val = val + 1 - case uint: - itm.val = val + 1 - case uint32: - itm.val = val + 1 - case uint64: - itm.val = val + 1 - default: - return errors.New("item val is not (u)int (u)int32 (u)int64") - } - return nil -} - -// Decr decrease counter in memory. -func (bc *MemoryCache) Decr(key string) error { - bc.Lock() - defer bc.Unlock() - itm, ok := bc.items[key] - if !ok { - return errors.New("key not exist") - } - switch val := itm.val.(type) { - case int: - itm.val = val - 1 - case int64: - itm.val = val - 1 - case int32: - itm.val = val - 1 - case uint: - if val > 0 { - itm.val = val - 1 - } else { - return errors.New("item val is less than 0") - } - case uint32: - if val > 0 { - itm.val = val - 1 - } else { - return errors.New("item val is less than 0") - } - case uint64: - if val > 0 { - itm.val = val - 1 - } else { - return errors.New("item val is less than 0") - } - default: - return errors.New("item val is not int int64 int32") - } - return nil -} - -// IsExist check cache exist in memory. -func (bc *MemoryCache) IsExist(name string) bool { - bc.RLock() - defer bc.RUnlock() - if v, ok := bc.items[name]; ok { - return !v.isExpire() - } - return false -} - -// ClearAll will delete all cache in memory. -func (bc *MemoryCache) ClearAll() error { - bc.Lock() - defer bc.Unlock() - bc.items = make(map[string]*MemoryItem) - return nil -} - -// StartAndGC start memory cache. it will check expiration in every clock time. -func (bc *MemoryCache) StartAndGC(config string) error { - var cf map[string]int - json.Unmarshal([]byte(config), &cf) - if _, ok := cf["interval"]; !ok { - cf = make(map[string]int) - cf["interval"] = DefaultEvery - } - dur := time.Duration(cf["interval"]) * time.Second - bc.Every = cf["interval"] - bc.dur = dur - go bc.vacuum() - return nil -} - -// check expiration. -func (bc *MemoryCache) vacuum() { - bc.RLock() - every := bc.Every - bc.RUnlock() - - if every < 1 { - return - } - for { - <-time.After(bc.dur) - bc.RLock() - if bc.items == nil { - bc.RUnlock() - return - } - bc.RUnlock() - if keys := bc.expiredKeys(); len(keys) != 0 { - bc.clearItems(keys) - } - } -} - -// expiredKeys returns key list which are expired. -func (bc *MemoryCache) expiredKeys() (keys []string) { - bc.RLock() - defer bc.RUnlock() - for key, itm := range bc.items { - if itm.isExpire() { - keys = append(keys, key) - } - } - return -} - -// clearItems removes all the items which key in keys. -func (bc *MemoryCache) clearItems(keys []string) { - bc.Lock() - defer bc.Unlock() - for _, key := range keys { - delete(bc.items, key) - } -} - -func init() { - Register("memory", NewMemoryCache) -} diff --git a/cache/redis/redis.go b/cache/redis/redis.go deleted file mode 100644 index d8737b3c..00000000 --- a/cache/redis/redis.go +++ /dev/null @@ -1,272 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Package redis for cache provider -// -// depend on github.com/gomodule/redigo/redis -// -// go install github.com/gomodule/redigo/redis -// -// Usage: -// import( -// _ "github.com/astaxie/beego/cache/redis" -// "github.com/astaxie/beego/cache" -// ) -// -// bm, err := cache.NewCache("redis", `{"conn":"127.0.0.1:11211"}`) -// -// more docs http://beego.me/docs/module/cache.md -package redis - -import ( - "encoding/json" - "errors" - "fmt" - "strconv" - "time" - - "github.com/gomodule/redigo/redis" - - "github.com/astaxie/beego/cache" - "strings" -) - -var ( - // DefaultKey the collection name of redis for cache adapter. - DefaultKey = "beecacheRedis" -) - -// Cache is Redis cache adapter. -type Cache struct { - p *redis.Pool // redis connection pool - conninfo string - dbNum int - key string - password string - maxIdle int - - //the timeout to a value less than the redis server's timeout. - timeout time.Duration -} - -// NewRedisCache create new redis cache with default collection name. -func NewRedisCache() cache.Cache { - return &Cache{key: DefaultKey} -} - -// actually do the redis cmds, args[0] must be the key name. -func (rc *Cache) do(commandName string, args ...interface{}) (reply interface{}, err error) { - if len(args) < 1 { - return nil, errors.New("missing required arguments") - } - args[0] = rc.associate(args[0]) - c := rc.p.Get() - defer c.Close() - - return c.Do(commandName, args...) -} - -// associate with config key. -func (rc *Cache) associate(originKey interface{}) string { - return fmt.Sprintf("%s:%s", rc.key, originKey) -} - -// Get cache from redis. -func (rc *Cache) Get(key string) interface{} { - if v, err := rc.do("GET", key); err == nil { - return v - } - return nil -} - -// GetMulti get cache from redis. -func (rc *Cache) GetMulti(keys []string) []interface{} { - c := rc.p.Get() - defer c.Close() - var args []interface{} - for _, key := range keys { - args = append(args, rc.associate(key)) - } - values, err := redis.Values(c.Do("MGET", args...)) - if err != nil { - return nil - } - return values -} - -// Put put cache to redis. -func (rc *Cache) Put(key string, val interface{}, timeout time.Duration) error { - _, err := rc.do("SETEX", key, int64(timeout/time.Second), val) - return err -} - -// Delete delete cache in redis. -func (rc *Cache) Delete(key string) error { - _, err := rc.do("DEL", key) - return err -} - -// IsExist check cache's existence in redis. -func (rc *Cache) IsExist(key string) bool { - v, err := redis.Bool(rc.do("EXISTS", key)) - if err != nil { - return false - } - return v -} - -// Incr increase counter in redis. -func (rc *Cache) Incr(key string) error { - _, err := redis.Bool(rc.do("INCRBY", key, 1)) - return err -} - -// Decr decrease counter in redis. -func (rc *Cache) Decr(key string) error { - _, err := redis.Bool(rc.do("INCRBY", key, -1)) - return err -} - -// ClearAll clean all cache in redis. delete this redis collection. -func (rc *Cache) ClearAll() error { - cachedKeys, err := rc.Scan(rc.key + ":*") - if err != nil { - return err - } - c := rc.p.Get() - defer c.Close() - for _, str := range cachedKeys { - if _, err = c.Do("DEL", str); err != nil { - return err - } - } - return err -} - -// Scan scan all keys matching the pattern. a better choice than `keys` -func (rc *Cache) Scan(pattern string) (keys []string, err error) { - c := rc.p.Get() - defer c.Close() - var ( - cursor uint64 = 0 // start - result []interface{} - list []string - ) - for { - result, err = redis.Values(c.Do("SCAN", cursor, "MATCH", pattern, "COUNT", 1024)) - if err != nil { - return - } - list, err = redis.Strings(result[1], nil) - if err != nil { - return - } - keys = append(keys, list...) - cursor, err = redis.Uint64(result[0], nil) - if err != nil { - return - } - if cursor == 0 { // over - return - } - } -} - -// StartAndGC start redis cache adapter. -// config is like {"key":"collection key","conn":"connection info","dbNum":"0"} -// the cache item in redis are stored forever, -// so no gc operation. -func (rc *Cache) StartAndGC(config string) error { - var cf map[string]string - json.Unmarshal([]byte(config), &cf) - - if _, ok := cf["key"]; !ok { - cf["key"] = DefaultKey - } - if _, ok := cf["conn"]; !ok { - return errors.New("config has no conn key") - } - - // Format redis://@: - cf["conn"] = strings.Replace(cf["conn"], "redis://", "", 1) - if i := strings.Index(cf["conn"], "@"); i > -1 { - cf["password"] = cf["conn"][0:i] - cf["conn"] = cf["conn"][i+1:] - } - - if _, ok := cf["dbNum"]; !ok { - cf["dbNum"] = "0" - } - if _, ok := cf["password"]; !ok { - cf["password"] = "" - } - if _, ok := cf["maxIdle"]; !ok { - cf["maxIdle"] = "3" - } - if _, ok := cf["timeout"]; !ok { - cf["timeout"] = "180s" - } - rc.key = cf["key"] - rc.conninfo = cf["conn"] - rc.dbNum, _ = strconv.Atoi(cf["dbNum"]) - rc.password = cf["password"] - rc.maxIdle, _ = strconv.Atoi(cf["maxIdle"]) - - if v, err := time.ParseDuration(cf["timeout"]); err == nil { - rc.timeout = v - } else { - rc.timeout = 180 * time.Second - } - - rc.connectInit() - - c := rc.p.Get() - defer c.Close() - - return c.Err() -} - -// connect to redis. -func (rc *Cache) connectInit() { - dialFunc := func() (c redis.Conn, err error) { - c, err = redis.Dial("tcp", rc.conninfo) - if err != nil { - return nil, err - } - - if rc.password != "" { - if _, err := c.Do("AUTH", rc.password); err != nil { - c.Close() - return nil, err - } - } - - _, selecterr := c.Do("SELECT", rc.dbNum) - if selecterr != nil { - c.Close() - return nil, selecterr - } - return - } - // initialize a new pool - rc.p = &redis.Pool{ - MaxIdle: rc.maxIdle, - IdleTimeout: rc.timeout, - Dial: dialFunc, - } -} - -func init() { - cache.Register("redis", NewRedisCache) -} diff --git a/cache/redis/redis_test.go b/cache/redis/redis_test.go deleted file mode 100644 index 60a19180..00000000 --- a/cache/redis/redis_test.go +++ /dev/null @@ -1,144 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package redis - -import ( - "fmt" - "testing" - "time" - - "github.com/astaxie/beego/cache" - "github.com/gomodule/redigo/redis" - "github.com/stretchr/testify/assert" -) - -func TestRedisCache(t *testing.T) { - bm, err := cache.NewCache("redis", `{"conn": "127.0.0.1:6379"}`) - if err != nil { - t.Error("init err") - } - timeoutDuration := 10 * time.Second - if err = bm.Put("astaxie", 1, timeoutDuration); err != nil { - t.Error("set Error", err) - } - if !bm.IsExist("astaxie") { - t.Error("check err") - } - - time.Sleep(11 * time.Second) - - if bm.IsExist("astaxie") { - t.Error("check err") - } - if err = bm.Put("astaxie", 1, timeoutDuration); err != nil { - t.Error("set Error", err) - } - - if v, _ := redis.Int(bm.Get("astaxie"), err); v != 1 { - t.Error("get err") - } - - if err = bm.Incr("astaxie"); err != nil { - t.Error("Incr Error", err) - } - - if v, _ := redis.Int(bm.Get("astaxie"), err); v != 2 { - t.Error("get err") - } - - if err = bm.Decr("astaxie"); err != nil { - t.Error("Decr Error", err) - } - - if v, _ := redis.Int(bm.Get("astaxie"), err); v != 1 { - t.Error("get err") - } - bm.Delete("astaxie") - if bm.IsExist("astaxie") { - t.Error("delete err") - } - - //test string - if err = bm.Put("astaxie", "author", timeoutDuration); err != nil { - t.Error("set Error", err) - } - if !bm.IsExist("astaxie") { - t.Error("check err") - } - - if v, _ := redis.String(bm.Get("astaxie"), err); v != "author" { - t.Error("get err") - } - - //test GetMulti - if err = bm.Put("astaxie1", "author1", timeoutDuration); err != nil { - t.Error("set Error", err) - } - if !bm.IsExist("astaxie1") { - t.Error("check err") - } - - vv := bm.GetMulti([]string{"astaxie", "astaxie1"}) - if len(vv) != 2 { - t.Error("GetMulti ERROR") - } - if v, _ := redis.String(vv[0], nil); v != "author" { - t.Error("GetMulti ERROR") - } - if v, _ := redis.String(vv[1], nil); v != "author1" { - t.Error("GetMulti ERROR") - } - - // test clear all - if err = bm.ClearAll(); err != nil { - t.Error("clear all err") - } -} - -func TestCache_Scan(t *testing.T) { - timeoutDuration := 10 * time.Second - // init - bm, err := cache.NewCache("redis", `{"conn": "127.0.0.1:6379"}`) - if err != nil { - t.Error("init err") - } - // insert all - for i := 0; i < 10000; i++ { - if err = bm.Put(fmt.Sprintf("astaxie%d", i), fmt.Sprintf("author%d", i), timeoutDuration); err != nil { - t.Error("set Error", err) - } - } - // scan all for the first time - keys, err := bm.(*Cache).Scan(DefaultKey + ":*") - if err != nil { - t.Error("scan Error", err) - } - - assert.Equal(t, 10000, len(keys), "scan all error") - - // clear all - if err = bm.ClearAll(); err != nil { - t.Error("clear all err") - } - - // scan all for the second time - keys, err = bm.(*Cache).Scan(DefaultKey + ":*") - if err != nil { - t.Error("scan Error", err) - } - if len(keys) != 0 { - t.Error("scan all err") - } -} diff --git a/cache/ssdb/ssdb.go b/cache/ssdb/ssdb.go deleted file mode 100644 index fa2ce04b..00000000 --- a/cache/ssdb/ssdb.go +++ /dev/null @@ -1,231 +0,0 @@ -package ssdb - -import ( - "encoding/json" - "errors" - "strconv" - "strings" - "time" - - "github.com/ssdb/gossdb/ssdb" - - "github.com/astaxie/beego/cache" -) - -// Cache SSDB adapter -type Cache struct { - conn *ssdb.Client - conninfo []string -} - -//NewSsdbCache create new ssdb adapter. -func NewSsdbCache() cache.Cache { - return &Cache{} -} - -// Get get value from memcache. -func (rc *Cache) Get(key string) interface{} { - if rc.conn == nil { - if err := rc.connectInit(); err != nil { - return nil - } - } - value, err := rc.conn.Get(key) - if err == nil { - return value - } - return nil -} - -// GetMulti get value from memcache. -func (rc *Cache) GetMulti(keys []string) []interface{} { - size := len(keys) - var values []interface{} - if rc.conn == nil { - if err := rc.connectInit(); err != nil { - for i := 0; i < size; i++ { - values = append(values, err) - } - return values - } - } - res, err := rc.conn.Do("multi_get", keys) - resSize := len(res) - if err == nil { - for i := 1; i < resSize; i += 2 { - values = append(values, res[i+1]) - } - return values - } - for i := 0; i < size; i++ { - values = append(values, err) - } - return values -} - -// DelMulti get value from memcache. -func (rc *Cache) DelMulti(keys []string) error { - if rc.conn == nil { - if err := rc.connectInit(); err != nil { - return err - } - } - _, err := rc.conn.Do("multi_del", keys) - return err -} - -// Put put value to memcache. only support string. -func (rc *Cache) Put(key string, value interface{}, timeout time.Duration) error { - if rc.conn == nil { - if err := rc.connectInit(); err != nil { - return err - } - } - v, ok := value.(string) - if !ok { - return errors.New("value must string") - } - var resp []string - var err error - ttl := int(timeout / time.Second) - if ttl < 0 { - resp, err = rc.conn.Do("set", key, v) - } else { - resp, err = rc.conn.Do("setx", key, v, ttl) - } - if err != nil { - return err - } - if len(resp) == 2 && resp[0] == "ok" { - return nil - } - return errors.New("bad response") -} - -// Delete delete value in memcache. -func (rc *Cache) Delete(key string) error { - if rc.conn == nil { - if err := rc.connectInit(); err != nil { - return err - } - } - _, err := rc.conn.Del(key) - return err -} - -// Incr increase counter. -func (rc *Cache) Incr(key string) error { - if rc.conn == nil { - if err := rc.connectInit(); err != nil { - return err - } - } - _, err := rc.conn.Do("incr", key, 1) - return err -} - -// Decr decrease counter. -func (rc *Cache) Decr(key string) error { - if rc.conn == nil { - if err := rc.connectInit(); err != nil { - return err - } - } - _, err := rc.conn.Do("incr", key, -1) - return err -} - -// IsExist check value exists in memcache. -func (rc *Cache) IsExist(key string) bool { - if rc.conn == nil { - if err := rc.connectInit(); err != nil { - return false - } - } - resp, err := rc.conn.Do("exists", key) - if err != nil { - return false - } - if len(resp) == 2 && resp[1] == "1" { - return true - } - return false - -} - -// ClearAll clear all cached in memcache. -func (rc *Cache) ClearAll() error { - if rc.conn == nil { - if err := rc.connectInit(); err != nil { - return err - } - } - keyStart, keyEnd, limit := "", "", 50 - resp, err := rc.Scan(keyStart, keyEnd, limit) - for err == nil { - size := len(resp) - if size == 1 { - return nil - } - keys := []string{} - for i := 1; i < size; i += 2 { - keys = append(keys, resp[i]) - } - _, e := rc.conn.Do("multi_del", keys) - if e != nil { - return e - } - keyStart = resp[size-2] - resp, err = rc.Scan(keyStart, keyEnd, limit) - } - return err -} - -// Scan key all cached in ssdb. -func (rc *Cache) Scan(keyStart string, keyEnd string, limit int) ([]string, error) { - if rc.conn == nil { - if err := rc.connectInit(); err != nil { - return nil, err - } - } - resp, err := rc.conn.Do("scan", keyStart, keyEnd, limit) - if err != nil { - return nil, err - } - return resp, nil -} - -// StartAndGC start memcache adapter. -// config string is like {"conn":"connection info"}. -// if connecting error, return. -func (rc *Cache) StartAndGC(config string) error { - var cf map[string]string - json.Unmarshal([]byte(config), &cf) - if _, ok := cf["conn"]; !ok { - return errors.New("config has no conn key") - } - rc.conninfo = strings.Split(cf["conn"], ";") - if rc.conn == nil { - if err := rc.connectInit(); err != nil { - return err - } - } - return nil -} - -// connect to memcache and keep the connection. -func (rc *Cache) connectInit() error { - conninfoArray := strings.Split(rc.conninfo[0], ":") - host := conninfoArray[0] - port, e := strconv.Atoi(conninfoArray[1]) - if e != nil { - return e - } - var err error - rc.conn, err = ssdb.Connect(host, port) - return err -} - -func init() { - cache.Register("ssdb", NewSsdbCache) -} diff --git a/cache/ssdb/ssdb_test.go b/cache/ssdb/ssdb_test.go deleted file mode 100644 index dd474960..00000000 --- a/cache/ssdb/ssdb_test.go +++ /dev/null @@ -1,104 +0,0 @@ -package ssdb - -import ( - "strconv" - "testing" - "time" - - "github.com/astaxie/beego/cache" -) - -func TestSsdbcacheCache(t *testing.T) { - ssdb, err := cache.NewCache("ssdb", `{"conn": "127.0.0.1:8888"}`) - if err != nil { - t.Error("init err") - } - - // test put and exist - if ssdb.IsExist("ssdb") { - t.Error("check err") - } - timeoutDuration := 10 * time.Second - //timeoutDuration := -10*time.Second if timeoutDuration is negtive,it means permanent - if err = ssdb.Put("ssdb", "ssdb", timeoutDuration); err != nil { - t.Error("set Error", err) - } - if !ssdb.IsExist("ssdb") { - t.Error("check err") - } - - // Get test done - if err = ssdb.Put("ssdb", "ssdb", timeoutDuration); err != nil { - t.Error("set Error", err) - } - - if v := ssdb.Get("ssdb"); v != "ssdb" { - t.Error("get Error") - } - - //inc/dec test done - if err = ssdb.Put("ssdb", "2", timeoutDuration); err != nil { - t.Error("set Error", err) - } - if err = ssdb.Incr("ssdb"); err != nil { - t.Error("incr Error", err) - } - - if v, err := strconv.Atoi(ssdb.Get("ssdb").(string)); err != nil || v != 3 { - t.Error("get err") - } - - if err = ssdb.Decr("ssdb"); err != nil { - t.Error("decr error") - } - - // test del - if err = ssdb.Put("ssdb", "3", timeoutDuration); err != nil { - t.Error("set Error", err) - } - if v, err := strconv.Atoi(ssdb.Get("ssdb").(string)); err != nil || v != 3 { - t.Error("get err") - } - if err := ssdb.Delete("ssdb"); err == nil { - if ssdb.IsExist("ssdb") { - t.Error("delete err") - } - } - - //test string - if err = ssdb.Put("ssdb", "ssdb", -10*time.Second); err != nil { - t.Error("set Error", err) - } - if !ssdb.IsExist("ssdb") { - t.Error("check err") - } - if v := ssdb.Get("ssdb").(string); v != "ssdb" { - t.Error("get err") - } - - //test GetMulti done - if err = ssdb.Put("ssdb1", "ssdb1", -10*time.Second); err != nil { - t.Error("set Error", err) - } - if !ssdb.IsExist("ssdb1") { - t.Error("check err") - } - vv := ssdb.GetMulti([]string{"ssdb", "ssdb1"}) - if len(vv) != 2 { - t.Error("getmulti error") - } - if vv[0].(string) != "ssdb" { - t.Error("getmulti error") - } - if vv[1].(string) != "ssdb1" { - t.Error("getmulti error") - } - - // test clear all done - if err = ssdb.ClearAll(); err != nil { - t.Error("clear all err") - } - if ssdb.IsExist("ssdb") || ssdb.IsExist("ssdb1") { - t.Error("check err") - } -} diff --git a/config.go b/config.go deleted file mode 100644 index 7917528e..00000000 --- a/config.go +++ /dev/null @@ -1,557 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package beego - -import ( - "crypto/tls" - "fmt" - "os" - "path/filepath" - "reflect" - "runtime" - "strings" - - "github.com/astaxie/beego/config" - "github.com/astaxie/beego/context" - "github.com/astaxie/beego/logs" - "github.com/astaxie/beego/session" - "github.com/astaxie/beego/utils" -) - -// Config is the main struct for BConfig -// Deprecated: using pkg/, we will delete this in v2.1.0 -type Config struct { - AppName string //Application name - RunMode string //Running Mode: dev | prod - RouterCaseSensitive bool - ServerName string - RecoverPanic bool - RecoverFunc func(*context.Context) - CopyRequestBody bool - EnableGzip bool - MaxMemory int64 - EnableErrorsShow bool - EnableErrorsRender bool - Listen Listen - WebConfig WebConfig - Log LogConfig -} - -// Listen holds for http and https related config -// Deprecated: using pkg/, we will delete this in v2.1.0 -type Listen struct { - Graceful bool // Graceful means use graceful module to start the server - ServerTimeOut int64 - ListenTCP4 bool - EnableHTTP bool - HTTPAddr string - HTTPPort int - AutoTLS bool - Domains []string - TLSCacheDir string - EnableHTTPS bool - EnableMutualHTTPS bool - HTTPSAddr string - HTTPSPort int - HTTPSCertFile string - HTTPSKeyFile string - TrustCaFile string - ClientAuth tls.ClientAuthType - EnableAdmin bool - AdminAddr string - AdminPort int - EnableFcgi bool - EnableStdIo bool // EnableStdIo works with EnableFcgi Use FCGI via standard I/O -} - -// WebConfig holds web related config -// Deprecated: using pkg/, we will delete this in v2.1.0 -type WebConfig struct { - AutoRender bool - EnableDocs bool - FlashName string - FlashSeparator string - DirectoryIndex bool - StaticDir map[string]string - StaticExtensionsToGzip []string - StaticCacheFileSize int - StaticCacheFileNum int - TemplateLeft string - TemplateRight string - ViewsPath string - EnableXSRF bool - XSRFKey string - XSRFExpire int - Session SessionConfig -} - -// SessionConfig holds session related config -// Deprecated: using pkg/, we will delete this in v2.1.0 -type SessionConfig struct { - SessionOn bool - SessionProvider string - SessionName string - SessionGCMaxLifetime int64 - SessionProviderConfig string - SessionCookieLifeTime int - SessionAutoSetCookie bool - SessionDomain string - SessionDisableHTTPOnly bool // used to allow for cross domain cookies/javascript cookies. - SessionEnableSidInHTTPHeader bool // enable store/get the sessionId into/from http headers - SessionNameInHTTPHeader string - SessionEnableSidInURLQuery bool // enable get the sessionId from Url Query params -} - -// LogConfig holds Log related config -// Deprecated: using pkg/, we will delete this in v2.1.0 -type LogConfig struct { - AccessLogs bool - EnableStaticLogs bool //log static files requests default: false - AccessLogsFormat string //access log format: JSON_FORMAT, APACHE_FORMAT or empty string - FileLineNum bool - Outputs map[string]string // Store Adaptor : config -} - -var ( - // BConfig is the default config for Application - // Deprecated: using pkg/, we will delete this in v2.1.0 - BConfig *Config - // AppConfig is the instance of Config, store the config information from file - // Deprecated: using pkg/, we will delete this in v2.1.0 - AppConfig *beegoAppConfig - // AppPath is the absolute path to the app - // Deprecated: using pkg/, we will delete this in v2.1.0 - AppPath string - // GlobalSessions is the instance for the session manager - // Deprecated: using pkg/, we will delete this in v2.1.0 - GlobalSessions *session.Manager - - // appConfigPath is the path to the config files - appConfigPath string - // appConfigProvider is the provider for the config, default is ini - appConfigProvider = "ini" - // WorkPath is the absolute path to project root directory - // Deprecated: using pkg/, we will delete this in v2.1.0 - WorkPath string -) - -func init() { - BConfig = newBConfig() - var err error - if AppPath, err = filepath.Abs(filepath.Dir(os.Args[0])); err != nil { - panic(err) - } - WorkPath, err = os.Getwd() - if err != nil { - panic(err) - } - var filename = "app.conf" - if os.Getenv("BEEGO_RUNMODE") != "" { - filename = os.Getenv("BEEGO_RUNMODE") + ".app.conf" - } - appConfigPath = filepath.Join(WorkPath, "conf", filename) - if configPath := os.Getenv("BEEGO_CONFIG_PATH"); configPath != "" { - appConfigPath = configPath - } - if !utils.FileExists(appConfigPath) { - appConfigPath = filepath.Join(AppPath, "conf", filename) - if !utils.FileExists(appConfigPath) { - AppConfig = &beegoAppConfig{innerConfig: config.NewFakeConfig()} - return - } - } - if err = parseConfig(appConfigPath); err != nil { - panic(err) - } -} - -func recoverPanic(ctx *context.Context) { - if err := recover(); err != nil { - if err == ErrAbort { - return - } - if !BConfig.RecoverPanic { - panic(err) - } - if BConfig.EnableErrorsShow { - if _, ok := ErrorMaps[fmt.Sprint(err)]; ok { - exception(fmt.Sprint(err), ctx) - return - } - } - var stack string - logs.Critical("the request url is ", ctx.Input.URL()) - logs.Critical("Handler crashed with error", err) - for i := 1; ; i++ { - _, file, line, ok := runtime.Caller(i) - if !ok { - break - } - logs.Critical(fmt.Sprintf("%s:%d", file, line)) - stack = stack + fmt.Sprintln(fmt.Sprintf("%s:%d", file, line)) - } - if BConfig.RunMode == DEV && BConfig.EnableErrorsRender { - showErr(err, ctx, stack) - } - if ctx.Output.Status != 0 { - ctx.ResponseWriter.WriteHeader(ctx.Output.Status) - } else { - ctx.ResponseWriter.WriteHeader(500) - } - } -} - -func newBConfig() *Config { - return &Config{ - AppName: "beego", - RunMode: PROD, - RouterCaseSensitive: true, - ServerName: "beegoServer:" + VERSION, - RecoverPanic: true, - RecoverFunc: recoverPanic, - CopyRequestBody: false, - EnableGzip: false, - MaxMemory: 1 << 26, //64MB - EnableErrorsShow: true, - EnableErrorsRender: true, - Listen: Listen{ - Graceful: false, - ServerTimeOut: 0, - ListenTCP4: false, - EnableHTTP: true, - AutoTLS: false, - Domains: []string{}, - TLSCacheDir: ".", - HTTPAddr: "", - HTTPPort: 8080, - EnableHTTPS: false, - HTTPSAddr: "", - HTTPSPort: 10443, - HTTPSCertFile: "", - HTTPSKeyFile: "", - EnableAdmin: false, - AdminAddr: "", - AdminPort: 8088, - EnableFcgi: false, - EnableStdIo: false, - ClientAuth: tls.RequireAndVerifyClientCert, - }, - WebConfig: WebConfig{ - AutoRender: true, - EnableDocs: false, - FlashName: "BEEGO_FLASH", - FlashSeparator: "BEEGOFLASH", - DirectoryIndex: false, - StaticDir: map[string]string{"/static": "static"}, - StaticExtensionsToGzip: []string{".css", ".js"}, - StaticCacheFileSize: 1024 * 100, - StaticCacheFileNum: 1000, - TemplateLeft: "{{", - TemplateRight: "}}", - ViewsPath: "views", - EnableXSRF: false, - XSRFKey: "beegoxsrf", - XSRFExpire: 0, - Session: SessionConfig{ - SessionOn: false, - SessionProvider: "memory", - SessionName: "beegosessionID", - SessionGCMaxLifetime: 3600, - SessionProviderConfig: "", - SessionDisableHTTPOnly: false, - SessionCookieLifeTime: 0, //set cookie default is the browser life - SessionAutoSetCookie: true, - SessionDomain: "", - SessionEnableSidInHTTPHeader: false, // enable store/get the sessionId into/from http headers - SessionNameInHTTPHeader: "Beegosessionid", - SessionEnableSidInURLQuery: false, // enable get the sessionId from Url Query params - }, - }, - Log: LogConfig{ - AccessLogs: false, - EnableStaticLogs: false, - AccessLogsFormat: "APACHE_FORMAT", - FileLineNum: true, - Outputs: map[string]string{"console": ""}, - }, - } -} - -// now only support ini, next will support json. -func parseConfig(appConfigPath string) (err error) { - AppConfig, err = newAppConfig(appConfigProvider, appConfigPath) - if err != nil { - return err - } - return assignConfig(AppConfig) -} - -func assignConfig(ac config.Configer) error { - for _, i := range []interface{}{BConfig, &BConfig.Listen, &BConfig.WebConfig, &BConfig.Log, &BConfig.WebConfig.Session} { - assignSingleConfig(i, ac) - } - // set the run mode first - if envRunMode := os.Getenv("BEEGO_RUNMODE"); envRunMode != "" { - BConfig.RunMode = envRunMode - } else if runMode := ac.String("RunMode"); runMode != "" { - BConfig.RunMode = runMode - } - - if sd := ac.String("StaticDir"); sd != "" { - BConfig.WebConfig.StaticDir = map[string]string{} - sds := strings.Fields(sd) - for _, v := range sds { - if url2fsmap := strings.SplitN(v, ":", 2); len(url2fsmap) == 2 { - BConfig.WebConfig.StaticDir["/"+strings.Trim(url2fsmap[0], "/")] = url2fsmap[1] - } else { - BConfig.WebConfig.StaticDir["/"+strings.Trim(url2fsmap[0], "/")] = url2fsmap[0] - } - } - } - - if sgz := ac.String("StaticExtensionsToGzip"); sgz != "" { - extensions := strings.Split(sgz, ",") - fileExts := []string{} - for _, ext := range extensions { - ext = strings.TrimSpace(ext) - if ext == "" { - continue - } - if !strings.HasPrefix(ext, ".") { - ext = "." + ext - } - fileExts = append(fileExts, ext) - } - if len(fileExts) > 0 { - BConfig.WebConfig.StaticExtensionsToGzip = fileExts - } - } - - if sfs, err := ac.Int("StaticCacheFileSize"); err == nil { - BConfig.WebConfig.StaticCacheFileSize = sfs - } - - if sfn, err := ac.Int("StaticCacheFileNum"); err == nil { - BConfig.WebConfig.StaticCacheFileNum = sfn - } - - if lo := ac.String("LogOutputs"); lo != "" { - // if lo is not nil or empty - // means user has set his own LogOutputs - // clear the default setting to BConfig.Log.Outputs - BConfig.Log.Outputs = make(map[string]string) - los := strings.Split(lo, ";") - for _, v := range los { - if logType2Config := strings.SplitN(v, ",", 2); len(logType2Config) == 2 { - BConfig.Log.Outputs[logType2Config[0]] = logType2Config[1] - } else { - continue - } - } - } - - //init log - logs.Reset() - for adaptor, config := range BConfig.Log.Outputs { - err := logs.SetLogger(adaptor, config) - if err != nil { - fmt.Fprintln(os.Stderr, fmt.Sprintf("%s with the config %q got err:%s", adaptor, config, err.Error())) - } - } - logs.SetLogFuncCall(BConfig.Log.FileLineNum) - - return nil -} - -func assignSingleConfig(p interface{}, ac config.Configer) { - pt := reflect.TypeOf(p) - if pt.Kind() != reflect.Ptr { - return - } - pt = pt.Elem() - if pt.Kind() != reflect.Struct { - return - } - pv := reflect.ValueOf(p).Elem() - - for i := 0; i < pt.NumField(); i++ { - pf := pv.Field(i) - if !pf.CanSet() { - continue - } - name := pt.Field(i).Name - switch pf.Kind() { - case reflect.String: - pf.SetString(ac.DefaultString(name, pf.String())) - case reflect.Int, reflect.Int64: - pf.SetInt(ac.DefaultInt64(name, pf.Int())) - case reflect.Bool: - pf.SetBool(ac.DefaultBool(name, pf.Bool())) - case reflect.Struct: - default: - //do nothing here - } - } - -} - -// LoadAppConfig allow developer to apply a config file -// Deprecated: using pkg/, we will delete this in v2.1.0 -func LoadAppConfig(adapterName, configPath string) error { - absConfigPath, err := filepath.Abs(configPath) - if err != nil { - return err - } - - if !utils.FileExists(absConfigPath) { - return fmt.Errorf("the target config file: %s don't exist", configPath) - } - - appConfigPath = absConfigPath - appConfigProvider = adapterName - - return parseConfig(appConfigPath) -} - -type beegoAppConfig struct { - innerConfig config.Configer -} - -func newAppConfig(appConfigProvider, appConfigPath string) (*beegoAppConfig, error) { - ac, err := config.NewConfig(appConfigProvider, appConfigPath) - if err != nil { - return nil, err - } - return &beegoAppConfig{ac}, nil -} - -// Deprecated: using pkg/, we will delete this in v2.1.0 -func (b *beegoAppConfig) Set(key, val string) error { - if err := b.innerConfig.Set(BConfig.RunMode+"::"+key, val); err != nil { - return b.innerConfig.Set(key, val) - } - return nil -} - -// Deprecated: using pkg/, we will delete this in v2.1.0 -func (b *beegoAppConfig) String(key string) string { - if v := b.innerConfig.String(BConfig.RunMode + "::" + key); v != "" { - return v - } - return b.innerConfig.String(key) -} - -// Deprecated: using pkg/, we will delete this in v2.1.0 -func (b *beegoAppConfig) Strings(key string) []string { - if v := b.innerConfig.Strings(BConfig.RunMode + "::" + key); len(v) > 0 { - return v - } - return b.innerConfig.Strings(key) -} - -// Deprecated: using pkg/, we will delete this in v2.1.0 -func (b *beegoAppConfig) Int(key string) (int, error) { - if v, err := b.innerConfig.Int(BConfig.RunMode + "::" + key); err == nil { - return v, nil - } - return b.innerConfig.Int(key) -} - -// Deprecated: using pkg/, we will delete this in v2.1.0 -func (b *beegoAppConfig) Int64(key string) (int64, error) { - if v, err := b.innerConfig.Int64(BConfig.RunMode + "::" + key); err == nil { - return v, nil - } - return b.innerConfig.Int64(key) -} - -// Deprecated: using pkg/, we will delete this in v2.1.0 -func (b *beegoAppConfig) Bool(key string) (bool, error) { - if v, err := b.innerConfig.Bool(BConfig.RunMode + "::" + key); err == nil { - return v, nil - } - return b.innerConfig.Bool(key) -} - -// Deprecated: using pkg/, we will delete this in v2.1.0 -func (b *beegoAppConfig) Float(key string) (float64, error) { - if v, err := b.innerConfig.Float(BConfig.RunMode + "::" + key); err == nil { - return v, nil - } - return b.innerConfig.Float(key) -} - -// Deprecated: using pkg/, we will delete this in v2.1.0 -func (b *beegoAppConfig) DefaultString(key string, defaultVal string) string { - if v := b.String(key); v != "" { - return v - } - return defaultVal -} - -// Deprecated: using pkg/, we will delete this in v2.1.0 -func (b *beegoAppConfig) DefaultStrings(key string, defaultVal []string) []string { - if v := b.Strings(key); len(v) != 0 { - return v - } - return defaultVal -} - -// Deprecated: using pkg/, we will delete this in v2.1.0 -func (b *beegoAppConfig) DefaultInt(key string, defaultVal int) int { - if v, err := b.Int(key); err == nil { - return v - } - return defaultVal -} - -// Deprecated: using pkg/, we will delete this in v2.1.0 -func (b *beegoAppConfig) DefaultInt64(key string, defaultVal int64) int64 { - if v, err := b.Int64(key); err == nil { - return v - } - return defaultVal -} - -// Deprecated: using pkg/, we will delete this in v2.1.0 -func (b *beegoAppConfig) DefaultBool(key string, defaultVal bool) bool { - if v, err := b.Bool(key); err == nil { - return v - } - return defaultVal -} - -// Deprecated: using pkg/, we will delete this in v2.1.0 -func (b *beegoAppConfig) DefaultFloat(key string, defaultVal float64) float64 { - if v, err := b.Float(key); err == nil { - return v - } - return defaultVal -} - -// Deprecated: using pkg/, we will delete this in v2.1.0 -func (b *beegoAppConfig) DIY(key string) (interface{}, error) { - return b.innerConfig.DIY(key) -} - -// Deprecated: using pkg/, we will delete this in v2.1.0 -func (b *beegoAppConfig) GetSection(section string) (map[string]string, error) { - return b.innerConfig.GetSection(section) -} - -// Deprecated: using pkg/, we will delete this in v2.1.0 -func (b *beegoAppConfig) SaveConfigFile(filename string) error { - return b.innerConfig.SaveConfigFile(filename) -} diff --git a/config/config.go b/config/config.go deleted file mode 100644 index db2e96f6..00000000 --- a/config/config.go +++ /dev/null @@ -1,259 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Package config is used to parse config. -// Usage: -// import "github.com/astaxie/beego/config" -//Examples. -// -// cnf, err := config.NewConfig("ini", "config.conf") -// -// cnf APIS: -// -// cnf.Set(key, val string) error -// cnf.String(key string) string -// cnf.Strings(key string) []string -// cnf.Int(key string) (int, error) -// cnf.Int64(key string) (int64, error) -// cnf.Bool(key string) (bool, error) -// cnf.Float(key string) (float64, error) -// cnf.DefaultString(key string, defaultVal string) string -// cnf.DefaultStrings(key string, defaultVal []string) []string -// cnf.DefaultInt(key string, defaultVal int) int -// cnf.DefaultInt64(key string, defaultVal int64) int64 -// cnf.DefaultBool(key string, defaultVal bool) bool -// cnf.DefaultFloat(key string, defaultVal float64) float64 -// cnf.DIY(key string) (interface{}, error) -// cnf.GetSection(section string) (map[string]string, error) -// cnf.SaveConfigFile(filename string) error -//More docs http://beego.me/docs/module/config.md -package config - -import ( - "fmt" - "os" - "reflect" - "time" -) - -// Configer defines how to get and set value from configuration raw data. -// Deprecated: using pkg/config, we will delete this in v2.1.0 -type Configer interface { - // Deprecated: using pkg/config, we will delete this in v2.1.0 - Set(key, val string) error //support section::key type in given key when using ini type. - // Deprecated: using pkg/config, we will delete this in v2.1.0 - String(key string) string //support section::key type in key string when using ini and json type; Int,Int64,Bool,Float,DIY are same. - // Deprecated: using pkg/config, we will delete this in v2.1.0 - Strings(key string) []string //get string slice - // Deprecated: using pkg/config, we will delete this in v2.1.0 - Int(key string) (int, error) - // Deprecated: using pkg/config, we will delete this in v2.1.0 - Int64(key string) (int64, error) - // Deprecated: using pkg/config, we will delete this in v2.1.0 - Bool(key string) (bool, error) - // Deprecated: using pkg/config, we will delete this in v2.1.0 - Float(key string) (float64, error) - // Deprecated: using pkg/config, we will delete this in v2.1.0 - DefaultString(key string, defaultVal string) string // support section::key type in key string when using ini and json type; Int,Int64,Bool,Float,DIY are same. - // Deprecated: using pkg/config, we will delete this in v2.1.0 - DefaultStrings(key string, defaultVal []string) []string //get string slice - // Deprecated: using pkg/config, we will delete this in v2.1.0 - DefaultInt(key string, defaultVal int) int - // Deprecated: using pkg/config, we will delete this in v2.1.0 - DefaultInt64(key string, defaultVal int64) int64 - // Deprecated: using pkg/config, we will delete this in v2.1.0 - DefaultBool(key string, defaultVal bool) bool - // Deprecated: using pkg/config, we will delete this in v2.1.0 - DefaultFloat(key string, defaultVal float64) float64 - // Deprecated: using pkg/config, we will delete this in v2.1.0 - DIY(key string) (interface{}, error) - // Deprecated: using pkg/config, we will delete this in v2.1.0 - GetSection(section string) (map[string]string, error) - // Deprecated: using pkg/config, we will delete this in v2.1.0 - SaveConfigFile(filename string) error -} - -// Config is the adapter interface for parsing config file to get raw data to Configer. -type Config interface { - Parse(key string) (Configer, error) - ParseData(data []byte) (Configer, error) -} - -var adapters = make(map[string]Config) - -// Register makes a config adapter available by the adapter name. -// If Register is called twice with the same name or if driver is nil, -// it panics. -func Register(name string, adapter Config) { - if adapter == nil { - panic("config: Register adapter is nil") - } - if _, ok := adapters[name]; ok { - panic("config: Register called twice for adapter " + name) - } - adapters[name] = adapter -} - -// NewConfig adapterName is ini/json/xml/yaml. -// filename is the config file path. -func NewConfig(adapterName, filename string) (Configer, error) { - adapter, ok := adapters[adapterName] - if !ok { - return nil, fmt.Errorf("config: unknown adaptername %q (forgotten import?)", adapterName) - } - return adapter.Parse(filename) -} - -// NewConfigData adapterName is ini/json/xml/yaml. -// data is the config data. -func NewConfigData(adapterName string, data []byte) (Configer, error) { - adapter, ok := adapters[adapterName] - if !ok { - return nil, fmt.Errorf("config: unknown adaptername %q (forgotten import?)", adapterName) - } - return adapter.ParseData(data) -} - -// ExpandValueEnvForMap convert all string value with environment variable. -func ExpandValueEnvForMap(m map[string]interface{}) map[string]interface{} { - for k, v := range m { - switch value := v.(type) { - case string: - m[k] = ExpandValueEnv(value) - case map[string]interface{}: - m[k] = ExpandValueEnvForMap(value) - case map[string]string: - for k2, v2 := range value { - value[k2] = ExpandValueEnv(v2) - } - m[k] = value - } - } - return m -} - -// ExpandValueEnv returns value of convert with environment variable. -// -// Return environment variable if value start with "${" and end with "}". -// Return default value if environment variable is empty or not exist. -// -// It accept value formats "${env}" , "${env||}}" , "${env||defaultValue}" , "defaultvalue". -// Examples: -// v1 := config.ExpandValueEnv("${GOPATH}") // return the GOPATH environment variable. -// v2 := config.ExpandValueEnv("${GOAsta||/usr/local/go}") // return the default value "/usr/local/go/". -// v3 := config.ExpandValueEnv("Astaxie") // return the value "Astaxie". -func ExpandValueEnv(value string) (realValue string) { - realValue = value - - vLen := len(value) - // 3 = ${} - if vLen < 3 { - return - } - // Need start with "${" and end with "}", then return. - if value[0] != '$' || value[1] != '{' || value[vLen-1] != '}' { - return - } - - key := "" - defaultV := "" - // value start with "${" - for i := 2; i < vLen; i++ { - if value[i] == '|' && (i+1 < vLen && value[i+1] == '|') { - key = value[2:i] - defaultV = value[i+2 : vLen-1] // other string is default value. - break - } else if value[i] == '}' { - key = value[2:i] - break - } - } - - realValue = os.Getenv(key) - if realValue == "" { - realValue = defaultV - } - - return -} - -// ParseBool returns the boolean value represented by the string. -// -// It accepts 1, 1.0, t, T, TRUE, true, True, YES, yes, Yes,Y, y, ON, on, On, -// 0, 0.0, f, F, FALSE, false, False, NO, no, No, N,n, OFF, off, Off. -// Any other value returns an error. -func ParseBool(val interface{}) (value bool, err error) { - if val != nil { - switch v := val.(type) { - case bool: - return v, nil - case string: - switch v { - case "1", "t", "T", "true", "TRUE", "True", "YES", "yes", "Yes", "Y", "y", "ON", "on", "On": - return true, nil - case "0", "f", "F", "false", "FALSE", "False", "NO", "no", "No", "N", "n", "OFF", "off", "Off": - return false, nil - } - case int8, int32, int64: - strV := fmt.Sprintf("%d", v) - if strV == "1" { - return true, nil - } else if strV == "0" { - return false, nil - } - case float64: - if v == 1.0 { - return true, nil - } else if v == 0.0 { - return false, nil - } - } - return false, fmt.Errorf("parsing %q: invalid syntax", val) - } - return false, fmt.Errorf("parsing : invalid syntax") -} - -// ToString converts values of any type to string. -func ToString(x interface{}) string { - switch y := x.(type) { - - // Handle dates with special logic - // This needs to come above the fmt.Stringer - // test since time.Time's have a .String() - // method - case time.Time: - return y.Format("A Monday") - - // Handle type string - case string: - return y - - // Handle type with .String() method - case fmt.Stringer: - return y.String() - - // Handle type with .Error() method - case error: - return y.Error() - - } - - // Handle named string type - if v := reflect.ValueOf(x); v.Kind() == reflect.String { - return v.String() - } - - // Fallback to fmt package for anything else like numeric types - return fmt.Sprint(x) -} diff --git a/config/config_test.go b/config/config_test.go deleted file mode 100644 index 15d6ffa6..00000000 --- a/config/config_test.go +++ /dev/null @@ -1,55 +0,0 @@ -// Copyright 2016 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package config - -import ( - "os" - "testing" -) - -func TestExpandValueEnv(t *testing.T) { - - testCases := []struct { - item string - want string - }{ - {"", ""}, - {"$", "$"}, - {"{", "{"}, - {"{}", "{}"}, - {"${}", ""}, - {"${|}", ""}, - {"${}", ""}, - {"${{}}", ""}, - {"${{||}}", "}"}, - {"${pwd||}", ""}, - {"${pwd||}", ""}, - {"${pwd||}", ""}, - {"${pwd||}}", "}"}, - {"${pwd||{{||}}}", "{{||}}"}, - {"${GOPATH}", os.Getenv("GOPATH")}, - {"${GOPATH||}", os.Getenv("GOPATH")}, - {"${GOPATH||root}", os.Getenv("GOPATH")}, - {"${GOPATH_NOT||root}", "root"}, - {"${GOPATH_NOT||||root}", "||root"}, - } - - for _, c := range testCases { - if got := ExpandValueEnv(c.item); got != c.want { - t.Errorf("expand value error, item %q want %q, got %q", c.item, c.want, got) - } - } - -} diff --git a/config/env/env.go b/config/env/env.go deleted file mode 100644 index 1a6c2527..00000000 --- a/config/env/env.go +++ /dev/null @@ -1,92 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// Copyright 2017 Faissal Elamraoui. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Package env is used to parse environment. -package env - -import ( - "fmt" - "os" - "strings" - - "github.com/astaxie/beego/utils" -) - -var env *utils.BeeMap - -func init() { - env = utils.NewBeeMap() - for _, e := range os.Environ() { - splits := strings.Split(e, "=") - env.Set(splits[0], os.Getenv(splits[0])) - } -} - -// Get returns a value by key. -// If the key does not exist, the default value will be returned. -// Deprecated: using pkg/config, we will delete this in v2.1.0 -func Get(key string, defVal string) string { - if val := env.Get(key); val != nil { - return val.(string) - } - return defVal -} - -// MustGet returns a value by key. -// If the key does not exist, it will return an error. -// Deprecated: using pkg/config, we will delete this in v2.1.0 -func MustGet(key string) (string, error) { - if val := env.Get(key); val != nil { - return val.(string), nil - } - return "", fmt.Errorf("no env variable with %s", key) -} - -// Set sets a value in the ENV copy. -// This does not affect the child process environment. -// Deprecated: using pkg/config, we will delete this in v2.1.0 -func Set(key string, value string) { - env.Set(key, value) -} - -// MustSet sets a value in the ENV copy and the child process environment. -// It returns an error in case the set operation failed. -// Deprecated: using pkg/config, we will delete this in v2.1.0 -func MustSet(key string, value string) error { - err := os.Setenv(key, value) - if err != nil { - return err - } - env.Set(key, value) - return nil -} - -// GetAll returns all keys/values in the current child process environment. -// Deprecated: using pkg/config, we will delete this in v2.1.0 -func GetAll() map[string]string { - items := env.Items() - envs := make(map[string]string, env.Count()) - - for key, val := range items { - switch key := key.(type) { - case string: - switch val := val.(type) { - case string: - envs[key] = val - } - } - } - return envs -} diff --git a/config/env/env_test.go b/config/env/env_test.go deleted file mode 100644 index 3f1d4dba..00000000 --- a/config/env/env_test.go +++ /dev/null @@ -1,75 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// Copyright 2017 Faissal Elamraoui. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -package env - -import ( - "os" - "testing" -) - -func TestEnvGet(t *testing.T) { - gopath := Get("GOPATH", "") - if gopath != os.Getenv("GOPATH") { - t.Error("expected GOPATH not empty.") - } - - noExistVar := Get("NOEXISTVAR", "foo") - if noExistVar != "foo" { - t.Errorf("expected NOEXISTVAR to equal foo, got %s.", noExistVar) - } -} - -func TestEnvMustGet(t *testing.T) { - gopath, err := MustGet("GOPATH") - if err != nil { - t.Error(err) - } - - if gopath != os.Getenv("GOPATH") { - t.Errorf("expected GOPATH to be the same, got %s.", gopath) - } - - _, err = MustGet("NOEXISTVAR") - if err == nil { - t.Error("expected error to be non-nil") - } -} - -func TestEnvSet(t *testing.T) { - Set("MYVAR", "foo") - myVar := Get("MYVAR", "bar") - if myVar != "foo" { - t.Errorf("expected MYVAR to equal foo, got %s.", myVar) - } -} - -func TestEnvMustSet(t *testing.T) { - err := MustSet("FOO", "bar") - if err != nil { - t.Error(err) - } - - fooVar := os.Getenv("FOO") - if fooVar != "bar" { - t.Errorf("expected FOO variable to equal bar, got %s.", fooVar) - } -} - -func TestEnvGetAll(t *testing.T) { - envMap := GetAll() - if len(envMap) == 0 { - t.Error("expected environment not empty.") - } -} diff --git a/config/fake.go b/config/fake.go deleted file mode 100644 index 8093ad61..00000000 --- a/config/fake.go +++ /dev/null @@ -1,151 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package config - -import ( - "errors" - "strconv" - "strings" -) - -type fakeConfigContainer struct { - data map[string]string -} - -func (c *fakeConfigContainer) getData(key string) string { - return c.data[strings.ToLower(key)] -} - -// Deprecated: using pkg/config, we will delete this in v2.1.0 -func (c *fakeConfigContainer) Set(key, val string) error { - c.data[strings.ToLower(key)] = val - return nil -} - -// Deprecated: using pkg/config, we will delete this in v2.1.0 -func (c *fakeConfigContainer) String(key string) string { - return c.getData(key) -} - -// Deprecated: using pkg/config, we will delete this in v2.1.0 -func (c *fakeConfigContainer) DefaultString(key string, defaultval string) string { - v := c.String(key) - if v == "" { - return defaultval - } - return v -} - -// Deprecated: using pkg/config, we will delete this in v2.1.0 -func (c *fakeConfigContainer) Strings(key string) []string { - v := c.String(key) - if v == "" { - return nil - } - return strings.Split(v, ";") -} - -// Deprecated: using pkg/config, we will delete this in v2.1.0 -func (c *fakeConfigContainer) DefaultStrings(key string, defaultval []string) []string { - v := c.Strings(key) - if v == nil { - return defaultval - } - return v -} - -// Deprecated: using pkg/config, we will delete this in v2.1.0 -func (c *fakeConfigContainer) Int(key string) (int, error) { - return strconv.Atoi(c.getData(key)) -} - -// Deprecated: using pkg/config, we will delete this in v2.1.0 -func (c *fakeConfigContainer) DefaultInt(key string, defaultval int) int { - v, err := c.Int(key) - if err != nil { - return defaultval - } - return v -} - -// Deprecated: using pkg/config, we will delete this in v2.1.0 -func (c *fakeConfigContainer) Int64(key string) (int64, error) { - return strconv.ParseInt(c.getData(key), 10, 64) -} - -// Deprecated: using pkg/config, we will delete this in v2.1.0 -func (c *fakeConfigContainer) DefaultInt64(key string, defaultval int64) int64 { - v, err := c.Int64(key) - if err != nil { - return defaultval - } - return v -} - -// Deprecated: using pkg/config, we will delete this in v2.1.0 -func (c *fakeConfigContainer) Bool(key string) (bool, error) { - return ParseBool(c.getData(key)) -} - -// Deprecated: using pkg/config, we will delete this in v2.1.0 -func (c *fakeConfigContainer) DefaultBool(key string, defaultval bool) bool { - v, err := c.Bool(key) - if err != nil { - return defaultval - } - return v -} - -// Deprecated: using pkg/config, we will delete this in v2.1.0 -func (c *fakeConfigContainer) Float(key string) (float64, error) { - return strconv.ParseFloat(c.getData(key), 64) -} - -// Deprecated: using pkg/config, we will delete this in v2.1.0 -func (c *fakeConfigContainer) DefaultFloat(key string, defaultval float64) float64 { - v, err := c.Float(key) - if err != nil { - return defaultval - } - return v -} - -// Deprecated: using pkg/config, we will delete this in v2.1.0 -func (c *fakeConfigContainer) DIY(key string) (interface{}, error) { - if v, ok := c.data[strings.ToLower(key)]; ok { - return v, nil - } - return nil, errors.New("key not find") -} - -// Deprecated: using pkg/config, we will delete this in v2.1.0 -func (c *fakeConfigContainer) GetSection(section string) (map[string]string, error) { - return nil, errors.New("not implement in the fakeConfigContainer") -} - -// Deprecated: using pkg/config, we will delete this in v2.1.0 -func (c *fakeConfigContainer) SaveConfigFile(filename string) error { - return errors.New("not implement in the fakeConfigContainer") -} - -var _ Configer = new(fakeConfigContainer) - -// NewFakeConfig return a fake Configer -// Deprecated: using pkg/config, we will delete this in v2.1.0 -func NewFakeConfig() Configer { - return &fakeConfigContainer{ - data: make(map[string]string), - } -} diff --git a/config/ini.go b/config/ini.go deleted file mode 100644 index 1da293dc..00000000 --- a/config/ini.go +++ /dev/null @@ -1,524 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package config - -import ( - "bufio" - "bytes" - "errors" - "io" - "io/ioutil" - "os" - "os/user" - "path/filepath" - "strconv" - "strings" - "sync" -) - -var ( - defaultSection = "default" // default section means if some ini items not in a section, make them in default section, - bNumComment = []byte{'#'} // number signal - bSemComment = []byte{';'} // semicolon signal - bEmpty = []byte{} - bEqual = []byte{'='} // equal signal - bDQuote = []byte{'"'} // quote signal - sectionStart = []byte{'['} // section start signal - sectionEnd = []byte{']'} // section end signal - lineBreak = "\n" -) - -// IniConfig implements Config to parse ini file. -// Deprecated: using pkg/config, we will delete this in v2.1.0 -type IniConfig struct { -} - -// Parse creates a new Config and parses the file configuration from the named file. -// Deprecated: using pkg/config, we will delete this in v2.1.0 -func (ini *IniConfig) Parse(name string) (Configer, error) { - return ini.parseFile(name) -} - -func (ini *IniConfig) parseFile(name string) (*IniConfigContainer, error) { - data, err := ioutil.ReadFile(name) - if err != nil { - return nil, err - } - - return ini.parseData(filepath.Dir(name), data) -} - -func (ini *IniConfig) parseData(dir string, data []byte) (*IniConfigContainer, error) { - cfg := &IniConfigContainer{ - data: make(map[string]map[string]string), - sectionComment: make(map[string]string), - keyComment: make(map[string]string), - RWMutex: sync.RWMutex{}, - } - cfg.Lock() - defer cfg.Unlock() - - var comment bytes.Buffer - buf := bufio.NewReader(bytes.NewBuffer(data)) - // check the BOM - head, err := buf.Peek(3) - if err == nil && head[0] == 239 && head[1] == 187 && head[2] == 191 { - for i := 1; i <= 3; i++ { - buf.ReadByte() - } - } - section := defaultSection - tmpBuf := bytes.NewBuffer(nil) - for { - tmpBuf.Reset() - - shouldBreak := false - for { - tmp, isPrefix, err := buf.ReadLine() - if err == io.EOF { - shouldBreak = true - break - } - - //It might be a good idea to throw a error on all unknonw errors? - if _, ok := err.(*os.PathError); ok { - return nil, err - } - - tmpBuf.Write(tmp) - if isPrefix { - continue - } - - if !isPrefix { - break - } - } - if shouldBreak { - break - } - - line := tmpBuf.Bytes() - line = bytes.TrimSpace(line) - if bytes.Equal(line, bEmpty) { - continue - } - var bComment []byte - switch { - case bytes.HasPrefix(line, bNumComment): - bComment = bNumComment - case bytes.HasPrefix(line, bSemComment): - bComment = bSemComment - } - if bComment != nil { - line = bytes.TrimLeft(line, string(bComment)) - // Need append to a new line if multi-line comments. - if comment.Len() > 0 { - comment.WriteByte('\n') - } - comment.Write(line) - continue - } - - if bytes.HasPrefix(line, sectionStart) && bytes.HasSuffix(line, sectionEnd) { - section = strings.ToLower(string(line[1 : len(line)-1])) // section name case insensitive - if comment.Len() > 0 { - cfg.sectionComment[section] = comment.String() - comment.Reset() - } - if _, ok := cfg.data[section]; !ok { - cfg.data[section] = make(map[string]string) - } - continue - } - - if _, ok := cfg.data[section]; !ok { - cfg.data[section] = make(map[string]string) - } - keyValue := bytes.SplitN(line, bEqual, 2) - - key := string(bytes.TrimSpace(keyValue[0])) // key name case insensitive - key = strings.ToLower(key) - - // handle include "other.conf" - if len(keyValue) == 1 && strings.HasPrefix(key, "include") { - - includefiles := strings.Fields(key) - if includefiles[0] == "include" && len(includefiles) == 2 { - - otherfile := strings.Trim(includefiles[1], "\"") - if !filepath.IsAbs(otherfile) { - otherfile = filepath.Join(dir, otherfile) - } - - i, err := ini.parseFile(otherfile) - if err != nil { - return nil, err - } - - for sec, dt := range i.data { - if _, ok := cfg.data[sec]; !ok { - cfg.data[sec] = make(map[string]string) - } - for k, v := range dt { - cfg.data[sec][k] = v - } - } - - for sec, comm := range i.sectionComment { - cfg.sectionComment[sec] = comm - } - - for k, comm := range i.keyComment { - cfg.keyComment[k] = comm - } - - continue - } - } - - if len(keyValue) != 2 { - return nil, errors.New("read the content error: \"" + string(line) + "\", should key = val") - } - val := bytes.TrimSpace(keyValue[1]) - if bytes.HasPrefix(val, bDQuote) { - val = bytes.Trim(val, `"`) - } - - cfg.data[section][key] = ExpandValueEnv(string(val)) - if comment.Len() > 0 { - cfg.keyComment[section+"."+key] = comment.String() - comment.Reset() - } - - } - return cfg, nil -} - -// ParseData parse ini the data -// When include other.conf,other.conf is either absolute directory -// or under beego in default temporary directory(/tmp/beego[-username]). -// Deprecated: using pkg/config, we will delete this in v2.1.0 -func (ini *IniConfig) ParseData(data []byte) (Configer, error) { - dir := "beego" - currentUser, err := user.Current() - if err == nil { - dir = "beego-" + currentUser.Username - } - dir = filepath.Join(os.TempDir(), dir) - if err = os.MkdirAll(dir, os.ModePerm); err != nil { - return nil, err - } - - return ini.parseData(dir, data) -} - -// IniConfigContainer A Config represents the ini configuration. -// When set and get value, support key as section:name type. -// Deprecated: using pkg/config, we will delete this in v2.1.0 -type IniConfigContainer struct { - data map[string]map[string]string // section=> key:val - sectionComment map[string]string // section : comment - keyComment map[string]string // id: []{comment, key...}; id 1 is for main comment. - sync.RWMutex -} - -// Bool returns the boolean value for a given key. -// Deprecated: using pkg/config, we will delete this in v2.1.0 -func (c *IniConfigContainer) Bool(key string) (bool, error) { - return ParseBool(c.getdata(key)) -} - -// DefaultBool returns the boolean value for a given key. -// if err != nil return defaultval -// Deprecated: using pkg/config, we will delete this in v2.1.0 -func (c *IniConfigContainer) DefaultBool(key string, defaultval bool) bool { - v, err := c.Bool(key) - if err != nil { - return defaultval - } - return v -} - -// Int returns the integer value for a given key. -// Deprecated: using pkg/config, we will delete this in v2.1.0 -func (c *IniConfigContainer) Int(key string) (int, error) { - return strconv.Atoi(c.getdata(key)) -} - -// DefaultInt returns the integer value for a given key. -// if err != nil return defaultval -// Deprecated: using pkg/config, we will delete this in v2.1.0 -func (c *IniConfigContainer) DefaultInt(key string, defaultval int) int { - v, err := c.Int(key) - if err != nil { - return defaultval - } - return v -} - -// Int64 returns the int64 value for a given key. -// Deprecated: using pkg/config, we will delete this in v2.1.0 -func (c *IniConfigContainer) Int64(key string) (int64, error) { - return strconv.ParseInt(c.getdata(key), 10, 64) -} - -// DefaultInt64 returns the int64 value for a given key. -// if err != nil return defaultval -// Deprecated: using pkg/config, we will delete this in v2.1.0 -func (c *IniConfigContainer) DefaultInt64(key string, defaultval int64) int64 { - v, err := c.Int64(key) - if err != nil { - return defaultval - } - return v -} - -// Float returns the float value for a given key. -// Deprecated: using pkg/config, we will delete this in v2.1.0 -func (c *IniConfigContainer) Float(key string) (float64, error) { - return strconv.ParseFloat(c.getdata(key), 64) -} - -// DefaultFloat returns the float64 value for a given key. -// if err != nil return defaultval -// Deprecated: using pkg/config, we will delete this in v2.1.0 -func (c *IniConfigContainer) DefaultFloat(key string, defaultval float64) float64 { - v, err := c.Float(key) - if err != nil { - return defaultval - } - return v -} - -// String returns the string value for a given key. -// Deprecated: using pkg/config, we will delete this in v2.1.0 -func (c *IniConfigContainer) String(key string) string { - return c.getdata(key) -} - -// DefaultString returns the string value for a given key. -// if err != nil return defaultval -// Deprecated: using pkg/config, we will delete this in v2.1.0 -func (c *IniConfigContainer) DefaultString(key string, defaultval string) string { - v := c.String(key) - if v == "" { - return defaultval - } - return v -} - -// Strings returns the []string value for a given key. -// Return nil if config value does not exist or is empty. -// Deprecated: using pkg/config, we will delete this in v2.1.0 -func (c *IniConfigContainer) Strings(key string) []string { - v := c.String(key) - if v == "" { - return nil - } - return strings.Split(v, ";") -} - -// DefaultStrings returns the []string value for a given key. -// if err != nil return defaultval -// Deprecated: using pkg/config, we will delete this in v2.1.0 -func (c *IniConfigContainer) DefaultStrings(key string, defaultval []string) []string { - v := c.Strings(key) - if v == nil { - return defaultval - } - return v -} - -// GetSection returns map for the given section -// Deprecated: using pkg/config, we will delete this in v2.1.0 -func (c *IniConfigContainer) GetSection(section string) (map[string]string, error) { - if v, ok := c.data[section]; ok { - return v, nil - } - return nil, errors.New("not exist section") -} - -// SaveConfigFile save the config into file. -// -// BUG(env): The environment variable config item will be saved with real value in SaveConfigFile Function. -// Deprecated: using pkg/config, we will delete this in v2.1.0 -func (c *IniConfigContainer) SaveConfigFile(filename string) (err error) { - // Write configuration file by filename. - f, err := os.Create(filename) - if err != nil { - return err - } - defer f.Close() - - // Get section or key comments. Fixed #1607 - getCommentStr := func(section, key string) string { - var ( - comment string - ok bool - ) - if len(key) == 0 { - comment, ok = c.sectionComment[section] - } else { - comment, ok = c.keyComment[section+"."+key] - } - - if ok { - // Empty comment - if len(comment) == 0 || len(strings.TrimSpace(comment)) == 0 { - return string(bNumComment) - } - prefix := string(bNumComment) - // Add the line head character "#" - return prefix + strings.Replace(comment, lineBreak, lineBreak+prefix, -1) - } - return "" - } - - buf := bytes.NewBuffer(nil) - // Save default section at first place - if dt, ok := c.data[defaultSection]; ok { - for key, val := range dt { - if key != " " { - // Write key comments. - if v := getCommentStr(defaultSection, key); len(v) > 0 { - if _, err = buf.WriteString(v + lineBreak); err != nil { - return err - } - } - - // Write key and value. - if _, err = buf.WriteString(key + string(bEqual) + val + lineBreak); err != nil { - return err - } - } - } - - // Put a line between sections. - if _, err = buf.WriteString(lineBreak); err != nil { - return err - } - } - // Save named sections - for section, dt := range c.data { - if section != defaultSection { - // Write section comments. - if v := getCommentStr(section, ""); len(v) > 0 { - if _, err = buf.WriteString(v + lineBreak); err != nil { - return err - } - } - - // Write section name. - if _, err = buf.WriteString(string(sectionStart) + section + string(sectionEnd) + lineBreak); err != nil { - return err - } - - for key, val := range dt { - if key != " " { - // Write key comments. - if v := getCommentStr(section, key); len(v) > 0 { - if _, err = buf.WriteString(v + lineBreak); err != nil { - return err - } - } - - // Write key and value. - if _, err = buf.WriteString(key + string(bEqual) + val + lineBreak); err != nil { - return err - } - } - } - - // Put a line between sections. - if _, err = buf.WriteString(lineBreak); err != nil { - return err - } - } - } - _, err = buf.WriteTo(f) - return err -} - -// Set writes a new value for key. -// if write to one section, the key need be "section::key". -// if the section is not existed, it panics. -// Deprecated: using pkg/config, we will delete this in v2.1.0 -func (c *IniConfigContainer) Set(key, value string) error { - c.Lock() - defer c.Unlock() - if len(key) == 0 { - return errors.New("key is empty") - } - - var ( - section, k string - sectionKey = strings.Split(strings.ToLower(key), "::") - ) - - if len(sectionKey) >= 2 { - section = sectionKey[0] - k = sectionKey[1] - } else { - section = defaultSection - k = sectionKey[0] - } - - if _, ok := c.data[section]; !ok { - c.data[section] = make(map[string]string) - } - c.data[section][k] = value - return nil -} - -// DIY returns the raw value by a given key. -// Deprecated: using pkg/config, we will delete this in v2.1.0 -func (c *IniConfigContainer) DIY(key string) (v interface{}, err error) { - if v, ok := c.data[strings.ToLower(key)]; ok { - return v, nil - } - return v, errors.New("key not find") -} - -// section.key or key -func (c *IniConfigContainer) getdata(key string) string { - if len(key) == 0 { - return "" - } - c.RLock() - defer c.RUnlock() - - var ( - section, k string - sectionKey = strings.Split(strings.ToLower(key), "::") - ) - if len(sectionKey) >= 2 { - section = sectionKey[0] - k = sectionKey[1] - } else { - section = defaultSection - k = sectionKey[0] - } - if v, ok := c.data[section]; ok { - if vv, ok := v[k]; ok { - return vv - } - } - return "" -} - -func init() { - Register("ini", &IniConfig{}) -} diff --git a/config/ini_test.go b/config/ini_test.go deleted file mode 100644 index ffcdb294..00000000 --- a/config/ini_test.go +++ /dev/null @@ -1,190 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package config - -import ( - "fmt" - "io/ioutil" - "os" - "strings" - "testing" -) - -func TestIni(t *testing.T) { - - var ( - inicontext = ` -;comment one -#comment two -appname = beeapi -httpport = 8080 -mysqlport = 3600 -PI = 3.1415976 -runmode = "dev" -autorender = false -copyrequestbody = true -session= on -cookieon= off -newreg = OFF -needlogin = ON -enableSession = Y -enableCookie = N -flag = 1 -path1 = ${GOPATH} -path2 = ${GOPATH||/home/go} -[demo] -key1="asta" -key2 = "xie" -CaseInsensitive = true -peers = one;two;three -password = ${GOPATH} -` - - keyValue = map[string]interface{}{ - "appname": "beeapi", - "httpport": 8080, - "mysqlport": int64(3600), - "pi": 3.1415976, - "runmode": "dev", - "autorender": false, - "copyrequestbody": true, - "session": true, - "cookieon": false, - "newreg": false, - "needlogin": true, - "enableSession": true, - "enableCookie": false, - "flag": true, - "path1": os.Getenv("GOPATH"), - "path2": os.Getenv("GOPATH"), - "demo::key1": "asta", - "demo::key2": "xie", - "demo::CaseInsensitive": true, - "demo::peers": []string{"one", "two", "three"}, - "demo::password": os.Getenv("GOPATH"), - "null": "", - "demo2::key1": "", - "error": "", - "emptystrings": []string{}, - } - ) - - f, err := os.Create("testini.conf") - if err != nil { - t.Fatal(err) - } - _, err = f.WriteString(inicontext) - if err != nil { - f.Close() - t.Fatal(err) - } - f.Close() - defer os.Remove("testini.conf") - iniconf, err := NewConfig("ini", "testini.conf") - if err != nil { - t.Fatal(err) - } - for k, v := range keyValue { - var err error - var value interface{} - switch v.(type) { - case int: - value, err = iniconf.Int(k) - case int64: - value, err = iniconf.Int64(k) - case float64: - value, err = iniconf.Float(k) - case bool: - value, err = iniconf.Bool(k) - case []string: - value = iniconf.Strings(k) - case string: - value = iniconf.String(k) - default: - value, err = iniconf.DIY(k) - } - if err != nil { - t.Fatalf("get key %q value fail,err %s", k, err) - } else if fmt.Sprintf("%v", v) != fmt.Sprintf("%v", value) { - t.Fatalf("get key %q value, want %v got %v .", k, v, value) - } - - } - if err = iniconf.Set("name", "astaxie"); err != nil { - t.Fatal(err) - } - if iniconf.String("name") != "astaxie" { - t.Fatal("get name error") - } - -} - -func TestIniSave(t *testing.T) { - - const ( - inicontext = ` -app = app -;comment one -#comment two -# comment three -appname = beeapi -httpport = 8080 -# DB Info -# enable db -[dbinfo] -# db type name -# suport mysql,sqlserver -name = mysql -` - - saveResult = ` -app=app -#comment one -#comment two -# comment three -appname=beeapi -httpport=8080 - -# DB Info -# enable db -[dbinfo] -# db type name -# suport mysql,sqlserver -name=mysql -` - ) - cfg, err := NewConfigData("ini", []byte(inicontext)) - if err != nil { - t.Fatal(err) - } - name := "newIniConfig.ini" - if err := cfg.SaveConfigFile(name); err != nil { - t.Fatal(err) - } - defer os.Remove(name) - - if data, err := ioutil.ReadFile(name); err != nil { - t.Fatal(err) - } else { - cfgData := string(data) - datas := strings.Split(saveResult, "\n") - for _, line := range datas { - if !strings.Contains(cfgData, line+"\n") { - t.Fatalf("different after save ini config file. need contains %q", line) - } - } - - } -} diff --git a/config/json.go b/config/json.go deleted file mode 100644 index 74a50d34..00000000 --- a/config/json.go +++ /dev/null @@ -1,290 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package config - -import ( - "encoding/json" - "errors" - "fmt" - "io/ioutil" - "os" - "strconv" - "strings" - "sync" -) - -// JSONConfig is a json config parser and implements Config interface. -// Deprecated: using pkg/config, we will delete this in v2.1.0 -type JSONConfig struct { -} - -// Parse returns a ConfigContainer with parsed json config map. -// Deprecated: using pkg/config, we will delete this in v2.1.0 -func (js *JSONConfig) Parse(filename string) (Configer, error) { - file, err := os.Open(filename) - if err != nil { - return nil, err - } - defer file.Close() - content, err := ioutil.ReadAll(file) - if err != nil { - return nil, err - } - - return js.ParseData(content) -} - -// ParseData returns a ConfigContainer with json string -// Deprecated: using pkg/config, we will delete this in v2.1.0 -func (js *JSONConfig) ParseData(data []byte) (Configer, error) { - x := &JSONConfigContainer{ - data: make(map[string]interface{}), - } - err := json.Unmarshal(data, &x.data) - if err != nil { - var wrappingArray []interface{} - err2 := json.Unmarshal(data, &wrappingArray) - if err2 != nil { - return nil, err - } - x.data["rootArray"] = wrappingArray - } - - x.data = ExpandValueEnvForMap(x.data) - - return x, nil -} - -// JSONConfigContainer A Config represents the json configuration. -// Only when get value, support key as section:name type. -// Deprecated: using pkg/config, we will delete this in v2.1.0 -type JSONConfigContainer struct { - data map[string]interface{} - sync.RWMutex -} - -// Bool returns the boolean value for a given key. -// Deprecated: using pkg/config, we will delete this in v2.1.0 -func (c *JSONConfigContainer) Bool(key string) (bool, error) { - val := c.getData(key) - if val != nil { - return ParseBool(val) - } - return false, fmt.Errorf("not exist key: %q", key) -} - -// DefaultBool return the bool value if has no error -// otherwise return the defaultval -// Deprecated: using pkg/config, we will delete this in v2.1.0 -func (c *JSONConfigContainer) DefaultBool(key string, defaultval bool) bool { - if v, err := c.Bool(key); err == nil { - return v - } - return defaultval -} - -// Int returns the integer value for a given key. -// Deprecated: using pkg/config, we will delete this in v2.1.0 -func (c *JSONConfigContainer) Int(key string) (int, error) { - val := c.getData(key) - if val != nil { - if v, ok := val.(float64); ok { - return int(v), nil - } else if v, ok := val.(string); ok { - return strconv.Atoi(v) - } - return 0, errors.New("not valid value") - } - return 0, errors.New("not exist key:" + key) -} - -// DefaultInt returns the integer value for a given key. -// if err != nil return defaultval -// Deprecated: using pkg/config, we will delete this in v2.1.0 -func (c *JSONConfigContainer) DefaultInt(key string, defaultval int) int { - if v, err := c.Int(key); err == nil { - return v - } - return defaultval -} - -// Int64 returns the int64 value for a given key. -// Deprecated: using pkg/config, we will delete this in v2.1.0 -func (c *JSONConfigContainer) Int64(key string) (int64, error) { - val := c.getData(key) - if val != nil { - if v, ok := val.(float64); ok { - return int64(v), nil - } - return 0, errors.New("not int64 value") - } - return 0, errors.New("not exist key:" + key) -} - -// DefaultInt64 returns the int64 value for a given key. -// if err != nil return defaultval -// Deprecated: using pkg/config, we will delete this in v2.1.0 -func (c *JSONConfigContainer) DefaultInt64(key string, defaultval int64) int64 { - if v, err := c.Int64(key); err == nil { - return v - } - return defaultval -} - -// Float returns the float value for a given key. -// Deprecated: using pkg/config, we will delete this in v2.1.0 -func (c *JSONConfigContainer) Float(key string) (float64, error) { - val := c.getData(key) - if val != nil { - if v, ok := val.(float64); ok { - return v, nil - } - return 0.0, errors.New("not float64 value") - } - return 0.0, errors.New("not exist key:" + key) -} - -// DefaultFloat returns the float64 value for a given key. -// if err != nil return defaultval -// Deprecated: using pkg/config, we will delete this in v2.1.0 -func (c *JSONConfigContainer) DefaultFloat(key string, defaultval float64) float64 { - if v, err := c.Float(key); err == nil { - return v - } - return defaultval -} - -// String returns the string value for a given key. -// Deprecated: using pkg/config, we will delete this in v2.1.0 -func (c *JSONConfigContainer) String(key string) string { - val := c.getData(key) - if val != nil { - if v, ok := val.(string); ok { - return v - } - } - return "" -} - -// DefaultString returns the string value for a given key. -// if err != nil return defaultval -// Deprecated: using pkg/config, we will delete this in v2.1.0 -func (c *JSONConfigContainer) DefaultString(key string, defaultval string) string { - // TODO FIXME should not use "" to replace non existence - if v := c.String(key); v != "" { - return v - } - return defaultval -} - -// Strings returns the []string value for a given key. -// Deprecated: using pkg/config, we will delete this in v2.1.0 -func (c *JSONConfigContainer) Strings(key string) []string { - stringVal := c.String(key) - if stringVal == "" { - return nil - } - return strings.Split(c.String(key), ";") -} - -// DefaultStrings returns the []string value for a given key. -// if err != nil return defaultval -// Deprecated: using pkg/config, we will delete this in v2.1.0 -func (c *JSONConfigContainer) DefaultStrings(key string, defaultval []string) []string { - if v := c.Strings(key); v != nil { - return v - } - return defaultval -} - -// GetSection returns map for the given section -// Deprecated: using pkg/config, we will delete this in v2.1.0 -func (c *JSONConfigContainer) GetSection(section string) (map[string]string, error) { - if v, ok := c.data[section]; ok { - return v.(map[string]string), nil - } - return nil, errors.New("nonexist section " + section) -} - -// SaveConfigFile save the config into file -// Deprecated: using pkg/config, we will delete this in v2.1.0 -func (c *JSONConfigContainer) SaveConfigFile(filename string) (err error) { - // Write configuration file by filename. - f, err := os.Create(filename) - if err != nil { - return err - } - defer f.Close() - b, err := json.MarshalIndent(c.data, "", " ") - if err != nil { - return err - } - _, err = f.Write(b) - return err -} - -// Set writes a new value for key. -// Deprecated: using pkg/config, we will delete this in v2.1.0 -func (c *JSONConfigContainer) Set(key, val string) error { - c.Lock() - defer c.Unlock() - c.data[key] = val - return nil -} - -// DIY returns the raw value by a given key. -// Deprecated: using pkg/config, we will delete this in v2.1.0 -func (c *JSONConfigContainer) DIY(key string) (v interface{}, err error) { - val := c.getData(key) - if val != nil { - return val, nil - } - return nil, errors.New("not exist key") -} - -// section.key or key -// Deprecated: using pkg/config, we will delete this in v2.1.0 -func (c *JSONConfigContainer) getData(key string) interface{} { - if len(key) == 0 { - return nil - } - - c.RLock() - defer c.RUnlock() - - sectionKeys := strings.Split(key, "::") - if len(sectionKeys) >= 2 { - curValue, ok := c.data[sectionKeys[0]] - if !ok { - return nil - } - for _, key := range sectionKeys[1:] { - if v, ok := curValue.(map[string]interface{}); ok { - if curValue, ok = v[key]; !ok { - return nil - } - } - } - return curValue - } - if v, ok := c.data[key]; ok { - return v - } - return nil -} - -func init() { - Register("json", &JSONConfig{}) -} diff --git a/config/json_test.go b/config/json_test.go deleted file mode 100644 index 16f42409..00000000 --- a/config/json_test.go +++ /dev/null @@ -1,222 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package config - -import ( - "fmt" - "os" - "testing" -) - -func TestJsonStartsWithArray(t *testing.T) { - - const jsoncontextwitharray = `[ - { - "url": "user", - "serviceAPI": "http://www.test.com/user" - }, - { - "url": "employee", - "serviceAPI": "http://www.test.com/employee" - } -]` - f, err := os.Create("testjsonWithArray.conf") - if err != nil { - t.Fatal(err) - } - _, err = f.WriteString(jsoncontextwitharray) - if err != nil { - f.Close() - t.Fatal(err) - } - f.Close() - defer os.Remove("testjsonWithArray.conf") - jsonconf, err := NewConfig("json", "testjsonWithArray.conf") - if err != nil { - t.Fatal(err) - } - rootArray, err := jsonconf.DIY("rootArray") - if err != nil { - t.Error("array does not exist as element") - } - rootArrayCasted := rootArray.([]interface{}) - if rootArrayCasted == nil { - t.Error("array from root is nil") - } else { - elem := rootArrayCasted[0].(map[string]interface{}) - if elem["url"] != "user" || elem["serviceAPI"] != "http://www.test.com/user" { - t.Error("array[0] values are not valid") - } - - elem2 := rootArrayCasted[1].(map[string]interface{}) - if elem2["url"] != "employee" || elem2["serviceAPI"] != "http://www.test.com/employee" { - t.Error("array[1] values are not valid") - } - } -} - -func TestJson(t *testing.T) { - - var ( - jsoncontext = `{ -"appname": "beeapi", -"testnames": "foo;bar", -"httpport": 8080, -"mysqlport": 3600, -"PI": 3.1415976, -"runmode": "dev", -"autorender": false, -"copyrequestbody": true, -"session": "on", -"cookieon": "off", -"newreg": "OFF", -"needlogin": "ON", -"enableSession": "Y", -"enableCookie": "N", -"flag": 1, -"path1": "${GOPATH}", -"path2": "${GOPATH||/home/go}", -"database": { - "host": "host", - "port": "port", - "database": "database", - "username": "username", - "password": "${GOPATH}", - "conns":{ - "maxconnection":12, - "autoconnect":true, - "connectioninfo":"info", - "root": "${GOPATH}" - } - } -}` - keyValue = map[string]interface{}{ - "appname": "beeapi", - "testnames": []string{"foo", "bar"}, - "httpport": 8080, - "mysqlport": int64(3600), - "PI": 3.1415976, - "runmode": "dev", - "autorender": false, - "copyrequestbody": true, - "session": true, - "cookieon": false, - "newreg": false, - "needlogin": true, - "enableSession": true, - "enableCookie": false, - "flag": true, - "path1": os.Getenv("GOPATH"), - "path2": os.Getenv("GOPATH"), - "database::host": "host", - "database::port": "port", - "database::database": "database", - "database::password": os.Getenv("GOPATH"), - "database::conns::maxconnection": 12, - "database::conns::autoconnect": true, - "database::conns::connectioninfo": "info", - "database::conns::root": os.Getenv("GOPATH"), - "unknown": "", - } - ) - - f, err := os.Create("testjson.conf") - if err != nil { - t.Fatal(err) - } - _, err = f.WriteString(jsoncontext) - if err != nil { - f.Close() - t.Fatal(err) - } - f.Close() - defer os.Remove("testjson.conf") - jsonconf, err := NewConfig("json", "testjson.conf") - if err != nil { - t.Fatal(err) - } - - for k, v := range keyValue { - var err error - var value interface{} - switch v.(type) { - case int: - value, err = jsonconf.Int(k) - case int64: - value, err = jsonconf.Int64(k) - case float64: - value, err = jsonconf.Float(k) - case bool: - value, err = jsonconf.Bool(k) - case []string: - value = jsonconf.Strings(k) - case string: - value = jsonconf.String(k) - default: - value, err = jsonconf.DIY(k) - } - if err != nil { - t.Fatalf("get key %q value fatal,%v err %s", k, v, err) - } else if fmt.Sprintf("%v", v) != fmt.Sprintf("%v", value) { - t.Fatalf("get key %q value, want %v got %v .", k, v, value) - } - - } - if err = jsonconf.Set("name", "astaxie"); err != nil { - t.Fatal(err) - } - if jsonconf.String("name") != "astaxie" { - t.Fatal("get name error") - } - - if db, err := jsonconf.DIY("database"); err != nil { - t.Fatal(err) - } else if m, ok := db.(map[string]interface{}); !ok { - t.Log(db) - t.Fatal("db not map[string]interface{}") - } else { - if m["host"].(string) != "host" { - t.Fatal("get host err") - } - } - - if _, err := jsonconf.Int("unknown"); err == nil { - t.Error("unknown keys should return an error when expecting an Int") - } - - if _, err := jsonconf.Int64("unknown"); err == nil { - t.Error("unknown keys should return an error when expecting an Int64") - } - - if _, err := jsonconf.Float("unknown"); err == nil { - t.Error("unknown keys should return an error when expecting a Float") - } - - if _, err := jsonconf.DIY("unknown"); err == nil { - t.Error("unknown keys should return an error when expecting an interface{}") - } - - if val := jsonconf.String("unknown"); val != "" { - t.Error("unknown keys should return an empty string when expecting a String") - } - - if _, err := jsonconf.Bool("unknown"); err == nil { - t.Error("unknown keys should return an error when expecting a Bool") - } - - if !jsonconf.DefaultBool("unknown", true) { - t.Error("unknown keys with default value wrong") - } -} diff --git a/config/xml/xml.go b/config/xml/xml.go deleted file mode 100644 index 1601561f..00000000 --- a/config/xml/xml.go +++ /dev/null @@ -1,248 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Package xml for config provider. -// -// depend on github.com/beego/x2j. -// -// go install github.com/beego/x2j. -// -// Usage: -// import( -// _ "github.com/astaxie/beego/config/xml" -// "github.com/astaxie/beego/config" -// ) -// -// cnf, err := config.NewConfig("xml", "config.xml") -// -//More docs http://beego.me/docs/module/config.md -package xml - -import ( - "encoding/xml" - "errors" - "fmt" - "io/ioutil" - "os" - "strconv" - "strings" - "sync" - - "github.com/astaxie/beego/config" - "github.com/beego/x2j" -) - -// Config is a xml config parser and implements Config interface. -// xml configurations should be included in tag. -// only support key/value pair as value as each item. -// Deprecated: using pkg/config, we will delete this in v2.1.0 -type Config struct{} - -// Parse returns a ConfigContainer with parsed xml config map. -// Deprecated: using pkg/config, we will delete this in v2.1.0 -func (xc *Config) Parse(filename string) (config.Configer, error) { - context, err := ioutil.ReadFile(filename) - if err != nil { - return nil, err - } - - return xc.ParseData(context) -} - -// ParseData xml data -// Deprecated: using pkg/config, we will delete this in v2.1.0 -func (xc *Config) ParseData(data []byte) (config.Configer, error) { - x := &ConfigContainer{data: make(map[string]interface{})} - - d, err := x2j.DocToMap(string(data)) - if err != nil { - return nil, err - } - - x.data = config.ExpandValueEnvForMap(d["config"].(map[string]interface{})) - - return x, nil -} - -// ConfigContainer A Config represents the xml configuration. -// Deprecated: using pkg/config, we will delete this in v2.1.0 -type ConfigContainer struct { - data map[string]interface{} - sync.Mutex -} - -// Bool returns the boolean value for a given key. -// Deprecated: using pkg/config, we will delete this in v2.1.0 -func (c *ConfigContainer) Bool(key string) (bool, error) { - if v := c.data[key]; v != nil { - return config.ParseBool(v) - } - return false, fmt.Errorf("not exist key: %q", key) -} - -// DefaultBool return the bool value if has no error -// otherwise return the defaultval -// Deprecated: using pkg/config, we will delete this in v2.1.0 -func (c *ConfigContainer) DefaultBool(key string, defaultval bool) bool { - v, err := c.Bool(key) - if err != nil { - return defaultval - } - return v -} - -// Int returns the integer value for a given key. -// Deprecated: using pkg/config, we will delete this in v2.1.0 -func (c *ConfigContainer) Int(key string) (int, error) { - return strconv.Atoi(c.data[key].(string)) -} - -// DefaultInt returns the integer value for a given key. -// if err != nil return defaultval -// Deprecated: using pkg/config, we will delete this in v2.1.0 -func (c *ConfigContainer) DefaultInt(key string, defaultval int) int { - v, err := c.Int(key) - if err != nil { - return defaultval - } - return v -} - -// Int64 returns the int64 value for a given key. -// Deprecated: using pkg/config, we will delete this in v2.1.0 -func (c *ConfigContainer) Int64(key string) (int64, error) { - return strconv.ParseInt(c.data[key].(string), 10, 64) -} - -// DefaultInt64 returns the int64 value for a given key. -// if err != nil return defaultval -// Deprecated: using pkg/config, we will delete this in v2.1.0 -func (c *ConfigContainer) DefaultInt64(key string, defaultval int64) int64 { - v, err := c.Int64(key) - if err != nil { - return defaultval - } - return v - -} - -// Float returns the float value for a given key. -// Deprecated: using pkg/config, we will delete this in v2.1.0 -func (c *ConfigContainer) Float(key string) (float64, error) { - return strconv.ParseFloat(c.data[key].(string), 64) -} - -// DefaultFloat returns the float64 value for a given key. -// if err != nil return defaultval -// Deprecated: using pkg/config, we will delete this in v2.1.0 -func (c *ConfigContainer) DefaultFloat(key string, defaultval float64) float64 { - v, err := c.Float(key) - if err != nil { - return defaultval - } - return v -} - -// String returns the string value for a given key. -// Deprecated: using pkg/config, we will delete this in v2.1.0 -func (c *ConfigContainer) String(key string) string { - if v, ok := c.data[key].(string); ok { - return v - } - return "" -} - -// DefaultString returns the string value for a given key. -// if err != nil return defaultval -// Deprecated: using pkg/config, we will delete this in v2.1.0 -func (c *ConfigContainer) DefaultString(key string, defaultval string) string { - v := c.String(key) - if v == "" { - return defaultval - } - return v -} - -// Strings returns the []string value for a given key. -// Deprecated: using pkg/config, we will delete this in v2.1.0 -func (c *ConfigContainer) Strings(key string) []string { - v := c.String(key) - if v == "" { - return nil - } - return strings.Split(v, ";") -} - -// DefaultStrings returns the []string value for a given key. -// if err != nil return defaultval -// Deprecated: using pkg/config, we will delete this in v2.1.0 -func (c *ConfigContainer) DefaultStrings(key string, defaultval []string) []string { - v := c.Strings(key) - if v == nil { - return defaultval - } - return v -} - -// GetSection returns map for the given section -// Deprecated: using pkg/config, we will delete this in v2.1.0 -func (c *ConfigContainer) GetSection(section string) (map[string]string, error) { - if v, ok := c.data[section].(map[string]interface{}); ok { - mapstr := make(map[string]string) - for k, val := range v { - mapstr[k] = config.ToString(val) - } - return mapstr, nil - } - return nil, fmt.Errorf("section '%s' not found", section) -} - -// SaveConfigFile save the config into file -// Deprecated: using pkg/config, we will delete this in v2.1.0 -func (c *ConfigContainer) SaveConfigFile(filename string) (err error) { - // Write configuration file by filename. - f, err := os.Create(filename) - if err != nil { - return err - } - defer f.Close() - b, err := xml.MarshalIndent(c.data, " ", " ") - if err != nil { - return err - } - _, err = f.Write(b) - return err -} - -// Set writes a new value for key. -// Deprecated: using pkg/config, we will delete this in v2.1.0 -func (c *ConfigContainer) Set(key, val string) error { - c.Lock() - defer c.Unlock() - c.data[key] = val - return nil -} - -// DIY returns the raw value by a given key. -// Deprecated: using pkg/config, we will delete this in v2.1.0 -func (c *ConfigContainer) DIY(key string) (v interface{}, err error) { - if v, ok := c.data[key]; ok { - return v, nil - } - return nil, errors.New("not exist key") -} - -func init() { - config.Register("xml", &Config{}) -} diff --git a/config/xml/xml_test.go b/config/xml/xml_test.go deleted file mode 100644 index 346c866e..00000000 --- a/config/xml/xml_test.go +++ /dev/null @@ -1,125 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package xml - -import ( - "fmt" - "os" - "testing" - - "github.com/astaxie/beego/config" -) - -func TestXML(t *testing.T) { - - var ( - //xml parse should incluce in tags - xmlcontext = ` - -beeapi -8080 -3600 -3.1415976 -dev -false -true -${GOPATH} -${GOPATH||/home/go} - -1 -MySection - - -` - keyValue = map[string]interface{}{ - "appname": "beeapi", - "httpport": 8080, - "mysqlport": int64(3600), - "PI": 3.1415976, - "runmode": "dev", - "autorender": false, - "copyrequestbody": true, - "path1": os.Getenv("GOPATH"), - "path2": os.Getenv("GOPATH"), - "error": "", - "emptystrings": []string{}, - } - ) - - f, err := os.Create("testxml.conf") - if err != nil { - t.Fatal(err) - } - _, err = f.WriteString(xmlcontext) - if err != nil { - f.Close() - t.Fatal(err) - } - f.Close() - defer os.Remove("testxml.conf") - - xmlconf, err := config.NewConfig("xml", "testxml.conf") - if err != nil { - t.Fatal(err) - } - - var xmlsection map[string]string - xmlsection, err = xmlconf.GetSection("mysection") - if err != nil { - t.Fatal(err) - } - - if len(xmlsection) == 0 { - t.Error("section should not be empty") - } - - for k, v := range keyValue { - - var ( - value interface{} - err error - ) - - switch v.(type) { - case int: - value, err = xmlconf.Int(k) - case int64: - value, err = xmlconf.Int64(k) - case float64: - value, err = xmlconf.Float(k) - case bool: - value, err = xmlconf.Bool(k) - case []string: - value = xmlconf.Strings(k) - case string: - value = xmlconf.String(k) - default: - value, err = xmlconf.DIY(k) - } - if err != nil { - t.Errorf("get key %q value fatal,%v err %s", k, v, err) - } else if fmt.Sprintf("%v", v) != fmt.Sprintf("%v", value) { - t.Errorf("get key %q value, want %v got %v .", k, v, value) - } - - } - - if err = xmlconf.Set("name", "astaxie"); err != nil { - t.Fatal(err) - } - if xmlconf.String("name") != "astaxie" { - t.Fatal("get name error") - } -} diff --git a/config/yaml/yaml.go b/config/yaml/yaml.go deleted file mode 100644 index 725f905b..00000000 --- a/config/yaml/yaml.go +++ /dev/null @@ -1,337 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Package yaml for config provider -// -// depend on github.com/beego/goyaml2 -// -// go install github.com/beego/goyaml2 -// -// Usage: -// import( -// _ "github.com/astaxie/beego/config/yaml" -// "github.com/astaxie/beego/config" -// ) -// -// cnf, err := config.NewConfig("yaml", "config.yaml") -// -//More docs http://beego.me/docs/module/config.md -package yaml - -import ( - "bytes" - "encoding/json" - "errors" - "fmt" - "io/ioutil" - "log" - "os" - "strings" - "sync" - - "github.com/astaxie/beego/config" - "github.com/beego/goyaml2" -) - -// Config is a yaml config parser and implements Config interface. -// Deprecated: using pkg/config, we will delete this in v2.1.0 -type Config struct{} - -// Parse returns a ConfigContainer with parsed yaml config map. -// Deprecated: using pkg/config, we will delete this in v2.1.0 -func (yaml *Config) Parse(filename string) (y config.Configer, err error) { - cnf, err := ReadYmlReader(filename) - if err != nil { - return - } - y = &ConfigContainer{ - data: cnf, - } - return -} - -// ParseData parse yaml data -// Deprecated: using pkg/config, we will delete this in v2.1.0 -func (yaml *Config) ParseData(data []byte) (config.Configer, error) { - cnf, err := parseYML(data) - if err != nil { - return nil, err - } - - return &ConfigContainer{ - data: cnf, - }, nil -} - -// ReadYmlReader Read yaml file to map. -// if json like, use json package, unless goyaml2 package. -// Deprecated: using pkg/config, we will delete this in v2.1.0 -func ReadYmlReader(path string) (cnf map[string]interface{}, err error) { - buf, err := ioutil.ReadFile(path) - if err != nil { - return - } - - return parseYML(buf) -} - -// parseYML parse yaml formatted []byte to map. -func parseYML(buf []byte) (cnf map[string]interface{}, err error) { - if len(buf) < 3 { - return - } - - if string(buf[0:1]) == "{" { - log.Println("Look like a Json, try json umarshal") - err = json.Unmarshal(buf, &cnf) - if err == nil { - log.Println("It is Json Map") - return - } - } - - data, err := goyaml2.Read(bytes.NewReader(buf)) - if err != nil { - log.Println("Goyaml2 ERR>", string(buf), err) - return - } - - if data == nil { - log.Println("Goyaml2 output nil? Pls report bug\n" + string(buf)) - return - } - cnf, ok := data.(map[string]interface{}) - if !ok { - log.Println("Not a Map? >> ", string(buf), data) - cnf = nil - } - cnf = config.ExpandValueEnvForMap(cnf) - return -} - -// ConfigContainer A Config represents the yaml configuration. -// Deprecated: using pkg/config, we will delete this in v2.1.0 -type ConfigContainer struct { - data map[string]interface{} - sync.RWMutex -} - -// Bool returns the boolean value for a given key. -// Deprecated: using pkg/config, we will delete this in v2.1.0 -func (c *ConfigContainer) Bool(key string) (bool, error) { - v, err := c.getData(key) - if err != nil { - return false, err - } - return config.ParseBool(v) -} - -// DefaultBool return the bool value if has no error -// otherwise return the defaultval -// Deprecated: using pkg/config, we will delete this in v2.1.0 -func (c *ConfigContainer) DefaultBool(key string, defaultval bool) bool { - v, err := c.Bool(key) - if err != nil { - return defaultval - } - return v -} - -// Int returns the integer value for a given key. -// Deprecated: using pkg/config, we will delete this in v2.1.0 -func (c *ConfigContainer) Int(key string) (int, error) { - if v, err := c.getData(key); err != nil { - return 0, err - } else if vv, ok := v.(int); ok { - return vv, nil - } else if vv, ok := v.(int64); ok { - return int(vv), nil - } - return 0, errors.New("not int value") -} - -// DefaultInt returns the integer value for a given key. -// if err != nil return defaultval -// Deprecated: using pkg/config, we will delete this in v2.1.0 -func (c *ConfigContainer) DefaultInt(key string, defaultval int) int { - v, err := c.Int(key) - if err != nil { - return defaultval - } - return v -} - -// Int64 returns the int64 value for a given key. -// Deprecated: using pkg/config, we will delete this in v2.1.0 -func (c *ConfigContainer) Int64(key string) (int64, error) { - if v, err := c.getData(key); err != nil { - return 0, err - } else if vv, ok := v.(int64); ok { - return vv, nil - } - return 0, errors.New("not bool value") -} - -// DefaultInt64 returns the int64 value for a given key. -// if err != nil return defaultval -// Deprecated: using pkg/config, we will delete this in v2.1.0 -func (c *ConfigContainer) DefaultInt64(key string, defaultval int64) int64 { - v, err := c.Int64(key) - if err != nil { - return defaultval - } - return v -} - -// Float returns the float value for a given key. -// Deprecated: using pkg/config, we will delete this in v2.1.0 -func (c *ConfigContainer) Float(key string) (float64, error) { - if v, err := c.getData(key); err != nil { - return 0.0, err - } else if vv, ok := v.(float64); ok { - return vv, nil - } else if vv, ok := v.(int); ok { - return float64(vv), nil - } else if vv, ok := v.(int64); ok { - return float64(vv), nil - } - return 0.0, errors.New("not float64 value") -} - -// DefaultFloat returns the float64 value for a given key. -// if err != nil return defaultval -// Deprecated: using pkg/config, we will delete this in v2.1.0 -func (c *ConfigContainer) DefaultFloat(key string, defaultval float64) float64 { - v, err := c.Float(key) - if err != nil { - return defaultval - } - return v -} - -// String returns the string value for a given key. -// Deprecated: using pkg/config, we will delete this in v2.1.0 -func (c *ConfigContainer) String(key string) string { - if v, err := c.getData(key); err == nil { - if vv, ok := v.(string); ok { - return vv - } - } - return "" -} - -// DefaultString returns the string value for a given key. -// if err != nil return defaultval -// Deprecated: using pkg/config, we will delete this in v2.1.0 -func (c *ConfigContainer) DefaultString(key string, defaultval string) string { - v := c.String(key) - if v == "" { - return defaultval - } - return v -} - -// Strings returns the []string value for a given key. -// Deprecated: using pkg/config, we will delete this in v2.1.0 -func (c *ConfigContainer) Strings(key string) []string { - v := c.String(key) - if v == "" { - return nil - } - return strings.Split(v, ";") -} - -// DefaultStrings returns the []string value for a given key. -// if err != nil return defaultval -// Deprecated: using pkg/config, we will delete this in v2.1.0 -func (c *ConfigContainer) DefaultStrings(key string, defaultval []string) []string { - v := c.Strings(key) - if v == nil { - return defaultval - } - return v -} - -// GetSection returns map for the given section -// Deprecated: using pkg/config, we will delete this in v2.1.0 -func (c *ConfigContainer) GetSection(section string) (map[string]string, error) { - - if v, ok := c.data[section]; ok { - return v.(map[string]string), nil - } - return nil, errors.New("not exist section") -} - -// SaveConfigFile save the config into file -// Deprecated: using pkg/config, we will delete this in v2.1.0 -func (c *ConfigContainer) SaveConfigFile(filename string) (err error) { - // Write configuration file by filename. - f, err := os.Create(filename) - if err != nil { - return err - } - defer f.Close() - err = goyaml2.Write(f, c.data) - return err -} - -// Set writes a new value for key. -// Deprecated: using pkg/config, we will delete this in v2.1.0 -func (c *ConfigContainer) Set(key, val string) error { - c.Lock() - defer c.Unlock() - c.data[key] = val - return nil -} - -// DIY returns the raw value by a given key. -// Deprecated: using pkg/config, we will delete this in v2.1.0 -func (c *ConfigContainer) DIY(key string) (v interface{}, err error) { - return c.getData(key) -} - -func (c *ConfigContainer) getData(key string) (interface{}, error) { - - if len(key) == 0 { - return nil, errors.New("key is empty") - } - c.RLock() - defer c.RUnlock() - - keys := strings.Split(key, ".") - tmpData := c.data - for idx, k := range keys { - if v, ok := tmpData[k]; ok { - switch v.(type) { - case map[string]interface{}: - { - tmpData = v.(map[string]interface{}) - if idx == len(keys)-1 { - return tmpData, nil - } - } - default: - { - return v, nil - } - - } - } - } - return nil, fmt.Errorf("not exist key %q", key) -} - -func init() { - config.Register("yaml", &Config{}) -} diff --git a/config/yaml/yaml_test.go b/config/yaml/yaml_test.go deleted file mode 100644 index 49cc1d1e..00000000 --- a/config/yaml/yaml_test.go +++ /dev/null @@ -1,115 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package yaml - -import ( - "fmt" - "os" - "testing" - - "github.com/astaxie/beego/config" -) - -func TestYaml(t *testing.T) { - - var ( - yamlcontext = ` -"appname": beeapi -"httpport": 8080 -"mysqlport": 3600 -"PI": 3.1415976 -"runmode": dev -"autorender": false -"copyrequestbody": true -"PATH": GOPATH -"path1": ${GOPATH} -"path2": ${GOPATH||/home/go} -"empty": "" -` - - keyValue = map[string]interface{}{ - "appname": "beeapi", - "httpport": 8080, - "mysqlport": int64(3600), - "PI": 3.1415976, - "runmode": "dev", - "autorender": false, - "copyrequestbody": true, - "PATH": "GOPATH", - "path1": os.Getenv("GOPATH"), - "path2": os.Getenv("GOPATH"), - "error": "", - "emptystrings": []string{}, - } - ) - f, err := os.Create("testyaml.conf") - if err != nil { - t.Fatal(err) - } - _, err = f.WriteString(yamlcontext) - if err != nil { - f.Close() - t.Fatal(err) - } - f.Close() - defer os.Remove("testyaml.conf") - yamlconf, err := config.NewConfig("yaml", "testyaml.conf") - if err != nil { - t.Fatal(err) - } - - if yamlconf.String("appname") != "beeapi" { - t.Fatal("appname not equal to beeapi") - } - - for k, v := range keyValue { - - var ( - value interface{} - err error - ) - - switch v.(type) { - case int: - value, err = yamlconf.Int(k) - case int64: - value, err = yamlconf.Int64(k) - case float64: - value, err = yamlconf.Float(k) - case bool: - value, err = yamlconf.Bool(k) - case []string: - value = yamlconf.Strings(k) - case string: - value = yamlconf.String(k) - default: - value, err = yamlconf.DIY(k) - } - if err != nil { - t.Errorf("get key %q value fatal,%v err %s", k, v, err) - } else if fmt.Sprintf("%v", v) != fmt.Sprintf("%v", value) { - t.Errorf("get key %q value, want %v got %v .", k, v, value) - } - - } - - if err = yamlconf.Set("name", "astaxie"); err != nil { - t.Fatal(err) - } - if yamlconf.String("name") != "astaxie" { - t.Fatal("get name error") - } - -} diff --git a/config_test.go b/config_test.go deleted file mode 100644 index 5f71f1c3..00000000 --- a/config_test.go +++ /dev/null @@ -1,146 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package beego - -import ( - "encoding/json" - "reflect" - "testing" - - "github.com/astaxie/beego/config" -) - -func TestDefaults(t *testing.T) { - if BConfig.WebConfig.FlashName != "BEEGO_FLASH" { - t.Errorf("FlashName was not set to default.") - } - - if BConfig.WebConfig.FlashSeparator != "BEEGOFLASH" { - t.Errorf("FlashName was not set to default.") - } -} - -func TestAssignConfig_01(t *testing.T) { - _BConfig := &Config{} - _BConfig.AppName = "beego_test" - jcf := &config.JSONConfig{} - ac, _ := jcf.ParseData([]byte(`{"AppName":"beego_json"}`)) - assignSingleConfig(_BConfig, ac) - if _BConfig.AppName != "beego_json" { - t.Log(_BConfig) - t.FailNow() - } -} - -func TestAssignConfig_02(t *testing.T) { - _BConfig := &Config{} - bs, _ := json.Marshal(newBConfig()) - - jsonMap := M{} - json.Unmarshal(bs, &jsonMap) - - configMap := M{} - for k, v := range jsonMap { - if reflect.TypeOf(v).Kind() == reflect.Map { - for k1, v1 := range v.(M) { - if reflect.TypeOf(v1).Kind() == reflect.Map { - for k2, v2 := range v1.(M) { - configMap[k2] = v2 - } - } else { - configMap[k1] = v1 - } - } - } else { - configMap[k] = v - } - } - configMap["MaxMemory"] = 1024 - configMap["Graceful"] = true - configMap["XSRFExpire"] = 32 - configMap["SessionProviderConfig"] = "file" - configMap["FileLineNum"] = true - - jcf := &config.JSONConfig{} - bs, _ = json.Marshal(configMap) - ac, _ := jcf.ParseData(bs) - - for _, i := range []interface{}{_BConfig, &_BConfig.Listen, &_BConfig.WebConfig, &_BConfig.Log, &_BConfig.WebConfig.Session} { - assignSingleConfig(i, ac) - } - - if _BConfig.MaxMemory != 1024 { - t.Log(_BConfig.MaxMemory) - t.FailNow() - } - - if !_BConfig.Listen.Graceful { - t.Log(_BConfig.Listen.Graceful) - t.FailNow() - } - - if _BConfig.WebConfig.XSRFExpire != 32 { - t.Log(_BConfig.WebConfig.XSRFExpire) - t.FailNow() - } - - if _BConfig.WebConfig.Session.SessionProviderConfig != "file" { - t.Log(_BConfig.WebConfig.Session.SessionProviderConfig) - t.FailNow() - } - - if !_BConfig.Log.FileLineNum { - t.Log(_BConfig.Log.FileLineNum) - t.FailNow() - } - -} - -func TestAssignConfig_03(t *testing.T) { - jcf := &config.JSONConfig{} - ac, _ := jcf.ParseData([]byte(`{"AppName":"beego"}`)) - ac.Set("AppName", "test_app") - ac.Set("RunMode", "online") - ac.Set("StaticDir", "download:down download2:down2") - ac.Set("StaticExtensionsToGzip", ".css,.js,.html,.jpg,.png") - ac.Set("StaticCacheFileSize", "87456") - ac.Set("StaticCacheFileNum", "1254") - assignConfig(ac) - - t.Logf("%#v", BConfig) - - if BConfig.AppName != "test_app" { - t.FailNow() - } - - if BConfig.RunMode != "online" { - t.FailNow() - } - if BConfig.WebConfig.StaticDir["/download"] != "down" { - t.FailNow() - } - if BConfig.WebConfig.StaticDir["/download2"] != "down2" { - t.FailNow() - } - if BConfig.WebConfig.StaticCacheFileSize != 87456 { - t.FailNow() - } - if BConfig.WebConfig.StaticCacheFileNum != 1254 { - t.FailNow() - } - if len(BConfig.WebConfig.StaticExtensionsToGzip) != 5 { - t.FailNow() - } -} diff --git a/context/acceptencoder.go b/context/acceptencoder.go deleted file mode 100644 index b4e2492c..00000000 --- a/context/acceptencoder.go +++ /dev/null @@ -1,232 +0,0 @@ -// Copyright 2015 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package context - -import ( - "bytes" - "compress/flate" - "compress/gzip" - "compress/zlib" - "io" - "net/http" - "os" - "strconv" - "strings" - "sync" -) - -var ( - //Default size==20B same as nginx - defaultGzipMinLength = 20 - //Content will only be compressed if content length is either unknown or greater than gzipMinLength. - gzipMinLength = defaultGzipMinLength - //The compression level used for deflate compression. (0-9). - gzipCompressLevel int - //List of HTTP methods to compress. If not set, only GET requests are compressed. - includedMethods map[string]bool - getMethodOnly bool -) - -// InitGzip init the gzipcompress -func InitGzip(minLength, compressLevel int, methods []string) { - if minLength >= 0 { - gzipMinLength = minLength - } - gzipCompressLevel = compressLevel - if gzipCompressLevel < flate.NoCompression || gzipCompressLevel > flate.BestCompression { - gzipCompressLevel = flate.BestSpeed - } - getMethodOnly = (len(methods) == 0) || (len(methods) == 1 && strings.ToUpper(methods[0]) == "GET") - includedMethods = make(map[string]bool, len(methods)) - for _, v := range methods { - includedMethods[strings.ToUpper(v)] = true - } -} - -type resetWriter interface { - io.Writer - Reset(w io.Writer) -} - -type nopResetWriter struct { - io.Writer -} - -func (n nopResetWriter) Reset(w io.Writer) { - //do nothing -} - -type acceptEncoder struct { - name string - levelEncode func(int) resetWriter - customCompressLevelPool *sync.Pool - bestCompressionPool *sync.Pool -} - -func (ac acceptEncoder) encode(wr io.Writer, level int) resetWriter { - if ac.customCompressLevelPool == nil || ac.bestCompressionPool == nil { - return nopResetWriter{wr} - } - var rwr resetWriter - switch level { - case flate.BestSpeed: - rwr = ac.customCompressLevelPool.Get().(resetWriter) - case flate.BestCompression: - rwr = ac.bestCompressionPool.Get().(resetWriter) - default: - rwr = ac.levelEncode(level) - } - rwr.Reset(wr) - return rwr -} - -func (ac acceptEncoder) put(wr resetWriter, level int) { - if ac.customCompressLevelPool == nil || ac.bestCompressionPool == nil { - return - } - wr.Reset(nil) - - //notice - //compressionLevel==BestCompression DOES NOT MATTER - //sync.Pool will not memory leak - - switch level { - case gzipCompressLevel: - ac.customCompressLevelPool.Put(wr) - case flate.BestCompression: - ac.bestCompressionPool.Put(wr) - } -} - -var ( - noneCompressEncoder = acceptEncoder{"", nil, nil, nil} - gzipCompressEncoder = acceptEncoder{ - name: "gzip", - levelEncode: func(level int) resetWriter { wr, _ := gzip.NewWriterLevel(nil, level); return wr }, - customCompressLevelPool: &sync.Pool{New: func() interface{} { wr, _ := gzip.NewWriterLevel(nil, gzipCompressLevel); return wr }}, - bestCompressionPool: &sync.Pool{New: func() interface{} { wr, _ := gzip.NewWriterLevel(nil, flate.BestCompression); return wr }}, - } - - //according to the sec :http://tools.ietf.org/html/rfc2616#section-3.5 ,the deflate compress in http is zlib indeed - //deflate - //The "zlib" format defined in RFC 1950 [31] in combination with - //the "deflate" compression mechanism described in RFC 1951 [29]. - deflateCompressEncoder = acceptEncoder{ - name: "deflate", - levelEncode: func(level int) resetWriter { wr, _ := zlib.NewWriterLevel(nil, level); return wr }, - customCompressLevelPool: &sync.Pool{New: func() interface{} { wr, _ := zlib.NewWriterLevel(nil, gzipCompressLevel); return wr }}, - bestCompressionPool: &sync.Pool{New: func() interface{} { wr, _ := zlib.NewWriterLevel(nil, flate.BestCompression); return wr }}, - } -) - -var ( - encoderMap = map[string]acceptEncoder{ // all the other compress methods will ignore - "gzip": gzipCompressEncoder, - "deflate": deflateCompressEncoder, - "*": gzipCompressEncoder, // * means any compress will accept,we prefer gzip - "identity": noneCompressEncoder, // identity means none-compress - } -) - -// WriteFile reads from file and writes to writer by the specific encoding(gzip/deflate) -func WriteFile(encoding string, writer io.Writer, file *os.File) (bool, string, error) { - return writeLevel(encoding, writer, file, flate.BestCompression) -} - -// WriteBody reads writes content to writer by the specific encoding(gzip/deflate) -func WriteBody(encoding string, writer io.Writer, content []byte) (bool, string, error) { - if encoding == "" || len(content) < gzipMinLength { - _, err := writer.Write(content) - return false, "", err - } - return writeLevel(encoding, writer, bytes.NewReader(content), gzipCompressLevel) -} - -// writeLevel reads from reader,writes to writer by specific encoding and compress level -// the compress level is defined by deflate package -func writeLevel(encoding string, writer io.Writer, reader io.Reader, level int) (bool, string, error) { - var outputWriter resetWriter - var err error - var ce = noneCompressEncoder - - if cf, ok := encoderMap[encoding]; ok { - ce = cf - } - encoding = ce.name - outputWriter = ce.encode(writer, level) - defer ce.put(outputWriter, level) - - _, err = io.Copy(outputWriter, reader) - if err != nil { - return false, "", err - } - - switch outputWriter.(type) { - case io.WriteCloser: - outputWriter.(io.WriteCloser).Close() - } - return encoding != "", encoding, nil -} - -// ParseEncoding will extract the right encoding for response -// the Accept-Encoding's sec is here: -// http://www.w3.org/Protocols/rfc2616/rfc2616-sec14.html#sec14.3 -func ParseEncoding(r *http.Request) string { - if r == nil { - return "" - } - if (getMethodOnly && r.Method == "GET") || includedMethods[r.Method] { - return parseEncoding(r) - } - return "" -} - -type q struct { - name string - value float64 -} - -func parseEncoding(r *http.Request) string { - acceptEncoding := r.Header.Get("Accept-Encoding") - if acceptEncoding == "" { - return "" - } - var lastQ q - for _, v := range strings.Split(acceptEncoding, ",") { - v = strings.TrimSpace(v) - if v == "" { - continue - } - vs := strings.Split(v, ";") - var cf acceptEncoder - var ok bool - if cf, ok = encoderMap[vs[0]]; !ok { - continue - } - if len(vs) == 1 { - return cf.name - } - if len(vs) == 2 { - f, _ := strconv.ParseFloat(strings.Replace(vs[1], "q=", "", -1), 64) - if f == 0 { - continue - } - if f > lastQ.value { - lastQ = q{cf.name, f} - } - } - } - return lastQ.name -} diff --git a/context/acceptencoder_test.go b/context/acceptencoder_test.go deleted file mode 100644 index e3d61e27..00000000 --- a/context/acceptencoder_test.go +++ /dev/null @@ -1,59 +0,0 @@ -// Copyright 2015 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package context - -import ( - "net/http" - "testing" -) - -func Test_ExtractEncoding(t *testing.T) { - if parseEncoding(&http.Request{Header: map[string][]string{"Accept-Encoding": {"gzip,deflate"}}}) != "gzip" { - t.Fail() - } - if parseEncoding(&http.Request{Header: map[string][]string{"Accept-Encoding": {"deflate,gzip"}}}) != "deflate" { - t.Fail() - } - if parseEncoding(&http.Request{Header: map[string][]string{"Accept-Encoding": {"gzip;q=.5,deflate"}}}) != "deflate" { - t.Fail() - } - if parseEncoding(&http.Request{Header: map[string][]string{"Accept-Encoding": {"gzip;q=.5,deflate;q=0.3"}}}) != "gzip" { - t.Fail() - } - if parseEncoding(&http.Request{Header: map[string][]string{"Accept-Encoding": {"gzip;q=0,deflate"}}}) != "deflate" { - t.Fail() - } - if parseEncoding(&http.Request{Header: map[string][]string{"Accept-Encoding": {"deflate;q=0.5,gzip;q=0.5,identity"}}}) != "" { - t.Fail() - } - if parseEncoding(&http.Request{Header: map[string][]string{"Accept-Encoding": {"*"}}}) != "gzip" { - t.Fail() - } - if parseEncoding(&http.Request{Header: map[string][]string{"Accept-Encoding": {"x,gzip,deflate"}}}) != "gzip" { - t.Fail() - } - if parseEncoding(&http.Request{Header: map[string][]string{"Accept-Encoding": {"gzip,x,deflate"}}}) != "gzip" { - t.Fail() - } - if parseEncoding(&http.Request{Header: map[string][]string{"Accept-Encoding": {"gzip;q=0.5,x,deflate"}}}) != "deflate" { - t.Fail() - } - if parseEncoding(&http.Request{Header: map[string][]string{"Accept-Encoding": {"x"}}}) != "" { - t.Fail() - } - if parseEncoding(&http.Request{Header: map[string][]string{"Accept-Encoding": {"gzip;q=0.5,x;q=0.8"}}}) != "gzip" { - t.Fail() - } -} diff --git a/context/context.go b/context/context.go deleted file mode 100644 index 7c161ac0..00000000 --- a/context/context.go +++ /dev/null @@ -1,263 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Package context provide the context utils -// Usage: -// -// import "github.com/astaxie/beego/context" -// -// ctx := context.Context{Request:req,ResponseWriter:rw} -// -// more docs http://beego.me/docs/module/context.md -package context - -import ( - "bufio" - "crypto/hmac" - "crypto/sha256" - "encoding/base64" - "errors" - "fmt" - "net" - "net/http" - "strconv" - "strings" - "time" - - "github.com/astaxie/beego/utils" -) - -//commonly used mime-types -const ( - ApplicationJSON = "application/json" - ApplicationXML = "application/xml" - ApplicationYAML = "application/x-yaml" - TextXML = "text/xml" -) - -// NewContext return the Context with Input and Output -func NewContext() *Context { - return &Context{ - Input: NewInput(), - Output: NewOutput(), - } -} - -// Context Http request context struct including BeegoInput, BeegoOutput, http.Request and http.ResponseWriter. -// BeegoInput and BeegoOutput provides some api to operate request and response more easily. -type Context struct { - Input *BeegoInput - Output *BeegoOutput - Request *http.Request - ResponseWriter *Response - _xsrfToken string -} - -// Reset init Context, BeegoInput and BeegoOutput -func (ctx *Context) Reset(rw http.ResponseWriter, r *http.Request) { - ctx.Request = r - if ctx.ResponseWriter == nil { - ctx.ResponseWriter = &Response{} - } - ctx.ResponseWriter.reset(rw) - ctx.Input.Reset(ctx) - ctx.Output.Reset(ctx) - ctx._xsrfToken = "" -} - -// Redirect does redirection to localurl with http header status code. -func (ctx *Context) Redirect(status int, localurl string) { - http.Redirect(ctx.ResponseWriter, ctx.Request, localurl, status) -} - -// Abort stops this request. -// if beego.ErrorMaps exists, panic body. -func (ctx *Context) Abort(status int, body string) { - ctx.Output.SetStatus(status) - panic(body) -} - -// WriteString Write string to response body. -// it sends response body. -func (ctx *Context) WriteString(content string) { - ctx.ResponseWriter.Write([]byte(content)) -} - -// GetCookie Get cookie from request by a given key. -// It's alias of BeegoInput.Cookie. -func (ctx *Context) GetCookie(key string) string { - return ctx.Input.Cookie(key) -} - -// SetCookie Set cookie for response. -// It's alias of BeegoOutput.Cookie. -func (ctx *Context) SetCookie(name string, value string, others ...interface{}) { - ctx.Output.Cookie(name, value, others...) -} - -// GetSecureCookie Get secure cookie from request by a given key. -func (ctx *Context) GetSecureCookie(Secret, key string) (string, bool) { - val := ctx.Input.Cookie(key) - if val == "" { - return "", false - } - - parts := strings.SplitN(val, "|", 3) - - if len(parts) != 3 { - return "", false - } - - vs := parts[0] - timestamp := parts[1] - sig := parts[2] - - h := hmac.New(sha256.New, []byte(Secret)) - fmt.Fprintf(h, "%s%s", vs, timestamp) - - if fmt.Sprintf("%02x", h.Sum(nil)) != sig { - return "", false - } - res, _ := base64.URLEncoding.DecodeString(vs) - return string(res), true -} - -// SetSecureCookie Set Secure cookie for response. -func (ctx *Context) SetSecureCookie(Secret, name, value string, others ...interface{}) { - vs := base64.URLEncoding.EncodeToString([]byte(value)) - timestamp := strconv.FormatInt(time.Now().UnixNano(), 10) - h := hmac.New(sha256.New, []byte(Secret)) - fmt.Fprintf(h, "%s%s", vs, timestamp) - sig := fmt.Sprintf("%02x", h.Sum(nil)) - cookie := strings.Join([]string{vs, timestamp, sig}, "|") - ctx.Output.Cookie(name, cookie, others...) -} - -// XSRFToken creates a xsrf token string and returns. -func (ctx *Context) XSRFToken(key string, expire int64) string { - if ctx._xsrfToken == "" { - token, ok := ctx.GetSecureCookie(key, "_xsrf") - if !ok { - token = string(utils.RandomCreateBytes(32)) - ctx.SetSecureCookie(key, "_xsrf", token, expire, "", "", true, true) - } - ctx._xsrfToken = token - } - return ctx._xsrfToken -} - -// CheckXSRFCookie checks xsrf token in this request is valid or not. -// the token can provided in request header "X-Xsrftoken" and "X-CsrfToken" -// or in form field value named as "_xsrf". -func (ctx *Context) CheckXSRFCookie() bool { - token := ctx.Input.Query("_xsrf") - if token == "" { - token = ctx.Request.Header.Get("X-Xsrftoken") - } - if token == "" { - token = ctx.Request.Header.Get("X-Csrftoken") - } - if token == "" { - ctx.Abort(422, "422") - return false - } - if ctx._xsrfToken != token { - ctx.Abort(417, "417") - return false - } - return true -} - -// RenderMethodResult renders the return value of a controller method to the output -func (ctx *Context) RenderMethodResult(result interface{}) { - if result != nil { - renderer, ok := result.(Renderer) - if !ok { - err, ok := result.(error) - if ok { - renderer = errorRenderer(err) - } else { - renderer = jsonRenderer(result) - } - } - renderer.Render(ctx) - } -} - -//Response is a wrapper for the http.ResponseWriter -//started set to true if response was written to then don't execute other handler -type Response struct { - http.ResponseWriter - Started bool - Status int - Elapsed time.Duration -} - -func (r *Response) reset(rw http.ResponseWriter) { - r.ResponseWriter = rw - r.Status = 0 - r.Started = false -} - -// Write writes the data to the connection as part of an HTTP reply, -// and sets `started` to true. -// started means the response has sent out. -func (r *Response) Write(p []byte) (int, error) { - r.Started = true - return r.ResponseWriter.Write(p) -} - -// WriteHeader sends an HTTP response header with status code, -// and sets `started` to true. -func (r *Response) WriteHeader(code int) { - if r.Status > 0 { - //prevent multiple response.WriteHeader calls - return - } - r.Status = code - r.Started = true - r.ResponseWriter.WriteHeader(code) -} - -// Hijack hijacker for http -func (r *Response) Hijack() (net.Conn, *bufio.ReadWriter, error) { - hj, ok := r.ResponseWriter.(http.Hijacker) - if !ok { - return nil, nil, errors.New("webserver doesn't support hijacking") - } - return hj.Hijack() -} - -// Flush http.Flusher -func (r *Response) Flush() { - if f, ok := r.ResponseWriter.(http.Flusher); ok { - f.Flush() - } -} - -// CloseNotify http.CloseNotifier -func (r *Response) CloseNotify() <-chan bool { - if cn, ok := r.ResponseWriter.(http.CloseNotifier); ok { - return cn.CloseNotify() - } - return nil -} - -// Pusher http.Pusher -func (r *Response) Pusher() (pusher http.Pusher) { - if pusher, ok := r.ResponseWriter.(http.Pusher); ok { - return pusher - } - return nil -} diff --git a/context/context_test.go b/context/context_test.go deleted file mode 100644 index e81e8191..00000000 --- a/context/context_test.go +++ /dev/null @@ -1,54 +0,0 @@ -// Copyright 2016 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package context - -import ( - "net/http" - "net/http/httptest" - "strings" - "testing" - - "github.com/stretchr/testify/assert" -) - -func TestXsrfReset_01(t *testing.T) { - r := &http.Request{} - c := NewContext() - c.Request = r - c.ResponseWriter = &Response{} - c.ResponseWriter.reset(httptest.NewRecorder()) - c.Output.Reset(c) - c.Input.Reset(c) - c.XSRFToken("key", 16) - if c._xsrfToken == "" { - t.FailNow() - } - token := c._xsrfToken - c.Reset(&Response{ResponseWriter: httptest.NewRecorder()}, r) - if c._xsrfToken != "" { - t.FailNow() - } - c.XSRFToken("key", 16) - if c._xsrfToken == "" { - t.FailNow() - } - if token == c._xsrfToken { - t.FailNow() - } - - ck := c.ResponseWriter.Header().Get("Set-Cookie") - assert.True(t, strings.Contains(ck, "Secure")) - assert.True(t, strings.Contains(ck, "HttpOnly")) -} diff --git a/context/input.go b/context/input.go deleted file mode 100644 index c2c1c63d..00000000 --- a/context/input.go +++ /dev/null @@ -1,709 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package context - -import ( - "bytes" - "compress/gzip" - "errors" - "io" - "io/ioutil" - "net" - "net/http" - "net/url" - "reflect" - "regexp" - "strconv" - "strings" - "sync" - - "github.com/astaxie/beego/session" -) - -// Regexes for checking the accept headers -// TODO make sure these are correct -var ( - acceptsHTMLRegex = regexp.MustCompile(`(text/html|application/xhtml\+xml)(?:,|$)`) - acceptsXMLRegex = regexp.MustCompile(`(application/xml|text/xml)(?:,|$)`) - acceptsJSONRegex = regexp.MustCompile(`(application/json)(?:,|$)`) - acceptsYAMLRegex = regexp.MustCompile(`(application/x-yaml)(?:,|$)`) - maxParam = 50 -) - -// BeegoInput operates the http request header, data, cookie and body. -// it also contains router params and current session. -type BeegoInput struct { - Context *Context - CruSession session.Store - pnames []string - pvalues []string - data map[interface{}]interface{} // store some values in this context when calling context in filter or controller. - dataLock sync.RWMutex - RequestBody []byte - RunMethod string - RunController reflect.Type -} - -// NewInput return BeegoInput generated by Context. -func NewInput() *BeegoInput { - return &BeegoInput{ - pnames: make([]string, 0, maxParam), - pvalues: make([]string, 0, maxParam), - data: make(map[interface{}]interface{}), - } -} - -// Reset init the BeegoInput -func (input *BeegoInput) Reset(ctx *Context) { - input.Context = ctx - input.CruSession = nil - input.pnames = input.pnames[:0] - input.pvalues = input.pvalues[:0] - input.dataLock.Lock() - input.data = nil - input.dataLock.Unlock() - input.RequestBody = []byte{} -} - -// Protocol returns request protocol name, such as HTTP/1.1 . -func (input *BeegoInput) Protocol() string { - return input.Context.Request.Proto -} - -// URI returns full request url with query string, fragment. -func (input *BeegoInput) URI() string { - return input.Context.Request.RequestURI -} - -// URL returns request url path (without query string, fragment). -func (input *BeegoInput) URL() string { - return input.Context.Request.URL.EscapedPath() -} - -// Site returns base site url as scheme://domain type. -func (input *BeegoInput) Site() string { - return input.Scheme() + "://" + input.Domain() -} - -// Scheme returns request scheme as "http" or "https". -func (input *BeegoInput) Scheme() string { - if scheme := input.Header("X-Forwarded-Proto"); scheme != "" { - return scheme - } - if input.Context.Request.URL.Scheme != "" { - return input.Context.Request.URL.Scheme - } - if input.Context.Request.TLS == nil { - return "http" - } - return "https" -} - -// Domain returns host name. -// Alias of Host method. -func (input *BeegoInput) Domain() string { - return input.Host() -} - -// Host returns host name. -// if no host info in request, return localhost. -func (input *BeegoInput) Host() string { - if input.Context.Request.Host != "" { - if hostPart, _, err := net.SplitHostPort(input.Context.Request.Host); err == nil { - return hostPart - } - return input.Context.Request.Host - } - return "localhost" -} - -// Method returns http request method. -func (input *BeegoInput) Method() string { - return input.Context.Request.Method -} - -// Is returns boolean of this request is on given method, such as Is("POST"). -func (input *BeegoInput) Is(method string) bool { - return input.Method() == method -} - -// IsGet Is this a GET method request? -func (input *BeegoInput) IsGet() bool { - return input.Is("GET") -} - -// IsPost Is this a POST method request? -func (input *BeegoInput) IsPost() bool { - return input.Is("POST") -} - -// IsHead Is this a Head method request? -func (input *BeegoInput) IsHead() bool { - return input.Is("HEAD") -} - -// IsOptions Is this a OPTIONS method request? -func (input *BeegoInput) IsOptions() bool { - return input.Is("OPTIONS") -} - -// IsPut Is this a PUT method request? -func (input *BeegoInput) IsPut() bool { - return input.Is("PUT") -} - -// IsDelete Is this a DELETE method request? -func (input *BeegoInput) IsDelete() bool { - return input.Is("DELETE") -} - -// IsPatch Is this a PATCH method request? -func (input *BeegoInput) IsPatch() bool { - return input.Is("PATCH") -} - -// IsAjax returns boolean of this request is generated by ajax. -func (input *BeegoInput) IsAjax() bool { - return input.Header("X-Requested-With") == "XMLHttpRequest" -} - -// IsSecure returns boolean of this request is in https. -func (input *BeegoInput) IsSecure() bool { - return input.Scheme() == "https" -} - -// IsWebsocket returns boolean of this request is in webSocket. -func (input *BeegoInput) IsWebsocket() bool { - return input.Header("Upgrade") == "websocket" -} - -// IsUpload returns boolean of whether file uploads in this request or not.. -func (input *BeegoInput) IsUpload() bool { - return strings.Contains(input.Header("Content-Type"), "multipart/form-data") -} - -// AcceptsHTML Checks if request accepts html response -func (input *BeegoInput) AcceptsHTML() bool { - return acceptsHTMLRegex.MatchString(input.Header("Accept")) -} - -// AcceptsXML Checks if request accepts xml response -func (input *BeegoInput) AcceptsXML() bool { - return acceptsXMLRegex.MatchString(input.Header("Accept")) -} - -// AcceptsJSON Checks if request accepts json response -func (input *BeegoInput) AcceptsJSON() bool { - return acceptsJSONRegex.MatchString(input.Header("Accept")) -} - -// AcceptsYAML Checks if request accepts json response -func (input *BeegoInput) AcceptsYAML() bool { - return acceptsYAMLRegex.MatchString(input.Header("Accept")) -} - -// IP returns request client ip. -// if in proxy, return first proxy id. -// if error, return RemoteAddr. -func (input *BeegoInput) IP() string { - ips := input.Proxy() - if len(ips) > 0 && ips[0] != "" { - rip, _, err := net.SplitHostPort(ips[0]) - if err != nil { - rip = ips[0] - } - return rip - } - if ip, _, err := net.SplitHostPort(input.Context.Request.RemoteAddr); err == nil { - return ip - } - return input.Context.Request.RemoteAddr -} - -// Proxy returns proxy client ips slice. -func (input *BeegoInput) Proxy() []string { - if ips := input.Header("X-Forwarded-For"); ips != "" { - return strings.Split(ips, ",") - } - return []string{} -} - -// Referer returns http referer header. -func (input *BeegoInput) Referer() string { - return input.Header("Referer") -} - -// Refer returns http referer header. -func (input *BeegoInput) Refer() string { - return input.Referer() -} - -// SubDomains returns sub domain string. -// if aa.bb.domain.com, returns aa.bb . -func (input *BeegoInput) SubDomains() string { - parts := strings.Split(input.Host(), ".") - if len(parts) >= 3 { - return strings.Join(parts[:len(parts)-2], ".") - } - return "" -} - -// Port returns request client port. -// when error or empty, return 80. -func (input *BeegoInput) Port() int { - if _, portPart, err := net.SplitHostPort(input.Context.Request.Host); err == nil { - port, _ := strconv.Atoi(portPart) - return port - } - return 80 -} - -// UserAgent returns request client user agent string. -func (input *BeegoInput) UserAgent() string { - return input.Header("User-Agent") -} - -// ParamsLen return the length of the params -func (input *BeegoInput) ParamsLen() int { - return len(input.pnames) -} - -// Param returns router param by a given key. -func (input *BeegoInput) Param(key string) string { - for i, v := range input.pnames { - if v == key && i <= len(input.pvalues) { - // we cannot use url.PathEscape(input.pvalues[i]) - // for example, if the value is /a/b - // after url.PathEscape(input.pvalues[i]), the value is %2Fa%2Fb - // However, the value is used in ControllerRegister.ServeHTTP - // and split by "/", so function crash... - return input.pvalues[i] - } - } - return "" -} - -// Params returns the map[key]value. -func (input *BeegoInput) Params() map[string]string { - m := make(map[string]string) - for i, v := range input.pnames { - if i <= len(input.pvalues) { - m[v] = input.pvalues[i] - } - } - return m -} - -// SetParam will set the param with key and value -func (input *BeegoInput) SetParam(key, val string) { - // check if already exists - for i, v := range input.pnames { - if v == key && i <= len(input.pvalues) { - input.pvalues[i] = val - return - } - } - input.pvalues = append(input.pvalues, val) - input.pnames = append(input.pnames, key) -} - -// ResetParams clears any of the input's Params -// This function is used to clear parameters so they may be reset between filter -// passes. -func (input *BeegoInput) ResetParams() { - input.pnames = input.pnames[:0] - input.pvalues = input.pvalues[:0] -} - -// ResetData: reset data -func (input *BeegoInput) ResetData() { - input.dataLock.Lock() - input.data = nil - input.dataLock.Unlock() -} - -// ResetBody: reset body -func (input *BeegoInput) ResetBody() { - input.RequestBody = []byte{} -} - -// Clear: clear all data in input -func (input *BeegoInput) Clear() { - input.ResetParams() - input.ResetData() - input.ResetBody() - -} - -// Query returns input data item string by a given string. -func (input *BeegoInput) Query(key string) string { - if val := input.Param(key); val != "" { - return val - } - if input.Context.Request.Form == nil { - input.dataLock.Lock() - if input.Context.Request.Form == nil { - input.Context.Request.ParseForm() - } - input.dataLock.Unlock() - } - input.dataLock.RLock() - defer input.dataLock.RUnlock() - return input.Context.Request.Form.Get(key) -} - -// Header returns request header item string by a given string. -// if non-existed, return empty string. -func (input *BeegoInput) Header(key string) string { - return input.Context.Request.Header.Get(key) -} - -// Cookie returns request cookie item string by a given key. -// if non-existed, return empty string. -func (input *BeegoInput) Cookie(key string) string { - ck, err := input.Context.Request.Cookie(key) - if err != nil { - return "" - } - return ck.Value -} - -// Session returns current session item value by a given key. -// if non-existed, return nil. -func (input *BeegoInput) Session(key interface{}) interface{} { - return input.CruSession.Get(key) -} - -// CopyBody returns the raw request body data as bytes. -func (input *BeegoInput) CopyBody(MaxMemory int64) []byte { - if input.Context.Request.Body == nil { - return []byte{} - } - - var requestbody []byte - safe := &io.LimitedReader{R: input.Context.Request.Body, N: MaxMemory} - if input.Header("Content-Encoding") == "gzip" { - reader, err := gzip.NewReader(safe) - if err != nil { - return nil - } - requestbody, _ = ioutil.ReadAll(reader) - } else { - requestbody, _ = ioutil.ReadAll(safe) - } - - input.Context.Request.Body.Close() - bf := bytes.NewBuffer(requestbody) - input.Context.Request.Body = http.MaxBytesReader(input.Context.ResponseWriter, ioutil.NopCloser(bf), MaxMemory) - input.RequestBody = requestbody - return requestbody -} - -// Data return the implicit data in the input -func (input *BeegoInput) Data() map[interface{}]interface{} { - input.dataLock.Lock() - defer input.dataLock.Unlock() - if input.data == nil { - input.data = make(map[interface{}]interface{}) - } - return input.data -} - -// GetData returns the stored data in this context. -func (input *BeegoInput) GetData(key interface{}) interface{} { - input.dataLock.Lock() - defer input.dataLock.Unlock() - if v, ok := input.data[key]; ok { - return v - } - return nil -} - -// SetData stores data with given key in this context. -// This data are only available in this context. -func (input *BeegoInput) SetData(key, val interface{}) { - input.dataLock.Lock() - defer input.dataLock.Unlock() - if input.data == nil { - input.data = make(map[interface{}]interface{}) - } - input.data[key] = val -} - -// ParseFormOrMulitForm parseForm or parseMultiForm based on Content-type -func (input *BeegoInput) ParseFormOrMulitForm(maxMemory int64) error { - // Parse the body depending on the content type. - if strings.Contains(input.Header("Content-Type"), "multipart/form-data") { - if err := input.Context.Request.ParseMultipartForm(maxMemory); err != nil { - return errors.New("Error parsing request body:" + err.Error()) - } - } else if err := input.Context.Request.ParseForm(); err != nil { - return errors.New("Error parsing request body:" + err.Error()) - } - return nil -} - -// Bind data from request.Form[key] to dest -// like /?id=123&isok=true&ft=1.2&ol[0]=1&ol[1]=2&ul[]=str&ul[]=array&user.Name=astaxie -// var id int beegoInput.Bind(&id, "id") id ==123 -// var isok bool beegoInput.Bind(&isok, "isok") isok ==true -// var ft float64 beegoInput.Bind(&ft, "ft") ft ==1.2 -// ol := make([]int, 0, 2) beegoInput.Bind(&ol, "ol") ol ==[1 2] -// ul := make([]string, 0, 2) beegoInput.Bind(&ul, "ul") ul ==[str array] -// user struct{Name} beegoInput.Bind(&user, "user") user == {Name:"astaxie"} -func (input *BeegoInput) Bind(dest interface{}, key string) error { - value := reflect.ValueOf(dest) - if value.Kind() != reflect.Ptr { - return errors.New("beego: non-pointer passed to Bind: " + key) - } - value = value.Elem() - if !value.CanSet() { - return errors.New("beego: non-settable variable passed to Bind: " + key) - } - typ := value.Type() - // Get real type if dest define with interface{}. - // e.g var dest interface{} dest=1.0 - if value.Kind() == reflect.Interface { - typ = value.Elem().Type() - } - rv := input.bind(key, typ) - if !rv.IsValid() { - return errors.New("beego: reflect value is empty") - } - value.Set(rv) - return nil -} - -func (input *BeegoInput) bind(key string, typ reflect.Type) reflect.Value { - if input.Context.Request.Form == nil { - input.Context.Request.ParseForm() - } - rv := reflect.Zero(typ) - switch typ.Kind() { - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: - val := input.Query(key) - if len(val) == 0 { - return rv - } - rv = input.bindInt(val, typ) - case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: - val := input.Query(key) - if len(val) == 0 { - return rv - } - rv = input.bindUint(val, typ) - case reflect.Float32, reflect.Float64: - val := input.Query(key) - if len(val) == 0 { - return rv - } - rv = input.bindFloat(val, typ) - case reflect.String: - val := input.Query(key) - if len(val) == 0 { - return rv - } - rv = input.bindString(val, typ) - case reflect.Bool: - val := input.Query(key) - if len(val) == 0 { - return rv - } - rv = input.bindBool(val, typ) - case reflect.Slice: - rv = input.bindSlice(&input.Context.Request.Form, key, typ) - case reflect.Struct: - rv = input.bindStruct(&input.Context.Request.Form, key, typ) - case reflect.Ptr: - rv = input.bindPoint(key, typ) - case reflect.Map: - rv = input.bindMap(&input.Context.Request.Form, key, typ) - } - return rv -} - -func (input *BeegoInput) bindValue(val string, typ reflect.Type) reflect.Value { - rv := reflect.Zero(typ) - switch typ.Kind() { - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: - rv = input.bindInt(val, typ) - case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: - rv = input.bindUint(val, typ) - case reflect.Float32, reflect.Float64: - rv = input.bindFloat(val, typ) - case reflect.String: - rv = input.bindString(val, typ) - case reflect.Bool: - rv = input.bindBool(val, typ) - case reflect.Slice: - rv = input.bindSlice(&url.Values{"": {val}}, "", typ) - case reflect.Struct: - rv = input.bindStruct(&url.Values{"": {val}}, "", typ) - case reflect.Ptr: - rv = input.bindPoint(val, typ) - case reflect.Map: - rv = input.bindMap(&url.Values{"": {val}}, "", typ) - } - return rv -} - -func (input *BeegoInput) bindInt(val string, typ reflect.Type) reflect.Value { - intValue, err := strconv.ParseInt(val, 10, 64) - if err != nil { - return reflect.Zero(typ) - } - pValue := reflect.New(typ) - pValue.Elem().SetInt(intValue) - return pValue.Elem() -} - -func (input *BeegoInput) bindUint(val string, typ reflect.Type) reflect.Value { - uintValue, err := strconv.ParseUint(val, 10, 64) - if err != nil { - return reflect.Zero(typ) - } - pValue := reflect.New(typ) - pValue.Elem().SetUint(uintValue) - return pValue.Elem() -} - -func (input *BeegoInput) bindFloat(val string, typ reflect.Type) reflect.Value { - floatValue, err := strconv.ParseFloat(val, 64) - if err != nil { - return reflect.Zero(typ) - } - pValue := reflect.New(typ) - pValue.Elem().SetFloat(floatValue) - return pValue.Elem() -} - -func (input *BeegoInput) bindString(val string, typ reflect.Type) reflect.Value { - return reflect.ValueOf(val) -} - -func (input *BeegoInput) bindBool(val string, typ reflect.Type) reflect.Value { - val = strings.TrimSpace(strings.ToLower(val)) - switch val { - case "true", "on", "1": - return reflect.ValueOf(true) - } - return reflect.ValueOf(false) -} - -type sliceValue struct { - index int // Index extracted from brackets. If -1, no index was provided. - value reflect.Value // the bound value for this slice element. -} - -func (input *BeegoInput) bindSlice(params *url.Values, key string, typ reflect.Type) reflect.Value { - maxIndex := -1 - numNoIndex := 0 - sliceValues := []sliceValue{} - for reqKey, vals := range *params { - if !strings.HasPrefix(reqKey, key+"[") { - continue - } - // Extract the index, and the index where a sub-key starts. (e.g. field[0].subkey) - index := -1 - leftBracket, rightBracket := len(key), strings.Index(reqKey[len(key):], "]")+len(key) - if rightBracket > leftBracket+1 { - index, _ = strconv.Atoi(reqKey[leftBracket+1 : rightBracket]) - } - subKeyIndex := rightBracket + 1 - - // Handle the indexed case. - if index > -1 { - if index > maxIndex { - maxIndex = index - } - sliceValues = append(sliceValues, sliceValue{ - index: index, - value: input.bind(reqKey[:subKeyIndex], typ.Elem()), - }) - continue - } - - // It's an un-indexed element. (e.g. element[]) - numNoIndex += len(vals) - for _, val := range vals { - // Unindexed values can only be direct-bound. - sliceValues = append(sliceValues, sliceValue{ - index: -1, - value: input.bindValue(val, typ.Elem()), - }) - } - } - resultArray := reflect.MakeSlice(typ, maxIndex+1, maxIndex+1+numNoIndex) - for _, sv := range sliceValues { - if sv.index != -1 { - resultArray.Index(sv.index).Set(sv.value) - } else { - resultArray = reflect.Append(resultArray, sv.value) - } - } - return resultArray -} - -func (input *BeegoInput) bindStruct(params *url.Values, key string, typ reflect.Type) reflect.Value { - result := reflect.New(typ).Elem() - fieldValues := make(map[string]reflect.Value) - for reqKey, val := range *params { - var fieldName string - if strings.HasPrefix(reqKey, key+".") { - fieldName = reqKey[len(key)+1:] - } else if strings.HasPrefix(reqKey, key+"[") && reqKey[len(reqKey)-1] == ']' { - fieldName = reqKey[len(key)+1 : len(reqKey)-1] - } else { - continue - } - - if _, ok := fieldValues[fieldName]; !ok { - // Time to bind this field. Get it and make sure we can set it. - fieldValue := result.FieldByName(fieldName) - if !fieldValue.IsValid() { - continue - } - if !fieldValue.CanSet() { - continue - } - boundVal := input.bindValue(val[0], fieldValue.Type()) - fieldValue.Set(boundVal) - fieldValues[fieldName] = boundVal - } - } - - return result -} - -func (input *BeegoInput) bindPoint(key string, typ reflect.Type) reflect.Value { - return input.bind(key, typ.Elem()).Addr() -} - -func (input *BeegoInput) bindMap(params *url.Values, key string, typ reflect.Type) reflect.Value { - var ( - result = reflect.MakeMap(typ) - keyType = typ.Key() - valueType = typ.Elem() - ) - for paramName, values := range *params { - if !strings.HasPrefix(paramName, key+"[") || paramName[len(paramName)-1] != ']' { - continue - } - - key := paramName[len(key)+1 : len(paramName)-1] - result.SetMapIndex(input.bindValue(key, keyType), input.bindValue(values[0], valueType)) - } - return result -} diff --git a/context/input_test.go b/context/input_test.go deleted file mode 100644 index 3a6c2e7b..00000000 --- a/context/input_test.go +++ /dev/null @@ -1,217 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package context - -import ( - "net/http" - "net/http/httptest" - "reflect" - "testing" -) - -func TestBind(t *testing.T) { - type testItem struct { - field string - empty interface{} - want interface{} - } - type Human struct { - ID int - Nick string - Pwd string - Ms bool - } - - cases := []struct { - request string - valueGp []testItem - }{ - {"/?p=str", []testItem{{"p", interface{}(""), interface{}("str")}}}, - - {"/?p=", []testItem{{"p", "", ""}}}, - {"/?p=str", []testItem{{"p", "", "str"}}}, - - {"/?p=123", []testItem{{"p", 0, 123}}}, - {"/?p=123", []testItem{{"p", uint(0), uint(123)}}}, - - {"/?p=1.0", []testItem{{"p", 0.0, 1.0}}}, - {"/?p=1", []testItem{{"p", false, true}}}, - - {"/?p=true", []testItem{{"p", false, true}}}, - {"/?p=ON", []testItem{{"p", false, true}}}, - {"/?p=on", []testItem{{"p", false, true}}}, - {"/?p=1", []testItem{{"p", false, true}}}, - {"/?p=2", []testItem{{"p", false, false}}}, - {"/?p=false", []testItem{{"p", false, false}}}, - - {"/?p[a]=1&p[b]=2&p[c]=3", []testItem{{"p", map[string]int{}, map[string]int{"a": 1, "b": 2, "c": 3}}}}, - {"/?p[a]=v1&p[b]=v2&p[c]=v3", []testItem{{"p", map[string]string{}, map[string]string{"a": "v1", "b": "v2", "c": "v3"}}}}, - - {"/?p[]=8&p[]=9&p[]=10", []testItem{{"p", []int{}, []int{8, 9, 10}}}}, - {"/?p[0]=8&p[1]=9&p[2]=10", []testItem{{"p", []int{}, []int{8, 9, 10}}}}, - {"/?p[0]=8&p[1]=9&p[2]=10&p[5]=14", []testItem{{"p", []int{}, []int{8, 9, 10, 0, 0, 14}}}}, - {"/?p[0]=8.0&p[1]=9.0&p[2]=10.0", []testItem{{"p", []float64{}, []float64{8.0, 9.0, 10.0}}}}, - - {"/?p[]=10&p[]=9&p[]=8", []testItem{{"p", []string{}, []string{"10", "9", "8"}}}}, - {"/?p[0]=8&p[1]=9&p[2]=10", []testItem{{"p", []string{}, []string{"8", "9", "10"}}}}, - - {"/?p[0]=true&p[1]=false&p[2]=true&p[5]=1&p[6]=ON&p[7]=other", []testItem{{"p", []bool{}, []bool{true, false, true, false, false, true, true, false}}}}, - - {"/?human.Nick=astaxie", []testItem{{"human", Human{}, Human{Nick: "astaxie"}}}}, - {"/?human.ID=888&human.Nick=astaxie&human.Ms=true&human[Pwd]=pass", []testItem{{"human", Human{}, Human{ID: 888, Nick: "astaxie", Ms: true, Pwd: "pass"}}}}, - {"/?human[0].ID=888&human[0].Nick=astaxie&human[0].Ms=true&human[0][Pwd]=pass01&human[1].ID=999&human[1].Nick=ysqi&human[1].Ms=On&human[1].Pwd=pass02", - []testItem{{"human", []Human{}, []Human{ - {ID: 888, Nick: "astaxie", Ms: true, Pwd: "pass01"}, - {ID: 999, Nick: "ysqi", Ms: true, Pwd: "pass02"}, - }}}}, - - { - "/?id=123&isok=true&ft=1.2&ol[0]=1&ol[1]=2&ul[]=str&ul[]=array&human.Nick=astaxie", - []testItem{ - {"id", 0, 123}, - {"isok", false, true}, - {"ft", 0.0, 1.2}, - {"ol", []int{}, []int{1, 2}}, - {"ul", []string{}, []string{"str", "array"}}, - {"human", Human{}, Human{Nick: "astaxie"}}, - }, - }, - } - for _, c := range cases { - r, _ := http.NewRequest("GET", c.request, nil) - beegoInput := NewInput() - beegoInput.Context = NewContext() - beegoInput.Context.Reset(httptest.NewRecorder(), r) - - for _, item := range c.valueGp { - got := item.empty - err := beegoInput.Bind(&got, item.field) - if err != nil { - t.Fatal(err) - } - if !reflect.DeepEqual(got, item.want) { - t.Fatalf("Bind %q error,should be:\n%#v \ngot:\n%#v", item.field, item.want, got) - } - } - - } -} - -func TestSubDomain(t *testing.T) { - r, _ := http.NewRequest("GET", "http://www.example.com/?id=123&isok=true&ft=1.2&ol[0]=1&ol[1]=2&ul[]=str&ul[]=array&user.Name=astaxie", nil) - beegoInput := NewInput() - beegoInput.Context = NewContext() - beegoInput.Context.Reset(httptest.NewRecorder(), r) - - subdomain := beegoInput.SubDomains() - if subdomain != "www" { - t.Fatal("Subdomain parse error, got" + subdomain) - } - - r, _ = http.NewRequest("GET", "http://localhost/", nil) - beegoInput.Context.Request = r - if beegoInput.SubDomains() != "" { - t.Fatal("Subdomain parse error, should be empty, got " + beegoInput.SubDomains()) - } - - r, _ = http.NewRequest("GET", "http://aa.bb.example.com/", nil) - beegoInput.Context.Request = r - if beegoInput.SubDomains() != "aa.bb" { - t.Fatal("Subdomain parse error, got " + beegoInput.SubDomains()) - } - - /* TODO Fix this - r, _ = http.NewRequest("GET", "http://127.0.0.1/", nil) - beegoInput.Context.Request = r - if beegoInput.SubDomains() != "" { - t.Fatal("Subdomain parse error, got " + beegoInput.SubDomains()) - } - */ - - r, _ = http.NewRequest("GET", "http://example.com/", nil) - beegoInput.Context.Request = r - if beegoInput.SubDomains() != "" { - t.Fatal("Subdomain parse error, got " + beegoInput.SubDomains()) - } - - r, _ = http.NewRequest("GET", "http://aa.bb.cc.dd.example.com/", nil) - beegoInput.Context.Request = r - if beegoInput.SubDomains() != "aa.bb.cc.dd" { - t.Fatal("Subdomain parse error, got " + beegoInput.SubDomains()) - } -} - -func TestParams(t *testing.T) { - inp := NewInput() - - inp.SetParam("p1", "val1_ver1") - inp.SetParam("p2", "val2_ver1") - inp.SetParam("p3", "val3_ver1") - if l := inp.ParamsLen(); l != 3 { - t.Fatalf("Input.ParamsLen wrong value: %d, expected %d", l, 3) - } - - if val := inp.Param("p1"); val != "val1_ver1" { - t.Fatalf("Input.Param wrong value: %s, expected %s", val, "val1_ver1") - } - if val := inp.Param("p3"); val != "val3_ver1" { - t.Fatalf("Input.Param wrong value: %s, expected %s", val, "val3_ver1") - } - vals := inp.Params() - expected := map[string]string{ - "p1": "val1_ver1", - "p2": "val2_ver1", - "p3": "val3_ver1", - } - if !reflect.DeepEqual(vals, expected) { - t.Fatalf("Input.Params wrong value: %s, expected %s", vals, expected) - } - - // overwriting existing params - inp.SetParam("p1", "val1_ver2") - inp.SetParam("p2", "val2_ver2") - expected = map[string]string{ - "p1": "val1_ver2", - "p2": "val2_ver2", - "p3": "val3_ver1", - } - vals = inp.Params() - if !reflect.DeepEqual(vals, expected) { - t.Fatalf("Input.Params wrong value: %s, expected %s", vals, expected) - } - - if l := inp.ParamsLen(); l != 3 { - t.Fatalf("Input.ParamsLen wrong value: %d, expected %d", l, 3) - } - - if val := inp.Param("p1"); val != "val1_ver2" { - t.Fatalf("Input.Param wrong value: %s, expected %s", val, "val1_ver2") - } - - if val := inp.Param("p2"); val != "val2_ver2" { - t.Fatalf("Input.Param wrong value: %s, expected %s", val, "val1_ver2") - } - -} -func BenchmarkQuery(b *testing.B) { - beegoInput := NewInput() - beegoInput.Context = NewContext() - beegoInput.Context.Request, _ = http.NewRequest("POST", "http://www.example.com/?q=foo", nil) - b.RunParallel(func(pb *testing.PB) { - for pb.Next() { - beegoInput.Query("q") - } - }) -} diff --git a/context/output.go b/context/output.go deleted file mode 100644 index 7409e4e5..00000000 --- a/context/output.go +++ /dev/null @@ -1,413 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package context - -import ( - "bytes" - "encoding/json" - "encoding/xml" - "errors" - "fmt" - "html/template" - "io" - "mime" - "net/http" - "net/url" - "os" - "path/filepath" - "strconv" - "strings" - "time" - - yaml "gopkg.in/yaml.v2" -) - -// BeegoOutput does work for sending response header. -type BeegoOutput struct { - Context *Context - Status int - EnableGzip bool -} - -// NewOutput returns new BeegoOutput. -// it contains nothing now. -func NewOutput() *BeegoOutput { - return &BeegoOutput{} -} - -// Reset init BeegoOutput -func (output *BeegoOutput) Reset(ctx *Context) { - output.Context = ctx - output.Clear() -} - -// Clear: clear all data in output -func (output *BeegoOutput) Clear() { - output.Status = 0 -} - -// Header sets response header item string via given key. -func (output *BeegoOutput) Header(key, val string) { - output.Context.ResponseWriter.Header().Set(key, val) -} - -// Body sets response body content. -// if EnableGzip, compress content string. -// it sends out response body directly. -func (output *BeegoOutput) Body(content []byte) error { - var encoding string - var buf = &bytes.Buffer{} - if output.EnableGzip { - encoding = ParseEncoding(output.Context.Request) - } - if b, n, _ := WriteBody(encoding, buf, content); b { - output.Header("Content-Encoding", n) - output.Header("Content-Length", strconv.Itoa(buf.Len())) - } else { - output.Header("Content-Length", strconv.Itoa(len(content))) - } - // Write status code if it has been set manually - // Set it to 0 afterwards to prevent "multiple response.WriteHeader calls" - if output.Status != 0 { - output.Context.ResponseWriter.WriteHeader(output.Status) - output.Status = 0 - } else { - output.Context.ResponseWriter.Started = true - } - io.Copy(output.Context.ResponseWriter, buf) - return nil -} - -// Cookie sets cookie value via given key. -// others are ordered as cookie's max age time, path,domain, secure and httponly. -func (output *BeegoOutput) Cookie(name string, value string, others ...interface{}) { - var b bytes.Buffer - fmt.Fprintf(&b, "%s=%s", sanitizeName(name), sanitizeValue(value)) - - //fix cookie not work in IE - if len(others) > 0 { - var maxAge int64 - - switch v := others[0].(type) { - case int: - maxAge = int64(v) - case int32: - maxAge = int64(v) - case int64: - maxAge = v - } - - switch { - case maxAge > 0: - fmt.Fprintf(&b, "; Expires=%s; Max-Age=%d", time.Now().Add(time.Duration(maxAge)*time.Second).UTC().Format(time.RFC1123), maxAge) - case maxAge < 0: - fmt.Fprintf(&b, "; Max-Age=0") - } - } - - // the settings below - // Path, Domain, Secure, HttpOnly - // can use nil skip set - - // default "/" - if len(others) > 1 { - if v, ok := others[1].(string); ok && len(v) > 0 { - fmt.Fprintf(&b, "; Path=%s", sanitizeValue(v)) - } - } else { - fmt.Fprintf(&b, "; Path=%s", "/") - } - - // default empty - if len(others) > 2 { - if v, ok := others[2].(string); ok && len(v) > 0 { - fmt.Fprintf(&b, "; Domain=%s", sanitizeValue(v)) - } - } - - // default empty - if len(others) > 3 { - var secure bool - switch v := others[3].(type) { - case bool: - secure = v - default: - if others[3] != nil { - secure = true - } - } - if secure { - fmt.Fprintf(&b, "; Secure") - } - } - - // default false. for session cookie default true - if len(others) > 4 { - if v, ok := others[4].(bool); ok && v { - fmt.Fprintf(&b, "; HttpOnly") - } - } - - output.Context.ResponseWriter.Header().Add("Set-Cookie", b.String()) -} - -var cookieNameSanitizer = strings.NewReplacer("\n", "-", "\r", "-") - -func sanitizeName(n string) string { - return cookieNameSanitizer.Replace(n) -} - -var cookieValueSanitizer = strings.NewReplacer("\n", " ", "\r", " ", ";", " ") - -func sanitizeValue(v string) string { - return cookieValueSanitizer.Replace(v) -} - -func jsonRenderer(value interface{}) Renderer { - return rendererFunc(func(ctx *Context) { - ctx.Output.JSON(value, false, false) - }) -} - -func errorRenderer(err error) Renderer { - return rendererFunc(func(ctx *Context) { - ctx.Output.SetStatus(500) - ctx.Output.Body([]byte(err.Error())) - }) -} - -// JSON writes json to response body. -// if encoding is true, it converts utf-8 to \u0000 type. -func (output *BeegoOutput) JSON(data interface{}, hasIndent bool, encoding bool) error { - output.Header("Content-Type", "application/json; charset=utf-8") - var content []byte - var err error - if hasIndent { - content, err = json.MarshalIndent(data, "", " ") - } else { - content, err = json.Marshal(data) - } - if err != nil { - http.Error(output.Context.ResponseWriter, err.Error(), http.StatusInternalServerError) - return err - } - if encoding { - content = []byte(stringsToJSON(string(content))) - } - return output.Body(content) -} - -// YAML writes yaml to response body. -func (output *BeegoOutput) YAML(data interface{}) error { - output.Header("Content-Type", "application/x-yaml; charset=utf-8") - var content []byte - var err error - content, err = yaml.Marshal(data) - if err != nil { - http.Error(output.Context.ResponseWriter, err.Error(), http.StatusInternalServerError) - return err - } - return output.Body(content) -} - -// JSONP writes jsonp to response body. -func (output *BeegoOutput) JSONP(data interface{}, hasIndent bool) error { - output.Header("Content-Type", "application/javascript; charset=utf-8") - var content []byte - var err error - if hasIndent { - content, err = json.MarshalIndent(data, "", " ") - } else { - content, err = json.Marshal(data) - } - if err != nil { - http.Error(output.Context.ResponseWriter, err.Error(), http.StatusInternalServerError) - return err - } - callback := output.Context.Input.Query("callback") - if callback == "" { - return errors.New(`"callback" parameter required`) - } - callback = template.JSEscapeString(callback) - callbackContent := bytes.NewBufferString(" if(window." + callback + ")" + callback) - callbackContent.WriteString("(") - callbackContent.Write(content) - callbackContent.WriteString(");\r\n") - return output.Body(callbackContent.Bytes()) -} - -// XML writes xml string to response body. -func (output *BeegoOutput) XML(data interface{}, hasIndent bool) error { - output.Header("Content-Type", "application/xml; charset=utf-8") - var content []byte - var err error - if hasIndent { - content, err = xml.MarshalIndent(data, "", " ") - } else { - content, err = xml.Marshal(data) - } - if err != nil { - http.Error(output.Context.ResponseWriter, err.Error(), http.StatusInternalServerError) - return err - } - return output.Body(content) -} - -// ServeFormatted serve YAML, XML OR JSON, depending on the value of the Accept header -func (output *BeegoOutput) ServeFormatted(data interface{}, hasIndent bool, hasEncode ...bool) { - accept := output.Context.Input.Header("Accept") - switch accept { - case ApplicationYAML: - output.YAML(data) - case ApplicationXML, TextXML: - output.XML(data, hasIndent) - default: - output.JSON(data, hasIndent, len(hasEncode) > 0 && hasEncode[0]) - } -} - -// Download forces response for download file. -// it prepares the download response header automatically. -func (output *BeegoOutput) Download(file string, filename ...string) { - // check get file error, file not found or other error. - if _, err := os.Stat(file); err != nil { - http.ServeFile(output.Context.ResponseWriter, output.Context.Request, file) - return - } - - var fName string - if len(filename) > 0 && filename[0] != "" { - fName = filename[0] - } else { - fName = filepath.Base(file) - } - //https://tools.ietf.org/html/rfc6266#section-4.3 - fn := url.PathEscape(fName) - if fName == fn { - fn = "filename=" + fn - } else { - /** - The parameters "filename" and "filename*" differ only in that - "filename*" uses the encoding defined in [RFC5987], allowing the use - of characters not present in the ISO-8859-1 character set - ([ISO-8859-1]). - */ - fn = "filename=" + fName + "; filename*=utf-8''" + fn - } - output.Header("Content-Disposition", "attachment; "+fn) - output.Header("Content-Description", "File Transfer") - output.Header("Content-Type", "application/octet-stream") - output.Header("Content-Transfer-Encoding", "binary") - output.Header("Expires", "0") - output.Header("Cache-Control", "must-revalidate") - output.Header("Pragma", "public") - http.ServeFile(output.Context.ResponseWriter, output.Context.Request, file) -} - -// ContentType sets the content type from ext string. -// MIME type is given in mime package. -func (output *BeegoOutput) ContentType(ext string) { - if !strings.HasPrefix(ext, ".") { - ext = "." + ext - } - ctype := mime.TypeByExtension(ext) - if ctype != "" { - output.Header("Content-Type", ctype) - } -} - -// SetStatus sets response status code. -// It writes response header directly. -func (output *BeegoOutput) SetStatus(status int) { - output.Status = status -} - -// IsCachable returns boolean of this request is cached. -// HTTP 304 means cached. -func (output *BeegoOutput) IsCachable() bool { - return output.Status >= 200 && output.Status < 300 || output.Status == 304 -} - -// IsEmpty returns boolean of this request is empty. -// HTTP 201,204 and 304 means empty. -func (output *BeegoOutput) IsEmpty() bool { - return output.Status == 201 || output.Status == 204 || output.Status == 304 -} - -// IsOk returns boolean of this request runs well. -// HTTP 200 means ok. -func (output *BeegoOutput) IsOk() bool { - return output.Status == 200 -} - -// IsSuccessful returns boolean of this request runs successfully. -// HTTP 2xx means ok. -func (output *BeegoOutput) IsSuccessful() bool { - return output.Status >= 200 && output.Status < 300 -} - -// IsRedirect returns boolean of this request is redirection header. -// HTTP 301,302,307 means redirection. -func (output *BeegoOutput) IsRedirect() bool { - return output.Status == 301 || output.Status == 302 || output.Status == 303 || output.Status == 307 -} - -// IsForbidden returns boolean of this request is forbidden. -// HTTP 403 means forbidden. -func (output *BeegoOutput) IsForbidden() bool { - return output.Status == 403 -} - -// IsNotFound returns boolean of this request is not found. -// HTTP 404 means not found. -func (output *BeegoOutput) IsNotFound() bool { - return output.Status == 404 -} - -// IsClientError returns boolean of this request client sends error data. -// HTTP 4xx means client error. -func (output *BeegoOutput) IsClientError() bool { - return output.Status >= 400 && output.Status < 500 -} - -// IsServerError returns boolean of this server handler errors. -// HTTP 5xx means server internal error. -func (output *BeegoOutput) IsServerError() bool { - return output.Status >= 500 && output.Status < 600 -} - -func stringsToJSON(str string) string { - var jsons bytes.Buffer - for _, r := range str { - rint := int(r) - if rint < 128 { - jsons.WriteRune(r) - } else { - jsons.WriteString("\\u") - if rint < 0x100 { - jsons.WriteString("00") - } else if rint < 0x1000 { - jsons.WriteString("0") - } - jsons.WriteString(strconv.FormatInt(int64(rint), 16)) - } - } - return jsons.String() -} - -// Session sets session item value with given key. -func (output *BeegoOutput) Session(name interface{}, value interface{}) { - output.Context.Input.CruSession.Set(name, value) -} diff --git a/context/param/conv.go b/context/param/conv.go deleted file mode 100644 index c200e008..00000000 --- a/context/param/conv.go +++ /dev/null @@ -1,78 +0,0 @@ -package param - -import ( - "fmt" - "reflect" - - beecontext "github.com/astaxie/beego/context" - "github.com/astaxie/beego/logs" -) - -// ConvertParams converts http method params to values that will be passed to the method controller as arguments -func ConvertParams(methodParams []*MethodParam, methodType reflect.Type, ctx *beecontext.Context) (result []reflect.Value) { - result = make([]reflect.Value, 0, len(methodParams)) - for i := 0; i < len(methodParams); i++ { - reflectValue := convertParam(methodParams[i], methodType.In(i), ctx) - result = append(result, reflectValue) - } - return -} - -func convertParam(param *MethodParam, paramType reflect.Type, ctx *beecontext.Context) (result reflect.Value) { - paramValue := getParamValue(param, ctx) - if paramValue == "" { - if param.required { - ctx.Abort(400, fmt.Sprintf("Missing parameter %s", param.name)) - } else { - paramValue = param.defaultValue - } - } - - reflectValue, err := parseValue(param, paramValue, paramType) - if err != nil { - logs.Debug(fmt.Sprintf("Error converting param %s to type %s. Value: %v, Error: %s", param.name, paramType, paramValue, err)) - ctx.Abort(400, fmt.Sprintf("Invalid parameter %s. Can not convert %v to type %s", param.name, paramValue, paramType)) - } - - return reflectValue -} - -func getParamValue(param *MethodParam, ctx *beecontext.Context) string { - switch param.in { - case body: - return string(ctx.Input.RequestBody) - case header: - return ctx.Input.Header(param.name) - case path: - return ctx.Input.Query(":" + param.name) - default: - return ctx.Input.Query(param.name) - } -} - -func parseValue(param *MethodParam, paramValue string, paramType reflect.Type) (result reflect.Value, err error) { - if paramValue == "" { - return reflect.Zero(paramType), nil - } - parser := getParser(param, paramType) - value, err := parser.parse(paramValue, paramType) - if err != nil { - return result, err - } - - return safeConvert(reflect.ValueOf(value), paramType) -} - -func safeConvert(value reflect.Value, t reflect.Type) (result reflect.Value, err error) { - defer func() { - if r := recover(); r != nil { - var ok bool - err, ok = r.(error) - if !ok { - err = fmt.Errorf("%v", r) - } - } - }() - result = value.Convert(t) - return -} diff --git a/context/param/methodparams.go b/context/param/methodparams.go deleted file mode 100644 index cd6708a2..00000000 --- a/context/param/methodparams.go +++ /dev/null @@ -1,69 +0,0 @@ -package param - -import ( - "fmt" - "strings" -) - -//MethodParam keeps param information to be auto passed to controller methods -type MethodParam struct { - name string - in paramType - required bool - defaultValue string -} - -type paramType byte - -const ( - param paramType = iota - path - body - header -) - -//New creates a new MethodParam with name and specific options -func New(name string, opts ...MethodParamOption) *MethodParam { - return newParam(name, nil, opts) -} - -func newParam(name string, parser paramParser, opts []MethodParamOption) (param *MethodParam) { - param = &MethodParam{name: name} - for _, option := range opts { - option(param) - } - return -} - -//Make creates an array of MethodParmas or an empty array -func Make(list ...*MethodParam) []*MethodParam { - if len(list) > 0 { - return list - } - return nil -} - -func (mp *MethodParam) String() string { - options := []string{} - result := "param.New(\"" + mp.name + "\"" - if mp.required { - options = append(options, "param.IsRequired") - } - switch mp.in { - case path: - options = append(options, "param.InPath") - case body: - options = append(options, "param.InBody") - case header: - options = append(options, "param.InHeader") - } - if mp.defaultValue != "" { - options = append(options, fmt.Sprintf(`param.Default("%s")`, mp.defaultValue)) - } - if len(options) > 0 { - result += ", " - } - result += strings.Join(options, ", ") - result += ")" - return result -} diff --git a/context/param/options.go b/context/param/options.go deleted file mode 100644 index 3d5ba013..00000000 --- a/context/param/options.go +++ /dev/null @@ -1,37 +0,0 @@ -package param - -import ( - "fmt" -) - -// MethodParamOption defines a func which apply options on a MethodParam -type MethodParamOption func(*MethodParam) - -// IsRequired indicates that this param is required and can not be omitted from the http request -var IsRequired MethodParamOption = func(p *MethodParam) { - p.required = true -} - -// InHeader indicates that this param is passed via an http header -var InHeader MethodParamOption = func(p *MethodParam) { - p.in = header -} - -// InPath indicates that this param is part of the URL path -var InPath MethodParamOption = func(p *MethodParam) { - p.in = path -} - -// InBody indicates that this param is passed as an http request body -var InBody MethodParamOption = func(p *MethodParam) { - p.in = body -} - -// Default provides a default value for the http param -func Default(defaultValue interface{}) MethodParamOption { - return func(p *MethodParam) { - if defaultValue != nil { - p.defaultValue = fmt.Sprint(defaultValue) - } - } -} diff --git a/context/param/parsers.go b/context/param/parsers.go deleted file mode 100644 index 421aecf0..00000000 --- a/context/param/parsers.go +++ /dev/null @@ -1,149 +0,0 @@ -package param - -import ( - "encoding/json" - "reflect" - "strconv" - "strings" - "time" -) - -type paramParser interface { - parse(value string, toType reflect.Type) (interface{}, error) -} - -func getParser(param *MethodParam, t reflect.Type) paramParser { - switch t.Kind() { - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, - reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: - return intParser{} - case reflect.Slice: - if t.Elem().Kind() == reflect.Uint8 { //treat []byte as string - return stringParser{} - } - if param.in == body { - return jsonParser{} - } - elemParser := getParser(param, t.Elem()) - if elemParser == (jsonParser{}) { - return elemParser - } - return sliceParser(elemParser) - case reflect.Bool: - return boolParser{} - case reflect.String: - return stringParser{} - case reflect.Float32, reflect.Float64: - return floatParser{} - case reflect.Ptr: - elemParser := getParser(param, t.Elem()) - if elemParser == (jsonParser{}) { - return elemParser - } - return ptrParser(elemParser) - default: - if t.PkgPath() == "time" && t.Name() == "Time" { - return timeParser{} - } - return jsonParser{} - } -} - -type parserFunc func(value string, toType reflect.Type) (interface{}, error) - -func (f parserFunc) parse(value string, toType reflect.Type) (interface{}, error) { - return f(value, toType) -} - -type boolParser struct { -} - -func (p boolParser) parse(value string, toType reflect.Type) (interface{}, error) { - return strconv.ParseBool(value) -} - -type stringParser struct { -} - -func (p stringParser) parse(value string, toType reflect.Type) (interface{}, error) { - return value, nil -} - -type intParser struct { -} - -func (p intParser) parse(value string, toType reflect.Type) (interface{}, error) { - return strconv.Atoi(value) -} - -type floatParser struct { -} - -func (p floatParser) parse(value string, toType reflect.Type) (interface{}, error) { - if toType.Kind() == reflect.Float32 { - res, err := strconv.ParseFloat(value, 32) - if err != nil { - return nil, err - } - return float32(res), nil - } - return strconv.ParseFloat(value, 64) -} - -type timeParser struct { -} - -func (p timeParser) parse(value string, toType reflect.Type) (result interface{}, err error) { - result, err = time.Parse(time.RFC3339, value) - if err != nil { - result, err = time.Parse("2006-01-02", value) - } - return -} - -type jsonParser struct { -} - -func (p jsonParser) parse(value string, toType reflect.Type) (interface{}, error) { - pResult := reflect.New(toType) - v := pResult.Interface() - err := json.Unmarshal([]byte(value), v) - if err != nil { - return nil, err - } - return pResult.Elem().Interface(), nil -} - -func sliceParser(elemParser paramParser) paramParser { - return parserFunc(func(value string, toType reflect.Type) (interface{}, error) { - values := strings.Split(value, ",") - result := reflect.MakeSlice(toType, 0, len(values)) - elemType := toType.Elem() - for _, v := range values { - parsedValue, err := elemParser.parse(v, elemType) - if err != nil { - return nil, err - } - result = reflect.Append(result, reflect.ValueOf(parsedValue)) - } - return result.Interface(), nil - }) -} - -func ptrParser(elemParser paramParser) paramParser { - return parserFunc(func(value string, toType reflect.Type) (interface{}, error) { - parsedValue, err := elemParser.parse(value, toType.Elem()) - if err != nil { - return nil, err - } - newValPtr := reflect.New(toType.Elem()) - newVal := reflect.Indirect(newValPtr) - convertedVal, err := safeConvert(reflect.ValueOf(parsedValue), toType.Elem()) - if err != nil { - return nil, err - } - - newVal.Set(convertedVal) - return newValPtr.Interface(), nil - }) -} diff --git a/context/param/parsers_test.go b/context/param/parsers_test.go deleted file mode 100644 index 7065a28e..00000000 --- a/context/param/parsers_test.go +++ /dev/null @@ -1,84 +0,0 @@ -package param - -import "testing" -import "reflect" -import "time" - -type testDefinition struct { - strValue string - expectedValue interface{} - expectedParser paramParser -} - -func Test_Parsers(t *testing.T) { - - //ints - checkParser(testDefinition{"1", 1, intParser{}}, t) - checkParser(testDefinition{"-1", int64(-1), intParser{}}, t) - checkParser(testDefinition{"1", uint64(1), intParser{}}, t) - - //floats - checkParser(testDefinition{"1.0", float32(1.0), floatParser{}}, t) - checkParser(testDefinition{"-1.0", float64(-1.0), floatParser{}}, t) - - //strings - checkParser(testDefinition{"AB", "AB", stringParser{}}, t) - checkParser(testDefinition{"AB", []byte{65, 66}, stringParser{}}, t) - - //bools - checkParser(testDefinition{"true", true, boolParser{}}, t) - checkParser(testDefinition{"0", false, boolParser{}}, t) - - //timeParser - checkParser(testDefinition{"2017-05-30T13:54:53Z", time.Date(2017, 5, 30, 13, 54, 53, 0, time.UTC), timeParser{}}, t) - checkParser(testDefinition{"2017-05-30", time.Date(2017, 5, 30, 0, 0, 0, 0, time.UTC), timeParser{}}, t) - - //json - checkParser(testDefinition{`{"X": 5, "Y":"Z"}`, struct { - X int - Y string - }{5, "Z"}, jsonParser{}}, t) - - //slice in query is parsed as comma delimited - checkParser(testDefinition{`1,2`, []int{1, 2}, sliceParser(intParser{})}, t) - - //slice in body is parsed as json - checkParser(testDefinition{`["a","b"]`, []string{"a", "b"}, jsonParser{}}, t, MethodParam{in: body}) - - //pointers - var someInt = 1 - checkParser(testDefinition{`1`, &someInt, ptrParser(intParser{})}, t) - - var someStruct = struct{ X int }{5} - checkParser(testDefinition{`{"X": 5}`, &someStruct, jsonParser{}}, t) - -} - -func checkParser(def testDefinition, t *testing.T, methodParam ...MethodParam) { - toType := reflect.TypeOf(def.expectedValue) - var mp MethodParam - if len(methodParam) == 0 { - mp = MethodParam{} - } else { - mp = methodParam[0] - } - parser := getParser(&mp, toType) - - if reflect.TypeOf(parser) != reflect.TypeOf(def.expectedParser) { - t.Errorf("Invalid parser for value %v. Expected: %v, actual: %v", def.strValue, reflect.TypeOf(def.expectedParser).Name(), reflect.TypeOf(parser).Name()) - return - } - result, err := parser.parse(def.strValue, toType) - if err != nil { - t.Errorf("Parsing error for value %v. Expected result: %v, error: %v", def.strValue, def.expectedValue, err) - return - } - convResult, err := safeConvert(reflect.ValueOf(result), toType) - if err != nil { - t.Errorf("Conversion error for %v. from value: %v, toType: %v, error: %v", def.strValue, result, toType, err) - return - } - if !reflect.DeepEqual(convResult.Interface(), def.expectedValue) { - t.Errorf("Parsing error for value %v. Expected result: %v, actual: %v", def.strValue, def.expectedValue, result) - } -} diff --git a/context/renderer.go b/context/renderer.go deleted file mode 100644 index 36a7cb53..00000000 --- a/context/renderer.go +++ /dev/null @@ -1,12 +0,0 @@ -package context - -// Renderer defines an http response renderer -type Renderer interface { - Render(ctx *Context) -} - -type rendererFunc func(ctx *Context) - -func (f rendererFunc) Render(ctx *Context) { - f(ctx) -} diff --git a/context/response.go b/context/response.go deleted file mode 100644 index 9c3c715a..00000000 --- a/context/response.go +++ /dev/null @@ -1,27 +0,0 @@ -package context - -import ( - "strconv" - - "net/http" -) - -const ( - //BadRequest indicates http error 400 - BadRequest StatusCode = http.StatusBadRequest - - //NotFound indicates http error 404 - NotFound StatusCode = http.StatusNotFound -) - -// StatusCode sets the http response status code -type StatusCode int - -func (s StatusCode) Error() string { - return strconv.Itoa(int(s)) -} - -// Render sets the http status code -func (s StatusCode) Render(ctx *Context) { - ctx.Output.SetStatus(int(s)) -} diff --git a/controller.go b/controller.go deleted file mode 100644 index 0b4a79a8..00000000 --- a/controller.go +++ /dev/null @@ -1,774 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package beego - -import ( - "bytes" - "errors" - "fmt" - "html/template" - "io" - "mime/multipart" - "net/http" - "net/url" - "os" - "reflect" - "strconv" - "strings" - - "github.com/astaxie/beego/context" - "github.com/astaxie/beego/context/param" - "github.com/astaxie/beego/session" -) - -var ( - // ErrAbort custom error when user stop request handler manually. - // Deprecated: using pkg/, we will delete this in v2.1.0 - ErrAbort = errors.New("user stop run") - // GlobalControllerRouter store comments with controller. pkgpath+controller:comments - // Deprecated: using pkg/, we will delete this in v2.1.0 - GlobalControllerRouter = make(map[string][]ControllerComments) -) - -// ControllerFilter store the filter for controller -// Deprecated: using pkg/, we will delete this in v2.1.0 -type ControllerFilter struct { - Pattern string - Pos int - Filter FilterFunc - ReturnOnOutput bool - ResetParams bool -} - -// ControllerFilterComments store the comment for controller level filter -// Deprecated: using pkg/, we will delete this in v2.1.0 -type ControllerFilterComments struct { - Pattern string - Pos int - Filter string // NOQA - ReturnOnOutput bool - ResetParams bool -} - -// ControllerImportComments store the import comment for controller needed -// Deprecated: using pkg/, we will delete this in v2.1.0 -type ControllerImportComments struct { - ImportPath string - ImportAlias string -} - -// ControllerComments store the comment for the controller method -// Deprecated: using pkg/, we will delete this in v2.1.0 -type ControllerComments struct { - Method string - Router string - Filters []*ControllerFilter - ImportComments []*ControllerImportComments - FilterComments []*ControllerFilterComments - AllowHTTPMethods []string - Params []map[string]string - MethodParams []*param.MethodParam -} - -// ControllerCommentsSlice implements the sort interface -// Deprecated: using pkg/, we will delete this in v2.1.0 -type ControllerCommentsSlice []ControllerComments - -func (p ControllerCommentsSlice) Len() int { return len(p) } -func (p ControllerCommentsSlice) Less(i, j int) bool { return p[i].Router < p[j].Router } -func (p ControllerCommentsSlice) Swap(i, j int) { p[i], p[j] = p[j], p[i] } - -// Controller defines some basic http request handler operations, such as -// http context, template and view, session and xsrf. -// Deprecated: using pkg/, we will delete this in v2.1.0 -type Controller struct { - // context data - Ctx *context.Context - Data map[interface{}]interface{} - - // route controller info - controllerName string - actionName string - methodMapping map[string]func() //method:routertree - AppController interface{} - - // template data - TplName string - ViewPath string - Layout string - LayoutSections map[string]string // the key is the section name and the value is the template name - TplPrefix string - TplExt string - EnableRender bool - - // xsrf data - _xsrfToken string - XSRFExpire int - EnableXSRF bool - - // session - CruSession session.Store -} - -// ControllerInterface is an interface to uniform all controller handler. -// Deprecated: using pkg/, we will delete this in v2.1.0 -type ControllerInterface interface { - Init(ct *context.Context, controllerName, actionName string, app interface{}) - Prepare() - Get() - Post() - Delete() - Put() - Head() - Patch() - Options() - Trace() - Finish() - Render() error - XSRFToken() string - CheckXSRFCookie() bool - HandlerFunc(fn string) bool - URLMapping() -} - -// Init generates default values of controller operations. -// Deprecated: using pkg/, we will delete this in v2.1.0 -func (c *Controller) Init(ctx *context.Context, controllerName, actionName string, app interface{}) { - c.Layout = "" - c.TplName = "" - c.controllerName = controllerName - c.actionName = actionName - c.Ctx = ctx - c.TplExt = "tpl" - c.AppController = app - c.EnableRender = true - c.EnableXSRF = true - c.Data = ctx.Input.Data() - c.methodMapping = make(map[string]func()) -} - -// Prepare runs after Init before request function execution. -// Deprecated: using pkg/, we will delete this in v2.1.0 -func (c *Controller) Prepare() {} - -// Finish runs after request function execution. -// Deprecated: using pkg/, we will delete this in v2.1.0 -func (c *Controller) Finish() {} - -// Get adds a request function to handle GET request. -// Deprecated: using pkg/, we will delete this in v2.1.0 -func (c *Controller) Get() { - http.Error(c.Ctx.ResponseWriter, "Method Not Allowed", http.StatusMethodNotAllowed) -} - -// Post adds a request function to handle POST request. -// Deprecated: using pkg/, we will delete this in v2.1.0 -func (c *Controller) Post() { - http.Error(c.Ctx.ResponseWriter, "Method Not Allowed", http.StatusMethodNotAllowed) -} - -// Delete adds a request function to handle DELETE request. -// Deprecated: using pkg/, we will delete this in v2.1.0 -func (c *Controller) Delete() { - http.Error(c.Ctx.ResponseWriter, "Method Not Allowed", http.StatusMethodNotAllowed) -} - -// Put adds a request function to handle PUT request. -// Deprecated: using pkg/, we will delete this in v2.1.0 -func (c *Controller) Put() { - http.Error(c.Ctx.ResponseWriter, "Method Not Allowed", http.StatusMethodNotAllowed) -} - -// Head adds a request function to handle HEAD request. -// Deprecated: using pkg/, we will delete this in v2.1.0 -func (c *Controller) Head() { - http.Error(c.Ctx.ResponseWriter, "Method Not Allowed", http.StatusMethodNotAllowed) -} - -// Patch adds a request function to handle PATCH request. -// Deprecated: using pkg/, we will delete this in v2.1.0 -func (c *Controller) Patch() { - http.Error(c.Ctx.ResponseWriter, "Method Not Allowed", http.StatusMethodNotAllowed) -} - -// Options adds a request function to handle OPTIONS request. -// Deprecated: using pkg/, we will delete this in v2.1.0 -func (c *Controller) Options() { - http.Error(c.Ctx.ResponseWriter, "Method Not Allowed", http.StatusMethodNotAllowed) -} - -// Trace adds a request function to handle Trace request. -// this method SHOULD NOT be overridden. -// https://tools.ietf.org/html/rfc7231#section-4.3.8 -// The TRACE method requests a remote, application-level loop-back of -// the request message. The final recipient of the request SHOULD -// reflect the message received, excluding some fields described below, -// back to the client as the message body of a 200 (OK) response with a -// Content-Type of "message/http" (Section 8.3.1 of [RFC7230]). -// Deprecated: using pkg/, we will delete this in v2.1.0 -func (c *Controller) Trace() { - ts := func(h http.Header) (hs string) { - for k, v := range h { - hs += fmt.Sprintf("\r\n%s: %s", k, v) - } - return - } - hs := fmt.Sprintf("\r\nTRACE %s %s%s\r\n", c.Ctx.Request.RequestURI, c.Ctx.Request.Proto, ts(c.Ctx.Request.Header)) - c.Ctx.Output.Header("Content-Type", "message/http") - c.Ctx.Output.Header("Content-Length", fmt.Sprint(len(hs))) - c.Ctx.Output.Header("Cache-Control", "no-cache, no-store, must-revalidate") - c.Ctx.WriteString(hs) -} - -// HandlerFunc call function with the name -// Deprecated: using pkg/, we will delete this in v2.1.0 -func (c *Controller) HandlerFunc(fnname string) bool { - if v, ok := c.methodMapping[fnname]; ok { - v() - return true - } - return false -} - -// URLMapping register the internal Controller router. -// Deprecated: using pkg/, we will delete this in v2.1.0 -func (c *Controller) URLMapping() {} - -// Mapping the method to function -// Deprecated: using pkg/, we will delete this in v2.1.0 -func (c *Controller) Mapping(method string, fn func()) { - c.methodMapping[method] = fn -} - -// Render sends the response with rendered template bytes as text/html type. -// Deprecated: using pkg/, we will delete this in v2.1.0 -func (c *Controller) Render() error { - if !c.EnableRender { - return nil - } - rb, err := c.RenderBytes() - if err != nil { - return err - } - - if c.Ctx.ResponseWriter.Header().Get("Content-Type") == "" { - c.Ctx.Output.Header("Content-Type", "text/html; charset=utf-8") - } - - return c.Ctx.Output.Body(rb) -} - -// RenderString returns the rendered template string. Do not send out response. -// Deprecated: using pkg/, we will delete this in v2.1.0 -func (c *Controller) RenderString() (string, error) { - b, e := c.RenderBytes() - return string(b), e -} - -// RenderBytes returns the bytes of rendered template string. Do not send out response. -// Deprecated: using pkg/, we will delete this in v2.1.0 -func (c *Controller) RenderBytes() ([]byte, error) { - buf, err := c.renderTemplate() - //if the controller has set layout, then first get the tplName's content set the content to the layout - if err == nil && c.Layout != "" { - c.Data["LayoutContent"] = template.HTML(buf.String()) - - if c.LayoutSections != nil { - for sectionName, sectionTpl := range c.LayoutSections { - if sectionTpl == "" { - c.Data[sectionName] = "" - continue - } - buf.Reset() - err = ExecuteViewPathTemplate(&buf, sectionTpl, c.viewPath(), c.Data) - if err != nil { - return nil, err - } - c.Data[sectionName] = template.HTML(buf.String()) - } - } - - buf.Reset() - ExecuteViewPathTemplate(&buf, c.Layout, c.viewPath(), c.Data) - } - return buf.Bytes(), err -} - -func (c *Controller) renderTemplate() (bytes.Buffer, error) { - var buf bytes.Buffer - if c.TplName == "" { - c.TplName = strings.ToLower(c.controllerName) + "/" + strings.ToLower(c.actionName) + "." + c.TplExt - } - if c.TplPrefix != "" { - c.TplName = c.TplPrefix + c.TplName - } - if BConfig.RunMode == DEV { - buildFiles := []string{c.TplName} - if c.Layout != "" { - buildFiles = append(buildFiles, c.Layout) - if c.LayoutSections != nil { - for _, sectionTpl := range c.LayoutSections { - if sectionTpl == "" { - continue - } - buildFiles = append(buildFiles, sectionTpl) - } - } - } - BuildTemplate(c.viewPath(), buildFiles...) - } - return buf, ExecuteViewPathTemplate(&buf, c.TplName, c.viewPath(), c.Data) -} - -func (c *Controller) viewPath() string { - if c.ViewPath == "" { - return BConfig.WebConfig.ViewsPath - } - return c.ViewPath -} - -// Redirect sends the redirection response to url with status code. -// Deprecated: using pkg/, we will delete this in v2.1.0 -func (c *Controller) Redirect(url string, code int) { - LogAccess(c.Ctx, nil, code) - c.Ctx.Redirect(code, url) -} - -// SetData set the data depending on the accepted -// Deprecated: using pkg/, we will delete this in v2.1.0 -func (c *Controller) SetData(data interface{}) { - accept := c.Ctx.Input.Header("Accept") - switch accept { - case context.ApplicationYAML: - c.Data["yaml"] = data - case context.ApplicationXML, context.TextXML: - c.Data["xml"] = data - default: - c.Data["json"] = data - } -} - -// Abort stops controller handler and show the error data if code is defined in ErrorMap or code string. -// Deprecated: using pkg/, we will delete this in v2.1.0 -func (c *Controller) Abort(code string) { - status, err := strconv.Atoi(code) - if err != nil { - status = 200 - } - c.CustomAbort(status, code) -} - -// CustomAbort stops controller handler and show the error data, it's similar Aborts, but support status code and body. -// Deprecated: using pkg/, we will delete this in v2.1.0 -func (c *Controller) CustomAbort(status int, body string) { - // first panic from ErrorMaps, it is user defined error functions. - if _, ok := ErrorMaps[body]; ok { - c.Ctx.Output.Status = status - panic(body) - } - // last panic user string - c.Ctx.ResponseWriter.WriteHeader(status) - c.Ctx.ResponseWriter.Write([]byte(body)) - panic(ErrAbort) -} - -// StopRun makes panic of USERSTOPRUN error and go to recover function if defined. -// Deprecated: using pkg/, we will delete this in v2.1.0 -func (c *Controller) StopRun() { - panic(ErrAbort) -} - -// URLFor does another controller handler in this request function. -// it goes to this controller method if endpoint is not clear. -// Deprecated: using pkg/, we will delete this in v2.1.0 -func (c *Controller) URLFor(endpoint string, values ...interface{}) string { - if len(endpoint) == 0 { - return "" - } - if endpoint[0] == '.' { - return URLFor(reflect.Indirect(reflect.ValueOf(c.AppController)).Type().Name()+endpoint, values...) - } - return URLFor(endpoint, values...) -} - -// ServeJSON sends a json response with encoding charset. -// Deprecated: using pkg/, we will delete this in v2.1.0 -func (c *Controller) ServeJSON(encoding ...bool) { - var ( - hasIndent = BConfig.RunMode != PROD - hasEncoding = len(encoding) > 0 && encoding[0] - ) - - c.Ctx.Output.JSON(c.Data["json"], hasIndent, hasEncoding) -} - -// ServeJSONP sends a jsonp response. -// Deprecated: using pkg/, we will delete this in v2.1.0 -func (c *Controller) ServeJSONP() { - hasIndent := BConfig.RunMode != PROD - c.Ctx.Output.JSONP(c.Data["jsonp"], hasIndent) -} - -// ServeXML sends xml response. -// Deprecated: using pkg/, we will delete this in v2.1.0 -func (c *Controller) ServeXML() { - hasIndent := BConfig.RunMode != PROD - c.Ctx.Output.XML(c.Data["xml"], hasIndent) -} - -// ServeYAML sends yaml response. -// Deprecated: using pkg/, we will delete this in v2.1.0 -func (c *Controller) ServeYAML() { - c.Ctx.Output.YAML(c.Data["yaml"]) -} - -// ServeFormatted serve YAML, XML OR JSON, depending on the value of the Accept header -// Deprecated: using pkg/, we will delete this in v2.1.0 -func (c *Controller) ServeFormatted(encoding ...bool) { - hasIndent := BConfig.RunMode != PROD - hasEncoding := len(encoding) > 0 && encoding[0] - c.Ctx.Output.ServeFormatted(c.Data, hasIndent, hasEncoding) -} - -// Input returns the input data map from POST or PUT request body and query string. -// Deprecated: using pkg/, we will delete this in v2.1.0 -func (c *Controller) Input() url.Values { - if c.Ctx.Request.Form == nil { - c.Ctx.Request.ParseForm() - } - return c.Ctx.Request.Form -} - -// ParseForm maps input data map to obj struct. -// Deprecated: using pkg/, we will delete this in v2.1.0 -func (c *Controller) ParseForm(obj interface{}) error { - return ParseForm(c.Input(), obj) -} - -// GetString returns the input value by key string or the default value while it's present and input is blank -// Deprecated: using pkg/, we will delete this in v2.1.0 -func (c *Controller) GetString(key string, def ...string) string { - if v := c.Ctx.Input.Query(key); v != "" { - return v - } - if len(def) > 0 { - return def[0] - } - return "" -} - -// GetStrings returns the input string slice by key string or the default value while it's present and input is blank -// it's designed for multi-value input field such as checkbox(input[type=checkbox]), multi-selection. -// Deprecated: using pkg/, we will delete this in v2.1.0 -func (c *Controller) GetStrings(key string, def ...[]string) []string { - var defv []string - if len(def) > 0 { - defv = def[0] - } - - if f := c.Input(); f == nil { - return defv - } else if vs := f[key]; len(vs) > 0 { - return vs - } - - return defv -} - -// GetInt returns input as an int or the default value while it's present and input is blank -// Deprecated: using pkg/, we will delete this in v2.1.0 -func (c *Controller) GetInt(key string, def ...int) (int, error) { - strv := c.Ctx.Input.Query(key) - if len(strv) == 0 && len(def) > 0 { - return def[0], nil - } - return strconv.Atoi(strv) -} - -// GetInt8 return input as an int8 or the default value while it's present and input is blank -// Deprecated: using pkg/, we will delete this in v2.1.0 -func (c *Controller) GetInt8(key string, def ...int8) (int8, error) { - strv := c.Ctx.Input.Query(key) - if len(strv) == 0 && len(def) > 0 { - return def[0], nil - } - i64, err := strconv.ParseInt(strv, 10, 8) - return int8(i64), err -} - -// GetUint8 return input as an uint8 or the default value while it's present and input is blank -// Deprecated: using pkg/, we will delete this in v2.1.0 -func (c *Controller) GetUint8(key string, def ...uint8) (uint8, error) { - strv := c.Ctx.Input.Query(key) - if len(strv) == 0 && len(def) > 0 { - return def[0], nil - } - u64, err := strconv.ParseUint(strv, 10, 8) - return uint8(u64), err -} - -// GetInt16 returns input as an int16 or the default value while it's present and input is blank -// Deprecated: using pkg/, we will delete this in v2.1.0 -func (c *Controller) GetInt16(key string, def ...int16) (int16, error) { - strv := c.Ctx.Input.Query(key) - if len(strv) == 0 && len(def) > 0 { - return def[0], nil - } - i64, err := strconv.ParseInt(strv, 10, 16) - return int16(i64), err -} - -// GetUint16 returns input as an uint16 or the default value while it's present and input is blank -// Deprecated: using pkg/, we will delete this in v2.1.0 -func (c *Controller) GetUint16(key string, def ...uint16) (uint16, error) { - strv := c.Ctx.Input.Query(key) - if len(strv) == 0 && len(def) > 0 { - return def[0], nil - } - u64, err := strconv.ParseUint(strv, 10, 16) - return uint16(u64), err -} - -// GetInt32 returns input as an int32 or the default value while it's present and input is blank -// Deprecated: using pkg/, we will delete this in v2.1.0 -func (c *Controller) GetInt32(key string, def ...int32) (int32, error) { - strv := c.Ctx.Input.Query(key) - if len(strv) == 0 && len(def) > 0 { - return def[0], nil - } - i64, err := strconv.ParseInt(strv, 10, 32) - return int32(i64), err -} - -// GetUint32 returns input as an uint32 or the default value while it's present and input is blank -// Deprecated: using pkg/, we will delete this in v2.1.0 -func (c *Controller) GetUint32(key string, def ...uint32) (uint32, error) { - strv := c.Ctx.Input.Query(key) - if len(strv) == 0 && len(def) > 0 { - return def[0], nil - } - u64, err := strconv.ParseUint(strv, 10, 32) - return uint32(u64), err -} - -// GetInt64 returns input value as int64 or the default value while it's present and input is blank. -// Deprecated: using pkg/, we will delete this in v2.1.0 -func (c *Controller) GetInt64(key string, def ...int64) (int64, error) { - strv := c.Ctx.Input.Query(key) - if len(strv) == 0 && len(def) > 0 { - return def[0], nil - } - return strconv.ParseInt(strv, 10, 64) -} - -// GetUint64 returns input value as uint64 or the default value while it's present and input is blank. -// Deprecated: using pkg/, we will delete this in v2.1.0 -func (c *Controller) GetUint64(key string, def ...uint64) (uint64, error) { - strv := c.Ctx.Input.Query(key) - if len(strv) == 0 && len(def) > 0 { - return def[0], nil - } - return strconv.ParseUint(strv, 10, 64) -} - -// GetBool returns input value as bool or the default value while it's present and input is blank. -// Deprecated: using pkg/, we will delete this in v2.1.0 -func (c *Controller) GetBool(key string, def ...bool) (bool, error) { - strv := c.Ctx.Input.Query(key) - if len(strv) == 0 && len(def) > 0 { - return def[0], nil - } - return strconv.ParseBool(strv) -} - -// GetFloat returns input value as float64 or the default value while it's present and input is blank. -// Deprecated: using pkg/, we will delete this in v2.1.0 -func (c *Controller) GetFloat(key string, def ...float64) (float64, error) { - strv := c.Ctx.Input.Query(key) - if len(strv) == 0 && len(def) > 0 { - return def[0], nil - } - return strconv.ParseFloat(strv, 64) -} - -// GetFile returns the file data in file upload field named as key. -// it returns the first one of multi-uploaded files. -// Deprecated: using pkg/, we will delete this in v2.1.0 -func (c *Controller) GetFile(key string) (multipart.File, *multipart.FileHeader, error) { - return c.Ctx.Request.FormFile(key) -} - -// GetFiles return multi-upload files -// files, err:=c.GetFiles("myfiles") -// if err != nil { -// http.Error(w, err.Error(), http.StatusNoContent) -// return -// } -// for i, _ := range files { -// //for each fileheader, get a handle to the actual file -// file, err := files[i].Open() -// defer file.Close() -// if err != nil { -// http.Error(w, err.Error(), http.StatusInternalServerError) -// return -// } -// //create destination file making sure the path is writeable. -// dst, err := os.Create("upload/" + files[i].Filename) -// defer dst.Close() -// if err != nil { -// http.Error(w, err.Error(), http.StatusInternalServerError) -// return -// } -// //copy the uploaded file to the destination file -// if _, err := io.Copy(dst, file); err != nil { -// http.Error(w, err.Error(), http.StatusInternalServerError) -// return -// } -// } -// Deprecated: using pkg/, we will delete this in v2.1.0 -func (c *Controller) GetFiles(key string) ([]*multipart.FileHeader, error) { - if files, ok := c.Ctx.Request.MultipartForm.File[key]; ok { - return files, nil - } - return nil, http.ErrMissingFile -} - -// SaveToFile saves uploaded file to new path. -// it only operates the first one of mutil-upload form file field. -// Deprecated: using pkg/, we will delete this in v2.1.0 -func (c *Controller) SaveToFile(fromfile, tofile string) error { - file, _, err := c.Ctx.Request.FormFile(fromfile) - if err != nil { - return err - } - defer file.Close() - f, err := os.OpenFile(tofile, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0666) - if err != nil { - return err - } - defer f.Close() - io.Copy(f, file) - return nil -} - -// StartSession starts session and load old session data info this controller. -// Deprecated: using pkg/, we will delete this in v2.1.0 -func (c *Controller) StartSession() session.Store { - if c.CruSession == nil { - c.CruSession = c.Ctx.Input.CruSession - } - return c.CruSession -} - -// SetSession puts value into session. -// Deprecated: using pkg/, we will delete this in v2.1.0 -func (c *Controller) SetSession(name interface{}, value interface{}) { - if c.CruSession == nil { - c.StartSession() - } - c.CruSession.Set(name, value) -} - -// GetSession gets value from session. -// Deprecated: using pkg/, we will delete this in v2.1.0 -func (c *Controller) GetSession(name interface{}) interface{} { - if c.CruSession == nil { - c.StartSession() - } - return c.CruSession.Get(name) -} - -// DelSession removes value from session. -// Deprecated: using pkg/, we will delete this in v2.1.0 -func (c *Controller) DelSession(name interface{}) { - if c.CruSession == nil { - c.StartSession() - } - c.CruSession.Delete(name) -} - -// SessionRegenerateID regenerates session id for this session. -// the session data have no changes. -// Deprecated: using pkg/, we will delete this in v2.1.0 -func (c *Controller) SessionRegenerateID() { - if c.CruSession != nil { - c.CruSession.SessionRelease(c.Ctx.ResponseWriter) - } - c.CruSession = GlobalSessions.SessionRegenerateID(c.Ctx.ResponseWriter, c.Ctx.Request) - c.Ctx.Input.CruSession = c.CruSession -} - -// DestroySession cleans session data and session cookie. -// Deprecated: using pkg/, we will delete this in v2.1.0 -func (c *Controller) DestroySession() { - c.Ctx.Input.CruSession.Flush() - c.Ctx.Input.CruSession = nil - GlobalSessions.SessionDestroy(c.Ctx.ResponseWriter, c.Ctx.Request) -} - -// IsAjax returns this request is ajax or not. -// Deprecated: using pkg/, we will delete this in v2.1.0 -func (c *Controller) IsAjax() bool { - return c.Ctx.Input.IsAjax() -} - -// GetSecureCookie returns decoded cookie value from encoded browser cookie values. -// Deprecated: using pkg/, we will delete this in v2.1.0 -func (c *Controller) GetSecureCookie(Secret, key string) (string, bool) { - return c.Ctx.GetSecureCookie(Secret, key) -} - -// SetSecureCookie puts value into cookie after encoded the value. -// Deprecated: using pkg/, we will delete this in v2.1.0 -func (c *Controller) SetSecureCookie(Secret, name, value string, others ...interface{}) { - c.Ctx.SetSecureCookie(Secret, name, value, others...) -} - -// XSRFToken creates a CSRF token string and returns. -// Deprecated: using pkg/, we will delete this in v2.1.0 -func (c *Controller) XSRFToken() string { - if c._xsrfToken == "" { - expire := int64(BConfig.WebConfig.XSRFExpire) - if c.XSRFExpire > 0 { - expire = int64(c.XSRFExpire) - } - c._xsrfToken = c.Ctx.XSRFToken(BConfig.WebConfig.XSRFKey, expire) - } - return c._xsrfToken -} - -// CheckXSRFCookie checks xsrf token in this request is valid or not. -// the token can provided in request header "X-Xsrftoken" and "X-CsrfToken" -// or in form field value named as "_xsrf". -// Deprecated: using pkg/, we will delete this in v2.1.0 -func (c *Controller) CheckXSRFCookie() bool { - if !c.EnableXSRF { - return true - } - return c.Ctx.CheckXSRFCookie() -} - -// XSRFFormHTML writes an input field contains xsrf token value. -// Deprecated: using pkg/, we will delete this in v2.1.0 -func (c *Controller) XSRFFormHTML() string { - return `` -} - -// GetControllerAndAction gets the executing controller name and action name. -// Deprecated: using pkg/, we will delete this in v2.1.0 -func (c *Controller) GetControllerAndAction() (string, string) { - return c.controllerName, c.actionName -} diff --git a/controller_test.go b/controller_test.go deleted file mode 100644 index 1e53416d..00000000 --- a/controller_test.go +++ /dev/null @@ -1,181 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package beego - -import ( - "math" - "strconv" - "testing" - - "github.com/astaxie/beego/context" - "os" - "path/filepath" -) - -func TestGetInt(t *testing.T) { - i := context.NewInput() - i.SetParam("age", "40") - ctx := &context.Context{Input: i} - ctrlr := Controller{Ctx: ctx} - val, _ := ctrlr.GetInt("age") - if val != 40 { - t.Errorf("TestGetInt expect 40,get %T,%v", val, val) - } -} - -func TestGetInt8(t *testing.T) { - i := context.NewInput() - i.SetParam("age", "40") - ctx := &context.Context{Input: i} - ctrlr := Controller{Ctx: ctx} - val, _ := ctrlr.GetInt8("age") - if val != 40 { - t.Errorf("TestGetInt8 expect 40,get %T,%v", val, val) - } - //Output: int8 -} - -func TestGetInt16(t *testing.T) { - i := context.NewInput() - i.SetParam("age", "40") - ctx := &context.Context{Input: i} - ctrlr := Controller{Ctx: ctx} - val, _ := ctrlr.GetInt16("age") - if val != 40 { - t.Errorf("TestGetInt16 expect 40,get %T,%v", val, val) - } -} - -func TestGetInt32(t *testing.T) { - i := context.NewInput() - i.SetParam("age", "40") - ctx := &context.Context{Input: i} - ctrlr := Controller{Ctx: ctx} - val, _ := ctrlr.GetInt32("age") - if val != 40 { - t.Errorf("TestGetInt32 expect 40,get %T,%v", val, val) - } -} - -func TestGetInt64(t *testing.T) { - i := context.NewInput() - i.SetParam("age", "40") - ctx := &context.Context{Input: i} - ctrlr := Controller{Ctx: ctx} - val, _ := ctrlr.GetInt64("age") - if val != 40 { - t.Errorf("TestGeetInt64 expect 40,get %T,%v", val, val) - } -} - -func TestGetUint8(t *testing.T) { - i := context.NewInput() - i.SetParam("age", strconv.FormatUint(math.MaxUint8, 10)) - ctx := &context.Context{Input: i} - ctrlr := Controller{Ctx: ctx} - val, _ := ctrlr.GetUint8("age") - if val != math.MaxUint8 { - t.Errorf("TestGetUint8 expect %v,get %T,%v", math.MaxUint8, val, val) - } -} - -func TestGetUint16(t *testing.T) { - i := context.NewInput() - i.SetParam("age", strconv.FormatUint(math.MaxUint16, 10)) - ctx := &context.Context{Input: i} - ctrlr := Controller{Ctx: ctx} - val, _ := ctrlr.GetUint16("age") - if val != math.MaxUint16 { - t.Errorf("TestGetUint16 expect %v,get %T,%v", math.MaxUint16, val, val) - } -} - -func TestGetUint32(t *testing.T) { - i := context.NewInput() - i.SetParam("age", strconv.FormatUint(math.MaxUint32, 10)) - ctx := &context.Context{Input: i} - ctrlr := Controller{Ctx: ctx} - val, _ := ctrlr.GetUint32("age") - if val != math.MaxUint32 { - t.Errorf("TestGetUint32 expect %v,get %T,%v", math.MaxUint32, val, val) - } -} - -func TestGetUint64(t *testing.T) { - i := context.NewInput() - i.SetParam("age", strconv.FormatUint(math.MaxUint64, 10)) - ctx := &context.Context{Input: i} - ctrlr := Controller{Ctx: ctx} - val, _ := ctrlr.GetUint64("age") - if val != math.MaxUint64 { - t.Errorf("TestGetUint64 expect %v,get %T,%v", uint64(math.MaxUint64), val, val) - } -} - -func TestAdditionalViewPaths(t *testing.T) { - dir1 := "_beeTmp" - dir2 := "_beeTmp2" - defer os.RemoveAll(dir1) - defer os.RemoveAll(dir2) - - dir1file := "file1.tpl" - dir2file := "file2.tpl" - - genFile := func(dir string, name string, content string) { - os.MkdirAll(filepath.Dir(filepath.Join(dir, name)), 0777) - if f, err := os.Create(filepath.Join(dir, name)); err != nil { - t.Fatal(err) - } else { - defer f.Close() - f.WriteString(content) - f.Close() - } - - } - genFile(dir1, dir1file, `
{{.Content}}
`) - genFile(dir2, dir2file, `{{.Content}}`) - - AddViewPath(dir1) - AddViewPath(dir2) - - ctrl := Controller{ - TplName: "file1.tpl", - ViewPath: dir1, - } - ctrl.Data = map[interface{}]interface{}{ - "Content": "value2", - } - if result, err := ctrl.RenderString(); err != nil { - t.Fatal(err) - } else { - if result != "
value2
" { - t.Fatalf("TestAdditionalViewPaths expect %s got %s", "
value2
", result) - } - } - - func() { - ctrl.TplName = "file2.tpl" - defer func() { - if r := recover(); r == nil { - t.Fatal("TestAdditionalViewPaths expected error") - } - }() - ctrl.RenderString() - }() - - ctrl.TplName = "file2.tpl" - ctrl.ViewPath = dir2 - ctrl.RenderString() -} diff --git a/doc.go b/doc.go deleted file mode 100644 index 72284c67..00000000 --- a/doc.go +++ /dev/null @@ -1,19 +0,0 @@ -/* -Package beego provide a MVC framework -beego: an open-source, high-performance, modular, full-stack web framework - -It is used for rapid development of RESTful APIs, web apps and backend services in Go. -beego is inspired by Tornado, Sinatra and Flask with the added benefit of some Go-specific features such as interfaces and struct embedding. - - package main - import "github.com/astaxie/beego" - - func main() { - beego.Run() - } - -more information: http://beego.me - -Deprecated: using pkg/, we will delete this in v2.1.0 -*/ -package beego diff --git a/error.go b/error.go deleted file mode 100644 index 40eea5fa..00000000 --- a/error.go +++ /dev/null @@ -1,492 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package beego - -import ( - "fmt" - "html/template" - "net/http" - "reflect" - "runtime" - "strconv" - "strings" - - "github.com/astaxie/beego/context" - "github.com/astaxie/beego/utils" -) - -const ( - errorTypeHandler = iota - errorTypeController -) - -var tpl = ` - - - - - beego application error - - - - - -
- - - - - - - - - - -
Request Method: {{.RequestMethod}}
Request URL: {{.RequestURL}}
RemoteAddr: {{.RemoteAddr }}
-
- Stack -
{{.Stack}}
-
-
- - - -` - -// render default application error page with error and stack string. -func showErr(err interface{}, ctx *context.Context, stack string) { - t, _ := template.New("beegoerrortemp").Parse(tpl) - data := map[string]string{ - "AppError": fmt.Sprintf("%s:%v", BConfig.AppName, err), - "RequestMethod": ctx.Input.Method(), - "RequestURL": ctx.Input.URI(), - "RemoteAddr": ctx.Input.IP(), - "Stack": stack, - "BeegoVersion": VERSION, - "GoVersion": runtime.Version(), - } - t.Execute(ctx.ResponseWriter, data) -} - -var errtpl = ` - - - - - {{.Title}} - - - -
-
- -
- {{.Content}} - Go Home
- -
Powered by beego {{.BeegoVersion}} -
-
-
- - -` - -type errorInfo struct { - controllerType reflect.Type - handler http.HandlerFunc - method string - errorType int -} - -// ErrorMaps holds map of http handlers for each error string. -// there is 10 kinds default error(40x and 50x) -// Deprecated: using pkg/, we will delete this in v2.1.0 -var ErrorMaps = make(map[string]*errorInfo, 10) - -// show 401 unauthorized error. -func unauthorized(rw http.ResponseWriter, r *http.Request) { - responseError(rw, r, - 401, - "
The page you have requested can't be authorized."+ - "
Perhaps you are here because:"+ - "

    "+ - "
    The credentials you supplied are incorrect"+ - "
    There are errors in the website address"+ - "
", - ) -} - -// show 402 Payment Required -func paymentRequired(rw http.ResponseWriter, r *http.Request) { - responseError(rw, r, - 402, - "
The page you have requested Payment Required."+ - "
Perhaps you are here because:"+ - "

    "+ - "
    The credentials you supplied are incorrect"+ - "
    There are errors in the website address"+ - "
", - ) -} - -// show 403 forbidden error. -func forbidden(rw http.ResponseWriter, r *http.Request) { - responseError(rw, r, - 403, - "
The page you have requested is forbidden."+ - "
Perhaps you are here because:"+ - "

    "+ - "
    Your address may be blocked"+ - "
    The site may be disabled"+ - "
    You need to log in"+ - "
", - ) -} - -// show 422 missing xsrf token -func missingxsrf(rw http.ResponseWriter, r *http.Request) { - responseError(rw, r, - 422, - "
The page you have requested is forbidden."+ - "
Perhaps you are here because:"+ - "

    "+ - "
    '_xsrf' argument missing from POST"+ - "
", - ) -} - -// show 417 invalid xsrf token -func invalidxsrf(rw http.ResponseWriter, r *http.Request) { - responseError(rw, r, - 417, - "
The page you have requested is forbidden."+ - "
Perhaps you are here because:"+ - "

    "+ - "
    expected XSRF not found"+ - "
", - ) -} - -// show 404 not found error. -func notFound(rw http.ResponseWriter, r *http.Request) { - responseError(rw, r, - 404, - "
The page you have requested has flown the coop."+ - "
Perhaps you are here because:"+ - "

    "+ - "
    The page has moved"+ - "
    The page no longer exists"+ - "
    You were looking for your puppy and got lost"+ - "
    You like 404 pages"+ - "
", - ) -} - -// show 405 Method Not Allowed -func methodNotAllowed(rw http.ResponseWriter, r *http.Request) { - responseError(rw, r, - 405, - "
The method you have requested Not Allowed."+ - "
Perhaps you are here because:"+ - "

    "+ - "
    The method specified in the Request-Line is not allowed for the resource identified by the Request-URI"+ - "
    The response MUST include an Allow header containing a list of valid methods for the requested resource."+ - "
", - ) -} - -// show 500 internal server error. -func internalServerError(rw http.ResponseWriter, r *http.Request) { - responseError(rw, r, - 500, - "
The page you have requested is down right now."+ - "

    "+ - "
    Please try again later and report the error to the website administrator"+ - "
", - ) -} - -// show 501 Not Implemented. -func notImplemented(rw http.ResponseWriter, r *http.Request) { - responseError(rw, r, - 501, - "
The page you have requested is Not Implemented."+ - "

    "+ - "
    Please try again later and report the error to the website administrator"+ - "
", - ) -} - -// show 502 Bad Gateway. -func badGateway(rw http.ResponseWriter, r *http.Request) { - responseError(rw, r, - 502, - "
The page you have requested is down right now."+ - "

    "+ - "
    The server, while acting as a gateway or proxy, received an invalid response from the upstream server it accessed in attempting to fulfill the request."+ - "
    Please try again later and report the error to the website administrator"+ - "
", - ) -} - -// show 503 service unavailable error. -func serviceUnavailable(rw http.ResponseWriter, r *http.Request) { - responseError(rw, r, - 503, - "
The page you have requested is unavailable."+ - "
Perhaps you are here because:"+ - "

    "+ - "

    The page is overloaded"+ - "
    Please try again later."+ - "
", - ) -} - -// show 504 Gateway Timeout. -func gatewayTimeout(rw http.ResponseWriter, r *http.Request) { - responseError(rw, r, - 504, - "
The page you have requested is unavailable"+ - "
Perhaps you are here because:"+ - "

    "+ - "

    The server, while acting as a gateway or proxy, did not receive a timely response from the upstream server specified by the URI."+ - "
    Please try again later."+ - "
", - ) -} - -// show 413 Payload Too Large -func payloadTooLarge(rw http.ResponseWriter, r *http.Request) { - responseError(rw, r, - 413, - `
The page you have requested is unavailable. -
Perhaps you are here because:

-
    -
    The request entity is larger than limits defined by server. -
    Please change the request entity and try again. -
- `, - ) -} - -func responseError(rw http.ResponseWriter, r *http.Request, errCode int, errContent string) { - t, _ := template.New("beegoerrortemp").Parse(errtpl) - data := M{ - "Title": http.StatusText(errCode), - "BeegoVersion": VERSION, - "Content": template.HTML(errContent), - } - t.Execute(rw, data) -} - -// ErrorHandler registers http.HandlerFunc to each http err code string. -// usage: -// beego.ErrorHandler("404",NotFound) -// beego.ErrorHandler("500",InternalServerError) -// Deprecated: using pkg/, we will delete this in v2.1.0 -func ErrorHandler(code string, h http.HandlerFunc) *App { - ErrorMaps[code] = &errorInfo{ - errorType: errorTypeHandler, - handler: h, - method: code, - } - return BeeApp -} - -// ErrorController registers ControllerInterface to each http err code string. -// usage: -// beego.ErrorController(&controllers.ErrorController{}) -// Deprecated: using pkg/, we will delete this in v2.1.0 -func ErrorController(c ControllerInterface) *App { - reflectVal := reflect.ValueOf(c) - rt := reflectVal.Type() - ct := reflect.Indirect(reflectVal).Type() - for i := 0; i < rt.NumMethod(); i++ { - methodName := rt.Method(i).Name - if !utils.InSlice(methodName, exceptMethod) && strings.HasPrefix(methodName, "Error") { - errName := strings.TrimPrefix(methodName, "Error") - ErrorMaps[errName] = &errorInfo{ - errorType: errorTypeController, - controllerType: ct, - method: methodName, - } - } - } - return BeeApp -} - -// Exception Write HttpStatus with errCode and Exec error handler if exist. -// Deprecated: using pkg/, we will delete this in v2.1.0 -func Exception(errCode uint64, ctx *context.Context) { - exception(strconv.FormatUint(errCode, 10), ctx) -} - -// show error string as simple text message. -// if error string is empty, show 503 or 500 error as default. -func exception(errCode string, ctx *context.Context) { - atoi := func(code string) int { - v, err := strconv.Atoi(code) - if err == nil { - return v - } - if ctx.Output.Status == 0 { - return 503 - } - return ctx.Output.Status - } - - for _, ec := range []string{errCode, "503", "500"} { - if h, ok := ErrorMaps[ec]; ok { - executeError(h, ctx, atoi(ec)) - return - } - } - //if 50x error has been removed from errorMap - ctx.ResponseWriter.WriteHeader(atoi(errCode)) - ctx.WriteString(errCode) -} - -func executeError(err *errorInfo, ctx *context.Context, code int) { - //make sure to log the error in the access log - LogAccess(ctx, nil, code) - - if err.errorType == errorTypeHandler { - ctx.ResponseWriter.WriteHeader(code) - err.handler(ctx.ResponseWriter, ctx.Request) - return - } - if err.errorType == errorTypeController { - ctx.Output.SetStatus(code) - //Invoke the request handler - vc := reflect.New(err.controllerType) - execController, ok := vc.Interface().(ControllerInterface) - if !ok { - panic("controller is not ControllerInterface") - } - //call the controller init function - execController.Init(ctx, err.controllerType.Name(), err.method, vc.Interface()) - - //call prepare function - execController.Prepare() - - execController.URLMapping() - - method := vc.MethodByName(err.method) - method.Call([]reflect.Value{}) - - //render template - if BConfig.WebConfig.AutoRender { - if err := execController.Render(); err != nil { - panic(err) - } - } - - // finish all runrouter. release resource - execController.Finish() - } -} diff --git a/error_test.go b/error_test.go deleted file mode 100644 index 378aa953..00000000 --- a/error_test.go +++ /dev/null @@ -1,88 +0,0 @@ -// Copyright 2016 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package beego - -import ( - "net/http" - "net/http/httptest" - "strconv" - "strings" - "testing" -) - -type errorTestController struct { - Controller -} - -const parseCodeError = "parse code error" - -func (ec *errorTestController) Get() { - errorCode, err := ec.GetInt("code") - if err != nil { - ec.Abort(parseCodeError) - } - if errorCode != 0 { - ec.CustomAbort(errorCode, ec.GetString("code")) - } - ec.Abort("404") -} - -func TestErrorCode_01(t *testing.T) { - registerDefaultErrorHandler() - for k := range ErrorMaps { - r, _ := http.NewRequest("GET", "/error?code="+k, nil) - w := httptest.NewRecorder() - - handler := NewControllerRegister() - handler.Add("/error", &errorTestController{}) - handler.ServeHTTP(w, r) - code, _ := strconv.Atoi(k) - if w.Code != code { - t.Fail() - } - if !strings.Contains(w.Body.String(), http.StatusText(code)) { - t.Fail() - } - } -} - -func TestErrorCode_02(t *testing.T) { - registerDefaultErrorHandler() - r, _ := http.NewRequest("GET", "/error?code=0", nil) - w := httptest.NewRecorder() - - handler := NewControllerRegister() - handler.Add("/error", &errorTestController{}) - handler.ServeHTTP(w, r) - if w.Code != 404 { - t.Fail() - } -} - -func TestErrorCode_03(t *testing.T) { - registerDefaultErrorHandler() - r, _ := http.NewRequest("GET", "/error?code=panic", nil) - w := httptest.NewRecorder() - - handler := NewControllerRegister() - handler.Add("/error", &errorTestController{}) - handler.ServeHTTP(w, r) - if w.Code != 200 { - t.Fail() - } - if w.Body.String() != parseCodeError { - t.Fail() - } -} diff --git a/filter.go b/filter.go deleted file mode 100644 index 8596d288..00000000 --- a/filter.go +++ /dev/null @@ -1,47 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package beego - -import "github.com/astaxie/beego/context" - -// FilterFunc defines a filter function which is invoked before the controller handler is executed. -// Deprecated: using pkg/, we will delete this in v2.1.0 -type FilterFunc func(*context.Context) - -// FilterRouter defines a filter operation which is invoked before the controller handler is executed. -// It can match the URL against a pattern, and execute a filter function -// when a request with a matching URL arrives. -// Deprecated: using pkg/, we will delete this in v2.1.0 -type FilterRouter struct { - filterFunc FilterFunc - tree *Tree - pattern string - returnOnOutput bool - resetParams bool -} - -// ValidRouter checks if the current request is matched by this filter. -// If the request is matched, the values of the URL parameters defined -// by the filter pattern are also returned. -// Deprecated: using pkg/, we will delete this in v2.1.0 -func (f *FilterRouter) ValidRouter(url string, ctx *context.Context) bool { - isOk := f.tree.Match(url, ctx) - if isOk != nil { - if b, ok := isOk.(bool); ok { - return b - } - } - return false -} diff --git a/filter_test.go b/filter_test.go deleted file mode 100644 index 4ca4d2b8..00000000 --- a/filter_test.go +++ /dev/null @@ -1,68 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package beego - -import ( - "net/http" - "net/http/httptest" - "testing" - - "github.com/astaxie/beego/context" -) - -var FilterUser = func(ctx *context.Context) { - ctx.Output.Body([]byte("i am " + ctx.Input.Param(":last") + ctx.Input.Param(":first"))) -} - -func TestFilter(t *testing.T) { - r, _ := http.NewRequest("GET", "/person/asta/Xie", nil) - w := httptest.NewRecorder() - handler := NewControllerRegister() - handler.InsertFilter("/person/:last/:first", BeforeRouter, FilterUser) - handler.Add("/person/:last/:first", &TestController{}) - handler.ServeHTTP(w, r) - if w.Body.String() != "i am astaXie" { - t.Errorf("user define func can't run") - } -} - -var FilterAdminUser = func(ctx *context.Context) { - ctx.Output.Body([]byte("i am admin")) -} - -// Filter pattern /admin/:all -// all url like /admin/ /admin/xie will all get filter - -func TestPatternTwo(t *testing.T) { - r, _ := http.NewRequest("GET", "/admin/", nil) - w := httptest.NewRecorder() - handler := NewControllerRegister() - handler.InsertFilter("/admin/?:all", BeforeRouter, FilterAdminUser) - handler.ServeHTTP(w, r) - if w.Body.String() != "i am admin" { - t.Errorf("filter /admin/ can't run") - } -} - -func TestPatternThree(t *testing.T) { - r, _ := http.NewRequest("GET", "/admin/astaxie", nil) - w := httptest.NewRecorder() - handler := NewControllerRegister() - handler.InsertFilter("/admin/:all", BeforeRouter, FilterAdminUser) - handler.ServeHTTP(w, r) - if w.Body.String() != "i am admin" { - t.Errorf("filter /admin/astaxie can't run") - } -} diff --git a/flash.go b/flash.go deleted file mode 100644 index fe3fb974..00000000 --- a/flash.go +++ /dev/null @@ -1,119 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package beego - -import ( - "fmt" - "net/url" - "strings" -) - -// FlashData is a tools to maintain data when using across request. -// Deprecated: using pkg/, we will delete this in v2.1.0 -type FlashData struct { - Data map[string]string -} - -// NewFlash return a new empty FlashData struct. -// Deprecated: using pkg/, we will delete this in v2.1.0 -func NewFlash() *FlashData { - return &FlashData{ - Data: make(map[string]string), - } -} - -// Set message to flash -// Deprecated: using pkg/, we will delete this in v2.1.0 -func (fd *FlashData) Set(key string, msg string, args ...interface{}) { - if len(args) == 0 { - fd.Data[key] = msg - } else { - fd.Data[key] = fmt.Sprintf(msg, args...) - } -} - -// Success writes success message to flash. -// Deprecated: using pkg/, we will delete this in v2.1.0 -func (fd *FlashData) Success(msg string, args ...interface{}) { - if len(args) == 0 { - fd.Data["success"] = msg - } else { - fd.Data["success"] = fmt.Sprintf(msg, args...) - } -} - -// Notice writes notice message to flash. -// Deprecated: using pkg/, we will delete this in v2.1.0 -func (fd *FlashData) Notice(msg string, args ...interface{}) { - if len(args) == 0 { - fd.Data["notice"] = msg - } else { - fd.Data["notice"] = fmt.Sprintf(msg, args...) - } -} - -// Warning writes warning message to flash. -// Deprecated: using pkg/, we will delete this in v2.1.0 -func (fd *FlashData) Warning(msg string, args ...interface{}) { - if len(args) == 0 { - fd.Data["warning"] = msg - } else { - fd.Data["warning"] = fmt.Sprintf(msg, args...) - } -} - -// Error writes error message to flash. -// Deprecated: using pkg/, we will delete this in v2.1.0 -func (fd *FlashData) Error(msg string, args ...interface{}) { - if len(args) == 0 { - fd.Data["error"] = msg - } else { - fd.Data["error"] = fmt.Sprintf(msg, args...) - } -} - -// Store does the saving operation of flash data. -// the data are encoded and saved in cookie. -// Deprecated: using pkg/, we will delete this in v2.1.0 -func (fd *FlashData) Store(c *Controller) { - c.Data["flash"] = fd.Data - var flashValue string - for key, value := range fd.Data { - flashValue += "\x00" + key + "\x23" + BConfig.WebConfig.FlashSeparator + "\x23" + value + "\x00" - } - c.Ctx.SetCookie(BConfig.WebConfig.FlashName, url.QueryEscape(flashValue), 0, "/") -} - -// ReadFromRequest parsed flash data from encoded values in cookie. -// Deprecated: using pkg/, we will delete this in v2.1.0 -func ReadFromRequest(c *Controller) *FlashData { - flash := NewFlash() - if cookie, err := c.Ctx.Request.Cookie(BConfig.WebConfig.FlashName); err == nil { - v, _ := url.QueryUnescape(cookie.Value) - vals := strings.Split(v, "\x00") - for _, v := range vals { - if len(v) > 0 { - kv := strings.Split(v, "\x23"+BConfig.WebConfig.FlashSeparator+"\x23") - if len(kv) == 2 { - flash.Data[kv[0]] = kv[1] - } - } - } - //read one time then delete it - c.Ctx.SetCookie(BConfig.WebConfig.FlashName, "", -1, "/") - } - c.Data["flash"] = flash.Data - return flash -} diff --git a/flash_test.go b/flash_test.go deleted file mode 100644 index d5e9608d..00000000 --- a/flash_test.go +++ /dev/null @@ -1,54 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package beego - -import ( - "net/http" - "net/http/httptest" - "strings" - "testing" -) - -type TestFlashController struct { - Controller -} - -func (t *TestFlashController) TestWriteFlash() { - flash := NewFlash() - flash.Notice("TestFlashString") - flash.Store(&t.Controller) - // we choose to serve json because we don't want to load a template html file - t.ServeJSON(true) -} - -func TestFlashHeader(t *testing.T) { - // create fake GET request - r, _ := http.NewRequest("GET", "/", nil) - w := httptest.NewRecorder() - - // setup the handler - handler := NewControllerRegister() - handler.Add("/", &TestFlashController{}, "get:TestWriteFlash") - handler.ServeHTTP(w, r) - - // get the Set-Cookie value - sc := w.Header().Get("Set-Cookie") - // match for the expected header - res := strings.Contains(sc, "BEEGO_FLASH=%00notice%23BEEGOFLASH%23TestFlashString%00") - // validate the assertion - if !res { - t.Errorf("TestFlashHeader() unable to validate flash message") - } -} diff --git a/fs.go b/fs.go deleted file mode 100644 index 3300813d..00000000 --- a/fs.go +++ /dev/null @@ -1,77 +0,0 @@ -package beego - -import ( - "net/http" - "os" - "path/filepath" -) - -// Deprecated: using pkg/, we will delete this in v2.1.0 -type FileSystem struct { -} - -// Deprecated: using pkg/, we will delete this in v2.1.0 -func (d FileSystem) Open(name string) (http.File, error) { - return os.Open(name) -} - -// Walk walks the file tree rooted at root in filesystem, calling walkFn for each file or -// directory in the tree, including root. All errors that arise visiting files -// and directories are filtered by walkFn. -// Deprecated: using pkg/, we will delete this in v2.1.0 -func Walk(fs http.FileSystem, root string, walkFn filepath.WalkFunc) error { - - f, err := fs.Open(root) - if err != nil { - return err - } - info, err := f.Stat() - if err != nil { - err = walkFn(root, nil, err) - } else { - err = walk(fs, root, info, walkFn) - } - if err == filepath.SkipDir { - return nil - } - return err -} - -// walk recursively descends path, calling walkFn. -func walk(fs http.FileSystem, path string, info os.FileInfo, walkFn filepath.WalkFunc) error { - var err error - if !info.IsDir() { - return walkFn(path, info, nil) - } - - dir, err := fs.Open(path) - if err != nil { - if err1 := walkFn(path, info, err); err1 != nil { - return err1 - } - return err - } - defer dir.Close() - dirs, err := dir.Readdir(-1) - err1 := walkFn(path, info, err) - // If err != nil, walk can't walk into this directory. - // err1 != nil means walkFn want walk to skip this directory or stop walking. - // Therefore, if one of err and err1 isn't nil, walk will return. - if err != nil || err1 != nil { - // The caller's behavior is controlled by the return value, which is decided - // by walkFn. walkFn may ignore err and return nil. - // If walkFn returns SkipDir, it will be handled by the caller. - // So walk should return whatever walkFn returns. - return err1 - } - - for _, fileInfo := range dirs { - filename := filepath.Join(path, fileInfo.Name()) - if err = walk(fs, filename, fileInfo, walkFn); err != nil { - if !fileInfo.IsDir() || err != filepath.SkipDir { - return err - } - } - } - return nil -} diff --git a/go.mod b/go.mod index e1b9fcc2..91bd9aef 100644 --- a/go.mod +++ b/go.mod @@ -36,6 +36,7 @@ require ( golang.org/x/tools v0.0.0-20200117065230-39095c1d176c google.golang.org/grpc v1.31.0 // indirect gopkg.in/yaml.v2 v2.2.8 + honnef.co/go/tools v0.0.1-2020.1.5 // indirect ) replace golang.org/x/crypto v0.0.0-20181127143415-eb0de9b17e85 => github.com/golang/crypto v0.0.0-20181127143415-eb0de9b17e85 diff --git a/go.sum b/go.sum index 1666981d..95babc92 100644 --- a/go.sum +++ b/go.sum @@ -93,6 +93,7 @@ github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMyw github.com/google/go-cmp v0.4.0 h1:xsAVV57WRhGj6kEIi8ReJzQlHHqcBYCElAvkovg3B/4= github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= +github.com/google/renameio v0.1.0/go.mod h1:KWCgfxg9yswjAJkECMjeO8J8rahYeXnNhOm40UhjYkI= github.com/hashicorp/golang-lru v0.5.4 h1:YDjusn29QI/Das2iO9M0BHnIbxPeyuCHsjMW+lJfyTc= github.com/hashicorp/golang-lru v0.5.4/go.mod h1:iADmTwqILo4mZ8BN3D2Q6+9jd8WM5uGBxy+E8yxSoD4= github.com/hpcloud/tail v1.0.0 h1:nfCOvKYfkgYP8hkirhJocXT2+zOD8yUNjXaWfTlyFKI= @@ -100,6 +101,7 @@ github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpO github.com/json-iterator/go v1.1.6/go.mod h1:+SdeFBvtyEkXs7REEP0seUULqWtbJapLOCVDaaPEHmU= github.com/json-iterator/go v1.1.10/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4= github.com/julienschmidt/httprouter v1.2.0/go.mod h1:SYymIcj16QtmaHHD7aYtjjsJG7VTCxuUUipMqKk8s4w= +github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck= github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= github.com/kr/logfmt v0.0.0-20140226030751-b84e30acd515 h1:T+h1c/A9Gawja4Y9mFVWj2vyii2bbUNDw3kt9VxK2EY= github.com/kr/logfmt v0.0.0-20140226030751-b84e30acd515/go.mod h1:+0opPa2QZZtGFBFZlji/RkVcI2GknAs/DXo4wKdlNEc= @@ -159,6 +161,7 @@ github.com/prometheus/procfs v0.0.0-20181005140218-185b4288413d/go.mod h1:c3At6R github.com/prometheus/procfs v0.0.2/go.mod h1:TjEm7ze935MbeOT/UhFTIMYKhuLP4wbCsTZCD3I8kEA= github.com/prometheus/procfs v0.1.3 h1:F0+tqvhOksq22sc6iCHF5WGlWjdwj92p0udFh1VFBS8= github.com/prometheus/procfs v0.1.3/go.mod h1:lV6e/gmhEcM9IjHGsFOCxxuZ+z1YqCvr4OA4YeYWdaU= +github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4= github.com/shiena/ansicolor v0.0.0-20151119151921-a422bbe96644 h1:X+yvsM2yrEktyI+b2qND5gpH8YhURn0k8OCaeRnkINo= github.com/shiena/ansicolor v0.0.0-20151119151921-a422bbe96644/go.mod h1:nkxAfR/5quYxwPZhyDxgasBMnRtBZd0FCEpawpjMUFg= github.com/siddontang/go v0.0.0-20170517070808-cb568a3e5cc0 h1:QIF48X1cihydXibm+4wfAc0r/qyPyuFiPFRNphdMpEE= @@ -182,16 +185,23 @@ github.com/syndtr/goleveldb v0.0.0-20181127023241-353a9fca669c/go.mod h1:Z4AUp2K github.com/ugorji/go v0.0.0-20171122102828-84cb69a8af83/go.mod h1:hnLbHMwcvSihnDhEfx2/BzKp2xb0Y+ErdfYcrs9tkJQ= github.com/wendal/errors v0.0.0-20130201093226-f66c77a7882b h1:0Ve0/CCjiAiyKddUMUn3RwIGlq2iTW4GuVzyoKBYO/8= github.com/wendal/errors v0.0.0-20130201093226-f66c77a7882b/go.mod h1:Q12BUT7DqIlHRmgv3RskH+UCM/4eqVMgI0EMmlSpAXc= +github.com/yuin/goldmark v1.1.32/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/gopher-lua v0.0.0-20171031051903-609c9cd26973/go.mod h1:aEV29XrmTYFr3CiRxZeGHpkvbwq+prZduBqMaascyCU= golang.org/x/crypto v0.0.0-20180904163835-0709b304e793/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= +golang.org/x/crypto v0.0.0-20190510104115-cbcb75029529/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550 h1:ObdrDkeb4kJdCP557AjRjq69pTHfNouLtWZG7j9rPN8= golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= +golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9 h1:psW17arqaxU48Z5kZ0CQnkZWQJsqcURM6tKiBApRjXI= +golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= golang.org/x/lint v0.0.0-20190227174305-5b3e6a55c961/go.mod h1:wehouNa3lNwaWXcvxsM5YxQ5yQlVC4a0KAMCusXpPoU= golang.org/x/lint v0.0.0-20190313153728-d0100b6bd8b3/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc= +golang.org/x/mod v0.0.0-20190513183733-4bf6d317e70e/go.mod h1:mXi4GBBbnImb6dmsKGUJ2LatrhH/nqhxcFungHvyanc= golang.org/x/mod v0.1.1-0.20191105210325-c90efee705ee/go.mod h1:QqPTAvyqsEbceGzBzNggFXnrqF1CaUcvgkdR5Ot7KZg= +golang.org/x/mod v0.3.0 h1:RM4zey1++hCTbCVQfnWeKs9/IEsaBLA8vTkd0WVtmH4= +golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20180906233101-161cd47e91fd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= @@ -204,12 +214,15 @@ golang.org/x/net v0.0.0-20190613194153-d28f0bde5980/go.mod h1:z5CRVTTTmAJ677TzLL golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20190923162816-aa69164e4478 h1:l5EDrHhldLYb3ZRHDUhXF7Om7MvYXnkV9/iQNo1lX6g= golang.org/x/net v0.0.0-20190923162816-aa69164e4478/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/net v0.0.0-20200625001655-4c5254603344 h1:vGXIOMxbNfDTk/aXCmfdLgkrSV+Z2tcbze+pEc3v5W4= +golang.org/x/net v0.0.0-20200625001655-4c5254603344/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA= golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20200625203802-6e8e738ad208/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20180909124046-d0be0721c37e/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= @@ -220,6 +233,7 @@ golang.org/x/sys v0.0.0-20190422165155-953cdadca894/go.mod h1:h1NjWce9XRLGQEsW7w golang.org/x/sys v0.0.0-20191010194322-b09406accb47/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20191120155948-bd437916bb0e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200106162015-b016eb3dc98e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200323222414-85ca7c5b95cd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200615200032-f1bc736245b1 h1:ogLJMz+qpzav7lGMh10LMvAkM/fAoGlaiiHYiFYdm80= golang.org/x/sys v0.0.0-20200615200032-f1bc736245b1/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/text v0.3.0 h1:g61tztE5qeGQ89tm6NTjjM9VPIm088od1l6aSorWRWg= @@ -231,11 +245,18 @@ golang.org/x/tools v0.0.0-20190114222345-bf090417da8b/go.mod h1:n7NCudcB/nEzxVGm golang.org/x/tools v0.0.0-20190226205152-f727befe758c/go.mod h1:9Yl7xja0Znq3iFh3HoIrodX9oNMXvdceNzlUR8zjMvY= golang.org/x/tools v0.0.0-20190311212946-11955173bddd/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= golang.org/x/tools v0.0.0-20190524140312-2c0ae7006135/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q= +golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= +golang.org/x/tools v0.0.0-20191130070609-6e064ea0cf2d/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.0.0-20200117065230-39095c1d176c h1:FodBYPZKH5tAN2O60HlglMwXGAeV/4k+NKbli79M/2c= golang.org/x/tools v0.0.0-20200117065230-39095c1d176c/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28= +golang.org/x/tools v0.0.0-20200815165600-90abf76919f3 h1:0aScV/0rLmANzEYIhjCOi2pTvDyhZNduBUMD2q3iqs4= +golang.org/x/tools v0.0.0-20200815165600-90abf76919f3/go.mod h1:njjCfa9FT2d7l9Bc6FUM5FLjQPp3cFF28FI3qnDFljA= +golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 h1:E7g+9GITq07hpfrRu66IVDexMakfv52eLZ2CXBWiKr4= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 h1:go1bK/D/BFZV2I8cIQd1NKEZ+0owSTG1fDTci4IqFcE= +golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= google.golang.org/appengine v1.1.0/go.mod h1:EbEs0AVv82hx2wNQdGPgUI5lhzA/G0D9YwlJXL52JkM= google.golang.org/appengine v1.4.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4= google.golang.org/genproto v0.0.0-20180817151627-c66870c02cf8/go.mod h1:JiN7NxoALGmiZfu7CAH4rXhgtRTLTxftemlI0sWmxmc= @@ -254,9 +275,11 @@ google.golang.org/protobuf v1.23.0 h1:4MY060fB1DLGMB/7MBTLnwQUY6+F09GEiz6SsrNqyz google.golang.org/protobuf v1.23.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU= gopkg.in/alecthomas/kingpin.v2 v2.2.6/go.mod h1:FMv+mEhP44yOT+4EoQTLFTRgOQ1FBLkstjWtayDeSgw= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20200227125254-8fa46927fb4f h1:BLraFXnmrev5lT+xlilqcH8XK9/i0At2xKjWk4p6zsU= gopkg.in/check.v1 v1.0.0-20200227125254-8fa46927fb4f/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/errgo.v2 v2.1.0/go.mod h1:hNsd1EY+bozCKY1Ytp96fpM3vjJbqLJn88ws8XvfDNI= gopkg.in/fsnotify.v1 v1.4.7 h1:xOHLXZwVvI9hhs+cLKq5+I5onOuwQLhQwiu63xxlHs4= gopkg.in/fsnotify.v1 v1.4.7/go.mod h1:Tz8NjZHkW78fSQdbUxIjBTcgA1z1m8ZHf0WmKUhAMys= gopkg.in/mgo.v2 v2.0.0-20190816093944-a6b53ec6cb22/go.mod h1:yeKp02qBN3iKW1OzL3MGk2IdtZzaj7SFntXj72NppTA= @@ -268,4 +291,7 @@ gopkg.in/yaml.v2 v2.2.5/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.2.8 h1:obN1ZagJSUGI0Ek/LBmuj4SNLPfIny3KsKFopxRdj10= gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= +honnef.co/go/tools v0.0.0-20190523083050-ea95bdfd59fc h1:/hemPrYIhOhy8zYrNj+069zDB68us2sMGsfkFJO0iZs= honnef.co/go/tools v0.0.0-20190523083050-ea95bdfd59fc/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= +honnef.co/go/tools v0.0.1-2020.1.5 h1:nI5egYTGJakVyOryqLs1cQO5dO0ksin5XXs2pspk75k= +honnef.co/go/tools v0.0.1-2020.1.5/go.mod h1:X/FiERA/W4tHapMX5mGpAtMSVEeEUOyHaw9vFzvIQ3k= diff --git a/grace/grace.go b/grace/grace.go deleted file mode 100644 index 39d067fd..00000000 --- a/grace/grace.go +++ /dev/null @@ -1,175 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Package grace use to hot reload -// Description: http://grisha.org/blog/2014/06/03/graceful-restart-in-golang/ -// -// Usage: -// -// import( -// "log" -// "net/http" -// "os" -// -// "github.com/astaxie/beego/grace" -// ) -// -// func handler(w http.ResponseWriter, r *http.Request) { -// w.Write([]byte("WORLD!")) -// } -// -// func main() { -// mux := http.NewServeMux() -// mux.HandleFunc("/hello", handler) -// -// err := grace.ListenAndServe("localhost:8080", mux) -// if err != nil { -// log.Println(err) -// } -// log.Println("Server on 8080 stopped") -// os.Exit(0) -// } -package grace - -import ( - "flag" - "net/http" - "os" - "strings" - "sync" - "syscall" - "time" -) - -const ( - // PreSignal is the position to add filter before signal - // Deprecated: using pkg/grace, we will delete this in v2.1.0 - PreSignal = iota - // PostSignal is the position to add filter after signal - // Deprecated: using pkg/grace, we will delete this in v2.1.0 - PostSignal - // StateInit represent the application inited - // Deprecated: using pkg/grace, we will delete this in v2.1.0 - StateInit - // StateRunning represent the application is running - // Deprecated: using pkg/grace, we will delete this in v2.1.0 - StateRunning - // StateShuttingDown represent the application is shutting down - // Deprecated: using pkg/grace, we will delete this in v2.1.0 - StateShuttingDown - // StateTerminate represent the application is killed - // Deprecated: using pkg/grace, we will delete this in v2.1.0 - StateTerminate -) - -var ( - regLock *sync.Mutex - runningServers map[string]*Server - runningServersOrder []string - socketPtrOffsetMap map[string]uint - runningServersForked bool - - // DefaultReadTimeOut is the HTTP read timeout - DefaultReadTimeOut time.Duration - // DefaultWriteTimeOut is the HTTP Write timeout - DefaultWriteTimeOut time.Duration - // DefaultMaxHeaderBytes is the Max HTTP Header size, default is 0, no limit - DefaultMaxHeaderBytes int - // DefaultTimeout is the shutdown server's timeout. default is 60s - DefaultTimeout = 60 * time.Second - - isChild bool - socketOrder string - - hookableSignals []os.Signal -) - -func init() { - flag.BoolVar(&isChild, "graceful", false, "listen on open fd (after forking)") - flag.StringVar(&socketOrder, "socketorder", "", "previous initialization order - used when more than one listener was started") - - regLock = &sync.Mutex{} - runningServers = make(map[string]*Server) - runningServersOrder = []string{} - socketPtrOffsetMap = make(map[string]uint) - - hookableSignals = []os.Signal{ - syscall.SIGHUP, - syscall.SIGINT, - syscall.SIGTERM, - } -} - -// NewServer returns a new graceServer. -// Deprecated: using pkg/grace, we will delete this in v2.1.0 -func NewServer(addr string, handler http.Handler) (srv *Server) { - regLock.Lock() - defer regLock.Unlock() - - if !flag.Parsed() { - flag.Parse() - } - if len(socketOrder) > 0 { - for i, addr := range strings.Split(socketOrder, ",") { - socketPtrOffsetMap[addr] = uint(i) - } - } else { - socketPtrOffsetMap[addr] = uint(len(runningServersOrder)) - } - - srv = &Server{ - sigChan: make(chan os.Signal), - isChild: isChild, - SignalHooks: map[int]map[os.Signal][]func(){ - PreSignal: { - syscall.SIGHUP: {}, - syscall.SIGINT: {}, - syscall.SIGTERM: {}, - }, - PostSignal: { - syscall.SIGHUP: {}, - syscall.SIGINT: {}, - syscall.SIGTERM: {}, - }, - }, - state: StateInit, - Network: "tcp", - terminalChan: make(chan error), //no cache channel - } - srv.Server = &http.Server{ - Addr: addr, - ReadTimeout: DefaultReadTimeOut, - WriteTimeout: DefaultWriteTimeOut, - MaxHeaderBytes: DefaultMaxHeaderBytes, - Handler: handler, - } - - runningServersOrder = append(runningServersOrder, addr) - runningServers[addr] = srv - return srv -} - -// ListenAndServe refer http.ListenAndServe -// Deprecated: using pkg/grace, we will delete this in v2.1.0 -func ListenAndServe(addr string, handler http.Handler) error { - server := NewServer(addr, handler) - return server.ListenAndServe() -} - -// ListenAndServeTLS refer http.ListenAndServeTLS -// Deprecated: using pkg/grace, we will delete this in v2.1.0 -func ListenAndServeTLS(addr string, certFile string, keyFile string, handler http.Handler) error { - server := NewServer(addr, handler) - return server.ListenAndServeTLS(certFile, keyFile) -} diff --git a/grace/server.go b/grace/server.go deleted file mode 100644 index cd659f82..00000000 --- a/grace/server.go +++ /dev/null @@ -1,362 +0,0 @@ -package grace - -import ( - "context" - "crypto/tls" - "crypto/x509" - "fmt" - "io/ioutil" - "log" - "net" - "net/http" - "os" - "os/exec" - "os/signal" - "strings" - "syscall" - "time" -) - -// Server embedded http.Server -// Deprecated: using pkg/grace, we will delete this in v2.1.0 -type Server struct { - *http.Server - ln net.Listener - SignalHooks map[int]map[os.Signal][]func() - sigChan chan os.Signal - isChild bool - state uint8 - Network string - terminalChan chan error -} - -// Serve accepts incoming connections on the Listener l, -// creating a new service goroutine for each. -// The service goroutines read requests and then call srv.Handler to reply to them. -// Deprecated: using pkg/grace, we will delete this in v2.1.0 -func (srv *Server) Serve() (err error) { - srv.state = StateRunning - defer func() { srv.state = StateTerminate }() - - // When Shutdown is called, Serve, ListenAndServe, and ListenAndServeTLS - // immediately return ErrServerClosed. Make sure the program doesn't exit - // and waits instead for Shutdown to return. - if err = srv.Server.Serve(srv.ln); err != nil && err != http.ErrServerClosed { - log.Println(syscall.Getpid(), "Server.Serve() error:", err) - return err - } - - log.Println(syscall.Getpid(), srv.ln.Addr(), "Listener closed.") - // wait for Shutdown to return - if shutdownErr := <-srv.terminalChan; shutdownErr != nil { - return shutdownErr - } - return -} - -// ListenAndServe listens on the TCP network address srv.Addr and then calls Serve -// to handle requests on incoming connections. If srv.Addr is blank, ":http" is -// used. -// Deprecated: using pkg/grace, we will delete this in v2.1.0 -func (srv *Server) ListenAndServe() (err error) { - addr := srv.Addr - if addr == "" { - addr = ":http" - } - - go srv.handleSignals() - - srv.ln, err = srv.getListener(addr) - if err != nil { - log.Println(err) - return err - } - - if srv.isChild { - process, err := os.FindProcess(os.Getppid()) - if err != nil { - log.Println(err) - return err - } - err = process.Signal(syscall.SIGTERM) - if err != nil { - return err - } - } - - log.Println(os.Getpid(), srv.Addr) - return srv.Serve() -} - -// ListenAndServeTLS listens on the TCP network address srv.Addr and then calls -// Serve to handle requests on incoming TLS connections. -// -// Filenames containing a certificate and matching private key for the server must -// be provided. If the certificate is signed by a certificate authority, the -// certFile should be the concatenation of the server's certificate followed by the -// CA's certificate. -// -// If srv.Addr is blank, ":https" is used. -// Deprecated: using pkg/grace, we will delete this in v2.1.0 -func (srv *Server) ListenAndServeTLS(certFile, keyFile string) (err error) { - addr := srv.Addr - if addr == "" { - addr = ":https" - } - - if srv.TLSConfig == nil { - srv.TLSConfig = &tls.Config{} - } - if srv.TLSConfig.NextProtos == nil { - srv.TLSConfig.NextProtos = []string{"http/1.1"} - } - - srv.TLSConfig.Certificates = make([]tls.Certificate, 1) - srv.TLSConfig.Certificates[0], err = tls.LoadX509KeyPair(certFile, keyFile) - if err != nil { - return - } - - go srv.handleSignals() - - ln, err := srv.getListener(addr) - if err != nil { - log.Println(err) - return err - } - srv.ln = tls.NewListener(tcpKeepAliveListener{ln.(*net.TCPListener)}, srv.TLSConfig) - - if srv.isChild { - process, err := os.FindProcess(os.Getppid()) - if err != nil { - log.Println(err) - return err - } - err = process.Signal(syscall.SIGTERM) - if err != nil { - return err - } - } - - log.Println(os.Getpid(), srv.Addr) - return srv.Serve() -} - -// ListenAndServeMutualTLS listens on the TCP network address srv.Addr and then calls -// Serve to handle requests on incoming mutual TLS connections. -// Deprecated: using pkg/grace, we will delete this in v2.1.0 -func (srv *Server) ListenAndServeMutualTLS(certFile, keyFile, trustFile string) (err error) { - addr := srv.Addr - if addr == "" { - addr = ":https" - } - - if srv.TLSConfig == nil { - srv.TLSConfig = &tls.Config{} - } - if srv.TLSConfig.NextProtos == nil { - srv.TLSConfig.NextProtos = []string{"http/1.1"} - } - - srv.TLSConfig.Certificates = make([]tls.Certificate, 1) - srv.TLSConfig.Certificates[0], err = tls.LoadX509KeyPair(certFile, keyFile) - if err != nil { - return - } - srv.TLSConfig.ClientAuth = tls.RequireAndVerifyClientCert - pool := x509.NewCertPool() - data, err := ioutil.ReadFile(trustFile) - if err != nil { - log.Println(err) - return err - } - pool.AppendCertsFromPEM(data) - srv.TLSConfig.ClientCAs = pool - log.Println("Mutual HTTPS") - go srv.handleSignals() - - ln, err := srv.getListener(addr) - if err != nil { - log.Println(err) - return err - } - srv.ln = tls.NewListener(tcpKeepAliveListener{ln.(*net.TCPListener)}, srv.TLSConfig) - - if srv.isChild { - process, err := os.FindProcess(os.Getppid()) - if err != nil { - log.Println(err) - return err - } - err = process.Signal(syscall.SIGTERM) - if err != nil { - return err - } - } - - log.Println(os.Getpid(), srv.Addr) - return srv.Serve() -} - -// getListener either opens a new socket to listen on, or takes the acceptor socket -// it got passed when restarted. -func (srv *Server) getListener(laddr string) (l net.Listener, err error) { - if srv.isChild { - var ptrOffset uint - if len(socketPtrOffsetMap) > 0 { - ptrOffset = socketPtrOffsetMap[laddr] - log.Println("laddr", laddr, "ptr offset", socketPtrOffsetMap[laddr]) - } - - f := os.NewFile(uintptr(3+ptrOffset), "") - l, err = net.FileListener(f) - if err != nil { - err = fmt.Errorf("net.FileListener error: %v", err) - return - } - } else { - l, err = net.Listen(srv.Network, laddr) - if err != nil { - err = fmt.Errorf("net.Listen error: %v", err) - return - } - } - return -} - -type tcpKeepAliveListener struct { - *net.TCPListener -} - -func (ln tcpKeepAliveListener) Accept() (c net.Conn, err error) { - tc, err := ln.AcceptTCP() - if err != nil { - return - } - tc.SetKeepAlive(true) - tc.SetKeepAlivePeriod(3 * time.Minute) - return tc, nil -} - -// handleSignals listens for os Signals and calls any hooked in function that the -// user had registered with the signal. -func (srv *Server) handleSignals() { - var sig os.Signal - - signal.Notify( - srv.sigChan, - hookableSignals..., - ) - - pid := syscall.Getpid() - for { - sig = <-srv.sigChan - srv.signalHooks(PreSignal, sig) - switch sig { - case syscall.SIGHUP: - log.Println(pid, "Received SIGHUP. forking.") - err := srv.fork() - if err != nil { - log.Println("Fork err:", err) - } - case syscall.SIGINT: - log.Println(pid, "Received SIGINT.") - srv.shutdown() - case syscall.SIGTERM: - log.Println(pid, "Received SIGTERM.") - srv.shutdown() - default: - log.Printf("Received %v: nothing i care about...\n", sig) - } - srv.signalHooks(PostSignal, sig) - } -} - -func (srv *Server) signalHooks(ppFlag int, sig os.Signal) { - if _, notSet := srv.SignalHooks[ppFlag][sig]; !notSet { - return - } - for _, f := range srv.SignalHooks[ppFlag][sig] { - f() - } -} - -// shutdown closes the listener so that no new connections are accepted. it also -// starts a goroutine that will serverTimeout (stop all running requests) the server -// after DefaultTimeout. -func (srv *Server) shutdown() { - if srv.state != StateRunning { - return - } - - srv.state = StateShuttingDown - log.Println(syscall.Getpid(), "Waiting for connections to finish...") - ctx := context.Background() - if DefaultTimeout >= 0 { - var cancel context.CancelFunc - ctx, cancel = context.WithTimeout(context.Background(), DefaultTimeout) - defer cancel() - } - srv.terminalChan <- srv.Server.Shutdown(ctx) -} - -func (srv *Server) fork() (err error) { - regLock.Lock() - defer regLock.Unlock() - if runningServersForked { - return - } - runningServersForked = true - - var files = make([]*os.File, len(runningServers)) - var orderArgs = make([]string, len(runningServers)) - for _, srvPtr := range runningServers { - f, _ := srvPtr.ln.(*net.TCPListener).File() - files[socketPtrOffsetMap[srvPtr.Server.Addr]] = f - orderArgs[socketPtrOffsetMap[srvPtr.Server.Addr]] = srvPtr.Server.Addr - } - - log.Println(files) - path := os.Args[0] - var args []string - if len(os.Args) > 1 { - for _, arg := range os.Args[1:] { - if arg == "-graceful" { - break - } - args = append(args, arg) - } - } - args = append(args, "-graceful") - if len(runningServers) > 1 { - args = append(args, fmt.Sprintf(`-socketorder=%s`, strings.Join(orderArgs, ","))) - log.Println(args) - } - cmd := exec.Command(path, args...) - cmd.Stdout = os.Stdout - cmd.Stderr = os.Stderr - cmd.ExtraFiles = files - err = cmd.Start() - if err != nil { - log.Fatalf("Restart: Failed to launch, error: %v", err) - } - - return -} - -// RegisterSignalHook registers a function to be run PreSignal or PostSignal for a given signal. -// Deprecated: using pkg/grace, we will delete this in v2.1.0 -func (srv *Server) RegisterSignalHook(ppFlag int, sig os.Signal, f func()) (err error) { - if ppFlag != PreSignal && ppFlag != PostSignal { - err = fmt.Errorf("Invalid ppFlag argument. Must be either grace.PreSignal or grace.PostSignal") - return - } - for _, s := range hookableSignals { - if s == sig { - srv.SignalHooks[ppFlag][sig] = append(srv.SignalHooks[ppFlag][sig], f) - return - } - } - err = fmt.Errorf("Signal '%v' is not supported", sig) - return -} diff --git a/hooks.go b/hooks.go deleted file mode 100644 index 49c42d5a..00000000 --- a/hooks.go +++ /dev/null @@ -1,104 +0,0 @@ -package beego - -import ( - "encoding/json" - "mime" - "net/http" - "path/filepath" - - "github.com/astaxie/beego/context" - "github.com/astaxie/beego/logs" - "github.com/astaxie/beego/session" -) - -// register MIME type with content type -func registerMime() error { - for k, v := range mimemaps { - mime.AddExtensionType(k, v) - } - return nil -} - -// register default error http handlers, 404,401,403,500 and 503. -func registerDefaultErrorHandler() error { - m := map[string]func(http.ResponseWriter, *http.Request){ - "401": unauthorized, - "402": paymentRequired, - "403": forbidden, - "404": notFound, - "405": methodNotAllowed, - "500": internalServerError, - "501": notImplemented, - "502": badGateway, - "503": serviceUnavailable, - "504": gatewayTimeout, - "417": invalidxsrf, - "422": missingxsrf, - "413": payloadTooLarge, - } - for e, h := range m { - if _, ok := ErrorMaps[e]; !ok { - ErrorHandler(e, h) - } - } - return nil -} - -func registerSession() error { - if BConfig.WebConfig.Session.SessionOn { - var err error - sessionConfig := AppConfig.String("sessionConfig") - conf := new(session.ManagerConfig) - if sessionConfig == "" { - conf.CookieName = BConfig.WebConfig.Session.SessionName - conf.EnableSetCookie = BConfig.WebConfig.Session.SessionAutoSetCookie - conf.Gclifetime = BConfig.WebConfig.Session.SessionGCMaxLifetime - conf.Secure = BConfig.Listen.EnableHTTPS - conf.CookieLifeTime = BConfig.WebConfig.Session.SessionCookieLifeTime - conf.ProviderConfig = filepath.ToSlash(BConfig.WebConfig.Session.SessionProviderConfig) - conf.DisableHTTPOnly = BConfig.WebConfig.Session.SessionDisableHTTPOnly - conf.Domain = BConfig.WebConfig.Session.SessionDomain - conf.EnableSidInHTTPHeader = BConfig.WebConfig.Session.SessionEnableSidInHTTPHeader - conf.SessionNameInHTTPHeader = BConfig.WebConfig.Session.SessionNameInHTTPHeader - conf.EnableSidInURLQuery = BConfig.WebConfig.Session.SessionEnableSidInURLQuery - } else { - if err = json.Unmarshal([]byte(sessionConfig), conf); err != nil { - return err - } - } - if GlobalSessions, err = session.NewManager(BConfig.WebConfig.Session.SessionProvider, conf); err != nil { - return err - } - go GlobalSessions.GC() - } - return nil -} - -func registerTemplate() error { - defer lockViewPaths() - if err := AddViewPath(BConfig.WebConfig.ViewsPath); err != nil { - if BConfig.RunMode == DEV { - logs.Warn(err) - } - return err - } - return nil -} - -func registerAdmin() error { - if BConfig.Listen.EnableAdmin { - go beeAdminApp.Run() - } - return nil -} - -func registerGzip() error { - if BConfig.EnableGzip { - context.InitGzip( - AppConfig.DefaultInt("gzipMinLength", -1), - AppConfig.DefaultInt("gzipCompressLevel", -1), - AppConfig.DefaultStrings("includedMethods", []string{"GET"}), - ) - } - return nil -} diff --git a/httplib/README.md b/httplib/README.md deleted file mode 100644 index 97df8e6b..00000000 --- a/httplib/README.md +++ /dev/null @@ -1,97 +0,0 @@ -# httplib -httplib is an libs help you to curl remote url. - -# How to use? - -## GET -you can use Get to crawl data. - - import "github.com/astaxie/beego/httplib" - - str, err := httplib.Get("http://beego.me/").String() - if err != nil { - // error - } - fmt.Println(str) - -## POST -POST data to remote url - - req := httplib.Post("http://beego.me/") - req.Param("username","astaxie") - req.Param("password","123456") - str, err := req.String() - if err != nil { - // error - } - fmt.Println(str) - -## Set timeout - -The default timeout is `60` seconds, function prototype: - - SetTimeout(connectTimeout, readWriteTimeout time.Duration) - -Example: - - // GET - httplib.Get("http://beego.me/").SetTimeout(100 * time.Second, 30 * time.Second) - - // POST - httplib.Post("http://beego.me/").SetTimeout(100 * time.Second, 30 * time.Second) - - -## Debug - -If you want to debug the request info, set the debug on - - httplib.Get("http://beego.me/").Debug(true) - -## Set HTTP Basic Auth - - str, err := Get("http://beego.me/").SetBasicAuth("user", "passwd").String() - if err != nil { - // error - } - fmt.Println(str) - -## Set HTTPS - -If request url is https, You can set the client support TSL: - - httplib.SetTLSClientConfig(&tls.Config{InsecureSkipVerify: true}) - -More info about the `tls.Config` please visit http://golang.org/pkg/crypto/tls/#Config - -## Set HTTP Version - -some servers need to specify the protocol version of HTTP - - httplib.Get("http://beego.me/").SetProtocolVersion("HTTP/1.1") - -## Set Cookie - -some http request need setcookie. So set it like this: - - cookie := &http.Cookie{} - cookie.Name = "username" - cookie.Value = "astaxie" - httplib.Get("http://beego.me/").SetCookie(cookie) - -## Upload file - -httplib support mutil file upload, use `req.PostFile()` - - req := httplib.Post("http://beego.me/") - req.Param("username","astaxie") - req.PostFile("uploadfile1", "httplib.pdf") - str, err := req.String() - if err != nil { - // error - } - fmt.Println(str) - - -See godoc for further documentation and examples. - -* [godoc.org/github.com/astaxie/beego/httplib](https://godoc.org/github.com/astaxie/beego/httplib) diff --git a/httplib/httplib.go b/httplib/httplib.go deleted file mode 100644 index 8ae95641..00000000 --- a/httplib/httplib.go +++ /dev/null @@ -1,697 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Package httplib is used as http.Client -// Usage: -// -// import "github.com/astaxie/beego/httplib" -// -// b := httplib.Post("http://beego.me/") -// b.Param("username","astaxie") -// b.Param("password","123456") -// b.PostFile("uploadfile1", "httplib.pdf") -// b.PostFile("uploadfile2", "httplib.txt") -// str, err := b.String() -// if err != nil { -// t.Fatal(err) -// } -// fmt.Println(str) -// -// more docs http://beego.me/docs/module/httplib.md -package httplib - -import ( - "bytes" - "compress/gzip" - "crypto/tls" - "encoding/json" - "encoding/xml" - "io" - "io/ioutil" - "log" - "mime/multipart" - "net" - "net/http" - "net/http/cookiejar" - "net/http/httputil" - "net/url" - "os" - "path" - "strings" - "sync" - "time" - - "gopkg.in/yaml.v2" -) - -var defaultSetting = BeegoHTTPSettings{ - UserAgent: "beegoServer", - ConnectTimeout: 60 * time.Second, - ReadWriteTimeout: 60 * time.Second, - Gzip: true, - DumpBody: true, -} - -var defaultCookieJar http.CookieJar -var settingMutex sync.Mutex - -// createDefaultCookie creates a global cookiejar to store cookies. -func createDefaultCookie() { - settingMutex.Lock() - defer settingMutex.Unlock() - defaultCookieJar, _ = cookiejar.New(nil) -} - -// SetDefaultSetting Overwrite default settings -// Deprecated: using pkg/httplib, we will delete this in v2.1.0 -func SetDefaultSetting(setting BeegoHTTPSettings) { - settingMutex.Lock() - defer settingMutex.Unlock() - defaultSetting = setting -} - -// NewBeegoRequest return *BeegoHttpRequest with specific method -// Deprecated: using pkg/httplib, we will delete this in v2.1.0 -func NewBeegoRequest(rawurl, method string) *BeegoHTTPRequest { - var resp http.Response - u, err := url.Parse(rawurl) - if err != nil { - log.Println("Httplib:", err) - } - req := http.Request{ - URL: u, - Method: method, - Header: make(http.Header), - Proto: "HTTP/1.1", - ProtoMajor: 1, - ProtoMinor: 1, - } - return &BeegoHTTPRequest{ - url: rawurl, - req: &req, - params: map[string][]string{}, - files: map[string]string{}, - setting: defaultSetting, - resp: &resp, - } -} - -// Get returns *BeegoHttpRequest with GET method. -// Deprecated: using pkg/httplib, we will delete this in v2.1.0 -func Get(url string) *BeegoHTTPRequest { - return NewBeegoRequest(url, "GET") -} - -// Post returns *BeegoHttpRequest with POST method. -// Deprecated: using pkg/httplib, we will delete this in v2.1.0 -func Post(url string) *BeegoHTTPRequest { - return NewBeegoRequest(url, "POST") -} - -// Put returns *BeegoHttpRequest with PUT method. -// Deprecated: using pkg/httplib, we will delete this in v2.1.0 -func Put(url string) *BeegoHTTPRequest { - return NewBeegoRequest(url, "PUT") -} - -// Delete returns *BeegoHttpRequest DELETE method. -// Deprecated: using pkg/httplib, we will delete this in v2.1.0 -func Delete(url string) *BeegoHTTPRequest { - return NewBeegoRequest(url, "DELETE") -} - -// Head returns *BeegoHttpRequest with HEAD method. -// Deprecated: using pkg/httplib, we will delete this in v2.1.0 -func Head(url string) *BeegoHTTPRequest { - return NewBeegoRequest(url, "HEAD") -} - -// BeegoHTTPSettings is the http.Client setting -// Deprecated: using pkg/httplib, we will delete this in v2.1.0 -type BeegoHTTPSettings struct { - ShowDebug bool - UserAgent string - ConnectTimeout time.Duration - ReadWriteTimeout time.Duration - TLSClientConfig *tls.Config - Proxy func(*http.Request) (*url.URL, error) - Transport http.RoundTripper - CheckRedirect func(req *http.Request, via []*http.Request) error - EnableCookie bool - Gzip bool - DumpBody bool - Retries int // if set to -1 means will retry forever - RetryDelay time.Duration -} - -// BeegoHTTPRequest provides more useful methods for requesting one url than http.Request. -// Deprecated: using pkg/httplib, we will delete this in v2.1.0 -type BeegoHTTPRequest struct { - url string - req *http.Request - params map[string][]string - files map[string]string - setting BeegoHTTPSettings - resp *http.Response - body []byte - dump []byte -} - -// GetRequest return the request object -// Deprecated: using pkg/httplib, we will delete this in v2.1.0 -func (b *BeegoHTTPRequest) GetRequest() *http.Request { - return b.req -} - -// Setting Change request settings -// Deprecated: using pkg/httplib, we will delete this in v2.1.0 -func (b *BeegoHTTPRequest) Setting(setting BeegoHTTPSettings) *BeegoHTTPRequest { - b.setting = setting - return b -} - -// SetBasicAuth sets the request's Authorization header to use HTTP Basic Authentication with the provided username and password. -// Deprecated: using pkg/httplib, we will delete this in v2.1.0 -func (b *BeegoHTTPRequest) SetBasicAuth(username, password string) *BeegoHTTPRequest { - b.req.SetBasicAuth(username, password) - return b -} - -// SetEnableCookie sets enable/disable cookiejar -// Deprecated: using pkg/httplib, we will delete this in v2.1.0 -func (b *BeegoHTTPRequest) SetEnableCookie(enable bool) *BeegoHTTPRequest { - b.setting.EnableCookie = enable - return b -} - -// SetUserAgent sets User-Agent header field -// Deprecated: using pkg/httplib, we will delete this in v2.1.0 -func (b *BeegoHTTPRequest) SetUserAgent(useragent string) *BeegoHTTPRequest { - b.setting.UserAgent = useragent - return b -} - -// Debug sets show debug or not when executing request. -// Deprecated: using pkg/httplib, we will delete this in v2.1.0 -func (b *BeegoHTTPRequest) Debug(isdebug bool) *BeegoHTTPRequest { - b.setting.ShowDebug = isdebug - return b -} - -// Retries sets Retries times. -// default is 0 means no retried. -// -1 means retried forever. -// others means retried times. -// Deprecated: using pkg/httplib, we will delete this in v2.1.0 -func (b *BeegoHTTPRequest) Retries(times int) *BeegoHTTPRequest { - b.setting.Retries = times - return b -} - -// Deprecated: using pkg/httplib, we will delete this in v2.1.0 -func (b *BeegoHTTPRequest) RetryDelay(delay time.Duration) *BeegoHTTPRequest { - b.setting.RetryDelay = delay - return b -} - -// DumpBody setting whether need to Dump the Body. -// Deprecated: using pkg/httplib, we will delete this in v2.1.0 -func (b *BeegoHTTPRequest) DumpBody(isdump bool) *BeegoHTTPRequest { - b.setting.DumpBody = isdump - return b -} - -// DumpRequest return the DumpRequest -// Deprecated: using pkg/httplib, we will delete this in v2.1.0 -func (b *BeegoHTTPRequest) DumpRequest() []byte { - return b.dump -} - -// SetTimeout sets connect time out and read-write time out for BeegoRequest. -// Deprecated: using pkg/httplib, we will delete this in v2.1.0 -func (b *BeegoHTTPRequest) SetTimeout(connectTimeout, readWriteTimeout time.Duration) *BeegoHTTPRequest { - b.setting.ConnectTimeout = connectTimeout - b.setting.ReadWriteTimeout = readWriteTimeout - return b -} - -// SetTLSClientConfig sets tls connection configurations if visiting https url. -// Deprecated: using pkg/httplib, we will delete this in v2.1.0 -func (b *BeegoHTTPRequest) SetTLSClientConfig(config *tls.Config) *BeegoHTTPRequest { - b.setting.TLSClientConfig = config - return b -} - -// Header add header item string in request. -// Deprecated: using pkg/httplib, we will delete this in v2.1.0 -func (b *BeegoHTTPRequest) Header(key, value string) *BeegoHTTPRequest { - b.req.Header.Set(key, value) - return b -} - -// SetHost set the request host -// Deprecated: using pkg/httplib, we will delete this in v2.1.0 -func (b *BeegoHTTPRequest) SetHost(host string) *BeegoHTTPRequest { - b.req.Host = host - return b -} - -// SetProtocolVersion Set the protocol version for incoming requests. -// Client requests always use HTTP/1.1. -// Deprecated: using pkg/httplib, we will delete this in v2.1.0 -func (b *BeegoHTTPRequest) SetProtocolVersion(vers string) *BeegoHTTPRequest { - if len(vers) == 0 { - vers = "HTTP/1.1" - } - - major, minor, ok := http.ParseHTTPVersion(vers) - if ok { - b.req.Proto = vers - b.req.ProtoMajor = major - b.req.ProtoMinor = minor - } - - return b -} - -// SetCookie add cookie into request. -// Deprecated: using pkg/httplib, we will delete this in v2.1.0 -func (b *BeegoHTTPRequest) SetCookie(cookie *http.Cookie) *BeegoHTTPRequest { - b.req.Header.Add("Cookie", cookie.String()) - return b -} - -// SetTransport set the setting transport -// Deprecated: using pkg/httplib, we will delete this in v2.1.0 -func (b *BeegoHTTPRequest) SetTransport(transport http.RoundTripper) *BeegoHTTPRequest { - b.setting.Transport = transport - return b -} - -// SetProxy set the http proxy -// example: -// -// func(req *http.Request) (*url.URL, error) { -// u, _ := url.ParseRequestURI("http://127.0.0.1:8118") -// return u, nil -// } -// Deprecated: using pkg/httplib, we will delete this in v2.1.0 -func (b *BeegoHTTPRequest) SetProxy(proxy func(*http.Request) (*url.URL, error)) *BeegoHTTPRequest { - b.setting.Proxy = proxy - return b -} - -// SetCheckRedirect specifies the policy for handling redirects. -// -// If CheckRedirect is nil, the Client uses its default policy, -// which is to stop after 10 consecutive requests. -// Deprecated: using pkg/httplib, we will delete this in v2.1.0 -func (b *BeegoHTTPRequest) SetCheckRedirect(redirect func(req *http.Request, via []*http.Request) error) *BeegoHTTPRequest { - b.setting.CheckRedirect = redirect - return b -} - -// Param adds query param in to request. -// params build query string as ?key1=value1&key2=value2... -// Deprecated: using pkg/httplib, we will delete this in v2.1.0 -func (b *BeegoHTTPRequest) Param(key, value string) *BeegoHTTPRequest { - if param, ok := b.params[key]; ok { - b.params[key] = append(param, value) - } else { - b.params[key] = []string{value} - } - return b -} - -// PostFile add a post file to the request -// Deprecated: using pkg/httplib, we will delete this in v2.1.0 -func (b *BeegoHTTPRequest) PostFile(formname, filename string) *BeegoHTTPRequest { - b.files[formname] = filename - return b -} - -// Body adds request raw body. -// it supports string and []byte. -// Deprecated: using pkg/httplib, we will delete this in v2.1.0 -func (b *BeegoHTTPRequest) Body(data interface{}) *BeegoHTTPRequest { - switch t := data.(type) { - case string: - bf := bytes.NewBufferString(t) - b.req.Body = ioutil.NopCloser(bf) - b.req.ContentLength = int64(len(t)) - case []byte: - bf := bytes.NewBuffer(t) - b.req.Body = ioutil.NopCloser(bf) - b.req.ContentLength = int64(len(t)) - } - return b -} - -// XMLBody adds request raw body encoding by XML. -// Deprecated: using pkg/httplib, we will delete this in v2.1.0 -func (b *BeegoHTTPRequest) XMLBody(obj interface{}) (*BeegoHTTPRequest, error) { - if b.req.Body == nil && obj != nil { - byts, err := xml.Marshal(obj) - if err != nil { - return b, err - } - b.req.Body = ioutil.NopCloser(bytes.NewReader(byts)) - b.req.ContentLength = int64(len(byts)) - b.req.Header.Set("Content-Type", "application/xml") - } - return b, nil -} - -// YAMLBody adds request raw body encoding by YAML. -// Deprecated: using pkg/httplib, we will delete this in v2.1.0 -func (b *BeegoHTTPRequest) YAMLBody(obj interface{}) (*BeegoHTTPRequest, error) { - if b.req.Body == nil && obj != nil { - byts, err := yaml.Marshal(obj) - if err != nil { - return b, err - } - b.req.Body = ioutil.NopCloser(bytes.NewReader(byts)) - b.req.ContentLength = int64(len(byts)) - b.req.Header.Set("Content-Type", "application/x+yaml") - } - return b, nil -} - -// JSONBody adds request raw body encoding by JSON. -// Deprecated: using pkg/httplib, we will delete this in v2.1.0 -func (b *BeegoHTTPRequest) JSONBody(obj interface{}) (*BeegoHTTPRequest, error) { - if b.req.Body == nil && obj != nil { - byts, err := json.Marshal(obj) - if err != nil { - return b, err - } - b.req.Body = ioutil.NopCloser(bytes.NewReader(byts)) - b.req.ContentLength = int64(len(byts)) - b.req.Header.Set("Content-Type", "application/json") - } - return b, nil -} - -func (b *BeegoHTTPRequest) buildURL(paramBody string) { - // build GET url with query string - if b.req.Method == "GET" && len(paramBody) > 0 { - if strings.Contains(b.url, "?") { - b.url += "&" + paramBody - } else { - b.url = b.url + "?" + paramBody - } - return - } - - // build POST/PUT/PATCH url and body - if (b.req.Method == "POST" || b.req.Method == "PUT" || b.req.Method == "PATCH" || b.req.Method == "DELETE") && b.req.Body == nil { - // with files - if len(b.files) > 0 { - pr, pw := io.Pipe() - bodyWriter := multipart.NewWriter(pw) - go func() { - for formname, filename := range b.files { - fileWriter, err := bodyWriter.CreateFormFile(formname, filename) - if err != nil { - log.Println("Httplib:", err) - } - fh, err := os.Open(filename) - if err != nil { - log.Println("Httplib:", err) - } - //iocopy - _, err = io.Copy(fileWriter, fh) - fh.Close() - if err != nil { - log.Println("Httplib:", err) - } - } - for k, v := range b.params { - for _, vv := range v { - bodyWriter.WriteField(k, vv) - } - } - bodyWriter.Close() - pw.Close() - }() - b.Header("Content-Type", bodyWriter.FormDataContentType()) - b.req.Body = ioutil.NopCloser(pr) - b.Header("Transfer-Encoding", "chunked") - return - } - - // with params - if len(paramBody) > 0 { - b.Header("Content-Type", "application/x-www-form-urlencoded") - b.Body(paramBody) - } - } -} - -func (b *BeegoHTTPRequest) getResponse() (*http.Response, error) { - if b.resp.StatusCode != 0 { - return b.resp, nil - } - resp, err := b.DoRequest() - if err != nil { - return nil, err - } - b.resp = resp - return resp, nil -} - -// DoRequest will do the client.Do -// Deprecated: using pkg/httplib, we will delete this in v2.1.0 -func (b *BeegoHTTPRequest) DoRequest() (resp *http.Response, err error) { - var paramBody string - if len(b.params) > 0 { - var buf bytes.Buffer - for k, v := range b.params { - for _, vv := range v { - buf.WriteString(url.QueryEscape(k)) - buf.WriteByte('=') - buf.WriteString(url.QueryEscape(vv)) - buf.WriteByte('&') - } - } - paramBody = buf.String() - paramBody = paramBody[0 : len(paramBody)-1] - } - - b.buildURL(paramBody) - urlParsed, err := url.Parse(b.url) - if err != nil { - return nil, err - } - - b.req.URL = urlParsed - - trans := b.setting.Transport - - if trans == nil { - // create default transport - trans = &http.Transport{ - TLSClientConfig: b.setting.TLSClientConfig, - Proxy: b.setting.Proxy, - Dial: TimeoutDialer(b.setting.ConnectTimeout, b.setting.ReadWriteTimeout), - MaxIdleConnsPerHost: 100, - } - } else { - // if b.transport is *http.Transport then set the settings. - if t, ok := trans.(*http.Transport); ok { - if t.TLSClientConfig == nil { - t.TLSClientConfig = b.setting.TLSClientConfig - } - if t.Proxy == nil { - t.Proxy = b.setting.Proxy - } - if t.Dial == nil { - t.Dial = TimeoutDialer(b.setting.ConnectTimeout, b.setting.ReadWriteTimeout) - } - } - } - - var jar http.CookieJar - if b.setting.EnableCookie { - if defaultCookieJar == nil { - createDefaultCookie() - } - jar = defaultCookieJar - } - - client := &http.Client{ - Transport: trans, - Jar: jar, - } - - if b.setting.UserAgent != "" && b.req.Header.Get("User-Agent") == "" { - b.req.Header.Set("User-Agent", b.setting.UserAgent) - } - - if b.setting.CheckRedirect != nil { - client.CheckRedirect = b.setting.CheckRedirect - } - - if b.setting.ShowDebug { - dump, err := httputil.DumpRequest(b.req, b.setting.DumpBody) - if err != nil { - log.Println(err.Error()) - } - b.dump = dump - } - // retries default value is 0, it will run once. - // retries equal to -1, it will run forever until success - // retries is setted, it will retries fixed times. - // Sleeps for a 400ms inbetween calls to reduce spam - for i := 0; b.setting.Retries == -1 || i <= b.setting.Retries; i++ { - resp, err = client.Do(b.req) - if err == nil { - break - } - time.Sleep(b.setting.RetryDelay) - } - return resp, err -} - -// String returns the body string in response. -// it calls Response inner. -// Deprecated: using pkg/httplib, we will delete this in v2.1.0 -func (b *BeegoHTTPRequest) String() (string, error) { - data, err := b.Bytes() - if err != nil { - return "", err - } - - return string(data), nil -} - -// Bytes returns the body []byte in response. -// it calls Response inner. -// Deprecated: using pkg/httplib, we will delete this in v2.1.0 -func (b *BeegoHTTPRequest) Bytes() ([]byte, error) { - if b.body != nil { - return b.body, nil - } - resp, err := b.getResponse() - if err != nil { - return nil, err - } - if resp.Body == nil { - return nil, nil - } - defer resp.Body.Close() - if b.setting.Gzip && resp.Header.Get("Content-Encoding") == "gzip" { - reader, err := gzip.NewReader(resp.Body) - if err != nil { - return nil, err - } - b.body, err = ioutil.ReadAll(reader) - return b.body, err - } - b.body, err = ioutil.ReadAll(resp.Body) - return b.body, err -} - -// ToFile saves the body data in response to one file. -// it calls Response inner. -// Deprecated: using pkg/httplib, we will delete this in v2.1.0 -func (b *BeegoHTTPRequest) ToFile(filename string) error { - resp, err := b.getResponse() - if err != nil { - return err - } - if resp.Body == nil { - return nil - } - defer resp.Body.Close() - err = pathExistAndMkdir(filename) - if err != nil { - return err - } - f, err := os.Create(filename) - if err != nil { - return err - } - defer f.Close() - _, err = io.Copy(f, resp.Body) - return err -} - -//Check that the file directory exists, there is no automatically created -func pathExistAndMkdir(filename string) (err error) { - filename = path.Dir(filename) - _, err = os.Stat(filename) - if err == nil { - return nil - } - if os.IsNotExist(err) { - err = os.MkdirAll(filename, os.ModePerm) - if err == nil { - return nil - } - } - return err -} - -// ToJSON returns the map that marshals from the body bytes as json in response . -// it calls Response inner. -// Deprecated: using pkg/httplib, we will delete this in v2.1.0 -func (b *BeegoHTTPRequest) ToJSON(v interface{}) error { - data, err := b.Bytes() - if err != nil { - return err - } - return json.Unmarshal(data, v) -} - -// ToXML returns the map that marshals from the body bytes as xml in response . -// it calls Response inner. -// Deprecated: using pkg/httplib, we will delete this in v2.1.0 -func (b *BeegoHTTPRequest) ToXML(v interface{}) error { - data, err := b.Bytes() - if err != nil { - return err - } - return xml.Unmarshal(data, v) -} - -// ToYAML returns the map that marshals from the body bytes as yaml in response . -// it calls Response inner. -// Deprecated: using pkg/httplib, we will delete this in v2.1.0 -func (b *BeegoHTTPRequest) ToYAML(v interface{}) error { - data, err := b.Bytes() - if err != nil { - return err - } - return yaml.Unmarshal(data, v) -} - -// Response executes request client gets response mannually. -// Deprecated: using pkg/httplib, we will delete this in v2.1.0 -func (b *BeegoHTTPRequest) Response() (*http.Response, error) { - return b.getResponse() -} - -// TimeoutDialer returns functions of connection dialer with timeout settings for http.Transport Dial field. -// Deprecated: using pkg/httplib, we will delete this in v2.1.0 -func TimeoutDialer(cTimeout time.Duration, rwTimeout time.Duration) func(net, addr string) (c net.Conn, err error) { - return func(netw, addr string) (net.Conn, error) { - conn, err := net.DialTimeout(netw, addr, cTimeout) - if err != nil { - return nil, err - } - err = conn.SetDeadline(time.Now().Add(rwTimeout)) - return conn, err - } -} diff --git a/httplib/httplib_test.go b/httplib/httplib_test.go deleted file mode 100644 index f6be8571..00000000 --- a/httplib/httplib_test.go +++ /dev/null @@ -1,286 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package httplib - -import ( - "errors" - "io/ioutil" - "net" - "net/http" - "os" - "strings" - "testing" - "time" -) - -func TestResponse(t *testing.T) { - req := Get("http://httpbin.org/get") - resp, err := req.Response() - if err != nil { - t.Fatal(err) - } - t.Log(resp) -} - -func TestDoRequest(t *testing.T) { - req := Get("https://goolnk.com/33BD2j") - retryAmount := 1 - req.Retries(1) - req.RetryDelay(1400 * time.Millisecond) - retryDelay := 1400 * time.Millisecond - - req.setting.CheckRedirect = func(redirectReq *http.Request, redirectVia []*http.Request) error { - return errors.New("Redirect triggered") - } - - startTime := time.Now().UnixNano() / int64(time.Millisecond) - - _, err := req.Response() - if err == nil { - t.Fatal("Response should have yielded an error") - } - - endTime := time.Now().UnixNano() / int64(time.Millisecond) - elapsedTime := endTime - startTime - delayedTime := int64(retryAmount) * retryDelay.Milliseconds() - - if elapsedTime < delayedTime { - t.Errorf("Not enough retries. Took %dms. Delay was meant to take %dms", elapsedTime, delayedTime) - } - -} - -func TestGet(t *testing.T) { - req := Get("http://httpbin.org/get") - b, err := req.Bytes() - if err != nil { - t.Fatal(err) - } - t.Log(b) - - s, err := req.String() - if err != nil { - t.Fatal(err) - } - t.Log(s) - - if string(b) != s { - t.Fatal("request data not match") - } -} - -func TestSimplePost(t *testing.T) { - v := "smallfish" - req := Post("http://httpbin.org/post") - req.Param("username", v) - - str, err := req.String() - if err != nil { - t.Fatal(err) - } - t.Log(str) - - n := strings.Index(str, v) - if n == -1 { - t.Fatal(v + " not found in post") - } -} - -//func TestPostFile(t *testing.T) { -// v := "smallfish" -// req := Post("http://httpbin.org/post") -// req.Debug(true) -// req.Param("username", v) -// req.PostFile("uploadfile", "httplib_test.go") - -// str, err := req.String() -// if err != nil { -// t.Fatal(err) -// } -// t.Log(str) - -// n := strings.Index(str, v) -// if n == -1 { -// t.Fatal(v + " not found in post") -// } -//} - -func TestSimplePut(t *testing.T) { - str, err := Put("http://httpbin.org/put").String() - if err != nil { - t.Fatal(err) - } - t.Log(str) -} - -func TestSimpleDelete(t *testing.T) { - str, err := Delete("http://httpbin.org/delete").String() - if err != nil { - t.Fatal(err) - } - t.Log(str) -} - -func TestSimpleDeleteParam(t *testing.T) { - str, err := Delete("http://httpbin.org/delete").Param("key", "val").String() - if err != nil { - t.Fatal(err) - } - t.Log(str) -} - -func TestWithCookie(t *testing.T) { - v := "smallfish" - str, err := Get("http://httpbin.org/cookies/set?k1=" + v).SetEnableCookie(true).String() - if err != nil { - t.Fatal(err) - } - t.Log(str) - - str, err = Get("http://httpbin.org/cookies").SetEnableCookie(true).String() - if err != nil { - t.Fatal(err) - } - t.Log(str) - - n := strings.Index(str, v) - if n == -1 { - t.Fatal(v + " not found in cookie") - } -} - -func TestWithBasicAuth(t *testing.T) { - str, err := Get("http://httpbin.org/basic-auth/user/passwd").SetBasicAuth("user", "passwd").String() - if err != nil { - t.Fatal(err) - } - t.Log(str) - n := strings.Index(str, "authenticated") - if n == -1 { - t.Fatal("authenticated not found in response") - } -} - -func TestWithUserAgent(t *testing.T) { - v := "beego" - str, err := Get("http://httpbin.org/headers").SetUserAgent(v).String() - if err != nil { - t.Fatal(err) - } - t.Log(str) - - n := strings.Index(str, v) - if n == -1 { - t.Fatal(v + " not found in user-agent") - } -} - -func TestWithSetting(t *testing.T) { - v := "beego" - var setting BeegoHTTPSettings - setting.EnableCookie = true - setting.UserAgent = v - setting.Transport = &http.Transport{ - DialContext: (&net.Dialer{ - Timeout: 30 * time.Second, - KeepAlive: 30 * time.Second, - DualStack: true, - }).DialContext, - MaxIdleConns: 50, - IdleConnTimeout: 90 * time.Second, - ExpectContinueTimeout: 1 * time.Second, - } - setting.ReadWriteTimeout = 5 * time.Second - SetDefaultSetting(setting) - - str, err := Get("http://httpbin.org/get").String() - if err != nil { - t.Fatal(err) - } - t.Log(str) - - n := strings.Index(str, v) - if n == -1 { - t.Fatal(v + " not found in user-agent") - } -} - -func TestToJson(t *testing.T) { - req := Get("http://httpbin.org/ip") - resp, err := req.Response() - if err != nil { - t.Fatal(err) - } - t.Log(resp) - - // httpbin will return http remote addr - type IP struct { - Origin string `json:"origin"` - } - var ip IP - err = req.ToJSON(&ip) - if err != nil { - t.Fatal(err) - } - t.Log(ip.Origin) - ips := strings.Split(ip.Origin, ",") - if len(ips) == 0 { - t.Fatal("response is not valid ip") - } - for i := range ips { - if net.ParseIP(strings.TrimSpace(ips[i])).To4() == nil { - t.Fatal("response is not valid ip") - } - } - -} - -func TestToFile(t *testing.T) { - f := "beego_testfile" - req := Get("http://httpbin.org/ip") - err := req.ToFile(f) - if err != nil { - t.Fatal(err) - } - defer os.Remove(f) - b, err := ioutil.ReadFile(f) - if n := strings.Index(string(b), "origin"); n == -1 { - t.Fatal(err) - } -} - -func TestToFileDir(t *testing.T) { - f := "./files/beego_testfile" - req := Get("http://httpbin.org/ip") - err := req.ToFile(f) - if err != nil { - t.Fatal(err) - } - defer os.RemoveAll("./files") - b, err := ioutil.ReadFile(f) - if n := strings.Index(string(b), "origin"); n == -1 { - t.Fatal(err) - } -} - -func TestHeader(t *testing.T) { - req := Get("http://httpbin.org/headers") - req.Header("User-Agent", "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_9_0) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/31.0.1650.57 Safari/537.36") - str, err := req.String() - if err != nil { - t.Fatal(err) - } - t.Log(str) -} diff --git a/logs/README.md b/logs/README.md deleted file mode 100644 index c05bcc04..00000000 --- a/logs/README.md +++ /dev/null @@ -1,72 +0,0 @@ -## logs -logs is a Go logs manager. It can use many logs adapters. The repo is inspired by `database/sql` . - - -## How to install? - - go get github.com/astaxie/beego/logs - - -## What adapters are supported? - -As of now this logs support console, file,smtp and conn. - - -## How to use it? - -First you must import it - -```golang -import ( - "github.com/astaxie/beego/logs" -) -``` - -Then init a Log (example with console adapter) - -```golang -log := logs.NewLogger(10000) -log.SetLogger("console", "") -``` - -> the first params stand for how many channel - -Use it like this: - -```golang -log.Trace("trace") -log.Info("info") -log.Warn("warning") -log.Debug("debug") -log.Critical("critical") -``` - -## File adapter - -Configure file adapter like this: - -```golang -log := NewLogger(10000) -log.SetLogger("file", `{"filename":"test.log"}`) -``` - -## Conn adapter - -Configure like this: - -```golang -log := NewLogger(1000) -log.SetLogger("conn", `{"net":"tcp","addr":":7020"}`) -log.Info("info") -``` - -## Smtp adapter - -Configure like this: - -```golang -log := NewLogger(10000) -log.SetLogger("smtp", `{"username":"beegotest@gmail.com","password":"xxxxxxxx","host":"smtp.gmail.com:587","sendTos":["xiemengjun@gmail.com"]}`) -log.Critical("sendmail critical") -time.Sleep(time.Second * 30) -``` diff --git a/logs/accesslog.go b/logs/accesslog.go deleted file mode 100644 index 9011b602..00000000 --- a/logs/accesslog.go +++ /dev/null @@ -1,83 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package logs - -import ( - "bytes" - "encoding/json" - "fmt" - "strings" - "time" -) - -const ( - apacheFormatPattern = "%s - - [%s] \"%s %d %d\" %f %s %s" - apacheFormat = "APACHE_FORMAT" - jsonFormat = "JSON_FORMAT" -) - -// AccessLogRecord struct for holding access log data. -type AccessLogRecord struct { - RemoteAddr string `json:"remote_addr"` - RequestTime time.Time `json:"request_time"` - RequestMethod string `json:"request_method"` - Request string `json:"request"` - ServerProtocol string `json:"server_protocol"` - Host string `json:"host"` - Status int `json:"status"` - BodyBytesSent int64 `json:"body_bytes_sent"` - ElapsedTime time.Duration `json:"elapsed_time"` - HTTPReferrer string `json:"http_referrer"` - HTTPUserAgent string `json:"http_user_agent"` - RemoteUser string `json:"remote_user"` -} - -func (r *AccessLogRecord) json() ([]byte, error) { - buffer := &bytes.Buffer{} - encoder := json.NewEncoder(buffer) - disableEscapeHTML(encoder) - - err := encoder.Encode(r) - return buffer.Bytes(), err -} - -func disableEscapeHTML(i interface{}) { - if e, ok := i.(interface { - SetEscapeHTML(bool) - }); ok { - e.SetEscapeHTML(false) - } -} - -// AccessLog - Format and print access log. -func AccessLog(r *AccessLogRecord, format string) { - var msg string - switch format { - case apacheFormat: - timeFormatted := r.RequestTime.Format("02/Jan/2006 03:04:05") - msg = fmt.Sprintf(apacheFormatPattern, r.RemoteAddr, timeFormatted, r.Request, r.Status, r.BodyBytesSent, - r.ElapsedTime.Seconds(), r.HTTPReferrer, r.HTTPUserAgent) - case jsonFormat: - fallthrough - default: - jsonData, err := r.json() - if err != nil { - msg = fmt.Sprintf(`{"Error": "%s"}`, err) - } else { - msg = string(jsonData) - } - } - beeLogger.writeMsg(levelLoggerImpl, strings.TrimSpace(msg)) -} diff --git a/logs/alils/alils.go b/logs/alils/alils.go deleted file mode 100644 index 867ff4cb..00000000 --- a/logs/alils/alils.go +++ /dev/null @@ -1,186 +0,0 @@ -package alils - -import ( - "encoding/json" - "strings" - "sync" - "time" - - "github.com/astaxie/beego/logs" - "github.com/gogo/protobuf/proto" -) - -const ( - // CacheSize set the flush size - CacheSize int = 64 - // Delimiter define the topic delimiter - Delimiter string = "##" -) - -// Config is the Config for Ali Log -type Config struct { - Project string `json:"project"` - Endpoint string `json:"endpoint"` - KeyID string `json:"key_id"` - KeySecret string `json:"key_secret"` - LogStore string `json:"log_store"` - Topics []string `json:"topics"` - Source string `json:"source"` - Level int `json:"level"` - FlushWhen int `json:"flush_when"` -} - -// aliLSWriter implements LoggerInterface. -// it writes messages in keep-live tcp connection. -type aliLSWriter struct { - store *LogStore - group []*LogGroup - withMap bool - groupMap map[string]*LogGroup - lock *sync.Mutex - Config -} - -// NewAliLS create a new Logger -func NewAliLS() logs.Logger { - alils := new(aliLSWriter) - alils.Level = logs.LevelTrace - return alils -} - -// Init parse config and init struct -func (c *aliLSWriter) Init(jsonConfig string) (err error) { - - json.Unmarshal([]byte(jsonConfig), c) - - if c.FlushWhen > CacheSize { - c.FlushWhen = CacheSize - } - - prj := &LogProject{ - Name: c.Project, - Endpoint: c.Endpoint, - AccessKeyID: c.KeyID, - AccessKeySecret: c.KeySecret, - } - - c.store, err = prj.GetLogStore(c.LogStore) - if err != nil { - return err - } - - // Create default Log Group - c.group = append(c.group, &LogGroup{ - Topic: proto.String(""), - Source: proto.String(c.Source), - Logs: make([]*Log, 0, c.FlushWhen), - }) - - // Create other Log Group - c.groupMap = make(map[string]*LogGroup) - for _, topic := range c.Topics { - - lg := &LogGroup{ - Topic: proto.String(topic), - Source: proto.String(c.Source), - Logs: make([]*Log, 0, c.FlushWhen), - } - - c.group = append(c.group, lg) - c.groupMap[topic] = lg - } - - if len(c.group) == 1 { - c.withMap = false - } else { - c.withMap = true - } - - c.lock = &sync.Mutex{} - - return nil -} - -// WriteMsg write message in connection. -// if connection is down, try to re-connect. -func (c *aliLSWriter) WriteMsg(when time.Time, msg string, level int) (err error) { - - if level > c.Level { - return nil - } - - var topic string - var content string - var lg *LogGroup - if c.withMap { - - // Topic,LogGroup - strs := strings.SplitN(msg, Delimiter, 2) - if len(strs) == 2 { - pos := strings.LastIndex(strs[0], " ") - topic = strs[0][pos+1 : len(strs[0])] - content = strs[0][0:pos] + strs[1] - lg = c.groupMap[topic] - } - - // send to empty Topic - if lg == nil { - content = msg - lg = c.group[0] - } - } else { - content = msg - lg = c.group[0] - } - - c1 := &LogContent{ - Key: proto.String("msg"), - Value: proto.String(content), - } - - l := &Log{ - Time: proto.Uint32(uint32(when.Unix())), - Contents: []*LogContent{ - c1, - }, - } - - c.lock.Lock() - lg.Logs = append(lg.Logs, l) - c.lock.Unlock() - - if len(lg.Logs) >= c.FlushWhen { - c.flush(lg) - } - - return nil -} - -// Flush implementing method. empty. -func (c *aliLSWriter) Flush() { - - // flush all group - for _, lg := range c.group { - c.flush(lg) - } -} - -// Destroy destroy connection writer and close tcp listener. -func (c *aliLSWriter) Destroy() { -} - -func (c *aliLSWriter) flush(lg *LogGroup) { - - c.lock.Lock() - defer c.lock.Unlock() - err := c.store.PutLogs(lg) - if err != nil { - return - } - - lg.Logs = make([]*Log, 0, c.FlushWhen) -} - -func init() { - logs.Register(logs.AdapterAliLS, NewAliLS) -} diff --git a/logs/alils/config.go b/logs/alils/config.go deleted file mode 100755 index e8c24448..00000000 --- a/logs/alils/config.go +++ /dev/null @@ -1,13 +0,0 @@ -package alils - -const ( - version = "0.5.0" // SDK version - signatureMethod = "hmac-sha1" // Signature method - - // OffsetNewest stands for the log head offset, i.e. the offset that will be - // assigned to the next message that will be produced to the shard. - OffsetNewest = "end" - // OffsetOldest stands for the oldest offset available on the logstore for a - // shard. - OffsetOldest = "begin" -) diff --git a/logs/alils/log.pb.go b/logs/alils/log.pb.go deleted file mode 100755 index 601b0d78..00000000 --- a/logs/alils/log.pb.go +++ /dev/null @@ -1,1038 +0,0 @@ -package alils - -import ( - "fmt" - "io" - "math" - - "github.com/gogo/protobuf/proto" - github_com_gogo_protobuf_proto "github.com/gogo/protobuf/proto" -) - -// Reference imports to suppress errors if they are not otherwise used. -var _ = proto.Marshal -var _ = fmt.Errorf -var _ = math.Inf - -var ( - // ErrInvalidLengthLog invalid proto - ErrInvalidLengthLog = fmt.Errorf("proto: negative length found during unmarshaling") - // ErrIntOverflowLog overflow - ErrIntOverflowLog = fmt.Errorf("proto: integer overflow") -) - -// Log define the proto Log -type Log struct { - Time *uint32 `protobuf:"varint,1,req,name=Time" json:"Time,omitempty"` - Contents []*LogContent `protobuf:"bytes,2,rep,name=Contents" json:"Contents,omitempty"` - XXXUnrecognized []byte `json:"-"` -} - -// Reset the Log -func (m *Log) Reset() { *m = Log{} } - -// String return the Compact Log -func (m *Log) String() string { return proto.CompactTextString(m) } - -// ProtoMessage not implemented -func (*Log) ProtoMessage() {} - -// GetTime return the Log's Time -func (m *Log) GetTime() uint32 { - if m != nil && m.Time != nil { - return *m.Time - } - return 0 -} - -// GetContents return the Log's Contents -func (m *Log) GetContents() []*LogContent { - if m != nil { - return m.Contents - } - return nil -} - -// LogContent define the Log content struct -type LogContent struct { - Key *string `protobuf:"bytes,1,req,name=Key" json:"Key,omitempty"` - Value *string `protobuf:"bytes,2,req,name=Value" json:"Value,omitempty"` - XXXUnrecognized []byte `json:"-"` -} - -// Reset LogContent -func (m *LogContent) Reset() { *m = LogContent{} } - -// String return the compact text -func (m *LogContent) String() string { return proto.CompactTextString(m) } - -// ProtoMessage not implemented -func (*LogContent) ProtoMessage() {} - -// GetKey return the Key -func (m *LogContent) GetKey() string { - if m != nil && m.Key != nil { - return *m.Key - } - return "" -} - -// GetValue return the Value -func (m *LogContent) GetValue() string { - if m != nil && m.Value != nil { - return *m.Value - } - return "" -} - -// LogGroup define the logs struct -type LogGroup struct { - Logs []*Log `protobuf:"bytes,1,rep,name=Logs" json:"Logs,omitempty"` - Reserved *string `protobuf:"bytes,2,opt,name=Reserved" json:"Reserved,omitempty"` - Topic *string `protobuf:"bytes,3,opt,name=Topic" json:"Topic,omitempty"` - Source *string `protobuf:"bytes,4,opt,name=Source" json:"Source,omitempty"` - XXXUnrecognized []byte `json:"-"` -} - -// Reset LogGroup -func (m *LogGroup) Reset() { *m = LogGroup{} } - -// String return the compact text -func (m *LogGroup) String() string { return proto.CompactTextString(m) } - -// ProtoMessage not implemented -func (*LogGroup) ProtoMessage() {} - -// GetLogs return the loggroup logs -func (m *LogGroup) GetLogs() []*Log { - if m != nil { - return m.Logs - } - return nil -} - -// GetReserved return Reserved -func (m *LogGroup) GetReserved() string { - if m != nil && m.Reserved != nil { - return *m.Reserved - } - return "" -} - -// GetTopic return Topic -func (m *LogGroup) GetTopic() string { - if m != nil && m.Topic != nil { - return *m.Topic - } - return "" -} - -// GetSource return Source -func (m *LogGroup) GetSource() string { - if m != nil && m.Source != nil { - return *m.Source - } - return "" -} - -// LogGroupList define the LogGroups -type LogGroupList struct { - LogGroups []*LogGroup `protobuf:"bytes,1,rep,name=logGroups" json:"logGroups,omitempty"` - XXXUnrecognized []byte `json:"-"` -} - -// Reset LogGroupList -func (m *LogGroupList) Reset() { *m = LogGroupList{} } - -// String return compact text -func (m *LogGroupList) String() string { return proto.CompactTextString(m) } - -// ProtoMessage not implemented -func (*LogGroupList) ProtoMessage() {} - -// GetLogGroups return the LogGroups -func (m *LogGroupList) GetLogGroups() []*LogGroup { - if m != nil { - return m.LogGroups - } - return nil -} - -// Marshal the logs to byte slice -func (m *Log) Marshal() (data []byte, err error) { - size := m.Size() - data = make([]byte, size) - n, err := m.MarshalTo(data) - if err != nil { - return nil, err - } - return data[:n], nil -} - -// MarshalTo data -func (m *Log) MarshalTo(data []byte) (int, error) { - var i int - _ = i - var l int - _ = l - if m.Time == nil { - return 0, github_com_gogo_protobuf_proto.NewRequiredNotSetError("Time") - } - data[i] = 0x8 - i++ - i = encodeVarintLog(data, i, uint64(*m.Time)) - if len(m.Contents) > 0 { - for _, msg := range m.Contents { - data[i] = 0x12 - i++ - i = encodeVarintLog(data, i, uint64(msg.Size())) - n, err := msg.MarshalTo(data[i:]) - if err != nil { - return 0, err - } - i += n - } - } - if m.XXXUnrecognized != nil { - i += copy(data[i:], m.XXXUnrecognized) - } - return i, nil -} - -// Marshal LogContent -func (m *LogContent) Marshal() (data []byte, err error) { - size := m.Size() - data = make([]byte, size) - n, err := m.MarshalTo(data) - if err != nil { - return nil, err - } - return data[:n], nil -} - -// MarshalTo logcontent to data -func (m *LogContent) MarshalTo(data []byte) (int, error) { - var i int - _ = i - var l int - _ = l - if m.Key == nil { - return 0, github_com_gogo_protobuf_proto.NewRequiredNotSetError("Key") - } - data[i] = 0xa - i++ - i = encodeVarintLog(data, i, uint64(len(*m.Key))) - i += copy(data[i:], *m.Key) - - if m.Value == nil { - return 0, github_com_gogo_protobuf_proto.NewRequiredNotSetError("Value") - } - data[i] = 0x12 - i++ - i = encodeVarintLog(data, i, uint64(len(*m.Value))) - i += copy(data[i:], *m.Value) - if m.XXXUnrecognized != nil { - i += copy(data[i:], m.XXXUnrecognized) - } - return i, nil -} - -// Marshal LogGroup -func (m *LogGroup) Marshal() (data []byte, err error) { - size := m.Size() - data = make([]byte, size) - n, err := m.MarshalTo(data) - if err != nil { - return nil, err - } - return data[:n], nil -} - -// MarshalTo LogGroup to data -func (m *LogGroup) MarshalTo(data []byte) (int, error) { - var i int - _ = i - var l int - _ = l - if len(m.Logs) > 0 { - for _, msg := range m.Logs { - data[i] = 0xa - i++ - i = encodeVarintLog(data, i, uint64(msg.Size())) - n, err := msg.MarshalTo(data[i:]) - if err != nil { - return 0, err - } - i += n - } - } - if m.Reserved != nil { - data[i] = 0x12 - i++ - i = encodeVarintLog(data, i, uint64(len(*m.Reserved))) - i += copy(data[i:], *m.Reserved) - } - if m.Topic != nil { - data[i] = 0x1a - i++ - i = encodeVarintLog(data, i, uint64(len(*m.Topic))) - i += copy(data[i:], *m.Topic) - } - if m.Source != nil { - data[i] = 0x22 - i++ - i = encodeVarintLog(data, i, uint64(len(*m.Source))) - i += copy(data[i:], *m.Source) - } - if m.XXXUnrecognized != nil { - i += copy(data[i:], m.XXXUnrecognized) - } - return i, nil -} - -// Marshal LogGroupList -func (m *LogGroupList) Marshal() (data []byte, err error) { - size := m.Size() - data = make([]byte, size) - n, err := m.MarshalTo(data) - if err != nil { - return nil, err - } - return data[:n], nil -} - -// MarshalTo LogGroupList to data -func (m *LogGroupList) MarshalTo(data []byte) (int, error) { - var i int - _ = i - var l int - _ = l - if len(m.LogGroups) > 0 { - for _, msg := range m.LogGroups { - data[i] = 0xa - i++ - i = encodeVarintLog(data, i, uint64(msg.Size())) - n, err := msg.MarshalTo(data[i:]) - if err != nil { - return 0, err - } - i += n - } - } - if m.XXXUnrecognized != nil { - i += copy(data[i:], m.XXXUnrecognized) - } - return i, nil -} - -func encodeFixed64Log(data []byte, offset int, v uint64) int { - data[offset] = uint8(v) - data[offset+1] = uint8(v >> 8) - data[offset+2] = uint8(v >> 16) - data[offset+3] = uint8(v >> 24) - data[offset+4] = uint8(v >> 32) - data[offset+5] = uint8(v >> 40) - data[offset+6] = uint8(v >> 48) - data[offset+7] = uint8(v >> 56) - return offset + 8 -} -func encodeFixed32Log(data []byte, offset int, v uint32) int { - data[offset] = uint8(v) - data[offset+1] = uint8(v >> 8) - data[offset+2] = uint8(v >> 16) - data[offset+3] = uint8(v >> 24) - return offset + 4 -} -func encodeVarintLog(data []byte, offset int, v uint64) int { - for v >= 1<<7 { - data[offset] = uint8(v&0x7f | 0x80) - v >>= 7 - offset++ - } - data[offset] = uint8(v) - return offset + 1 -} - -// Size return the log's size -func (m *Log) Size() (n int) { - var l int - _ = l - if m.Time != nil { - n += 1 + sovLog(uint64(*m.Time)) - } - if len(m.Contents) > 0 { - for _, e := range m.Contents { - l = e.Size() - n += 1 + l + sovLog(uint64(l)) - } - } - if m.XXXUnrecognized != nil { - n += len(m.XXXUnrecognized) - } - return n -} - -// Size return LogContent size based on Key and Value -func (m *LogContent) Size() (n int) { - var l int - _ = l - if m.Key != nil { - l = len(*m.Key) - n += 1 + l + sovLog(uint64(l)) - } - if m.Value != nil { - l = len(*m.Value) - n += 1 + l + sovLog(uint64(l)) - } - if m.XXXUnrecognized != nil { - n += len(m.XXXUnrecognized) - } - return n -} - -// Size return LogGroup size based on Logs -func (m *LogGroup) Size() (n int) { - var l int - _ = l - if len(m.Logs) > 0 { - for _, e := range m.Logs { - l = e.Size() - n += 1 + l + sovLog(uint64(l)) - } - } - if m.Reserved != nil { - l = len(*m.Reserved) - n += 1 + l + sovLog(uint64(l)) - } - if m.Topic != nil { - l = len(*m.Topic) - n += 1 + l + sovLog(uint64(l)) - } - if m.Source != nil { - l = len(*m.Source) - n += 1 + l + sovLog(uint64(l)) - } - if m.XXXUnrecognized != nil { - n += len(m.XXXUnrecognized) - } - return n -} - -// Size return LogGroupList size -func (m *LogGroupList) Size() (n int) { - var l int - _ = l - if len(m.LogGroups) > 0 { - for _, e := range m.LogGroups { - l = e.Size() - n += 1 + l + sovLog(uint64(l)) - } - } - if m.XXXUnrecognized != nil { - n += len(m.XXXUnrecognized) - } - return n -} - -func sovLog(x uint64) (n int) { - for { - n++ - x >>= 7 - if x == 0 { - break - } - } - return n -} -func sozLog(x uint64) (n int) { - return sovLog((x << 1) ^ (x >> 63)) -} - -// Unmarshal data to log -func (m *Log) Unmarshal(data []byte) error { - var hasFields [1]uint64 - l := len(data) - iNdEx := 0 - for iNdEx < l { - preIndex := iNdEx - var wire uint64 - for shift := uint(0); ; shift += 7 { - if shift >= 64 { - return ErrIntOverflowLog - } - if iNdEx >= l { - return io.ErrUnexpectedEOF - } - b := data[iNdEx] - iNdEx++ - wire |= (uint64(b) & 0x7F) << shift - if b < 0x80 { - break - } - } - fieldNum := int32(wire >> 3) - wireType := int(wire & 0x7) - if wireType == 4 { - return fmt.Errorf("proto: Log: wiretype end group for non-group") - } - if fieldNum <= 0 { - return fmt.Errorf("proto: Log: illegal tag %d (wire type %d)", fieldNum, wire) - } - switch fieldNum { - case 1: - if wireType != 0 { - return fmt.Errorf("proto: wrong wireType = %d for field Time", wireType) - } - var v uint32 - for shift := uint(0); ; shift += 7 { - if shift >= 64 { - return ErrIntOverflowLog - } - if iNdEx >= l { - return io.ErrUnexpectedEOF - } - b := data[iNdEx] - iNdEx++ - v |= (uint32(b) & 0x7F) << shift - if b < 0x80 { - break - } - } - m.Time = &v - hasFields[0] |= uint64(0x00000001) - case 2: - if wireType != 2 { - return fmt.Errorf("proto: wrong wireType = %d for field Contents", wireType) - } - var msglen int - for shift := uint(0); ; shift += 7 { - if shift >= 64 { - return ErrIntOverflowLog - } - if iNdEx >= l { - return io.ErrUnexpectedEOF - } - b := data[iNdEx] - iNdEx++ - msglen |= (int(b) & 0x7F) << shift - if b < 0x80 { - break - } - } - if msglen < 0 { - return ErrInvalidLengthLog - } - postIndex := iNdEx + msglen - if postIndex > l { - return io.ErrUnexpectedEOF - } - m.Contents = append(m.Contents, &LogContent{}) - if err := m.Contents[len(m.Contents)-1].Unmarshal(data[iNdEx:postIndex]); err != nil { - return err - } - iNdEx = postIndex - default: - iNdEx = preIndex - skippy, err := skipLog(data[iNdEx:]) - if err != nil { - return err - } - if skippy < 0 { - return ErrInvalidLengthLog - } - if (iNdEx + skippy) > l { - return io.ErrUnexpectedEOF - } - m.XXXUnrecognized = append(m.XXXUnrecognized, data[iNdEx:iNdEx+skippy]...) - iNdEx += skippy - } - } - if hasFields[0]&uint64(0x00000001) == 0 { - return github_com_gogo_protobuf_proto.NewRequiredNotSetError("Time") - } - - if iNdEx > l { - return io.ErrUnexpectedEOF - } - return nil -} - -// Unmarshal data to LogContent -func (m *LogContent) Unmarshal(data []byte) error { - var hasFields [1]uint64 - l := len(data) - iNdEx := 0 - for iNdEx < l { - preIndex := iNdEx - var wire uint64 - for shift := uint(0); ; shift += 7 { - if shift >= 64 { - return ErrIntOverflowLog - } - if iNdEx >= l { - return io.ErrUnexpectedEOF - } - b := data[iNdEx] - iNdEx++ - wire |= (uint64(b) & 0x7F) << shift - if b < 0x80 { - break - } - } - fieldNum := int32(wire >> 3) - wireType := int(wire & 0x7) - if wireType == 4 { - return fmt.Errorf("proto: Content: wiretype end group for non-group") - } - if fieldNum <= 0 { - return fmt.Errorf("proto: Content: illegal tag %d (wire type %d)", fieldNum, wire) - } - switch fieldNum { - case 1: - if wireType != 2 { - return fmt.Errorf("proto: wrong wireType = %d for field Key", wireType) - } - var stringLen uint64 - for shift := uint(0); ; shift += 7 { - if shift >= 64 { - return ErrIntOverflowLog - } - if iNdEx >= l { - return io.ErrUnexpectedEOF - } - b := data[iNdEx] - iNdEx++ - stringLen |= (uint64(b) & 0x7F) << shift - if b < 0x80 { - break - } - } - intStringLen := int(stringLen) - if intStringLen < 0 { - return ErrInvalidLengthLog - } - postIndex := iNdEx + intStringLen - if postIndex > l { - return io.ErrUnexpectedEOF - } - s := string(data[iNdEx:postIndex]) - m.Key = &s - iNdEx = postIndex - hasFields[0] |= uint64(0x00000001) - case 2: - if wireType != 2 { - return fmt.Errorf("proto: wrong wireType = %d for field Value", wireType) - } - var stringLen uint64 - for shift := uint(0); ; shift += 7 { - if shift >= 64 { - return ErrIntOverflowLog - } - if iNdEx >= l { - return io.ErrUnexpectedEOF - } - b := data[iNdEx] - iNdEx++ - stringLen |= (uint64(b) & 0x7F) << shift - if b < 0x80 { - break - } - } - intStringLen := int(stringLen) - if intStringLen < 0 { - return ErrInvalidLengthLog - } - postIndex := iNdEx + intStringLen - if postIndex > l { - return io.ErrUnexpectedEOF - } - s := string(data[iNdEx:postIndex]) - m.Value = &s - iNdEx = postIndex - hasFields[0] |= uint64(0x00000002) - default: - iNdEx = preIndex - skippy, err := skipLog(data[iNdEx:]) - if err != nil { - return err - } - if skippy < 0 { - return ErrInvalidLengthLog - } - if (iNdEx + skippy) > l { - return io.ErrUnexpectedEOF - } - m.XXXUnrecognized = append(m.XXXUnrecognized, data[iNdEx:iNdEx+skippy]...) - iNdEx += skippy - } - } - if hasFields[0]&uint64(0x00000001) == 0 { - return github_com_gogo_protobuf_proto.NewRequiredNotSetError("Key") - } - if hasFields[0]&uint64(0x00000002) == 0 { - return github_com_gogo_protobuf_proto.NewRequiredNotSetError("Value") - } - - if iNdEx > l { - return io.ErrUnexpectedEOF - } - return nil -} - -// Unmarshal data to LogGroup -func (m *LogGroup) Unmarshal(data []byte) error { - l := len(data) - iNdEx := 0 - for iNdEx < l { - preIndex := iNdEx - var wire uint64 - for shift := uint(0); ; shift += 7 { - if shift >= 64 { - return ErrIntOverflowLog - } - if iNdEx >= l { - return io.ErrUnexpectedEOF - } - b := data[iNdEx] - iNdEx++ - wire |= (uint64(b) & 0x7F) << shift - if b < 0x80 { - break - } - } - fieldNum := int32(wire >> 3) - wireType := int(wire & 0x7) - if wireType == 4 { - return fmt.Errorf("proto: LogGroup: wiretype end group for non-group") - } - if fieldNum <= 0 { - return fmt.Errorf("proto: LogGroup: illegal tag %d (wire type %d)", fieldNum, wire) - } - switch fieldNum { - case 1: - if wireType != 2 { - return fmt.Errorf("proto: wrong wireType = %d for field Logs", wireType) - } - var msglen int - for shift := uint(0); ; shift += 7 { - if shift >= 64 { - return ErrIntOverflowLog - } - if iNdEx >= l { - return io.ErrUnexpectedEOF - } - b := data[iNdEx] - iNdEx++ - msglen |= (int(b) & 0x7F) << shift - if b < 0x80 { - break - } - } - if msglen < 0 { - return ErrInvalidLengthLog - } - postIndex := iNdEx + msglen - if postIndex > l { - return io.ErrUnexpectedEOF - } - m.Logs = append(m.Logs, &Log{}) - if err := m.Logs[len(m.Logs)-1].Unmarshal(data[iNdEx:postIndex]); err != nil { - return err - } - iNdEx = postIndex - case 2: - if wireType != 2 { - return fmt.Errorf("proto: wrong wireType = %d for field Reserved", wireType) - } - var stringLen uint64 - for shift := uint(0); ; shift += 7 { - if shift >= 64 { - return ErrIntOverflowLog - } - if iNdEx >= l { - return io.ErrUnexpectedEOF - } - b := data[iNdEx] - iNdEx++ - stringLen |= (uint64(b) & 0x7F) << shift - if b < 0x80 { - break - } - } - intStringLen := int(stringLen) - if intStringLen < 0 { - return ErrInvalidLengthLog - } - postIndex := iNdEx + intStringLen - if postIndex > l { - return io.ErrUnexpectedEOF - } - s := string(data[iNdEx:postIndex]) - m.Reserved = &s - iNdEx = postIndex - case 3: - if wireType != 2 { - return fmt.Errorf("proto: wrong wireType = %d for field Topic", wireType) - } - var stringLen uint64 - for shift := uint(0); ; shift += 7 { - if shift >= 64 { - return ErrIntOverflowLog - } - if iNdEx >= l { - return io.ErrUnexpectedEOF - } - b := data[iNdEx] - iNdEx++ - stringLen |= (uint64(b) & 0x7F) << shift - if b < 0x80 { - break - } - } - intStringLen := int(stringLen) - if intStringLen < 0 { - return ErrInvalidLengthLog - } - postIndex := iNdEx + intStringLen - if postIndex > l { - return io.ErrUnexpectedEOF - } - s := string(data[iNdEx:postIndex]) - m.Topic = &s - iNdEx = postIndex - case 4: - if wireType != 2 { - return fmt.Errorf("proto: wrong wireType = %d for field Source", wireType) - } - var stringLen uint64 - for shift := uint(0); ; shift += 7 { - if shift >= 64 { - return ErrIntOverflowLog - } - if iNdEx >= l { - return io.ErrUnexpectedEOF - } - b := data[iNdEx] - iNdEx++ - stringLen |= (uint64(b) & 0x7F) << shift - if b < 0x80 { - break - } - } - intStringLen := int(stringLen) - if intStringLen < 0 { - return ErrInvalidLengthLog - } - postIndex := iNdEx + intStringLen - if postIndex > l { - return io.ErrUnexpectedEOF - } - s := string(data[iNdEx:postIndex]) - m.Source = &s - iNdEx = postIndex - default: - iNdEx = preIndex - skippy, err := skipLog(data[iNdEx:]) - if err != nil { - return err - } - if skippy < 0 { - return ErrInvalidLengthLog - } - if (iNdEx + skippy) > l { - return io.ErrUnexpectedEOF - } - m.XXXUnrecognized = append(m.XXXUnrecognized, data[iNdEx:iNdEx+skippy]...) - iNdEx += skippy - } - } - - if iNdEx > l { - return io.ErrUnexpectedEOF - } - return nil -} - -// Unmarshal data to LogGroupList -func (m *LogGroupList) Unmarshal(data []byte) error { - l := len(data) - iNdEx := 0 - for iNdEx < l { - preIndex := iNdEx - var wire uint64 - for shift := uint(0); ; shift += 7 { - if shift >= 64 { - return ErrIntOverflowLog - } - if iNdEx >= l { - return io.ErrUnexpectedEOF - } - b := data[iNdEx] - iNdEx++ - wire |= (uint64(b) & 0x7F) << shift - if b < 0x80 { - break - } - } - fieldNum := int32(wire >> 3) - wireType := int(wire & 0x7) - if wireType == 4 { - return fmt.Errorf("proto: LogGroupList: wiretype end group for non-group") - } - if fieldNum <= 0 { - return fmt.Errorf("proto: LogGroupList: illegal tag %d (wire type %d)", fieldNum, wire) - } - switch fieldNum { - case 1: - if wireType != 2 { - return fmt.Errorf("proto: wrong wireType = %d for field LogGroups", wireType) - } - var msglen int - for shift := uint(0); ; shift += 7 { - if shift >= 64 { - return ErrIntOverflowLog - } - if iNdEx >= l { - return io.ErrUnexpectedEOF - } - b := data[iNdEx] - iNdEx++ - msglen |= (int(b) & 0x7F) << shift - if b < 0x80 { - break - } - } - if msglen < 0 { - return ErrInvalidLengthLog - } - postIndex := iNdEx + msglen - if postIndex > l { - return io.ErrUnexpectedEOF - } - m.LogGroups = append(m.LogGroups, &LogGroup{}) - if err := m.LogGroups[len(m.LogGroups)-1].Unmarshal(data[iNdEx:postIndex]); err != nil { - return err - } - iNdEx = postIndex - default: - iNdEx = preIndex - skippy, err := skipLog(data[iNdEx:]) - if err != nil { - return err - } - if skippy < 0 { - return ErrInvalidLengthLog - } - if (iNdEx + skippy) > l { - return io.ErrUnexpectedEOF - } - m.XXXUnrecognized = append(m.XXXUnrecognized, data[iNdEx:iNdEx+skippy]...) - iNdEx += skippy - } - } - - if iNdEx > l { - return io.ErrUnexpectedEOF - } - return nil -} - -func skipLog(data []byte) (n int, err error) { - l := len(data) - iNdEx := 0 - for iNdEx < l { - var wire uint64 - for shift := uint(0); ; shift += 7 { - if shift >= 64 { - return 0, ErrIntOverflowLog - } - if iNdEx >= l { - return 0, io.ErrUnexpectedEOF - } - b := data[iNdEx] - iNdEx++ - wire |= (uint64(b) & 0x7F) << shift - if b < 0x80 { - break - } - } - wireType := int(wire & 0x7) - switch wireType { - case 0: - for shift := uint(0); ; shift += 7 { - if shift >= 64 { - return 0, ErrIntOverflowLog - } - if iNdEx >= l { - return 0, io.ErrUnexpectedEOF - } - iNdEx++ - if data[iNdEx-1] < 0x80 { - break - } - } - return iNdEx, nil - case 1: - iNdEx += 8 - return iNdEx, nil - case 2: - var length int - for shift := uint(0); ; shift += 7 { - if shift >= 64 { - return 0, ErrIntOverflowLog - } - if iNdEx >= l { - return 0, io.ErrUnexpectedEOF - } - b := data[iNdEx] - iNdEx++ - length |= (int(b) & 0x7F) << shift - if b < 0x80 { - break - } - } - iNdEx += length - if length < 0 { - return 0, ErrInvalidLengthLog - } - return iNdEx, nil - case 3: - for { - var innerWire uint64 - var start = iNdEx - for shift := uint(0); ; shift += 7 { - if shift >= 64 { - return 0, ErrIntOverflowLog - } - if iNdEx >= l { - return 0, io.ErrUnexpectedEOF - } - b := data[iNdEx] - iNdEx++ - innerWire |= (uint64(b) & 0x7F) << shift - if b < 0x80 { - break - } - } - innerWireType := int(innerWire & 0x7) - if innerWireType == 4 { - break - } - next, err := skipLog(data[start:]) - if err != nil { - return 0, err - } - iNdEx = start + next - } - return iNdEx, nil - case 4: - return iNdEx, nil - case 5: - iNdEx += 4 - return iNdEx, nil - default: - return 0, fmt.Errorf("proto: illegal wireType %d", wireType) - } - } - panic("unreachable") -} diff --git a/logs/alils/log_config.go b/logs/alils/log_config.go deleted file mode 100755 index e8564efb..00000000 --- a/logs/alils/log_config.go +++ /dev/null @@ -1,42 +0,0 @@ -package alils - -// InputDetail define log detail -type InputDetail struct { - LogType string `json:"logType"` - LogPath string `json:"logPath"` - FilePattern string `json:"filePattern"` - LocalStorage bool `json:"localStorage"` - TimeFormat string `json:"timeFormat"` - LogBeginRegex string `json:"logBeginRegex"` - Regex string `json:"regex"` - Keys []string `json:"key"` - FilterKeys []string `json:"filterKey"` - FilterRegex []string `json:"filterRegex"` - TopicFormat string `json:"topicFormat"` -} - -// OutputDetail define the output detail -type OutputDetail struct { - Endpoint string `json:"endpoint"` - LogStoreName string `json:"logstoreName"` -} - -// LogConfig define Log Config -type LogConfig struct { - Name string `json:"configName"` - InputType string `json:"inputType"` - InputDetail InputDetail `json:"inputDetail"` - OutputType string `json:"outputType"` - OutputDetail OutputDetail `json:"outputDetail"` - - CreateTime uint32 - LastModifyTime uint32 - - project *LogProject -} - -// GetAppliedMachineGroup returns applied machine group of this config. -func (c *LogConfig) GetAppliedMachineGroup(confName string) (groupNames []string, err error) { - groupNames, err = c.project.GetAppliedMachineGroups(c.Name) - return -} diff --git a/logs/alils/log_project.go b/logs/alils/log_project.go deleted file mode 100755 index 59db8cbf..00000000 --- a/logs/alils/log_project.go +++ /dev/null @@ -1,819 +0,0 @@ -/* -Package alils implements the SDK(v0.5.0) of Simple Log Service(abbr. SLS). - -For more description about SLS, please read this article: -http://gitlab.alibaba-inc.com/sls/doc. -*/ -package alils - -import ( - "encoding/json" - "fmt" - "io/ioutil" - "net/http" - "net/http/httputil" -) - -// Error message in SLS HTTP response. -type errorMessage struct { - Code string `json:"errorCode"` - Message string `json:"errorMessage"` -} - -// LogProject Define the Ali Project detail -type LogProject struct { - Name string // Project name - Endpoint string // IP or hostname of SLS endpoint - AccessKeyID string - AccessKeySecret string -} - -// NewLogProject creates a new SLS project. -func NewLogProject(name, endpoint, AccessKeyID, accessKeySecret string) (p *LogProject, err error) { - p = &LogProject{ - Name: name, - Endpoint: endpoint, - AccessKeyID: AccessKeyID, - AccessKeySecret: accessKeySecret, - } - return p, nil -} - -// ListLogStore returns all logstore names of project p. -func (p *LogProject) ListLogStore() (storeNames []string, err error) { - h := map[string]string{ - "x-sls-bodyrawsize": "0", - } - - uri := fmt.Sprintf("/logstores") - r, err := request(p, "GET", uri, h, nil) - if err != nil { - return - } - - buf, err := ioutil.ReadAll(r.Body) - if err != nil { - return - } - - if r.StatusCode != http.StatusOK { - errMsg := &errorMessage{} - err = json.Unmarshal(buf, errMsg) - if err != nil { - err = fmt.Errorf("failed to list logstore") - dump, _ := httputil.DumpResponse(r, true) - fmt.Printf("%s\n", dump) - return - } - err = fmt.Errorf("%v:%v", errMsg.Code, errMsg.Message) - return - } - - type Body struct { - Count int - LogStores []string - } - body := &Body{} - - err = json.Unmarshal(buf, body) - if err != nil { - return - } - - storeNames = body.LogStores - - return -} - -// GetLogStore returns logstore according by logstore name. -func (p *LogProject) GetLogStore(name string) (s *LogStore, err error) { - h := map[string]string{ - "x-sls-bodyrawsize": "0", - } - - r, err := request(p, "GET", "/logstores/"+name, h, nil) - if err != nil { - return - } - - buf, err := ioutil.ReadAll(r.Body) - if err != nil { - return - } - - if r.StatusCode != http.StatusOK { - errMsg := &errorMessage{} - err = json.Unmarshal(buf, errMsg) - if err != nil { - err = fmt.Errorf("failed to get logstore") - dump, _ := httputil.DumpResponse(r, true) - fmt.Printf("%s\n", dump) - return - } - err = fmt.Errorf("%v:%v", errMsg.Code, errMsg.Message) - return - } - - s = &LogStore{} - err = json.Unmarshal(buf, s) - if err != nil { - return - } - s.project = p - return -} - -// CreateLogStore creates a new logstore in SLS, -// where name is logstore name, -// and ttl is time-to-live(in day) of logs, -// and shardCnt is the number of shards. -func (p *LogProject) CreateLogStore(name string, ttl, shardCnt int) (err error) { - - type Body struct { - Name string `json:"logstoreName"` - TTL int `json:"ttl"` - ShardCount int `json:"shardCount"` - } - - store := &Body{ - Name: name, - TTL: ttl, - ShardCount: shardCnt, - } - - body, err := json.Marshal(store) - if err != nil { - return - } - - h := map[string]string{ - "x-sls-bodyrawsize": fmt.Sprintf("%v", len(body)), - "Content-Type": "application/json", - "Accept-Encoding": "deflate", // TODO: support lz4 - } - - r, err := request(p, "POST", "/logstores", h, body) - if err != nil { - return - } - - body, err = ioutil.ReadAll(r.Body) - if err != nil { - return - } - - if r.StatusCode != http.StatusOK { - errMsg := &errorMessage{} - err = json.Unmarshal(body, errMsg) - if err != nil { - err = fmt.Errorf("failed to create logstore") - dump, _ := httputil.DumpResponse(r, true) - fmt.Printf("%s\n", dump) - return - } - err = fmt.Errorf("%v:%v", errMsg.Code, errMsg.Message) - return - } - - return -} - -// DeleteLogStore deletes a logstore according by logstore name. -func (p *LogProject) DeleteLogStore(name string) (err error) { - h := map[string]string{ - "x-sls-bodyrawsize": "0", - } - - r, err := request(p, "DELETE", "/logstores/"+name, h, nil) - if err != nil { - return - } - - body, err := ioutil.ReadAll(r.Body) - if err != nil { - return - } - - if r.StatusCode != http.StatusOK { - errMsg := &errorMessage{} - err = json.Unmarshal(body, errMsg) - if err != nil { - err = fmt.Errorf("failed to delete logstore") - dump, _ := httputil.DumpResponse(r, true) - fmt.Printf("%s\n", dump) - return - } - err = fmt.Errorf("%v:%v", errMsg.Code, errMsg.Message) - return - } - return -} - -// UpdateLogStore updates a logstore according by logstore name, -// obviously we can't modify the logstore name itself. -func (p *LogProject) UpdateLogStore(name string, ttl, shardCnt int) (err error) { - - type Body struct { - Name string `json:"logstoreName"` - TTL int `json:"ttl"` - ShardCount int `json:"shardCount"` - } - - store := &Body{ - Name: name, - TTL: ttl, - ShardCount: shardCnt, - } - - body, err := json.Marshal(store) - if err != nil { - return - } - - h := map[string]string{ - "x-sls-bodyrawsize": fmt.Sprintf("%v", len(body)), - "Content-Type": "application/json", - "Accept-Encoding": "deflate", // TODO: support lz4 - } - - r, err := request(p, "PUT", "/logstores", h, body) - if err != nil { - return - } - - body, err = ioutil.ReadAll(r.Body) - if err != nil { - return - } - - if r.StatusCode != http.StatusOK { - errMsg := &errorMessage{} - err = json.Unmarshal(body, errMsg) - if err != nil { - err = fmt.Errorf("failed to update logstore") - dump, _ := httputil.DumpResponse(r, true) - fmt.Printf("%s\n", dump) - return - } - err = fmt.Errorf("%v:%v", errMsg.Code, errMsg.Message) - return - } - - return -} - -// ListMachineGroup returns machine group name list and the total number of machine groups. -// The offset starts from 0 and the size is the max number of machine groups could be returned. -func (p *LogProject) ListMachineGroup(offset, size int) (m []string, total int, err error) { - h := map[string]string{ - "x-sls-bodyrawsize": "0", - } - - if size <= 0 { - size = 500 - } - - uri := fmt.Sprintf("/machinegroups?offset=%v&size=%v", offset, size) - r, err := request(p, "GET", uri, h, nil) - if err != nil { - return - } - - buf, err := ioutil.ReadAll(r.Body) - if err != nil { - return - } - - if r.StatusCode != http.StatusOK { - errMsg := &errorMessage{} - err = json.Unmarshal(buf, errMsg) - if err != nil { - err = fmt.Errorf("failed to list machine group") - dump, _ := httputil.DumpResponse(r, true) - fmt.Printf("%s\n", dump) - return - } - err = fmt.Errorf("%v:%v", errMsg.Code, errMsg.Message) - return - } - - type Body struct { - MachineGroups []string - Count int - Total int - } - body := &Body{} - - err = json.Unmarshal(buf, body) - if err != nil { - return - } - - m = body.MachineGroups - total = body.Total - - return -} - -// GetMachineGroup retruns machine group according by machine group name. -func (p *LogProject) GetMachineGroup(name string) (m *MachineGroup, err error) { - h := map[string]string{ - "x-sls-bodyrawsize": "0", - } - - r, err := request(p, "GET", "/machinegroups/"+name, h, nil) - if err != nil { - return - } - - buf, err := ioutil.ReadAll(r.Body) - if err != nil { - return - } - - if r.StatusCode != http.StatusOK { - errMsg := &errorMessage{} - err = json.Unmarshal(buf, errMsg) - if err != nil { - err = fmt.Errorf("failed to get machine group:%v", name) - dump, _ := httputil.DumpResponse(r, true) - fmt.Printf("%s\n", dump) - return - } - err = fmt.Errorf("%v:%v", errMsg.Code, errMsg.Message) - return - } - - m = &MachineGroup{} - err = json.Unmarshal(buf, m) - if err != nil { - return - } - m.project = p - return -} - -// CreateMachineGroup creates a new machine group in SLS. -func (p *LogProject) CreateMachineGroup(m *MachineGroup) (err error) { - - body, err := json.Marshal(m) - if err != nil { - return - } - - h := map[string]string{ - "x-sls-bodyrawsize": fmt.Sprintf("%v", len(body)), - "Content-Type": "application/json", - "Accept-Encoding": "deflate", // TODO: support lz4 - } - - r, err := request(p, "POST", "/machinegroups", h, body) - if err != nil { - return - } - - body, err = ioutil.ReadAll(r.Body) - if err != nil { - return - } - - if r.StatusCode != http.StatusOK { - errMsg := &errorMessage{} - err = json.Unmarshal(body, errMsg) - if err != nil { - err = fmt.Errorf("failed to create machine group") - dump, _ := httputil.DumpResponse(r, true) - fmt.Printf("%s\n", dump) - return - } - err = fmt.Errorf("%v:%v", errMsg.Code, errMsg.Message) - return - } - - return -} - -// UpdateMachineGroup updates a machine group. -func (p *LogProject) UpdateMachineGroup(m *MachineGroup) (err error) { - - body, err := json.Marshal(m) - if err != nil { - return - } - - h := map[string]string{ - "x-sls-bodyrawsize": fmt.Sprintf("%v", len(body)), - "Content-Type": "application/json", - "Accept-Encoding": "deflate", // TODO: support lz4 - } - - r, err := request(p, "PUT", "/machinegroups/"+m.Name, h, body) - if err != nil { - return - } - - body, err = ioutil.ReadAll(r.Body) - if err != nil { - return - } - - if r.StatusCode != http.StatusOK { - errMsg := &errorMessage{} - err = json.Unmarshal(body, errMsg) - if err != nil { - err = fmt.Errorf("failed to update machine group") - dump, _ := httputil.DumpResponse(r, true) - fmt.Printf("%s\n", dump) - return - } - err = fmt.Errorf("%v:%v", errMsg.Code, errMsg.Message) - return - } - - return -} - -// DeleteMachineGroup deletes machine group according machine group name. -func (p *LogProject) DeleteMachineGroup(name string) (err error) { - h := map[string]string{ - "x-sls-bodyrawsize": "0", - } - - r, err := request(p, "DELETE", "/machinegroups/"+name, h, nil) - if err != nil { - return - } - - body, err := ioutil.ReadAll(r.Body) - if err != nil { - return - } - - if r.StatusCode != http.StatusOK { - errMsg := &errorMessage{} - err = json.Unmarshal(body, errMsg) - if err != nil { - err = fmt.Errorf("failed to delete machine group") - dump, _ := httputil.DumpResponse(r, true) - fmt.Printf("%s\n", dump) - return - } - err = fmt.Errorf("%v:%v", errMsg.Code, errMsg.Message) - return - } - return -} - -// ListConfig returns config names list and the total number of configs. -// The offset starts from 0 and the size is the max number of configs could be returned. -func (p *LogProject) ListConfig(offset, size int) (cfgNames []string, total int, err error) { - h := map[string]string{ - "x-sls-bodyrawsize": "0", - } - - if size <= 0 { - size = 100 - } - - uri := fmt.Sprintf("/configs?offset=%v&size=%v", offset, size) - r, err := request(p, "GET", uri, h, nil) - if err != nil { - return - } - - buf, err := ioutil.ReadAll(r.Body) - if err != nil { - return - } - - if r.StatusCode != http.StatusOK { - errMsg := &errorMessage{} - err = json.Unmarshal(buf, errMsg) - if err != nil { - err = fmt.Errorf("failed to delete machine group") - dump, _ := httputil.DumpResponse(r, true) - fmt.Printf("%s\n", dump) - return - } - err = fmt.Errorf("%v:%v", errMsg.Code, errMsg.Message) - return - } - - type Body struct { - Total int - Configs []string - } - body := &Body{} - - err = json.Unmarshal(buf, body) - if err != nil { - return - } - - cfgNames = body.Configs - total = body.Total - return -} - -// GetConfig returns config according by config name. -func (p *LogProject) GetConfig(name string) (c *LogConfig, err error) { - h := map[string]string{ - "x-sls-bodyrawsize": "0", - } - - r, err := request(p, "GET", "/configs/"+name, h, nil) - if err != nil { - return - } - - buf, err := ioutil.ReadAll(r.Body) - if err != nil { - return - } - - if r.StatusCode != http.StatusOK { - errMsg := &errorMessage{} - err = json.Unmarshal(buf, errMsg) - if err != nil { - err = fmt.Errorf("failed to delete config") - dump, _ := httputil.DumpResponse(r, true) - fmt.Printf("%s\n", dump) - return - } - err = fmt.Errorf("%v:%v", errMsg.Code, errMsg.Message) - return - } - - c = &LogConfig{} - err = json.Unmarshal(buf, c) - if err != nil { - return - } - c.project = p - return -} - -// UpdateConfig updates a config. -func (p *LogProject) UpdateConfig(c *LogConfig) (err error) { - - body, err := json.Marshal(c) - if err != nil { - return - } - - h := map[string]string{ - "x-sls-bodyrawsize": fmt.Sprintf("%v", len(body)), - "Content-Type": "application/json", - "Accept-Encoding": "deflate", // TODO: support lz4 - } - - r, err := request(p, "PUT", "/configs/"+c.Name, h, body) - if err != nil { - return - } - - body, err = ioutil.ReadAll(r.Body) - if err != nil { - return - } - - if r.StatusCode != http.StatusOK { - errMsg := &errorMessage{} - err = json.Unmarshal(body, errMsg) - if err != nil { - err = fmt.Errorf("failed to update config") - dump, _ := httputil.DumpResponse(r, true) - fmt.Printf("%s\n", dump) - return - } - err = fmt.Errorf("%v:%v", errMsg.Code, errMsg.Message) - return - } - - return -} - -// CreateConfig creates a new config in SLS. -func (p *LogProject) CreateConfig(c *LogConfig) (err error) { - - body, err := json.Marshal(c) - if err != nil { - return - } - - h := map[string]string{ - "x-sls-bodyrawsize": fmt.Sprintf("%v", len(body)), - "Content-Type": "application/json", - "Accept-Encoding": "deflate", // TODO: support lz4 - } - - r, err := request(p, "POST", "/configs", h, body) - if err != nil { - return - } - - body, err = ioutil.ReadAll(r.Body) - if err != nil { - return - } - - if r.StatusCode != http.StatusOK { - errMsg := &errorMessage{} - err = json.Unmarshal(body, errMsg) - if err != nil { - err = fmt.Errorf("failed to update config") - dump, _ := httputil.DumpResponse(r, true) - fmt.Printf("%s\n", dump) - return - } - err = fmt.Errorf("%v:%v", errMsg.Code, errMsg.Message) - return - } - - return -} - -// DeleteConfig deletes a config according by config name. -func (p *LogProject) DeleteConfig(name string) (err error) { - h := map[string]string{ - "x-sls-bodyrawsize": "0", - } - - r, err := request(p, "DELETE", "/configs/"+name, h, nil) - if err != nil { - return - } - - body, err := ioutil.ReadAll(r.Body) - if err != nil { - return - } - - if r.StatusCode != http.StatusOK { - errMsg := &errorMessage{} - err = json.Unmarshal(body, errMsg) - if err != nil { - err = fmt.Errorf("failed to delete config") - dump, _ := httputil.DumpResponse(r, true) - fmt.Printf("%s\n", dump) - return - } - err = fmt.Errorf("%v:%v", errMsg.Code, errMsg.Message) - return - } - return -} - -// GetAppliedMachineGroups returns applied machine group names list according config name. -func (p *LogProject) GetAppliedMachineGroups(confName string) (groupNames []string, err error) { - h := map[string]string{ - "x-sls-bodyrawsize": "0", - } - - uri := fmt.Sprintf("/configs/%v/machinegroups", confName) - r, err := request(p, "GET", uri, h, nil) - if err != nil { - return - } - - buf, err := ioutil.ReadAll(r.Body) - if err != nil { - return - } - - if r.StatusCode != http.StatusOK { - errMsg := &errorMessage{} - err = json.Unmarshal(buf, errMsg) - if err != nil { - err = fmt.Errorf("failed to get applied machine groups") - dump, _ := httputil.DumpResponse(r, true) - fmt.Printf("%s\n", dump) - return - } - err = fmt.Errorf("%v:%v", errMsg.Code, errMsg.Message) - return - } - - type Body struct { - Count int - Machinegroups []string - } - - body := &Body{} - err = json.Unmarshal(buf, body) - if err != nil { - return - } - - groupNames = body.Machinegroups - return -} - -// GetAppliedConfigs returns applied config names list according machine group name groupName. -func (p *LogProject) GetAppliedConfigs(groupName string) (confNames []string, err error) { - h := map[string]string{ - "x-sls-bodyrawsize": "0", - } - - uri := fmt.Sprintf("/machinegroups/%v/configs", groupName) - r, err := request(p, "GET", uri, h, nil) - if err != nil { - return - } - - buf, err := ioutil.ReadAll(r.Body) - if err != nil { - return - } - - if r.StatusCode != http.StatusOK { - errMsg := &errorMessage{} - err = json.Unmarshal(buf, errMsg) - if err != nil { - err = fmt.Errorf("failed to applied configs") - dump, _ := httputil.DumpResponse(r, true) - fmt.Printf("%s\n", dump) - return - } - err = fmt.Errorf("%v:%v", errMsg.Code, errMsg.Message) - return - } - - type Cfg struct { - Count int `json:"count"` - Configs []string `json:"configs"` - } - - body := &Cfg{} - err = json.Unmarshal(buf, body) - if err != nil { - return - } - - confNames = body.Configs - return -} - -// ApplyConfigToMachineGroup applies config to machine group. -func (p *LogProject) ApplyConfigToMachineGroup(confName, groupName string) (err error) { - h := map[string]string{ - "x-sls-bodyrawsize": "0", - } - - uri := fmt.Sprintf("/machinegroups/%v/configs/%v", groupName, confName) - r, err := request(p, "PUT", uri, h, nil) - if err != nil { - return - } - - buf, err := ioutil.ReadAll(r.Body) - if err != nil { - return - } - - if r.StatusCode != http.StatusOK { - errMsg := &errorMessage{} - err = json.Unmarshal(buf, errMsg) - if err != nil { - err = fmt.Errorf("failed to apply config to machine group") - dump, _ := httputil.DumpResponse(r, true) - fmt.Printf("%s\n", dump) - return - } - err = fmt.Errorf("%v:%v", errMsg.Code, errMsg.Message) - return - } - return -} - -// RemoveConfigFromMachineGroup removes config from machine group. -func (p *LogProject) RemoveConfigFromMachineGroup(confName, groupName string) (err error) { - h := map[string]string{ - "x-sls-bodyrawsize": "0", - } - - uri := fmt.Sprintf("/machinegroups/%v/configs/%v", groupName, confName) - r, err := request(p, "DELETE", uri, h, nil) - if err != nil { - return - } - - buf, err := ioutil.ReadAll(r.Body) - if err != nil { - return - } - - if r.StatusCode != http.StatusOK { - errMsg := &errorMessage{} - err = json.Unmarshal(buf, errMsg) - if err != nil { - err = fmt.Errorf("failed to remove config from machine group") - dump, _ := httputil.DumpResponse(r, true) - fmt.Printf("%s\n", dump) - return - } - err = fmt.Errorf("%v:%v", errMsg.Code, errMsg.Message) - return - } - return -} diff --git a/logs/alils/log_store.go b/logs/alils/log_store.go deleted file mode 100755 index fa502736..00000000 --- a/logs/alils/log_store.go +++ /dev/null @@ -1,271 +0,0 @@ -package alils - -import ( - "encoding/json" - "fmt" - "io/ioutil" - "net/http" - "net/http/httputil" - "strconv" - - lz4 "github.com/cloudflare/golz4" - "github.com/gogo/protobuf/proto" -) - -// LogStore Store the logs -type LogStore struct { - Name string `json:"logstoreName"` - TTL int - ShardCount int - - CreateTime uint32 - LastModifyTime uint32 - - project *LogProject -} - -// Shard define the Log Shard -type Shard struct { - ShardID int `json:"shardID"` -} - -// ListShards returns shard id list of this logstore. -func (s *LogStore) ListShards() (shardIDs []int, err error) { - h := map[string]string{ - "x-sls-bodyrawsize": "0", - } - - uri := fmt.Sprintf("/logstores/%v/shards", s.Name) - r, err := request(s.project, "GET", uri, h, nil) - if err != nil { - return - } - - buf, err := ioutil.ReadAll(r.Body) - if err != nil { - return - } - - if r.StatusCode != http.StatusOK { - errMsg := &errorMessage{} - err = json.Unmarshal(buf, errMsg) - if err != nil { - err = fmt.Errorf("failed to list logstore") - dump, _ := httputil.DumpResponse(r, true) - fmt.Println(dump) - return - } - err = fmt.Errorf("%v:%v", errMsg.Code, errMsg.Message) - return - } - - var shards []*Shard - err = json.Unmarshal(buf, &shards) - if err != nil { - return - } - - for _, v := range shards { - shardIDs = append(shardIDs, v.ShardID) - } - return -} - -// PutLogs put logs into logstore. -// The callers should transform user logs into LogGroup. -func (s *LogStore) PutLogs(lg *LogGroup) (err error) { - body, err := proto.Marshal(lg) - if err != nil { - return - } - - // Compresse body with lz4 - out := make([]byte, lz4.CompressBound(body)) - n, err := lz4.Compress(body, out) - if err != nil { - return - } - - h := map[string]string{ - "x-sls-compresstype": "lz4", - "x-sls-bodyrawsize": fmt.Sprintf("%v", len(body)), - "Content-Type": "application/x-protobuf", - } - - uri := fmt.Sprintf("/logstores/%v", s.Name) - r, err := request(s.project, "POST", uri, h, out[:n]) - if err != nil { - return - } - - buf, err := ioutil.ReadAll(r.Body) - if err != nil { - return - } - - if r.StatusCode != http.StatusOK { - errMsg := &errorMessage{} - err = json.Unmarshal(buf, errMsg) - if err != nil { - err = fmt.Errorf("failed to put logs") - dump, _ := httputil.DumpResponse(r, true) - fmt.Println(dump) - return - } - err = fmt.Errorf("%v:%v", errMsg.Code, errMsg.Message) - return - } - return -} - -// GetCursor gets log cursor of one shard specified by shardID. -// The from can be in three form: a) unix timestamp in seccond, b) "begin", c) "end". -// For more detail please read: http://gitlab.alibaba-inc.com/sls/doc/blob/master/api/shard.md#logstore -func (s *LogStore) GetCursor(shardID int, from string) (cursor string, err error) { - h := map[string]string{ - "x-sls-bodyrawsize": "0", - } - - uri := fmt.Sprintf("/logstores/%v/shards/%v?type=cursor&from=%v", - s.Name, shardID, from) - - r, err := request(s.project, "GET", uri, h, nil) - if err != nil { - return - } - - buf, err := ioutil.ReadAll(r.Body) - if err != nil { - return - } - - if r.StatusCode != http.StatusOK { - errMsg := &errorMessage{} - err = json.Unmarshal(buf, errMsg) - if err != nil { - err = fmt.Errorf("failed to get cursor") - dump, _ := httputil.DumpResponse(r, true) - fmt.Println(dump) - return - } - err = fmt.Errorf("%v:%v", errMsg.Code, errMsg.Message) - return - } - - type Body struct { - Cursor string - } - body := &Body{} - - err = json.Unmarshal(buf, body) - if err != nil { - return - } - cursor = body.Cursor - return -} - -// GetLogsBytes gets logs binary data from shard specified by shardID according cursor. -// The logGroupMaxCount is the max number of logGroup could be returned. -// The nextCursor is the next curosr can be used to read logs at next time. -func (s *LogStore) GetLogsBytes(shardID int, cursor string, - logGroupMaxCount int) (out []byte, nextCursor string, err error) { - - h := map[string]string{ - "x-sls-bodyrawsize": "0", - "Accept": "application/x-protobuf", - "Accept-Encoding": "lz4", - } - - uri := fmt.Sprintf("/logstores/%v/shards/%v?type=logs&cursor=%v&count=%v", - s.Name, shardID, cursor, logGroupMaxCount) - - r, err := request(s.project, "GET", uri, h, nil) - if err != nil { - return - } - - buf, err := ioutil.ReadAll(r.Body) - if err != nil { - return - } - - if r.StatusCode != http.StatusOK { - errMsg := &errorMessage{} - err = json.Unmarshal(buf, errMsg) - if err != nil { - err = fmt.Errorf("failed to get cursor") - dump, _ := httputil.DumpResponse(r, true) - fmt.Println(dump) - return - } - err = fmt.Errorf("%v:%v", errMsg.Code, errMsg.Message) - return - } - - v, ok := r.Header["X-Sls-Compresstype"] - if !ok || len(v) == 0 { - err = fmt.Errorf("can't find 'x-sls-compresstype' header") - return - } - if v[0] != "lz4" { - err = fmt.Errorf("unexpected compress type:%v", v[0]) - return - } - - v, ok = r.Header["X-Sls-Cursor"] - if !ok || len(v) == 0 { - err = fmt.Errorf("can't find 'x-sls-cursor' header") - return - } - nextCursor = v[0] - - v, ok = r.Header["X-Sls-Bodyrawsize"] - if !ok || len(v) == 0 { - err = fmt.Errorf("can't find 'x-sls-bodyrawsize' header") - return - } - bodyRawSize, err := strconv.Atoi(v[0]) - if err != nil { - return - } - - out = make([]byte, bodyRawSize) - err = lz4.Uncompress(buf, out) - if err != nil { - return - } - - return -} - -// LogsBytesDecode decodes logs binary data retruned by GetLogsBytes API -func LogsBytesDecode(data []byte) (gl *LogGroupList, err error) { - - gl = &LogGroupList{} - err = proto.Unmarshal(data, gl) - if err != nil { - return - } - - return -} - -// GetLogs gets logs from shard specified by shardID according cursor. -// The logGroupMaxCount is the max number of logGroup could be returned. -// The nextCursor is the next curosr can be used to read logs at next time. -func (s *LogStore) GetLogs(shardID int, cursor string, - logGroupMaxCount int) (gl *LogGroupList, nextCursor string, err error) { - - out, nextCursor, err := s.GetLogsBytes(shardID, cursor, logGroupMaxCount) - if err != nil { - return - } - - gl, err = LogsBytesDecode(out) - if err != nil { - return - } - - return -} diff --git a/logs/alils/machine_group.go b/logs/alils/machine_group.go deleted file mode 100755 index b6c69a14..00000000 --- a/logs/alils/machine_group.go +++ /dev/null @@ -1,91 +0,0 @@ -package alils - -import ( - "encoding/json" - "fmt" - "io/ioutil" - "net/http" - "net/http/httputil" -) - -// MachineGroupAttribute define the Attribute -type MachineGroupAttribute struct { - ExternalName string `json:"externalName"` - TopicName string `json:"groupTopic"` -} - -// MachineGroup define the machine Group -type MachineGroup struct { - Name string `json:"groupName"` - Type string `json:"groupType"` - MachineIDType string `json:"machineIdentifyType"` - MachineIDList []string `json:"machineList"` - - Attribute MachineGroupAttribute `json:"groupAttribute"` - - CreateTime uint32 - LastModifyTime uint32 - - project *LogProject -} - -// Machine define the Machine -type Machine struct { - IP string - UniqueID string `json:"machine-uniqueid"` - UserdefinedID string `json:"userdefined-id"` -} - -// MachineList define the Machine List -type MachineList struct { - Total int - Machines []*Machine -} - -// ListMachines returns machine list of this machine group. -func (m *MachineGroup) ListMachines() (ms []*Machine, total int, err error) { - h := map[string]string{ - "x-sls-bodyrawsize": "0", - } - - uri := fmt.Sprintf("/machinegroups/%v/machines", m.Name) - r, err := request(m.project, "GET", uri, h, nil) - if err != nil { - return - } - - buf, err := ioutil.ReadAll(r.Body) - if err != nil { - return - } - - if r.StatusCode != http.StatusOK { - errMsg := &errorMessage{} - err = json.Unmarshal(buf, errMsg) - if err != nil { - err = fmt.Errorf("failed to remove config from machine group") - dump, _ := httputil.DumpResponse(r, true) - fmt.Println(dump) - return - } - err = fmt.Errorf("%v:%v", errMsg.Code, errMsg.Message) - return - } - - body := &MachineList{} - err = json.Unmarshal(buf, body) - if err != nil { - return - } - - ms = body.Machines - total = body.Total - - return -} - -// GetAppliedConfigs returns applied configs of this machine group. -func (m *MachineGroup) GetAppliedConfigs() (confNames []string, err error) { - confNames, err = m.project.GetAppliedConfigs(m.Name) - return -} diff --git a/logs/alils/request.go b/logs/alils/request.go deleted file mode 100755 index 50d9c43c..00000000 --- a/logs/alils/request.go +++ /dev/null @@ -1,62 +0,0 @@ -package alils - -import ( - "bytes" - "crypto/md5" - "fmt" - "net/http" -) - -// request sends a request to SLS. -func request(project *LogProject, method, uri string, headers map[string]string, - body []byte) (resp *http.Response, err error) { - - // The caller should provide 'x-sls-bodyrawsize' header - if _, ok := headers["x-sls-bodyrawsize"]; !ok { - err = fmt.Errorf("Can't find 'x-sls-bodyrawsize' header") - return - } - - // SLS public request headers - headers["Host"] = project.Name + "." + project.Endpoint - headers["Date"] = nowRFC1123() - headers["x-sls-apiversion"] = version - headers["x-sls-signaturemethod"] = signatureMethod - if body != nil { - bodyMD5 := fmt.Sprintf("%X", md5.Sum(body)) - headers["Content-MD5"] = bodyMD5 - - if _, ok := headers["Content-Type"]; !ok { - err = fmt.Errorf("Can't find 'Content-Type' header") - return - } - } - - // Calc Authorization - // Authorization = "SLS :" - digest, err := signature(project, method, uri, headers) - if err != nil { - return - } - auth := fmt.Sprintf("SLS %v:%v", project.AccessKeyID, digest) - headers["Authorization"] = auth - - // Initialize http request - reader := bytes.NewReader(body) - urlStr := fmt.Sprintf("http://%v.%v%v", project.Name, project.Endpoint, uri) - req, err := http.NewRequest(method, urlStr, reader) - if err != nil { - return - } - for k, v := range headers { - req.Header.Add(k, v) - } - - // Get ready to do request - resp, err = http.DefaultClient.Do(req) - if err != nil { - return - } - - return -} diff --git a/logs/alils/signature.go b/logs/alils/signature.go deleted file mode 100755 index 2d611307..00000000 --- a/logs/alils/signature.go +++ /dev/null @@ -1,111 +0,0 @@ -package alils - -import ( - "crypto/hmac" - "crypto/sha1" - "encoding/base64" - "fmt" - "net/url" - "sort" - "strings" - "time" -) - -// GMT location -var gmtLoc = time.FixedZone("GMT", 0) - -// NowRFC1123 returns now time in RFC1123 format with GMT timezone, -// eg. "Mon, 02 Jan 2006 15:04:05 GMT". -func nowRFC1123() string { - return time.Now().In(gmtLoc).Format(time.RFC1123) -} - -// signature calculates a request's signature digest. -func signature(project *LogProject, method, uri string, - headers map[string]string) (digest string, err error) { - var contentMD5, contentType, date, canoHeaders, canoResource string - var slsHeaderKeys sort.StringSlice - - // SignString = VERB + "\n" - // + CONTENT-MD5 + "\n" - // + CONTENT-TYPE + "\n" - // + DATE + "\n" - // + CanonicalizedSLSHeaders + "\n" - // + CanonicalizedResource - - if val, ok := headers["Content-MD5"]; ok { - contentMD5 = val - } - - if val, ok := headers["Content-Type"]; ok { - contentType = val - } - - date, ok := headers["Date"] - if !ok { - err = fmt.Errorf("Can't find 'Date' header") - return - } - - // Calc CanonicalizedSLSHeaders - slsHeaders := make(map[string]string, len(headers)) - for k, v := range headers { - l := strings.TrimSpace(strings.ToLower(k)) - if strings.HasPrefix(l, "x-sls-") { - slsHeaders[l] = strings.TrimSpace(v) - slsHeaderKeys = append(slsHeaderKeys, l) - } - } - - sort.Sort(slsHeaderKeys) - for i, k := range slsHeaderKeys { - canoHeaders += k + ":" + slsHeaders[k] - if i+1 < len(slsHeaderKeys) { - canoHeaders += "\n" - } - } - - // Calc CanonicalizedResource - u, err := url.Parse(uri) - if err != nil { - return - } - - canoResource += url.QueryEscape(u.Path) - if u.RawQuery != "" { - var keys sort.StringSlice - - vals := u.Query() - for k := range vals { - keys = append(keys, k) - } - - sort.Sort(keys) - canoResource += "?" - for i, k := range keys { - if i > 0 { - canoResource += "&" - } - - for _, v := range vals[k] { - canoResource += k + "=" + v - } - } - } - - signStr := method + "\n" + - contentMD5 + "\n" + - contentType + "\n" + - date + "\n" + - canoHeaders + "\n" + - canoResource - - // Signature = base64(hmac-sha1(UTF8-Encoding-Of(SignString),AccessKeySecret)) - mac := hmac.New(sha1.New, []byte(project.AccessKeySecret)) - _, err = mac.Write([]byte(signStr)) - if err != nil { - return - } - digest = base64.StdEncoding.EncodeToString(mac.Sum(nil)) - return -} diff --git a/logs/conn.go b/logs/conn.go deleted file mode 100644 index 74c458ab..00000000 --- a/logs/conn.go +++ /dev/null @@ -1,119 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package logs - -import ( - "encoding/json" - "io" - "net" - "time" -) - -// connWriter implements LoggerInterface. -// it writes messages in keep-live tcp connection. -type connWriter struct { - lg *logWriter - innerWriter io.WriteCloser - ReconnectOnMsg bool `json:"reconnectOnMsg"` - Reconnect bool `json:"reconnect"` - Net string `json:"net"` - Addr string `json:"addr"` - Level int `json:"level"` -} - -// NewConn create new ConnWrite returning as LoggerInterface. -func NewConn() Logger { - conn := new(connWriter) - conn.Level = LevelTrace - return conn -} - -// Init init connection writer with json config. -// json config only need key "level". -func (c *connWriter) Init(jsonConfig string) error { - return json.Unmarshal([]byte(jsonConfig), c) -} - -// WriteMsg write message in connection. -// if connection is down, try to re-connect. -func (c *connWriter) WriteMsg(when time.Time, msg string, level int) error { - if level > c.Level { - return nil - } - if c.needToConnectOnMsg() { - err := c.connect() - if err != nil { - return err - } - } - - if c.ReconnectOnMsg { - defer c.innerWriter.Close() - } - - _, err := c.lg.writeln(when, msg) - if err != nil { - return err - } - return nil -} - -// Flush implementing method. empty. -func (c *connWriter) Flush() { - -} - -// Destroy destroy connection writer and close tcp listener. -func (c *connWriter) Destroy() { - if c.innerWriter != nil { - c.innerWriter.Close() - } -} - -func (c *connWriter) connect() error { - if c.innerWriter != nil { - c.innerWriter.Close() - c.innerWriter = nil - } - - conn, err := net.Dial(c.Net, c.Addr) - if err != nil { - return err - } - - if tcpConn, ok := conn.(*net.TCPConn); ok { - tcpConn.SetKeepAlive(true) - } - - c.innerWriter = conn - c.lg = newLogWriter(conn) - return nil -} - -func (c *connWriter) needToConnectOnMsg() bool { - if c.Reconnect { - return true - } - - if c.innerWriter == nil { - return true - } - - return c.ReconnectOnMsg -} - -func init() { - Register(AdapterConn, NewConn) -} diff --git a/logs/conn_test.go b/logs/conn_test.go deleted file mode 100644 index 7cfb4d2b..00000000 --- a/logs/conn_test.go +++ /dev/null @@ -1,80 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package logs - -import ( - "net" - "os" - "testing" -) - -// ConnTCPListener takes a TCP listener and accepts n TCP connections -// Returns connections using connChan -func connTCPListener(t *testing.T, n int, ln net.Listener, connChan chan<- net.Conn) { - - // Listen and accept n incoming connections - for i := 0; i < n; i++ { - conn, err := ln.Accept() - if err != nil { - t.Log("Error accepting connection: ", err.Error()) - os.Exit(1) - } - - // Send accepted connection to channel - connChan <- conn - } - ln.Close() - close(connChan) -} - -func TestConn(t *testing.T) { - log := NewLogger(1000) - log.SetLogger("conn", `{"net":"tcp","addr":":7020"}`) - log.Informational("informational") -} - -func TestReconnect(t *testing.T) { - // Setup connection listener - newConns := make(chan net.Conn) - connNum := 2 - ln, err := net.Listen("tcp", ":6002") - if err != nil { - t.Log("Error listening:", err.Error()) - os.Exit(1) - } - go connTCPListener(t, connNum, ln, newConns) - - // Setup logger - log := NewLogger(1000) - log.SetPrefix("test") - log.SetLogger(AdapterConn, `{"net":"tcp","reconnect":true,"level":6,"addr":":6002"}`) - log.Informational("informational 1") - - // Refuse first connection - first := <-newConns - first.Close() - - // Send another log after conn closed - log.Informational("informational 2") - - // Check if there was a second connection attempt - // close this because we moved the codes to pkg/logs - // select { - // case second := <-newConns: - // second.Close() - // default: - // t.Error("Did not reconnect") - // } -} diff --git a/logs/console.go b/logs/console.go deleted file mode 100644 index 3dcaee1d..00000000 --- a/logs/console.go +++ /dev/null @@ -1,99 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package logs - -import ( - "encoding/json" - "os" - "strings" - "time" - - "github.com/shiena/ansicolor" -) - -// brush is a color join function -type brush func(string) string - -// newBrush return a fix color Brush -func newBrush(color string) brush { - pre := "\033[" - reset := "\033[0m" - return func(text string) string { - return pre + color + "m" + text + reset - } -} - -var colors = []brush{ - newBrush("1;37"), // Emergency white - newBrush("1;36"), // Alert cyan - newBrush("1;35"), // Critical magenta - newBrush("1;31"), // Error red - newBrush("1;33"), // Warning yellow - newBrush("1;32"), // Notice green - newBrush("1;34"), // Informational blue - newBrush("1;44"), // Debug Background blue -} - -// consoleWriter implements LoggerInterface and writes messages to terminal. -type consoleWriter struct { - lg *logWriter - Level int `json:"level"` - Colorful bool `json:"color"` //this filed is useful only when system's terminal supports color -} - -// NewConsole create ConsoleWriter returning as LoggerInterface. -func NewConsole() Logger { - cw := &consoleWriter{ - lg: newLogWriter(ansicolor.NewAnsiColorWriter(os.Stdout)), - Level: LevelDebug, - Colorful: true, - } - return cw -} - -// Init init console logger. -// jsonConfig like '{"level":LevelTrace}'. -func (c *consoleWriter) Init(jsonConfig string) error { - if len(jsonConfig) == 0 { - return nil - } - return json.Unmarshal([]byte(jsonConfig), c) -} - -// WriteMsg write message in console. -func (c *consoleWriter) WriteMsg(when time.Time, msg string, level int) error { - if level > c.Level { - return nil - } - if c.Colorful { - msg = strings.Replace(msg, levelPrefix[level], colors[level](levelPrefix[level]), 1) - } - c.lg.writeln(when, msg) - return nil -} - -// Destroy implementing method. empty. -func (c *consoleWriter) Destroy() { - -} - -// Flush implementing method. empty. -func (c *consoleWriter) Flush() { - -} - -func init() { - Register(AdapterConsole, NewConsole) -} diff --git a/logs/console_test.go b/logs/console_test.go deleted file mode 100644 index 4bc45f57..00000000 --- a/logs/console_test.go +++ /dev/null @@ -1,64 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package logs - -import ( - "testing" - "time" -) - -// Try each log level in decreasing order of priority. -func testConsoleCalls(bl *BeeLogger) { - bl.Emergency("emergency") - bl.Alert("alert") - bl.Critical("critical") - bl.Error("error") - bl.Warning("warning") - bl.Notice("notice") - bl.Informational("informational") - bl.Debug("debug") -} - -// Test console logging by visually comparing the lines being output with and -// without a log level specification. -func TestConsole(t *testing.T) { - log1 := NewLogger(10000) - log1.EnableFuncCallDepth(true) - log1.SetLogger("console", "") - testConsoleCalls(log1) - - log2 := NewLogger(100) - log2.SetLogger("console", `{"level":3}`) - testConsoleCalls(log2) -} - -// Test console without color -func TestConsoleNoColor(t *testing.T) { - log := NewLogger(100) - log.SetLogger("console", `{"color":false}`) - testConsoleCalls(log) -} - -// Test console async -func TestConsoleAsync(t *testing.T) { - log := NewLogger(100) - log.SetLogger("console") - log.Async() - //log.Close() - testConsoleCalls(log) - for len(log.msgChan) != 0 { - time.Sleep(1 * time.Millisecond) - } -} diff --git a/logs/es/es.go b/logs/es/es.go deleted file mode 100644 index 2b7b1710..00000000 --- a/logs/es/es.go +++ /dev/null @@ -1,102 +0,0 @@ -package es - -import ( - "context" - "encoding/json" - "errors" - "fmt" - "net/url" - "strings" - "time" - - "github.com/elastic/go-elasticsearch/v6" - "github.com/elastic/go-elasticsearch/v6/esapi" - - "github.com/astaxie/beego/logs" -) - -// NewES return a LoggerInterface -func NewES() logs.Logger { - cw := &esLogger{ - Level: logs.LevelDebug, - } - return cw -} - -// esLogger will log msg into ES -// before you using this implementation, -// please import this package -// usually means that you can import this package in your main package -// for example, anonymous: -// import _ "github.com/astaxie/beego/logs/es" -type esLogger struct { - *elasticsearch.Client - DSN string `json:"dsn"` - Level int `json:"level"` -} - -// {"dsn":"http://localhost:9200/","level":1} -func (el *esLogger) Init(jsonconfig string) error { - err := json.Unmarshal([]byte(jsonconfig), el) - if err != nil { - return err - } - if el.DSN == "" { - return errors.New("empty dsn") - } else if u, err := url.Parse(el.DSN); err != nil { - return err - } else if u.Path == "" { - return errors.New("missing prefix") - } else { - conn, err := elasticsearch.NewClient(elasticsearch.Config{ - Addresses: []string{el.DSN}, - }) - if err != nil { - return err - } - el.Client = conn - } - return nil -} - -// WriteMsg will write the msg and level into es -func (el *esLogger) WriteMsg(when time.Time, msg string, level int) error { - if level > el.Level { - return nil - } - - idx := LogDocument{ - Timestamp: when.Format(time.RFC3339), - Msg: msg, - } - - body, err := json.Marshal(idx) - if err != nil { - return err - } - req := esapi.IndexRequest{ - Index: fmt.Sprintf("%04d.%02d.%02d", when.Year(), when.Month(), when.Day()), - DocumentType: "logs", - Body: strings.NewReader(string(body)), - } - _, err = req.Do(context.Background(), el.Client) - return err -} - -// Destroy is a empty method -func (el *esLogger) Destroy() { -} - -// Flush is a empty method -func (el *esLogger) Flush() { - -} - -type LogDocument struct { - Timestamp string `json:"timestamp"` - Msg string `json:"msg"` -} - -func init() { - logs.Register(logs.AdapterEs, NewES) -} diff --git a/logs/file.go b/logs/file.go deleted file mode 100644 index 40a3572a..00000000 --- a/logs/file.go +++ /dev/null @@ -1,409 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package logs - -import ( - "bytes" - "encoding/json" - "errors" - "fmt" - "io" - "os" - "path" - "path/filepath" - "strconv" - "strings" - "sync" - "time" -) - -// fileLogWriter implements LoggerInterface. -// It writes messages by lines limit, file size limit, or time frequency. -type fileLogWriter struct { - sync.RWMutex // write log order by order and atomic incr maxLinesCurLines and maxSizeCurSize - // The opened file - Filename string `json:"filename"` - fileWriter *os.File - - // Rotate at line - MaxLines int `json:"maxlines"` - maxLinesCurLines int - - MaxFiles int `json:"maxfiles"` - MaxFilesCurFiles int - - // Rotate at size - MaxSize int `json:"maxsize"` - maxSizeCurSize int - - // Rotate daily - Daily bool `json:"daily"` - MaxDays int64 `json:"maxdays"` - dailyOpenDate int - dailyOpenTime time.Time - - // Rotate hourly - Hourly bool `json:"hourly"` - MaxHours int64 `json:"maxhours"` - hourlyOpenDate int - hourlyOpenTime time.Time - - Rotate bool `json:"rotate"` - - Level int `json:"level"` - - Perm string `json:"perm"` - - RotatePerm string `json:"rotateperm"` - - fileNameOnly, suffix string // like "project.log", project is fileNameOnly and .log is suffix -} - -// newFileWriter create a FileLogWriter returning as LoggerInterface. -func newFileWriter() Logger { - w := &fileLogWriter{ - Daily: true, - MaxDays: 7, - Hourly: false, - MaxHours: 168, - Rotate: true, - RotatePerm: "0440", - Level: LevelTrace, - Perm: "0660", - MaxLines: 10000000, - MaxFiles: 999, - MaxSize: 1 << 28, - } - return w -} - -// Init file logger with json config. -// jsonConfig like: -// { -// "filename":"logs/beego.log", -// "maxLines":10000, -// "maxsize":1024, -// "daily":true, -// "maxDays":15, -// "rotate":true, -// "perm":"0600" -// } -func (w *fileLogWriter) Init(jsonConfig string) error { - err := json.Unmarshal([]byte(jsonConfig), w) - if err != nil { - return err - } - if len(w.Filename) == 0 { - return errors.New("jsonconfig must have filename") - } - w.suffix = filepath.Ext(w.Filename) - w.fileNameOnly = strings.TrimSuffix(w.Filename, w.suffix) - if w.suffix == "" { - w.suffix = ".log" - } - err = w.startLogger() - return err -} - -// start file logger. create log file and set to locker-inside file writer. -func (w *fileLogWriter) startLogger() error { - file, err := w.createLogFile() - if err != nil { - return err - } - if w.fileWriter != nil { - w.fileWriter.Close() - } - w.fileWriter = file - return w.initFd() -} - -func (w *fileLogWriter) needRotateDaily(size int, day int) bool { - return (w.MaxLines > 0 && w.maxLinesCurLines >= w.MaxLines) || - (w.MaxSize > 0 && w.maxSizeCurSize >= w.MaxSize) || - (w.Daily && day != w.dailyOpenDate) -} - -func (w *fileLogWriter) needRotateHourly(size int, hour int) bool { - return (w.MaxLines > 0 && w.maxLinesCurLines >= w.MaxLines) || - (w.MaxSize > 0 && w.maxSizeCurSize >= w.MaxSize) || - (w.Hourly && hour != w.hourlyOpenDate) - -} - -// WriteMsg write logger message into file. -func (w *fileLogWriter) WriteMsg(when time.Time, msg string, level int) error { - if level > w.Level { - return nil - } - hd, d, h := formatTimeHeader(when) - msg = string(hd) + msg + "\n" - if w.Rotate { - w.RLock() - if w.needRotateHourly(len(msg), h) { - w.RUnlock() - w.Lock() - if w.needRotateHourly(len(msg), h) { - if err := w.doRotate(when); err != nil { - fmt.Fprintf(os.Stderr, "FileLogWriter(%q): %s\n", w.Filename, err) - } - } - w.Unlock() - } else if w.needRotateDaily(len(msg), d) { - w.RUnlock() - w.Lock() - if w.needRotateDaily(len(msg), d) { - if err := w.doRotate(when); err != nil { - fmt.Fprintf(os.Stderr, "FileLogWriter(%q): %s\n", w.Filename, err) - } - } - w.Unlock() - } else { - w.RUnlock() - } - } - - w.Lock() - _, err := w.fileWriter.Write([]byte(msg)) - if err == nil { - w.maxLinesCurLines++ - w.maxSizeCurSize += len(msg) - } - w.Unlock() - return err -} - -func (w *fileLogWriter) createLogFile() (*os.File, error) { - // Open the log file - perm, err := strconv.ParseInt(w.Perm, 8, 64) - if err != nil { - return nil, err - } - - filepath := path.Dir(w.Filename) - os.MkdirAll(filepath, os.FileMode(perm)) - - fd, err := os.OpenFile(w.Filename, os.O_WRONLY|os.O_APPEND|os.O_CREATE, os.FileMode(perm)) - if err == nil { - // Make sure file perm is user set perm cause of `os.OpenFile` will obey umask - os.Chmod(w.Filename, os.FileMode(perm)) - } - return fd, err -} - -func (w *fileLogWriter) initFd() error { - fd := w.fileWriter - fInfo, err := fd.Stat() - if err != nil { - return fmt.Errorf("get stat err: %s", err) - } - w.maxSizeCurSize = int(fInfo.Size()) - w.dailyOpenTime = time.Now() - w.dailyOpenDate = w.dailyOpenTime.Day() - w.hourlyOpenTime = time.Now() - w.hourlyOpenDate = w.hourlyOpenTime.Hour() - w.maxLinesCurLines = 0 - if w.Hourly { - go w.hourlyRotate(w.hourlyOpenTime) - } else if w.Daily { - go w.dailyRotate(w.dailyOpenTime) - } - if fInfo.Size() > 0 && w.MaxLines > 0 { - count, err := w.lines() - if err != nil { - return err - } - w.maxLinesCurLines = count - } - return nil -} - -func (w *fileLogWriter) dailyRotate(openTime time.Time) { - y, m, d := openTime.Add(24 * time.Hour).Date() - nextDay := time.Date(y, m, d, 0, 0, 0, 0, openTime.Location()) - tm := time.NewTimer(time.Duration(nextDay.UnixNano() - openTime.UnixNano() + 100)) - <-tm.C - w.Lock() - if w.needRotateDaily(0, time.Now().Day()) { - if err := w.doRotate(time.Now()); err != nil { - fmt.Fprintf(os.Stderr, "FileLogWriter(%q): %s\n", w.Filename, err) - } - } - w.Unlock() -} - -func (w *fileLogWriter) hourlyRotate(openTime time.Time) { - y, m, d := openTime.Add(1 * time.Hour).Date() - h, _, _ := openTime.Add(1 * time.Hour).Clock() - nextHour := time.Date(y, m, d, h, 0, 0, 0, openTime.Location()) - tm := time.NewTimer(time.Duration(nextHour.UnixNano() - openTime.UnixNano() + 100)) - <-tm.C - w.Lock() - if w.needRotateHourly(0, time.Now().Hour()) { - if err := w.doRotate(time.Now()); err != nil { - fmt.Fprintf(os.Stderr, "FileLogWriter(%q): %s\n", w.Filename, err) - } - } - w.Unlock() -} - -func (w *fileLogWriter) lines() (int, error) { - fd, err := os.Open(w.Filename) - if err != nil { - return 0, err - } - defer fd.Close() - - buf := make([]byte, 32768) // 32k - count := 0 - lineSep := []byte{'\n'} - - for { - c, err := fd.Read(buf) - if err != nil && err != io.EOF { - return count, err - } - - count += bytes.Count(buf[:c], lineSep) - - if err == io.EOF { - break - } - } - - return count, nil -} - -// DoRotate means it need to write file in new file. -// new file name like xx.2013-01-01.log (daily) or xx.001.log (by line or size) -func (w *fileLogWriter) doRotate(logTime time.Time) error { - // file exists - // Find the next available number - num := w.MaxFilesCurFiles + 1 - fName := "" - format := "" - var openTime time.Time - rotatePerm, err := strconv.ParseInt(w.RotatePerm, 8, 64) - if err != nil { - return err - } - - _, err = os.Lstat(w.Filename) - if err != nil { - //even if the file is not exist or other ,we should RESTART the logger - goto RESTART_LOGGER - } - - if w.Hourly { - format = "2006010215" - openTime = w.hourlyOpenTime - } else if w.Daily { - format = "2006-01-02" - openTime = w.dailyOpenTime - } - - // only when one of them be setted, then the file would be splited - if w.MaxLines > 0 || w.MaxSize > 0 { - for ; err == nil && num <= w.MaxFiles; num++ { - fName = w.fileNameOnly + fmt.Sprintf(".%s.%03d%s", logTime.Format(format), num, w.suffix) - _, err = os.Lstat(fName) - } - } else { - fName = w.fileNameOnly + fmt.Sprintf(".%s.%03d%s", openTime.Format(format), num, w.suffix) - _, err = os.Lstat(fName) - w.MaxFilesCurFiles = num - } - - // return error if the last file checked still existed - if err == nil { - return fmt.Errorf("Rotate: Cannot find free log number to rename %s", w.Filename) - } - - // close fileWriter before rename - w.fileWriter.Close() - - // Rename the file to its new found name - // even if occurs error,we MUST guarantee to restart new logger - err = os.Rename(w.Filename, fName) - if err != nil { - goto RESTART_LOGGER - } - - err = os.Chmod(fName, os.FileMode(rotatePerm)) - -RESTART_LOGGER: - - startLoggerErr := w.startLogger() - go w.deleteOldLog() - - if startLoggerErr != nil { - return fmt.Errorf("Rotate StartLogger: %s", startLoggerErr) - } - if err != nil { - return fmt.Errorf("Rotate: %s", err) - } - return nil -} - -func (w *fileLogWriter) deleteOldLog() { - dir := filepath.Dir(w.Filename) - absolutePath, err := filepath.EvalSymlinks(w.Filename) - if err == nil { - dir = filepath.Dir(absolutePath) - } - filepath.Walk(dir, func(path string, info os.FileInfo, err error) (returnErr error) { - defer func() { - if r := recover(); r != nil { - fmt.Fprintf(os.Stderr, "Unable to delete old log '%s', error: %v\n", path, r) - } - }() - - if info == nil { - return - } - if w.Hourly { - if !info.IsDir() && info.ModTime().Add(1*time.Hour*time.Duration(w.MaxHours)).Before(time.Now()) { - if strings.HasPrefix(filepath.Base(path), filepath.Base(w.fileNameOnly)) && - strings.HasSuffix(filepath.Base(path), w.suffix) { - os.Remove(path) - } - } - } else if w.Daily { - if !info.IsDir() && info.ModTime().Add(24*time.Hour*time.Duration(w.MaxDays)).Before(time.Now()) { - if strings.HasPrefix(filepath.Base(path), filepath.Base(w.fileNameOnly)) && - strings.HasSuffix(filepath.Base(path), w.suffix) { - os.Remove(path) - } - } - } - return - }) -} - -// Destroy close the file description, close file writer. -func (w *fileLogWriter) Destroy() { - w.fileWriter.Close() -} - -// Flush flush file logger. -// there are no buffering messages in file logger in memory. -// flush file means sync file from disk. -func (w *fileLogWriter) Flush() { - w.fileWriter.Sync() -} - -func init() { - Register(AdapterFile, newFileWriter) -} diff --git a/logs/file_test.go b/logs/file_test.go deleted file mode 100644 index 385eac43..00000000 --- a/logs/file_test.go +++ /dev/null @@ -1,420 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package logs - -import ( - "bufio" - "fmt" - "io/ioutil" - "os" - "strconv" - "testing" - "time" -) - -func TestFilePerm(t *testing.T) { - log := NewLogger(10000) - // use 0666 as test perm cause the default umask is 022 - log.SetLogger("file", `{"filename":"test.log", "perm": "0666"}`) - log.Debug("debug") - log.Informational("info") - log.Notice("notice") - log.Warning("warning") - log.Error("error") - log.Alert("alert") - log.Critical("critical") - log.Emergency("emergency") - file, err := os.Stat("test.log") - if err != nil { - t.Fatal(err) - } - if file.Mode() != 0666 { - t.Fatal("unexpected log file permission") - } - os.Remove("test.log") -} - -func TestFile1(t *testing.T) { - log := NewLogger(10000) - log.SetLogger("file", `{"filename":"test.log"}`) - log.Debug("debug") - log.Informational("info") - log.Notice("notice") - log.Warning("warning") - log.Error("error") - log.Alert("alert") - log.Critical("critical") - log.Emergency("emergency") - f, err := os.Open("test.log") - if err != nil { - t.Fatal(err) - } - b := bufio.NewReader(f) - lineNum := 0 - for { - line, _, err := b.ReadLine() - if err != nil { - break - } - if len(line) > 0 { - lineNum++ - } - } - var expected = LevelDebug + 1 - if lineNum != expected { - t.Fatal(lineNum, "not "+strconv.Itoa(expected)+" lines") - } - os.Remove("test.log") -} - -func TestFile2(t *testing.T) { - log := NewLogger(10000) - log.SetLogger("file", fmt.Sprintf(`{"filename":"test2.log","level":%d}`, LevelError)) - log.Debug("debug") - log.Info("info") - log.Notice("notice") - log.Warning("warning") - log.Error("error") - log.Alert("alert") - log.Critical("critical") - log.Emergency("emergency") - f, err := os.Open("test2.log") - if err != nil { - t.Fatal(err) - } - b := bufio.NewReader(f) - lineNum := 0 - for { - line, _, err := b.ReadLine() - if err != nil { - break - } - if len(line) > 0 { - lineNum++ - } - } - var expected = LevelError + 1 - if lineNum != expected { - t.Fatal(lineNum, "not "+strconv.Itoa(expected)+" lines") - } - os.Remove("test2.log") -} - -func TestFileDailyRotate_01(t *testing.T) { - log := NewLogger(10000) - log.SetLogger("file", `{"filename":"test3.log","maxlines":4}`) - log.Debug("debug") - log.Info("info") - log.Notice("notice") - log.Warning("warning") - log.Error("error") - log.Alert("alert") - log.Critical("critical") - log.Emergency("emergency") - rotateName := "test3" + fmt.Sprintf(".%s.%03d", time.Now().Format("2006-01-02"), 1) + ".log" - b, err := exists(rotateName) - if !b || err != nil { - os.Remove("test3.log") - t.Fatal("rotate not generated") - } - os.Remove(rotateName) - os.Remove("test3.log") -} - -func TestFileDailyRotate_02(t *testing.T) { - fn1 := "rotate_day.log" - fn2 := "rotate_day." + time.Now().Add(-24*time.Hour).Format("2006-01-02") + ".001.log" - testFileRotate(t, fn1, fn2, true, false) -} - -func TestFileDailyRotate_03(t *testing.T) { - fn1 := "rotate_day.log" - fn := "rotate_day." + time.Now().Add(-24*time.Hour).Format("2006-01-02") + ".log" - os.Create(fn) - fn2 := "rotate_day." + time.Now().Add(-24*time.Hour).Format("2006-01-02") + ".001.log" - testFileRotate(t, fn1, fn2, true, false) - os.Remove(fn) -} - -func TestFileDailyRotate_04(t *testing.T) { - fn1 := "rotate_day.log" - fn2 := "rotate_day." + time.Now().Add(-24*time.Hour).Format("2006-01-02") + ".001.log" - testFileDailyRotate(t, fn1, fn2) -} - -func TestFileDailyRotate_05(t *testing.T) { - fn1 := "rotate_day.log" - fn := "rotate_day." + time.Now().Add(-24*time.Hour).Format("2006-01-02") + ".log" - os.Create(fn) - fn2 := "rotate_day." + time.Now().Add(-24*time.Hour).Format("2006-01-02") + ".001.log" - testFileDailyRotate(t, fn1, fn2) - os.Remove(fn) -} -func TestFileDailyRotate_06(t *testing.T) { //test file mode - log := NewLogger(10000) - log.SetLogger("file", `{"filename":"test3.log","maxlines":4}`) - log.Debug("debug") - log.Info("info") - log.Notice("notice") - log.Warning("warning") - log.Error("error") - log.Alert("alert") - log.Critical("critical") - log.Emergency("emergency") - rotateName := "test3" + fmt.Sprintf(".%s.%03d", time.Now().Format("2006-01-02"), 1) + ".log" - s, _ := os.Lstat(rotateName) - if s.Mode() != 0440 { - os.Remove(rotateName) - os.Remove("test3.log") - t.Fatal("rotate file mode error") - } - os.Remove(rotateName) - os.Remove("test3.log") -} - -func TestFileHourlyRotate_01(t *testing.T) { - log := NewLogger(10000) - log.SetLogger("file", `{"filename":"test3.log","hourly":true,"maxlines":4}`) - log.Debug("debug") - log.Info("info") - log.Notice("notice") - log.Warning("warning") - log.Error("error") - log.Alert("alert") - log.Critical("critical") - log.Emergency("emergency") - rotateName := "test3" + fmt.Sprintf(".%s.%03d", time.Now().Format("2006010215"), 1) + ".log" - b, err := exists(rotateName) - if !b || err != nil { - os.Remove("test3.log") - t.Fatal("rotate not generated") - } - os.Remove(rotateName) - os.Remove("test3.log") -} - -func TestFileHourlyRotate_02(t *testing.T) { - fn1 := "rotate_hour.log" - fn2 := "rotate_hour." + time.Now().Add(-1*time.Hour).Format("2006010215") + ".001.log" - testFileRotate(t, fn1, fn2, false, true) -} - -func TestFileHourlyRotate_03(t *testing.T) { - fn1 := "rotate_hour.log" - fn := "rotate_hour." + time.Now().Add(-1*time.Hour).Format("2006010215") + ".log" - os.Create(fn) - fn2 := "rotate_hour." + time.Now().Add(-1*time.Hour).Format("2006010215") + ".001.log" - testFileRotate(t, fn1, fn2, false, true) - os.Remove(fn) -} - -func TestFileHourlyRotate_04(t *testing.T) { - fn1 := "rotate_hour.log" - fn2 := "rotate_hour." + time.Now().Add(-1*time.Hour).Format("2006010215") + ".001.log" - testFileHourlyRotate(t, fn1, fn2) -} - -func TestFileHourlyRotate_05(t *testing.T) { - fn1 := "rotate_hour.log" - fn := "rotate_hour." + time.Now().Add(-1*time.Hour).Format("2006010215") + ".log" - os.Create(fn) - fn2 := "rotate_hour." + time.Now().Add(-1*time.Hour).Format("2006010215") + ".001.log" - testFileHourlyRotate(t, fn1, fn2) - os.Remove(fn) -} - -func TestFileHourlyRotate_06(t *testing.T) { //test file mode - log := NewLogger(10000) - log.SetLogger("file", `{"filename":"test3.log", "hourly":true, "maxlines":4}`) - log.Debug("debug") - log.Info("info") - log.Notice("notice") - log.Warning("warning") - log.Error("error") - log.Alert("alert") - log.Critical("critical") - log.Emergency("emergency") - rotateName := "test3" + fmt.Sprintf(".%s.%03d", time.Now().Format("2006010215"), 1) + ".log" - s, _ := os.Lstat(rotateName) - if s.Mode() != 0440 { - os.Remove(rotateName) - os.Remove("test3.log") - t.Fatal("rotate file mode error") - } - os.Remove(rotateName) - os.Remove("test3.log") -} - -func testFileRotate(t *testing.T, fn1, fn2 string, daily, hourly bool) { - fw := &fileLogWriter{ - Daily: daily, - MaxDays: 7, - Hourly: hourly, - MaxHours: 168, - Rotate: true, - Level: LevelTrace, - Perm: "0660", - RotatePerm: "0440", - } - - if daily { - fw.Init(fmt.Sprintf(`{"filename":"%v","maxdays":1}`, fn1)) - fw.dailyOpenTime = time.Now().Add(-24 * time.Hour) - fw.dailyOpenDate = fw.dailyOpenTime.Day() - } - - if hourly { - fw.Init(fmt.Sprintf(`{"filename":"%v","maxhours":1}`, fn1)) - fw.hourlyOpenTime = time.Now().Add(-1 * time.Hour) - fw.hourlyOpenDate = fw.hourlyOpenTime.Day() - } - - fw.WriteMsg(time.Now(), "this is a msg for test", LevelDebug) - - for _, file := range []string{fn1, fn2} { - _, err := os.Stat(file) - if err != nil { - t.Log(err) - t.FailNow() - } - os.Remove(file) - } - fw.Destroy() -} - -func testFileDailyRotate(t *testing.T, fn1, fn2 string) { - fw := &fileLogWriter{ - Daily: true, - MaxDays: 7, - Rotate: true, - Level: LevelTrace, - Perm: "0660", - RotatePerm: "0440", - } - fw.Init(fmt.Sprintf(`{"filename":"%v","maxdays":1}`, fn1)) - fw.dailyOpenTime = time.Now().Add(-24 * time.Hour) - fw.dailyOpenDate = fw.dailyOpenTime.Day() - today, _ := time.ParseInLocation("2006-01-02", time.Now().Format("2006-01-02"), fw.dailyOpenTime.Location()) - today = today.Add(-1 * time.Second) - fw.dailyRotate(today) - for _, file := range []string{fn1, fn2} { - _, err := os.Stat(file) - if err != nil { - t.FailNow() - } - content, err := ioutil.ReadFile(file) - if err != nil { - t.FailNow() - } - if len(content) > 0 { - t.FailNow() - } - os.Remove(file) - } - fw.Destroy() -} - -func testFileHourlyRotate(t *testing.T, fn1, fn2 string) { - fw := &fileLogWriter{ - Hourly: true, - MaxHours: 168, - Rotate: true, - Level: LevelTrace, - Perm: "0660", - RotatePerm: "0440", - } - fw.Init(fmt.Sprintf(`{"filename":"%v","maxhours":1}`, fn1)) - fw.hourlyOpenTime = time.Now().Add(-1 * time.Hour) - fw.hourlyOpenDate = fw.hourlyOpenTime.Hour() - hour, _ := time.ParseInLocation("2006010215", time.Now().Format("2006010215"), fw.hourlyOpenTime.Location()) - hour = hour.Add(-1 * time.Second) - fw.hourlyRotate(hour) - for _, file := range []string{fn1, fn2} { - _, err := os.Stat(file) - if err != nil { - t.FailNow() - } - content, err := ioutil.ReadFile(file) - if err != nil { - t.FailNow() - } - if len(content) > 0 { - t.FailNow() - } - os.Remove(file) - } - fw.Destroy() -} -func exists(path string) (bool, error) { - _, err := os.Stat(path) - if err == nil { - return true, nil - } - if os.IsNotExist(err) { - return false, nil - } - return false, err -} - -func BenchmarkFile(b *testing.B) { - log := NewLogger(100000) - log.SetLogger("file", `{"filename":"test4.log"}`) - for i := 0; i < b.N; i++ { - log.Debug("debug") - } - os.Remove("test4.log") -} - -func BenchmarkFileAsynchronous(b *testing.B) { - log := NewLogger(100000) - log.SetLogger("file", `{"filename":"test4.log"}`) - log.Async() - for i := 0; i < b.N; i++ { - log.Debug("debug") - } - os.Remove("test4.log") -} - -func BenchmarkFileCallDepth(b *testing.B) { - log := NewLogger(100000) - log.SetLogger("file", `{"filename":"test4.log"}`) - log.EnableFuncCallDepth(true) - log.SetLogFuncCallDepth(2) - for i := 0; i < b.N; i++ { - log.Debug("debug") - } - os.Remove("test4.log") -} - -func BenchmarkFileAsynchronousCallDepth(b *testing.B) { - log := NewLogger(100000) - log.SetLogger("file", `{"filename":"test4.log"}`) - log.EnableFuncCallDepth(true) - log.SetLogFuncCallDepth(2) - log.Async() - for i := 0; i < b.N; i++ { - log.Debug("debug") - } - os.Remove("test4.log") -} - -func BenchmarkFileOnGoroutine(b *testing.B) { - log := NewLogger(100000) - log.SetLogger("file", `{"filename":"test4.log"}`) - for i := 0; i < b.N; i++ { - go log.Debug("debug") - } - os.Remove("test4.log") -} diff --git a/logs/jianliao.go b/logs/jianliao.go deleted file mode 100644 index 88ba0f9a..00000000 --- a/logs/jianliao.go +++ /dev/null @@ -1,72 +0,0 @@ -package logs - -import ( - "encoding/json" - "fmt" - "net/http" - "net/url" - "time" -) - -// JLWriter implements beego LoggerInterface and is used to send jiaoliao webhook -type JLWriter struct { - AuthorName string `json:"authorname"` - Title string `json:"title"` - WebhookURL string `json:"webhookurl"` - RedirectURL string `json:"redirecturl,omitempty"` - ImageURL string `json:"imageurl,omitempty"` - Level int `json:"level"` -} - -// newJLWriter create jiaoliao writer. -func newJLWriter() Logger { - return &JLWriter{Level: LevelTrace} -} - -// Init JLWriter with json config string -func (s *JLWriter) Init(jsonconfig string) error { - return json.Unmarshal([]byte(jsonconfig), s) -} - -// WriteMsg write message in smtp writer. -// it will send an email with subject and only this message. -func (s *JLWriter) WriteMsg(when time.Time, msg string, level int) error { - if level > s.Level { - return nil - } - - text := fmt.Sprintf("%s %s", when.Format("2006-01-02 15:04:05"), msg) - - form := url.Values{} - form.Add("authorName", s.AuthorName) - form.Add("title", s.Title) - form.Add("text", text) - if s.RedirectURL != "" { - form.Add("redirectUrl", s.RedirectURL) - } - if s.ImageURL != "" { - form.Add("imageUrl", s.ImageURL) - } - - resp, err := http.PostForm(s.WebhookURL, form) - if err != nil { - return err - } - defer resp.Body.Close() - if resp.StatusCode != http.StatusOK { - return fmt.Errorf("Post webhook failed %s %d", resp.Status, resp.StatusCode) - } - return nil -} - -// Flush implementing method. empty. -func (s *JLWriter) Flush() { -} - -// Destroy implementing method. empty. -func (s *JLWriter) Destroy() { -} - -func init() { - Register(AdapterJianLiao, newJLWriter) -} diff --git a/logs/log.go b/logs/log.go deleted file mode 100644 index 39c006d2..00000000 --- a/logs/log.go +++ /dev/null @@ -1,669 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Package logs provide a general log interface -// Usage: -// -// import "github.com/astaxie/beego/logs" -// -// log := NewLogger(10000) -// log.SetLogger("console", "") -// -// > the first params stand for how many channel -// -// Use it like this: -// -// log.Trace("trace") -// log.Info("info") -// log.Warn("warning") -// log.Debug("debug") -// log.Critical("critical") -// -// more docs http://beego.me/docs/module/logs.md -package logs - -import ( - "fmt" - "log" - "os" - "path" - "runtime" - "strconv" - "strings" - "sync" - "time" -) - -// RFC5424 log message levels. -const ( - LevelEmergency = iota - LevelAlert - LevelCritical - LevelError - LevelWarning - LevelNotice - LevelInformational - LevelDebug -) - -// levelLogLogger is defined to implement log.Logger -// the real log level will be LevelEmergency -const levelLoggerImpl = -1 - -// Name for adapter with beego official support -const ( - AdapterConsole = "console" - AdapterFile = "file" - AdapterMultiFile = "multifile" - AdapterMail = "smtp" - AdapterConn = "conn" - AdapterEs = "es" - AdapterJianLiao = "jianliao" - AdapterSlack = "slack" - AdapterAliLS = "alils" -) - -// Legacy log level constants to ensure backwards compatibility. -const ( - LevelInfo = LevelInformational - LevelTrace = LevelDebug - LevelWarn = LevelWarning -) - -type newLoggerFunc func() Logger - -// Logger defines the behavior of a log provider. -type Logger interface { - Init(config string) error - WriteMsg(when time.Time, msg string, level int) error - Destroy() - Flush() -} - -var adapters = make(map[string]newLoggerFunc) -var levelPrefix = [LevelDebug + 1]string{"[M]", "[A]", "[C]", "[E]", "[W]", "[N]", "[I]", "[D]"} - -// Register makes a log provide available by the provided name. -// If Register is called twice with the same name or if driver is nil, -// it panics. -func Register(name string, log newLoggerFunc) { - if log == nil { - panic("logs: Register provide is nil") - } - if _, dup := adapters[name]; dup { - panic("logs: Register called twice for provider " + name) - } - adapters[name] = log -} - -// BeeLogger is default logger in beego application. -// it can contain several providers and log message into all providers. -type BeeLogger struct { - lock sync.Mutex - level int - init bool - enableFuncCallDepth bool - loggerFuncCallDepth int - asynchronous bool - prefix string - msgChanLen int64 - msgChan chan *logMsg - signalChan chan string - wg sync.WaitGroup - outputs []*nameLogger -} - -const defaultAsyncMsgLen = 1e3 - -type nameLogger struct { - Logger - name string -} - -type logMsg struct { - level int - msg string - when time.Time -} - -var logMsgPool *sync.Pool - -// NewLogger returns a new BeeLogger. -// channelLen means the number of messages in chan(used where asynchronous is true). -// if the buffering chan is full, logger adapters write to file or other way. -func NewLogger(channelLens ...int64) *BeeLogger { - bl := new(BeeLogger) - bl.level = LevelDebug - bl.loggerFuncCallDepth = 2 - bl.msgChanLen = append(channelLens, 0)[0] - if bl.msgChanLen <= 0 { - bl.msgChanLen = defaultAsyncMsgLen - } - bl.signalChan = make(chan string, 1) - bl.setLogger(AdapterConsole) - return bl -} - -// Async set the log to asynchronous and start the goroutine -func (bl *BeeLogger) Async(msgLen ...int64) *BeeLogger { - bl.lock.Lock() - defer bl.lock.Unlock() - if bl.asynchronous { - return bl - } - bl.asynchronous = true - if len(msgLen) > 0 && msgLen[0] > 0 { - bl.msgChanLen = msgLen[0] - } - bl.msgChan = make(chan *logMsg, bl.msgChanLen) - logMsgPool = &sync.Pool{ - New: func() interface{} { - return &logMsg{} - }, - } - bl.wg.Add(1) - go bl.startLogger() - return bl -} - -// SetLogger provides a given logger adapter into BeeLogger with config string. -// config need to be correct JSON as string: {"interval":360}. -func (bl *BeeLogger) setLogger(adapterName string, configs ...string) error { - config := append(configs, "{}")[0] - for _, l := range bl.outputs { - if l.name == adapterName { - return fmt.Errorf("logs: duplicate adaptername %q (you have set this logger before)", adapterName) - } - } - - logAdapter, ok := adapters[adapterName] - if !ok { - return fmt.Errorf("logs: unknown adaptername %q (forgotten Register?)", adapterName) - } - - lg := logAdapter() - err := lg.Init(config) - if err != nil { - fmt.Fprintln(os.Stderr, "logs.BeeLogger.SetLogger: "+err.Error()) - return err - } - bl.outputs = append(bl.outputs, &nameLogger{name: adapterName, Logger: lg}) - return nil -} - -// SetLogger provides a given logger adapter into BeeLogger with config string. -// config need to be correct JSON as string: {"interval":360}. -func (bl *BeeLogger) SetLogger(adapterName string, configs ...string) error { - bl.lock.Lock() - defer bl.lock.Unlock() - if !bl.init { - bl.outputs = []*nameLogger{} - bl.init = true - } - return bl.setLogger(adapterName, configs...) -} - -// DelLogger remove a logger adapter in BeeLogger. -func (bl *BeeLogger) DelLogger(adapterName string) error { - bl.lock.Lock() - defer bl.lock.Unlock() - outputs := []*nameLogger{} - for _, lg := range bl.outputs { - if lg.name == adapterName { - lg.Destroy() - } else { - outputs = append(outputs, lg) - } - } - if len(outputs) == len(bl.outputs) { - return fmt.Errorf("logs: unknown adaptername %q (forgotten Register?)", adapterName) - } - bl.outputs = outputs - return nil -} - -func (bl *BeeLogger) writeToLoggers(when time.Time, msg string, level int) { - for _, l := range bl.outputs { - err := l.WriteMsg(when, msg, level) - if err != nil { - fmt.Fprintf(os.Stderr, "unable to WriteMsg to adapter:%v,error:%v\n", l.name, err) - } - } -} - -func (bl *BeeLogger) Write(p []byte) (n int, err error) { - if len(p) == 0 { - return 0, nil - } - // writeMsg will always add a '\n' character - if p[len(p)-1] == '\n' { - p = p[0 : len(p)-1] - } - // set levelLoggerImpl to ensure all log message will be write out - err = bl.writeMsg(levelLoggerImpl, string(p)) - if err == nil { - return len(p), err - } - return 0, err -} - -func (bl *BeeLogger) writeMsg(logLevel int, msg string, v ...interface{}) error { - if !bl.init { - bl.lock.Lock() - bl.setLogger(AdapterConsole) - bl.lock.Unlock() - } - - if len(v) > 0 { - msg = fmt.Sprintf(msg, v...) - } - - msg = bl.prefix + " " + msg - - when := time.Now() - if bl.enableFuncCallDepth { - _, file, line, ok := runtime.Caller(bl.loggerFuncCallDepth) - if !ok { - file = "???" - line = 0 - } - _, filename := path.Split(file) - msg = "[" + filename + ":" + strconv.Itoa(line) + "] " + msg - } - - //set level info in front of filename info - if logLevel == levelLoggerImpl { - // set to emergency to ensure all log will be print out correctly - logLevel = LevelEmergency - } else { - msg = levelPrefix[logLevel] + " " + msg - } - - if bl.asynchronous { - lm := logMsgPool.Get().(*logMsg) - lm.level = logLevel - lm.msg = msg - lm.when = when - if bl.outputs != nil { - bl.msgChan <- lm - } else { - logMsgPool.Put(lm) - } - } else { - bl.writeToLoggers(when, msg, logLevel) - } - return nil -} - -// SetLevel Set log message level. -// If message level (such as LevelDebug) is higher than logger level (such as LevelWarning), -// log providers will not even be sent the message. -func (bl *BeeLogger) SetLevel(l int) { - bl.level = l -} - -// GetLevel Get Current log message level. -func (bl *BeeLogger) GetLevel() int { - return bl.level -} - -// SetLogFuncCallDepth set log funcCallDepth -func (bl *BeeLogger) SetLogFuncCallDepth(d int) { - bl.loggerFuncCallDepth = d -} - -// GetLogFuncCallDepth return log funcCallDepth for wrapper -func (bl *BeeLogger) GetLogFuncCallDepth() int { - return bl.loggerFuncCallDepth -} - -// EnableFuncCallDepth enable log funcCallDepth -func (bl *BeeLogger) EnableFuncCallDepth(b bool) { - bl.enableFuncCallDepth = b -} - -// set prefix -func (bl *BeeLogger) SetPrefix(s string) { - bl.prefix = s -} - -// start logger chan reading. -// when chan is not empty, write logs. -func (bl *BeeLogger) startLogger() { - gameOver := false - for { - select { - case bm := <-bl.msgChan: - bl.writeToLoggers(bm.when, bm.msg, bm.level) - logMsgPool.Put(bm) - case sg := <-bl.signalChan: - // Now should only send "flush" or "close" to bl.signalChan - bl.flush() - if sg == "close" { - for _, l := range bl.outputs { - l.Destroy() - } - bl.outputs = nil - gameOver = true - } - bl.wg.Done() - } - if gameOver { - break - } - } -} - -// Emergency Log EMERGENCY level message. -func (bl *BeeLogger) Emergency(format string, v ...interface{}) { - if LevelEmergency > bl.level { - return - } - bl.writeMsg(LevelEmergency, format, v...) -} - -// Alert Log ALERT level message. -func (bl *BeeLogger) Alert(format string, v ...interface{}) { - if LevelAlert > bl.level { - return - } - bl.writeMsg(LevelAlert, format, v...) -} - -// Critical Log CRITICAL level message. -func (bl *BeeLogger) Critical(format string, v ...interface{}) { - if LevelCritical > bl.level { - return - } - bl.writeMsg(LevelCritical, format, v...) -} - -// Error Log ERROR level message. -func (bl *BeeLogger) Error(format string, v ...interface{}) { - if LevelError > bl.level { - return - } - bl.writeMsg(LevelError, format, v...) -} - -// Warning Log WARNING level message. -func (bl *BeeLogger) Warning(format string, v ...interface{}) { - if LevelWarn > bl.level { - return - } - bl.writeMsg(LevelWarn, format, v...) -} - -// Notice Log NOTICE level message. -func (bl *BeeLogger) Notice(format string, v ...interface{}) { - if LevelNotice > bl.level { - return - } - bl.writeMsg(LevelNotice, format, v...) -} - -// Informational Log INFORMATIONAL level message. -func (bl *BeeLogger) Informational(format string, v ...interface{}) { - if LevelInfo > bl.level { - return - } - bl.writeMsg(LevelInfo, format, v...) -} - -// Debug Log DEBUG level message. -func (bl *BeeLogger) Debug(format string, v ...interface{}) { - if LevelDebug > bl.level { - return - } - bl.writeMsg(LevelDebug, format, v...) -} - -// Warn Log WARN level message. -// compatibility alias for Warning() -func (bl *BeeLogger) Warn(format string, v ...interface{}) { - if LevelWarn > bl.level { - return - } - bl.writeMsg(LevelWarn, format, v...) -} - -// Info Log INFO level message. -// compatibility alias for Informational() -func (bl *BeeLogger) Info(format string, v ...interface{}) { - if LevelInfo > bl.level { - return - } - bl.writeMsg(LevelInfo, format, v...) -} - -// Trace Log TRACE level message. -// compatibility alias for Debug() -func (bl *BeeLogger) Trace(format string, v ...interface{}) { - if LevelDebug > bl.level { - return - } - bl.writeMsg(LevelDebug, format, v...) -} - -// Flush flush all chan data. -func (bl *BeeLogger) Flush() { - if bl.asynchronous { - bl.signalChan <- "flush" - bl.wg.Wait() - bl.wg.Add(1) - return - } - bl.flush() -} - -// Close close logger, flush all chan data and destroy all adapters in BeeLogger. -func (bl *BeeLogger) Close() { - if bl.asynchronous { - bl.signalChan <- "close" - bl.wg.Wait() - close(bl.msgChan) - } else { - bl.flush() - for _, l := range bl.outputs { - l.Destroy() - } - bl.outputs = nil - } - close(bl.signalChan) -} - -// Reset close all outputs, and set bl.outputs to nil -func (bl *BeeLogger) Reset() { - bl.Flush() - for _, l := range bl.outputs { - l.Destroy() - } - bl.outputs = nil -} - -func (bl *BeeLogger) flush() { - if bl.asynchronous { - for { - if len(bl.msgChan) > 0 { - bm := <-bl.msgChan - bl.writeToLoggers(bm.when, bm.msg, bm.level) - logMsgPool.Put(bm) - continue - } - break - } - } - for _, l := range bl.outputs { - l.Flush() - } -} - -// beeLogger references the used application logger. -var beeLogger = NewLogger() - -// GetBeeLogger returns the default BeeLogger -func GetBeeLogger() *BeeLogger { - return beeLogger -} - -var beeLoggerMap = struct { - sync.RWMutex - logs map[string]*log.Logger -}{ - logs: map[string]*log.Logger{}, -} - -// GetLogger returns the default BeeLogger -func GetLogger(prefixes ...string) *log.Logger { - prefix := append(prefixes, "")[0] - if prefix != "" { - prefix = fmt.Sprintf(`[%s] `, strings.ToUpper(prefix)) - } - beeLoggerMap.RLock() - l, ok := beeLoggerMap.logs[prefix] - if ok { - beeLoggerMap.RUnlock() - return l - } - beeLoggerMap.RUnlock() - beeLoggerMap.Lock() - defer beeLoggerMap.Unlock() - l, ok = beeLoggerMap.logs[prefix] - if !ok { - l = log.New(beeLogger, prefix, 0) - beeLoggerMap.logs[prefix] = l - } - return l -} - -// Reset will remove all the adapter -func Reset() { - beeLogger.Reset() -} - -// Async set the beelogger with Async mode and hold msglen messages -func Async(msgLen ...int64) *BeeLogger { - return beeLogger.Async(msgLen...) -} - -// SetLevel sets the global log level used by the simple logger. -func SetLevel(l int) { - beeLogger.SetLevel(l) -} - -// SetPrefix sets the prefix -func SetPrefix(s string) { - beeLogger.SetPrefix(s) -} - -// EnableFuncCallDepth enable log funcCallDepth -func EnableFuncCallDepth(b bool) { - beeLogger.enableFuncCallDepth = b -} - -// SetLogFuncCall set the CallDepth, default is 4 -func SetLogFuncCall(b bool) { - beeLogger.EnableFuncCallDepth(b) - beeLogger.SetLogFuncCallDepth(4) -} - -// SetLogFuncCallDepth set log funcCallDepth -func SetLogFuncCallDepth(d int) { - beeLogger.loggerFuncCallDepth = d -} - -// SetLogger sets a new logger. -func SetLogger(adapter string, config ...string) error { - return beeLogger.SetLogger(adapter, config...) -} - -// Emergency logs a message at emergency level. -func Emergency(f interface{}, v ...interface{}) { - beeLogger.Emergency(formatLog(f, v...)) -} - -// Alert logs a message at alert level. -func Alert(f interface{}, v ...interface{}) { - beeLogger.Alert(formatLog(f, v...)) -} - -// Critical logs a message at critical level. -func Critical(f interface{}, v ...interface{}) { - beeLogger.Critical(formatLog(f, v...)) -} - -// Error logs a message at error level. -func Error(f interface{}, v ...interface{}) { - beeLogger.Error(formatLog(f, v...)) -} - -// Warning logs a message at warning level. -func Warning(f interface{}, v ...interface{}) { - beeLogger.Warn(formatLog(f, v...)) -} - -// Warn compatibility alias for Warning() -func Warn(f interface{}, v ...interface{}) { - beeLogger.Warn(formatLog(f, v...)) -} - -// Notice logs a message at notice level. -func Notice(f interface{}, v ...interface{}) { - beeLogger.Notice(formatLog(f, v...)) -} - -// Informational logs a message at info level. -func Informational(f interface{}, v ...interface{}) { - beeLogger.Info(formatLog(f, v...)) -} - -// Info compatibility alias for Warning() -func Info(f interface{}, v ...interface{}) { - beeLogger.Info(formatLog(f, v...)) -} - -// Debug logs a message at debug level. -func Debug(f interface{}, v ...interface{}) { - beeLogger.Debug(formatLog(f, v...)) -} - -// Trace logs a message at trace level. -// compatibility alias for Warning() -func Trace(f interface{}, v ...interface{}) { - beeLogger.Trace(formatLog(f, v...)) -} - -func formatLog(f interface{}, v ...interface{}) string { - var msg string - switch f.(type) { - case string: - msg = f.(string) - if len(v) == 0 { - return msg - } - if strings.Contains(msg, "%") && !strings.Contains(msg, "%%") { - //format string - } else { - //do not contain format char - msg += strings.Repeat(" %v", len(v)) - } - default: - msg = fmt.Sprint(f) - if len(v) == 0 { - return msg - } - msg += strings.Repeat(" %v", len(v)) - } - return fmt.Sprintf(msg, v...) -} diff --git a/logs/logger.go b/logs/logger.go deleted file mode 100644 index a28bff6f..00000000 --- a/logs/logger.go +++ /dev/null @@ -1,176 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package logs - -import ( - "io" - "runtime" - "sync" - "time" -) - -type logWriter struct { - sync.Mutex - writer io.Writer -} - -func newLogWriter(wr io.Writer) *logWriter { - return &logWriter{writer: wr} -} - -func (lg *logWriter) writeln(when time.Time, msg string) (int, error) { - lg.Lock() - h, _, _ := formatTimeHeader(when) - n, err := lg.writer.Write(append(append(h, msg...), '\n')) - lg.Unlock() - return n, err -} - -const ( - y1 = `0123456789` - y2 = `0123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789` - y3 = `0000000000111111111122222222223333333333444444444455555555556666666666777777777788888888889999999999` - y4 = `0123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789` - mo1 = `000000000111` - mo2 = `123456789012` - d1 = `0000000001111111111222222222233` - d2 = `1234567890123456789012345678901` - h1 = `000000000011111111112222` - h2 = `012345678901234567890123` - mi1 = `000000000011111111112222222222333333333344444444445555555555` - mi2 = `012345678901234567890123456789012345678901234567890123456789` - s1 = `000000000011111111112222222222333333333344444444445555555555` - s2 = `012345678901234567890123456789012345678901234567890123456789` - ns1 = `0123456789` -) - -func formatTimeHeader(when time.Time) ([]byte, int, int) { - y, mo, d := when.Date() - h, mi, s := when.Clock() - ns := when.Nanosecond() / 1000000 - //len("2006/01/02 15:04:05.123 ")==24 - var buf [24]byte - - buf[0] = y1[y/1000%10] - buf[1] = y2[y/100] - buf[2] = y3[y-y/100*100] - buf[3] = y4[y-y/100*100] - buf[4] = '/' - buf[5] = mo1[mo-1] - buf[6] = mo2[mo-1] - buf[7] = '/' - buf[8] = d1[d-1] - buf[9] = d2[d-1] - buf[10] = ' ' - buf[11] = h1[h] - buf[12] = h2[h] - buf[13] = ':' - buf[14] = mi1[mi] - buf[15] = mi2[mi] - buf[16] = ':' - buf[17] = s1[s] - buf[18] = s2[s] - buf[19] = '.' - buf[20] = ns1[ns/100] - buf[21] = ns1[ns%100/10] - buf[22] = ns1[ns%10] - - buf[23] = ' ' - - return buf[0:], d, h -} - -var ( - green = string([]byte{27, 91, 57, 55, 59, 52, 50, 109}) - white = string([]byte{27, 91, 57, 48, 59, 52, 55, 109}) - yellow = string([]byte{27, 91, 57, 55, 59, 52, 51, 109}) - red = string([]byte{27, 91, 57, 55, 59, 52, 49, 109}) - blue = string([]byte{27, 91, 57, 55, 59, 52, 52, 109}) - magenta = string([]byte{27, 91, 57, 55, 59, 52, 53, 109}) - cyan = string([]byte{27, 91, 57, 55, 59, 52, 54, 109}) - - w32Green = string([]byte{27, 91, 52, 50, 109}) - w32White = string([]byte{27, 91, 52, 55, 109}) - w32Yellow = string([]byte{27, 91, 52, 51, 109}) - w32Red = string([]byte{27, 91, 52, 49, 109}) - w32Blue = string([]byte{27, 91, 52, 52, 109}) - w32Magenta = string([]byte{27, 91, 52, 53, 109}) - w32Cyan = string([]byte{27, 91, 52, 54, 109}) - - reset = string([]byte{27, 91, 48, 109}) -) - -var once sync.Once -var colorMap map[string]string - -func initColor() { - if runtime.GOOS == "windows" { - green = w32Green - white = w32White - yellow = w32Yellow - red = w32Red - blue = w32Blue - magenta = w32Magenta - cyan = w32Cyan - } - colorMap = map[string]string{ - //by color - "green": green, - "white": white, - "yellow": yellow, - "red": red, - //by method - "GET": blue, - "POST": cyan, - "PUT": yellow, - "DELETE": red, - "PATCH": green, - "HEAD": magenta, - "OPTIONS": white, - } -} - -// ColorByStatus return color by http code -// 2xx return Green -// 3xx return White -// 4xx return Yellow -// 5xx return Red -func ColorByStatus(code int) string { - once.Do(initColor) - switch { - case code >= 200 && code < 300: - return colorMap["green"] - case code >= 300 && code < 400: - return colorMap["white"] - case code >= 400 && code < 500: - return colorMap["yellow"] - default: - return colorMap["red"] - } -} - -// ColorByMethod return color by http code -func ColorByMethod(method string) string { - once.Do(initColor) - if c := colorMap[method]; c != "" { - return c - } - return reset -} - -// ResetColor return reset color -func ResetColor() string { - return reset -} diff --git a/logs/logger_test.go b/logs/logger_test.go deleted file mode 100644 index 15be500d..00000000 --- a/logs/logger_test.go +++ /dev/null @@ -1,57 +0,0 @@ -// Copyright 2016 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package logs - -import ( - "testing" - "time" -) - -func TestFormatHeader_0(t *testing.T) { - tm := time.Now() - if tm.Year() >= 2100 { - t.FailNow() - } - dur := time.Second - for { - if tm.Year() >= 2100 { - break - } - h, _, _ := formatTimeHeader(tm) - if tm.Format("2006/01/02 15:04:05.000 ") != string(h) { - t.Log(tm) - t.FailNow() - } - tm = tm.Add(dur) - dur *= 2 - } -} - -func TestFormatHeader_1(t *testing.T) { - tm := time.Now() - year := tm.Year() - dur := time.Second - for { - if tm.Year() >= year+1 { - break - } - h, _, _ := formatTimeHeader(tm) - if tm.Format("2006/01/02 15:04:05.000 ") != string(h) { - t.Log(tm) - t.FailNow() - } - tm = tm.Add(dur) - } -} diff --git a/logs/multifile.go b/logs/multifile.go deleted file mode 100644 index 90168274..00000000 --- a/logs/multifile.go +++ /dev/null @@ -1,119 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package logs - -import ( - "encoding/json" - "time" -) - -// A filesLogWriter manages several fileLogWriter -// filesLogWriter will write logs to the file in json configuration and write the same level log to correspond file -// means if the file name in configuration is project.log filesLogWriter will create project.error.log/project.debug.log -// and write the error-level logs to project.error.log and write the debug-level logs to project.debug.log -// the rotate attribute also acts like fileLogWriter -type multiFileLogWriter struct { - writers [LevelDebug + 1 + 1]*fileLogWriter // the last one for fullLogWriter - fullLogWriter *fileLogWriter - Separate []string `json:"separate"` -} - -var levelNames = [...]string{"emergency", "alert", "critical", "error", "warning", "notice", "info", "debug"} - -// Init file logger with json config. -// jsonConfig like: -// { -// "filename":"logs/beego.log", -// "maxLines":0, -// "maxsize":0, -// "daily":true, -// "maxDays":15, -// "rotate":true, -// "perm":0600, -// "separate":["emergency", "alert", "critical", "error", "warning", "notice", "info", "debug"], -// } - -func (f *multiFileLogWriter) Init(config string) error { - writer := newFileWriter().(*fileLogWriter) - err := writer.Init(config) - if err != nil { - return err - } - f.fullLogWriter = writer - f.writers[LevelDebug+1] = writer - - //unmarshal "separate" field to f.Separate - json.Unmarshal([]byte(config), f) - - jsonMap := map[string]interface{}{} - json.Unmarshal([]byte(config), &jsonMap) - - for i := LevelEmergency; i < LevelDebug+1; i++ { - for _, v := range f.Separate { - if v == levelNames[i] { - jsonMap["filename"] = f.fullLogWriter.fileNameOnly + "." + levelNames[i] + f.fullLogWriter.suffix - jsonMap["level"] = i - bs, _ := json.Marshal(jsonMap) - writer = newFileWriter().(*fileLogWriter) - err := writer.Init(string(bs)) - if err != nil { - return err - } - f.writers[i] = writer - } - } - } - - return nil -} - -func (f *multiFileLogWriter) Destroy() { - for i := 0; i < len(f.writers); i++ { - if f.writers[i] != nil { - f.writers[i].Destroy() - } - } -} - -func (f *multiFileLogWriter) WriteMsg(when time.Time, msg string, level int) error { - if f.fullLogWriter != nil { - f.fullLogWriter.WriteMsg(when, msg, level) - } - for i := 0; i < len(f.writers)-1; i++ { - if f.writers[i] != nil { - if level == f.writers[i].Level { - f.writers[i].WriteMsg(when, msg, level) - } - } - } - return nil -} - -func (f *multiFileLogWriter) Flush() { - for i := 0; i < len(f.writers); i++ { - if f.writers[i] != nil { - f.writers[i].Flush() - } - } -} - -// newFilesWriter create a FileLogWriter returning as LoggerInterface. -func newFilesWriter() Logger { - return &multiFileLogWriter{} -} - -func init() { - Register(AdapterMultiFile, newFilesWriter) -} diff --git a/logs/multifile_test.go b/logs/multifile_test.go deleted file mode 100644 index 57b96094..00000000 --- a/logs/multifile_test.go +++ /dev/null @@ -1,78 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package logs - -import ( - "bufio" - "os" - "strconv" - "strings" - "testing" -) - -func TestFiles_1(t *testing.T) { - log := NewLogger(10000) - log.SetLogger("multifile", `{"filename":"test.log","separate":["emergency", "alert", "critical", "error", "warning", "notice", "info", "debug"]}`) - log.Debug("debug") - log.Informational("info") - log.Notice("notice") - log.Warning("warning") - log.Error("error") - log.Alert("alert") - log.Critical("critical") - log.Emergency("emergency") - fns := []string{""} - fns = append(fns, levelNames[0:]...) - name := "test" - suffix := ".log" - for _, fn := range fns { - - file := name + suffix - if fn != "" { - file = name + "." + fn + suffix - } - f, err := os.Open(file) - if err != nil { - t.Fatal(err) - } - b := bufio.NewReader(f) - lineNum := 0 - lastLine := "" - for { - line, _, err := b.ReadLine() - if err != nil { - break - } - if len(line) > 0 { - lastLine = string(line) - lineNum++ - } - } - var expected = 1 - if fn == "" { - expected = LevelDebug + 1 - } - if lineNum != expected { - t.Fatal(file, "has", lineNum, "lines not "+strconv.Itoa(expected)+" lines") - } - if lineNum == 1 { - if !strings.Contains(lastLine, fn) { - t.Fatal(file + " " + lastLine + " not contains the log msg " + fn) - } - } - os.Remove(file) - } - -} diff --git a/logs/slack.go b/logs/slack.go deleted file mode 100644 index 1cd2e5ae..00000000 --- a/logs/slack.go +++ /dev/null @@ -1,60 +0,0 @@ -package logs - -import ( - "encoding/json" - "fmt" - "net/http" - "net/url" - "time" -) - -// SLACKWriter implements beego LoggerInterface and is used to send jiaoliao webhook -type SLACKWriter struct { - WebhookURL string `json:"webhookurl"` - Level int `json:"level"` -} - -// newSLACKWriter create jiaoliao writer. -func newSLACKWriter() Logger { - return &SLACKWriter{Level: LevelTrace} -} - -// Init SLACKWriter with json config string -func (s *SLACKWriter) Init(jsonconfig string) error { - return json.Unmarshal([]byte(jsonconfig), s) -} - -// WriteMsg write message in smtp writer. -// it will send an email with subject and only this message. -func (s *SLACKWriter) WriteMsg(when time.Time, msg string, level int) error { - if level > s.Level { - return nil - } - - text := fmt.Sprintf("{\"text\": \"%s %s\"}", when.Format("2006-01-02 15:04:05"), msg) - - form := url.Values{} - form.Add("payload", text) - - resp, err := http.PostForm(s.WebhookURL, form) - if err != nil { - return err - } - defer resp.Body.Close() - if resp.StatusCode != http.StatusOK { - return fmt.Errorf("Post webhook failed %s %d", resp.Status, resp.StatusCode) - } - return nil -} - -// Flush implementing method. empty. -func (s *SLACKWriter) Flush() { -} - -// Destroy implementing method. empty. -func (s *SLACKWriter) Destroy() { -} - -func init() { - Register(AdapterSlack, newSLACKWriter) -} diff --git a/logs/smtp.go b/logs/smtp.go deleted file mode 100644 index 6208d7b8..00000000 --- a/logs/smtp.go +++ /dev/null @@ -1,149 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package logs - -import ( - "crypto/tls" - "encoding/json" - "fmt" - "net" - "net/smtp" - "strings" - "time" -) - -// SMTPWriter implements LoggerInterface and is used to send emails via given SMTP-server. -type SMTPWriter struct { - Username string `json:"username"` - Password string `json:"password"` - Host string `json:"host"` - Subject string `json:"subject"` - FromAddress string `json:"fromAddress"` - RecipientAddresses []string `json:"sendTos"` - Level int `json:"level"` -} - -// NewSMTPWriter create smtp writer. -func newSMTPWriter() Logger { - return &SMTPWriter{Level: LevelTrace} -} - -// Init smtp writer with json config. -// config like: -// { -// "username":"example@gmail.com", -// "password:"password", -// "host":"smtp.gmail.com:465", -// "subject":"email title", -// "fromAddress":"from@example.com", -// "sendTos":["email1","email2"], -// "level":LevelError -// } -func (s *SMTPWriter) Init(jsonconfig string) error { - return json.Unmarshal([]byte(jsonconfig), s) -} - -func (s *SMTPWriter) getSMTPAuth(host string) smtp.Auth { - if len(strings.Trim(s.Username, " ")) == 0 && len(strings.Trim(s.Password, " ")) == 0 { - return nil - } - return smtp.PlainAuth( - "", - s.Username, - s.Password, - host, - ) -} - -func (s *SMTPWriter) sendMail(hostAddressWithPort string, auth smtp.Auth, fromAddress string, recipients []string, msgContent []byte) error { - client, err := smtp.Dial(hostAddressWithPort) - if err != nil { - return err - } - - host, _, _ := net.SplitHostPort(hostAddressWithPort) - tlsConn := &tls.Config{ - InsecureSkipVerify: true, - ServerName: host, - } - if err = client.StartTLS(tlsConn); err != nil { - return err - } - - if auth != nil { - if err = client.Auth(auth); err != nil { - return err - } - } - - if err = client.Mail(fromAddress); err != nil { - return err - } - - for _, rec := range recipients { - if err = client.Rcpt(rec); err != nil { - return err - } - } - - w, err := client.Data() - if err != nil { - return err - } - _, err = w.Write(msgContent) - if err != nil { - return err - } - - err = w.Close() - if err != nil { - return err - } - - return client.Quit() -} - -// WriteMsg write message in smtp writer. -// it will send an email with subject and only this message. -func (s *SMTPWriter) WriteMsg(when time.Time, msg string, level int) error { - if level > s.Level { - return nil - } - - hp := strings.Split(s.Host, ":") - - // Set up authentication information. - auth := s.getSMTPAuth(hp[0]) - - // Connect to the server, authenticate, set the sender and recipient, - // and send the email all in one step. - contentType := "Content-Type: text/plain" + "; charset=UTF-8" - mailmsg := []byte("To: " + strings.Join(s.RecipientAddresses, ";") + "\r\nFrom: " + s.FromAddress + "<" + s.FromAddress + - ">\r\nSubject: " + s.Subject + "\r\n" + contentType + "\r\n\r\n" + fmt.Sprintf(".%s", when.Format("2006-01-02 15:04:05")) + msg) - - return s.sendMail(s.Host, auth, s.FromAddress, s.RecipientAddresses, mailmsg) -} - -// Flush implementing method. empty. -func (s *SMTPWriter) Flush() { -} - -// Destroy implementing method. empty. -func (s *SMTPWriter) Destroy() { -} - -func init() { - Register(AdapterMail, newSMTPWriter) -} diff --git a/logs/smtp_test.go b/logs/smtp_test.go deleted file mode 100644 index ebc8a952..00000000 --- a/logs/smtp_test.go +++ /dev/null @@ -1,24 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package logs - -// it often failed. And we moved this to pkg/logs, -// so we ignore it -// func TestSmtp(t *testing.T) { -// log := NewLogger(10000) -// log.SetLogger("smtp", `{"username":"beegotest@gmail.com","password":"xxxxxxxx","host":"smtp.gmail.com:587","sendTos":["xiemengjun@gmail.com"]}`) -// log.Critical("sendmail critical") -// time.Sleep(time.Second * 30) -// } diff --git a/metric/prometheus.go b/metric/prometheus.go deleted file mode 100644 index 215896bd..00000000 --- a/metric/prometheus.go +++ /dev/null @@ -1,101 +0,0 @@ -// Copyright 2020 astaxie -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package metric - -import ( - "net/http" - "reflect" - "strconv" - "strings" - "time" - - "github.com/prometheus/client_golang/prometheus" - - "github.com/astaxie/beego" - "github.com/astaxie/beego/logs" -) - -// Deprecated: we will removed this function in 2.1.0 -// please use pkg/web/filter/prometheus#FilterChain -func PrometheusMiddleWare(next http.Handler) http.Handler { - summaryVec := prometheus.NewSummaryVec(prometheus.SummaryOpts{ - Name: "beego", - Subsystem: "http_request", - ConstLabels: map[string]string{ - "server": beego.BConfig.ServerName, - "env": beego.BConfig.RunMode, - "appname": beego.BConfig.AppName, - }, - Help: "The statics info for http request", - }, []string{"pattern", "method", "status", "duration"}) - - prometheus.MustRegister(summaryVec) - - registerBuildInfo() - - return http.HandlerFunc(func(writer http.ResponseWriter, q *http.Request) { - start := time.Now() - next.ServeHTTP(writer, q) - end := time.Now() - go report(end.Sub(start), writer, q, summaryVec) - }) -} - -func registerBuildInfo() { - buildInfo := prometheus.NewGaugeVec(prometheus.GaugeOpts{ - Name: "beego", - Subsystem: "build_info", - Help: "The building information", - ConstLabels: map[string]string{ - "appname": beego.BConfig.AppName, - "build_version": beego.BuildVersion, - "build_revision": beego.BuildGitRevision, - "build_status": beego.BuildStatus, - "build_tag": beego.BuildTag, - "build_time": strings.Replace(beego.BuildTime, "--", " ", 1), - "go_version": beego.GoVersion, - "git_branch": beego.GitBranch, - "start_time": time.Now().Format("2006-01-02 15:04:05"), - }, - }, []string{}) - - prometheus.MustRegister(buildInfo) - buildInfo.WithLabelValues().Set(1) -} - -func report(dur time.Duration, writer http.ResponseWriter, q *http.Request, vec *prometheus.SummaryVec) { - ctrl := beego.BeeApp.Handlers - ctx := ctrl.GetContext() - ctx.Reset(writer, q) - defer ctrl.GiveBackContext(ctx) - - // We cannot read the status code from q.Response.StatusCode - // since the http server does not set q.Response. So q.Response is nil - // Thus, we use reflection to read the status from writer whose concrete type is http.response - responseVal := reflect.ValueOf(writer).Elem() - field := responseVal.FieldByName("status") - status := -1 - if field.IsValid() && field.Kind() == reflect.Int { - status = int(field.Int()) - } - ptn := "UNKNOWN" - if rt, found := ctrl.FindRouter(ctx); found { - ptn = rt.GetPattern() - } else { - logs.Warn("we can not find the router info for this request, so request will be recorded as UNKNOWN: " + q.URL.String()) - } - ms := dur / time.Millisecond - vec.WithLabelValues(ptn, q.Method, strconv.Itoa(status), strconv.Itoa(int(ms))).Observe(float64(ms)) -} diff --git a/metric/prometheus_test.go b/metric/prometheus_test.go deleted file mode 100644 index d82a6dec..00000000 --- a/metric/prometheus_test.go +++ /dev/null @@ -1,42 +0,0 @@ -// Copyright 2020 astaxie -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package metric - -import ( - "net/http" - "net/url" - "testing" - "time" - - "github.com/prometheus/client_golang/prometheus" - - "github.com/astaxie/beego/context" -) - -func TestPrometheusMiddleWare(t *testing.T) { - middleware := PrometheusMiddleWare(http.HandlerFunc(func(http.ResponseWriter, *http.Request) {})) - writer := &context.Response{} - request := &http.Request{ - URL: &url.URL{ - Host: "localhost", - RawPath: "/a/b/c", - }, - Method: "POST", - } - vec := prometheus.NewSummaryVec(prometheus.SummaryOpts{}, []string{"pattern", "method", "status", "duration"}) - - report(time.Second, writer, request, vec) - middleware.ServeHTTP(writer, request) -} diff --git a/migration/ddl.go b/migration/ddl.go deleted file mode 100644 index cd2c1c49..00000000 --- a/migration/ddl.go +++ /dev/null @@ -1,395 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package migration - -import ( - "fmt" - - "github.com/astaxie/beego/logs" -) - -// Index struct defines the structure of Index Columns -type Index struct { - Name string -} - -// Unique struct defines a single unique key combination -type Unique struct { - Definition string - Columns []*Column -} - -//Column struct defines a single column of a table -type Column struct { - Name string - Inc string - Null string - Default string - Unsign string - DataType string - remove bool - Modify bool -} - -// Foreign struct defines a single foreign relationship -type Foreign struct { - ForeignTable string - ForeignColumn string - OnDelete string - OnUpdate string - Column -} - -// RenameColumn struct allows renaming of columns -type RenameColumn struct { - OldName string - OldNull string - OldDefault string - OldUnsign string - OldDataType string - NewName string - Column -} - -// CreateTable creates the table on system -func (m *Migration) CreateTable(tablename, engine, charset string, p ...func()) { - m.TableName = tablename - m.Engine = engine - m.Charset = charset - m.ModifyType = "create" -} - -// AlterTable set the ModifyType to alter -func (m *Migration) AlterTable(tablename string) { - m.TableName = tablename - m.ModifyType = "alter" -} - -// NewCol creates a new standard column and attaches it to m struct -func (m *Migration) NewCol(name string) *Column { - col := &Column{Name: name} - m.AddColumns(col) - return col -} - -//PriCol creates a new primary column and attaches it to m struct -func (m *Migration) PriCol(name string) *Column { - col := &Column{Name: name} - m.AddColumns(col) - m.AddPrimary(col) - return col -} - -//UniCol creates / appends columns to specified unique key and attaches it to m struct -func (m *Migration) UniCol(uni, name string) *Column { - col := &Column{Name: name} - m.AddColumns(col) - - uniqueOriginal := &Unique{} - - for _, unique := range m.Uniques { - if unique.Definition == uni { - unique.AddColumnsToUnique(col) - uniqueOriginal = unique - } - } - if uniqueOriginal.Definition == "" { - unique := &Unique{Definition: uni} - unique.AddColumnsToUnique(col) - m.AddUnique(unique) - } - - return col -} - -//ForeignCol creates a new foreign column and returns the instance of column -func (m *Migration) ForeignCol(colname, foreigncol, foreigntable string) (foreign *Foreign) { - - foreign = &Foreign{ForeignColumn: foreigncol, ForeignTable: foreigntable} - foreign.Name = colname - m.AddForeign(foreign) - return foreign -} - -//SetOnDelete sets the on delete of foreign -func (foreign *Foreign) SetOnDelete(del string) *Foreign { - foreign.OnDelete = "ON DELETE" + del - return foreign -} - -//SetOnUpdate sets the on update of foreign -func (foreign *Foreign) SetOnUpdate(update string) *Foreign { - foreign.OnUpdate = "ON UPDATE" + update - return foreign -} - -//Remove marks the columns to be removed. -//it allows reverse m to create the column. -func (c *Column) Remove() { - c.remove = true -} - -//SetAuto enables auto_increment of column (can be used once) -func (c *Column) SetAuto(inc bool) *Column { - if inc { - c.Inc = "auto_increment" - } - return c -} - -//SetNullable sets the column to be null -func (c *Column) SetNullable(null bool) *Column { - if null { - c.Null = "" - - } else { - c.Null = "NOT NULL" - } - return c -} - -//SetDefault sets the default value, prepend with "DEFAULT " -func (c *Column) SetDefault(def string) *Column { - c.Default = "DEFAULT " + def - return c -} - -//SetUnsigned sets the column to be unsigned int -func (c *Column) SetUnsigned(unsign bool) *Column { - if unsign { - c.Unsign = "UNSIGNED" - } - return c -} - -//SetDataType sets the dataType of the column -func (c *Column) SetDataType(dataType string) *Column { - c.DataType = dataType - return c -} - -//SetOldNullable allows reverting to previous nullable on reverse ms -func (c *RenameColumn) SetOldNullable(null bool) *RenameColumn { - if null { - c.OldNull = "" - - } else { - c.OldNull = "NOT NULL" - } - return c -} - -//SetOldDefault allows reverting to previous default on reverse ms -func (c *RenameColumn) SetOldDefault(def string) *RenameColumn { - c.OldDefault = def - return c -} - -//SetOldUnsigned allows reverting to previous unsgined on reverse ms -func (c *RenameColumn) SetOldUnsigned(unsign bool) *RenameColumn { - if unsign { - c.OldUnsign = "UNSIGNED" - } - return c -} - -//SetOldDataType allows reverting to previous datatype on reverse ms -func (c *RenameColumn) SetOldDataType(dataType string) *RenameColumn { - c.OldDataType = dataType - return c -} - -//SetPrimary adds the columns to the primary key (can only be used any number of times in only one m) -func (c *Column) SetPrimary(m *Migration) *Column { - m.Primary = append(m.Primary, c) - return c -} - -//AddColumnsToUnique adds the columns to Unique Struct -func (unique *Unique) AddColumnsToUnique(columns ...*Column) *Unique { - - unique.Columns = append(unique.Columns, columns...) - - return unique -} - -//AddColumns adds columns to m struct -func (m *Migration) AddColumns(columns ...*Column) *Migration { - - m.Columns = append(m.Columns, columns...) - - return m -} - -//AddPrimary adds the column to primary in m struct -func (m *Migration) AddPrimary(primary *Column) *Migration { - m.Primary = append(m.Primary, primary) - return m -} - -//AddUnique adds the column to unique in m struct -func (m *Migration) AddUnique(unique *Unique) *Migration { - m.Uniques = append(m.Uniques, unique) - return m -} - -//AddForeign adds the column to foreign in m struct -func (m *Migration) AddForeign(foreign *Foreign) *Migration { - m.Foreigns = append(m.Foreigns, foreign) - return m -} - -//AddIndex adds the column to index in m struct -func (m *Migration) AddIndex(index *Index) *Migration { - m.Indexes = append(m.Indexes, index) - return m -} - -//RenameColumn allows renaming of columns -func (m *Migration) RenameColumn(from, to string) *RenameColumn { - rename := &RenameColumn{OldName: from, NewName: to} - m.Renames = append(m.Renames, rename) - return rename -} - -//GetSQL returns the generated sql depending on ModifyType -func (m *Migration) GetSQL() (sql string) { - sql = "" - switch m.ModifyType { - case "create": - { - sql += fmt.Sprintf("CREATE TABLE `%s` (", m.TableName) - for index, column := range m.Columns { - sql += fmt.Sprintf("\n `%s` %s %s %s %s %s", column.Name, column.DataType, column.Unsign, column.Null, column.Inc, column.Default) - if len(m.Columns) > index+1 { - sql += "," - } - } - - if len(m.Primary) > 0 { - sql += fmt.Sprintf(",\n PRIMARY KEY( ") - } - for index, column := range m.Primary { - sql += fmt.Sprintf(" `%s`", column.Name) - if len(m.Primary) > index+1 { - sql += "," - } - - } - if len(m.Primary) > 0 { - sql += fmt.Sprintf(")") - } - - for _, unique := range m.Uniques { - sql += fmt.Sprintf(",\n UNIQUE KEY `%s`( ", unique.Definition) - for index, column := range unique.Columns { - sql += fmt.Sprintf(" `%s`", column.Name) - if len(unique.Columns) > index+1 { - sql += "," - } - } - sql += fmt.Sprintf(")") - } - for _, foreign := range m.Foreigns { - sql += fmt.Sprintf(",\n `%s` %s %s %s %s %s", foreign.Name, foreign.DataType, foreign.Unsign, foreign.Null, foreign.Inc, foreign.Default) - sql += fmt.Sprintf(",\n KEY `%s_%s_foreign`(`%s`),", m.TableName, foreign.Column.Name, foreign.Column.Name) - sql += fmt.Sprintf("\n CONSTRAINT `%s_%s_foreign` FOREIGN KEY (`%s`) REFERENCES `%s` (`%s`) %s %s", m.TableName, foreign.Column.Name, foreign.Column.Name, foreign.ForeignTable, foreign.ForeignColumn, foreign.OnDelete, foreign.OnUpdate) - - } - sql += fmt.Sprintf(")ENGINE=%s DEFAULT CHARSET=%s;", m.Engine, m.Charset) - break - } - case "alter": - { - sql += fmt.Sprintf("ALTER TABLE `%s` ", m.TableName) - for index, column := range m.Columns { - if !column.remove { - logs.Info("col") - sql += fmt.Sprintf("\n ADD `%s` %s %s %s %s %s", column.Name, column.DataType, column.Unsign, column.Null, column.Inc, column.Default) - } else { - sql += fmt.Sprintf("\n DROP COLUMN `%s`", column.Name) - } - - if len(m.Columns) > index+1 { - sql += "," - } - } - for index, column := range m.Renames { - sql += fmt.Sprintf("CHANGE COLUMN `%s` `%s` %s %s %s %s %s", column.OldName, column.NewName, column.DataType, column.Unsign, column.Null, column.Inc, column.Default) - if len(m.Renames) > index+1 { - sql += "," - } - } - - for index, foreign := range m.Foreigns { - sql += fmt.Sprintf("ADD `%s` %s %s %s %s %s", foreign.Name, foreign.DataType, foreign.Unsign, foreign.Null, foreign.Inc, foreign.Default) - sql += fmt.Sprintf(",\n ADD KEY `%s_%s_foreign`(`%s`)", m.TableName, foreign.Column.Name, foreign.Column.Name) - sql += fmt.Sprintf(",\n ADD CONSTRAINT `%s_%s_foreign` FOREIGN KEY (`%s`) REFERENCES `%s` (`%s`) %s %s", m.TableName, foreign.Column.Name, foreign.Column.Name, foreign.ForeignTable, foreign.ForeignColumn, foreign.OnDelete, foreign.OnUpdate) - if len(m.Foreigns) > index+1 { - sql += "," - } - } - sql += ";" - - break - } - case "reverse": - { - - sql += fmt.Sprintf("ALTER TABLE `%s`", m.TableName) - for index, column := range m.Columns { - if column.remove { - sql += fmt.Sprintf("\n ADD `%s` %s %s %s %s %s", column.Name, column.DataType, column.Unsign, column.Null, column.Inc, column.Default) - } else { - sql += fmt.Sprintf("\n DROP COLUMN `%s`", column.Name) - } - if len(m.Columns) > index+1 { - sql += "," - } - } - - if len(m.Primary) > 0 { - sql += fmt.Sprintf("\n DROP PRIMARY KEY,") - } - - for index, unique := range m.Uniques { - sql += fmt.Sprintf("\n DROP KEY `%s`", unique.Definition) - if len(m.Uniques) > index+1 { - sql += "," - } - - } - for index, column := range m.Renames { - sql += fmt.Sprintf("\n CHANGE COLUMN `%s` `%s` %s %s %s %s", column.NewName, column.OldName, column.OldDataType, column.OldUnsign, column.OldNull, column.OldDefault) - if len(m.Renames) > index+1 { - sql += "," - } - } - - for _, foreign := range m.Foreigns { - sql += fmt.Sprintf("\n DROP KEY `%s_%s_foreign`", m.TableName, foreign.Column.Name) - sql += fmt.Sprintf(",\n DROP FOREIGN KEY `%s_%s_foreign`", m.TableName, foreign.Column.Name) - sql += fmt.Sprintf(",\n DROP COLUMN `%s`", foreign.Name) - } - sql += ";" - } - case "delete": - { - sql += fmt.Sprintf("DROP TABLE IF EXISTS `%s`;", m.TableName) - } - } - - return -} diff --git a/migration/doc.go b/migration/doc.go deleted file mode 100644 index 0c6564d4..00000000 --- a/migration/doc.go +++ /dev/null @@ -1,32 +0,0 @@ -// Package migration enables you to generate migrations back and forth. It generates both migrations. -// -// //Creates a table -// m.CreateTable("tablename","InnoDB","utf8"); -// -// //Alter a table -// m.AlterTable("tablename") -// -// Standard Column Methods -// * SetDataType -// * SetNullable -// * SetDefault -// * SetUnsigned (use only on integer types unless produces error) -// -// //Sets a primary column, multiple calls allowed, standard column methods available -// m.PriCol("id").SetAuto(true).SetNullable(false).SetDataType("INT(10)").SetUnsigned(true) -// -// //UniCol Can be used multiple times, allows standard Column methods. Use same "index" string to add to same index -// m.UniCol("index","column") -// -// //Standard Column Initialisation, can call .Remove() after NewCol("") on alter to remove -// m.NewCol("name").SetDataType("VARCHAR(255) COLLATE utf8_unicode_ci").SetNullable(false) -// m.NewCol("value").SetDataType("DOUBLE(8,2)").SetNullable(false) -// -// //Rename Columns , only use with Alter table, doesn't works with Create, prefix standard column methods with "Old" to -// //create a true reversible migration eg: SetOldDataType("DOUBLE(12,3)") -// m.RenameColumn("from","to")... -// -// //Foreign Columns, single columns are only supported, SetOnDelete & SetOnUpdate are available, call appropriately. -// //Supports standard column methods, automatic reverse. -// m.ForeignCol("local_col","foreign_col","foreign_table") -package migration diff --git a/migration/migration.go b/migration/migration.go deleted file mode 100644 index 5ddfd972..00000000 --- a/migration/migration.go +++ /dev/null @@ -1,330 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Package migration is used for migration -// -// The table structure is as follow: -// -// CREATE TABLE `migrations` ( -// `id_migration` int(10) unsigned NOT NULL AUTO_INCREMENT COMMENT 'surrogate key', -// `name` varchar(255) DEFAULT NULL COMMENT 'migration name, unique', -// `created_at` timestamp NOT NULL DEFAULT CURRENT_TIMESTAMP COMMENT 'date migrated or rolled back', -// `statements` longtext COMMENT 'SQL statements for this migration', -// `rollback_statements` longtext, -// `status` enum('update','rollback') DEFAULT NULL COMMENT 'update indicates it is a normal migration while rollback means this migration is rolled back', -// PRIMARY KEY (`id_migration`) -// ) ENGINE=InnoDB DEFAULT CHARSET=utf8; -package migration - -import ( - "errors" - "sort" - "strings" - "time" - - "github.com/astaxie/beego/logs" - "github.com/astaxie/beego/orm" -) - -// const the data format for the bee generate migration datatype -const ( - DateFormat = "20060102_150405" - DBDateFormat = "2006-01-02 15:04:05" -) - -// Migrationer is an interface for all Migration struct -type Migrationer interface { - Up() - Down() - Reset() - Exec(name, status string) error - GetCreated() int64 -} - -//Migration defines the migrations by either SQL or DDL -type Migration struct { - sqls []string - Created string - TableName string - Engine string - Charset string - ModifyType string - Columns []*Column - Indexes []*Index - Primary []*Column - Uniques []*Unique - Foreigns []*Foreign - Renames []*RenameColumn - RemoveColumns []*Column - RemoveIndexes []*Index - RemoveUniques []*Unique - RemoveForeigns []*Foreign -} - -var ( - migrationMap map[string]Migrationer -) - -func init() { - migrationMap = make(map[string]Migrationer) -} - -// Up implement in the Inheritance struct for upgrade -func (m *Migration) Up() { - - switch m.ModifyType { - case "reverse": - m.ModifyType = "alter" - case "delete": - m.ModifyType = "create" - } - m.sqls = append(m.sqls, m.GetSQL()) -} - -// Down implement in the Inheritance struct for down -func (m *Migration) Down() { - - switch m.ModifyType { - case "alter": - m.ModifyType = "reverse" - case "create": - m.ModifyType = "delete" - } - m.sqls = append(m.sqls, m.GetSQL()) -} - -//Migrate adds the SQL to the execution list -func (m *Migration) Migrate(migrationType string) { - m.ModifyType = migrationType - m.sqls = append(m.sqls, m.GetSQL()) -} - -// SQL add sql want to execute -func (m *Migration) SQL(sql string) { - m.sqls = append(m.sqls, sql) -} - -// Reset the sqls -func (m *Migration) Reset() { - m.sqls = make([]string, 0) -} - -// Exec execute the sql already add in the sql -func (m *Migration) Exec(name, status string) error { - o := orm.NewOrm() - for _, s := range m.sqls { - logs.Info("exec sql:", s) - r := o.Raw(s) - _, err := r.Exec() - if err != nil { - return err - } - } - return m.addOrUpdateRecord(name, status) -} - -func (m *Migration) addOrUpdateRecord(name, status string) error { - o := orm.NewOrm() - if status == "down" { - status = "rollback" - p, err := o.Raw("update migrations set status = ?, rollback_statements = ?, created_at = ? where name = ?").Prepare() - if err != nil { - return nil - } - _, err = p.Exec(status, strings.Join(m.sqls, "; "), time.Now().Format(DBDateFormat), name) - return err - } - status = "update" - p, err := o.Raw("insert into migrations(name, created_at, statements, status) values(?,?,?,?)").Prepare() - if err != nil { - return err - } - _, err = p.Exec(name, time.Now().Format(DBDateFormat), strings.Join(m.sqls, "; "), status) - return err -} - -// GetCreated get the unixtime from the Created -func (m *Migration) GetCreated() int64 { - t, err := time.Parse(DateFormat, m.Created) - if err != nil { - return 0 - } - return t.Unix() -} - -// Register register the Migration in the map -func Register(name string, m Migrationer) error { - if _, ok := migrationMap[name]; ok { - return errors.New("already exist name:" + name) - } - migrationMap[name] = m - return nil -} - -// Upgrade upgrade the migration from lasttime -func Upgrade(lasttime int64) error { - sm := sortMap(migrationMap) - i := 0 - migs, _ := getAllMigrations() - for _, v := range sm { - if _, ok := migs[v.name]; !ok { - logs.Info("start upgrade", v.name) - v.m.Reset() - v.m.Up() - err := v.m.Exec(v.name, "up") - if err != nil { - logs.Error("execute error:", err) - time.Sleep(2 * time.Second) - return err - } - logs.Info("end upgrade:", v.name) - i++ - } - } - logs.Info("total success upgrade:", i, " migration") - time.Sleep(2 * time.Second) - return nil -} - -// Rollback rollback the migration by the name -func Rollback(name string) error { - if v, ok := migrationMap[name]; ok { - logs.Info("start rollback") - v.Reset() - v.Down() - err := v.Exec(name, "down") - if err != nil { - logs.Error("execute error:", err) - time.Sleep(2 * time.Second) - return err - } - logs.Info("end rollback") - time.Sleep(2 * time.Second) - return nil - } - logs.Error("not exist the migrationMap name:" + name) - time.Sleep(2 * time.Second) - return errors.New("not exist the migrationMap name:" + name) -} - -// Reset reset all migration -// run all migration's down function -func Reset() error { - sm := sortMap(migrationMap) - i := 0 - for j := len(sm) - 1; j >= 0; j-- { - v := sm[j] - if isRollBack(v.name) { - logs.Info("skip the", v.name) - time.Sleep(1 * time.Second) - continue - } - logs.Info("start reset:", v.name) - v.m.Reset() - v.m.Down() - err := v.m.Exec(v.name, "down") - if err != nil { - logs.Error("execute error:", err) - time.Sleep(2 * time.Second) - return err - } - i++ - logs.Info("end reset:", v.name) - } - logs.Info("total success reset:", i, " migration") - time.Sleep(2 * time.Second) - return nil -} - -// Refresh first Reset, then Upgrade -func Refresh() error { - err := Reset() - if err != nil { - logs.Error("execute error:", err) - time.Sleep(2 * time.Second) - return err - } - err = Upgrade(0) - return err -} - -type dataSlice []data - -type data struct { - created int64 - name string - m Migrationer -} - -// Len is part of sort.Interface. -func (d dataSlice) Len() int { - return len(d) -} - -// Swap is part of sort.Interface. -func (d dataSlice) Swap(i, j int) { - d[i], d[j] = d[j], d[i] -} - -// Less is part of sort.Interface. We use count as the value to sort by -func (d dataSlice) Less(i, j int) bool { - return d[i].created < d[j].created -} - -func sortMap(m map[string]Migrationer) dataSlice { - s := make(dataSlice, 0, len(m)) - for k, v := range m { - d := data{} - d.created = v.GetCreated() - d.name = k - d.m = v - s = append(s, d) - } - sort.Sort(s) - return s -} - -func isRollBack(name string) bool { - o := orm.NewOrm() - var maps []orm.Params - num, err := o.Raw("select * from migrations where `name` = ? order by id_migration desc", name).Values(&maps) - if err != nil { - logs.Info("get name has error", err) - return false - } - if num <= 0 { - return false - } - if maps[0]["status"] == "rollback" { - return true - } - return false -} -func getAllMigrations() (map[string]string, error) { - o := orm.NewOrm() - var maps []orm.Params - migs := make(map[string]string) - num, err := o.Raw("select * from migrations order by id_migration desc").Values(&maps) - if err != nil { - logs.Info("get name has error", err) - return migs, err - } - if num > 0 { - for _, v := range maps { - name := v["name"].(string) - migs[name] = v["status"].(string) - } - } - return migs, nil -} diff --git a/mime.go b/mime.go deleted file mode 100644 index ca2878ab..00000000 --- a/mime.go +++ /dev/null @@ -1,556 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package beego - -var mimemaps = map[string]string{ - ".3dm": "x-world/x-3dmf", - ".3dmf": "x-world/x-3dmf", - ".7z": "application/x-7z-compressed", - ".a": "application/octet-stream", - ".aab": "application/x-authorware-bin", - ".aam": "application/x-authorware-map", - ".aas": "application/x-authorware-seg", - ".abc": "text/vndabc", - ".ace": "application/x-ace-compressed", - ".acgi": "text/html", - ".afl": "video/animaflex", - ".ai": "application/postscript", - ".aif": "audio/aiff", - ".aifc": "audio/aiff", - ".aiff": "audio/aiff", - ".aim": "application/x-aim", - ".aip": "text/x-audiosoft-intra", - ".alz": "application/x-alz-compressed", - ".ani": "application/x-navi-animation", - ".aos": "application/x-nokia-9000-communicator-add-on-software", - ".aps": "application/mime", - ".apk": "application/vnd.android.package-archive", - ".arc": "application/x-arc-compressed", - ".arj": "application/arj", - ".art": "image/x-jg", - ".asf": "video/x-ms-asf", - ".asm": "text/x-asm", - ".asp": "text/asp", - ".asx": "application/x-mplayer2", - ".au": "audio/basic", - ".avi": "video/x-msvideo", - ".avs": "video/avs-video", - ".bcpio": "application/x-bcpio", - ".bin": "application/mac-binary", - ".bmp": "image/bmp", - ".boo": "application/book", - ".book": "application/book", - ".boz": "application/x-bzip2", - ".bsh": "application/x-bsh", - ".bz2": "application/x-bzip2", - ".bz": "application/x-bzip", - ".c++": "text/plain", - ".c": "text/x-c", - ".cab": "application/vnd.ms-cab-compressed", - ".cat": "application/vndms-pkiseccat", - ".cc": "text/x-c", - ".ccad": "application/clariscad", - ".cco": "application/x-cocoa", - ".cdf": "application/cdf", - ".cer": "application/pkix-cert", - ".cha": "application/x-chat", - ".chat": "application/x-chat", - ".chrt": "application/vnd.kde.kchart", - ".class": "application/java", - ".com": "text/plain", - ".conf": "text/plain", - ".cpio": "application/x-cpio", - ".cpp": "text/x-c", - ".cpt": "application/mac-compactpro", - ".crl": "application/pkcs-crl", - ".crt": "application/pkix-cert", - ".crx": "application/x-chrome-extension", - ".csh": "text/x-scriptcsh", - ".css": "text/css", - ".csv": "text/csv", - ".cxx": "text/plain", - ".dar": "application/x-dar", - ".dcr": "application/x-director", - ".deb": "application/x-debian-package", - ".deepv": "application/x-deepv", - ".def": "text/plain", - ".der": "application/x-x509-ca-cert", - ".dif": "video/x-dv", - ".dir": "application/x-director", - ".divx": "video/divx", - ".dl": "video/dl", - ".dmg": "application/x-apple-diskimage", - ".doc": "application/msword", - ".dot": "application/msword", - ".dp": "application/commonground", - ".drw": "application/drafting", - ".dump": "application/octet-stream", - ".dv": "video/x-dv", - ".dvi": "application/x-dvi", - ".dwf": "drawing/x-dwf=(old)", - ".dwg": "application/acad", - ".dxf": "application/dxf", - ".dxr": "application/x-director", - ".el": "text/x-scriptelisp", - ".elc": "application/x-bytecodeelisp=(compiled=elisp)", - ".eml": "message/rfc822", - ".env": "application/x-envoy", - ".eps": "application/postscript", - ".es": "application/x-esrehber", - ".etx": "text/x-setext", - ".evy": "application/envoy", - ".exe": "application/octet-stream", - ".f77": "text/x-fortran", - ".f90": "text/x-fortran", - ".f": "text/x-fortran", - ".fdf": "application/vndfdf", - ".fif": "application/fractals", - ".fli": "video/fli", - ".flo": "image/florian", - ".flv": "video/x-flv", - ".flx": "text/vndfmiflexstor", - ".fmf": "video/x-atomic3d-feature", - ".for": "text/x-fortran", - ".fpx": "image/vndfpx", - ".frl": "application/freeloader", - ".funk": "audio/make", - ".g3": "image/g3fax", - ".g": "text/plain", - ".gif": "image/gif", - ".gl": "video/gl", - ".gsd": "audio/x-gsm", - ".gsm": "audio/x-gsm", - ".gsp": "application/x-gsp", - ".gss": "application/x-gss", - ".gtar": "application/x-gtar", - ".gz": "application/x-compressed", - ".gzip": "application/x-gzip", - ".h": "text/x-h", - ".hdf": "application/x-hdf", - ".help": "application/x-helpfile", - ".hgl": "application/vndhp-hpgl", - ".hh": "text/x-h", - ".hlb": "text/x-script", - ".hlp": "application/hlp", - ".hpg": "application/vndhp-hpgl", - ".hpgl": "application/vndhp-hpgl", - ".hqx": "application/binhex", - ".hta": "application/hta", - ".htc": "text/x-component", - ".htm": "text/html", - ".html": "text/html", - ".htmls": "text/html", - ".htt": "text/webviewhtml", - ".htx": "text/html", - ".ice": "x-conference/x-cooltalk", - ".ico": "image/x-icon", - ".ics": "text/calendar", - ".icz": "text/calendar", - ".idc": "text/plain", - ".ief": "image/ief", - ".iefs": "image/ief", - ".iges": "application/iges", - ".igs": "application/iges", - ".ima": "application/x-ima", - ".imap": "application/x-httpd-imap", - ".inf": "application/inf", - ".ins": "application/x-internett-signup", - ".ip": "application/x-ip2", - ".isu": "video/x-isvideo", - ".it": "audio/it", - ".iv": "application/x-inventor", - ".ivr": "i-world/i-vrml", - ".ivy": "application/x-livescreen", - ".jam": "audio/x-jam", - ".jav": "text/x-java-source", - ".java": "text/x-java-source", - ".jcm": "application/x-java-commerce", - ".jfif-tbnl": "image/jpeg", - ".jfif": "image/jpeg", - ".jnlp": "application/x-java-jnlp-file", - ".jpe": "image/jpeg", - ".jpeg": "image/jpeg", - ".jpg": "image/jpeg", - ".jps": "image/x-jps", - ".js": "application/javascript", - ".json": "application/json", - ".jut": "image/jutvision", - ".kar": "audio/midi", - ".karbon": "application/vnd.kde.karbon", - ".kfo": "application/vnd.kde.kformula", - ".flw": "application/vnd.kde.kivio", - ".kml": "application/vnd.google-earth.kml+xml", - ".kmz": "application/vnd.google-earth.kmz", - ".kon": "application/vnd.kde.kontour", - ".kpr": "application/vnd.kde.kpresenter", - ".kpt": "application/vnd.kde.kpresenter", - ".ksp": "application/vnd.kde.kspread", - ".kwd": "application/vnd.kde.kword", - ".kwt": "application/vnd.kde.kword", - ".ksh": "text/x-scriptksh", - ".la": "audio/nspaudio", - ".lam": "audio/x-liveaudio", - ".latex": "application/x-latex", - ".lha": "application/lha", - ".lhx": "application/octet-stream", - ".list": "text/plain", - ".lma": "audio/nspaudio", - ".log": "text/plain", - ".lsp": "text/x-scriptlisp", - ".lst": "text/plain", - ".lsx": "text/x-la-asf", - ".ltx": "application/x-latex", - ".lzh": "application/octet-stream", - ".lzx": "application/lzx", - ".m1v": "video/mpeg", - ".m2a": "audio/mpeg", - ".m2v": "video/mpeg", - ".m3u": "audio/x-mpegurl", - ".m": "text/x-m", - ".man": "application/x-troff-man", - ".manifest": "text/cache-manifest", - ".map": "application/x-navimap", - ".mar": "text/plain", - ".mbd": "application/mbedlet", - ".mc$": "application/x-magic-cap-package-10", - ".mcd": "application/mcad", - ".mcf": "text/mcf", - ".mcp": "application/netmc", - ".me": "application/x-troff-me", - ".mht": "message/rfc822", - ".mhtml": "message/rfc822", - ".mid": "application/x-midi", - ".midi": "application/x-midi", - ".mif": "application/x-frame", - ".mime": "message/rfc822", - ".mjf": "audio/x-vndaudioexplosionmjuicemediafile", - ".mjpg": "video/x-motion-jpeg", - ".mm": "application/base64", - ".mme": "application/base64", - ".mod": "audio/mod", - ".moov": "video/quicktime", - ".mov": "video/quicktime", - ".movie": "video/x-sgi-movie", - ".mp2": "audio/mpeg", - ".mp3": "audio/mpeg3", - ".mp4": "video/mp4", - ".mpa": "audio/mpeg", - ".mpc": "application/x-project", - ".mpe": "video/mpeg", - ".mpeg": "video/mpeg", - ".mpg": "video/mpeg", - ".mpga": "audio/mpeg", - ".mpp": "application/vndms-project", - ".mpt": "application/x-project", - ".mpv": "application/x-project", - ".mpx": "application/x-project", - ".mrc": "application/marc", - ".ms": "application/x-troff-ms", - ".mv": "video/x-sgi-movie", - ".my": "audio/make", - ".mzz": "application/x-vndaudioexplosionmzz", - ".nap": "image/naplps", - ".naplps": "image/naplps", - ".nc": "application/x-netcdf", - ".ncm": "application/vndnokiaconfiguration-message", - ".nif": "image/x-niff", - ".niff": "image/x-niff", - ".nix": "application/x-mix-transfer", - ".nsc": "application/x-conference", - ".nvd": "application/x-navidoc", - ".o": "application/octet-stream", - ".oda": "application/oda", - ".odb": "application/vnd.oasis.opendocument.database", - ".odc": "application/vnd.oasis.opendocument.chart", - ".odf": "application/vnd.oasis.opendocument.formula", - ".odg": "application/vnd.oasis.opendocument.graphics", - ".odi": "application/vnd.oasis.opendocument.image", - ".odm": "application/vnd.oasis.opendocument.text-master", - ".odp": "application/vnd.oasis.opendocument.presentation", - ".ods": "application/vnd.oasis.opendocument.spreadsheet", - ".odt": "application/vnd.oasis.opendocument.text", - ".oga": "audio/ogg", - ".ogg": "audio/ogg", - ".ogv": "video/ogg", - ".omc": "application/x-omc", - ".omcd": "application/x-omcdatamaker", - ".omcr": "application/x-omcregerator", - ".otc": "application/vnd.oasis.opendocument.chart-template", - ".otf": "application/vnd.oasis.opendocument.formula-template", - ".otg": "application/vnd.oasis.opendocument.graphics-template", - ".oth": "application/vnd.oasis.opendocument.text-web", - ".oti": "application/vnd.oasis.opendocument.image-template", - ".otm": "application/vnd.oasis.opendocument.text-master", - ".otp": "application/vnd.oasis.opendocument.presentation-template", - ".ots": "application/vnd.oasis.opendocument.spreadsheet-template", - ".ott": "application/vnd.oasis.opendocument.text-template", - ".p10": "application/pkcs10", - ".p12": "application/pkcs-12", - ".p7a": "application/x-pkcs7-signature", - ".p7c": "application/pkcs7-mime", - ".p7m": "application/pkcs7-mime", - ".p7r": "application/x-pkcs7-certreqresp", - ".p7s": "application/pkcs7-signature", - ".p": "text/x-pascal", - ".part": "application/pro_eng", - ".pas": "text/pascal", - ".pbm": "image/x-portable-bitmap", - ".pcl": "application/vndhp-pcl", - ".pct": "image/x-pict", - ".pcx": "image/x-pcx", - ".pdb": "chemical/x-pdb", - ".pdf": "application/pdf", - ".pfunk": "audio/make", - ".pgm": "image/x-portable-graymap", - ".pic": "image/pict", - ".pict": "image/pict", - ".pkg": "application/x-newton-compatible-pkg", - ".pko": "application/vndms-pkipko", - ".pl": "text/x-scriptperl", - ".plx": "application/x-pixclscript", - ".pm4": "application/x-pagemaker", - ".pm5": "application/x-pagemaker", - ".pm": "text/x-scriptperl-module", - ".png": "image/png", - ".pnm": "application/x-portable-anymap", - ".pot": "application/mspowerpoint", - ".pov": "model/x-pov", - ".ppa": "application/vndms-powerpoint", - ".ppm": "image/x-portable-pixmap", - ".pps": "application/mspowerpoint", - ".ppt": "application/mspowerpoint", - ".ppz": "application/mspowerpoint", - ".pre": "application/x-freelance", - ".prt": "application/pro_eng", - ".ps": "application/postscript", - ".psd": "application/octet-stream", - ".pvu": "paleovu/x-pv", - ".pwz": "application/vndms-powerpoint", - ".py": "text/x-scriptphyton", - ".pyc": "application/x-bytecodepython", - ".qcp": "audio/vndqcelp", - ".qd3": "x-world/x-3dmf", - ".qd3d": "x-world/x-3dmf", - ".qif": "image/x-quicktime", - ".qt": "video/quicktime", - ".qtc": "video/x-qtc", - ".qti": "image/x-quicktime", - ".qtif": "image/x-quicktime", - ".ra": "audio/x-pn-realaudio", - ".ram": "audio/x-pn-realaudio", - ".rar": "application/x-rar-compressed", - ".ras": "application/x-cmu-raster", - ".rast": "image/cmu-raster", - ".rexx": "text/x-scriptrexx", - ".rf": "image/vndrn-realflash", - ".rgb": "image/x-rgb", - ".rm": "application/vndrn-realmedia", - ".rmi": "audio/mid", - ".rmm": "audio/x-pn-realaudio", - ".rmp": "audio/x-pn-realaudio", - ".rng": "application/ringing-tones", - ".rnx": "application/vndrn-realplayer", - ".roff": "application/x-troff", - ".rp": "image/vndrn-realpix", - ".rpm": "audio/x-pn-realaudio-plugin", - ".rt": "text/vndrn-realtext", - ".rtf": "text/richtext", - ".rtx": "text/richtext", - ".rv": "video/vndrn-realvideo", - ".s": "text/x-asm", - ".s3m": "audio/s3m", - ".s7z": "application/x-7z-compressed", - ".saveme": "application/octet-stream", - ".sbk": "application/x-tbook", - ".scm": "text/x-scriptscheme", - ".sdml": "text/plain", - ".sdp": "application/sdp", - ".sdr": "application/sounder", - ".sea": "application/sea", - ".set": "application/set", - ".sgm": "text/x-sgml", - ".sgml": "text/x-sgml", - ".sh": "text/x-scriptsh", - ".shar": "application/x-bsh", - ".shtml": "text/x-server-parsed-html", - ".sid": "audio/x-psid", - ".skd": "application/x-koan", - ".skm": "application/x-koan", - ".skp": "application/x-koan", - ".skt": "application/x-koan", - ".sit": "application/x-stuffit", - ".sitx": "application/x-stuffitx", - ".sl": "application/x-seelogo", - ".smi": "application/smil", - ".smil": "application/smil", - ".snd": "audio/basic", - ".sol": "application/solids", - ".spc": "text/x-speech", - ".spl": "application/futuresplash", - ".spr": "application/x-sprite", - ".sprite": "application/x-sprite", - ".spx": "audio/ogg", - ".src": "application/x-wais-source", - ".ssi": "text/x-server-parsed-html", - ".ssm": "application/streamingmedia", - ".sst": "application/vndms-pkicertstore", - ".step": "application/step", - ".stl": "application/sla", - ".stp": "application/step", - ".sv4cpio": "application/x-sv4cpio", - ".sv4crc": "application/x-sv4crc", - ".svf": "image/vnddwg", - ".svg": "image/svg+xml", - ".svr": "application/x-world", - ".swf": "application/x-shockwave-flash", - ".t": "application/x-troff", - ".talk": "text/x-speech", - ".tar": "application/x-tar", - ".tbk": "application/toolbook", - ".tcl": "text/x-scripttcl", - ".tcsh": "text/x-scripttcsh", - ".tex": "application/x-tex", - ".texi": "application/x-texinfo", - ".texinfo": "application/x-texinfo", - ".text": "text/plain", - ".tgz": "application/gnutar", - ".tif": "image/tiff", - ".tiff": "image/tiff", - ".tr": "application/x-troff", - ".tsi": "audio/tsp-audio", - ".tsp": "application/dsptype", - ".tsv": "text/tab-separated-values", - ".turbot": "image/florian", - ".txt": "text/plain", - ".uil": "text/x-uil", - ".uni": "text/uri-list", - ".unis": "text/uri-list", - ".unv": "application/i-deas", - ".uri": "text/uri-list", - ".uris": "text/uri-list", - ".ustar": "application/x-ustar", - ".uu": "text/x-uuencode", - ".uue": "text/x-uuencode", - ".vcd": "application/x-cdlink", - ".vcf": "text/x-vcard", - ".vcard": "text/x-vcard", - ".vcs": "text/x-vcalendar", - ".vda": "application/vda", - ".vdo": "video/vdo", - ".vew": "application/groupwise", - ".viv": "video/vivo", - ".vivo": "video/vivo", - ".vmd": "application/vocaltec-media-desc", - ".vmf": "application/vocaltec-media-file", - ".voc": "audio/voc", - ".vos": "video/vosaic", - ".vox": "audio/voxware", - ".vqe": "audio/x-twinvq-plugin", - ".vqf": "audio/x-twinvq", - ".vql": "audio/x-twinvq-plugin", - ".vrml": "application/x-vrml", - ".vrt": "x-world/x-vrt", - ".vsd": "application/x-visio", - ".vst": "application/x-visio", - ".vsw": "application/x-visio", - ".w60": "application/wordperfect60", - ".w61": "application/wordperfect61", - ".w6w": "application/msword", - ".wav": "audio/wav", - ".wb1": "application/x-qpro", - ".wbmp": "image/vnd.wap.wbmp", - ".web": "application/vndxara", - ".wiz": "application/msword", - ".wk1": "application/x-123", - ".wmf": "windows/metafile", - ".wml": "text/vnd.wap.wml", - ".wmlc": "application/vnd.wap.wmlc", - ".wmls": "text/vnd.wap.wmlscript", - ".wmlsc": "application/vnd.wap.wmlscriptc", - ".word": "application/msword", - ".wp5": "application/wordperfect", - ".wp6": "application/wordperfect", - ".wp": "application/wordperfect", - ".wpd": "application/wordperfect", - ".wq1": "application/x-lotus", - ".wri": "application/mswrite", - ".wrl": "application/x-world", - ".wrz": "model/vrml", - ".wsc": "text/scriplet", - ".wsrc": "application/x-wais-source", - ".wtk": "application/x-wintalk", - ".x-png": "image/png", - ".xbm": "image/x-xbitmap", - ".xdr": "video/x-amt-demorun", - ".xgz": "xgl/drawing", - ".xif": "image/vndxiff", - ".xl": "application/excel", - ".xla": "application/excel", - ".xlb": "application/excel", - ".xlc": "application/excel", - ".xld": "application/excel", - ".xlk": "application/excel", - ".xll": "application/excel", - ".xlm": "application/excel", - ".xls": "application/excel", - ".xlt": "application/excel", - ".xlv": "application/excel", - ".xlw": "application/excel", - ".xm": "audio/xm", - ".xml": "text/xml", - ".xmz": "xgl/movie", - ".xpix": "application/x-vndls-xpix", - ".xpm": "image/x-xpixmap", - ".xsr": "video/x-amt-showrun", - ".xwd": "image/x-xwd", - ".xyz": "chemical/x-pdb", - ".z": "application/x-compress", - ".zip": "application/zip", - ".zoo": "application/octet-stream", - ".zsh": "text/x-scriptzsh", - ".docx": "application/vnd.openxmlformats-officedocument.wordprocessingml.document", - ".docm": "application/vnd.ms-word.document.macroEnabled.12", - ".dotx": "application/vnd.openxmlformats-officedocument.wordprocessingml.template", - ".dotm": "application/vnd.ms-word.template.macroEnabled.12", - ".xlsx": "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet", - ".xlsm": "application/vnd.ms-excel.sheet.macroEnabled.12", - ".xltx": "application/vnd.openxmlformats-officedocument.spreadsheetml.template", - ".xltm": "application/vnd.ms-excel.template.macroEnabled.12", - ".xlsb": "application/vnd.ms-excel.sheet.binary.macroEnabled.12", - ".xlam": "application/vnd.ms-excel.addin.macroEnabled.12", - ".pptx": "application/vnd.openxmlformats-officedocument.presentationml.presentation", - ".pptm": "application/vnd.ms-powerpoint.presentation.macroEnabled.12", - ".ppsx": "application/vnd.openxmlformats-officedocument.presentationml.slideshow", - ".ppsm": "application/vnd.ms-powerpoint.slideshow.macroEnabled.12", - ".potx": "application/vnd.openxmlformats-officedocument.presentationml.template", - ".potm": "application/vnd.ms-powerpoint.template.macroEnabled.12", - ".ppam": "application/vnd.ms-powerpoint.addin.macroEnabled.12", - ".sldx": "application/vnd.openxmlformats-officedocument.presentationml.slide", - ".sldm": "application/vnd.ms-powerpoint.slide.macroEnabled.12", - ".thmx": "application/vnd.ms-officetheme", - ".onetoc": "application/onenote", - ".onetoc2": "application/onenote", - ".onetmp": "application/onenote", - ".onepkg": "application/onenote", - ".key": "application/x-iwork-keynote-sffkey", - ".kth": "application/x-iwork-keynote-sffkth", - ".nmbtemplate": "application/x-iwork-numbers-sfftemplate", - ".numbers": "application/x-iwork-numbers-sffnumbers", - ".pages": "application/x-iwork-pages-sffpages", - ".template": "application/x-iwork-pages-sfftemplate", - ".xpi": "application/x-xpinstall", - ".oex": "application/x-opera-extension", - ".mustache": "text/html", -} diff --git a/namespace.go b/namespace.go deleted file mode 100644 index a6962994..00000000 --- a/namespace.go +++ /dev/null @@ -1,433 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package beego - -import ( - "net/http" - "strings" - - beecontext "github.com/astaxie/beego/context" -) - -type namespaceCond func(*beecontext.Context) bool - -// LinkNamespace used as link action -// Deprecated: using pkg/, we will delete this in v2.1.0 -type LinkNamespace func(*Namespace) - -// Namespace is store all the info -// Deprecated: using pkg/, we will delete this in v2.1.0 -type Namespace struct { - prefix string - handlers *ControllerRegister -} - -// NewNamespace get new Namespace -// Deprecated: using pkg/, we will delete this in v2.1.0 -func NewNamespace(prefix string, params ...LinkNamespace) *Namespace { - ns := &Namespace{ - prefix: prefix, - handlers: NewControllerRegister(), - } - for _, p := range params { - p(ns) - } - return ns -} - -// Cond set condition function -// if cond return true can run this namespace, else can't -// usage: -// ns.Cond(func (ctx *context.Context) bool{ -// if ctx.Input.Domain() == "api.beego.me" { -// return true -// } -// return false -// }) -// Cond as the first filter -// Deprecated: using pkg/, we will delete this in v2.1.0 -func (n *Namespace) Cond(cond namespaceCond) *Namespace { - fn := func(ctx *beecontext.Context) { - if !cond(ctx) { - exception("405", ctx) - } - } - if v := n.handlers.filters[BeforeRouter]; len(v) > 0 { - mr := new(FilterRouter) - mr.tree = NewTree() - mr.pattern = "*" - mr.filterFunc = fn - mr.tree.AddRouter("*", true) - n.handlers.filters[BeforeRouter] = append([]*FilterRouter{mr}, v...) - } else { - n.handlers.InsertFilter("*", BeforeRouter, fn) - } - return n -} - -// Filter add filter in the Namespace -// action has before & after -// FilterFunc -// usage: -// Filter("before", func (ctx *context.Context){ -// _, ok := ctx.Input.Session("uid").(int) -// if !ok && ctx.Request.RequestURI != "/login" { -// ctx.Redirect(302, "/login") -// } -// }) -// Deprecated: using pkg/, we will delete this in v2.1.0 -func (n *Namespace) Filter(action string, filter ...FilterFunc) *Namespace { - var a int - if action == "before" { - a = BeforeRouter - } else if action == "after" { - a = FinishRouter - } - for _, f := range filter { - n.handlers.InsertFilter("*", a, f) - } - return n -} - -// Router same as beego.Rourer -// refer: https://godoc.org/github.com/astaxie/beego#Router -// Deprecated: using pkg/, we will delete this in v2.1.0 -func (n *Namespace) Router(rootpath string, c ControllerInterface, mappingMethods ...string) *Namespace { - n.handlers.Add(rootpath, c, mappingMethods...) - return n -} - -// AutoRouter same as beego.AutoRouter -// refer: https://godoc.org/github.com/astaxie/beego#AutoRouter -// Deprecated: using pkg/, we will delete this in v2.1.0 -func (n *Namespace) AutoRouter(c ControllerInterface) *Namespace { - n.handlers.AddAuto(c) - return n -} - -// AutoPrefix same as beego.AutoPrefix -// refer: https://godoc.org/github.com/astaxie/beego#AutoPrefix -// Deprecated: using pkg/, we will delete this in v2.1.0 -func (n *Namespace) AutoPrefix(prefix string, c ControllerInterface) *Namespace { - n.handlers.AddAutoPrefix(prefix, c) - return n -} - -// Get same as beego.Get -// refer: https://godoc.org/github.com/astaxie/beego#Get -// Deprecated: using pkg/, we will delete this in v2.1.0 -func (n *Namespace) Get(rootpath string, f FilterFunc) *Namespace { - n.handlers.Get(rootpath, f) - return n -} - -// Post same as beego.Post -// refer: https://godoc.org/github.com/astaxie/beego#Post -// Deprecated: using pkg/, we will delete this in v2.1.0 -func (n *Namespace) Post(rootpath string, f FilterFunc) *Namespace { - n.handlers.Post(rootpath, f) - return n -} - -// Delete same as beego.Delete -// refer: https://godoc.org/github.com/astaxie/beego#Delete -// Deprecated: using pkg/, we will delete this in v2.1.0 -func (n *Namespace) Delete(rootpath string, f FilterFunc) *Namespace { - n.handlers.Delete(rootpath, f) - return n -} - -// Put same as beego.Put -// refer: https://godoc.org/github.com/astaxie/beego#Put -// Deprecated: using pkg/, we will delete this in v2.1.0 -func (n *Namespace) Put(rootpath string, f FilterFunc) *Namespace { - n.handlers.Put(rootpath, f) - return n -} - -// Head same as beego.Head -// refer: https://godoc.org/github.com/astaxie/beego#Head -// Deprecated: using pkg/, we will delete this in v2.1.0 -func (n *Namespace) Head(rootpath string, f FilterFunc) *Namespace { - n.handlers.Head(rootpath, f) - return n -} - -// Options same as beego.Options -// refer: https://godoc.org/github.com/astaxie/beego#Options -// Deprecated: using pkg/, we will delete this in v2.1.0 -func (n *Namespace) Options(rootpath string, f FilterFunc) *Namespace { - n.handlers.Options(rootpath, f) - return n -} - -// Patch same as beego.Patch -// refer: https://godoc.org/github.com/astaxie/beego#Patch -// Deprecated: using pkg/, we will delete this in v2.1.0 -func (n *Namespace) Patch(rootpath string, f FilterFunc) *Namespace { - n.handlers.Patch(rootpath, f) - return n -} - -// Any same as beego.Any -// refer: https://godoc.org/github.com/astaxie/beego#Any -// Deprecated: using pkg/, we will delete this in v2.1.0 -func (n *Namespace) Any(rootpath string, f FilterFunc) *Namespace { - n.handlers.Any(rootpath, f) - return n -} - -// Handler same as beego.Handler -// refer: https://godoc.org/github.com/astaxie/beego#Handler -// Deprecated: using pkg/, we will delete this in v2.1.0 -func (n *Namespace) Handler(rootpath string, h http.Handler) *Namespace { - n.handlers.Handler(rootpath, h) - return n -} - -// Include add include class -// refer: https://godoc.org/github.com/astaxie/beego#Include -// Deprecated: using pkg/, we will delete this in v2.1.0 -func (n *Namespace) Include(cList ...ControllerInterface) *Namespace { - n.handlers.Include(cList...) - return n -} - -// Namespace add nest Namespace -// usage: -//ns := beego.NewNamespace(“/v1”). -//Namespace( -// beego.NewNamespace("/shop"). -// Get("/:id", func(ctx *context.Context) { -// ctx.Output.Body([]byte("shopinfo")) -// }), -// beego.NewNamespace("/order"). -// Get("/:id", func(ctx *context.Context) { -// ctx.Output.Body([]byte("orderinfo")) -// }), -// beego.NewNamespace("/crm"). -// Get("/:id", func(ctx *context.Context) { -// ctx.Output.Body([]byte("crminfo")) -// }), -//) -// Deprecated: using pkg/, we will delete this in v2.1.0 -func (n *Namespace) Namespace(ns ...*Namespace) *Namespace { - for _, ni := range ns { - for k, v := range ni.handlers.routers { - if _, ok := n.handlers.routers[k]; ok { - addPrefix(v, ni.prefix) - n.handlers.routers[k].AddTree(ni.prefix, v) - } else { - t := NewTree() - t.AddTree(ni.prefix, v) - addPrefix(t, ni.prefix) - n.handlers.routers[k] = t - } - } - if ni.handlers.enableFilter { - for pos, filterList := range ni.handlers.filters { - for _, mr := range filterList { - t := NewTree() - t.AddTree(ni.prefix, mr.tree) - mr.tree = t - n.handlers.insertFilterRouter(pos, mr) - } - } - } - } - return n -} - -// AddNamespace register Namespace into beego.Handler -// support multi Namespace -// Deprecated: using pkg/, we will delete this in v2.1.0 -func AddNamespace(nl ...*Namespace) { - for _, n := range nl { - for k, v := range n.handlers.routers { - if _, ok := BeeApp.Handlers.routers[k]; ok { - addPrefix(v, n.prefix) - BeeApp.Handlers.routers[k].AddTree(n.prefix, v) - } else { - t := NewTree() - t.AddTree(n.prefix, v) - addPrefix(t, n.prefix) - BeeApp.Handlers.routers[k] = t - } - } - if n.handlers.enableFilter { - for pos, filterList := range n.handlers.filters { - for _, mr := range filterList { - t := NewTree() - t.AddTree(n.prefix, mr.tree) - mr.tree = t - BeeApp.Handlers.insertFilterRouter(pos, mr) - } - } - } - } -} - -func addPrefix(t *Tree, prefix string) { - for _, v := range t.fixrouters { - addPrefix(v, prefix) - } - if t.wildcard != nil { - addPrefix(t.wildcard, prefix) - } - for _, l := range t.leaves { - if c, ok := l.runObject.(*ControllerInfo); ok { - if !strings.HasPrefix(c.pattern, prefix) { - c.pattern = prefix + c.pattern - } - } - } -} - -// NSCond is Namespace Condition -// Deprecated: using pkg/, we will delete this in v2.1.0 -func NSCond(cond namespaceCond) LinkNamespace { - return func(ns *Namespace) { - ns.Cond(cond) - } -} - -// NSBefore Namespace BeforeRouter filter -// Deprecated: using pkg/, we will delete this in v2.1.0 -func NSBefore(filterList ...FilterFunc) LinkNamespace { - return func(ns *Namespace) { - ns.Filter("before", filterList...) - } -} - -// NSAfter add Namespace FinishRouter filter -// Deprecated: using pkg/, we will delete this in v2.1.0 -func NSAfter(filterList ...FilterFunc) LinkNamespace { - return func(ns *Namespace) { - ns.Filter("after", filterList...) - } -} - -// NSInclude Namespace Include ControllerInterface -// Deprecated: using pkg/, we will delete this in v2.1.0 -func NSInclude(cList ...ControllerInterface) LinkNamespace { - return func(ns *Namespace) { - ns.Include(cList...) - } -} - -// NSRouter call Namespace Router -// Deprecated: using pkg/, we will delete this in v2.1.0 -func NSRouter(rootpath string, c ControllerInterface, mappingMethods ...string) LinkNamespace { - return func(ns *Namespace) { - ns.Router(rootpath, c, mappingMethods...) - } -} - -// NSGet call Namespace Get -// Deprecated: using pkg/, we will delete this in v2.1.0 -func NSGet(rootpath string, f FilterFunc) LinkNamespace { - return func(ns *Namespace) { - ns.Get(rootpath, f) - } -} - -// NSPost call Namespace Post -// Deprecated: using pkg/, we will delete this in v2.1.0 -func NSPost(rootpath string, f FilterFunc) LinkNamespace { - return func(ns *Namespace) { - ns.Post(rootpath, f) - } -} - -// NSHead call Namespace Head -// Deprecated: using pkg/, we will delete this in v2.1.0 -func NSHead(rootpath string, f FilterFunc) LinkNamespace { - return func(ns *Namespace) { - ns.Head(rootpath, f) - } -} - -// NSPut call Namespace Put -// Deprecated: using pkg/, we will delete this in v2.1.0 -func NSPut(rootpath string, f FilterFunc) LinkNamespace { - return func(ns *Namespace) { - ns.Put(rootpath, f) - } -} - -// NSDelete call Namespace Delete -// Deprecated: using pkg/, we will delete this in v2.1.0 -func NSDelete(rootpath string, f FilterFunc) LinkNamespace { - return func(ns *Namespace) { - ns.Delete(rootpath, f) - } -} - -// NSAny call Namespace Any -// Deprecated: using pkg/, we will delete this in v2.1.0 -func NSAny(rootpath string, f FilterFunc) LinkNamespace { - return func(ns *Namespace) { - ns.Any(rootpath, f) - } -} - -// NSOptions call Namespace Options -// Deprecated: using pkg/, we will delete this in v2.1.0 -func NSOptions(rootpath string, f FilterFunc) LinkNamespace { - return func(ns *Namespace) { - ns.Options(rootpath, f) - } -} - -// NSPatch call Namespace Patch -// Deprecated: using pkg/, we will delete this in v2.1.0 -func NSPatch(rootpath string, f FilterFunc) LinkNamespace { - return func(ns *Namespace) { - ns.Patch(rootpath, f) - } -} - -// NSAutoRouter call Namespace AutoRouter -// Deprecated: using pkg/, we will delete this in v2.1.0 -func NSAutoRouter(c ControllerInterface) LinkNamespace { - return func(ns *Namespace) { - ns.AutoRouter(c) - } -} - -// NSAutoPrefix call Namespace AutoPrefix -// Deprecated: using pkg/, we will delete this in v2.1.0 -func NSAutoPrefix(prefix string, c ControllerInterface) LinkNamespace { - return func(ns *Namespace) { - ns.AutoPrefix(prefix, c) - } -} - -// NSNamespace add sub Namespace -// Deprecated: using pkg/, we will delete this in v2.1.0 -func NSNamespace(prefix string, params ...LinkNamespace) LinkNamespace { - return func(ns *Namespace) { - n := NewNamespace(prefix, params...) - ns.Namespace(n) - } -} - -// NSHandler add handler -// Deprecated: using pkg/, we will delete this in v2.1.0 -func NSHandler(rootpath string, h http.Handler) LinkNamespace { - return func(ns *Namespace) { - ns.Handler(rootpath, h) - } -} diff --git a/namespace_test.go b/namespace_test.go deleted file mode 100644 index b3f20dff..00000000 --- a/namespace_test.go +++ /dev/null @@ -1,168 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package beego - -import ( - "net/http" - "net/http/httptest" - "strconv" - "testing" - - "github.com/astaxie/beego/context" -) - -func TestNamespaceGet(t *testing.T) { - r, _ := http.NewRequest("GET", "/v1/user", nil) - w := httptest.NewRecorder() - - ns := NewNamespace("/v1") - ns.Get("/user", func(ctx *context.Context) { - ctx.Output.Body([]byte("v1_user")) - }) - AddNamespace(ns) - BeeApp.Handlers.ServeHTTP(w, r) - if w.Body.String() != "v1_user" { - t.Errorf("TestNamespaceGet can't run, get the response is " + w.Body.String()) - } -} - -func TestNamespacePost(t *testing.T) { - r, _ := http.NewRequest("POST", "/v1/user/123", nil) - w := httptest.NewRecorder() - - ns := NewNamespace("/v1") - ns.Post("/user/:id", func(ctx *context.Context) { - ctx.Output.Body([]byte(ctx.Input.Param(":id"))) - }) - AddNamespace(ns) - BeeApp.Handlers.ServeHTTP(w, r) - if w.Body.String() != "123" { - t.Errorf("TestNamespacePost can't run, get the response is " + w.Body.String()) - } -} - -func TestNamespaceNest(t *testing.T) { - r, _ := http.NewRequest("GET", "/v1/admin/order", nil) - w := httptest.NewRecorder() - - ns := NewNamespace("/v1") - ns.Namespace( - NewNamespace("/admin"). - Get("/order", func(ctx *context.Context) { - ctx.Output.Body([]byte("order")) - }), - ) - AddNamespace(ns) - BeeApp.Handlers.ServeHTTP(w, r) - if w.Body.String() != "order" { - t.Errorf("TestNamespaceNest can't run, get the response is " + w.Body.String()) - } -} - -func TestNamespaceNestParam(t *testing.T) { - r, _ := http.NewRequest("GET", "/v1/admin/order/123", nil) - w := httptest.NewRecorder() - - ns := NewNamespace("/v1") - ns.Namespace( - NewNamespace("/admin"). - Get("/order/:id", func(ctx *context.Context) { - ctx.Output.Body([]byte(ctx.Input.Param(":id"))) - }), - ) - AddNamespace(ns) - BeeApp.Handlers.ServeHTTP(w, r) - if w.Body.String() != "123" { - t.Errorf("TestNamespaceNestParam can't run, get the response is " + w.Body.String()) - } -} - -func TestNamespaceRouter(t *testing.T) { - r, _ := http.NewRequest("GET", "/v1/api/list", nil) - w := httptest.NewRecorder() - - ns := NewNamespace("/v1") - ns.Router("/api/list", &TestController{}, "*:List") - AddNamespace(ns) - BeeApp.Handlers.ServeHTTP(w, r) - if w.Body.String() != "i am list" { - t.Errorf("TestNamespaceRouter can't run, get the response is " + w.Body.String()) - } -} - -func TestNamespaceAutoFunc(t *testing.T) { - r, _ := http.NewRequest("GET", "/v1/test/list", nil) - w := httptest.NewRecorder() - - ns := NewNamespace("/v1") - ns.AutoRouter(&TestController{}) - AddNamespace(ns) - BeeApp.Handlers.ServeHTTP(w, r) - if w.Body.String() != "i am list" { - t.Errorf("user define func can't run") - } -} - -func TestNamespaceFilter(t *testing.T) { - r, _ := http.NewRequest("GET", "/v1/user/123", nil) - w := httptest.NewRecorder() - - ns := NewNamespace("/v1") - ns.Filter("before", func(ctx *context.Context) { - ctx.Output.Body([]byte("this is Filter")) - }). - Get("/user/:id", func(ctx *context.Context) { - ctx.Output.Body([]byte(ctx.Input.Param(":id"))) - }) - AddNamespace(ns) - BeeApp.Handlers.ServeHTTP(w, r) - if w.Body.String() != "this is Filter" { - t.Errorf("TestNamespaceFilter can't run, get the response is " + w.Body.String()) - } -} - -func TestNamespaceCond(t *testing.T) { - r, _ := http.NewRequest("GET", "/v2/test/list", nil) - w := httptest.NewRecorder() - - ns := NewNamespace("/v2") - ns.Cond(func(ctx *context.Context) bool { - return ctx.Input.Domain() == "beego.me" - }). - AutoRouter(&TestController{}) - AddNamespace(ns) - BeeApp.Handlers.ServeHTTP(w, r) - if w.Code != 405 { - t.Errorf("TestNamespaceCond can't run get the result " + strconv.Itoa(w.Code)) - } -} - -func TestNamespaceInside(t *testing.T) { - r, _ := http.NewRequest("GET", "/v3/shop/order/123", nil) - w := httptest.NewRecorder() - ns := NewNamespace("/v3", - NSAutoRouter(&TestController{}), - NSNamespace("/shop", - NSGet("/order/:id", func(ctx *context.Context) { - ctx.Output.Body([]byte(ctx.Input.Param(":id"))) - }), - ), - ) - AddNamespace(ns) - BeeApp.Handlers.ServeHTTP(w, r) - if w.Body.String() != "123" { - t.Errorf("TestNamespaceInside can't run, get the response is " + w.Body.String()) - } -} diff --git a/orm/README.md b/orm/README.md deleted file mode 100644 index 6e808d2a..00000000 --- a/orm/README.md +++ /dev/null @@ -1,159 +0,0 @@ -# beego orm - -[![Build Status](https://drone.io/github.com/astaxie/beego/status.png)](https://drone.io/github.com/astaxie/beego/latest) - -A powerful orm framework for go. - -It is heavily influenced by Django ORM, SQLAlchemy. - -**Support Database:** - -* MySQL: [github.com/go-sql-driver/mysql](https://github.com/go-sql-driver/mysql) -* PostgreSQL: [github.com/lib/pq](https://github.com/lib/pq) -* Sqlite3: [github.com/mattn/go-sqlite3](https://github.com/mattn/go-sqlite3) - -Passed all test, but need more feedback. - -**Features:** - -* full go type support -* easy for usage, simple CRUD operation -* auto join with relation table -* cross DataBase compatible query -* Raw SQL query / mapper without orm model -* full test keep stable and strong - -more features please read the docs - -**Install:** - - go get github.com/astaxie/beego/orm - -## Changelog - -* 2013-08-19: support table auto create -* 2013-08-13: update test for database types -* 2013-08-13: go type support, such as int8, uint8, byte, rune -* 2013-08-13: date / datetime timezone support very well - -## Quick Start - -#### Simple Usage - -```go -package main - -import ( - "fmt" - "github.com/astaxie/beego/orm" - _ "github.com/go-sql-driver/mysql" // import your used driver -) - -// Model Struct -type User struct { - Id int `orm:"auto"` - Name string `orm:"size(100)"` -} - -func init() { - // register model - orm.RegisterModel(new(User)) - - // set default database - orm.RegisterDataBase("default", "mysql", "root:root@/my_db?charset=utf8", 30) - - // create table - orm.RunSyncdb("default", false, true) -} - -func main() { - o := orm.NewOrm() - - user := User{Name: "slene"} - - // insert - id, err := o.Insert(&user) - - // update - user.Name = "astaxie" - num, err := o.Update(&user) - - // read one - u := User{Id: user.Id} - err = o.Read(&u) - - // delete - num, err = o.Delete(&u) -} -``` - -#### Next with relation - -```go -type Post struct { - Id int `orm:"auto"` - Title string `orm:"size(100)"` - User *User `orm:"rel(fk)"` -} - -var posts []*Post -qs := o.QueryTable("post") -num, err := qs.Filter("User__Name", "slene").All(&posts) -``` - -#### Use Raw sql - -If you don't like ORM,use Raw SQL to query / mapping without ORM setting - -```go -var maps []Params -num, err := o.Raw("SELECT id FROM user WHERE name = ?", "slene").Values(&maps) -if num > 0 { - fmt.Println(maps[0]["id"]) -} -``` - -#### Transaction - -```go -o.Begin() -... -user := User{Name: "slene"} -id, err := o.Insert(&user) -if err == nil { - o.Commit() -} else { - o.Rollback() -} - -``` - -#### Debug Log Queries - -In development env, you can simple use - -```go -func main() { - orm.Debug = true -... -``` - -enable log queries. - -output include all queries, such as exec / prepare / transaction. - -like this: - -```go -[ORM] - 2013-08-09 13:18:16 - [Queries/default] - [ db.Exec / 0.4ms] - [INSERT INTO `user` (`name`) VALUES (?)] - `slene` -... -``` - -note: not recommend use this in product env. - -## Docs - -more details and examples in docs and test - -[documents](http://beego.me/docs/mvc/model/overview.md) - diff --git a/orm/cmd.go b/orm/cmd.go deleted file mode 100644 index 0ff4dc40..00000000 --- a/orm/cmd.go +++ /dev/null @@ -1,283 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package orm - -import ( - "flag" - "fmt" - "os" - "strings" -) - -type commander interface { - Parse([]string) - Run() error -} - -var ( - commands = make(map[string]commander) -) - -// print help. -func printHelp(errs ...string) { - content := `orm command usage: - - syncdb - auto create tables - sqlall - print sql of create tables - help - print this help -` - - if len(errs) > 0 { - fmt.Println(errs[0]) - } - fmt.Println(content) - os.Exit(2) -} - -// RunCommand listen for orm command and then run it if command arguments passed. -func RunCommand() { - if len(os.Args) < 2 || os.Args[1] != "orm" { - return - } - - BootStrap() - - args := argString(os.Args[2:]) - name := args.Get(0) - - if name == "help" { - printHelp() - } - - if cmd, ok := commands[name]; ok { - cmd.Parse(os.Args[3:]) - cmd.Run() - os.Exit(0) - } else { - if name == "" { - printHelp() - } else { - printHelp(fmt.Sprintf("unknown command %s", name)) - } - } -} - -// sync database struct command interface. -type commandSyncDb struct { - al *alias - force bool - verbose bool - noInfo bool - rtOnError bool -} - -// parse orm command line arguments. -func (d *commandSyncDb) Parse(args []string) { - var name string - - flagSet := flag.NewFlagSet("orm command: syncdb", flag.ExitOnError) - flagSet.StringVar(&name, "db", "default", "DataBase alias name") - flagSet.BoolVar(&d.force, "force", false, "drop tables before create") - flagSet.BoolVar(&d.verbose, "v", false, "verbose info") - flagSet.Parse(args) - - d.al = getDbAlias(name) -} - -// run orm line command. -func (d *commandSyncDb) Run() error { - var drops []string - if d.force { - drops = getDbDropSQL(d.al) - } - - db := d.al.DB - - if d.force { - for i, mi := range modelCache.allOrdered() { - query := drops[i] - if !d.noInfo { - fmt.Printf("drop table `%s`\n", mi.table) - } - _, err := db.Exec(query) - if d.verbose { - fmt.Printf(" %s\n\n", query) - } - if err != nil { - if d.rtOnError { - return err - } - fmt.Printf(" %s\n", err.Error()) - } - } - } - - sqls, indexes := getDbCreateSQL(d.al) - - tables, err := d.al.DbBaser.GetTables(db) - if err != nil { - if d.rtOnError { - return err - } - fmt.Printf(" %s\n", err.Error()) - } - - for i, mi := range modelCache.allOrdered() { - if tables[mi.table] { - if !d.noInfo { - fmt.Printf("table `%s` already exists, skip\n", mi.table) - } - - var fields []*fieldInfo - columns, err := d.al.DbBaser.GetColumns(db, mi.table) - if err != nil { - if d.rtOnError { - return err - } - fmt.Printf(" %s\n", err.Error()) - } - - for _, fi := range mi.fields.fieldsDB { - if _, ok := columns[fi.column]; !ok { - fields = append(fields, fi) - } - } - - for _, fi := range fields { - query := getColumnAddQuery(d.al, fi) - - if !d.noInfo { - fmt.Printf("add column `%s` for table `%s`\n", fi.fullName, mi.table) - } - - _, err := db.Exec(query) - if d.verbose { - fmt.Printf(" %s\n", query) - } - if err != nil { - if d.rtOnError { - return err - } - fmt.Printf(" %s\n", err.Error()) - } - } - - for _, idx := range indexes[mi.table] { - if !d.al.DbBaser.IndexExists(db, idx.Table, idx.Name) { - if !d.noInfo { - fmt.Printf("create index `%s` for table `%s`\n", idx.Name, idx.Table) - } - - query := idx.SQL - _, err := db.Exec(query) - if d.verbose { - fmt.Printf(" %s\n", query) - } - if err != nil { - if d.rtOnError { - return err - } - fmt.Printf(" %s\n", err.Error()) - } - } - } - - continue - } - - if !d.noInfo { - fmt.Printf("create table `%s` \n", mi.table) - } - - queries := []string{sqls[i]} - for _, idx := range indexes[mi.table] { - queries = append(queries, idx.SQL) - } - - for _, query := range queries { - _, err := db.Exec(query) - if d.verbose { - query = " " + strings.Join(strings.Split(query, "\n"), "\n ") - fmt.Println(query) - } - if err != nil { - if d.rtOnError { - return err - } - fmt.Printf(" %s\n", err.Error()) - } - } - if d.verbose { - fmt.Println("") - } - } - - return nil -} - -// database creation commander interface implement. -type commandSQLAll struct { - al *alias -} - -// parse orm command line arguments. -func (d *commandSQLAll) Parse(args []string) { - var name string - - flagSet := flag.NewFlagSet("orm command: sqlall", flag.ExitOnError) - flagSet.StringVar(&name, "db", "default", "DataBase alias name") - flagSet.Parse(args) - - d.al = getDbAlias(name) -} - -// run orm line command. -func (d *commandSQLAll) Run() error { - sqls, indexes := getDbCreateSQL(d.al) - var all []string - for i, mi := range modelCache.allOrdered() { - queries := []string{sqls[i]} - for _, idx := range indexes[mi.table] { - queries = append(queries, idx.SQL) - } - sql := strings.Join(queries, "\n") - all = append(all, sql) - } - fmt.Println(strings.Join(all, "\n\n")) - - return nil -} - -func init() { - commands["syncdb"] = new(commandSyncDb) - commands["sqlall"] = new(commandSQLAll) -} - -// RunSyncdb run syncdb command line. -// name means table's alias name. default is "default". -// force means run next sql if the current is error. -// verbose means show all info when running command or not. -func RunSyncdb(name string, force bool, verbose bool) error { - BootStrap() - - al := getDbAlias(name) - cmd := new(commandSyncDb) - cmd.al = al - cmd.force = force - cmd.noInfo = !verbose - cmd.verbose = verbose - cmd.rtOnError = true - return cmd.Run() -} diff --git a/orm/cmd_utils.go b/orm/cmd_utils.go deleted file mode 100644 index 692a079f..00000000 --- a/orm/cmd_utils.go +++ /dev/null @@ -1,320 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package orm - -import ( - "fmt" - "os" - "strings" -) - -type dbIndex struct { - Table string - Name string - SQL string -} - -// create database drop sql. -func getDbDropSQL(al *alias) (sqls []string) { - if len(modelCache.cache) == 0 { - fmt.Println("no Model found, need register your model") - os.Exit(2) - } - - Q := al.DbBaser.TableQuote() - - for _, mi := range modelCache.allOrdered() { - sqls = append(sqls, fmt.Sprintf(`DROP TABLE IF EXISTS %s%s%s`, Q, mi.table, Q)) - } - return sqls -} - -// get database column type string. -func getColumnTyp(al *alias, fi *fieldInfo) (col string) { - T := al.DbBaser.DbTypes() - fieldType := fi.fieldType - fieldSize := fi.size - -checkColumn: - switch fieldType { - case TypeBooleanField: - col = T["bool"] - case TypeVarCharField: - if al.Driver == DRPostgres && fi.toText { - col = T["string-text"] - } else { - col = fmt.Sprintf(T["string"], fieldSize) - } - case TypeCharField: - col = fmt.Sprintf(T["string-char"], fieldSize) - case TypeTextField: - col = T["string-text"] - case TypeTimeField: - col = T["time.Time-clock"] - case TypeDateField: - col = T["time.Time-date"] - case TypeDateTimeField: - col = T["time.Time"] - case TypeBitField: - col = T["int8"] - case TypeSmallIntegerField: - col = T["int16"] - case TypeIntegerField: - col = T["int32"] - case TypeBigIntegerField: - if al.Driver == DRSqlite { - fieldType = TypeIntegerField - goto checkColumn - } - col = T["int64"] - case TypePositiveBitField: - col = T["uint8"] - case TypePositiveSmallIntegerField: - col = T["uint16"] - case TypePositiveIntegerField: - col = T["uint32"] - case TypePositiveBigIntegerField: - col = T["uint64"] - case TypeFloatField: - col = T["float64"] - case TypeDecimalField: - s := T["float64-decimal"] - if !strings.Contains(s, "%d") { - col = s - } else { - col = fmt.Sprintf(s, fi.digits, fi.decimals) - } - case TypeJSONField: - if al.Driver != DRPostgres { - fieldType = TypeVarCharField - goto checkColumn - } - col = T["json"] - case TypeJsonbField: - if al.Driver != DRPostgres { - fieldType = TypeVarCharField - goto checkColumn - } - col = T["jsonb"] - case RelForeignKey, RelOneToOne: - fieldType = fi.relModelInfo.fields.pk.fieldType - fieldSize = fi.relModelInfo.fields.pk.size - goto checkColumn - } - - return -} - -// create alter sql string. -func getColumnAddQuery(al *alias, fi *fieldInfo) string { - Q := al.DbBaser.TableQuote() - typ := getColumnTyp(al, fi) - - if !fi.null { - typ += " " + "NOT NULL" - } - - return fmt.Sprintf("ALTER TABLE %s%s%s ADD COLUMN %s%s%s %s %s", - Q, fi.mi.table, Q, - Q, fi.column, Q, - typ, getColumnDefault(fi), - ) -} - -// create database creation string. -func getDbCreateSQL(al *alias) (sqls []string, tableIndexes map[string][]dbIndex) { - if len(modelCache.cache) == 0 { - fmt.Println("no Model found, need register your model") - os.Exit(2) - } - - Q := al.DbBaser.TableQuote() - T := al.DbBaser.DbTypes() - sep := fmt.Sprintf("%s, %s", Q, Q) - - tableIndexes = make(map[string][]dbIndex) - - for _, mi := range modelCache.allOrdered() { - sql := fmt.Sprintf("-- %s\n", strings.Repeat("-", 50)) - sql += fmt.Sprintf("-- Table Structure for `%s`\n", mi.fullName) - sql += fmt.Sprintf("-- %s\n", strings.Repeat("-", 50)) - - sql += fmt.Sprintf("CREATE TABLE IF NOT EXISTS %s%s%s (\n", Q, mi.table, Q) - - columns := make([]string, 0, len(mi.fields.fieldsDB)) - - sqlIndexes := [][]string{} - - for _, fi := range mi.fields.fieldsDB { - - column := fmt.Sprintf(" %s%s%s ", Q, fi.column, Q) - col := getColumnTyp(al, fi) - - if fi.auto { - switch al.Driver { - case DRSqlite, DRPostgres: - column += T["auto"] - default: - column += col + " " + T["auto"] - } - } else if fi.pk { - column += col + " " + T["pk"] - } else { - column += col - - if !fi.null { - column += " " + "NOT NULL" - } - - //if fi.initial.String() != "" { - // column += " DEFAULT " + fi.initial.String() - //} - - // Append attribute DEFAULT - column += getColumnDefault(fi) - - if fi.unique { - column += " " + "UNIQUE" - } - - if fi.index { - sqlIndexes = append(sqlIndexes, []string{fi.column}) - } - } - - if strings.Contains(column, "%COL%") { - column = strings.Replace(column, "%COL%", fi.column, -1) - } - - if fi.description != "" && al.Driver != DRSqlite { - column += " " + fmt.Sprintf("COMMENT '%s'", fi.description) - } - - columns = append(columns, column) - } - - if mi.model != nil { - allnames := getTableUnique(mi.addrField) - if !mi.manual && len(mi.uniques) > 0 { - allnames = append(allnames, mi.uniques) - } - for _, names := range allnames { - cols := make([]string, 0, len(names)) - for _, name := range names { - if fi, ok := mi.fields.GetByAny(name); ok && fi.dbcol { - cols = append(cols, fi.column) - } else { - panic(fmt.Errorf("cannot found column `%s` when parse UNIQUE in `%s.TableUnique`", name, mi.fullName)) - } - } - column := fmt.Sprintf(" UNIQUE (%s%s%s)", Q, strings.Join(cols, sep), Q) - columns = append(columns, column) - } - } - - sql += strings.Join(columns, ",\n") - sql += "\n)" - - if al.Driver == DRMySQL { - var engine string - if mi.model != nil { - engine = getTableEngine(mi.addrField) - } - if engine == "" { - engine = al.Engine - } - sql += " ENGINE=" + engine - } - - sql += ";" - sqls = append(sqls, sql) - - if mi.model != nil { - for _, names := range getTableIndex(mi.addrField) { - cols := make([]string, 0, len(names)) - for _, name := range names { - if fi, ok := mi.fields.GetByAny(name); ok && fi.dbcol { - cols = append(cols, fi.column) - } else { - panic(fmt.Errorf("cannot found column `%s` when parse INDEX in `%s.TableIndex`", name, mi.fullName)) - } - } - sqlIndexes = append(sqlIndexes, cols) - } - } - - for _, names := range sqlIndexes { - name := mi.table + "_" + strings.Join(names, "_") - cols := strings.Join(names, sep) - sql := fmt.Sprintf("CREATE INDEX %s%s%s ON %s%s%s (%s%s%s);", Q, name, Q, Q, mi.table, Q, Q, cols, Q) - - index := dbIndex{} - index.Table = mi.table - index.Name = name - index.SQL = sql - - tableIndexes[mi.table] = append(tableIndexes[mi.table], index) - } - - } - - return -} - -// Get string value for the attribute "DEFAULT" for the CREATE, ALTER commands -func getColumnDefault(fi *fieldInfo) string { - var ( - v, t, d string - ) - - // Skip default attribute if field is in relations - if fi.rel || fi.reverse { - return v - } - - t = " DEFAULT '%s' " - - // These defaults will be useful if there no config value orm:"default" and NOT NULL is on - switch fi.fieldType { - case TypeTimeField, TypeDateField, TypeDateTimeField, TypeTextField: - return v - - case TypeBitField, TypeSmallIntegerField, TypeIntegerField, - TypeBigIntegerField, TypePositiveBitField, TypePositiveSmallIntegerField, - TypePositiveIntegerField, TypePositiveBigIntegerField, TypeFloatField, - TypeDecimalField: - t = " DEFAULT %s " - d = "0" - case TypeBooleanField: - t = " DEFAULT %s " - d = "FALSE" - case TypeJSONField, TypeJsonbField: - d = "{}" - } - - if fi.colDefault { - if !fi.initial.Exist() { - v = fmt.Sprintf(t, "") - } else { - v = fmt.Sprintf(t, fi.initial.String()) - } - } else { - if !fi.null { - v = fmt.Sprintf(t, d) - } - } - - return v -} diff --git a/orm/db.go b/orm/db.go deleted file mode 100644 index 5d175bf1..00000000 --- a/orm/db.go +++ /dev/null @@ -1,1908 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package orm - -import ( - "database/sql" - "errors" - "fmt" - "reflect" - "strings" - "time" -) - -const ( - formatTime = "15:04:05" - formatDate = "2006-01-02" - formatDateTime = "2006-01-02 15:04:05" -) - -var ( - // ErrMissPK missing pk error - ErrMissPK = errors.New("missed pk value") -) - -var ( - operators = map[string]bool{ - "exact": true, - "iexact": true, - "contains": true, - "icontains": true, - // "regex": true, - // "iregex": true, - "gt": true, - "gte": true, - "lt": true, - "lte": true, - "eq": true, - "nq": true, - "ne": true, - ">": true, - ">=": true, - "<": true, - "<=": true, - "=": true, - "!=": true, - "startswith": true, - "endswith": true, - "istartswith": true, - "iendswith": true, - "in": true, - "between": true, - // "year": true, - // "month": true, - // "day": true, - // "week_day": true, - "isnull": true, - // "search": true, - } -) - -// an instance of dbBaser interface/ -type dbBase struct { - ins dbBaser -} - -// check dbBase implements dbBaser interface. -var _ dbBaser = new(dbBase) - -// get struct columns values as interface slice. -func (d *dbBase) collectValues(mi *modelInfo, ind reflect.Value, cols []string, skipAuto bool, insert bool, names *[]string, tz *time.Location) (values []interface{}, autoFields []string, err error) { - if names == nil { - ns := make([]string, 0, len(cols)) - names = &ns - } - values = make([]interface{}, 0, len(cols)) - - for _, column := range cols { - var fi *fieldInfo - if fi, _ = mi.fields.GetByAny(column); fi != nil { - column = fi.column - } else { - panic(fmt.Errorf("wrong db field/column name `%s` for model `%s`", column, mi.fullName)) - } - if !fi.dbcol || fi.auto && skipAuto { - continue - } - value, err := d.collectFieldValue(mi, fi, ind, insert, tz) - if err != nil { - return nil, nil, err - } - - // ignore empty value auto field - if insert && fi.auto { - if fi.fieldType&IsPositiveIntegerField > 0 { - if vu, ok := value.(uint64); !ok || vu == 0 { - continue - } - } else { - if vu, ok := value.(int64); !ok || vu == 0 { - continue - } - } - autoFields = append(autoFields, fi.column) - } - - *names, values = append(*names, column), append(values, value) - } - - return -} - -// get one field value in struct column as interface. -func (d *dbBase) collectFieldValue(mi *modelInfo, fi *fieldInfo, ind reflect.Value, insert bool, tz *time.Location) (interface{}, error) { - var value interface{} - if fi.pk { - _, value, _ = getExistPk(mi, ind) - } else { - field := ind.FieldByIndex(fi.fieldIndex) - if fi.isFielder { - f := field.Addr().Interface().(Fielder) - value = f.RawValue() - } else { - switch fi.fieldType { - case TypeBooleanField: - if nb, ok := field.Interface().(sql.NullBool); ok { - value = nil - if nb.Valid { - value = nb.Bool - } - } else if field.Kind() == reflect.Ptr { - if field.IsNil() { - value = nil - } else { - value = field.Elem().Bool() - } - } else { - value = field.Bool() - } - case TypeVarCharField, TypeCharField, TypeTextField, TypeJSONField, TypeJsonbField: - if ns, ok := field.Interface().(sql.NullString); ok { - value = nil - if ns.Valid { - value = ns.String - } - } else if field.Kind() == reflect.Ptr { - if field.IsNil() { - value = nil - } else { - value = field.Elem().String() - } - } else { - value = field.String() - } - case TypeFloatField, TypeDecimalField: - if nf, ok := field.Interface().(sql.NullFloat64); ok { - value = nil - if nf.Valid { - value = nf.Float64 - } - } else if field.Kind() == reflect.Ptr { - if field.IsNil() { - value = nil - } else { - value = field.Elem().Float() - } - } else { - vu := field.Interface() - if _, ok := vu.(float32); ok { - value, _ = StrTo(ToStr(vu)).Float64() - } else { - value = field.Float() - } - } - case TypeTimeField, TypeDateField, TypeDateTimeField: - value = field.Interface() - if t, ok := value.(time.Time); ok { - d.ins.TimeToDB(&t, tz) - if t.IsZero() { - value = nil - } else { - value = t - } - } - default: - switch { - case fi.fieldType&IsPositiveIntegerField > 0: - if field.Kind() == reflect.Ptr { - if field.IsNil() { - value = nil - } else { - value = field.Elem().Uint() - } - } else { - value = field.Uint() - } - case fi.fieldType&IsIntegerField > 0: - if ni, ok := field.Interface().(sql.NullInt64); ok { - value = nil - if ni.Valid { - value = ni.Int64 - } - } else if field.Kind() == reflect.Ptr { - if field.IsNil() { - value = nil - } else { - value = field.Elem().Int() - } - } else { - value = field.Int() - } - case fi.fieldType&IsRelField > 0: - if field.IsNil() { - value = nil - } else { - if _, vu, ok := getExistPk(fi.relModelInfo, reflect.Indirect(field)); ok { - value = vu - } else { - value = nil - } - } - if !fi.null && value == nil { - return nil, fmt.Errorf("field `%s` cannot be NULL", fi.fullName) - } - } - } - } - switch fi.fieldType { - case TypeTimeField, TypeDateField, TypeDateTimeField: - if fi.autoNow || fi.autoNowAdd && insert { - if insert { - if t, ok := value.(time.Time); ok && !t.IsZero() { - break - } - } - tnow := time.Now() - d.ins.TimeToDB(&tnow, tz) - value = tnow - if fi.isFielder { - f := field.Addr().Interface().(Fielder) - f.SetRaw(tnow.In(DefaultTimeLoc)) - } else if field.Kind() == reflect.Ptr { - v := tnow.In(DefaultTimeLoc) - field.Set(reflect.ValueOf(&v)) - } else { - field.Set(reflect.ValueOf(tnow.In(DefaultTimeLoc))) - } - } - case TypeJSONField, TypeJsonbField: - if s, ok := value.(string); (ok && len(s) == 0) || value == nil { - if fi.colDefault && fi.initial.Exist() { - value = fi.initial.String() - } else { - value = nil - } - } - } - } - return value, nil -} - -// create insert sql preparation statement object. -func (d *dbBase) PrepareInsert(q dbQuerier, mi *modelInfo) (stmtQuerier, string, error) { - Q := d.ins.TableQuote() - - dbcols := make([]string, 0, len(mi.fields.dbcols)) - marks := make([]string, 0, len(mi.fields.dbcols)) - for _, fi := range mi.fields.fieldsDB { - if !fi.auto { - dbcols = append(dbcols, fi.column) - marks = append(marks, "?") - } - } - qmarks := strings.Join(marks, ", ") - sep := fmt.Sprintf("%s, %s", Q, Q) - columns := strings.Join(dbcols, sep) - - query := fmt.Sprintf("INSERT INTO %s%s%s (%s%s%s) VALUES (%s)", Q, mi.table, Q, Q, columns, Q, qmarks) - - d.ins.ReplaceMarks(&query) - - d.ins.HasReturningID(mi, &query) - - stmt, err := q.Prepare(query) - return stmt, query, err -} - -// insert struct with prepared statement and given struct reflect value. -func (d *dbBase) InsertStmt(stmt stmtQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location) (int64, error) { - values, _, err := d.collectValues(mi, ind, mi.fields.dbcols, true, true, nil, tz) - if err != nil { - return 0, err - } - - if d.ins.HasReturningID(mi, nil) { - row := stmt.QueryRow(values...) - var id int64 - err := row.Scan(&id) - return id, err - } - res, err := stmt.Exec(values...) - if err == nil { - return res.LastInsertId() - } - return 0, err -} - -// query sql ,read records and persist in dbBaser. -func (d *dbBase) Read(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location, cols []string, isForUpdate bool) error { - var whereCols []string - var args []interface{} - - // if specify cols length > 0, then use it for where condition. - if len(cols) > 0 { - var err error - whereCols = make([]string, 0, len(cols)) - args, _, err = d.collectValues(mi, ind, cols, false, false, &whereCols, tz) - if err != nil { - return err - } - } else { - // default use pk value as where condtion. - pkColumn, pkValue, ok := getExistPk(mi, ind) - if !ok { - return ErrMissPK - } - whereCols = []string{pkColumn} - args = append(args, pkValue) - } - - Q := d.ins.TableQuote() - - sep := fmt.Sprintf("%s, %s", Q, Q) - sels := strings.Join(mi.fields.dbcols, sep) - colsNum := len(mi.fields.dbcols) - - sep = fmt.Sprintf("%s = ? AND %s", Q, Q) - wheres := strings.Join(whereCols, sep) - - forUpdate := "" - if isForUpdate { - forUpdate = "FOR UPDATE" - } - - query := fmt.Sprintf("SELECT %s%s%s FROM %s%s%s WHERE %s%s%s = ? %s", Q, sels, Q, Q, mi.table, Q, Q, wheres, Q, forUpdate) - - refs := make([]interface{}, colsNum) - for i := range refs { - var ref interface{} - refs[i] = &ref - } - - d.ins.ReplaceMarks(&query) - - row := q.QueryRow(query, args...) - if err := row.Scan(refs...); err != nil { - if err == sql.ErrNoRows { - return ErrNoRows - } - return err - } - elm := reflect.New(mi.addrField.Elem().Type()) - mind := reflect.Indirect(elm) - d.setColsValues(mi, &mind, mi.fields.dbcols, refs, tz) - ind.Set(mind) - return nil -} - -// execute insert sql dbQuerier with given struct reflect.Value. -func (d *dbBase) Insert(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location) (int64, error) { - names := make([]string, 0, len(mi.fields.dbcols)) - values, autoFields, err := d.collectValues(mi, ind, mi.fields.dbcols, false, true, &names, tz) - if err != nil { - return 0, err - } - - id, err := d.InsertValue(q, mi, false, names, values) - if err != nil { - return 0, err - } - - if len(autoFields) > 0 { - err = d.ins.setval(q, mi, autoFields) - } - return id, err -} - -// multi-insert sql with given slice struct reflect.Value. -func (d *dbBase) InsertMulti(q dbQuerier, mi *modelInfo, sind reflect.Value, bulk int, tz *time.Location) (int64, error) { - var ( - cnt int64 - nums int - values []interface{} - names []string - ) - - // typ := reflect.Indirect(mi.addrField).Type() - - length, autoFields := sind.Len(), make([]string, 0, 1) - - for i := 1; i <= length; i++ { - - ind := reflect.Indirect(sind.Index(i - 1)) - - // Is this needed ? - // if !ind.Type().AssignableTo(typ) { - // return cnt, ErrArgs - // } - - if i == 1 { - var ( - vus []interface{} - err error - ) - vus, autoFields, err = d.collectValues(mi, ind, mi.fields.dbcols, false, true, &names, tz) - if err != nil { - return cnt, err - } - values = make([]interface{}, bulk*len(vus)) - nums += copy(values, vus) - } else { - vus, _, err := d.collectValues(mi, ind, mi.fields.dbcols, false, true, nil, tz) - if err != nil { - return cnt, err - } - - if len(vus) != len(names) { - return cnt, ErrArgs - } - - nums += copy(values[nums:], vus) - } - - if i > 1 && i%bulk == 0 || length == i { - num, err := d.InsertValue(q, mi, true, names, values[:nums]) - if err != nil { - return cnt, err - } - cnt += num - nums = 0 - } - } - - var err error - if len(autoFields) > 0 { - err = d.ins.setval(q, mi, autoFields) - } - - return cnt, err -} - -// execute insert sql with given struct and given values. -// insert the given values, not the field values in struct. -func (d *dbBase) InsertValue(q dbQuerier, mi *modelInfo, isMulti bool, names []string, values []interface{}) (int64, error) { - Q := d.ins.TableQuote() - - marks := make([]string, len(names)) - for i := range marks { - marks[i] = "?" - } - - sep := fmt.Sprintf("%s, %s", Q, Q) - qmarks := strings.Join(marks, ", ") - columns := strings.Join(names, sep) - - multi := len(values) / len(names) - - if isMulti && multi > 1 { - qmarks = strings.Repeat(qmarks+"), (", multi-1) + qmarks - } - - query := fmt.Sprintf("INSERT INTO %s%s%s (%s%s%s) VALUES (%s)", Q, mi.table, Q, Q, columns, Q, qmarks) - - d.ins.ReplaceMarks(&query) - - if isMulti || !d.ins.HasReturningID(mi, &query) { - res, err := q.Exec(query, values...) - if err == nil { - if isMulti { - return res.RowsAffected() - } - return res.LastInsertId() - } - return 0, err - } - row := q.QueryRow(query, values...) - var id int64 - err := row.Scan(&id) - return id, err -} - -// InsertOrUpdate a row -// If your primary key or unique column conflict will update -// If no will insert -func (d *dbBase) InsertOrUpdate(q dbQuerier, mi *modelInfo, ind reflect.Value, a *alias, args ...string) (int64, error) { - args0 := "" - iouStr := "" - argsMap := map[string]string{} - switch a.Driver { - case DRMySQL: - iouStr = "ON DUPLICATE KEY UPDATE" - case DRPostgres: - if len(args) == 0 { - return 0, fmt.Errorf("`%s` use InsertOrUpdate must have a conflict column", a.DriverName) - } - args0 = strings.ToLower(args[0]) - iouStr = fmt.Sprintf("ON CONFLICT (%s) DO UPDATE SET", args0) - default: - return 0, fmt.Errorf("`%s` nonsupport InsertOrUpdate in beego", a.DriverName) - } - - //Get on the key-value pairs - for _, v := range args { - kv := strings.Split(v, "=") - if len(kv) == 2 { - argsMap[strings.ToLower(kv[0])] = kv[1] - } - } - - isMulti := false - names := make([]string, 0, len(mi.fields.dbcols)-1) - Q := d.ins.TableQuote() - values, _, err := d.collectValues(mi, ind, mi.fields.dbcols, true, true, &names, a.TZ) - - if err != nil { - return 0, err - } - - marks := make([]string, len(names)) - updateValues := make([]interface{}, 0) - updates := make([]string, len(names)) - var conflitValue interface{} - for i, v := range names { - // identifier in database may not be case-sensitive, so quote it - v = fmt.Sprintf("%s%s%s", Q, v, Q) - marks[i] = "?" - valueStr := argsMap[strings.ToLower(v)] - if v == args0 { - conflitValue = values[i] - } - if valueStr != "" { - switch a.Driver { - case DRMySQL: - updates[i] = v + "=" + valueStr - case DRPostgres: - if conflitValue != nil { - //postgres ON CONFLICT DO UPDATE SET can`t use colu=colu+values - updates[i] = fmt.Sprintf("%s=(select %s from %s where %s = ? )", v, valueStr, mi.table, args0) - updateValues = append(updateValues, conflitValue) - } else { - return 0, fmt.Errorf("`%s` must be in front of `%s` in your struct", args0, v) - } - } - } else { - updates[i] = v + "=?" - updateValues = append(updateValues, values[i]) - } - } - - values = append(values, updateValues...) - - sep := fmt.Sprintf("%s, %s", Q, Q) - qmarks := strings.Join(marks, ", ") - qupdates := strings.Join(updates, ", ") - columns := strings.Join(names, sep) - - multi := len(values) / len(names) - - if isMulti { - qmarks = strings.Repeat(qmarks+"), (", multi-1) + qmarks - } - //conflitValue maybe is a int,can`t use fmt.Sprintf - query := fmt.Sprintf("INSERT INTO %s%s%s (%s%s%s) VALUES (%s) %s "+qupdates, Q, mi.table, Q, Q, columns, Q, qmarks, iouStr) - - d.ins.ReplaceMarks(&query) - - if isMulti || !d.ins.HasReturningID(mi, &query) { - res, err := q.Exec(query, values...) - if err == nil { - if isMulti { - return res.RowsAffected() - } - return res.LastInsertId() - } - return 0, err - } - - row := q.QueryRow(query, values...) - var id int64 - err = row.Scan(&id) - if err != nil && err.Error() == `pq: syntax error at or near "ON"` { - err = fmt.Errorf("postgres version must 9.5 or higher") - } - return id, err -} - -// execute update sql dbQuerier with given struct reflect.Value. -func (d *dbBase) Update(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location, cols []string) (int64, error) { - pkName, pkValue, ok := getExistPk(mi, ind) - if !ok { - return 0, ErrMissPK - } - - var setNames []string - - // if specify cols length is zero, then commit all columns. - if len(cols) == 0 { - cols = mi.fields.dbcols - setNames = make([]string, 0, len(mi.fields.dbcols)-1) - } else { - setNames = make([]string, 0, len(cols)) - } - - setValues, _, err := d.collectValues(mi, ind, cols, true, false, &setNames, tz) - if err != nil { - return 0, err - } - - var findAutoNowAdd, findAutoNow bool - var index int - for i, col := range setNames { - if mi.fields.GetByColumn(col).autoNowAdd { - index = i - findAutoNowAdd = true - } - if mi.fields.GetByColumn(col).autoNow { - findAutoNow = true - } - } - if findAutoNowAdd { - setNames = append(setNames[0:index], setNames[index+1:]...) - setValues = append(setValues[0:index], setValues[index+1:]...) - } - - if !findAutoNow { - for col, info := range mi.fields.columns { - if info.autoNow { - setNames = append(setNames, col) - setValues = append(setValues, time.Now()) - } - } - } - - setValues = append(setValues, pkValue) - - Q := d.ins.TableQuote() - - sep := fmt.Sprintf("%s = ?, %s", Q, Q) - setColumns := strings.Join(setNames, sep) - - query := fmt.Sprintf("UPDATE %s%s%s SET %s%s%s = ? WHERE %s%s%s = ?", Q, mi.table, Q, Q, setColumns, Q, Q, pkName, Q) - - d.ins.ReplaceMarks(&query) - - res, err := q.Exec(query, setValues...) - if err == nil { - return res.RowsAffected() - } - return 0, err -} - -// execute delete sql dbQuerier with given struct reflect.Value. -// delete index is pk. -func (d *dbBase) Delete(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location, cols []string) (int64, error) { - var whereCols []string - var args []interface{} - // if specify cols length > 0, then use it for where condition. - if len(cols) > 0 { - var err error - whereCols = make([]string, 0, len(cols)) - args, _, err = d.collectValues(mi, ind, cols, false, false, &whereCols, tz) - if err != nil { - return 0, err - } - } else { - // default use pk value as where condtion. - pkColumn, pkValue, ok := getExistPk(mi, ind) - if !ok { - return 0, ErrMissPK - } - whereCols = []string{pkColumn} - args = append(args, pkValue) - } - - Q := d.ins.TableQuote() - - sep := fmt.Sprintf("%s = ? AND %s", Q, Q) - wheres := strings.Join(whereCols, sep) - - query := fmt.Sprintf("DELETE FROM %s%s%s WHERE %s%s%s = ?", Q, mi.table, Q, Q, wheres, Q) - - d.ins.ReplaceMarks(&query) - res, err := q.Exec(query, args...) - if err == nil { - num, err := res.RowsAffected() - if err != nil { - return 0, err - } - if num > 0 { - if mi.fields.pk.auto { - if mi.fields.pk.fieldType&IsPositiveIntegerField > 0 { - ind.FieldByIndex(mi.fields.pk.fieldIndex).SetUint(0) - } else { - ind.FieldByIndex(mi.fields.pk.fieldIndex).SetInt(0) - } - } - err := d.deleteRels(q, mi, args, tz) - if err != nil { - return num, err - } - } - return num, err - } - return 0, err -} - -// update table-related record by querySet. -// need querySet not struct reflect.Value to update related records. -func (d *dbBase) UpdateBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition, params Params, tz *time.Location) (int64, error) { - columns := make([]string, 0, len(params)) - values := make([]interface{}, 0, len(params)) - for col, val := range params { - if fi, ok := mi.fields.GetByAny(col); !ok || !fi.dbcol { - panic(fmt.Errorf("wrong field/column name `%s`", col)) - } else { - columns = append(columns, fi.column) - values = append(values, val) - } - } - - if len(columns) == 0 { - panic(fmt.Errorf("update params cannot empty")) - } - - tables := newDbTables(mi, d.ins) - if qs != nil { - tables.parseRelated(qs.related, qs.relDepth) - } - - where, args := tables.getCondSQL(cond, false, tz) - - values = append(values, args...) - - join := tables.getJoinSQL() - - var query, T string - - Q := d.ins.TableQuote() - - if d.ins.SupportUpdateJoin() { - T = "T0." - } - - cols := make([]string, 0, len(columns)) - - for i, v := range columns { - col := fmt.Sprintf("%s%s%s%s", T, Q, v, Q) - if c, ok := values[i].(colValue); ok { - switch c.opt { - case ColAdd: - cols = append(cols, col+" = "+col+" + ?") - case ColMinus: - cols = append(cols, col+" = "+col+" - ?") - case ColMultiply: - cols = append(cols, col+" = "+col+" * ?") - case ColExcept: - cols = append(cols, col+" = "+col+" / ?") - case ColBitAnd: - cols = append(cols, col+" = "+col+" & ?") - case ColBitRShift: - cols = append(cols, col+" = "+col+" >> ?") - case ColBitLShift: - cols = append(cols, col+" = "+col+" << ?") - case ColBitXOR: - cols = append(cols, col+" = "+col+" ^ ?") - case ColBitOr: - cols = append(cols, col+" = "+col+" | ?") - } - values[i] = c.value - } else { - cols = append(cols, col+" = ?") - } - } - - sets := strings.Join(cols, ", ") + " " - - if d.ins.SupportUpdateJoin() { - query = fmt.Sprintf("UPDATE %s%s%s T0 %sSET %s%s", Q, mi.table, Q, join, sets, where) - } else { - supQuery := fmt.Sprintf("SELECT T0.%s%s%s FROM %s%s%s T0 %s%s", Q, mi.fields.pk.column, Q, Q, mi.table, Q, join, where) - query = fmt.Sprintf("UPDATE %s%s%s SET %sWHERE %s%s%s IN ( %s )", Q, mi.table, Q, sets, Q, mi.fields.pk.column, Q, supQuery) - } - - d.ins.ReplaceMarks(&query) - var err error - var res sql.Result - if qs != nil && qs.forContext { - res, err = q.ExecContext(qs.ctx, query, values...) - } else { - res, err = q.Exec(query, values...) - } - if err == nil { - return res.RowsAffected() - } - return 0, err -} - -// delete related records. -// do UpdateBanch or DeleteBanch by condition of tables' relationship. -func (d *dbBase) deleteRels(q dbQuerier, mi *modelInfo, args []interface{}, tz *time.Location) error { - for _, fi := range mi.fields.fieldsReverse { - fi = fi.reverseFieldInfo - switch fi.onDelete { - case odCascade: - cond := NewCondition().And(fmt.Sprintf("%s__in", fi.name), args...) - _, err := d.DeleteBatch(q, nil, fi.mi, cond, tz) - if err != nil { - return err - } - case odSetDefault, odSetNULL: - cond := NewCondition().And(fmt.Sprintf("%s__in", fi.name), args...) - params := Params{fi.column: nil} - if fi.onDelete == odSetDefault { - params[fi.column] = fi.initial.String() - } - _, err := d.UpdateBatch(q, nil, fi.mi, cond, params, tz) - if err != nil { - return err - } - case odDoNothing: - } - } - return nil -} - -// delete table-related records. -func (d *dbBase) DeleteBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition, tz *time.Location) (int64, error) { - tables := newDbTables(mi, d.ins) - tables.skipEnd = true - - if qs != nil { - tables.parseRelated(qs.related, qs.relDepth) - } - - if cond == nil || cond.IsEmpty() { - panic(fmt.Errorf("delete operation cannot execute without condition")) - } - - Q := d.ins.TableQuote() - - where, args := tables.getCondSQL(cond, false, tz) - join := tables.getJoinSQL() - - cols := fmt.Sprintf("T0.%s%s%s", Q, mi.fields.pk.column, Q) - query := fmt.Sprintf("SELECT %s FROM %s%s%s T0 %s%s", cols, Q, mi.table, Q, join, where) - - d.ins.ReplaceMarks(&query) - - var rs *sql.Rows - r, err := q.Query(query, args...) - if err != nil { - return 0, err - } - rs = r - defer rs.Close() - - var ref interface{} - args = make([]interface{}, 0) - cnt := 0 - for rs.Next() { - if err := rs.Scan(&ref); err != nil { - return 0, err - } - pkValue, err := d.convertValueFromDB(mi.fields.pk, reflect.ValueOf(ref).Interface(), tz) - if err != nil { - return 0, err - } - args = append(args, pkValue) - cnt++ - } - - if cnt == 0 { - return 0, nil - } - - marks := make([]string, len(args)) - for i := range marks { - marks[i] = "?" - } - sqlIn := fmt.Sprintf("IN (%s)", strings.Join(marks, ", ")) - query = fmt.Sprintf("DELETE FROM %s%s%s WHERE %s%s%s %s", Q, mi.table, Q, Q, mi.fields.pk.column, Q, sqlIn) - - d.ins.ReplaceMarks(&query) - var res sql.Result - if qs != nil && qs.forContext { - res, err = q.ExecContext(qs.ctx, query, args...) - } else { - res, err = q.Exec(query, args...) - } - if err == nil { - num, err := res.RowsAffected() - if err != nil { - return 0, err - } - if num > 0 { - err := d.deleteRels(q, mi, args, tz) - if err != nil { - return num, err - } - } - return num, nil - } - return 0, err -} - -// read related records. -func (d *dbBase) ReadBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition, container interface{}, tz *time.Location, cols []string) (int64, error) { - - val := reflect.ValueOf(container) - ind := reflect.Indirect(val) - - errTyp := true - one := true - isPtr := true - - if val.Kind() == reflect.Ptr { - fn := "" - if ind.Kind() == reflect.Slice { - one = false - typ := ind.Type().Elem() - switch typ.Kind() { - case reflect.Ptr: - fn = getFullName(typ.Elem()) - case reflect.Struct: - isPtr = false - fn = getFullName(typ) - } - } else { - fn = getFullName(ind.Type()) - } - errTyp = fn != mi.fullName - } - - if errTyp { - if one { - panic(fmt.Errorf("wrong object type `%s` for rows scan, need *%s", val.Type(), mi.fullName)) - } else { - panic(fmt.Errorf("wrong object type `%s` for rows scan, need *[]*%s or *[]%s", val.Type(), mi.fullName, mi.fullName)) - } - } - - rlimit := qs.limit - offset := qs.offset - - Q := d.ins.TableQuote() - - var tCols []string - if len(cols) > 0 { - hasRel := len(qs.related) > 0 || qs.relDepth > 0 - tCols = make([]string, 0, len(cols)) - var maps map[string]bool - if hasRel { - maps = make(map[string]bool) - } - for _, col := range cols { - if fi, ok := mi.fields.GetByAny(col); ok { - tCols = append(tCols, fi.column) - if hasRel { - maps[fi.column] = true - } - } else { - return 0, fmt.Errorf("wrong field/column name `%s`", col) - } - } - if hasRel { - for _, fi := range mi.fields.fieldsDB { - if fi.fieldType&IsRelField > 0 { - if !maps[fi.column] { - tCols = append(tCols, fi.column) - } - } - } - } - } else { - tCols = mi.fields.dbcols - } - - colsNum := len(tCols) - sep := fmt.Sprintf("%s, T0.%s", Q, Q) - sels := fmt.Sprintf("T0.%s%s%s", Q, strings.Join(tCols, sep), Q) - - tables := newDbTables(mi, d.ins) - tables.parseRelated(qs.related, qs.relDepth) - - where, args := tables.getCondSQL(cond, false, tz) - groupBy := tables.getGroupSQL(qs.groups) - orderBy := tables.getOrderSQL(qs.orders) - limit := tables.getLimitSQL(mi, offset, rlimit) - join := tables.getJoinSQL() - - for _, tbl := range tables.tables { - if tbl.sel { - colsNum += len(tbl.mi.fields.dbcols) - sep := fmt.Sprintf("%s, %s.%s", Q, tbl.index, Q) - sels += fmt.Sprintf(", %s.%s%s%s", tbl.index, Q, strings.Join(tbl.mi.fields.dbcols, sep), Q) - } - } - - sqlSelect := "SELECT" - if qs.distinct { - sqlSelect += " DISTINCT" - } - query := fmt.Sprintf("%s %s FROM %s%s%s T0 %s%s%s%s%s", sqlSelect, sels, Q, mi.table, Q, join, where, groupBy, orderBy, limit) - - if qs.forupdate { - query += " FOR UPDATE" - } - - d.ins.ReplaceMarks(&query) - - var rs *sql.Rows - var err error - if qs != nil && qs.forContext { - rs, err = q.QueryContext(qs.ctx, query, args...) - if err != nil { - return 0, err - } - } else { - rs, err = q.Query(query, args...) - if err != nil { - return 0, err - } - } - - refs := make([]interface{}, colsNum) - for i := range refs { - var ref interface{} - refs[i] = &ref - } - - defer rs.Close() - - slice := ind - - var cnt int64 - for rs.Next() { - if one && cnt == 0 || !one { - if err := rs.Scan(refs...); err != nil { - return 0, err - } - - elm := reflect.New(mi.addrField.Elem().Type()) - mind := reflect.Indirect(elm) - - cacheV := make(map[string]*reflect.Value) - cacheM := make(map[string]*modelInfo) - trefs := refs - - d.setColsValues(mi, &mind, tCols, refs[:len(tCols)], tz) - trefs = refs[len(tCols):] - - for _, tbl := range tables.tables { - // loop selected tables - if tbl.sel { - last := mind - names := "" - mmi := mi - // loop cascade models - for _, name := range tbl.names { - names += name - if val, ok := cacheV[names]; ok { - last = *val - mmi = cacheM[names] - } else { - fi := mmi.fields.GetByName(name) - lastm := mmi - mmi = fi.relModelInfo - field := last - if last.Kind() != reflect.Invalid { - field = reflect.Indirect(last.FieldByIndex(fi.fieldIndex)) - if field.IsValid() { - d.setColsValues(mmi, &field, mmi.fields.dbcols, trefs[:len(mmi.fields.dbcols)], tz) - for _, fi := range mmi.fields.fieldsReverse { - if fi.inModel && fi.reverseFieldInfo.mi == lastm { - if fi.reverseFieldInfo != nil { - f := field.FieldByIndex(fi.fieldIndex) - if f.Kind() == reflect.Ptr { - f.Set(last.Addr()) - } - } - } - } - last = field - } - } - cacheV[names] = &field - cacheM[names] = mmi - } - } - trefs = trefs[len(mmi.fields.dbcols):] - } - } - - if one { - ind.Set(mind) - } else { - if cnt == 0 { - // you can use a empty & caped container list - // orm will not replace it - if ind.Len() != 0 { - // if container is not empty - // create a new one - slice = reflect.New(ind.Type()).Elem() - } - } - - if isPtr { - slice = reflect.Append(slice, mind.Addr()) - } else { - slice = reflect.Append(slice, mind) - } - } - } - cnt++ - } - - if !one { - if cnt > 0 { - ind.Set(slice) - } else { - // when a result is empty and container is nil - // to set a empty container - if ind.IsNil() { - ind.Set(reflect.MakeSlice(ind.Type(), 0, 0)) - } - } - } - - return cnt, nil -} - -// excute count sql and return count result int64. -func (d *dbBase) Count(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition, tz *time.Location) (cnt int64, err error) { - tables := newDbTables(mi, d.ins) - tables.parseRelated(qs.related, qs.relDepth) - - where, args := tables.getCondSQL(cond, false, tz) - groupBy := tables.getGroupSQL(qs.groups) - tables.getOrderSQL(qs.orders) - join := tables.getJoinSQL() - - Q := d.ins.TableQuote() - - query := fmt.Sprintf("SELECT COUNT(*) FROM %s%s%s T0 %s%s%s", Q, mi.table, Q, join, where, groupBy) - - if groupBy != "" { - query = fmt.Sprintf("SELECT COUNT(*) FROM (%s) AS T", query) - } - - d.ins.ReplaceMarks(&query) - - var row *sql.Row - if qs != nil && qs.forContext { - row = q.QueryRowContext(qs.ctx, query, args...) - } else { - row = q.QueryRow(query, args...) - } - err = row.Scan(&cnt) - return -} - -// generate sql with replacing operator string placeholders and replaced values. -func (d *dbBase) GenerateOperatorSQL(mi *modelInfo, fi *fieldInfo, operator string, args []interface{}, tz *time.Location) (string, []interface{}) { - var sql string - params := getFlatParams(fi, args, tz) - - if len(params) == 0 { - panic(fmt.Errorf("operator `%s` need at least one args", operator)) - } - arg := params[0] - - switch operator { - case "in": - marks := make([]string, len(params)) - for i := range marks { - marks[i] = "?" - } - sql = fmt.Sprintf("IN (%s)", strings.Join(marks, ", ")) - case "between": - if len(params) != 2 { - panic(fmt.Errorf("operator `%s` need 2 args not %d", operator, len(params))) - } - sql = "BETWEEN ? AND ?" - default: - if len(params) > 1 { - panic(fmt.Errorf("operator `%s` need 1 args not %d", operator, len(params))) - } - sql = d.ins.OperatorSQL(operator) - switch operator { - case "exact": - if arg == nil { - params[0] = "IS NULL" - } - case "iexact", "contains", "icontains", "startswith", "endswith", "istartswith", "iendswith": - param := strings.Replace(ToStr(arg), `%`, `\%`, -1) - switch operator { - case "iexact": - case "contains", "icontains": - param = fmt.Sprintf("%%%s%%", param) - case "startswith", "istartswith": - param = fmt.Sprintf("%s%%", param) - case "endswith", "iendswith": - param = fmt.Sprintf("%%%s", param) - } - params[0] = param - case "isnull": - if b, ok := arg.(bool); ok { - if b { - sql = "IS NULL" - } else { - sql = "IS NOT NULL" - } - params = nil - } else { - panic(fmt.Errorf("operator `%s` need a bool value not `%T`", operator, arg)) - } - } - } - return sql, params -} - -// gernerate sql string with inner function, such as UPPER(text). -func (d *dbBase) GenerateOperatorLeftCol(*fieldInfo, string, *string) { - // default not use -} - -// set values to struct column. -func (d *dbBase) setColsValues(mi *modelInfo, ind *reflect.Value, cols []string, values []interface{}, tz *time.Location) { - for i, column := range cols { - val := reflect.Indirect(reflect.ValueOf(values[i])).Interface() - - fi := mi.fields.GetByColumn(column) - - field := ind.FieldByIndex(fi.fieldIndex) - - value, err := d.convertValueFromDB(fi, val, tz) - if err != nil { - panic(fmt.Errorf("Raw value: `%v` %s", val, err.Error())) - } - - _, err = d.setFieldValue(fi, value, field) - - if err != nil { - panic(fmt.Errorf("Raw value: `%v` %s", val, err.Error())) - } - } -} - -// convert value from database result to value following in field type. -func (d *dbBase) convertValueFromDB(fi *fieldInfo, val interface{}, tz *time.Location) (interface{}, error) { - if val == nil { - return nil, nil - } - - var value interface{} - var tErr error - - var str *StrTo - switch v := val.(type) { - case []byte: - s := StrTo(string(v)) - str = &s - case string: - s := StrTo(v) - str = &s - } - - fieldType := fi.fieldType - -setValue: - switch { - case fieldType == TypeBooleanField: - if str == nil { - switch v := val.(type) { - case int64: - b := v == 1 - value = b - default: - s := StrTo(ToStr(v)) - str = &s - } - } - if str != nil { - b, err := str.Bool() - if err != nil { - tErr = err - goto end - } - value = b - } - case fieldType == TypeVarCharField || fieldType == TypeCharField || fieldType == TypeTextField || fieldType == TypeJSONField || fieldType == TypeJsonbField: - if str == nil { - value = ToStr(val) - } else { - value = str.String() - } - case fieldType == TypeTimeField || fieldType == TypeDateField || fieldType == TypeDateTimeField: - if str == nil { - switch t := val.(type) { - case time.Time: - d.ins.TimeFromDB(&t, tz) - value = t - default: - s := StrTo(ToStr(t)) - str = &s - } - } - if str != nil { - s := str.String() - var ( - t time.Time - err error - ) - if len(s) >= 19 { - s = s[:19] - t, err = time.ParseInLocation(formatDateTime, s, tz) - } else if len(s) >= 10 { - if len(s) > 10 { - s = s[:10] - } - t, err = time.ParseInLocation(formatDate, s, tz) - } else if len(s) >= 8 { - if len(s) > 8 { - s = s[:8] - } - t, err = time.ParseInLocation(formatTime, s, tz) - } - t = t.In(DefaultTimeLoc) - - if err != nil && s != "00:00:00" && s != "0000-00-00" && s != "0000-00-00 00:00:00" { - tErr = err - goto end - } - value = t - } - case fieldType&IsIntegerField > 0: - if str == nil { - s := StrTo(ToStr(val)) - str = &s - } - if str != nil { - var err error - switch fieldType { - case TypeBitField: - _, err = str.Int8() - case TypeSmallIntegerField: - _, err = str.Int16() - case TypeIntegerField: - _, err = str.Int32() - case TypeBigIntegerField: - _, err = str.Int64() - case TypePositiveBitField: - _, err = str.Uint8() - case TypePositiveSmallIntegerField: - _, err = str.Uint16() - case TypePositiveIntegerField: - _, err = str.Uint32() - case TypePositiveBigIntegerField: - _, err = str.Uint64() - } - if err != nil { - tErr = err - goto end - } - if fieldType&IsPositiveIntegerField > 0 { - v, _ := str.Uint64() - value = v - } else { - v, _ := str.Int64() - value = v - } - } - case fieldType == TypeFloatField || fieldType == TypeDecimalField: - if str == nil { - switch v := val.(type) { - case float64: - value = v - default: - s := StrTo(ToStr(v)) - str = &s - } - } - if str != nil { - v, err := str.Float64() - if err != nil { - tErr = err - goto end - } - value = v - } - case fieldType&IsRelField > 0: - fi = fi.relModelInfo.fields.pk - fieldType = fi.fieldType - goto setValue - } - -end: - if tErr != nil { - err := fmt.Errorf("convert to `%s` failed, field: %s err: %s", fi.addrValue.Type(), fi.fullName, tErr) - return nil, err - } - - return value, nil - -} - -// set one value to struct column field. -func (d *dbBase) setFieldValue(fi *fieldInfo, value interface{}, field reflect.Value) (interface{}, error) { - - fieldType := fi.fieldType - isNative := !fi.isFielder - -setValue: - switch { - case fieldType == TypeBooleanField: - if isNative { - if nb, ok := field.Interface().(sql.NullBool); ok { - if value == nil { - nb.Valid = false - } else { - nb.Bool = value.(bool) - nb.Valid = true - } - field.Set(reflect.ValueOf(nb)) - } else if field.Kind() == reflect.Ptr { - if value != nil { - v := value.(bool) - field.Set(reflect.ValueOf(&v)) - } - } else { - if value == nil { - value = false - } - field.SetBool(value.(bool)) - } - } - case fieldType == TypeVarCharField || fieldType == TypeCharField || fieldType == TypeTextField || fieldType == TypeJSONField || fieldType == TypeJsonbField: - if isNative { - if ns, ok := field.Interface().(sql.NullString); ok { - if value == nil { - ns.Valid = false - } else { - ns.String = value.(string) - ns.Valid = true - } - field.Set(reflect.ValueOf(ns)) - } else if field.Kind() == reflect.Ptr { - if value != nil { - v := value.(string) - field.Set(reflect.ValueOf(&v)) - } - } else { - if value == nil { - value = "" - } - field.SetString(value.(string)) - } - } - case fieldType == TypeTimeField || fieldType == TypeDateField || fieldType == TypeDateTimeField: - if isNative { - if value == nil { - value = time.Time{} - } else if field.Kind() == reflect.Ptr { - if value != nil { - v := value.(time.Time) - field.Set(reflect.ValueOf(&v)) - } - } else { - field.Set(reflect.ValueOf(value)) - } - } - case fieldType == TypePositiveBitField && field.Kind() == reflect.Ptr: - if value != nil { - v := uint8(value.(uint64)) - field.Set(reflect.ValueOf(&v)) - } - case fieldType == TypePositiveSmallIntegerField && field.Kind() == reflect.Ptr: - if value != nil { - v := uint16(value.(uint64)) - field.Set(reflect.ValueOf(&v)) - } - case fieldType == TypePositiveIntegerField && field.Kind() == reflect.Ptr: - if value != nil { - if field.Type() == reflect.TypeOf(new(uint)) { - v := uint(value.(uint64)) - field.Set(reflect.ValueOf(&v)) - } else { - v := uint32(value.(uint64)) - field.Set(reflect.ValueOf(&v)) - } - } - case fieldType == TypePositiveBigIntegerField && field.Kind() == reflect.Ptr: - if value != nil { - v := value.(uint64) - field.Set(reflect.ValueOf(&v)) - } - case fieldType == TypeBitField && field.Kind() == reflect.Ptr: - if value != nil { - v := int8(value.(int64)) - field.Set(reflect.ValueOf(&v)) - } - case fieldType == TypeSmallIntegerField && field.Kind() == reflect.Ptr: - if value != nil { - v := int16(value.(int64)) - field.Set(reflect.ValueOf(&v)) - } - case fieldType == TypeIntegerField && field.Kind() == reflect.Ptr: - if value != nil { - if field.Type() == reflect.TypeOf(new(int)) { - v := int(value.(int64)) - field.Set(reflect.ValueOf(&v)) - } else { - v := int32(value.(int64)) - field.Set(reflect.ValueOf(&v)) - } - } - case fieldType == TypeBigIntegerField && field.Kind() == reflect.Ptr: - if value != nil { - v := value.(int64) - field.Set(reflect.ValueOf(&v)) - } - case fieldType&IsIntegerField > 0: - if fieldType&IsPositiveIntegerField > 0 { - if isNative { - if value == nil { - value = uint64(0) - } - field.SetUint(value.(uint64)) - } - } else { - if isNative { - if ni, ok := field.Interface().(sql.NullInt64); ok { - if value == nil { - ni.Valid = false - } else { - ni.Int64 = value.(int64) - ni.Valid = true - } - field.Set(reflect.ValueOf(ni)) - } else { - if value == nil { - value = int64(0) - } - field.SetInt(value.(int64)) - } - } - } - case fieldType == TypeFloatField || fieldType == TypeDecimalField: - if isNative { - if nf, ok := field.Interface().(sql.NullFloat64); ok { - if value == nil { - nf.Valid = false - } else { - nf.Float64 = value.(float64) - nf.Valid = true - } - field.Set(reflect.ValueOf(nf)) - } else if field.Kind() == reflect.Ptr { - if value != nil { - if field.Type() == reflect.TypeOf(new(float32)) { - v := float32(value.(float64)) - field.Set(reflect.ValueOf(&v)) - } else { - v := value.(float64) - field.Set(reflect.ValueOf(&v)) - } - } - } else { - - if value == nil { - value = float64(0) - } - field.SetFloat(value.(float64)) - } - } - case fieldType&IsRelField > 0: - if value != nil { - fieldType = fi.relModelInfo.fields.pk.fieldType - mf := reflect.New(fi.relModelInfo.addrField.Elem().Type()) - field.Set(mf) - f := mf.Elem().FieldByIndex(fi.relModelInfo.fields.pk.fieldIndex) - field = f - goto setValue - } - } - - if !isNative { - fd := field.Addr().Interface().(Fielder) - err := fd.SetRaw(value) - if err != nil { - err = fmt.Errorf("converted value `%v` set to Fielder `%s` failed, err: %s", value, fi.fullName, err) - return nil, err - } - } - - return value, nil -} - -// query sql, read values , save to *[]ParamList. -func (d *dbBase) ReadValues(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition, exprs []string, container interface{}, tz *time.Location) (int64, error) { - - var ( - maps []Params - lists []ParamsList - list ParamsList - ) - - typ := 0 - switch v := container.(type) { - case *[]Params: - d := *v - if len(d) == 0 { - maps = d - } - typ = 1 - case *[]ParamsList: - d := *v - if len(d) == 0 { - lists = d - } - typ = 2 - case *ParamsList: - d := *v - if len(d) == 0 { - list = d - } - typ = 3 - default: - panic(fmt.Errorf("unsupport read values type `%T`", container)) - } - - tables := newDbTables(mi, d.ins) - - var ( - cols []string - infos []*fieldInfo - ) - - hasExprs := len(exprs) > 0 - - Q := d.ins.TableQuote() - - if hasExprs { - cols = make([]string, 0, len(exprs)) - infos = make([]*fieldInfo, 0, len(exprs)) - for _, ex := range exprs { - index, name, fi, suc := tables.parseExprs(mi, strings.Split(ex, ExprSep)) - if !suc { - panic(fmt.Errorf("unknown field/column name `%s`", ex)) - } - cols = append(cols, fmt.Sprintf("%s.%s%s%s %s%s%s", index, Q, fi.column, Q, Q, name, Q)) - infos = append(infos, fi) - } - } else { - cols = make([]string, 0, len(mi.fields.dbcols)) - infos = make([]*fieldInfo, 0, len(exprs)) - for _, fi := range mi.fields.fieldsDB { - cols = append(cols, fmt.Sprintf("T0.%s%s%s %s%s%s", Q, fi.column, Q, Q, fi.name, Q)) - infos = append(infos, fi) - } - } - - where, args := tables.getCondSQL(cond, false, tz) - groupBy := tables.getGroupSQL(qs.groups) - orderBy := tables.getOrderSQL(qs.orders) - limit := tables.getLimitSQL(mi, qs.offset, qs.limit) - join := tables.getJoinSQL() - - sels := strings.Join(cols, ", ") - - sqlSelect := "SELECT" - if qs.distinct { - sqlSelect += " DISTINCT" - } - query := fmt.Sprintf("%s %s FROM %s%s%s T0 %s%s%s%s%s", sqlSelect, sels, Q, mi.table, Q, join, where, groupBy, orderBy, limit) - - d.ins.ReplaceMarks(&query) - - rs, err := q.Query(query, args...) - if err != nil { - return 0, err - } - refs := make([]interface{}, len(cols)) - for i := range refs { - var ref interface{} - refs[i] = &ref - } - - defer rs.Close() - - var ( - cnt int64 - columns []string - ) - for rs.Next() { - if cnt == 0 { - cols, err := rs.Columns() - if err != nil { - return 0, err - } - columns = cols - } - - if err := rs.Scan(refs...); err != nil { - return 0, err - } - - switch typ { - case 1: - params := make(Params, len(cols)) - for i, ref := range refs { - fi := infos[i] - - val := reflect.Indirect(reflect.ValueOf(ref)).Interface() - - value, err := d.convertValueFromDB(fi, val, tz) - if err != nil { - panic(fmt.Errorf("db value convert failed `%v` %s", val, err.Error())) - } - - params[columns[i]] = value - } - maps = append(maps, params) - case 2: - params := make(ParamsList, 0, len(cols)) - for i, ref := range refs { - fi := infos[i] - - val := reflect.Indirect(reflect.ValueOf(ref)).Interface() - - value, err := d.convertValueFromDB(fi, val, tz) - if err != nil { - panic(fmt.Errorf("db value convert failed `%v` %s", val, err.Error())) - } - - params = append(params, value) - } - lists = append(lists, params) - case 3: - for i, ref := range refs { - fi := infos[i] - - val := reflect.Indirect(reflect.ValueOf(ref)).Interface() - - value, err := d.convertValueFromDB(fi, val, tz) - if err != nil { - panic(fmt.Errorf("db value convert failed `%v` %s", val, err.Error())) - } - - list = append(list, value) - } - } - - cnt++ - } - - switch v := container.(type) { - case *[]Params: - *v = maps - case *[]ParamsList: - *v = lists - case *ParamsList: - *v = list - } - - return cnt, nil -} - -func (d *dbBase) RowsTo(dbQuerier, *querySet, *modelInfo, *Condition, interface{}, string, string, *time.Location) (int64, error) { - return 0, nil -} - -// flag of update joined record. -func (d *dbBase) SupportUpdateJoin() bool { - return true -} - -func (d *dbBase) MaxLimit() uint64 { - return 18446744073709551615 -} - -// return quote. -func (d *dbBase) TableQuote() string { - return "`" -} - -// replace value placeholder in parametered sql string. -func (d *dbBase) ReplaceMarks(query *string) { - // default use `?` as mark, do nothing -} - -// flag of RETURNING sql. -func (d *dbBase) HasReturningID(*modelInfo, *string) bool { - return false -} - -// sync auto key -func (d *dbBase) setval(db dbQuerier, mi *modelInfo, autoFields []string) error { - return nil -} - -// convert time from db. -func (d *dbBase) TimeFromDB(t *time.Time, tz *time.Location) { - *t = t.In(tz) -} - -// convert time to db. -func (d *dbBase) TimeToDB(t *time.Time, tz *time.Location) { - *t = t.In(tz) -} - -// get database types. -func (d *dbBase) DbTypes() map[string]string { - return nil -} - -// gt all tables. -func (d *dbBase) GetTables(db dbQuerier) (map[string]bool, error) { - tables := make(map[string]bool) - query := d.ins.ShowTablesQuery() - rows, err := db.Query(query) - if err != nil { - return tables, err - } - - defer rows.Close() - - for rows.Next() { - var table string - err := rows.Scan(&table) - if err != nil { - return tables, err - } - if table != "" { - tables[table] = true - } - } - - return tables, nil -} - -// get all cloumns in table. -func (d *dbBase) GetColumns(db dbQuerier, table string) (map[string][3]string, error) { - columns := make(map[string][3]string) - query := d.ins.ShowColumnsQuery(table) - rows, err := db.Query(query) - if err != nil { - return columns, err - } - - defer rows.Close() - - for rows.Next() { - var ( - name string - typ string - null string - ) - err := rows.Scan(&name, &typ, &null) - if err != nil { - return columns, err - } - columns[name] = [3]string{name, typ, null} - } - - return columns, nil -} - -// not implement. -func (d *dbBase) OperatorSQL(operator string) string { - panic(ErrNotImplement) -} - -// not implement. -func (d *dbBase) ShowTablesQuery() string { - panic(ErrNotImplement) -} - -// not implement. -func (d *dbBase) ShowColumnsQuery(table string) string { - panic(ErrNotImplement) -} - -// not implement. -func (d *dbBase) IndexExists(dbQuerier, string, string) bool { - panic(ErrNotImplement) -} diff --git a/orm/db_alias.go b/orm/db_alias.go deleted file mode 100644 index d3dbc595..00000000 --- a/orm/db_alias.go +++ /dev/null @@ -1,487 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package orm - -import ( - "context" - "database/sql" - "fmt" - lru "github.com/hashicorp/golang-lru" - "reflect" - "sync" - "time" -) - -// DriverType database driver constant int. -type DriverType int - -// Enum the Database driver -const ( - _ DriverType = iota // int enum type - DRMySQL // mysql - DRSqlite // sqlite - DROracle // oracle - DRPostgres // pgsql - DRTiDB // TiDB -) - -// database driver string. -type driver string - -// get type constant int of current driver.. -// Deprecated: using pkg/orm. We will remove this method in v2.1.0 -func (d driver) Type() DriverType { - a, _ := dataBaseCache.get(string(d)) - return a.Driver -} - -// get name of current driver -// Deprecated: using pkg/orm. We will remove this method in v2.1.0 -func (d driver) Name() string { - return string(d) -} - -// check driver iis implemented Driver interface or not. -var _ Driver = new(driver) - -var ( - dataBaseCache = &_dbCache{cache: make(map[string]*alias)} - drivers = map[string]DriverType{ - "mysql": DRMySQL, - "postgres": DRPostgres, - "sqlite3": DRSqlite, - "tidb": DRTiDB, - "oracle": DROracle, - "oci8": DROracle, // github.com/mattn/go-oci8 - "ora": DROracle, //https://github.com/rana/ora - } - dbBasers = map[DriverType]dbBaser{ - DRMySQL: newdbBaseMysql(), - DRSqlite: newdbBaseSqlite(), - DROracle: newdbBaseOracle(), - DRPostgres: newdbBasePostgres(), - DRTiDB: newdbBaseTidb(), - } -) - -// database alias cacher. -type _dbCache struct { - mux sync.RWMutex - cache map[string]*alias -} - -// add database alias with original name. -func (ac *_dbCache) add(name string, al *alias) (added bool) { - ac.mux.Lock() - defer ac.mux.Unlock() - if _, ok := ac.cache[name]; !ok { - ac.cache[name] = al - added = true - } - return -} - -// get database alias if cached. -func (ac *_dbCache) get(name string) (al *alias, ok bool) { - ac.mux.RLock() - defer ac.mux.RUnlock() - al, ok = ac.cache[name] - return -} - -// get default alias. -func (ac *_dbCache) getDefault() (al *alias) { - al, _ = ac.get("default") - return -} - -type DB struct { - *sync.RWMutex - DB *sql.DB - stmtDecorators *lru.Cache -} - -// Begin start a transaction -// Deprecated: using pkg/orm. We will remove this method in v2.1.0 -func (d *DB) Begin() (*sql.Tx, error) { - return d.DB.Begin() -} - -// BeginTx start a transaction with context and those options -// Deprecated: using pkg/orm. We will remove this method in v2.1.0 -func (d *DB) BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error) { - return d.DB.BeginTx(ctx, opts) -} - -// su must call release to release *sql.Stmt after using -func (d *DB) getStmtDecorator(query string) (*stmtDecorator, error) { - d.RLock() - c, ok := d.stmtDecorators.Get(query) - if ok { - c.(*stmtDecorator).acquire() - d.RUnlock() - return c.(*stmtDecorator), nil - } - d.RUnlock() - - d.Lock() - c, ok = d.stmtDecorators.Get(query) - if ok { - c.(*stmtDecorator).acquire() - d.Unlock() - return c.(*stmtDecorator), nil - } - - stmt, err := d.Prepare(query) - if err != nil { - d.Unlock() - return nil, err - } - sd := newStmtDecorator(stmt) - sd.acquire() - d.stmtDecorators.Add(query, sd) - d.Unlock() - - return sd, nil -} - -// Deprecated: using pkg/orm. We will remove this method in v2.1.0 -func (d *DB) Prepare(query string) (*sql.Stmt, error) { - return d.DB.Prepare(query) -} - -// Deprecated: using pkg/orm. We will remove this method in v2.1.0 -func (d *DB) PrepareContext(ctx context.Context, query string) (*sql.Stmt, error) { - return d.DB.PrepareContext(ctx, query) -} - -// Deprecated: using pkg/orm. We will remove this method in v2.1.0 -func (d *DB) Exec(query string, args ...interface{}) (sql.Result, error) { - sd, err := d.getStmtDecorator(query) - if err != nil { - return nil, err - } - stmt := sd.getStmt() - defer sd.release() - return stmt.Exec(args...) -} - -// Deprecated: using pkg/orm. We will remove this method in v2.1.0 -func (d *DB) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) { - sd, err := d.getStmtDecorator(query) - if err != nil { - return nil, err - } - stmt := sd.getStmt() - defer sd.release() - return stmt.ExecContext(ctx, args...) -} - -// Deprecated: using pkg/orm. We will remove this method in v2.1.0 -func (d *DB) Query(query string, args ...interface{}) (*sql.Rows, error) { - sd, err := d.getStmtDecorator(query) - if err != nil { - return nil, err - } - stmt := sd.getStmt() - defer sd.release() - return stmt.Query(args...) -} - -// Deprecated: using pkg/orm. We will remove this method in v2.1.0 -func (d *DB) QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) { - sd, err := d.getStmtDecorator(query) - if err != nil { - return nil, err - } - stmt := sd.getStmt() - defer sd.release() - return stmt.QueryContext(ctx, args...) -} - -// Deprecated: using pkg/orm. We will remove this method in v2.1.0 -func (d *DB) QueryRow(query string, args ...interface{}) *sql.Row { - sd, err := d.getStmtDecorator(query) - if err != nil { - panic(err) - } - stmt := sd.getStmt() - defer sd.release() - return stmt.QueryRow(args...) - -} - -// Deprecated: using pkg/orm. We will remove this method in v2.1.0 -func (d *DB) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row { - sd, err := d.getStmtDecorator(query) - if err != nil { - panic(err) - } - stmt := sd.getStmt() - defer sd.release() - return stmt.QueryRowContext(ctx, args) -} - -type alias struct { - Name string - Driver DriverType - DriverName string - DataSource string - MaxIdleConns int - MaxOpenConns int - DB *DB - DbBaser dbBaser - TZ *time.Location - Engine string -} - -func detectTZ(al *alias) { - // orm timezone system match database - // default use Local - al.TZ = DefaultTimeLoc - - if al.DriverName == "sphinx" { - return - } - - switch al.Driver { - case DRMySQL: - row := al.DB.QueryRow("SELECT TIMEDIFF(NOW(), UTC_TIMESTAMP)") - var tz string - row.Scan(&tz) - if len(tz) >= 8 { - if tz[0] != '-' { - tz = "+" + tz - } - t, err := time.Parse("-07:00:00", tz) - if err == nil { - if t.Location().String() != "" { - al.TZ = t.Location() - } - } else { - DebugLog.Printf("Detect DB timezone: %s %s\n", tz, err.Error()) - } - } - - // get default engine from current database - row = al.DB.QueryRow("SELECT ENGINE, TRANSACTIONS FROM information_schema.engines WHERE SUPPORT = 'DEFAULT'") - var engine string - var tx bool - row.Scan(&engine, &tx) - - if engine != "" { - al.Engine = engine - } else { - al.Engine = "INNODB" - } - - case DRSqlite, DROracle: - al.TZ = time.UTC - - case DRPostgres: - row := al.DB.QueryRow("SELECT current_setting('TIMEZONE')") - var tz string - row.Scan(&tz) - loc, err := time.LoadLocation(tz) - if err == nil { - al.TZ = loc - } else { - DebugLog.Printf("Detect DB timezone: %s %s\n", tz, err.Error()) - } - } -} - -func addAliasWthDB(aliasName, driverName string, db *sql.DB) (*alias, error) { - al := new(alias) - al.Name = aliasName - al.DriverName = driverName - al.DB = &DB{ - RWMutex: new(sync.RWMutex), - DB: db, - stmtDecorators: newStmtDecoratorLruWithEvict(), - } - - if dr, ok := drivers[driverName]; ok { - al.DbBaser = dbBasers[dr] - al.Driver = dr - } else { - return nil, fmt.Errorf("driver name `%s` have not registered", driverName) - } - - err := db.Ping() - if err != nil { - return nil, fmt.Errorf("register db Ping `%s`, %s", aliasName, err.Error()) - } - - if !dataBaseCache.add(aliasName, al) { - return nil, fmt.Errorf("DataBase alias name `%s` already registered, cannot reuse", aliasName) - } - - return al, nil -} - -// AddAliasWthDB add a aliasName for the drivename -// Deprecated: using pkg/orm. We will remove this method in v2.1.0 -func AddAliasWthDB(aliasName, driverName string, db *sql.DB) error { - _, err := addAliasWthDB(aliasName, driverName, db) - return err -} - -// RegisterDataBase Setting the database connect params. Use the database driver self dataSource args. -// Deprecated: using pkg/orm. We will remove this method in v2.1.0 -func RegisterDataBase(aliasName, driverName, dataSource string, params ...int) error { - var ( - err error - db *sql.DB - al *alias - ) - - db, err = sql.Open(driverName, dataSource) - if err != nil { - err = fmt.Errorf("register db `%s`, %s", aliasName, err.Error()) - goto end - } - - al, err = addAliasWthDB(aliasName, driverName, db) - if err != nil { - goto end - } - - al.DataSource = dataSource - - detectTZ(al) - - for i, v := range params { - switch i { - case 0: - SetMaxIdleConns(al.Name, v) - case 1: - SetMaxOpenConns(al.Name, v) - } - } - -end: - if err != nil { - if db != nil { - db.Close() - } - DebugLog.Println(err.Error()) - } - - return err -} - -// RegisterDriver Register a database driver use specify driver name, this can be definition the driver is which database type. -// Deprecated: using pkg/orm. We will remove this method in v2.1.0 -func RegisterDriver(driverName string, typ DriverType) error { - if t, ok := drivers[driverName]; !ok { - drivers[driverName] = typ - } else { - if t != typ { - return fmt.Errorf("driverName `%s` db driver already registered and is other type", driverName) - } - } - return nil -} - -// SetDataBaseTZ Change the database default used timezone -// Deprecated: using pkg/orm. We will remove this method in v2.1.0 -func SetDataBaseTZ(aliasName string, tz *time.Location) error { - if al, ok := dataBaseCache.get(aliasName); ok { - al.TZ = tz - } else { - return fmt.Errorf("DataBase alias name `%s` not registered", aliasName) - } - return nil -} - -// SetMaxIdleConns Change the max idle conns for *sql.DB, use specify database alias name -// Deprecated: using pkg/orm. We will remove this method in v2.1.0 -func SetMaxIdleConns(aliasName string, maxIdleConns int) { - al := getDbAlias(aliasName) - al.MaxIdleConns = maxIdleConns - al.DB.DB.SetMaxIdleConns(maxIdleConns) -} - -// SetMaxOpenConns Change the max open conns for *sql.DB, use specify database alias name -// Deprecated: using pkg/orm. We will remove this method in v2.1.0 -func SetMaxOpenConns(aliasName string, maxOpenConns int) { - al := getDbAlias(aliasName) - al.MaxOpenConns = maxOpenConns - al.DB.DB.SetMaxOpenConns(maxOpenConns) - // for tip go 1.2 - if fun := reflect.ValueOf(al.DB).MethodByName("SetMaxOpenConns"); fun.IsValid() { - fun.Call([]reflect.Value{reflect.ValueOf(maxOpenConns)}) - } -} - -// GetDB Get *sql.DB from registered database by db alias name. -// Use "default" as alias name if you not set. -// Deprecated: using pkg/orm. We will remove this method in v2.1.0 -func GetDB(aliasNames ...string) (*sql.DB, error) { - var name string - if len(aliasNames) > 0 { - name = aliasNames[0] - } else { - name = "default" - } - al, ok := dataBaseCache.get(name) - if ok { - return al.DB.DB, nil - } - return nil, fmt.Errorf("DataBase of alias name `%s` not found", name) -} - -type stmtDecorator struct { - wg sync.WaitGroup - stmt *sql.Stmt -} - -func (s *stmtDecorator) getStmt() *sql.Stmt { - return s.stmt -} - -// acquire will add one -// since this method will be used inside read lock scope, -// so we can not do more things here -// we should think about refactor this -func (s *stmtDecorator) acquire() { - s.wg.Add(1) -} - -func (s *stmtDecorator) release() { - s.wg.Done() -} - -//garbage recycle for stmt -func (s *stmtDecorator) destroy() { - go func() { - s.wg.Wait() - _ = s.stmt.Close() - }() -} - -func newStmtDecorator(sqlStmt *sql.Stmt) *stmtDecorator { - return &stmtDecorator{ - stmt: sqlStmt, - } -} - -func newStmtDecoratorLruWithEvict() *lru.Cache { - cache, _ := lru.NewWithEvict(1000, func(key interface{}, value interface{}) { - value.(*stmtDecorator).destroy() - }) - return cache -} diff --git a/orm/db_mysql.go b/orm/db_mysql.go deleted file mode 100644 index 36f6f566..00000000 --- a/orm/db_mysql.go +++ /dev/null @@ -1,189 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package orm - -import ( - "fmt" - "reflect" - "strings" -) - -// mysql operators. -var mysqlOperators = map[string]string{ - "exact": "= ?", - "iexact": "LIKE ?", - "contains": "LIKE BINARY ?", - "icontains": "LIKE ?", - // "regex": "REGEXP BINARY ?", - // "iregex": "REGEXP ?", - "gt": "> ?", - ">": "> ?", - "gte": ">= ?", - ">=": ">= ?", - "lt": "< ?", - "<": "< ?", - "lte": "<= ?", - "<=": "<= ?", - "eq": "= ?", - "=": "= ?", - "ne": "!= ?", - "!=": "!= ?", - "startswith": "LIKE BINARY ?", - "endswith": "LIKE BINARY ?", - "istartswith": "LIKE ?", - "iendswith": "LIKE ?", -} - -// mysql column field types. -var mysqlTypes = map[string]string{ - "auto": "AUTO_INCREMENT NOT NULL PRIMARY KEY", - "pk": "NOT NULL PRIMARY KEY", - "bool": "bool", - "string": "varchar(%d)", - "string-char": "char(%d)", - "string-text": "longtext", - "time.Time-date": "date", - "time.Time": "datetime", - "int8": "tinyint", - "int16": "smallint", - "int32": "integer", - "int64": "bigint", - "uint8": "tinyint unsigned", - "uint16": "smallint unsigned", - "uint32": "integer unsigned", - "uint64": "bigint unsigned", - "float64": "double precision", - "float64-decimal": "numeric(%d, %d)", -} - -// mysql dbBaser implementation. -type dbBaseMysql struct { - dbBase -} - -var _ dbBaser = new(dbBaseMysql) - -// get mysql operator. -func (d *dbBaseMysql) OperatorSQL(operator string) string { - return mysqlOperators[operator] -} - -// get mysql table field types. -func (d *dbBaseMysql) DbTypes() map[string]string { - return mysqlTypes -} - -// show table sql for mysql. -func (d *dbBaseMysql) ShowTablesQuery() string { - return "SELECT table_name FROM information_schema.tables WHERE table_type = 'BASE TABLE' AND table_schema = DATABASE()" -} - -// show columns sql of table for mysql. -func (d *dbBaseMysql) ShowColumnsQuery(table string) string { - return fmt.Sprintf("SELECT COLUMN_NAME, COLUMN_TYPE, IS_NULLABLE FROM information_schema.columns "+ - "WHERE table_schema = DATABASE() AND table_name = '%s'", table) -} - -// execute sql to check index exist. -func (d *dbBaseMysql) IndexExists(db dbQuerier, table string, name string) bool { - row := db.QueryRow("SELECT count(*) FROM information_schema.statistics "+ - "WHERE table_schema = DATABASE() AND table_name = ? AND index_name = ?", table, name) - var cnt int - row.Scan(&cnt) - return cnt > 0 -} - -// InsertOrUpdate a row -// If your primary key or unique column conflict will update -// If no will insert -// Add "`" for mysql sql building -func (d *dbBaseMysql) InsertOrUpdate(q dbQuerier, mi *modelInfo, ind reflect.Value, a *alias, args ...string) (int64, error) { - var iouStr string - argsMap := map[string]string{} - - iouStr = "ON DUPLICATE KEY UPDATE" - - //Get on the key-value pairs - for _, v := range args { - kv := strings.Split(v, "=") - if len(kv) == 2 { - argsMap[strings.ToLower(kv[0])] = kv[1] - } - } - - isMulti := false - names := make([]string, 0, len(mi.fields.dbcols)-1) - Q := d.ins.TableQuote() - values, _, err := d.collectValues(mi, ind, mi.fields.dbcols, true, true, &names, a.TZ) - - if err != nil { - return 0, err - } - - marks := make([]string, len(names)) - updateValues := make([]interface{}, 0) - updates := make([]string, len(names)) - - for i, v := range names { - marks[i] = "?" - valueStr := argsMap[strings.ToLower(v)] - if valueStr != "" { - updates[i] = "`" + v + "`" + "=" + valueStr - } else { - updates[i] = "`" + v + "`" + "=?" - updateValues = append(updateValues, values[i]) - } - } - - values = append(values, updateValues...) - - sep := fmt.Sprintf("%s, %s", Q, Q) - qmarks := strings.Join(marks, ", ") - qupdates := strings.Join(updates, ", ") - columns := strings.Join(names, sep) - - multi := len(values) / len(names) - - if isMulti { - qmarks = strings.Repeat(qmarks+"), (", multi-1) + qmarks - } - //conflitValue maybe is a int,can`t use fmt.Sprintf - query := fmt.Sprintf("INSERT INTO %s%s%s (%s%s%s) VALUES (%s) %s "+qupdates, Q, mi.table, Q, Q, columns, Q, qmarks, iouStr) - - d.ins.ReplaceMarks(&query) - - if isMulti || !d.ins.HasReturningID(mi, &query) { - res, err := q.Exec(query, values...) - if err == nil { - if isMulti { - return res.RowsAffected() - } - return res.LastInsertId() - } - return 0, err - } - - row := q.QueryRow(query, values...) - var id int64 - err = row.Scan(&id) - return id, err -} - -// create new mysql dbBaser. -func newdbBaseMysql() dbBaser { - b := new(dbBaseMysql) - b.ins = b - return b -} diff --git a/orm/db_oracle.go b/orm/db_oracle.go deleted file mode 100644 index ed2ec74c..00000000 --- a/orm/db_oracle.go +++ /dev/null @@ -1,142 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package orm - -import ( - "fmt" - "strings" -) - -// oracle operators. -var oracleOperators = map[string]string{ - "exact": "= ?", - "=": "= ?", - "gt": "> ?", - ">": "> ?", - "gte": ">= ?", - ">=": ">= ?", - "lt": "< ?", - "<": "< ?", - "lte": "<= ?", - "<=": "<= ?", - "//iendswith": "LIKE ?", -} - -// oracle column field types. -var oracleTypes = map[string]string{ - "pk": "NOT NULL PRIMARY KEY", - "bool": "bool", - "string": "VARCHAR2(%d)", - "string-char": "CHAR(%d)", - "string-text": "VARCHAR2(%d)", - "time.Time-date": "DATE", - "time.Time": "TIMESTAMP", - "int8": "INTEGER", - "int16": "INTEGER", - "int32": "INTEGER", - "int64": "INTEGER", - "uint8": "INTEGER", - "uint16": "INTEGER", - "uint32": "INTEGER", - "uint64": "INTEGER", - "float64": "NUMBER", - "float64-decimal": "NUMBER(%d, %d)", -} - -// oracle dbBaser -type dbBaseOracle struct { - dbBase -} - -var _ dbBaser = new(dbBaseOracle) - -// create oracle dbBaser. -func newdbBaseOracle() dbBaser { - b := new(dbBaseOracle) - b.ins = b - return b -} - -// OperatorSQL get oracle operator. -func (d *dbBaseOracle) OperatorSQL(operator string) string { - return oracleOperators[operator] -} - -// DbTypes get oracle table field types. -func (d *dbBaseOracle) DbTypes() map[string]string { - return oracleTypes -} - -//ShowTablesQuery show all the tables in database -func (d *dbBaseOracle) ShowTablesQuery() string { - return "SELECT TABLE_NAME FROM USER_TABLES" -} - -// Oracle -func (d *dbBaseOracle) ShowColumnsQuery(table string) string { - return fmt.Sprintf("SELECT COLUMN_NAME FROM ALL_TAB_COLUMNS "+ - "WHERE TABLE_NAME ='%s'", strings.ToUpper(table)) -} - -// check index is exist -func (d *dbBaseOracle) IndexExists(db dbQuerier, table string, name string) bool { - row := db.QueryRow("SELECT COUNT(*) FROM USER_IND_COLUMNS, USER_INDEXES "+ - "WHERE USER_IND_COLUMNS.INDEX_NAME = USER_INDEXES.INDEX_NAME "+ - "AND USER_IND_COLUMNS.TABLE_NAME = ? AND USER_IND_COLUMNS.INDEX_NAME = ?", strings.ToUpper(table), strings.ToUpper(name)) - - var cnt int - row.Scan(&cnt) - return cnt > 0 -} - -// execute insert sql with given struct and given values. -// insert the given values, not the field values in struct. -func (d *dbBaseOracle) InsertValue(q dbQuerier, mi *modelInfo, isMulti bool, names []string, values []interface{}) (int64, error) { - Q := d.ins.TableQuote() - - marks := make([]string, len(names)) - for i := range marks { - marks[i] = ":" + names[i] - } - - sep := fmt.Sprintf("%s, %s", Q, Q) - qmarks := strings.Join(marks, ", ") - columns := strings.Join(names, sep) - - multi := len(values) / len(names) - - if isMulti { - qmarks = strings.Repeat(qmarks+"), (", multi-1) + qmarks - } - - query := fmt.Sprintf("INSERT INTO %s%s%s (%s%s%s) VALUES (%s)", Q, mi.table, Q, Q, columns, Q, qmarks) - - d.ins.ReplaceMarks(&query) - - if isMulti || !d.ins.HasReturningID(mi, &query) { - res, err := q.Exec(query, values...) - if err == nil { - if isMulti { - return res.RowsAffected() - } - return res.LastInsertId() - } - return 0, err - } - row := q.QueryRow(query, values...) - var id int64 - err := row.Scan(&id) - return id, err -} diff --git a/orm/db_postgres.go b/orm/db_postgres.go deleted file mode 100644 index 7eb88d7a..00000000 --- a/orm/db_postgres.go +++ /dev/null @@ -1,195 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package orm - -import ( - "fmt" - "strconv" -) - -// postgresql operators. -var postgresOperators = map[string]string{ - "exact": "= ?", - "iexact": "= UPPER(?)", - "contains": "LIKE ?", - "icontains": "LIKE UPPER(?)", - "gt": "> ?", - ">": "> ?", - "gte": ">= ?", - ">=": ">= ?", - "lt": "< ?", - "<": "< ?", - "lte": "<= ?", - "<=": "<= ?", - "eq": "= ?", - "=": "= ?", - "ne": "!= ?", - "!=": "!= ?", - "startswith": "LIKE ?", - "endswith": "LIKE ?", - "istartswith": "LIKE UPPER(?)", - "iendswith": "LIKE UPPER(?)", -} - -// postgresql column field types. -var postgresTypes = map[string]string{ - "auto": "serial NOT NULL PRIMARY KEY", - "pk": "NOT NULL PRIMARY KEY", - "bool": "bool", - "string": "varchar(%d)", - "string-char": "char(%d)", - "string-text": "text", - "time.Time-date": "date", - "time.Time": "timestamp with time zone", - "int8": `smallint CHECK("%COL%" >= -127 AND "%COL%" <= 128)`, - "int16": "smallint", - "int32": "integer", - "int64": "bigint", - "uint8": `smallint CHECK("%COL%" >= 0 AND "%COL%" <= 255)`, - "uint16": `integer CHECK("%COL%" >= 0)`, - "uint32": `bigint CHECK("%COL%" >= 0)`, - "uint64": `bigint CHECK("%COL%" >= 0)`, - "float64": "double precision", - "float64-decimal": "numeric(%d, %d)", - "json": "json", - "jsonb": "jsonb", -} - -// postgresql dbBaser. -type dbBasePostgres struct { - dbBase -} - -var _ dbBaser = new(dbBasePostgres) - -// get postgresql operator. -func (d *dbBasePostgres) OperatorSQL(operator string) string { - return postgresOperators[operator] -} - -// generate functioned sql string, such as contains(text). -func (d *dbBasePostgres) GenerateOperatorLeftCol(fi *fieldInfo, operator string, leftCol *string) { - switch operator { - case "contains", "startswith", "endswith": - *leftCol = fmt.Sprintf("%s::text", *leftCol) - case "iexact", "icontains", "istartswith", "iendswith": - *leftCol = fmt.Sprintf("UPPER(%s::text)", *leftCol) - } -} - -// postgresql unsupports updating joined record. -func (d *dbBasePostgres) SupportUpdateJoin() bool { - return false -} - -func (d *dbBasePostgres) MaxLimit() uint64 { - return 0 -} - -// postgresql quote is ". -func (d *dbBasePostgres) TableQuote() string { - return `"` -} - -// postgresql value placeholder is $n. -// replace default ? to $n. -func (d *dbBasePostgres) ReplaceMarks(query *string) { - q := *query - num := 0 - for _, c := range q { - if c == '?' { - num++ - } - } - if num == 0 { - return - } - data := make([]byte, 0, len(q)+num) - num = 1 - for i := 0; i < len(q); i++ { - c := q[i] - if c == '?' { - data = append(data, '$') - data = append(data, []byte(strconv.Itoa(num))...) - num++ - } else { - data = append(data, c) - } - } - *query = string(data) -} - -// make returning sql support for postgresql. -func (d *dbBasePostgres) HasReturningID(mi *modelInfo, query *string) bool { - fi := mi.fields.pk - if fi.fieldType&IsPositiveIntegerField == 0 && fi.fieldType&IsIntegerField == 0 { - return false - } - - if query != nil { - *query = fmt.Sprintf(`%s RETURNING "%s"`, *query, fi.column) - } - return true -} - -// sync auto key -func (d *dbBasePostgres) setval(db dbQuerier, mi *modelInfo, autoFields []string) error { - if len(autoFields) == 0 { - return nil - } - - Q := d.ins.TableQuote() - for _, name := range autoFields { - query := fmt.Sprintf("SELECT setval(pg_get_serial_sequence('%s', '%s'), (SELECT MAX(%s%s%s) FROM %s%s%s));", - mi.table, name, - Q, name, Q, - Q, mi.table, Q) - if _, err := db.Exec(query); err != nil { - return err - } - } - return nil -} - -// show table sql for postgresql. -func (d *dbBasePostgres) ShowTablesQuery() string { - return "SELECT table_name FROM information_schema.tables WHERE table_type = 'BASE TABLE' AND table_schema NOT IN ('pg_catalog', 'information_schema')" -} - -// show table columns sql for postgresql. -func (d *dbBasePostgres) ShowColumnsQuery(table string) string { - return fmt.Sprintf("SELECT column_name, data_type, is_nullable FROM information_schema.columns where table_schema NOT IN ('pg_catalog', 'information_schema') and table_name = '%s'", table) -} - -// get column types of postgresql. -func (d *dbBasePostgres) DbTypes() map[string]string { - return postgresTypes -} - -// check index exist in postgresql. -func (d *dbBasePostgres) IndexExists(db dbQuerier, table string, name string) bool { - query := fmt.Sprintf("SELECT COUNT(*) FROM pg_indexes WHERE tablename = '%s' AND indexname = '%s'", table, name) - row := db.QueryRow(query) - var cnt int - row.Scan(&cnt) - return cnt > 0 -} - -// create new postgresql dbBaser. -func newdbBasePostgres() dbBaser { - b := new(dbBasePostgres) - b.ins = b - return b -} diff --git a/orm/db_sqlite.go b/orm/db_sqlite.go deleted file mode 100644 index bd9f5d3b..00000000 --- a/orm/db_sqlite.go +++ /dev/null @@ -1,167 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package orm - -import ( - "database/sql" - "fmt" - "reflect" - "time" -) - -// sqlite operators. -var sqliteOperators = map[string]string{ - "exact": "= ?", - "iexact": "LIKE ? ESCAPE '\\'", - "contains": "LIKE ? ESCAPE '\\'", - "icontains": "LIKE ? ESCAPE '\\'", - "gt": "> ?", - ">": "> ?", - "gte": ">= ?", - ">=": ">= ?", - "lt": "< ?", - "<": "< ?", - "lte": "<= ?", - "<=": "<= ?", - "eq": "= ?", - "=": "= ?", - "ne": "!= ?", - "!=": "!= ?", - "startswith": "LIKE ? ESCAPE '\\'", - "endswith": "LIKE ? ESCAPE '\\'", - "istartswith": "LIKE ? ESCAPE '\\'", - "iendswith": "LIKE ? ESCAPE '\\'", -} - -// sqlite column types. -var sqliteTypes = map[string]string{ - "auto": "integer NOT NULL PRIMARY KEY AUTOINCREMENT", - "pk": "NOT NULL PRIMARY KEY", - "bool": "bool", - "string": "varchar(%d)", - "string-char": "character(%d)", - "string-text": "text", - "time.Time-date": "date", - "time.Time": "datetime", - "int8": "tinyint", - "int16": "smallint", - "int32": "integer", - "int64": "bigint", - "uint8": "tinyint unsigned", - "uint16": "smallint unsigned", - "uint32": "integer unsigned", - "uint64": "bigint unsigned", - "float64": "real", - "float64-decimal": "decimal", -} - -// sqlite dbBaser. -type dbBaseSqlite struct { - dbBase -} - -var _ dbBaser = new(dbBaseSqlite) - -// override base db read for update behavior as SQlite does not support syntax -func (d *dbBaseSqlite) Read(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location, cols []string, isForUpdate bool) error { - if isForUpdate { - DebugLog.Println("[WARN] SQLite does not support SELECT FOR UPDATE query, isForUpdate param is ignored and always as false to do the work") - } - return d.dbBase.Read(q, mi, ind, tz, cols, false) -} - -// get sqlite operator. -func (d *dbBaseSqlite) OperatorSQL(operator string) string { - return sqliteOperators[operator] -} - -// generate functioned sql for sqlite. -// only support DATE(text). -func (d *dbBaseSqlite) GenerateOperatorLeftCol(fi *fieldInfo, operator string, leftCol *string) { - if fi.fieldType == TypeDateField { - *leftCol = fmt.Sprintf("DATE(%s)", *leftCol) - } -} - -// unable updating joined record in sqlite. -func (d *dbBaseSqlite) SupportUpdateJoin() bool { - return false -} - -// max int in sqlite. -func (d *dbBaseSqlite) MaxLimit() uint64 { - return 9223372036854775807 -} - -// get column types in sqlite. -func (d *dbBaseSqlite) DbTypes() map[string]string { - return sqliteTypes -} - -// get show tables sql in sqlite. -func (d *dbBaseSqlite) ShowTablesQuery() string { - return "SELECT name FROM sqlite_master WHERE type = 'table'" -} - -// get columns in sqlite. -func (d *dbBaseSqlite) GetColumns(db dbQuerier, table string) (map[string][3]string, error) { - query := d.ins.ShowColumnsQuery(table) - rows, err := db.Query(query) - if err != nil { - return nil, err - } - - columns := make(map[string][3]string) - for rows.Next() { - var tmp, name, typ, null sql.NullString - err := rows.Scan(&tmp, &name, &typ, &null, &tmp, &tmp) - if err != nil { - return nil, err - } - columns[name.String] = [3]string{name.String, typ.String, null.String} - } - - return columns, nil -} - -// get show columns sql in sqlite. -func (d *dbBaseSqlite) ShowColumnsQuery(table string) string { - return fmt.Sprintf("pragma table_info('%s')", table) -} - -// check index exist in sqlite. -func (d *dbBaseSqlite) IndexExists(db dbQuerier, table string, name string) bool { - query := fmt.Sprintf("PRAGMA index_list('%s')", table) - rows, err := db.Query(query) - if err != nil { - panic(err) - } - defer rows.Close() - for rows.Next() { - var tmp, index sql.NullString - rows.Scan(&tmp, &index, &tmp, &tmp, &tmp) - if name == index.String { - return true - } - } - return false -} - -// create new sqlite dbBaser. -func newdbBaseSqlite() dbBaser { - b := new(dbBaseSqlite) - b.ins = b - return b -} diff --git a/orm/db_tables.go b/orm/db_tables.go deleted file mode 100644 index 4b21a6fc..00000000 --- a/orm/db_tables.go +++ /dev/null @@ -1,482 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package orm - -import ( - "fmt" - "strings" - "time" -) - -// table info struct. -type dbTable struct { - id int - index string - name string - names []string - sel bool - inner bool - mi *modelInfo - fi *fieldInfo - jtl *dbTable -} - -// tables collection struct, contains some tables. -type dbTables struct { - tablesM map[string]*dbTable - tables []*dbTable - mi *modelInfo - base dbBaser - skipEnd bool -} - -// set table info to collection. -// if not exist, create new. -func (t *dbTables) set(names []string, mi *modelInfo, fi *fieldInfo, inner bool) *dbTable { - name := strings.Join(names, ExprSep) - if j, ok := t.tablesM[name]; ok { - j.name = name - j.mi = mi - j.fi = fi - j.inner = inner - } else { - i := len(t.tables) + 1 - jt := &dbTable{i, fmt.Sprintf("T%d", i), name, names, false, inner, mi, fi, nil} - t.tablesM[name] = jt - t.tables = append(t.tables, jt) - } - return t.tablesM[name] -} - -// add table info to collection. -func (t *dbTables) add(names []string, mi *modelInfo, fi *fieldInfo, inner bool) (*dbTable, bool) { - name := strings.Join(names, ExprSep) - if _, ok := t.tablesM[name]; !ok { - i := len(t.tables) + 1 - jt := &dbTable{i, fmt.Sprintf("T%d", i), name, names, false, inner, mi, fi, nil} - t.tablesM[name] = jt - t.tables = append(t.tables, jt) - return jt, true - } - return t.tablesM[name], false -} - -// get table info in collection. -func (t *dbTables) get(name string) (*dbTable, bool) { - j, ok := t.tablesM[name] - return j, ok -} - -// get related fields info in recursive depth loop. -// loop once, depth decreases one. -func (t *dbTables) loopDepth(depth int, prefix string, fi *fieldInfo, related []string) []string { - if depth < 0 || fi.fieldType == RelManyToMany { - return related - } - - if prefix == "" { - prefix = fi.name - } else { - prefix = prefix + ExprSep + fi.name - } - related = append(related, prefix) - - depth-- - for _, fi := range fi.relModelInfo.fields.fieldsRel { - related = t.loopDepth(depth, prefix, fi, related) - } - - return related -} - -// parse related fields. -func (t *dbTables) parseRelated(rels []string, depth int) { - - relsNum := len(rels) - related := make([]string, relsNum) - copy(related, rels) - - relDepth := depth - - if relsNum != 0 { - relDepth = 0 - } - - relDepth-- - for _, fi := range t.mi.fields.fieldsRel { - related = t.loopDepth(relDepth, "", fi, related) - } - - for i, s := range related { - var ( - exs = strings.Split(s, ExprSep) - names = make([]string, 0, len(exs)) - mmi = t.mi - cancel = true - jtl *dbTable - ) - - inner := true - - for _, ex := range exs { - if fi, ok := mmi.fields.GetByAny(ex); ok && fi.rel && fi.fieldType != RelManyToMany { - names = append(names, fi.name) - mmi = fi.relModelInfo - - if fi.null || t.skipEnd { - inner = false - } - - jt := t.set(names, mmi, fi, inner) - jt.jtl = jtl - - if fi.reverse { - cancel = false - } - - if cancel { - jt.sel = depth > 0 - - if i < relsNum { - jt.sel = true - } - } - - jtl = jt - - } else { - panic(fmt.Errorf("unknown model/table name `%s`", ex)) - } - } - } -} - -// generate join string. -func (t *dbTables) getJoinSQL() (join string) { - Q := t.base.TableQuote() - - for _, jt := range t.tables { - if jt.inner { - join += "INNER JOIN " - } else { - join += "LEFT OUTER JOIN " - } - var ( - table string - t1, t2 string - c1, c2 string - ) - t1 = "T0" - if jt.jtl != nil { - t1 = jt.jtl.index - } - t2 = jt.index - table = jt.mi.table - - switch { - case jt.fi.fieldType == RelManyToMany || jt.fi.fieldType == RelReverseMany || jt.fi.reverse && jt.fi.reverseFieldInfo.fieldType == RelManyToMany: - c1 = jt.fi.mi.fields.pk.column - for _, ffi := range jt.mi.fields.fieldsRel { - if jt.fi.mi == ffi.relModelInfo { - c2 = ffi.column - break - } - } - default: - c1 = jt.fi.column - c2 = jt.fi.relModelInfo.fields.pk.column - - if jt.fi.reverse { - c1 = jt.mi.fields.pk.column - c2 = jt.fi.reverseFieldInfo.column - } - } - - join += fmt.Sprintf("%s%s%s %s ON %s.%s%s%s = %s.%s%s%s ", Q, table, Q, t2, - t2, Q, c2, Q, t1, Q, c1, Q) - } - return -} - -// parse orm model struct field tag expression. -func (t *dbTables) parseExprs(mi *modelInfo, exprs []string) (index, name string, info *fieldInfo, success bool) { - var ( - jtl *dbTable - fi *fieldInfo - fiN *fieldInfo - mmi = mi - ) - - num := len(exprs) - 1 - var names []string - - inner := true - -loopFor: - for i, ex := range exprs { - - var ok, okN bool - - if fiN != nil { - fi = fiN - ok = true - fiN = nil - } - - if i == 0 { - fi, ok = mmi.fields.GetByAny(ex) - } - - _ = okN - - if ok { - - isRel := fi.rel || fi.reverse - - names = append(names, fi.name) - - switch { - case fi.rel: - mmi = fi.relModelInfo - if fi.fieldType == RelManyToMany { - mmi = fi.relThroughModelInfo - } - case fi.reverse: - mmi = fi.reverseFieldInfo.mi - } - - if i < num { - fiN, okN = mmi.fields.GetByAny(exprs[i+1]) - } - - if isRel && (!fi.mi.isThrough || num != i) { - if fi.null || t.skipEnd { - inner = false - } - - if t.skipEnd && okN || !t.skipEnd { - if t.skipEnd && okN && fiN.pk { - goto loopEnd - } - - jt, _ := t.add(names, mmi, fi, inner) - jt.jtl = jtl - jtl = jt - } - - } - - if num != i { - continue - } - - loopEnd: - - if i == 0 || jtl == nil { - index = "T0" - } else { - index = jtl.index - } - - info = fi - - if jtl == nil { - name = fi.name - } else { - name = jtl.name + ExprSep + fi.name - } - - switch { - case fi.rel: - - case fi.reverse: - switch fi.reverseFieldInfo.fieldType { - case RelOneToOne, RelForeignKey: - index = jtl.index - info = fi.reverseFieldInfo.mi.fields.pk - name = info.name - } - } - - break loopFor - - } else { - index = "" - name = "" - info = nil - success = false - return - } - } - - success = index != "" && info != nil - return -} - -// generate condition sql. -func (t *dbTables) getCondSQL(cond *Condition, sub bool, tz *time.Location) (where string, params []interface{}) { - if cond == nil || cond.IsEmpty() { - return - } - - Q := t.base.TableQuote() - - mi := t.mi - - for i, p := range cond.params { - if i > 0 { - if p.isOr { - where += "OR " - } else { - where += "AND " - } - } - if p.isNot { - where += "NOT " - } - if p.isCond { - w, ps := t.getCondSQL(p.cond, true, tz) - if w != "" { - w = fmt.Sprintf("( %s) ", w) - } - where += w - params = append(params, ps...) - } else { - exprs := p.exprs - - num := len(exprs) - 1 - operator := "" - if operators[exprs[num]] { - operator = exprs[num] - exprs = exprs[:num] - } - - index, _, fi, suc := t.parseExprs(mi, exprs) - if !suc { - panic(fmt.Errorf("unknown field/column name `%s`", strings.Join(p.exprs, ExprSep))) - } - - if operator == "" { - operator = "exact" - } - - var operSQL string - var args []interface{} - if p.isRaw { - operSQL = p.sql - } else { - operSQL, args = t.base.GenerateOperatorSQL(mi, fi, operator, p.args, tz) - } - - leftCol := fmt.Sprintf("%s.%s%s%s", index, Q, fi.column, Q) - t.base.GenerateOperatorLeftCol(fi, operator, &leftCol) - - where += fmt.Sprintf("%s %s ", leftCol, operSQL) - params = append(params, args...) - - } - } - - if !sub && where != "" { - where = "WHERE " + where - } - - return -} - -// generate group sql. -func (t *dbTables) getGroupSQL(groups []string) (groupSQL string) { - if len(groups) == 0 { - return - } - - Q := t.base.TableQuote() - - groupSqls := make([]string, 0, len(groups)) - for _, group := range groups { - exprs := strings.Split(group, ExprSep) - - index, _, fi, suc := t.parseExprs(t.mi, exprs) - if !suc { - panic(fmt.Errorf("unknown field/column name `%s`", strings.Join(exprs, ExprSep))) - } - - groupSqls = append(groupSqls, fmt.Sprintf("%s.%s%s%s", index, Q, fi.column, Q)) - } - - groupSQL = fmt.Sprintf("GROUP BY %s ", strings.Join(groupSqls, ", ")) - return -} - -// generate order sql. -func (t *dbTables) getOrderSQL(orders []string) (orderSQL string) { - if len(orders) == 0 { - return - } - - Q := t.base.TableQuote() - - orderSqls := make([]string, 0, len(orders)) - for _, order := range orders { - asc := "ASC" - if order[0] == '-' { - asc = "DESC" - order = order[1:] - } - exprs := strings.Split(order, ExprSep) - - index, _, fi, suc := t.parseExprs(t.mi, exprs) - if !suc { - panic(fmt.Errorf("unknown field/column name `%s`", strings.Join(exprs, ExprSep))) - } - - orderSqls = append(orderSqls, fmt.Sprintf("%s.%s%s%s %s", index, Q, fi.column, Q, asc)) - } - - orderSQL = fmt.Sprintf("ORDER BY %s ", strings.Join(orderSqls, ", ")) - return -} - -// generate limit sql. -func (t *dbTables) getLimitSQL(mi *modelInfo, offset int64, limit int64) (limits string) { - if limit == 0 { - limit = int64(DefaultRowsLimit) - } - if limit < 0 { - // no limit - if offset > 0 { - maxLimit := t.base.MaxLimit() - if maxLimit == 0 { - limits = fmt.Sprintf("OFFSET %d", offset) - } else { - limits = fmt.Sprintf("LIMIT %d OFFSET %d", maxLimit, offset) - } - } - } else if offset <= 0 { - limits = fmt.Sprintf("LIMIT %d", limit) - } else { - limits = fmt.Sprintf("LIMIT %d OFFSET %d", limit, offset) - } - return -} - -// crete new tables collection. -func newDbTables(mi *modelInfo, base dbBaser) *dbTables { - tables := &dbTables{} - tables.tablesM = make(map[string]*dbTable) - tables.mi = mi - tables.base = base - return tables -} diff --git a/orm/db_tidb.go b/orm/db_tidb.go deleted file mode 100644 index 6020a488..00000000 --- a/orm/db_tidb.go +++ /dev/null @@ -1,63 +0,0 @@ -// Copyright 2015 TiDB Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package orm - -import ( - "fmt" -) - -// mysql dbBaser implementation. -type dbBaseTidb struct { - dbBase -} - -var _ dbBaser = new(dbBaseTidb) - -// get mysql operator. -func (d *dbBaseTidb) OperatorSQL(operator string) string { - return mysqlOperators[operator] -} - -// get mysql table field types. -func (d *dbBaseTidb) DbTypes() map[string]string { - return mysqlTypes -} - -// show table sql for mysql. -func (d *dbBaseTidb) ShowTablesQuery() string { - return "SELECT table_name FROM information_schema.tables WHERE table_type = 'BASE TABLE' AND table_schema = DATABASE()" -} - -// show columns sql of table for mysql. -func (d *dbBaseTidb) ShowColumnsQuery(table string) string { - return fmt.Sprintf("SELECT COLUMN_NAME, COLUMN_TYPE, IS_NULLABLE FROM information_schema.columns "+ - "WHERE table_schema = DATABASE() AND table_name = '%s'", table) -} - -// execute sql to check index exist. -func (d *dbBaseTidb) IndexExists(db dbQuerier, table string, name string) bool { - row := db.QueryRow("SELECT count(*) FROM information_schema.statistics "+ - "WHERE table_schema = DATABASE() AND table_name = ? AND index_name = ?", table, name) - var cnt int - row.Scan(&cnt) - return cnt > 0 -} - -// create new mysql dbBaser. -func newdbBaseTidb() dbBaser { - b := new(dbBaseTidb) - b.ins = b - return b -} diff --git a/orm/db_utils.go b/orm/db_utils.go deleted file mode 100644 index 7ae10ca5..00000000 --- a/orm/db_utils.go +++ /dev/null @@ -1,177 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package orm - -import ( - "fmt" - "reflect" - "time" -) - -// get table alias. -func getDbAlias(name string) *alias { - if al, ok := dataBaseCache.get(name); ok { - return al - } - panic(fmt.Errorf("unknown DataBase alias name %s", name)) -} - -// get pk column info. -func getExistPk(mi *modelInfo, ind reflect.Value) (column string, value interface{}, exist bool) { - fi := mi.fields.pk - - v := ind.FieldByIndex(fi.fieldIndex) - if fi.fieldType&IsPositiveIntegerField > 0 { - vu := v.Uint() - exist = vu > 0 - value = vu - } else if fi.fieldType&IsIntegerField > 0 { - vu := v.Int() - exist = true - value = vu - } else if fi.fieldType&IsRelField > 0 { - _, value, exist = getExistPk(fi.relModelInfo, reflect.Indirect(v)) - } else { - vu := v.String() - exist = vu != "" - value = vu - } - - column = fi.column - return -} - -// get fields description as flatted string. -func getFlatParams(fi *fieldInfo, args []interface{}, tz *time.Location) (params []interface{}) { - -outFor: - for _, arg := range args { - val := reflect.ValueOf(arg) - - if arg == nil { - params = append(params, arg) - continue - } - - kind := val.Kind() - if kind == reflect.Ptr { - val = val.Elem() - kind = val.Kind() - arg = val.Interface() - } - - switch kind { - case reflect.String: - v := val.String() - if fi != nil { - if fi.fieldType == TypeTimeField || fi.fieldType == TypeDateField || fi.fieldType == TypeDateTimeField { - var t time.Time - var err error - if len(v) >= 19 { - s := v[:19] - t, err = time.ParseInLocation(formatDateTime, s, DefaultTimeLoc) - } else if len(v) >= 10 { - s := v - if len(v) > 10 { - s = v[:10] - } - t, err = time.ParseInLocation(formatDate, s, tz) - } else { - s := v - if len(s) > 8 { - s = v[:8] - } - t, err = time.ParseInLocation(formatTime, s, tz) - } - if err == nil { - if fi.fieldType == TypeDateField { - v = t.In(tz).Format(formatDate) - } else if fi.fieldType == TypeDateTimeField { - v = t.In(tz).Format(formatDateTime) - } else { - v = t.In(tz).Format(formatTime) - } - } - } - } - arg = v - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: - arg = val.Int() - case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: - arg = val.Uint() - case reflect.Float32: - arg, _ = StrTo(ToStr(arg)).Float64() - case reflect.Float64: - arg = val.Float() - case reflect.Bool: - arg = val.Bool() - case reflect.Slice, reflect.Array: - if _, ok := arg.([]byte); ok { - continue outFor - } - - var args []interface{} - for i := 0; i < val.Len(); i++ { - v := val.Index(i) - - var vu interface{} - if v.CanInterface() { - vu = v.Interface() - } - - if vu == nil { - continue - } - - args = append(args, vu) - } - - if len(args) > 0 { - p := getFlatParams(fi, args, tz) - params = append(params, p...) - } - continue outFor - case reflect.Struct: - if v, ok := arg.(time.Time); ok { - if fi != nil && fi.fieldType == TypeDateField { - arg = v.In(tz).Format(formatDate) - } else if fi != nil && fi.fieldType == TypeDateTimeField { - arg = v.In(tz).Format(formatDateTime) - } else if fi != nil && fi.fieldType == TypeTimeField { - arg = v.In(tz).Format(formatTime) - } else { - arg = v.In(tz).Format(formatDateTime) - } - } else { - typ := val.Type() - name := getFullName(typ) - var value interface{} - if mmi, ok := modelCache.getByFullName(name); ok { - if _, vu, exist := getExistPk(mmi, val); exist { - value = vu - } - } - arg = value - - if arg == nil { - panic(fmt.Errorf("need a valid args value, unknown table or value `%s`", name)) - } - } - } - - params = append(params, arg) - } - return -} diff --git a/orm/models.go b/orm/models.go deleted file mode 100644 index 4776bcba..00000000 --- a/orm/models.go +++ /dev/null @@ -1,99 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package orm - -import ( - "sync" -) - -const ( - odCascade = "cascade" - odSetNULL = "set_null" - odSetDefault = "set_default" - odDoNothing = "do_nothing" - defaultStructTagName = "orm" - defaultStructTagDelim = ";" -) - -var ( - modelCache = &_modelCache{ - cache: make(map[string]*modelInfo), - cacheByFullName: make(map[string]*modelInfo), - } -) - -// model info collection -type _modelCache struct { - sync.RWMutex // only used outsite for bootStrap - orders []string - cache map[string]*modelInfo - cacheByFullName map[string]*modelInfo - done bool -} - -// get all model info -func (mc *_modelCache) all() map[string]*modelInfo { - m := make(map[string]*modelInfo, len(mc.cache)) - for k, v := range mc.cache { - m[k] = v - } - return m -} - -// get ordered model info -func (mc *_modelCache) allOrdered() []*modelInfo { - m := make([]*modelInfo, 0, len(mc.orders)) - for _, table := range mc.orders { - m = append(m, mc.cache[table]) - } - return m -} - -// get model info by table name -func (mc *_modelCache) get(table string) (mi *modelInfo, ok bool) { - mi, ok = mc.cache[table] - return -} - -// get model info by full name -func (mc *_modelCache) getByFullName(name string) (mi *modelInfo, ok bool) { - mi, ok = mc.cacheByFullName[name] - return -} - -// set model info to collection -func (mc *_modelCache) set(table string, mi *modelInfo) *modelInfo { - mii := mc.cache[table] - mc.cache[table] = mi - mc.cacheByFullName[mi.fullName] = mi - if mii == nil { - mc.orders = append(mc.orders, table) - } - return mii -} - -// clean all model info. -func (mc *_modelCache) clean() { - mc.orders = make([]string, 0) - mc.cache = make(map[string]*modelInfo) - mc.cacheByFullName = make(map[string]*modelInfo) - mc.done = false -} - -// ResetModelCache Clean model cache. Then you can re-RegisterModel. -// Common use this api for test case. -func ResetModelCache() { - modelCache.clean() -} diff --git a/orm/models_boot.go b/orm/models_boot.go deleted file mode 100644 index 8c56b3c4..00000000 --- a/orm/models_boot.go +++ /dev/null @@ -1,347 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package orm - -import ( - "fmt" - "os" - "reflect" - "runtime/debug" - "strings" -) - -// register models. -// PrefixOrSuffix means table name prefix or suffix. -// isPrefix whether the prefix is prefix or suffix -func registerModel(PrefixOrSuffix string, model interface{}, isPrefix bool) { - val := reflect.ValueOf(model) - typ := reflect.Indirect(val).Type() - - if val.Kind() != reflect.Ptr { - panic(fmt.Errorf(" cannot use non-ptr model struct `%s`", getFullName(typ))) - } - // For this case: - // u := &User{} - // registerModel(&u) - if typ.Kind() == reflect.Ptr { - panic(fmt.Errorf(" only allow ptr model struct, it looks you use two reference to the struct `%s`", typ)) - } - - table := getTableName(val) - - if PrefixOrSuffix != "" { - if isPrefix { - table = PrefixOrSuffix + table - } else { - table = table + PrefixOrSuffix - } - } - // models's fullname is pkgpath + struct name - name := getFullName(typ) - if _, ok := modelCache.getByFullName(name); ok { - fmt.Printf(" model `%s` repeat register, must be unique\n", name) - os.Exit(2) - } - - if _, ok := modelCache.get(table); ok { - fmt.Printf(" table name `%s` repeat register, must be unique\n", table) - os.Exit(2) - } - - mi := newModelInfo(val) - if mi.fields.pk == nil { - outFor: - for _, fi := range mi.fields.fieldsDB { - if strings.ToLower(fi.name) == "id" { - switch fi.addrValue.Elem().Kind() { - case reflect.Int, reflect.Int32, reflect.Int64, reflect.Uint, reflect.Uint32, reflect.Uint64: - fi.auto = true - fi.pk = true - mi.fields.pk = fi - break outFor - } - } - } - - if mi.fields.pk == nil { - fmt.Printf(" `%s` needs a primary key field, default is to use 'id' if not set\n", name) - os.Exit(2) - } - - } - - mi.table = table - mi.pkg = typ.PkgPath() - mi.model = model - mi.manual = true - - modelCache.set(table, mi) -} - -// bootstrap models -func bootStrap() { - if modelCache.done { - return - } - var ( - err error - models map[string]*modelInfo - ) - if dataBaseCache.getDefault() == nil { - err = fmt.Errorf("must have one register DataBase alias named `default`") - goto end - } - - // set rel and reverse model - // RelManyToMany set the relTable - models = modelCache.all() - for _, mi := range models { - for _, fi := range mi.fields.columns { - if fi.rel || fi.reverse { - elm := fi.addrValue.Type().Elem() - if fi.fieldType == RelReverseMany || fi.fieldType == RelManyToMany { - elm = elm.Elem() - } - // check the rel or reverse model already register - name := getFullName(elm) - mii, ok := modelCache.getByFullName(name) - if !ok || mii.pkg != elm.PkgPath() { - err = fmt.Errorf("can not find rel in field `%s`, `%s` may be miss register", fi.fullName, elm.String()) - goto end - } - fi.relModelInfo = mii - - switch fi.fieldType { - case RelManyToMany: - if fi.relThrough != "" { - if i := strings.LastIndex(fi.relThrough, "."); i != -1 && len(fi.relThrough) > (i+1) { - pn := fi.relThrough[:i] - rmi, ok := modelCache.getByFullName(fi.relThrough) - if !ok || pn != rmi.pkg { - err = fmt.Errorf("field `%s` wrong rel_through value `%s` cannot find table", fi.fullName, fi.relThrough) - goto end - } - fi.relThroughModelInfo = rmi - fi.relTable = rmi.table - } else { - err = fmt.Errorf("field `%s` wrong rel_through value `%s`", fi.fullName, fi.relThrough) - goto end - } - } else { - i := newM2MModelInfo(mi, mii) - if fi.relTable != "" { - i.table = fi.relTable - } - if v := modelCache.set(i.table, i); v != nil { - err = fmt.Errorf("the rel table name `%s` already registered, cannot be use, please change one", fi.relTable) - goto end - } - fi.relTable = i.table - fi.relThroughModelInfo = i - } - - fi.relThroughModelInfo.isThrough = true - } - } - } - } - - // check the rel filed while the relModelInfo also has filed point to current model - // if not exist, add a new field to the relModelInfo - models = modelCache.all() - for _, mi := range models { - for _, fi := range mi.fields.fieldsRel { - switch fi.fieldType { - case RelForeignKey, RelOneToOne, RelManyToMany: - inModel := false - for _, ffi := range fi.relModelInfo.fields.fieldsReverse { - if ffi.relModelInfo == mi { - inModel = true - break - } - } - if !inModel { - rmi := fi.relModelInfo - ffi := new(fieldInfo) - ffi.name = mi.name - ffi.column = ffi.name - ffi.fullName = rmi.fullName + "." + ffi.name - ffi.reverse = true - ffi.relModelInfo = mi - ffi.mi = rmi - if fi.fieldType == RelOneToOne { - ffi.fieldType = RelReverseOne - } else { - ffi.fieldType = RelReverseMany - } - if !rmi.fields.Add(ffi) { - added := false - for cnt := 0; cnt < 5; cnt++ { - ffi.name = fmt.Sprintf("%s%d", mi.name, cnt) - ffi.column = ffi.name - ffi.fullName = rmi.fullName + "." + ffi.name - if added = rmi.fields.Add(ffi); added { - break - } - } - if !added { - panic(fmt.Errorf("cannot generate auto reverse field info `%s` to `%s`", fi.fullName, ffi.fullName)) - } - } - } - } - } - } - - models = modelCache.all() - for _, mi := range models { - for _, fi := range mi.fields.fieldsRel { - switch fi.fieldType { - case RelManyToMany: - for _, ffi := range fi.relThroughModelInfo.fields.fieldsRel { - switch ffi.fieldType { - case RelOneToOne, RelForeignKey: - if ffi.relModelInfo == fi.relModelInfo { - fi.reverseFieldInfoTwo = ffi - } - if ffi.relModelInfo == mi { - fi.reverseField = ffi.name - fi.reverseFieldInfo = ffi - } - } - } - if fi.reverseFieldInfoTwo == nil { - err = fmt.Errorf("can not find m2m field for m2m model `%s`, ensure your m2m model defined correct", - fi.relThroughModelInfo.fullName) - goto end - } - } - } - } - - models = modelCache.all() - for _, mi := range models { - for _, fi := range mi.fields.fieldsReverse { - switch fi.fieldType { - case RelReverseOne: - found := false - mForA: - for _, ffi := range fi.relModelInfo.fields.fieldsByType[RelOneToOne] { - if ffi.relModelInfo == mi { - found = true - fi.reverseField = ffi.name - fi.reverseFieldInfo = ffi - - ffi.reverseField = fi.name - ffi.reverseFieldInfo = fi - break mForA - } - } - if !found { - err = fmt.Errorf("reverse field `%s` not found in model `%s`", fi.fullName, fi.relModelInfo.fullName) - goto end - } - case RelReverseMany: - found := false - mForB: - for _, ffi := range fi.relModelInfo.fields.fieldsByType[RelForeignKey] { - if ffi.relModelInfo == mi { - found = true - fi.reverseField = ffi.name - fi.reverseFieldInfo = ffi - - ffi.reverseField = fi.name - ffi.reverseFieldInfo = fi - - break mForB - } - } - if !found { - mForC: - for _, ffi := range fi.relModelInfo.fields.fieldsByType[RelManyToMany] { - conditions := fi.relThrough != "" && fi.relThrough == ffi.relThrough || - fi.relTable != "" && fi.relTable == ffi.relTable || - fi.relThrough == "" && fi.relTable == "" - if ffi.relModelInfo == mi && conditions { - found = true - - fi.reverseField = ffi.reverseFieldInfoTwo.name - fi.reverseFieldInfo = ffi.reverseFieldInfoTwo - fi.relThroughModelInfo = ffi.relThroughModelInfo - fi.reverseFieldInfoTwo = ffi.reverseFieldInfo - fi.reverseFieldInfoM2M = ffi - ffi.reverseFieldInfoM2M = fi - - break mForC - } - } - } - if !found { - err = fmt.Errorf("reverse field for `%s` not found in model `%s`", fi.fullName, fi.relModelInfo.fullName) - goto end - } - } - } - } - -end: - if err != nil { - fmt.Println(err) - debug.PrintStack() - os.Exit(2) - } -} - -// RegisterModel register models -func RegisterModel(models ...interface{}) { - if modelCache.done { - panic(fmt.Errorf("RegisterModel must be run before BootStrap")) - } - RegisterModelWithPrefix("", models...) -} - -// RegisterModelWithPrefix register models with a prefix -func RegisterModelWithPrefix(prefix string, models ...interface{}) { - if modelCache.done { - panic(fmt.Errorf("RegisterModelWithPrefix must be run before BootStrap")) - } - - for _, model := range models { - registerModel(prefix, model, true) - } -} - -// RegisterModelWithSuffix register models with a suffix -func RegisterModelWithSuffix(suffix string, models ...interface{}) { - if modelCache.done { - panic(fmt.Errorf("RegisterModelWithSuffix must be run before BootStrap")) - } - - for _, model := range models { - registerModel(suffix, model, false) - } -} - -// BootStrap bootstrap models. -// make all model parsed and can not add more models -func BootStrap() { - modelCache.Lock() - defer modelCache.Unlock() - if modelCache.done { - return - } - bootStrap() - modelCache.done = true -} diff --git a/orm/models_fields.go b/orm/models_fields.go deleted file mode 100644 index b4fad94f..00000000 --- a/orm/models_fields.go +++ /dev/null @@ -1,783 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package orm - -import ( - "fmt" - "strconv" - "time" -) - -// Define the Type enum -const ( - TypeBooleanField = 1 << iota - TypeVarCharField - TypeCharField - TypeTextField - TypeTimeField - TypeDateField - TypeDateTimeField - TypeBitField - TypeSmallIntegerField - TypeIntegerField - TypeBigIntegerField - TypePositiveBitField - TypePositiveSmallIntegerField - TypePositiveIntegerField - TypePositiveBigIntegerField - TypeFloatField - TypeDecimalField - TypeJSONField - TypeJsonbField - RelForeignKey - RelOneToOne - RelManyToMany - RelReverseOne - RelReverseMany -) - -// Define some logic enum -const ( - IsIntegerField = ^-TypePositiveBigIntegerField >> 6 << 7 - IsPositiveIntegerField = ^-TypePositiveBigIntegerField >> 10 << 11 - IsRelField = ^-RelReverseMany >> 18 << 19 - IsFieldType = ^-RelReverseMany<<1 + 1 -) - -// BooleanField A true/false field. -type BooleanField bool - -// Value return the BooleanField -func (e BooleanField) Value() bool { - return bool(e) -} - -// Set will set the BooleanField -func (e *BooleanField) Set(d bool) { - *e = BooleanField(d) -} - -// String format the Bool to string -func (e *BooleanField) String() string { - return strconv.FormatBool(e.Value()) -} - -// FieldType return BooleanField the type -func (e *BooleanField) FieldType() int { - return TypeBooleanField -} - -// SetRaw set the interface to bool -func (e *BooleanField) SetRaw(value interface{}) error { - switch d := value.(type) { - case bool: - e.Set(d) - case string: - v, err := StrTo(d).Bool() - if err == nil { - e.Set(v) - } - return err - default: - return fmt.Errorf(" unknown value `%s`", value) - } - return nil -} - -// RawValue return the current value -func (e *BooleanField) RawValue() interface{} { - return e.Value() -} - -// verify the BooleanField implement the Fielder interface -var _ Fielder = new(BooleanField) - -// CharField A string field -// required values tag: size -// The size is enforced at the database level and in models’s validation. -// eg: `orm:"size(120)"` -type CharField string - -// Value return the CharField's Value -func (e CharField) Value() string { - return string(e) -} - -// Set CharField value -func (e *CharField) Set(d string) { - *e = CharField(d) -} - -// String return the CharField -func (e *CharField) String() string { - return e.Value() -} - -// FieldType return the enum type -func (e *CharField) FieldType() int { - return TypeVarCharField -} - -// SetRaw set the interface to string -func (e *CharField) SetRaw(value interface{}) error { - switch d := value.(type) { - case string: - e.Set(d) - default: - return fmt.Errorf(" unknown value `%s`", value) - } - return nil -} - -// RawValue return the CharField value -func (e *CharField) RawValue() interface{} { - return e.Value() -} - -// verify CharField implement Fielder -var _ Fielder = new(CharField) - -// TimeField A time, represented in go by a time.Time instance. -// only time values like 10:00:00 -// Has a few extra, optional attr tag: -// -// auto_now: -// Automatically set the field to now every time the object is saved. Useful for “last-modified” timestamps. -// Note that the current date is always used; it’s not just a default value that you can override. -// -// auto_now_add: -// Automatically set the field to now when the object is first created. Useful for creation of timestamps. -// Note that the current date is always used; it’s not just a default value that you can override. -// -// eg: `orm:"auto_now"` or `orm:"auto_now_add"` -type TimeField time.Time - -// Value return the time.Time -func (e TimeField) Value() time.Time { - return time.Time(e) -} - -// Set set the TimeField's value -func (e *TimeField) Set(d time.Time) { - *e = TimeField(d) -} - -// String convert time to string -func (e *TimeField) String() string { - return e.Value().String() -} - -// FieldType return enum type Date -func (e *TimeField) FieldType() int { - return TypeDateField -} - -// SetRaw convert the interface to time.Time. Allow string and time.Time -func (e *TimeField) SetRaw(value interface{}) error { - switch d := value.(type) { - case time.Time: - e.Set(d) - case string: - v, err := timeParse(d, formatTime) - if err == nil { - e.Set(v) - } - return err - default: - return fmt.Errorf(" unknown value `%s`", value) - } - return nil -} - -// RawValue return time value -func (e *TimeField) RawValue() interface{} { - return e.Value() -} - -var _ Fielder = new(TimeField) - -// DateField A date, represented in go by a time.Time instance. -// only date values like 2006-01-02 -// Has a few extra, optional attr tag: -// -// auto_now: -// Automatically set the field to now every time the object is saved. Useful for “last-modified” timestamps. -// Note that the current date is always used; it’s not just a default value that you can override. -// -// auto_now_add: -// Automatically set the field to now when the object is first created. Useful for creation of timestamps. -// Note that the current date is always used; it’s not just a default value that you can override. -// -// eg: `orm:"auto_now"` or `orm:"auto_now_add"` -type DateField time.Time - -// Value return the time.Time -func (e DateField) Value() time.Time { - return time.Time(e) -} - -// Set set the DateField's value -func (e *DateField) Set(d time.Time) { - *e = DateField(d) -} - -// String convert datetime to string -func (e *DateField) String() string { - return e.Value().String() -} - -// FieldType return enum type Date -func (e *DateField) FieldType() int { - return TypeDateField -} - -// SetRaw convert the interface to time.Time. Allow string and time.Time -func (e *DateField) SetRaw(value interface{}) error { - switch d := value.(type) { - case time.Time: - e.Set(d) - case string: - v, err := timeParse(d, formatDate) - if err == nil { - e.Set(v) - } - return err - default: - return fmt.Errorf(" unknown value `%s`", value) - } - return nil -} - -// RawValue return Date value -func (e *DateField) RawValue() interface{} { - return e.Value() -} - -// verify DateField implement fielder interface -var _ Fielder = new(DateField) - -// DateTimeField A date, represented in go by a time.Time instance. -// datetime values like 2006-01-02 15:04:05 -// Takes the same extra arguments as DateField. -type DateTimeField time.Time - -// Value return the datetime value -func (e DateTimeField) Value() time.Time { - return time.Time(e) -} - -// Set set the time.Time to datetime -func (e *DateTimeField) Set(d time.Time) { - *e = DateTimeField(d) -} - -// String return the time's String -func (e *DateTimeField) String() string { - return e.Value().String() -} - -// FieldType return the enum TypeDateTimeField -func (e *DateTimeField) FieldType() int { - return TypeDateTimeField -} - -// SetRaw convert the string or time.Time to DateTimeField -func (e *DateTimeField) SetRaw(value interface{}) error { - switch d := value.(type) { - case time.Time: - e.Set(d) - case string: - v, err := timeParse(d, formatDateTime) - if err == nil { - e.Set(v) - } - return err - default: - return fmt.Errorf(" unknown value `%s`", value) - } - return nil -} - -// RawValue return the datetime value -func (e *DateTimeField) RawValue() interface{} { - return e.Value() -} - -// verify datetime implement fielder -var _ Fielder = new(DateTimeField) - -// FloatField A floating-point number represented in go by a float32 value. -type FloatField float64 - -// Value return the FloatField value -func (e FloatField) Value() float64 { - return float64(e) -} - -// Set the Float64 -func (e *FloatField) Set(d float64) { - *e = FloatField(d) -} - -// String return the string -func (e *FloatField) String() string { - return ToStr(e.Value(), -1, 32) -} - -// FieldType return the enum type -func (e *FloatField) FieldType() int { - return TypeFloatField -} - -// SetRaw converter interface Float64 float32 or string to FloatField -func (e *FloatField) SetRaw(value interface{}) error { - switch d := value.(type) { - case float32: - e.Set(float64(d)) - case float64: - e.Set(d) - case string: - v, err := StrTo(d).Float64() - if err == nil { - e.Set(v) - } - return err - default: - return fmt.Errorf(" unknown value `%s`", value) - } - return nil -} - -// RawValue return the FloatField value -func (e *FloatField) RawValue() interface{} { - return e.Value() -} - -// verify FloatField implement Fielder -var _ Fielder = new(FloatField) - -// SmallIntegerField -32768 to 32767 -type SmallIntegerField int16 - -// Value return int16 value -func (e SmallIntegerField) Value() int16 { - return int16(e) -} - -// Set the SmallIntegerField value -func (e *SmallIntegerField) Set(d int16) { - *e = SmallIntegerField(d) -} - -// String convert smallint to string -func (e *SmallIntegerField) String() string { - return ToStr(e.Value()) -} - -// FieldType return enum type SmallIntegerField -func (e *SmallIntegerField) FieldType() int { - return TypeSmallIntegerField -} - -// SetRaw convert interface int16/string to int16 -func (e *SmallIntegerField) SetRaw(value interface{}) error { - switch d := value.(type) { - case int16: - e.Set(d) - case string: - v, err := StrTo(d).Int16() - if err == nil { - e.Set(v) - } - return err - default: - return fmt.Errorf(" unknown value `%s`", value) - } - return nil -} - -// RawValue return smallint value -func (e *SmallIntegerField) RawValue() interface{} { - return e.Value() -} - -// verify SmallIntegerField implement Fielder -var _ Fielder = new(SmallIntegerField) - -// IntegerField -2147483648 to 2147483647 -type IntegerField int32 - -// Value return the int32 -func (e IntegerField) Value() int32 { - return int32(e) -} - -// Set IntegerField value -func (e *IntegerField) Set(d int32) { - *e = IntegerField(d) -} - -// String convert Int32 to string -func (e *IntegerField) String() string { - return ToStr(e.Value()) -} - -// FieldType return the enum type -func (e *IntegerField) FieldType() int { - return TypeIntegerField -} - -// SetRaw convert interface int32/string to int32 -func (e *IntegerField) SetRaw(value interface{}) error { - switch d := value.(type) { - case int32: - e.Set(d) - case string: - v, err := StrTo(d).Int32() - if err == nil { - e.Set(v) - } - return err - default: - return fmt.Errorf(" unknown value `%s`", value) - } - return nil -} - -// RawValue return IntegerField value -func (e *IntegerField) RawValue() interface{} { - return e.Value() -} - -// verify IntegerField implement Fielder -var _ Fielder = new(IntegerField) - -// BigIntegerField -9223372036854775808 to 9223372036854775807. -type BigIntegerField int64 - -// Value return int64 -func (e BigIntegerField) Value() int64 { - return int64(e) -} - -// Set the BigIntegerField value -func (e *BigIntegerField) Set(d int64) { - *e = BigIntegerField(d) -} - -// String convert BigIntegerField to string -func (e *BigIntegerField) String() string { - return ToStr(e.Value()) -} - -// FieldType return enum type -func (e *BigIntegerField) FieldType() int { - return TypeBigIntegerField -} - -// SetRaw convert interface int64/string to int64 -func (e *BigIntegerField) SetRaw(value interface{}) error { - switch d := value.(type) { - case int64: - e.Set(d) - case string: - v, err := StrTo(d).Int64() - if err == nil { - e.Set(v) - } - return err - default: - return fmt.Errorf(" unknown value `%s`", value) - } - return nil -} - -// RawValue return BigIntegerField value -func (e *BigIntegerField) RawValue() interface{} { - return e.Value() -} - -// verify BigIntegerField implement Fielder -var _ Fielder = new(BigIntegerField) - -// PositiveSmallIntegerField 0 to 65535 -type PositiveSmallIntegerField uint16 - -// Value return uint16 -func (e PositiveSmallIntegerField) Value() uint16 { - return uint16(e) -} - -// Set PositiveSmallIntegerField value -func (e *PositiveSmallIntegerField) Set(d uint16) { - *e = PositiveSmallIntegerField(d) -} - -// String convert uint16 to string -func (e *PositiveSmallIntegerField) String() string { - return ToStr(e.Value()) -} - -// FieldType return enum type -func (e *PositiveSmallIntegerField) FieldType() int { - return TypePositiveSmallIntegerField -} - -// SetRaw convert Interface uint16/string to uint16 -func (e *PositiveSmallIntegerField) SetRaw(value interface{}) error { - switch d := value.(type) { - case uint16: - e.Set(d) - case string: - v, err := StrTo(d).Uint16() - if err == nil { - e.Set(v) - } - return err - default: - return fmt.Errorf(" unknown value `%s`", value) - } - return nil -} - -// RawValue returns PositiveSmallIntegerField value -func (e *PositiveSmallIntegerField) RawValue() interface{} { - return e.Value() -} - -// verify PositiveSmallIntegerField implement Fielder -var _ Fielder = new(PositiveSmallIntegerField) - -// PositiveIntegerField 0 to 4294967295 -type PositiveIntegerField uint32 - -// Value return PositiveIntegerField value. Uint32 -func (e PositiveIntegerField) Value() uint32 { - return uint32(e) -} - -// Set the PositiveIntegerField value -func (e *PositiveIntegerField) Set(d uint32) { - *e = PositiveIntegerField(d) -} - -// String convert PositiveIntegerField to string -func (e *PositiveIntegerField) String() string { - return ToStr(e.Value()) -} - -// FieldType return enum type -func (e *PositiveIntegerField) FieldType() int { - return TypePositiveIntegerField -} - -// SetRaw convert interface uint32/string to Uint32 -func (e *PositiveIntegerField) SetRaw(value interface{}) error { - switch d := value.(type) { - case uint32: - e.Set(d) - case string: - v, err := StrTo(d).Uint32() - if err == nil { - e.Set(v) - } - return err - default: - return fmt.Errorf(" unknown value `%s`", value) - } - return nil -} - -// RawValue return the PositiveIntegerField Value -func (e *PositiveIntegerField) RawValue() interface{} { - return e.Value() -} - -// verify PositiveIntegerField implement Fielder -var _ Fielder = new(PositiveIntegerField) - -// PositiveBigIntegerField 0 to 18446744073709551615 -type PositiveBigIntegerField uint64 - -// Value return uint64 -func (e PositiveBigIntegerField) Value() uint64 { - return uint64(e) -} - -// Set PositiveBigIntegerField value -func (e *PositiveBigIntegerField) Set(d uint64) { - *e = PositiveBigIntegerField(d) -} - -// String convert PositiveBigIntegerField to string -func (e *PositiveBigIntegerField) String() string { - return ToStr(e.Value()) -} - -// FieldType return enum type -func (e *PositiveBigIntegerField) FieldType() int { - return TypePositiveIntegerField -} - -// SetRaw convert interface uint64/string to Uint64 -func (e *PositiveBigIntegerField) SetRaw(value interface{}) error { - switch d := value.(type) { - case uint64: - e.Set(d) - case string: - v, err := StrTo(d).Uint64() - if err == nil { - e.Set(v) - } - return err - default: - return fmt.Errorf(" unknown value `%s`", value) - } - return nil -} - -// RawValue return PositiveBigIntegerField value -func (e *PositiveBigIntegerField) RawValue() interface{} { - return e.Value() -} - -// verify PositiveBigIntegerField implement Fielder -var _ Fielder = new(PositiveBigIntegerField) - -// TextField A large text field. -type TextField string - -// Value return TextField value -func (e TextField) Value() string { - return string(e) -} - -// Set the TextField value -func (e *TextField) Set(d string) { - *e = TextField(d) -} - -// String convert TextField to string -func (e *TextField) String() string { - return e.Value() -} - -// FieldType return enum type -func (e *TextField) FieldType() int { - return TypeTextField -} - -// SetRaw convert interface string to string -func (e *TextField) SetRaw(value interface{}) error { - switch d := value.(type) { - case string: - e.Set(d) - default: - return fmt.Errorf(" unknown value `%s`", value) - } - return nil -} - -// RawValue return TextField value -func (e *TextField) RawValue() interface{} { - return e.Value() -} - -// verify TextField implement Fielder -var _ Fielder = new(TextField) - -// JSONField postgres json field. -type JSONField string - -// Value return JSONField value -func (j JSONField) Value() string { - return string(j) -} - -// Set the JSONField value -func (j *JSONField) Set(d string) { - *j = JSONField(d) -} - -// String convert JSONField to string -func (j *JSONField) String() string { - return j.Value() -} - -// FieldType return enum type -func (j *JSONField) FieldType() int { - return TypeJSONField -} - -// SetRaw convert interface string to string -func (j *JSONField) SetRaw(value interface{}) error { - switch d := value.(type) { - case string: - j.Set(d) - default: - return fmt.Errorf(" unknown value `%s`", value) - } - return nil -} - -// RawValue return JSONField value -func (j *JSONField) RawValue() interface{} { - return j.Value() -} - -// verify JSONField implement Fielder -var _ Fielder = new(JSONField) - -// JsonbField postgres json field. -type JsonbField string - -// Value return JsonbField value -func (j JsonbField) Value() string { - return string(j) -} - -// Set the JsonbField value -func (j *JsonbField) Set(d string) { - *j = JsonbField(d) -} - -// String convert JsonbField to string -func (j *JsonbField) String() string { - return j.Value() -} - -// FieldType return enum type -func (j *JsonbField) FieldType() int { - return TypeJsonbField -} - -// SetRaw convert interface string to string -func (j *JsonbField) SetRaw(value interface{}) error { - switch d := value.(type) { - case string: - j.Set(d) - default: - return fmt.Errorf(" unknown value `%s`", value) - } - return nil -} - -// RawValue return JsonbField value -func (j *JsonbField) RawValue() interface{} { - return j.Value() -} - -// verify JsonbField implement Fielder -var _ Fielder = new(JsonbField) diff --git a/orm/models_info_f.go b/orm/models_info_f.go deleted file mode 100644 index 7044b0bd..00000000 --- a/orm/models_info_f.go +++ /dev/null @@ -1,473 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package orm - -import ( - "errors" - "fmt" - "reflect" - "strings" -) - -var errSkipField = errors.New("skip field") - -// field info collection -type fields struct { - pk *fieldInfo - columns map[string]*fieldInfo - fields map[string]*fieldInfo - fieldsLow map[string]*fieldInfo - fieldsByType map[int][]*fieldInfo - fieldsRel []*fieldInfo - fieldsReverse []*fieldInfo - fieldsDB []*fieldInfo - rels []*fieldInfo - orders []string - dbcols []string -} - -// add field info -func (f *fields) Add(fi *fieldInfo) (added bool) { - if f.fields[fi.name] == nil && f.columns[fi.column] == nil { - f.columns[fi.column] = fi - f.fields[fi.name] = fi - f.fieldsLow[strings.ToLower(fi.name)] = fi - } else { - return - } - if _, ok := f.fieldsByType[fi.fieldType]; !ok { - f.fieldsByType[fi.fieldType] = make([]*fieldInfo, 0) - } - f.fieldsByType[fi.fieldType] = append(f.fieldsByType[fi.fieldType], fi) - f.orders = append(f.orders, fi.column) - if fi.dbcol { - f.dbcols = append(f.dbcols, fi.column) - f.fieldsDB = append(f.fieldsDB, fi) - } - if fi.rel { - f.fieldsRel = append(f.fieldsRel, fi) - } - if fi.reverse { - f.fieldsReverse = append(f.fieldsReverse, fi) - } - return true -} - -// get field info by name -func (f *fields) GetByName(name string) *fieldInfo { - return f.fields[name] -} - -// get field info by column name -func (f *fields) GetByColumn(column string) *fieldInfo { - return f.columns[column] -} - -// get field info by string, name is prior -func (f *fields) GetByAny(name string) (*fieldInfo, bool) { - if fi, ok := f.fields[name]; ok { - return fi, ok - } - if fi, ok := f.fieldsLow[strings.ToLower(name)]; ok { - return fi, ok - } - if fi, ok := f.columns[name]; ok { - return fi, ok - } - return nil, false -} - -// create new field info collection -func newFields() *fields { - f := new(fields) - f.fields = make(map[string]*fieldInfo) - f.fieldsLow = make(map[string]*fieldInfo) - f.columns = make(map[string]*fieldInfo) - f.fieldsByType = make(map[int][]*fieldInfo) - return f -} - -// single field info -type fieldInfo struct { - mi *modelInfo - fieldIndex []int - fieldType int - dbcol bool // table column fk and onetoone - inModel bool - name string - fullName string - column string - addrValue reflect.Value - sf reflect.StructField - auto bool - pk bool - null bool - index bool - unique bool - colDefault bool // whether has default tag - initial StrTo // store the default value - size int - toText bool - autoNow bool - autoNowAdd bool - rel bool // if type equal to RelForeignKey, RelOneToOne, RelManyToMany then true - reverse bool - reverseField string - reverseFieldInfo *fieldInfo - reverseFieldInfoTwo *fieldInfo - reverseFieldInfoM2M *fieldInfo - relTable string - relThrough string - relThroughModelInfo *modelInfo - relModelInfo *modelInfo - digits int - decimals int - isFielder bool // implement Fielder interface - onDelete string - description string -} - -// new field info -func newFieldInfo(mi *modelInfo, field reflect.Value, sf reflect.StructField, mName string) (fi *fieldInfo, err error) { - var ( - tag string - tagValue string - initial StrTo // store the default value - fieldType int - attrs map[string]bool - tags map[string]string - addrField reflect.Value - ) - - fi = new(fieldInfo) - - // if field which CanAddr is the follow type - // A value is addressable if it is an element of a slice, - // an element of an addressable array, a field of an - // addressable struct, or the result of dereferencing a pointer. - addrField = field - if field.CanAddr() && field.Kind() != reflect.Ptr { - addrField = field.Addr() - if _, ok := addrField.Interface().(Fielder); !ok { - if field.Kind() == reflect.Slice { - addrField = field - } - } - } - - attrs, tags = parseStructTag(sf.Tag.Get(defaultStructTagName)) - - if _, ok := attrs["-"]; ok { - return nil, errSkipField - } - - digits := tags["digits"] - decimals := tags["decimals"] - size := tags["size"] - onDelete := tags["on_delete"] - - initial.Clear() - if v, ok := tags["default"]; ok { - initial.Set(v) - } - -checkType: - switch f := addrField.Interface().(type) { - case Fielder: - fi.isFielder = true - if field.Kind() == reflect.Ptr { - err = fmt.Errorf("the model Fielder can not be use ptr") - goto end - } - fieldType = f.FieldType() - if fieldType&IsRelField > 0 { - err = fmt.Errorf("unsupport type custom field, please refer to https://github.com/astaxie/beego/blob/master/orm/models_fields.go#L24-L42") - goto end - } - default: - tag = "rel" - tagValue = tags[tag] - if tagValue != "" { - switch tagValue { - case "fk": - fieldType = RelForeignKey - break checkType - case "one": - fieldType = RelOneToOne - break checkType - case "m2m": - fieldType = RelManyToMany - if tv := tags["rel_table"]; tv != "" { - fi.relTable = tv - } else if tv := tags["rel_through"]; tv != "" { - fi.relThrough = tv - } - break checkType - default: - err = fmt.Errorf("rel only allow these value: fk, one, m2m") - goto wrongTag - } - } - tag = "reverse" - tagValue = tags[tag] - if tagValue != "" { - switch tagValue { - case "one": - fieldType = RelReverseOne - break checkType - case "many": - fieldType = RelReverseMany - if tv := tags["rel_table"]; tv != "" { - fi.relTable = tv - } else if tv := tags["rel_through"]; tv != "" { - fi.relThrough = tv - } - break checkType - default: - err = fmt.Errorf("reverse only allow these value: one, many") - goto wrongTag - } - } - - fieldType, err = getFieldType(addrField) - if err != nil { - goto end - } - if fieldType == TypeVarCharField { - switch tags["type"] { - case "char": - fieldType = TypeCharField - case "text": - fieldType = TypeTextField - case "json": - fieldType = TypeJSONField - case "jsonb": - fieldType = TypeJsonbField - } - } - if fieldType == TypeFloatField && (digits != "" || decimals != "") { - fieldType = TypeDecimalField - } - if fieldType == TypeDateTimeField && tags["type"] == "date" { - fieldType = TypeDateField - } - if fieldType == TypeTimeField && tags["type"] == "time" { - fieldType = TypeTimeField - } - } - - // check the rel and reverse type - // rel should Ptr - // reverse should slice []*struct - switch fieldType { - case RelForeignKey, RelOneToOne, RelReverseOne: - if field.Kind() != reflect.Ptr { - err = fmt.Errorf("rel/reverse:one field must be *%s", field.Type().Name()) - goto end - } - case RelManyToMany, RelReverseMany: - if field.Kind() != reflect.Slice { - err = fmt.Errorf("rel/reverse:many field must be slice") - goto end - } else { - if field.Type().Elem().Kind() != reflect.Ptr { - err = fmt.Errorf("rel/reverse:many slice must be []*%s", field.Type().Elem().Name()) - goto end - } - } - } - - if fieldType&IsFieldType == 0 { - err = fmt.Errorf("wrong field type") - goto end - } - - fi.fieldType = fieldType - fi.name = sf.Name - fi.column = getColumnName(fieldType, addrField, sf, tags["column"]) - fi.addrValue = addrField - fi.sf = sf - fi.fullName = mi.fullName + mName + "." + sf.Name - - fi.description = tags["description"] - fi.null = attrs["null"] - fi.index = attrs["index"] - fi.auto = attrs["auto"] - fi.pk = attrs["pk"] - fi.unique = attrs["unique"] - - // Mark object property if there is attribute "default" in the orm configuration - if _, ok := tags["default"]; ok { - fi.colDefault = true - } - - switch fieldType { - case RelManyToMany, RelReverseMany, RelReverseOne: - fi.null = false - fi.index = false - fi.auto = false - fi.pk = false - fi.unique = false - default: - fi.dbcol = true - } - - switch fieldType { - case RelForeignKey, RelOneToOne, RelManyToMany: - fi.rel = true - if fieldType == RelOneToOne { - fi.unique = true - } - case RelReverseMany, RelReverseOne: - fi.reverse = true - } - - if fi.rel && fi.dbcol { - switch onDelete { - case odCascade, odDoNothing: - case odSetDefault: - if !initial.Exist() { - err = errors.New("on_delete: set_default need set field a default value") - goto end - } - case odSetNULL: - if !fi.null { - err = errors.New("on_delete: set_null need set field null") - goto end - } - default: - if onDelete == "" { - onDelete = odCascade - } else { - err = fmt.Errorf("on_delete value expected choice in `cascade,set_null,set_default,do_nothing`, unknown `%s`", onDelete) - goto end - } - } - - fi.onDelete = onDelete - } - - switch fieldType { - case TypeBooleanField: - case TypeVarCharField, TypeCharField, TypeJSONField, TypeJsonbField: - if size != "" { - v, e := StrTo(size).Int32() - if e != nil { - err = fmt.Errorf("wrong size value `%s`", size) - } else { - fi.size = int(v) - } - } else { - fi.size = 255 - fi.toText = true - } - case TypeTextField: - fi.index = false - fi.unique = false - case TypeTimeField, TypeDateField, TypeDateTimeField: - if attrs["auto_now"] { - fi.autoNow = true - } else if attrs["auto_now_add"] { - fi.autoNowAdd = true - } - case TypeFloatField: - case TypeDecimalField: - d1 := digits - d2 := decimals - v1, er1 := StrTo(d1).Int8() - v2, er2 := StrTo(d2).Int8() - if er1 != nil || er2 != nil { - err = fmt.Errorf("wrong digits/decimals value %s/%s", d2, d1) - goto end - } - fi.digits = int(v1) - fi.decimals = int(v2) - default: - switch { - case fieldType&IsIntegerField > 0: - case fieldType&IsRelField > 0: - } - } - - if fieldType&IsIntegerField == 0 { - if fi.auto { - err = fmt.Errorf("non-integer type cannot set auto") - goto end - } - } - - if fi.auto || fi.pk { - if fi.auto { - switch addrField.Elem().Kind() { - case reflect.Int, reflect.Int32, reflect.Int64, reflect.Uint, reflect.Uint32, reflect.Uint64: - default: - err = fmt.Errorf("auto primary key only support int, int32, int64, uint, uint32, uint64 but found `%s`", addrField.Elem().Kind()) - goto end - } - fi.pk = true - } - fi.null = false - fi.index = false - fi.unique = false - } - - if fi.unique { - fi.index = false - } - - // can not set default for these type - if fi.auto || fi.pk || fi.unique || fieldType == TypeTimeField || fieldType == TypeDateField || fieldType == TypeDateTimeField { - initial.Clear() - } - - if initial.Exist() { - v := initial - switch fieldType { - case TypeBooleanField: - _, err = v.Bool() - case TypeFloatField, TypeDecimalField: - _, err = v.Float64() - case TypeBitField: - _, err = v.Int8() - case TypeSmallIntegerField: - _, err = v.Int16() - case TypeIntegerField: - _, err = v.Int32() - case TypeBigIntegerField: - _, err = v.Int64() - case TypePositiveBitField: - _, err = v.Uint8() - case TypePositiveSmallIntegerField: - _, err = v.Uint16() - case TypePositiveIntegerField: - _, err = v.Uint32() - case TypePositiveBigIntegerField: - _, err = v.Uint64() - } - if err != nil { - tag, tagValue = "default", tags["default"] - goto wrongTag - } - } - - fi.initial = initial -end: - if err != nil { - return nil, err - } - return -wrongTag: - return nil, fmt.Errorf("wrong tag format: `%s:\"%s\"`, %s", tag, tagValue, err) -} diff --git a/orm/models_info_m.go b/orm/models_info_m.go deleted file mode 100644 index a4d733b6..00000000 --- a/orm/models_info_m.go +++ /dev/null @@ -1,148 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package orm - -import ( - "fmt" - "os" - "reflect" -) - -// single model info -type modelInfo struct { - pkg string - name string - fullName string - table string - model interface{} - fields *fields - manual bool - addrField reflect.Value //store the original struct value - uniques []string - isThrough bool -} - -// new model info -func newModelInfo(val reflect.Value) (mi *modelInfo) { - mi = &modelInfo{} - mi.fields = newFields() - ind := reflect.Indirect(val) - mi.addrField = val - mi.name = ind.Type().Name() - mi.fullName = getFullName(ind.Type()) - addModelFields(mi, ind, "", []int{}) - return -} - -// index: FieldByIndex returns the nested field corresponding to index -func addModelFields(mi *modelInfo, ind reflect.Value, mName string, index []int) { - var ( - err error - fi *fieldInfo - sf reflect.StructField - ) - - for i := 0; i < ind.NumField(); i++ { - field := ind.Field(i) - sf = ind.Type().Field(i) - // if the field is unexported skip - if sf.PkgPath != "" { - continue - } - // add anonymous struct fields - if sf.Anonymous { - addModelFields(mi, field, mName+"."+sf.Name, append(index, i)) - continue - } - - fi, err = newFieldInfo(mi, field, sf, mName) - if err == errSkipField { - err = nil - continue - } else if err != nil { - break - } - //record current field index - fi.fieldIndex = append(fi.fieldIndex, index...) - fi.fieldIndex = append(fi.fieldIndex, i) - fi.mi = mi - fi.inModel = true - if !mi.fields.Add(fi) { - err = fmt.Errorf("duplicate column name: %s", fi.column) - break - } - if fi.pk { - if mi.fields.pk != nil { - err = fmt.Errorf("one model must have one pk field only") - break - } else { - mi.fields.pk = fi - } - } - } - - if err != nil { - fmt.Println(fmt.Errorf("field: %s.%s, %s", ind.Type(), sf.Name, err)) - os.Exit(2) - } -} - -// combine related model info to new model info. -// prepare for relation models query. -func newM2MModelInfo(m1, m2 *modelInfo) (mi *modelInfo) { - mi = new(modelInfo) - mi.fields = newFields() - mi.table = m1.table + "_" + m2.table + "s" - mi.name = camelString(mi.table) - mi.fullName = m1.pkg + "." + mi.name - - fa := new(fieldInfo) // pk - f1 := new(fieldInfo) // m1 table RelForeignKey - f2 := new(fieldInfo) // m2 table RelForeignKey - fa.fieldType = TypeBigIntegerField - fa.auto = true - fa.pk = true - fa.dbcol = true - fa.name = "Id" - fa.column = "id" - fa.fullName = mi.fullName + "." + fa.name - - f1.dbcol = true - f2.dbcol = true - f1.fieldType = RelForeignKey - f2.fieldType = RelForeignKey - f1.name = camelString(m1.table) - f2.name = camelString(m2.table) - f1.fullName = mi.fullName + "." + f1.name - f2.fullName = mi.fullName + "." + f2.name - f1.column = m1.table + "_id" - f2.column = m2.table + "_id" - f1.rel = true - f2.rel = true - f1.relTable = m1.table - f2.relTable = m2.table - f1.relModelInfo = m1 - f2.relModelInfo = m2 - f1.mi = mi - f2.mi = mi - - mi.fields.Add(fa) - mi.fields.Add(f1) - mi.fields.Add(f2) - mi.fields.pk = fa - - mi.uniques = []string{f1.column, f2.column} - return -} diff --git a/orm/models_test.go b/orm/models_test.go deleted file mode 100644 index e3a635f2..00000000 --- a/orm/models_test.go +++ /dev/null @@ -1,497 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package orm - -import ( - "database/sql" - "encoding/json" - "fmt" - "os" - "strings" - "time" - - _ "github.com/go-sql-driver/mysql" - _ "github.com/lib/pq" - _ "github.com/mattn/go-sqlite3" - // As tidb can't use go get, so disable the tidb testing now - // _ "github.com/pingcap/tidb" -) - -// A slice string field. -type SliceStringField []string - -func (e SliceStringField) Value() []string { - return []string(e) -} - -func (e *SliceStringField) Set(d []string) { - *e = SliceStringField(d) -} - -func (e *SliceStringField) Add(v string) { - *e = append(*e, v) -} - -func (e *SliceStringField) String() string { - return strings.Join(e.Value(), ",") -} - -func (e *SliceStringField) FieldType() int { - return TypeVarCharField -} - -func (e *SliceStringField) SetRaw(value interface{}) error { - switch d := value.(type) { - case []string: - e.Set(d) - case string: - if len(d) > 0 { - parts := strings.Split(d, ",") - v := make([]string, 0, len(parts)) - for _, p := range parts { - v = append(v, strings.TrimSpace(p)) - } - e.Set(v) - } - default: - return fmt.Errorf(" unknown value `%v`", value) - } - return nil -} - -func (e *SliceStringField) RawValue() interface{} { - return e.String() -} - -var _ Fielder = new(SliceStringField) - -// A json field. -type JSONFieldTest struct { - Name string - Data string -} - -func (e *JSONFieldTest) String() string { - data, _ := json.Marshal(e) - return string(data) -} - -func (e *JSONFieldTest) FieldType() int { - return TypeTextField -} - -func (e *JSONFieldTest) SetRaw(value interface{}) error { - switch d := value.(type) { - case string: - return json.Unmarshal([]byte(d), e) - default: - return fmt.Errorf(" unknown value `%v`", value) - } -} - -func (e *JSONFieldTest) RawValue() interface{} { - return e.String() -} - -var _ Fielder = new(JSONFieldTest) - -type Data struct { - ID int `orm:"column(id)"` - Boolean bool - Char string `orm:"size(50)"` - Text string `orm:"type(text)"` - JSON string `orm:"type(json);default({\"name\":\"json\"})"` - Jsonb string `orm:"type(jsonb)"` - Time time.Time `orm:"type(time)"` - Date time.Time `orm:"type(date)"` - DateTime time.Time `orm:"column(datetime)"` - Byte byte - Rune rune - Int int - Int8 int8 - Int16 int16 - Int32 int32 - Int64 int64 - Uint uint - Uint8 uint8 - Uint16 uint16 - Uint32 uint32 - Uint64 uint64 - Float32 float32 - Float64 float64 - Decimal float64 `orm:"digits(8);decimals(4)"` -} - -type DataNull struct { - ID int `orm:"column(id)"` - Boolean bool `orm:"null"` - Char string `orm:"null;size(50)"` - Text string `orm:"null;type(text)"` - JSON string `orm:"type(json);null"` - Jsonb string `orm:"type(jsonb);null"` - Time time.Time `orm:"null;type(time)"` - Date time.Time `orm:"null;type(date)"` - DateTime time.Time `orm:"null;column(datetime)"` - Byte byte `orm:"null"` - Rune rune `orm:"null"` - Int int `orm:"null"` - Int8 int8 `orm:"null"` - Int16 int16 `orm:"null"` - Int32 int32 `orm:"null"` - Int64 int64 `orm:"null"` - Uint uint `orm:"null"` - Uint8 uint8 `orm:"null"` - Uint16 uint16 `orm:"null"` - Uint32 uint32 `orm:"null"` - Uint64 uint64 `orm:"null"` - Float32 float32 `orm:"null"` - Float64 float64 `orm:"null"` - Decimal float64 `orm:"digits(8);decimals(4);null"` - NullString sql.NullString `orm:"null"` - NullBool sql.NullBool `orm:"null"` - NullFloat64 sql.NullFloat64 `orm:"null"` - NullInt64 sql.NullInt64 `orm:"null"` - BooleanPtr *bool `orm:"null"` - CharPtr *string `orm:"null;size(50)"` - TextPtr *string `orm:"null;type(text)"` - BytePtr *byte `orm:"null"` - RunePtr *rune `orm:"null"` - IntPtr *int `orm:"null"` - Int8Ptr *int8 `orm:"null"` - Int16Ptr *int16 `orm:"null"` - Int32Ptr *int32 `orm:"null"` - Int64Ptr *int64 `orm:"null"` - UintPtr *uint `orm:"null"` - Uint8Ptr *uint8 `orm:"null"` - Uint16Ptr *uint16 `orm:"null"` - Uint32Ptr *uint32 `orm:"null"` - Uint64Ptr *uint64 `orm:"null"` - Float32Ptr *float32 `orm:"null"` - Float64Ptr *float64 `orm:"null"` - DecimalPtr *float64 `orm:"digits(8);decimals(4);null"` - TimePtr *time.Time `orm:"null;type(time)"` - DatePtr *time.Time `orm:"null;type(date)"` - DateTimePtr *time.Time `orm:"null"` -} - -type String string -type Boolean bool -type Byte byte -type Rune rune -type Int int -type Int8 int8 -type Int16 int16 -type Int32 int32 -type Int64 int64 -type Uint uint -type Uint8 uint8 -type Uint16 uint16 -type Uint32 uint32 -type Uint64 uint64 -type Float32 float64 -type Float64 float64 - -type DataCustom struct { - ID int `orm:"column(id)"` - Boolean Boolean - Char string `orm:"size(50)"` - Text string `orm:"type(text)"` - Byte Byte - Rune Rune - Int Int - Int8 Int8 - Int16 Int16 - Int32 Int32 - Int64 Int64 - Uint Uint - Uint8 Uint8 - Uint16 Uint16 - Uint32 Uint32 - Uint64 Uint64 - Float32 Float32 - Float64 Float64 - Decimal Float64 `orm:"digits(8);decimals(4)"` -} - -// only for mysql -type UserBig struct { - ID uint64 `orm:"column(id)"` - Name string -} - -type User struct { - ID int `orm:"column(id)"` - UserName string `orm:"size(30);unique"` - Email string `orm:"size(100)"` - Password string `orm:"size(100)"` - Status int16 `orm:"column(Status)"` - IsStaff bool - IsActive bool `orm:"default(true)"` - Created time.Time `orm:"auto_now_add;type(date)"` - Updated time.Time `orm:"auto_now"` - Profile *Profile `orm:"null;rel(one);on_delete(set_null)"` - Posts []*Post `orm:"reverse(many)" json:"-"` - ShouldSkip string `orm:"-"` - Nums int - Langs SliceStringField `orm:"size(100)"` - Extra JSONFieldTest `orm:"type(text)"` - unexport bool `orm:"-"` - unexportBool bool -} - -func (u *User) TableIndex() [][]string { - return [][]string{ - {"Id", "UserName"}, - {"Id", "Created"}, - } -} - -func (u *User) TableUnique() [][]string { - return [][]string{ - {"UserName", "Email"}, - } -} - -func NewUser() *User { - obj := new(User) - return obj -} - -type Profile struct { - ID int `orm:"column(id)"` - Age int16 - Money float64 - User *User `orm:"reverse(one)" json:"-"` - BestPost *Post `orm:"rel(one);null"` -} - -func (u *Profile) TableName() string { - return "user_profile" -} - -func NewProfile() *Profile { - obj := new(Profile) - return obj -} - -type Post struct { - ID int `orm:"column(id)"` - User *User `orm:"rel(fk)"` - Title string `orm:"size(60)"` - Content string `orm:"type(text)"` - Created time.Time `orm:"auto_now_add"` - Updated time.Time `orm:"auto_now"` - Tags []*Tag `orm:"rel(m2m);rel_through(github.com/astaxie/beego/orm.PostTags)"` -} - -func (u *Post) TableIndex() [][]string { - return [][]string{ - {"Id", "Created"}, - } -} - -func NewPost() *Post { - obj := new(Post) - return obj -} - -type Tag struct { - ID int `orm:"column(id)"` - Name string `orm:"size(30)"` - BestPost *Post `orm:"rel(one);null"` - Posts []*Post `orm:"reverse(many)" json:"-"` -} - -func NewTag() *Tag { - obj := new(Tag) - return obj -} - -type PostTags struct { - ID int `orm:"column(id)"` - Post *Post `orm:"rel(fk)"` - Tag *Tag `orm:"rel(fk)"` -} - -func (m *PostTags) TableName() string { - return "prefix_post_tags" -} - -type Comment struct { - ID int `orm:"column(id)"` - Post *Post `orm:"rel(fk);column(post)"` - Content string `orm:"type(text)"` - Parent *Comment `orm:"null;rel(fk)"` - Created time.Time `orm:"auto_now_add"` -} - -func NewComment() *Comment { - obj := new(Comment) - return obj -} - -type Group struct { - ID int `orm:"column(gid);size(32)"` - Name string - Permissions []*Permission `orm:"reverse(many)" json:"-"` -} - -type Permission struct { - ID int `orm:"column(id)"` - Name string - Groups []*Group `orm:"rel(m2m);rel_through(github.com/astaxie/beego/orm.GroupPermissions)"` -} - -type GroupPermissions struct { - ID int `orm:"column(id)"` - Group *Group `orm:"rel(fk)"` - Permission *Permission `orm:"rel(fk)"` -} - -type ModelID struct { - ID int64 -} - -type ModelBase struct { - ModelID - - Created time.Time `orm:"auto_now_add;type(datetime)"` - Updated time.Time `orm:"auto_now;type(datetime)"` -} - -type InLine struct { - // Common Fields - ModelBase - - // Other Fields - Name string `orm:"unique"` - Email string -} - -func NewInLine() *InLine { - return new(InLine) -} - -type InLineOneToOne struct { - // Common Fields - ModelBase - - Note string - InLine *InLine `orm:"rel(fk);column(inline)"` -} - -func NewInLineOneToOne() *InLineOneToOne { - return new(InLineOneToOne) -} - -type IntegerPk struct { - ID int64 `orm:"pk"` - Value string -} - -type UintPk struct { - ID uint32 `orm:"pk"` - Name string -} - -type PtrPk struct { - ID *IntegerPk `orm:"pk;rel(one)"` - Positive bool -} - -var DBARGS = struct { - Driver string - Source string - Debug string -}{ - os.Getenv("ORM_DRIVER"), - os.Getenv("ORM_SOURCE"), - os.Getenv("ORM_DEBUG"), -} - -var ( - IsMysql = DBARGS.Driver == "mysql" - IsSqlite = DBARGS.Driver == "sqlite3" - IsPostgres = DBARGS.Driver == "postgres" - IsTidb = DBARGS.Driver == "tidb" -) - -var ( - dORM Ormer - dDbBaser dbBaser -) - -var ( - helpinfo = `need driver and source! - - Default DB Drivers. - - driver: url - mysql: https://github.com/go-sql-driver/mysql - sqlite3: https://github.com/mattn/go-sqlite3 - postgres: https://github.com/lib/pq - tidb: https://github.com/pingcap/tidb - - usage: - - go get -u github.com/astaxie/beego/orm - go get -u github.com/go-sql-driver/mysql - go get -u github.com/mattn/go-sqlite3 - go get -u github.com/lib/pq - go get -u github.com/pingcap/tidb - - #### MySQL - mysql -u root -e 'create database orm_test;' - export ORM_DRIVER=mysql - export ORM_SOURCE="root:@/orm_test?charset=utf8" - go test -v github.com/astaxie/beego/orm - - - #### Sqlite3 - export ORM_DRIVER=sqlite3 - export ORM_SOURCE='file:memory_test?mode=memory' - go test -v github.com/astaxie/beego/orm - - - #### PostgreSQL - psql -c 'create database orm_test;' -U postgres - export ORM_DRIVER=postgres - export ORM_SOURCE="user=postgres dbname=orm_test sslmode=disable" - go test -v github.com/astaxie/beego/orm - - #### TiDB - export ORM_DRIVER=tidb - export ORM_SOURCE='memory://test/test' - go test -v github.com/astaxie/beego/orm - - ` -) - -func init() { - Debug, _ = StrTo(DBARGS.Debug).Bool() - - if DBARGS.Driver == "" || DBARGS.Source == "" { - fmt.Println(helpinfo) - os.Exit(2) - } - - RegisterDataBase("default", DBARGS.Driver, DBARGS.Source, 20) - - alias := getDbAlias("default") - if alias.Driver == DRMySQL { - alias.Engine = "INNODB" - } - -} diff --git a/orm/models_utils.go b/orm/models_utils.go deleted file mode 100644 index 71127a6b..00000000 --- a/orm/models_utils.go +++ /dev/null @@ -1,227 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package orm - -import ( - "database/sql" - "fmt" - "reflect" - "strings" - "time" -) - -// 1 is attr -// 2 is tag -var supportTag = map[string]int{ - "-": 1, - "null": 1, - "index": 1, - "unique": 1, - "pk": 1, - "auto": 1, - "auto_now": 1, - "auto_now_add": 1, - "size": 2, - "column": 2, - "default": 2, - "rel": 2, - "reverse": 2, - "rel_table": 2, - "rel_through": 2, - "digits": 2, - "decimals": 2, - "on_delete": 2, - "type": 2, - "description": 2, -} - -// get reflect.Type name with package path. -func getFullName(typ reflect.Type) string { - return typ.PkgPath() + "." + typ.Name() -} - -// getTableName get struct table name. -// If the struct implement the TableName, then get the result as tablename -// else use the struct name which will apply snakeString. -func getTableName(val reflect.Value) string { - if fun := val.MethodByName("TableName"); fun.IsValid() { - vals := fun.Call([]reflect.Value{}) - // has return and the first val is string - if len(vals) > 0 && vals[0].Kind() == reflect.String { - return vals[0].String() - } - } - return snakeString(reflect.Indirect(val).Type().Name()) -} - -// get table engine, myisam or innodb. -func getTableEngine(val reflect.Value) string { - fun := val.MethodByName("TableEngine") - if fun.IsValid() { - vals := fun.Call([]reflect.Value{}) - if len(vals) > 0 && vals[0].Kind() == reflect.String { - return vals[0].String() - } - } - return "" -} - -// get table index from method. -func getTableIndex(val reflect.Value) [][]string { - fun := val.MethodByName("TableIndex") - if fun.IsValid() { - vals := fun.Call([]reflect.Value{}) - if len(vals) > 0 && vals[0].CanInterface() { - if d, ok := vals[0].Interface().([][]string); ok { - return d - } - } - } - return nil -} - -// get table unique from method -func getTableUnique(val reflect.Value) [][]string { - fun := val.MethodByName("TableUnique") - if fun.IsValid() { - vals := fun.Call([]reflect.Value{}) - if len(vals) > 0 && vals[0].CanInterface() { - if d, ok := vals[0].Interface().([][]string); ok { - return d - } - } - } - return nil -} - -// get snaked column name -func getColumnName(ft int, addrField reflect.Value, sf reflect.StructField, col string) string { - column := col - if col == "" { - column = nameStrategyMap[nameStrategy](sf.Name) - } - switch ft { - case RelForeignKey, RelOneToOne: - if len(col) == 0 { - column = column + "_id" - } - case RelManyToMany, RelReverseMany, RelReverseOne: - column = sf.Name - } - return column -} - -// return field type as type constant from reflect.Value -func getFieldType(val reflect.Value) (ft int, err error) { - switch val.Type() { - case reflect.TypeOf(new(int8)): - ft = TypeBitField - case reflect.TypeOf(new(int16)): - ft = TypeSmallIntegerField - case reflect.TypeOf(new(int32)), - reflect.TypeOf(new(int)): - ft = TypeIntegerField - case reflect.TypeOf(new(int64)): - ft = TypeBigIntegerField - case reflect.TypeOf(new(uint8)): - ft = TypePositiveBitField - case reflect.TypeOf(new(uint16)): - ft = TypePositiveSmallIntegerField - case reflect.TypeOf(new(uint32)), - reflect.TypeOf(new(uint)): - ft = TypePositiveIntegerField - case reflect.TypeOf(new(uint64)): - ft = TypePositiveBigIntegerField - case reflect.TypeOf(new(float32)), - reflect.TypeOf(new(float64)): - ft = TypeFloatField - case reflect.TypeOf(new(bool)): - ft = TypeBooleanField - case reflect.TypeOf(new(string)): - ft = TypeVarCharField - case reflect.TypeOf(new(time.Time)): - ft = TypeDateTimeField - default: - elm := reflect.Indirect(val) - switch elm.Kind() { - case reflect.Int8: - ft = TypeBitField - case reflect.Int16: - ft = TypeSmallIntegerField - case reflect.Int32, reflect.Int: - ft = TypeIntegerField - case reflect.Int64: - ft = TypeBigIntegerField - case reflect.Uint8: - ft = TypePositiveBitField - case reflect.Uint16: - ft = TypePositiveSmallIntegerField - case reflect.Uint32, reflect.Uint: - ft = TypePositiveIntegerField - case reflect.Uint64: - ft = TypePositiveBigIntegerField - case reflect.Float32, reflect.Float64: - ft = TypeFloatField - case reflect.Bool: - ft = TypeBooleanField - case reflect.String: - ft = TypeVarCharField - default: - if elm.Interface() == nil { - panic(fmt.Errorf("%s is nil pointer, may be miss setting tag", val)) - } - switch elm.Interface().(type) { - case sql.NullInt64: - ft = TypeBigIntegerField - case sql.NullFloat64: - ft = TypeFloatField - case sql.NullBool: - ft = TypeBooleanField - case sql.NullString: - ft = TypeVarCharField - case time.Time: - ft = TypeDateTimeField - } - } - } - if ft&IsFieldType == 0 { - err = fmt.Errorf("unsupport field type %s, may be miss setting tag", val) - } - return -} - -// parse struct tag string -func parseStructTag(data string) (attrs map[string]bool, tags map[string]string) { - attrs = make(map[string]bool) - tags = make(map[string]string) - for _, v := range strings.Split(data, defaultStructTagDelim) { - if v == "" { - continue - } - v = strings.TrimSpace(v) - if t := strings.ToLower(v); supportTag[t] == 1 { - attrs[t] = true - } else if i := strings.Index(v, "("); i > 0 && strings.Index(v, ")") == len(v)-1 { - name := t[:i] - if supportTag[name] == 2 { - v = v[i+1 : len(v)-1] - tags[name] = v - } - } else { - DebugLog.Println("unsupport orm tag", v) - } - } - return -} diff --git a/orm/orm.go b/orm/orm.go deleted file mode 100644 index c7566b9a..00000000 --- a/orm/orm.go +++ /dev/null @@ -1,602 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// +build go1.8 - -// Package orm provide ORM for MySQL/PostgreSQL/sqlite -// Simple Usage -// -// package main -// -// import ( -// "fmt" -// "github.com/astaxie/beego/orm" -// _ "github.com/go-sql-driver/mysql" // import your used driver -// ) -// -// // Model Struct -// type User struct { -// Id int `orm:"auto"` -// Name string `orm:"size(100)"` -// } -// -// func init() { -// orm.RegisterDataBase("default", "mysql", "root:root@/my_db?charset=utf8", 30) -// } -// -// func main() { -// o := orm.NewOrm() -// user := User{Name: "slene"} -// // insert -// id, err := o.Insert(&user) -// // update -// user.Name = "astaxie" -// num, err := o.Update(&user) -// // read one -// u := User{Id: user.Id} -// err = o.Read(&u) -// // delete -// num, err = o.Delete(&u) -// } -// -// more docs: http://beego.me/docs/mvc/model/overview.md -package orm - -import ( - "context" - "database/sql" - "errors" - "fmt" - "os" - "reflect" - "sync" - "time" -) - -// DebugQueries define the debug -const ( - DebugQueries = iota -) - -// Define common vars -var ( - Debug = false - DebugLog = NewLog(os.Stdout) - DefaultRowsLimit = -1 - DefaultRelsDepth = 2 - DefaultTimeLoc = time.Local - ErrTxHasBegan = errors.New(" transaction already begin") - ErrTxDone = errors.New(" transaction not begin") - ErrMultiRows = errors.New(" return multi rows") - ErrNoRows = errors.New(" no row found") - ErrStmtClosed = errors.New(" stmt already closed") - ErrArgs = errors.New(" args error may be empty") - ErrNotImplement = errors.New("have not implement") -) - -// Params stores the Params -type Params map[string]interface{} - -// ParamsList stores paramslist -type ParamsList []interface{} - -type orm struct { - alias *alias - db dbQuerier - isTx bool -} - -var _ Ormer = new(orm) - -// get model info and model reflect value -func (o *orm) getMiInd(md interface{}, needPtr bool) (mi *modelInfo, ind reflect.Value) { - val := reflect.ValueOf(md) - ind = reflect.Indirect(val) - typ := ind.Type() - if needPtr && val.Kind() != reflect.Ptr { - panic(fmt.Errorf(" cannot use non-ptr model struct `%s`", getFullName(typ))) - } - name := getFullName(typ) - if mi, ok := modelCache.getByFullName(name); ok { - return mi, ind - } - panic(fmt.Errorf(" table: `%s` not found, make sure it was registered with `RegisterModel()`", name)) -} - -// get field info from model info by given field name -func (o *orm) getFieldInfo(mi *modelInfo, name string) *fieldInfo { - fi, ok := mi.fields.GetByAny(name) - if !ok { - panic(fmt.Errorf(" cannot find field `%s` for model `%s`", name, mi.fullName)) - } - return fi -} - -// read data to model -// Deprecated: using pkg/orm. We will remove this method in v2.1.0 -func (o *orm) Read(md interface{}, cols ...string) error { - mi, ind := o.getMiInd(md, true) - return o.alias.DbBaser.Read(o.db, mi, ind, o.alias.TZ, cols, false) -} - -// read data to model, like Read(), but use "SELECT FOR UPDATE" form -// Deprecated: using pkg/orm. We will remove this method in v2.1.0 -func (o *orm) ReadForUpdate(md interface{}, cols ...string) error { - mi, ind := o.getMiInd(md, true) - return o.alias.DbBaser.Read(o.db, mi, ind, o.alias.TZ, cols, true) -} - -// Try to read a row from the database, or insert one if it doesn't exist -// Deprecated: using pkg/orm. We will remove this method in v2.1.0 -func (o *orm) ReadOrCreate(md interface{}, col1 string, cols ...string) (bool, int64, error) { - cols = append([]string{col1}, cols...) - mi, ind := o.getMiInd(md, true) - err := o.alias.DbBaser.Read(o.db, mi, ind, o.alias.TZ, cols, false) - if err == ErrNoRows { - // Create - id, err := o.Insert(md) - return (err == nil), id, err - } - - id, vid := int64(0), ind.FieldByIndex(mi.fields.pk.fieldIndex) - if mi.fields.pk.fieldType&IsPositiveIntegerField > 0 { - id = int64(vid.Uint()) - } else if mi.fields.pk.rel { - return o.ReadOrCreate(vid.Interface(), mi.fields.pk.relModelInfo.fields.pk.name) - } else { - id = vid.Int() - } - - return false, id, err -} - -// insert model data to database -// Deprecated: using pkg/orm. We will remove this method in v2.1.0 -func (o *orm) Insert(md interface{}) (int64, error) { - mi, ind := o.getMiInd(md, true) - id, err := o.alias.DbBaser.Insert(o.db, mi, ind, o.alias.TZ) - if err != nil { - return id, err - } - - o.setPk(mi, ind, id) - - return id, nil -} - -// set auto pk field -func (o *orm) setPk(mi *modelInfo, ind reflect.Value, id int64) { - if mi.fields.pk.auto { - if mi.fields.pk.fieldType&IsPositiveIntegerField > 0 { - ind.FieldByIndex(mi.fields.pk.fieldIndex).SetUint(uint64(id)) - } else { - ind.FieldByIndex(mi.fields.pk.fieldIndex).SetInt(id) - } - } -} - -// insert some models to database -// Deprecated: using pkg/orm. We will remove this method in v2.1.0 -func (o *orm) InsertMulti(bulk int, mds interface{}) (int64, error) { - var cnt int64 - - sind := reflect.Indirect(reflect.ValueOf(mds)) - - switch sind.Kind() { - case reflect.Array, reflect.Slice: - if sind.Len() == 0 { - return cnt, ErrArgs - } - default: - return cnt, ErrArgs - } - - if bulk <= 1 { - for i := 0; i < sind.Len(); i++ { - ind := reflect.Indirect(sind.Index(i)) - mi, _ := o.getMiInd(ind.Interface(), false) - id, err := o.alias.DbBaser.Insert(o.db, mi, ind, o.alias.TZ) - if err != nil { - return cnt, err - } - - o.setPk(mi, ind, id) - - cnt++ - } - } else { - mi, _ := o.getMiInd(sind.Index(0).Interface(), false) - return o.alias.DbBaser.InsertMulti(o.db, mi, sind, bulk, o.alias.TZ) - } - return cnt, nil -} - -// InsertOrUpdate data to database -// Deprecated: using pkg/orm. We will remove this method in v2.1.0 -func (o *orm) InsertOrUpdate(md interface{}, colConflitAndArgs ...string) (int64, error) { - mi, ind := o.getMiInd(md, true) - id, err := o.alias.DbBaser.InsertOrUpdate(o.db, mi, ind, o.alias, colConflitAndArgs...) - if err != nil { - return id, err - } - - o.setPk(mi, ind, id) - - return id, nil -} - -// update model to database. -// cols set the columns those want to update. -// Deprecated: using pkg/orm. We will remove this method in v2.1.0 -func (o *orm) Update(md interface{}, cols ...string) (int64, error) { - mi, ind := o.getMiInd(md, true) - return o.alias.DbBaser.Update(o.db, mi, ind, o.alias.TZ, cols) -} - -// delete model in database -// cols shows the delete conditions values read from. default is pk -// Deprecated: using pkg/orm. We will remove this method in v2.1.0 -func (o *orm) Delete(md interface{}, cols ...string) (int64, error) { - mi, ind := o.getMiInd(md, true) - num, err := o.alias.DbBaser.Delete(o.db, mi, ind, o.alias.TZ, cols) - if err != nil { - return num, err - } - if num > 0 { - o.setPk(mi, ind, 0) - } - return num, nil -} - -// create a models to models queryer -// Deprecated: using pkg/orm. We will remove this method in v2.1.0 -func (o *orm) QueryM2M(md interface{}, name string) QueryM2Mer { - mi, ind := o.getMiInd(md, true) - fi := o.getFieldInfo(mi, name) - - switch { - case fi.fieldType == RelManyToMany: - case fi.fieldType == RelReverseMany && fi.reverseFieldInfo.mi.isThrough: - default: - panic(fmt.Errorf(" model `%s` . name `%s` is not a m2m field", fi.name, mi.fullName)) - } - - return newQueryM2M(md, o, mi, fi, ind) -} - -// load related models to md model. -// args are limit, offset int and order string. -// -// example: -// orm.LoadRelated(post,"Tags") -// for _,tag := range post.Tags{...} -// -// make sure the relation is defined in model struct tags. -// Deprecated: using pkg/orm. We will remove this method in v2.1.0 -func (o *orm) LoadRelated(md interface{}, name string, args ...interface{}) (int64, error) { - _, fi, ind, qseter := o.queryRelated(md, name) - - qs := qseter.(*querySet) - - var relDepth int - var limit, offset int64 - var order string - for i, arg := range args { - switch i { - case 0: - if v, ok := arg.(bool); ok { - if v { - relDepth = DefaultRelsDepth - } - } else if v, ok := arg.(int); ok { - relDepth = v - } - case 1: - limit = ToInt64(arg) - case 2: - offset = ToInt64(arg) - case 3: - order, _ = arg.(string) - } - } - - switch fi.fieldType { - case RelOneToOne, RelForeignKey, RelReverseOne: - limit = 1 - offset = 0 - } - - qs.limit = limit - qs.offset = offset - qs.relDepth = relDepth - - if len(order) > 0 { - qs.orders = []string{order} - } - - find := ind.FieldByIndex(fi.fieldIndex) - - var nums int64 - var err error - switch fi.fieldType { - case RelOneToOne, RelForeignKey, RelReverseOne: - val := reflect.New(find.Type().Elem()) - container := val.Interface() - err = qs.One(container) - if err == nil { - find.Set(val) - nums = 1 - } - default: - nums, err = qs.All(find.Addr().Interface()) - } - - return nums, err -} - -// return a QuerySeter for related models to md model. -// it can do all, update, delete in QuerySeter. -// example: -// qs := orm.QueryRelated(post,"Tag") -// qs.All(&[]*Tag{}) -// -// Deprecated: using pkg/orm. We will remove this method in v2.1.0 -func (o *orm) QueryRelated(md interface{}, name string) QuerySeter { - // is this api needed ? - _, _, _, qs := o.queryRelated(md, name) - return qs -} - -// get QuerySeter for related models to md model -func (o *orm) queryRelated(md interface{}, name string) (*modelInfo, *fieldInfo, reflect.Value, QuerySeter) { - mi, ind := o.getMiInd(md, true) - fi := o.getFieldInfo(mi, name) - - _, _, exist := getExistPk(mi, ind) - if !exist { - panic(ErrMissPK) - } - - var qs *querySet - - switch fi.fieldType { - case RelOneToOne, RelForeignKey, RelManyToMany: - if !fi.inModel { - break - } - qs = o.getRelQs(md, mi, fi) - case RelReverseOne, RelReverseMany: - if !fi.inModel { - break - } - qs = o.getReverseQs(md, mi, fi) - } - - if qs == nil { - panic(fmt.Errorf(" name `%s` for model `%s` is not an available rel/reverse field", md, name)) - } - - return mi, fi, ind, qs -} - -// get reverse relation QuerySeter -func (o *orm) getReverseQs(md interface{}, mi *modelInfo, fi *fieldInfo) *querySet { - switch fi.fieldType { - case RelReverseOne, RelReverseMany: - default: - panic(fmt.Errorf(" name `%s` for model `%s` is not an available reverse field", fi.name, mi.fullName)) - } - - var q *querySet - - if fi.fieldType == RelReverseMany && fi.reverseFieldInfo.mi.isThrough { - q = newQuerySet(o, fi.relModelInfo).(*querySet) - q.cond = NewCondition().And(fi.reverseFieldInfoM2M.column+ExprSep+fi.reverseFieldInfo.column, md) - } else { - q = newQuerySet(o, fi.reverseFieldInfo.mi).(*querySet) - q.cond = NewCondition().And(fi.reverseFieldInfo.column, md) - } - - return q -} - -// get relation QuerySeter -func (o *orm) getRelQs(md interface{}, mi *modelInfo, fi *fieldInfo) *querySet { - switch fi.fieldType { - case RelOneToOne, RelForeignKey, RelManyToMany: - default: - panic(fmt.Errorf(" name `%s` for model `%s` is not an available rel field", fi.name, mi.fullName)) - } - - q := newQuerySet(o, fi.relModelInfo).(*querySet) - q.cond = NewCondition() - - if fi.fieldType == RelManyToMany { - q.cond = q.cond.And(fi.reverseFieldInfoM2M.column+ExprSep+fi.reverseFieldInfo.column, md) - } else { - q.cond = q.cond.And(fi.reverseFieldInfo.column, md) - } - - return q -} - -// return a QuerySeter for table operations. -// table name can be string or struct. -// e.g. QueryTable("user"), QueryTable(&user{}) or QueryTable((*User)(nil)), -// Deprecated: using pkg/orm. We will remove this method in v2.1.0 -func (o *orm) QueryTable(ptrStructOrTableName interface{}) (qs QuerySeter) { - var name string - if table, ok := ptrStructOrTableName.(string); ok { - name = nameStrategyMap[defaultNameStrategy](table) - if mi, ok := modelCache.get(name); ok { - qs = newQuerySet(o, mi) - } - } else { - name = getFullName(indirectType(reflect.TypeOf(ptrStructOrTableName))) - if mi, ok := modelCache.getByFullName(name); ok { - qs = newQuerySet(o, mi) - } - } - if qs == nil { - panic(fmt.Errorf(" table name: `%s` not exists", name)) - } - return -} - -// switch to another registered database driver by given name. -// Deprecated: using pkg/orm. We will remove this method in v2.1.0 -// Using NewOrmUsingDB(name) -func (o *orm) Using(name string) error { - if o.isTx { - panic(fmt.Errorf(" transaction has been start, cannot change db")) - } - if al, ok := dataBaseCache.get(name); ok { - o.alias = al - if Debug { - o.db = newDbQueryLog(al, al.DB) - } else { - o.db = al.DB - } - } else { - return fmt.Errorf(" unknown db alias name `%s`", name) - } - return nil -} - -// begin transaction -// Deprecated: using pkg/orm. We will remove this method in v2.1.0 -func (o *orm) Begin() error { - return o.BeginTx(context.Background(), nil) -} - -// Deprecated: using pkg/orm. We will remove this method in v2.1.0 -func (o *orm) BeginTx(ctx context.Context, opts *sql.TxOptions) error { - if o.isTx { - return ErrTxHasBegan - } - var tx *sql.Tx - tx, err := o.db.(txer).BeginTx(ctx, opts) - if err != nil { - return err - } - o.isTx = true - if Debug { - o.db.(*dbQueryLog).SetDB(tx) - } else { - o.db = tx - } - return nil -} - -// commit transaction -// Deprecated: using pkg/orm. We will remove this method in v2.1.0 -func (o *orm) Commit() error { - if !o.isTx { - return ErrTxDone - } - err := o.db.(txEnder).Commit() - if err == nil { - o.isTx = false - o.Using(o.alias.Name) - } else if err == sql.ErrTxDone { - return ErrTxDone - } - return err -} - -// rollback transaction -// Deprecated: using pkg/orm. We will remove this method in v2.1.0 -func (o *orm) Rollback() error { - if !o.isTx { - return ErrTxDone - } - err := o.db.(txEnder).Rollback() - if err == nil { - o.isTx = false - o.Using(o.alias.Name) - } else if err == sql.ErrTxDone { - return ErrTxDone - } - return err -} - -// return a raw query seter for raw sql string. -// Deprecated: using pkg/orm. We will remove this method in v2.1.0 -func (o *orm) Raw(query string, args ...interface{}) RawSeter { - return newRawSet(o, query, args) -} - -// return current using database Driver -// Deprecated: using pkg/orm. We will remove this method in v2.1.0 -func (o *orm) Driver() Driver { - return driver(o.alias.Name) -} - -// return sql.DBStats for current database -// Deprecated: using pkg/orm. We will remove this method in v2.1.0 -func (o *orm) DBStats() *sql.DBStats { - if o.alias != nil && o.alias.DB != nil { - stats := o.alias.DB.DB.Stats() - return &stats - } - return nil -} - -// NewOrm create new orm -// Deprecated: using pkg/orm. We will remove this method in v2.1.0 -func NewOrm() Ormer { - BootStrap() // execute only once - - o := new(orm) - err := o.Using("default") - if err != nil { - panic(err) - } - return o -} - -// NewOrmWithDB create a new ormer object with specify *sql.DB for query -// Deprecated: using pkg/orm. We will remove this method in v2.1.0 -func NewOrmWithDB(driverName, aliasName string, db *sql.DB) (Ormer, error) { - var al *alias - - if dr, ok := drivers[driverName]; ok { - al = new(alias) - al.DbBaser = dbBasers[dr] - al.Driver = dr - } else { - return nil, fmt.Errorf("driver name `%s` have not registered", driverName) - } - - al.Name = aliasName - al.DriverName = driverName - al.DB = &DB{ - RWMutex: new(sync.RWMutex), - DB: db, - stmtDecorators: newStmtDecoratorLruWithEvict(), - } - - detectTZ(al) - - o := new(orm) - o.alias = al - - if Debug { - o.db = newDbQueryLog(o.alias, db) - } else { - o.db = db - } - - return o, nil -} diff --git a/orm/orm_conds.go b/orm/orm_conds.go deleted file mode 100644 index f3fd66f0..00000000 --- a/orm/orm_conds.go +++ /dev/null @@ -1,153 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package orm - -import ( - "fmt" - "strings" -) - -// ExprSep define the expression separation -const ( - ExprSep = "__" -) - -type condValue struct { - exprs []string - args []interface{} - cond *Condition - isOr bool - isNot bool - isCond bool - isRaw bool - sql string -} - -// Condition struct. -// work for WHERE conditions. -type Condition struct { - params []condValue -} - -// NewCondition return new condition struct -func NewCondition() *Condition { - c := &Condition{} - return c -} - -// Raw add raw sql to condition -func (c Condition) Raw(expr string, sql string) *Condition { - if len(sql) == 0 { - panic(fmt.Errorf(" sql cannot empty")) - } - c.params = append(c.params, condValue{exprs: strings.Split(expr, ExprSep), sql: sql, isRaw: true}) - return &c -} - -// And add expression to condition -func (c Condition) And(expr string, args ...interface{}) *Condition { - if expr == "" || len(args) == 0 { - panic(fmt.Errorf(" args cannot empty")) - } - c.params = append(c.params, condValue{exprs: strings.Split(expr, ExprSep), args: args}) - return &c -} - -// AndNot add NOT expression to condition -func (c Condition) AndNot(expr string, args ...interface{}) *Condition { - if expr == "" || len(args) == 0 { - panic(fmt.Errorf(" args cannot empty")) - } - c.params = append(c.params, condValue{exprs: strings.Split(expr, ExprSep), args: args, isNot: true}) - return &c -} - -// AndCond combine a condition to current condition -func (c *Condition) AndCond(cond *Condition) *Condition { - c = c.clone() - if c == cond { - panic(fmt.Errorf(" cannot use self as sub cond")) - } - if cond != nil { - c.params = append(c.params, condValue{cond: cond, isCond: true}) - } - return c -} - -// AndNotCond combine a AND NOT condition to current condition -func (c *Condition) AndNotCond(cond *Condition) *Condition { - c = c.clone() - if c == cond { - panic(fmt.Errorf(" cannot use self as sub cond")) - } - - if cond != nil { - c.params = append(c.params, condValue{cond: cond, isCond: true, isNot: true}) - } - return c -} - -// Or add OR expression to condition -func (c Condition) Or(expr string, args ...interface{}) *Condition { - if expr == "" || len(args) == 0 { - panic(fmt.Errorf(" args cannot empty")) - } - c.params = append(c.params, condValue{exprs: strings.Split(expr, ExprSep), args: args, isOr: true}) - return &c -} - -// OrNot add OR NOT expression to condition -func (c Condition) OrNot(expr string, args ...interface{}) *Condition { - if expr == "" || len(args) == 0 { - panic(fmt.Errorf(" args cannot empty")) - } - c.params = append(c.params, condValue{exprs: strings.Split(expr, ExprSep), args: args, isNot: true, isOr: true}) - return &c -} - -// OrCond combine a OR condition to current condition -func (c *Condition) OrCond(cond *Condition) *Condition { - c = c.clone() - if c == cond { - panic(fmt.Errorf(" cannot use self as sub cond")) - } - if cond != nil { - c.params = append(c.params, condValue{cond: cond, isCond: true, isOr: true}) - } - return c -} - -// OrNotCond combine a OR NOT condition to current condition -func (c *Condition) OrNotCond(cond *Condition) *Condition { - c = c.clone() - if c == cond { - panic(fmt.Errorf(" cannot use self as sub cond")) - } - - if cond != nil { - c.params = append(c.params, condValue{cond: cond, isCond: true, isNot: true, isOr: true}) - } - return c -} - -// IsEmpty check the condition arguments are empty or not. -func (c *Condition) IsEmpty() bool { - return len(c.params) == 0 -} - -// clone clone a condition -func (c Condition) clone() *Condition { - return &c -} diff --git a/orm/orm_log.go b/orm/orm_log.go deleted file mode 100644 index 5bb3a24f..00000000 --- a/orm/orm_log.go +++ /dev/null @@ -1,222 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package orm - -import ( - "context" - "database/sql" - "fmt" - "io" - "log" - "strings" - "time" -) - -// Log implement the log.Logger -type Log struct { - *log.Logger -} - -//costomer log func -var LogFunc func(query map[string]interface{}) - -// NewLog set io.Writer to create a Logger. -func NewLog(out io.Writer) *Log { - d := new(Log) - d.Logger = log.New(out, "[ORM]", log.LstdFlags) - return d -} - -func debugLogQueies(alias *alias, operaton, query string, t time.Time, err error, args ...interface{}) { - var logMap = make(map[string]interface{}) - sub := time.Now().Sub(t) / 1e5 - elsp := float64(int(sub)) / 10.0 - logMap["cost_time"] = elsp - flag := " OK" - if err != nil { - flag = "FAIL" - } - logMap["flag"] = flag - con := fmt.Sprintf(" -[Queries/%s] - [%s / %11s / %7.1fms] - [%s]", alias.Name, flag, operaton, elsp, query) - cons := make([]string, 0, len(args)) - for _, arg := range args { - cons = append(cons, fmt.Sprintf("%v", arg)) - } - if len(cons) > 0 { - con += fmt.Sprintf(" - `%s`", strings.Join(cons, "`, `")) - } - if err != nil { - con += " - " + err.Error() - } - logMap["sql"] = fmt.Sprintf("%s-`%s`", query, strings.Join(cons, "`, `")) - if LogFunc != nil { - LogFunc(logMap) - } - DebugLog.Println(con) -} - -// statement query logger struct. -// if dev mode, use stmtQueryLog, or use stmtQuerier. -type stmtQueryLog struct { - alias *alias - query string - stmt stmtQuerier -} - -var _ stmtQuerier = new(stmtQueryLog) - -func (d *stmtQueryLog) Close() error { - a := time.Now() - err := d.stmt.Close() - debugLogQueies(d.alias, "st.Close", d.query, a, err) - return err -} - -func (d *stmtQueryLog) Exec(args ...interface{}) (sql.Result, error) { - a := time.Now() - res, err := d.stmt.Exec(args...) - debugLogQueies(d.alias, "st.Exec", d.query, a, err, args...) - return res, err -} - -func (d *stmtQueryLog) Query(args ...interface{}) (*sql.Rows, error) { - a := time.Now() - res, err := d.stmt.Query(args...) - debugLogQueies(d.alias, "st.Query", d.query, a, err, args...) - return res, err -} - -func (d *stmtQueryLog) QueryRow(args ...interface{}) *sql.Row { - a := time.Now() - res := d.stmt.QueryRow(args...) - debugLogQueies(d.alias, "st.QueryRow", d.query, a, nil, args...) - return res -} - -func newStmtQueryLog(alias *alias, stmt stmtQuerier, query string) stmtQuerier { - d := new(stmtQueryLog) - d.stmt = stmt - d.alias = alias - d.query = query - return d -} - -// database query logger struct. -// if dev mode, use dbQueryLog, or use dbQuerier. -type dbQueryLog struct { - alias *alias - db dbQuerier - tx txer - txe txEnder -} - -var _ dbQuerier = new(dbQueryLog) -var _ txer = new(dbQueryLog) -var _ txEnder = new(dbQueryLog) - -func (d *dbQueryLog) Prepare(query string) (*sql.Stmt, error) { - a := time.Now() - stmt, err := d.db.Prepare(query) - debugLogQueies(d.alias, "db.Prepare", query, a, err) - return stmt, err -} - -func (d *dbQueryLog) PrepareContext(ctx context.Context, query string) (*sql.Stmt, error) { - a := time.Now() - stmt, err := d.db.PrepareContext(ctx, query) - debugLogQueies(d.alias, "db.Prepare", query, a, err) - return stmt, err -} - -func (d *dbQueryLog) Exec(query string, args ...interface{}) (sql.Result, error) { - a := time.Now() - res, err := d.db.Exec(query, args...) - debugLogQueies(d.alias, "db.Exec", query, a, err, args...) - return res, err -} - -func (d *dbQueryLog) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) { - a := time.Now() - res, err := d.db.ExecContext(ctx, query, args...) - debugLogQueies(d.alias, "db.Exec", query, a, err, args...) - return res, err -} - -func (d *dbQueryLog) Query(query string, args ...interface{}) (*sql.Rows, error) { - a := time.Now() - res, err := d.db.Query(query, args...) - debugLogQueies(d.alias, "db.Query", query, a, err, args...) - return res, err -} - -func (d *dbQueryLog) QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) { - a := time.Now() - res, err := d.db.QueryContext(ctx, query, args...) - debugLogQueies(d.alias, "db.Query", query, a, err, args...) - return res, err -} - -func (d *dbQueryLog) QueryRow(query string, args ...interface{}) *sql.Row { - a := time.Now() - res := d.db.QueryRow(query, args...) - debugLogQueies(d.alias, "db.QueryRow", query, a, nil, args...) - return res -} - -func (d *dbQueryLog) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row { - a := time.Now() - res := d.db.QueryRowContext(ctx, query, args...) - debugLogQueies(d.alias, "db.QueryRow", query, a, nil, args...) - return res -} - -func (d *dbQueryLog) Begin() (*sql.Tx, error) { - a := time.Now() - tx, err := d.db.(txer).Begin() - debugLogQueies(d.alias, "db.Begin", "START TRANSACTION", a, err) - return tx, err -} - -func (d *dbQueryLog) BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error) { - a := time.Now() - tx, err := d.db.(txer).BeginTx(ctx, opts) - debugLogQueies(d.alias, "db.BeginTx", "START TRANSACTION", a, err) - return tx, err -} - -func (d *dbQueryLog) Commit() error { - a := time.Now() - err := d.db.(txEnder).Commit() - debugLogQueies(d.alias, "tx.Commit", "COMMIT", a, err) - return err -} - -func (d *dbQueryLog) Rollback() error { - a := time.Now() - err := d.db.(txEnder).Rollback() - debugLogQueies(d.alias, "tx.Rollback", "ROLLBACK", a, err) - return err -} - -func (d *dbQueryLog) SetDB(db dbQuerier) { - d.db = db -} - -func newDbQueryLog(alias *alias, db dbQuerier) dbQuerier { - d := new(dbQueryLog) - d.alias = alias - d.db = db - return d -} diff --git a/orm/orm_object.go b/orm/orm_object.go deleted file mode 100644 index de3181ce..00000000 --- a/orm/orm_object.go +++ /dev/null @@ -1,87 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package orm - -import ( - "fmt" - "reflect" -) - -// an insert queryer struct -type insertSet struct { - mi *modelInfo - orm *orm - stmt stmtQuerier - closed bool -} - -var _ Inserter = new(insertSet) - -// insert model ignore it's registered or not. -func (o *insertSet) Insert(md interface{}) (int64, error) { - if o.closed { - return 0, ErrStmtClosed - } - val := reflect.ValueOf(md) - ind := reflect.Indirect(val) - typ := ind.Type() - name := getFullName(typ) - if val.Kind() != reflect.Ptr { - panic(fmt.Errorf(" cannot use non-ptr model struct `%s`", name)) - } - if name != o.mi.fullName { - panic(fmt.Errorf(" need model `%s` but found `%s`", o.mi.fullName, name)) - } - id, err := o.orm.alias.DbBaser.InsertStmt(o.stmt, o.mi, ind, o.orm.alias.TZ) - if err != nil { - return id, err - } - if id > 0 { - if o.mi.fields.pk.auto { - if o.mi.fields.pk.fieldType&IsPositiveIntegerField > 0 { - ind.FieldByIndex(o.mi.fields.pk.fieldIndex).SetUint(uint64(id)) - } else { - ind.FieldByIndex(o.mi.fields.pk.fieldIndex).SetInt(id) - } - } - } - return id, nil -} - -// close insert queryer statement -func (o *insertSet) Close() error { - if o.closed { - return ErrStmtClosed - } - o.closed = true - return o.stmt.Close() -} - -// create new insert queryer. -func newInsertSet(orm *orm, mi *modelInfo) (Inserter, error) { - bi := new(insertSet) - bi.orm = orm - bi.mi = mi - st, query, err := orm.alias.DbBaser.PrepareInsert(orm.db, mi) - if err != nil { - return nil, err - } - if Debug { - bi.stmt = newStmtQueryLog(orm.alias, st, query) - } else { - bi.stmt = st - } - return bi, nil -} diff --git a/orm/orm_querym2m.go b/orm/orm_querym2m.go deleted file mode 100644 index 6a270a0d..00000000 --- a/orm/orm_querym2m.go +++ /dev/null @@ -1,140 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package orm - -import "reflect" - -// model to model struct -type queryM2M struct { - md interface{} - mi *modelInfo - fi *fieldInfo - qs *querySet - ind reflect.Value -} - -// add models to origin models when creating queryM2M. -// example: -// m2m := orm.QueryM2M(post,"Tag") -// m2m.Add(&Tag1{},&Tag2{}) -// for _,tag := range post.Tags{} -// -// make sure the relation is defined in post model struct tag. -func (o *queryM2M) Add(mds ...interface{}) (int64, error) { - fi := o.fi - mi := fi.relThroughModelInfo - mfi := fi.reverseFieldInfo - rfi := fi.reverseFieldInfoTwo - - orm := o.qs.orm - dbase := orm.alias.DbBaser - - var models []interface{} - var otherValues []interface{} - var otherNames []string - - for _, colname := range mi.fields.dbcols { - if colname != mfi.column && colname != rfi.column && colname != fi.mi.fields.pk.column && - mi.fields.columns[colname] != mi.fields.pk { - otherNames = append(otherNames, colname) - } - } - for i, md := range mds { - if reflect.Indirect(reflect.ValueOf(md)).Kind() != reflect.Struct && i > 0 { - otherValues = append(otherValues, md) - mds = append(mds[:i], mds[i+1:]...) - } - } - for _, md := range mds { - val := reflect.ValueOf(md) - if val.Kind() == reflect.Slice || val.Kind() == reflect.Array { - for i := 0; i < val.Len(); i++ { - v := val.Index(i) - if v.CanInterface() { - models = append(models, v.Interface()) - } - } - } else { - models = append(models, md) - } - } - - _, v1, exist := getExistPk(o.mi, o.ind) - if !exist { - panic(ErrMissPK) - } - - names := []string{mfi.column, rfi.column} - - values := make([]interface{}, 0, len(models)*2) - for _, md := range models { - - ind := reflect.Indirect(reflect.ValueOf(md)) - var v2 interface{} - if ind.Kind() != reflect.Struct { - v2 = ind.Interface() - } else { - _, v2, exist = getExistPk(fi.relModelInfo, ind) - if !exist { - panic(ErrMissPK) - } - } - values = append(values, v1, v2) - - } - names = append(names, otherNames...) - values = append(values, otherValues...) - return dbase.InsertValue(orm.db, mi, true, names, values) -} - -// remove models following the origin model relationship -func (o *queryM2M) Remove(mds ...interface{}) (int64, error) { - fi := o.fi - qs := o.qs.Filter(fi.reverseFieldInfo.name, o.md) - - return qs.Filter(fi.reverseFieldInfoTwo.name+ExprSep+"in", mds).Delete() -} - -// check model is existed in relationship of origin model -func (o *queryM2M) Exist(md interface{}) bool { - fi := o.fi - return o.qs.Filter(fi.reverseFieldInfo.name, o.md). - Filter(fi.reverseFieldInfoTwo.name, md).Exist() -} - -// clean all models in related of origin model -func (o *queryM2M) Clear() (int64, error) { - fi := o.fi - return o.qs.Filter(fi.reverseFieldInfo.name, o.md).Delete() -} - -// count all related models of origin model -func (o *queryM2M) Count() (int64, error) { - fi := o.fi - return o.qs.Filter(fi.reverseFieldInfo.name, o.md).Count() -} - -var _ QueryM2Mer = new(queryM2M) - -// create new M2M queryer. -func newQueryM2M(md interface{}, o *orm, mi *modelInfo, fi *fieldInfo, ind reflect.Value) QueryM2Mer { - qm2m := new(queryM2M) - qm2m.md = md - qm2m.mi = mi - qm2m.fi = fi - qm2m.ind = ind - qm2m.qs = newQuerySet(o, fi.relThroughModelInfo).(*querySet) - return qm2m -} diff --git a/orm/orm_queryset.go b/orm/orm_queryset.go deleted file mode 100644 index 878b836b..00000000 --- a/orm/orm_queryset.go +++ /dev/null @@ -1,300 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package orm - -import ( - "context" - "fmt" -) - -type colValue struct { - value int64 - opt operator -} - -type operator int - -// define Col operations -const ( - ColAdd operator = iota - ColMinus - ColMultiply - ColExcept - ColBitAnd - ColBitRShift - ColBitLShift - ColBitXOR - ColBitOr -) - -// ColValue do the field raw changes. e.g Nums = Nums + 10. usage: -// Params{ -// "Nums": ColValue(Col_Add, 10), -// } -func ColValue(opt operator, value interface{}) interface{} { - switch opt { - case ColAdd, ColMinus, ColMultiply, ColExcept, ColBitAnd, ColBitRShift, - ColBitLShift, ColBitXOR, ColBitOr: - default: - panic(fmt.Errorf("orm.ColValue wrong operator")) - } - v, err := StrTo(ToStr(value)).Int64() - if err != nil { - panic(fmt.Errorf("orm.ColValue doesn't support non string/numeric type, %s", err)) - } - var val colValue - val.value = v - val.opt = opt - return val -} - -// real query struct -type querySet struct { - mi *modelInfo - cond *Condition - related []string - relDepth int - limit int64 - offset int64 - groups []string - orders []string - distinct bool - forupdate bool - orm *orm - ctx context.Context - forContext bool -} - -var _ QuerySeter = new(querySet) - -// add condition expression to QuerySeter. -func (o querySet) Filter(expr string, args ...interface{}) QuerySeter { - if o.cond == nil { - o.cond = NewCondition() - } - o.cond = o.cond.And(expr, args...) - return &o -} - -// add raw sql to querySeter. -func (o querySet) FilterRaw(expr string, sql string) QuerySeter { - if o.cond == nil { - o.cond = NewCondition() - } - o.cond = o.cond.Raw(expr, sql) - return &o -} - -// add NOT condition to querySeter. -func (o querySet) Exclude(expr string, args ...interface{}) QuerySeter { - if o.cond == nil { - o.cond = NewCondition() - } - o.cond = o.cond.AndNot(expr, args...) - return &o -} - -// set offset number -func (o *querySet) setOffset(num interface{}) { - o.offset = ToInt64(num) -} - -// add LIMIT value. -// args[0] means offset, e.g. LIMIT num,offset. -func (o querySet) Limit(limit interface{}, args ...interface{}) QuerySeter { - o.limit = ToInt64(limit) - if len(args) > 0 { - o.setOffset(args[0]) - } - return &o -} - -// add OFFSET value -func (o querySet) Offset(offset interface{}) QuerySeter { - o.setOffset(offset) - return &o -} - -// add GROUP expression -func (o querySet) GroupBy(exprs ...string) QuerySeter { - o.groups = exprs - return &o -} - -// add ORDER expression. -// "column" means ASC, "-column" means DESC. -func (o querySet) OrderBy(exprs ...string) QuerySeter { - o.orders = exprs - return &o -} - -// add DISTINCT to SELECT -func (o querySet) Distinct() QuerySeter { - o.distinct = true - return &o -} - -// add FOR UPDATE to SELECT -func (o querySet) ForUpdate() QuerySeter { - o.forupdate = true - return &o -} - -// set relation model to query together. -// it will query relation models and assign to parent model. -func (o querySet) RelatedSel(params ...interface{}) QuerySeter { - if len(params) == 0 { - o.relDepth = DefaultRelsDepth - } else { - for _, p := range params { - switch val := p.(type) { - case string: - o.related = append(o.related, val) - case int: - o.relDepth = val - default: - panic(fmt.Errorf(" wrong param kind: %v", val)) - } - } - } - return &o -} - -// set condition to QuerySeter. -func (o querySet) SetCond(cond *Condition) QuerySeter { - o.cond = cond - return &o -} - -// get condition from QuerySeter -func (o querySet) GetCond() *Condition { - return o.cond -} - -// return QuerySeter execution result number -func (o *querySet) Count() (int64, error) { - return o.orm.alias.DbBaser.Count(o.orm.db, o, o.mi, o.cond, o.orm.alias.TZ) -} - -// check result empty or not after QuerySeter executed -func (o *querySet) Exist() bool { - cnt, _ := o.orm.alias.DbBaser.Count(o.orm.db, o, o.mi, o.cond, o.orm.alias.TZ) - return cnt > 0 -} - -// execute update with parameters -func (o *querySet) Update(values Params) (int64, error) { - return o.orm.alias.DbBaser.UpdateBatch(o.orm.db, o, o.mi, o.cond, values, o.orm.alias.TZ) -} - -// execute delete -func (o *querySet) Delete() (int64, error) { - return o.orm.alias.DbBaser.DeleteBatch(o.orm.db, o, o.mi, o.cond, o.orm.alias.TZ) -} - -// return a insert queryer. -// it can be used in times. -// example: -// i,err := sq.PrepareInsert() -// i.Add(&user1{},&user2{}) -func (o *querySet) PrepareInsert() (Inserter, error) { - return newInsertSet(o.orm, o.mi) -} - -// query all data and map to containers. -// cols means the columns when querying. -func (o *querySet) All(container interface{}, cols ...string) (int64, error) { - return o.orm.alias.DbBaser.ReadBatch(o.orm.db, o, o.mi, o.cond, container, o.orm.alias.TZ, cols) -} - -// query one row data and map to containers. -// cols means the columns when querying. -func (o *querySet) One(container interface{}, cols ...string) error { - o.limit = 1 - num, err := o.orm.alias.DbBaser.ReadBatch(o.orm.db, o, o.mi, o.cond, container, o.orm.alias.TZ, cols) - if err != nil { - return err - } - if num == 0 { - return ErrNoRows - } - - if num > 1 { - return ErrMultiRows - } - return nil -} - -// query all data and map to []map[string]interface. -// expres means condition expression. -// it converts data to []map[column]value. -func (o *querySet) Values(results *[]Params, exprs ...string) (int64, error) { - return o.orm.alias.DbBaser.ReadValues(o.orm.db, o, o.mi, o.cond, exprs, results, o.orm.alias.TZ) -} - -// query all data and map to [][]interface -// it converts data to [][column_index]value -func (o *querySet) ValuesList(results *[]ParamsList, exprs ...string) (int64, error) { - return o.orm.alias.DbBaser.ReadValues(o.orm.db, o, o.mi, o.cond, exprs, results, o.orm.alias.TZ) -} - -// query all data and map to []interface. -// it's designed for one row record set, auto change to []value, not [][column]value. -func (o *querySet) ValuesFlat(result *ParamsList, expr string) (int64, error) { - return o.orm.alias.DbBaser.ReadValues(o.orm.db, o, o.mi, o.cond, []string{expr}, result, o.orm.alias.TZ) -} - -// query all rows into map[string]interface with specify key and value column name. -// keyCol = "name", valueCol = "value" -// table data -// name | value -// total | 100 -// found | 200 -// to map[string]interface{}{ -// "total": 100, -// "found": 200, -// } -func (o *querySet) RowsToMap(result *Params, keyCol, valueCol string) (int64, error) { - panic(ErrNotImplement) -} - -// query all rows into struct with specify key and value column name. -// keyCol = "name", valueCol = "value" -// table data -// name | value -// total | 100 -// found | 200 -// to struct { -// Total int -// Found int -// } -func (o *querySet) RowsToStruct(ptrStruct interface{}, keyCol, valueCol string) (int64, error) { - panic(ErrNotImplement) -} - -// set context to QuerySeter. -func (o querySet) WithContext(ctx context.Context) QuerySeter { - o.ctx = ctx - o.forContext = true - return &o -} - -// create new QuerySeter. -func newQuerySet(orm *orm, mi *modelInfo) QuerySeter { - o := new(querySet) - o.mi = mi - o.orm = orm - return o -} diff --git a/orm/orm_raw.go b/orm/orm_raw.go deleted file mode 100644 index 3325a7ea..00000000 --- a/orm/orm_raw.go +++ /dev/null @@ -1,867 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package orm - -import ( - "database/sql" - "fmt" - "reflect" - "time" -) - -// raw sql string prepared statement -type rawPrepare struct { - rs *rawSet - stmt stmtQuerier - closed bool -} - -func (o *rawPrepare) Exec(args ...interface{}) (sql.Result, error) { - if o.closed { - return nil, ErrStmtClosed - } - return o.stmt.Exec(args...) -} - -func (o *rawPrepare) Close() error { - o.closed = true - return o.stmt.Close() -} - -func newRawPreparer(rs *rawSet) (RawPreparer, error) { - o := new(rawPrepare) - o.rs = rs - - query := rs.query - rs.orm.alias.DbBaser.ReplaceMarks(&query) - - st, err := rs.orm.db.Prepare(query) - if err != nil { - return nil, err - } - if Debug { - o.stmt = newStmtQueryLog(rs.orm.alias, st, query) - } else { - o.stmt = st - } - return o, nil -} - -// raw query seter -type rawSet struct { - query string - args []interface{} - orm *orm -} - -var _ RawSeter = new(rawSet) - -// set args for every query -func (o rawSet) SetArgs(args ...interface{}) RawSeter { - o.args = args - return &o -} - -// execute raw sql and return sql.Result -func (o *rawSet) Exec() (sql.Result, error) { - query := o.query - o.orm.alias.DbBaser.ReplaceMarks(&query) - - args := getFlatParams(nil, o.args, o.orm.alias.TZ) - return o.orm.db.Exec(query, args...) -} - -// set field value to row container -func (o *rawSet) setFieldValue(ind reflect.Value, value interface{}) { - switch ind.Kind() { - case reflect.Bool: - if value == nil { - ind.SetBool(false) - } else if v, ok := value.(bool); ok { - ind.SetBool(v) - } else { - v, _ := StrTo(ToStr(value)).Bool() - ind.SetBool(v) - } - - case reflect.String: - if value == nil { - ind.SetString("") - } else { - ind.SetString(ToStr(value)) - } - - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: - if value == nil { - ind.SetInt(0) - } else { - val := reflect.ValueOf(value) - switch val.Kind() { - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: - ind.SetInt(val.Int()) - case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: - ind.SetInt(int64(val.Uint())) - default: - v, _ := StrTo(ToStr(value)).Int64() - ind.SetInt(v) - } - } - case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: - if value == nil { - ind.SetUint(0) - } else { - val := reflect.ValueOf(value) - switch val.Kind() { - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: - ind.SetUint(uint64(val.Int())) - case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: - ind.SetUint(val.Uint()) - default: - v, _ := StrTo(ToStr(value)).Uint64() - ind.SetUint(v) - } - } - case reflect.Float64, reflect.Float32: - if value == nil { - ind.SetFloat(0) - } else { - val := reflect.ValueOf(value) - switch val.Kind() { - case reflect.Float64: - ind.SetFloat(val.Float()) - default: - v, _ := StrTo(ToStr(value)).Float64() - ind.SetFloat(v) - } - } - - case reflect.Struct: - if value == nil { - ind.Set(reflect.Zero(ind.Type())) - return - } - switch ind.Interface().(type) { - case time.Time: - var str string - switch d := value.(type) { - case time.Time: - o.orm.alias.DbBaser.TimeFromDB(&d, o.orm.alias.TZ) - ind.Set(reflect.ValueOf(d)) - case []byte: - str = string(d) - case string: - str = d - } - if str != "" { - if len(str) >= 19 { - str = str[:19] - t, err := time.ParseInLocation(formatDateTime, str, o.orm.alias.TZ) - if err == nil { - t = t.In(DefaultTimeLoc) - ind.Set(reflect.ValueOf(t)) - } - } else if len(str) >= 10 { - str = str[:10] - t, err := time.ParseInLocation(formatDate, str, DefaultTimeLoc) - if err == nil { - ind.Set(reflect.ValueOf(t)) - } - } - } - case sql.NullString, sql.NullInt64, sql.NullFloat64, sql.NullBool: - indi := reflect.New(ind.Type()).Interface() - sc, ok := indi.(sql.Scanner) - if !ok { - return - } - err := sc.Scan(value) - if err == nil { - ind.Set(reflect.Indirect(reflect.ValueOf(sc))) - } - } - - case reflect.Ptr: - if value == nil { - ind.Set(reflect.Zero(ind.Type())) - break - } - ind.Set(reflect.New(ind.Type().Elem())) - o.setFieldValue(reflect.Indirect(ind), value) - } -} - -// set field value in loop for slice container -func (o *rawSet) loopSetRefs(refs []interface{}, sInds []reflect.Value, nIndsPtr *[]reflect.Value, eTyps []reflect.Type, init bool) { - nInds := *nIndsPtr - - cur := 0 - for i := 0; i < len(sInds); i++ { - sInd := sInds[i] - eTyp := eTyps[i] - - typ := eTyp - isPtr := false - if typ.Kind() == reflect.Ptr { - isPtr = true - typ = typ.Elem() - } - if typ.Kind() == reflect.Ptr { - isPtr = true - typ = typ.Elem() - } - - var nInd reflect.Value - if init { - nInd = reflect.New(sInd.Type()).Elem() - } else { - nInd = nInds[i] - } - - val := reflect.New(typ) - ind := val.Elem() - - tpName := ind.Type().String() - - if ind.Kind() == reflect.Struct { - if tpName == "time.Time" { - value := reflect.ValueOf(refs[cur]).Elem().Interface() - if isPtr && value == nil { - val = reflect.New(val.Type()).Elem() - } else { - o.setFieldValue(ind, value) - } - cur++ - } - - } else { - value := reflect.ValueOf(refs[cur]).Elem().Interface() - if isPtr && value == nil { - val = reflect.New(val.Type()).Elem() - } else { - o.setFieldValue(ind, value) - } - cur++ - } - - if nInd.Kind() == reflect.Slice { - if isPtr { - nInd = reflect.Append(nInd, val) - } else { - nInd = reflect.Append(nInd, ind) - } - } else { - if isPtr { - nInd.Set(val) - } else { - nInd.Set(ind) - } - } - - nInds[i] = nInd - } -} - -// query data and map to container -func (o *rawSet) QueryRow(containers ...interface{}) error { - var ( - refs = make([]interface{}, 0, len(containers)) - sInds []reflect.Value - eTyps []reflect.Type - sMi *modelInfo - ) - structMode := false - for _, container := range containers { - val := reflect.ValueOf(container) - ind := reflect.Indirect(val) - - if val.Kind() != reflect.Ptr { - panic(fmt.Errorf(" all args must be use ptr")) - } - - etyp := ind.Type() - typ := etyp - if typ.Kind() == reflect.Ptr { - typ = typ.Elem() - } - - sInds = append(sInds, ind) - eTyps = append(eTyps, etyp) - - if typ.Kind() == reflect.Struct && typ.String() != "time.Time" { - if len(containers) > 1 { - panic(fmt.Errorf(" now support one struct only. see #384")) - } - - structMode = true - fn := getFullName(typ) - if mi, ok := modelCache.getByFullName(fn); ok { - sMi = mi - } - } else { - var ref interface{} - refs = append(refs, &ref) - } - } - - query := o.query - o.orm.alias.DbBaser.ReplaceMarks(&query) - - args := getFlatParams(nil, o.args, o.orm.alias.TZ) - rows, err := o.orm.db.Query(query, args...) - if err != nil { - if err == sql.ErrNoRows { - return ErrNoRows - } - return err - } - - defer rows.Close() - - if rows.Next() { - if structMode { - columns, err := rows.Columns() - if err != nil { - return err - } - - columnsMp := make(map[string]interface{}, len(columns)) - - refs = make([]interface{}, 0, len(columns)) - for _, col := range columns { - var ref interface{} - columnsMp[col] = &ref - refs = append(refs, &ref) - } - - if err := rows.Scan(refs...); err != nil { - return err - } - - ind := sInds[0] - - if ind.Kind() == reflect.Ptr { - if ind.IsNil() || !ind.IsValid() { - ind.Set(reflect.New(eTyps[0].Elem())) - } - ind = ind.Elem() - } - - if sMi != nil { - for _, col := range columns { - if fi := sMi.fields.GetByColumn(col); fi != nil { - value := reflect.ValueOf(columnsMp[col]).Elem().Interface() - field := ind.FieldByIndex(fi.fieldIndex) - if fi.fieldType&IsRelField > 0 { - mf := reflect.New(fi.relModelInfo.addrField.Elem().Type()) - field.Set(mf) - field = mf.Elem().FieldByIndex(fi.relModelInfo.fields.pk.fieldIndex) - } - o.setFieldValue(field, value) - } - } - } else { - for i := 0; i < ind.NumField(); i++ { - f := ind.Field(i) - fe := ind.Type().Field(i) - _, tags := parseStructTag(fe.Tag.Get(defaultStructTagName)) - var col string - if col = tags["column"]; col == "" { - col = nameStrategyMap[nameStrategy](fe.Name) - } - if v, ok := columnsMp[col]; ok { - value := reflect.ValueOf(v).Elem().Interface() - o.setFieldValue(f, value) - } - } - } - - } else { - if err := rows.Scan(refs...); err != nil { - return err - } - - nInds := make([]reflect.Value, len(sInds)) - o.loopSetRefs(refs, sInds, &nInds, eTyps, true) - for i, sInd := range sInds { - nInd := nInds[i] - sInd.Set(nInd) - } - } - - } else { - return ErrNoRows - } - - return nil -} - -// query data rows and map to container -func (o *rawSet) QueryRows(containers ...interface{}) (int64, error) { - var ( - refs = make([]interface{}, 0, len(containers)) - sInds []reflect.Value - eTyps []reflect.Type - sMi *modelInfo - ) - structMode := false - for _, container := range containers { - val := reflect.ValueOf(container) - sInd := reflect.Indirect(val) - if val.Kind() != reflect.Ptr || sInd.Kind() != reflect.Slice { - panic(fmt.Errorf(" all args must be use ptr slice")) - } - - etyp := sInd.Type().Elem() - typ := etyp - if typ.Kind() == reflect.Ptr { - typ = typ.Elem() - } - - sInds = append(sInds, sInd) - eTyps = append(eTyps, etyp) - - if typ.Kind() == reflect.Struct && typ.String() != "time.Time" { - if len(containers) > 1 { - panic(fmt.Errorf(" now support one struct only. see #384")) - } - - structMode = true - fn := getFullName(typ) - if mi, ok := modelCache.getByFullName(fn); ok { - sMi = mi - } - } else { - var ref interface{} - refs = append(refs, &ref) - } - } - - query := o.query - o.orm.alias.DbBaser.ReplaceMarks(&query) - - args := getFlatParams(nil, o.args, o.orm.alias.TZ) - rows, err := o.orm.db.Query(query, args...) - if err != nil { - return 0, err - } - - defer rows.Close() - - var cnt int64 - nInds := make([]reflect.Value, len(sInds)) - sInd := sInds[0] - - for rows.Next() { - - if structMode { - columns, err := rows.Columns() - if err != nil { - return 0, err - } - - columnsMp := make(map[string]interface{}, len(columns)) - - refs = make([]interface{}, 0, len(columns)) - for _, col := range columns { - var ref interface{} - columnsMp[col] = &ref - refs = append(refs, &ref) - } - - if err := rows.Scan(refs...); err != nil { - return 0, err - } - - if cnt == 0 && !sInd.IsNil() { - sInd.Set(reflect.New(sInd.Type()).Elem()) - } - - var ind reflect.Value - if eTyps[0].Kind() == reflect.Ptr { - ind = reflect.New(eTyps[0].Elem()) - } else { - ind = reflect.New(eTyps[0]) - } - - if ind.Kind() == reflect.Ptr { - ind = ind.Elem() - } - - if sMi != nil { - for _, col := range columns { - if fi := sMi.fields.GetByColumn(col); fi != nil { - value := reflect.ValueOf(columnsMp[col]).Elem().Interface() - field := ind.FieldByIndex(fi.fieldIndex) - if fi.fieldType&IsRelField > 0 { - mf := reflect.New(fi.relModelInfo.addrField.Elem().Type()) - field.Set(mf) - field = mf.Elem().FieldByIndex(fi.relModelInfo.fields.pk.fieldIndex) - } - o.setFieldValue(field, value) - } - } - } else { - // define recursive function - var recursiveSetField func(rv reflect.Value) - recursiveSetField = func(rv reflect.Value) { - for i := 0; i < rv.NumField(); i++ { - f := rv.Field(i) - fe := rv.Type().Field(i) - - // check if the field is a Struct - // recursive the Struct type - if fe.Type.Kind() == reflect.Struct { - recursiveSetField(f) - } - - _, tags := parseStructTag(fe.Tag.Get(defaultStructTagName)) - var col string - if col = tags["column"]; col == "" { - col = nameStrategyMap[nameStrategy](fe.Name) - } - if v, ok := columnsMp[col]; ok { - value := reflect.ValueOf(v).Elem().Interface() - o.setFieldValue(f, value) - } - } - } - - // init call the recursive function - recursiveSetField(ind) - } - - if eTyps[0].Kind() == reflect.Ptr { - ind = ind.Addr() - } - - sInd = reflect.Append(sInd, ind) - - } else { - if err := rows.Scan(refs...); err != nil { - return 0, err - } - - o.loopSetRefs(refs, sInds, &nInds, eTyps, cnt == 0) - } - - cnt++ - } - - if cnt > 0 { - - if structMode { - sInds[0].Set(sInd) - } else { - for i, sInd := range sInds { - nInd := nInds[i] - sInd.Set(nInd) - } - } - } - - return cnt, nil -} - -func (o *rawSet) readValues(container interface{}, needCols []string) (int64, error) { - var ( - maps []Params - lists []ParamsList - list ParamsList - ) - - typ := 0 - switch container.(type) { - case *[]Params: - typ = 1 - case *[]ParamsList: - typ = 2 - case *ParamsList: - typ = 3 - default: - panic(fmt.Errorf(" unsupport read values type `%T`", container)) - } - - query := o.query - o.orm.alias.DbBaser.ReplaceMarks(&query) - - args := getFlatParams(nil, o.args, o.orm.alias.TZ) - - var rs *sql.Rows - rs, err := o.orm.db.Query(query, args...) - if err != nil { - return 0, err - } - - defer rs.Close() - - var ( - refs []interface{} - cnt int64 - cols []string - indexs []int - ) - - for rs.Next() { - if cnt == 0 { - columns, err := rs.Columns() - if err != nil { - return 0, err - } - if len(needCols) > 0 { - indexs = make([]int, 0, len(needCols)) - } else { - indexs = make([]int, 0, len(columns)) - } - - cols = columns - refs = make([]interface{}, len(cols)) - for i := range refs { - var ref sql.NullString - refs[i] = &ref - - if len(needCols) > 0 { - for _, c := range needCols { - if c == cols[i] { - indexs = append(indexs, i) - } - } - } else { - indexs = append(indexs, i) - } - } - } - - if err := rs.Scan(refs...); err != nil { - return 0, err - } - - switch typ { - case 1: - params := make(Params, len(cols)) - for _, i := range indexs { - ref := refs[i] - value := reflect.Indirect(reflect.ValueOf(ref)).Interface().(sql.NullString) - if value.Valid { - params[cols[i]] = value.String - } else { - params[cols[i]] = nil - } - } - maps = append(maps, params) - case 2: - params := make(ParamsList, 0, len(cols)) - for _, i := range indexs { - ref := refs[i] - value := reflect.Indirect(reflect.ValueOf(ref)).Interface().(sql.NullString) - if value.Valid { - params = append(params, value.String) - } else { - params = append(params, nil) - } - } - lists = append(lists, params) - case 3: - for _, i := range indexs { - ref := refs[i] - value := reflect.Indirect(reflect.ValueOf(ref)).Interface().(sql.NullString) - if value.Valid { - list = append(list, value.String) - } else { - list = append(list, nil) - } - } - } - - cnt++ - } - - switch v := container.(type) { - case *[]Params: - *v = maps - case *[]ParamsList: - *v = lists - case *ParamsList: - *v = list - } - - return cnt, nil -} - -func (o *rawSet) queryRowsTo(container interface{}, keyCol, valueCol string) (int64, error) { - var ( - maps Params - ind *reflect.Value - ) - - var typ int - switch container.(type) { - case *Params: - typ = 1 - default: - typ = 2 - vl := reflect.ValueOf(container) - id := reflect.Indirect(vl) - if vl.Kind() != reflect.Ptr || id.Kind() != reflect.Struct { - panic(fmt.Errorf(" RowsTo unsupport type `%T` need ptr struct", container)) - } - - ind = &id - } - - query := o.query - o.orm.alias.DbBaser.ReplaceMarks(&query) - - args := getFlatParams(nil, o.args, o.orm.alias.TZ) - - rs, err := o.orm.db.Query(query, args...) - if err != nil { - return 0, err - } - - defer rs.Close() - - var ( - refs []interface{} - cnt int64 - cols []string - ) - - var ( - keyIndex = -1 - valueIndex = -1 - ) - - for rs.Next() { - if cnt == 0 { - columns, err := rs.Columns() - if err != nil { - return 0, err - } - cols = columns - refs = make([]interface{}, len(cols)) - for i := range refs { - if keyCol == cols[i] { - keyIndex = i - } - if typ == 1 || keyIndex == i { - var ref sql.NullString - refs[i] = &ref - } else { - var ref interface{} - refs[i] = &ref - } - if valueCol == cols[i] { - valueIndex = i - } - } - if keyIndex == -1 || valueIndex == -1 { - panic(fmt.Errorf(" RowsTo unknown key, value column name `%s: %s`", keyCol, valueCol)) - } - } - - if err := rs.Scan(refs...); err != nil { - return 0, err - } - - if cnt == 0 { - switch typ { - case 1: - maps = make(Params) - } - } - - key := reflect.Indirect(reflect.ValueOf(refs[keyIndex])).Interface().(sql.NullString).String - - switch typ { - case 1: - value := reflect.Indirect(reflect.ValueOf(refs[valueIndex])).Interface().(sql.NullString) - if value.Valid { - maps[key] = value.String - } else { - maps[key] = nil - } - - default: - if id := ind.FieldByName(camelString(key)); id.IsValid() { - o.setFieldValue(id, reflect.ValueOf(refs[valueIndex]).Elem().Interface()) - } - } - - cnt++ - } - - if typ == 1 { - v, _ := container.(*Params) - *v = maps - } - - return cnt, nil -} - -// query data to []map[string]interface -func (o *rawSet) Values(container *[]Params, cols ...string) (int64, error) { - return o.readValues(container, cols) -} - -// query data to [][]interface -func (o *rawSet) ValuesList(container *[]ParamsList, cols ...string) (int64, error) { - return o.readValues(container, cols) -} - -// query data to []interface -func (o *rawSet) ValuesFlat(container *ParamsList, cols ...string) (int64, error) { - return o.readValues(container, cols) -} - -// query all rows into map[string]interface with specify key and value column name. -// keyCol = "name", valueCol = "value" -// table data -// name | value -// total | 100 -// found | 200 -// to map[string]interface{}{ -// "total": 100, -// "found": 200, -// } -func (o *rawSet) RowsToMap(result *Params, keyCol, valueCol string) (int64, error) { - return o.queryRowsTo(result, keyCol, valueCol) -} - -// query all rows into struct with specify key and value column name. -// keyCol = "name", valueCol = "value" -// table data -// name | value -// total | 100 -// found | 200 -// to struct { -// Total int -// Found int -// } -func (o *rawSet) RowsToStruct(ptrStruct interface{}, keyCol, valueCol string) (int64, error) { - return o.queryRowsTo(ptrStruct, keyCol, valueCol) -} - -// return prepared raw statement for used in times. -func (o *rawSet) Prepare() (RawPreparer, error) { - return newRawPreparer(o) -} - -func newRawSet(orm *orm, query string, args []interface{}) RawSeter { - o := new(rawSet) - o.query = query - o.args = args - o.orm = orm - return o -} diff --git a/orm/orm_test.go b/orm/orm_test.go deleted file mode 100644 index eac7b33a..00000000 --- a/orm/orm_test.go +++ /dev/null @@ -1,2500 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// +build go1.8 - -package orm - -import ( - "bytes" - "context" - "database/sql" - "fmt" - "io/ioutil" - "math" - "os" - "path/filepath" - "reflect" - "runtime" - "strings" - "testing" - "time" - - "github.com/stretchr/testify/assert" -) - -var _ = os.PathSeparator - -var ( - testDate = formatDate + " -0700" - testDateTime = formatDateTime + " -0700" - testTime = formatTime + " -0700" -) - -type argAny []interface{} - -// get interface by index from interface slice -func (a argAny) Get(i int, args ...interface{}) (r interface{}) { - if i >= 0 && i < len(a) { - r = a[i] - } - if len(args) > 0 { - r = args[0] - } - return -} - -func ValuesCompare(is bool, a interface{}, args ...interface{}) (ok bool, err error) { - if len(args) == 0 { - return false, fmt.Errorf("miss args") - } - b := args[0] - arg := argAny(args) - - switch v := a.(type) { - case reflect.Kind: - ok = reflect.ValueOf(b).Kind() == v - case time.Time: - if v2, vo := b.(time.Time); vo { - if arg.Get(1) != nil { - format := ToStr(arg.Get(1)) - a = v.Format(format) - b = v2.Format(format) - ok = a == b - } else { - err = fmt.Errorf("compare datetime miss format") - goto wrongArg - } - } - default: - ok = ToStr(a) == ToStr(b) - } - ok = is && ok || !is && !ok - if !ok { - if is { - err = fmt.Errorf("expected: `%v`, get `%v`", b, a) - } else { - err = fmt.Errorf("expected: `%v`, get `%v`", b, a) - } - } - -wrongArg: - if err != nil { - return false, err - } - - return true, nil -} - -func AssertIs(a interface{}, args ...interface{}) error { - if ok, err := ValuesCompare(true, a, args...); !ok { - return err - } - return nil -} - -func AssertNot(a interface{}, args ...interface{}) error { - if ok, err := ValuesCompare(false, a, args...); !ok { - return err - } - return nil -} - -func getCaller(skip int) string { - pc, file, line, _ := runtime.Caller(skip) - fun := runtime.FuncForPC(pc) - _, fn := filepath.Split(file) - data, err := ioutil.ReadFile(file) - var codes []string - if err == nil { - lines := bytes.Split(data, []byte{'\n'}) - n := 10 - for i := 0; i < n; i++ { - o := line - n - if o < 0 { - continue - } - cur := o + i + 1 - flag := " " - if cur == line { - flag = ">>" - } - code := fmt.Sprintf(" %s %5d: %s", flag, cur, strings.Replace(string(lines[o+i]), "\t", " ", -1)) - if code != "" { - codes = append(codes, code) - } - } - } - funName := fun.Name() - if i := strings.LastIndex(funName, "."); i > -1 { - funName = funName[i+1:] - } - return fmt.Sprintf("%s:%s:%d: \n%s", fn, funName, line, strings.Join(codes, "\n")) -} - -// Deprecated: Using stretchr/testify/assert -func throwFail(t *testing.T, err error, args ...interface{}) { - if err != nil { - con := fmt.Sprintf("\t\nError: %s\n%s\n", err.Error(), getCaller(2)) - if len(args) > 0 { - parts := make([]string, 0, len(args)) - for _, arg := range args { - parts = append(parts, fmt.Sprintf("%v", arg)) - } - con += " " + strings.Join(parts, ", ") - } - t.Error(con) - t.Fail() - } -} - -func throwFailNow(t *testing.T, err error, args ...interface{}) { - if err != nil { - con := fmt.Sprintf("\t\nError: %s\n%s\n", err.Error(), getCaller(2)) - if len(args) > 0 { - parts := make([]string, 0, len(args)) - for _, arg := range args { - parts = append(parts, fmt.Sprintf("%v", arg)) - } - con += " " + strings.Join(parts, ", ") - } - t.Error(con) - t.FailNow() - } -} - -func TestGetDB(t *testing.T) { - if db, err := GetDB(); err != nil { - throwFailNow(t, err) - } else { - err = db.Ping() - throwFailNow(t, err) - } -} - -func TestSyncDb(t *testing.T) { - RegisterModel(new(Data), new(DataNull), new(DataCustom)) - RegisterModel(new(User)) - RegisterModel(new(Profile)) - RegisterModel(new(Post)) - RegisterModel(new(Tag)) - RegisterModel(new(Comment)) - RegisterModel(new(UserBig)) - RegisterModel(new(PostTags)) - RegisterModel(new(Group)) - RegisterModel(new(Permission)) - RegisterModel(new(GroupPermissions)) - RegisterModel(new(InLine)) - RegisterModel(new(InLineOneToOne)) - RegisterModel(new(IntegerPk)) - RegisterModel(new(UintPk)) - RegisterModel(new(PtrPk)) - - err := RunSyncdb("default", true, Debug) - throwFail(t, err) - - modelCache.clean() -} - -func TestRegisterModels(t *testing.T) { - RegisterModel(new(Data), new(DataNull), new(DataCustom)) - RegisterModel(new(User)) - RegisterModel(new(Profile)) - RegisterModel(new(Post)) - RegisterModel(new(Tag)) - RegisterModel(new(Comment)) - RegisterModel(new(UserBig)) - RegisterModel(new(PostTags)) - RegisterModel(new(Group)) - RegisterModel(new(Permission)) - RegisterModel(new(GroupPermissions)) - RegisterModel(new(InLine)) - RegisterModel(new(InLineOneToOne)) - RegisterModel(new(IntegerPk)) - RegisterModel(new(UintPk)) - RegisterModel(new(PtrPk)) - - BootStrap() - - dORM = NewOrm() - dDbBaser = getDbAlias("default").DbBaser -} - -func TestModelSyntax(t *testing.T) { - user := &User{} - ind := reflect.ValueOf(user).Elem() - fn := getFullName(ind.Type()) - mi, ok := modelCache.getByFullName(fn) - throwFail(t, AssertIs(ok, true)) - - mi, ok = modelCache.get("user") - throwFail(t, AssertIs(ok, true)) - if ok { - throwFail(t, AssertIs(mi.fields.GetByName("ShouldSkip") == nil, true)) - } -} - -var DataValues = map[string]interface{}{ - "Boolean": true, - "Char": "char", - "Text": "text", - "JSON": `{"name":"json"}`, - "Jsonb": `{"name": "jsonb"}`, - "Time": time.Now(), - "Date": time.Now(), - "DateTime": time.Now(), - "Byte": byte(1<<8 - 1), - "Rune": rune(1<<31 - 1), - "Int": int(1<<31 - 1), - "Int8": int8(1<<7 - 1), - "Int16": int16(1<<15 - 1), - "Int32": int32(1<<31 - 1), - "Int64": int64(1<<63 - 1), - "Uint": uint(1<<32 - 1), - "Uint8": uint8(1<<8 - 1), - "Uint16": uint16(1<<16 - 1), - "Uint32": uint32(1<<32 - 1), - "Uint64": uint64(1<<63 - 1), // uint64 values with high bit set are not supported - "Float32": float32(100.1234), - "Float64": float64(100.1234), - "Decimal": float64(100.1234), -} - -func TestDataTypes(t *testing.T) { - d := Data{} - ind := reflect.Indirect(reflect.ValueOf(&d)) - - for name, value := range DataValues { - if name == "JSON" { - continue - } - e := ind.FieldByName(name) - e.Set(reflect.ValueOf(value)) - } - id, err := dORM.Insert(&d) - throwFail(t, err) - throwFail(t, AssertIs(id, 1)) - - d = Data{ID: 1} - err = dORM.Read(&d) - throwFail(t, err) - - ind = reflect.Indirect(reflect.ValueOf(&d)) - - for name, value := range DataValues { - e := ind.FieldByName(name) - vu := e.Interface() - switch name { - case "Date": - vu = vu.(time.Time).In(DefaultTimeLoc).Format(testDate) - value = value.(time.Time).In(DefaultTimeLoc).Format(testDate) - case "DateTime": - vu = vu.(time.Time).In(DefaultTimeLoc).Format(testDateTime) - value = value.(time.Time).In(DefaultTimeLoc).Format(testDateTime) - case "Time": - vu = vu.(time.Time).In(DefaultTimeLoc).Format(testTime) - value = value.(time.Time).In(DefaultTimeLoc).Format(testTime) - } - throwFail(t, AssertIs(vu == value, true), value, vu) - } -} - -func TestNullDataTypes(t *testing.T) { - d := DataNull{} - - if IsPostgres { - // can removed when this fixed - // https://github.com/lib/pq/pull/125 - d.DateTime = time.Now() - } - - id, err := dORM.Insert(&d) - throwFail(t, err) - throwFail(t, AssertIs(id, 1)) - - data := `{"ok":1,"data":{"arr":[1,2],"msg":"gopher"}}` - d = DataNull{ID: 1, JSON: data} - num, err := dORM.Update(&d) - throwFail(t, err) - throwFail(t, AssertIs(num, 1)) - - d = DataNull{ID: 1} - err = dORM.Read(&d) - throwFail(t, err) - - throwFail(t, AssertIs(d.JSON, data)) - - throwFail(t, AssertIs(d.NullBool.Valid, false)) - throwFail(t, AssertIs(d.NullString.Valid, false)) - throwFail(t, AssertIs(d.NullInt64.Valid, false)) - throwFail(t, AssertIs(d.NullFloat64.Valid, false)) - - throwFail(t, AssertIs(d.BooleanPtr, nil)) - throwFail(t, AssertIs(d.CharPtr, nil)) - throwFail(t, AssertIs(d.TextPtr, nil)) - throwFail(t, AssertIs(d.BytePtr, nil)) - throwFail(t, AssertIs(d.RunePtr, nil)) - throwFail(t, AssertIs(d.IntPtr, nil)) - throwFail(t, AssertIs(d.Int8Ptr, nil)) - throwFail(t, AssertIs(d.Int16Ptr, nil)) - throwFail(t, AssertIs(d.Int32Ptr, nil)) - throwFail(t, AssertIs(d.Int64Ptr, nil)) - throwFail(t, AssertIs(d.UintPtr, nil)) - throwFail(t, AssertIs(d.Uint8Ptr, nil)) - throwFail(t, AssertIs(d.Uint16Ptr, nil)) - throwFail(t, AssertIs(d.Uint32Ptr, nil)) - throwFail(t, AssertIs(d.Uint64Ptr, nil)) - throwFail(t, AssertIs(d.Float32Ptr, nil)) - throwFail(t, AssertIs(d.Float64Ptr, nil)) - throwFail(t, AssertIs(d.DecimalPtr, nil)) - throwFail(t, AssertIs(d.TimePtr, nil)) - throwFail(t, AssertIs(d.DatePtr, nil)) - throwFail(t, AssertIs(d.DateTimePtr, nil)) - - _, err = dORM.Raw(`INSERT INTO data_null (boolean) VALUES (?)`, nil).Exec() - throwFail(t, err) - - d = DataNull{ID: 2} - err = dORM.Read(&d) - throwFail(t, err) - - booleanPtr := true - charPtr := string("test") - textPtr := string("test") - bytePtr := byte('t') - runePtr := rune('t') - intPtr := int(42) - int8Ptr := int8(42) - int16Ptr := int16(42) - int32Ptr := int32(42) - int64Ptr := int64(42) - uintPtr := uint(42) - uint8Ptr := uint8(42) - uint16Ptr := uint16(42) - uint32Ptr := uint32(42) - uint64Ptr := uint64(42) - float32Ptr := float32(42.0) - float64Ptr := float64(42.0) - decimalPtr := float64(42.0) - timePtr := time.Now() - datePtr := time.Now() - dateTimePtr := time.Now() - - d = DataNull{ - DateTime: time.Now(), - NullString: sql.NullString{String: "test", Valid: true}, - NullBool: sql.NullBool{Bool: true, Valid: true}, - NullInt64: sql.NullInt64{Int64: 42, Valid: true}, - NullFloat64: sql.NullFloat64{Float64: 42.42, Valid: true}, - BooleanPtr: &booleanPtr, - CharPtr: &charPtr, - TextPtr: &textPtr, - BytePtr: &bytePtr, - RunePtr: &runePtr, - IntPtr: &intPtr, - Int8Ptr: &int8Ptr, - Int16Ptr: &int16Ptr, - Int32Ptr: &int32Ptr, - Int64Ptr: &int64Ptr, - UintPtr: &uintPtr, - Uint8Ptr: &uint8Ptr, - Uint16Ptr: &uint16Ptr, - Uint32Ptr: &uint32Ptr, - Uint64Ptr: &uint64Ptr, - Float32Ptr: &float32Ptr, - Float64Ptr: &float64Ptr, - DecimalPtr: &decimalPtr, - TimePtr: &timePtr, - DatePtr: &datePtr, - DateTimePtr: &dateTimePtr, - } - - id, err = dORM.Insert(&d) - throwFail(t, err) - throwFail(t, AssertIs(id, 3)) - - d = DataNull{ID: 3} - err = dORM.Read(&d) - throwFail(t, err) - - throwFail(t, AssertIs(d.NullBool.Valid, true)) - throwFail(t, AssertIs(d.NullBool.Bool, true)) - - throwFail(t, AssertIs(d.NullString.Valid, true)) - throwFail(t, AssertIs(d.NullString.String, "test")) - - throwFail(t, AssertIs(d.NullInt64.Valid, true)) - throwFail(t, AssertIs(d.NullInt64.Int64, 42)) - - throwFail(t, AssertIs(d.NullFloat64.Valid, true)) - throwFail(t, AssertIs(d.NullFloat64.Float64, 42.42)) - - throwFail(t, AssertIs(*d.BooleanPtr, booleanPtr)) - throwFail(t, AssertIs(*d.CharPtr, charPtr)) - throwFail(t, AssertIs(*d.TextPtr, textPtr)) - throwFail(t, AssertIs(*d.BytePtr, bytePtr)) - throwFail(t, AssertIs(*d.RunePtr, runePtr)) - throwFail(t, AssertIs(*d.IntPtr, intPtr)) - throwFail(t, AssertIs(*d.Int8Ptr, int8Ptr)) - throwFail(t, AssertIs(*d.Int16Ptr, int16Ptr)) - throwFail(t, AssertIs(*d.Int32Ptr, int32Ptr)) - throwFail(t, AssertIs(*d.Int64Ptr, int64Ptr)) - throwFail(t, AssertIs(*d.UintPtr, uintPtr)) - throwFail(t, AssertIs(*d.Uint8Ptr, uint8Ptr)) - throwFail(t, AssertIs(*d.Uint16Ptr, uint16Ptr)) - throwFail(t, AssertIs(*d.Uint32Ptr, uint32Ptr)) - throwFail(t, AssertIs(*d.Uint64Ptr, uint64Ptr)) - throwFail(t, AssertIs(*d.Float32Ptr, float32Ptr)) - throwFail(t, AssertIs(*d.Float64Ptr, float64Ptr)) - throwFail(t, AssertIs(*d.DecimalPtr, decimalPtr)) - - // in mysql, there are some precision problem, (*d.TimePtr).UTC() != timePtr.UTC() - assert.True(t, (*d.TimePtr).UTC().Sub(timePtr.UTC()) <= time.Second) - assert.True(t, (*d.DatePtr).UTC().Sub(datePtr.UTC()) <= time.Second) - assert.True(t, (*d.DateTimePtr).UTC().Sub(dateTimePtr.UTC()) <= time.Second) - - // test support for pointer fields using RawSeter.QueryRows() - var dnList []*DataNull - Q := dDbBaser.TableQuote() - num, err = dORM.Raw(fmt.Sprintf("SELECT * FROM %sdata_null%s where id=?", Q, Q), 3).QueryRows(&dnList) - throwFailNow(t, err) - throwFailNow(t, AssertIs(num, 1)) - equal := reflect.DeepEqual(*dnList[0], d) - throwFailNow(t, AssertIs(equal, true)) -} - -func TestDataCustomTypes(t *testing.T) { - d := DataCustom{} - ind := reflect.Indirect(reflect.ValueOf(&d)) - - for name, value := range DataValues { - e := ind.FieldByName(name) - if !e.IsValid() { - continue - } - e.Set(reflect.ValueOf(value).Convert(e.Type())) - } - - id, err := dORM.Insert(&d) - throwFail(t, err) - throwFail(t, AssertIs(id, 1)) - - d = DataCustom{ID: 1} - err = dORM.Read(&d) - throwFail(t, err) - - ind = reflect.Indirect(reflect.ValueOf(&d)) - - for name, value := range DataValues { - e := ind.FieldByName(name) - if !e.IsValid() { - continue - } - vu := e.Interface() - value = reflect.ValueOf(value).Convert(e.Type()).Interface() - throwFail(t, AssertIs(vu == value, true), value, vu) - } -} - -func TestCRUD(t *testing.T) { - profile := NewProfile() - profile.Age = 30 - profile.Money = 1234.12 - id, err := dORM.Insert(profile) - throwFail(t, err) - throwFail(t, AssertIs(id, 1)) - - user := NewUser() - user.UserName = "slene" - user.Email = "vslene@gmail.com" - user.Password = "pass" - user.Status = 3 - user.IsStaff = true - user.IsActive = true - - id, err = dORM.Insert(user) - throwFail(t, err) - throwFail(t, AssertIs(id, 1)) - - u := &User{ID: user.ID} - err = dORM.Read(u) - throwFail(t, err) - - throwFail(t, AssertIs(u.UserName, "slene")) - throwFail(t, AssertIs(u.Email, "vslene@gmail.com")) - throwFail(t, AssertIs(u.Password, "pass")) - throwFail(t, AssertIs(u.Status, 3)) - throwFail(t, AssertIs(u.IsStaff, true)) - throwFail(t, AssertIs(u.IsActive, true)) - - assert.True(t, u.Created.In(DefaultTimeLoc).Sub(user.Created.In(DefaultTimeLoc)) <= time.Second) - assert.True(t, u.Updated.In(DefaultTimeLoc).Sub(user.Updated.In(DefaultTimeLoc)) <= time.Second) - - user.UserName = "astaxie" - user.Profile = profile - num, err := dORM.Update(user) - throwFail(t, err) - throwFail(t, AssertIs(num, 1)) - - u = &User{ID: user.ID} - err = dORM.Read(u) - throwFailNow(t, err) - throwFail(t, AssertIs(u.UserName, "astaxie")) - throwFail(t, AssertIs(u.Profile.ID, profile.ID)) - - u = &User{UserName: "astaxie", Password: "pass"} - err = dORM.Read(u, "UserName") - throwFailNow(t, err) - throwFailNow(t, AssertIs(id, 1)) - - u.UserName = "QQ" - u.Password = "111" - num, err = dORM.Update(u, "UserName") - throwFail(t, err) - throwFail(t, AssertIs(num, 1)) - - u = &User{ID: user.ID} - err = dORM.Read(u) - throwFailNow(t, err) - throwFail(t, AssertIs(u.UserName, "QQ")) - throwFail(t, AssertIs(u.Password, "pass")) - - num, err = dORM.Delete(profile) - throwFail(t, err) - throwFail(t, AssertIs(num, 1)) - - u = &User{ID: user.ID} - err = dORM.Read(u) - throwFail(t, err) - throwFail(t, AssertIs(true, u.Profile == nil)) - - num, err = dORM.Delete(user) - throwFail(t, err) - throwFail(t, AssertIs(num, 1)) - - u = &User{ID: 100} - err = dORM.Read(u) - throwFail(t, AssertIs(err, ErrNoRows)) - - ub := UserBig{} - ub.Name = "name" - id, err = dORM.Insert(&ub) - throwFail(t, err) - throwFail(t, AssertIs(id, 1)) - - ub = UserBig{ID: 1} - err = dORM.Read(&ub) - throwFail(t, err) - throwFail(t, AssertIs(ub.Name, "name")) - - num, err = dORM.Delete(&ub, "name") - throwFail(t, err) - throwFail(t, AssertIs(num, 1)) -} - -func TestInsertTestData(t *testing.T) { - var users []*User - - profile := NewProfile() - profile.Age = 28 - profile.Money = 1234.12 - - id, err := dORM.Insert(profile) - throwFail(t, err) - throwFail(t, AssertIs(id, 2)) - - user := NewUser() - user.UserName = "slene" - user.Email = "vslene@gmail.com" - user.Password = "pass" - user.Status = 1 - user.IsStaff = false - user.IsActive = true - user.Profile = profile - - users = append(users, user) - - id, err = dORM.Insert(user) - throwFail(t, err) - throwFail(t, AssertIs(id, 2)) - - profile = NewProfile() - profile.Age = 30 - profile.Money = 4321.09 - - id, err = dORM.Insert(profile) - throwFail(t, err) - throwFail(t, AssertIs(id, 3)) - - user = NewUser() - user.UserName = "astaxie" - user.Email = "astaxie@gmail.com" - user.Password = "password" - user.Status = 2 - user.IsStaff = true - user.IsActive = false - user.Profile = profile - - users = append(users, user) - - id, err = dORM.Insert(user) - throwFail(t, err) - throwFail(t, AssertIs(id, 3)) - - user = NewUser() - user.UserName = "nobody" - user.Email = "nobody@gmail.com" - user.Password = "nobody" - user.Status = 3 - user.IsStaff = false - user.IsActive = false - - users = append(users, user) - - id, err = dORM.Insert(user) - throwFail(t, err) - throwFail(t, AssertIs(id, 4)) - - tags := []*Tag{ - {Name: "golang", BestPost: &Post{ID: 2}}, - {Name: "example"}, - {Name: "format"}, - {Name: "c++"}, - } - - posts := []*Post{ - {User: users[0], Tags: []*Tag{tags[0]}, Title: "Introduction", Content: `Go is a new language. Although it borrows ideas from existing languages, it has unusual properties that make effective Go programs different in character from programs written in its relatives. A straightforward translation of a C++ or Java program into Go is unlikely to produce a satisfactory result—Java programs are written in Java, not Go. On the other hand, thinking about the problem from a Go perspective could produce a successful but quite different program. In other words, to write Go well, it's important to understand its properties and idioms. It's also important to know the established conventions for programming in Go, such as naming, formatting, program construction, and so on, so that programs you write will be easy for other Go programmers to understand. -This document gives tips for writing clear, idiomatic Go code. It augments the language specification, the Tour of Go, and How to Write Go Code, all of which you should read first.`}, - {User: users[1], Tags: []*Tag{tags[0], tags[1]}, Title: "Examples", Content: `The Go package sources are intended to serve not only as the core library but also as examples of how to use the language. Moreover, many of the packages contain working, self-contained executable examples you can run directly from the golang.org web site, such as this one (click on the word "Example" to open it up). If you have a question about how to approach a problem or how something might be implemented, the documentation, code and examples in the library can provide answers, ideas and background.`}, - {User: users[1], Tags: []*Tag{tags[0], tags[2]}, Title: "Formatting", Content: `Formatting issues are the most contentious but the least consequential. People can adapt to different formatting styles but it's better if they don't have to, and less time is devoted to the topic if everyone adheres to the same style. The problem is how to approach this Utopia without a long prescriptive style guide. -With Go we take an unusual approach and let the machine take care of most formatting issues. The gofmt program (also available as go fmt, which operates at the package level rather than source file level) reads a Go program and emits the source in a standard style of indentation and vertical alignment, retaining and if necessary reformatting comments. If you want to know how to handle some new layout situation, run gofmt; if the answer doesn't seem right, rearrange your program (or file a bug about gofmt), don't work around it.`}, - {User: users[2], Tags: []*Tag{tags[3]}, Title: "Commentary", Content: `Go provides C-style /* */ block comments and C++-style // line comments. Line comments are the norm; block comments appear mostly as package comments, but are useful within an expression or to disable large swaths of code. -The program—and web server—godoc processes Go source files to extract documentation about the contents of the package. Comments that appear before top-level declarations, with no intervening newlines, are extracted along with the declaration to serve as explanatory text for the item. The nature and style of these comments determines the quality of the documentation godoc produces.`}, - } - - comments := []*Comment{ - {Post: posts[0], Content: "a comment"}, - {Post: posts[1], Content: "yes"}, - {Post: posts[1]}, - {Post: posts[1]}, - {Post: posts[2]}, - {Post: posts[2]}, - } - - for _, tag := range tags { - id, err := dORM.Insert(tag) - throwFail(t, err) - throwFail(t, AssertIs(id > 0, true)) - } - - for _, post := range posts { - id, err := dORM.Insert(post) - throwFail(t, err) - throwFail(t, AssertIs(id > 0, true)) - - num := len(post.Tags) - if num > 0 { - nums, err := dORM.QueryM2M(post, "tags").Add(post.Tags) - throwFailNow(t, err) - throwFailNow(t, AssertIs(nums, num)) - } - } - - for _, comment := range comments { - id, err := dORM.Insert(comment) - throwFail(t, err) - throwFail(t, AssertIs(id > 0, true)) - } - - permissions := []*Permission{ - {Name: "writePosts"}, - {Name: "readComments"}, - {Name: "readPosts"}, - } - - groups := []*Group{ - { - Name: "admins", - Permissions: []*Permission{permissions[0], permissions[1], permissions[2]}, - }, - { - Name: "users", - Permissions: []*Permission{permissions[1], permissions[2]}, - }, - } - - for _, permission := range permissions { - id, err := dORM.Insert(permission) - throwFail(t, err) - throwFail(t, AssertIs(id > 0, true)) - } - - for _, group := range groups { - _, err := dORM.Insert(group) - throwFail(t, err) - throwFail(t, AssertIs(id > 0, true)) - - num := len(group.Permissions) - if num > 0 { - nums, err := dORM.QueryM2M(group, "permissions").Add(group.Permissions) - throwFailNow(t, err) - throwFailNow(t, AssertIs(nums, num)) - } - } - -} - -func TestCustomField(t *testing.T) { - user := User{ID: 2} - err := dORM.Read(&user) - throwFailNow(t, err) - - user.Langs = append(user.Langs, "zh-CN", "en-US") - user.Extra.Name = "beego" - user.Extra.Data = "orm" - _, err = dORM.Update(&user, "Langs", "Extra") - throwFailNow(t, err) - - user = User{ID: 2} - err = dORM.Read(&user) - throwFailNow(t, err) - throwFailNow(t, AssertIs(len(user.Langs), 2)) - throwFailNow(t, AssertIs(user.Langs[0], "zh-CN")) - throwFailNow(t, AssertIs(user.Langs[1], "en-US")) - - throwFailNow(t, AssertIs(user.Extra.Name, "beego")) - throwFailNow(t, AssertIs(user.Extra.Data, "orm")) -} - -func TestExpr(t *testing.T) { - user := &User{} - qs := dORM.QueryTable(user) - qs = dORM.QueryTable((*User)(nil)) - qs = dORM.QueryTable("User") - qs = dORM.QueryTable("user") - num, err := qs.Filter("UserName", "slene").Filter("user_name", "slene").Filter("profile__Age", 28).Count() - throwFail(t, err) - throwFail(t, AssertIs(num, 1)) - - num, err = qs.Filter("created", time.Now()).Count() - throwFail(t, err) - throwFail(t, AssertIs(num, 3)) - - // num, err = qs.Filter("created", time.Now().Format(format_Date)).Count() - // throwFail(t, err) - // throwFail(t, AssertIs(num, 3)) -} - -func TestOperators(t *testing.T) { - qs := dORM.QueryTable("user") - num, err := qs.Filter("user_name", "slene").Count() - throwFail(t, err) - throwFail(t, AssertIs(num, 1)) - - num, err = qs.Filter("user_name__exact", String("slene")).Count() - throwFail(t, err) - throwFail(t, AssertIs(num, 1)) - - num, err = qs.Filter("user_name__exact", "slene").Count() - throwFail(t, err) - throwFail(t, AssertIs(num, 1)) - - num, err = qs.Filter("user_name__iexact", "Slene").Count() - throwFail(t, err) - throwFail(t, AssertIs(num, 1)) - - num, err = qs.Filter("user_name__contains", "e").Count() - throwFail(t, err) - throwFail(t, AssertIs(num, 2)) - - var shouldNum int - - if IsSqlite || IsTidb { - shouldNum = 2 - } else { - shouldNum = 0 - } - - num, err = qs.Filter("user_name__contains", "E").Count() - throwFail(t, err) - throwFail(t, AssertIs(num, shouldNum)) - - num, err = qs.Filter("user_name__icontains", "E").Count() - throwFail(t, err) - throwFail(t, AssertIs(num, 2)) - - num, err = qs.Filter("user_name__icontains", "E").Count() - throwFail(t, err) - throwFail(t, AssertIs(num, 2)) - - num, err = qs.Filter("status__gt", 1).Count() - throwFail(t, err) - throwFail(t, AssertIs(num, 2)) - - num, err = qs.Filter("status__gte", 1).Count() - throwFail(t, err) - throwFail(t, AssertIs(num, 3)) - - num, err = qs.Filter("status__lt", Uint(3)).Count() - throwFail(t, err) - throwFail(t, AssertIs(num, 2)) - - num, err = qs.Filter("status__lte", Int(3)).Count() - throwFail(t, err) - throwFail(t, AssertIs(num, 3)) - - num, err = qs.Filter("user_name__startswith", "s").Count() - throwFail(t, err) - throwFail(t, AssertIs(num, 1)) - - if IsSqlite || IsTidb { - shouldNum = 1 - } else { - shouldNum = 0 - } - - num, err = qs.Filter("user_name__startswith", "S").Count() - throwFail(t, err) - throwFail(t, AssertIs(num, shouldNum)) - - num, err = qs.Filter("user_name__istartswith", "S").Count() - throwFail(t, err) - throwFail(t, AssertIs(num, 1)) - - num, err = qs.Filter("user_name__endswith", "e").Count() - throwFail(t, err) - throwFail(t, AssertIs(num, 2)) - - if IsSqlite || IsTidb { - shouldNum = 2 - } else { - shouldNum = 0 - } - - num, err = qs.Filter("user_name__endswith", "E").Count() - throwFail(t, err) - throwFail(t, AssertIs(num, shouldNum)) - - num, err = qs.Filter("user_name__iendswith", "E").Count() - throwFail(t, err) - throwFail(t, AssertIs(num, 2)) - - num, err = qs.Filter("profile__isnull", true).Count() - throwFail(t, err) - throwFail(t, AssertIs(num, 1)) - - num, err = qs.Filter("status__in", 1, 2).Count() - throwFail(t, err) - throwFail(t, AssertIs(num, 2)) - - num, err = qs.Filter("status__in", []int{1, 2}).Count() - throwFail(t, err) - throwFail(t, AssertIs(num, 2)) - - n1, n2 := 1, 2 - num, err = qs.Filter("status__in", []*int{&n1}, &n2).Count() - throwFail(t, err) - throwFail(t, AssertIs(num, 2)) - - num, err = qs.Filter("id__between", 2, 3).Count() - throwFail(t, err) - throwFail(t, AssertIs(num, 2)) - - num, err = qs.Filter("id__between", []int{2, 3}).Count() - throwFail(t, err) - throwFail(t, AssertIs(num, 2)) - - num, err = qs.FilterRaw("user_name", "= 'slene'").Count() - throwFail(t, err) - throwFail(t, AssertIs(num, 1)) - - num, err = qs.FilterRaw("status", "IN (1, 2)").Count() - throwFail(t, err) - throwFail(t, AssertIs(num, 2)) - - num, err = qs.FilterRaw("profile_id", "IN (SELECT id FROM user_profile WHERE age=30)").Count() - throwFail(t, err) - throwFail(t, AssertIs(num, 1)) -} - -func TestSetCond(t *testing.T) { - cond := NewCondition() - cond1 := cond.And("profile__isnull", false).AndNot("status__in", 1).Or("profile__age__gt", 2000) - - qs := dORM.QueryTable("user") - num, err := qs.SetCond(cond1).Count() - throwFail(t, err) - throwFail(t, AssertIs(num, 1)) - - cond2 := cond.AndCond(cond1).OrCond(cond.And("user_name", "slene")) - num, err = qs.SetCond(cond2).Count() - throwFail(t, err) - throwFail(t, AssertIs(num, 2)) - - cond3 := cond.AndNotCond(cond.And("status__in", 1)) - num, err = qs.SetCond(cond3).Count() - throwFail(t, err) - throwFail(t, AssertIs(num, 2)) - - cond4 := cond.And("user_name", "slene").OrNotCond(cond.And("user_name", "slene")) - num, err = qs.SetCond(cond4).Count() - throwFail(t, err) - throwFail(t, AssertIs(num, 3)) - - cond5 := cond.Raw("user_name", "= 'slene'").OrNotCond(cond.And("user_name", "slene")) - num, err = qs.SetCond(cond5).Count() - throwFail(t, err) - throwFail(t, AssertIs(num, 3)) -} - -func TestLimit(t *testing.T) { - var posts []*Post - qs := dORM.QueryTable("post") - num, err := qs.Limit(1).All(&posts) - throwFail(t, err) - throwFail(t, AssertIs(num, 1)) - - num, err = qs.Limit(-1).All(&posts) - throwFail(t, err) - throwFail(t, AssertIs(num, 4)) - - num, err = qs.Limit(-1, 2).All(&posts) - throwFail(t, err) - throwFail(t, AssertIs(num, 2)) - - num, err = qs.Limit(0, 2).All(&posts) - throwFail(t, err) - throwFail(t, AssertIs(num, 2)) -} - -func TestOffset(t *testing.T) { - var posts []*Post - qs := dORM.QueryTable("post") - num, err := qs.Limit(1).Offset(2).All(&posts) - throwFail(t, err) - throwFail(t, AssertIs(num, 1)) - - num, err = qs.Offset(2).All(&posts) - throwFail(t, err) - throwFail(t, AssertIs(num, 2)) -} - -func TestOrderBy(t *testing.T) { - qs := dORM.QueryTable("user") - num, err := qs.OrderBy("-status").Filter("user_name", "nobody").Count() - throwFail(t, err) - throwFail(t, AssertIs(num, 1)) - - num, err = qs.OrderBy("status").Filter("user_name", "slene").Count() - throwFail(t, err) - throwFail(t, AssertIs(num, 1)) - - num, err = qs.OrderBy("-profile__age").Filter("user_name", "astaxie").Count() - throwFail(t, err) - throwFail(t, AssertIs(num, 1)) -} - -func TestAll(t *testing.T) { - var users []*User - qs := dORM.QueryTable("user") - num, err := qs.OrderBy("Id").All(&users) - throwFail(t, err) - throwFailNow(t, AssertIs(num, 3)) - - throwFail(t, AssertIs(users[0].UserName, "slene")) - throwFail(t, AssertIs(users[1].UserName, "astaxie")) - throwFail(t, AssertIs(users[2].UserName, "nobody")) - - var users2 []User - qs = dORM.QueryTable("user") - num, err = qs.OrderBy("Id").All(&users2) - throwFail(t, err) - throwFailNow(t, AssertIs(num, 3)) - - throwFailNow(t, AssertIs(users2[0].UserName, "slene")) - throwFailNow(t, AssertIs(users2[1].UserName, "astaxie")) - throwFailNow(t, AssertIs(users2[2].UserName, "nobody")) - - qs = dORM.QueryTable("user") - num, err = qs.OrderBy("Id").RelatedSel().All(&users2, "UserName") - throwFail(t, err) - throwFailNow(t, AssertIs(num, 3)) - throwFailNow(t, AssertIs(len(users2), 3)) - throwFailNow(t, AssertIs(users2[0].UserName, "slene")) - throwFailNow(t, AssertIs(users2[1].UserName, "astaxie")) - throwFailNow(t, AssertIs(users2[2].UserName, "nobody")) - throwFailNow(t, AssertIs(users2[0].ID, 0)) - throwFailNow(t, AssertIs(users2[1].ID, 0)) - throwFailNow(t, AssertIs(users2[2].ID, 0)) - throwFailNow(t, AssertIs(users2[0].Profile == nil, false)) - throwFailNow(t, AssertIs(users2[1].Profile == nil, false)) - throwFailNow(t, AssertIs(users2[2].Profile == nil, true)) - - qs = dORM.QueryTable("user") - num, err = qs.Filter("user_name", "nothing").All(&users) - throwFailNow(t, err) - throwFailNow(t, AssertIs(num, 0)) - - var users3 []*User - qs = dORM.QueryTable("user") - num, err = qs.Filter("user_name", "nothing").All(&users3) - throwFailNow(t, err) - throwFailNow(t, AssertIs(num, 0)) - throwFailNow(t, AssertIs(users3 == nil, false)) -} - -func TestOne(t *testing.T) { - var user User - qs := dORM.QueryTable("user") - err := qs.One(&user) - throwFail(t, err) - - user = User{} - err = qs.OrderBy("Id").Limit(1).One(&user) - throwFailNow(t, err) - throwFail(t, AssertIs(user.UserName, "slene")) - throwFail(t, AssertNot(err, ErrMultiRows)) - - user = User{} - err = qs.OrderBy("-Id").Limit(100).One(&user) - throwFailNow(t, err) - throwFail(t, AssertIs(user.UserName, "nobody")) - throwFail(t, AssertNot(err, ErrMultiRows)) - - err = qs.Filter("user_name", "nothing").One(&user) - throwFail(t, AssertIs(err, ErrNoRows)) - -} - -func TestValues(t *testing.T) { - var maps []Params - qs := dORM.QueryTable("user") - - num, err := qs.OrderBy("Id").Values(&maps) - throwFail(t, err) - throwFail(t, AssertIs(num, 3)) - if num == 3 { - throwFail(t, AssertIs(maps[0]["UserName"], "slene")) - throwFail(t, AssertIs(maps[2]["Profile"], nil)) - } - - num, err = qs.OrderBy("Id").Values(&maps, "UserName", "Profile__Age") - throwFail(t, err) - throwFail(t, AssertIs(num, 3)) - if num == 3 { - throwFail(t, AssertIs(maps[0]["UserName"], "slene")) - throwFail(t, AssertIs(maps[0]["Profile__Age"], 28)) - throwFail(t, AssertIs(maps[2]["Profile__Age"], nil)) - } - - num, err = qs.Filter("UserName", "slene").Values(&maps) - throwFail(t, err) - throwFail(t, AssertIs(num, 1)) -} - -func TestValuesList(t *testing.T) { - var list []ParamsList - qs := dORM.QueryTable("user") - - num, err := qs.OrderBy("Id").ValuesList(&list) - throwFail(t, err) - throwFail(t, AssertIs(num, 3)) - if num == 3 { - throwFail(t, AssertIs(list[0][1], "slene")) - throwFail(t, AssertIs(list[2][9], nil)) - } - - num, err = qs.OrderBy("Id").ValuesList(&list, "UserName", "Profile__Age") - throwFail(t, err) - throwFail(t, AssertIs(num, 3)) - if num == 3 { - throwFail(t, AssertIs(list[0][0], "slene")) - throwFail(t, AssertIs(list[0][1], 28)) - throwFail(t, AssertIs(list[2][1], nil)) - } -} - -func TestValuesFlat(t *testing.T) { - var list ParamsList - qs := dORM.QueryTable("user") - - num, err := qs.OrderBy("id").ValuesFlat(&list, "UserName") - throwFail(t, err) - throwFail(t, AssertIs(num, 3)) - if num == 3 { - throwFail(t, AssertIs(list[0], "slene")) - throwFail(t, AssertIs(list[1], "astaxie")) - throwFail(t, AssertIs(list[2], "nobody")) - } -} - -func TestRelatedSel(t *testing.T) { - if IsTidb { - // Skip it. TiDB does not support relation now. - return - } - qs := dORM.QueryTable("user") - num, err := qs.Filter("profile__age", 28).Count() - throwFail(t, err) - throwFail(t, AssertIs(num, 1)) - - num, err = qs.Filter("profile__age__gt", 28).Count() - throwFail(t, err) - throwFail(t, AssertIs(num, 1)) - - num, err = qs.Filter("profile__user__profile__age__gt", 28).Count() - throwFail(t, err) - throwFail(t, AssertIs(num, 1)) - - var user User - err = qs.Filter("user_name", "slene").RelatedSel("profile").One(&user) - throwFail(t, err) - throwFail(t, AssertIs(num, 1)) - throwFail(t, AssertNot(user.Profile, nil)) - if user.Profile != nil { - throwFail(t, AssertIs(user.Profile.Age, 28)) - } - - err = qs.Filter("user_name", "slene").RelatedSel().One(&user) - throwFail(t, err) - throwFail(t, AssertIs(num, 1)) - throwFail(t, AssertNot(user.Profile, nil)) - if user.Profile != nil { - throwFail(t, AssertIs(user.Profile.Age, 28)) - } - - err = qs.Filter("user_name", "nobody").RelatedSel("profile").One(&user) - throwFail(t, err) - throwFail(t, AssertIs(num, 1)) - throwFail(t, AssertIs(user.Profile, nil)) - - qs = dORM.QueryTable("user_profile") - num, err = qs.Filter("user__username", "slene").Count() - throwFail(t, err) - throwFail(t, AssertIs(num, 1)) - - var posts []*Post - qs = dORM.QueryTable("post") - num, err = qs.RelatedSel().All(&posts) - throwFail(t, err) - throwFailNow(t, AssertIs(num, 4)) - - throwFailNow(t, AssertIs(posts[0].User.UserName, "slene")) - throwFailNow(t, AssertIs(posts[1].User.UserName, "astaxie")) - throwFailNow(t, AssertIs(posts[2].User.UserName, "astaxie")) - throwFailNow(t, AssertIs(posts[3].User.UserName, "nobody")) -} - -func TestReverseQuery(t *testing.T) { - var profile Profile - err := dORM.QueryTable("user_profile").Filter("User", 3).One(&profile) - throwFailNow(t, err) - throwFailNow(t, AssertIs(profile.Age, 30)) - - profile = Profile{} - err = dORM.QueryTable("user_profile").Filter("User__UserName", "astaxie").One(&profile) - throwFailNow(t, err) - throwFailNow(t, AssertIs(profile.Age, 30)) - - var user User - err = dORM.QueryTable("user").Filter("Posts__Title", "Examples").One(&user) - throwFailNow(t, err) - throwFailNow(t, AssertIs(user.UserName, "astaxie")) - - user = User{} - err = dORM.QueryTable("user").Filter("Posts__User__UserName", "astaxie").Limit(1).One(&user) - throwFailNow(t, err) - throwFailNow(t, AssertIs(user.UserName, "astaxie")) - - user = User{} - err = dORM.QueryTable("user").Filter("Posts__User__UserName", "astaxie").RelatedSel().Limit(1).One(&user) - throwFailNow(t, err) - throwFailNow(t, AssertIs(user.UserName, "astaxie")) - throwFailNow(t, AssertIs(user.Profile == nil, false)) - throwFailNow(t, AssertIs(user.Profile.Age, 30)) - - var posts []*Post - num, err := dORM.QueryTable("post").Filter("Tags__Tag__Name", "golang").All(&posts) - throwFailNow(t, err) - throwFailNow(t, AssertIs(num, 3)) - throwFailNow(t, AssertIs(posts[0].Title, "Introduction")) - - posts = []*Post{} - num, err = dORM.QueryTable("post").Filter("Tags__Tag__Name", "golang").Filter("User__UserName", "slene").All(&posts) - throwFailNow(t, err) - throwFailNow(t, AssertIs(num, 1)) - throwFailNow(t, AssertIs(posts[0].Title, "Introduction")) - - posts = []*Post{} - num, err = dORM.QueryTable("post").Filter("Tags__Tag__Name", "golang"). - Filter("User__UserName", "slene").RelatedSel().All(&posts) - throwFailNow(t, err) - throwFailNow(t, AssertIs(num, 1)) - throwFailNow(t, AssertIs(posts[0].User == nil, false)) - throwFailNow(t, AssertIs(posts[0].User.UserName, "slene")) - - var tags []*Tag - num, err = dORM.QueryTable("tag").Filter("Posts__Post__Title", "Introduction").All(&tags) - throwFailNow(t, err) - throwFailNow(t, AssertIs(num, 1)) - throwFailNow(t, AssertIs(tags[0].Name, "golang")) - - tags = []*Tag{} - num, err = dORM.QueryTable("tag").Filter("Posts__Post__Title", "Introduction"). - Filter("BestPost__User__UserName", "astaxie").All(&tags) - throwFailNow(t, err) - throwFailNow(t, AssertIs(num, 1)) - throwFailNow(t, AssertIs(tags[0].Name, "golang")) - - tags = []*Tag{} - num, err = dORM.QueryTable("tag").Filter("Posts__Post__Title", "Introduction"). - Filter("BestPost__User__UserName", "astaxie").RelatedSel().All(&tags) - throwFailNow(t, err) - throwFailNow(t, AssertIs(num, 1)) - throwFailNow(t, AssertIs(tags[0].Name, "golang")) - throwFailNow(t, AssertIs(tags[0].BestPost == nil, false)) - throwFailNow(t, AssertIs(tags[0].BestPost.Title, "Examples")) - throwFailNow(t, AssertIs(tags[0].BestPost.User == nil, false)) - throwFailNow(t, AssertIs(tags[0].BestPost.User.UserName, "astaxie")) -} - -func TestLoadRelated(t *testing.T) { - // load reverse foreign key - user := User{ID: 3} - - err := dORM.Read(&user) - throwFailNow(t, err) - - num, err := dORM.LoadRelated(&user, "Posts") - throwFailNow(t, err) - throwFailNow(t, AssertIs(num, 2)) - throwFailNow(t, AssertIs(len(user.Posts), 2)) - throwFailNow(t, AssertIs(user.Posts[0].User.ID, 3)) - - num, err = dORM.LoadRelated(&user, "Posts", true) - throwFailNow(t, err) - throwFailNow(t, AssertIs(num, 2)) - throwFailNow(t, AssertIs(len(user.Posts), 2)) - throwFailNow(t, AssertIs(user.Posts[0].User.UserName, "astaxie")) - - num, err = dORM.LoadRelated(&user, "Posts", true, 1) - throwFailNow(t, err) - throwFailNow(t, AssertIs(num, 1)) - throwFailNow(t, AssertIs(len(user.Posts), 1)) - - num, err = dORM.LoadRelated(&user, "Posts", true, 0, 0, "-Id") - throwFailNow(t, err) - throwFailNow(t, AssertIs(num, 2)) - throwFailNow(t, AssertIs(len(user.Posts), 2)) - throwFailNow(t, AssertIs(user.Posts[0].Title, "Formatting")) - - num, err = dORM.LoadRelated(&user, "Posts", true, 1, 1, "Id") - throwFailNow(t, err) - throwFailNow(t, AssertIs(num, 1)) - throwFailNow(t, AssertIs(len(user.Posts), 1)) - throwFailNow(t, AssertIs(user.Posts[0].Title, "Formatting")) - - // load reverse one to one - profile := Profile{ID: 3} - profile.BestPost = &Post{ID: 2} - num, err = dORM.Update(&profile, "BestPost") - throwFailNow(t, err) - throwFailNow(t, AssertIs(num, 1)) - - err = dORM.Read(&profile) - throwFailNow(t, err) - - num, err = dORM.LoadRelated(&profile, "User") - throwFailNow(t, err) - throwFailNow(t, AssertIs(num, 1)) - throwFailNow(t, AssertIs(profile.User == nil, false)) - throwFailNow(t, AssertIs(profile.User.UserName, "astaxie")) - - num, err = dORM.LoadRelated(&profile, "User", true) - throwFailNow(t, err) - throwFailNow(t, AssertIs(num, 1)) - throwFailNow(t, AssertIs(profile.User == nil, false)) - throwFailNow(t, AssertIs(profile.User.UserName, "astaxie")) - throwFailNow(t, AssertIs(profile.User.Profile.Age, profile.Age)) - - // load rel one to one - err = dORM.Read(&user) - throwFailNow(t, err) - - num, err = dORM.LoadRelated(&user, "Profile") - throwFailNow(t, err) - throwFailNow(t, AssertIs(num, 1)) - throwFailNow(t, AssertIs(user.Profile == nil, false)) - throwFailNow(t, AssertIs(user.Profile.Age, 30)) - - num, err = dORM.LoadRelated(&user, "Profile", true) - throwFailNow(t, err) - throwFailNow(t, AssertIs(num, 1)) - throwFailNow(t, AssertIs(user.Profile == nil, false)) - throwFailNow(t, AssertIs(user.Profile.Age, 30)) - throwFailNow(t, AssertIs(user.Profile.BestPost == nil, false)) - throwFailNow(t, AssertIs(user.Profile.BestPost.Title, "Examples")) - - post := Post{ID: 2} - - // load rel foreign key - err = dORM.Read(&post) - throwFailNow(t, err) - - num, err = dORM.LoadRelated(&post, "User") - throwFailNow(t, err) - throwFailNow(t, AssertIs(num, 1)) - throwFailNow(t, AssertIs(post.User == nil, false)) - throwFailNow(t, AssertIs(post.User.UserName, "astaxie")) - - num, err = dORM.LoadRelated(&post, "User", true) - throwFailNow(t, err) - throwFailNow(t, AssertIs(num, 1)) - throwFailNow(t, AssertIs(post.User == nil, false)) - throwFailNow(t, AssertIs(post.User.UserName, "astaxie")) - throwFailNow(t, AssertIs(post.User.Profile == nil, false)) - throwFailNow(t, AssertIs(post.User.Profile.Age, 30)) - - // load rel m2m - post = Post{ID: 2} - - err = dORM.Read(&post) - throwFailNow(t, err) - - num, err = dORM.LoadRelated(&post, "Tags") - throwFailNow(t, err) - throwFailNow(t, AssertIs(num, 2)) - throwFailNow(t, AssertIs(len(post.Tags), 2)) - throwFailNow(t, AssertIs(post.Tags[0].Name, "golang")) - - num, err = dORM.LoadRelated(&post, "Tags", true) - throwFailNow(t, err) - throwFailNow(t, AssertIs(num, 2)) - throwFailNow(t, AssertIs(len(post.Tags), 2)) - throwFailNow(t, AssertIs(post.Tags[0].Name, "golang")) - throwFailNow(t, AssertIs(post.Tags[0].BestPost == nil, false)) - throwFailNow(t, AssertIs(post.Tags[0].BestPost.User.UserName, "astaxie")) - - // load reverse m2m - tag := Tag{ID: 1} - - err = dORM.Read(&tag) - throwFailNow(t, err) - - num, err = dORM.LoadRelated(&tag, "Posts") - throwFailNow(t, err) - throwFailNow(t, AssertIs(num, 3)) - throwFailNow(t, AssertIs(tag.Posts[0].Title, "Introduction")) - throwFailNow(t, AssertIs(tag.Posts[0].User.ID, 2)) - throwFailNow(t, AssertIs(tag.Posts[0].User.Profile == nil, true)) - - num, err = dORM.LoadRelated(&tag, "Posts", true) - throwFailNow(t, err) - throwFailNow(t, AssertIs(num, 3)) - throwFailNow(t, AssertIs(tag.Posts[0].Title, "Introduction")) - throwFailNow(t, AssertIs(tag.Posts[0].User.ID, 2)) - throwFailNow(t, AssertIs(tag.Posts[0].User.UserName, "slene")) -} - -func TestQueryM2M(t *testing.T) { - post := Post{ID: 4} - m2m := dORM.QueryM2M(&post, "Tags") - - tag1 := []*Tag{{Name: "TestTag1"}, {Name: "TestTag2"}} - tag2 := &Tag{Name: "TestTag3"} - tag3 := []interface{}{&Tag{Name: "TestTag4"}} - - tags := []interface{}{tag1[0], tag1[1], tag2, tag3[0]} - - for _, tag := range tags { - _, err := dORM.Insert(tag) - throwFailNow(t, err) - } - - num, err := m2m.Add(tag1) - throwFailNow(t, err) - throwFailNow(t, AssertIs(num, 2)) - - num, err = m2m.Add(tag2) - throwFailNow(t, err) - throwFailNow(t, AssertIs(num, 1)) - - num, err = m2m.Add(tag3) - throwFailNow(t, err) - throwFailNow(t, AssertIs(num, 1)) - - num, err = m2m.Count() - throwFailNow(t, err) - throwFailNow(t, AssertIs(num, 5)) - - num, err = m2m.Remove(tag3) - throwFailNow(t, err) - throwFailNow(t, AssertIs(num, 1)) - - num, err = m2m.Count() - throwFailNow(t, err) - throwFailNow(t, AssertIs(num, 4)) - - exist := m2m.Exist(tag2) - throwFailNow(t, AssertIs(exist, true)) - - num, err = m2m.Remove(tag2) - throwFailNow(t, err) - throwFailNow(t, AssertIs(num, 1)) - - exist = m2m.Exist(tag2) - throwFailNow(t, AssertIs(exist, false)) - - num, err = m2m.Count() - throwFailNow(t, err) - throwFailNow(t, AssertIs(num, 3)) - - num, err = m2m.Clear() - throwFailNow(t, err) - throwFailNow(t, AssertIs(num, 3)) - - num, err = m2m.Count() - throwFailNow(t, err) - throwFailNow(t, AssertIs(num, 0)) - - tag := Tag{Name: "test"} - _, err = dORM.Insert(&tag) - throwFailNow(t, err) - - m2m = dORM.QueryM2M(&tag, "Posts") - - post1 := []*Post{{Title: "TestPost1"}, {Title: "TestPost2"}} - post2 := &Post{Title: "TestPost3"} - post3 := []interface{}{&Post{Title: "TestPost4"}} - - posts := []interface{}{post1[0], post1[1], post2, post3[0]} - - for _, post := range posts { - p := post.(*Post) - p.User = &User{ID: 1} - _, err := dORM.Insert(post) - throwFailNow(t, err) - } - - num, err = m2m.Add(post1) - throwFailNow(t, err) - throwFailNow(t, AssertIs(num, 2)) - - num, err = m2m.Add(post2) - throwFailNow(t, err) - throwFailNow(t, AssertIs(num, 1)) - - num, err = m2m.Add(post3) - throwFailNow(t, err) - throwFailNow(t, AssertIs(num, 1)) - - num, err = m2m.Count() - throwFailNow(t, err) - throwFailNow(t, AssertIs(num, 4)) - - num, err = m2m.Remove(post3) - throwFailNow(t, err) - throwFailNow(t, AssertIs(num, 1)) - - num, err = m2m.Count() - throwFailNow(t, err) - throwFailNow(t, AssertIs(num, 3)) - - exist = m2m.Exist(post2) - throwFailNow(t, AssertIs(exist, true)) - - num, err = m2m.Remove(post2) - throwFailNow(t, err) - throwFailNow(t, AssertIs(num, 1)) - - exist = m2m.Exist(post2) - throwFailNow(t, AssertIs(exist, false)) - - num, err = m2m.Count() - throwFailNow(t, err) - throwFailNow(t, AssertIs(num, 2)) - - num, err = m2m.Clear() - throwFailNow(t, err) - throwFailNow(t, AssertIs(num, 2)) - - num, err = m2m.Count() - throwFailNow(t, err) - throwFailNow(t, AssertIs(num, 0)) - - num, err = dORM.Delete(&tag) - throwFailNow(t, err) - throwFailNow(t, AssertIs(num, 1)) -} - -func TestQueryRelate(t *testing.T) { - // post := &Post{Id: 2} - - // qs := dORM.QueryRelate(post, "Tags") - // num, err := qs.Count() - // throwFailNow(t, err) - // throwFailNow(t, AssertIs(num, 2)) - - // var tags []*Tag - // num, err = qs.All(&tags) - // throwFailNow(t, err) - // throwFailNow(t, AssertIs(num, 2)) - // throwFailNow(t, AssertIs(tags[0].Name, "golang")) - - // num, err = dORM.QueryTable("Tag").Filter("Posts__Post", 2).Count() - // throwFailNow(t, err) - // throwFailNow(t, AssertIs(num, 2)) -} - -func TestPkManyRelated(t *testing.T) { - permission := &Permission{Name: "readPosts"} - err := dORM.Read(permission, "Name") - throwFailNow(t, err) - - var groups []*Group - qs := dORM.QueryTable("Group") - num, err := qs.Filter("Permissions__Permission", permission.ID).All(&groups) - throwFailNow(t, err) - throwFailNow(t, AssertIs(num, 2)) -} - -func TestPrepareInsert(t *testing.T) { - qs := dORM.QueryTable("user") - i, err := qs.PrepareInsert() - throwFailNow(t, err) - - var user User - user.UserName = "testing1" - num, err := i.Insert(&user) - throwFail(t, err) - throwFail(t, AssertIs(num > 0, true)) - - user.UserName = "testing2" - num, err = i.Insert(&user) - throwFail(t, err) - throwFail(t, AssertIs(num > 0, true)) - - num, err = qs.Filter("user_name__in", "testing1", "testing2").Delete() - throwFail(t, err) - throwFail(t, AssertIs(num, 2)) - - err = i.Close() - throwFail(t, err) - err = i.Close() - throwFail(t, AssertIs(err, ErrStmtClosed)) -} - -func TestRawExec(t *testing.T) { - Q := dDbBaser.TableQuote() - - query := fmt.Sprintf("UPDATE %suser%s SET %suser_name%s = ? WHERE %suser_name%s = ?", Q, Q, Q, Q, Q, Q) - res, err := dORM.Raw(query, "testing", "slene").Exec() - throwFail(t, err) - num, err := res.RowsAffected() - throwFail(t, AssertIs(num, 1), err) - - res, err = dORM.Raw(query, "slene", "testing").Exec() - throwFail(t, err) - num, err = res.RowsAffected() - throwFail(t, AssertIs(num, 1), err) -} - -func TestRawQueryRow(t *testing.T) { - var ( - Boolean bool - Char string - Text string - Time time.Time - Date time.Time - DateTime time.Time - Byte byte - Rune rune - Int int - Int8 int - Int16 int16 - Int32 int32 - Int64 int64 - Uint uint - Uint8 uint8 - Uint16 uint16 - Uint32 uint32 - Uint64 uint64 - Float32 float32 - Float64 float64 - Decimal float64 - ) - - dataValues := make(map[string]interface{}, len(DataValues)) - - for k, v := range DataValues { - dataValues[strings.ToLower(k)] = v - } - - Q := dDbBaser.TableQuote() - - cols := []string{ - "id", "boolean", "char", "text", "time", "date", "datetime", "byte", "rune", "int", "int8", "int16", "int32", - "int64", "uint", "uint8", "uint16", "uint32", "uint64", "float32", "float64", "decimal", - } - sep := fmt.Sprintf("%s, %s", Q, Q) - query := fmt.Sprintf("SELECT %s%s%s FROM data WHERE id = ?", Q, strings.Join(cols, sep), Q) - var id int - values := []interface{}{ - &id, &Boolean, &Char, &Text, &Time, &Date, &DateTime, &Byte, &Rune, &Int, &Int8, &Int16, &Int32, - &Int64, &Uint, &Uint8, &Uint16, &Uint32, &Uint64, &Float32, &Float64, &Decimal, - } - err := dORM.Raw(query, 1).QueryRow(values...) - throwFailNow(t, err) - for i, col := range cols { - vu := values[i] - v := reflect.ValueOf(vu).Elem().Interface() - switch col { - case "id": - throwFail(t, AssertIs(id, 1)) - case "time": - v = v.(time.Time).In(DefaultTimeLoc) - value := dataValues[col].(time.Time).In(DefaultTimeLoc) - throwFail(t, AssertIs(v, value, testTime)) - case "date": - v = v.(time.Time).In(DefaultTimeLoc) - value := dataValues[col].(time.Time).In(DefaultTimeLoc) - throwFail(t, AssertIs(v, value, testDate)) - case "datetime": - v = v.(time.Time).In(DefaultTimeLoc) - value := dataValues[col].(time.Time).In(DefaultTimeLoc) - throwFail(t, AssertIs(v, value, testDateTime)) - default: - throwFail(t, AssertIs(v, dataValues[col])) - } - } - - var ( - uid int - status *int - pid *int - ) - - cols = []string{ - "id", "Status", "profile_id", - } - query = fmt.Sprintf("SELECT %s%s%s FROM %suser%s WHERE id = ?", Q, strings.Join(cols, sep), Q, Q, Q) - err = dORM.Raw(query, 4).QueryRow(&uid, &status, &pid) - throwFail(t, err) - throwFail(t, AssertIs(uid, 4)) - throwFail(t, AssertIs(*status, 3)) - throwFail(t, AssertIs(pid, nil)) - - // test for sql.Null* fields - nData := &DataNull{ - NullString: sql.NullString{String: "test sql.null", Valid: true}, - NullBool: sql.NullBool{Bool: true, Valid: true}, - NullInt64: sql.NullInt64{Int64: 42, Valid: true}, - NullFloat64: sql.NullFloat64{Float64: 42.42, Valid: true}, - } - newId, err := dORM.Insert(nData) - throwFailNow(t, err) - - var nd *DataNull - query = fmt.Sprintf("SELECT * FROM %sdata_null%s where id=?", Q, Q) - err = dORM.Raw(query, newId).QueryRow(&nd) - throwFailNow(t, err) - - throwFailNow(t, AssertNot(nd, nil)) - throwFail(t, AssertIs(nd.NullBool.Valid, true)) - throwFail(t, AssertIs(nd.NullBool.Bool, true)) - throwFail(t, AssertIs(nd.NullString.Valid, true)) - throwFail(t, AssertIs(nd.NullString.String, "test sql.null")) - throwFail(t, AssertIs(nd.NullInt64.Valid, true)) - throwFail(t, AssertIs(nd.NullInt64.Int64, 42)) - throwFail(t, AssertIs(nd.NullFloat64.Valid, true)) - throwFail(t, AssertIs(nd.NullFloat64.Float64, 42.42)) -} - -// user_profile table -type userProfile struct { - User - Age int - Money float64 -} - -func TestQueryRows(t *testing.T) { - Q := dDbBaser.TableQuote() - - var datas []*Data - - query := fmt.Sprintf("SELECT * FROM %sdata%s", Q, Q) - num, err := dORM.Raw(query).QueryRows(&datas) - throwFailNow(t, err) - throwFailNow(t, AssertIs(num, 1)) - throwFailNow(t, AssertIs(len(datas), 1)) - - ind := reflect.Indirect(reflect.ValueOf(datas[0])) - - for name, value := range DataValues { - e := ind.FieldByName(name) - vu := e.Interface() - switch name { - case "Time": - vu = vu.(time.Time).In(DefaultTimeLoc).Format(testTime) - value = value.(time.Time).In(DefaultTimeLoc).Format(testTime) - case "Date": - vu = vu.(time.Time).In(DefaultTimeLoc).Format(testDate) - value = value.(time.Time).In(DefaultTimeLoc).Format(testDate) - case "DateTime": - vu = vu.(time.Time).In(DefaultTimeLoc).Format(testDateTime) - value = value.(time.Time).In(DefaultTimeLoc).Format(testDateTime) - } - throwFail(t, AssertIs(vu == value, true), value, vu) - } - - var datas2 []Data - - query = fmt.Sprintf("SELECT * FROM %sdata%s", Q, Q) - num, err = dORM.Raw(query).QueryRows(&datas2) - throwFailNow(t, err) - throwFailNow(t, AssertIs(num, 1)) - throwFailNow(t, AssertIs(len(datas2), 1)) - - ind = reflect.Indirect(reflect.ValueOf(datas2[0])) - - for name, value := range DataValues { - e := ind.FieldByName(name) - vu := e.Interface() - switch name { - case "Time": - vu = vu.(time.Time).In(DefaultTimeLoc).Format(testTime) - value = value.(time.Time).In(DefaultTimeLoc).Format(testTime) - case "Date": - vu = vu.(time.Time).In(DefaultTimeLoc).Format(testDate) - value = value.(time.Time).In(DefaultTimeLoc).Format(testDate) - case "DateTime": - vu = vu.(time.Time).In(DefaultTimeLoc).Format(testDateTime) - value = value.(time.Time).In(DefaultTimeLoc).Format(testDateTime) - } - throwFail(t, AssertIs(vu == value, true), value, vu) - } - - var ids []int - var usernames []string - query = fmt.Sprintf("SELECT %sid%s, %suser_name%s FROM %suser%s ORDER BY %sid%s ASC", Q, Q, Q, Q, Q, Q, Q, Q) - num, err = dORM.Raw(query).QueryRows(&ids, &usernames) - throwFailNow(t, err) - throwFailNow(t, AssertIs(num, 3)) - throwFailNow(t, AssertIs(len(ids), 3)) - throwFailNow(t, AssertIs(ids[0], 2)) - throwFailNow(t, AssertIs(usernames[0], "slene")) - throwFailNow(t, AssertIs(ids[1], 3)) - throwFailNow(t, AssertIs(usernames[1], "astaxie")) - throwFailNow(t, AssertIs(ids[2], 4)) - throwFailNow(t, AssertIs(usernames[2], "nobody")) - - // test query rows by nested struct - var l []userProfile - query = fmt.Sprintf("SELECT * FROM %suser_profile%s LEFT JOIN %suser%s ON %suser_profile%s.%sid%s = %suser%s.%sid%s", Q, Q, Q, Q, Q, Q, Q, Q, Q, Q, Q, Q) - num, err = dORM.Raw(query).QueryRows(&l) - throwFailNow(t, err) - throwFailNow(t, AssertIs(num, 2)) - throwFailNow(t, AssertIs(len(l), 2)) - throwFailNow(t, AssertIs(l[0].UserName, "slene")) - throwFailNow(t, AssertIs(l[0].Age, 28)) - throwFailNow(t, AssertIs(l[1].UserName, "astaxie")) - throwFailNow(t, AssertIs(l[1].Age, 30)) - - // test for sql.Null* fields - nData := &DataNull{ - NullString: sql.NullString{String: "test sql.null", Valid: true}, - NullBool: sql.NullBool{Bool: true, Valid: true}, - NullInt64: sql.NullInt64{Int64: 42, Valid: true}, - NullFloat64: sql.NullFloat64{Float64: 42.42, Valid: true}, - } - newId, err := dORM.Insert(nData) - throwFailNow(t, err) - - var nDataList []*DataNull - query = fmt.Sprintf("SELECT * FROM %sdata_null%s where id=?", Q, Q) - num, err = dORM.Raw(query, newId).QueryRows(&nDataList) - throwFailNow(t, err) - throwFailNow(t, AssertIs(num, 1)) - - nd := nDataList[0] - throwFailNow(t, AssertNot(nd, nil)) - throwFail(t, AssertIs(nd.NullBool.Valid, true)) - throwFail(t, AssertIs(nd.NullBool.Bool, true)) - throwFail(t, AssertIs(nd.NullString.Valid, true)) - throwFail(t, AssertIs(nd.NullString.String, "test sql.null")) - throwFail(t, AssertIs(nd.NullInt64.Valid, true)) - throwFail(t, AssertIs(nd.NullInt64.Int64, 42)) - throwFail(t, AssertIs(nd.NullFloat64.Valid, true)) - throwFail(t, AssertIs(nd.NullFloat64.Float64, 42.42)) -} - -func TestRawValues(t *testing.T) { - Q := dDbBaser.TableQuote() - - var maps []Params - query := fmt.Sprintf("SELECT %suser_name%s FROM %suser%s WHERE %sStatus%s = ?", Q, Q, Q, Q, Q, Q) - num, err := dORM.Raw(query, 1).Values(&maps) - throwFail(t, err) - throwFail(t, AssertIs(num, 1)) - if num == 1 { - throwFail(t, AssertIs(maps[0]["user_name"], "slene")) - } - - var lists []ParamsList - num, err = dORM.Raw(query, 1).ValuesList(&lists) - throwFail(t, err) - throwFail(t, AssertIs(num, 1)) - if num == 1 { - throwFail(t, AssertIs(lists[0][0], "slene")) - } - - query = fmt.Sprintf("SELECT %sprofile_id%s FROM %suser%s ORDER BY %sid%s ASC", Q, Q, Q, Q, Q, Q) - var list ParamsList - num, err = dORM.Raw(query).ValuesFlat(&list) - throwFail(t, err) - throwFail(t, AssertIs(num, 3)) - if num == 3 { - throwFail(t, AssertIs(list[0], "2")) - throwFail(t, AssertIs(list[1], "3")) - throwFail(t, AssertIs(list[2], nil)) - } -} - -func TestRawPrepare(t *testing.T) { - switch { - case IsMysql || IsSqlite: - - pre, err := dORM.Raw("INSERT INTO tag (name) VALUES (?)").Prepare() - throwFail(t, err) - if pre != nil { - r, err := pre.Exec("name1") - throwFail(t, err) - - tid, err := r.LastInsertId() - throwFail(t, err) - throwFail(t, AssertIs(tid > 0, true)) - - r, err = pre.Exec("name2") - throwFail(t, err) - - id, err := r.LastInsertId() - throwFail(t, err) - throwFail(t, AssertIs(id, tid+1)) - - r, err = pre.Exec("name3") - throwFail(t, err) - - id, err = r.LastInsertId() - throwFail(t, err) - throwFail(t, AssertIs(id, tid+2)) - - err = pre.Close() - throwFail(t, err) - - res, err := dORM.Raw("DELETE FROM tag WHERE name IN (?, ?, ?)", []string{"name1", "name2", "name3"}).Exec() - throwFail(t, err) - - num, err := res.RowsAffected() - throwFail(t, err) - throwFail(t, AssertIs(num, 3)) - } - - case IsPostgres: - - pre, err := dORM.Raw(`INSERT INTO "tag" ("name") VALUES (?) RETURNING "id"`).Prepare() - throwFail(t, err) - if pre != nil { - _, err := pre.Exec("name1") - throwFail(t, err) - - _, err = pre.Exec("name2") - throwFail(t, err) - - _, err = pre.Exec("name3") - throwFail(t, err) - - err = pre.Close() - throwFail(t, err) - - res, err := dORM.Raw(`DELETE FROM "tag" WHERE "name" IN (?, ?, ?)`, []string{"name1", "name2", "name3"}).Exec() - throwFail(t, err) - - if err == nil { - num, err := res.RowsAffected() - throwFail(t, err) - throwFail(t, AssertIs(num, 3)) - } - } - } -} - -func TestUpdate(t *testing.T) { - qs := dORM.QueryTable("user") - num, err := qs.Filter("user_name", "slene").Filter("is_staff", false).Update(Params{ - "is_staff": true, - "is_active": true, - }) - throwFail(t, err) - throwFail(t, AssertIs(num, 1)) - - // with join - num, err = qs.Filter("user_name", "slene").Filter("profile__age", 28).Filter("is_staff", true).Update(Params{ - "is_staff": false, - }) - throwFail(t, err) - throwFail(t, AssertIs(num, 1)) - - num, err = qs.Filter("user_name", "slene").Update(Params{ - "Nums": ColValue(ColAdd, 100), - }) - throwFail(t, err) - throwFail(t, AssertIs(num, 1)) - - num, err = qs.Filter("user_name", "slene").Update(Params{ - "Nums": ColValue(ColMinus, 50), - }) - throwFail(t, err) - throwFail(t, AssertIs(num, 1)) - - num, err = qs.Filter("user_name", "slene").Update(Params{ - "Nums": ColValue(ColMultiply, 3), - }) - throwFail(t, err) - throwFail(t, AssertIs(num, 1)) - - num, err = qs.Filter("user_name", "slene").Update(Params{ - "Nums": ColValue(ColExcept, 5), - }) - throwFail(t, err) - throwFail(t, AssertIs(num, 1)) - - user := User{UserName: "slene"} - err = dORM.Read(&user, "UserName") - throwFail(t, err) - throwFail(t, AssertIs(user.Nums, 30)) -} - -func TestDelete(t *testing.T) { - qs := dORM.QueryTable("user_profile") - num, err := qs.Filter("user__user_name", "slene").Delete() - throwFail(t, err) - throwFail(t, AssertIs(num, 1)) - - qs = dORM.QueryTable("user") - num, err = qs.Filter("user_name", "slene").Filter("profile__isnull", true).Count() - throwFail(t, err) - throwFail(t, AssertIs(num, 1)) - - qs = dORM.QueryTable("comment") - num, err = qs.Count() - throwFail(t, err) - throwFail(t, AssertIs(num, 6)) - - qs = dORM.QueryTable("post") - num, err = qs.Filter("Id", 3).Delete() - throwFail(t, err) - throwFail(t, AssertIs(num, 1)) - - qs = dORM.QueryTable("comment") - num, err = qs.Count() - throwFail(t, err) - throwFail(t, AssertIs(num, 4)) - - qs = dORM.QueryTable("comment") - num, err = qs.Filter("Post__User", 3).Delete() - throwFail(t, err) - throwFail(t, AssertIs(num, 3)) - - qs = dORM.QueryTable("comment") - num, err = qs.Count() - throwFail(t, err) - throwFail(t, AssertIs(num, 1)) -} - -func TestTransaction(t *testing.T) { - // this test worked when database support transaction - - o := NewOrm() - err := o.Begin() - throwFail(t, err) - - var names = []string{"1", "2", "3"} - - var tag Tag - tag.Name = names[0] - id, err := o.Insert(&tag) - throwFail(t, err) - throwFail(t, AssertIs(id > 0, true)) - - num, err := o.QueryTable("tag").Filter("name", "golang").Update(Params{"name": names[1]}) - throwFail(t, err) - throwFail(t, AssertIs(num, 1)) - - switch { - case IsMysql || IsSqlite: - res, err := o.Raw("INSERT INTO tag (name) VALUES (?)", names[2]).Exec() - throwFail(t, err) - if err == nil { - id, err = res.LastInsertId() - throwFail(t, err) - throwFail(t, AssertIs(id > 0, true)) - } - } - - err = o.Rollback() - throwFail(t, err) - - num, err = o.QueryTable("tag").Filter("name__in", names).Count() - throwFail(t, err) - throwFail(t, AssertIs(num, 0)) - - err = o.Begin() - throwFail(t, err) - - tag.Name = "commit" - id, err = o.Insert(&tag) - throwFail(t, err) - throwFail(t, AssertIs(id > 0, true)) - - o.Commit() - throwFail(t, err) - - num, err = o.QueryTable("tag").Filter("name", "commit").Delete() - throwFail(t, err) - throwFail(t, AssertIs(num, 1)) - -} - -func TestTransactionIsolationLevel(t *testing.T) { - // this test worked when database support transaction isolation level - if IsSqlite { - return - } - - o1 := NewOrm() - o2 := NewOrm() - - // start two transaction with isolation level repeatable read - err := o1.BeginTx(context.Background(), &sql.TxOptions{Isolation: sql.LevelRepeatableRead}) - throwFail(t, err) - err = o2.BeginTx(context.Background(), &sql.TxOptions{Isolation: sql.LevelRepeatableRead}) - throwFail(t, err) - - // o1 insert tag - var tag Tag - tag.Name = "test-transaction" - id, err := o1.Insert(&tag) - throwFail(t, err) - throwFail(t, AssertIs(id > 0, true)) - - // o2 query tag table, no result - num, err := o2.QueryTable("tag").Filter("name", "test-transaction").Count() - throwFail(t, err) - throwFail(t, AssertIs(num, 0)) - - // o1 commit - o1.Commit() - - // o2 query tag table, still no result - num, err = o2.QueryTable("tag").Filter("name", "test-transaction").Count() - throwFail(t, err) - throwFail(t, AssertIs(num, 0)) - - // o2 commit and query tag table, get the result - o2.Commit() - num, err = o2.QueryTable("tag").Filter("name", "test-transaction").Count() - throwFail(t, err) - throwFail(t, AssertIs(num, 1)) - - num, err = o1.QueryTable("tag").Filter("name", "test-transaction").Delete() - throwFail(t, err) - throwFail(t, AssertIs(num, 1)) -} - -func TestBeginTxWithContextCanceled(t *testing.T) { - o := NewOrm() - ctx, cancel := context.WithCancel(context.Background()) - o.BeginTx(ctx, nil) - id, err := o.Insert(&Tag{Name: "test-context"}) - throwFail(t, err) - throwFail(t, AssertIs(id > 0, true)) - - // cancel the context before commit to make it error - cancel() - err = o.Commit() - throwFail(t, AssertIs(err, context.Canceled)) -} - -func TestReadOrCreate(t *testing.T) { - u := &User{ - UserName: "Kyle", - Email: "kylemcc@gmail.com", - Password: "other_pass", - Status: 7, - IsStaff: false, - IsActive: true, - } - - created, pk, err := dORM.ReadOrCreate(u, "UserName") - throwFail(t, err) - throwFail(t, AssertIs(created, true)) - throwFail(t, AssertIs(u.ID, pk)) - throwFail(t, AssertIs(u.UserName, "Kyle")) - throwFail(t, AssertIs(u.Email, "kylemcc@gmail.com")) - throwFail(t, AssertIs(u.Password, "other_pass")) - throwFail(t, AssertIs(u.Status, 7)) - throwFail(t, AssertIs(u.IsStaff, false)) - throwFail(t, AssertIs(u.IsActive, true)) - throwFail(t, AssertIs(u.Created.In(DefaultTimeLoc), u.Created.In(DefaultTimeLoc), testDate)) - throwFail(t, AssertIs(u.Updated.In(DefaultTimeLoc), u.Updated.In(DefaultTimeLoc), testDateTime)) - - nu := &User{UserName: u.UserName, Email: "someotheremail@gmail.com"} - created, pk, err = dORM.ReadOrCreate(nu, "UserName") - throwFail(t, err) - throwFail(t, AssertIs(created, false)) - throwFail(t, AssertIs(nu.ID, u.ID)) - throwFail(t, AssertIs(pk, u.ID)) - throwFail(t, AssertIs(nu.UserName, u.UserName)) - throwFail(t, AssertIs(nu.Email, u.Email)) // should contain the value in the table, not the one specified above - throwFail(t, AssertIs(nu.Password, u.Password)) - throwFail(t, AssertIs(nu.Status, u.Status)) - throwFail(t, AssertIs(nu.IsStaff, u.IsStaff)) - throwFail(t, AssertIs(nu.IsActive, u.IsActive)) - - dORM.Delete(u) -} - -func TestInLine(t *testing.T) { - name := "inline" - email := "hello@go.com" - inline := NewInLine() - inline.Name = name - inline.Email = email - - id, err := dORM.Insert(inline) - throwFail(t, err) - throwFail(t, AssertIs(id, 1)) - - il := NewInLine() - il.ID = 1 - err = dORM.Read(il) - throwFail(t, err) - - throwFail(t, AssertIs(il.Name, name)) - throwFail(t, AssertIs(il.Email, email)) - throwFail(t, AssertIs(il.Created.In(DefaultTimeLoc), inline.Created.In(DefaultTimeLoc), testDate)) - throwFail(t, AssertIs(il.Updated.In(DefaultTimeLoc), inline.Updated.In(DefaultTimeLoc), testDateTime)) -} - -func TestInLineOneToOne(t *testing.T) { - name := "121" - email := "121@go.com" - inline := NewInLine() - inline.Name = name - inline.Email = email - - id, err := dORM.Insert(inline) - throwFail(t, err) - throwFail(t, AssertIs(id, 2)) - - note := "one2one" - il121 := NewInLineOneToOne() - il121.Note = note - il121.InLine = inline - _, err = dORM.Insert(il121) - throwFail(t, err) - throwFail(t, AssertIs(il121.ID, 1)) - - il := NewInLineOneToOne() - err = dORM.QueryTable(il).Filter("Id", 1).RelatedSel().One(il) - - throwFail(t, err) - throwFail(t, AssertIs(il.Note, note)) - throwFail(t, AssertIs(il.InLine.ID, id)) - throwFail(t, AssertIs(il.InLine.Name, name)) - throwFail(t, AssertIs(il.InLine.Email, email)) - - rinline := NewInLine() - err = dORM.QueryTable(rinline).Filter("InLineOneToOne__Id", 1).One(rinline) - - throwFail(t, err) - throwFail(t, AssertIs(rinline.ID, id)) - throwFail(t, AssertIs(rinline.Name, name)) - throwFail(t, AssertIs(rinline.Email, email)) -} - -func TestIntegerPk(t *testing.T) { - its := []IntegerPk{ - {ID: math.MinInt64, Value: "-"}, - {ID: 0, Value: "0"}, - {ID: math.MaxInt64, Value: "+"}, - } - - num, err := dORM.InsertMulti(len(its), its) - throwFail(t, err) - throwFail(t, AssertIs(num, len(its))) - - for _, intPk := range its { - out := IntegerPk{ID: intPk.ID} - err = dORM.Read(&out) - throwFail(t, err) - throwFail(t, AssertIs(out.Value, intPk.Value)) - } - - num, err = dORM.InsertMulti(1, []*IntegerPk{{ - ID: 1, Value: "ok", - }}) - throwFail(t, err) - throwFail(t, AssertIs(num, 1)) -} - -func TestInsertAuto(t *testing.T) { - u := &User{ - UserName: "autoPre", - Email: "autoPre@gmail.com", - } - - id, err := dORM.Insert(u) - throwFail(t, err) - - id += 100 - su := &User{ - ID: int(id), - UserName: "auto", - Email: "auto@gmail.com", - } - - nid, err := dORM.Insert(su) - throwFail(t, err) - throwFail(t, AssertIs(nid, id)) - - users := []User{ - {ID: int(id + 100), UserName: "auto_100"}, - {ID: int(id + 110), UserName: "auto_110"}, - {ID: int(id + 120), UserName: "auto_120"}, - } - num, err := dORM.InsertMulti(100, users) - throwFail(t, err) - throwFail(t, AssertIs(num, 3)) - - u = &User{ - UserName: "auto_121", - } - - nid, err = dORM.Insert(u) - throwFail(t, err) - throwFail(t, AssertIs(nid, id+120+1)) -} - -func TestUintPk(t *testing.T) { - name := "go" - u := &UintPk{ - ID: 8, - Name: name, - } - - created, _, err := dORM.ReadOrCreate(u, "ID") - throwFail(t, err) - throwFail(t, AssertIs(created, true)) - throwFail(t, AssertIs(u.Name, name)) - - nu := &UintPk{ID: 8} - created, pk, err := dORM.ReadOrCreate(nu, "ID") - throwFail(t, err) - throwFail(t, AssertIs(created, false)) - throwFail(t, AssertIs(nu.ID, u.ID)) - throwFail(t, AssertIs(pk, u.ID)) - throwFail(t, AssertIs(nu.Name, name)) - - dORM.Delete(u) -} - -func TestPtrPk(t *testing.T) { - parent := &IntegerPk{ID: 10, Value: "10"} - - id, _ := dORM.Insert(parent) - if !IsMysql { - // MySql does not support last_insert_id in this case: see #2382 - throwFail(t, AssertIs(id, 10)) - } - - ptr := PtrPk{ID: parent, Positive: true} - num, err := dORM.InsertMulti(2, []PtrPk{ptr}) - throwFail(t, err) - throwFail(t, AssertIs(num, 1)) - throwFail(t, AssertIs(ptr.ID, parent)) - - nptr := &PtrPk{ID: parent} - created, pk, err := dORM.ReadOrCreate(nptr, "ID") - throwFail(t, err) - throwFail(t, AssertIs(created, false)) - throwFail(t, AssertIs(pk, 10)) - throwFail(t, AssertIs(nptr.ID, parent)) - throwFail(t, AssertIs(nptr.Positive, true)) - - nptr = &PtrPk{Positive: true} - created, pk, err = dORM.ReadOrCreate(nptr, "Positive") - throwFail(t, err) - throwFail(t, AssertIs(created, false)) - throwFail(t, AssertIs(pk, 10)) - throwFail(t, AssertIs(nptr.ID, parent)) - - nptr.Positive = false - num, err = dORM.Update(nptr) - throwFail(t, err) - throwFail(t, AssertIs(num, 1)) - throwFail(t, AssertIs(nptr.ID, parent)) - throwFail(t, AssertIs(nptr.Positive, false)) - - num, err = dORM.Delete(nptr) - throwFail(t, err) - throwFail(t, AssertIs(num, 1)) -} - -func TestSnake(t *testing.T) { - cases := map[string]string{ - "i": "i", - "I": "i", - "iD": "i_d", - "ID": "i_d", - "NO": "n_o", - "NOO": "n_o_o", - "NOOooOOoo": "n_o_ooo_o_ooo", - "OrderNO": "order_n_o", - "tagName": "tag_name", - "tag_Name": "tag__name", - "tag_name": "tag_name", - "_tag_name": "_tag_name", - "tag_666name": "tag_666name", - "tag_666Name": "tag_666_name", - } - for name, want := range cases { - got := snakeString(name) - throwFail(t, AssertIs(got, want)) - } -} - -func TestIgnoreCaseTag(t *testing.T) { - type testTagModel struct { - ID int `orm:"pk"` - NOO string `orm:"column(n)"` - Name01 string `orm:"NULL"` - Name02 string `orm:"COLUMN(Name)"` - Name03 string `orm:"Column(name)"` - } - modelCache.clean() - RegisterModel(&testTagModel{}) - info, ok := modelCache.get("test_tag_model") - throwFail(t, AssertIs(ok, true)) - throwFail(t, AssertNot(info, nil)) - if t == nil { - return - } - throwFail(t, AssertIs(info.fields.GetByName("NOO").column, "n")) - throwFail(t, AssertIs(info.fields.GetByName("Name01").null, true)) - throwFail(t, AssertIs(info.fields.GetByName("Name02").column, "Name")) - throwFail(t, AssertIs(info.fields.GetByName("Name03").column, "name")) -} - -func TestInsertOrUpdate(t *testing.T) { - RegisterModel(new(User)) - user := User{UserName: "unique_username133", Status: 1, Password: "o"} - user1 := User{UserName: "unique_username133", Status: 2, Password: "o"} - user2 := User{UserName: "unique_username133", Status: 3, Password: "oo"} - dORM.Insert(&user) - test := User{UserName: "unique_username133"} - fmt.Println(dORM.Driver().Name()) - if dORM.Driver().Name() == "sqlite3" { - fmt.Println("sqlite3 is nonsupport") - return - } - // test1 - _, err := dORM.InsertOrUpdate(&user1, "user_name") - if err != nil { - fmt.Println(err) - if err.Error() == "postgres version must 9.5 or higher" || err.Error() == "`sqlite3` nonsupport InsertOrUpdate in beego" { - } else { - throwFailNow(t, err) - } - } else { - dORM.Read(&test, "user_name") - throwFailNow(t, AssertIs(user1.Status, test.Status)) - } - // test2 - _, err = dORM.InsertOrUpdate(&user2, "user_name") - if err != nil { - fmt.Println(err) - if err.Error() == "postgres version must 9.5 or higher" || err.Error() == "`sqlite3` nonsupport InsertOrUpdate in beego" { - } else { - throwFailNow(t, err) - } - } else { - dORM.Read(&test, "user_name") - throwFailNow(t, AssertIs(user2.Status, test.Status)) - throwFailNow(t, AssertIs(user2.Password, strings.TrimSpace(test.Password))) - } - - // postgres ON CONFLICT DO UPDATE SET can`t use colu=colu+values - if IsPostgres { - return - } - // test3 + - _, err = dORM.InsertOrUpdate(&user2, "user_name", "status=status+1") - if err != nil { - fmt.Println(err) - if err.Error() == "postgres version must 9.5 or higher" || err.Error() == "`sqlite3` nonsupport InsertOrUpdate in beego" { - } else { - throwFailNow(t, err) - } - } else { - dORM.Read(&test, "user_name") - throwFailNow(t, AssertIs(user2.Status+1, test.Status)) - } - // test4 - - _, err = dORM.InsertOrUpdate(&user2, "user_name", "status=status-1") - if err != nil { - fmt.Println(err) - if err.Error() == "postgres version must 9.5 or higher" || err.Error() == "`sqlite3` nonsupport InsertOrUpdate in beego" { - } else { - throwFailNow(t, err) - } - } else { - dORM.Read(&test, "user_name") - throwFailNow(t, AssertIs((user2.Status+1)-1, test.Status)) - } - // test5 * - _, err = dORM.InsertOrUpdate(&user2, "user_name", "status=status*3") - if err != nil { - fmt.Println(err) - if err.Error() == "postgres version must 9.5 or higher" || err.Error() == "`sqlite3` nonsupport InsertOrUpdate in beego" { - } else { - throwFailNow(t, err) - } - } else { - dORM.Read(&test, "user_name") - throwFailNow(t, AssertIs(((user2.Status+1)-1)*3, test.Status)) - } - // test6 / - _, err = dORM.InsertOrUpdate(&user2, "user_name", "Status=Status/3") - if err != nil { - fmt.Println(err) - if err.Error() == "postgres version must 9.5 or higher" || err.Error() == "`sqlite3` nonsupport InsertOrUpdate in beego" { - } else { - throwFailNow(t, err) - } - } else { - dORM.Read(&test, "user_name") - throwFailNow(t, AssertIs((((user2.Status+1)-1)*3)/3, test.Status)) - } -} diff --git a/orm/qb.go b/orm/qb.go deleted file mode 100644 index e0655a17..00000000 --- a/orm/qb.go +++ /dev/null @@ -1,62 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package orm - -import "errors" - -// QueryBuilder is the Query builder interface -type QueryBuilder interface { - Select(fields ...string) QueryBuilder - ForUpdate() QueryBuilder - From(tables ...string) QueryBuilder - InnerJoin(table string) QueryBuilder - LeftJoin(table string) QueryBuilder - RightJoin(table string) QueryBuilder - On(cond string) QueryBuilder - Where(cond string) QueryBuilder - And(cond string) QueryBuilder - Or(cond string) QueryBuilder - In(vals ...string) QueryBuilder - OrderBy(fields ...string) QueryBuilder - Asc() QueryBuilder - Desc() QueryBuilder - Limit(limit int) QueryBuilder - Offset(offset int) QueryBuilder - GroupBy(fields ...string) QueryBuilder - Having(cond string) QueryBuilder - Update(tables ...string) QueryBuilder - Set(kv ...string) QueryBuilder - Delete(tables ...string) QueryBuilder - InsertInto(table string, fields ...string) QueryBuilder - Values(vals ...string) QueryBuilder - Subquery(sub string, alias string) string - String() string -} - -// NewQueryBuilder return the QueryBuilder -func NewQueryBuilder(driver string) (qb QueryBuilder, err error) { - if driver == "mysql" { - qb = new(MySQLQueryBuilder) - } else if driver == "tidb" { - qb = new(TiDBQueryBuilder) - } else if driver == "postgres" { - err = errors.New("postgres query builder is not supported yet") - } else if driver == "sqlite" { - err = errors.New("sqlite query builder is not supported yet") - } else { - err = errors.New("unknown driver for query builder") - } - return -} diff --git a/orm/qb_mysql.go b/orm/qb_mysql.go deleted file mode 100644 index 23bdc9ee..00000000 --- a/orm/qb_mysql.go +++ /dev/null @@ -1,185 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package orm - -import ( - "fmt" - "strconv" - "strings" -) - -// CommaSpace is the separation -const CommaSpace = ", " - -// MySQLQueryBuilder is the SQL build -type MySQLQueryBuilder struct { - Tokens []string -} - -// Select will join the fields -func (qb *MySQLQueryBuilder) Select(fields ...string) QueryBuilder { - qb.Tokens = append(qb.Tokens, "SELECT", strings.Join(fields, CommaSpace)) - return qb -} - -// ForUpdate add the FOR UPDATE clause -func (qb *MySQLQueryBuilder) ForUpdate() QueryBuilder { - qb.Tokens = append(qb.Tokens, "FOR UPDATE") - return qb -} - -// From join the tables -func (qb *MySQLQueryBuilder) From(tables ...string) QueryBuilder { - qb.Tokens = append(qb.Tokens, "FROM", strings.Join(tables, CommaSpace)) - return qb -} - -// InnerJoin INNER JOIN the table -func (qb *MySQLQueryBuilder) InnerJoin(table string) QueryBuilder { - qb.Tokens = append(qb.Tokens, "INNER JOIN", table) - return qb -} - -// LeftJoin LEFT JOIN the table -func (qb *MySQLQueryBuilder) LeftJoin(table string) QueryBuilder { - qb.Tokens = append(qb.Tokens, "LEFT JOIN", table) - return qb -} - -// RightJoin RIGHT JOIN the table -func (qb *MySQLQueryBuilder) RightJoin(table string) QueryBuilder { - qb.Tokens = append(qb.Tokens, "RIGHT JOIN", table) - return qb -} - -// On join with on cond -func (qb *MySQLQueryBuilder) On(cond string) QueryBuilder { - qb.Tokens = append(qb.Tokens, "ON", cond) - return qb -} - -// Where join the Where cond -func (qb *MySQLQueryBuilder) Where(cond string) QueryBuilder { - qb.Tokens = append(qb.Tokens, "WHERE", cond) - return qb -} - -// And join the and cond -func (qb *MySQLQueryBuilder) And(cond string) QueryBuilder { - qb.Tokens = append(qb.Tokens, "AND", cond) - return qb -} - -// Or join the or cond -func (qb *MySQLQueryBuilder) Or(cond string) QueryBuilder { - qb.Tokens = append(qb.Tokens, "OR", cond) - return qb -} - -// In join the IN (vals) -func (qb *MySQLQueryBuilder) In(vals ...string) QueryBuilder { - qb.Tokens = append(qb.Tokens, "IN", "(", strings.Join(vals, CommaSpace), ")") - return qb -} - -// OrderBy join the Order by fields -func (qb *MySQLQueryBuilder) OrderBy(fields ...string) QueryBuilder { - qb.Tokens = append(qb.Tokens, "ORDER BY", strings.Join(fields, CommaSpace)) - return qb -} - -// Asc join the asc -func (qb *MySQLQueryBuilder) Asc() QueryBuilder { - qb.Tokens = append(qb.Tokens, "ASC") - return qb -} - -// Desc join the desc -func (qb *MySQLQueryBuilder) Desc() QueryBuilder { - qb.Tokens = append(qb.Tokens, "DESC") - return qb -} - -// Limit join the limit num -func (qb *MySQLQueryBuilder) Limit(limit int) QueryBuilder { - qb.Tokens = append(qb.Tokens, "LIMIT", strconv.Itoa(limit)) - return qb -} - -// Offset join the offset num -func (qb *MySQLQueryBuilder) Offset(offset int) QueryBuilder { - qb.Tokens = append(qb.Tokens, "OFFSET", strconv.Itoa(offset)) - return qb -} - -// GroupBy join the Group by fields -func (qb *MySQLQueryBuilder) GroupBy(fields ...string) QueryBuilder { - qb.Tokens = append(qb.Tokens, "GROUP BY", strings.Join(fields, CommaSpace)) - return qb -} - -// Having join the Having cond -func (qb *MySQLQueryBuilder) Having(cond string) QueryBuilder { - qb.Tokens = append(qb.Tokens, "HAVING", cond) - return qb -} - -// Update join the update table -func (qb *MySQLQueryBuilder) Update(tables ...string) QueryBuilder { - qb.Tokens = append(qb.Tokens, "UPDATE", strings.Join(tables, CommaSpace)) - return qb -} - -// Set join the set kv -func (qb *MySQLQueryBuilder) Set(kv ...string) QueryBuilder { - qb.Tokens = append(qb.Tokens, "SET", strings.Join(kv, CommaSpace)) - return qb -} - -// Delete join the Delete tables -func (qb *MySQLQueryBuilder) Delete(tables ...string) QueryBuilder { - qb.Tokens = append(qb.Tokens, "DELETE") - if len(tables) != 0 { - qb.Tokens = append(qb.Tokens, strings.Join(tables, CommaSpace)) - } - return qb -} - -// InsertInto join the insert SQL -func (qb *MySQLQueryBuilder) InsertInto(table string, fields ...string) QueryBuilder { - qb.Tokens = append(qb.Tokens, "INSERT INTO", table) - if len(fields) != 0 { - fieldsStr := strings.Join(fields, CommaSpace) - qb.Tokens = append(qb.Tokens, "(", fieldsStr, ")") - } - return qb -} - -// Values join the Values(vals) -func (qb *MySQLQueryBuilder) Values(vals ...string) QueryBuilder { - valsStr := strings.Join(vals, CommaSpace) - qb.Tokens = append(qb.Tokens, "VALUES", "(", valsStr, ")") - return qb -} - -// Subquery join the sub as alias -func (qb *MySQLQueryBuilder) Subquery(sub string, alias string) string { - return fmt.Sprintf("(%s) AS %s", sub, alias) -} - -// String join all Tokens -func (qb *MySQLQueryBuilder) String() string { - return strings.Join(qb.Tokens, " ") -} diff --git a/orm/qb_tidb.go b/orm/qb_tidb.go deleted file mode 100644 index 87b3ae84..00000000 --- a/orm/qb_tidb.go +++ /dev/null @@ -1,182 +0,0 @@ -// Copyright 2015 TiDB Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package orm - -import ( - "fmt" - "strconv" - "strings" -) - -// TiDBQueryBuilder is the SQL build -type TiDBQueryBuilder struct { - Tokens []string -} - -// Select will join the fields -func (qb *TiDBQueryBuilder) Select(fields ...string) QueryBuilder { - qb.Tokens = append(qb.Tokens, "SELECT", strings.Join(fields, CommaSpace)) - return qb -} - -// ForUpdate add the FOR UPDATE clause -func (qb *TiDBQueryBuilder) ForUpdate() QueryBuilder { - qb.Tokens = append(qb.Tokens, "FOR UPDATE") - return qb -} - -// From join the tables -func (qb *TiDBQueryBuilder) From(tables ...string) QueryBuilder { - qb.Tokens = append(qb.Tokens, "FROM", strings.Join(tables, CommaSpace)) - return qb -} - -// InnerJoin INNER JOIN the table -func (qb *TiDBQueryBuilder) InnerJoin(table string) QueryBuilder { - qb.Tokens = append(qb.Tokens, "INNER JOIN", table) - return qb -} - -// LeftJoin LEFT JOIN the table -func (qb *TiDBQueryBuilder) LeftJoin(table string) QueryBuilder { - qb.Tokens = append(qb.Tokens, "LEFT JOIN", table) - return qb -} - -// RightJoin RIGHT JOIN the table -func (qb *TiDBQueryBuilder) RightJoin(table string) QueryBuilder { - qb.Tokens = append(qb.Tokens, "RIGHT JOIN", table) - return qb -} - -// On join with on cond -func (qb *TiDBQueryBuilder) On(cond string) QueryBuilder { - qb.Tokens = append(qb.Tokens, "ON", cond) - return qb -} - -// Where join the Where cond -func (qb *TiDBQueryBuilder) Where(cond string) QueryBuilder { - qb.Tokens = append(qb.Tokens, "WHERE", cond) - return qb -} - -// And join the and cond -func (qb *TiDBQueryBuilder) And(cond string) QueryBuilder { - qb.Tokens = append(qb.Tokens, "AND", cond) - return qb -} - -// Or join the or cond -func (qb *TiDBQueryBuilder) Or(cond string) QueryBuilder { - qb.Tokens = append(qb.Tokens, "OR", cond) - return qb -} - -// In join the IN (vals) -func (qb *TiDBQueryBuilder) In(vals ...string) QueryBuilder { - qb.Tokens = append(qb.Tokens, "IN", "(", strings.Join(vals, CommaSpace), ")") - return qb -} - -// OrderBy join the Order by fields -func (qb *TiDBQueryBuilder) OrderBy(fields ...string) QueryBuilder { - qb.Tokens = append(qb.Tokens, "ORDER BY", strings.Join(fields, CommaSpace)) - return qb -} - -// Asc join the asc -func (qb *TiDBQueryBuilder) Asc() QueryBuilder { - qb.Tokens = append(qb.Tokens, "ASC") - return qb -} - -// Desc join the desc -func (qb *TiDBQueryBuilder) Desc() QueryBuilder { - qb.Tokens = append(qb.Tokens, "DESC") - return qb -} - -// Limit join the limit num -func (qb *TiDBQueryBuilder) Limit(limit int) QueryBuilder { - qb.Tokens = append(qb.Tokens, "LIMIT", strconv.Itoa(limit)) - return qb -} - -// Offset join the offset num -func (qb *TiDBQueryBuilder) Offset(offset int) QueryBuilder { - qb.Tokens = append(qb.Tokens, "OFFSET", strconv.Itoa(offset)) - return qb -} - -// GroupBy join the Group by fields -func (qb *TiDBQueryBuilder) GroupBy(fields ...string) QueryBuilder { - qb.Tokens = append(qb.Tokens, "GROUP BY", strings.Join(fields, CommaSpace)) - return qb -} - -// Having join the Having cond -func (qb *TiDBQueryBuilder) Having(cond string) QueryBuilder { - qb.Tokens = append(qb.Tokens, "HAVING", cond) - return qb -} - -// Update join the update table -func (qb *TiDBQueryBuilder) Update(tables ...string) QueryBuilder { - qb.Tokens = append(qb.Tokens, "UPDATE", strings.Join(tables, CommaSpace)) - return qb -} - -// Set join the set kv -func (qb *TiDBQueryBuilder) Set(kv ...string) QueryBuilder { - qb.Tokens = append(qb.Tokens, "SET", strings.Join(kv, CommaSpace)) - return qb -} - -// Delete join the Delete tables -func (qb *TiDBQueryBuilder) Delete(tables ...string) QueryBuilder { - qb.Tokens = append(qb.Tokens, "DELETE") - if len(tables) != 0 { - qb.Tokens = append(qb.Tokens, strings.Join(tables, CommaSpace)) - } - return qb -} - -// InsertInto join the insert SQL -func (qb *TiDBQueryBuilder) InsertInto(table string, fields ...string) QueryBuilder { - qb.Tokens = append(qb.Tokens, "INSERT INTO", table) - if len(fields) != 0 { - fieldsStr := strings.Join(fields, CommaSpace) - qb.Tokens = append(qb.Tokens, "(", fieldsStr, ")") - } - return qb -} - -// Values join the Values(vals) -func (qb *TiDBQueryBuilder) Values(vals ...string) QueryBuilder { - valsStr := strings.Join(vals, CommaSpace) - qb.Tokens = append(qb.Tokens, "VALUES", "(", valsStr, ")") - return qb -} - -// Subquery join the sub as alias -func (qb *TiDBQueryBuilder) Subquery(sub string, alias string) string { - return fmt.Sprintf("(%s) AS %s", sub, alias) -} - -// String join all Tokens -func (qb *TiDBQueryBuilder) String() string { - return strings.Join(qb.Tokens, " ") -} diff --git a/orm/types.go b/orm/types.go deleted file mode 100644 index 75af7149..00000000 --- a/orm/types.go +++ /dev/null @@ -1,474 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package orm - -import ( - "context" - "database/sql" - "reflect" - "time" -) - -// Driver define database driver - -type Driver interface { - Name() string - Type() DriverType -} - -// Fielder define field info -type Fielder interface { - String() string - FieldType() int - SetRaw(interface{}) error - RawValue() interface{} -} - -// Ormer define the orm interface -type Ormer interface { - // read data to model - // for example: - // this will find User by Id field - // u = &User{Id: user.Id} - // err = Ormer.Read(u) - // this will find User by UserName field - // u = &User{UserName: "astaxie", Password: "pass"} - // err = Ormer.Read(u, "UserName") - Read(md interface{}, cols ...string) error - // Like Read(), but with "FOR UPDATE" clause, useful in transaction. - // Some databases are not support this feature. - ReadForUpdate(md interface{}, cols ...string) error - // Try to read a row from the database, or insert one if it doesn't exist - ReadOrCreate(md interface{}, col1 string, cols ...string) (bool, int64, error) - // insert model data to database - // for example: - // user := new(User) - // id, err = Ormer.Insert(user) - // user must be a pointer and Insert will set user's pk field - Insert(interface{}) (int64, error) - // mysql:InsertOrUpdate(model) or InsertOrUpdate(model,"colu=colu+value") - // if colu type is integer : can use(+-*/), string : convert(colu,"value") - // postgres: InsertOrUpdate(model,"conflictColumnName") or InsertOrUpdate(model,"conflictColumnName","colu=colu+value") - // if colu type is integer : can use(+-*/), string : colu || "value" - InsertOrUpdate(md interface{}, colConflitAndArgs ...string) (int64, error) - // insert some models to database - InsertMulti(bulk int, mds interface{}) (int64, error) - // update model to database. - // cols set the columns those want to update. - // find model by Id(pk) field and update columns specified by fields, if cols is null then update all columns - // for example: - // user := User{Id: 2} - // user.Langs = append(user.Langs, "zh-CN", "en-US") - // user.Extra.Name = "beego" - // user.Extra.Data = "orm" - // num, err = Ormer.Update(&user, "Langs", "Extra") - Update(md interface{}, cols ...string) (int64, error) - // delete model in database - Delete(md interface{}, cols ...string) (int64, error) - // load related models to md model. - // args are limit, offset int and order string. - // - // example: - // Ormer.LoadRelated(post,"Tags") - // for _,tag := range post.Tags{...} - //args[0] bool true useDefaultRelsDepth ; false depth 0 - //args[0] int loadRelationDepth - //args[1] int limit default limit 1000 - //args[2] int offset default offset 0 - //args[3] string order for example : "-Id" - // make sure the relation is defined in model struct tags. - LoadRelated(md interface{}, name string, args ...interface{}) (int64, error) - // create a models to models queryer - // for example: - // post := Post{Id: 4} - // m2m := Ormer.QueryM2M(&post, "Tags") - QueryM2M(md interface{}, name string) QueryM2Mer - // return a QuerySeter for table operations. - // table name can be string or struct. - // e.g. QueryTable("user"), QueryTable(&user{}) or QueryTable((*User)(nil)), - QueryTable(ptrStructOrTableName interface{}) QuerySeter - // switch to another registered database driver by given name. - Using(name string) error - // begin transaction - // for example: - // o := NewOrm() - // err := o.Begin() - // ... - // err = o.Rollback() - Begin() error - // begin transaction with provided context and option - // the provided context is used until the transaction is committed or rolled back. - // if the context is canceled, the transaction will be rolled back. - // the provided TxOptions is optional and may be nil if defaults should be used. - // if a non-default isolation level is used that the driver doesn't support, an error will be returned. - // for example: - // o := NewOrm() - // err := o.BeginTx(context.Background(), &sql.TxOptions{Isolation: sql.LevelRepeatableRead}) - // ... - // err = o.Rollback() - BeginTx(ctx context.Context, opts *sql.TxOptions) error - // commit transaction - Commit() error - // rollback transaction - Rollback() error - // return a raw query seter for raw sql string. - // for example: - // ormer.Raw("UPDATE `user` SET `user_name` = ? WHERE `user_name` = ?", "slene", "testing").Exec() - // // update user testing's name to slene - Raw(query string, args ...interface{}) RawSeter - Driver() Driver - DBStats() *sql.DBStats -} - -// Inserter insert prepared statement -type Inserter interface { - Insert(interface{}) (int64, error) - Close() error -} - -// QuerySeter query seter -type QuerySeter interface { - // add condition expression to QuerySeter. - // for example: - // filter by UserName == 'slene' - // qs.Filter("UserName", "slene") - // sql : left outer join profile on t0.id1==t1.id2 where t1.age == 28 - // Filter("profile__Age", 28) - // // time compare - // qs.Filter("created", time.Now()) - Filter(string, ...interface{}) QuerySeter - // add raw sql to querySeter. - // for example: - // qs.FilterRaw("user_id IN (SELECT id FROM profile WHERE age>=18)") - // //sql-> WHERE user_id IN (SELECT id FROM profile WHERE age>=18) - FilterRaw(string, string) QuerySeter - // add NOT condition to querySeter. - // have the same usage as Filter - Exclude(string, ...interface{}) QuerySeter - // set condition to QuerySeter. - // sql's where condition - // cond := orm.NewCondition() - // cond1 := cond.And("profile__isnull", false).AndNot("status__in", 1).Or("profile__age__gt", 2000) - // //sql-> WHERE T0.`profile_id` IS NOT NULL AND NOT T0.`Status` IN (?) OR T1.`age` > 2000 - // num, err := qs.SetCond(cond1).Count() - SetCond(*Condition) QuerySeter - // get condition from QuerySeter. - // sql's where condition - // cond := orm.NewCondition() - // cond = cond.And("profile__isnull", false).AndNot("status__in", 1) - // qs = qs.SetCond(cond) - // cond = qs.GetCond() - // cond := cond.Or("profile__age__gt", 2000) - // //sql-> WHERE T0.`profile_id` IS NOT NULL AND NOT T0.`Status` IN (?) OR T1.`age` > 2000 - // num, err := qs.SetCond(cond).Count() - GetCond() *Condition - // add LIMIT value. - // args[0] means offset, e.g. LIMIT num,offset. - // if Limit <= 0 then Limit will be set to default limit ,eg 1000 - // if QuerySeter doesn't call Limit, the sql's Limit will be set to default limit, eg 1000 - // for example: - // qs.Limit(10, 2) - // // sql-> limit 10 offset 2 - Limit(limit interface{}, args ...interface{}) QuerySeter - // add OFFSET value - // same as Limit function's args[0] - Offset(offset interface{}) QuerySeter - // add GROUP BY expression - // for example: - // qs.GroupBy("id") - GroupBy(exprs ...string) QuerySeter - // add ORDER expression. - // "column" means ASC, "-column" means DESC. - // for example: - // qs.OrderBy("-status") - OrderBy(exprs ...string) QuerySeter - // set relation model to query together. - // it will query relation models and assign to parent model. - // for example: - // // will load all related fields use left join . - // qs.RelatedSel().One(&user) - // // will load related field only profile - // qs.RelatedSel("profile").One(&user) - // user.Profile.Age = 32 - RelatedSel(params ...interface{}) QuerySeter - // Set Distinct - // for example: - // o.QueryTable("policy").Filter("Groups__Group__Users__User", user). - // Distinct(). - // All(&permissions) - Distinct() QuerySeter - // set FOR UPDATE to query. - // for example: - // o.QueryTable("user").Filter("uid", uid).ForUpdate().All(&users) - ForUpdate() QuerySeter - // return QuerySeter execution result number - // for example: - // num, err = qs.Filter("profile__age__gt", 28).Count() - Count() (int64, error) - // check result empty or not after QuerySeter executed - // the same as QuerySeter.Count > 0 - Exist() bool - // execute update with parameters - // for example: - // num, err = qs.Filter("user_name", "slene").Update(Params{ - // "Nums": ColValue(Col_Minus, 50), - // }) // user slene's Nums will minus 50 - // num, err = qs.Filter("UserName", "slene").Update(Params{ - // "user_name": "slene2" - // }) // user slene's name will change to slene2 - Update(values Params) (int64, error) - // delete from table - //for example: - // num ,err = qs.Filter("user_name__in", "testing1", "testing2").Delete() - // //delete two user who's name is testing1 or testing2 - Delete() (int64, error) - // return a insert queryer. - // it can be used in times. - // example: - // i,err := sq.PrepareInsert() - // num, err = i.Insert(&user1) // user table will add one record user1 at once - // num, err = i.Insert(&user2) // user table will add one record user2 at once - // err = i.Close() //don't forget call Close - PrepareInsert() (Inserter, error) - // query all data and map to containers. - // cols means the columns when querying. - // for example: - // var users []*User - // qs.All(&users) // users[0],users[1],users[2] ... - All(container interface{}, cols ...string) (int64, error) - // query one row data and map to containers. - // cols means the columns when querying. - // for example: - // var user User - // qs.One(&user) //user.UserName == "slene" - One(container interface{}, cols ...string) error - // query all data and map to []map[string]interface. - // expres means condition expression. - // it converts data to []map[column]value. - // for example: - // var maps []Params - // qs.Values(&maps) //maps[0]["UserName"]=="slene" - Values(results *[]Params, exprs ...string) (int64, error) - // query all data and map to [][]interface - // it converts data to [][column_index]value - // for example: - // var list []ParamsList - // qs.ValuesList(&list) // list[0][1] == "slene" - ValuesList(results *[]ParamsList, exprs ...string) (int64, error) - // query all data and map to []interface. - // it's designed for one column record set, auto change to []value, not [][column]value. - // for example: - // var list ParamsList - // qs.ValuesFlat(&list, "UserName") // list[0] == "slene" - ValuesFlat(result *ParamsList, expr string) (int64, error) - // query all rows into map[string]interface with specify key and value column name. - // keyCol = "name", valueCol = "value" - // table data - // name | value - // total | 100 - // found | 200 - // to map[string]interface{}{ - // "total": 100, - // "found": 200, - // } - RowsToMap(result *Params, keyCol, valueCol string) (int64, error) - // query all rows into struct with specify key and value column name. - // keyCol = "name", valueCol = "value" - // table data - // name | value - // total | 100 - // found | 200 - // to struct { - // Total int - // Found int - // } - RowsToStruct(ptrStruct interface{}, keyCol, valueCol string) (int64, error) -} - -// QueryM2Mer model to model query struct -// all operations are on the m2m table only, will not affect the origin model table -type QueryM2Mer interface { - // add models to origin models when creating queryM2M. - // example: - // m2m := orm.QueryM2M(post,"Tag") - // m2m.Add(&Tag1{},&Tag2{}) - // for _,tag := range post.Tags{}{ ... } - // param could also be any of the follow - // []*Tag{{Id:3,Name: "TestTag1"}, {Id:4,Name: "TestTag2"}} - // &Tag{Id:5,Name: "TestTag3"} - // []interface{}{&Tag{Id:6,Name: "TestTag4"}} - // insert one or more rows to m2m table - // make sure the relation is defined in post model struct tag. - Add(...interface{}) (int64, error) - // remove models following the origin model relationship - // only delete rows from m2m table - // for example: - //tag3 := &Tag{Id:5,Name: "TestTag3"} - //num, err = m2m.Remove(tag3) - Remove(...interface{}) (int64, error) - // check model is existed in relationship of origin model - Exist(interface{}) bool - // clean all models in related of origin model - Clear() (int64, error) - // count all related models of origin model - Count() (int64, error) -} - -// RawPreparer raw query statement -type RawPreparer interface { - Exec(...interface{}) (sql.Result, error) - Close() error -} - -// RawSeter raw query seter -// create From Ormer.Raw -// for example: -// sql := fmt.Sprintf("SELECT %sid%s,%sname%s FROM %suser%s WHERE id = ?",Q,Q,Q,Q,Q,Q) -// rs := Ormer.Raw(sql, 1) -type RawSeter interface { - //execute sql and get result - Exec() (sql.Result, error) - //query data and map to container - //for example: - // var name string - // var id int - // rs.QueryRow(&id,&name) // id==2 name=="slene" - QueryRow(containers ...interface{}) error - - // query data rows and map to container - // var ids []int - // var names []int - // query = fmt.Sprintf("SELECT 'id','name' FROM %suser%s", Q, Q) - // num, err = dORM.Raw(query).QueryRows(&ids,&names) // ids=>{1,2},names=>{"nobody","slene"} - QueryRows(containers ...interface{}) (int64, error) - SetArgs(...interface{}) RawSeter - // query data to []map[string]interface - // see QuerySeter's Values - Values(container *[]Params, cols ...string) (int64, error) - // query data to [][]interface - // see QuerySeter's ValuesList - ValuesList(container *[]ParamsList, cols ...string) (int64, error) - // query data to []interface - // see QuerySeter's ValuesFlat - ValuesFlat(container *ParamsList, cols ...string) (int64, error) - // query all rows into map[string]interface with specify key and value column name. - // keyCol = "name", valueCol = "value" - // table data - // name | value - // total | 100 - // found | 200 - // to map[string]interface{}{ - // "total": 100, - // "found": 200, - // } - RowsToMap(result *Params, keyCol, valueCol string) (int64, error) - // query all rows into struct with specify key and value column name. - // keyCol = "name", valueCol = "value" - // table data - // name | value - // total | 100 - // found | 200 - // to struct { - // Total int - // Found int - // } - RowsToStruct(ptrStruct interface{}, keyCol, valueCol string) (int64, error) - - // return prepared raw statement for used in times. - // for example: - // pre, err := dORM.Raw("INSERT INTO tag (name) VALUES (?)").Prepare() - // r, err := pre.Exec("name1") // INSERT INTO tag (name) VALUES (`name1`) - Prepare() (RawPreparer, error) -} - -// stmtQuerier statement querier -type stmtQuerier interface { - Close() error - Exec(args ...interface{}) (sql.Result, error) - //ExecContext(ctx context.Context, args ...interface{}) (sql.Result, error) - Query(args ...interface{}) (*sql.Rows, error) - //QueryContext(args ...interface{}) (*sql.Rows, error) - QueryRow(args ...interface{}) *sql.Row - //QueryRowContext(ctx context.Context, args ...interface{}) *sql.Row -} - -// db querier -type dbQuerier interface { - Prepare(query string) (*sql.Stmt, error) - PrepareContext(ctx context.Context, query string) (*sql.Stmt, error) - Exec(query string, args ...interface{}) (sql.Result, error) - ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) - Query(query string, args ...interface{}) (*sql.Rows, error) - QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) - QueryRow(query string, args ...interface{}) *sql.Row - QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row -} - -// type DB interface { -// Begin() (*sql.Tx, error) -// Prepare(query string) (stmtQuerier, error) -// Exec(query string, args ...interface{}) (sql.Result, error) -// Query(query string, args ...interface{}) (*sql.Rows, error) -// QueryRow(query string, args ...interface{}) *sql.Row -// } - -// transaction beginner -type txer interface { - Begin() (*sql.Tx, error) - BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error) -} - -// transaction ending -type txEnder interface { - Commit() error - Rollback() error -} - -// base database struct -type dbBaser interface { - Read(dbQuerier, *modelInfo, reflect.Value, *time.Location, []string, bool) error - Insert(dbQuerier, *modelInfo, reflect.Value, *time.Location) (int64, error) - InsertOrUpdate(dbQuerier, *modelInfo, reflect.Value, *alias, ...string) (int64, error) - InsertMulti(dbQuerier, *modelInfo, reflect.Value, int, *time.Location) (int64, error) - InsertValue(dbQuerier, *modelInfo, bool, []string, []interface{}) (int64, error) - InsertStmt(stmtQuerier, *modelInfo, reflect.Value, *time.Location) (int64, error) - Update(dbQuerier, *modelInfo, reflect.Value, *time.Location, []string) (int64, error) - Delete(dbQuerier, *modelInfo, reflect.Value, *time.Location, []string) (int64, error) - ReadBatch(dbQuerier, *querySet, *modelInfo, *Condition, interface{}, *time.Location, []string) (int64, error) - SupportUpdateJoin() bool - UpdateBatch(dbQuerier, *querySet, *modelInfo, *Condition, Params, *time.Location) (int64, error) - DeleteBatch(dbQuerier, *querySet, *modelInfo, *Condition, *time.Location) (int64, error) - Count(dbQuerier, *querySet, *modelInfo, *Condition, *time.Location) (int64, error) - OperatorSQL(string) string - GenerateOperatorSQL(*modelInfo, *fieldInfo, string, []interface{}, *time.Location) (string, []interface{}) - GenerateOperatorLeftCol(*fieldInfo, string, *string) - PrepareInsert(dbQuerier, *modelInfo) (stmtQuerier, string, error) - ReadValues(dbQuerier, *querySet, *modelInfo, *Condition, []string, interface{}, *time.Location) (int64, error) - RowsTo(dbQuerier, *querySet, *modelInfo, *Condition, interface{}, string, string, *time.Location) (int64, error) - MaxLimit() uint64 - TableQuote() string - ReplaceMarks(*string) - HasReturningID(*modelInfo, *string) bool - TimeFromDB(*time.Time, *time.Location) - TimeToDB(*time.Time, *time.Location) - DbTypes() map[string]string - GetTables(dbQuerier) (map[string]bool, error) - GetColumns(dbQuerier, string) (map[string][3]string, error) - ShowTablesQuery() string - ShowColumnsQuery(string) string - IndexExists(dbQuerier, string, string) bool - collectFieldValue(*modelInfo, *fieldInfo, reflect.Value, bool, *time.Location) (interface{}, error) - setval(dbQuerier, *modelInfo, []string) error -} diff --git a/orm/utils.go b/orm/utils.go deleted file mode 100644 index 3ff76772..00000000 --- a/orm/utils.go +++ /dev/null @@ -1,319 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package orm - -import ( - "fmt" - "math/big" - "reflect" - "strconv" - "strings" - "time" -) - -type fn func(string) string - -var ( - nameStrategyMap = map[string]fn{ - defaultNameStrategy: snakeString, - SnakeAcronymNameStrategy: snakeStringWithAcronym, - } - defaultNameStrategy = "snakeString" - SnakeAcronymNameStrategy = "snakeStringWithAcronym" - nameStrategy = defaultNameStrategy -) - -// StrTo is the target string -type StrTo string - -// Set string -func (f *StrTo) Set(v string) { - if v != "" { - *f = StrTo(v) - } else { - f.Clear() - } -} - -// Clear string -func (f *StrTo) Clear() { - *f = StrTo(0x1E) -} - -// Exist check string exist -func (f StrTo) Exist() bool { - return string(f) != string(0x1E) -} - -// Bool string to bool -func (f StrTo) Bool() (bool, error) { - return strconv.ParseBool(f.String()) -} - -// Float32 string to float32 -func (f StrTo) Float32() (float32, error) { - v, err := strconv.ParseFloat(f.String(), 32) - return float32(v), err -} - -// Float64 string to float64 -func (f StrTo) Float64() (float64, error) { - return strconv.ParseFloat(f.String(), 64) -} - -// Int string to int -func (f StrTo) Int() (int, error) { - v, err := strconv.ParseInt(f.String(), 10, 32) - return int(v), err -} - -// Int8 string to int8 -func (f StrTo) Int8() (int8, error) { - v, err := strconv.ParseInt(f.String(), 10, 8) - return int8(v), err -} - -// Int16 string to int16 -func (f StrTo) Int16() (int16, error) { - v, err := strconv.ParseInt(f.String(), 10, 16) - return int16(v), err -} - -// Int32 string to int32 -func (f StrTo) Int32() (int32, error) { - v, err := strconv.ParseInt(f.String(), 10, 32) - return int32(v), err -} - -// Int64 string to int64 -func (f StrTo) Int64() (int64, error) { - v, err := strconv.ParseInt(f.String(), 10, 64) - if err != nil { - i := new(big.Int) - ni, ok := i.SetString(f.String(), 10) // octal - if !ok { - return v, err - } - return ni.Int64(), nil - } - return v, err -} - -// Uint string to uint -func (f StrTo) Uint() (uint, error) { - v, err := strconv.ParseUint(f.String(), 10, 32) - return uint(v), err -} - -// Uint8 string to uint8 -func (f StrTo) Uint8() (uint8, error) { - v, err := strconv.ParseUint(f.String(), 10, 8) - return uint8(v), err -} - -// Uint16 string to uint16 -func (f StrTo) Uint16() (uint16, error) { - v, err := strconv.ParseUint(f.String(), 10, 16) - return uint16(v), err -} - -// Uint32 string to uint32 -func (f StrTo) Uint32() (uint32, error) { - v, err := strconv.ParseUint(f.String(), 10, 32) - return uint32(v), err -} - -// Uint64 string to uint64 -func (f StrTo) Uint64() (uint64, error) { - v, err := strconv.ParseUint(f.String(), 10, 64) - if err != nil { - i := new(big.Int) - ni, ok := i.SetString(f.String(), 10) - if !ok { - return v, err - } - return ni.Uint64(), nil - } - return v, err -} - -// String string to string -func (f StrTo) String() string { - if f.Exist() { - return string(f) - } - return "" -} - -// ToStr interface to string -func ToStr(value interface{}, args ...int) (s string) { - switch v := value.(type) { - case bool: - s = strconv.FormatBool(v) - case float32: - s = strconv.FormatFloat(float64(v), 'f', argInt(args).Get(0, -1), argInt(args).Get(1, 32)) - case float64: - s = strconv.FormatFloat(v, 'f', argInt(args).Get(0, -1), argInt(args).Get(1, 64)) - case int: - s = strconv.FormatInt(int64(v), argInt(args).Get(0, 10)) - case int8: - s = strconv.FormatInt(int64(v), argInt(args).Get(0, 10)) - case int16: - s = strconv.FormatInt(int64(v), argInt(args).Get(0, 10)) - case int32: - s = strconv.FormatInt(int64(v), argInt(args).Get(0, 10)) - case int64: - s = strconv.FormatInt(v, argInt(args).Get(0, 10)) - case uint: - s = strconv.FormatUint(uint64(v), argInt(args).Get(0, 10)) - case uint8: - s = strconv.FormatUint(uint64(v), argInt(args).Get(0, 10)) - case uint16: - s = strconv.FormatUint(uint64(v), argInt(args).Get(0, 10)) - case uint32: - s = strconv.FormatUint(uint64(v), argInt(args).Get(0, 10)) - case uint64: - s = strconv.FormatUint(v, argInt(args).Get(0, 10)) - case string: - s = v - case []byte: - s = string(v) - default: - s = fmt.Sprintf("%v", v) - } - return s -} - -// ToInt64 interface to int64 -func ToInt64(value interface{}) (d int64) { - val := reflect.ValueOf(value) - switch value.(type) { - case int, int8, int16, int32, int64: - d = val.Int() - case uint, uint8, uint16, uint32, uint64: - d = int64(val.Uint()) - default: - panic(fmt.Errorf("ToInt64 need numeric not `%T`", value)) - } - return -} - -func snakeStringWithAcronym(s string) string { - data := make([]byte, 0, len(s)*2) - num := len(s) - for i := 0; i < num; i++ { - d := s[i] - before := false - after := false - if i > 0 { - before = s[i-1] >= 'a' && s[i-1] <= 'z' - } - if i+1 < num { - after = s[i+1] >= 'a' && s[i+1] <= 'z' - } - if i > 0 && d >= 'A' && d <= 'Z' && (before || after) { - data = append(data, '_') - } - data = append(data, d) - } - return strings.ToLower(string(data[:])) -} - -// snake string, XxYy to xx_yy , XxYY to xx_y_y -func snakeString(s string) string { - data := make([]byte, 0, len(s)*2) - j := false - num := len(s) - for i := 0; i < num; i++ { - d := s[i] - if i > 0 && d >= 'A' && d <= 'Z' && j { - data = append(data, '_') - } - if d != '_' { - j = true - } - data = append(data, d) - } - return strings.ToLower(string(data[:])) -} - -// SetNameStrategy set different name strategy -func SetNameStrategy(s string) { - if SnakeAcronymNameStrategy != s { - nameStrategy = defaultNameStrategy - } - nameStrategy = s -} - -// camel string, xx_yy to XxYy -func camelString(s string) string { - data := make([]byte, 0, len(s)) - flag, num := true, len(s)-1 - for i := 0; i <= num; i++ { - d := s[i] - if d == '_' { - flag = true - continue - } else if flag { - if d >= 'a' && d <= 'z' { - d = d - 32 - } - flag = false - } - data = append(data, d) - } - return string(data[:]) -} - -type argString []string - -// get string by index from string slice -func (a argString) Get(i int, args ...string) (r string) { - if i >= 0 && i < len(a) { - r = a[i] - } else if len(args) > 0 { - r = args[0] - } - return -} - -type argInt []int - -// get int by index from int slice -func (a argInt) Get(i int, args ...int) (r int) { - if i >= 0 && i < len(a) { - r = a[i] - } - if len(args) > 0 { - r = args[0] - } - return -} - -// parse time to string with location -func timeParse(dateString, format string) (time.Time, error) { - tp, err := time.ParseInLocation(format, dateString, DefaultTimeLoc) - return tp, err -} - -// get pointer indirect type -func indirectType(v reflect.Type) reflect.Type { - switch v.Kind() { - case reflect.Ptr: - return indirectType(v.Elem()) - default: - return v - } -} diff --git a/orm/utils_test.go b/orm/utils_test.go deleted file mode 100644 index 7d94cada..00000000 --- a/orm/utils_test.go +++ /dev/null @@ -1,70 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package orm - -import ( - "testing" -) - -func TestCamelString(t *testing.T) { - snake := []string{"pic_url", "hello_world_", "hello__World", "_HelLO_Word", "pic_url_1", "pic_url__1"} - camel := []string{"PicUrl", "HelloWorld", "HelloWorld", "HelLOWord", "PicUrl1", "PicUrl1"} - - answer := make(map[string]string) - for i, v := range snake { - answer[v] = camel[i] - } - - for _, v := range snake { - res := camelString(v) - if res != answer[v] { - t.Error("Unit Test Fail:", v, res, answer[v]) - } - } -} - -func TestSnakeString(t *testing.T) { - camel := []string{"PicUrl", "HelloWorld", "HelloWorld", "HelLOWord", "PicUrl1", "XyXX"} - snake := []string{"pic_url", "hello_world", "hello_world", "hel_l_o_word", "pic_url1", "xy_x_x"} - - answer := make(map[string]string) - for i, v := range camel { - answer[v] = snake[i] - } - - for _, v := range camel { - res := snakeString(v) - if res != answer[v] { - t.Error("Unit Test Fail:", v, res, answer[v]) - } - } -} - -func TestSnakeStringWithAcronym(t *testing.T) { - camel := []string{"ID", "PicURL", "HelloWorld", "HelloWorld", "HelLOWord", "PicUrl1", "XyXX"} - snake := []string{"id", "pic_url", "hello_world", "hello_world", "hel_lo_word", "pic_url1", "xy_xx"} - - answer := make(map[string]string) - for i, v := range camel { - answer[v] = snake[i] - } - - for _, v := range camel { - res := snakeStringWithAcronym(v) - if res != answer[v] { - t.Error("Unit Test Fail:", v, res, answer[v]) - } - } -} diff --git a/parser.go b/parser.go deleted file mode 100644 index 3a311894..00000000 --- a/parser.go +++ /dev/null @@ -1,591 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package beego - -import ( - "encoding/json" - "errors" - "fmt" - "go/ast" - "go/parser" - "go/token" - "io/ioutil" - "os" - "path/filepath" - "regexp" - "sort" - "strconv" - "strings" - "unicode" - - "github.com/astaxie/beego/context/param" - "github.com/astaxie/beego/logs" - "github.com/astaxie/beego/utils" -) - -var globalRouterTemplate = `package {{.routersDir}} - -import ( - "github.com/astaxie/beego" - "github.com/astaxie/beego/context/param"{{.globalimport}} -) - -func init() { -{{.globalinfo}} -} -` - -var ( - lastupdateFilename = "lastupdate.tmp" - commentFilename string - pkgLastupdate map[string]int64 - genInfoList map[string][]ControllerComments - - routerHooks = map[string]int{ - "beego.BeforeStatic": BeforeStatic, - "beego.BeforeRouter": BeforeRouter, - "beego.BeforeExec": BeforeExec, - "beego.AfterExec": AfterExec, - "beego.FinishRouter": FinishRouter, - } - - routerHooksMapping = map[int]string{ - BeforeStatic: "beego.BeforeStatic", - BeforeRouter: "beego.BeforeRouter", - BeforeExec: "beego.BeforeExec", - AfterExec: "beego.AfterExec", - FinishRouter: "beego.FinishRouter", - } -) - -const commentPrefix = "commentsRouter_" - -func init() { - pkgLastupdate = make(map[string]int64) -} - -func parserPkg(pkgRealpath, pkgpath string) error { - rep := strings.NewReplacer("\\", "_", "/", "_", ".", "_") - commentFilename, _ = filepath.Rel(AppPath, pkgRealpath) - commentFilename = commentPrefix + rep.Replace(commentFilename) + ".go" - if !compareFile(pkgRealpath) { - logs.Info(pkgRealpath + " no changed") - return nil - } - genInfoList = make(map[string][]ControllerComments) - fileSet := token.NewFileSet() - astPkgs, err := parser.ParseDir(fileSet, pkgRealpath, func(info os.FileInfo) bool { - name := info.Name() - return !info.IsDir() && !strings.HasPrefix(name, ".") && strings.HasSuffix(name, ".go") - }, parser.ParseComments) - - if err != nil { - return err - } - for _, pkg := range astPkgs { - for _, fl := range pkg.Files { - for _, d := range fl.Decls { - switch specDecl := d.(type) { - case *ast.FuncDecl: - if specDecl.Recv != nil { - exp, ok := specDecl.Recv.List[0].Type.(*ast.StarExpr) // Check that the type is correct first beforing throwing to parser - if ok { - parserComments(specDecl, fmt.Sprint(exp.X), pkgpath) - } - } - } - } - } - } - genRouterCode(pkgRealpath) - savetoFile(pkgRealpath) - return nil -} - -type parsedComment struct { - routerPath string - methods []string - params map[string]parsedParam - filters []parsedFilter - imports []parsedImport -} - -type parsedImport struct { - importPath string - importAlias string -} - -type parsedFilter struct { - pattern string - pos int - filter string - params []bool -} - -type parsedParam struct { - name string - datatype string - location string - defValue string - required bool -} - -func parserComments(f *ast.FuncDecl, controllerName, pkgpath string) error { - if f.Doc != nil { - parsedComments, err := parseComment(f.Doc.List) - if err != nil { - return err - } - for _, parsedComment := range parsedComments { - if parsedComment.routerPath != "" { - key := pkgpath + ":" + controllerName - cc := ControllerComments{} - cc.Method = f.Name.String() - cc.Router = parsedComment.routerPath - cc.AllowHTTPMethods = parsedComment.methods - cc.MethodParams = buildMethodParams(f.Type.Params.List, parsedComment) - cc.FilterComments = buildFilters(parsedComment.filters) - cc.ImportComments = buildImports(parsedComment.imports) - genInfoList[key] = append(genInfoList[key], cc) - } - } - } - return nil -} - -func buildImports(pis []parsedImport) []*ControllerImportComments { - var importComments []*ControllerImportComments - - for _, pi := range pis { - importComments = append(importComments, &ControllerImportComments{ - ImportPath: pi.importPath, - ImportAlias: pi.importAlias, - }) - } - - return importComments -} - -func buildFilters(pfs []parsedFilter) []*ControllerFilterComments { - var filterComments []*ControllerFilterComments - - for _, pf := range pfs { - var ( - returnOnOutput bool - resetParams bool - ) - - if len(pf.params) >= 1 { - returnOnOutput = pf.params[0] - } - - if len(pf.params) >= 2 { - resetParams = pf.params[1] - } - - filterComments = append(filterComments, &ControllerFilterComments{ - Filter: pf.filter, - Pattern: pf.pattern, - Pos: pf.pos, - ReturnOnOutput: returnOnOutput, - ResetParams: resetParams, - }) - } - - return filterComments -} - -func buildMethodParams(funcParams []*ast.Field, pc *parsedComment) []*param.MethodParam { - result := make([]*param.MethodParam, 0, len(funcParams)) - for _, fparam := range funcParams { - for _, pName := range fparam.Names { - methodParam := buildMethodParam(fparam, pName.Name, pc) - result = append(result, methodParam) - } - } - return result -} - -func buildMethodParam(fparam *ast.Field, name string, pc *parsedComment) *param.MethodParam { - options := []param.MethodParamOption{} - if cparam, ok := pc.params[name]; ok { - //Build param from comment info - name = cparam.name - if cparam.required { - options = append(options, param.IsRequired) - } - switch cparam.location { - case "body": - options = append(options, param.InBody) - case "header": - options = append(options, param.InHeader) - case "path": - options = append(options, param.InPath) - } - if cparam.defValue != "" { - options = append(options, param.Default(cparam.defValue)) - } - } else { - if paramInPath(name, pc.routerPath) { - options = append(options, param.InPath) - } - } - return param.New(name, options...) -} - -func paramInPath(name, route string) bool { - return strings.HasSuffix(route, ":"+name) || - strings.Contains(route, ":"+name+"/") -} - -var routeRegex = regexp.MustCompile(`@router\s+(\S+)(?:\s+\[(\S+)\])?`) - -func parseComment(lines []*ast.Comment) (pcs []*parsedComment, err error) { - pcs = []*parsedComment{} - params := map[string]parsedParam{} - filters := []parsedFilter{} - imports := []parsedImport{} - - for _, c := range lines { - t := strings.TrimSpace(strings.TrimLeft(c.Text, "//")) - if strings.HasPrefix(t, "@Param") { - pv := getparams(strings.TrimSpace(strings.TrimLeft(t, "@Param"))) - if len(pv) < 4 { - logs.Error("Invalid @Param format. Needs at least 4 parameters") - } - p := parsedParam{} - names := strings.SplitN(pv[0], "=>", 2) - p.name = names[0] - funcParamName := p.name - if len(names) > 1 { - funcParamName = names[1] - } - p.location = pv[1] - p.datatype = pv[2] - switch len(pv) { - case 5: - p.required, _ = strconv.ParseBool(pv[3]) - case 6: - p.defValue = pv[3] - p.required, _ = strconv.ParseBool(pv[4]) - } - params[funcParamName] = p - } - } - - for _, c := range lines { - t := strings.TrimSpace(strings.TrimLeft(c.Text, "//")) - if strings.HasPrefix(t, "@Import") { - iv := getparams(strings.TrimSpace(strings.TrimLeft(t, "@Import"))) - if len(iv) == 0 || len(iv) > 2 { - logs.Error("Invalid @Import format. Only accepts 1 or 2 parameters") - continue - } - - p := parsedImport{} - p.importPath = iv[0] - - if len(iv) == 2 { - p.importAlias = iv[1] - } - - imports = append(imports, p) - } - } - -filterLoop: - for _, c := range lines { - t := strings.TrimSpace(strings.TrimLeft(c.Text, "//")) - if strings.HasPrefix(t, "@Filter") { - fv := getparams(strings.TrimSpace(strings.TrimLeft(t, "@Filter"))) - if len(fv) < 3 { - logs.Error("Invalid @Filter format. Needs at least 3 parameters") - continue filterLoop - } - - p := parsedFilter{} - p.pattern = fv[0] - posName := fv[1] - if pos, exists := routerHooks[posName]; exists { - p.pos = pos - } else { - logs.Error("Invalid @Filter pos: ", posName) - continue filterLoop - } - - p.filter = fv[2] - fvParams := fv[3:] - for _, fvParam := range fvParams { - switch fvParam { - case "true": - p.params = append(p.params, true) - case "false": - p.params = append(p.params, false) - default: - logs.Error("Invalid @Filter param: ", fvParam) - continue filterLoop - } - } - - filters = append(filters, p) - } - } - - for _, c := range lines { - var pc = &parsedComment{} - pc.params = params - pc.filters = filters - pc.imports = imports - - t := strings.TrimSpace(strings.TrimLeft(c.Text, "//")) - if strings.HasPrefix(t, "@router") { - t := strings.TrimSpace(strings.TrimLeft(c.Text, "//")) - matches := routeRegex.FindStringSubmatch(t) - if len(matches) == 3 { - pc.routerPath = matches[1] - methods := matches[2] - if methods == "" { - pc.methods = []string{"get"} - //pc.hasGet = true - } else { - pc.methods = strings.Split(methods, ",") - //pc.hasGet = strings.Contains(methods, "get") - } - pcs = append(pcs, pc) - } else { - return nil, errors.New("Router information is missing") - } - } - } - return -} - -// direct copy from bee\g_docs.go -// analysis params return []string -// @Param query form string true "The email for login" -// [query form string true "The email for login"] -func getparams(str string) []string { - var s []rune - var j int - var start bool - var r []string - var quoted int8 - for _, c := range str { - if unicode.IsSpace(c) && quoted == 0 { - if !start { - continue - } else { - start = false - j++ - r = append(r, string(s)) - s = make([]rune, 0) - continue - } - } - - start = true - if c == '"' { - quoted ^= 1 - continue - } - s = append(s, c) - } - if len(s) > 0 { - r = append(r, string(s)) - } - return r -} - -func genRouterCode(pkgRealpath string) { - os.Mkdir(getRouterDir(pkgRealpath), 0755) - logs.Info("generate router from comments") - var ( - globalinfo string - globalimport string - sortKey []string - ) - for k := range genInfoList { - sortKey = append(sortKey, k) - } - sort.Strings(sortKey) - for _, k := range sortKey { - cList := genInfoList[k] - sort.Sort(ControllerCommentsSlice(cList)) - for _, c := range cList { - allmethod := "nil" - if len(c.AllowHTTPMethods) > 0 { - allmethod = "[]string{" - for _, m := range c.AllowHTTPMethods { - allmethod += `"` + m + `",` - } - allmethod = strings.TrimRight(allmethod, ",") + "}" - } - - params := "nil" - if len(c.Params) > 0 { - params = "[]map[string]string{" - for _, p := range c.Params { - for k, v := range p { - params = params + `map[string]string{` + k + `:"` + v + `"},` - } - } - params = strings.TrimRight(params, ",") + "}" - } - - methodParams := "param.Make(" - if len(c.MethodParams) > 0 { - lines := make([]string, 0, len(c.MethodParams)) - for _, m := range c.MethodParams { - lines = append(lines, fmt.Sprint(m)) - } - methodParams += "\n " + - strings.Join(lines, ",\n ") + - ",\n " - } - methodParams += ")" - - imports := "" - if len(c.ImportComments) > 0 { - for _, i := range c.ImportComments { - var s string - if i.ImportAlias != "" { - s = fmt.Sprintf(` - %s "%s"`, i.ImportAlias, i.ImportPath) - } else { - s = fmt.Sprintf(` - "%s"`, i.ImportPath) - } - if !strings.Contains(globalimport, s) { - imports += s - } - } - } - - filters := "" - if len(c.FilterComments) > 0 { - for _, f := range c.FilterComments { - filters += fmt.Sprintf(` &beego.ControllerFilter{ - Pattern: "%s", - Pos: %s, - Filter: %s, - ReturnOnOutput: %v, - ResetParams: %v, - },`, f.Pattern, routerHooksMapping[f.Pos], f.Filter, f.ReturnOnOutput, f.ResetParams) - } - } - - if filters == "" { - filters = "nil" - } else { - filters = fmt.Sprintf(`[]*beego.ControllerFilter{ -%s - }`, filters) - } - - globalimport += imports - - globalinfo = globalinfo + ` - beego.GlobalControllerRouter["` + k + `"] = append(beego.GlobalControllerRouter["` + k + `"], - beego.ControllerComments{ - Method: "` + strings.TrimSpace(c.Method) + `", - ` + `Router: "` + c.Router + `"` + `, - AllowHTTPMethods: ` + allmethod + `, - MethodParams: ` + methodParams + `, - Filters: ` + filters + `, - Params: ` + params + `}) -` - } - } - - if globalinfo != "" { - f, err := os.Create(filepath.Join(getRouterDir(pkgRealpath), commentFilename)) - if err != nil { - panic(err) - } - defer f.Close() - - routersDir := AppConfig.DefaultString("routersdir", "routers") - content := strings.Replace(globalRouterTemplate, "{{.globalinfo}}", globalinfo, -1) - content = strings.Replace(content, "{{.routersDir}}", routersDir, -1) - content = strings.Replace(content, "{{.globalimport}}", globalimport, -1) - f.WriteString(content) - } -} - -func compareFile(pkgRealpath string) bool { - if !utils.FileExists(filepath.Join(getRouterDir(pkgRealpath), commentFilename)) { - return true - } - if utils.FileExists(lastupdateFilename) { - content, err := ioutil.ReadFile(lastupdateFilename) - if err != nil { - return true - } - json.Unmarshal(content, &pkgLastupdate) - lastupdate, err := getpathTime(pkgRealpath) - if err != nil { - return true - } - if v, ok := pkgLastupdate[pkgRealpath]; ok { - if lastupdate <= v { - return false - } - } - } - return true -} - -func savetoFile(pkgRealpath string) { - lastupdate, err := getpathTime(pkgRealpath) - if err != nil { - return - } - pkgLastupdate[pkgRealpath] = lastupdate - d, err := json.Marshal(pkgLastupdate) - if err != nil { - return - } - ioutil.WriteFile(lastupdateFilename, d, os.ModePerm) -} - -func getpathTime(pkgRealpath string) (lastupdate int64, err error) { - fl, err := ioutil.ReadDir(pkgRealpath) - if err != nil { - return lastupdate, err - } - for _, f := range fl { - if lastupdate < f.ModTime().UnixNano() { - lastupdate = f.ModTime().UnixNano() - } - } - return lastupdate, nil -} - -func getRouterDir(pkgRealpath string) string { - dir := filepath.Dir(pkgRealpath) - for { - routersDir := AppConfig.DefaultString("routersdir", "routers") - d := filepath.Join(dir, routersDir) - if utils.FileExists(d) { - return d - } - - if r, _ := filepath.Rel(dir, AppPath); r == "." { - return d - } - // Parent dir. - dir = filepath.Dir(dir) - } -} diff --git a/plugins/apiauth/apiauth.go b/plugins/apiauth/apiauth.go deleted file mode 100644 index 10e25f3f..00000000 --- a/plugins/apiauth/apiauth.go +++ /dev/null @@ -1,165 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Package apiauth provides handlers to enable apiauth support. -// -// Simple Usage: -// import( -// "github.com/astaxie/beego" -// "github.com/astaxie/beego/plugins/apiauth" -// ) -// -// func main(){ -// // apiauth every request -// beego.InsertFilter("*", beego.BeforeRouter,apiauth.APIBaiscAuth("appid","appkey")) -// beego.Run() -// } -// -// Advanced Usage: -// -// func getAppSecret(appid string) string { -// // get appsecret by appid -// // maybe store in configure, maybe in database -// } -// -// beego.InsertFilter("*", beego.BeforeRouter,apiauth.APISecretAuth(getAppSecret, 360)) -// -// Information: -// -// In the request user should include these params in the query -// -// 1. appid -// -// appid is assigned to the application -// -// 2. signature -// -// get the signature use apiauth.Signature() -// -// when you send to server remember use url.QueryEscape() -// -// 3. timestamp: -// -// send the request time, the format is yyyy-mm-dd HH:ii:ss -// -package apiauth - -import ( - "bytes" - "crypto/hmac" - "crypto/sha256" - "encoding/base64" - "fmt" - "net/url" - "sort" - "time" - - "github.com/astaxie/beego" - "github.com/astaxie/beego/context" -) - -// AppIDToAppSecret is used to get appsecret throw appid -type AppIDToAppSecret func(string) string - -// APIBasicAuth use the basic appid/appkey as the AppIdToAppSecret -func APIBasicAuth(appid, appkey string) beego.FilterFunc { - ft := func(aid string) string { - if aid == appid { - return appkey - } - return "" - } - return APISecretAuth(ft, 300) -} - -// APIBaiscAuth calls APIBasicAuth for previous callers -func APIBaiscAuth(appid, appkey string) beego.FilterFunc { - return APIBasicAuth(appid, appkey) -} - -// APISecretAuth use AppIdToAppSecret verify and -func APISecretAuth(f AppIDToAppSecret, timeout int) beego.FilterFunc { - return func(ctx *context.Context) { - if ctx.Input.Query("appid") == "" { - ctx.ResponseWriter.WriteHeader(403) - ctx.WriteString("miss query param: appid") - return - } - appsecret := f(ctx.Input.Query("appid")) - if appsecret == "" { - ctx.ResponseWriter.WriteHeader(403) - ctx.WriteString("not exist this appid") - return - } - if ctx.Input.Query("signature") == "" { - ctx.ResponseWriter.WriteHeader(403) - ctx.WriteString("miss query param: signature") - return - } - if ctx.Input.Query("timestamp") == "" { - ctx.ResponseWriter.WriteHeader(403) - ctx.WriteString("miss query param: timestamp") - return - } - u, err := time.Parse("2006-01-02 15:04:05", ctx.Input.Query("timestamp")) - if err != nil { - ctx.ResponseWriter.WriteHeader(403) - ctx.WriteString("timestamp format is error, should 2006-01-02 15:04:05") - return - } - t := time.Now() - if t.Sub(u).Seconds() > float64(timeout) { - ctx.ResponseWriter.WriteHeader(403) - ctx.WriteString("timeout! the request time is long ago, please try again") - return - } - if ctx.Input.Query("signature") != - Signature(appsecret, ctx.Input.Method(), ctx.Request.Form, ctx.Input.URL()) { - ctx.ResponseWriter.WriteHeader(403) - ctx.WriteString("auth failed") - } - } -} - -// Signature used to generate signature with the appsecret/method/params/RequestURI -func Signature(appsecret, method string, params url.Values, RequestURL string) (result string) { - var b bytes.Buffer - keys := make([]string, len(params)) - pa := make(map[string]string) - for k, v := range params { - pa[k] = v[0] - keys = append(keys, k) - } - - sort.Strings(keys) - - for _, key := range keys { - if key == "signature" { - continue - } - - val := pa[key] - if key != "" && val != "" { - b.WriteString(key) - b.WriteString(val) - } - } - - stringToSign := fmt.Sprintf("%v\n%v\n%v\n", method, b.String(), RequestURL) - - sha256 := sha256.New - hash := hmac.New(sha256, []byte(appsecret)) - hash.Write([]byte(stringToSign)) - return base64.StdEncoding.EncodeToString(hash.Sum(nil)) -} diff --git a/plugins/apiauth/apiauth_test.go b/plugins/apiauth/apiauth_test.go deleted file mode 100644 index 1f56cb0f..00000000 --- a/plugins/apiauth/apiauth_test.go +++ /dev/null @@ -1,20 +0,0 @@ -package apiauth - -import ( - "net/url" - "testing" -) - -func TestSignature(t *testing.T) { - appsecret := "beego secret" - method := "GET" - RequestURL := "http://localhost/test/url" - params := make(url.Values) - params.Add("arg1", "hello") - params.Add("arg2", "beego") - - signature := "mFdpvLh48ca4mDVEItE9++AKKQ/IVca7O/ZyyB8hR58=" - if Signature(appsecret, method, params, RequestURL) != signature { - t.Error("Signature error") - } -} diff --git a/plugins/auth/basic.go b/plugins/auth/basic.go deleted file mode 100644 index c478044a..00000000 --- a/plugins/auth/basic.go +++ /dev/null @@ -1,107 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Package auth provides handlers to enable basic auth support. -// Simple Usage: -// import( -// "github.com/astaxie/beego" -// "github.com/astaxie/beego/plugins/auth" -// ) -// -// func main(){ -// // authenticate every request -// beego.InsertFilter("*", beego.BeforeRouter,auth.Basic("username","secretpassword")) -// beego.Run() -// } -// -// -// Advanced Usage: -// -// func SecretAuth(username, password string) bool { -// return username == "astaxie" && password == "helloBeego" -// } -// authPlugin := auth.NewBasicAuthenticator(SecretAuth, "Authorization Required") -// beego.InsertFilter("*", beego.BeforeRouter,authPlugin) -package auth - -import ( - "encoding/base64" - "net/http" - "strings" - - "github.com/astaxie/beego" - "github.com/astaxie/beego/context" -) - -var defaultRealm = "Authorization Required" - -// Basic is the http basic auth -func Basic(username string, password string) beego.FilterFunc { - secrets := func(user, pass string) bool { - return user == username && pass == password - } - return NewBasicAuthenticator(secrets, defaultRealm) -} - -// NewBasicAuthenticator return the BasicAuth -func NewBasicAuthenticator(secrets SecretProvider, Realm string) beego.FilterFunc { - return func(ctx *context.Context) { - a := &BasicAuth{Secrets: secrets, Realm: Realm} - if username := a.CheckAuth(ctx.Request); username == "" { - a.RequireAuth(ctx.ResponseWriter, ctx.Request) - } - } -} - -// SecretProvider is the SecretProvider function -type SecretProvider func(user, pass string) bool - -// BasicAuth store the SecretProvider and Realm -type BasicAuth struct { - Secrets SecretProvider - Realm string -} - -// CheckAuth Checks the username/password combination from the request. Returns -// either an empty string (authentication failed) or the name of the -// authenticated user. -// Supports MD5 and SHA1 password entries -func (a *BasicAuth) CheckAuth(r *http.Request) string { - s := strings.SplitN(r.Header.Get("Authorization"), " ", 2) - if len(s) != 2 || s[0] != "Basic" { - return "" - } - - b, err := base64.StdEncoding.DecodeString(s[1]) - if err != nil { - return "" - } - pair := strings.SplitN(string(b), ":", 2) - if len(pair) != 2 { - return "" - } - - if a.Secrets(pair[0], pair[1]) { - return pair[0] - } - return "" -} - -// RequireAuth http.Handler for BasicAuth which initiates the authentication process -// (or requires reauthentication). -func (a *BasicAuth) RequireAuth(w http.ResponseWriter, r *http.Request) { - w.Header().Set("WWW-Authenticate", `Basic realm="`+a.Realm+`"`) - w.WriteHeader(401) - w.Write([]byte("401 Unauthorized\n")) -} diff --git a/plugins/authz/authz.go b/plugins/authz/authz.go deleted file mode 100644 index 9dc0db76..00000000 --- a/plugins/authz/authz.go +++ /dev/null @@ -1,86 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Package authz provides handlers to enable ACL, RBAC, ABAC authorization support. -// Simple Usage: -// import( -// "github.com/astaxie/beego" -// "github.com/astaxie/beego/plugins/authz" -// "github.com/casbin/casbin" -// ) -// -// func main(){ -// // mediate the access for every request -// beego.InsertFilter("*", beego.BeforeRouter, authz.NewAuthorizer(casbin.NewEnforcer("authz_model.conf", "authz_policy.csv"))) -// beego.Run() -// } -// -// -// Advanced Usage: -// -// func main(){ -// e := casbin.NewEnforcer("authz_model.conf", "") -// e.AddRoleForUser("alice", "admin") -// e.AddPolicy(...) -// -// beego.InsertFilter("*", beego.BeforeRouter, authz.NewAuthorizer(e)) -// beego.Run() -// } -package authz - -import ( - "github.com/astaxie/beego" - "github.com/astaxie/beego/context" - "github.com/casbin/casbin" - "net/http" -) - -// NewAuthorizer returns the authorizer. -// Use a casbin enforcer as input -func NewAuthorizer(e *casbin.Enforcer) beego.FilterFunc { - return func(ctx *context.Context) { - a := &BasicAuthorizer{enforcer: e} - - if !a.CheckPermission(ctx.Request) { - a.RequirePermission(ctx.ResponseWriter) - } - } -} - -// BasicAuthorizer stores the casbin handler -type BasicAuthorizer struct { - enforcer *casbin.Enforcer -} - -// GetUserName gets the user name from the request. -// Currently, only HTTP basic authentication is supported -func (a *BasicAuthorizer) GetUserName(r *http.Request) string { - username, _, _ := r.BasicAuth() - return username -} - -// CheckPermission checks the user/method/path combination from the request. -// Returns true (permission granted) or false (permission forbidden) -func (a *BasicAuthorizer) CheckPermission(r *http.Request) bool { - user := a.GetUserName(r) - method := r.Method - path := r.URL.Path - return a.enforcer.Enforce(user, path, method) -} - -// RequirePermission returns the 403 Forbidden to the client -func (a *BasicAuthorizer) RequirePermission(w http.ResponseWriter) { - w.WriteHeader(403) - w.Write([]byte("403 Forbidden\n")) -} diff --git a/plugins/authz/authz_model.conf b/plugins/authz/authz_model.conf deleted file mode 100644 index d1b3dbd7..00000000 --- a/plugins/authz/authz_model.conf +++ /dev/null @@ -1,14 +0,0 @@ -[request_definition] -r = sub, obj, act - -[policy_definition] -p = sub, obj, act - -[role_definition] -g = _, _ - -[policy_effect] -e = some(where (p.eft == allow)) - -[matchers] -m = g(r.sub, p.sub) && keyMatch(r.obj, p.obj) && (r.act == p.act || p.act == "*") \ No newline at end of file diff --git a/plugins/authz/authz_policy.csv b/plugins/authz/authz_policy.csv deleted file mode 100644 index c062dd3e..00000000 --- a/plugins/authz/authz_policy.csv +++ /dev/null @@ -1,7 +0,0 @@ -p, alice, /dataset1/*, GET -p, alice, /dataset1/resource1, POST -p, bob, /dataset2/resource1, * -p, bob, /dataset2/resource2, GET -p, bob, /dataset2/folder1/*, POST -p, dataset1_admin, /dataset1/*, * -g, cathy, dataset1_admin \ No newline at end of file diff --git a/plugins/authz/authz_test.go b/plugins/authz/authz_test.go deleted file mode 100644 index 49aed84c..00000000 --- a/plugins/authz/authz_test.go +++ /dev/null @@ -1,107 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package authz - -import ( - "github.com/astaxie/beego" - "github.com/astaxie/beego/context" - "github.com/astaxie/beego/plugins/auth" - "github.com/casbin/casbin" - "net/http" - "net/http/httptest" - "testing" -) - -func testRequest(t *testing.T, handler *beego.ControllerRegister, user string, path string, method string, code int) { - r, _ := http.NewRequest(method, path, nil) - r.SetBasicAuth(user, "123") - w := httptest.NewRecorder() - handler.ServeHTTP(w, r) - - if w.Code != code { - t.Errorf("%s, %s, %s: %d, supposed to be %d", user, path, method, w.Code, code) - } -} - -func TestBasic(t *testing.T) { - handler := beego.NewControllerRegister() - - handler.InsertFilter("*", beego.BeforeRouter, auth.Basic("alice", "123")) - handler.InsertFilter("*", beego.BeforeRouter, NewAuthorizer(casbin.NewEnforcer("authz_model.conf", "authz_policy.csv"))) - - handler.Any("*", func(ctx *context.Context) { - ctx.Output.SetStatus(200) - }) - - testRequest(t, handler, "alice", "/dataset1/resource1", "GET", 200) - testRequest(t, handler, "alice", "/dataset1/resource1", "POST", 200) - testRequest(t, handler, "alice", "/dataset1/resource2", "GET", 200) - testRequest(t, handler, "alice", "/dataset1/resource2", "POST", 403) -} - -func TestPathWildcard(t *testing.T) { - handler := beego.NewControllerRegister() - - handler.InsertFilter("*", beego.BeforeRouter, auth.Basic("bob", "123")) - handler.InsertFilter("*", beego.BeforeRouter, NewAuthorizer(casbin.NewEnforcer("authz_model.conf", "authz_policy.csv"))) - - handler.Any("*", func(ctx *context.Context) { - ctx.Output.SetStatus(200) - }) - - testRequest(t, handler, "bob", "/dataset2/resource1", "GET", 200) - testRequest(t, handler, "bob", "/dataset2/resource1", "POST", 200) - testRequest(t, handler, "bob", "/dataset2/resource1", "DELETE", 200) - testRequest(t, handler, "bob", "/dataset2/resource2", "GET", 200) - testRequest(t, handler, "bob", "/dataset2/resource2", "POST", 403) - testRequest(t, handler, "bob", "/dataset2/resource2", "DELETE", 403) - - testRequest(t, handler, "bob", "/dataset2/folder1/item1", "GET", 403) - testRequest(t, handler, "bob", "/dataset2/folder1/item1", "POST", 200) - testRequest(t, handler, "bob", "/dataset2/folder1/item1", "DELETE", 403) - testRequest(t, handler, "bob", "/dataset2/folder1/item2", "GET", 403) - testRequest(t, handler, "bob", "/dataset2/folder1/item2", "POST", 200) - testRequest(t, handler, "bob", "/dataset2/folder1/item2", "DELETE", 403) -} - -func TestRBAC(t *testing.T) { - handler := beego.NewControllerRegister() - - handler.InsertFilter("*", beego.BeforeRouter, auth.Basic("cathy", "123")) - e := casbin.NewEnforcer("authz_model.conf", "authz_policy.csv") - handler.InsertFilter("*", beego.BeforeRouter, NewAuthorizer(e)) - - handler.Any("*", func(ctx *context.Context) { - ctx.Output.SetStatus(200) - }) - - // cathy can access all /dataset1/* resources via all methods because it has the dataset1_admin role. - testRequest(t, handler, "cathy", "/dataset1/item", "GET", 200) - testRequest(t, handler, "cathy", "/dataset1/item", "POST", 200) - testRequest(t, handler, "cathy", "/dataset1/item", "DELETE", 200) - testRequest(t, handler, "cathy", "/dataset2/item", "GET", 403) - testRequest(t, handler, "cathy", "/dataset2/item", "POST", 403) - testRequest(t, handler, "cathy", "/dataset2/item", "DELETE", 403) - - // delete all roles on user cathy, so cathy cannot access any resources now. - e.DeleteRolesForUser("cathy") - - testRequest(t, handler, "cathy", "/dataset1/item", "GET", 403) - testRequest(t, handler, "cathy", "/dataset1/item", "POST", 403) - testRequest(t, handler, "cathy", "/dataset1/item", "DELETE", 403) - testRequest(t, handler, "cathy", "/dataset2/item", "GET", 403) - testRequest(t, handler, "cathy", "/dataset2/item", "POST", 403) - testRequest(t, handler, "cathy", "/dataset2/item", "DELETE", 403) -} diff --git a/plugins/cors/cors.go b/plugins/cors/cors.go deleted file mode 100644 index 45c327ab..00000000 --- a/plugins/cors/cors.go +++ /dev/null @@ -1,228 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Package cors provides handlers to enable CORS support. -// Usage -// import ( -// "github.com/astaxie/beego" -// "github.com/astaxie/beego/plugins/cors" -// ) -// -// func main() { -// // CORS for https://foo.* origins, allowing: -// // - PUT and PATCH methods -// // - Origin header -// // - Credentials share -// beego.InsertFilter("*", beego.BeforeRouter, cors.Allow(&cors.Options{ -// AllowOrigins: []string{"https://*.foo.com"}, -// AllowMethods: []string{"PUT", "PATCH"}, -// AllowHeaders: []string{"Origin"}, -// ExposeHeaders: []string{"Content-Length"}, -// AllowCredentials: true, -// })) -// beego.Run() -// } -package cors - -import ( - "net/http" - "regexp" - "strconv" - "strings" - "time" - - "github.com/astaxie/beego" - "github.com/astaxie/beego/context" -) - -const ( - headerAllowOrigin = "Access-Control-Allow-Origin" - headerAllowCredentials = "Access-Control-Allow-Credentials" - headerAllowHeaders = "Access-Control-Allow-Headers" - headerAllowMethods = "Access-Control-Allow-Methods" - headerExposeHeaders = "Access-Control-Expose-Headers" - headerMaxAge = "Access-Control-Max-Age" - - headerOrigin = "Origin" - headerRequestMethod = "Access-Control-Request-Method" - headerRequestHeaders = "Access-Control-Request-Headers" -) - -var ( - defaultAllowHeaders = []string{"Origin", "Accept", "Content-Type", "Authorization"} - // Regex patterns are generated from AllowOrigins. These are used and generated internally. - allowOriginPatterns = []string{} -) - -// Options represents Access Control options. -type Options struct { - // If set, all origins are allowed. - AllowAllOrigins bool - // A list of allowed origins. Wild cards and FQDNs are supported. - AllowOrigins []string - // If set, allows to share auth credentials such as cookies. - AllowCredentials bool - // A list of allowed HTTP methods. - AllowMethods []string - // A list of allowed HTTP headers. - AllowHeaders []string - // A list of exposed HTTP headers. - ExposeHeaders []string - // Max age of the CORS headers. - MaxAge time.Duration -} - -// Header converts options into CORS headers. -func (o *Options) Header(origin string) (headers map[string]string) { - headers = make(map[string]string) - // if origin is not allowed, don't extend the headers - // with CORS headers. - if !o.AllowAllOrigins && !o.IsOriginAllowed(origin) { - return - } - - // add allow origin - if o.AllowAllOrigins { - headers[headerAllowOrigin] = "*" - } else { - headers[headerAllowOrigin] = origin - } - - // add allow credentials - headers[headerAllowCredentials] = strconv.FormatBool(o.AllowCredentials) - - // add allow methods - if len(o.AllowMethods) > 0 { - headers[headerAllowMethods] = strings.Join(o.AllowMethods, ",") - } - - // add allow headers - if len(o.AllowHeaders) > 0 { - headers[headerAllowHeaders] = strings.Join(o.AllowHeaders, ",") - } - - // add exposed header - if len(o.ExposeHeaders) > 0 { - headers[headerExposeHeaders] = strings.Join(o.ExposeHeaders, ",") - } - // add a max age header - if o.MaxAge > time.Duration(0) { - headers[headerMaxAge] = strconv.FormatInt(int64(o.MaxAge/time.Second), 10) - } - return -} - -// PreflightHeader converts options into CORS headers for a preflight response. -func (o *Options) PreflightHeader(origin, rMethod, rHeaders string) (headers map[string]string) { - headers = make(map[string]string) - if !o.AllowAllOrigins && !o.IsOriginAllowed(origin) { - return - } - // verify if requested method is allowed - for _, method := range o.AllowMethods { - if method == rMethod { - headers[headerAllowMethods] = strings.Join(o.AllowMethods, ",") - break - } - } - - // verify if requested headers are allowed - var allowed []string - for _, rHeader := range strings.Split(rHeaders, ",") { - rHeader = strings.TrimSpace(rHeader) - lookupLoop: - for _, allowedHeader := range o.AllowHeaders { - if strings.ToLower(rHeader) == strings.ToLower(allowedHeader) { - allowed = append(allowed, rHeader) - break lookupLoop - } - } - } - - headers[headerAllowCredentials] = strconv.FormatBool(o.AllowCredentials) - // add allow origin - if o.AllowAllOrigins { - headers[headerAllowOrigin] = "*" - } else { - headers[headerAllowOrigin] = origin - } - - // add allowed headers - if len(allowed) > 0 { - headers[headerAllowHeaders] = strings.Join(allowed, ",") - } - - // add exposed headers - if len(o.ExposeHeaders) > 0 { - headers[headerExposeHeaders] = strings.Join(o.ExposeHeaders, ",") - } - // add a max age header - if o.MaxAge > time.Duration(0) { - headers[headerMaxAge] = strconv.FormatInt(int64(o.MaxAge/time.Second), 10) - } - return -} - -// IsOriginAllowed looks up if the origin matches one of the patterns -// generated from Options.AllowOrigins patterns. -func (o *Options) IsOriginAllowed(origin string) (allowed bool) { - for _, pattern := range allowOriginPatterns { - allowed, _ = regexp.MatchString(pattern, origin) - if allowed { - return - } - } - return -} - -// Allow enables CORS for requests those match the provided options. -func Allow(opts *Options) beego.FilterFunc { - // Allow default headers if nothing is specified. - if len(opts.AllowHeaders) == 0 { - opts.AllowHeaders = defaultAllowHeaders - } - - for _, origin := range opts.AllowOrigins { - pattern := regexp.QuoteMeta(origin) - pattern = strings.Replace(pattern, "\\*", ".*", -1) - pattern = strings.Replace(pattern, "\\?", ".", -1) - allowOriginPatterns = append(allowOriginPatterns, "^"+pattern+"$") - } - - return func(ctx *context.Context) { - var ( - origin = ctx.Input.Header(headerOrigin) - requestedMethod = ctx.Input.Header(headerRequestMethod) - requestedHeaders = ctx.Input.Header(headerRequestHeaders) - // additional headers to be added - // to the response. - headers map[string]string - ) - - if ctx.Input.Method() == "OPTIONS" && - (requestedMethod != "" || requestedHeaders != "") { - headers = opts.PreflightHeader(origin, requestedMethod, requestedHeaders) - for key, value := range headers { - ctx.Output.Header(key, value) - } - ctx.ResponseWriter.WriteHeader(http.StatusOK) - return - } - headers = opts.Header(origin) - - for key, value := range headers { - ctx.Output.Header(key, value) - } - } -} diff --git a/plugins/cors/cors_test.go b/plugins/cors/cors_test.go deleted file mode 100644 index 34039143..00000000 --- a/plugins/cors/cors_test.go +++ /dev/null @@ -1,253 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package cors - -import ( - "net/http" - "net/http/httptest" - "strings" - "testing" - "time" - - "github.com/astaxie/beego" - "github.com/astaxie/beego/context" -) - -// HTTPHeaderGuardRecorder is httptest.ResponseRecorder with own http.Header -type HTTPHeaderGuardRecorder struct { - *httptest.ResponseRecorder - savedHeaderMap http.Header -} - -// NewRecorder return HttpHeaderGuardRecorder -func NewRecorder() *HTTPHeaderGuardRecorder { - return &HTTPHeaderGuardRecorder{httptest.NewRecorder(), nil} -} - -func (gr *HTTPHeaderGuardRecorder) WriteHeader(code int) { - gr.ResponseRecorder.WriteHeader(code) - gr.savedHeaderMap = gr.ResponseRecorder.Header() -} - -func (gr *HTTPHeaderGuardRecorder) Header() http.Header { - if gr.savedHeaderMap != nil { - // headers were written. clone so we don't get updates - clone := make(http.Header) - for k, v := range gr.savedHeaderMap { - clone[k] = v - } - return clone - } - return gr.ResponseRecorder.Header() -} - -func Test_AllowAll(t *testing.T) { - recorder := httptest.NewRecorder() - handler := beego.NewControllerRegister() - handler.InsertFilter("*", beego.BeforeRouter, Allow(&Options{ - AllowAllOrigins: true, - })) - handler.Any("/foo", func(ctx *context.Context) { - ctx.Output.SetStatus(500) - }) - r, _ := http.NewRequest("PUT", "/foo", nil) - handler.ServeHTTP(recorder, r) - - if recorder.HeaderMap.Get(headerAllowOrigin) != "*" { - t.Errorf("Allow-Origin header should be *") - } -} - -func Test_AllowRegexMatch(t *testing.T) { - recorder := httptest.NewRecorder() - handler := beego.NewControllerRegister() - handler.InsertFilter("*", beego.BeforeRouter, Allow(&Options{ - AllowOrigins: []string{"https://aaa.com", "https://*.foo.com"}, - })) - handler.Any("/foo", func(ctx *context.Context) { - ctx.Output.SetStatus(500) - }) - origin := "https://bar.foo.com" - r, _ := http.NewRequest("PUT", "/foo", nil) - r.Header.Add("Origin", origin) - handler.ServeHTTP(recorder, r) - - headerValue := recorder.HeaderMap.Get(headerAllowOrigin) - if headerValue != origin { - t.Errorf("Allow-Origin header should be %v, found %v", origin, headerValue) - } -} - -func Test_AllowRegexNoMatch(t *testing.T) { - recorder := httptest.NewRecorder() - handler := beego.NewControllerRegister() - handler.InsertFilter("*", beego.BeforeRouter, Allow(&Options{ - AllowOrigins: []string{"https://*.foo.com"}, - })) - handler.Any("/foo", func(ctx *context.Context) { - ctx.Output.SetStatus(500) - }) - origin := "https://ww.foo.com.evil.com" - r, _ := http.NewRequest("PUT", "/foo", nil) - r.Header.Add("Origin", origin) - handler.ServeHTTP(recorder, r) - - headerValue := recorder.HeaderMap.Get(headerAllowOrigin) - if headerValue != "" { - t.Errorf("Allow-Origin header should not exist, found %v", headerValue) - } -} - -func Test_OtherHeaders(t *testing.T) { - recorder := httptest.NewRecorder() - handler := beego.NewControllerRegister() - handler.InsertFilter("*", beego.BeforeRouter, Allow(&Options{ - AllowAllOrigins: true, - AllowCredentials: true, - AllowMethods: []string{"PATCH", "GET"}, - AllowHeaders: []string{"Origin", "X-whatever"}, - ExposeHeaders: []string{"Content-Length", "Hello"}, - MaxAge: 5 * time.Minute, - })) - handler.Any("/foo", func(ctx *context.Context) { - ctx.Output.SetStatus(500) - }) - r, _ := http.NewRequest("PUT", "/foo", nil) - handler.ServeHTTP(recorder, r) - - credentialsVal := recorder.HeaderMap.Get(headerAllowCredentials) - methodsVal := recorder.HeaderMap.Get(headerAllowMethods) - headersVal := recorder.HeaderMap.Get(headerAllowHeaders) - exposedHeadersVal := recorder.HeaderMap.Get(headerExposeHeaders) - maxAgeVal := recorder.HeaderMap.Get(headerMaxAge) - - if credentialsVal != "true" { - t.Errorf("Allow-Credentials is expected to be true, found %v", credentialsVal) - } - - if methodsVal != "PATCH,GET" { - t.Errorf("Allow-Methods is expected to be PATCH,GET; found %v", methodsVal) - } - - if headersVal != "Origin,X-whatever" { - t.Errorf("Allow-Headers is expected to be Origin,X-whatever; found %v", headersVal) - } - - if exposedHeadersVal != "Content-Length,Hello" { - t.Errorf("Expose-Headers are expected to be Content-Length,Hello. Found %v", exposedHeadersVal) - } - - if maxAgeVal != "300" { - t.Errorf("Max-Age is expected to be 300, found %v", maxAgeVal) - } -} - -func Test_DefaultAllowHeaders(t *testing.T) { - recorder := httptest.NewRecorder() - handler := beego.NewControllerRegister() - handler.InsertFilter("*", beego.BeforeRouter, Allow(&Options{ - AllowAllOrigins: true, - })) - handler.Any("/foo", func(ctx *context.Context) { - ctx.Output.SetStatus(500) - }) - - r, _ := http.NewRequest("PUT", "/foo", nil) - handler.ServeHTTP(recorder, r) - - headersVal := recorder.HeaderMap.Get(headerAllowHeaders) - if headersVal != "Origin,Accept,Content-Type,Authorization" { - t.Errorf("Allow-Headers is expected to be Origin,Accept,Content-Type,Authorization; found %v", headersVal) - } -} - -func Test_Preflight(t *testing.T) { - recorder := NewRecorder() - handler := beego.NewControllerRegister() - handler.InsertFilter("*", beego.BeforeRouter, Allow(&Options{ - AllowAllOrigins: true, - AllowMethods: []string{"PUT", "PATCH"}, - AllowHeaders: []string{"Origin", "X-whatever", "X-CaseSensitive"}, - })) - - handler.Any("/foo", func(ctx *context.Context) { - ctx.Output.SetStatus(200) - }) - - r, _ := http.NewRequest("OPTIONS", "/foo", nil) - r.Header.Add(headerRequestMethod, "PUT") - r.Header.Add(headerRequestHeaders, "X-whatever, x-casesensitive") - handler.ServeHTTP(recorder, r) - - headers := recorder.Header() - methodsVal := headers.Get(headerAllowMethods) - headersVal := headers.Get(headerAllowHeaders) - originVal := headers.Get(headerAllowOrigin) - - if methodsVal != "PUT,PATCH" { - t.Errorf("Allow-Methods is expected to be PUT,PATCH, found %v", methodsVal) - } - - if !strings.Contains(headersVal, "X-whatever") { - t.Errorf("Allow-Headers is expected to contain X-whatever, found %v", headersVal) - } - - if !strings.Contains(headersVal, "x-casesensitive") { - t.Errorf("Allow-Headers is expected to contain x-casesensitive, found %v", headersVal) - } - - if originVal != "*" { - t.Errorf("Allow-Origin is expected to be *, found %v", originVal) - } - - if recorder.Code != http.StatusOK { - t.Errorf("Status code is expected to be 200, found %d", recorder.Code) - } -} - -func Benchmark_WithoutCORS(b *testing.B) { - recorder := httptest.NewRecorder() - handler := beego.NewControllerRegister() - beego.BConfig.RunMode = beego.PROD - handler.Any("/foo", func(ctx *context.Context) { - ctx.Output.SetStatus(500) - }) - b.ResetTimer() - r, _ := http.NewRequest("PUT", "/foo", nil) - for i := 0; i < b.N; i++ { - handler.ServeHTTP(recorder, r) - } -} - -func Benchmark_WithCORS(b *testing.B) { - recorder := httptest.NewRecorder() - handler := beego.NewControllerRegister() - beego.BConfig.RunMode = beego.PROD - handler.InsertFilter("*", beego.BeforeRouter, Allow(&Options{ - AllowAllOrigins: true, - AllowCredentials: true, - AllowMethods: []string{"PATCH", "GET"}, - AllowHeaders: []string{"Origin", "X-whatever"}, - MaxAge: 5 * time.Minute, - })) - handler.Any("/foo", func(ctx *context.Context) { - ctx.Output.SetStatus(500) - }) - b.ResetTimer() - r, _ := http.NewRequest("PUT", "/foo", nil) - for i := 0; i < b.N; i++ { - handler.ServeHTTP(recorder, r) - } -} diff --git a/policy.go b/policy.go deleted file mode 100644 index 358a0539..00000000 --- a/policy.go +++ /dev/null @@ -1,100 +0,0 @@ -// Copyright 2016 beego authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package beego - -import ( - "strings" - - "github.com/astaxie/beego/context" -) - -// PolicyFunc defines a policy function which is invoked before the controller handler is executed. -// Deprecated: using pkg/, we will delete this in v2.1.0 -type PolicyFunc func(*context.Context) - -// FindPolicy Find Router info for URL -// Deprecated: using pkg/, we will delete this in v2.1.0 -func (p *ControllerRegister) FindPolicy(cont *context.Context) []PolicyFunc { - var urlPath = cont.Input.URL() - if !BConfig.RouterCaseSensitive { - urlPath = strings.ToLower(urlPath) - } - httpMethod := cont.Input.Method() - isWildcard := false - // Find policy for current method - t, ok := p.policies[httpMethod] - // If not found - find policy for whole controller - if !ok { - t, ok = p.policies["*"] - isWildcard = true - } - if ok { - runObjects := t.Match(urlPath, cont) - if r, ok := runObjects.([]PolicyFunc); ok { - return r - } else if !isWildcard { - // If no policies found and we checked not for "*" method - try to find it - t, ok = p.policies["*"] - if ok { - runObjects = t.Match(urlPath, cont) - if r, ok = runObjects.([]PolicyFunc); ok { - return r - } - } - } - } - return nil -} - -func (p *ControllerRegister) addToPolicy(method, pattern string, r ...PolicyFunc) { - method = strings.ToUpper(method) - p.enablePolicy = true - if !BConfig.RouterCaseSensitive { - pattern = strings.ToLower(pattern) - } - if t, ok := p.policies[method]; ok { - t.AddRouter(pattern, r) - } else { - t := NewTree() - t.AddRouter(pattern, r) - p.policies[method] = t - } -} - -// Policy Register new policy in beego -// Deprecated: using pkg/, we will delete this in v2.1.0 -func Policy(pattern, method string, policy ...PolicyFunc) { - BeeApp.Handlers.addToPolicy(method, pattern, policy...) -} - -// Find policies and execute if were found -func (p *ControllerRegister) execPolicy(cont *context.Context, urlPath string) (started bool) { - if !p.enablePolicy { - return false - } - // Find Policy for method - policyList := p.FindPolicy(cont) - if len(policyList) > 0 { - // Run policies - for _, runPolicy := range policyList { - runPolicy(cont) - if cont.ResponseWriter.Started { - return true - } - } - return false - } - return false -} diff --git a/router.go b/router.go deleted file mode 100644 index 1be495ab..00000000 --- a/router.go +++ /dev/null @@ -1,1085 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package beego - -import ( - "errors" - "fmt" - "net/http" - "os" - "path" - "path/filepath" - "reflect" - "strconv" - "strings" - "sync" - "time" - - beecontext "github.com/astaxie/beego/context" - "github.com/astaxie/beego/context/param" - "github.com/astaxie/beego/logs" - "github.com/astaxie/beego/toolbox" - "github.com/astaxie/beego/utils" -) - -// default filter execution points -const ( - BeforeStatic = iota - BeforeRouter - BeforeExec - AfterExec - FinishRouter -) - -const ( - routerTypeBeego = iota - routerTypeRESTFul - routerTypeHandler -) - -var ( - // HTTPMETHOD list the supported http methods. - // Deprecated: using pkg/, we will delete this in v2.1.0 - HTTPMETHOD = map[string]bool{ - "GET": true, - "POST": true, - "PUT": true, - "DELETE": true, - "PATCH": true, - "OPTIONS": true, - "HEAD": true, - "TRACE": true, - "CONNECT": true, - "MKCOL": true, - "COPY": true, - "MOVE": true, - "PROPFIND": true, - "PROPPATCH": true, - "LOCK": true, - "UNLOCK": true, - } - // these beego.Controller's methods shouldn't reflect to AutoRouter - exceptMethod = []string{"Init", "Prepare", "Finish", "Render", "RenderString", - "RenderBytes", "Redirect", "Abort", "StopRun", "UrlFor", "ServeJSON", "ServeJSONP", - "ServeYAML", "ServeXML", "Input", "ParseForm", "GetString", "GetStrings", "GetInt", "GetBool", - "GetFloat", "GetFile", "SaveToFile", "StartSession", "SetSession", "GetSession", - "DelSession", "SessionRegenerateID", "DestroySession", "IsAjax", "GetSecureCookie", - "SetSecureCookie", "XsrfToken", "CheckXsrfCookie", "XsrfFormHtml", - "GetControllerAndAction", "ServeFormatted"} - - urlPlaceholder = "{{placeholder}}" - // DefaultAccessLogFilter will skip the accesslog if return true - // Deprecated: using pkg/, we will delete this in v2.1.0 - DefaultAccessLogFilter FilterHandler = &logFilter{} -) - -// FilterHandler is an interface for -// Deprecated: using pkg/, we will delete this in v2.1.0 -type FilterHandler interface { - Filter(*beecontext.Context) bool -} - -// default log filter static file will not show -type logFilter struct { -} - -// Deprecated: using pkg/, we will delete this in v2.1.0 -func (l *logFilter) Filter(ctx *beecontext.Context) bool { - requestPath := path.Clean(ctx.Request.URL.Path) - if requestPath == "/favicon.ico" || requestPath == "/robots.txt" { - return true - } - for prefix := range BConfig.WebConfig.StaticDir { - if strings.HasPrefix(requestPath, prefix) { - return true - } - } - return false -} - -// ExceptMethodAppend to append a slice's value into "exceptMethod", for controller's methods shouldn't reflect to AutoRouter -// Deprecated: using pkg/, we will delete this in v2.1.0 -func ExceptMethodAppend(action string) { - exceptMethod = append(exceptMethod, action) -} - -// ControllerInfo holds information about the controller. -type ControllerInfo struct { - pattern string - controllerType reflect.Type - methods map[string]string - handler http.Handler - runFunction FilterFunc - routerType int - initialize func() ControllerInterface - methodParams []*param.MethodParam -} - -// Deprecated: using pkg/, we will delete this in v2.1.0 -func (c *ControllerInfo) GetPattern() string { - return c.pattern -} - -// ControllerRegister containers registered router rules, controller handlers and filters. -// Deprecated: using pkg/, we will delete this in v2.1.0 -type ControllerRegister struct { - routers map[string]*Tree - enablePolicy bool - policies map[string]*Tree - enableFilter bool - filters [FinishRouter + 1][]*FilterRouter - pool sync.Pool -} - -// NewControllerRegister returns a new ControllerRegister. -// Deprecated: using pkg/, we will delete this in v2.1.0 -func NewControllerRegister() *ControllerRegister { - return &ControllerRegister{ - routers: make(map[string]*Tree), - policies: make(map[string]*Tree), - pool: sync.Pool{ - New: func() interface{} { - return beecontext.NewContext() - }, - }, - } -} - -// Add controller handler and pattern rules to ControllerRegister. -// usage: -// default methods is the same name as method -// Add("/user",&UserController{}) -// Add("/api/list",&RestController{},"*:ListFood") -// Add("/api/create",&RestController{},"post:CreateFood") -// Add("/api/update",&RestController{},"put:UpdateFood") -// Add("/api/delete",&RestController{},"delete:DeleteFood") -// Add("/api",&RestController{},"get,post:ApiFunc" -// Add("/simple",&SimpleController{},"get:GetFunc;post:PostFunc") -// Deprecated: using pkg/, we will delete this in v2.1.0 -func (p *ControllerRegister) Add(pattern string, c ControllerInterface, mappingMethods ...string) { - p.addWithMethodParams(pattern, c, nil, mappingMethods...) -} - -func (p *ControllerRegister) addWithMethodParams(pattern string, c ControllerInterface, methodParams []*param.MethodParam, mappingMethods ...string) { - reflectVal := reflect.ValueOf(c) - t := reflect.Indirect(reflectVal).Type() - methods := make(map[string]string) - if len(mappingMethods) > 0 { - semi := strings.Split(mappingMethods[0], ";") - for _, v := range semi { - colon := strings.Split(v, ":") - if len(colon) != 2 { - panic("method mapping format is invalid") - } - comma := strings.Split(colon[0], ",") - for _, m := range comma { - if m == "*" || HTTPMETHOD[strings.ToUpper(m)] { - if val := reflectVal.MethodByName(colon[1]); val.IsValid() { - methods[strings.ToUpper(m)] = colon[1] - } else { - panic("'" + colon[1] + "' method doesn't exist in the controller " + t.Name()) - } - } else { - panic(v + " is an invalid method mapping. Method doesn't exist " + m) - } - } - } - } - - route := &ControllerInfo{} - route.pattern = pattern - route.methods = methods - route.routerType = routerTypeBeego - route.controllerType = t - route.initialize = func() ControllerInterface { - vc := reflect.New(route.controllerType) - execController, ok := vc.Interface().(ControllerInterface) - if !ok { - panic("controller is not ControllerInterface") - } - - elemVal := reflect.ValueOf(c).Elem() - elemType := reflect.TypeOf(c).Elem() - execElem := reflect.ValueOf(execController).Elem() - - numOfFields := elemVal.NumField() - for i := 0; i < numOfFields; i++ { - fieldType := elemType.Field(i) - elemField := execElem.FieldByName(fieldType.Name) - if elemField.CanSet() { - fieldVal := elemVal.Field(i) - elemField.Set(fieldVal) - } - } - - return execController - } - - route.methodParams = methodParams - if len(methods) == 0 { - for m := range HTTPMETHOD { - p.addToRouter(m, pattern, route) - } - } else { - for k := range methods { - if k == "*" { - for m := range HTTPMETHOD { - p.addToRouter(m, pattern, route) - } - } else { - p.addToRouter(k, pattern, route) - } - } - } -} - -func (p *ControllerRegister) addToRouter(method, pattern string, r *ControllerInfo) { - if !BConfig.RouterCaseSensitive { - pattern = strings.ToLower(pattern) - } - if t, ok := p.routers[method]; ok { - t.AddRouter(pattern, r) - } else { - t := NewTree() - t.AddRouter(pattern, r) - p.routers[method] = t - } -} - -// Include only when the Runmode is dev will generate router file in the router/auto.go from the controller -// Include(&BankAccount{}, &OrderController{},&RefundController{},&ReceiptController{}) -// Deprecated: using pkg/, we will delete this in v2.1.0 -func (p *ControllerRegister) Include(cList ...ControllerInterface) { - if BConfig.RunMode == DEV { - skip := make(map[string]bool, 10) - wgopath := utils.GetGOPATHs() - go111module := os.Getenv(`GO111MODULE`) - for _, c := range cList { - reflectVal := reflect.ValueOf(c) - t := reflect.Indirect(reflectVal).Type() - // for go modules - if go111module == `on` { - pkgpath := filepath.Join(WorkPath, "..", t.PkgPath()) - if utils.FileExists(pkgpath) { - if pkgpath != "" { - if _, ok := skip[pkgpath]; !ok { - skip[pkgpath] = true - parserPkg(pkgpath, t.PkgPath()) - } - } - } - } else { - if len(wgopath) == 0 { - panic("you are in dev mode. So please set gopath") - } - pkgpath := "" - for _, wg := range wgopath { - wg, _ = filepath.EvalSymlinks(filepath.Join(wg, "src", t.PkgPath())) - if utils.FileExists(wg) { - pkgpath = wg - break - } - } - if pkgpath != "" { - if _, ok := skip[pkgpath]; !ok { - skip[pkgpath] = true - parserPkg(pkgpath, t.PkgPath()) - } - } - } - } - } - for _, c := range cList { - reflectVal := reflect.ValueOf(c) - t := reflect.Indirect(reflectVal).Type() - key := t.PkgPath() + ":" + t.Name() - if comm, ok := GlobalControllerRouter[key]; ok { - for _, a := range comm { - for _, f := range a.Filters { - p.InsertFilter(f.Pattern, f.Pos, f.Filter, f.ReturnOnOutput, f.ResetParams) - } - - p.addWithMethodParams(a.Router, c, a.MethodParams, strings.Join(a.AllowHTTPMethods, ",")+":"+a.Method) - } - } - } -} - -// GetContext returns a context from pool, so usually you should remember to call Reset function to clean the context -// And don't forget to give back context to pool -// example: -// ctx := p.GetContext() -// ctx.Reset(w, q) -// defer p.GiveBackContext(ctx) -// Deprecated: using pkg/, we will delete this in v2.1.0 -func (p *ControllerRegister) GetContext() *beecontext.Context { - return p.pool.Get().(*beecontext.Context) -} - -// GiveBackContext put the ctx into pool so that it could be reuse -// Deprecated: using pkg/, we will delete this in v2.1.0 -func (p *ControllerRegister) GiveBackContext(ctx *beecontext.Context) { - // clear input cached data - ctx.Input.Clear() - // clear output cached data - ctx.Output.Clear() - p.pool.Put(ctx) -} - -// Get add get method -// usage: -// Get("/", func(ctx *context.Context){ -// ctx.Output.Body("hello world") -// }) -// Deprecated: using pkg/, we will delete this in v2.1.0 -func (p *ControllerRegister) Get(pattern string, f FilterFunc) { - p.AddMethod("get", pattern, f) -} - -// Post add post method -// usage: -// Post("/api", func(ctx *context.Context){ -// ctx.Output.Body("hello world") -// }) -// Deprecated: using pkg/, we will delete this in v2.1.0 -func (p *ControllerRegister) Post(pattern string, f FilterFunc) { - p.AddMethod("post", pattern, f) -} - -// Put add put method -// usage: -// Put("/api/:id", func(ctx *context.Context){ -// ctx.Output.Body("hello world") -// }) -// Deprecated: using pkg/, we will delete this in v2.1.0 -func (p *ControllerRegister) Put(pattern string, f FilterFunc) { - p.AddMethod("put", pattern, f) -} - -// Delete add delete method -// usage: -// Delete("/api/:id", func(ctx *context.Context){ -// ctx.Output.Body("hello world") -// }) -// Deprecated: using pkg/, we will delete this in v2.1.0 -func (p *ControllerRegister) Delete(pattern string, f FilterFunc) { - p.AddMethod("delete", pattern, f) -} - -// Head add head method -// usage: -// Head("/api/:id", func(ctx *context.Context){ -// ctx.Output.Body("hello world") -// }) -// Deprecated: using pkg/, we will delete this in v2.1.0 -func (p *ControllerRegister) Head(pattern string, f FilterFunc) { - p.AddMethod("head", pattern, f) -} - -// Patch add patch method -// usage: -// Patch("/api/:id", func(ctx *context.Context){ -// ctx.Output.Body("hello world") -// }) -// Deprecated: using pkg/, we will delete this in v2.1.0 -func (p *ControllerRegister) Patch(pattern string, f FilterFunc) { - p.AddMethod("patch", pattern, f) -} - -// Options add options method -// usage: -// Options("/api/:id", func(ctx *context.Context){ -// ctx.Output.Body("hello world") -// }) -// Deprecated: using pkg/, we will delete this in v2.1.0 -func (p *ControllerRegister) Options(pattern string, f FilterFunc) { - p.AddMethod("options", pattern, f) -} - -// Any add all method -// usage: -// Any("/api/:id", func(ctx *context.Context){ -// ctx.Output.Body("hello world") -// }) -// Deprecated: using pkg/, we will delete this in v2.1.0 -func (p *ControllerRegister) Any(pattern string, f FilterFunc) { - p.AddMethod("*", pattern, f) -} - -// AddMethod add http method router -// usage: -// AddMethod("get","/api/:id", func(ctx *context.Context){ -// ctx.Output.Body("hello world") -// }) -// Deprecated: using pkg/, we will delete this in v2.1.0 -func (p *ControllerRegister) AddMethod(method, pattern string, f FilterFunc) { - method = strings.ToUpper(method) - if method != "*" && !HTTPMETHOD[method] { - panic("not support http method: " + method) - } - route := &ControllerInfo{} - route.pattern = pattern - route.routerType = routerTypeRESTFul - route.runFunction = f - methods := make(map[string]string) - if method == "*" { - for val := range HTTPMETHOD { - methods[val] = val - } - } else { - methods[method] = method - } - route.methods = methods - for k := range methods { - if k == "*" { - for m := range HTTPMETHOD { - p.addToRouter(m, pattern, route) - } - } else { - p.addToRouter(k, pattern, route) - } - } -} - -// Handler add user defined Handler -// Deprecated: using pkg/, we will delete this in v2.1.0 -func (p *ControllerRegister) Handler(pattern string, h http.Handler, options ...interface{}) { - route := &ControllerInfo{} - route.pattern = pattern - route.routerType = routerTypeHandler - route.handler = h - if len(options) > 0 { - if _, ok := options[0].(bool); ok { - pattern = path.Join(pattern, "?:all(.*)") - } - } - for m := range HTTPMETHOD { - p.addToRouter(m, pattern, route) - } -} - -// AddAuto router to ControllerRegister. -// example beego.AddAuto(&MainContorlller{}), -// MainController has method List and Page. -// visit the url /main/list to execute List function -// /main/page to execute Page function. -// Deprecated: using pkg/, we will delete this in v2.1.0 -func (p *ControllerRegister) AddAuto(c ControllerInterface) { - p.AddAutoPrefix("/", c) -} - -// AddAutoPrefix Add auto router to ControllerRegister with prefix. -// example beego.AddAutoPrefix("/admin",&MainContorlller{}), -// MainController has method List and Page. -// visit the url /admin/main/list to execute List function -// /admin/main/page to execute Page function. -// Deprecated: using pkg/, we will delete this in v2.1.0 -func (p *ControllerRegister) AddAutoPrefix(prefix string, c ControllerInterface) { - reflectVal := reflect.ValueOf(c) - rt := reflectVal.Type() - ct := reflect.Indirect(reflectVal).Type() - controllerName := strings.TrimSuffix(ct.Name(), "Controller") - for i := 0; i < rt.NumMethod(); i++ { - if !utils.InSlice(rt.Method(i).Name, exceptMethod) { - route := &ControllerInfo{} - route.routerType = routerTypeBeego - route.methods = map[string]string{"*": rt.Method(i).Name} - route.controllerType = ct - pattern := path.Join(prefix, strings.ToLower(controllerName), strings.ToLower(rt.Method(i).Name), "*") - patternInit := path.Join(prefix, controllerName, rt.Method(i).Name, "*") - patternFix := path.Join(prefix, strings.ToLower(controllerName), strings.ToLower(rt.Method(i).Name)) - patternFixInit := path.Join(prefix, controllerName, rt.Method(i).Name) - route.pattern = pattern - for m := range HTTPMETHOD { - p.addToRouter(m, pattern, route) - p.addToRouter(m, patternInit, route) - p.addToRouter(m, patternFix, route) - p.addToRouter(m, patternFixInit, route) - } - } - } -} - -// InsertFilter Add a FilterFunc with pattern rule and action constant. -// params is for: -// 1. setting the returnOnOutput value (false allows multiple filters to execute) -// 2. determining whether or not params need to be reset. -// Deprecated: using pkg/, we will delete this in v2.1.0 -func (p *ControllerRegister) InsertFilter(pattern string, pos int, filter FilterFunc, params ...bool) error { - mr := &FilterRouter{ - tree: NewTree(), - pattern: pattern, - filterFunc: filter, - returnOnOutput: true, - } - if !BConfig.RouterCaseSensitive { - mr.pattern = strings.ToLower(pattern) - } - - paramsLen := len(params) - if paramsLen > 0 { - mr.returnOnOutput = params[0] - } - if paramsLen > 1 { - mr.resetParams = params[1] - } - mr.tree.AddRouter(pattern, true) - return p.insertFilterRouter(pos, mr) -} - -// add Filter into -func (p *ControllerRegister) insertFilterRouter(pos int, mr *FilterRouter) (err error) { - if pos < BeforeStatic || pos > FinishRouter { - return errors.New("can not find your filter position") - } - p.enableFilter = true - p.filters[pos] = append(p.filters[pos], mr) - return nil -} - -// URLFor does another controller handler in this request function. -// it can access any controller method. -// Deprecated: using pkg/, we will delete this in v2.1.0 -func (p *ControllerRegister) URLFor(endpoint string, values ...interface{}) string { - paths := strings.Split(endpoint, ".") - if len(paths) <= 1 { - logs.Warn("urlfor endpoint must like path.controller.method") - return "" - } - if len(values)%2 != 0 { - logs.Warn("urlfor params must key-value pair") - return "" - } - params := make(map[string]string) - if len(values) > 0 { - key := "" - for k, v := range values { - if k%2 == 0 { - key = fmt.Sprint(v) - } else { - params[key] = fmt.Sprint(v) - } - } - } - controllerName := strings.Join(paths[:len(paths)-1], "/") - methodName := paths[len(paths)-1] - for m, t := range p.routers { - ok, url := p.getURL(t, "/", controllerName, methodName, params, m) - if ok { - return url - } - } - return "" -} - -func (p *ControllerRegister) getURL(t *Tree, url, controllerName, methodName string, params map[string]string, httpMethod string) (bool, string) { - for _, subtree := range t.fixrouters { - u := path.Join(url, subtree.prefix) - ok, u := p.getURL(subtree, u, controllerName, methodName, params, httpMethod) - if ok { - return ok, u - } - } - if t.wildcard != nil { - u := path.Join(url, urlPlaceholder) - ok, u := p.getURL(t.wildcard, u, controllerName, methodName, params, httpMethod) - if ok { - return ok, u - } - } - for _, l := range t.leaves { - if c, ok := l.runObject.(*ControllerInfo); ok { - if c.routerType == routerTypeBeego && - strings.HasSuffix(path.Join(c.controllerType.PkgPath(), c.controllerType.Name()), controllerName) { - find := false - if HTTPMETHOD[strings.ToUpper(methodName)] { - if len(c.methods) == 0 { - find = true - } else if m, ok := c.methods[strings.ToUpper(methodName)]; ok && m == strings.ToUpper(methodName) { - find = true - } else if m, ok = c.methods["*"]; ok && m == methodName { - find = true - } - } - if !find { - for m, md := range c.methods { - if (m == "*" || m == httpMethod) && md == methodName { - find = true - } - } - } - if find { - if l.regexps == nil { - if len(l.wildcards) == 0 { - return true, strings.Replace(url, "/"+urlPlaceholder, "", 1) + toURL(params) - } - if len(l.wildcards) == 1 { - if v, ok := params[l.wildcards[0]]; ok { - delete(params, l.wildcards[0]) - return true, strings.Replace(url, urlPlaceholder, v, 1) + toURL(params) - } - return false, "" - } - if len(l.wildcards) == 3 && l.wildcards[0] == "." { - if p, ok := params[":path"]; ok { - if e, isok := params[":ext"]; isok { - delete(params, ":path") - delete(params, ":ext") - return true, strings.Replace(url, urlPlaceholder, p+"."+e, -1) + toURL(params) - } - } - } - canSkip := false - for _, v := range l.wildcards { - if v == ":" { - canSkip = true - continue - } - if u, ok := params[v]; ok { - delete(params, v) - url = strings.Replace(url, urlPlaceholder, u, 1) - } else { - if canSkip { - canSkip = false - continue - } - return false, "" - } - } - return true, url + toURL(params) - } - var i int - var startReg bool - regURL := "" - for _, v := range strings.Trim(l.regexps.String(), "^$") { - if v == '(' { - startReg = true - continue - } else if v == ')' { - startReg = false - if v, ok := params[l.wildcards[i]]; ok { - delete(params, l.wildcards[i]) - regURL = regURL + v - i++ - } else { - break - } - } else if !startReg { - regURL = string(append([]rune(regURL), v)) - } - } - if l.regexps.MatchString(regURL) { - ps := strings.Split(regURL, "/") - for _, p := range ps { - url = strings.Replace(url, urlPlaceholder, p, 1) - } - return true, url + toURL(params) - } - } - } - } - } - - return false, "" -} - -func (p *ControllerRegister) execFilter(context *beecontext.Context, urlPath string, pos int) (started bool) { - var preFilterParams map[string]string - for _, filterR := range p.filters[pos] { - if filterR.returnOnOutput && context.ResponseWriter.Started { - return true - } - if filterR.resetParams { - preFilterParams = context.Input.Params() - } - if ok := filterR.ValidRouter(urlPath, context); ok { - filterR.filterFunc(context) - if filterR.resetParams { - context.Input.ResetParams() - for k, v := range preFilterParams { - context.Input.SetParam(k, v) - } - } - } - if filterR.returnOnOutput && context.ResponseWriter.Started { - return true - } - } - return false -} - -// Implement http.Handler interface. -// Deprecated: using pkg/, we will delete this in v2.1.0 -func (p *ControllerRegister) ServeHTTP(rw http.ResponseWriter, r *http.Request) { - startTime := time.Now() - var ( - runRouter reflect.Type - findRouter bool - runMethod string - methodParams []*param.MethodParam - routerInfo *ControllerInfo - isRunnable bool - ) - context := p.GetContext() - - context.Reset(rw, r) - - defer p.GiveBackContext(context) - if BConfig.RecoverFunc != nil { - defer BConfig.RecoverFunc(context) - } - - context.Output.EnableGzip = BConfig.EnableGzip - - if BConfig.RunMode == DEV { - context.Output.Header("Server", BConfig.ServerName) - } - - var urlPath = r.URL.Path - - if !BConfig.RouterCaseSensitive { - urlPath = strings.ToLower(urlPath) - } - - // filter wrong http method - if !HTTPMETHOD[r.Method] { - exception("405", context) - goto Admin - } - - // filter for static file - if len(p.filters[BeforeStatic]) > 0 && p.execFilter(context, urlPath, BeforeStatic) { - goto Admin - } - - serverStaticRouter(context) - - if context.ResponseWriter.Started { - findRouter = true - goto Admin - } - - if r.Method != http.MethodGet && r.Method != http.MethodHead { - if BConfig.CopyRequestBody && !context.Input.IsUpload() { - // connection will close if the incoming data are larger (RFC 7231, 6.5.11) - if r.ContentLength > BConfig.MaxMemory { - logs.Error(errors.New("payload too large")) - exception("413", context) - goto Admin - } - context.Input.CopyBody(BConfig.MaxMemory) - } - context.Input.ParseFormOrMulitForm(BConfig.MaxMemory) - } - - // session init - if BConfig.WebConfig.Session.SessionOn { - var err error - context.Input.CruSession, err = GlobalSessions.SessionStart(rw, r) - if err != nil { - logs.Error(err) - exception("503", context) - goto Admin - } - defer func() { - if context.Input.CruSession != nil { - context.Input.CruSession.SessionRelease(rw) - } - }() - } - if len(p.filters[BeforeRouter]) > 0 && p.execFilter(context, urlPath, BeforeRouter) { - goto Admin - } - // User can define RunController and RunMethod in filter - if context.Input.RunController != nil && context.Input.RunMethod != "" { - findRouter = true - runMethod = context.Input.RunMethod - runRouter = context.Input.RunController - } else { - routerInfo, findRouter = p.FindRouter(context) - } - - // if no matches to url, throw a not found exception - if !findRouter { - exception("404", context) - goto Admin - } - if splat := context.Input.Param(":splat"); splat != "" { - for k, v := range strings.Split(splat, "/") { - context.Input.SetParam(strconv.Itoa(k), v) - } - } - - if routerInfo != nil { - // store router pattern into context - context.Input.SetData("RouterPattern", routerInfo.pattern) - } - - // execute middleware filters - if len(p.filters[BeforeExec]) > 0 && p.execFilter(context, urlPath, BeforeExec) { - goto Admin - } - - // check policies - if p.execPolicy(context, urlPath) { - goto Admin - } - - if routerInfo != nil { - if routerInfo.routerType == routerTypeRESTFul { - if _, ok := routerInfo.methods[r.Method]; ok { - isRunnable = true - routerInfo.runFunction(context) - } else { - exception("405", context) - goto Admin - } - } else if routerInfo.routerType == routerTypeHandler { - isRunnable = true - routerInfo.handler.ServeHTTP(context.ResponseWriter, context.Request) - } else { - runRouter = routerInfo.controllerType - methodParams = routerInfo.methodParams - method := r.Method - if r.Method == http.MethodPost && context.Input.Query("_method") == http.MethodPut { - method = http.MethodPut - } - if r.Method == http.MethodPost && context.Input.Query("_method") == http.MethodDelete { - method = http.MethodDelete - } - if m, ok := routerInfo.methods[method]; ok { - runMethod = m - } else if m, ok = routerInfo.methods["*"]; ok { - runMethod = m - } else { - runMethod = method - } - } - } - - // also defined runRouter & runMethod from filter - if !isRunnable { - // Invoke the request handler - var execController ControllerInterface - if routerInfo != nil && routerInfo.initialize != nil { - execController = routerInfo.initialize() - } else { - vc := reflect.New(runRouter) - var ok bool - execController, ok = vc.Interface().(ControllerInterface) - if !ok { - panic("controller is not ControllerInterface") - } - } - - // call the controller init function - execController.Init(context, runRouter.Name(), runMethod, execController) - - // call prepare function - execController.Prepare() - - // if XSRF is Enable then check cookie where there has any cookie in the request's cookie _csrf - if BConfig.WebConfig.EnableXSRF { - execController.XSRFToken() - if r.Method == http.MethodPost || r.Method == http.MethodDelete || r.Method == http.MethodPut || - (r.Method == http.MethodPost && (context.Input.Query("_method") == http.MethodDelete || context.Input.Query("_method") == http.MethodPut)) { - execController.CheckXSRFCookie() - } - } - - execController.URLMapping() - - if !context.ResponseWriter.Started { - // exec main logic - switch runMethod { - case http.MethodGet: - execController.Get() - case http.MethodPost: - execController.Post() - case http.MethodDelete: - execController.Delete() - case http.MethodPut: - execController.Put() - case http.MethodHead: - execController.Head() - case http.MethodPatch: - execController.Patch() - case http.MethodOptions: - execController.Options() - case http.MethodTrace: - execController.Trace() - default: - if !execController.HandlerFunc(runMethod) { - vc := reflect.ValueOf(execController) - method := vc.MethodByName(runMethod) - in := param.ConvertParams(methodParams, method.Type(), context) - out := method.Call(in) - - // For backward compatibility we only handle response if we had incoming methodParams - if methodParams != nil { - p.handleParamResponse(context, execController, out) - } - } - } - - // render template - if !context.ResponseWriter.Started && context.Output.Status == 0 { - if BConfig.WebConfig.AutoRender { - if err := execController.Render(); err != nil { - logs.Error(err) - } - } - } - } - - // finish all runRouter. release resource - execController.Finish() - } - - // execute middleware filters - if len(p.filters[AfterExec]) > 0 && p.execFilter(context, urlPath, AfterExec) { - goto Admin - } - - if len(p.filters[FinishRouter]) > 0 && p.execFilter(context, urlPath, FinishRouter) { - goto Admin - } - -Admin: - // admin module record QPS - - statusCode := context.ResponseWriter.Status - if statusCode == 0 { - statusCode = 200 - } - - LogAccess(context, &startTime, statusCode) - - timeDur := time.Since(startTime) - context.ResponseWriter.Elapsed = timeDur - if BConfig.Listen.EnableAdmin { - pattern := "" - if routerInfo != nil { - pattern = routerInfo.pattern - } - - if FilterMonitorFunc(r.Method, r.URL.Path, timeDur, pattern, statusCode) { - routerName := "" - if runRouter != nil { - routerName = runRouter.Name() - } - go toolbox.StatisticsMap.AddStatistics(r.Method, r.URL.Path, routerName, timeDur) - } - } - - if BConfig.RunMode == DEV && !BConfig.Log.AccessLogs { - match := map[bool]string{true: "match", false: "nomatch"} - devInfo := fmt.Sprintf("|%15s|%s %3d %s|%13s|%8s|%s %-7s %s %-3s", - context.Input.IP(), - logs.ColorByStatus(statusCode), statusCode, logs.ResetColor(), - timeDur.String(), - match[findRouter], - logs.ColorByMethod(r.Method), r.Method, logs.ResetColor(), - r.URL.Path) - if routerInfo != nil { - devInfo += fmt.Sprintf(" r:%s", routerInfo.pattern) - } - - logs.Debug(devInfo) - } - // Call WriteHeader if status code has been set changed - if context.Output.Status != 0 { - context.ResponseWriter.WriteHeader(context.Output.Status) - } -} - -func (p *ControllerRegister) handleParamResponse(context *beecontext.Context, execController ControllerInterface, results []reflect.Value) { - // looping in reverse order for the case when both error and value are returned and error sets the response status code - for i := len(results) - 1; i >= 0; i-- { - result := results[i] - if result.Kind() != reflect.Interface || !result.IsNil() { - resultValue := result.Interface() - context.RenderMethodResult(resultValue) - } - } - if !context.ResponseWriter.Started && len(results) > 0 && context.Output.Status == 0 { - context.Output.SetStatus(200) - } -} - -// FindRouter Find Router info for URL -// Deprecated: using pkg/, we will delete this in v2.1.0 -func (p *ControllerRegister) FindRouter(context *beecontext.Context) (routerInfo *ControllerInfo, isFind bool) { - var urlPath = context.Input.URL() - if !BConfig.RouterCaseSensitive { - urlPath = strings.ToLower(urlPath) - } - httpMethod := context.Input.Method() - if t, ok := p.routers[httpMethod]; ok { - runObject := t.Match(urlPath, context) - if r, ok := runObject.(*ControllerInfo); ok { - return r, true - } - } - return -} - -func toURL(params map[string]string) string { - if len(params) == 0 { - return "" - } - u := "?" - for k, v := range params { - u += k + "=" + v + "&" - } - return strings.TrimRight(u, "&") -} - -// LogAccess logging info HTTP Access -// Deprecated: using pkg/, we will delete this in v2.1.0 -func LogAccess(ctx *beecontext.Context, startTime *time.Time, statusCode int) { - // Skip logging if AccessLogs config is false - if !BConfig.Log.AccessLogs { - return - } - // Skip logging static requests unless EnableStaticLogs config is true - if !BConfig.Log.EnableStaticLogs && DefaultAccessLogFilter.Filter(ctx) { - return - } - var ( - requestTime time.Time - elapsedTime time.Duration - r = ctx.Request - ) - if startTime != nil { - requestTime = *startTime - elapsedTime = time.Since(*startTime) - } - record := &logs.AccessLogRecord{ - RemoteAddr: ctx.Input.IP(), - RequestTime: requestTime, - RequestMethod: r.Method, - Request: fmt.Sprintf("%s %s %s", r.Method, r.RequestURI, r.Proto), - ServerProtocol: r.Proto, - Host: r.Host, - Status: statusCode, - ElapsedTime: elapsedTime, - HTTPReferrer: r.Header.Get("Referer"), - HTTPUserAgent: r.Header.Get("User-Agent"), - RemoteUser: r.Header.Get("Remote-User"), - BodyBytesSent: r.ContentLength, - } - logs.AccessLog(record, BConfig.Log.AccessLogsFormat) -} diff --git a/router_test.go b/router_test.go deleted file mode 100644 index 8ec7927a..00000000 --- a/router_test.go +++ /dev/null @@ -1,732 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package beego - -import ( - "bytes" - "net/http" - "net/http/httptest" - "strings" - "testing" - - "github.com/astaxie/beego/context" - "github.com/astaxie/beego/logs" -) - -type TestController struct { - Controller -} - -func (tc *TestController) Get() { - tc.Data["Username"] = "astaxie" - tc.Ctx.Output.Body([]byte("ok")) -} - -func (tc *TestController) Post() { - tc.Ctx.Output.Body([]byte(tc.Ctx.Input.Query(":name"))) -} - -func (tc *TestController) Param() { - tc.Ctx.Output.Body([]byte(tc.Ctx.Input.Query(":name"))) -} - -func (tc *TestController) List() { - tc.Ctx.Output.Body([]byte("i am list")) -} - -func (tc *TestController) Params() { - tc.Ctx.Output.Body([]byte(tc.Ctx.Input.Param("0") + tc.Ctx.Input.Param("1") + tc.Ctx.Input.Param("2"))) -} - -func (tc *TestController) Myext() { - tc.Ctx.Output.Body([]byte(tc.Ctx.Input.Param(":ext"))) -} - -func (tc *TestController) GetURL() { - tc.Ctx.Output.Body([]byte(tc.URLFor(".Myext"))) -} - -func (tc *TestController) GetParams() { - tc.Ctx.WriteString(tc.Ctx.Input.Query(":last") + "+" + - tc.Ctx.Input.Query(":first") + "+" + tc.Ctx.Input.Query("learn")) -} - -func (tc *TestController) GetManyRouter() { - tc.Ctx.WriteString(tc.Ctx.Input.Query(":id") + tc.Ctx.Input.Query(":page")) -} - -func (tc *TestController) GetEmptyBody() { - var res []byte - tc.Ctx.Output.Body(res) -} - -type JSONController struct { - Controller -} - -func (jc *JSONController) Prepare() { - jc.Data["json"] = "prepare" - jc.ServeJSON(true) -} - -func (jc *JSONController) Get() { - jc.Data["Username"] = "astaxie" - jc.Ctx.Output.Body([]byte("ok")) -} - -func TestUrlFor(t *testing.T) { - handler := NewControllerRegister() - handler.Add("/api/list", &TestController{}, "*:List") - handler.Add("/person/:last/:first", &TestController{}, "*:Param") - if a := handler.URLFor("TestController.List"); a != "/api/list" { - logs.Info(a) - t.Errorf("TestController.List must equal to /api/list") - } - if a := handler.URLFor("TestController.Param", ":last", "xie", ":first", "asta"); a != "/person/xie/asta" { - t.Errorf("TestController.Param must equal to /person/xie/asta, but get " + a) - } -} - -func TestUrlFor3(t *testing.T) { - handler := NewControllerRegister() - handler.AddAuto(&TestController{}) - if a := handler.URLFor("TestController.Myext"); a != "/test/myext" && a != "/Test/Myext" { - t.Errorf("TestController.Myext must equal to /test/myext, but get " + a) - } - if a := handler.URLFor("TestController.GetURL"); a != "/test/geturl" && a != "/Test/GetURL" { - t.Errorf("TestController.GetURL must equal to /test/geturl, but get " + a) - } -} - -func TestUrlFor2(t *testing.T) { - handler := NewControllerRegister() - handler.Add("/v1/:v/cms_:id(.+)_:page(.+).html", &TestController{}, "*:List") - handler.Add("/v1/:username/edit", &TestController{}, "get:GetURL") - handler.Add("/v1/:v(.+)_cms/ttt_:id(.+)_:page(.+).html", &TestController{}, "*:Param") - handler.Add("/:year:int/:month:int/:title/:entid", &TestController{}) - if handler.URLFor("TestController.GetURL", ":username", "astaxie") != "/v1/astaxie/edit" { - logs.Info(handler.URLFor("TestController.GetURL")) - t.Errorf("TestController.List must equal to /v1/astaxie/edit") - } - - if handler.URLFor("TestController.List", ":v", "za", ":id", "12", ":page", "123") != - "/v1/za/cms_12_123.html" { - logs.Info(handler.URLFor("TestController.List")) - t.Errorf("TestController.List must equal to /v1/za/cms_12_123.html") - } - if handler.URLFor("TestController.Param", ":v", "za", ":id", "12", ":page", "123") != - "/v1/za_cms/ttt_12_123.html" { - logs.Info(handler.URLFor("TestController.Param")) - t.Errorf("TestController.List must equal to /v1/za_cms/ttt_12_123.html") - } - if handler.URLFor("TestController.Get", ":year", "1111", ":month", "11", - ":title", "aaaa", ":entid", "aaaa") != - "/1111/11/aaaa/aaaa" { - logs.Info(handler.URLFor("TestController.Get")) - t.Errorf("TestController.Get must equal to /1111/11/aaaa/aaaa") - } -} - -func TestUserFunc(t *testing.T) { - r, _ := http.NewRequest("GET", "/api/list", nil) - w := httptest.NewRecorder() - - handler := NewControllerRegister() - handler.Add("/api/list", &TestController{}, "*:List") - handler.ServeHTTP(w, r) - if w.Body.String() != "i am list" { - t.Errorf("user define func can't run") - } -} - -func TestPostFunc(t *testing.T) { - r, _ := http.NewRequest("POST", "/astaxie", nil) - w := httptest.NewRecorder() - - handler := NewControllerRegister() - handler.Add("/:name", &TestController{}) - handler.ServeHTTP(w, r) - if w.Body.String() != "astaxie" { - t.Errorf("post func should astaxie") - } -} - -func TestAutoFunc(t *testing.T) { - r, _ := http.NewRequest("GET", "/test/list", nil) - w := httptest.NewRecorder() - - handler := NewControllerRegister() - handler.AddAuto(&TestController{}) - handler.ServeHTTP(w, r) - if w.Body.String() != "i am list" { - t.Errorf("user define func can't run") - } -} - -func TestAutoFunc2(t *testing.T) { - r, _ := http.NewRequest("GET", "/Test/List", nil) - w := httptest.NewRecorder() - - handler := NewControllerRegister() - handler.AddAuto(&TestController{}) - handler.ServeHTTP(w, r) - if w.Body.String() != "i am list" { - t.Errorf("user define func can't run") - } -} - -func TestAutoFuncParams(t *testing.T) { - r, _ := http.NewRequest("GET", "/test/params/2009/11/12", nil) - w := httptest.NewRecorder() - - handler := NewControllerRegister() - handler.AddAuto(&TestController{}) - handler.ServeHTTP(w, r) - if w.Body.String() != "20091112" { - t.Errorf("user define func can't run") - } -} - -func TestAutoExtFunc(t *testing.T) { - r, _ := http.NewRequest("GET", "/test/myext.json", nil) - w := httptest.NewRecorder() - - handler := NewControllerRegister() - handler.AddAuto(&TestController{}) - handler.ServeHTTP(w, r) - if w.Body.String() != "json" { - t.Errorf("user define func can't run") - } -} - -func TestRouteOk(t *testing.T) { - - r, _ := http.NewRequest("GET", "/person/anderson/thomas?learn=kungfu", nil) - w := httptest.NewRecorder() - - handler := NewControllerRegister() - handler.Add("/person/:last/:first", &TestController{}, "get:GetParams") - handler.ServeHTTP(w, r) - body := w.Body.String() - if body != "anderson+thomas+kungfu" { - t.Errorf("url param set to [%s];", body) - } -} - -func TestManyRoute(t *testing.T) { - - r, _ := http.NewRequest("GET", "/beego32-12.html", nil) - w := httptest.NewRecorder() - - handler := NewControllerRegister() - handler.Add("/beego:id([0-9]+)-:page([0-9]+).html", &TestController{}, "get:GetManyRouter") - handler.ServeHTTP(w, r) - - body := w.Body.String() - - if body != "3212" { - t.Errorf("url param set to [%s];", body) - } -} - -// Test for issue #1669 -func TestEmptyResponse(t *testing.T) { - - r, _ := http.NewRequest("GET", "/beego-empty.html", nil) - w := httptest.NewRecorder() - - handler := NewControllerRegister() - handler.Add("/beego-empty.html", &TestController{}, "get:GetEmptyBody") - handler.ServeHTTP(w, r) - - if body := w.Body.String(); body != "" { - t.Error("want empty body") - } -} - -func TestNotFound(t *testing.T) { - r, _ := http.NewRequest("GET", "/", nil) - w := httptest.NewRecorder() - - handler := NewControllerRegister() - handler.ServeHTTP(w, r) - - if w.Code != http.StatusNotFound { - t.Errorf("Code set to [%v]; want [%v]", w.Code, http.StatusNotFound) - } -} - -// TestStatic tests the ability to serve static -// content from the filesystem -func TestStatic(t *testing.T) { - r, _ := http.NewRequest("GET", "/static/js/jquery.js", nil) - w := httptest.NewRecorder() - - handler := NewControllerRegister() - handler.ServeHTTP(w, r) - - if w.Code != 404 { - t.Errorf("handler.Static failed to serve file") - } -} - -func TestPrepare(t *testing.T) { - r, _ := http.NewRequest("GET", "/json/list", nil) - w := httptest.NewRecorder() - - handler := NewControllerRegister() - handler.Add("/json/list", &JSONController{}) - handler.ServeHTTP(w, r) - if w.Body.String() != `"prepare"` { - t.Errorf(w.Body.String() + "user define func can't run") - } -} - -func TestAutoPrefix(t *testing.T) { - r, _ := http.NewRequest("GET", "/admin/test/list", nil) - w := httptest.NewRecorder() - - handler := NewControllerRegister() - handler.AddAutoPrefix("/admin", &TestController{}) - handler.ServeHTTP(w, r) - if w.Body.String() != "i am list" { - t.Errorf("TestAutoPrefix can't run") - } -} - -func TestRouterGet(t *testing.T) { - r, _ := http.NewRequest("GET", "/user", nil) - w := httptest.NewRecorder() - - handler := NewControllerRegister() - handler.Get("/user", func(ctx *context.Context) { - ctx.Output.Body([]byte("Get userlist")) - }) - handler.ServeHTTP(w, r) - if w.Body.String() != "Get userlist" { - t.Errorf("TestRouterGet can't run") - } -} - -func TestRouterPost(t *testing.T) { - r, _ := http.NewRequest("POST", "/user/123", nil) - w := httptest.NewRecorder() - - handler := NewControllerRegister() - handler.Post("/user/:id", func(ctx *context.Context) { - ctx.Output.Body([]byte(ctx.Input.Param(":id"))) - }) - handler.ServeHTTP(w, r) - if w.Body.String() != "123" { - t.Errorf("TestRouterPost can't run") - } -} - -func sayhello(w http.ResponseWriter, r *http.Request) { - w.Write([]byte("sayhello")) -} - -func TestRouterHandler(t *testing.T) { - r, _ := http.NewRequest("POST", "/sayhi", nil) - w := httptest.NewRecorder() - - handler := NewControllerRegister() - handler.Handler("/sayhi", http.HandlerFunc(sayhello)) - handler.ServeHTTP(w, r) - if w.Body.String() != "sayhello" { - t.Errorf("TestRouterHandler can't run") - } -} - -func TestRouterHandlerAll(t *testing.T) { - r, _ := http.NewRequest("POST", "/sayhi/a/b/c", nil) - w := httptest.NewRecorder() - - handler := NewControllerRegister() - handler.Handler("/sayhi", http.HandlerFunc(sayhello), true) - handler.ServeHTTP(w, r) - if w.Body.String() != "sayhello" { - t.Errorf("TestRouterHandler can't run") - } -} - -// -// Benchmarks NewApp: -// - -func beegoFilterFunc(ctx *context.Context) { - ctx.WriteString("hello") -} - -type AdminController struct { - Controller -} - -func (a *AdminController) Get() { - a.Ctx.WriteString("hello") -} - -func TestRouterFunc(t *testing.T) { - mux := NewControllerRegister() - mux.Get("/action", beegoFilterFunc) - mux.Post("/action", beegoFilterFunc) - rw, r := testRequest("GET", "/action") - mux.ServeHTTP(rw, r) - if rw.Body.String() != "hello" { - t.Errorf("TestRouterFunc can't run") - } -} - -func BenchmarkFunc(b *testing.B) { - mux := NewControllerRegister() - mux.Get("/action", beegoFilterFunc) - rw, r := testRequest("GET", "/action") - b.ResetTimer() - for i := 0; i < b.N; i++ { - mux.ServeHTTP(rw, r) - } -} - -func BenchmarkController(b *testing.B) { - mux := NewControllerRegister() - mux.Add("/action", &AdminController{}) - rw, r := testRequest("GET", "/action") - b.ResetTimer() - for i := 0; i < b.N; i++ { - mux.ServeHTTP(rw, r) - } -} - -func testRequest(method, path string) (*httptest.ResponseRecorder, *http.Request) { - request, _ := http.NewRequest(method, path, nil) - recorder := httptest.NewRecorder() - - return recorder, request -} - -// Expectation: A Filter with the correct configuration should be created given -// specific parameters. -func TestInsertFilter(t *testing.T) { - testName := "TestInsertFilter" - - mux := NewControllerRegister() - mux.InsertFilter("*", BeforeRouter, func(*context.Context) {}) - if !mux.filters[BeforeRouter][0].returnOnOutput { - t.Errorf( - "%s: passing no variadic params should set returnOnOutput to true", - testName) - } - if mux.filters[BeforeRouter][0].resetParams { - t.Errorf( - "%s: passing no variadic params should set resetParams to false", - testName) - } - - mux = NewControllerRegister() - mux.InsertFilter("*", BeforeRouter, func(*context.Context) {}, false) - if mux.filters[BeforeRouter][0].returnOnOutput { - t.Errorf( - "%s: passing false as 1st variadic param should set returnOnOutput to false", - testName) - } - - mux = NewControllerRegister() - mux.InsertFilter("*", BeforeRouter, func(*context.Context) {}, true, true) - if !mux.filters[BeforeRouter][0].resetParams { - t.Errorf( - "%s: passing true as 2nd variadic param should set resetParams to true", - testName) - } -} - -// Expectation: the second variadic arg should cause the execution of the filter -// to preserve the parameters from before its execution. -func TestParamResetFilter(t *testing.T) { - testName := "TestParamResetFilter" - route := "/beego/*" // splat - path := "/beego/routes/routes" - - mux := NewControllerRegister() - - mux.InsertFilter("*", BeforeExec, beegoResetParams, true, true) - - mux.Get(route, beegoHandleResetParams) - - rw, r := testRequest("GET", path) - mux.ServeHTTP(rw, r) - - // The two functions, `beegoResetParams` and `beegoHandleResetParams` add - // a response header of `Splat`. The expectation here is that that Header - // value should match what the _request's_ router set, not the filter's. - - headers := rw.Result().Header - if len(headers["Splat"]) != 1 { - t.Errorf( - "%s: There was an error in the test. Splat param not set in Header", - testName) - } - if headers["Splat"][0] != "routes/routes" { - t.Errorf( - "%s: expected `:splat` param to be [routes/routes] but it was [%s]", - testName, headers["Splat"][0]) - } -} - -// Execution point: BeforeRouter -// expectation: only BeforeRouter function is executed, notmatch output as router doesn't handle -func TestFilterBeforeRouter(t *testing.T) { - testName := "TestFilterBeforeRouter" - url := "/beforeRouter" - - mux := NewControllerRegister() - mux.InsertFilter(url, BeforeRouter, beegoBeforeRouter1) - - mux.Get(url, beegoFilterFunc) - - rw, r := testRequest("GET", url) - mux.ServeHTTP(rw, r) - - if !strings.Contains(rw.Body.String(), "BeforeRouter1") { - t.Errorf(testName + " BeforeRouter did not run") - } - if strings.Contains(rw.Body.String(), "hello") { - t.Errorf(testName + " BeforeRouter did not return properly") - } -} - -// Execution point: BeforeExec -// expectation: only BeforeExec function is executed, match as router determines route only -func TestFilterBeforeExec(t *testing.T) { - testName := "TestFilterBeforeExec" - url := "/beforeExec" - - mux := NewControllerRegister() - mux.InsertFilter(url, BeforeRouter, beegoFilterNoOutput) - mux.InsertFilter(url, BeforeExec, beegoBeforeExec1) - - mux.Get(url, beegoFilterFunc) - - rw, r := testRequest("GET", url) - mux.ServeHTTP(rw, r) - - if !strings.Contains(rw.Body.String(), "BeforeExec1") { - t.Errorf(testName + " BeforeExec did not run") - } - if strings.Contains(rw.Body.String(), "hello") { - t.Errorf(testName + " BeforeExec did not return properly") - } - if strings.Contains(rw.Body.String(), "BeforeRouter") { - t.Errorf(testName + " BeforeRouter ran in error") - } -} - -// Execution point: AfterExec -// expectation: only AfterExec function is executed, match as router handles -func TestFilterAfterExec(t *testing.T) { - testName := "TestFilterAfterExec" - url := "/afterExec" - - mux := NewControllerRegister() - mux.InsertFilter(url, BeforeRouter, beegoFilterNoOutput) - mux.InsertFilter(url, BeforeExec, beegoFilterNoOutput) - mux.InsertFilter(url, AfterExec, beegoAfterExec1, false) - - mux.Get(url, beegoFilterFunc) - - rw, r := testRequest("GET", url) - mux.ServeHTTP(rw, r) - - if !strings.Contains(rw.Body.String(), "AfterExec1") { - t.Errorf(testName + " AfterExec did not run") - } - if !strings.Contains(rw.Body.String(), "hello") { - t.Errorf(testName + " handler did not run properly") - } - if strings.Contains(rw.Body.String(), "BeforeRouter") { - t.Errorf(testName + " BeforeRouter ran in error") - } - if strings.Contains(rw.Body.String(), "BeforeExec") { - t.Errorf(testName + " BeforeExec ran in error") - } -} - -// Execution point: FinishRouter -// expectation: only FinishRouter function is executed, match as router handles -func TestFilterFinishRouter(t *testing.T) { - testName := "TestFilterFinishRouter" - url := "/finishRouter" - - mux := NewControllerRegister() - mux.InsertFilter(url, BeforeRouter, beegoFilterNoOutput) - mux.InsertFilter(url, BeforeExec, beegoFilterNoOutput) - mux.InsertFilter(url, AfterExec, beegoFilterNoOutput) - mux.InsertFilter(url, FinishRouter, beegoFinishRouter1) - - mux.Get(url, beegoFilterFunc) - - rw, r := testRequest("GET", url) - mux.ServeHTTP(rw, r) - - if strings.Contains(rw.Body.String(), "FinishRouter1") { - t.Errorf(testName + " FinishRouter did not run") - } - if !strings.Contains(rw.Body.String(), "hello") { - t.Errorf(testName + " handler did not run properly") - } - if strings.Contains(rw.Body.String(), "AfterExec1") { - t.Errorf(testName + " AfterExec ran in error") - } - if strings.Contains(rw.Body.String(), "BeforeRouter") { - t.Errorf(testName + " BeforeRouter ran in error") - } - if strings.Contains(rw.Body.String(), "BeforeExec") { - t.Errorf(testName + " BeforeExec ran in error") - } -} - -// Execution point: FinishRouter -// expectation: only first FinishRouter function is executed, match as router handles -func TestFilterFinishRouterMultiFirstOnly(t *testing.T) { - testName := "TestFilterFinishRouterMultiFirstOnly" - url := "/finishRouterMultiFirstOnly" - - mux := NewControllerRegister() - mux.InsertFilter(url, FinishRouter, beegoFinishRouter1, false) - mux.InsertFilter(url, FinishRouter, beegoFinishRouter2) - - mux.Get(url, beegoFilterFunc) - - rw, r := testRequest("GET", url) - mux.ServeHTTP(rw, r) - - if !strings.Contains(rw.Body.String(), "FinishRouter1") { - t.Errorf(testName + " FinishRouter1 did not run") - } - if !strings.Contains(rw.Body.String(), "hello") { - t.Errorf(testName + " handler did not run properly") - } - // not expected in body - if strings.Contains(rw.Body.String(), "FinishRouter2") { - t.Errorf(testName + " FinishRouter2 did run") - } -} - -// Execution point: FinishRouter -// expectation: both FinishRouter functions execute, match as router handles -func TestFilterFinishRouterMulti(t *testing.T) { - testName := "TestFilterFinishRouterMulti" - url := "/finishRouterMulti" - - mux := NewControllerRegister() - mux.InsertFilter(url, FinishRouter, beegoFinishRouter1, false) - mux.InsertFilter(url, FinishRouter, beegoFinishRouter2, false) - - mux.Get(url, beegoFilterFunc) - - rw, r := testRequest("GET", url) - mux.ServeHTTP(rw, r) - - if !strings.Contains(rw.Body.String(), "FinishRouter1") { - t.Errorf(testName + " FinishRouter1 did not run") - } - if !strings.Contains(rw.Body.String(), "hello") { - t.Errorf(testName + " handler did not run properly") - } - if !strings.Contains(rw.Body.String(), "FinishRouter2") { - t.Errorf(testName + " FinishRouter2 did not run properly") - } -} - -func beegoFilterNoOutput(ctx *context.Context) { -} - -func beegoBeforeRouter1(ctx *context.Context) { - ctx.WriteString("|BeforeRouter1") -} - -func beegoBeforeExec1(ctx *context.Context) { - ctx.WriteString("|BeforeExec1") -} - -func beegoAfterExec1(ctx *context.Context) { - ctx.WriteString("|AfterExec1") -} - -func beegoFinishRouter1(ctx *context.Context) { - ctx.WriteString("|FinishRouter1") -} - -func beegoFinishRouter2(ctx *context.Context) { - ctx.WriteString("|FinishRouter2") -} - -func beegoResetParams(ctx *context.Context) { - ctx.ResponseWriter.Header().Set("splat", ctx.Input.Param(":splat")) -} - -func beegoHandleResetParams(ctx *context.Context) { - ctx.ResponseWriter.Header().Set("splat", ctx.Input.Param(":splat")) -} - -// YAML -type YAMLController struct { - Controller -} - -func (jc *YAMLController) Prepare() { - jc.Data["yaml"] = "prepare" - jc.ServeYAML() -} - -func (jc *YAMLController) Get() { - jc.Data["Username"] = "astaxie" - jc.Ctx.Output.Body([]byte("ok")) -} - -func TestYAMLPrepare(t *testing.T) { - r, _ := http.NewRequest("GET", "/yaml/list", nil) - w := httptest.NewRecorder() - - handler := NewControllerRegister() - handler.Add("/yaml/list", &YAMLController{}) - handler.ServeHTTP(w, r) - if strings.TrimSpace(w.Body.String()) != "prepare" { - t.Errorf(w.Body.String()) - } -} - -func TestRouterEntityTooLargeCopyBody(t *testing.T) { - _MaxMemory := BConfig.MaxMemory - _CopyRequestBody := BConfig.CopyRequestBody - BConfig.CopyRequestBody = true - BConfig.MaxMemory = 20 - - b := bytes.NewBuffer([]byte("barbarbarbarbarbarbarbarbarbar")) - r, _ := http.NewRequest("POST", "/user/123", b) - w := httptest.NewRecorder() - - handler := NewControllerRegister() - handler.Post("/user/:id", func(ctx *context.Context) { - ctx.Output.Body([]byte(ctx.Input.Param(":id"))) - }) - handler.ServeHTTP(w, r) - - BConfig.CopyRequestBody = _CopyRequestBody - BConfig.MaxMemory = _MaxMemory - - if w.Code != http.StatusRequestEntityTooLarge { - t.Errorf("TestRouterRequestEntityTooLarge can't run") - } -} diff --git a/session/README.md b/session/README.md deleted file mode 100644 index 6d0a297e..00000000 --- a/session/README.md +++ /dev/null @@ -1,114 +0,0 @@ -session -============== - -session is a Go session manager. It can use many session providers. Just like the `database/sql` and `database/sql/driver`. - -## How to install? - - go get github.com/astaxie/beego/session - - -## What providers are supported? - -As of now this session manager support memory, file, Redis and MySQL. - - -## How to use it? - -First you must import it - - import ( - "github.com/astaxie/beego/session" - ) - -Then in you web app init the global session manager - - var globalSessions *session.Manager - -* Use **memory** as provider: - - func init() { - globalSessions, _ = session.NewManager("memory", `{"cookieName":"gosessionid","gclifetime":3600}`) - go globalSessions.GC() - } - -* Use **file** as provider, the last param is the path where you want file to be stored: - - func init() { - globalSessions, _ = session.NewManager("file",`{"cookieName":"gosessionid","gclifetime":3600,"ProviderConfig":"./tmp"}`) - go globalSessions.GC() - } - -* Use **Redis** as provider, the last param is the Redis conn address,poolsize,password: - - func init() { - globalSessions, _ = session.NewManager("redis", `{"cookieName":"gosessionid","gclifetime":3600,"ProviderConfig":"127.0.0.1:6379,100,astaxie"}`) - go globalSessions.GC() - } - -* Use **MySQL** as provider, the last param is the DSN, learn more from [mysql](https://github.com/go-sql-driver/mysql#dsn-data-source-name): - - func init() { - globalSessions, _ = session.NewManager( - "mysql", `{"cookieName":"gosessionid","gclifetime":3600,"ProviderConfig":"username:password@protocol(address)/dbname?param=value"}`) - go globalSessions.GC() - } - -* Use **Cookie** as provider: - - func init() { - globalSessions, _ = session.NewManager( - "cookie", `{"cookieName":"gosessionid","enableSetCookie":false,"gclifetime":3600,"ProviderConfig":"{\"cookieName\":\"gosessionid\",\"securityKey\":\"beegocookiehashkey\"}"}`) - go globalSessions.GC() - } - - -Finally in the handlerfunc you can use it like this - - func login(w http.ResponseWriter, r *http.Request) { - sess := globalSessions.SessionStart(w, r) - defer sess.SessionRelease(w) - username := sess.Get("username") - fmt.Println(username) - if r.Method == "GET" { - t, _ := template.ParseFiles("login.gtpl") - t.Execute(w, nil) - } else { - fmt.Println("username:", r.Form["username"]) - sess.Set("username", r.Form["username"]) - fmt.Println("password:", r.Form["password"]) - } - } - - -## How to write own provider? - -When you develop a web app, maybe you want to write own provider because you must meet the requirements. - -Writing a provider is easy. You only need to define two struct types -(Session and Provider), which satisfy the interface definition. -Maybe you will find the **memory** provider is a good example. - - type SessionStore interface { - Set(key, value interface{}) error //set session value - Get(key interface{}) interface{} //get session value - Delete(key interface{}) error //delete session value - SessionID() string //back current sessionID - SessionRelease(w http.ResponseWriter) // release the resource & save data to provider & return the data - Flush() error //delete all data - } - - type Provider interface { - SessionInit(gclifetime int64, config string) error - SessionRead(sid string) (SessionStore, error) - SessionExist(sid string) bool - SessionRegenerate(oldsid, sid string) (SessionStore, error) - SessionDestroy(sid string) error - SessionAll() int //get all active session - SessionGC() - } - - -## LICENSE - -BSD License http://creativecommons.org/licenses/BSD/ diff --git a/session/couchbase/sess_couchbase.go b/session/couchbase/sess_couchbase.go deleted file mode 100644 index 707d042c..00000000 --- a/session/couchbase/sess_couchbase.go +++ /dev/null @@ -1,247 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Package couchbase for session provider -// -// depend on github.com/couchbaselabs/go-couchbasee -// -// go install github.com/couchbaselabs/go-couchbase -// -// Usage: -// import( -// _ "github.com/astaxie/beego/session/couchbase" -// "github.com/astaxie/beego/session" -// ) -// -// func init() { -// globalSessions, _ = session.NewManager("couchbase", ``{"cookieName":"gosessionid","gclifetime":3600,"ProviderConfig":"http://host:port/, Pool, Bucket"}``) -// go globalSessions.GC() -// } -// -// more docs: http://beego.me/docs/module/session.md -package couchbase - -import ( - "net/http" - "strings" - "sync" - - couchbase "github.com/couchbase/go-couchbase" - - "github.com/astaxie/beego/session" -) - -var couchbpder = &Provider{} - -// SessionStore store each session -type SessionStore struct { - b *couchbase.Bucket - sid string - lock sync.RWMutex - values map[interface{}]interface{} - maxlifetime int64 -} - -// Provider couchabse provided -type Provider struct { - maxlifetime int64 - savePath string - pool string - bucket string - b *couchbase.Bucket -} - -// Set value to couchabse session -func (cs *SessionStore) Set(key, value interface{}) error { - cs.lock.Lock() - defer cs.lock.Unlock() - cs.values[key] = value - return nil -} - -// Get value from couchabse session -func (cs *SessionStore) Get(key interface{}) interface{} { - cs.lock.RLock() - defer cs.lock.RUnlock() - if v, ok := cs.values[key]; ok { - return v - } - return nil -} - -// Delete value in couchbase session by given key -func (cs *SessionStore) Delete(key interface{}) error { - cs.lock.Lock() - defer cs.lock.Unlock() - delete(cs.values, key) - return nil -} - -// Flush Clean all values in couchbase session -func (cs *SessionStore) Flush() error { - cs.lock.Lock() - defer cs.lock.Unlock() - cs.values = make(map[interface{}]interface{}) - return nil -} - -// SessionID Get couchbase session store id -func (cs *SessionStore) SessionID() string { - return cs.sid -} - -// SessionRelease Write couchbase session with Gob string -func (cs *SessionStore) SessionRelease(w http.ResponseWriter) { - defer cs.b.Close() - - bo, err := session.EncodeGob(cs.values) - if err != nil { - return - } - - cs.b.Set(cs.sid, int(cs.maxlifetime), bo) -} - -func (cp *Provider) getBucket() *couchbase.Bucket { - c, err := couchbase.Connect(cp.savePath) - if err != nil { - return nil - } - - pool, err := c.GetPool(cp.pool) - if err != nil { - return nil - } - - bucket, err := pool.GetBucket(cp.bucket) - if err != nil { - return nil - } - - return bucket -} - -// SessionInit init couchbase session -// savepath like couchbase server REST/JSON URL -// e.g. http://host:port/, Pool, Bucket -func (cp *Provider) SessionInit(maxlifetime int64, savePath string) error { - cp.maxlifetime = maxlifetime - configs := strings.Split(savePath, ",") - if len(configs) > 0 { - cp.savePath = configs[0] - } - if len(configs) > 1 { - cp.pool = configs[1] - } - if len(configs) > 2 { - cp.bucket = configs[2] - } - - return nil -} - -// SessionRead read couchbase session by sid -func (cp *Provider) SessionRead(sid string) (session.Store, error) { - cp.b = cp.getBucket() - - var ( - kv map[interface{}]interface{} - err error - doc []byte - ) - - err = cp.b.Get(sid, &doc) - if err != nil { - return nil, err - } else if doc == nil { - kv = make(map[interface{}]interface{}) - } else { - kv, err = session.DecodeGob(doc) - if err != nil { - return nil, err - } - } - - cs := &SessionStore{b: cp.b, sid: sid, values: kv, maxlifetime: cp.maxlifetime} - return cs, nil -} - -// SessionExist Check couchbase session exist. -// it checkes sid exist or not. -func (cp *Provider) SessionExist(sid string) bool { - cp.b = cp.getBucket() - defer cp.b.Close() - - var doc []byte - - if err := cp.b.Get(sid, &doc); err != nil || doc == nil { - return false - } - return true -} - -// SessionRegenerate remove oldsid and use sid to generate new session -func (cp *Provider) SessionRegenerate(oldsid, sid string) (session.Store, error) { - cp.b = cp.getBucket() - - var doc []byte - if err := cp.b.Get(oldsid, &doc); err != nil || doc == nil { - cp.b.Set(sid, int(cp.maxlifetime), "") - } else { - err := cp.b.Delete(oldsid) - if err != nil { - return nil, err - } - _, _ = cp.b.Add(sid, int(cp.maxlifetime), doc) - } - - err := cp.b.Get(sid, &doc) - if err != nil { - return nil, err - } - var kv map[interface{}]interface{} - if doc == nil { - kv = make(map[interface{}]interface{}) - } else { - kv, err = session.DecodeGob(doc) - if err != nil { - return nil, err - } - } - - cs := &SessionStore{b: cp.b, sid: sid, values: kv, maxlifetime: cp.maxlifetime} - return cs, nil -} - -// SessionDestroy Remove bucket in this couchbase -func (cp *Provider) SessionDestroy(sid string) error { - cp.b = cp.getBucket() - defer cp.b.Close() - - cp.b.Delete(sid) - return nil -} - -// SessionGC Recycle -func (cp *Provider) SessionGC() { -} - -// SessionAll return all active session -func (cp *Provider) SessionAll() int { - return 0 -} - -func init() { - session.Register("couchbase", couchbpder) -} diff --git a/session/ledis/ledis_session.go b/session/ledis/ledis_session.go deleted file mode 100644 index ee81df67..00000000 --- a/session/ledis/ledis_session.go +++ /dev/null @@ -1,173 +0,0 @@ -// Package ledis provide session Provider -package ledis - -import ( - "net/http" - "strconv" - "strings" - "sync" - - "github.com/ledisdb/ledisdb/config" - "github.com/ledisdb/ledisdb/ledis" - - "github.com/astaxie/beego/session" -) - -var ( - ledispder = &Provider{} - c *ledis.DB -) - -// SessionStore ledis session store -type SessionStore struct { - sid string - lock sync.RWMutex - values map[interface{}]interface{} - maxlifetime int64 -} - -// Set value in ledis session -func (ls *SessionStore) Set(key, value interface{}) error { - ls.lock.Lock() - defer ls.lock.Unlock() - ls.values[key] = value - return nil -} - -// Get value in ledis session -func (ls *SessionStore) Get(key interface{}) interface{} { - ls.lock.RLock() - defer ls.lock.RUnlock() - if v, ok := ls.values[key]; ok { - return v - } - return nil -} - -// Delete value in ledis session -func (ls *SessionStore) Delete(key interface{}) error { - ls.lock.Lock() - defer ls.lock.Unlock() - delete(ls.values, key) - return nil -} - -// Flush clear all values in ledis session -func (ls *SessionStore) Flush() error { - ls.lock.Lock() - defer ls.lock.Unlock() - ls.values = make(map[interface{}]interface{}) - return nil -} - -// SessionID get ledis session id -func (ls *SessionStore) SessionID() string { - return ls.sid -} - -// SessionRelease save session values to ledis -func (ls *SessionStore) SessionRelease(w http.ResponseWriter) { - b, err := session.EncodeGob(ls.values) - if err != nil { - return - } - c.Set([]byte(ls.sid), b) - c.Expire([]byte(ls.sid), ls.maxlifetime) -} - -// Provider ledis session provider -type Provider struct { - maxlifetime int64 - savePath string - db int -} - -// SessionInit init ledis session -// savepath like ledis server saveDataPath,pool size -// e.g. 127.0.0.1:6379,100,astaxie -func (lp *Provider) SessionInit(maxlifetime int64, savePath string) error { - var err error - lp.maxlifetime = maxlifetime - configs := strings.Split(savePath, ",") - if len(configs) == 1 { - lp.savePath = configs[0] - } else if len(configs) == 2 { - lp.savePath = configs[0] - lp.db, err = strconv.Atoi(configs[1]) - if err != nil { - return err - } - } - cfg := new(config.Config) - cfg.DataDir = lp.savePath - - var ledisInstance *ledis.Ledis - ledisInstance, err = ledis.Open(cfg) - if err != nil { - return err - } - c, err = ledisInstance.Select(lp.db) - return err -} - -// SessionRead read ledis session by sid -func (lp *Provider) SessionRead(sid string) (session.Store, error) { - var ( - kv map[interface{}]interface{} - err error - ) - - kvs, _ := c.Get([]byte(sid)) - - if len(kvs) == 0 { - kv = make(map[interface{}]interface{}) - } else { - if kv, err = session.DecodeGob(kvs); err != nil { - return nil, err - } - } - - ls := &SessionStore{sid: sid, values: kv, maxlifetime: lp.maxlifetime} - return ls, nil -} - -// SessionExist check ledis session exist by sid -func (lp *Provider) SessionExist(sid string) bool { - count, _ := c.Exists([]byte(sid)) - return count != 0 -} - -// SessionRegenerate generate new sid for ledis session -func (lp *Provider) SessionRegenerate(oldsid, sid string) (session.Store, error) { - count, _ := c.Exists([]byte(sid)) - if count == 0 { - // oldsid doesn't exists, set the new sid directly - // ignore error here, since if it return error - // the existed value will be 0 - c.Set([]byte(sid), []byte("")) - c.Expire([]byte(sid), lp.maxlifetime) - } else { - data, _ := c.Get([]byte(oldsid)) - c.Set([]byte(sid), data) - c.Expire([]byte(sid), lp.maxlifetime) - } - return lp.SessionRead(sid) -} - -// SessionDestroy delete ledis session by id -func (lp *Provider) SessionDestroy(sid string) error { - c.Del([]byte(sid)) - return nil -} - -// SessionGC Impelment method, no used. -func (lp *Provider) SessionGC() { -} - -// SessionAll return all active session -func (lp *Provider) SessionAll() int { - return 0 -} -func init() { - session.Register("ledis", ledispder) -} diff --git a/session/memcache/sess_memcache.go b/session/memcache/sess_memcache.go deleted file mode 100644 index 85a2d815..00000000 --- a/session/memcache/sess_memcache.go +++ /dev/null @@ -1,230 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Package memcache for session provider -// -// depend on github.com/bradfitz/gomemcache/memcache -// -// go install github.com/bradfitz/gomemcache/memcache -// -// Usage: -// import( -// _ "github.com/astaxie/beego/session/memcache" -// "github.com/astaxie/beego/session" -// ) -// -// func init() { -// globalSessions, _ = session.NewManager("memcache", ``{"cookieName":"gosessionid","gclifetime":3600,"ProviderConfig":"127.0.0.1:11211"}``) -// go globalSessions.GC() -// } -// -// more docs: http://beego.me/docs/module/session.md -package memcache - -import ( - "net/http" - "strings" - "sync" - - "github.com/astaxie/beego/session" - - "github.com/bradfitz/gomemcache/memcache" -) - -var mempder = &MemProvider{} -var client *memcache.Client - -// SessionStore memcache session store -type SessionStore struct { - sid string - lock sync.RWMutex - values map[interface{}]interface{} - maxlifetime int64 -} - -// Set value in memcache session -func (rs *SessionStore) Set(key, value interface{}) error { - rs.lock.Lock() - defer rs.lock.Unlock() - rs.values[key] = value - return nil -} - -// Get value in memcache session -func (rs *SessionStore) Get(key interface{}) interface{} { - rs.lock.RLock() - defer rs.lock.RUnlock() - if v, ok := rs.values[key]; ok { - return v - } - return nil -} - -// Delete value in memcache session -func (rs *SessionStore) Delete(key interface{}) error { - rs.lock.Lock() - defer rs.lock.Unlock() - delete(rs.values, key) - return nil -} - -// Flush clear all values in memcache session -func (rs *SessionStore) Flush() error { - rs.lock.Lock() - defer rs.lock.Unlock() - rs.values = make(map[interface{}]interface{}) - return nil -} - -// SessionID get memcache session id -func (rs *SessionStore) SessionID() string { - return rs.sid -} - -// SessionRelease save session values to memcache -func (rs *SessionStore) SessionRelease(w http.ResponseWriter) { - b, err := session.EncodeGob(rs.values) - if err != nil { - return - } - item := memcache.Item{Key: rs.sid, Value: b, Expiration: int32(rs.maxlifetime)} - client.Set(&item) -} - -// MemProvider memcache session provider -type MemProvider struct { - maxlifetime int64 - conninfo []string - poolsize int - password string -} - -// SessionInit init memcache session -// savepath like -// e.g. 127.0.0.1:9090 -func (rp *MemProvider) SessionInit(maxlifetime int64, savePath string) error { - rp.maxlifetime = maxlifetime - rp.conninfo = strings.Split(savePath, ";") - client = memcache.New(rp.conninfo...) - return nil -} - -// SessionRead read memcache session by sid -func (rp *MemProvider) SessionRead(sid string) (session.Store, error) { - if client == nil { - if err := rp.connectInit(); err != nil { - return nil, err - } - } - item, err := client.Get(sid) - if err != nil { - if err == memcache.ErrCacheMiss { - rs := &SessionStore{sid: sid, values: make(map[interface{}]interface{}), maxlifetime: rp.maxlifetime} - return rs, nil - } - return nil, err - } - var kv map[interface{}]interface{} - if len(item.Value) == 0 { - kv = make(map[interface{}]interface{}) - } else { - kv, err = session.DecodeGob(item.Value) - if err != nil { - return nil, err - } - } - rs := &SessionStore{sid: sid, values: kv, maxlifetime: rp.maxlifetime} - return rs, nil -} - -// SessionExist check memcache session exist by sid -func (rp *MemProvider) SessionExist(sid string) bool { - if client == nil { - if err := rp.connectInit(); err != nil { - return false - } - } - if item, err := client.Get(sid); err != nil || len(item.Value) == 0 { - return false - } - return true -} - -// SessionRegenerate generate new sid for memcache session -func (rp *MemProvider) SessionRegenerate(oldsid, sid string) (session.Store, error) { - if client == nil { - if err := rp.connectInit(); err != nil { - return nil, err - } - } - var contain []byte - if item, err := client.Get(sid); err != nil || len(item.Value) == 0 { - // oldsid doesn't exists, set the new sid directly - // ignore error here, since if it return error - // the existed value will be 0 - item.Key = sid - item.Value = []byte("") - item.Expiration = int32(rp.maxlifetime) - client.Set(item) - } else { - client.Delete(oldsid) - item.Key = sid - item.Expiration = int32(rp.maxlifetime) - client.Set(item) - contain = item.Value - } - - var kv map[interface{}]interface{} - if len(contain) == 0 { - kv = make(map[interface{}]interface{}) - } else { - var err error - kv, err = session.DecodeGob(contain) - if err != nil { - return nil, err - } - } - - rs := &SessionStore{sid: sid, values: kv, maxlifetime: rp.maxlifetime} - return rs, nil -} - -// SessionDestroy delete memcache session by id -func (rp *MemProvider) SessionDestroy(sid string) error { - if client == nil { - if err := rp.connectInit(); err != nil { - return err - } - } - - return client.Delete(sid) -} - -func (rp *MemProvider) connectInit() error { - client = memcache.New(rp.conninfo...) - return nil -} - -// SessionGC Impelment method, no used. -func (rp *MemProvider) SessionGC() { -} - -// SessionAll return all activeSession -func (rp *MemProvider) SessionAll() int { - return 0 -} - -func init() { - session.Register("memcache", mempder) -} diff --git a/session/mysql/sess_mysql.go b/session/mysql/sess_mysql.go deleted file mode 100644 index 301353ab..00000000 --- a/session/mysql/sess_mysql.go +++ /dev/null @@ -1,228 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Package mysql for session provider -// -// depends on github.com/go-sql-driver/mysql: -// -// go install github.com/go-sql-driver/mysql -// -// mysql session support need create table as sql: -// CREATE TABLE `session` ( -// `session_key` char(64) NOT NULL, -// `session_data` blob, -// `session_expiry` int(11) unsigned NOT NULL, -// PRIMARY KEY (`session_key`) -// ) ENGINE=MyISAM DEFAULT CHARSET=utf8; -// -// Usage: -// import( -// _ "github.com/astaxie/beego/session/mysql" -// "github.com/astaxie/beego/session" -// ) -// -// func init() { -// globalSessions, _ = session.NewManager("mysql", ``{"cookieName":"gosessionid","gclifetime":3600,"ProviderConfig":"[username[:password]@][protocol[(address)]]/dbname[?param1=value1&...¶mN=valueN]"}``) -// go globalSessions.GC() -// } -// -// more docs: http://beego.me/docs/module/session.md -package mysql - -import ( - "database/sql" - "net/http" - "sync" - "time" - - "github.com/astaxie/beego/session" - // import mysql driver - _ "github.com/go-sql-driver/mysql" -) - -var ( - // TableName store the session in MySQL - TableName = "session" - mysqlpder = &Provider{} -) - -// SessionStore mysql session store -type SessionStore struct { - c *sql.DB - sid string - lock sync.RWMutex - values map[interface{}]interface{} -} - -// Set value in mysql session. -// it is temp value in map. -func (st *SessionStore) Set(key, value interface{}) error { - st.lock.Lock() - defer st.lock.Unlock() - st.values[key] = value - return nil -} - -// Get value from mysql session -func (st *SessionStore) Get(key interface{}) interface{} { - st.lock.RLock() - defer st.lock.RUnlock() - if v, ok := st.values[key]; ok { - return v - } - return nil -} - -// Delete value in mysql session -func (st *SessionStore) Delete(key interface{}) error { - st.lock.Lock() - defer st.lock.Unlock() - delete(st.values, key) - return nil -} - -// Flush clear all values in mysql session -func (st *SessionStore) Flush() error { - st.lock.Lock() - defer st.lock.Unlock() - st.values = make(map[interface{}]interface{}) - return nil -} - -// SessionID get session id of this mysql session store -func (st *SessionStore) SessionID() string { - return st.sid -} - -// SessionRelease save mysql session values to database. -// must call this method to save values to database. -func (st *SessionStore) SessionRelease(w http.ResponseWriter) { - defer st.c.Close() - b, err := session.EncodeGob(st.values) - if err != nil { - return - } - st.c.Exec("UPDATE "+TableName+" set `session_data`=?, `session_expiry`=? where session_key=?", - b, time.Now().Unix(), st.sid) -} - -// Provider mysql session provider -type Provider struct { - maxlifetime int64 - savePath string -} - -// connect to mysql -func (mp *Provider) connectInit() *sql.DB { - db, e := sql.Open("mysql", mp.savePath) - if e != nil { - return nil - } - return db -} - -// SessionInit init mysql session. -// savepath is the connection string of mysql. -func (mp *Provider) SessionInit(maxlifetime int64, savePath string) error { - mp.maxlifetime = maxlifetime - mp.savePath = savePath - return nil -} - -// SessionRead get mysql session by sid -func (mp *Provider) SessionRead(sid string) (session.Store, error) { - c := mp.connectInit() - row := c.QueryRow("select session_data from "+TableName+" where session_key=?", sid) - var sessiondata []byte - err := row.Scan(&sessiondata) - if err == sql.ErrNoRows { - c.Exec("insert into "+TableName+"(`session_key`,`session_data`,`session_expiry`) values(?,?,?)", - sid, "", time.Now().Unix()) - } - var kv map[interface{}]interface{} - if len(sessiondata) == 0 { - kv = make(map[interface{}]interface{}) - } else { - kv, err = session.DecodeGob(sessiondata) - if err != nil { - return nil, err - } - } - rs := &SessionStore{c: c, sid: sid, values: kv} - return rs, nil -} - -// SessionExist check mysql session exist -func (mp *Provider) SessionExist(sid string) bool { - c := mp.connectInit() - defer c.Close() - row := c.QueryRow("select session_data from "+TableName+" where session_key=?", sid) - var sessiondata []byte - err := row.Scan(&sessiondata) - return err != sql.ErrNoRows -} - -// SessionRegenerate generate new sid for mysql session -func (mp *Provider) SessionRegenerate(oldsid, sid string) (session.Store, error) { - c := mp.connectInit() - row := c.QueryRow("select session_data from "+TableName+" where session_key=?", oldsid) - var sessiondata []byte - err := row.Scan(&sessiondata) - if err == sql.ErrNoRows { - c.Exec("insert into "+TableName+"(`session_key`,`session_data`,`session_expiry`) values(?,?,?)", oldsid, "", time.Now().Unix()) - } - c.Exec("update "+TableName+" set `session_key`=? where session_key=?", sid, oldsid) - var kv map[interface{}]interface{} - if len(sessiondata) == 0 { - kv = make(map[interface{}]interface{}) - } else { - kv, err = session.DecodeGob(sessiondata) - if err != nil { - return nil, err - } - } - rs := &SessionStore{c: c, sid: sid, values: kv} - return rs, nil -} - -// SessionDestroy delete mysql session by sid -func (mp *Provider) SessionDestroy(sid string) error { - c := mp.connectInit() - c.Exec("DELETE FROM "+TableName+" where session_key=?", sid) - c.Close() - return nil -} - -// SessionGC delete expired values in mysql session -func (mp *Provider) SessionGC() { - c := mp.connectInit() - c.Exec("DELETE from "+TableName+" where session_expiry < ?", time.Now().Unix()-mp.maxlifetime) - c.Close() -} - -// SessionAll count values in mysql session -func (mp *Provider) SessionAll() int { - c := mp.connectInit() - defer c.Close() - var total int - err := c.QueryRow("SELECT count(*) as num from " + TableName).Scan(&total) - if err != nil { - return 0 - } - return total -} - -func init() { - session.Register("mysql", mysqlpder) -} diff --git a/session/postgres/sess_postgresql.go b/session/postgres/sess_postgresql.go deleted file mode 100644 index 0b8b9645..00000000 --- a/session/postgres/sess_postgresql.go +++ /dev/null @@ -1,243 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Package postgres for session provider -// -// depends on github.com/lib/pq: -// -// go install github.com/lib/pq -// -// -// needs this table in your database: -// -// CREATE TABLE session ( -// session_key char(64) NOT NULL, -// session_data bytea, -// session_expiry timestamp NOT NULL, -// CONSTRAINT session_key PRIMARY KEY(session_key) -// ); -// -// will be activated with these settings in app.conf: -// -// SessionOn = true -// SessionProvider = postgresql -// SessionSavePath = "user=a password=b dbname=c sslmode=disable" -// SessionName = session -// -// -// Usage: -// import( -// _ "github.com/astaxie/beego/session/postgresql" -// "github.com/astaxie/beego/session" -// ) -// -// func init() { -// globalSessions, _ = session.NewManager("postgresql", ``{"cookieName":"gosessionid","gclifetime":3600,"ProviderConfig":"user=pqgotest dbname=pqgotest sslmode=verify-full"}``) -// go globalSessions.GC() -// } -// -// more docs: http://beego.me/docs/module/session.md -package postgres - -import ( - "database/sql" - "net/http" - "sync" - "time" - - "github.com/astaxie/beego/session" - // import postgresql Driver - _ "github.com/lib/pq" -) - -var postgresqlpder = &Provider{} - -// SessionStore postgresql session store -type SessionStore struct { - c *sql.DB - sid string - lock sync.RWMutex - values map[interface{}]interface{} -} - -// Set value in postgresql session. -// it is temp value in map. -func (st *SessionStore) Set(key, value interface{}) error { - st.lock.Lock() - defer st.lock.Unlock() - st.values[key] = value - return nil -} - -// Get value from postgresql session -func (st *SessionStore) Get(key interface{}) interface{} { - st.lock.RLock() - defer st.lock.RUnlock() - if v, ok := st.values[key]; ok { - return v - } - return nil -} - -// Delete value in postgresql session -func (st *SessionStore) Delete(key interface{}) error { - st.lock.Lock() - defer st.lock.Unlock() - delete(st.values, key) - return nil -} - -// Flush clear all values in postgresql session -func (st *SessionStore) Flush() error { - st.lock.Lock() - defer st.lock.Unlock() - st.values = make(map[interface{}]interface{}) - return nil -} - -// SessionID get session id of this postgresql session store -func (st *SessionStore) SessionID() string { - return st.sid -} - -// SessionRelease save postgresql session values to database. -// must call this method to save values to database. -func (st *SessionStore) SessionRelease(w http.ResponseWriter) { - defer st.c.Close() - b, err := session.EncodeGob(st.values) - if err != nil { - return - } - st.c.Exec("UPDATE session set session_data=$1, session_expiry=$2 where session_key=$3", - b, time.Now().Format(time.RFC3339), st.sid) - -} - -// Provider postgresql session provider -type Provider struct { - maxlifetime int64 - savePath string -} - -// connect to postgresql -func (mp *Provider) connectInit() *sql.DB { - db, e := sql.Open("postgres", mp.savePath) - if e != nil { - return nil - } - return db -} - -// SessionInit init postgresql session. -// savepath is the connection string of postgresql. -func (mp *Provider) SessionInit(maxlifetime int64, savePath string) error { - mp.maxlifetime = maxlifetime - mp.savePath = savePath - return nil -} - -// SessionRead get postgresql session by sid -func (mp *Provider) SessionRead(sid string) (session.Store, error) { - c := mp.connectInit() - row := c.QueryRow("select session_data from session where session_key=$1", sid) - var sessiondata []byte - err := row.Scan(&sessiondata) - if err == sql.ErrNoRows { - _, err = c.Exec("insert into session(session_key,session_data,session_expiry) values($1,$2,$3)", - sid, "", time.Now().Format(time.RFC3339)) - - if err != nil { - return nil, err - } - } else if err != nil { - return nil, err - } - - var kv map[interface{}]interface{} - if len(sessiondata) == 0 { - kv = make(map[interface{}]interface{}) - } else { - kv, err = session.DecodeGob(sessiondata) - if err != nil { - return nil, err - } - } - rs := &SessionStore{c: c, sid: sid, values: kv} - return rs, nil -} - -// SessionExist check postgresql session exist -func (mp *Provider) SessionExist(sid string) bool { - c := mp.connectInit() - defer c.Close() - row := c.QueryRow("select session_data from session where session_key=$1", sid) - var sessiondata []byte - err := row.Scan(&sessiondata) - return err != sql.ErrNoRows -} - -// SessionRegenerate generate new sid for postgresql session -func (mp *Provider) SessionRegenerate(oldsid, sid string) (session.Store, error) { - c := mp.connectInit() - row := c.QueryRow("select session_data from session where session_key=$1", oldsid) - var sessiondata []byte - err := row.Scan(&sessiondata) - if err == sql.ErrNoRows { - c.Exec("insert into session(session_key,session_data,session_expiry) values($1,$2,$3)", - oldsid, "", time.Now().Format(time.RFC3339)) - } - c.Exec("update session set session_key=$1 where session_key=$2", sid, oldsid) - var kv map[interface{}]interface{} - if len(sessiondata) == 0 { - kv = make(map[interface{}]interface{}) - } else { - kv, err = session.DecodeGob(sessiondata) - if err != nil { - return nil, err - } - } - rs := &SessionStore{c: c, sid: sid, values: kv} - return rs, nil -} - -// SessionDestroy delete postgresql session by sid -func (mp *Provider) SessionDestroy(sid string) error { - c := mp.connectInit() - c.Exec("DELETE FROM session where session_key=$1", sid) - c.Close() - return nil -} - -// SessionGC delete expired values in postgresql session -func (mp *Provider) SessionGC() { - c := mp.connectInit() - c.Exec("DELETE from session where EXTRACT(EPOCH FROM (current_timestamp - session_expiry)) > $1", mp.maxlifetime) - c.Close() -} - -// SessionAll count values in postgresql session -func (mp *Provider) SessionAll() int { - c := mp.connectInit() - defer c.Close() - var total int - err := c.QueryRow("SELECT count(*) as num from session").Scan(&total) - if err != nil { - return 0 - } - return total -} - -func init() { - session.Register("postgresql", postgresqlpder) -} diff --git a/session/redis/sess_redis.go b/session/redis/sess_redis.go deleted file mode 100644 index 5c382d61..00000000 --- a/session/redis/sess_redis.go +++ /dev/null @@ -1,261 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Package redis for session provider -// -// depend on github.com/gomodule/redigo/redis -// -// go install github.com/gomodule/redigo/redis -// -// Usage: -// import( -// _ "github.com/astaxie/beego/session/redis" -// "github.com/astaxie/beego/session" -// ) -// -// func init() { -// globalSessions, _ = session.NewManager("redis", ``{"cookieName":"gosessionid","gclifetime":3600,"ProviderConfig":"127.0.0.1:7070"}``) -// go globalSessions.GC() -// } -// -// more docs: http://beego.me/docs/module/session.md -package redis - -import ( - "net/http" - "strconv" - "strings" - "sync" - "time" - - "github.com/astaxie/beego/session" - - "github.com/gomodule/redigo/redis" -) - -var redispder = &Provider{} - -// MaxPoolSize redis max pool size -var MaxPoolSize = 100 - -// SessionStore redis session store -type SessionStore struct { - p *redis.Pool - sid string - lock sync.RWMutex - values map[interface{}]interface{} - maxlifetime int64 -} - -// Set value in redis session -func (rs *SessionStore) Set(key, value interface{}) error { - rs.lock.Lock() - defer rs.lock.Unlock() - rs.values[key] = value - return nil -} - -// Get value in redis session -func (rs *SessionStore) Get(key interface{}) interface{} { - rs.lock.RLock() - defer rs.lock.RUnlock() - if v, ok := rs.values[key]; ok { - return v - } - return nil -} - -// Delete value in redis session -func (rs *SessionStore) Delete(key interface{}) error { - rs.lock.Lock() - defer rs.lock.Unlock() - delete(rs.values, key) - return nil -} - -// Flush clear all values in redis session -func (rs *SessionStore) Flush() error { - rs.lock.Lock() - defer rs.lock.Unlock() - rs.values = make(map[interface{}]interface{}) - return nil -} - -// SessionID get redis session id -func (rs *SessionStore) SessionID() string { - return rs.sid -} - -// SessionRelease save session values to redis -func (rs *SessionStore) SessionRelease(w http.ResponseWriter) { - b, err := session.EncodeGob(rs.values) - if err != nil { - return - } - c := rs.p.Get() - defer c.Close() - c.Do("SETEX", rs.sid, rs.maxlifetime, string(b)) -} - -// Provider redis session provider -type Provider struct { - maxlifetime int64 - savePath string - poolsize int - password string - dbNum int - poollist *redis.Pool -} - -// SessionInit init redis session -// savepath like redis server addr,pool size,password,dbnum,IdleTimeout second -// e.g. 127.0.0.1:6379,100,astaxie,0,30 -func (rp *Provider) SessionInit(maxlifetime int64, savePath string) error { - rp.maxlifetime = maxlifetime - configs := strings.Split(savePath, ",") - if len(configs) > 0 { - rp.savePath = configs[0] - } - if len(configs) > 1 { - poolsize, err := strconv.Atoi(configs[1]) - if err != nil || poolsize < 0 { - rp.poolsize = MaxPoolSize - } else { - rp.poolsize = poolsize - } - } else { - rp.poolsize = MaxPoolSize - } - if len(configs) > 2 { - rp.password = configs[2] - } - if len(configs) > 3 { - dbnum, err := strconv.Atoi(configs[3]) - if err != nil || dbnum < 0 { - rp.dbNum = 0 - } else { - rp.dbNum = dbnum - } - } else { - rp.dbNum = 0 - } - var idleTimeout time.Duration = 0 - if len(configs) > 4 { - timeout, err := strconv.Atoi(configs[4]) - if err == nil && timeout > 0 { - idleTimeout = time.Duration(timeout) * time.Second - } - } - rp.poollist = &redis.Pool{ - Dial: func() (redis.Conn, error) { - c, err := redis.Dial("tcp", rp.savePath) - if err != nil { - return nil, err - } - if rp.password != "" { - if _, err = c.Do("AUTH", rp.password); err != nil { - c.Close() - return nil, err - } - } - // some redis proxy such as twemproxy is not support select command - if rp.dbNum > 0 { - _, err = c.Do("SELECT", rp.dbNum) - if err != nil { - c.Close() - return nil, err - } - } - return c, err - }, - MaxIdle: rp.poolsize, - } - - rp.poollist.IdleTimeout = idleTimeout - - return rp.poollist.Get().Err() -} - -// SessionRead read redis session by sid -func (rp *Provider) SessionRead(sid string) (session.Store, error) { - c := rp.poollist.Get() - defer c.Close() - - var kv map[interface{}]interface{} - - kvs, err := redis.String(c.Do("GET", sid)) - if err != nil && err != redis.ErrNil { - return nil, err - } - if len(kvs) == 0 { - kv = make(map[interface{}]interface{}) - } else { - if kv, err = session.DecodeGob([]byte(kvs)); err != nil { - return nil, err - } - } - - rs := &SessionStore{p: rp.poollist, sid: sid, values: kv, maxlifetime: rp.maxlifetime} - return rs, nil -} - -// SessionExist check redis session exist by sid -func (rp *Provider) SessionExist(sid string) bool { - c := rp.poollist.Get() - defer c.Close() - - if existed, err := redis.Int(c.Do("EXISTS", sid)); err != nil || existed == 0 { - return false - } - return true -} - -// SessionRegenerate generate new sid for redis session -func (rp *Provider) SessionRegenerate(oldsid, sid string) (session.Store, error) { - c := rp.poollist.Get() - defer c.Close() - - if existed, _ := redis.Int(c.Do("EXISTS", oldsid)); existed == 0 { - // oldsid doesn't exists, set the new sid directly - // ignore error here, since if it return error - // the existed value will be 0 - c.Do("SET", sid, "", "EX", rp.maxlifetime) - } else { - c.Do("RENAME", oldsid, sid) - c.Do("EXPIRE", sid, rp.maxlifetime) - } - return rp.SessionRead(sid) -} - -// SessionDestroy delete redis session by id -func (rp *Provider) SessionDestroy(sid string) error { - c := rp.poollist.Get() - defer c.Close() - - c.Do("DEL", sid) - return nil -} - -// SessionGC Impelment method, no used. -func (rp *Provider) SessionGC() { -} - -// SessionAll return all activeSession -func (rp *Provider) SessionAll() int { - return 0 -} - -func init() { - session.Register("redis", redispder) -} diff --git a/session/redis_cluster/redis_cluster.go b/session/redis_cluster/redis_cluster.go deleted file mode 100644 index 262fa2e3..00000000 --- a/session/redis_cluster/redis_cluster.go +++ /dev/null @@ -1,221 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Package redis for session provider -// -// depend on github.com/go-redis/redis -// -// go install github.com/go-redis/redis -// -// Usage: -// import( -// _ "github.com/astaxie/beego/session/redis_cluster" -// "github.com/astaxie/beego/session" -// ) -// -// func init() { -// globalSessions, _ = session.NewManager("redis_cluster", ``{"cookieName":"gosessionid","gclifetime":3600,"ProviderConfig":"127.0.0.1:7070;127.0.0.1:7071"}``) -// go globalSessions.GC() -// } -// -// more docs: http://beego.me/docs/module/session.md -package redis_cluster - -import ( - "github.com/astaxie/beego/session" - rediss "github.com/go-redis/redis" - "net/http" - "strconv" - "strings" - "sync" - "time" -) - -var redispder = &Provider{} - -// MaxPoolSize redis_cluster max pool size -var MaxPoolSize = 1000 - -// SessionStore redis_cluster session store -type SessionStore struct { - p *rediss.ClusterClient - sid string - lock sync.RWMutex - values map[interface{}]interface{} - maxlifetime int64 -} - -// Set value in redis_cluster session -func (rs *SessionStore) Set(key, value interface{}) error { - rs.lock.Lock() - defer rs.lock.Unlock() - rs.values[key] = value - return nil -} - -// Get value in redis_cluster session -func (rs *SessionStore) Get(key interface{}) interface{} { - rs.lock.RLock() - defer rs.lock.RUnlock() - if v, ok := rs.values[key]; ok { - return v - } - return nil -} - -// Delete value in redis_cluster session -func (rs *SessionStore) Delete(key interface{}) error { - rs.lock.Lock() - defer rs.lock.Unlock() - delete(rs.values, key) - return nil -} - -// Flush clear all values in redis_cluster session -func (rs *SessionStore) Flush() error { - rs.lock.Lock() - defer rs.lock.Unlock() - rs.values = make(map[interface{}]interface{}) - return nil -} - -// SessionID get redis_cluster session id -func (rs *SessionStore) SessionID() string { - return rs.sid -} - -// SessionRelease save session values to redis_cluster -func (rs *SessionStore) SessionRelease(w http.ResponseWriter) { - b, err := session.EncodeGob(rs.values) - if err != nil { - return - } - c := rs.p - c.Set(rs.sid, string(b), time.Duration(rs.maxlifetime)*time.Second) -} - -// Provider redis_cluster session provider -type Provider struct { - maxlifetime int64 - savePath string - poolsize int - password string - dbNum int - poollist *rediss.ClusterClient -} - -// SessionInit init redis_cluster session -// savepath like redis server addr,pool size,password,dbnum -// e.g. 127.0.0.1:6379;127.0.0.1:6380,100,test,0 -func (rp *Provider) SessionInit(maxlifetime int64, savePath string) error { - rp.maxlifetime = maxlifetime - configs := strings.Split(savePath, ",") - if len(configs) > 0 { - rp.savePath = configs[0] - } - if len(configs) > 1 { - poolsize, err := strconv.Atoi(configs[1]) - if err != nil || poolsize < 0 { - rp.poolsize = MaxPoolSize - } else { - rp.poolsize = poolsize - } - } else { - rp.poolsize = MaxPoolSize - } - if len(configs) > 2 { - rp.password = configs[2] - } - if len(configs) > 3 { - dbnum, err := strconv.Atoi(configs[3]) - if err != nil || dbnum < 0 { - rp.dbNum = 0 - } else { - rp.dbNum = dbnum - } - } else { - rp.dbNum = 0 - } - - rp.poollist = rediss.NewClusterClient(&rediss.ClusterOptions{ - Addrs: strings.Split(rp.savePath, ";"), - Password: rp.password, - PoolSize: rp.poolsize, - }) - return rp.poollist.Ping().Err() -} - -// SessionRead read redis_cluster session by sid -func (rp *Provider) SessionRead(sid string) (session.Store, error) { - var kv map[interface{}]interface{} - kvs, err := rp.poollist.Get(sid).Result() - if err != nil && err != rediss.Nil { - return nil, err - } - if len(kvs) == 0 { - kv = make(map[interface{}]interface{}) - } else { - if kv, err = session.DecodeGob([]byte(kvs)); err != nil { - return nil, err - } - } - - rs := &SessionStore{p: rp.poollist, sid: sid, values: kv, maxlifetime: rp.maxlifetime} - return rs, nil -} - -// SessionExist check redis_cluster session exist by sid -func (rp *Provider) SessionExist(sid string) bool { - c := rp.poollist - if existed, err := c.Exists(sid).Result(); err != nil || existed == 0 { - return false - } - return true -} - -// SessionRegenerate generate new sid for redis_cluster session -func (rp *Provider) SessionRegenerate(oldsid, sid string) (session.Store, error) { - c := rp.poollist - - if existed, err := c.Exists(oldsid).Result(); err != nil || existed == 0 { - // oldsid doesn't exists, set the new sid directly - // ignore error here, since if it return error - // the existed value will be 0 - c.Set(sid, "", time.Duration(rp.maxlifetime)*time.Second) - } else { - c.Rename(oldsid, sid) - c.Expire(sid, time.Duration(rp.maxlifetime)*time.Second) - } - return rp.SessionRead(sid) -} - -// SessionDestroy delete redis session by id -func (rp *Provider) SessionDestroy(sid string) error { - c := rp.poollist - c.Del(sid) - return nil -} - -// SessionGC Impelment method, no used. -func (rp *Provider) SessionGC() { -} - -// SessionAll return all activeSession -func (rp *Provider) SessionAll() int { - return 0 -} - -func init() { - session.Register("redis_cluster", redispder) -} diff --git a/session/redis_sentinel/sess_redis_sentinel.go b/session/redis_sentinel/sess_redis_sentinel.go deleted file mode 100644 index 6ecb2977..00000000 --- a/session/redis_sentinel/sess_redis_sentinel.go +++ /dev/null @@ -1,234 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Package redis for session provider -// -// depend on github.com/go-redis/redis -// -// go install github.com/go-redis/redis -// -// Usage: -// import( -// _ "github.com/astaxie/beego/session/redis_sentinel" -// "github.com/astaxie/beego/session" -// ) -// -// func init() { -// globalSessions, _ = session.NewManager("redis_sentinel", ``{"cookieName":"gosessionid","gclifetime":3600,"ProviderConfig":"127.0.0.1:26379;127.0.0.2:26379"}``) -// go globalSessions.GC() -// } -// -// more detail about params: please check the notes on the function SessionInit in this package -package redis_sentinel - -import ( - "github.com/astaxie/beego/session" - "github.com/go-redis/redis" - "net/http" - "strconv" - "strings" - "sync" - "time" -) - -var redispder = &Provider{} - -// DefaultPoolSize redis_sentinel default pool size -var DefaultPoolSize = 100 - -// SessionStore redis_sentinel session store -type SessionStore struct { - p *redis.Client - sid string - lock sync.RWMutex - values map[interface{}]interface{} - maxlifetime int64 -} - -// Set value in redis_sentinel session -func (rs *SessionStore) Set(key, value interface{}) error { - rs.lock.Lock() - defer rs.lock.Unlock() - rs.values[key] = value - return nil -} - -// Get value in redis_sentinel session -func (rs *SessionStore) Get(key interface{}) interface{} { - rs.lock.RLock() - defer rs.lock.RUnlock() - if v, ok := rs.values[key]; ok { - return v - } - return nil -} - -// Delete value in redis_sentinel session -func (rs *SessionStore) Delete(key interface{}) error { - rs.lock.Lock() - defer rs.lock.Unlock() - delete(rs.values, key) - return nil -} - -// Flush clear all values in redis_sentinel session -func (rs *SessionStore) Flush() error { - rs.lock.Lock() - defer rs.lock.Unlock() - rs.values = make(map[interface{}]interface{}) - return nil -} - -// SessionID get redis_sentinel session id -func (rs *SessionStore) SessionID() string { - return rs.sid -} - -// SessionRelease save session values to redis_sentinel -func (rs *SessionStore) SessionRelease(w http.ResponseWriter) { - b, err := session.EncodeGob(rs.values) - if err != nil { - return - } - c := rs.p - c.Set(rs.sid, string(b), time.Duration(rs.maxlifetime)*time.Second) -} - -// Provider redis_sentinel session provider -type Provider struct { - maxlifetime int64 - savePath string - poolsize int - password string - dbNum int - poollist *redis.Client - masterName string -} - -// SessionInit init redis_sentinel session -// savepath like redis sentinel addr,pool size,password,dbnum,masterName -// e.g. 127.0.0.1:26379;127.0.0.2:26379,100,1qaz2wsx,0,mymaster -func (rp *Provider) SessionInit(maxlifetime int64, savePath string) error { - rp.maxlifetime = maxlifetime - configs := strings.Split(savePath, ",") - if len(configs) > 0 { - rp.savePath = configs[0] - } - if len(configs) > 1 { - poolsize, err := strconv.Atoi(configs[1]) - if err != nil || poolsize < 0 { - rp.poolsize = DefaultPoolSize - } else { - rp.poolsize = poolsize - } - } else { - rp.poolsize = DefaultPoolSize - } - if len(configs) > 2 { - rp.password = configs[2] - } - if len(configs) > 3 { - dbnum, err := strconv.Atoi(configs[3]) - if err != nil || dbnum < 0 { - rp.dbNum = 0 - } else { - rp.dbNum = dbnum - } - } else { - rp.dbNum = 0 - } - if len(configs) > 4 { - if configs[4] != "" { - rp.masterName = configs[4] - } else { - rp.masterName = "mymaster" - } - } else { - rp.masterName = "mymaster" - } - - rp.poollist = redis.NewFailoverClient(&redis.FailoverOptions{ - SentinelAddrs: strings.Split(rp.savePath, ";"), - Password: rp.password, - PoolSize: rp.poolsize, - DB: rp.dbNum, - MasterName: rp.masterName, - }) - - return rp.poollist.Ping().Err() -} - -// SessionRead read redis_sentinel session by sid -func (rp *Provider) SessionRead(sid string) (session.Store, error) { - var kv map[interface{}]interface{} - kvs, err := rp.poollist.Get(sid).Result() - if err != nil && err != redis.Nil { - return nil, err - } - if len(kvs) == 0 { - kv = make(map[interface{}]interface{}) - } else { - if kv, err = session.DecodeGob([]byte(kvs)); err != nil { - return nil, err - } - } - - rs := &SessionStore{p: rp.poollist, sid: sid, values: kv, maxlifetime: rp.maxlifetime} - return rs, nil -} - -// SessionExist check redis_sentinel session exist by sid -func (rp *Provider) SessionExist(sid string) bool { - c := rp.poollist - if existed, err := c.Exists(sid).Result(); err != nil || existed == 0 { - return false - } - return true -} - -// SessionRegenerate generate new sid for redis_sentinel session -func (rp *Provider) SessionRegenerate(oldsid, sid string) (session.Store, error) { - c := rp.poollist - - if existed, err := c.Exists(oldsid).Result(); err != nil || existed == 0 { - // oldsid doesn't exists, set the new sid directly - // ignore error here, since if it return error - // the existed value will be 0 - c.Set(sid, "", time.Duration(rp.maxlifetime)*time.Second) - } else { - c.Rename(oldsid, sid) - c.Expire(sid, time.Duration(rp.maxlifetime)*time.Second) - } - return rp.SessionRead(sid) -} - -// SessionDestroy delete redis session by id -func (rp *Provider) SessionDestroy(sid string) error { - c := rp.poollist - c.Del(sid) - return nil -} - -// SessionGC Impelment method, no used. -func (rp *Provider) SessionGC() { -} - -// SessionAll return all activeSession -func (rp *Provider) SessionAll() int { - return 0 -} - -func init() { - session.Register("redis_sentinel", redispder) -} diff --git a/session/redis_sentinel/sess_redis_sentinel_test.go b/session/redis_sentinel/sess_redis_sentinel_test.go deleted file mode 100644 index fd4155c6..00000000 --- a/session/redis_sentinel/sess_redis_sentinel_test.go +++ /dev/null @@ -1,90 +0,0 @@ -package redis_sentinel - -import ( - "net/http" - "net/http/httptest" - "testing" - - "github.com/astaxie/beego/session" -) - -func TestRedisSentinel(t *testing.T) { - sessionConfig := &session.ManagerConfig{ - CookieName: "gosessionid", - EnableSetCookie: true, - Gclifetime: 3600, - Maxlifetime: 3600, - Secure: false, - CookieLifeTime: 3600, - ProviderConfig: "127.0.0.1:6379,100,,0,master", - } - globalSessions, e := session.NewManager("redis_sentinel", sessionConfig) - if e != nil { - t.Log(e) - return - } - //todo test if e==nil - go globalSessions.GC() - - r, _ := http.NewRequest("GET", "/", nil) - w := httptest.NewRecorder() - - sess, err := globalSessions.SessionStart(w, r) - if err != nil { - t.Fatal("session start failed:", err) - } - defer sess.SessionRelease(w) - - // SET AND GET - err = sess.Set("username", "astaxie") - if err != nil { - t.Fatal("set username failed:", err) - } - username := sess.Get("username") - if username != "astaxie" { - t.Fatal("get username failed") - } - - // DELETE - err = sess.Delete("username") - if err != nil { - t.Fatal("delete username failed:", err) - } - username = sess.Get("username") - if username != nil { - t.Fatal("delete username failed") - } - - // FLUSH - err = sess.Set("username", "astaxie") - if err != nil { - t.Fatal("set failed:", err) - } - err = sess.Set("password", "1qaz2wsx") - if err != nil { - t.Fatal("set failed:", err) - } - username = sess.Get("username") - if username != "astaxie" { - t.Fatal("get username failed") - } - password := sess.Get("password") - if password != "1qaz2wsx" { - t.Fatal("get password failed") - } - err = sess.Flush() - if err != nil { - t.Fatal("flush failed:", err) - } - username = sess.Get("username") - if username != nil { - t.Fatal("flush failed") - } - password = sess.Get("password") - if password != nil { - t.Fatal("flush failed") - } - - sess.SessionRelease(w) - -} diff --git a/session/sess_cookie.go b/session/sess_cookie.go deleted file mode 100644 index 6ad5debc..00000000 --- a/session/sess_cookie.go +++ /dev/null @@ -1,180 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package session - -import ( - "crypto/aes" - "crypto/cipher" - "encoding/json" - "net/http" - "net/url" - "sync" -) - -var cookiepder = &CookieProvider{} - -// CookieSessionStore Cookie SessionStore -type CookieSessionStore struct { - sid string - values map[interface{}]interface{} // session data - lock sync.RWMutex -} - -// Set value to cookie session. -// the value are encoded as gob with hash block string. -func (st *CookieSessionStore) Set(key, value interface{}) error { - st.lock.Lock() - defer st.lock.Unlock() - st.values[key] = value - return nil -} - -// Get value from cookie session -func (st *CookieSessionStore) Get(key interface{}) interface{} { - st.lock.RLock() - defer st.lock.RUnlock() - if v, ok := st.values[key]; ok { - return v - } - return nil -} - -// Delete value in cookie session -func (st *CookieSessionStore) Delete(key interface{}) error { - st.lock.Lock() - defer st.lock.Unlock() - delete(st.values, key) - return nil -} - -// Flush Clean all values in cookie session -func (st *CookieSessionStore) Flush() error { - st.lock.Lock() - defer st.lock.Unlock() - st.values = make(map[interface{}]interface{}) - return nil -} - -// SessionID Return id of this cookie session -func (st *CookieSessionStore) SessionID() string { - return st.sid -} - -// SessionRelease Write cookie session to http response cookie -func (st *CookieSessionStore) SessionRelease(w http.ResponseWriter) { - st.lock.Lock() - encodedCookie, err := encodeCookie(cookiepder.block, cookiepder.config.SecurityKey, cookiepder.config.SecurityName, st.values) - st.lock.Unlock() - if err == nil { - cookie := &http.Cookie{Name: cookiepder.config.CookieName, - Value: url.QueryEscape(encodedCookie), - Path: "/", - HttpOnly: true, - Secure: cookiepder.config.Secure, - MaxAge: cookiepder.config.Maxage} - http.SetCookie(w, cookie) - } -} - -type cookieConfig struct { - SecurityKey string `json:"securityKey"` - BlockKey string `json:"blockKey"` - SecurityName string `json:"securityName"` - CookieName string `json:"cookieName"` - Secure bool `json:"secure"` - Maxage int `json:"maxage"` -} - -// CookieProvider Cookie session provider -type CookieProvider struct { - maxlifetime int64 - config *cookieConfig - block cipher.Block -} - -// SessionInit Init cookie session provider with max lifetime and config json. -// maxlifetime is ignored. -// json config: -// securityKey - hash string -// blockKey - gob encode hash string. it's saved as aes crypto. -// securityName - recognized name in encoded cookie string -// cookieName - cookie name -// maxage - cookie max life time. -func (pder *CookieProvider) SessionInit(maxlifetime int64, config string) error { - pder.config = &cookieConfig{} - err := json.Unmarshal([]byte(config), pder.config) - if err != nil { - return err - } - if pder.config.BlockKey == "" { - pder.config.BlockKey = string(generateRandomKey(16)) - } - if pder.config.SecurityName == "" { - pder.config.SecurityName = string(generateRandomKey(20)) - } - pder.block, err = aes.NewCipher([]byte(pder.config.BlockKey)) - if err != nil { - return err - } - pder.maxlifetime = maxlifetime - return nil -} - -// SessionRead Get SessionStore in cooke. -// decode cooke string to map and put into SessionStore with sid. -func (pder *CookieProvider) SessionRead(sid string) (Store, error) { - maps, _ := decodeCookie(pder.block, - pder.config.SecurityKey, - pder.config.SecurityName, - sid, pder.maxlifetime) - if maps == nil { - maps = make(map[interface{}]interface{}) - } - rs := &CookieSessionStore{sid: sid, values: maps} - return rs, nil -} - -// SessionExist Cookie session is always existed -func (pder *CookieProvider) SessionExist(sid string) bool { - return true -} - -// SessionRegenerate Implement method, no used. -func (pder *CookieProvider) SessionRegenerate(oldsid, sid string) (Store, error) { - return nil, nil -} - -// SessionDestroy Implement method, no used. -func (pder *CookieProvider) SessionDestroy(sid string) error { - return nil -} - -// SessionGC Implement method, no used. -func (pder *CookieProvider) SessionGC() { -} - -// SessionAll Implement method, return 0. -func (pder *CookieProvider) SessionAll() int { - return 0 -} - -// SessionUpdate Implement method, no used. -func (pder *CookieProvider) SessionUpdate(sid string) error { - return nil -} - -func init() { - Register("cookie", cookiepder) -} diff --git a/session/sess_cookie_test.go b/session/sess_cookie_test.go deleted file mode 100644 index b6726005..00000000 --- a/session/sess_cookie_test.go +++ /dev/null @@ -1,105 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package session - -import ( - "encoding/json" - "net/http" - "net/http/httptest" - "strings" - "testing" -) - -func TestCookie(t *testing.T) { - config := `{"cookieName":"gosessionid","enableSetCookie":false,"gclifetime":3600,"ProviderConfig":"{\"cookieName\":\"gosessionid\",\"securityKey\":\"beegocookiehashkey\"}"}` - conf := new(ManagerConfig) - if err := json.Unmarshal([]byte(config), conf); err != nil { - t.Fatal("json decode error", err) - } - globalSessions, err := NewManager("cookie", conf) - if err != nil { - t.Fatal("init cookie session err", err) - } - r, _ := http.NewRequest("GET", "/", nil) - w := httptest.NewRecorder() - sess, err := globalSessions.SessionStart(w, r) - if err != nil { - t.Fatal("set error,", err) - } - err = sess.Set("username", "astaxie") - if err != nil { - t.Fatal("set error,", err) - } - if username := sess.Get("username"); username != "astaxie" { - t.Fatal("get username error") - } - sess.SessionRelease(w) - if cookiestr := w.Header().Get("Set-Cookie"); cookiestr == "" { - t.Fatal("setcookie error") - } else { - parts := strings.Split(strings.TrimSpace(cookiestr), ";") - for k, v := range parts { - nameval := strings.Split(v, "=") - if k == 0 && nameval[0] != "gosessionid" { - t.Fatal("error") - } - } - } -} - -func TestDestorySessionCookie(t *testing.T) { - config := `{"cookieName":"gosessionid","enableSetCookie":true,"gclifetime":3600,"ProviderConfig":"{\"cookieName\":\"gosessionid\",\"securityKey\":\"beegocookiehashkey\"}"}` - conf := new(ManagerConfig) - if err := json.Unmarshal([]byte(config), conf); err != nil { - t.Fatal("json decode error", err) - } - globalSessions, err := NewManager("cookie", conf) - if err != nil { - t.Fatal("init cookie session err", err) - } - - r, _ := http.NewRequest("GET", "/", nil) - w := httptest.NewRecorder() - session, err := globalSessions.SessionStart(w, r) - if err != nil { - t.Fatal("session start err,", err) - } - - // request again ,will get same sesssion id . - r1, _ := http.NewRequest("GET", "/", nil) - r1.Header.Set("Cookie", w.Header().Get("Set-Cookie")) - w = httptest.NewRecorder() - newSession, err := globalSessions.SessionStart(w, r1) - if err != nil { - t.Fatal("session start err,", err) - } - if newSession.SessionID() != session.SessionID() { - t.Fatal("get cookie session id is not the same again.") - } - - // After destroy session , will get a new session id . - globalSessions.SessionDestroy(w, r1) - r2, _ := http.NewRequest("GET", "/", nil) - r2.Header.Set("Cookie", w.Header().Get("Set-Cookie")) - - w = httptest.NewRecorder() - newSession, err = globalSessions.SessionStart(w, r2) - if err != nil { - t.Fatal("session start error") - } - if newSession.SessionID() == session.SessionID() { - t.Fatal("after destroy session and reqeust again ,get cookie session id is same.") - } -} diff --git a/session/sess_file.go b/session/sess_file.go deleted file mode 100644 index 47ad54a7..00000000 --- a/session/sess_file.go +++ /dev/null @@ -1,315 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package session - -import ( - "errors" - "fmt" - "io/ioutil" - "net/http" - "os" - "path" - "path/filepath" - "strings" - "sync" - "time" -) - -var ( - filepder = &FileProvider{} - gcmaxlifetime int64 -) - -// FileSessionStore File session store -type FileSessionStore struct { - sid string - lock sync.RWMutex - values map[interface{}]interface{} -} - -// Set value to file session -func (fs *FileSessionStore) Set(key, value interface{}) error { - fs.lock.Lock() - defer fs.lock.Unlock() - fs.values[key] = value - return nil -} - -// Get value from file session -func (fs *FileSessionStore) Get(key interface{}) interface{} { - fs.lock.RLock() - defer fs.lock.RUnlock() - if v, ok := fs.values[key]; ok { - return v - } - return nil -} - -// Delete value in file session by given key -func (fs *FileSessionStore) Delete(key interface{}) error { - fs.lock.Lock() - defer fs.lock.Unlock() - delete(fs.values, key) - return nil -} - -// Flush Clean all values in file session -func (fs *FileSessionStore) Flush() error { - fs.lock.Lock() - defer fs.lock.Unlock() - fs.values = make(map[interface{}]interface{}) - return nil -} - -// SessionID Get file session store id -func (fs *FileSessionStore) SessionID() string { - return fs.sid -} - -// SessionRelease Write file session to local file with Gob string -func (fs *FileSessionStore) SessionRelease(w http.ResponseWriter) { - filepder.lock.Lock() - defer filepder.lock.Unlock() - b, err := EncodeGob(fs.values) - if err != nil { - SLogger.Println(err) - return - } - _, err = os.Stat(path.Join(filepder.savePath, string(fs.sid[0]), string(fs.sid[1]), fs.sid)) - var f *os.File - if err == nil { - f, err = os.OpenFile(path.Join(filepder.savePath, string(fs.sid[0]), string(fs.sid[1]), fs.sid), os.O_RDWR, 0777) - if err != nil { - SLogger.Println(err) - return - } - } else if os.IsNotExist(err) { - f, err = os.Create(path.Join(filepder.savePath, string(fs.sid[0]), string(fs.sid[1]), fs.sid)) - if err != nil { - SLogger.Println(err) - return - } - } else { - return - } - f.Truncate(0) - f.Seek(0, 0) - f.Write(b) - f.Close() -} - -// FileProvider File session provider -type FileProvider struct { - lock sync.RWMutex - maxlifetime int64 - savePath string -} - -// SessionInit Init file session provider. -// savePath sets the session files path. -func (fp *FileProvider) SessionInit(maxlifetime int64, savePath string) error { - fp.maxlifetime = maxlifetime - fp.savePath = savePath - return nil -} - -// SessionRead Read file session by sid. -// if file is not exist, create it. -// the file path is generated from sid string. -func (fp *FileProvider) SessionRead(sid string) (Store, error) { - invalidChars := "./" - if strings.ContainsAny(sid, invalidChars) { - return nil, errors.New("the sid shouldn't have following characters: " + invalidChars) - } - if len(sid) < 2 { - return nil, errors.New("length of the sid is less than 2") - } - filepder.lock.Lock() - defer filepder.lock.Unlock() - - err := os.MkdirAll(path.Join(fp.savePath, string(sid[0]), string(sid[1])), 0755) - if err != nil { - SLogger.Println(err.Error()) - } - _, err = os.Stat(path.Join(fp.savePath, string(sid[0]), string(sid[1]), sid)) - var f *os.File - if err == nil { - f, err = os.OpenFile(path.Join(fp.savePath, string(sid[0]), string(sid[1]), sid), os.O_RDWR, 0777) - } else if os.IsNotExist(err) { - f, err = os.Create(path.Join(fp.savePath, string(sid[0]), string(sid[1]), sid)) - } else { - return nil, err - } - - defer f.Close() - - os.Chtimes(path.Join(fp.savePath, string(sid[0]), string(sid[1]), sid), time.Now(), time.Now()) - var kv map[interface{}]interface{} - b, err := ioutil.ReadAll(f) - if err != nil { - return nil, err - } - if len(b) == 0 { - kv = make(map[interface{}]interface{}) - } else { - kv, err = DecodeGob(b) - if err != nil { - return nil, err - } - } - - ss := &FileSessionStore{sid: sid, values: kv} - return ss, nil -} - -// SessionExist Check file session exist. -// it checks the file named from sid exist or not. -func (fp *FileProvider) SessionExist(sid string) bool { - filepder.lock.Lock() - defer filepder.lock.Unlock() - - if len(sid) < 2 { - SLogger.Println("min length of session id is 2", sid) - return false - } - - _, err := os.Stat(path.Join(fp.savePath, string(sid[0]), string(sid[1]), sid)) - return err == nil -} - -// SessionDestroy Remove all files in this save path -func (fp *FileProvider) SessionDestroy(sid string) error { - filepder.lock.Lock() - defer filepder.lock.Unlock() - os.Remove(path.Join(fp.savePath, string(sid[0]), string(sid[1]), sid)) - return nil -} - -// SessionGC Recycle files in save path -func (fp *FileProvider) SessionGC() { - filepder.lock.Lock() - defer filepder.lock.Unlock() - - gcmaxlifetime = fp.maxlifetime - filepath.Walk(fp.savePath, gcpath) -} - -// SessionAll Get active file session number. -// it walks save path to count files. -func (fp *FileProvider) SessionAll() int { - a := &activeSession{} - err := filepath.Walk(fp.savePath, func(path string, f os.FileInfo, err error) error { - return a.visit(path, f, err) - }) - if err != nil { - SLogger.Printf("filepath.Walk() returned %v\n", err) - return 0 - } - return a.total -} - -// SessionRegenerate Generate new sid for file session. -// it delete old file and create new file named from new sid. -func (fp *FileProvider) SessionRegenerate(oldsid, sid string) (Store, error) { - filepder.lock.Lock() - defer filepder.lock.Unlock() - - oldPath := path.Join(fp.savePath, string(oldsid[0]), string(oldsid[1])) - oldSidFile := path.Join(oldPath, oldsid) - newPath := path.Join(fp.savePath, string(sid[0]), string(sid[1])) - newSidFile := path.Join(newPath, sid) - - // new sid file is exist - _, err := os.Stat(newSidFile) - if err == nil { - return nil, fmt.Errorf("newsid %s exist", newSidFile) - } - - err = os.MkdirAll(newPath, 0755) - if err != nil { - SLogger.Println(err.Error()) - } - - // if old sid file exist - // 1.read and parse file content - // 2.write content to new sid file - // 3.remove old sid file, change new sid file atime and ctime - // 4.return FileSessionStore - _, err = os.Stat(oldSidFile) - if err == nil { - b, err := ioutil.ReadFile(oldSidFile) - if err != nil { - return nil, err - } - - var kv map[interface{}]interface{} - if len(b) == 0 { - kv = make(map[interface{}]interface{}) - } else { - kv, err = DecodeGob(b) - if err != nil { - return nil, err - } - } - - ioutil.WriteFile(newSidFile, b, 0777) - os.Remove(oldSidFile) - os.Chtimes(newSidFile, time.Now(), time.Now()) - ss := &FileSessionStore{sid: sid, values: kv} - return ss, nil - } - - // if old sid file not exist, just create new sid file and return - newf, err := os.Create(newSidFile) - if err != nil { - return nil, err - } - newf.Close() - ss := &FileSessionStore{sid: sid, values: make(map[interface{}]interface{})} - return ss, nil -} - -// remove file in save path if expired -func gcpath(path string, info os.FileInfo, err error) error { - if err != nil { - return err - } - if info.IsDir() { - return nil - } - if (info.ModTime().Unix() + gcmaxlifetime) < time.Now().Unix() { - os.Remove(path) - } - return nil -} - -type activeSession struct { - total int -} - -func (as *activeSession) visit(paths string, f os.FileInfo, err error) error { - if err != nil { - return err - } - if f.IsDir() { - return nil - } - as.total = as.total + 1 - return nil -} - -func init() { - Register("file", filepder) -} diff --git a/session/sess_file_test.go b/session/sess_file_test.go deleted file mode 100644 index 021c43fc..00000000 --- a/session/sess_file_test.go +++ /dev/null @@ -1,386 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package session - -import ( - "fmt" - "os" - "sync" - "testing" - "time" -) - -const sid = "Session_id" -const sidNew = "Session_id_new" -const sessionPath = "./_session_runtime" - -var ( - mutex sync.Mutex -) - -func TestFileProvider_SessionInit(t *testing.T) { - mutex.Lock() - defer mutex.Unlock() - os.RemoveAll(sessionPath) - defer os.RemoveAll(sessionPath) - fp := &FileProvider{} - - _ = fp.SessionInit(180, sessionPath) - if fp.maxlifetime != 180 { - t.Error() - } - - if fp.savePath != sessionPath { - t.Error() - } -} - -func TestFileProvider_SessionExist(t *testing.T) { - mutex.Lock() - defer mutex.Unlock() - os.RemoveAll(sessionPath) - defer os.RemoveAll(sessionPath) - fp := &FileProvider{} - - _ = fp.SessionInit(180, sessionPath) - - if fp.SessionExist(sid) { - t.Error() - } - - _, err := fp.SessionRead(sid) - if err != nil { - t.Error(err) - } - - if !fp.SessionExist(sid) { - t.Error() - } -} - -func TestFileProvider_SessionExist2(t *testing.T) { - mutex.Lock() - defer mutex.Unlock() - os.RemoveAll(sessionPath) - defer os.RemoveAll(sessionPath) - fp := &FileProvider{} - - _ = fp.SessionInit(180, sessionPath) - - if fp.SessionExist(sid) { - t.Error() - } - - if fp.SessionExist("") { - t.Error() - } - - if fp.SessionExist("1") { - t.Error() - } -} - -func TestFileProvider_SessionRead(t *testing.T) { - mutex.Lock() - defer mutex.Unlock() - os.RemoveAll(sessionPath) - defer os.RemoveAll(sessionPath) - fp := &FileProvider{} - - _ = fp.SessionInit(180, sessionPath) - - s, err := fp.SessionRead(sid) - if err != nil { - t.Error(err) - } - - _ = s.Set("sessionValue", 18975) - v := s.Get("sessionValue") - - if v.(int) != 18975 { - t.Error() - } -} - -func TestFileProvider_SessionRead1(t *testing.T) { - mutex.Lock() - defer mutex.Unlock() - os.RemoveAll(sessionPath) - defer os.RemoveAll(sessionPath) - fp := &FileProvider{} - - _ = fp.SessionInit(180, sessionPath) - - _, err := fp.SessionRead("") - if err == nil { - t.Error(err) - } - - _, err = fp.SessionRead("1") - if err == nil { - t.Error(err) - } -} - -func TestFileProvider_SessionAll(t *testing.T) { - mutex.Lock() - defer mutex.Unlock() - os.RemoveAll(sessionPath) - defer os.RemoveAll(sessionPath) - fp := &FileProvider{} - - _ = fp.SessionInit(180, sessionPath) - - sessionCount := 546 - - for i := 1; i <= sessionCount; i++ { - _, err := fp.SessionRead(fmt.Sprintf("%s_%d", sid, i)) - if err != nil { - t.Error(err) - } - } - - if fp.SessionAll() != sessionCount { - t.Error() - } -} - -func TestFileProvider_SessionRegenerate(t *testing.T) { - mutex.Lock() - defer mutex.Unlock() - os.RemoveAll(sessionPath) - defer os.RemoveAll(sessionPath) - fp := &FileProvider{} - - _ = fp.SessionInit(180, sessionPath) - - _, err := fp.SessionRead(sid) - if err != nil { - t.Error(err) - } - - if !fp.SessionExist(sid) { - t.Error() - } - - _, err = fp.SessionRegenerate(sid, sidNew) - if err != nil { - t.Error(err) - } - - if fp.SessionExist(sid) { - t.Error() - } - - if !fp.SessionExist(sidNew) { - t.Error() - } -} - -func TestFileProvider_SessionDestroy(t *testing.T) { - mutex.Lock() - defer mutex.Unlock() - os.RemoveAll(sessionPath) - defer os.RemoveAll(sessionPath) - fp := &FileProvider{} - - _ = fp.SessionInit(180, sessionPath) - - _, err := fp.SessionRead(sid) - if err != nil { - t.Error(err) - } - - if !fp.SessionExist(sid) { - t.Error() - } - - err = fp.SessionDestroy(sid) - if err != nil { - t.Error(err) - } - - if fp.SessionExist(sid) { - t.Error() - } -} - -func TestFileProvider_SessionGC(t *testing.T) { - mutex.Lock() - defer mutex.Unlock() - os.RemoveAll(sessionPath) - defer os.RemoveAll(sessionPath) - fp := &FileProvider{} - - _ = fp.SessionInit(1, sessionPath) - - sessionCount := 412 - - for i := 1; i <= sessionCount; i++ { - _, err := fp.SessionRead(fmt.Sprintf("%s_%d", sid, i)) - if err != nil { - t.Error(err) - } - } - - time.Sleep(2 * time.Second) - - fp.SessionGC() - if fp.SessionAll() != 0 { - t.Error() - } -} - -func TestFileSessionStore_Set(t *testing.T) { - mutex.Lock() - defer mutex.Unlock() - os.RemoveAll(sessionPath) - defer os.RemoveAll(sessionPath) - fp := &FileProvider{} - - _ = fp.SessionInit(180, sessionPath) - - sessionCount := 100 - s, _ := fp.SessionRead(sid) - for i := 1; i <= sessionCount; i++ { - err := s.Set(i, i) - if err != nil { - t.Error(err) - } - } -} - -func TestFileSessionStore_Get(t *testing.T) { - mutex.Lock() - defer mutex.Unlock() - os.RemoveAll(sessionPath) - defer os.RemoveAll(sessionPath) - fp := &FileProvider{} - - _ = fp.SessionInit(180, sessionPath) - - sessionCount := 100 - s, _ := fp.SessionRead(sid) - for i := 1; i <= sessionCount; i++ { - _ = s.Set(i, i) - - v := s.Get(i) - if v.(int) != i { - t.Error() - } - } -} - -func TestFileSessionStore_Delete(t *testing.T) { - mutex.Lock() - defer mutex.Unlock() - os.RemoveAll(sessionPath) - defer os.RemoveAll(sessionPath) - fp := &FileProvider{} - - _ = fp.SessionInit(180, sessionPath) - - s, _ := fp.SessionRead(sid) - s.Set("1", 1) - - if s.Get("1") == nil { - t.Error() - } - - s.Delete("1") - - if s.Get("1") != nil { - t.Error() - } -} - -func TestFileSessionStore_Flush(t *testing.T) { - mutex.Lock() - defer mutex.Unlock() - os.RemoveAll(sessionPath) - defer os.RemoveAll(sessionPath) - fp := &FileProvider{} - - _ = fp.SessionInit(180, sessionPath) - - sessionCount := 100 - s, _ := fp.SessionRead(sid) - for i := 1; i <= sessionCount; i++ { - _ = s.Set(i, i) - } - - _ = s.Flush() - - for i := 1; i <= sessionCount; i++ { - if s.Get(i) != nil { - t.Error() - } - } -} - -func TestFileSessionStore_SessionID(t *testing.T) { - mutex.Lock() - defer mutex.Unlock() - os.RemoveAll(sessionPath) - defer os.RemoveAll(sessionPath) - fp := &FileProvider{} - - _ = fp.SessionInit(180, sessionPath) - - sessionCount := 85 - - for i := 1; i <= sessionCount; i++ { - s, err := fp.SessionRead(fmt.Sprintf("%s_%d", sid, i)) - if err != nil { - t.Error(err) - } - if s.SessionID() != fmt.Sprintf("%s_%d", sid, i) { - t.Error(err) - } - } -} - -func TestFileSessionStore_SessionRelease(t *testing.T) { - mutex.Lock() - defer mutex.Unlock() - os.RemoveAll(sessionPath) - defer os.RemoveAll(sessionPath) - fp := &FileProvider{} - - _ = fp.SessionInit(180, sessionPath) - filepder.savePath = sessionPath - sessionCount := 85 - - for i := 1; i <= sessionCount; i++ { - s, err := fp.SessionRead(fmt.Sprintf("%s_%d", sid, i)) - if err != nil { - t.Error(err) - } - - s.Set(i, i) - s.SessionRelease(nil) - } - - for i := 1; i <= sessionCount; i++ { - s, err := fp.SessionRead(fmt.Sprintf("%s_%d", sid, i)) - if err != nil { - t.Error(err) - } - - if s.Get(i).(int) != i { - t.Error() - } - } -} diff --git a/session/sess_mem.go b/session/sess_mem.go deleted file mode 100644 index 64d8b056..00000000 --- a/session/sess_mem.go +++ /dev/null @@ -1,196 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package session - -import ( - "container/list" - "net/http" - "sync" - "time" -) - -var mempder = &MemProvider{list: list.New(), sessions: make(map[string]*list.Element)} - -// MemSessionStore memory session store. -// it saved sessions in a map in memory. -type MemSessionStore struct { - sid string //session id - timeAccessed time.Time //last access time - value map[interface{}]interface{} //session store - lock sync.RWMutex -} - -// Set value to memory session -func (st *MemSessionStore) Set(key, value interface{}) error { - st.lock.Lock() - defer st.lock.Unlock() - st.value[key] = value - return nil -} - -// Get value from memory session by key -func (st *MemSessionStore) Get(key interface{}) interface{} { - st.lock.RLock() - defer st.lock.RUnlock() - if v, ok := st.value[key]; ok { - return v - } - return nil -} - -// Delete in memory session by key -func (st *MemSessionStore) Delete(key interface{}) error { - st.lock.Lock() - defer st.lock.Unlock() - delete(st.value, key) - return nil -} - -// Flush clear all values in memory session -func (st *MemSessionStore) Flush() error { - st.lock.Lock() - defer st.lock.Unlock() - st.value = make(map[interface{}]interface{}) - return nil -} - -// SessionID get this id of memory session store -func (st *MemSessionStore) SessionID() string { - return st.sid -} - -// SessionRelease Implement method, no used. -func (st *MemSessionStore) SessionRelease(w http.ResponseWriter) { -} - -// MemProvider Implement the provider interface -type MemProvider struct { - lock sync.RWMutex // locker - sessions map[string]*list.Element // map in memory - list *list.List // for gc - maxlifetime int64 - savePath string -} - -// SessionInit init memory session -func (pder *MemProvider) SessionInit(maxlifetime int64, savePath string) error { - pder.maxlifetime = maxlifetime - pder.savePath = savePath - return nil -} - -// SessionRead get memory session store by sid -func (pder *MemProvider) SessionRead(sid string) (Store, error) { - pder.lock.RLock() - if element, ok := pder.sessions[sid]; ok { - go pder.SessionUpdate(sid) - pder.lock.RUnlock() - return element.Value.(*MemSessionStore), nil - } - pder.lock.RUnlock() - pder.lock.Lock() - newsess := &MemSessionStore{sid: sid, timeAccessed: time.Now(), value: make(map[interface{}]interface{})} - element := pder.list.PushFront(newsess) - pder.sessions[sid] = element - pder.lock.Unlock() - return newsess, nil -} - -// SessionExist check session store exist in memory session by sid -func (pder *MemProvider) SessionExist(sid string) bool { - pder.lock.RLock() - defer pder.lock.RUnlock() - if _, ok := pder.sessions[sid]; ok { - return true - } - return false -} - -// SessionRegenerate generate new sid for session store in memory session -func (pder *MemProvider) SessionRegenerate(oldsid, sid string) (Store, error) { - pder.lock.RLock() - if element, ok := pder.sessions[oldsid]; ok { - go pder.SessionUpdate(oldsid) - pder.lock.RUnlock() - pder.lock.Lock() - element.Value.(*MemSessionStore).sid = sid - pder.sessions[sid] = element - delete(pder.sessions, oldsid) - pder.lock.Unlock() - return element.Value.(*MemSessionStore), nil - } - pder.lock.RUnlock() - pder.lock.Lock() - newsess := &MemSessionStore{sid: sid, timeAccessed: time.Now(), value: make(map[interface{}]interface{})} - element := pder.list.PushFront(newsess) - pder.sessions[sid] = element - pder.lock.Unlock() - return newsess, nil -} - -// SessionDestroy delete session store in memory session by id -func (pder *MemProvider) SessionDestroy(sid string) error { - pder.lock.Lock() - defer pder.lock.Unlock() - if element, ok := pder.sessions[sid]; ok { - delete(pder.sessions, sid) - pder.list.Remove(element) - return nil - } - return nil -} - -// SessionGC clean expired session stores in memory session -func (pder *MemProvider) SessionGC() { - pder.lock.RLock() - for { - element := pder.list.Back() - if element == nil { - break - } - if (element.Value.(*MemSessionStore).timeAccessed.Unix() + pder.maxlifetime) < time.Now().Unix() { - pder.lock.RUnlock() - pder.lock.Lock() - pder.list.Remove(element) - delete(pder.sessions, element.Value.(*MemSessionStore).sid) - pder.lock.Unlock() - pder.lock.RLock() - } else { - break - } - } - pder.lock.RUnlock() -} - -// SessionAll get count number of memory session -func (pder *MemProvider) SessionAll() int { - return pder.list.Len() -} - -// SessionUpdate expand time of session store by id in memory session -func (pder *MemProvider) SessionUpdate(sid string) error { - pder.lock.Lock() - defer pder.lock.Unlock() - if element, ok := pder.sessions[sid]; ok { - element.Value.(*MemSessionStore).timeAccessed = time.Now() - pder.list.MoveToFront(element) - return nil - } - return nil -} - -func init() { - Register("memory", mempder) -} diff --git a/session/sess_mem_test.go b/session/sess_mem_test.go deleted file mode 100644 index 2e8934b8..00000000 --- a/session/sess_mem_test.go +++ /dev/null @@ -1,58 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package session - -import ( - "encoding/json" - "net/http" - "net/http/httptest" - "strings" - "testing" -) - -func TestMem(t *testing.T) { - config := `{"cookieName":"gosessionid","gclifetime":10, "enableSetCookie":true}` - conf := new(ManagerConfig) - if err := json.Unmarshal([]byte(config), conf); err != nil { - t.Fatal("json decode error", err) - } - globalSessions, _ := NewManager("memory", conf) - go globalSessions.GC() - r, _ := http.NewRequest("GET", "/", nil) - w := httptest.NewRecorder() - sess, err := globalSessions.SessionStart(w, r) - if err != nil { - t.Fatal("set error,", err) - } - defer sess.SessionRelease(w) - err = sess.Set("username", "astaxie") - if err != nil { - t.Fatal("set error,", err) - } - if username := sess.Get("username"); username != "astaxie" { - t.Fatal("get username error") - } - if cookiestr := w.Header().Get("Set-Cookie"); cookiestr == "" { - t.Fatal("setcookie error") - } else { - parts := strings.Split(strings.TrimSpace(cookiestr), ";") - for k, v := range parts { - nameval := strings.Split(v, "=") - if k == 0 && nameval[0] != "gosessionid" { - t.Fatal("error") - } - } - } -} diff --git a/session/sess_test.go b/session/sess_test.go deleted file mode 100644 index 906abec2..00000000 --- a/session/sess_test.go +++ /dev/null @@ -1,131 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package session - -import ( - "crypto/aes" - "encoding/json" - "testing" -) - -func Test_gob(t *testing.T) { - a := make(map[interface{}]interface{}) - a["username"] = "astaxie" - a[12] = 234 - a["user"] = User{"asta", "xie"} - b, err := EncodeGob(a) - if err != nil { - t.Error(err) - } - c, err := DecodeGob(b) - if err != nil { - t.Error(err) - } - if len(c) == 0 { - t.Error("decodeGob empty") - } - if c["username"] != "astaxie" { - t.Error("decode string error") - } - if c[12] != 234 { - t.Error("decode int error") - } - if c["user"].(User).Username != "asta" { - t.Error("decode struct error") - } -} - -type User struct { - Username string - NickName string -} - -func TestGenerate(t *testing.T) { - str := generateRandomKey(20) - if len(str) != 20 { - t.Fatal("generate length is not equal to 20") - } -} - -func TestCookieEncodeDecode(t *testing.T) { - hashKey := "testhashKey" - blockkey := generateRandomKey(16) - block, err := aes.NewCipher(blockkey) - if err != nil { - t.Fatal("NewCipher:", err) - } - securityName := string(generateRandomKey(20)) - val := make(map[interface{}]interface{}) - val["name"] = "astaxie" - val["gender"] = "male" - str, err := encodeCookie(block, hashKey, securityName, val) - if err != nil { - t.Fatal("encodeCookie:", err) - } - dst, err := decodeCookie(block, hashKey, securityName, str, 3600) - if err != nil { - t.Fatal("decodeCookie", err) - } - if dst["name"] != "astaxie" { - t.Fatal("dst get map error") - } - if dst["gender"] != "male" { - t.Fatal("dst get map error") - } -} - -func TestParseConfig(t *testing.T) { - s := `{"cookieName":"gosessionid","gclifetime":3600}` - cf := new(ManagerConfig) - cf.EnableSetCookie = true - err := json.Unmarshal([]byte(s), cf) - if err != nil { - t.Fatal("parse json error,", err) - } - if cf.CookieName != "gosessionid" { - t.Fatal("parseconfig get cookiename error") - } - if cf.Gclifetime != 3600 { - t.Fatal("parseconfig get gclifetime error") - } - - cc := `{"cookieName":"gosessionid","enableSetCookie":false,"gclifetime":3600,"ProviderConfig":"{\"cookieName\":\"gosessionid\",\"securityKey\":\"beegocookiehashkey\"}"}` - cf2 := new(ManagerConfig) - cf2.EnableSetCookie = true - err = json.Unmarshal([]byte(cc), cf2) - if err != nil { - t.Fatal("parse json error,", err) - } - if cf2.CookieName != "gosessionid" { - t.Fatal("parseconfig get cookiename error") - } - if cf2.Gclifetime != 3600 { - t.Fatal("parseconfig get gclifetime error") - } - if cf2.EnableSetCookie { - t.Fatal("parseconfig get enableSetCookie error") - } - cconfig := new(cookieConfig) - err = json.Unmarshal([]byte(cf2.ProviderConfig), cconfig) - if err != nil { - t.Fatal("parse ProviderConfig err,", err) - } - if cconfig.CookieName != "gosessionid" { - t.Fatal("ProviderConfig get cookieName error") - } - if cconfig.SecurityKey != "beegocookiehashkey" { - t.Fatal("ProviderConfig get securityKey error") - } -} diff --git a/session/sess_utils.go b/session/sess_utils.go deleted file mode 100644 index 20915bb6..00000000 --- a/session/sess_utils.go +++ /dev/null @@ -1,207 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package session - -import ( - "bytes" - "crypto/cipher" - "crypto/hmac" - "crypto/rand" - "crypto/sha256" - "crypto/subtle" - "encoding/base64" - "encoding/gob" - "errors" - "fmt" - "io" - "strconv" - "time" - - "github.com/astaxie/beego/utils" -) - -func init() { - gob.Register([]interface{}{}) - gob.Register(map[int]interface{}{}) - gob.Register(map[string]interface{}{}) - gob.Register(map[interface{}]interface{}{}) - gob.Register(map[string]string{}) - gob.Register(map[int]string{}) - gob.Register(map[int]int{}) - gob.Register(map[int]int64{}) -} - -// EncodeGob encode the obj to gob -func EncodeGob(obj map[interface{}]interface{}) ([]byte, error) { - for _, v := range obj { - gob.Register(v) - } - buf := bytes.NewBuffer(nil) - enc := gob.NewEncoder(buf) - err := enc.Encode(obj) - if err != nil { - return []byte(""), err - } - return buf.Bytes(), nil -} - -// DecodeGob decode data to map -func DecodeGob(encoded []byte) (map[interface{}]interface{}, error) { - buf := bytes.NewBuffer(encoded) - dec := gob.NewDecoder(buf) - var out map[interface{}]interface{} - err := dec.Decode(&out) - if err != nil { - return nil, err - } - return out, nil -} - -// generateRandomKey creates a random key with the given strength. -func generateRandomKey(strength int) []byte { - k := make([]byte, strength) - if n, err := io.ReadFull(rand.Reader, k); n != strength || err != nil { - return utils.RandomCreateBytes(strength) - } - return k -} - -// Encryption ----------------------------------------------------------------- - -// encrypt encrypts a value using the given block in counter mode. -// -// A random initialization vector (http://goo.gl/zF67k) with the length of the -// block size is prepended to the resulting ciphertext. -func encrypt(block cipher.Block, value []byte) ([]byte, error) { - iv := generateRandomKey(block.BlockSize()) - if iv == nil { - return nil, errors.New("encrypt: failed to generate random iv") - } - // Encrypt it. - stream := cipher.NewCTR(block, iv) - stream.XORKeyStream(value, value) - // Return iv + ciphertext. - return append(iv, value...), nil -} - -// decrypt decrypts a value using the given block in counter mode. -// -// The value to be decrypted must be prepended by a initialization vector -// (http://goo.gl/zF67k) with the length of the block size. -func decrypt(block cipher.Block, value []byte) ([]byte, error) { - size := block.BlockSize() - if len(value) > size { - // Extract iv. - iv := value[:size] - // Extract ciphertext. - value = value[size:] - // Decrypt it. - stream := cipher.NewCTR(block, iv) - stream.XORKeyStream(value, value) - return value, nil - } - return nil, errors.New("decrypt: the value could not be decrypted") -} - -func encodeCookie(block cipher.Block, hashKey, name string, value map[interface{}]interface{}) (string, error) { - var err error - var b []byte - // 1. EncodeGob. - if b, err = EncodeGob(value); err != nil { - return "", err - } - // 2. Encrypt (optional). - if b, err = encrypt(block, b); err != nil { - return "", err - } - b = encode(b) - // 3. Create MAC for "name|date|value". Extra pipe to be used later. - b = []byte(fmt.Sprintf("%s|%d|%s|", name, time.Now().UTC().Unix(), b)) - h := hmac.New(sha256.New, []byte(hashKey)) - h.Write(b) - sig := h.Sum(nil) - // Append mac, remove name. - b = append(b, sig...)[len(name)+1:] - // 4. Encode to base64. - b = encode(b) - // Done. - return string(b), nil -} - -func decodeCookie(block cipher.Block, hashKey, name, value string, gcmaxlifetime int64) (map[interface{}]interface{}, error) { - // 1. Decode from base64. - b, err := decode([]byte(value)) - if err != nil { - return nil, err - } - // 2. Verify MAC. Value is "date|value|mac". - parts := bytes.SplitN(b, []byte("|"), 3) - if len(parts) != 3 { - return nil, errors.New("Decode: invalid value format") - } - - b = append([]byte(name+"|"), b[:len(b)-len(parts[2])]...) - h := hmac.New(sha256.New, []byte(hashKey)) - h.Write(b) - sig := h.Sum(nil) - if len(sig) != len(parts[2]) || subtle.ConstantTimeCompare(sig, parts[2]) != 1 { - return nil, errors.New("Decode: the value is not valid") - } - // 3. Verify date ranges. - var t1 int64 - if t1, err = strconv.ParseInt(string(parts[0]), 10, 64); err != nil { - return nil, errors.New("Decode: invalid timestamp") - } - t2 := time.Now().UTC().Unix() - if t1 > t2 { - return nil, errors.New("Decode: timestamp is too new") - } - if t1 < t2-gcmaxlifetime { - return nil, errors.New("Decode: expired timestamp") - } - // 4. Decrypt (optional). - b, err = decode(parts[1]) - if err != nil { - return nil, err - } - if b, err = decrypt(block, b); err != nil { - return nil, err - } - // 5. DecodeGob. - dst, err := DecodeGob(b) - if err != nil { - return nil, err - } - return dst, nil -} - -// Encoding ------------------------------------------------------------------- - -// encode encodes a value using base64. -func encode(value []byte) []byte { - encoded := make([]byte, base64.URLEncoding.EncodedLen(len(value))) - base64.URLEncoding.Encode(encoded, value) - return encoded -} - -// decode decodes a cookie using base64. -func decode(value []byte) ([]byte, error) { - decoded := make([]byte, base64.URLEncoding.DecodedLen(len(value))) - b, err := base64.URLEncoding.Decode(decoded, value) - if err != nil { - return nil, err - } - return decoded[:b], nil -} diff --git a/session/session.go b/session/session.go deleted file mode 100644 index eb85360a..00000000 --- a/session/session.go +++ /dev/null @@ -1,377 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Package session provider -// -// Usage: -// import( -// "github.com/astaxie/beego/session" -// ) -// -// func init() { -// globalSessions, _ = session.NewManager("memory", `{"cookieName":"gosessionid", "enableSetCookie,omitempty": true, "gclifetime":3600, "maxLifetime": 3600, "secure": false, "cookieLifeTime": 3600, "providerConfig": ""}`) -// go globalSessions.GC() -// } -// -// more docs: http://beego.me/docs/module/session.md -package session - -import ( - "crypto/rand" - "encoding/hex" - "errors" - "fmt" - "io" - "log" - "net/http" - "net/textproto" - "net/url" - "os" - "time" -) - -// Store contains all data for one session process with specific id. -type Store interface { - Set(key, value interface{}) error //set session value - Get(key interface{}) interface{} //get session value - Delete(key interface{}) error //delete session value - SessionID() string //back current sessionID - SessionRelease(w http.ResponseWriter) // release the resource & save data to provider & return the data - Flush() error //delete all data -} - -// Provider contains global session methods and saved SessionStores. -// it can operate a SessionStore by its id. -type Provider interface { - SessionInit(gclifetime int64, config string) error - SessionRead(sid string) (Store, error) - SessionExist(sid string) bool - SessionRegenerate(oldsid, sid string) (Store, error) - SessionDestroy(sid string) error - SessionAll() int //get all active session - SessionGC() -} - -var provides = make(map[string]Provider) - -// SLogger a helpful variable to log information about session -var SLogger = NewSessionLog(os.Stderr) - -// Register makes a session provide available by the provided name. -// If Register is called twice with the same name or if driver is nil, -// it panics. -func Register(name string, provide Provider) { - if provide == nil { - panic("session: Register provide is nil") - } - if _, dup := provides[name]; dup { - panic("session: Register called twice for provider " + name) - } - provides[name] = provide -} - -//GetProvider -func GetProvider(name string) (Provider, error) { - provider, ok := provides[name] - if !ok { - return nil, fmt.Errorf("session: unknown provide %q (forgotten import?)", name) - } - return provider, nil -} - -// ManagerConfig define the session config -type ManagerConfig struct { - CookieName string `json:"cookieName"` - EnableSetCookie bool `json:"enableSetCookie,omitempty"` - Gclifetime int64 `json:"gclifetime"` - Maxlifetime int64 `json:"maxLifetime"` - DisableHTTPOnly bool `json:"disableHTTPOnly"` - Secure bool `json:"secure"` - CookieLifeTime int `json:"cookieLifeTime"` - ProviderConfig string `json:"providerConfig"` - Domain string `json:"domain"` - SessionIDLength int64 `json:"sessionIDLength"` - EnableSidInHTTPHeader bool `json:"EnableSidInHTTPHeader"` - SessionNameInHTTPHeader string `json:"SessionNameInHTTPHeader"` - EnableSidInURLQuery bool `json:"EnableSidInURLQuery"` - SessionIDPrefix string `json:"sessionIDPrefix"` -} - -// Manager contains Provider and its configuration. -type Manager struct { - provider Provider - config *ManagerConfig -} - -// NewManager Create new Manager with provider name and json config string. -// provider name: -// 1. cookie -// 2. file -// 3. memory -// 4. redis -// 5. mysql -// json config: -// 1. is https default false -// 2. hashfunc default sha1 -// 3. hashkey default beegosessionkey -// 4. maxage default is none -func NewManager(provideName string, cf *ManagerConfig) (*Manager, error) { - provider, ok := provides[provideName] - if !ok { - return nil, fmt.Errorf("session: unknown provide %q (forgotten import?)", provideName) - } - - if cf.Maxlifetime == 0 { - cf.Maxlifetime = cf.Gclifetime - } - - if cf.EnableSidInHTTPHeader { - if cf.SessionNameInHTTPHeader == "" { - panic(errors.New("SessionNameInHTTPHeader is empty")) - } - - strMimeHeader := textproto.CanonicalMIMEHeaderKey(cf.SessionNameInHTTPHeader) - if cf.SessionNameInHTTPHeader != strMimeHeader { - strErrMsg := "SessionNameInHTTPHeader (" + cf.SessionNameInHTTPHeader + ") has the wrong format, it should be like this : " + strMimeHeader - panic(errors.New(strErrMsg)) - } - } - - err := provider.SessionInit(cf.Maxlifetime, cf.ProviderConfig) - if err != nil { - return nil, err - } - - if cf.SessionIDLength == 0 { - cf.SessionIDLength = 16 - } - - return &Manager{ - provider, - cf, - }, nil -} - -// GetProvider return current manager's provider -func (manager *Manager) GetProvider() Provider { - return manager.provider -} - -// getSid retrieves session identifier from HTTP Request. -// First try to retrieve id by reading from cookie, session cookie name is configurable, -// if not exist, then retrieve id from querying parameters. -// -// error is not nil when there is anything wrong. -// sid is empty when need to generate a new session id -// otherwise return an valid session id. -func (manager *Manager) getSid(r *http.Request) (string, error) { - cookie, errs := r.Cookie(manager.config.CookieName) - if errs != nil || cookie.Value == "" { - var sid string - if manager.config.EnableSidInURLQuery { - errs := r.ParseForm() - if errs != nil { - return "", errs - } - - sid = r.FormValue(manager.config.CookieName) - } - - // if not found in Cookie / param, then read it from request headers - if manager.config.EnableSidInHTTPHeader && sid == "" { - sids, isFound := r.Header[manager.config.SessionNameInHTTPHeader] - if isFound && len(sids) != 0 { - return sids[0], nil - } - } - - return sid, nil - } - - // HTTP Request contains cookie for sessionid info. - return url.QueryUnescape(cookie.Value) -} - -// SessionStart generate or read the session id from http request. -// if session id exists, return SessionStore with this id. -func (manager *Manager) SessionStart(w http.ResponseWriter, r *http.Request) (session Store, err error) { - sid, errs := manager.getSid(r) - if errs != nil { - return nil, errs - } - - if sid != "" && manager.provider.SessionExist(sid) { - return manager.provider.SessionRead(sid) - } - - // Generate a new session - sid, errs = manager.sessionID() - if errs != nil { - return nil, errs - } - - session, err = manager.provider.SessionRead(sid) - if err != nil { - return nil, err - } - cookie := &http.Cookie{ - Name: manager.config.CookieName, - Value: url.QueryEscape(sid), - Path: "/", - HttpOnly: !manager.config.DisableHTTPOnly, - Secure: manager.isSecure(r), - Domain: manager.config.Domain, - } - if manager.config.CookieLifeTime > 0 { - cookie.MaxAge = manager.config.CookieLifeTime - cookie.Expires = time.Now().Add(time.Duration(manager.config.CookieLifeTime) * time.Second) - } - if manager.config.EnableSetCookie { - http.SetCookie(w, cookie) - } - r.AddCookie(cookie) - - if manager.config.EnableSidInHTTPHeader { - r.Header.Set(manager.config.SessionNameInHTTPHeader, sid) - w.Header().Set(manager.config.SessionNameInHTTPHeader, sid) - } - - return -} - -// SessionDestroy Destroy session by its id in http request cookie. -func (manager *Manager) SessionDestroy(w http.ResponseWriter, r *http.Request) { - if manager.config.EnableSidInHTTPHeader { - r.Header.Del(manager.config.SessionNameInHTTPHeader) - w.Header().Del(manager.config.SessionNameInHTTPHeader) - } - - cookie, err := r.Cookie(manager.config.CookieName) - if err != nil || cookie.Value == "" { - return - } - - sid, _ := url.QueryUnescape(cookie.Value) - manager.provider.SessionDestroy(sid) - if manager.config.EnableSetCookie { - expiration := time.Now() - cookie = &http.Cookie{Name: manager.config.CookieName, - Path: "/", - HttpOnly: !manager.config.DisableHTTPOnly, - Expires: expiration, - MaxAge: -1, - Domain: manager.config.Domain} - - http.SetCookie(w, cookie) - } -} - -// GetSessionStore Get SessionStore by its id. -func (manager *Manager) GetSessionStore(sid string) (sessions Store, err error) { - sessions, err = manager.provider.SessionRead(sid) - return -} - -// GC Start session gc process. -// it can do gc in times after gc lifetime. -func (manager *Manager) GC() { - manager.provider.SessionGC() - time.AfterFunc(time.Duration(manager.config.Gclifetime)*time.Second, func() { manager.GC() }) -} - -// SessionRegenerateID Regenerate a session id for this SessionStore who's id is saving in http request. -func (manager *Manager) SessionRegenerateID(w http.ResponseWriter, r *http.Request) (session Store) { - sid, err := manager.sessionID() - if err != nil { - return - } - cookie, err := r.Cookie(manager.config.CookieName) - if err != nil || cookie.Value == "" { - //delete old cookie - session, _ = manager.provider.SessionRead(sid) - cookie = &http.Cookie{Name: manager.config.CookieName, - Value: url.QueryEscape(sid), - Path: "/", - HttpOnly: !manager.config.DisableHTTPOnly, - Secure: manager.isSecure(r), - Domain: manager.config.Domain, - } - } else { - oldsid, _ := url.QueryUnescape(cookie.Value) - session, _ = manager.provider.SessionRegenerate(oldsid, sid) - cookie.Value = url.QueryEscape(sid) - cookie.HttpOnly = true - cookie.Path = "/" - } - if manager.config.CookieLifeTime > 0 { - cookie.MaxAge = manager.config.CookieLifeTime - cookie.Expires = time.Now().Add(time.Duration(manager.config.CookieLifeTime) * time.Second) - } - if manager.config.EnableSetCookie { - http.SetCookie(w, cookie) - } - r.AddCookie(cookie) - - if manager.config.EnableSidInHTTPHeader { - r.Header.Set(manager.config.SessionNameInHTTPHeader, sid) - w.Header().Set(manager.config.SessionNameInHTTPHeader, sid) - } - - return -} - -// GetActiveSession Get all active sessions count number. -func (manager *Manager) GetActiveSession() int { - return manager.provider.SessionAll() -} - -// SetSecure Set cookie with https. -func (manager *Manager) SetSecure(secure bool) { - manager.config.Secure = secure -} - -func (manager *Manager) sessionID() (string, error) { - b := make([]byte, manager.config.SessionIDLength) - n, err := rand.Read(b) - if n != len(b) || err != nil { - return "", fmt.Errorf("Could not successfully read from the system CSPRNG") - } - return manager.config.SessionIDPrefix + hex.EncodeToString(b), nil -} - -// Set cookie with https. -func (manager *Manager) isSecure(req *http.Request) bool { - if !manager.config.Secure { - return false - } - if req.URL.Scheme != "" { - return req.URL.Scheme == "https" - } - if req.TLS == nil { - return false - } - return true -} - -// Log implement the log.Logger -type Log struct { - *log.Logger -} - -// NewSessionLog set io.Writer to create a Logger for session. -func NewSessionLog(out io.Writer) *Log { - sl := new(Log) - sl.Logger = log.New(out, "[SESSION]", 1e9) - return sl -} diff --git a/session/ssdb/sess_ssdb.go b/session/ssdb/sess_ssdb.go deleted file mode 100644 index de0c6360..00000000 --- a/session/ssdb/sess_ssdb.go +++ /dev/null @@ -1,199 +0,0 @@ -package ssdb - -import ( - "errors" - "net/http" - "strconv" - "strings" - "sync" - - "github.com/astaxie/beego/session" - "github.com/ssdb/gossdb/ssdb" -) - -var ssdbProvider = &Provider{} - -// Provider holds ssdb client and configs -type Provider struct { - client *ssdb.Client - host string - port int - maxLifetime int64 -} - -func (p *Provider) connectInit() error { - var err error - if p.host == "" || p.port == 0 { - return errors.New("SessionInit First") - } - p.client, err = ssdb.Connect(p.host, p.port) - return err -} - -// SessionInit init the ssdb with the config -func (p *Provider) SessionInit(maxLifetime int64, savePath string) error { - p.maxLifetime = maxLifetime - address := strings.Split(savePath, ":") - p.host = address[0] - - var err error - if p.port, err = strconv.Atoi(address[1]); err != nil { - return err - } - return p.connectInit() -} - -// SessionRead return a ssdb client session Store -func (p *Provider) SessionRead(sid string) (session.Store, error) { - if p.client == nil { - if err := p.connectInit(); err != nil { - return nil, err - } - } - var kv map[interface{}]interface{} - value, err := p.client.Get(sid) - if err != nil { - return nil, err - } - if value == nil || len(value.(string)) == 0 { - kv = make(map[interface{}]interface{}) - } else { - kv, err = session.DecodeGob([]byte(value.(string))) - if err != nil { - return nil, err - } - } - rs := &SessionStore{sid: sid, values: kv, maxLifetime: p.maxLifetime, client: p.client} - return rs, nil -} - -// SessionExist judged whether sid is exist in session -func (p *Provider) SessionExist(sid string) bool { - if p.client == nil { - if err := p.connectInit(); err != nil { - panic(err) - } - } - value, err := p.client.Get(sid) - if err != nil { - panic(err) - } - if value == nil || len(value.(string)) == 0 { - return false - } - return true -} - -// SessionRegenerate regenerate session with new sid and delete oldsid -func (p *Provider) SessionRegenerate(oldsid, sid string) (session.Store, error) { - //conn.Do("setx", key, v, ttl) - if p.client == nil { - if err := p.connectInit(); err != nil { - return nil, err - } - } - value, err := p.client.Get(oldsid) - if err != nil { - return nil, err - } - var kv map[interface{}]interface{} - if value == nil || len(value.(string)) == 0 { - kv = make(map[interface{}]interface{}) - } else { - kv, err = session.DecodeGob([]byte(value.(string))) - if err != nil { - return nil, err - } - _, err = p.client.Del(oldsid) - if err != nil { - return nil, err - } - } - _, e := p.client.Do("setx", sid, value, p.maxLifetime) - if e != nil { - return nil, e - } - rs := &SessionStore{sid: sid, values: kv, maxLifetime: p.maxLifetime, client: p.client} - return rs, nil -} - -// SessionDestroy destroy the sid -func (p *Provider) SessionDestroy(sid string) error { - if p.client == nil { - if err := p.connectInit(); err != nil { - return err - } - } - _, err := p.client.Del(sid) - return err -} - -// SessionGC not implemented -func (p *Provider) SessionGC() { -} - -// SessionAll not implemented -func (p *Provider) SessionAll() int { - return 0 -} - -// SessionStore holds the session information which stored in ssdb -type SessionStore struct { - sid string - lock sync.RWMutex - values map[interface{}]interface{} - maxLifetime int64 - client *ssdb.Client -} - -// Set the key and value -func (s *SessionStore) Set(key, value interface{}) error { - s.lock.Lock() - defer s.lock.Unlock() - s.values[key] = value - return nil -} - -// Get return the value by the key -func (s *SessionStore) Get(key interface{}) interface{} { - s.lock.Lock() - defer s.lock.Unlock() - if value, ok := s.values[key]; ok { - return value - } - return nil -} - -// Delete the key in session store -func (s *SessionStore) Delete(key interface{}) error { - s.lock.Lock() - defer s.lock.Unlock() - delete(s.values, key) - return nil -} - -// Flush delete all keys and values -func (s *SessionStore) Flush() error { - s.lock.Lock() - defer s.lock.Unlock() - s.values = make(map[interface{}]interface{}) - return nil -} - -// SessionID return the sessionID -func (s *SessionStore) SessionID() string { - return s.sid -} - -// SessionRelease Store the keyvalues into ssdb -func (s *SessionStore) SessionRelease(w http.ResponseWriter) { - b, err := session.EncodeGob(s.values) - if err != nil { - return - } - s.client.Do("setx", s.sid, string(b), s.maxLifetime) -} - -func init() { - session.Register("ssdb", ssdbProvider) -} diff --git a/staticfile.go b/staticfile.go deleted file mode 100644 index e26776c5..00000000 --- a/staticfile.go +++ /dev/null @@ -1,234 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package beego - -import ( - "bytes" - "errors" - "net/http" - "os" - "path" - "path/filepath" - "strconv" - "strings" - "sync" - "time" - - "github.com/astaxie/beego/context" - "github.com/astaxie/beego/logs" - "github.com/hashicorp/golang-lru" -) - -var errNotStaticRequest = errors.New("request not a static file request") - -func serverStaticRouter(ctx *context.Context) { - if ctx.Input.Method() != "GET" && ctx.Input.Method() != "HEAD" { - return - } - - forbidden, filePath, fileInfo, err := lookupFile(ctx) - if err == errNotStaticRequest { - return - } - - if forbidden { - exception("403", ctx) - return - } - - if filePath == "" || fileInfo == nil { - if BConfig.RunMode == DEV { - logs.Warn("Can't find/open the file:", filePath, err) - } - http.NotFound(ctx.ResponseWriter, ctx.Request) - return - } - if fileInfo.IsDir() { - requestURL := ctx.Input.URL() - if requestURL[len(requestURL)-1] != '/' { - redirectURL := requestURL + "/" - if ctx.Request.URL.RawQuery != "" { - redirectURL = redirectURL + "?" + ctx.Request.URL.RawQuery - } - ctx.Redirect(302, redirectURL) - } else { - //serveFile will list dir - http.ServeFile(ctx.ResponseWriter, ctx.Request, filePath) - } - return - } else if fileInfo.Size() > int64(BConfig.WebConfig.StaticCacheFileSize) { - //over size file serve with http module - http.ServeFile(ctx.ResponseWriter, ctx.Request, filePath) - return - } - - var enableCompress = BConfig.EnableGzip && isStaticCompress(filePath) - var acceptEncoding string - if enableCompress { - acceptEncoding = context.ParseEncoding(ctx.Request) - } - b, n, sch, reader, err := openFile(filePath, fileInfo, acceptEncoding) - if err != nil { - if BConfig.RunMode == DEV { - logs.Warn("Can't compress the file:", filePath, err) - } - http.NotFound(ctx.ResponseWriter, ctx.Request) - return - } - - if b { - ctx.Output.Header("Content-Encoding", n) - } else { - ctx.Output.Header("Content-Length", strconv.FormatInt(sch.size, 10)) - } - - http.ServeContent(ctx.ResponseWriter, ctx.Request, filePath, sch.modTime, reader) -} - -type serveContentHolder struct { - data []byte - modTime time.Time - size int64 - originSize int64 //original file size:to judge file changed - encoding string -} - -type serveContentReader struct { - *bytes.Reader -} - -var ( - staticFileLruCache *lru.Cache - lruLock sync.RWMutex -) - -func openFile(filePath string, fi os.FileInfo, acceptEncoding string) (bool, string, *serveContentHolder, *serveContentReader, error) { - if staticFileLruCache == nil { - //avoid lru cache error - if BConfig.WebConfig.StaticCacheFileNum >= 1 { - staticFileLruCache, _ = lru.New(BConfig.WebConfig.StaticCacheFileNum) - } else { - staticFileLruCache, _ = lru.New(1) - } - } - mapKey := acceptEncoding + ":" + filePath - lruLock.RLock() - var mapFile *serveContentHolder - if cacheItem, ok := staticFileLruCache.Get(mapKey); ok { - mapFile = cacheItem.(*serveContentHolder) - } - lruLock.RUnlock() - if isOk(mapFile, fi) { - reader := &serveContentReader{Reader: bytes.NewReader(mapFile.data)} - return mapFile.encoding != "", mapFile.encoding, mapFile, reader, nil - } - lruLock.Lock() - defer lruLock.Unlock() - if cacheItem, ok := staticFileLruCache.Get(mapKey); ok { - mapFile = cacheItem.(*serveContentHolder) - } - if !isOk(mapFile, fi) { - file, err := os.Open(filePath) - if err != nil { - return false, "", nil, nil, err - } - defer file.Close() - var bufferWriter bytes.Buffer - _, n, err := context.WriteFile(acceptEncoding, &bufferWriter, file) - if err != nil { - return false, "", nil, nil, err - } - mapFile = &serveContentHolder{data: bufferWriter.Bytes(), modTime: fi.ModTime(), size: int64(bufferWriter.Len()), originSize: fi.Size(), encoding: n} - if isOk(mapFile, fi) { - staticFileLruCache.Add(mapKey, mapFile) - } - } - - reader := &serveContentReader{Reader: bytes.NewReader(mapFile.data)} - return mapFile.encoding != "", mapFile.encoding, mapFile, reader, nil -} - -func isOk(s *serveContentHolder, fi os.FileInfo) bool { - if s == nil { - return false - } else if s.size > int64(BConfig.WebConfig.StaticCacheFileSize) { - return false - } - return s.modTime == fi.ModTime() && s.originSize == fi.Size() -} - -// isStaticCompress detect static files -func isStaticCompress(filePath string) bool { - for _, statExtension := range BConfig.WebConfig.StaticExtensionsToGzip { - if strings.HasSuffix(strings.ToLower(filePath), strings.ToLower(statExtension)) { - return true - } - } - return false -} - -// searchFile search the file by url path -// if none the static file prefix matches ,return notStaticRequestErr -func searchFile(ctx *context.Context) (string, os.FileInfo, error) { - requestPath := filepath.ToSlash(filepath.Clean(ctx.Request.URL.Path)) - // special processing : favicon.ico/robots.txt can be in any static dir - if requestPath == "/favicon.ico" || requestPath == "/robots.txt" { - file := path.Join(".", requestPath) - if fi, _ := os.Stat(file); fi != nil { - return file, fi, nil - } - for _, staticDir := range BConfig.WebConfig.StaticDir { - filePath := path.Join(staticDir, requestPath) - if fi, _ := os.Stat(filePath); fi != nil { - return filePath, fi, nil - } - } - return "", nil, errNotStaticRequest - } - - for prefix, staticDir := range BConfig.WebConfig.StaticDir { - if !strings.Contains(requestPath, prefix) { - continue - } - if prefix != "/" && len(requestPath) > len(prefix) && requestPath[len(prefix)] != '/' { - continue - } - filePath := path.Join(staticDir, requestPath[len(prefix):]) - if fi, err := os.Stat(filePath); fi != nil { - return filePath, fi, err - } - } - return "", nil, errNotStaticRequest -} - -// lookupFile find the file to serve -// if the file is dir ,search the index.html as default file( MUST NOT A DIR also) -// if the index.html not exist or is a dir, give a forbidden response depending on DirectoryIndex -func lookupFile(ctx *context.Context) (bool, string, os.FileInfo, error) { - fp, fi, err := searchFile(ctx) - if fp == "" || fi == nil { - return false, "", nil, err - } - if !fi.IsDir() { - return false, fp, fi, err - } - if requestURL := ctx.Input.URL(); requestURL[len(requestURL)-1] == '/' { - ifp := filepath.Join(fp, "index.html") - if ifi, _ := os.Stat(ifp); ifi != nil && ifi.Mode().IsRegular() { - return false, ifp, ifi, err - } - } - return !BConfig.WebConfig.DirectoryIndex, fp, fi, err -} diff --git a/staticfile_test.go b/staticfile_test.go deleted file mode 100644 index e46c13ec..00000000 --- a/staticfile_test.go +++ /dev/null @@ -1,99 +0,0 @@ -package beego - -import ( - "bytes" - "compress/gzip" - "compress/zlib" - "fmt" - "io" - "io/ioutil" - "os" - "path/filepath" - "testing" -) - -var currentWorkDir, _ = os.Getwd() -var licenseFile = filepath.Join(currentWorkDir, "LICENSE") - -func testOpenFile(encoding string, content []byte, t *testing.T) { - fi, _ := os.Stat(licenseFile) - b, n, sch, reader, err := openFile(licenseFile, fi, encoding) - if err != nil { - t.Log(err) - t.Fail() - } - - t.Log("open static file encoding "+n, b) - - assetOpenFileAndContent(sch, reader, content, t) -} -func TestOpenStaticFile_1(t *testing.T) { - file, _ := os.Open(licenseFile) - content, _ := ioutil.ReadAll(file) - testOpenFile("", content, t) -} - -func TestOpenStaticFileGzip_1(t *testing.T) { - file, _ := os.Open(licenseFile) - var zipBuf bytes.Buffer - fileWriter, _ := gzip.NewWriterLevel(&zipBuf, gzip.BestCompression) - io.Copy(fileWriter, file) - fileWriter.Close() - content, _ := ioutil.ReadAll(&zipBuf) - - testOpenFile("gzip", content, t) -} -func TestOpenStaticFileDeflate_1(t *testing.T) { - file, _ := os.Open(licenseFile) - var zipBuf bytes.Buffer - fileWriter, _ := zlib.NewWriterLevel(&zipBuf, zlib.BestCompression) - io.Copy(fileWriter, file) - fileWriter.Close() - content, _ := ioutil.ReadAll(&zipBuf) - - testOpenFile("deflate", content, t) -} - -func TestStaticCacheWork(t *testing.T) { - encodings := []string{"", "gzip", "deflate"} - - fi, _ := os.Stat(licenseFile) - for _, encoding := range encodings { - _, _, first, _, err := openFile(licenseFile, fi, encoding) - if err != nil { - t.Error(err) - continue - } - - _, _, second, _, err := openFile(licenseFile, fi, encoding) - if err != nil { - t.Error(err) - continue - } - - address1 := fmt.Sprintf("%p", first) - address2 := fmt.Sprintf("%p", second) - if address1 != address2 { - t.Errorf("encoding '%v' can not hit cache", encoding) - } - } -} - -func assetOpenFileAndContent(sch *serveContentHolder, reader *serveContentReader, content []byte, t *testing.T) { - t.Log(sch.size, len(content)) - if sch.size != int64(len(content)) { - t.Log("static content file size not same") - t.Fail() - } - bs, _ := ioutil.ReadAll(reader) - for i, v := range content { - if v != bs[i] { - t.Log("content not same") - t.Fail() - } - } - if staticFileLruCache.Len() == 0 { - t.Log("men map is empty") - t.Fail() - } -} diff --git a/swagger/swagger.go b/swagger/swagger.go deleted file mode 100644 index a55676cd..00000000 --- a/swagger/swagger.go +++ /dev/null @@ -1,174 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// -// Swagger™ is a project used to describe and document RESTful APIs. -// -// The Swagger specification defines a set of files required to describe such an API. These files can then be used by the Swagger-UI project to display the API and Swagger-Codegen to generate clients in various languages. Additional utilities can also take advantage of the resulting files, such as testing tools. -// Now in version 2.0, Swagger is more enabling than ever. And it's 100% open source software. - -// Package swagger struct definition -package swagger - -// Swagger list the resource -type Swagger struct { - SwaggerVersion string `json:"swagger,omitempty" yaml:"swagger,omitempty"` - Infos Information `json:"info" yaml:"info"` - Host string `json:"host,omitempty" yaml:"host,omitempty"` - BasePath string `json:"basePath,omitempty" yaml:"basePath,omitempty"` - Schemes []string `json:"schemes,omitempty" yaml:"schemes,omitempty"` - Consumes []string `json:"consumes,omitempty" yaml:"consumes,omitempty"` - Produces []string `json:"produces,omitempty" yaml:"produces,omitempty"` - Paths map[string]*Item `json:"paths" yaml:"paths"` - Definitions map[string]Schema `json:"definitions,omitempty" yaml:"definitions,omitempty"` - SecurityDefinitions map[string]Security `json:"securityDefinitions,omitempty" yaml:"securityDefinitions,omitempty"` - Security []map[string][]string `json:"security,omitempty" yaml:"security,omitempty"` - Tags []Tag `json:"tags,omitempty" yaml:"tags,omitempty"` - ExternalDocs *ExternalDocs `json:"externalDocs,omitempty" yaml:"externalDocs,omitempty"` -} - -// Information Provides metadata about the API. The metadata can be used by the clients if needed. -type Information struct { - Title string `json:"title,omitempty" yaml:"title,omitempty"` - Description string `json:"description,omitempty" yaml:"description,omitempty"` - Version string `json:"version,omitempty" yaml:"version,omitempty"` - TermsOfService string `json:"termsOfService,omitempty" yaml:"termsOfService,omitempty"` - - Contact Contact `json:"contact,omitempty" yaml:"contact,omitempty"` - License *License `json:"license,omitempty" yaml:"license,omitempty"` -} - -// Contact information for the exposed API. -type Contact struct { - Name string `json:"name,omitempty" yaml:"name,omitempty"` - URL string `json:"url,omitempty" yaml:"url,omitempty"` - EMail string `json:"email,omitempty" yaml:"email,omitempty"` -} - -// License information for the exposed API. -type License struct { - Name string `json:"name,omitempty" yaml:"name,omitempty"` - URL string `json:"url,omitempty" yaml:"url,omitempty"` -} - -// Item Describes the operations available on a single path. -type Item struct { - Ref string `json:"$ref,omitempty" yaml:"$ref,omitempty"` - Get *Operation `json:"get,omitempty" yaml:"get,omitempty"` - Put *Operation `json:"put,omitempty" yaml:"put,omitempty"` - Post *Operation `json:"post,omitempty" yaml:"post,omitempty"` - Delete *Operation `json:"delete,omitempty" yaml:"delete,omitempty"` - Options *Operation `json:"options,omitempty" yaml:"options,omitempty"` - Head *Operation `json:"head,omitempty" yaml:"head,omitempty"` - Patch *Operation `json:"patch,omitempty" yaml:"patch,omitempty"` -} - -// Operation Describes a single API operation on a path. -type Operation struct { - Tags []string `json:"tags,omitempty" yaml:"tags,omitempty"` - Summary string `json:"summary,omitempty" yaml:"summary,omitempty"` - Description string `json:"description,omitempty" yaml:"description,omitempty"` - OperationID string `json:"operationId,omitempty" yaml:"operationId,omitempty"` - Consumes []string `json:"consumes,omitempty" yaml:"consumes,omitempty"` - Produces []string `json:"produces,omitempty" yaml:"produces,omitempty"` - Schemes []string `json:"schemes,omitempty" yaml:"schemes,omitempty"` - Parameters []Parameter `json:"parameters,omitempty" yaml:"parameters,omitempty"` - Responses map[string]Response `json:"responses,omitempty" yaml:"responses,omitempty"` - Security []map[string][]string `json:"security,omitempty" yaml:"security,omitempty"` - Deprecated bool `json:"deprecated,omitempty" yaml:"deprecated,omitempty"` -} - -// Parameter Describes a single operation parameter. -type Parameter struct { - In string `json:"in,omitempty" yaml:"in,omitempty"` - Name string `json:"name,omitempty" yaml:"name,omitempty"` - Description string `json:"description,omitempty" yaml:"description,omitempty"` - Required bool `json:"required,omitempty" yaml:"required,omitempty"` - Schema *Schema `json:"schema,omitempty" yaml:"schema,omitempty"` - Type string `json:"type,omitempty" yaml:"type,omitempty"` - Format string `json:"format,omitempty" yaml:"format,omitempty"` - Items *ParameterItems `json:"items,omitempty" yaml:"items,omitempty"` - Default interface{} `json:"default,omitempty" yaml:"default,omitempty"` -} - -// ParameterItems A limited subset of JSON-Schema's items object. It is used by parameter definitions that are not located in "body". -// http://swagger.io/specification/#itemsObject -type ParameterItems struct { - Type string `json:"type,omitempty" yaml:"type,omitempty"` - Format string `json:"format,omitempty" yaml:"format,omitempty"` - Items []*ParameterItems `json:"items,omitempty" yaml:"items,omitempty"` //Required if type is "array". Describes the type of items in the array. - CollectionFormat string `json:"collectionFormat,omitempty" yaml:"collectionFormat,omitempty"` - Default string `json:"default,omitempty" yaml:"default,omitempty"` -} - -// Schema Object allows the definition of input and output data types. -type Schema struct { - Ref string `json:"$ref,omitempty" yaml:"$ref,omitempty"` - Title string `json:"title,omitempty" yaml:"title,omitempty"` - Format string `json:"format,omitempty" yaml:"format,omitempty"` - Description string `json:"description,omitempty" yaml:"description,omitempty"` - Required []string `json:"required,omitempty" yaml:"required,omitempty"` - Type string `json:"type,omitempty" yaml:"type,omitempty"` - Items *Schema `json:"items,omitempty" yaml:"items,omitempty"` - Properties map[string]Propertie `json:"properties,omitempty" yaml:"properties,omitempty"` - Enum []interface{} `json:"enum,omitempty" yaml:"enum,omitempty"` - Example interface{} `json:"example,omitempty" yaml:"example,omitempty"` -} - -// Propertie are taken from the JSON Schema definition but their definitions were adjusted to the Swagger Specification -type Propertie struct { - Ref string `json:"$ref,omitempty" yaml:"$ref,omitempty"` - Title string `json:"title,omitempty" yaml:"title,omitempty"` - Description string `json:"description,omitempty" yaml:"description,omitempty"` - Default interface{} `json:"default,omitempty" yaml:"default,omitempty"` - Type string `json:"type,omitempty" yaml:"type,omitempty"` - Example interface{} `json:"example,omitempty" yaml:"example,omitempty"` - Required []string `json:"required,omitempty" yaml:"required,omitempty"` - Format string `json:"format,omitempty" yaml:"format,omitempty"` - ReadOnly bool `json:"readOnly,omitempty" yaml:"readOnly,omitempty"` - Properties map[string]Propertie `json:"properties,omitempty" yaml:"properties,omitempty"` - Items *Propertie `json:"items,omitempty" yaml:"items,omitempty"` - AdditionalProperties *Propertie `json:"additionalProperties,omitempty" yaml:"additionalProperties,omitempty"` -} - -// Response as they are returned from executing this operation. -type Response struct { - Description string `json:"description" yaml:"description"` - Schema *Schema `json:"schema,omitempty" yaml:"schema,omitempty"` - Ref string `json:"$ref,omitempty" yaml:"$ref,omitempty"` -} - -// Security Allows the definition of a security scheme that can be used by the operations -type Security struct { - Type string `json:"type,omitempty" yaml:"type,omitempty"` // Valid values are "basic", "apiKey" or "oauth2". - Description string `json:"description,omitempty" yaml:"description,omitempty"` - Name string `json:"name,omitempty" yaml:"name,omitempty"` - In string `json:"in,omitempty" yaml:"in,omitempty"` // Valid values are "query" or "header". - Flow string `json:"flow,omitempty" yaml:"flow,omitempty"` // Valid values are "implicit", "password", "application" or "accessCode". - AuthorizationURL string `json:"authorizationUrl,omitempty" yaml:"authorizationUrl,omitempty"` - TokenURL string `json:"tokenUrl,omitempty" yaml:"tokenUrl,omitempty"` - Scopes map[string]string `json:"scopes,omitempty" yaml:"scopes,omitempty"` // The available scopes for the OAuth2 security scheme. -} - -// Tag Allows adding meta data to a single tag that is used by the Operation Object -type Tag struct { - Name string `json:"name,omitempty" yaml:"name,omitempty"` - Description string `json:"description,omitempty" yaml:"description,omitempty"` - ExternalDocs *ExternalDocs `json:"externalDocs,omitempty" yaml:"externalDocs,omitempty"` -} - -// ExternalDocs include Additional external documentation -type ExternalDocs struct { - Description string `json:"description,omitempty" yaml:"description,omitempty"` - URL string `json:"url,omitempty" yaml:"url,omitempty"` -} diff --git a/template.go b/template.go deleted file mode 100644 index 69b178ca..00000000 --- a/template.go +++ /dev/null @@ -1,417 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package beego - -import ( - "errors" - "fmt" - "html/template" - "io" - "io/ioutil" - "net/http" - "os" - "path/filepath" - "regexp" - "strings" - "sync" - - "github.com/astaxie/beego/logs" - "github.com/astaxie/beego/utils" -) - -var ( - beegoTplFuncMap = make(template.FuncMap) - beeViewPathTemplateLocked = false - // beeViewPathTemplates caching map and supported template file extensions per view - beeViewPathTemplates = make(map[string]map[string]*template.Template) - templatesLock sync.RWMutex - // beeTemplateExt stores the template extension which will build - beeTemplateExt = []string{"tpl", "html", "gohtml"} - // beeTemplatePreprocessors stores associations of extension -> preprocessor handler - beeTemplateEngines = map[string]templatePreProcessor{} - beeTemplateFS = defaultFSFunc -) - -// ExecuteTemplate applies the template with name to the specified data object, -// writing the output to wr. -// A template will be executed safely in parallel. -// Deprecated: using pkg/, we will delete this in v2.1.0 -func ExecuteTemplate(wr io.Writer, name string, data interface{}) error { - return ExecuteViewPathTemplate(wr, name, BConfig.WebConfig.ViewsPath, data) -} - -// ExecuteViewPathTemplate applies the template with name and from specific viewPath to the specified data object, -// writing the output to wr. -// A template will be executed safely in parallel. -// Deprecated: using pkg/, we will delete this in v2.1.0 -func ExecuteViewPathTemplate(wr io.Writer, name string, viewPath string, data interface{}) error { - if BConfig.RunMode == DEV { - templatesLock.RLock() - defer templatesLock.RUnlock() - } - if beeTemplates, ok := beeViewPathTemplates[viewPath]; ok { - if t, ok := beeTemplates[name]; ok { - var err error - if t.Lookup(name) != nil { - err = t.ExecuteTemplate(wr, name, data) - } else { - err = t.Execute(wr, data) - } - if err != nil { - logs.Trace("template Execute err:", err) - } - return err - } - panic("can't find templatefile in the path:" + viewPath + "/" + name) - } - panic("Unknown view path:" + viewPath) -} - -func init() { - beegoTplFuncMap["dateformat"] = DateFormat - beegoTplFuncMap["date"] = Date - beegoTplFuncMap["compare"] = Compare - beegoTplFuncMap["compare_not"] = CompareNot - beegoTplFuncMap["not_nil"] = NotNil - beegoTplFuncMap["not_null"] = NotNil - beegoTplFuncMap["substr"] = Substr - beegoTplFuncMap["html2str"] = HTML2str - beegoTplFuncMap["str2html"] = Str2html - beegoTplFuncMap["htmlquote"] = Htmlquote - beegoTplFuncMap["htmlunquote"] = Htmlunquote - beegoTplFuncMap["renderform"] = RenderForm - beegoTplFuncMap["assets_js"] = AssetsJs - beegoTplFuncMap["assets_css"] = AssetsCSS - beegoTplFuncMap["config"] = GetConfig - beegoTplFuncMap["map_get"] = MapGet - - // Comparisons - beegoTplFuncMap["eq"] = eq // == - beegoTplFuncMap["ge"] = ge // >= - beegoTplFuncMap["gt"] = gt // > - beegoTplFuncMap["le"] = le // <= - beegoTplFuncMap["lt"] = lt // < - beegoTplFuncMap["ne"] = ne // != - - beegoTplFuncMap["urlfor"] = URLFor // build a URL to match a Controller and it's method -} - -// AddFuncMap let user to register a func in the template. -func AddFuncMap(key string, fn interface{}) error { - beegoTplFuncMap[key] = fn - return nil -} - -type templatePreProcessor func(root, path string, funcs template.FuncMap) (*template.Template, error) - -type templateFile struct { - root string - files map[string][]string -} - -// visit will make the paths into two part,the first is subDir (without tf.root),the second is full path(without tf.root). -// if tf.root="views" and -// paths is "views/errors/404.html",the subDir will be "errors",the file will be "errors/404.html" -// paths is "views/admin/errors/404.html",the subDir will be "admin/errors",the file will be "admin/errors/404.html" -func (tf *templateFile) visit(paths string, f os.FileInfo, err error) error { - if f == nil { - return err - } - if f.IsDir() || (f.Mode()&os.ModeSymlink) > 0 { - return nil - } - if !HasTemplateExt(paths) { - return nil - } - - replace := strings.NewReplacer("\\", "/") - file := strings.TrimLeft(replace.Replace(paths[len(tf.root):]), "/") - subDir := filepath.Dir(file) - - tf.files[subDir] = append(tf.files[subDir], file) - return nil -} - -// HasTemplateExt return this path contains supported template extension of beego or not. -// Deprecated: using pkg/, we will delete this in v2.1.0 -func HasTemplateExt(paths string) bool { - for _, v := range beeTemplateExt { - if strings.HasSuffix(paths, "."+v) { - return true - } - } - return false -} - -// AddTemplateExt add new extension for template. -// Deprecated: using pkg/, we will delete this in v2.1.0 -func AddTemplateExt(ext string) { - for _, v := range beeTemplateExt { - if v == ext { - return - } - } - beeTemplateExt = append(beeTemplateExt, ext) -} - -// AddViewPath adds a new path to the supported view paths. -//Can later be used by setting a controller ViewPath to this folder -//will panic if called after beego.Run() -// Deprecated: using pkg/, we will delete this in v2.1.0 -func AddViewPath(viewPath string) error { - if beeViewPathTemplateLocked { - if _, exist := beeViewPathTemplates[viewPath]; exist { - return nil //Ignore if viewpath already exists - } - panic("Can not add new view paths after beego.Run()") - } - beeViewPathTemplates[viewPath] = make(map[string]*template.Template) - return BuildTemplate(viewPath) -} - -func lockViewPaths() { - beeViewPathTemplateLocked = true -} - -// BuildTemplate will build all template files in a directory. -// it makes beego can render any template file in view directory. -// Deprecated: using pkg/, we will delete this in v2.1.0 -func BuildTemplate(dir string, files ...string) error { - var err error - fs := beeTemplateFS() - f, err := fs.Open(dir) - if err != nil { - if os.IsNotExist(err) { - return nil - } - return errors.New("dir open err") - } - defer f.Close() - - beeTemplates, ok := beeViewPathTemplates[dir] - if !ok { - panic("Unknown view path: " + dir) - } - self := &templateFile{ - root: dir, - files: make(map[string][]string), - } - err = Walk(fs, dir, func(path string, f os.FileInfo, err error) error { - return self.visit(path, f, err) - }) - if err != nil { - fmt.Printf("Walk() returned %v\n", err) - return err - } - buildAllFiles := len(files) == 0 - for _, v := range self.files { - for _, file := range v { - if buildAllFiles || utils.InSlice(file, files) { - templatesLock.Lock() - ext := filepath.Ext(file) - var t *template.Template - if len(ext) == 0 { - t, err = getTemplate(self.root, fs, file, v...) - } else if fn, ok := beeTemplateEngines[ext[1:]]; ok { - t, err = fn(self.root, file, beegoTplFuncMap) - } else { - t, err = getTemplate(self.root, fs, file, v...) - } - if err != nil { - logs.Error("parse template err:", file, err) - templatesLock.Unlock() - return err - } - beeTemplates[file] = t - templatesLock.Unlock() - } - } - } - return nil -} - -func getTplDeep(root string, fs http.FileSystem, file string, parent string, t *template.Template) (*template.Template, [][]string, error) { - var fileAbsPath string - var rParent string - var err error - if strings.HasPrefix(file, "../") { - rParent = filepath.Join(filepath.Dir(parent), file) - fileAbsPath = filepath.Join(root, filepath.Dir(parent), file) - } else { - rParent = file - fileAbsPath = filepath.Join(root, file) - } - f, err := fs.Open(fileAbsPath) - if err != nil { - panic("can't find template file:" + file) - } - defer f.Close() - data, err := ioutil.ReadAll(f) - if err != nil { - return nil, [][]string{}, err - } - t, err = t.New(file).Parse(string(data)) - if err != nil { - return nil, [][]string{}, err - } - reg := regexp.MustCompile(BConfig.WebConfig.TemplateLeft + "[ ]*template[ ]+\"([^\"]+)\"") - allSub := reg.FindAllStringSubmatch(string(data), -1) - for _, m := range allSub { - if len(m) == 2 { - tl := t.Lookup(m[1]) - if tl != nil { - continue - } - if !HasTemplateExt(m[1]) { - continue - } - _, _, err = getTplDeep(root, fs, m[1], rParent, t) - if err != nil { - return nil, [][]string{}, err - } - } - } - return t, allSub, nil -} - -func getTemplate(root string, fs http.FileSystem, file string, others ...string) (t *template.Template, err error) { - t = template.New(file).Delims(BConfig.WebConfig.TemplateLeft, BConfig.WebConfig.TemplateRight).Funcs(beegoTplFuncMap) - var subMods [][]string - t, subMods, err = getTplDeep(root, fs, file, "", t) - if err != nil { - return nil, err - } - t, err = _getTemplate(t, root, fs, subMods, others...) - - if err != nil { - return nil, err - } - return -} - -func _getTemplate(t0 *template.Template, root string, fs http.FileSystem, subMods [][]string, others ...string) (t *template.Template, err error) { - t = t0 - for _, m := range subMods { - if len(m) == 2 { - tpl := t.Lookup(m[1]) - if tpl != nil { - continue - } - //first check filename - for _, otherFile := range others { - if otherFile == m[1] { - var subMods1 [][]string - t, subMods1, err = getTplDeep(root, fs, otherFile, "", t) - if err != nil { - logs.Trace("template parse file err:", err) - } else if len(subMods1) > 0 { - t, err = _getTemplate(t, root, fs, subMods1, others...) - } - break - } - } - //second check define - for _, otherFile := range others { - var data []byte - fileAbsPath := filepath.Join(root, otherFile) - f, err := fs.Open(fileAbsPath) - if err != nil { - f.Close() - logs.Trace("template file parse error, not success open file:", err) - continue - } - data, err = ioutil.ReadAll(f) - f.Close() - if err != nil { - logs.Trace("template file parse error, not success read file:", err) - continue - } - reg := regexp.MustCompile(BConfig.WebConfig.TemplateLeft + "[ ]*define[ ]+\"([^\"]+)\"") - allSub := reg.FindAllStringSubmatch(string(data), -1) - for _, sub := range allSub { - if len(sub) == 2 && sub[1] == m[1] { - var subMods1 [][]string - t, subMods1, err = getTplDeep(root, fs, otherFile, "", t) - if err != nil { - logs.Trace("template parse file err:", err) - } else if len(subMods1) > 0 { - t, err = _getTemplate(t, root, fs, subMods1, others...) - if err != nil { - logs.Trace("template parse file err:", err) - } - } - break - } - } - } - } - - } - return -} - -type templateFSFunc func() http.FileSystem - -func defaultFSFunc() http.FileSystem { - return FileSystem{} -} - -// SetTemplateFSFunc set default filesystem function -// Deprecated: using pkg/, we will delete this in v2.1.0 -func SetTemplateFSFunc(fnt templateFSFunc) { - beeTemplateFS = fnt -} - -// SetViewsPath sets view directory path in beego application. -// Deprecated: using pkg/, we will delete this in v2.1.0 -func SetViewsPath(path string) *App { - BConfig.WebConfig.ViewsPath = path - return BeeApp -} - -// SetStaticPath sets static directory path and proper url pattern in beego application. -// if beego.SetStaticPath("static","public"), visit /static/* to load static file in folder "public". -// Deprecated: using pkg/, we will delete this in v2.1.0 -func SetStaticPath(url string, path string) *App { - if !strings.HasPrefix(url, "/") { - url = "/" + url - } - if url != "/" { - url = strings.TrimRight(url, "/") - } - BConfig.WebConfig.StaticDir[url] = path - return BeeApp -} - -// DelStaticPath removes the static folder setting in this url pattern in beego application. -// Deprecated: using pkg/, we will delete this in v2.1.0 -func DelStaticPath(url string) *App { - if !strings.HasPrefix(url, "/") { - url = "/" + url - } - if url != "/" { - url = strings.TrimRight(url, "/") - } - delete(BConfig.WebConfig.StaticDir, url) - return BeeApp -} - -// AddTemplateEngine add a new templatePreProcessor which support extension -// Deprecated: using pkg/, we will delete this in v2.1.0 -func AddTemplateEngine(extension string, fn templatePreProcessor) *App { - AddTemplateExt(extension) - beeTemplateEngines[extension] = fn - return BeeApp -} diff --git a/template_test.go b/template_test.go deleted file mode 100644 index bde9c100..00000000 --- a/template_test.go +++ /dev/null @@ -1,316 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package beego - -import ( - "bytes" - "github.com/astaxie/beego/test" - "github.com/elazarl/go-bindata-assetfs" - "net/http" - "os" - "path/filepath" - "testing" -) - -var header = `{{define "header"}} -

Hello, astaxie!

-{{end}}` - -var index = ` - - - beego welcome template - - -{{template "block"}} -{{template "header"}} -{{template "blocks/block.tpl"}} - - -` - -var block = `{{define "block"}} -

Hello, blocks!

-{{end}}` - -func TestTemplate(t *testing.T) { - dir := "_beeTmp" - files := []string{ - "header.tpl", - "index.tpl", - "blocks/block.tpl", - } - if err := os.MkdirAll(dir, 0777); err != nil { - t.Fatal(err) - } - for k, name := range files { - os.MkdirAll(filepath.Dir(filepath.Join(dir, name)), 0777) - if f, err := os.Create(filepath.Join(dir, name)); err != nil { - t.Fatal(err) - } else { - if k == 0 { - f.WriteString(header) - } else if k == 1 { - f.WriteString(index) - } else if k == 2 { - f.WriteString(block) - } - - f.Close() - } - } - if err := AddViewPath(dir); err != nil { - t.Fatal(err) - } - beeTemplates := beeViewPathTemplates[dir] - if len(beeTemplates) != 3 { - t.Fatalf("should be 3 but got %v", len(beeTemplates)) - } - if err := beeTemplates["index.tpl"].ExecuteTemplate(os.Stdout, "index.tpl", nil); err != nil { - t.Fatal(err) - } - for _, name := range files { - os.RemoveAll(filepath.Join(dir, name)) - } - os.RemoveAll(dir) -} - -var menu = ` -` -var user = ` - - - beego welcome template - - -{{template "../public/menu.tpl"}} - - -` - -func TestRelativeTemplate(t *testing.T) { - dir := "_beeTmp" - - //Just add dir to known viewPaths - if err := AddViewPath(dir); err != nil { - t.Fatal(err) - } - - files := []string{ - "easyui/public/menu.tpl", - "easyui/rbac/user.tpl", - } - if err := os.MkdirAll(dir, 0777); err != nil { - t.Fatal(err) - } - for k, name := range files { - os.MkdirAll(filepath.Dir(filepath.Join(dir, name)), 0777) - if f, err := os.Create(filepath.Join(dir, name)); err != nil { - t.Fatal(err) - } else { - if k == 0 { - f.WriteString(menu) - } else if k == 1 { - f.WriteString(user) - } - f.Close() - } - } - if err := BuildTemplate(dir, files[1]); err != nil { - t.Fatal(err) - } - beeTemplates := beeViewPathTemplates[dir] - if err := beeTemplates["easyui/rbac/user.tpl"].ExecuteTemplate(os.Stdout, "easyui/rbac/user.tpl", nil); err != nil { - t.Fatal(err) - } - for _, name := range files { - os.RemoveAll(filepath.Join(dir, name)) - } - os.RemoveAll(dir) -} - -var add = `{{ template "layout_blog.tpl" . }} -{{ define "css" }} - -{{ end}} - - -{{ define "content" }} -

{{ .Title }}

-

This is SomeVar: {{ .SomeVar }}

-{{ end }} - -{{ define "js" }} - -{{ end}}` - -var layoutBlog = ` - - - Lin Li - - - - - {{ block "css" . }}{{ end }} - - - -
- {{ block "content" . }}{{ end }} -
- - - {{ block "js" . }}{{ end }} - -` - -var output = ` - - - Lin Li - - - - - - - - - - -
- -

Hello

-

This is SomeVar: val

- -
- - - - - - - - - - - - -` - -func TestTemplateLayout(t *testing.T) { - dir := "_beeTmp" - files := []string{ - "add.tpl", - "layout_blog.tpl", - } - if err := os.MkdirAll(dir, 0777); err != nil { - t.Fatal(err) - } - for k, name := range files { - os.MkdirAll(filepath.Dir(filepath.Join(dir, name)), 0777) - if f, err := os.Create(filepath.Join(dir, name)); err != nil { - t.Fatal(err) - } else { - if k == 0 { - f.WriteString(add) - } else if k == 1 { - f.WriteString(layoutBlog) - } - f.Close() - } - } - if err := AddViewPath(dir); err != nil { - t.Fatal(err) - } - beeTemplates := beeViewPathTemplates[dir] - if len(beeTemplates) != 2 { - t.Fatalf("should be 2 but got %v", len(beeTemplates)) - } - out := bytes.NewBufferString("") - if err := beeTemplates["add.tpl"].ExecuteTemplate(out, "add.tpl", map[string]string{"Title": "Hello", "SomeVar": "val"}); err != nil { - t.Fatal(err) - } - if out.String() != output { - t.Log(out.String()) - t.Fatal("Compare failed") - } - for _, name := range files { - os.RemoveAll(filepath.Join(dir, name)) - } - os.RemoveAll(dir) -} - -type TestingFileSystem struct { - assetfs *assetfs.AssetFS -} - -func (d TestingFileSystem) Open(name string) (http.File, error) { - return d.assetfs.Open(name) -} - -var outputBinData = ` - - - beego welcome template - - - - -

Hello, blocks!

- - -

Hello, astaxie!

- - - -

Hello

-

This is SomeVar: val

- - -` - -func TestFsBinData(t *testing.T) { - SetTemplateFSFunc(func() http.FileSystem { - return TestingFileSystem{&assetfs.AssetFS{Asset: test.Asset, AssetDir: test.AssetDir, AssetInfo: test.AssetInfo}} - }) - dir := "views" - if err := AddViewPath("views"); err != nil { - t.Fatal(err) - } - beeTemplates := beeViewPathTemplates[dir] - if len(beeTemplates) != 3 { - t.Fatalf("should be 3 but got %v", len(beeTemplates)) - } - if err := beeTemplates["index.tpl"].ExecuteTemplate(os.Stdout, "index.tpl", map[string]string{"Title": "Hello", "SomeVar": "val"}); err != nil { - t.Fatal(err) - } - out := bytes.NewBufferString("") - if err := beeTemplates["index.tpl"].ExecuteTemplate(out, "index.tpl", map[string]string{"Title": "Hello", "SomeVar": "val"}); err != nil { - t.Fatal(err) - } - - if out.String() != outputBinData { - t.Log(out.String()) - t.Fatal("Compare failed") - } -} diff --git a/templatefunc.go b/templatefunc.go deleted file mode 100644 index 9e7c42fc..00000000 --- a/templatefunc.go +++ /dev/null @@ -1,798 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package beego - -import ( - "errors" - "fmt" - "html" - "html/template" - "net/url" - "reflect" - "regexp" - "strconv" - "strings" - "time" -) - -const ( - formatTime = "15:04:05" - formatDate = "2006-01-02" - formatDateTime = "2006-01-02 15:04:05" - formatDateTimeT = "2006-01-02T15:04:05" -) - -// Substr returns the substr from start to length. -// Deprecated: using pkg/, we will delete this in v2.1.0 -func Substr(s string, start, length int) string { - bt := []rune(s) - if start < 0 { - start = 0 - } - if start > len(bt) { - start = start % len(bt) - } - var end int - if (start + length) > (len(bt) - 1) { - end = len(bt) - } else { - end = start + length - } - return string(bt[start:end]) -} - -// HTML2str returns escaping text convert from html. -// Deprecated: using pkg/, we will delete this in v2.1.0 -func HTML2str(html string) string { - - re := regexp.MustCompile(`\<[\S\s]+?\>`) - html = re.ReplaceAllStringFunc(html, strings.ToLower) - - //remove STYLE - re = regexp.MustCompile(`\`) - html = re.ReplaceAllString(html, "") - - //remove SCRIPT - re = regexp.MustCompile(`\`) - html = re.ReplaceAllString(html, "") - - re = regexp.MustCompile(`\<[\S\s]+?\>`) - html = re.ReplaceAllString(html, "\n") - - re = regexp.MustCompile(`\s{2,}`) - html = re.ReplaceAllString(html, "\n") - - return strings.TrimSpace(html) -} - -// DateFormat takes a time and a layout string and returns a string with the formatted date. Used by the template parser as "dateformat" -// Deprecated: using pkg/, we will delete this in v2.1.0 -func DateFormat(t time.Time, layout string) (datestring string) { - datestring = t.Format(layout) - return -} - -// DateFormat pattern rules. -var datePatterns = []string{ - // year - "Y", "2006", // A full numeric representation of a year, 4 digits Examples: 1999 or 2003 - "y", "06", //A two digit representation of a year Examples: 99 or 03 - - // month - "m", "01", // Numeric representation of a month, with leading zeros 01 through 12 - "n", "1", // Numeric representation of a month, without leading zeros 1 through 12 - "M", "Jan", // A short textual representation of a month, three letters Jan through Dec - "F", "January", // A full textual representation of a month, such as January or March January through December - - // day - "d", "02", // Day of the month, 2 digits with leading zeros 01 to 31 - "j", "2", // Day of the month without leading zeros 1 to 31 - - // week - "D", "Mon", // A textual representation of a day, three letters Mon through Sun - "l", "Monday", // A full textual representation of the day of the week Sunday through Saturday - - // time - "g", "3", // 12-hour format of an hour without leading zeros 1 through 12 - "G", "15", // 24-hour format of an hour without leading zeros 0 through 23 - "h", "03", // 12-hour format of an hour with leading zeros 01 through 12 - "H", "15", // 24-hour format of an hour with leading zeros 00 through 23 - - "a", "pm", // Lowercase Ante meridiem and Post meridiem am or pm - "A", "PM", // Uppercase Ante meridiem and Post meridiem AM or PM - - "i", "04", // Minutes with leading zeros 00 to 59 - "s", "05", // Seconds, with leading zeros 00 through 59 - - // time zone - "T", "MST", - "P", "-07:00", - "O", "-0700", - - // RFC 2822 - "r", time.RFC1123Z, -} - -// DateParse Parse Date use PHP time format. -// Deprecated: using pkg/, we will delete this in v2.1.0 -func DateParse(dateString, format string) (time.Time, error) { - replacer := strings.NewReplacer(datePatterns...) - format = replacer.Replace(format) - return time.ParseInLocation(format, dateString, time.Local) -} - -// Date takes a PHP like date func to Go's time format. -// Deprecated: using pkg/, we will delete this in v2.1.0 -func Date(t time.Time, format string) string { - replacer := strings.NewReplacer(datePatterns...) - format = replacer.Replace(format) - return t.Format(format) -} - -// Compare is a quick and dirty comparison function. It will convert whatever you give it to strings and see if the two values are equal. -// Whitespace is trimmed. Used by the template parser as "eq". -// Deprecated: using pkg/, we will delete this in v2.1.0 -func Compare(a, b interface{}) (equal bool) { - equal = false - if strings.TrimSpace(fmt.Sprintf("%v", a)) == strings.TrimSpace(fmt.Sprintf("%v", b)) { - equal = true - } - return -} - -// CompareNot !Compare -// Deprecated: using pkg/, we will delete this in v2.1.0 -func CompareNot(a, b interface{}) (equal bool) { - return !Compare(a, b) -} - -// NotNil the same as CompareNot -// Deprecated: using pkg/, we will delete this in v2.1.0 -func NotNil(a interface{}) (isNil bool) { - return CompareNot(a, nil) -} - -// GetConfig get the Appconfig -// Deprecated: using pkg/, we will delete this in v2.1.0 -func GetConfig(returnType, key string, defaultVal interface{}) (value interface{}, err error) { - switch returnType { - case "String": - value = AppConfig.String(key) - case "Bool": - value, err = AppConfig.Bool(key) - case "Int": - value, err = AppConfig.Int(key) - case "Int64": - value, err = AppConfig.Int64(key) - case "Float": - value, err = AppConfig.Float(key) - case "DIY": - value, err = AppConfig.DIY(key) - default: - err = errors.New("config keys must be of type String, Bool, Int, Int64, Float, or DIY") - } - - if err != nil { - if reflect.TypeOf(returnType) != reflect.TypeOf(defaultVal) { - err = errors.New("defaultVal type does not match returnType") - } else { - value, err = defaultVal, nil - } - } else if reflect.TypeOf(value).Kind() == reflect.String { - if value == "" { - if reflect.TypeOf(defaultVal).Kind() != reflect.String { - err = errors.New("defaultVal type must be a String if the returnType is a String") - } else { - value = defaultVal.(string) - } - } - } - - return -} - -// Str2html Convert string to template.HTML type. -// Deprecated: using pkg/, we will delete this in v2.1.0 -func Str2html(raw string) template.HTML { - return template.HTML(raw) -} - -// Htmlquote returns quoted html string. -// Deprecated: using pkg/, we will delete this in v2.1.0 -func Htmlquote(text string) string { - //HTML编码为实体符号 - /* - Encodes `text` for raw use in HTML. - >>> htmlquote("<'&\\">") - '<'&">' - */ - - text = html.EscapeString(text) - text = strings.NewReplacer( - `“`, "“", - `”`, "”", - ` `, " ", - ).Replace(text) - - return strings.TrimSpace(text) -} - -// Htmlunquote returns unquoted html string. -// Deprecated: using pkg/, we will delete this in v2.1.0 -func Htmlunquote(text string) string { - //实体符号解释为HTML - /* - Decodes `text` that's HTML quoted. - >>> htmlunquote('<'&">') - '<\\'&">' - */ - - text = html.UnescapeString(text) - - return strings.TrimSpace(text) -} - -// URLFor returns url string with another registered controller handler with params. -// usage: -// -// URLFor(".index") -// print URLFor("index") -// router /login -// print URLFor("login") -// print URLFor("login", "next","/"") -// router /profile/:username -// print UrlFor("profile", ":username","John Doe") -// result: -// / -// /login -// /login?next=/ -// /user/John%20Doe -// -// more detail http://beego.me/docs/mvc/controller/urlbuilding.md -// Deprecated: using pkg/, we will delete this in v2.1.0 -func URLFor(endpoint string, values ...interface{}) string { - return BeeApp.Handlers.URLFor(endpoint, values...) -} - -// AssetsJs returns script tag with src string. -// Deprecated: using pkg/, we will delete this in v2.1.0 -func AssetsJs(text string) template.HTML { - - text = "" - - return template.HTML(text) -} - -// AssetsCSS returns stylesheet link tag with src string. -// Deprecated: using pkg/, we will delete this in v2.1.0 -func AssetsCSS(text string) template.HTML { - - text = "" - - return template.HTML(text) -} - -// ParseForm will parse form values to struct via tag. -// Support for anonymous struct. -func parseFormToStruct(form url.Values, objT reflect.Type, objV reflect.Value) error { - for i := 0; i < objT.NumField(); i++ { - fieldV := objV.Field(i) - if !fieldV.CanSet() { - continue - } - - fieldT := objT.Field(i) - if fieldT.Anonymous && fieldT.Type.Kind() == reflect.Struct { - err := parseFormToStruct(form, fieldT.Type, fieldV) - if err != nil { - return err - } - continue - } - - tags := strings.Split(fieldT.Tag.Get("form"), ",") - var tag string - if len(tags) == 0 || len(tags[0]) == 0 { - tag = fieldT.Name - } else if tags[0] == "-" { - continue - } else { - tag = tags[0] - } - - formValues := form[tag] - var value string - if len(formValues) == 0 { - defaultValue := fieldT.Tag.Get("default") - if defaultValue != "" { - value = defaultValue - } else { - continue - } - } - if len(formValues) == 1 { - value = formValues[0] - if value == "" { - continue - } - } - - switch fieldT.Type.Kind() { - case reflect.Bool: - if strings.ToLower(value) == "on" || strings.ToLower(value) == "1" || strings.ToLower(value) == "yes" { - fieldV.SetBool(true) - continue - } - if strings.ToLower(value) == "off" || strings.ToLower(value) == "0" || strings.ToLower(value) == "no" { - fieldV.SetBool(false) - continue - } - b, err := strconv.ParseBool(value) - if err != nil { - return err - } - fieldV.SetBool(b) - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: - x, err := strconv.ParseInt(value, 10, 64) - if err != nil { - return err - } - fieldV.SetInt(x) - case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: - x, err := strconv.ParseUint(value, 10, 64) - if err != nil { - return err - } - fieldV.SetUint(x) - case reflect.Float32, reflect.Float64: - x, err := strconv.ParseFloat(value, 64) - if err != nil { - return err - } - fieldV.SetFloat(x) - case reflect.Interface: - fieldV.Set(reflect.ValueOf(value)) - case reflect.String: - fieldV.SetString(value) - case reflect.Struct: - switch fieldT.Type.String() { - case "time.Time": - var ( - t time.Time - err error - ) - if len(value) >= 25 { - value = value[:25] - t, err = time.ParseInLocation(time.RFC3339, value, time.Local) - } else if strings.HasSuffix(strings.ToUpper(value), "Z") { - t, err = time.ParseInLocation(time.RFC3339, value, time.Local) - } else if len(value) >= 19 { - if strings.Contains(value, "T") { - value = value[:19] - t, err = time.ParseInLocation(formatDateTimeT, value, time.Local) - } else { - value = value[:19] - t, err = time.ParseInLocation(formatDateTime, value, time.Local) - } - } else if len(value) >= 10 { - if len(value) > 10 { - value = value[:10] - } - t, err = time.ParseInLocation(formatDate, value, time.Local) - } else if len(value) >= 8 { - if len(value) > 8 { - value = value[:8] - } - t, err = time.ParseInLocation(formatTime, value, time.Local) - } - if err != nil { - return err - } - fieldV.Set(reflect.ValueOf(t)) - } - case reflect.Slice: - if fieldT.Type == sliceOfInts { - formVals := form[tag] - fieldV.Set(reflect.MakeSlice(reflect.SliceOf(reflect.TypeOf(int(1))), len(formVals), len(formVals))) - for i := 0; i < len(formVals); i++ { - val, err := strconv.Atoi(formVals[i]) - if err != nil { - return err - } - fieldV.Index(i).SetInt(int64(val)) - } - } else if fieldT.Type == sliceOfStrings { - formVals := form[tag] - fieldV.Set(reflect.MakeSlice(reflect.SliceOf(reflect.TypeOf("")), len(formVals), len(formVals))) - for i := 0; i < len(formVals); i++ { - fieldV.Index(i).SetString(formVals[i]) - } - } - } - } - return nil -} - -// ParseForm will parse form values to struct via tag. -// Deprecated: using pkg/, we will delete this in v2.1.0 -func ParseForm(form url.Values, obj interface{}) error { - objT := reflect.TypeOf(obj) - objV := reflect.ValueOf(obj) - if !isStructPtr(objT) { - return fmt.Errorf("%v must be a struct pointer", obj) - } - objT = objT.Elem() - objV = objV.Elem() - - return parseFormToStruct(form, objT, objV) -} - -var sliceOfInts = reflect.TypeOf([]int(nil)) -var sliceOfStrings = reflect.TypeOf([]string(nil)) - -var unKind = map[reflect.Kind]bool{ - reflect.Uintptr: true, - reflect.Complex64: true, - reflect.Complex128: true, - reflect.Array: true, - reflect.Chan: true, - reflect.Func: true, - reflect.Map: true, - reflect.Ptr: true, - reflect.Slice: true, - reflect.Struct: true, - reflect.UnsafePointer: true, -} - -// RenderForm will render object to form html. -// obj must be a struct pointer. -// Deprecated: using pkg/, we will delete this in v2.1.0 -func RenderForm(obj interface{}) template.HTML { - objT := reflect.TypeOf(obj) - objV := reflect.ValueOf(obj) - if !isStructPtr(objT) { - return template.HTML("") - } - objT = objT.Elem() - objV = objV.Elem() - - var raw []string - for i := 0; i < objT.NumField(); i++ { - fieldV := objV.Field(i) - if !fieldV.CanSet() || unKind[fieldV.Kind()] { - continue - } - - fieldT := objT.Field(i) - - label, name, fType, id, class, ignored, required := parseFormTag(fieldT) - if ignored { - continue - } - - raw = append(raw, renderFormField(label, name, fType, fieldV.Interface(), id, class, required)) - } - return template.HTML(strings.Join(raw, "
")) -} - -// renderFormField returns a string containing HTML of a single form field. -func renderFormField(label, name, fType string, value interface{}, id string, class string, required bool) string { - if id != "" { - id = " id=\"" + id + "\"" - } - - if class != "" { - class = " class=\"" + class + "\"" - } - - requiredString := "" - if required { - requiredString = " required" - } - - if isValidForInput(fType) { - return fmt.Sprintf(`%v`, label, id, class, name, fType, value, requiredString) - } - - return fmt.Sprintf(`%v<%v%v%v name="%v"%v>%v`, label, fType, id, class, name, requiredString, value, fType) -} - -// isValidForInput checks if fType is a valid value for the `type` property of an HTML input element. -func isValidForInput(fType string) bool { - validInputTypes := strings.Fields("text password checkbox radio submit reset hidden image file button search email url tel number range date month week time datetime datetime-local color") - for _, validType := range validInputTypes { - if fType == validType { - return true - } - } - return false -} - -// parseFormTag takes the stuct-tag of a StructField and parses the `form` value. -// returned are the form label, name-property, type and wether the field should be ignored. -func parseFormTag(fieldT reflect.StructField) (label, name, fType string, id string, class string, ignored bool, required bool) { - tags := strings.Split(fieldT.Tag.Get("form"), ",") - label = fieldT.Name + ": " - name = fieldT.Name - fType = "text" - ignored = false - id = fieldT.Tag.Get("id") - class = fieldT.Tag.Get("class") - - required = false - requiredField := fieldT.Tag.Get("required") - if requiredField != "-" && requiredField != "" { - required, _ = strconv.ParseBool(requiredField) - } - - switch len(tags) { - case 1: - if tags[0] == "-" { - ignored = true - } - if len(tags[0]) > 0 { - name = tags[0] - } - case 2: - if len(tags[0]) > 0 { - name = tags[0] - } - if len(tags[1]) > 0 { - fType = tags[1] - } - case 3: - if len(tags[0]) > 0 { - name = tags[0] - } - if len(tags[1]) > 0 { - fType = tags[1] - } - if len(tags[2]) > 0 { - label = tags[2] - } - } - - return -} - -func isStructPtr(t reflect.Type) bool { - return t.Kind() == reflect.Ptr && t.Elem().Kind() == reflect.Struct -} - -// go1.2 added template funcs. begin -var ( - errBadComparisonType = errors.New("invalid type for comparison") - errBadComparison = errors.New("incompatible types for comparison") - errNoComparison = errors.New("missing argument for comparison") -) - -type kind int - -const ( - invalidKind kind = iota - boolKind - complexKind - intKind - floatKind - stringKind - uintKind -) - -func basicKind(v reflect.Value) (kind, error) { - switch v.Kind() { - case reflect.Bool: - return boolKind, nil - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: - return intKind, nil - case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: - return uintKind, nil - case reflect.Float32, reflect.Float64: - return floatKind, nil - case reflect.Complex64, reflect.Complex128: - return complexKind, nil - case reflect.String: - return stringKind, nil - } - return invalidKind, errBadComparisonType -} - -// eq evaluates the comparison a == b || a == c || ... -func eq(arg1 interface{}, arg2 ...interface{}) (bool, error) { - v1 := reflect.ValueOf(arg1) - k1, err := basicKind(v1) - if err != nil { - return false, err - } - if len(arg2) == 0 { - return false, errNoComparison - } - for _, arg := range arg2 { - v2 := reflect.ValueOf(arg) - k2, err := basicKind(v2) - if err != nil { - return false, err - } - if k1 != k2 { - return false, errBadComparison - } - truth := false - switch k1 { - case boolKind: - truth = v1.Bool() == v2.Bool() - case complexKind: - truth = v1.Complex() == v2.Complex() - case floatKind: - truth = v1.Float() == v2.Float() - case intKind: - truth = v1.Int() == v2.Int() - case stringKind: - truth = v1.String() == v2.String() - case uintKind: - truth = v1.Uint() == v2.Uint() - default: - panic("invalid kind") - } - if truth { - return true, nil - } - } - return false, nil -} - -// ne evaluates the comparison a != b. -func ne(arg1, arg2 interface{}) (bool, error) { - // != is the inverse of ==. - equal, err := eq(arg1, arg2) - return !equal, err -} - -// lt evaluates the comparison a < b. -func lt(arg1, arg2 interface{}) (bool, error) { - v1 := reflect.ValueOf(arg1) - k1, err := basicKind(v1) - if err != nil { - return false, err - } - v2 := reflect.ValueOf(arg2) - k2, err := basicKind(v2) - if err != nil { - return false, err - } - if k1 != k2 { - return false, errBadComparison - } - truth := false - switch k1 { - case boolKind, complexKind: - return false, errBadComparisonType - case floatKind: - truth = v1.Float() < v2.Float() - case intKind: - truth = v1.Int() < v2.Int() - case stringKind: - truth = v1.String() < v2.String() - case uintKind: - truth = v1.Uint() < v2.Uint() - default: - panic("invalid kind") - } - return truth, nil -} - -// le evaluates the comparison <= b. -func le(arg1, arg2 interface{}) (bool, error) { - // <= is < or ==. - lessThan, err := lt(arg1, arg2) - if lessThan || err != nil { - return lessThan, err - } - return eq(arg1, arg2) -} - -// gt evaluates the comparison a > b. -func gt(arg1, arg2 interface{}) (bool, error) { - // > is the inverse of <=. - lessOrEqual, err := le(arg1, arg2) - if err != nil { - return false, err - } - return !lessOrEqual, nil -} - -// ge evaluates the comparison a >= b. -func ge(arg1, arg2 interface{}) (bool, error) { - // >= is the inverse of <. - lessThan, err := lt(arg1, arg2) - if err != nil { - return false, err - } - return !lessThan, nil -} - -// MapGet getting value from map by keys -// usage: -// Data["m"] = M{ -// "a": 1, -// "1": map[string]float64{ -// "c": 4, -// }, -// } -// -// {{ map_get m "a" }} // return 1 -// {{ map_get m 1 "c" }} // return 4 -// Deprecated: using pkg/, we will delete this in v2.1.0 -func MapGet(arg1 interface{}, arg2 ...interface{}) (interface{}, error) { - arg1Type := reflect.TypeOf(arg1) - arg1Val := reflect.ValueOf(arg1) - - if arg1Type.Kind() == reflect.Map && len(arg2) > 0 { - // check whether arg2[0] type equals to arg1 key type - // if they are different, make conversion - arg2Val := reflect.ValueOf(arg2[0]) - arg2Type := reflect.TypeOf(arg2[0]) - if arg2Type.Kind() != arg1Type.Key().Kind() { - // convert arg2Value to string - var arg2ConvertedVal interface{} - arg2String := fmt.Sprintf("%v", arg2[0]) - - // convert string representation to any other type - switch arg1Type.Key().Kind() { - case reflect.Bool: - arg2ConvertedVal, _ = strconv.ParseBool(arg2String) - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: - arg2ConvertedVal, _ = strconv.ParseInt(arg2String, 0, 64) - case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: - arg2ConvertedVal, _ = strconv.ParseUint(arg2String, 0, 64) - case reflect.Float32, reflect.Float64: - arg2ConvertedVal, _ = strconv.ParseFloat(arg2String, 64) - case reflect.String: - arg2ConvertedVal = arg2String - default: - arg2ConvertedVal = arg2Val.Interface() - } - arg2Val = reflect.ValueOf(arg2ConvertedVal) - } - - storedVal := arg1Val.MapIndex(arg2Val) - - if storedVal.IsValid() { - var result interface{} - - switch arg1Type.Elem().Kind() { - case reflect.Bool: - result = storedVal.Bool() - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: - result = storedVal.Int() - case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: - result = storedVal.Uint() - case reflect.Float32, reflect.Float64: - result = storedVal.Float() - case reflect.String: - result = storedVal.String() - default: - result = storedVal.Interface() - } - - // if there is more keys, handle this recursively - if len(arg2) > 1 { - return MapGet(result, arg2[1:]...) - } - return result, nil - } - return nil, nil - - } - return nil, nil -} diff --git a/templatefunc_test.go b/templatefunc_test.go deleted file mode 100644 index b4c19c2e..00000000 --- a/templatefunc_test.go +++ /dev/null @@ -1,380 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package beego - -import ( - "html/template" - "net/url" - "reflect" - "testing" - "time" -) - -func TestSubstr(t *testing.T) { - s := `012345` - if Substr(s, 0, 2) != "01" { - t.Error("should be equal") - } - if Substr(s, 0, 100) != "012345" { - t.Error("should be equal") - } - if Substr(s, 12, 100) != "012345" { - t.Error("should be equal") - } -} - -func TestHtml2str(t *testing.T) { - h := `<123> 123\n - - - \n` - if HTML2str(h) != "123\\n\n\\n" { - t.Error("should be equal") - } -} - -func TestDateFormat(t *testing.T) { - ts := "Mon, 01 Jul 2013 13:27:42 CST" - tt, _ := time.Parse(time.RFC1123, ts) - - if ss := DateFormat(tt, "2006-01-02 15:04:05"); ss != "2013-07-01 13:27:42" { - t.Errorf("2013-07-01 13:27:42 does not equal %v", ss) - } -} - -func TestDate(t *testing.T) { - ts := "Mon, 01 Jul 2013 13:27:42 CST" - tt, _ := time.Parse(time.RFC1123, ts) - - if ss := Date(tt, "Y-m-d H:i:s"); ss != "2013-07-01 13:27:42" { - t.Errorf("2013-07-01 13:27:42 does not equal %v", ss) - } - if ss := Date(tt, "y-n-j h:i:s A"); ss != "13-7-1 01:27:42 PM" { - t.Errorf("13-7-1 01:27:42 PM does not equal %v", ss) - } - if ss := Date(tt, "D, d M Y g:i:s a"); ss != "Mon, 01 Jul 2013 1:27:42 pm" { - t.Errorf("Mon, 01 Jul 2013 1:27:42 pm does not equal %v", ss) - } - if ss := Date(tt, "l, d F Y G:i:s"); ss != "Monday, 01 July 2013 13:27:42" { - t.Errorf("Monday, 01 July 2013 13:27:42 does not equal %v", ss) - } -} - -func TestCompareRelated(t *testing.T) { - if !Compare("abc", "abc") { - t.Error("should be equal") - } - if Compare("abc", "aBc") { - t.Error("should be not equal") - } - if !Compare("1", 1) { - t.Error("should be equal") - } - if CompareNot("abc", "abc") { - t.Error("should be equal") - } - if !CompareNot("abc", "aBc") { - t.Error("should be not equal") - } - if !NotNil("a string") { - t.Error("should not be nil") - } -} - -func TestHtmlquote(t *testing.T) { - h := `<' ”“&">` - s := `<' ”“&">` - if Htmlquote(s) != h { - t.Error("should be equal") - } -} - -func TestHtmlunquote(t *testing.T) { - h := `<' ”“&">` - s := `<' ”“&">` - if Htmlunquote(h) != s { - t.Error("should be equal") - } -} - -func TestParseForm(t *testing.T) { - type ExtendInfo struct { - Hobby []string `form:"hobby"` - Memo string - } - - type OtherInfo struct { - Organization string `form:"organization"` - Title string `form:"title"` - ExtendInfo - } - - type user struct { - ID int `form:"-"` - tag string `form:"tag"` - Name interface{} `form:"username"` - Age int `form:"age,text"` - Email string - Intro string `form:",textarea"` - StrBool bool `form:"strbool"` - Date time.Time `form:"date,2006-01-02"` - OtherInfo - } - - u := user{} - form := url.Values{ - "ID": []string{"1"}, - "-": []string{"1"}, - "tag": []string{"no"}, - "username": []string{"test"}, - "age": []string{"40"}, - "Email": []string{"test@gmail.com"}, - "Intro": []string{"I am an engineer!"}, - "strbool": []string{"yes"}, - "date": []string{"2014-11-12"}, - "organization": []string{"beego"}, - "title": []string{"CXO"}, - "hobby": []string{"", "Basketball", "Football"}, - "memo": []string{"nothing"}, - } - if err := ParseForm(form, u); err == nil { - t.Fatal("nothing will be changed") - } - if err := ParseForm(form, &u); err != nil { - t.Fatal(err) - } - if u.ID != 0 { - t.Errorf("ID should equal 0 but got %v", u.ID) - } - if len(u.tag) != 0 { - t.Errorf("tag's length should equal 0 but got %v", len(u.tag)) - } - if u.Name.(string) != "test" { - t.Errorf("Name should equal `test` but got `%v`", u.Name.(string)) - } - if u.Age != 40 { - t.Errorf("Age should equal 40 but got %v", u.Age) - } - if u.Email != "test@gmail.com" { - t.Errorf("Email should equal `test@gmail.com` but got `%v`", u.Email) - } - if u.Intro != "I am an engineer!" { - t.Errorf("Intro should equal `I am an engineer!` but got `%v`", u.Intro) - } - if !u.StrBool { - t.Errorf("strboll should equal `true`, but got `%v`", u.StrBool) - } - y, m, d := u.Date.Date() - if y != 2014 || m.String() != "November" || d != 12 { - t.Errorf("Date should equal `2014-11-12`, but got `%v`", u.Date.String()) - } - if u.Organization != "beego" { - t.Errorf("Organization should equal `beego`, but got `%v`", u.Organization) - } - if u.Title != "CXO" { - t.Errorf("Title should equal `CXO`, but got `%v`", u.Title) - } - if u.Hobby[0] != "" { - t.Errorf("Hobby should equal ``, but got `%v`", u.Hobby[0]) - } - if u.Hobby[1] != "Basketball" { - t.Errorf("Hobby should equal `Basketball`, but got `%v`", u.Hobby[1]) - } - if u.Hobby[2] != "Football" { - t.Errorf("Hobby should equal `Football`, but got `%v`", u.Hobby[2]) - } - if len(u.Memo) != 0 { - t.Errorf("Memo's length should equal 0 but got %v", len(u.Memo)) - } -} - -func TestRenderForm(t *testing.T) { - type user struct { - ID int `form:"-"` - Name interface{} `form:"username"` - Age int `form:"age,text,年龄:"` - Sex string - Email []string - Intro string `form:",textarea"` - Ignored string `form:"-"` - } - - u := user{Name: "test", Intro: "Some Text"} - output := RenderForm(u) - if output != template.HTML("") { - t.Errorf("output should be empty but got %v", output) - } - output = RenderForm(&u) - result := template.HTML( - `Name:
` + - `年龄:
` + - `Sex:
` + - `Intro: `) - if output != result { - t.Errorf("output should equal `%v` but got `%v`", result, output) - } -} - -func TestRenderFormField(t *testing.T) { - html := renderFormField("Label: ", "Name", "text", "Value", "", "", false) - if html != `Label: ` { - t.Errorf("Wrong html output for input[type=text]: %v ", html) - } - - html = renderFormField("Label: ", "Name", "textarea", "Value", "", "", false) - if html != `Label: ` { - t.Errorf("Wrong html output for textarea: %v ", html) - } - - html = renderFormField("Label: ", "Name", "textarea", "Value", "", "", true) - if html != `Label: ` { - t.Errorf("Wrong html output for textarea: %v ", html) - } -} - -func TestParseFormTag(t *testing.T) { - // create struct to contain field with different types of struct-tag `form` - type user struct { - All int `form:"name,text,年龄:"` - NoName int `form:",hidden,年龄:"` - OnlyLabel int `form:",,年龄:"` - OnlyName int `form:"name" id:"name" class:"form-name"` - Ignored int `form:"-"` - Required int `form:"name" required:"true"` - IgnoreRequired int `form:"name"` - NotRequired int `form:"name" required:"false"` - } - - objT := reflect.TypeOf(&user{}).Elem() - - label, name, fType, _, _, ignored, _ := parseFormTag(objT.Field(0)) - if !(name == "name" && label == "年龄:" && fType == "text" && !ignored) { - t.Errorf("Form Tag with name, label and type was not correctly parsed.") - } - - label, name, fType, _, _, ignored, _ = parseFormTag(objT.Field(1)) - if !(name == "NoName" && label == "年龄:" && fType == "hidden" && !ignored) { - t.Errorf("Form Tag with label and type but without name was not correctly parsed.") - } - - label, name, fType, _, _, ignored, _ = parseFormTag(objT.Field(2)) - if !(name == "OnlyLabel" && label == "年龄:" && fType == "text" && !ignored) { - t.Errorf("Form Tag containing only label was not correctly parsed.") - } - - label, name, fType, id, class, ignored, _ := parseFormTag(objT.Field(3)) - if !(name == "name" && label == "OnlyName: " && fType == "text" && !ignored && - id == "name" && class == "form-name") { - t.Errorf("Form Tag containing only name was not correctly parsed.") - } - - _, _, _, _, _, ignored, _ = parseFormTag(objT.Field(4)) - if !ignored { - t.Errorf("Form Tag that should be ignored was not correctly parsed.") - } - - _, name, _, _, _, _, required := parseFormTag(objT.Field(5)) - if !(name == "name" && required) { - t.Errorf("Form Tag containing only name and required was not correctly parsed.") - } - - _, name, _, _, _, _, required = parseFormTag(objT.Field(6)) - if !(name == "name" && !required) { - t.Errorf("Form Tag containing only name and ignore required was not correctly parsed.") - } - - _, name, _, _, _, _, required = parseFormTag(objT.Field(7)) - if !(name == "name" && !required) { - t.Errorf("Form Tag containing only name and not required was not correctly parsed.") - } - -} - -func TestMapGet(t *testing.T) { - // test one level map - m1 := map[string]int64{ - "a": 1, - "1": 2, - } - - if res, err := MapGet(m1, "a"); err == nil { - if res.(int64) != 1 { - t.Errorf("Should return 1, but return %v", res) - } - } else { - t.Errorf("Error happens %v", err) - } - - if res, err := MapGet(m1, "1"); err == nil { - if res.(int64) != 2 { - t.Errorf("Should return 2, but return %v", res) - } - } else { - t.Errorf("Error happens %v", err) - } - - if res, err := MapGet(m1, 1); err == nil { - if res.(int64) != 2 { - t.Errorf("Should return 2, but return %v", res) - } - } else { - t.Errorf("Error happens %v", err) - } - - // test 2 level map - m2 := M{ - "1": map[string]float64{ - "2": 3.5, - }, - } - - if res, err := MapGet(m2, 1, 2); err == nil { - if res.(float64) != 3.5 { - t.Errorf("Should return 3.5, but return %v", res) - } - } else { - t.Errorf("Error happens %v", err) - } - - // test 5 level map - m5 := M{ - "1": M{ - "2": M{ - "3": M{ - "4": M{ - "5": 1.2, - }, - }, - }, - }, - } - - if res, err := MapGet(m5, 1, 2, 3, 4, 5); err == nil { - if res.(float64) != 1.2 { - t.Errorf("Should return 1.2, but return %v", res) - } - } else { - t.Errorf("Error happens %v", err) - } - - // check whether element not exists in map - if res, err := MapGet(m5, 5, 4, 3, 2, 1); err == nil { - if res != nil { - t.Errorf("Should return nil, but return %v", res) - } - } else { - t.Errorf("Error happens %v", err) - } -} diff --git a/testing/assertions.go b/testing/assertions.go deleted file mode 100644 index 96c5d4dd..00000000 --- a/testing/assertions.go +++ /dev/null @@ -1,15 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package testing diff --git a/testing/client.go b/testing/client.go deleted file mode 100644 index c3737e9c..00000000 --- a/testing/client.go +++ /dev/null @@ -1,65 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package testing - -import ( - "github.com/astaxie/beego/config" - "github.com/astaxie/beego/httplib" -) - -var port = "" -var baseURL = "http://localhost:" - -// TestHTTPRequest beego test request client -type TestHTTPRequest struct { - httplib.BeegoHTTPRequest -} - -func getPort() string { - if port == "" { - config, err := config.NewConfig("ini", "../conf/app.conf") - if err != nil { - return "8080" - } - port = config.String("httpport") - return port - } - return port -} - -// Get returns test client in GET method -func Get(path string) *TestHTTPRequest { - return &TestHTTPRequest{*httplib.Get(baseURL + getPort() + path)} -} - -// Post returns test client in POST method -func Post(path string) *TestHTTPRequest { - return &TestHTTPRequest{*httplib.Post(baseURL + getPort() + path)} -} - -// Put returns test client in PUT method -func Put(path string) *TestHTTPRequest { - return &TestHTTPRequest{*httplib.Put(baseURL + getPort() + path)} -} - -// Delete returns test client in DELETE method -func Delete(path string) *TestHTTPRequest { - return &TestHTTPRequest{*httplib.Delete(baseURL + getPort() + path)} -} - -// Head returns test client in HEAD method -func Head(path string) *TestHTTPRequest { - return &TestHTTPRequest{*httplib.Head(baseURL + getPort() + path)} -} diff --git a/toolbox/healthcheck.go b/toolbox/healthcheck.go deleted file mode 100644 index e3544b3a..00000000 --- a/toolbox/healthcheck.go +++ /dev/null @@ -1,48 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Package toolbox healthcheck -// -// type DatabaseCheck struct { -// } -// -// func (dc *DatabaseCheck) Check() error { -// if dc.isConnected() { -// return nil -// } else { -// return errors.New("can't connect database") -// } -// } -// -// AddHealthCheck("database",&DatabaseCheck{}) -// -// more docs: http://beego.me/docs/module/toolbox.md -package toolbox - -// AdminCheckList holds health checker map -var AdminCheckList map[string]HealthChecker - -// HealthChecker health checker interface -type HealthChecker interface { - Check() error -} - -// AddHealthCheck add health checker with name string -func AddHealthCheck(name string, hc HealthChecker) { - AdminCheckList[name] = hc -} - -func init() { - AdminCheckList = make(map[string]HealthChecker) -} diff --git a/toolbox/profile.go b/toolbox/profile.go deleted file mode 100644 index 06e40ede..00000000 --- a/toolbox/profile.go +++ /dev/null @@ -1,184 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package toolbox - -import ( - "fmt" - "io" - "log" - "os" - "path" - "runtime" - "runtime/debug" - "runtime/pprof" - "strconv" - "time" -) - -var startTime = time.Now() -var pid int - -func init() { - pid = os.Getpid() -} - -// ProcessInput parse input command string -func ProcessInput(input string, w io.Writer) { - switch input { - case "lookup goroutine": - p := pprof.Lookup("goroutine") - p.WriteTo(w, 2) - case "lookup heap": - p := pprof.Lookup("heap") - p.WriteTo(w, 2) - case "lookup threadcreate": - p := pprof.Lookup("threadcreate") - p.WriteTo(w, 2) - case "lookup block": - p := pprof.Lookup("block") - p.WriteTo(w, 2) - case "get cpuprof": - GetCPUProfile(w) - case "get memprof": - MemProf(w) - case "gc summary": - PrintGCSummary(w) - } -} - -// MemProf record memory profile in pprof -func MemProf(w io.Writer) { - filename := "mem-" + strconv.Itoa(pid) + ".memprof" - if f, err := os.Create(filename); err != nil { - fmt.Fprintf(w, "create file %s error %s\n", filename, err.Error()) - log.Fatal("record heap profile failed: ", err) - } else { - runtime.GC() - pprof.WriteHeapProfile(f) - f.Close() - fmt.Fprintf(w, "create heap profile %s \n", filename) - _, fl := path.Split(os.Args[0]) - fmt.Fprintf(w, "Now you can use this to check it: go tool pprof %s %s\n", fl, filename) - } -} - -// GetCPUProfile start cpu profile monitor -func GetCPUProfile(w io.Writer) { - sec := 30 - filename := "cpu-" + strconv.Itoa(pid) + ".pprof" - f, err := os.Create(filename) - if err != nil { - fmt.Fprintf(w, "Could not enable CPU profiling: %s\n", err) - log.Fatal("record cpu profile failed: ", err) - } - pprof.StartCPUProfile(f) - time.Sleep(time.Duration(sec) * time.Second) - pprof.StopCPUProfile() - - fmt.Fprintf(w, "create cpu profile %s \n", filename) - _, fl := path.Split(os.Args[0]) - fmt.Fprintf(w, "Now you can use this to check it: go tool pprof %s %s\n", fl, filename) -} - -// PrintGCSummary print gc information to io.Writer -func PrintGCSummary(w io.Writer) { - memStats := &runtime.MemStats{} - runtime.ReadMemStats(memStats) - gcstats := &debug.GCStats{PauseQuantiles: make([]time.Duration, 100)} - debug.ReadGCStats(gcstats) - - printGC(memStats, gcstats, w) -} - -func printGC(memStats *runtime.MemStats, gcstats *debug.GCStats, w io.Writer) { - - if gcstats.NumGC > 0 { - lastPause := gcstats.Pause[0] - elapsed := time.Now().Sub(startTime) - overhead := float64(gcstats.PauseTotal) / float64(elapsed) * 100 - allocatedRate := float64(memStats.TotalAlloc) / elapsed.Seconds() - - fmt.Fprintf(w, "NumGC:%d Pause:%s Pause(Avg):%s Overhead:%3.2f%% Alloc:%s Sys:%s Alloc(Rate):%s/s Histogram:%s %s %s \n", - gcstats.NumGC, - toS(lastPause), - toS(avg(gcstats.Pause)), - overhead, - toH(memStats.Alloc), - toH(memStats.Sys), - toH(uint64(allocatedRate)), - toS(gcstats.PauseQuantiles[94]), - toS(gcstats.PauseQuantiles[98]), - toS(gcstats.PauseQuantiles[99])) - } else { - // while GC has disabled - elapsed := time.Now().Sub(startTime) - allocatedRate := float64(memStats.TotalAlloc) / elapsed.Seconds() - - fmt.Fprintf(w, "Alloc:%s Sys:%s Alloc(Rate):%s/s\n", - toH(memStats.Alloc), - toH(memStats.Sys), - toH(uint64(allocatedRate))) - } -} - -func avg(items []time.Duration) time.Duration { - var sum time.Duration - for _, item := range items { - sum += item - } - return time.Duration(int64(sum) / int64(len(items))) -} - -// format bytes number friendly -func toH(bytes uint64) string { - switch { - case bytes < 1024: - return fmt.Sprintf("%dB", bytes) - case bytes < 1024*1024: - return fmt.Sprintf("%.2fK", float64(bytes)/1024) - case bytes < 1024*1024*1024: - return fmt.Sprintf("%.2fM", float64(bytes)/1024/1024) - default: - return fmt.Sprintf("%.2fG", float64(bytes)/1024/1024/1024) - } -} - -// short string format -func toS(d time.Duration) string { - - u := uint64(d) - if u < uint64(time.Second) { - switch { - case u == 0: - return "0" - case u < uint64(time.Microsecond): - return fmt.Sprintf("%.2fns", float64(u)) - case u < uint64(time.Millisecond): - return fmt.Sprintf("%.2fus", float64(u)/1000) - default: - return fmt.Sprintf("%.2fms", float64(u)/1000/1000) - } - } else { - switch { - case u < uint64(time.Minute): - return fmt.Sprintf("%.2fs", float64(u)/1000/1000/1000) - case u < uint64(time.Hour): - return fmt.Sprintf("%.2fm", float64(u)/1000/1000/1000/60) - default: - return fmt.Sprintf("%.2fh", float64(u)/1000/1000/1000/60/60) - } - } - -} diff --git a/toolbox/profile_test.go b/toolbox/profile_test.go deleted file mode 100644 index 07a20c4e..00000000 --- a/toolbox/profile_test.go +++ /dev/null @@ -1,28 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package toolbox - -import ( - "os" - "testing" -) - -func TestProcessInput(t *testing.T) { - ProcessInput("lookup goroutine", os.Stdout) - ProcessInput("lookup heap", os.Stdout) - ProcessInput("lookup threadcreate", os.Stdout) - ProcessInput("lookup block", os.Stdout) - ProcessInput("gc summary", os.Stdout) -} diff --git a/toolbox/statistics.go b/toolbox/statistics.go deleted file mode 100644 index fd73dfb3..00000000 --- a/toolbox/statistics.go +++ /dev/null @@ -1,149 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package toolbox - -import ( - "fmt" - "sync" - "time" -) - -// Statistics struct -type Statistics struct { - RequestURL string - RequestController string - RequestNum int64 - MinTime time.Duration - MaxTime time.Duration - TotalTime time.Duration -} - -// URLMap contains several statistics struct to log different data -type URLMap struct { - lock sync.RWMutex - LengthLimit int //limit the urlmap's length if it's equal to 0 there's no limit - urlmap map[string]map[string]*Statistics -} - -// AddStatistics add statistics task. -// it needs request method, request url, request controller and statistics time duration -func (m *URLMap) AddStatistics(requestMethod, requestURL, requestController string, requesttime time.Duration) { - m.lock.Lock() - defer m.lock.Unlock() - if method, ok := m.urlmap[requestURL]; ok { - if s, ok := method[requestMethod]; ok { - s.RequestNum++ - if s.MaxTime < requesttime { - s.MaxTime = requesttime - } - if s.MinTime > requesttime { - s.MinTime = requesttime - } - s.TotalTime += requesttime - } else { - nb := &Statistics{ - RequestURL: requestURL, - RequestController: requestController, - RequestNum: 1, - MinTime: requesttime, - MaxTime: requesttime, - TotalTime: requesttime, - } - m.urlmap[requestURL][requestMethod] = nb - } - - } else { - if m.LengthLimit > 0 && m.LengthLimit <= len(m.urlmap) { - return - } - methodmap := make(map[string]*Statistics) - nb := &Statistics{ - RequestURL: requestURL, - RequestController: requestController, - RequestNum: 1, - MinTime: requesttime, - MaxTime: requesttime, - TotalTime: requesttime, - } - methodmap[requestMethod] = nb - m.urlmap[requestURL] = methodmap - } -} - -// GetMap put url statistics result in io.Writer -func (m *URLMap) GetMap() map[string]interface{} { - m.lock.RLock() - defer m.lock.RUnlock() - - var fields = []string{"requestUrl", "method", "times", "used", "max used", "min used", "avg used"} - - var resultLists [][]string - content := make(map[string]interface{}) - content["Fields"] = fields - - for k, v := range m.urlmap { - for kk, vv := range v { - result := []string{ - fmt.Sprintf("% -50s", k), - fmt.Sprintf("% -10s", kk), - fmt.Sprintf("% -16d", vv.RequestNum), - fmt.Sprintf("%d", vv.TotalTime), - fmt.Sprintf("% -16s", toS(vv.TotalTime)), - fmt.Sprintf("%d", vv.MaxTime), - fmt.Sprintf("% -16s", toS(vv.MaxTime)), - fmt.Sprintf("%d", vv.MinTime), - fmt.Sprintf("% -16s", toS(vv.MinTime)), - fmt.Sprintf("%d", time.Duration(int64(vv.TotalTime)/vv.RequestNum)), - fmt.Sprintf("% -16s", toS(time.Duration(int64(vv.TotalTime)/vv.RequestNum))), - } - resultLists = append(resultLists, result) - } - } - content["Data"] = resultLists - return content -} - -// GetMapData return all mapdata -func (m *URLMap) GetMapData() []map[string]interface{} { - m.lock.RLock() - defer m.lock.RUnlock() - - var resultLists []map[string]interface{} - - for k, v := range m.urlmap { - for kk, vv := range v { - result := map[string]interface{}{ - "request_url": k, - "method": kk, - "times": vv.RequestNum, - "total_time": toS(vv.TotalTime), - "max_time": toS(vv.MaxTime), - "min_time": toS(vv.MinTime), - "avg_time": toS(time.Duration(int64(vv.TotalTime) / vv.RequestNum)), - } - resultLists = append(resultLists, result) - } - } - return resultLists -} - -// StatisticsMap hosld global statistics data map -var StatisticsMap *URLMap - -func init() { - StatisticsMap = &URLMap{ - urlmap: make(map[string]map[string]*Statistics), - } -} diff --git a/toolbox/statistics_test.go b/toolbox/statistics_test.go deleted file mode 100644 index ac29476c..00000000 --- a/toolbox/statistics_test.go +++ /dev/null @@ -1,40 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package toolbox - -import ( - "encoding/json" - "testing" - "time" -) - -func TestStatics(t *testing.T) { - StatisticsMap.AddStatistics("POST", "/api/user", "&admin.user", time.Duration(2000)) - StatisticsMap.AddStatistics("POST", "/api/user", "&admin.user", time.Duration(120000)) - StatisticsMap.AddStatistics("GET", "/api/user", "&admin.user", time.Duration(13000)) - StatisticsMap.AddStatistics("POST", "/api/admin", "&admin.user", time.Duration(14000)) - StatisticsMap.AddStatistics("POST", "/api/user/astaxie", "&admin.user", time.Duration(12000)) - StatisticsMap.AddStatistics("POST", "/api/user/xiemengjun", "&admin.user", time.Duration(13000)) - StatisticsMap.AddStatistics("DELETE", "/api/user", "&admin.user", time.Duration(1400)) - t.Log(StatisticsMap.GetMap()) - - data := StatisticsMap.GetMapData() - b, err := json.Marshal(data) - if err != nil { - t.Errorf(err.Error()) - } - - t.Log(string(b)) -} diff --git a/toolbox/task.go b/toolbox/task.go deleted file mode 100644 index fb2c5f16..00000000 --- a/toolbox/task.go +++ /dev/null @@ -1,640 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package toolbox - -import ( - "log" - "math" - "sort" - "strconv" - "strings" - "sync" - "time" -) - -// bounds provides a range of acceptable values (plus a map of name to value). -type bounds struct { - min, max uint - names map[string]uint -} - -// The bounds for each field. -var ( - AdminTaskList map[string]Tasker - taskLock sync.RWMutex - stop chan bool - changed chan bool - isstart bool - seconds = bounds{0, 59, nil} - minutes = bounds{0, 59, nil} - hours = bounds{0, 23, nil} - days = bounds{1, 31, nil} - months = bounds{1, 12, map[string]uint{ - "jan": 1, - "feb": 2, - "mar": 3, - "apr": 4, - "may": 5, - "jun": 6, - "jul": 7, - "aug": 8, - "sep": 9, - "oct": 10, - "nov": 11, - "dec": 12, - }} - weeks = bounds{0, 6, map[string]uint{ - "sun": 0, - "mon": 1, - "tue": 2, - "wed": 3, - "thu": 4, - "fri": 5, - "sat": 6, - }} -) - -const ( - // Set the top bit if a star was included in the expression. - starBit = 1 << 63 -) - -// Schedule time taks schedule -type Schedule struct { - Second uint64 - Minute uint64 - Hour uint64 - Day uint64 - Month uint64 - Week uint64 -} - -// TaskFunc task func type -type TaskFunc func() error - -// Tasker task interface -type Tasker interface { - GetSpec() string - GetStatus() string - Run() error - SetNext(time.Time) - GetNext() time.Time - SetPrev(time.Time) - GetPrev() time.Time -} - -// task error -type taskerr struct { - t time.Time - errinfo string -} - -// Task task struct -// It's not a thread-safe structure. -// Only nearest errors will be saved in ErrList -type Task struct { - Taskname string - Spec *Schedule - SpecStr string - DoFunc TaskFunc - Prev time.Time - Next time.Time - Errlist []*taskerr // like errtime:errinfo - ErrLimit int // max length for the errlist, 0 stand for no limit - errCnt int // records the error count during the execution -} - -// NewTask add new task with name, time and func -func NewTask(tname string, spec string, f TaskFunc) *Task { - - task := &Task{ - Taskname: tname, - DoFunc: f, - // Make configurable - ErrLimit: 100, - SpecStr: spec, - // we only store the pointer, so it won't use too many space - Errlist: make([]*taskerr, 100, 100), - } - task.SetCron(spec) - return task -} - -// GetSpec get spec string -func (t *Task) GetSpec() string { - return t.SpecStr -} - -// GetStatus get current task status -func (t *Task) GetStatus() string { - var str string - for _, v := range t.Errlist { - str += v.t.String() + ":" + v.errinfo + "
" - } - return str -} - -// Run run all tasks -func (t *Task) Run() error { - err := t.DoFunc() - if err != nil { - index := t.errCnt % t.ErrLimit - t.Errlist[index] = &taskerr{t: t.Next, errinfo: err.Error()} - t.errCnt++ - } - return err -} - -// SetNext set next time for this task -func (t *Task) SetNext(now time.Time) { - t.Next = t.Spec.Next(now) -} - -// GetNext get the next call time of this task -func (t *Task) GetNext() time.Time { - return t.Next -} - -// SetPrev set prev time of this task -func (t *Task) SetPrev(now time.Time) { - t.Prev = now -} - -// GetPrev get prev time of this task -func (t *Task) GetPrev() time.Time { - return t.Prev -} - -// six columns mean: -// second:0-59 -// minute:0-59 -// hour:1-23 -// day:1-31 -// month:1-12 -// week:0-6(0 means Sunday) - -// SetCron some signals: -// *: any time -// ,:  separate signal -//   -:duration -// /n : do as n times of time duration -///////////////////////////////////////////////////////// -// 0/30 * * * * * every 30s -// 0 43 21 * * * 21:43 -// 0 15 05 * * *    05:15 -// 0 0 17 * * * 17:00 -// 0 0 17 * * 1 17:00 in every Monday -// 0 0,10 17 * * 0,2,3 17:00 and 17:10 in every Sunday, Tuesday and Wednesday -// 0 0-10 17 1 * * 17:00 to 17:10 in 1 min duration each time on the first day of month -// 0 0 0 1,15 * 1 0:00 on the 1st day and 15th day of month -// 0 42 4 1 * *     4:42 on the 1st day of month -// 0 0 21 * * 1-6   21:00 from Monday to Saturday -// 0 0,10,20,30,40,50 * * * *  every 10 min duration -// 0 */10 * * * *        every 10 min duration -// 0 * 1 * * *         1:00 to 1:59 in 1 min duration each time -// 0 0 1 * * *         1:00 -// 0 0 */1 * * *        0 min of hour in 1 hour duration -// 0 0 * * * *         0 min of hour in 1 hour duration -// 0 2 8-20/3 * * *       8:02, 11:02, 14:02, 17:02, 20:02 -// 0 30 5 1,15 * *       5:30 on the 1st day and 15th day of month -func (t *Task) SetCron(spec string) { - t.Spec = t.parse(spec) -} - -func (t *Task) parse(spec string) *Schedule { - if len(spec) > 0 && spec[0] == '@' { - return t.parseSpec(spec) - } - // Split on whitespace. We require 5 or 6 fields. - // (second) (minute) (hour) (day of month) (month) (day of week, optional) - fields := strings.Fields(spec) - if len(fields) != 5 && len(fields) != 6 { - log.Panicf("Expected 5 or 6 fields, found %d: %s", len(fields), spec) - } - - // If a sixth field is not provided (DayOfWeek), then it is equivalent to star. - if len(fields) == 5 { - fields = append(fields, "*") - } - - schedule := &Schedule{ - Second: getField(fields[0], seconds), - Minute: getField(fields[1], minutes), - Hour: getField(fields[2], hours), - Day: getField(fields[3], days), - Month: getField(fields[4], months), - Week: getField(fields[5], weeks), - } - - return schedule -} - -func (t *Task) parseSpec(spec string) *Schedule { - switch spec { - case "@yearly", "@annually": - return &Schedule{ - Second: 1 << seconds.min, - Minute: 1 << minutes.min, - Hour: 1 << hours.min, - Day: 1 << days.min, - Month: 1 << months.min, - Week: all(weeks), - } - - case "@monthly": - return &Schedule{ - Second: 1 << seconds.min, - Minute: 1 << minutes.min, - Hour: 1 << hours.min, - Day: 1 << days.min, - Month: all(months), - Week: all(weeks), - } - - case "@weekly": - return &Schedule{ - Second: 1 << seconds.min, - Minute: 1 << minutes.min, - Hour: 1 << hours.min, - Day: all(days), - Month: all(months), - Week: 1 << weeks.min, - } - - case "@daily", "@midnight": - return &Schedule{ - Second: 1 << seconds.min, - Minute: 1 << minutes.min, - Hour: 1 << hours.min, - Day: all(days), - Month: all(months), - Week: all(weeks), - } - - case "@hourly": - return &Schedule{ - Second: 1 << seconds.min, - Minute: 1 << minutes.min, - Hour: all(hours), - Day: all(days), - Month: all(months), - Week: all(weeks), - } - } - log.Panicf("Unrecognized descriptor: %s", spec) - return nil -} - -// Next set schedule to next time -func (s *Schedule) Next(t time.Time) time.Time { - - // Start at the earliest possible time (the upcoming second). - t = t.Add(1*time.Second - time.Duration(t.Nanosecond())*time.Nanosecond) - - // This flag indicates whether a field has been incremented. - added := false - - // If no time is found within five years, return zero. - yearLimit := t.Year() + 5 - -WRAP: - if t.Year() > yearLimit { - return time.Time{} - } - - // Find the first applicable month. - // If it's this month, then do nothing. - for 1< 0 - dowMatch = 1< 0 - ) - - if s.Day&starBit > 0 || s.Week&starBit > 0 { - return domMatch && dowMatch - } - return domMatch || dowMatch -} - -// StartTask start all tasks -func StartTask() { - taskLock.Lock() - defer taskLock.Unlock() - if isstart { - //If already started, no need to start another goroutine. - return - } - isstart = true - go run() -} - -func run() { - now := time.Now().Local() - for _, t := range AdminTaskList { - t.SetNext(now) - } - - for { - // we only use RLock here because NewMapSorter copy the reference, do not change any thing - taskLock.RLock() - sortList := NewMapSorter(AdminTaskList) - taskLock.RUnlock() - sortList.Sort() - var effective time.Time - if len(AdminTaskList) == 0 || sortList.Vals[0].GetNext().IsZero() { - // If there are no entries yet, just sleep - it still handles new entries - // and stop requests. - effective = now.AddDate(10, 0, 0) - } else { - effective = sortList.Vals[0].GetNext() - } - select { - case now = <-time.After(effective.Sub(now)): - // Run every entry whose next time was this effective time. - for _, e := range sortList.Vals { - if e.GetNext() != effective { - break - } - go e.Run() - e.SetPrev(e.GetNext()) - e.SetNext(effective) - } - continue - case <-changed: - now = time.Now().Local() - taskLock.Lock() - for _, t := range AdminTaskList { - t.SetNext(now) - } - taskLock.Unlock() - continue - case <-stop: - return - } - } -} - -// StopTask stop all tasks -func StopTask() { - taskLock.Lock() - defer taskLock.Unlock() - if isstart { - isstart = false - stop <- true - } - -} - -// AddTask add task with name -func AddTask(taskname string, t Tasker) { - taskLock.Lock() - defer taskLock.Unlock() - t.SetNext(time.Now().Local()) - AdminTaskList[taskname] = t - if isstart { - changed <- true - } -} - -// DeleteTask delete task with name -func DeleteTask(taskname string) { - taskLock.Lock() - defer taskLock.Unlock() - delete(AdminTaskList, taskname) - if isstart { - changed <- true - } -} - -// MapSorter sort map for tasker -type MapSorter struct { - Keys []string - Vals []Tasker -} - -// NewMapSorter create new tasker map -func NewMapSorter(m map[string]Tasker) *MapSorter { - ms := &MapSorter{ - Keys: make([]string, 0, len(m)), - Vals: make([]Tasker, 0, len(m)), - } - for k, v := range m { - ms.Keys = append(ms.Keys, k) - ms.Vals = append(ms.Vals, v) - } - return ms -} - -// Sort sort tasker map -func (ms *MapSorter) Sort() { - sort.Sort(ms) -} - -func (ms *MapSorter) Len() int { return len(ms.Keys) } -func (ms *MapSorter) Less(i, j int) bool { - if ms.Vals[i].GetNext().IsZero() { - return false - } - if ms.Vals[j].GetNext().IsZero() { - return true - } - return ms.Vals[i].GetNext().Before(ms.Vals[j].GetNext()) -} -func (ms *MapSorter) Swap(i, j int) { - ms.Vals[i], ms.Vals[j] = ms.Vals[j], ms.Vals[i] - ms.Keys[i], ms.Keys[j] = ms.Keys[j], ms.Keys[i] -} - -func getField(field string, r bounds) uint64 { - // list = range {"," range} - var bits uint64 - ranges := strings.FieldsFunc(field, func(r rune) bool { return r == ',' }) - for _, expr := range ranges { - bits |= getRange(expr, r) - } - return bits -} - -// getRange returns the bits indicated by the given expression: -// number | number "-" number [ "/" number ] -func getRange(expr string, r bounds) uint64 { - - var ( - start, end, step uint - rangeAndStep = strings.Split(expr, "/") - lowAndHigh = strings.Split(rangeAndStep[0], "-") - singleDigit = len(lowAndHigh) == 1 - ) - - var extrastar uint64 - if lowAndHigh[0] == "*" || lowAndHigh[0] == "?" { - start = r.min - end = r.max - extrastar = starBit - } else { - start = parseIntOrName(lowAndHigh[0], r.names) - switch len(lowAndHigh) { - case 1: - end = start - case 2: - end = parseIntOrName(lowAndHigh[1], r.names) - default: - log.Panicf("Too many hyphens: %s", expr) - } - } - - switch len(rangeAndStep) { - case 1: - step = 1 - case 2: - step = mustParseInt(rangeAndStep[1]) - - // Special handling: "N/step" means "N-max/step". - if singleDigit { - end = r.max - } - default: - log.Panicf("Too many slashes: %s", expr) - } - - if start < r.min { - log.Panicf("Beginning of range (%d) below minimum (%d): %s", start, r.min, expr) - } - if end > r.max { - log.Panicf("End of range (%d) above maximum (%d): %s", end, r.max, expr) - } - if start > end { - log.Panicf("Beginning of range (%d) beyond end of range (%d): %s", start, end, expr) - } - - return getBits(start, end, step) | extrastar -} - -// parseIntOrName returns the (possibly-named) integer contained in expr. -func parseIntOrName(expr string, names map[string]uint) uint { - if names != nil { - if namedInt, ok := names[strings.ToLower(expr)]; ok { - return namedInt - } - } - return mustParseInt(expr) -} - -// mustParseInt parses the given expression as an int or panics. -func mustParseInt(expr string) uint { - num, err := strconv.Atoi(expr) - if err != nil { - log.Panicf("Failed to parse int from %s: %s", expr, err) - } - if num < 0 { - log.Panicf("Negative number (%d) not allowed: %s", num, expr) - } - - return uint(num) -} - -// getBits sets all bits in the range [min, max], modulo the given step size. -func getBits(min, max, step uint) uint64 { - var bits uint64 - - // If step is 1, use shifts. - if step == 1 { - return ^(math.MaxUint64 << (max + 1)) & (math.MaxUint64 << min) - } - - // Else, use a simple loop. - for i := min; i <= max; i += step { - bits |= 1 << i - } - return bits -} - -// all returns all bits within the given bounds. (plus the star bit) -func all(r bounds) uint64 { - return getBits(r.min, r.max, 1) | starBit -} - -func init() { - AdminTaskList = make(map[string]Tasker) - stop = make(chan bool) - changed = make(chan bool) -} diff --git a/toolbox/task_test.go b/toolbox/task_test.go deleted file mode 100644 index b63f4391..00000000 --- a/toolbox/task_test.go +++ /dev/null @@ -1,85 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package toolbox - -import ( - "errors" - "fmt" - "sync" - "testing" - "time" - - "github.com/stretchr/testify/assert" -) - -func TestParse(t *testing.T) { - tk := NewTask("taska", "0/30 * * * * *", func() error { fmt.Println("hello world"); return nil }) - err := tk.Run() - if err != nil { - t.Fatal(err) - } - AddTask("taska", tk) - StartTask() - time.Sleep(6 * time.Second) - StopTask() -} - -func TestSpec(t *testing.T) { - wg := &sync.WaitGroup{} - wg.Add(2) - tk1 := NewTask("tk1", "0 12 * * * *", func() error { fmt.Println("tk1"); return nil }) - tk2 := NewTask("tk2", "0,10,20 * * * * *", func() error { fmt.Println("tk2"); wg.Done(); return nil }) - tk3 := NewTask("tk3", "0 10 * * * *", func() error { fmt.Println("tk3"); wg.Done(); return nil }) - - AddTask("tk1", tk1) - AddTask("tk2", tk2) - AddTask("tk3", tk3) - StartTask() - defer StopTask() - - select { - case <-time.After(200 * time.Second): - t.FailNow() - case <-wait(wg): - } -} - -func TestTask_Run(t *testing.T) { - cnt := -1 - task := func() error { - cnt++ - fmt.Printf("Hello, world! %d \n", cnt) - return errors.New(fmt.Sprintf("Hello, world! %d", cnt)) - } - tk := NewTask("taska", "0/30 * * * * *", task) - for i := 0; i < 200; i++ { - e := tk.Run() - assert.NotNil(t, e) - } - - l := tk.Errlist - assert.Equal(t, 100, len(l)) - assert.Equal(t, "Hello, world! 100", l[0].errinfo) - assert.Equal(t, "Hello, world! 101", l[1].errinfo) -} - -func wait(wg *sync.WaitGroup) chan bool { - ch := make(chan bool) - go func() { - wg.Wait() - ch <- true - }() - return ch -} diff --git a/tree.go b/tree.go deleted file mode 100644 index 7fa3a7cb..00000000 --- a/tree.go +++ /dev/null @@ -1,590 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package beego - -import ( - "path" - "regexp" - "strings" - - "github.com/astaxie/beego/context" - "github.com/astaxie/beego/utils" -) - -var ( - allowSuffixExt = []string{".json", ".xml", ".html"} -) - -// Tree has three elements: FixRouter/wildcard/leaves -// fixRouter stores Fixed Router -// wildcard stores params -// leaves store the endpoint information -// Deprecated: using pkg/, we will delete this in v2.1.0 -type Tree struct { - //prefix set for static router - prefix string - //search fix route first - fixrouters []*Tree - //if set, failure to match fixrouters search then search wildcard - wildcard *Tree - //if set, failure to match wildcard search - leaves []*leafInfo -} - -// NewTree return a new Tree -// Deprecated: using pkg/, we will delete this in v2.1.0 -func NewTree() *Tree { - return &Tree{} -} - -// AddTree will add tree to the exist Tree -// prefix should has no params -// Deprecated: using pkg/, we will delete this in v2.1.0 -func (t *Tree) AddTree(prefix string, tree *Tree) { - t.addtree(splitPath(prefix), tree, nil, "") -} - -func (t *Tree) addtree(segments []string, tree *Tree, wildcards []string, reg string) { - if len(segments) == 0 { - panic("prefix should has path") - } - seg := segments[0] - iswild, params, regexpStr := splitSegment(seg) - // if it's ? meaning can igone this, so add one more rule for it - if len(params) > 0 && params[0] == ":" { - params = params[1:] - if len(segments[1:]) > 0 { - t.addtree(segments[1:], tree, append(wildcards, params...), reg) - } else { - filterTreeWithPrefix(tree, wildcards, reg) - } - } - //Rule: /login/*/access match /login/2009/11/access - //if already has *, and when loop the access, should as a regexpStr - if !iswild && utils.InSlice(":splat", wildcards) { - iswild = true - regexpStr = seg - } - //Rule: /user/:id/* - if seg == "*" && len(wildcards) > 0 && reg == "" { - regexpStr = "(.+)" - } - if len(segments) == 1 { - if iswild { - if regexpStr != "" { - if reg == "" { - rr := "" - for _, w := range wildcards { - if w == ":splat" { - rr = rr + "(.+)/" - } else { - rr = rr + "([^/]+)/" - } - } - regexpStr = rr + regexpStr - } else { - regexpStr = "/" + regexpStr - } - } else if reg != "" { - if seg == "*.*" { - regexpStr = "([^.]+).(.+)" - } else { - for _, w := range params { - if w == "." || w == ":" { - continue - } - regexpStr = "([^/]+)/" + regexpStr - } - } - } - reg = strings.Trim(reg+"/"+regexpStr, "/") - filterTreeWithPrefix(tree, append(wildcards, params...), reg) - t.wildcard = tree - } else { - reg = strings.Trim(reg+"/"+regexpStr, "/") - filterTreeWithPrefix(tree, append(wildcards, params...), reg) - tree.prefix = seg - t.fixrouters = append(t.fixrouters, tree) - } - return - } - - if iswild { - if t.wildcard == nil { - t.wildcard = NewTree() - } - if regexpStr != "" { - if reg == "" { - rr := "" - for _, w := range wildcards { - if w == ":splat" { - rr = rr + "(.+)/" - } else { - rr = rr + "([^/]+)/" - } - } - regexpStr = rr + regexpStr - } else { - regexpStr = "/" + regexpStr - } - } else if reg != "" { - if seg == "*.*" { - regexpStr = "([^.]+).(.+)" - params = params[1:] - } else { - for range params { - regexpStr = "([^/]+)/" + regexpStr - } - } - } else { - if seg == "*.*" { - params = params[1:] - } - } - reg = strings.TrimRight(strings.TrimRight(reg, "/")+"/"+regexpStr, "/") - t.wildcard.addtree(segments[1:], tree, append(wildcards, params...), reg) - } else { - subTree := NewTree() - subTree.prefix = seg - t.fixrouters = append(t.fixrouters, subTree) - subTree.addtree(segments[1:], tree, append(wildcards, params...), reg) - } -} - -func filterTreeWithPrefix(t *Tree, wildcards []string, reg string) { - for _, v := range t.fixrouters { - filterTreeWithPrefix(v, wildcards, reg) - } - if t.wildcard != nil { - filterTreeWithPrefix(t.wildcard, wildcards, reg) - } - for _, l := range t.leaves { - if reg != "" { - if l.regexps != nil { - l.wildcards = append(wildcards, l.wildcards...) - l.regexps = regexp.MustCompile("^" + reg + "/" + strings.Trim(l.regexps.String(), "^$") + "$") - } else { - for _, v := range l.wildcards { - if v == ":splat" { - reg = reg + "/(.+)" - } else { - reg = reg + "/([^/]+)" - } - } - l.regexps = regexp.MustCompile("^" + reg + "$") - l.wildcards = append(wildcards, l.wildcards...) - } - } else { - l.wildcards = append(wildcards, l.wildcards...) - if l.regexps != nil { - for _, w := range wildcards { - if w == ":splat" { - reg = "(.+)/" + reg - } else { - reg = "([^/]+)/" + reg - } - } - l.regexps = regexp.MustCompile("^" + reg + strings.Trim(l.regexps.String(), "^$") + "$") - } - } - } -} - -// AddRouter call addseg function -// Deprecated: using pkg/, we will delete this in v2.1.0 -func (t *Tree) AddRouter(pattern string, runObject interface{}) { - t.addseg(splitPath(pattern), runObject, nil, "") -} - -// "/" -// "admin" -> -func (t *Tree) addseg(segments []string, route interface{}, wildcards []string, reg string) { - if len(segments) == 0 { - if reg != "" { - t.leaves = append(t.leaves, &leafInfo{runObject: route, wildcards: wildcards, regexps: regexp.MustCompile("^" + reg + "$")}) - } else { - t.leaves = append(t.leaves, &leafInfo{runObject: route, wildcards: wildcards}) - } - } else { - seg := segments[0] - iswild, params, regexpStr := splitSegment(seg) - // if it's ? meaning can igone this, so add one more rule for it - if len(params) > 0 && params[0] == ":" { - t.addseg(segments[1:], route, wildcards, reg) - params = params[1:] - } - //Rule: /login/*/access match /login/2009/11/access - //if already has *, and when loop the access, should as a regexpStr - if !iswild && utils.InSlice(":splat", wildcards) { - iswild = true - regexpStr = seg - } - //Rule: /user/:id/* - if seg == "*" && len(wildcards) > 0 && reg == "" { - regexpStr = "(.+)" - } - if iswild { - if t.wildcard == nil { - t.wildcard = NewTree() - } - if regexpStr != "" { - if reg == "" { - rr := "" - for _, w := range wildcards { - if w == ":splat" { - rr = rr + "(.+)/" - } else { - rr = rr + "([^/]+)/" - } - } - regexpStr = rr + regexpStr - } else { - regexpStr = "/" + regexpStr - } - } else if reg != "" { - if seg == "*.*" { - regexpStr = "/([^.]+).(.+)" - params = params[1:] - } else { - for range params { - regexpStr = "/([^/]+)" + regexpStr - } - } - } else { - if seg == "*.*" { - params = params[1:] - } - } - t.wildcard.addseg(segments[1:], route, append(wildcards, params...), reg+regexpStr) - } else { - var subTree *Tree - for _, sub := range t.fixrouters { - if sub.prefix == seg { - subTree = sub - break - } - } - if subTree == nil { - subTree = NewTree() - subTree.prefix = seg - t.fixrouters = append(t.fixrouters, subTree) - } - subTree.addseg(segments[1:], route, wildcards, reg) - } - } -} - -// Match router to runObject & params -// Deprecated: using pkg/, we will delete this in v2.1.0 -func (t *Tree) Match(pattern string, ctx *context.Context) (runObject interface{}) { - if len(pattern) == 0 || pattern[0] != '/' { - return nil - } - w := make([]string, 0, 20) - return t.match(pattern[1:], pattern, w, ctx) -} - -func (t *Tree) match(treePattern string, pattern string, wildcardValues []string, ctx *context.Context) (runObject interface{}) { - if len(pattern) > 0 { - i := 0 - for ; i < len(pattern) && pattern[i] == '/'; i++ { - } - pattern = pattern[i:] - } - // Handle leaf nodes: - if len(pattern) == 0 { - for _, l := range t.leaves { - if ok := l.match(treePattern, wildcardValues, ctx); ok { - return l.runObject - } - } - if t.wildcard != nil { - for _, l := range t.wildcard.leaves { - if ok := l.match(treePattern, wildcardValues, ctx); ok { - return l.runObject - } - } - } - return nil - } - var seg string - i, l := 0, len(pattern) - for ; i < l && pattern[i] != '/'; i++ { - } - if i == 0 { - seg = pattern - pattern = "" - } else { - seg = pattern[:i] - pattern = pattern[i:] - } - for _, subTree := range t.fixrouters { - if subTree.prefix == seg { - if len(pattern) != 0 && pattern[0] == '/' { - treePattern = pattern[1:] - } else { - treePattern = pattern - } - runObject = subTree.match(treePattern, pattern, wildcardValues, ctx) - if runObject != nil { - break - } - } - } - if runObject == nil && len(t.fixrouters) > 0 { - // Filter the .json .xml .html extension - for _, str := range allowSuffixExt { - if strings.HasSuffix(seg, str) { - for _, subTree := range t.fixrouters { - if subTree.prefix == seg[:len(seg)-len(str)] { - runObject = subTree.match(treePattern, pattern, wildcardValues, ctx) - if runObject != nil { - ctx.Input.SetParam(":ext", str[1:]) - } - } - } - } - } - } - if runObject == nil && t.wildcard != nil { - runObject = t.wildcard.match(treePattern, pattern, append(wildcardValues, seg), ctx) - } - - if runObject == nil && len(t.leaves) > 0 { - wildcardValues = append(wildcardValues, seg) - start, i := 0, 0 - for ; i < len(pattern); i++ { - if pattern[i] == '/' { - if i != 0 && start < len(pattern) { - wildcardValues = append(wildcardValues, pattern[start:i]) - } - start = i + 1 - continue - } - } - if start > 0 { - wildcardValues = append(wildcardValues, pattern[start:i]) - } - for _, l := range t.leaves { - if ok := l.match(treePattern, wildcardValues, ctx); ok { - return l.runObject - } - } - } - return runObject -} - -type leafInfo struct { - // names of wildcards that lead to this leaf. eg, ["id" "name"] for the wildcard ":id" and ":name" - wildcards []string - - // if the leaf is regexp - regexps *regexp.Regexp - - runObject interface{} -} - -func (leaf *leafInfo) match(treePattern string, wildcardValues []string, ctx *context.Context) (ok bool) { - //fmt.Println("Leaf:", wildcardValues, leaf.wildcards, leaf.regexps) - if leaf.regexps == nil { - if len(wildcardValues) == 0 && len(leaf.wildcards) == 0 { // static path - return true - } - // match * - if len(leaf.wildcards) == 1 && leaf.wildcards[0] == ":splat" { - ctx.Input.SetParam(":splat", treePattern) - return true - } - // match *.* or :id - if len(leaf.wildcards) >= 2 && leaf.wildcards[len(leaf.wildcards)-2] == ":path" && leaf.wildcards[len(leaf.wildcards)-1] == ":ext" { - if len(leaf.wildcards) == 2 { - lastone := wildcardValues[len(wildcardValues)-1] - strs := strings.SplitN(lastone, ".", 2) - if len(strs) == 2 { - ctx.Input.SetParam(":ext", strs[1]) - } - ctx.Input.SetParam(":path", path.Join(path.Join(wildcardValues[:len(wildcardValues)-1]...), strs[0])) - return true - } else if len(wildcardValues) < 2 { - return false - } - var index int - for index = 0; index < len(leaf.wildcards)-2; index++ { - ctx.Input.SetParam(leaf.wildcards[index], wildcardValues[index]) - } - lastone := wildcardValues[len(wildcardValues)-1] - strs := strings.SplitN(lastone, ".", 2) - if len(strs) == 2 { - ctx.Input.SetParam(":ext", strs[1]) - } - if index > (len(wildcardValues) - 1) { - ctx.Input.SetParam(":path", "") - } else { - ctx.Input.SetParam(":path", path.Join(path.Join(wildcardValues[index:len(wildcardValues)-1]...), strs[0])) - } - return true - } - // match :id - if len(leaf.wildcards) != len(wildcardValues) { - return false - } - for j, v := range leaf.wildcards { - ctx.Input.SetParam(v, wildcardValues[j]) - } - return true - } - - if !leaf.regexps.MatchString(path.Join(wildcardValues...)) { - return false - } - matches := leaf.regexps.FindStringSubmatch(path.Join(wildcardValues...)) - for i, match := range matches[1:] { - if i < len(leaf.wildcards) { - ctx.Input.SetParam(leaf.wildcards[i], match) - } - } - return true -} - -// "/" -> [] -// "/admin" -> ["admin"] -// "/admin/" -> ["admin"] -// "/admin/users" -> ["admin", "users"] -func splitPath(key string) []string { - key = strings.Trim(key, "/ ") - if key == "" { - return []string{} - } - return strings.Split(key, "/") -} - -// "admin" -> false, nil, "" -// ":id" -> true, [:id], "" -// "?:id" -> true, [: :id], "" : meaning can empty -// ":id:int" -> true, [:id], ([0-9]+) -// ":name:string" -> true, [:name], ([\w]+) -// ":id([0-9]+)" -> true, [:id], ([0-9]+) -// ":id([0-9]+)_:name" -> true, [:id :name], ([0-9]+)_(.+) -// "cms_:id_:page.html" -> true, [:id_ :page], cms_(.+)(.+).html -// "cms_:id(.+)_:page.html" -> true, [:id :page], cms_(.+)_(.+).html -// "*" -> true, [:splat], "" -// "*.*" -> true,[. :path :ext], "" . meaning separator -func splitSegment(key string) (bool, []string, string) { - if strings.HasPrefix(key, "*") { - if key == "*.*" { - return true, []string{".", ":path", ":ext"}, "" - } - return true, []string{":splat"}, "" - } - if strings.ContainsAny(key, ":") { - var paramsNum int - var out []rune - var start bool - var startexp bool - var param []rune - var expt []rune - var skipnum int - params := []string{} - reg := regexp.MustCompile(`[a-zA-Z0-9_]+`) - for i, v := range key { - if skipnum > 0 { - skipnum-- - continue - } - if start { - //:id:int and :name:string - if v == ':' { - if len(key) >= i+4 { - if key[i+1:i+4] == "int" { - out = append(out, []rune("([0-9]+)")...) - params = append(params, ":"+string(param)) - start = false - startexp = false - skipnum = 3 - param = make([]rune, 0) - paramsNum++ - continue - } - } - if len(key) >= i+7 { - if key[i+1:i+7] == "string" { - out = append(out, []rune(`([\w]+)`)...) - params = append(params, ":"+string(param)) - paramsNum++ - start = false - startexp = false - skipnum = 6 - param = make([]rune, 0) - continue - } - } - } - // params only support a-zA-Z0-9 - if reg.MatchString(string(v)) { - param = append(param, v) - continue - } - if v != '(' { - out = append(out, []rune(`(.+)`)...) - params = append(params, ":"+string(param)) - param = make([]rune, 0) - paramsNum++ - start = false - startexp = false - } - } - if startexp { - if v != ')' { - expt = append(expt, v) - continue - } - } - // Escape Sequence '\' - if i > 0 && key[i-1] == '\\' { - out = append(out, v) - } else if v == ':' { - param = make([]rune, 0) - start = true - } else if v == '(' { - startexp = true - start = false - if len(param) > 0 { - params = append(params, ":"+string(param)) - param = make([]rune, 0) - } - paramsNum++ - expt = make([]rune, 0) - expt = append(expt, '(') - } else if v == ')' { - startexp = false - expt = append(expt, ')') - out = append(out, expt...) - param = make([]rune, 0) - } else if v == '?' { - params = append(params, ":") - } else { - out = append(out, v) - } - } - if len(param) > 0 { - if paramsNum > 0 { - out = append(out, []rune(`(.+)`)...) - } - params = append(params, ":"+string(param)) - } - return true, params, string(out) - } - return false, nil, "" -} diff --git a/tree_test.go b/tree_test.go deleted file mode 100644 index d412a348..00000000 --- a/tree_test.go +++ /dev/null @@ -1,306 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package beego - -import ( - "strings" - "testing" - - "github.com/astaxie/beego/context" -) - -type testinfo struct { - url string - requesturl string - params map[string]string -} - -var routers []testinfo - -func init() { - routers = make([]testinfo, 0) - routers = append(routers, testinfo{"/topic/?:auth:int", "/topic", nil}) - routers = append(routers, testinfo{"/topic/?:auth:int", "/topic/123", map[string]string{":auth": "123"}}) - routers = append(routers, testinfo{"/topic/:id/?:auth", "/topic/1", map[string]string{":id": "1"}}) - routers = append(routers, testinfo{"/topic/:id/?:auth", "/topic/1/2", map[string]string{":id": "1", ":auth": "2"}}) - routers = append(routers, testinfo{"/topic/:id/?:auth:int", "/topic/1", map[string]string{":id": "1"}}) - routers = append(routers, testinfo{"/topic/:id/?:auth:int", "/topic/1/123", map[string]string{":id": "1", ":auth": "123"}}) - routers = append(routers, testinfo{"/:id", "/123", map[string]string{":id": "123"}}) - routers = append(routers, testinfo{"/hello/?:id", "/hello", map[string]string{":id": ""}}) - routers = append(routers, testinfo{"/", "/", nil}) - routers = append(routers, testinfo{"/customer/login", "/customer/login", nil}) - routers = append(routers, testinfo{"/customer/login", "/customer/login.json", map[string]string{":ext": "json"}}) - routers = append(routers, testinfo{"/*", "/http://customer/123/", map[string]string{":splat": "http://customer/123/"}}) - routers = append(routers, testinfo{"/*", "/customer/2009/12/11", map[string]string{":splat": "customer/2009/12/11"}}) - routers = append(routers, testinfo{"/aa/*/bb", "/aa/2009/bb", map[string]string{":splat": "2009"}}) - routers = append(routers, testinfo{"/cc/*/dd", "/cc/2009/11/dd", map[string]string{":splat": "2009/11"}}) - routers = append(routers, testinfo{"/cc/:id/*", "/cc/2009/11/dd", map[string]string{":id": "2009", ":splat": "11/dd"}}) - routers = append(routers, testinfo{"/ee/:year/*/ff", "/ee/2009/11/ff", map[string]string{":year": "2009", ":splat": "11"}}) - routers = append(routers, testinfo{"/thumbnail/:size/uploads/*", - "/thumbnail/100x100/uploads/items/2014/04/20/dPRCdChkUd651t1Hvs18.jpg", - map[string]string{":size": "100x100", ":splat": "items/2014/04/20/dPRCdChkUd651t1Hvs18.jpg"}}) - routers = append(routers, testinfo{"/*.*", "/nice/api.json", map[string]string{":path": "nice/api", ":ext": "json"}}) - routers = append(routers, testinfo{"/:name/*.*", "/nice/api.json", map[string]string{":name": "nice", ":path": "api", ":ext": "json"}}) - routers = append(routers, testinfo{"/:name/test/*.*", "/nice/test/api.json", map[string]string{":name": "nice", ":path": "api", ":ext": "json"}}) - routers = append(routers, testinfo{"/dl/:width:int/:height:int/*.*", - "/dl/48/48/05ac66d9bda00a3acf948c43e306fc9a.jpg", - map[string]string{":width": "48", ":height": "48", ":ext": "jpg", ":path": "05ac66d9bda00a3acf948c43e306fc9a"}}) - routers = append(routers, testinfo{"/v1/shop/:id:int", "/v1/shop/123", map[string]string{":id": "123"}}) - routers = append(routers, testinfo{"/v1/shop/:id\\((a|b|c)\\)", "/v1/shop/123(a)", map[string]string{":id": "123"}}) - routers = append(routers, testinfo{"/v1/shop/:id\\((a|b|c)\\)", "/v1/shop/123(b)", map[string]string{":id": "123"}}) - routers = append(routers, testinfo{"/v1/shop/:id\\((a|b|c)\\)", "/v1/shop/123(c)", map[string]string{":id": "123"}}) - routers = append(routers, testinfo{"/:year:int/:month:int/:id/:endid", "/1111/111/aaa/aaa", map[string]string{":year": "1111", ":month": "111", ":id": "aaa", ":endid": "aaa"}}) - routers = append(routers, testinfo{"/v1/shop/:id/:name", "/v1/shop/123/nike", map[string]string{":id": "123", ":name": "nike"}}) - routers = append(routers, testinfo{"/v1/shop/:id/account", "/v1/shop/123/account", map[string]string{":id": "123"}}) - routers = append(routers, testinfo{"/v1/shop/:name:string", "/v1/shop/nike", map[string]string{":name": "nike"}}) - routers = append(routers, testinfo{"/v1/shop/:id([0-9]+)", "/v1/shop//123", map[string]string{":id": "123"}}) - routers = append(routers, testinfo{"/v1/shop/:id([0-9]+)_:name", "/v1/shop/123_nike", map[string]string{":id": "123", ":name": "nike"}}) - routers = append(routers, testinfo{"/v1/shop/:id(.+)_cms.html", "/v1/shop/123_cms.html", map[string]string{":id": "123"}}) - routers = append(routers, testinfo{"/v1/shop/cms_:id(.+)_:page(.+).html", "/v1/shop/cms_123_1.html", map[string]string{":id": "123", ":page": "1"}}) - routers = append(routers, testinfo{"/v1/:v/cms/aaa_:id(.+)_:page(.+).html", "/v1/2/cms/aaa_123_1.html", map[string]string{":v": "2", ":id": "123", ":page": "1"}}) - routers = append(routers, testinfo{"/v1/:v/cms_:id(.+)_:page(.+).html", "/v1/2/cms_123_1.html", map[string]string{":v": "2", ":id": "123", ":page": "1"}}) - routers = append(routers, testinfo{"/v1/:v(.+)_cms/ttt_:id(.+)_:page(.+).html", "/v1/2_cms/ttt_123_1.html", map[string]string{":v": "2", ":id": "123", ":page": "1"}}) - routers = append(routers, testinfo{"/api/projects/:pid/members/?:mid", "/api/projects/1/members", map[string]string{":pid": "1"}}) - routers = append(routers, testinfo{"/api/projects/:pid/members/?:mid", "/api/projects/1/members/2", map[string]string{":pid": "1", ":mid": "2"}}) -} - -func TestTreeRouters(t *testing.T) { - for _, r := range routers { - tr := NewTree() - tr.AddRouter(r.url, "astaxie") - ctx := context.NewContext() - obj := tr.Match(r.requesturl, ctx) - if obj == nil || obj.(string) != "astaxie" { - t.Fatal(r.url+" can't get obj, Expect ", r.requesturl) - } - if r.params != nil { - for k, v := range r.params { - if vv := ctx.Input.Param(k); vv != v { - t.Fatal("The Rule: " + r.url + "\nThe RequestURL:" + r.requesturl + "\nThe Key is " + k + ", The Value should be: " + v + ", but get: " + vv) - } else if vv == "" && v != "" { - t.Fatal(r.url + " " + r.requesturl + " get param empty:" + k) - } - } - } - } -} - -func TestStaticPath(t *testing.T) { - tr := NewTree() - tr.AddRouter("/topic/:id", "wildcard") - tr.AddRouter("/topic", "static") - ctx := context.NewContext() - obj := tr.Match("/topic", ctx) - if obj == nil || obj.(string) != "static" { - t.Fatal("/topic is a static route") - } - obj = tr.Match("/topic/1", ctx) - if obj == nil || obj.(string) != "wildcard" { - t.Fatal("/topic/1 is a wildcard route") - } -} - -func TestAddTree(t *testing.T) { - tr := NewTree() - tr.AddRouter("/shop/:id/account", "astaxie") - tr.AddRouter("/shop/:sd/ttt_:id(.+)_:page(.+).html", "astaxie") - t1 := NewTree() - t1.AddTree("/v1/zl", tr) - ctx := context.NewContext() - obj := t1.Match("/v1/zl/shop/123/account", ctx) - if obj == nil || obj.(string) != "astaxie" { - t.Fatal("/v1/zl/shop/:id/account can't get obj ") - } - if ctx.Input.ParamsLen() == 0 { - t.Fatal("get param error") - } - if ctx.Input.Param(":id") != "123" { - t.Fatal("get :id param error") - } - ctx.Input.Reset(ctx) - obj = t1.Match("/v1/zl/shop/123/ttt_1_12.html", ctx) - if obj == nil || obj.(string) != "astaxie" { - t.Fatal("/v1/zl//shop/:sd/ttt_:id(.+)_:page(.+).html can't get obj ") - } - if ctx.Input.ParamsLen() == 0 { - t.Fatal("get param error") - } - if ctx.Input.Param(":sd") != "123" || ctx.Input.Param(":id") != "1" || ctx.Input.Param(":page") != "12" { - t.Fatal("get :sd :id :page param error") - } - - t2 := NewTree() - t2.AddTree("/v1/:shopid", tr) - ctx.Input.Reset(ctx) - obj = t2.Match("/v1/zl/shop/123/account", ctx) - if obj == nil || obj.(string) != "astaxie" { - t.Fatal("/v1/:shopid/shop/:id/account can't get obj ") - } - if ctx.Input.ParamsLen() == 0 { - t.Fatal("get param error") - } - if ctx.Input.Param(":id") != "123" || ctx.Input.Param(":shopid") != "zl" { - t.Fatal("get :id :shopid param error") - } - ctx.Input.Reset(ctx) - obj = t2.Match("/v1/zl/shop/123/ttt_1_12.html", ctx) - if obj == nil || obj.(string) != "astaxie" { - t.Fatal("/v1/:shopid/shop/:sd/ttt_:id(.+)_:page(.+).html can't get obj ") - } - if ctx.Input.ParamsLen() == 0 { - t.Fatal("get :shopid param error") - } - if ctx.Input.Param(":sd") != "123" || ctx.Input.Param(":id") != "1" || ctx.Input.Param(":page") != "12" || ctx.Input.Param(":shopid") != "zl" { - t.Fatal("get :sd :id :page :shopid param error") - } -} - -func TestAddTree2(t *testing.T) { - tr := NewTree() - tr.AddRouter("/shop/:id/account", "astaxie") - tr.AddRouter("/shop/:sd/ttt_:id(.+)_:page(.+).html", "astaxie") - t3 := NewTree() - t3.AddTree("/:version(v1|v2)/:prefix", tr) - ctx := context.NewContext() - obj := t3.Match("/v1/zl/shop/123/account", ctx) - if obj == nil || obj.(string) != "astaxie" { - t.Fatal("/:version(v1|v2)/:prefix/shop/:id/account can't get obj ") - } - if ctx.Input.ParamsLen() == 0 { - t.Fatal("get param error") - } - if ctx.Input.Param(":id") != "123" || ctx.Input.Param(":prefix") != "zl" || ctx.Input.Param(":version") != "v1" { - t.Fatal("get :id :prefix :version param error") - } -} - -func TestAddTree3(t *testing.T) { - tr := NewTree() - tr.AddRouter("/create", "astaxie") - tr.AddRouter("/shop/:sd/account", "astaxie") - t3 := NewTree() - t3.AddTree("/table/:num", tr) - ctx := context.NewContext() - obj := t3.Match("/table/123/shop/123/account", ctx) - if obj == nil || obj.(string) != "astaxie" { - t.Fatal("/table/:num/shop/:sd/account can't get obj ") - } - if ctx.Input.ParamsLen() == 0 { - t.Fatal("get param error") - } - if ctx.Input.Param(":num") != "123" || ctx.Input.Param(":sd") != "123" { - t.Fatal("get :num :sd param error") - } - ctx.Input.Reset(ctx) - obj = t3.Match("/table/123/create", ctx) - if obj == nil || obj.(string) != "astaxie" { - t.Fatal("/table/:num/create can't get obj ") - } -} - -func TestAddTree4(t *testing.T) { - tr := NewTree() - tr.AddRouter("/create", "astaxie") - tr.AddRouter("/shop/:sd/:account", "astaxie") - t4 := NewTree() - t4.AddTree("/:info:int/:num/:id", tr) - ctx := context.NewContext() - obj := t4.Match("/12/123/456/shop/123/account", ctx) - if obj == nil || obj.(string) != "astaxie" { - t.Fatal("/:info:int/:num/:id/shop/:sd/:account can't get obj ") - } - if ctx.Input.ParamsLen() == 0 { - t.Fatal("get param error") - } - if ctx.Input.Param(":info") != "12" || ctx.Input.Param(":num") != "123" || - ctx.Input.Param(":id") != "456" || ctx.Input.Param(":sd") != "123" || - ctx.Input.Param(":account") != "account" { - t.Fatal("get :info :num :id :sd :account param error") - } - ctx.Input.Reset(ctx) - obj = t4.Match("/12/123/456/create", ctx) - if obj == nil || obj.(string) != "astaxie" { - t.Fatal("/:info:int/:num/:id/create can't get obj ") - } -} - -// Test for issue #1595 -func TestAddTree5(t *testing.T) { - tr := NewTree() - tr.AddRouter("/v1/shop/:id", "shopdetail") - tr.AddRouter("/v1/shop/", "shophome") - ctx := context.NewContext() - obj := tr.Match("/v1/shop/", ctx) - if obj == nil || obj.(string) != "shophome" { - t.Fatal("url /v1/shop/ need match router /v1/shop/ ") - } -} - -func TestSplitPath(t *testing.T) { - a := splitPath("") - if len(a) != 0 { - t.Fatal("/ should retrun []") - } - a = splitPath("/") - if len(a) != 0 { - t.Fatal("/ should retrun []") - } - a = splitPath("/admin") - if len(a) != 1 || a[0] != "admin" { - t.Fatal("/admin should retrun [admin]") - } - a = splitPath("/admin/") - if len(a) != 1 || a[0] != "admin" { - t.Fatal("/admin/ should retrun [admin]") - } - a = splitPath("/admin/users") - if len(a) != 2 || a[0] != "admin" || a[1] != "users" { - t.Fatal("/admin should retrun [admin users]") - } - a = splitPath("/admin/:id:int") - if len(a) != 2 || a[0] != "admin" || a[1] != ":id:int" { - t.Fatal("/admin should retrun [admin :id:int]") - } -} - -func TestSplitSegment(t *testing.T) { - - items := map[string]struct { - isReg bool - params []string - regStr string - }{ - "admin": {false, nil, ""}, - "*": {true, []string{":splat"}, ""}, - "*.*": {true, []string{".", ":path", ":ext"}, ""}, - ":id": {true, []string{":id"}, ""}, - "?:id": {true, []string{":", ":id"}, ""}, - ":id:int": {true, []string{":id"}, "([0-9]+)"}, - ":name:string": {true, []string{":name"}, `([\w]+)`}, - ":id([0-9]+)": {true, []string{":id"}, `([0-9]+)`}, - ":id([0-9]+)_:name": {true, []string{":id", ":name"}, `([0-9]+)_(.+)`}, - ":id(.+)_cms.html": {true, []string{":id"}, `(.+)_cms.html`}, - "cms_:id(.+)_:page(.+).html": {true, []string{":id", ":page"}, `cms_(.+)_(.+).html`}, - `:app(a|b|c)`: {true, []string{":app"}, `(a|b|c)`}, - `:app\((a|b|c)\)`: {true, []string{":app"}, `(.+)\((a|b|c)\)`}, - } - - for pattern, v := range items { - b, w, r := splitSegment(pattern) - if b != v.isReg || r != v.regStr || strings.Join(w, ",") != strings.Join(v.params, ",") { - t.Fatalf("%s should return %t,%s,%q, got %t,%s,%q", pattern, v.isReg, v.params, v.regStr, b, w, r) - } - } -} diff --git a/unregroute_test.go b/unregroute_test.go deleted file mode 100644 index 08b1b77b..00000000 --- a/unregroute_test.go +++ /dev/null @@ -1,226 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package beego - -import ( - "net/http" - "net/http/httptest" - "strings" - "testing" -) - -// -// The unregroute_test.go contains tests for the unregister route -// functionality, that allows overriding route paths in children project -// that embed parent routers. -// - -const contentRootOriginal = "ok-original-root" -const contentLevel1Original = "ok-original-level1" -const contentLevel2Original = "ok-original-level2" - -const contentRootReplacement = "ok-replacement-root" -const contentLevel1Replacement = "ok-replacement-level1" -const contentLevel2Replacement = "ok-replacement-level2" - -// TestPreUnregController will supply content for the original routes, -// before unregistration -type TestPreUnregController struct { - Controller -} - -func (tc *TestPreUnregController) GetFixedRoot() { - tc.Ctx.Output.Body([]byte(contentRootOriginal)) -} -func (tc *TestPreUnregController) GetFixedLevel1() { - tc.Ctx.Output.Body([]byte(contentLevel1Original)) -} -func (tc *TestPreUnregController) GetFixedLevel2() { - tc.Ctx.Output.Body([]byte(contentLevel2Original)) -} - -// TestPostUnregController will supply content for the overriding routes, -// after the original ones are unregistered. -type TestPostUnregController struct { - Controller -} - -func (tc *TestPostUnregController) GetFixedRoot() { - tc.Ctx.Output.Body([]byte(contentRootReplacement)) -} -func (tc *TestPostUnregController) GetFixedLevel1() { - tc.Ctx.Output.Body([]byte(contentLevel1Replacement)) -} -func (tc *TestPostUnregController) GetFixedLevel2() { - tc.Ctx.Output.Body([]byte(contentLevel2Replacement)) -} - -// TestUnregisterFixedRouteRoot replaces just the root fixed route path. -// In this case, for a path like "/level1/level2" or "/level1", those actions -// should remain intact, and continue to serve the original content. -func TestUnregisterFixedRouteRoot(t *testing.T) { - - var method = "GET" - - handler := NewControllerRegister() - handler.Add("/", &TestPreUnregController{}, "get:GetFixedRoot") - handler.Add("/level1", &TestPreUnregController{}, "get:GetFixedLevel1") - handler.Add("/level1/level2", &TestPreUnregController{}, "get:GetFixedLevel2") - - // Test original root - testHelperFnContentCheck(t, handler, "Test original root", - method, "/", contentRootOriginal) - - // Test original level 1 - testHelperFnContentCheck(t, handler, "Test original level 1", - method, "/level1", contentLevel1Original) - - // Test original level 2 - testHelperFnContentCheck(t, handler, "Test original level 2", - method, "/level1/level2", contentLevel2Original) - - // Remove only the root path - findAndRemoveSingleTree(handler.routers[method]) - - // Replace the root path TestPreUnregController action with the action from - // TestPostUnregController - handler.Add("/", &TestPostUnregController{}, "get:GetFixedRoot") - - // Test replacement root (expect change) - testHelperFnContentCheck(t, handler, "Test replacement root (expect change)", method, "/", contentRootReplacement) - - // Test level 1 (expect no change from the original) - testHelperFnContentCheck(t, handler, "Test level 1 (expect no change from the original)", method, "/level1", contentLevel1Original) - - // Test level 2 (expect no change from the original) - testHelperFnContentCheck(t, handler, "Test level 2 (expect no change from the original)", method, "/level1/level2", contentLevel2Original) - -} - -// TestUnregisterFixedRouteLevel1 replaces just the "/level1" fixed route path. -// In this case, for a path like "/level1/level2" or "/", those actions -// should remain intact, and continue to serve the original content. -func TestUnregisterFixedRouteLevel1(t *testing.T) { - - var method = "GET" - - handler := NewControllerRegister() - handler.Add("/", &TestPreUnregController{}, "get:GetFixedRoot") - handler.Add("/level1", &TestPreUnregController{}, "get:GetFixedLevel1") - handler.Add("/level1/level2", &TestPreUnregController{}, "get:GetFixedLevel2") - - // Test original root - testHelperFnContentCheck(t, handler, - "TestUnregisterFixedRouteLevel1.Test original root", - method, "/", contentRootOriginal) - - // Test original level 1 - testHelperFnContentCheck(t, handler, - "TestUnregisterFixedRouteLevel1.Test original level 1", - method, "/level1", contentLevel1Original) - - // Test original level 2 - testHelperFnContentCheck(t, handler, - "TestUnregisterFixedRouteLevel1.Test original level 2", - method, "/level1/level2", contentLevel2Original) - - // Remove only the level1 path - subPaths := splitPath("/level1") - if handler.routers[method].prefix == strings.Trim("/level1", "/ ") { - findAndRemoveSingleTree(handler.routers[method]) - } else { - findAndRemoveTree(subPaths, handler.routers[method], method) - } - - // Replace the "level1" path TestPreUnregController action with the action from - // TestPostUnregController - handler.Add("/level1", &TestPostUnregController{}, "get:GetFixedLevel1") - - // Test replacement root (expect no change from the original) - testHelperFnContentCheck(t, handler, "Test replacement root (expect no change from the original)", method, "/", contentRootOriginal) - - // Test level 1 (expect change) - testHelperFnContentCheck(t, handler, "Test level 1 (expect change)", method, "/level1", contentLevel1Replacement) - - // Test level 2 (expect no change from the original) - testHelperFnContentCheck(t, handler, "Test level 2 (expect no change from the original)", method, "/level1/level2", contentLevel2Original) - -} - -// TestUnregisterFixedRouteLevel2 unregisters just the "/level1/level2" fixed -// route path. In this case, for a path like "/level1" or "/", those actions -// should remain intact, and continue to serve the original content. -func TestUnregisterFixedRouteLevel2(t *testing.T) { - - var method = "GET" - - handler := NewControllerRegister() - handler.Add("/", &TestPreUnregController{}, "get:GetFixedRoot") - handler.Add("/level1", &TestPreUnregController{}, "get:GetFixedLevel1") - handler.Add("/level1/level2", &TestPreUnregController{}, "get:GetFixedLevel2") - - // Test original root - testHelperFnContentCheck(t, handler, - "TestUnregisterFixedRouteLevel1.Test original root", - method, "/", contentRootOriginal) - - // Test original level 1 - testHelperFnContentCheck(t, handler, - "TestUnregisterFixedRouteLevel1.Test original level 1", - method, "/level1", contentLevel1Original) - - // Test original level 2 - testHelperFnContentCheck(t, handler, - "TestUnregisterFixedRouteLevel1.Test original level 2", - method, "/level1/level2", contentLevel2Original) - - // Remove only the level2 path - subPaths := splitPath("/level1/level2") - if handler.routers[method].prefix == strings.Trim("/level1/level2", "/ ") { - findAndRemoveSingleTree(handler.routers[method]) - } else { - findAndRemoveTree(subPaths, handler.routers[method], method) - } - - // Replace the "/level1/level2" path TestPreUnregController action with the action from - // TestPostUnregController - handler.Add("/level1/level2", &TestPostUnregController{}, "get:GetFixedLevel2") - - // Test replacement root (expect no change from the original) - testHelperFnContentCheck(t, handler, "Test replacement root (expect no change from the original)", method, "/", contentRootOriginal) - - // Test level 1 (expect no change from the original) - testHelperFnContentCheck(t, handler, "Test level 1 (expect no change from the original)", method, "/level1", contentLevel1Original) - - // Test level 2 (expect change) - testHelperFnContentCheck(t, handler, "Test level 2 (expect change)", method, "/level1/level2", contentLevel2Replacement) - -} - -func testHelperFnContentCheck(t *testing.T, handler *ControllerRegister, - testName, method, path, expectedBodyContent string) { - - r, err := http.NewRequest(method, path, nil) - if err != nil { - t.Errorf("httpRecorderBodyTest NewRequest error: %v", err) - return - } - w := httptest.NewRecorder() - handler.ServeHTTP(w, r) - body := w.Body.String() - if body != expectedBodyContent { - t.Errorf("%s: expected [%s], got [%s];", testName, expectedBodyContent, body) - } -} diff --git a/utils/caller.go b/utils/caller.go deleted file mode 100644 index 73c52a62..00000000 --- a/utils/caller.go +++ /dev/null @@ -1,25 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package utils - -import ( - "reflect" - "runtime" -) - -// GetFuncName get function name -func GetFuncName(i interface{}) string { - return runtime.FuncForPC(reflect.ValueOf(i).Pointer()).Name() -} diff --git a/utils/caller_test.go b/utils/caller_test.go deleted file mode 100644 index 0675f0aa..00000000 --- a/utils/caller_test.go +++ /dev/null @@ -1,28 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package utils - -import ( - "strings" - "testing" -) - -func TestGetFuncName(t *testing.T) { - name := GetFuncName(TestGetFuncName) - t.Log(name) - if !strings.HasSuffix(name, ".TestGetFuncName") { - t.Error("get func name error") - } -} diff --git a/utils/captcha/LICENSE b/utils/captcha/LICENSE deleted file mode 100644 index 0ad73ae0..00000000 --- a/utils/captcha/LICENSE +++ /dev/null @@ -1,19 +0,0 @@ -Copyright (c) 2011-2014 Dmitry Chestnykh - -Permission is hereby granted, free of charge, to any person obtaining a copy -of this software and associated documentation files (the "Software"), to deal -in the Software without restriction, including without limitation the rights -to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -copies of the Software, and to permit persons to whom the Software is -furnished to do so, subject to the following conditions: - -The above copyright notice and this permission notice shall be included in -all copies or substantial portions of the Software. - -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -THE SOFTWARE. diff --git a/utils/captcha/README.md b/utils/captcha/README.md deleted file mode 100644 index dbc2026b..00000000 --- a/utils/captcha/README.md +++ /dev/null @@ -1,45 +0,0 @@ -# Captcha - -an example for use captcha - -``` -package controllers - -import ( - "github.com/astaxie/beego" - "github.com/astaxie/beego/cache" - "github.com/astaxie/beego/utils/captcha" -) - -var cpt *captcha.Captcha - -func init() { - // use beego cache system store the captcha data - store := cache.NewMemoryCache() - cpt = captcha.NewWithFilter("/captcha/", store) -} - -type MainController struct { - beego.Controller -} - -func (this *MainController) Get() { - this.TplName = "index.tpl" -} - -func (this *MainController) Post() { - this.TplName = "index.tpl" - - this.Data["Success"] = cpt.VerifyReq(this.Ctx.Request) -} -``` - -template usage - -``` -{{.Success}} -
- {{create_captcha}} - -
-``` diff --git a/utils/captcha/captcha.go b/utils/captcha/captcha.go deleted file mode 100644 index 42ac70d3..00000000 --- a/utils/captcha/captcha.go +++ /dev/null @@ -1,270 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Package captcha implements generation and verification of image CAPTCHAs. -// an example for use captcha -// -// ``` -// package controllers -// -// import ( -// "github.com/astaxie/beego" -// "github.com/astaxie/beego/cache" -// "github.com/astaxie/beego/utils/captcha" -// ) -// -// var cpt *captcha.Captcha -// -// func init() { -// // use beego cache system store the captcha data -// store := cache.NewMemoryCache() -// cpt = captcha.NewWithFilter("/captcha/", store) -// } -// -// type MainController struct { -// beego.Controller -// } -// -// func (this *MainController) Get() { -// this.TplName = "index.tpl" -// } -// -// func (this *MainController) Post() { -// this.TplName = "index.tpl" -// -// this.Data["Success"] = cpt.VerifyReq(this.Ctx.Request) -// } -// ``` -// -// template usage -// -// ``` -// {{.Success}} -//
-// {{create_captcha}} -// -//
-// ``` -package captcha - -import ( - "fmt" - "html/template" - "net/http" - "path" - "strings" - "time" - - "github.com/astaxie/beego" - "github.com/astaxie/beego/cache" - "github.com/astaxie/beego/context" - "github.com/astaxie/beego/logs" - "github.com/astaxie/beego/utils" -) - -var ( - defaultChars = []byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9} -) - -const ( - // default captcha attributes - challengeNums = 6 - expiration = 600 * time.Second - fieldIDName = "captcha_id" - fieldCaptchaName = "captcha" - cachePrefix = "captcha_" - defaultURLPrefix = "/captcha/" -) - -// Captcha struct -type Captcha struct { - // beego cache store - store cache.Cache - - // url prefix for captcha image - URLPrefix string - - // specify captcha id input field name - FieldIDName string - // specify captcha result input field name - FieldCaptchaName string - - // captcha image width and height - StdWidth int - StdHeight int - - // captcha chars nums - ChallengeNums int - - // captcha expiration seconds - Expiration time.Duration - - // cache key prefix - CachePrefix string -} - -// generate key string -func (c *Captcha) key(id string) string { - return c.CachePrefix + id -} - -// generate rand chars with default chars -func (c *Captcha) genRandChars() []byte { - return utils.RandomCreateBytes(c.ChallengeNums, defaultChars...) -} - -// Handler beego filter handler for serve captcha image -func (c *Captcha) Handler(ctx *context.Context) { - var chars []byte - - id := path.Base(ctx.Request.RequestURI) - if i := strings.Index(id, "."); i != -1 { - id = id[:i] - } - - key := c.key(id) - - if len(ctx.Input.Query("reload")) > 0 { - chars = c.genRandChars() - if err := c.store.Put(key, chars, c.Expiration); err != nil { - ctx.Output.SetStatus(500) - ctx.WriteString("captcha reload error") - logs.Error("Reload Create Captcha Error:", err) - return - } - } else { - if v, ok := c.store.Get(key).([]byte); ok { - chars = v - } else { - ctx.Output.SetStatus(404) - ctx.WriteString("captcha not found") - return - } - } - - img := NewImage(chars, c.StdWidth, c.StdHeight) - if _, err := img.WriteTo(ctx.ResponseWriter); err != nil { - logs.Error("Write Captcha Image Error:", err) - } -} - -// CreateCaptchaHTML template func for output html -func (c *Captcha) CreateCaptchaHTML() template.HTML { - value, err := c.CreateCaptcha() - if err != nil { - logs.Error("Create Captcha Error:", err) - return "" - } - - // create html - return template.HTML(fmt.Sprintf(``+ - ``+ - ``+ - ``, c.FieldIDName, value, c.URLPrefix, value, c.URLPrefix, value)) -} - -// CreateCaptcha create a new captcha id -func (c *Captcha) CreateCaptcha() (string, error) { - // generate captcha id - id := string(utils.RandomCreateBytes(15)) - - // get the captcha chars - chars := c.genRandChars() - - // save to store - if err := c.store.Put(c.key(id), chars, c.Expiration); err != nil { - return "", err - } - - return id, nil -} - -// VerifyReq verify from a request -func (c *Captcha) VerifyReq(req *http.Request) bool { - req.ParseForm() - return c.Verify(req.Form.Get(c.FieldIDName), req.Form.Get(c.FieldCaptchaName)) -} - -// Verify direct verify id and challenge string -func (c *Captcha) Verify(id string, challenge string) (success bool) { - if len(challenge) == 0 || len(id) == 0 { - return - } - - var chars []byte - - key := c.key(id) - - if v, ok := c.store.Get(key).([]byte); ok { - chars = v - } else { - return - } - - defer func() { - // finally remove it - c.store.Delete(key) - }() - - if len(chars) != len(challenge) { - return - } - // verify challenge - for i, c := range chars { - if c != challenge[i]-48 { - return - } - } - - return true -} - -// NewCaptcha create a new captcha.Captcha -func NewCaptcha(urlPrefix string, store cache.Cache) *Captcha { - cpt := &Captcha{} - cpt.store = store - cpt.FieldIDName = fieldIDName - cpt.FieldCaptchaName = fieldCaptchaName - cpt.ChallengeNums = challengeNums - cpt.Expiration = expiration - cpt.CachePrefix = cachePrefix - cpt.StdWidth = stdWidth - cpt.StdHeight = stdHeight - - if len(urlPrefix) == 0 { - urlPrefix = defaultURLPrefix - } - - if urlPrefix[len(urlPrefix)-1] != '/' { - urlPrefix += "/" - } - - cpt.URLPrefix = urlPrefix - - return cpt -} - -// NewWithFilter create a new captcha.Captcha and auto AddFilter for serve captacha image -// and add a template func for output html -func NewWithFilter(urlPrefix string, store cache.Cache) *Captcha { - cpt := NewCaptcha(urlPrefix, store) - - // create filter for serve captcha image - beego.InsertFilter(cpt.URLPrefix+"*", beego.BeforeRouter, cpt.Handler) - - // add to template func map - beego.AddFuncMap("create_captcha", cpt.CreateCaptchaHTML) - - return cpt -} diff --git a/utils/captcha/image.go b/utils/captcha/image.go deleted file mode 100644 index c3c9a83a..00000000 --- a/utils/captcha/image.go +++ /dev/null @@ -1,501 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package captcha - -import ( - "bytes" - "image" - "image/color" - "image/png" - "io" - "math" -) - -const ( - fontWidth = 11 - fontHeight = 18 - blackChar = 1 - - // Standard width and height of a captcha image. - stdWidth = 240 - stdHeight = 80 - // Maximum absolute skew factor of a single digit. - maxSkew = 0.7 - // Number of background circles. - circleCount = 20 -) - -var font = [][]byte{ - { // 0 - 0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 0, - 0, 0, 1, 1, 1, 1, 1, 1, 1, 0, 0, - 0, 1, 1, 1, 0, 0, 0, 1, 1, 1, 0, - 0, 1, 1, 0, 0, 0, 0, 0, 1, 1, 0, - 1, 1, 1, 0, 0, 0, 0, 0, 1, 1, 0, - 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, - 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, - 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, - 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, - 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, - 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, - 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, - 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, - 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 1, - 0, 1, 1, 0, 0, 0, 0, 0, 1, 1, 0, - 0, 1, 1, 1, 0, 0, 0, 1, 1, 1, 0, - 0, 0, 1, 1, 1, 1, 1, 1, 1, 0, 0, - 0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 0, - }, - { // 1 - 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, - 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, - 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, - 0, 0, 1, 1, 1, 1, 1, 0, 0, 0, 0, - 0, 0, 1, 1, 0, 1, 1, 0, 0, 0, 0, - 0, 0, 1, 0, 0, 1, 1, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, - 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, - 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, - }, - { // 2 - 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, - 0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, - 1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 0, - 0, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, - 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, - 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, - 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, - 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, - 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, - 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, - 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, - 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, - 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, - 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, - }, - { // 3 - 0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 0, - 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, - 1, 1, 0, 0, 0, 0, 0, 1, 1, 1, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, - 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, - 0, 0, 1, 1, 1, 1, 1, 0, 0, 0, 0, - 0, 0, 1, 1, 1, 1, 1, 1, 1, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, - 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, - 1, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, - 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, - 0, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, - }, - { // 4 - 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, - 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, - 0, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0, - 0, 0, 0, 0, 0, 1, 0, 1, 1, 0, 0, - 0, 0, 0, 0, 1, 1, 0, 1, 1, 0, 0, - 0, 0, 0, 1, 1, 0, 0, 1, 1, 0, 0, - 0, 0, 0, 1, 1, 0, 0, 1, 1, 0, 0, - 0, 0, 1, 1, 0, 0, 0, 1, 1, 0, 0, - 0, 1, 1, 0, 0, 0, 0, 1, 1, 0, 0, - 0, 1, 1, 0, 0, 0, 0, 1, 1, 0, 0, - 1, 1, 0, 0, 0, 0, 0, 1, 1, 0, 0, - 1, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, - 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, - 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, - 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, - }, - { // 5 - 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, - 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, - 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, - 0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, - 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, - 1, 1, 0, 0, 0, 0, 0, 1, 1, 1, 0, - 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, - 0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 0, - }, - { // 6 - 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 0, - 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 0, - 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, - 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 1, 1, 0, 0, 1, 1, 1, 1, 0, 0, 0, - 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 0, - 1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 0, - 1, 1, 1, 0, 0, 0, 0, 0, 1, 1, 1, - 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, - 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, - 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, - 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, - 0, 1, 1, 0, 0, 0, 0, 0, 1, 1, 1, - 0, 1, 1, 1, 0, 0, 0, 1, 1, 1, 0, - 0, 0, 1, 1, 1, 1, 1, 1, 1, 0, 0, - 0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 0, - }, - { // 7 - 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, - 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, - 1, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, - 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, - 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, - 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, - 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, - 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, - 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, - 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, - 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, - 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, - 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, - 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, - }, - { // 8 - 0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 0, - 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 0, - 0, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1, - 0, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, - 0, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, - 0, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, - 0, 0, 1, 1, 0, 0, 0, 0, 1, 1, 0, - 0, 0, 1, 1, 1, 1, 1, 1, 1, 0, 0, - 0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, - 0, 0, 1, 1, 1, 0, 1, 1, 1, 0, 0, - 0, 1, 1, 1, 0, 0, 0, 1, 1, 1, 0, - 1, 1, 1, 0, 0, 0, 0, 0, 1, 1, 1, - 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, - 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, - 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, - 1, 1, 1, 0, 0, 0, 0, 0, 1, 1, 0, - 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, - 0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 0, - }, - { // 9 - 0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 0, - 0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, - 0, 1, 1, 0, 0, 0, 0, 1, 1, 1, 0, - 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0, - 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, - 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, - 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, - 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 1, - 0, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1, - 0, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, - 0, 0, 0, 1, 1, 1, 1, 0, 0, 1, 1, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, - 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, - 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, - 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, - 0, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, - 0, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, - }, -} - -// Image struct -type Image struct { - *image.Paletted - numWidth int - numHeight int - dotSize int -} - -var prng = &siprng{} - -// randIntn returns a pseudorandom non-negative int in range [0, n). -func randIntn(n int) int { - return prng.Intn(n) -} - -// randInt returns a pseudorandom int in range [from, to]. -func randInt(from, to int) int { - return prng.Intn(to+1-from) + from -} - -// randFloat returns a pseudorandom float64 in range [from, to]. -func randFloat(from, to float64) float64 { - return (to-from)*prng.Float64() + from -} - -func randomPalette() color.Palette { - p := make([]color.Color, circleCount+1) - // Transparent color. - p[0] = color.RGBA{0xFF, 0xFF, 0xFF, 0x00} - // Primary color. - prim := color.RGBA{ - uint8(randIntn(129)), - uint8(randIntn(129)), - uint8(randIntn(129)), - 0xFF, - } - p[1] = prim - // Circle colors. - for i := 2; i <= circleCount; i++ { - p[i] = randomBrightness(prim, 255) - } - return p -} - -// NewImage returns a new captcha image of the given width and height with the -// given digits, where each digit must be in range 0-9. -func NewImage(digits []byte, width, height int) *Image { - m := new(Image) - m.Paletted = image.NewPaletted(image.Rect(0, 0, width, height), randomPalette()) - m.calculateSizes(width, height, len(digits)) - // Randomly position captcha inside the image. - maxx := width - (m.numWidth+m.dotSize)*len(digits) - m.dotSize - maxy := height - m.numHeight - m.dotSize*2 - var border int - if width > height { - border = height / 5 - } else { - border = width / 5 - } - x := randInt(border, maxx-border) - y := randInt(border, maxy-border) - // Draw digits. - for _, n := range digits { - m.drawDigit(font[n], x, y) - x += m.numWidth + m.dotSize - } - // Draw strike-through line. - m.strikeThrough() - // Apply wave distortion. - m.distort(randFloat(5, 10), randFloat(100, 200)) - // Fill image with random circles. - m.fillWithCircles(circleCount, m.dotSize) - return m -} - -// encodedPNG encodes an image to PNG and returns -// the result as a byte slice. -func (m *Image) encodedPNG() []byte { - var buf bytes.Buffer - if err := png.Encode(&buf, m.Paletted); err != nil { - panic(err.Error()) - } - return buf.Bytes() -} - -// WriteTo writes captcha image in PNG format into the given writer. -func (m *Image) WriteTo(w io.Writer) (int64, error) { - n, err := w.Write(m.encodedPNG()) - return int64(n), err -} - -func (m *Image) calculateSizes(width, height, ncount int) { - // Goal: fit all digits inside the image. - var border int - if width > height { - border = height / 4 - } else { - border = width / 4 - } - // Convert everything to floats for calculations. - w := float64(width - border*2) - h := float64(height - border*2) - // fw takes into account 1-dot spacing between digits. - fw := float64(fontWidth + 1) - fh := float64(fontHeight) - nc := float64(ncount) - // Calculate the width of a single digit taking into account only the - // width of the image. - nw := w / nc - // Calculate the height of a digit from this width. - nh := nw * fh / fw - // Digit too high? - if nh > h { - // Fit digits based on height. - nh = h - nw = fw / fh * nh - } - // Calculate dot size. - m.dotSize = int(nh / fh) - if m.dotSize < 1 { - m.dotSize = 1 - } - // Save everything, making the actual width smaller by 1 dot to account - // for spacing between digits. - m.numWidth = int(nw) - m.dotSize - m.numHeight = int(nh) -} - -func (m *Image) drawHorizLine(fromX, toX, y int, colorIdx uint8) { - for x := fromX; x <= toX; x++ { - m.SetColorIndex(x, y, colorIdx) - } -} - -func (m *Image) drawCircle(x, y, radius int, colorIdx uint8) { - f := 1 - radius - dfx := 1 - dfy := -2 * radius - xo := 0 - yo := radius - - m.SetColorIndex(x, y+radius, colorIdx) - m.SetColorIndex(x, y-radius, colorIdx) - m.drawHorizLine(x-radius, x+radius, y, colorIdx) - - for xo < yo { - if f >= 0 { - yo-- - dfy += 2 - f += dfy - } - xo++ - dfx += 2 - f += dfx - m.drawHorizLine(x-xo, x+xo, y+yo, colorIdx) - m.drawHorizLine(x-xo, x+xo, y-yo, colorIdx) - m.drawHorizLine(x-yo, x+yo, y+xo, colorIdx) - m.drawHorizLine(x-yo, x+yo, y-xo, colorIdx) - } -} - -func (m *Image) fillWithCircles(n, maxradius int) { - maxx := m.Bounds().Max.X - maxy := m.Bounds().Max.Y - for i := 0; i < n; i++ { - colorIdx := uint8(randInt(1, circleCount-1)) - r := randInt(1, maxradius) - m.drawCircle(randInt(r, maxx-r), randInt(r, maxy-r), r, colorIdx) - } -} - -func (m *Image) strikeThrough() { - maxx := m.Bounds().Max.X - maxy := m.Bounds().Max.Y - y := randInt(maxy/3, maxy-maxy/3) - amplitude := randFloat(5, 20) - period := randFloat(80, 180) - dx := 2.0 * math.Pi / period - for x := 0; x < maxx; x++ { - xo := amplitude * math.Cos(float64(y)*dx) - yo := amplitude * math.Sin(float64(x)*dx) - for yn := 0; yn < m.dotSize; yn++ { - r := randInt(0, m.dotSize) - m.drawCircle(x+int(xo), y+int(yo)+(yn*m.dotSize), r/2, 1) - } - } -} - -func (m *Image) drawDigit(digit []byte, x, y int) { - skf := randFloat(-maxSkew, maxSkew) - xs := float64(x) - r := m.dotSize / 2 - y += randInt(-r, r) - for yo := 0; yo < fontHeight; yo++ { - for xo := 0; xo < fontWidth; xo++ { - if digit[yo*fontWidth+xo] != blackChar { - continue - } - m.drawCircle(x+xo*m.dotSize, y+yo*m.dotSize, r, 1) - } - xs += skf - x = int(xs) - } -} - -func (m *Image) distort(amplude float64, period float64) { - w := m.Bounds().Max.X - h := m.Bounds().Max.Y - - oldm := m.Paletted - newm := image.NewPaletted(image.Rect(0, 0, w, h), oldm.Palette) - - dx := 2.0 * math.Pi / period - for x := 0; x < w; x++ { - for y := 0; y < h; y++ { - xo := amplude * math.Sin(float64(y)*dx) - yo := amplude * math.Cos(float64(x)*dx) - newm.SetColorIndex(x, y, oldm.ColorIndexAt(x+int(xo), y+int(yo))) - } - } - m.Paletted = newm -} - -func randomBrightness(c color.RGBA, max uint8) color.RGBA { - minc := min3(c.R, c.G, c.B) - maxc := max3(c.R, c.G, c.B) - if maxc > max { - return c - } - n := randIntn(int(max-maxc)) - int(minc) - return color.RGBA{ - uint8(int(c.R) + n), - uint8(int(c.G) + n), - uint8(int(c.B) + n), - c.A, - } -} - -func min3(x, y, z uint8) (m uint8) { - m = x - if y < m { - m = y - } - if z < m { - m = z - } - return -} - -func max3(x, y, z uint8) (m uint8) { - m = x - if y > m { - m = y - } - if z > m { - m = z - } - return -} diff --git a/utils/captcha/image_test.go b/utils/captcha/image_test.go deleted file mode 100644 index 5e35b7f7..00000000 --- a/utils/captcha/image_test.go +++ /dev/null @@ -1,52 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package captcha - -import ( - "testing" - - "github.com/astaxie/beego/utils" -) - -type byteCounter struct { - n int64 -} - -func (bc *byteCounter) Write(b []byte) (int, error) { - bc.n += int64(len(b)) - return len(b), nil -} - -func BenchmarkNewImage(b *testing.B) { - b.StopTimer() - d := utils.RandomCreateBytes(challengeNums, defaultChars...) - b.StartTimer() - for i := 0; i < b.N; i++ { - NewImage(d, stdWidth, stdHeight) - } -} - -func BenchmarkImageWriteTo(b *testing.B) { - b.StopTimer() - d := utils.RandomCreateBytes(challengeNums, defaultChars...) - b.StartTimer() - counter := &byteCounter{} - for i := 0; i < b.N; i++ { - img := NewImage(d, stdWidth, stdHeight) - img.WriteTo(counter) - b.SetBytes(counter.n) - counter.n = 0 - } -} diff --git a/utils/captcha/siprng.go b/utils/captcha/siprng.go deleted file mode 100644 index 5e256cf9..00000000 --- a/utils/captcha/siprng.go +++ /dev/null @@ -1,277 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package captcha - -import ( - "crypto/rand" - "encoding/binary" - "io" - "sync" -) - -// siprng is PRNG based on SipHash-2-4. -type siprng struct { - mu sync.Mutex - k0, k1, ctr uint64 -} - -// siphash implements SipHash-2-4, accepting a uint64 as a message. -func siphash(k0, k1, m uint64) uint64 { - // Initialization. - v0 := k0 ^ 0x736f6d6570736575 - v1 := k1 ^ 0x646f72616e646f6d - v2 := k0 ^ 0x6c7967656e657261 - v3 := k1 ^ 0x7465646279746573 - t := uint64(8) << 56 - - // Compression. - v3 ^= m - - // Round 1. - v0 += v1 - v1 = v1<<13 | v1>>(64-13) - v1 ^= v0 - v0 = v0<<32 | v0>>(64-32) - - v2 += v3 - v3 = v3<<16 | v3>>(64-16) - v3 ^= v2 - - v0 += v3 - v3 = v3<<21 | v3>>(64-21) - v3 ^= v0 - - v2 += v1 - v1 = v1<<17 | v1>>(64-17) - v1 ^= v2 - v2 = v2<<32 | v2>>(64-32) - - // Round 2. - v0 += v1 - v1 = v1<<13 | v1>>(64-13) - v1 ^= v0 - v0 = v0<<32 | v0>>(64-32) - - v2 += v3 - v3 = v3<<16 | v3>>(64-16) - v3 ^= v2 - - v0 += v3 - v3 = v3<<21 | v3>>(64-21) - v3 ^= v0 - - v2 += v1 - v1 = v1<<17 | v1>>(64-17) - v1 ^= v2 - v2 = v2<<32 | v2>>(64-32) - - v0 ^= m - - // Compress last block. - v3 ^= t - - // Round 1. - v0 += v1 - v1 = v1<<13 | v1>>(64-13) - v1 ^= v0 - v0 = v0<<32 | v0>>(64-32) - - v2 += v3 - v3 = v3<<16 | v3>>(64-16) - v3 ^= v2 - - v0 += v3 - v3 = v3<<21 | v3>>(64-21) - v3 ^= v0 - - v2 += v1 - v1 = v1<<17 | v1>>(64-17) - v1 ^= v2 - v2 = v2<<32 | v2>>(64-32) - - // Round 2. - v0 += v1 - v1 = v1<<13 | v1>>(64-13) - v1 ^= v0 - v0 = v0<<32 | v0>>(64-32) - - v2 += v3 - v3 = v3<<16 | v3>>(64-16) - v3 ^= v2 - - v0 += v3 - v3 = v3<<21 | v3>>(64-21) - v3 ^= v0 - - v2 += v1 - v1 = v1<<17 | v1>>(64-17) - v1 ^= v2 - v2 = v2<<32 | v2>>(64-32) - - v0 ^= t - - // Finalization. - v2 ^= 0xff - - // Round 1. - v0 += v1 - v1 = v1<<13 | v1>>(64-13) - v1 ^= v0 - v0 = v0<<32 | v0>>(64-32) - - v2 += v3 - v3 = v3<<16 | v3>>(64-16) - v3 ^= v2 - - v0 += v3 - v3 = v3<<21 | v3>>(64-21) - v3 ^= v0 - - v2 += v1 - v1 = v1<<17 | v1>>(64-17) - v1 ^= v2 - v2 = v2<<32 | v2>>(64-32) - - // Round 2. - v0 += v1 - v1 = v1<<13 | v1>>(64-13) - v1 ^= v0 - v0 = v0<<32 | v0>>(64-32) - - v2 += v3 - v3 = v3<<16 | v3>>(64-16) - v3 ^= v2 - - v0 += v3 - v3 = v3<<21 | v3>>(64-21) - v3 ^= v0 - - v2 += v1 - v1 = v1<<17 | v1>>(64-17) - v1 ^= v2 - v2 = v2<<32 | v2>>(64-32) - - // Round 3. - v0 += v1 - v1 = v1<<13 | v1>>(64-13) - v1 ^= v0 - v0 = v0<<32 | v0>>(64-32) - - v2 += v3 - v3 = v3<<16 | v3>>(64-16) - v3 ^= v2 - - v0 += v3 - v3 = v3<<21 | v3>>(64-21) - v3 ^= v0 - - v2 += v1 - v1 = v1<<17 | v1>>(64-17) - v1 ^= v2 - v2 = v2<<32 | v2>>(64-32) - - // Round 4. - v0 += v1 - v1 = v1<<13 | v1>>(64-13) - v1 ^= v0 - v0 = v0<<32 | v0>>(64-32) - - v2 += v3 - v3 = v3<<16 | v3>>(64-16) - v3 ^= v2 - - v0 += v3 - v3 = v3<<21 | v3>>(64-21) - v3 ^= v0 - - v2 += v1 - v1 = v1<<17 | v1>>(64-17) - v1 ^= v2 - v2 = v2<<32 | v2>>(64-32) - - return v0 ^ v1 ^ v2 ^ v3 -} - -// rekey sets a new PRNG key, which is read from crypto/rand. -func (p *siprng) rekey() { - var k [16]byte - if _, err := io.ReadFull(rand.Reader, k[:]); err != nil { - panic(err.Error()) - } - p.k0 = binary.LittleEndian.Uint64(k[0:8]) - p.k1 = binary.LittleEndian.Uint64(k[8:16]) - p.ctr = 1 -} - -// Uint64 returns a new pseudorandom uint64. -// It rekeys PRNG on the first call and every 64 MB of generated data. -func (p *siprng) Uint64() uint64 { - p.mu.Lock() - if p.ctr == 0 || p.ctr > 8*1024*1024 { - p.rekey() - } - v := siphash(p.k0, p.k1, p.ctr) - p.ctr++ - p.mu.Unlock() - return v -} - -func (p *siprng) Int63() int64 { - return int64(p.Uint64() & 0x7fffffffffffffff) -} - -func (p *siprng) Uint32() uint32 { - return uint32(p.Uint64()) -} - -func (p *siprng) Int31() int32 { - return int32(p.Uint32() & 0x7fffffff) -} - -func (p *siprng) Intn(n int) int { - if n <= 0 { - panic("invalid argument to Intn") - } - if n <= 1<<31-1 { - return int(p.Int31n(int32(n))) - } - return int(p.Int63n(int64(n))) -} - -func (p *siprng) Int63n(n int64) int64 { - if n <= 0 { - panic("invalid argument to Int63n") - } - max := int64((1 << 63) - 1 - (1<<63)%uint64(n)) - v := p.Int63() - for v > max { - v = p.Int63() - } - return v % n -} - -func (p *siprng) Int31n(n int32) int32 { - if n <= 0 { - panic("invalid argument to Int31n") - } - max := int32((1 << 31) - 1 - (1<<31)%uint32(n)) - v := p.Int31() - for v > max { - v = p.Int31() - } - return v % n -} - -func (p *siprng) Float64() float64 { return float64(p.Int63()) / (1 << 63) } diff --git a/utils/captcha/siprng_test.go b/utils/captcha/siprng_test.go deleted file mode 100644 index 189d3d3c..00000000 --- a/utils/captcha/siprng_test.go +++ /dev/null @@ -1,33 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package captcha - -import "testing" - -func TestSiphash(t *testing.T) { - good := uint64(0xe849e8bb6ffe2567) - cur := siphash(0, 0, 0) - if cur != good { - t.Fatalf("siphash: expected %x, got %x", good, cur) - } -} - -func BenchmarkSiprng(b *testing.B) { - b.SetBytes(8) - p := &siprng{} - for i := 0; i < b.N; i++ { - p.Uint64() - } -} diff --git a/utils/debug.go b/utils/debug.go deleted file mode 100644 index 93c27b70..00000000 --- a/utils/debug.go +++ /dev/null @@ -1,478 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package utils - -import ( - "bytes" - "fmt" - "log" - "reflect" - "runtime" -) - -var ( - dunno = []byte("???") - centerDot = []byte("·") - dot = []byte(".") -) - -type pointerInfo struct { - prev *pointerInfo - n int - addr uintptr - pos int - used []int -} - -// Display print the data in console -func Display(data ...interface{}) { - display(true, data...) -} - -// GetDisplayString return data print string -func GetDisplayString(data ...interface{}) string { - return display(false, data...) -} - -func display(displayed bool, data ...interface{}) string { - var pc, file, line, ok = runtime.Caller(2) - - if !ok { - return "" - } - - var buf = new(bytes.Buffer) - - fmt.Fprintf(buf, "[Debug] at %s() [%s:%d]\n", function(pc), file, line) - - fmt.Fprintf(buf, "\n[Variables]\n") - - for i := 0; i < len(data); i += 2 { - var output = fomateinfo(len(data[i].(string))+3, data[i+1]) - fmt.Fprintf(buf, "%s = %s", data[i], output) - } - - if displayed { - log.Print(buf) - } - return buf.String() -} - -// return data dump and format bytes -func fomateinfo(headlen int, data ...interface{}) []byte { - var buf = new(bytes.Buffer) - - if len(data) > 1 { - fmt.Fprint(buf, " ") - - fmt.Fprint(buf, "[") - - fmt.Fprintln(buf) - } - - for k, v := range data { - var buf2 = new(bytes.Buffer) - var pointers *pointerInfo - var interfaces = make([]reflect.Value, 0, 10) - - printKeyValue(buf2, reflect.ValueOf(v), &pointers, &interfaces, nil, true, " ", 1) - - if k < len(data)-1 { - fmt.Fprint(buf2, ", ") - } - - fmt.Fprintln(buf2) - - buf.Write(buf2.Bytes()) - } - - if len(data) > 1 { - fmt.Fprintln(buf) - - fmt.Fprint(buf, " ") - - fmt.Fprint(buf, "]") - } - - return buf.Bytes() -} - -// check data is golang basic type -func isSimpleType(val reflect.Value, kind reflect.Kind, pointers **pointerInfo, interfaces *[]reflect.Value) bool { - switch kind { - case reflect.Bool: - return true - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: - return true - case reflect.Uint8, reflect.Uint16, reflect.Uint, reflect.Uint32, reflect.Uint64: - return true - case reflect.Float32, reflect.Float64: - return true - case reflect.Complex64, reflect.Complex128: - return true - case reflect.String: - return true - case reflect.Chan: - return true - case reflect.Invalid: - return true - case reflect.Interface: - for _, in := range *interfaces { - if reflect.DeepEqual(in, val) { - return true - } - } - return false - case reflect.UnsafePointer: - if val.IsNil() { - return true - } - - var elem = val.Elem() - - if isSimpleType(elem, elem.Kind(), pointers, interfaces) { - return true - } - - var addr = val.Elem().UnsafeAddr() - - for p := *pointers; p != nil; p = p.prev { - if addr == p.addr { - return true - } - } - - return false - } - - return false -} - -// dump value -func printKeyValue(buf *bytes.Buffer, val reflect.Value, pointers **pointerInfo, interfaces *[]reflect.Value, structFilter func(string, string) bool, formatOutput bool, indent string, level int) { - var t = val.Kind() - - switch t { - case reflect.Bool: - fmt.Fprint(buf, val.Bool()) - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: - fmt.Fprint(buf, val.Int()) - case reflect.Uint8, reflect.Uint16, reflect.Uint, reflect.Uint32, reflect.Uint64: - fmt.Fprint(buf, val.Uint()) - case reflect.Float32, reflect.Float64: - fmt.Fprint(buf, val.Float()) - case reflect.Complex64, reflect.Complex128: - fmt.Fprint(buf, val.Complex()) - case reflect.UnsafePointer: - fmt.Fprintf(buf, "unsafe.Pointer(0x%X)", val.Pointer()) - case reflect.Ptr: - if val.IsNil() { - fmt.Fprint(buf, "nil") - return - } - - var addr = val.Elem().UnsafeAddr() - - for p := *pointers; p != nil; p = p.prev { - if addr == p.addr { - p.used = append(p.used, buf.Len()) - fmt.Fprintf(buf, "0x%X", addr) - return - } - } - - *pointers = &pointerInfo{ - prev: *pointers, - addr: addr, - pos: buf.Len(), - used: make([]int, 0), - } - - fmt.Fprint(buf, "&") - - printKeyValue(buf, val.Elem(), pointers, interfaces, structFilter, formatOutput, indent, level) - case reflect.String: - fmt.Fprint(buf, "\"", val.String(), "\"") - case reflect.Interface: - var value = val.Elem() - - if !value.IsValid() { - fmt.Fprint(buf, "nil") - } else { - for _, in := range *interfaces { - if reflect.DeepEqual(in, val) { - fmt.Fprint(buf, "repeat") - return - } - } - - *interfaces = append(*interfaces, val) - - printKeyValue(buf, value, pointers, interfaces, structFilter, formatOutput, indent, level+1) - } - case reflect.Struct: - var t = val.Type() - - fmt.Fprint(buf, t) - fmt.Fprint(buf, "{") - - for i := 0; i < val.NumField(); i++ { - if formatOutput { - fmt.Fprintln(buf) - } else { - fmt.Fprint(buf, " ") - } - - var name = t.Field(i).Name - - if formatOutput { - for ind := 0; ind < level; ind++ { - fmt.Fprint(buf, indent) - } - } - - fmt.Fprint(buf, name) - fmt.Fprint(buf, ": ") - - if structFilter != nil && structFilter(t.String(), name) { - fmt.Fprint(buf, "ignore") - } else { - printKeyValue(buf, val.Field(i), pointers, interfaces, structFilter, formatOutput, indent, level+1) - } - - fmt.Fprint(buf, ",") - } - - if formatOutput { - fmt.Fprintln(buf) - - for ind := 0; ind < level-1; ind++ { - fmt.Fprint(buf, indent) - } - } else { - fmt.Fprint(buf, " ") - } - - fmt.Fprint(buf, "}") - case reflect.Array, reflect.Slice: - fmt.Fprint(buf, val.Type()) - fmt.Fprint(buf, "{") - - var allSimple = true - - for i := 0; i < val.Len(); i++ { - var elem = val.Index(i) - - var isSimple = isSimpleType(elem, elem.Kind(), pointers, interfaces) - - if !isSimple { - allSimple = false - } - - if formatOutput && !isSimple { - fmt.Fprintln(buf) - } else { - fmt.Fprint(buf, " ") - } - - if formatOutput && !isSimple { - for ind := 0; ind < level; ind++ { - fmt.Fprint(buf, indent) - } - } - - printKeyValue(buf, elem, pointers, interfaces, structFilter, formatOutput, indent, level+1) - - if i != val.Len()-1 || !allSimple { - fmt.Fprint(buf, ",") - } - } - - if formatOutput && !allSimple { - fmt.Fprintln(buf) - - for ind := 0; ind < level-1; ind++ { - fmt.Fprint(buf, indent) - } - } else { - fmt.Fprint(buf, " ") - } - - fmt.Fprint(buf, "}") - case reflect.Map: - var t = val.Type() - var keys = val.MapKeys() - - fmt.Fprint(buf, t) - fmt.Fprint(buf, "{") - - var allSimple = true - - for i := 0; i < len(keys); i++ { - var elem = val.MapIndex(keys[i]) - - var isSimple = isSimpleType(elem, elem.Kind(), pointers, interfaces) - - if !isSimple { - allSimple = false - } - - if formatOutput && !isSimple { - fmt.Fprintln(buf) - } else { - fmt.Fprint(buf, " ") - } - - if formatOutput && !isSimple { - for ind := 0; ind <= level; ind++ { - fmt.Fprint(buf, indent) - } - } - - printKeyValue(buf, keys[i], pointers, interfaces, structFilter, formatOutput, indent, level+1) - fmt.Fprint(buf, ": ") - printKeyValue(buf, elem, pointers, interfaces, structFilter, formatOutput, indent, level+1) - - if i != val.Len()-1 || !allSimple { - fmt.Fprint(buf, ",") - } - } - - if formatOutput && !allSimple { - fmt.Fprintln(buf) - - for ind := 0; ind < level-1; ind++ { - fmt.Fprint(buf, indent) - } - } else { - fmt.Fprint(buf, " ") - } - - fmt.Fprint(buf, "}") - case reflect.Chan: - fmt.Fprint(buf, val.Type()) - case reflect.Invalid: - fmt.Fprint(buf, "invalid") - default: - fmt.Fprint(buf, "unknow") - } -} - -// PrintPointerInfo dump pointer value -func PrintPointerInfo(buf *bytes.Buffer, headlen int, pointers *pointerInfo) { - var anyused = false - var pointerNum = 0 - - for p := pointers; p != nil; p = p.prev { - if len(p.used) > 0 { - anyused = true - } - pointerNum++ - p.n = pointerNum - } - - if anyused { - var pointerBufs = make([][]rune, pointerNum+1) - - for i := 0; i < len(pointerBufs); i++ { - var pointerBuf = make([]rune, buf.Len()+headlen) - - for j := 0; j < len(pointerBuf); j++ { - pointerBuf[j] = ' ' - } - - pointerBufs[i] = pointerBuf - } - - for pn := 0; pn <= pointerNum; pn++ { - for p := pointers; p != nil; p = p.prev { - if len(p.used) > 0 && p.n >= pn { - if pn == p.n { - pointerBufs[pn][p.pos+headlen] = '└' - - var maxpos = 0 - - for i, pos := range p.used { - if i < len(p.used)-1 { - pointerBufs[pn][pos+headlen] = '┴' - } else { - pointerBufs[pn][pos+headlen] = '┘' - } - - maxpos = pos - } - - for i := 0; i < maxpos-p.pos-1; i++ { - if pointerBufs[pn][i+p.pos+headlen+1] == ' ' { - pointerBufs[pn][i+p.pos+headlen+1] = '─' - } - } - } else { - pointerBufs[pn][p.pos+headlen] = '│' - - for _, pos := range p.used { - if pointerBufs[pn][pos+headlen] == ' ' { - pointerBufs[pn][pos+headlen] = '│' - } else { - pointerBufs[pn][pos+headlen] = '┼' - } - } - } - } - } - - buf.WriteString(string(pointerBufs[pn]) + "\n") - } - } -} - -// Stack get stack bytes -func Stack(skip int, indent string) []byte { - var buf = new(bytes.Buffer) - - for i := skip; ; i++ { - var pc, file, line, ok = runtime.Caller(i) - - if !ok { - break - } - - buf.WriteString(indent) - - fmt.Fprintf(buf, "at %s() [%s:%d]\n", function(pc), file, line) - } - - return buf.Bytes() -} - -// return the name of the function containing the PC if possible, -func function(pc uintptr) []byte { - fn := runtime.FuncForPC(pc) - if fn == nil { - return dunno - } - name := []byte(fn.Name()) - // The name includes the path name to the package, which is unnecessary - // since the file name is already included. Plus, it has center dots. - // That is, we see - // runtime/debug.*T·ptrmethod - // and want - // *T.ptrmethod - if period := bytes.Index(name, dot); period >= 0 { - name = name[period+1:] - } - name = bytes.Replace(name, centerDot, dot, -1) - return name -} diff --git a/utils/debug_test.go b/utils/debug_test.go deleted file mode 100644 index efb8924e..00000000 --- a/utils/debug_test.go +++ /dev/null @@ -1,46 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package utils - -import ( - "testing" -) - -type mytype struct { - next *mytype - prev *mytype -} - -func TestPrint(t *testing.T) { - Display("v1", 1, "v2", 2, "v3", 3) -} - -func TestPrintPoint(t *testing.T) { - var v1 = new(mytype) - var v2 = new(mytype) - - v1.prev = nil - v1.next = v2 - - v2.prev = v1 - v2.next = nil - - Display("v1", v1, "v2", v2) -} - -func TestPrintString(t *testing.T) { - str := GetDisplayString("v1", 1, "v2", 2) - println(str) -} diff --git a/utils/file.go b/utils/file.go deleted file mode 100644 index 6090eb17..00000000 --- a/utils/file.go +++ /dev/null @@ -1,101 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package utils - -import ( - "bufio" - "errors" - "io" - "os" - "path/filepath" - "regexp" -) - -// SelfPath gets compiled executable file absolute path -func SelfPath() string { - path, _ := filepath.Abs(os.Args[0]) - return path -} - -// SelfDir gets compiled executable file directory -func SelfDir() string { - return filepath.Dir(SelfPath()) -} - -// FileExists reports whether the named file or directory exists. -func FileExists(name string) bool { - if _, err := os.Stat(name); err != nil { - if os.IsNotExist(err) { - return false - } - } - return true -} - -// SearchFile Search a file in paths. -// this is often used in search config file in /etc ~/ -func SearchFile(filename string, paths ...string) (fullpath string, err error) { - for _, path := range paths { - if fullpath = filepath.Join(path, filename); FileExists(fullpath) { - return - } - } - err = errors.New(fullpath + " not found in paths") - return -} - -// GrepFile like command grep -E -// for example: GrepFile(`^hello`, "hello.txt") -// \n is striped while read -func GrepFile(patten string, filename string) (lines []string, err error) { - re, err := regexp.Compile(patten) - if err != nil { - return - } - - fd, err := os.Open(filename) - if err != nil { - return - } - lines = make([]string, 0) - reader := bufio.NewReader(fd) - prefix := "" - var isLongLine bool - for { - byteLine, isPrefix, er := reader.ReadLine() - if er != nil && er != io.EOF { - return nil, er - } - if er == io.EOF { - break - } - line := string(byteLine) - if isPrefix { - prefix += line - continue - } else { - isLongLine = true - } - - line = prefix + line - if isLongLine { - prefix = "" - } - if re.MatchString(line) { - lines = append(lines, line) - } - } - return lines, nil -} diff --git a/utils/file_test.go b/utils/file_test.go deleted file mode 100644 index 84443e20..00000000 --- a/utils/file_test.go +++ /dev/null @@ -1,76 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package utils - -import ( - "path/filepath" - "reflect" - "testing" - - "github.com/stretchr/testify/assert" -) - -var noExistedFile = "/tmp/not_existed_file" - -func TestSelfPath(t *testing.T) { - path := SelfPath() - if path == "" { - t.Error("path cannot be empty") - } - t.Logf("SelfPath: %s", path) -} - -func TestSelfDir(t *testing.T) { - dir := SelfDir() - t.Logf("SelfDir: %s", dir) -} - -func TestFileExists(t *testing.T) { - if !FileExists("./file.go") { - t.Errorf("./file.go should exists, but it didn't") - } - - if FileExists(noExistedFile) { - t.Errorf("Weird, how could this file exists: %s", noExistedFile) - } -} - -func TestSearchFile(t *testing.T) { - path, err := SearchFile(filepath.Base(SelfPath()), SelfDir()) - if err != nil { - t.Error(err) - } - t.Log(path) - - _, err = SearchFile(noExistedFile, ".") - if err == nil { - t.Errorf("err shouldnt be nil, got path: %s", SelfDir()) - } -} - -func TestGrepFile(t *testing.T) { - _, err := GrepFile("", noExistedFile) - if err == nil { - t.Error("expect file-not-existed error, but got nothing") - } - - path := filepath.Join(".", "testdata", "grepe.test") - lines, err := GrepFile(`^\s*[^#]+`, path) - assert.Nil(t, err) - - if !reflect.DeepEqual(lines, []string{"hello", "world"}) { - t.Errorf("expect [hello world], but receive %v", lines) - } -} diff --git a/utils/mail.go b/utils/mail.go deleted file mode 100644 index 80a366ca..00000000 --- a/utils/mail.go +++ /dev/null @@ -1,424 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package utils - -import ( - "bytes" - "encoding/base64" - "encoding/json" - "errors" - "fmt" - "io" - "mime" - "mime/multipart" - "net/mail" - "net/smtp" - "net/textproto" - "os" - "path" - "path/filepath" - "strconv" - "strings" - "sync" -) - -const ( - maxLineLength = 76 - - upperhex = "0123456789ABCDEF" -) - -// Email is the type used for email messages -type Email struct { - Auth smtp.Auth - Identity string `json:"identity"` - Username string `json:"username"` - Password string `json:"password"` - Host string `json:"host"` - Port int `json:"port"` - From string `json:"from"` - To []string - Bcc []string - Cc []string - Subject string - Text string // Plaintext message (optional) - HTML string // Html message (optional) - Headers textproto.MIMEHeader - Attachments []*Attachment - ReadReceipt []string -} - -// Attachment is a struct representing an email attachment. -// Based on the mime/multipart.FileHeader struct, Attachment contains the name, MIMEHeader, and content of the attachment in question -type Attachment struct { - Filename string - Header textproto.MIMEHeader - Content []byte -} - -// NewEMail create new Email struct with config json. -// config json is followed from Email struct fields. -func NewEMail(config string) *Email { - e := new(Email) - e.Headers = textproto.MIMEHeader{} - err := json.Unmarshal([]byte(config), e) - if err != nil { - return nil - } - return e -} - -// Bytes Make all send information to byte -func (e *Email) Bytes() ([]byte, error) { - buff := &bytes.Buffer{} - w := multipart.NewWriter(buff) - // Set the appropriate headers (overwriting any conflicts) - // Leave out Bcc (only included in envelope headers) - e.Headers.Set("To", strings.Join(e.To, ",")) - if e.Cc != nil { - e.Headers.Set("Cc", strings.Join(e.Cc, ",")) - } - e.Headers.Set("From", e.From) - e.Headers.Set("Subject", e.Subject) - if len(e.ReadReceipt) != 0 { - e.Headers.Set("Disposition-Notification-To", strings.Join(e.ReadReceipt, ",")) - } - e.Headers.Set("MIME-Version", "1.0") - - // Write the envelope headers (including any custom headers) - if err := headerToBytes(buff, e.Headers); err != nil { - return nil, fmt.Errorf("Failed to render message headers: %s", err) - } - - e.Headers.Set("Content-Type", fmt.Sprintf("multipart/mixed;\r\n boundary=%s\r\n", w.Boundary())) - fmt.Fprintf(buff, "%s:", "Content-Type") - fmt.Fprintf(buff, " %s\r\n", fmt.Sprintf("multipart/mixed;\r\n boundary=%s\r\n", w.Boundary())) - - // Start the multipart/mixed part - fmt.Fprintf(buff, "--%s\r\n", w.Boundary()) - header := textproto.MIMEHeader{} - // Check to see if there is a Text or HTML field - if e.Text != "" || e.HTML != "" { - subWriter := multipart.NewWriter(buff) - // Create the multipart alternative part - header.Set("Content-Type", fmt.Sprintf("multipart/alternative;\r\n boundary=%s\r\n", subWriter.Boundary())) - // Write the header - if err := headerToBytes(buff, header); err != nil { - return nil, fmt.Errorf("Failed to render multipart message headers: %s", err) - } - // Create the body sections - if e.Text != "" { - header.Set("Content-Type", fmt.Sprintf("text/plain; charset=UTF-8")) - header.Set("Content-Transfer-Encoding", "quoted-printable") - if _, err := subWriter.CreatePart(header); err != nil { - return nil, err - } - // Write the text - if err := quotePrintEncode(buff, e.Text); err != nil { - return nil, err - } - } - if e.HTML != "" { - header.Set("Content-Type", fmt.Sprintf("text/html; charset=UTF-8")) - header.Set("Content-Transfer-Encoding", "quoted-printable") - if _, err := subWriter.CreatePart(header); err != nil { - return nil, err - } - // Write the text - if err := quotePrintEncode(buff, e.HTML); err != nil { - return nil, err - } - } - if err := subWriter.Close(); err != nil { - return nil, err - } - } - // Create attachment part, if necessary - for _, a := range e.Attachments { - ap, err := w.CreatePart(a.Header) - if err != nil { - return nil, err - } - // Write the base64Wrapped content to the part - base64Wrap(ap, a.Content) - } - if err := w.Close(); err != nil { - return nil, err - } - return buff.Bytes(), nil -} - -// AttachFile Add attach file to the send mail -func (e *Email) AttachFile(args ...string) (a *Attachment, err error) { - if len(args) < 1 || len(args) > 2 { // change && to || - err = errors.New("Must specify a file name and number of parameters can not exceed at least two") - return - } - filename := args[0] - id := "" - if len(args) > 1 { - id = args[1] - } - f, err := os.Open(filename) - if err != nil { - return - } - defer f.Close() - ct := mime.TypeByExtension(filepath.Ext(filename)) - basename := path.Base(filename) - return e.Attach(f, basename, ct, id) -} - -// Attach is used to attach content from an io.Reader to the email. -// Parameters include an io.Reader, the desired filename for the attachment, and the Content-Type. -func (e *Email) Attach(r io.Reader, filename string, args ...string) (a *Attachment, err error) { - if len(args) < 1 || len(args) > 2 { // change && to || - err = errors.New("Must specify the file type and number of parameters can not exceed at least two") - return - } - c := args[0] //Content-Type - id := "" - if len(args) > 1 { - id = args[1] //Content-ID - } - var buffer bytes.Buffer - if _, err = io.Copy(&buffer, r); err != nil { - return - } - at := &Attachment{ - Filename: filename, - Header: textproto.MIMEHeader{}, - Content: buffer.Bytes(), - } - // Get the Content-Type to be used in the MIMEHeader - if c != "" { - at.Header.Set("Content-Type", c) - } else { - // If the Content-Type is blank, set the Content-Type to "application/octet-stream" - at.Header.Set("Content-Type", "application/octet-stream") - } - if id != "" { - at.Header.Set("Content-Disposition", fmt.Sprintf("inline;\r\n filename=\"%s\"", filename)) - at.Header.Set("Content-ID", fmt.Sprintf("<%s>", id)) - } else { - at.Header.Set("Content-Disposition", fmt.Sprintf("attachment;\r\n filename=\"%s\"", filename)) - } - at.Header.Set("Content-Transfer-Encoding", "base64") - e.Attachments = append(e.Attachments, at) - return at, nil -} - -// Send will send out the mail -func (e *Email) Send() error { - if e.Auth == nil { - e.Auth = smtp.PlainAuth(e.Identity, e.Username, e.Password, e.Host) - } - // Merge the To, Cc, and Bcc fields - to := make([]string, 0, len(e.To)+len(e.Cc)+len(e.Bcc)) - to = append(append(append(to, e.To...), e.Cc...), e.Bcc...) - // Check to make sure there is at least one recipient and one "From" address - if len(to) == 0 { - return errors.New("Must specify at least one To address") - } - - // Use the username if no From is provided - if len(e.From) == 0 { - e.From = e.Username - } - - from, err := mail.ParseAddress(e.From) - if err != nil { - return err - } - - // use mail's RFC 2047 to encode any string - e.Subject = qEncode("utf-8", e.Subject) - - raw, err := e.Bytes() - if err != nil { - return err - } - return smtp.SendMail(e.Host+":"+strconv.Itoa(e.Port), e.Auth, from.Address, to, raw) -} - -// quotePrintEncode writes the quoted-printable text to the IO Writer (according to RFC 2045) -func quotePrintEncode(w io.Writer, s string) error { - var buf [3]byte - mc := 0 - for i := 0; i < len(s); i++ { - c := s[i] - // We're assuming Unix style text formats as input (LF line break), and - // quoted-printble uses CRLF line breaks. (Literal CRs will become - // "=0D", but probably shouldn't be there to begin with!) - if c == '\n' { - io.WriteString(w, "\r\n") - mc = 0 - continue - } - - var nextOut []byte - if isPrintable(c) { - nextOut = append(buf[:0], c) - } else { - nextOut = buf[:] - qpEscape(nextOut, c) - } - - // Add a soft line break if the next (encoded) byte would push this line - // to or past the limit. - if mc+len(nextOut) >= maxLineLength { - if _, err := io.WriteString(w, "=\r\n"); err != nil { - return err - } - mc = 0 - } - - if _, err := w.Write(nextOut); err != nil { - return err - } - mc += len(nextOut) - } - // No trailing end-of-line?? Soft line break, then. TODO: is this sane? - if mc > 0 { - io.WriteString(w, "=\r\n") - } - return nil -} - -// isPrintable returns true if the rune given is "printable" according to RFC 2045, false otherwise -func isPrintable(c byte) bool { - return (c >= '!' && c <= '<') || (c >= '>' && c <= '~') || (c == ' ' || c == '\n' || c == '\t') -} - -// qpEscape is a helper function for quotePrintEncode which escapes a -// non-printable byte. Expects len(dest) == 3. -func qpEscape(dest []byte, c byte) { - const nums = "0123456789ABCDEF" - dest[0] = '=' - dest[1] = nums[(c&0xf0)>>4] - dest[2] = nums[(c & 0xf)] -} - -// headerToBytes enumerates the key and values in the header, and writes the results to the IO Writer -func headerToBytes(w io.Writer, t textproto.MIMEHeader) error { - for k, v := range t { - // Write the header key - _, err := fmt.Fprintf(w, "%s:", k) - if err != nil { - return err - } - // Write each value in the header - for _, c := range v { - _, err := fmt.Fprintf(w, " %s\r\n", c) - if err != nil { - return err - } - } - } - return nil -} - -// base64Wrap encodes the attachment content, and wraps it according to RFC 2045 standards (every 76 chars) -// The output is then written to the specified io.Writer -func base64Wrap(w io.Writer, b []byte) { - // 57 raw bytes per 76-byte base64 line. - const maxRaw = 57 - // Buffer for each line, including trailing CRLF. - var buffer [maxLineLength + len("\r\n")]byte - copy(buffer[maxLineLength:], "\r\n") - // Process raw chunks until there's no longer enough to fill a line. - for len(b) >= maxRaw { - base64.StdEncoding.Encode(buffer[:], b[:maxRaw]) - w.Write(buffer[:]) - b = b[maxRaw:] - } - // Handle the last chunk of bytes. - if len(b) > 0 { - out := buffer[:base64.StdEncoding.EncodedLen(len(b))] - base64.StdEncoding.Encode(out, b) - out = append(out, "\r\n"...) - w.Write(out) - } -} - -// Encode returns the encoded-word form of s. If s is ASCII without special -// characters, it is returned unchanged. The provided charset is the IANA -// charset name of s. It is case insensitive. -// RFC 2047 encoded-word -func qEncode(charset, s string) string { - if !needsEncoding(s) { - return s - } - return encodeWord(charset, s) -} - -func needsEncoding(s string) bool { - for _, b := range s { - if (b < ' ' || b > '~') && b != '\t' { - return true - } - } - return false -} - -// encodeWord encodes a string into an encoded-word. -func encodeWord(charset, s string) string { - buf := getBuffer() - - buf.WriteString("=?") - buf.WriteString(charset) - buf.WriteByte('?') - buf.WriteByte('q') - buf.WriteByte('?') - - enc := make([]byte, 3) - for i := 0; i < len(s); i++ { - b := s[i] - switch { - case b == ' ': - buf.WriteByte('_') - case b <= '~' && b >= '!' && b != '=' && b != '?' && b != '_': - buf.WriteByte(b) - default: - enc[0] = '=' - enc[1] = upperhex[b>>4] - enc[2] = upperhex[b&0x0f] - buf.Write(enc) - } - } - buf.WriteString("?=") - - es := buf.String() - putBuffer(buf) - return es -} - -var bufPool = sync.Pool{ - New: func() interface{} { - return new(bytes.Buffer) - }, -} - -func getBuffer() *bytes.Buffer { - return bufPool.Get().(*bytes.Buffer) -} - -func putBuffer(buf *bytes.Buffer) { - if buf.Len() > 1024 { - return - } - buf.Reset() - bufPool.Put(buf) -} diff --git a/utils/mail_test.go b/utils/mail_test.go deleted file mode 100644 index c38356a2..00000000 --- a/utils/mail_test.go +++ /dev/null @@ -1,41 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package utils - -import "testing" - -func TestMail(t *testing.T) { - config := `{"username":"astaxie@gmail.com","password":"astaxie","host":"smtp.gmail.com","port":587}` - mail := NewEMail(config) - if mail.Username != "astaxie@gmail.com" { - t.Fatal("email parse get username error") - } - if mail.Password != "astaxie" { - t.Fatal("email parse get password error") - } - if mail.Host != "smtp.gmail.com" { - t.Fatal("email parse get host error") - } - if mail.Port != 587 { - t.Fatal("email parse get port error") - } - mail.To = []string{"xiemengjun@gmail.com"} - mail.From = "astaxie@gmail.com" - mail.Subject = "hi, just from beego!" - mail.Text = "Text Body is, of course, supported!" - mail.HTML = "

Fancy Html is supported, too!

" - mail.AttachFile("/Users/astaxie/github/beego/beego.go") - mail.Send() -} diff --git a/utils/pagination/controller.go b/utils/pagination/controller.go deleted file mode 100644 index 2f022d0c..00000000 --- a/utils/pagination/controller.go +++ /dev/null @@ -1,26 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package pagination - -import ( - "github.com/astaxie/beego/context" -) - -// SetPaginator Instantiates a Paginator and assigns it to context.Input.Data("paginator"). -func SetPaginator(context *context.Context, per int, nums int64) (paginator *Paginator) { - paginator = NewPaginator(context.Request, per, nums) - context.Input.SetData("paginator", &paginator) - return -} diff --git a/utils/pagination/doc.go b/utils/pagination/doc.go deleted file mode 100644 index 9abc6d78..00000000 --- a/utils/pagination/doc.go +++ /dev/null @@ -1,58 +0,0 @@ -/* -Package pagination provides utilities to setup a paginator within the -context of a http request. - -Usage - -In your beego.Controller: - - package controllers - - import "github.com/astaxie/beego/utils/pagination" - - type PostsController struct { - beego.Controller - } - - func (this *PostsController) ListAllPosts() { - // sets this.Data["paginator"] with the current offset (from the url query param) - postsPerPage := 20 - paginator := pagination.SetPaginator(this.Ctx, postsPerPage, CountPosts()) - - // fetch the next 20 posts - this.Data["posts"] = ListPostsByOffsetAndLimit(paginator.Offset(), postsPerPage) - } - - -In your view templates: - - {{if .paginator.HasPages}} - - {{end}} - -See also - -http://beego.me/docs/mvc/view/page.md - -*/ -package pagination diff --git a/utils/pagination/paginator.go b/utils/pagination/paginator.go deleted file mode 100644 index c6db31e0..00000000 --- a/utils/pagination/paginator.go +++ /dev/null @@ -1,189 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package pagination - -import ( - "math" - "net/http" - "net/url" - "strconv" -) - -// Paginator within the state of a http request. -type Paginator struct { - Request *http.Request - PerPageNums int - MaxPages int - - nums int64 - pageRange []int - pageNums int - page int -} - -// PageNums Returns the total number of pages. -func (p *Paginator) PageNums() int { - if p.pageNums != 0 { - return p.pageNums - } - pageNums := math.Ceil(float64(p.nums) / float64(p.PerPageNums)) - if p.MaxPages > 0 { - pageNums = math.Min(pageNums, float64(p.MaxPages)) - } - p.pageNums = int(pageNums) - return p.pageNums -} - -// Nums Returns the total number of items (e.g. from doing SQL count). -func (p *Paginator) Nums() int64 { - return p.nums -} - -// SetNums Sets the total number of items. -func (p *Paginator) SetNums(nums interface{}) { - p.nums, _ = toInt64(nums) -} - -// Page Returns the current page. -func (p *Paginator) Page() int { - if p.page != 0 { - return p.page - } - if p.Request.Form == nil { - p.Request.ParseForm() - } - p.page, _ = strconv.Atoi(p.Request.Form.Get("p")) - if p.page > p.PageNums() { - p.page = p.PageNums() - } - if p.page <= 0 { - p.page = 1 - } - return p.page -} - -// Pages Returns a list of all pages. -// -// Usage (in a view template): -// -// {{range $index, $page := .paginator.Pages}} -// -// {{$page}} -// -// {{end}} -func (p *Paginator) Pages() []int { - if p.pageRange == nil && p.nums > 0 { - var pages []int - pageNums := p.PageNums() - page := p.Page() - switch { - case page >= pageNums-4 && pageNums > 9: - start := pageNums - 9 + 1 - pages = make([]int, 9) - for i := range pages { - pages[i] = start + i - } - case page >= 5 && pageNums > 9: - start := page - 5 + 1 - pages = make([]int, int(math.Min(9, float64(page+4+1)))) - for i := range pages { - pages[i] = start + i - } - default: - pages = make([]int, int(math.Min(9, float64(pageNums)))) - for i := range pages { - pages[i] = i + 1 - } - } - p.pageRange = pages - } - return p.pageRange -} - -// PageLink Returns URL for a given page index. -func (p *Paginator) PageLink(page int) string { - link, _ := url.ParseRequestURI(p.Request.URL.String()) - values := link.Query() - if page == 1 { - values.Del("p") - } else { - values.Set("p", strconv.Itoa(page)) - } - link.RawQuery = values.Encode() - return link.String() -} - -// PageLinkPrev Returns URL to the previous page. -func (p *Paginator) PageLinkPrev() (link string) { - if p.HasPrev() { - link = p.PageLink(p.Page() - 1) - } - return -} - -// PageLinkNext Returns URL to the next page. -func (p *Paginator) PageLinkNext() (link string) { - if p.HasNext() { - link = p.PageLink(p.Page() + 1) - } - return -} - -// PageLinkFirst Returns URL to the first page. -func (p *Paginator) PageLinkFirst() (link string) { - return p.PageLink(1) -} - -// PageLinkLast Returns URL to the last page. -func (p *Paginator) PageLinkLast() (link string) { - return p.PageLink(p.PageNums()) -} - -// HasPrev Returns true if the current page has a predecessor. -func (p *Paginator) HasPrev() bool { - return p.Page() > 1 -} - -// HasNext Returns true if the current page has a successor. -func (p *Paginator) HasNext() bool { - return p.Page() < p.PageNums() -} - -// IsActive Returns true if the given page index points to the current page. -func (p *Paginator) IsActive(page int) bool { - return p.Page() == page -} - -// Offset Returns the current offset. -func (p *Paginator) Offset() int { - return (p.Page() - 1) * p.PerPageNums -} - -// HasPages Returns true if there is more than one page. -func (p *Paginator) HasPages() bool { - return p.PageNums() > 1 -} - -// NewPaginator Instantiates a paginator struct for the current http request. -func NewPaginator(req *http.Request, per int, nums interface{}) *Paginator { - p := Paginator{} - p.Request = req - if per <= 0 { - per = 10 - } - p.PerPageNums = per - p.SetNums(nums) - return &p -} diff --git a/utils/pagination/utils.go b/utils/pagination/utils.go deleted file mode 100644 index 686e68b0..00000000 --- a/utils/pagination/utils.go +++ /dev/null @@ -1,34 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package pagination - -import ( - "fmt" - "reflect" -) - -// ToInt64 convert any numeric value to int64 -func toInt64(value interface{}) (d int64, err error) { - val := reflect.ValueOf(value) - switch value.(type) { - case int, int8, int16, int32, int64: - d = val.Int() - case uint, uint8, uint16, uint32, uint64: - d = int64(val.Uint()) - default: - err = fmt.Errorf("ToInt64 need numeric not `%T`", value) - } - return -} diff --git a/utils/rand.go b/utils/rand.go deleted file mode 100644 index 344d1cd5..00000000 --- a/utils/rand.go +++ /dev/null @@ -1,44 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package utils - -import ( - "crypto/rand" - r "math/rand" - "time" -) - -var alphaNum = []byte(`0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz`) - -// RandomCreateBytes generate random []byte by specify chars. -func RandomCreateBytes(n int, alphabets ...byte) []byte { - if len(alphabets) == 0 { - alphabets = alphaNum - } - var bytes = make([]byte, n) - var randBy bool - if num, err := rand.Read(bytes); num != n || err != nil { - r.Seed(time.Now().UnixNano()) - randBy = true - } - for i, b := range bytes { - if randBy { - bytes[i] = alphabets[r.Intn(len(alphabets))] - } else { - bytes[i] = alphabets[b%byte(len(alphabets))] - } - } - return bytes -} diff --git a/utils/rand_test.go b/utils/rand_test.go deleted file mode 100644 index 6c238b5e..00000000 --- a/utils/rand_test.go +++ /dev/null @@ -1,33 +0,0 @@ -// Copyright 2016 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package utils - -import "testing" - -func TestRand_01(t *testing.T) { - bs0 := RandomCreateBytes(16) - bs1 := RandomCreateBytes(16) - - t.Log(string(bs0), string(bs1)) - if string(bs0) == string(bs1) { - t.FailNow() - } - - bs0 = RandomCreateBytes(4, []byte(`a`)...) - - if string(bs0) != "aaaa" { - t.FailNow() - } -} diff --git a/utils/safemap.go b/utils/safemap.go deleted file mode 100644 index 1793030a..00000000 --- a/utils/safemap.go +++ /dev/null @@ -1,91 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package utils - -import ( - "sync" -) - -// BeeMap is a map with lock -type BeeMap struct { - lock *sync.RWMutex - bm map[interface{}]interface{} -} - -// NewBeeMap return new safemap -func NewBeeMap() *BeeMap { - return &BeeMap{ - lock: new(sync.RWMutex), - bm: make(map[interface{}]interface{}), - } -} - -// Get from maps return the k's value -func (m *BeeMap) Get(k interface{}) interface{} { - m.lock.RLock() - defer m.lock.RUnlock() - if val, ok := m.bm[k]; ok { - return val - } - return nil -} - -// Set Maps the given key and value. Returns false -// if the key is already in the map and changes nothing. -func (m *BeeMap) Set(k interface{}, v interface{}) bool { - m.lock.Lock() - defer m.lock.Unlock() - if val, ok := m.bm[k]; !ok { - m.bm[k] = v - } else if val != v { - m.bm[k] = v - } else { - return false - } - return true -} - -// Check Returns true if k is exist in the map. -func (m *BeeMap) Check(k interface{}) bool { - m.lock.RLock() - defer m.lock.RUnlock() - _, ok := m.bm[k] - return ok -} - -// Delete the given key and value. -func (m *BeeMap) Delete(k interface{}) { - m.lock.Lock() - defer m.lock.Unlock() - delete(m.bm, k) -} - -// Items returns all items in safemap. -func (m *BeeMap) Items() map[interface{}]interface{} { - m.lock.RLock() - defer m.lock.RUnlock() - r := make(map[interface{}]interface{}) - for k, v := range m.bm { - r[k] = v - } - return r -} - -// Count returns the number of items within the map. -func (m *BeeMap) Count() int { - m.lock.RLock() - defer m.lock.RUnlock() - return len(m.bm) -} diff --git a/utils/safemap_test.go b/utils/safemap_test.go deleted file mode 100644 index 65085195..00000000 --- a/utils/safemap_test.go +++ /dev/null @@ -1,89 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package utils - -import "testing" - -var safeMap *BeeMap - -func TestNewBeeMap(t *testing.T) { - safeMap = NewBeeMap() - if safeMap == nil { - t.Fatal("expected to return non-nil BeeMap", "got", safeMap) - } -} - -func TestSet(t *testing.T) { - safeMap = NewBeeMap() - if ok := safeMap.Set("astaxie", 1); !ok { - t.Error("expected", true, "got", false) - } -} - -func TestReSet(t *testing.T) { - safeMap := NewBeeMap() - if ok := safeMap.Set("astaxie", 1); !ok { - t.Error("expected", true, "got", false) - } - // set diff value - if ok := safeMap.Set("astaxie", -1); !ok { - t.Error("expected", true, "got", false) - } - - // set same value - if ok := safeMap.Set("astaxie", -1); ok { - t.Error("expected", false, "got", true) - } -} - -func TestCheck(t *testing.T) { - if exists := safeMap.Check("astaxie"); !exists { - t.Error("expected", true, "got", false) - } -} - -func TestGet(t *testing.T) { - if val := safeMap.Get("astaxie"); val.(int) != 1 { - t.Error("expected value", 1, "got", val) - } -} - -func TestDelete(t *testing.T) { - safeMap.Delete("astaxie") - if exists := safeMap.Check("astaxie"); exists { - t.Error("expected element to be deleted") - } -} - -func TestItems(t *testing.T) { - safeMap := NewBeeMap() - safeMap.Set("astaxie", "hello") - for k, v := range safeMap.Items() { - key := k.(string) - value := v.(string) - if key != "astaxie" { - t.Error("expected the key should be astaxie") - } - if value != "hello" { - t.Error("expected the value should be hello") - } - } -} - -func TestCount(t *testing.T) { - if count := safeMap.Count(); count != 0 { - t.Error("expected count to be", 0, "got", count) - } -} diff --git a/utils/slice.go b/utils/slice.go deleted file mode 100644 index 8f2cef98..00000000 --- a/utils/slice.go +++ /dev/null @@ -1,170 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package utils - -import ( - "math/rand" - "time" -) - -type reducetype func(interface{}) interface{} -type filtertype func(interface{}) bool - -// InSlice checks given string in string slice or not. -func InSlice(v string, sl []string) bool { - for _, vv := range sl { - if vv == v { - return true - } - } - return false -} - -// InSliceIface checks given interface in interface slice. -func InSliceIface(v interface{}, sl []interface{}) bool { - for _, vv := range sl { - if vv == v { - return true - } - } - return false -} - -// SliceRandList generate an int slice from min to max. -func SliceRandList(min, max int) []int { - if max < min { - min, max = max, min - } - length := max - min + 1 - t0 := time.Now() - rand.Seed(int64(t0.Nanosecond())) - list := rand.Perm(length) - for index := range list { - list[index] += min - } - return list -} - -// SliceMerge merges interface slices to one slice. -func SliceMerge(slice1, slice2 []interface{}) (c []interface{}) { - c = append(slice1, slice2...) - return -} - -// SliceReduce generates a new slice after parsing every value by reduce function -func SliceReduce(slice []interface{}, a reducetype) (dslice []interface{}) { - for _, v := range slice { - dslice = append(dslice, a(v)) - } - return -} - -// SliceRand returns random one from slice. -func SliceRand(a []interface{}) (b interface{}) { - randnum := rand.Intn(len(a)) - b = a[randnum] - return -} - -// SliceSum sums all values in int64 slice. -func SliceSum(intslice []int64) (sum int64) { - for _, v := range intslice { - sum += v - } - return -} - -// SliceFilter generates a new slice after filter function. -func SliceFilter(slice []interface{}, a filtertype) (ftslice []interface{}) { - for _, v := range slice { - if a(v) { - ftslice = append(ftslice, v) - } - } - return -} - -// SliceDiff returns diff slice of slice1 - slice2. -func SliceDiff(slice1, slice2 []interface{}) (diffslice []interface{}) { - for _, v := range slice1 { - if !InSliceIface(v, slice2) { - diffslice = append(diffslice, v) - } - } - return -} - -// SliceIntersect returns slice that are present in all the slice1 and slice2. -func SliceIntersect(slice1, slice2 []interface{}) (diffslice []interface{}) { - for _, v := range slice1 { - if InSliceIface(v, slice2) { - diffslice = append(diffslice, v) - } - } - return -} - -// SliceChunk separates one slice to some sized slice. -func SliceChunk(slice []interface{}, size int) (chunkslice [][]interface{}) { - if size >= len(slice) { - chunkslice = append(chunkslice, slice) - return - } - end := size - for i := 0; i <= (len(slice) - size); i += size { - chunkslice = append(chunkslice, slice[i:end]) - end += size - } - return -} - -// SliceRange generates a new slice from begin to end with step duration of int64 number. -func SliceRange(start, end, step int64) (intslice []int64) { - for i := start; i <= end; i += step { - intslice = append(intslice, i) - } - return -} - -// SlicePad prepends size number of val into slice. -func SlicePad(slice []interface{}, size int, val interface{}) []interface{} { - if size <= len(slice) { - return slice - } - for i := 0; i < (size - len(slice)); i++ { - slice = append(slice, val) - } - return slice -} - -// SliceUnique cleans repeated values in slice. -func SliceUnique(slice []interface{}) (uniqueslice []interface{}) { - for _, v := range slice { - if !InSliceIface(v, uniqueslice) { - uniqueslice = append(uniqueslice, v) - } - } - return -} - -// SliceShuffle shuffles a slice. -func SliceShuffle(slice []interface{}) []interface{} { - for i := 0; i < len(slice); i++ { - a := rand.Intn(len(slice)) - b := rand.Intn(len(slice)) - slice[a], slice[b] = slice[b], slice[a] - } - return slice -} diff --git a/utils/slice_test.go b/utils/slice_test.go deleted file mode 100644 index 142dec96..00000000 --- a/utils/slice_test.go +++ /dev/null @@ -1,29 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package utils - -import ( - "testing" -) - -func TestInSlice(t *testing.T) { - sl := []string{"A", "b"} - if !InSlice("A", sl) { - t.Error("should be true") - } - if InSlice("B", sl) { - t.Error("should be false") - } -} diff --git a/utils/testdata/grepe.test b/utils/testdata/grepe.test deleted file mode 100644 index 6c014c40..00000000 --- a/utils/testdata/grepe.test +++ /dev/null @@ -1,7 +0,0 @@ -# empty lines - - - -hello -# comment -world diff --git a/utils/utils.go b/utils/utils.go deleted file mode 100644 index 3874b803..00000000 --- a/utils/utils.go +++ /dev/null @@ -1,89 +0,0 @@ -package utils - -import ( - "os" - "path/filepath" - "regexp" - "runtime" - "strconv" - "strings" -) - -// GetGOPATHs returns all paths in GOPATH variable. -func GetGOPATHs() []string { - gopath := os.Getenv("GOPATH") - if gopath == "" && compareGoVersion(runtime.Version(), "go1.8") >= 0 { - gopath = defaultGOPATH() - } - return filepath.SplitList(gopath) -} - -func compareGoVersion(a, b string) int { - reg := regexp.MustCompile("^\\d*") - - a = strings.TrimPrefix(a, "go") - b = strings.TrimPrefix(b, "go") - - versionsA := strings.Split(a, ".") - versionsB := strings.Split(b, ".") - - for i := 0; i < len(versionsA) && i < len(versionsB); i++ { - versionA := versionsA[i] - versionB := versionsB[i] - - vA, err := strconv.Atoi(versionA) - if err != nil { - str := reg.FindString(versionA) - if str != "" { - vA, _ = strconv.Atoi(str) - } else { - vA = -1 - } - } - - vB, err := strconv.Atoi(versionB) - if err != nil { - str := reg.FindString(versionB) - if str != "" { - vB, _ = strconv.Atoi(str) - } else { - vB = -1 - } - } - - if vA > vB { - // vA = 12, vB = 8 - return 1 - } else if vA < vB { - // vA = 6, vB = 8 - return -1 - } else if vA == -1 { - // vA = rc1, vB = rc3 - return strings.Compare(versionA, versionB) - } - - // vA = vB = 8 - continue - } - - if len(versionsA) > len(versionsB) { - return 1 - } else if len(versionsA) == len(versionsB) { - return 0 - } - - return -1 -} - -func defaultGOPATH() string { - env := "HOME" - if runtime.GOOS == "windows" { - env = "USERPROFILE" - } else if runtime.GOOS == "plan9" { - env = "home" - } - if home := os.Getenv(env); home != "" { - return filepath.Join(home, "go") - } - return "" -} diff --git a/utils/utils_test.go b/utils/utils_test.go deleted file mode 100644 index ced6f63f..00000000 --- a/utils/utils_test.go +++ /dev/null @@ -1,36 +0,0 @@ -package utils - -import ( - "testing" -) - -func TestCompareGoVersion(t *testing.T) { - targetVersion := "go1.8" - if compareGoVersion("go1.12.4", targetVersion) != 1 { - t.Error("should be 1") - } - - if compareGoVersion("go1.8.7", targetVersion) != 1 { - t.Error("should be 1") - } - - if compareGoVersion("go1.8", targetVersion) != 0 { - t.Error("should be 0") - } - - if compareGoVersion("go1.7.6", targetVersion) != -1 { - t.Error("should be -1") - } - - if compareGoVersion("go1.12.1rc1", targetVersion) != 1 { - t.Error("should be 1") - } - - if compareGoVersion("go1.8rc1", targetVersion) != 0 { - t.Error("should be 0") - } - - if compareGoVersion("go1.7rc1", targetVersion) != -1 { - t.Error("should be -1") - } -} diff --git a/validation/README.md b/validation/README.md deleted file mode 100644 index 43373e47..00000000 --- a/validation/README.md +++ /dev/null @@ -1,147 +0,0 @@ -validation -============== - -validation is a form validation for a data validation and error collecting using Go. - -## Installation and tests - -Install: - - go get github.com/astaxie/beego/validation - -Test: - - go test github.com/astaxie/beego/validation - -## Example - -Direct Use: - - import ( - "github.com/astaxie/beego/validation" - "log" - ) - - type User struct { - Name string - Age int - } - - func main() { - u := User{"man", 40} - valid := validation.Validation{} - valid.Required(u.Name, "name") - valid.MaxSize(u.Name, 15, "nameMax") - valid.Range(u.Age, 0, 140, "age") - if valid.HasErrors() { - // validation does not pass - // print invalid message - for _, err := range valid.Errors { - log.Println(err.Key, err.Message) - } - } - // or use like this - if v := valid.Max(u.Age, 140, "ageMax"); !v.Ok { - log.Println(v.Error.Key, v.Error.Message) - } - } - -Struct Tag Use: - - import ( - "github.com/astaxie/beego/validation" - ) - - // validation function follow with "valid" tag - // functions divide with ";" - // parameters in parentheses "()" and divide with "," - // Match function's pattern string must in "//" - type user struct { - Id int - Name string `valid:"Required;Match(/^(test)?\\w*@;com$/)"` - Age int `valid:"Required;Range(1, 140)"` - } - - func main() { - valid := validation.Validation{} - // ignore empty field valid - // see CanSkipFuncs - // valid := validation.Validation{RequiredFirst:true} - u := user{Name: "test", Age: 40} - b, err := valid.Valid(u) - if err != nil { - // handle error - } - if !b { - // validation does not pass - // blabla... - } - } - -Use custom function: - - import ( - "github.com/astaxie/beego/validation" - ) - - type user struct { - Id int - Name string `valid:"Required;IsMe"` - Age int `valid:"Required;Range(1, 140)"` - } - - func IsMe(v *validation.Validation, obj interface{}, key string) { - name, ok:= obj.(string) - if !ok { - // wrong use case? - return - } - - if name != "me" { - // valid false - v.SetError("Name", "is not me!") - } - } - - func main() { - valid := validation.Validation{} - if err := validation.AddCustomFunc("IsMe", IsMe); err != nil { - // hadle error - } - u := user{Name: "test", Age: 40} - b, err := valid.Valid(u) - if err != nil { - // handle error - } - if !b { - // validation does not pass - // blabla... - } - } - -Struct Tag Functions: - - Required - Min(min int) - Max(max int) - Range(min, max int) - MinSize(min int) - MaxSize(max int) - Length(length int) - Alpha - Numeric - AlphaNumeric - Match(pattern string) - AlphaDash - Email - IP - Base64 - Mobile - Tel - Phone - ZipCode - - -## LICENSE - -BSD License http://creativecommons.org/licenses/BSD/ diff --git a/validation/util.go b/validation/util.go deleted file mode 100644 index 918b206c..00000000 --- a/validation/util.go +++ /dev/null @@ -1,298 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package validation - -import ( - "fmt" - "reflect" - "regexp" - "strconv" - "strings" -) - -const ( - // ValidTag struct tag - ValidTag = "valid" - - LabelTag = "label" - - wordsize = 32 << (^uint(0) >> 32 & 1) -) - -var ( - // key: function name - // value: the number of parameters - funcs = make(Funcs) - - // doesn't belong to validation functions - unFuncs = map[string]bool{ - "Clear": true, - "HasErrors": true, - "ErrorMap": true, - "Error": true, - "apply": true, - "Check": true, - "Valid": true, - "NoMatch": true, - } - // ErrInt64On32 show 32 bit platform not support int64 - ErrInt64On32 = fmt.Errorf("not support int64 on 32-bit platform") -) - -func init() { - v := &Validation{} - t := reflect.TypeOf(v) - for i := 0; i < t.NumMethod(); i++ { - m := t.Method(i) - if !unFuncs[m.Name] { - funcs[m.Name] = m.Func - } - } -} - -// CustomFunc is for custom validate function -type CustomFunc func(v *Validation, obj interface{}, key string) - -// AddCustomFunc Add a custom function to validation -// The name can not be: -// Clear -// HasErrors -// ErrorMap -// Error -// Check -// Valid -// NoMatch -// If the name is same with exists function, it will replace the origin valid function -func AddCustomFunc(name string, f CustomFunc) error { - if unFuncs[name] { - return fmt.Errorf("invalid function name: %s", name) - } - - funcs[name] = reflect.ValueOf(f) - return nil -} - -// ValidFunc Valid function type -type ValidFunc struct { - Name string - Params []interface{} -} - -// Funcs Validate function map -type Funcs map[string]reflect.Value - -// Call validate values with named type string -func (f Funcs) Call(name string, params ...interface{}) (result []reflect.Value, err error) { - defer func() { - if r := recover(); r != nil { - err = fmt.Errorf("%v", r) - } - }() - if _, ok := f[name]; !ok { - err = fmt.Errorf("%s does not exist", name) - return - } - if len(params) != f[name].Type().NumIn() { - err = fmt.Errorf("The number of params is not adapted") - return - } - in := make([]reflect.Value, len(params)) - for k, param := range params { - in[k] = reflect.ValueOf(param) - } - result = f[name].Call(in) - return -} - -func isStruct(t reflect.Type) bool { - return t.Kind() == reflect.Struct -} - -func isStructPtr(t reflect.Type) bool { - return t.Kind() == reflect.Ptr && t.Elem().Kind() == reflect.Struct -} - -func getValidFuncs(f reflect.StructField) (vfs []ValidFunc, err error) { - tag := f.Tag.Get(ValidTag) - label := f.Tag.Get(LabelTag) - if len(tag) == 0 { - return - } - if vfs, tag, err = getRegFuncs(tag, f.Name); err != nil { - return - } - fs := strings.Split(tag, ";") - for _, vfunc := range fs { - var vf ValidFunc - if len(vfunc) == 0 { - continue - } - vf, err = parseFunc(vfunc, f.Name, label) - if err != nil { - return - } - vfs = append(vfs, vf) - } - return -} - -// Get Match function -// May be get NoMatch function in the future -func getRegFuncs(tag, key string) (vfs []ValidFunc, str string, err error) { - tag = strings.TrimSpace(tag) - index := strings.Index(tag, "Match(/") - if index == -1 { - str = tag - return - } - end := strings.LastIndex(tag, "/)") - if end < index { - err = fmt.Errorf("invalid Match function") - return - } - reg, err := regexp.Compile(tag[index+len("Match(/") : end]) - if err != nil { - return - } - vfs = []ValidFunc{{"Match", []interface{}{reg, key + ".Match"}}} - str = strings.TrimSpace(tag[:index]) + strings.TrimSpace(tag[end+len("/)"):]) - return -} - -func parseFunc(vfunc, key string, label string) (v ValidFunc, err error) { - defer func() { - if r := recover(); r != nil { - err = fmt.Errorf("%v", r) - } - }() - - vfunc = strings.TrimSpace(vfunc) - start := strings.Index(vfunc, "(") - var num int - - // doesn't need parameter valid function - if start == -1 { - if num, err = numIn(vfunc); err != nil { - return - } - if num != 0 { - err = fmt.Errorf("%s require %d parameters", vfunc, num) - return - } - v = ValidFunc{vfunc, []interface{}{key + "." + vfunc + "." + label}} - return - } - - end := strings.Index(vfunc, ")") - if end == -1 { - err = fmt.Errorf("invalid valid function") - return - } - - name := strings.TrimSpace(vfunc[:start]) - if num, err = numIn(name); err != nil { - return - } - - params := strings.Split(vfunc[start+1:end], ",") - // the num of param must be equal - if num != len(params) { - err = fmt.Errorf("%s require %d parameters", name, num) - return - } - - tParams, err := trim(name, key+"."+name+"."+label, params) - if err != nil { - return - } - v = ValidFunc{name, tParams} - return -} - -func numIn(name string) (num int, err error) { - fn, ok := funcs[name] - if !ok { - err = fmt.Errorf("doesn't exists %s valid function", name) - return - } - // sub *Validation obj and key - num = fn.Type().NumIn() - 3 - return -} - -func trim(name, key string, s []string) (ts []interface{}, err error) { - ts = make([]interface{}, len(s), len(s)+1) - fn, ok := funcs[name] - if !ok { - err = fmt.Errorf("doesn't exists %s valid function", name) - return - } - for i := 0; i < len(s); i++ { - var param interface{} - // skip *Validation and obj params - if param, err = parseParam(fn.Type().In(i+2), strings.TrimSpace(s[i])); err != nil { - return - } - ts[i] = param - } - ts = append(ts, key) - return -} - -// modify the parameters's type to adapt the function input parameters' type -func parseParam(t reflect.Type, s string) (i interface{}, err error) { - switch t.Kind() { - case reflect.Int: - i, err = strconv.Atoi(s) - case reflect.Int64: - if wordsize == 32 { - return nil, ErrInt64On32 - } - i, err = strconv.ParseInt(s, 10, 64) - case reflect.Int32: - var v int64 - v, err = strconv.ParseInt(s, 10, 32) - if err == nil { - i = int32(v) - } - case reflect.Int16: - var v int64 - v, err = strconv.ParseInt(s, 10, 16) - if err == nil { - i = int16(v) - } - case reflect.Int8: - var v int64 - v, err = strconv.ParseInt(s, 10, 8) - if err == nil { - i = int8(v) - } - case reflect.String: - i = s - case reflect.Ptr: - if t.Elem().String() != "regexp.Regexp" { - err = fmt.Errorf("not support %s", t.Elem().String()) - return - } - i, err = regexp.Compile(s) - default: - err = fmt.Errorf("not support %s", t.Kind().String()) - } - return -} - -func mergeParam(v *Validation, obj interface{}, params []interface{}) []interface{} { - return append([]interface{}{v, obj}, params...) -} diff --git a/validation/util_test.go b/validation/util_test.go deleted file mode 100644 index 58ca38db..00000000 --- a/validation/util_test.go +++ /dev/null @@ -1,128 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package validation - -import ( - "log" - "reflect" - "testing" -) - -type user struct { - ID int - Tag string `valid:"Maxx(aa)"` - Name string `valid:"Required;"` - Age int `valid:"Required; Range(1, 140)"` - match string `valid:"Required; Match(/^(test)?\\w*@(/test/);com$/);Max(2)"` -} - -func TestGetValidFuncs(t *testing.T) { - u := user{Name: "test", Age: 1} - tf := reflect.TypeOf(u) - var vfs []ValidFunc - var err error - - f, _ := tf.FieldByName("ID") - if vfs, err = getValidFuncs(f); err != nil { - t.Fatal(err) - } - if len(vfs) != 0 { - t.Fatal("should get none ValidFunc") - } - - f, _ = tf.FieldByName("Tag") - if _, err = getValidFuncs(f); err.Error() != "doesn't exists Maxx valid function" { - t.Fatal(err) - } - - f, _ = tf.FieldByName("Name") - if vfs, err = getValidFuncs(f); err != nil { - t.Fatal(err) - } - if len(vfs) != 1 { - t.Fatal("should get 1 ValidFunc") - } - if vfs[0].Name != "Required" && len(vfs[0].Params) != 0 { - t.Error("Required funcs should be got") - } - - f, _ = tf.FieldByName("Age") - if vfs, err = getValidFuncs(f); err != nil { - t.Fatal(err) - } - if len(vfs) != 2 { - t.Fatal("should get 2 ValidFunc") - } - if vfs[0].Name != "Required" && len(vfs[0].Params) != 0 { - t.Error("Required funcs should be got") - } - if vfs[1].Name != "Range" && len(vfs[1].Params) != 2 { - t.Error("Range funcs should be got") - } - - f, _ = tf.FieldByName("match") - if vfs, err = getValidFuncs(f); err != nil { - t.Fatal(err) - } - if len(vfs) != 3 { - t.Fatal("should get 3 ValidFunc but now is", len(vfs)) - } -} - -type User struct { - Name string `valid:"Required;MaxSize(5)" ` - Sex string `valid:"Required;" label:"sex_label"` - Age int `valid:"Required;Range(1, 140);" label:"age_label"` -} - -func TestValidation(t *testing.T) { - u := User{"man1238888456", "", 1140} - valid := Validation{} - b, err := valid.Valid(&u) - if err != nil { - // handle error - } - if !b { - // validation does not pass - // blabla... - for _, err := range valid.Errors { - log.Println(err.Key, err.Message) - } - if len(valid.Errors) != 3 { - t.Error("must be has 3 error") - } - } else { - t.Error("must be has 3 error") - } -} - -func TestCall(t *testing.T) { - u := user{Name: "test", Age: 180} - tf := reflect.TypeOf(u) - var vfs []ValidFunc - var err error - f, _ := tf.FieldByName("Age") - if vfs, err = getValidFuncs(f); err != nil { - t.Fatal(err) - } - valid := &Validation{} - vfs[1].Params = append([]interface{}{valid, u.Age}, vfs[1].Params...) - if _, err = funcs.Call(vfs[1].Name, vfs[1].Params...); err != nil { - t.Fatal(err) - } - if len(valid.Errors) != 1 { - t.Error("age out of range should be has an error") - } -} diff --git a/validation/validation.go b/validation/validation.go deleted file mode 100644 index 190e0f0e..00000000 --- a/validation/validation.go +++ /dev/null @@ -1,456 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Package validation for validations -// -// import ( -// "github.com/astaxie/beego/validation" -// "log" -// ) -// -// type User struct { -// Name string -// Age int -// } -// -// func main() { -// u := User{"man", 40} -// valid := validation.Validation{} -// valid.Required(u.Name, "name") -// valid.MaxSize(u.Name, 15, "nameMax") -// valid.Range(u.Age, 0, 140, "age") -// if valid.HasErrors() { -// // validation does not pass -// // print invalid message -// for _, err := range valid.Errors { -// log.Println(err.Key, err.Message) -// } -// } -// // or use like this -// if v := valid.Max(u.Age, 140, "ageMax"); !v.Ok { -// log.Println(v.Error.Key, v.Error.Message) -// } -// } -// -// more info: http://beego.me/docs/mvc/controller/validation.md -package validation - -import ( - "fmt" - "reflect" - "regexp" - "strings" -) - -// ValidFormer valid interface -type ValidFormer interface { - Valid(*Validation) -} - -// Error show the error -type Error struct { - Message, Key, Name, Field, Tmpl string - Value interface{} - LimitValue interface{} -} - -// String Returns the Message. -func (e *Error) String() string { - if e == nil { - return "" - } - return e.Message -} - -// Implement Error interface. -// Return e.String() -func (e *Error) Error() string { return e.String() } - -// Result is returned from every validation method. -// It provides an indication of success, and a pointer to the Error (if any). -type Result struct { - Error *Error - Ok bool -} - -// Key Get Result by given key string. -func (r *Result) Key(key string) *Result { - if r.Error != nil { - r.Error.Key = key - } - return r -} - -// Message Set Result message by string or format string with args -func (r *Result) Message(message string, args ...interface{}) *Result { - if r.Error != nil { - if len(args) == 0 { - r.Error.Message = message - } else { - r.Error.Message = fmt.Sprintf(message, args...) - } - } - return r -} - -// A Validation context manages data validation and error messages. -type Validation struct { - // if this field set true, in struct tag valid - // if the struct field vale is empty - // it will skip those valid functions, see CanSkipFuncs - RequiredFirst bool - - Errors []*Error - ErrorsMap map[string][]*Error -} - -// Clear Clean all ValidationError. -func (v *Validation) Clear() { - v.Errors = []*Error{} - v.ErrorsMap = nil -} - -// HasErrors Has ValidationError nor not. -func (v *Validation) HasErrors() bool { - return len(v.Errors) > 0 -} - -// ErrorMap Return the errors mapped by key. -// If there are multiple validation errors associated with a single key, the -// first one "wins". (Typically the first validation will be the more basic). -func (v *Validation) ErrorMap() map[string][]*Error { - return v.ErrorsMap -} - -// Error Add an error to the validation context. -func (v *Validation) Error(message string, args ...interface{}) *Result { - result := (&Result{ - Ok: false, - Error: &Error{}, - }).Message(message, args...) - v.Errors = append(v.Errors, result.Error) - return result -} - -// Required Test that the argument is non-nil and non-empty (if string or list) -func (v *Validation) Required(obj interface{}, key string) *Result { - return v.apply(Required{key}, obj) -} - -// Min Test that the obj is greater than min if obj's type is int -func (v *Validation) Min(obj interface{}, min int, key string) *Result { - return v.apply(Min{min, key}, obj) -} - -// Max Test that the obj is less than max if obj's type is int -func (v *Validation) Max(obj interface{}, max int, key string) *Result { - return v.apply(Max{max, key}, obj) -} - -// Range Test that the obj is between mni and max if obj's type is int -func (v *Validation) Range(obj interface{}, min, max int, key string) *Result { - return v.apply(Range{Min{Min: min}, Max{Max: max}, key}, obj) -} - -// MinSize Test that the obj is longer than min size if type is string or slice -func (v *Validation) MinSize(obj interface{}, min int, key string) *Result { - return v.apply(MinSize{min, key}, obj) -} - -// MaxSize Test that the obj is shorter than max size if type is string or slice -func (v *Validation) MaxSize(obj interface{}, max int, key string) *Result { - return v.apply(MaxSize{max, key}, obj) -} - -// Length Test that the obj is same length to n if type is string or slice -func (v *Validation) Length(obj interface{}, n int, key string) *Result { - return v.apply(Length{n, key}, obj) -} - -// Alpha Test that the obj is [a-zA-Z] if type is string -func (v *Validation) Alpha(obj interface{}, key string) *Result { - return v.apply(Alpha{key}, obj) -} - -// Numeric Test that the obj is [0-9] if type is string -func (v *Validation) Numeric(obj interface{}, key string) *Result { - return v.apply(Numeric{key}, obj) -} - -// AlphaNumeric Test that the obj is [0-9a-zA-Z] if type is string -func (v *Validation) AlphaNumeric(obj interface{}, key string) *Result { - return v.apply(AlphaNumeric{key}, obj) -} - -// Match Test that the obj matches regexp if type is string -func (v *Validation) Match(obj interface{}, regex *regexp.Regexp, key string) *Result { - return v.apply(Match{regex, key}, obj) -} - -// NoMatch Test that the obj doesn't match regexp if type is string -func (v *Validation) NoMatch(obj interface{}, regex *regexp.Regexp, key string) *Result { - return v.apply(NoMatch{Match{Regexp: regex}, key}, obj) -} - -// AlphaDash Test that the obj is [0-9a-zA-Z_-] if type is string -func (v *Validation) AlphaDash(obj interface{}, key string) *Result { - return v.apply(AlphaDash{NoMatch{Match: Match{Regexp: alphaDashPattern}}, key}, obj) -} - -// Email Test that the obj is email address if type is string -func (v *Validation) Email(obj interface{}, key string) *Result { - return v.apply(Email{Match{Regexp: emailPattern}, key}, obj) -} - -// IP Test that the obj is IP address if type is string -func (v *Validation) IP(obj interface{}, key string) *Result { - return v.apply(IP{Match{Regexp: ipPattern}, key}, obj) -} - -// Base64 Test that the obj is base64 encoded if type is string -func (v *Validation) Base64(obj interface{}, key string) *Result { - return v.apply(Base64{Match{Regexp: base64Pattern}, key}, obj) -} - -// Mobile Test that the obj is chinese mobile number if type is string -func (v *Validation) Mobile(obj interface{}, key string) *Result { - return v.apply(Mobile{Match{Regexp: mobilePattern}, key}, obj) -} - -// Tel Test that the obj is chinese telephone number if type is string -func (v *Validation) Tel(obj interface{}, key string) *Result { - return v.apply(Tel{Match{Regexp: telPattern}, key}, obj) -} - -// Phone Test that the obj is chinese mobile or telephone number if type is string -func (v *Validation) Phone(obj interface{}, key string) *Result { - return v.apply(Phone{Mobile{Match: Match{Regexp: mobilePattern}}, - Tel{Match: Match{Regexp: telPattern}}, key}, obj) -} - -// ZipCode Test that the obj is chinese zip code if type is string -func (v *Validation) ZipCode(obj interface{}, key string) *Result { - return v.apply(ZipCode{Match{Regexp: zipCodePattern}, key}, obj) -} - -func (v *Validation) apply(chk Validator, obj interface{}) *Result { - if nil == obj { - if chk.IsSatisfied(obj) { - return &Result{Ok: true} - } - } else if reflect.TypeOf(obj).Kind() == reflect.Ptr { - if reflect.ValueOf(obj).IsNil() { - if chk.IsSatisfied(nil) { - return &Result{Ok: true} - } - } else { - if chk.IsSatisfied(reflect.ValueOf(obj).Elem().Interface()) { - return &Result{Ok: true} - } - } - } else if chk.IsSatisfied(obj) { - return &Result{Ok: true} - } - - // Add the error to the validation context. - key := chk.GetKey() - Name := key - Field := "" - Label := "" - parts := strings.Split(key, ".") - if len(parts) == 3 { - Field = parts[0] - Name = parts[1] - Label = parts[2] - if len(Label) == 0 { - Label = Field - } - } - - err := &Error{ - Message: Label + " " + chk.DefaultMessage(), - Key: key, - Name: Name, - Field: Field, - Value: obj, - Tmpl: MessageTmpls[Name], - LimitValue: chk.GetLimitValue(), - } - v.setError(err) - - // Also return it in the result. - return &Result{ - Ok: false, - Error: err, - } -} - -// key must like aa.bb.cc or aa.bb. -// AddError adds independent error message for the provided key -func (v *Validation) AddError(key, message string) { - Name := key - Field := "" - - Label := "" - parts := strings.Split(key, ".") - if len(parts) == 3 { - Field = parts[0] - Name = parts[1] - Label = parts[2] - if len(Label) == 0 { - Label = Field - } - } - - err := &Error{ - Message: Label + " " + message, - Key: key, - Name: Name, - Field: Field, - } - v.setError(err) -} - -func (v *Validation) setError(err *Error) { - v.Errors = append(v.Errors, err) - if v.ErrorsMap == nil { - v.ErrorsMap = make(map[string][]*Error) - } - if _, ok := v.ErrorsMap[err.Field]; !ok { - v.ErrorsMap[err.Field] = []*Error{} - } - v.ErrorsMap[err.Field] = append(v.ErrorsMap[err.Field], err) -} - -// SetError Set error message for one field in ValidationError -func (v *Validation) SetError(fieldName string, errMsg string) *Error { - err := &Error{Key: fieldName, Field: fieldName, Tmpl: errMsg, Message: errMsg} - v.setError(err) - return err -} - -// Check Apply a group of validators to a field, in order, and return the -// ValidationResult from the first one that fails, or the last one that -// succeeds. -func (v *Validation) Check(obj interface{}, checks ...Validator) *Result { - var result *Result - for _, check := range checks { - result = v.apply(check, obj) - if !result.Ok { - return result - } - } - return result -} - -// Valid Validate a struct. -// the obj parameter must be a struct or a struct pointer -func (v *Validation) Valid(obj interface{}) (b bool, err error) { - objT := reflect.TypeOf(obj) - objV := reflect.ValueOf(obj) - switch { - case isStruct(objT): - case isStructPtr(objT): - objT = objT.Elem() - objV = objV.Elem() - default: - err = fmt.Errorf("%v must be a struct or a struct pointer", obj) - return - } - - for i := 0; i < objT.NumField(); i++ { - var vfs []ValidFunc - if vfs, err = getValidFuncs(objT.Field(i)); err != nil { - return - } - - var hasRequired bool - for _, vf := range vfs { - if vf.Name == "Required" { - hasRequired = true - } - - currentField := objV.Field(i).Interface() - if objV.Field(i).Kind() == reflect.Ptr { - if objV.Field(i).IsNil() { - currentField = "" - } else { - currentField = objV.Field(i).Elem().Interface() - } - } - - chk := Required{""}.IsSatisfied(currentField) - if !hasRequired && v.RequiredFirst && !chk { - if _, ok := CanSkipFuncs[vf.Name]; ok { - continue - } - } - - if _, err = funcs.Call(vf.Name, - mergeParam(v, objV.Field(i).Interface(), vf.Params)...); err != nil { - return - } - } - } - - if !v.HasErrors() { - if form, ok := obj.(ValidFormer); ok { - form.Valid(v) - } - } - - return !v.HasErrors(), nil -} - -// RecursiveValid Recursively validate a struct. -// Step1: Validate by v.Valid -// Step2: If pass on step1, then reflect obj's fields -// Step3: Do the Recursively validation to all struct or struct pointer fields -func (v *Validation) RecursiveValid(objc interface{}) (bool, error) { - //Step 1: validate obj itself firstly - // fails if objc is not struct - pass, err := v.Valid(objc) - if err != nil || !pass { - return pass, err // Stop recursive validation - } - // Step 2: Validate struct's struct fields - objT := reflect.TypeOf(objc) - objV := reflect.ValueOf(objc) - - if isStructPtr(objT) { - objT = objT.Elem() - objV = objV.Elem() - } - - for i := 0; i < objT.NumField(); i++ { - - t := objT.Field(i).Type - - // Recursive applies to struct or pointer to structs fields - if isStruct(t) || isStructPtr(t) { - // Step 3: do the recursive validation - // Only valid the Public field recursively - if objV.Field(i).CanInterface() { - pass, err = v.RecursiveValid(objV.Field(i).Interface()) - } - } - } - return pass, err -} - -func (v *Validation) CanSkipAlso(skipFunc string) { - if _, ok := CanSkipFuncs[skipFunc]; !ok { - CanSkipFuncs[skipFunc] = struct{}{} - } -} diff --git a/validation/validation_test.go b/validation/validation_test.go deleted file mode 100644 index b4b5b1b6..00000000 --- a/validation/validation_test.go +++ /dev/null @@ -1,609 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package validation - -import ( - "regexp" - "testing" - "time" -) - -func TestRequired(t *testing.T) { - valid := Validation{} - - if valid.Required(nil, "nil").Ok { - t.Error("nil object should be false") - } - if !valid.Required(true, "bool").Ok { - t.Error("Bool value should always return true") - } - if !valid.Required(false, "bool").Ok { - t.Error("Bool value should always return true") - } - if valid.Required("", "string").Ok { - t.Error("\"'\" string should be false") - } - if valid.Required(" ", "string").Ok { - t.Error("\" \" string should be false") // For #2361 - } - if valid.Required("\n", "string").Ok { - t.Error("new line string should be false") // For #2361 - } - if !valid.Required("astaxie", "string").Ok { - t.Error("string should be true") - } - if valid.Required(0, "zero").Ok { - t.Error("Integer should not be equal 0") - } - if !valid.Required(1, "int").Ok { - t.Error("Integer except 0 should be true") - } - if !valid.Required(time.Now(), "time").Ok { - t.Error("time should be true") - } - if valid.Required([]string{}, "emptySlice").Ok { - t.Error("empty slice should be false") - } - if !valid.Required([]interface{}{"ok"}, "slice").Ok { - t.Error("slice should be true") - } -} - -func TestMin(t *testing.T) { - valid := Validation{} - - if valid.Min(-1, 0, "min0").Ok { - t.Error("-1 is less than the minimum value of 0 should be false") - } - if !valid.Min(1, 0, "min0").Ok { - t.Error("1 is greater or equal than the minimum value of 0 should be true") - } -} - -func TestMax(t *testing.T) { - valid := Validation{} - - if valid.Max(1, 0, "max0").Ok { - t.Error("1 is greater than the minimum value of 0 should be false") - } - if !valid.Max(-1, 0, "max0").Ok { - t.Error("-1 is less or equal than the maximum value of 0 should be true") - } -} - -func TestRange(t *testing.T) { - valid := Validation{} - - if valid.Range(-1, 0, 1, "range0_1").Ok { - t.Error("-1 is between 0 and 1 should be false") - } - if !valid.Range(1, 0, 1, "range0_1").Ok { - t.Error("1 is between 0 and 1 should be true") - } -} - -func TestMinSize(t *testing.T) { - valid := Validation{} - - if valid.MinSize("", 1, "minSize1").Ok { - t.Error("the length of \"\" is less than the minimum value of 1 should be false") - } - if !valid.MinSize("ok", 1, "minSize1").Ok { - t.Error("the length of \"ok\" is greater or equal than the minimum value of 1 should be true") - } - if valid.MinSize([]string{}, 1, "minSize1").Ok { - t.Error("the length of empty slice is less than the minimum value of 1 should be false") - } - if !valid.MinSize([]interface{}{"ok"}, 1, "minSize1").Ok { - t.Error("the length of [\"ok\"] is greater or equal than the minimum value of 1 should be true") - } -} - -func TestMaxSize(t *testing.T) { - valid := Validation{} - - if valid.MaxSize("ok", 1, "maxSize1").Ok { - t.Error("the length of \"ok\" is greater than the maximum value of 1 should be false") - } - if !valid.MaxSize("", 1, "maxSize1").Ok { - t.Error("the length of \"\" is less or equal than the maximum value of 1 should be true") - } - if valid.MaxSize([]interface{}{"ok", false}, 1, "maxSize1").Ok { - t.Error("the length of [\"ok\", false] is greater than the maximum value of 1 should be false") - } - if !valid.MaxSize([]string{}, 1, "maxSize1").Ok { - t.Error("the length of empty slice is less or equal than the maximum value of 1 should be true") - } -} - -func TestLength(t *testing.T) { - valid := Validation{} - - if valid.Length("", 1, "length1").Ok { - t.Error("the length of \"\" must equal 1 should be false") - } - if !valid.Length("1", 1, "length1").Ok { - t.Error("the length of \"1\" must equal 1 should be true") - } - if valid.Length([]string{}, 1, "length1").Ok { - t.Error("the length of empty slice must equal 1 should be false") - } - if !valid.Length([]interface{}{"ok"}, 1, "length1").Ok { - t.Error("the length of [\"ok\"] must equal 1 should be true") - } -} - -func TestAlpha(t *testing.T) { - valid := Validation{} - - if valid.Alpha("a,1-@ $", "alpha").Ok { - t.Error("\"a,1-@ $\" are valid alpha characters should be false") - } - if !valid.Alpha("abCD", "alpha").Ok { - t.Error("\"abCD\" are valid alpha characters should be true") - } -} - -func TestNumeric(t *testing.T) { - valid := Validation{} - - if valid.Numeric("a,1-@ $", "numeric").Ok { - t.Error("\"a,1-@ $\" are valid numeric characters should be false") - } - if !valid.Numeric("1234", "numeric").Ok { - t.Error("\"1234\" are valid numeric characters should be true") - } -} - -func TestAlphaNumeric(t *testing.T) { - valid := Validation{} - - if valid.AlphaNumeric("a,1-@ $", "alphaNumeric").Ok { - t.Error("\"a,1-@ $\" are valid alpha or numeric characters should be false") - } - if !valid.AlphaNumeric("1234aB", "alphaNumeric").Ok { - t.Error("\"1234aB\" are valid alpha or numeric characters should be true") - } -} - -func TestMatch(t *testing.T) { - valid := Validation{} - - if valid.Match("suchuangji@gmail", regexp.MustCompile(`^\w+@\w+\.\w+$`), "match").Ok { - t.Error("\"suchuangji@gmail\" match \"^\\w+@\\w+\\.\\w+$\" should be false") - } - if !valid.Match("suchuangji@gmail.com", regexp.MustCompile(`^\w+@\w+\.\w+$`), "match").Ok { - t.Error("\"suchuangji@gmail\" match \"^\\w+@\\w+\\.\\w+$\" should be true") - } -} - -func TestNoMatch(t *testing.T) { - valid := Validation{} - - if valid.NoMatch("123@gmail", regexp.MustCompile(`[^\w\d]`), "nomatch").Ok { - t.Error("\"123@gmail\" not match \"[^\\w\\d]\" should be false") - } - if !valid.NoMatch("123gmail", regexp.MustCompile(`[^\w\d]`), "match").Ok { - t.Error("\"123@gmail\" not match \"[^\\w\\d@]\" should be true") - } -} - -func TestAlphaDash(t *testing.T) { - valid := Validation{} - - if valid.AlphaDash("a,1-@ $", "alphaDash").Ok { - t.Error("\"a,1-@ $\" are valid alpha or numeric or dash(-_) characters should be false") - } - if !valid.AlphaDash("1234aB-_", "alphaDash").Ok { - t.Error("\"1234aB\" are valid alpha or numeric or dash(-_) characters should be true") - } -} - -func TestEmail(t *testing.T) { - valid := Validation{} - - if valid.Email("not@a email", "email").Ok { - t.Error("\"not@a email\" is a valid email address should be false") - } - if !valid.Email("suchuangji@gmail.com", "email").Ok { - t.Error("\"suchuangji@gmail.com\" is a valid email address should be true") - } - if valid.Email("@suchuangji@gmail.com", "email").Ok { - t.Error("\"@suchuangji@gmail.com\" is a valid email address should be false") - } - if valid.Email("suchuangji@gmail.com ok", "email").Ok { - t.Error("\"suchuangji@gmail.com ok\" is a valid email address should be false") - } -} - -func TestIP(t *testing.T) { - valid := Validation{} - - if valid.IP("11.255.255.256", "IP").Ok { - t.Error("\"11.255.255.256\" is a valid ip address should be false") - } - if !valid.IP("01.11.11.11", "IP").Ok { - t.Error("\"suchuangji@gmail.com\" is a valid ip address should be true") - } -} - -func TestBase64(t *testing.T) { - valid := Validation{} - - if valid.Base64("suchuangji@gmail.com", "base64").Ok { - t.Error("\"suchuangji@gmail.com\" are a valid base64 characters should be false") - } - if !valid.Base64("c3VjaHVhbmdqaUBnbWFpbC5jb20=", "base64").Ok { - t.Error("\"c3VjaHVhbmdqaUBnbWFpbC5jb20=\" are a valid base64 characters should be true") - } -} - -func TestMobile(t *testing.T) { - valid := Validation{} - - validMobiles := []string{ - "19800008888", - "18800008888", - "18000008888", - "8618300008888", - "+8614700008888", - "17300008888", - "+8617100008888", - "8617500008888", - "8617400008888", - "16200008888", - "16500008888", - "16600008888", - "16700008888", - "13300008888", - "14900008888", - "15300008888", - "17300008888", - "17700008888", - "18000008888", - "18900008888", - "19100008888", - "19900008888", - "19300008888", - "13000008888", - "13100008888", - "13200008888", - "14500008888", - "15500008888", - "15600008888", - "16600008888", - "17100008888", - "17500008888", - "17600008888", - "18500008888", - "18600008888", - "13400008888", - "13500008888", - "13600008888", - "13700008888", - "13800008888", - "13900008888", - "14700008888", - "15000008888", - "15100008888", - "15200008888", - "15800008888", - "15900008888", - "17200008888", - "17800008888", - "18200008888", - "18300008888", - "18400008888", - "18700008888", - "18800008888", - "19800008888", - } - - for _, m := range validMobiles { - if !valid.Mobile(m, "mobile").Ok { - t.Error(m + " is a valid mobile phone number should be true") - } - } -} - -func TestTel(t *testing.T) { - valid := Validation{} - - if valid.Tel("222-00008888", "telephone").Ok { - t.Error("\"222-00008888\" is a valid telephone number should be false") - } - if !valid.Tel("022-70008888", "telephone").Ok { - t.Error("\"022-70008888\" is a valid telephone number should be true") - } - if !valid.Tel("02270008888", "telephone").Ok { - t.Error("\"02270008888\" is a valid telephone number should be true") - } - if !valid.Tel("70008888", "telephone").Ok { - t.Error("\"70008888\" is a valid telephone number should be true") - } -} - -func TestPhone(t *testing.T) { - valid := Validation{} - - if valid.Phone("222-00008888", "phone").Ok { - t.Error("\"222-00008888\" is a valid phone number should be false") - } - if !valid.Mobile("+8614700008888", "phone").Ok { - t.Error("\"+8614700008888\" is a valid phone number should be true") - } - if !valid.Tel("02270008888", "phone").Ok { - t.Error("\"02270008888\" is a valid phone number should be true") - } -} - -func TestZipCode(t *testing.T) { - valid := Validation{} - - if valid.ZipCode("", "zipcode").Ok { - t.Error("\"00008888\" is a valid zipcode should be false") - } - if !valid.ZipCode("536000", "zipcode").Ok { - t.Error("\"536000\" is a valid zipcode should be true") - } -} - -func TestValid(t *testing.T) { - type user struct { - ID int - Name string `valid:"Required;Match(/^(test)?\\w*@(/test/);com$/)"` - Age int `valid:"Required;Range(1, 140)"` - } - valid := Validation{} - - u := user{Name: "test@/test/;com", Age: 40} - b, err := valid.Valid(u) - if err != nil { - t.Fatal(err) - } - if !b { - t.Error("validation should be passed") - } - - uptr := &user{Name: "test", Age: 40} - valid.Clear() - b, err = valid.Valid(uptr) - if err != nil { - t.Fatal(err) - } - if b { - t.Error("validation should not be passed") - } - if len(valid.Errors) != 1 { - t.Fatalf("valid errors len should be 1 but got %d", len(valid.Errors)) - } - if valid.Errors[0].Key != "Name.Match" { - t.Errorf("Message key should be `Name.Match` but got %s", valid.Errors[0].Key) - } - - u = user{Name: "test@/test/;com", Age: 180} - valid.Clear() - b, err = valid.Valid(u) - if err != nil { - t.Fatal(err) - } - if b { - t.Error("validation should not be passed") - } - if len(valid.Errors) != 1 { - t.Fatalf("valid errors len should be 1 but got %d", len(valid.Errors)) - } - if valid.Errors[0].Key != "Age.Range." { - t.Errorf("Message key should be `Age.Range` but got %s", valid.Errors[0].Key) - } -} - -func TestRecursiveValid(t *testing.T) { - type User struct { - ID int - Name string `valid:"Required;Match(/^(test)?\\w*@(/test/);com$/)"` - Age int `valid:"Required;Range(1, 140)"` - } - - type AnonymouseUser struct { - ID2 int - Name2 string `valid:"Required;Match(/^(test)?\\w*@(/test/);com$/)"` - Age2 int `valid:"Required;Range(1, 140)"` - } - - type Account struct { - Password string `valid:"Required"` - U User - AnonymouseUser - } - valid := Validation{} - - u := Account{Password: "abc123_", U: User{}} - b, err := valid.RecursiveValid(u) - if err != nil { - t.Fatal(err) - } - if b { - t.Error("validation should not be passed") - } -} - -func TestSkipValid(t *testing.T) { - type User struct { - ID int - - Email string `valid:"Email"` - ReqEmail string `valid:"Required;Email"` - - IP string `valid:"IP"` - ReqIP string `valid:"Required;IP"` - - Mobile string `valid:"Mobile"` - ReqMobile string `valid:"Required;Mobile"` - - Tel string `valid:"Tel"` - ReqTel string `valid:"Required;Tel"` - - Phone string `valid:"Phone"` - ReqPhone string `valid:"Required;Phone"` - - ZipCode string `valid:"ZipCode"` - ReqZipCode string `valid:"Required;ZipCode"` - } - - u := User{ - ReqEmail: "a@a.com", - ReqIP: "127.0.0.1", - ReqMobile: "18888888888", - ReqTel: "02088888888", - ReqPhone: "02088888888", - ReqZipCode: "510000", - } - - valid := Validation{} - b, err := valid.Valid(u) - if err != nil { - t.Fatal(err) - } - if b { - t.Fatal("validation should not be passed") - } - - valid = Validation{RequiredFirst: true} - b, err = valid.Valid(u) - if err != nil { - t.Fatal(err) - } - if !b { - t.Fatal("validation should be passed") - } -} - -func TestPointer(t *testing.T) { - type User struct { - ID int - - Email *string `valid:"Email"` - ReqEmail *string `valid:"Required;Email"` - } - - u := User{ - ReqEmail: nil, - Email: nil, - } - - valid := Validation{} - b, err := valid.Valid(u) - if err != nil { - t.Fatal(err) - } - if b { - t.Fatal("validation should not be passed") - } - - validEmail := "a@a.com" - u = User{ - ReqEmail: &validEmail, - Email: nil, - } - - valid = Validation{RequiredFirst: true} - b, err = valid.Valid(u) - if err != nil { - t.Fatal(err) - } - if !b { - t.Fatal("validation should be passed") - } - - u = User{ - ReqEmail: &validEmail, - Email: nil, - } - - valid = Validation{} - b, err = valid.Valid(u) - if err != nil { - t.Fatal(err) - } - if b { - t.Fatal("validation should not be passed") - } - - invalidEmail := "a@a" - u = User{ - ReqEmail: &validEmail, - Email: &invalidEmail, - } - - valid = Validation{RequiredFirst: true} - b, err = valid.Valid(u) - if err != nil { - t.Fatal(err) - } - if b { - t.Fatal("validation should not be passed") - } - - u = User{ - ReqEmail: &validEmail, - Email: &invalidEmail, - } - - valid = Validation{} - b, err = valid.Valid(u) - if err != nil { - t.Fatal(err) - } - if b { - t.Fatal("validation should not be passed") - } -} - -func TestCanSkipAlso(t *testing.T) { - type User struct { - ID int - - Email string `valid:"Email"` - ReqEmail string `valid:"Required;Email"` - MatchRange int `valid:"Range(10, 20)"` - } - - u := User{ - ReqEmail: "a@a.com", - Email: "", - MatchRange: 0, - } - - valid := Validation{RequiredFirst: true} - b, err := valid.Valid(u) - if err != nil { - t.Fatal(err) - } - if b { - t.Fatal("validation should not be passed") - } - - valid = Validation{RequiredFirst: true} - valid.CanSkipAlso("Range") - b, err = valid.Valid(u) - if err != nil { - t.Fatal(err) - } - if !b { - t.Fatal("validation should be passed") - } - -} diff --git a/validation/validators.go b/validation/validators.go deleted file mode 100644 index 38b6f1aa..00000000 --- a/validation/validators.go +++ /dev/null @@ -1,738 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package validation - -import ( - "fmt" - "github.com/astaxie/beego/logs" - "reflect" - "regexp" - "strings" - "sync" - "time" - "unicode/utf8" -) - -// CanSkipFuncs will skip valid if RequiredFirst is true and the struct field's value is empty -var CanSkipFuncs = map[string]struct{}{ - "Email": {}, - "IP": {}, - "Mobile": {}, - "Tel": {}, - "Phone": {}, - "ZipCode": {}, -} - -// MessageTmpls store commond validate template -var MessageTmpls = map[string]string{ - "Required": "Can not be empty", - "Min": "Minimum is %d", - "Max": "Maximum is %d", - "Range": "Range is %d to %d", - "MinSize": "Minimum size is %d", - "MaxSize": "Maximum size is %d", - "Length": "Required length is %d", - "Alpha": "Must be valid alpha characters", - "Numeric": "Must be valid numeric characters", - "AlphaNumeric": "Must be valid alpha or numeric characters", - "Match": "Must match %s", - "NoMatch": "Must not match %s", - "AlphaDash": "Must be valid alpha or numeric or dash(-_) characters", - "Email": "Must be a valid email address", - "IP": "Must be a valid ip address", - "Base64": "Must be valid base64 characters", - "Mobile": "Must be valid mobile number", - "Tel": "Must be valid telephone number", - "Phone": "Must be valid telephone or mobile phone number", - "ZipCode": "Must be valid zipcode", -} - -var once sync.Once - -// SetDefaultMessage set default messages -// if not set, the default messages are -// "Required": "Can not be empty", -// "Min": "Minimum is %d", -// "Max": "Maximum is %d", -// "Range": "Range is %d to %d", -// "MinSize": "Minimum size is %d", -// "MaxSize": "Maximum size is %d", -// "Length": "Required length is %d", -// "Alpha": "Must be valid alpha characters", -// "Numeric": "Must be valid numeric characters", -// "AlphaNumeric": "Must be valid alpha or numeric characters", -// "Match": "Must match %s", -// "NoMatch": "Must not match %s", -// "AlphaDash": "Must be valid alpha or numeric or dash(-_) characters", -// "Email": "Must be a valid email address", -// "IP": "Must be a valid ip address", -// "Base64": "Must be valid base64 characters", -// "Mobile": "Must be valid mobile number", -// "Tel": "Must be valid telephone number", -// "Phone": "Must be valid telephone or mobile phone number", -// "ZipCode": "Must be valid zipcode", -func SetDefaultMessage(msg map[string]string) { - if len(msg) == 0 { - return - } - - once.Do(func() { - for name := range msg { - MessageTmpls[name] = msg[name] - } - }) - logs.Warn(`you must SetDefaultMessage at once`) -} - -// Validator interface -type Validator interface { - IsSatisfied(interface{}) bool - DefaultMessage() string - GetKey() string - GetLimitValue() interface{} -} - -// Required struct -type Required struct { - Key string -} - -// IsSatisfied judge whether obj has value -func (r Required) IsSatisfied(obj interface{}) bool { - if obj == nil { - return false - } - - if str, ok := obj.(string); ok { - return len(strings.TrimSpace(str)) > 0 - } - if _, ok := obj.(bool); ok { - return true - } - if i, ok := obj.(int); ok { - return i != 0 - } - if i, ok := obj.(uint); ok { - return i != 0 - } - if i, ok := obj.(int8); ok { - return i != 0 - } - if i, ok := obj.(uint8); ok { - return i != 0 - } - if i, ok := obj.(int16); ok { - return i != 0 - } - if i, ok := obj.(uint16); ok { - return i != 0 - } - if i, ok := obj.(uint32); ok { - return i != 0 - } - if i, ok := obj.(int32); ok { - return i != 0 - } - if i, ok := obj.(int64); ok { - return i != 0 - } - if i, ok := obj.(uint64); ok { - return i != 0 - } - if t, ok := obj.(time.Time); ok { - return !t.IsZero() - } - v := reflect.ValueOf(obj) - if v.Kind() == reflect.Slice { - return v.Len() > 0 - } - return true -} - -// DefaultMessage return the default error message -func (r Required) DefaultMessage() string { - return MessageTmpls["Required"] -} - -// GetKey return the r.Key -func (r Required) GetKey() string { - return r.Key -} - -// GetLimitValue return nil now -func (r Required) GetLimitValue() interface{} { - return nil -} - -// Min check struct -type Min struct { - Min int - Key string -} - -// IsSatisfied judge whether obj is valid -// not support int64 on 32-bit platform -func (m Min) IsSatisfied(obj interface{}) bool { - var v int - switch obj.(type) { - case int64: - if wordsize == 32 { - return false - } - v = int(obj.(int64)) - case int: - v = obj.(int) - case int32: - v = int(obj.(int32)) - case int16: - v = int(obj.(int16)) - case int8: - v = int(obj.(int8)) - default: - return false - } - - return v >= m.Min -} - -// DefaultMessage return the default min error message -func (m Min) DefaultMessage() string { - return fmt.Sprintf(MessageTmpls["Min"], m.Min) -} - -// GetKey return the m.Key -func (m Min) GetKey() string { - return m.Key -} - -// GetLimitValue return the limit value, Min -func (m Min) GetLimitValue() interface{} { - return m.Min -} - -// Max validate struct -type Max struct { - Max int - Key string -} - -// IsSatisfied judge whether obj is valid -// not support int64 on 32-bit platform -func (m Max) IsSatisfied(obj interface{}) bool { - var v int - switch obj.(type) { - case int64: - if wordsize == 32 { - return false - } - v = int(obj.(int64)) - case int: - v = obj.(int) - case int32: - v = int(obj.(int32)) - case int16: - v = int(obj.(int16)) - case int8: - v = int(obj.(int8)) - default: - return false - } - - return v <= m.Max -} - -// DefaultMessage return the default max error message -func (m Max) DefaultMessage() string { - return fmt.Sprintf(MessageTmpls["Max"], m.Max) -} - -// GetKey return the m.Key -func (m Max) GetKey() string { - return m.Key -} - -// GetLimitValue return the limit value, Max -func (m Max) GetLimitValue() interface{} { - return m.Max -} - -// Range Requires an integer to be within Min, Max inclusive. -type Range struct { - Min - Max - Key string -} - -// IsSatisfied judge whether obj is valid -// not support int64 on 32-bit platform -func (r Range) IsSatisfied(obj interface{}) bool { - return r.Min.IsSatisfied(obj) && r.Max.IsSatisfied(obj) -} - -// DefaultMessage return the default Range error message -func (r Range) DefaultMessage() string { - return fmt.Sprintf(MessageTmpls["Range"], r.Min.Min, r.Max.Max) -} - -// GetKey return the m.Key -func (r Range) GetKey() string { - return r.Key -} - -// GetLimitValue return the limit value, Max -func (r Range) GetLimitValue() interface{} { - return []int{r.Min.Min, r.Max.Max} -} - -// MinSize Requires an array or string to be at least a given length. -type MinSize struct { - Min int - Key string -} - -// IsSatisfied judge whether obj is valid -func (m MinSize) IsSatisfied(obj interface{}) bool { - if str, ok := obj.(string); ok { - return utf8.RuneCountInString(str) >= m.Min - } - v := reflect.ValueOf(obj) - if v.Kind() == reflect.Slice { - return v.Len() >= m.Min - } - return false -} - -// DefaultMessage return the default MinSize error message -func (m MinSize) DefaultMessage() string { - return fmt.Sprintf(MessageTmpls["MinSize"], m.Min) -} - -// GetKey return the m.Key -func (m MinSize) GetKey() string { - return m.Key -} - -// GetLimitValue return the limit value -func (m MinSize) GetLimitValue() interface{} { - return m.Min -} - -// MaxSize Requires an array or string to be at most a given length. -type MaxSize struct { - Max int - Key string -} - -// IsSatisfied judge whether obj is valid -func (m MaxSize) IsSatisfied(obj interface{}) bool { - if str, ok := obj.(string); ok { - return utf8.RuneCountInString(str) <= m.Max - } - v := reflect.ValueOf(obj) - if v.Kind() == reflect.Slice { - return v.Len() <= m.Max - } - return false -} - -// DefaultMessage return the default MaxSize error message -func (m MaxSize) DefaultMessage() string { - return fmt.Sprintf(MessageTmpls["MaxSize"], m.Max) -} - -// GetKey return the m.Key -func (m MaxSize) GetKey() string { - return m.Key -} - -// GetLimitValue return the limit value -func (m MaxSize) GetLimitValue() interface{} { - return m.Max -} - -// Length Requires an array or string to be exactly a given length. -type Length struct { - N int - Key string -} - -// IsSatisfied judge whether obj is valid -func (l Length) IsSatisfied(obj interface{}) bool { - if str, ok := obj.(string); ok { - return utf8.RuneCountInString(str) == l.N - } - v := reflect.ValueOf(obj) - if v.Kind() == reflect.Slice { - return v.Len() == l.N - } - return false -} - -// DefaultMessage return the default Length error message -func (l Length) DefaultMessage() string { - return fmt.Sprintf(MessageTmpls["Length"], l.N) -} - -// GetKey return the m.Key -func (l Length) GetKey() string { - return l.Key -} - -// GetLimitValue return the limit value -func (l Length) GetLimitValue() interface{} { - return l.N -} - -// Alpha check the alpha -type Alpha struct { - Key string -} - -// IsSatisfied judge whether obj is valid -func (a Alpha) IsSatisfied(obj interface{}) bool { - if str, ok := obj.(string); ok { - for _, v := range str { - if ('Z' < v || v < 'A') && ('z' < v || v < 'a') { - return false - } - } - return true - } - return false -} - -// DefaultMessage return the default Length error message -func (a Alpha) DefaultMessage() string { - return MessageTmpls["Alpha"] -} - -// GetKey return the m.Key -func (a Alpha) GetKey() string { - return a.Key -} - -// GetLimitValue return the limit value -func (a Alpha) GetLimitValue() interface{} { - return nil -} - -// Numeric check number -type Numeric struct { - Key string -} - -// IsSatisfied judge whether obj is valid -func (n Numeric) IsSatisfied(obj interface{}) bool { - if str, ok := obj.(string); ok { - for _, v := range str { - if '9' < v || v < '0' { - return false - } - } - return true - } - return false -} - -// DefaultMessage return the default Length error message -func (n Numeric) DefaultMessage() string { - return MessageTmpls["Numeric"] -} - -// GetKey return the n.Key -func (n Numeric) GetKey() string { - return n.Key -} - -// GetLimitValue return the limit value -func (n Numeric) GetLimitValue() interface{} { - return nil -} - -// AlphaNumeric check alpha and number -type AlphaNumeric struct { - Key string -} - -// IsSatisfied judge whether obj is valid -func (a AlphaNumeric) IsSatisfied(obj interface{}) bool { - if str, ok := obj.(string); ok { - for _, v := range str { - if ('Z' < v || v < 'A') && ('z' < v || v < 'a') && ('9' < v || v < '0') { - return false - } - } - return true - } - return false -} - -// DefaultMessage return the default Length error message -func (a AlphaNumeric) DefaultMessage() string { - return MessageTmpls["AlphaNumeric"] -} - -// GetKey return the a.Key -func (a AlphaNumeric) GetKey() string { - return a.Key -} - -// GetLimitValue return the limit value -func (a AlphaNumeric) GetLimitValue() interface{} { - return nil -} - -// Match Requires a string to match a given regex. -type Match struct { - Regexp *regexp.Regexp - Key string -} - -// IsSatisfied judge whether obj is valid -func (m Match) IsSatisfied(obj interface{}) bool { - return m.Regexp.MatchString(fmt.Sprintf("%v", obj)) -} - -// DefaultMessage return the default Match error message -func (m Match) DefaultMessage() string { - return fmt.Sprintf(MessageTmpls["Match"], m.Regexp.String()) -} - -// GetKey return the m.Key -func (m Match) GetKey() string { - return m.Key -} - -// GetLimitValue return the limit value -func (m Match) GetLimitValue() interface{} { - return m.Regexp.String() -} - -// NoMatch Requires a string to not match a given regex. -type NoMatch struct { - Match - Key string -} - -// IsSatisfied judge whether obj is valid -func (n NoMatch) IsSatisfied(obj interface{}) bool { - return !n.Match.IsSatisfied(obj) -} - -// DefaultMessage return the default NoMatch error message -func (n NoMatch) DefaultMessage() string { - return fmt.Sprintf(MessageTmpls["NoMatch"], n.Regexp.String()) -} - -// GetKey return the n.Key -func (n NoMatch) GetKey() string { - return n.Key -} - -// GetLimitValue return the limit value -func (n NoMatch) GetLimitValue() interface{} { - return n.Regexp.String() -} - -var alphaDashPattern = regexp.MustCompile(`[^\d\w-_]`) - -// AlphaDash check not Alpha -type AlphaDash struct { - NoMatch - Key string -} - -// DefaultMessage return the default AlphaDash error message -func (a AlphaDash) DefaultMessage() string { - return MessageTmpls["AlphaDash"] -} - -// GetKey return the n.Key -func (a AlphaDash) GetKey() string { - return a.Key -} - -// GetLimitValue return the limit value -func (a AlphaDash) GetLimitValue() interface{} { - return nil -} - -var emailPattern = regexp.MustCompile(`^[\w!#$%&'*+/=?^_` + "`" + `{|}~-]+(?:\.[\w!#$%&'*+/=?^_` + "`" + `{|}~-]+)*@(?:[\w](?:[\w-]*[\w])?\.)+[a-zA-Z0-9](?:[\w-]*[\w])?$`) - -// Email check struct -type Email struct { - Match - Key string -} - -// DefaultMessage return the default Email error message -func (e Email) DefaultMessage() string { - return MessageTmpls["Email"] -} - -// GetKey return the n.Key -func (e Email) GetKey() string { - return e.Key -} - -// GetLimitValue return the limit value -func (e Email) GetLimitValue() interface{} { - return nil -} - -var ipPattern = regexp.MustCompile(`^((2[0-4]\d|25[0-5]|[01]?\d\d?)\.){3}(2[0-4]\d|25[0-5]|[01]?\d\d?)$`) - -// IP check struct -type IP struct { - Match - Key string -} - -// DefaultMessage return the default IP error message -func (i IP) DefaultMessage() string { - return MessageTmpls["IP"] -} - -// GetKey return the i.Key -func (i IP) GetKey() string { - return i.Key -} - -// GetLimitValue return the limit value -func (i IP) GetLimitValue() interface{} { - return nil -} - -var base64Pattern = regexp.MustCompile(`^(?:[A-Za-z0-99+/]{4})*(?:[A-Za-z0-9+/]{2}==|[A-Za-z0-9+/]{3}=)?$`) - -// Base64 check struct -type Base64 struct { - Match - Key string -} - -// DefaultMessage return the default Base64 error message -func (b Base64) DefaultMessage() string { - return MessageTmpls["Base64"] -} - -// GetKey return the b.Key -func (b Base64) GetKey() string { - return b.Key -} - -// GetLimitValue return the limit value -func (b Base64) GetLimitValue() interface{} { - return nil -} - -// just for chinese mobile phone number -var mobilePattern = regexp.MustCompile(`^((\+86)|(86))?1([356789][0-9]|4[579]|6[67]|7[0135678]|9[189])[0-9]{8}$`) - -// Mobile check struct -type Mobile struct { - Match - Key string -} - -// DefaultMessage return the default Mobile error message -func (m Mobile) DefaultMessage() string { - return MessageTmpls["Mobile"] -} - -// GetKey return the m.Key -func (m Mobile) GetKey() string { - return m.Key -} - -// GetLimitValue return the limit value -func (m Mobile) GetLimitValue() interface{} { - return nil -} - -// just for chinese telephone number -var telPattern = regexp.MustCompile(`^(0\d{2,3}(\-)?)?\d{7,8}$`) - -// Tel check telephone struct -type Tel struct { - Match - Key string -} - -// DefaultMessage return the default Tel error message -func (t Tel) DefaultMessage() string { - return MessageTmpls["Tel"] -} - -// GetKey return the t.Key -func (t Tel) GetKey() string { - return t.Key -} - -// GetLimitValue return the limit value -func (t Tel) GetLimitValue() interface{} { - return nil -} - -// Phone just for chinese telephone or mobile phone number -type Phone struct { - Mobile - Tel - Key string -} - -// IsSatisfied judge whether obj is valid -func (p Phone) IsSatisfied(obj interface{}) bool { - return p.Mobile.IsSatisfied(obj) || p.Tel.IsSatisfied(obj) -} - -// DefaultMessage return the default Phone error message -func (p Phone) DefaultMessage() string { - return MessageTmpls["Phone"] -} - -// GetKey return the p.Key -func (p Phone) GetKey() string { - return p.Key -} - -// GetLimitValue return the limit value -func (p Phone) GetLimitValue() interface{} { - return nil -} - -// just for chinese zipcode -var zipCodePattern = regexp.MustCompile(`^[1-9]\d{5}$`) - -// ZipCode check the zip struct -type ZipCode struct { - Match - Key string -} - -// DefaultMessage return the default Zip error message -func (z ZipCode) DefaultMessage() string { - return MessageTmpls["ZipCode"] -} - -// GetKey return the z.Key -func (z ZipCode) GetKey() string { - return z.Key -} - -// GetLimitValue return the limit value -func (z ZipCode) GetLimitValue() interface{} { - return nil -}