From 9c51952db485cb32a6658df173f622f6930199cd Mon Sep 17 00:00:00 2001 From: Ming Deng Date: Wed, 22 Jul 2020 22:50:08 +0800 Subject: [PATCH 1/3] Move package --- orm/cmd_utils.go | 10 +- orm/db_alias.go | 87 +- orm/orm_alias_adapt_test.go | 46 - pkg/admin.go | 458 +++++++ pkg/admin_test.go | 239 ++++ pkg/adminui.go | 356 ++++++ pkg/app.go | 496 ++++++++ pkg/beego.go | 123 ++ pkg/build_info.go | 27 + pkg/cache/README.md | 59 + pkg/cache/cache.go | 103 ++ pkg/cache/cache_test.go | 191 +++ pkg/cache/conv.go | 100 ++ pkg/cache/conv_test.go | 143 +++ pkg/cache/file.go | 258 ++++ pkg/cache/memcache/memcache.go | 188 +++ pkg/cache/memcache/memcache_test.go | 108 ++ pkg/cache/memory.go | 256 ++++ pkg/cache/redis/redis.go | 272 +++++ pkg/cache/redis/redis_test.go | 144 +++ pkg/cache/ssdb/ssdb.go | 231 ++++ pkg/cache/ssdb/ssdb_test.go | 104 ++ pkg/config.go | 524 ++++++++ pkg/config/config.go | 242 ++++ pkg/config/config_test.go | 55 + pkg/config/env/env.go | 87 ++ pkg/config/env/env_test.go | 75 ++ pkg/config/fake.go | 134 +++ pkg/config/ini.go | 504 ++++++++ pkg/config/ini_test.go | 190 +++ pkg/config/json.go | 269 +++++ pkg/config/json_test.go | 222 ++++ pkg/config/xml/xml.go | 228 ++++ pkg/config/xml/xml_test.go | 125 ++ pkg/config/yaml/yaml.go | 316 +++++ pkg/config/yaml/yaml_test.go | 115 ++ pkg/config_test.go | 146 +++ pkg/context/acceptencoder.go | 232 ++++ pkg/context/acceptencoder_test.go | 59 + pkg/context/context.go | 263 +++++ pkg/context/context_test.go | 47 + pkg/context/input.go | 689 +++++++++++ pkg/context/input_test.go | 217 ++++ pkg/context/output.go | 408 +++++++ pkg/context/param/conv.go | 78 ++ pkg/context/param/methodparams.go | 69 ++ pkg/context/param/options.go | 37 + pkg/context/param/parsers.go | 149 +++ pkg/context/param/parsers_test.go | 84 ++ pkg/context/renderer.go | 12 + pkg/context/response.go | 27 + pkg/controller.go | 706 +++++++++++ pkg/controller_test.go | 181 +++ pkg/doc.go | 17 + pkg/error.go | 488 ++++++++ pkg/error_test.go | 88 ++ pkg/filter.go | 44 + pkg/filter_test.go | 68 ++ pkg/flash.go | 110 ++ pkg/flash_test.go | 54 + pkg/fs.go | 74 ++ pkg/grace/grace.go | 166 +++ pkg/grace/server.go | 356 ++++++ pkg/hooks.go | 104 ++ pkg/httplib/README.md | 97 ++ pkg/httplib/httplib.go | 654 ++++++++++ pkg/httplib/httplib_test.go | 286 +++++ pkg/log.go | 127 ++ pkg/logs/README.md | 72 ++ pkg/logs/accesslog.go | 83 ++ pkg/logs/alils/alils.go | 186 +++ pkg/logs/alils/config.go | 13 + pkg/logs/alils/log.pb.go | 1038 ++++++++++++++++ pkg/logs/alils/log_config.go | 42 + pkg/logs/alils/log_project.go | 819 +++++++++++++ pkg/logs/alils/log_store.go | 271 +++++ pkg/logs/alils/machine_group.go | 91 ++ pkg/logs/alils/request.go | 62 + pkg/logs/alils/signature.go | 111 ++ pkg/logs/conn.go | 119 ++ pkg/logs/conn_test.go | 79 ++ pkg/logs/console.go | 99 ++ pkg/logs/console_test.go | 64 + pkg/logs/es/es.go | 102 ++ pkg/logs/file.go | 409 +++++++ pkg/logs/file_test.go | 420 +++++++ pkg/logs/jianliao.go | 72 ++ pkg/logs/log.go | 669 +++++++++++ pkg/logs/logger.go | 176 +++ pkg/logs/logger_test.go | 57 + pkg/logs/multifile.go | 119 ++ pkg/logs/multifile_test.go | 78 ++ pkg/logs/slack.go | 60 + pkg/logs/smtp.go | 149 +++ pkg/logs/smtp_test.go | 27 + pkg/metric/prometheus.go | 99 ++ pkg/metric/prometheus_test.go | 42 + pkg/migration/ddl.go | 395 +++++++ pkg/migration/doc.go | 32 + pkg/migration/migration.go | 330 ++++++ pkg/mime.go | 556 +++++++++ pkg/namespace.go | 396 +++++++ pkg/namespace_test.go | 168 +++ pkg/parser.go | 591 +++++++++ pkg/plugins/apiauth/apiauth.go | 165 +++ pkg/plugins/apiauth/apiauth_test.go | 20 + pkg/plugins/auth/basic.go | 107 ++ pkg/plugins/authz/authz.go | 86 ++ pkg/plugins/authz/authz_model.conf | 14 + pkg/plugins/authz/authz_policy.csv | 7 + pkg/plugins/authz/authz_test.go | 107 ++ pkg/plugins/cors/cors.go | 228 ++++ pkg/plugins/cors/cors_test.go | 253 ++++ pkg/policy.go | 97 ++ pkg/router.go | 1052 +++++++++++++++++ pkg/router_test.go | 732 ++++++++++++ pkg/session/README.md | 114 ++ pkg/session/couchbase/sess_couchbase.go | 247 ++++ pkg/session/ledis/ledis_session.go | 173 +++ pkg/session/memcache/sess_memcache.go | 230 ++++ pkg/session/mysql/sess_mysql.go | 228 ++++ pkg/session/postgres/sess_postgresql.go | 243 ++++ pkg/session/redis/sess_redis.go | 261 ++++ pkg/session/redis_cluster/redis_cluster.go | 220 ++++ .../redis_sentinel/sess_redis_sentinel.go | 234 ++++ .../sess_redis_sentinel_test.go | 90 ++ pkg/session/sess_cookie.go | 180 +++ pkg/session/sess_cookie_test.go | 105 ++ pkg/session/sess_file.go | 315 +++++ pkg/session/sess_file_test.go | 387 ++++++ pkg/session/sess_mem.go | 196 +++ pkg/session/sess_mem_test.go | 58 + pkg/session/sess_test.go | 131 ++ pkg/session/sess_utils.go | 207 ++++ pkg/session/session.go | 377 ++++++ pkg/session/ssdb/sess_ssdb.go | 199 ++++ pkg/staticfile.go | 234 ++++ pkg/staticfile_test.go | 99 ++ pkg/swagger/swagger.go | 174 +++ pkg/template.go | 406 +++++++ pkg/template_test.go | 316 +++++ pkg/templatefunc.go | 780 ++++++++++++ pkg/templatefunc_test.go | 380 ++++++ pkg/testdata/Makefile | 2 + pkg/testdata/bindata.go | 296 +++++ pkg/testdata/views/blocks/block.tpl | 3 + pkg/testdata/views/header.tpl | 3 + pkg/testdata/views/index.tpl | 15 + pkg/testing/assertions.go | 15 + pkg/testing/client.go | 65 + pkg/toolbox/healthcheck.go | 48 + pkg/toolbox/profile.go | 184 +++ pkg/toolbox/profile_test.go | 28 + pkg/toolbox/statistics.go | 149 +++ pkg/toolbox/statistics_test.go | 40 + pkg/toolbox/task.go | 640 ++++++++++ pkg/toolbox/task_test.go | 85 ++ pkg/tree.go | 585 +++++++++ pkg/tree_test.go | 306 +++++ pkg/unregroute_test.go | 226 ++++ pkg/utils/caller.go | 25 + pkg/utils/caller_test.go | 28 + pkg/utils/captcha/LICENSE | 19 + pkg/utils/captcha/README.md | 45 + pkg/utils/captcha/captcha.go | 270 +++++ pkg/utils/captcha/image.go | 501 ++++++++ pkg/utils/captcha/image_test.go | 52 + pkg/utils/captcha/siprng.go | 277 +++++ pkg/utils/captcha/siprng_test.go | 33 + pkg/utils/debug.go | 478 ++++++++ pkg/utils/debug_test.go | 46 + pkg/utils/file.go | 101 ++ pkg/utils/file_test.go | 75 ++ pkg/utils/mail.go | 424 +++++++ pkg/utils/mail_test.go | 41 + pkg/utils/pagination/controller.go | 26 + pkg/utils/pagination/doc.go | 58 + pkg/utils/pagination/paginator.go | 189 +++ pkg/utils/pagination/utils.go | 34 + pkg/utils/rand.go | 44 + pkg/utils/rand_test.go | 33 + pkg/utils/safemap.go | 91 ++ pkg/utils/safemap_test.go | 89 ++ pkg/utils/slice.go | 170 +++ pkg/utils/slice_test.go | 29 + pkg/utils/testdata/grepe.test | 7 + pkg/utils/utils.go | 89 ++ pkg/utils/utils_test.go | 36 + pkg/validation/README.md | 147 +++ pkg/validation/util.go | 298 +++++ pkg/validation/util_test.go | 128 ++ pkg/validation/validation.go | 456 +++++++ pkg/validation/validation_test.go | 609 ++++++++++ pkg/validation/validators.go | 738 ++++++++++++ 194 files changed, 39077 insertions(+), 69 deletions(-) delete mode 100644 orm/orm_alias_adapt_test.go create mode 100644 pkg/admin.go create mode 100644 pkg/admin_test.go create mode 100644 pkg/adminui.go create mode 100644 pkg/app.go create mode 100644 pkg/beego.go create mode 100644 pkg/build_info.go create mode 100644 pkg/cache/README.md create mode 100644 pkg/cache/cache.go create mode 100644 pkg/cache/cache_test.go create mode 100644 pkg/cache/conv.go create mode 100644 pkg/cache/conv_test.go create mode 100644 pkg/cache/file.go create mode 100644 pkg/cache/memcache/memcache.go create mode 100644 pkg/cache/memcache/memcache_test.go create mode 100644 pkg/cache/memory.go create mode 100644 pkg/cache/redis/redis.go create mode 100644 pkg/cache/redis/redis_test.go create mode 100644 pkg/cache/ssdb/ssdb.go create mode 100644 pkg/cache/ssdb/ssdb_test.go create mode 100644 pkg/config.go create mode 100644 pkg/config/config.go create mode 100644 pkg/config/config_test.go create mode 100644 pkg/config/env/env.go create mode 100644 pkg/config/env/env_test.go create mode 100644 pkg/config/fake.go create mode 100644 pkg/config/ini.go create mode 100644 pkg/config/ini_test.go create mode 100644 pkg/config/json.go create mode 100644 pkg/config/json_test.go create mode 100644 pkg/config/xml/xml.go create mode 100644 pkg/config/xml/xml_test.go create mode 100644 pkg/config/yaml/yaml.go create mode 100644 pkg/config/yaml/yaml_test.go create mode 100644 pkg/config_test.go create mode 100644 pkg/context/acceptencoder.go create mode 100644 pkg/context/acceptencoder_test.go create mode 100644 pkg/context/context.go create mode 100644 pkg/context/context_test.go create mode 100644 pkg/context/input.go create mode 100644 pkg/context/input_test.go create mode 100644 pkg/context/output.go create mode 100644 pkg/context/param/conv.go create mode 100644 pkg/context/param/methodparams.go create mode 100644 pkg/context/param/options.go create mode 100644 pkg/context/param/parsers.go create mode 100644 pkg/context/param/parsers_test.go create mode 100644 pkg/context/renderer.go create mode 100644 pkg/context/response.go create mode 100644 pkg/controller.go create mode 100644 pkg/controller_test.go create mode 100644 pkg/doc.go create mode 100644 pkg/error.go create mode 100644 pkg/error_test.go create mode 100644 pkg/filter.go create mode 100644 pkg/filter_test.go create mode 100644 pkg/flash.go create mode 100644 pkg/flash_test.go create mode 100644 pkg/fs.go create mode 100644 pkg/grace/grace.go create mode 100644 pkg/grace/server.go create mode 100644 pkg/hooks.go create mode 100644 pkg/httplib/README.md create mode 100644 pkg/httplib/httplib.go create mode 100644 pkg/httplib/httplib_test.go create mode 100644 pkg/log.go create mode 100644 pkg/logs/README.md create mode 100644 pkg/logs/accesslog.go create mode 100644 pkg/logs/alils/alils.go create mode 100755 pkg/logs/alils/config.go create mode 100755 pkg/logs/alils/log.pb.go create mode 100755 pkg/logs/alils/log_config.go create mode 100755 pkg/logs/alils/log_project.go create mode 100755 pkg/logs/alils/log_store.go create mode 100755 pkg/logs/alils/machine_group.go create mode 100755 pkg/logs/alils/request.go create mode 100755 pkg/logs/alils/signature.go create mode 100644 pkg/logs/conn.go create mode 100644 pkg/logs/conn_test.go create mode 100644 pkg/logs/console.go create mode 100644 pkg/logs/console_test.go create mode 100644 pkg/logs/es/es.go create mode 100644 pkg/logs/file.go create mode 100644 pkg/logs/file_test.go create mode 100644 pkg/logs/jianliao.go create mode 100644 pkg/logs/log.go create mode 100644 pkg/logs/logger.go create mode 100644 pkg/logs/logger_test.go create mode 100644 pkg/logs/multifile.go create mode 100644 pkg/logs/multifile_test.go create mode 100644 pkg/logs/slack.go create mode 100644 pkg/logs/smtp.go create mode 100644 pkg/logs/smtp_test.go create mode 100644 pkg/metric/prometheus.go create mode 100644 pkg/metric/prometheus_test.go create mode 100644 pkg/migration/ddl.go create mode 100644 pkg/migration/doc.go create mode 100644 pkg/migration/migration.go create mode 100644 pkg/mime.go create mode 100644 pkg/namespace.go create mode 100644 pkg/namespace_test.go create mode 100644 pkg/parser.go create mode 100644 pkg/plugins/apiauth/apiauth.go create mode 100644 pkg/plugins/apiauth/apiauth_test.go create mode 100644 pkg/plugins/auth/basic.go create mode 100644 pkg/plugins/authz/authz.go create mode 100644 pkg/plugins/authz/authz_model.conf create mode 100644 pkg/plugins/authz/authz_policy.csv create mode 100644 pkg/plugins/authz/authz_test.go create mode 100644 pkg/plugins/cors/cors.go create mode 100644 pkg/plugins/cors/cors_test.go create mode 100644 pkg/policy.go create mode 100644 pkg/router.go create mode 100644 pkg/router_test.go create mode 100644 pkg/session/README.md create mode 100644 pkg/session/couchbase/sess_couchbase.go create mode 100644 pkg/session/ledis/ledis_session.go create mode 100644 pkg/session/memcache/sess_memcache.go create mode 100644 pkg/session/mysql/sess_mysql.go create mode 100644 pkg/session/postgres/sess_postgresql.go create mode 100644 pkg/session/redis/sess_redis.go create mode 100644 pkg/session/redis_cluster/redis_cluster.go create mode 100644 pkg/session/redis_sentinel/sess_redis_sentinel.go create mode 100644 pkg/session/redis_sentinel/sess_redis_sentinel_test.go create mode 100644 pkg/session/sess_cookie.go create mode 100644 pkg/session/sess_cookie_test.go create mode 100644 pkg/session/sess_file.go create mode 100644 pkg/session/sess_file_test.go create mode 100644 pkg/session/sess_mem.go create mode 100644 pkg/session/sess_mem_test.go create mode 100644 pkg/session/sess_test.go create mode 100644 pkg/session/sess_utils.go create mode 100644 pkg/session/session.go create mode 100644 pkg/session/ssdb/sess_ssdb.go create mode 100644 pkg/staticfile.go create mode 100644 pkg/staticfile_test.go create mode 100644 pkg/swagger/swagger.go create mode 100644 pkg/template.go create mode 100644 pkg/template_test.go create mode 100644 pkg/templatefunc.go create mode 100644 pkg/templatefunc_test.go create mode 100644 pkg/testdata/Makefile create mode 100644 pkg/testdata/bindata.go create mode 100644 pkg/testdata/views/blocks/block.tpl create mode 100644 pkg/testdata/views/header.tpl create mode 100644 pkg/testdata/views/index.tpl create mode 100644 pkg/testing/assertions.go create mode 100644 pkg/testing/client.go create mode 100644 pkg/toolbox/healthcheck.go create mode 100644 pkg/toolbox/profile.go create mode 100644 pkg/toolbox/profile_test.go create mode 100644 pkg/toolbox/statistics.go create mode 100644 pkg/toolbox/statistics_test.go create mode 100644 pkg/toolbox/task.go create mode 100644 pkg/toolbox/task_test.go create mode 100644 pkg/tree.go create mode 100644 pkg/tree_test.go create mode 100644 pkg/unregroute_test.go create mode 100644 pkg/utils/caller.go create mode 100644 pkg/utils/caller_test.go create mode 100644 pkg/utils/captcha/LICENSE create mode 100644 pkg/utils/captcha/README.md create mode 100644 pkg/utils/captcha/captcha.go create mode 100644 pkg/utils/captcha/image.go create mode 100644 pkg/utils/captcha/image_test.go create mode 100644 pkg/utils/captcha/siprng.go create mode 100644 pkg/utils/captcha/siprng_test.go create mode 100644 pkg/utils/debug.go create mode 100644 pkg/utils/debug_test.go create mode 100644 pkg/utils/file.go create mode 100644 pkg/utils/file_test.go create mode 100644 pkg/utils/mail.go create mode 100644 pkg/utils/mail_test.go create mode 100644 pkg/utils/pagination/controller.go create mode 100644 pkg/utils/pagination/doc.go create mode 100644 pkg/utils/pagination/paginator.go create mode 100644 pkg/utils/pagination/utils.go create mode 100644 pkg/utils/rand.go create mode 100644 pkg/utils/rand_test.go create mode 100644 pkg/utils/safemap.go create mode 100644 pkg/utils/safemap_test.go create mode 100644 pkg/utils/slice.go create mode 100644 pkg/utils/slice_test.go create mode 100644 pkg/utils/testdata/grepe.test create mode 100644 pkg/utils/utils.go create mode 100644 pkg/utils/utils_test.go create mode 100644 pkg/validation/README.md create mode 100644 pkg/validation/util.go create mode 100644 pkg/validation/util_test.go create mode 100644 pkg/validation/validation.go create mode 100644 pkg/validation/validation_test.go create mode 100644 pkg/validation/validators.go diff --git a/orm/cmd_utils.go b/orm/cmd_utils.go index eac85091..61f17346 100644 --- a/orm/cmd_utils.go +++ b/orm/cmd_utils.go @@ -178,9 +178,9 @@ func getDbCreateSQL(al *alias) (sqls []string, tableIndexes map[string][]dbIndex column += " " + "NOT NULL" } - // if fi.initial.String() != "" { + //if fi.initial.String() != "" { // column += " DEFAULT " + fi.initial.String() - // } + //} // Append attribute DEFAULT column += getColumnDefault(fi) @@ -197,9 +197,9 @@ func getDbCreateSQL(al *alias) (sqls []string, tableIndexes map[string][]dbIndex if strings.Contains(column, "%COL%") { column = strings.Replace(column, "%COL%", fi.column, -1) } - - if fi.description != "" && al.Driver != DRSqlite { - column += " " + fmt.Sprintf("COMMENT '%s'", fi.description) + + if fi.description != "" && al.Driver!=DRSqlite { + column += " " + fmt.Sprintf("COMMENT '%s'",fi.description) } columns = append(columns, column) diff --git a/orm/db_alias.go b/orm/db_alias.go index a84070b4..bf6c350c 100644 --- a/orm/db_alias.go +++ b/orm/db_alias.go @@ -12,21 +12,16 @@ // See the License for the specific language governing permissions and // limitations under the License. -// Deprecated: we will remove this package, please using pkg/orm package orm import ( "context" "database/sql" "fmt" + lru "github.com/hashicorp/golang-lru" "reflect" "sync" "time" - - lru "github.com/hashicorp/golang-lru" - - "github.com/astaxie/beego/pkg/common" - orm2 "github.com/astaxie/beego/pkg/orm" ) // DriverType database driver constant int. @@ -68,7 +63,7 @@ var ( "tidb": DRTiDB, "oracle": DROracle, "oci8": DROracle, // github.com/mattn/go-oci8 - "ora": DROracle, // https://github.com/rana/ora + "ora": DROracle, //https://github.com/rana/ora } dbBasers = map[DriverType]dbBaser{ DRMySQL: newdbBaseMysql(), @@ -124,7 +119,7 @@ func (d *DB) BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error) return d.DB.BeginTx(ctx, opts) } -// su must call release to release *sql.Stmt after using +//su must call release to release *sql.Stmt after using func (d *DB) getStmtDecorator(query string) (*stmtDecorator, error) { d.RLock() c, ok := d.stmtDecorators.Get(query) @@ -294,26 +289,82 @@ func detectTZ(al *alias) { } } +func addAliasWthDB(aliasName, driverName string, db *sql.DB) (*alias, error) { + al := new(alias) + al.Name = aliasName + al.DriverName = driverName + al.DB = &DB{ + RWMutex: new(sync.RWMutex), + DB: db, + stmtDecorators: newStmtDecoratorLruWithEvict(), + } + + if dr, ok := drivers[driverName]; ok { + al.DbBaser = dbBasers[dr] + al.Driver = dr + } else { + return nil, fmt.Errorf("driver name `%s` have not registered", driverName) + } + + err := db.Ping() + if err != nil { + return nil, fmt.Errorf("register db Ping `%s`, %s", aliasName, err.Error()) + } + + if !dataBaseCache.add(aliasName, al) { + return nil, fmt.Errorf("DataBase alias name `%s` already registered, cannot reuse", aliasName) + } + + return al, nil +} + // AddAliasWthDB add a aliasName for the drivename -// Deprecated: please using pkg/orm func AddAliasWthDB(aliasName, driverName string, db *sql.DB) error { - return orm2.AddAliasWthDB(aliasName, driverName, db) + _, err := addAliasWthDB(aliasName, driverName, db) + return err } // RegisterDataBase Setting the database connect params. Use the database driver self dataSource args. func RegisterDataBase(aliasName, driverName, dataSource string, params ...int) error { - kvs := make([]common.KV, 0, 2) + var ( + err error + db *sql.DB + al *alias + ) + + db, err = sql.Open(driverName, dataSource) + if err != nil { + err = fmt.Errorf("register db `%s`, %s", aliasName, err.Error()) + goto end + } + + al, err = addAliasWthDB(aliasName, driverName, db) + if err != nil { + goto end + } + + al.DataSource = dataSource + + detectTZ(al) + for i, v := range params { switch i { case 0: - kvs = append(kvs, common.KV{Key: orm2.MaxIdleConnsKey, Value: v}) + SetMaxIdleConns(al.Name, v) case 1: - kvs = append(kvs, common.KV{Key: orm2.MaxOpenConnsKey, Value: v}) - case 2: - kvs = append(kvs, common.KV{Key: orm2.ConnMaxLifetimeKey, Value: time.Duration(v) * time.Millisecond}) + SetMaxOpenConns(al.Name, v) } } - return orm2.RegisterDataBase(aliasName, driverName, dataSource, kvs...) + +end: + if err != nil { + if db != nil { + db.Close() + } + DebugLog.Println(err.Error()) + } + + return err } // RegisterDriver Register a database driver use specify driver name, this can be definition the driver is which database type. @@ -373,7 +424,7 @@ func GetDB(aliasNames ...string) (*sql.DB, error) { } type stmtDecorator struct { - wg sync.WaitGroup + wg sync.WaitGroup stmt *sql.Stmt } @@ -393,7 +444,7 @@ func (s *stmtDecorator) release() { s.wg.Done() } -// garbage recycle for stmt +//garbage recycle for stmt func (s *stmtDecorator) destroy() { go func() { s.wg.Wait() diff --git a/orm/orm_alias_adapt_test.go b/orm/orm_alias_adapt_test.go deleted file mode 100644 index d7724527..00000000 --- a/orm/orm_alias_adapt_test.go +++ /dev/null @@ -1,46 +0,0 @@ -// Copyright 2020 beego-dev -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package orm - -import ( - "os" - "testing" - - _ "github.com/go-sql-driver/mysql" - _ "github.com/lib/pq" - _ "github.com/mattn/go-sqlite3" - "github.com/stretchr/testify/assert" -) - -var DBARGS = struct { - Driver string - Source string - Debug string -}{ - os.Getenv("ORM_DRIVER"), - os.Getenv("ORM_SOURCE"), - os.Getenv("ORM_DEBUG"), -} - -func TestRegisterDataBase(t *testing.T) { - err := RegisterDataBase("test-adapt1", DBARGS.Driver, DBARGS.Source) - assert.Nil(t, err) - err = RegisterDataBase("test-adapt2", DBARGS.Driver, DBARGS.Source, 20) - assert.Nil(t, err) - err = RegisterDataBase("test-adapt3", DBARGS.Driver, DBARGS.Source, 20, 300) - assert.Nil(t, err) - err = RegisterDataBase("test-adapt4", DBARGS.Driver, DBARGS.Source, 20, 300, 60*1000) - assert.Nil(t, err) -} diff --git a/pkg/admin.go b/pkg/admin.go new file mode 100644 index 00000000..db52647e --- /dev/null +++ b/pkg/admin.go @@ -0,0 +1,458 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package beego + +import ( + "bytes" + "encoding/json" + "fmt" + "net/http" + "os" + "reflect" + "strconv" + "text/template" + "time" + + "github.com/prometheus/client_golang/prometheus/promhttp" + + "github.com/astaxie/beego/grace" + "github.com/astaxie/beego/logs" + "github.com/astaxie/beego/toolbox" + "github.com/astaxie/beego/utils" +) + +// BeeAdminApp is the default adminApp used by admin module. +var beeAdminApp *adminApp + +// FilterMonitorFunc is default monitor filter when admin module is enable. +// if this func returns, admin module records qps for this request by condition of this function logic. +// usage: +// func MyFilterMonitor(method, requestPath string, t time.Duration, pattern string, statusCode int) bool { +// if method == "POST" { +// return false +// } +// if t.Nanoseconds() < 100 { +// return false +// } +// if strings.HasPrefix(requestPath, "/astaxie") { +// return false +// } +// return true +// } +// beego.FilterMonitorFunc = MyFilterMonitor. +var FilterMonitorFunc func(string, string, time.Duration, string, int) bool + +func init() { + beeAdminApp = &adminApp{ + routers: make(map[string]http.HandlerFunc), + } + // keep in mind that all data should be html escaped to avoid XSS attack + beeAdminApp.Route("/", adminIndex) + beeAdminApp.Route("/qps", qpsIndex) + beeAdminApp.Route("/prof", profIndex) + beeAdminApp.Route("/healthcheck", healthcheck) + beeAdminApp.Route("/task", taskStatus) + beeAdminApp.Route("/listconf", listConf) + beeAdminApp.Route("/metrics", promhttp.Handler().ServeHTTP) + FilterMonitorFunc = func(string, string, time.Duration, string, int) bool { return true } +} + +// AdminIndex is the default http.Handler for admin module. +// it matches url pattern "/". +func adminIndex(rw http.ResponseWriter, _ *http.Request) { + writeTemplate(rw, map[interface{}]interface{}{}, indexTpl, defaultScriptsTpl) +} + +// QpsIndex is the http.Handler for writing qps statistics map result info in http.ResponseWriter. +// it's registered with url pattern "/qps" in admin module. +func qpsIndex(rw http.ResponseWriter, _ *http.Request) { + data := make(map[interface{}]interface{}) + data["Content"] = toolbox.StatisticsMap.GetMap() + + // do html escape before display path, avoid xss + if content, ok := (data["Content"]).(M); ok { + if resultLists, ok := (content["Data"]).([][]string); ok { + for i := range resultLists { + if len(resultLists[i]) > 0 { + resultLists[i][0] = template.HTMLEscapeString(resultLists[i][0]) + } + } + } + } + + writeTemplate(rw, data, qpsTpl, defaultScriptsTpl) +} + +// ListConf is the http.Handler of displaying all beego configuration values as key/value pair. +// it's registered with url pattern "/listconf" in admin module. +func listConf(rw http.ResponseWriter, r *http.Request) { + r.ParseForm() + command := r.Form.Get("command") + if command == "" { + rw.Write([]byte("command not support")) + return + } + + data := make(map[interface{}]interface{}) + switch command { + case "conf": + m := make(M) + list("BConfig", BConfig, m) + m["AppConfigPath"] = template.HTMLEscapeString(appConfigPath) + m["AppConfigProvider"] = template.HTMLEscapeString(appConfigProvider) + tmpl := template.Must(template.New("dashboard").Parse(dashboardTpl)) + tmpl = template.Must(tmpl.Parse(configTpl)) + tmpl = template.Must(tmpl.Parse(defaultScriptsTpl)) + + data["Content"] = m + + tmpl.Execute(rw, data) + + case "router": + content := PrintTree() + content["Fields"] = []string{ + "Router Pattern", + "Methods", + "Controller", + } + data["Content"] = content + data["Title"] = "Routers" + writeTemplate(rw, data, routerAndFilterTpl, defaultScriptsTpl) + case "filter": + var ( + content = M{ + "Fields": []string{ + "Router Pattern", + "Filter Function", + }, + } + filterTypes = []string{} + filterTypeData = make(M) + ) + + if BeeApp.Handlers.enableFilter { + var filterType string + for k, fr := range map[int]string{ + BeforeStatic: "Before Static", + BeforeRouter: "Before Router", + BeforeExec: "Before Exec", + AfterExec: "After Exec", + FinishRouter: "Finish Router"} { + if bf := BeeApp.Handlers.filters[k]; len(bf) > 0 { + filterType = fr + filterTypes = append(filterTypes, filterType) + resultList := new([][]string) + for _, f := range bf { + var result = []string{ + // void xss + template.HTMLEscapeString(f.pattern), + template.HTMLEscapeString(utils.GetFuncName(f.filterFunc)), + } + *resultList = append(*resultList, result) + } + filterTypeData[filterType] = resultList + } + } + } + + content["Data"] = filterTypeData + content["Methods"] = filterTypes + + data["Content"] = content + data["Title"] = "Filters" + writeTemplate(rw, data, routerAndFilterTpl, defaultScriptsTpl) + default: + rw.Write([]byte("command not support")) + } +} + +func list(root string, p interface{}, m M) { + pt := reflect.TypeOf(p) + pv := reflect.ValueOf(p) + if pt.Kind() == reflect.Ptr { + pt = pt.Elem() + pv = pv.Elem() + } + for i := 0; i < pv.NumField(); i++ { + var key string + if root == "" { + key = pt.Field(i).Name + } else { + key = root + "." + pt.Field(i).Name + } + if pv.Field(i).Kind() == reflect.Struct { + list(key, pv.Field(i).Interface(), m) + } else { + m[key] = pv.Field(i).Interface() + } + } +} + +// PrintTree prints all registered routers. +func PrintTree() M { + var ( + content = M{} + methods = []string{} + methodsData = make(M) + ) + for method, t := range BeeApp.Handlers.routers { + + resultList := new([][]string) + + printTree(resultList, t) + + methods = append(methods, template.HTMLEscapeString(method)) + methodsData[template.HTMLEscapeString(method)] = resultList + } + + content["Data"] = methodsData + content["Methods"] = methods + return content +} + +func printTree(resultList *[][]string, t *Tree) { + for _, tr := range t.fixrouters { + printTree(resultList, tr) + } + if t.wildcard != nil { + printTree(resultList, t.wildcard) + } + for _, l := range t.leaves { + if v, ok := l.runObject.(*ControllerInfo); ok { + if v.routerType == routerTypeBeego { + var result = []string{ + template.HTMLEscapeString(v.pattern), + template.HTMLEscapeString(fmt.Sprintf("%s", v.methods)), + template.HTMLEscapeString(v.controllerType.String()), + } + *resultList = append(*resultList, result) + } else if v.routerType == routerTypeRESTFul { + var result = []string{ + template.HTMLEscapeString(v.pattern), + template.HTMLEscapeString(fmt.Sprintf("%s", v.methods)), + "", + } + *resultList = append(*resultList, result) + } else if v.routerType == routerTypeHandler { + var result = []string{ + template.HTMLEscapeString(v.pattern), + "", + "", + } + *resultList = append(*resultList, result) + } + } + } +} + +// ProfIndex is a http.Handler for showing profile command. +// it's in url pattern "/prof" in admin module. +func profIndex(rw http.ResponseWriter, r *http.Request) { + r.ParseForm() + command := r.Form.Get("command") + if command == "" { + return + } + + var ( + format = r.Form.Get("format") + data = make(map[interface{}]interface{}) + result bytes.Buffer + ) + toolbox.ProcessInput(command, &result) + data["Content"] = template.HTMLEscapeString(result.String()) + + if format == "json" && command == "gc summary" { + dataJSON, err := json.Marshal(data) + if err != nil { + http.Error(rw, err.Error(), http.StatusInternalServerError) + return + } + writeJSON(rw, dataJSON) + return + } + + data["Title"] = template.HTMLEscapeString(command) + defaultTpl := defaultScriptsTpl + if command == "gc summary" { + defaultTpl = gcAjaxTpl + } + writeTemplate(rw, data, profillingTpl, defaultTpl) +} + +// Healthcheck is a http.Handler calling health checking and showing the result. +// it's in "/healthcheck" pattern in admin module. +func healthcheck(rw http.ResponseWriter, r *http.Request) { + var ( + result []string + data = make(map[interface{}]interface{}) + resultList = new([][]string) + content = M{ + "Fields": []string{"Name", "Message", "Status"}, + } + ) + + for name, h := range toolbox.AdminCheckList { + if err := h.Check(); err != nil { + result = []string{ + "error", + template.HTMLEscapeString(name), + template.HTMLEscapeString(err.Error()), + } + } else { + result = []string{ + "success", + template.HTMLEscapeString(name), + "OK", + } + } + *resultList = append(*resultList, result) + } + + queryParams := r.URL.Query() + jsonFlag := queryParams.Get("json") + shouldReturnJSON, _ := strconv.ParseBool(jsonFlag) + + if shouldReturnJSON { + response := buildHealthCheckResponseList(resultList) + jsonResponse, err := json.Marshal(response) + + if err != nil { + http.Error(rw, err.Error(), http.StatusInternalServerError) + } else { + writeJSON(rw, jsonResponse) + } + return + } + + content["Data"] = resultList + data["Content"] = content + data["Title"] = "Health Check" + + writeTemplate(rw, data, healthCheckTpl, defaultScriptsTpl) +} + +func buildHealthCheckResponseList(healthCheckResults *[][]string) []map[string]interface{} { + response := make([]map[string]interface{}, len(*healthCheckResults)) + + for i, healthCheckResult := range *healthCheckResults { + currentResultMap := make(map[string]interface{}) + + currentResultMap["name"] = healthCheckResult[0] + currentResultMap["message"] = healthCheckResult[1] + currentResultMap["status"] = healthCheckResult[2] + + response[i] = currentResultMap + } + + return response + +} + +func writeJSON(rw http.ResponseWriter, jsonData []byte) { + rw.Header().Set("Content-Type", "application/json") + rw.Write(jsonData) +} + +// TaskStatus is a http.Handler with running task status (task name, status and the last execution). +// it's in "/task" pattern in admin module. +func taskStatus(rw http.ResponseWriter, req *http.Request) { + data := make(map[interface{}]interface{}) + + // Run Task + req.ParseForm() + taskname := req.Form.Get("taskname") + if taskname != "" { + if t, ok := toolbox.AdminTaskList[taskname]; ok { + if err := t.Run(); err != nil { + data["Message"] = []string{"error", template.HTMLEscapeString(fmt.Sprintf("%s", err))} + } + data["Message"] = []string{"success", template.HTMLEscapeString(fmt.Sprintf("%s run success,Now the Status is
%s", taskname, t.GetStatus()))} + } else { + data["Message"] = []string{"warning", template.HTMLEscapeString(fmt.Sprintf("there's no task which named: %s", taskname))} + } + } + + // List Tasks + content := make(M) + resultList := new([][]string) + var fields = []string{ + "Task Name", + "Task Spec", + "Task Status", + "Last Time", + "", + } + for tname, tk := range toolbox.AdminTaskList { + result := []string{ + template.HTMLEscapeString(tname), + template.HTMLEscapeString(tk.GetSpec()), + template.HTMLEscapeString(tk.GetStatus()), + template.HTMLEscapeString(tk.GetPrev().String()), + } + *resultList = append(*resultList, result) + } + + content["Fields"] = fields + content["Data"] = resultList + data["Content"] = content + data["Title"] = "Tasks" + writeTemplate(rw, data, tasksTpl, defaultScriptsTpl) +} + +func writeTemplate(rw http.ResponseWriter, data map[interface{}]interface{}, tpls ...string) { + tmpl := template.Must(template.New("dashboard").Parse(dashboardTpl)) + for _, tpl := range tpls { + tmpl = template.Must(tmpl.Parse(tpl)) + } + tmpl.Execute(rw, data) +} + +// adminApp is an http.HandlerFunc map used as beeAdminApp. +type adminApp struct { + routers map[string]http.HandlerFunc +} + +// Route adds http.HandlerFunc to adminApp with url pattern. +func (admin *adminApp) Route(pattern string, f http.HandlerFunc) { + admin.routers[pattern] = f +} + +// Run adminApp http server. +// Its addr is defined in configuration file as adminhttpaddr and adminhttpport. +func (admin *adminApp) Run() { + if len(toolbox.AdminTaskList) > 0 { + toolbox.StartTask() + } + addr := BConfig.Listen.AdminAddr + + if BConfig.Listen.AdminPort != 0 { + addr = fmt.Sprintf("%s:%d", BConfig.Listen.AdminAddr, BConfig.Listen.AdminPort) + } + for p, f := range admin.routers { + http.Handle(p, f) + } + logs.Info("Admin server Running on %s", addr) + + var err error + if BConfig.Listen.Graceful { + err = grace.ListenAndServe(addr, nil) + } else { + err = http.ListenAndServe(addr, nil) + } + if err != nil { + logs.Critical("Admin ListenAndServe: ", err, fmt.Sprintf("%d", os.Getpid())) + } +} diff --git a/pkg/admin_test.go b/pkg/admin_test.go new file mode 100644 index 00000000..3f3612e4 --- /dev/null +++ b/pkg/admin_test.go @@ -0,0 +1,239 @@ +package beego + +import ( + "encoding/json" + "errors" + "fmt" + "net/http" + "net/http/httptest" + "reflect" + "strings" + "testing" + + "github.com/astaxie/beego/toolbox" +) + +type SampleDatabaseCheck struct { +} + +type SampleCacheCheck struct { +} + +func (dc *SampleDatabaseCheck) Check() error { + return nil +} + +func (cc *SampleCacheCheck) Check() error { + return errors.New("no cache detected") +} + +func TestList_01(t *testing.T) { + m := make(M) + list("BConfig", BConfig, m) + t.Log(m) + om := oldMap() + for k, v := range om { + if fmt.Sprint(m[k]) != fmt.Sprint(v) { + t.Log(k, "old-key", v, "new-key", m[k]) + t.FailNow() + } + } +} + +func oldMap() M { + m := make(M) + m["BConfig.AppName"] = BConfig.AppName + m["BConfig.RunMode"] = BConfig.RunMode + m["BConfig.RouterCaseSensitive"] = BConfig.RouterCaseSensitive + m["BConfig.ServerName"] = BConfig.ServerName + m["BConfig.RecoverPanic"] = BConfig.RecoverPanic + m["BConfig.CopyRequestBody"] = BConfig.CopyRequestBody + m["BConfig.EnableGzip"] = BConfig.EnableGzip + m["BConfig.MaxMemory"] = BConfig.MaxMemory + m["BConfig.EnableErrorsShow"] = BConfig.EnableErrorsShow + m["BConfig.Listen.Graceful"] = BConfig.Listen.Graceful + m["BConfig.Listen.ServerTimeOut"] = BConfig.Listen.ServerTimeOut + m["BConfig.Listen.ListenTCP4"] = BConfig.Listen.ListenTCP4 + m["BConfig.Listen.EnableHTTP"] = BConfig.Listen.EnableHTTP + m["BConfig.Listen.HTTPAddr"] = BConfig.Listen.HTTPAddr + m["BConfig.Listen.HTTPPort"] = BConfig.Listen.HTTPPort + m["BConfig.Listen.EnableHTTPS"] = BConfig.Listen.EnableHTTPS + m["BConfig.Listen.HTTPSAddr"] = BConfig.Listen.HTTPSAddr + m["BConfig.Listen.HTTPSPort"] = BConfig.Listen.HTTPSPort + m["BConfig.Listen.HTTPSCertFile"] = BConfig.Listen.HTTPSCertFile + m["BConfig.Listen.HTTPSKeyFile"] = BConfig.Listen.HTTPSKeyFile + m["BConfig.Listen.EnableAdmin"] = BConfig.Listen.EnableAdmin + m["BConfig.Listen.AdminAddr"] = BConfig.Listen.AdminAddr + m["BConfig.Listen.AdminPort"] = BConfig.Listen.AdminPort + m["BConfig.Listen.EnableFcgi"] = BConfig.Listen.EnableFcgi + m["BConfig.Listen.EnableStdIo"] = BConfig.Listen.EnableStdIo + m["BConfig.WebConfig.AutoRender"] = BConfig.WebConfig.AutoRender + m["BConfig.WebConfig.EnableDocs"] = BConfig.WebConfig.EnableDocs + m["BConfig.WebConfig.FlashName"] = BConfig.WebConfig.FlashName + m["BConfig.WebConfig.FlashSeparator"] = BConfig.WebConfig.FlashSeparator + m["BConfig.WebConfig.DirectoryIndex"] = BConfig.WebConfig.DirectoryIndex + m["BConfig.WebConfig.StaticDir"] = BConfig.WebConfig.StaticDir + m["BConfig.WebConfig.StaticExtensionsToGzip"] = BConfig.WebConfig.StaticExtensionsToGzip + m["BConfig.WebConfig.StaticCacheFileSize"] = BConfig.WebConfig.StaticCacheFileSize + m["BConfig.WebConfig.StaticCacheFileNum"] = BConfig.WebConfig.StaticCacheFileNum + m["BConfig.WebConfig.TemplateLeft"] = BConfig.WebConfig.TemplateLeft + m["BConfig.WebConfig.TemplateRight"] = BConfig.WebConfig.TemplateRight + m["BConfig.WebConfig.ViewsPath"] = BConfig.WebConfig.ViewsPath + m["BConfig.WebConfig.EnableXSRF"] = BConfig.WebConfig.EnableXSRF + m["BConfig.WebConfig.XSRFExpire"] = BConfig.WebConfig.XSRFExpire + m["BConfig.WebConfig.Session.SessionOn"] = BConfig.WebConfig.Session.SessionOn + m["BConfig.WebConfig.Session.SessionProvider"] = BConfig.WebConfig.Session.SessionProvider + m["BConfig.WebConfig.Session.SessionName"] = BConfig.WebConfig.Session.SessionName + m["BConfig.WebConfig.Session.SessionGCMaxLifetime"] = BConfig.WebConfig.Session.SessionGCMaxLifetime + m["BConfig.WebConfig.Session.SessionProviderConfig"] = BConfig.WebConfig.Session.SessionProviderConfig + m["BConfig.WebConfig.Session.SessionCookieLifeTime"] = BConfig.WebConfig.Session.SessionCookieLifeTime + m["BConfig.WebConfig.Session.SessionAutoSetCookie"] = BConfig.WebConfig.Session.SessionAutoSetCookie + m["BConfig.WebConfig.Session.SessionDomain"] = BConfig.WebConfig.Session.SessionDomain + m["BConfig.WebConfig.Session.SessionDisableHTTPOnly"] = BConfig.WebConfig.Session.SessionDisableHTTPOnly + m["BConfig.Log.AccessLogs"] = BConfig.Log.AccessLogs + m["BConfig.Log.EnableStaticLogs"] = BConfig.Log.EnableStaticLogs + m["BConfig.Log.AccessLogsFormat"] = BConfig.Log.AccessLogsFormat + m["BConfig.Log.FileLineNum"] = BConfig.Log.FileLineNum + m["BConfig.Log.Outputs"] = BConfig.Log.Outputs + return m +} + +func TestWriteJSON(t *testing.T) { + t.Log("Testing the adding of JSON to the response") + + w := httptest.NewRecorder() + originalBody := []int{1, 2, 3} + + res, _ := json.Marshal(originalBody) + + writeJSON(w, res) + + decodedBody := []int{} + err := json.NewDecoder(w.Body).Decode(&decodedBody) + + if err != nil { + t.Fatal("Could not decode response body into slice.") + } + + for i := range decodedBody { + if decodedBody[i] != originalBody[i] { + t.Fatalf("Expected %d but got %d in decoded body slice", originalBody[i], decodedBody[i]) + } + } +} + +func TestHealthCheckHandlerDefault(t *testing.T) { + endpointPath := "/healthcheck" + + toolbox.AddHealthCheck("database", &SampleDatabaseCheck{}) + toolbox.AddHealthCheck("cache", &SampleCacheCheck{}) + + req, err := http.NewRequest("GET", endpointPath, nil) + if err != nil { + t.Fatal(err) + } + + w := httptest.NewRecorder() + + handler := http.HandlerFunc(healthcheck) + + handler.ServeHTTP(w, req) + + if status := w.Code; status != http.StatusOK { + t.Errorf("handler returned wrong status code: got %v want %v", + status, http.StatusOK) + } + if !strings.Contains(w.Body.String(), "database") { + t.Errorf("Expected 'database' in generated template.") + } + +} + +func TestBuildHealthCheckResponseList(t *testing.T) { + healthCheckResults := [][]string{ + []string{ + "error", + "Database", + "Error occured whie starting the db", + }, + []string{ + "success", + "Cache", + "Cache started successfully", + }, + } + + responseList := buildHealthCheckResponseList(&healthCheckResults) + + if len(responseList) != len(healthCheckResults) { + t.Errorf("invalid response map length: got %d want %d", + len(responseList), len(healthCheckResults)) + } + + responseFields := []string{"name", "message", "status"} + + for _, response := range responseList { + for _, field := range responseFields { + _, ok := response[field] + if !ok { + t.Errorf("expected %s to be in the response %v", field, response) + } + } + + } + +} + +func TestHealthCheckHandlerReturnsJSON(t *testing.T) { + + toolbox.AddHealthCheck("database", &SampleDatabaseCheck{}) + toolbox.AddHealthCheck("cache", &SampleCacheCheck{}) + + req, err := http.NewRequest("GET", "/healthcheck?json=true", nil) + if err != nil { + t.Fatal(err) + } + + w := httptest.NewRecorder() + + handler := http.HandlerFunc(healthcheck) + + handler.ServeHTTP(w, req) + if status := w.Code; status != http.StatusOK { + t.Errorf("handler returned wrong status code: got %v want %v", + status, http.StatusOK) + } + + decodedResponseBody := []map[string]interface{}{} + expectedResponseBody := []map[string]interface{}{} + + expectedJSONString := []byte(` + [ + { + "message":"database", + "name":"success", + "status":"OK" + }, + { + "message":"cache", + "name":"error", + "status":"no cache detected" + } + ] + `) + + json.Unmarshal(expectedJSONString, &expectedResponseBody) + + json.Unmarshal(w.Body.Bytes(), &decodedResponseBody) + + if len(expectedResponseBody) != len(decodedResponseBody) { + t.Errorf("invalid response map length: got %d want %d", + len(decodedResponseBody), len(expectedResponseBody)) + } + + if !reflect.DeepEqual(decodedResponseBody, expectedResponseBody) { + t.Errorf("handler returned unexpected body: got %v want %v", + decodedResponseBody, expectedResponseBody) + } + +} diff --git a/pkg/adminui.go b/pkg/adminui.go new file mode 100644 index 00000000..cdcdef33 --- /dev/null +++ b/pkg/adminui.go @@ -0,0 +1,356 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package beego + +var indexTpl = ` +{{define "content"}} +

Beego Admin Dashboard

+

+For detail usage please check our document: +

+

+Toolbox +

+

+Live Monitor +

+{{.Content}} +{{end}}` + +var profillingTpl = ` +{{define "content"}} +

{{.Title}}

+
+
{{.Content}}
+
+{{end}}` + +var defaultScriptsTpl = `` + +var gcAjaxTpl = ` +{{define "scripts"}} + +{{end}} +` + +var qpsTpl = `{{define "content"}} +

Requests statistics

+ + + + {{range .Content.Fields}} + + {{end}} + + + + + {{range $i, $elem := .Content.Data}} + + + + + + + + + + + {{end}} + + +
+ {{.}} +
{{index $elem 0}}{{index $elem 1}}{{index $elem 2}}{{index $elem 4}}{{index $elem 6}}{{index $elem 8}}{{index $elem 10}}
+{{end}}` + +var configTpl = ` +{{define "content"}} +

Configurations

+
+{{range $index, $elem := .Content}}
+{{$index}}={{$elem}}
+{{end}}
+
+{{end}} +` + +var routerAndFilterTpl = `{{define "content"}} + + +

{{.Title}}

+ +{{range .Content.Methods}} + +
+
{{.}}
+
+ + + + {{range $.Content.Fields}} + + {{end}} + + + + + {{$slice := index $.Content.Data .}} + {{range $i, $elem := $slice}} + + + {{range $elem}} + + {{end}} + + + {{end}} + + +
+ {{.}} +
+ {{.}} +
+
+
+{{end}} + + +{{end}}` + +var tasksTpl = `{{define "content"}} + +

{{.Title}}

+ +{{if .Message }} +{{ $messageType := index .Message 0}} +

+{{index .Message 1}} +

+{{end}} + + + + + +{{range .Content.Fields}} + +{{end}} + + + + +{{range $i, $slice := .Content.Data}} + + {{range $slice}} + + {{end}} + + +{{end}} + +
+{{.}} +
+ {{.}} + + Run +
+ +{{end}}` + +var healthCheckTpl = ` +{{define "content"}} + +

{{.Title}}

+ + + +{{range .Content.Fields}} + +{{end}} + + + +{{range $i, $slice := .Content.Data}} + {{ $header := index $slice 0}} + {{ if eq "success" $header}} + + {{else if eq "error" $header}} + + {{else}} + + {{end}} + {{range $j, $elem := $slice}} + {{if ne $j 0}} + + {{end}} + {{end}} + + +{{end}} + + +
+ {{.}} +
+ {{$elem}} + + {{$header}} +
+{{end}}` + +// The base dashboardTpl +var dashboardTpl = ` + + + + + + + + + + +Welcome to Beego Admin Dashboard + + + + + + + + + + + + + +
+{{template "content" .}} +
+ + + + + + + +{{template "scripts" .}} + + +` diff --git a/pkg/app.go b/pkg/app.go new file mode 100644 index 00000000..f3fe6f7b --- /dev/null +++ b/pkg/app.go @@ -0,0 +1,496 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package beego + +import ( + "crypto/tls" + "crypto/x509" + "fmt" + "io/ioutil" + "net" + "net/http" + "net/http/fcgi" + "os" + "path" + "strings" + "time" + + "github.com/astaxie/beego/grace" + "github.com/astaxie/beego/logs" + "github.com/astaxie/beego/utils" + "golang.org/x/crypto/acme/autocert" +) + +var ( + // BeeApp is an application instance + BeeApp *App +) + +func init() { + // create beego application + BeeApp = NewApp() +} + +// App defines beego application with a new PatternServeMux. +type App struct { + Handlers *ControllerRegister + Server *http.Server +} + +// NewApp returns a new beego application. +func NewApp() *App { + cr := NewControllerRegister() + app := &App{Handlers: cr, Server: &http.Server{}} + return app +} + +// MiddleWare function for http.Handler +type MiddleWare func(http.Handler) http.Handler + +// Run beego application. +func (app *App) Run(mws ...MiddleWare) { + addr := BConfig.Listen.HTTPAddr + + if BConfig.Listen.HTTPPort != 0 { + addr = fmt.Sprintf("%s:%d", BConfig.Listen.HTTPAddr, BConfig.Listen.HTTPPort) + } + + var ( + err error + l net.Listener + endRunning = make(chan bool, 1) + ) + + // run cgi server + if BConfig.Listen.EnableFcgi { + if BConfig.Listen.EnableStdIo { + if err = fcgi.Serve(nil, app.Handlers); err == nil { // standard I/O + logs.Info("Use FCGI via standard I/O") + } else { + logs.Critical("Cannot use FCGI via standard I/O", err) + } + return + } + if BConfig.Listen.HTTPPort == 0 { + // remove the Socket file before start + if utils.FileExists(addr) { + os.Remove(addr) + } + l, err = net.Listen("unix", addr) + } else { + l, err = net.Listen("tcp", addr) + } + if err != nil { + logs.Critical("Listen: ", err) + } + if err = fcgi.Serve(l, app.Handlers); err != nil { + logs.Critical("fcgi.Serve: ", err) + } + return + } + + app.Server.Handler = app.Handlers + for i := len(mws) - 1; i >= 0; i-- { + if mws[i] == nil { + continue + } + app.Server.Handler = mws[i](app.Server.Handler) + } + app.Server.ReadTimeout = time.Duration(BConfig.Listen.ServerTimeOut) * time.Second + app.Server.WriteTimeout = time.Duration(BConfig.Listen.ServerTimeOut) * time.Second + app.Server.ErrorLog = logs.GetLogger("HTTP") + + // run graceful mode + if BConfig.Listen.Graceful { + httpsAddr := BConfig.Listen.HTTPSAddr + app.Server.Addr = httpsAddr + if BConfig.Listen.EnableHTTPS || BConfig.Listen.EnableMutualHTTPS { + go func() { + time.Sleep(1000 * time.Microsecond) + if BConfig.Listen.HTTPSPort != 0 { + httpsAddr = fmt.Sprintf("%s:%d", BConfig.Listen.HTTPSAddr, BConfig.Listen.HTTPSPort) + app.Server.Addr = httpsAddr + } + server := grace.NewServer(httpsAddr, app.Server.Handler) + server.Server.ReadTimeout = app.Server.ReadTimeout + server.Server.WriteTimeout = app.Server.WriteTimeout + if BConfig.Listen.EnableMutualHTTPS { + if err := server.ListenAndServeMutualTLS(BConfig.Listen.HTTPSCertFile, BConfig.Listen.HTTPSKeyFile, BConfig.Listen.TrustCaFile); err != nil { + logs.Critical("ListenAndServeTLS: ", err, fmt.Sprintf("%d", os.Getpid())) + time.Sleep(100 * time.Microsecond) + } + } else { + if BConfig.Listen.AutoTLS { + m := autocert.Manager{ + Prompt: autocert.AcceptTOS, + HostPolicy: autocert.HostWhitelist(BConfig.Listen.Domains...), + Cache: autocert.DirCache(BConfig.Listen.TLSCacheDir), + } + app.Server.TLSConfig = &tls.Config{GetCertificate: m.GetCertificate} + BConfig.Listen.HTTPSCertFile, BConfig.Listen.HTTPSKeyFile = "", "" + } + if err := server.ListenAndServeTLS(BConfig.Listen.HTTPSCertFile, BConfig.Listen.HTTPSKeyFile); err != nil { + logs.Critical("ListenAndServeTLS: ", err, fmt.Sprintf("%d", os.Getpid())) + time.Sleep(100 * time.Microsecond) + } + } + endRunning <- true + }() + } + if BConfig.Listen.EnableHTTP { + go func() { + server := grace.NewServer(addr, app.Server.Handler) + server.Server.ReadTimeout = app.Server.ReadTimeout + server.Server.WriteTimeout = app.Server.WriteTimeout + if BConfig.Listen.ListenTCP4 { + server.Network = "tcp4" + } + if err := server.ListenAndServe(); err != nil { + logs.Critical("ListenAndServe: ", err, fmt.Sprintf("%d", os.Getpid())) + time.Sleep(100 * time.Microsecond) + } + endRunning <- true + }() + } + <-endRunning + return + } + + // run normal mode + if BConfig.Listen.EnableHTTPS || BConfig.Listen.EnableMutualHTTPS { + go func() { + time.Sleep(1000 * time.Microsecond) + if BConfig.Listen.HTTPSPort != 0 { + app.Server.Addr = fmt.Sprintf("%s:%d", BConfig.Listen.HTTPSAddr, BConfig.Listen.HTTPSPort) + } else if BConfig.Listen.EnableHTTP { + logs.Info("Start https server error, conflict with http. Please reset https port") + return + } + logs.Info("https server Running on https://%s", app.Server.Addr) + if BConfig.Listen.AutoTLS { + m := autocert.Manager{ + Prompt: autocert.AcceptTOS, + HostPolicy: autocert.HostWhitelist(BConfig.Listen.Domains...), + Cache: autocert.DirCache(BConfig.Listen.TLSCacheDir), + } + app.Server.TLSConfig = &tls.Config{GetCertificate: m.GetCertificate} + BConfig.Listen.HTTPSCertFile, BConfig.Listen.HTTPSKeyFile = "", "" + } else if BConfig.Listen.EnableMutualHTTPS { + pool := x509.NewCertPool() + data, err := ioutil.ReadFile(BConfig.Listen.TrustCaFile) + if err != nil { + logs.Info("MutualHTTPS should provide TrustCaFile") + return + } + pool.AppendCertsFromPEM(data) + app.Server.TLSConfig = &tls.Config{ + ClientCAs: pool, + ClientAuth: tls.RequireAndVerifyClientCert, + } + } + if err := app.Server.ListenAndServeTLS(BConfig.Listen.HTTPSCertFile, BConfig.Listen.HTTPSKeyFile); err != nil { + logs.Critical("ListenAndServeTLS: ", err) + time.Sleep(100 * time.Microsecond) + endRunning <- true + } + }() + + } + if BConfig.Listen.EnableHTTP { + go func() { + app.Server.Addr = addr + logs.Info("http server Running on http://%s", app.Server.Addr) + if BConfig.Listen.ListenTCP4 { + ln, err := net.Listen("tcp4", app.Server.Addr) + if err != nil { + logs.Critical("ListenAndServe: ", err) + time.Sleep(100 * time.Microsecond) + endRunning <- true + return + } + if err = app.Server.Serve(ln); err != nil { + logs.Critical("ListenAndServe: ", err) + time.Sleep(100 * time.Microsecond) + endRunning <- true + return + } + } else { + if err := app.Server.ListenAndServe(); err != nil { + logs.Critical("ListenAndServe: ", err) + time.Sleep(100 * time.Microsecond) + endRunning <- true + } + } + }() + } + <-endRunning +} + +// Router adds a patterned controller handler to BeeApp. +// it's an alias method of App.Router. +// usage: +// simple router +// beego.Router("/admin", &admin.UserController{}) +// beego.Router("/admin/index", &admin.ArticleController{}) +// +// regex router +// +// beego.Router("/api/:id([0-9]+)", &controllers.RController{}) +// +// custom rules +// beego.Router("/api/list",&RestController{},"*:ListFood") +// beego.Router("/api/create",&RestController{},"post:CreateFood") +// beego.Router("/api/update",&RestController{},"put:UpdateFood") +// beego.Router("/api/delete",&RestController{},"delete:DeleteFood") +func Router(rootpath string, c ControllerInterface, mappingMethods ...string) *App { + BeeApp.Handlers.Add(rootpath, c, mappingMethods...) + return BeeApp +} + +// UnregisterFixedRoute unregisters the route with the specified fixedRoute. It is particularly useful +// in web applications that inherit most routes from a base webapp via the underscore +// import, and aim to overwrite only certain paths. +// The method parameter can be empty or "*" for all HTTP methods, or a particular +// method type (e.g. "GET" or "POST") for selective removal. +// +// Usage (replace "GET" with "*" for all methods): +// beego.UnregisterFixedRoute("/yourpreviouspath", "GET") +// beego.Router("/yourpreviouspath", yourControllerAddress, "get:GetNewPage") +func UnregisterFixedRoute(fixedRoute string, method string) *App { + subPaths := splitPath(fixedRoute) + if method == "" || method == "*" { + for m := range HTTPMETHOD { + if _, ok := BeeApp.Handlers.routers[m]; !ok { + continue + } + if BeeApp.Handlers.routers[m].prefix == strings.Trim(fixedRoute, "/ ") { + findAndRemoveSingleTree(BeeApp.Handlers.routers[m]) + continue + } + findAndRemoveTree(subPaths, BeeApp.Handlers.routers[m], m) + } + return BeeApp + } + // Single HTTP method + um := strings.ToUpper(method) + if _, ok := BeeApp.Handlers.routers[um]; ok { + if BeeApp.Handlers.routers[um].prefix == strings.Trim(fixedRoute, "/ ") { + findAndRemoveSingleTree(BeeApp.Handlers.routers[um]) + return BeeApp + } + findAndRemoveTree(subPaths, BeeApp.Handlers.routers[um], um) + } + return BeeApp +} + +func findAndRemoveTree(paths []string, entryPointTree *Tree, method string) { + for i := range entryPointTree.fixrouters { + if entryPointTree.fixrouters[i].prefix == paths[0] { + if len(paths) == 1 { + if len(entryPointTree.fixrouters[i].fixrouters) > 0 { + // If the route had children subtrees, remove just the functional leaf, + // to allow children to function as before + if len(entryPointTree.fixrouters[i].leaves) > 0 { + entryPointTree.fixrouters[i].leaves[0] = nil + entryPointTree.fixrouters[i].leaves = entryPointTree.fixrouters[i].leaves[1:] + } + } else { + // Remove the *Tree from the fixrouters slice + entryPointTree.fixrouters[i] = nil + + if i == len(entryPointTree.fixrouters)-1 { + entryPointTree.fixrouters = entryPointTree.fixrouters[:i] + } else { + entryPointTree.fixrouters = append(entryPointTree.fixrouters[:i], entryPointTree.fixrouters[i+1:len(entryPointTree.fixrouters)]...) + } + } + return + } + findAndRemoveTree(paths[1:], entryPointTree.fixrouters[i], method) + } + } +} + +func findAndRemoveSingleTree(entryPointTree *Tree) { + if entryPointTree == nil { + return + } + if len(entryPointTree.fixrouters) > 0 { + // If the route had children subtrees, remove just the functional leaf, + // to allow children to function as before + if len(entryPointTree.leaves) > 0 { + entryPointTree.leaves[0] = nil + entryPointTree.leaves = entryPointTree.leaves[1:] + } + } +} + +// Include will generate router file in the router/xxx.go from the controller's comments +// usage: +// beego.Include(&BankAccount{}, &OrderController{},&RefundController{},&ReceiptController{}) +// type BankAccount struct{ +// beego.Controller +// } +// +// register the function +// func (b *BankAccount)Mapping(){ +// b.Mapping("ShowAccount" , b.ShowAccount) +// b.Mapping("ModifyAccount", b.ModifyAccount) +//} +// +// //@router /account/:id [get] +// func (b *BankAccount) ShowAccount(){ +// //logic +// } +// +// +// //@router /account/:id [post] +// func (b *BankAccount) ModifyAccount(){ +// //logic +// } +// +// the comments @router url methodlist +// url support all the function Router's pattern +// methodlist [get post head put delete options *] +func Include(cList ...ControllerInterface) *App { + BeeApp.Handlers.Include(cList...) + return BeeApp +} + +// RESTRouter adds a restful controller handler to BeeApp. +// its' controller implements beego.ControllerInterface and +// defines a param "pattern/:objectId" to visit each resource. +func RESTRouter(rootpath string, c ControllerInterface) *App { + Router(rootpath, c) + Router(path.Join(rootpath, ":objectId"), c) + return BeeApp +} + +// AutoRouter adds defined controller handler to BeeApp. +// it's same to App.AutoRouter. +// if beego.AddAuto(&MainContorlller{}) and MainController has methods List and Page, +// visit the url /main/list to exec List function or /main/page to exec Page function. +func AutoRouter(c ControllerInterface) *App { + BeeApp.Handlers.AddAuto(c) + return BeeApp +} + +// AutoPrefix adds controller handler to BeeApp with prefix. +// it's same to App.AutoRouterWithPrefix. +// if beego.AutoPrefix("/admin",&MainContorlller{}) and MainController has methods List and Page, +// visit the url /admin/main/list to exec List function or /admin/main/page to exec Page function. +func AutoPrefix(prefix string, c ControllerInterface) *App { + BeeApp.Handlers.AddAutoPrefix(prefix, c) + return BeeApp +} + +// Get used to register router for Get method +// usage: +// beego.Get("/", func(ctx *context.Context){ +// ctx.Output.Body("hello world") +// }) +func Get(rootpath string, f FilterFunc) *App { + BeeApp.Handlers.Get(rootpath, f) + return BeeApp +} + +// Post used to register router for Post method +// usage: +// beego.Post("/api", func(ctx *context.Context){ +// ctx.Output.Body("hello world") +// }) +func Post(rootpath string, f FilterFunc) *App { + BeeApp.Handlers.Post(rootpath, f) + return BeeApp +} + +// Delete used to register router for Delete method +// usage: +// beego.Delete("/api", func(ctx *context.Context){ +// ctx.Output.Body("hello world") +// }) +func Delete(rootpath string, f FilterFunc) *App { + BeeApp.Handlers.Delete(rootpath, f) + return BeeApp +} + +// Put used to register router for Put method +// usage: +// beego.Put("/api", func(ctx *context.Context){ +// ctx.Output.Body("hello world") +// }) +func Put(rootpath string, f FilterFunc) *App { + BeeApp.Handlers.Put(rootpath, f) + return BeeApp +} + +// Head used to register router for Head method +// usage: +// beego.Head("/api", func(ctx *context.Context){ +// ctx.Output.Body("hello world") +// }) +func Head(rootpath string, f FilterFunc) *App { + BeeApp.Handlers.Head(rootpath, f) + return BeeApp +} + +// Options used to register router for Options method +// usage: +// beego.Options("/api", func(ctx *context.Context){ +// ctx.Output.Body("hello world") +// }) +func Options(rootpath string, f FilterFunc) *App { + BeeApp.Handlers.Options(rootpath, f) + return BeeApp +} + +// Patch used to register router for Patch method +// usage: +// beego.Patch("/api", func(ctx *context.Context){ +// ctx.Output.Body("hello world") +// }) +func Patch(rootpath string, f FilterFunc) *App { + BeeApp.Handlers.Patch(rootpath, f) + return BeeApp +} + +// Any used to register router for all methods +// usage: +// beego.Any("/api", func(ctx *context.Context){ +// ctx.Output.Body("hello world") +// }) +func Any(rootpath string, f FilterFunc) *App { + BeeApp.Handlers.Any(rootpath, f) + return BeeApp +} + +// Handler used to register a Handler router +// usage: +// beego.Handler("/api", http.HandlerFunc(func (w http.ResponseWriter, r *http.Request) { +// fmt.Fprintf(w, "Hello, %q", html.EscapeString(r.URL.Path)) +// })) +func Handler(rootpath string, h http.Handler, options ...interface{}) *App { + BeeApp.Handlers.Handler(rootpath, h, options...) + return BeeApp +} + +// InsertFilter adds a FilterFunc with pattern condition and action constant. +// The pos means action constant including +// beego.BeforeStatic, beego.BeforeRouter, beego.BeforeExec, beego.AfterExec and beego.FinishRouter. +// The bool params is for setting the returnOnOutput value (false allows multiple filters to execute) +func InsertFilter(pattern string, pos int, filter FilterFunc, params ...bool) *App { + BeeApp.Handlers.InsertFilter(pattern, pos, filter, params...) + return BeeApp +} diff --git a/pkg/beego.go b/pkg/beego.go new file mode 100644 index 00000000..8ebe0bab --- /dev/null +++ b/pkg/beego.go @@ -0,0 +1,123 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package beego + +import ( + "os" + "path/filepath" + "strconv" + "strings" +) + +const ( + // VERSION represent beego web framework version. + VERSION = "1.12.2" + + // DEV is for develop + DEV = "dev" + // PROD is for production + PROD = "prod" +) + +// M is Map shortcut +type M map[string]interface{} + +// Hook function to run +type hookfunc func() error + +var ( + hooks = make([]hookfunc, 0) //hook function slice to store the hookfunc +) + +// AddAPPStartHook is used to register the hookfunc +// The hookfuncs will run in beego.Run() +// such as initiating session , starting middleware , building template, starting admin control and so on. +func AddAPPStartHook(hf ...hookfunc) { + hooks = append(hooks, hf...) +} + +// Run beego application. +// beego.Run() default run on HttpPort +// beego.Run("localhost") +// beego.Run(":8089") +// beego.Run("127.0.0.1:8089") +func Run(params ...string) { + + initBeforeHTTPRun() + + if len(params) > 0 && params[0] != "" { + strs := strings.Split(params[0], ":") + if len(strs) > 0 && strs[0] != "" { + BConfig.Listen.HTTPAddr = strs[0] + } + if len(strs) > 1 && strs[1] != "" { + BConfig.Listen.HTTPPort, _ = strconv.Atoi(strs[1]) + } + + BConfig.Listen.Domains = params + } + + BeeApp.Run() +} + +// RunWithMiddleWares Run beego application with middlewares. +func RunWithMiddleWares(addr string, mws ...MiddleWare) { + initBeforeHTTPRun() + + strs := strings.Split(addr, ":") + if len(strs) > 0 && strs[0] != "" { + BConfig.Listen.HTTPAddr = strs[0] + BConfig.Listen.Domains = []string{strs[0]} + } + if len(strs) > 1 && strs[1] != "" { + BConfig.Listen.HTTPPort, _ = strconv.Atoi(strs[1]) + } + + BeeApp.Run(mws...) +} + +func initBeforeHTTPRun() { + //init hooks + AddAPPStartHook( + registerMime, + registerDefaultErrorHandler, + registerSession, + registerTemplate, + registerAdmin, + registerGzip, + ) + + for _, hk := range hooks { + if err := hk(); err != nil { + panic(err) + } + } +} + +// TestBeegoInit is for test package init +func TestBeegoInit(ap string) { + path := filepath.Join(ap, "conf", "app.conf") + os.Chdir(ap) + InitBeegoBeforeTest(path) +} + +// InitBeegoBeforeTest is for test package init +func InitBeegoBeforeTest(appConfigPath string) { + if err := LoadAppConfig(appConfigProvider, appConfigPath); err != nil { + panic(err) + } + BConfig.RunMode = "test" + initBeforeHTTPRun() +} diff --git a/pkg/build_info.go b/pkg/build_info.go new file mode 100644 index 00000000..6dc2835e --- /dev/null +++ b/pkg/build_info.go @@ -0,0 +1,27 @@ +// Copyright 2020 astaxie +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package beego + +var ( + BuildVersion string + BuildGitRevision string + BuildStatus string + BuildTag string + BuildTime string + + GoVersion string + + GitBranch string +) diff --git a/pkg/cache/README.md b/pkg/cache/README.md new file mode 100644 index 00000000..b467760a --- /dev/null +++ b/pkg/cache/README.md @@ -0,0 +1,59 @@ +## cache +cache is a Go cache manager. It can use many cache adapters. The repo is inspired by `database/sql` . + + +## How to install? + + go get github.com/astaxie/beego/cache + + +## What adapters are supported? + +As of now this cache support memory, Memcache and Redis. + + +## How to use it? + +First you must import it + + import ( + "github.com/astaxie/beego/cache" + ) + +Then init a Cache (example with memory adapter) + + bm, err := cache.NewCache("memory", `{"interval":60}`) + +Use it like this: + + bm.Put("astaxie", 1, 10 * time.Second) + bm.Get("astaxie") + bm.IsExist("astaxie") + bm.Delete("astaxie") + + +## Memory adapter + +Configure memory adapter like this: + + {"interval":60} + +interval means the gc time. The cache will check at each time interval, whether item has expired. + + +## Memcache adapter + +Memcache adapter use the [gomemcache](http://github.com/bradfitz/gomemcache) client. + +Configure like this: + + {"conn":"127.0.0.1:11211"} + + +## Redis adapter + +Redis adapter use the [redigo](http://github.com/gomodule/redigo) client. + +Configure like this: + + {"conn":":6039"} diff --git a/pkg/cache/cache.go b/pkg/cache/cache.go new file mode 100644 index 00000000..82585c4e --- /dev/null +++ b/pkg/cache/cache.go @@ -0,0 +1,103 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package cache provide a Cache interface and some implement engine +// Usage: +// +// import( +// "github.com/astaxie/beego/cache" +// ) +// +// bm, err := cache.NewCache("memory", `{"interval":60}`) +// +// Use it like this: +// +// bm.Put("astaxie", 1, 10 * time.Second) +// bm.Get("astaxie") +// bm.IsExist("astaxie") +// bm.Delete("astaxie") +// +// more docs http://beego.me/docs/module/cache.md +package cache + +import ( + "fmt" + "time" +) + +// Cache interface contains all behaviors for cache adapter. +// usage: +// cache.Register("file",cache.NewFileCache) // this operation is run in init method of file.go. +// c,err := cache.NewCache("file","{....}") +// c.Put("key",value, 3600 * time.Second) +// v := c.Get("key") +// +// c.Incr("counter") // now is 1 +// c.Incr("counter") // now is 2 +// count := c.Get("counter").(int) +type Cache interface { + // get cached value by key. + Get(key string) interface{} + // GetMulti is a batch version of Get. + GetMulti(keys []string) []interface{} + // set cached value with key and expire time. + Put(key string, val interface{}, timeout time.Duration) error + // delete cached value by key. + Delete(key string) error + // increase cached int value by key, as a counter. + Incr(key string) error + // decrease cached int value by key, as a counter. + Decr(key string) error + // check if cached value exists or not. + IsExist(key string) bool + // clear all cache. + ClearAll() error + // start gc routine based on config string settings. + StartAndGC(config string) error +} + +// Instance is a function create a new Cache Instance +type Instance func() Cache + +var adapters = make(map[string]Instance) + +// Register makes a cache adapter available by the adapter name. +// If Register is called twice with the same name or if driver is nil, +// it panics. +func Register(name string, adapter Instance) { + if adapter == nil { + panic("cache: Register adapter is nil") + } + if _, ok := adapters[name]; ok { + panic("cache: Register called twice for adapter " + name) + } + adapters[name] = adapter +} + +// NewCache Create a new cache driver by adapter name and config string. +// config need to be correct JSON as string: {"interval":360}. +// it will start gc automatically. +func NewCache(adapterName, config string) (adapter Cache, err error) { + instanceFunc, ok := adapters[adapterName] + if !ok { + err = fmt.Errorf("cache: unknown adapter name %q (forgot to import?)", adapterName) + return + } + adapter = instanceFunc() + err = adapter.StartAndGC(config) + if err != nil { + adapter = nil + } + return +} diff --git a/pkg/cache/cache_test.go b/pkg/cache/cache_test.go new file mode 100644 index 00000000..470c0a43 --- /dev/null +++ b/pkg/cache/cache_test.go @@ -0,0 +1,191 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package cache + +import ( + "os" + "sync" + "testing" + "time" +) + +func TestCacheIncr(t *testing.T) { + bm, err := NewCache("memory", `{"interval":20}`) + if err != nil { + t.Error("init err") + } + //timeoutDuration := 10 * time.Second + + bm.Put("edwardhey", 0, time.Second*20) + wg := sync.WaitGroup{} + wg.Add(10) + for i := 0; i < 10; i++ { + go func() { + defer wg.Done() + bm.Incr("edwardhey") + }() + } + wg.Wait() + if bm.Get("edwardhey").(int) != 10 { + t.Error("Incr err") + } +} + +func TestCache(t *testing.T) { + bm, err := NewCache("memory", `{"interval":20}`) + if err != nil { + t.Error("init err") + } + timeoutDuration := 10 * time.Second + if err = bm.Put("astaxie", 1, timeoutDuration); err != nil { + t.Error("set Error", err) + } + if !bm.IsExist("astaxie") { + t.Error("check err") + } + + if v := bm.Get("astaxie"); v.(int) != 1 { + t.Error("get err") + } + + time.Sleep(30 * time.Second) + + if bm.IsExist("astaxie") { + t.Error("check err") + } + + if err = bm.Put("astaxie", 1, timeoutDuration); err != nil { + t.Error("set Error", err) + } + + if err = bm.Incr("astaxie"); err != nil { + t.Error("Incr Error", err) + } + + if v := bm.Get("astaxie"); v.(int) != 2 { + t.Error("get err") + } + + if err = bm.Decr("astaxie"); err != nil { + t.Error("Decr Error", err) + } + + if v := bm.Get("astaxie"); v.(int) != 1 { + t.Error("get err") + } + bm.Delete("astaxie") + if bm.IsExist("astaxie") { + t.Error("delete err") + } + + //test GetMulti + if err = bm.Put("astaxie", "author", timeoutDuration); err != nil { + t.Error("set Error", err) + } + if !bm.IsExist("astaxie") { + t.Error("check err") + } + if v := bm.Get("astaxie"); v.(string) != "author" { + t.Error("get err") + } + + if err = bm.Put("astaxie1", "author1", timeoutDuration); err != nil { + t.Error("set Error", err) + } + if !bm.IsExist("astaxie1") { + t.Error("check err") + } + + vv := bm.GetMulti([]string{"astaxie", "astaxie1"}) + if len(vv) != 2 { + t.Error("GetMulti ERROR") + } + if vv[0].(string) != "author" { + t.Error("GetMulti ERROR") + } + if vv[1].(string) != "author1" { + t.Error("GetMulti ERROR") + } +} + +func TestFileCache(t *testing.T) { + bm, err := NewCache("file", `{"CachePath":"cache","FileSuffix":".bin","DirectoryLevel":"2","EmbedExpiry":"0"}`) + if err != nil { + t.Error("init err") + } + timeoutDuration := 10 * time.Second + if err = bm.Put("astaxie", 1, timeoutDuration); err != nil { + t.Error("set Error", err) + } + if !bm.IsExist("astaxie") { + t.Error("check err") + } + + if v := bm.Get("astaxie"); v.(int) != 1 { + t.Error("get err") + } + + if err = bm.Incr("astaxie"); err != nil { + t.Error("Incr Error", err) + } + + if v := bm.Get("astaxie"); v.(int) != 2 { + t.Error("get err") + } + + if err = bm.Decr("astaxie"); err != nil { + t.Error("Decr Error", err) + } + + if v := bm.Get("astaxie"); v.(int) != 1 { + t.Error("get err") + } + bm.Delete("astaxie") + if bm.IsExist("astaxie") { + t.Error("delete err") + } + + //test string + if err = bm.Put("astaxie", "author", timeoutDuration); err != nil { + t.Error("set Error", err) + } + if !bm.IsExist("astaxie") { + t.Error("check err") + } + if v := bm.Get("astaxie"); v.(string) != "author" { + t.Error("get err") + } + + //test GetMulti + if err = bm.Put("astaxie1", "author1", timeoutDuration); err != nil { + t.Error("set Error", err) + } + if !bm.IsExist("astaxie1") { + t.Error("check err") + } + + vv := bm.GetMulti([]string{"astaxie", "astaxie1"}) + if len(vv) != 2 { + t.Error("GetMulti ERROR") + } + if vv[0].(string) != "author" { + t.Error("GetMulti ERROR") + } + if vv[1].(string) != "author1" { + t.Error("GetMulti ERROR") + } + + os.RemoveAll("cache") +} diff --git a/pkg/cache/conv.go b/pkg/cache/conv.go new file mode 100644 index 00000000..87800586 --- /dev/null +++ b/pkg/cache/conv.go @@ -0,0 +1,100 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package cache + +import ( + "fmt" + "strconv" +) + +// GetString convert interface to string. +func GetString(v interface{}) string { + switch result := v.(type) { + case string: + return result + case []byte: + return string(result) + default: + if v != nil { + return fmt.Sprint(result) + } + } + return "" +} + +// GetInt convert interface to int. +func GetInt(v interface{}) int { + switch result := v.(type) { + case int: + return result + case int32: + return int(result) + case int64: + return int(result) + default: + if d := GetString(v); d != "" { + value, _ := strconv.Atoi(d) + return value + } + } + return 0 +} + +// GetInt64 convert interface to int64. +func GetInt64(v interface{}) int64 { + switch result := v.(type) { + case int: + return int64(result) + case int32: + return int64(result) + case int64: + return result + default: + + if d := GetString(v); d != "" { + value, _ := strconv.ParseInt(d, 10, 64) + return value + } + } + return 0 +} + +// GetFloat64 convert interface to float64. +func GetFloat64(v interface{}) float64 { + switch result := v.(type) { + case float64: + return result + default: + if d := GetString(v); d != "" { + value, _ := strconv.ParseFloat(d, 64) + return value + } + } + return 0 +} + +// GetBool convert interface to bool. +func GetBool(v interface{}) bool { + switch result := v.(type) { + case bool: + return result + default: + if d := GetString(v); d != "" { + value, _ := strconv.ParseBool(d) + return value + } + } + return false +} diff --git a/pkg/cache/conv_test.go b/pkg/cache/conv_test.go new file mode 100644 index 00000000..b90e224a --- /dev/null +++ b/pkg/cache/conv_test.go @@ -0,0 +1,143 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package cache + +import ( + "testing" +) + +func TestGetString(t *testing.T) { + var t1 = "test1" + if "test1" != GetString(t1) { + t.Error("get string from string error") + } + var t2 = []byte("test2") + if "test2" != GetString(t2) { + t.Error("get string from byte array error") + } + var t3 = 1 + if "1" != GetString(t3) { + t.Error("get string from int error") + } + var t4 int64 = 1 + if "1" != GetString(t4) { + t.Error("get string from int64 error") + } + var t5 = 1.1 + if "1.1" != GetString(t5) { + t.Error("get string from float64 error") + } + + if "" != GetString(nil) { + t.Error("get string from nil error") + } +} + +func TestGetInt(t *testing.T) { + var t1 = 1 + if 1 != GetInt(t1) { + t.Error("get int from int error") + } + var t2 int32 = 32 + if 32 != GetInt(t2) { + t.Error("get int from int32 error") + } + var t3 int64 = 64 + if 64 != GetInt(t3) { + t.Error("get int from int64 error") + } + var t4 = "128" + if 128 != GetInt(t4) { + t.Error("get int from num string error") + } + if 0 != GetInt(nil) { + t.Error("get int from nil error") + } +} + +func TestGetInt64(t *testing.T) { + var i int64 = 1 + var t1 = 1 + if i != GetInt64(t1) { + t.Error("get int64 from int error") + } + var t2 int32 = 1 + if i != GetInt64(t2) { + t.Error("get int64 from int32 error") + } + var t3 int64 = 1 + if i != GetInt64(t3) { + t.Error("get int64 from int64 error") + } + var t4 = "1" + if i != GetInt64(t4) { + t.Error("get int64 from num string error") + } + if 0 != GetInt64(nil) { + t.Error("get int64 from nil") + } +} + +func TestGetFloat64(t *testing.T) { + var f = 1.11 + var t1 float32 = 1.11 + if f != GetFloat64(t1) { + t.Error("get float64 from float32 error") + } + var t2 = 1.11 + if f != GetFloat64(t2) { + t.Error("get float64 from float64 error") + } + var t3 = "1.11" + if f != GetFloat64(t3) { + t.Error("get float64 from string error") + } + + var f2 float64 = 1 + var t4 = 1 + if f2 != GetFloat64(t4) { + t.Error("get float64 from int error") + } + + if 0 != GetFloat64(nil) { + t.Error("get float64 from nil error") + } +} + +func TestGetBool(t *testing.T) { + var t1 = true + if !GetBool(t1) { + t.Error("get bool from bool error") + } + var t2 = "true" + if !GetBool(t2) { + t.Error("get bool from string error") + } + if GetBool(nil) { + t.Error("get bool from nil error") + } +} + +func byteArrayEquals(a []byte, b []byte) bool { + if len(a) != len(b) { + return false + } + for i, v := range a { + if v != b[i] { + return false + } + } + return true +} diff --git a/pkg/cache/file.go b/pkg/cache/file.go new file mode 100644 index 00000000..6f12d3ee --- /dev/null +++ b/pkg/cache/file.go @@ -0,0 +1,258 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package cache + +import ( + "bytes" + "crypto/md5" + "encoding/gob" + "encoding/hex" + "encoding/json" + "fmt" + "io" + "io/ioutil" + "os" + "path/filepath" + "reflect" + "strconv" + "time" +) + +// FileCacheItem is basic unit of file cache adapter. +// it contains data and expire time. +type FileCacheItem struct { + Data interface{} + Lastaccess time.Time + Expired time.Time +} + +// FileCache Config +var ( + FileCachePath = "cache" // cache directory + FileCacheFileSuffix = ".bin" // cache file suffix + FileCacheDirectoryLevel = 2 // cache file deep level if auto generated cache files. + FileCacheEmbedExpiry time.Duration // cache expire time, default is no expire forever. +) + +// FileCache is cache adapter for file storage. +type FileCache struct { + CachePath string + FileSuffix string + DirectoryLevel int + EmbedExpiry int +} + +// NewFileCache Create new file cache with no config. +// the level and expiry need set in method StartAndGC as config string. +func NewFileCache() Cache { + // return &FileCache{CachePath:FileCachePath, FileSuffix:FileCacheFileSuffix} + return &FileCache{} +} + +// StartAndGC will start and begin gc for file cache. +// the config need to be like {CachePath:"/cache","FileSuffix":".bin","DirectoryLevel":"2","EmbedExpiry":"0"} +func (fc *FileCache) StartAndGC(config string) error { + + cfg := make(map[string]string) + err := json.Unmarshal([]byte(config), &cfg) + if err != nil { + return err + } + if _, ok := cfg["CachePath"]; !ok { + cfg["CachePath"] = FileCachePath + } + if _, ok := cfg["FileSuffix"]; !ok { + cfg["FileSuffix"] = FileCacheFileSuffix + } + if _, ok := cfg["DirectoryLevel"]; !ok { + cfg["DirectoryLevel"] = strconv.Itoa(FileCacheDirectoryLevel) + } + if _, ok := cfg["EmbedExpiry"]; !ok { + cfg["EmbedExpiry"] = strconv.FormatInt(int64(FileCacheEmbedExpiry.Seconds()), 10) + } + fc.CachePath = cfg["CachePath"] + fc.FileSuffix = cfg["FileSuffix"] + fc.DirectoryLevel, _ = strconv.Atoi(cfg["DirectoryLevel"]) + fc.EmbedExpiry, _ = strconv.Atoi(cfg["EmbedExpiry"]) + + fc.Init() + return nil +} + +// Init will make new dir for file cache if not exist. +func (fc *FileCache) Init() { + if ok, _ := exists(fc.CachePath); !ok { // todo : error handle + _ = os.MkdirAll(fc.CachePath, os.ModePerm) // todo : error handle + } +} + +// get cached file name. it's md5 encoded. +func (fc *FileCache) getCacheFileName(key string) string { + m := md5.New() + io.WriteString(m, key) + keyMd5 := hex.EncodeToString(m.Sum(nil)) + cachePath := fc.CachePath + switch fc.DirectoryLevel { + case 2: + cachePath = filepath.Join(cachePath, keyMd5[0:2], keyMd5[2:4]) + case 1: + cachePath = filepath.Join(cachePath, keyMd5[0:2]) + } + + if ok, _ := exists(cachePath); !ok { // todo : error handle + _ = os.MkdirAll(cachePath, os.ModePerm) // todo : error handle + } + + return filepath.Join(cachePath, fmt.Sprintf("%s%s", keyMd5, fc.FileSuffix)) +} + +// Get value from file cache. +// if non-exist or expired, return empty string. +func (fc *FileCache) Get(key string) interface{} { + fileData, err := FileGetContents(fc.getCacheFileName(key)) + if err != nil { + return "" + } + var to FileCacheItem + GobDecode(fileData, &to) + if to.Expired.Before(time.Now()) { + return "" + } + return to.Data +} + +// GetMulti gets values from file cache. +// if non-exist or expired, return empty string. +func (fc *FileCache) GetMulti(keys []string) []interface{} { + var rc []interface{} + for _, key := range keys { + rc = append(rc, fc.Get(key)) + } + return rc +} + +// Put value into file cache. +// timeout means how long to keep this file, unit of ms. +// if timeout equals fc.EmbedExpiry(default is 0), cache this item forever. +func (fc *FileCache) Put(key string, val interface{}, timeout time.Duration) error { + gob.Register(val) + + item := FileCacheItem{Data: val} + if timeout == time.Duration(fc.EmbedExpiry) { + item.Expired = time.Now().Add((86400 * 365 * 10) * time.Second) // ten years + } else { + item.Expired = time.Now().Add(timeout) + } + item.Lastaccess = time.Now() + data, err := GobEncode(item) + if err != nil { + return err + } + return FilePutContents(fc.getCacheFileName(key), data) +} + +// Delete file cache value. +func (fc *FileCache) Delete(key string) error { + filename := fc.getCacheFileName(key) + if ok, _ := exists(filename); ok { + return os.Remove(filename) + } + return nil +} + +// Incr will increase cached int value. +// fc value is saving forever unless Delete. +func (fc *FileCache) Incr(key string) error { + data := fc.Get(key) + var incr int + if reflect.TypeOf(data).Name() != "int" { + incr = 0 + } else { + incr = data.(int) + 1 + } + fc.Put(key, incr, time.Duration(fc.EmbedExpiry)) + return nil +} + +// Decr will decrease cached int value. +func (fc *FileCache) Decr(key string) error { + data := fc.Get(key) + var decr int + if reflect.TypeOf(data).Name() != "int" || data.(int)-1 <= 0 { + decr = 0 + } else { + decr = data.(int) - 1 + } + fc.Put(key, decr, time.Duration(fc.EmbedExpiry)) + return nil +} + +// IsExist check value is exist. +func (fc *FileCache) IsExist(key string) bool { + ret, _ := exists(fc.getCacheFileName(key)) + return ret +} + +// ClearAll will clean cached files. +// not implemented. +func (fc *FileCache) ClearAll() error { + return nil +} + +// check file exist. +func exists(path string) (bool, error) { + _, err := os.Stat(path) + if err == nil { + return true, nil + } + if os.IsNotExist(err) { + return false, nil + } + return false, err +} + +// FileGetContents Get bytes to file. +// if non-exist, create this file. +func FileGetContents(filename string) (data []byte, e error) { + return ioutil.ReadFile(filename) +} + +// FilePutContents Put bytes to file. +// if non-exist, create this file. +func FilePutContents(filename string, content []byte) error { + return ioutil.WriteFile(filename, content, os.ModePerm) +} + +// GobEncode Gob encodes file cache item. +func GobEncode(data interface{}) ([]byte, error) { + buf := bytes.NewBuffer(nil) + enc := gob.NewEncoder(buf) + err := enc.Encode(data) + if err != nil { + return nil, err + } + return buf.Bytes(), err +} + +// GobDecode Gob decodes file cache item. +func GobDecode(data []byte, to *FileCacheItem) error { + buf := bytes.NewBuffer(data) + dec := gob.NewDecoder(buf) + return dec.Decode(&to) +} + +func init() { + Register("file", NewFileCache) +} diff --git a/pkg/cache/memcache/memcache.go b/pkg/cache/memcache/memcache.go new file mode 100644 index 00000000..19116bfa --- /dev/null +++ b/pkg/cache/memcache/memcache.go @@ -0,0 +1,188 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package memcache for cache provider +// +// depend on github.com/bradfitz/gomemcache/memcache +// +// go install github.com/bradfitz/gomemcache/memcache +// +// Usage: +// import( +// _ "github.com/astaxie/beego/cache/memcache" +// "github.com/astaxie/beego/cache" +// ) +// +// bm, err := cache.NewCache("memcache", `{"conn":"127.0.0.1:11211"}`) +// +// more docs http://beego.me/docs/module/cache.md +package memcache + +import ( + "encoding/json" + "errors" + "strings" + "time" + + "github.com/astaxie/beego/cache" + "github.com/bradfitz/gomemcache/memcache" +) + +// Cache Memcache adapter. +type Cache struct { + conn *memcache.Client + conninfo []string +} + +// NewMemCache create new memcache adapter. +func NewMemCache() cache.Cache { + return &Cache{} +} + +// Get get value from memcache. +func (rc *Cache) Get(key string) interface{} { + if rc.conn == nil { + if err := rc.connectInit(); err != nil { + return err + } + } + if item, err := rc.conn.Get(key); err == nil { + return item.Value + } + return nil +} + +// GetMulti get value from memcache. +func (rc *Cache) GetMulti(keys []string) []interface{} { + size := len(keys) + var rv []interface{} + if rc.conn == nil { + if err := rc.connectInit(); err != nil { + for i := 0; i < size; i++ { + rv = append(rv, err) + } + return rv + } + } + mv, err := rc.conn.GetMulti(keys) + if err == nil { + for _, v := range mv { + rv = append(rv, v.Value) + } + return rv + } + for i := 0; i < size; i++ { + rv = append(rv, err) + } + return rv +} + +// Put put value to memcache. +func (rc *Cache) Put(key string, val interface{}, timeout time.Duration) error { + if rc.conn == nil { + if err := rc.connectInit(); err != nil { + return err + } + } + item := memcache.Item{Key: key, Expiration: int32(timeout / time.Second)} + if v, ok := val.([]byte); ok { + item.Value = v + } else if str, ok := val.(string); ok { + item.Value = []byte(str) + } else { + return errors.New("val only support string and []byte") + } + return rc.conn.Set(&item) +} + +// Delete delete value in memcache. +func (rc *Cache) Delete(key string) error { + if rc.conn == nil { + if err := rc.connectInit(); err != nil { + return err + } + } + return rc.conn.Delete(key) +} + +// Incr increase counter. +func (rc *Cache) Incr(key string) error { + if rc.conn == nil { + if err := rc.connectInit(); err != nil { + return err + } + } + _, err := rc.conn.Increment(key, 1) + return err +} + +// Decr decrease counter. +func (rc *Cache) Decr(key string) error { + if rc.conn == nil { + if err := rc.connectInit(); err != nil { + return err + } + } + _, err := rc.conn.Decrement(key, 1) + return err +} + +// IsExist check value exists in memcache. +func (rc *Cache) IsExist(key string) bool { + if rc.conn == nil { + if err := rc.connectInit(); err != nil { + return false + } + } + _, err := rc.conn.Get(key) + return err == nil +} + +// ClearAll clear all cached in memcache. +func (rc *Cache) ClearAll() error { + if rc.conn == nil { + if err := rc.connectInit(); err != nil { + return err + } + } + return rc.conn.FlushAll() +} + +// StartAndGC start memcache adapter. +// config string is like {"conn":"connection info"}. +// if connecting error, return. +func (rc *Cache) StartAndGC(config string) error { + var cf map[string]string + json.Unmarshal([]byte(config), &cf) + if _, ok := cf["conn"]; !ok { + return errors.New("config has no conn key") + } + rc.conninfo = strings.Split(cf["conn"], ";") + if rc.conn == nil { + if err := rc.connectInit(); err != nil { + return err + } + } + return nil +} + +// connect to memcache and keep the connection. +func (rc *Cache) connectInit() error { + rc.conn = memcache.New(rc.conninfo...) + return nil +} + +func init() { + cache.Register("memcache", NewMemCache) +} diff --git a/pkg/cache/memcache/memcache_test.go b/pkg/cache/memcache/memcache_test.go new file mode 100644 index 00000000..d9129b69 --- /dev/null +++ b/pkg/cache/memcache/memcache_test.go @@ -0,0 +1,108 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package memcache + +import ( + _ "github.com/bradfitz/gomemcache/memcache" + + "strconv" + "testing" + "time" + + "github.com/astaxie/beego/cache" +) + +func TestMemcacheCache(t *testing.T) { + bm, err := cache.NewCache("memcache", `{"conn": "127.0.0.1:11211"}`) + if err != nil { + t.Error("init err") + } + timeoutDuration := 10 * time.Second + if err = bm.Put("astaxie", "1", timeoutDuration); err != nil { + t.Error("set Error", err) + } + if !bm.IsExist("astaxie") { + t.Error("check err") + } + + time.Sleep(11 * time.Second) + + if bm.IsExist("astaxie") { + t.Error("check err") + } + if err = bm.Put("astaxie", "1", timeoutDuration); err != nil { + t.Error("set Error", err) + } + + if v, err := strconv.Atoi(string(bm.Get("astaxie").([]byte))); err != nil || v != 1 { + t.Error("get err") + } + + if err = bm.Incr("astaxie"); err != nil { + t.Error("Incr Error", err) + } + + if v, err := strconv.Atoi(string(bm.Get("astaxie").([]byte))); err != nil || v != 2 { + t.Error("get err") + } + + if err = bm.Decr("astaxie"); err != nil { + t.Error("Decr Error", err) + } + + if v, err := strconv.Atoi(string(bm.Get("astaxie").([]byte))); err != nil || v != 1 { + t.Error("get err") + } + bm.Delete("astaxie") + if bm.IsExist("astaxie") { + t.Error("delete err") + } + + //test string + if err = bm.Put("astaxie", "author", timeoutDuration); err != nil { + t.Error("set Error", err) + } + if !bm.IsExist("astaxie") { + t.Error("check err") + } + + if v := bm.Get("astaxie").([]byte); string(v) != "author" { + t.Error("get err") + } + + //test GetMulti + if err = bm.Put("astaxie1", "author1", timeoutDuration); err != nil { + t.Error("set Error", err) + } + if !bm.IsExist("astaxie1") { + t.Error("check err") + } + + vv := bm.GetMulti([]string{"astaxie", "astaxie1"}) + if len(vv) != 2 { + t.Error("GetMulti ERROR") + } + if string(vv[0].([]byte)) != "author" && string(vv[0].([]byte)) != "author1" { + t.Error("GetMulti ERROR") + } + if string(vv[1].([]byte)) != "author1" && string(vv[1].([]byte)) != "author" { + t.Error("GetMulti ERROR") + } + + // test clear all + if err = bm.ClearAll(); err != nil { + t.Error("clear all err") + } +} diff --git a/pkg/cache/memory.go b/pkg/cache/memory.go new file mode 100644 index 00000000..d8314e3c --- /dev/null +++ b/pkg/cache/memory.go @@ -0,0 +1,256 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package cache + +import ( + "encoding/json" + "errors" + "sync" + "time" +) + +var ( + // DefaultEvery means the clock time of recycling the expired cache items in memory. + DefaultEvery = 60 // 1 minute +) + +// MemoryItem store memory cache item. +type MemoryItem struct { + val interface{} + createdTime time.Time + lifespan time.Duration +} + +func (mi *MemoryItem) isExpire() bool { + // 0 means forever + if mi.lifespan == 0 { + return false + } + return time.Now().Sub(mi.createdTime) > mi.lifespan +} + +// MemoryCache is Memory cache adapter. +// it contains a RW locker for safe map storage. +type MemoryCache struct { + sync.RWMutex + dur time.Duration + items map[string]*MemoryItem + Every int // run an expiration check Every clock time +} + +// NewMemoryCache returns a new MemoryCache. +func NewMemoryCache() Cache { + cache := MemoryCache{items: make(map[string]*MemoryItem)} + return &cache +} + +// Get cache from memory. +// if non-existed or expired, return nil. +func (bc *MemoryCache) Get(name string) interface{} { + bc.RLock() + defer bc.RUnlock() + if itm, ok := bc.items[name]; ok { + if itm.isExpire() { + return nil + } + return itm.val + } + return nil +} + +// GetMulti gets caches from memory. +// if non-existed or expired, return nil. +func (bc *MemoryCache) GetMulti(names []string) []interface{} { + var rc []interface{} + for _, name := range names { + rc = append(rc, bc.Get(name)) + } + return rc +} + +// Put cache to memory. +// if lifespan is 0, it will be forever till restart. +func (bc *MemoryCache) Put(name string, value interface{}, lifespan time.Duration) error { + bc.Lock() + defer bc.Unlock() + bc.items[name] = &MemoryItem{ + val: value, + createdTime: time.Now(), + lifespan: lifespan, + } + return nil +} + +// Delete cache in memory. +func (bc *MemoryCache) Delete(name string) error { + bc.Lock() + defer bc.Unlock() + if _, ok := bc.items[name]; !ok { + return errors.New("key not exist") + } + delete(bc.items, name) + if _, ok := bc.items[name]; ok { + return errors.New("delete key error") + } + return nil +} + +// Incr increase cache counter in memory. +// it supports int,int32,int64,uint,uint32,uint64. +func (bc *MemoryCache) Incr(key string) error { + bc.Lock() + defer bc.Unlock() + itm, ok := bc.items[key] + if !ok { + return errors.New("key not exist") + } + switch val := itm.val.(type) { + case int: + itm.val = val + 1 + case int32: + itm.val = val + 1 + case int64: + itm.val = val + 1 + case uint: + itm.val = val + 1 + case uint32: + itm.val = val + 1 + case uint64: + itm.val = val + 1 + default: + return errors.New("item val is not (u)int (u)int32 (u)int64") + } + return nil +} + +// Decr decrease counter in memory. +func (bc *MemoryCache) Decr(key string) error { + bc.Lock() + defer bc.Unlock() + itm, ok := bc.items[key] + if !ok { + return errors.New("key not exist") + } + switch val := itm.val.(type) { + case int: + itm.val = val - 1 + case int64: + itm.val = val - 1 + case int32: + itm.val = val - 1 + case uint: + if val > 0 { + itm.val = val - 1 + } else { + return errors.New("item val is less than 0") + } + case uint32: + if val > 0 { + itm.val = val - 1 + } else { + return errors.New("item val is less than 0") + } + case uint64: + if val > 0 { + itm.val = val - 1 + } else { + return errors.New("item val is less than 0") + } + default: + return errors.New("item val is not int int64 int32") + } + return nil +} + +// IsExist check cache exist in memory. +func (bc *MemoryCache) IsExist(name string) bool { + bc.RLock() + defer bc.RUnlock() + if v, ok := bc.items[name]; ok { + return !v.isExpire() + } + return false +} + +// ClearAll will delete all cache in memory. +func (bc *MemoryCache) ClearAll() error { + bc.Lock() + defer bc.Unlock() + bc.items = make(map[string]*MemoryItem) + return nil +} + +// StartAndGC start memory cache. it will check expiration in every clock time. +func (bc *MemoryCache) StartAndGC(config string) error { + var cf map[string]int + json.Unmarshal([]byte(config), &cf) + if _, ok := cf["interval"]; !ok { + cf = make(map[string]int) + cf["interval"] = DefaultEvery + } + dur := time.Duration(cf["interval"]) * time.Second + bc.Every = cf["interval"] + bc.dur = dur + go bc.vacuum() + return nil +} + +// check expiration. +func (bc *MemoryCache) vacuum() { + bc.RLock() + every := bc.Every + bc.RUnlock() + + if every < 1 { + return + } + for { + <-time.After(bc.dur) + bc.RLock() + if bc.items == nil { + bc.RUnlock() + return + } + bc.RUnlock() + if keys := bc.expiredKeys(); len(keys) != 0 { + bc.clearItems(keys) + } + } +} + +// expiredKeys returns key list which are expired. +func (bc *MemoryCache) expiredKeys() (keys []string) { + bc.RLock() + defer bc.RUnlock() + for key, itm := range bc.items { + if itm.isExpire() { + keys = append(keys, key) + } + } + return +} + +// clearItems removes all the items which key in keys. +func (bc *MemoryCache) clearItems(keys []string) { + bc.Lock() + defer bc.Unlock() + for _, key := range keys { + delete(bc.items, key) + } +} + +func init() { + Register("memory", NewMemoryCache) +} diff --git a/pkg/cache/redis/redis.go b/pkg/cache/redis/redis.go new file mode 100644 index 00000000..56faf211 --- /dev/null +++ b/pkg/cache/redis/redis.go @@ -0,0 +1,272 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package redis for cache provider +// +// depend on github.com/gomodule/redigo/redis +// +// go install github.com/gomodule/redigo/redis +// +// Usage: +// import( +// _ "github.com/astaxie/beego/cache/redis" +// "github.com/astaxie/beego/cache" +// ) +// +// bm, err := cache.NewCache("redis", `{"conn":"127.0.0.1:11211"}`) +// +// more docs http://beego.me/docs/module/cache.md +package redis + +import ( + "encoding/json" + "errors" + "fmt" + "strconv" + "time" + + "github.com/gomodule/redigo/redis" + + "github.com/astaxie/beego/cache" + "strings" +) + +var ( + // DefaultKey the collection name of redis for cache adapter. + DefaultKey = "beecacheRedis" +) + +// Cache is Redis cache adapter. +type Cache struct { + p *redis.Pool // redis connection pool + conninfo string + dbNum int + key string + password string + maxIdle int + + //the timeout to a value less than the redis server's timeout. + timeout time.Duration +} + +// NewRedisCache create new redis cache with default collection name. +func NewRedisCache() cache.Cache { + return &Cache{key: DefaultKey} +} + +// actually do the redis cmds, args[0] must be the key name. +func (rc *Cache) do(commandName string, args ...interface{}) (reply interface{}, err error) { + if len(args) < 1 { + return nil, errors.New("missing required arguments") + } + args[0] = rc.associate(args[0]) + c := rc.p.Get() + defer c.Close() + + return c.Do(commandName, args...) +} + +// associate with config key. +func (rc *Cache) associate(originKey interface{}) string { + return fmt.Sprintf("%s:%s", rc.key, originKey) +} + +// Get cache from redis. +func (rc *Cache) Get(key string) interface{} { + if v, err := rc.do("GET", key); err == nil { + return v + } + return nil +} + +// GetMulti get cache from redis. +func (rc *Cache) GetMulti(keys []string) []interface{} { + c := rc.p.Get() + defer c.Close() + var args []interface{} + for _, key := range keys { + args = append(args, rc.associate(key)) + } + values, err := redis.Values(c.Do("MGET", args...)) + if err != nil { + return nil + } + return values +} + +// Put put cache to redis. +func (rc *Cache) Put(key string, val interface{}, timeout time.Duration) error { + _, err := rc.do("SETEX", key, int64(timeout/time.Second), val) + return err +} + +// Delete delete cache in redis. +func (rc *Cache) Delete(key string) error { + _, err := rc.do("DEL", key) + return err +} + +// IsExist check cache's existence in redis. +func (rc *Cache) IsExist(key string) bool { + v, err := redis.Bool(rc.do("EXISTS", key)) + if err != nil { + return false + } + return v +} + +// Incr increase counter in redis. +func (rc *Cache) Incr(key string) error { + _, err := redis.Bool(rc.do("INCRBY", key, 1)) + return err +} + +// Decr decrease counter in redis. +func (rc *Cache) Decr(key string) error { + _, err := redis.Bool(rc.do("INCRBY", key, -1)) + return err +} + +// ClearAll clean all cache in redis. delete this redis collection. +func (rc *Cache) ClearAll() error { + cachedKeys, err := rc.Scan(rc.key + ":*") + if err != nil { + return err + } + c := rc.p.Get() + defer c.Close() + for _, str := range cachedKeys { + if _, err = c.Do("DEL", str); err != nil { + return err + } + } + return err +} + +// Scan scan all keys matching the pattern. a better choice than `keys` +func (rc *Cache) Scan(pattern string) (keys []string, err error) { + c := rc.p.Get() + defer c.Close() + var ( + cursor uint64 = 0 // start + result []interface{} + list []string + ) + for { + result, err = redis.Values(c.Do("SCAN", cursor, "MATCH", pattern, "COUNT", 1024)) + if err != nil { + return + } + list, err = redis.Strings(result[1], nil) + if err != nil { + return + } + keys = append(keys, list...) + cursor, err = redis.Uint64(result[0], nil) + if err != nil { + return + } + if cursor == 0 { // over + return + } + } +} + +// StartAndGC start redis cache adapter. +// config is like {"key":"collection key","conn":"connection info","dbNum":"0"} +// the cache item in redis are stored forever, +// so no gc operation. +func (rc *Cache) StartAndGC(config string) error { + var cf map[string]string + json.Unmarshal([]byte(config), &cf) + + if _, ok := cf["key"]; !ok { + cf["key"] = DefaultKey + } + if _, ok := cf["conn"]; !ok { + return errors.New("config has no conn key") + } + + // Format redis://@: + cf["conn"] = strings.Replace(cf["conn"], "redis://", "", 1) + if i := strings.Index(cf["conn"], "@"); i > -1 { + cf["password"] = cf["conn"][0:i] + cf["conn"] = cf["conn"][i+1:] + } + + if _, ok := cf["dbNum"]; !ok { + cf["dbNum"] = "0" + } + if _, ok := cf["password"]; !ok { + cf["password"] = "" + } + if _, ok := cf["maxIdle"]; !ok { + cf["maxIdle"] = "3" + } + if _, ok := cf["timeout"]; !ok { + cf["timeout"] = "180s" + } + rc.key = cf["key"] + rc.conninfo = cf["conn"] + rc.dbNum, _ = strconv.Atoi(cf["dbNum"]) + rc.password = cf["password"] + rc.maxIdle, _ = strconv.Atoi(cf["maxIdle"]) + + if v, err := time.ParseDuration(cf["timeout"]); err == nil { + rc.timeout = v + } else { + rc.timeout = 180 * time.Second + } + + rc.connectInit() + + c := rc.p.Get() + defer c.Close() + + return c.Err() +} + +// connect to redis. +func (rc *Cache) connectInit() { + dialFunc := func() (c redis.Conn, err error) { + c, err = redis.Dial("tcp", rc.conninfo) + if err != nil { + return nil, err + } + + if rc.password != "" { + if _, err := c.Do("AUTH", rc.password); err != nil { + c.Close() + return nil, err + } + } + + _, selecterr := c.Do("SELECT", rc.dbNum) + if selecterr != nil { + c.Close() + return nil, selecterr + } + return + } + // initialize a new pool + rc.p = &redis.Pool{ + MaxIdle: rc.maxIdle, + IdleTimeout: rc.timeout, + Dial: dialFunc, + } +} + +func init() { + cache.Register("redis", NewRedisCache) +} diff --git a/pkg/cache/redis/redis_test.go b/pkg/cache/redis/redis_test.go new file mode 100644 index 00000000..60a19180 --- /dev/null +++ b/pkg/cache/redis/redis_test.go @@ -0,0 +1,144 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package redis + +import ( + "fmt" + "testing" + "time" + + "github.com/astaxie/beego/cache" + "github.com/gomodule/redigo/redis" + "github.com/stretchr/testify/assert" +) + +func TestRedisCache(t *testing.T) { + bm, err := cache.NewCache("redis", `{"conn": "127.0.0.1:6379"}`) + if err != nil { + t.Error("init err") + } + timeoutDuration := 10 * time.Second + if err = bm.Put("astaxie", 1, timeoutDuration); err != nil { + t.Error("set Error", err) + } + if !bm.IsExist("astaxie") { + t.Error("check err") + } + + time.Sleep(11 * time.Second) + + if bm.IsExist("astaxie") { + t.Error("check err") + } + if err = bm.Put("astaxie", 1, timeoutDuration); err != nil { + t.Error("set Error", err) + } + + if v, _ := redis.Int(bm.Get("astaxie"), err); v != 1 { + t.Error("get err") + } + + if err = bm.Incr("astaxie"); err != nil { + t.Error("Incr Error", err) + } + + if v, _ := redis.Int(bm.Get("astaxie"), err); v != 2 { + t.Error("get err") + } + + if err = bm.Decr("astaxie"); err != nil { + t.Error("Decr Error", err) + } + + if v, _ := redis.Int(bm.Get("astaxie"), err); v != 1 { + t.Error("get err") + } + bm.Delete("astaxie") + if bm.IsExist("astaxie") { + t.Error("delete err") + } + + //test string + if err = bm.Put("astaxie", "author", timeoutDuration); err != nil { + t.Error("set Error", err) + } + if !bm.IsExist("astaxie") { + t.Error("check err") + } + + if v, _ := redis.String(bm.Get("astaxie"), err); v != "author" { + t.Error("get err") + } + + //test GetMulti + if err = bm.Put("astaxie1", "author1", timeoutDuration); err != nil { + t.Error("set Error", err) + } + if !bm.IsExist("astaxie1") { + t.Error("check err") + } + + vv := bm.GetMulti([]string{"astaxie", "astaxie1"}) + if len(vv) != 2 { + t.Error("GetMulti ERROR") + } + if v, _ := redis.String(vv[0], nil); v != "author" { + t.Error("GetMulti ERROR") + } + if v, _ := redis.String(vv[1], nil); v != "author1" { + t.Error("GetMulti ERROR") + } + + // test clear all + if err = bm.ClearAll(); err != nil { + t.Error("clear all err") + } +} + +func TestCache_Scan(t *testing.T) { + timeoutDuration := 10 * time.Second + // init + bm, err := cache.NewCache("redis", `{"conn": "127.0.0.1:6379"}`) + if err != nil { + t.Error("init err") + } + // insert all + for i := 0; i < 10000; i++ { + if err = bm.Put(fmt.Sprintf("astaxie%d", i), fmt.Sprintf("author%d", i), timeoutDuration); err != nil { + t.Error("set Error", err) + } + } + // scan all for the first time + keys, err := bm.(*Cache).Scan(DefaultKey + ":*") + if err != nil { + t.Error("scan Error", err) + } + + assert.Equal(t, 10000, len(keys), "scan all error") + + // clear all + if err = bm.ClearAll(); err != nil { + t.Error("clear all err") + } + + // scan all for the second time + keys, err = bm.(*Cache).Scan(DefaultKey + ":*") + if err != nil { + t.Error("scan Error", err) + } + if len(keys) != 0 { + t.Error("scan all err") + } +} diff --git a/pkg/cache/ssdb/ssdb.go b/pkg/cache/ssdb/ssdb.go new file mode 100644 index 00000000..fa2ce04b --- /dev/null +++ b/pkg/cache/ssdb/ssdb.go @@ -0,0 +1,231 @@ +package ssdb + +import ( + "encoding/json" + "errors" + "strconv" + "strings" + "time" + + "github.com/ssdb/gossdb/ssdb" + + "github.com/astaxie/beego/cache" +) + +// Cache SSDB adapter +type Cache struct { + conn *ssdb.Client + conninfo []string +} + +//NewSsdbCache create new ssdb adapter. +func NewSsdbCache() cache.Cache { + return &Cache{} +} + +// Get get value from memcache. +func (rc *Cache) Get(key string) interface{} { + if rc.conn == nil { + if err := rc.connectInit(); err != nil { + return nil + } + } + value, err := rc.conn.Get(key) + if err == nil { + return value + } + return nil +} + +// GetMulti get value from memcache. +func (rc *Cache) GetMulti(keys []string) []interface{} { + size := len(keys) + var values []interface{} + if rc.conn == nil { + if err := rc.connectInit(); err != nil { + for i := 0; i < size; i++ { + values = append(values, err) + } + return values + } + } + res, err := rc.conn.Do("multi_get", keys) + resSize := len(res) + if err == nil { + for i := 1; i < resSize; i += 2 { + values = append(values, res[i+1]) + } + return values + } + for i := 0; i < size; i++ { + values = append(values, err) + } + return values +} + +// DelMulti get value from memcache. +func (rc *Cache) DelMulti(keys []string) error { + if rc.conn == nil { + if err := rc.connectInit(); err != nil { + return err + } + } + _, err := rc.conn.Do("multi_del", keys) + return err +} + +// Put put value to memcache. only support string. +func (rc *Cache) Put(key string, value interface{}, timeout time.Duration) error { + if rc.conn == nil { + if err := rc.connectInit(); err != nil { + return err + } + } + v, ok := value.(string) + if !ok { + return errors.New("value must string") + } + var resp []string + var err error + ttl := int(timeout / time.Second) + if ttl < 0 { + resp, err = rc.conn.Do("set", key, v) + } else { + resp, err = rc.conn.Do("setx", key, v, ttl) + } + if err != nil { + return err + } + if len(resp) == 2 && resp[0] == "ok" { + return nil + } + return errors.New("bad response") +} + +// Delete delete value in memcache. +func (rc *Cache) Delete(key string) error { + if rc.conn == nil { + if err := rc.connectInit(); err != nil { + return err + } + } + _, err := rc.conn.Del(key) + return err +} + +// Incr increase counter. +func (rc *Cache) Incr(key string) error { + if rc.conn == nil { + if err := rc.connectInit(); err != nil { + return err + } + } + _, err := rc.conn.Do("incr", key, 1) + return err +} + +// Decr decrease counter. +func (rc *Cache) Decr(key string) error { + if rc.conn == nil { + if err := rc.connectInit(); err != nil { + return err + } + } + _, err := rc.conn.Do("incr", key, -1) + return err +} + +// IsExist check value exists in memcache. +func (rc *Cache) IsExist(key string) bool { + if rc.conn == nil { + if err := rc.connectInit(); err != nil { + return false + } + } + resp, err := rc.conn.Do("exists", key) + if err != nil { + return false + } + if len(resp) == 2 && resp[1] == "1" { + return true + } + return false + +} + +// ClearAll clear all cached in memcache. +func (rc *Cache) ClearAll() error { + if rc.conn == nil { + if err := rc.connectInit(); err != nil { + return err + } + } + keyStart, keyEnd, limit := "", "", 50 + resp, err := rc.Scan(keyStart, keyEnd, limit) + for err == nil { + size := len(resp) + if size == 1 { + return nil + } + keys := []string{} + for i := 1; i < size; i += 2 { + keys = append(keys, resp[i]) + } + _, e := rc.conn.Do("multi_del", keys) + if e != nil { + return e + } + keyStart = resp[size-2] + resp, err = rc.Scan(keyStart, keyEnd, limit) + } + return err +} + +// Scan key all cached in ssdb. +func (rc *Cache) Scan(keyStart string, keyEnd string, limit int) ([]string, error) { + if rc.conn == nil { + if err := rc.connectInit(); err != nil { + return nil, err + } + } + resp, err := rc.conn.Do("scan", keyStart, keyEnd, limit) + if err != nil { + return nil, err + } + return resp, nil +} + +// StartAndGC start memcache adapter. +// config string is like {"conn":"connection info"}. +// if connecting error, return. +func (rc *Cache) StartAndGC(config string) error { + var cf map[string]string + json.Unmarshal([]byte(config), &cf) + if _, ok := cf["conn"]; !ok { + return errors.New("config has no conn key") + } + rc.conninfo = strings.Split(cf["conn"], ";") + if rc.conn == nil { + if err := rc.connectInit(); err != nil { + return err + } + } + return nil +} + +// connect to memcache and keep the connection. +func (rc *Cache) connectInit() error { + conninfoArray := strings.Split(rc.conninfo[0], ":") + host := conninfoArray[0] + port, e := strconv.Atoi(conninfoArray[1]) + if e != nil { + return e + } + var err error + rc.conn, err = ssdb.Connect(host, port) + return err +} + +func init() { + cache.Register("ssdb", NewSsdbCache) +} diff --git a/pkg/cache/ssdb/ssdb_test.go b/pkg/cache/ssdb/ssdb_test.go new file mode 100644 index 00000000..dd474960 --- /dev/null +++ b/pkg/cache/ssdb/ssdb_test.go @@ -0,0 +1,104 @@ +package ssdb + +import ( + "strconv" + "testing" + "time" + + "github.com/astaxie/beego/cache" +) + +func TestSsdbcacheCache(t *testing.T) { + ssdb, err := cache.NewCache("ssdb", `{"conn": "127.0.0.1:8888"}`) + if err != nil { + t.Error("init err") + } + + // test put and exist + if ssdb.IsExist("ssdb") { + t.Error("check err") + } + timeoutDuration := 10 * time.Second + //timeoutDuration := -10*time.Second if timeoutDuration is negtive,it means permanent + if err = ssdb.Put("ssdb", "ssdb", timeoutDuration); err != nil { + t.Error("set Error", err) + } + if !ssdb.IsExist("ssdb") { + t.Error("check err") + } + + // Get test done + if err = ssdb.Put("ssdb", "ssdb", timeoutDuration); err != nil { + t.Error("set Error", err) + } + + if v := ssdb.Get("ssdb"); v != "ssdb" { + t.Error("get Error") + } + + //inc/dec test done + if err = ssdb.Put("ssdb", "2", timeoutDuration); err != nil { + t.Error("set Error", err) + } + if err = ssdb.Incr("ssdb"); err != nil { + t.Error("incr Error", err) + } + + if v, err := strconv.Atoi(ssdb.Get("ssdb").(string)); err != nil || v != 3 { + t.Error("get err") + } + + if err = ssdb.Decr("ssdb"); err != nil { + t.Error("decr error") + } + + // test del + if err = ssdb.Put("ssdb", "3", timeoutDuration); err != nil { + t.Error("set Error", err) + } + if v, err := strconv.Atoi(ssdb.Get("ssdb").(string)); err != nil || v != 3 { + t.Error("get err") + } + if err := ssdb.Delete("ssdb"); err == nil { + if ssdb.IsExist("ssdb") { + t.Error("delete err") + } + } + + //test string + if err = ssdb.Put("ssdb", "ssdb", -10*time.Second); err != nil { + t.Error("set Error", err) + } + if !ssdb.IsExist("ssdb") { + t.Error("check err") + } + if v := ssdb.Get("ssdb").(string); v != "ssdb" { + t.Error("get err") + } + + //test GetMulti done + if err = ssdb.Put("ssdb1", "ssdb1", -10*time.Second); err != nil { + t.Error("set Error", err) + } + if !ssdb.IsExist("ssdb1") { + t.Error("check err") + } + vv := ssdb.GetMulti([]string{"ssdb", "ssdb1"}) + if len(vv) != 2 { + t.Error("getmulti error") + } + if vv[0].(string) != "ssdb" { + t.Error("getmulti error") + } + if vv[1].(string) != "ssdb1" { + t.Error("getmulti error") + } + + // test clear all done + if err = ssdb.ClearAll(); err != nil { + t.Error("clear all err") + } + if ssdb.IsExist("ssdb") || ssdb.IsExist("ssdb1") { + t.Error("check err") + } +} diff --git a/pkg/config.go b/pkg/config.go new file mode 100644 index 00000000..b6c9a99c --- /dev/null +++ b/pkg/config.go @@ -0,0 +1,524 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package beego + +import ( + "fmt" + "os" + "path/filepath" + "reflect" + "runtime" + "strings" + + "github.com/astaxie/beego/config" + "github.com/astaxie/beego/context" + "github.com/astaxie/beego/logs" + "github.com/astaxie/beego/session" + "github.com/astaxie/beego/utils" +) + +// Config is the main struct for BConfig +type Config struct { + AppName string //Application name + RunMode string //Running Mode: dev | prod + RouterCaseSensitive bool + ServerName string + RecoverPanic bool + RecoverFunc func(*context.Context) + CopyRequestBody bool + EnableGzip bool + MaxMemory int64 + EnableErrorsShow bool + EnableErrorsRender bool + Listen Listen + WebConfig WebConfig + Log LogConfig +} + +// Listen holds for http and https related config +type Listen struct { + Graceful bool // Graceful means use graceful module to start the server + ServerTimeOut int64 + ListenTCP4 bool + EnableHTTP bool + HTTPAddr string + HTTPPort int + AutoTLS bool + Domains []string + TLSCacheDir string + EnableHTTPS bool + EnableMutualHTTPS bool + HTTPSAddr string + HTTPSPort int + HTTPSCertFile string + HTTPSKeyFile string + TrustCaFile string + EnableAdmin bool + AdminAddr string + AdminPort int + EnableFcgi bool + EnableStdIo bool // EnableStdIo works with EnableFcgi Use FCGI via standard I/O +} + +// WebConfig holds web related config +type WebConfig struct { + AutoRender bool + EnableDocs bool + FlashName string + FlashSeparator string + DirectoryIndex bool + StaticDir map[string]string + StaticExtensionsToGzip []string + StaticCacheFileSize int + StaticCacheFileNum int + TemplateLeft string + TemplateRight string + ViewsPath string + EnableXSRF bool + XSRFKey string + XSRFExpire int + Session SessionConfig +} + +// SessionConfig holds session related config +type SessionConfig struct { + SessionOn bool + SessionProvider string + SessionName string + SessionGCMaxLifetime int64 + SessionProviderConfig string + SessionCookieLifeTime int + SessionAutoSetCookie bool + SessionDomain string + SessionDisableHTTPOnly bool // used to allow for cross domain cookies/javascript cookies. + SessionEnableSidInHTTPHeader bool // enable store/get the sessionId into/from http headers + SessionNameInHTTPHeader string + SessionEnableSidInURLQuery bool // enable get the sessionId from Url Query params +} + +// LogConfig holds Log related config +type LogConfig struct { + AccessLogs bool + EnableStaticLogs bool //log static files requests default: false + AccessLogsFormat string //access log format: JSON_FORMAT, APACHE_FORMAT or empty string + FileLineNum bool + Outputs map[string]string // Store Adaptor : config +} + +var ( + // BConfig is the default config for Application + BConfig *Config + // AppConfig is the instance of Config, store the config information from file + AppConfig *beegoAppConfig + // AppPath is the absolute path to the app + AppPath string + // GlobalSessions is the instance for the session manager + GlobalSessions *session.Manager + + // appConfigPath is the path to the config files + appConfigPath string + // appConfigProvider is the provider for the config, default is ini + appConfigProvider = "ini" + // WorkPath is the absolute path to project root directory + WorkPath string +) + +func init() { + BConfig = newBConfig() + var err error + if AppPath, err = filepath.Abs(filepath.Dir(os.Args[0])); err != nil { + panic(err) + } + WorkPath, err = os.Getwd() + if err != nil { + panic(err) + } + var filename = "app.conf" + if os.Getenv("BEEGO_RUNMODE") != "" { + filename = os.Getenv("BEEGO_RUNMODE") + ".app.conf" + } + appConfigPath = filepath.Join(WorkPath, "conf", filename) + if !utils.FileExists(appConfigPath) { + appConfigPath = filepath.Join(AppPath, "conf", filename) + if !utils.FileExists(appConfigPath) { + AppConfig = &beegoAppConfig{innerConfig: config.NewFakeConfig()} + return + } + } + if err = parseConfig(appConfigPath); err != nil { + panic(err) + } +} + +func recoverPanic(ctx *context.Context) { + if err := recover(); err != nil { + if err == ErrAbort { + return + } + if !BConfig.RecoverPanic { + panic(err) + } + if BConfig.EnableErrorsShow { + if _, ok := ErrorMaps[fmt.Sprint(err)]; ok { + exception(fmt.Sprint(err), ctx) + return + } + } + var stack string + logs.Critical("the request url is ", ctx.Input.URL()) + logs.Critical("Handler crashed with error", err) + for i := 1; ; i++ { + _, file, line, ok := runtime.Caller(i) + if !ok { + break + } + logs.Critical(fmt.Sprintf("%s:%d", file, line)) + stack = stack + fmt.Sprintln(fmt.Sprintf("%s:%d", file, line)) + } + if BConfig.RunMode == DEV && BConfig.EnableErrorsRender { + showErr(err, ctx, stack) + } + if ctx.Output.Status != 0 { + ctx.ResponseWriter.WriteHeader(ctx.Output.Status) + } else { + ctx.ResponseWriter.WriteHeader(500) + } + } +} + +func newBConfig() *Config { + return &Config{ + AppName: "beego", + RunMode: PROD, + RouterCaseSensitive: true, + ServerName: "beegoServer:" + VERSION, + RecoverPanic: true, + RecoverFunc: recoverPanic, + CopyRequestBody: false, + EnableGzip: false, + MaxMemory: 1 << 26, //64MB + EnableErrorsShow: true, + EnableErrorsRender: true, + Listen: Listen{ + Graceful: false, + ServerTimeOut: 0, + ListenTCP4: false, + EnableHTTP: true, + AutoTLS: false, + Domains: []string{}, + TLSCacheDir: ".", + HTTPAddr: "", + HTTPPort: 8080, + EnableHTTPS: false, + HTTPSAddr: "", + HTTPSPort: 10443, + HTTPSCertFile: "", + HTTPSKeyFile: "", + EnableAdmin: false, + AdminAddr: "", + AdminPort: 8088, + EnableFcgi: false, + EnableStdIo: false, + }, + WebConfig: WebConfig{ + AutoRender: true, + EnableDocs: false, + FlashName: "BEEGO_FLASH", + FlashSeparator: "BEEGOFLASH", + DirectoryIndex: false, + StaticDir: map[string]string{"/static": "static"}, + StaticExtensionsToGzip: []string{".css", ".js"}, + StaticCacheFileSize: 1024 * 100, + StaticCacheFileNum: 1000, + TemplateLeft: "{{", + TemplateRight: "}}", + ViewsPath: "views", + EnableXSRF: false, + XSRFKey: "beegoxsrf", + XSRFExpire: 0, + Session: SessionConfig{ + SessionOn: false, + SessionProvider: "memory", + SessionName: "beegosessionID", + SessionGCMaxLifetime: 3600, + SessionProviderConfig: "", + SessionDisableHTTPOnly: false, + SessionCookieLifeTime: 0, //set cookie default is the browser life + SessionAutoSetCookie: true, + SessionDomain: "", + SessionEnableSidInHTTPHeader: false, // enable store/get the sessionId into/from http headers + SessionNameInHTTPHeader: "Beegosessionid", + SessionEnableSidInURLQuery: false, // enable get the sessionId from Url Query params + }, + }, + Log: LogConfig{ + AccessLogs: false, + EnableStaticLogs: false, + AccessLogsFormat: "APACHE_FORMAT", + FileLineNum: true, + Outputs: map[string]string{"console": ""}, + }, + } +} + +// now only support ini, next will support json. +func parseConfig(appConfigPath string) (err error) { + AppConfig, err = newAppConfig(appConfigProvider, appConfigPath) + if err != nil { + return err + } + return assignConfig(AppConfig) +} + +func assignConfig(ac config.Configer) error { + for _, i := range []interface{}{BConfig, &BConfig.Listen, &BConfig.WebConfig, &BConfig.Log, &BConfig.WebConfig.Session} { + assignSingleConfig(i, ac) + } + // set the run mode first + if envRunMode := os.Getenv("BEEGO_RUNMODE"); envRunMode != "" { + BConfig.RunMode = envRunMode + } else if runMode := ac.String("RunMode"); runMode != "" { + BConfig.RunMode = runMode + } + + if sd := ac.String("StaticDir"); sd != "" { + BConfig.WebConfig.StaticDir = map[string]string{} + sds := strings.Fields(sd) + for _, v := range sds { + if url2fsmap := strings.SplitN(v, ":", 2); len(url2fsmap) == 2 { + BConfig.WebConfig.StaticDir["/"+strings.Trim(url2fsmap[0], "/")] = url2fsmap[1] + } else { + BConfig.WebConfig.StaticDir["/"+strings.Trim(url2fsmap[0], "/")] = url2fsmap[0] + } + } + } + + if sgz := ac.String("StaticExtensionsToGzip"); sgz != "" { + extensions := strings.Split(sgz, ",") + fileExts := []string{} + for _, ext := range extensions { + ext = strings.TrimSpace(ext) + if ext == "" { + continue + } + if !strings.HasPrefix(ext, ".") { + ext = "." + ext + } + fileExts = append(fileExts, ext) + } + if len(fileExts) > 0 { + BConfig.WebConfig.StaticExtensionsToGzip = fileExts + } + } + + if sfs, err := ac.Int("StaticCacheFileSize"); err == nil { + BConfig.WebConfig.StaticCacheFileSize = sfs + } + + if sfn, err := ac.Int("StaticCacheFileNum"); err == nil { + BConfig.WebConfig.StaticCacheFileNum = sfn + } + + if lo := ac.String("LogOutputs"); lo != "" { + // if lo is not nil or empty + // means user has set his own LogOutputs + // clear the default setting to BConfig.Log.Outputs + BConfig.Log.Outputs = make(map[string]string) + los := strings.Split(lo, ";") + for _, v := range los { + if logType2Config := strings.SplitN(v, ",", 2); len(logType2Config) == 2 { + BConfig.Log.Outputs[logType2Config[0]] = logType2Config[1] + } else { + continue + } + } + } + + //init log + logs.Reset() + for adaptor, config := range BConfig.Log.Outputs { + err := logs.SetLogger(adaptor, config) + if err != nil { + fmt.Fprintln(os.Stderr, fmt.Sprintf("%s with the config %q got err:%s", adaptor, config, err.Error())) + } + } + logs.SetLogFuncCall(BConfig.Log.FileLineNum) + + return nil +} + +func assignSingleConfig(p interface{}, ac config.Configer) { + pt := reflect.TypeOf(p) + if pt.Kind() != reflect.Ptr { + return + } + pt = pt.Elem() + if pt.Kind() != reflect.Struct { + return + } + pv := reflect.ValueOf(p).Elem() + + for i := 0; i < pt.NumField(); i++ { + pf := pv.Field(i) + if !pf.CanSet() { + continue + } + name := pt.Field(i).Name + switch pf.Kind() { + case reflect.String: + pf.SetString(ac.DefaultString(name, pf.String())) + case reflect.Int, reflect.Int64: + pf.SetInt(ac.DefaultInt64(name, pf.Int())) + case reflect.Bool: + pf.SetBool(ac.DefaultBool(name, pf.Bool())) + case reflect.Struct: + default: + //do nothing here + } + } + +} + +// LoadAppConfig allow developer to apply a config file +func LoadAppConfig(adapterName, configPath string) error { + absConfigPath, err := filepath.Abs(configPath) + if err != nil { + return err + } + + if !utils.FileExists(absConfigPath) { + return fmt.Errorf("the target config file: %s don't exist", configPath) + } + + appConfigPath = absConfigPath + appConfigProvider = adapterName + + return parseConfig(appConfigPath) +} + +type beegoAppConfig struct { + innerConfig config.Configer +} + +func newAppConfig(appConfigProvider, appConfigPath string) (*beegoAppConfig, error) { + ac, err := config.NewConfig(appConfigProvider, appConfigPath) + if err != nil { + return nil, err + } + return &beegoAppConfig{ac}, nil +} + +func (b *beegoAppConfig) Set(key, val string) error { + if err := b.innerConfig.Set(BConfig.RunMode+"::"+key, val); err != nil { + return b.innerConfig.Set(key, val) + } + return nil +} + +func (b *beegoAppConfig) String(key string) string { + if v := b.innerConfig.String(BConfig.RunMode + "::" + key); v != "" { + return v + } + return b.innerConfig.String(key) +} + +func (b *beegoAppConfig) Strings(key string) []string { + if v := b.innerConfig.Strings(BConfig.RunMode + "::" + key); len(v) > 0 { + return v + } + return b.innerConfig.Strings(key) +} + +func (b *beegoAppConfig) Int(key string) (int, error) { + if v, err := b.innerConfig.Int(BConfig.RunMode + "::" + key); err == nil { + return v, nil + } + return b.innerConfig.Int(key) +} + +func (b *beegoAppConfig) Int64(key string) (int64, error) { + if v, err := b.innerConfig.Int64(BConfig.RunMode + "::" + key); err == nil { + return v, nil + } + return b.innerConfig.Int64(key) +} + +func (b *beegoAppConfig) Bool(key string) (bool, error) { + if v, err := b.innerConfig.Bool(BConfig.RunMode + "::" + key); err == nil { + return v, nil + } + return b.innerConfig.Bool(key) +} + +func (b *beegoAppConfig) Float(key string) (float64, error) { + if v, err := b.innerConfig.Float(BConfig.RunMode + "::" + key); err == nil { + return v, nil + } + return b.innerConfig.Float(key) +} + +func (b *beegoAppConfig) DefaultString(key string, defaultVal string) string { + if v := b.String(key); v != "" { + return v + } + return defaultVal +} + +func (b *beegoAppConfig) DefaultStrings(key string, defaultVal []string) []string { + if v := b.Strings(key); len(v) != 0 { + return v + } + return defaultVal +} + +func (b *beegoAppConfig) DefaultInt(key string, defaultVal int) int { + if v, err := b.Int(key); err == nil { + return v + } + return defaultVal +} + +func (b *beegoAppConfig) DefaultInt64(key string, defaultVal int64) int64 { + if v, err := b.Int64(key); err == nil { + return v + } + return defaultVal +} + +func (b *beegoAppConfig) DefaultBool(key string, defaultVal bool) bool { + if v, err := b.Bool(key); err == nil { + return v + } + return defaultVal +} + +func (b *beegoAppConfig) DefaultFloat(key string, defaultVal float64) float64 { + if v, err := b.Float(key); err == nil { + return v + } + return defaultVal +} + +func (b *beegoAppConfig) DIY(key string) (interface{}, error) { + return b.innerConfig.DIY(key) +} + +func (b *beegoAppConfig) GetSection(section string) (map[string]string, error) { + return b.innerConfig.GetSection(section) +} + +func (b *beegoAppConfig) SaveConfigFile(filename string) error { + return b.innerConfig.SaveConfigFile(filename) +} diff --git a/pkg/config/config.go b/pkg/config/config.go new file mode 100644 index 00000000..bfd79e85 --- /dev/null +++ b/pkg/config/config.go @@ -0,0 +1,242 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package config is used to parse config. +// Usage: +// import "github.com/astaxie/beego/config" +//Examples. +// +// cnf, err := config.NewConfig("ini", "config.conf") +// +// cnf APIS: +// +// cnf.Set(key, val string) error +// cnf.String(key string) string +// cnf.Strings(key string) []string +// cnf.Int(key string) (int, error) +// cnf.Int64(key string) (int64, error) +// cnf.Bool(key string) (bool, error) +// cnf.Float(key string) (float64, error) +// cnf.DefaultString(key string, defaultVal string) string +// cnf.DefaultStrings(key string, defaultVal []string) []string +// cnf.DefaultInt(key string, defaultVal int) int +// cnf.DefaultInt64(key string, defaultVal int64) int64 +// cnf.DefaultBool(key string, defaultVal bool) bool +// cnf.DefaultFloat(key string, defaultVal float64) float64 +// cnf.DIY(key string) (interface{}, error) +// cnf.GetSection(section string) (map[string]string, error) +// cnf.SaveConfigFile(filename string) error +//More docs http://beego.me/docs/module/config.md +package config + +import ( + "fmt" + "os" + "reflect" + "time" +) + +// Configer defines how to get and set value from configuration raw data. +type Configer interface { + Set(key, val string) error //support section::key type in given key when using ini type. + String(key string) string //support section::key type in key string when using ini and json type; Int,Int64,Bool,Float,DIY are same. + Strings(key string) []string //get string slice + Int(key string) (int, error) + Int64(key string) (int64, error) + Bool(key string) (bool, error) + Float(key string) (float64, error) + DefaultString(key string, defaultVal string) string // support section::key type in key string when using ini and json type; Int,Int64,Bool,Float,DIY are same. + DefaultStrings(key string, defaultVal []string) []string //get string slice + DefaultInt(key string, defaultVal int) int + DefaultInt64(key string, defaultVal int64) int64 + DefaultBool(key string, defaultVal bool) bool + DefaultFloat(key string, defaultVal float64) float64 + DIY(key string) (interface{}, error) + GetSection(section string) (map[string]string, error) + SaveConfigFile(filename string) error +} + +// Config is the adapter interface for parsing config file to get raw data to Configer. +type Config interface { + Parse(key string) (Configer, error) + ParseData(data []byte) (Configer, error) +} + +var adapters = make(map[string]Config) + +// Register makes a config adapter available by the adapter name. +// If Register is called twice with the same name or if driver is nil, +// it panics. +func Register(name string, adapter Config) { + if adapter == nil { + panic("config: Register adapter is nil") + } + if _, ok := adapters[name]; ok { + panic("config: Register called twice for adapter " + name) + } + adapters[name] = adapter +} + +// NewConfig adapterName is ini/json/xml/yaml. +// filename is the config file path. +func NewConfig(adapterName, filename string) (Configer, error) { + adapter, ok := adapters[adapterName] + if !ok { + return nil, fmt.Errorf("config: unknown adaptername %q (forgotten import?)", adapterName) + } + return adapter.Parse(filename) +} + +// NewConfigData adapterName is ini/json/xml/yaml. +// data is the config data. +func NewConfigData(adapterName string, data []byte) (Configer, error) { + adapter, ok := adapters[adapterName] + if !ok { + return nil, fmt.Errorf("config: unknown adaptername %q (forgotten import?)", adapterName) + } + return adapter.ParseData(data) +} + +// ExpandValueEnvForMap convert all string value with environment variable. +func ExpandValueEnvForMap(m map[string]interface{}) map[string]interface{} { + for k, v := range m { + switch value := v.(type) { + case string: + m[k] = ExpandValueEnv(value) + case map[string]interface{}: + m[k] = ExpandValueEnvForMap(value) + case map[string]string: + for k2, v2 := range value { + value[k2] = ExpandValueEnv(v2) + } + m[k] = value + } + } + return m +} + +// ExpandValueEnv returns value of convert with environment variable. +// +// Return environment variable if value start with "${" and end with "}". +// Return default value if environment variable is empty or not exist. +// +// It accept value formats "${env}" , "${env||}}" , "${env||defaultValue}" , "defaultvalue". +// Examples: +// v1 := config.ExpandValueEnv("${GOPATH}") // return the GOPATH environment variable. +// v2 := config.ExpandValueEnv("${GOAsta||/usr/local/go}") // return the default value "/usr/local/go/". +// v3 := config.ExpandValueEnv("Astaxie") // return the value "Astaxie". +func ExpandValueEnv(value string) (realValue string) { + realValue = value + + vLen := len(value) + // 3 = ${} + if vLen < 3 { + return + } + // Need start with "${" and end with "}", then return. + if value[0] != '$' || value[1] != '{' || value[vLen-1] != '}' { + return + } + + key := "" + defaultV := "" + // value start with "${" + for i := 2; i < vLen; i++ { + if value[i] == '|' && (i+1 < vLen && value[i+1] == '|') { + key = value[2:i] + defaultV = value[i+2 : vLen-1] // other string is default value. + break + } else if value[i] == '}' { + key = value[2:i] + break + } + } + + realValue = os.Getenv(key) + if realValue == "" { + realValue = defaultV + } + + return +} + +// ParseBool returns the boolean value represented by the string. +// +// It accepts 1, 1.0, t, T, TRUE, true, True, YES, yes, Yes,Y, y, ON, on, On, +// 0, 0.0, f, F, FALSE, false, False, NO, no, No, N,n, OFF, off, Off. +// Any other value returns an error. +func ParseBool(val interface{}) (value bool, err error) { + if val != nil { + switch v := val.(type) { + case bool: + return v, nil + case string: + switch v { + case "1", "t", "T", "true", "TRUE", "True", "YES", "yes", "Yes", "Y", "y", "ON", "on", "On": + return true, nil + case "0", "f", "F", "false", "FALSE", "False", "NO", "no", "No", "N", "n", "OFF", "off", "Off": + return false, nil + } + case int8, int32, int64: + strV := fmt.Sprintf("%d", v) + if strV == "1" { + return true, nil + } else if strV == "0" { + return false, nil + } + case float64: + if v == 1.0 { + return true, nil + } else if v == 0.0 { + return false, nil + } + } + return false, fmt.Errorf("parsing %q: invalid syntax", val) + } + return false, fmt.Errorf("parsing : invalid syntax") +} + +// ToString converts values of any type to string. +func ToString(x interface{}) string { + switch y := x.(type) { + + // Handle dates with special logic + // This needs to come above the fmt.Stringer + // test since time.Time's have a .String() + // method + case time.Time: + return y.Format("A Monday") + + // Handle type string + case string: + return y + + // Handle type with .String() method + case fmt.Stringer: + return y.String() + + // Handle type with .Error() method + case error: + return y.Error() + + } + + // Handle named string type + if v := reflect.ValueOf(x); v.Kind() == reflect.String { + return v.String() + } + + // Fallback to fmt package for anything else like numeric types + return fmt.Sprint(x) +} diff --git a/pkg/config/config_test.go b/pkg/config/config_test.go new file mode 100644 index 00000000..15d6ffa6 --- /dev/null +++ b/pkg/config/config_test.go @@ -0,0 +1,55 @@ +// Copyright 2016 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package config + +import ( + "os" + "testing" +) + +func TestExpandValueEnv(t *testing.T) { + + testCases := []struct { + item string + want string + }{ + {"", ""}, + {"$", "$"}, + {"{", "{"}, + {"{}", "{}"}, + {"${}", ""}, + {"${|}", ""}, + {"${}", ""}, + {"${{}}", ""}, + {"${{||}}", "}"}, + {"${pwd||}", ""}, + {"${pwd||}", ""}, + {"${pwd||}", ""}, + {"${pwd||}}", "}"}, + {"${pwd||{{||}}}", "{{||}}"}, + {"${GOPATH}", os.Getenv("GOPATH")}, + {"${GOPATH||}", os.Getenv("GOPATH")}, + {"${GOPATH||root}", os.Getenv("GOPATH")}, + {"${GOPATH_NOT||root}", "root"}, + {"${GOPATH_NOT||||root}", "||root"}, + } + + for _, c := range testCases { + if got := ExpandValueEnv(c.item); got != c.want { + t.Errorf("expand value error, item %q want %q, got %q", c.item, c.want, got) + } + } + +} diff --git a/pkg/config/env/env.go b/pkg/config/env/env.go new file mode 100644 index 00000000..34f094fe --- /dev/null +++ b/pkg/config/env/env.go @@ -0,0 +1,87 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// Copyright 2017 Faissal Elamraoui. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package env is used to parse environment. +package env + +import ( + "fmt" + "os" + "strings" + + "github.com/astaxie/beego/utils" +) + +var env *utils.BeeMap + +func init() { + env = utils.NewBeeMap() + for _, e := range os.Environ() { + splits := strings.Split(e, "=") + env.Set(splits[0], os.Getenv(splits[0])) + } +} + +// Get returns a value by key. +// If the key does not exist, the default value will be returned. +func Get(key string, defVal string) string { + if val := env.Get(key); val != nil { + return val.(string) + } + return defVal +} + +// MustGet returns a value by key. +// If the key does not exist, it will return an error. +func MustGet(key string) (string, error) { + if val := env.Get(key); val != nil { + return val.(string), nil + } + return "", fmt.Errorf("no env variable with %s", key) +} + +// Set sets a value in the ENV copy. +// This does not affect the child process environment. +func Set(key string, value string) { + env.Set(key, value) +} + +// MustSet sets a value in the ENV copy and the child process environment. +// It returns an error in case the set operation failed. +func MustSet(key string, value string) error { + err := os.Setenv(key, value) + if err != nil { + return err + } + env.Set(key, value) + return nil +} + +// GetAll returns all keys/values in the current child process environment. +func GetAll() map[string]string { + items := env.Items() + envs := make(map[string]string, env.Count()) + + for key, val := range items { + switch key := key.(type) { + case string: + switch val := val.(type) { + case string: + envs[key] = val + } + } + } + return envs +} diff --git a/pkg/config/env/env_test.go b/pkg/config/env/env_test.go new file mode 100644 index 00000000..3f1d4dba --- /dev/null +++ b/pkg/config/env/env_test.go @@ -0,0 +1,75 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// Copyright 2017 Faissal Elamraoui. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +package env + +import ( + "os" + "testing" +) + +func TestEnvGet(t *testing.T) { + gopath := Get("GOPATH", "") + if gopath != os.Getenv("GOPATH") { + t.Error("expected GOPATH not empty.") + } + + noExistVar := Get("NOEXISTVAR", "foo") + if noExistVar != "foo" { + t.Errorf("expected NOEXISTVAR to equal foo, got %s.", noExistVar) + } +} + +func TestEnvMustGet(t *testing.T) { + gopath, err := MustGet("GOPATH") + if err != nil { + t.Error(err) + } + + if gopath != os.Getenv("GOPATH") { + t.Errorf("expected GOPATH to be the same, got %s.", gopath) + } + + _, err = MustGet("NOEXISTVAR") + if err == nil { + t.Error("expected error to be non-nil") + } +} + +func TestEnvSet(t *testing.T) { + Set("MYVAR", "foo") + myVar := Get("MYVAR", "bar") + if myVar != "foo" { + t.Errorf("expected MYVAR to equal foo, got %s.", myVar) + } +} + +func TestEnvMustSet(t *testing.T) { + err := MustSet("FOO", "bar") + if err != nil { + t.Error(err) + } + + fooVar := os.Getenv("FOO") + if fooVar != "bar" { + t.Errorf("expected FOO variable to equal bar, got %s.", fooVar) + } +} + +func TestEnvGetAll(t *testing.T) { + envMap := GetAll() + if len(envMap) == 0 { + t.Error("expected environment not empty.") + } +} diff --git a/pkg/config/fake.go b/pkg/config/fake.go new file mode 100644 index 00000000..d21ab820 --- /dev/null +++ b/pkg/config/fake.go @@ -0,0 +1,134 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package config + +import ( + "errors" + "strconv" + "strings" +) + +type fakeConfigContainer struct { + data map[string]string +} + +func (c *fakeConfigContainer) getData(key string) string { + return c.data[strings.ToLower(key)] +} + +func (c *fakeConfigContainer) Set(key, val string) error { + c.data[strings.ToLower(key)] = val + return nil +} + +func (c *fakeConfigContainer) String(key string) string { + return c.getData(key) +} + +func (c *fakeConfigContainer) DefaultString(key string, defaultval string) string { + v := c.String(key) + if v == "" { + return defaultval + } + return v +} + +func (c *fakeConfigContainer) Strings(key string) []string { + v := c.String(key) + if v == "" { + return nil + } + return strings.Split(v, ";") +} + +func (c *fakeConfigContainer) DefaultStrings(key string, defaultval []string) []string { + v := c.Strings(key) + if v == nil { + return defaultval + } + return v +} + +func (c *fakeConfigContainer) Int(key string) (int, error) { + return strconv.Atoi(c.getData(key)) +} + +func (c *fakeConfigContainer) DefaultInt(key string, defaultval int) int { + v, err := c.Int(key) + if err != nil { + return defaultval + } + return v +} + +func (c *fakeConfigContainer) Int64(key string) (int64, error) { + return strconv.ParseInt(c.getData(key), 10, 64) +} + +func (c *fakeConfigContainer) DefaultInt64(key string, defaultval int64) int64 { + v, err := c.Int64(key) + if err != nil { + return defaultval + } + return v +} + +func (c *fakeConfigContainer) Bool(key string) (bool, error) { + return ParseBool(c.getData(key)) +} + +func (c *fakeConfigContainer) DefaultBool(key string, defaultval bool) bool { + v, err := c.Bool(key) + if err != nil { + return defaultval + } + return v +} + +func (c *fakeConfigContainer) Float(key string) (float64, error) { + return strconv.ParseFloat(c.getData(key), 64) +} + +func (c *fakeConfigContainer) DefaultFloat(key string, defaultval float64) float64 { + v, err := c.Float(key) + if err != nil { + return defaultval + } + return v +} + +func (c *fakeConfigContainer) DIY(key string) (interface{}, error) { + if v, ok := c.data[strings.ToLower(key)]; ok { + return v, nil + } + return nil, errors.New("key not find") +} + +func (c *fakeConfigContainer) GetSection(section string) (map[string]string, error) { + return nil, errors.New("not implement in the fakeConfigContainer") +} + +func (c *fakeConfigContainer) SaveConfigFile(filename string) error { + return errors.New("not implement in the fakeConfigContainer") +} + +var _ Configer = new(fakeConfigContainer) + +// NewFakeConfig return a fake Configer +func NewFakeConfig() Configer { + return &fakeConfigContainer{ + data: make(map[string]string), + } +} diff --git a/pkg/config/ini.go b/pkg/config/ini.go new file mode 100644 index 00000000..002e5e05 --- /dev/null +++ b/pkg/config/ini.go @@ -0,0 +1,504 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package config + +import ( + "bufio" + "bytes" + "errors" + "io" + "io/ioutil" + "os" + "os/user" + "path/filepath" + "strconv" + "strings" + "sync" +) + +var ( + defaultSection = "default" // default section means if some ini items not in a section, make them in default section, + bNumComment = []byte{'#'} // number signal + bSemComment = []byte{';'} // semicolon signal + bEmpty = []byte{} + bEqual = []byte{'='} // equal signal + bDQuote = []byte{'"'} // quote signal + sectionStart = []byte{'['} // section start signal + sectionEnd = []byte{']'} // section end signal + lineBreak = "\n" +) + +// IniConfig implements Config to parse ini file. +type IniConfig struct { +} + +// Parse creates a new Config and parses the file configuration from the named file. +func (ini *IniConfig) Parse(name string) (Configer, error) { + return ini.parseFile(name) +} + +func (ini *IniConfig) parseFile(name string) (*IniConfigContainer, error) { + data, err := ioutil.ReadFile(name) + if err != nil { + return nil, err + } + + return ini.parseData(filepath.Dir(name), data) +} + +func (ini *IniConfig) parseData(dir string, data []byte) (*IniConfigContainer, error) { + cfg := &IniConfigContainer{ + data: make(map[string]map[string]string), + sectionComment: make(map[string]string), + keyComment: make(map[string]string), + RWMutex: sync.RWMutex{}, + } + cfg.Lock() + defer cfg.Unlock() + + var comment bytes.Buffer + buf := bufio.NewReader(bytes.NewBuffer(data)) + // check the BOM + head, err := buf.Peek(3) + if err == nil && head[0] == 239 && head[1] == 187 && head[2] == 191 { + for i := 1; i <= 3; i++ { + buf.ReadByte() + } + } + section := defaultSection + tmpBuf := bytes.NewBuffer(nil) + for { + tmpBuf.Reset() + + shouldBreak := false + for { + tmp, isPrefix, err := buf.ReadLine() + if err == io.EOF { + shouldBreak = true + break + } + + //It might be a good idea to throw a error on all unknonw errors? + if _, ok := err.(*os.PathError); ok { + return nil, err + } + + tmpBuf.Write(tmp) + if isPrefix { + continue + } + + if !isPrefix { + break + } + } + if shouldBreak { + break + } + + line := tmpBuf.Bytes() + line = bytes.TrimSpace(line) + if bytes.Equal(line, bEmpty) { + continue + } + var bComment []byte + switch { + case bytes.HasPrefix(line, bNumComment): + bComment = bNumComment + case bytes.HasPrefix(line, bSemComment): + bComment = bSemComment + } + if bComment != nil { + line = bytes.TrimLeft(line, string(bComment)) + // Need append to a new line if multi-line comments. + if comment.Len() > 0 { + comment.WriteByte('\n') + } + comment.Write(line) + continue + } + + if bytes.HasPrefix(line, sectionStart) && bytes.HasSuffix(line, sectionEnd) { + section = strings.ToLower(string(line[1 : len(line)-1])) // section name case insensitive + if comment.Len() > 0 { + cfg.sectionComment[section] = comment.String() + comment.Reset() + } + if _, ok := cfg.data[section]; !ok { + cfg.data[section] = make(map[string]string) + } + continue + } + + if _, ok := cfg.data[section]; !ok { + cfg.data[section] = make(map[string]string) + } + keyValue := bytes.SplitN(line, bEqual, 2) + + key := string(bytes.TrimSpace(keyValue[0])) // key name case insensitive + key = strings.ToLower(key) + + // handle include "other.conf" + if len(keyValue) == 1 && strings.HasPrefix(key, "include") { + + includefiles := strings.Fields(key) + if includefiles[0] == "include" && len(includefiles) == 2 { + + otherfile := strings.Trim(includefiles[1], "\"") + if !filepath.IsAbs(otherfile) { + otherfile = filepath.Join(dir, otherfile) + } + + i, err := ini.parseFile(otherfile) + if err != nil { + return nil, err + } + + for sec, dt := range i.data { + if _, ok := cfg.data[sec]; !ok { + cfg.data[sec] = make(map[string]string) + } + for k, v := range dt { + cfg.data[sec][k] = v + } + } + + for sec, comm := range i.sectionComment { + cfg.sectionComment[sec] = comm + } + + for k, comm := range i.keyComment { + cfg.keyComment[k] = comm + } + + continue + } + } + + if len(keyValue) != 2 { + return nil, errors.New("read the content error: \"" + string(line) + "\", should key = val") + } + val := bytes.TrimSpace(keyValue[1]) + if bytes.HasPrefix(val, bDQuote) { + val = bytes.Trim(val, `"`) + } + + cfg.data[section][key] = ExpandValueEnv(string(val)) + if comment.Len() > 0 { + cfg.keyComment[section+"."+key] = comment.String() + comment.Reset() + } + + } + return cfg, nil +} + +// ParseData parse ini the data +// When include other.conf,other.conf is either absolute directory +// or under beego in default temporary directory(/tmp/beego[-username]). +func (ini *IniConfig) ParseData(data []byte) (Configer, error) { + dir := "beego" + currentUser, err := user.Current() + if err == nil { + dir = "beego-" + currentUser.Username + } + dir = filepath.Join(os.TempDir(), dir) + if err = os.MkdirAll(dir, os.ModePerm); err != nil { + return nil, err + } + + return ini.parseData(dir, data) +} + +// IniConfigContainer A Config represents the ini configuration. +// When set and get value, support key as section:name type. +type IniConfigContainer struct { + data map[string]map[string]string // section=> key:val + sectionComment map[string]string // section : comment + keyComment map[string]string // id: []{comment, key...}; id 1 is for main comment. + sync.RWMutex +} + +// Bool returns the boolean value for a given key. +func (c *IniConfigContainer) Bool(key string) (bool, error) { + return ParseBool(c.getdata(key)) +} + +// DefaultBool returns the boolean value for a given key. +// if err != nil return defaultval +func (c *IniConfigContainer) DefaultBool(key string, defaultval bool) bool { + v, err := c.Bool(key) + if err != nil { + return defaultval + } + return v +} + +// Int returns the integer value for a given key. +func (c *IniConfigContainer) Int(key string) (int, error) { + return strconv.Atoi(c.getdata(key)) +} + +// DefaultInt returns the integer value for a given key. +// if err != nil return defaultval +func (c *IniConfigContainer) DefaultInt(key string, defaultval int) int { + v, err := c.Int(key) + if err != nil { + return defaultval + } + return v +} + +// Int64 returns the int64 value for a given key. +func (c *IniConfigContainer) Int64(key string) (int64, error) { + return strconv.ParseInt(c.getdata(key), 10, 64) +} + +// DefaultInt64 returns the int64 value for a given key. +// if err != nil return defaultval +func (c *IniConfigContainer) DefaultInt64(key string, defaultval int64) int64 { + v, err := c.Int64(key) + if err != nil { + return defaultval + } + return v +} + +// Float returns the float value for a given key. +func (c *IniConfigContainer) Float(key string) (float64, error) { + return strconv.ParseFloat(c.getdata(key), 64) +} + +// DefaultFloat returns the float64 value for a given key. +// if err != nil return defaultval +func (c *IniConfigContainer) DefaultFloat(key string, defaultval float64) float64 { + v, err := c.Float(key) + if err != nil { + return defaultval + } + return v +} + +// String returns the string value for a given key. +func (c *IniConfigContainer) String(key string) string { + return c.getdata(key) +} + +// DefaultString returns the string value for a given key. +// if err != nil return defaultval +func (c *IniConfigContainer) DefaultString(key string, defaultval string) string { + v := c.String(key) + if v == "" { + return defaultval + } + return v +} + +// Strings returns the []string value for a given key. +// Return nil if config value does not exist or is empty. +func (c *IniConfigContainer) Strings(key string) []string { + v := c.String(key) + if v == "" { + return nil + } + return strings.Split(v, ";") +} + +// DefaultStrings returns the []string value for a given key. +// if err != nil return defaultval +func (c *IniConfigContainer) DefaultStrings(key string, defaultval []string) []string { + v := c.Strings(key) + if v == nil { + return defaultval + } + return v +} + +// GetSection returns map for the given section +func (c *IniConfigContainer) GetSection(section string) (map[string]string, error) { + if v, ok := c.data[section]; ok { + return v, nil + } + return nil, errors.New("not exist section") +} + +// SaveConfigFile save the config into file. +// +// BUG(env): The environment variable config item will be saved with real value in SaveConfigFile Function. +func (c *IniConfigContainer) SaveConfigFile(filename string) (err error) { + // Write configuration file by filename. + f, err := os.Create(filename) + if err != nil { + return err + } + defer f.Close() + + // Get section or key comments. Fixed #1607 + getCommentStr := func(section, key string) string { + var ( + comment string + ok bool + ) + if len(key) == 0 { + comment, ok = c.sectionComment[section] + } else { + comment, ok = c.keyComment[section+"."+key] + } + + if ok { + // Empty comment + if len(comment) == 0 || len(strings.TrimSpace(comment)) == 0 { + return string(bNumComment) + } + prefix := string(bNumComment) + // Add the line head character "#" + return prefix + strings.Replace(comment, lineBreak, lineBreak+prefix, -1) + } + return "" + } + + buf := bytes.NewBuffer(nil) + // Save default section at first place + if dt, ok := c.data[defaultSection]; ok { + for key, val := range dt { + if key != " " { + // Write key comments. + if v := getCommentStr(defaultSection, key); len(v) > 0 { + if _, err = buf.WriteString(v + lineBreak); err != nil { + return err + } + } + + // Write key and value. + if _, err = buf.WriteString(key + string(bEqual) + val + lineBreak); err != nil { + return err + } + } + } + + // Put a line between sections. + if _, err = buf.WriteString(lineBreak); err != nil { + return err + } + } + // Save named sections + for section, dt := range c.data { + if section != defaultSection { + // Write section comments. + if v := getCommentStr(section, ""); len(v) > 0 { + if _, err = buf.WriteString(v + lineBreak); err != nil { + return err + } + } + + // Write section name. + if _, err = buf.WriteString(string(sectionStart) + section + string(sectionEnd) + lineBreak); err != nil { + return err + } + + for key, val := range dt { + if key != " " { + // Write key comments. + if v := getCommentStr(section, key); len(v) > 0 { + if _, err = buf.WriteString(v + lineBreak); err != nil { + return err + } + } + + // Write key and value. + if _, err = buf.WriteString(key + string(bEqual) + val + lineBreak); err != nil { + return err + } + } + } + + // Put a line between sections. + if _, err = buf.WriteString(lineBreak); err != nil { + return err + } + } + } + _, err = buf.WriteTo(f) + return err +} + +// Set writes a new value for key. +// if write to one section, the key need be "section::key". +// if the section is not existed, it panics. +func (c *IniConfigContainer) Set(key, value string) error { + c.Lock() + defer c.Unlock() + if len(key) == 0 { + return errors.New("key is empty") + } + + var ( + section, k string + sectionKey = strings.Split(strings.ToLower(key), "::") + ) + + if len(sectionKey) >= 2 { + section = sectionKey[0] + k = sectionKey[1] + } else { + section = defaultSection + k = sectionKey[0] + } + + if _, ok := c.data[section]; !ok { + c.data[section] = make(map[string]string) + } + c.data[section][k] = value + return nil +} + +// DIY returns the raw value by a given key. +func (c *IniConfigContainer) DIY(key string) (v interface{}, err error) { + if v, ok := c.data[strings.ToLower(key)]; ok { + return v, nil + } + return v, errors.New("key not find") +} + +// section.key or key +func (c *IniConfigContainer) getdata(key string) string { + if len(key) == 0 { + return "" + } + c.RLock() + defer c.RUnlock() + + var ( + section, k string + sectionKey = strings.Split(strings.ToLower(key), "::") + ) + if len(sectionKey) >= 2 { + section = sectionKey[0] + k = sectionKey[1] + } else { + section = defaultSection + k = sectionKey[0] + } + if v, ok := c.data[section]; ok { + if vv, ok := v[k]; ok { + return vv + } + } + return "" +} + +func init() { + Register("ini", &IniConfig{}) +} diff --git a/pkg/config/ini_test.go b/pkg/config/ini_test.go new file mode 100644 index 00000000..ffcdb294 --- /dev/null +++ b/pkg/config/ini_test.go @@ -0,0 +1,190 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package config + +import ( + "fmt" + "io/ioutil" + "os" + "strings" + "testing" +) + +func TestIni(t *testing.T) { + + var ( + inicontext = ` +;comment one +#comment two +appname = beeapi +httpport = 8080 +mysqlport = 3600 +PI = 3.1415976 +runmode = "dev" +autorender = false +copyrequestbody = true +session= on +cookieon= off +newreg = OFF +needlogin = ON +enableSession = Y +enableCookie = N +flag = 1 +path1 = ${GOPATH} +path2 = ${GOPATH||/home/go} +[demo] +key1="asta" +key2 = "xie" +CaseInsensitive = true +peers = one;two;three +password = ${GOPATH} +` + + keyValue = map[string]interface{}{ + "appname": "beeapi", + "httpport": 8080, + "mysqlport": int64(3600), + "pi": 3.1415976, + "runmode": "dev", + "autorender": false, + "copyrequestbody": true, + "session": true, + "cookieon": false, + "newreg": false, + "needlogin": true, + "enableSession": true, + "enableCookie": false, + "flag": true, + "path1": os.Getenv("GOPATH"), + "path2": os.Getenv("GOPATH"), + "demo::key1": "asta", + "demo::key2": "xie", + "demo::CaseInsensitive": true, + "demo::peers": []string{"one", "two", "three"}, + "demo::password": os.Getenv("GOPATH"), + "null": "", + "demo2::key1": "", + "error": "", + "emptystrings": []string{}, + } + ) + + f, err := os.Create("testini.conf") + if err != nil { + t.Fatal(err) + } + _, err = f.WriteString(inicontext) + if err != nil { + f.Close() + t.Fatal(err) + } + f.Close() + defer os.Remove("testini.conf") + iniconf, err := NewConfig("ini", "testini.conf") + if err != nil { + t.Fatal(err) + } + for k, v := range keyValue { + var err error + var value interface{} + switch v.(type) { + case int: + value, err = iniconf.Int(k) + case int64: + value, err = iniconf.Int64(k) + case float64: + value, err = iniconf.Float(k) + case bool: + value, err = iniconf.Bool(k) + case []string: + value = iniconf.Strings(k) + case string: + value = iniconf.String(k) + default: + value, err = iniconf.DIY(k) + } + if err != nil { + t.Fatalf("get key %q value fail,err %s", k, err) + } else if fmt.Sprintf("%v", v) != fmt.Sprintf("%v", value) { + t.Fatalf("get key %q value, want %v got %v .", k, v, value) + } + + } + if err = iniconf.Set("name", "astaxie"); err != nil { + t.Fatal(err) + } + if iniconf.String("name") != "astaxie" { + t.Fatal("get name error") + } + +} + +func TestIniSave(t *testing.T) { + + const ( + inicontext = ` +app = app +;comment one +#comment two +# comment three +appname = beeapi +httpport = 8080 +# DB Info +# enable db +[dbinfo] +# db type name +# suport mysql,sqlserver +name = mysql +` + + saveResult = ` +app=app +#comment one +#comment two +# comment three +appname=beeapi +httpport=8080 + +# DB Info +# enable db +[dbinfo] +# db type name +# suport mysql,sqlserver +name=mysql +` + ) + cfg, err := NewConfigData("ini", []byte(inicontext)) + if err != nil { + t.Fatal(err) + } + name := "newIniConfig.ini" + if err := cfg.SaveConfigFile(name); err != nil { + t.Fatal(err) + } + defer os.Remove(name) + + if data, err := ioutil.ReadFile(name); err != nil { + t.Fatal(err) + } else { + cfgData := string(data) + datas := strings.Split(saveResult, "\n") + for _, line := range datas { + if !strings.Contains(cfgData, line+"\n") { + t.Fatalf("different after save ini config file. need contains %q", line) + } + } + + } +} diff --git a/pkg/config/json.go b/pkg/config/json.go new file mode 100644 index 00000000..c4ef25cd --- /dev/null +++ b/pkg/config/json.go @@ -0,0 +1,269 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package config + +import ( + "encoding/json" + "errors" + "fmt" + "io/ioutil" + "os" + "strconv" + "strings" + "sync" +) + +// JSONConfig is a json config parser and implements Config interface. +type JSONConfig struct { +} + +// Parse returns a ConfigContainer with parsed json config map. +func (js *JSONConfig) Parse(filename string) (Configer, error) { + file, err := os.Open(filename) + if err != nil { + return nil, err + } + defer file.Close() + content, err := ioutil.ReadAll(file) + if err != nil { + return nil, err + } + + return js.ParseData(content) +} + +// ParseData returns a ConfigContainer with json string +func (js *JSONConfig) ParseData(data []byte) (Configer, error) { + x := &JSONConfigContainer{ + data: make(map[string]interface{}), + } + err := json.Unmarshal(data, &x.data) + if err != nil { + var wrappingArray []interface{} + err2 := json.Unmarshal(data, &wrappingArray) + if err2 != nil { + return nil, err + } + x.data["rootArray"] = wrappingArray + } + + x.data = ExpandValueEnvForMap(x.data) + + return x, nil +} + +// JSONConfigContainer A Config represents the json configuration. +// Only when get value, support key as section:name type. +type JSONConfigContainer struct { + data map[string]interface{} + sync.RWMutex +} + +// Bool returns the boolean value for a given key. +func (c *JSONConfigContainer) Bool(key string) (bool, error) { + val := c.getData(key) + if val != nil { + return ParseBool(val) + } + return false, fmt.Errorf("not exist key: %q", key) +} + +// DefaultBool return the bool value if has no error +// otherwise return the defaultval +func (c *JSONConfigContainer) DefaultBool(key string, defaultval bool) bool { + if v, err := c.Bool(key); err == nil { + return v + } + return defaultval +} + +// Int returns the integer value for a given key. +func (c *JSONConfigContainer) Int(key string) (int, error) { + val := c.getData(key) + if val != nil { + if v, ok := val.(float64); ok { + return int(v), nil + } else if v, ok := val.(string); ok { + return strconv.Atoi(v) + } + return 0, errors.New("not valid value") + } + return 0, errors.New("not exist key:" + key) +} + +// DefaultInt returns the integer value for a given key. +// if err != nil return defaultval +func (c *JSONConfigContainer) DefaultInt(key string, defaultval int) int { + if v, err := c.Int(key); err == nil { + return v + } + return defaultval +} + +// Int64 returns the int64 value for a given key. +func (c *JSONConfigContainer) Int64(key string) (int64, error) { + val := c.getData(key) + if val != nil { + if v, ok := val.(float64); ok { + return int64(v), nil + } + return 0, errors.New("not int64 value") + } + return 0, errors.New("not exist key:" + key) +} + +// DefaultInt64 returns the int64 value for a given key. +// if err != nil return defaultval +func (c *JSONConfigContainer) DefaultInt64(key string, defaultval int64) int64 { + if v, err := c.Int64(key); err == nil { + return v + } + return defaultval +} + +// Float returns the float value for a given key. +func (c *JSONConfigContainer) Float(key string) (float64, error) { + val := c.getData(key) + if val != nil { + if v, ok := val.(float64); ok { + return v, nil + } + return 0.0, errors.New("not float64 value") + } + return 0.0, errors.New("not exist key:" + key) +} + +// DefaultFloat returns the float64 value for a given key. +// if err != nil return defaultval +func (c *JSONConfigContainer) DefaultFloat(key string, defaultval float64) float64 { + if v, err := c.Float(key); err == nil { + return v + } + return defaultval +} + +// String returns the string value for a given key. +func (c *JSONConfigContainer) String(key string) string { + val := c.getData(key) + if val != nil { + if v, ok := val.(string); ok { + return v + } + } + return "" +} + +// DefaultString returns the string value for a given key. +// if err != nil return defaultval +func (c *JSONConfigContainer) DefaultString(key string, defaultval string) string { + // TODO FIXME should not use "" to replace non existence + if v := c.String(key); v != "" { + return v + } + return defaultval +} + +// Strings returns the []string value for a given key. +func (c *JSONConfigContainer) Strings(key string) []string { + stringVal := c.String(key) + if stringVal == "" { + return nil + } + return strings.Split(c.String(key), ";") +} + +// DefaultStrings returns the []string value for a given key. +// if err != nil return defaultval +func (c *JSONConfigContainer) DefaultStrings(key string, defaultval []string) []string { + if v := c.Strings(key); v != nil { + return v + } + return defaultval +} + +// GetSection returns map for the given section +func (c *JSONConfigContainer) GetSection(section string) (map[string]string, error) { + if v, ok := c.data[section]; ok { + return v.(map[string]string), nil + } + return nil, errors.New("nonexist section " + section) +} + +// SaveConfigFile save the config into file +func (c *JSONConfigContainer) SaveConfigFile(filename string) (err error) { + // Write configuration file by filename. + f, err := os.Create(filename) + if err != nil { + return err + } + defer f.Close() + b, err := json.MarshalIndent(c.data, "", " ") + if err != nil { + return err + } + _, err = f.Write(b) + return err +} + +// Set writes a new value for key. +func (c *JSONConfigContainer) Set(key, val string) error { + c.Lock() + defer c.Unlock() + c.data[key] = val + return nil +} + +// DIY returns the raw value by a given key. +func (c *JSONConfigContainer) DIY(key string) (v interface{}, err error) { + val := c.getData(key) + if val != nil { + return val, nil + } + return nil, errors.New("not exist key") +} + +// section.key or key +func (c *JSONConfigContainer) getData(key string) interface{} { + if len(key) == 0 { + return nil + } + + c.RLock() + defer c.RUnlock() + + sectionKeys := strings.Split(key, "::") + if len(sectionKeys) >= 2 { + curValue, ok := c.data[sectionKeys[0]] + if !ok { + return nil + } + for _, key := range sectionKeys[1:] { + if v, ok := curValue.(map[string]interface{}); ok { + if curValue, ok = v[key]; !ok { + return nil + } + } + } + return curValue + } + if v, ok := c.data[key]; ok { + return v + } + return nil +} + +func init() { + Register("json", &JSONConfig{}) +} diff --git a/pkg/config/json_test.go b/pkg/config/json_test.go new file mode 100644 index 00000000..16f42409 --- /dev/null +++ b/pkg/config/json_test.go @@ -0,0 +1,222 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package config + +import ( + "fmt" + "os" + "testing" +) + +func TestJsonStartsWithArray(t *testing.T) { + + const jsoncontextwitharray = `[ + { + "url": "user", + "serviceAPI": "http://www.test.com/user" + }, + { + "url": "employee", + "serviceAPI": "http://www.test.com/employee" + } +]` + f, err := os.Create("testjsonWithArray.conf") + if err != nil { + t.Fatal(err) + } + _, err = f.WriteString(jsoncontextwitharray) + if err != nil { + f.Close() + t.Fatal(err) + } + f.Close() + defer os.Remove("testjsonWithArray.conf") + jsonconf, err := NewConfig("json", "testjsonWithArray.conf") + if err != nil { + t.Fatal(err) + } + rootArray, err := jsonconf.DIY("rootArray") + if err != nil { + t.Error("array does not exist as element") + } + rootArrayCasted := rootArray.([]interface{}) + if rootArrayCasted == nil { + t.Error("array from root is nil") + } else { + elem := rootArrayCasted[0].(map[string]interface{}) + if elem["url"] != "user" || elem["serviceAPI"] != "http://www.test.com/user" { + t.Error("array[0] values are not valid") + } + + elem2 := rootArrayCasted[1].(map[string]interface{}) + if elem2["url"] != "employee" || elem2["serviceAPI"] != "http://www.test.com/employee" { + t.Error("array[1] values are not valid") + } + } +} + +func TestJson(t *testing.T) { + + var ( + jsoncontext = `{ +"appname": "beeapi", +"testnames": "foo;bar", +"httpport": 8080, +"mysqlport": 3600, +"PI": 3.1415976, +"runmode": "dev", +"autorender": false, +"copyrequestbody": true, +"session": "on", +"cookieon": "off", +"newreg": "OFF", +"needlogin": "ON", +"enableSession": "Y", +"enableCookie": "N", +"flag": 1, +"path1": "${GOPATH}", +"path2": "${GOPATH||/home/go}", +"database": { + "host": "host", + "port": "port", + "database": "database", + "username": "username", + "password": "${GOPATH}", + "conns":{ + "maxconnection":12, + "autoconnect":true, + "connectioninfo":"info", + "root": "${GOPATH}" + } + } +}` + keyValue = map[string]interface{}{ + "appname": "beeapi", + "testnames": []string{"foo", "bar"}, + "httpport": 8080, + "mysqlport": int64(3600), + "PI": 3.1415976, + "runmode": "dev", + "autorender": false, + "copyrequestbody": true, + "session": true, + "cookieon": false, + "newreg": false, + "needlogin": true, + "enableSession": true, + "enableCookie": false, + "flag": true, + "path1": os.Getenv("GOPATH"), + "path2": os.Getenv("GOPATH"), + "database::host": "host", + "database::port": "port", + "database::database": "database", + "database::password": os.Getenv("GOPATH"), + "database::conns::maxconnection": 12, + "database::conns::autoconnect": true, + "database::conns::connectioninfo": "info", + "database::conns::root": os.Getenv("GOPATH"), + "unknown": "", + } + ) + + f, err := os.Create("testjson.conf") + if err != nil { + t.Fatal(err) + } + _, err = f.WriteString(jsoncontext) + if err != nil { + f.Close() + t.Fatal(err) + } + f.Close() + defer os.Remove("testjson.conf") + jsonconf, err := NewConfig("json", "testjson.conf") + if err != nil { + t.Fatal(err) + } + + for k, v := range keyValue { + var err error + var value interface{} + switch v.(type) { + case int: + value, err = jsonconf.Int(k) + case int64: + value, err = jsonconf.Int64(k) + case float64: + value, err = jsonconf.Float(k) + case bool: + value, err = jsonconf.Bool(k) + case []string: + value = jsonconf.Strings(k) + case string: + value = jsonconf.String(k) + default: + value, err = jsonconf.DIY(k) + } + if err != nil { + t.Fatalf("get key %q value fatal,%v err %s", k, v, err) + } else if fmt.Sprintf("%v", v) != fmt.Sprintf("%v", value) { + t.Fatalf("get key %q value, want %v got %v .", k, v, value) + } + + } + if err = jsonconf.Set("name", "astaxie"); err != nil { + t.Fatal(err) + } + if jsonconf.String("name") != "astaxie" { + t.Fatal("get name error") + } + + if db, err := jsonconf.DIY("database"); err != nil { + t.Fatal(err) + } else if m, ok := db.(map[string]interface{}); !ok { + t.Log(db) + t.Fatal("db not map[string]interface{}") + } else { + if m["host"].(string) != "host" { + t.Fatal("get host err") + } + } + + if _, err := jsonconf.Int("unknown"); err == nil { + t.Error("unknown keys should return an error when expecting an Int") + } + + if _, err := jsonconf.Int64("unknown"); err == nil { + t.Error("unknown keys should return an error when expecting an Int64") + } + + if _, err := jsonconf.Float("unknown"); err == nil { + t.Error("unknown keys should return an error when expecting a Float") + } + + if _, err := jsonconf.DIY("unknown"); err == nil { + t.Error("unknown keys should return an error when expecting an interface{}") + } + + if val := jsonconf.String("unknown"); val != "" { + t.Error("unknown keys should return an empty string when expecting a String") + } + + if _, err := jsonconf.Bool("unknown"); err == nil { + t.Error("unknown keys should return an error when expecting a Bool") + } + + if !jsonconf.DefaultBool("unknown", true) { + t.Error("unknown keys with default value wrong") + } +} diff --git a/pkg/config/xml/xml.go b/pkg/config/xml/xml.go new file mode 100644 index 00000000..494242d3 --- /dev/null +++ b/pkg/config/xml/xml.go @@ -0,0 +1,228 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package xml for config provider. +// +// depend on github.com/beego/x2j. +// +// go install github.com/beego/x2j. +// +// Usage: +// import( +// _ "github.com/astaxie/beego/config/xml" +// "github.com/astaxie/beego/config" +// ) +// +// cnf, err := config.NewConfig("xml", "config.xml") +// +//More docs http://beego.me/docs/module/config.md +package xml + +import ( + "encoding/xml" + "errors" + "fmt" + "io/ioutil" + "os" + "strconv" + "strings" + "sync" + + "github.com/astaxie/beego/config" + "github.com/beego/x2j" +) + +// Config is a xml config parser and implements Config interface. +// xml configurations should be included in tag. +// only support key/value pair as value as each item. +type Config struct{} + +// Parse returns a ConfigContainer with parsed xml config map. +func (xc *Config) Parse(filename string) (config.Configer, error) { + context, err := ioutil.ReadFile(filename) + if err != nil { + return nil, err + } + + return xc.ParseData(context) +} + +// ParseData xml data +func (xc *Config) ParseData(data []byte) (config.Configer, error) { + x := &ConfigContainer{data: make(map[string]interface{})} + + d, err := x2j.DocToMap(string(data)) + if err != nil { + return nil, err + } + + x.data = config.ExpandValueEnvForMap(d["config"].(map[string]interface{})) + + return x, nil +} + +// ConfigContainer A Config represents the xml configuration. +type ConfigContainer struct { + data map[string]interface{} + sync.Mutex +} + +// Bool returns the boolean value for a given key. +func (c *ConfigContainer) Bool(key string) (bool, error) { + if v := c.data[key]; v != nil { + return config.ParseBool(v) + } + return false, fmt.Errorf("not exist key: %q", key) +} + +// DefaultBool return the bool value if has no error +// otherwise return the defaultval +func (c *ConfigContainer) DefaultBool(key string, defaultval bool) bool { + v, err := c.Bool(key) + if err != nil { + return defaultval + } + return v +} + +// Int returns the integer value for a given key. +func (c *ConfigContainer) Int(key string) (int, error) { + return strconv.Atoi(c.data[key].(string)) +} + +// DefaultInt returns the integer value for a given key. +// if err != nil return defaultval +func (c *ConfigContainer) DefaultInt(key string, defaultval int) int { + v, err := c.Int(key) + if err != nil { + return defaultval + } + return v +} + +// Int64 returns the int64 value for a given key. +func (c *ConfigContainer) Int64(key string) (int64, error) { + return strconv.ParseInt(c.data[key].(string), 10, 64) +} + +// DefaultInt64 returns the int64 value for a given key. +// if err != nil return defaultval +func (c *ConfigContainer) DefaultInt64(key string, defaultval int64) int64 { + v, err := c.Int64(key) + if err != nil { + return defaultval + } + return v + +} + +// Float returns the float value for a given key. +func (c *ConfigContainer) Float(key string) (float64, error) { + return strconv.ParseFloat(c.data[key].(string), 64) +} + +// DefaultFloat returns the float64 value for a given key. +// if err != nil return defaultval +func (c *ConfigContainer) DefaultFloat(key string, defaultval float64) float64 { + v, err := c.Float(key) + if err != nil { + return defaultval + } + return v +} + +// String returns the string value for a given key. +func (c *ConfigContainer) String(key string) string { + if v, ok := c.data[key].(string); ok { + return v + } + return "" +} + +// DefaultString returns the string value for a given key. +// if err != nil return defaultval +func (c *ConfigContainer) DefaultString(key string, defaultval string) string { + v := c.String(key) + if v == "" { + return defaultval + } + return v +} + +// Strings returns the []string value for a given key. +func (c *ConfigContainer) Strings(key string) []string { + v := c.String(key) + if v == "" { + return nil + } + return strings.Split(v, ";") +} + +// DefaultStrings returns the []string value for a given key. +// if err != nil return defaultval +func (c *ConfigContainer) DefaultStrings(key string, defaultval []string) []string { + v := c.Strings(key) + if v == nil { + return defaultval + } + return v +} + +// GetSection returns map for the given section +func (c *ConfigContainer) GetSection(section string) (map[string]string, error) { + if v, ok := c.data[section].(map[string]interface{}); ok { + mapstr := make(map[string]string) + for k, val := range v { + mapstr[k] = config.ToString(val) + } + return mapstr, nil + } + return nil, fmt.Errorf("section '%s' not found", section) +} + +// SaveConfigFile save the config into file +func (c *ConfigContainer) SaveConfigFile(filename string) (err error) { + // Write configuration file by filename. + f, err := os.Create(filename) + if err != nil { + return err + } + defer f.Close() + b, err := xml.MarshalIndent(c.data, " ", " ") + if err != nil { + return err + } + _, err = f.Write(b) + return err +} + +// Set writes a new value for key. +func (c *ConfigContainer) Set(key, val string) error { + c.Lock() + defer c.Unlock() + c.data[key] = val + return nil +} + +// DIY returns the raw value by a given key. +func (c *ConfigContainer) DIY(key string) (v interface{}, err error) { + if v, ok := c.data[key]; ok { + return v, nil + } + return nil, errors.New("not exist key") +} + +func init() { + config.Register("xml", &Config{}) +} diff --git a/pkg/config/xml/xml_test.go b/pkg/config/xml/xml_test.go new file mode 100644 index 00000000..346c866e --- /dev/null +++ b/pkg/config/xml/xml_test.go @@ -0,0 +1,125 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package xml + +import ( + "fmt" + "os" + "testing" + + "github.com/astaxie/beego/config" +) + +func TestXML(t *testing.T) { + + var ( + //xml parse should incluce in tags + xmlcontext = ` + +beeapi +8080 +3600 +3.1415976 +dev +false +true +${GOPATH} +${GOPATH||/home/go} + +1 +MySection + + +` + keyValue = map[string]interface{}{ + "appname": "beeapi", + "httpport": 8080, + "mysqlport": int64(3600), + "PI": 3.1415976, + "runmode": "dev", + "autorender": false, + "copyrequestbody": true, + "path1": os.Getenv("GOPATH"), + "path2": os.Getenv("GOPATH"), + "error": "", + "emptystrings": []string{}, + } + ) + + f, err := os.Create("testxml.conf") + if err != nil { + t.Fatal(err) + } + _, err = f.WriteString(xmlcontext) + if err != nil { + f.Close() + t.Fatal(err) + } + f.Close() + defer os.Remove("testxml.conf") + + xmlconf, err := config.NewConfig("xml", "testxml.conf") + if err != nil { + t.Fatal(err) + } + + var xmlsection map[string]string + xmlsection, err = xmlconf.GetSection("mysection") + if err != nil { + t.Fatal(err) + } + + if len(xmlsection) == 0 { + t.Error("section should not be empty") + } + + for k, v := range keyValue { + + var ( + value interface{} + err error + ) + + switch v.(type) { + case int: + value, err = xmlconf.Int(k) + case int64: + value, err = xmlconf.Int64(k) + case float64: + value, err = xmlconf.Float(k) + case bool: + value, err = xmlconf.Bool(k) + case []string: + value = xmlconf.Strings(k) + case string: + value = xmlconf.String(k) + default: + value, err = xmlconf.DIY(k) + } + if err != nil { + t.Errorf("get key %q value fatal,%v err %s", k, v, err) + } else if fmt.Sprintf("%v", v) != fmt.Sprintf("%v", value) { + t.Errorf("get key %q value, want %v got %v .", k, v, value) + } + + } + + if err = xmlconf.Set("name", "astaxie"); err != nil { + t.Fatal(err) + } + if xmlconf.String("name") != "astaxie" { + t.Fatal("get name error") + } +} diff --git a/pkg/config/yaml/yaml.go b/pkg/config/yaml/yaml.go new file mode 100644 index 00000000..5def2da3 --- /dev/null +++ b/pkg/config/yaml/yaml.go @@ -0,0 +1,316 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package yaml for config provider +// +// depend on github.com/beego/goyaml2 +// +// go install github.com/beego/goyaml2 +// +// Usage: +// import( +// _ "github.com/astaxie/beego/config/yaml" +// "github.com/astaxie/beego/config" +// ) +// +// cnf, err := config.NewConfig("yaml", "config.yaml") +// +//More docs http://beego.me/docs/module/config.md +package yaml + +import ( + "bytes" + "encoding/json" + "errors" + "fmt" + "io/ioutil" + "log" + "os" + "strings" + "sync" + + "github.com/astaxie/beego/config" + "github.com/beego/goyaml2" +) + +// Config is a yaml config parser and implements Config interface. +type Config struct{} + +// Parse returns a ConfigContainer with parsed yaml config map. +func (yaml *Config) Parse(filename string) (y config.Configer, err error) { + cnf, err := ReadYmlReader(filename) + if err != nil { + return + } + y = &ConfigContainer{ + data: cnf, + } + return +} + +// ParseData parse yaml data +func (yaml *Config) ParseData(data []byte) (config.Configer, error) { + cnf, err := parseYML(data) + if err != nil { + return nil, err + } + + return &ConfigContainer{ + data: cnf, + }, nil +} + +// ReadYmlReader Read yaml file to map. +// if json like, use json package, unless goyaml2 package. +func ReadYmlReader(path string) (cnf map[string]interface{}, err error) { + buf, err := ioutil.ReadFile(path) + if err != nil { + return + } + + return parseYML(buf) +} + +// parseYML parse yaml formatted []byte to map. +func parseYML(buf []byte) (cnf map[string]interface{}, err error) { + if len(buf) < 3 { + return + } + + if string(buf[0:1]) == "{" { + log.Println("Look like a Json, try json umarshal") + err = json.Unmarshal(buf, &cnf) + if err == nil { + log.Println("It is Json Map") + return + } + } + + data, err := goyaml2.Read(bytes.NewReader(buf)) + if err != nil { + log.Println("Goyaml2 ERR>", string(buf), err) + return + } + + if data == nil { + log.Println("Goyaml2 output nil? Pls report bug\n" + string(buf)) + return + } + cnf, ok := data.(map[string]interface{}) + if !ok { + log.Println("Not a Map? >> ", string(buf), data) + cnf = nil + } + cnf = config.ExpandValueEnvForMap(cnf) + return +} + +// ConfigContainer A Config represents the yaml configuration. +type ConfigContainer struct { + data map[string]interface{} + sync.RWMutex +} + +// Bool returns the boolean value for a given key. +func (c *ConfigContainer) Bool(key string) (bool, error) { + v, err := c.getData(key) + if err != nil { + return false, err + } + return config.ParseBool(v) +} + +// DefaultBool return the bool value if has no error +// otherwise return the defaultval +func (c *ConfigContainer) DefaultBool(key string, defaultval bool) bool { + v, err := c.Bool(key) + if err != nil { + return defaultval + } + return v +} + +// Int returns the integer value for a given key. +func (c *ConfigContainer) Int(key string) (int, error) { + if v, err := c.getData(key); err != nil { + return 0, err + } else if vv, ok := v.(int); ok { + return vv, nil + } else if vv, ok := v.(int64); ok { + return int(vv), nil + } + return 0, errors.New("not int value") +} + +// DefaultInt returns the integer value for a given key. +// if err != nil return defaultval +func (c *ConfigContainer) DefaultInt(key string, defaultval int) int { + v, err := c.Int(key) + if err != nil { + return defaultval + } + return v +} + +// Int64 returns the int64 value for a given key. +func (c *ConfigContainer) Int64(key string) (int64, error) { + if v, err := c.getData(key); err != nil { + return 0, err + } else if vv, ok := v.(int64); ok { + return vv, nil + } + return 0, errors.New("not bool value") +} + +// DefaultInt64 returns the int64 value for a given key. +// if err != nil return defaultval +func (c *ConfigContainer) DefaultInt64(key string, defaultval int64) int64 { + v, err := c.Int64(key) + if err != nil { + return defaultval + } + return v +} + +// Float returns the float value for a given key. +func (c *ConfigContainer) Float(key string) (float64, error) { + if v, err := c.getData(key); err != nil { + return 0.0, err + } else if vv, ok := v.(float64); ok { + return vv, nil + } else if vv, ok := v.(int); ok { + return float64(vv), nil + } else if vv, ok := v.(int64); ok { + return float64(vv), nil + } + return 0.0, errors.New("not float64 value") +} + +// DefaultFloat returns the float64 value for a given key. +// if err != nil return defaultval +func (c *ConfigContainer) DefaultFloat(key string, defaultval float64) float64 { + v, err := c.Float(key) + if err != nil { + return defaultval + } + return v +} + +// String returns the string value for a given key. +func (c *ConfigContainer) String(key string) string { + if v, err := c.getData(key); err == nil { + if vv, ok := v.(string); ok { + return vv + } + } + return "" +} + +// DefaultString returns the string value for a given key. +// if err != nil return defaultval +func (c *ConfigContainer) DefaultString(key string, defaultval string) string { + v := c.String(key) + if v == "" { + return defaultval + } + return v +} + +// Strings returns the []string value for a given key. +func (c *ConfigContainer) Strings(key string) []string { + v := c.String(key) + if v == "" { + return nil + } + return strings.Split(v, ";") +} + +// DefaultStrings returns the []string value for a given key. +// if err != nil return defaultval +func (c *ConfigContainer) DefaultStrings(key string, defaultval []string) []string { + v := c.Strings(key) + if v == nil { + return defaultval + } + return v +} + +// GetSection returns map for the given section +func (c *ConfigContainer) GetSection(section string) (map[string]string, error) { + + if v, ok := c.data[section]; ok { + return v.(map[string]string), nil + } + return nil, errors.New("not exist section") +} + +// SaveConfigFile save the config into file +func (c *ConfigContainer) SaveConfigFile(filename string) (err error) { + // Write configuration file by filename. + f, err := os.Create(filename) + if err != nil { + return err + } + defer f.Close() + err = goyaml2.Write(f, c.data) + return err +} + +// Set writes a new value for key. +func (c *ConfigContainer) Set(key, val string) error { + c.Lock() + defer c.Unlock() + c.data[key] = val + return nil +} + +// DIY returns the raw value by a given key. +func (c *ConfigContainer) DIY(key string) (v interface{}, err error) { + return c.getData(key) +} + +func (c *ConfigContainer) getData(key string) (interface{}, error) { + + if len(key) == 0 { + return nil, errors.New("key is empty") + } + c.RLock() + defer c.RUnlock() + + keys := strings.Split(key, ".") + tmpData := c.data + for idx, k := range keys { + if v, ok := tmpData[k]; ok { + switch v.(type) { + case map[string]interface{}: + { + tmpData = v.(map[string]interface{}) + if idx == len(keys) - 1 { + return tmpData, nil + } + } + default: + { + return v, nil + } + + } + } + } + return nil, fmt.Errorf("not exist key %q", key) +} + +func init() { + config.Register("yaml", &Config{}) +} diff --git a/pkg/config/yaml/yaml_test.go b/pkg/config/yaml/yaml_test.go new file mode 100644 index 00000000..49cc1d1e --- /dev/null +++ b/pkg/config/yaml/yaml_test.go @@ -0,0 +1,115 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package yaml + +import ( + "fmt" + "os" + "testing" + + "github.com/astaxie/beego/config" +) + +func TestYaml(t *testing.T) { + + var ( + yamlcontext = ` +"appname": beeapi +"httpport": 8080 +"mysqlport": 3600 +"PI": 3.1415976 +"runmode": dev +"autorender": false +"copyrequestbody": true +"PATH": GOPATH +"path1": ${GOPATH} +"path2": ${GOPATH||/home/go} +"empty": "" +` + + keyValue = map[string]interface{}{ + "appname": "beeapi", + "httpport": 8080, + "mysqlport": int64(3600), + "PI": 3.1415976, + "runmode": "dev", + "autorender": false, + "copyrequestbody": true, + "PATH": "GOPATH", + "path1": os.Getenv("GOPATH"), + "path2": os.Getenv("GOPATH"), + "error": "", + "emptystrings": []string{}, + } + ) + f, err := os.Create("testyaml.conf") + if err != nil { + t.Fatal(err) + } + _, err = f.WriteString(yamlcontext) + if err != nil { + f.Close() + t.Fatal(err) + } + f.Close() + defer os.Remove("testyaml.conf") + yamlconf, err := config.NewConfig("yaml", "testyaml.conf") + if err != nil { + t.Fatal(err) + } + + if yamlconf.String("appname") != "beeapi" { + t.Fatal("appname not equal to beeapi") + } + + for k, v := range keyValue { + + var ( + value interface{} + err error + ) + + switch v.(type) { + case int: + value, err = yamlconf.Int(k) + case int64: + value, err = yamlconf.Int64(k) + case float64: + value, err = yamlconf.Float(k) + case bool: + value, err = yamlconf.Bool(k) + case []string: + value = yamlconf.Strings(k) + case string: + value = yamlconf.String(k) + default: + value, err = yamlconf.DIY(k) + } + if err != nil { + t.Errorf("get key %q value fatal,%v err %s", k, v, err) + } else if fmt.Sprintf("%v", v) != fmt.Sprintf("%v", value) { + t.Errorf("get key %q value, want %v got %v .", k, v, value) + } + + } + + if err = yamlconf.Set("name", "astaxie"); err != nil { + t.Fatal(err) + } + if yamlconf.String("name") != "astaxie" { + t.Fatal("get name error") + } + +} diff --git a/pkg/config_test.go b/pkg/config_test.go new file mode 100644 index 00000000..5f71f1c3 --- /dev/null +++ b/pkg/config_test.go @@ -0,0 +1,146 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package beego + +import ( + "encoding/json" + "reflect" + "testing" + + "github.com/astaxie/beego/config" +) + +func TestDefaults(t *testing.T) { + if BConfig.WebConfig.FlashName != "BEEGO_FLASH" { + t.Errorf("FlashName was not set to default.") + } + + if BConfig.WebConfig.FlashSeparator != "BEEGOFLASH" { + t.Errorf("FlashName was not set to default.") + } +} + +func TestAssignConfig_01(t *testing.T) { + _BConfig := &Config{} + _BConfig.AppName = "beego_test" + jcf := &config.JSONConfig{} + ac, _ := jcf.ParseData([]byte(`{"AppName":"beego_json"}`)) + assignSingleConfig(_BConfig, ac) + if _BConfig.AppName != "beego_json" { + t.Log(_BConfig) + t.FailNow() + } +} + +func TestAssignConfig_02(t *testing.T) { + _BConfig := &Config{} + bs, _ := json.Marshal(newBConfig()) + + jsonMap := M{} + json.Unmarshal(bs, &jsonMap) + + configMap := M{} + for k, v := range jsonMap { + if reflect.TypeOf(v).Kind() == reflect.Map { + for k1, v1 := range v.(M) { + if reflect.TypeOf(v1).Kind() == reflect.Map { + for k2, v2 := range v1.(M) { + configMap[k2] = v2 + } + } else { + configMap[k1] = v1 + } + } + } else { + configMap[k] = v + } + } + configMap["MaxMemory"] = 1024 + configMap["Graceful"] = true + configMap["XSRFExpire"] = 32 + configMap["SessionProviderConfig"] = "file" + configMap["FileLineNum"] = true + + jcf := &config.JSONConfig{} + bs, _ = json.Marshal(configMap) + ac, _ := jcf.ParseData(bs) + + for _, i := range []interface{}{_BConfig, &_BConfig.Listen, &_BConfig.WebConfig, &_BConfig.Log, &_BConfig.WebConfig.Session} { + assignSingleConfig(i, ac) + } + + if _BConfig.MaxMemory != 1024 { + t.Log(_BConfig.MaxMemory) + t.FailNow() + } + + if !_BConfig.Listen.Graceful { + t.Log(_BConfig.Listen.Graceful) + t.FailNow() + } + + if _BConfig.WebConfig.XSRFExpire != 32 { + t.Log(_BConfig.WebConfig.XSRFExpire) + t.FailNow() + } + + if _BConfig.WebConfig.Session.SessionProviderConfig != "file" { + t.Log(_BConfig.WebConfig.Session.SessionProviderConfig) + t.FailNow() + } + + if !_BConfig.Log.FileLineNum { + t.Log(_BConfig.Log.FileLineNum) + t.FailNow() + } + +} + +func TestAssignConfig_03(t *testing.T) { + jcf := &config.JSONConfig{} + ac, _ := jcf.ParseData([]byte(`{"AppName":"beego"}`)) + ac.Set("AppName", "test_app") + ac.Set("RunMode", "online") + ac.Set("StaticDir", "download:down download2:down2") + ac.Set("StaticExtensionsToGzip", ".css,.js,.html,.jpg,.png") + ac.Set("StaticCacheFileSize", "87456") + ac.Set("StaticCacheFileNum", "1254") + assignConfig(ac) + + t.Logf("%#v", BConfig) + + if BConfig.AppName != "test_app" { + t.FailNow() + } + + if BConfig.RunMode != "online" { + t.FailNow() + } + if BConfig.WebConfig.StaticDir["/download"] != "down" { + t.FailNow() + } + if BConfig.WebConfig.StaticDir["/download2"] != "down2" { + t.FailNow() + } + if BConfig.WebConfig.StaticCacheFileSize != 87456 { + t.FailNow() + } + if BConfig.WebConfig.StaticCacheFileNum != 1254 { + t.FailNow() + } + if len(BConfig.WebConfig.StaticExtensionsToGzip) != 5 { + t.FailNow() + } +} diff --git a/pkg/context/acceptencoder.go b/pkg/context/acceptencoder.go new file mode 100644 index 00000000..b4e2492c --- /dev/null +++ b/pkg/context/acceptencoder.go @@ -0,0 +1,232 @@ +// Copyright 2015 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package context + +import ( + "bytes" + "compress/flate" + "compress/gzip" + "compress/zlib" + "io" + "net/http" + "os" + "strconv" + "strings" + "sync" +) + +var ( + //Default size==20B same as nginx + defaultGzipMinLength = 20 + //Content will only be compressed if content length is either unknown or greater than gzipMinLength. + gzipMinLength = defaultGzipMinLength + //The compression level used for deflate compression. (0-9). + gzipCompressLevel int + //List of HTTP methods to compress. If not set, only GET requests are compressed. + includedMethods map[string]bool + getMethodOnly bool +) + +// InitGzip init the gzipcompress +func InitGzip(minLength, compressLevel int, methods []string) { + if minLength >= 0 { + gzipMinLength = minLength + } + gzipCompressLevel = compressLevel + if gzipCompressLevel < flate.NoCompression || gzipCompressLevel > flate.BestCompression { + gzipCompressLevel = flate.BestSpeed + } + getMethodOnly = (len(methods) == 0) || (len(methods) == 1 && strings.ToUpper(methods[0]) == "GET") + includedMethods = make(map[string]bool, len(methods)) + for _, v := range methods { + includedMethods[strings.ToUpper(v)] = true + } +} + +type resetWriter interface { + io.Writer + Reset(w io.Writer) +} + +type nopResetWriter struct { + io.Writer +} + +func (n nopResetWriter) Reset(w io.Writer) { + //do nothing +} + +type acceptEncoder struct { + name string + levelEncode func(int) resetWriter + customCompressLevelPool *sync.Pool + bestCompressionPool *sync.Pool +} + +func (ac acceptEncoder) encode(wr io.Writer, level int) resetWriter { + if ac.customCompressLevelPool == nil || ac.bestCompressionPool == nil { + return nopResetWriter{wr} + } + var rwr resetWriter + switch level { + case flate.BestSpeed: + rwr = ac.customCompressLevelPool.Get().(resetWriter) + case flate.BestCompression: + rwr = ac.bestCompressionPool.Get().(resetWriter) + default: + rwr = ac.levelEncode(level) + } + rwr.Reset(wr) + return rwr +} + +func (ac acceptEncoder) put(wr resetWriter, level int) { + if ac.customCompressLevelPool == nil || ac.bestCompressionPool == nil { + return + } + wr.Reset(nil) + + //notice + //compressionLevel==BestCompression DOES NOT MATTER + //sync.Pool will not memory leak + + switch level { + case gzipCompressLevel: + ac.customCompressLevelPool.Put(wr) + case flate.BestCompression: + ac.bestCompressionPool.Put(wr) + } +} + +var ( + noneCompressEncoder = acceptEncoder{"", nil, nil, nil} + gzipCompressEncoder = acceptEncoder{ + name: "gzip", + levelEncode: func(level int) resetWriter { wr, _ := gzip.NewWriterLevel(nil, level); return wr }, + customCompressLevelPool: &sync.Pool{New: func() interface{} { wr, _ := gzip.NewWriterLevel(nil, gzipCompressLevel); return wr }}, + bestCompressionPool: &sync.Pool{New: func() interface{} { wr, _ := gzip.NewWriterLevel(nil, flate.BestCompression); return wr }}, + } + + //according to the sec :http://tools.ietf.org/html/rfc2616#section-3.5 ,the deflate compress in http is zlib indeed + //deflate + //The "zlib" format defined in RFC 1950 [31] in combination with + //the "deflate" compression mechanism described in RFC 1951 [29]. + deflateCompressEncoder = acceptEncoder{ + name: "deflate", + levelEncode: func(level int) resetWriter { wr, _ := zlib.NewWriterLevel(nil, level); return wr }, + customCompressLevelPool: &sync.Pool{New: func() interface{} { wr, _ := zlib.NewWriterLevel(nil, gzipCompressLevel); return wr }}, + bestCompressionPool: &sync.Pool{New: func() interface{} { wr, _ := zlib.NewWriterLevel(nil, flate.BestCompression); return wr }}, + } +) + +var ( + encoderMap = map[string]acceptEncoder{ // all the other compress methods will ignore + "gzip": gzipCompressEncoder, + "deflate": deflateCompressEncoder, + "*": gzipCompressEncoder, // * means any compress will accept,we prefer gzip + "identity": noneCompressEncoder, // identity means none-compress + } +) + +// WriteFile reads from file and writes to writer by the specific encoding(gzip/deflate) +func WriteFile(encoding string, writer io.Writer, file *os.File) (bool, string, error) { + return writeLevel(encoding, writer, file, flate.BestCompression) +} + +// WriteBody reads writes content to writer by the specific encoding(gzip/deflate) +func WriteBody(encoding string, writer io.Writer, content []byte) (bool, string, error) { + if encoding == "" || len(content) < gzipMinLength { + _, err := writer.Write(content) + return false, "", err + } + return writeLevel(encoding, writer, bytes.NewReader(content), gzipCompressLevel) +} + +// writeLevel reads from reader,writes to writer by specific encoding and compress level +// the compress level is defined by deflate package +func writeLevel(encoding string, writer io.Writer, reader io.Reader, level int) (bool, string, error) { + var outputWriter resetWriter + var err error + var ce = noneCompressEncoder + + if cf, ok := encoderMap[encoding]; ok { + ce = cf + } + encoding = ce.name + outputWriter = ce.encode(writer, level) + defer ce.put(outputWriter, level) + + _, err = io.Copy(outputWriter, reader) + if err != nil { + return false, "", err + } + + switch outputWriter.(type) { + case io.WriteCloser: + outputWriter.(io.WriteCloser).Close() + } + return encoding != "", encoding, nil +} + +// ParseEncoding will extract the right encoding for response +// the Accept-Encoding's sec is here: +// http://www.w3.org/Protocols/rfc2616/rfc2616-sec14.html#sec14.3 +func ParseEncoding(r *http.Request) string { + if r == nil { + return "" + } + if (getMethodOnly && r.Method == "GET") || includedMethods[r.Method] { + return parseEncoding(r) + } + return "" +} + +type q struct { + name string + value float64 +} + +func parseEncoding(r *http.Request) string { + acceptEncoding := r.Header.Get("Accept-Encoding") + if acceptEncoding == "" { + return "" + } + var lastQ q + for _, v := range strings.Split(acceptEncoding, ",") { + v = strings.TrimSpace(v) + if v == "" { + continue + } + vs := strings.Split(v, ";") + var cf acceptEncoder + var ok bool + if cf, ok = encoderMap[vs[0]]; !ok { + continue + } + if len(vs) == 1 { + return cf.name + } + if len(vs) == 2 { + f, _ := strconv.ParseFloat(strings.Replace(vs[1], "q=", "", -1), 64) + if f == 0 { + continue + } + if f > lastQ.value { + lastQ = q{cf.name, f} + } + } + } + return lastQ.name +} diff --git a/pkg/context/acceptencoder_test.go b/pkg/context/acceptencoder_test.go new file mode 100644 index 00000000..e3d61e27 --- /dev/null +++ b/pkg/context/acceptencoder_test.go @@ -0,0 +1,59 @@ +// Copyright 2015 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package context + +import ( + "net/http" + "testing" +) + +func Test_ExtractEncoding(t *testing.T) { + if parseEncoding(&http.Request{Header: map[string][]string{"Accept-Encoding": {"gzip,deflate"}}}) != "gzip" { + t.Fail() + } + if parseEncoding(&http.Request{Header: map[string][]string{"Accept-Encoding": {"deflate,gzip"}}}) != "deflate" { + t.Fail() + } + if parseEncoding(&http.Request{Header: map[string][]string{"Accept-Encoding": {"gzip;q=.5,deflate"}}}) != "deflate" { + t.Fail() + } + if parseEncoding(&http.Request{Header: map[string][]string{"Accept-Encoding": {"gzip;q=.5,deflate;q=0.3"}}}) != "gzip" { + t.Fail() + } + if parseEncoding(&http.Request{Header: map[string][]string{"Accept-Encoding": {"gzip;q=0,deflate"}}}) != "deflate" { + t.Fail() + } + if parseEncoding(&http.Request{Header: map[string][]string{"Accept-Encoding": {"deflate;q=0.5,gzip;q=0.5,identity"}}}) != "" { + t.Fail() + } + if parseEncoding(&http.Request{Header: map[string][]string{"Accept-Encoding": {"*"}}}) != "gzip" { + t.Fail() + } + if parseEncoding(&http.Request{Header: map[string][]string{"Accept-Encoding": {"x,gzip,deflate"}}}) != "gzip" { + t.Fail() + } + if parseEncoding(&http.Request{Header: map[string][]string{"Accept-Encoding": {"gzip,x,deflate"}}}) != "gzip" { + t.Fail() + } + if parseEncoding(&http.Request{Header: map[string][]string{"Accept-Encoding": {"gzip;q=0.5,x,deflate"}}}) != "deflate" { + t.Fail() + } + if parseEncoding(&http.Request{Header: map[string][]string{"Accept-Encoding": {"x"}}}) != "" { + t.Fail() + } + if parseEncoding(&http.Request{Header: map[string][]string{"Accept-Encoding": {"gzip;q=0.5,x;q=0.8"}}}) != "gzip" { + t.Fail() + } +} diff --git a/pkg/context/context.go b/pkg/context/context.go new file mode 100644 index 00000000..de248ed2 --- /dev/null +++ b/pkg/context/context.go @@ -0,0 +1,263 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package context provide the context utils +// Usage: +// +// import "github.com/astaxie/beego/context" +// +// ctx := context.Context{Request:req,ResponseWriter:rw} +// +// more docs http://beego.me/docs/module/context.md +package context + +import ( + "bufio" + "crypto/hmac" + "crypto/sha256" + "encoding/base64" + "errors" + "fmt" + "net" + "net/http" + "strconv" + "strings" + "time" + + "github.com/astaxie/beego/utils" +) + +//commonly used mime-types +const ( + ApplicationJSON = "application/json" + ApplicationXML = "application/xml" + ApplicationYAML = "application/x-yaml" + TextXML = "text/xml" +) + +// NewContext return the Context with Input and Output +func NewContext() *Context { + return &Context{ + Input: NewInput(), + Output: NewOutput(), + } +} + +// Context Http request context struct including BeegoInput, BeegoOutput, http.Request and http.ResponseWriter. +// BeegoInput and BeegoOutput provides some api to operate request and response more easily. +type Context struct { + Input *BeegoInput + Output *BeegoOutput + Request *http.Request + ResponseWriter *Response + _xsrfToken string +} + +// Reset init Context, BeegoInput and BeegoOutput +func (ctx *Context) Reset(rw http.ResponseWriter, r *http.Request) { + ctx.Request = r + if ctx.ResponseWriter == nil { + ctx.ResponseWriter = &Response{} + } + ctx.ResponseWriter.reset(rw) + ctx.Input.Reset(ctx) + ctx.Output.Reset(ctx) + ctx._xsrfToken = "" +} + +// Redirect does redirection to localurl with http header status code. +func (ctx *Context) Redirect(status int, localurl string) { + http.Redirect(ctx.ResponseWriter, ctx.Request, localurl, status) +} + +// Abort stops this request. +// if beego.ErrorMaps exists, panic body. +func (ctx *Context) Abort(status int, body string) { + ctx.Output.SetStatus(status) + panic(body) +} + +// WriteString Write string to response body. +// it sends response body. +func (ctx *Context) WriteString(content string) { + ctx.ResponseWriter.Write([]byte(content)) +} + +// GetCookie Get cookie from request by a given key. +// It's alias of BeegoInput.Cookie. +func (ctx *Context) GetCookie(key string) string { + return ctx.Input.Cookie(key) +} + +// SetCookie Set cookie for response. +// It's alias of BeegoOutput.Cookie. +func (ctx *Context) SetCookie(name string, value string, others ...interface{}) { + ctx.Output.Cookie(name, value, others...) +} + +// GetSecureCookie Get secure cookie from request by a given key. +func (ctx *Context) GetSecureCookie(Secret, key string) (string, bool) { + val := ctx.Input.Cookie(key) + if val == "" { + return "", false + } + + parts := strings.SplitN(val, "|", 3) + + if len(parts) != 3 { + return "", false + } + + vs := parts[0] + timestamp := parts[1] + sig := parts[2] + + h := hmac.New(sha256.New, []byte(Secret)) + fmt.Fprintf(h, "%s%s", vs, timestamp) + + if fmt.Sprintf("%02x", h.Sum(nil)) != sig { + return "", false + } + res, _ := base64.URLEncoding.DecodeString(vs) + return string(res), true +} + +// SetSecureCookie Set Secure cookie for response. +func (ctx *Context) SetSecureCookie(Secret, name, value string, others ...interface{}) { + vs := base64.URLEncoding.EncodeToString([]byte(value)) + timestamp := strconv.FormatInt(time.Now().UnixNano(), 10) + h := hmac.New(sha256.New, []byte(Secret)) + fmt.Fprintf(h, "%s%s", vs, timestamp) + sig := fmt.Sprintf("%02x", h.Sum(nil)) + cookie := strings.Join([]string{vs, timestamp, sig}, "|") + ctx.Output.Cookie(name, cookie, others...) +} + +// XSRFToken creates a xsrf token string and returns. +func (ctx *Context) XSRFToken(key string, expire int64) string { + if ctx._xsrfToken == "" { + token, ok := ctx.GetSecureCookie(key, "_xsrf") + if !ok { + token = string(utils.RandomCreateBytes(32)) + ctx.SetSecureCookie(key, "_xsrf", token, expire) + } + ctx._xsrfToken = token + } + return ctx._xsrfToken +} + +// CheckXSRFCookie checks xsrf token in this request is valid or not. +// the token can provided in request header "X-Xsrftoken" and "X-CsrfToken" +// or in form field value named as "_xsrf". +func (ctx *Context) CheckXSRFCookie() bool { + token := ctx.Input.Query("_xsrf") + if token == "" { + token = ctx.Request.Header.Get("X-Xsrftoken") + } + if token == "" { + token = ctx.Request.Header.Get("X-Csrftoken") + } + if token == "" { + ctx.Abort(422, "422") + return false + } + if ctx._xsrfToken != token { + ctx.Abort(417, "417") + return false + } + return true +} + +// RenderMethodResult renders the return value of a controller method to the output +func (ctx *Context) RenderMethodResult(result interface{}) { + if result != nil { + renderer, ok := result.(Renderer) + if !ok { + err, ok := result.(error) + if ok { + renderer = errorRenderer(err) + } else { + renderer = jsonRenderer(result) + } + } + renderer.Render(ctx) + } +} + +//Response is a wrapper for the http.ResponseWriter +//started set to true if response was written to then don't execute other handler +type Response struct { + http.ResponseWriter + Started bool + Status int + Elapsed time.Duration +} + +func (r *Response) reset(rw http.ResponseWriter) { + r.ResponseWriter = rw + r.Status = 0 + r.Started = false +} + +// Write writes the data to the connection as part of an HTTP reply, +// and sets `started` to true. +// started means the response has sent out. +func (r *Response) Write(p []byte) (int, error) { + r.Started = true + return r.ResponseWriter.Write(p) +} + +// WriteHeader sends an HTTP response header with status code, +// and sets `started` to true. +func (r *Response) WriteHeader(code int) { + if r.Status > 0 { + //prevent multiple response.WriteHeader calls + return + } + r.Status = code + r.Started = true + r.ResponseWriter.WriteHeader(code) +} + +// Hijack hijacker for http +func (r *Response) Hijack() (net.Conn, *bufio.ReadWriter, error) { + hj, ok := r.ResponseWriter.(http.Hijacker) + if !ok { + return nil, nil, errors.New("webserver doesn't support hijacking") + } + return hj.Hijack() +} + +// Flush http.Flusher +func (r *Response) Flush() { + if f, ok := r.ResponseWriter.(http.Flusher); ok { + f.Flush() + } +} + +// CloseNotify http.CloseNotifier +func (r *Response) CloseNotify() <-chan bool { + if cn, ok := r.ResponseWriter.(http.CloseNotifier); ok { + return cn.CloseNotify() + } + return nil +} + +// Pusher http.Pusher +func (r *Response) Pusher() (pusher http.Pusher) { + if pusher, ok := r.ResponseWriter.(http.Pusher); ok { + return pusher + } + return nil +} diff --git a/pkg/context/context_test.go b/pkg/context/context_test.go new file mode 100644 index 00000000..7c0535e0 --- /dev/null +++ b/pkg/context/context_test.go @@ -0,0 +1,47 @@ +// Copyright 2016 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package context + +import ( + "net/http" + "net/http/httptest" + "testing" +) + +func TestXsrfReset_01(t *testing.T) { + r := &http.Request{} + c := NewContext() + c.Request = r + c.ResponseWriter = &Response{} + c.ResponseWriter.reset(httptest.NewRecorder()) + c.Output.Reset(c) + c.Input.Reset(c) + c.XSRFToken("key", 16) + if c._xsrfToken == "" { + t.FailNow() + } + token := c._xsrfToken + c.Reset(&Response{ResponseWriter: httptest.NewRecorder()}, r) + if c._xsrfToken != "" { + t.FailNow() + } + c.XSRFToken("key", 16) + if c._xsrfToken == "" { + t.FailNow() + } + if token == c._xsrfToken { + t.FailNow() + } +} diff --git a/pkg/context/input.go b/pkg/context/input.go new file mode 100644 index 00000000..385549c1 --- /dev/null +++ b/pkg/context/input.go @@ -0,0 +1,689 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package context + +import ( + "bytes" + "compress/gzip" + "errors" + "io" + "io/ioutil" + "net" + "net/http" + "net/url" + "reflect" + "regexp" + "strconv" + "strings" + "sync" + + "github.com/astaxie/beego/session" +) + +// Regexes for checking the accept headers +// TODO make sure these are correct +var ( + acceptsHTMLRegex = regexp.MustCompile(`(text/html|application/xhtml\+xml)(?:,|$)`) + acceptsXMLRegex = regexp.MustCompile(`(application/xml|text/xml)(?:,|$)`) + acceptsJSONRegex = regexp.MustCompile(`(application/json)(?:,|$)`) + acceptsYAMLRegex = regexp.MustCompile(`(application/x-yaml)(?:,|$)`) + maxParam = 50 +) + +// BeegoInput operates the http request header, data, cookie and body. +// it also contains router params and current session. +type BeegoInput struct { + Context *Context + CruSession session.Store + pnames []string + pvalues []string + data map[interface{}]interface{} // store some values in this context when calling context in filter or controller. + dataLock sync.RWMutex + RequestBody []byte + RunMethod string + RunController reflect.Type +} + +// NewInput return BeegoInput generated by Context. +func NewInput() *BeegoInput { + return &BeegoInput{ + pnames: make([]string, 0, maxParam), + pvalues: make([]string, 0, maxParam), + data: make(map[interface{}]interface{}), + } +} + +// Reset init the BeegoInput +func (input *BeegoInput) Reset(ctx *Context) { + input.Context = ctx + input.CruSession = nil + input.pnames = input.pnames[:0] + input.pvalues = input.pvalues[:0] + input.dataLock.Lock() + input.data = nil + input.dataLock.Unlock() + input.RequestBody = []byte{} +} + +// Protocol returns request protocol name, such as HTTP/1.1 . +func (input *BeegoInput) Protocol() string { + return input.Context.Request.Proto +} + +// URI returns full request url with query string, fragment. +func (input *BeegoInput) URI() string { + return input.Context.Request.RequestURI +} + +// URL returns request url path (without query string, fragment). +func (input *BeegoInput) URL() string { + return input.Context.Request.URL.EscapedPath() +} + +// Site returns base site url as scheme://domain type. +func (input *BeegoInput) Site() string { + return input.Scheme() + "://" + input.Domain() +} + +// Scheme returns request scheme as "http" or "https". +func (input *BeegoInput) Scheme() string { + if scheme := input.Header("X-Forwarded-Proto"); scheme != "" { + return scheme + } + if input.Context.Request.URL.Scheme != "" { + return input.Context.Request.URL.Scheme + } + if input.Context.Request.TLS == nil { + return "http" + } + return "https" +} + +// Domain returns host name. +// Alias of Host method. +func (input *BeegoInput) Domain() string { + return input.Host() +} + +// Host returns host name. +// if no host info in request, return localhost. +func (input *BeegoInput) Host() string { + if input.Context.Request.Host != "" { + if hostPart, _, err := net.SplitHostPort(input.Context.Request.Host); err == nil { + return hostPart + } + return input.Context.Request.Host + } + return "localhost" +} + +// Method returns http request method. +func (input *BeegoInput) Method() string { + return input.Context.Request.Method +} + +// Is returns boolean of this request is on given method, such as Is("POST"). +func (input *BeegoInput) Is(method string) bool { + return input.Method() == method +} + +// IsGet Is this a GET method request? +func (input *BeegoInput) IsGet() bool { + return input.Is("GET") +} + +// IsPost Is this a POST method request? +func (input *BeegoInput) IsPost() bool { + return input.Is("POST") +} + +// IsHead Is this a Head method request? +func (input *BeegoInput) IsHead() bool { + return input.Is("HEAD") +} + +// IsOptions Is this a OPTIONS method request? +func (input *BeegoInput) IsOptions() bool { + return input.Is("OPTIONS") +} + +// IsPut Is this a PUT method request? +func (input *BeegoInput) IsPut() bool { + return input.Is("PUT") +} + +// IsDelete Is this a DELETE method request? +func (input *BeegoInput) IsDelete() bool { + return input.Is("DELETE") +} + +// IsPatch Is this a PATCH method request? +func (input *BeegoInput) IsPatch() bool { + return input.Is("PATCH") +} + +// IsAjax returns boolean of this request is generated by ajax. +func (input *BeegoInput) IsAjax() bool { + return input.Header("X-Requested-With") == "XMLHttpRequest" +} + +// IsSecure returns boolean of this request is in https. +func (input *BeegoInput) IsSecure() bool { + return input.Scheme() == "https" +} + +// IsWebsocket returns boolean of this request is in webSocket. +func (input *BeegoInput) IsWebsocket() bool { + return input.Header("Upgrade") == "websocket" +} + +// IsUpload returns boolean of whether file uploads in this request or not.. +func (input *BeegoInput) IsUpload() bool { + return strings.Contains(input.Header("Content-Type"), "multipart/form-data") +} + +// AcceptsHTML Checks if request accepts html response +func (input *BeegoInput) AcceptsHTML() bool { + return acceptsHTMLRegex.MatchString(input.Header("Accept")) +} + +// AcceptsXML Checks if request accepts xml response +func (input *BeegoInput) AcceptsXML() bool { + return acceptsXMLRegex.MatchString(input.Header("Accept")) +} + +// AcceptsJSON Checks if request accepts json response +func (input *BeegoInput) AcceptsJSON() bool { + return acceptsJSONRegex.MatchString(input.Header("Accept")) +} + +// AcceptsYAML Checks if request accepts json response +func (input *BeegoInput) AcceptsYAML() bool { + return acceptsYAMLRegex.MatchString(input.Header("Accept")) +} + +// IP returns request client ip. +// if in proxy, return first proxy id. +// if error, return RemoteAddr. +func (input *BeegoInput) IP() string { + ips := input.Proxy() + if len(ips) > 0 && ips[0] != "" { + rip, _, err := net.SplitHostPort(ips[0]) + if err != nil { + rip = ips[0] + } + return rip + } + if ip, _, err := net.SplitHostPort(input.Context.Request.RemoteAddr); err == nil { + return ip + } + return input.Context.Request.RemoteAddr +} + +// Proxy returns proxy client ips slice. +func (input *BeegoInput) Proxy() []string { + if ips := input.Header("X-Forwarded-For"); ips != "" { + return strings.Split(ips, ",") + } + return []string{} +} + +// Referer returns http referer header. +func (input *BeegoInput) Referer() string { + return input.Header("Referer") +} + +// Refer returns http referer header. +func (input *BeegoInput) Refer() string { + return input.Referer() +} + +// SubDomains returns sub domain string. +// if aa.bb.domain.com, returns aa.bb . +func (input *BeegoInput) SubDomains() string { + parts := strings.Split(input.Host(), ".") + if len(parts) >= 3 { + return strings.Join(parts[:len(parts)-2], ".") + } + return "" +} + +// Port returns request client port. +// when error or empty, return 80. +func (input *BeegoInput) Port() int { + if _, portPart, err := net.SplitHostPort(input.Context.Request.Host); err == nil { + port, _ := strconv.Atoi(portPart) + return port + } + return 80 +} + +// UserAgent returns request client user agent string. +func (input *BeegoInput) UserAgent() string { + return input.Header("User-Agent") +} + +// ParamsLen return the length of the params +func (input *BeegoInput) ParamsLen() int { + return len(input.pnames) +} + +// Param returns router param by a given key. +func (input *BeegoInput) Param(key string) string { + for i, v := range input.pnames { + if v == key && i <= len(input.pvalues) { + // we cannot use url.PathEscape(input.pvalues[i]) + // for example, if the value is /a/b + // after url.PathEscape(input.pvalues[i]), the value is %2Fa%2Fb + // However, the value is used in ControllerRegister.ServeHTTP + // and split by "/", so function crash... + return input.pvalues[i] + } + } + return "" +} + +// Params returns the map[key]value. +func (input *BeegoInput) Params() map[string]string { + m := make(map[string]string) + for i, v := range input.pnames { + if i <= len(input.pvalues) { + m[v] = input.pvalues[i] + } + } + return m +} + +// SetParam will set the param with key and value +func (input *BeegoInput) SetParam(key, val string) { + // check if already exists + for i, v := range input.pnames { + if v == key && i <= len(input.pvalues) { + input.pvalues[i] = val + return + } + } + input.pvalues = append(input.pvalues, val) + input.pnames = append(input.pnames, key) +} + +// ResetParams clears any of the input's Params +// This function is used to clear parameters so they may be reset between filter +// passes. +func (input *BeegoInput) ResetParams() { + input.pnames = input.pnames[:0] + input.pvalues = input.pvalues[:0] +} + +// Query returns input data item string by a given string. +func (input *BeegoInput) Query(key string) string { + if val := input.Param(key); val != "" { + return val + } + if input.Context.Request.Form == nil { + input.dataLock.Lock() + if input.Context.Request.Form == nil { + input.Context.Request.ParseForm() + } + input.dataLock.Unlock() + } + input.dataLock.RLock() + defer input.dataLock.RUnlock() + return input.Context.Request.Form.Get(key) +} + +// Header returns request header item string by a given string. +// if non-existed, return empty string. +func (input *BeegoInput) Header(key string) string { + return input.Context.Request.Header.Get(key) +} + +// Cookie returns request cookie item string by a given key. +// if non-existed, return empty string. +func (input *BeegoInput) Cookie(key string) string { + ck, err := input.Context.Request.Cookie(key) + if err != nil { + return "" + } + return ck.Value +} + +// Session returns current session item value by a given key. +// if non-existed, return nil. +func (input *BeegoInput) Session(key interface{}) interface{} { + return input.CruSession.Get(key) +} + +// CopyBody returns the raw request body data as bytes. +func (input *BeegoInput) CopyBody(MaxMemory int64) []byte { + if input.Context.Request.Body == nil { + return []byte{} + } + + var requestbody []byte + safe := &io.LimitedReader{R: input.Context.Request.Body, N: MaxMemory} + if input.Header("Content-Encoding") == "gzip" { + reader, err := gzip.NewReader(safe) + if err != nil { + return nil + } + requestbody, _ = ioutil.ReadAll(reader) + } else { + requestbody, _ = ioutil.ReadAll(safe) + } + + input.Context.Request.Body.Close() + bf := bytes.NewBuffer(requestbody) + input.Context.Request.Body = http.MaxBytesReader(input.Context.ResponseWriter, ioutil.NopCloser(bf), MaxMemory) + input.RequestBody = requestbody + return requestbody +} + +// Data return the implicit data in the input +func (input *BeegoInput) Data() map[interface{}]interface{} { + input.dataLock.Lock() + defer input.dataLock.Unlock() + if input.data == nil { + input.data = make(map[interface{}]interface{}) + } + return input.data +} + +// GetData returns the stored data in this context. +func (input *BeegoInput) GetData(key interface{}) interface{} { + input.dataLock.Lock() + defer input.dataLock.Unlock() + if v, ok := input.data[key]; ok { + return v + } + return nil +} + +// SetData stores data with given key in this context. +// This data are only available in this context. +func (input *BeegoInput) SetData(key, val interface{}) { + input.dataLock.Lock() + defer input.dataLock.Unlock() + if input.data == nil { + input.data = make(map[interface{}]interface{}) + } + input.data[key] = val +} + +// ParseFormOrMulitForm parseForm or parseMultiForm based on Content-type +func (input *BeegoInput) ParseFormOrMulitForm(maxMemory int64) error { + // Parse the body depending on the content type. + if strings.Contains(input.Header("Content-Type"), "multipart/form-data") { + if err := input.Context.Request.ParseMultipartForm(maxMemory); err != nil { + return errors.New("Error parsing request body:" + err.Error()) + } + } else if err := input.Context.Request.ParseForm(); err != nil { + return errors.New("Error parsing request body:" + err.Error()) + } + return nil +} + +// Bind data from request.Form[key] to dest +// like /?id=123&isok=true&ft=1.2&ol[0]=1&ol[1]=2&ul[]=str&ul[]=array&user.Name=astaxie +// var id int beegoInput.Bind(&id, "id") id ==123 +// var isok bool beegoInput.Bind(&isok, "isok") isok ==true +// var ft float64 beegoInput.Bind(&ft, "ft") ft ==1.2 +// ol := make([]int, 0, 2) beegoInput.Bind(&ol, "ol") ol ==[1 2] +// ul := make([]string, 0, 2) beegoInput.Bind(&ul, "ul") ul ==[str array] +// user struct{Name} beegoInput.Bind(&user, "user") user == {Name:"astaxie"} +func (input *BeegoInput) Bind(dest interface{}, key string) error { + value := reflect.ValueOf(dest) + if value.Kind() != reflect.Ptr { + return errors.New("beego: non-pointer passed to Bind: " + key) + } + value = value.Elem() + if !value.CanSet() { + return errors.New("beego: non-settable variable passed to Bind: " + key) + } + typ := value.Type() + // Get real type if dest define with interface{}. + // e.g var dest interface{} dest=1.0 + if value.Kind() == reflect.Interface { + typ = value.Elem().Type() + } + rv := input.bind(key, typ) + if !rv.IsValid() { + return errors.New("beego: reflect value is empty") + } + value.Set(rv) + return nil +} + +func (input *BeegoInput) bind(key string, typ reflect.Type) reflect.Value { + if input.Context.Request.Form == nil { + input.Context.Request.ParseForm() + } + rv := reflect.Zero(typ) + switch typ.Kind() { + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + val := input.Query(key) + if len(val) == 0 { + return rv + } + rv = input.bindInt(val, typ) + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + val := input.Query(key) + if len(val) == 0 { + return rv + } + rv = input.bindUint(val, typ) + case reflect.Float32, reflect.Float64: + val := input.Query(key) + if len(val) == 0 { + return rv + } + rv = input.bindFloat(val, typ) + case reflect.String: + val := input.Query(key) + if len(val) == 0 { + return rv + } + rv = input.bindString(val, typ) + case reflect.Bool: + val := input.Query(key) + if len(val) == 0 { + return rv + } + rv = input.bindBool(val, typ) + case reflect.Slice: + rv = input.bindSlice(&input.Context.Request.Form, key, typ) + case reflect.Struct: + rv = input.bindStruct(&input.Context.Request.Form, key, typ) + case reflect.Ptr: + rv = input.bindPoint(key, typ) + case reflect.Map: + rv = input.bindMap(&input.Context.Request.Form, key, typ) + } + return rv +} + +func (input *BeegoInput) bindValue(val string, typ reflect.Type) reflect.Value { + rv := reflect.Zero(typ) + switch typ.Kind() { + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + rv = input.bindInt(val, typ) + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + rv = input.bindUint(val, typ) + case reflect.Float32, reflect.Float64: + rv = input.bindFloat(val, typ) + case reflect.String: + rv = input.bindString(val, typ) + case reflect.Bool: + rv = input.bindBool(val, typ) + case reflect.Slice: + rv = input.bindSlice(&url.Values{"": {val}}, "", typ) + case reflect.Struct: + rv = input.bindStruct(&url.Values{"": {val}}, "", typ) + case reflect.Ptr: + rv = input.bindPoint(val, typ) + case reflect.Map: + rv = input.bindMap(&url.Values{"": {val}}, "", typ) + } + return rv +} + +func (input *BeegoInput) bindInt(val string, typ reflect.Type) reflect.Value { + intValue, err := strconv.ParseInt(val, 10, 64) + if err != nil { + return reflect.Zero(typ) + } + pValue := reflect.New(typ) + pValue.Elem().SetInt(intValue) + return pValue.Elem() +} + +func (input *BeegoInput) bindUint(val string, typ reflect.Type) reflect.Value { + uintValue, err := strconv.ParseUint(val, 10, 64) + if err != nil { + return reflect.Zero(typ) + } + pValue := reflect.New(typ) + pValue.Elem().SetUint(uintValue) + return pValue.Elem() +} + +func (input *BeegoInput) bindFloat(val string, typ reflect.Type) reflect.Value { + floatValue, err := strconv.ParseFloat(val, 64) + if err != nil { + return reflect.Zero(typ) + } + pValue := reflect.New(typ) + pValue.Elem().SetFloat(floatValue) + return pValue.Elem() +} + +func (input *BeegoInput) bindString(val string, typ reflect.Type) reflect.Value { + return reflect.ValueOf(val) +} + +func (input *BeegoInput) bindBool(val string, typ reflect.Type) reflect.Value { + val = strings.TrimSpace(strings.ToLower(val)) + switch val { + case "true", "on", "1": + return reflect.ValueOf(true) + } + return reflect.ValueOf(false) +} + +type sliceValue struct { + index int // Index extracted from brackets. If -1, no index was provided. + value reflect.Value // the bound value for this slice element. +} + +func (input *BeegoInput) bindSlice(params *url.Values, key string, typ reflect.Type) reflect.Value { + maxIndex := -1 + numNoIndex := 0 + sliceValues := []sliceValue{} + for reqKey, vals := range *params { + if !strings.HasPrefix(reqKey, key+"[") { + continue + } + // Extract the index, and the index where a sub-key starts. (e.g. field[0].subkey) + index := -1 + leftBracket, rightBracket := len(key), strings.Index(reqKey[len(key):], "]")+len(key) + if rightBracket > leftBracket+1 { + index, _ = strconv.Atoi(reqKey[leftBracket+1 : rightBracket]) + } + subKeyIndex := rightBracket + 1 + + // Handle the indexed case. + if index > -1 { + if index > maxIndex { + maxIndex = index + } + sliceValues = append(sliceValues, sliceValue{ + index: index, + value: input.bind(reqKey[:subKeyIndex], typ.Elem()), + }) + continue + } + + // It's an un-indexed element. (e.g. element[]) + numNoIndex += len(vals) + for _, val := range vals { + // Unindexed values can only be direct-bound. + sliceValues = append(sliceValues, sliceValue{ + index: -1, + value: input.bindValue(val, typ.Elem()), + }) + } + } + resultArray := reflect.MakeSlice(typ, maxIndex+1, maxIndex+1+numNoIndex) + for _, sv := range sliceValues { + if sv.index != -1 { + resultArray.Index(sv.index).Set(sv.value) + } else { + resultArray = reflect.Append(resultArray, sv.value) + } + } + return resultArray +} + +func (input *BeegoInput) bindStruct(params *url.Values, key string, typ reflect.Type) reflect.Value { + result := reflect.New(typ).Elem() + fieldValues := make(map[string]reflect.Value) + for reqKey, val := range *params { + var fieldName string + if strings.HasPrefix(reqKey, key+".") { + fieldName = reqKey[len(key)+1:] + } else if strings.HasPrefix(reqKey, key+"[") && reqKey[len(reqKey)-1] == ']' { + fieldName = reqKey[len(key)+1 : len(reqKey)-1] + } else { + continue + } + + if _, ok := fieldValues[fieldName]; !ok { + // Time to bind this field. Get it and make sure we can set it. + fieldValue := result.FieldByName(fieldName) + if !fieldValue.IsValid() { + continue + } + if !fieldValue.CanSet() { + continue + } + boundVal := input.bindValue(val[0], fieldValue.Type()) + fieldValue.Set(boundVal) + fieldValues[fieldName] = boundVal + } + } + + return result +} + +func (input *BeegoInput) bindPoint(key string, typ reflect.Type) reflect.Value { + return input.bind(key, typ.Elem()).Addr() +} + +func (input *BeegoInput) bindMap(params *url.Values, key string, typ reflect.Type) reflect.Value { + var ( + result = reflect.MakeMap(typ) + keyType = typ.Key() + valueType = typ.Elem() + ) + for paramName, values := range *params { + if !strings.HasPrefix(paramName, key+"[") || paramName[len(paramName)-1] != ']' { + continue + } + + key := paramName[len(key)+1 : len(paramName)-1] + result.SetMapIndex(input.bindValue(key, keyType), input.bindValue(values[0], valueType)) + } + return result +} diff --git a/pkg/context/input_test.go b/pkg/context/input_test.go new file mode 100644 index 00000000..3a6c2e7b --- /dev/null +++ b/pkg/context/input_test.go @@ -0,0 +1,217 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package context + +import ( + "net/http" + "net/http/httptest" + "reflect" + "testing" +) + +func TestBind(t *testing.T) { + type testItem struct { + field string + empty interface{} + want interface{} + } + type Human struct { + ID int + Nick string + Pwd string + Ms bool + } + + cases := []struct { + request string + valueGp []testItem + }{ + {"/?p=str", []testItem{{"p", interface{}(""), interface{}("str")}}}, + + {"/?p=", []testItem{{"p", "", ""}}}, + {"/?p=str", []testItem{{"p", "", "str"}}}, + + {"/?p=123", []testItem{{"p", 0, 123}}}, + {"/?p=123", []testItem{{"p", uint(0), uint(123)}}}, + + {"/?p=1.0", []testItem{{"p", 0.0, 1.0}}}, + {"/?p=1", []testItem{{"p", false, true}}}, + + {"/?p=true", []testItem{{"p", false, true}}}, + {"/?p=ON", []testItem{{"p", false, true}}}, + {"/?p=on", []testItem{{"p", false, true}}}, + {"/?p=1", []testItem{{"p", false, true}}}, + {"/?p=2", []testItem{{"p", false, false}}}, + {"/?p=false", []testItem{{"p", false, false}}}, + + {"/?p[a]=1&p[b]=2&p[c]=3", []testItem{{"p", map[string]int{}, map[string]int{"a": 1, "b": 2, "c": 3}}}}, + {"/?p[a]=v1&p[b]=v2&p[c]=v3", []testItem{{"p", map[string]string{}, map[string]string{"a": "v1", "b": "v2", "c": "v3"}}}}, + + {"/?p[]=8&p[]=9&p[]=10", []testItem{{"p", []int{}, []int{8, 9, 10}}}}, + {"/?p[0]=8&p[1]=9&p[2]=10", []testItem{{"p", []int{}, []int{8, 9, 10}}}}, + {"/?p[0]=8&p[1]=9&p[2]=10&p[5]=14", []testItem{{"p", []int{}, []int{8, 9, 10, 0, 0, 14}}}}, + {"/?p[0]=8.0&p[1]=9.0&p[2]=10.0", []testItem{{"p", []float64{}, []float64{8.0, 9.0, 10.0}}}}, + + {"/?p[]=10&p[]=9&p[]=8", []testItem{{"p", []string{}, []string{"10", "9", "8"}}}}, + {"/?p[0]=8&p[1]=9&p[2]=10", []testItem{{"p", []string{}, []string{"8", "9", "10"}}}}, + + {"/?p[0]=true&p[1]=false&p[2]=true&p[5]=1&p[6]=ON&p[7]=other", []testItem{{"p", []bool{}, []bool{true, false, true, false, false, true, true, false}}}}, + + {"/?human.Nick=astaxie", []testItem{{"human", Human{}, Human{Nick: "astaxie"}}}}, + {"/?human.ID=888&human.Nick=astaxie&human.Ms=true&human[Pwd]=pass", []testItem{{"human", Human{}, Human{ID: 888, Nick: "astaxie", Ms: true, Pwd: "pass"}}}}, + {"/?human[0].ID=888&human[0].Nick=astaxie&human[0].Ms=true&human[0][Pwd]=pass01&human[1].ID=999&human[1].Nick=ysqi&human[1].Ms=On&human[1].Pwd=pass02", + []testItem{{"human", []Human{}, []Human{ + {ID: 888, Nick: "astaxie", Ms: true, Pwd: "pass01"}, + {ID: 999, Nick: "ysqi", Ms: true, Pwd: "pass02"}, + }}}}, + + { + "/?id=123&isok=true&ft=1.2&ol[0]=1&ol[1]=2&ul[]=str&ul[]=array&human.Nick=astaxie", + []testItem{ + {"id", 0, 123}, + {"isok", false, true}, + {"ft", 0.0, 1.2}, + {"ol", []int{}, []int{1, 2}}, + {"ul", []string{}, []string{"str", "array"}}, + {"human", Human{}, Human{Nick: "astaxie"}}, + }, + }, + } + for _, c := range cases { + r, _ := http.NewRequest("GET", c.request, nil) + beegoInput := NewInput() + beegoInput.Context = NewContext() + beegoInput.Context.Reset(httptest.NewRecorder(), r) + + for _, item := range c.valueGp { + got := item.empty + err := beegoInput.Bind(&got, item.field) + if err != nil { + t.Fatal(err) + } + if !reflect.DeepEqual(got, item.want) { + t.Fatalf("Bind %q error,should be:\n%#v \ngot:\n%#v", item.field, item.want, got) + } + } + + } +} + +func TestSubDomain(t *testing.T) { + r, _ := http.NewRequest("GET", "http://www.example.com/?id=123&isok=true&ft=1.2&ol[0]=1&ol[1]=2&ul[]=str&ul[]=array&user.Name=astaxie", nil) + beegoInput := NewInput() + beegoInput.Context = NewContext() + beegoInput.Context.Reset(httptest.NewRecorder(), r) + + subdomain := beegoInput.SubDomains() + if subdomain != "www" { + t.Fatal("Subdomain parse error, got" + subdomain) + } + + r, _ = http.NewRequest("GET", "http://localhost/", nil) + beegoInput.Context.Request = r + if beegoInput.SubDomains() != "" { + t.Fatal("Subdomain parse error, should be empty, got " + beegoInput.SubDomains()) + } + + r, _ = http.NewRequest("GET", "http://aa.bb.example.com/", nil) + beegoInput.Context.Request = r + if beegoInput.SubDomains() != "aa.bb" { + t.Fatal("Subdomain parse error, got " + beegoInput.SubDomains()) + } + + /* TODO Fix this + r, _ = http.NewRequest("GET", "http://127.0.0.1/", nil) + beegoInput.Context.Request = r + if beegoInput.SubDomains() != "" { + t.Fatal("Subdomain parse error, got " + beegoInput.SubDomains()) + } + */ + + r, _ = http.NewRequest("GET", "http://example.com/", nil) + beegoInput.Context.Request = r + if beegoInput.SubDomains() != "" { + t.Fatal("Subdomain parse error, got " + beegoInput.SubDomains()) + } + + r, _ = http.NewRequest("GET", "http://aa.bb.cc.dd.example.com/", nil) + beegoInput.Context.Request = r + if beegoInput.SubDomains() != "aa.bb.cc.dd" { + t.Fatal("Subdomain parse error, got " + beegoInput.SubDomains()) + } +} + +func TestParams(t *testing.T) { + inp := NewInput() + + inp.SetParam("p1", "val1_ver1") + inp.SetParam("p2", "val2_ver1") + inp.SetParam("p3", "val3_ver1") + if l := inp.ParamsLen(); l != 3 { + t.Fatalf("Input.ParamsLen wrong value: %d, expected %d", l, 3) + } + + if val := inp.Param("p1"); val != "val1_ver1" { + t.Fatalf("Input.Param wrong value: %s, expected %s", val, "val1_ver1") + } + if val := inp.Param("p3"); val != "val3_ver1" { + t.Fatalf("Input.Param wrong value: %s, expected %s", val, "val3_ver1") + } + vals := inp.Params() + expected := map[string]string{ + "p1": "val1_ver1", + "p2": "val2_ver1", + "p3": "val3_ver1", + } + if !reflect.DeepEqual(vals, expected) { + t.Fatalf("Input.Params wrong value: %s, expected %s", vals, expected) + } + + // overwriting existing params + inp.SetParam("p1", "val1_ver2") + inp.SetParam("p2", "val2_ver2") + expected = map[string]string{ + "p1": "val1_ver2", + "p2": "val2_ver2", + "p3": "val3_ver1", + } + vals = inp.Params() + if !reflect.DeepEqual(vals, expected) { + t.Fatalf("Input.Params wrong value: %s, expected %s", vals, expected) + } + + if l := inp.ParamsLen(); l != 3 { + t.Fatalf("Input.ParamsLen wrong value: %d, expected %d", l, 3) + } + + if val := inp.Param("p1"); val != "val1_ver2" { + t.Fatalf("Input.Param wrong value: %s, expected %s", val, "val1_ver2") + } + + if val := inp.Param("p2"); val != "val2_ver2" { + t.Fatalf("Input.Param wrong value: %s, expected %s", val, "val1_ver2") + } + +} +func BenchmarkQuery(b *testing.B) { + beegoInput := NewInput() + beegoInput.Context = NewContext() + beegoInput.Context.Request, _ = http.NewRequest("POST", "http://www.example.com/?q=foo", nil) + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + beegoInput.Query("q") + } + }) +} diff --git a/pkg/context/output.go b/pkg/context/output.go new file mode 100644 index 00000000..238dcf45 --- /dev/null +++ b/pkg/context/output.go @@ -0,0 +1,408 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package context + +import ( + "bytes" + "encoding/json" + "encoding/xml" + "errors" + "fmt" + "html/template" + "io" + "mime" + "net/http" + "net/url" + "os" + "path/filepath" + "strconv" + "strings" + "time" + + yaml "gopkg.in/yaml.v2" +) + +// BeegoOutput does work for sending response header. +type BeegoOutput struct { + Context *Context + Status int + EnableGzip bool +} + +// NewOutput returns new BeegoOutput. +// it contains nothing now. +func NewOutput() *BeegoOutput { + return &BeegoOutput{} +} + +// Reset init BeegoOutput +func (output *BeegoOutput) Reset(ctx *Context) { + output.Context = ctx + output.Status = 0 +} + +// Header sets response header item string via given key. +func (output *BeegoOutput) Header(key, val string) { + output.Context.ResponseWriter.Header().Set(key, val) +} + +// Body sets response body content. +// if EnableGzip, compress content string. +// it sends out response body directly. +func (output *BeegoOutput) Body(content []byte) error { + var encoding string + var buf = &bytes.Buffer{} + if output.EnableGzip { + encoding = ParseEncoding(output.Context.Request) + } + if b, n, _ := WriteBody(encoding, buf, content); b { + output.Header("Content-Encoding", n) + output.Header("Content-Length", strconv.Itoa(buf.Len())) + } else { + output.Header("Content-Length", strconv.Itoa(len(content))) + } + // Write status code if it has been set manually + // Set it to 0 afterwards to prevent "multiple response.WriteHeader calls" + if output.Status != 0 { + output.Context.ResponseWriter.WriteHeader(output.Status) + output.Status = 0 + } else { + output.Context.ResponseWriter.Started = true + } + io.Copy(output.Context.ResponseWriter, buf) + return nil +} + +// Cookie sets cookie value via given key. +// others are ordered as cookie's max age time, path,domain, secure and httponly. +func (output *BeegoOutput) Cookie(name string, value string, others ...interface{}) { + var b bytes.Buffer + fmt.Fprintf(&b, "%s=%s", sanitizeName(name), sanitizeValue(value)) + + //fix cookie not work in IE + if len(others) > 0 { + var maxAge int64 + + switch v := others[0].(type) { + case int: + maxAge = int64(v) + case int32: + maxAge = int64(v) + case int64: + maxAge = v + } + + switch { + case maxAge > 0: + fmt.Fprintf(&b, "; Expires=%s; Max-Age=%d", time.Now().Add(time.Duration(maxAge)*time.Second).UTC().Format(time.RFC1123), maxAge) + case maxAge < 0: + fmt.Fprintf(&b, "; Max-Age=0") + } + } + + // the settings below + // Path, Domain, Secure, HttpOnly + // can use nil skip set + + // default "/" + if len(others) > 1 { + if v, ok := others[1].(string); ok && len(v) > 0 { + fmt.Fprintf(&b, "; Path=%s", sanitizeValue(v)) + } + } else { + fmt.Fprintf(&b, "; Path=%s", "/") + } + + // default empty + if len(others) > 2 { + if v, ok := others[2].(string); ok && len(v) > 0 { + fmt.Fprintf(&b, "; Domain=%s", sanitizeValue(v)) + } + } + + // default empty + if len(others) > 3 { + var secure bool + switch v := others[3].(type) { + case bool: + secure = v + default: + if others[3] != nil { + secure = true + } + } + if secure { + fmt.Fprintf(&b, "; Secure") + } + } + + // default false. for session cookie default true + if len(others) > 4 { + if v, ok := others[4].(bool); ok && v { + fmt.Fprintf(&b, "; HttpOnly") + } + } + + output.Context.ResponseWriter.Header().Add("Set-Cookie", b.String()) +} + +var cookieNameSanitizer = strings.NewReplacer("\n", "-", "\r", "-") + +func sanitizeName(n string) string { + return cookieNameSanitizer.Replace(n) +} + +var cookieValueSanitizer = strings.NewReplacer("\n", " ", "\r", " ", ";", " ") + +func sanitizeValue(v string) string { + return cookieValueSanitizer.Replace(v) +} + +func jsonRenderer(value interface{}) Renderer { + return rendererFunc(func(ctx *Context) { + ctx.Output.JSON(value, false, false) + }) +} + +func errorRenderer(err error) Renderer { + return rendererFunc(func(ctx *Context) { + ctx.Output.SetStatus(500) + ctx.Output.Body([]byte(err.Error())) + }) +} + +// JSON writes json to response body. +// if encoding is true, it converts utf-8 to \u0000 type. +func (output *BeegoOutput) JSON(data interface{}, hasIndent bool, encoding bool) error { + output.Header("Content-Type", "application/json; charset=utf-8") + var content []byte + var err error + if hasIndent { + content, err = json.MarshalIndent(data, "", " ") + } else { + content, err = json.Marshal(data) + } + if err != nil { + http.Error(output.Context.ResponseWriter, err.Error(), http.StatusInternalServerError) + return err + } + if encoding { + content = []byte(stringsToJSON(string(content))) + } + return output.Body(content) +} + +// YAML writes yaml to response body. +func (output *BeegoOutput) YAML(data interface{}) error { + output.Header("Content-Type", "application/x-yaml; charset=utf-8") + var content []byte + var err error + content, err = yaml.Marshal(data) + if err != nil { + http.Error(output.Context.ResponseWriter, err.Error(), http.StatusInternalServerError) + return err + } + return output.Body(content) +} + +// JSONP writes jsonp to response body. +func (output *BeegoOutput) JSONP(data interface{}, hasIndent bool) error { + output.Header("Content-Type", "application/javascript; charset=utf-8") + var content []byte + var err error + if hasIndent { + content, err = json.MarshalIndent(data, "", " ") + } else { + content, err = json.Marshal(data) + } + if err != nil { + http.Error(output.Context.ResponseWriter, err.Error(), http.StatusInternalServerError) + return err + } + callback := output.Context.Input.Query("callback") + if callback == "" { + return errors.New(`"callback" parameter required`) + } + callback = template.JSEscapeString(callback) + callbackContent := bytes.NewBufferString(" if(window." + callback + ")" + callback) + callbackContent.WriteString("(") + callbackContent.Write(content) + callbackContent.WriteString(");\r\n") + return output.Body(callbackContent.Bytes()) +} + +// XML writes xml string to response body. +func (output *BeegoOutput) XML(data interface{}, hasIndent bool) error { + output.Header("Content-Type", "application/xml; charset=utf-8") + var content []byte + var err error + if hasIndent { + content, err = xml.MarshalIndent(data, "", " ") + } else { + content, err = xml.Marshal(data) + } + if err != nil { + http.Error(output.Context.ResponseWriter, err.Error(), http.StatusInternalServerError) + return err + } + return output.Body(content) +} + +// ServeFormatted serve YAML, XML OR JSON, depending on the value of the Accept header +func (output *BeegoOutput) ServeFormatted(data interface{}, hasIndent bool, hasEncode ...bool) { + accept := output.Context.Input.Header("Accept") + switch accept { + case ApplicationYAML: + output.YAML(data) + case ApplicationXML, TextXML: + output.XML(data, hasIndent) + default: + output.JSON(data, hasIndent, len(hasEncode) > 0 && hasEncode[0]) + } +} + +// Download forces response for download file. +// it prepares the download response header automatically. +func (output *BeegoOutput) Download(file string, filename ...string) { + // check get file error, file not found or other error. + if _, err := os.Stat(file); err != nil { + http.ServeFile(output.Context.ResponseWriter, output.Context.Request, file) + return + } + + var fName string + if len(filename) > 0 && filename[0] != "" { + fName = filename[0] + } else { + fName = filepath.Base(file) + } + //https://tools.ietf.org/html/rfc6266#section-4.3 + fn := url.PathEscape(fName) + if fName == fn { + fn = "filename=" + fn + } else { + /** + The parameters "filename" and "filename*" differ only in that + "filename*" uses the encoding defined in [RFC5987], allowing the use + of characters not present in the ISO-8859-1 character set + ([ISO-8859-1]). + */ + fn = "filename=" + fName + "; filename*=utf-8''" + fn + } + output.Header("Content-Disposition", "attachment; "+fn) + output.Header("Content-Description", "File Transfer") + output.Header("Content-Type", "application/octet-stream") + output.Header("Content-Transfer-Encoding", "binary") + output.Header("Expires", "0") + output.Header("Cache-Control", "must-revalidate") + output.Header("Pragma", "public") + http.ServeFile(output.Context.ResponseWriter, output.Context.Request, file) +} + +// ContentType sets the content type from ext string. +// MIME type is given in mime package. +func (output *BeegoOutput) ContentType(ext string) { + if !strings.HasPrefix(ext, ".") { + ext = "." + ext + } + ctype := mime.TypeByExtension(ext) + if ctype != "" { + output.Header("Content-Type", ctype) + } +} + +// SetStatus sets response status code. +// It writes response header directly. +func (output *BeegoOutput) SetStatus(status int) { + output.Status = status +} + +// IsCachable returns boolean of this request is cached. +// HTTP 304 means cached. +func (output *BeegoOutput) IsCachable() bool { + return output.Status >= 200 && output.Status < 300 || output.Status == 304 +} + +// IsEmpty returns boolean of this request is empty. +// HTTP 201,204 and 304 means empty. +func (output *BeegoOutput) IsEmpty() bool { + return output.Status == 201 || output.Status == 204 || output.Status == 304 +} + +// IsOk returns boolean of this request runs well. +// HTTP 200 means ok. +func (output *BeegoOutput) IsOk() bool { + return output.Status == 200 +} + +// IsSuccessful returns boolean of this request runs successfully. +// HTTP 2xx means ok. +func (output *BeegoOutput) IsSuccessful() bool { + return output.Status >= 200 && output.Status < 300 +} + +// IsRedirect returns boolean of this request is redirection header. +// HTTP 301,302,307 means redirection. +func (output *BeegoOutput) IsRedirect() bool { + return output.Status == 301 || output.Status == 302 || output.Status == 303 || output.Status == 307 +} + +// IsForbidden returns boolean of this request is forbidden. +// HTTP 403 means forbidden. +func (output *BeegoOutput) IsForbidden() bool { + return output.Status == 403 +} + +// IsNotFound returns boolean of this request is not found. +// HTTP 404 means not found. +func (output *BeegoOutput) IsNotFound() bool { + return output.Status == 404 +} + +// IsClientError returns boolean of this request client sends error data. +// HTTP 4xx means client error. +func (output *BeegoOutput) IsClientError() bool { + return output.Status >= 400 && output.Status < 500 +} + +// IsServerError returns boolean of this server handler errors. +// HTTP 5xx means server internal error. +func (output *BeegoOutput) IsServerError() bool { + return output.Status >= 500 && output.Status < 600 +} + +func stringsToJSON(str string) string { + var jsons bytes.Buffer + for _, r := range str { + rint := int(r) + if rint < 128 { + jsons.WriteRune(r) + } else { + jsons.WriteString("\\u") + if rint < 0x100 { + jsons.WriteString("00") + } else if rint < 0x1000 { + jsons.WriteString("0") + } + jsons.WriteString(strconv.FormatInt(int64(rint), 16)) + } + } + return jsons.String() +} + +// Session sets session item value with given key. +func (output *BeegoOutput) Session(name interface{}, value interface{}) { + output.Context.Input.CruSession.Set(name, value) +} diff --git a/pkg/context/param/conv.go b/pkg/context/param/conv.go new file mode 100644 index 00000000..c200e008 --- /dev/null +++ b/pkg/context/param/conv.go @@ -0,0 +1,78 @@ +package param + +import ( + "fmt" + "reflect" + + beecontext "github.com/astaxie/beego/context" + "github.com/astaxie/beego/logs" +) + +// ConvertParams converts http method params to values that will be passed to the method controller as arguments +func ConvertParams(methodParams []*MethodParam, methodType reflect.Type, ctx *beecontext.Context) (result []reflect.Value) { + result = make([]reflect.Value, 0, len(methodParams)) + for i := 0; i < len(methodParams); i++ { + reflectValue := convertParam(methodParams[i], methodType.In(i), ctx) + result = append(result, reflectValue) + } + return +} + +func convertParam(param *MethodParam, paramType reflect.Type, ctx *beecontext.Context) (result reflect.Value) { + paramValue := getParamValue(param, ctx) + if paramValue == "" { + if param.required { + ctx.Abort(400, fmt.Sprintf("Missing parameter %s", param.name)) + } else { + paramValue = param.defaultValue + } + } + + reflectValue, err := parseValue(param, paramValue, paramType) + if err != nil { + logs.Debug(fmt.Sprintf("Error converting param %s to type %s. Value: %v, Error: %s", param.name, paramType, paramValue, err)) + ctx.Abort(400, fmt.Sprintf("Invalid parameter %s. Can not convert %v to type %s", param.name, paramValue, paramType)) + } + + return reflectValue +} + +func getParamValue(param *MethodParam, ctx *beecontext.Context) string { + switch param.in { + case body: + return string(ctx.Input.RequestBody) + case header: + return ctx.Input.Header(param.name) + case path: + return ctx.Input.Query(":" + param.name) + default: + return ctx.Input.Query(param.name) + } +} + +func parseValue(param *MethodParam, paramValue string, paramType reflect.Type) (result reflect.Value, err error) { + if paramValue == "" { + return reflect.Zero(paramType), nil + } + parser := getParser(param, paramType) + value, err := parser.parse(paramValue, paramType) + if err != nil { + return result, err + } + + return safeConvert(reflect.ValueOf(value), paramType) +} + +func safeConvert(value reflect.Value, t reflect.Type) (result reflect.Value, err error) { + defer func() { + if r := recover(); r != nil { + var ok bool + err, ok = r.(error) + if !ok { + err = fmt.Errorf("%v", r) + } + } + }() + result = value.Convert(t) + return +} diff --git a/pkg/context/param/methodparams.go b/pkg/context/param/methodparams.go new file mode 100644 index 00000000..cd6708a2 --- /dev/null +++ b/pkg/context/param/methodparams.go @@ -0,0 +1,69 @@ +package param + +import ( + "fmt" + "strings" +) + +//MethodParam keeps param information to be auto passed to controller methods +type MethodParam struct { + name string + in paramType + required bool + defaultValue string +} + +type paramType byte + +const ( + param paramType = iota + path + body + header +) + +//New creates a new MethodParam with name and specific options +func New(name string, opts ...MethodParamOption) *MethodParam { + return newParam(name, nil, opts) +} + +func newParam(name string, parser paramParser, opts []MethodParamOption) (param *MethodParam) { + param = &MethodParam{name: name} + for _, option := range opts { + option(param) + } + return +} + +//Make creates an array of MethodParmas or an empty array +func Make(list ...*MethodParam) []*MethodParam { + if len(list) > 0 { + return list + } + return nil +} + +func (mp *MethodParam) String() string { + options := []string{} + result := "param.New(\"" + mp.name + "\"" + if mp.required { + options = append(options, "param.IsRequired") + } + switch mp.in { + case path: + options = append(options, "param.InPath") + case body: + options = append(options, "param.InBody") + case header: + options = append(options, "param.InHeader") + } + if mp.defaultValue != "" { + options = append(options, fmt.Sprintf(`param.Default("%s")`, mp.defaultValue)) + } + if len(options) > 0 { + result += ", " + } + result += strings.Join(options, ", ") + result += ")" + return result +} diff --git a/pkg/context/param/options.go b/pkg/context/param/options.go new file mode 100644 index 00000000..3d5ba013 --- /dev/null +++ b/pkg/context/param/options.go @@ -0,0 +1,37 @@ +package param + +import ( + "fmt" +) + +// MethodParamOption defines a func which apply options on a MethodParam +type MethodParamOption func(*MethodParam) + +// IsRequired indicates that this param is required and can not be omitted from the http request +var IsRequired MethodParamOption = func(p *MethodParam) { + p.required = true +} + +// InHeader indicates that this param is passed via an http header +var InHeader MethodParamOption = func(p *MethodParam) { + p.in = header +} + +// InPath indicates that this param is part of the URL path +var InPath MethodParamOption = func(p *MethodParam) { + p.in = path +} + +// InBody indicates that this param is passed as an http request body +var InBody MethodParamOption = func(p *MethodParam) { + p.in = body +} + +// Default provides a default value for the http param +func Default(defaultValue interface{}) MethodParamOption { + return func(p *MethodParam) { + if defaultValue != nil { + p.defaultValue = fmt.Sprint(defaultValue) + } + } +} diff --git a/pkg/context/param/parsers.go b/pkg/context/param/parsers.go new file mode 100644 index 00000000..421aecf0 --- /dev/null +++ b/pkg/context/param/parsers.go @@ -0,0 +1,149 @@ +package param + +import ( + "encoding/json" + "reflect" + "strconv" + "strings" + "time" +) + +type paramParser interface { + parse(value string, toType reflect.Type) (interface{}, error) +} + +func getParser(param *MethodParam, t reflect.Type) paramParser { + switch t.Kind() { + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, + reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + return intParser{} + case reflect.Slice: + if t.Elem().Kind() == reflect.Uint8 { //treat []byte as string + return stringParser{} + } + if param.in == body { + return jsonParser{} + } + elemParser := getParser(param, t.Elem()) + if elemParser == (jsonParser{}) { + return elemParser + } + return sliceParser(elemParser) + case reflect.Bool: + return boolParser{} + case reflect.String: + return stringParser{} + case reflect.Float32, reflect.Float64: + return floatParser{} + case reflect.Ptr: + elemParser := getParser(param, t.Elem()) + if elemParser == (jsonParser{}) { + return elemParser + } + return ptrParser(elemParser) + default: + if t.PkgPath() == "time" && t.Name() == "Time" { + return timeParser{} + } + return jsonParser{} + } +} + +type parserFunc func(value string, toType reflect.Type) (interface{}, error) + +func (f parserFunc) parse(value string, toType reflect.Type) (interface{}, error) { + return f(value, toType) +} + +type boolParser struct { +} + +func (p boolParser) parse(value string, toType reflect.Type) (interface{}, error) { + return strconv.ParseBool(value) +} + +type stringParser struct { +} + +func (p stringParser) parse(value string, toType reflect.Type) (interface{}, error) { + return value, nil +} + +type intParser struct { +} + +func (p intParser) parse(value string, toType reflect.Type) (interface{}, error) { + return strconv.Atoi(value) +} + +type floatParser struct { +} + +func (p floatParser) parse(value string, toType reflect.Type) (interface{}, error) { + if toType.Kind() == reflect.Float32 { + res, err := strconv.ParseFloat(value, 32) + if err != nil { + return nil, err + } + return float32(res), nil + } + return strconv.ParseFloat(value, 64) +} + +type timeParser struct { +} + +func (p timeParser) parse(value string, toType reflect.Type) (result interface{}, err error) { + result, err = time.Parse(time.RFC3339, value) + if err != nil { + result, err = time.Parse("2006-01-02", value) + } + return +} + +type jsonParser struct { +} + +func (p jsonParser) parse(value string, toType reflect.Type) (interface{}, error) { + pResult := reflect.New(toType) + v := pResult.Interface() + err := json.Unmarshal([]byte(value), v) + if err != nil { + return nil, err + } + return pResult.Elem().Interface(), nil +} + +func sliceParser(elemParser paramParser) paramParser { + return parserFunc(func(value string, toType reflect.Type) (interface{}, error) { + values := strings.Split(value, ",") + result := reflect.MakeSlice(toType, 0, len(values)) + elemType := toType.Elem() + for _, v := range values { + parsedValue, err := elemParser.parse(v, elemType) + if err != nil { + return nil, err + } + result = reflect.Append(result, reflect.ValueOf(parsedValue)) + } + return result.Interface(), nil + }) +} + +func ptrParser(elemParser paramParser) paramParser { + return parserFunc(func(value string, toType reflect.Type) (interface{}, error) { + parsedValue, err := elemParser.parse(value, toType.Elem()) + if err != nil { + return nil, err + } + newValPtr := reflect.New(toType.Elem()) + newVal := reflect.Indirect(newValPtr) + convertedVal, err := safeConvert(reflect.ValueOf(parsedValue), toType.Elem()) + if err != nil { + return nil, err + } + + newVal.Set(convertedVal) + return newValPtr.Interface(), nil + }) +} diff --git a/pkg/context/param/parsers_test.go b/pkg/context/param/parsers_test.go new file mode 100644 index 00000000..7065a28e --- /dev/null +++ b/pkg/context/param/parsers_test.go @@ -0,0 +1,84 @@ +package param + +import "testing" +import "reflect" +import "time" + +type testDefinition struct { + strValue string + expectedValue interface{} + expectedParser paramParser +} + +func Test_Parsers(t *testing.T) { + + //ints + checkParser(testDefinition{"1", 1, intParser{}}, t) + checkParser(testDefinition{"-1", int64(-1), intParser{}}, t) + checkParser(testDefinition{"1", uint64(1), intParser{}}, t) + + //floats + checkParser(testDefinition{"1.0", float32(1.0), floatParser{}}, t) + checkParser(testDefinition{"-1.0", float64(-1.0), floatParser{}}, t) + + //strings + checkParser(testDefinition{"AB", "AB", stringParser{}}, t) + checkParser(testDefinition{"AB", []byte{65, 66}, stringParser{}}, t) + + //bools + checkParser(testDefinition{"true", true, boolParser{}}, t) + checkParser(testDefinition{"0", false, boolParser{}}, t) + + //timeParser + checkParser(testDefinition{"2017-05-30T13:54:53Z", time.Date(2017, 5, 30, 13, 54, 53, 0, time.UTC), timeParser{}}, t) + checkParser(testDefinition{"2017-05-30", time.Date(2017, 5, 30, 0, 0, 0, 0, time.UTC), timeParser{}}, t) + + //json + checkParser(testDefinition{`{"X": 5, "Y":"Z"}`, struct { + X int + Y string + }{5, "Z"}, jsonParser{}}, t) + + //slice in query is parsed as comma delimited + checkParser(testDefinition{`1,2`, []int{1, 2}, sliceParser(intParser{})}, t) + + //slice in body is parsed as json + checkParser(testDefinition{`["a","b"]`, []string{"a", "b"}, jsonParser{}}, t, MethodParam{in: body}) + + //pointers + var someInt = 1 + checkParser(testDefinition{`1`, &someInt, ptrParser(intParser{})}, t) + + var someStruct = struct{ X int }{5} + checkParser(testDefinition{`{"X": 5}`, &someStruct, jsonParser{}}, t) + +} + +func checkParser(def testDefinition, t *testing.T, methodParam ...MethodParam) { + toType := reflect.TypeOf(def.expectedValue) + var mp MethodParam + if len(methodParam) == 0 { + mp = MethodParam{} + } else { + mp = methodParam[0] + } + parser := getParser(&mp, toType) + + if reflect.TypeOf(parser) != reflect.TypeOf(def.expectedParser) { + t.Errorf("Invalid parser for value %v. Expected: %v, actual: %v", def.strValue, reflect.TypeOf(def.expectedParser).Name(), reflect.TypeOf(parser).Name()) + return + } + result, err := parser.parse(def.strValue, toType) + if err != nil { + t.Errorf("Parsing error for value %v. Expected result: %v, error: %v", def.strValue, def.expectedValue, err) + return + } + convResult, err := safeConvert(reflect.ValueOf(result), toType) + if err != nil { + t.Errorf("Conversion error for %v. from value: %v, toType: %v, error: %v", def.strValue, result, toType, err) + return + } + if !reflect.DeepEqual(convResult.Interface(), def.expectedValue) { + t.Errorf("Parsing error for value %v. Expected result: %v, actual: %v", def.strValue, def.expectedValue, result) + } +} diff --git a/pkg/context/renderer.go b/pkg/context/renderer.go new file mode 100644 index 00000000..36a7cb53 --- /dev/null +++ b/pkg/context/renderer.go @@ -0,0 +1,12 @@ +package context + +// Renderer defines an http response renderer +type Renderer interface { + Render(ctx *Context) +} + +type rendererFunc func(ctx *Context) + +func (f rendererFunc) Render(ctx *Context) { + f(ctx) +} diff --git a/pkg/context/response.go b/pkg/context/response.go new file mode 100644 index 00000000..9c3c715a --- /dev/null +++ b/pkg/context/response.go @@ -0,0 +1,27 @@ +package context + +import ( + "strconv" + + "net/http" +) + +const ( + //BadRequest indicates http error 400 + BadRequest StatusCode = http.StatusBadRequest + + //NotFound indicates http error 404 + NotFound StatusCode = http.StatusNotFound +) + +// StatusCode sets the http response status code +type StatusCode int + +func (s StatusCode) Error() string { + return strconv.Itoa(int(s)) +} + +// Render sets the http status code +func (s StatusCode) Render(ctx *Context) { + ctx.Output.SetStatus(int(s)) +} diff --git a/pkg/controller.go b/pkg/controller.go new file mode 100644 index 00000000..0e8853b3 --- /dev/null +++ b/pkg/controller.go @@ -0,0 +1,706 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package beego + +import ( + "bytes" + "errors" + "fmt" + "html/template" + "io" + "mime/multipart" + "net/http" + "net/url" + "os" + "reflect" + "strconv" + "strings" + + "github.com/astaxie/beego/context" + "github.com/astaxie/beego/context/param" + "github.com/astaxie/beego/session" +) + +var ( + // ErrAbort custom error when user stop request handler manually. + ErrAbort = errors.New("user stop run") + // GlobalControllerRouter store comments with controller. pkgpath+controller:comments + GlobalControllerRouter = make(map[string][]ControllerComments) +) + +// ControllerFilter store the filter for controller +type ControllerFilter struct { + Pattern string + Pos int + Filter FilterFunc + ReturnOnOutput bool + ResetParams bool +} + +// ControllerFilterComments store the comment for controller level filter +type ControllerFilterComments struct { + Pattern string + Pos int + Filter string // NOQA + ReturnOnOutput bool + ResetParams bool +} + +// ControllerImportComments store the import comment for controller needed +type ControllerImportComments struct { + ImportPath string + ImportAlias string +} + +// ControllerComments store the comment for the controller method +type ControllerComments struct { + Method string + Router string + Filters []*ControllerFilter + ImportComments []*ControllerImportComments + FilterComments []*ControllerFilterComments + AllowHTTPMethods []string + Params []map[string]string + MethodParams []*param.MethodParam +} + +// ControllerCommentsSlice implements the sort interface +type ControllerCommentsSlice []ControllerComments + +func (p ControllerCommentsSlice) Len() int { return len(p) } +func (p ControllerCommentsSlice) Less(i, j int) bool { return p[i].Router < p[j].Router } +func (p ControllerCommentsSlice) Swap(i, j int) { p[i], p[j] = p[j], p[i] } + +// Controller defines some basic http request handler operations, such as +// http context, template and view, session and xsrf. +type Controller struct { + // context data + Ctx *context.Context + Data map[interface{}]interface{} + + // route controller info + controllerName string + actionName string + methodMapping map[string]func() //method:routertree + AppController interface{} + + // template data + TplName string + ViewPath string + Layout string + LayoutSections map[string]string // the key is the section name and the value is the template name + TplPrefix string + TplExt string + EnableRender bool + + // xsrf data + _xsrfToken string + XSRFExpire int + EnableXSRF bool + + // session + CruSession session.Store +} + +// ControllerInterface is an interface to uniform all controller handler. +type ControllerInterface interface { + Init(ct *context.Context, controllerName, actionName string, app interface{}) + Prepare() + Get() + Post() + Delete() + Put() + Head() + Patch() + Options() + Trace() + Finish() + Render() error + XSRFToken() string + CheckXSRFCookie() bool + HandlerFunc(fn string) bool + URLMapping() +} + +// Init generates default values of controller operations. +func (c *Controller) Init(ctx *context.Context, controllerName, actionName string, app interface{}) { + c.Layout = "" + c.TplName = "" + c.controllerName = controllerName + c.actionName = actionName + c.Ctx = ctx + c.TplExt = "tpl" + c.AppController = app + c.EnableRender = true + c.EnableXSRF = true + c.Data = ctx.Input.Data() + c.methodMapping = make(map[string]func()) +} + +// Prepare runs after Init before request function execution. +func (c *Controller) Prepare() {} + +// Finish runs after request function execution. +func (c *Controller) Finish() {} + +// Get adds a request function to handle GET request. +func (c *Controller) Get() { + http.Error(c.Ctx.ResponseWriter, "Method Not Allowed", http.StatusMethodNotAllowed) +} + +// Post adds a request function to handle POST request. +func (c *Controller) Post() { + http.Error(c.Ctx.ResponseWriter, "Method Not Allowed", http.StatusMethodNotAllowed) +} + +// Delete adds a request function to handle DELETE request. +func (c *Controller) Delete() { + http.Error(c.Ctx.ResponseWriter, "Method Not Allowed", http.StatusMethodNotAllowed) +} + +// Put adds a request function to handle PUT request. +func (c *Controller) Put() { + http.Error(c.Ctx.ResponseWriter, "Method Not Allowed", http.StatusMethodNotAllowed) +} + +// Head adds a request function to handle HEAD request. +func (c *Controller) Head() { + http.Error(c.Ctx.ResponseWriter, "Method Not Allowed", http.StatusMethodNotAllowed) +} + +// Patch adds a request function to handle PATCH request. +func (c *Controller) Patch() { + http.Error(c.Ctx.ResponseWriter, "Method Not Allowed", http.StatusMethodNotAllowed) +} + +// Options adds a request function to handle OPTIONS request. +func (c *Controller) Options() { + http.Error(c.Ctx.ResponseWriter, "Method Not Allowed", http.StatusMethodNotAllowed) +} + +// Trace adds a request function to handle Trace request. +// this method SHOULD NOT be overridden. +// https://tools.ietf.org/html/rfc7231#section-4.3.8 +// The TRACE method requests a remote, application-level loop-back of +// the request message. The final recipient of the request SHOULD +// reflect the message received, excluding some fields described below, +// back to the client as the message body of a 200 (OK) response with a +// Content-Type of "message/http" (Section 8.3.1 of [RFC7230]). +func (c *Controller) Trace() { + ts := func(h http.Header) (hs string) { + for k, v := range h { + hs += fmt.Sprintf("\r\n%s: %s", k, v) + } + return + } + hs := fmt.Sprintf("\r\nTRACE %s %s%s\r\n", c.Ctx.Request.RequestURI, c.Ctx.Request.Proto, ts(c.Ctx.Request.Header)) + c.Ctx.Output.Header("Content-Type", "message/http") + c.Ctx.Output.Header("Content-Length", fmt.Sprint(len(hs))) + c.Ctx.Output.Header("Cache-Control", "no-cache, no-store, must-revalidate") + c.Ctx.WriteString(hs) +} + +// HandlerFunc call function with the name +func (c *Controller) HandlerFunc(fnname string) bool { + if v, ok := c.methodMapping[fnname]; ok { + v() + return true + } + return false +} + +// URLMapping register the internal Controller router. +func (c *Controller) URLMapping() {} + +// Mapping the method to function +func (c *Controller) Mapping(method string, fn func()) { + c.methodMapping[method] = fn +} + +// Render sends the response with rendered template bytes as text/html type. +func (c *Controller) Render() error { + if !c.EnableRender { + return nil + } + rb, err := c.RenderBytes() + if err != nil { + return err + } + + if c.Ctx.ResponseWriter.Header().Get("Content-Type") == "" { + c.Ctx.Output.Header("Content-Type", "text/html; charset=utf-8") + } + + return c.Ctx.Output.Body(rb) +} + +// RenderString returns the rendered template string. Do not send out response. +func (c *Controller) RenderString() (string, error) { + b, e := c.RenderBytes() + return string(b), e +} + +// RenderBytes returns the bytes of rendered template string. Do not send out response. +func (c *Controller) RenderBytes() ([]byte, error) { + buf, err := c.renderTemplate() + //if the controller has set layout, then first get the tplName's content set the content to the layout + if err == nil && c.Layout != "" { + c.Data["LayoutContent"] = template.HTML(buf.String()) + + if c.LayoutSections != nil { + for sectionName, sectionTpl := range c.LayoutSections { + if sectionTpl == "" { + c.Data[sectionName] = "" + continue + } + buf.Reset() + err = ExecuteViewPathTemplate(&buf, sectionTpl, c.viewPath(), c.Data) + if err != nil { + return nil, err + } + c.Data[sectionName] = template.HTML(buf.String()) + } + } + + buf.Reset() + ExecuteViewPathTemplate(&buf, c.Layout, c.viewPath(), c.Data) + } + return buf.Bytes(), err +} + +func (c *Controller) renderTemplate() (bytes.Buffer, error) { + var buf bytes.Buffer + if c.TplName == "" { + c.TplName = strings.ToLower(c.controllerName) + "/" + strings.ToLower(c.actionName) + "." + c.TplExt + } + if c.TplPrefix != "" { + c.TplName = c.TplPrefix + c.TplName + } + if BConfig.RunMode == DEV { + buildFiles := []string{c.TplName} + if c.Layout != "" { + buildFiles = append(buildFiles, c.Layout) + if c.LayoutSections != nil { + for _, sectionTpl := range c.LayoutSections { + if sectionTpl == "" { + continue + } + buildFiles = append(buildFiles, sectionTpl) + } + } + } + BuildTemplate(c.viewPath(), buildFiles...) + } + return buf, ExecuteViewPathTemplate(&buf, c.TplName, c.viewPath(), c.Data) +} + +func (c *Controller) viewPath() string { + if c.ViewPath == "" { + return BConfig.WebConfig.ViewsPath + } + return c.ViewPath +} + +// Redirect sends the redirection response to url with status code. +func (c *Controller) Redirect(url string, code int) { + LogAccess(c.Ctx, nil, code) + c.Ctx.Redirect(code, url) +} + +// SetData set the data depending on the accepted +func (c *Controller) SetData(data interface{}) { + accept := c.Ctx.Input.Header("Accept") + switch accept { + case context.ApplicationYAML: + c.Data["yaml"] = data + case context.ApplicationXML, context.TextXML: + c.Data["xml"] = data + default: + c.Data["json"] = data + } +} + +// Abort stops controller handler and show the error data if code is defined in ErrorMap or code string. +func (c *Controller) Abort(code string) { + status, err := strconv.Atoi(code) + if err != nil { + status = 200 + } + c.CustomAbort(status, code) +} + +// CustomAbort stops controller handler and show the error data, it's similar Aborts, but support status code and body. +func (c *Controller) CustomAbort(status int, body string) { + // first panic from ErrorMaps, it is user defined error functions. + if _, ok := ErrorMaps[body]; ok { + c.Ctx.Output.Status = status + panic(body) + } + // last panic user string + c.Ctx.ResponseWriter.WriteHeader(status) + c.Ctx.ResponseWriter.Write([]byte(body)) + panic(ErrAbort) +} + +// StopRun makes panic of USERSTOPRUN error and go to recover function if defined. +func (c *Controller) StopRun() { + panic(ErrAbort) +} + +// URLFor does another controller handler in this request function. +// it goes to this controller method if endpoint is not clear. +func (c *Controller) URLFor(endpoint string, values ...interface{}) string { + if len(endpoint) == 0 { + return "" + } + if endpoint[0] == '.' { + return URLFor(reflect.Indirect(reflect.ValueOf(c.AppController)).Type().Name()+endpoint, values...) + } + return URLFor(endpoint, values...) +} + +// ServeJSON sends a json response with encoding charset. +func (c *Controller) ServeJSON(encoding ...bool) { + var ( + hasIndent = BConfig.RunMode != PROD + hasEncoding = len(encoding) > 0 && encoding[0] + ) + + c.Ctx.Output.JSON(c.Data["json"], hasIndent, hasEncoding) +} + +// ServeJSONP sends a jsonp response. +func (c *Controller) ServeJSONP() { + hasIndent := BConfig.RunMode != PROD + c.Ctx.Output.JSONP(c.Data["jsonp"], hasIndent) +} + +// ServeXML sends xml response. +func (c *Controller) ServeXML() { + hasIndent := BConfig.RunMode != PROD + c.Ctx.Output.XML(c.Data["xml"], hasIndent) +} + +// ServeYAML sends yaml response. +func (c *Controller) ServeYAML() { + c.Ctx.Output.YAML(c.Data["yaml"]) +} + +// ServeFormatted serve YAML, XML OR JSON, depending on the value of the Accept header +func (c *Controller) ServeFormatted(encoding ...bool) { + hasIndent := BConfig.RunMode != PROD + hasEncoding := len(encoding) > 0 && encoding[0] + c.Ctx.Output.ServeFormatted(c.Data, hasIndent, hasEncoding) +} + +// Input returns the input data map from POST or PUT request body and query string. +func (c *Controller) Input() url.Values { + if c.Ctx.Request.Form == nil { + c.Ctx.Request.ParseForm() + } + return c.Ctx.Request.Form +} + +// ParseForm maps input data map to obj struct. +func (c *Controller) ParseForm(obj interface{}) error { + return ParseForm(c.Input(), obj) +} + +// GetString returns the input value by key string or the default value while it's present and input is blank +func (c *Controller) GetString(key string, def ...string) string { + if v := c.Ctx.Input.Query(key); v != "" { + return v + } + if len(def) > 0 { + return def[0] + } + return "" +} + +// GetStrings returns the input string slice by key string or the default value while it's present and input is blank +// it's designed for multi-value input field such as checkbox(input[type=checkbox]), multi-selection. +func (c *Controller) GetStrings(key string, def ...[]string) []string { + var defv []string + if len(def) > 0 { + defv = def[0] + } + + if f := c.Input(); f == nil { + return defv + } else if vs := f[key]; len(vs) > 0 { + return vs + } + + return defv +} + +// GetInt returns input as an int or the default value while it's present and input is blank +func (c *Controller) GetInt(key string, def ...int) (int, error) { + strv := c.Ctx.Input.Query(key) + if len(strv) == 0 && len(def) > 0 { + return def[0], nil + } + return strconv.Atoi(strv) +} + +// GetInt8 return input as an int8 or the default value while it's present and input is blank +func (c *Controller) GetInt8(key string, def ...int8) (int8, error) { + strv := c.Ctx.Input.Query(key) + if len(strv) == 0 && len(def) > 0 { + return def[0], nil + } + i64, err := strconv.ParseInt(strv, 10, 8) + return int8(i64), err +} + +// GetUint8 return input as an uint8 or the default value while it's present and input is blank +func (c *Controller) GetUint8(key string, def ...uint8) (uint8, error) { + strv := c.Ctx.Input.Query(key) + if len(strv) == 0 && len(def) > 0 { + return def[0], nil + } + u64, err := strconv.ParseUint(strv, 10, 8) + return uint8(u64), err +} + +// GetInt16 returns input as an int16 or the default value while it's present and input is blank +func (c *Controller) GetInt16(key string, def ...int16) (int16, error) { + strv := c.Ctx.Input.Query(key) + if len(strv) == 0 && len(def) > 0 { + return def[0], nil + } + i64, err := strconv.ParseInt(strv, 10, 16) + return int16(i64), err +} + +// GetUint16 returns input as an uint16 or the default value while it's present and input is blank +func (c *Controller) GetUint16(key string, def ...uint16) (uint16, error) { + strv := c.Ctx.Input.Query(key) + if len(strv) == 0 && len(def) > 0 { + return def[0], nil + } + u64, err := strconv.ParseUint(strv, 10, 16) + return uint16(u64), err +} + +// GetInt32 returns input as an int32 or the default value while it's present and input is blank +func (c *Controller) GetInt32(key string, def ...int32) (int32, error) { + strv := c.Ctx.Input.Query(key) + if len(strv) == 0 && len(def) > 0 { + return def[0], nil + } + i64, err := strconv.ParseInt(strv, 10, 32) + return int32(i64), err +} + +// GetUint32 returns input as an uint32 or the default value while it's present and input is blank +func (c *Controller) GetUint32(key string, def ...uint32) (uint32, error) { + strv := c.Ctx.Input.Query(key) + if len(strv) == 0 && len(def) > 0 { + return def[0], nil + } + u64, err := strconv.ParseUint(strv, 10, 32) + return uint32(u64), err +} + +// GetInt64 returns input value as int64 or the default value while it's present and input is blank. +func (c *Controller) GetInt64(key string, def ...int64) (int64, error) { + strv := c.Ctx.Input.Query(key) + if len(strv) == 0 && len(def) > 0 { + return def[0], nil + } + return strconv.ParseInt(strv, 10, 64) +} + +// GetUint64 returns input value as uint64 or the default value while it's present and input is blank. +func (c *Controller) GetUint64(key string, def ...uint64) (uint64, error) { + strv := c.Ctx.Input.Query(key) + if len(strv) == 0 && len(def) > 0 { + return def[0], nil + } + return strconv.ParseUint(strv, 10, 64) +} + +// GetBool returns input value as bool or the default value while it's present and input is blank. +func (c *Controller) GetBool(key string, def ...bool) (bool, error) { + strv := c.Ctx.Input.Query(key) + if len(strv) == 0 && len(def) > 0 { + return def[0], nil + } + return strconv.ParseBool(strv) +} + +// GetFloat returns input value as float64 or the default value while it's present and input is blank. +func (c *Controller) GetFloat(key string, def ...float64) (float64, error) { + strv := c.Ctx.Input.Query(key) + if len(strv) == 0 && len(def) > 0 { + return def[0], nil + } + return strconv.ParseFloat(strv, 64) +} + +// GetFile returns the file data in file upload field named as key. +// it returns the first one of multi-uploaded files. +func (c *Controller) GetFile(key string) (multipart.File, *multipart.FileHeader, error) { + return c.Ctx.Request.FormFile(key) +} + +// GetFiles return multi-upload files +// files, err:=c.GetFiles("myfiles") +// if err != nil { +// http.Error(w, err.Error(), http.StatusNoContent) +// return +// } +// for i, _ := range files { +// //for each fileheader, get a handle to the actual file +// file, err := files[i].Open() +// defer file.Close() +// if err != nil { +// http.Error(w, err.Error(), http.StatusInternalServerError) +// return +// } +// //create destination file making sure the path is writeable. +// dst, err := os.Create("upload/" + files[i].Filename) +// defer dst.Close() +// if err != nil { +// http.Error(w, err.Error(), http.StatusInternalServerError) +// return +// } +// //copy the uploaded file to the destination file +// if _, err := io.Copy(dst, file); err != nil { +// http.Error(w, err.Error(), http.StatusInternalServerError) +// return +// } +// } +func (c *Controller) GetFiles(key string) ([]*multipart.FileHeader, error) { + if files, ok := c.Ctx.Request.MultipartForm.File[key]; ok { + return files, nil + } + return nil, http.ErrMissingFile +} + +// SaveToFile saves uploaded file to new path. +// it only operates the first one of mutil-upload form file field. +func (c *Controller) SaveToFile(fromfile, tofile string) error { + file, _, err := c.Ctx.Request.FormFile(fromfile) + if err != nil { + return err + } + defer file.Close() + f, err := os.OpenFile(tofile, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0666) + if err != nil { + return err + } + defer f.Close() + io.Copy(f, file) + return nil +} + +// StartSession starts session and load old session data info this controller. +func (c *Controller) StartSession() session.Store { + if c.CruSession == nil { + c.CruSession = c.Ctx.Input.CruSession + } + return c.CruSession +} + +// SetSession puts value into session. +func (c *Controller) SetSession(name interface{}, value interface{}) { + if c.CruSession == nil { + c.StartSession() + } + c.CruSession.Set(name, value) +} + +// GetSession gets value from session. +func (c *Controller) GetSession(name interface{}) interface{} { + if c.CruSession == nil { + c.StartSession() + } + return c.CruSession.Get(name) +} + +// DelSession removes value from session. +func (c *Controller) DelSession(name interface{}) { + if c.CruSession == nil { + c.StartSession() + } + c.CruSession.Delete(name) +} + +// SessionRegenerateID regenerates session id for this session. +// the session data have no changes. +func (c *Controller) SessionRegenerateID() { + if c.CruSession != nil { + c.CruSession.SessionRelease(c.Ctx.ResponseWriter) + } + c.CruSession = GlobalSessions.SessionRegenerateID(c.Ctx.ResponseWriter, c.Ctx.Request) + c.Ctx.Input.CruSession = c.CruSession +} + +// DestroySession cleans session data and session cookie. +func (c *Controller) DestroySession() { + c.Ctx.Input.CruSession.Flush() + c.Ctx.Input.CruSession = nil + GlobalSessions.SessionDestroy(c.Ctx.ResponseWriter, c.Ctx.Request) +} + +// IsAjax returns this request is ajax or not. +func (c *Controller) IsAjax() bool { + return c.Ctx.Input.IsAjax() +} + +// GetSecureCookie returns decoded cookie value from encoded browser cookie values. +func (c *Controller) GetSecureCookie(Secret, key string) (string, bool) { + return c.Ctx.GetSecureCookie(Secret, key) +} + +// SetSecureCookie puts value into cookie after encoded the value. +func (c *Controller) SetSecureCookie(Secret, name, value string, others ...interface{}) { + c.Ctx.SetSecureCookie(Secret, name, value, others...) +} + +// XSRFToken creates a CSRF token string and returns. +func (c *Controller) XSRFToken() string { + if c._xsrfToken == "" { + expire := int64(BConfig.WebConfig.XSRFExpire) + if c.XSRFExpire > 0 { + expire = int64(c.XSRFExpire) + } + c._xsrfToken = c.Ctx.XSRFToken(BConfig.WebConfig.XSRFKey, expire) + } + return c._xsrfToken +} + +// CheckXSRFCookie checks xsrf token in this request is valid or not. +// the token can provided in request header "X-Xsrftoken" and "X-CsrfToken" +// or in form field value named as "_xsrf". +func (c *Controller) CheckXSRFCookie() bool { + if !c.EnableXSRF { + return true + } + return c.Ctx.CheckXSRFCookie() +} + +// XSRFFormHTML writes an input field contains xsrf token value. +func (c *Controller) XSRFFormHTML() string { + return `` +} + +// GetControllerAndAction gets the executing controller name and action name. +func (c *Controller) GetControllerAndAction() (string, string) { + return c.controllerName, c.actionName +} diff --git a/pkg/controller_test.go b/pkg/controller_test.go new file mode 100644 index 00000000..1e53416d --- /dev/null +++ b/pkg/controller_test.go @@ -0,0 +1,181 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package beego + +import ( + "math" + "strconv" + "testing" + + "github.com/astaxie/beego/context" + "os" + "path/filepath" +) + +func TestGetInt(t *testing.T) { + i := context.NewInput() + i.SetParam("age", "40") + ctx := &context.Context{Input: i} + ctrlr := Controller{Ctx: ctx} + val, _ := ctrlr.GetInt("age") + if val != 40 { + t.Errorf("TestGetInt expect 40,get %T,%v", val, val) + } +} + +func TestGetInt8(t *testing.T) { + i := context.NewInput() + i.SetParam("age", "40") + ctx := &context.Context{Input: i} + ctrlr := Controller{Ctx: ctx} + val, _ := ctrlr.GetInt8("age") + if val != 40 { + t.Errorf("TestGetInt8 expect 40,get %T,%v", val, val) + } + //Output: int8 +} + +func TestGetInt16(t *testing.T) { + i := context.NewInput() + i.SetParam("age", "40") + ctx := &context.Context{Input: i} + ctrlr := Controller{Ctx: ctx} + val, _ := ctrlr.GetInt16("age") + if val != 40 { + t.Errorf("TestGetInt16 expect 40,get %T,%v", val, val) + } +} + +func TestGetInt32(t *testing.T) { + i := context.NewInput() + i.SetParam("age", "40") + ctx := &context.Context{Input: i} + ctrlr := Controller{Ctx: ctx} + val, _ := ctrlr.GetInt32("age") + if val != 40 { + t.Errorf("TestGetInt32 expect 40,get %T,%v", val, val) + } +} + +func TestGetInt64(t *testing.T) { + i := context.NewInput() + i.SetParam("age", "40") + ctx := &context.Context{Input: i} + ctrlr := Controller{Ctx: ctx} + val, _ := ctrlr.GetInt64("age") + if val != 40 { + t.Errorf("TestGeetInt64 expect 40,get %T,%v", val, val) + } +} + +func TestGetUint8(t *testing.T) { + i := context.NewInput() + i.SetParam("age", strconv.FormatUint(math.MaxUint8, 10)) + ctx := &context.Context{Input: i} + ctrlr := Controller{Ctx: ctx} + val, _ := ctrlr.GetUint8("age") + if val != math.MaxUint8 { + t.Errorf("TestGetUint8 expect %v,get %T,%v", math.MaxUint8, val, val) + } +} + +func TestGetUint16(t *testing.T) { + i := context.NewInput() + i.SetParam("age", strconv.FormatUint(math.MaxUint16, 10)) + ctx := &context.Context{Input: i} + ctrlr := Controller{Ctx: ctx} + val, _ := ctrlr.GetUint16("age") + if val != math.MaxUint16 { + t.Errorf("TestGetUint16 expect %v,get %T,%v", math.MaxUint16, val, val) + } +} + +func TestGetUint32(t *testing.T) { + i := context.NewInput() + i.SetParam("age", strconv.FormatUint(math.MaxUint32, 10)) + ctx := &context.Context{Input: i} + ctrlr := Controller{Ctx: ctx} + val, _ := ctrlr.GetUint32("age") + if val != math.MaxUint32 { + t.Errorf("TestGetUint32 expect %v,get %T,%v", math.MaxUint32, val, val) + } +} + +func TestGetUint64(t *testing.T) { + i := context.NewInput() + i.SetParam("age", strconv.FormatUint(math.MaxUint64, 10)) + ctx := &context.Context{Input: i} + ctrlr := Controller{Ctx: ctx} + val, _ := ctrlr.GetUint64("age") + if val != math.MaxUint64 { + t.Errorf("TestGetUint64 expect %v,get %T,%v", uint64(math.MaxUint64), val, val) + } +} + +func TestAdditionalViewPaths(t *testing.T) { + dir1 := "_beeTmp" + dir2 := "_beeTmp2" + defer os.RemoveAll(dir1) + defer os.RemoveAll(dir2) + + dir1file := "file1.tpl" + dir2file := "file2.tpl" + + genFile := func(dir string, name string, content string) { + os.MkdirAll(filepath.Dir(filepath.Join(dir, name)), 0777) + if f, err := os.Create(filepath.Join(dir, name)); err != nil { + t.Fatal(err) + } else { + defer f.Close() + f.WriteString(content) + f.Close() + } + + } + genFile(dir1, dir1file, `
{{.Content}}
`) + genFile(dir2, dir2file, `{{.Content}}`) + + AddViewPath(dir1) + AddViewPath(dir2) + + ctrl := Controller{ + TplName: "file1.tpl", + ViewPath: dir1, + } + ctrl.Data = map[interface{}]interface{}{ + "Content": "value2", + } + if result, err := ctrl.RenderString(); err != nil { + t.Fatal(err) + } else { + if result != "
value2
" { + t.Fatalf("TestAdditionalViewPaths expect %s got %s", "
value2
", result) + } + } + + func() { + ctrl.TplName = "file2.tpl" + defer func() { + if r := recover(); r == nil { + t.Fatal("TestAdditionalViewPaths expected error") + } + }() + ctrl.RenderString() + }() + + ctrl.TplName = "file2.tpl" + ctrl.ViewPath = dir2 + ctrl.RenderString() +} diff --git a/pkg/doc.go b/pkg/doc.go new file mode 100644 index 00000000..8825bd29 --- /dev/null +++ b/pkg/doc.go @@ -0,0 +1,17 @@ +/* +Package beego provide a MVC framework +beego: an open-source, high-performance, modular, full-stack web framework + +It is used for rapid development of RESTful APIs, web apps and backend services in Go. +beego is inspired by Tornado, Sinatra and Flask with the added benefit of some Go-specific features such as interfaces and struct embedding. + + package main + import "github.com/astaxie/beego" + + func main() { + beego.Run() + } + +more information: http://beego.me +*/ +package beego diff --git a/pkg/error.go b/pkg/error.go new file mode 100644 index 00000000..f268f723 --- /dev/null +++ b/pkg/error.go @@ -0,0 +1,488 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package beego + +import ( + "fmt" + "html/template" + "net/http" + "reflect" + "runtime" + "strconv" + "strings" + + "github.com/astaxie/beego/context" + "github.com/astaxie/beego/utils" +) + +const ( + errorTypeHandler = iota + errorTypeController +) + +var tpl = ` + + + + + beego application error + + + + + +
+ + + + + + + + + + +
Request Method: {{.RequestMethod}}
Request URL: {{.RequestURL}}
RemoteAddr: {{.RemoteAddr }}
+
+ Stack +
{{.Stack}}
+
+
+ + + +` + +// render default application error page with error and stack string. +func showErr(err interface{}, ctx *context.Context, stack string) { + t, _ := template.New("beegoerrortemp").Parse(tpl) + data := map[string]string{ + "AppError": fmt.Sprintf("%s:%v", BConfig.AppName, err), + "RequestMethod": ctx.Input.Method(), + "RequestURL": ctx.Input.URI(), + "RemoteAddr": ctx.Input.IP(), + "Stack": stack, + "BeegoVersion": VERSION, + "GoVersion": runtime.Version(), + } + t.Execute(ctx.ResponseWriter, data) +} + +var errtpl = ` + + + + + {{.Title}} + + + +
+
+ +
+ {{.Content}} + Go Home
+ +
Powered by beego {{.BeegoVersion}} +
+
+
+ + +` + +type errorInfo struct { + controllerType reflect.Type + handler http.HandlerFunc + method string + errorType int +} + +// ErrorMaps holds map of http handlers for each error string. +// there is 10 kinds default error(40x and 50x) +var ErrorMaps = make(map[string]*errorInfo, 10) + +// show 401 unauthorized error. +func unauthorized(rw http.ResponseWriter, r *http.Request) { + responseError(rw, r, + 401, + "
The page you have requested can't be authorized."+ + "
Perhaps you are here because:"+ + "

    "+ + "
    The credentials you supplied are incorrect"+ + "
    There are errors in the website address"+ + "
", + ) +} + +// show 402 Payment Required +func paymentRequired(rw http.ResponseWriter, r *http.Request) { + responseError(rw, r, + 402, + "
The page you have requested Payment Required."+ + "
Perhaps you are here because:"+ + "

    "+ + "
    The credentials you supplied are incorrect"+ + "
    There are errors in the website address"+ + "
", + ) +} + +// show 403 forbidden error. +func forbidden(rw http.ResponseWriter, r *http.Request) { + responseError(rw, r, + 403, + "
The page you have requested is forbidden."+ + "
Perhaps you are here because:"+ + "

    "+ + "
    Your address may be blocked"+ + "
    The site may be disabled"+ + "
    You need to log in"+ + "
", + ) +} + +// show 422 missing xsrf token +func missingxsrf(rw http.ResponseWriter, r *http.Request) { + responseError(rw, r, + 422, + "
The page you have requested is forbidden."+ + "
Perhaps you are here because:"+ + "

    "+ + "
    '_xsrf' argument missing from POST"+ + "
", + ) +} + +// show 417 invalid xsrf token +func invalidxsrf(rw http.ResponseWriter, r *http.Request) { + responseError(rw, r, + 417, + "
The page you have requested is forbidden."+ + "
Perhaps you are here because:"+ + "

    "+ + "
    expected XSRF not found"+ + "
", + ) +} + +// show 404 not found error. +func notFound(rw http.ResponseWriter, r *http.Request) { + responseError(rw, r, + 404, + "
The page you have requested has flown the coop."+ + "
Perhaps you are here because:"+ + "

    "+ + "
    The page has moved"+ + "
    The page no longer exists"+ + "
    You were looking for your puppy and got lost"+ + "
    You like 404 pages"+ + "
", + ) +} + +// show 405 Method Not Allowed +func methodNotAllowed(rw http.ResponseWriter, r *http.Request) { + responseError(rw, r, + 405, + "
The method you have requested Not Allowed."+ + "
Perhaps you are here because:"+ + "

    "+ + "
    The method specified in the Request-Line is not allowed for the resource identified by the Request-URI"+ + "
    The response MUST include an Allow header containing a list of valid methods for the requested resource."+ + "
", + ) +} + +// show 500 internal server error. +func internalServerError(rw http.ResponseWriter, r *http.Request) { + responseError(rw, r, + 500, + "
The page you have requested is down right now."+ + "

    "+ + "
    Please try again later and report the error to the website administrator"+ + "
", + ) +} + +// show 501 Not Implemented. +func notImplemented(rw http.ResponseWriter, r *http.Request) { + responseError(rw, r, + 501, + "
The page you have requested is Not Implemented."+ + "

    "+ + "
    Please try again later and report the error to the website administrator"+ + "
", + ) +} + +// show 502 Bad Gateway. +func badGateway(rw http.ResponseWriter, r *http.Request) { + responseError(rw, r, + 502, + "
The page you have requested is down right now."+ + "

    "+ + "
    The server, while acting as a gateway or proxy, received an invalid response from the upstream server it accessed in attempting to fulfill the request."+ + "
    Please try again later and report the error to the website administrator"+ + "
", + ) +} + +// show 503 service unavailable error. +func serviceUnavailable(rw http.ResponseWriter, r *http.Request) { + responseError(rw, r, + 503, + "
The page you have requested is unavailable."+ + "
Perhaps you are here because:"+ + "

    "+ + "

    The page is overloaded"+ + "
    Please try again later."+ + "
", + ) +} + +// show 504 Gateway Timeout. +func gatewayTimeout(rw http.ResponseWriter, r *http.Request) { + responseError(rw, r, + 504, + "
The page you have requested is unavailable"+ + "
Perhaps you are here because:"+ + "

    "+ + "

    The server, while acting as a gateway or proxy, did not receive a timely response from the upstream server specified by the URI."+ + "
    Please try again later."+ + "
", + ) +} + +// show 413 Payload Too Large +func payloadTooLarge(rw http.ResponseWriter, r *http.Request) { + responseError(rw, r, + 413, + `
The page you have requested is unavailable. +
Perhaps you are here because:

+
    +
    The request entity is larger than limits defined by server. +
    Please change the request entity and try again. +
+ `, + ) +} + +func responseError(rw http.ResponseWriter, r *http.Request, errCode int, errContent string) { + t, _ := template.New("beegoerrortemp").Parse(errtpl) + data := M{ + "Title": http.StatusText(errCode), + "BeegoVersion": VERSION, + "Content": template.HTML(errContent), + } + t.Execute(rw, data) +} + +// ErrorHandler registers http.HandlerFunc to each http err code string. +// usage: +// beego.ErrorHandler("404",NotFound) +// beego.ErrorHandler("500",InternalServerError) +func ErrorHandler(code string, h http.HandlerFunc) *App { + ErrorMaps[code] = &errorInfo{ + errorType: errorTypeHandler, + handler: h, + method: code, + } + return BeeApp +} + +// ErrorController registers ControllerInterface to each http err code string. +// usage: +// beego.ErrorController(&controllers.ErrorController{}) +func ErrorController(c ControllerInterface) *App { + reflectVal := reflect.ValueOf(c) + rt := reflectVal.Type() + ct := reflect.Indirect(reflectVal).Type() + for i := 0; i < rt.NumMethod(); i++ { + methodName := rt.Method(i).Name + if !utils.InSlice(methodName, exceptMethod) && strings.HasPrefix(methodName, "Error") { + errName := strings.TrimPrefix(methodName, "Error") + ErrorMaps[errName] = &errorInfo{ + errorType: errorTypeController, + controllerType: ct, + method: methodName, + } + } + } + return BeeApp +} + +// Exception Write HttpStatus with errCode and Exec error handler if exist. +func Exception(errCode uint64, ctx *context.Context) { + exception(strconv.FormatUint(errCode, 10), ctx) +} + +// show error string as simple text message. +// if error string is empty, show 503 or 500 error as default. +func exception(errCode string, ctx *context.Context) { + atoi := func(code string) int { + v, err := strconv.Atoi(code) + if err == nil { + return v + } + if ctx.Output.Status == 0 { + return 503 + } + return ctx.Output.Status + } + + for _, ec := range []string{errCode, "503", "500"} { + if h, ok := ErrorMaps[ec]; ok { + executeError(h, ctx, atoi(ec)) + return + } + } + //if 50x error has been removed from errorMap + ctx.ResponseWriter.WriteHeader(atoi(errCode)) + ctx.WriteString(errCode) +} + +func executeError(err *errorInfo, ctx *context.Context, code int) { + //make sure to log the error in the access log + LogAccess(ctx, nil, code) + + if err.errorType == errorTypeHandler { + ctx.ResponseWriter.WriteHeader(code) + err.handler(ctx.ResponseWriter, ctx.Request) + return + } + if err.errorType == errorTypeController { + ctx.Output.SetStatus(code) + //Invoke the request handler + vc := reflect.New(err.controllerType) + execController, ok := vc.Interface().(ControllerInterface) + if !ok { + panic("controller is not ControllerInterface") + } + //call the controller init function + execController.Init(ctx, err.controllerType.Name(), err.method, vc.Interface()) + + //call prepare function + execController.Prepare() + + execController.URLMapping() + + method := vc.MethodByName(err.method) + method.Call([]reflect.Value{}) + + //render template + if BConfig.WebConfig.AutoRender { + if err := execController.Render(); err != nil { + panic(err) + } + } + + // finish all runrouter. release resource + execController.Finish() + } +} diff --git a/pkg/error_test.go b/pkg/error_test.go new file mode 100644 index 00000000..378aa953 --- /dev/null +++ b/pkg/error_test.go @@ -0,0 +1,88 @@ +// Copyright 2016 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package beego + +import ( + "net/http" + "net/http/httptest" + "strconv" + "strings" + "testing" +) + +type errorTestController struct { + Controller +} + +const parseCodeError = "parse code error" + +func (ec *errorTestController) Get() { + errorCode, err := ec.GetInt("code") + if err != nil { + ec.Abort(parseCodeError) + } + if errorCode != 0 { + ec.CustomAbort(errorCode, ec.GetString("code")) + } + ec.Abort("404") +} + +func TestErrorCode_01(t *testing.T) { + registerDefaultErrorHandler() + for k := range ErrorMaps { + r, _ := http.NewRequest("GET", "/error?code="+k, nil) + w := httptest.NewRecorder() + + handler := NewControllerRegister() + handler.Add("/error", &errorTestController{}) + handler.ServeHTTP(w, r) + code, _ := strconv.Atoi(k) + if w.Code != code { + t.Fail() + } + if !strings.Contains(w.Body.String(), http.StatusText(code)) { + t.Fail() + } + } +} + +func TestErrorCode_02(t *testing.T) { + registerDefaultErrorHandler() + r, _ := http.NewRequest("GET", "/error?code=0", nil) + w := httptest.NewRecorder() + + handler := NewControllerRegister() + handler.Add("/error", &errorTestController{}) + handler.ServeHTTP(w, r) + if w.Code != 404 { + t.Fail() + } +} + +func TestErrorCode_03(t *testing.T) { + registerDefaultErrorHandler() + r, _ := http.NewRequest("GET", "/error?code=panic", nil) + w := httptest.NewRecorder() + + handler := NewControllerRegister() + handler.Add("/error", &errorTestController{}) + handler.ServeHTTP(w, r) + if w.Code != 200 { + t.Fail() + } + if w.Body.String() != parseCodeError { + t.Fail() + } +} diff --git a/pkg/filter.go b/pkg/filter.go new file mode 100644 index 00000000..9cc6e913 --- /dev/null +++ b/pkg/filter.go @@ -0,0 +1,44 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package beego + +import "github.com/astaxie/beego/context" + +// FilterFunc defines a filter function which is invoked before the controller handler is executed. +type FilterFunc func(*context.Context) + +// FilterRouter defines a filter operation which is invoked before the controller handler is executed. +// It can match the URL against a pattern, and execute a filter function +// when a request with a matching URL arrives. +type FilterRouter struct { + filterFunc FilterFunc + tree *Tree + pattern string + returnOnOutput bool + resetParams bool +} + +// ValidRouter checks if the current request is matched by this filter. +// If the request is matched, the values of the URL parameters defined +// by the filter pattern are also returned. +func (f *FilterRouter) ValidRouter(url string, ctx *context.Context) bool { + isOk := f.tree.Match(url, ctx) + if isOk != nil { + if b, ok := isOk.(bool); ok { + return b + } + } + return false +} diff --git a/pkg/filter_test.go b/pkg/filter_test.go new file mode 100644 index 00000000..4ca4d2b8 --- /dev/null +++ b/pkg/filter_test.go @@ -0,0 +1,68 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package beego + +import ( + "net/http" + "net/http/httptest" + "testing" + + "github.com/astaxie/beego/context" +) + +var FilterUser = func(ctx *context.Context) { + ctx.Output.Body([]byte("i am " + ctx.Input.Param(":last") + ctx.Input.Param(":first"))) +} + +func TestFilter(t *testing.T) { + r, _ := http.NewRequest("GET", "/person/asta/Xie", nil) + w := httptest.NewRecorder() + handler := NewControllerRegister() + handler.InsertFilter("/person/:last/:first", BeforeRouter, FilterUser) + handler.Add("/person/:last/:first", &TestController{}) + handler.ServeHTTP(w, r) + if w.Body.String() != "i am astaXie" { + t.Errorf("user define func can't run") + } +} + +var FilterAdminUser = func(ctx *context.Context) { + ctx.Output.Body([]byte("i am admin")) +} + +// Filter pattern /admin/:all +// all url like /admin/ /admin/xie will all get filter + +func TestPatternTwo(t *testing.T) { + r, _ := http.NewRequest("GET", "/admin/", nil) + w := httptest.NewRecorder() + handler := NewControllerRegister() + handler.InsertFilter("/admin/?:all", BeforeRouter, FilterAdminUser) + handler.ServeHTTP(w, r) + if w.Body.String() != "i am admin" { + t.Errorf("filter /admin/ can't run") + } +} + +func TestPatternThree(t *testing.T) { + r, _ := http.NewRequest("GET", "/admin/astaxie", nil) + w := httptest.NewRecorder() + handler := NewControllerRegister() + handler.InsertFilter("/admin/:all", BeforeRouter, FilterAdminUser) + handler.ServeHTTP(w, r) + if w.Body.String() != "i am admin" { + t.Errorf("filter /admin/astaxie can't run") + } +} diff --git a/pkg/flash.go b/pkg/flash.go new file mode 100644 index 00000000..a6485a17 --- /dev/null +++ b/pkg/flash.go @@ -0,0 +1,110 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package beego + +import ( + "fmt" + "net/url" + "strings" +) + +// FlashData is a tools to maintain data when using across request. +type FlashData struct { + Data map[string]string +} + +// NewFlash return a new empty FlashData struct. +func NewFlash() *FlashData { + return &FlashData{ + Data: make(map[string]string), + } +} + +// Set message to flash +func (fd *FlashData) Set(key string, msg string, args ...interface{}) { + if len(args) == 0 { + fd.Data[key] = msg + } else { + fd.Data[key] = fmt.Sprintf(msg, args...) + } +} + +// Success writes success message to flash. +func (fd *FlashData) Success(msg string, args ...interface{}) { + if len(args) == 0 { + fd.Data["success"] = msg + } else { + fd.Data["success"] = fmt.Sprintf(msg, args...) + } +} + +// Notice writes notice message to flash. +func (fd *FlashData) Notice(msg string, args ...interface{}) { + if len(args) == 0 { + fd.Data["notice"] = msg + } else { + fd.Data["notice"] = fmt.Sprintf(msg, args...) + } +} + +// Warning writes warning message to flash. +func (fd *FlashData) Warning(msg string, args ...interface{}) { + if len(args) == 0 { + fd.Data["warning"] = msg + } else { + fd.Data["warning"] = fmt.Sprintf(msg, args...) + } +} + +// Error writes error message to flash. +func (fd *FlashData) Error(msg string, args ...interface{}) { + if len(args) == 0 { + fd.Data["error"] = msg + } else { + fd.Data["error"] = fmt.Sprintf(msg, args...) + } +} + +// Store does the saving operation of flash data. +// the data are encoded and saved in cookie. +func (fd *FlashData) Store(c *Controller) { + c.Data["flash"] = fd.Data + var flashValue string + for key, value := range fd.Data { + flashValue += "\x00" + key + "\x23" + BConfig.WebConfig.FlashSeparator + "\x23" + value + "\x00" + } + c.Ctx.SetCookie(BConfig.WebConfig.FlashName, url.QueryEscape(flashValue), 0, "/") +} + +// ReadFromRequest parsed flash data from encoded values in cookie. +func ReadFromRequest(c *Controller) *FlashData { + flash := NewFlash() + if cookie, err := c.Ctx.Request.Cookie(BConfig.WebConfig.FlashName); err == nil { + v, _ := url.QueryUnescape(cookie.Value) + vals := strings.Split(v, "\x00") + for _, v := range vals { + if len(v) > 0 { + kv := strings.Split(v, "\x23"+BConfig.WebConfig.FlashSeparator+"\x23") + if len(kv) == 2 { + flash.Data[kv[0]] = kv[1] + } + } + } + //read one time then delete it + c.Ctx.SetCookie(BConfig.WebConfig.FlashName, "", -1, "/") + } + c.Data["flash"] = flash.Data + return flash +} diff --git a/pkg/flash_test.go b/pkg/flash_test.go new file mode 100644 index 00000000..d5e9608d --- /dev/null +++ b/pkg/flash_test.go @@ -0,0 +1,54 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package beego + +import ( + "net/http" + "net/http/httptest" + "strings" + "testing" +) + +type TestFlashController struct { + Controller +} + +func (t *TestFlashController) TestWriteFlash() { + flash := NewFlash() + flash.Notice("TestFlashString") + flash.Store(&t.Controller) + // we choose to serve json because we don't want to load a template html file + t.ServeJSON(true) +} + +func TestFlashHeader(t *testing.T) { + // create fake GET request + r, _ := http.NewRequest("GET", "/", nil) + w := httptest.NewRecorder() + + // setup the handler + handler := NewControllerRegister() + handler.Add("/", &TestFlashController{}, "get:TestWriteFlash") + handler.ServeHTTP(w, r) + + // get the Set-Cookie value + sc := w.Header().Get("Set-Cookie") + // match for the expected header + res := strings.Contains(sc, "BEEGO_FLASH=%00notice%23BEEGOFLASH%23TestFlashString%00") + // validate the assertion + if !res { + t.Errorf("TestFlashHeader() unable to validate flash message") + } +} diff --git a/pkg/fs.go b/pkg/fs.go new file mode 100644 index 00000000..41cc6f6e --- /dev/null +++ b/pkg/fs.go @@ -0,0 +1,74 @@ +package beego + +import ( + "net/http" + "os" + "path/filepath" +) + +type FileSystem struct { +} + +func (d FileSystem) Open(name string) (http.File, error) { + return os.Open(name) +} + +// Walk walks the file tree rooted at root in filesystem, calling walkFn for each file or +// directory in the tree, including root. All errors that arise visiting files +// and directories are filtered by walkFn. +func Walk(fs http.FileSystem, root string, walkFn filepath.WalkFunc) error { + + f, err := fs.Open(root) + if err != nil { + return err + } + info, err := f.Stat() + if err != nil { + err = walkFn(root, nil, err) + } else { + err = walk(fs, root, info, walkFn) + } + if err == filepath.SkipDir { + return nil + } + return err +} + +// walk recursively descends path, calling walkFn. +func walk(fs http.FileSystem, path string, info os.FileInfo, walkFn filepath.WalkFunc) error { + var err error + if !info.IsDir() { + return walkFn(path, info, nil) + } + + dir, err := fs.Open(path) + if err != nil { + if err1 := walkFn(path, info, err); err1 != nil { + return err1 + } + return err + } + defer dir.Close() + dirs, err := dir.Readdir(-1) + err1 := walkFn(path, info, err) + // If err != nil, walk can't walk into this directory. + // err1 != nil means walkFn want walk to skip this directory or stop walking. + // Therefore, if one of err and err1 isn't nil, walk will return. + if err != nil || err1 != nil { + // The caller's behavior is controlled by the return value, which is decided + // by walkFn. walkFn may ignore err and return nil. + // If walkFn returns SkipDir, it will be handled by the caller. + // So walk should return whatever walkFn returns. + return err1 + } + + for _, fileInfo := range dirs { + filename := filepath.Join(path, fileInfo.Name()) + if err = walk(fs, filename, fileInfo, walkFn); err != nil { + if !fileInfo.IsDir() || err != filepath.SkipDir { + return err + } + } + } + return nil +} diff --git a/pkg/grace/grace.go b/pkg/grace/grace.go new file mode 100644 index 00000000..fb0cb7bb --- /dev/null +++ b/pkg/grace/grace.go @@ -0,0 +1,166 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package grace use to hot reload +// Description: http://grisha.org/blog/2014/06/03/graceful-restart-in-golang/ +// +// Usage: +// +// import( +// "log" +// "net/http" +// "os" +// +// "github.com/astaxie/beego/grace" +// ) +// +// func handler(w http.ResponseWriter, r *http.Request) { +// w.Write([]byte("WORLD!")) +// } +// +// func main() { +// mux := http.NewServeMux() +// mux.HandleFunc("/hello", handler) +// +// err := grace.ListenAndServe("localhost:8080", mux) +// if err != nil { +// log.Println(err) +// } +// log.Println("Server on 8080 stopped") +// os.Exit(0) +// } +package grace + +import ( + "flag" + "net/http" + "os" + "strings" + "sync" + "syscall" + "time" +) + +const ( + // PreSignal is the position to add filter before signal + PreSignal = iota + // PostSignal is the position to add filter after signal + PostSignal + // StateInit represent the application inited + StateInit + // StateRunning represent the application is running + StateRunning + // StateShuttingDown represent the application is shutting down + StateShuttingDown + // StateTerminate represent the application is killed + StateTerminate +) + +var ( + regLock *sync.Mutex + runningServers map[string]*Server + runningServersOrder []string + socketPtrOffsetMap map[string]uint + runningServersForked bool + + // DefaultReadTimeOut is the HTTP read timeout + DefaultReadTimeOut time.Duration + // DefaultWriteTimeOut is the HTTP Write timeout + DefaultWriteTimeOut time.Duration + // DefaultMaxHeaderBytes is the Max HTTP Header size, default is 0, no limit + DefaultMaxHeaderBytes int + // DefaultTimeout is the shutdown server's timeout. default is 60s + DefaultTimeout = 60 * time.Second + + isChild bool + socketOrder string + + hookableSignals []os.Signal +) + +func init() { + flag.BoolVar(&isChild, "graceful", false, "listen on open fd (after forking)") + flag.StringVar(&socketOrder, "socketorder", "", "previous initialization order - used when more than one listener was started") + + regLock = &sync.Mutex{} + runningServers = make(map[string]*Server) + runningServersOrder = []string{} + socketPtrOffsetMap = make(map[string]uint) + + hookableSignals = []os.Signal{ + syscall.SIGHUP, + syscall.SIGINT, + syscall.SIGTERM, + } +} + +// NewServer returns a new graceServer. +func NewServer(addr string, handler http.Handler) (srv *Server) { + regLock.Lock() + defer regLock.Unlock() + + if !flag.Parsed() { + flag.Parse() + } + if len(socketOrder) > 0 { + for i, addr := range strings.Split(socketOrder, ",") { + socketPtrOffsetMap[addr] = uint(i) + } + } else { + socketPtrOffsetMap[addr] = uint(len(runningServersOrder)) + } + + srv = &Server{ + sigChan: make(chan os.Signal), + isChild: isChild, + SignalHooks: map[int]map[os.Signal][]func(){ + PreSignal: { + syscall.SIGHUP: {}, + syscall.SIGINT: {}, + syscall.SIGTERM: {}, + }, + PostSignal: { + syscall.SIGHUP: {}, + syscall.SIGINT: {}, + syscall.SIGTERM: {}, + }, + }, + state: StateInit, + Network: "tcp", + terminalChan: make(chan error), //no cache channel + } + srv.Server = &http.Server{ + Addr: addr, + ReadTimeout: DefaultReadTimeOut, + WriteTimeout: DefaultWriteTimeOut, + MaxHeaderBytes: DefaultMaxHeaderBytes, + Handler: handler, + } + + runningServersOrder = append(runningServersOrder, addr) + runningServers[addr] = srv + return srv +} + +// ListenAndServe refer http.ListenAndServe +func ListenAndServe(addr string, handler http.Handler) error { + server := NewServer(addr, handler) + return server.ListenAndServe() +} + +// ListenAndServeTLS refer http.ListenAndServeTLS +func ListenAndServeTLS(addr string, certFile string, keyFile string, handler http.Handler) error { + server := NewServer(addr, handler) + return server.ListenAndServeTLS(certFile, keyFile) +} diff --git a/pkg/grace/server.go b/pkg/grace/server.go new file mode 100644 index 00000000..008a6171 --- /dev/null +++ b/pkg/grace/server.go @@ -0,0 +1,356 @@ +package grace + +import ( + "context" + "crypto/tls" + "crypto/x509" + "fmt" + "io/ioutil" + "log" + "net" + "net/http" + "os" + "os/exec" + "os/signal" + "strings" + "syscall" + "time" +) + +// Server embedded http.Server +type Server struct { + *http.Server + ln net.Listener + SignalHooks map[int]map[os.Signal][]func() + sigChan chan os.Signal + isChild bool + state uint8 + Network string + terminalChan chan error +} + +// Serve accepts incoming connections on the Listener l, +// creating a new service goroutine for each. +// The service goroutines read requests and then call srv.Handler to reply to them. +func (srv *Server) Serve() (err error) { + srv.state = StateRunning + defer func() { srv.state = StateTerminate }() + + // When Shutdown is called, Serve, ListenAndServe, and ListenAndServeTLS + // immediately return ErrServerClosed. Make sure the program doesn't exit + // and waits instead for Shutdown to return. + if err = srv.Server.Serve(srv.ln); err != nil && err != http.ErrServerClosed { + log.Println(syscall.Getpid(), "Server.Serve() error:", err) + return err + } + + log.Println(syscall.Getpid(), srv.ln.Addr(), "Listener closed.") + // wait for Shutdown to return + if shutdownErr := <-srv.terminalChan; shutdownErr != nil { + return shutdownErr + } + return +} + +// ListenAndServe listens on the TCP network address srv.Addr and then calls Serve +// to handle requests on incoming connections. If srv.Addr is blank, ":http" is +// used. +func (srv *Server) ListenAndServe() (err error) { + addr := srv.Addr + if addr == "" { + addr = ":http" + } + + go srv.handleSignals() + + srv.ln, err = srv.getListener(addr) + if err != nil { + log.Println(err) + return err + } + + if srv.isChild { + process, err := os.FindProcess(os.Getppid()) + if err != nil { + log.Println(err) + return err + } + err = process.Signal(syscall.SIGTERM) + if err != nil { + return err + } + } + + log.Println(os.Getpid(), srv.Addr) + return srv.Serve() +} + +// ListenAndServeTLS listens on the TCP network address srv.Addr and then calls +// Serve to handle requests on incoming TLS connections. +// +// Filenames containing a certificate and matching private key for the server must +// be provided. If the certificate is signed by a certificate authority, the +// certFile should be the concatenation of the server's certificate followed by the +// CA's certificate. +// +// If srv.Addr is blank, ":https" is used. +func (srv *Server) ListenAndServeTLS(certFile, keyFile string) (err error) { + addr := srv.Addr + if addr == "" { + addr = ":https" + } + + if srv.TLSConfig == nil { + srv.TLSConfig = &tls.Config{} + } + if srv.TLSConfig.NextProtos == nil { + srv.TLSConfig.NextProtos = []string{"http/1.1"} + } + + srv.TLSConfig.Certificates = make([]tls.Certificate, 1) + srv.TLSConfig.Certificates[0], err = tls.LoadX509KeyPair(certFile, keyFile) + if err != nil { + return + } + + go srv.handleSignals() + + ln, err := srv.getListener(addr) + if err != nil { + log.Println(err) + return err + } + srv.ln = tls.NewListener(tcpKeepAliveListener{ln.(*net.TCPListener)}, srv.TLSConfig) + + if srv.isChild { + process, err := os.FindProcess(os.Getppid()) + if err != nil { + log.Println(err) + return err + } + err = process.Signal(syscall.SIGTERM) + if err != nil { + return err + } + } + + log.Println(os.Getpid(), srv.Addr) + return srv.Serve() +} + +// ListenAndServeMutualTLS listens on the TCP network address srv.Addr and then calls +// Serve to handle requests on incoming mutual TLS connections. +func (srv *Server) ListenAndServeMutualTLS(certFile, keyFile, trustFile string) (err error) { + addr := srv.Addr + if addr == "" { + addr = ":https" + } + + if srv.TLSConfig == nil { + srv.TLSConfig = &tls.Config{} + } + if srv.TLSConfig.NextProtos == nil { + srv.TLSConfig.NextProtos = []string{"http/1.1"} + } + + srv.TLSConfig.Certificates = make([]tls.Certificate, 1) + srv.TLSConfig.Certificates[0], err = tls.LoadX509KeyPair(certFile, keyFile) + if err != nil { + return + } + srv.TLSConfig.ClientAuth = tls.RequireAndVerifyClientCert + pool := x509.NewCertPool() + data, err := ioutil.ReadFile(trustFile) + if err != nil { + log.Println(err) + return err + } + pool.AppendCertsFromPEM(data) + srv.TLSConfig.ClientCAs = pool + log.Println("Mutual HTTPS") + go srv.handleSignals() + + ln, err := srv.getListener(addr) + if err != nil { + log.Println(err) + return err + } + srv.ln = tls.NewListener(tcpKeepAliveListener{ln.(*net.TCPListener)}, srv.TLSConfig) + + if srv.isChild { + process, err := os.FindProcess(os.Getppid()) + if err != nil { + log.Println(err) + return err + } + err = process.Signal(syscall.SIGTERM) + if err != nil { + return err + } + } + + log.Println(os.Getpid(), srv.Addr) + return srv.Serve() +} + +// getListener either opens a new socket to listen on, or takes the acceptor socket +// it got passed when restarted. +func (srv *Server) getListener(laddr string) (l net.Listener, err error) { + if srv.isChild { + var ptrOffset uint + if len(socketPtrOffsetMap) > 0 { + ptrOffset = socketPtrOffsetMap[laddr] + log.Println("laddr", laddr, "ptr offset", socketPtrOffsetMap[laddr]) + } + + f := os.NewFile(uintptr(3+ptrOffset), "") + l, err = net.FileListener(f) + if err != nil { + err = fmt.Errorf("net.FileListener error: %v", err) + return + } + } else { + l, err = net.Listen(srv.Network, laddr) + if err != nil { + err = fmt.Errorf("net.Listen error: %v", err) + return + } + } + return +} + +type tcpKeepAliveListener struct { + *net.TCPListener +} + +func (ln tcpKeepAliveListener) Accept() (c net.Conn, err error) { + tc, err := ln.AcceptTCP() + if err != nil { + return + } + tc.SetKeepAlive(true) + tc.SetKeepAlivePeriod(3 * time.Minute) + return tc, nil +} + +// handleSignals listens for os Signals and calls any hooked in function that the +// user had registered with the signal. +func (srv *Server) handleSignals() { + var sig os.Signal + + signal.Notify( + srv.sigChan, + hookableSignals..., + ) + + pid := syscall.Getpid() + for { + sig = <-srv.sigChan + srv.signalHooks(PreSignal, sig) + switch sig { + case syscall.SIGHUP: + log.Println(pid, "Received SIGHUP. forking.") + err := srv.fork() + if err != nil { + log.Println("Fork err:", err) + } + case syscall.SIGINT: + log.Println(pid, "Received SIGINT.") + srv.shutdown() + case syscall.SIGTERM: + log.Println(pid, "Received SIGTERM.") + srv.shutdown() + default: + log.Printf("Received %v: nothing i care about...\n", sig) + } + srv.signalHooks(PostSignal, sig) + } +} + +func (srv *Server) signalHooks(ppFlag int, sig os.Signal) { + if _, notSet := srv.SignalHooks[ppFlag][sig]; !notSet { + return + } + for _, f := range srv.SignalHooks[ppFlag][sig] { + f() + } +} + +// shutdown closes the listener so that no new connections are accepted. it also +// starts a goroutine that will serverTimeout (stop all running requests) the server +// after DefaultTimeout. +func (srv *Server) shutdown() { + if srv.state != StateRunning { + return + } + + srv.state = StateShuttingDown + log.Println(syscall.Getpid(), "Waiting for connections to finish...") + ctx := context.Background() + if DefaultTimeout >= 0 { + var cancel context.CancelFunc + ctx, cancel = context.WithTimeout(context.Background(), DefaultTimeout) + defer cancel() + } + srv.terminalChan <- srv.Server.Shutdown(ctx) +} + +func (srv *Server) fork() (err error) { + regLock.Lock() + defer regLock.Unlock() + if runningServersForked { + return + } + runningServersForked = true + + var files = make([]*os.File, len(runningServers)) + var orderArgs = make([]string, len(runningServers)) + for _, srvPtr := range runningServers { + f, _ := srvPtr.ln.(*net.TCPListener).File() + files[socketPtrOffsetMap[srvPtr.Server.Addr]] = f + orderArgs[socketPtrOffsetMap[srvPtr.Server.Addr]] = srvPtr.Server.Addr + } + + log.Println(files) + path := os.Args[0] + var args []string + if len(os.Args) > 1 { + for _, arg := range os.Args[1:] { + if arg == "-graceful" { + break + } + args = append(args, arg) + } + } + args = append(args, "-graceful") + if len(runningServers) > 1 { + args = append(args, fmt.Sprintf(`-socketorder=%s`, strings.Join(orderArgs, ","))) + log.Println(args) + } + cmd := exec.Command(path, args...) + cmd.Stdout = os.Stdout + cmd.Stderr = os.Stderr + cmd.ExtraFiles = files + err = cmd.Start() + if err != nil { + log.Fatalf("Restart: Failed to launch, error: %v", err) + } + + return +} + +// RegisterSignalHook registers a function to be run PreSignal or PostSignal for a given signal. +func (srv *Server) RegisterSignalHook(ppFlag int, sig os.Signal, f func()) (err error) { + if ppFlag != PreSignal && ppFlag != PostSignal { + err = fmt.Errorf("Invalid ppFlag argument. Must be either grace.PreSignal or grace.PostSignal") + return + } + for _, s := range hookableSignals { + if s == sig { + srv.SignalHooks[ppFlag][sig] = append(srv.SignalHooks[ppFlag][sig], f) + return + } + } + err = fmt.Errorf("Signal '%v' is not supported", sig) + return +} diff --git a/pkg/hooks.go b/pkg/hooks.go new file mode 100644 index 00000000..49c42d5a --- /dev/null +++ b/pkg/hooks.go @@ -0,0 +1,104 @@ +package beego + +import ( + "encoding/json" + "mime" + "net/http" + "path/filepath" + + "github.com/astaxie/beego/context" + "github.com/astaxie/beego/logs" + "github.com/astaxie/beego/session" +) + +// register MIME type with content type +func registerMime() error { + for k, v := range mimemaps { + mime.AddExtensionType(k, v) + } + return nil +} + +// register default error http handlers, 404,401,403,500 and 503. +func registerDefaultErrorHandler() error { + m := map[string]func(http.ResponseWriter, *http.Request){ + "401": unauthorized, + "402": paymentRequired, + "403": forbidden, + "404": notFound, + "405": methodNotAllowed, + "500": internalServerError, + "501": notImplemented, + "502": badGateway, + "503": serviceUnavailable, + "504": gatewayTimeout, + "417": invalidxsrf, + "422": missingxsrf, + "413": payloadTooLarge, + } + for e, h := range m { + if _, ok := ErrorMaps[e]; !ok { + ErrorHandler(e, h) + } + } + return nil +} + +func registerSession() error { + if BConfig.WebConfig.Session.SessionOn { + var err error + sessionConfig := AppConfig.String("sessionConfig") + conf := new(session.ManagerConfig) + if sessionConfig == "" { + conf.CookieName = BConfig.WebConfig.Session.SessionName + conf.EnableSetCookie = BConfig.WebConfig.Session.SessionAutoSetCookie + conf.Gclifetime = BConfig.WebConfig.Session.SessionGCMaxLifetime + conf.Secure = BConfig.Listen.EnableHTTPS + conf.CookieLifeTime = BConfig.WebConfig.Session.SessionCookieLifeTime + conf.ProviderConfig = filepath.ToSlash(BConfig.WebConfig.Session.SessionProviderConfig) + conf.DisableHTTPOnly = BConfig.WebConfig.Session.SessionDisableHTTPOnly + conf.Domain = BConfig.WebConfig.Session.SessionDomain + conf.EnableSidInHTTPHeader = BConfig.WebConfig.Session.SessionEnableSidInHTTPHeader + conf.SessionNameInHTTPHeader = BConfig.WebConfig.Session.SessionNameInHTTPHeader + conf.EnableSidInURLQuery = BConfig.WebConfig.Session.SessionEnableSidInURLQuery + } else { + if err = json.Unmarshal([]byte(sessionConfig), conf); err != nil { + return err + } + } + if GlobalSessions, err = session.NewManager(BConfig.WebConfig.Session.SessionProvider, conf); err != nil { + return err + } + go GlobalSessions.GC() + } + return nil +} + +func registerTemplate() error { + defer lockViewPaths() + if err := AddViewPath(BConfig.WebConfig.ViewsPath); err != nil { + if BConfig.RunMode == DEV { + logs.Warn(err) + } + return err + } + return nil +} + +func registerAdmin() error { + if BConfig.Listen.EnableAdmin { + go beeAdminApp.Run() + } + return nil +} + +func registerGzip() error { + if BConfig.EnableGzip { + context.InitGzip( + AppConfig.DefaultInt("gzipMinLength", -1), + AppConfig.DefaultInt("gzipCompressLevel", -1), + AppConfig.DefaultStrings("includedMethods", []string{"GET"}), + ) + } + return nil +} diff --git a/pkg/httplib/README.md b/pkg/httplib/README.md new file mode 100644 index 00000000..97df8e6b --- /dev/null +++ b/pkg/httplib/README.md @@ -0,0 +1,97 @@ +# httplib +httplib is an libs help you to curl remote url. + +# How to use? + +## GET +you can use Get to crawl data. + + import "github.com/astaxie/beego/httplib" + + str, err := httplib.Get("http://beego.me/").String() + if err != nil { + // error + } + fmt.Println(str) + +## POST +POST data to remote url + + req := httplib.Post("http://beego.me/") + req.Param("username","astaxie") + req.Param("password","123456") + str, err := req.String() + if err != nil { + // error + } + fmt.Println(str) + +## Set timeout + +The default timeout is `60` seconds, function prototype: + + SetTimeout(connectTimeout, readWriteTimeout time.Duration) + +Example: + + // GET + httplib.Get("http://beego.me/").SetTimeout(100 * time.Second, 30 * time.Second) + + // POST + httplib.Post("http://beego.me/").SetTimeout(100 * time.Second, 30 * time.Second) + + +## Debug + +If you want to debug the request info, set the debug on + + httplib.Get("http://beego.me/").Debug(true) + +## Set HTTP Basic Auth + + str, err := Get("http://beego.me/").SetBasicAuth("user", "passwd").String() + if err != nil { + // error + } + fmt.Println(str) + +## Set HTTPS + +If request url is https, You can set the client support TSL: + + httplib.SetTLSClientConfig(&tls.Config{InsecureSkipVerify: true}) + +More info about the `tls.Config` please visit http://golang.org/pkg/crypto/tls/#Config + +## Set HTTP Version + +some servers need to specify the protocol version of HTTP + + httplib.Get("http://beego.me/").SetProtocolVersion("HTTP/1.1") + +## Set Cookie + +some http request need setcookie. So set it like this: + + cookie := &http.Cookie{} + cookie.Name = "username" + cookie.Value = "astaxie" + httplib.Get("http://beego.me/").SetCookie(cookie) + +## Upload file + +httplib support mutil file upload, use `req.PostFile()` + + req := httplib.Post("http://beego.me/") + req.Param("username","astaxie") + req.PostFile("uploadfile1", "httplib.pdf") + str, err := req.String() + if err != nil { + // error + } + fmt.Println(str) + + +See godoc for further documentation and examples. + +* [godoc.org/github.com/astaxie/beego/httplib](https://godoc.org/github.com/astaxie/beego/httplib) diff --git a/pkg/httplib/httplib.go b/pkg/httplib/httplib.go new file mode 100644 index 00000000..60aa4e8b --- /dev/null +++ b/pkg/httplib/httplib.go @@ -0,0 +1,654 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package httplib is used as http.Client +// Usage: +// +// import "github.com/astaxie/beego/httplib" +// +// b := httplib.Post("http://beego.me/") +// b.Param("username","astaxie") +// b.Param("password","123456") +// b.PostFile("uploadfile1", "httplib.pdf") +// b.PostFile("uploadfile2", "httplib.txt") +// str, err := b.String() +// if err != nil { +// t.Fatal(err) +// } +// fmt.Println(str) +// +// more docs http://beego.me/docs/module/httplib.md +package httplib + +import ( + "bytes" + "compress/gzip" + "crypto/tls" + "encoding/json" + "encoding/xml" + "io" + "io/ioutil" + "log" + "mime/multipart" + "net" + "net/http" + "net/http/cookiejar" + "net/http/httputil" + "net/url" + "os" + "path" + "strings" + "sync" + "time" + + "gopkg.in/yaml.v2" +) + +var defaultSetting = BeegoHTTPSettings{ + UserAgent: "beegoServer", + ConnectTimeout: 60 * time.Second, + ReadWriteTimeout: 60 * time.Second, + Gzip: true, + DumpBody: true, +} + +var defaultCookieJar http.CookieJar +var settingMutex sync.Mutex + +// createDefaultCookie creates a global cookiejar to store cookies. +func createDefaultCookie() { + settingMutex.Lock() + defer settingMutex.Unlock() + defaultCookieJar, _ = cookiejar.New(nil) +} + +// SetDefaultSetting Overwrite default settings +func SetDefaultSetting(setting BeegoHTTPSettings) { + settingMutex.Lock() + defer settingMutex.Unlock() + defaultSetting = setting +} + +// NewBeegoRequest return *BeegoHttpRequest with specific method +func NewBeegoRequest(rawurl, method string) *BeegoHTTPRequest { + var resp http.Response + u, err := url.Parse(rawurl) + if err != nil { + log.Println("Httplib:", err) + } + req := http.Request{ + URL: u, + Method: method, + Header: make(http.Header), + Proto: "HTTP/1.1", + ProtoMajor: 1, + ProtoMinor: 1, + } + return &BeegoHTTPRequest{ + url: rawurl, + req: &req, + params: map[string][]string{}, + files: map[string]string{}, + setting: defaultSetting, + resp: &resp, + } +} + +// Get returns *BeegoHttpRequest with GET method. +func Get(url string) *BeegoHTTPRequest { + return NewBeegoRequest(url, "GET") +} + +// Post returns *BeegoHttpRequest with POST method. +func Post(url string) *BeegoHTTPRequest { + return NewBeegoRequest(url, "POST") +} + +// Put returns *BeegoHttpRequest with PUT method. +func Put(url string) *BeegoHTTPRequest { + return NewBeegoRequest(url, "PUT") +} + +// Delete returns *BeegoHttpRequest DELETE method. +func Delete(url string) *BeegoHTTPRequest { + return NewBeegoRequest(url, "DELETE") +} + +// Head returns *BeegoHttpRequest with HEAD method. +func Head(url string) *BeegoHTTPRequest { + return NewBeegoRequest(url, "HEAD") +} + +// BeegoHTTPSettings is the http.Client setting +type BeegoHTTPSettings struct { + ShowDebug bool + UserAgent string + ConnectTimeout time.Duration + ReadWriteTimeout time.Duration + TLSClientConfig *tls.Config + Proxy func(*http.Request) (*url.URL, error) + Transport http.RoundTripper + CheckRedirect func(req *http.Request, via []*http.Request) error + EnableCookie bool + Gzip bool + DumpBody bool + Retries int // if set to -1 means will retry forever + RetryDelay time.Duration +} + +// BeegoHTTPRequest provides more useful methods for requesting one url than http.Request. +type BeegoHTTPRequest struct { + url string + req *http.Request + params map[string][]string + files map[string]string + setting BeegoHTTPSettings + resp *http.Response + body []byte + dump []byte +} + +// GetRequest return the request object +func (b *BeegoHTTPRequest) GetRequest() *http.Request { + return b.req +} + +// Setting Change request settings +func (b *BeegoHTTPRequest) Setting(setting BeegoHTTPSettings) *BeegoHTTPRequest { + b.setting = setting + return b +} + +// SetBasicAuth sets the request's Authorization header to use HTTP Basic Authentication with the provided username and password. +func (b *BeegoHTTPRequest) SetBasicAuth(username, password string) *BeegoHTTPRequest { + b.req.SetBasicAuth(username, password) + return b +} + +// SetEnableCookie sets enable/disable cookiejar +func (b *BeegoHTTPRequest) SetEnableCookie(enable bool) *BeegoHTTPRequest { + b.setting.EnableCookie = enable + return b +} + +// SetUserAgent sets User-Agent header field +func (b *BeegoHTTPRequest) SetUserAgent(useragent string) *BeegoHTTPRequest { + b.setting.UserAgent = useragent + return b +} + +// Debug sets show debug or not when executing request. +func (b *BeegoHTTPRequest) Debug(isdebug bool) *BeegoHTTPRequest { + b.setting.ShowDebug = isdebug + return b +} + +// Retries sets Retries times. +// default is 0 means no retried. +// -1 means retried forever. +// others means retried times. +func (b *BeegoHTTPRequest) Retries(times int) *BeegoHTTPRequest { + b.setting.Retries = times + return b +} + +func (b *BeegoHTTPRequest) RetryDelay(delay time.Duration) *BeegoHTTPRequest { + b.setting.RetryDelay = delay + return b +} + +// DumpBody setting whether need to Dump the Body. +func (b *BeegoHTTPRequest) DumpBody(isdump bool) *BeegoHTTPRequest { + b.setting.DumpBody = isdump + return b +} + +// DumpRequest return the DumpRequest +func (b *BeegoHTTPRequest) DumpRequest() []byte { + return b.dump +} + +// SetTimeout sets connect time out and read-write time out for BeegoRequest. +func (b *BeegoHTTPRequest) SetTimeout(connectTimeout, readWriteTimeout time.Duration) *BeegoHTTPRequest { + b.setting.ConnectTimeout = connectTimeout + b.setting.ReadWriteTimeout = readWriteTimeout + return b +} + +// SetTLSClientConfig sets tls connection configurations if visiting https url. +func (b *BeegoHTTPRequest) SetTLSClientConfig(config *tls.Config) *BeegoHTTPRequest { + b.setting.TLSClientConfig = config + return b +} + +// Header add header item string in request. +func (b *BeegoHTTPRequest) Header(key, value string) *BeegoHTTPRequest { + b.req.Header.Set(key, value) + return b +} + +// SetHost set the request host +func (b *BeegoHTTPRequest) SetHost(host string) *BeegoHTTPRequest { + b.req.Host = host + return b +} + +// SetProtocolVersion Set the protocol version for incoming requests. +// Client requests always use HTTP/1.1. +func (b *BeegoHTTPRequest) SetProtocolVersion(vers string) *BeegoHTTPRequest { + if len(vers) == 0 { + vers = "HTTP/1.1" + } + + major, minor, ok := http.ParseHTTPVersion(vers) + if ok { + b.req.Proto = vers + b.req.ProtoMajor = major + b.req.ProtoMinor = minor + } + + return b +} + +// SetCookie add cookie into request. +func (b *BeegoHTTPRequest) SetCookie(cookie *http.Cookie) *BeegoHTTPRequest { + b.req.Header.Add("Cookie", cookie.String()) + return b +} + +// SetTransport set the setting transport +func (b *BeegoHTTPRequest) SetTransport(transport http.RoundTripper) *BeegoHTTPRequest { + b.setting.Transport = transport + return b +} + +// SetProxy set the http proxy +// example: +// +// func(req *http.Request) (*url.URL, error) { +// u, _ := url.ParseRequestURI("http://127.0.0.1:8118") +// return u, nil +// } +func (b *BeegoHTTPRequest) SetProxy(proxy func(*http.Request) (*url.URL, error)) *BeegoHTTPRequest { + b.setting.Proxy = proxy + return b +} + +// SetCheckRedirect specifies the policy for handling redirects. +// +// If CheckRedirect is nil, the Client uses its default policy, +// which is to stop after 10 consecutive requests. +func (b *BeegoHTTPRequest) SetCheckRedirect(redirect func(req *http.Request, via []*http.Request) error) *BeegoHTTPRequest { + b.setting.CheckRedirect = redirect + return b +} + +// Param adds query param in to request. +// params build query string as ?key1=value1&key2=value2... +func (b *BeegoHTTPRequest) Param(key, value string) *BeegoHTTPRequest { + if param, ok := b.params[key]; ok { + b.params[key] = append(param, value) + } else { + b.params[key] = []string{value} + } + return b +} + +// PostFile add a post file to the request +func (b *BeegoHTTPRequest) PostFile(formname, filename string) *BeegoHTTPRequest { + b.files[formname] = filename + return b +} + +// Body adds request raw body. +// it supports string and []byte. +func (b *BeegoHTTPRequest) Body(data interface{}) *BeegoHTTPRequest { + switch t := data.(type) { + case string: + bf := bytes.NewBufferString(t) + b.req.Body = ioutil.NopCloser(bf) + b.req.ContentLength = int64(len(t)) + case []byte: + bf := bytes.NewBuffer(t) + b.req.Body = ioutil.NopCloser(bf) + b.req.ContentLength = int64(len(t)) + } + return b +} + +// XMLBody adds request raw body encoding by XML. +func (b *BeegoHTTPRequest) XMLBody(obj interface{}) (*BeegoHTTPRequest, error) { + if b.req.Body == nil && obj != nil { + byts, err := xml.Marshal(obj) + if err != nil { + return b, err + } + b.req.Body = ioutil.NopCloser(bytes.NewReader(byts)) + b.req.ContentLength = int64(len(byts)) + b.req.Header.Set("Content-Type", "application/xml") + } + return b, nil +} + +// YAMLBody adds request raw body encoding by YAML. +func (b *BeegoHTTPRequest) YAMLBody(obj interface{}) (*BeegoHTTPRequest, error) { + if b.req.Body == nil && obj != nil { + byts, err := yaml.Marshal(obj) + if err != nil { + return b, err + } + b.req.Body = ioutil.NopCloser(bytes.NewReader(byts)) + b.req.ContentLength = int64(len(byts)) + b.req.Header.Set("Content-Type", "application/x+yaml") + } + return b, nil +} + +// JSONBody adds request raw body encoding by JSON. +func (b *BeegoHTTPRequest) JSONBody(obj interface{}) (*BeegoHTTPRequest, error) { + if b.req.Body == nil && obj != nil { + byts, err := json.Marshal(obj) + if err != nil { + return b, err + } + b.req.Body = ioutil.NopCloser(bytes.NewReader(byts)) + b.req.ContentLength = int64(len(byts)) + b.req.Header.Set("Content-Type", "application/json") + } + return b, nil +} + +func (b *BeegoHTTPRequest) buildURL(paramBody string) { + // build GET url with query string + if b.req.Method == "GET" && len(paramBody) > 0 { + if strings.Contains(b.url, "?") { + b.url += "&" + paramBody + } else { + b.url = b.url + "?" + paramBody + } + return + } + + // build POST/PUT/PATCH url and body + if (b.req.Method == "POST" || b.req.Method == "PUT" || b.req.Method == "PATCH" || b.req.Method == "DELETE") && b.req.Body == nil { + // with files + if len(b.files) > 0 { + pr, pw := io.Pipe() + bodyWriter := multipart.NewWriter(pw) + go func() { + for formname, filename := range b.files { + fileWriter, err := bodyWriter.CreateFormFile(formname, filename) + if err != nil { + log.Println("Httplib:", err) + } + fh, err := os.Open(filename) + if err != nil { + log.Println("Httplib:", err) + } + //iocopy + _, err = io.Copy(fileWriter, fh) + fh.Close() + if err != nil { + log.Println("Httplib:", err) + } + } + for k, v := range b.params { + for _, vv := range v { + bodyWriter.WriteField(k, vv) + } + } + bodyWriter.Close() + pw.Close() + }() + b.Header("Content-Type", bodyWriter.FormDataContentType()) + b.req.Body = ioutil.NopCloser(pr) + b.Header("Transfer-Encoding", "chunked") + return + } + + // with params + if len(paramBody) > 0 { + b.Header("Content-Type", "application/x-www-form-urlencoded") + b.Body(paramBody) + } + } +} + +func (b *BeegoHTTPRequest) getResponse() (*http.Response, error) { + if b.resp.StatusCode != 0 { + return b.resp, nil + } + resp, err := b.DoRequest() + if err != nil { + return nil, err + } + b.resp = resp + return resp, nil +} + +// DoRequest will do the client.Do +func (b *BeegoHTTPRequest) DoRequest() (resp *http.Response, err error) { + var paramBody string + if len(b.params) > 0 { + var buf bytes.Buffer + for k, v := range b.params { + for _, vv := range v { + buf.WriteString(url.QueryEscape(k)) + buf.WriteByte('=') + buf.WriteString(url.QueryEscape(vv)) + buf.WriteByte('&') + } + } + paramBody = buf.String() + paramBody = paramBody[0 : len(paramBody)-1] + } + + b.buildURL(paramBody) + urlParsed, err := url.Parse(b.url) + if err != nil { + return nil, err + } + + b.req.URL = urlParsed + + trans := b.setting.Transport + + if trans == nil { + // create default transport + trans = &http.Transport{ + TLSClientConfig: b.setting.TLSClientConfig, + Proxy: b.setting.Proxy, + Dial: TimeoutDialer(b.setting.ConnectTimeout, b.setting.ReadWriteTimeout), + MaxIdleConnsPerHost: 100, + } + } else { + // if b.transport is *http.Transport then set the settings. + if t, ok := trans.(*http.Transport); ok { + if t.TLSClientConfig == nil { + t.TLSClientConfig = b.setting.TLSClientConfig + } + if t.Proxy == nil { + t.Proxy = b.setting.Proxy + } + if t.Dial == nil { + t.Dial = TimeoutDialer(b.setting.ConnectTimeout, b.setting.ReadWriteTimeout) + } + } + } + + var jar http.CookieJar + if b.setting.EnableCookie { + if defaultCookieJar == nil { + createDefaultCookie() + } + jar = defaultCookieJar + } + + client := &http.Client{ + Transport: trans, + Jar: jar, + } + + if b.setting.UserAgent != "" && b.req.Header.Get("User-Agent") == "" { + b.req.Header.Set("User-Agent", b.setting.UserAgent) + } + + if b.setting.CheckRedirect != nil { + client.CheckRedirect = b.setting.CheckRedirect + } + + if b.setting.ShowDebug { + dump, err := httputil.DumpRequest(b.req, b.setting.DumpBody) + if err != nil { + log.Println(err.Error()) + } + b.dump = dump + } + // retries default value is 0, it will run once. + // retries equal to -1, it will run forever until success + // retries is setted, it will retries fixed times. + // Sleeps for a 400ms inbetween calls to reduce spam + for i := 0; b.setting.Retries == -1 || i <= b.setting.Retries; i++ { + resp, err = client.Do(b.req) + if err == nil { + break + } + time.Sleep(b.setting.RetryDelay) + } + return resp, err +} + +// String returns the body string in response. +// it calls Response inner. +func (b *BeegoHTTPRequest) String() (string, error) { + data, err := b.Bytes() + if err != nil { + return "", err + } + + return string(data), nil +} + +// Bytes returns the body []byte in response. +// it calls Response inner. +func (b *BeegoHTTPRequest) Bytes() ([]byte, error) { + if b.body != nil { + return b.body, nil + } + resp, err := b.getResponse() + if err != nil { + return nil, err + } + if resp.Body == nil { + return nil, nil + } + defer resp.Body.Close() + if b.setting.Gzip && resp.Header.Get("Content-Encoding") == "gzip" { + reader, err := gzip.NewReader(resp.Body) + if err != nil { + return nil, err + } + b.body, err = ioutil.ReadAll(reader) + return b.body, err + } + b.body, err = ioutil.ReadAll(resp.Body) + return b.body, err +} + +// ToFile saves the body data in response to one file. +// it calls Response inner. +func (b *BeegoHTTPRequest) ToFile(filename string) error { + resp, err := b.getResponse() + if err != nil { + return err + } + if resp.Body == nil { + return nil + } + defer resp.Body.Close() + err = pathExistAndMkdir(filename) + if err != nil { + return err + } + f, err := os.Create(filename) + if err != nil { + return err + } + defer f.Close() + _, err = io.Copy(f, resp.Body) + return err +} + +//Check that the file directory exists, there is no automatically created +func pathExistAndMkdir(filename string) (err error) { + filename = path.Dir(filename) + _, err = os.Stat(filename) + if err == nil { + return nil + } + if os.IsNotExist(err) { + err = os.MkdirAll(filename, os.ModePerm) + if err == nil { + return nil + } + } + return err +} + +// ToJSON returns the map that marshals from the body bytes as json in response . +// it calls Response inner. +func (b *BeegoHTTPRequest) ToJSON(v interface{}) error { + data, err := b.Bytes() + if err != nil { + return err + } + return json.Unmarshal(data, v) +} + +// ToXML returns the map that marshals from the body bytes as xml in response . +// it calls Response inner. +func (b *BeegoHTTPRequest) ToXML(v interface{}) error { + data, err := b.Bytes() + if err != nil { + return err + } + return xml.Unmarshal(data, v) +} + +// ToYAML returns the map that marshals from the body bytes as yaml in response . +// it calls Response inner. +func (b *BeegoHTTPRequest) ToYAML(v interface{}) error { + data, err := b.Bytes() + if err != nil { + return err + } + return yaml.Unmarshal(data, v) +} + +// Response executes request client gets response mannually. +func (b *BeegoHTTPRequest) Response() (*http.Response, error) { + return b.getResponse() +} + +// TimeoutDialer returns functions of connection dialer with timeout settings for http.Transport Dial field. +func TimeoutDialer(cTimeout time.Duration, rwTimeout time.Duration) func(net, addr string) (c net.Conn, err error) { + return func(netw, addr string) (net.Conn, error) { + conn, err := net.DialTimeout(netw, addr, cTimeout) + if err != nil { + return nil, err + } + err = conn.SetDeadline(time.Now().Add(rwTimeout)) + return conn, err + } +} diff --git a/pkg/httplib/httplib_test.go b/pkg/httplib/httplib_test.go new file mode 100644 index 00000000..f6be8571 --- /dev/null +++ b/pkg/httplib/httplib_test.go @@ -0,0 +1,286 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package httplib + +import ( + "errors" + "io/ioutil" + "net" + "net/http" + "os" + "strings" + "testing" + "time" +) + +func TestResponse(t *testing.T) { + req := Get("http://httpbin.org/get") + resp, err := req.Response() + if err != nil { + t.Fatal(err) + } + t.Log(resp) +} + +func TestDoRequest(t *testing.T) { + req := Get("https://goolnk.com/33BD2j") + retryAmount := 1 + req.Retries(1) + req.RetryDelay(1400 * time.Millisecond) + retryDelay := 1400 * time.Millisecond + + req.setting.CheckRedirect = func(redirectReq *http.Request, redirectVia []*http.Request) error { + return errors.New("Redirect triggered") + } + + startTime := time.Now().UnixNano() / int64(time.Millisecond) + + _, err := req.Response() + if err == nil { + t.Fatal("Response should have yielded an error") + } + + endTime := time.Now().UnixNano() / int64(time.Millisecond) + elapsedTime := endTime - startTime + delayedTime := int64(retryAmount) * retryDelay.Milliseconds() + + if elapsedTime < delayedTime { + t.Errorf("Not enough retries. Took %dms. Delay was meant to take %dms", elapsedTime, delayedTime) + } + +} + +func TestGet(t *testing.T) { + req := Get("http://httpbin.org/get") + b, err := req.Bytes() + if err != nil { + t.Fatal(err) + } + t.Log(b) + + s, err := req.String() + if err != nil { + t.Fatal(err) + } + t.Log(s) + + if string(b) != s { + t.Fatal("request data not match") + } +} + +func TestSimplePost(t *testing.T) { + v := "smallfish" + req := Post("http://httpbin.org/post") + req.Param("username", v) + + str, err := req.String() + if err != nil { + t.Fatal(err) + } + t.Log(str) + + n := strings.Index(str, v) + if n == -1 { + t.Fatal(v + " not found in post") + } +} + +//func TestPostFile(t *testing.T) { +// v := "smallfish" +// req := Post("http://httpbin.org/post") +// req.Debug(true) +// req.Param("username", v) +// req.PostFile("uploadfile", "httplib_test.go") + +// str, err := req.String() +// if err != nil { +// t.Fatal(err) +// } +// t.Log(str) + +// n := strings.Index(str, v) +// if n == -1 { +// t.Fatal(v + " not found in post") +// } +//} + +func TestSimplePut(t *testing.T) { + str, err := Put("http://httpbin.org/put").String() + if err != nil { + t.Fatal(err) + } + t.Log(str) +} + +func TestSimpleDelete(t *testing.T) { + str, err := Delete("http://httpbin.org/delete").String() + if err != nil { + t.Fatal(err) + } + t.Log(str) +} + +func TestSimpleDeleteParam(t *testing.T) { + str, err := Delete("http://httpbin.org/delete").Param("key", "val").String() + if err != nil { + t.Fatal(err) + } + t.Log(str) +} + +func TestWithCookie(t *testing.T) { + v := "smallfish" + str, err := Get("http://httpbin.org/cookies/set?k1=" + v).SetEnableCookie(true).String() + if err != nil { + t.Fatal(err) + } + t.Log(str) + + str, err = Get("http://httpbin.org/cookies").SetEnableCookie(true).String() + if err != nil { + t.Fatal(err) + } + t.Log(str) + + n := strings.Index(str, v) + if n == -1 { + t.Fatal(v + " not found in cookie") + } +} + +func TestWithBasicAuth(t *testing.T) { + str, err := Get("http://httpbin.org/basic-auth/user/passwd").SetBasicAuth("user", "passwd").String() + if err != nil { + t.Fatal(err) + } + t.Log(str) + n := strings.Index(str, "authenticated") + if n == -1 { + t.Fatal("authenticated not found in response") + } +} + +func TestWithUserAgent(t *testing.T) { + v := "beego" + str, err := Get("http://httpbin.org/headers").SetUserAgent(v).String() + if err != nil { + t.Fatal(err) + } + t.Log(str) + + n := strings.Index(str, v) + if n == -1 { + t.Fatal(v + " not found in user-agent") + } +} + +func TestWithSetting(t *testing.T) { + v := "beego" + var setting BeegoHTTPSettings + setting.EnableCookie = true + setting.UserAgent = v + setting.Transport = &http.Transport{ + DialContext: (&net.Dialer{ + Timeout: 30 * time.Second, + KeepAlive: 30 * time.Second, + DualStack: true, + }).DialContext, + MaxIdleConns: 50, + IdleConnTimeout: 90 * time.Second, + ExpectContinueTimeout: 1 * time.Second, + } + setting.ReadWriteTimeout = 5 * time.Second + SetDefaultSetting(setting) + + str, err := Get("http://httpbin.org/get").String() + if err != nil { + t.Fatal(err) + } + t.Log(str) + + n := strings.Index(str, v) + if n == -1 { + t.Fatal(v + " not found in user-agent") + } +} + +func TestToJson(t *testing.T) { + req := Get("http://httpbin.org/ip") + resp, err := req.Response() + if err != nil { + t.Fatal(err) + } + t.Log(resp) + + // httpbin will return http remote addr + type IP struct { + Origin string `json:"origin"` + } + var ip IP + err = req.ToJSON(&ip) + if err != nil { + t.Fatal(err) + } + t.Log(ip.Origin) + ips := strings.Split(ip.Origin, ",") + if len(ips) == 0 { + t.Fatal("response is not valid ip") + } + for i := range ips { + if net.ParseIP(strings.TrimSpace(ips[i])).To4() == nil { + t.Fatal("response is not valid ip") + } + } + +} + +func TestToFile(t *testing.T) { + f := "beego_testfile" + req := Get("http://httpbin.org/ip") + err := req.ToFile(f) + if err != nil { + t.Fatal(err) + } + defer os.Remove(f) + b, err := ioutil.ReadFile(f) + if n := strings.Index(string(b), "origin"); n == -1 { + t.Fatal(err) + } +} + +func TestToFileDir(t *testing.T) { + f := "./files/beego_testfile" + req := Get("http://httpbin.org/ip") + err := req.ToFile(f) + if err != nil { + t.Fatal(err) + } + defer os.RemoveAll("./files") + b, err := ioutil.ReadFile(f) + if n := strings.Index(string(b), "origin"); n == -1 { + t.Fatal(err) + } +} + +func TestHeader(t *testing.T) { + req := Get("http://httpbin.org/headers") + req.Header("User-Agent", "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_9_0) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/31.0.1650.57 Safari/537.36") + str, err := req.String() + if err != nil { + t.Fatal(err) + } + t.Log(str) +} diff --git a/pkg/log.go b/pkg/log.go new file mode 100644 index 00000000..cc4c0f81 --- /dev/null +++ b/pkg/log.go @@ -0,0 +1,127 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package beego + +import ( + "strings" + + "github.com/astaxie/beego/logs" +) + +// Log levels to control the logging output. +// Deprecated: use github.com/astaxie/beego/logs instead. +const ( + LevelEmergency = iota + LevelAlert + LevelCritical + LevelError + LevelWarning + LevelNotice + LevelInformational + LevelDebug +) + +// BeeLogger references the used application logger. +// Deprecated: use github.com/astaxie/beego/logs instead. +var BeeLogger = logs.GetBeeLogger() + +// SetLevel sets the global log level used by the simple logger. +// Deprecated: use github.com/astaxie/beego/logs instead. +func SetLevel(l int) { + logs.SetLevel(l) +} + +// SetLogFuncCall set the CallDepth, default is 3 +// Deprecated: use github.com/astaxie/beego/logs instead. +func SetLogFuncCall(b bool) { + logs.SetLogFuncCall(b) +} + +// SetLogger sets a new logger. +// Deprecated: use github.com/astaxie/beego/logs instead. +func SetLogger(adaptername string, config string) error { + return logs.SetLogger(adaptername, config) +} + +// Emergency logs a message at emergency level. +// Deprecated: use github.com/astaxie/beego/logs instead. +func Emergency(v ...interface{}) { + logs.Emergency(generateFmtStr(len(v)), v...) +} + +// Alert logs a message at alert level. +// Deprecated: use github.com/astaxie/beego/logs instead. +func Alert(v ...interface{}) { + logs.Alert(generateFmtStr(len(v)), v...) +} + +// Critical logs a message at critical level. +// Deprecated: use github.com/astaxie/beego/logs instead. +func Critical(v ...interface{}) { + logs.Critical(generateFmtStr(len(v)), v...) +} + +// Error logs a message at error level. +// Deprecated: use github.com/astaxie/beego/logs instead. +func Error(v ...interface{}) { + logs.Error(generateFmtStr(len(v)), v...) +} + +// Warning logs a message at warning level. +// Deprecated: use github.com/astaxie/beego/logs instead. +func Warning(v ...interface{}) { + logs.Warning(generateFmtStr(len(v)), v...) +} + +// Warn compatibility alias for Warning() +// Deprecated: use github.com/astaxie/beego/logs instead. +func Warn(v ...interface{}) { + logs.Warn(generateFmtStr(len(v)), v...) +} + +// Notice logs a message at notice level. +// Deprecated: use github.com/astaxie/beego/logs instead. +func Notice(v ...interface{}) { + logs.Notice(generateFmtStr(len(v)), v...) +} + +// Informational logs a message at info level. +// Deprecated: use github.com/astaxie/beego/logs instead. +func Informational(v ...interface{}) { + logs.Informational(generateFmtStr(len(v)), v...) +} + +// Info compatibility alias for Warning() +// Deprecated: use github.com/astaxie/beego/logs instead. +func Info(v ...interface{}) { + logs.Info(generateFmtStr(len(v)), v...) +} + +// Debug logs a message at debug level. +// Deprecated: use github.com/astaxie/beego/logs instead. +func Debug(v ...interface{}) { + logs.Debug(generateFmtStr(len(v)), v...) +} + +// Trace logs a message at trace level. +// compatibility alias for Warning() +// Deprecated: use github.com/astaxie/beego/logs instead. +func Trace(v ...interface{}) { + logs.Trace(generateFmtStr(len(v)), v...) +} + +func generateFmtStr(n int) string { + return strings.Repeat("%v ", n) +} diff --git a/pkg/logs/README.md b/pkg/logs/README.md new file mode 100644 index 00000000..c05bcc04 --- /dev/null +++ b/pkg/logs/README.md @@ -0,0 +1,72 @@ +## logs +logs is a Go logs manager. It can use many logs adapters. The repo is inspired by `database/sql` . + + +## How to install? + + go get github.com/astaxie/beego/logs + + +## What adapters are supported? + +As of now this logs support console, file,smtp and conn. + + +## How to use it? + +First you must import it + +```golang +import ( + "github.com/astaxie/beego/logs" +) +``` + +Then init a Log (example with console adapter) + +```golang +log := logs.NewLogger(10000) +log.SetLogger("console", "") +``` + +> the first params stand for how many channel + +Use it like this: + +```golang +log.Trace("trace") +log.Info("info") +log.Warn("warning") +log.Debug("debug") +log.Critical("critical") +``` + +## File adapter + +Configure file adapter like this: + +```golang +log := NewLogger(10000) +log.SetLogger("file", `{"filename":"test.log"}`) +``` + +## Conn adapter + +Configure like this: + +```golang +log := NewLogger(1000) +log.SetLogger("conn", `{"net":"tcp","addr":":7020"}`) +log.Info("info") +``` + +## Smtp adapter + +Configure like this: + +```golang +log := NewLogger(10000) +log.SetLogger("smtp", `{"username":"beegotest@gmail.com","password":"xxxxxxxx","host":"smtp.gmail.com:587","sendTos":["xiemengjun@gmail.com"]}`) +log.Critical("sendmail critical") +time.Sleep(time.Second * 30) +``` diff --git a/pkg/logs/accesslog.go b/pkg/logs/accesslog.go new file mode 100644 index 00000000..3ff9e20f --- /dev/null +++ b/pkg/logs/accesslog.go @@ -0,0 +1,83 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package logs + +import ( + "bytes" + "strings" + "encoding/json" + "fmt" + "time" +) + +const ( + apacheFormatPattern = "%s - - [%s] \"%s %d %d\" %f %s %s" + apacheFormat = "APACHE_FORMAT" + jsonFormat = "JSON_FORMAT" +) + +// AccessLogRecord struct for holding access log data. +type AccessLogRecord struct { + RemoteAddr string `json:"remote_addr"` + RequestTime time.Time `json:"request_time"` + RequestMethod string `json:"request_method"` + Request string `json:"request"` + ServerProtocol string `json:"server_protocol"` + Host string `json:"host"` + Status int `json:"status"` + BodyBytesSent int64 `json:"body_bytes_sent"` + ElapsedTime time.Duration `json:"elapsed_time"` + HTTPReferrer string `json:"http_referrer"` + HTTPUserAgent string `json:"http_user_agent"` + RemoteUser string `json:"remote_user"` +} + +func (r *AccessLogRecord) json() ([]byte, error) { + buffer := &bytes.Buffer{} + encoder := json.NewEncoder(buffer) + disableEscapeHTML(encoder) + + err := encoder.Encode(r) + return buffer.Bytes(), err +} + +func disableEscapeHTML(i interface{}) { + if e, ok := i.(interface { + SetEscapeHTML(bool) + }); ok { + e.SetEscapeHTML(false) + } +} + +// AccessLog - Format and print access log. +func AccessLog(r *AccessLogRecord, format string) { + var msg string + switch format { + case apacheFormat: + timeFormatted := r.RequestTime.Format("02/Jan/2006 03:04:05") + msg = fmt.Sprintf(apacheFormatPattern, r.RemoteAddr, timeFormatted, r.Request, r.Status, r.BodyBytesSent, + r.ElapsedTime.Seconds(), r.HTTPReferrer, r.HTTPUserAgent) + case jsonFormat: + fallthrough + default: + jsonData, err := r.json() + if err != nil { + msg = fmt.Sprintf(`{"Error": "%s"}`, err) + } else { + msg = string(jsonData) + } + } + beeLogger.writeMsg(levelLoggerImpl, strings.TrimSpace(msg)) +} diff --git a/pkg/logs/alils/alils.go b/pkg/logs/alils/alils.go new file mode 100644 index 00000000..867ff4cb --- /dev/null +++ b/pkg/logs/alils/alils.go @@ -0,0 +1,186 @@ +package alils + +import ( + "encoding/json" + "strings" + "sync" + "time" + + "github.com/astaxie/beego/logs" + "github.com/gogo/protobuf/proto" +) + +const ( + // CacheSize set the flush size + CacheSize int = 64 + // Delimiter define the topic delimiter + Delimiter string = "##" +) + +// Config is the Config for Ali Log +type Config struct { + Project string `json:"project"` + Endpoint string `json:"endpoint"` + KeyID string `json:"key_id"` + KeySecret string `json:"key_secret"` + LogStore string `json:"log_store"` + Topics []string `json:"topics"` + Source string `json:"source"` + Level int `json:"level"` + FlushWhen int `json:"flush_when"` +} + +// aliLSWriter implements LoggerInterface. +// it writes messages in keep-live tcp connection. +type aliLSWriter struct { + store *LogStore + group []*LogGroup + withMap bool + groupMap map[string]*LogGroup + lock *sync.Mutex + Config +} + +// NewAliLS create a new Logger +func NewAliLS() logs.Logger { + alils := new(aliLSWriter) + alils.Level = logs.LevelTrace + return alils +} + +// Init parse config and init struct +func (c *aliLSWriter) Init(jsonConfig string) (err error) { + + json.Unmarshal([]byte(jsonConfig), c) + + if c.FlushWhen > CacheSize { + c.FlushWhen = CacheSize + } + + prj := &LogProject{ + Name: c.Project, + Endpoint: c.Endpoint, + AccessKeyID: c.KeyID, + AccessKeySecret: c.KeySecret, + } + + c.store, err = prj.GetLogStore(c.LogStore) + if err != nil { + return err + } + + // Create default Log Group + c.group = append(c.group, &LogGroup{ + Topic: proto.String(""), + Source: proto.String(c.Source), + Logs: make([]*Log, 0, c.FlushWhen), + }) + + // Create other Log Group + c.groupMap = make(map[string]*LogGroup) + for _, topic := range c.Topics { + + lg := &LogGroup{ + Topic: proto.String(topic), + Source: proto.String(c.Source), + Logs: make([]*Log, 0, c.FlushWhen), + } + + c.group = append(c.group, lg) + c.groupMap[topic] = lg + } + + if len(c.group) == 1 { + c.withMap = false + } else { + c.withMap = true + } + + c.lock = &sync.Mutex{} + + return nil +} + +// WriteMsg write message in connection. +// if connection is down, try to re-connect. +func (c *aliLSWriter) WriteMsg(when time.Time, msg string, level int) (err error) { + + if level > c.Level { + return nil + } + + var topic string + var content string + var lg *LogGroup + if c.withMap { + + // Topic,LogGroup + strs := strings.SplitN(msg, Delimiter, 2) + if len(strs) == 2 { + pos := strings.LastIndex(strs[0], " ") + topic = strs[0][pos+1 : len(strs[0])] + content = strs[0][0:pos] + strs[1] + lg = c.groupMap[topic] + } + + // send to empty Topic + if lg == nil { + content = msg + lg = c.group[0] + } + } else { + content = msg + lg = c.group[0] + } + + c1 := &LogContent{ + Key: proto.String("msg"), + Value: proto.String(content), + } + + l := &Log{ + Time: proto.Uint32(uint32(when.Unix())), + Contents: []*LogContent{ + c1, + }, + } + + c.lock.Lock() + lg.Logs = append(lg.Logs, l) + c.lock.Unlock() + + if len(lg.Logs) >= c.FlushWhen { + c.flush(lg) + } + + return nil +} + +// Flush implementing method. empty. +func (c *aliLSWriter) Flush() { + + // flush all group + for _, lg := range c.group { + c.flush(lg) + } +} + +// Destroy destroy connection writer and close tcp listener. +func (c *aliLSWriter) Destroy() { +} + +func (c *aliLSWriter) flush(lg *LogGroup) { + + c.lock.Lock() + defer c.lock.Unlock() + err := c.store.PutLogs(lg) + if err != nil { + return + } + + lg.Logs = make([]*Log, 0, c.FlushWhen) +} + +func init() { + logs.Register(logs.AdapterAliLS, NewAliLS) +} diff --git a/pkg/logs/alils/config.go b/pkg/logs/alils/config.go new file mode 100755 index 00000000..e8c24448 --- /dev/null +++ b/pkg/logs/alils/config.go @@ -0,0 +1,13 @@ +package alils + +const ( + version = "0.5.0" // SDK version + signatureMethod = "hmac-sha1" // Signature method + + // OffsetNewest stands for the log head offset, i.e. the offset that will be + // assigned to the next message that will be produced to the shard. + OffsetNewest = "end" + // OffsetOldest stands for the oldest offset available on the logstore for a + // shard. + OffsetOldest = "begin" +) diff --git a/pkg/logs/alils/log.pb.go b/pkg/logs/alils/log.pb.go new file mode 100755 index 00000000..601b0d78 --- /dev/null +++ b/pkg/logs/alils/log.pb.go @@ -0,0 +1,1038 @@ +package alils + +import ( + "fmt" + "io" + "math" + + "github.com/gogo/protobuf/proto" + github_com_gogo_protobuf_proto "github.com/gogo/protobuf/proto" +) + +// Reference imports to suppress errors if they are not otherwise used. +var _ = proto.Marshal +var _ = fmt.Errorf +var _ = math.Inf + +var ( + // ErrInvalidLengthLog invalid proto + ErrInvalidLengthLog = fmt.Errorf("proto: negative length found during unmarshaling") + // ErrIntOverflowLog overflow + ErrIntOverflowLog = fmt.Errorf("proto: integer overflow") +) + +// Log define the proto Log +type Log struct { + Time *uint32 `protobuf:"varint,1,req,name=Time" json:"Time,omitempty"` + Contents []*LogContent `protobuf:"bytes,2,rep,name=Contents" json:"Contents,omitempty"` + XXXUnrecognized []byte `json:"-"` +} + +// Reset the Log +func (m *Log) Reset() { *m = Log{} } + +// String return the Compact Log +func (m *Log) String() string { return proto.CompactTextString(m) } + +// ProtoMessage not implemented +func (*Log) ProtoMessage() {} + +// GetTime return the Log's Time +func (m *Log) GetTime() uint32 { + if m != nil && m.Time != nil { + return *m.Time + } + return 0 +} + +// GetContents return the Log's Contents +func (m *Log) GetContents() []*LogContent { + if m != nil { + return m.Contents + } + return nil +} + +// LogContent define the Log content struct +type LogContent struct { + Key *string `protobuf:"bytes,1,req,name=Key" json:"Key,omitempty"` + Value *string `protobuf:"bytes,2,req,name=Value" json:"Value,omitempty"` + XXXUnrecognized []byte `json:"-"` +} + +// Reset LogContent +func (m *LogContent) Reset() { *m = LogContent{} } + +// String return the compact text +func (m *LogContent) String() string { return proto.CompactTextString(m) } + +// ProtoMessage not implemented +func (*LogContent) ProtoMessage() {} + +// GetKey return the Key +func (m *LogContent) GetKey() string { + if m != nil && m.Key != nil { + return *m.Key + } + return "" +} + +// GetValue return the Value +func (m *LogContent) GetValue() string { + if m != nil && m.Value != nil { + return *m.Value + } + return "" +} + +// LogGroup define the logs struct +type LogGroup struct { + Logs []*Log `protobuf:"bytes,1,rep,name=Logs" json:"Logs,omitempty"` + Reserved *string `protobuf:"bytes,2,opt,name=Reserved" json:"Reserved,omitempty"` + Topic *string `protobuf:"bytes,3,opt,name=Topic" json:"Topic,omitempty"` + Source *string `protobuf:"bytes,4,opt,name=Source" json:"Source,omitempty"` + XXXUnrecognized []byte `json:"-"` +} + +// Reset LogGroup +func (m *LogGroup) Reset() { *m = LogGroup{} } + +// String return the compact text +func (m *LogGroup) String() string { return proto.CompactTextString(m) } + +// ProtoMessage not implemented +func (*LogGroup) ProtoMessage() {} + +// GetLogs return the loggroup logs +func (m *LogGroup) GetLogs() []*Log { + if m != nil { + return m.Logs + } + return nil +} + +// GetReserved return Reserved +func (m *LogGroup) GetReserved() string { + if m != nil && m.Reserved != nil { + return *m.Reserved + } + return "" +} + +// GetTopic return Topic +func (m *LogGroup) GetTopic() string { + if m != nil && m.Topic != nil { + return *m.Topic + } + return "" +} + +// GetSource return Source +func (m *LogGroup) GetSource() string { + if m != nil && m.Source != nil { + return *m.Source + } + return "" +} + +// LogGroupList define the LogGroups +type LogGroupList struct { + LogGroups []*LogGroup `protobuf:"bytes,1,rep,name=logGroups" json:"logGroups,omitempty"` + XXXUnrecognized []byte `json:"-"` +} + +// Reset LogGroupList +func (m *LogGroupList) Reset() { *m = LogGroupList{} } + +// String return compact text +func (m *LogGroupList) String() string { return proto.CompactTextString(m) } + +// ProtoMessage not implemented +func (*LogGroupList) ProtoMessage() {} + +// GetLogGroups return the LogGroups +func (m *LogGroupList) GetLogGroups() []*LogGroup { + if m != nil { + return m.LogGroups + } + return nil +} + +// Marshal the logs to byte slice +func (m *Log) Marshal() (data []byte, err error) { + size := m.Size() + data = make([]byte, size) + n, err := m.MarshalTo(data) + if err != nil { + return nil, err + } + return data[:n], nil +} + +// MarshalTo data +func (m *Log) MarshalTo(data []byte) (int, error) { + var i int + _ = i + var l int + _ = l + if m.Time == nil { + return 0, github_com_gogo_protobuf_proto.NewRequiredNotSetError("Time") + } + data[i] = 0x8 + i++ + i = encodeVarintLog(data, i, uint64(*m.Time)) + if len(m.Contents) > 0 { + for _, msg := range m.Contents { + data[i] = 0x12 + i++ + i = encodeVarintLog(data, i, uint64(msg.Size())) + n, err := msg.MarshalTo(data[i:]) + if err != nil { + return 0, err + } + i += n + } + } + if m.XXXUnrecognized != nil { + i += copy(data[i:], m.XXXUnrecognized) + } + return i, nil +} + +// Marshal LogContent +func (m *LogContent) Marshal() (data []byte, err error) { + size := m.Size() + data = make([]byte, size) + n, err := m.MarshalTo(data) + if err != nil { + return nil, err + } + return data[:n], nil +} + +// MarshalTo logcontent to data +func (m *LogContent) MarshalTo(data []byte) (int, error) { + var i int + _ = i + var l int + _ = l + if m.Key == nil { + return 0, github_com_gogo_protobuf_proto.NewRequiredNotSetError("Key") + } + data[i] = 0xa + i++ + i = encodeVarintLog(data, i, uint64(len(*m.Key))) + i += copy(data[i:], *m.Key) + + if m.Value == nil { + return 0, github_com_gogo_protobuf_proto.NewRequiredNotSetError("Value") + } + data[i] = 0x12 + i++ + i = encodeVarintLog(data, i, uint64(len(*m.Value))) + i += copy(data[i:], *m.Value) + if m.XXXUnrecognized != nil { + i += copy(data[i:], m.XXXUnrecognized) + } + return i, nil +} + +// Marshal LogGroup +func (m *LogGroup) Marshal() (data []byte, err error) { + size := m.Size() + data = make([]byte, size) + n, err := m.MarshalTo(data) + if err != nil { + return nil, err + } + return data[:n], nil +} + +// MarshalTo LogGroup to data +func (m *LogGroup) MarshalTo(data []byte) (int, error) { + var i int + _ = i + var l int + _ = l + if len(m.Logs) > 0 { + for _, msg := range m.Logs { + data[i] = 0xa + i++ + i = encodeVarintLog(data, i, uint64(msg.Size())) + n, err := msg.MarshalTo(data[i:]) + if err != nil { + return 0, err + } + i += n + } + } + if m.Reserved != nil { + data[i] = 0x12 + i++ + i = encodeVarintLog(data, i, uint64(len(*m.Reserved))) + i += copy(data[i:], *m.Reserved) + } + if m.Topic != nil { + data[i] = 0x1a + i++ + i = encodeVarintLog(data, i, uint64(len(*m.Topic))) + i += copy(data[i:], *m.Topic) + } + if m.Source != nil { + data[i] = 0x22 + i++ + i = encodeVarintLog(data, i, uint64(len(*m.Source))) + i += copy(data[i:], *m.Source) + } + if m.XXXUnrecognized != nil { + i += copy(data[i:], m.XXXUnrecognized) + } + return i, nil +} + +// Marshal LogGroupList +func (m *LogGroupList) Marshal() (data []byte, err error) { + size := m.Size() + data = make([]byte, size) + n, err := m.MarshalTo(data) + if err != nil { + return nil, err + } + return data[:n], nil +} + +// MarshalTo LogGroupList to data +func (m *LogGroupList) MarshalTo(data []byte) (int, error) { + var i int + _ = i + var l int + _ = l + if len(m.LogGroups) > 0 { + for _, msg := range m.LogGroups { + data[i] = 0xa + i++ + i = encodeVarintLog(data, i, uint64(msg.Size())) + n, err := msg.MarshalTo(data[i:]) + if err != nil { + return 0, err + } + i += n + } + } + if m.XXXUnrecognized != nil { + i += copy(data[i:], m.XXXUnrecognized) + } + return i, nil +} + +func encodeFixed64Log(data []byte, offset int, v uint64) int { + data[offset] = uint8(v) + data[offset+1] = uint8(v >> 8) + data[offset+2] = uint8(v >> 16) + data[offset+3] = uint8(v >> 24) + data[offset+4] = uint8(v >> 32) + data[offset+5] = uint8(v >> 40) + data[offset+6] = uint8(v >> 48) + data[offset+7] = uint8(v >> 56) + return offset + 8 +} +func encodeFixed32Log(data []byte, offset int, v uint32) int { + data[offset] = uint8(v) + data[offset+1] = uint8(v >> 8) + data[offset+2] = uint8(v >> 16) + data[offset+3] = uint8(v >> 24) + return offset + 4 +} +func encodeVarintLog(data []byte, offset int, v uint64) int { + for v >= 1<<7 { + data[offset] = uint8(v&0x7f | 0x80) + v >>= 7 + offset++ + } + data[offset] = uint8(v) + return offset + 1 +} + +// Size return the log's size +func (m *Log) Size() (n int) { + var l int + _ = l + if m.Time != nil { + n += 1 + sovLog(uint64(*m.Time)) + } + if len(m.Contents) > 0 { + for _, e := range m.Contents { + l = e.Size() + n += 1 + l + sovLog(uint64(l)) + } + } + if m.XXXUnrecognized != nil { + n += len(m.XXXUnrecognized) + } + return n +} + +// Size return LogContent size based on Key and Value +func (m *LogContent) Size() (n int) { + var l int + _ = l + if m.Key != nil { + l = len(*m.Key) + n += 1 + l + sovLog(uint64(l)) + } + if m.Value != nil { + l = len(*m.Value) + n += 1 + l + sovLog(uint64(l)) + } + if m.XXXUnrecognized != nil { + n += len(m.XXXUnrecognized) + } + return n +} + +// Size return LogGroup size based on Logs +func (m *LogGroup) Size() (n int) { + var l int + _ = l + if len(m.Logs) > 0 { + for _, e := range m.Logs { + l = e.Size() + n += 1 + l + sovLog(uint64(l)) + } + } + if m.Reserved != nil { + l = len(*m.Reserved) + n += 1 + l + sovLog(uint64(l)) + } + if m.Topic != nil { + l = len(*m.Topic) + n += 1 + l + sovLog(uint64(l)) + } + if m.Source != nil { + l = len(*m.Source) + n += 1 + l + sovLog(uint64(l)) + } + if m.XXXUnrecognized != nil { + n += len(m.XXXUnrecognized) + } + return n +} + +// Size return LogGroupList size +func (m *LogGroupList) Size() (n int) { + var l int + _ = l + if len(m.LogGroups) > 0 { + for _, e := range m.LogGroups { + l = e.Size() + n += 1 + l + sovLog(uint64(l)) + } + } + if m.XXXUnrecognized != nil { + n += len(m.XXXUnrecognized) + } + return n +} + +func sovLog(x uint64) (n int) { + for { + n++ + x >>= 7 + if x == 0 { + break + } + } + return n +} +func sozLog(x uint64) (n int) { + return sovLog((x << 1) ^ (x >> 63)) +} + +// Unmarshal data to log +func (m *Log) Unmarshal(data []byte) error { + var hasFields [1]uint64 + l := len(data) + iNdEx := 0 + for iNdEx < l { + preIndex := iNdEx + var wire uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowLog + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := data[iNdEx] + iNdEx++ + wire |= (uint64(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + fieldNum := int32(wire >> 3) + wireType := int(wire & 0x7) + if wireType == 4 { + return fmt.Errorf("proto: Log: wiretype end group for non-group") + } + if fieldNum <= 0 { + return fmt.Errorf("proto: Log: illegal tag %d (wire type %d)", fieldNum, wire) + } + switch fieldNum { + case 1: + if wireType != 0 { + return fmt.Errorf("proto: wrong wireType = %d for field Time", wireType) + } + var v uint32 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowLog + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := data[iNdEx] + iNdEx++ + v |= (uint32(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + m.Time = &v + hasFields[0] |= uint64(0x00000001) + case 2: + if wireType != 2 { + return fmt.Errorf("proto: wrong wireType = %d for field Contents", wireType) + } + var msglen int + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowLog + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := data[iNdEx] + iNdEx++ + msglen |= (int(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + if msglen < 0 { + return ErrInvalidLengthLog + } + postIndex := iNdEx + msglen + if postIndex > l { + return io.ErrUnexpectedEOF + } + m.Contents = append(m.Contents, &LogContent{}) + if err := m.Contents[len(m.Contents)-1].Unmarshal(data[iNdEx:postIndex]); err != nil { + return err + } + iNdEx = postIndex + default: + iNdEx = preIndex + skippy, err := skipLog(data[iNdEx:]) + if err != nil { + return err + } + if skippy < 0 { + return ErrInvalidLengthLog + } + if (iNdEx + skippy) > l { + return io.ErrUnexpectedEOF + } + m.XXXUnrecognized = append(m.XXXUnrecognized, data[iNdEx:iNdEx+skippy]...) + iNdEx += skippy + } + } + if hasFields[0]&uint64(0x00000001) == 0 { + return github_com_gogo_protobuf_proto.NewRequiredNotSetError("Time") + } + + if iNdEx > l { + return io.ErrUnexpectedEOF + } + return nil +} + +// Unmarshal data to LogContent +func (m *LogContent) Unmarshal(data []byte) error { + var hasFields [1]uint64 + l := len(data) + iNdEx := 0 + for iNdEx < l { + preIndex := iNdEx + var wire uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowLog + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := data[iNdEx] + iNdEx++ + wire |= (uint64(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + fieldNum := int32(wire >> 3) + wireType := int(wire & 0x7) + if wireType == 4 { + return fmt.Errorf("proto: Content: wiretype end group for non-group") + } + if fieldNum <= 0 { + return fmt.Errorf("proto: Content: illegal tag %d (wire type %d)", fieldNum, wire) + } + switch fieldNum { + case 1: + if wireType != 2 { + return fmt.Errorf("proto: wrong wireType = %d for field Key", wireType) + } + var stringLen uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowLog + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := data[iNdEx] + iNdEx++ + stringLen |= (uint64(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + intStringLen := int(stringLen) + if intStringLen < 0 { + return ErrInvalidLengthLog + } + postIndex := iNdEx + intStringLen + if postIndex > l { + return io.ErrUnexpectedEOF + } + s := string(data[iNdEx:postIndex]) + m.Key = &s + iNdEx = postIndex + hasFields[0] |= uint64(0x00000001) + case 2: + if wireType != 2 { + return fmt.Errorf("proto: wrong wireType = %d for field Value", wireType) + } + var stringLen uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowLog + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := data[iNdEx] + iNdEx++ + stringLen |= (uint64(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + intStringLen := int(stringLen) + if intStringLen < 0 { + return ErrInvalidLengthLog + } + postIndex := iNdEx + intStringLen + if postIndex > l { + return io.ErrUnexpectedEOF + } + s := string(data[iNdEx:postIndex]) + m.Value = &s + iNdEx = postIndex + hasFields[0] |= uint64(0x00000002) + default: + iNdEx = preIndex + skippy, err := skipLog(data[iNdEx:]) + if err != nil { + return err + } + if skippy < 0 { + return ErrInvalidLengthLog + } + if (iNdEx + skippy) > l { + return io.ErrUnexpectedEOF + } + m.XXXUnrecognized = append(m.XXXUnrecognized, data[iNdEx:iNdEx+skippy]...) + iNdEx += skippy + } + } + if hasFields[0]&uint64(0x00000001) == 0 { + return github_com_gogo_protobuf_proto.NewRequiredNotSetError("Key") + } + if hasFields[0]&uint64(0x00000002) == 0 { + return github_com_gogo_protobuf_proto.NewRequiredNotSetError("Value") + } + + if iNdEx > l { + return io.ErrUnexpectedEOF + } + return nil +} + +// Unmarshal data to LogGroup +func (m *LogGroup) Unmarshal(data []byte) error { + l := len(data) + iNdEx := 0 + for iNdEx < l { + preIndex := iNdEx + var wire uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowLog + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := data[iNdEx] + iNdEx++ + wire |= (uint64(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + fieldNum := int32(wire >> 3) + wireType := int(wire & 0x7) + if wireType == 4 { + return fmt.Errorf("proto: LogGroup: wiretype end group for non-group") + } + if fieldNum <= 0 { + return fmt.Errorf("proto: LogGroup: illegal tag %d (wire type %d)", fieldNum, wire) + } + switch fieldNum { + case 1: + if wireType != 2 { + return fmt.Errorf("proto: wrong wireType = %d for field Logs", wireType) + } + var msglen int + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowLog + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := data[iNdEx] + iNdEx++ + msglen |= (int(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + if msglen < 0 { + return ErrInvalidLengthLog + } + postIndex := iNdEx + msglen + if postIndex > l { + return io.ErrUnexpectedEOF + } + m.Logs = append(m.Logs, &Log{}) + if err := m.Logs[len(m.Logs)-1].Unmarshal(data[iNdEx:postIndex]); err != nil { + return err + } + iNdEx = postIndex + case 2: + if wireType != 2 { + return fmt.Errorf("proto: wrong wireType = %d for field Reserved", wireType) + } + var stringLen uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowLog + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := data[iNdEx] + iNdEx++ + stringLen |= (uint64(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + intStringLen := int(stringLen) + if intStringLen < 0 { + return ErrInvalidLengthLog + } + postIndex := iNdEx + intStringLen + if postIndex > l { + return io.ErrUnexpectedEOF + } + s := string(data[iNdEx:postIndex]) + m.Reserved = &s + iNdEx = postIndex + case 3: + if wireType != 2 { + return fmt.Errorf("proto: wrong wireType = %d for field Topic", wireType) + } + var stringLen uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowLog + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := data[iNdEx] + iNdEx++ + stringLen |= (uint64(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + intStringLen := int(stringLen) + if intStringLen < 0 { + return ErrInvalidLengthLog + } + postIndex := iNdEx + intStringLen + if postIndex > l { + return io.ErrUnexpectedEOF + } + s := string(data[iNdEx:postIndex]) + m.Topic = &s + iNdEx = postIndex + case 4: + if wireType != 2 { + return fmt.Errorf("proto: wrong wireType = %d for field Source", wireType) + } + var stringLen uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowLog + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := data[iNdEx] + iNdEx++ + stringLen |= (uint64(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + intStringLen := int(stringLen) + if intStringLen < 0 { + return ErrInvalidLengthLog + } + postIndex := iNdEx + intStringLen + if postIndex > l { + return io.ErrUnexpectedEOF + } + s := string(data[iNdEx:postIndex]) + m.Source = &s + iNdEx = postIndex + default: + iNdEx = preIndex + skippy, err := skipLog(data[iNdEx:]) + if err != nil { + return err + } + if skippy < 0 { + return ErrInvalidLengthLog + } + if (iNdEx + skippy) > l { + return io.ErrUnexpectedEOF + } + m.XXXUnrecognized = append(m.XXXUnrecognized, data[iNdEx:iNdEx+skippy]...) + iNdEx += skippy + } + } + + if iNdEx > l { + return io.ErrUnexpectedEOF + } + return nil +} + +// Unmarshal data to LogGroupList +func (m *LogGroupList) Unmarshal(data []byte) error { + l := len(data) + iNdEx := 0 + for iNdEx < l { + preIndex := iNdEx + var wire uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowLog + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := data[iNdEx] + iNdEx++ + wire |= (uint64(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + fieldNum := int32(wire >> 3) + wireType := int(wire & 0x7) + if wireType == 4 { + return fmt.Errorf("proto: LogGroupList: wiretype end group for non-group") + } + if fieldNum <= 0 { + return fmt.Errorf("proto: LogGroupList: illegal tag %d (wire type %d)", fieldNum, wire) + } + switch fieldNum { + case 1: + if wireType != 2 { + return fmt.Errorf("proto: wrong wireType = %d for field LogGroups", wireType) + } + var msglen int + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowLog + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := data[iNdEx] + iNdEx++ + msglen |= (int(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + if msglen < 0 { + return ErrInvalidLengthLog + } + postIndex := iNdEx + msglen + if postIndex > l { + return io.ErrUnexpectedEOF + } + m.LogGroups = append(m.LogGroups, &LogGroup{}) + if err := m.LogGroups[len(m.LogGroups)-1].Unmarshal(data[iNdEx:postIndex]); err != nil { + return err + } + iNdEx = postIndex + default: + iNdEx = preIndex + skippy, err := skipLog(data[iNdEx:]) + if err != nil { + return err + } + if skippy < 0 { + return ErrInvalidLengthLog + } + if (iNdEx + skippy) > l { + return io.ErrUnexpectedEOF + } + m.XXXUnrecognized = append(m.XXXUnrecognized, data[iNdEx:iNdEx+skippy]...) + iNdEx += skippy + } + } + + if iNdEx > l { + return io.ErrUnexpectedEOF + } + return nil +} + +func skipLog(data []byte) (n int, err error) { + l := len(data) + iNdEx := 0 + for iNdEx < l { + var wire uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return 0, ErrIntOverflowLog + } + if iNdEx >= l { + return 0, io.ErrUnexpectedEOF + } + b := data[iNdEx] + iNdEx++ + wire |= (uint64(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + wireType := int(wire & 0x7) + switch wireType { + case 0: + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return 0, ErrIntOverflowLog + } + if iNdEx >= l { + return 0, io.ErrUnexpectedEOF + } + iNdEx++ + if data[iNdEx-1] < 0x80 { + break + } + } + return iNdEx, nil + case 1: + iNdEx += 8 + return iNdEx, nil + case 2: + var length int + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return 0, ErrIntOverflowLog + } + if iNdEx >= l { + return 0, io.ErrUnexpectedEOF + } + b := data[iNdEx] + iNdEx++ + length |= (int(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + iNdEx += length + if length < 0 { + return 0, ErrInvalidLengthLog + } + return iNdEx, nil + case 3: + for { + var innerWire uint64 + var start = iNdEx + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return 0, ErrIntOverflowLog + } + if iNdEx >= l { + return 0, io.ErrUnexpectedEOF + } + b := data[iNdEx] + iNdEx++ + innerWire |= (uint64(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + innerWireType := int(innerWire & 0x7) + if innerWireType == 4 { + break + } + next, err := skipLog(data[start:]) + if err != nil { + return 0, err + } + iNdEx = start + next + } + return iNdEx, nil + case 4: + return iNdEx, nil + case 5: + iNdEx += 4 + return iNdEx, nil + default: + return 0, fmt.Errorf("proto: illegal wireType %d", wireType) + } + } + panic("unreachable") +} diff --git a/pkg/logs/alils/log_config.go b/pkg/logs/alils/log_config.go new file mode 100755 index 00000000..e8564efb --- /dev/null +++ b/pkg/logs/alils/log_config.go @@ -0,0 +1,42 @@ +package alils + +// InputDetail define log detail +type InputDetail struct { + LogType string `json:"logType"` + LogPath string `json:"logPath"` + FilePattern string `json:"filePattern"` + LocalStorage bool `json:"localStorage"` + TimeFormat string `json:"timeFormat"` + LogBeginRegex string `json:"logBeginRegex"` + Regex string `json:"regex"` + Keys []string `json:"key"` + FilterKeys []string `json:"filterKey"` + FilterRegex []string `json:"filterRegex"` + TopicFormat string `json:"topicFormat"` +} + +// OutputDetail define the output detail +type OutputDetail struct { + Endpoint string `json:"endpoint"` + LogStoreName string `json:"logstoreName"` +} + +// LogConfig define Log Config +type LogConfig struct { + Name string `json:"configName"` + InputType string `json:"inputType"` + InputDetail InputDetail `json:"inputDetail"` + OutputType string `json:"outputType"` + OutputDetail OutputDetail `json:"outputDetail"` + + CreateTime uint32 + LastModifyTime uint32 + + project *LogProject +} + +// GetAppliedMachineGroup returns applied machine group of this config. +func (c *LogConfig) GetAppliedMachineGroup(confName string) (groupNames []string, err error) { + groupNames, err = c.project.GetAppliedMachineGroups(c.Name) + return +} diff --git a/pkg/logs/alils/log_project.go b/pkg/logs/alils/log_project.go new file mode 100755 index 00000000..59db8cbf --- /dev/null +++ b/pkg/logs/alils/log_project.go @@ -0,0 +1,819 @@ +/* +Package alils implements the SDK(v0.5.0) of Simple Log Service(abbr. SLS). + +For more description about SLS, please read this article: +http://gitlab.alibaba-inc.com/sls/doc. +*/ +package alils + +import ( + "encoding/json" + "fmt" + "io/ioutil" + "net/http" + "net/http/httputil" +) + +// Error message in SLS HTTP response. +type errorMessage struct { + Code string `json:"errorCode"` + Message string `json:"errorMessage"` +} + +// LogProject Define the Ali Project detail +type LogProject struct { + Name string // Project name + Endpoint string // IP or hostname of SLS endpoint + AccessKeyID string + AccessKeySecret string +} + +// NewLogProject creates a new SLS project. +func NewLogProject(name, endpoint, AccessKeyID, accessKeySecret string) (p *LogProject, err error) { + p = &LogProject{ + Name: name, + Endpoint: endpoint, + AccessKeyID: AccessKeyID, + AccessKeySecret: accessKeySecret, + } + return p, nil +} + +// ListLogStore returns all logstore names of project p. +func (p *LogProject) ListLogStore() (storeNames []string, err error) { + h := map[string]string{ + "x-sls-bodyrawsize": "0", + } + + uri := fmt.Sprintf("/logstores") + r, err := request(p, "GET", uri, h, nil) + if err != nil { + return + } + + buf, err := ioutil.ReadAll(r.Body) + if err != nil { + return + } + + if r.StatusCode != http.StatusOK { + errMsg := &errorMessage{} + err = json.Unmarshal(buf, errMsg) + if err != nil { + err = fmt.Errorf("failed to list logstore") + dump, _ := httputil.DumpResponse(r, true) + fmt.Printf("%s\n", dump) + return + } + err = fmt.Errorf("%v:%v", errMsg.Code, errMsg.Message) + return + } + + type Body struct { + Count int + LogStores []string + } + body := &Body{} + + err = json.Unmarshal(buf, body) + if err != nil { + return + } + + storeNames = body.LogStores + + return +} + +// GetLogStore returns logstore according by logstore name. +func (p *LogProject) GetLogStore(name string) (s *LogStore, err error) { + h := map[string]string{ + "x-sls-bodyrawsize": "0", + } + + r, err := request(p, "GET", "/logstores/"+name, h, nil) + if err != nil { + return + } + + buf, err := ioutil.ReadAll(r.Body) + if err != nil { + return + } + + if r.StatusCode != http.StatusOK { + errMsg := &errorMessage{} + err = json.Unmarshal(buf, errMsg) + if err != nil { + err = fmt.Errorf("failed to get logstore") + dump, _ := httputil.DumpResponse(r, true) + fmt.Printf("%s\n", dump) + return + } + err = fmt.Errorf("%v:%v", errMsg.Code, errMsg.Message) + return + } + + s = &LogStore{} + err = json.Unmarshal(buf, s) + if err != nil { + return + } + s.project = p + return +} + +// CreateLogStore creates a new logstore in SLS, +// where name is logstore name, +// and ttl is time-to-live(in day) of logs, +// and shardCnt is the number of shards. +func (p *LogProject) CreateLogStore(name string, ttl, shardCnt int) (err error) { + + type Body struct { + Name string `json:"logstoreName"` + TTL int `json:"ttl"` + ShardCount int `json:"shardCount"` + } + + store := &Body{ + Name: name, + TTL: ttl, + ShardCount: shardCnt, + } + + body, err := json.Marshal(store) + if err != nil { + return + } + + h := map[string]string{ + "x-sls-bodyrawsize": fmt.Sprintf("%v", len(body)), + "Content-Type": "application/json", + "Accept-Encoding": "deflate", // TODO: support lz4 + } + + r, err := request(p, "POST", "/logstores", h, body) + if err != nil { + return + } + + body, err = ioutil.ReadAll(r.Body) + if err != nil { + return + } + + if r.StatusCode != http.StatusOK { + errMsg := &errorMessage{} + err = json.Unmarshal(body, errMsg) + if err != nil { + err = fmt.Errorf("failed to create logstore") + dump, _ := httputil.DumpResponse(r, true) + fmt.Printf("%s\n", dump) + return + } + err = fmt.Errorf("%v:%v", errMsg.Code, errMsg.Message) + return + } + + return +} + +// DeleteLogStore deletes a logstore according by logstore name. +func (p *LogProject) DeleteLogStore(name string) (err error) { + h := map[string]string{ + "x-sls-bodyrawsize": "0", + } + + r, err := request(p, "DELETE", "/logstores/"+name, h, nil) + if err != nil { + return + } + + body, err := ioutil.ReadAll(r.Body) + if err != nil { + return + } + + if r.StatusCode != http.StatusOK { + errMsg := &errorMessage{} + err = json.Unmarshal(body, errMsg) + if err != nil { + err = fmt.Errorf("failed to delete logstore") + dump, _ := httputil.DumpResponse(r, true) + fmt.Printf("%s\n", dump) + return + } + err = fmt.Errorf("%v:%v", errMsg.Code, errMsg.Message) + return + } + return +} + +// UpdateLogStore updates a logstore according by logstore name, +// obviously we can't modify the logstore name itself. +func (p *LogProject) UpdateLogStore(name string, ttl, shardCnt int) (err error) { + + type Body struct { + Name string `json:"logstoreName"` + TTL int `json:"ttl"` + ShardCount int `json:"shardCount"` + } + + store := &Body{ + Name: name, + TTL: ttl, + ShardCount: shardCnt, + } + + body, err := json.Marshal(store) + if err != nil { + return + } + + h := map[string]string{ + "x-sls-bodyrawsize": fmt.Sprintf("%v", len(body)), + "Content-Type": "application/json", + "Accept-Encoding": "deflate", // TODO: support lz4 + } + + r, err := request(p, "PUT", "/logstores", h, body) + if err != nil { + return + } + + body, err = ioutil.ReadAll(r.Body) + if err != nil { + return + } + + if r.StatusCode != http.StatusOK { + errMsg := &errorMessage{} + err = json.Unmarshal(body, errMsg) + if err != nil { + err = fmt.Errorf("failed to update logstore") + dump, _ := httputil.DumpResponse(r, true) + fmt.Printf("%s\n", dump) + return + } + err = fmt.Errorf("%v:%v", errMsg.Code, errMsg.Message) + return + } + + return +} + +// ListMachineGroup returns machine group name list and the total number of machine groups. +// The offset starts from 0 and the size is the max number of machine groups could be returned. +func (p *LogProject) ListMachineGroup(offset, size int) (m []string, total int, err error) { + h := map[string]string{ + "x-sls-bodyrawsize": "0", + } + + if size <= 0 { + size = 500 + } + + uri := fmt.Sprintf("/machinegroups?offset=%v&size=%v", offset, size) + r, err := request(p, "GET", uri, h, nil) + if err != nil { + return + } + + buf, err := ioutil.ReadAll(r.Body) + if err != nil { + return + } + + if r.StatusCode != http.StatusOK { + errMsg := &errorMessage{} + err = json.Unmarshal(buf, errMsg) + if err != nil { + err = fmt.Errorf("failed to list machine group") + dump, _ := httputil.DumpResponse(r, true) + fmt.Printf("%s\n", dump) + return + } + err = fmt.Errorf("%v:%v", errMsg.Code, errMsg.Message) + return + } + + type Body struct { + MachineGroups []string + Count int + Total int + } + body := &Body{} + + err = json.Unmarshal(buf, body) + if err != nil { + return + } + + m = body.MachineGroups + total = body.Total + + return +} + +// GetMachineGroup retruns machine group according by machine group name. +func (p *LogProject) GetMachineGroup(name string) (m *MachineGroup, err error) { + h := map[string]string{ + "x-sls-bodyrawsize": "0", + } + + r, err := request(p, "GET", "/machinegroups/"+name, h, nil) + if err != nil { + return + } + + buf, err := ioutil.ReadAll(r.Body) + if err != nil { + return + } + + if r.StatusCode != http.StatusOK { + errMsg := &errorMessage{} + err = json.Unmarshal(buf, errMsg) + if err != nil { + err = fmt.Errorf("failed to get machine group:%v", name) + dump, _ := httputil.DumpResponse(r, true) + fmt.Printf("%s\n", dump) + return + } + err = fmt.Errorf("%v:%v", errMsg.Code, errMsg.Message) + return + } + + m = &MachineGroup{} + err = json.Unmarshal(buf, m) + if err != nil { + return + } + m.project = p + return +} + +// CreateMachineGroup creates a new machine group in SLS. +func (p *LogProject) CreateMachineGroup(m *MachineGroup) (err error) { + + body, err := json.Marshal(m) + if err != nil { + return + } + + h := map[string]string{ + "x-sls-bodyrawsize": fmt.Sprintf("%v", len(body)), + "Content-Type": "application/json", + "Accept-Encoding": "deflate", // TODO: support lz4 + } + + r, err := request(p, "POST", "/machinegroups", h, body) + if err != nil { + return + } + + body, err = ioutil.ReadAll(r.Body) + if err != nil { + return + } + + if r.StatusCode != http.StatusOK { + errMsg := &errorMessage{} + err = json.Unmarshal(body, errMsg) + if err != nil { + err = fmt.Errorf("failed to create machine group") + dump, _ := httputil.DumpResponse(r, true) + fmt.Printf("%s\n", dump) + return + } + err = fmt.Errorf("%v:%v", errMsg.Code, errMsg.Message) + return + } + + return +} + +// UpdateMachineGroup updates a machine group. +func (p *LogProject) UpdateMachineGroup(m *MachineGroup) (err error) { + + body, err := json.Marshal(m) + if err != nil { + return + } + + h := map[string]string{ + "x-sls-bodyrawsize": fmt.Sprintf("%v", len(body)), + "Content-Type": "application/json", + "Accept-Encoding": "deflate", // TODO: support lz4 + } + + r, err := request(p, "PUT", "/machinegroups/"+m.Name, h, body) + if err != nil { + return + } + + body, err = ioutil.ReadAll(r.Body) + if err != nil { + return + } + + if r.StatusCode != http.StatusOK { + errMsg := &errorMessage{} + err = json.Unmarshal(body, errMsg) + if err != nil { + err = fmt.Errorf("failed to update machine group") + dump, _ := httputil.DumpResponse(r, true) + fmt.Printf("%s\n", dump) + return + } + err = fmt.Errorf("%v:%v", errMsg.Code, errMsg.Message) + return + } + + return +} + +// DeleteMachineGroup deletes machine group according machine group name. +func (p *LogProject) DeleteMachineGroup(name string) (err error) { + h := map[string]string{ + "x-sls-bodyrawsize": "0", + } + + r, err := request(p, "DELETE", "/machinegroups/"+name, h, nil) + if err != nil { + return + } + + body, err := ioutil.ReadAll(r.Body) + if err != nil { + return + } + + if r.StatusCode != http.StatusOK { + errMsg := &errorMessage{} + err = json.Unmarshal(body, errMsg) + if err != nil { + err = fmt.Errorf("failed to delete machine group") + dump, _ := httputil.DumpResponse(r, true) + fmt.Printf("%s\n", dump) + return + } + err = fmt.Errorf("%v:%v", errMsg.Code, errMsg.Message) + return + } + return +} + +// ListConfig returns config names list and the total number of configs. +// The offset starts from 0 and the size is the max number of configs could be returned. +func (p *LogProject) ListConfig(offset, size int) (cfgNames []string, total int, err error) { + h := map[string]string{ + "x-sls-bodyrawsize": "0", + } + + if size <= 0 { + size = 100 + } + + uri := fmt.Sprintf("/configs?offset=%v&size=%v", offset, size) + r, err := request(p, "GET", uri, h, nil) + if err != nil { + return + } + + buf, err := ioutil.ReadAll(r.Body) + if err != nil { + return + } + + if r.StatusCode != http.StatusOK { + errMsg := &errorMessage{} + err = json.Unmarshal(buf, errMsg) + if err != nil { + err = fmt.Errorf("failed to delete machine group") + dump, _ := httputil.DumpResponse(r, true) + fmt.Printf("%s\n", dump) + return + } + err = fmt.Errorf("%v:%v", errMsg.Code, errMsg.Message) + return + } + + type Body struct { + Total int + Configs []string + } + body := &Body{} + + err = json.Unmarshal(buf, body) + if err != nil { + return + } + + cfgNames = body.Configs + total = body.Total + return +} + +// GetConfig returns config according by config name. +func (p *LogProject) GetConfig(name string) (c *LogConfig, err error) { + h := map[string]string{ + "x-sls-bodyrawsize": "0", + } + + r, err := request(p, "GET", "/configs/"+name, h, nil) + if err != nil { + return + } + + buf, err := ioutil.ReadAll(r.Body) + if err != nil { + return + } + + if r.StatusCode != http.StatusOK { + errMsg := &errorMessage{} + err = json.Unmarshal(buf, errMsg) + if err != nil { + err = fmt.Errorf("failed to delete config") + dump, _ := httputil.DumpResponse(r, true) + fmt.Printf("%s\n", dump) + return + } + err = fmt.Errorf("%v:%v", errMsg.Code, errMsg.Message) + return + } + + c = &LogConfig{} + err = json.Unmarshal(buf, c) + if err != nil { + return + } + c.project = p + return +} + +// UpdateConfig updates a config. +func (p *LogProject) UpdateConfig(c *LogConfig) (err error) { + + body, err := json.Marshal(c) + if err != nil { + return + } + + h := map[string]string{ + "x-sls-bodyrawsize": fmt.Sprintf("%v", len(body)), + "Content-Type": "application/json", + "Accept-Encoding": "deflate", // TODO: support lz4 + } + + r, err := request(p, "PUT", "/configs/"+c.Name, h, body) + if err != nil { + return + } + + body, err = ioutil.ReadAll(r.Body) + if err != nil { + return + } + + if r.StatusCode != http.StatusOK { + errMsg := &errorMessage{} + err = json.Unmarshal(body, errMsg) + if err != nil { + err = fmt.Errorf("failed to update config") + dump, _ := httputil.DumpResponse(r, true) + fmt.Printf("%s\n", dump) + return + } + err = fmt.Errorf("%v:%v", errMsg.Code, errMsg.Message) + return + } + + return +} + +// CreateConfig creates a new config in SLS. +func (p *LogProject) CreateConfig(c *LogConfig) (err error) { + + body, err := json.Marshal(c) + if err != nil { + return + } + + h := map[string]string{ + "x-sls-bodyrawsize": fmt.Sprintf("%v", len(body)), + "Content-Type": "application/json", + "Accept-Encoding": "deflate", // TODO: support lz4 + } + + r, err := request(p, "POST", "/configs", h, body) + if err != nil { + return + } + + body, err = ioutil.ReadAll(r.Body) + if err != nil { + return + } + + if r.StatusCode != http.StatusOK { + errMsg := &errorMessage{} + err = json.Unmarshal(body, errMsg) + if err != nil { + err = fmt.Errorf("failed to update config") + dump, _ := httputil.DumpResponse(r, true) + fmt.Printf("%s\n", dump) + return + } + err = fmt.Errorf("%v:%v", errMsg.Code, errMsg.Message) + return + } + + return +} + +// DeleteConfig deletes a config according by config name. +func (p *LogProject) DeleteConfig(name string) (err error) { + h := map[string]string{ + "x-sls-bodyrawsize": "0", + } + + r, err := request(p, "DELETE", "/configs/"+name, h, nil) + if err != nil { + return + } + + body, err := ioutil.ReadAll(r.Body) + if err != nil { + return + } + + if r.StatusCode != http.StatusOK { + errMsg := &errorMessage{} + err = json.Unmarshal(body, errMsg) + if err != nil { + err = fmt.Errorf("failed to delete config") + dump, _ := httputil.DumpResponse(r, true) + fmt.Printf("%s\n", dump) + return + } + err = fmt.Errorf("%v:%v", errMsg.Code, errMsg.Message) + return + } + return +} + +// GetAppliedMachineGroups returns applied machine group names list according config name. +func (p *LogProject) GetAppliedMachineGroups(confName string) (groupNames []string, err error) { + h := map[string]string{ + "x-sls-bodyrawsize": "0", + } + + uri := fmt.Sprintf("/configs/%v/machinegroups", confName) + r, err := request(p, "GET", uri, h, nil) + if err != nil { + return + } + + buf, err := ioutil.ReadAll(r.Body) + if err != nil { + return + } + + if r.StatusCode != http.StatusOK { + errMsg := &errorMessage{} + err = json.Unmarshal(buf, errMsg) + if err != nil { + err = fmt.Errorf("failed to get applied machine groups") + dump, _ := httputil.DumpResponse(r, true) + fmt.Printf("%s\n", dump) + return + } + err = fmt.Errorf("%v:%v", errMsg.Code, errMsg.Message) + return + } + + type Body struct { + Count int + Machinegroups []string + } + + body := &Body{} + err = json.Unmarshal(buf, body) + if err != nil { + return + } + + groupNames = body.Machinegroups + return +} + +// GetAppliedConfigs returns applied config names list according machine group name groupName. +func (p *LogProject) GetAppliedConfigs(groupName string) (confNames []string, err error) { + h := map[string]string{ + "x-sls-bodyrawsize": "0", + } + + uri := fmt.Sprintf("/machinegroups/%v/configs", groupName) + r, err := request(p, "GET", uri, h, nil) + if err != nil { + return + } + + buf, err := ioutil.ReadAll(r.Body) + if err != nil { + return + } + + if r.StatusCode != http.StatusOK { + errMsg := &errorMessage{} + err = json.Unmarshal(buf, errMsg) + if err != nil { + err = fmt.Errorf("failed to applied configs") + dump, _ := httputil.DumpResponse(r, true) + fmt.Printf("%s\n", dump) + return + } + err = fmt.Errorf("%v:%v", errMsg.Code, errMsg.Message) + return + } + + type Cfg struct { + Count int `json:"count"` + Configs []string `json:"configs"` + } + + body := &Cfg{} + err = json.Unmarshal(buf, body) + if err != nil { + return + } + + confNames = body.Configs + return +} + +// ApplyConfigToMachineGroup applies config to machine group. +func (p *LogProject) ApplyConfigToMachineGroup(confName, groupName string) (err error) { + h := map[string]string{ + "x-sls-bodyrawsize": "0", + } + + uri := fmt.Sprintf("/machinegroups/%v/configs/%v", groupName, confName) + r, err := request(p, "PUT", uri, h, nil) + if err != nil { + return + } + + buf, err := ioutil.ReadAll(r.Body) + if err != nil { + return + } + + if r.StatusCode != http.StatusOK { + errMsg := &errorMessage{} + err = json.Unmarshal(buf, errMsg) + if err != nil { + err = fmt.Errorf("failed to apply config to machine group") + dump, _ := httputil.DumpResponse(r, true) + fmt.Printf("%s\n", dump) + return + } + err = fmt.Errorf("%v:%v", errMsg.Code, errMsg.Message) + return + } + return +} + +// RemoveConfigFromMachineGroup removes config from machine group. +func (p *LogProject) RemoveConfigFromMachineGroup(confName, groupName string) (err error) { + h := map[string]string{ + "x-sls-bodyrawsize": "0", + } + + uri := fmt.Sprintf("/machinegroups/%v/configs/%v", groupName, confName) + r, err := request(p, "DELETE", uri, h, nil) + if err != nil { + return + } + + buf, err := ioutil.ReadAll(r.Body) + if err != nil { + return + } + + if r.StatusCode != http.StatusOK { + errMsg := &errorMessage{} + err = json.Unmarshal(buf, errMsg) + if err != nil { + err = fmt.Errorf("failed to remove config from machine group") + dump, _ := httputil.DumpResponse(r, true) + fmt.Printf("%s\n", dump) + return + } + err = fmt.Errorf("%v:%v", errMsg.Code, errMsg.Message) + return + } + return +} diff --git a/pkg/logs/alils/log_store.go b/pkg/logs/alils/log_store.go new file mode 100755 index 00000000..fa502736 --- /dev/null +++ b/pkg/logs/alils/log_store.go @@ -0,0 +1,271 @@ +package alils + +import ( + "encoding/json" + "fmt" + "io/ioutil" + "net/http" + "net/http/httputil" + "strconv" + + lz4 "github.com/cloudflare/golz4" + "github.com/gogo/protobuf/proto" +) + +// LogStore Store the logs +type LogStore struct { + Name string `json:"logstoreName"` + TTL int + ShardCount int + + CreateTime uint32 + LastModifyTime uint32 + + project *LogProject +} + +// Shard define the Log Shard +type Shard struct { + ShardID int `json:"shardID"` +} + +// ListShards returns shard id list of this logstore. +func (s *LogStore) ListShards() (shardIDs []int, err error) { + h := map[string]string{ + "x-sls-bodyrawsize": "0", + } + + uri := fmt.Sprintf("/logstores/%v/shards", s.Name) + r, err := request(s.project, "GET", uri, h, nil) + if err != nil { + return + } + + buf, err := ioutil.ReadAll(r.Body) + if err != nil { + return + } + + if r.StatusCode != http.StatusOK { + errMsg := &errorMessage{} + err = json.Unmarshal(buf, errMsg) + if err != nil { + err = fmt.Errorf("failed to list logstore") + dump, _ := httputil.DumpResponse(r, true) + fmt.Println(dump) + return + } + err = fmt.Errorf("%v:%v", errMsg.Code, errMsg.Message) + return + } + + var shards []*Shard + err = json.Unmarshal(buf, &shards) + if err != nil { + return + } + + for _, v := range shards { + shardIDs = append(shardIDs, v.ShardID) + } + return +} + +// PutLogs put logs into logstore. +// The callers should transform user logs into LogGroup. +func (s *LogStore) PutLogs(lg *LogGroup) (err error) { + body, err := proto.Marshal(lg) + if err != nil { + return + } + + // Compresse body with lz4 + out := make([]byte, lz4.CompressBound(body)) + n, err := lz4.Compress(body, out) + if err != nil { + return + } + + h := map[string]string{ + "x-sls-compresstype": "lz4", + "x-sls-bodyrawsize": fmt.Sprintf("%v", len(body)), + "Content-Type": "application/x-protobuf", + } + + uri := fmt.Sprintf("/logstores/%v", s.Name) + r, err := request(s.project, "POST", uri, h, out[:n]) + if err != nil { + return + } + + buf, err := ioutil.ReadAll(r.Body) + if err != nil { + return + } + + if r.StatusCode != http.StatusOK { + errMsg := &errorMessage{} + err = json.Unmarshal(buf, errMsg) + if err != nil { + err = fmt.Errorf("failed to put logs") + dump, _ := httputil.DumpResponse(r, true) + fmt.Println(dump) + return + } + err = fmt.Errorf("%v:%v", errMsg.Code, errMsg.Message) + return + } + return +} + +// GetCursor gets log cursor of one shard specified by shardID. +// The from can be in three form: a) unix timestamp in seccond, b) "begin", c) "end". +// For more detail please read: http://gitlab.alibaba-inc.com/sls/doc/blob/master/api/shard.md#logstore +func (s *LogStore) GetCursor(shardID int, from string) (cursor string, err error) { + h := map[string]string{ + "x-sls-bodyrawsize": "0", + } + + uri := fmt.Sprintf("/logstores/%v/shards/%v?type=cursor&from=%v", + s.Name, shardID, from) + + r, err := request(s.project, "GET", uri, h, nil) + if err != nil { + return + } + + buf, err := ioutil.ReadAll(r.Body) + if err != nil { + return + } + + if r.StatusCode != http.StatusOK { + errMsg := &errorMessage{} + err = json.Unmarshal(buf, errMsg) + if err != nil { + err = fmt.Errorf("failed to get cursor") + dump, _ := httputil.DumpResponse(r, true) + fmt.Println(dump) + return + } + err = fmt.Errorf("%v:%v", errMsg.Code, errMsg.Message) + return + } + + type Body struct { + Cursor string + } + body := &Body{} + + err = json.Unmarshal(buf, body) + if err != nil { + return + } + cursor = body.Cursor + return +} + +// GetLogsBytes gets logs binary data from shard specified by shardID according cursor. +// The logGroupMaxCount is the max number of logGroup could be returned. +// The nextCursor is the next curosr can be used to read logs at next time. +func (s *LogStore) GetLogsBytes(shardID int, cursor string, + logGroupMaxCount int) (out []byte, nextCursor string, err error) { + + h := map[string]string{ + "x-sls-bodyrawsize": "0", + "Accept": "application/x-protobuf", + "Accept-Encoding": "lz4", + } + + uri := fmt.Sprintf("/logstores/%v/shards/%v?type=logs&cursor=%v&count=%v", + s.Name, shardID, cursor, logGroupMaxCount) + + r, err := request(s.project, "GET", uri, h, nil) + if err != nil { + return + } + + buf, err := ioutil.ReadAll(r.Body) + if err != nil { + return + } + + if r.StatusCode != http.StatusOK { + errMsg := &errorMessage{} + err = json.Unmarshal(buf, errMsg) + if err != nil { + err = fmt.Errorf("failed to get cursor") + dump, _ := httputil.DumpResponse(r, true) + fmt.Println(dump) + return + } + err = fmt.Errorf("%v:%v", errMsg.Code, errMsg.Message) + return + } + + v, ok := r.Header["X-Sls-Compresstype"] + if !ok || len(v) == 0 { + err = fmt.Errorf("can't find 'x-sls-compresstype' header") + return + } + if v[0] != "lz4" { + err = fmt.Errorf("unexpected compress type:%v", v[0]) + return + } + + v, ok = r.Header["X-Sls-Cursor"] + if !ok || len(v) == 0 { + err = fmt.Errorf("can't find 'x-sls-cursor' header") + return + } + nextCursor = v[0] + + v, ok = r.Header["X-Sls-Bodyrawsize"] + if !ok || len(v) == 0 { + err = fmt.Errorf("can't find 'x-sls-bodyrawsize' header") + return + } + bodyRawSize, err := strconv.Atoi(v[0]) + if err != nil { + return + } + + out = make([]byte, bodyRawSize) + err = lz4.Uncompress(buf, out) + if err != nil { + return + } + + return +} + +// LogsBytesDecode decodes logs binary data retruned by GetLogsBytes API +func LogsBytesDecode(data []byte) (gl *LogGroupList, err error) { + + gl = &LogGroupList{} + err = proto.Unmarshal(data, gl) + if err != nil { + return + } + + return +} + +// GetLogs gets logs from shard specified by shardID according cursor. +// The logGroupMaxCount is the max number of logGroup could be returned. +// The nextCursor is the next curosr can be used to read logs at next time. +func (s *LogStore) GetLogs(shardID int, cursor string, + logGroupMaxCount int) (gl *LogGroupList, nextCursor string, err error) { + + out, nextCursor, err := s.GetLogsBytes(shardID, cursor, logGroupMaxCount) + if err != nil { + return + } + + gl, err = LogsBytesDecode(out) + if err != nil { + return + } + + return +} diff --git a/pkg/logs/alils/machine_group.go b/pkg/logs/alils/machine_group.go new file mode 100755 index 00000000..b6c69a14 --- /dev/null +++ b/pkg/logs/alils/machine_group.go @@ -0,0 +1,91 @@ +package alils + +import ( + "encoding/json" + "fmt" + "io/ioutil" + "net/http" + "net/http/httputil" +) + +// MachineGroupAttribute define the Attribute +type MachineGroupAttribute struct { + ExternalName string `json:"externalName"` + TopicName string `json:"groupTopic"` +} + +// MachineGroup define the machine Group +type MachineGroup struct { + Name string `json:"groupName"` + Type string `json:"groupType"` + MachineIDType string `json:"machineIdentifyType"` + MachineIDList []string `json:"machineList"` + + Attribute MachineGroupAttribute `json:"groupAttribute"` + + CreateTime uint32 + LastModifyTime uint32 + + project *LogProject +} + +// Machine define the Machine +type Machine struct { + IP string + UniqueID string `json:"machine-uniqueid"` + UserdefinedID string `json:"userdefined-id"` +} + +// MachineList define the Machine List +type MachineList struct { + Total int + Machines []*Machine +} + +// ListMachines returns machine list of this machine group. +func (m *MachineGroup) ListMachines() (ms []*Machine, total int, err error) { + h := map[string]string{ + "x-sls-bodyrawsize": "0", + } + + uri := fmt.Sprintf("/machinegroups/%v/machines", m.Name) + r, err := request(m.project, "GET", uri, h, nil) + if err != nil { + return + } + + buf, err := ioutil.ReadAll(r.Body) + if err != nil { + return + } + + if r.StatusCode != http.StatusOK { + errMsg := &errorMessage{} + err = json.Unmarshal(buf, errMsg) + if err != nil { + err = fmt.Errorf("failed to remove config from machine group") + dump, _ := httputil.DumpResponse(r, true) + fmt.Println(dump) + return + } + err = fmt.Errorf("%v:%v", errMsg.Code, errMsg.Message) + return + } + + body := &MachineList{} + err = json.Unmarshal(buf, body) + if err != nil { + return + } + + ms = body.Machines + total = body.Total + + return +} + +// GetAppliedConfigs returns applied configs of this machine group. +func (m *MachineGroup) GetAppliedConfigs() (confNames []string, err error) { + confNames, err = m.project.GetAppliedConfigs(m.Name) + return +} diff --git a/pkg/logs/alils/request.go b/pkg/logs/alils/request.go new file mode 100755 index 00000000..50d9c43c --- /dev/null +++ b/pkg/logs/alils/request.go @@ -0,0 +1,62 @@ +package alils + +import ( + "bytes" + "crypto/md5" + "fmt" + "net/http" +) + +// request sends a request to SLS. +func request(project *LogProject, method, uri string, headers map[string]string, + body []byte) (resp *http.Response, err error) { + + // The caller should provide 'x-sls-bodyrawsize' header + if _, ok := headers["x-sls-bodyrawsize"]; !ok { + err = fmt.Errorf("Can't find 'x-sls-bodyrawsize' header") + return + } + + // SLS public request headers + headers["Host"] = project.Name + "." + project.Endpoint + headers["Date"] = nowRFC1123() + headers["x-sls-apiversion"] = version + headers["x-sls-signaturemethod"] = signatureMethod + if body != nil { + bodyMD5 := fmt.Sprintf("%X", md5.Sum(body)) + headers["Content-MD5"] = bodyMD5 + + if _, ok := headers["Content-Type"]; !ok { + err = fmt.Errorf("Can't find 'Content-Type' header") + return + } + } + + // Calc Authorization + // Authorization = "SLS :" + digest, err := signature(project, method, uri, headers) + if err != nil { + return + } + auth := fmt.Sprintf("SLS %v:%v", project.AccessKeyID, digest) + headers["Authorization"] = auth + + // Initialize http request + reader := bytes.NewReader(body) + urlStr := fmt.Sprintf("http://%v.%v%v", project.Name, project.Endpoint, uri) + req, err := http.NewRequest(method, urlStr, reader) + if err != nil { + return + } + for k, v := range headers { + req.Header.Add(k, v) + } + + // Get ready to do request + resp, err = http.DefaultClient.Do(req) + if err != nil { + return + } + + return +} diff --git a/pkg/logs/alils/signature.go b/pkg/logs/alils/signature.go new file mode 100755 index 00000000..2d611307 --- /dev/null +++ b/pkg/logs/alils/signature.go @@ -0,0 +1,111 @@ +package alils + +import ( + "crypto/hmac" + "crypto/sha1" + "encoding/base64" + "fmt" + "net/url" + "sort" + "strings" + "time" +) + +// GMT location +var gmtLoc = time.FixedZone("GMT", 0) + +// NowRFC1123 returns now time in RFC1123 format with GMT timezone, +// eg. "Mon, 02 Jan 2006 15:04:05 GMT". +func nowRFC1123() string { + return time.Now().In(gmtLoc).Format(time.RFC1123) +} + +// signature calculates a request's signature digest. +func signature(project *LogProject, method, uri string, + headers map[string]string) (digest string, err error) { + var contentMD5, contentType, date, canoHeaders, canoResource string + var slsHeaderKeys sort.StringSlice + + // SignString = VERB + "\n" + // + CONTENT-MD5 + "\n" + // + CONTENT-TYPE + "\n" + // + DATE + "\n" + // + CanonicalizedSLSHeaders + "\n" + // + CanonicalizedResource + + if val, ok := headers["Content-MD5"]; ok { + contentMD5 = val + } + + if val, ok := headers["Content-Type"]; ok { + contentType = val + } + + date, ok := headers["Date"] + if !ok { + err = fmt.Errorf("Can't find 'Date' header") + return + } + + // Calc CanonicalizedSLSHeaders + slsHeaders := make(map[string]string, len(headers)) + for k, v := range headers { + l := strings.TrimSpace(strings.ToLower(k)) + if strings.HasPrefix(l, "x-sls-") { + slsHeaders[l] = strings.TrimSpace(v) + slsHeaderKeys = append(slsHeaderKeys, l) + } + } + + sort.Sort(slsHeaderKeys) + for i, k := range slsHeaderKeys { + canoHeaders += k + ":" + slsHeaders[k] + if i+1 < len(slsHeaderKeys) { + canoHeaders += "\n" + } + } + + // Calc CanonicalizedResource + u, err := url.Parse(uri) + if err != nil { + return + } + + canoResource += url.QueryEscape(u.Path) + if u.RawQuery != "" { + var keys sort.StringSlice + + vals := u.Query() + for k := range vals { + keys = append(keys, k) + } + + sort.Sort(keys) + canoResource += "?" + for i, k := range keys { + if i > 0 { + canoResource += "&" + } + + for _, v := range vals[k] { + canoResource += k + "=" + v + } + } + } + + signStr := method + "\n" + + contentMD5 + "\n" + + contentType + "\n" + + date + "\n" + + canoHeaders + "\n" + + canoResource + + // Signature = base64(hmac-sha1(UTF8-Encoding-Of(SignString),AccessKeySecret)) + mac := hmac.New(sha1.New, []byte(project.AccessKeySecret)) + _, err = mac.Write([]byte(signStr)) + if err != nil { + return + } + digest = base64.StdEncoding.EncodeToString(mac.Sum(nil)) + return +} diff --git a/pkg/logs/conn.go b/pkg/logs/conn.go new file mode 100644 index 00000000..74c458ab --- /dev/null +++ b/pkg/logs/conn.go @@ -0,0 +1,119 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package logs + +import ( + "encoding/json" + "io" + "net" + "time" +) + +// connWriter implements LoggerInterface. +// it writes messages in keep-live tcp connection. +type connWriter struct { + lg *logWriter + innerWriter io.WriteCloser + ReconnectOnMsg bool `json:"reconnectOnMsg"` + Reconnect bool `json:"reconnect"` + Net string `json:"net"` + Addr string `json:"addr"` + Level int `json:"level"` +} + +// NewConn create new ConnWrite returning as LoggerInterface. +func NewConn() Logger { + conn := new(connWriter) + conn.Level = LevelTrace + return conn +} + +// Init init connection writer with json config. +// json config only need key "level". +func (c *connWriter) Init(jsonConfig string) error { + return json.Unmarshal([]byte(jsonConfig), c) +} + +// WriteMsg write message in connection. +// if connection is down, try to re-connect. +func (c *connWriter) WriteMsg(when time.Time, msg string, level int) error { + if level > c.Level { + return nil + } + if c.needToConnectOnMsg() { + err := c.connect() + if err != nil { + return err + } + } + + if c.ReconnectOnMsg { + defer c.innerWriter.Close() + } + + _, err := c.lg.writeln(when, msg) + if err != nil { + return err + } + return nil +} + +// Flush implementing method. empty. +func (c *connWriter) Flush() { + +} + +// Destroy destroy connection writer and close tcp listener. +func (c *connWriter) Destroy() { + if c.innerWriter != nil { + c.innerWriter.Close() + } +} + +func (c *connWriter) connect() error { + if c.innerWriter != nil { + c.innerWriter.Close() + c.innerWriter = nil + } + + conn, err := net.Dial(c.Net, c.Addr) + if err != nil { + return err + } + + if tcpConn, ok := conn.(*net.TCPConn); ok { + tcpConn.SetKeepAlive(true) + } + + c.innerWriter = conn + c.lg = newLogWriter(conn) + return nil +} + +func (c *connWriter) needToConnectOnMsg() bool { + if c.Reconnect { + return true + } + + if c.innerWriter == nil { + return true + } + + return c.ReconnectOnMsg +} + +func init() { + Register(AdapterConn, NewConn) +} diff --git a/pkg/logs/conn_test.go b/pkg/logs/conn_test.go new file mode 100644 index 00000000..bb377d41 --- /dev/null +++ b/pkg/logs/conn_test.go @@ -0,0 +1,79 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package logs + +import ( + "net" + "os" + "testing" +) + +// ConnTCPListener takes a TCP listener and accepts n TCP connections +// Returns connections using connChan +func connTCPListener(t *testing.T, n int, ln net.Listener, connChan chan<- net.Conn) { + + // Listen and accept n incoming connections + for i := 0; i < n; i++ { + conn, err := ln.Accept() + if err != nil { + t.Log("Error accepting connection: ", err.Error()) + os.Exit(1) + } + + // Send accepted connection to channel + connChan <- conn + } + ln.Close() + close(connChan) +} + +func TestConn(t *testing.T) { + log := NewLogger(1000) + log.SetLogger("conn", `{"net":"tcp","addr":":7020"}`) + log.Informational("informational") +} + +func TestReconnect(t *testing.T) { + // Setup connection listener + newConns := make(chan net.Conn) + connNum := 2 + ln, err := net.Listen("tcp", ":6002") + if err != nil { + t.Log("Error listening:", err.Error()) + os.Exit(1) + } + go connTCPListener(t, connNum, ln, newConns) + + // Setup logger + log := NewLogger(1000) + log.SetPrefix("test") + log.SetLogger(AdapterConn, `{"net":"tcp","reconnect":true,"level":6,"addr":":6002"}`) + log.Informational("informational 1") + + // Refuse first connection + first := <-newConns + first.Close() + + // Send another log after conn closed + log.Informational("informational 2") + + // Check if there was a second connection attempt + select { + case second := <-newConns: + second.Close() + default: + t.Error("Did not reconnect") + } +} diff --git a/pkg/logs/console.go b/pkg/logs/console.go new file mode 100644 index 00000000..3dcaee1d --- /dev/null +++ b/pkg/logs/console.go @@ -0,0 +1,99 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package logs + +import ( + "encoding/json" + "os" + "strings" + "time" + + "github.com/shiena/ansicolor" +) + +// brush is a color join function +type brush func(string) string + +// newBrush return a fix color Brush +func newBrush(color string) brush { + pre := "\033[" + reset := "\033[0m" + return func(text string) string { + return pre + color + "m" + text + reset + } +} + +var colors = []brush{ + newBrush("1;37"), // Emergency white + newBrush("1;36"), // Alert cyan + newBrush("1;35"), // Critical magenta + newBrush("1;31"), // Error red + newBrush("1;33"), // Warning yellow + newBrush("1;32"), // Notice green + newBrush("1;34"), // Informational blue + newBrush("1;44"), // Debug Background blue +} + +// consoleWriter implements LoggerInterface and writes messages to terminal. +type consoleWriter struct { + lg *logWriter + Level int `json:"level"` + Colorful bool `json:"color"` //this filed is useful only when system's terminal supports color +} + +// NewConsole create ConsoleWriter returning as LoggerInterface. +func NewConsole() Logger { + cw := &consoleWriter{ + lg: newLogWriter(ansicolor.NewAnsiColorWriter(os.Stdout)), + Level: LevelDebug, + Colorful: true, + } + return cw +} + +// Init init console logger. +// jsonConfig like '{"level":LevelTrace}'. +func (c *consoleWriter) Init(jsonConfig string) error { + if len(jsonConfig) == 0 { + return nil + } + return json.Unmarshal([]byte(jsonConfig), c) +} + +// WriteMsg write message in console. +func (c *consoleWriter) WriteMsg(when time.Time, msg string, level int) error { + if level > c.Level { + return nil + } + if c.Colorful { + msg = strings.Replace(msg, levelPrefix[level], colors[level](levelPrefix[level]), 1) + } + c.lg.writeln(when, msg) + return nil +} + +// Destroy implementing method. empty. +func (c *consoleWriter) Destroy() { + +} + +// Flush implementing method. empty. +func (c *consoleWriter) Flush() { + +} + +func init() { + Register(AdapterConsole, NewConsole) +} diff --git a/pkg/logs/console_test.go b/pkg/logs/console_test.go new file mode 100644 index 00000000..4bc45f57 --- /dev/null +++ b/pkg/logs/console_test.go @@ -0,0 +1,64 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package logs + +import ( + "testing" + "time" +) + +// Try each log level in decreasing order of priority. +func testConsoleCalls(bl *BeeLogger) { + bl.Emergency("emergency") + bl.Alert("alert") + bl.Critical("critical") + bl.Error("error") + bl.Warning("warning") + bl.Notice("notice") + bl.Informational("informational") + bl.Debug("debug") +} + +// Test console logging by visually comparing the lines being output with and +// without a log level specification. +func TestConsole(t *testing.T) { + log1 := NewLogger(10000) + log1.EnableFuncCallDepth(true) + log1.SetLogger("console", "") + testConsoleCalls(log1) + + log2 := NewLogger(100) + log2.SetLogger("console", `{"level":3}`) + testConsoleCalls(log2) +} + +// Test console without color +func TestConsoleNoColor(t *testing.T) { + log := NewLogger(100) + log.SetLogger("console", `{"color":false}`) + testConsoleCalls(log) +} + +// Test console async +func TestConsoleAsync(t *testing.T) { + log := NewLogger(100) + log.SetLogger("console") + log.Async() + //log.Close() + testConsoleCalls(log) + for len(log.msgChan) != 0 { + time.Sleep(1 * time.Millisecond) + } +} diff --git a/pkg/logs/es/es.go b/pkg/logs/es/es.go new file mode 100644 index 00000000..2b7b1710 --- /dev/null +++ b/pkg/logs/es/es.go @@ -0,0 +1,102 @@ +package es + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "net/url" + "strings" + "time" + + "github.com/elastic/go-elasticsearch/v6" + "github.com/elastic/go-elasticsearch/v6/esapi" + + "github.com/astaxie/beego/logs" +) + +// NewES return a LoggerInterface +func NewES() logs.Logger { + cw := &esLogger{ + Level: logs.LevelDebug, + } + return cw +} + +// esLogger will log msg into ES +// before you using this implementation, +// please import this package +// usually means that you can import this package in your main package +// for example, anonymous: +// import _ "github.com/astaxie/beego/logs/es" +type esLogger struct { + *elasticsearch.Client + DSN string `json:"dsn"` + Level int `json:"level"` +} + +// {"dsn":"http://localhost:9200/","level":1} +func (el *esLogger) Init(jsonconfig string) error { + err := json.Unmarshal([]byte(jsonconfig), el) + if err != nil { + return err + } + if el.DSN == "" { + return errors.New("empty dsn") + } else if u, err := url.Parse(el.DSN); err != nil { + return err + } else if u.Path == "" { + return errors.New("missing prefix") + } else { + conn, err := elasticsearch.NewClient(elasticsearch.Config{ + Addresses: []string{el.DSN}, + }) + if err != nil { + return err + } + el.Client = conn + } + return nil +} + +// WriteMsg will write the msg and level into es +func (el *esLogger) WriteMsg(when time.Time, msg string, level int) error { + if level > el.Level { + return nil + } + + idx := LogDocument{ + Timestamp: when.Format(time.RFC3339), + Msg: msg, + } + + body, err := json.Marshal(idx) + if err != nil { + return err + } + req := esapi.IndexRequest{ + Index: fmt.Sprintf("%04d.%02d.%02d", when.Year(), when.Month(), when.Day()), + DocumentType: "logs", + Body: strings.NewReader(string(body)), + } + _, err = req.Do(context.Background(), el.Client) + return err +} + +// Destroy is a empty method +func (el *esLogger) Destroy() { +} + +// Flush is a empty method +func (el *esLogger) Flush() { + +} + +type LogDocument struct { + Timestamp string `json:"timestamp"` + Msg string `json:"msg"` +} + +func init() { + logs.Register(logs.AdapterEs, NewES) +} diff --git a/pkg/logs/file.go b/pkg/logs/file.go new file mode 100644 index 00000000..222db989 --- /dev/null +++ b/pkg/logs/file.go @@ -0,0 +1,409 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package logs + +import ( + "bytes" + "encoding/json" + "errors" + "fmt" + "io" + "os" + "path" + "path/filepath" + "strconv" + "strings" + "sync" + "time" +) + +// fileLogWriter implements LoggerInterface. +// It writes messages by lines limit, file size limit, or time frequency. +type fileLogWriter struct { + sync.RWMutex // write log order by order and atomic incr maxLinesCurLines and maxSizeCurSize + // The opened file + Filename string `json:"filename"` + fileWriter *os.File + + // Rotate at line + MaxLines int `json:"maxlines"` + maxLinesCurLines int + + MaxFiles int `json:"maxfiles"` + MaxFilesCurFiles int + + // Rotate at size + MaxSize int `json:"maxsize"` + maxSizeCurSize int + + // Rotate daily + Daily bool `json:"daily"` + MaxDays int64 `json:"maxdays"` + dailyOpenDate int + dailyOpenTime time.Time + + // Rotate hourly + Hourly bool `json:"hourly"` + MaxHours int64 `json:"maxhours"` + hourlyOpenDate int + hourlyOpenTime time.Time + + Rotate bool `json:"rotate"` + + Level int `json:"level"` + + Perm string `json:"perm"` + + RotatePerm string `json:"rotateperm"` + + fileNameOnly, suffix string // like "project.log", project is fileNameOnly and .log is suffix +} + +// newFileWriter create a FileLogWriter returning as LoggerInterface. +func newFileWriter() Logger { + w := &fileLogWriter{ + Daily: true, + MaxDays: 7, + Hourly: false, + MaxHours: 168, + Rotate: true, + RotatePerm: "0440", + Level: LevelTrace, + Perm: "0660", + MaxLines: 10000000, + MaxFiles: 999, + MaxSize: 1 << 28, + } + return w +} + +// Init file logger with json config. +// jsonConfig like: +// { +// "filename":"logs/beego.log", +// "maxLines":10000, +// "maxsize":1024, +// "daily":true, +// "maxDays":15, +// "rotate":true, +// "perm":"0600" +// } +func (w *fileLogWriter) Init(jsonConfig string) error { + err := json.Unmarshal([]byte(jsonConfig), w) + if err != nil { + return err + } + if len(w.Filename) == 0 { + return errors.New("jsonconfig must have filename") + } + w.suffix = filepath.Ext(w.Filename) + w.fileNameOnly = strings.TrimSuffix(w.Filename, w.suffix) + if w.suffix == "" { + w.suffix = ".log" + } + err = w.startLogger() + return err +} + +// start file logger. create log file and set to locker-inside file writer. +func (w *fileLogWriter) startLogger() error { + file, err := w.createLogFile() + if err != nil { + return err + } + if w.fileWriter != nil { + w.fileWriter.Close() + } + w.fileWriter = file + return w.initFd() +} + +func (w *fileLogWriter) needRotateDaily(size int, day int) bool { + return (w.MaxLines > 0 && w.maxLinesCurLines >= w.MaxLines) || + (w.MaxSize > 0 && w.maxSizeCurSize >= w.MaxSize) || + (w.Daily && day != w.dailyOpenDate) +} + +func (w *fileLogWriter) needRotateHourly(size int, hour int) bool { + return (w.MaxLines > 0 && w.maxLinesCurLines >= w.MaxLines) || + (w.MaxSize > 0 && w.maxSizeCurSize >= w.MaxSize) || + (w.Hourly && hour != w.hourlyOpenDate) + +} + +// WriteMsg write logger message into file. +func (w *fileLogWriter) WriteMsg(when time.Time, msg string, level int) error { + if level > w.Level { + return nil + } + hd, d, h := formatTimeHeader(when) + msg = string(hd) + msg + "\n" + if w.Rotate { + w.RLock() + if w.needRotateHourly(len(msg), h) { + w.RUnlock() + w.Lock() + if w.needRotateHourly(len(msg), h) { + if err := w.doRotate(when); err != nil { + fmt.Fprintf(os.Stderr, "FileLogWriter(%q): %s\n", w.Filename, err) + } + } + w.Unlock() + } else if w.needRotateDaily(len(msg), d) { + w.RUnlock() + w.Lock() + if w.needRotateDaily(len(msg), d) { + if err := w.doRotate(when); err != nil { + fmt.Fprintf(os.Stderr, "FileLogWriter(%q): %s\n", w.Filename, err) + } + } + w.Unlock() + } else { + w.RUnlock() + } + } + + w.Lock() + _, err := w.fileWriter.Write([]byte(msg)) + if err == nil { + w.maxLinesCurLines++ + w.maxSizeCurSize += len(msg) + } + w.Unlock() + return err +} + +func (w *fileLogWriter) createLogFile() (*os.File, error) { + // Open the log file + perm, err := strconv.ParseInt(w.Perm, 8, 64) + if err != nil { + return nil, err + } + + filepath := path.Dir(w.Filename) + os.MkdirAll(filepath, os.FileMode(perm)) + + fd, err := os.OpenFile(w.Filename, os.O_WRONLY|os.O_APPEND|os.O_CREATE, os.FileMode(perm)) + if err == nil { + // Make sure file perm is user set perm cause of `os.OpenFile` will obey umask + os.Chmod(w.Filename, os.FileMode(perm)) + } + return fd, err +} + +func (w *fileLogWriter) initFd() error { + fd := w.fileWriter + fInfo, err := fd.Stat() + if err != nil { + return fmt.Errorf("get stat err: %s", err) + } + w.maxSizeCurSize = int(fInfo.Size()) + w.dailyOpenTime = time.Now() + w.dailyOpenDate = w.dailyOpenTime.Day() + w.hourlyOpenTime = time.Now() + w.hourlyOpenDate = w.hourlyOpenTime.Hour() + w.maxLinesCurLines = 0 + if w.Hourly { + go w.hourlyRotate(w.hourlyOpenTime) + } else if w.Daily { + go w.dailyRotate(w.dailyOpenTime) + } + if fInfo.Size() > 0 && w.MaxLines > 0 { + count, err := w.lines() + if err != nil { + return err + } + w.maxLinesCurLines = count + } + return nil +} + +func (w *fileLogWriter) dailyRotate(openTime time.Time) { + y, m, d := openTime.Add(24 * time.Hour).Date() + nextDay := time.Date(y, m, d, 0, 0, 0, 0, openTime.Location()) + tm := time.NewTimer(time.Duration(nextDay.UnixNano() - openTime.UnixNano() + 100)) + <-tm.C + w.Lock() + if w.needRotateDaily(0, time.Now().Day()) { + if err := w.doRotate(time.Now()); err != nil { + fmt.Fprintf(os.Stderr, "FileLogWriter(%q): %s\n", w.Filename, err) + } + } + w.Unlock() +} + +func (w *fileLogWriter) hourlyRotate(openTime time.Time) { + y, m, d := openTime.Add(1 * time.Hour).Date() + h, _, _ := openTime.Add(1 * time.Hour).Clock() + nextHour := time.Date(y, m, d, h, 0, 0, 0, openTime.Location()) + tm := time.NewTimer(time.Duration(nextHour.UnixNano() - openTime.UnixNano() + 100)) + <-tm.C + w.Lock() + if w.needRotateHourly(0, time.Now().Hour()) { + if err := w.doRotate(time.Now()); err != nil { + fmt.Fprintf(os.Stderr, "FileLogWriter(%q): %s\n", w.Filename, err) + } + } + w.Unlock() +} + +func (w *fileLogWriter) lines() (int, error) { + fd, err := os.Open(w.Filename) + if err != nil { + return 0, err + } + defer fd.Close() + + buf := make([]byte, 32768) // 32k + count := 0 + lineSep := []byte{'\n'} + + for { + c, err := fd.Read(buf) + if err != nil && err != io.EOF { + return count, err + } + + count += bytes.Count(buf[:c], lineSep) + + if err == io.EOF { + break + } + } + + return count, nil +} + +// DoRotate means it need to write file in new file. +// new file name like xx.2013-01-01.log (daily) or xx.001.log (by line or size) +func (w *fileLogWriter) doRotate(logTime time.Time) error { + // file exists + // Find the next available number + num := w.MaxFilesCurFiles + 1 + fName := "" + format := "" + var openTime time.Time + rotatePerm, err := strconv.ParseInt(w.RotatePerm, 8, 64) + if err != nil { + return err + } + + _, err = os.Lstat(w.Filename) + if err != nil { + //even if the file is not exist or other ,we should RESTART the logger + goto RESTART_LOGGER + } + + if w.Hourly { + format = "2006010215" + openTime = w.hourlyOpenTime + } else if w.Daily { + format = "2006-01-02" + openTime = w.dailyOpenTime + } + + // only when one of them be setted, then the file would be splited + if w.MaxLines > 0 || w.MaxSize > 0 { + for ; err == nil && num <= w.MaxFiles; num++ { + fName = w.fileNameOnly + fmt.Sprintf(".%s.%03d%s", logTime.Format(format), num, w.suffix) + _, err = os.Lstat(fName) + } + } else { + fName = w.fileNameOnly + fmt.Sprintf(".%s.%03d%s", openTime.Format(format), num, w.suffix) + _, err = os.Lstat(fName) + w.MaxFilesCurFiles = num + } + + // return error if the last file checked still existed + if err == nil { + return fmt.Errorf("Rotate: Cannot find free log number to rename %s", w.Filename) + } + + // close fileWriter before rename + w.fileWriter.Close() + + // Rename the file to its new found name + // even if occurs error,we MUST guarantee to restart new logger + err = os.Rename(w.Filename, fName) + if err != nil { + goto RESTART_LOGGER + } + + err = os.Chmod(fName, os.FileMode(rotatePerm)) + +RESTART_LOGGER: + + startLoggerErr := w.startLogger() + go w.deleteOldLog() + + if startLoggerErr != nil { + return fmt.Errorf("Rotate StartLogger: %s", startLoggerErr) + } + if err != nil { + return fmt.Errorf("Rotate: %s", err) + } + return nil +} + +func (w *fileLogWriter) deleteOldLog() { + dir := filepath.Dir(w.Filename) + absolutePath, err := filepath.EvalSymlinks(w.Filename) + if err == nil { + dir = filepath.Dir(absolutePath) + } + filepath.Walk(dir, func(path string, info os.FileInfo, err error) (returnErr error) { + defer func() { + if r := recover(); r != nil { + fmt.Fprintf(os.Stderr, "Unable to delete old log '%s', error: %v\n", path, r) + } + }() + + if info == nil { + return + } + if w.Hourly { + if !info.IsDir() && info.ModTime().Add(1 * time.Hour * time.Duration(w.MaxHours)).Before(time.Now()) { + if strings.HasPrefix(filepath.Base(path), filepath.Base(w.fileNameOnly)) && + strings.HasSuffix(filepath.Base(path), w.suffix) { + os.Remove(path) + } + } + } else if w.Daily { + if !info.IsDir() && info.ModTime().Add(24 * time.Hour * time.Duration(w.MaxDays)).Before(time.Now()) { + if strings.HasPrefix(filepath.Base(path), filepath.Base(w.fileNameOnly)) && + strings.HasSuffix(filepath.Base(path), w.suffix) { + os.Remove(path) + } + } + } + return + }) +} + +// Destroy close the file description, close file writer. +func (w *fileLogWriter) Destroy() { + w.fileWriter.Close() +} + +// Flush flush file logger. +// there are no buffering messages in file logger in memory. +// flush file means sync file from disk. +func (w *fileLogWriter) Flush() { + w.fileWriter.Sync() +} + +func init() { + Register(AdapterFile, newFileWriter) +} diff --git a/pkg/logs/file_test.go b/pkg/logs/file_test.go new file mode 100644 index 00000000..e7c2ca9a --- /dev/null +++ b/pkg/logs/file_test.go @@ -0,0 +1,420 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package logs + +import ( + "bufio" + "fmt" + "io/ioutil" + "os" + "strconv" + "testing" + "time" +) + +func TestFilePerm(t *testing.T) { + log := NewLogger(10000) + // use 0666 as test perm cause the default umask is 022 + log.SetLogger("file", `{"filename":"test.log", "perm": "0666"}`) + log.Debug("debug") + log.Informational("info") + log.Notice("notice") + log.Warning("warning") + log.Error("error") + log.Alert("alert") + log.Critical("critical") + log.Emergency("emergency") + file, err := os.Stat("test.log") + if err != nil { + t.Fatal(err) + } + if file.Mode() != 0666 { + t.Fatal("unexpected log file permission") + } + os.Remove("test.log") +} + +func TestFile1(t *testing.T) { + log := NewLogger(10000) + log.SetLogger("file", `{"filename":"test.log"}`) + log.Debug("debug") + log.Informational("info") + log.Notice("notice") + log.Warning("warning") + log.Error("error") + log.Alert("alert") + log.Critical("critical") + log.Emergency("emergency") + f, err := os.Open("test.log") + if err != nil { + t.Fatal(err) + } + b := bufio.NewReader(f) + lineNum := 0 + for { + line, _, err := b.ReadLine() + if err != nil { + break + } + if len(line) > 0 { + lineNum++ + } + } + var expected = LevelDebug + 1 + if lineNum != expected { + t.Fatal(lineNum, "not "+strconv.Itoa(expected)+" lines") + } + os.Remove("test.log") +} + +func TestFile2(t *testing.T) { + log := NewLogger(10000) + log.SetLogger("file", fmt.Sprintf(`{"filename":"test2.log","level":%d}`, LevelError)) + log.Debug("debug") + log.Info("info") + log.Notice("notice") + log.Warning("warning") + log.Error("error") + log.Alert("alert") + log.Critical("critical") + log.Emergency("emergency") + f, err := os.Open("test2.log") + if err != nil { + t.Fatal(err) + } + b := bufio.NewReader(f) + lineNum := 0 + for { + line, _, err := b.ReadLine() + if err != nil { + break + } + if len(line) > 0 { + lineNum++ + } + } + var expected = LevelError + 1 + if lineNum != expected { + t.Fatal(lineNum, "not "+strconv.Itoa(expected)+" lines") + } + os.Remove("test2.log") +} + +func TestFileDailyRotate_01(t *testing.T) { + log := NewLogger(10000) + log.SetLogger("file", `{"filename":"test3.log","maxlines":4}`) + log.Debug("debug") + log.Info("info") + log.Notice("notice") + log.Warning("warning") + log.Error("error") + log.Alert("alert") + log.Critical("critical") + log.Emergency("emergency") + rotateName := "test3" + fmt.Sprintf(".%s.%03d", time.Now().Format("2006-01-02"), 1) + ".log" + b, err := exists(rotateName) + if !b || err != nil { + os.Remove("test3.log") + t.Fatal("rotate not generated") + } + os.Remove(rotateName) + os.Remove("test3.log") +} + +func TestFileDailyRotate_02(t *testing.T) { + fn1 := "rotate_day.log" + fn2 := "rotate_day." + time.Now().Add(-24*time.Hour).Format("2006-01-02") + ".001.log" + testFileRotate(t, fn1, fn2, true, false) +} + +func TestFileDailyRotate_03(t *testing.T) { + fn1 := "rotate_day.log" + fn := "rotate_day." + time.Now().Add(-24*time.Hour).Format("2006-01-02") + ".log" + os.Create(fn) + fn2 := "rotate_day." + time.Now().Add(-24*time.Hour).Format("2006-01-02") + ".001.log" + testFileRotate(t, fn1, fn2, true, false) + os.Remove(fn) +} + +func TestFileDailyRotate_04(t *testing.T) { + fn1 := "rotate_day.log" + fn2 := "rotate_day." + time.Now().Add(-24*time.Hour).Format("2006-01-02") + ".001.log" + testFileDailyRotate(t, fn1, fn2) +} + +func TestFileDailyRotate_05(t *testing.T) { + fn1 := "rotate_day.log" + fn := "rotate_day." + time.Now().Add(-24*time.Hour).Format("2006-01-02") + ".log" + os.Create(fn) + fn2 := "rotate_day." + time.Now().Add(-24*time.Hour).Format("2006-01-02") + ".001.log" + testFileDailyRotate(t, fn1, fn2) + os.Remove(fn) +} +func TestFileDailyRotate_06(t *testing.T) { //test file mode + log := NewLogger(10000) + log.SetLogger("file", `{"filename":"test3.log","maxlines":4}`) + log.Debug("debug") + log.Info("info") + log.Notice("notice") + log.Warning("warning") + log.Error("error") + log.Alert("alert") + log.Critical("critical") + log.Emergency("emergency") + rotateName := "test3" + fmt.Sprintf(".%s.%03d", time.Now().Format("2006-01-02"), 1) + ".log" + s, _ := os.Lstat(rotateName) + if s.Mode() != 0440 { + os.Remove(rotateName) + os.Remove("test3.log") + t.Fatal("rotate file mode error") + } + os.Remove(rotateName) + os.Remove("test3.log") +} + +func TestFileHourlyRotate_01(t *testing.T) { + log := NewLogger(10000) + log.SetLogger("file", `{"filename":"test3.log","hourly":true,"maxlines":4}`) + log.Debug("debug") + log.Info("info") + log.Notice("notice") + log.Warning("warning") + log.Error("error") + log.Alert("alert") + log.Critical("critical") + log.Emergency("emergency") + rotateName := "test3" + fmt.Sprintf(".%s.%03d", time.Now().Format("2006010215"), 1) + ".log" + b, err := exists(rotateName) + if !b || err != nil { + os.Remove("test3.log") + t.Fatal("rotate not generated") + } + os.Remove(rotateName) + os.Remove("test3.log") +} + +func TestFileHourlyRotate_02(t *testing.T) { + fn1 := "rotate_hour.log" + fn2 := "rotate_hour." + time.Now().Add(-1*time.Hour).Format("2006010215") + ".001.log" + testFileRotate(t, fn1, fn2, false, true) +} + +func TestFileHourlyRotate_03(t *testing.T) { + fn1 := "rotate_hour.log" + fn := "rotate_hour." + time.Now().Add(-1*time.Hour).Format("2006010215") + ".log" + os.Create(fn) + fn2 := "rotate_hour." + time.Now().Add(-1*time.Hour).Format("2006010215") + ".001.log" + testFileRotate(t, fn1, fn2, false, true) + os.Remove(fn) +} + +func TestFileHourlyRotate_04(t *testing.T) { + fn1 := "rotate_hour.log" + fn2 := "rotate_hour." + time.Now().Add(-1*time.Hour).Format("2006010215") + ".001.log" + testFileHourlyRotate(t, fn1, fn2) +} + +func TestFileHourlyRotate_05(t *testing.T) { + fn1 := "rotate_hour.log" + fn := "rotate_hour." + time.Now().Add(-1*time.Hour).Format("2006010215") + ".log" + os.Create(fn) + fn2 := "rotate_hour." + time.Now().Add(-1*time.Hour).Format("2006010215") + ".001.log" + testFileHourlyRotate(t, fn1, fn2) + os.Remove(fn) +} + +func TestFileHourlyRotate_06(t *testing.T) { //test file mode + log := NewLogger(10000) + log.SetLogger("file", `{"filename":"test3.log", "hourly":true, "maxlines":4}`) + log.Debug("debug") + log.Info("info") + log.Notice("notice") + log.Warning("warning") + log.Error("error") + log.Alert("alert") + log.Critical("critical") + log.Emergency("emergency") + rotateName := "test3" + fmt.Sprintf(".%s.%03d", time.Now().Format("2006010215"), 1) + ".log" + s, _ := os.Lstat(rotateName) + if s.Mode() != 0440 { + os.Remove(rotateName) + os.Remove("test3.log") + t.Fatal("rotate file mode error") + } + os.Remove(rotateName) + os.Remove("test3.log") +} + +func testFileRotate(t *testing.T, fn1, fn2 string, daily, hourly bool) { + fw := &fileLogWriter{ + Daily: daily, + MaxDays: 7, + Hourly: hourly, + MaxHours: 168, + Rotate: true, + Level: LevelTrace, + Perm: "0660", + RotatePerm: "0440", + } + + if daily { + fw.Init(fmt.Sprintf(`{"filename":"%v","maxdays":1}`, fn1)) + fw.dailyOpenTime = time.Now().Add(-24 * time.Hour) + fw.dailyOpenDate = fw.dailyOpenTime.Day() + } + + if hourly { + fw.Init(fmt.Sprintf(`{"filename":"%v","maxhours":1}`, fn1)) + fw.hourlyOpenTime = time.Now().Add(-1 * time.Hour) + fw.hourlyOpenDate = fw.hourlyOpenTime.Day() + } + + fw.WriteMsg(time.Now(), "this is a msg for test", LevelDebug) + + for _, file := range []string{fn1, fn2} { + _, err := os.Stat(file) + if err != nil { + t.Log(err) + t.FailNow() + } + os.Remove(file) + } + fw.Destroy() +} + +func testFileDailyRotate(t *testing.T, fn1, fn2 string) { + fw := &fileLogWriter{ + Daily: true, + MaxDays: 7, + Rotate: true, + Level: LevelTrace, + Perm: "0660", + RotatePerm: "0440", + } + fw.Init(fmt.Sprintf(`{"filename":"%v","maxdays":1}`, fn1)) + fw.dailyOpenTime = time.Now().Add(-24 * time.Hour) + fw.dailyOpenDate = fw.dailyOpenTime.Day() + today, _ := time.ParseInLocation("2006-01-02", time.Now().Format("2006-01-02"), fw.dailyOpenTime.Location()) + today = today.Add(-1 * time.Second) + fw.dailyRotate(today) + for _, file := range []string{fn1, fn2} { + _, err := os.Stat(file) + if err != nil { + t.FailNow() + } + content, err := ioutil.ReadFile(file) + if err != nil { + t.FailNow() + } + if len(content) > 0 { + t.FailNow() + } + os.Remove(file) + } + fw.Destroy() +} + +func testFileHourlyRotate(t *testing.T, fn1, fn2 string) { + fw := &fileLogWriter{ + Hourly: true, + MaxHours: 168, + Rotate: true, + Level: LevelTrace, + Perm: "0660", + RotatePerm: "0440", + } + fw.Init(fmt.Sprintf(`{"filename":"%v","maxhours":1}`, fn1)) + fw.hourlyOpenTime = time.Now().Add(-1 * time.Hour) + fw.hourlyOpenDate = fw.hourlyOpenTime.Hour() + hour, _ := time.ParseInLocation("2006010215", time.Now().Format("2006010215"), fw.hourlyOpenTime.Location()) + hour = hour.Add(-1 * time.Second) + fw.hourlyRotate(hour) + for _, file := range []string{fn1, fn2} { + _, err := os.Stat(file) + if err != nil { + t.FailNow() + } + content, err := ioutil.ReadFile(file) + if err != nil { + t.FailNow() + } + if len(content) > 0 { + t.FailNow() + } + os.Remove(file) + } + fw.Destroy() +} +func exists(path string) (bool, error) { + _, err := os.Stat(path) + if err == nil { + return true, nil + } + if os.IsNotExist(err) { + return false, nil + } + return false, err +} + +func BenchmarkFile(b *testing.B) { + log := NewLogger(100000) + log.SetLogger("file", `{"filename":"test4.log"}`) + for i := 0; i < b.N; i++ { + log.Debug("debug") + } + os.Remove("test4.log") +} + +func BenchmarkFileAsynchronous(b *testing.B) { + log := NewLogger(100000) + log.SetLogger("file", `{"filename":"test4.log"}`) + log.Async() + for i := 0; i < b.N; i++ { + log.Debug("debug") + } + os.Remove("test4.log") +} + +func BenchmarkFileCallDepth(b *testing.B) { + log := NewLogger(100000) + log.SetLogger("file", `{"filename":"test4.log"}`) + log.EnableFuncCallDepth(true) + log.SetLogFuncCallDepth(2) + for i := 0; i < b.N; i++ { + log.Debug("debug") + } + os.Remove("test4.log") +} + +func BenchmarkFileAsynchronousCallDepth(b *testing.B) { + log := NewLogger(100000) + log.SetLogger("file", `{"filename":"test4.log"}`) + log.EnableFuncCallDepth(true) + log.SetLogFuncCallDepth(2) + log.Async() + for i := 0; i < b.N; i++ { + log.Debug("debug") + } + os.Remove("test4.log") +} + +func BenchmarkFileOnGoroutine(b *testing.B) { + log := NewLogger(100000) + log.SetLogger("file", `{"filename":"test4.log"}`) + for i := 0; i < b.N; i++ { + go log.Debug("debug") + } + os.Remove("test4.log") +} diff --git a/pkg/logs/jianliao.go b/pkg/logs/jianliao.go new file mode 100644 index 00000000..88ba0f9a --- /dev/null +++ b/pkg/logs/jianliao.go @@ -0,0 +1,72 @@ +package logs + +import ( + "encoding/json" + "fmt" + "net/http" + "net/url" + "time" +) + +// JLWriter implements beego LoggerInterface and is used to send jiaoliao webhook +type JLWriter struct { + AuthorName string `json:"authorname"` + Title string `json:"title"` + WebhookURL string `json:"webhookurl"` + RedirectURL string `json:"redirecturl,omitempty"` + ImageURL string `json:"imageurl,omitempty"` + Level int `json:"level"` +} + +// newJLWriter create jiaoliao writer. +func newJLWriter() Logger { + return &JLWriter{Level: LevelTrace} +} + +// Init JLWriter with json config string +func (s *JLWriter) Init(jsonconfig string) error { + return json.Unmarshal([]byte(jsonconfig), s) +} + +// WriteMsg write message in smtp writer. +// it will send an email with subject and only this message. +func (s *JLWriter) WriteMsg(when time.Time, msg string, level int) error { + if level > s.Level { + return nil + } + + text := fmt.Sprintf("%s %s", when.Format("2006-01-02 15:04:05"), msg) + + form := url.Values{} + form.Add("authorName", s.AuthorName) + form.Add("title", s.Title) + form.Add("text", text) + if s.RedirectURL != "" { + form.Add("redirectUrl", s.RedirectURL) + } + if s.ImageURL != "" { + form.Add("imageUrl", s.ImageURL) + } + + resp, err := http.PostForm(s.WebhookURL, form) + if err != nil { + return err + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + return fmt.Errorf("Post webhook failed %s %d", resp.Status, resp.StatusCode) + } + return nil +} + +// Flush implementing method. empty. +func (s *JLWriter) Flush() { +} + +// Destroy implementing method. empty. +func (s *JLWriter) Destroy() { +} + +func init() { + Register(AdapterJianLiao, newJLWriter) +} diff --git a/pkg/logs/log.go b/pkg/logs/log.go new file mode 100644 index 00000000..39c006d2 --- /dev/null +++ b/pkg/logs/log.go @@ -0,0 +1,669 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package logs provide a general log interface +// Usage: +// +// import "github.com/astaxie/beego/logs" +// +// log := NewLogger(10000) +// log.SetLogger("console", "") +// +// > the first params stand for how many channel +// +// Use it like this: +// +// log.Trace("trace") +// log.Info("info") +// log.Warn("warning") +// log.Debug("debug") +// log.Critical("critical") +// +// more docs http://beego.me/docs/module/logs.md +package logs + +import ( + "fmt" + "log" + "os" + "path" + "runtime" + "strconv" + "strings" + "sync" + "time" +) + +// RFC5424 log message levels. +const ( + LevelEmergency = iota + LevelAlert + LevelCritical + LevelError + LevelWarning + LevelNotice + LevelInformational + LevelDebug +) + +// levelLogLogger is defined to implement log.Logger +// the real log level will be LevelEmergency +const levelLoggerImpl = -1 + +// Name for adapter with beego official support +const ( + AdapterConsole = "console" + AdapterFile = "file" + AdapterMultiFile = "multifile" + AdapterMail = "smtp" + AdapterConn = "conn" + AdapterEs = "es" + AdapterJianLiao = "jianliao" + AdapterSlack = "slack" + AdapterAliLS = "alils" +) + +// Legacy log level constants to ensure backwards compatibility. +const ( + LevelInfo = LevelInformational + LevelTrace = LevelDebug + LevelWarn = LevelWarning +) + +type newLoggerFunc func() Logger + +// Logger defines the behavior of a log provider. +type Logger interface { + Init(config string) error + WriteMsg(when time.Time, msg string, level int) error + Destroy() + Flush() +} + +var adapters = make(map[string]newLoggerFunc) +var levelPrefix = [LevelDebug + 1]string{"[M]", "[A]", "[C]", "[E]", "[W]", "[N]", "[I]", "[D]"} + +// Register makes a log provide available by the provided name. +// If Register is called twice with the same name or if driver is nil, +// it panics. +func Register(name string, log newLoggerFunc) { + if log == nil { + panic("logs: Register provide is nil") + } + if _, dup := adapters[name]; dup { + panic("logs: Register called twice for provider " + name) + } + adapters[name] = log +} + +// BeeLogger is default logger in beego application. +// it can contain several providers and log message into all providers. +type BeeLogger struct { + lock sync.Mutex + level int + init bool + enableFuncCallDepth bool + loggerFuncCallDepth int + asynchronous bool + prefix string + msgChanLen int64 + msgChan chan *logMsg + signalChan chan string + wg sync.WaitGroup + outputs []*nameLogger +} + +const defaultAsyncMsgLen = 1e3 + +type nameLogger struct { + Logger + name string +} + +type logMsg struct { + level int + msg string + when time.Time +} + +var logMsgPool *sync.Pool + +// NewLogger returns a new BeeLogger. +// channelLen means the number of messages in chan(used where asynchronous is true). +// if the buffering chan is full, logger adapters write to file or other way. +func NewLogger(channelLens ...int64) *BeeLogger { + bl := new(BeeLogger) + bl.level = LevelDebug + bl.loggerFuncCallDepth = 2 + bl.msgChanLen = append(channelLens, 0)[0] + if bl.msgChanLen <= 0 { + bl.msgChanLen = defaultAsyncMsgLen + } + bl.signalChan = make(chan string, 1) + bl.setLogger(AdapterConsole) + return bl +} + +// Async set the log to asynchronous and start the goroutine +func (bl *BeeLogger) Async(msgLen ...int64) *BeeLogger { + bl.lock.Lock() + defer bl.lock.Unlock() + if bl.asynchronous { + return bl + } + bl.asynchronous = true + if len(msgLen) > 0 && msgLen[0] > 0 { + bl.msgChanLen = msgLen[0] + } + bl.msgChan = make(chan *logMsg, bl.msgChanLen) + logMsgPool = &sync.Pool{ + New: func() interface{} { + return &logMsg{} + }, + } + bl.wg.Add(1) + go bl.startLogger() + return bl +} + +// SetLogger provides a given logger adapter into BeeLogger with config string. +// config need to be correct JSON as string: {"interval":360}. +func (bl *BeeLogger) setLogger(adapterName string, configs ...string) error { + config := append(configs, "{}")[0] + for _, l := range bl.outputs { + if l.name == adapterName { + return fmt.Errorf("logs: duplicate adaptername %q (you have set this logger before)", adapterName) + } + } + + logAdapter, ok := adapters[adapterName] + if !ok { + return fmt.Errorf("logs: unknown adaptername %q (forgotten Register?)", adapterName) + } + + lg := logAdapter() + err := lg.Init(config) + if err != nil { + fmt.Fprintln(os.Stderr, "logs.BeeLogger.SetLogger: "+err.Error()) + return err + } + bl.outputs = append(bl.outputs, &nameLogger{name: adapterName, Logger: lg}) + return nil +} + +// SetLogger provides a given logger adapter into BeeLogger with config string. +// config need to be correct JSON as string: {"interval":360}. +func (bl *BeeLogger) SetLogger(adapterName string, configs ...string) error { + bl.lock.Lock() + defer bl.lock.Unlock() + if !bl.init { + bl.outputs = []*nameLogger{} + bl.init = true + } + return bl.setLogger(adapterName, configs...) +} + +// DelLogger remove a logger adapter in BeeLogger. +func (bl *BeeLogger) DelLogger(adapterName string) error { + bl.lock.Lock() + defer bl.lock.Unlock() + outputs := []*nameLogger{} + for _, lg := range bl.outputs { + if lg.name == adapterName { + lg.Destroy() + } else { + outputs = append(outputs, lg) + } + } + if len(outputs) == len(bl.outputs) { + return fmt.Errorf("logs: unknown adaptername %q (forgotten Register?)", adapterName) + } + bl.outputs = outputs + return nil +} + +func (bl *BeeLogger) writeToLoggers(when time.Time, msg string, level int) { + for _, l := range bl.outputs { + err := l.WriteMsg(when, msg, level) + if err != nil { + fmt.Fprintf(os.Stderr, "unable to WriteMsg to adapter:%v,error:%v\n", l.name, err) + } + } +} + +func (bl *BeeLogger) Write(p []byte) (n int, err error) { + if len(p) == 0 { + return 0, nil + } + // writeMsg will always add a '\n' character + if p[len(p)-1] == '\n' { + p = p[0 : len(p)-1] + } + // set levelLoggerImpl to ensure all log message will be write out + err = bl.writeMsg(levelLoggerImpl, string(p)) + if err == nil { + return len(p), err + } + return 0, err +} + +func (bl *BeeLogger) writeMsg(logLevel int, msg string, v ...interface{}) error { + if !bl.init { + bl.lock.Lock() + bl.setLogger(AdapterConsole) + bl.lock.Unlock() + } + + if len(v) > 0 { + msg = fmt.Sprintf(msg, v...) + } + + msg = bl.prefix + " " + msg + + when := time.Now() + if bl.enableFuncCallDepth { + _, file, line, ok := runtime.Caller(bl.loggerFuncCallDepth) + if !ok { + file = "???" + line = 0 + } + _, filename := path.Split(file) + msg = "[" + filename + ":" + strconv.Itoa(line) + "] " + msg + } + + //set level info in front of filename info + if logLevel == levelLoggerImpl { + // set to emergency to ensure all log will be print out correctly + logLevel = LevelEmergency + } else { + msg = levelPrefix[logLevel] + " " + msg + } + + if bl.asynchronous { + lm := logMsgPool.Get().(*logMsg) + lm.level = logLevel + lm.msg = msg + lm.when = when + if bl.outputs != nil { + bl.msgChan <- lm + } else { + logMsgPool.Put(lm) + } + } else { + bl.writeToLoggers(when, msg, logLevel) + } + return nil +} + +// SetLevel Set log message level. +// If message level (such as LevelDebug) is higher than logger level (such as LevelWarning), +// log providers will not even be sent the message. +func (bl *BeeLogger) SetLevel(l int) { + bl.level = l +} + +// GetLevel Get Current log message level. +func (bl *BeeLogger) GetLevel() int { + return bl.level +} + +// SetLogFuncCallDepth set log funcCallDepth +func (bl *BeeLogger) SetLogFuncCallDepth(d int) { + bl.loggerFuncCallDepth = d +} + +// GetLogFuncCallDepth return log funcCallDepth for wrapper +func (bl *BeeLogger) GetLogFuncCallDepth() int { + return bl.loggerFuncCallDepth +} + +// EnableFuncCallDepth enable log funcCallDepth +func (bl *BeeLogger) EnableFuncCallDepth(b bool) { + bl.enableFuncCallDepth = b +} + +// set prefix +func (bl *BeeLogger) SetPrefix(s string) { + bl.prefix = s +} + +// start logger chan reading. +// when chan is not empty, write logs. +func (bl *BeeLogger) startLogger() { + gameOver := false + for { + select { + case bm := <-bl.msgChan: + bl.writeToLoggers(bm.when, bm.msg, bm.level) + logMsgPool.Put(bm) + case sg := <-bl.signalChan: + // Now should only send "flush" or "close" to bl.signalChan + bl.flush() + if sg == "close" { + for _, l := range bl.outputs { + l.Destroy() + } + bl.outputs = nil + gameOver = true + } + bl.wg.Done() + } + if gameOver { + break + } + } +} + +// Emergency Log EMERGENCY level message. +func (bl *BeeLogger) Emergency(format string, v ...interface{}) { + if LevelEmergency > bl.level { + return + } + bl.writeMsg(LevelEmergency, format, v...) +} + +// Alert Log ALERT level message. +func (bl *BeeLogger) Alert(format string, v ...interface{}) { + if LevelAlert > bl.level { + return + } + bl.writeMsg(LevelAlert, format, v...) +} + +// Critical Log CRITICAL level message. +func (bl *BeeLogger) Critical(format string, v ...interface{}) { + if LevelCritical > bl.level { + return + } + bl.writeMsg(LevelCritical, format, v...) +} + +// Error Log ERROR level message. +func (bl *BeeLogger) Error(format string, v ...interface{}) { + if LevelError > bl.level { + return + } + bl.writeMsg(LevelError, format, v...) +} + +// Warning Log WARNING level message. +func (bl *BeeLogger) Warning(format string, v ...interface{}) { + if LevelWarn > bl.level { + return + } + bl.writeMsg(LevelWarn, format, v...) +} + +// Notice Log NOTICE level message. +func (bl *BeeLogger) Notice(format string, v ...interface{}) { + if LevelNotice > bl.level { + return + } + bl.writeMsg(LevelNotice, format, v...) +} + +// Informational Log INFORMATIONAL level message. +func (bl *BeeLogger) Informational(format string, v ...interface{}) { + if LevelInfo > bl.level { + return + } + bl.writeMsg(LevelInfo, format, v...) +} + +// Debug Log DEBUG level message. +func (bl *BeeLogger) Debug(format string, v ...interface{}) { + if LevelDebug > bl.level { + return + } + bl.writeMsg(LevelDebug, format, v...) +} + +// Warn Log WARN level message. +// compatibility alias for Warning() +func (bl *BeeLogger) Warn(format string, v ...interface{}) { + if LevelWarn > bl.level { + return + } + bl.writeMsg(LevelWarn, format, v...) +} + +// Info Log INFO level message. +// compatibility alias for Informational() +func (bl *BeeLogger) Info(format string, v ...interface{}) { + if LevelInfo > bl.level { + return + } + bl.writeMsg(LevelInfo, format, v...) +} + +// Trace Log TRACE level message. +// compatibility alias for Debug() +func (bl *BeeLogger) Trace(format string, v ...interface{}) { + if LevelDebug > bl.level { + return + } + bl.writeMsg(LevelDebug, format, v...) +} + +// Flush flush all chan data. +func (bl *BeeLogger) Flush() { + if bl.asynchronous { + bl.signalChan <- "flush" + bl.wg.Wait() + bl.wg.Add(1) + return + } + bl.flush() +} + +// Close close logger, flush all chan data and destroy all adapters in BeeLogger. +func (bl *BeeLogger) Close() { + if bl.asynchronous { + bl.signalChan <- "close" + bl.wg.Wait() + close(bl.msgChan) + } else { + bl.flush() + for _, l := range bl.outputs { + l.Destroy() + } + bl.outputs = nil + } + close(bl.signalChan) +} + +// Reset close all outputs, and set bl.outputs to nil +func (bl *BeeLogger) Reset() { + bl.Flush() + for _, l := range bl.outputs { + l.Destroy() + } + bl.outputs = nil +} + +func (bl *BeeLogger) flush() { + if bl.asynchronous { + for { + if len(bl.msgChan) > 0 { + bm := <-bl.msgChan + bl.writeToLoggers(bm.when, bm.msg, bm.level) + logMsgPool.Put(bm) + continue + } + break + } + } + for _, l := range bl.outputs { + l.Flush() + } +} + +// beeLogger references the used application logger. +var beeLogger = NewLogger() + +// GetBeeLogger returns the default BeeLogger +func GetBeeLogger() *BeeLogger { + return beeLogger +} + +var beeLoggerMap = struct { + sync.RWMutex + logs map[string]*log.Logger +}{ + logs: map[string]*log.Logger{}, +} + +// GetLogger returns the default BeeLogger +func GetLogger(prefixes ...string) *log.Logger { + prefix := append(prefixes, "")[0] + if prefix != "" { + prefix = fmt.Sprintf(`[%s] `, strings.ToUpper(prefix)) + } + beeLoggerMap.RLock() + l, ok := beeLoggerMap.logs[prefix] + if ok { + beeLoggerMap.RUnlock() + return l + } + beeLoggerMap.RUnlock() + beeLoggerMap.Lock() + defer beeLoggerMap.Unlock() + l, ok = beeLoggerMap.logs[prefix] + if !ok { + l = log.New(beeLogger, prefix, 0) + beeLoggerMap.logs[prefix] = l + } + return l +} + +// Reset will remove all the adapter +func Reset() { + beeLogger.Reset() +} + +// Async set the beelogger with Async mode and hold msglen messages +func Async(msgLen ...int64) *BeeLogger { + return beeLogger.Async(msgLen...) +} + +// SetLevel sets the global log level used by the simple logger. +func SetLevel(l int) { + beeLogger.SetLevel(l) +} + +// SetPrefix sets the prefix +func SetPrefix(s string) { + beeLogger.SetPrefix(s) +} + +// EnableFuncCallDepth enable log funcCallDepth +func EnableFuncCallDepth(b bool) { + beeLogger.enableFuncCallDepth = b +} + +// SetLogFuncCall set the CallDepth, default is 4 +func SetLogFuncCall(b bool) { + beeLogger.EnableFuncCallDepth(b) + beeLogger.SetLogFuncCallDepth(4) +} + +// SetLogFuncCallDepth set log funcCallDepth +func SetLogFuncCallDepth(d int) { + beeLogger.loggerFuncCallDepth = d +} + +// SetLogger sets a new logger. +func SetLogger(adapter string, config ...string) error { + return beeLogger.SetLogger(adapter, config...) +} + +// Emergency logs a message at emergency level. +func Emergency(f interface{}, v ...interface{}) { + beeLogger.Emergency(formatLog(f, v...)) +} + +// Alert logs a message at alert level. +func Alert(f interface{}, v ...interface{}) { + beeLogger.Alert(formatLog(f, v...)) +} + +// Critical logs a message at critical level. +func Critical(f interface{}, v ...interface{}) { + beeLogger.Critical(formatLog(f, v...)) +} + +// Error logs a message at error level. +func Error(f interface{}, v ...interface{}) { + beeLogger.Error(formatLog(f, v...)) +} + +// Warning logs a message at warning level. +func Warning(f interface{}, v ...interface{}) { + beeLogger.Warn(formatLog(f, v...)) +} + +// Warn compatibility alias for Warning() +func Warn(f interface{}, v ...interface{}) { + beeLogger.Warn(formatLog(f, v...)) +} + +// Notice logs a message at notice level. +func Notice(f interface{}, v ...interface{}) { + beeLogger.Notice(formatLog(f, v...)) +} + +// Informational logs a message at info level. +func Informational(f interface{}, v ...interface{}) { + beeLogger.Info(formatLog(f, v...)) +} + +// Info compatibility alias for Warning() +func Info(f interface{}, v ...interface{}) { + beeLogger.Info(formatLog(f, v...)) +} + +// Debug logs a message at debug level. +func Debug(f interface{}, v ...interface{}) { + beeLogger.Debug(formatLog(f, v...)) +} + +// Trace logs a message at trace level. +// compatibility alias for Warning() +func Trace(f interface{}, v ...interface{}) { + beeLogger.Trace(formatLog(f, v...)) +} + +func formatLog(f interface{}, v ...interface{}) string { + var msg string + switch f.(type) { + case string: + msg = f.(string) + if len(v) == 0 { + return msg + } + if strings.Contains(msg, "%") && !strings.Contains(msg, "%%") { + //format string + } else { + //do not contain format char + msg += strings.Repeat(" %v", len(v)) + } + default: + msg = fmt.Sprint(f) + if len(v) == 0 { + return msg + } + msg += strings.Repeat(" %v", len(v)) + } + return fmt.Sprintf(msg, v...) +} diff --git a/pkg/logs/logger.go b/pkg/logs/logger.go new file mode 100644 index 00000000..a28bff6f --- /dev/null +++ b/pkg/logs/logger.go @@ -0,0 +1,176 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package logs + +import ( + "io" + "runtime" + "sync" + "time" +) + +type logWriter struct { + sync.Mutex + writer io.Writer +} + +func newLogWriter(wr io.Writer) *logWriter { + return &logWriter{writer: wr} +} + +func (lg *logWriter) writeln(when time.Time, msg string) (int, error) { + lg.Lock() + h, _, _ := formatTimeHeader(when) + n, err := lg.writer.Write(append(append(h, msg...), '\n')) + lg.Unlock() + return n, err +} + +const ( + y1 = `0123456789` + y2 = `0123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789` + y3 = `0000000000111111111122222222223333333333444444444455555555556666666666777777777788888888889999999999` + y4 = `0123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789` + mo1 = `000000000111` + mo2 = `123456789012` + d1 = `0000000001111111111222222222233` + d2 = `1234567890123456789012345678901` + h1 = `000000000011111111112222` + h2 = `012345678901234567890123` + mi1 = `000000000011111111112222222222333333333344444444445555555555` + mi2 = `012345678901234567890123456789012345678901234567890123456789` + s1 = `000000000011111111112222222222333333333344444444445555555555` + s2 = `012345678901234567890123456789012345678901234567890123456789` + ns1 = `0123456789` +) + +func formatTimeHeader(when time.Time) ([]byte, int, int) { + y, mo, d := when.Date() + h, mi, s := when.Clock() + ns := when.Nanosecond() / 1000000 + //len("2006/01/02 15:04:05.123 ")==24 + var buf [24]byte + + buf[0] = y1[y/1000%10] + buf[1] = y2[y/100] + buf[2] = y3[y-y/100*100] + buf[3] = y4[y-y/100*100] + buf[4] = '/' + buf[5] = mo1[mo-1] + buf[6] = mo2[mo-1] + buf[7] = '/' + buf[8] = d1[d-1] + buf[9] = d2[d-1] + buf[10] = ' ' + buf[11] = h1[h] + buf[12] = h2[h] + buf[13] = ':' + buf[14] = mi1[mi] + buf[15] = mi2[mi] + buf[16] = ':' + buf[17] = s1[s] + buf[18] = s2[s] + buf[19] = '.' + buf[20] = ns1[ns/100] + buf[21] = ns1[ns%100/10] + buf[22] = ns1[ns%10] + + buf[23] = ' ' + + return buf[0:], d, h +} + +var ( + green = string([]byte{27, 91, 57, 55, 59, 52, 50, 109}) + white = string([]byte{27, 91, 57, 48, 59, 52, 55, 109}) + yellow = string([]byte{27, 91, 57, 55, 59, 52, 51, 109}) + red = string([]byte{27, 91, 57, 55, 59, 52, 49, 109}) + blue = string([]byte{27, 91, 57, 55, 59, 52, 52, 109}) + magenta = string([]byte{27, 91, 57, 55, 59, 52, 53, 109}) + cyan = string([]byte{27, 91, 57, 55, 59, 52, 54, 109}) + + w32Green = string([]byte{27, 91, 52, 50, 109}) + w32White = string([]byte{27, 91, 52, 55, 109}) + w32Yellow = string([]byte{27, 91, 52, 51, 109}) + w32Red = string([]byte{27, 91, 52, 49, 109}) + w32Blue = string([]byte{27, 91, 52, 52, 109}) + w32Magenta = string([]byte{27, 91, 52, 53, 109}) + w32Cyan = string([]byte{27, 91, 52, 54, 109}) + + reset = string([]byte{27, 91, 48, 109}) +) + +var once sync.Once +var colorMap map[string]string + +func initColor() { + if runtime.GOOS == "windows" { + green = w32Green + white = w32White + yellow = w32Yellow + red = w32Red + blue = w32Blue + magenta = w32Magenta + cyan = w32Cyan + } + colorMap = map[string]string{ + //by color + "green": green, + "white": white, + "yellow": yellow, + "red": red, + //by method + "GET": blue, + "POST": cyan, + "PUT": yellow, + "DELETE": red, + "PATCH": green, + "HEAD": magenta, + "OPTIONS": white, + } +} + +// ColorByStatus return color by http code +// 2xx return Green +// 3xx return White +// 4xx return Yellow +// 5xx return Red +func ColorByStatus(code int) string { + once.Do(initColor) + switch { + case code >= 200 && code < 300: + return colorMap["green"] + case code >= 300 && code < 400: + return colorMap["white"] + case code >= 400 && code < 500: + return colorMap["yellow"] + default: + return colorMap["red"] + } +} + +// ColorByMethod return color by http code +func ColorByMethod(method string) string { + once.Do(initColor) + if c := colorMap[method]; c != "" { + return c + } + return reset +} + +// ResetColor return reset color +func ResetColor() string { + return reset +} diff --git a/pkg/logs/logger_test.go b/pkg/logs/logger_test.go new file mode 100644 index 00000000..15be500d --- /dev/null +++ b/pkg/logs/logger_test.go @@ -0,0 +1,57 @@ +// Copyright 2016 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package logs + +import ( + "testing" + "time" +) + +func TestFormatHeader_0(t *testing.T) { + tm := time.Now() + if tm.Year() >= 2100 { + t.FailNow() + } + dur := time.Second + for { + if tm.Year() >= 2100 { + break + } + h, _, _ := formatTimeHeader(tm) + if tm.Format("2006/01/02 15:04:05.000 ") != string(h) { + t.Log(tm) + t.FailNow() + } + tm = tm.Add(dur) + dur *= 2 + } +} + +func TestFormatHeader_1(t *testing.T) { + tm := time.Now() + year := tm.Year() + dur := time.Second + for { + if tm.Year() >= year+1 { + break + } + h, _, _ := formatTimeHeader(tm) + if tm.Format("2006/01/02 15:04:05.000 ") != string(h) { + t.Log(tm) + t.FailNow() + } + tm = tm.Add(dur) + } +} diff --git a/pkg/logs/multifile.go b/pkg/logs/multifile.go new file mode 100644 index 00000000..90168274 --- /dev/null +++ b/pkg/logs/multifile.go @@ -0,0 +1,119 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package logs + +import ( + "encoding/json" + "time" +) + +// A filesLogWriter manages several fileLogWriter +// filesLogWriter will write logs to the file in json configuration and write the same level log to correspond file +// means if the file name in configuration is project.log filesLogWriter will create project.error.log/project.debug.log +// and write the error-level logs to project.error.log and write the debug-level logs to project.debug.log +// the rotate attribute also acts like fileLogWriter +type multiFileLogWriter struct { + writers [LevelDebug + 1 + 1]*fileLogWriter // the last one for fullLogWriter + fullLogWriter *fileLogWriter + Separate []string `json:"separate"` +} + +var levelNames = [...]string{"emergency", "alert", "critical", "error", "warning", "notice", "info", "debug"} + +// Init file logger with json config. +// jsonConfig like: +// { +// "filename":"logs/beego.log", +// "maxLines":0, +// "maxsize":0, +// "daily":true, +// "maxDays":15, +// "rotate":true, +// "perm":0600, +// "separate":["emergency", "alert", "critical", "error", "warning", "notice", "info", "debug"], +// } + +func (f *multiFileLogWriter) Init(config string) error { + writer := newFileWriter().(*fileLogWriter) + err := writer.Init(config) + if err != nil { + return err + } + f.fullLogWriter = writer + f.writers[LevelDebug+1] = writer + + //unmarshal "separate" field to f.Separate + json.Unmarshal([]byte(config), f) + + jsonMap := map[string]interface{}{} + json.Unmarshal([]byte(config), &jsonMap) + + for i := LevelEmergency; i < LevelDebug+1; i++ { + for _, v := range f.Separate { + if v == levelNames[i] { + jsonMap["filename"] = f.fullLogWriter.fileNameOnly + "." + levelNames[i] + f.fullLogWriter.suffix + jsonMap["level"] = i + bs, _ := json.Marshal(jsonMap) + writer = newFileWriter().(*fileLogWriter) + err := writer.Init(string(bs)) + if err != nil { + return err + } + f.writers[i] = writer + } + } + } + + return nil +} + +func (f *multiFileLogWriter) Destroy() { + for i := 0; i < len(f.writers); i++ { + if f.writers[i] != nil { + f.writers[i].Destroy() + } + } +} + +func (f *multiFileLogWriter) WriteMsg(when time.Time, msg string, level int) error { + if f.fullLogWriter != nil { + f.fullLogWriter.WriteMsg(when, msg, level) + } + for i := 0; i < len(f.writers)-1; i++ { + if f.writers[i] != nil { + if level == f.writers[i].Level { + f.writers[i].WriteMsg(when, msg, level) + } + } + } + return nil +} + +func (f *multiFileLogWriter) Flush() { + for i := 0; i < len(f.writers); i++ { + if f.writers[i] != nil { + f.writers[i].Flush() + } + } +} + +// newFilesWriter create a FileLogWriter returning as LoggerInterface. +func newFilesWriter() Logger { + return &multiFileLogWriter{} +} + +func init() { + Register(AdapterMultiFile, newFilesWriter) +} diff --git a/pkg/logs/multifile_test.go b/pkg/logs/multifile_test.go new file mode 100644 index 00000000..57b96094 --- /dev/null +++ b/pkg/logs/multifile_test.go @@ -0,0 +1,78 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package logs + +import ( + "bufio" + "os" + "strconv" + "strings" + "testing" +) + +func TestFiles_1(t *testing.T) { + log := NewLogger(10000) + log.SetLogger("multifile", `{"filename":"test.log","separate":["emergency", "alert", "critical", "error", "warning", "notice", "info", "debug"]}`) + log.Debug("debug") + log.Informational("info") + log.Notice("notice") + log.Warning("warning") + log.Error("error") + log.Alert("alert") + log.Critical("critical") + log.Emergency("emergency") + fns := []string{""} + fns = append(fns, levelNames[0:]...) + name := "test" + suffix := ".log" + for _, fn := range fns { + + file := name + suffix + if fn != "" { + file = name + "." + fn + suffix + } + f, err := os.Open(file) + if err != nil { + t.Fatal(err) + } + b := bufio.NewReader(f) + lineNum := 0 + lastLine := "" + for { + line, _, err := b.ReadLine() + if err != nil { + break + } + if len(line) > 0 { + lastLine = string(line) + lineNum++ + } + } + var expected = 1 + if fn == "" { + expected = LevelDebug + 1 + } + if lineNum != expected { + t.Fatal(file, "has", lineNum, "lines not "+strconv.Itoa(expected)+" lines") + } + if lineNum == 1 { + if !strings.Contains(lastLine, fn) { + t.Fatal(file + " " + lastLine + " not contains the log msg " + fn) + } + } + os.Remove(file) + } + +} diff --git a/pkg/logs/slack.go b/pkg/logs/slack.go new file mode 100644 index 00000000..1cd2e5ae --- /dev/null +++ b/pkg/logs/slack.go @@ -0,0 +1,60 @@ +package logs + +import ( + "encoding/json" + "fmt" + "net/http" + "net/url" + "time" +) + +// SLACKWriter implements beego LoggerInterface and is used to send jiaoliao webhook +type SLACKWriter struct { + WebhookURL string `json:"webhookurl"` + Level int `json:"level"` +} + +// newSLACKWriter create jiaoliao writer. +func newSLACKWriter() Logger { + return &SLACKWriter{Level: LevelTrace} +} + +// Init SLACKWriter with json config string +func (s *SLACKWriter) Init(jsonconfig string) error { + return json.Unmarshal([]byte(jsonconfig), s) +} + +// WriteMsg write message in smtp writer. +// it will send an email with subject and only this message. +func (s *SLACKWriter) WriteMsg(when time.Time, msg string, level int) error { + if level > s.Level { + return nil + } + + text := fmt.Sprintf("{\"text\": \"%s %s\"}", when.Format("2006-01-02 15:04:05"), msg) + + form := url.Values{} + form.Add("payload", text) + + resp, err := http.PostForm(s.WebhookURL, form) + if err != nil { + return err + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + return fmt.Errorf("Post webhook failed %s %d", resp.Status, resp.StatusCode) + } + return nil +} + +// Flush implementing method. empty. +func (s *SLACKWriter) Flush() { +} + +// Destroy implementing method. empty. +func (s *SLACKWriter) Destroy() { +} + +func init() { + Register(AdapterSlack, newSLACKWriter) +} diff --git a/pkg/logs/smtp.go b/pkg/logs/smtp.go new file mode 100644 index 00000000..6208d7b8 --- /dev/null +++ b/pkg/logs/smtp.go @@ -0,0 +1,149 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package logs + +import ( + "crypto/tls" + "encoding/json" + "fmt" + "net" + "net/smtp" + "strings" + "time" +) + +// SMTPWriter implements LoggerInterface and is used to send emails via given SMTP-server. +type SMTPWriter struct { + Username string `json:"username"` + Password string `json:"password"` + Host string `json:"host"` + Subject string `json:"subject"` + FromAddress string `json:"fromAddress"` + RecipientAddresses []string `json:"sendTos"` + Level int `json:"level"` +} + +// NewSMTPWriter create smtp writer. +func newSMTPWriter() Logger { + return &SMTPWriter{Level: LevelTrace} +} + +// Init smtp writer with json config. +// config like: +// { +// "username":"example@gmail.com", +// "password:"password", +// "host":"smtp.gmail.com:465", +// "subject":"email title", +// "fromAddress":"from@example.com", +// "sendTos":["email1","email2"], +// "level":LevelError +// } +func (s *SMTPWriter) Init(jsonconfig string) error { + return json.Unmarshal([]byte(jsonconfig), s) +} + +func (s *SMTPWriter) getSMTPAuth(host string) smtp.Auth { + if len(strings.Trim(s.Username, " ")) == 0 && len(strings.Trim(s.Password, " ")) == 0 { + return nil + } + return smtp.PlainAuth( + "", + s.Username, + s.Password, + host, + ) +} + +func (s *SMTPWriter) sendMail(hostAddressWithPort string, auth smtp.Auth, fromAddress string, recipients []string, msgContent []byte) error { + client, err := smtp.Dial(hostAddressWithPort) + if err != nil { + return err + } + + host, _, _ := net.SplitHostPort(hostAddressWithPort) + tlsConn := &tls.Config{ + InsecureSkipVerify: true, + ServerName: host, + } + if err = client.StartTLS(tlsConn); err != nil { + return err + } + + if auth != nil { + if err = client.Auth(auth); err != nil { + return err + } + } + + if err = client.Mail(fromAddress); err != nil { + return err + } + + for _, rec := range recipients { + if err = client.Rcpt(rec); err != nil { + return err + } + } + + w, err := client.Data() + if err != nil { + return err + } + _, err = w.Write(msgContent) + if err != nil { + return err + } + + err = w.Close() + if err != nil { + return err + } + + return client.Quit() +} + +// WriteMsg write message in smtp writer. +// it will send an email with subject and only this message. +func (s *SMTPWriter) WriteMsg(when time.Time, msg string, level int) error { + if level > s.Level { + return nil + } + + hp := strings.Split(s.Host, ":") + + // Set up authentication information. + auth := s.getSMTPAuth(hp[0]) + + // Connect to the server, authenticate, set the sender and recipient, + // and send the email all in one step. + contentType := "Content-Type: text/plain" + "; charset=UTF-8" + mailmsg := []byte("To: " + strings.Join(s.RecipientAddresses, ";") + "\r\nFrom: " + s.FromAddress + "<" + s.FromAddress + + ">\r\nSubject: " + s.Subject + "\r\n" + contentType + "\r\n\r\n" + fmt.Sprintf(".%s", when.Format("2006-01-02 15:04:05")) + msg) + + return s.sendMail(s.Host, auth, s.FromAddress, s.RecipientAddresses, mailmsg) +} + +// Flush implementing method. empty. +func (s *SMTPWriter) Flush() { +} + +// Destroy implementing method. empty. +func (s *SMTPWriter) Destroy() { +} + +func init() { + Register(AdapterMail, newSMTPWriter) +} diff --git a/pkg/logs/smtp_test.go b/pkg/logs/smtp_test.go new file mode 100644 index 00000000..28e762d2 --- /dev/null +++ b/pkg/logs/smtp_test.go @@ -0,0 +1,27 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package logs + +import ( + "testing" + "time" +) + +func TestSmtp(t *testing.T) { + log := NewLogger(10000) + log.SetLogger("smtp", `{"username":"beegotest@gmail.com","password":"xxxxxxxx","host":"smtp.gmail.com:587","sendTos":["xiemengjun@gmail.com"]}`) + log.Critical("sendmail critical") + time.Sleep(time.Second * 30) +} diff --git a/pkg/metric/prometheus.go b/pkg/metric/prometheus.go new file mode 100644 index 00000000..7722240b --- /dev/null +++ b/pkg/metric/prometheus.go @@ -0,0 +1,99 @@ +// Copyright 2020 astaxie +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package metric + +import ( + "net/http" + "reflect" + "strconv" + "strings" + "time" + + "github.com/prometheus/client_golang/prometheus" + + "github.com/astaxie/beego" + "github.com/astaxie/beego/logs" +) + +func PrometheusMiddleWare(next http.Handler) http.Handler { + summaryVec := prometheus.NewSummaryVec(prometheus.SummaryOpts{ + Name: "beego", + Subsystem: "http_request", + ConstLabels: map[string]string{ + "server": beego.BConfig.ServerName, + "env": beego.BConfig.RunMode, + "appname": beego.BConfig.AppName, + }, + Help: "The statics info for http request", + }, []string{"pattern", "method", "status", "duration"}) + + prometheus.MustRegister(summaryVec) + + registerBuildInfo() + + return http.HandlerFunc(func(writer http.ResponseWriter, q *http.Request) { + start := time.Now() + next.ServeHTTP(writer, q) + end := time.Now() + go report(end.Sub(start), writer, q, summaryVec) + }) +} + +func registerBuildInfo() { + buildInfo := prometheus.NewGaugeVec(prometheus.GaugeOpts{ + Name: "beego", + Subsystem: "build_info", + Help: "The building information", + ConstLabels: map[string]string{ + "appname": beego.BConfig.AppName, + "build_version": beego.BuildVersion, + "build_revision": beego.BuildGitRevision, + "build_status": beego.BuildStatus, + "build_tag": beego.BuildTag, + "build_time": strings.Replace(beego.BuildTime, "--", " ", 1), + "go_version": beego.GoVersion, + "git_branch": beego.GitBranch, + "start_time": time.Now().Format("2006-01-02 15:04:05"), + }, + }, []string{}) + + prometheus.MustRegister(buildInfo) + buildInfo.WithLabelValues().Set(1) +} + +func report(dur time.Duration, writer http.ResponseWriter, q *http.Request, vec *prometheus.SummaryVec) { + ctrl := beego.BeeApp.Handlers + ctx := ctrl.GetContext() + ctx.Reset(writer, q) + defer ctrl.GiveBackContext(ctx) + + // We cannot read the status code from q.Response.StatusCode + // since the http server does not set q.Response. So q.Response is nil + // Thus, we use reflection to read the status from writer whose concrete type is http.response + responseVal := reflect.ValueOf(writer).Elem() + field := responseVal.FieldByName("status") + status := -1 + if field.IsValid() && field.Kind() == reflect.Int { + status = int(field.Int()) + } + ptn := "UNKNOWN" + if rt, found := ctrl.FindRouter(ctx); found { + ptn = rt.GetPattern() + } else { + logs.Warn("we can not find the router info for this request, so request will be recorded as UNKNOWN: " + q.URL.String()) + } + ms := dur / time.Millisecond + vec.WithLabelValues(ptn, q.Method, strconv.Itoa(status), strconv.Itoa(int(ms))).Observe(float64(ms)) +} diff --git a/pkg/metric/prometheus_test.go b/pkg/metric/prometheus_test.go new file mode 100644 index 00000000..d82a6dec --- /dev/null +++ b/pkg/metric/prometheus_test.go @@ -0,0 +1,42 @@ +// Copyright 2020 astaxie +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package metric + +import ( + "net/http" + "net/url" + "testing" + "time" + + "github.com/prometheus/client_golang/prometheus" + + "github.com/astaxie/beego/context" +) + +func TestPrometheusMiddleWare(t *testing.T) { + middleware := PrometheusMiddleWare(http.HandlerFunc(func(http.ResponseWriter, *http.Request) {})) + writer := &context.Response{} + request := &http.Request{ + URL: &url.URL{ + Host: "localhost", + RawPath: "/a/b/c", + }, + Method: "POST", + } + vec := prometheus.NewSummaryVec(prometheus.SummaryOpts{}, []string{"pattern", "method", "status", "duration"}) + + report(time.Second, writer, request, vec) + middleware.ServeHTTP(writer, request) +} diff --git a/pkg/migration/ddl.go b/pkg/migration/ddl.go new file mode 100644 index 00000000..cd2c1c49 --- /dev/null +++ b/pkg/migration/ddl.go @@ -0,0 +1,395 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package migration + +import ( + "fmt" + + "github.com/astaxie/beego/logs" +) + +// Index struct defines the structure of Index Columns +type Index struct { + Name string +} + +// Unique struct defines a single unique key combination +type Unique struct { + Definition string + Columns []*Column +} + +//Column struct defines a single column of a table +type Column struct { + Name string + Inc string + Null string + Default string + Unsign string + DataType string + remove bool + Modify bool +} + +// Foreign struct defines a single foreign relationship +type Foreign struct { + ForeignTable string + ForeignColumn string + OnDelete string + OnUpdate string + Column +} + +// RenameColumn struct allows renaming of columns +type RenameColumn struct { + OldName string + OldNull string + OldDefault string + OldUnsign string + OldDataType string + NewName string + Column +} + +// CreateTable creates the table on system +func (m *Migration) CreateTable(tablename, engine, charset string, p ...func()) { + m.TableName = tablename + m.Engine = engine + m.Charset = charset + m.ModifyType = "create" +} + +// AlterTable set the ModifyType to alter +func (m *Migration) AlterTable(tablename string) { + m.TableName = tablename + m.ModifyType = "alter" +} + +// NewCol creates a new standard column and attaches it to m struct +func (m *Migration) NewCol(name string) *Column { + col := &Column{Name: name} + m.AddColumns(col) + return col +} + +//PriCol creates a new primary column and attaches it to m struct +func (m *Migration) PriCol(name string) *Column { + col := &Column{Name: name} + m.AddColumns(col) + m.AddPrimary(col) + return col +} + +//UniCol creates / appends columns to specified unique key and attaches it to m struct +func (m *Migration) UniCol(uni, name string) *Column { + col := &Column{Name: name} + m.AddColumns(col) + + uniqueOriginal := &Unique{} + + for _, unique := range m.Uniques { + if unique.Definition == uni { + unique.AddColumnsToUnique(col) + uniqueOriginal = unique + } + } + if uniqueOriginal.Definition == "" { + unique := &Unique{Definition: uni} + unique.AddColumnsToUnique(col) + m.AddUnique(unique) + } + + return col +} + +//ForeignCol creates a new foreign column and returns the instance of column +func (m *Migration) ForeignCol(colname, foreigncol, foreigntable string) (foreign *Foreign) { + + foreign = &Foreign{ForeignColumn: foreigncol, ForeignTable: foreigntable} + foreign.Name = colname + m.AddForeign(foreign) + return foreign +} + +//SetOnDelete sets the on delete of foreign +func (foreign *Foreign) SetOnDelete(del string) *Foreign { + foreign.OnDelete = "ON DELETE" + del + return foreign +} + +//SetOnUpdate sets the on update of foreign +func (foreign *Foreign) SetOnUpdate(update string) *Foreign { + foreign.OnUpdate = "ON UPDATE" + update + return foreign +} + +//Remove marks the columns to be removed. +//it allows reverse m to create the column. +func (c *Column) Remove() { + c.remove = true +} + +//SetAuto enables auto_increment of column (can be used once) +func (c *Column) SetAuto(inc bool) *Column { + if inc { + c.Inc = "auto_increment" + } + return c +} + +//SetNullable sets the column to be null +func (c *Column) SetNullable(null bool) *Column { + if null { + c.Null = "" + + } else { + c.Null = "NOT NULL" + } + return c +} + +//SetDefault sets the default value, prepend with "DEFAULT " +func (c *Column) SetDefault(def string) *Column { + c.Default = "DEFAULT " + def + return c +} + +//SetUnsigned sets the column to be unsigned int +func (c *Column) SetUnsigned(unsign bool) *Column { + if unsign { + c.Unsign = "UNSIGNED" + } + return c +} + +//SetDataType sets the dataType of the column +func (c *Column) SetDataType(dataType string) *Column { + c.DataType = dataType + return c +} + +//SetOldNullable allows reverting to previous nullable on reverse ms +func (c *RenameColumn) SetOldNullable(null bool) *RenameColumn { + if null { + c.OldNull = "" + + } else { + c.OldNull = "NOT NULL" + } + return c +} + +//SetOldDefault allows reverting to previous default on reverse ms +func (c *RenameColumn) SetOldDefault(def string) *RenameColumn { + c.OldDefault = def + return c +} + +//SetOldUnsigned allows reverting to previous unsgined on reverse ms +func (c *RenameColumn) SetOldUnsigned(unsign bool) *RenameColumn { + if unsign { + c.OldUnsign = "UNSIGNED" + } + return c +} + +//SetOldDataType allows reverting to previous datatype on reverse ms +func (c *RenameColumn) SetOldDataType(dataType string) *RenameColumn { + c.OldDataType = dataType + return c +} + +//SetPrimary adds the columns to the primary key (can only be used any number of times in only one m) +func (c *Column) SetPrimary(m *Migration) *Column { + m.Primary = append(m.Primary, c) + return c +} + +//AddColumnsToUnique adds the columns to Unique Struct +func (unique *Unique) AddColumnsToUnique(columns ...*Column) *Unique { + + unique.Columns = append(unique.Columns, columns...) + + return unique +} + +//AddColumns adds columns to m struct +func (m *Migration) AddColumns(columns ...*Column) *Migration { + + m.Columns = append(m.Columns, columns...) + + return m +} + +//AddPrimary adds the column to primary in m struct +func (m *Migration) AddPrimary(primary *Column) *Migration { + m.Primary = append(m.Primary, primary) + return m +} + +//AddUnique adds the column to unique in m struct +func (m *Migration) AddUnique(unique *Unique) *Migration { + m.Uniques = append(m.Uniques, unique) + return m +} + +//AddForeign adds the column to foreign in m struct +func (m *Migration) AddForeign(foreign *Foreign) *Migration { + m.Foreigns = append(m.Foreigns, foreign) + return m +} + +//AddIndex adds the column to index in m struct +func (m *Migration) AddIndex(index *Index) *Migration { + m.Indexes = append(m.Indexes, index) + return m +} + +//RenameColumn allows renaming of columns +func (m *Migration) RenameColumn(from, to string) *RenameColumn { + rename := &RenameColumn{OldName: from, NewName: to} + m.Renames = append(m.Renames, rename) + return rename +} + +//GetSQL returns the generated sql depending on ModifyType +func (m *Migration) GetSQL() (sql string) { + sql = "" + switch m.ModifyType { + case "create": + { + sql += fmt.Sprintf("CREATE TABLE `%s` (", m.TableName) + for index, column := range m.Columns { + sql += fmt.Sprintf("\n `%s` %s %s %s %s %s", column.Name, column.DataType, column.Unsign, column.Null, column.Inc, column.Default) + if len(m.Columns) > index+1 { + sql += "," + } + } + + if len(m.Primary) > 0 { + sql += fmt.Sprintf(",\n PRIMARY KEY( ") + } + for index, column := range m.Primary { + sql += fmt.Sprintf(" `%s`", column.Name) + if len(m.Primary) > index+1 { + sql += "," + } + + } + if len(m.Primary) > 0 { + sql += fmt.Sprintf(")") + } + + for _, unique := range m.Uniques { + sql += fmt.Sprintf(",\n UNIQUE KEY `%s`( ", unique.Definition) + for index, column := range unique.Columns { + sql += fmt.Sprintf(" `%s`", column.Name) + if len(unique.Columns) > index+1 { + sql += "," + } + } + sql += fmt.Sprintf(")") + } + for _, foreign := range m.Foreigns { + sql += fmt.Sprintf(",\n `%s` %s %s %s %s %s", foreign.Name, foreign.DataType, foreign.Unsign, foreign.Null, foreign.Inc, foreign.Default) + sql += fmt.Sprintf(",\n KEY `%s_%s_foreign`(`%s`),", m.TableName, foreign.Column.Name, foreign.Column.Name) + sql += fmt.Sprintf("\n CONSTRAINT `%s_%s_foreign` FOREIGN KEY (`%s`) REFERENCES `%s` (`%s`) %s %s", m.TableName, foreign.Column.Name, foreign.Column.Name, foreign.ForeignTable, foreign.ForeignColumn, foreign.OnDelete, foreign.OnUpdate) + + } + sql += fmt.Sprintf(")ENGINE=%s DEFAULT CHARSET=%s;", m.Engine, m.Charset) + break + } + case "alter": + { + sql += fmt.Sprintf("ALTER TABLE `%s` ", m.TableName) + for index, column := range m.Columns { + if !column.remove { + logs.Info("col") + sql += fmt.Sprintf("\n ADD `%s` %s %s %s %s %s", column.Name, column.DataType, column.Unsign, column.Null, column.Inc, column.Default) + } else { + sql += fmt.Sprintf("\n DROP COLUMN `%s`", column.Name) + } + + if len(m.Columns) > index+1 { + sql += "," + } + } + for index, column := range m.Renames { + sql += fmt.Sprintf("CHANGE COLUMN `%s` `%s` %s %s %s %s %s", column.OldName, column.NewName, column.DataType, column.Unsign, column.Null, column.Inc, column.Default) + if len(m.Renames) > index+1 { + sql += "," + } + } + + for index, foreign := range m.Foreigns { + sql += fmt.Sprintf("ADD `%s` %s %s %s %s %s", foreign.Name, foreign.DataType, foreign.Unsign, foreign.Null, foreign.Inc, foreign.Default) + sql += fmt.Sprintf(",\n ADD KEY `%s_%s_foreign`(`%s`)", m.TableName, foreign.Column.Name, foreign.Column.Name) + sql += fmt.Sprintf(",\n ADD CONSTRAINT `%s_%s_foreign` FOREIGN KEY (`%s`) REFERENCES `%s` (`%s`) %s %s", m.TableName, foreign.Column.Name, foreign.Column.Name, foreign.ForeignTable, foreign.ForeignColumn, foreign.OnDelete, foreign.OnUpdate) + if len(m.Foreigns) > index+1 { + sql += "," + } + } + sql += ";" + + break + } + case "reverse": + { + + sql += fmt.Sprintf("ALTER TABLE `%s`", m.TableName) + for index, column := range m.Columns { + if column.remove { + sql += fmt.Sprintf("\n ADD `%s` %s %s %s %s %s", column.Name, column.DataType, column.Unsign, column.Null, column.Inc, column.Default) + } else { + sql += fmt.Sprintf("\n DROP COLUMN `%s`", column.Name) + } + if len(m.Columns) > index+1 { + sql += "," + } + } + + if len(m.Primary) > 0 { + sql += fmt.Sprintf("\n DROP PRIMARY KEY,") + } + + for index, unique := range m.Uniques { + sql += fmt.Sprintf("\n DROP KEY `%s`", unique.Definition) + if len(m.Uniques) > index+1 { + sql += "," + } + + } + for index, column := range m.Renames { + sql += fmt.Sprintf("\n CHANGE COLUMN `%s` `%s` %s %s %s %s", column.NewName, column.OldName, column.OldDataType, column.OldUnsign, column.OldNull, column.OldDefault) + if len(m.Renames) > index+1 { + sql += "," + } + } + + for _, foreign := range m.Foreigns { + sql += fmt.Sprintf("\n DROP KEY `%s_%s_foreign`", m.TableName, foreign.Column.Name) + sql += fmt.Sprintf(",\n DROP FOREIGN KEY `%s_%s_foreign`", m.TableName, foreign.Column.Name) + sql += fmt.Sprintf(",\n DROP COLUMN `%s`", foreign.Name) + } + sql += ";" + } + case "delete": + { + sql += fmt.Sprintf("DROP TABLE IF EXISTS `%s`;", m.TableName) + } + } + + return +} diff --git a/pkg/migration/doc.go b/pkg/migration/doc.go new file mode 100644 index 00000000..0c6564d4 --- /dev/null +++ b/pkg/migration/doc.go @@ -0,0 +1,32 @@ +// Package migration enables you to generate migrations back and forth. It generates both migrations. +// +// //Creates a table +// m.CreateTable("tablename","InnoDB","utf8"); +// +// //Alter a table +// m.AlterTable("tablename") +// +// Standard Column Methods +// * SetDataType +// * SetNullable +// * SetDefault +// * SetUnsigned (use only on integer types unless produces error) +// +// //Sets a primary column, multiple calls allowed, standard column methods available +// m.PriCol("id").SetAuto(true).SetNullable(false).SetDataType("INT(10)").SetUnsigned(true) +// +// //UniCol Can be used multiple times, allows standard Column methods. Use same "index" string to add to same index +// m.UniCol("index","column") +// +// //Standard Column Initialisation, can call .Remove() after NewCol("") on alter to remove +// m.NewCol("name").SetDataType("VARCHAR(255) COLLATE utf8_unicode_ci").SetNullable(false) +// m.NewCol("value").SetDataType("DOUBLE(8,2)").SetNullable(false) +// +// //Rename Columns , only use with Alter table, doesn't works with Create, prefix standard column methods with "Old" to +// //create a true reversible migration eg: SetOldDataType("DOUBLE(12,3)") +// m.RenameColumn("from","to")... +// +// //Foreign Columns, single columns are only supported, SetOnDelete & SetOnUpdate are available, call appropriately. +// //Supports standard column methods, automatic reverse. +// m.ForeignCol("local_col","foreign_col","foreign_table") +package migration diff --git a/pkg/migration/migration.go b/pkg/migration/migration.go new file mode 100644 index 00000000..5ddfd972 --- /dev/null +++ b/pkg/migration/migration.go @@ -0,0 +1,330 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package migration is used for migration +// +// The table structure is as follow: +// +// CREATE TABLE `migrations` ( +// `id_migration` int(10) unsigned NOT NULL AUTO_INCREMENT COMMENT 'surrogate key', +// `name` varchar(255) DEFAULT NULL COMMENT 'migration name, unique', +// `created_at` timestamp NOT NULL DEFAULT CURRENT_TIMESTAMP COMMENT 'date migrated or rolled back', +// `statements` longtext COMMENT 'SQL statements for this migration', +// `rollback_statements` longtext, +// `status` enum('update','rollback') DEFAULT NULL COMMENT 'update indicates it is a normal migration while rollback means this migration is rolled back', +// PRIMARY KEY (`id_migration`) +// ) ENGINE=InnoDB DEFAULT CHARSET=utf8; +package migration + +import ( + "errors" + "sort" + "strings" + "time" + + "github.com/astaxie/beego/logs" + "github.com/astaxie/beego/orm" +) + +// const the data format for the bee generate migration datatype +const ( + DateFormat = "20060102_150405" + DBDateFormat = "2006-01-02 15:04:05" +) + +// Migrationer is an interface for all Migration struct +type Migrationer interface { + Up() + Down() + Reset() + Exec(name, status string) error + GetCreated() int64 +} + +//Migration defines the migrations by either SQL or DDL +type Migration struct { + sqls []string + Created string + TableName string + Engine string + Charset string + ModifyType string + Columns []*Column + Indexes []*Index + Primary []*Column + Uniques []*Unique + Foreigns []*Foreign + Renames []*RenameColumn + RemoveColumns []*Column + RemoveIndexes []*Index + RemoveUniques []*Unique + RemoveForeigns []*Foreign +} + +var ( + migrationMap map[string]Migrationer +) + +func init() { + migrationMap = make(map[string]Migrationer) +} + +// Up implement in the Inheritance struct for upgrade +func (m *Migration) Up() { + + switch m.ModifyType { + case "reverse": + m.ModifyType = "alter" + case "delete": + m.ModifyType = "create" + } + m.sqls = append(m.sqls, m.GetSQL()) +} + +// Down implement in the Inheritance struct for down +func (m *Migration) Down() { + + switch m.ModifyType { + case "alter": + m.ModifyType = "reverse" + case "create": + m.ModifyType = "delete" + } + m.sqls = append(m.sqls, m.GetSQL()) +} + +//Migrate adds the SQL to the execution list +func (m *Migration) Migrate(migrationType string) { + m.ModifyType = migrationType + m.sqls = append(m.sqls, m.GetSQL()) +} + +// SQL add sql want to execute +func (m *Migration) SQL(sql string) { + m.sqls = append(m.sqls, sql) +} + +// Reset the sqls +func (m *Migration) Reset() { + m.sqls = make([]string, 0) +} + +// Exec execute the sql already add in the sql +func (m *Migration) Exec(name, status string) error { + o := orm.NewOrm() + for _, s := range m.sqls { + logs.Info("exec sql:", s) + r := o.Raw(s) + _, err := r.Exec() + if err != nil { + return err + } + } + return m.addOrUpdateRecord(name, status) +} + +func (m *Migration) addOrUpdateRecord(name, status string) error { + o := orm.NewOrm() + if status == "down" { + status = "rollback" + p, err := o.Raw("update migrations set status = ?, rollback_statements = ?, created_at = ? where name = ?").Prepare() + if err != nil { + return nil + } + _, err = p.Exec(status, strings.Join(m.sqls, "; "), time.Now().Format(DBDateFormat), name) + return err + } + status = "update" + p, err := o.Raw("insert into migrations(name, created_at, statements, status) values(?,?,?,?)").Prepare() + if err != nil { + return err + } + _, err = p.Exec(name, time.Now().Format(DBDateFormat), strings.Join(m.sqls, "; "), status) + return err +} + +// GetCreated get the unixtime from the Created +func (m *Migration) GetCreated() int64 { + t, err := time.Parse(DateFormat, m.Created) + if err != nil { + return 0 + } + return t.Unix() +} + +// Register register the Migration in the map +func Register(name string, m Migrationer) error { + if _, ok := migrationMap[name]; ok { + return errors.New("already exist name:" + name) + } + migrationMap[name] = m + return nil +} + +// Upgrade upgrade the migration from lasttime +func Upgrade(lasttime int64) error { + sm := sortMap(migrationMap) + i := 0 + migs, _ := getAllMigrations() + for _, v := range sm { + if _, ok := migs[v.name]; !ok { + logs.Info("start upgrade", v.name) + v.m.Reset() + v.m.Up() + err := v.m.Exec(v.name, "up") + if err != nil { + logs.Error("execute error:", err) + time.Sleep(2 * time.Second) + return err + } + logs.Info("end upgrade:", v.name) + i++ + } + } + logs.Info("total success upgrade:", i, " migration") + time.Sleep(2 * time.Second) + return nil +} + +// Rollback rollback the migration by the name +func Rollback(name string) error { + if v, ok := migrationMap[name]; ok { + logs.Info("start rollback") + v.Reset() + v.Down() + err := v.Exec(name, "down") + if err != nil { + logs.Error("execute error:", err) + time.Sleep(2 * time.Second) + return err + } + logs.Info("end rollback") + time.Sleep(2 * time.Second) + return nil + } + logs.Error("not exist the migrationMap name:" + name) + time.Sleep(2 * time.Second) + return errors.New("not exist the migrationMap name:" + name) +} + +// Reset reset all migration +// run all migration's down function +func Reset() error { + sm := sortMap(migrationMap) + i := 0 + for j := len(sm) - 1; j >= 0; j-- { + v := sm[j] + if isRollBack(v.name) { + logs.Info("skip the", v.name) + time.Sleep(1 * time.Second) + continue + } + logs.Info("start reset:", v.name) + v.m.Reset() + v.m.Down() + err := v.m.Exec(v.name, "down") + if err != nil { + logs.Error("execute error:", err) + time.Sleep(2 * time.Second) + return err + } + i++ + logs.Info("end reset:", v.name) + } + logs.Info("total success reset:", i, " migration") + time.Sleep(2 * time.Second) + return nil +} + +// Refresh first Reset, then Upgrade +func Refresh() error { + err := Reset() + if err != nil { + logs.Error("execute error:", err) + time.Sleep(2 * time.Second) + return err + } + err = Upgrade(0) + return err +} + +type dataSlice []data + +type data struct { + created int64 + name string + m Migrationer +} + +// Len is part of sort.Interface. +func (d dataSlice) Len() int { + return len(d) +} + +// Swap is part of sort.Interface. +func (d dataSlice) Swap(i, j int) { + d[i], d[j] = d[j], d[i] +} + +// Less is part of sort.Interface. We use count as the value to sort by +func (d dataSlice) Less(i, j int) bool { + return d[i].created < d[j].created +} + +func sortMap(m map[string]Migrationer) dataSlice { + s := make(dataSlice, 0, len(m)) + for k, v := range m { + d := data{} + d.created = v.GetCreated() + d.name = k + d.m = v + s = append(s, d) + } + sort.Sort(s) + return s +} + +func isRollBack(name string) bool { + o := orm.NewOrm() + var maps []orm.Params + num, err := o.Raw("select * from migrations where `name` = ? order by id_migration desc", name).Values(&maps) + if err != nil { + logs.Info("get name has error", err) + return false + } + if num <= 0 { + return false + } + if maps[0]["status"] == "rollback" { + return true + } + return false +} +func getAllMigrations() (map[string]string, error) { + o := orm.NewOrm() + var maps []orm.Params + migs := make(map[string]string) + num, err := o.Raw("select * from migrations order by id_migration desc").Values(&maps) + if err != nil { + logs.Info("get name has error", err) + return migs, err + } + if num > 0 { + for _, v := range maps { + name := v["name"].(string) + migs[name] = v["status"].(string) + } + } + return migs, nil +} diff --git a/pkg/mime.go b/pkg/mime.go new file mode 100644 index 00000000..ca2878ab --- /dev/null +++ b/pkg/mime.go @@ -0,0 +1,556 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package beego + +var mimemaps = map[string]string{ + ".3dm": "x-world/x-3dmf", + ".3dmf": "x-world/x-3dmf", + ".7z": "application/x-7z-compressed", + ".a": "application/octet-stream", + ".aab": "application/x-authorware-bin", + ".aam": "application/x-authorware-map", + ".aas": "application/x-authorware-seg", + ".abc": "text/vndabc", + ".ace": "application/x-ace-compressed", + ".acgi": "text/html", + ".afl": "video/animaflex", + ".ai": "application/postscript", + ".aif": "audio/aiff", + ".aifc": "audio/aiff", + ".aiff": "audio/aiff", + ".aim": "application/x-aim", + ".aip": "text/x-audiosoft-intra", + ".alz": "application/x-alz-compressed", + ".ani": "application/x-navi-animation", + ".aos": "application/x-nokia-9000-communicator-add-on-software", + ".aps": "application/mime", + ".apk": "application/vnd.android.package-archive", + ".arc": "application/x-arc-compressed", + ".arj": "application/arj", + ".art": "image/x-jg", + ".asf": "video/x-ms-asf", + ".asm": "text/x-asm", + ".asp": "text/asp", + ".asx": "application/x-mplayer2", + ".au": "audio/basic", + ".avi": "video/x-msvideo", + ".avs": "video/avs-video", + ".bcpio": "application/x-bcpio", + ".bin": "application/mac-binary", + ".bmp": "image/bmp", + ".boo": "application/book", + ".book": "application/book", + ".boz": "application/x-bzip2", + ".bsh": "application/x-bsh", + ".bz2": "application/x-bzip2", + ".bz": "application/x-bzip", + ".c++": "text/plain", + ".c": "text/x-c", + ".cab": "application/vnd.ms-cab-compressed", + ".cat": "application/vndms-pkiseccat", + ".cc": "text/x-c", + ".ccad": "application/clariscad", + ".cco": "application/x-cocoa", + ".cdf": "application/cdf", + ".cer": "application/pkix-cert", + ".cha": "application/x-chat", + ".chat": "application/x-chat", + ".chrt": "application/vnd.kde.kchart", + ".class": "application/java", + ".com": "text/plain", + ".conf": "text/plain", + ".cpio": "application/x-cpio", + ".cpp": "text/x-c", + ".cpt": "application/mac-compactpro", + ".crl": "application/pkcs-crl", + ".crt": "application/pkix-cert", + ".crx": "application/x-chrome-extension", + ".csh": "text/x-scriptcsh", + ".css": "text/css", + ".csv": "text/csv", + ".cxx": "text/plain", + ".dar": "application/x-dar", + ".dcr": "application/x-director", + ".deb": "application/x-debian-package", + ".deepv": "application/x-deepv", + ".def": "text/plain", + ".der": "application/x-x509-ca-cert", + ".dif": "video/x-dv", + ".dir": "application/x-director", + ".divx": "video/divx", + ".dl": "video/dl", + ".dmg": "application/x-apple-diskimage", + ".doc": "application/msword", + ".dot": "application/msword", + ".dp": "application/commonground", + ".drw": "application/drafting", + ".dump": "application/octet-stream", + ".dv": "video/x-dv", + ".dvi": "application/x-dvi", + ".dwf": "drawing/x-dwf=(old)", + ".dwg": "application/acad", + ".dxf": "application/dxf", + ".dxr": "application/x-director", + ".el": "text/x-scriptelisp", + ".elc": "application/x-bytecodeelisp=(compiled=elisp)", + ".eml": "message/rfc822", + ".env": "application/x-envoy", + ".eps": "application/postscript", + ".es": "application/x-esrehber", + ".etx": "text/x-setext", + ".evy": "application/envoy", + ".exe": "application/octet-stream", + ".f77": "text/x-fortran", + ".f90": "text/x-fortran", + ".f": "text/x-fortran", + ".fdf": "application/vndfdf", + ".fif": "application/fractals", + ".fli": "video/fli", + ".flo": "image/florian", + ".flv": "video/x-flv", + ".flx": "text/vndfmiflexstor", + ".fmf": "video/x-atomic3d-feature", + ".for": "text/x-fortran", + ".fpx": "image/vndfpx", + ".frl": "application/freeloader", + ".funk": "audio/make", + ".g3": "image/g3fax", + ".g": "text/plain", + ".gif": "image/gif", + ".gl": "video/gl", + ".gsd": "audio/x-gsm", + ".gsm": "audio/x-gsm", + ".gsp": "application/x-gsp", + ".gss": "application/x-gss", + ".gtar": "application/x-gtar", + ".gz": "application/x-compressed", + ".gzip": "application/x-gzip", + ".h": "text/x-h", + ".hdf": "application/x-hdf", + ".help": "application/x-helpfile", + ".hgl": "application/vndhp-hpgl", + ".hh": "text/x-h", + ".hlb": "text/x-script", + ".hlp": "application/hlp", + ".hpg": "application/vndhp-hpgl", + ".hpgl": "application/vndhp-hpgl", + ".hqx": "application/binhex", + ".hta": "application/hta", + ".htc": "text/x-component", + ".htm": "text/html", + ".html": "text/html", + ".htmls": "text/html", + ".htt": "text/webviewhtml", + ".htx": "text/html", + ".ice": "x-conference/x-cooltalk", + ".ico": "image/x-icon", + ".ics": "text/calendar", + ".icz": "text/calendar", + ".idc": "text/plain", + ".ief": "image/ief", + ".iefs": "image/ief", + ".iges": "application/iges", + ".igs": "application/iges", + ".ima": "application/x-ima", + ".imap": "application/x-httpd-imap", + ".inf": "application/inf", + ".ins": "application/x-internett-signup", + ".ip": "application/x-ip2", + ".isu": "video/x-isvideo", + ".it": "audio/it", + ".iv": "application/x-inventor", + ".ivr": "i-world/i-vrml", + ".ivy": "application/x-livescreen", + ".jam": "audio/x-jam", + ".jav": "text/x-java-source", + ".java": "text/x-java-source", + ".jcm": "application/x-java-commerce", + ".jfif-tbnl": "image/jpeg", + ".jfif": "image/jpeg", + ".jnlp": "application/x-java-jnlp-file", + ".jpe": "image/jpeg", + ".jpeg": "image/jpeg", + ".jpg": "image/jpeg", + ".jps": "image/x-jps", + ".js": "application/javascript", + ".json": "application/json", + ".jut": "image/jutvision", + ".kar": "audio/midi", + ".karbon": "application/vnd.kde.karbon", + ".kfo": "application/vnd.kde.kformula", + ".flw": "application/vnd.kde.kivio", + ".kml": "application/vnd.google-earth.kml+xml", + ".kmz": "application/vnd.google-earth.kmz", + ".kon": "application/vnd.kde.kontour", + ".kpr": "application/vnd.kde.kpresenter", + ".kpt": "application/vnd.kde.kpresenter", + ".ksp": "application/vnd.kde.kspread", + ".kwd": "application/vnd.kde.kword", + ".kwt": "application/vnd.kde.kword", + ".ksh": "text/x-scriptksh", + ".la": "audio/nspaudio", + ".lam": "audio/x-liveaudio", + ".latex": "application/x-latex", + ".lha": "application/lha", + ".lhx": "application/octet-stream", + ".list": "text/plain", + ".lma": "audio/nspaudio", + ".log": "text/plain", + ".lsp": "text/x-scriptlisp", + ".lst": "text/plain", + ".lsx": "text/x-la-asf", + ".ltx": "application/x-latex", + ".lzh": "application/octet-stream", + ".lzx": "application/lzx", + ".m1v": "video/mpeg", + ".m2a": "audio/mpeg", + ".m2v": "video/mpeg", + ".m3u": "audio/x-mpegurl", + ".m": "text/x-m", + ".man": "application/x-troff-man", + ".manifest": "text/cache-manifest", + ".map": "application/x-navimap", + ".mar": "text/plain", + ".mbd": "application/mbedlet", + ".mc$": "application/x-magic-cap-package-10", + ".mcd": "application/mcad", + ".mcf": "text/mcf", + ".mcp": "application/netmc", + ".me": "application/x-troff-me", + ".mht": "message/rfc822", + ".mhtml": "message/rfc822", + ".mid": "application/x-midi", + ".midi": "application/x-midi", + ".mif": "application/x-frame", + ".mime": "message/rfc822", + ".mjf": "audio/x-vndaudioexplosionmjuicemediafile", + ".mjpg": "video/x-motion-jpeg", + ".mm": "application/base64", + ".mme": "application/base64", + ".mod": "audio/mod", + ".moov": "video/quicktime", + ".mov": "video/quicktime", + ".movie": "video/x-sgi-movie", + ".mp2": "audio/mpeg", + ".mp3": "audio/mpeg3", + ".mp4": "video/mp4", + ".mpa": "audio/mpeg", + ".mpc": "application/x-project", + ".mpe": "video/mpeg", + ".mpeg": "video/mpeg", + ".mpg": "video/mpeg", + ".mpga": "audio/mpeg", + ".mpp": "application/vndms-project", + ".mpt": "application/x-project", + ".mpv": "application/x-project", + ".mpx": "application/x-project", + ".mrc": "application/marc", + ".ms": "application/x-troff-ms", + ".mv": "video/x-sgi-movie", + ".my": "audio/make", + ".mzz": "application/x-vndaudioexplosionmzz", + ".nap": "image/naplps", + ".naplps": "image/naplps", + ".nc": "application/x-netcdf", + ".ncm": "application/vndnokiaconfiguration-message", + ".nif": "image/x-niff", + ".niff": "image/x-niff", + ".nix": "application/x-mix-transfer", + ".nsc": "application/x-conference", + ".nvd": "application/x-navidoc", + ".o": "application/octet-stream", + ".oda": "application/oda", + ".odb": "application/vnd.oasis.opendocument.database", + ".odc": "application/vnd.oasis.opendocument.chart", + ".odf": "application/vnd.oasis.opendocument.formula", + ".odg": "application/vnd.oasis.opendocument.graphics", + ".odi": "application/vnd.oasis.opendocument.image", + ".odm": "application/vnd.oasis.opendocument.text-master", + ".odp": "application/vnd.oasis.opendocument.presentation", + ".ods": "application/vnd.oasis.opendocument.spreadsheet", + ".odt": "application/vnd.oasis.opendocument.text", + ".oga": "audio/ogg", + ".ogg": "audio/ogg", + ".ogv": "video/ogg", + ".omc": "application/x-omc", + ".omcd": "application/x-omcdatamaker", + ".omcr": "application/x-omcregerator", + ".otc": "application/vnd.oasis.opendocument.chart-template", + ".otf": "application/vnd.oasis.opendocument.formula-template", + ".otg": "application/vnd.oasis.opendocument.graphics-template", + ".oth": "application/vnd.oasis.opendocument.text-web", + ".oti": "application/vnd.oasis.opendocument.image-template", + ".otm": "application/vnd.oasis.opendocument.text-master", + ".otp": "application/vnd.oasis.opendocument.presentation-template", + ".ots": "application/vnd.oasis.opendocument.spreadsheet-template", + ".ott": "application/vnd.oasis.opendocument.text-template", + ".p10": "application/pkcs10", + ".p12": "application/pkcs-12", + ".p7a": "application/x-pkcs7-signature", + ".p7c": "application/pkcs7-mime", + ".p7m": "application/pkcs7-mime", + ".p7r": "application/x-pkcs7-certreqresp", + ".p7s": "application/pkcs7-signature", + ".p": "text/x-pascal", + ".part": "application/pro_eng", + ".pas": "text/pascal", + ".pbm": "image/x-portable-bitmap", + ".pcl": "application/vndhp-pcl", + ".pct": "image/x-pict", + ".pcx": "image/x-pcx", + ".pdb": "chemical/x-pdb", + ".pdf": "application/pdf", + ".pfunk": "audio/make", + ".pgm": "image/x-portable-graymap", + ".pic": "image/pict", + ".pict": "image/pict", + ".pkg": "application/x-newton-compatible-pkg", + ".pko": "application/vndms-pkipko", + ".pl": "text/x-scriptperl", + ".plx": "application/x-pixclscript", + ".pm4": "application/x-pagemaker", + ".pm5": "application/x-pagemaker", + ".pm": "text/x-scriptperl-module", + ".png": "image/png", + ".pnm": "application/x-portable-anymap", + ".pot": "application/mspowerpoint", + ".pov": "model/x-pov", + ".ppa": "application/vndms-powerpoint", + ".ppm": "image/x-portable-pixmap", + ".pps": "application/mspowerpoint", + ".ppt": "application/mspowerpoint", + ".ppz": "application/mspowerpoint", + ".pre": "application/x-freelance", + ".prt": "application/pro_eng", + ".ps": "application/postscript", + ".psd": "application/octet-stream", + ".pvu": "paleovu/x-pv", + ".pwz": "application/vndms-powerpoint", + ".py": "text/x-scriptphyton", + ".pyc": "application/x-bytecodepython", + ".qcp": "audio/vndqcelp", + ".qd3": "x-world/x-3dmf", + ".qd3d": "x-world/x-3dmf", + ".qif": "image/x-quicktime", + ".qt": "video/quicktime", + ".qtc": "video/x-qtc", + ".qti": "image/x-quicktime", + ".qtif": "image/x-quicktime", + ".ra": "audio/x-pn-realaudio", + ".ram": "audio/x-pn-realaudio", + ".rar": "application/x-rar-compressed", + ".ras": "application/x-cmu-raster", + ".rast": "image/cmu-raster", + ".rexx": "text/x-scriptrexx", + ".rf": "image/vndrn-realflash", + ".rgb": "image/x-rgb", + ".rm": "application/vndrn-realmedia", + ".rmi": "audio/mid", + ".rmm": "audio/x-pn-realaudio", + ".rmp": "audio/x-pn-realaudio", + ".rng": "application/ringing-tones", + ".rnx": "application/vndrn-realplayer", + ".roff": "application/x-troff", + ".rp": "image/vndrn-realpix", + ".rpm": "audio/x-pn-realaudio-plugin", + ".rt": "text/vndrn-realtext", + ".rtf": "text/richtext", + ".rtx": "text/richtext", + ".rv": "video/vndrn-realvideo", + ".s": "text/x-asm", + ".s3m": "audio/s3m", + ".s7z": "application/x-7z-compressed", + ".saveme": "application/octet-stream", + ".sbk": "application/x-tbook", + ".scm": "text/x-scriptscheme", + ".sdml": "text/plain", + ".sdp": "application/sdp", + ".sdr": "application/sounder", + ".sea": "application/sea", + ".set": "application/set", + ".sgm": "text/x-sgml", + ".sgml": "text/x-sgml", + ".sh": "text/x-scriptsh", + ".shar": "application/x-bsh", + ".shtml": "text/x-server-parsed-html", + ".sid": "audio/x-psid", + ".skd": "application/x-koan", + ".skm": "application/x-koan", + ".skp": "application/x-koan", + ".skt": "application/x-koan", + ".sit": "application/x-stuffit", + ".sitx": "application/x-stuffitx", + ".sl": "application/x-seelogo", + ".smi": "application/smil", + ".smil": "application/smil", + ".snd": "audio/basic", + ".sol": "application/solids", + ".spc": "text/x-speech", + ".spl": "application/futuresplash", + ".spr": "application/x-sprite", + ".sprite": "application/x-sprite", + ".spx": "audio/ogg", + ".src": "application/x-wais-source", + ".ssi": "text/x-server-parsed-html", + ".ssm": "application/streamingmedia", + ".sst": "application/vndms-pkicertstore", + ".step": "application/step", + ".stl": "application/sla", + ".stp": "application/step", + ".sv4cpio": "application/x-sv4cpio", + ".sv4crc": "application/x-sv4crc", + ".svf": "image/vnddwg", + ".svg": "image/svg+xml", + ".svr": "application/x-world", + ".swf": "application/x-shockwave-flash", + ".t": "application/x-troff", + ".talk": "text/x-speech", + ".tar": "application/x-tar", + ".tbk": "application/toolbook", + ".tcl": "text/x-scripttcl", + ".tcsh": "text/x-scripttcsh", + ".tex": "application/x-tex", + ".texi": "application/x-texinfo", + ".texinfo": "application/x-texinfo", + ".text": "text/plain", + ".tgz": "application/gnutar", + ".tif": "image/tiff", + ".tiff": "image/tiff", + ".tr": "application/x-troff", + ".tsi": "audio/tsp-audio", + ".tsp": "application/dsptype", + ".tsv": "text/tab-separated-values", + ".turbot": "image/florian", + ".txt": "text/plain", + ".uil": "text/x-uil", + ".uni": "text/uri-list", + ".unis": "text/uri-list", + ".unv": "application/i-deas", + ".uri": "text/uri-list", + ".uris": "text/uri-list", + ".ustar": "application/x-ustar", + ".uu": "text/x-uuencode", + ".uue": "text/x-uuencode", + ".vcd": "application/x-cdlink", + ".vcf": "text/x-vcard", + ".vcard": "text/x-vcard", + ".vcs": "text/x-vcalendar", + ".vda": "application/vda", + ".vdo": "video/vdo", + ".vew": "application/groupwise", + ".viv": "video/vivo", + ".vivo": "video/vivo", + ".vmd": "application/vocaltec-media-desc", + ".vmf": "application/vocaltec-media-file", + ".voc": "audio/voc", + ".vos": "video/vosaic", + ".vox": "audio/voxware", + ".vqe": "audio/x-twinvq-plugin", + ".vqf": "audio/x-twinvq", + ".vql": "audio/x-twinvq-plugin", + ".vrml": "application/x-vrml", + ".vrt": "x-world/x-vrt", + ".vsd": "application/x-visio", + ".vst": "application/x-visio", + ".vsw": "application/x-visio", + ".w60": "application/wordperfect60", + ".w61": "application/wordperfect61", + ".w6w": "application/msword", + ".wav": "audio/wav", + ".wb1": "application/x-qpro", + ".wbmp": "image/vnd.wap.wbmp", + ".web": "application/vndxara", + ".wiz": "application/msword", + ".wk1": "application/x-123", + ".wmf": "windows/metafile", + ".wml": "text/vnd.wap.wml", + ".wmlc": "application/vnd.wap.wmlc", + ".wmls": "text/vnd.wap.wmlscript", + ".wmlsc": "application/vnd.wap.wmlscriptc", + ".word": "application/msword", + ".wp5": "application/wordperfect", + ".wp6": "application/wordperfect", + ".wp": "application/wordperfect", + ".wpd": "application/wordperfect", + ".wq1": "application/x-lotus", + ".wri": "application/mswrite", + ".wrl": "application/x-world", + ".wrz": "model/vrml", + ".wsc": "text/scriplet", + ".wsrc": "application/x-wais-source", + ".wtk": "application/x-wintalk", + ".x-png": "image/png", + ".xbm": "image/x-xbitmap", + ".xdr": "video/x-amt-demorun", + ".xgz": "xgl/drawing", + ".xif": "image/vndxiff", + ".xl": "application/excel", + ".xla": "application/excel", + ".xlb": "application/excel", + ".xlc": "application/excel", + ".xld": "application/excel", + ".xlk": "application/excel", + ".xll": "application/excel", + ".xlm": "application/excel", + ".xls": "application/excel", + ".xlt": "application/excel", + ".xlv": "application/excel", + ".xlw": "application/excel", + ".xm": "audio/xm", + ".xml": "text/xml", + ".xmz": "xgl/movie", + ".xpix": "application/x-vndls-xpix", + ".xpm": "image/x-xpixmap", + ".xsr": "video/x-amt-showrun", + ".xwd": "image/x-xwd", + ".xyz": "chemical/x-pdb", + ".z": "application/x-compress", + ".zip": "application/zip", + ".zoo": "application/octet-stream", + ".zsh": "text/x-scriptzsh", + ".docx": "application/vnd.openxmlformats-officedocument.wordprocessingml.document", + ".docm": "application/vnd.ms-word.document.macroEnabled.12", + ".dotx": "application/vnd.openxmlformats-officedocument.wordprocessingml.template", + ".dotm": "application/vnd.ms-word.template.macroEnabled.12", + ".xlsx": "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet", + ".xlsm": "application/vnd.ms-excel.sheet.macroEnabled.12", + ".xltx": "application/vnd.openxmlformats-officedocument.spreadsheetml.template", + ".xltm": "application/vnd.ms-excel.template.macroEnabled.12", + ".xlsb": "application/vnd.ms-excel.sheet.binary.macroEnabled.12", + ".xlam": "application/vnd.ms-excel.addin.macroEnabled.12", + ".pptx": "application/vnd.openxmlformats-officedocument.presentationml.presentation", + ".pptm": "application/vnd.ms-powerpoint.presentation.macroEnabled.12", + ".ppsx": "application/vnd.openxmlformats-officedocument.presentationml.slideshow", + ".ppsm": "application/vnd.ms-powerpoint.slideshow.macroEnabled.12", + ".potx": "application/vnd.openxmlformats-officedocument.presentationml.template", + ".potm": "application/vnd.ms-powerpoint.template.macroEnabled.12", + ".ppam": "application/vnd.ms-powerpoint.addin.macroEnabled.12", + ".sldx": "application/vnd.openxmlformats-officedocument.presentationml.slide", + ".sldm": "application/vnd.ms-powerpoint.slide.macroEnabled.12", + ".thmx": "application/vnd.ms-officetheme", + ".onetoc": "application/onenote", + ".onetoc2": "application/onenote", + ".onetmp": "application/onenote", + ".onepkg": "application/onenote", + ".key": "application/x-iwork-keynote-sffkey", + ".kth": "application/x-iwork-keynote-sffkth", + ".nmbtemplate": "application/x-iwork-numbers-sfftemplate", + ".numbers": "application/x-iwork-numbers-sffnumbers", + ".pages": "application/x-iwork-pages-sffpages", + ".template": "application/x-iwork-pages-sfftemplate", + ".xpi": "application/x-xpinstall", + ".oex": "application/x-opera-extension", + ".mustache": "text/html", +} diff --git a/pkg/namespace.go b/pkg/namespace.go new file mode 100644 index 00000000..4952c9d5 --- /dev/null +++ b/pkg/namespace.go @@ -0,0 +1,396 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package beego + +import ( + "net/http" + "strings" + + beecontext "github.com/astaxie/beego/context" +) + +type namespaceCond func(*beecontext.Context) bool + +// LinkNamespace used as link action +type LinkNamespace func(*Namespace) + +// Namespace is store all the info +type Namespace struct { + prefix string + handlers *ControllerRegister +} + +// NewNamespace get new Namespace +func NewNamespace(prefix string, params ...LinkNamespace) *Namespace { + ns := &Namespace{ + prefix: prefix, + handlers: NewControllerRegister(), + } + for _, p := range params { + p(ns) + } + return ns +} + +// Cond set condition function +// if cond return true can run this namespace, else can't +// usage: +// ns.Cond(func (ctx *context.Context) bool{ +// if ctx.Input.Domain() == "api.beego.me" { +// return true +// } +// return false +// }) +// Cond as the first filter +func (n *Namespace) Cond(cond namespaceCond) *Namespace { + fn := func(ctx *beecontext.Context) { + if !cond(ctx) { + exception("405", ctx) + } + } + if v := n.handlers.filters[BeforeRouter]; len(v) > 0 { + mr := new(FilterRouter) + mr.tree = NewTree() + mr.pattern = "*" + mr.filterFunc = fn + mr.tree.AddRouter("*", true) + n.handlers.filters[BeforeRouter] = append([]*FilterRouter{mr}, v...) + } else { + n.handlers.InsertFilter("*", BeforeRouter, fn) + } + return n +} + +// Filter add filter in the Namespace +// action has before & after +// FilterFunc +// usage: +// Filter("before", func (ctx *context.Context){ +// _, ok := ctx.Input.Session("uid").(int) +// if !ok && ctx.Request.RequestURI != "/login" { +// ctx.Redirect(302, "/login") +// } +// }) +func (n *Namespace) Filter(action string, filter ...FilterFunc) *Namespace { + var a int + if action == "before" { + a = BeforeRouter + } else if action == "after" { + a = FinishRouter + } + for _, f := range filter { + n.handlers.InsertFilter("*", a, f) + } + return n +} + +// Router same as beego.Rourer +// refer: https://godoc.org/github.com/astaxie/beego#Router +func (n *Namespace) Router(rootpath string, c ControllerInterface, mappingMethods ...string) *Namespace { + n.handlers.Add(rootpath, c, mappingMethods...) + return n +} + +// AutoRouter same as beego.AutoRouter +// refer: https://godoc.org/github.com/astaxie/beego#AutoRouter +func (n *Namespace) AutoRouter(c ControllerInterface) *Namespace { + n.handlers.AddAuto(c) + return n +} + +// AutoPrefix same as beego.AutoPrefix +// refer: https://godoc.org/github.com/astaxie/beego#AutoPrefix +func (n *Namespace) AutoPrefix(prefix string, c ControllerInterface) *Namespace { + n.handlers.AddAutoPrefix(prefix, c) + return n +} + +// Get same as beego.Get +// refer: https://godoc.org/github.com/astaxie/beego#Get +func (n *Namespace) Get(rootpath string, f FilterFunc) *Namespace { + n.handlers.Get(rootpath, f) + return n +} + +// Post same as beego.Post +// refer: https://godoc.org/github.com/astaxie/beego#Post +func (n *Namespace) Post(rootpath string, f FilterFunc) *Namespace { + n.handlers.Post(rootpath, f) + return n +} + +// Delete same as beego.Delete +// refer: https://godoc.org/github.com/astaxie/beego#Delete +func (n *Namespace) Delete(rootpath string, f FilterFunc) *Namespace { + n.handlers.Delete(rootpath, f) + return n +} + +// Put same as beego.Put +// refer: https://godoc.org/github.com/astaxie/beego#Put +func (n *Namespace) Put(rootpath string, f FilterFunc) *Namespace { + n.handlers.Put(rootpath, f) + return n +} + +// Head same as beego.Head +// refer: https://godoc.org/github.com/astaxie/beego#Head +func (n *Namespace) Head(rootpath string, f FilterFunc) *Namespace { + n.handlers.Head(rootpath, f) + return n +} + +// Options same as beego.Options +// refer: https://godoc.org/github.com/astaxie/beego#Options +func (n *Namespace) Options(rootpath string, f FilterFunc) *Namespace { + n.handlers.Options(rootpath, f) + return n +} + +// Patch same as beego.Patch +// refer: https://godoc.org/github.com/astaxie/beego#Patch +func (n *Namespace) Patch(rootpath string, f FilterFunc) *Namespace { + n.handlers.Patch(rootpath, f) + return n +} + +// Any same as beego.Any +// refer: https://godoc.org/github.com/astaxie/beego#Any +func (n *Namespace) Any(rootpath string, f FilterFunc) *Namespace { + n.handlers.Any(rootpath, f) + return n +} + +// Handler same as beego.Handler +// refer: https://godoc.org/github.com/astaxie/beego#Handler +func (n *Namespace) Handler(rootpath string, h http.Handler) *Namespace { + n.handlers.Handler(rootpath, h) + return n +} + +// Include add include class +// refer: https://godoc.org/github.com/astaxie/beego#Include +func (n *Namespace) Include(cList ...ControllerInterface) *Namespace { + n.handlers.Include(cList...) + return n +} + +// Namespace add nest Namespace +// usage: +//ns := beego.NewNamespace(“/v1”). +//Namespace( +// beego.NewNamespace("/shop"). +// Get("/:id", func(ctx *context.Context) { +// ctx.Output.Body([]byte("shopinfo")) +// }), +// beego.NewNamespace("/order"). +// Get("/:id", func(ctx *context.Context) { +// ctx.Output.Body([]byte("orderinfo")) +// }), +// beego.NewNamespace("/crm"). +// Get("/:id", func(ctx *context.Context) { +// ctx.Output.Body([]byte("crminfo")) +// }), +//) +func (n *Namespace) Namespace(ns ...*Namespace) *Namespace { + for _, ni := range ns { + for k, v := range ni.handlers.routers { + if _, ok := n.handlers.routers[k]; ok { + addPrefix(v, ni.prefix) + n.handlers.routers[k].AddTree(ni.prefix, v) + } else { + t := NewTree() + t.AddTree(ni.prefix, v) + addPrefix(t, ni.prefix) + n.handlers.routers[k] = t + } + } + if ni.handlers.enableFilter { + for pos, filterList := range ni.handlers.filters { + for _, mr := range filterList { + t := NewTree() + t.AddTree(ni.prefix, mr.tree) + mr.tree = t + n.handlers.insertFilterRouter(pos, mr) + } + } + } + } + return n +} + +// AddNamespace register Namespace into beego.Handler +// support multi Namespace +func AddNamespace(nl ...*Namespace) { + for _, n := range nl { + for k, v := range n.handlers.routers { + if _, ok := BeeApp.Handlers.routers[k]; ok { + addPrefix(v, n.prefix) + BeeApp.Handlers.routers[k].AddTree(n.prefix, v) + } else { + t := NewTree() + t.AddTree(n.prefix, v) + addPrefix(t, n.prefix) + BeeApp.Handlers.routers[k] = t + } + } + if n.handlers.enableFilter { + for pos, filterList := range n.handlers.filters { + for _, mr := range filterList { + t := NewTree() + t.AddTree(n.prefix, mr.tree) + mr.tree = t + BeeApp.Handlers.insertFilterRouter(pos, mr) + } + } + } + } +} + +func addPrefix(t *Tree, prefix string) { + for _, v := range t.fixrouters { + addPrefix(v, prefix) + } + if t.wildcard != nil { + addPrefix(t.wildcard, prefix) + } + for _, l := range t.leaves { + if c, ok := l.runObject.(*ControllerInfo); ok { + if !strings.HasPrefix(c.pattern, prefix) { + c.pattern = prefix + c.pattern + } + } + } +} + +// NSCond is Namespace Condition +func NSCond(cond namespaceCond) LinkNamespace { + return func(ns *Namespace) { + ns.Cond(cond) + } +} + +// NSBefore Namespace BeforeRouter filter +func NSBefore(filterList ...FilterFunc) LinkNamespace { + return func(ns *Namespace) { + ns.Filter("before", filterList...) + } +} + +// NSAfter add Namespace FinishRouter filter +func NSAfter(filterList ...FilterFunc) LinkNamespace { + return func(ns *Namespace) { + ns.Filter("after", filterList...) + } +} + +// NSInclude Namespace Include ControllerInterface +func NSInclude(cList ...ControllerInterface) LinkNamespace { + return func(ns *Namespace) { + ns.Include(cList...) + } +} + +// NSRouter call Namespace Router +func NSRouter(rootpath string, c ControllerInterface, mappingMethods ...string) LinkNamespace { + return func(ns *Namespace) { + ns.Router(rootpath, c, mappingMethods...) + } +} + +// NSGet call Namespace Get +func NSGet(rootpath string, f FilterFunc) LinkNamespace { + return func(ns *Namespace) { + ns.Get(rootpath, f) + } +} + +// NSPost call Namespace Post +func NSPost(rootpath string, f FilterFunc) LinkNamespace { + return func(ns *Namespace) { + ns.Post(rootpath, f) + } +} + +// NSHead call Namespace Head +func NSHead(rootpath string, f FilterFunc) LinkNamespace { + return func(ns *Namespace) { + ns.Head(rootpath, f) + } +} + +// NSPut call Namespace Put +func NSPut(rootpath string, f FilterFunc) LinkNamespace { + return func(ns *Namespace) { + ns.Put(rootpath, f) + } +} + +// NSDelete call Namespace Delete +func NSDelete(rootpath string, f FilterFunc) LinkNamespace { + return func(ns *Namespace) { + ns.Delete(rootpath, f) + } +} + +// NSAny call Namespace Any +func NSAny(rootpath string, f FilterFunc) LinkNamespace { + return func(ns *Namespace) { + ns.Any(rootpath, f) + } +} + +// NSOptions call Namespace Options +func NSOptions(rootpath string, f FilterFunc) LinkNamespace { + return func(ns *Namespace) { + ns.Options(rootpath, f) + } +} + +// NSPatch call Namespace Patch +func NSPatch(rootpath string, f FilterFunc) LinkNamespace { + return func(ns *Namespace) { + ns.Patch(rootpath, f) + } +} + +// NSAutoRouter call Namespace AutoRouter +func NSAutoRouter(c ControllerInterface) LinkNamespace { + return func(ns *Namespace) { + ns.AutoRouter(c) + } +} + +// NSAutoPrefix call Namespace AutoPrefix +func NSAutoPrefix(prefix string, c ControllerInterface) LinkNamespace { + return func(ns *Namespace) { + ns.AutoPrefix(prefix, c) + } +} + +// NSNamespace add sub Namespace +func NSNamespace(prefix string, params ...LinkNamespace) LinkNamespace { + return func(ns *Namespace) { + n := NewNamespace(prefix, params...) + ns.Namespace(n) + } +} + +// NSHandler add handler +func NSHandler(rootpath string, h http.Handler) LinkNamespace { + return func(ns *Namespace) { + ns.Handler(rootpath, h) + } +} diff --git a/pkg/namespace_test.go b/pkg/namespace_test.go new file mode 100644 index 00000000..b3f20dff --- /dev/null +++ b/pkg/namespace_test.go @@ -0,0 +1,168 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package beego + +import ( + "net/http" + "net/http/httptest" + "strconv" + "testing" + + "github.com/astaxie/beego/context" +) + +func TestNamespaceGet(t *testing.T) { + r, _ := http.NewRequest("GET", "/v1/user", nil) + w := httptest.NewRecorder() + + ns := NewNamespace("/v1") + ns.Get("/user", func(ctx *context.Context) { + ctx.Output.Body([]byte("v1_user")) + }) + AddNamespace(ns) + BeeApp.Handlers.ServeHTTP(w, r) + if w.Body.String() != "v1_user" { + t.Errorf("TestNamespaceGet can't run, get the response is " + w.Body.String()) + } +} + +func TestNamespacePost(t *testing.T) { + r, _ := http.NewRequest("POST", "/v1/user/123", nil) + w := httptest.NewRecorder() + + ns := NewNamespace("/v1") + ns.Post("/user/:id", func(ctx *context.Context) { + ctx.Output.Body([]byte(ctx.Input.Param(":id"))) + }) + AddNamespace(ns) + BeeApp.Handlers.ServeHTTP(w, r) + if w.Body.String() != "123" { + t.Errorf("TestNamespacePost can't run, get the response is " + w.Body.String()) + } +} + +func TestNamespaceNest(t *testing.T) { + r, _ := http.NewRequest("GET", "/v1/admin/order", nil) + w := httptest.NewRecorder() + + ns := NewNamespace("/v1") + ns.Namespace( + NewNamespace("/admin"). + Get("/order", func(ctx *context.Context) { + ctx.Output.Body([]byte("order")) + }), + ) + AddNamespace(ns) + BeeApp.Handlers.ServeHTTP(w, r) + if w.Body.String() != "order" { + t.Errorf("TestNamespaceNest can't run, get the response is " + w.Body.String()) + } +} + +func TestNamespaceNestParam(t *testing.T) { + r, _ := http.NewRequest("GET", "/v1/admin/order/123", nil) + w := httptest.NewRecorder() + + ns := NewNamespace("/v1") + ns.Namespace( + NewNamespace("/admin"). + Get("/order/:id", func(ctx *context.Context) { + ctx.Output.Body([]byte(ctx.Input.Param(":id"))) + }), + ) + AddNamespace(ns) + BeeApp.Handlers.ServeHTTP(w, r) + if w.Body.String() != "123" { + t.Errorf("TestNamespaceNestParam can't run, get the response is " + w.Body.String()) + } +} + +func TestNamespaceRouter(t *testing.T) { + r, _ := http.NewRequest("GET", "/v1/api/list", nil) + w := httptest.NewRecorder() + + ns := NewNamespace("/v1") + ns.Router("/api/list", &TestController{}, "*:List") + AddNamespace(ns) + BeeApp.Handlers.ServeHTTP(w, r) + if w.Body.String() != "i am list" { + t.Errorf("TestNamespaceRouter can't run, get the response is " + w.Body.String()) + } +} + +func TestNamespaceAutoFunc(t *testing.T) { + r, _ := http.NewRequest("GET", "/v1/test/list", nil) + w := httptest.NewRecorder() + + ns := NewNamespace("/v1") + ns.AutoRouter(&TestController{}) + AddNamespace(ns) + BeeApp.Handlers.ServeHTTP(w, r) + if w.Body.String() != "i am list" { + t.Errorf("user define func can't run") + } +} + +func TestNamespaceFilter(t *testing.T) { + r, _ := http.NewRequest("GET", "/v1/user/123", nil) + w := httptest.NewRecorder() + + ns := NewNamespace("/v1") + ns.Filter("before", func(ctx *context.Context) { + ctx.Output.Body([]byte("this is Filter")) + }). + Get("/user/:id", func(ctx *context.Context) { + ctx.Output.Body([]byte(ctx.Input.Param(":id"))) + }) + AddNamespace(ns) + BeeApp.Handlers.ServeHTTP(w, r) + if w.Body.String() != "this is Filter" { + t.Errorf("TestNamespaceFilter can't run, get the response is " + w.Body.String()) + } +} + +func TestNamespaceCond(t *testing.T) { + r, _ := http.NewRequest("GET", "/v2/test/list", nil) + w := httptest.NewRecorder() + + ns := NewNamespace("/v2") + ns.Cond(func(ctx *context.Context) bool { + return ctx.Input.Domain() == "beego.me" + }). + AutoRouter(&TestController{}) + AddNamespace(ns) + BeeApp.Handlers.ServeHTTP(w, r) + if w.Code != 405 { + t.Errorf("TestNamespaceCond can't run get the result " + strconv.Itoa(w.Code)) + } +} + +func TestNamespaceInside(t *testing.T) { + r, _ := http.NewRequest("GET", "/v3/shop/order/123", nil) + w := httptest.NewRecorder() + ns := NewNamespace("/v3", + NSAutoRouter(&TestController{}), + NSNamespace("/shop", + NSGet("/order/:id", func(ctx *context.Context) { + ctx.Output.Body([]byte(ctx.Input.Param(":id"))) + }), + ), + ) + AddNamespace(ns) + BeeApp.Handlers.ServeHTTP(w, r) + if w.Body.String() != "123" { + t.Errorf("TestNamespaceInside can't run, get the response is " + w.Body.String()) + } +} diff --git a/pkg/parser.go b/pkg/parser.go new file mode 100644 index 00000000..3a311894 --- /dev/null +++ b/pkg/parser.go @@ -0,0 +1,591 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package beego + +import ( + "encoding/json" + "errors" + "fmt" + "go/ast" + "go/parser" + "go/token" + "io/ioutil" + "os" + "path/filepath" + "regexp" + "sort" + "strconv" + "strings" + "unicode" + + "github.com/astaxie/beego/context/param" + "github.com/astaxie/beego/logs" + "github.com/astaxie/beego/utils" +) + +var globalRouterTemplate = `package {{.routersDir}} + +import ( + "github.com/astaxie/beego" + "github.com/astaxie/beego/context/param"{{.globalimport}} +) + +func init() { +{{.globalinfo}} +} +` + +var ( + lastupdateFilename = "lastupdate.tmp" + commentFilename string + pkgLastupdate map[string]int64 + genInfoList map[string][]ControllerComments + + routerHooks = map[string]int{ + "beego.BeforeStatic": BeforeStatic, + "beego.BeforeRouter": BeforeRouter, + "beego.BeforeExec": BeforeExec, + "beego.AfterExec": AfterExec, + "beego.FinishRouter": FinishRouter, + } + + routerHooksMapping = map[int]string{ + BeforeStatic: "beego.BeforeStatic", + BeforeRouter: "beego.BeforeRouter", + BeforeExec: "beego.BeforeExec", + AfterExec: "beego.AfterExec", + FinishRouter: "beego.FinishRouter", + } +) + +const commentPrefix = "commentsRouter_" + +func init() { + pkgLastupdate = make(map[string]int64) +} + +func parserPkg(pkgRealpath, pkgpath string) error { + rep := strings.NewReplacer("\\", "_", "/", "_", ".", "_") + commentFilename, _ = filepath.Rel(AppPath, pkgRealpath) + commentFilename = commentPrefix + rep.Replace(commentFilename) + ".go" + if !compareFile(pkgRealpath) { + logs.Info(pkgRealpath + " no changed") + return nil + } + genInfoList = make(map[string][]ControllerComments) + fileSet := token.NewFileSet() + astPkgs, err := parser.ParseDir(fileSet, pkgRealpath, func(info os.FileInfo) bool { + name := info.Name() + return !info.IsDir() && !strings.HasPrefix(name, ".") && strings.HasSuffix(name, ".go") + }, parser.ParseComments) + + if err != nil { + return err + } + for _, pkg := range astPkgs { + for _, fl := range pkg.Files { + for _, d := range fl.Decls { + switch specDecl := d.(type) { + case *ast.FuncDecl: + if specDecl.Recv != nil { + exp, ok := specDecl.Recv.List[0].Type.(*ast.StarExpr) // Check that the type is correct first beforing throwing to parser + if ok { + parserComments(specDecl, fmt.Sprint(exp.X), pkgpath) + } + } + } + } + } + } + genRouterCode(pkgRealpath) + savetoFile(pkgRealpath) + return nil +} + +type parsedComment struct { + routerPath string + methods []string + params map[string]parsedParam + filters []parsedFilter + imports []parsedImport +} + +type parsedImport struct { + importPath string + importAlias string +} + +type parsedFilter struct { + pattern string + pos int + filter string + params []bool +} + +type parsedParam struct { + name string + datatype string + location string + defValue string + required bool +} + +func parserComments(f *ast.FuncDecl, controllerName, pkgpath string) error { + if f.Doc != nil { + parsedComments, err := parseComment(f.Doc.List) + if err != nil { + return err + } + for _, parsedComment := range parsedComments { + if parsedComment.routerPath != "" { + key := pkgpath + ":" + controllerName + cc := ControllerComments{} + cc.Method = f.Name.String() + cc.Router = parsedComment.routerPath + cc.AllowHTTPMethods = parsedComment.methods + cc.MethodParams = buildMethodParams(f.Type.Params.List, parsedComment) + cc.FilterComments = buildFilters(parsedComment.filters) + cc.ImportComments = buildImports(parsedComment.imports) + genInfoList[key] = append(genInfoList[key], cc) + } + } + } + return nil +} + +func buildImports(pis []parsedImport) []*ControllerImportComments { + var importComments []*ControllerImportComments + + for _, pi := range pis { + importComments = append(importComments, &ControllerImportComments{ + ImportPath: pi.importPath, + ImportAlias: pi.importAlias, + }) + } + + return importComments +} + +func buildFilters(pfs []parsedFilter) []*ControllerFilterComments { + var filterComments []*ControllerFilterComments + + for _, pf := range pfs { + var ( + returnOnOutput bool + resetParams bool + ) + + if len(pf.params) >= 1 { + returnOnOutput = pf.params[0] + } + + if len(pf.params) >= 2 { + resetParams = pf.params[1] + } + + filterComments = append(filterComments, &ControllerFilterComments{ + Filter: pf.filter, + Pattern: pf.pattern, + Pos: pf.pos, + ReturnOnOutput: returnOnOutput, + ResetParams: resetParams, + }) + } + + return filterComments +} + +func buildMethodParams(funcParams []*ast.Field, pc *parsedComment) []*param.MethodParam { + result := make([]*param.MethodParam, 0, len(funcParams)) + for _, fparam := range funcParams { + for _, pName := range fparam.Names { + methodParam := buildMethodParam(fparam, pName.Name, pc) + result = append(result, methodParam) + } + } + return result +} + +func buildMethodParam(fparam *ast.Field, name string, pc *parsedComment) *param.MethodParam { + options := []param.MethodParamOption{} + if cparam, ok := pc.params[name]; ok { + //Build param from comment info + name = cparam.name + if cparam.required { + options = append(options, param.IsRequired) + } + switch cparam.location { + case "body": + options = append(options, param.InBody) + case "header": + options = append(options, param.InHeader) + case "path": + options = append(options, param.InPath) + } + if cparam.defValue != "" { + options = append(options, param.Default(cparam.defValue)) + } + } else { + if paramInPath(name, pc.routerPath) { + options = append(options, param.InPath) + } + } + return param.New(name, options...) +} + +func paramInPath(name, route string) bool { + return strings.HasSuffix(route, ":"+name) || + strings.Contains(route, ":"+name+"/") +} + +var routeRegex = regexp.MustCompile(`@router\s+(\S+)(?:\s+\[(\S+)\])?`) + +func parseComment(lines []*ast.Comment) (pcs []*parsedComment, err error) { + pcs = []*parsedComment{} + params := map[string]parsedParam{} + filters := []parsedFilter{} + imports := []parsedImport{} + + for _, c := range lines { + t := strings.TrimSpace(strings.TrimLeft(c.Text, "//")) + if strings.HasPrefix(t, "@Param") { + pv := getparams(strings.TrimSpace(strings.TrimLeft(t, "@Param"))) + if len(pv) < 4 { + logs.Error("Invalid @Param format. Needs at least 4 parameters") + } + p := parsedParam{} + names := strings.SplitN(pv[0], "=>", 2) + p.name = names[0] + funcParamName := p.name + if len(names) > 1 { + funcParamName = names[1] + } + p.location = pv[1] + p.datatype = pv[2] + switch len(pv) { + case 5: + p.required, _ = strconv.ParseBool(pv[3]) + case 6: + p.defValue = pv[3] + p.required, _ = strconv.ParseBool(pv[4]) + } + params[funcParamName] = p + } + } + + for _, c := range lines { + t := strings.TrimSpace(strings.TrimLeft(c.Text, "//")) + if strings.HasPrefix(t, "@Import") { + iv := getparams(strings.TrimSpace(strings.TrimLeft(t, "@Import"))) + if len(iv) == 0 || len(iv) > 2 { + logs.Error("Invalid @Import format. Only accepts 1 or 2 parameters") + continue + } + + p := parsedImport{} + p.importPath = iv[0] + + if len(iv) == 2 { + p.importAlias = iv[1] + } + + imports = append(imports, p) + } + } + +filterLoop: + for _, c := range lines { + t := strings.TrimSpace(strings.TrimLeft(c.Text, "//")) + if strings.HasPrefix(t, "@Filter") { + fv := getparams(strings.TrimSpace(strings.TrimLeft(t, "@Filter"))) + if len(fv) < 3 { + logs.Error("Invalid @Filter format. Needs at least 3 parameters") + continue filterLoop + } + + p := parsedFilter{} + p.pattern = fv[0] + posName := fv[1] + if pos, exists := routerHooks[posName]; exists { + p.pos = pos + } else { + logs.Error("Invalid @Filter pos: ", posName) + continue filterLoop + } + + p.filter = fv[2] + fvParams := fv[3:] + for _, fvParam := range fvParams { + switch fvParam { + case "true": + p.params = append(p.params, true) + case "false": + p.params = append(p.params, false) + default: + logs.Error("Invalid @Filter param: ", fvParam) + continue filterLoop + } + } + + filters = append(filters, p) + } + } + + for _, c := range lines { + var pc = &parsedComment{} + pc.params = params + pc.filters = filters + pc.imports = imports + + t := strings.TrimSpace(strings.TrimLeft(c.Text, "//")) + if strings.HasPrefix(t, "@router") { + t := strings.TrimSpace(strings.TrimLeft(c.Text, "//")) + matches := routeRegex.FindStringSubmatch(t) + if len(matches) == 3 { + pc.routerPath = matches[1] + methods := matches[2] + if methods == "" { + pc.methods = []string{"get"} + //pc.hasGet = true + } else { + pc.methods = strings.Split(methods, ",") + //pc.hasGet = strings.Contains(methods, "get") + } + pcs = append(pcs, pc) + } else { + return nil, errors.New("Router information is missing") + } + } + } + return +} + +// direct copy from bee\g_docs.go +// analysis params return []string +// @Param query form string true "The email for login" +// [query form string true "The email for login"] +func getparams(str string) []string { + var s []rune + var j int + var start bool + var r []string + var quoted int8 + for _, c := range str { + if unicode.IsSpace(c) && quoted == 0 { + if !start { + continue + } else { + start = false + j++ + r = append(r, string(s)) + s = make([]rune, 0) + continue + } + } + + start = true + if c == '"' { + quoted ^= 1 + continue + } + s = append(s, c) + } + if len(s) > 0 { + r = append(r, string(s)) + } + return r +} + +func genRouterCode(pkgRealpath string) { + os.Mkdir(getRouterDir(pkgRealpath), 0755) + logs.Info("generate router from comments") + var ( + globalinfo string + globalimport string + sortKey []string + ) + for k := range genInfoList { + sortKey = append(sortKey, k) + } + sort.Strings(sortKey) + for _, k := range sortKey { + cList := genInfoList[k] + sort.Sort(ControllerCommentsSlice(cList)) + for _, c := range cList { + allmethod := "nil" + if len(c.AllowHTTPMethods) > 0 { + allmethod = "[]string{" + for _, m := range c.AllowHTTPMethods { + allmethod += `"` + m + `",` + } + allmethod = strings.TrimRight(allmethod, ",") + "}" + } + + params := "nil" + if len(c.Params) > 0 { + params = "[]map[string]string{" + for _, p := range c.Params { + for k, v := range p { + params = params + `map[string]string{` + k + `:"` + v + `"},` + } + } + params = strings.TrimRight(params, ",") + "}" + } + + methodParams := "param.Make(" + if len(c.MethodParams) > 0 { + lines := make([]string, 0, len(c.MethodParams)) + for _, m := range c.MethodParams { + lines = append(lines, fmt.Sprint(m)) + } + methodParams += "\n " + + strings.Join(lines, ",\n ") + + ",\n " + } + methodParams += ")" + + imports := "" + if len(c.ImportComments) > 0 { + for _, i := range c.ImportComments { + var s string + if i.ImportAlias != "" { + s = fmt.Sprintf(` + %s "%s"`, i.ImportAlias, i.ImportPath) + } else { + s = fmt.Sprintf(` + "%s"`, i.ImportPath) + } + if !strings.Contains(globalimport, s) { + imports += s + } + } + } + + filters := "" + if len(c.FilterComments) > 0 { + for _, f := range c.FilterComments { + filters += fmt.Sprintf(` &beego.ControllerFilter{ + Pattern: "%s", + Pos: %s, + Filter: %s, + ReturnOnOutput: %v, + ResetParams: %v, + },`, f.Pattern, routerHooksMapping[f.Pos], f.Filter, f.ReturnOnOutput, f.ResetParams) + } + } + + if filters == "" { + filters = "nil" + } else { + filters = fmt.Sprintf(`[]*beego.ControllerFilter{ +%s + }`, filters) + } + + globalimport += imports + + globalinfo = globalinfo + ` + beego.GlobalControllerRouter["` + k + `"] = append(beego.GlobalControllerRouter["` + k + `"], + beego.ControllerComments{ + Method: "` + strings.TrimSpace(c.Method) + `", + ` + `Router: "` + c.Router + `"` + `, + AllowHTTPMethods: ` + allmethod + `, + MethodParams: ` + methodParams + `, + Filters: ` + filters + `, + Params: ` + params + `}) +` + } + } + + if globalinfo != "" { + f, err := os.Create(filepath.Join(getRouterDir(pkgRealpath), commentFilename)) + if err != nil { + panic(err) + } + defer f.Close() + + routersDir := AppConfig.DefaultString("routersdir", "routers") + content := strings.Replace(globalRouterTemplate, "{{.globalinfo}}", globalinfo, -1) + content = strings.Replace(content, "{{.routersDir}}", routersDir, -1) + content = strings.Replace(content, "{{.globalimport}}", globalimport, -1) + f.WriteString(content) + } +} + +func compareFile(pkgRealpath string) bool { + if !utils.FileExists(filepath.Join(getRouterDir(pkgRealpath), commentFilename)) { + return true + } + if utils.FileExists(lastupdateFilename) { + content, err := ioutil.ReadFile(lastupdateFilename) + if err != nil { + return true + } + json.Unmarshal(content, &pkgLastupdate) + lastupdate, err := getpathTime(pkgRealpath) + if err != nil { + return true + } + if v, ok := pkgLastupdate[pkgRealpath]; ok { + if lastupdate <= v { + return false + } + } + } + return true +} + +func savetoFile(pkgRealpath string) { + lastupdate, err := getpathTime(pkgRealpath) + if err != nil { + return + } + pkgLastupdate[pkgRealpath] = lastupdate + d, err := json.Marshal(pkgLastupdate) + if err != nil { + return + } + ioutil.WriteFile(lastupdateFilename, d, os.ModePerm) +} + +func getpathTime(pkgRealpath string) (lastupdate int64, err error) { + fl, err := ioutil.ReadDir(pkgRealpath) + if err != nil { + return lastupdate, err + } + for _, f := range fl { + if lastupdate < f.ModTime().UnixNano() { + lastupdate = f.ModTime().UnixNano() + } + } + return lastupdate, nil +} + +func getRouterDir(pkgRealpath string) string { + dir := filepath.Dir(pkgRealpath) + for { + routersDir := AppConfig.DefaultString("routersdir", "routers") + d := filepath.Join(dir, routersDir) + if utils.FileExists(d) { + return d + } + + if r, _ := filepath.Rel(dir, AppPath); r == "." { + return d + } + // Parent dir. + dir = filepath.Dir(dir) + } +} diff --git a/pkg/plugins/apiauth/apiauth.go b/pkg/plugins/apiauth/apiauth.go new file mode 100644 index 00000000..10e25f3f --- /dev/null +++ b/pkg/plugins/apiauth/apiauth.go @@ -0,0 +1,165 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package apiauth provides handlers to enable apiauth support. +// +// Simple Usage: +// import( +// "github.com/astaxie/beego" +// "github.com/astaxie/beego/plugins/apiauth" +// ) +// +// func main(){ +// // apiauth every request +// beego.InsertFilter("*", beego.BeforeRouter,apiauth.APIBaiscAuth("appid","appkey")) +// beego.Run() +// } +// +// Advanced Usage: +// +// func getAppSecret(appid string) string { +// // get appsecret by appid +// // maybe store in configure, maybe in database +// } +// +// beego.InsertFilter("*", beego.BeforeRouter,apiauth.APISecretAuth(getAppSecret, 360)) +// +// Information: +// +// In the request user should include these params in the query +// +// 1. appid +// +// appid is assigned to the application +// +// 2. signature +// +// get the signature use apiauth.Signature() +// +// when you send to server remember use url.QueryEscape() +// +// 3. timestamp: +// +// send the request time, the format is yyyy-mm-dd HH:ii:ss +// +package apiauth + +import ( + "bytes" + "crypto/hmac" + "crypto/sha256" + "encoding/base64" + "fmt" + "net/url" + "sort" + "time" + + "github.com/astaxie/beego" + "github.com/astaxie/beego/context" +) + +// AppIDToAppSecret is used to get appsecret throw appid +type AppIDToAppSecret func(string) string + +// APIBasicAuth use the basic appid/appkey as the AppIdToAppSecret +func APIBasicAuth(appid, appkey string) beego.FilterFunc { + ft := func(aid string) string { + if aid == appid { + return appkey + } + return "" + } + return APISecretAuth(ft, 300) +} + +// APIBaiscAuth calls APIBasicAuth for previous callers +func APIBaiscAuth(appid, appkey string) beego.FilterFunc { + return APIBasicAuth(appid, appkey) +} + +// APISecretAuth use AppIdToAppSecret verify and +func APISecretAuth(f AppIDToAppSecret, timeout int) beego.FilterFunc { + return func(ctx *context.Context) { + if ctx.Input.Query("appid") == "" { + ctx.ResponseWriter.WriteHeader(403) + ctx.WriteString("miss query param: appid") + return + } + appsecret := f(ctx.Input.Query("appid")) + if appsecret == "" { + ctx.ResponseWriter.WriteHeader(403) + ctx.WriteString("not exist this appid") + return + } + if ctx.Input.Query("signature") == "" { + ctx.ResponseWriter.WriteHeader(403) + ctx.WriteString("miss query param: signature") + return + } + if ctx.Input.Query("timestamp") == "" { + ctx.ResponseWriter.WriteHeader(403) + ctx.WriteString("miss query param: timestamp") + return + } + u, err := time.Parse("2006-01-02 15:04:05", ctx.Input.Query("timestamp")) + if err != nil { + ctx.ResponseWriter.WriteHeader(403) + ctx.WriteString("timestamp format is error, should 2006-01-02 15:04:05") + return + } + t := time.Now() + if t.Sub(u).Seconds() > float64(timeout) { + ctx.ResponseWriter.WriteHeader(403) + ctx.WriteString("timeout! the request time is long ago, please try again") + return + } + if ctx.Input.Query("signature") != + Signature(appsecret, ctx.Input.Method(), ctx.Request.Form, ctx.Input.URL()) { + ctx.ResponseWriter.WriteHeader(403) + ctx.WriteString("auth failed") + } + } +} + +// Signature used to generate signature with the appsecret/method/params/RequestURI +func Signature(appsecret, method string, params url.Values, RequestURL string) (result string) { + var b bytes.Buffer + keys := make([]string, len(params)) + pa := make(map[string]string) + for k, v := range params { + pa[k] = v[0] + keys = append(keys, k) + } + + sort.Strings(keys) + + for _, key := range keys { + if key == "signature" { + continue + } + + val := pa[key] + if key != "" && val != "" { + b.WriteString(key) + b.WriteString(val) + } + } + + stringToSign := fmt.Sprintf("%v\n%v\n%v\n", method, b.String(), RequestURL) + + sha256 := sha256.New + hash := hmac.New(sha256, []byte(appsecret)) + hash.Write([]byte(stringToSign)) + return base64.StdEncoding.EncodeToString(hash.Sum(nil)) +} diff --git a/pkg/plugins/apiauth/apiauth_test.go b/pkg/plugins/apiauth/apiauth_test.go new file mode 100644 index 00000000..1f56cb0f --- /dev/null +++ b/pkg/plugins/apiauth/apiauth_test.go @@ -0,0 +1,20 @@ +package apiauth + +import ( + "net/url" + "testing" +) + +func TestSignature(t *testing.T) { + appsecret := "beego secret" + method := "GET" + RequestURL := "http://localhost/test/url" + params := make(url.Values) + params.Add("arg1", "hello") + params.Add("arg2", "beego") + + signature := "mFdpvLh48ca4mDVEItE9++AKKQ/IVca7O/ZyyB8hR58=" + if Signature(appsecret, method, params, RequestURL) != signature { + t.Error("Signature error") + } +} diff --git a/pkg/plugins/auth/basic.go b/pkg/plugins/auth/basic.go new file mode 100644 index 00000000..c478044a --- /dev/null +++ b/pkg/plugins/auth/basic.go @@ -0,0 +1,107 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package auth provides handlers to enable basic auth support. +// Simple Usage: +// import( +// "github.com/astaxie/beego" +// "github.com/astaxie/beego/plugins/auth" +// ) +// +// func main(){ +// // authenticate every request +// beego.InsertFilter("*", beego.BeforeRouter,auth.Basic("username","secretpassword")) +// beego.Run() +// } +// +// +// Advanced Usage: +// +// func SecretAuth(username, password string) bool { +// return username == "astaxie" && password == "helloBeego" +// } +// authPlugin := auth.NewBasicAuthenticator(SecretAuth, "Authorization Required") +// beego.InsertFilter("*", beego.BeforeRouter,authPlugin) +package auth + +import ( + "encoding/base64" + "net/http" + "strings" + + "github.com/astaxie/beego" + "github.com/astaxie/beego/context" +) + +var defaultRealm = "Authorization Required" + +// Basic is the http basic auth +func Basic(username string, password string) beego.FilterFunc { + secrets := func(user, pass string) bool { + return user == username && pass == password + } + return NewBasicAuthenticator(secrets, defaultRealm) +} + +// NewBasicAuthenticator return the BasicAuth +func NewBasicAuthenticator(secrets SecretProvider, Realm string) beego.FilterFunc { + return func(ctx *context.Context) { + a := &BasicAuth{Secrets: secrets, Realm: Realm} + if username := a.CheckAuth(ctx.Request); username == "" { + a.RequireAuth(ctx.ResponseWriter, ctx.Request) + } + } +} + +// SecretProvider is the SecretProvider function +type SecretProvider func(user, pass string) bool + +// BasicAuth store the SecretProvider and Realm +type BasicAuth struct { + Secrets SecretProvider + Realm string +} + +// CheckAuth Checks the username/password combination from the request. Returns +// either an empty string (authentication failed) or the name of the +// authenticated user. +// Supports MD5 and SHA1 password entries +func (a *BasicAuth) CheckAuth(r *http.Request) string { + s := strings.SplitN(r.Header.Get("Authorization"), " ", 2) + if len(s) != 2 || s[0] != "Basic" { + return "" + } + + b, err := base64.StdEncoding.DecodeString(s[1]) + if err != nil { + return "" + } + pair := strings.SplitN(string(b), ":", 2) + if len(pair) != 2 { + return "" + } + + if a.Secrets(pair[0], pair[1]) { + return pair[0] + } + return "" +} + +// RequireAuth http.Handler for BasicAuth which initiates the authentication process +// (or requires reauthentication). +func (a *BasicAuth) RequireAuth(w http.ResponseWriter, r *http.Request) { + w.Header().Set("WWW-Authenticate", `Basic realm="`+a.Realm+`"`) + w.WriteHeader(401) + w.Write([]byte("401 Unauthorized\n")) +} diff --git a/pkg/plugins/authz/authz.go b/pkg/plugins/authz/authz.go new file mode 100644 index 00000000..9dc0db76 --- /dev/null +++ b/pkg/plugins/authz/authz.go @@ -0,0 +1,86 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package authz provides handlers to enable ACL, RBAC, ABAC authorization support. +// Simple Usage: +// import( +// "github.com/astaxie/beego" +// "github.com/astaxie/beego/plugins/authz" +// "github.com/casbin/casbin" +// ) +// +// func main(){ +// // mediate the access for every request +// beego.InsertFilter("*", beego.BeforeRouter, authz.NewAuthorizer(casbin.NewEnforcer("authz_model.conf", "authz_policy.csv"))) +// beego.Run() +// } +// +// +// Advanced Usage: +// +// func main(){ +// e := casbin.NewEnforcer("authz_model.conf", "") +// e.AddRoleForUser("alice", "admin") +// e.AddPolicy(...) +// +// beego.InsertFilter("*", beego.BeforeRouter, authz.NewAuthorizer(e)) +// beego.Run() +// } +package authz + +import ( + "github.com/astaxie/beego" + "github.com/astaxie/beego/context" + "github.com/casbin/casbin" + "net/http" +) + +// NewAuthorizer returns the authorizer. +// Use a casbin enforcer as input +func NewAuthorizer(e *casbin.Enforcer) beego.FilterFunc { + return func(ctx *context.Context) { + a := &BasicAuthorizer{enforcer: e} + + if !a.CheckPermission(ctx.Request) { + a.RequirePermission(ctx.ResponseWriter) + } + } +} + +// BasicAuthorizer stores the casbin handler +type BasicAuthorizer struct { + enforcer *casbin.Enforcer +} + +// GetUserName gets the user name from the request. +// Currently, only HTTP basic authentication is supported +func (a *BasicAuthorizer) GetUserName(r *http.Request) string { + username, _, _ := r.BasicAuth() + return username +} + +// CheckPermission checks the user/method/path combination from the request. +// Returns true (permission granted) or false (permission forbidden) +func (a *BasicAuthorizer) CheckPermission(r *http.Request) bool { + user := a.GetUserName(r) + method := r.Method + path := r.URL.Path + return a.enforcer.Enforce(user, path, method) +} + +// RequirePermission returns the 403 Forbidden to the client +func (a *BasicAuthorizer) RequirePermission(w http.ResponseWriter) { + w.WriteHeader(403) + w.Write([]byte("403 Forbidden\n")) +} diff --git a/pkg/plugins/authz/authz_model.conf b/pkg/plugins/authz/authz_model.conf new file mode 100644 index 00000000..d1b3dbd7 --- /dev/null +++ b/pkg/plugins/authz/authz_model.conf @@ -0,0 +1,14 @@ +[request_definition] +r = sub, obj, act + +[policy_definition] +p = sub, obj, act + +[role_definition] +g = _, _ + +[policy_effect] +e = some(where (p.eft == allow)) + +[matchers] +m = g(r.sub, p.sub) && keyMatch(r.obj, p.obj) && (r.act == p.act || p.act == "*") \ No newline at end of file diff --git a/pkg/plugins/authz/authz_policy.csv b/pkg/plugins/authz/authz_policy.csv new file mode 100644 index 00000000..c062dd3e --- /dev/null +++ b/pkg/plugins/authz/authz_policy.csv @@ -0,0 +1,7 @@ +p, alice, /dataset1/*, GET +p, alice, /dataset1/resource1, POST +p, bob, /dataset2/resource1, * +p, bob, /dataset2/resource2, GET +p, bob, /dataset2/folder1/*, POST +p, dataset1_admin, /dataset1/*, * +g, cathy, dataset1_admin \ No newline at end of file diff --git a/pkg/plugins/authz/authz_test.go b/pkg/plugins/authz/authz_test.go new file mode 100644 index 00000000..49aed84c --- /dev/null +++ b/pkg/plugins/authz/authz_test.go @@ -0,0 +1,107 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package authz + +import ( + "github.com/astaxie/beego" + "github.com/astaxie/beego/context" + "github.com/astaxie/beego/plugins/auth" + "github.com/casbin/casbin" + "net/http" + "net/http/httptest" + "testing" +) + +func testRequest(t *testing.T, handler *beego.ControllerRegister, user string, path string, method string, code int) { + r, _ := http.NewRequest(method, path, nil) + r.SetBasicAuth(user, "123") + w := httptest.NewRecorder() + handler.ServeHTTP(w, r) + + if w.Code != code { + t.Errorf("%s, %s, %s: %d, supposed to be %d", user, path, method, w.Code, code) + } +} + +func TestBasic(t *testing.T) { + handler := beego.NewControllerRegister() + + handler.InsertFilter("*", beego.BeforeRouter, auth.Basic("alice", "123")) + handler.InsertFilter("*", beego.BeforeRouter, NewAuthorizer(casbin.NewEnforcer("authz_model.conf", "authz_policy.csv"))) + + handler.Any("*", func(ctx *context.Context) { + ctx.Output.SetStatus(200) + }) + + testRequest(t, handler, "alice", "/dataset1/resource1", "GET", 200) + testRequest(t, handler, "alice", "/dataset1/resource1", "POST", 200) + testRequest(t, handler, "alice", "/dataset1/resource2", "GET", 200) + testRequest(t, handler, "alice", "/dataset1/resource2", "POST", 403) +} + +func TestPathWildcard(t *testing.T) { + handler := beego.NewControllerRegister() + + handler.InsertFilter("*", beego.BeforeRouter, auth.Basic("bob", "123")) + handler.InsertFilter("*", beego.BeforeRouter, NewAuthorizer(casbin.NewEnforcer("authz_model.conf", "authz_policy.csv"))) + + handler.Any("*", func(ctx *context.Context) { + ctx.Output.SetStatus(200) + }) + + testRequest(t, handler, "bob", "/dataset2/resource1", "GET", 200) + testRequest(t, handler, "bob", "/dataset2/resource1", "POST", 200) + testRequest(t, handler, "bob", "/dataset2/resource1", "DELETE", 200) + testRequest(t, handler, "bob", "/dataset2/resource2", "GET", 200) + testRequest(t, handler, "bob", "/dataset2/resource2", "POST", 403) + testRequest(t, handler, "bob", "/dataset2/resource2", "DELETE", 403) + + testRequest(t, handler, "bob", "/dataset2/folder1/item1", "GET", 403) + testRequest(t, handler, "bob", "/dataset2/folder1/item1", "POST", 200) + testRequest(t, handler, "bob", "/dataset2/folder1/item1", "DELETE", 403) + testRequest(t, handler, "bob", "/dataset2/folder1/item2", "GET", 403) + testRequest(t, handler, "bob", "/dataset2/folder1/item2", "POST", 200) + testRequest(t, handler, "bob", "/dataset2/folder1/item2", "DELETE", 403) +} + +func TestRBAC(t *testing.T) { + handler := beego.NewControllerRegister() + + handler.InsertFilter("*", beego.BeforeRouter, auth.Basic("cathy", "123")) + e := casbin.NewEnforcer("authz_model.conf", "authz_policy.csv") + handler.InsertFilter("*", beego.BeforeRouter, NewAuthorizer(e)) + + handler.Any("*", func(ctx *context.Context) { + ctx.Output.SetStatus(200) + }) + + // cathy can access all /dataset1/* resources via all methods because it has the dataset1_admin role. + testRequest(t, handler, "cathy", "/dataset1/item", "GET", 200) + testRequest(t, handler, "cathy", "/dataset1/item", "POST", 200) + testRequest(t, handler, "cathy", "/dataset1/item", "DELETE", 200) + testRequest(t, handler, "cathy", "/dataset2/item", "GET", 403) + testRequest(t, handler, "cathy", "/dataset2/item", "POST", 403) + testRequest(t, handler, "cathy", "/dataset2/item", "DELETE", 403) + + // delete all roles on user cathy, so cathy cannot access any resources now. + e.DeleteRolesForUser("cathy") + + testRequest(t, handler, "cathy", "/dataset1/item", "GET", 403) + testRequest(t, handler, "cathy", "/dataset1/item", "POST", 403) + testRequest(t, handler, "cathy", "/dataset1/item", "DELETE", 403) + testRequest(t, handler, "cathy", "/dataset2/item", "GET", 403) + testRequest(t, handler, "cathy", "/dataset2/item", "POST", 403) + testRequest(t, handler, "cathy", "/dataset2/item", "DELETE", 403) +} diff --git a/pkg/plugins/cors/cors.go b/pkg/plugins/cors/cors.go new file mode 100644 index 00000000..45c327ab --- /dev/null +++ b/pkg/plugins/cors/cors.go @@ -0,0 +1,228 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package cors provides handlers to enable CORS support. +// Usage +// import ( +// "github.com/astaxie/beego" +// "github.com/astaxie/beego/plugins/cors" +// ) +// +// func main() { +// // CORS for https://foo.* origins, allowing: +// // - PUT and PATCH methods +// // - Origin header +// // - Credentials share +// beego.InsertFilter("*", beego.BeforeRouter, cors.Allow(&cors.Options{ +// AllowOrigins: []string{"https://*.foo.com"}, +// AllowMethods: []string{"PUT", "PATCH"}, +// AllowHeaders: []string{"Origin"}, +// ExposeHeaders: []string{"Content-Length"}, +// AllowCredentials: true, +// })) +// beego.Run() +// } +package cors + +import ( + "net/http" + "regexp" + "strconv" + "strings" + "time" + + "github.com/astaxie/beego" + "github.com/astaxie/beego/context" +) + +const ( + headerAllowOrigin = "Access-Control-Allow-Origin" + headerAllowCredentials = "Access-Control-Allow-Credentials" + headerAllowHeaders = "Access-Control-Allow-Headers" + headerAllowMethods = "Access-Control-Allow-Methods" + headerExposeHeaders = "Access-Control-Expose-Headers" + headerMaxAge = "Access-Control-Max-Age" + + headerOrigin = "Origin" + headerRequestMethod = "Access-Control-Request-Method" + headerRequestHeaders = "Access-Control-Request-Headers" +) + +var ( + defaultAllowHeaders = []string{"Origin", "Accept", "Content-Type", "Authorization"} + // Regex patterns are generated from AllowOrigins. These are used and generated internally. + allowOriginPatterns = []string{} +) + +// Options represents Access Control options. +type Options struct { + // If set, all origins are allowed. + AllowAllOrigins bool + // A list of allowed origins. Wild cards and FQDNs are supported. + AllowOrigins []string + // If set, allows to share auth credentials such as cookies. + AllowCredentials bool + // A list of allowed HTTP methods. + AllowMethods []string + // A list of allowed HTTP headers. + AllowHeaders []string + // A list of exposed HTTP headers. + ExposeHeaders []string + // Max age of the CORS headers. + MaxAge time.Duration +} + +// Header converts options into CORS headers. +func (o *Options) Header(origin string) (headers map[string]string) { + headers = make(map[string]string) + // if origin is not allowed, don't extend the headers + // with CORS headers. + if !o.AllowAllOrigins && !o.IsOriginAllowed(origin) { + return + } + + // add allow origin + if o.AllowAllOrigins { + headers[headerAllowOrigin] = "*" + } else { + headers[headerAllowOrigin] = origin + } + + // add allow credentials + headers[headerAllowCredentials] = strconv.FormatBool(o.AllowCredentials) + + // add allow methods + if len(o.AllowMethods) > 0 { + headers[headerAllowMethods] = strings.Join(o.AllowMethods, ",") + } + + // add allow headers + if len(o.AllowHeaders) > 0 { + headers[headerAllowHeaders] = strings.Join(o.AllowHeaders, ",") + } + + // add exposed header + if len(o.ExposeHeaders) > 0 { + headers[headerExposeHeaders] = strings.Join(o.ExposeHeaders, ",") + } + // add a max age header + if o.MaxAge > time.Duration(0) { + headers[headerMaxAge] = strconv.FormatInt(int64(o.MaxAge/time.Second), 10) + } + return +} + +// PreflightHeader converts options into CORS headers for a preflight response. +func (o *Options) PreflightHeader(origin, rMethod, rHeaders string) (headers map[string]string) { + headers = make(map[string]string) + if !o.AllowAllOrigins && !o.IsOriginAllowed(origin) { + return + } + // verify if requested method is allowed + for _, method := range o.AllowMethods { + if method == rMethod { + headers[headerAllowMethods] = strings.Join(o.AllowMethods, ",") + break + } + } + + // verify if requested headers are allowed + var allowed []string + for _, rHeader := range strings.Split(rHeaders, ",") { + rHeader = strings.TrimSpace(rHeader) + lookupLoop: + for _, allowedHeader := range o.AllowHeaders { + if strings.ToLower(rHeader) == strings.ToLower(allowedHeader) { + allowed = append(allowed, rHeader) + break lookupLoop + } + } + } + + headers[headerAllowCredentials] = strconv.FormatBool(o.AllowCredentials) + // add allow origin + if o.AllowAllOrigins { + headers[headerAllowOrigin] = "*" + } else { + headers[headerAllowOrigin] = origin + } + + // add allowed headers + if len(allowed) > 0 { + headers[headerAllowHeaders] = strings.Join(allowed, ",") + } + + // add exposed headers + if len(o.ExposeHeaders) > 0 { + headers[headerExposeHeaders] = strings.Join(o.ExposeHeaders, ",") + } + // add a max age header + if o.MaxAge > time.Duration(0) { + headers[headerMaxAge] = strconv.FormatInt(int64(o.MaxAge/time.Second), 10) + } + return +} + +// IsOriginAllowed looks up if the origin matches one of the patterns +// generated from Options.AllowOrigins patterns. +func (o *Options) IsOriginAllowed(origin string) (allowed bool) { + for _, pattern := range allowOriginPatterns { + allowed, _ = regexp.MatchString(pattern, origin) + if allowed { + return + } + } + return +} + +// Allow enables CORS for requests those match the provided options. +func Allow(opts *Options) beego.FilterFunc { + // Allow default headers if nothing is specified. + if len(opts.AllowHeaders) == 0 { + opts.AllowHeaders = defaultAllowHeaders + } + + for _, origin := range opts.AllowOrigins { + pattern := regexp.QuoteMeta(origin) + pattern = strings.Replace(pattern, "\\*", ".*", -1) + pattern = strings.Replace(pattern, "\\?", ".", -1) + allowOriginPatterns = append(allowOriginPatterns, "^"+pattern+"$") + } + + return func(ctx *context.Context) { + var ( + origin = ctx.Input.Header(headerOrigin) + requestedMethod = ctx.Input.Header(headerRequestMethod) + requestedHeaders = ctx.Input.Header(headerRequestHeaders) + // additional headers to be added + // to the response. + headers map[string]string + ) + + if ctx.Input.Method() == "OPTIONS" && + (requestedMethod != "" || requestedHeaders != "") { + headers = opts.PreflightHeader(origin, requestedMethod, requestedHeaders) + for key, value := range headers { + ctx.Output.Header(key, value) + } + ctx.ResponseWriter.WriteHeader(http.StatusOK) + return + } + headers = opts.Header(origin) + + for key, value := range headers { + ctx.Output.Header(key, value) + } + } +} diff --git a/pkg/plugins/cors/cors_test.go b/pkg/plugins/cors/cors_test.go new file mode 100644 index 00000000..34039143 --- /dev/null +++ b/pkg/plugins/cors/cors_test.go @@ -0,0 +1,253 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package cors + +import ( + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" + + "github.com/astaxie/beego" + "github.com/astaxie/beego/context" +) + +// HTTPHeaderGuardRecorder is httptest.ResponseRecorder with own http.Header +type HTTPHeaderGuardRecorder struct { + *httptest.ResponseRecorder + savedHeaderMap http.Header +} + +// NewRecorder return HttpHeaderGuardRecorder +func NewRecorder() *HTTPHeaderGuardRecorder { + return &HTTPHeaderGuardRecorder{httptest.NewRecorder(), nil} +} + +func (gr *HTTPHeaderGuardRecorder) WriteHeader(code int) { + gr.ResponseRecorder.WriteHeader(code) + gr.savedHeaderMap = gr.ResponseRecorder.Header() +} + +func (gr *HTTPHeaderGuardRecorder) Header() http.Header { + if gr.savedHeaderMap != nil { + // headers were written. clone so we don't get updates + clone := make(http.Header) + for k, v := range gr.savedHeaderMap { + clone[k] = v + } + return clone + } + return gr.ResponseRecorder.Header() +} + +func Test_AllowAll(t *testing.T) { + recorder := httptest.NewRecorder() + handler := beego.NewControllerRegister() + handler.InsertFilter("*", beego.BeforeRouter, Allow(&Options{ + AllowAllOrigins: true, + })) + handler.Any("/foo", func(ctx *context.Context) { + ctx.Output.SetStatus(500) + }) + r, _ := http.NewRequest("PUT", "/foo", nil) + handler.ServeHTTP(recorder, r) + + if recorder.HeaderMap.Get(headerAllowOrigin) != "*" { + t.Errorf("Allow-Origin header should be *") + } +} + +func Test_AllowRegexMatch(t *testing.T) { + recorder := httptest.NewRecorder() + handler := beego.NewControllerRegister() + handler.InsertFilter("*", beego.BeforeRouter, Allow(&Options{ + AllowOrigins: []string{"https://aaa.com", "https://*.foo.com"}, + })) + handler.Any("/foo", func(ctx *context.Context) { + ctx.Output.SetStatus(500) + }) + origin := "https://bar.foo.com" + r, _ := http.NewRequest("PUT", "/foo", nil) + r.Header.Add("Origin", origin) + handler.ServeHTTP(recorder, r) + + headerValue := recorder.HeaderMap.Get(headerAllowOrigin) + if headerValue != origin { + t.Errorf("Allow-Origin header should be %v, found %v", origin, headerValue) + } +} + +func Test_AllowRegexNoMatch(t *testing.T) { + recorder := httptest.NewRecorder() + handler := beego.NewControllerRegister() + handler.InsertFilter("*", beego.BeforeRouter, Allow(&Options{ + AllowOrigins: []string{"https://*.foo.com"}, + })) + handler.Any("/foo", func(ctx *context.Context) { + ctx.Output.SetStatus(500) + }) + origin := "https://ww.foo.com.evil.com" + r, _ := http.NewRequest("PUT", "/foo", nil) + r.Header.Add("Origin", origin) + handler.ServeHTTP(recorder, r) + + headerValue := recorder.HeaderMap.Get(headerAllowOrigin) + if headerValue != "" { + t.Errorf("Allow-Origin header should not exist, found %v", headerValue) + } +} + +func Test_OtherHeaders(t *testing.T) { + recorder := httptest.NewRecorder() + handler := beego.NewControllerRegister() + handler.InsertFilter("*", beego.BeforeRouter, Allow(&Options{ + AllowAllOrigins: true, + AllowCredentials: true, + AllowMethods: []string{"PATCH", "GET"}, + AllowHeaders: []string{"Origin", "X-whatever"}, + ExposeHeaders: []string{"Content-Length", "Hello"}, + MaxAge: 5 * time.Minute, + })) + handler.Any("/foo", func(ctx *context.Context) { + ctx.Output.SetStatus(500) + }) + r, _ := http.NewRequest("PUT", "/foo", nil) + handler.ServeHTTP(recorder, r) + + credentialsVal := recorder.HeaderMap.Get(headerAllowCredentials) + methodsVal := recorder.HeaderMap.Get(headerAllowMethods) + headersVal := recorder.HeaderMap.Get(headerAllowHeaders) + exposedHeadersVal := recorder.HeaderMap.Get(headerExposeHeaders) + maxAgeVal := recorder.HeaderMap.Get(headerMaxAge) + + if credentialsVal != "true" { + t.Errorf("Allow-Credentials is expected to be true, found %v", credentialsVal) + } + + if methodsVal != "PATCH,GET" { + t.Errorf("Allow-Methods is expected to be PATCH,GET; found %v", methodsVal) + } + + if headersVal != "Origin,X-whatever" { + t.Errorf("Allow-Headers is expected to be Origin,X-whatever; found %v", headersVal) + } + + if exposedHeadersVal != "Content-Length,Hello" { + t.Errorf("Expose-Headers are expected to be Content-Length,Hello. Found %v", exposedHeadersVal) + } + + if maxAgeVal != "300" { + t.Errorf("Max-Age is expected to be 300, found %v", maxAgeVal) + } +} + +func Test_DefaultAllowHeaders(t *testing.T) { + recorder := httptest.NewRecorder() + handler := beego.NewControllerRegister() + handler.InsertFilter("*", beego.BeforeRouter, Allow(&Options{ + AllowAllOrigins: true, + })) + handler.Any("/foo", func(ctx *context.Context) { + ctx.Output.SetStatus(500) + }) + + r, _ := http.NewRequest("PUT", "/foo", nil) + handler.ServeHTTP(recorder, r) + + headersVal := recorder.HeaderMap.Get(headerAllowHeaders) + if headersVal != "Origin,Accept,Content-Type,Authorization" { + t.Errorf("Allow-Headers is expected to be Origin,Accept,Content-Type,Authorization; found %v", headersVal) + } +} + +func Test_Preflight(t *testing.T) { + recorder := NewRecorder() + handler := beego.NewControllerRegister() + handler.InsertFilter("*", beego.BeforeRouter, Allow(&Options{ + AllowAllOrigins: true, + AllowMethods: []string{"PUT", "PATCH"}, + AllowHeaders: []string{"Origin", "X-whatever", "X-CaseSensitive"}, + })) + + handler.Any("/foo", func(ctx *context.Context) { + ctx.Output.SetStatus(200) + }) + + r, _ := http.NewRequest("OPTIONS", "/foo", nil) + r.Header.Add(headerRequestMethod, "PUT") + r.Header.Add(headerRequestHeaders, "X-whatever, x-casesensitive") + handler.ServeHTTP(recorder, r) + + headers := recorder.Header() + methodsVal := headers.Get(headerAllowMethods) + headersVal := headers.Get(headerAllowHeaders) + originVal := headers.Get(headerAllowOrigin) + + if methodsVal != "PUT,PATCH" { + t.Errorf("Allow-Methods is expected to be PUT,PATCH, found %v", methodsVal) + } + + if !strings.Contains(headersVal, "X-whatever") { + t.Errorf("Allow-Headers is expected to contain X-whatever, found %v", headersVal) + } + + if !strings.Contains(headersVal, "x-casesensitive") { + t.Errorf("Allow-Headers is expected to contain x-casesensitive, found %v", headersVal) + } + + if originVal != "*" { + t.Errorf("Allow-Origin is expected to be *, found %v", originVal) + } + + if recorder.Code != http.StatusOK { + t.Errorf("Status code is expected to be 200, found %d", recorder.Code) + } +} + +func Benchmark_WithoutCORS(b *testing.B) { + recorder := httptest.NewRecorder() + handler := beego.NewControllerRegister() + beego.BConfig.RunMode = beego.PROD + handler.Any("/foo", func(ctx *context.Context) { + ctx.Output.SetStatus(500) + }) + b.ResetTimer() + r, _ := http.NewRequest("PUT", "/foo", nil) + for i := 0; i < b.N; i++ { + handler.ServeHTTP(recorder, r) + } +} + +func Benchmark_WithCORS(b *testing.B) { + recorder := httptest.NewRecorder() + handler := beego.NewControllerRegister() + beego.BConfig.RunMode = beego.PROD + handler.InsertFilter("*", beego.BeforeRouter, Allow(&Options{ + AllowAllOrigins: true, + AllowCredentials: true, + AllowMethods: []string{"PATCH", "GET"}, + AllowHeaders: []string{"Origin", "X-whatever"}, + MaxAge: 5 * time.Minute, + })) + handler.Any("/foo", func(ctx *context.Context) { + ctx.Output.SetStatus(500) + }) + b.ResetTimer() + r, _ := http.NewRequest("PUT", "/foo", nil) + for i := 0; i < b.N; i++ { + handler.ServeHTTP(recorder, r) + } +} diff --git a/pkg/policy.go b/pkg/policy.go new file mode 100644 index 00000000..ab23f927 --- /dev/null +++ b/pkg/policy.go @@ -0,0 +1,97 @@ +// Copyright 2016 beego authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package beego + +import ( + "strings" + + "github.com/astaxie/beego/context" +) + +// PolicyFunc defines a policy function which is invoked before the controller handler is executed. +type PolicyFunc func(*context.Context) + +// FindPolicy Find Router info for URL +func (p *ControllerRegister) FindPolicy(cont *context.Context) []PolicyFunc { + var urlPath = cont.Input.URL() + if !BConfig.RouterCaseSensitive { + urlPath = strings.ToLower(urlPath) + } + httpMethod := cont.Input.Method() + isWildcard := false + // Find policy for current method + t, ok := p.policies[httpMethod] + // If not found - find policy for whole controller + if !ok { + t, ok = p.policies["*"] + isWildcard = true + } + if ok { + runObjects := t.Match(urlPath, cont) + if r, ok := runObjects.([]PolicyFunc); ok { + return r + } else if !isWildcard { + // If no policies found and we checked not for "*" method - try to find it + t, ok = p.policies["*"] + if ok { + runObjects = t.Match(urlPath, cont) + if r, ok = runObjects.([]PolicyFunc); ok { + return r + } + } + } + } + return nil +} + +func (p *ControllerRegister) addToPolicy(method, pattern string, r ...PolicyFunc) { + method = strings.ToUpper(method) + p.enablePolicy = true + if !BConfig.RouterCaseSensitive { + pattern = strings.ToLower(pattern) + } + if t, ok := p.policies[method]; ok { + t.AddRouter(pattern, r) + } else { + t := NewTree() + t.AddRouter(pattern, r) + p.policies[method] = t + } +} + +// Policy Register new policy in beego +func Policy(pattern, method string, policy ...PolicyFunc) { + BeeApp.Handlers.addToPolicy(method, pattern, policy...) +} + +// Find policies and execute if were found +func (p *ControllerRegister) execPolicy(cont *context.Context, urlPath string) (started bool) { + if !p.enablePolicy { + return false + } + // Find Policy for method + policyList := p.FindPolicy(cont) + if len(policyList) > 0 { + // Run policies + for _, runPolicy := range policyList { + runPolicy(cont) + if cont.ResponseWriter.Started { + return true + } + } + return false + } + return false +} diff --git a/pkg/router.go b/pkg/router.go new file mode 100644 index 00000000..6a8ac6f7 --- /dev/null +++ b/pkg/router.go @@ -0,0 +1,1052 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package beego + +import ( + "errors" + "fmt" + "net/http" + "os" + "path" + "path/filepath" + "reflect" + "strconv" + "strings" + "sync" + "time" + + beecontext "github.com/astaxie/beego/context" + "github.com/astaxie/beego/context/param" + "github.com/astaxie/beego/logs" + "github.com/astaxie/beego/toolbox" + "github.com/astaxie/beego/utils" +) + +// default filter execution points +const ( + BeforeStatic = iota + BeforeRouter + BeforeExec + AfterExec + FinishRouter +) + +const ( + routerTypeBeego = iota + routerTypeRESTFul + routerTypeHandler +) + +var ( + // HTTPMETHOD list the supported http methods. + HTTPMETHOD = map[string]bool{ + "GET": true, + "POST": true, + "PUT": true, + "DELETE": true, + "PATCH": true, + "OPTIONS": true, + "HEAD": true, + "TRACE": true, + "CONNECT": true, + "MKCOL": true, + "COPY": true, + "MOVE": true, + "PROPFIND": true, + "PROPPATCH": true, + "LOCK": true, + "UNLOCK": true, + } + // these beego.Controller's methods shouldn't reflect to AutoRouter + exceptMethod = []string{"Init", "Prepare", "Finish", "Render", "RenderString", + "RenderBytes", "Redirect", "Abort", "StopRun", "UrlFor", "ServeJSON", "ServeJSONP", + "ServeYAML", "ServeXML", "Input", "ParseForm", "GetString", "GetStrings", "GetInt", "GetBool", + "GetFloat", "GetFile", "SaveToFile", "StartSession", "SetSession", "GetSession", + "DelSession", "SessionRegenerateID", "DestroySession", "IsAjax", "GetSecureCookie", + "SetSecureCookie", "XsrfToken", "CheckXsrfCookie", "XsrfFormHtml", + "GetControllerAndAction", "ServeFormatted"} + + urlPlaceholder = "{{placeholder}}" + // DefaultAccessLogFilter will skip the accesslog if return true + DefaultAccessLogFilter FilterHandler = &logFilter{} +) + +// FilterHandler is an interface for +type FilterHandler interface { + Filter(*beecontext.Context) bool +} + +// default log filter static file will not show +type logFilter struct { +} + +func (l *logFilter) Filter(ctx *beecontext.Context) bool { + requestPath := path.Clean(ctx.Request.URL.Path) + if requestPath == "/favicon.ico" || requestPath == "/robots.txt" { + return true + } + for prefix := range BConfig.WebConfig.StaticDir { + if strings.HasPrefix(requestPath, prefix) { + return true + } + } + return false +} + +// ExceptMethodAppend to append a slice's value into "exceptMethod", for controller's methods shouldn't reflect to AutoRouter +func ExceptMethodAppend(action string) { + exceptMethod = append(exceptMethod, action) +} + +// ControllerInfo holds information about the controller. +type ControllerInfo struct { + pattern string + controllerType reflect.Type + methods map[string]string + handler http.Handler + runFunction FilterFunc + routerType int + initialize func() ControllerInterface + methodParams []*param.MethodParam +} + +func (c *ControllerInfo) GetPattern() string { + return c.pattern +} + +// ControllerRegister containers registered router rules, controller handlers and filters. +type ControllerRegister struct { + routers map[string]*Tree + enablePolicy bool + policies map[string]*Tree + enableFilter bool + filters [FinishRouter + 1][]*FilterRouter + pool sync.Pool +} + +// NewControllerRegister returns a new ControllerRegister. +func NewControllerRegister() *ControllerRegister { + return &ControllerRegister{ + routers: make(map[string]*Tree), + policies: make(map[string]*Tree), + pool: sync.Pool{ + New: func() interface{} { + return beecontext.NewContext() + }, + }, + } +} + +// Add controller handler and pattern rules to ControllerRegister. +// usage: +// default methods is the same name as method +// Add("/user",&UserController{}) +// Add("/api/list",&RestController{},"*:ListFood") +// Add("/api/create",&RestController{},"post:CreateFood") +// Add("/api/update",&RestController{},"put:UpdateFood") +// Add("/api/delete",&RestController{},"delete:DeleteFood") +// Add("/api",&RestController{},"get,post:ApiFunc" +// Add("/simple",&SimpleController{},"get:GetFunc;post:PostFunc") +func (p *ControllerRegister) Add(pattern string, c ControllerInterface, mappingMethods ...string) { + p.addWithMethodParams(pattern, c, nil, mappingMethods...) +} + +func (p *ControllerRegister) addWithMethodParams(pattern string, c ControllerInterface, methodParams []*param.MethodParam, mappingMethods ...string) { + reflectVal := reflect.ValueOf(c) + t := reflect.Indirect(reflectVal).Type() + methods := make(map[string]string) + if len(mappingMethods) > 0 { + semi := strings.Split(mappingMethods[0], ";") + for _, v := range semi { + colon := strings.Split(v, ":") + if len(colon) != 2 { + panic("method mapping format is invalid") + } + comma := strings.Split(colon[0], ",") + for _, m := range comma { + if m == "*" || HTTPMETHOD[strings.ToUpper(m)] { + if val := reflectVal.MethodByName(colon[1]); val.IsValid() { + methods[strings.ToUpper(m)] = colon[1] + } else { + panic("'" + colon[1] + "' method doesn't exist in the controller " + t.Name()) + } + } else { + panic(v + " is an invalid method mapping. Method doesn't exist " + m) + } + } + } + } + + route := &ControllerInfo{} + route.pattern = pattern + route.methods = methods + route.routerType = routerTypeBeego + route.controllerType = t + route.initialize = func() ControllerInterface { + vc := reflect.New(route.controllerType) + execController, ok := vc.Interface().(ControllerInterface) + if !ok { + panic("controller is not ControllerInterface") + } + + elemVal := reflect.ValueOf(c).Elem() + elemType := reflect.TypeOf(c).Elem() + execElem := reflect.ValueOf(execController).Elem() + + numOfFields := elemVal.NumField() + for i := 0; i < numOfFields; i++ { + fieldType := elemType.Field(i) + elemField := execElem.FieldByName(fieldType.Name) + if elemField.CanSet() { + fieldVal := elemVal.Field(i) + elemField.Set(fieldVal) + } + } + + return execController + } + + route.methodParams = methodParams + if len(methods) == 0 { + for m := range HTTPMETHOD { + p.addToRouter(m, pattern, route) + } + } else { + for k := range methods { + if k == "*" { + for m := range HTTPMETHOD { + p.addToRouter(m, pattern, route) + } + } else { + p.addToRouter(k, pattern, route) + } + } + } +} + +func (p *ControllerRegister) addToRouter(method, pattern string, r *ControllerInfo) { + if !BConfig.RouterCaseSensitive { + pattern = strings.ToLower(pattern) + } + if t, ok := p.routers[method]; ok { + t.AddRouter(pattern, r) + } else { + t := NewTree() + t.AddRouter(pattern, r) + p.routers[method] = t + } +} + +// Include only when the Runmode is dev will generate router file in the router/auto.go from the controller +// Include(&BankAccount{}, &OrderController{},&RefundController{},&ReceiptController{}) +func (p *ControllerRegister) Include(cList ...ControllerInterface) { + if BConfig.RunMode == DEV { + skip := make(map[string]bool, 10) + wgopath := utils.GetGOPATHs() + go111module := os.Getenv(`GO111MODULE`) + for _, c := range cList { + reflectVal := reflect.ValueOf(c) + t := reflect.Indirect(reflectVal).Type() + // for go modules + if go111module == `on` { + pkgpath := filepath.Join(WorkPath, "..", t.PkgPath()) + if utils.FileExists(pkgpath) { + if pkgpath != "" { + if _, ok := skip[pkgpath]; !ok { + skip[pkgpath] = true + parserPkg(pkgpath, t.PkgPath()) + } + } + } + } else { + if len(wgopath) == 0 { + panic("you are in dev mode. So please set gopath") + } + pkgpath := "" + for _, wg := range wgopath { + wg, _ = filepath.EvalSymlinks(filepath.Join(wg, "src", t.PkgPath())) + if utils.FileExists(wg) { + pkgpath = wg + break + } + } + if pkgpath != "" { + if _, ok := skip[pkgpath]; !ok { + skip[pkgpath] = true + parserPkg(pkgpath, t.PkgPath()) + } + } + } + } + } + for _, c := range cList { + reflectVal := reflect.ValueOf(c) + t := reflect.Indirect(reflectVal).Type() + key := t.PkgPath() + ":" + t.Name() + if comm, ok := GlobalControllerRouter[key]; ok { + for _, a := range comm { + for _, f := range a.Filters { + p.InsertFilter(f.Pattern, f.Pos, f.Filter, f.ReturnOnOutput, f.ResetParams) + } + + p.addWithMethodParams(a.Router, c, a.MethodParams, strings.Join(a.AllowHTTPMethods, ",")+":"+a.Method) + } + } + } +} + +// GetContext returns a context from pool, so usually you should remember to call Reset function to clean the context +// And don't forget to give back context to pool +// example: +// ctx := p.GetContext() +// ctx.Reset(w, q) +// defer p.GiveBackContext(ctx) +func (p *ControllerRegister) GetContext() *beecontext.Context { + return p.pool.Get().(*beecontext.Context) +} + +// GiveBackContext put the ctx into pool so that it could be reuse +func (p *ControllerRegister) GiveBackContext(ctx *beecontext.Context) { + p.pool.Put(ctx) +} + +// Get add get method +// usage: +// Get("/", func(ctx *context.Context){ +// ctx.Output.Body("hello world") +// }) +func (p *ControllerRegister) Get(pattern string, f FilterFunc) { + p.AddMethod("get", pattern, f) +} + +// Post add post method +// usage: +// Post("/api", func(ctx *context.Context){ +// ctx.Output.Body("hello world") +// }) +func (p *ControllerRegister) Post(pattern string, f FilterFunc) { + p.AddMethod("post", pattern, f) +} + +// Put add put method +// usage: +// Put("/api/:id", func(ctx *context.Context){ +// ctx.Output.Body("hello world") +// }) +func (p *ControllerRegister) Put(pattern string, f FilterFunc) { + p.AddMethod("put", pattern, f) +} + +// Delete add delete method +// usage: +// Delete("/api/:id", func(ctx *context.Context){ +// ctx.Output.Body("hello world") +// }) +func (p *ControllerRegister) Delete(pattern string, f FilterFunc) { + p.AddMethod("delete", pattern, f) +} + +// Head add head method +// usage: +// Head("/api/:id", func(ctx *context.Context){ +// ctx.Output.Body("hello world") +// }) +func (p *ControllerRegister) Head(pattern string, f FilterFunc) { + p.AddMethod("head", pattern, f) +} + +// Patch add patch method +// usage: +// Patch("/api/:id", func(ctx *context.Context){ +// ctx.Output.Body("hello world") +// }) +func (p *ControllerRegister) Patch(pattern string, f FilterFunc) { + p.AddMethod("patch", pattern, f) +} + +// Options add options method +// usage: +// Options("/api/:id", func(ctx *context.Context){ +// ctx.Output.Body("hello world") +// }) +func (p *ControllerRegister) Options(pattern string, f FilterFunc) { + p.AddMethod("options", pattern, f) +} + +// Any add all method +// usage: +// Any("/api/:id", func(ctx *context.Context){ +// ctx.Output.Body("hello world") +// }) +func (p *ControllerRegister) Any(pattern string, f FilterFunc) { + p.AddMethod("*", pattern, f) +} + +// AddMethod add http method router +// usage: +// AddMethod("get","/api/:id", func(ctx *context.Context){ +// ctx.Output.Body("hello world") +// }) +func (p *ControllerRegister) AddMethod(method, pattern string, f FilterFunc) { + method = strings.ToUpper(method) + if method != "*" && !HTTPMETHOD[method] { + panic("not support http method: " + method) + } + route := &ControllerInfo{} + route.pattern = pattern + route.routerType = routerTypeRESTFul + route.runFunction = f + methods := make(map[string]string) + if method == "*" { + for val := range HTTPMETHOD { + methods[val] = val + } + } else { + methods[method] = method + } + route.methods = methods + for k := range methods { + if k == "*" { + for m := range HTTPMETHOD { + p.addToRouter(m, pattern, route) + } + } else { + p.addToRouter(k, pattern, route) + } + } +} + +// Handler add user defined Handler +func (p *ControllerRegister) Handler(pattern string, h http.Handler, options ...interface{}) { + route := &ControllerInfo{} + route.pattern = pattern + route.routerType = routerTypeHandler + route.handler = h + if len(options) > 0 { + if _, ok := options[0].(bool); ok { + pattern = path.Join(pattern, "?:all(.*)") + } + } + for m := range HTTPMETHOD { + p.addToRouter(m, pattern, route) + } +} + +// AddAuto router to ControllerRegister. +// example beego.AddAuto(&MainContorlller{}), +// MainController has method List and Page. +// visit the url /main/list to execute List function +// /main/page to execute Page function. +func (p *ControllerRegister) AddAuto(c ControllerInterface) { + p.AddAutoPrefix("/", c) +} + +// AddAutoPrefix Add auto router to ControllerRegister with prefix. +// example beego.AddAutoPrefix("/admin",&MainContorlller{}), +// MainController has method List and Page. +// visit the url /admin/main/list to execute List function +// /admin/main/page to execute Page function. +func (p *ControllerRegister) AddAutoPrefix(prefix string, c ControllerInterface) { + reflectVal := reflect.ValueOf(c) + rt := reflectVal.Type() + ct := reflect.Indirect(reflectVal).Type() + controllerName := strings.TrimSuffix(ct.Name(), "Controller") + for i := 0; i < rt.NumMethod(); i++ { + if !utils.InSlice(rt.Method(i).Name, exceptMethod) { + route := &ControllerInfo{} + route.routerType = routerTypeBeego + route.methods = map[string]string{"*": rt.Method(i).Name} + route.controllerType = ct + pattern := path.Join(prefix, strings.ToLower(controllerName), strings.ToLower(rt.Method(i).Name), "*") + patternInit := path.Join(prefix, controllerName, rt.Method(i).Name, "*") + patternFix := path.Join(prefix, strings.ToLower(controllerName), strings.ToLower(rt.Method(i).Name)) + patternFixInit := path.Join(prefix, controllerName, rt.Method(i).Name) + route.pattern = pattern + for m := range HTTPMETHOD { + p.addToRouter(m, pattern, route) + p.addToRouter(m, patternInit, route) + p.addToRouter(m, patternFix, route) + p.addToRouter(m, patternFixInit, route) + } + } + } +} + +// InsertFilter Add a FilterFunc with pattern rule and action constant. +// params is for: +// 1. setting the returnOnOutput value (false allows multiple filters to execute) +// 2. determining whether or not params need to be reset. +func (p *ControllerRegister) InsertFilter(pattern string, pos int, filter FilterFunc, params ...bool) error { + mr := &FilterRouter{ + tree: NewTree(), + pattern: pattern, + filterFunc: filter, + returnOnOutput: true, + } + if !BConfig.RouterCaseSensitive { + mr.pattern = strings.ToLower(pattern) + } + + paramsLen := len(params) + if paramsLen > 0 { + mr.returnOnOutput = params[0] + } + if paramsLen > 1 { + mr.resetParams = params[1] + } + mr.tree.AddRouter(pattern, true) + return p.insertFilterRouter(pos, mr) +} + +// add Filter into +func (p *ControllerRegister) insertFilterRouter(pos int, mr *FilterRouter) (err error) { + if pos < BeforeStatic || pos > FinishRouter { + return errors.New("can not find your filter position") + } + p.enableFilter = true + p.filters[pos] = append(p.filters[pos], mr) + return nil +} + +// URLFor does another controller handler in this request function. +// it can access any controller method. +func (p *ControllerRegister) URLFor(endpoint string, values ...interface{}) string { + paths := strings.Split(endpoint, ".") + if len(paths) <= 1 { + logs.Warn("urlfor endpoint must like path.controller.method") + return "" + } + if len(values)%2 != 0 { + logs.Warn("urlfor params must key-value pair") + return "" + } + params := make(map[string]string) + if len(values) > 0 { + key := "" + for k, v := range values { + if k%2 == 0 { + key = fmt.Sprint(v) + } else { + params[key] = fmt.Sprint(v) + } + } + } + controllerName := strings.Join(paths[:len(paths)-1], "/") + methodName := paths[len(paths)-1] + for m, t := range p.routers { + ok, url := p.getURL(t, "/", controllerName, methodName, params, m) + if ok { + return url + } + } + return "" +} + +func (p *ControllerRegister) getURL(t *Tree, url, controllerName, methodName string, params map[string]string, httpMethod string) (bool, string) { + for _, subtree := range t.fixrouters { + u := path.Join(url, subtree.prefix) + ok, u := p.getURL(subtree, u, controllerName, methodName, params, httpMethod) + if ok { + return ok, u + } + } + if t.wildcard != nil { + u := path.Join(url, urlPlaceholder) + ok, u := p.getURL(t.wildcard, u, controllerName, methodName, params, httpMethod) + if ok { + return ok, u + } + } + for _, l := range t.leaves { + if c, ok := l.runObject.(*ControllerInfo); ok { + if c.routerType == routerTypeBeego && + strings.HasSuffix(path.Join(c.controllerType.PkgPath(), c.controllerType.Name()), controllerName) { + find := false + if HTTPMETHOD[strings.ToUpper(methodName)] { + if len(c.methods) == 0 { + find = true + } else if m, ok := c.methods[strings.ToUpper(methodName)]; ok && m == strings.ToUpper(methodName) { + find = true + } else if m, ok = c.methods["*"]; ok && m == methodName { + find = true + } + } + if !find { + for m, md := range c.methods { + if (m == "*" || m == httpMethod) && md == methodName { + find = true + } + } + } + if find { + if l.regexps == nil { + if len(l.wildcards) == 0 { + return true, strings.Replace(url, "/"+urlPlaceholder, "", 1) + toURL(params) + } + if len(l.wildcards) == 1 { + if v, ok := params[l.wildcards[0]]; ok { + delete(params, l.wildcards[0]) + return true, strings.Replace(url, urlPlaceholder, v, 1) + toURL(params) + } + return false, "" + } + if len(l.wildcards) == 3 && l.wildcards[0] == "." { + if p, ok := params[":path"]; ok { + if e, isok := params[":ext"]; isok { + delete(params, ":path") + delete(params, ":ext") + return true, strings.Replace(url, urlPlaceholder, p+"."+e, -1) + toURL(params) + } + } + } + canSkip := false + for _, v := range l.wildcards { + if v == ":" { + canSkip = true + continue + } + if u, ok := params[v]; ok { + delete(params, v) + url = strings.Replace(url, urlPlaceholder, u, 1) + } else { + if canSkip { + canSkip = false + continue + } + return false, "" + } + } + return true, url + toURL(params) + } + var i int + var startReg bool + regURL := "" + for _, v := range strings.Trim(l.regexps.String(), "^$") { + if v == '(' { + startReg = true + continue + } else if v == ')' { + startReg = false + if v, ok := params[l.wildcards[i]]; ok { + delete(params, l.wildcards[i]) + regURL = regURL + v + i++ + } else { + break + } + } else if !startReg { + regURL = string(append([]rune(regURL), v)) + } + } + if l.regexps.MatchString(regURL) { + ps := strings.Split(regURL, "/") + for _, p := range ps { + url = strings.Replace(url, urlPlaceholder, p, 1) + } + return true, url + toURL(params) + } + } + } + } + } + + return false, "" +} + +func (p *ControllerRegister) execFilter(context *beecontext.Context, urlPath string, pos int) (started bool) { + var preFilterParams map[string]string + for _, filterR := range p.filters[pos] { + if filterR.returnOnOutput && context.ResponseWriter.Started { + return true + } + if filterR.resetParams { + preFilterParams = context.Input.Params() + } + if ok := filterR.ValidRouter(urlPath, context); ok { + filterR.filterFunc(context) + if filterR.resetParams { + context.Input.ResetParams() + for k, v := range preFilterParams { + context.Input.SetParam(k, v) + } + } + } + if filterR.returnOnOutput && context.ResponseWriter.Started { + return true + } + } + return false +} + +// Implement http.Handler interface. +func (p *ControllerRegister) ServeHTTP(rw http.ResponseWriter, r *http.Request) { + startTime := time.Now() + var ( + runRouter reflect.Type + findRouter bool + runMethod string + methodParams []*param.MethodParam + routerInfo *ControllerInfo + isRunnable bool + ) + context := p.GetContext() + + context.Reset(rw, r) + + defer p.GiveBackContext(context) + if BConfig.RecoverFunc != nil { + defer BConfig.RecoverFunc(context) + } + + context.Output.EnableGzip = BConfig.EnableGzip + + if BConfig.RunMode == DEV { + context.Output.Header("Server", BConfig.ServerName) + } + + var urlPath = r.URL.Path + + if !BConfig.RouterCaseSensitive { + urlPath = strings.ToLower(urlPath) + } + + // filter wrong http method + if !HTTPMETHOD[r.Method] { + exception("405", context) + goto Admin + } + + // filter for static file + if len(p.filters[BeforeStatic]) > 0 && p.execFilter(context, urlPath, BeforeStatic) { + goto Admin + } + + serverStaticRouter(context) + + if context.ResponseWriter.Started { + findRouter = true + goto Admin + } + + if r.Method != http.MethodGet && r.Method != http.MethodHead { + if BConfig.CopyRequestBody && !context.Input.IsUpload() { + // connection will close if the incoming data are larger (RFC 7231, 6.5.11) + if r.ContentLength > BConfig.MaxMemory { + logs.Error(errors.New("payload too large")) + exception("413", context) + goto Admin + } + context.Input.CopyBody(BConfig.MaxMemory) + } + context.Input.ParseFormOrMulitForm(BConfig.MaxMemory) + } + + // session init + if BConfig.WebConfig.Session.SessionOn { + var err error + context.Input.CruSession, err = GlobalSessions.SessionStart(rw, r) + if err != nil { + logs.Error(err) + exception("503", context) + goto Admin + } + defer func() { + if context.Input.CruSession != nil { + context.Input.CruSession.SessionRelease(rw) + } + }() + } + if len(p.filters[BeforeRouter]) > 0 && p.execFilter(context, urlPath, BeforeRouter) { + goto Admin + } + // User can define RunController and RunMethod in filter + if context.Input.RunController != nil && context.Input.RunMethod != "" { + findRouter = true + runMethod = context.Input.RunMethod + runRouter = context.Input.RunController + } else { + routerInfo, findRouter = p.FindRouter(context) + } + + // if no matches to url, throw a not found exception + if !findRouter { + exception("404", context) + goto Admin + } + if splat := context.Input.Param(":splat"); splat != "" { + for k, v := range strings.Split(splat, "/") { + context.Input.SetParam(strconv.Itoa(k), v) + } + } + + if routerInfo != nil { + // store router pattern into context + context.Input.SetData("RouterPattern", routerInfo.pattern) + } + + // execute middleware filters + if len(p.filters[BeforeExec]) > 0 && p.execFilter(context, urlPath, BeforeExec) { + goto Admin + } + + // check policies + if p.execPolicy(context, urlPath) { + goto Admin + } + + if routerInfo != nil { + if routerInfo.routerType == routerTypeRESTFul { + if _, ok := routerInfo.methods[r.Method]; ok { + isRunnable = true + routerInfo.runFunction(context) + } else { + exception("405", context) + goto Admin + } + } else if routerInfo.routerType == routerTypeHandler { + isRunnable = true + routerInfo.handler.ServeHTTP(context.ResponseWriter, context.Request) + } else { + runRouter = routerInfo.controllerType + methodParams = routerInfo.methodParams + method := r.Method + if r.Method == http.MethodPost && context.Input.Query("_method") == http.MethodPut { + method = http.MethodPut + } + if r.Method == http.MethodPost && context.Input.Query("_method") == http.MethodDelete { + method = http.MethodDelete + } + if m, ok := routerInfo.methods[method]; ok { + runMethod = m + } else if m, ok = routerInfo.methods["*"]; ok { + runMethod = m + } else { + runMethod = method + } + } + } + + // also defined runRouter & runMethod from filter + if !isRunnable { + // Invoke the request handler + var execController ControllerInterface + if routerInfo != nil && routerInfo.initialize != nil { + execController = routerInfo.initialize() + } else { + vc := reflect.New(runRouter) + var ok bool + execController, ok = vc.Interface().(ControllerInterface) + if !ok { + panic("controller is not ControllerInterface") + } + } + + // call the controller init function + execController.Init(context, runRouter.Name(), runMethod, execController) + + // call prepare function + execController.Prepare() + + // if XSRF is Enable then check cookie where there has any cookie in the request's cookie _csrf + if BConfig.WebConfig.EnableXSRF { + execController.XSRFToken() + if r.Method == http.MethodPost || r.Method == http.MethodDelete || r.Method == http.MethodPut || + (r.Method == http.MethodPost && (context.Input.Query("_method") == http.MethodDelete || context.Input.Query("_method") == http.MethodPut)) { + execController.CheckXSRFCookie() + } + } + + execController.URLMapping() + + if !context.ResponseWriter.Started { + // exec main logic + switch runMethod { + case http.MethodGet: + execController.Get() + case http.MethodPost: + execController.Post() + case http.MethodDelete: + execController.Delete() + case http.MethodPut: + execController.Put() + case http.MethodHead: + execController.Head() + case http.MethodPatch: + execController.Patch() + case http.MethodOptions: + execController.Options() + case http.MethodTrace: + execController.Trace() + default: + if !execController.HandlerFunc(runMethod) { + vc := reflect.ValueOf(execController) + method := vc.MethodByName(runMethod) + in := param.ConvertParams(methodParams, method.Type(), context) + out := method.Call(in) + + // For backward compatibility we only handle response if we had incoming methodParams + if methodParams != nil { + p.handleParamResponse(context, execController, out) + } + } + } + + // render template + if !context.ResponseWriter.Started && context.Output.Status == 0 { + if BConfig.WebConfig.AutoRender { + if err := execController.Render(); err != nil { + logs.Error(err) + } + } + } + } + + // finish all runRouter. release resource + execController.Finish() + } + + // execute middleware filters + if len(p.filters[AfterExec]) > 0 && p.execFilter(context, urlPath, AfterExec) { + goto Admin + } + + if len(p.filters[FinishRouter]) > 0 && p.execFilter(context, urlPath, FinishRouter) { + goto Admin + } + +Admin: + // admin module record QPS + + statusCode := context.ResponseWriter.Status + if statusCode == 0 { + statusCode = 200 + } + + LogAccess(context, &startTime, statusCode) + + timeDur := time.Since(startTime) + context.ResponseWriter.Elapsed = timeDur + if BConfig.Listen.EnableAdmin { + pattern := "" + if routerInfo != nil { + pattern = routerInfo.pattern + } + + if FilterMonitorFunc(r.Method, r.URL.Path, timeDur, pattern, statusCode) { + routerName := "" + if runRouter != nil { + routerName = runRouter.Name() + } + go toolbox.StatisticsMap.AddStatistics(r.Method, r.URL.Path, routerName, timeDur) + } + } + + if BConfig.RunMode == DEV && !BConfig.Log.AccessLogs { + match := map[bool]string{true: "match", false: "nomatch"} + devInfo := fmt.Sprintf("|%15s|%s %3d %s|%13s|%8s|%s %-7s %s %-3s", + context.Input.IP(), + logs.ColorByStatus(statusCode), statusCode, logs.ResetColor(), + timeDur.String(), + match[findRouter], + logs.ColorByMethod(r.Method), r.Method, logs.ResetColor(), + r.URL.Path) + if routerInfo != nil { + devInfo += fmt.Sprintf(" r:%s", routerInfo.pattern) + } + + logs.Debug(devInfo) + } + // Call WriteHeader if status code has been set changed + if context.Output.Status != 0 { + context.ResponseWriter.WriteHeader(context.Output.Status) + } +} + +func (p *ControllerRegister) handleParamResponse(context *beecontext.Context, execController ControllerInterface, results []reflect.Value) { + // looping in reverse order for the case when both error and value are returned and error sets the response status code + for i := len(results) - 1; i >= 0; i-- { + result := results[i] + if result.Kind() != reflect.Interface || !result.IsNil() { + resultValue := result.Interface() + context.RenderMethodResult(resultValue) + } + } + if !context.ResponseWriter.Started && len(results) > 0 && context.Output.Status == 0 { + context.Output.SetStatus(200) + } +} + +// FindRouter Find Router info for URL +func (p *ControllerRegister) FindRouter(context *beecontext.Context) (routerInfo *ControllerInfo, isFind bool) { + var urlPath = context.Input.URL() + if !BConfig.RouterCaseSensitive { + urlPath = strings.ToLower(urlPath) + } + httpMethod := context.Input.Method() + if t, ok := p.routers[httpMethod]; ok { + runObject := t.Match(urlPath, context) + if r, ok := runObject.(*ControllerInfo); ok { + return r, true + } + } + return +} + +func toURL(params map[string]string) string { + if len(params) == 0 { + return "" + } + u := "?" + for k, v := range params { + u += k + "=" + v + "&" + } + return strings.TrimRight(u, "&") +} + +// LogAccess logging info HTTP Access +func LogAccess(ctx *beecontext.Context, startTime *time.Time, statusCode int) { + // Skip logging if AccessLogs config is false + if !BConfig.Log.AccessLogs { + return + } + // Skip logging static requests unless EnableStaticLogs config is true + if !BConfig.Log.EnableStaticLogs && DefaultAccessLogFilter.Filter(ctx) { + return + } + var ( + requestTime time.Time + elapsedTime time.Duration + r = ctx.Request + ) + if startTime != nil { + requestTime = *startTime + elapsedTime = time.Since(*startTime) + } + record := &logs.AccessLogRecord{ + RemoteAddr: ctx.Input.IP(), + RequestTime: requestTime, + RequestMethod: r.Method, + Request: fmt.Sprintf("%s %s %s", r.Method, r.RequestURI, r.Proto), + ServerProtocol: r.Proto, + Host: r.Host, + Status: statusCode, + ElapsedTime: elapsedTime, + HTTPReferrer: r.Header.Get("Referer"), + HTTPUserAgent: r.Header.Get("User-Agent"), + RemoteUser: r.Header.Get("Remote-User"), + BodyBytesSent: r.ContentLength, + } + logs.AccessLog(record, BConfig.Log.AccessLogsFormat) +} diff --git a/pkg/router_test.go b/pkg/router_test.go new file mode 100644 index 00000000..8ec7927a --- /dev/null +++ b/pkg/router_test.go @@ -0,0 +1,732 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package beego + +import ( + "bytes" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/astaxie/beego/context" + "github.com/astaxie/beego/logs" +) + +type TestController struct { + Controller +} + +func (tc *TestController) Get() { + tc.Data["Username"] = "astaxie" + tc.Ctx.Output.Body([]byte("ok")) +} + +func (tc *TestController) Post() { + tc.Ctx.Output.Body([]byte(tc.Ctx.Input.Query(":name"))) +} + +func (tc *TestController) Param() { + tc.Ctx.Output.Body([]byte(tc.Ctx.Input.Query(":name"))) +} + +func (tc *TestController) List() { + tc.Ctx.Output.Body([]byte("i am list")) +} + +func (tc *TestController) Params() { + tc.Ctx.Output.Body([]byte(tc.Ctx.Input.Param("0") + tc.Ctx.Input.Param("1") + tc.Ctx.Input.Param("2"))) +} + +func (tc *TestController) Myext() { + tc.Ctx.Output.Body([]byte(tc.Ctx.Input.Param(":ext"))) +} + +func (tc *TestController) GetURL() { + tc.Ctx.Output.Body([]byte(tc.URLFor(".Myext"))) +} + +func (tc *TestController) GetParams() { + tc.Ctx.WriteString(tc.Ctx.Input.Query(":last") + "+" + + tc.Ctx.Input.Query(":first") + "+" + tc.Ctx.Input.Query("learn")) +} + +func (tc *TestController) GetManyRouter() { + tc.Ctx.WriteString(tc.Ctx.Input.Query(":id") + tc.Ctx.Input.Query(":page")) +} + +func (tc *TestController) GetEmptyBody() { + var res []byte + tc.Ctx.Output.Body(res) +} + +type JSONController struct { + Controller +} + +func (jc *JSONController) Prepare() { + jc.Data["json"] = "prepare" + jc.ServeJSON(true) +} + +func (jc *JSONController) Get() { + jc.Data["Username"] = "astaxie" + jc.Ctx.Output.Body([]byte("ok")) +} + +func TestUrlFor(t *testing.T) { + handler := NewControllerRegister() + handler.Add("/api/list", &TestController{}, "*:List") + handler.Add("/person/:last/:first", &TestController{}, "*:Param") + if a := handler.URLFor("TestController.List"); a != "/api/list" { + logs.Info(a) + t.Errorf("TestController.List must equal to /api/list") + } + if a := handler.URLFor("TestController.Param", ":last", "xie", ":first", "asta"); a != "/person/xie/asta" { + t.Errorf("TestController.Param must equal to /person/xie/asta, but get " + a) + } +} + +func TestUrlFor3(t *testing.T) { + handler := NewControllerRegister() + handler.AddAuto(&TestController{}) + if a := handler.URLFor("TestController.Myext"); a != "/test/myext" && a != "/Test/Myext" { + t.Errorf("TestController.Myext must equal to /test/myext, but get " + a) + } + if a := handler.URLFor("TestController.GetURL"); a != "/test/geturl" && a != "/Test/GetURL" { + t.Errorf("TestController.GetURL must equal to /test/geturl, but get " + a) + } +} + +func TestUrlFor2(t *testing.T) { + handler := NewControllerRegister() + handler.Add("/v1/:v/cms_:id(.+)_:page(.+).html", &TestController{}, "*:List") + handler.Add("/v1/:username/edit", &TestController{}, "get:GetURL") + handler.Add("/v1/:v(.+)_cms/ttt_:id(.+)_:page(.+).html", &TestController{}, "*:Param") + handler.Add("/:year:int/:month:int/:title/:entid", &TestController{}) + if handler.URLFor("TestController.GetURL", ":username", "astaxie") != "/v1/astaxie/edit" { + logs.Info(handler.URLFor("TestController.GetURL")) + t.Errorf("TestController.List must equal to /v1/astaxie/edit") + } + + if handler.URLFor("TestController.List", ":v", "za", ":id", "12", ":page", "123") != + "/v1/za/cms_12_123.html" { + logs.Info(handler.URLFor("TestController.List")) + t.Errorf("TestController.List must equal to /v1/za/cms_12_123.html") + } + if handler.URLFor("TestController.Param", ":v", "za", ":id", "12", ":page", "123") != + "/v1/za_cms/ttt_12_123.html" { + logs.Info(handler.URLFor("TestController.Param")) + t.Errorf("TestController.List must equal to /v1/za_cms/ttt_12_123.html") + } + if handler.URLFor("TestController.Get", ":year", "1111", ":month", "11", + ":title", "aaaa", ":entid", "aaaa") != + "/1111/11/aaaa/aaaa" { + logs.Info(handler.URLFor("TestController.Get")) + t.Errorf("TestController.Get must equal to /1111/11/aaaa/aaaa") + } +} + +func TestUserFunc(t *testing.T) { + r, _ := http.NewRequest("GET", "/api/list", nil) + w := httptest.NewRecorder() + + handler := NewControllerRegister() + handler.Add("/api/list", &TestController{}, "*:List") + handler.ServeHTTP(w, r) + if w.Body.String() != "i am list" { + t.Errorf("user define func can't run") + } +} + +func TestPostFunc(t *testing.T) { + r, _ := http.NewRequest("POST", "/astaxie", nil) + w := httptest.NewRecorder() + + handler := NewControllerRegister() + handler.Add("/:name", &TestController{}) + handler.ServeHTTP(w, r) + if w.Body.String() != "astaxie" { + t.Errorf("post func should astaxie") + } +} + +func TestAutoFunc(t *testing.T) { + r, _ := http.NewRequest("GET", "/test/list", nil) + w := httptest.NewRecorder() + + handler := NewControllerRegister() + handler.AddAuto(&TestController{}) + handler.ServeHTTP(w, r) + if w.Body.String() != "i am list" { + t.Errorf("user define func can't run") + } +} + +func TestAutoFunc2(t *testing.T) { + r, _ := http.NewRequest("GET", "/Test/List", nil) + w := httptest.NewRecorder() + + handler := NewControllerRegister() + handler.AddAuto(&TestController{}) + handler.ServeHTTP(w, r) + if w.Body.String() != "i am list" { + t.Errorf("user define func can't run") + } +} + +func TestAutoFuncParams(t *testing.T) { + r, _ := http.NewRequest("GET", "/test/params/2009/11/12", nil) + w := httptest.NewRecorder() + + handler := NewControllerRegister() + handler.AddAuto(&TestController{}) + handler.ServeHTTP(w, r) + if w.Body.String() != "20091112" { + t.Errorf("user define func can't run") + } +} + +func TestAutoExtFunc(t *testing.T) { + r, _ := http.NewRequest("GET", "/test/myext.json", nil) + w := httptest.NewRecorder() + + handler := NewControllerRegister() + handler.AddAuto(&TestController{}) + handler.ServeHTTP(w, r) + if w.Body.String() != "json" { + t.Errorf("user define func can't run") + } +} + +func TestRouteOk(t *testing.T) { + + r, _ := http.NewRequest("GET", "/person/anderson/thomas?learn=kungfu", nil) + w := httptest.NewRecorder() + + handler := NewControllerRegister() + handler.Add("/person/:last/:first", &TestController{}, "get:GetParams") + handler.ServeHTTP(w, r) + body := w.Body.String() + if body != "anderson+thomas+kungfu" { + t.Errorf("url param set to [%s];", body) + } +} + +func TestManyRoute(t *testing.T) { + + r, _ := http.NewRequest("GET", "/beego32-12.html", nil) + w := httptest.NewRecorder() + + handler := NewControllerRegister() + handler.Add("/beego:id([0-9]+)-:page([0-9]+).html", &TestController{}, "get:GetManyRouter") + handler.ServeHTTP(w, r) + + body := w.Body.String() + + if body != "3212" { + t.Errorf("url param set to [%s];", body) + } +} + +// Test for issue #1669 +func TestEmptyResponse(t *testing.T) { + + r, _ := http.NewRequest("GET", "/beego-empty.html", nil) + w := httptest.NewRecorder() + + handler := NewControllerRegister() + handler.Add("/beego-empty.html", &TestController{}, "get:GetEmptyBody") + handler.ServeHTTP(w, r) + + if body := w.Body.String(); body != "" { + t.Error("want empty body") + } +} + +func TestNotFound(t *testing.T) { + r, _ := http.NewRequest("GET", "/", nil) + w := httptest.NewRecorder() + + handler := NewControllerRegister() + handler.ServeHTTP(w, r) + + if w.Code != http.StatusNotFound { + t.Errorf("Code set to [%v]; want [%v]", w.Code, http.StatusNotFound) + } +} + +// TestStatic tests the ability to serve static +// content from the filesystem +func TestStatic(t *testing.T) { + r, _ := http.NewRequest("GET", "/static/js/jquery.js", nil) + w := httptest.NewRecorder() + + handler := NewControllerRegister() + handler.ServeHTTP(w, r) + + if w.Code != 404 { + t.Errorf("handler.Static failed to serve file") + } +} + +func TestPrepare(t *testing.T) { + r, _ := http.NewRequest("GET", "/json/list", nil) + w := httptest.NewRecorder() + + handler := NewControllerRegister() + handler.Add("/json/list", &JSONController{}) + handler.ServeHTTP(w, r) + if w.Body.String() != `"prepare"` { + t.Errorf(w.Body.String() + "user define func can't run") + } +} + +func TestAutoPrefix(t *testing.T) { + r, _ := http.NewRequest("GET", "/admin/test/list", nil) + w := httptest.NewRecorder() + + handler := NewControllerRegister() + handler.AddAutoPrefix("/admin", &TestController{}) + handler.ServeHTTP(w, r) + if w.Body.String() != "i am list" { + t.Errorf("TestAutoPrefix can't run") + } +} + +func TestRouterGet(t *testing.T) { + r, _ := http.NewRequest("GET", "/user", nil) + w := httptest.NewRecorder() + + handler := NewControllerRegister() + handler.Get("/user", func(ctx *context.Context) { + ctx.Output.Body([]byte("Get userlist")) + }) + handler.ServeHTTP(w, r) + if w.Body.String() != "Get userlist" { + t.Errorf("TestRouterGet can't run") + } +} + +func TestRouterPost(t *testing.T) { + r, _ := http.NewRequest("POST", "/user/123", nil) + w := httptest.NewRecorder() + + handler := NewControllerRegister() + handler.Post("/user/:id", func(ctx *context.Context) { + ctx.Output.Body([]byte(ctx.Input.Param(":id"))) + }) + handler.ServeHTTP(w, r) + if w.Body.String() != "123" { + t.Errorf("TestRouterPost can't run") + } +} + +func sayhello(w http.ResponseWriter, r *http.Request) { + w.Write([]byte("sayhello")) +} + +func TestRouterHandler(t *testing.T) { + r, _ := http.NewRequest("POST", "/sayhi", nil) + w := httptest.NewRecorder() + + handler := NewControllerRegister() + handler.Handler("/sayhi", http.HandlerFunc(sayhello)) + handler.ServeHTTP(w, r) + if w.Body.String() != "sayhello" { + t.Errorf("TestRouterHandler can't run") + } +} + +func TestRouterHandlerAll(t *testing.T) { + r, _ := http.NewRequest("POST", "/sayhi/a/b/c", nil) + w := httptest.NewRecorder() + + handler := NewControllerRegister() + handler.Handler("/sayhi", http.HandlerFunc(sayhello), true) + handler.ServeHTTP(w, r) + if w.Body.String() != "sayhello" { + t.Errorf("TestRouterHandler can't run") + } +} + +// +// Benchmarks NewApp: +// + +func beegoFilterFunc(ctx *context.Context) { + ctx.WriteString("hello") +} + +type AdminController struct { + Controller +} + +func (a *AdminController) Get() { + a.Ctx.WriteString("hello") +} + +func TestRouterFunc(t *testing.T) { + mux := NewControllerRegister() + mux.Get("/action", beegoFilterFunc) + mux.Post("/action", beegoFilterFunc) + rw, r := testRequest("GET", "/action") + mux.ServeHTTP(rw, r) + if rw.Body.String() != "hello" { + t.Errorf("TestRouterFunc can't run") + } +} + +func BenchmarkFunc(b *testing.B) { + mux := NewControllerRegister() + mux.Get("/action", beegoFilterFunc) + rw, r := testRequest("GET", "/action") + b.ResetTimer() + for i := 0; i < b.N; i++ { + mux.ServeHTTP(rw, r) + } +} + +func BenchmarkController(b *testing.B) { + mux := NewControllerRegister() + mux.Add("/action", &AdminController{}) + rw, r := testRequest("GET", "/action") + b.ResetTimer() + for i := 0; i < b.N; i++ { + mux.ServeHTTP(rw, r) + } +} + +func testRequest(method, path string) (*httptest.ResponseRecorder, *http.Request) { + request, _ := http.NewRequest(method, path, nil) + recorder := httptest.NewRecorder() + + return recorder, request +} + +// Expectation: A Filter with the correct configuration should be created given +// specific parameters. +func TestInsertFilter(t *testing.T) { + testName := "TestInsertFilter" + + mux := NewControllerRegister() + mux.InsertFilter("*", BeforeRouter, func(*context.Context) {}) + if !mux.filters[BeforeRouter][0].returnOnOutput { + t.Errorf( + "%s: passing no variadic params should set returnOnOutput to true", + testName) + } + if mux.filters[BeforeRouter][0].resetParams { + t.Errorf( + "%s: passing no variadic params should set resetParams to false", + testName) + } + + mux = NewControllerRegister() + mux.InsertFilter("*", BeforeRouter, func(*context.Context) {}, false) + if mux.filters[BeforeRouter][0].returnOnOutput { + t.Errorf( + "%s: passing false as 1st variadic param should set returnOnOutput to false", + testName) + } + + mux = NewControllerRegister() + mux.InsertFilter("*", BeforeRouter, func(*context.Context) {}, true, true) + if !mux.filters[BeforeRouter][0].resetParams { + t.Errorf( + "%s: passing true as 2nd variadic param should set resetParams to true", + testName) + } +} + +// Expectation: the second variadic arg should cause the execution of the filter +// to preserve the parameters from before its execution. +func TestParamResetFilter(t *testing.T) { + testName := "TestParamResetFilter" + route := "/beego/*" // splat + path := "/beego/routes/routes" + + mux := NewControllerRegister() + + mux.InsertFilter("*", BeforeExec, beegoResetParams, true, true) + + mux.Get(route, beegoHandleResetParams) + + rw, r := testRequest("GET", path) + mux.ServeHTTP(rw, r) + + // The two functions, `beegoResetParams` and `beegoHandleResetParams` add + // a response header of `Splat`. The expectation here is that that Header + // value should match what the _request's_ router set, not the filter's. + + headers := rw.Result().Header + if len(headers["Splat"]) != 1 { + t.Errorf( + "%s: There was an error in the test. Splat param not set in Header", + testName) + } + if headers["Splat"][0] != "routes/routes" { + t.Errorf( + "%s: expected `:splat` param to be [routes/routes] but it was [%s]", + testName, headers["Splat"][0]) + } +} + +// Execution point: BeforeRouter +// expectation: only BeforeRouter function is executed, notmatch output as router doesn't handle +func TestFilterBeforeRouter(t *testing.T) { + testName := "TestFilterBeforeRouter" + url := "/beforeRouter" + + mux := NewControllerRegister() + mux.InsertFilter(url, BeforeRouter, beegoBeforeRouter1) + + mux.Get(url, beegoFilterFunc) + + rw, r := testRequest("GET", url) + mux.ServeHTTP(rw, r) + + if !strings.Contains(rw.Body.String(), "BeforeRouter1") { + t.Errorf(testName + " BeforeRouter did not run") + } + if strings.Contains(rw.Body.String(), "hello") { + t.Errorf(testName + " BeforeRouter did not return properly") + } +} + +// Execution point: BeforeExec +// expectation: only BeforeExec function is executed, match as router determines route only +func TestFilterBeforeExec(t *testing.T) { + testName := "TestFilterBeforeExec" + url := "/beforeExec" + + mux := NewControllerRegister() + mux.InsertFilter(url, BeforeRouter, beegoFilterNoOutput) + mux.InsertFilter(url, BeforeExec, beegoBeforeExec1) + + mux.Get(url, beegoFilterFunc) + + rw, r := testRequest("GET", url) + mux.ServeHTTP(rw, r) + + if !strings.Contains(rw.Body.String(), "BeforeExec1") { + t.Errorf(testName + " BeforeExec did not run") + } + if strings.Contains(rw.Body.String(), "hello") { + t.Errorf(testName + " BeforeExec did not return properly") + } + if strings.Contains(rw.Body.String(), "BeforeRouter") { + t.Errorf(testName + " BeforeRouter ran in error") + } +} + +// Execution point: AfterExec +// expectation: only AfterExec function is executed, match as router handles +func TestFilterAfterExec(t *testing.T) { + testName := "TestFilterAfterExec" + url := "/afterExec" + + mux := NewControllerRegister() + mux.InsertFilter(url, BeforeRouter, beegoFilterNoOutput) + mux.InsertFilter(url, BeforeExec, beegoFilterNoOutput) + mux.InsertFilter(url, AfterExec, beegoAfterExec1, false) + + mux.Get(url, beegoFilterFunc) + + rw, r := testRequest("GET", url) + mux.ServeHTTP(rw, r) + + if !strings.Contains(rw.Body.String(), "AfterExec1") { + t.Errorf(testName + " AfterExec did not run") + } + if !strings.Contains(rw.Body.String(), "hello") { + t.Errorf(testName + " handler did not run properly") + } + if strings.Contains(rw.Body.String(), "BeforeRouter") { + t.Errorf(testName + " BeforeRouter ran in error") + } + if strings.Contains(rw.Body.String(), "BeforeExec") { + t.Errorf(testName + " BeforeExec ran in error") + } +} + +// Execution point: FinishRouter +// expectation: only FinishRouter function is executed, match as router handles +func TestFilterFinishRouter(t *testing.T) { + testName := "TestFilterFinishRouter" + url := "/finishRouter" + + mux := NewControllerRegister() + mux.InsertFilter(url, BeforeRouter, beegoFilterNoOutput) + mux.InsertFilter(url, BeforeExec, beegoFilterNoOutput) + mux.InsertFilter(url, AfterExec, beegoFilterNoOutput) + mux.InsertFilter(url, FinishRouter, beegoFinishRouter1) + + mux.Get(url, beegoFilterFunc) + + rw, r := testRequest("GET", url) + mux.ServeHTTP(rw, r) + + if strings.Contains(rw.Body.String(), "FinishRouter1") { + t.Errorf(testName + " FinishRouter did not run") + } + if !strings.Contains(rw.Body.String(), "hello") { + t.Errorf(testName + " handler did not run properly") + } + if strings.Contains(rw.Body.String(), "AfterExec1") { + t.Errorf(testName + " AfterExec ran in error") + } + if strings.Contains(rw.Body.String(), "BeforeRouter") { + t.Errorf(testName + " BeforeRouter ran in error") + } + if strings.Contains(rw.Body.String(), "BeforeExec") { + t.Errorf(testName + " BeforeExec ran in error") + } +} + +// Execution point: FinishRouter +// expectation: only first FinishRouter function is executed, match as router handles +func TestFilterFinishRouterMultiFirstOnly(t *testing.T) { + testName := "TestFilterFinishRouterMultiFirstOnly" + url := "/finishRouterMultiFirstOnly" + + mux := NewControllerRegister() + mux.InsertFilter(url, FinishRouter, beegoFinishRouter1, false) + mux.InsertFilter(url, FinishRouter, beegoFinishRouter2) + + mux.Get(url, beegoFilterFunc) + + rw, r := testRequest("GET", url) + mux.ServeHTTP(rw, r) + + if !strings.Contains(rw.Body.String(), "FinishRouter1") { + t.Errorf(testName + " FinishRouter1 did not run") + } + if !strings.Contains(rw.Body.String(), "hello") { + t.Errorf(testName + " handler did not run properly") + } + // not expected in body + if strings.Contains(rw.Body.String(), "FinishRouter2") { + t.Errorf(testName + " FinishRouter2 did run") + } +} + +// Execution point: FinishRouter +// expectation: both FinishRouter functions execute, match as router handles +func TestFilterFinishRouterMulti(t *testing.T) { + testName := "TestFilterFinishRouterMulti" + url := "/finishRouterMulti" + + mux := NewControllerRegister() + mux.InsertFilter(url, FinishRouter, beegoFinishRouter1, false) + mux.InsertFilter(url, FinishRouter, beegoFinishRouter2, false) + + mux.Get(url, beegoFilterFunc) + + rw, r := testRequest("GET", url) + mux.ServeHTTP(rw, r) + + if !strings.Contains(rw.Body.String(), "FinishRouter1") { + t.Errorf(testName + " FinishRouter1 did not run") + } + if !strings.Contains(rw.Body.String(), "hello") { + t.Errorf(testName + " handler did not run properly") + } + if !strings.Contains(rw.Body.String(), "FinishRouter2") { + t.Errorf(testName + " FinishRouter2 did not run properly") + } +} + +func beegoFilterNoOutput(ctx *context.Context) { +} + +func beegoBeforeRouter1(ctx *context.Context) { + ctx.WriteString("|BeforeRouter1") +} + +func beegoBeforeExec1(ctx *context.Context) { + ctx.WriteString("|BeforeExec1") +} + +func beegoAfterExec1(ctx *context.Context) { + ctx.WriteString("|AfterExec1") +} + +func beegoFinishRouter1(ctx *context.Context) { + ctx.WriteString("|FinishRouter1") +} + +func beegoFinishRouter2(ctx *context.Context) { + ctx.WriteString("|FinishRouter2") +} + +func beegoResetParams(ctx *context.Context) { + ctx.ResponseWriter.Header().Set("splat", ctx.Input.Param(":splat")) +} + +func beegoHandleResetParams(ctx *context.Context) { + ctx.ResponseWriter.Header().Set("splat", ctx.Input.Param(":splat")) +} + +// YAML +type YAMLController struct { + Controller +} + +func (jc *YAMLController) Prepare() { + jc.Data["yaml"] = "prepare" + jc.ServeYAML() +} + +func (jc *YAMLController) Get() { + jc.Data["Username"] = "astaxie" + jc.Ctx.Output.Body([]byte("ok")) +} + +func TestYAMLPrepare(t *testing.T) { + r, _ := http.NewRequest("GET", "/yaml/list", nil) + w := httptest.NewRecorder() + + handler := NewControllerRegister() + handler.Add("/yaml/list", &YAMLController{}) + handler.ServeHTTP(w, r) + if strings.TrimSpace(w.Body.String()) != "prepare" { + t.Errorf(w.Body.String()) + } +} + +func TestRouterEntityTooLargeCopyBody(t *testing.T) { + _MaxMemory := BConfig.MaxMemory + _CopyRequestBody := BConfig.CopyRequestBody + BConfig.CopyRequestBody = true + BConfig.MaxMemory = 20 + + b := bytes.NewBuffer([]byte("barbarbarbarbarbarbarbarbarbar")) + r, _ := http.NewRequest("POST", "/user/123", b) + w := httptest.NewRecorder() + + handler := NewControllerRegister() + handler.Post("/user/:id", func(ctx *context.Context) { + ctx.Output.Body([]byte(ctx.Input.Param(":id"))) + }) + handler.ServeHTTP(w, r) + + BConfig.CopyRequestBody = _CopyRequestBody + BConfig.MaxMemory = _MaxMemory + + if w.Code != http.StatusRequestEntityTooLarge { + t.Errorf("TestRouterRequestEntityTooLarge can't run") + } +} diff --git a/pkg/session/README.md b/pkg/session/README.md new file mode 100644 index 00000000..6d0a297e --- /dev/null +++ b/pkg/session/README.md @@ -0,0 +1,114 @@ +session +============== + +session is a Go session manager. It can use many session providers. Just like the `database/sql` and `database/sql/driver`. + +## How to install? + + go get github.com/astaxie/beego/session + + +## What providers are supported? + +As of now this session manager support memory, file, Redis and MySQL. + + +## How to use it? + +First you must import it + + import ( + "github.com/astaxie/beego/session" + ) + +Then in you web app init the global session manager + + var globalSessions *session.Manager + +* Use **memory** as provider: + + func init() { + globalSessions, _ = session.NewManager("memory", `{"cookieName":"gosessionid","gclifetime":3600}`) + go globalSessions.GC() + } + +* Use **file** as provider, the last param is the path where you want file to be stored: + + func init() { + globalSessions, _ = session.NewManager("file",`{"cookieName":"gosessionid","gclifetime":3600,"ProviderConfig":"./tmp"}`) + go globalSessions.GC() + } + +* Use **Redis** as provider, the last param is the Redis conn address,poolsize,password: + + func init() { + globalSessions, _ = session.NewManager("redis", `{"cookieName":"gosessionid","gclifetime":3600,"ProviderConfig":"127.0.0.1:6379,100,astaxie"}`) + go globalSessions.GC() + } + +* Use **MySQL** as provider, the last param is the DSN, learn more from [mysql](https://github.com/go-sql-driver/mysql#dsn-data-source-name): + + func init() { + globalSessions, _ = session.NewManager( + "mysql", `{"cookieName":"gosessionid","gclifetime":3600,"ProviderConfig":"username:password@protocol(address)/dbname?param=value"}`) + go globalSessions.GC() + } + +* Use **Cookie** as provider: + + func init() { + globalSessions, _ = session.NewManager( + "cookie", `{"cookieName":"gosessionid","enableSetCookie":false,"gclifetime":3600,"ProviderConfig":"{\"cookieName\":\"gosessionid\",\"securityKey\":\"beegocookiehashkey\"}"}`) + go globalSessions.GC() + } + + +Finally in the handlerfunc you can use it like this + + func login(w http.ResponseWriter, r *http.Request) { + sess := globalSessions.SessionStart(w, r) + defer sess.SessionRelease(w) + username := sess.Get("username") + fmt.Println(username) + if r.Method == "GET" { + t, _ := template.ParseFiles("login.gtpl") + t.Execute(w, nil) + } else { + fmt.Println("username:", r.Form["username"]) + sess.Set("username", r.Form["username"]) + fmt.Println("password:", r.Form["password"]) + } + } + + +## How to write own provider? + +When you develop a web app, maybe you want to write own provider because you must meet the requirements. + +Writing a provider is easy. You only need to define two struct types +(Session and Provider), which satisfy the interface definition. +Maybe you will find the **memory** provider is a good example. + + type SessionStore interface { + Set(key, value interface{}) error //set session value + Get(key interface{}) interface{} //get session value + Delete(key interface{}) error //delete session value + SessionID() string //back current sessionID + SessionRelease(w http.ResponseWriter) // release the resource & save data to provider & return the data + Flush() error //delete all data + } + + type Provider interface { + SessionInit(gclifetime int64, config string) error + SessionRead(sid string) (SessionStore, error) + SessionExist(sid string) bool + SessionRegenerate(oldsid, sid string) (SessionStore, error) + SessionDestroy(sid string) error + SessionAll() int //get all active session + SessionGC() + } + + +## LICENSE + +BSD License http://creativecommons.org/licenses/BSD/ diff --git a/pkg/session/couchbase/sess_couchbase.go b/pkg/session/couchbase/sess_couchbase.go new file mode 100644 index 00000000..707d042c --- /dev/null +++ b/pkg/session/couchbase/sess_couchbase.go @@ -0,0 +1,247 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package couchbase for session provider +// +// depend on github.com/couchbaselabs/go-couchbasee +// +// go install github.com/couchbaselabs/go-couchbase +// +// Usage: +// import( +// _ "github.com/astaxie/beego/session/couchbase" +// "github.com/astaxie/beego/session" +// ) +// +// func init() { +// globalSessions, _ = session.NewManager("couchbase", ``{"cookieName":"gosessionid","gclifetime":3600,"ProviderConfig":"http://host:port/, Pool, Bucket"}``) +// go globalSessions.GC() +// } +// +// more docs: http://beego.me/docs/module/session.md +package couchbase + +import ( + "net/http" + "strings" + "sync" + + couchbase "github.com/couchbase/go-couchbase" + + "github.com/astaxie/beego/session" +) + +var couchbpder = &Provider{} + +// SessionStore store each session +type SessionStore struct { + b *couchbase.Bucket + sid string + lock sync.RWMutex + values map[interface{}]interface{} + maxlifetime int64 +} + +// Provider couchabse provided +type Provider struct { + maxlifetime int64 + savePath string + pool string + bucket string + b *couchbase.Bucket +} + +// Set value to couchabse session +func (cs *SessionStore) Set(key, value interface{}) error { + cs.lock.Lock() + defer cs.lock.Unlock() + cs.values[key] = value + return nil +} + +// Get value from couchabse session +func (cs *SessionStore) Get(key interface{}) interface{} { + cs.lock.RLock() + defer cs.lock.RUnlock() + if v, ok := cs.values[key]; ok { + return v + } + return nil +} + +// Delete value in couchbase session by given key +func (cs *SessionStore) Delete(key interface{}) error { + cs.lock.Lock() + defer cs.lock.Unlock() + delete(cs.values, key) + return nil +} + +// Flush Clean all values in couchbase session +func (cs *SessionStore) Flush() error { + cs.lock.Lock() + defer cs.lock.Unlock() + cs.values = make(map[interface{}]interface{}) + return nil +} + +// SessionID Get couchbase session store id +func (cs *SessionStore) SessionID() string { + return cs.sid +} + +// SessionRelease Write couchbase session with Gob string +func (cs *SessionStore) SessionRelease(w http.ResponseWriter) { + defer cs.b.Close() + + bo, err := session.EncodeGob(cs.values) + if err != nil { + return + } + + cs.b.Set(cs.sid, int(cs.maxlifetime), bo) +} + +func (cp *Provider) getBucket() *couchbase.Bucket { + c, err := couchbase.Connect(cp.savePath) + if err != nil { + return nil + } + + pool, err := c.GetPool(cp.pool) + if err != nil { + return nil + } + + bucket, err := pool.GetBucket(cp.bucket) + if err != nil { + return nil + } + + return bucket +} + +// SessionInit init couchbase session +// savepath like couchbase server REST/JSON URL +// e.g. http://host:port/, Pool, Bucket +func (cp *Provider) SessionInit(maxlifetime int64, savePath string) error { + cp.maxlifetime = maxlifetime + configs := strings.Split(savePath, ",") + if len(configs) > 0 { + cp.savePath = configs[0] + } + if len(configs) > 1 { + cp.pool = configs[1] + } + if len(configs) > 2 { + cp.bucket = configs[2] + } + + return nil +} + +// SessionRead read couchbase session by sid +func (cp *Provider) SessionRead(sid string) (session.Store, error) { + cp.b = cp.getBucket() + + var ( + kv map[interface{}]interface{} + err error + doc []byte + ) + + err = cp.b.Get(sid, &doc) + if err != nil { + return nil, err + } else if doc == nil { + kv = make(map[interface{}]interface{}) + } else { + kv, err = session.DecodeGob(doc) + if err != nil { + return nil, err + } + } + + cs := &SessionStore{b: cp.b, sid: sid, values: kv, maxlifetime: cp.maxlifetime} + return cs, nil +} + +// SessionExist Check couchbase session exist. +// it checkes sid exist or not. +func (cp *Provider) SessionExist(sid string) bool { + cp.b = cp.getBucket() + defer cp.b.Close() + + var doc []byte + + if err := cp.b.Get(sid, &doc); err != nil || doc == nil { + return false + } + return true +} + +// SessionRegenerate remove oldsid and use sid to generate new session +func (cp *Provider) SessionRegenerate(oldsid, sid string) (session.Store, error) { + cp.b = cp.getBucket() + + var doc []byte + if err := cp.b.Get(oldsid, &doc); err != nil || doc == nil { + cp.b.Set(sid, int(cp.maxlifetime), "") + } else { + err := cp.b.Delete(oldsid) + if err != nil { + return nil, err + } + _, _ = cp.b.Add(sid, int(cp.maxlifetime), doc) + } + + err := cp.b.Get(sid, &doc) + if err != nil { + return nil, err + } + var kv map[interface{}]interface{} + if doc == nil { + kv = make(map[interface{}]interface{}) + } else { + kv, err = session.DecodeGob(doc) + if err != nil { + return nil, err + } + } + + cs := &SessionStore{b: cp.b, sid: sid, values: kv, maxlifetime: cp.maxlifetime} + return cs, nil +} + +// SessionDestroy Remove bucket in this couchbase +func (cp *Provider) SessionDestroy(sid string) error { + cp.b = cp.getBucket() + defer cp.b.Close() + + cp.b.Delete(sid) + return nil +} + +// SessionGC Recycle +func (cp *Provider) SessionGC() { +} + +// SessionAll return all active session +func (cp *Provider) SessionAll() int { + return 0 +} + +func init() { + session.Register("couchbase", couchbpder) +} diff --git a/pkg/session/ledis/ledis_session.go b/pkg/session/ledis/ledis_session.go new file mode 100644 index 00000000..ee81df67 --- /dev/null +++ b/pkg/session/ledis/ledis_session.go @@ -0,0 +1,173 @@ +// Package ledis provide session Provider +package ledis + +import ( + "net/http" + "strconv" + "strings" + "sync" + + "github.com/ledisdb/ledisdb/config" + "github.com/ledisdb/ledisdb/ledis" + + "github.com/astaxie/beego/session" +) + +var ( + ledispder = &Provider{} + c *ledis.DB +) + +// SessionStore ledis session store +type SessionStore struct { + sid string + lock sync.RWMutex + values map[interface{}]interface{} + maxlifetime int64 +} + +// Set value in ledis session +func (ls *SessionStore) Set(key, value interface{}) error { + ls.lock.Lock() + defer ls.lock.Unlock() + ls.values[key] = value + return nil +} + +// Get value in ledis session +func (ls *SessionStore) Get(key interface{}) interface{} { + ls.lock.RLock() + defer ls.lock.RUnlock() + if v, ok := ls.values[key]; ok { + return v + } + return nil +} + +// Delete value in ledis session +func (ls *SessionStore) Delete(key interface{}) error { + ls.lock.Lock() + defer ls.lock.Unlock() + delete(ls.values, key) + return nil +} + +// Flush clear all values in ledis session +func (ls *SessionStore) Flush() error { + ls.lock.Lock() + defer ls.lock.Unlock() + ls.values = make(map[interface{}]interface{}) + return nil +} + +// SessionID get ledis session id +func (ls *SessionStore) SessionID() string { + return ls.sid +} + +// SessionRelease save session values to ledis +func (ls *SessionStore) SessionRelease(w http.ResponseWriter) { + b, err := session.EncodeGob(ls.values) + if err != nil { + return + } + c.Set([]byte(ls.sid), b) + c.Expire([]byte(ls.sid), ls.maxlifetime) +} + +// Provider ledis session provider +type Provider struct { + maxlifetime int64 + savePath string + db int +} + +// SessionInit init ledis session +// savepath like ledis server saveDataPath,pool size +// e.g. 127.0.0.1:6379,100,astaxie +func (lp *Provider) SessionInit(maxlifetime int64, savePath string) error { + var err error + lp.maxlifetime = maxlifetime + configs := strings.Split(savePath, ",") + if len(configs) == 1 { + lp.savePath = configs[0] + } else if len(configs) == 2 { + lp.savePath = configs[0] + lp.db, err = strconv.Atoi(configs[1]) + if err != nil { + return err + } + } + cfg := new(config.Config) + cfg.DataDir = lp.savePath + + var ledisInstance *ledis.Ledis + ledisInstance, err = ledis.Open(cfg) + if err != nil { + return err + } + c, err = ledisInstance.Select(lp.db) + return err +} + +// SessionRead read ledis session by sid +func (lp *Provider) SessionRead(sid string) (session.Store, error) { + var ( + kv map[interface{}]interface{} + err error + ) + + kvs, _ := c.Get([]byte(sid)) + + if len(kvs) == 0 { + kv = make(map[interface{}]interface{}) + } else { + if kv, err = session.DecodeGob(kvs); err != nil { + return nil, err + } + } + + ls := &SessionStore{sid: sid, values: kv, maxlifetime: lp.maxlifetime} + return ls, nil +} + +// SessionExist check ledis session exist by sid +func (lp *Provider) SessionExist(sid string) bool { + count, _ := c.Exists([]byte(sid)) + return count != 0 +} + +// SessionRegenerate generate new sid for ledis session +func (lp *Provider) SessionRegenerate(oldsid, sid string) (session.Store, error) { + count, _ := c.Exists([]byte(sid)) + if count == 0 { + // oldsid doesn't exists, set the new sid directly + // ignore error here, since if it return error + // the existed value will be 0 + c.Set([]byte(sid), []byte("")) + c.Expire([]byte(sid), lp.maxlifetime) + } else { + data, _ := c.Get([]byte(oldsid)) + c.Set([]byte(sid), data) + c.Expire([]byte(sid), lp.maxlifetime) + } + return lp.SessionRead(sid) +} + +// SessionDestroy delete ledis session by id +func (lp *Provider) SessionDestroy(sid string) error { + c.Del([]byte(sid)) + return nil +} + +// SessionGC Impelment method, no used. +func (lp *Provider) SessionGC() { +} + +// SessionAll return all active session +func (lp *Provider) SessionAll() int { + return 0 +} +func init() { + session.Register("ledis", ledispder) +} diff --git a/pkg/session/memcache/sess_memcache.go b/pkg/session/memcache/sess_memcache.go new file mode 100644 index 00000000..85a2d815 --- /dev/null +++ b/pkg/session/memcache/sess_memcache.go @@ -0,0 +1,230 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package memcache for session provider +// +// depend on github.com/bradfitz/gomemcache/memcache +// +// go install github.com/bradfitz/gomemcache/memcache +// +// Usage: +// import( +// _ "github.com/astaxie/beego/session/memcache" +// "github.com/astaxie/beego/session" +// ) +// +// func init() { +// globalSessions, _ = session.NewManager("memcache", ``{"cookieName":"gosessionid","gclifetime":3600,"ProviderConfig":"127.0.0.1:11211"}``) +// go globalSessions.GC() +// } +// +// more docs: http://beego.me/docs/module/session.md +package memcache + +import ( + "net/http" + "strings" + "sync" + + "github.com/astaxie/beego/session" + + "github.com/bradfitz/gomemcache/memcache" +) + +var mempder = &MemProvider{} +var client *memcache.Client + +// SessionStore memcache session store +type SessionStore struct { + sid string + lock sync.RWMutex + values map[interface{}]interface{} + maxlifetime int64 +} + +// Set value in memcache session +func (rs *SessionStore) Set(key, value interface{}) error { + rs.lock.Lock() + defer rs.lock.Unlock() + rs.values[key] = value + return nil +} + +// Get value in memcache session +func (rs *SessionStore) Get(key interface{}) interface{} { + rs.lock.RLock() + defer rs.lock.RUnlock() + if v, ok := rs.values[key]; ok { + return v + } + return nil +} + +// Delete value in memcache session +func (rs *SessionStore) Delete(key interface{}) error { + rs.lock.Lock() + defer rs.lock.Unlock() + delete(rs.values, key) + return nil +} + +// Flush clear all values in memcache session +func (rs *SessionStore) Flush() error { + rs.lock.Lock() + defer rs.lock.Unlock() + rs.values = make(map[interface{}]interface{}) + return nil +} + +// SessionID get memcache session id +func (rs *SessionStore) SessionID() string { + return rs.sid +} + +// SessionRelease save session values to memcache +func (rs *SessionStore) SessionRelease(w http.ResponseWriter) { + b, err := session.EncodeGob(rs.values) + if err != nil { + return + } + item := memcache.Item{Key: rs.sid, Value: b, Expiration: int32(rs.maxlifetime)} + client.Set(&item) +} + +// MemProvider memcache session provider +type MemProvider struct { + maxlifetime int64 + conninfo []string + poolsize int + password string +} + +// SessionInit init memcache session +// savepath like +// e.g. 127.0.0.1:9090 +func (rp *MemProvider) SessionInit(maxlifetime int64, savePath string) error { + rp.maxlifetime = maxlifetime + rp.conninfo = strings.Split(savePath, ";") + client = memcache.New(rp.conninfo...) + return nil +} + +// SessionRead read memcache session by sid +func (rp *MemProvider) SessionRead(sid string) (session.Store, error) { + if client == nil { + if err := rp.connectInit(); err != nil { + return nil, err + } + } + item, err := client.Get(sid) + if err != nil { + if err == memcache.ErrCacheMiss { + rs := &SessionStore{sid: sid, values: make(map[interface{}]interface{}), maxlifetime: rp.maxlifetime} + return rs, nil + } + return nil, err + } + var kv map[interface{}]interface{} + if len(item.Value) == 0 { + kv = make(map[interface{}]interface{}) + } else { + kv, err = session.DecodeGob(item.Value) + if err != nil { + return nil, err + } + } + rs := &SessionStore{sid: sid, values: kv, maxlifetime: rp.maxlifetime} + return rs, nil +} + +// SessionExist check memcache session exist by sid +func (rp *MemProvider) SessionExist(sid string) bool { + if client == nil { + if err := rp.connectInit(); err != nil { + return false + } + } + if item, err := client.Get(sid); err != nil || len(item.Value) == 0 { + return false + } + return true +} + +// SessionRegenerate generate new sid for memcache session +func (rp *MemProvider) SessionRegenerate(oldsid, sid string) (session.Store, error) { + if client == nil { + if err := rp.connectInit(); err != nil { + return nil, err + } + } + var contain []byte + if item, err := client.Get(sid); err != nil || len(item.Value) == 0 { + // oldsid doesn't exists, set the new sid directly + // ignore error here, since if it return error + // the existed value will be 0 + item.Key = sid + item.Value = []byte("") + item.Expiration = int32(rp.maxlifetime) + client.Set(item) + } else { + client.Delete(oldsid) + item.Key = sid + item.Expiration = int32(rp.maxlifetime) + client.Set(item) + contain = item.Value + } + + var kv map[interface{}]interface{} + if len(contain) == 0 { + kv = make(map[interface{}]interface{}) + } else { + var err error + kv, err = session.DecodeGob(contain) + if err != nil { + return nil, err + } + } + + rs := &SessionStore{sid: sid, values: kv, maxlifetime: rp.maxlifetime} + return rs, nil +} + +// SessionDestroy delete memcache session by id +func (rp *MemProvider) SessionDestroy(sid string) error { + if client == nil { + if err := rp.connectInit(); err != nil { + return err + } + } + + return client.Delete(sid) +} + +func (rp *MemProvider) connectInit() error { + client = memcache.New(rp.conninfo...) + return nil +} + +// SessionGC Impelment method, no used. +func (rp *MemProvider) SessionGC() { +} + +// SessionAll return all activeSession +func (rp *MemProvider) SessionAll() int { + return 0 +} + +func init() { + session.Register("memcache", mempder) +} diff --git a/pkg/session/mysql/sess_mysql.go b/pkg/session/mysql/sess_mysql.go new file mode 100644 index 00000000..301353ab --- /dev/null +++ b/pkg/session/mysql/sess_mysql.go @@ -0,0 +1,228 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package mysql for session provider +// +// depends on github.com/go-sql-driver/mysql: +// +// go install github.com/go-sql-driver/mysql +// +// mysql session support need create table as sql: +// CREATE TABLE `session` ( +// `session_key` char(64) NOT NULL, +// `session_data` blob, +// `session_expiry` int(11) unsigned NOT NULL, +// PRIMARY KEY (`session_key`) +// ) ENGINE=MyISAM DEFAULT CHARSET=utf8; +// +// Usage: +// import( +// _ "github.com/astaxie/beego/session/mysql" +// "github.com/astaxie/beego/session" +// ) +// +// func init() { +// globalSessions, _ = session.NewManager("mysql", ``{"cookieName":"gosessionid","gclifetime":3600,"ProviderConfig":"[username[:password]@][protocol[(address)]]/dbname[?param1=value1&...¶mN=valueN]"}``) +// go globalSessions.GC() +// } +// +// more docs: http://beego.me/docs/module/session.md +package mysql + +import ( + "database/sql" + "net/http" + "sync" + "time" + + "github.com/astaxie/beego/session" + // import mysql driver + _ "github.com/go-sql-driver/mysql" +) + +var ( + // TableName store the session in MySQL + TableName = "session" + mysqlpder = &Provider{} +) + +// SessionStore mysql session store +type SessionStore struct { + c *sql.DB + sid string + lock sync.RWMutex + values map[interface{}]interface{} +} + +// Set value in mysql session. +// it is temp value in map. +func (st *SessionStore) Set(key, value interface{}) error { + st.lock.Lock() + defer st.lock.Unlock() + st.values[key] = value + return nil +} + +// Get value from mysql session +func (st *SessionStore) Get(key interface{}) interface{} { + st.lock.RLock() + defer st.lock.RUnlock() + if v, ok := st.values[key]; ok { + return v + } + return nil +} + +// Delete value in mysql session +func (st *SessionStore) Delete(key interface{}) error { + st.lock.Lock() + defer st.lock.Unlock() + delete(st.values, key) + return nil +} + +// Flush clear all values in mysql session +func (st *SessionStore) Flush() error { + st.lock.Lock() + defer st.lock.Unlock() + st.values = make(map[interface{}]interface{}) + return nil +} + +// SessionID get session id of this mysql session store +func (st *SessionStore) SessionID() string { + return st.sid +} + +// SessionRelease save mysql session values to database. +// must call this method to save values to database. +func (st *SessionStore) SessionRelease(w http.ResponseWriter) { + defer st.c.Close() + b, err := session.EncodeGob(st.values) + if err != nil { + return + } + st.c.Exec("UPDATE "+TableName+" set `session_data`=?, `session_expiry`=? where session_key=?", + b, time.Now().Unix(), st.sid) +} + +// Provider mysql session provider +type Provider struct { + maxlifetime int64 + savePath string +} + +// connect to mysql +func (mp *Provider) connectInit() *sql.DB { + db, e := sql.Open("mysql", mp.savePath) + if e != nil { + return nil + } + return db +} + +// SessionInit init mysql session. +// savepath is the connection string of mysql. +func (mp *Provider) SessionInit(maxlifetime int64, savePath string) error { + mp.maxlifetime = maxlifetime + mp.savePath = savePath + return nil +} + +// SessionRead get mysql session by sid +func (mp *Provider) SessionRead(sid string) (session.Store, error) { + c := mp.connectInit() + row := c.QueryRow("select session_data from "+TableName+" where session_key=?", sid) + var sessiondata []byte + err := row.Scan(&sessiondata) + if err == sql.ErrNoRows { + c.Exec("insert into "+TableName+"(`session_key`,`session_data`,`session_expiry`) values(?,?,?)", + sid, "", time.Now().Unix()) + } + var kv map[interface{}]interface{} + if len(sessiondata) == 0 { + kv = make(map[interface{}]interface{}) + } else { + kv, err = session.DecodeGob(sessiondata) + if err != nil { + return nil, err + } + } + rs := &SessionStore{c: c, sid: sid, values: kv} + return rs, nil +} + +// SessionExist check mysql session exist +func (mp *Provider) SessionExist(sid string) bool { + c := mp.connectInit() + defer c.Close() + row := c.QueryRow("select session_data from "+TableName+" where session_key=?", sid) + var sessiondata []byte + err := row.Scan(&sessiondata) + return err != sql.ErrNoRows +} + +// SessionRegenerate generate new sid for mysql session +func (mp *Provider) SessionRegenerate(oldsid, sid string) (session.Store, error) { + c := mp.connectInit() + row := c.QueryRow("select session_data from "+TableName+" where session_key=?", oldsid) + var sessiondata []byte + err := row.Scan(&sessiondata) + if err == sql.ErrNoRows { + c.Exec("insert into "+TableName+"(`session_key`,`session_data`,`session_expiry`) values(?,?,?)", oldsid, "", time.Now().Unix()) + } + c.Exec("update "+TableName+" set `session_key`=? where session_key=?", sid, oldsid) + var kv map[interface{}]interface{} + if len(sessiondata) == 0 { + kv = make(map[interface{}]interface{}) + } else { + kv, err = session.DecodeGob(sessiondata) + if err != nil { + return nil, err + } + } + rs := &SessionStore{c: c, sid: sid, values: kv} + return rs, nil +} + +// SessionDestroy delete mysql session by sid +func (mp *Provider) SessionDestroy(sid string) error { + c := mp.connectInit() + c.Exec("DELETE FROM "+TableName+" where session_key=?", sid) + c.Close() + return nil +} + +// SessionGC delete expired values in mysql session +func (mp *Provider) SessionGC() { + c := mp.connectInit() + c.Exec("DELETE from "+TableName+" where session_expiry < ?", time.Now().Unix()-mp.maxlifetime) + c.Close() +} + +// SessionAll count values in mysql session +func (mp *Provider) SessionAll() int { + c := mp.connectInit() + defer c.Close() + var total int + err := c.QueryRow("SELECT count(*) as num from " + TableName).Scan(&total) + if err != nil { + return 0 + } + return total +} + +func init() { + session.Register("mysql", mysqlpder) +} diff --git a/pkg/session/postgres/sess_postgresql.go b/pkg/session/postgres/sess_postgresql.go new file mode 100644 index 00000000..0b8b9645 --- /dev/null +++ b/pkg/session/postgres/sess_postgresql.go @@ -0,0 +1,243 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package postgres for session provider +// +// depends on github.com/lib/pq: +// +// go install github.com/lib/pq +// +// +// needs this table in your database: +// +// CREATE TABLE session ( +// session_key char(64) NOT NULL, +// session_data bytea, +// session_expiry timestamp NOT NULL, +// CONSTRAINT session_key PRIMARY KEY(session_key) +// ); +// +// will be activated with these settings in app.conf: +// +// SessionOn = true +// SessionProvider = postgresql +// SessionSavePath = "user=a password=b dbname=c sslmode=disable" +// SessionName = session +// +// +// Usage: +// import( +// _ "github.com/astaxie/beego/session/postgresql" +// "github.com/astaxie/beego/session" +// ) +// +// func init() { +// globalSessions, _ = session.NewManager("postgresql", ``{"cookieName":"gosessionid","gclifetime":3600,"ProviderConfig":"user=pqgotest dbname=pqgotest sslmode=verify-full"}``) +// go globalSessions.GC() +// } +// +// more docs: http://beego.me/docs/module/session.md +package postgres + +import ( + "database/sql" + "net/http" + "sync" + "time" + + "github.com/astaxie/beego/session" + // import postgresql Driver + _ "github.com/lib/pq" +) + +var postgresqlpder = &Provider{} + +// SessionStore postgresql session store +type SessionStore struct { + c *sql.DB + sid string + lock sync.RWMutex + values map[interface{}]interface{} +} + +// Set value in postgresql session. +// it is temp value in map. +func (st *SessionStore) Set(key, value interface{}) error { + st.lock.Lock() + defer st.lock.Unlock() + st.values[key] = value + return nil +} + +// Get value from postgresql session +func (st *SessionStore) Get(key interface{}) interface{} { + st.lock.RLock() + defer st.lock.RUnlock() + if v, ok := st.values[key]; ok { + return v + } + return nil +} + +// Delete value in postgresql session +func (st *SessionStore) Delete(key interface{}) error { + st.lock.Lock() + defer st.lock.Unlock() + delete(st.values, key) + return nil +} + +// Flush clear all values in postgresql session +func (st *SessionStore) Flush() error { + st.lock.Lock() + defer st.lock.Unlock() + st.values = make(map[interface{}]interface{}) + return nil +} + +// SessionID get session id of this postgresql session store +func (st *SessionStore) SessionID() string { + return st.sid +} + +// SessionRelease save postgresql session values to database. +// must call this method to save values to database. +func (st *SessionStore) SessionRelease(w http.ResponseWriter) { + defer st.c.Close() + b, err := session.EncodeGob(st.values) + if err != nil { + return + } + st.c.Exec("UPDATE session set session_data=$1, session_expiry=$2 where session_key=$3", + b, time.Now().Format(time.RFC3339), st.sid) + +} + +// Provider postgresql session provider +type Provider struct { + maxlifetime int64 + savePath string +} + +// connect to postgresql +func (mp *Provider) connectInit() *sql.DB { + db, e := sql.Open("postgres", mp.savePath) + if e != nil { + return nil + } + return db +} + +// SessionInit init postgresql session. +// savepath is the connection string of postgresql. +func (mp *Provider) SessionInit(maxlifetime int64, savePath string) error { + mp.maxlifetime = maxlifetime + mp.savePath = savePath + return nil +} + +// SessionRead get postgresql session by sid +func (mp *Provider) SessionRead(sid string) (session.Store, error) { + c := mp.connectInit() + row := c.QueryRow("select session_data from session where session_key=$1", sid) + var sessiondata []byte + err := row.Scan(&sessiondata) + if err == sql.ErrNoRows { + _, err = c.Exec("insert into session(session_key,session_data,session_expiry) values($1,$2,$3)", + sid, "", time.Now().Format(time.RFC3339)) + + if err != nil { + return nil, err + } + } else if err != nil { + return nil, err + } + + var kv map[interface{}]interface{} + if len(sessiondata) == 0 { + kv = make(map[interface{}]interface{}) + } else { + kv, err = session.DecodeGob(sessiondata) + if err != nil { + return nil, err + } + } + rs := &SessionStore{c: c, sid: sid, values: kv} + return rs, nil +} + +// SessionExist check postgresql session exist +func (mp *Provider) SessionExist(sid string) bool { + c := mp.connectInit() + defer c.Close() + row := c.QueryRow("select session_data from session where session_key=$1", sid) + var sessiondata []byte + err := row.Scan(&sessiondata) + return err != sql.ErrNoRows +} + +// SessionRegenerate generate new sid for postgresql session +func (mp *Provider) SessionRegenerate(oldsid, sid string) (session.Store, error) { + c := mp.connectInit() + row := c.QueryRow("select session_data from session where session_key=$1", oldsid) + var sessiondata []byte + err := row.Scan(&sessiondata) + if err == sql.ErrNoRows { + c.Exec("insert into session(session_key,session_data,session_expiry) values($1,$2,$3)", + oldsid, "", time.Now().Format(time.RFC3339)) + } + c.Exec("update session set session_key=$1 where session_key=$2", sid, oldsid) + var kv map[interface{}]interface{} + if len(sessiondata) == 0 { + kv = make(map[interface{}]interface{}) + } else { + kv, err = session.DecodeGob(sessiondata) + if err != nil { + return nil, err + } + } + rs := &SessionStore{c: c, sid: sid, values: kv} + return rs, nil +} + +// SessionDestroy delete postgresql session by sid +func (mp *Provider) SessionDestroy(sid string) error { + c := mp.connectInit() + c.Exec("DELETE FROM session where session_key=$1", sid) + c.Close() + return nil +} + +// SessionGC delete expired values in postgresql session +func (mp *Provider) SessionGC() { + c := mp.connectInit() + c.Exec("DELETE from session where EXTRACT(EPOCH FROM (current_timestamp - session_expiry)) > $1", mp.maxlifetime) + c.Close() +} + +// SessionAll count values in postgresql session +func (mp *Provider) SessionAll() int { + c := mp.connectInit() + defer c.Close() + var total int + err := c.QueryRow("SELECT count(*) as num from session").Scan(&total) + if err != nil { + return 0 + } + return total +} + +func init() { + session.Register("postgresql", postgresqlpder) +} diff --git a/pkg/session/redis/sess_redis.go b/pkg/session/redis/sess_redis.go new file mode 100644 index 00000000..5c382d61 --- /dev/null +++ b/pkg/session/redis/sess_redis.go @@ -0,0 +1,261 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package redis for session provider +// +// depend on github.com/gomodule/redigo/redis +// +// go install github.com/gomodule/redigo/redis +// +// Usage: +// import( +// _ "github.com/astaxie/beego/session/redis" +// "github.com/astaxie/beego/session" +// ) +// +// func init() { +// globalSessions, _ = session.NewManager("redis", ``{"cookieName":"gosessionid","gclifetime":3600,"ProviderConfig":"127.0.0.1:7070"}``) +// go globalSessions.GC() +// } +// +// more docs: http://beego.me/docs/module/session.md +package redis + +import ( + "net/http" + "strconv" + "strings" + "sync" + "time" + + "github.com/astaxie/beego/session" + + "github.com/gomodule/redigo/redis" +) + +var redispder = &Provider{} + +// MaxPoolSize redis max pool size +var MaxPoolSize = 100 + +// SessionStore redis session store +type SessionStore struct { + p *redis.Pool + sid string + lock sync.RWMutex + values map[interface{}]interface{} + maxlifetime int64 +} + +// Set value in redis session +func (rs *SessionStore) Set(key, value interface{}) error { + rs.lock.Lock() + defer rs.lock.Unlock() + rs.values[key] = value + return nil +} + +// Get value in redis session +func (rs *SessionStore) Get(key interface{}) interface{} { + rs.lock.RLock() + defer rs.lock.RUnlock() + if v, ok := rs.values[key]; ok { + return v + } + return nil +} + +// Delete value in redis session +func (rs *SessionStore) Delete(key interface{}) error { + rs.lock.Lock() + defer rs.lock.Unlock() + delete(rs.values, key) + return nil +} + +// Flush clear all values in redis session +func (rs *SessionStore) Flush() error { + rs.lock.Lock() + defer rs.lock.Unlock() + rs.values = make(map[interface{}]interface{}) + return nil +} + +// SessionID get redis session id +func (rs *SessionStore) SessionID() string { + return rs.sid +} + +// SessionRelease save session values to redis +func (rs *SessionStore) SessionRelease(w http.ResponseWriter) { + b, err := session.EncodeGob(rs.values) + if err != nil { + return + } + c := rs.p.Get() + defer c.Close() + c.Do("SETEX", rs.sid, rs.maxlifetime, string(b)) +} + +// Provider redis session provider +type Provider struct { + maxlifetime int64 + savePath string + poolsize int + password string + dbNum int + poollist *redis.Pool +} + +// SessionInit init redis session +// savepath like redis server addr,pool size,password,dbnum,IdleTimeout second +// e.g. 127.0.0.1:6379,100,astaxie,0,30 +func (rp *Provider) SessionInit(maxlifetime int64, savePath string) error { + rp.maxlifetime = maxlifetime + configs := strings.Split(savePath, ",") + if len(configs) > 0 { + rp.savePath = configs[0] + } + if len(configs) > 1 { + poolsize, err := strconv.Atoi(configs[1]) + if err != nil || poolsize < 0 { + rp.poolsize = MaxPoolSize + } else { + rp.poolsize = poolsize + } + } else { + rp.poolsize = MaxPoolSize + } + if len(configs) > 2 { + rp.password = configs[2] + } + if len(configs) > 3 { + dbnum, err := strconv.Atoi(configs[3]) + if err != nil || dbnum < 0 { + rp.dbNum = 0 + } else { + rp.dbNum = dbnum + } + } else { + rp.dbNum = 0 + } + var idleTimeout time.Duration = 0 + if len(configs) > 4 { + timeout, err := strconv.Atoi(configs[4]) + if err == nil && timeout > 0 { + idleTimeout = time.Duration(timeout) * time.Second + } + } + rp.poollist = &redis.Pool{ + Dial: func() (redis.Conn, error) { + c, err := redis.Dial("tcp", rp.savePath) + if err != nil { + return nil, err + } + if rp.password != "" { + if _, err = c.Do("AUTH", rp.password); err != nil { + c.Close() + return nil, err + } + } + // some redis proxy such as twemproxy is not support select command + if rp.dbNum > 0 { + _, err = c.Do("SELECT", rp.dbNum) + if err != nil { + c.Close() + return nil, err + } + } + return c, err + }, + MaxIdle: rp.poolsize, + } + + rp.poollist.IdleTimeout = idleTimeout + + return rp.poollist.Get().Err() +} + +// SessionRead read redis session by sid +func (rp *Provider) SessionRead(sid string) (session.Store, error) { + c := rp.poollist.Get() + defer c.Close() + + var kv map[interface{}]interface{} + + kvs, err := redis.String(c.Do("GET", sid)) + if err != nil && err != redis.ErrNil { + return nil, err + } + if len(kvs) == 0 { + kv = make(map[interface{}]interface{}) + } else { + if kv, err = session.DecodeGob([]byte(kvs)); err != nil { + return nil, err + } + } + + rs := &SessionStore{p: rp.poollist, sid: sid, values: kv, maxlifetime: rp.maxlifetime} + return rs, nil +} + +// SessionExist check redis session exist by sid +func (rp *Provider) SessionExist(sid string) bool { + c := rp.poollist.Get() + defer c.Close() + + if existed, err := redis.Int(c.Do("EXISTS", sid)); err != nil || existed == 0 { + return false + } + return true +} + +// SessionRegenerate generate new sid for redis session +func (rp *Provider) SessionRegenerate(oldsid, sid string) (session.Store, error) { + c := rp.poollist.Get() + defer c.Close() + + if existed, _ := redis.Int(c.Do("EXISTS", oldsid)); existed == 0 { + // oldsid doesn't exists, set the new sid directly + // ignore error here, since if it return error + // the existed value will be 0 + c.Do("SET", sid, "", "EX", rp.maxlifetime) + } else { + c.Do("RENAME", oldsid, sid) + c.Do("EXPIRE", sid, rp.maxlifetime) + } + return rp.SessionRead(sid) +} + +// SessionDestroy delete redis session by id +func (rp *Provider) SessionDestroy(sid string) error { + c := rp.poollist.Get() + defer c.Close() + + c.Do("DEL", sid) + return nil +} + +// SessionGC Impelment method, no used. +func (rp *Provider) SessionGC() { +} + +// SessionAll return all activeSession +func (rp *Provider) SessionAll() int { + return 0 +} + +func init() { + session.Register("redis", redispder) +} diff --git a/pkg/session/redis_cluster/redis_cluster.go b/pkg/session/redis_cluster/redis_cluster.go new file mode 100644 index 00000000..2fe300df --- /dev/null +++ b/pkg/session/redis_cluster/redis_cluster.go @@ -0,0 +1,220 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package redis for session provider +// +// depend on github.com/go-redis/redis +// +// go install github.com/go-redis/redis +// +// Usage: +// import( +// _ "github.com/astaxie/beego/session/redis_cluster" +// "github.com/astaxie/beego/session" +// ) +// +// func init() { +// globalSessions, _ = session.NewManager("redis_cluster", ``{"cookieName":"gosessionid","gclifetime":3600,"ProviderConfig":"127.0.0.1:7070;127.0.0.1:7071"}``) +// go globalSessions.GC() +// } +// +// more docs: http://beego.me/docs/module/session.md +package redis_cluster +import ( + "net/http" + "strconv" + "strings" + "sync" + "github.com/astaxie/beego/session" + rediss "github.com/go-redis/redis" + "time" +) + +var redispder = &Provider{} + +// MaxPoolSize redis_cluster max pool size +var MaxPoolSize = 1000 + +// SessionStore redis_cluster session store +type SessionStore struct { + p *rediss.ClusterClient + sid string + lock sync.RWMutex + values map[interface{}]interface{} + maxlifetime int64 +} + +// Set value in redis_cluster session +func (rs *SessionStore) Set(key, value interface{}) error { + rs.lock.Lock() + defer rs.lock.Unlock() + rs.values[key] = value + return nil +} + +// Get value in redis_cluster session +func (rs *SessionStore) Get(key interface{}) interface{} { + rs.lock.RLock() + defer rs.lock.RUnlock() + if v, ok := rs.values[key]; ok { + return v + } + return nil +} + +// Delete value in redis_cluster session +func (rs *SessionStore) Delete(key interface{}) error { + rs.lock.Lock() + defer rs.lock.Unlock() + delete(rs.values, key) + return nil +} + +// Flush clear all values in redis_cluster session +func (rs *SessionStore) Flush() error { + rs.lock.Lock() + defer rs.lock.Unlock() + rs.values = make(map[interface{}]interface{}) + return nil +} + +// SessionID get redis_cluster session id +func (rs *SessionStore) SessionID() string { + return rs.sid +} + +// SessionRelease save session values to redis_cluster +func (rs *SessionStore) SessionRelease(w http.ResponseWriter) { + b, err := session.EncodeGob(rs.values) + if err != nil { + return + } + c := rs.p + c.Set(rs.sid, string(b), time.Duration(rs.maxlifetime) * time.Second) +} + +// Provider redis_cluster session provider +type Provider struct { + maxlifetime int64 + savePath string + poolsize int + password string + dbNum int + poollist *rediss.ClusterClient +} + +// SessionInit init redis_cluster session +// savepath like redis server addr,pool size,password,dbnum +// e.g. 127.0.0.1:6379;127.0.0.1:6380,100,test,0 +func (rp *Provider) SessionInit(maxlifetime int64, savePath string) error { + rp.maxlifetime = maxlifetime + configs := strings.Split(savePath, ",") + if len(configs) > 0 { + rp.savePath = configs[0] + } + if len(configs) > 1 { + poolsize, err := strconv.Atoi(configs[1]) + if err != nil || poolsize < 0 { + rp.poolsize = MaxPoolSize + } else { + rp.poolsize = poolsize + } + } else { + rp.poolsize = MaxPoolSize + } + if len(configs) > 2 { + rp.password = configs[2] + } + if len(configs) > 3 { + dbnum, err := strconv.Atoi(configs[3]) + if err != nil || dbnum < 0 { + rp.dbNum = 0 + } else { + rp.dbNum = dbnum + } + } else { + rp.dbNum = 0 + } + + rp.poollist = rediss.NewClusterClient(&rediss.ClusterOptions{ + Addrs: strings.Split(rp.savePath, ";"), + Password: rp.password, + PoolSize: rp.poolsize, + }) + return rp.poollist.Ping().Err() +} + +// SessionRead read redis_cluster session by sid +func (rp *Provider) SessionRead(sid string) (session.Store, error) { + var kv map[interface{}]interface{} + kvs, err := rp.poollist.Get(sid).Result() + if err != nil && err != rediss.Nil { + return nil, err + } + if len(kvs) == 0 { + kv = make(map[interface{}]interface{}) + } else { + if kv, err = session.DecodeGob([]byte(kvs)); err != nil { + return nil, err + } + } + + rs := &SessionStore{p: rp.poollist, sid: sid, values: kv, maxlifetime: rp.maxlifetime} + return rs, nil +} + +// SessionExist check redis_cluster session exist by sid +func (rp *Provider) SessionExist(sid string) bool { + c := rp.poollist + if existed, err := c.Exists(sid).Result(); err != nil || existed == 0 { + return false + } + return true +} + +// SessionRegenerate generate new sid for redis_cluster session +func (rp *Provider) SessionRegenerate(oldsid, sid string) (session.Store, error) { + c := rp.poollist + + if existed, err := c.Exists(oldsid).Result(); err != nil || existed == 0 { + // oldsid doesn't exists, set the new sid directly + // ignore error here, since if it return error + // the existed value will be 0 + c.Set(sid, "", time.Duration(rp.maxlifetime) * time.Second) + } else { + c.Rename(oldsid, sid) + c.Expire(sid, time.Duration(rp.maxlifetime) * time.Second) + } + return rp.SessionRead(sid) +} + +// SessionDestroy delete redis session by id +func (rp *Provider) SessionDestroy(sid string) error { + c := rp.poollist + c.Del(sid) + return nil +} + +// SessionGC Impelment method, no used. +func (rp *Provider) SessionGC() { +} + +// SessionAll return all activeSession +func (rp *Provider) SessionAll() int { + return 0 +} + +func init() { + session.Register("redis_cluster", redispder) +} diff --git a/pkg/session/redis_sentinel/sess_redis_sentinel.go b/pkg/session/redis_sentinel/sess_redis_sentinel.go new file mode 100644 index 00000000..6ecb2977 --- /dev/null +++ b/pkg/session/redis_sentinel/sess_redis_sentinel.go @@ -0,0 +1,234 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package redis for session provider +// +// depend on github.com/go-redis/redis +// +// go install github.com/go-redis/redis +// +// Usage: +// import( +// _ "github.com/astaxie/beego/session/redis_sentinel" +// "github.com/astaxie/beego/session" +// ) +// +// func init() { +// globalSessions, _ = session.NewManager("redis_sentinel", ``{"cookieName":"gosessionid","gclifetime":3600,"ProviderConfig":"127.0.0.1:26379;127.0.0.2:26379"}``) +// go globalSessions.GC() +// } +// +// more detail about params: please check the notes on the function SessionInit in this package +package redis_sentinel + +import ( + "github.com/astaxie/beego/session" + "github.com/go-redis/redis" + "net/http" + "strconv" + "strings" + "sync" + "time" +) + +var redispder = &Provider{} + +// DefaultPoolSize redis_sentinel default pool size +var DefaultPoolSize = 100 + +// SessionStore redis_sentinel session store +type SessionStore struct { + p *redis.Client + sid string + lock sync.RWMutex + values map[interface{}]interface{} + maxlifetime int64 +} + +// Set value in redis_sentinel session +func (rs *SessionStore) Set(key, value interface{}) error { + rs.lock.Lock() + defer rs.lock.Unlock() + rs.values[key] = value + return nil +} + +// Get value in redis_sentinel session +func (rs *SessionStore) Get(key interface{}) interface{} { + rs.lock.RLock() + defer rs.lock.RUnlock() + if v, ok := rs.values[key]; ok { + return v + } + return nil +} + +// Delete value in redis_sentinel session +func (rs *SessionStore) Delete(key interface{}) error { + rs.lock.Lock() + defer rs.lock.Unlock() + delete(rs.values, key) + return nil +} + +// Flush clear all values in redis_sentinel session +func (rs *SessionStore) Flush() error { + rs.lock.Lock() + defer rs.lock.Unlock() + rs.values = make(map[interface{}]interface{}) + return nil +} + +// SessionID get redis_sentinel session id +func (rs *SessionStore) SessionID() string { + return rs.sid +} + +// SessionRelease save session values to redis_sentinel +func (rs *SessionStore) SessionRelease(w http.ResponseWriter) { + b, err := session.EncodeGob(rs.values) + if err != nil { + return + } + c := rs.p + c.Set(rs.sid, string(b), time.Duration(rs.maxlifetime)*time.Second) +} + +// Provider redis_sentinel session provider +type Provider struct { + maxlifetime int64 + savePath string + poolsize int + password string + dbNum int + poollist *redis.Client + masterName string +} + +// SessionInit init redis_sentinel session +// savepath like redis sentinel addr,pool size,password,dbnum,masterName +// e.g. 127.0.0.1:26379;127.0.0.2:26379,100,1qaz2wsx,0,mymaster +func (rp *Provider) SessionInit(maxlifetime int64, savePath string) error { + rp.maxlifetime = maxlifetime + configs := strings.Split(savePath, ",") + if len(configs) > 0 { + rp.savePath = configs[0] + } + if len(configs) > 1 { + poolsize, err := strconv.Atoi(configs[1]) + if err != nil || poolsize < 0 { + rp.poolsize = DefaultPoolSize + } else { + rp.poolsize = poolsize + } + } else { + rp.poolsize = DefaultPoolSize + } + if len(configs) > 2 { + rp.password = configs[2] + } + if len(configs) > 3 { + dbnum, err := strconv.Atoi(configs[3]) + if err != nil || dbnum < 0 { + rp.dbNum = 0 + } else { + rp.dbNum = dbnum + } + } else { + rp.dbNum = 0 + } + if len(configs) > 4 { + if configs[4] != "" { + rp.masterName = configs[4] + } else { + rp.masterName = "mymaster" + } + } else { + rp.masterName = "mymaster" + } + + rp.poollist = redis.NewFailoverClient(&redis.FailoverOptions{ + SentinelAddrs: strings.Split(rp.savePath, ";"), + Password: rp.password, + PoolSize: rp.poolsize, + DB: rp.dbNum, + MasterName: rp.masterName, + }) + + return rp.poollist.Ping().Err() +} + +// SessionRead read redis_sentinel session by sid +func (rp *Provider) SessionRead(sid string) (session.Store, error) { + var kv map[interface{}]interface{} + kvs, err := rp.poollist.Get(sid).Result() + if err != nil && err != redis.Nil { + return nil, err + } + if len(kvs) == 0 { + kv = make(map[interface{}]interface{}) + } else { + if kv, err = session.DecodeGob([]byte(kvs)); err != nil { + return nil, err + } + } + + rs := &SessionStore{p: rp.poollist, sid: sid, values: kv, maxlifetime: rp.maxlifetime} + return rs, nil +} + +// SessionExist check redis_sentinel session exist by sid +func (rp *Provider) SessionExist(sid string) bool { + c := rp.poollist + if existed, err := c.Exists(sid).Result(); err != nil || existed == 0 { + return false + } + return true +} + +// SessionRegenerate generate new sid for redis_sentinel session +func (rp *Provider) SessionRegenerate(oldsid, sid string) (session.Store, error) { + c := rp.poollist + + if existed, err := c.Exists(oldsid).Result(); err != nil || existed == 0 { + // oldsid doesn't exists, set the new sid directly + // ignore error here, since if it return error + // the existed value will be 0 + c.Set(sid, "", time.Duration(rp.maxlifetime)*time.Second) + } else { + c.Rename(oldsid, sid) + c.Expire(sid, time.Duration(rp.maxlifetime)*time.Second) + } + return rp.SessionRead(sid) +} + +// SessionDestroy delete redis session by id +func (rp *Provider) SessionDestroy(sid string) error { + c := rp.poollist + c.Del(sid) + return nil +} + +// SessionGC Impelment method, no used. +func (rp *Provider) SessionGC() { +} + +// SessionAll return all activeSession +func (rp *Provider) SessionAll() int { + return 0 +} + +func init() { + session.Register("redis_sentinel", redispder) +} diff --git a/pkg/session/redis_sentinel/sess_redis_sentinel_test.go b/pkg/session/redis_sentinel/sess_redis_sentinel_test.go new file mode 100644 index 00000000..fd4155c6 --- /dev/null +++ b/pkg/session/redis_sentinel/sess_redis_sentinel_test.go @@ -0,0 +1,90 @@ +package redis_sentinel + +import ( + "net/http" + "net/http/httptest" + "testing" + + "github.com/astaxie/beego/session" +) + +func TestRedisSentinel(t *testing.T) { + sessionConfig := &session.ManagerConfig{ + CookieName: "gosessionid", + EnableSetCookie: true, + Gclifetime: 3600, + Maxlifetime: 3600, + Secure: false, + CookieLifeTime: 3600, + ProviderConfig: "127.0.0.1:6379,100,,0,master", + } + globalSessions, e := session.NewManager("redis_sentinel", sessionConfig) + if e != nil { + t.Log(e) + return + } + //todo test if e==nil + go globalSessions.GC() + + r, _ := http.NewRequest("GET", "/", nil) + w := httptest.NewRecorder() + + sess, err := globalSessions.SessionStart(w, r) + if err != nil { + t.Fatal("session start failed:", err) + } + defer sess.SessionRelease(w) + + // SET AND GET + err = sess.Set("username", "astaxie") + if err != nil { + t.Fatal("set username failed:", err) + } + username := sess.Get("username") + if username != "astaxie" { + t.Fatal("get username failed") + } + + // DELETE + err = sess.Delete("username") + if err != nil { + t.Fatal("delete username failed:", err) + } + username = sess.Get("username") + if username != nil { + t.Fatal("delete username failed") + } + + // FLUSH + err = sess.Set("username", "astaxie") + if err != nil { + t.Fatal("set failed:", err) + } + err = sess.Set("password", "1qaz2wsx") + if err != nil { + t.Fatal("set failed:", err) + } + username = sess.Get("username") + if username != "astaxie" { + t.Fatal("get username failed") + } + password := sess.Get("password") + if password != "1qaz2wsx" { + t.Fatal("get password failed") + } + err = sess.Flush() + if err != nil { + t.Fatal("flush failed:", err) + } + username = sess.Get("username") + if username != nil { + t.Fatal("flush failed") + } + password = sess.Get("password") + if password != nil { + t.Fatal("flush failed") + } + + sess.SessionRelease(w) + +} diff --git a/pkg/session/sess_cookie.go b/pkg/session/sess_cookie.go new file mode 100644 index 00000000..6ad5debc --- /dev/null +++ b/pkg/session/sess_cookie.go @@ -0,0 +1,180 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package session + +import ( + "crypto/aes" + "crypto/cipher" + "encoding/json" + "net/http" + "net/url" + "sync" +) + +var cookiepder = &CookieProvider{} + +// CookieSessionStore Cookie SessionStore +type CookieSessionStore struct { + sid string + values map[interface{}]interface{} // session data + lock sync.RWMutex +} + +// Set value to cookie session. +// the value are encoded as gob with hash block string. +func (st *CookieSessionStore) Set(key, value interface{}) error { + st.lock.Lock() + defer st.lock.Unlock() + st.values[key] = value + return nil +} + +// Get value from cookie session +func (st *CookieSessionStore) Get(key interface{}) interface{} { + st.lock.RLock() + defer st.lock.RUnlock() + if v, ok := st.values[key]; ok { + return v + } + return nil +} + +// Delete value in cookie session +func (st *CookieSessionStore) Delete(key interface{}) error { + st.lock.Lock() + defer st.lock.Unlock() + delete(st.values, key) + return nil +} + +// Flush Clean all values in cookie session +func (st *CookieSessionStore) Flush() error { + st.lock.Lock() + defer st.lock.Unlock() + st.values = make(map[interface{}]interface{}) + return nil +} + +// SessionID Return id of this cookie session +func (st *CookieSessionStore) SessionID() string { + return st.sid +} + +// SessionRelease Write cookie session to http response cookie +func (st *CookieSessionStore) SessionRelease(w http.ResponseWriter) { + st.lock.Lock() + encodedCookie, err := encodeCookie(cookiepder.block, cookiepder.config.SecurityKey, cookiepder.config.SecurityName, st.values) + st.lock.Unlock() + if err == nil { + cookie := &http.Cookie{Name: cookiepder.config.CookieName, + Value: url.QueryEscape(encodedCookie), + Path: "/", + HttpOnly: true, + Secure: cookiepder.config.Secure, + MaxAge: cookiepder.config.Maxage} + http.SetCookie(w, cookie) + } +} + +type cookieConfig struct { + SecurityKey string `json:"securityKey"` + BlockKey string `json:"blockKey"` + SecurityName string `json:"securityName"` + CookieName string `json:"cookieName"` + Secure bool `json:"secure"` + Maxage int `json:"maxage"` +} + +// CookieProvider Cookie session provider +type CookieProvider struct { + maxlifetime int64 + config *cookieConfig + block cipher.Block +} + +// SessionInit Init cookie session provider with max lifetime and config json. +// maxlifetime is ignored. +// json config: +// securityKey - hash string +// blockKey - gob encode hash string. it's saved as aes crypto. +// securityName - recognized name in encoded cookie string +// cookieName - cookie name +// maxage - cookie max life time. +func (pder *CookieProvider) SessionInit(maxlifetime int64, config string) error { + pder.config = &cookieConfig{} + err := json.Unmarshal([]byte(config), pder.config) + if err != nil { + return err + } + if pder.config.BlockKey == "" { + pder.config.BlockKey = string(generateRandomKey(16)) + } + if pder.config.SecurityName == "" { + pder.config.SecurityName = string(generateRandomKey(20)) + } + pder.block, err = aes.NewCipher([]byte(pder.config.BlockKey)) + if err != nil { + return err + } + pder.maxlifetime = maxlifetime + return nil +} + +// SessionRead Get SessionStore in cooke. +// decode cooke string to map and put into SessionStore with sid. +func (pder *CookieProvider) SessionRead(sid string) (Store, error) { + maps, _ := decodeCookie(pder.block, + pder.config.SecurityKey, + pder.config.SecurityName, + sid, pder.maxlifetime) + if maps == nil { + maps = make(map[interface{}]interface{}) + } + rs := &CookieSessionStore{sid: sid, values: maps} + return rs, nil +} + +// SessionExist Cookie session is always existed +func (pder *CookieProvider) SessionExist(sid string) bool { + return true +} + +// SessionRegenerate Implement method, no used. +func (pder *CookieProvider) SessionRegenerate(oldsid, sid string) (Store, error) { + return nil, nil +} + +// SessionDestroy Implement method, no used. +func (pder *CookieProvider) SessionDestroy(sid string) error { + return nil +} + +// SessionGC Implement method, no used. +func (pder *CookieProvider) SessionGC() { +} + +// SessionAll Implement method, return 0. +func (pder *CookieProvider) SessionAll() int { + return 0 +} + +// SessionUpdate Implement method, no used. +func (pder *CookieProvider) SessionUpdate(sid string) error { + return nil +} + +func init() { + Register("cookie", cookiepder) +} diff --git a/pkg/session/sess_cookie_test.go b/pkg/session/sess_cookie_test.go new file mode 100644 index 00000000..b6726005 --- /dev/null +++ b/pkg/session/sess_cookie_test.go @@ -0,0 +1,105 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package session + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "strings" + "testing" +) + +func TestCookie(t *testing.T) { + config := `{"cookieName":"gosessionid","enableSetCookie":false,"gclifetime":3600,"ProviderConfig":"{\"cookieName\":\"gosessionid\",\"securityKey\":\"beegocookiehashkey\"}"}` + conf := new(ManagerConfig) + if err := json.Unmarshal([]byte(config), conf); err != nil { + t.Fatal("json decode error", err) + } + globalSessions, err := NewManager("cookie", conf) + if err != nil { + t.Fatal("init cookie session err", err) + } + r, _ := http.NewRequest("GET", "/", nil) + w := httptest.NewRecorder() + sess, err := globalSessions.SessionStart(w, r) + if err != nil { + t.Fatal("set error,", err) + } + err = sess.Set("username", "astaxie") + if err != nil { + t.Fatal("set error,", err) + } + if username := sess.Get("username"); username != "astaxie" { + t.Fatal("get username error") + } + sess.SessionRelease(w) + if cookiestr := w.Header().Get("Set-Cookie"); cookiestr == "" { + t.Fatal("setcookie error") + } else { + parts := strings.Split(strings.TrimSpace(cookiestr), ";") + for k, v := range parts { + nameval := strings.Split(v, "=") + if k == 0 && nameval[0] != "gosessionid" { + t.Fatal("error") + } + } + } +} + +func TestDestorySessionCookie(t *testing.T) { + config := `{"cookieName":"gosessionid","enableSetCookie":true,"gclifetime":3600,"ProviderConfig":"{\"cookieName\":\"gosessionid\",\"securityKey\":\"beegocookiehashkey\"}"}` + conf := new(ManagerConfig) + if err := json.Unmarshal([]byte(config), conf); err != nil { + t.Fatal("json decode error", err) + } + globalSessions, err := NewManager("cookie", conf) + if err != nil { + t.Fatal("init cookie session err", err) + } + + r, _ := http.NewRequest("GET", "/", nil) + w := httptest.NewRecorder() + session, err := globalSessions.SessionStart(w, r) + if err != nil { + t.Fatal("session start err,", err) + } + + // request again ,will get same sesssion id . + r1, _ := http.NewRequest("GET", "/", nil) + r1.Header.Set("Cookie", w.Header().Get("Set-Cookie")) + w = httptest.NewRecorder() + newSession, err := globalSessions.SessionStart(w, r1) + if err != nil { + t.Fatal("session start err,", err) + } + if newSession.SessionID() != session.SessionID() { + t.Fatal("get cookie session id is not the same again.") + } + + // After destroy session , will get a new session id . + globalSessions.SessionDestroy(w, r1) + r2, _ := http.NewRequest("GET", "/", nil) + r2.Header.Set("Cookie", w.Header().Get("Set-Cookie")) + + w = httptest.NewRecorder() + newSession, err = globalSessions.SessionStart(w, r2) + if err != nil { + t.Fatal("session start error") + } + if newSession.SessionID() == session.SessionID() { + t.Fatal("after destroy session and reqeust again ,get cookie session id is same.") + } +} diff --git a/pkg/session/sess_file.go b/pkg/session/sess_file.go new file mode 100644 index 00000000..47ad54a7 --- /dev/null +++ b/pkg/session/sess_file.go @@ -0,0 +1,315 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package session + +import ( + "errors" + "fmt" + "io/ioutil" + "net/http" + "os" + "path" + "path/filepath" + "strings" + "sync" + "time" +) + +var ( + filepder = &FileProvider{} + gcmaxlifetime int64 +) + +// FileSessionStore File session store +type FileSessionStore struct { + sid string + lock sync.RWMutex + values map[interface{}]interface{} +} + +// Set value to file session +func (fs *FileSessionStore) Set(key, value interface{}) error { + fs.lock.Lock() + defer fs.lock.Unlock() + fs.values[key] = value + return nil +} + +// Get value from file session +func (fs *FileSessionStore) Get(key interface{}) interface{} { + fs.lock.RLock() + defer fs.lock.RUnlock() + if v, ok := fs.values[key]; ok { + return v + } + return nil +} + +// Delete value in file session by given key +func (fs *FileSessionStore) Delete(key interface{}) error { + fs.lock.Lock() + defer fs.lock.Unlock() + delete(fs.values, key) + return nil +} + +// Flush Clean all values in file session +func (fs *FileSessionStore) Flush() error { + fs.lock.Lock() + defer fs.lock.Unlock() + fs.values = make(map[interface{}]interface{}) + return nil +} + +// SessionID Get file session store id +func (fs *FileSessionStore) SessionID() string { + return fs.sid +} + +// SessionRelease Write file session to local file with Gob string +func (fs *FileSessionStore) SessionRelease(w http.ResponseWriter) { + filepder.lock.Lock() + defer filepder.lock.Unlock() + b, err := EncodeGob(fs.values) + if err != nil { + SLogger.Println(err) + return + } + _, err = os.Stat(path.Join(filepder.savePath, string(fs.sid[0]), string(fs.sid[1]), fs.sid)) + var f *os.File + if err == nil { + f, err = os.OpenFile(path.Join(filepder.savePath, string(fs.sid[0]), string(fs.sid[1]), fs.sid), os.O_RDWR, 0777) + if err != nil { + SLogger.Println(err) + return + } + } else if os.IsNotExist(err) { + f, err = os.Create(path.Join(filepder.savePath, string(fs.sid[0]), string(fs.sid[1]), fs.sid)) + if err != nil { + SLogger.Println(err) + return + } + } else { + return + } + f.Truncate(0) + f.Seek(0, 0) + f.Write(b) + f.Close() +} + +// FileProvider File session provider +type FileProvider struct { + lock sync.RWMutex + maxlifetime int64 + savePath string +} + +// SessionInit Init file session provider. +// savePath sets the session files path. +func (fp *FileProvider) SessionInit(maxlifetime int64, savePath string) error { + fp.maxlifetime = maxlifetime + fp.savePath = savePath + return nil +} + +// SessionRead Read file session by sid. +// if file is not exist, create it. +// the file path is generated from sid string. +func (fp *FileProvider) SessionRead(sid string) (Store, error) { + invalidChars := "./" + if strings.ContainsAny(sid, invalidChars) { + return nil, errors.New("the sid shouldn't have following characters: " + invalidChars) + } + if len(sid) < 2 { + return nil, errors.New("length of the sid is less than 2") + } + filepder.lock.Lock() + defer filepder.lock.Unlock() + + err := os.MkdirAll(path.Join(fp.savePath, string(sid[0]), string(sid[1])), 0755) + if err != nil { + SLogger.Println(err.Error()) + } + _, err = os.Stat(path.Join(fp.savePath, string(sid[0]), string(sid[1]), sid)) + var f *os.File + if err == nil { + f, err = os.OpenFile(path.Join(fp.savePath, string(sid[0]), string(sid[1]), sid), os.O_RDWR, 0777) + } else if os.IsNotExist(err) { + f, err = os.Create(path.Join(fp.savePath, string(sid[0]), string(sid[1]), sid)) + } else { + return nil, err + } + + defer f.Close() + + os.Chtimes(path.Join(fp.savePath, string(sid[0]), string(sid[1]), sid), time.Now(), time.Now()) + var kv map[interface{}]interface{} + b, err := ioutil.ReadAll(f) + if err != nil { + return nil, err + } + if len(b) == 0 { + kv = make(map[interface{}]interface{}) + } else { + kv, err = DecodeGob(b) + if err != nil { + return nil, err + } + } + + ss := &FileSessionStore{sid: sid, values: kv} + return ss, nil +} + +// SessionExist Check file session exist. +// it checks the file named from sid exist or not. +func (fp *FileProvider) SessionExist(sid string) bool { + filepder.lock.Lock() + defer filepder.lock.Unlock() + + if len(sid) < 2 { + SLogger.Println("min length of session id is 2", sid) + return false + } + + _, err := os.Stat(path.Join(fp.savePath, string(sid[0]), string(sid[1]), sid)) + return err == nil +} + +// SessionDestroy Remove all files in this save path +func (fp *FileProvider) SessionDestroy(sid string) error { + filepder.lock.Lock() + defer filepder.lock.Unlock() + os.Remove(path.Join(fp.savePath, string(sid[0]), string(sid[1]), sid)) + return nil +} + +// SessionGC Recycle files in save path +func (fp *FileProvider) SessionGC() { + filepder.lock.Lock() + defer filepder.lock.Unlock() + + gcmaxlifetime = fp.maxlifetime + filepath.Walk(fp.savePath, gcpath) +} + +// SessionAll Get active file session number. +// it walks save path to count files. +func (fp *FileProvider) SessionAll() int { + a := &activeSession{} + err := filepath.Walk(fp.savePath, func(path string, f os.FileInfo, err error) error { + return a.visit(path, f, err) + }) + if err != nil { + SLogger.Printf("filepath.Walk() returned %v\n", err) + return 0 + } + return a.total +} + +// SessionRegenerate Generate new sid for file session. +// it delete old file and create new file named from new sid. +func (fp *FileProvider) SessionRegenerate(oldsid, sid string) (Store, error) { + filepder.lock.Lock() + defer filepder.lock.Unlock() + + oldPath := path.Join(fp.savePath, string(oldsid[0]), string(oldsid[1])) + oldSidFile := path.Join(oldPath, oldsid) + newPath := path.Join(fp.savePath, string(sid[0]), string(sid[1])) + newSidFile := path.Join(newPath, sid) + + // new sid file is exist + _, err := os.Stat(newSidFile) + if err == nil { + return nil, fmt.Errorf("newsid %s exist", newSidFile) + } + + err = os.MkdirAll(newPath, 0755) + if err != nil { + SLogger.Println(err.Error()) + } + + // if old sid file exist + // 1.read and parse file content + // 2.write content to new sid file + // 3.remove old sid file, change new sid file atime and ctime + // 4.return FileSessionStore + _, err = os.Stat(oldSidFile) + if err == nil { + b, err := ioutil.ReadFile(oldSidFile) + if err != nil { + return nil, err + } + + var kv map[interface{}]interface{} + if len(b) == 0 { + kv = make(map[interface{}]interface{}) + } else { + kv, err = DecodeGob(b) + if err != nil { + return nil, err + } + } + + ioutil.WriteFile(newSidFile, b, 0777) + os.Remove(oldSidFile) + os.Chtimes(newSidFile, time.Now(), time.Now()) + ss := &FileSessionStore{sid: sid, values: kv} + return ss, nil + } + + // if old sid file not exist, just create new sid file and return + newf, err := os.Create(newSidFile) + if err != nil { + return nil, err + } + newf.Close() + ss := &FileSessionStore{sid: sid, values: make(map[interface{}]interface{})} + return ss, nil +} + +// remove file in save path if expired +func gcpath(path string, info os.FileInfo, err error) error { + if err != nil { + return err + } + if info.IsDir() { + return nil + } + if (info.ModTime().Unix() + gcmaxlifetime) < time.Now().Unix() { + os.Remove(path) + } + return nil +} + +type activeSession struct { + total int +} + +func (as *activeSession) visit(paths string, f os.FileInfo, err error) error { + if err != nil { + return err + } + if f.IsDir() { + return nil + } + as.total = as.total + 1 + return nil +} + +func init() { + Register("file", filepder) +} diff --git a/pkg/session/sess_file_test.go b/pkg/session/sess_file_test.go new file mode 100644 index 00000000..0cf021db --- /dev/null +++ b/pkg/session/sess_file_test.go @@ -0,0 +1,387 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package session + +import ( + "fmt" + "os" + "sync" + "testing" + "time" +) + +const sid = "Session_id" +const sidNew = "Session_id_new" +const sessionPath = "./_session_runtime" + +var ( + mutex sync.Mutex +) + +func TestFileProvider_SessionInit(t *testing.T) { + mutex.Lock() + defer mutex.Unlock() + os.RemoveAll(sessionPath) + defer os.RemoveAll(sessionPath) + fp := &FileProvider{} + + _ = fp.SessionInit(180, sessionPath) + if fp.maxlifetime != 180 { + t.Error() + } + + if fp.savePath != sessionPath { + t.Error() + } +} + +func TestFileProvider_SessionExist(t *testing.T) { + mutex.Lock() + defer mutex.Unlock() + os.RemoveAll(sessionPath) + defer os.RemoveAll(sessionPath) + fp := &FileProvider{} + + _ = fp.SessionInit(180, sessionPath) + + if fp.SessionExist(sid) { + t.Error() + } + + _, err := fp.SessionRead(sid) + if err != nil { + t.Error(err) + } + + if !fp.SessionExist(sid) { + t.Error() + } +} + +func TestFileProvider_SessionExist2(t *testing.T) { + mutex.Lock() + defer mutex.Unlock() + os.RemoveAll(sessionPath) + defer os.RemoveAll(sessionPath) + fp := &FileProvider{} + + _ = fp.SessionInit(180, sessionPath) + + if fp.SessionExist(sid) { + t.Error() + } + + if fp.SessionExist("") { + t.Error() + } + + if fp.SessionExist("1") { + t.Error() + } +} + +func TestFileProvider_SessionRead(t *testing.T) { + mutex.Lock() + defer mutex.Unlock() + os.RemoveAll(sessionPath) + defer os.RemoveAll(sessionPath) + fp := &FileProvider{} + + _ = fp.SessionInit(180, sessionPath) + + s, err := fp.SessionRead(sid) + if err != nil { + t.Error(err) + } + + _ = s.Set("sessionValue", 18975) + v := s.Get("sessionValue") + + if v.(int) != 18975 { + t.Error() + } +} + +func TestFileProvider_SessionRead1(t *testing.T) { + mutex.Lock() + defer mutex.Unlock() + os.RemoveAll(sessionPath) + defer os.RemoveAll(sessionPath) + fp := &FileProvider{} + + _ = fp.SessionInit(180, sessionPath) + + _, err := fp.SessionRead("") + if err == nil { + t.Error(err) + } + + _, err = fp.SessionRead("1") + if err == nil { + t.Error(err) + } +} + +func TestFileProvider_SessionAll(t *testing.T) { + mutex.Lock() + defer mutex.Unlock() + os.RemoveAll(sessionPath) + defer os.RemoveAll(sessionPath) + fp := &FileProvider{} + + _ = fp.SessionInit(180, sessionPath) + + sessionCount := 546 + + for i := 1; i <= sessionCount; i++ { + _, err := fp.SessionRead(fmt.Sprintf("%s_%d", sid, i)) + if err != nil { + t.Error(err) + } + } + + if fp.SessionAll() != sessionCount { + t.Error() + } +} + +func TestFileProvider_SessionRegenerate(t *testing.T) { + mutex.Lock() + defer mutex.Unlock() + os.RemoveAll(sessionPath) + defer os.RemoveAll(sessionPath) + fp := &FileProvider{} + + _ = fp.SessionInit(180, sessionPath) + + _, err := fp.SessionRead(sid) + if err != nil { + t.Error(err) + } + + if !fp.SessionExist(sid) { + t.Error() + } + + _, err = fp.SessionRegenerate(sid, sidNew) + if err != nil { + t.Error(err) + } + + if fp.SessionExist(sid) { + t.Error() + } + + if !fp.SessionExist(sidNew) { + t.Error() + } +} + +func TestFileProvider_SessionDestroy(t *testing.T) { + mutex.Lock() + defer mutex.Unlock() + os.RemoveAll(sessionPath) + defer os.RemoveAll(sessionPath) + fp := &FileProvider{} + + _ = fp.SessionInit(180, sessionPath) + + _, err := fp.SessionRead(sid) + if err != nil { + t.Error(err) + } + + if !fp.SessionExist(sid) { + t.Error() + } + + err = fp.SessionDestroy(sid) + if err != nil { + t.Error(err) + } + + if fp.SessionExist(sid) { + t.Error() + } +} + +func TestFileProvider_SessionGC(t *testing.T) { + mutex.Lock() + defer mutex.Unlock() + os.RemoveAll(sessionPath) + defer os.RemoveAll(sessionPath) + fp := &FileProvider{} + + _ = fp.SessionInit(1, sessionPath) + + sessionCount := 412 + + for i := 1; i <= sessionCount; i++ { + _, err := fp.SessionRead(fmt.Sprintf("%s_%d", sid, i)) + if err != nil { + t.Error(err) + } + } + + time.Sleep(2 * time.Second) + + fp.SessionGC() + if fp.SessionAll() != 0 { + t.Error() + } +} + +func TestFileSessionStore_Set(t *testing.T) { + mutex.Lock() + defer mutex.Unlock() + os.RemoveAll(sessionPath) + defer os.RemoveAll(sessionPath) + fp := &FileProvider{} + + _ = fp.SessionInit(180, sessionPath) + + sessionCount := 100 + s, _ := fp.SessionRead(sid) + for i := 1; i <= sessionCount; i++ { + err := s.Set(i, i) + if err != nil { + t.Error(err) + } + } +} + +func TestFileSessionStore_Get(t *testing.T) { + mutex.Lock() + defer mutex.Unlock() + os.RemoveAll(sessionPath) + defer os.RemoveAll(sessionPath) + fp := &FileProvider{} + + _ = fp.SessionInit(180, sessionPath) + + sessionCount := 100 + s, _ := fp.SessionRead(sid) + for i := 1; i <= sessionCount; i++ { + _ = s.Set(i, i) + + v := s.Get(i) + if v.(int) != i { + t.Error() + } + } +} + +func TestFileSessionStore_Delete(t *testing.T) { + mutex.Lock() + defer mutex.Unlock() + os.RemoveAll(sessionPath) + defer os.RemoveAll(sessionPath) + fp := &FileProvider{} + + _ = fp.SessionInit(180, sessionPath) + + s, _ := fp.SessionRead(sid) + s.Set("1", 1) + + if s.Get("1") == nil { + t.Error() + } + + s.Delete("1") + + if s.Get("1") != nil { + t.Error() + } +} + +func TestFileSessionStore_Flush(t *testing.T) { + mutex.Lock() + defer mutex.Unlock() + os.RemoveAll(sessionPath) + defer os.RemoveAll(sessionPath) + fp := &FileProvider{} + + _ = fp.SessionInit(180, sessionPath) + + sessionCount := 100 + s, _ := fp.SessionRead(sid) + for i := 1; i <= sessionCount; i++ { + _ = s.Set(i, i) + } + + _ = s.Flush() + + for i := 1; i <= sessionCount; i++ { + if s.Get(i) != nil { + t.Error() + } + } +} + +func TestFileSessionStore_SessionID(t *testing.T) { + mutex.Lock() + defer mutex.Unlock() + os.RemoveAll(sessionPath) + defer os.RemoveAll(sessionPath) + fp := &FileProvider{} + + _ = fp.SessionInit(180, sessionPath) + + sessionCount := 85 + + for i := 1; i <= sessionCount; i++ { + s, err := fp.SessionRead(fmt.Sprintf("%s_%d", sid, i)) + if err != nil { + t.Error(err) + } + if s.SessionID() != fmt.Sprintf("%s_%d", sid, i) { + t.Error(err) + } + } +} + +func TestFileSessionStore_SessionRelease(t *testing.T) { + mutex.Lock() + defer mutex.Unlock() + os.RemoveAll(sessionPath) + defer os.RemoveAll(sessionPath) + fp := &FileProvider{} + + _ = fp.SessionInit(180, sessionPath) + filepder.savePath = sessionPath + sessionCount := 85 + + for i := 1; i <= sessionCount; i++ { + s, err := fp.SessionRead(fmt.Sprintf("%s_%d", sid, i)) + if err != nil { + t.Error(err) + } + + + s.Set(i,i) + s.SessionRelease(nil) + } + + for i := 1; i <= sessionCount; i++ { + s, err := fp.SessionRead(fmt.Sprintf("%s_%d", sid, i)) + if err != nil { + t.Error(err) + } + + if s.Get(i).(int) != i { + t.Error() + } + } +} \ No newline at end of file diff --git a/pkg/session/sess_mem.go b/pkg/session/sess_mem.go new file mode 100644 index 00000000..64d8b056 --- /dev/null +++ b/pkg/session/sess_mem.go @@ -0,0 +1,196 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package session + +import ( + "container/list" + "net/http" + "sync" + "time" +) + +var mempder = &MemProvider{list: list.New(), sessions: make(map[string]*list.Element)} + +// MemSessionStore memory session store. +// it saved sessions in a map in memory. +type MemSessionStore struct { + sid string //session id + timeAccessed time.Time //last access time + value map[interface{}]interface{} //session store + lock sync.RWMutex +} + +// Set value to memory session +func (st *MemSessionStore) Set(key, value interface{}) error { + st.lock.Lock() + defer st.lock.Unlock() + st.value[key] = value + return nil +} + +// Get value from memory session by key +func (st *MemSessionStore) Get(key interface{}) interface{} { + st.lock.RLock() + defer st.lock.RUnlock() + if v, ok := st.value[key]; ok { + return v + } + return nil +} + +// Delete in memory session by key +func (st *MemSessionStore) Delete(key interface{}) error { + st.lock.Lock() + defer st.lock.Unlock() + delete(st.value, key) + return nil +} + +// Flush clear all values in memory session +func (st *MemSessionStore) Flush() error { + st.lock.Lock() + defer st.lock.Unlock() + st.value = make(map[interface{}]interface{}) + return nil +} + +// SessionID get this id of memory session store +func (st *MemSessionStore) SessionID() string { + return st.sid +} + +// SessionRelease Implement method, no used. +func (st *MemSessionStore) SessionRelease(w http.ResponseWriter) { +} + +// MemProvider Implement the provider interface +type MemProvider struct { + lock sync.RWMutex // locker + sessions map[string]*list.Element // map in memory + list *list.List // for gc + maxlifetime int64 + savePath string +} + +// SessionInit init memory session +func (pder *MemProvider) SessionInit(maxlifetime int64, savePath string) error { + pder.maxlifetime = maxlifetime + pder.savePath = savePath + return nil +} + +// SessionRead get memory session store by sid +func (pder *MemProvider) SessionRead(sid string) (Store, error) { + pder.lock.RLock() + if element, ok := pder.sessions[sid]; ok { + go pder.SessionUpdate(sid) + pder.lock.RUnlock() + return element.Value.(*MemSessionStore), nil + } + pder.lock.RUnlock() + pder.lock.Lock() + newsess := &MemSessionStore{sid: sid, timeAccessed: time.Now(), value: make(map[interface{}]interface{})} + element := pder.list.PushFront(newsess) + pder.sessions[sid] = element + pder.lock.Unlock() + return newsess, nil +} + +// SessionExist check session store exist in memory session by sid +func (pder *MemProvider) SessionExist(sid string) bool { + pder.lock.RLock() + defer pder.lock.RUnlock() + if _, ok := pder.sessions[sid]; ok { + return true + } + return false +} + +// SessionRegenerate generate new sid for session store in memory session +func (pder *MemProvider) SessionRegenerate(oldsid, sid string) (Store, error) { + pder.lock.RLock() + if element, ok := pder.sessions[oldsid]; ok { + go pder.SessionUpdate(oldsid) + pder.lock.RUnlock() + pder.lock.Lock() + element.Value.(*MemSessionStore).sid = sid + pder.sessions[sid] = element + delete(pder.sessions, oldsid) + pder.lock.Unlock() + return element.Value.(*MemSessionStore), nil + } + pder.lock.RUnlock() + pder.lock.Lock() + newsess := &MemSessionStore{sid: sid, timeAccessed: time.Now(), value: make(map[interface{}]interface{})} + element := pder.list.PushFront(newsess) + pder.sessions[sid] = element + pder.lock.Unlock() + return newsess, nil +} + +// SessionDestroy delete session store in memory session by id +func (pder *MemProvider) SessionDestroy(sid string) error { + pder.lock.Lock() + defer pder.lock.Unlock() + if element, ok := pder.sessions[sid]; ok { + delete(pder.sessions, sid) + pder.list.Remove(element) + return nil + } + return nil +} + +// SessionGC clean expired session stores in memory session +func (pder *MemProvider) SessionGC() { + pder.lock.RLock() + for { + element := pder.list.Back() + if element == nil { + break + } + if (element.Value.(*MemSessionStore).timeAccessed.Unix() + pder.maxlifetime) < time.Now().Unix() { + pder.lock.RUnlock() + pder.lock.Lock() + pder.list.Remove(element) + delete(pder.sessions, element.Value.(*MemSessionStore).sid) + pder.lock.Unlock() + pder.lock.RLock() + } else { + break + } + } + pder.lock.RUnlock() +} + +// SessionAll get count number of memory session +func (pder *MemProvider) SessionAll() int { + return pder.list.Len() +} + +// SessionUpdate expand time of session store by id in memory session +func (pder *MemProvider) SessionUpdate(sid string) error { + pder.lock.Lock() + defer pder.lock.Unlock() + if element, ok := pder.sessions[sid]; ok { + element.Value.(*MemSessionStore).timeAccessed = time.Now() + pder.list.MoveToFront(element) + return nil + } + return nil +} + +func init() { + Register("memory", mempder) +} diff --git a/pkg/session/sess_mem_test.go b/pkg/session/sess_mem_test.go new file mode 100644 index 00000000..2e8934b8 --- /dev/null +++ b/pkg/session/sess_mem_test.go @@ -0,0 +1,58 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package session + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "strings" + "testing" +) + +func TestMem(t *testing.T) { + config := `{"cookieName":"gosessionid","gclifetime":10, "enableSetCookie":true}` + conf := new(ManagerConfig) + if err := json.Unmarshal([]byte(config), conf); err != nil { + t.Fatal("json decode error", err) + } + globalSessions, _ := NewManager("memory", conf) + go globalSessions.GC() + r, _ := http.NewRequest("GET", "/", nil) + w := httptest.NewRecorder() + sess, err := globalSessions.SessionStart(w, r) + if err != nil { + t.Fatal("set error,", err) + } + defer sess.SessionRelease(w) + err = sess.Set("username", "astaxie") + if err != nil { + t.Fatal("set error,", err) + } + if username := sess.Get("username"); username != "astaxie" { + t.Fatal("get username error") + } + if cookiestr := w.Header().Get("Set-Cookie"); cookiestr == "" { + t.Fatal("setcookie error") + } else { + parts := strings.Split(strings.TrimSpace(cookiestr), ";") + for k, v := range parts { + nameval := strings.Split(v, "=") + if k == 0 && nameval[0] != "gosessionid" { + t.Fatal("error") + } + } + } +} diff --git a/pkg/session/sess_test.go b/pkg/session/sess_test.go new file mode 100644 index 00000000..906abec2 --- /dev/null +++ b/pkg/session/sess_test.go @@ -0,0 +1,131 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package session + +import ( + "crypto/aes" + "encoding/json" + "testing" +) + +func Test_gob(t *testing.T) { + a := make(map[interface{}]interface{}) + a["username"] = "astaxie" + a[12] = 234 + a["user"] = User{"asta", "xie"} + b, err := EncodeGob(a) + if err != nil { + t.Error(err) + } + c, err := DecodeGob(b) + if err != nil { + t.Error(err) + } + if len(c) == 0 { + t.Error("decodeGob empty") + } + if c["username"] != "astaxie" { + t.Error("decode string error") + } + if c[12] != 234 { + t.Error("decode int error") + } + if c["user"].(User).Username != "asta" { + t.Error("decode struct error") + } +} + +type User struct { + Username string + NickName string +} + +func TestGenerate(t *testing.T) { + str := generateRandomKey(20) + if len(str) != 20 { + t.Fatal("generate length is not equal to 20") + } +} + +func TestCookieEncodeDecode(t *testing.T) { + hashKey := "testhashKey" + blockkey := generateRandomKey(16) + block, err := aes.NewCipher(blockkey) + if err != nil { + t.Fatal("NewCipher:", err) + } + securityName := string(generateRandomKey(20)) + val := make(map[interface{}]interface{}) + val["name"] = "astaxie" + val["gender"] = "male" + str, err := encodeCookie(block, hashKey, securityName, val) + if err != nil { + t.Fatal("encodeCookie:", err) + } + dst, err := decodeCookie(block, hashKey, securityName, str, 3600) + if err != nil { + t.Fatal("decodeCookie", err) + } + if dst["name"] != "astaxie" { + t.Fatal("dst get map error") + } + if dst["gender"] != "male" { + t.Fatal("dst get map error") + } +} + +func TestParseConfig(t *testing.T) { + s := `{"cookieName":"gosessionid","gclifetime":3600}` + cf := new(ManagerConfig) + cf.EnableSetCookie = true + err := json.Unmarshal([]byte(s), cf) + if err != nil { + t.Fatal("parse json error,", err) + } + if cf.CookieName != "gosessionid" { + t.Fatal("parseconfig get cookiename error") + } + if cf.Gclifetime != 3600 { + t.Fatal("parseconfig get gclifetime error") + } + + cc := `{"cookieName":"gosessionid","enableSetCookie":false,"gclifetime":3600,"ProviderConfig":"{\"cookieName\":\"gosessionid\",\"securityKey\":\"beegocookiehashkey\"}"}` + cf2 := new(ManagerConfig) + cf2.EnableSetCookie = true + err = json.Unmarshal([]byte(cc), cf2) + if err != nil { + t.Fatal("parse json error,", err) + } + if cf2.CookieName != "gosessionid" { + t.Fatal("parseconfig get cookiename error") + } + if cf2.Gclifetime != 3600 { + t.Fatal("parseconfig get gclifetime error") + } + if cf2.EnableSetCookie { + t.Fatal("parseconfig get enableSetCookie error") + } + cconfig := new(cookieConfig) + err = json.Unmarshal([]byte(cf2.ProviderConfig), cconfig) + if err != nil { + t.Fatal("parse ProviderConfig err,", err) + } + if cconfig.CookieName != "gosessionid" { + t.Fatal("ProviderConfig get cookieName error") + } + if cconfig.SecurityKey != "beegocookiehashkey" { + t.Fatal("ProviderConfig get securityKey error") + } +} diff --git a/pkg/session/sess_utils.go b/pkg/session/sess_utils.go new file mode 100644 index 00000000..20915bb6 --- /dev/null +++ b/pkg/session/sess_utils.go @@ -0,0 +1,207 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package session + +import ( + "bytes" + "crypto/cipher" + "crypto/hmac" + "crypto/rand" + "crypto/sha256" + "crypto/subtle" + "encoding/base64" + "encoding/gob" + "errors" + "fmt" + "io" + "strconv" + "time" + + "github.com/astaxie/beego/utils" +) + +func init() { + gob.Register([]interface{}{}) + gob.Register(map[int]interface{}{}) + gob.Register(map[string]interface{}{}) + gob.Register(map[interface{}]interface{}{}) + gob.Register(map[string]string{}) + gob.Register(map[int]string{}) + gob.Register(map[int]int{}) + gob.Register(map[int]int64{}) +} + +// EncodeGob encode the obj to gob +func EncodeGob(obj map[interface{}]interface{}) ([]byte, error) { + for _, v := range obj { + gob.Register(v) + } + buf := bytes.NewBuffer(nil) + enc := gob.NewEncoder(buf) + err := enc.Encode(obj) + if err != nil { + return []byte(""), err + } + return buf.Bytes(), nil +} + +// DecodeGob decode data to map +func DecodeGob(encoded []byte) (map[interface{}]interface{}, error) { + buf := bytes.NewBuffer(encoded) + dec := gob.NewDecoder(buf) + var out map[interface{}]interface{} + err := dec.Decode(&out) + if err != nil { + return nil, err + } + return out, nil +} + +// generateRandomKey creates a random key with the given strength. +func generateRandomKey(strength int) []byte { + k := make([]byte, strength) + if n, err := io.ReadFull(rand.Reader, k); n != strength || err != nil { + return utils.RandomCreateBytes(strength) + } + return k +} + +// Encryption ----------------------------------------------------------------- + +// encrypt encrypts a value using the given block in counter mode. +// +// A random initialization vector (http://goo.gl/zF67k) with the length of the +// block size is prepended to the resulting ciphertext. +func encrypt(block cipher.Block, value []byte) ([]byte, error) { + iv := generateRandomKey(block.BlockSize()) + if iv == nil { + return nil, errors.New("encrypt: failed to generate random iv") + } + // Encrypt it. + stream := cipher.NewCTR(block, iv) + stream.XORKeyStream(value, value) + // Return iv + ciphertext. + return append(iv, value...), nil +} + +// decrypt decrypts a value using the given block in counter mode. +// +// The value to be decrypted must be prepended by a initialization vector +// (http://goo.gl/zF67k) with the length of the block size. +func decrypt(block cipher.Block, value []byte) ([]byte, error) { + size := block.BlockSize() + if len(value) > size { + // Extract iv. + iv := value[:size] + // Extract ciphertext. + value = value[size:] + // Decrypt it. + stream := cipher.NewCTR(block, iv) + stream.XORKeyStream(value, value) + return value, nil + } + return nil, errors.New("decrypt: the value could not be decrypted") +} + +func encodeCookie(block cipher.Block, hashKey, name string, value map[interface{}]interface{}) (string, error) { + var err error + var b []byte + // 1. EncodeGob. + if b, err = EncodeGob(value); err != nil { + return "", err + } + // 2. Encrypt (optional). + if b, err = encrypt(block, b); err != nil { + return "", err + } + b = encode(b) + // 3. Create MAC for "name|date|value". Extra pipe to be used later. + b = []byte(fmt.Sprintf("%s|%d|%s|", name, time.Now().UTC().Unix(), b)) + h := hmac.New(sha256.New, []byte(hashKey)) + h.Write(b) + sig := h.Sum(nil) + // Append mac, remove name. + b = append(b, sig...)[len(name)+1:] + // 4. Encode to base64. + b = encode(b) + // Done. + return string(b), nil +} + +func decodeCookie(block cipher.Block, hashKey, name, value string, gcmaxlifetime int64) (map[interface{}]interface{}, error) { + // 1. Decode from base64. + b, err := decode([]byte(value)) + if err != nil { + return nil, err + } + // 2. Verify MAC. Value is "date|value|mac". + parts := bytes.SplitN(b, []byte("|"), 3) + if len(parts) != 3 { + return nil, errors.New("Decode: invalid value format") + } + + b = append([]byte(name+"|"), b[:len(b)-len(parts[2])]...) + h := hmac.New(sha256.New, []byte(hashKey)) + h.Write(b) + sig := h.Sum(nil) + if len(sig) != len(parts[2]) || subtle.ConstantTimeCompare(sig, parts[2]) != 1 { + return nil, errors.New("Decode: the value is not valid") + } + // 3. Verify date ranges. + var t1 int64 + if t1, err = strconv.ParseInt(string(parts[0]), 10, 64); err != nil { + return nil, errors.New("Decode: invalid timestamp") + } + t2 := time.Now().UTC().Unix() + if t1 > t2 { + return nil, errors.New("Decode: timestamp is too new") + } + if t1 < t2-gcmaxlifetime { + return nil, errors.New("Decode: expired timestamp") + } + // 4. Decrypt (optional). + b, err = decode(parts[1]) + if err != nil { + return nil, err + } + if b, err = decrypt(block, b); err != nil { + return nil, err + } + // 5. DecodeGob. + dst, err := DecodeGob(b) + if err != nil { + return nil, err + } + return dst, nil +} + +// Encoding ------------------------------------------------------------------- + +// encode encodes a value using base64. +func encode(value []byte) []byte { + encoded := make([]byte, base64.URLEncoding.EncodedLen(len(value))) + base64.URLEncoding.Encode(encoded, value) + return encoded +} + +// decode decodes a cookie using base64. +func decode(value []byte) ([]byte, error) { + decoded := make([]byte, base64.URLEncoding.DecodedLen(len(value))) + b, err := base64.URLEncoding.Decode(decoded, value) + if err != nil { + return nil, err + } + return decoded[:b], nil +} diff --git a/pkg/session/session.go b/pkg/session/session.go new file mode 100644 index 00000000..eb85360a --- /dev/null +++ b/pkg/session/session.go @@ -0,0 +1,377 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package session provider +// +// Usage: +// import( +// "github.com/astaxie/beego/session" +// ) +// +// func init() { +// globalSessions, _ = session.NewManager("memory", `{"cookieName":"gosessionid", "enableSetCookie,omitempty": true, "gclifetime":3600, "maxLifetime": 3600, "secure": false, "cookieLifeTime": 3600, "providerConfig": ""}`) +// go globalSessions.GC() +// } +// +// more docs: http://beego.me/docs/module/session.md +package session + +import ( + "crypto/rand" + "encoding/hex" + "errors" + "fmt" + "io" + "log" + "net/http" + "net/textproto" + "net/url" + "os" + "time" +) + +// Store contains all data for one session process with specific id. +type Store interface { + Set(key, value interface{}) error //set session value + Get(key interface{}) interface{} //get session value + Delete(key interface{}) error //delete session value + SessionID() string //back current sessionID + SessionRelease(w http.ResponseWriter) // release the resource & save data to provider & return the data + Flush() error //delete all data +} + +// Provider contains global session methods and saved SessionStores. +// it can operate a SessionStore by its id. +type Provider interface { + SessionInit(gclifetime int64, config string) error + SessionRead(sid string) (Store, error) + SessionExist(sid string) bool + SessionRegenerate(oldsid, sid string) (Store, error) + SessionDestroy(sid string) error + SessionAll() int //get all active session + SessionGC() +} + +var provides = make(map[string]Provider) + +// SLogger a helpful variable to log information about session +var SLogger = NewSessionLog(os.Stderr) + +// Register makes a session provide available by the provided name. +// If Register is called twice with the same name or if driver is nil, +// it panics. +func Register(name string, provide Provider) { + if provide == nil { + panic("session: Register provide is nil") + } + if _, dup := provides[name]; dup { + panic("session: Register called twice for provider " + name) + } + provides[name] = provide +} + +//GetProvider +func GetProvider(name string) (Provider, error) { + provider, ok := provides[name] + if !ok { + return nil, fmt.Errorf("session: unknown provide %q (forgotten import?)", name) + } + return provider, nil +} + +// ManagerConfig define the session config +type ManagerConfig struct { + CookieName string `json:"cookieName"` + EnableSetCookie bool `json:"enableSetCookie,omitempty"` + Gclifetime int64 `json:"gclifetime"` + Maxlifetime int64 `json:"maxLifetime"` + DisableHTTPOnly bool `json:"disableHTTPOnly"` + Secure bool `json:"secure"` + CookieLifeTime int `json:"cookieLifeTime"` + ProviderConfig string `json:"providerConfig"` + Domain string `json:"domain"` + SessionIDLength int64 `json:"sessionIDLength"` + EnableSidInHTTPHeader bool `json:"EnableSidInHTTPHeader"` + SessionNameInHTTPHeader string `json:"SessionNameInHTTPHeader"` + EnableSidInURLQuery bool `json:"EnableSidInURLQuery"` + SessionIDPrefix string `json:"sessionIDPrefix"` +} + +// Manager contains Provider and its configuration. +type Manager struct { + provider Provider + config *ManagerConfig +} + +// NewManager Create new Manager with provider name and json config string. +// provider name: +// 1. cookie +// 2. file +// 3. memory +// 4. redis +// 5. mysql +// json config: +// 1. is https default false +// 2. hashfunc default sha1 +// 3. hashkey default beegosessionkey +// 4. maxage default is none +func NewManager(provideName string, cf *ManagerConfig) (*Manager, error) { + provider, ok := provides[provideName] + if !ok { + return nil, fmt.Errorf("session: unknown provide %q (forgotten import?)", provideName) + } + + if cf.Maxlifetime == 0 { + cf.Maxlifetime = cf.Gclifetime + } + + if cf.EnableSidInHTTPHeader { + if cf.SessionNameInHTTPHeader == "" { + panic(errors.New("SessionNameInHTTPHeader is empty")) + } + + strMimeHeader := textproto.CanonicalMIMEHeaderKey(cf.SessionNameInHTTPHeader) + if cf.SessionNameInHTTPHeader != strMimeHeader { + strErrMsg := "SessionNameInHTTPHeader (" + cf.SessionNameInHTTPHeader + ") has the wrong format, it should be like this : " + strMimeHeader + panic(errors.New(strErrMsg)) + } + } + + err := provider.SessionInit(cf.Maxlifetime, cf.ProviderConfig) + if err != nil { + return nil, err + } + + if cf.SessionIDLength == 0 { + cf.SessionIDLength = 16 + } + + return &Manager{ + provider, + cf, + }, nil +} + +// GetProvider return current manager's provider +func (manager *Manager) GetProvider() Provider { + return manager.provider +} + +// getSid retrieves session identifier from HTTP Request. +// First try to retrieve id by reading from cookie, session cookie name is configurable, +// if not exist, then retrieve id from querying parameters. +// +// error is not nil when there is anything wrong. +// sid is empty when need to generate a new session id +// otherwise return an valid session id. +func (manager *Manager) getSid(r *http.Request) (string, error) { + cookie, errs := r.Cookie(manager.config.CookieName) + if errs != nil || cookie.Value == "" { + var sid string + if manager.config.EnableSidInURLQuery { + errs := r.ParseForm() + if errs != nil { + return "", errs + } + + sid = r.FormValue(manager.config.CookieName) + } + + // if not found in Cookie / param, then read it from request headers + if manager.config.EnableSidInHTTPHeader && sid == "" { + sids, isFound := r.Header[manager.config.SessionNameInHTTPHeader] + if isFound && len(sids) != 0 { + return sids[0], nil + } + } + + return sid, nil + } + + // HTTP Request contains cookie for sessionid info. + return url.QueryUnescape(cookie.Value) +} + +// SessionStart generate or read the session id from http request. +// if session id exists, return SessionStore with this id. +func (manager *Manager) SessionStart(w http.ResponseWriter, r *http.Request) (session Store, err error) { + sid, errs := manager.getSid(r) + if errs != nil { + return nil, errs + } + + if sid != "" && manager.provider.SessionExist(sid) { + return manager.provider.SessionRead(sid) + } + + // Generate a new session + sid, errs = manager.sessionID() + if errs != nil { + return nil, errs + } + + session, err = manager.provider.SessionRead(sid) + if err != nil { + return nil, err + } + cookie := &http.Cookie{ + Name: manager.config.CookieName, + Value: url.QueryEscape(sid), + Path: "/", + HttpOnly: !manager.config.DisableHTTPOnly, + Secure: manager.isSecure(r), + Domain: manager.config.Domain, + } + if manager.config.CookieLifeTime > 0 { + cookie.MaxAge = manager.config.CookieLifeTime + cookie.Expires = time.Now().Add(time.Duration(manager.config.CookieLifeTime) * time.Second) + } + if manager.config.EnableSetCookie { + http.SetCookie(w, cookie) + } + r.AddCookie(cookie) + + if manager.config.EnableSidInHTTPHeader { + r.Header.Set(manager.config.SessionNameInHTTPHeader, sid) + w.Header().Set(manager.config.SessionNameInHTTPHeader, sid) + } + + return +} + +// SessionDestroy Destroy session by its id in http request cookie. +func (manager *Manager) SessionDestroy(w http.ResponseWriter, r *http.Request) { + if manager.config.EnableSidInHTTPHeader { + r.Header.Del(manager.config.SessionNameInHTTPHeader) + w.Header().Del(manager.config.SessionNameInHTTPHeader) + } + + cookie, err := r.Cookie(manager.config.CookieName) + if err != nil || cookie.Value == "" { + return + } + + sid, _ := url.QueryUnescape(cookie.Value) + manager.provider.SessionDestroy(sid) + if manager.config.EnableSetCookie { + expiration := time.Now() + cookie = &http.Cookie{Name: manager.config.CookieName, + Path: "/", + HttpOnly: !manager.config.DisableHTTPOnly, + Expires: expiration, + MaxAge: -1, + Domain: manager.config.Domain} + + http.SetCookie(w, cookie) + } +} + +// GetSessionStore Get SessionStore by its id. +func (manager *Manager) GetSessionStore(sid string) (sessions Store, err error) { + sessions, err = manager.provider.SessionRead(sid) + return +} + +// GC Start session gc process. +// it can do gc in times after gc lifetime. +func (manager *Manager) GC() { + manager.provider.SessionGC() + time.AfterFunc(time.Duration(manager.config.Gclifetime)*time.Second, func() { manager.GC() }) +} + +// SessionRegenerateID Regenerate a session id for this SessionStore who's id is saving in http request. +func (manager *Manager) SessionRegenerateID(w http.ResponseWriter, r *http.Request) (session Store) { + sid, err := manager.sessionID() + if err != nil { + return + } + cookie, err := r.Cookie(manager.config.CookieName) + if err != nil || cookie.Value == "" { + //delete old cookie + session, _ = manager.provider.SessionRead(sid) + cookie = &http.Cookie{Name: manager.config.CookieName, + Value: url.QueryEscape(sid), + Path: "/", + HttpOnly: !manager.config.DisableHTTPOnly, + Secure: manager.isSecure(r), + Domain: manager.config.Domain, + } + } else { + oldsid, _ := url.QueryUnescape(cookie.Value) + session, _ = manager.provider.SessionRegenerate(oldsid, sid) + cookie.Value = url.QueryEscape(sid) + cookie.HttpOnly = true + cookie.Path = "/" + } + if manager.config.CookieLifeTime > 0 { + cookie.MaxAge = manager.config.CookieLifeTime + cookie.Expires = time.Now().Add(time.Duration(manager.config.CookieLifeTime) * time.Second) + } + if manager.config.EnableSetCookie { + http.SetCookie(w, cookie) + } + r.AddCookie(cookie) + + if manager.config.EnableSidInHTTPHeader { + r.Header.Set(manager.config.SessionNameInHTTPHeader, sid) + w.Header().Set(manager.config.SessionNameInHTTPHeader, sid) + } + + return +} + +// GetActiveSession Get all active sessions count number. +func (manager *Manager) GetActiveSession() int { + return manager.provider.SessionAll() +} + +// SetSecure Set cookie with https. +func (manager *Manager) SetSecure(secure bool) { + manager.config.Secure = secure +} + +func (manager *Manager) sessionID() (string, error) { + b := make([]byte, manager.config.SessionIDLength) + n, err := rand.Read(b) + if n != len(b) || err != nil { + return "", fmt.Errorf("Could not successfully read from the system CSPRNG") + } + return manager.config.SessionIDPrefix + hex.EncodeToString(b), nil +} + +// Set cookie with https. +func (manager *Manager) isSecure(req *http.Request) bool { + if !manager.config.Secure { + return false + } + if req.URL.Scheme != "" { + return req.URL.Scheme == "https" + } + if req.TLS == nil { + return false + } + return true +} + +// Log implement the log.Logger +type Log struct { + *log.Logger +} + +// NewSessionLog set io.Writer to create a Logger for session. +func NewSessionLog(out io.Writer) *Log { + sl := new(Log) + sl.Logger = log.New(out, "[SESSION]", 1e9) + return sl +} diff --git a/pkg/session/ssdb/sess_ssdb.go b/pkg/session/ssdb/sess_ssdb.go new file mode 100644 index 00000000..de0c6360 --- /dev/null +++ b/pkg/session/ssdb/sess_ssdb.go @@ -0,0 +1,199 @@ +package ssdb + +import ( + "errors" + "net/http" + "strconv" + "strings" + "sync" + + "github.com/astaxie/beego/session" + "github.com/ssdb/gossdb/ssdb" +) + +var ssdbProvider = &Provider{} + +// Provider holds ssdb client and configs +type Provider struct { + client *ssdb.Client + host string + port int + maxLifetime int64 +} + +func (p *Provider) connectInit() error { + var err error + if p.host == "" || p.port == 0 { + return errors.New("SessionInit First") + } + p.client, err = ssdb.Connect(p.host, p.port) + return err +} + +// SessionInit init the ssdb with the config +func (p *Provider) SessionInit(maxLifetime int64, savePath string) error { + p.maxLifetime = maxLifetime + address := strings.Split(savePath, ":") + p.host = address[0] + + var err error + if p.port, err = strconv.Atoi(address[1]); err != nil { + return err + } + return p.connectInit() +} + +// SessionRead return a ssdb client session Store +func (p *Provider) SessionRead(sid string) (session.Store, error) { + if p.client == nil { + if err := p.connectInit(); err != nil { + return nil, err + } + } + var kv map[interface{}]interface{} + value, err := p.client.Get(sid) + if err != nil { + return nil, err + } + if value == nil || len(value.(string)) == 0 { + kv = make(map[interface{}]interface{}) + } else { + kv, err = session.DecodeGob([]byte(value.(string))) + if err != nil { + return nil, err + } + } + rs := &SessionStore{sid: sid, values: kv, maxLifetime: p.maxLifetime, client: p.client} + return rs, nil +} + +// SessionExist judged whether sid is exist in session +func (p *Provider) SessionExist(sid string) bool { + if p.client == nil { + if err := p.connectInit(); err != nil { + panic(err) + } + } + value, err := p.client.Get(sid) + if err != nil { + panic(err) + } + if value == nil || len(value.(string)) == 0 { + return false + } + return true +} + +// SessionRegenerate regenerate session with new sid and delete oldsid +func (p *Provider) SessionRegenerate(oldsid, sid string) (session.Store, error) { + //conn.Do("setx", key, v, ttl) + if p.client == nil { + if err := p.connectInit(); err != nil { + return nil, err + } + } + value, err := p.client.Get(oldsid) + if err != nil { + return nil, err + } + var kv map[interface{}]interface{} + if value == nil || len(value.(string)) == 0 { + kv = make(map[interface{}]interface{}) + } else { + kv, err = session.DecodeGob([]byte(value.(string))) + if err != nil { + return nil, err + } + _, err = p.client.Del(oldsid) + if err != nil { + return nil, err + } + } + _, e := p.client.Do("setx", sid, value, p.maxLifetime) + if e != nil { + return nil, e + } + rs := &SessionStore{sid: sid, values: kv, maxLifetime: p.maxLifetime, client: p.client} + return rs, nil +} + +// SessionDestroy destroy the sid +func (p *Provider) SessionDestroy(sid string) error { + if p.client == nil { + if err := p.connectInit(); err != nil { + return err + } + } + _, err := p.client.Del(sid) + return err +} + +// SessionGC not implemented +func (p *Provider) SessionGC() { +} + +// SessionAll not implemented +func (p *Provider) SessionAll() int { + return 0 +} + +// SessionStore holds the session information which stored in ssdb +type SessionStore struct { + sid string + lock sync.RWMutex + values map[interface{}]interface{} + maxLifetime int64 + client *ssdb.Client +} + +// Set the key and value +func (s *SessionStore) Set(key, value interface{}) error { + s.lock.Lock() + defer s.lock.Unlock() + s.values[key] = value + return nil +} + +// Get return the value by the key +func (s *SessionStore) Get(key interface{}) interface{} { + s.lock.Lock() + defer s.lock.Unlock() + if value, ok := s.values[key]; ok { + return value + } + return nil +} + +// Delete the key in session store +func (s *SessionStore) Delete(key interface{}) error { + s.lock.Lock() + defer s.lock.Unlock() + delete(s.values, key) + return nil +} + +// Flush delete all keys and values +func (s *SessionStore) Flush() error { + s.lock.Lock() + defer s.lock.Unlock() + s.values = make(map[interface{}]interface{}) + return nil +} + +// SessionID return the sessionID +func (s *SessionStore) SessionID() string { + return s.sid +} + +// SessionRelease Store the keyvalues into ssdb +func (s *SessionStore) SessionRelease(w http.ResponseWriter) { + b, err := session.EncodeGob(s.values) + if err != nil { + return + } + s.client.Do("setx", s.sid, string(b), s.maxLifetime) +} + +func init() { + session.Register("ssdb", ssdbProvider) +} diff --git a/pkg/staticfile.go b/pkg/staticfile.go new file mode 100644 index 00000000..84e9aa7b --- /dev/null +++ b/pkg/staticfile.go @@ -0,0 +1,234 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package beego + +import ( + "bytes" + "errors" + "net/http" + "os" + "path" + "path/filepath" + "strconv" + "strings" + "sync" + "time" + + "github.com/astaxie/beego/context" + "github.com/astaxie/beego/logs" + "github.com/hashicorp/golang-lru" +) + +var errNotStaticRequest = errors.New("request not a static file request") + +func serverStaticRouter(ctx *context.Context) { + if ctx.Input.Method() != "GET" && ctx.Input.Method() != "HEAD" { + return + } + + forbidden, filePath, fileInfo, err := lookupFile(ctx) + if err == errNotStaticRequest { + return + } + + if forbidden { + exception("403", ctx) + return + } + + if filePath == "" || fileInfo == nil { + if BConfig.RunMode == DEV { + logs.Warn("Can't find/open the file:", filePath, err) + } + http.NotFound(ctx.ResponseWriter, ctx.Request) + return + } + if fileInfo.IsDir() { + requestURL := ctx.Input.URL() + if requestURL[len(requestURL)-1] != '/' { + redirectURL := requestURL + "/" + if ctx.Request.URL.RawQuery != "" { + redirectURL = redirectURL + "?" + ctx.Request.URL.RawQuery + } + ctx.Redirect(302, redirectURL) + } else { + //serveFile will list dir + http.ServeFile(ctx.ResponseWriter, ctx.Request, filePath) + } + return + } else if fileInfo.Size() > int64(BConfig.WebConfig.StaticCacheFileSize) { + //over size file serve with http module + http.ServeFile(ctx.ResponseWriter, ctx.Request, filePath) + return + } + + var enableCompress = BConfig.EnableGzip && isStaticCompress(filePath) + var acceptEncoding string + if enableCompress { + acceptEncoding = context.ParseEncoding(ctx.Request) + } + b, n, sch, reader, err := openFile(filePath, fileInfo, acceptEncoding) + if err != nil { + if BConfig.RunMode == DEV { + logs.Warn("Can't compress the file:", filePath, err) + } + http.NotFound(ctx.ResponseWriter, ctx.Request) + return + } + + if b { + ctx.Output.Header("Content-Encoding", n) + } else { + ctx.Output.Header("Content-Length", strconv.FormatInt(sch.size, 10)) + } + + http.ServeContent(ctx.ResponseWriter, ctx.Request, filePath, sch.modTime, reader) +} + +type serveContentHolder struct { + data []byte + modTime time.Time + size int64 + originSize int64 //original file size:to judge file changed + encoding string +} + +type serveContentReader struct { + *bytes.Reader +} + +var ( + staticFileLruCache *lru.Cache + lruLock sync.RWMutex +) + +func openFile(filePath string, fi os.FileInfo, acceptEncoding string) (bool, string, *serveContentHolder, *serveContentReader, error) { + if staticFileLruCache == nil { + //avoid lru cache error + if BConfig.WebConfig.StaticCacheFileNum >= 1 { + staticFileLruCache, _ = lru.New(BConfig.WebConfig.StaticCacheFileNum) + } else { + staticFileLruCache, _ = lru.New(1) + } + } + mapKey := acceptEncoding + ":" + filePath + lruLock.RLock() + var mapFile *serveContentHolder + if cacheItem, ok := staticFileLruCache.Get(mapKey); ok { + mapFile = cacheItem.(*serveContentHolder) + } + lruLock.RUnlock() + if isOk(mapFile, fi) { + reader := &serveContentReader{Reader: bytes.NewReader(mapFile.data)} + return mapFile.encoding != "", mapFile.encoding, mapFile, reader, nil + } + lruLock.Lock() + defer lruLock.Unlock() + if cacheItem, ok := staticFileLruCache.Get(mapKey); ok { + mapFile = cacheItem.(*serveContentHolder) + } + if !isOk(mapFile, fi) { + file, err := os.Open(filePath) + if err != nil { + return false, "", nil, nil, err + } + defer file.Close() + var bufferWriter bytes.Buffer + _, n, err := context.WriteFile(acceptEncoding, &bufferWriter, file) + if err != nil { + return false, "", nil, nil, err + } + mapFile = &serveContentHolder{data: bufferWriter.Bytes(), modTime: fi.ModTime(), size: int64(bufferWriter.Len()), originSize: fi.Size(), encoding: n} + if isOk(mapFile, fi) { + staticFileLruCache.Add(mapKey, mapFile) + } + } + + reader := &serveContentReader{Reader: bytes.NewReader(mapFile.data)} + return mapFile.encoding != "", mapFile.encoding, mapFile, reader, nil +} + +func isOk(s *serveContentHolder, fi os.FileInfo) bool { + if s == nil { + return false + } else if s.size > int64(BConfig.WebConfig.StaticCacheFileSize) { + return false + } + return s.modTime == fi.ModTime() && s.originSize == fi.Size() +} + +// isStaticCompress detect static files +func isStaticCompress(filePath string) bool { + for _, statExtension := range BConfig.WebConfig.StaticExtensionsToGzip { + if strings.HasSuffix(strings.ToLower(filePath), strings.ToLower(statExtension)) { + return true + } + } + return false +} + +// searchFile search the file by url path +// if none the static file prefix matches ,return notStaticRequestErr +func searchFile(ctx *context.Context) (string, os.FileInfo, error) { + requestPath := filepath.ToSlash(filepath.Clean(ctx.Request.URL.Path)) + // special processing : favicon.ico/robots.txt can be in any static dir + if requestPath == "/favicon.ico" || requestPath == "/robots.txt" { + file := path.Join(".", requestPath) + if fi, _ := os.Stat(file); fi != nil { + return file, fi, nil + } + for _, staticDir := range BConfig.WebConfig.StaticDir { + filePath := path.Join(staticDir, requestPath) + if fi, _ := os.Stat(filePath); fi != nil { + return filePath, fi, nil + } + } + return "", nil, errNotStaticRequest + } + + for prefix, staticDir := range BConfig.WebConfig.StaticDir { + if !strings.Contains(requestPath, prefix) { + continue + } + if prefix != "/" && len(requestPath) > len(prefix) && requestPath[len(prefix)] != '/' { + continue + } + filePath := path.Join(staticDir, requestPath[len(prefix):]) + if fi, err := os.Stat(filePath); fi != nil { + return filePath, fi, err + } + } + return "", nil, errNotStaticRequest +} + +// lookupFile find the file to serve +// if the file is dir ,search the index.html as default file( MUST NOT A DIR also) +// if the index.html not exist or is a dir, give a forbidden response depending on DirectoryIndex +func lookupFile(ctx *context.Context) (bool, string, os.FileInfo, error) { + fp, fi, err := searchFile(ctx) + if fp == "" || fi == nil { + return false, "", nil, err + } + if !fi.IsDir() { + return false, fp, fi, err + } + if requestURL := ctx.Input.URL(); requestURL[len(requestURL)-1] == '/' { + ifp := filepath.Join(fp, "index.html") + if ifi, _ := os.Stat(ifp); ifi != nil && ifi.Mode().IsRegular() { + return false, ifp, ifi, err + } + } + return !BConfig.WebConfig.DirectoryIndex, fp, fi, err +} diff --git a/pkg/staticfile_test.go b/pkg/staticfile_test.go new file mode 100644 index 00000000..e46c13ec --- /dev/null +++ b/pkg/staticfile_test.go @@ -0,0 +1,99 @@ +package beego + +import ( + "bytes" + "compress/gzip" + "compress/zlib" + "fmt" + "io" + "io/ioutil" + "os" + "path/filepath" + "testing" +) + +var currentWorkDir, _ = os.Getwd() +var licenseFile = filepath.Join(currentWorkDir, "LICENSE") + +func testOpenFile(encoding string, content []byte, t *testing.T) { + fi, _ := os.Stat(licenseFile) + b, n, sch, reader, err := openFile(licenseFile, fi, encoding) + if err != nil { + t.Log(err) + t.Fail() + } + + t.Log("open static file encoding "+n, b) + + assetOpenFileAndContent(sch, reader, content, t) +} +func TestOpenStaticFile_1(t *testing.T) { + file, _ := os.Open(licenseFile) + content, _ := ioutil.ReadAll(file) + testOpenFile("", content, t) +} + +func TestOpenStaticFileGzip_1(t *testing.T) { + file, _ := os.Open(licenseFile) + var zipBuf bytes.Buffer + fileWriter, _ := gzip.NewWriterLevel(&zipBuf, gzip.BestCompression) + io.Copy(fileWriter, file) + fileWriter.Close() + content, _ := ioutil.ReadAll(&zipBuf) + + testOpenFile("gzip", content, t) +} +func TestOpenStaticFileDeflate_1(t *testing.T) { + file, _ := os.Open(licenseFile) + var zipBuf bytes.Buffer + fileWriter, _ := zlib.NewWriterLevel(&zipBuf, zlib.BestCompression) + io.Copy(fileWriter, file) + fileWriter.Close() + content, _ := ioutil.ReadAll(&zipBuf) + + testOpenFile("deflate", content, t) +} + +func TestStaticCacheWork(t *testing.T) { + encodings := []string{"", "gzip", "deflate"} + + fi, _ := os.Stat(licenseFile) + for _, encoding := range encodings { + _, _, first, _, err := openFile(licenseFile, fi, encoding) + if err != nil { + t.Error(err) + continue + } + + _, _, second, _, err := openFile(licenseFile, fi, encoding) + if err != nil { + t.Error(err) + continue + } + + address1 := fmt.Sprintf("%p", first) + address2 := fmt.Sprintf("%p", second) + if address1 != address2 { + t.Errorf("encoding '%v' can not hit cache", encoding) + } + } +} + +func assetOpenFileAndContent(sch *serveContentHolder, reader *serveContentReader, content []byte, t *testing.T) { + t.Log(sch.size, len(content)) + if sch.size != int64(len(content)) { + t.Log("static content file size not same") + t.Fail() + } + bs, _ := ioutil.ReadAll(reader) + for i, v := range content { + if v != bs[i] { + t.Log("content not same") + t.Fail() + } + } + if staticFileLruCache.Len() == 0 { + t.Log("men map is empty") + t.Fail() + } +} diff --git a/pkg/swagger/swagger.go b/pkg/swagger/swagger.go new file mode 100644 index 00000000..a55676cd --- /dev/null +++ b/pkg/swagger/swagger.go @@ -0,0 +1,174 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Swagger™ is a project used to describe and document RESTful APIs. +// +// The Swagger specification defines a set of files required to describe such an API. These files can then be used by the Swagger-UI project to display the API and Swagger-Codegen to generate clients in various languages. Additional utilities can also take advantage of the resulting files, such as testing tools. +// Now in version 2.0, Swagger is more enabling than ever. And it's 100% open source software. + +// Package swagger struct definition +package swagger + +// Swagger list the resource +type Swagger struct { + SwaggerVersion string `json:"swagger,omitempty" yaml:"swagger,omitempty"` + Infos Information `json:"info" yaml:"info"` + Host string `json:"host,omitempty" yaml:"host,omitempty"` + BasePath string `json:"basePath,omitempty" yaml:"basePath,omitempty"` + Schemes []string `json:"schemes,omitempty" yaml:"schemes,omitempty"` + Consumes []string `json:"consumes,omitempty" yaml:"consumes,omitempty"` + Produces []string `json:"produces,omitempty" yaml:"produces,omitempty"` + Paths map[string]*Item `json:"paths" yaml:"paths"` + Definitions map[string]Schema `json:"definitions,omitempty" yaml:"definitions,omitempty"` + SecurityDefinitions map[string]Security `json:"securityDefinitions,omitempty" yaml:"securityDefinitions,omitempty"` + Security []map[string][]string `json:"security,omitempty" yaml:"security,omitempty"` + Tags []Tag `json:"tags,omitempty" yaml:"tags,omitempty"` + ExternalDocs *ExternalDocs `json:"externalDocs,omitempty" yaml:"externalDocs,omitempty"` +} + +// Information Provides metadata about the API. The metadata can be used by the clients if needed. +type Information struct { + Title string `json:"title,omitempty" yaml:"title,omitempty"` + Description string `json:"description,omitempty" yaml:"description,omitempty"` + Version string `json:"version,omitempty" yaml:"version,omitempty"` + TermsOfService string `json:"termsOfService,omitempty" yaml:"termsOfService,omitempty"` + + Contact Contact `json:"contact,omitempty" yaml:"contact,omitempty"` + License *License `json:"license,omitempty" yaml:"license,omitempty"` +} + +// Contact information for the exposed API. +type Contact struct { + Name string `json:"name,omitempty" yaml:"name,omitempty"` + URL string `json:"url,omitempty" yaml:"url,omitempty"` + EMail string `json:"email,omitempty" yaml:"email,omitempty"` +} + +// License information for the exposed API. +type License struct { + Name string `json:"name,omitempty" yaml:"name,omitempty"` + URL string `json:"url,omitempty" yaml:"url,omitempty"` +} + +// Item Describes the operations available on a single path. +type Item struct { + Ref string `json:"$ref,omitempty" yaml:"$ref,omitempty"` + Get *Operation `json:"get,omitempty" yaml:"get,omitempty"` + Put *Operation `json:"put,omitempty" yaml:"put,omitempty"` + Post *Operation `json:"post,omitempty" yaml:"post,omitempty"` + Delete *Operation `json:"delete,omitempty" yaml:"delete,omitempty"` + Options *Operation `json:"options,omitempty" yaml:"options,omitempty"` + Head *Operation `json:"head,omitempty" yaml:"head,omitempty"` + Patch *Operation `json:"patch,omitempty" yaml:"patch,omitempty"` +} + +// Operation Describes a single API operation on a path. +type Operation struct { + Tags []string `json:"tags,omitempty" yaml:"tags,omitempty"` + Summary string `json:"summary,omitempty" yaml:"summary,omitempty"` + Description string `json:"description,omitempty" yaml:"description,omitempty"` + OperationID string `json:"operationId,omitempty" yaml:"operationId,omitempty"` + Consumes []string `json:"consumes,omitempty" yaml:"consumes,omitempty"` + Produces []string `json:"produces,omitempty" yaml:"produces,omitempty"` + Schemes []string `json:"schemes,omitempty" yaml:"schemes,omitempty"` + Parameters []Parameter `json:"parameters,omitempty" yaml:"parameters,omitempty"` + Responses map[string]Response `json:"responses,omitempty" yaml:"responses,omitempty"` + Security []map[string][]string `json:"security,omitempty" yaml:"security,omitempty"` + Deprecated bool `json:"deprecated,omitempty" yaml:"deprecated,omitempty"` +} + +// Parameter Describes a single operation parameter. +type Parameter struct { + In string `json:"in,omitempty" yaml:"in,omitempty"` + Name string `json:"name,omitempty" yaml:"name,omitempty"` + Description string `json:"description,omitempty" yaml:"description,omitempty"` + Required bool `json:"required,omitempty" yaml:"required,omitempty"` + Schema *Schema `json:"schema,omitempty" yaml:"schema,omitempty"` + Type string `json:"type,omitempty" yaml:"type,omitempty"` + Format string `json:"format,omitempty" yaml:"format,omitempty"` + Items *ParameterItems `json:"items,omitempty" yaml:"items,omitempty"` + Default interface{} `json:"default,omitempty" yaml:"default,omitempty"` +} + +// ParameterItems A limited subset of JSON-Schema's items object. It is used by parameter definitions that are not located in "body". +// http://swagger.io/specification/#itemsObject +type ParameterItems struct { + Type string `json:"type,omitempty" yaml:"type,omitempty"` + Format string `json:"format,omitempty" yaml:"format,omitempty"` + Items []*ParameterItems `json:"items,omitempty" yaml:"items,omitempty"` //Required if type is "array". Describes the type of items in the array. + CollectionFormat string `json:"collectionFormat,omitempty" yaml:"collectionFormat,omitempty"` + Default string `json:"default,omitempty" yaml:"default,omitempty"` +} + +// Schema Object allows the definition of input and output data types. +type Schema struct { + Ref string `json:"$ref,omitempty" yaml:"$ref,omitempty"` + Title string `json:"title,omitempty" yaml:"title,omitempty"` + Format string `json:"format,omitempty" yaml:"format,omitempty"` + Description string `json:"description,omitempty" yaml:"description,omitempty"` + Required []string `json:"required,omitempty" yaml:"required,omitempty"` + Type string `json:"type,omitempty" yaml:"type,omitempty"` + Items *Schema `json:"items,omitempty" yaml:"items,omitempty"` + Properties map[string]Propertie `json:"properties,omitempty" yaml:"properties,omitempty"` + Enum []interface{} `json:"enum,omitempty" yaml:"enum,omitempty"` + Example interface{} `json:"example,omitempty" yaml:"example,omitempty"` +} + +// Propertie are taken from the JSON Schema definition but their definitions were adjusted to the Swagger Specification +type Propertie struct { + Ref string `json:"$ref,omitempty" yaml:"$ref,omitempty"` + Title string `json:"title,omitempty" yaml:"title,omitempty"` + Description string `json:"description,omitempty" yaml:"description,omitempty"` + Default interface{} `json:"default,omitempty" yaml:"default,omitempty"` + Type string `json:"type,omitempty" yaml:"type,omitempty"` + Example interface{} `json:"example,omitempty" yaml:"example,omitempty"` + Required []string `json:"required,omitempty" yaml:"required,omitempty"` + Format string `json:"format,omitempty" yaml:"format,omitempty"` + ReadOnly bool `json:"readOnly,omitempty" yaml:"readOnly,omitempty"` + Properties map[string]Propertie `json:"properties,omitempty" yaml:"properties,omitempty"` + Items *Propertie `json:"items,omitempty" yaml:"items,omitempty"` + AdditionalProperties *Propertie `json:"additionalProperties,omitempty" yaml:"additionalProperties,omitempty"` +} + +// Response as they are returned from executing this operation. +type Response struct { + Description string `json:"description" yaml:"description"` + Schema *Schema `json:"schema,omitempty" yaml:"schema,omitempty"` + Ref string `json:"$ref,omitempty" yaml:"$ref,omitempty"` +} + +// Security Allows the definition of a security scheme that can be used by the operations +type Security struct { + Type string `json:"type,omitempty" yaml:"type,omitempty"` // Valid values are "basic", "apiKey" or "oauth2". + Description string `json:"description,omitempty" yaml:"description,omitempty"` + Name string `json:"name,omitempty" yaml:"name,omitempty"` + In string `json:"in,omitempty" yaml:"in,omitempty"` // Valid values are "query" or "header". + Flow string `json:"flow,omitempty" yaml:"flow,omitempty"` // Valid values are "implicit", "password", "application" or "accessCode". + AuthorizationURL string `json:"authorizationUrl,omitempty" yaml:"authorizationUrl,omitempty"` + TokenURL string `json:"tokenUrl,omitempty" yaml:"tokenUrl,omitempty"` + Scopes map[string]string `json:"scopes,omitempty" yaml:"scopes,omitempty"` // The available scopes for the OAuth2 security scheme. +} + +// Tag Allows adding meta data to a single tag that is used by the Operation Object +type Tag struct { + Name string `json:"name,omitempty" yaml:"name,omitempty"` + Description string `json:"description,omitempty" yaml:"description,omitempty"` + ExternalDocs *ExternalDocs `json:"externalDocs,omitempty" yaml:"externalDocs,omitempty"` +} + +// ExternalDocs include Additional external documentation +type ExternalDocs struct { + Description string `json:"description,omitempty" yaml:"description,omitempty"` + URL string `json:"url,omitempty" yaml:"url,omitempty"` +} diff --git a/pkg/template.go b/pkg/template.go new file mode 100644 index 00000000..59875be7 --- /dev/null +++ b/pkg/template.go @@ -0,0 +1,406 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package beego + +import ( + "errors" + "fmt" + "html/template" + "io" + "io/ioutil" + "net/http" + "os" + "path/filepath" + "regexp" + "strings" + "sync" + + "github.com/astaxie/beego/logs" + "github.com/astaxie/beego/utils" +) + +var ( + beegoTplFuncMap = make(template.FuncMap) + beeViewPathTemplateLocked = false + // beeViewPathTemplates caching map and supported template file extensions per view + beeViewPathTemplates = make(map[string]map[string]*template.Template) + templatesLock sync.RWMutex + // beeTemplateExt stores the template extension which will build + beeTemplateExt = []string{"tpl", "html", "gohtml"} + // beeTemplatePreprocessors stores associations of extension -> preprocessor handler + beeTemplateEngines = map[string]templatePreProcessor{} + beeTemplateFS = defaultFSFunc +) + +// ExecuteTemplate applies the template with name to the specified data object, +// writing the output to wr. +// A template will be executed safely in parallel. +func ExecuteTemplate(wr io.Writer, name string, data interface{}) error { + return ExecuteViewPathTemplate(wr, name, BConfig.WebConfig.ViewsPath, data) +} + +// ExecuteViewPathTemplate applies the template with name and from specific viewPath to the specified data object, +// writing the output to wr. +// A template will be executed safely in parallel. +func ExecuteViewPathTemplate(wr io.Writer, name string, viewPath string, data interface{}) error { + if BConfig.RunMode == DEV { + templatesLock.RLock() + defer templatesLock.RUnlock() + } + if beeTemplates, ok := beeViewPathTemplates[viewPath]; ok { + if t, ok := beeTemplates[name]; ok { + var err error + if t.Lookup(name) != nil { + err = t.ExecuteTemplate(wr, name, data) + } else { + err = t.Execute(wr, data) + } + if err != nil { + logs.Trace("template Execute err:", err) + } + return err + } + panic("can't find templatefile in the path:" + viewPath + "/" + name) + } + panic("Unknown view path:" + viewPath) +} + +func init() { + beegoTplFuncMap["dateformat"] = DateFormat + beegoTplFuncMap["date"] = Date + beegoTplFuncMap["compare"] = Compare + beegoTplFuncMap["compare_not"] = CompareNot + beegoTplFuncMap["not_nil"] = NotNil + beegoTplFuncMap["not_null"] = NotNil + beegoTplFuncMap["substr"] = Substr + beegoTplFuncMap["html2str"] = HTML2str + beegoTplFuncMap["str2html"] = Str2html + beegoTplFuncMap["htmlquote"] = Htmlquote + beegoTplFuncMap["htmlunquote"] = Htmlunquote + beegoTplFuncMap["renderform"] = RenderForm + beegoTplFuncMap["assets_js"] = AssetsJs + beegoTplFuncMap["assets_css"] = AssetsCSS + beegoTplFuncMap["config"] = GetConfig + beegoTplFuncMap["map_get"] = MapGet + + // Comparisons + beegoTplFuncMap["eq"] = eq // == + beegoTplFuncMap["ge"] = ge // >= + beegoTplFuncMap["gt"] = gt // > + beegoTplFuncMap["le"] = le // <= + beegoTplFuncMap["lt"] = lt // < + beegoTplFuncMap["ne"] = ne // != + + beegoTplFuncMap["urlfor"] = URLFor // build a URL to match a Controller and it's method +} + +// AddFuncMap let user to register a func in the template. +func AddFuncMap(key string, fn interface{}) error { + beegoTplFuncMap[key] = fn + return nil +} + +type templatePreProcessor func(root, path string, funcs template.FuncMap) (*template.Template, error) + +type templateFile struct { + root string + files map[string][]string +} + +// visit will make the paths into two part,the first is subDir (without tf.root),the second is full path(without tf.root). +// if tf.root="views" and +// paths is "views/errors/404.html",the subDir will be "errors",the file will be "errors/404.html" +// paths is "views/admin/errors/404.html",the subDir will be "admin/errors",the file will be "admin/errors/404.html" +func (tf *templateFile) visit(paths string, f os.FileInfo, err error) error { + if f == nil { + return err + } + if f.IsDir() || (f.Mode()&os.ModeSymlink) > 0 { + return nil + } + if !HasTemplateExt(paths) { + return nil + } + + replace := strings.NewReplacer("\\", "/") + file := strings.TrimLeft(replace.Replace(paths[len(tf.root):]), "/") + subDir := filepath.Dir(file) + + tf.files[subDir] = append(tf.files[subDir], file) + return nil +} + +// HasTemplateExt return this path contains supported template extension of beego or not. +func HasTemplateExt(paths string) bool { + for _, v := range beeTemplateExt { + if strings.HasSuffix(paths, "."+v) { + return true + } + } + return false +} + +// AddTemplateExt add new extension for template. +func AddTemplateExt(ext string) { + for _, v := range beeTemplateExt { + if v == ext { + return + } + } + beeTemplateExt = append(beeTemplateExt, ext) +} + +// AddViewPath adds a new path to the supported view paths. +//Can later be used by setting a controller ViewPath to this folder +//will panic if called after beego.Run() +func AddViewPath(viewPath string) error { + if beeViewPathTemplateLocked { + if _, exist := beeViewPathTemplates[viewPath]; exist { + return nil //Ignore if viewpath already exists + } + panic("Can not add new view paths after beego.Run()") + } + beeViewPathTemplates[viewPath] = make(map[string]*template.Template) + return BuildTemplate(viewPath) +} + +func lockViewPaths() { + beeViewPathTemplateLocked = true +} + +// BuildTemplate will build all template files in a directory. +// it makes beego can render any template file in view directory. +func BuildTemplate(dir string, files ...string) error { + var err error + fs := beeTemplateFS() + f, err := fs.Open(dir) + if err != nil { + if os.IsNotExist(err) { + return nil + } + return errors.New("dir open err") + } + defer f.Close() + + beeTemplates, ok := beeViewPathTemplates[dir] + if !ok { + panic("Unknown view path: " + dir) + } + self := &templateFile{ + root: dir, + files: make(map[string][]string), + } + err = Walk(fs, dir, func(path string, f os.FileInfo, err error) error { + return self.visit(path, f, err) + }) + if err != nil { + fmt.Printf("Walk() returned %v\n", err) + return err + } + buildAllFiles := len(files) == 0 + for _, v := range self.files { + for _, file := range v { + if buildAllFiles || utils.InSlice(file, files) { + templatesLock.Lock() + ext := filepath.Ext(file) + var t *template.Template + if len(ext) == 0 { + t, err = getTemplate(self.root, fs, file, v...) + } else if fn, ok := beeTemplateEngines[ext[1:]]; ok { + t, err = fn(self.root, file, beegoTplFuncMap) + } else { + t, err = getTemplate(self.root, fs, file, v...) + } + if err != nil { + logs.Error("parse template err:", file, err) + templatesLock.Unlock() + return err + } + beeTemplates[file] = t + templatesLock.Unlock() + } + } + } + return nil +} + +func getTplDeep(root string, fs http.FileSystem, file string, parent string, t *template.Template) (*template.Template, [][]string, error) { + var fileAbsPath string + var rParent string + var err error + if strings.HasPrefix(file, "../") { + rParent = filepath.Join(filepath.Dir(parent), file) + fileAbsPath = filepath.Join(root, filepath.Dir(parent), file) + } else { + rParent = file + fileAbsPath = filepath.Join(root, file) + } + f, err := fs.Open(fileAbsPath) + if err != nil { + panic("can't find template file:" + file) + } + defer f.Close() + data, err := ioutil.ReadAll(f) + if err != nil { + return nil, [][]string{}, err + } + t, err = t.New(file).Parse(string(data)) + if err != nil { + return nil, [][]string{}, err + } + reg := regexp.MustCompile(BConfig.WebConfig.TemplateLeft + "[ ]*template[ ]+\"([^\"]+)\"") + allSub := reg.FindAllStringSubmatch(string(data), -1) + for _, m := range allSub { + if len(m) == 2 { + tl := t.Lookup(m[1]) + if tl != nil { + continue + } + if !HasTemplateExt(m[1]) { + continue + } + _, _, err = getTplDeep(root, fs, m[1], rParent, t) + if err != nil { + return nil, [][]string{}, err + } + } + } + return t, allSub, nil +} + +func getTemplate(root string, fs http.FileSystem, file string, others ...string) (t *template.Template, err error) { + t = template.New(file).Delims(BConfig.WebConfig.TemplateLeft, BConfig.WebConfig.TemplateRight).Funcs(beegoTplFuncMap) + var subMods [][]string + t, subMods, err = getTplDeep(root, fs, file, "", t) + if err != nil { + return nil, err + } + t, err = _getTemplate(t, root, fs, subMods, others...) + + if err != nil { + return nil, err + } + return +} + +func _getTemplate(t0 *template.Template, root string, fs http.FileSystem, subMods [][]string, others ...string) (t *template.Template, err error) { + t = t0 + for _, m := range subMods { + if len(m) == 2 { + tpl := t.Lookup(m[1]) + if tpl != nil { + continue + } + //first check filename + for _, otherFile := range others { + if otherFile == m[1] { + var subMods1 [][]string + t, subMods1, err = getTplDeep(root, fs, otherFile, "", t) + if err != nil { + logs.Trace("template parse file err:", err) + } else if len(subMods1) > 0 { + t, err = _getTemplate(t, root, fs, subMods1, others...) + } + break + } + } + //second check define + for _, otherFile := range others { + var data []byte + fileAbsPath := filepath.Join(root, otherFile) + f, err := fs.Open(fileAbsPath) + if err != nil { + f.Close() + logs.Trace("template file parse error, not success open file:", err) + continue + } + data, err = ioutil.ReadAll(f) + f.Close() + if err != nil { + logs.Trace("template file parse error, not success read file:", err) + continue + } + reg := regexp.MustCompile(BConfig.WebConfig.TemplateLeft + "[ ]*define[ ]+\"([^\"]+)\"") + allSub := reg.FindAllStringSubmatch(string(data), -1) + for _, sub := range allSub { + if len(sub) == 2 && sub[1] == m[1] { + var subMods1 [][]string + t, subMods1, err = getTplDeep(root, fs, otherFile, "", t) + if err != nil { + logs.Trace("template parse file err:", err) + } else if len(subMods1) > 0 { + t, err = _getTemplate(t, root, fs, subMods1, others...) + if err != nil { + logs.Trace("template parse file err:", err) + } + } + break + } + } + } + } + + } + return +} + +type templateFSFunc func() http.FileSystem + +func defaultFSFunc() http.FileSystem { + return FileSystem{} +} + +// SetTemplateFSFunc set default filesystem function +func SetTemplateFSFunc(fnt templateFSFunc) { + beeTemplateFS = fnt +} + +// SetViewsPath sets view directory path in beego application. +func SetViewsPath(path string) *App { + BConfig.WebConfig.ViewsPath = path + return BeeApp +} + +// SetStaticPath sets static directory path and proper url pattern in beego application. +// if beego.SetStaticPath("static","public"), visit /static/* to load static file in folder "public". +func SetStaticPath(url string, path string) *App { + if !strings.HasPrefix(url, "/") { + url = "/" + url + } + if url != "/" { + url = strings.TrimRight(url, "/") + } + BConfig.WebConfig.StaticDir[url] = path + return BeeApp +} + +// DelStaticPath removes the static folder setting in this url pattern in beego application. +func DelStaticPath(url string) *App { + if !strings.HasPrefix(url, "/") { + url = "/" + url + } + if url != "/" { + url = strings.TrimRight(url, "/") + } + delete(BConfig.WebConfig.StaticDir, url) + return BeeApp +} + +// AddTemplateEngine add a new templatePreProcessor which support extension +func AddTemplateEngine(extension string, fn templatePreProcessor) *App { + AddTemplateExt(extension) + beeTemplateEngines[extension] = fn + return BeeApp +} diff --git a/pkg/template_test.go b/pkg/template_test.go new file mode 100644 index 00000000..287faadc --- /dev/null +++ b/pkg/template_test.go @@ -0,0 +1,316 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package beego + +import ( + "bytes" + "github.com/astaxie/beego/testdata" + "github.com/elazarl/go-bindata-assetfs" + "net/http" + "os" + "path/filepath" + "testing" +) + +var header = `{{define "header"}} +

Hello, astaxie!

+{{end}}` + +var index = ` + + + beego welcome template + + +{{template "block"}} +{{template "header"}} +{{template "blocks/block.tpl"}} + + +` + +var block = `{{define "block"}} +

Hello, blocks!

+{{end}}` + +func TestTemplate(t *testing.T) { + dir := "_beeTmp" + files := []string{ + "header.tpl", + "index.tpl", + "blocks/block.tpl", + } + if err := os.MkdirAll(dir, 0777); err != nil { + t.Fatal(err) + } + for k, name := range files { + os.MkdirAll(filepath.Dir(filepath.Join(dir, name)), 0777) + if f, err := os.Create(filepath.Join(dir, name)); err != nil { + t.Fatal(err) + } else { + if k == 0 { + f.WriteString(header) + } else if k == 1 { + f.WriteString(index) + } else if k == 2 { + f.WriteString(block) + } + + f.Close() + } + } + if err := AddViewPath(dir); err != nil { + t.Fatal(err) + } + beeTemplates := beeViewPathTemplates[dir] + if len(beeTemplates) != 3 { + t.Fatalf("should be 3 but got %v", len(beeTemplates)) + } + if err := beeTemplates["index.tpl"].ExecuteTemplate(os.Stdout, "index.tpl", nil); err != nil { + t.Fatal(err) + } + for _, name := range files { + os.RemoveAll(filepath.Join(dir, name)) + } + os.RemoveAll(dir) +} + +var menu = ` +` +var user = ` + + + beego welcome template + + +{{template "../public/menu.tpl"}} + + +` + +func TestRelativeTemplate(t *testing.T) { + dir := "_beeTmp" + + //Just add dir to known viewPaths + if err := AddViewPath(dir); err != nil { + t.Fatal(err) + } + + files := []string{ + "easyui/public/menu.tpl", + "easyui/rbac/user.tpl", + } + if err := os.MkdirAll(dir, 0777); err != nil { + t.Fatal(err) + } + for k, name := range files { + os.MkdirAll(filepath.Dir(filepath.Join(dir, name)), 0777) + if f, err := os.Create(filepath.Join(dir, name)); err != nil { + t.Fatal(err) + } else { + if k == 0 { + f.WriteString(menu) + } else if k == 1 { + f.WriteString(user) + } + f.Close() + } + } + if err := BuildTemplate(dir, files[1]); err != nil { + t.Fatal(err) + } + beeTemplates := beeViewPathTemplates[dir] + if err := beeTemplates["easyui/rbac/user.tpl"].ExecuteTemplate(os.Stdout, "easyui/rbac/user.tpl", nil); err != nil { + t.Fatal(err) + } + for _, name := range files { + os.RemoveAll(filepath.Join(dir, name)) + } + os.RemoveAll(dir) +} + +var add = `{{ template "layout_blog.tpl" . }} +{{ define "css" }} + +{{ end}} + + +{{ define "content" }} +

{{ .Title }}

+

This is SomeVar: {{ .SomeVar }}

+{{ end }} + +{{ define "js" }} + +{{ end}}` + +var layoutBlog = ` + + + Lin Li + + + + + {{ block "css" . }}{{ end }} + + + +
+ {{ block "content" . }}{{ end }} +
+ + + {{ block "js" . }}{{ end }} + +` + +var output = ` + + + Lin Li + + + + + + + + + + +
+ +

Hello

+

This is SomeVar: val

+ +
+ + + + + + + + + + + + +` + +func TestTemplateLayout(t *testing.T) { + dir := "_beeTmp" + files := []string{ + "add.tpl", + "layout_blog.tpl", + } + if err := os.MkdirAll(dir, 0777); err != nil { + t.Fatal(err) + } + for k, name := range files { + os.MkdirAll(filepath.Dir(filepath.Join(dir, name)), 0777) + if f, err := os.Create(filepath.Join(dir, name)); err != nil { + t.Fatal(err) + } else { + if k == 0 { + f.WriteString(add) + } else if k == 1 { + f.WriteString(layoutBlog) + } + f.Close() + } + } + if err := AddViewPath(dir); err != nil { + t.Fatal(err) + } + beeTemplates := beeViewPathTemplates[dir] + if len(beeTemplates) != 2 { + t.Fatalf("should be 2 but got %v", len(beeTemplates)) + } + out := bytes.NewBufferString("") + if err := beeTemplates["add.tpl"].ExecuteTemplate(out, "add.tpl", map[string]string{"Title": "Hello", "SomeVar": "val"}); err != nil { + t.Fatal(err) + } + if out.String() != output { + t.Log(out.String()) + t.Fatal("Compare failed") + } + for _, name := range files { + os.RemoveAll(filepath.Join(dir, name)) + } + os.RemoveAll(dir) +} + +type TestingFileSystem struct { + assetfs *assetfs.AssetFS +} + +func (d TestingFileSystem) Open(name string) (http.File, error) { + return d.assetfs.Open(name) +} + +var outputBinData = ` + + + beego welcome template + + + + +

Hello, blocks!

+ + +

Hello, astaxie!

+ + + +

Hello

+

This is SomeVar: val

+ + +` + +func TestFsBinData(t *testing.T) { + SetTemplateFSFunc(func() http.FileSystem { + return TestingFileSystem{&assetfs.AssetFS{Asset: testdata.Asset, AssetDir: testdata.AssetDir, AssetInfo: testdata.AssetInfo}} + }) + dir := "views" + if err := AddViewPath("views"); err != nil { + t.Fatal(err) + } + beeTemplates := beeViewPathTemplates[dir] + if len(beeTemplates) != 3 { + t.Fatalf("should be 3 but got %v", len(beeTemplates)) + } + if err := beeTemplates["index.tpl"].ExecuteTemplate(os.Stdout, "index.tpl", map[string]string{"Title": "Hello", "SomeVar": "val"}); err != nil { + t.Fatal(err) + } + out := bytes.NewBufferString("") + if err := beeTemplates["index.tpl"].ExecuteTemplate(out, "index.tpl", map[string]string{"Title": "Hello", "SomeVar": "val"}); err != nil { + t.Fatal(err) + } + + if out.String() != outputBinData { + t.Log(out.String()) + t.Fatal("Compare failed") + } +} diff --git a/pkg/templatefunc.go b/pkg/templatefunc.go new file mode 100644 index 00000000..ba1ec5eb --- /dev/null +++ b/pkg/templatefunc.go @@ -0,0 +1,780 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package beego + +import ( + "errors" + "fmt" + "html" + "html/template" + "net/url" + "reflect" + "regexp" + "strconv" + "strings" + "time" +) + +const ( + formatTime = "15:04:05" + formatDate = "2006-01-02" + formatDateTime = "2006-01-02 15:04:05" + formatDateTimeT = "2006-01-02T15:04:05" +) + +// Substr returns the substr from start to length. +func Substr(s string, start, length int) string { + bt := []rune(s) + if start < 0 { + start = 0 + } + if start > len(bt) { + start = start % len(bt) + } + var end int + if (start + length) > (len(bt) - 1) { + end = len(bt) + } else { + end = start + length + } + return string(bt[start:end]) +} + +// HTML2str returns escaping text convert from html. +func HTML2str(html string) string { + + re := regexp.MustCompile(`\<[\S\s]+?\>`) + html = re.ReplaceAllStringFunc(html, strings.ToLower) + + //remove STYLE + re = regexp.MustCompile(`\`) + html = re.ReplaceAllString(html, "") + + //remove SCRIPT + re = regexp.MustCompile(`\`) + html = re.ReplaceAllString(html, "") + + re = regexp.MustCompile(`\<[\S\s]+?\>`) + html = re.ReplaceAllString(html, "\n") + + re = regexp.MustCompile(`\s{2,}`) + html = re.ReplaceAllString(html, "\n") + + return strings.TrimSpace(html) +} + +// DateFormat takes a time and a layout string and returns a string with the formatted date. Used by the template parser as "dateformat" +func DateFormat(t time.Time, layout string) (datestring string) { + datestring = t.Format(layout) + return +} + +// DateFormat pattern rules. +var datePatterns = []string{ + // year + "Y", "2006", // A full numeric representation of a year, 4 digits Examples: 1999 or 2003 + "y", "06", //A two digit representation of a year Examples: 99 or 03 + + // month + "m", "01", // Numeric representation of a month, with leading zeros 01 through 12 + "n", "1", // Numeric representation of a month, without leading zeros 1 through 12 + "M", "Jan", // A short textual representation of a month, three letters Jan through Dec + "F", "January", // A full textual representation of a month, such as January or March January through December + + // day + "d", "02", // Day of the month, 2 digits with leading zeros 01 to 31 + "j", "2", // Day of the month without leading zeros 1 to 31 + + // week + "D", "Mon", // A textual representation of a day, three letters Mon through Sun + "l", "Monday", // A full textual representation of the day of the week Sunday through Saturday + + // time + "g", "3", // 12-hour format of an hour without leading zeros 1 through 12 + "G", "15", // 24-hour format of an hour without leading zeros 0 through 23 + "h", "03", // 12-hour format of an hour with leading zeros 01 through 12 + "H", "15", // 24-hour format of an hour with leading zeros 00 through 23 + + "a", "pm", // Lowercase Ante meridiem and Post meridiem am or pm + "A", "PM", // Uppercase Ante meridiem and Post meridiem AM or PM + + "i", "04", // Minutes with leading zeros 00 to 59 + "s", "05", // Seconds, with leading zeros 00 through 59 + + // time zone + "T", "MST", + "P", "-07:00", + "O", "-0700", + + // RFC 2822 + "r", time.RFC1123Z, +} + +// DateParse Parse Date use PHP time format. +func DateParse(dateString, format string) (time.Time, error) { + replacer := strings.NewReplacer(datePatterns...) + format = replacer.Replace(format) + return time.ParseInLocation(format, dateString, time.Local) +} + +// Date takes a PHP like date func to Go's time format. +func Date(t time.Time, format string) string { + replacer := strings.NewReplacer(datePatterns...) + format = replacer.Replace(format) + return t.Format(format) +} + +// Compare is a quick and dirty comparison function. It will convert whatever you give it to strings and see if the two values are equal. +// Whitespace is trimmed. Used by the template parser as "eq". +func Compare(a, b interface{}) (equal bool) { + equal = false + if strings.TrimSpace(fmt.Sprintf("%v", a)) == strings.TrimSpace(fmt.Sprintf("%v", b)) { + equal = true + } + return +} + +// CompareNot !Compare +func CompareNot(a, b interface{}) (equal bool) { + return !Compare(a, b) +} + +// NotNil the same as CompareNot +func NotNil(a interface{}) (isNil bool) { + return CompareNot(a, nil) +} + +// GetConfig get the Appconfig +func GetConfig(returnType, key string, defaultVal interface{}) (value interface{}, err error) { + switch returnType { + case "String": + value = AppConfig.String(key) + case "Bool": + value, err = AppConfig.Bool(key) + case "Int": + value, err = AppConfig.Int(key) + case "Int64": + value, err = AppConfig.Int64(key) + case "Float": + value, err = AppConfig.Float(key) + case "DIY": + value, err = AppConfig.DIY(key) + default: + err = errors.New("config keys must be of type String, Bool, Int, Int64, Float, or DIY") + } + + if err != nil { + if reflect.TypeOf(returnType) != reflect.TypeOf(defaultVal) { + err = errors.New("defaultVal type does not match returnType") + } else { + value, err = defaultVal, nil + } + } else if reflect.TypeOf(value).Kind() == reflect.String { + if value == "" { + if reflect.TypeOf(defaultVal).Kind() != reflect.String { + err = errors.New("defaultVal type must be a String if the returnType is a String") + } else { + value = defaultVal.(string) + } + } + } + + return +} + +// Str2html Convert string to template.HTML type. +func Str2html(raw string) template.HTML { + return template.HTML(raw) +} + +// Htmlquote returns quoted html string. +func Htmlquote(text string) string { + //HTML编码为实体符号 + /* + Encodes `text` for raw use in HTML. + >>> htmlquote("<'&\\">") + '<'&">' + */ + + text = html.EscapeString(text) + text = strings.NewReplacer( + `“`, "“", + `”`, "”", + ` `, " ", + ).Replace(text) + + return strings.TrimSpace(text) +} + +// Htmlunquote returns unquoted html string. +func Htmlunquote(text string) string { + //实体符号解释为HTML + /* + Decodes `text` that's HTML quoted. + >>> htmlunquote('<'&">') + '<\\'&">' + */ + + text = html.UnescapeString(text) + + return strings.TrimSpace(text) +} + +// URLFor returns url string with another registered controller handler with params. +// usage: +// +// URLFor(".index") +// print URLFor("index") +// router /login +// print URLFor("login") +// print URLFor("login", "next","/"") +// router /profile/:username +// print UrlFor("profile", ":username","John Doe") +// result: +// / +// /login +// /login?next=/ +// /user/John%20Doe +// +// more detail http://beego.me/docs/mvc/controller/urlbuilding.md +func URLFor(endpoint string, values ...interface{}) string { + return BeeApp.Handlers.URLFor(endpoint, values...) +} + +// AssetsJs returns script tag with src string. +func AssetsJs(text string) template.HTML { + + text = "" + + return template.HTML(text) +} + +// AssetsCSS returns stylesheet link tag with src string. +func AssetsCSS(text string) template.HTML { + + text = "" + + return template.HTML(text) +} + +// ParseForm will parse form values to struct via tag. +// Support for anonymous struct. +func parseFormToStruct(form url.Values, objT reflect.Type, objV reflect.Value) error { + for i := 0; i < objT.NumField(); i++ { + fieldV := objV.Field(i) + if !fieldV.CanSet() { + continue + } + + fieldT := objT.Field(i) + if fieldT.Anonymous && fieldT.Type.Kind() == reflect.Struct { + err := parseFormToStruct(form, fieldT.Type, fieldV) + if err != nil { + return err + } + continue + } + + tags := strings.Split(fieldT.Tag.Get("form"), ",") + var tag string + if len(tags) == 0 || len(tags[0]) == 0 { + tag = fieldT.Name + } else if tags[0] == "-" { + continue + } else { + tag = tags[0] + } + + formValues := form[tag] + var value string + if len(formValues) == 0 { + defaultValue := fieldT.Tag.Get("default") + if defaultValue != "" { + value = defaultValue + } else { + continue + } + } + if len(formValues) == 1 { + value = formValues[0] + if value == "" { + continue + } + } + + switch fieldT.Type.Kind() { + case reflect.Bool: + if strings.ToLower(value) == "on" || strings.ToLower(value) == "1" || strings.ToLower(value) == "yes" { + fieldV.SetBool(true) + continue + } + if strings.ToLower(value) == "off" || strings.ToLower(value) == "0" || strings.ToLower(value) == "no" { + fieldV.SetBool(false) + continue + } + b, err := strconv.ParseBool(value) + if err != nil { + return err + } + fieldV.SetBool(b) + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + x, err := strconv.ParseInt(value, 10, 64) + if err != nil { + return err + } + fieldV.SetInt(x) + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + x, err := strconv.ParseUint(value, 10, 64) + if err != nil { + return err + } + fieldV.SetUint(x) + case reflect.Float32, reflect.Float64: + x, err := strconv.ParseFloat(value, 64) + if err != nil { + return err + } + fieldV.SetFloat(x) + case reflect.Interface: + fieldV.Set(reflect.ValueOf(value)) + case reflect.String: + fieldV.SetString(value) + case reflect.Struct: + switch fieldT.Type.String() { + case "time.Time": + var ( + t time.Time + err error + ) + if len(value) >= 25 { + value = value[:25] + t, err = time.ParseInLocation(time.RFC3339, value, time.Local) + } else if strings.HasSuffix(strings.ToUpper(value), "Z") { + t, err = time.ParseInLocation(time.RFC3339, value, time.Local) + } else if len(value) >= 19 { + if strings.Contains(value, "T") { + value = value[:19] + t, err = time.ParseInLocation(formatDateTimeT, value, time.Local) + } else { + value = value[:19] + t, err = time.ParseInLocation(formatDateTime, value, time.Local) + } + } else if len(value) >= 10 { + if len(value) > 10 { + value = value[:10] + } + t, err = time.ParseInLocation(formatDate, value, time.Local) + } else if len(value) >= 8 { + if len(value) > 8 { + value = value[:8] + } + t, err = time.ParseInLocation(formatTime, value, time.Local) + } + if err != nil { + return err + } + fieldV.Set(reflect.ValueOf(t)) + } + case reflect.Slice: + if fieldT.Type == sliceOfInts { + formVals := form[tag] + fieldV.Set(reflect.MakeSlice(reflect.SliceOf(reflect.TypeOf(int(1))), len(formVals), len(formVals))) + for i := 0; i < len(formVals); i++ { + val, err := strconv.Atoi(formVals[i]) + if err != nil { + return err + } + fieldV.Index(i).SetInt(int64(val)) + } + } else if fieldT.Type == sliceOfStrings { + formVals := form[tag] + fieldV.Set(reflect.MakeSlice(reflect.SliceOf(reflect.TypeOf("")), len(formVals), len(formVals))) + for i := 0; i < len(formVals); i++ { + fieldV.Index(i).SetString(formVals[i]) + } + } + } + } + return nil +} + +// ParseForm will parse form values to struct via tag. +func ParseForm(form url.Values, obj interface{}) error { + objT := reflect.TypeOf(obj) + objV := reflect.ValueOf(obj) + if !isStructPtr(objT) { + return fmt.Errorf("%v must be a struct pointer", obj) + } + objT = objT.Elem() + objV = objV.Elem() + + return parseFormToStruct(form, objT, objV) +} + +var sliceOfInts = reflect.TypeOf([]int(nil)) +var sliceOfStrings = reflect.TypeOf([]string(nil)) + +var unKind = map[reflect.Kind]bool{ + reflect.Uintptr: true, + reflect.Complex64: true, + reflect.Complex128: true, + reflect.Array: true, + reflect.Chan: true, + reflect.Func: true, + reflect.Map: true, + reflect.Ptr: true, + reflect.Slice: true, + reflect.Struct: true, + reflect.UnsafePointer: true, +} + +// RenderForm will render object to form html. +// obj must be a struct pointer. +func RenderForm(obj interface{}) template.HTML { + objT := reflect.TypeOf(obj) + objV := reflect.ValueOf(obj) + if !isStructPtr(objT) { + return template.HTML("") + } + objT = objT.Elem() + objV = objV.Elem() + + var raw []string + for i := 0; i < objT.NumField(); i++ { + fieldV := objV.Field(i) + if !fieldV.CanSet() || unKind[fieldV.Kind()] { + continue + } + + fieldT := objT.Field(i) + + label, name, fType, id, class, ignored, required := parseFormTag(fieldT) + if ignored { + continue + } + + raw = append(raw, renderFormField(label, name, fType, fieldV.Interface(), id, class, required)) + } + return template.HTML(strings.Join(raw, "
")) +} + +// renderFormField returns a string containing HTML of a single form field. +func renderFormField(label, name, fType string, value interface{}, id string, class string, required bool) string { + if id != "" { + id = " id=\"" + id + "\"" + } + + if class != "" { + class = " class=\"" + class + "\"" + } + + requiredString := "" + if required { + requiredString = " required" + } + + if isValidForInput(fType) { + return fmt.Sprintf(`%v`, label, id, class, name, fType, value, requiredString) + } + + return fmt.Sprintf(`%v<%v%v%v name="%v"%v>%v`, label, fType, id, class, name, requiredString, value, fType) +} + +// isValidForInput checks if fType is a valid value for the `type` property of an HTML input element. +func isValidForInput(fType string) bool { + validInputTypes := strings.Fields("text password checkbox radio submit reset hidden image file button search email url tel number range date month week time datetime datetime-local color") + for _, validType := range validInputTypes { + if fType == validType { + return true + } + } + return false +} + +// parseFormTag takes the stuct-tag of a StructField and parses the `form` value. +// returned are the form label, name-property, type and wether the field should be ignored. +func parseFormTag(fieldT reflect.StructField) (label, name, fType string, id string, class string, ignored bool, required bool) { + tags := strings.Split(fieldT.Tag.Get("form"), ",") + label = fieldT.Name + ": " + name = fieldT.Name + fType = "text" + ignored = false + id = fieldT.Tag.Get("id") + class = fieldT.Tag.Get("class") + + required = false + requiredField := fieldT.Tag.Get("required") + if requiredField != "-" && requiredField != "" { + required, _ = strconv.ParseBool(requiredField) + } + + switch len(tags) { + case 1: + if tags[0] == "-" { + ignored = true + } + if len(tags[0]) > 0 { + name = tags[0] + } + case 2: + if len(tags[0]) > 0 { + name = tags[0] + } + if len(tags[1]) > 0 { + fType = tags[1] + } + case 3: + if len(tags[0]) > 0 { + name = tags[0] + } + if len(tags[1]) > 0 { + fType = tags[1] + } + if len(tags[2]) > 0 { + label = tags[2] + } + } + + return +} + +func isStructPtr(t reflect.Type) bool { + return t.Kind() == reflect.Ptr && t.Elem().Kind() == reflect.Struct +} + +// go1.2 added template funcs. begin +var ( + errBadComparisonType = errors.New("invalid type for comparison") + errBadComparison = errors.New("incompatible types for comparison") + errNoComparison = errors.New("missing argument for comparison") +) + +type kind int + +const ( + invalidKind kind = iota + boolKind + complexKind + intKind + floatKind + stringKind + uintKind +) + +func basicKind(v reflect.Value) (kind, error) { + switch v.Kind() { + case reflect.Bool: + return boolKind, nil + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + return intKind, nil + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: + return uintKind, nil + case reflect.Float32, reflect.Float64: + return floatKind, nil + case reflect.Complex64, reflect.Complex128: + return complexKind, nil + case reflect.String: + return stringKind, nil + } + return invalidKind, errBadComparisonType +} + +// eq evaluates the comparison a == b || a == c || ... +func eq(arg1 interface{}, arg2 ...interface{}) (bool, error) { + v1 := reflect.ValueOf(arg1) + k1, err := basicKind(v1) + if err != nil { + return false, err + } + if len(arg2) == 0 { + return false, errNoComparison + } + for _, arg := range arg2 { + v2 := reflect.ValueOf(arg) + k2, err := basicKind(v2) + if err != nil { + return false, err + } + if k1 != k2 { + return false, errBadComparison + } + truth := false + switch k1 { + case boolKind: + truth = v1.Bool() == v2.Bool() + case complexKind: + truth = v1.Complex() == v2.Complex() + case floatKind: + truth = v1.Float() == v2.Float() + case intKind: + truth = v1.Int() == v2.Int() + case stringKind: + truth = v1.String() == v2.String() + case uintKind: + truth = v1.Uint() == v2.Uint() + default: + panic("invalid kind") + } + if truth { + return true, nil + } + } + return false, nil +} + +// ne evaluates the comparison a != b. +func ne(arg1, arg2 interface{}) (bool, error) { + // != is the inverse of ==. + equal, err := eq(arg1, arg2) + return !equal, err +} + +// lt evaluates the comparison a < b. +func lt(arg1, arg2 interface{}) (bool, error) { + v1 := reflect.ValueOf(arg1) + k1, err := basicKind(v1) + if err != nil { + return false, err + } + v2 := reflect.ValueOf(arg2) + k2, err := basicKind(v2) + if err != nil { + return false, err + } + if k1 != k2 { + return false, errBadComparison + } + truth := false + switch k1 { + case boolKind, complexKind: + return false, errBadComparisonType + case floatKind: + truth = v1.Float() < v2.Float() + case intKind: + truth = v1.Int() < v2.Int() + case stringKind: + truth = v1.String() < v2.String() + case uintKind: + truth = v1.Uint() < v2.Uint() + default: + panic("invalid kind") + } + return truth, nil +} + +// le evaluates the comparison <= b. +func le(arg1, arg2 interface{}) (bool, error) { + // <= is < or ==. + lessThan, err := lt(arg1, arg2) + if lessThan || err != nil { + return lessThan, err + } + return eq(arg1, arg2) +} + +// gt evaluates the comparison a > b. +func gt(arg1, arg2 interface{}) (bool, error) { + // > is the inverse of <=. + lessOrEqual, err := le(arg1, arg2) + if err != nil { + return false, err + } + return !lessOrEqual, nil +} + +// ge evaluates the comparison a >= b. +func ge(arg1, arg2 interface{}) (bool, error) { + // >= is the inverse of <. + lessThan, err := lt(arg1, arg2) + if err != nil { + return false, err + } + return !lessThan, nil +} + +// MapGet getting value from map by keys +// usage: +// Data["m"] = M{ +// "a": 1, +// "1": map[string]float64{ +// "c": 4, +// }, +// } +// +// {{ map_get m "a" }} // return 1 +// {{ map_get m 1 "c" }} // return 4 +func MapGet(arg1 interface{}, arg2 ...interface{}) (interface{}, error) { + arg1Type := reflect.TypeOf(arg1) + arg1Val := reflect.ValueOf(arg1) + + if arg1Type.Kind() == reflect.Map && len(arg2) > 0 { + // check whether arg2[0] type equals to arg1 key type + // if they are different, make conversion + arg2Val := reflect.ValueOf(arg2[0]) + arg2Type := reflect.TypeOf(arg2[0]) + if arg2Type.Kind() != arg1Type.Key().Kind() { + // convert arg2Value to string + var arg2ConvertedVal interface{} + arg2String := fmt.Sprintf("%v", arg2[0]) + + // convert string representation to any other type + switch arg1Type.Key().Kind() { + case reflect.Bool: + arg2ConvertedVal, _ = strconv.ParseBool(arg2String) + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + arg2ConvertedVal, _ = strconv.ParseInt(arg2String, 0, 64) + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: + arg2ConvertedVal, _ = strconv.ParseUint(arg2String, 0, 64) + case reflect.Float32, reflect.Float64: + arg2ConvertedVal, _ = strconv.ParseFloat(arg2String, 64) + case reflect.String: + arg2ConvertedVal = arg2String + default: + arg2ConvertedVal = arg2Val.Interface() + } + arg2Val = reflect.ValueOf(arg2ConvertedVal) + } + + storedVal := arg1Val.MapIndex(arg2Val) + + if storedVal.IsValid() { + var result interface{} + + switch arg1Type.Elem().Kind() { + case reflect.Bool: + result = storedVal.Bool() + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + result = storedVal.Int() + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: + result = storedVal.Uint() + case reflect.Float32, reflect.Float64: + result = storedVal.Float() + case reflect.String: + result = storedVal.String() + default: + result = storedVal.Interface() + } + + // if there is more keys, handle this recursively + if len(arg2) > 1 { + return MapGet(result, arg2[1:]...) + } + return result, nil + } + return nil, nil + + } + return nil, nil +} diff --git a/pkg/templatefunc_test.go b/pkg/templatefunc_test.go new file mode 100644 index 00000000..b4c19c2e --- /dev/null +++ b/pkg/templatefunc_test.go @@ -0,0 +1,380 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package beego + +import ( + "html/template" + "net/url" + "reflect" + "testing" + "time" +) + +func TestSubstr(t *testing.T) { + s := `012345` + if Substr(s, 0, 2) != "01" { + t.Error("should be equal") + } + if Substr(s, 0, 100) != "012345" { + t.Error("should be equal") + } + if Substr(s, 12, 100) != "012345" { + t.Error("should be equal") + } +} + +func TestHtml2str(t *testing.T) { + h := `<123> 123\n + + + \n` + if HTML2str(h) != "123\\n\n\\n" { + t.Error("should be equal") + } +} + +func TestDateFormat(t *testing.T) { + ts := "Mon, 01 Jul 2013 13:27:42 CST" + tt, _ := time.Parse(time.RFC1123, ts) + + if ss := DateFormat(tt, "2006-01-02 15:04:05"); ss != "2013-07-01 13:27:42" { + t.Errorf("2013-07-01 13:27:42 does not equal %v", ss) + } +} + +func TestDate(t *testing.T) { + ts := "Mon, 01 Jul 2013 13:27:42 CST" + tt, _ := time.Parse(time.RFC1123, ts) + + if ss := Date(tt, "Y-m-d H:i:s"); ss != "2013-07-01 13:27:42" { + t.Errorf("2013-07-01 13:27:42 does not equal %v", ss) + } + if ss := Date(tt, "y-n-j h:i:s A"); ss != "13-7-1 01:27:42 PM" { + t.Errorf("13-7-1 01:27:42 PM does not equal %v", ss) + } + if ss := Date(tt, "D, d M Y g:i:s a"); ss != "Mon, 01 Jul 2013 1:27:42 pm" { + t.Errorf("Mon, 01 Jul 2013 1:27:42 pm does not equal %v", ss) + } + if ss := Date(tt, "l, d F Y G:i:s"); ss != "Monday, 01 July 2013 13:27:42" { + t.Errorf("Monday, 01 July 2013 13:27:42 does not equal %v", ss) + } +} + +func TestCompareRelated(t *testing.T) { + if !Compare("abc", "abc") { + t.Error("should be equal") + } + if Compare("abc", "aBc") { + t.Error("should be not equal") + } + if !Compare("1", 1) { + t.Error("should be equal") + } + if CompareNot("abc", "abc") { + t.Error("should be equal") + } + if !CompareNot("abc", "aBc") { + t.Error("should be not equal") + } + if !NotNil("a string") { + t.Error("should not be nil") + } +} + +func TestHtmlquote(t *testing.T) { + h := `<' ”“&">` + s := `<' ”“&">` + if Htmlquote(s) != h { + t.Error("should be equal") + } +} + +func TestHtmlunquote(t *testing.T) { + h := `<' ”“&">` + s := `<' ”“&">` + if Htmlunquote(h) != s { + t.Error("should be equal") + } +} + +func TestParseForm(t *testing.T) { + type ExtendInfo struct { + Hobby []string `form:"hobby"` + Memo string + } + + type OtherInfo struct { + Organization string `form:"organization"` + Title string `form:"title"` + ExtendInfo + } + + type user struct { + ID int `form:"-"` + tag string `form:"tag"` + Name interface{} `form:"username"` + Age int `form:"age,text"` + Email string + Intro string `form:",textarea"` + StrBool bool `form:"strbool"` + Date time.Time `form:"date,2006-01-02"` + OtherInfo + } + + u := user{} + form := url.Values{ + "ID": []string{"1"}, + "-": []string{"1"}, + "tag": []string{"no"}, + "username": []string{"test"}, + "age": []string{"40"}, + "Email": []string{"test@gmail.com"}, + "Intro": []string{"I am an engineer!"}, + "strbool": []string{"yes"}, + "date": []string{"2014-11-12"}, + "organization": []string{"beego"}, + "title": []string{"CXO"}, + "hobby": []string{"", "Basketball", "Football"}, + "memo": []string{"nothing"}, + } + if err := ParseForm(form, u); err == nil { + t.Fatal("nothing will be changed") + } + if err := ParseForm(form, &u); err != nil { + t.Fatal(err) + } + if u.ID != 0 { + t.Errorf("ID should equal 0 but got %v", u.ID) + } + if len(u.tag) != 0 { + t.Errorf("tag's length should equal 0 but got %v", len(u.tag)) + } + if u.Name.(string) != "test" { + t.Errorf("Name should equal `test` but got `%v`", u.Name.(string)) + } + if u.Age != 40 { + t.Errorf("Age should equal 40 but got %v", u.Age) + } + if u.Email != "test@gmail.com" { + t.Errorf("Email should equal `test@gmail.com` but got `%v`", u.Email) + } + if u.Intro != "I am an engineer!" { + t.Errorf("Intro should equal `I am an engineer!` but got `%v`", u.Intro) + } + if !u.StrBool { + t.Errorf("strboll should equal `true`, but got `%v`", u.StrBool) + } + y, m, d := u.Date.Date() + if y != 2014 || m.String() != "November" || d != 12 { + t.Errorf("Date should equal `2014-11-12`, but got `%v`", u.Date.String()) + } + if u.Organization != "beego" { + t.Errorf("Organization should equal `beego`, but got `%v`", u.Organization) + } + if u.Title != "CXO" { + t.Errorf("Title should equal `CXO`, but got `%v`", u.Title) + } + if u.Hobby[0] != "" { + t.Errorf("Hobby should equal ``, but got `%v`", u.Hobby[0]) + } + if u.Hobby[1] != "Basketball" { + t.Errorf("Hobby should equal `Basketball`, but got `%v`", u.Hobby[1]) + } + if u.Hobby[2] != "Football" { + t.Errorf("Hobby should equal `Football`, but got `%v`", u.Hobby[2]) + } + if len(u.Memo) != 0 { + t.Errorf("Memo's length should equal 0 but got %v", len(u.Memo)) + } +} + +func TestRenderForm(t *testing.T) { + type user struct { + ID int `form:"-"` + Name interface{} `form:"username"` + Age int `form:"age,text,年龄:"` + Sex string + Email []string + Intro string `form:",textarea"` + Ignored string `form:"-"` + } + + u := user{Name: "test", Intro: "Some Text"} + output := RenderForm(u) + if output != template.HTML("") { + t.Errorf("output should be empty but got %v", output) + } + output = RenderForm(&u) + result := template.HTML( + `Name:
` + + `年龄:
` + + `Sex:
` + + `Intro: `) + if output != result { + t.Errorf("output should equal `%v` but got `%v`", result, output) + } +} + +func TestRenderFormField(t *testing.T) { + html := renderFormField("Label: ", "Name", "text", "Value", "", "", false) + if html != `Label: ` { + t.Errorf("Wrong html output for input[type=text]: %v ", html) + } + + html = renderFormField("Label: ", "Name", "textarea", "Value", "", "", false) + if html != `Label: ` { + t.Errorf("Wrong html output for textarea: %v ", html) + } + + html = renderFormField("Label: ", "Name", "textarea", "Value", "", "", true) + if html != `Label: ` { + t.Errorf("Wrong html output for textarea: %v ", html) + } +} + +func TestParseFormTag(t *testing.T) { + // create struct to contain field with different types of struct-tag `form` + type user struct { + All int `form:"name,text,年龄:"` + NoName int `form:",hidden,年龄:"` + OnlyLabel int `form:",,年龄:"` + OnlyName int `form:"name" id:"name" class:"form-name"` + Ignored int `form:"-"` + Required int `form:"name" required:"true"` + IgnoreRequired int `form:"name"` + NotRequired int `form:"name" required:"false"` + } + + objT := reflect.TypeOf(&user{}).Elem() + + label, name, fType, _, _, ignored, _ := parseFormTag(objT.Field(0)) + if !(name == "name" && label == "年龄:" && fType == "text" && !ignored) { + t.Errorf("Form Tag with name, label and type was not correctly parsed.") + } + + label, name, fType, _, _, ignored, _ = parseFormTag(objT.Field(1)) + if !(name == "NoName" && label == "年龄:" && fType == "hidden" && !ignored) { + t.Errorf("Form Tag with label and type but without name was not correctly parsed.") + } + + label, name, fType, _, _, ignored, _ = parseFormTag(objT.Field(2)) + if !(name == "OnlyLabel" && label == "年龄:" && fType == "text" && !ignored) { + t.Errorf("Form Tag containing only label was not correctly parsed.") + } + + label, name, fType, id, class, ignored, _ := parseFormTag(objT.Field(3)) + if !(name == "name" && label == "OnlyName: " && fType == "text" && !ignored && + id == "name" && class == "form-name") { + t.Errorf("Form Tag containing only name was not correctly parsed.") + } + + _, _, _, _, _, ignored, _ = parseFormTag(objT.Field(4)) + if !ignored { + t.Errorf("Form Tag that should be ignored was not correctly parsed.") + } + + _, name, _, _, _, _, required := parseFormTag(objT.Field(5)) + if !(name == "name" && required) { + t.Errorf("Form Tag containing only name and required was not correctly parsed.") + } + + _, name, _, _, _, _, required = parseFormTag(objT.Field(6)) + if !(name == "name" && !required) { + t.Errorf("Form Tag containing only name and ignore required was not correctly parsed.") + } + + _, name, _, _, _, _, required = parseFormTag(objT.Field(7)) + if !(name == "name" && !required) { + t.Errorf("Form Tag containing only name and not required was not correctly parsed.") + } + +} + +func TestMapGet(t *testing.T) { + // test one level map + m1 := map[string]int64{ + "a": 1, + "1": 2, + } + + if res, err := MapGet(m1, "a"); err == nil { + if res.(int64) != 1 { + t.Errorf("Should return 1, but return %v", res) + } + } else { + t.Errorf("Error happens %v", err) + } + + if res, err := MapGet(m1, "1"); err == nil { + if res.(int64) != 2 { + t.Errorf("Should return 2, but return %v", res) + } + } else { + t.Errorf("Error happens %v", err) + } + + if res, err := MapGet(m1, 1); err == nil { + if res.(int64) != 2 { + t.Errorf("Should return 2, but return %v", res) + } + } else { + t.Errorf("Error happens %v", err) + } + + // test 2 level map + m2 := M{ + "1": map[string]float64{ + "2": 3.5, + }, + } + + if res, err := MapGet(m2, 1, 2); err == nil { + if res.(float64) != 3.5 { + t.Errorf("Should return 3.5, but return %v", res) + } + } else { + t.Errorf("Error happens %v", err) + } + + // test 5 level map + m5 := M{ + "1": M{ + "2": M{ + "3": M{ + "4": M{ + "5": 1.2, + }, + }, + }, + }, + } + + if res, err := MapGet(m5, 1, 2, 3, 4, 5); err == nil { + if res.(float64) != 1.2 { + t.Errorf("Should return 1.2, but return %v", res) + } + } else { + t.Errorf("Error happens %v", err) + } + + // check whether element not exists in map + if res, err := MapGet(m5, 5, 4, 3, 2, 1); err == nil { + if res != nil { + t.Errorf("Should return nil, but return %v", res) + } + } else { + t.Errorf("Error happens %v", err) + } +} diff --git a/pkg/testdata/Makefile b/pkg/testdata/Makefile new file mode 100644 index 00000000..e80e8238 --- /dev/null +++ b/pkg/testdata/Makefile @@ -0,0 +1,2 @@ +build_view: + $(GOPATH)/bin/go-bindata-assetfs -pkg testdata views/... \ No newline at end of file diff --git a/pkg/testdata/bindata.go b/pkg/testdata/bindata.go new file mode 100644 index 00000000..beade103 --- /dev/null +++ b/pkg/testdata/bindata.go @@ -0,0 +1,296 @@ +// Code generated by go-bindata. +// sources: +// views/blocks/block.tpl +// views/header.tpl +// views/index.tpl +// DO NOT EDIT! + +package testdata + +import ( + "bytes" + "compress/gzip" + "fmt" + "github.com/elazarl/go-bindata-assetfs" + "io" + "io/ioutil" + "os" + "path/filepath" + "strings" + "time" +) + +func bindataRead(data []byte, name string) ([]byte, error) { + gz, err := gzip.NewReader(bytes.NewBuffer(data)) + if err != nil { + return nil, fmt.Errorf("Read %q: %v", name, err) + } + + var buf bytes.Buffer + _, err = io.Copy(&buf, gz) + clErr := gz.Close() + + if err != nil { + return nil, fmt.Errorf("Read %q: %v", name, err) + } + if clErr != nil { + return nil, err + } + + return buf.Bytes(), nil +} + +type asset struct { + bytes []byte + info os.FileInfo +} + +type bindataFileInfo struct { + name string + size int64 + mode os.FileMode + modTime time.Time +} + +func (fi bindataFileInfo) Name() string { + return fi.name +} +func (fi bindataFileInfo) Size() int64 { + return fi.size +} +func (fi bindataFileInfo) Mode() os.FileMode { + return fi.mode +} +func (fi bindataFileInfo) ModTime() time.Time { + return fi.modTime +} +func (fi bindataFileInfo) IsDir() bool { + return false +} +func (fi bindataFileInfo) Sys() interface{} { + return nil +} + +var _viewsBlocksBlockTpl = []byte("\x1f\x8b\x08\x00\x00\x00\x00\x00\x00\xff\xaa\xae\x4e\x49\x4d\xcb\xcc\x4b\x55\x50\x4a\xca\xc9\x4f\xce\x56\xaa\xad\xe5\xb2\xc9\x30\xb4\xf3\x48\xcd\xc9\xc9\xd7\x51\x00\x8b\x15\x2b\xda\xe8\x67\x18\xda\x71\x55\x57\xa7\xe6\xa5\xd4\xd6\x02\x02\x00\x00\xff\xff\xfd\xa1\x7a\xf6\x32\x00\x00\x00") + +func viewsBlocksBlockTplBytes() ([]byte, error) { + return bindataRead( + _viewsBlocksBlockTpl, + "views/blocks/block.tpl", + ) +} + +func viewsBlocksBlockTpl() (*asset, error) { + bytes, err := viewsBlocksBlockTplBytes() + if err != nil { + return nil, err + } + + info := bindataFileInfo{name: "views/blocks/block.tpl", size: 50, mode: os.FileMode(436), modTime: time.Unix(1541431067, 0)} + a := &asset{bytes: bytes, info: info} + return a, nil +} + +var _viewsHeaderTpl = []byte("\x1f\x8b\x08\x00\x00\x00\x00\x00\x00\xff\xaa\xae\x4e\x49\x4d\xcb\xcc\x4b\x55\x50\xca\x48\x4d\x4c\x49\x2d\x52\xaa\xad\xe5\xb2\xc9\x30\xb4\xf3\x48\xcd\xc9\xc9\xd7\x51\x48\x2c\x2e\x49\xac\xc8\x4c\x55\xb4\xd1\xcf\x30\xb4\xe3\xaa\xae\x4e\xcd\x4b\xa9\xad\x05\x04\x00\x00\xff\xff\xe4\x12\x47\x01\x34\x00\x00\x00") + +func viewsHeaderTplBytes() ([]byte, error) { + return bindataRead( + _viewsHeaderTpl, + "views/header.tpl", + ) +} + +func viewsHeaderTpl() (*asset, error) { + bytes, err := viewsHeaderTplBytes() + if err != nil { + return nil, err + } + + info := bindataFileInfo{name: "views/header.tpl", size: 52, mode: os.FileMode(436), modTime: time.Unix(1541431067, 0)} + a := &asset{bytes: bytes, info: info} + return a, nil +} + +var _viewsIndexTpl = []byte("\x1f\x8b\x08\x00\x00\x00\x00\x00\x00\xff\x64\x8f\xbd\x8a\xc3\x30\x10\x84\x6b\xeb\x29\xe6\xfc\x00\x16\xb8\x3c\x16\x35\x77\xa9\x13\x88\x09\xa4\xf4\xcf\x12\x99\x48\x48\xd8\x82\x10\x84\xde\x3d\xc8\x8a\x8b\x90\x6a\xa4\xd9\x6f\xd8\x59\xfa\xf9\x3f\xfe\x75\xd7\xd3\x01\x3a\x58\xa3\x04\x15\x01\x48\x73\x3f\xe5\x07\x40\x61\x0e\x86\xd5\xc0\x7c\x73\x78\xb0\x19\x9d\x65\x04\xb6\xde\xf4\x81\x49\x96\x69\x8e\xc8\x3d\x43\x83\x9b\x9e\x4a\x88\x2a\xc6\x9d\x43\x3d\x18\x37\xde\xeb\x94\x3e\xdd\x1c\xe1\xe5\xcb\xde\xe0\x55\x6e\xd2\x04\x6f\x32\x20\x2a\xd2\xad\x8a\x11\x4d\x97\x57\x22\x25\x92\xba\x55\xa2\x22\xaf\xd0\xe9\x79\xc5\xbc\xe2\xec\x2c\x5f\xfa\xe5\x17\x99\x7b\x7f\x36\xd2\x97\x8a\xa5\x19\xc9\x72\xe7\x2b\x00\x00\xff\xff\xb2\x39\xca\x9f\xff\x00\x00\x00") + +func viewsIndexTplBytes() ([]byte, error) { + return bindataRead( + _viewsIndexTpl, + "views/index.tpl", + ) +} + +func viewsIndexTpl() (*asset, error) { + bytes, err := viewsIndexTplBytes() + if err != nil { + return nil, err + } + + info := bindataFileInfo{name: "views/index.tpl", size: 255, mode: os.FileMode(436), modTime: time.Unix(1541434906, 0)} + a := &asset{bytes: bytes, info: info} + return a, nil +} + +// Asset loads and returns the asset for the given name. +// It returns an error if the asset could not be found or +// could not be loaded. +func Asset(name string) ([]byte, error) { + cannonicalName := strings.Replace(name, "\\", "/", -1) + if f, ok := _bindata[cannonicalName]; ok { + a, err := f() + if err != nil { + return nil, fmt.Errorf("Asset %s can't read by error: %v", name, err) + } + return a.bytes, nil + } + return nil, fmt.Errorf("Asset %s not found", name) +} + +// MustAsset is like Asset but panics when Asset would return an error. +// It simplifies safe initialization of global variables. +func MustAsset(name string) []byte { + a, err := Asset(name) + if err != nil { + panic("asset: Asset(" + name + "): " + err.Error()) + } + + return a +} + +// AssetInfo loads and returns the asset info for the given name. +// It returns an error if the asset could not be found or +// could not be loaded. +func AssetInfo(name string) (os.FileInfo, error) { + cannonicalName := strings.Replace(name, "\\", "/", -1) + if f, ok := _bindata[cannonicalName]; ok { + a, err := f() + if err != nil { + return nil, fmt.Errorf("AssetInfo %s can't read by error: %v", name, err) + } + return a.info, nil + } + return nil, fmt.Errorf("AssetInfo %s not found", name) +} + +// AssetNames returns the names of the assets. +func AssetNames() []string { + names := make([]string, 0, len(_bindata)) + for name := range _bindata { + names = append(names, name) + } + return names +} + +// _bindata is a table, holding each asset generator, mapped to its name. +var _bindata = map[string]func() (*asset, error){ + "views/blocks/block.tpl": viewsBlocksBlockTpl, + "views/header.tpl": viewsHeaderTpl, + "views/index.tpl": viewsIndexTpl, +} + +// AssetDir returns the file names below a certain +// directory embedded in the file by go-bindata. +// For example if you run go-bindata on data/... and data contains the +// following hierarchy: +// data/ +// foo.txt +// img/ +// a.png +// b.png +// then AssetDir("data") would return []string{"foo.txt", "img"} +// AssetDir("data/img") would return []string{"a.png", "b.png"} +// AssetDir("foo.txt") and AssetDir("notexist") would return an error +// AssetDir("") will return []string{"data"}. +func AssetDir(name string) ([]string, error) { + node := _bintree + if len(name) != 0 { + cannonicalName := strings.Replace(name, "\\", "/", -1) + pathList := strings.Split(cannonicalName, "/") + for _, p := range pathList { + node = node.Children[p] + if node == nil { + return nil, fmt.Errorf("Asset %s not found", name) + } + } + } + if node.Func != nil { + return nil, fmt.Errorf("Asset %s not found", name) + } + rv := make([]string, 0, len(node.Children)) + for childName := range node.Children { + rv = append(rv, childName) + } + return rv, nil +} + +type bintree struct { + Func func() (*asset, error) + Children map[string]*bintree +} + +var _bintree = &bintree{nil, map[string]*bintree{ + "views": &bintree{nil, map[string]*bintree{ + "blocks": &bintree{nil, map[string]*bintree{ + "block.tpl": &bintree{viewsBlocksBlockTpl, map[string]*bintree{}}, + }}, + "header.tpl": &bintree{viewsHeaderTpl, map[string]*bintree{}}, + "index.tpl": &bintree{viewsIndexTpl, map[string]*bintree{}}, + }}, +}} + +// RestoreAsset restores an asset under the given directory +func RestoreAsset(dir, name string) error { + data, err := Asset(name) + if err != nil { + return err + } + info, err := AssetInfo(name) + if err != nil { + return err + } + err = os.MkdirAll(_filePath(dir, filepath.Dir(name)), os.FileMode(0755)) + if err != nil { + return err + } + err = ioutil.WriteFile(_filePath(dir, name), data, info.Mode()) + if err != nil { + return err + } + err = os.Chtimes(_filePath(dir, name), info.ModTime(), info.ModTime()) + if err != nil { + return err + } + return nil +} + +// RestoreAssets restores an asset under the given directory recursively +func RestoreAssets(dir, name string) error { + children, err := AssetDir(name) + // File + if err != nil { + return RestoreAsset(dir, name) + } + // Dir + for _, child := range children { + err = RestoreAssets(dir, filepath.Join(name, child)) + if err != nil { + return err + } + } + return nil +} + +func _filePath(dir, name string) string { + cannonicalName := strings.Replace(name, "\\", "/", -1) + return filepath.Join(append([]string{dir}, strings.Split(cannonicalName, "/")...)...) +} + +func assetFS() *assetfs.AssetFS { + assetInfo := func(path string) (os.FileInfo, error) { + return os.Stat(path) + } + for k := range _bintree.Children { + return &assetfs.AssetFS{Asset: Asset, AssetDir: AssetDir, AssetInfo: assetInfo, Prefix: k} + } + panic("unreachable") +} diff --git a/pkg/testdata/views/blocks/block.tpl b/pkg/testdata/views/blocks/block.tpl new file mode 100644 index 00000000..2a9c57fc --- /dev/null +++ b/pkg/testdata/views/blocks/block.tpl @@ -0,0 +1,3 @@ +{{define "block"}} +

Hello, blocks!

+{{end}} \ No newline at end of file diff --git a/pkg/testdata/views/header.tpl b/pkg/testdata/views/header.tpl new file mode 100644 index 00000000..041fa403 --- /dev/null +++ b/pkg/testdata/views/header.tpl @@ -0,0 +1,3 @@ +{{define "header"}} +

Hello, astaxie!

+{{end}} \ No newline at end of file diff --git a/pkg/testdata/views/index.tpl b/pkg/testdata/views/index.tpl new file mode 100644 index 00000000..21b7fc06 --- /dev/null +++ b/pkg/testdata/views/index.tpl @@ -0,0 +1,15 @@ + + + + beego welcome template + + + + {{template "block"}} + {{template "header"}} + {{template "blocks/block.tpl"}} + +

{{ .Title }}

+

This is SomeVar: {{ .SomeVar }}

+ + diff --git a/pkg/testing/assertions.go b/pkg/testing/assertions.go new file mode 100644 index 00000000..96c5d4dd --- /dev/null +++ b/pkg/testing/assertions.go @@ -0,0 +1,15 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package testing diff --git a/pkg/testing/client.go b/pkg/testing/client.go new file mode 100644 index 00000000..c3737e9c --- /dev/null +++ b/pkg/testing/client.go @@ -0,0 +1,65 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package testing + +import ( + "github.com/astaxie/beego/config" + "github.com/astaxie/beego/httplib" +) + +var port = "" +var baseURL = "http://localhost:" + +// TestHTTPRequest beego test request client +type TestHTTPRequest struct { + httplib.BeegoHTTPRequest +} + +func getPort() string { + if port == "" { + config, err := config.NewConfig("ini", "../conf/app.conf") + if err != nil { + return "8080" + } + port = config.String("httpport") + return port + } + return port +} + +// Get returns test client in GET method +func Get(path string) *TestHTTPRequest { + return &TestHTTPRequest{*httplib.Get(baseURL + getPort() + path)} +} + +// Post returns test client in POST method +func Post(path string) *TestHTTPRequest { + return &TestHTTPRequest{*httplib.Post(baseURL + getPort() + path)} +} + +// Put returns test client in PUT method +func Put(path string) *TestHTTPRequest { + return &TestHTTPRequest{*httplib.Put(baseURL + getPort() + path)} +} + +// Delete returns test client in DELETE method +func Delete(path string) *TestHTTPRequest { + return &TestHTTPRequest{*httplib.Delete(baseURL + getPort() + path)} +} + +// Head returns test client in HEAD method +func Head(path string) *TestHTTPRequest { + return &TestHTTPRequest{*httplib.Head(baseURL + getPort() + path)} +} diff --git a/pkg/toolbox/healthcheck.go b/pkg/toolbox/healthcheck.go new file mode 100644 index 00000000..e3544b3a --- /dev/null +++ b/pkg/toolbox/healthcheck.go @@ -0,0 +1,48 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package toolbox healthcheck +// +// type DatabaseCheck struct { +// } +// +// func (dc *DatabaseCheck) Check() error { +// if dc.isConnected() { +// return nil +// } else { +// return errors.New("can't connect database") +// } +// } +// +// AddHealthCheck("database",&DatabaseCheck{}) +// +// more docs: http://beego.me/docs/module/toolbox.md +package toolbox + +// AdminCheckList holds health checker map +var AdminCheckList map[string]HealthChecker + +// HealthChecker health checker interface +type HealthChecker interface { + Check() error +} + +// AddHealthCheck add health checker with name string +func AddHealthCheck(name string, hc HealthChecker) { + AdminCheckList[name] = hc +} + +func init() { + AdminCheckList = make(map[string]HealthChecker) +} diff --git a/pkg/toolbox/profile.go b/pkg/toolbox/profile.go new file mode 100644 index 00000000..06e40ede --- /dev/null +++ b/pkg/toolbox/profile.go @@ -0,0 +1,184 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package toolbox + +import ( + "fmt" + "io" + "log" + "os" + "path" + "runtime" + "runtime/debug" + "runtime/pprof" + "strconv" + "time" +) + +var startTime = time.Now() +var pid int + +func init() { + pid = os.Getpid() +} + +// ProcessInput parse input command string +func ProcessInput(input string, w io.Writer) { + switch input { + case "lookup goroutine": + p := pprof.Lookup("goroutine") + p.WriteTo(w, 2) + case "lookup heap": + p := pprof.Lookup("heap") + p.WriteTo(w, 2) + case "lookup threadcreate": + p := pprof.Lookup("threadcreate") + p.WriteTo(w, 2) + case "lookup block": + p := pprof.Lookup("block") + p.WriteTo(w, 2) + case "get cpuprof": + GetCPUProfile(w) + case "get memprof": + MemProf(w) + case "gc summary": + PrintGCSummary(w) + } +} + +// MemProf record memory profile in pprof +func MemProf(w io.Writer) { + filename := "mem-" + strconv.Itoa(pid) + ".memprof" + if f, err := os.Create(filename); err != nil { + fmt.Fprintf(w, "create file %s error %s\n", filename, err.Error()) + log.Fatal("record heap profile failed: ", err) + } else { + runtime.GC() + pprof.WriteHeapProfile(f) + f.Close() + fmt.Fprintf(w, "create heap profile %s \n", filename) + _, fl := path.Split(os.Args[0]) + fmt.Fprintf(w, "Now you can use this to check it: go tool pprof %s %s\n", fl, filename) + } +} + +// GetCPUProfile start cpu profile monitor +func GetCPUProfile(w io.Writer) { + sec := 30 + filename := "cpu-" + strconv.Itoa(pid) + ".pprof" + f, err := os.Create(filename) + if err != nil { + fmt.Fprintf(w, "Could not enable CPU profiling: %s\n", err) + log.Fatal("record cpu profile failed: ", err) + } + pprof.StartCPUProfile(f) + time.Sleep(time.Duration(sec) * time.Second) + pprof.StopCPUProfile() + + fmt.Fprintf(w, "create cpu profile %s \n", filename) + _, fl := path.Split(os.Args[0]) + fmt.Fprintf(w, "Now you can use this to check it: go tool pprof %s %s\n", fl, filename) +} + +// PrintGCSummary print gc information to io.Writer +func PrintGCSummary(w io.Writer) { + memStats := &runtime.MemStats{} + runtime.ReadMemStats(memStats) + gcstats := &debug.GCStats{PauseQuantiles: make([]time.Duration, 100)} + debug.ReadGCStats(gcstats) + + printGC(memStats, gcstats, w) +} + +func printGC(memStats *runtime.MemStats, gcstats *debug.GCStats, w io.Writer) { + + if gcstats.NumGC > 0 { + lastPause := gcstats.Pause[0] + elapsed := time.Now().Sub(startTime) + overhead := float64(gcstats.PauseTotal) / float64(elapsed) * 100 + allocatedRate := float64(memStats.TotalAlloc) / elapsed.Seconds() + + fmt.Fprintf(w, "NumGC:%d Pause:%s Pause(Avg):%s Overhead:%3.2f%% Alloc:%s Sys:%s Alloc(Rate):%s/s Histogram:%s %s %s \n", + gcstats.NumGC, + toS(lastPause), + toS(avg(gcstats.Pause)), + overhead, + toH(memStats.Alloc), + toH(memStats.Sys), + toH(uint64(allocatedRate)), + toS(gcstats.PauseQuantiles[94]), + toS(gcstats.PauseQuantiles[98]), + toS(gcstats.PauseQuantiles[99])) + } else { + // while GC has disabled + elapsed := time.Now().Sub(startTime) + allocatedRate := float64(memStats.TotalAlloc) / elapsed.Seconds() + + fmt.Fprintf(w, "Alloc:%s Sys:%s Alloc(Rate):%s/s\n", + toH(memStats.Alloc), + toH(memStats.Sys), + toH(uint64(allocatedRate))) + } +} + +func avg(items []time.Duration) time.Duration { + var sum time.Duration + for _, item := range items { + sum += item + } + return time.Duration(int64(sum) / int64(len(items))) +} + +// format bytes number friendly +func toH(bytes uint64) string { + switch { + case bytes < 1024: + return fmt.Sprintf("%dB", bytes) + case bytes < 1024*1024: + return fmt.Sprintf("%.2fK", float64(bytes)/1024) + case bytes < 1024*1024*1024: + return fmt.Sprintf("%.2fM", float64(bytes)/1024/1024) + default: + return fmt.Sprintf("%.2fG", float64(bytes)/1024/1024/1024) + } +} + +// short string format +func toS(d time.Duration) string { + + u := uint64(d) + if u < uint64(time.Second) { + switch { + case u == 0: + return "0" + case u < uint64(time.Microsecond): + return fmt.Sprintf("%.2fns", float64(u)) + case u < uint64(time.Millisecond): + return fmt.Sprintf("%.2fus", float64(u)/1000) + default: + return fmt.Sprintf("%.2fms", float64(u)/1000/1000) + } + } else { + switch { + case u < uint64(time.Minute): + return fmt.Sprintf("%.2fs", float64(u)/1000/1000/1000) + case u < uint64(time.Hour): + return fmt.Sprintf("%.2fm", float64(u)/1000/1000/1000/60) + default: + return fmt.Sprintf("%.2fh", float64(u)/1000/1000/1000/60/60) + } + } + +} diff --git a/pkg/toolbox/profile_test.go b/pkg/toolbox/profile_test.go new file mode 100644 index 00000000..07a20c4e --- /dev/null +++ b/pkg/toolbox/profile_test.go @@ -0,0 +1,28 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package toolbox + +import ( + "os" + "testing" +) + +func TestProcessInput(t *testing.T) { + ProcessInput("lookup goroutine", os.Stdout) + ProcessInput("lookup heap", os.Stdout) + ProcessInput("lookup threadcreate", os.Stdout) + ProcessInput("lookup block", os.Stdout) + ProcessInput("gc summary", os.Stdout) +} diff --git a/pkg/toolbox/statistics.go b/pkg/toolbox/statistics.go new file mode 100644 index 00000000..fd73dfb3 --- /dev/null +++ b/pkg/toolbox/statistics.go @@ -0,0 +1,149 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package toolbox + +import ( + "fmt" + "sync" + "time" +) + +// Statistics struct +type Statistics struct { + RequestURL string + RequestController string + RequestNum int64 + MinTime time.Duration + MaxTime time.Duration + TotalTime time.Duration +} + +// URLMap contains several statistics struct to log different data +type URLMap struct { + lock sync.RWMutex + LengthLimit int //limit the urlmap's length if it's equal to 0 there's no limit + urlmap map[string]map[string]*Statistics +} + +// AddStatistics add statistics task. +// it needs request method, request url, request controller and statistics time duration +func (m *URLMap) AddStatistics(requestMethod, requestURL, requestController string, requesttime time.Duration) { + m.lock.Lock() + defer m.lock.Unlock() + if method, ok := m.urlmap[requestURL]; ok { + if s, ok := method[requestMethod]; ok { + s.RequestNum++ + if s.MaxTime < requesttime { + s.MaxTime = requesttime + } + if s.MinTime > requesttime { + s.MinTime = requesttime + } + s.TotalTime += requesttime + } else { + nb := &Statistics{ + RequestURL: requestURL, + RequestController: requestController, + RequestNum: 1, + MinTime: requesttime, + MaxTime: requesttime, + TotalTime: requesttime, + } + m.urlmap[requestURL][requestMethod] = nb + } + + } else { + if m.LengthLimit > 0 && m.LengthLimit <= len(m.urlmap) { + return + } + methodmap := make(map[string]*Statistics) + nb := &Statistics{ + RequestURL: requestURL, + RequestController: requestController, + RequestNum: 1, + MinTime: requesttime, + MaxTime: requesttime, + TotalTime: requesttime, + } + methodmap[requestMethod] = nb + m.urlmap[requestURL] = methodmap + } +} + +// GetMap put url statistics result in io.Writer +func (m *URLMap) GetMap() map[string]interface{} { + m.lock.RLock() + defer m.lock.RUnlock() + + var fields = []string{"requestUrl", "method", "times", "used", "max used", "min used", "avg used"} + + var resultLists [][]string + content := make(map[string]interface{}) + content["Fields"] = fields + + for k, v := range m.urlmap { + for kk, vv := range v { + result := []string{ + fmt.Sprintf("% -50s", k), + fmt.Sprintf("% -10s", kk), + fmt.Sprintf("% -16d", vv.RequestNum), + fmt.Sprintf("%d", vv.TotalTime), + fmt.Sprintf("% -16s", toS(vv.TotalTime)), + fmt.Sprintf("%d", vv.MaxTime), + fmt.Sprintf("% -16s", toS(vv.MaxTime)), + fmt.Sprintf("%d", vv.MinTime), + fmt.Sprintf("% -16s", toS(vv.MinTime)), + fmt.Sprintf("%d", time.Duration(int64(vv.TotalTime)/vv.RequestNum)), + fmt.Sprintf("% -16s", toS(time.Duration(int64(vv.TotalTime)/vv.RequestNum))), + } + resultLists = append(resultLists, result) + } + } + content["Data"] = resultLists + return content +} + +// GetMapData return all mapdata +func (m *URLMap) GetMapData() []map[string]interface{} { + m.lock.RLock() + defer m.lock.RUnlock() + + var resultLists []map[string]interface{} + + for k, v := range m.urlmap { + for kk, vv := range v { + result := map[string]interface{}{ + "request_url": k, + "method": kk, + "times": vv.RequestNum, + "total_time": toS(vv.TotalTime), + "max_time": toS(vv.MaxTime), + "min_time": toS(vv.MinTime), + "avg_time": toS(time.Duration(int64(vv.TotalTime) / vv.RequestNum)), + } + resultLists = append(resultLists, result) + } + } + return resultLists +} + +// StatisticsMap hosld global statistics data map +var StatisticsMap *URLMap + +func init() { + StatisticsMap = &URLMap{ + urlmap: make(map[string]map[string]*Statistics), + } +} diff --git a/pkg/toolbox/statistics_test.go b/pkg/toolbox/statistics_test.go new file mode 100644 index 00000000..ac29476c --- /dev/null +++ b/pkg/toolbox/statistics_test.go @@ -0,0 +1,40 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package toolbox + +import ( + "encoding/json" + "testing" + "time" +) + +func TestStatics(t *testing.T) { + StatisticsMap.AddStatistics("POST", "/api/user", "&admin.user", time.Duration(2000)) + StatisticsMap.AddStatistics("POST", "/api/user", "&admin.user", time.Duration(120000)) + StatisticsMap.AddStatistics("GET", "/api/user", "&admin.user", time.Duration(13000)) + StatisticsMap.AddStatistics("POST", "/api/admin", "&admin.user", time.Duration(14000)) + StatisticsMap.AddStatistics("POST", "/api/user/astaxie", "&admin.user", time.Duration(12000)) + StatisticsMap.AddStatistics("POST", "/api/user/xiemengjun", "&admin.user", time.Duration(13000)) + StatisticsMap.AddStatistics("DELETE", "/api/user", "&admin.user", time.Duration(1400)) + t.Log(StatisticsMap.GetMap()) + + data := StatisticsMap.GetMapData() + b, err := json.Marshal(data) + if err != nil { + t.Errorf(err.Error()) + } + + t.Log(string(b)) +} diff --git a/pkg/toolbox/task.go b/pkg/toolbox/task.go new file mode 100644 index 00000000..c902fdfc --- /dev/null +++ b/pkg/toolbox/task.go @@ -0,0 +1,640 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package toolbox + +import ( + "log" + "math" + "sort" + "strconv" + "strings" + "sync" + "time" +) + +// bounds provides a range of acceptable values (plus a map of name to value). +type bounds struct { + min, max uint + names map[string]uint +} + +// The bounds for each field. +var ( + AdminTaskList map[string]Tasker + taskLock sync.RWMutex + stop chan bool + changed chan bool + isstart bool + seconds = bounds{0, 59, nil} + minutes = bounds{0, 59, nil} + hours = bounds{0, 23, nil} + days = bounds{1, 31, nil} + months = bounds{1, 12, map[string]uint{ + "jan": 1, + "feb": 2, + "mar": 3, + "apr": 4, + "may": 5, + "jun": 6, + "jul": 7, + "aug": 8, + "sep": 9, + "oct": 10, + "nov": 11, + "dec": 12, + }} + weeks = bounds{0, 6, map[string]uint{ + "sun": 0, + "mon": 1, + "tue": 2, + "wed": 3, + "thu": 4, + "fri": 5, + "sat": 6, + }} +) + +const ( + // Set the top bit if a star was included in the expression. + starBit = 1 << 63 +) + +// Schedule time taks schedule +type Schedule struct { + Second uint64 + Minute uint64 + Hour uint64 + Day uint64 + Month uint64 + Week uint64 +} + +// TaskFunc task func type +type TaskFunc func() error + +// Tasker task interface +type Tasker interface { + GetSpec() string + GetStatus() string + Run() error + SetNext(time.Time) + GetNext() time.Time + SetPrev(time.Time) + GetPrev() time.Time +} + +// task error +type taskerr struct { + t time.Time + errinfo string +} + +// Task task struct +// It's not a thread-safe structure. +// Only nearest errors will be saved in ErrList +type Task struct { + Taskname string + Spec *Schedule + SpecStr string + DoFunc TaskFunc + Prev time.Time + Next time.Time + Errlist []*taskerr // like errtime:errinfo + ErrLimit int // max length for the errlist, 0 stand for no limit + errCnt int // records the error count during the execution +} + +// NewTask add new task with name, time and func +func NewTask(tname string, spec string, f TaskFunc) *Task { + + task := &Task{ + Taskname: tname, + DoFunc: f, + // Make configurable + ErrLimit: 100, + SpecStr: spec, + // we only store the pointer, so it won't use too many space + Errlist: make([]*taskerr, 100, 100), + } + task.SetCron(spec) + return task +} + +// GetSpec get spec string +func (t *Task) GetSpec() string { + return t.SpecStr +} + +// GetStatus get current task status +func (t *Task) GetStatus() string { + var str string + for _, v := range t.Errlist { + str += v.t.String() + ":" + v.errinfo + "
" + } + return str +} + +// Run run all tasks +func (t *Task) Run() error { + err := t.DoFunc() + if err != nil { + index := t.errCnt % t.ErrLimit + t.Errlist[index] = &taskerr{t: t.Next, errinfo: err.Error()} + t.errCnt++ + } + return err +} + +// SetNext set next time for this task +func (t *Task) SetNext(now time.Time) { + t.Next = t.Spec.Next(now) +} + +// GetNext get the next call time of this task +func (t *Task) GetNext() time.Time { + return t.Next +} + +// SetPrev set prev time of this task +func (t *Task) SetPrev(now time.Time) { + t.Prev = now +} + +// GetPrev get prev time of this task +func (t *Task) GetPrev() time.Time { + return t.Prev +} + +// six columns mean: +// second:0-59 +// minute:0-59 +// hour:1-23 +// day:1-31 +// month:1-12 +// week:0-6(0 means Sunday) + +// SetCron some signals: +// *: any time +// ,:  separate signal +//   -:duration +// /n : do as n times of time duration +///////////////////////////////////////////////////////// +// 0/30 * * * * * every 30s +// 0 43 21 * * * 21:43 +// 0 15 05 * * *    05:15 +// 0 0 17 * * * 17:00 +// 0 0 17 * * 1 17:00 in every Monday +// 0 0,10 17 * * 0,2,3 17:00 and 17:10 in every Sunday, Tuesday and Wednesday +// 0 0-10 17 1 * * 17:00 to 17:10 in 1 min duration each time on the first day of month +// 0 0 0 1,15 * 1 0:00 on the 1st day and 15th day of month +// 0 42 4 1 * *     4:42 on the 1st day of month +// 0 0 21 * * 1-6   21:00 from Monday to Saturday +// 0 0,10,20,30,40,50 * * * *  every 10 min duration +// 0 */10 * * * *        every 10 min duration +// 0 * 1 * * *         1:00 to 1:59 in 1 min duration each time +// 0 0 1 * * *         1:00 +// 0 0 */1 * * *        0 min of hour in 1 hour duration +// 0 0 * * * *         0 min of hour in 1 hour duration +// 0 2 8-20/3 * * *       8:02, 11:02, 14:02, 17:02, 20:02 +// 0 30 5 1,15 * *       5:30 on the 1st day and 15th day of month +func (t *Task) SetCron(spec string) { + t.Spec = t.parse(spec) +} + +func (t *Task) parse(spec string) *Schedule { + if len(spec) > 0 && spec[0] == '@' { + return t.parseSpec(spec) + } + // Split on whitespace. We require 5 or 6 fields. + // (second) (minute) (hour) (day of month) (month) (day of week, optional) + fields := strings.Fields(spec) + if len(fields) != 5 && len(fields) != 6 { + log.Panicf("Expected 5 or 6 fields, found %d: %s", len(fields), spec) + } + + // If a sixth field is not provided (DayOfWeek), then it is equivalent to star. + if len(fields) == 5 { + fields = append(fields, "*") + } + + schedule := &Schedule{ + Second: getField(fields[0], seconds), + Minute: getField(fields[1], minutes), + Hour: getField(fields[2], hours), + Day: getField(fields[3], days), + Month: getField(fields[4], months), + Week: getField(fields[5], weeks), + } + + return schedule +} + +func (t *Task) parseSpec(spec string) *Schedule { + switch spec { + case "@yearly", "@annually": + return &Schedule{ + Second: 1 << seconds.min, + Minute: 1 << minutes.min, + Hour: 1 << hours.min, + Day: 1 << days.min, + Month: 1 << months.min, + Week: all(weeks), + } + + case "@monthly": + return &Schedule{ + Second: 1 << seconds.min, + Minute: 1 << minutes.min, + Hour: 1 << hours.min, + Day: 1 << days.min, + Month: all(months), + Week: all(weeks), + } + + case "@weekly": + return &Schedule{ + Second: 1 << seconds.min, + Minute: 1 << minutes.min, + Hour: 1 << hours.min, + Day: all(days), + Month: all(months), + Week: 1 << weeks.min, + } + + case "@daily", "@midnight": + return &Schedule{ + Second: 1 << seconds.min, + Minute: 1 << minutes.min, + Hour: 1 << hours.min, + Day: all(days), + Month: all(months), + Week: all(weeks), + } + + case "@hourly": + return &Schedule{ + Second: 1 << seconds.min, + Minute: 1 << minutes.min, + Hour: all(hours), + Day: all(days), + Month: all(months), + Week: all(weeks), + } + } + log.Panicf("Unrecognized descriptor: %s", spec) + return nil +} + +// Next set schedule to next time +func (s *Schedule) Next(t time.Time) time.Time { + + // Start at the earliest possible time (the upcoming second). + t = t.Add(1*time.Second - time.Duration(t.Nanosecond())*time.Nanosecond) + + // This flag indicates whether a field has been incremented. + added := false + + // If no time is found within five years, return zero. + yearLimit := t.Year() + 5 + +WRAP: + if t.Year() > yearLimit { + return time.Time{} + } + + // Find the first applicable month. + // If it's this month, then do nothing. + for 1< 0 + dowMatch = 1< 0 + ) + + if s.Day&starBit > 0 || s.Week&starBit > 0 { + return domMatch && dowMatch + } + return domMatch || dowMatch +} + +// StartTask start all tasks +func StartTask() { + taskLock.Lock() + defer taskLock.Unlock() + if isstart { + //If already started, no need to start another goroutine. + return + } + isstart = true + go run() +} + +func run() { + now := time.Now().Local() + for _, t := range AdminTaskList { + t.SetNext(now) + } + + for { + // we only use RLock here because NewMapSorter copy the reference, do not change any thing + taskLock.RLock() + sortList := NewMapSorter(AdminTaskList) + taskLock.RUnlock() + sortList.Sort() + var effective time.Time + if len(AdminTaskList) == 0 || sortList.Vals[0].GetNext().IsZero() { + // If there are no entries yet, just sleep - it still handles new entries + // and stop requests. + effective = now.AddDate(10, 0, 0) + } else { + effective = sortList.Vals[0].GetNext() + } + select { + case now = <-time.After(effective.Sub(now)): + // Run every entry whose next time was this effective time. + for _, e := range sortList.Vals { + if e.GetNext() != effective { + break + } + go e.Run() + e.SetPrev(e.GetNext()) + e.SetNext(effective) + } + continue + case <-changed: + now = time.Now().Local() + taskLock.Lock() + for _, t := range AdminTaskList { + t.SetNext(now) + } + taskLock.Unlock() + continue + case <-stop: + return + } + } +} + +// StopTask stop all tasks +func StopTask() { + taskLock.Lock() + defer taskLock.Unlock() + if isstart { + isstart = false + stop <- true + } + +} + +// AddTask add task with name +func AddTask(taskname string, t Tasker) { + taskLock.Lock() + defer taskLock.Unlock() + t.SetNext(time.Now().Local()) + AdminTaskList[taskname] = t + if isstart { + changed <- true + } +} + +// DeleteTask delete task with name +func DeleteTask(taskname string) { + taskLock.Lock() + defer taskLock.Unlock() + delete(AdminTaskList, taskname) + if isstart { + changed <- true + } +} + +// MapSorter sort map for tasker +type MapSorter struct { + Keys []string + Vals []Tasker +} + +// NewMapSorter create new tasker map +func NewMapSorter(m map[string]Tasker) *MapSorter { + ms := &MapSorter{ + Keys: make([]string, 0, len(m)), + Vals: make([]Tasker, 0, len(m)), + } + for k, v := range m { + ms.Keys = append(ms.Keys, k) + ms.Vals = append(ms.Vals, v) + } + return ms +} + +// Sort sort tasker map +func (ms *MapSorter) Sort() { + sort.Sort(ms) +} + +func (ms *MapSorter) Len() int { return len(ms.Keys) } +func (ms *MapSorter) Less(i, j int) bool { + if ms.Vals[i].GetNext().IsZero() { + return false + } + if ms.Vals[j].GetNext().IsZero() { + return true + } + return ms.Vals[i].GetNext().Before(ms.Vals[j].GetNext()) +} +func (ms *MapSorter) Swap(i, j int) { + ms.Vals[i], ms.Vals[j] = ms.Vals[j], ms.Vals[i] + ms.Keys[i], ms.Keys[j] = ms.Keys[j], ms.Keys[i] +} + +func getField(field string, r bounds) uint64 { + // list = range {"," range} + var bits uint64 + ranges := strings.FieldsFunc(field, func(r rune) bool { return r == ',' }) + for _, expr := range ranges { + bits |= getRange(expr, r) + } + return bits +} + +// getRange returns the bits indicated by the given expression: +// number | number "-" number [ "/" number ] +func getRange(expr string, r bounds) uint64 { + + var ( + start, end, step uint + rangeAndStep = strings.Split(expr, "/") + lowAndHigh = strings.Split(rangeAndStep[0], "-") + singleDigit = len(lowAndHigh) == 1 + ) + + var extrastar uint64 + if lowAndHigh[0] == "*" || lowAndHigh[0] == "?" { + start = r.min + end = r.max + extrastar = starBit + } else { + start = parseIntOrName(lowAndHigh[0], r.names) + switch len(lowAndHigh) { + case 1: + end = start + case 2: + end = parseIntOrName(lowAndHigh[1], r.names) + default: + log.Panicf("Too many hyphens: %s", expr) + } + } + + switch len(rangeAndStep) { + case 1: + step = 1 + case 2: + step = mustParseInt(rangeAndStep[1]) + + // Special handling: "N/step" means "N-max/step". + if singleDigit { + end = r.max + } + default: + log.Panicf("Too many slashes: %s", expr) + } + + if start < r.min { + log.Panicf("Beginning of range (%d) below minimum (%d): %s", start, r.min, expr) + } + if end > r.max { + log.Panicf("End of range (%d) above maximum (%d): %s", end, r.max, expr) + } + if start > end { + log.Panicf("Beginning of range (%d) beyond end of range (%d): %s", start, end, expr) + } + + return getBits(start, end, step) | extrastar +} + +// parseIntOrName returns the (possibly-named) integer contained in expr. +func parseIntOrName(expr string, names map[string]uint) uint { + if names != nil { + if namedInt, ok := names[strings.ToLower(expr)]; ok { + return namedInt + } + } + return mustParseInt(expr) +} + +// mustParseInt parses the given expression as an int or panics. +func mustParseInt(expr string) uint { + num, err := strconv.Atoi(expr) + if err != nil { + log.Panicf("Failed to parse int from %s: %s", expr, err) + } + if num < 0 { + log.Panicf("Negative number (%d) not allowed: %s", num, expr) + } + + return uint(num) +} + +// getBits sets all bits in the range [min, max], modulo the given step size. +func getBits(min, max, step uint) uint64 { + var bits uint64 + + // If step is 1, use shifts. + if step == 1 { + return ^(math.MaxUint64 << (max + 1)) & (math.MaxUint64 << min) + } + + // Else, use a simple loop. + for i := min; i <= max; i += step { + bits |= 1 << i + } + return bits +} + +// all returns all bits within the given bounds. (plus the star bit) +func all(r bounds) uint64 { + return getBits(r.min, r.max, 1) | starBit +} + +func init() { + AdminTaskList = make(map[string]Tasker) + stop = make(chan bool) + changed = make(chan bool) +} diff --git a/pkg/toolbox/task_test.go b/pkg/toolbox/task_test.go new file mode 100644 index 00000000..3a4cce2f --- /dev/null +++ b/pkg/toolbox/task_test.go @@ -0,0 +1,85 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package toolbox + +import ( + "errors" + "fmt" + "sync" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestParse(t *testing.T) { + tk := NewTask("taska", "0/30 * * * * *", func() error { fmt.Println("hello world"); return nil }) + err := tk.Run() + if err != nil { + t.Fatal(err) + } + AddTask("taska", tk) + StartTask() + time.Sleep(6 * time.Second) + StopTask() +} + +func TestSpec(t *testing.T) { + wg := &sync.WaitGroup{} + wg.Add(2) + tk1 := NewTask("tk1", "0 12 * * * *", func() error { fmt.Println("tk1"); return nil }) + tk2 := NewTask("tk2", "0,10,20 * * * * *", func() error { fmt.Println("tk2"); wg.Done(); return nil }) + tk3 := NewTask("tk3", "0 10 * * * *", func() error { fmt.Println("tk3"); wg.Done(); return nil }) + + AddTask("tk1", tk1) + AddTask("tk2", tk2) + AddTask("tk3", tk3) + StartTask() + defer StopTask() + + select { + case <-time.After(200 * time.Second): + t.FailNow() + case <-wait(wg): + } +} + +func TestTask_Run(t *testing.T) { + cnt := -1 + task := func() error { + cnt ++ + fmt.Printf("Hello, world! %d \n", cnt) + return errors.New(fmt.Sprintf("Hello, world! %d", cnt)) + } + tk := NewTask("taska", "0/30 * * * * *", task) + for i := 0; i < 200 ; i ++ { + e := tk.Run() + assert.NotNil(t, e) + } + + l := tk.Errlist + assert.Equal(t, 100, len(l)) + assert.Equal(t, "Hello, world! 100", l[0].errinfo) + assert.Equal(t, "Hello, world! 101", l[1].errinfo) +} + +func wait(wg *sync.WaitGroup) chan bool { + ch := make(chan bool) + go func() { + wg.Wait() + ch <- true + }() + return ch +} diff --git a/pkg/tree.go b/pkg/tree.go new file mode 100644 index 00000000..9e53003b --- /dev/null +++ b/pkg/tree.go @@ -0,0 +1,585 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package beego + +import ( + "path" + "regexp" + "strings" + + "github.com/astaxie/beego/context" + "github.com/astaxie/beego/utils" +) + +var ( + allowSuffixExt = []string{".json", ".xml", ".html"} +) + +// Tree has three elements: FixRouter/wildcard/leaves +// fixRouter stores Fixed Router +// wildcard stores params +// leaves store the endpoint information +type Tree struct { + //prefix set for static router + prefix string + //search fix route first + fixrouters []*Tree + //if set, failure to match fixrouters search then search wildcard + wildcard *Tree + //if set, failure to match wildcard search + leaves []*leafInfo +} + +// NewTree return a new Tree +func NewTree() *Tree { + return &Tree{} +} + +// AddTree will add tree to the exist Tree +// prefix should has no params +func (t *Tree) AddTree(prefix string, tree *Tree) { + t.addtree(splitPath(prefix), tree, nil, "") +} + +func (t *Tree) addtree(segments []string, tree *Tree, wildcards []string, reg string) { + if len(segments) == 0 { + panic("prefix should has path") + } + seg := segments[0] + iswild, params, regexpStr := splitSegment(seg) + // if it's ? meaning can igone this, so add one more rule for it + if len(params) > 0 && params[0] == ":" { + params = params[1:] + if len(segments[1:]) > 0 { + t.addtree(segments[1:], tree, append(wildcards, params...), reg) + } else { + filterTreeWithPrefix(tree, wildcards, reg) + } + } + //Rule: /login/*/access match /login/2009/11/access + //if already has *, and when loop the access, should as a regexpStr + if !iswild && utils.InSlice(":splat", wildcards) { + iswild = true + regexpStr = seg + } + //Rule: /user/:id/* + if seg == "*" && len(wildcards) > 0 && reg == "" { + regexpStr = "(.+)" + } + if len(segments) == 1 { + if iswild { + if regexpStr != "" { + if reg == "" { + rr := "" + for _, w := range wildcards { + if w == ":splat" { + rr = rr + "(.+)/" + } else { + rr = rr + "([^/]+)/" + } + } + regexpStr = rr + regexpStr + } else { + regexpStr = "/" + regexpStr + } + } else if reg != "" { + if seg == "*.*" { + regexpStr = "([^.]+).(.+)" + } else { + for _, w := range params { + if w == "." || w == ":" { + continue + } + regexpStr = "([^/]+)/" + regexpStr + } + } + } + reg = strings.Trim(reg+"/"+regexpStr, "/") + filterTreeWithPrefix(tree, append(wildcards, params...), reg) + t.wildcard = tree + } else { + reg = strings.Trim(reg+"/"+regexpStr, "/") + filterTreeWithPrefix(tree, append(wildcards, params...), reg) + tree.prefix = seg + t.fixrouters = append(t.fixrouters, tree) + } + return + } + + if iswild { + if t.wildcard == nil { + t.wildcard = NewTree() + } + if regexpStr != "" { + if reg == "" { + rr := "" + for _, w := range wildcards { + if w == ":splat" { + rr = rr + "(.+)/" + } else { + rr = rr + "([^/]+)/" + } + } + regexpStr = rr + regexpStr + } else { + regexpStr = "/" + regexpStr + } + } else if reg != "" { + if seg == "*.*" { + regexpStr = "([^.]+).(.+)" + params = params[1:] + } else { + for range params { + regexpStr = "([^/]+)/" + regexpStr + } + } + } else { + if seg == "*.*" { + params = params[1:] + } + } + reg = strings.TrimRight(strings.TrimRight(reg, "/")+"/"+regexpStr, "/") + t.wildcard.addtree(segments[1:], tree, append(wildcards, params...), reg) + } else { + subTree := NewTree() + subTree.prefix = seg + t.fixrouters = append(t.fixrouters, subTree) + subTree.addtree(segments[1:], tree, append(wildcards, params...), reg) + } +} + +func filterTreeWithPrefix(t *Tree, wildcards []string, reg string) { + for _, v := range t.fixrouters { + filterTreeWithPrefix(v, wildcards, reg) + } + if t.wildcard != nil { + filterTreeWithPrefix(t.wildcard, wildcards, reg) + } + for _, l := range t.leaves { + if reg != "" { + if l.regexps != nil { + l.wildcards = append(wildcards, l.wildcards...) + l.regexps = regexp.MustCompile("^" + reg + "/" + strings.Trim(l.regexps.String(), "^$") + "$") + } else { + for _, v := range l.wildcards { + if v == ":splat" { + reg = reg + "/(.+)" + } else { + reg = reg + "/([^/]+)" + } + } + l.regexps = regexp.MustCompile("^" + reg + "$") + l.wildcards = append(wildcards, l.wildcards...) + } + } else { + l.wildcards = append(wildcards, l.wildcards...) + if l.regexps != nil { + for _, w := range wildcards { + if w == ":splat" { + reg = "(.+)/" + reg + } else { + reg = "([^/]+)/" + reg + } + } + l.regexps = regexp.MustCompile("^" + reg + strings.Trim(l.regexps.String(), "^$") + "$") + } + } + } +} + +// AddRouter call addseg function +func (t *Tree) AddRouter(pattern string, runObject interface{}) { + t.addseg(splitPath(pattern), runObject, nil, "") +} + +// "/" +// "admin" -> +func (t *Tree) addseg(segments []string, route interface{}, wildcards []string, reg string) { + if len(segments) == 0 { + if reg != "" { + t.leaves = append(t.leaves, &leafInfo{runObject: route, wildcards: wildcards, regexps: regexp.MustCompile("^" + reg + "$")}) + } else { + t.leaves = append(t.leaves, &leafInfo{runObject: route, wildcards: wildcards}) + } + } else { + seg := segments[0] + iswild, params, regexpStr := splitSegment(seg) + // if it's ? meaning can igone this, so add one more rule for it + if len(params) > 0 && params[0] == ":" { + t.addseg(segments[1:], route, wildcards, reg) + params = params[1:] + } + //Rule: /login/*/access match /login/2009/11/access + //if already has *, and when loop the access, should as a regexpStr + if !iswild && utils.InSlice(":splat", wildcards) { + iswild = true + regexpStr = seg + } + //Rule: /user/:id/* + if seg == "*" && len(wildcards) > 0 && reg == "" { + regexpStr = "(.+)" + } + if iswild { + if t.wildcard == nil { + t.wildcard = NewTree() + } + if regexpStr != "" { + if reg == "" { + rr := "" + for _, w := range wildcards { + if w == ":splat" { + rr = rr + "(.+)/" + } else { + rr = rr + "([^/]+)/" + } + } + regexpStr = rr + regexpStr + } else { + regexpStr = "/" + regexpStr + } + } else if reg != "" { + if seg == "*.*" { + regexpStr = "/([^.]+).(.+)" + params = params[1:] + } else { + for range params { + regexpStr = "/([^/]+)" + regexpStr + } + } + } else { + if seg == "*.*" { + params = params[1:] + } + } + t.wildcard.addseg(segments[1:], route, append(wildcards, params...), reg+regexpStr) + } else { + var subTree *Tree + for _, sub := range t.fixrouters { + if sub.prefix == seg { + subTree = sub + break + } + } + if subTree == nil { + subTree = NewTree() + subTree.prefix = seg + t.fixrouters = append(t.fixrouters, subTree) + } + subTree.addseg(segments[1:], route, wildcards, reg) + } + } +} + +// Match router to runObject & params +func (t *Tree) Match(pattern string, ctx *context.Context) (runObject interface{}) { + if len(pattern) == 0 || pattern[0] != '/' { + return nil + } + w := make([]string, 0, 20) + return t.match(pattern[1:], pattern, w, ctx) +} + +func (t *Tree) match(treePattern string, pattern string, wildcardValues []string, ctx *context.Context) (runObject interface{}) { + if len(pattern) > 0 { + i := 0 + for ; i < len(pattern) && pattern[i] == '/'; i++ { + } + pattern = pattern[i:] + } + // Handle leaf nodes: + if len(pattern) == 0 { + for _, l := range t.leaves { + if ok := l.match(treePattern, wildcardValues, ctx); ok { + return l.runObject + } + } + if t.wildcard != nil { + for _, l := range t.wildcard.leaves { + if ok := l.match(treePattern, wildcardValues, ctx); ok { + return l.runObject + } + } + } + return nil + } + var seg string + i, l := 0, len(pattern) + for ; i < l && pattern[i] != '/'; i++ { + } + if i == 0 { + seg = pattern + pattern = "" + } else { + seg = pattern[:i] + pattern = pattern[i:] + } + for _, subTree := range t.fixrouters { + if subTree.prefix == seg { + if len(pattern) != 0 && pattern[0] == '/' { + treePattern = pattern[1:] + } else { + treePattern = pattern + } + runObject = subTree.match(treePattern, pattern, wildcardValues, ctx) + if runObject != nil { + break + } + } + } + if runObject == nil && len(t.fixrouters) > 0 { + // Filter the .json .xml .html extension + for _, str := range allowSuffixExt { + if strings.HasSuffix(seg, str) { + for _, subTree := range t.fixrouters { + if subTree.prefix == seg[:len(seg)-len(str)] { + runObject = subTree.match(treePattern, pattern, wildcardValues, ctx) + if runObject != nil { + ctx.Input.SetParam(":ext", str[1:]) + } + } + } + } + } + } + if runObject == nil && t.wildcard != nil { + runObject = t.wildcard.match(treePattern, pattern, append(wildcardValues, seg), ctx) + } + + if runObject == nil && len(t.leaves) > 0 { + wildcardValues = append(wildcardValues, seg) + start, i := 0, 0 + for ; i < len(pattern); i++ { + if pattern[i] == '/' { + if i != 0 && start < len(pattern) { + wildcardValues = append(wildcardValues, pattern[start:i]) + } + start = i + 1 + continue + } + } + if start > 0 { + wildcardValues = append(wildcardValues, pattern[start:i]) + } + for _, l := range t.leaves { + if ok := l.match(treePattern, wildcardValues, ctx); ok { + return l.runObject + } + } + } + return runObject +} + +type leafInfo struct { + // names of wildcards that lead to this leaf. eg, ["id" "name"] for the wildcard ":id" and ":name" + wildcards []string + + // if the leaf is regexp + regexps *regexp.Regexp + + runObject interface{} +} + +func (leaf *leafInfo) match(treePattern string, wildcardValues []string, ctx *context.Context) (ok bool) { + //fmt.Println("Leaf:", wildcardValues, leaf.wildcards, leaf.regexps) + if leaf.regexps == nil { + if len(wildcardValues) == 0 && len(leaf.wildcards) == 0 { // static path + return true + } + // match * + if len(leaf.wildcards) == 1 && leaf.wildcards[0] == ":splat" { + ctx.Input.SetParam(":splat", treePattern) + return true + } + // match *.* or :id + if len(leaf.wildcards) >= 2 && leaf.wildcards[len(leaf.wildcards)-2] == ":path" && leaf.wildcards[len(leaf.wildcards)-1] == ":ext" { + if len(leaf.wildcards) == 2 { + lastone := wildcardValues[len(wildcardValues)-1] + strs := strings.SplitN(lastone, ".", 2) + if len(strs) == 2 { + ctx.Input.SetParam(":ext", strs[1]) + } + ctx.Input.SetParam(":path", path.Join(path.Join(wildcardValues[:len(wildcardValues)-1]...), strs[0])) + return true + } else if len(wildcardValues) < 2 { + return false + } + var index int + for index = 0; index < len(leaf.wildcards)-2; index++ { + ctx.Input.SetParam(leaf.wildcards[index], wildcardValues[index]) + } + lastone := wildcardValues[len(wildcardValues)-1] + strs := strings.SplitN(lastone, ".", 2) + if len(strs) == 2 { + ctx.Input.SetParam(":ext", strs[1]) + } + if index > (len(wildcardValues) - 1) { + ctx.Input.SetParam(":path", "") + } else { + ctx.Input.SetParam(":path", path.Join(path.Join(wildcardValues[index:len(wildcardValues)-1]...), strs[0])) + } + return true + } + // match :id + if len(leaf.wildcards) != len(wildcardValues) { + return false + } + for j, v := range leaf.wildcards { + ctx.Input.SetParam(v, wildcardValues[j]) + } + return true + } + + if !leaf.regexps.MatchString(path.Join(wildcardValues...)) { + return false + } + matches := leaf.regexps.FindStringSubmatch(path.Join(wildcardValues...)) + for i, match := range matches[1:] { + if i < len(leaf.wildcards) { + ctx.Input.SetParam(leaf.wildcards[i], match) + } + } + return true +} + +// "/" -> [] +// "/admin" -> ["admin"] +// "/admin/" -> ["admin"] +// "/admin/users" -> ["admin", "users"] +func splitPath(key string) []string { + key = strings.Trim(key, "/ ") + if key == "" { + return []string{} + } + return strings.Split(key, "/") +} + +// "admin" -> false, nil, "" +// ":id" -> true, [:id], "" +// "?:id" -> true, [: :id], "" : meaning can empty +// ":id:int" -> true, [:id], ([0-9]+) +// ":name:string" -> true, [:name], ([\w]+) +// ":id([0-9]+)" -> true, [:id], ([0-9]+) +// ":id([0-9]+)_:name" -> true, [:id :name], ([0-9]+)_(.+) +// "cms_:id_:page.html" -> true, [:id_ :page], cms_(.+)(.+).html +// "cms_:id(.+)_:page.html" -> true, [:id :page], cms_(.+)_(.+).html +// "*" -> true, [:splat], "" +// "*.*" -> true,[. :path :ext], "" . meaning separator +func splitSegment(key string) (bool, []string, string) { + if strings.HasPrefix(key, "*") { + if key == "*.*" { + return true, []string{".", ":path", ":ext"}, "" + } + return true, []string{":splat"}, "" + } + if strings.ContainsAny(key, ":") { + var paramsNum int + var out []rune + var start bool + var startexp bool + var param []rune + var expt []rune + var skipnum int + params := []string{} + reg := regexp.MustCompile(`[a-zA-Z0-9_]+`) + for i, v := range key { + if skipnum > 0 { + skipnum-- + continue + } + if start { + //:id:int and :name:string + if v == ':' { + if len(key) >= i+4 { + if key[i+1:i+4] == "int" { + out = append(out, []rune("([0-9]+)")...) + params = append(params, ":"+string(param)) + start = false + startexp = false + skipnum = 3 + param = make([]rune, 0) + paramsNum++ + continue + } + } + if len(key) >= i+7 { + if key[i+1:i+7] == "string" { + out = append(out, []rune(`([\w]+)`)...) + params = append(params, ":"+string(param)) + paramsNum++ + start = false + startexp = false + skipnum = 6 + param = make([]rune, 0) + continue + } + } + } + // params only support a-zA-Z0-9 + if reg.MatchString(string(v)) { + param = append(param, v) + continue + } + if v != '(' { + out = append(out, []rune(`(.+)`)...) + params = append(params, ":"+string(param)) + param = make([]rune, 0) + paramsNum++ + start = false + startexp = false + } + } + if startexp { + if v != ')' { + expt = append(expt, v) + continue + } + } + // Escape Sequence '\' + if i > 0 && key[i-1] == '\\' { + out = append(out, v) + } else if v == ':' { + param = make([]rune, 0) + start = true + } else if v == '(' { + startexp = true + start = false + if len(param) > 0 { + params = append(params, ":"+string(param)) + param = make([]rune, 0) + } + paramsNum++ + expt = make([]rune, 0) + expt = append(expt, '(') + } else if v == ')' { + startexp = false + expt = append(expt, ')') + out = append(out, expt...) + param = make([]rune, 0) + } else if v == '?' { + params = append(params, ":") + } else { + out = append(out, v) + } + } + if len(param) > 0 { + if paramsNum > 0 { + out = append(out, []rune(`(.+)`)...) + } + params = append(params, ":"+string(param)) + } + return true, params, string(out) + } + return false, nil, "" +} diff --git a/pkg/tree_test.go b/pkg/tree_test.go new file mode 100644 index 00000000..d412a348 --- /dev/null +++ b/pkg/tree_test.go @@ -0,0 +1,306 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package beego + +import ( + "strings" + "testing" + + "github.com/astaxie/beego/context" +) + +type testinfo struct { + url string + requesturl string + params map[string]string +} + +var routers []testinfo + +func init() { + routers = make([]testinfo, 0) + routers = append(routers, testinfo{"/topic/?:auth:int", "/topic", nil}) + routers = append(routers, testinfo{"/topic/?:auth:int", "/topic/123", map[string]string{":auth": "123"}}) + routers = append(routers, testinfo{"/topic/:id/?:auth", "/topic/1", map[string]string{":id": "1"}}) + routers = append(routers, testinfo{"/topic/:id/?:auth", "/topic/1/2", map[string]string{":id": "1", ":auth": "2"}}) + routers = append(routers, testinfo{"/topic/:id/?:auth:int", "/topic/1", map[string]string{":id": "1"}}) + routers = append(routers, testinfo{"/topic/:id/?:auth:int", "/topic/1/123", map[string]string{":id": "1", ":auth": "123"}}) + routers = append(routers, testinfo{"/:id", "/123", map[string]string{":id": "123"}}) + routers = append(routers, testinfo{"/hello/?:id", "/hello", map[string]string{":id": ""}}) + routers = append(routers, testinfo{"/", "/", nil}) + routers = append(routers, testinfo{"/customer/login", "/customer/login", nil}) + routers = append(routers, testinfo{"/customer/login", "/customer/login.json", map[string]string{":ext": "json"}}) + routers = append(routers, testinfo{"/*", "/http://customer/123/", map[string]string{":splat": "http://customer/123/"}}) + routers = append(routers, testinfo{"/*", "/customer/2009/12/11", map[string]string{":splat": "customer/2009/12/11"}}) + routers = append(routers, testinfo{"/aa/*/bb", "/aa/2009/bb", map[string]string{":splat": "2009"}}) + routers = append(routers, testinfo{"/cc/*/dd", "/cc/2009/11/dd", map[string]string{":splat": "2009/11"}}) + routers = append(routers, testinfo{"/cc/:id/*", "/cc/2009/11/dd", map[string]string{":id": "2009", ":splat": "11/dd"}}) + routers = append(routers, testinfo{"/ee/:year/*/ff", "/ee/2009/11/ff", map[string]string{":year": "2009", ":splat": "11"}}) + routers = append(routers, testinfo{"/thumbnail/:size/uploads/*", + "/thumbnail/100x100/uploads/items/2014/04/20/dPRCdChkUd651t1Hvs18.jpg", + map[string]string{":size": "100x100", ":splat": "items/2014/04/20/dPRCdChkUd651t1Hvs18.jpg"}}) + routers = append(routers, testinfo{"/*.*", "/nice/api.json", map[string]string{":path": "nice/api", ":ext": "json"}}) + routers = append(routers, testinfo{"/:name/*.*", "/nice/api.json", map[string]string{":name": "nice", ":path": "api", ":ext": "json"}}) + routers = append(routers, testinfo{"/:name/test/*.*", "/nice/test/api.json", map[string]string{":name": "nice", ":path": "api", ":ext": "json"}}) + routers = append(routers, testinfo{"/dl/:width:int/:height:int/*.*", + "/dl/48/48/05ac66d9bda00a3acf948c43e306fc9a.jpg", + map[string]string{":width": "48", ":height": "48", ":ext": "jpg", ":path": "05ac66d9bda00a3acf948c43e306fc9a"}}) + routers = append(routers, testinfo{"/v1/shop/:id:int", "/v1/shop/123", map[string]string{":id": "123"}}) + routers = append(routers, testinfo{"/v1/shop/:id\\((a|b|c)\\)", "/v1/shop/123(a)", map[string]string{":id": "123"}}) + routers = append(routers, testinfo{"/v1/shop/:id\\((a|b|c)\\)", "/v1/shop/123(b)", map[string]string{":id": "123"}}) + routers = append(routers, testinfo{"/v1/shop/:id\\((a|b|c)\\)", "/v1/shop/123(c)", map[string]string{":id": "123"}}) + routers = append(routers, testinfo{"/:year:int/:month:int/:id/:endid", "/1111/111/aaa/aaa", map[string]string{":year": "1111", ":month": "111", ":id": "aaa", ":endid": "aaa"}}) + routers = append(routers, testinfo{"/v1/shop/:id/:name", "/v1/shop/123/nike", map[string]string{":id": "123", ":name": "nike"}}) + routers = append(routers, testinfo{"/v1/shop/:id/account", "/v1/shop/123/account", map[string]string{":id": "123"}}) + routers = append(routers, testinfo{"/v1/shop/:name:string", "/v1/shop/nike", map[string]string{":name": "nike"}}) + routers = append(routers, testinfo{"/v1/shop/:id([0-9]+)", "/v1/shop//123", map[string]string{":id": "123"}}) + routers = append(routers, testinfo{"/v1/shop/:id([0-9]+)_:name", "/v1/shop/123_nike", map[string]string{":id": "123", ":name": "nike"}}) + routers = append(routers, testinfo{"/v1/shop/:id(.+)_cms.html", "/v1/shop/123_cms.html", map[string]string{":id": "123"}}) + routers = append(routers, testinfo{"/v1/shop/cms_:id(.+)_:page(.+).html", "/v1/shop/cms_123_1.html", map[string]string{":id": "123", ":page": "1"}}) + routers = append(routers, testinfo{"/v1/:v/cms/aaa_:id(.+)_:page(.+).html", "/v1/2/cms/aaa_123_1.html", map[string]string{":v": "2", ":id": "123", ":page": "1"}}) + routers = append(routers, testinfo{"/v1/:v/cms_:id(.+)_:page(.+).html", "/v1/2/cms_123_1.html", map[string]string{":v": "2", ":id": "123", ":page": "1"}}) + routers = append(routers, testinfo{"/v1/:v(.+)_cms/ttt_:id(.+)_:page(.+).html", "/v1/2_cms/ttt_123_1.html", map[string]string{":v": "2", ":id": "123", ":page": "1"}}) + routers = append(routers, testinfo{"/api/projects/:pid/members/?:mid", "/api/projects/1/members", map[string]string{":pid": "1"}}) + routers = append(routers, testinfo{"/api/projects/:pid/members/?:mid", "/api/projects/1/members/2", map[string]string{":pid": "1", ":mid": "2"}}) +} + +func TestTreeRouters(t *testing.T) { + for _, r := range routers { + tr := NewTree() + tr.AddRouter(r.url, "astaxie") + ctx := context.NewContext() + obj := tr.Match(r.requesturl, ctx) + if obj == nil || obj.(string) != "astaxie" { + t.Fatal(r.url+" can't get obj, Expect ", r.requesturl) + } + if r.params != nil { + for k, v := range r.params { + if vv := ctx.Input.Param(k); vv != v { + t.Fatal("The Rule: " + r.url + "\nThe RequestURL:" + r.requesturl + "\nThe Key is " + k + ", The Value should be: " + v + ", but get: " + vv) + } else if vv == "" && v != "" { + t.Fatal(r.url + " " + r.requesturl + " get param empty:" + k) + } + } + } + } +} + +func TestStaticPath(t *testing.T) { + tr := NewTree() + tr.AddRouter("/topic/:id", "wildcard") + tr.AddRouter("/topic", "static") + ctx := context.NewContext() + obj := tr.Match("/topic", ctx) + if obj == nil || obj.(string) != "static" { + t.Fatal("/topic is a static route") + } + obj = tr.Match("/topic/1", ctx) + if obj == nil || obj.(string) != "wildcard" { + t.Fatal("/topic/1 is a wildcard route") + } +} + +func TestAddTree(t *testing.T) { + tr := NewTree() + tr.AddRouter("/shop/:id/account", "astaxie") + tr.AddRouter("/shop/:sd/ttt_:id(.+)_:page(.+).html", "astaxie") + t1 := NewTree() + t1.AddTree("/v1/zl", tr) + ctx := context.NewContext() + obj := t1.Match("/v1/zl/shop/123/account", ctx) + if obj == nil || obj.(string) != "astaxie" { + t.Fatal("/v1/zl/shop/:id/account can't get obj ") + } + if ctx.Input.ParamsLen() == 0 { + t.Fatal("get param error") + } + if ctx.Input.Param(":id") != "123" { + t.Fatal("get :id param error") + } + ctx.Input.Reset(ctx) + obj = t1.Match("/v1/zl/shop/123/ttt_1_12.html", ctx) + if obj == nil || obj.(string) != "astaxie" { + t.Fatal("/v1/zl//shop/:sd/ttt_:id(.+)_:page(.+).html can't get obj ") + } + if ctx.Input.ParamsLen() == 0 { + t.Fatal("get param error") + } + if ctx.Input.Param(":sd") != "123" || ctx.Input.Param(":id") != "1" || ctx.Input.Param(":page") != "12" { + t.Fatal("get :sd :id :page param error") + } + + t2 := NewTree() + t2.AddTree("/v1/:shopid", tr) + ctx.Input.Reset(ctx) + obj = t2.Match("/v1/zl/shop/123/account", ctx) + if obj == nil || obj.(string) != "astaxie" { + t.Fatal("/v1/:shopid/shop/:id/account can't get obj ") + } + if ctx.Input.ParamsLen() == 0 { + t.Fatal("get param error") + } + if ctx.Input.Param(":id") != "123" || ctx.Input.Param(":shopid") != "zl" { + t.Fatal("get :id :shopid param error") + } + ctx.Input.Reset(ctx) + obj = t2.Match("/v1/zl/shop/123/ttt_1_12.html", ctx) + if obj == nil || obj.(string) != "astaxie" { + t.Fatal("/v1/:shopid/shop/:sd/ttt_:id(.+)_:page(.+).html can't get obj ") + } + if ctx.Input.ParamsLen() == 0 { + t.Fatal("get :shopid param error") + } + if ctx.Input.Param(":sd") != "123" || ctx.Input.Param(":id") != "1" || ctx.Input.Param(":page") != "12" || ctx.Input.Param(":shopid") != "zl" { + t.Fatal("get :sd :id :page :shopid param error") + } +} + +func TestAddTree2(t *testing.T) { + tr := NewTree() + tr.AddRouter("/shop/:id/account", "astaxie") + tr.AddRouter("/shop/:sd/ttt_:id(.+)_:page(.+).html", "astaxie") + t3 := NewTree() + t3.AddTree("/:version(v1|v2)/:prefix", tr) + ctx := context.NewContext() + obj := t3.Match("/v1/zl/shop/123/account", ctx) + if obj == nil || obj.(string) != "astaxie" { + t.Fatal("/:version(v1|v2)/:prefix/shop/:id/account can't get obj ") + } + if ctx.Input.ParamsLen() == 0 { + t.Fatal("get param error") + } + if ctx.Input.Param(":id") != "123" || ctx.Input.Param(":prefix") != "zl" || ctx.Input.Param(":version") != "v1" { + t.Fatal("get :id :prefix :version param error") + } +} + +func TestAddTree3(t *testing.T) { + tr := NewTree() + tr.AddRouter("/create", "astaxie") + tr.AddRouter("/shop/:sd/account", "astaxie") + t3 := NewTree() + t3.AddTree("/table/:num", tr) + ctx := context.NewContext() + obj := t3.Match("/table/123/shop/123/account", ctx) + if obj == nil || obj.(string) != "astaxie" { + t.Fatal("/table/:num/shop/:sd/account can't get obj ") + } + if ctx.Input.ParamsLen() == 0 { + t.Fatal("get param error") + } + if ctx.Input.Param(":num") != "123" || ctx.Input.Param(":sd") != "123" { + t.Fatal("get :num :sd param error") + } + ctx.Input.Reset(ctx) + obj = t3.Match("/table/123/create", ctx) + if obj == nil || obj.(string) != "astaxie" { + t.Fatal("/table/:num/create can't get obj ") + } +} + +func TestAddTree4(t *testing.T) { + tr := NewTree() + tr.AddRouter("/create", "astaxie") + tr.AddRouter("/shop/:sd/:account", "astaxie") + t4 := NewTree() + t4.AddTree("/:info:int/:num/:id", tr) + ctx := context.NewContext() + obj := t4.Match("/12/123/456/shop/123/account", ctx) + if obj == nil || obj.(string) != "astaxie" { + t.Fatal("/:info:int/:num/:id/shop/:sd/:account can't get obj ") + } + if ctx.Input.ParamsLen() == 0 { + t.Fatal("get param error") + } + if ctx.Input.Param(":info") != "12" || ctx.Input.Param(":num") != "123" || + ctx.Input.Param(":id") != "456" || ctx.Input.Param(":sd") != "123" || + ctx.Input.Param(":account") != "account" { + t.Fatal("get :info :num :id :sd :account param error") + } + ctx.Input.Reset(ctx) + obj = t4.Match("/12/123/456/create", ctx) + if obj == nil || obj.(string) != "astaxie" { + t.Fatal("/:info:int/:num/:id/create can't get obj ") + } +} + +// Test for issue #1595 +func TestAddTree5(t *testing.T) { + tr := NewTree() + tr.AddRouter("/v1/shop/:id", "shopdetail") + tr.AddRouter("/v1/shop/", "shophome") + ctx := context.NewContext() + obj := tr.Match("/v1/shop/", ctx) + if obj == nil || obj.(string) != "shophome" { + t.Fatal("url /v1/shop/ need match router /v1/shop/ ") + } +} + +func TestSplitPath(t *testing.T) { + a := splitPath("") + if len(a) != 0 { + t.Fatal("/ should retrun []") + } + a = splitPath("/") + if len(a) != 0 { + t.Fatal("/ should retrun []") + } + a = splitPath("/admin") + if len(a) != 1 || a[0] != "admin" { + t.Fatal("/admin should retrun [admin]") + } + a = splitPath("/admin/") + if len(a) != 1 || a[0] != "admin" { + t.Fatal("/admin/ should retrun [admin]") + } + a = splitPath("/admin/users") + if len(a) != 2 || a[0] != "admin" || a[1] != "users" { + t.Fatal("/admin should retrun [admin users]") + } + a = splitPath("/admin/:id:int") + if len(a) != 2 || a[0] != "admin" || a[1] != ":id:int" { + t.Fatal("/admin should retrun [admin :id:int]") + } +} + +func TestSplitSegment(t *testing.T) { + + items := map[string]struct { + isReg bool + params []string + regStr string + }{ + "admin": {false, nil, ""}, + "*": {true, []string{":splat"}, ""}, + "*.*": {true, []string{".", ":path", ":ext"}, ""}, + ":id": {true, []string{":id"}, ""}, + "?:id": {true, []string{":", ":id"}, ""}, + ":id:int": {true, []string{":id"}, "([0-9]+)"}, + ":name:string": {true, []string{":name"}, `([\w]+)`}, + ":id([0-9]+)": {true, []string{":id"}, `([0-9]+)`}, + ":id([0-9]+)_:name": {true, []string{":id", ":name"}, `([0-9]+)_(.+)`}, + ":id(.+)_cms.html": {true, []string{":id"}, `(.+)_cms.html`}, + "cms_:id(.+)_:page(.+).html": {true, []string{":id", ":page"}, `cms_(.+)_(.+).html`}, + `:app(a|b|c)`: {true, []string{":app"}, `(a|b|c)`}, + `:app\((a|b|c)\)`: {true, []string{":app"}, `(.+)\((a|b|c)\)`}, + } + + for pattern, v := range items { + b, w, r := splitSegment(pattern) + if b != v.isReg || r != v.regStr || strings.Join(w, ",") != strings.Join(v.params, ",") { + t.Fatalf("%s should return %t,%s,%q, got %t,%s,%q", pattern, v.isReg, v.params, v.regStr, b, w, r) + } + } +} diff --git a/pkg/unregroute_test.go b/pkg/unregroute_test.go new file mode 100644 index 00000000..08b1b77b --- /dev/null +++ b/pkg/unregroute_test.go @@ -0,0 +1,226 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package beego + +import ( + "net/http" + "net/http/httptest" + "strings" + "testing" +) + +// +// The unregroute_test.go contains tests for the unregister route +// functionality, that allows overriding route paths in children project +// that embed parent routers. +// + +const contentRootOriginal = "ok-original-root" +const contentLevel1Original = "ok-original-level1" +const contentLevel2Original = "ok-original-level2" + +const contentRootReplacement = "ok-replacement-root" +const contentLevel1Replacement = "ok-replacement-level1" +const contentLevel2Replacement = "ok-replacement-level2" + +// TestPreUnregController will supply content for the original routes, +// before unregistration +type TestPreUnregController struct { + Controller +} + +func (tc *TestPreUnregController) GetFixedRoot() { + tc.Ctx.Output.Body([]byte(contentRootOriginal)) +} +func (tc *TestPreUnregController) GetFixedLevel1() { + tc.Ctx.Output.Body([]byte(contentLevel1Original)) +} +func (tc *TestPreUnregController) GetFixedLevel2() { + tc.Ctx.Output.Body([]byte(contentLevel2Original)) +} + +// TestPostUnregController will supply content for the overriding routes, +// after the original ones are unregistered. +type TestPostUnregController struct { + Controller +} + +func (tc *TestPostUnregController) GetFixedRoot() { + tc.Ctx.Output.Body([]byte(contentRootReplacement)) +} +func (tc *TestPostUnregController) GetFixedLevel1() { + tc.Ctx.Output.Body([]byte(contentLevel1Replacement)) +} +func (tc *TestPostUnregController) GetFixedLevel2() { + tc.Ctx.Output.Body([]byte(contentLevel2Replacement)) +} + +// TestUnregisterFixedRouteRoot replaces just the root fixed route path. +// In this case, for a path like "/level1/level2" or "/level1", those actions +// should remain intact, and continue to serve the original content. +func TestUnregisterFixedRouteRoot(t *testing.T) { + + var method = "GET" + + handler := NewControllerRegister() + handler.Add("/", &TestPreUnregController{}, "get:GetFixedRoot") + handler.Add("/level1", &TestPreUnregController{}, "get:GetFixedLevel1") + handler.Add("/level1/level2", &TestPreUnregController{}, "get:GetFixedLevel2") + + // Test original root + testHelperFnContentCheck(t, handler, "Test original root", + method, "/", contentRootOriginal) + + // Test original level 1 + testHelperFnContentCheck(t, handler, "Test original level 1", + method, "/level1", contentLevel1Original) + + // Test original level 2 + testHelperFnContentCheck(t, handler, "Test original level 2", + method, "/level1/level2", contentLevel2Original) + + // Remove only the root path + findAndRemoveSingleTree(handler.routers[method]) + + // Replace the root path TestPreUnregController action with the action from + // TestPostUnregController + handler.Add("/", &TestPostUnregController{}, "get:GetFixedRoot") + + // Test replacement root (expect change) + testHelperFnContentCheck(t, handler, "Test replacement root (expect change)", method, "/", contentRootReplacement) + + // Test level 1 (expect no change from the original) + testHelperFnContentCheck(t, handler, "Test level 1 (expect no change from the original)", method, "/level1", contentLevel1Original) + + // Test level 2 (expect no change from the original) + testHelperFnContentCheck(t, handler, "Test level 2 (expect no change from the original)", method, "/level1/level2", contentLevel2Original) + +} + +// TestUnregisterFixedRouteLevel1 replaces just the "/level1" fixed route path. +// In this case, for a path like "/level1/level2" or "/", those actions +// should remain intact, and continue to serve the original content. +func TestUnregisterFixedRouteLevel1(t *testing.T) { + + var method = "GET" + + handler := NewControllerRegister() + handler.Add("/", &TestPreUnregController{}, "get:GetFixedRoot") + handler.Add("/level1", &TestPreUnregController{}, "get:GetFixedLevel1") + handler.Add("/level1/level2", &TestPreUnregController{}, "get:GetFixedLevel2") + + // Test original root + testHelperFnContentCheck(t, handler, + "TestUnregisterFixedRouteLevel1.Test original root", + method, "/", contentRootOriginal) + + // Test original level 1 + testHelperFnContentCheck(t, handler, + "TestUnregisterFixedRouteLevel1.Test original level 1", + method, "/level1", contentLevel1Original) + + // Test original level 2 + testHelperFnContentCheck(t, handler, + "TestUnregisterFixedRouteLevel1.Test original level 2", + method, "/level1/level2", contentLevel2Original) + + // Remove only the level1 path + subPaths := splitPath("/level1") + if handler.routers[method].prefix == strings.Trim("/level1", "/ ") { + findAndRemoveSingleTree(handler.routers[method]) + } else { + findAndRemoveTree(subPaths, handler.routers[method], method) + } + + // Replace the "level1" path TestPreUnregController action with the action from + // TestPostUnregController + handler.Add("/level1", &TestPostUnregController{}, "get:GetFixedLevel1") + + // Test replacement root (expect no change from the original) + testHelperFnContentCheck(t, handler, "Test replacement root (expect no change from the original)", method, "/", contentRootOriginal) + + // Test level 1 (expect change) + testHelperFnContentCheck(t, handler, "Test level 1 (expect change)", method, "/level1", contentLevel1Replacement) + + // Test level 2 (expect no change from the original) + testHelperFnContentCheck(t, handler, "Test level 2 (expect no change from the original)", method, "/level1/level2", contentLevel2Original) + +} + +// TestUnregisterFixedRouteLevel2 unregisters just the "/level1/level2" fixed +// route path. In this case, for a path like "/level1" or "/", those actions +// should remain intact, and continue to serve the original content. +func TestUnregisterFixedRouteLevel2(t *testing.T) { + + var method = "GET" + + handler := NewControllerRegister() + handler.Add("/", &TestPreUnregController{}, "get:GetFixedRoot") + handler.Add("/level1", &TestPreUnregController{}, "get:GetFixedLevel1") + handler.Add("/level1/level2", &TestPreUnregController{}, "get:GetFixedLevel2") + + // Test original root + testHelperFnContentCheck(t, handler, + "TestUnregisterFixedRouteLevel1.Test original root", + method, "/", contentRootOriginal) + + // Test original level 1 + testHelperFnContentCheck(t, handler, + "TestUnregisterFixedRouteLevel1.Test original level 1", + method, "/level1", contentLevel1Original) + + // Test original level 2 + testHelperFnContentCheck(t, handler, + "TestUnregisterFixedRouteLevel1.Test original level 2", + method, "/level1/level2", contentLevel2Original) + + // Remove only the level2 path + subPaths := splitPath("/level1/level2") + if handler.routers[method].prefix == strings.Trim("/level1/level2", "/ ") { + findAndRemoveSingleTree(handler.routers[method]) + } else { + findAndRemoveTree(subPaths, handler.routers[method], method) + } + + // Replace the "/level1/level2" path TestPreUnregController action with the action from + // TestPostUnregController + handler.Add("/level1/level2", &TestPostUnregController{}, "get:GetFixedLevel2") + + // Test replacement root (expect no change from the original) + testHelperFnContentCheck(t, handler, "Test replacement root (expect no change from the original)", method, "/", contentRootOriginal) + + // Test level 1 (expect no change from the original) + testHelperFnContentCheck(t, handler, "Test level 1 (expect no change from the original)", method, "/level1", contentLevel1Original) + + // Test level 2 (expect change) + testHelperFnContentCheck(t, handler, "Test level 2 (expect change)", method, "/level1/level2", contentLevel2Replacement) + +} + +func testHelperFnContentCheck(t *testing.T, handler *ControllerRegister, + testName, method, path, expectedBodyContent string) { + + r, err := http.NewRequest(method, path, nil) + if err != nil { + t.Errorf("httpRecorderBodyTest NewRequest error: %v", err) + return + } + w := httptest.NewRecorder() + handler.ServeHTTP(w, r) + body := w.Body.String() + if body != expectedBodyContent { + t.Errorf("%s: expected [%s], got [%s];", testName, expectedBodyContent, body) + } +} diff --git a/pkg/utils/caller.go b/pkg/utils/caller.go new file mode 100644 index 00000000..73c52a62 --- /dev/null +++ b/pkg/utils/caller.go @@ -0,0 +1,25 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package utils + +import ( + "reflect" + "runtime" +) + +// GetFuncName get function name +func GetFuncName(i interface{}) string { + return runtime.FuncForPC(reflect.ValueOf(i).Pointer()).Name() +} diff --git a/pkg/utils/caller_test.go b/pkg/utils/caller_test.go new file mode 100644 index 00000000..0675f0aa --- /dev/null +++ b/pkg/utils/caller_test.go @@ -0,0 +1,28 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package utils + +import ( + "strings" + "testing" +) + +func TestGetFuncName(t *testing.T) { + name := GetFuncName(TestGetFuncName) + t.Log(name) + if !strings.HasSuffix(name, ".TestGetFuncName") { + t.Error("get func name error") + } +} diff --git a/pkg/utils/captcha/LICENSE b/pkg/utils/captcha/LICENSE new file mode 100644 index 00000000..0ad73ae0 --- /dev/null +++ b/pkg/utils/captcha/LICENSE @@ -0,0 +1,19 @@ +Copyright (c) 2011-2014 Dmitry Chestnykh + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. diff --git a/pkg/utils/captcha/README.md b/pkg/utils/captcha/README.md new file mode 100644 index 00000000..dbc2026b --- /dev/null +++ b/pkg/utils/captcha/README.md @@ -0,0 +1,45 @@ +# Captcha + +an example for use captcha + +``` +package controllers + +import ( + "github.com/astaxie/beego" + "github.com/astaxie/beego/cache" + "github.com/astaxie/beego/utils/captcha" +) + +var cpt *captcha.Captcha + +func init() { + // use beego cache system store the captcha data + store := cache.NewMemoryCache() + cpt = captcha.NewWithFilter("/captcha/", store) +} + +type MainController struct { + beego.Controller +} + +func (this *MainController) Get() { + this.TplName = "index.tpl" +} + +func (this *MainController) Post() { + this.TplName = "index.tpl" + + this.Data["Success"] = cpt.VerifyReq(this.Ctx.Request) +} +``` + +template usage + +``` +{{.Success}} +
+ {{create_captcha}} + +
+``` diff --git a/pkg/utils/captcha/captcha.go b/pkg/utils/captcha/captcha.go new file mode 100644 index 00000000..42ac70d3 --- /dev/null +++ b/pkg/utils/captcha/captcha.go @@ -0,0 +1,270 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package captcha implements generation and verification of image CAPTCHAs. +// an example for use captcha +// +// ``` +// package controllers +// +// import ( +// "github.com/astaxie/beego" +// "github.com/astaxie/beego/cache" +// "github.com/astaxie/beego/utils/captcha" +// ) +// +// var cpt *captcha.Captcha +// +// func init() { +// // use beego cache system store the captcha data +// store := cache.NewMemoryCache() +// cpt = captcha.NewWithFilter("/captcha/", store) +// } +// +// type MainController struct { +// beego.Controller +// } +// +// func (this *MainController) Get() { +// this.TplName = "index.tpl" +// } +// +// func (this *MainController) Post() { +// this.TplName = "index.tpl" +// +// this.Data["Success"] = cpt.VerifyReq(this.Ctx.Request) +// } +// ``` +// +// template usage +// +// ``` +// {{.Success}} +//
+// {{create_captcha}} +// +//
+// ``` +package captcha + +import ( + "fmt" + "html/template" + "net/http" + "path" + "strings" + "time" + + "github.com/astaxie/beego" + "github.com/astaxie/beego/cache" + "github.com/astaxie/beego/context" + "github.com/astaxie/beego/logs" + "github.com/astaxie/beego/utils" +) + +var ( + defaultChars = []byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9} +) + +const ( + // default captcha attributes + challengeNums = 6 + expiration = 600 * time.Second + fieldIDName = "captcha_id" + fieldCaptchaName = "captcha" + cachePrefix = "captcha_" + defaultURLPrefix = "/captcha/" +) + +// Captcha struct +type Captcha struct { + // beego cache store + store cache.Cache + + // url prefix for captcha image + URLPrefix string + + // specify captcha id input field name + FieldIDName string + // specify captcha result input field name + FieldCaptchaName string + + // captcha image width and height + StdWidth int + StdHeight int + + // captcha chars nums + ChallengeNums int + + // captcha expiration seconds + Expiration time.Duration + + // cache key prefix + CachePrefix string +} + +// generate key string +func (c *Captcha) key(id string) string { + return c.CachePrefix + id +} + +// generate rand chars with default chars +func (c *Captcha) genRandChars() []byte { + return utils.RandomCreateBytes(c.ChallengeNums, defaultChars...) +} + +// Handler beego filter handler for serve captcha image +func (c *Captcha) Handler(ctx *context.Context) { + var chars []byte + + id := path.Base(ctx.Request.RequestURI) + if i := strings.Index(id, "."); i != -1 { + id = id[:i] + } + + key := c.key(id) + + if len(ctx.Input.Query("reload")) > 0 { + chars = c.genRandChars() + if err := c.store.Put(key, chars, c.Expiration); err != nil { + ctx.Output.SetStatus(500) + ctx.WriteString("captcha reload error") + logs.Error("Reload Create Captcha Error:", err) + return + } + } else { + if v, ok := c.store.Get(key).([]byte); ok { + chars = v + } else { + ctx.Output.SetStatus(404) + ctx.WriteString("captcha not found") + return + } + } + + img := NewImage(chars, c.StdWidth, c.StdHeight) + if _, err := img.WriteTo(ctx.ResponseWriter); err != nil { + logs.Error("Write Captcha Image Error:", err) + } +} + +// CreateCaptchaHTML template func for output html +func (c *Captcha) CreateCaptchaHTML() template.HTML { + value, err := c.CreateCaptcha() + if err != nil { + logs.Error("Create Captcha Error:", err) + return "" + } + + // create html + return template.HTML(fmt.Sprintf(``+ + ``+ + ``+ + ``, c.FieldIDName, value, c.URLPrefix, value, c.URLPrefix, value)) +} + +// CreateCaptcha create a new captcha id +func (c *Captcha) CreateCaptcha() (string, error) { + // generate captcha id + id := string(utils.RandomCreateBytes(15)) + + // get the captcha chars + chars := c.genRandChars() + + // save to store + if err := c.store.Put(c.key(id), chars, c.Expiration); err != nil { + return "", err + } + + return id, nil +} + +// VerifyReq verify from a request +func (c *Captcha) VerifyReq(req *http.Request) bool { + req.ParseForm() + return c.Verify(req.Form.Get(c.FieldIDName), req.Form.Get(c.FieldCaptchaName)) +} + +// Verify direct verify id and challenge string +func (c *Captcha) Verify(id string, challenge string) (success bool) { + if len(challenge) == 0 || len(id) == 0 { + return + } + + var chars []byte + + key := c.key(id) + + if v, ok := c.store.Get(key).([]byte); ok { + chars = v + } else { + return + } + + defer func() { + // finally remove it + c.store.Delete(key) + }() + + if len(chars) != len(challenge) { + return + } + // verify challenge + for i, c := range chars { + if c != challenge[i]-48 { + return + } + } + + return true +} + +// NewCaptcha create a new captcha.Captcha +func NewCaptcha(urlPrefix string, store cache.Cache) *Captcha { + cpt := &Captcha{} + cpt.store = store + cpt.FieldIDName = fieldIDName + cpt.FieldCaptchaName = fieldCaptchaName + cpt.ChallengeNums = challengeNums + cpt.Expiration = expiration + cpt.CachePrefix = cachePrefix + cpt.StdWidth = stdWidth + cpt.StdHeight = stdHeight + + if len(urlPrefix) == 0 { + urlPrefix = defaultURLPrefix + } + + if urlPrefix[len(urlPrefix)-1] != '/' { + urlPrefix += "/" + } + + cpt.URLPrefix = urlPrefix + + return cpt +} + +// NewWithFilter create a new captcha.Captcha and auto AddFilter for serve captacha image +// and add a template func for output html +func NewWithFilter(urlPrefix string, store cache.Cache) *Captcha { + cpt := NewCaptcha(urlPrefix, store) + + // create filter for serve captcha image + beego.InsertFilter(cpt.URLPrefix+"*", beego.BeforeRouter, cpt.Handler) + + // add to template func map + beego.AddFuncMap("create_captcha", cpt.CreateCaptchaHTML) + + return cpt +} diff --git a/pkg/utils/captcha/image.go b/pkg/utils/captcha/image.go new file mode 100644 index 00000000..c3c9a83a --- /dev/null +++ b/pkg/utils/captcha/image.go @@ -0,0 +1,501 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package captcha + +import ( + "bytes" + "image" + "image/color" + "image/png" + "io" + "math" +) + +const ( + fontWidth = 11 + fontHeight = 18 + blackChar = 1 + + // Standard width and height of a captcha image. + stdWidth = 240 + stdHeight = 80 + // Maximum absolute skew factor of a single digit. + maxSkew = 0.7 + // Number of background circles. + circleCount = 20 +) + +var font = [][]byte{ + { // 0 + 0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 0, + 0, 0, 1, 1, 1, 1, 1, 1, 1, 0, 0, + 0, 1, 1, 1, 0, 0, 0, 1, 1, 1, 0, + 0, 1, 1, 0, 0, 0, 0, 0, 1, 1, 0, + 1, 1, 1, 0, 0, 0, 0, 0, 1, 1, 0, + 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, + 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, + 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, + 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, + 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, + 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, + 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, + 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, + 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 1, + 0, 1, 1, 0, 0, 0, 0, 0, 1, 1, 0, + 0, 1, 1, 1, 0, 0, 0, 1, 1, 1, 0, + 0, 0, 1, 1, 1, 1, 1, 1, 1, 0, 0, + 0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 0, + }, + { // 1 + 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, + 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, + 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, + 0, 0, 1, 1, 1, 1, 1, 0, 0, 0, 0, + 0, 0, 1, 1, 0, 1, 1, 0, 0, 0, 0, + 0, 0, 1, 0, 0, 1, 1, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, + 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + }, + { // 2 + 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, + 0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, + 1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 0, + 0, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, + 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, + 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, + 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, + 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, + 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, + 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, + 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, + 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + }, + { // 3 + 0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 0, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, + 1, 1, 0, 0, 0, 0, 0, 1, 1, 1, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, + 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, + 0, 0, 1, 1, 1, 1, 1, 0, 0, 0, 0, + 0, 0, 1, 1, 1, 1, 1, 1, 1, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, + 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, + 1, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, + 0, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, + }, + { // 4 + 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, + 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, + 0, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0, + 0, 0, 0, 0, 0, 1, 0, 1, 1, 0, 0, + 0, 0, 0, 0, 1, 1, 0, 1, 1, 0, 0, + 0, 0, 0, 1, 1, 0, 0, 1, 1, 0, 0, + 0, 0, 0, 1, 1, 0, 0, 1, 1, 0, 0, + 0, 0, 1, 1, 0, 0, 0, 1, 1, 0, 0, + 0, 1, 1, 0, 0, 0, 0, 1, 1, 0, 0, + 0, 1, 1, 0, 0, 0, 0, 1, 1, 0, 0, + 1, 1, 0, 0, 0, 0, 0, 1, 1, 0, 0, + 1, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, + }, + { // 5 + 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, + 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, + 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, + 0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, + 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, + 1, 1, 0, 0, 0, 0, 0, 1, 1, 1, 0, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, + 0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 0, + }, + { // 6 + 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 0, + 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 0, + 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, + 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 1, 1, 0, 0, 1, 1, 1, 1, 0, 0, 0, + 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 0, + 1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 0, + 1, 1, 1, 0, 0, 0, 0, 0, 1, 1, 1, + 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, + 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, + 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, + 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, + 0, 1, 1, 0, 0, 0, 0, 0, 1, 1, 1, + 0, 1, 1, 1, 0, 0, 0, 1, 1, 1, 0, + 0, 0, 1, 1, 1, 1, 1, 1, 1, 0, 0, + 0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 0, + }, + { // 7 + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, + 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, + 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, + 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, + 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, + 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, + 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, + 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, + 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, + 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, + 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, + 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, + }, + { // 8 + 0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 0, + 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 0, + 0, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1, + 0, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, + 0, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, + 0, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, + 0, 0, 1, 1, 0, 0, 0, 0, 1, 1, 0, + 0, 0, 1, 1, 1, 1, 1, 1, 1, 0, 0, + 0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, + 0, 0, 1, 1, 1, 0, 1, 1, 1, 0, 0, + 0, 1, 1, 1, 0, 0, 0, 1, 1, 1, 0, + 1, 1, 1, 0, 0, 0, 0, 0, 1, 1, 1, + 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, + 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, + 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, + 1, 1, 1, 0, 0, 0, 0, 0, 1, 1, 0, + 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, + 0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 0, + }, + { // 9 + 0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 0, + 0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, + 0, 1, 1, 0, 0, 0, 0, 1, 1, 1, 0, + 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0, + 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, + 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, + 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, + 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 1, + 0, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1, + 0, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, + 0, 0, 0, 1, 1, 1, 1, 0, 0, 1, 1, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, + 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, + 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, + 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, + 0, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, + 0, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, + }, +} + +// Image struct +type Image struct { + *image.Paletted + numWidth int + numHeight int + dotSize int +} + +var prng = &siprng{} + +// randIntn returns a pseudorandom non-negative int in range [0, n). +func randIntn(n int) int { + return prng.Intn(n) +} + +// randInt returns a pseudorandom int in range [from, to]. +func randInt(from, to int) int { + return prng.Intn(to+1-from) + from +} + +// randFloat returns a pseudorandom float64 in range [from, to]. +func randFloat(from, to float64) float64 { + return (to-from)*prng.Float64() + from +} + +func randomPalette() color.Palette { + p := make([]color.Color, circleCount+1) + // Transparent color. + p[0] = color.RGBA{0xFF, 0xFF, 0xFF, 0x00} + // Primary color. + prim := color.RGBA{ + uint8(randIntn(129)), + uint8(randIntn(129)), + uint8(randIntn(129)), + 0xFF, + } + p[1] = prim + // Circle colors. + for i := 2; i <= circleCount; i++ { + p[i] = randomBrightness(prim, 255) + } + return p +} + +// NewImage returns a new captcha image of the given width and height with the +// given digits, where each digit must be in range 0-9. +func NewImage(digits []byte, width, height int) *Image { + m := new(Image) + m.Paletted = image.NewPaletted(image.Rect(0, 0, width, height), randomPalette()) + m.calculateSizes(width, height, len(digits)) + // Randomly position captcha inside the image. + maxx := width - (m.numWidth+m.dotSize)*len(digits) - m.dotSize + maxy := height - m.numHeight - m.dotSize*2 + var border int + if width > height { + border = height / 5 + } else { + border = width / 5 + } + x := randInt(border, maxx-border) + y := randInt(border, maxy-border) + // Draw digits. + for _, n := range digits { + m.drawDigit(font[n], x, y) + x += m.numWidth + m.dotSize + } + // Draw strike-through line. + m.strikeThrough() + // Apply wave distortion. + m.distort(randFloat(5, 10), randFloat(100, 200)) + // Fill image with random circles. + m.fillWithCircles(circleCount, m.dotSize) + return m +} + +// encodedPNG encodes an image to PNG and returns +// the result as a byte slice. +func (m *Image) encodedPNG() []byte { + var buf bytes.Buffer + if err := png.Encode(&buf, m.Paletted); err != nil { + panic(err.Error()) + } + return buf.Bytes() +} + +// WriteTo writes captcha image in PNG format into the given writer. +func (m *Image) WriteTo(w io.Writer) (int64, error) { + n, err := w.Write(m.encodedPNG()) + return int64(n), err +} + +func (m *Image) calculateSizes(width, height, ncount int) { + // Goal: fit all digits inside the image. + var border int + if width > height { + border = height / 4 + } else { + border = width / 4 + } + // Convert everything to floats for calculations. + w := float64(width - border*2) + h := float64(height - border*2) + // fw takes into account 1-dot spacing between digits. + fw := float64(fontWidth + 1) + fh := float64(fontHeight) + nc := float64(ncount) + // Calculate the width of a single digit taking into account only the + // width of the image. + nw := w / nc + // Calculate the height of a digit from this width. + nh := nw * fh / fw + // Digit too high? + if nh > h { + // Fit digits based on height. + nh = h + nw = fw / fh * nh + } + // Calculate dot size. + m.dotSize = int(nh / fh) + if m.dotSize < 1 { + m.dotSize = 1 + } + // Save everything, making the actual width smaller by 1 dot to account + // for spacing between digits. + m.numWidth = int(nw) - m.dotSize + m.numHeight = int(nh) +} + +func (m *Image) drawHorizLine(fromX, toX, y int, colorIdx uint8) { + for x := fromX; x <= toX; x++ { + m.SetColorIndex(x, y, colorIdx) + } +} + +func (m *Image) drawCircle(x, y, radius int, colorIdx uint8) { + f := 1 - radius + dfx := 1 + dfy := -2 * radius + xo := 0 + yo := radius + + m.SetColorIndex(x, y+radius, colorIdx) + m.SetColorIndex(x, y-radius, colorIdx) + m.drawHorizLine(x-radius, x+radius, y, colorIdx) + + for xo < yo { + if f >= 0 { + yo-- + dfy += 2 + f += dfy + } + xo++ + dfx += 2 + f += dfx + m.drawHorizLine(x-xo, x+xo, y+yo, colorIdx) + m.drawHorizLine(x-xo, x+xo, y-yo, colorIdx) + m.drawHorizLine(x-yo, x+yo, y+xo, colorIdx) + m.drawHorizLine(x-yo, x+yo, y-xo, colorIdx) + } +} + +func (m *Image) fillWithCircles(n, maxradius int) { + maxx := m.Bounds().Max.X + maxy := m.Bounds().Max.Y + for i := 0; i < n; i++ { + colorIdx := uint8(randInt(1, circleCount-1)) + r := randInt(1, maxradius) + m.drawCircle(randInt(r, maxx-r), randInt(r, maxy-r), r, colorIdx) + } +} + +func (m *Image) strikeThrough() { + maxx := m.Bounds().Max.X + maxy := m.Bounds().Max.Y + y := randInt(maxy/3, maxy-maxy/3) + amplitude := randFloat(5, 20) + period := randFloat(80, 180) + dx := 2.0 * math.Pi / period + for x := 0; x < maxx; x++ { + xo := amplitude * math.Cos(float64(y)*dx) + yo := amplitude * math.Sin(float64(x)*dx) + for yn := 0; yn < m.dotSize; yn++ { + r := randInt(0, m.dotSize) + m.drawCircle(x+int(xo), y+int(yo)+(yn*m.dotSize), r/2, 1) + } + } +} + +func (m *Image) drawDigit(digit []byte, x, y int) { + skf := randFloat(-maxSkew, maxSkew) + xs := float64(x) + r := m.dotSize / 2 + y += randInt(-r, r) + for yo := 0; yo < fontHeight; yo++ { + for xo := 0; xo < fontWidth; xo++ { + if digit[yo*fontWidth+xo] != blackChar { + continue + } + m.drawCircle(x+xo*m.dotSize, y+yo*m.dotSize, r, 1) + } + xs += skf + x = int(xs) + } +} + +func (m *Image) distort(amplude float64, period float64) { + w := m.Bounds().Max.X + h := m.Bounds().Max.Y + + oldm := m.Paletted + newm := image.NewPaletted(image.Rect(0, 0, w, h), oldm.Palette) + + dx := 2.0 * math.Pi / period + for x := 0; x < w; x++ { + for y := 0; y < h; y++ { + xo := amplude * math.Sin(float64(y)*dx) + yo := amplude * math.Cos(float64(x)*dx) + newm.SetColorIndex(x, y, oldm.ColorIndexAt(x+int(xo), y+int(yo))) + } + } + m.Paletted = newm +} + +func randomBrightness(c color.RGBA, max uint8) color.RGBA { + minc := min3(c.R, c.G, c.B) + maxc := max3(c.R, c.G, c.B) + if maxc > max { + return c + } + n := randIntn(int(max-maxc)) - int(minc) + return color.RGBA{ + uint8(int(c.R) + n), + uint8(int(c.G) + n), + uint8(int(c.B) + n), + c.A, + } +} + +func min3(x, y, z uint8) (m uint8) { + m = x + if y < m { + m = y + } + if z < m { + m = z + } + return +} + +func max3(x, y, z uint8) (m uint8) { + m = x + if y > m { + m = y + } + if z > m { + m = z + } + return +} diff --git a/pkg/utils/captcha/image_test.go b/pkg/utils/captcha/image_test.go new file mode 100644 index 00000000..5e35b7f7 --- /dev/null +++ b/pkg/utils/captcha/image_test.go @@ -0,0 +1,52 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package captcha + +import ( + "testing" + + "github.com/astaxie/beego/utils" +) + +type byteCounter struct { + n int64 +} + +func (bc *byteCounter) Write(b []byte) (int, error) { + bc.n += int64(len(b)) + return len(b), nil +} + +func BenchmarkNewImage(b *testing.B) { + b.StopTimer() + d := utils.RandomCreateBytes(challengeNums, defaultChars...) + b.StartTimer() + for i := 0; i < b.N; i++ { + NewImage(d, stdWidth, stdHeight) + } +} + +func BenchmarkImageWriteTo(b *testing.B) { + b.StopTimer() + d := utils.RandomCreateBytes(challengeNums, defaultChars...) + b.StartTimer() + counter := &byteCounter{} + for i := 0; i < b.N; i++ { + img := NewImage(d, stdWidth, stdHeight) + img.WriteTo(counter) + b.SetBytes(counter.n) + counter.n = 0 + } +} diff --git a/pkg/utils/captcha/siprng.go b/pkg/utils/captcha/siprng.go new file mode 100644 index 00000000..5e256cf9 --- /dev/null +++ b/pkg/utils/captcha/siprng.go @@ -0,0 +1,277 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package captcha + +import ( + "crypto/rand" + "encoding/binary" + "io" + "sync" +) + +// siprng is PRNG based on SipHash-2-4. +type siprng struct { + mu sync.Mutex + k0, k1, ctr uint64 +} + +// siphash implements SipHash-2-4, accepting a uint64 as a message. +func siphash(k0, k1, m uint64) uint64 { + // Initialization. + v0 := k0 ^ 0x736f6d6570736575 + v1 := k1 ^ 0x646f72616e646f6d + v2 := k0 ^ 0x6c7967656e657261 + v3 := k1 ^ 0x7465646279746573 + t := uint64(8) << 56 + + // Compression. + v3 ^= m + + // Round 1. + v0 += v1 + v1 = v1<<13 | v1>>(64-13) + v1 ^= v0 + v0 = v0<<32 | v0>>(64-32) + + v2 += v3 + v3 = v3<<16 | v3>>(64-16) + v3 ^= v2 + + v0 += v3 + v3 = v3<<21 | v3>>(64-21) + v3 ^= v0 + + v2 += v1 + v1 = v1<<17 | v1>>(64-17) + v1 ^= v2 + v2 = v2<<32 | v2>>(64-32) + + // Round 2. + v0 += v1 + v1 = v1<<13 | v1>>(64-13) + v1 ^= v0 + v0 = v0<<32 | v0>>(64-32) + + v2 += v3 + v3 = v3<<16 | v3>>(64-16) + v3 ^= v2 + + v0 += v3 + v3 = v3<<21 | v3>>(64-21) + v3 ^= v0 + + v2 += v1 + v1 = v1<<17 | v1>>(64-17) + v1 ^= v2 + v2 = v2<<32 | v2>>(64-32) + + v0 ^= m + + // Compress last block. + v3 ^= t + + // Round 1. + v0 += v1 + v1 = v1<<13 | v1>>(64-13) + v1 ^= v0 + v0 = v0<<32 | v0>>(64-32) + + v2 += v3 + v3 = v3<<16 | v3>>(64-16) + v3 ^= v2 + + v0 += v3 + v3 = v3<<21 | v3>>(64-21) + v3 ^= v0 + + v2 += v1 + v1 = v1<<17 | v1>>(64-17) + v1 ^= v2 + v2 = v2<<32 | v2>>(64-32) + + // Round 2. + v0 += v1 + v1 = v1<<13 | v1>>(64-13) + v1 ^= v0 + v0 = v0<<32 | v0>>(64-32) + + v2 += v3 + v3 = v3<<16 | v3>>(64-16) + v3 ^= v2 + + v0 += v3 + v3 = v3<<21 | v3>>(64-21) + v3 ^= v0 + + v2 += v1 + v1 = v1<<17 | v1>>(64-17) + v1 ^= v2 + v2 = v2<<32 | v2>>(64-32) + + v0 ^= t + + // Finalization. + v2 ^= 0xff + + // Round 1. + v0 += v1 + v1 = v1<<13 | v1>>(64-13) + v1 ^= v0 + v0 = v0<<32 | v0>>(64-32) + + v2 += v3 + v3 = v3<<16 | v3>>(64-16) + v3 ^= v2 + + v0 += v3 + v3 = v3<<21 | v3>>(64-21) + v3 ^= v0 + + v2 += v1 + v1 = v1<<17 | v1>>(64-17) + v1 ^= v2 + v2 = v2<<32 | v2>>(64-32) + + // Round 2. + v0 += v1 + v1 = v1<<13 | v1>>(64-13) + v1 ^= v0 + v0 = v0<<32 | v0>>(64-32) + + v2 += v3 + v3 = v3<<16 | v3>>(64-16) + v3 ^= v2 + + v0 += v3 + v3 = v3<<21 | v3>>(64-21) + v3 ^= v0 + + v2 += v1 + v1 = v1<<17 | v1>>(64-17) + v1 ^= v2 + v2 = v2<<32 | v2>>(64-32) + + // Round 3. + v0 += v1 + v1 = v1<<13 | v1>>(64-13) + v1 ^= v0 + v0 = v0<<32 | v0>>(64-32) + + v2 += v3 + v3 = v3<<16 | v3>>(64-16) + v3 ^= v2 + + v0 += v3 + v3 = v3<<21 | v3>>(64-21) + v3 ^= v0 + + v2 += v1 + v1 = v1<<17 | v1>>(64-17) + v1 ^= v2 + v2 = v2<<32 | v2>>(64-32) + + // Round 4. + v0 += v1 + v1 = v1<<13 | v1>>(64-13) + v1 ^= v0 + v0 = v0<<32 | v0>>(64-32) + + v2 += v3 + v3 = v3<<16 | v3>>(64-16) + v3 ^= v2 + + v0 += v3 + v3 = v3<<21 | v3>>(64-21) + v3 ^= v0 + + v2 += v1 + v1 = v1<<17 | v1>>(64-17) + v1 ^= v2 + v2 = v2<<32 | v2>>(64-32) + + return v0 ^ v1 ^ v2 ^ v3 +} + +// rekey sets a new PRNG key, which is read from crypto/rand. +func (p *siprng) rekey() { + var k [16]byte + if _, err := io.ReadFull(rand.Reader, k[:]); err != nil { + panic(err.Error()) + } + p.k0 = binary.LittleEndian.Uint64(k[0:8]) + p.k1 = binary.LittleEndian.Uint64(k[8:16]) + p.ctr = 1 +} + +// Uint64 returns a new pseudorandom uint64. +// It rekeys PRNG on the first call and every 64 MB of generated data. +func (p *siprng) Uint64() uint64 { + p.mu.Lock() + if p.ctr == 0 || p.ctr > 8*1024*1024 { + p.rekey() + } + v := siphash(p.k0, p.k1, p.ctr) + p.ctr++ + p.mu.Unlock() + return v +} + +func (p *siprng) Int63() int64 { + return int64(p.Uint64() & 0x7fffffffffffffff) +} + +func (p *siprng) Uint32() uint32 { + return uint32(p.Uint64()) +} + +func (p *siprng) Int31() int32 { + return int32(p.Uint32() & 0x7fffffff) +} + +func (p *siprng) Intn(n int) int { + if n <= 0 { + panic("invalid argument to Intn") + } + if n <= 1<<31-1 { + return int(p.Int31n(int32(n))) + } + return int(p.Int63n(int64(n))) +} + +func (p *siprng) Int63n(n int64) int64 { + if n <= 0 { + panic("invalid argument to Int63n") + } + max := int64((1 << 63) - 1 - (1<<63)%uint64(n)) + v := p.Int63() + for v > max { + v = p.Int63() + } + return v % n +} + +func (p *siprng) Int31n(n int32) int32 { + if n <= 0 { + panic("invalid argument to Int31n") + } + max := int32((1 << 31) - 1 - (1<<31)%uint32(n)) + v := p.Int31() + for v > max { + v = p.Int31() + } + return v % n +} + +func (p *siprng) Float64() float64 { return float64(p.Int63()) / (1 << 63) } diff --git a/pkg/utils/captcha/siprng_test.go b/pkg/utils/captcha/siprng_test.go new file mode 100644 index 00000000..189d3d3c --- /dev/null +++ b/pkg/utils/captcha/siprng_test.go @@ -0,0 +1,33 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package captcha + +import "testing" + +func TestSiphash(t *testing.T) { + good := uint64(0xe849e8bb6ffe2567) + cur := siphash(0, 0, 0) + if cur != good { + t.Fatalf("siphash: expected %x, got %x", good, cur) + } +} + +func BenchmarkSiprng(b *testing.B) { + b.SetBytes(8) + p := &siprng{} + for i := 0; i < b.N; i++ { + p.Uint64() + } +} diff --git a/pkg/utils/debug.go b/pkg/utils/debug.go new file mode 100644 index 00000000..93c27b70 --- /dev/null +++ b/pkg/utils/debug.go @@ -0,0 +1,478 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package utils + +import ( + "bytes" + "fmt" + "log" + "reflect" + "runtime" +) + +var ( + dunno = []byte("???") + centerDot = []byte("·") + dot = []byte(".") +) + +type pointerInfo struct { + prev *pointerInfo + n int + addr uintptr + pos int + used []int +} + +// Display print the data in console +func Display(data ...interface{}) { + display(true, data...) +} + +// GetDisplayString return data print string +func GetDisplayString(data ...interface{}) string { + return display(false, data...) +} + +func display(displayed bool, data ...interface{}) string { + var pc, file, line, ok = runtime.Caller(2) + + if !ok { + return "" + } + + var buf = new(bytes.Buffer) + + fmt.Fprintf(buf, "[Debug] at %s() [%s:%d]\n", function(pc), file, line) + + fmt.Fprintf(buf, "\n[Variables]\n") + + for i := 0; i < len(data); i += 2 { + var output = fomateinfo(len(data[i].(string))+3, data[i+1]) + fmt.Fprintf(buf, "%s = %s", data[i], output) + } + + if displayed { + log.Print(buf) + } + return buf.String() +} + +// return data dump and format bytes +func fomateinfo(headlen int, data ...interface{}) []byte { + var buf = new(bytes.Buffer) + + if len(data) > 1 { + fmt.Fprint(buf, " ") + + fmt.Fprint(buf, "[") + + fmt.Fprintln(buf) + } + + for k, v := range data { + var buf2 = new(bytes.Buffer) + var pointers *pointerInfo + var interfaces = make([]reflect.Value, 0, 10) + + printKeyValue(buf2, reflect.ValueOf(v), &pointers, &interfaces, nil, true, " ", 1) + + if k < len(data)-1 { + fmt.Fprint(buf2, ", ") + } + + fmt.Fprintln(buf2) + + buf.Write(buf2.Bytes()) + } + + if len(data) > 1 { + fmt.Fprintln(buf) + + fmt.Fprint(buf, " ") + + fmt.Fprint(buf, "]") + } + + return buf.Bytes() +} + +// check data is golang basic type +func isSimpleType(val reflect.Value, kind reflect.Kind, pointers **pointerInfo, interfaces *[]reflect.Value) bool { + switch kind { + case reflect.Bool: + return true + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + return true + case reflect.Uint8, reflect.Uint16, reflect.Uint, reflect.Uint32, reflect.Uint64: + return true + case reflect.Float32, reflect.Float64: + return true + case reflect.Complex64, reflect.Complex128: + return true + case reflect.String: + return true + case reflect.Chan: + return true + case reflect.Invalid: + return true + case reflect.Interface: + for _, in := range *interfaces { + if reflect.DeepEqual(in, val) { + return true + } + } + return false + case reflect.UnsafePointer: + if val.IsNil() { + return true + } + + var elem = val.Elem() + + if isSimpleType(elem, elem.Kind(), pointers, interfaces) { + return true + } + + var addr = val.Elem().UnsafeAddr() + + for p := *pointers; p != nil; p = p.prev { + if addr == p.addr { + return true + } + } + + return false + } + + return false +} + +// dump value +func printKeyValue(buf *bytes.Buffer, val reflect.Value, pointers **pointerInfo, interfaces *[]reflect.Value, structFilter func(string, string) bool, formatOutput bool, indent string, level int) { + var t = val.Kind() + + switch t { + case reflect.Bool: + fmt.Fprint(buf, val.Bool()) + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + fmt.Fprint(buf, val.Int()) + case reflect.Uint8, reflect.Uint16, reflect.Uint, reflect.Uint32, reflect.Uint64: + fmt.Fprint(buf, val.Uint()) + case reflect.Float32, reflect.Float64: + fmt.Fprint(buf, val.Float()) + case reflect.Complex64, reflect.Complex128: + fmt.Fprint(buf, val.Complex()) + case reflect.UnsafePointer: + fmt.Fprintf(buf, "unsafe.Pointer(0x%X)", val.Pointer()) + case reflect.Ptr: + if val.IsNil() { + fmt.Fprint(buf, "nil") + return + } + + var addr = val.Elem().UnsafeAddr() + + for p := *pointers; p != nil; p = p.prev { + if addr == p.addr { + p.used = append(p.used, buf.Len()) + fmt.Fprintf(buf, "0x%X", addr) + return + } + } + + *pointers = &pointerInfo{ + prev: *pointers, + addr: addr, + pos: buf.Len(), + used: make([]int, 0), + } + + fmt.Fprint(buf, "&") + + printKeyValue(buf, val.Elem(), pointers, interfaces, structFilter, formatOutput, indent, level) + case reflect.String: + fmt.Fprint(buf, "\"", val.String(), "\"") + case reflect.Interface: + var value = val.Elem() + + if !value.IsValid() { + fmt.Fprint(buf, "nil") + } else { + for _, in := range *interfaces { + if reflect.DeepEqual(in, val) { + fmt.Fprint(buf, "repeat") + return + } + } + + *interfaces = append(*interfaces, val) + + printKeyValue(buf, value, pointers, interfaces, structFilter, formatOutput, indent, level+1) + } + case reflect.Struct: + var t = val.Type() + + fmt.Fprint(buf, t) + fmt.Fprint(buf, "{") + + for i := 0; i < val.NumField(); i++ { + if formatOutput { + fmt.Fprintln(buf) + } else { + fmt.Fprint(buf, " ") + } + + var name = t.Field(i).Name + + if formatOutput { + for ind := 0; ind < level; ind++ { + fmt.Fprint(buf, indent) + } + } + + fmt.Fprint(buf, name) + fmt.Fprint(buf, ": ") + + if structFilter != nil && structFilter(t.String(), name) { + fmt.Fprint(buf, "ignore") + } else { + printKeyValue(buf, val.Field(i), pointers, interfaces, structFilter, formatOutput, indent, level+1) + } + + fmt.Fprint(buf, ",") + } + + if formatOutput { + fmt.Fprintln(buf) + + for ind := 0; ind < level-1; ind++ { + fmt.Fprint(buf, indent) + } + } else { + fmt.Fprint(buf, " ") + } + + fmt.Fprint(buf, "}") + case reflect.Array, reflect.Slice: + fmt.Fprint(buf, val.Type()) + fmt.Fprint(buf, "{") + + var allSimple = true + + for i := 0; i < val.Len(); i++ { + var elem = val.Index(i) + + var isSimple = isSimpleType(elem, elem.Kind(), pointers, interfaces) + + if !isSimple { + allSimple = false + } + + if formatOutput && !isSimple { + fmt.Fprintln(buf) + } else { + fmt.Fprint(buf, " ") + } + + if formatOutput && !isSimple { + for ind := 0; ind < level; ind++ { + fmt.Fprint(buf, indent) + } + } + + printKeyValue(buf, elem, pointers, interfaces, structFilter, formatOutput, indent, level+1) + + if i != val.Len()-1 || !allSimple { + fmt.Fprint(buf, ",") + } + } + + if formatOutput && !allSimple { + fmt.Fprintln(buf) + + for ind := 0; ind < level-1; ind++ { + fmt.Fprint(buf, indent) + } + } else { + fmt.Fprint(buf, " ") + } + + fmt.Fprint(buf, "}") + case reflect.Map: + var t = val.Type() + var keys = val.MapKeys() + + fmt.Fprint(buf, t) + fmt.Fprint(buf, "{") + + var allSimple = true + + for i := 0; i < len(keys); i++ { + var elem = val.MapIndex(keys[i]) + + var isSimple = isSimpleType(elem, elem.Kind(), pointers, interfaces) + + if !isSimple { + allSimple = false + } + + if formatOutput && !isSimple { + fmt.Fprintln(buf) + } else { + fmt.Fprint(buf, " ") + } + + if formatOutput && !isSimple { + for ind := 0; ind <= level; ind++ { + fmt.Fprint(buf, indent) + } + } + + printKeyValue(buf, keys[i], pointers, interfaces, structFilter, formatOutput, indent, level+1) + fmt.Fprint(buf, ": ") + printKeyValue(buf, elem, pointers, interfaces, structFilter, formatOutput, indent, level+1) + + if i != val.Len()-1 || !allSimple { + fmt.Fprint(buf, ",") + } + } + + if formatOutput && !allSimple { + fmt.Fprintln(buf) + + for ind := 0; ind < level-1; ind++ { + fmt.Fprint(buf, indent) + } + } else { + fmt.Fprint(buf, " ") + } + + fmt.Fprint(buf, "}") + case reflect.Chan: + fmt.Fprint(buf, val.Type()) + case reflect.Invalid: + fmt.Fprint(buf, "invalid") + default: + fmt.Fprint(buf, "unknow") + } +} + +// PrintPointerInfo dump pointer value +func PrintPointerInfo(buf *bytes.Buffer, headlen int, pointers *pointerInfo) { + var anyused = false + var pointerNum = 0 + + for p := pointers; p != nil; p = p.prev { + if len(p.used) > 0 { + anyused = true + } + pointerNum++ + p.n = pointerNum + } + + if anyused { + var pointerBufs = make([][]rune, pointerNum+1) + + for i := 0; i < len(pointerBufs); i++ { + var pointerBuf = make([]rune, buf.Len()+headlen) + + for j := 0; j < len(pointerBuf); j++ { + pointerBuf[j] = ' ' + } + + pointerBufs[i] = pointerBuf + } + + for pn := 0; pn <= pointerNum; pn++ { + for p := pointers; p != nil; p = p.prev { + if len(p.used) > 0 && p.n >= pn { + if pn == p.n { + pointerBufs[pn][p.pos+headlen] = '└' + + var maxpos = 0 + + for i, pos := range p.used { + if i < len(p.used)-1 { + pointerBufs[pn][pos+headlen] = '┴' + } else { + pointerBufs[pn][pos+headlen] = '┘' + } + + maxpos = pos + } + + for i := 0; i < maxpos-p.pos-1; i++ { + if pointerBufs[pn][i+p.pos+headlen+1] == ' ' { + pointerBufs[pn][i+p.pos+headlen+1] = '─' + } + } + } else { + pointerBufs[pn][p.pos+headlen] = '│' + + for _, pos := range p.used { + if pointerBufs[pn][pos+headlen] == ' ' { + pointerBufs[pn][pos+headlen] = '│' + } else { + pointerBufs[pn][pos+headlen] = '┼' + } + } + } + } + } + + buf.WriteString(string(pointerBufs[pn]) + "\n") + } + } +} + +// Stack get stack bytes +func Stack(skip int, indent string) []byte { + var buf = new(bytes.Buffer) + + for i := skip; ; i++ { + var pc, file, line, ok = runtime.Caller(i) + + if !ok { + break + } + + buf.WriteString(indent) + + fmt.Fprintf(buf, "at %s() [%s:%d]\n", function(pc), file, line) + } + + return buf.Bytes() +} + +// return the name of the function containing the PC if possible, +func function(pc uintptr) []byte { + fn := runtime.FuncForPC(pc) + if fn == nil { + return dunno + } + name := []byte(fn.Name()) + // The name includes the path name to the package, which is unnecessary + // since the file name is already included. Plus, it has center dots. + // That is, we see + // runtime/debug.*T·ptrmethod + // and want + // *T.ptrmethod + if period := bytes.Index(name, dot); period >= 0 { + name = name[period+1:] + } + name = bytes.Replace(name, centerDot, dot, -1) + return name +} diff --git a/pkg/utils/debug_test.go b/pkg/utils/debug_test.go new file mode 100644 index 00000000..efb8924e --- /dev/null +++ b/pkg/utils/debug_test.go @@ -0,0 +1,46 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package utils + +import ( + "testing" +) + +type mytype struct { + next *mytype + prev *mytype +} + +func TestPrint(t *testing.T) { + Display("v1", 1, "v2", 2, "v3", 3) +} + +func TestPrintPoint(t *testing.T) { + var v1 = new(mytype) + var v2 = new(mytype) + + v1.prev = nil + v1.next = v2 + + v2.prev = v1 + v2.next = nil + + Display("v1", v1, "v2", v2) +} + +func TestPrintString(t *testing.T) { + str := GetDisplayString("v1", 1, "v2", 2) + println(str) +} diff --git a/pkg/utils/file.go b/pkg/utils/file.go new file mode 100644 index 00000000..6090eb17 --- /dev/null +++ b/pkg/utils/file.go @@ -0,0 +1,101 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package utils + +import ( + "bufio" + "errors" + "io" + "os" + "path/filepath" + "regexp" +) + +// SelfPath gets compiled executable file absolute path +func SelfPath() string { + path, _ := filepath.Abs(os.Args[0]) + return path +} + +// SelfDir gets compiled executable file directory +func SelfDir() string { + return filepath.Dir(SelfPath()) +} + +// FileExists reports whether the named file or directory exists. +func FileExists(name string) bool { + if _, err := os.Stat(name); err != nil { + if os.IsNotExist(err) { + return false + } + } + return true +} + +// SearchFile Search a file in paths. +// this is often used in search config file in /etc ~/ +func SearchFile(filename string, paths ...string) (fullpath string, err error) { + for _, path := range paths { + if fullpath = filepath.Join(path, filename); FileExists(fullpath) { + return + } + } + err = errors.New(fullpath + " not found in paths") + return +} + +// GrepFile like command grep -E +// for example: GrepFile(`^hello`, "hello.txt") +// \n is striped while read +func GrepFile(patten string, filename string) (lines []string, err error) { + re, err := regexp.Compile(patten) + if err != nil { + return + } + + fd, err := os.Open(filename) + if err != nil { + return + } + lines = make([]string, 0) + reader := bufio.NewReader(fd) + prefix := "" + var isLongLine bool + for { + byteLine, isPrefix, er := reader.ReadLine() + if er != nil && er != io.EOF { + return nil, er + } + if er == io.EOF { + break + } + line := string(byteLine) + if isPrefix { + prefix += line + continue + } else { + isLongLine = true + } + + line = prefix + line + if isLongLine { + prefix = "" + } + if re.MatchString(line) { + lines = append(lines, line) + } + } + return lines, nil +} diff --git a/pkg/utils/file_test.go b/pkg/utils/file_test.go new file mode 100644 index 00000000..b2644157 --- /dev/null +++ b/pkg/utils/file_test.go @@ -0,0 +1,75 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package utils + +import ( + "path/filepath" + "reflect" + "testing" +) + +var noExistedFile = "/tmp/not_existed_file" + +func TestSelfPath(t *testing.T) { + path := SelfPath() + if path == "" { + t.Error("path cannot be empty") + } + t.Logf("SelfPath: %s", path) +} + +func TestSelfDir(t *testing.T) { + dir := SelfDir() + t.Logf("SelfDir: %s", dir) +} + +func TestFileExists(t *testing.T) { + if !FileExists("./file.go") { + t.Errorf("./file.go should exists, but it didn't") + } + + if FileExists(noExistedFile) { + t.Errorf("Weird, how could this file exists: %s", noExistedFile) + } +} + +func TestSearchFile(t *testing.T) { + path, err := SearchFile(filepath.Base(SelfPath()), SelfDir()) + if err != nil { + t.Error(err) + } + t.Log(path) + + _, err = SearchFile(noExistedFile, ".") + if err == nil { + t.Errorf("err shouldnt be nil, got path: %s", SelfDir()) + } +} + +func TestGrepFile(t *testing.T) { + _, err := GrepFile("", noExistedFile) + if err == nil { + t.Error("expect file-not-existed error, but got nothing") + } + + path := filepath.Join(".", "testdata", "grepe.test") + lines, err := GrepFile(`^\s*[^#]+`, path) + if err != nil { + t.Error(err) + } + if !reflect.DeepEqual(lines, []string{"hello", "world"}) { + t.Errorf("expect [hello world], but receive %v", lines) + } +} diff --git a/pkg/utils/mail.go b/pkg/utils/mail.go new file mode 100644 index 00000000..80a366ca --- /dev/null +++ b/pkg/utils/mail.go @@ -0,0 +1,424 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package utils + +import ( + "bytes" + "encoding/base64" + "encoding/json" + "errors" + "fmt" + "io" + "mime" + "mime/multipart" + "net/mail" + "net/smtp" + "net/textproto" + "os" + "path" + "path/filepath" + "strconv" + "strings" + "sync" +) + +const ( + maxLineLength = 76 + + upperhex = "0123456789ABCDEF" +) + +// Email is the type used for email messages +type Email struct { + Auth smtp.Auth + Identity string `json:"identity"` + Username string `json:"username"` + Password string `json:"password"` + Host string `json:"host"` + Port int `json:"port"` + From string `json:"from"` + To []string + Bcc []string + Cc []string + Subject string + Text string // Plaintext message (optional) + HTML string // Html message (optional) + Headers textproto.MIMEHeader + Attachments []*Attachment + ReadReceipt []string +} + +// Attachment is a struct representing an email attachment. +// Based on the mime/multipart.FileHeader struct, Attachment contains the name, MIMEHeader, and content of the attachment in question +type Attachment struct { + Filename string + Header textproto.MIMEHeader + Content []byte +} + +// NewEMail create new Email struct with config json. +// config json is followed from Email struct fields. +func NewEMail(config string) *Email { + e := new(Email) + e.Headers = textproto.MIMEHeader{} + err := json.Unmarshal([]byte(config), e) + if err != nil { + return nil + } + return e +} + +// Bytes Make all send information to byte +func (e *Email) Bytes() ([]byte, error) { + buff := &bytes.Buffer{} + w := multipart.NewWriter(buff) + // Set the appropriate headers (overwriting any conflicts) + // Leave out Bcc (only included in envelope headers) + e.Headers.Set("To", strings.Join(e.To, ",")) + if e.Cc != nil { + e.Headers.Set("Cc", strings.Join(e.Cc, ",")) + } + e.Headers.Set("From", e.From) + e.Headers.Set("Subject", e.Subject) + if len(e.ReadReceipt) != 0 { + e.Headers.Set("Disposition-Notification-To", strings.Join(e.ReadReceipt, ",")) + } + e.Headers.Set("MIME-Version", "1.0") + + // Write the envelope headers (including any custom headers) + if err := headerToBytes(buff, e.Headers); err != nil { + return nil, fmt.Errorf("Failed to render message headers: %s", err) + } + + e.Headers.Set("Content-Type", fmt.Sprintf("multipart/mixed;\r\n boundary=%s\r\n", w.Boundary())) + fmt.Fprintf(buff, "%s:", "Content-Type") + fmt.Fprintf(buff, " %s\r\n", fmt.Sprintf("multipart/mixed;\r\n boundary=%s\r\n", w.Boundary())) + + // Start the multipart/mixed part + fmt.Fprintf(buff, "--%s\r\n", w.Boundary()) + header := textproto.MIMEHeader{} + // Check to see if there is a Text or HTML field + if e.Text != "" || e.HTML != "" { + subWriter := multipart.NewWriter(buff) + // Create the multipart alternative part + header.Set("Content-Type", fmt.Sprintf("multipart/alternative;\r\n boundary=%s\r\n", subWriter.Boundary())) + // Write the header + if err := headerToBytes(buff, header); err != nil { + return nil, fmt.Errorf("Failed to render multipart message headers: %s", err) + } + // Create the body sections + if e.Text != "" { + header.Set("Content-Type", fmt.Sprintf("text/plain; charset=UTF-8")) + header.Set("Content-Transfer-Encoding", "quoted-printable") + if _, err := subWriter.CreatePart(header); err != nil { + return nil, err + } + // Write the text + if err := quotePrintEncode(buff, e.Text); err != nil { + return nil, err + } + } + if e.HTML != "" { + header.Set("Content-Type", fmt.Sprintf("text/html; charset=UTF-8")) + header.Set("Content-Transfer-Encoding", "quoted-printable") + if _, err := subWriter.CreatePart(header); err != nil { + return nil, err + } + // Write the text + if err := quotePrintEncode(buff, e.HTML); err != nil { + return nil, err + } + } + if err := subWriter.Close(); err != nil { + return nil, err + } + } + // Create attachment part, if necessary + for _, a := range e.Attachments { + ap, err := w.CreatePart(a.Header) + if err != nil { + return nil, err + } + // Write the base64Wrapped content to the part + base64Wrap(ap, a.Content) + } + if err := w.Close(); err != nil { + return nil, err + } + return buff.Bytes(), nil +} + +// AttachFile Add attach file to the send mail +func (e *Email) AttachFile(args ...string) (a *Attachment, err error) { + if len(args) < 1 || len(args) > 2 { // change && to || + err = errors.New("Must specify a file name and number of parameters can not exceed at least two") + return + } + filename := args[0] + id := "" + if len(args) > 1 { + id = args[1] + } + f, err := os.Open(filename) + if err != nil { + return + } + defer f.Close() + ct := mime.TypeByExtension(filepath.Ext(filename)) + basename := path.Base(filename) + return e.Attach(f, basename, ct, id) +} + +// Attach is used to attach content from an io.Reader to the email. +// Parameters include an io.Reader, the desired filename for the attachment, and the Content-Type. +func (e *Email) Attach(r io.Reader, filename string, args ...string) (a *Attachment, err error) { + if len(args) < 1 || len(args) > 2 { // change && to || + err = errors.New("Must specify the file type and number of parameters can not exceed at least two") + return + } + c := args[0] //Content-Type + id := "" + if len(args) > 1 { + id = args[1] //Content-ID + } + var buffer bytes.Buffer + if _, err = io.Copy(&buffer, r); err != nil { + return + } + at := &Attachment{ + Filename: filename, + Header: textproto.MIMEHeader{}, + Content: buffer.Bytes(), + } + // Get the Content-Type to be used in the MIMEHeader + if c != "" { + at.Header.Set("Content-Type", c) + } else { + // If the Content-Type is blank, set the Content-Type to "application/octet-stream" + at.Header.Set("Content-Type", "application/octet-stream") + } + if id != "" { + at.Header.Set("Content-Disposition", fmt.Sprintf("inline;\r\n filename=\"%s\"", filename)) + at.Header.Set("Content-ID", fmt.Sprintf("<%s>", id)) + } else { + at.Header.Set("Content-Disposition", fmt.Sprintf("attachment;\r\n filename=\"%s\"", filename)) + } + at.Header.Set("Content-Transfer-Encoding", "base64") + e.Attachments = append(e.Attachments, at) + return at, nil +} + +// Send will send out the mail +func (e *Email) Send() error { + if e.Auth == nil { + e.Auth = smtp.PlainAuth(e.Identity, e.Username, e.Password, e.Host) + } + // Merge the To, Cc, and Bcc fields + to := make([]string, 0, len(e.To)+len(e.Cc)+len(e.Bcc)) + to = append(append(append(to, e.To...), e.Cc...), e.Bcc...) + // Check to make sure there is at least one recipient and one "From" address + if len(to) == 0 { + return errors.New("Must specify at least one To address") + } + + // Use the username if no From is provided + if len(e.From) == 0 { + e.From = e.Username + } + + from, err := mail.ParseAddress(e.From) + if err != nil { + return err + } + + // use mail's RFC 2047 to encode any string + e.Subject = qEncode("utf-8", e.Subject) + + raw, err := e.Bytes() + if err != nil { + return err + } + return smtp.SendMail(e.Host+":"+strconv.Itoa(e.Port), e.Auth, from.Address, to, raw) +} + +// quotePrintEncode writes the quoted-printable text to the IO Writer (according to RFC 2045) +func quotePrintEncode(w io.Writer, s string) error { + var buf [3]byte + mc := 0 + for i := 0; i < len(s); i++ { + c := s[i] + // We're assuming Unix style text formats as input (LF line break), and + // quoted-printble uses CRLF line breaks. (Literal CRs will become + // "=0D", but probably shouldn't be there to begin with!) + if c == '\n' { + io.WriteString(w, "\r\n") + mc = 0 + continue + } + + var nextOut []byte + if isPrintable(c) { + nextOut = append(buf[:0], c) + } else { + nextOut = buf[:] + qpEscape(nextOut, c) + } + + // Add a soft line break if the next (encoded) byte would push this line + // to or past the limit. + if mc+len(nextOut) >= maxLineLength { + if _, err := io.WriteString(w, "=\r\n"); err != nil { + return err + } + mc = 0 + } + + if _, err := w.Write(nextOut); err != nil { + return err + } + mc += len(nextOut) + } + // No trailing end-of-line?? Soft line break, then. TODO: is this sane? + if mc > 0 { + io.WriteString(w, "=\r\n") + } + return nil +} + +// isPrintable returns true if the rune given is "printable" according to RFC 2045, false otherwise +func isPrintable(c byte) bool { + return (c >= '!' && c <= '<') || (c >= '>' && c <= '~') || (c == ' ' || c == '\n' || c == '\t') +} + +// qpEscape is a helper function for quotePrintEncode which escapes a +// non-printable byte. Expects len(dest) == 3. +func qpEscape(dest []byte, c byte) { + const nums = "0123456789ABCDEF" + dest[0] = '=' + dest[1] = nums[(c&0xf0)>>4] + dest[2] = nums[(c & 0xf)] +} + +// headerToBytes enumerates the key and values in the header, and writes the results to the IO Writer +func headerToBytes(w io.Writer, t textproto.MIMEHeader) error { + for k, v := range t { + // Write the header key + _, err := fmt.Fprintf(w, "%s:", k) + if err != nil { + return err + } + // Write each value in the header + for _, c := range v { + _, err := fmt.Fprintf(w, " %s\r\n", c) + if err != nil { + return err + } + } + } + return nil +} + +// base64Wrap encodes the attachment content, and wraps it according to RFC 2045 standards (every 76 chars) +// The output is then written to the specified io.Writer +func base64Wrap(w io.Writer, b []byte) { + // 57 raw bytes per 76-byte base64 line. + const maxRaw = 57 + // Buffer for each line, including trailing CRLF. + var buffer [maxLineLength + len("\r\n")]byte + copy(buffer[maxLineLength:], "\r\n") + // Process raw chunks until there's no longer enough to fill a line. + for len(b) >= maxRaw { + base64.StdEncoding.Encode(buffer[:], b[:maxRaw]) + w.Write(buffer[:]) + b = b[maxRaw:] + } + // Handle the last chunk of bytes. + if len(b) > 0 { + out := buffer[:base64.StdEncoding.EncodedLen(len(b))] + base64.StdEncoding.Encode(out, b) + out = append(out, "\r\n"...) + w.Write(out) + } +} + +// Encode returns the encoded-word form of s. If s is ASCII without special +// characters, it is returned unchanged. The provided charset is the IANA +// charset name of s. It is case insensitive. +// RFC 2047 encoded-word +func qEncode(charset, s string) string { + if !needsEncoding(s) { + return s + } + return encodeWord(charset, s) +} + +func needsEncoding(s string) bool { + for _, b := range s { + if (b < ' ' || b > '~') && b != '\t' { + return true + } + } + return false +} + +// encodeWord encodes a string into an encoded-word. +func encodeWord(charset, s string) string { + buf := getBuffer() + + buf.WriteString("=?") + buf.WriteString(charset) + buf.WriteByte('?') + buf.WriteByte('q') + buf.WriteByte('?') + + enc := make([]byte, 3) + for i := 0; i < len(s); i++ { + b := s[i] + switch { + case b == ' ': + buf.WriteByte('_') + case b <= '~' && b >= '!' && b != '=' && b != '?' && b != '_': + buf.WriteByte(b) + default: + enc[0] = '=' + enc[1] = upperhex[b>>4] + enc[2] = upperhex[b&0x0f] + buf.Write(enc) + } + } + buf.WriteString("?=") + + es := buf.String() + putBuffer(buf) + return es +} + +var bufPool = sync.Pool{ + New: func() interface{} { + return new(bytes.Buffer) + }, +} + +func getBuffer() *bytes.Buffer { + return bufPool.Get().(*bytes.Buffer) +} + +func putBuffer(buf *bytes.Buffer) { + if buf.Len() > 1024 { + return + } + buf.Reset() + bufPool.Put(buf) +} diff --git a/pkg/utils/mail_test.go b/pkg/utils/mail_test.go new file mode 100644 index 00000000..c38356a2 --- /dev/null +++ b/pkg/utils/mail_test.go @@ -0,0 +1,41 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package utils + +import "testing" + +func TestMail(t *testing.T) { + config := `{"username":"astaxie@gmail.com","password":"astaxie","host":"smtp.gmail.com","port":587}` + mail := NewEMail(config) + if mail.Username != "astaxie@gmail.com" { + t.Fatal("email parse get username error") + } + if mail.Password != "astaxie" { + t.Fatal("email parse get password error") + } + if mail.Host != "smtp.gmail.com" { + t.Fatal("email parse get host error") + } + if mail.Port != 587 { + t.Fatal("email parse get port error") + } + mail.To = []string{"xiemengjun@gmail.com"} + mail.From = "astaxie@gmail.com" + mail.Subject = "hi, just from beego!" + mail.Text = "Text Body is, of course, supported!" + mail.HTML = "

Fancy Html is supported, too!

" + mail.AttachFile("/Users/astaxie/github/beego/beego.go") + mail.Send() +} diff --git a/pkg/utils/pagination/controller.go b/pkg/utils/pagination/controller.go new file mode 100644 index 00000000..2f022d0c --- /dev/null +++ b/pkg/utils/pagination/controller.go @@ -0,0 +1,26 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package pagination + +import ( + "github.com/astaxie/beego/context" +) + +// SetPaginator Instantiates a Paginator and assigns it to context.Input.Data("paginator"). +func SetPaginator(context *context.Context, per int, nums int64) (paginator *Paginator) { + paginator = NewPaginator(context.Request, per, nums) + context.Input.SetData("paginator", &paginator) + return +} diff --git a/pkg/utils/pagination/doc.go b/pkg/utils/pagination/doc.go new file mode 100644 index 00000000..9abc6d78 --- /dev/null +++ b/pkg/utils/pagination/doc.go @@ -0,0 +1,58 @@ +/* +Package pagination provides utilities to setup a paginator within the +context of a http request. + +Usage + +In your beego.Controller: + + package controllers + + import "github.com/astaxie/beego/utils/pagination" + + type PostsController struct { + beego.Controller + } + + func (this *PostsController) ListAllPosts() { + // sets this.Data["paginator"] with the current offset (from the url query param) + postsPerPage := 20 + paginator := pagination.SetPaginator(this.Ctx, postsPerPage, CountPosts()) + + // fetch the next 20 posts + this.Data["posts"] = ListPostsByOffsetAndLimit(paginator.Offset(), postsPerPage) + } + + +In your view templates: + + {{if .paginator.HasPages}} + + {{end}} + +See also + +http://beego.me/docs/mvc/view/page.md + +*/ +package pagination diff --git a/pkg/utils/pagination/paginator.go b/pkg/utils/pagination/paginator.go new file mode 100644 index 00000000..c6db31e0 --- /dev/null +++ b/pkg/utils/pagination/paginator.go @@ -0,0 +1,189 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package pagination + +import ( + "math" + "net/http" + "net/url" + "strconv" +) + +// Paginator within the state of a http request. +type Paginator struct { + Request *http.Request + PerPageNums int + MaxPages int + + nums int64 + pageRange []int + pageNums int + page int +} + +// PageNums Returns the total number of pages. +func (p *Paginator) PageNums() int { + if p.pageNums != 0 { + return p.pageNums + } + pageNums := math.Ceil(float64(p.nums) / float64(p.PerPageNums)) + if p.MaxPages > 0 { + pageNums = math.Min(pageNums, float64(p.MaxPages)) + } + p.pageNums = int(pageNums) + return p.pageNums +} + +// Nums Returns the total number of items (e.g. from doing SQL count). +func (p *Paginator) Nums() int64 { + return p.nums +} + +// SetNums Sets the total number of items. +func (p *Paginator) SetNums(nums interface{}) { + p.nums, _ = toInt64(nums) +} + +// Page Returns the current page. +func (p *Paginator) Page() int { + if p.page != 0 { + return p.page + } + if p.Request.Form == nil { + p.Request.ParseForm() + } + p.page, _ = strconv.Atoi(p.Request.Form.Get("p")) + if p.page > p.PageNums() { + p.page = p.PageNums() + } + if p.page <= 0 { + p.page = 1 + } + return p.page +} + +// Pages Returns a list of all pages. +// +// Usage (in a view template): +// +// {{range $index, $page := .paginator.Pages}} +// +// {{$page}} +// +// {{end}} +func (p *Paginator) Pages() []int { + if p.pageRange == nil && p.nums > 0 { + var pages []int + pageNums := p.PageNums() + page := p.Page() + switch { + case page >= pageNums-4 && pageNums > 9: + start := pageNums - 9 + 1 + pages = make([]int, 9) + for i := range pages { + pages[i] = start + i + } + case page >= 5 && pageNums > 9: + start := page - 5 + 1 + pages = make([]int, int(math.Min(9, float64(page+4+1)))) + for i := range pages { + pages[i] = start + i + } + default: + pages = make([]int, int(math.Min(9, float64(pageNums)))) + for i := range pages { + pages[i] = i + 1 + } + } + p.pageRange = pages + } + return p.pageRange +} + +// PageLink Returns URL for a given page index. +func (p *Paginator) PageLink(page int) string { + link, _ := url.ParseRequestURI(p.Request.URL.String()) + values := link.Query() + if page == 1 { + values.Del("p") + } else { + values.Set("p", strconv.Itoa(page)) + } + link.RawQuery = values.Encode() + return link.String() +} + +// PageLinkPrev Returns URL to the previous page. +func (p *Paginator) PageLinkPrev() (link string) { + if p.HasPrev() { + link = p.PageLink(p.Page() - 1) + } + return +} + +// PageLinkNext Returns URL to the next page. +func (p *Paginator) PageLinkNext() (link string) { + if p.HasNext() { + link = p.PageLink(p.Page() + 1) + } + return +} + +// PageLinkFirst Returns URL to the first page. +func (p *Paginator) PageLinkFirst() (link string) { + return p.PageLink(1) +} + +// PageLinkLast Returns URL to the last page. +func (p *Paginator) PageLinkLast() (link string) { + return p.PageLink(p.PageNums()) +} + +// HasPrev Returns true if the current page has a predecessor. +func (p *Paginator) HasPrev() bool { + return p.Page() > 1 +} + +// HasNext Returns true if the current page has a successor. +func (p *Paginator) HasNext() bool { + return p.Page() < p.PageNums() +} + +// IsActive Returns true if the given page index points to the current page. +func (p *Paginator) IsActive(page int) bool { + return p.Page() == page +} + +// Offset Returns the current offset. +func (p *Paginator) Offset() int { + return (p.Page() - 1) * p.PerPageNums +} + +// HasPages Returns true if there is more than one page. +func (p *Paginator) HasPages() bool { + return p.PageNums() > 1 +} + +// NewPaginator Instantiates a paginator struct for the current http request. +func NewPaginator(req *http.Request, per int, nums interface{}) *Paginator { + p := Paginator{} + p.Request = req + if per <= 0 { + per = 10 + } + p.PerPageNums = per + p.SetNums(nums) + return &p +} diff --git a/pkg/utils/pagination/utils.go b/pkg/utils/pagination/utils.go new file mode 100644 index 00000000..686e68b0 --- /dev/null +++ b/pkg/utils/pagination/utils.go @@ -0,0 +1,34 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package pagination + +import ( + "fmt" + "reflect" +) + +// ToInt64 convert any numeric value to int64 +func toInt64(value interface{}) (d int64, err error) { + val := reflect.ValueOf(value) + switch value.(type) { + case int, int8, int16, int32, int64: + d = val.Int() + case uint, uint8, uint16, uint32, uint64: + d = int64(val.Uint()) + default: + err = fmt.Errorf("ToInt64 need numeric not `%T`", value) + } + return +} diff --git a/pkg/utils/rand.go b/pkg/utils/rand.go new file mode 100644 index 00000000..344d1cd5 --- /dev/null +++ b/pkg/utils/rand.go @@ -0,0 +1,44 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package utils + +import ( + "crypto/rand" + r "math/rand" + "time" +) + +var alphaNum = []byte(`0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz`) + +// RandomCreateBytes generate random []byte by specify chars. +func RandomCreateBytes(n int, alphabets ...byte) []byte { + if len(alphabets) == 0 { + alphabets = alphaNum + } + var bytes = make([]byte, n) + var randBy bool + if num, err := rand.Read(bytes); num != n || err != nil { + r.Seed(time.Now().UnixNano()) + randBy = true + } + for i, b := range bytes { + if randBy { + bytes[i] = alphabets[r.Intn(len(alphabets))] + } else { + bytes[i] = alphabets[b%byte(len(alphabets))] + } + } + return bytes +} diff --git a/pkg/utils/rand_test.go b/pkg/utils/rand_test.go new file mode 100644 index 00000000..6c238b5e --- /dev/null +++ b/pkg/utils/rand_test.go @@ -0,0 +1,33 @@ +// Copyright 2016 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package utils + +import "testing" + +func TestRand_01(t *testing.T) { + bs0 := RandomCreateBytes(16) + bs1 := RandomCreateBytes(16) + + t.Log(string(bs0), string(bs1)) + if string(bs0) == string(bs1) { + t.FailNow() + } + + bs0 = RandomCreateBytes(4, []byte(`a`)...) + + if string(bs0) != "aaaa" { + t.FailNow() + } +} diff --git a/pkg/utils/safemap.go b/pkg/utils/safemap.go new file mode 100644 index 00000000..1793030a --- /dev/null +++ b/pkg/utils/safemap.go @@ -0,0 +1,91 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package utils + +import ( + "sync" +) + +// BeeMap is a map with lock +type BeeMap struct { + lock *sync.RWMutex + bm map[interface{}]interface{} +} + +// NewBeeMap return new safemap +func NewBeeMap() *BeeMap { + return &BeeMap{ + lock: new(sync.RWMutex), + bm: make(map[interface{}]interface{}), + } +} + +// Get from maps return the k's value +func (m *BeeMap) Get(k interface{}) interface{} { + m.lock.RLock() + defer m.lock.RUnlock() + if val, ok := m.bm[k]; ok { + return val + } + return nil +} + +// Set Maps the given key and value. Returns false +// if the key is already in the map and changes nothing. +func (m *BeeMap) Set(k interface{}, v interface{}) bool { + m.lock.Lock() + defer m.lock.Unlock() + if val, ok := m.bm[k]; !ok { + m.bm[k] = v + } else if val != v { + m.bm[k] = v + } else { + return false + } + return true +} + +// Check Returns true if k is exist in the map. +func (m *BeeMap) Check(k interface{}) bool { + m.lock.RLock() + defer m.lock.RUnlock() + _, ok := m.bm[k] + return ok +} + +// Delete the given key and value. +func (m *BeeMap) Delete(k interface{}) { + m.lock.Lock() + defer m.lock.Unlock() + delete(m.bm, k) +} + +// Items returns all items in safemap. +func (m *BeeMap) Items() map[interface{}]interface{} { + m.lock.RLock() + defer m.lock.RUnlock() + r := make(map[interface{}]interface{}) + for k, v := range m.bm { + r[k] = v + } + return r +} + +// Count returns the number of items within the map. +func (m *BeeMap) Count() int { + m.lock.RLock() + defer m.lock.RUnlock() + return len(m.bm) +} diff --git a/pkg/utils/safemap_test.go b/pkg/utils/safemap_test.go new file mode 100644 index 00000000..65085195 --- /dev/null +++ b/pkg/utils/safemap_test.go @@ -0,0 +1,89 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package utils + +import "testing" + +var safeMap *BeeMap + +func TestNewBeeMap(t *testing.T) { + safeMap = NewBeeMap() + if safeMap == nil { + t.Fatal("expected to return non-nil BeeMap", "got", safeMap) + } +} + +func TestSet(t *testing.T) { + safeMap = NewBeeMap() + if ok := safeMap.Set("astaxie", 1); !ok { + t.Error("expected", true, "got", false) + } +} + +func TestReSet(t *testing.T) { + safeMap := NewBeeMap() + if ok := safeMap.Set("astaxie", 1); !ok { + t.Error("expected", true, "got", false) + } + // set diff value + if ok := safeMap.Set("astaxie", -1); !ok { + t.Error("expected", true, "got", false) + } + + // set same value + if ok := safeMap.Set("astaxie", -1); ok { + t.Error("expected", false, "got", true) + } +} + +func TestCheck(t *testing.T) { + if exists := safeMap.Check("astaxie"); !exists { + t.Error("expected", true, "got", false) + } +} + +func TestGet(t *testing.T) { + if val := safeMap.Get("astaxie"); val.(int) != 1 { + t.Error("expected value", 1, "got", val) + } +} + +func TestDelete(t *testing.T) { + safeMap.Delete("astaxie") + if exists := safeMap.Check("astaxie"); exists { + t.Error("expected element to be deleted") + } +} + +func TestItems(t *testing.T) { + safeMap := NewBeeMap() + safeMap.Set("astaxie", "hello") + for k, v := range safeMap.Items() { + key := k.(string) + value := v.(string) + if key != "astaxie" { + t.Error("expected the key should be astaxie") + } + if value != "hello" { + t.Error("expected the value should be hello") + } + } +} + +func TestCount(t *testing.T) { + if count := safeMap.Count(); count != 0 { + t.Error("expected count to be", 0, "got", count) + } +} diff --git a/pkg/utils/slice.go b/pkg/utils/slice.go new file mode 100644 index 00000000..8f2cef98 --- /dev/null +++ b/pkg/utils/slice.go @@ -0,0 +1,170 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package utils + +import ( + "math/rand" + "time" +) + +type reducetype func(interface{}) interface{} +type filtertype func(interface{}) bool + +// InSlice checks given string in string slice or not. +func InSlice(v string, sl []string) bool { + for _, vv := range sl { + if vv == v { + return true + } + } + return false +} + +// InSliceIface checks given interface in interface slice. +func InSliceIface(v interface{}, sl []interface{}) bool { + for _, vv := range sl { + if vv == v { + return true + } + } + return false +} + +// SliceRandList generate an int slice from min to max. +func SliceRandList(min, max int) []int { + if max < min { + min, max = max, min + } + length := max - min + 1 + t0 := time.Now() + rand.Seed(int64(t0.Nanosecond())) + list := rand.Perm(length) + for index := range list { + list[index] += min + } + return list +} + +// SliceMerge merges interface slices to one slice. +func SliceMerge(slice1, slice2 []interface{}) (c []interface{}) { + c = append(slice1, slice2...) + return +} + +// SliceReduce generates a new slice after parsing every value by reduce function +func SliceReduce(slice []interface{}, a reducetype) (dslice []interface{}) { + for _, v := range slice { + dslice = append(dslice, a(v)) + } + return +} + +// SliceRand returns random one from slice. +func SliceRand(a []interface{}) (b interface{}) { + randnum := rand.Intn(len(a)) + b = a[randnum] + return +} + +// SliceSum sums all values in int64 slice. +func SliceSum(intslice []int64) (sum int64) { + for _, v := range intslice { + sum += v + } + return +} + +// SliceFilter generates a new slice after filter function. +func SliceFilter(slice []interface{}, a filtertype) (ftslice []interface{}) { + for _, v := range slice { + if a(v) { + ftslice = append(ftslice, v) + } + } + return +} + +// SliceDiff returns diff slice of slice1 - slice2. +func SliceDiff(slice1, slice2 []interface{}) (diffslice []interface{}) { + for _, v := range slice1 { + if !InSliceIface(v, slice2) { + diffslice = append(diffslice, v) + } + } + return +} + +// SliceIntersect returns slice that are present in all the slice1 and slice2. +func SliceIntersect(slice1, slice2 []interface{}) (diffslice []interface{}) { + for _, v := range slice1 { + if InSliceIface(v, slice2) { + diffslice = append(diffslice, v) + } + } + return +} + +// SliceChunk separates one slice to some sized slice. +func SliceChunk(slice []interface{}, size int) (chunkslice [][]interface{}) { + if size >= len(slice) { + chunkslice = append(chunkslice, slice) + return + } + end := size + for i := 0; i <= (len(slice) - size); i += size { + chunkslice = append(chunkslice, slice[i:end]) + end += size + } + return +} + +// SliceRange generates a new slice from begin to end with step duration of int64 number. +func SliceRange(start, end, step int64) (intslice []int64) { + for i := start; i <= end; i += step { + intslice = append(intslice, i) + } + return +} + +// SlicePad prepends size number of val into slice. +func SlicePad(slice []interface{}, size int, val interface{}) []interface{} { + if size <= len(slice) { + return slice + } + for i := 0; i < (size - len(slice)); i++ { + slice = append(slice, val) + } + return slice +} + +// SliceUnique cleans repeated values in slice. +func SliceUnique(slice []interface{}) (uniqueslice []interface{}) { + for _, v := range slice { + if !InSliceIface(v, uniqueslice) { + uniqueslice = append(uniqueslice, v) + } + } + return +} + +// SliceShuffle shuffles a slice. +func SliceShuffle(slice []interface{}) []interface{} { + for i := 0; i < len(slice); i++ { + a := rand.Intn(len(slice)) + b := rand.Intn(len(slice)) + slice[a], slice[b] = slice[b], slice[a] + } + return slice +} diff --git a/pkg/utils/slice_test.go b/pkg/utils/slice_test.go new file mode 100644 index 00000000..142dec96 --- /dev/null +++ b/pkg/utils/slice_test.go @@ -0,0 +1,29 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package utils + +import ( + "testing" +) + +func TestInSlice(t *testing.T) { + sl := []string{"A", "b"} + if !InSlice("A", sl) { + t.Error("should be true") + } + if InSlice("B", sl) { + t.Error("should be false") + } +} diff --git a/pkg/utils/testdata/grepe.test b/pkg/utils/testdata/grepe.test new file mode 100644 index 00000000..6c014c40 --- /dev/null +++ b/pkg/utils/testdata/grepe.test @@ -0,0 +1,7 @@ +# empty lines + + + +hello +# comment +world diff --git a/pkg/utils/utils.go b/pkg/utils/utils.go new file mode 100644 index 00000000..3874b803 --- /dev/null +++ b/pkg/utils/utils.go @@ -0,0 +1,89 @@ +package utils + +import ( + "os" + "path/filepath" + "regexp" + "runtime" + "strconv" + "strings" +) + +// GetGOPATHs returns all paths in GOPATH variable. +func GetGOPATHs() []string { + gopath := os.Getenv("GOPATH") + if gopath == "" && compareGoVersion(runtime.Version(), "go1.8") >= 0 { + gopath = defaultGOPATH() + } + return filepath.SplitList(gopath) +} + +func compareGoVersion(a, b string) int { + reg := regexp.MustCompile("^\\d*") + + a = strings.TrimPrefix(a, "go") + b = strings.TrimPrefix(b, "go") + + versionsA := strings.Split(a, ".") + versionsB := strings.Split(b, ".") + + for i := 0; i < len(versionsA) && i < len(versionsB); i++ { + versionA := versionsA[i] + versionB := versionsB[i] + + vA, err := strconv.Atoi(versionA) + if err != nil { + str := reg.FindString(versionA) + if str != "" { + vA, _ = strconv.Atoi(str) + } else { + vA = -1 + } + } + + vB, err := strconv.Atoi(versionB) + if err != nil { + str := reg.FindString(versionB) + if str != "" { + vB, _ = strconv.Atoi(str) + } else { + vB = -1 + } + } + + if vA > vB { + // vA = 12, vB = 8 + return 1 + } else if vA < vB { + // vA = 6, vB = 8 + return -1 + } else if vA == -1 { + // vA = rc1, vB = rc3 + return strings.Compare(versionA, versionB) + } + + // vA = vB = 8 + continue + } + + if len(versionsA) > len(versionsB) { + return 1 + } else if len(versionsA) == len(versionsB) { + return 0 + } + + return -1 +} + +func defaultGOPATH() string { + env := "HOME" + if runtime.GOOS == "windows" { + env = "USERPROFILE" + } else if runtime.GOOS == "plan9" { + env = "home" + } + if home := os.Getenv(env); home != "" { + return filepath.Join(home, "go") + } + return "" +} diff --git a/pkg/utils/utils_test.go b/pkg/utils/utils_test.go new file mode 100644 index 00000000..ced6f63f --- /dev/null +++ b/pkg/utils/utils_test.go @@ -0,0 +1,36 @@ +package utils + +import ( + "testing" +) + +func TestCompareGoVersion(t *testing.T) { + targetVersion := "go1.8" + if compareGoVersion("go1.12.4", targetVersion) != 1 { + t.Error("should be 1") + } + + if compareGoVersion("go1.8.7", targetVersion) != 1 { + t.Error("should be 1") + } + + if compareGoVersion("go1.8", targetVersion) != 0 { + t.Error("should be 0") + } + + if compareGoVersion("go1.7.6", targetVersion) != -1 { + t.Error("should be -1") + } + + if compareGoVersion("go1.12.1rc1", targetVersion) != 1 { + t.Error("should be 1") + } + + if compareGoVersion("go1.8rc1", targetVersion) != 0 { + t.Error("should be 0") + } + + if compareGoVersion("go1.7rc1", targetVersion) != -1 { + t.Error("should be -1") + } +} diff --git a/pkg/validation/README.md b/pkg/validation/README.md new file mode 100644 index 00000000..43373e47 --- /dev/null +++ b/pkg/validation/README.md @@ -0,0 +1,147 @@ +validation +============== + +validation is a form validation for a data validation and error collecting using Go. + +## Installation and tests + +Install: + + go get github.com/astaxie/beego/validation + +Test: + + go test github.com/astaxie/beego/validation + +## Example + +Direct Use: + + import ( + "github.com/astaxie/beego/validation" + "log" + ) + + type User struct { + Name string + Age int + } + + func main() { + u := User{"man", 40} + valid := validation.Validation{} + valid.Required(u.Name, "name") + valid.MaxSize(u.Name, 15, "nameMax") + valid.Range(u.Age, 0, 140, "age") + if valid.HasErrors() { + // validation does not pass + // print invalid message + for _, err := range valid.Errors { + log.Println(err.Key, err.Message) + } + } + // or use like this + if v := valid.Max(u.Age, 140, "ageMax"); !v.Ok { + log.Println(v.Error.Key, v.Error.Message) + } + } + +Struct Tag Use: + + import ( + "github.com/astaxie/beego/validation" + ) + + // validation function follow with "valid" tag + // functions divide with ";" + // parameters in parentheses "()" and divide with "," + // Match function's pattern string must in "//" + type user struct { + Id int + Name string `valid:"Required;Match(/^(test)?\\w*@;com$/)"` + Age int `valid:"Required;Range(1, 140)"` + } + + func main() { + valid := validation.Validation{} + // ignore empty field valid + // see CanSkipFuncs + // valid := validation.Validation{RequiredFirst:true} + u := user{Name: "test", Age: 40} + b, err := valid.Valid(u) + if err != nil { + // handle error + } + if !b { + // validation does not pass + // blabla... + } + } + +Use custom function: + + import ( + "github.com/astaxie/beego/validation" + ) + + type user struct { + Id int + Name string `valid:"Required;IsMe"` + Age int `valid:"Required;Range(1, 140)"` + } + + func IsMe(v *validation.Validation, obj interface{}, key string) { + name, ok:= obj.(string) + if !ok { + // wrong use case? + return + } + + if name != "me" { + // valid false + v.SetError("Name", "is not me!") + } + } + + func main() { + valid := validation.Validation{} + if err := validation.AddCustomFunc("IsMe", IsMe); err != nil { + // hadle error + } + u := user{Name: "test", Age: 40} + b, err := valid.Valid(u) + if err != nil { + // handle error + } + if !b { + // validation does not pass + // blabla... + } + } + +Struct Tag Functions: + + Required + Min(min int) + Max(max int) + Range(min, max int) + MinSize(min int) + MaxSize(max int) + Length(length int) + Alpha + Numeric + AlphaNumeric + Match(pattern string) + AlphaDash + Email + IP + Base64 + Mobile + Tel + Phone + ZipCode + + +## LICENSE + +BSD License http://creativecommons.org/licenses/BSD/ diff --git a/pkg/validation/util.go b/pkg/validation/util.go new file mode 100644 index 00000000..82206f4f --- /dev/null +++ b/pkg/validation/util.go @@ -0,0 +1,298 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package validation + +import ( + "fmt" + "reflect" + "regexp" + "strconv" + "strings" +) + +const ( + // ValidTag struct tag + ValidTag = "valid" + + LabelTag = "label" + + wordsize = 32 << (^uint(0) >> 32 & 1) +) + +var ( + // key: function name + // value: the number of parameters + funcs = make(Funcs) + + // doesn't belong to validation functions + unFuncs = map[string]bool{ + "Clear": true, + "HasErrors": true, + "ErrorMap": true, + "Error": true, + "apply": true, + "Check": true, + "Valid": true, + "NoMatch": true, + } + // ErrInt64On32 show 32 bit platform not support int64 + ErrInt64On32 = fmt.Errorf("not support int64 on 32-bit platform") +) + +func init() { + v := &Validation{} + t := reflect.TypeOf(v) + for i := 0; i < t.NumMethod(); i++ { + m := t.Method(i) + if !unFuncs[m.Name] { + funcs[m.Name] = m.Func + } + } +} + +// CustomFunc is for custom validate function +type CustomFunc func(v *Validation, obj interface{}, key string) + +// AddCustomFunc Add a custom function to validation +// The name can not be: +// Clear +// HasErrors +// ErrorMap +// Error +// Check +// Valid +// NoMatch +// If the name is same with exists function, it will replace the origin valid function +func AddCustomFunc(name string, f CustomFunc) error { + if unFuncs[name] { + return fmt.Errorf("invalid function name: %s", name) + } + + funcs[name] = reflect.ValueOf(f) + return nil +} + +// ValidFunc Valid function type +type ValidFunc struct { + Name string + Params []interface{} +} + +// Funcs Validate function map +type Funcs map[string]reflect.Value + +// Call validate values with named type string +func (f Funcs) Call(name string, params ...interface{}) (result []reflect.Value, err error) { + defer func() { + if r := recover(); r != nil { + err = fmt.Errorf("%v", r) + } + }() + if _, ok := f[name]; !ok { + err = fmt.Errorf("%s does not exist", name) + return + } + if len(params) != f[name].Type().NumIn() { + err = fmt.Errorf("The number of params is not adapted") + return + } + in := make([]reflect.Value, len(params)) + for k, param := range params { + in[k] = reflect.ValueOf(param) + } + result = f[name].Call(in) + return +} + +func isStruct(t reflect.Type) bool { + return t.Kind() == reflect.Struct +} + +func isStructPtr(t reflect.Type) bool { + return t.Kind() == reflect.Ptr && t.Elem().Kind() == reflect.Struct +} + +func getValidFuncs(f reflect.StructField) (vfs []ValidFunc, err error) { + tag := f.Tag.Get(ValidTag) + label := f.Tag.Get(LabelTag) + if len(tag) == 0 { + return + } + if vfs, tag, err = getRegFuncs(tag, f.Name); err != nil { + return + } + fs := strings.Split(tag, ";") + for _, vfunc := range fs { + var vf ValidFunc + if len(vfunc) == 0 { + continue + } + vf, err = parseFunc(vfunc, f.Name, label) + if err != nil { + return + } + vfs = append(vfs, vf) + } + return +} + +// Get Match function +// May be get NoMatch function in the future +func getRegFuncs(tag, key string) (vfs []ValidFunc, str string, err error) { + tag = strings.TrimSpace(tag) + index := strings.Index(tag, "Match(/") + if index == -1 { + str = tag + return + } + end := strings.LastIndex(tag, "/)") + if end < index { + err = fmt.Errorf("invalid Match function") + return + } + reg, err := regexp.Compile(tag[index+len("Match(/") : end]) + if err != nil { + return + } + vfs = []ValidFunc{{"Match", []interface{}{reg, key + ".Match"}}} + str = strings.TrimSpace(tag[:index]) + strings.TrimSpace(tag[end+len("/)"):]) + return +} + +func parseFunc(vfunc, key string, label string) (v ValidFunc, err error) { + defer func() { + if r := recover(); r != nil { + err = fmt.Errorf("%v", r) + } + }() + + vfunc = strings.TrimSpace(vfunc) + start := strings.Index(vfunc, "(") + var num int + + // doesn't need parameter valid function + if start == -1 { + if num, err = numIn(vfunc); err != nil { + return + } + if num != 0 { + err = fmt.Errorf("%s require %d parameters", vfunc, num) + return + } + v = ValidFunc{vfunc, []interface{}{key + "." + vfunc + "." + label}} + return + } + + end := strings.Index(vfunc, ")") + if end == -1 { + err = fmt.Errorf("invalid valid function") + return + } + + name := strings.TrimSpace(vfunc[:start]) + if num, err = numIn(name); err != nil { + return + } + + params := strings.Split(vfunc[start+1:end], ",") + // the num of param must be equal + if num != len(params) { + err = fmt.Errorf("%s require %d parameters", name, num) + return + } + + tParams, err := trim(name, key+"."+ name + "." + label, params) + if err != nil { + return + } + v = ValidFunc{name, tParams} + return +} + +func numIn(name string) (num int, err error) { + fn, ok := funcs[name] + if !ok { + err = fmt.Errorf("doesn't exists %s valid function", name) + return + } + // sub *Validation obj and key + num = fn.Type().NumIn() - 3 + return +} + +func trim(name, key string, s []string) (ts []interface{}, err error) { + ts = make([]interface{}, len(s), len(s)+1) + fn, ok := funcs[name] + if !ok { + err = fmt.Errorf("doesn't exists %s valid function", name) + return + } + for i := 0; i < len(s); i++ { + var param interface{} + // skip *Validation and obj params + if param, err = parseParam(fn.Type().In(i+2), strings.TrimSpace(s[i])); err != nil { + return + } + ts[i] = param + } + ts = append(ts, key) + return +} + +// modify the parameters's type to adapt the function input parameters' type +func parseParam(t reflect.Type, s string) (i interface{}, err error) { + switch t.Kind() { + case reflect.Int: + i, err = strconv.Atoi(s) + case reflect.Int64: + if wordsize == 32 { + return nil, ErrInt64On32 + } + i, err = strconv.ParseInt(s, 10, 64) + case reflect.Int32: + var v int64 + v, err = strconv.ParseInt(s, 10, 32) + if err == nil { + i = int32(v) + } + case reflect.Int16: + var v int64 + v, err = strconv.ParseInt(s, 10, 16) + if err == nil { + i = int16(v) + } + case reflect.Int8: + var v int64 + v, err = strconv.ParseInt(s, 10, 8) + if err == nil { + i = int8(v) + } + case reflect.String: + i = s + case reflect.Ptr: + if t.Elem().String() != "regexp.Regexp" { + err = fmt.Errorf("not support %s", t.Elem().String()) + return + } + i, err = regexp.Compile(s) + default: + err = fmt.Errorf("not support %s", t.Kind().String()) + } + return +} + +func mergeParam(v *Validation, obj interface{}, params []interface{}) []interface{} { + return append([]interface{}{v, obj}, params...) +} diff --git a/pkg/validation/util_test.go b/pkg/validation/util_test.go new file mode 100644 index 00000000..58ca38db --- /dev/null +++ b/pkg/validation/util_test.go @@ -0,0 +1,128 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package validation + +import ( + "log" + "reflect" + "testing" +) + +type user struct { + ID int + Tag string `valid:"Maxx(aa)"` + Name string `valid:"Required;"` + Age int `valid:"Required; Range(1, 140)"` + match string `valid:"Required; Match(/^(test)?\\w*@(/test/);com$/);Max(2)"` +} + +func TestGetValidFuncs(t *testing.T) { + u := user{Name: "test", Age: 1} + tf := reflect.TypeOf(u) + var vfs []ValidFunc + var err error + + f, _ := tf.FieldByName("ID") + if vfs, err = getValidFuncs(f); err != nil { + t.Fatal(err) + } + if len(vfs) != 0 { + t.Fatal("should get none ValidFunc") + } + + f, _ = tf.FieldByName("Tag") + if _, err = getValidFuncs(f); err.Error() != "doesn't exists Maxx valid function" { + t.Fatal(err) + } + + f, _ = tf.FieldByName("Name") + if vfs, err = getValidFuncs(f); err != nil { + t.Fatal(err) + } + if len(vfs) != 1 { + t.Fatal("should get 1 ValidFunc") + } + if vfs[0].Name != "Required" && len(vfs[0].Params) != 0 { + t.Error("Required funcs should be got") + } + + f, _ = tf.FieldByName("Age") + if vfs, err = getValidFuncs(f); err != nil { + t.Fatal(err) + } + if len(vfs) != 2 { + t.Fatal("should get 2 ValidFunc") + } + if vfs[0].Name != "Required" && len(vfs[0].Params) != 0 { + t.Error("Required funcs should be got") + } + if vfs[1].Name != "Range" && len(vfs[1].Params) != 2 { + t.Error("Range funcs should be got") + } + + f, _ = tf.FieldByName("match") + if vfs, err = getValidFuncs(f); err != nil { + t.Fatal(err) + } + if len(vfs) != 3 { + t.Fatal("should get 3 ValidFunc but now is", len(vfs)) + } +} + +type User struct { + Name string `valid:"Required;MaxSize(5)" ` + Sex string `valid:"Required;" label:"sex_label"` + Age int `valid:"Required;Range(1, 140);" label:"age_label"` +} + +func TestValidation(t *testing.T) { + u := User{"man1238888456", "", 1140} + valid := Validation{} + b, err := valid.Valid(&u) + if err != nil { + // handle error + } + if !b { + // validation does not pass + // blabla... + for _, err := range valid.Errors { + log.Println(err.Key, err.Message) + } + if len(valid.Errors) != 3 { + t.Error("must be has 3 error") + } + } else { + t.Error("must be has 3 error") + } +} + +func TestCall(t *testing.T) { + u := user{Name: "test", Age: 180} + tf := reflect.TypeOf(u) + var vfs []ValidFunc + var err error + f, _ := tf.FieldByName("Age") + if vfs, err = getValidFuncs(f); err != nil { + t.Fatal(err) + } + valid := &Validation{} + vfs[1].Params = append([]interface{}{valid, u.Age}, vfs[1].Params...) + if _, err = funcs.Call(vfs[1].Name, vfs[1].Params...); err != nil { + t.Fatal(err) + } + if len(valid.Errors) != 1 { + t.Error("age out of range should be has an error") + } +} diff --git a/pkg/validation/validation.go b/pkg/validation/validation.go new file mode 100644 index 00000000..190e0f0e --- /dev/null +++ b/pkg/validation/validation.go @@ -0,0 +1,456 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package validation for validations +// +// import ( +// "github.com/astaxie/beego/validation" +// "log" +// ) +// +// type User struct { +// Name string +// Age int +// } +// +// func main() { +// u := User{"man", 40} +// valid := validation.Validation{} +// valid.Required(u.Name, "name") +// valid.MaxSize(u.Name, 15, "nameMax") +// valid.Range(u.Age, 0, 140, "age") +// if valid.HasErrors() { +// // validation does not pass +// // print invalid message +// for _, err := range valid.Errors { +// log.Println(err.Key, err.Message) +// } +// } +// // or use like this +// if v := valid.Max(u.Age, 140, "ageMax"); !v.Ok { +// log.Println(v.Error.Key, v.Error.Message) +// } +// } +// +// more info: http://beego.me/docs/mvc/controller/validation.md +package validation + +import ( + "fmt" + "reflect" + "regexp" + "strings" +) + +// ValidFormer valid interface +type ValidFormer interface { + Valid(*Validation) +} + +// Error show the error +type Error struct { + Message, Key, Name, Field, Tmpl string + Value interface{} + LimitValue interface{} +} + +// String Returns the Message. +func (e *Error) String() string { + if e == nil { + return "" + } + return e.Message +} + +// Implement Error interface. +// Return e.String() +func (e *Error) Error() string { return e.String() } + +// Result is returned from every validation method. +// It provides an indication of success, and a pointer to the Error (if any). +type Result struct { + Error *Error + Ok bool +} + +// Key Get Result by given key string. +func (r *Result) Key(key string) *Result { + if r.Error != nil { + r.Error.Key = key + } + return r +} + +// Message Set Result message by string or format string with args +func (r *Result) Message(message string, args ...interface{}) *Result { + if r.Error != nil { + if len(args) == 0 { + r.Error.Message = message + } else { + r.Error.Message = fmt.Sprintf(message, args...) + } + } + return r +} + +// A Validation context manages data validation and error messages. +type Validation struct { + // if this field set true, in struct tag valid + // if the struct field vale is empty + // it will skip those valid functions, see CanSkipFuncs + RequiredFirst bool + + Errors []*Error + ErrorsMap map[string][]*Error +} + +// Clear Clean all ValidationError. +func (v *Validation) Clear() { + v.Errors = []*Error{} + v.ErrorsMap = nil +} + +// HasErrors Has ValidationError nor not. +func (v *Validation) HasErrors() bool { + return len(v.Errors) > 0 +} + +// ErrorMap Return the errors mapped by key. +// If there are multiple validation errors associated with a single key, the +// first one "wins". (Typically the first validation will be the more basic). +func (v *Validation) ErrorMap() map[string][]*Error { + return v.ErrorsMap +} + +// Error Add an error to the validation context. +func (v *Validation) Error(message string, args ...interface{}) *Result { + result := (&Result{ + Ok: false, + Error: &Error{}, + }).Message(message, args...) + v.Errors = append(v.Errors, result.Error) + return result +} + +// Required Test that the argument is non-nil and non-empty (if string or list) +func (v *Validation) Required(obj interface{}, key string) *Result { + return v.apply(Required{key}, obj) +} + +// Min Test that the obj is greater than min if obj's type is int +func (v *Validation) Min(obj interface{}, min int, key string) *Result { + return v.apply(Min{min, key}, obj) +} + +// Max Test that the obj is less than max if obj's type is int +func (v *Validation) Max(obj interface{}, max int, key string) *Result { + return v.apply(Max{max, key}, obj) +} + +// Range Test that the obj is between mni and max if obj's type is int +func (v *Validation) Range(obj interface{}, min, max int, key string) *Result { + return v.apply(Range{Min{Min: min}, Max{Max: max}, key}, obj) +} + +// MinSize Test that the obj is longer than min size if type is string or slice +func (v *Validation) MinSize(obj interface{}, min int, key string) *Result { + return v.apply(MinSize{min, key}, obj) +} + +// MaxSize Test that the obj is shorter than max size if type is string or slice +func (v *Validation) MaxSize(obj interface{}, max int, key string) *Result { + return v.apply(MaxSize{max, key}, obj) +} + +// Length Test that the obj is same length to n if type is string or slice +func (v *Validation) Length(obj interface{}, n int, key string) *Result { + return v.apply(Length{n, key}, obj) +} + +// Alpha Test that the obj is [a-zA-Z] if type is string +func (v *Validation) Alpha(obj interface{}, key string) *Result { + return v.apply(Alpha{key}, obj) +} + +// Numeric Test that the obj is [0-9] if type is string +func (v *Validation) Numeric(obj interface{}, key string) *Result { + return v.apply(Numeric{key}, obj) +} + +// AlphaNumeric Test that the obj is [0-9a-zA-Z] if type is string +func (v *Validation) AlphaNumeric(obj interface{}, key string) *Result { + return v.apply(AlphaNumeric{key}, obj) +} + +// Match Test that the obj matches regexp if type is string +func (v *Validation) Match(obj interface{}, regex *regexp.Regexp, key string) *Result { + return v.apply(Match{regex, key}, obj) +} + +// NoMatch Test that the obj doesn't match regexp if type is string +func (v *Validation) NoMatch(obj interface{}, regex *regexp.Regexp, key string) *Result { + return v.apply(NoMatch{Match{Regexp: regex}, key}, obj) +} + +// AlphaDash Test that the obj is [0-9a-zA-Z_-] if type is string +func (v *Validation) AlphaDash(obj interface{}, key string) *Result { + return v.apply(AlphaDash{NoMatch{Match: Match{Regexp: alphaDashPattern}}, key}, obj) +} + +// Email Test that the obj is email address if type is string +func (v *Validation) Email(obj interface{}, key string) *Result { + return v.apply(Email{Match{Regexp: emailPattern}, key}, obj) +} + +// IP Test that the obj is IP address if type is string +func (v *Validation) IP(obj interface{}, key string) *Result { + return v.apply(IP{Match{Regexp: ipPattern}, key}, obj) +} + +// Base64 Test that the obj is base64 encoded if type is string +func (v *Validation) Base64(obj interface{}, key string) *Result { + return v.apply(Base64{Match{Regexp: base64Pattern}, key}, obj) +} + +// Mobile Test that the obj is chinese mobile number if type is string +func (v *Validation) Mobile(obj interface{}, key string) *Result { + return v.apply(Mobile{Match{Regexp: mobilePattern}, key}, obj) +} + +// Tel Test that the obj is chinese telephone number if type is string +func (v *Validation) Tel(obj interface{}, key string) *Result { + return v.apply(Tel{Match{Regexp: telPattern}, key}, obj) +} + +// Phone Test that the obj is chinese mobile or telephone number if type is string +func (v *Validation) Phone(obj interface{}, key string) *Result { + return v.apply(Phone{Mobile{Match: Match{Regexp: mobilePattern}}, + Tel{Match: Match{Regexp: telPattern}}, key}, obj) +} + +// ZipCode Test that the obj is chinese zip code if type is string +func (v *Validation) ZipCode(obj interface{}, key string) *Result { + return v.apply(ZipCode{Match{Regexp: zipCodePattern}, key}, obj) +} + +func (v *Validation) apply(chk Validator, obj interface{}) *Result { + if nil == obj { + if chk.IsSatisfied(obj) { + return &Result{Ok: true} + } + } else if reflect.TypeOf(obj).Kind() == reflect.Ptr { + if reflect.ValueOf(obj).IsNil() { + if chk.IsSatisfied(nil) { + return &Result{Ok: true} + } + } else { + if chk.IsSatisfied(reflect.ValueOf(obj).Elem().Interface()) { + return &Result{Ok: true} + } + } + } else if chk.IsSatisfied(obj) { + return &Result{Ok: true} + } + + // Add the error to the validation context. + key := chk.GetKey() + Name := key + Field := "" + Label := "" + parts := strings.Split(key, ".") + if len(parts) == 3 { + Field = parts[0] + Name = parts[1] + Label = parts[2] + if len(Label) == 0 { + Label = Field + } + } + + err := &Error{ + Message: Label + " " + chk.DefaultMessage(), + Key: key, + Name: Name, + Field: Field, + Value: obj, + Tmpl: MessageTmpls[Name], + LimitValue: chk.GetLimitValue(), + } + v.setError(err) + + // Also return it in the result. + return &Result{ + Ok: false, + Error: err, + } +} + +// key must like aa.bb.cc or aa.bb. +// AddError adds independent error message for the provided key +func (v *Validation) AddError(key, message string) { + Name := key + Field := "" + + Label := "" + parts := strings.Split(key, ".") + if len(parts) == 3 { + Field = parts[0] + Name = parts[1] + Label = parts[2] + if len(Label) == 0 { + Label = Field + } + } + + err := &Error{ + Message: Label + " " + message, + Key: key, + Name: Name, + Field: Field, + } + v.setError(err) +} + +func (v *Validation) setError(err *Error) { + v.Errors = append(v.Errors, err) + if v.ErrorsMap == nil { + v.ErrorsMap = make(map[string][]*Error) + } + if _, ok := v.ErrorsMap[err.Field]; !ok { + v.ErrorsMap[err.Field] = []*Error{} + } + v.ErrorsMap[err.Field] = append(v.ErrorsMap[err.Field], err) +} + +// SetError Set error message for one field in ValidationError +func (v *Validation) SetError(fieldName string, errMsg string) *Error { + err := &Error{Key: fieldName, Field: fieldName, Tmpl: errMsg, Message: errMsg} + v.setError(err) + return err +} + +// Check Apply a group of validators to a field, in order, and return the +// ValidationResult from the first one that fails, or the last one that +// succeeds. +func (v *Validation) Check(obj interface{}, checks ...Validator) *Result { + var result *Result + for _, check := range checks { + result = v.apply(check, obj) + if !result.Ok { + return result + } + } + return result +} + +// Valid Validate a struct. +// the obj parameter must be a struct or a struct pointer +func (v *Validation) Valid(obj interface{}) (b bool, err error) { + objT := reflect.TypeOf(obj) + objV := reflect.ValueOf(obj) + switch { + case isStruct(objT): + case isStructPtr(objT): + objT = objT.Elem() + objV = objV.Elem() + default: + err = fmt.Errorf("%v must be a struct or a struct pointer", obj) + return + } + + for i := 0; i < objT.NumField(); i++ { + var vfs []ValidFunc + if vfs, err = getValidFuncs(objT.Field(i)); err != nil { + return + } + + var hasRequired bool + for _, vf := range vfs { + if vf.Name == "Required" { + hasRequired = true + } + + currentField := objV.Field(i).Interface() + if objV.Field(i).Kind() == reflect.Ptr { + if objV.Field(i).IsNil() { + currentField = "" + } else { + currentField = objV.Field(i).Elem().Interface() + } + } + + chk := Required{""}.IsSatisfied(currentField) + if !hasRequired && v.RequiredFirst && !chk { + if _, ok := CanSkipFuncs[vf.Name]; ok { + continue + } + } + + if _, err = funcs.Call(vf.Name, + mergeParam(v, objV.Field(i).Interface(), vf.Params)...); err != nil { + return + } + } + } + + if !v.HasErrors() { + if form, ok := obj.(ValidFormer); ok { + form.Valid(v) + } + } + + return !v.HasErrors(), nil +} + +// RecursiveValid Recursively validate a struct. +// Step1: Validate by v.Valid +// Step2: If pass on step1, then reflect obj's fields +// Step3: Do the Recursively validation to all struct or struct pointer fields +func (v *Validation) RecursiveValid(objc interface{}) (bool, error) { + //Step 1: validate obj itself firstly + // fails if objc is not struct + pass, err := v.Valid(objc) + if err != nil || !pass { + return pass, err // Stop recursive validation + } + // Step 2: Validate struct's struct fields + objT := reflect.TypeOf(objc) + objV := reflect.ValueOf(objc) + + if isStructPtr(objT) { + objT = objT.Elem() + objV = objV.Elem() + } + + for i := 0; i < objT.NumField(); i++ { + + t := objT.Field(i).Type + + // Recursive applies to struct or pointer to structs fields + if isStruct(t) || isStructPtr(t) { + // Step 3: do the recursive validation + // Only valid the Public field recursively + if objV.Field(i).CanInterface() { + pass, err = v.RecursiveValid(objV.Field(i).Interface()) + } + } + } + return pass, err +} + +func (v *Validation) CanSkipAlso(skipFunc string) { + if _, ok := CanSkipFuncs[skipFunc]; !ok { + CanSkipFuncs[skipFunc] = struct{}{} + } +} diff --git a/pkg/validation/validation_test.go b/pkg/validation/validation_test.go new file mode 100644 index 00000000..b4b5b1b6 --- /dev/null +++ b/pkg/validation/validation_test.go @@ -0,0 +1,609 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package validation + +import ( + "regexp" + "testing" + "time" +) + +func TestRequired(t *testing.T) { + valid := Validation{} + + if valid.Required(nil, "nil").Ok { + t.Error("nil object should be false") + } + if !valid.Required(true, "bool").Ok { + t.Error("Bool value should always return true") + } + if !valid.Required(false, "bool").Ok { + t.Error("Bool value should always return true") + } + if valid.Required("", "string").Ok { + t.Error("\"'\" string should be false") + } + if valid.Required(" ", "string").Ok { + t.Error("\" \" string should be false") // For #2361 + } + if valid.Required("\n", "string").Ok { + t.Error("new line string should be false") // For #2361 + } + if !valid.Required("astaxie", "string").Ok { + t.Error("string should be true") + } + if valid.Required(0, "zero").Ok { + t.Error("Integer should not be equal 0") + } + if !valid.Required(1, "int").Ok { + t.Error("Integer except 0 should be true") + } + if !valid.Required(time.Now(), "time").Ok { + t.Error("time should be true") + } + if valid.Required([]string{}, "emptySlice").Ok { + t.Error("empty slice should be false") + } + if !valid.Required([]interface{}{"ok"}, "slice").Ok { + t.Error("slice should be true") + } +} + +func TestMin(t *testing.T) { + valid := Validation{} + + if valid.Min(-1, 0, "min0").Ok { + t.Error("-1 is less than the minimum value of 0 should be false") + } + if !valid.Min(1, 0, "min0").Ok { + t.Error("1 is greater or equal than the minimum value of 0 should be true") + } +} + +func TestMax(t *testing.T) { + valid := Validation{} + + if valid.Max(1, 0, "max0").Ok { + t.Error("1 is greater than the minimum value of 0 should be false") + } + if !valid.Max(-1, 0, "max0").Ok { + t.Error("-1 is less or equal than the maximum value of 0 should be true") + } +} + +func TestRange(t *testing.T) { + valid := Validation{} + + if valid.Range(-1, 0, 1, "range0_1").Ok { + t.Error("-1 is between 0 and 1 should be false") + } + if !valid.Range(1, 0, 1, "range0_1").Ok { + t.Error("1 is between 0 and 1 should be true") + } +} + +func TestMinSize(t *testing.T) { + valid := Validation{} + + if valid.MinSize("", 1, "minSize1").Ok { + t.Error("the length of \"\" is less than the minimum value of 1 should be false") + } + if !valid.MinSize("ok", 1, "minSize1").Ok { + t.Error("the length of \"ok\" is greater or equal than the minimum value of 1 should be true") + } + if valid.MinSize([]string{}, 1, "minSize1").Ok { + t.Error("the length of empty slice is less than the minimum value of 1 should be false") + } + if !valid.MinSize([]interface{}{"ok"}, 1, "minSize1").Ok { + t.Error("the length of [\"ok\"] is greater or equal than the minimum value of 1 should be true") + } +} + +func TestMaxSize(t *testing.T) { + valid := Validation{} + + if valid.MaxSize("ok", 1, "maxSize1").Ok { + t.Error("the length of \"ok\" is greater than the maximum value of 1 should be false") + } + if !valid.MaxSize("", 1, "maxSize1").Ok { + t.Error("the length of \"\" is less or equal than the maximum value of 1 should be true") + } + if valid.MaxSize([]interface{}{"ok", false}, 1, "maxSize1").Ok { + t.Error("the length of [\"ok\", false] is greater than the maximum value of 1 should be false") + } + if !valid.MaxSize([]string{}, 1, "maxSize1").Ok { + t.Error("the length of empty slice is less or equal than the maximum value of 1 should be true") + } +} + +func TestLength(t *testing.T) { + valid := Validation{} + + if valid.Length("", 1, "length1").Ok { + t.Error("the length of \"\" must equal 1 should be false") + } + if !valid.Length("1", 1, "length1").Ok { + t.Error("the length of \"1\" must equal 1 should be true") + } + if valid.Length([]string{}, 1, "length1").Ok { + t.Error("the length of empty slice must equal 1 should be false") + } + if !valid.Length([]interface{}{"ok"}, 1, "length1").Ok { + t.Error("the length of [\"ok\"] must equal 1 should be true") + } +} + +func TestAlpha(t *testing.T) { + valid := Validation{} + + if valid.Alpha("a,1-@ $", "alpha").Ok { + t.Error("\"a,1-@ $\" are valid alpha characters should be false") + } + if !valid.Alpha("abCD", "alpha").Ok { + t.Error("\"abCD\" are valid alpha characters should be true") + } +} + +func TestNumeric(t *testing.T) { + valid := Validation{} + + if valid.Numeric("a,1-@ $", "numeric").Ok { + t.Error("\"a,1-@ $\" are valid numeric characters should be false") + } + if !valid.Numeric("1234", "numeric").Ok { + t.Error("\"1234\" are valid numeric characters should be true") + } +} + +func TestAlphaNumeric(t *testing.T) { + valid := Validation{} + + if valid.AlphaNumeric("a,1-@ $", "alphaNumeric").Ok { + t.Error("\"a,1-@ $\" are valid alpha or numeric characters should be false") + } + if !valid.AlphaNumeric("1234aB", "alphaNumeric").Ok { + t.Error("\"1234aB\" are valid alpha or numeric characters should be true") + } +} + +func TestMatch(t *testing.T) { + valid := Validation{} + + if valid.Match("suchuangji@gmail", regexp.MustCompile(`^\w+@\w+\.\w+$`), "match").Ok { + t.Error("\"suchuangji@gmail\" match \"^\\w+@\\w+\\.\\w+$\" should be false") + } + if !valid.Match("suchuangji@gmail.com", regexp.MustCompile(`^\w+@\w+\.\w+$`), "match").Ok { + t.Error("\"suchuangji@gmail\" match \"^\\w+@\\w+\\.\\w+$\" should be true") + } +} + +func TestNoMatch(t *testing.T) { + valid := Validation{} + + if valid.NoMatch("123@gmail", regexp.MustCompile(`[^\w\d]`), "nomatch").Ok { + t.Error("\"123@gmail\" not match \"[^\\w\\d]\" should be false") + } + if !valid.NoMatch("123gmail", regexp.MustCompile(`[^\w\d]`), "match").Ok { + t.Error("\"123@gmail\" not match \"[^\\w\\d@]\" should be true") + } +} + +func TestAlphaDash(t *testing.T) { + valid := Validation{} + + if valid.AlphaDash("a,1-@ $", "alphaDash").Ok { + t.Error("\"a,1-@ $\" are valid alpha or numeric or dash(-_) characters should be false") + } + if !valid.AlphaDash("1234aB-_", "alphaDash").Ok { + t.Error("\"1234aB\" are valid alpha or numeric or dash(-_) characters should be true") + } +} + +func TestEmail(t *testing.T) { + valid := Validation{} + + if valid.Email("not@a email", "email").Ok { + t.Error("\"not@a email\" is a valid email address should be false") + } + if !valid.Email("suchuangji@gmail.com", "email").Ok { + t.Error("\"suchuangji@gmail.com\" is a valid email address should be true") + } + if valid.Email("@suchuangji@gmail.com", "email").Ok { + t.Error("\"@suchuangji@gmail.com\" is a valid email address should be false") + } + if valid.Email("suchuangji@gmail.com ok", "email").Ok { + t.Error("\"suchuangji@gmail.com ok\" is a valid email address should be false") + } +} + +func TestIP(t *testing.T) { + valid := Validation{} + + if valid.IP("11.255.255.256", "IP").Ok { + t.Error("\"11.255.255.256\" is a valid ip address should be false") + } + if !valid.IP("01.11.11.11", "IP").Ok { + t.Error("\"suchuangji@gmail.com\" is a valid ip address should be true") + } +} + +func TestBase64(t *testing.T) { + valid := Validation{} + + if valid.Base64("suchuangji@gmail.com", "base64").Ok { + t.Error("\"suchuangji@gmail.com\" are a valid base64 characters should be false") + } + if !valid.Base64("c3VjaHVhbmdqaUBnbWFpbC5jb20=", "base64").Ok { + t.Error("\"c3VjaHVhbmdqaUBnbWFpbC5jb20=\" are a valid base64 characters should be true") + } +} + +func TestMobile(t *testing.T) { + valid := Validation{} + + validMobiles := []string{ + "19800008888", + "18800008888", + "18000008888", + "8618300008888", + "+8614700008888", + "17300008888", + "+8617100008888", + "8617500008888", + "8617400008888", + "16200008888", + "16500008888", + "16600008888", + "16700008888", + "13300008888", + "14900008888", + "15300008888", + "17300008888", + "17700008888", + "18000008888", + "18900008888", + "19100008888", + "19900008888", + "19300008888", + "13000008888", + "13100008888", + "13200008888", + "14500008888", + "15500008888", + "15600008888", + "16600008888", + "17100008888", + "17500008888", + "17600008888", + "18500008888", + "18600008888", + "13400008888", + "13500008888", + "13600008888", + "13700008888", + "13800008888", + "13900008888", + "14700008888", + "15000008888", + "15100008888", + "15200008888", + "15800008888", + "15900008888", + "17200008888", + "17800008888", + "18200008888", + "18300008888", + "18400008888", + "18700008888", + "18800008888", + "19800008888", + } + + for _, m := range validMobiles { + if !valid.Mobile(m, "mobile").Ok { + t.Error(m + " is a valid mobile phone number should be true") + } + } +} + +func TestTel(t *testing.T) { + valid := Validation{} + + if valid.Tel("222-00008888", "telephone").Ok { + t.Error("\"222-00008888\" is a valid telephone number should be false") + } + if !valid.Tel("022-70008888", "telephone").Ok { + t.Error("\"022-70008888\" is a valid telephone number should be true") + } + if !valid.Tel("02270008888", "telephone").Ok { + t.Error("\"02270008888\" is a valid telephone number should be true") + } + if !valid.Tel("70008888", "telephone").Ok { + t.Error("\"70008888\" is a valid telephone number should be true") + } +} + +func TestPhone(t *testing.T) { + valid := Validation{} + + if valid.Phone("222-00008888", "phone").Ok { + t.Error("\"222-00008888\" is a valid phone number should be false") + } + if !valid.Mobile("+8614700008888", "phone").Ok { + t.Error("\"+8614700008888\" is a valid phone number should be true") + } + if !valid.Tel("02270008888", "phone").Ok { + t.Error("\"02270008888\" is a valid phone number should be true") + } +} + +func TestZipCode(t *testing.T) { + valid := Validation{} + + if valid.ZipCode("", "zipcode").Ok { + t.Error("\"00008888\" is a valid zipcode should be false") + } + if !valid.ZipCode("536000", "zipcode").Ok { + t.Error("\"536000\" is a valid zipcode should be true") + } +} + +func TestValid(t *testing.T) { + type user struct { + ID int + Name string `valid:"Required;Match(/^(test)?\\w*@(/test/);com$/)"` + Age int `valid:"Required;Range(1, 140)"` + } + valid := Validation{} + + u := user{Name: "test@/test/;com", Age: 40} + b, err := valid.Valid(u) + if err != nil { + t.Fatal(err) + } + if !b { + t.Error("validation should be passed") + } + + uptr := &user{Name: "test", Age: 40} + valid.Clear() + b, err = valid.Valid(uptr) + if err != nil { + t.Fatal(err) + } + if b { + t.Error("validation should not be passed") + } + if len(valid.Errors) != 1 { + t.Fatalf("valid errors len should be 1 but got %d", len(valid.Errors)) + } + if valid.Errors[0].Key != "Name.Match" { + t.Errorf("Message key should be `Name.Match` but got %s", valid.Errors[0].Key) + } + + u = user{Name: "test@/test/;com", Age: 180} + valid.Clear() + b, err = valid.Valid(u) + if err != nil { + t.Fatal(err) + } + if b { + t.Error("validation should not be passed") + } + if len(valid.Errors) != 1 { + t.Fatalf("valid errors len should be 1 but got %d", len(valid.Errors)) + } + if valid.Errors[0].Key != "Age.Range." { + t.Errorf("Message key should be `Age.Range` but got %s", valid.Errors[0].Key) + } +} + +func TestRecursiveValid(t *testing.T) { + type User struct { + ID int + Name string `valid:"Required;Match(/^(test)?\\w*@(/test/);com$/)"` + Age int `valid:"Required;Range(1, 140)"` + } + + type AnonymouseUser struct { + ID2 int + Name2 string `valid:"Required;Match(/^(test)?\\w*@(/test/);com$/)"` + Age2 int `valid:"Required;Range(1, 140)"` + } + + type Account struct { + Password string `valid:"Required"` + U User + AnonymouseUser + } + valid := Validation{} + + u := Account{Password: "abc123_", U: User{}} + b, err := valid.RecursiveValid(u) + if err != nil { + t.Fatal(err) + } + if b { + t.Error("validation should not be passed") + } +} + +func TestSkipValid(t *testing.T) { + type User struct { + ID int + + Email string `valid:"Email"` + ReqEmail string `valid:"Required;Email"` + + IP string `valid:"IP"` + ReqIP string `valid:"Required;IP"` + + Mobile string `valid:"Mobile"` + ReqMobile string `valid:"Required;Mobile"` + + Tel string `valid:"Tel"` + ReqTel string `valid:"Required;Tel"` + + Phone string `valid:"Phone"` + ReqPhone string `valid:"Required;Phone"` + + ZipCode string `valid:"ZipCode"` + ReqZipCode string `valid:"Required;ZipCode"` + } + + u := User{ + ReqEmail: "a@a.com", + ReqIP: "127.0.0.1", + ReqMobile: "18888888888", + ReqTel: "02088888888", + ReqPhone: "02088888888", + ReqZipCode: "510000", + } + + valid := Validation{} + b, err := valid.Valid(u) + if err != nil { + t.Fatal(err) + } + if b { + t.Fatal("validation should not be passed") + } + + valid = Validation{RequiredFirst: true} + b, err = valid.Valid(u) + if err != nil { + t.Fatal(err) + } + if !b { + t.Fatal("validation should be passed") + } +} + +func TestPointer(t *testing.T) { + type User struct { + ID int + + Email *string `valid:"Email"` + ReqEmail *string `valid:"Required;Email"` + } + + u := User{ + ReqEmail: nil, + Email: nil, + } + + valid := Validation{} + b, err := valid.Valid(u) + if err != nil { + t.Fatal(err) + } + if b { + t.Fatal("validation should not be passed") + } + + validEmail := "a@a.com" + u = User{ + ReqEmail: &validEmail, + Email: nil, + } + + valid = Validation{RequiredFirst: true} + b, err = valid.Valid(u) + if err != nil { + t.Fatal(err) + } + if !b { + t.Fatal("validation should be passed") + } + + u = User{ + ReqEmail: &validEmail, + Email: nil, + } + + valid = Validation{} + b, err = valid.Valid(u) + if err != nil { + t.Fatal(err) + } + if b { + t.Fatal("validation should not be passed") + } + + invalidEmail := "a@a" + u = User{ + ReqEmail: &validEmail, + Email: &invalidEmail, + } + + valid = Validation{RequiredFirst: true} + b, err = valid.Valid(u) + if err != nil { + t.Fatal(err) + } + if b { + t.Fatal("validation should not be passed") + } + + u = User{ + ReqEmail: &validEmail, + Email: &invalidEmail, + } + + valid = Validation{} + b, err = valid.Valid(u) + if err != nil { + t.Fatal(err) + } + if b { + t.Fatal("validation should not be passed") + } +} + +func TestCanSkipAlso(t *testing.T) { + type User struct { + ID int + + Email string `valid:"Email"` + ReqEmail string `valid:"Required;Email"` + MatchRange int `valid:"Range(10, 20)"` + } + + u := User{ + ReqEmail: "a@a.com", + Email: "", + MatchRange: 0, + } + + valid := Validation{RequiredFirst: true} + b, err := valid.Valid(u) + if err != nil { + t.Fatal(err) + } + if b { + t.Fatal("validation should not be passed") + } + + valid = Validation{RequiredFirst: true} + valid.CanSkipAlso("Range") + b, err = valid.Valid(u) + if err != nil { + t.Fatal(err) + } + if !b { + t.Fatal("validation should be passed") + } + +} diff --git a/pkg/validation/validators.go b/pkg/validation/validators.go new file mode 100644 index 00000000..38b6f1aa --- /dev/null +++ b/pkg/validation/validators.go @@ -0,0 +1,738 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package validation + +import ( + "fmt" + "github.com/astaxie/beego/logs" + "reflect" + "regexp" + "strings" + "sync" + "time" + "unicode/utf8" +) + +// CanSkipFuncs will skip valid if RequiredFirst is true and the struct field's value is empty +var CanSkipFuncs = map[string]struct{}{ + "Email": {}, + "IP": {}, + "Mobile": {}, + "Tel": {}, + "Phone": {}, + "ZipCode": {}, +} + +// MessageTmpls store commond validate template +var MessageTmpls = map[string]string{ + "Required": "Can not be empty", + "Min": "Minimum is %d", + "Max": "Maximum is %d", + "Range": "Range is %d to %d", + "MinSize": "Minimum size is %d", + "MaxSize": "Maximum size is %d", + "Length": "Required length is %d", + "Alpha": "Must be valid alpha characters", + "Numeric": "Must be valid numeric characters", + "AlphaNumeric": "Must be valid alpha or numeric characters", + "Match": "Must match %s", + "NoMatch": "Must not match %s", + "AlphaDash": "Must be valid alpha or numeric or dash(-_) characters", + "Email": "Must be a valid email address", + "IP": "Must be a valid ip address", + "Base64": "Must be valid base64 characters", + "Mobile": "Must be valid mobile number", + "Tel": "Must be valid telephone number", + "Phone": "Must be valid telephone or mobile phone number", + "ZipCode": "Must be valid zipcode", +} + +var once sync.Once + +// SetDefaultMessage set default messages +// if not set, the default messages are +// "Required": "Can not be empty", +// "Min": "Minimum is %d", +// "Max": "Maximum is %d", +// "Range": "Range is %d to %d", +// "MinSize": "Minimum size is %d", +// "MaxSize": "Maximum size is %d", +// "Length": "Required length is %d", +// "Alpha": "Must be valid alpha characters", +// "Numeric": "Must be valid numeric characters", +// "AlphaNumeric": "Must be valid alpha or numeric characters", +// "Match": "Must match %s", +// "NoMatch": "Must not match %s", +// "AlphaDash": "Must be valid alpha or numeric or dash(-_) characters", +// "Email": "Must be a valid email address", +// "IP": "Must be a valid ip address", +// "Base64": "Must be valid base64 characters", +// "Mobile": "Must be valid mobile number", +// "Tel": "Must be valid telephone number", +// "Phone": "Must be valid telephone or mobile phone number", +// "ZipCode": "Must be valid zipcode", +func SetDefaultMessage(msg map[string]string) { + if len(msg) == 0 { + return + } + + once.Do(func() { + for name := range msg { + MessageTmpls[name] = msg[name] + } + }) + logs.Warn(`you must SetDefaultMessage at once`) +} + +// Validator interface +type Validator interface { + IsSatisfied(interface{}) bool + DefaultMessage() string + GetKey() string + GetLimitValue() interface{} +} + +// Required struct +type Required struct { + Key string +} + +// IsSatisfied judge whether obj has value +func (r Required) IsSatisfied(obj interface{}) bool { + if obj == nil { + return false + } + + if str, ok := obj.(string); ok { + return len(strings.TrimSpace(str)) > 0 + } + if _, ok := obj.(bool); ok { + return true + } + if i, ok := obj.(int); ok { + return i != 0 + } + if i, ok := obj.(uint); ok { + return i != 0 + } + if i, ok := obj.(int8); ok { + return i != 0 + } + if i, ok := obj.(uint8); ok { + return i != 0 + } + if i, ok := obj.(int16); ok { + return i != 0 + } + if i, ok := obj.(uint16); ok { + return i != 0 + } + if i, ok := obj.(uint32); ok { + return i != 0 + } + if i, ok := obj.(int32); ok { + return i != 0 + } + if i, ok := obj.(int64); ok { + return i != 0 + } + if i, ok := obj.(uint64); ok { + return i != 0 + } + if t, ok := obj.(time.Time); ok { + return !t.IsZero() + } + v := reflect.ValueOf(obj) + if v.Kind() == reflect.Slice { + return v.Len() > 0 + } + return true +} + +// DefaultMessage return the default error message +func (r Required) DefaultMessage() string { + return MessageTmpls["Required"] +} + +// GetKey return the r.Key +func (r Required) GetKey() string { + return r.Key +} + +// GetLimitValue return nil now +func (r Required) GetLimitValue() interface{} { + return nil +} + +// Min check struct +type Min struct { + Min int + Key string +} + +// IsSatisfied judge whether obj is valid +// not support int64 on 32-bit platform +func (m Min) IsSatisfied(obj interface{}) bool { + var v int + switch obj.(type) { + case int64: + if wordsize == 32 { + return false + } + v = int(obj.(int64)) + case int: + v = obj.(int) + case int32: + v = int(obj.(int32)) + case int16: + v = int(obj.(int16)) + case int8: + v = int(obj.(int8)) + default: + return false + } + + return v >= m.Min +} + +// DefaultMessage return the default min error message +func (m Min) DefaultMessage() string { + return fmt.Sprintf(MessageTmpls["Min"], m.Min) +} + +// GetKey return the m.Key +func (m Min) GetKey() string { + return m.Key +} + +// GetLimitValue return the limit value, Min +func (m Min) GetLimitValue() interface{} { + return m.Min +} + +// Max validate struct +type Max struct { + Max int + Key string +} + +// IsSatisfied judge whether obj is valid +// not support int64 on 32-bit platform +func (m Max) IsSatisfied(obj interface{}) bool { + var v int + switch obj.(type) { + case int64: + if wordsize == 32 { + return false + } + v = int(obj.(int64)) + case int: + v = obj.(int) + case int32: + v = int(obj.(int32)) + case int16: + v = int(obj.(int16)) + case int8: + v = int(obj.(int8)) + default: + return false + } + + return v <= m.Max +} + +// DefaultMessage return the default max error message +func (m Max) DefaultMessage() string { + return fmt.Sprintf(MessageTmpls["Max"], m.Max) +} + +// GetKey return the m.Key +func (m Max) GetKey() string { + return m.Key +} + +// GetLimitValue return the limit value, Max +func (m Max) GetLimitValue() interface{} { + return m.Max +} + +// Range Requires an integer to be within Min, Max inclusive. +type Range struct { + Min + Max + Key string +} + +// IsSatisfied judge whether obj is valid +// not support int64 on 32-bit platform +func (r Range) IsSatisfied(obj interface{}) bool { + return r.Min.IsSatisfied(obj) && r.Max.IsSatisfied(obj) +} + +// DefaultMessage return the default Range error message +func (r Range) DefaultMessage() string { + return fmt.Sprintf(MessageTmpls["Range"], r.Min.Min, r.Max.Max) +} + +// GetKey return the m.Key +func (r Range) GetKey() string { + return r.Key +} + +// GetLimitValue return the limit value, Max +func (r Range) GetLimitValue() interface{} { + return []int{r.Min.Min, r.Max.Max} +} + +// MinSize Requires an array or string to be at least a given length. +type MinSize struct { + Min int + Key string +} + +// IsSatisfied judge whether obj is valid +func (m MinSize) IsSatisfied(obj interface{}) bool { + if str, ok := obj.(string); ok { + return utf8.RuneCountInString(str) >= m.Min + } + v := reflect.ValueOf(obj) + if v.Kind() == reflect.Slice { + return v.Len() >= m.Min + } + return false +} + +// DefaultMessage return the default MinSize error message +func (m MinSize) DefaultMessage() string { + return fmt.Sprintf(MessageTmpls["MinSize"], m.Min) +} + +// GetKey return the m.Key +func (m MinSize) GetKey() string { + return m.Key +} + +// GetLimitValue return the limit value +func (m MinSize) GetLimitValue() interface{} { + return m.Min +} + +// MaxSize Requires an array or string to be at most a given length. +type MaxSize struct { + Max int + Key string +} + +// IsSatisfied judge whether obj is valid +func (m MaxSize) IsSatisfied(obj interface{}) bool { + if str, ok := obj.(string); ok { + return utf8.RuneCountInString(str) <= m.Max + } + v := reflect.ValueOf(obj) + if v.Kind() == reflect.Slice { + return v.Len() <= m.Max + } + return false +} + +// DefaultMessage return the default MaxSize error message +func (m MaxSize) DefaultMessage() string { + return fmt.Sprintf(MessageTmpls["MaxSize"], m.Max) +} + +// GetKey return the m.Key +func (m MaxSize) GetKey() string { + return m.Key +} + +// GetLimitValue return the limit value +func (m MaxSize) GetLimitValue() interface{} { + return m.Max +} + +// Length Requires an array or string to be exactly a given length. +type Length struct { + N int + Key string +} + +// IsSatisfied judge whether obj is valid +func (l Length) IsSatisfied(obj interface{}) bool { + if str, ok := obj.(string); ok { + return utf8.RuneCountInString(str) == l.N + } + v := reflect.ValueOf(obj) + if v.Kind() == reflect.Slice { + return v.Len() == l.N + } + return false +} + +// DefaultMessage return the default Length error message +func (l Length) DefaultMessage() string { + return fmt.Sprintf(MessageTmpls["Length"], l.N) +} + +// GetKey return the m.Key +func (l Length) GetKey() string { + return l.Key +} + +// GetLimitValue return the limit value +func (l Length) GetLimitValue() interface{} { + return l.N +} + +// Alpha check the alpha +type Alpha struct { + Key string +} + +// IsSatisfied judge whether obj is valid +func (a Alpha) IsSatisfied(obj interface{}) bool { + if str, ok := obj.(string); ok { + for _, v := range str { + if ('Z' < v || v < 'A') && ('z' < v || v < 'a') { + return false + } + } + return true + } + return false +} + +// DefaultMessage return the default Length error message +func (a Alpha) DefaultMessage() string { + return MessageTmpls["Alpha"] +} + +// GetKey return the m.Key +func (a Alpha) GetKey() string { + return a.Key +} + +// GetLimitValue return the limit value +func (a Alpha) GetLimitValue() interface{} { + return nil +} + +// Numeric check number +type Numeric struct { + Key string +} + +// IsSatisfied judge whether obj is valid +func (n Numeric) IsSatisfied(obj interface{}) bool { + if str, ok := obj.(string); ok { + for _, v := range str { + if '9' < v || v < '0' { + return false + } + } + return true + } + return false +} + +// DefaultMessage return the default Length error message +func (n Numeric) DefaultMessage() string { + return MessageTmpls["Numeric"] +} + +// GetKey return the n.Key +func (n Numeric) GetKey() string { + return n.Key +} + +// GetLimitValue return the limit value +func (n Numeric) GetLimitValue() interface{} { + return nil +} + +// AlphaNumeric check alpha and number +type AlphaNumeric struct { + Key string +} + +// IsSatisfied judge whether obj is valid +func (a AlphaNumeric) IsSatisfied(obj interface{}) bool { + if str, ok := obj.(string); ok { + for _, v := range str { + if ('Z' < v || v < 'A') && ('z' < v || v < 'a') && ('9' < v || v < '0') { + return false + } + } + return true + } + return false +} + +// DefaultMessage return the default Length error message +func (a AlphaNumeric) DefaultMessage() string { + return MessageTmpls["AlphaNumeric"] +} + +// GetKey return the a.Key +func (a AlphaNumeric) GetKey() string { + return a.Key +} + +// GetLimitValue return the limit value +func (a AlphaNumeric) GetLimitValue() interface{} { + return nil +} + +// Match Requires a string to match a given regex. +type Match struct { + Regexp *regexp.Regexp + Key string +} + +// IsSatisfied judge whether obj is valid +func (m Match) IsSatisfied(obj interface{}) bool { + return m.Regexp.MatchString(fmt.Sprintf("%v", obj)) +} + +// DefaultMessage return the default Match error message +func (m Match) DefaultMessage() string { + return fmt.Sprintf(MessageTmpls["Match"], m.Regexp.String()) +} + +// GetKey return the m.Key +func (m Match) GetKey() string { + return m.Key +} + +// GetLimitValue return the limit value +func (m Match) GetLimitValue() interface{} { + return m.Regexp.String() +} + +// NoMatch Requires a string to not match a given regex. +type NoMatch struct { + Match + Key string +} + +// IsSatisfied judge whether obj is valid +func (n NoMatch) IsSatisfied(obj interface{}) bool { + return !n.Match.IsSatisfied(obj) +} + +// DefaultMessage return the default NoMatch error message +func (n NoMatch) DefaultMessage() string { + return fmt.Sprintf(MessageTmpls["NoMatch"], n.Regexp.String()) +} + +// GetKey return the n.Key +func (n NoMatch) GetKey() string { + return n.Key +} + +// GetLimitValue return the limit value +func (n NoMatch) GetLimitValue() interface{} { + return n.Regexp.String() +} + +var alphaDashPattern = regexp.MustCompile(`[^\d\w-_]`) + +// AlphaDash check not Alpha +type AlphaDash struct { + NoMatch + Key string +} + +// DefaultMessage return the default AlphaDash error message +func (a AlphaDash) DefaultMessage() string { + return MessageTmpls["AlphaDash"] +} + +// GetKey return the n.Key +func (a AlphaDash) GetKey() string { + return a.Key +} + +// GetLimitValue return the limit value +func (a AlphaDash) GetLimitValue() interface{} { + return nil +} + +var emailPattern = regexp.MustCompile(`^[\w!#$%&'*+/=?^_` + "`" + `{|}~-]+(?:\.[\w!#$%&'*+/=?^_` + "`" + `{|}~-]+)*@(?:[\w](?:[\w-]*[\w])?\.)+[a-zA-Z0-9](?:[\w-]*[\w])?$`) + +// Email check struct +type Email struct { + Match + Key string +} + +// DefaultMessage return the default Email error message +func (e Email) DefaultMessage() string { + return MessageTmpls["Email"] +} + +// GetKey return the n.Key +func (e Email) GetKey() string { + return e.Key +} + +// GetLimitValue return the limit value +func (e Email) GetLimitValue() interface{} { + return nil +} + +var ipPattern = regexp.MustCompile(`^((2[0-4]\d|25[0-5]|[01]?\d\d?)\.){3}(2[0-4]\d|25[0-5]|[01]?\d\d?)$`) + +// IP check struct +type IP struct { + Match + Key string +} + +// DefaultMessage return the default IP error message +func (i IP) DefaultMessage() string { + return MessageTmpls["IP"] +} + +// GetKey return the i.Key +func (i IP) GetKey() string { + return i.Key +} + +// GetLimitValue return the limit value +func (i IP) GetLimitValue() interface{} { + return nil +} + +var base64Pattern = regexp.MustCompile(`^(?:[A-Za-z0-99+/]{4})*(?:[A-Za-z0-9+/]{2}==|[A-Za-z0-9+/]{3}=)?$`) + +// Base64 check struct +type Base64 struct { + Match + Key string +} + +// DefaultMessage return the default Base64 error message +func (b Base64) DefaultMessage() string { + return MessageTmpls["Base64"] +} + +// GetKey return the b.Key +func (b Base64) GetKey() string { + return b.Key +} + +// GetLimitValue return the limit value +func (b Base64) GetLimitValue() interface{} { + return nil +} + +// just for chinese mobile phone number +var mobilePattern = regexp.MustCompile(`^((\+86)|(86))?1([356789][0-9]|4[579]|6[67]|7[0135678]|9[189])[0-9]{8}$`) + +// Mobile check struct +type Mobile struct { + Match + Key string +} + +// DefaultMessage return the default Mobile error message +func (m Mobile) DefaultMessage() string { + return MessageTmpls["Mobile"] +} + +// GetKey return the m.Key +func (m Mobile) GetKey() string { + return m.Key +} + +// GetLimitValue return the limit value +func (m Mobile) GetLimitValue() interface{} { + return nil +} + +// just for chinese telephone number +var telPattern = regexp.MustCompile(`^(0\d{2,3}(\-)?)?\d{7,8}$`) + +// Tel check telephone struct +type Tel struct { + Match + Key string +} + +// DefaultMessage return the default Tel error message +func (t Tel) DefaultMessage() string { + return MessageTmpls["Tel"] +} + +// GetKey return the t.Key +func (t Tel) GetKey() string { + return t.Key +} + +// GetLimitValue return the limit value +func (t Tel) GetLimitValue() interface{} { + return nil +} + +// Phone just for chinese telephone or mobile phone number +type Phone struct { + Mobile + Tel + Key string +} + +// IsSatisfied judge whether obj is valid +func (p Phone) IsSatisfied(obj interface{}) bool { + return p.Mobile.IsSatisfied(obj) || p.Tel.IsSatisfied(obj) +} + +// DefaultMessage return the default Phone error message +func (p Phone) DefaultMessage() string { + return MessageTmpls["Phone"] +} + +// GetKey return the p.Key +func (p Phone) GetKey() string { + return p.Key +} + +// GetLimitValue return the limit value +func (p Phone) GetLimitValue() interface{} { + return nil +} + +// just for chinese zipcode +var zipCodePattern = regexp.MustCompile(`^[1-9]\d{5}$`) + +// ZipCode check the zip struct +type ZipCode struct { + Match + Key string +} + +// DefaultMessage return the default Zip error message +func (z ZipCode) DefaultMessage() string { + return MessageTmpls["ZipCode"] +} + +// GetKey return the z.Key +func (z ZipCode) GetKey() string { + return z.Key +} + +// GetLimitValue return the limit value +func (z ZipCode) GetLimitValue() interface{} { + return nil +} From 30eb889a91f58189ac0b6d059031ee66e556d966 Mon Sep 17 00:00:00 2001 From: Ming Deng Date: Wed, 22 Jul 2020 23:00:06 +0800 Subject: [PATCH 2/3] Format code --- build_info.go | 8 +++--- cache/redis/redis.go | 2 +- config/yaml/yaml.go | 2 +- logs/accesslog.go | 2 +- logs/file.go | 30 +++++++++++----------- logs/file_test.go | 30 +++++++++++----------- metric/prometheus.go | 6 ++--- orm/cmd_utils.go | 6 ++--- orm/db_alias.go | 2 +- orm/orm_log.go | 2 +- pkg/build_info.go | 8 +++--- pkg/cache/redis/redis.go | 2 +- pkg/common/kv_test.go | 2 +- pkg/config/yaml/yaml.go | 2 +- pkg/logs/accesslog.go | 2 +- pkg/logs/file.go | 30 +++++++++++----------- pkg/logs/file_test.go | 30 +++++++++++----------- pkg/metric/prometheus.go | 6 ++--- pkg/orm/cmd_utils.go | 6 ++--- pkg/orm/db_alias.go | 6 ++--- pkg/orm/models_test.go | 6 ++--- pkg/orm/orm_log.go | 2 +- pkg/orm/types.go | 6 ++--- pkg/session/redis_cluster/redis_cluster.go | 17 ++++++------ pkg/session/sess_file_test.go | 5 ++-- pkg/staticfile.go | 2 +- pkg/templatefunc.go | 2 +- pkg/toolbox/task.go | 2 +- pkg/toolbox/task_test.go | 4 +-- pkg/validation/util.go | 2 +- session/redis_cluster/redis_cluster.go | 17 ++++++------ session/sess_file_test.go | 5 ++-- staticfile.go | 2 +- templatefunc.go | 2 +- toolbox/task.go | 2 +- toolbox/task_test.go | 4 +-- validation/util.go | 2 +- 37 files changed, 133 insertions(+), 133 deletions(-) diff --git a/build_info.go b/build_info.go index 6dc2835e..c31152ea 100644 --- a/build_info.go +++ b/build_info.go @@ -15,11 +15,11 @@ package beego var ( - BuildVersion string + BuildVersion string BuildGitRevision string - BuildStatus string - BuildTag string - BuildTime string + BuildStatus string + BuildTag string + BuildTime string GoVersion string diff --git a/cache/redis/redis.go b/cache/redis/redis.go index 56faf211..d8737b3c 100644 --- a/cache/redis/redis.go +++ b/cache/redis/redis.go @@ -57,7 +57,7 @@ type Cache struct { maxIdle int //the timeout to a value less than the redis server's timeout. - timeout time.Duration + timeout time.Duration } // NewRedisCache create new redis cache with default collection name. diff --git a/config/yaml/yaml.go b/config/yaml/yaml.go index 5def2da3..a5644c7b 100644 --- a/config/yaml/yaml.go +++ b/config/yaml/yaml.go @@ -296,7 +296,7 @@ func (c *ConfigContainer) getData(key string) (interface{}, error) { case map[string]interface{}: { tmpData = v.(map[string]interface{}) - if idx == len(keys) - 1 { + if idx == len(keys)-1 { return tmpData, nil } } diff --git a/logs/accesslog.go b/logs/accesslog.go index 3ff9e20f..9011b602 100644 --- a/logs/accesslog.go +++ b/logs/accesslog.go @@ -16,9 +16,9 @@ package logs import ( "bytes" - "strings" "encoding/json" "fmt" + "strings" "time" ) diff --git a/logs/file.go b/logs/file.go index 222db989..40a3572a 100644 --- a/logs/file.go +++ b/logs/file.go @@ -373,21 +373,21 @@ func (w *fileLogWriter) deleteOldLog() { if info == nil { return } - if w.Hourly { - if !info.IsDir() && info.ModTime().Add(1 * time.Hour * time.Duration(w.MaxHours)).Before(time.Now()) { - if strings.HasPrefix(filepath.Base(path), filepath.Base(w.fileNameOnly)) && - strings.HasSuffix(filepath.Base(path), w.suffix) { - os.Remove(path) - } - } - } else if w.Daily { - if !info.IsDir() && info.ModTime().Add(24 * time.Hour * time.Duration(w.MaxDays)).Before(time.Now()) { - if strings.HasPrefix(filepath.Base(path), filepath.Base(w.fileNameOnly)) && - strings.HasSuffix(filepath.Base(path), w.suffix) { - os.Remove(path) - } - } - } + if w.Hourly { + if !info.IsDir() && info.ModTime().Add(1*time.Hour*time.Duration(w.MaxHours)).Before(time.Now()) { + if strings.HasPrefix(filepath.Base(path), filepath.Base(w.fileNameOnly)) && + strings.HasSuffix(filepath.Base(path), w.suffix) { + os.Remove(path) + } + } + } else if w.Daily { + if !info.IsDir() && info.ModTime().Add(24*time.Hour*time.Duration(w.MaxDays)).Before(time.Now()) { + if strings.HasPrefix(filepath.Base(path), filepath.Base(w.fileNameOnly)) && + strings.HasSuffix(filepath.Base(path), w.suffix) { + os.Remove(path) + } + } + } return }) } diff --git a/logs/file_test.go b/logs/file_test.go index e7c2ca9a..385eac43 100644 --- a/logs/file_test.go +++ b/logs/file_test.go @@ -186,7 +186,7 @@ func TestFileDailyRotate_06(t *testing.T) { //test file mode func TestFileHourlyRotate_01(t *testing.T) { log := NewLogger(10000) - log.SetLogger("file", `{"filename":"test3.log","hourly":true,"maxlines":4}`) + log.SetLogger("file", `{"filename":"test3.log","hourly":true,"maxlines":4}`) log.Debug("debug") log.Info("info") log.Notice("notice") @@ -237,7 +237,7 @@ func TestFileHourlyRotate_05(t *testing.T) { func TestFileHourlyRotate_06(t *testing.T) { //test file mode log := NewLogger(10000) - log.SetLogger("file", `{"filename":"test3.log", "hourly":true, "maxlines":4}`) + log.SetLogger("file", `{"filename":"test3.log", "hourly":true, "maxlines":4}`) log.Debug("debug") log.Info("info") log.Notice("notice") @@ -269,19 +269,19 @@ func testFileRotate(t *testing.T, fn1, fn2 string, daily, hourly bool) { RotatePerm: "0440", } - if daily { - fw.Init(fmt.Sprintf(`{"filename":"%v","maxdays":1}`, fn1)) - fw.dailyOpenTime = time.Now().Add(-24 * time.Hour) - fw.dailyOpenDate = fw.dailyOpenTime.Day() - } + if daily { + fw.Init(fmt.Sprintf(`{"filename":"%v","maxdays":1}`, fn1)) + fw.dailyOpenTime = time.Now().Add(-24 * time.Hour) + fw.dailyOpenDate = fw.dailyOpenTime.Day() + } - if hourly { - fw.Init(fmt.Sprintf(`{"filename":"%v","maxhours":1}`, fn1)) - fw.hourlyOpenTime = time.Now().Add(-1 * time.Hour) - fw.hourlyOpenDate = fw.hourlyOpenTime.Day() - } + if hourly { + fw.Init(fmt.Sprintf(`{"filename":"%v","maxhours":1}`, fn1)) + fw.hourlyOpenTime = time.Now().Add(-1 * time.Hour) + fw.hourlyOpenDate = fw.hourlyOpenTime.Day() + } - fw.WriteMsg(time.Now(), "this is a msg for test", LevelDebug) + fw.WriteMsg(time.Now(), "this is a msg for test", LevelDebug) for _, file := range []string{fn1, fn2} { _, err := os.Stat(file) @@ -328,8 +328,8 @@ func testFileDailyRotate(t *testing.T, fn1, fn2 string) { func testFileHourlyRotate(t *testing.T, fn1, fn2 string) { fw := &fileLogWriter{ - Hourly: true, - MaxHours: 168, + Hourly: true, + MaxHours: 168, Rotate: true, Level: LevelTrace, Perm: "0660", diff --git a/metric/prometheus.go b/metric/prometheus.go index 7722240b..86e2c1b1 100644 --- a/metric/prometheus.go +++ b/metric/prometheus.go @@ -57,15 +57,15 @@ func registerBuildInfo() { Subsystem: "build_info", Help: "The building information", ConstLabels: map[string]string{ - "appname": beego.BConfig.AppName, + "appname": beego.BConfig.AppName, "build_version": beego.BuildVersion, "build_revision": beego.BuildGitRevision, "build_status": beego.BuildStatus, "build_tag": beego.BuildTag, - "build_time": strings.Replace(beego.BuildTime, "--", " ", 1), + "build_time": strings.Replace(beego.BuildTime, "--", " ", 1), "go_version": beego.GoVersion, "git_branch": beego.GitBranch, - "start_time": time.Now().Format("2006-01-02 15:04:05"), + "start_time": time.Now().Format("2006-01-02 15:04:05"), }, }, []string{}) diff --git a/orm/cmd_utils.go b/orm/cmd_utils.go index 61f17346..692a079f 100644 --- a/orm/cmd_utils.go +++ b/orm/cmd_utils.go @@ -197,9 +197,9 @@ func getDbCreateSQL(al *alias) (sqls []string, tableIndexes map[string][]dbIndex if strings.Contains(column, "%COL%") { column = strings.Replace(column, "%COL%", fi.column, -1) } - - if fi.description != "" && al.Driver!=DRSqlite { - column += " " + fmt.Sprintf("COMMENT '%s'",fi.description) + + if fi.description != "" && al.Driver != DRSqlite { + column += " " + fmt.Sprintf("COMMENT '%s'", fi.description) } columns = append(columns, column) diff --git a/orm/db_alias.go b/orm/db_alias.go index bf6c350c..fe6abeb5 100644 --- a/orm/db_alias.go +++ b/orm/db_alias.go @@ -424,7 +424,7 @@ func GetDB(aliasNames ...string) (*sql.DB, error) { } type stmtDecorator struct { - wg sync.WaitGroup + wg sync.WaitGroup stmt *sql.Stmt } diff --git a/orm/orm_log.go b/orm/orm_log.go index f107bb59..5bb3a24f 100644 --- a/orm/orm_log.go +++ b/orm/orm_log.go @@ -61,7 +61,7 @@ func debugLogQueies(alias *alias, operaton, query string, t time.Time, err error con += " - " + err.Error() } logMap["sql"] = fmt.Sprintf("%s-`%s`", query, strings.Join(cons, "`, `")) - if LogFunc != nil{ + if LogFunc != nil { LogFunc(logMap) } DebugLog.Println(con) diff --git a/pkg/build_info.go b/pkg/build_info.go index 6dc2835e..c31152ea 100644 --- a/pkg/build_info.go +++ b/pkg/build_info.go @@ -15,11 +15,11 @@ package beego var ( - BuildVersion string + BuildVersion string BuildGitRevision string - BuildStatus string - BuildTag string - BuildTime string + BuildStatus string + BuildTag string + BuildTime string GoVersion string diff --git a/pkg/cache/redis/redis.go b/pkg/cache/redis/redis.go index 56faf211..d8737b3c 100644 --- a/pkg/cache/redis/redis.go +++ b/pkg/cache/redis/redis.go @@ -57,7 +57,7 @@ type Cache struct { maxIdle int //the timeout to a value less than the redis server's timeout. - timeout time.Duration + timeout time.Duration } // NewRedisCache create new redis cache with default collection name. diff --git a/pkg/common/kv_test.go b/pkg/common/kv_test.go index ed7dc7ef..45adf5ff 100644 --- a/pkg/common/kv_test.go +++ b/pkg/common/kv_test.go @@ -23,7 +23,7 @@ import ( func TestKVs(t *testing.T) { key := "my-key" kvs := NewKVs(KV{ - Key: key, + Key: key, Value: 12, }) diff --git a/pkg/config/yaml/yaml.go b/pkg/config/yaml/yaml.go index 5def2da3..a5644c7b 100644 --- a/pkg/config/yaml/yaml.go +++ b/pkg/config/yaml/yaml.go @@ -296,7 +296,7 @@ func (c *ConfigContainer) getData(key string) (interface{}, error) { case map[string]interface{}: { tmpData = v.(map[string]interface{}) - if idx == len(keys) - 1 { + if idx == len(keys)-1 { return tmpData, nil } } diff --git a/pkg/logs/accesslog.go b/pkg/logs/accesslog.go index 3ff9e20f..9011b602 100644 --- a/pkg/logs/accesslog.go +++ b/pkg/logs/accesslog.go @@ -16,9 +16,9 @@ package logs import ( "bytes" - "strings" "encoding/json" "fmt" + "strings" "time" ) diff --git a/pkg/logs/file.go b/pkg/logs/file.go index 222db989..40a3572a 100644 --- a/pkg/logs/file.go +++ b/pkg/logs/file.go @@ -373,21 +373,21 @@ func (w *fileLogWriter) deleteOldLog() { if info == nil { return } - if w.Hourly { - if !info.IsDir() && info.ModTime().Add(1 * time.Hour * time.Duration(w.MaxHours)).Before(time.Now()) { - if strings.HasPrefix(filepath.Base(path), filepath.Base(w.fileNameOnly)) && - strings.HasSuffix(filepath.Base(path), w.suffix) { - os.Remove(path) - } - } - } else if w.Daily { - if !info.IsDir() && info.ModTime().Add(24 * time.Hour * time.Duration(w.MaxDays)).Before(time.Now()) { - if strings.HasPrefix(filepath.Base(path), filepath.Base(w.fileNameOnly)) && - strings.HasSuffix(filepath.Base(path), w.suffix) { - os.Remove(path) - } - } - } + if w.Hourly { + if !info.IsDir() && info.ModTime().Add(1*time.Hour*time.Duration(w.MaxHours)).Before(time.Now()) { + if strings.HasPrefix(filepath.Base(path), filepath.Base(w.fileNameOnly)) && + strings.HasSuffix(filepath.Base(path), w.suffix) { + os.Remove(path) + } + } + } else if w.Daily { + if !info.IsDir() && info.ModTime().Add(24*time.Hour*time.Duration(w.MaxDays)).Before(time.Now()) { + if strings.HasPrefix(filepath.Base(path), filepath.Base(w.fileNameOnly)) && + strings.HasSuffix(filepath.Base(path), w.suffix) { + os.Remove(path) + } + } + } return }) } diff --git a/pkg/logs/file_test.go b/pkg/logs/file_test.go index e7c2ca9a..385eac43 100644 --- a/pkg/logs/file_test.go +++ b/pkg/logs/file_test.go @@ -186,7 +186,7 @@ func TestFileDailyRotate_06(t *testing.T) { //test file mode func TestFileHourlyRotate_01(t *testing.T) { log := NewLogger(10000) - log.SetLogger("file", `{"filename":"test3.log","hourly":true,"maxlines":4}`) + log.SetLogger("file", `{"filename":"test3.log","hourly":true,"maxlines":4}`) log.Debug("debug") log.Info("info") log.Notice("notice") @@ -237,7 +237,7 @@ func TestFileHourlyRotate_05(t *testing.T) { func TestFileHourlyRotate_06(t *testing.T) { //test file mode log := NewLogger(10000) - log.SetLogger("file", `{"filename":"test3.log", "hourly":true, "maxlines":4}`) + log.SetLogger("file", `{"filename":"test3.log", "hourly":true, "maxlines":4}`) log.Debug("debug") log.Info("info") log.Notice("notice") @@ -269,19 +269,19 @@ func testFileRotate(t *testing.T, fn1, fn2 string, daily, hourly bool) { RotatePerm: "0440", } - if daily { - fw.Init(fmt.Sprintf(`{"filename":"%v","maxdays":1}`, fn1)) - fw.dailyOpenTime = time.Now().Add(-24 * time.Hour) - fw.dailyOpenDate = fw.dailyOpenTime.Day() - } + if daily { + fw.Init(fmt.Sprintf(`{"filename":"%v","maxdays":1}`, fn1)) + fw.dailyOpenTime = time.Now().Add(-24 * time.Hour) + fw.dailyOpenDate = fw.dailyOpenTime.Day() + } - if hourly { - fw.Init(fmt.Sprintf(`{"filename":"%v","maxhours":1}`, fn1)) - fw.hourlyOpenTime = time.Now().Add(-1 * time.Hour) - fw.hourlyOpenDate = fw.hourlyOpenTime.Day() - } + if hourly { + fw.Init(fmt.Sprintf(`{"filename":"%v","maxhours":1}`, fn1)) + fw.hourlyOpenTime = time.Now().Add(-1 * time.Hour) + fw.hourlyOpenDate = fw.hourlyOpenTime.Day() + } - fw.WriteMsg(time.Now(), "this is a msg for test", LevelDebug) + fw.WriteMsg(time.Now(), "this is a msg for test", LevelDebug) for _, file := range []string{fn1, fn2} { _, err := os.Stat(file) @@ -328,8 +328,8 @@ func testFileDailyRotate(t *testing.T, fn1, fn2 string) { func testFileHourlyRotate(t *testing.T, fn1, fn2 string) { fw := &fileLogWriter{ - Hourly: true, - MaxHours: 168, + Hourly: true, + MaxHours: 168, Rotate: true, Level: LevelTrace, Perm: "0660", diff --git a/pkg/metric/prometheus.go b/pkg/metric/prometheus.go index 7722240b..86e2c1b1 100644 --- a/pkg/metric/prometheus.go +++ b/pkg/metric/prometheus.go @@ -57,15 +57,15 @@ func registerBuildInfo() { Subsystem: "build_info", Help: "The building information", ConstLabels: map[string]string{ - "appname": beego.BConfig.AppName, + "appname": beego.BConfig.AppName, "build_version": beego.BuildVersion, "build_revision": beego.BuildGitRevision, "build_status": beego.BuildStatus, "build_tag": beego.BuildTag, - "build_time": strings.Replace(beego.BuildTime, "--", " ", 1), + "build_time": strings.Replace(beego.BuildTime, "--", " ", 1), "go_version": beego.GoVersion, "git_branch": beego.GitBranch, - "start_time": time.Now().Format("2006-01-02 15:04:05"), + "start_time": time.Now().Format("2006-01-02 15:04:05"), }, }, []string{}) diff --git a/pkg/orm/cmd_utils.go b/pkg/orm/cmd_utils.go index 61f17346..692a079f 100644 --- a/pkg/orm/cmd_utils.go +++ b/pkg/orm/cmd_utils.go @@ -197,9 +197,9 @@ func getDbCreateSQL(al *alias) (sqls []string, tableIndexes map[string][]dbIndex if strings.Contains(column, "%COL%") { column = strings.Replace(column, "%COL%", fi.column, -1) } - - if fi.description != "" && al.Driver!=DRSqlite { - column += " " + fmt.Sprintf("COMMENT '%s'",fi.description) + + if fi.description != "" && al.Driver != DRSqlite { + column += " " + fmt.Sprintf("COMMENT '%s'", fi.description) } columns = append(columns, column) diff --git a/pkg/orm/db_alias.go b/pkg/orm/db_alias.go index 90c5de3c..a3f2a0b9 100644 --- a/pkg/orm/db_alias.go +++ b/pkg/orm/db_alias.go @@ -244,7 +244,7 @@ var _ dbQuerier = new(TxDB) var _ txEnder = new(TxDB) func (t *TxDB) Prepare(query string) (*sql.Stmt, error) { - return t.PrepareContext(context.Background(),query) + return t.PrepareContext(context.Background(), query) } func (t *TxDB) PrepareContext(ctx context.Context, query string) (*sql.Stmt, error) { @@ -260,7 +260,7 @@ func (t *TxDB) ExecContext(ctx context.Context, query string, args ...interface{ } func (t *TxDB) Query(query string, args ...interface{}) (*sql.Rows, error) { - return t.QueryContext(context.Background(),query,args...) + return t.QueryContext(context.Background(), query, args...) } func (t *TxDB) QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) { @@ -268,7 +268,7 @@ func (t *TxDB) QueryContext(ctx context.Context, query string, args ...interface } func (t *TxDB) QueryRow(query string, args ...interface{}) *sql.Row { - return t.QueryRowContext(context.Background(),query,args...) + return t.QueryRowContext(context.Background(), query, args...) } func (t *TxDB) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row { diff --git a/pkg/orm/models_test.go b/pkg/orm/models_test.go index f14ee9cf..4c00050d 100644 --- a/pkg/orm/models_test.go +++ b/pkg/orm/models_test.go @@ -490,11 +490,11 @@ func init() { } err := RegisterDataBase("default", DBARGS.Driver, DBARGS.Source, common.KV{ - Key:MaxIdleConnsKey, - Value:20, + Key: MaxIdleConnsKey, + Value: 20, }) - if err != nil{ + if err != nil { panic(fmt.Sprintf("can not register database: %v", err)) } diff --git a/pkg/orm/orm_log.go b/pkg/orm/orm_log.go index f107bb59..5bb3a24f 100644 --- a/pkg/orm/orm_log.go +++ b/pkg/orm/orm_log.go @@ -61,7 +61,7 @@ func debugLogQueies(alias *alias, operaton, query string, t time.Time, err error con += " - " + err.Error() } logMap["sql"] = fmt.Sprintf("%s-`%s`", query, strings.Join(cons, "`, `")) - if LogFunc != nil{ + if LogFunc != nil { LogFunc(logMap) } DebugLog.Println(con) diff --git a/pkg/orm/types.go b/pkg/orm/types.go index b7a38826..8255d93e 100644 --- a/pkg/orm/types.go +++ b/pkg/orm/types.go @@ -110,7 +110,7 @@ type DQL interface { // Like Read(), but with "FOR UPDATE" clause, useful in transaction. // Some databases are not support this feature. - ReadForUpdate( md interface{}, cols ...string) error + ReadForUpdate(md interface{}, cols ...string) error ReadForUpdateWithCtx(ctx context.Context, md interface{}, cols ...string) error // Try to read a row from the database, or insert one if it doesn't exist @@ -129,14 +129,14 @@ type DQL interface { // args[2] int offset default offset 0 // args[3] string order for example : "-Id" // make sure the relation is defined in model struct tags. - LoadRelated( md interface{}, name string, args ...interface{}) (int64, error) + LoadRelated(md interface{}, name string, args ...interface{}) (int64, error) LoadRelatedWithCtx(ctx context.Context, md interface{}, name string, args ...interface{}) (int64, error) // create a models to models queryer // for example: // post := Post{Id: 4} // m2m := Ormer.QueryM2M(&post, "Tags") - QueryM2M( md interface{}, name string) QueryM2Mer + QueryM2M(md interface{}, name string) QueryM2Mer QueryM2MWithCtx(ctx context.Context, md interface{}, name string) QueryM2Mer // return a QuerySeter for table operations. diff --git a/pkg/session/redis_cluster/redis_cluster.go b/pkg/session/redis_cluster/redis_cluster.go index 2fe300df..262fa2e3 100644 --- a/pkg/session/redis_cluster/redis_cluster.go +++ b/pkg/session/redis_cluster/redis_cluster.go @@ -31,13 +31,14 @@ // // more docs: http://beego.me/docs/module/session.md package redis_cluster + import ( + "github.com/astaxie/beego/session" + rediss "github.com/go-redis/redis" "net/http" "strconv" "strings" "sync" - "github.com/astaxie/beego/session" - rediss "github.com/go-redis/redis" "time" ) @@ -101,7 +102,7 @@ func (rs *SessionStore) SessionRelease(w http.ResponseWriter) { return } c := rs.p - c.Set(rs.sid, string(b), time.Duration(rs.maxlifetime) * time.Second) + c.Set(rs.sid, string(b), time.Duration(rs.maxlifetime)*time.Second) } // Provider redis_cluster session provider @@ -146,10 +147,10 @@ func (rp *Provider) SessionInit(maxlifetime int64, savePath string) error { } else { rp.dbNum = 0 } - + rp.poollist = rediss.NewClusterClient(&rediss.ClusterOptions{ Addrs: strings.Split(rp.savePath, ";"), - Password: rp.password, + Password: rp.password, PoolSize: rp.poolsize, }) return rp.poollist.Ping().Err() @@ -186,15 +187,15 @@ func (rp *Provider) SessionExist(sid string) bool { // SessionRegenerate generate new sid for redis_cluster session func (rp *Provider) SessionRegenerate(oldsid, sid string) (session.Store, error) { c := rp.poollist - + if existed, err := c.Exists(oldsid).Result(); err != nil || existed == 0 { // oldsid doesn't exists, set the new sid directly // ignore error here, since if it return error // the existed value will be 0 - c.Set(sid, "", time.Duration(rp.maxlifetime) * time.Second) + c.Set(sid, "", time.Duration(rp.maxlifetime)*time.Second) } else { c.Rename(oldsid, sid) - c.Expire(sid, time.Duration(rp.maxlifetime) * time.Second) + c.Expire(sid, time.Duration(rp.maxlifetime)*time.Second) } return rp.SessionRead(sid) } diff --git a/pkg/session/sess_file_test.go b/pkg/session/sess_file_test.go index 0cf021db..021c43fc 100644 --- a/pkg/session/sess_file_test.go +++ b/pkg/session/sess_file_test.go @@ -369,8 +369,7 @@ func TestFileSessionStore_SessionRelease(t *testing.T) { t.Error(err) } - - s.Set(i,i) + s.Set(i, i) s.SessionRelease(nil) } @@ -384,4 +383,4 @@ func TestFileSessionStore_SessionRelease(t *testing.T) { t.Error() } } -} \ No newline at end of file +} diff --git a/pkg/staticfile.go b/pkg/staticfile.go index 84e9aa7b..e26776c5 100644 --- a/pkg/staticfile.go +++ b/pkg/staticfile.go @@ -202,7 +202,7 @@ func searchFile(ctx *context.Context) (string, os.FileInfo, error) { if !strings.Contains(requestPath, prefix) { continue } - if prefix != "/" && len(requestPath) > len(prefix) && requestPath[len(prefix)] != '/' { + if prefix != "/" && len(requestPath) > len(prefix) && requestPath[len(prefix)] != '/' { continue } filePath := path.Join(staticDir, requestPath[len(prefix):]) diff --git a/pkg/templatefunc.go b/pkg/templatefunc.go index ba1ec5eb..6f02b8d6 100644 --- a/pkg/templatefunc.go +++ b/pkg/templatefunc.go @@ -362,7 +362,7 @@ func parseFormToStruct(form url.Values, objT reflect.Type, objV reflect.Value) e value = value[:25] t, err = time.ParseInLocation(time.RFC3339, value, time.Local) } else if strings.HasSuffix(strings.ToUpper(value), "Z") { - t, err = time.ParseInLocation(time.RFC3339, value, time.Local) + t, err = time.ParseInLocation(time.RFC3339, value, time.Local) } else if len(value) >= 19 { if strings.Contains(value, "T") { value = value[:19] diff --git a/pkg/toolbox/task.go b/pkg/toolbox/task.go index c902fdfc..fb2c5f16 100644 --- a/pkg/toolbox/task.go +++ b/pkg/toolbox/task.go @@ -113,7 +113,7 @@ type Task struct { Next time.Time Errlist []*taskerr // like errtime:errinfo ErrLimit int // max length for the errlist, 0 stand for no limit - errCnt int // records the error count during the execution + errCnt int // records the error count during the execution } // NewTask add new task with name, time and func diff --git a/pkg/toolbox/task_test.go b/pkg/toolbox/task_test.go index 3a4cce2f..b63f4391 100644 --- a/pkg/toolbox/task_test.go +++ b/pkg/toolbox/task_test.go @@ -59,12 +59,12 @@ func TestSpec(t *testing.T) { func TestTask_Run(t *testing.T) { cnt := -1 task := func() error { - cnt ++ + cnt++ fmt.Printf("Hello, world! %d \n", cnt) return errors.New(fmt.Sprintf("Hello, world! %d", cnt)) } tk := NewTask("taska", "0/30 * * * * *", task) - for i := 0; i < 200 ; i ++ { + for i := 0; i < 200; i++ { e := tk.Run() assert.NotNil(t, e) } diff --git a/pkg/validation/util.go b/pkg/validation/util.go index 82206f4f..918b206c 100644 --- a/pkg/validation/util.go +++ b/pkg/validation/util.go @@ -213,7 +213,7 @@ func parseFunc(vfunc, key string, label string) (v ValidFunc, err error) { return } - tParams, err := trim(name, key+"."+ name + "." + label, params) + tParams, err := trim(name, key+"."+name+"."+label, params) if err != nil { return } diff --git a/session/redis_cluster/redis_cluster.go b/session/redis_cluster/redis_cluster.go index 2fe300df..262fa2e3 100644 --- a/session/redis_cluster/redis_cluster.go +++ b/session/redis_cluster/redis_cluster.go @@ -31,13 +31,14 @@ // // more docs: http://beego.me/docs/module/session.md package redis_cluster + import ( + "github.com/astaxie/beego/session" + rediss "github.com/go-redis/redis" "net/http" "strconv" "strings" "sync" - "github.com/astaxie/beego/session" - rediss "github.com/go-redis/redis" "time" ) @@ -101,7 +102,7 @@ func (rs *SessionStore) SessionRelease(w http.ResponseWriter) { return } c := rs.p - c.Set(rs.sid, string(b), time.Duration(rs.maxlifetime) * time.Second) + c.Set(rs.sid, string(b), time.Duration(rs.maxlifetime)*time.Second) } // Provider redis_cluster session provider @@ -146,10 +147,10 @@ func (rp *Provider) SessionInit(maxlifetime int64, savePath string) error { } else { rp.dbNum = 0 } - + rp.poollist = rediss.NewClusterClient(&rediss.ClusterOptions{ Addrs: strings.Split(rp.savePath, ";"), - Password: rp.password, + Password: rp.password, PoolSize: rp.poolsize, }) return rp.poollist.Ping().Err() @@ -186,15 +187,15 @@ func (rp *Provider) SessionExist(sid string) bool { // SessionRegenerate generate new sid for redis_cluster session func (rp *Provider) SessionRegenerate(oldsid, sid string) (session.Store, error) { c := rp.poollist - + if existed, err := c.Exists(oldsid).Result(); err != nil || existed == 0 { // oldsid doesn't exists, set the new sid directly // ignore error here, since if it return error // the existed value will be 0 - c.Set(sid, "", time.Duration(rp.maxlifetime) * time.Second) + c.Set(sid, "", time.Duration(rp.maxlifetime)*time.Second) } else { c.Rename(oldsid, sid) - c.Expire(sid, time.Duration(rp.maxlifetime) * time.Second) + c.Expire(sid, time.Duration(rp.maxlifetime)*time.Second) } return rp.SessionRead(sid) } diff --git a/session/sess_file_test.go b/session/sess_file_test.go index 0cf021db..021c43fc 100644 --- a/session/sess_file_test.go +++ b/session/sess_file_test.go @@ -369,8 +369,7 @@ func TestFileSessionStore_SessionRelease(t *testing.T) { t.Error(err) } - - s.Set(i,i) + s.Set(i, i) s.SessionRelease(nil) } @@ -384,4 +383,4 @@ func TestFileSessionStore_SessionRelease(t *testing.T) { t.Error() } } -} \ No newline at end of file +} diff --git a/staticfile.go b/staticfile.go index 84e9aa7b..e26776c5 100644 --- a/staticfile.go +++ b/staticfile.go @@ -202,7 +202,7 @@ func searchFile(ctx *context.Context) (string, os.FileInfo, error) { if !strings.Contains(requestPath, prefix) { continue } - if prefix != "/" && len(requestPath) > len(prefix) && requestPath[len(prefix)] != '/' { + if prefix != "/" && len(requestPath) > len(prefix) && requestPath[len(prefix)] != '/' { continue } filePath := path.Join(staticDir, requestPath[len(prefix):]) diff --git a/templatefunc.go b/templatefunc.go index ba1ec5eb..6f02b8d6 100644 --- a/templatefunc.go +++ b/templatefunc.go @@ -362,7 +362,7 @@ func parseFormToStruct(form url.Values, objT reflect.Type, objV reflect.Value) e value = value[:25] t, err = time.ParseInLocation(time.RFC3339, value, time.Local) } else if strings.HasSuffix(strings.ToUpper(value), "Z") { - t, err = time.ParseInLocation(time.RFC3339, value, time.Local) + t, err = time.ParseInLocation(time.RFC3339, value, time.Local) } else if len(value) >= 19 { if strings.Contains(value, "T") { value = value[:19] diff --git a/toolbox/task.go b/toolbox/task.go index c902fdfc..fb2c5f16 100644 --- a/toolbox/task.go +++ b/toolbox/task.go @@ -113,7 +113,7 @@ type Task struct { Next time.Time Errlist []*taskerr // like errtime:errinfo ErrLimit int // max length for the errlist, 0 stand for no limit - errCnt int // records the error count during the execution + errCnt int // records the error count during the execution } // NewTask add new task with name, time and func diff --git a/toolbox/task_test.go b/toolbox/task_test.go index 3a4cce2f..b63f4391 100644 --- a/toolbox/task_test.go +++ b/toolbox/task_test.go @@ -59,12 +59,12 @@ func TestSpec(t *testing.T) { func TestTask_Run(t *testing.T) { cnt := -1 task := func() error { - cnt ++ + cnt++ fmt.Printf("Hello, world! %d \n", cnt) return errors.New(fmt.Sprintf("Hello, world! %d", cnt)) } tk := NewTask("taska", "0/30 * * * * *", task) - for i := 0; i < 200 ; i ++ { + for i := 0; i < 200; i++ { e := tk.Run() assert.NotNil(t, e) } diff --git a/validation/util.go b/validation/util.go index 82206f4f..918b206c 100644 --- a/validation/util.go +++ b/validation/util.go @@ -213,7 +213,7 @@ func parseFunc(vfunc, key string, label string) (v ValidFunc, err error) { return } - tParams, err := trim(name, key+"."+ name + "." + label, params) + tParams, err := trim(name, key+"."+name+"."+label, params) if err != nil { return } From 79c2157ad47c392ee780f71f42535717444de08b Mon Sep 17 00:00:00 2001 From: Ming Deng Date: Wed, 22 Jul 2020 15:34:55 +0000 Subject: [PATCH 3/3] Fix UT --- orm/models_test.go | 497 ++++ orm/orm_test.go | 2500 +++++++++++++++++ orm/utils_test.go | 70 + pkg/LICENSE | 13 + test.sh => scripts/test.sh | 2 +- .../test_docker_compose.yaml | 0 6 files changed, 3081 insertions(+), 1 deletion(-) create mode 100644 orm/models_test.go create mode 100644 orm/orm_test.go create mode 100644 orm/utils_test.go create mode 100644 pkg/LICENSE rename test.sh => scripts/test.sh (94%) rename test_docker_compose.yaml => scripts/test_docker_compose.yaml (100%) diff --git a/orm/models_test.go b/orm/models_test.go new file mode 100644 index 00000000..e3a635f2 --- /dev/null +++ b/orm/models_test.go @@ -0,0 +1,497 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package orm + +import ( + "database/sql" + "encoding/json" + "fmt" + "os" + "strings" + "time" + + _ "github.com/go-sql-driver/mysql" + _ "github.com/lib/pq" + _ "github.com/mattn/go-sqlite3" + // As tidb can't use go get, so disable the tidb testing now + // _ "github.com/pingcap/tidb" +) + +// A slice string field. +type SliceStringField []string + +func (e SliceStringField) Value() []string { + return []string(e) +} + +func (e *SliceStringField) Set(d []string) { + *e = SliceStringField(d) +} + +func (e *SliceStringField) Add(v string) { + *e = append(*e, v) +} + +func (e *SliceStringField) String() string { + return strings.Join(e.Value(), ",") +} + +func (e *SliceStringField) FieldType() int { + return TypeVarCharField +} + +func (e *SliceStringField) SetRaw(value interface{}) error { + switch d := value.(type) { + case []string: + e.Set(d) + case string: + if len(d) > 0 { + parts := strings.Split(d, ",") + v := make([]string, 0, len(parts)) + for _, p := range parts { + v = append(v, strings.TrimSpace(p)) + } + e.Set(v) + } + default: + return fmt.Errorf(" unknown value `%v`", value) + } + return nil +} + +func (e *SliceStringField) RawValue() interface{} { + return e.String() +} + +var _ Fielder = new(SliceStringField) + +// A json field. +type JSONFieldTest struct { + Name string + Data string +} + +func (e *JSONFieldTest) String() string { + data, _ := json.Marshal(e) + return string(data) +} + +func (e *JSONFieldTest) FieldType() int { + return TypeTextField +} + +func (e *JSONFieldTest) SetRaw(value interface{}) error { + switch d := value.(type) { + case string: + return json.Unmarshal([]byte(d), e) + default: + return fmt.Errorf(" unknown value `%v`", value) + } +} + +func (e *JSONFieldTest) RawValue() interface{} { + return e.String() +} + +var _ Fielder = new(JSONFieldTest) + +type Data struct { + ID int `orm:"column(id)"` + Boolean bool + Char string `orm:"size(50)"` + Text string `orm:"type(text)"` + JSON string `orm:"type(json);default({\"name\":\"json\"})"` + Jsonb string `orm:"type(jsonb)"` + Time time.Time `orm:"type(time)"` + Date time.Time `orm:"type(date)"` + DateTime time.Time `orm:"column(datetime)"` + Byte byte + Rune rune + Int int + Int8 int8 + Int16 int16 + Int32 int32 + Int64 int64 + Uint uint + Uint8 uint8 + Uint16 uint16 + Uint32 uint32 + Uint64 uint64 + Float32 float32 + Float64 float64 + Decimal float64 `orm:"digits(8);decimals(4)"` +} + +type DataNull struct { + ID int `orm:"column(id)"` + Boolean bool `orm:"null"` + Char string `orm:"null;size(50)"` + Text string `orm:"null;type(text)"` + JSON string `orm:"type(json);null"` + Jsonb string `orm:"type(jsonb);null"` + Time time.Time `orm:"null;type(time)"` + Date time.Time `orm:"null;type(date)"` + DateTime time.Time `orm:"null;column(datetime)"` + Byte byte `orm:"null"` + Rune rune `orm:"null"` + Int int `orm:"null"` + Int8 int8 `orm:"null"` + Int16 int16 `orm:"null"` + Int32 int32 `orm:"null"` + Int64 int64 `orm:"null"` + Uint uint `orm:"null"` + Uint8 uint8 `orm:"null"` + Uint16 uint16 `orm:"null"` + Uint32 uint32 `orm:"null"` + Uint64 uint64 `orm:"null"` + Float32 float32 `orm:"null"` + Float64 float64 `orm:"null"` + Decimal float64 `orm:"digits(8);decimals(4);null"` + NullString sql.NullString `orm:"null"` + NullBool sql.NullBool `orm:"null"` + NullFloat64 sql.NullFloat64 `orm:"null"` + NullInt64 sql.NullInt64 `orm:"null"` + BooleanPtr *bool `orm:"null"` + CharPtr *string `orm:"null;size(50)"` + TextPtr *string `orm:"null;type(text)"` + BytePtr *byte `orm:"null"` + RunePtr *rune `orm:"null"` + IntPtr *int `orm:"null"` + Int8Ptr *int8 `orm:"null"` + Int16Ptr *int16 `orm:"null"` + Int32Ptr *int32 `orm:"null"` + Int64Ptr *int64 `orm:"null"` + UintPtr *uint `orm:"null"` + Uint8Ptr *uint8 `orm:"null"` + Uint16Ptr *uint16 `orm:"null"` + Uint32Ptr *uint32 `orm:"null"` + Uint64Ptr *uint64 `orm:"null"` + Float32Ptr *float32 `orm:"null"` + Float64Ptr *float64 `orm:"null"` + DecimalPtr *float64 `orm:"digits(8);decimals(4);null"` + TimePtr *time.Time `orm:"null;type(time)"` + DatePtr *time.Time `orm:"null;type(date)"` + DateTimePtr *time.Time `orm:"null"` +} + +type String string +type Boolean bool +type Byte byte +type Rune rune +type Int int +type Int8 int8 +type Int16 int16 +type Int32 int32 +type Int64 int64 +type Uint uint +type Uint8 uint8 +type Uint16 uint16 +type Uint32 uint32 +type Uint64 uint64 +type Float32 float64 +type Float64 float64 + +type DataCustom struct { + ID int `orm:"column(id)"` + Boolean Boolean + Char string `orm:"size(50)"` + Text string `orm:"type(text)"` + Byte Byte + Rune Rune + Int Int + Int8 Int8 + Int16 Int16 + Int32 Int32 + Int64 Int64 + Uint Uint + Uint8 Uint8 + Uint16 Uint16 + Uint32 Uint32 + Uint64 Uint64 + Float32 Float32 + Float64 Float64 + Decimal Float64 `orm:"digits(8);decimals(4)"` +} + +// only for mysql +type UserBig struct { + ID uint64 `orm:"column(id)"` + Name string +} + +type User struct { + ID int `orm:"column(id)"` + UserName string `orm:"size(30);unique"` + Email string `orm:"size(100)"` + Password string `orm:"size(100)"` + Status int16 `orm:"column(Status)"` + IsStaff bool + IsActive bool `orm:"default(true)"` + Created time.Time `orm:"auto_now_add;type(date)"` + Updated time.Time `orm:"auto_now"` + Profile *Profile `orm:"null;rel(one);on_delete(set_null)"` + Posts []*Post `orm:"reverse(many)" json:"-"` + ShouldSkip string `orm:"-"` + Nums int + Langs SliceStringField `orm:"size(100)"` + Extra JSONFieldTest `orm:"type(text)"` + unexport bool `orm:"-"` + unexportBool bool +} + +func (u *User) TableIndex() [][]string { + return [][]string{ + {"Id", "UserName"}, + {"Id", "Created"}, + } +} + +func (u *User) TableUnique() [][]string { + return [][]string{ + {"UserName", "Email"}, + } +} + +func NewUser() *User { + obj := new(User) + return obj +} + +type Profile struct { + ID int `orm:"column(id)"` + Age int16 + Money float64 + User *User `orm:"reverse(one)" json:"-"` + BestPost *Post `orm:"rel(one);null"` +} + +func (u *Profile) TableName() string { + return "user_profile" +} + +func NewProfile() *Profile { + obj := new(Profile) + return obj +} + +type Post struct { + ID int `orm:"column(id)"` + User *User `orm:"rel(fk)"` + Title string `orm:"size(60)"` + Content string `orm:"type(text)"` + Created time.Time `orm:"auto_now_add"` + Updated time.Time `orm:"auto_now"` + Tags []*Tag `orm:"rel(m2m);rel_through(github.com/astaxie/beego/orm.PostTags)"` +} + +func (u *Post) TableIndex() [][]string { + return [][]string{ + {"Id", "Created"}, + } +} + +func NewPost() *Post { + obj := new(Post) + return obj +} + +type Tag struct { + ID int `orm:"column(id)"` + Name string `orm:"size(30)"` + BestPost *Post `orm:"rel(one);null"` + Posts []*Post `orm:"reverse(many)" json:"-"` +} + +func NewTag() *Tag { + obj := new(Tag) + return obj +} + +type PostTags struct { + ID int `orm:"column(id)"` + Post *Post `orm:"rel(fk)"` + Tag *Tag `orm:"rel(fk)"` +} + +func (m *PostTags) TableName() string { + return "prefix_post_tags" +} + +type Comment struct { + ID int `orm:"column(id)"` + Post *Post `orm:"rel(fk);column(post)"` + Content string `orm:"type(text)"` + Parent *Comment `orm:"null;rel(fk)"` + Created time.Time `orm:"auto_now_add"` +} + +func NewComment() *Comment { + obj := new(Comment) + return obj +} + +type Group struct { + ID int `orm:"column(gid);size(32)"` + Name string + Permissions []*Permission `orm:"reverse(many)" json:"-"` +} + +type Permission struct { + ID int `orm:"column(id)"` + Name string + Groups []*Group `orm:"rel(m2m);rel_through(github.com/astaxie/beego/orm.GroupPermissions)"` +} + +type GroupPermissions struct { + ID int `orm:"column(id)"` + Group *Group `orm:"rel(fk)"` + Permission *Permission `orm:"rel(fk)"` +} + +type ModelID struct { + ID int64 +} + +type ModelBase struct { + ModelID + + Created time.Time `orm:"auto_now_add;type(datetime)"` + Updated time.Time `orm:"auto_now;type(datetime)"` +} + +type InLine struct { + // Common Fields + ModelBase + + // Other Fields + Name string `orm:"unique"` + Email string +} + +func NewInLine() *InLine { + return new(InLine) +} + +type InLineOneToOne struct { + // Common Fields + ModelBase + + Note string + InLine *InLine `orm:"rel(fk);column(inline)"` +} + +func NewInLineOneToOne() *InLineOneToOne { + return new(InLineOneToOne) +} + +type IntegerPk struct { + ID int64 `orm:"pk"` + Value string +} + +type UintPk struct { + ID uint32 `orm:"pk"` + Name string +} + +type PtrPk struct { + ID *IntegerPk `orm:"pk;rel(one)"` + Positive bool +} + +var DBARGS = struct { + Driver string + Source string + Debug string +}{ + os.Getenv("ORM_DRIVER"), + os.Getenv("ORM_SOURCE"), + os.Getenv("ORM_DEBUG"), +} + +var ( + IsMysql = DBARGS.Driver == "mysql" + IsSqlite = DBARGS.Driver == "sqlite3" + IsPostgres = DBARGS.Driver == "postgres" + IsTidb = DBARGS.Driver == "tidb" +) + +var ( + dORM Ormer + dDbBaser dbBaser +) + +var ( + helpinfo = `need driver and source! + + Default DB Drivers. + + driver: url + mysql: https://github.com/go-sql-driver/mysql + sqlite3: https://github.com/mattn/go-sqlite3 + postgres: https://github.com/lib/pq + tidb: https://github.com/pingcap/tidb + + usage: + + go get -u github.com/astaxie/beego/orm + go get -u github.com/go-sql-driver/mysql + go get -u github.com/mattn/go-sqlite3 + go get -u github.com/lib/pq + go get -u github.com/pingcap/tidb + + #### MySQL + mysql -u root -e 'create database orm_test;' + export ORM_DRIVER=mysql + export ORM_SOURCE="root:@/orm_test?charset=utf8" + go test -v github.com/astaxie/beego/orm + + + #### Sqlite3 + export ORM_DRIVER=sqlite3 + export ORM_SOURCE='file:memory_test?mode=memory' + go test -v github.com/astaxie/beego/orm + + + #### PostgreSQL + psql -c 'create database orm_test;' -U postgres + export ORM_DRIVER=postgres + export ORM_SOURCE="user=postgres dbname=orm_test sslmode=disable" + go test -v github.com/astaxie/beego/orm + + #### TiDB + export ORM_DRIVER=tidb + export ORM_SOURCE='memory://test/test' + go test -v github.com/astaxie/beego/orm + + ` +) + +func init() { + Debug, _ = StrTo(DBARGS.Debug).Bool() + + if DBARGS.Driver == "" || DBARGS.Source == "" { + fmt.Println(helpinfo) + os.Exit(2) + } + + RegisterDataBase("default", DBARGS.Driver, DBARGS.Source, 20) + + alias := getDbAlias("default") + if alias.Driver == DRMySQL { + alias.Engine = "INNODB" + } + +} diff --git a/orm/orm_test.go b/orm/orm_test.go new file mode 100644 index 00000000..eac7b33a --- /dev/null +++ b/orm/orm_test.go @@ -0,0 +1,2500 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// +build go1.8 + +package orm + +import ( + "bytes" + "context" + "database/sql" + "fmt" + "io/ioutil" + "math" + "os" + "path/filepath" + "reflect" + "runtime" + "strings" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +var _ = os.PathSeparator + +var ( + testDate = formatDate + " -0700" + testDateTime = formatDateTime + " -0700" + testTime = formatTime + " -0700" +) + +type argAny []interface{} + +// get interface by index from interface slice +func (a argAny) Get(i int, args ...interface{}) (r interface{}) { + if i >= 0 && i < len(a) { + r = a[i] + } + if len(args) > 0 { + r = args[0] + } + return +} + +func ValuesCompare(is bool, a interface{}, args ...interface{}) (ok bool, err error) { + if len(args) == 0 { + return false, fmt.Errorf("miss args") + } + b := args[0] + arg := argAny(args) + + switch v := a.(type) { + case reflect.Kind: + ok = reflect.ValueOf(b).Kind() == v + case time.Time: + if v2, vo := b.(time.Time); vo { + if arg.Get(1) != nil { + format := ToStr(arg.Get(1)) + a = v.Format(format) + b = v2.Format(format) + ok = a == b + } else { + err = fmt.Errorf("compare datetime miss format") + goto wrongArg + } + } + default: + ok = ToStr(a) == ToStr(b) + } + ok = is && ok || !is && !ok + if !ok { + if is { + err = fmt.Errorf("expected: `%v`, get `%v`", b, a) + } else { + err = fmt.Errorf("expected: `%v`, get `%v`", b, a) + } + } + +wrongArg: + if err != nil { + return false, err + } + + return true, nil +} + +func AssertIs(a interface{}, args ...interface{}) error { + if ok, err := ValuesCompare(true, a, args...); !ok { + return err + } + return nil +} + +func AssertNot(a interface{}, args ...interface{}) error { + if ok, err := ValuesCompare(false, a, args...); !ok { + return err + } + return nil +} + +func getCaller(skip int) string { + pc, file, line, _ := runtime.Caller(skip) + fun := runtime.FuncForPC(pc) + _, fn := filepath.Split(file) + data, err := ioutil.ReadFile(file) + var codes []string + if err == nil { + lines := bytes.Split(data, []byte{'\n'}) + n := 10 + for i := 0; i < n; i++ { + o := line - n + if o < 0 { + continue + } + cur := o + i + 1 + flag := " " + if cur == line { + flag = ">>" + } + code := fmt.Sprintf(" %s %5d: %s", flag, cur, strings.Replace(string(lines[o+i]), "\t", " ", -1)) + if code != "" { + codes = append(codes, code) + } + } + } + funName := fun.Name() + if i := strings.LastIndex(funName, "."); i > -1 { + funName = funName[i+1:] + } + return fmt.Sprintf("%s:%s:%d: \n%s", fn, funName, line, strings.Join(codes, "\n")) +} + +// Deprecated: Using stretchr/testify/assert +func throwFail(t *testing.T, err error, args ...interface{}) { + if err != nil { + con := fmt.Sprintf("\t\nError: %s\n%s\n", err.Error(), getCaller(2)) + if len(args) > 0 { + parts := make([]string, 0, len(args)) + for _, arg := range args { + parts = append(parts, fmt.Sprintf("%v", arg)) + } + con += " " + strings.Join(parts, ", ") + } + t.Error(con) + t.Fail() + } +} + +func throwFailNow(t *testing.T, err error, args ...interface{}) { + if err != nil { + con := fmt.Sprintf("\t\nError: %s\n%s\n", err.Error(), getCaller(2)) + if len(args) > 0 { + parts := make([]string, 0, len(args)) + for _, arg := range args { + parts = append(parts, fmt.Sprintf("%v", arg)) + } + con += " " + strings.Join(parts, ", ") + } + t.Error(con) + t.FailNow() + } +} + +func TestGetDB(t *testing.T) { + if db, err := GetDB(); err != nil { + throwFailNow(t, err) + } else { + err = db.Ping() + throwFailNow(t, err) + } +} + +func TestSyncDb(t *testing.T) { + RegisterModel(new(Data), new(DataNull), new(DataCustom)) + RegisterModel(new(User)) + RegisterModel(new(Profile)) + RegisterModel(new(Post)) + RegisterModel(new(Tag)) + RegisterModel(new(Comment)) + RegisterModel(new(UserBig)) + RegisterModel(new(PostTags)) + RegisterModel(new(Group)) + RegisterModel(new(Permission)) + RegisterModel(new(GroupPermissions)) + RegisterModel(new(InLine)) + RegisterModel(new(InLineOneToOne)) + RegisterModel(new(IntegerPk)) + RegisterModel(new(UintPk)) + RegisterModel(new(PtrPk)) + + err := RunSyncdb("default", true, Debug) + throwFail(t, err) + + modelCache.clean() +} + +func TestRegisterModels(t *testing.T) { + RegisterModel(new(Data), new(DataNull), new(DataCustom)) + RegisterModel(new(User)) + RegisterModel(new(Profile)) + RegisterModel(new(Post)) + RegisterModel(new(Tag)) + RegisterModel(new(Comment)) + RegisterModel(new(UserBig)) + RegisterModel(new(PostTags)) + RegisterModel(new(Group)) + RegisterModel(new(Permission)) + RegisterModel(new(GroupPermissions)) + RegisterModel(new(InLine)) + RegisterModel(new(InLineOneToOne)) + RegisterModel(new(IntegerPk)) + RegisterModel(new(UintPk)) + RegisterModel(new(PtrPk)) + + BootStrap() + + dORM = NewOrm() + dDbBaser = getDbAlias("default").DbBaser +} + +func TestModelSyntax(t *testing.T) { + user := &User{} + ind := reflect.ValueOf(user).Elem() + fn := getFullName(ind.Type()) + mi, ok := modelCache.getByFullName(fn) + throwFail(t, AssertIs(ok, true)) + + mi, ok = modelCache.get("user") + throwFail(t, AssertIs(ok, true)) + if ok { + throwFail(t, AssertIs(mi.fields.GetByName("ShouldSkip") == nil, true)) + } +} + +var DataValues = map[string]interface{}{ + "Boolean": true, + "Char": "char", + "Text": "text", + "JSON": `{"name":"json"}`, + "Jsonb": `{"name": "jsonb"}`, + "Time": time.Now(), + "Date": time.Now(), + "DateTime": time.Now(), + "Byte": byte(1<<8 - 1), + "Rune": rune(1<<31 - 1), + "Int": int(1<<31 - 1), + "Int8": int8(1<<7 - 1), + "Int16": int16(1<<15 - 1), + "Int32": int32(1<<31 - 1), + "Int64": int64(1<<63 - 1), + "Uint": uint(1<<32 - 1), + "Uint8": uint8(1<<8 - 1), + "Uint16": uint16(1<<16 - 1), + "Uint32": uint32(1<<32 - 1), + "Uint64": uint64(1<<63 - 1), // uint64 values with high bit set are not supported + "Float32": float32(100.1234), + "Float64": float64(100.1234), + "Decimal": float64(100.1234), +} + +func TestDataTypes(t *testing.T) { + d := Data{} + ind := reflect.Indirect(reflect.ValueOf(&d)) + + for name, value := range DataValues { + if name == "JSON" { + continue + } + e := ind.FieldByName(name) + e.Set(reflect.ValueOf(value)) + } + id, err := dORM.Insert(&d) + throwFail(t, err) + throwFail(t, AssertIs(id, 1)) + + d = Data{ID: 1} + err = dORM.Read(&d) + throwFail(t, err) + + ind = reflect.Indirect(reflect.ValueOf(&d)) + + for name, value := range DataValues { + e := ind.FieldByName(name) + vu := e.Interface() + switch name { + case "Date": + vu = vu.(time.Time).In(DefaultTimeLoc).Format(testDate) + value = value.(time.Time).In(DefaultTimeLoc).Format(testDate) + case "DateTime": + vu = vu.(time.Time).In(DefaultTimeLoc).Format(testDateTime) + value = value.(time.Time).In(DefaultTimeLoc).Format(testDateTime) + case "Time": + vu = vu.(time.Time).In(DefaultTimeLoc).Format(testTime) + value = value.(time.Time).In(DefaultTimeLoc).Format(testTime) + } + throwFail(t, AssertIs(vu == value, true), value, vu) + } +} + +func TestNullDataTypes(t *testing.T) { + d := DataNull{} + + if IsPostgres { + // can removed when this fixed + // https://github.com/lib/pq/pull/125 + d.DateTime = time.Now() + } + + id, err := dORM.Insert(&d) + throwFail(t, err) + throwFail(t, AssertIs(id, 1)) + + data := `{"ok":1,"data":{"arr":[1,2],"msg":"gopher"}}` + d = DataNull{ID: 1, JSON: data} + num, err := dORM.Update(&d) + throwFail(t, err) + throwFail(t, AssertIs(num, 1)) + + d = DataNull{ID: 1} + err = dORM.Read(&d) + throwFail(t, err) + + throwFail(t, AssertIs(d.JSON, data)) + + throwFail(t, AssertIs(d.NullBool.Valid, false)) + throwFail(t, AssertIs(d.NullString.Valid, false)) + throwFail(t, AssertIs(d.NullInt64.Valid, false)) + throwFail(t, AssertIs(d.NullFloat64.Valid, false)) + + throwFail(t, AssertIs(d.BooleanPtr, nil)) + throwFail(t, AssertIs(d.CharPtr, nil)) + throwFail(t, AssertIs(d.TextPtr, nil)) + throwFail(t, AssertIs(d.BytePtr, nil)) + throwFail(t, AssertIs(d.RunePtr, nil)) + throwFail(t, AssertIs(d.IntPtr, nil)) + throwFail(t, AssertIs(d.Int8Ptr, nil)) + throwFail(t, AssertIs(d.Int16Ptr, nil)) + throwFail(t, AssertIs(d.Int32Ptr, nil)) + throwFail(t, AssertIs(d.Int64Ptr, nil)) + throwFail(t, AssertIs(d.UintPtr, nil)) + throwFail(t, AssertIs(d.Uint8Ptr, nil)) + throwFail(t, AssertIs(d.Uint16Ptr, nil)) + throwFail(t, AssertIs(d.Uint32Ptr, nil)) + throwFail(t, AssertIs(d.Uint64Ptr, nil)) + throwFail(t, AssertIs(d.Float32Ptr, nil)) + throwFail(t, AssertIs(d.Float64Ptr, nil)) + throwFail(t, AssertIs(d.DecimalPtr, nil)) + throwFail(t, AssertIs(d.TimePtr, nil)) + throwFail(t, AssertIs(d.DatePtr, nil)) + throwFail(t, AssertIs(d.DateTimePtr, nil)) + + _, err = dORM.Raw(`INSERT INTO data_null (boolean) VALUES (?)`, nil).Exec() + throwFail(t, err) + + d = DataNull{ID: 2} + err = dORM.Read(&d) + throwFail(t, err) + + booleanPtr := true + charPtr := string("test") + textPtr := string("test") + bytePtr := byte('t') + runePtr := rune('t') + intPtr := int(42) + int8Ptr := int8(42) + int16Ptr := int16(42) + int32Ptr := int32(42) + int64Ptr := int64(42) + uintPtr := uint(42) + uint8Ptr := uint8(42) + uint16Ptr := uint16(42) + uint32Ptr := uint32(42) + uint64Ptr := uint64(42) + float32Ptr := float32(42.0) + float64Ptr := float64(42.0) + decimalPtr := float64(42.0) + timePtr := time.Now() + datePtr := time.Now() + dateTimePtr := time.Now() + + d = DataNull{ + DateTime: time.Now(), + NullString: sql.NullString{String: "test", Valid: true}, + NullBool: sql.NullBool{Bool: true, Valid: true}, + NullInt64: sql.NullInt64{Int64: 42, Valid: true}, + NullFloat64: sql.NullFloat64{Float64: 42.42, Valid: true}, + BooleanPtr: &booleanPtr, + CharPtr: &charPtr, + TextPtr: &textPtr, + BytePtr: &bytePtr, + RunePtr: &runePtr, + IntPtr: &intPtr, + Int8Ptr: &int8Ptr, + Int16Ptr: &int16Ptr, + Int32Ptr: &int32Ptr, + Int64Ptr: &int64Ptr, + UintPtr: &uintPtr, + Uint8Ptr: &uint8Ptr, + Uint16Ptr: &uint16Ptr, + Uint32Ptr: &uint32Ptr, + Uint64Ptr: &uint64Ptr, + Float32Ptr: &float32Ptr, + Float64Ptr: &float64Ptr, + DecimalPtr: &decimalPtr, + TimePtr: &timePtr, + DatePtr: &datePtr, + DateTimePtr: &dateTimePtr, + } + + id, err = dORM.Insert(&d) + throwFail(t, err) + throwFail(t, AssertIs(id, 3)) + + d = DataNull{ID: 3} + err = dORM.Read(&d) + throwFail(t, err) + + throwFail(t, AssertIs(d.NullBool.Valid, true)) + throwFail(t, AssertIs(d.NullBool.Bool, true)) + + throwFail(t, AssertIs(d.NullString.Valid, true)) + throwFail(t, AssertIs(d.NullString.String, "test")) + + throwFail(t, AssertIs(d.NullInt64.Valid, true)) + throwFail(t, AssertIs(d.NullInt64.Int64, 42)) + + throwFail(t, AssertIs(d.NullFloat64.Valid, true)) + throwFail(t, AssertIs(d.NullFloat64.Float64, 42.42)) + + throwFail(t, AssertIs(*d.BooleanPtr, booleanPtr)) + throwFail(t, AssertIs(*d.CharPtr, charPtr)) + throwFail(t, AssertIs(*d.TextPtr, textPtr)) + throwFail(t, AssertIs(*d.BytePtr, bytePtr)) + throwFail(t, AssertIs(*d.RunePtr, runePtr)) + throwFail(t, AssertIs(*d.IntPtr, intPtr)) + throwFail(t, AssertIs(*d.Int8Ptr, int8Ptr)) + throwFail(t, AssertIs(*d.Int16Ptr, int16Ptr)) + throwFail(t, AssertIs(*d.Int32Ptr, int32Ptr)) + throwFail(t, AssertIs(*d.Int64Ptr, int64Ptr)) + throwFail(t, AssertIs(*d.UintPtr, uintPtr)) + throwFail(t, AssertIs(*d.Uint8Ptr, uint8Ptr)) + throwFail(t, AssertIs(*d.Uint16Ptr, uint16Ptr)) + throwFail(t, AssertIs(*d.Uint32Ptr, uint32Ptr)) + throwFail(t, AssertIs(*d.Uint64Ptr, uint64Ptr)) + throwFail(t, AssertIs(*d.Float32Ptr, float32Ptr)) + throwFail(t, AssertIs(*d.Float64Ptr, float64Ptr)) + throwFail(t, AssertIs(*d.DecimalPtr, decimalPtr)) + + // in mysql, there are some precision problem, (*d.TimePtr).UTC() != timePtr.UTC() + assert.True(t, (*d.TimePtr).UTC().Sub(timePtr.UTC()) <= time.Second) + assert.True(t, (*d.DatePtr).UTC().Sub(datePtr.UTC()) <= time.Second) + assert.True(t, (*d.DateTimePtr).UTC().Sub(dateTimePtr.UTC()) <= time.Second) + + // test support for pointer fields using RawSeter.QueryRows() + var dnList []*DataNull + Q := dDbBaser.TableQuote() + num, err = dORM.Raw(fmt.Sprintf("SELECT * FROM %sdata_null%s where id=?", Q, Q), 3).QueryRows(&dnList) + throwFailNow(t, err) + throwFailNow(t, AssertIs(num, 1)) + equal := reflect.DeepEqual(*dnList[0], d) + throwFailNow(t, AssertIs(equal, true)) +} + +func TestDataCustomTypes(t *testing.T) { + d := DataCustom{} + ind := reflect.Indirect(reflect.ValueOf(&d)) + + for name, value := range DataValues { + e := ind.FieldByName(name) + if !e.IsValid() { + continue + } + e.Set(reflect.ValueOf(value).Convert(e.Type())) + } + + id, err := dORM.Insert(&d) + throwFail(t, err) + throwFail(t, AssertIs(id, 1)) + + d = DataCustom{ID: 1} + err = dORM.Read(&d) + throwFail(t, err) + + ind = reflect.Indirect(reflect.ValueOf(&d)) + + for name, value := range DataValues { + e := ind.FieldByName(name) + if !e.IsValid() { + continue + } + vu := e.Interface() + value = reflect.ValueOf(value).Convert(e.Type()).Interface() + throwFail(t, AssertIs(vu == value, true), value, vu) + } +} + +func TestCRUD(t *testing.T) { + profile := NewProfile() + profile.Age = 30 + profile.Money = 1234.12 + id, err := dORM.Insert(profile) + throwFail(t, err) + throwFail(t, AssertIs(id, 1)) + + user := NewUser() + user.UserName = "slene" + user.Email = "vslene@gmail.com" + user.Password = "pass" + user.Status = 3 + user.IsStaff = true + user.IsActive = true + + id, err = dORM.Insert(user) + throwFail(t, err) + throwFail(t, AssertIs(id, 1)) + + u := &User{ID: user.ID} + err = dORM.Read(u) + throwFail(t, err) + + throwFail(t, AssertIs(u.UserName, "slene")) + throwFail(t, AssertIs(u.Email, "vslene@gmail.com")) + throwFail(t, AssertIs(u.Password, "pass")) + throwFail(t, AssertIs(u.Status, 3)) + throwFail(t, AssertIs(u.IsStaff, true)) + throwFail(t, AssertIs(u.IsActive, true)) + + assert.True(t, u.Created.In(DefaultTimeLoc).Sub(user.Created.In(DefaultTimeLoc)) <= time.Second) + assert.True(t, u.Updated.In(DefaultTimeLoc).Sub(user.Updated.In(DefaultTimeLoc)) <= time.Second) + + user.UserName = "astaxie" + user.Profile = profile + num, err := dORM.Update(user) + throwFail(t, err) + throwFail(t, AssertIs(num, 1)) + + u = &User{ID: user.ID} + err = dORM.Read(u) + throwFailNow(t, err) + throwFail(t, AssertIs(u.UserName, "astaxie")) + throwFail(t, AssertIs(u.Profile.ID, profile.ID)) + + u = &User{UserName: "astaxie", Password: "pass"} + err = dORM.Read(u, "UserName") + throwFailNow(t, err) + throwFailNow(t, AssertIs(id, 1)) + + u.UserName = "QQ" + u.Password = "111" + num, err = dORM.Update(u, "UserName") + throwFail(t, err) + throwFail(t, AssertIs(num, 1)) + + u = &User{ID: user.ID} + err = dORM.Read(u) + throwFailNow(t, err) + throwFail(t, AssertIs(u.UserName, "QQ")) + throwFail(t, AssertIs(u.Password, "pass")) + + num, err = dORM.Delete(profile) + throwFail(t, err) + throwFail(t, AssertIs(num, 1)) + + u = &User{ID: user.ID} + err = dORM.Read(u) + throwFail(t, err) + throwFail(t, AssertIs(true, u.Profile == nil)) + + num, err = dORM.Delete(user) + throwFail(t, err) + throwFail(t, AssertIs(num, 1)) + + u = &User{ID: 100} + err = dORM.Read(u) + throwFail(t, AssertIs(err, ErrNoRows)) + + ub := UserBig{} + ub.Name = "name" + id, err = dORM.Insert(&ub) + throwFail(t, err) + throwFail(t, AssertIs(id, 1)) + + ub = UserBig{ID: 1} + err = dORM.Read(&ub) + throwFail(t, err) + throwFail(t, AssertIs(ub.Name, "name")) + + num, err = dORM.Delete(&ub, "name") + throwFail(t, err) + throwFail(t, AssertIs(num, 1)) +} + +func TestInsertTestData(t *testing.T) { + var users []*User + + profile := NewProfile() + profile.Age = 28 + profile.Money = 1234.12 + + id, err := dORM.Insert(profile) + throwFail(t, err) + throwFail(t, AssertIs(id, 2)) + + user := NewUser() + user.UserName = "slene" + user.Email = "vslene@gmail.com" + user.Password = "pass" + user.Status = 1 + user.IsStaff = false + user.IsActive = true + user.Profile = profile + + users = append(users, user) + + id, err = dORM.Insert(user) + throwFail(t, err) + throwFail(t, AssertIs(id, 2)) + + profile = NewProfile() + profile.Age = 30 + profile.Money = 4321.09 + + id, err = dORM.Insert(profile) + throwFail(t, err) + throwFail(t, AssertIs(id, 3)) + + user = NewUser() + user.UserName = "astaxie" + user.Email = "astaxie@gmail.com" + user.Password = "password" + user.Status = 2 + user.IsStaff = true + user.IsActive = false + user.Profile = profile + + users = append(users, user) + + id, err = dORM.Insert(user) + throwFail(t, err) + throwFail(t, AssertIs(id, 3)) + + user = NewUser() + user.UserName = "nobody" + user.Email = "nobody@gmail.com" + user.Password = "nobody" + user.Status = 3 + user.IsStaff = false + user.IsActive = false + + users = append(users, user) + + id, err = dORM.Insert(user) + throwFail(t, err) + throwFail(t, AssertIs(id, 4)) + + tags := []*Tag{ + {Name: "golang", BestPost: &Post{ID: 2}}, + {Name: "example"}, + {Name: "format"}, + {Name: "c++"}, + } + + posts := []*Post{ + {User: users[0], Tags: []*Tag{tags[0]}, Title: "Introduction", Content: `Go is a new language. Although it borrows ideas from existing languages, it has unusual properties that make effective Go programs different in character from programs written in its relatives. A straightforward translation of a C++ or Java program into Go is unlikely to produce a satisfactory result—Java programs are written in Java, not Go. On the other hand, thinking about the problem from a Go perspective could produce a successful but quite different program. In other words, to write Go well, it's important to understand its properties and idioms. It's also important to know the established conventions for programming in Go, such as naming, formatting, program construction, and so on, so that programs you write will be easy for other Go programmers to understand. +This document gives tips for writing clear, idiomatic Go code. It augments the language specification, the Tour of Go, and How to Write Go Code, all of which you should read first.`}, + {User: users[1], Tags: []*Tag{tags[0], tags[1]}, Title: "Examples", Content: `The Go package sources are intended to serve not only as the core library but also as examples of how to use the language. Moreover, many of the packages contain working, self-contained executable examples you can run directly from the golang.org web site, such as this one (click on the word "Example" to open it up). If you have a question about how to approach a problem or how something might be implemented, the documentation, code and examples in the library can provide answers, ideas and background.`}, + {User: users[1], Tags: []*Tag{tags[0], tags[2]}, Title: "Formatting", Content: `Formatting issues are the most contentious but the least consequential. People can adapt to different formatting styles but it's better if they don't have to, and less time is devoted to the topic if everyone adheres to the same style. The problem is how to approach this Utopia without a long prescriptive style guide. +With Go we take an unusual approach and let the machine take care of most formatting issues. The gofmt program (also available as go fmt, which operates at the package level rather than source file level) reads a Go program and emits the source in a standard style of indentation and vertical alignment, retaining and if necessary reformatting comments. If you want to know how to handle some new layout situation, run gofmt; if the answer doesn't seem right, rearrange your program (or file a bug about gofmt), don't work around it.`}, + {User: users[2], Tags: []*Tag{tags[3]}, Title: "Commentary", Content: `Go provides C-style /* */ block comments and C++-style // line comments. Line comments are the norm; block comments appear mostly as package comments, but are useful within an expression or to disable large swaths of code. +The program—and web server—godoc processes Go source files to extract documentation about the contents of the package. Comments that appear before top-level declarations, with no intervening newlines, are extracted along with the declaration to serve as explanatory text for the item. The nature and style of these comments determines the quality of the documentation godoc produces.`}, + } + + comments := []*Comment{ + {Post: posts[0], Content: "a comment"}, + {Post: posts[1], Content: "yes"}, + {Post: posts[1]}, + {Post: posts[1]}, + {Post: posts[2]}, + {Post: posts[2]}, + } + + for _, tag := range tags { + id, err := dORM.Insert(tag) + throwFail(t, err) + throwFail(t, AssertIs(id > 0, true)) + } + + for _, post := range posts { + id, err := dORM.Insert(post) + throwFail(t, err) + throwFail(t, AssertIs(id > 0, true)) + + num := len(post.Tags) + if num > 0 { + nums, err := dORM.QueryM2M(post, "tags").Add(post.Tags) + throwFailNow(t, err) + throwFailNow(t, AssertIs(nums, num)) + } + } + + for _, comment := range comments { + id, err := dORM.Insert(comment) + throwFail(t, err) + throwFail(t, AssertIs(id > 0, true)) + } + + permissions := []*Permission{ + {Name: "writePosts"}, + {Name: "readComments"}, + {Name: "readPosts"}, + } + + groups := []*Group{ + { + Name: "admins", + Permissions: []*Permission{permissions[0], permissions[1], permissions[2]}, + }, + { + Name: "users", + Permissions: []*Permission{permissions[1], permissions[2]}, + }, + } + + for _, permission := range permissions { + id, err := dORM.Insert(permission) + throwFail(t, err) + throwFail(t, AssertIs(id > 0, true)) + } + + for _, group := range groups { + _, err := dORM.Insert(group) + throwFail(t, err) + throwFail(t, AssertIs(id > 0, true)) + + num := len(group.Permissions) + if num > 0 { + nums, err := dORM.QueryM2M(group, "permissions").Add(group.Permissions) + throwFailNow(t, err) + throwFailNow(t, AssertIs(nums, num)) + } + } + +} + +func TestCustomField(t *testing.T) { + user := User{ID: 2} + err := dORM.Read(&user) + throwFailNow(t, err) + + user.Langs = append(user.Langs, "zh-CN", "en-US") + user.Extra.Name = "beego" + user.Extra.Data = "orm" + _, err = dORM.Update(&user, "Langs", "Extra") + throwFailNow(t, err) + + user = User{ID: 2} + err = dORM.Read(&user) + throwFailNow(t, err) + throwFailNow(t, AssertIs(len(user.Langs), 2)) + throwFailNow(t, AssertIs(user.Langs[0], "zh-CN")) + throwFailNow(t, AssertIs(user.Langs[1], "en-US")) + + throwFailNow(t, AssertIs(user.Extra.Name, "beego")) + throwFailNow(t, AssertIs(user.Extra.Data, "orm")) +} + +func TestExpr(t *testing.T) { + user := &User{} + qs := dORM.QueryTable(user) + qs = dORM.QueryTable((*User)(nil)) + qs = dORM.QueryTable("User") + qs = dORM.QueryTable("user") + num, err := qs.Filter("UserName", "slene").Filter("user_name", "slene").Filter("profile__Age", 28).Count() + throwFail(t, err) + throwFail(t, AssertIs(num, 1)) + + num, err = qs.Filter("created", time.Now()).Count() + throwFail(t, err) + throwFail(t, AssertIs(num, 3)) + + // num, err = qs.Filter("created", time.Now().Format(format_Date)).Count() + // throwFail(t, err) + // throwFail(t, AssertIs(num, 3)) +} + +func TestOperators(t *testing.T) { + qs := dORM.QueryTable("user") + num, err := qs.Filter("user_name", "slene").Count() + throwFail(t, err) + throwFail(t, AssertIs(num, 1)) + + num, err = qs.Filter("user_name__exact", String("slene")).Count() + throwFail(t, err) + throwFail(t, AssertIs(num, 1)) + + num, err = qs.Filter("user_name__exact", "slene").Count() + throwFail(t, err) + throwFail(t, AssertIs(num, 1)) + + num, err = qs.Filter("user_name__iexact", "Slene").Count() + throwFail(t, err) + throwFail(t, AssertIs(num, 1)) + + num, err = qs.Filter("user_name__contains", "e").Count() + throwFail(t, err) + throwFail(t, AssertIs(num, 2)) + + var shouldNum int + + if IsSqlite || IsTidb { + shouldNum = 2 + } else { + shouldNum = 0 + } + + num, err = qs.Filter("user_name__contains", "E").Count() + throwFail(t, err) + throwFail(t, AssertIs(num, shouldNum)) + + num, err = qs.Filter("user_name__icontains", "E").Count() + throwFail(t, err) + throwFail(t, AssertIs(num, 2)) + + num, err = qs.Filter("user_name__icontains", "E").Count() + throwFail(t, err) + throwFail(t, AssertIs(num, 2)) + + num, err = qs.Filter("status__gt", 1).Count() + throwFail(t, err) + throwFail(t, AssertIs(num, 2)) + + num, err = qs.Filter("status__gte", 1).Count() + throwFail(t, err) + throwFail(t, AssertIs(num, 3)) + + num, err = qs.Filter("status__lt", Uint(3)).Count() + throwFail(t, err) + throwFail(t, AssertIs(num, 2)) + + num, err = qs.Filter("status__lte", Int(3)).Count() + throwFail(t, err) + throwFail(t, AssertIs(num, 3)) + + num, err = qs.Filter("user_name__startswith", "s").Count() + throwFail(t, err) + throwFail(t, AssertIs(num, 1)) + + if IsSqlite || IsTidb { + shouldNum = 1 + } else { + shouldNum = 0 + } + + num, err = qs.Filter("user_name__startswith", "S").Count() + throwFail(t, err) + throwFail(t, AssertIs(num, shouldNum)) + + num, err = qs.Filter("user_name__istartswith", "S").Count() + throwFail(t, err) + throwFail(t, AssertIs(num, 1)) + + num, err = qs.Filter("user_name__endswith", "e").Count() + throwFail(t, err) + throwFail(t, AssertIs(num, 2)) + + if IsSqlite || IsTidb { + shouldNum = 2 + } else { + shouldNum = 0 + } + + num, err = qs.Filter("user_name__endswith", "E").Count() + throwFail(t, err) + throwFail(t, AssertIs(num, shouldNum)) + + num, err = qs.Filter("user_name__iendswith", "E").Count() + throwFail(t, err) + throwFail(t, AssertIs(num, 2)) + + num, err = qs.Filter("profile__isnull", true).Count() + throwFail(t, err) + throwFail(t, AssertIs(num, 1)) + + num, err = qs.Filter("status__in", 1, 2).Count() + throwFail(t, err) + throwFail(t, AssertIs(num, 2)) + + num, err = qs.Filter("status__in", []int{1, 2}).Count() + throwFail(t, err) + throwFail(t, AssertIs(num, 2)) + + n1, n2 := 1, 2 + num, err = qs.Filter("status__in", []*int{&n1}, &n2).Count() + throwFail(t, err) + throwFail(t, AssertIs(num, 2)) + + num, err = qs.Filter("id__between", 2, 3).Count() + throwFail(t, err) + throwFail(t, AssertIs(num, 2)) + + num, err = qs.Filter("id__between", []int{2, 3}).Count() + throwFail(t, err) + throwFail(t, AssertIs(num, 2)) + + num, err = qs.FilterRaw("user_name", "= 'slene'").Count() + throwFail(t, err) + throwFail(t, AssertIs(num, 1)) + + num, err = qs.FilterRaw("status", "IN (1, 2)").Count() + throwFail(t, err) + throwFail(t, AssertIs(num, 2)) + + num, err = qs.FilterRaw("profile_id", "IN (SELECT id FROM user_profile WHERE age=30)").Count() + throwFail(t, err) + throwFail(t, AssertIs(num, 1)) +} + +func TestSetCond(t *testing.T) { + cond := NewCondition() + cond1 := cond.And("profile__isnull", false).AndNot("status__in", 1).Or("profile__age__gt", 2000) + + qs := dORM.QueryTable("user") + num, err := qs.SetCond(cond1).Count() + throwFail(t, err) + throwFail(t, AssertIs(num, 1)) + + cond2 := cond.AndCond(cond1).OrCond(cond.And("user_name", "slene")) + num, err = qs.SetCond(cond2).Count() + throwFail(t, err) + throwFail(t, AssertIs(num, 2)) + + cond3 := cond.AndNotCond(cond.And("status__in", 1)) + num, err = qs.SetCond(cond3).Count() + throwFail(t, err) + throwFail(t, AssertIs(num, 2)) + + cond4 := cond.And("user_name", "slene").OrNotCond(cond.And("user_name", "slene")) + num, err = qs.SetCond(cond4).Count() + throwFail(t, err) + throwFail(t, AssertIs(num, 3)) + + cond5 := cond.Raw("user_name", "= 'slene'").OrNotCond(cond.And("user_name", "slene")) + num, err = qs.SetCond(cond5).Count() + throwFail(t, err) + throwFail(t, AssertIs(num, 3)) +} + +func TestLimit(t *testing.T) { + var posts []*Post + qs := dORM.QueryTable("post") + num, err := qs.Limit(1).All(&posts) + throwFail(t, err) + throwFail(t, AssertIs(num, 1)) + + num, err = qs.Limit(-1).All(&posts) + throwFail(t, err) + throwFail(t, AssertIs(num, 4)) + + num, err = qs.Limit(-1, 2).All(&posts) + throwFail(t, err) + throwFail(t, AssertIs(num, 2)) + + num, err = qs.Limit(0, 2).All(&posts) + throwFail(t, err) + throwFail(t, AssertIs(num, 2)) +} + +func TestOffset(t *testing.T) { + var posts []*Post + qs := dORM.QueryTable("post") + num, err := qs.Limit(1).Offset(2).All(&posts) + throwFail(t, err) + throwFail(t, AssertIs(num, 1)) + + num, err = qs.Offset(2).All(&posts) + throwFail(t, err) + throwFail(t, AssertIs(num, 2)) +} + +func TestOrderBy(t *testing.T) { + qs := dORM.QueryTable("user") + num, err := qs.OrderBy("-status").Filter("user_name", "nobody").Count() + throwFail(t, err) + throwFail(t, AssertIs(num, 1)) + + num, err = qs.OrderBy("status").Filter("user_name", "slene").Count() + throwFail(t, err) + throwFail(t, AssertIs(num, 1)) + + num, err = qs.OrderBy("-profile__age").Filter("user_name", "astaxie").Count() + throwFail(t, err) + throwFail(t, AssertIs(num, 1)) +} + +func TestAll(t *testing.T) { + var users []*User + qs := dORM.QueryTable("user") + num, err := qs.OrderBy("Id").All(&users) + throwFail(t, err) + throwFailNow(t, AssertIs(num, 3)) + + throwFail(t, AssertIs(users[0].UserName, "slene")) + throwFail(t, AssertIs(users[1].UserName, "astaxie")) + throwFail(t, AssertIs(users[2].UserName, "nobody")) + + var users2 []User + qs = dORM.QueryTable("user") + num, err = qs.OrderBy("Id").All(&users2) + throwFail(t, err) + throwFailNow(t, AssertIs(num, 3)) + + throwFailNow(t, AssertIs(users2[0].UserName, "slene")) + throwFailNow(t, AssertIs(users2[1].UserName, "astaxie")) + throwFailNow(t, AssertIs(users2[2].UserName, "nobody")) + + qs = dORM.QueryTable("user") + num, err = qs.OrderBy("Id").RelatedSel().All(&users2, "UserName") + throwFail(t, err) + throwFailNow(t, AssertIs(num, 3)) + throwFailNow(t, AssertIs(len(users2), 3)) + throwFailNow(t, AssertIs(users2[0].UserName, "slene")) + throwFailNow(t, AssertIs(users2[1].UserName, "astaxie")) + throwFailNow(t, AssertIs(users2[2].UserName, "nobody")) + throwFailNow(t, AssertIs(users2[0].ID, 0)) + throwFailNow(t, AssertIs(users2[1].ID, 0)) + throwFailNow(t, AssertIs(users2[2].ID, 0)) + throwFailNow(t, AssertIs(users2[0].Profile == nil, false)) + throwFailNow(t, AssertIs(users2[1].Profile == nil, false)) + throwFailNow(t, AssertIs(users2[2].Profile == nil, true)) + + qs = dORM.QueryTable("user") + num, err = qs.Filter("user_name", "nothing").All(&users) + throwFailNow(t, err) + throwFailNow(t, AssertIs(num, 0)) + + var users3 []*User + qs = dORM.QueryTable("user") + num, err = qs.Filter("user_name", "nothing").All(&users3) + throwFailNow(t, err) + throwFailNow(t, AssertIs(num, 0)) + throwFailNow(t, AssertIs(users3 == nil, false)) +} + +func TestOne(t *testing.T) { + var user User + qs := dORM.QueryTable("user") + err := qs.One(&user) + throwFail(t, err) + + user = User{} + err = qs.OrderBy("Id").Limit(1).One(&user) + throwFailNow(t, err) + throwFail(t, AssertIs(user.UserName, "slene")) + throwFail(t, AssertNot(err, ErrMultiRows)) + + user = User{} + err = qs.OrderBy("-Id").Limit(100).One(&user) + throwFailNow(t, err) + throwFail(t, AssertIs(user.UserName, "nobody")) + throwFail(t, AssertNot(err, ErrMultiRows)) + + err = qs.Filter("user_name", "nothing").One(&user) + throwFail(t, AssertIs(err, ErrNoRows)) + +} + +func TestValues(t *testing.T) { + var maps []Params + qs := dORM.QueryTable("user") + + num, err := qs.OrderBy("Id").Values(&maps) + throwFail(t, err) + throwFail(t, AssertIs(num, 3)) + if num == 3 { + throwFail(t, AssertIs(maps[0]["UserName"], "slene")) + throwFail(t, AssertIs(maps[2]["Profile"], nil)) + } + + num, err = qs.OrderBy("Id").Values(&maps, "UserName", "Profile__Age") + throwFail(t, err) + throwFail(t, AssertIs(num, 3)) + if num == 3 { + throwFail(t, AssertIs(maps[0]["UserName"], "slene")) + throwFail(t, AssertIs(maps[0]["Profile__Age"], 28)) + throwFail(t, AssertIs(maps[2]["Profile__Age"], nil)) + } + + num, err = qs.Filter("UserName", "slene").Values(&maps) + throwFail(t, err) + throwFail(t, AssertIs(num, 1)) +} + +func TestValuesList(t *testing.T) { + var list []ParamsList + qs := dORM.QueryTable("user") + + num, err := qs.OrderBy("Id").ValuesList(&list) + throwFail(t, err) + throwFail(t, AssertIs(num, 3)) + if num == 3 { + throwFail(t, AssertIs(list[0][1], "slene")) + throwFail(t, AssertIs(list[2][9], nil)) + } + + num, err = qs.OrderBy("Id").ValuesList(&list, "UserName", "Profile__Age") + throwFail(t, err) + throwFail(t, AssertIs(num, 3)) + if num == 3 { + throwFail(t, AssertIs(list[0][0], "slene")) + throwFail(t, AssertIs(list[0][1], 28)) + throwFail(t, AssertIs(list[2][1], nil)) + } +} + +func TestValuesFlat(t *testing.T) { + var list ParamsList + qs := dORM.QueryTable("user") + + num, err := qs.OrderBy("id").ValuesFlat(&list, "UserName") + throwFail(t, err) + throwFail(t, AssertIs(num, 3)) + if num == 3 { + throwFail(t, AssertIs(list[0], "slene")) + throwFail(t, AssertIs(list[1], "astaxie")) + throwFail(t, AssertIs(list[2], "nobody")) + } +} + +func TestRelatedSel(t *testing.T) { + if IsTidb { + // Skip it. TiDB does not support relation now. + return + } + qs := dORM.QueryTable("user") + num, err := qs.Filter("profile__age", 28).Count() + throwFail(t, err) + throwFail(t, AssertIs(num, 1)) + + num, err = qs.Filter("profile__age__gt", 28).Count() + throwFail(t, err) + throwFail(t, AssertIs(num, 1)) + + num, err = qs.Filter("profile__user__profile__age__gt", 28).Count() + throwFail(t, err) + throwFail(t, AssertIs(num, 1)) + + var user User + err = qs.Filter("user_name", "slene").RelatedSel("profile").One(&user) + throwFail(t, err) + throwFail(t, AssertIs(num, 1)) + throwFail(t, AssertNot(user.Profile, nil)) + if user.Profile != nil { + throwFail(t, AssertIs(user.Profile.Age, 28)) + } + + err = qs.Filter("user_name", "slene").RelatedSel().One(&user) + throwFail(t, err) + throwFail(t, AssertIs(num, 1)) + throwFail(t, AssertNot(user.Profile, nil)) + if user.Profile != nil { + throwFail(t, AssertIs(user.Profile.Age, 28)) + } + + err = qs.Filter("user_name", "nobody").RelatedSel("profile").One(&user) + throwFail(t, err) + throwFail(t, AssertIs(num, 1)) + throwFail(t, AssertIs(user.Profile, nil)) + + qs = dORM.QueryTable("user_profile") + num, err = qs.Filter("user__username", "slene").Count() + throwFail(t, err) + throwFail(t, AssertIs(num, 1)) + + var posts []*Post + qs = dORM.QueryTable("post") + num, err = qs.RelatedSel().All(&posts) + throwFail(t, err) + throwFailNow(t, AssertIs(num, 4)) + + throwFailNow(t, AssertIs(posts[0].User.UserName, "slene")) + throwFailNow(t, AssertIs(posts[1].User.UserName, "astaxie")) + throwFailNow(t, AssertIs(posts[2].User.UserName, "astaxie")) + throwFailNow(t, AssertIs(posts[3].User.UserName, "nobody")) +} + +func TestReverseQuery(t *testing.T) { + var profile Profile + err := dORM.QueryTable("user_profile").Filter("User", 3).One(&profile) + throwFailNow(t, err) + throwFailNow(t, AssertIs(profile.Age, 30)) + + profile = Profile{} + err = dORM.QueryTable("user_profile").Filter("User__UserName", "astaxie").One(&profile) + throwFailNow(t, err) + throwFailNow(t, AssertIs(profile.Age, 30)) + + var user User + err = dORM.QueryTable("user").Filter("Posts__Title", "Examples").One(&user) + throwFailNow(t, err) + throwFailNow(t, AssertIs(user.UserName, "astaxie")) + + user = User{} + err = dORM.QueryTable("user").Filter("Posts__User__UserName", "astaxie").Limit(1).One(&user) + throwFailNow(t, err) + throwFailNow(t, AssertIs(user.UserName, "astaxie")) + + user = User{} + err = dORM.QueryTable("user").Filter("Posts__User__UserName", "astaxie").RelatedSel().Limit(1).One(&user) + throwFailNow(t, err) + throwFailNow(t, AssertIs(user.UserName, "astaxie")) + throwFailNow(t, AssertIs(user.Profile == nil, false)) + throwFailNow(t, AssertIs(user.Profile.Age, 30)) + + var posts []*Post + num, err := dORM.QueryTable("post").Filter("Tags__Tag__Name", "golang").All(&posts) + throwFailNow(t, err) + throwFailNow(t, AssertIs(num, 3)) + throwFailNow(t, AssertIs(posts[0].Title, "Introduction")) + + posts = []*Post{} + num, err = dORM.QueryTable("post").Filter("Tags__Tag__Name", "golang").Filter("User__UserName", "slene").All(&posts) + throwFailNow(t, err) + throwFailNow(t, AssertIs(num, 1)) + throwFailNow(t, AssertIs(posts[0].Title, "Introduction")) + + posts = []*Post{} + num, err = dORM.QueryTable("post").Filter("Tags__Tag__Name", "golang"). + Filter("User__UserName", "slene").RelatedSel().All(&posts) + throwFailNow(t, err) + throwFailNow(t, AssertIs(num, 1)) + throwFailNow(t, AssertIs(posts[0].User == nil, false)) + throwFailNow(t, AssertIs(posts[0].User.UserName, "slene")) + + var tags []*Tag + num, err = dORM.QueryTable("tag").Filter("Posts__Post__Title", "Introduction").All(&tags) + throwFailNow(t, err) + throwFailNow(t, AssertIs(num, 1)) + throwFailNow(t, AssertIs(tags[0].Name, "golang")) + + tags = []*Tag{} + num, err = dORM.QueryTable("tag").Filter("Posts__Post__Title", "Introduction"). + Filter("BestPost__User__UserName", "astaxie").All(&tags) + throwFailNow(t, err) + throwFailNow(t, AssertIs(num, 1)) + throwFailNow(t, AssertIs(tags[0].Name, "golang")) + + tags = []*Tag{} + num, err = dORM.QueryTable("tag").Filter("Posts__Post__Title", "Introduction"). + Filter("BestPost__User__UserName", "astaxie").RelatedSel().All(&tags) + throwFailNow(t, err) + throwFailNow(t, AssertIs(num, 1)) + throwFailNow(t, AssertIs(tags[0].Name, "golang")) + throwFailNow(t, AssertIs(tags[0].BestPost == nil, false)) + throwFailNow(t, AssertIs(tags[0].BestPost.Title, "Examples")) + throwFailNow(t, AssertIs(tags[0].BestPost.User == nil, false)) + throwFailNow(t, AssertIs(tags[0].BestPost.User.UserName, "astaxie")) +} + +func TestLoadRelated(t *testing.T) { + // load reverse foreign key + user := User{ID: 3} + + err := dORM.Read(&user) + throwFailNow(t, err) + + num, err := dORM.LoadRelated(&user, "Posts") + throwFailNow(t, err) + throwFailNow(t, AssertIs(num, 2)) + throwFailNow(t, AssertIs(len(user.Posts), 2)) + throwFailNow(t, AssertIs(user.Posts[0].User.ID, 3)) + + num, err = dORM.LoadRelated(&user, "Posts", true) + throwFailNow(t, err) + throwFailNow(t, AssertIs(num, 2)) + throwFailNow(t, AssertIs(len(user.Posts), 2)) + throwFailNow(t, AssertIs(user.Posts[0].User.UserName, "astaxie")) + + num, err = dORM.LoadRelated(&user, "Posts", true, 1) + throwFailNow(t, err) + throwFailNow(t, AssertIs(num, 1)) + throwFailNow(t, AssertIs(len(user.Posts), 1)) + + num, err = dORM.LoadRelated(&user, "Posts", true, 0, 0, "-Id") + throwFailNow(t, err) + throwFailNow(t, AssertIs(num, 2)) + throwFailNow(t, AssertIs(len(user.Posts), 2)) + throwFailNow(t, AssertIs(user.Posts[0].Title, "Formatting")) + + num, err = dORM.LoadRelated(&user, "Posts", true, 1, 1, "Id") + throwFailNow(t, err) + throwFailNow(t, AssertIs(num, 1)) + throwFailNow(t, AssertIs(len(user.Posts), 1)) + throwFailNow(t, AssertIs(user.Posts[0].Title, "Formatting")) + + // load reverse one to one + profile := Profile{ID: 3} + profile.BestPost = &Post{ID: 2} + num, err = dORM.Update(&profile, "BestPost") + throwFailNow(t, err) + throwFailNow(t, AssertIs(num, 1)) + + err = dORM.Read(&profile) + throwFailNow(t, err) + + num, err = dORM.LoadRelated(&profile, "User") + throwFailNow(t, err) + throwFailNow(t, AssertIs(num, 1)) + throwFailNow(t, AssertIs(profile.User == nil, false)) + throwFailNow(t, AssertIs(profile.User.UserName, "astaxie")) + + num, err = dORM.LoadRelated(&profile, "User", true) + throwFailNow(t, err) + throwFailNow(t, AssertIs(num, 1)) + throwFailNow(t, AssertIs(profile.User == nil, false)) + throwFailNow(t, AssertIs(profile.User.UserName, "astaxie")) + throwFailNow(t, AssertIs(profile.User.Profile.Age, profile.Age)) + + // load rel one to one + err = dORM.Read(&user) + throwFailNow(t, err) + + num, err = dORM.LoadRelated(&user, "Profile") + throwFailNow(t, err) + throwFailNow(t, AssertIs(num, 1)) + throwFailNow(t, AssertIs(user.Profile == nil, false)) + throwFailNow(t, AssertIs(user.Profile.Age, 30)) + + num, err = dORM.LoadRelated(&user, "Profile", true) + throwFailNow(t, err) + throwFailNow(t, AssertIs(num, 1)) + throwFailNow(t, AssertIs(user.Profile == nil, false)) + throwFailNow(t, AssertIs(user.Profile.Age, 30)) + throwFailNow(t, AssertIs(user.Profile.BestPost == nil, false)) + throwFailNow(t, AssertIs(user.Profile.BestPost.Title, "Examples")) + + post := Post{ID: 2} + + // load rel foreign key + err = dORM.Read(&post) + throwFailNow(t, err) + + num, err = dORM.LoadRelated(&post, "User") + throwFailNow(t, err) + throwFailNow(t, AssertIs(num, 1)) + throwFailNow(t, AssertIs(post.User == nil, false)) + throwFailNow(t, AssertIs(post.User.UserName, "astaxie")) + + num, err = dORM.LoadRelated(&post, "User", true) + throwFailNow(t, err) + throwFailNow(t, AssertIs(num, 1)) + throwFailNow(t, AssertIs(post.User == nil, false)) + throwFailNow(t, AssertIs(post.User.UserName, "astaxie")) + throwFailNow(t, AssertIs(post.User.Profile == nil, false)) + throwFailNow(t, AssertIs(post.User.Profile.Age, 30)) + + // load rel m2m + post = Post{ID: 2} + + err = dORM.Read(&post) + throwFailNow(t, err) + + num, err = dORM.LoadRelated(&post, "Tags") + throwFailNow(t, err) + throwFailNow(t, AssertIs(num, 2)) + throwFailNow(t, AssertIs(len(post.Tags), 2)) + throwFailNow(t, AssertIs(post.Tags[0].Name, "golang")) + + num, err = dORM.LoadRelated(&post, "Tags", true) + throwFailNow(t, err) + throwFailNow(t, AssertIs(num, 2)) + throwFailNow(t, AssertIs(len(post.Tags), 2)) + throwFailNow(t, AssertIs(post.Tags[0].Name, "golang")) + throwFailNow(t, AssertIs(post.Tags[0].BestPost == nil, false)) + throwFailNow(t, AssertIs(post.Tags[0].BestPost.User.UserName, "astaxie")) + + // load reverse m2m + tag := Tag{ID: 1} + + err = dORM.Read(&tag) + throwFailNow(t, err) + + num, err = dORM.LoadRelated(&tag, "Posts") + throwFailNow(t, err) + throwFailNow(t, AssertIs(num, 3)) + throwFailNow(t, AssertIs(tag.Posts[0].Title, "Introduction")) + throwFailNow(t, AssertIs(tag.Posts[0].User.ID, 2)) + throwFailNow(t, AssertIs(tag.Posts[0].User.Profile == nil, true)) + + num, err = dORM.LoadRelated(&tag, "Posts", true) + throwFailNow(t, err) + throwFailNow(t, AssertIs(num, 3)) + throwFailNow(t, AssertIs(tag.Posts[0].Title, "Introduction")) + throwFailNow(t, AssertIs(tag.Posts[0].User.ID, 2)) + throwFailNow(t, AssertIs(tag.Posts[0].User.UserName, "slene")) +} + +func TestQueryM2M(t *testing.T) { + post := Post{ID: 4} + m2m := dORM.QueryM2M(&post, "Tags") + + tag1 := []*Tag{{Name: "TestTag1"}, {Name: "TestTag2"}} + tag2 := &Tag{Name: "TestTag3"} + tag3 := []interface{}{&Tag{Name: "TestTag4"}} + + tags := []interface{}{tag1[0], tag1[1], tag2, tag3[0]} + + for _, tag := range tags { + _, err := dORM.Insert(tag) + throwFailNow(t, err) + } + + num, err := m2m.Add(tag1) + throwFailNow(t, err) + throwFailNow(t, AssertIs(num, 2)) + + num, err = m2m.Add(tag2) + throwFailNow(t, err) + throwFailNow(t, AssertIs(num, 1)) + + num, err = m2m.Add(tag3) + throwFailNow(t, err) + throwFailNow(t, AssertIs(num, 1)) + + num, err = m2m.Count() + throwFailNow(t, err) + throwFailNow(t, AssertIs(num, 5)) + + num, err = m2m.Remove(tag3) + throwFailNow(t, err) + throwFailNow(t, AssertIs(num, 1)) + + num, err = m2m.Count() + throwFailNow(t, err) + throwFailNow(t, AssertIs(num, 4)) + + exist := m2m.Exist(tag2) + throwFailNow(t, AssertIs(exist, true)) + + num, err = m2m.Remove(tag2) + throwFailNow(t, err) + throwFailNow(t, AssertIs(num, 1)) + + exist = m2m.Exist(tag2) + throwFailNow(t, AssertIs(exist, false)) + + num, err = m2m.Count() + throwFailNow(t, err) + throwFailNow(t, AssertIs(num, 3)) + + num, err = m2m.Clear() + throwFailNow(t, err) + throwFailNow(t, AssertIs(num, 3)) + + num, err = m2m.Count() + throwFailNow(t, err) + throwFailNow(t, AssertIs(num, 0)) + + tag := Tag{Name: "test"} + _, err = dORM.Insert(&tag) + throwFailNow(t, err) + + m2m = dORM.QueryM2M(&tag, "Posts") + + post1 := []*Post{{Title: "TestPost1"}, {Title: "TestPost2"}} + post2 := &Post{Title: "TestPost3"} + post3 := []interface{}{&Post{Title: "TestPost4"}} + + posts := []interface{}{post1[0], post1[1], post2, post3[0]} + + for _, post := range posts { + p := post.(*Post) + p.User = &User{ID: 1} + _, err := dORM.Insert(post) + throwFailNow(t, err) + } + + num, err = m2m.Add(post1) + throwFailNow(t, err) + throwFailNow(t, AssertIs(num, 2)) + + num, err = m2m.Add(post2) + throwFailNow(t, err) + throwFailNow(t, AssertIs(num, 1)) + + num, err = m2m.Add(post3) + throwFailNow(t, err) + throwFailNow(t, AssertIs(num, 1)) + + num, err = m2m.Count() + throwFailNow(t, err) + throwFailNow(t, AssertIs(num, 4)) + + num, err = m2m.Remove(post3) + throwFailNow(t, err) + throwFailNow(t, AssertIs(num, 1)) + + num, err = m2m.Count() + throwFailNow(t, err) + throwFailNow(t, AssertIs(num, 3)) + + exist = m2m.Exist(post2) + throwFailNow(t, AssertIs(exist, true)) + + num, err = m2m.Remove(post2) + throwFailNow(t, err) + throwFailNow(t, AssertIs(num, 1)) + + exist = m2m.Exist(post2) + throwFailNow(t, AssertIs(exist, false)) + + num, err = m2m.Count() + throwFailNow(t, err) + throwFailNow(t, AssertIs(num, 2)) + + num, err = m2m.Clear() + throwFailNow(t, err) + throwFailNow(t, AssertIs(num, 2)) + + num, err = m2m.Count() + throwFailNow(t, err) + throwFailNow(t, AssertIs(num, 0)) + + num, err = dORM.Delete(&tag) + throwFailNow(t, err) + throwFailNow(t, AssertIs(num, 1)) +} + +func TestQueryRelate(t *testing.T) { + // post := &Post{Id: 2} + + // qs := dORM.QueryRelate(post, "Tags") + // num, err := qs.Count() + // throwFailNow(t, err) + // throwFailNow(t, AssertIs(num, 2)) + + // var tags []*Tag + // num, err = qs.All(&tags) + // throwFailNow(t, err) + // throwFailNow(t, AssertIs(num, 2)) + // throwFailNow(t, AssertIs(tags[0].Name, "golang")) + + // num, err = dORM.QueryTable("Tag").Filter("Posts__Post", 2).Count() + // throwFailNow(t, err) + // throwFailNow(t, AssertIs(num, 2)) +} + +func TestPkManyRelated(t *testing.T) { + permission := &Permission{Name: "readPosts"} + err := dORM.Read(permission, "Name") + throwFailNow(t, err) + + var groups []*Group + qs := dORM.QueryTable("Group") + num, err := qs.Filter("Permissions__Permission", permission.ID).All(&groups) + throwFailNow(t, err) + throwFailNow(t, AssertIs(num, 2)) +} + +func TestPrepareInsert(t *testing.T) { + qs := dORM.QueryTable("user") + i, err := qs.PrepareInsert() + throwFailNow(t, err) + + var user User + user.UserName = "testing1" + num, err := i.Insert(&user) + throwFail(t, err) + throwFail(t, AssertIs(num > 0, true)) + + user.UserName = "testing2" + num, err = i.Insert(&user) + throwFail(t, err) + throwFail(t, AssertIs(num > 0, true)) + + num, err = qs.Filter("user_name__in", "testing1", "testing2").Delete() + throwFail(t, err) + throwFail(t, AssertIs(num, 2)) + + err = i.Close() + throwFail(t, err) + err = i.Close() + throwFail(t, AssertIs(err, ErrStmtClosed)) +} + +func TestRawExec(t *testing.T) { + Q := dDbBaser.TableQuote() + + query := fmt.Sprintf("UPDATE %suser%s SET %suser_name%s = ? WHERE %suser_name%s = ?", Q, Q, Q, Q, Q, Q) + res, err := dORM.Raw(query, "testing", "slene").Exec() + throwFail(t, err) + num, err := res.RowsAffected() + throwFail(t, AssertIs(num, 1), err) + + res, err = dORM.Raw(query, "slene", "testing").Exec() + throwFail(t, err) + num, err = res.RowsAffected() + throwFail(t, AssertIs(num, 1), err) +} + +func TestRawQueryRow(t *testing.T) { + var ( + Boolean bool + Char string + Text string + Time time.Time + Date time.Time + DateTime time.Time + Byte byte + Rune rune + Int int + Int8 int + Int16 int16 + Int32 int32 + Int64 int64 + Uint uint + Uint8 uint8 + Uint16 uint16 + Uint32 uint32 + Uint64 uint64 + Float32 float32 + Float64 float64 + Decimal float64 + ) + + dataValues := make(map[string]interface{}, len(DataValues)) + + for k, v := range DataValues { + dataValues[strings.ToLower(k)] = v + } + + Q := dDbBaser.TableQuote() + + cols := []string{ + "id", "boolean", "char", "text", "time", "date", "datetime", "byte", "rune", "int", "int8", "int16", "int32", + "int64", "uint", "uint8", "uint16", "uint32", "uint64", "float32", "float64", "decimal", + } + sep := fmt.Sprintf("%s, %s", Q, Q) + query := fmt.Sprintf("SELECT %s%s%s FROM data WHERE id = ?", Q, strings.Join(cols, sep), Q) + var id int + values := []interface{}{ + &id, &Boolean, &Char, &Text, &Time, &Date, &DateTime, &Byte, &Rune, &Int, &Int8, &Int16, &Int32, + &Int64, &Uint, &Uint8, &Uint16, &Uint32, &Uint64, &Float32, &Float64, &Decimal, + } + err := dORM.Raw(query, 1).QueryRow(values...) + throwFailNow(t, err) + for i, col := range cols { + vu := values[i] + v := reflect.ValueOf(vu).Elem().Interface() + switch col { + case "id": + throwFail(t, AssertIs(id, 1)) + case "time": + v = v.(time.Time).In(DefaultTimeLoc) + value := dataValues[col].(time.Time).In(DefaultTimeLoc) + throwFail(t, AssertIs(v, value, testTime)) + case "date": + v = v.(time.Time).In(DefaultTimeLoc) + value := dataValues[col].(time.Time).In(DefaultTimeLoc) + throwFail(t, AssertIs(v, value, testDate)) + case "datetime": + v = v.(time.Time).In(DefaultTimeLoc) + value := dataValues[col].(time.Time).In(DefaultTimeLoc) + throwFail(t, AssertIs(v, value, testDateTime)) + default: + throwFail(t, AssertIs(v, dataValues[col])) + } + } + + var ( + uid int + status *int + pid *int + ) + + cols = []string{ + "id", "Status", "profile_id", + } + query = fmt.Sprintf("SELECT %s%s%s FROM %suser%s WHERE id = ?", Q, strings.Join(cols, sep), Q, Q, Q) + err = dORM.Raw(query, 4).QueryRow(&uid, &status, &pid) + throwFail(t, err) + throwFail(t, AssertIs(uid, 4)) + throwFail(t, AssertIs(*status, 3)) + throwFail(t, AssertIs(pid, nil)) + + // test for sql.Null* fields + nData := &DataNull{ + NullString: sql.NullString{String: "test sql.null", Valid: true}, + NullBool: sql.NullBool{Bool: true, Valid: true}, + NullInt64: sql.NullInt64{Int64: 42, Valid: true}, + NullFloat64: sql.NullFloat64{Float64: 42.42, Valid: true}, + } + newId, err := dORM.Insert(nData) + throwFailNow(t, err) + + var nd *DataNull + query = fmt.Sprintf("SELECT * FROM %sdata_null%s where id=?", Q, Q) + err = dORM.Raw(query, newId).QueryRow(&nd) + throwFailNow(t, err) + + throwFailNow(t, AssertNot(nd, nil)) + throwFail(t, AssertIs(nd.NullBool.Valid, true)) + throwFail(t, AssertIs(nd.NullBool.Bool, true)) + throwFail(t, AssertIs(nd.NullString.Valid, true)) + throwFail(t, AssertIs(nd.NullString.String, "test sql.null")) + throwFail(t, AssertIs(nd.NullInt64.Valid, true)) + throwFail(t, AssertIs(nd.NullInt64.Int64, 42)) + throwFail(t, AssertIs(nd.NullFloat64.Valid, true)) + throwFail(t, AssertIs(nd.NullFloat64.Float64, 42.42)) +} + +// user_profile table +type userProfile struct { + User + Age int + Money float64 +} + +func TestQueryRows(t *testing.T) { + Q := dDbBaser.TableQuote() + + var datas []*Data + + query := fmt.Sprintf("SELECT * FROM %sdata%s", Q, Q) + num, err := dORM.Raw(query).QueryRows(&datas) + throwFailNow(t, err) + throwFailNow(t, AssertIs(num, 1)) + throwFailNow(t, AssertIs(len(datas), 1)) + + ind := reflect.Indirect(reflect.ValueOf(datas[0])) + + for name, value := range DataValues { + e := ind.FieldByName(name) + vu := e.Interface() + switch name { + case "Time": + vu = vu.(time.Time).In(DefaultTimeLoc).Format(testTime) + value = value.(time.Time).In(DefaultTimeLoc).Format(testTime) + case "Date": + vu = vu.(time.Time).In(DefaultTimeLoc).Format(testDate) + value = value.(time.Time).In(DefaultTimeLoc).Format(testDate) + case "DateTime": + vu = vu.(time.Time).In(DefaultTimeLoc).Format(testDateTime) + value = value.(time.Time).In(DefaultTimeLoc).Format(testDateTime) + } + throwFail(t, AssertIs(vu == value, true), value, vu) + } + + var datas2 []Data + + query = fmt.Sprintf("SELECT * FROM %sdata%s", Q, Q) + num, err = dORM.Raw(query).QueryRows(&datas2) + throwFailNow(t, err) + throwFailNow(t, AssertIs(num, 1)) + throwFailNow(t, AssertIs(len(datas2), 1)) + + ind = reflect.Indirect(reflect.ValueOf(datas2[0])) + + for name, value := range DataValues { + e := ind.FieldByName(name) + vu := e.Interface() + switch name { + case "Time": + vu = vu.(time.Time).In(DefaultTimeLoc).Format(testTime) + value = value.(time.Time).In(DefaultTimeLoc).Format(testTime) + case "Date": + vu = vu.(time.Time).In(DefaultTimeLoc).Format(testDate) + value = value.(time.Time).In(DefaultTimeLoc).Format(testDate) + case "DateTime": + vu = vu.(time.Time).In(DefaultTimeLoc).Format(testDateTime) + value = value.(time.Time).In(DefaultTimeLoc).Format(testDateTime) + } + throwFail(t, AssertIs(vu == value, true), value, vu) + } + + var ids []int + var usernames []string + query = fmt.Sprintf("SELECT %sid%s, %suser_name%s FROM %suser%s ORDER BY %sid%s ASC", Q, Q, Q, Q, Q, Q, Q, Q) + num, err = dORM.Raw(query).QueryRows(&ids, &usernames) + throwFailNow(t, err) + throwFailNow(t, AssertIs(num, 3)) + throwFailNow(t, AssertIs(len(ids), 3)) + throwFailNow(t, AssertIs(ids[0], 2)) + throwFailNow(t, AssertIs(usernames[0], "slene")) + throwFailNow(t, AssertIs(ids[1], 3)) + throwFailNow(t, AssertIs(usernames[1], "astaxie")) + throwFailNow(t, AssertIs(ids[2], 4)) + throwFailNow(t, AssertIs(usernames[2], "nobody")) + + // test query rows by nested struct + var l []userProfile + query = fmt.Sprintf("SELECT * FROM %suser_profile%s LEFT JOIN %suser%s ON %suser_profile%s.%sid%s = %suser%s.%sid%s", Q, Q, Q, Q, Q, Q, Q, Q, Q, Q, Q, Q) + num, err = dORM.Raw(query).QueryRows(&l) + throwFailNow(t, err) + throwFailNow(t, AssertIs(num, 2)) + throwFailNow(t, AssertIs(len(l), 2)) + throwFailNow(t, AssertIs(l[0].UserName, "slene")) + throwFailNow(t, AssertIs(l[0].Age, 28)) + throwFailNow(t, AssertIs(l[1].UserName, "astaxie")) + throwFailNow(t, AssertIs(l[1].Age, 30)) + + // test for sql.Null* fields + nData := &DataNull{ + NullString: sql.NullString{String: "test sql.null", Valid: true}, + NullBool: sql.NullBool{Bool: true, Valid: true}, + NullInt64: sql.NullInt64{Int64: 42, Valid: true}, + NullFloat64: sql.NullFloat64{Float64: 42.42, Valid: true}, + } + newId, err := dORM.Insert(nData) + throwFailNow(t, err) + + var nDataList []*DataNull + query = fmt.Sprintf("SELECT * FROM %sdata_null%s where id=?", Q, Q) + num, err = dORM.Raw(query, newId).QueryRows(&nDataList) + throwFailNow(t, err) + throwFailNow(t, AssertIs(num, 1)) + + nd := nDataList[0] + throwFailNow(t, AssertNot(nd, nil)) + throwFail(t, AssertIs(nd.NullBool.Valid, true)) + throwFail(t, AssertIs(nd.NullBool.Bool, true)) + throwFail(t, AssertIs(nd.NullString.Valid, true)) + throwFail(t, AssertIs(nd.NullString.String, "test sql.null")) + throwFail(t, AssertIs(nd.NullInt64.Valid, true)) + throwFail(t, AssertIs(nd.NullInt64.Int64, 42)) + throwFail(t, AssertIs(nd.NullFloat64.Valid, true)) + throwFail(t, AssertIs(nd.NullFloat64.Float64, 42.42)) +} + +func TestRawValues(t *testing.T) { + Q := dDbBaser.TableQuote() + + var maps []Params + query := fmt.Sprintf("SELECT %suser_name%s FROM %suser%s WHERE %sStatus%s = ?", Q, Q, Q, Q, Q, Q) + num, err := dORM.Raw(query, 1).Values(&maps) + throwFail(t, err) + throwFail(t, AssertIs(num, 1)) + if num == 1 { + throwFail(t, AssertIs(maps[0]["user_name"], "slene")) + } + + var lists []ParamsList + num, err = dORM.Raw(query, 1).ValuesList(&lists) + throwFail(t, err) + throwFail(t, AssertIs(num, 1)) + if num == 1 { + throwFail(t, AssertIs(lists[0][0], "slene")) + } + + query = fmt.Sprintf("SELECT %sprofile_id%s FROM %suser%s ORDER BY %sid%s ASC", Q, Q, Q, Q, Q, Q) + var list ParamsList + num, err = dORM.Raw(query).ValuesFlat(&list) + throwFail(t, err) + throwFail(t, AssertIs(num, 3)) + if num == 3 { + throwFail(t, AssertIs(list[0], "2")) + throwFail(t, AssertIs(list[1], "3")) + throwFail(t, AssertIs(list[2], nil)) + } +} + +func TestRawPrepare(t *testing.T) { + switch { + case IsMysql || IsSqlite: + + pre, err := dORM.Raw("INSERT INTO tag (name) VALUES (?)").Prepare() + throwFail(t, err) + if pre != nil { + r, err := pre.Exec("name1") + throwFail(t, err) + + tid, err := r.LastInsertId() + throwFail(t, err) + throwFail(t, AssertIs(tid > 0, true)) + + r, err = pre.Exec("name2") + throwFail(t, err) + + id, err := r.LastInsertId() + throwFail(t, err) + throwFail(t, AssertIs(id, tid+1)) + + r, err = pre.Exec("name3") + throwFail(t, err) + + id, err = r.LastInsertId() + throwFail(t, err) + throwFail(t, AssertIs(id, tid+2)) + + err = pre.Close() + throwFail(t, err) + + res, err := dORM.Raw("DELETE FROM tag WHERE name IN (?, ?, ?)", []string{"name1", "name2", "name3"}).Exec() + throwFail(t, err) + + num, err := res.RowsAffected() + throwFail(t, err) + throwFail(t, AssertIs(num, 3)) + } + + case IsPostgres: + + pre, err := dORM.Raw(`INSERT INTO "tag" ("name") VALUES (?) RETURNING "id"`).Prepare() + throwFail(t, err) + if pre != nil { + _, err := pre.Exec("name1") + throwFail(t, err) + + _, err = pre.Exec("name2") + throwFail(t, err) + + _, err = pre.Exec("name3") + throwFail(t, err) + + err = pre.Close() + throwFail(t, err) + + res, err := dORM.Raw(`DELETE FROM "tag" WHERE "name" IN (?, ?, ?)`, []string{"name1", "name2", "name3"}).Exec() + throwFail(t, err) + + if err == nil { + num, err := res.RowsAffected() + throwFail(t, err) + throwFail(t, AssertIs(num, 3)) + } + } + } +} + +func TestUpdate(t *testing.T) { + qs := dORM.QueryTable("user") + num, err := qs.Filter("user_name", "slene").Filter("is_staff", false).Update(Params{ + "is_staff": true, + "is_active": true, + }) + throwFail(t, err) + throwFail(t, AssertIs(num, 1)) + + // with join + num, err = qs.Filter("user_name", "slene").Filter("profile__age", 28).Filter("is_staff", true).Update(Params{ + "is_staff": false, + }) + throwFail(t, err) + throwFail(t, AssertIs(num, 1)) + + num, err = qs.Filter("user_name", "slene").Update(Params{ + "Nums": ColValue(ColAdd, 100), + }) + throwFail(t, err) + throwFail(t, AssertIs(num, 1)) + + num, err = qs.Filter("user_name", "slene").Update(Params{ + "Nums": ColValue(ColMinus, 50), + }) + throwFail(t, err) + throwFail(t, AssertIs(num, 1)) + + num, err = qs.Filter("user_name", "slene").Update(Params{ + "Nums": ColValue(ColMultiply, 3), + }) + throwFail(t, err) + throwFail(t, AssertIs(num, 1)) + + num, err = qs.Filter("user_name", "slene").Update(Params{ + "Nums": ColValue(ColExcept, 5), + }) + throwFail(t, err) + throwFail(t, AssertIs(num, 1)) + + user := User{UserName: "slene"} + err = dORM.Read(&user, "UserName") + throwFail(t, err) + throwFail(t, AssertIs(user.Nums, 30)) +} + +func TestDelete(t *testing.T) { + qs := dORM.QueryTable("user_profile") + num, err := qs.Filter("user__user_name", "slene").Delete() + throwFail(t, err) + throwFail(t, AssertIs(num, 1)) + + qs = dORM.QueryTable("user") + num, err = qs.Filter("user_name", "slene").Filter("profile__isnull", true).Count() + throwFail(t, err) + throwFail(t, AssertIs(num, 1)) + + qs = dORM.QueryTable("comment") + num, err = qs.Count() + throwFail(t, err) + throwFail(t, AssertIs(num, 6)) + + qs = dORM.QueryTable("post") + num, err = qs.Filter("Id", 3).Delete() + throwFail(t, err) + throwFail(t, AssertIs(num, 1)) + + qs = dORM.QueryTable("comment") + num, err = qs.Count() + throwFail(t, err) + throwFail(t, AssertIs(num, 4)) + + qs = dORM.QueryTable("comment") + num, err = qs.Filter("Post__User", 3).Delete() + throwFail(t, err) + throwFail(t, AssertIs(num, 3)) + + qs = dORM.QueryTable("comment") + num, err = qs.Count() + throwFail(t, err) + throwFail(t, AssertIs(num, 1)) +} + +func TestTransaction(t *testing.T) { + // this test worked when database support transaction + + o := NewOrm() + err := o.Begin() + throwFail(t, err) + + var names = []string{"1", "2", "3"} + + var tag Tag + tag.Name = names[0] + id, err := o.Insert(&tag) + throwFail(t, err) + throwFail(t, AssertIs(id > 0, true)) + + num, err := o.QueryTable("tag").Filter("name", "golang").Update(Params{"name": names[1]}) + throwFail(t, err) + throwFail(t, AssertIs(num, 1)) + + switch { + case IsMysql || IsSqlite: + res, err := o.Raw("INSERT INTO tag (name) VALUES (?)", names[2]).Exec() + throwFail(t, err) + if err == nil { + id, err = res.LastInsertId() + throwFail(t, err) + throwFail(t, AssertIs(id > 0, true)) + } + } + + err = o.Rollback() + throwFail(t, err) + + num, err = o.QueryTable("tag").Filter("name__in", names).Count() + throwFail(t, err) + throwFail(t, AssertIs(num, 0)) + + err = o.Begin() + throwFail(t, err) + + tag.Name = "commit" + id, err = o.Insert(&tag) + throwFail(t, err) + throwFail(t, AssertIs(id > 0, true)) + + o.Commit() + throwFail(t, err) + + num, err = o.QueryTable("tag").Filter("name", "commit").Delete() + throwFail(t, err) + throwFail(t, AssertIs(num, 1)) + +} + +func TestTransactionIsolationLevel(t *testing.T) { + // this test worked when database support transaction isolation level + if IsSqlite { + return + } + + o1 := NewOrm() + o2 := NewOrm() + + // start two transaction with isolation level repeatable read + err := o1.BeginTx(context.Background(), &sql.TxOptions{Isolation: sql.LevelRepeatableRead}) + throwFail(t, err) + err = o2.BeginTx(context.Background(), &sql.TxOptions{Isolation: sql.LevelRepeatableRead}) + throwFail(t, err) + + // o1 insert tag + var tag Tag + tag.Name = "test-transaction" + id, err := o1.Insert(&tag) + throwFail(t, err) + throwFail(t, AssertIs(id > 0, true)) + + // o2 query tag table, no result + num, err := o2.QueryTable("tag").Filter("name", "test-transaction").Count() + throwFail(t, err) + throwFail(t, AssertIs(num, 0)) + + // o1 commit + o1.Commit() + + // o2 query tag table, still no result + num, err = o2.QueryTable("tag").Filter("name", "test-transaction").Count() + throwFail(t, err) + throwFail(t, AssertIs(num, 0)) + + // o2 commit and query tag table, get the result + o2.Commit() + num, err = o2.QueryTable("tag").Filter("name", "test-transaction").Count() + throwFail(t, err) + throwFail(t, AssertIs(num, 1)) + + num, err = o1.QueryTable("tag").Filter("name", "test-transaction").Delete() + throwFail(t, err) + throwFail(t, AssertIs(num, 1)) +} + +func TestBeginTxWithContextCanceled(t *testing.T) { + o := NewOrm() + ctx, cancel := context.WithCancel(context.Background()) + o.BeginTx(ctx, nil) + id, err := o.Insert(&Tag{Name: "test-context"}) + throwFail(t, err) + throwFail(t, AssertIs(id > 0, true)) + + // cancel the context before commit to make it error + cancel() + err = o.Commit() + throwFail(t, AssertIs(err, context.Canceled)) +} + +func TestReadOrCreate(t *testing.T) { + u := &User{ + UserName: "Kyle", + Email: "kylemcc@gmail.com", + Password: "other_pass", + Status: 7, + IsStaff: false, + IsActive: true, + } + + created, pk, err := dORM.ReadOrCreate(u, "UserName") + throwFail(t, err) + throwFail(t, AssertIs(created, true)) + throwFail(t, AssertIs(u.ID, pk)) + throwFail(t, AssertIs(u.UserName, "Kyle")) + throwFail(t, AssertIs(u.Email, "kylemcc@gmail.com")) + throwFail(t, AssertIs(u.Password, "other_pass")) + throwFail(t, AssertIs(u.Status, 7)) + throwFail(t, AssertIs(u.IsStaff, false)) + throwFail(t, AssertIs(u.IsActive, true)) + throwFail(t, AssertIs(u.Created.In(DefaultTimeLoc), u.Created.In(DefaultTimeLoc), testDate)) + throwFail(t, AssertIs(u.Updated.In(DefaultTimeLoc), u.Updated.In(DefaultTimeLoc), testDateTime)) + + nu := &User{UserName: u.UserName, Email: "someotheremail@gmail.com"} + created, pk, err = dORM.ReadOrCreate(nu, "UserName") + throwFail(t, err) + throwFail(t, AssertIs(created, false)) + throwFail(t, AssertIs(nu.ID, u.ID)) + throwFail(t, AssertIs(pk, u.ID)) + throwFail(t, AssertIs(nu.UserName, u.UserName)) + throwFail(t, AssertIs(nu.Email, u.Email)) // should contain the value in the table, not the one specified above + throwFail(t, AssertIs(nu.Password, u.Password)) + throwFail(t, AssertIs(nu.Status, u.Status)) + throwFail(t, AssertIs(nu.IsStaff, u.IsStaff)) + throwFail(t, AssertIs(nu.IsActive, u.IsActive)) + + dORM.Delete(u) +} + +func TestInLine(t *testing.T) { + name := "inline" + email := "hello@go.com" + inline := NewInLine() + inline.Name = name + inline.Email = email + + id, err := dORM.Insert(inline) + throwFail(t, err) + throwFail(t, AssertIs(id, 1)) + + il := NewInLine() + il.ID = 1 + err = dORM.Read(il) + throwFail(t, err) + + throwFail(t, AssertIs(il.Name, name)) + throwFail(t, AssertIs(il.Email, email)) + throwFail(t, AssertIs(il.Created.In(DefaultTimeLoc), inline.Created.In(DefaultTimeLoc), testDate)) + throwFail(t, AssertIs(il.Updated.In(DefaultTimeLoc), inline.Updated.In(DefaultTimeLoc), testDateTime)) +} + +func TestInLineOneToOne(t *testing.T) { + name := "121" + email := "121@go.com" + inline := NewInLine() + inline.Name = name + inline.Email = email + + id, err := dORM.Insert(inline) + throwFail(t, err) + throwFail(t, AssertIs(id, 2)) + + note := "one2one" + il121 := NewInLineOneToOne() + il121.Note = note + il121.InLine = inline + _, err = dORM.Insert(il121) + throwFail(t, err) + throwFail(t, AssertIs(il121.ID, 1)) + + il := NewInLineOneToOne() + err = dORM.QueryTable(il).Filter("Id", 1).RelatedSel().One(il) + + throwFail(t, err) + throwFail(t, AssertIs(il.Note, note)) + throwFail(t, AssertIs(il.InLine.ID, id)) + throwFail(t, AssertIs(il.InLine.Name, name)) + throwFail(t, AssertIs(il.InLine.Email, email)) + + rinline := NewInLine() + err = dORM.QueryTable(rinline).Filter("InLineOneToOne__Id", 1).One(rinline) + + throwFail(t, err) + throwFail(t, AssertIs(rinline.ID, id)) + throwFail(t, AssertIs(rinline.Name, name)) + throwFail(t, AssertIs(rinline.Email, email)) +} + +func TestIntegerPk(t *testing.T) { + its := []IntegerPk{ + {ID: math.MinInt64, Value: "-"}, + {ID: 0, Value: "0"}, + {ID: math.MaxInt64, Value: "+"}, + } + + num, err := dORM.InsertMulti(len(its), its) + throwFail(t, err) + throwFail(t, AssertIs(num, len(its))) + + for _, intPk := range its { + out := IntegerPk{ID: intPk.ID} + err = dORM.Read(&out) + throwFail(t, err) + throwFail(t, AssertIs(out.Value, intPk.Value)) + } + + num, err = dORM.InsertMulti(1, []*IntegerPk{{ + ID: 1, Value: "ok", + }}) + throwFail(t, err) + throwFail(t, AssertIs(num, 1)) +} + +func TestInsertAuto(t *testing.T) { + u := &User{ + UserName: "autoPre", + Email: "autoPre@gmail.com", + } + + id, err := dORM.Insert(u) + throwFail(t, err) + + id += 100 + su := &User{ + ID: int(id), + UserName: "auto", + Email: "auto@gmail.com", + } + + nid, err := dORM.Insert(su) + throwFail(t, err) + throwFail(t, AssertIs(nid, id)) + + users := []User{ + {ID: int(id + 100), UserName: "auto_100"}, + {ID: int(id + 110), UserName: "auto_110"}, + {ID: int(id + 120), UserName: "auto_120"}, + } + num, err := dORM.InsertMulti(100, users) + throwFail(t, err) + throwFail(t, AssertIs(num, 3)) + + u = &User{ + UserName: "auto_121", + } + + nid, err = dORM.Insert(u) + throwFail(t, err) + throwFail(t, AssertIs(nid, id+120+1)) +} + +func TestUintPk(t *testing.T) { + name := "go" + u := &UintPk{ + ID: 8, + Name: name, + } + + created, _, err := dORM.ReadOrCreate(u, "ID") + throwFail(t, err) + throwFail(t, AssertIs(created, true)) + throwFail(t, AssertIs(u.Name, name)) + + nu := &UintPk{ID: 8} + created, pk, err := dORM.ReadOrCreate(nu, "ID") + throwFail(t, err) + throwFail(t, AssertIs(created, false)) + throwFail(t, AssertIs(nu.ID, u.ID)) + throwFail(t, AssertIs(pk, u.ID)) + throwFail(t, AssertIs(nu.Name, name)) + + dORM.Delete(u) +} + +func TestPtrPk(t *testing.T) { + parent := &IntegerPk{ID: 10, Value: "10"} + + id, _ := dORM.Insert(parent) + if !IsMysql { + // MySql does not support last_insert_id in this case: see #2382 + throwFail(t, AssertIs(id, 10)) + } + + ptr := PtrPk{ID: parent, Positive: true} + num, err := dORM.InsertMulti(2, []PtrPk{ptr}) + throwFail(t, err) + throwFail(t, AssertIs(num, 1)) + throwFail(t, AssertIs(ptr.ID, parent)) + + nptr := &PtrPk{ID: parent} + created, pk, err := dORM.ReadOrCreate(nptr, "ID") + throwFail(t, err) + throwFail(t, AssertIs(created, false)) + throwFail(t, AssertIs(pk, 10)) + throwFail(t, AssertIs(nptr.ID, parent)) + throwFail(t, AssertIs(nptr.Positive, true)) + + nptr = &PtrPk{Positive: true} + created, pk, err = dORM.ReadOrCreate(nptr, "Positive") + throwFail(t, err) + throwFail(t, AssertIs(created, false)) + throwFail(t, AssertIs(pk, 10)) + throwFail(t, AssertIs(nptr.ID, parent)) + + nptr.Positive = false + num, err = dORM.Update(nptr) + throwFail(t, err) + throwFail(t, AssertIs(num, 1)) + throwFail(t, AssertIs(nptr.ID, parent)) + throwFail(t, AssertIs(nptr.Positive, false)) + + num, err = dORM.Delete(nptr) + throwFail(t, err) + throwFail(t, AssertIs(num, 1)) +} + +func TestSnake(t *testing.T) { + cases := map[string]string{ + "i": "i", + "I": "i", + "iD": "i_d", + "ID": "i_d", + "NO": "n_o", + "NOO": "n_o_o", + "NOOooOOoo": "n_o_ooo_o_ooo", + "OrderNO": "order_n_o", + "tagName": "tag_name", + "tag_Name": "tag__name", + "tag_name": "tag_name", + "_tag_name": "_tag_name", + "tag_666name": "tag_666name", + "tag_666Name": "tag_666_name", + } + for name, want := range cases { + got := snakeString(name) + throwFail(t, AssertIs(got, want)) + } +} + +func TestIgnoreCaseTag(t *testing.T) { + type testTagModel struct { + ID int `orm:"pk"` + NOO string `orm:"column(n)"` + Name01 string `orm:"NULL"` + Name02 string `orm:"COLUMN(Name)"` + Name03 string `orm:"Column(name)"` + } + modelCache.clean() + RegisterModel(&testTagModel{}) + info, ok := modelCache.get("test_tag_model") + throwFail(t, AssertIs(ok, true)) + throwFail(t, AssertNot(info, nil)) + if t == nil { + return + } + throwFail(t, AssertIs(info.fields.GetByName("NOO").column, "n")) + throwFail(t, AssertIs(info.fields.GetByName("Name01").null, true)) + throwFail(t, AssertIs(info.fields.GetByName("Name02").column, "Name")) + throwFail(t, AssertIs(info.fields.GetByName("Name03").column, "name")) +} + +func TestInsertOrUpdate(t *testing.T) { + RegisterModel(new(User)) + user := User{UserName: "unique_username133", Status: 1, Password: "o"} + user1 := User{UserName: "unique_username133", Status: 2, Password: "o"} + user2 := User{UserName: "unique_username133", Status: 3, Password: "oo"} + dORM.Insert(&user) + test := User{UserName: "unique_username133"} + fmt.Println(dORM.Driver().Name()) + if dORM.Driver().Name() == "sqlite3" { + fmt.Println("sqlite3 is nonsupport") + return + } + // test1 + _, err := dORM.InsertOrUpdate(&user1, "user_name") + if err != nil { + fmt.Println(err) + if err.Error() == "postgres version must 9.5 or higher" || err.Error() == "`sqlite3` nonsupport InsertOrUpdate in beego" { + } else { + throwFailNow(t, err) + } + } else { + dORM.Read(&test, "user_name") + throwFailNow(t, AssertIs(user1.Status, test.Status)) + } + // test2 + _, err = dORM.InsertOrUpdate(&user2, "user_name") + if err != nil { + fmt.Println(err) + if err.Error() == "postgres version must 9.5 or higher" || err.Error() == "`sqlite3` nonsupport InsertOrUpdate in beego" { + } else { + throwFailNow(t, err) + } + } else { + dORM.Read(&test, "user_name") + throwFailNow(t, AssertIs(user2.Status, test.Status)) + throwFailNow(t, AssertIs(user2.Password, strings.TrimSpace(test.Password))) + } + + // postgres ON CONFLICT DO UPDATE SET can`t use colu=colu+values + if IsPostgres { + return + } + // test3 + + _, err = dORM.InsertOrUpdate(&user2, "user_name", "status=status+1") + if err != nil { + fmt.Println(err) + if err.Error() == "postgres version must 9.5 or higher" || err.Error() == "`sqlite3` nonsupport InsertOrUpdate in beego" { + } else { + throwFailNow(t, err) + } + } else { + dORM.Read(&test, "user_name") + throwFailNow(t, AssertIs(user2.Status+1, test.Status)) + } + // test4 - + _, err = dORM.InsertOrUpdate(&user2, "user_name", "status=status-1") + if err != nil { + fmt.Println(err) + if err.Error() == "postgres version must 9.5 or higher" || err.Error() == "`sqlite3` nonsupport InsertOrUpdate in beego" { + } else { + throwFailNow(t, err) + } + } else { + dORM.Read(&test, "user_name") + throwFailNow(t, AssertIs((user2.Status+1)-1, test.Status)) + } + // test5 * + _, err = dORM.InsertOrUpdate(&user2, "user_name", "status=status*3") + if err != nil { + fmt.Println(err) + if err.Error() == "postgres version must 9.5 or higher" || err.Error() == "`sqlite3` nonsupport InsertOrUpdate in beego" { + } else { + throwFailNow(t, err) + } + } else { + dORM.Read(&test, "user_name") + throwFailNow(t, AssertIs(((user2.Status+1)-1)*3, test.Status)) + } + // test6 / + _, err = dORM.InsertOrUpdate(&user2, "user_name", "Status=Status/3") + if err != nil { + fmt.Println(err) + if err.Error() == "postgres version must 9.5 or higher" || err.Error() == "`sqlite3` nonsupport InsertOrUpdate in beego" { + } else { + throwFailNow(t, err) + } + } else { + dORM.Read(&test, "user_name") + throwFailNow(t, AssertIs((((user2.Status+1)-1)*3)/3, test.Status)) + } +} diff --git a/orm/utils_test.go b/orm/utils_test.go new file mode 100644 index 00000000..7d94cada --- /dev/null +++ b/orm/utils_test.go @@ -0,0 +1,70 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package orm + +import ( + "testing" +) + +func TestCamelString(t *testing.T) { + snake := []string{"pic_url", "hello_world_", "hello__World", "_HelLO_Word", "pic_url_1", "pic_url__1"} + camel := []string{"PicUrl", "HelloWorld", "HelloWorld", "HelLOWord", "PicUrl1", "PicUrl1"} + + answer := make(map[string]string) + for i, v := range snake { + answer[v] = camel[i] + } + + for _, v := range snake { + res := camelString(v) + if res != answer[v] { + t.Error("Unit Test Fail:", v, res, answer[v]) + } + } +} + +func TestSnakeString(t *testing.T) { + camel := []string{"PicUrl", "HelloWorld", "HelloWorld", "HelLOWord", "PicUrl1", "XyXX"} + snake := []string{"pic_url", "hello_world", "hello_world", "hel_l_o_word", "pic_url1", "xy_x_x"} + + answer := make(map[string]string) + for i, v := range camel { + answer[v] = snake[i] + } + + for _, v := range camel { + res := snakeString(v) + if res != answer[v] { + t.Error("Unit Test Fail:", v, res, answer[v]) + } + } +} + +func TestSnakeStringWithAcronym(t *testing.T) { + camel := []string{"ID", "PicURL", "HelloWorld", "HelloWorld", "HelLOWord", "PicUrl1", "XyXX"} + snake := []string{"id", "pic_url", "hello_world", "hello_world", "hel_lo_word", "pic_url1", "xy_xx"} + + answer := make(map[string]string) + for i, v := range camel { + answer[v] = snake[i] + } + + for _, v := range camel { + res := snakeStringWithAcronym(v) + if res != answer[v] { + t.Error("Unit Test Fail:", v, res, answer[v]) + } + } +} diff --git a/pkg/LICENSE b/pkg/LICENSE new file mode 100644 index 00000000..5dbd4243 --- /dev/null +++ b/pkg/LICENSE @@ -0,0 +1,13 @@ +Copyright 2014 astaxie + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. \ No newline at end of file diff --git a/test.sh b/scripts/test.sh similarity index 94% rename from test.sh rename to scripts/test.sh index 78928fea..d626d24b 100644 --- a/test.sh +++ b/scripts/test.sh @@ -6,7 +6,7 @@ export ORM_DRIVER=mysql export TZ=UTC export ORM_SOURCE="beego:test@tcp(localhost:13306)/orm_test?charset=utf8" -go test ./... +go test ../... # clear all container docker-compose -f test_docker_compose.yaml down diff --git a/test_docker_compose.yaml b/scripts/test_docker_compose.yaml similarity index 100% rename from test_docker_compose.yaml rename to scripts/test_docker_compose.yaml