From 9c51952db485cb32a6658df173f622f6930199cd Mon Sep 17 00:00:00 2001 From: Ming Deng Date: Wed, 22 Jul 2020 22:50:08 +0800 Subject: [PATCH] Move package --- orm/cmd_utils.go | 10 +- orm/db_alias.go | 87 +- orm/orm_alias_adapt_test.go | 46 - pkg/admin.go | 458 +++++++ pkg/admin_test.go | 239 ++++ pkg/adminui.go | 356 ++++++ pkg/app.go | 496 ++++++++ pkg/beego.go | 123 ++ pkg/build_info.go | 27 + pkg/cache/README.md | 59 + pkg/cache/cache.go | 103 ++ pkg/cache/cache_test.go | 191 +++ pkg/cache/conv.go | 100 ++ pkg/cache/conv_test.go | 143 +++ pkg/cache/file.go | 258 ++++ pkg/cache/memcache/memcache.go | 188 +++ pkg/cache/memcache/memcache_test.go | 108 ++ pkg/cache/memory.go | 256 ++++ pkg/cache/redis/redis.go | 272 +++++ pkg/cache/redis/redis_test.go | 144 +++ pkg/cache/ssdb/ssdb.go | 231 ++++ pkg/cache/ssdb/ssdb_test.go | 104 ++ pkg/config.go | 524 ++++++++ pkg/config/config.go | 242 ++++ pkg/config/config_test.go | 55 + pkg/config/env/env.go | 87 ++ pkg/config/env/env_test.go | 75 ++ pkg/config/fake.go | 134 +++ pkg/config/ini.go | 504 ++++++++ pkg/config/ini_test.go | 190 +++ pkg/config/json.go | 269 +++++ pkg/config/json_test.go | 222 ++++ pkg/config/xml/xml.go | 228 ++++ pkg/config/xml/xml_test.go | 125 ++ pkg/config/yaml/yaml.go | 316 +++++ pkg/config/yaml/yaml_test.go | 115 ++ pkg/config_test.go | 146 +++ pkg/context/acceptencoder.go | 232 ++++ pkg/context/acceptencoder_test.go | 59 + pkg/context/context.go | 263 +++++ pkg/context/context_test.go | 47 + pkg/context/input.go | 689 +++++++++++ pkg/context/input_test.go | 217 ++++ pkg/context/output.go | 408 +++++++ pkg/context/param/conv.go | 78 ++ pkg/context/param/methodparams.go | 69 ++ pkg/context/param/options.go | 37 + pkg/context/param/parsers.go | 149 +++ pkg/context/param/parsers_test.go | 84 ++ pkg/context/renderer.go | 12 + pkg/context/response.go | 27 + pkg/controller.go | 706 +++++++++++ pkg/controller_test.go | 181 +++ pkg/doc.go | 17 + pkg/error.go | 488 ++++++++ pkg/error_test.go | 88 ++ pkg/filter.go | 44 + pkg/filter_test.go | 68 ++ pkg/flash.go | 110 ++ pkg/flash_test.go | 54 + pkg/fs.go | 74 ++ pkg/grace/grace.go | 166 +++ pkg/grace/server.go | 356 ++++++ pkg/hooks.go | 104 ++ pkg/httplib/README.md | 97 ++ pkg/httplib/httplib.go | 654 ++++++++++ pkg/httplib/httplib_test.go | 286 +++++ pkg/log.go | 127 ++ pkg/logs/README.md | 72 ++ pkg/logs/accesslog.go | 83 ++ pkg/logs/alils/alils.go | 186 +++ pkg/logs/alils/config.go | 13 + pkg/logs/alils/log.pb.go | 1038 ++++++++++++++++ pkg/logs/alils/log_config.go | 42 + pkg/logs/alils/log_project.go | 819 +++++++++++++ pkg/logs/alils/log_store.go | 271 +++++ pkg/logs/alils/machine_group.go | 91 ++ pkg/logs/alils/request.go | 62 + pkg/logs/alils/signature.go | 111 ++ pkg/logs/conn.go | 119 ++ pkg/logs/conn_test.go | 79 ++ pkg/logs/console.go | 99 ++ pkg/logs/console_test.go | 64 + pkg/logs/es/es.go | 102 ++ pkg/logs/file.go | 409 +++++++ pkg/logs/file_test.go | 420 +++++++ pkg/logs/jianliao.go | 72 ++ pkg/logs/log.go | 669 +++++++++++ pkg/logs/logger.go | 176 +++ pkg/logs/logger_test.go | 57 + pkg/logs/multifile.go | 119 ++ pkg/logs/multifile_test.go | 78 ++ pkg/logs/slack.go | 60 + pkg/logs/smtp.go | 149 +++ pkg/logs/smtp_test.go | 27 + pkg/metric/prometheus.go | 99 ++ pkg/metric/prometheus_test.go | 42 + pkg/migration/ddl.go | 395 +++++++ pkg/migration/doc.go | 32 + pkg/migration/migration.go | 330 ++++++ pkg/mime.go | 556 +++++++++ pkg/namespace.go | 396 +++++++ pkg/namespace_test.go | 168 +++ pkg/parser.go | 591 +++++++++ pkg/plugins/apiauth/apiauth.go | 165 +++ pkg/plugins/apiauth/apiauth_test.go | 20 + pkg/plugins/auth/basic.go | 107 ++ pkg/plugins/authz/authz.go | 86 ++ pkg/plugins/authz/authz_model.conf | 14 + pkg/plugins/authz/authz_policy.csv | 7 + pkg/plugins/authz/authz_test.go | 107 ++ pkg/plugins/cors/cors.go | 228 ++++ pkg/plugins/cors/cors_test.go | 253 ++++ pkg/policy.go | 97 ++ pkg/router.go | 1052 +++++++++++++++++ pkg/router_test.go | 732 ++++++++++++ pkg/session/README.md | 114 ++ pkg/session/couchbase/sess_couchbase.go | 247 ++++ pkg/session/ledis/ledis_session.go | 173 +++ pkg/session/memcache/sess_memcache.go | 230 ++++ pkg/session/mysql/sess_mysql.go | 228 ++++ pkg/session/postgres/sess_postgresql.go | 243 ++++ pkg/session/redis/sess_redis.go | 261 ++++ pkg/session/redis_cluster/redis_cluster.go | 220 ++++ .../redis_sentinel/sess_redis_sentinel.go | 234 ++++ .../sess_redis_sentinel_test.go | 90 ++ pkg/session/sess_cookie.go | 180 +++ pkg/session/sess_cookie_test.go | 105 ++ pkg/session/sess_file.go | 315 +++++ pkg/session/sess_file_test.go | 387 ++++++ pkg/session/sess_mem.go | 196 +++ pkg/session/sess_mem_test.go | 58 + pkg/session/sess_test.go | 131 ++ pkg/session/sess_utils.go | 207 ++++ pkg/session/session.go | 377 ++++++ pkg/session/ssdb/sess_ssdb.go | 199 ++++ pkg/staticfile.go | 234 ++++ pkg/staticfile_test.go | 99 ++ pkg/swagger/swagger.go | 174 +++ pkg/template.go | 406 +++++++ pkg/template_test.go | 316 +++++ pkg/templatefunc.go | 780 ++++++++++++ pkg/templatefunc_test.go | 380 ++++++ pkg/testdata/Makefile | 2 + pkg/testdata/bindata.go | 296 +++++ pkg/testdata/views/blocks/block.tpl | 3 + pkg/testdata/views/header.tpl | 3 + pkg/testdata/views/index.tpl | 15 + pkg/testing/assertions.go | 15 + pkg/testing/client.go | 65 + pkg/toolbox/healthcheck.go | 48 + pkg/toolbox/profile.go | 184 +++ pkg/toolbox/profile_test.go | 28 + pkg/toolbox/statistics.go | 149 +++ pkg/toolbox/statistics_test.go | 40 + pkg/toolbox/task.go | 640 ++++++++++ pkg/toolbox/task_test.go | 85 ++ pkg/tree.go | 585 +++++++++ pkg/tree_test.go | 306 +++++ pkg/unregroute_test.go | 226 ++++ pkg/utils/caller.go | 25 + pkg/utils/caller_test.go | 28 + pkg/utils/captcha/LICENSE | 19 + pkg/utils/captcha/README.md | 45 + pkg/utils/captcha/captcha.go | 270 +++++ pkg/utils/captcha/image.go | 501 ++++++++ pkg/utils/captcha/image_test.go | 52 + pkg/utils/captcha/siprng.go | 277 +++++ pkg/utils/captcha/siprng_test.go | 33 + pkg/utils/debug.go | 478 ++++++++ pkg/utils/debug_test.go | 46 + pkg/utils/file.go | 101 ++ pkg/utils/file_test.go | 75 ++ pkg/utils/mail.go | 424 +++++++ pkg/utils/mail_test.go | 41 + pkg/utils/pagination/controller.go | 26 + pkg/utils/pagination/doc.go | 58 + pkg/utils/pagination/paginator.go | 189 +++ pkg/utils/pagination/utils.go | 34 + pkg/utils/rand.go | 44 + pkg/utils/rand_test.go | 33 + pkg/utils/safemap.go | 91 ++ pkg/utils/safemap_test.go | 89 ++ pkg/utils/slice.go | 170 +++ pkg/utils/slice_test.go | 29 + pkg/utils/testdata/grepe.test | 7 + pkg/utils/utils.go | 89 ++ pkg/utils/utils_test.go | 36 + pkg/validation/README.md | 147 +++ pkg/validation/util.go | 298 +++++ pkg/validation/util_test.go | 128 ++ pkg/validation/validation.go | 456 +++++++ pkg/validation/validation_test.go | 609 ++++++++++ pkg/validation/validators.go | 738 ++++++++++++ 194 files changed, 39077 insertions(+), 69 deletions(-) delete mode 100644 orm/orm_alias_adapt_test.go create mode 100644 pkg/admin.go create mode 100644 pkg/admin_test.go create mode 100644 pkg/adminui.go create mode 100644 pkg/app.go create mode 100644 pkg/beego.go create mode 100644 pkg/build_info.go create mode 100644 pkg/cache/README.md create mode 100644 pkg/cache/cache.go create mode 100644 pkg/cache/cache_test.go create mode 100644 pkg/cache/conv.go create mode 100644 pkg/cache/conv_test.go create mode 100644 pkg/cache/file.go create mode 100644 pkg/cache/memcache/memcache.go create mode 100644 pkg/cache/memcache/memcache_test.go create mode 100644 pkg/cache/memory.go create mode 100644 pkg/cache/redis/redis.go create mode 100644 pkg/cache/redis/redis_test.go create mode 100644 pkg/cache/ssdb/ssdb.go create mode 100644 pkg/cache/ssdb/ssdb_test.go create mode 100644 pkg/config.go create mode 100644 pkg/config/config.go create mode 100644 pkg/config/config_test.go create mode 100644 pkg/config/env/env.go create mode 100644 pkg/config/env/env_test.go create mode 100644 pkg/config/fake.go create mode 100644 pkg/config/ini.go create mode 100644 pkg/config/ini_test.go create mode 100644 pkg/config/json.go create mode 100644 pkg/config/json_test.go create mode 100644 pkg/config/xml/xml.go create mode 100644 pkg/config/xml/xml_test.go create mode 100644 pkg/config/yaml/yaml.go create mode 100644 pkg/config/yaml/yaml_test.go create mode 100644 pkg/config_test.go create mode 100644 pkg/context/acceptencoder.go create mode 100644 pkg/context/acceptencoder_test.go create mode 100644 pkg/context/context.go create mode 100644 pkg/context/context_test.go create mode 100644 pkg/context/input.go create mode 100644 pkg/context/input_test.go create mode 100644 pkg/context/output.go create mode 100644 pkg/context/param/conv.go create mode 100644 pkg/context/param/methodparams.go create mode 100644 pkg/context/param/options.go create mode 100644 pkg/context/param/parsers.go create mode 100644 pkg/context/param/parsers_test.go create mode 100644 pkg/context/renderer.go create mode 100644 pkg/context/response.go create mode 100644 pkg/controller.go create mode 100644 pkg/controller_test.go create mode 100644 pkg/doc.go create mode 100644 pkg/error.go create mode 100644 pkg/error_test.go create mode 100644 pkg/filter.go create mode 100644 pkg/filter_test.go create mode 100644 pkg/flash.go create mode 100644 pkg/flash_test.go create mode 100644 pkg/fs.go create mode 100644 pkg/grace/grace.go create mode 100644 pkg/grace/server.go create mode 100644 pkg/hooks.go create mode 100644 pkg/httplib/README.md create mode 100644 pkg/httplib/httplib.go create mode 100644 pkg/httplib/httplib_test.go create mode 100644 pkg/log.go create mode 100644 pkg/logs/README.md create mode 100644 pkg/logs/accesslog.go create mode 100644 pkg/logs/alils/alils.go create mode 100755 pkg/logs/alils/config.go create mode 100755 pkg/logs/alils/log.pb.go create mode 100755 pkg/logs/alils/log_config.go create mode 100755 pkg/logs/alils/log_project.go create mode 100755 pkg/logs/alils/log_store.go create mode 100755 pkg/logs/alils/machine_group.go create mode 100755 pkg/logs/alils/request.go create mode 100755 pkg/logs/alils/signature.go create mode 100644 pkg/logs/conn.go create mode 100644 pkg/logs/conn_test.go create mode 100644 pkg/logs/console.go create mode 100644 pkg/logs/console_test.go create mode 100644 pkg/logs/es/es.go create mode 100644 pkg/logs/file.go create mode 100644 pkg/logs/file_test.go create mode 100644 pkg/logs/jianliao.go create mode 100644 pkg/logs/log.go create mode 100644 pkg/logs/logger.go create mode 100644 pkg/logs/logger_test.go create mode 100644 pkg/logs/multifile.go create mode 100644 pkg/logs/multifile_test.go create mode 100644 pkg/logs/slack.go create mode 100644 pkg/logs/smtp.go create mode 100644 pkg/logs/smtp_test.go create mode 100644 pkg/metric/prometheus.go create mode 100644 pkg/metric/prometheus_test.go create mode 100644 pkg/migration/ddl.go create mode 100644 pkg/migration/doc.go create mode 100644 pkg/migration/migration.go create mode 100644 pkg/mime.go create mode 100644 pkg/namespace.go create mode 100644 pkg/namespace_test.go create mode 100644 pkg/parser.go create mode 100644 pkg/plugins/apiauth/apiauth.go create mode 100644 pkg/plugins/apiauth/apiauth_test.go create mode 100644 pkg/plugins/auth/basic.go create mode 100644 pkg/plugins/authz/authz.go create mode 100644 pkg/plugins/authz/authz_model.conf create mode 100644 pkg/plugins/authz/authz_policy.csv create mode 100644 pkg/plugins/authz/authz_test.go create mode 100644 pkg/plugins/cors/cors.go create mode 100644 pkg/plugins/cors/cors_test.go create mode 100644 pkg/policy.go create mode 100644 pkg/router.go create mode 100644 pkg/router_test.go create mode 100644 pkg/session/README.md create mode 100644 pkg/session/couchbase/sess_couchbase.go create mode 100644 pkg/session/ledis/ledis_session.go create mode 100644 pkg/session/memcache/sess_memcache.go create mode 100644 pkg/session/mysql/sess_mysql.go create mode 100644 pkg/session/postgres/sess_postgresql.go create mode 100644 pkg/session/redis/sess_redis.go create mode 100644 pkg/session/redis_cluster/redis_cluster.go create mode 100644 pkg/session/redis_sentinel/sess_redis_sentinel.go create mode 100644 pkg/session/redis_sentinel/sess_redis_sentinel_test.go create mode 100644 pkg/session/sess_cookie.go create mode 100644 pkg/session/sess_cookie_test.go create mode 100644 pkg/session/sess_file.go create mode 100644 pkg/session/sess_file_test.go create mode 100644 pkg/session/sess_mem.go create mode 100644 pkg/session/sess_mem_test.go create mode 100644 pkg/session/sess_test.go create mode 100644 pkg/session/sess_utils.go create mode 100644 pkg/session/session.go create mode 100644 pkg/session/ssdb/sess_ssdb.go create mode 100644 pkg/staticfile.go create mode 100644 pkg/staticfile_test.go create mode 100644 pkg/swagger/swagger.go create mode 100644 pkg/template.go create mode 100644 pkg/template_test.go create mode 100644 pkg/templatefunc.go create mode 100644 pkg/templatefunc_test.go create mode 100644 pkg/testdata/Makefile create mode 100644 pkg/testdata/bindata.go create mode 100644 pkg/testdata/views/blocks/block.tpl create mode 100644 pkg/testdata/views/header.tpl create mode 100644 pkg/testdata/views/index.tpl create mode 100644 pkg/testing/assertions.go create mode 100644 pkg/testing/client.go create mode 100644 pkg/toolbox/healthcheck.go create mode 100644 pkg/toolbox/profile.go create mode 100644 pkg/toolbox/profile_test.go create mode 100644 pkg/toolbox/statistics.go create mode 100644 pkg/toolbox/statistics_test.go create mode 100644 pkg/toolbox/task.go create mode 100644 pkg/toolbox/task_test.go create mode 100644 pkg/tree.go create mode 100644 pkg/tree_test.go create mode 100644 pkg/unregroute_test.go create mode 100644 pkg/utils/caller.go create mode 100644 pkg/utils/caller_test.go create mode 100644 pkg/utils/captcha/LICENSE create mode 100644 pkg/utils/captcha/README.md create mode 100644 pkg/utils/captcha/captcha.go create mode 100644 pkg/utils/captcha/image.go create mode 100644 pkg/utils/captcha/image_test.go create mode 100644 pkg/utils/captcha/siprng.go create mode 100644 pkg/utils/captcha/siprng_test.go create mode 100644 pkg/utils/debug.go create mode 100644 pkg/utils/debug_test.go create mode 100644 pkg/utils/file.go create mode 100644 pkg/utils/file_test.go create mode 100644 pkg/utils/mail.go create mode 100644 pkg/utils/mail_test.go create mode 100644 pkg/utils/pagination/controller.go create mode 100644 pkg/utils/pagination/doc.go create mode 100644 pkg/utils/pagination/paginator.go create mode 100644 pkg/utils/pagination/utils.go create mode 100644 pkg/utils/rand.go create mode 100644 pkg/utils/rand_test.go create mode 100644 pkg/utils/safemap.go create mode 100644 pkg/utils/safemap_test.go create mode 100644 pkg/utils/slice.go create mode 100644 pkg/utils/slice_test.go create mode 100644 pkg/utils/testdata/grepe.test create mode 100644 pkg/utils/utils.go create mode 100644 pkg/utils/utils_test.go create mode 100644 pkg/validation/README.md create mode 100644 pkg/validation/util.go create mode 100644 pkg/validation/util_test.go create mode 100644 pkg/validation/validation.go create mode 100644 pkg/validation/validation_test.go create mode 100644 pkg/validation/validators.go diff --git a/orm/cmd_utils.go b/orm/cmd_utils.go index eac85091..61f17346 100644 --- a/orm/cmd_utils.go +++ b/orm/cmd_utils.go @@ -178,9 +178,9 @@ func getDbCreateSQL(al *alias) (sqls []string, tableIndexes map[string][]dbIndex column += " " + "NOT NULL" } - // if fi.initial.String() != "" { + //if fi.initial.String() != "" { // column += " DEFAULT " + fi.initial.String() - // } + //} // Append attribute DEFAULT column += getColumnDefault(fi) @@ -197,9 +197,9 @@ func getDbCreateSQL(al *alias) (sqls []string, tableIndexes map[string][]dbIndex if strings.Contains(column, "%COL%") { column = strings.Replace(column, "%COL%", fi.column, -1) } - - if fi.description != "" && al.Driver != DRSqlite { - column += " " + fmt.Sprintf("COMMENT '%s'", fi.description) + + if fi.description != "" && al.Driver!=DRSqlite { + column += " " + fmt.Sprintf("COMMENT '%s'",fi.description) } columns = append(columns, column) diff --git a/orm/db_alias.go b/orm/db_alias.go index a84070b4..bf6c350c 100644 --- a/orm/db_alias.go +++ b/orm/db_alias.go @@ -12,21 +12,16 @@ // See the License for the specific language governing permissions and // limitations under the License. -// Deprecated: we will remove this package, please using pkg/orm package orm import ( "context" "database/sql" "fmt" + lru "github.com/hashicorp/golang-lru" "reflect" "sync" "time" - - lru "github.com/hashicorp/golang-lru" - - "github.com/astaxie/beego/pkg/common" - orm2 "github.com/astaxie/beego/pkg/orm" ) // DriverType database driver constant int. @@ -68,7 +63,7 @@ var ( "tidb": DRTiDB, "oracle": DROracle, "oci8": DROracle, // github.com/mattn/go-oci8 - "ora": DROracle, // https://github.com/rana/ora + "ora": DROracle, //https://github.com/rana/ora } dbBasers = map[DriverType]dbBaser{ DRMySQL: newdbBaseMysql(), @@ -124,7 +119,7 @@ func (d *DB) BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error) return d.DB.BeginTx(ctx, opts) } -// su must call release to release *sql.Stmt after using +//su must call release to release *sql.Stmt after using func (d *DB) getStmtDecorator(query string) (*stmtDecorator, error) { d.RLock() c, ok := d.stmtDecorators.Get(query) @@ -294,26 +289,82 @@ func detectTZ(al *alias) { } } +func addAliasWthDB(aliasName, driverName string, db *sql.DB) (*alias, error) { + al := new(alias) + al.Name = aliasName + al.DriverName = driverName + al.DB = &DB{ + RWMutex: new(sync.RWMutex), + DB: db, + stmtDecorators: newStmtDecoratorLruWithEvict(), + } + + if dr, ok := drivers[driverName]; ok { + al.DbBaser = dbBasers[dr] + al.Driver = dr + } else { + return nil, fmt.Errorf("driver name `%s` have not registered", driverName) + } + + err := db.Ping() + if err != nil { + return nil, fmt.Errorf("register db Ping `%s`, %s", aliasName, err.Error()) + } + + if !dataBaseCache.add(aliasName, al) { + return nil, fmt.Errorf("DataBase alias name `%s` already registered, cannot reuse", aliasName) + } + + return al, nil +} + // AddAliasWthDB add a aliasName for the drivename -// Deprecated: please using pkg/orm func AddAliasWthDB(aliasName, driverName string, db *sql.DB) error { - return orm2.AddAliasWthDB(aliasName, driverName, db) + _, err := addAliasWthDB(aliasName, driverName, db) + return err } // RegisterDataBase Setting the database connect params. Use the database driver self dataSource args. func RegisterDataBase(aliasName, driverName, dataSource string, params ...int) error { - kvs := make([]common.KV, 0, 2) + var ( + err error + db *sql.DB + al *alias + ) + + db, err = sql.Open(driverName, dataSource) + if err != nil { + err = fmt.Errorf("register db `%s`, %s", aliasName, err.Error()) + goto end + } + + al, err = addAliasWthDB(aliasName, driverName, db) + if err != nil { + goto end + } + + al.DataSource = dataSource + + detectTZ(al) + for i, v := range params { switch i { case 0: - kvs = append(kvs, common.KV{Key: orm2.MaxIdleConnsKey, Value: v}) + SetMaxIdleConns(al.Name, v) case 1: - kvs = append(kvs, common.KV{Key: orm2.MaxOpenConnsKey, Value: v}) - case 2: - kvs = append(kvs, common.KV{Key: orm2.ConnMaxLifetimeKey, Value: time.Duration(v) * time.Millisecond}) + SetMaxOpenConns(al.Name, v) } } - return orm2.RegisterDataBase(aliasName, driverName, dataSource, kvs...) + +end: + if err != nil { + if db != nil { + db.Close() + } + DebugLog.Println(err.Error()) + } + + return err } // RegisterDriver Register a database driver use specify driver name, this can be definition the driver is which database type. @@ -373,7 +424,7 @@ func GetDB(aliasNames ...string) (*sql.DB, error) { } type stmtDecorator struct { - wg sync.WaitGroup + wg sync.WaitGroup stmt *sql.Stmt } @@ -393,7 +444,7 @@ func (s *stmtDecorator) release() { s.wg.Done() } -// garbage recycle for stmt +//garbage recycle for stmt func (s *stmtDecorator) destroy() { go func() { s.wg.Wait() diff --git a/orm/orm_alias_adapt_test.go b/orm/orm_alias_adapt_test.go deleted file mode 100644 index d7724527..00000000 --- a/orm/orm_alias_adapt_test.go +++ /dev/null @@ -1,46 +0,0 @@ -// Copyright 2020 beego-dev -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package orm - -import ( - "os" - "testing" - - _ "github.com/go-sql-driver/mysql" - _ "github.com/lib/pq" - _ "github.com/mattn/go-sqlite3" - "github.com/stretchr/testify/assert" -) - -var DBARGS = struct { - Driver string - Source string - Debug string -}{ - os.Getenv("ORM_DRIVER"), - os.Getenv("ORM_SOURCE"), - os.Getenv("ORM_DEBUG"), -} - -func TestRegisterDataBase(t *testing.T) { - err := RegisterDataBase("test-adapt1", DBARGS.Driver, DBARGS.Source) - assert.Nil(t, err) - err = RegisterDataBase("test-adapt2", DBARGS.Driver, DBARGS.Source, 20) - assert.Nil(t, err) - err = RegisterDataBase("test-adapt3", DBARGS.Driver, DBARGS.Source, 20, 300) - assert.Nil(t, err) - err = RegisterDataBase("test-adapt4", DBARGS.Driver, DBARGS.Source, 20, 300, 60*1000) - assert.Nil(t, err) -} diff --git a/pkg/admin.go b/pkg/admin.go new file mode 100644 index 00000000..db52647e --- /dev/null +++ b/pkg/admin.go @@ -0,0 +1,458 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package beego + +import ( + "bytes" + "encoding/json" + "fmt" + "net/http" + "os" + "reflect" + "strconv" + "text/template" + "time" + + "github.com/prometheus/client_golang/prometheus/promhttp" + + "github.com/astaxie/beego/grace" + "github.com/astaxie/beego/logs" + "github.com/astaxie/beego/toolbox" + "github.com/astaxie/beego/utils" +) + +// BeeAdminApp is the default adminApp used by admin module. +var beeAdminApp *adminApp + +// FilterMonitorFunc is default monitor filter when admin module is enable. +// if this func returns, admin module records qps for this request by condition of this function logic. +// usage: +// func MyFilterMonitor(method, requestPath string, t time.Duration, pattern string, statusCode int) bool { +// if method == "POST" { +// return false +// } +// if t.Nanoseconds() < 100 { +// return false +// } +// if strings.HasPrefix(requestPath, "/astaxie") { +// return false +// } +// return true +// } +// beego.FilterMonitorFunc = MyFilterMonitor. +var FilterMonitorFunc func(string, string, time.Duration, string, int) bool + +func init() { + beeAdminApp = &adminApp{ + routers: make(map[string]http.HandlerFunc), + } + // keep in mind that all data should be html escaped to avoid XSS attack + beeAdminApp.Route("/", adminIndex) + beeAdminApp.Route("/qps", qpsIndex) + beeAdminApp.Route("/prof", profIndex) + beeAdminApp.Route("/healthcheck", healthcheck) + beeAdminApp.Route("/task", taskStatus) + beeAdminApp.Route("/listconf", listConf) + beeAdminApp.Route("/metrics", promhttp.Handler().ServeHTTP) + FilterMonitorFunc = func(string, string, time.Duration, string, int) bool { return true } +} + +// AdminIndex is the default http.Handler for admin module. +// it matches url pattern "/". +func adminIndex(rw http.ResponseWriter, _ *http.Request) { + writeTemplate(rw, map[interface{}]interface{}{}, indexTpl, defaultScriptsTpl) +} + +// QpsIndex is the http.Handler for writing qps statistics map result info in http.ResponseWriter. +// it's registered with url pattern "/qps" in admin module. +func qpsIndex(rw http.ResponseWriter, _ *http.Request) { + data := make(map[interface{}]interface{}) + data["Content"] = toolbox.StatisticsMap.GetMap() + + // do html escape before display path, avoid xss + if content, ok := (data["Content"]).(M); ok { + if resultLists, ok := (content["Data"]).([][]string); ok { + for i := range resultLists { + if len(resultLists[i]) > 0 { + resultLists[i][0] = template.HTMLEscapeString(resultLists[i][0]) + } + } + } + } + + writeTemplate(rw, data, qpsTpl, defaultScriptsTpl) +} + +// ListConf is the http.Handler of displaying all beego configuration values as key/value pair. +// it's registered with url pattern "/listconf" in admin module. +func listConf(rw http.ResponseWriter, r *http.Request) { + r.ParseForm() + command := r.Form.Get("command") + if command == "" { + rw.Write([]byte("command not support")) + return + } + + data := make(map[interface{}]interface{}) + switch command { + case "conf": + m := make(M) + list("BConfig", BConfig, m) + m["AppConfigPath"] = template.HTMLEscapeString(appConfigPath) + m["AppConfigProvider"] = template.HTMLEscapeString(appConfigProvider) + tmpl := template.Must(template.New("dashboard").Parse(dashboardTpl)) + tmpl = template.Must(tmpl.Parse(configTpl)) + tmpl = template.Must(tmpl.Parse(defaultScriptsTpl)) + + data["Content"] = m + + tmpl.Execute(rw, data) + + case "router": + content := PrintTree() + content["Fields"] = []string{ + "Router Pattern", + "Methods", + "Controller", + } + data["Content"] = content + data["Title"] = "Routers" + writeTemplate(rw, data, routerAndFilterTpl, defaultScriptsTpl) + case "filter": + var ( + content = M{ + "Fields": []string{ + "Router Pattern", + "Filter Function", + }, + } + filterTypes = []string{} + filterTypeData = make(M) + ) + + if BeeApp.Handlers.enableFilter { + var filterType string + for k, fr := range map[int]string{ + BeforeStatic: "Before Static", + BeforeRouter: "Before Router", + BeforeExec: "Before Exec", + AfterExec: "After Exec", + FinishRouter: "Finish Router"} { + if bf := BeeApp.Handlers.filters[k]; len(bf) > 0 { + filterType = fr + filterTypes = append(filterTypes, filterType) + resultList := new([][]string) + for _, f := range bf { + var result = []string{ + // void xss + template.HTMLEscapeString(f.pattern), + template.HTMLEscapeString(utils.GetFuncName(f.filterFunc)), + } + *resultList = append(*resultList, result) + } + filterTypeData[filterType] = resultList + } + } + } + + content["Data"] = filterTypeData + content["Methods"] = filterTypes + + data["Content"] = content + data["Title"] = "Filters" + writeTemplate(rw, data, routerAndFilterTpl, defaultScriptsTpl) + default: + rw.Write([]byte("command not support")) + } +} + +func list(root string, p interface{}, m M) { + pt := reflect.TypeOf(p) + pv := reflect.ValueOf(p) + if pt.Kind() == reflect.Ptr { + pt = pt.Elem() + pv = pv.Elem() + } + for i := 0; i < pv.NumField(); i++ { + var key string + if root == "" { + key = pt.Field(i).Name + } else { + key = root + "." + pt.Field(i).Name + } + if pv.Field(i).Kind() == reflect.Struct { + list(key, pv.Field(i).Interface(), m) + } else { + m[key] = pv.Field(i).Interface() + } + } +} + +// PrintTree prints all registered routers. +func PrintTree() M { + var ( + content = M{} + methods = []string{} + methodsData = make(M) + ) + for method, t := range BeeApp.Handlers.routers { + + resultList := new([][]string) + + printTree(resultList, t) + + methods = append(methods, template.HTMLEscapeString(method)) + methodsData[template.HTMLEscapeString(method)] = resultList + } + + content["Data"] = methodsData + content["Methods"] = methods + return content +} + +func printTree(resultList *[][]string, t *Tree) { + for _, tr := range t.fixrouters { + printTree(resultList, tr) + } + if t.wildcard != nil { + printTree(resultList, t.wildcard) + } + for _, l := range t.leaves { + if v, ok := l.runObject.(*ControllerInfo); ok { + if v.routerType == routerTypeBeego { + var result = []string{ + template.HTMLEscapeString(v.pattern), + template.HTMLEscapeString(fmt.Sprintf("%s", v.methods)), + template.HTMLEscapeString(v.controllerType.String()), + } + *resultList = append(*resultList, result) + } else if v.routerType == routerTypeRESTFul { + var result = []string{ + template.HTMLEscapeString(v.pattern), + template.HTMLEscapeString(fmt.Sprintf("%s", v.methods)), + "", + } + *resultList = append(*resultList, result) + } else if v.routerType == routerTypeHandler { + var result = []string{ + template.HTMLEscapeString(v.pattern), + "", + "", + } + *resultList = append(*resultList, result) + } + } + } +} + +// ProfIndex is a http.Handler for showing profile command. +// it's in url pattern "/prof" in admin module. +func profIndex(rw http.ResponseWriter, r *http.Request) { + r.ParseForm() + command := r.Form.Get("command") + if command == "" { + return + } + + var ( + format = r.Form.Get("format") + data = make(map[interface{}]interface{}) + result bytes.Buffer + ) + toolbox.ProcessInput(command, &result) + data["Content"] = template.HTMLEscapeString(result.String()) + + if format == "json" && command == "gc summary" { + dataJSON, err := json.Marshal(data) + if err != nil { + http.Error(rw, err.Error(), http.StatusInternalServerError) + return + } + writeJSON(rw, dataJSON) + return + } + + data["Title"] = template.HTMLEscapeString(command) + defaultTpl := defaultScriptsTpl + if command == "gc summary" { + defaultTpl = gcAjaxTpl + } + writeTemplate(rw, data, profillingTpl, defaultTpl) +} + +// Healthcheck is a http.Handler calling health checking and showing the result. +// it's in "/healthcheck" pattern in admin module. +func healthcheck(rw http.ResponseWriter, r *http.Request) { + var ( + result []string + data = make(map[interface{}]interface{}) + resultList = new([][]string) + content = M{ + "Fields": []string{"Name", "Message", "Status"}, + } + ) + + for name, h := range toolbox.AdminCheckList { + if err := h.Check(); err != nil { + result = []string{ + "error", + template.HTMLEscapeString(name), + template.HTMLEscapeString(err.Error()), + } + } else { + result = []string{ + "success", + template.HTMLEscapeString(name), + "OK", + } + } + *resultList = append(*resultList, result) + } + + queryParams := r.URL.Query() + jsonFlag := queryParams.Get("json") + shouldReturnJSON, _ := strconv.ParseBool(jsonFlag) + + if shouldReturnJSON { + response := buildHealthCheckResponseList(resultList) + jsonResponse, err := json.Marshal(response) + + if err != nil { + http.Error(rw, err.Error(), http.StatusInternalServerError) + } else { + writeJSON(rw, jsonResponse) + } + return + } + + content["Data"] = resultList + data["Content"] = content + data["Title"] = "Health Check" + + writeTemplate(rw, data, healthCheckTpl, defaultScriptsTpl) +} + +func buildHealthCheckResponseList(healthCheckResults *[][]string) []map[string]interface{} { + response := make([]map[string]interface{}, len(*healthCheckResults)) + + for i, healthCheckResult := range *healthCheckResults { + currentResultMap := make(map[string]interface{}) + + currentResultMap["name"] = healthCheckResult[0] + currentResultMap["message"] = healthCheckResult[1] + currentResultMap["status"] = healthCheckResult[2] + + response[i] = currentResultMap + } + + return response + +} + +func writeJSON(rw http.ResponseWriter, jsonData []byte) { + rw.Header().Set("Content-Type", "application/json") + rw.Write(jsonData) +} + +// TaskStatus is a http.Handler with running task status (task name, status and the last execution). +// it's in "/task" pattern in admin module. +func taskStatus(rw http.ResponseWriter, req *http.Request) { + data := make(map[interface{}]interface{}) + + // Run Task + req.ParseForm() + taskname := req.Form.Get("taskname") + if taskname != "" { + if t, ok := toolbox.AdminTaskList[taskname]; ok { + if err := t.Run(); err != nil { + data["Message"] = []string{"error", template.HTMLEscapeString(fmt.Sprintf("%s", err))} + } + data["Message"] = []string{"success", template.HTMLEscapeString(fmt.Sprintf("%s run success,Now the Status is
%s", taskname, t.GetStatus()))} + } else { + data["Message"] = []string{"warning", template.HTMLEscapeString(fmt.Sprintf("there's no task which named: %s", taskname))} + } + } + + // List Tasks + content := make(M) + resultList := new([][]string) + var fields = []string{ + "Task Name", + "Task Spec", + "Task Status", + "Last Time", + "", + } + for tname, tk := range toolbox.AdminTaskList { + result := []string{ + template.HTMLEscapeString(tname), + template.HTMLEscapeString(tk.GetSpec()), + template.HTMLEscapeString(tk.GetStatus()), + template.HTMLEscapeString(tk.GetPrev().String()), + } + *resultList = append(*resultList, result) + } + + content["Fields"] = fields + content["Data"] = resultList + data["Content"] = content + data["Title"] = "Tasks" + writeTemplate(rw, data, tasksTpl, defaultScriptsTpl) +} + +func writeTemplate(rw http.ResponseWriter, data map[interface{}]interface{}, tpls ...string) { + tmpl := template.Must(template.New("dashboard").Parse(dashboardTpl)) + for _, tpl := range tpls { + tmpl = template.Must(tmpl.Parse(tpl)) + } + tmpl.Execute(rw, data) +} + +// adminApp is an http.HandlerFunc map used as beeAdminApp. +type adminApp struct { + routers map[string]http.HandlerFunc +} + +// Route adds http.HandlerFunc to adminApp with url pattern. +func (admin *adminApp) Route(pattern string, f http.HandlerFunc) { + admin.routers[pattern] = f +} + +// Run adminApp http server. +// Its addr is defined in configuration file as adminhttpaddr and adminhttpport. +func (admin *adminApp) Run() { + if len(toolbox.AdminTaskList) > 0 { + toolbox.StartTask() + } + addr := BConfig.Listen.AdminAddr + + if BConfig.Listen.AdminPort != 0 { + addr = fmt.Sprintf("%s:%d", BConfig.Listen.AdminAddr, BConfig.Listen.AdminPort) + } + for p, f := range admin.routers { + http.Handle(p, f) + } + logs.Info("Admin server Running on %s", addr) + + var err error + if BConfig.Listen.Graceful { + err = grace.ListenAndServe(addr, nil) + } else { + err = http.ListenAndServe(addr, nil) + } + if err != nil { + logs.Critical("Admin ListenAndServe: ", err, fmt.Sprintf("%d", os.Getpid())) + } +} diff --git a/pkg/admin_test.go b/pkg/admin_test.go new file mode 100644 index 00000000..3f3612e4 --- /dev/null +++ b/pkg/admin_test.go @@ -0,0 +1,239 @@ +package beego + +import ( + "encoding/json" + "errors" + "fmt" + "net/http" + "net/http/httptest" + "reflect" + "strings" + "testing" + + "github.com/astaxie/beego/toolbox" +) + +type SampleDatabaseCheck struct { +} + +type SampleCacheCheck struct { +} + +func (dc *SampleDatabaseCheck) Check() error { + return nil +} + +func (cc *SampleCacheCheck) Check() error { + return errors.New("no cache detected") +} + +func TestList_01(t *testing.T) { + m := make(M) + list("BConfig", BConfig, m) + t.Log(m) + om := oldMap() + for k, v := range om { + if fmt.Sprint(m[k]) != fmt.Sprint(v) { + t.Log(k, "old-key", v, "new-key", m[k]) + t.FailNow() + } + } +} + +func oldMap() M { + m := make(M) + m["BConfig.AppName"] = BConfig.AppName + m["BConfig.RunMode"] = BConfig.RunMode + m["BConfig.RouterCaseSensitive"] = BConfig.RouterCaseSensitive + m["BConfig.ServerName"] = BConfig.ServerName + m["BConfig.RecoverPanic"] = BConfig.RecoverPanic + m["BConfig.CopyRequestBody"] = BConfig.CopyRequestBody + m["BConfig.EnableGzip"] = BConfig.EnableGzip + m["BConfig.MaxMemory"] = BConfig.MaxMemory + m["BConfig.EnableErrorsShow"] = BConfig.EnableErrorsShow + m["BConfig.Listen.Graceful"] = BConfig.Listen.Graceful + m["BConfig.Listen.ServerTimeOut"] = BConfig.Listen.ServerTimeOut + m["BConfig.Listen.ListenTCP4"] = BConfig.Listen.ListenTCP4 + m["BConfig.Listen.EnableHTTP"] = BConfig.Listen.EnableHTTP + m["BConfig.Listen.HTTPAddr"] = BConfig.Listen.HTTPAddr + m["BConfig.Listen.HTTPPort"] = BConfig.Listen.HTTPPort + m["BConfig.Listen.EnableHTTPS"] = BConfig.Listen.EnableHTTPS + m["BConfig.Listen.HTTPSAddr"] = BConfig.Listen.HTTPSAddr + m["BConfig.Listen.HTTPSPort"] = BConfig.Listen.HTTPSPort + m["BConfig.Listen.HTTPSCertFile"] = BConfig.Listen.HTTPSCertFile + m["BConfig.Listen.HTTPSKeyFile"] = BConfig.Listen.HTTPSKeyFile + m["BConfig.Listen.EnableAdmin"] = BConfig.Listen.EnableAdmin + m["BConfig.Listen.AdminAddr"] = BConfig.Listen.AdminAddr + m["BConfig.Listen.AdminPort"] = BConfig.Listen.AdminPort + m["BConfig.Listen.EnableFcgi"] = BConfig.Listen.EnableFcgi + m["BConfig.Listen.EnableStdIo"] = BConfig.Listen.EnableStdIo + m["BConfig.WebConfig.AutoRender"] = BConfig.WebConfig.AutoRender + m["BConfig.WebConfig.EnableDocs"] = BConfig.WebConfig.EnableDocs + m["BConfig.WebConfig.FlashName"] = BConfig.WebConfig.FlashName + m["BConfig.WebConfig.FlashSeparator"] = BConfig.WebConfig.FlashSeparator + m["BConfig.WebConfig.DirectoryIndex"] = BConfig.WebConfig.DirectoryIndex + m["BConfig.WebConfig.StaticDir"] = BConfig.WebConfig.StaticDir + m["BConfig.WebConfig.StaticExtensionsToGzip"] = BConfig.WebConfig.StaticExtensionsToGzip + m["BConfig.WebConfig.StaticCacheFileSize"] = BConfig.WebConfig.StaticCacheFileSize + m["BConfig.WebConfig.StaticCacheFileNum"] = BConfig.WebConfig.StaticCacheFileNum + m["BConfig.WebConfig.TemplateLeft"] = BConfig.WebConfig.TemplateLeft + m["BConfig.WebConfig.TemplateRight"] = BConfig.WebConfig.TemplateRight + m["BConfig.WebConfig.ViewsPath"] = BConfig.WebConfig.ViewsPath + m["BConfig.WebConfig.EnableXSRF"] = BConfig.WebConfig.EnableXSRF + m["BConfig.WebConfig.XSRFExpire"] = BConfig.WebConfig.XSRFExpire + m["BConfig.WebConfig.Session.SessionOn"] = BConfig.WebConfig.Session.SessionOn + m["BConfig.WebConfig.Session.SessionProvider"] = BConfig.WebConfig.Session.SessionProvider + m["BConfig.WebConfig.Session.SessionName"] = BConfig.WebConfig.Session.SessionName + m["BConfig.WebConfig.Session.SessionGCMaxLifetime"] = BConfig.WebConfig.Session.SessionGCMaxLifetime + m["BConfig.WebConfig.Session.SessionProviderConfig"] = BConfig.WebConfig.Session.SessionProviderConfig + m["BConfig.WebConfig.Session.SessionCookieLifeTime"] = BConfig.WebConfig.Session.SessionCookieLifeTime + m["BConfig.WebConfig.Session.SessionAutoSetCookie"] = BConfig.WebConfig.Session.SessionAutoSetCookie + m["BConfig.WebConfig.Session.SessionDomain"] = BConfig.WebConfig.Session.SessionDomain + m["BConfig.WebConfig.Session.SessionDisableHTTPOnly"] = BConfig.WebConfig.Session.SessionDisableHTTPOnly + m["BConfig.Log.AccessLogs"] = BConfig.Log.AccessLogs + m["BConfig.Log.EnableStaticLogs"] = BConfig.Log.EnableStaticLogs + m["BConfig.Log.AccessLogsFormat"] = BConfig.Log.AccessLogsFormat + m["BConfig.Log.FileLineNum"] = BConfig.Log.FileLineNum + m["BConfig.Log.Outputs"] = BConfig.Log.Outputs + return m +} + +func TestWriteJSON(t *testing.T) { + t.Log("Testing the adding of JSON to the response") + + w := httptest.NewRecorder() + originalBody := []int{1, 2, 3} + + res, _ := json.Marshal(originalBody) + + writeJSON(w, res) + + decodedBody := []int{} + err := json.NewDecoder(w.Body).Decode(&decodedBody) + + if err != nil { + t.Fatal("Could not decode response body into slice.") + } + + for i := range decodedBody { + if decodedBody[i] != originalBody[i] { + t.Fatalf("Expected %d but got %d in decoded body slice", originalBody[i], decodedBody[i]) + } + } +} + +func TestHealthCheckHandlerDefault(t *testing.T) { + endpointPath := "/healthcheck" + + toolbox.AddHealthCheck("database", &SampleDatabaseCheck{}) + toolbox.AddHealthCheck("cache", &SampleCacheCheck{}) + + req, err := http.NewRequest("GET", endpointPath, nil) + if err != nil { + t.Fatal(err) + } + + w := httptest.NewRecorder() + + handler := http.HandlerFunc(healthcheck) + + handler.ServeHTTP(w, req) + + if status := w.Code; status != http.StatusOK { + t.Errorf("handler returned wrong status code: got %v want %v", + status, http.StatusOK) + } + if !strings.Contains(w.Body.String(), "database") { + t.Errorf("Expected 'database' in generated template.") + } + +} + +func TestBuildHealthCheckResponseList(t *testing.T) { + healthCheckResults := [][]string{ + []string{ + "error", + "Database", + "Error occured whie starting the db", + }, + []string{ + "success", + "Cache", + "Cache started successfully", + }, + } + + responseList := buildHealthCheckResponseList(&healthCheckResults) + + if len(responseList) != len(healthCheckResults) { + t.Errorf("invalid response map length: got %d want %d", + len(responseList), len(healthCheckResults)) + } + + responseFields := []string{"name", "message", "status"} + + for _, response := range responseList { + for _, field := range responseFields { + _, ok := response[field] + if !ok { + t.Errorf("expected %s to be in the response %v", field, response) + } + } + + } + +} + +func TestHealthCheckHandlerReturnsJSON(t *testing.T) { + + toolbox.AddHealthCheck("database", &SampleDatabaseCheck{}) + toolbox.AddHealthCheck("cache", &SampleCacheCheck{}) + + req, err := http.NewRequest("GET", "/healthcheck?json=true", nil) + if err != nil { + t.Fatal(err) + } + + w := httptest.NewRecorder() + + handler := http.HandlerFunc(healthcheck) + + handler.ServeHTTP(w, req) + if status := w.Code; status != http.StatusOK { + t.Errorf("handler returned wrong status code: got %v want %v", + status, http.StatusOK) + } + + decodedResponseBody := []map[string]interface{}{} + expectedResponseBody := []map[string]interface{}{} + + expectedJSONString := []byte(` + [ + { + "message":"database", + "name":"success", + "status":"OK" + }, + { + "message":"cache", + "name":"error", + "status":"no cache detected" + } + ] + `) + + json.Unmarshal(expectedJSONString, &expectedResponseBody) + + json.Unmarshal(w.Body.Bytes(), &decodedResponseBody) + + if len(expectedResponseBody) != len(decodedResponseBody) { + t.Errorf("invalid response map length: got %d want %d", + len(decodedResponseBody), len(expectedResponseBody)) + } + + if !reflect.DeepEqual(decodedResponseBody, expectedResponseBody) { + t.Errorf("handler returned unexpected body: got %v want %v", + decodedResponseBody, expectedResponseBody) + } + +} diff --git a/pkg/adminui.go b/pkg/adminui.go new file mode 100644 index 00000000..cdcdef33 --- /dev/null +++ b/pkg/adminui.go @@ -0,0 +1,356 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package beego + +var indexTpl = ` +{{define "content"}} +

Beego Admin Dashboard

+

+For detail usage please check our document: +

+

+Toolbox +

+

+Live Monitor +

+{{.Content}} +{{end}}` + +var profillingTpl = ` +{{define "content"}} +

{{.Title}}

+
+
{{.Content}}
+
+{{end}}` + +var defaultScriptsTpl = `` + +var gcAjaxTpl = ` +{{define "scripts"}} + +{{end}} +` + +var qpsTpl = `{{define "content"}} +

Requests statistics

+ + + + {{range .Content.Fields}} + + {{end}} + + + + + {{range $i, $elem := .Content.Data}} + + + + + + + + + + + {{end}} + + +
+ {{.}} +
{{index $elem 0}}{{index $elem 1}}{{index $elem 2}}{{index $elem 4}}{{index $elem 6}}{{index $elem 8}}{{index $elem 10}}
+{{end}}` + +var configTpl = ` +{{define "content"}} +

Configurations

+
+{{range $index, $elem := .Content}}
+{{$index}}={{$elem}}
+{{end}}
+
+{{end}} +` + +var routerAndFilterTpl = `{{define "content"}} + + +

{{.Title}}

+ +{{range .Content.Methods}} + +
+
{{.}}
+
+ + + + {{range $.Content.Fields}} + + {{end}} + + + + + {{$slice := index $.Content.Data .}} + {{range $i, $elem := $slice}} + + + {{range $elem}} + + {{end}} + + + {{end}} + + +
+ {{.}} +
+ {{.}} +
+
+
+{{end}} + + +{{end}}` + +var tasksTpl = `{{define "content"}} + +

{{.Title}}

+ +{{if .Message }} +{{ $messageType := index .Message 0}} +

+{{index .Message 1}} +

+{{end}} + + + + + +{{range .Content.Fields}} + +{{end}} + + + + +{{range $i, $slice := .Content.Data}} + + {{range $slice}} + + {{end}} + + +{{end}} + +
+{{.}} +
+ {{.}} + + Run +
+ +{{end}}` + +var healthCheckTpl = ` +{{define "content"}} + +

{{.Title}}

+ + + +{{range .Content.Fields}} + +{{end}} + + + +{{range $i, $slice := .Content.Data}} + {{ $header := index $slice 0}} + {{ if eq "success" $header}} + + {{else if eq "error" $header}} + + {{else}} + + {{end}} + {{range $j, $elem := $slice}} + {{if ne $j 0}} + + {{end}} + {{end}} + + +{{end}} + + +
+ {{.}} +
+ {{$elem}} + + {{$header}} +
+{{end}}` + +// The base dashboardTpl +var dashboardTpl = ` + + + + + + + + + + +Welcome to Beego Admin Dashboard + + + + + + + + + + + + + +
+{{template "content" .}} +
+ + + + + + + +{{template "scripts" .}} + + +` diff --git a/pkg/app.go b/pkg/app.go new file mode 100644 index 00000000..f3fe6f7b --- /dev/null +++ b/pkg/app.go @@ -0,0 +1,496 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package beego + +import ( + "crypto/tls" + "crypto/x509" + "fmt" + "io/ioutil" + "net" + "net/http" + "net/http/fcgi" + "os" + "path" + "strings" + "time" + + "github.com/astaxie/beego/grace" + "github.com/astaxie/beego/logs" + "github.com/astaxie/beego/utils" + "golang.org/x/crypto/acme/autocert" +) + +var ( + // BeeApp is an application instance + BeeApp *App +) + +func init() { + // create beego application + BeeApp = NewApp() +} + +// App defines beego application with a new PatternServeMux. +type App struct { + Handlers *ControllerRegister + Server *http.Server +} + +// NewApp returns a new beego application. +func NewApp() *App { + cr := NewControllerRegister() + app := &App{Handlers: cr, Server: &http.Server{}} + return app +} + +// MiddleWare function for http.Handler +type MiddleWare func(http.Handler) http.Handler + +// Run beego application. +func (app *App) Run(mws ...MiddleWare) { + addr := BConfig.Listen.HTTPAddr + + if BConfig.Listen.HTTPPort != 0 { + addr = fmt.Sprintf("%s:%d", BConfig.Listen.HTTPAddr, BConfig.Listen.HTTPPort) + } + + var ( + err error + l net.Listener + endRunning = make(chan bool, 1) + ) + + // run cgi server + if BConfig.Listen.EnableFcgi { + if BConfig.Listen.EnableStdIo { + if err = fcgi.Serve(nil, app.Handlers); err == nil { // standard I/O + logs.Info("Use FCGI via standard I/O") + } else { + logs.Critical("Cannot use FCGI via standard I/O", err) + } + return + } + if BConfig.Listen.HTTPPort == 0 { + // remove the Socket file before start + if utils.FileExists(addr) { + os.Remove(addr) + } + l, err = net.Listen("unix", addr) + } else { + l, err = net.Listen("tcp", addr) + } + if err != nil { + logs.Critical("Listen: ", err) + } + if err = fcgi.Serve(l, app.Handlers); err != nil { + logs.Critical("fcgi.Serve: ", err) + } + return + } + + app.Server.Handler = app.Handlers + for i := len(mws) - 1; i >= 0; i-- { + if mws[i] == nil { + continue + } + app.Server.Handler = mws[i](app.Server.Handler) + } + app.Server.ReadTimeout = time.Duration(BConfig.Listen.ServerTimeOut) * time.Second + app.Server.WriteTimeout = time.Duration(BConfig.Listen.ServerTimeOut) * time.Second + app.Server.ErrorLog = logs.GetLogger("HTTP") + + // run graceful mode + if BConfig.Listen.Graceful { + httpsAddr := BConfig.Listen.HTTPSAddr + app.Server.Addr = httpsAddr + if BConfig.Listen.EnableHTTPS || BConfig.Listen.EnableMutualHTTPS { + go func() { + time.Sleep(1000 * time.Microsecond) + if BConfig.Listen.HTTPSPort != 0 { + httpsAddr = fmt.Sprintf("%s:%d", BConfig.Listen.HTTPSAddr, BConfig.Listen.HTTPSPort) + app.Server.Addr = httpsAddr + } + server := grace.NewServer(httpsAddr, app.Server.Handler) + server.Server.ReadTimeout = app.Server.ReadTimeout + server.Server.WriteTimeout = app.Server.WriteTimeout + if BConfig.Listen.EnableMutualHTTPS { + if err := server.ListenAndServeMutualTLS(BConfig.Listen.HTTPSCertFile, BConfig.Listen.HTTPSKeyFile, BConfig.Listen.TrustCaFile); err != nil { + logs.Critical("ListenAndServeTLS: ", err, fmt.Sprintf("%d", os.Getpid())) + time.Sleep(100 * time.Microsecond) + } + } else { + if BConfig.Listen.AutoTLS { + m := autocert.Manager{ + Prompt: autocert.AcceptTOS, + HostPolicy: autocert.HostWhitelist(BConfig.Listen.Domains...), + Cache: autocert.DirCache(BConfig.Listen.TLSCacheDir), + } + app.Server.TLSConfig = &tls.Config{GetCertificate: m.GetCertificate} + BConfig.Listen.HTTPSCertFile, BConfig.Listen.HTTPSKeyFile = "", "" + } + if err := server.ListenAndServeTLS(BConfig.Listen.HTTPSCertFile, BConfig.Listen.HTTPSKeyFile); err != nil { + logs.Critical("ListenAndServeTLS: ", err, fmt.Sprintf("%d", os.Getpid())) + time.Sleep(100 * time.Microsecond) + } + } + endRunning <- true + }() + } + if BConfig.Listen.EnableHTTP { + go func() { + server := grace.NewServer(addr, app.Server.Handler) + server.Server.ReadTimeout = app.Server.ReadTimeout + server.Server.WriteTimeout = app.Server.WriteTimeout + if BConfig.Listen.ListenTCP4 { + server.Network = "tcp4" + } + if err := server.ListenAndServe(); err != nil { + logs.Critical("ListenAndServe: ", err, fmt.Sprintf("%d", os.Getpid())) + time.Sleep(100 * time.Microsecond) + } + endRunning <- true + }() + } + <-endRunning + return + } + + // run normal mode + if BConfig.Listen.EnableHTTPS || BConfig.Listen.EnableMutualHTTPS { + go func() { + time.Sleep(1000 * time.Microsecond) + if BConfig.Listen.HTTPSPort != 0 { + app.Server.Addr = fmt.Sprintf("%s:%d", BConfig.Listen.HTTPSAddr, BConfig.Listen.HTTPSPort) + } else if BConfig.Listen.EnableHTTP { + logs.Info("Start https server error, conflict with http. Please reset https port") + return + } + logs.Info("https server Running on https://%s", app.Server.Addr) + if BConfig.Listen.AutoTLS { + m := autocert.Manager{ + Prompt: autocert.AcceptTOS, + HostPolicy: autocert.HostWhitelist(BConfig.Listen.Domains...), + Cache: autocert.DirCache(BConfig.Listen.TLSCacheDir), + } + app.Server.TLSConfig = &tls.Config{GetCertificate: m.GetCertificate} + BConfig.Listen.HTTPSCertFile, BConfig.Listen.HTTPSKeyFile = "", "" + } else if BConfig.Listen.EnableMutualHTTPS { + pool := x509.NewCertPool() + data, err := ioutil.ReadFile(BConfig.Listen.TrustCaFile) + if err != nil { + logs.Info("MutualHTTPS should provide TrustCaFile") + return + } + pool.AppendCertsFromPEM(data) + app.Server.TLSConfig = &tls.Config{ + ClientCAs: pool, + ClientAuth: tls.RequireAndVerifyClientCert, + } + } + if err := app.Server.ListenAndServeTLS(BConfig.Listen.HTTPSCertFile, BConfig.Listen.HTTPSKeyFile); err != nil { + logs.Critical("ListenAndServeTLS: ", err) + time.Sleep(100 * time.Microsecond) + endRunning <- true + } + }() + + } + if BConfig.Listen.EnableHTTP { + go func() { + app.Server.Addr = addr + logs.Info("http server Running on http://%s", app.Server.Addr) + if BConfig.Listen.ListenTCP4 { + ln, err := net.Listen("tcp4", app.Server.Addr) + if err != nil { + logs.Critical("ListenAndServe: ", err) + time.Sleep(100 * time.Microsecond) + endRunning <- true + return + } + if err = app.Server.Serve(ln); err != nil { + logs.Critical("ListenAndServe: ", err) + time.Sleep(100 * time.Microsecond) + endRunning <- true + return + } + } else { + if err := app.Server.ListenAndServe(); err != nil { + logs.Critical("ListenAndServe: ", err) + time.Sleep(100 * time.Microsecond) + endRunning <- true + } + } + }() + } + <-endRunning +} + +// Router adds a patterned controller handler to BeeApp. +// it's an alias method of App.Router. +// usage: +// simple router +// beego.Router("/admin", &admin.UserController{}) +// beego.Router("/admin/index", &admin.ArticleController{}) +// +// regex router +// +// beego.Router("/api/:id([0-9]+)", &controllers.RController{}) +// +// custom rules +// beego.Router("/api/list",&RestController{},"*:ListFood") +// beego.Router("/api/create",&RestController{},"post:CreateFood") +// beego.Router("/api/update",&RestController{},"put:UpdateFood") +// beego.Router("/api/delete",&RestController{},"delete:DeleteFood") +func Router(rootpath string, c ControllerInterface, mappingMethods ...string) *App { + BeeApp.Handlers.Add(rootpath, c, mappingMethods...) + return BeeApp +} + +// UnregisterFixedRoute unregisters the route with the specified fixedRoute. It is particularly useful +// in web applications that inherit most routes from a base webapp via the underscore +// import, and aim to overwrite only certain paths. +// The method parameter can be empty or "*" for all HTTP methods, or a particular +// method type (e.g. "GET" or "POST") for selective removal. +// +// Usage (replace "GET" with "*" for all methods): +// beego.UnregisterFixedRoute("/yourpreviouspath", "GET") +// beego.Router("/yourpreviouspath", yourControllerAddress, "get:GetNewPage") +func UnregisterFixedRoute(fixedRoute string, method string) *App { + subPaths := splitPath(fixedRoute) + if method == "" || method == "*" { + for m := range HTTPMETHOD { + if _, ok := BeeApp.Handlers.routers[m]; !ok { + continue + } + if BeeApp.Handlers.routers[m].prefix == strings.Trim(fixedRoute, "/ ") { + findAndRemoveSingleTree(BeeApp.Handlers.routers[m]) + continue + } + findAndRemoveTree(subPaths, BeeApp.Handlers.routers[m], m) + } + return BeeApp + } + // Single HTTP method + um := strings.ToUpper(method) + if _, ok := BeeApp.Handlers.routers[um]; ok { + if BeeApp.Handlers.routers[um].prefix == strings.Trim(fixedRoute, "/ ") { + findAndRemoveSingleTree(BeeApp.Handlers.routers[um]) + return BeeApp + } + findAndRemoveTree(subPaths, BeeApp.Handlers.routers[um], um) + } + return BeeApp +} + +func findAndRemoveTree(paths []string, entryPointTree *Tree, method string) { + for i := range entryPointTree.fixrouters { + if entryPointTree.fixrouters[i].prefix == paths[0] { + if len(paths) == 1 { + if len(entryPointTree.fixrouters[i].fixrouters) > 0 { + // If the route had children subtrees, remove just the functional leaf, + // to allow children to function as before + if len(entryPointTree.fixrouters[i].leaves) > 0 { + entryPointTree.fixrouters[i].leaves[0] = nil + entryPointTree.fixrouters[i].leaves = entryPointTree.fixrouters[i].leaves[1:] + } + } else { + // Remove the *Tree from the fixrouters slice + entryPointTree.fixrouters[i] = nil + + if i == len(entryPointTree.fixrouters)-1 { + entryPointTree.fixrouters = entryPointTree.fixrouters[:i] + } else { + entryPointTree.fixrouters = append(entryPointTree.fixrouters[:i], entryPointTree.fixrouters[i+1:len(entryPointTree.fixrouters)]...) + } + } + return + } + findAndRemoveTree(paths[1:], entryPointTree.fixrouters[i], method) + } + } +} + +func findAndRemoveSingleTree(entryPointTree *Tree) { + if entryPointTree == nil { + return + } + if len(entryPointTree.fixrouters) > 0 { + // If the route had children subtrees, remove just the functional leaf, + // to allow children to function as before + if len(entryPointTree.leaves) > 0 { + entryPointTree.leaves[0] = nil + entryPointTree.leaves = entryPointTree.leaves[1:] + } + } +} + +// Include will generate router file in the router/xxx.go from the controller's comments +// usage: +// beego.Include(&BankAccount{}, &OrderController{},&RefundController{},&ReceiptController{}) +// type BankAccount struct{ +// beego.Controller +// } +// +// register the function +// func (b *BankAccount)Mapping(){ +// b.Mapping("ShowAccount" , b.ShowAccount) +// b.Mapping("ModifyAccount", b.ModifyAccount) +//} +// +// //@router /account/:id [get] +// func (b *BankAccount) ShowAccount(){ +// //logic +// } +// +// +// //@router /account/:id [post] +// func (b *BankAccount) ModifyAccount(){ +// //logic +// } +// +// the comments @router url methodlist +// url support all the function Router's pattern +// methodlist [get post head put delete options *] +func Include(cList ...ControllerInterface) *App { + BeeApp.Handlers.Include(cList...) + return BeeApp +} + +// RESTRouter adds a restful controller handler to BeeApp. +// its' controller implements beego.ControllerInterface and +// defines a param "pattern/:objectId" to visit each resource. +func RESTRouter(rootpath string, c ControllerInterface) *App { + Router(rootpath, c) + Router(path.Join(rootpath, ":objectId"), c) + return BeeApp +} + +// AutoRouter adds defined controller handler to BeeApp. +// it's same to App.AutoRouter. +// if beego.AddAuto(&MainContorlller{}) and MainController has methods List and Page, +// visit the url /main/list to exec List function or /main/page to exec Page function. +func AutoRouter(c ControllerInterface) *App { + BeeApp.Handlers.AddAuto(c) + return BeeApp +} + +// AutoPrefix adds controller handler to BeeApp with prefix. +// it's same to App.AutoRouterWithPrefix. +// if beego.AutoPrefix("/admin",&MainContorlller{}) and MainController has methods List and Page, +// visit the url /admin/main/list to exec List function or /admin/main/page to exec Page function. +func AutoPrefix(prefix string, c ControllerInterface) *App { + BeeApp.Handlers.AddAutoPrefix(prefix, c) + return BeeApp +} + +// Get used to register router for Get method +// usage: +// beego.Get("/", func(ctx *context.Context){ +// ctx.Output.Body("hello world") +// }) +func Get(rootpath string, f FilterFunc) *App { + BeeApp.Handlers.Get(rootpath, f) + return BeeApp +} + +// Post used to register router for Post method +// usage: +// beego.Post("/api", func(ctx *context.Context){ +// ctx.Output.Body("hello world") +// }) +func Post(rootpath string, f FilterFunc) *App { + BeeApp.Handlers.Post(rootpath, f) + return BeeApp +} + +// Delete used to register router for Delete method +// usage: +// beego.Delete("/api", func(ctx *context.Context){ +// ctx.Output.Body("hello world") +// }) +func Delete(rootpath string, f FilterFunc) *App { + BeeApp.Handlers.Delete(rootpath, f) + return BeeApp +} + +// Put used to register router for Put method +// usage: +// beego.Put("/api", func(ctx *context.Context){ +// ctx.Output.Body("hello world") +// }) +func Put(rootpath string, f FilterFunc) *App { + BeeApp.Handlers.Put(rootpath, f) + return BeeApp +} + +// Head used to register router for Head method +// usage: +// beego.Head("/api", func(ctx *context.Context){ +// ctx.Output.Body("hello world") +// }) +func Head(rootpath string, f FilterFunc) *App { + BeeApp.Handlers.Head(rootpath, f) + return BeeApp +} + +// Options used to register router for Options method +// usage: +// beego.Options("/api", func(ctx *context.Context){ +// ctx.Output.Body("hello world") +// }) +func Options(rootpath string, f FilterFunc) *App { + BeeApp.Handlers.Options(rootpath, f) + return BeeApp +} + +// Patch used to register router for Patch method +// usage: +// beego.Patch("/api", func(ctx *context.Context){ +// ctx.Output.Body("hello world") +// }) +func Patch(rootpath string, f FilterFunc) *App { + BeeApp.Handlers.Patch(rootpath, f) + return BeeApp +} + +// Any used to register router for all methods +// usage: +// beego.Any("/api", func(ctx *context.Context){ +// ctx.Output.Body("hello world") +// }) +func Any(rootpath string, f FilterFunc) *App { + BeeApp.Handlers.Any(rootpath, f) + return BeeApp +} + +// Handler used to register a Handler router +// usage: +// beego.Handler("/api", http.HandlerFunc(func (w http.ResponseWriter, r *http.Request) { +// fmt.Fprintf(w, "Hello, %q", html.EscapeString(r.URL.Path)) +// })) +func Handler(rootpath string, h http.Handler, options ...interface{}) *App { + BeeApp.Handlers.Handler(rootpath, h, options...) + return BeeApp +} + +// InsertFilter adds a FilterFunc with pattern condition and action constant. +// The pos means action constant including +// beego.BeforeStatic, beego.BeforeRouter, beego.BeforeExec, beego.AfterExec and beego.FinishRouter. +// The bool params is for setting the returnOnOutput value (false allows multiple filters to execute) +func InsertFilter(pattern string, pos int, filter FilterFunc, params ...bool) *App { + BeeApp.Handlers.InsertFilter(pattern, pos, filter, params...) + return BeeApp +} diff --git a/pkg/beego.go b/pkg/beego.go new file mode 100644 index 00000000..8ebe0bab --- /dev/null +++ b/pkg/beego.go @@ -0,0 +1,123 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package beego + +import ( + "os" + "path/filepath" + "strconv" + "strings" +) + +const ( + // VERSION represent beego web framework version. + VERSION = "1.12.2" + + // DEV is for develop + DEV = "dev" + // PROD is for production + PROD = "prod" +) + +// M is Map shortcut +type M map[string]interface{} + +// Hook function to run +type hookfunc func() error + +var ( + hooks = make([]hookfunc, 0) //hook function slice to store the hookfunc +) + +// AddAPPStartHook is used to register the hookfunc +// The hookfuncs will run in beego.Run() +// such as initiating session , starting middleware , building template, starting admin control and so on. +func AddAPPStartHook(hf ...hookfunc) { + hooks = append(hooks, hf...) +} + +// Run beego application. +// beego.Run() default run on HttpPort +// beego.Run("localhost") +// beego.Run(":8089") +// beego.Run("127.0.0.1:8089") +func Run(params ...string) { + + initBeforeHTTPRun() + + if len(params) > 0 && params[0] != "" { + strs := strings.Split(params[0], ":") + if len(strs) > 0 && strs[0] != "" { + BConfig.Listen.HTTPAddr = strs[0] + } + if len(strs) > 1 && strs[1] != "" { + BConfig.Listen.HTTPPort, _ = strconv.Atoi(strs[1]) + } + + BConfig.Listen.Domains = params + } + + BeeApp.Run() +} + +// RunWithMiddleWares Run beego application with middlewares. +func RunWithMiddleWares(addr string, mws ...MiddleWare) { + initBeforeHTTPRun() + + strs := strings.Split(addr, ":") + if len(strs) > 0 && strs[0] != "" { + BConfig.Listen.HTTPAddr = strs[0] + BConfig.Listen.Domains = []string{strs[0]} + } + if len(strs) > 1 && strs[1] != "" { + BConfig.Listen.HTTPPort, _ = strconv.Atoi(strs[1]) + } + + BeeApp.Run(mws...) +} + +func initBeforeHTTPRun() { + //init hooks + AddAPPStartHook( + registerMime, + registerDefaultErrorHandler, + registerSession, + registerTemplate, + registerAdmin, + registerGzip, + ) + + for _, hk := range hooks { + if err := hk(); err != nil { + panic(err) + } + } +} + +// TestBeegoInit is for test package init +func TestBeegoInit(ap string) { + path := filepath.Join(ap, "conf", "app.conf") + os.Chdir(ap) + InitBeegoBeforeTest(path) +} + +// InitBeegoBeforeTest is for test package init +func InitBeegoBeforeTest(appConfigPath string) { + if err := LoadAppConfig(appConfigProvider, appConfigPath); err != nil { + panic(err) + } + BConfig.RunMode = "test" + initBeforeHTTPRun() +} diff --git a/pkg/build_info.go b/pkg/build_info.go new file mode 100644 index 00000000..6dc2835e --- /dev/null +++ b/pkg/build_info.go @@ -0,0 +1,27 @@ +// Copyright 2020 astaxie +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package beego + +var ( + BuildVersion string + BuildGitRevision string + BuildStatus string + BuildTag string + BuildTime string + + GoVersion string + + GitBranch string +) diff --git a/pkg/cache/README.md b/pkg/cache/README.md new file mode 100644 index 00000000..b467760a --- /dev/null +++ b/pkg/cache/README.md @@ -0,0 +1,59 @@ +## cache +cache is a Go cache manager. It can use many cache adapters. The repo is inspired by `database/sql` . + + +## How to install? + + go get github.com/astaxie/beego/cache + + +## What adapters are supported? + +As of now this cache support memory, Memcache and Redis. + + +## How to use it? + +First you must import it + + import ( + "github.com/astaxie/beego/cache" + ) + +Then init a Cache (example with memory adapter) + + bm, err := cache.NewCache("memory", `{"interval":60}`) + +Use it like this: + + bm.Put("astaxie", 1, 10 * time.Second) + bm.Get("astaxie") + bm.IsExist("astaxie") + bm.Delete("astaxie") + + +## Memory adapter + +Configure memory adapter like this: + + {"interval":60} + +interval means the gc time. The cache will check at each time interval, whether item has expired. + + +## Memcache adapter + +Memcache adapter use the [gomemcache](http://github.com/bradfitz/gomemcache) client. + +Configure like this: + + {"conn":"127.0.0.1:11211"} + + +## Redis adapter + +Redis adapter use the [redigo](http://github.com/gomodule/redigo) client. + +Configure like this: + + {"conn":":6039"} diff --git a/pkg/cache/cache.go b/pkg/cache/cache.go new file mode 100644 index 00000000..82585c4e --- /dev/null +++ b/pkg/cache/cache.go @@ -0,0 +1,103 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package cache provide a Cache interface and some implement engine +// Usage: +// +// import( +// "github.com/astaxie/beego/cache" +// ) +// +// bm, err := cache.NewCache("memory", `{"interval":60}`) +// +// Use it like this: +// +// bm.Put("astaxie", 1, 10 * time.Second) +// bm.Get("astaxie") +// bm.IsExist("astaxie") +// bm.Delete("astaxie") +// +// more docs http://beego.me/docs/module/cache.md +package cache + +import ( + "fmt" + "time" +) + +// Cache interface contains all behaviors for cache adapter. +// usage: +// cache.Register("file",cache.NewFileCache) // this operation is run in init method of file.go. +// c,err := cache.NewCache("file","{....}") +// c.Put("key",value, 3600 * time.Second) +// v := c.Get("key") +// +// c.Incr("counter") // now is 1 +// c.Incr("counter") // now is 2 +// count := c.Get("counter").(int) +type Cache interface { + // get cached value by key. + Get(key string) interface{} + // GetMulti is a batch version of Get. + GetMulti(keys []string) []interface{} + // set cached value with key and expire time. + Put(key string, val interface{}, timeout time.Duration) error + // delete cached value by key. + Delete(key string) error + // increase cached int value by key, as a counter. + Incr(key string) error + // decrease cached int value by key, as a counter. + Decr(key string) error + // check if cached value exists or not. + IsExist(key string) bool + // clear all cache. + ClearAll() error + // start gc routine based on config string settings. + StartAndGC(config string) error +} + +// Instance is a function create a new Cache Instance +type Instance func() Cache + +var adapters = make(map[string]Instance) + +// Register makes a cache adapter available by the adapter name. +// If Register is called twice with the same name or if driver is nil, +// it panics. +func Register(name string, adapter Instance) { + if adapter == nil { + panic("cache: Register adapter is nil") + } + if _, ok := adapters[name]; ok { + panic("cache: Register called twice for adapter " + name) + } + adapters[name] = adapter +} + +// NewCache Create a new cache driver by adapter name and config string. +// config need to be correct JSON as string: {"interval":360}. +// it will start gc automatically. +func NewCache(adapterName, config string) (adapter Cache, err error) { + instanceFunc, ok := adapters[adapterName] + if !ok { + err = fmt.Errorf("cache: unknown adapter name %q (forgot to import?)", adapterName) + return + } + adapter = instanceFunc() + err = adapter.StartAndGC(config) + if err != nil { + adapter = nil + } + return +} diff --git a/pkg/cache/cache_test.go b/pkg/cache/cache_test.go new file mode 100644 index 00000000..470c0a43 --- /dev/null +++ b/pkg/cache/cache_test.go @@ -0,0 +1,191 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package cache + +import ( + "os" + "sync" + "testing" + "time" +) + +func TestCacheIncr(t *testing.T) { + bm, err := NewCache("memory", `{"interval":20}`) + if err != nil { + t.Error("init err") + } + //timeoutDuration := 10 * time.Second + + bm.Put("edwardhey", 0, time.Second*20) + wg := sync.WaitGroup{} + wg.Add(10) + for i := 0; i < 10; i++ { + go func() { + defer wg.Done() + bm.Incr("edwardhey") + }() + } + wg.Wait() + if bm.Get("edwardhey").(int) != 10 { + t.Error("Incr err") + } +} + +func TestCache(t *testing.T) { + bm, err := NewCache("memory", `{"interval":20}`) + if err != nil { + t.Error("init err") + } + timeoutDuration := 10 * time.Second + if err = bm.Put("astaxie", 1, timeoutDuration); err != nil { + t.Error("set Error", err) + } + if !bm.IsExist("astaxie") { + t.Error("check err") + } + + if v := bm.Get("astaxie"); v.(int) != 1 { + t.Error("get err") + } + + time.Sleep(30 * time.Second) + + if bm.IsExist("astaxie") { + t.Error("check err") + } + + if err = bm.Put("astaxie", 1, timeoutDuration); err != nil { + t.Error("set Error", err) + } + + if err = bm.Incr("astaxie"); err != nil { + t.Error("Incr Error", err) + } + + if v := bm.Get("astaxie"); v.(int) != 2 { + t.Error("get err") + } + + if err = bm.Decr("astaxie"); err != nil { + t.Error("Decr Error", err) + } + + if v := bm.Get("astaxie"); v.(int) != 1 { + t.Error("get err") + } + bm.Delete("astaxie") + if bm.IsExist("astaxie") { + t.Error("delete err") + } + + //test GetMulti + if err = bm.Put("astaxie", "author", timeoutDuration); err != nil { + t.Error("set Error", err) + } + if !bm.IsExist("astaxie") { + t.Error("check err") + } + if v := bm.Get("astaxie"); v.(string) != "author" { + t.Error("get err") + } + + if err = bm.Put("astaxie1", "author1", timeoutDuration); err != nil { + t.Error("set Error", err) + } + if !bm.IsExist("astaxie1") { + t.Error("check err") + } + + vv := bm.GetMulti([]string{"astaxie", "astaxie1"}) + if len(vv) != 2 { + t.Error("GetMulti ERROR") + } + if vv[0].(string) != "author" { + t.Error("GetMulti ERROR") + } + if vv[1].(string) != "author1" { + t.Error("GetMulti ERROR") + } +} + +func TestFileCache(t *testing.T) { + bm, err := NewCache("file", `{"CachePath":"cache","FileSuffix":".bin","DirectoryLevel":"2","EmbedExpiry":"0"}`) + if err != nil { + t.Error("init err") + } + timeoutDuration := 10 * time.Second + if err = bm.Put("astaxie", 1, timeoutDuration); err != nil { + t.Error("set Error", err) + } + if !bm.IsExist("astaxie") { + t.Error("check err") + } + + if v := bm.Get("astaxie"); v.(int) != 1 { + t.Error("get err") + } + + if err = bm.Incr("astaxie"); err != nil { + t.Error("Incr Error", err) + } + + if v := bm.Get("astaxie"); v.(int) != 2 { + t.Error("get err") + } + + if err = bm.Decr("astaxie"); err != nil { + t.Error("Decr Error", err) + } + + if v := bm.Get("astaxie"); v.(int) != 1 { + t.Error("get err") + } + bm.Delete("astaxie") + if bm.IsExist("astaxie") { + t.Error("delete err") + } + + //test string + if err = bm.Put("astaxie", "author", timeoutDuration); err != nil { + t.Error("set Error", err) + } + if !bm.IsExist("astaxie") { + t.Error("check err") + } + if v := bm.Get("astaxie"); v.(string) != "author" { + t.Error("get err") + } + + //test GetMulti + if err = bm.Put("astaxie1", "author1", timeoutDuration); err != nil { + t.Error("set Error", err) + } + if !bm.IsExist("astaxie1") { + t.Error("check err") + } + + vv := bm.GetMulti([]string{"astaxie", "astaxie1"}) + if len(vv) != 2 { + t.Error("GetMulti ERROR") + } + if vv[0].(string) != "author" { + t.Error("GetMulti ERROR") + } + if vv[1].(string) != "author1" { + t.Error("GetMulti ERROR") + } + + os.RemoveAll("cache") +} diff --git a/pkg/cache/conv.go b/pkg/cache/conv.go new file mode 100644 index 00000000..87800586 --- /dev/null +++ b/pkg/cache/conv.go @@ -0,0 +1,100 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package cache + +import ( + "fmt" + "strconv" +) + +// GetString convert interface to string. +func GetString(v interface{}) string { + switch result := v.(type) { + case string: + return result + case []byte: + return string(result) + default: + if v != nil { + return fmt.Sprint(result) + } + } + return "" +} + +// GetInt convert interface to int. +func GetInt(v interface{}) int { + switch result := v.(type) { + case int: + return result + case int32: + return int(result) + case int64: + return int(result) + default: + if d := GetString(v); d != "" { + value, _ := strconv.Atoi(d) + return value + } + } + return 0 +} + +// GetInt64 convert interface to int64. +func GetInt64(v interface{}) int64 { + switch result := v.(type) { + case int: + return int64(result) + case int32: + return int64(result) + case int64: + return result + default: + + if d := GetString(v); d != "" { + value, _ := strconv.ParseInt(d, 10, 64) + return value + } + } + return 0 +} + +// GetFloat64 convert interface to float64. +func GetFloat64(v interface{}) float64 { + switch result := v.(type) { + case float64: + return result + default: + if d := GetString(v); d != "" { + value, _ := strconv.ParseFloat(d, 64) + return value + } + } + return 0 +} + +// GetBool convert interface to bool. +func GetBool(v interface{}) bool { + switch result := v.(type) { + case bool: + return result + default: + if d := GetString(v); d != "" { + value, _ := strconv.ParseBool(d) + return value + } + } + return false +} diff --git a/pkg/cache/conv_test.go b/pkg/cache/conv_test.go new file mode 100644 index 00000000..b90e224a --- /dev/null +++ b/pkg/cache/conv_test.go @@ -0,0 +1,143 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package cache + +import ( + "testing" +) + +func TestGetString(t *testing.T) { + var t1 = "test1" + if "test1" != GetString(t1) { + t.Error("get string from string error") + } + var t2 = []byte("test2") + if "test2" != GetString(t2) { + t.Error("get string from byte array error") + } + var t3 = 1 + if "1" != GetString(t3) { + t.Error("get string from int error") + } + var t4 int64 = 1 + if "1" != GetString(t4) { + t.Error("get string from int64 error") + } + var t5 = 1.1 + if "1.1" != GetString(t5) { + t.Error("get string from float64 error") + } + + if "" != GetString(nil) { + t.Error("get string from nil error") + } +} + +func TestGetInt(t *testing.T) { + var t1 = 1 + if 1 != GetInt(t1) { + t.Error("get int from int error") + } + var t2 int32 = 32 + if 32 != GetInt(t2) { + t.Error("get int from int32 error") + } + var t3 int64 = 64 + if 64 != GetInt(t3) { + t.Error("get int from int64 error") + } + var t4 = "128" + if 128 != GetInt(t4) { + t.Error("get int from num string error") + } + if 0 != GetInt(nil) { + t.Error("get int from nil error") + } +} + +func TestGetInt64(t *testing.T) { + var i int64 = 1 + var t1 = 1 + if i != GetInt64(t1) { + t.Error("get int64 from int error") + } + var t2 int32 = 1 + if i != GetInt64(t2) { + t.Error("get int64 from int32 error") + } + var t3 int64 = 1 + if i != GetInt64(t3) { + t.Error("get int64 from int64 error") + } + var t4 = "1" + if i != GetInt64(t4) { + t.Error("get int64 from num string error") + } + if 0 != GetInt64(nil) { + t.Error("get int64 from nil") + } +} + +func TestGetFloat64(t *testing.T) { + var f = 1.11 + var t1 float32 = 1.11 + if f != GetFloat64(t1) { + t.Error("get float64 from float32 error") + } + var t2 = 1.11 + if f != GetFloat64(t2) { + t.Error("get float64 from float64 error") + } + var t3 = "1.11" + if f != GetFloat64(t3) { + t.Error("get float64 from string error") + } + + var f2 float64 = 1 + var t4 = 1 + if f2 != GetFloat64(t4) { + t.Error("get float64 from int error") + } + + if 0 != GetFloat64(nil) { + t.Error("get float64 from nil error") + } +} + +func TestGetBool(t *testing.T) { + var t1 = true + if !GetBool(t1) { + t.Error("get bool from bool error") + } + var t2 = "true" + if !GetBool(t2) { + t.Error("get bool from string error") + } + if GetBool(nil) { + t.Error("get bool from nil error") + } +} + +func byteArrayEquals(a []byte, b []byte) bool { + if len(a) != len(b) { + return false + } + for i, v := range a { + if v != b[i] { + return false + } + } + return true +} diff --git a/pkg/cache/file.go b/pkg/cache/file.go new file mode 100644 index 00000000..6f12d3ee --- /dev/null +++ b/pkg/cache/file.go @@ -0,0 +1,258 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package cache + +import ( + "bytes" + "crypto/md5" + "encoding/gob" + "encoding/hex" + "encoding/json" + "fmt" + "io" + "io/ioutil" + "os" + "path/filepath" + "reflect" + "strconv" + "time" +) + +// FileCacheItem is basic unit of file cache adapter. +// it contains data and expire time. +type FileCacheItem struct { + Data interface{} + Lastaccess time.Time + Expired time.Time +} + +// FileCache Config +var ( + FileCachePath = "cache" // cache directory + FileCacheFileSuffix = ".bin" // cache file suffix + FileCacheDirectoryLevel = 2 // cache file deep level if auto generated cache files. + FileCacheEmbedExpiry time.Duration // cache expire time, default is no expire forever. +) + +// FileCache is cache adapter for file storage. +type FileCache struct { + CachePath string + FileSuffix string + DirectoryLevel int + EmbedExpiry int +} + +// NewFileCache Create new file cache with no config. +// the level and expiry need set in method StartAndGC as config string. +func NewFileCache() Cache { + // return &FileCache{CachePath:FileCachePath, FileSuffix:FileCacheFileSuffix} + return &FileCache{} +} + +// StartAndGC will start and begin gc for file cache. +// the config need to be like {CachePath:"/cache","FileSuffix":".bin","DirectoryLevel":"2","EmbedExpiry":"0"} +func (fc *FileCache) StartAndGC(config string) error { + + cfg := make(map[string]string) + err := json.Unmarshal([]byte(config), &cfg) + if err != nil { + return err + } + if _, ok := cfg["CachePath"]; !ok { + cfg["CachePath"] = FileCachePath + } + if _, ok := cfg["FileSuffix"]; !ok { + cfg["FileSuffix"] = FileCacheFileSuffix + } + if _, ok := cfg["DirectoryLevel"]; !ok { + cfg["DirectoryLevel"] = strconv.Itoa(FileCacheDirectoryLevel) + } + if _, ok := cfg["EmbedExpiry"]; !ok { + cfg["EmbedExpiry"] = strconv.FormatInt(int64(FileCacheEmbedExpiry.Seconds()), 10) + } + fc.CachePath = cfg["CachePath"] + fc.FileSuffix = cfg["FileSuffix"] + fc.DirectoryLevel, _ = strconv.Atoi(cfg["DirectoryLevel"]) + fc.EmbedExpiry, _ = strconv.Atoi(cfg["EmbedExpiry"]) + + fc.Init() + return nil +} + +// Init will make new dir for file cache if not exist. +func (fc *FileCache) Init() { + if ok, _ := exists(fc.CachePath); !ok { // todo : error handle + _ = os.MkdirAll(fc.CachePath, os.ModePerm) // todo : error handle + } +} + +// get cached file name. it's md5 encoded. +func (fc *FileCache) getCacheFileName(key string) string { + m := md5.New() + io.WriteString(m, key) + keyMd5 := hex.EncodeToString(m.Sum(nil)) + cachePath := fc.CachePath + switch fc.DirectoryLevel { + case 2: + cachePath = filepath.Join(cachePath, keyMd5[0:2], keyMd5[2:4]) + case 1: + cachePath = filepath.Join(cachePath, keyMd5[0:2]) + } + + if ok, _ := exists(cachePath); !ok { // todo : error handle + _ = os.MkdirAll(cachePath, os.ModePerm) // todo : error handle + } + + return filepath.Join(cachePath, fmt.Sprintf("%s%s", keyMd5, fc.FileSuffix)) +} + +// Get value from file cache. +// if non-exist or expired, return empty string. +func (fc *FileCache) Get(key string) interface{} { + fileData, err := FileGetContents(fc.getCacheFileName(key)) + if err != nil { + return "" + } + var to FileCacheItem + GobDecode(fileData, &to) + if to.Expired.Before(time.Now()) { + return "" + } + return to.Data +} + +// GetMulti gets values from file cache. +// if non-exist or expired, return empty string. +func (fc *FileCache) GetMulti(keys []string) []interface{} { + var rc []interface{} + for _, key := range keys { + rc = append(rc, fc.Get(key)) + } + return rc +} + +// Put value into file cache. +// timeout means how long to keep this file, unit of ms. +// if timeout equals fc.EmbedExpiry(default is 0), cache this item forever. +func (fc *FileCache) Put(key string, val interface{}, timeout time.Duration) error { + gob.Register(val) + + item := FileCacheItem{Data: val} + if timeout == time.Duration(fc.EmbedExpiry) { + item.Expired = time.Now().Add((86400 * 365 * 10) * time.Second) // ten years + } else { + item.Expired = time.Now().Add(timeout) + } + item.Lastaccess = time.Now() + data, err := GobEncode(item) + if err != nil { + return err + } + return FilePutContents(fc.getCacheFileName(key), data) +} + +// Delete file cache value. +func (fc *FileCache) Delete(key string) error { + filename := fc.getCacheFileName(key) + if ok, _ := exists(filename); ok { + return os.Remove(filename) + } + return nil +} + +// Incr will increase cached int value. +// fc value is saving forever unless Delete. +func (fc *FileCache) Incr(key string) error { + data := fc.Get(key) + var incr int + if reflect.TypeOf(data).Name() != "int" { + incr = 0 + } else { + incr = data.(int) + 1 + } + fc.Put(key, incr, time.Duration(fc.EmbedExpiry)) + return nil +} + +// Decr will decrease cached int value. +func (fc *FileCache) Decr(key string) error { + data := fc.Get(key) + var decr int + if reflect.TypeOf(data).Name() != "int" || data.(int)-1 <= 0 { + decr = 0 + } else { + decr = data.(int) - 1 + } + fc.Put(key, decr, time.Duration(fc.EmbedExpiry)) + return nil +} + +// IsExist check value is exist. +func (fc *FileCache) IsExist(key string) bool { + ret, _ := exists(fc.getCacheFileName(key)) + return ret +} + +// ClearAll will clean cached files. +// not implemented. +func (fc *FileCache) ClearAll() error { + return nil +} + +// check file exist. +func exists(path string) (bool, error) { + _, err := os.Stat(path) + if err == nil { + return true, nil + } + if os.IsNotExist(err) { + return false, nil + } + return false, err +} + +// FileGetContents Get bytes to file. +// if non-exist, create this file. +func FileGetContents(filename string) (data []byte, e error) { + return ioutil.ReadFile(filename) +} + +// FilePutContents Put bytes to file. +// if non-exist, create this file. +func FilePutContents(filename string, content []byte) error { + return ioutil.WriteFile(filename, content, os.ModePerm) +} + +// GobEncode Gob encodes file cache item. +func GobEncode(data interface{}) ([]byte, error) { + buf := bytes.NewBuffer(nil) + enc := gob.NewEncoder(buf) + err := enc.Encode(data) + if err != nil { + return nil, err + } + return buf.Bytes(), err +} + +// GobDecode Gob decodes file cache item. +func GobDecode(data []byte, to *FileCacheItem) error { + buf := bytes.NewBuffer(data) + dec := gob.NewDecoder(buf) + return dec.Decode(&to) +} + +func init() { + Register("file", NewFileCache) +} diff --git a/pkg/cache/memcache/memcache.go b/pkg/cache/memcache/memcache.go new file mode 100644 index 00000000..19116bfa --- /dev/null +++ b/pkg/cache/memcache/memcache.go @@ -0,0 +1,188 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package memcache for cache provider +// +// depend on github.com/bradfitz/gomemcache/memcache +// +// go install github.com/bradfitz/gomemcache/memcache +// +// Usage: +// import( +// _ "github.com/astaxie/beego/cache/memcache" +// "github.com/astaxie/beego/cache" +// ) +// +// bm, err := cache.NewCache("memcache", `{"conn":"127.0.0.1:11211"}`) +// +// more docs http://beego.me/docs/module/cache.md +package memcache + +import ( + "encoding/json" + "errors" + "strings" + "time" + + "github.com/astaxie/beego/cache" + "github.com/bradfitz/gomemcache/memcache" +) + +// Cache Memcache adapter. +type Cache struct { + conn *memcache.Client + conninfo []string +} + +// NewMemCache create new memcache adapter. +func NewMemCache() cache.Cache { + return &Cache{} +} + +// Get get value from memcache. +func (rc *Cache) Get(key string) interface{} { + if rc.conn == nil { + if err := rc.connectInit(); err != nil { + return err + } + } + if item, err := rc.conn.Get(key); err == nil { + return item.Value + } + return nil +} + +// GetMulti get value from memcache. +func (rc *Cache) GetMulti(keys []string) []interface{} { + size := len(keys) + var rv []interface{} + if rc.conn == nil { + if err := rc.connectInit(); err != nil { + for i := 0; i < size; i++ { + rv = append(rv, err) + } + return rv + } + } + mv, err := rc.conn.GetMulti(keys) + if err == nil { + for _, v := range mv { + rv = append(rv, v.Value) + } + return rv + } + for i := 0; i < size; i++ { + rv = append(rv, err) + } + return rv +} + +// Put put value to memcache. +func (rc *Cache) Put(key string, val interface{}, timeout time.Duration) error { + if rc.conn == nil { + if err := rc.connectInit(); err != nil { + return err + } + } + item := memcache.Item{Key: key, Expiration: int32(timeout / time.Second)} + if v, ok := val.([]byte); ok { + item.Value = v + } else if str, ok := val.(string); ok { + item.Value = []byte(str) + } else { + return errors.New("val only support string and []byte") + } + return rc.conn.Set(&item) +} + +// Delete delete value in memcache. +func (rc *Cache) Delete(key string) error { + if rc.conn == nil { + if err := rc.connectInit(); err != nil { + return err + } + } + return rc.conn.Delete(key) +} + +// Incr increase counter. +func (rc *Cache) Incr(key string) error { + if rc.conn == nil { + if err := rc.connectInit(); err != nil { + return err + } + } + _, err := rc.conn.Increment(key, 1) + return err +} + +// Decr decrease counter. +func (rc *Cache) Decr(key string) error { + if rc.conn == nil { + if err := rc.connectInit(); err != nil { + return err + } + } + _, err := rc.conn.Decrement(key, 1) + return err +} + +// IsExist check value exists in memcache. +func (rc *Cache) IsExist(key string) bool { + if rc.conn == nil { + if err := rc.connectInit(); err != nil { + return false + } + } + _, err := rc.conn.Get(key) + return err == nil +} + +// ClearAll clear all cached in memcache. +func (rc *Cache) ClearAll() error { + if rc.conn == nil { + if err := rc.connectInit(); err != nil { + return err + } + } + return rc.conn.FlushAll() +} + +// StartAndGC start memcache adapter. +// config string is like {"conn":"connection info"}. +// if connecting error, return. +func (rc *Cache) StartAndGC(config string) error { + var cf map[string]string + json.Unmarshal([]byte(config), &cf) + if _, ok := cf["conn"]; !ok { + return errors.New("config has no conn key") + } + rc.conninfo = strings.Split(cf["conn"], ";") + if rc.conn == nil { + if err := rc.connectInit(); err != nil { + return err + } + } + return nil +} + +// connect to memcache and keep the connection. +func (rc *Cache) connectInit() error { + rc.conn = memcache.New(rc.conninfo...) + return nil +} + +func init() { + cache.Register("memcache", NewMemCache) +} diff --git a/pkg/cache/memcache/memcache_test.go b/pkg/cache/memcache/memcache_test.go new file mode 100644 index 00000000..d9129b69 --- /dev/null +++ b/pkg/cache/memcache/memcache_test.go @@ -0,0 +1,108 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package memcache + +import ( + _ "github.com/bradfitz/gomemcache/memcache" + + "strconv" + "testing" + "time" + + "github.com/astaxie/beego/cache" +) + +func TestMemcacheCache(t *testing.T) { + bm, err := cache.NewCache("memcache", `{"conn": "127.0.0.1:11211"}`) + if err != nil { + t.Error("init err") + } + timeoutDuration := 10 * time.Second + if err = bm.Put("astaxie", "1", timeoutDuration); err != nil { + t.Error("set Error", err) + } + if !bm.IsExist("astaxie") { + t.Error("check err") + } + + time.Sleep(11 * time.Second) + + if bm.IsExist("astaxie") { + t.Error("check err") + } + if err = bm.Put("astaxie", "1", timeoutDuration); err != nil { + t.Error("set Error", err) + } + + if v, err := strconv.Atoi(string(bm.Get("astaxie").([]byte))); err != nil || v != 1 { + t.Error("get err") + } + + if err = bm.Incr("astaxie"); err != nil { + t.Error("Incr Error", err) + } + + if v, err := strconv.Atoi(string(bm.Get("astaxie").([]byte))); err != nil || v != 2 { + t.Error("get err") + } + + if err = bm.Decr("astaxie"); err != nil { + t.Error("Decr Error", err) + } + + if v, err := strconv.Atoi(string(bm.Get("astaxie").([]byte))); err != nil || v != 1 { + t.Error("get err") + } + bm.Delete("astaxie") + if bm.IsExist("astaxie") { + t.Error("delete err") + } + + //test string + if err = bm.Put("astaxie", "author", timeoutDuration); err != nil { + t.Error("set Error", err) + } + if !bm.IsExist("astaxie") { + t.Error("check err") + } + + if v := bm.Get("astaxie").([]byte); string(v) != "author" { + t.Error("get err") + } + + //test GetMulti + if err = bm.Put("astaxie1", "author1", timeoutDuration); err != nil { + t.Error("set Error", err) + } + if !bm.IsExist("astaxie1") { + t.Error("check err") + } + + vv := bm.GetMulti([]string{"astaxie", "astaxie1"}) + if len(vv) != 2 { + t.Error("GetMulti ERROR") + } + if string(vv[0].([]byte)) != "author" && string(vv[0].([]byte)) != "author1" { + t.Error("GetMulti ERROR") + } + if string(vv[1].([]byte)) != "author1" && string(vv[1].([]byte)) != "author" { + t.Error("GetMulti ERROR") + } + + // test clear all + if err = bm.ClearAll(); err != nil { + t.Error("clear all err") + } +} diff --git a/pkg/cache/memory.go b/pkg/cache/memory.go new file mode 100644 index 00000000..d8314e3c --- /dev/null +++ b/pkg/cache/memory.go @@ -0,0 +1,256 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package cache + +import ( + "encoding/json" + "errors" + "sync" + "time" +) + +var ( + // DefaultEvery means the clock time of recycling the expired cache items in memory. + DefaultEvery = 60 // 1 minute +) + +// MemoryItem store memory cache item. +type MemoryItem struct { + val interface{} + createdTime time.Time + lifespan time.Duration +} + +func (mi *MemoryItem) isExpire() bool { + // 0 means forever + if mi.lifespan == 0 { + return false + } + return time.Now().Sub(mi.createdTime) > mi.lifespan +} + +// MemoryCache is Memory cache adapter. +// it contains a RW locker for safe map storage. +type MemoryCache struct { + sync.RWMutex + dur time.Duration + items map[string]*MemoryItem + Every int // run an expiration check Every clock time +} + +// NewMemoryCache returns a new MemoryCache. +func NewMemoryCache() Cache { + cache := MemoryCache{items: make(map[string]*MemoryItem)} + return &cache +} + +// Get cache from memory. +// if non-existed or expired, return nil. +func (bc *MemoryCache) Get(name string) interface{} { + bc.RLock() + defer bc.RUnlock() + if itm, ok := bc.items[name]; ok { + if itm.isExpire() { + return nil + } + return itm.val + } + return nil +} + +// GetMulti gets caches from memory. +// if non-existed or expired, return nil. +func (bc *MemoryCache) GetMulti(names []string) []interface{} { + var rc []interface{} + for _, name := range names { + rc = append(rc, bc.Get(name)) + } + return rc +} + +// Put cache to memory. +// if lifespan is 0, it will be forever till restart. +func (bc *MemoryCache) Put(name string, value interface{}, lifespan time.Duration) error { + bc.Lock() + defer bc.Unlock() + bc.items[name] = &MemoryItem{ + val: value, + createdTime: time.Now(), + lifespan: lifespan, + } + return nil +} + +// Delete cache in memory. +func (bc *MemoryCache) Delete(name string) error { + bc.Lock() + defer bc.Unlock() + if _, ok := bc.items[name]; !ok { + return errors.New("key not exist") + } + delete(bc.items, name) + if _, ok := bc.items[name]; ok { + return errors.New("delete key error") + } + return nil +} + +// Incr increase cache counter in memory. +// it supports int,int32,int64,uint,uint32,uint64. +func (bc *MemoryCache) Incr(key string) error { + bc.Lock() + defer bc.Unlock() + itm, ok := bc.items[key] + if !ok { + return errors.New("key not exist") + } + switch val := itm.val.(type) { + case int: + itm.val = val + 1 + case int32: + itm.val = val + 1 + case int64: + itm.val = val + 1 + case uint: + itm.val = val + 1 + case uint32: + itm.val = val + 1 + case uint64: + itm.val = val + 1 + default: + return errors.New("item val is not (u)int (u)int32 (u)int64") + } + return nil +} + +// Decr decrease counter in memory. +func (bc *MemoryCache) Decr(key string) error { + bc.Lock() + defer bc.Unlock() + itm, ok := bc.items[key] + if !ok { + return errors.New("key not exist") + } + switch val := itm.val.(type) { + case int: + itm.val = val - 1 + case int64: + itm.val = val - 1 + case int32: + itm.val = val - 1 + case uint: + if val > 0 { + itm.val = val - 1 + } else { + return errors.New("item val is less than 0") + } + case uint32: + if val > 0 { + itm.val = val - 1 + } else { + return errors.New("item val is less than 0") + } + case uint64: + if val > 0 { + itm.val = val - 1 + } else { + return errors.New("item val is less than 0") + } + default: + return errors.New("item val is not int int64 int32") + } + return nil +} + +// IsExist check cache exist in memory. +func (bc *MemoryCache) IsExist(name string) bool { + bc.RLock() + defer bc.RUnlock() + if v, ok := bc.items[name]; ok { + return !v.isExpire() + } + return false +} + +// ClearAll will delete all cache in memory. +func (bc *MemoryCache) ClearAll() error { + bc.Lock() + defer bc.Unlock() + bc.items = make(map[string]*MemoryItem) + return nil +} + +// StartAndGC start memory cache. it will check expiration in every clock time. +func (bc *MemoryCache) StartAndGC(config string) error { + var cf map[string]int + json.Unmarshal([]byte(config), &cf) + if _, ok := cf["interval"]; !ok { + cf = make(map[string]int) + cf["interval"] = DefaultEvery + } + dur := time.Duration(cf["interval"]) * time.Second + bc.Every = cf["interval"] + bc.dur = dur + go bc.vacuum() + return nil +} + +// check expiration. +func (bc *MemoryCache) vacuum() { + bc.RLock() + every := bc.Every + bc.RUnlock() + + if every < 1 { + return + } + for { + <-time.After(bc.dur) + bc.RLock() + if bc.items == nil { + bc.RUnlock() + return + } + bc.RUnlock() + if keys := bc.expiredKeys(); len(keys) != 0 { + bc.clearItems(keys) + } + } +} + +// expiredKeys returns key list which are expired. +func (bc *MemoryCache) expiredKeys() (keys []string) { + bc.RLock() + defer bc.RUnlock() + for key, itm := range bc.items { + if itm.isExpire() { + keys = append(keys, key) + } + } + return +} + +// clearItems removes all the items which key in keys. +func (bc *MemoryCache) clearItems(keys []string) { + bc.Lock() + defer bc.Unlock() + for _, key := range keys { + delete(bc.items, key) + } +} + +func init() { + Register("memory", NewMemoryCache) +} diff --git a/pkg/cache/redis/redis.go b/pkg/cache/redis/redis.go new file mode 100644 index 00000000..56faf211 --- /dev/null +++ b/pkg/cache/redis/redis.go @@ -0,0 +1,272 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package redis for cache provider +// +// depend on github.com/gomodule/redigo/redis +// +// go install github.com/gomodule/redigo/redis +// +// Usage: +// import( +// _ "github.com/astaxie/beego/cache/redis" +// "github.com/astaxie/beego/cache" +// ) +// +// bm, err := cache.NewCache("redis", `{"conn":"127.0.0.1:11211"}`) +// +// more docs http://beego.me/docs/module/cache.md +package redis + +import ( + "encoding/json" + "errors" + "fmt" + "strconv" + "time" + + "github.com/gomodule/redigo/redis" + + "github.com/astaxie/beego/cache" + "strings" +) + +var ( + // DefaultKey the collection name of redis for cache adapter. + DefaultKey = "beecacheRedis" +) + +// Cache is Redis cache adapter. +type Cache struct { + p *redis.Pool // redis connection pool + conninfo string + dbNum int + key string + password string + maxIdle int + + //the timeout to a value less than the redis server's timeout. + timeout time.Duration +} + +// NewRedisCache create new redis cache with default collection name. +func NewRedisCache() cache.Cache { + return &Cache{key: DefaultKey} +} + +// actually do the redis cmds, args[0] must be the key name. +func (rc *Cache) do(commandName string, args ...interface{}) (reply interface{}, err error) { + if len(args) < 1 { + return nil, errors.New("missing required arguments") + } + args[0] = rc.associate(args[0]) + c := rc.p.Get() + defer c.Close() + + return c.Do(commandName, args...) +} + +// associate with config key. +func (rc *Cache) associate(originKey interface{}) string { + return fmt.Sprintf("%s:%s", rc.key, originKey) +} + +// Get cache from redis. +func (rc *Cache) Get(key string) interface{} { + if v, err := rc.do("GET", key); err == nil { + return v + } + return nil +} + +// GetMulti get cache from redis. +func (rc *Cache) GetMulti(keys []string) []interface{} { + c := rc.p.Get() + defer c.Close() + var args []interface{} + for _, key := range keys { + args = append(args, rc.associate(key)) + } + values, err := redis.Values(c.Do("MGET", args...)) + if err != nil { + return nil + } + return values +} + +// Put put cache to redis. +func (rc *Cache) Put(key string, val interface{}, timeout time.Duration) error { + _, err := rc.do("SETEX", key, int64(timeout/time.Second), val) + return err +} + +// Delete delete cache in redis. +func (rc *Cache) Delete(key string) error { + _, err := rc.do("DEL", key) + return err +} + +// IsExist check cache's existence in redis. +func (rc *Cache) IsExist(key string) bool { + v, err := redis.Bool(rc.do("EXISTS", key)) + if err != nil { + return false + } + return v +} + +// Incr increase counter in redis. +func (rc *Cache) Incr(key string) error { + _, err := redis.Bool(rc.do("INCRBY", key, 1)) + return err +} + +// Decr decrease counter in redis. +func (rc *Cache) Decr(key string) error { + _, err := redis.Bool(rc.do("INCRBY", key, -1)) + return err +} + +// ClearAll clean all cache in redis. delete this redis collection. +func (rc *Cache) ClearAll() error { + cachedKeys, err := rc.Scan(rc.key + ":*") + if err != nil { + return err + } + c := rc.p.Get() + defer c.Close() + for _, str := range cachedKeys { + if _, err = c.Do("DEL", str); err != nil { + return err + } + } + return err +} + +// Scan scan all keys matching the pattern. a better choice than `keys` +func (rc *Cache) Scan(pattern string) (keys []string, err error) { + c := rc.p.Get() + defer c.Close() + var ( + cursor uint64 = 0 // start + result []interface{} + list []string + ) + for { + result, err = redis.Values(c.Do("SCAN", cursor, "MATCH", pattern, "COUNT", 1024)) + if err != nil { + return + } + list, err = redis.Strings(result[1], nil) + if err != nil { + return + } + keys = append(keys, list...) + cursor, err = redis.Uint64(result[0], nil) + if err != nil { + return + } + if cursor == 0 { // over + return + } + } +} + +// StartAndGC start redis cache adapter. +// config is like {"key":"collection key","conn":"connection info","dbNum":"0"} +// the cache item in redis are stored forever, +// so no gc operation. +func (rc *Cache) StartAndGC(config string) error { + var cf map[string]string + json.Unmarshal([]byte(config), &cf) + + if _, ok := cf["key"]; !ok { + cf["key"] = DefaultKey + } + if _, ok := cf["conn"]; !ok { + return errors.New("config has no conn key") + } + + // Format redis://@: + cf["conn"] = strings.Replace(cf["conn"], "redis://", "", 1) + if i := strings.Index(cf["conn"], "@"); i > -1 { + cf["password"] = cf["conn"][0:i] + cf["conn"] = cf["conn"][i+1:] + } + + if _, ok := cf["dbNum"]; !ok { + cf["dbNum"] = "0" + } + if _, ok := cf["password"]; !ok { + cf["password"] = "" + } + if _, ok := cf["maxIdle"]; !ok { + cf["maxIdle"] = "3" + } + if _, ok := cf["timeout"]; !ok { + cf["timeout"] = "180s" + } + rc.key = cf["key"] + rc.conninfo = cf["conn"] + rc.dbNum, _ = strconv.Atoi(cf["dbNum"]) + rc.password = cf["password"] + rc.maxIdle, _ = strconv.Atoi(cf["maxIdle"]) + + if v, err := time.ParseDuration(cf["timeout"]); err == nil { + rc.timeout = v + } else { + rc.timeout = 180 * time.Second + } + + rc.connectInit() + + c := rc.p.Get() + defer c.Close() + + return c.Err() +} + +// connect to redis. +func (rc *Cache) connectInit() { + dialFunc := func() (c redis.Conn, err error) { + c, err = redis.Dial("tcp", rc.conninfo) + if err != nil { + return nil, err + } + + if rc.password != "" { + if _, err := c.Do("AUTH", rc.password); err != nil { + c.Close() + return nil, err + } + } + + _, selecterr := c.Do("SELECT", rc.dbNum) + if selecterr != nil { + c.Close() + return nil, selecterr + } + return + } + // initialize a new pool + rc.p = &redis.Pool{ + MaxIdle: rc.maxIdle, + IdleTimeout: rc.timeout, + Dial: dialFunc, + } +} + +func init() { + cache.Register("redis", NewRedisCache) +} diff --git a/pkg/cache/redis/redis_test.go b/pkg/cache/redis/redis_test.go new file mode 100644 index 00000000..60a19180 --- /dev/null +++ b/pkg/cache/redis/redis_test.go @@ -0,0 +1,144 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package redis + +import ( + "fmt" + "testing" + "time" + + "github.com/astaxie/beego/cache" + "github.com/gomodule/redigo/redis" + "github.com/stretchr/testify/assert" +) + +func TestRedisCache(t *testing.T) { + bm, err := cache.NewCache("redis", `{"conn": "127.0.0.1:6379"}`) + if err != nil { + t.Error("init err") + } + timeoutDuration := 10 * time.Second + if err = bm.Put("astaxie", 1, timeoutDuration); err != nil { + t.Error("set Error", err) + } + if !bm.IsExist("astaxie") { + t.Error("check err") + } + + time.Sleep(11 * time.Second) + + if bm.IsExist("astaxie") { + t.Error("check err") + } + if err = bm.Put("astaxie", 1, timeoutDuration); err != nil { + t.Error("set Error", err) + } + + if v, _ := redis.Int(bm.Get("astaxie"), err); v != 1 { + t.Error("get err") + } + + if err = bm.Incr("astaxie"); err != nil { + t.Error("Incr Error", err) + } + + if v, _ := redis.Int(bm.Get("astaxie"), err); v != 2 { + t.Error("get err") + } + + if err = bm.Decr("astaxie"); err != nil { + t.Error("Decr Error", err) + } + + if v, _ := redis.Int(bm.Get("astaxie"), err); v != 1 { + t.Error("get err") + } + bm.Delete("astaxie") + if bm.IsExist("astaxie") { + t.Error("delete err") + } + + //test string + if err = bm.Put("astaxie", "author", timeoutDuration); err != nil { + t.Error("set Error", err) + } + if !bm.IsExist("astaxie") { + t.Error("check err") + } + + if v, _ := redis.String(bm.Get("astaxie"), err); v != "author" { + t.Error("get err") + } + + //test GetMulti + if err = bm.Put("astaxie1", "author1", timeoutDuration); err != nil { + t.Error("set Error", err) + } + if !bm.IsExist("astaxie1") { + t.Error("check err") + } + + vv := bm.GetMulti([]string{"astaxie", "astaxie1"}) + if len(vv) != 2 { + t.Error("GetMulti ERROR") + } + if v, _ := redis.String(vv[0], nil); v != "author" { + t.Error("GetMulti ERROR") + } + if v, _ := redis.String(vv[1], nil); v != "author1" { + t.Error("GetMulti ERROR") + } + + // test clear all + if err = bm.ClearAll(); err != nil { + t.Error("clear all err") + } +} + +func TestCache_Scan(t *testing.T) { + timeoutDuration := 10 * time.Second + // init + bm, err := cache.NewCache("redis", `{"conn": "127.0.0.1:6379"}`) + if err != nil { + t.Error("init err") + } + // insert all + for i := 0; i < 10000; i++ { + if err = bm.Put(fmt.Sprintf("astaxie%d", i), fmt.Sprintf("author%d", i), timeoutDuration); err != nil { + t.Error("set Error", err) + } + } + // scan all for the first time + keys, err := bm.(*Cache).Scan(DefaultKey + ":*") + if err != nil { + t.Error("scan Error", err) + } + + assert.Equal(t, 10000, len(keys), "scan all error") + + // clear all + if err = bm.ClearAll(); err != nil { + t.Error("clear all err") + } + + // scan all for the second time + keys, err = bm.(*Cache).Scan(DefaultKey + ":*") + if err != nil { + t.Error("scan Error", err) + } + if len(keys) != 0 { + t.Error("scan all err") + } +} diff --git a/pkg/cache/ssdb/ssdb.go b/pkg/cache/ssdb/ssdb.go new file mode 100644 index 00000000..fa2ce04b --- /dev/null +++ b/pkg/cache/ssdb/ssdb.go @@ -0,0 +1,231 @@ +package ssdb + +import ( + "encoding/json" + "errors" + "strconv" + "strings" + "time" + + "github.com/ssdb/gossdb/ssdb" + + "github.com/astaxie/beego/cache" +) + +// Cache SSDB adapter +type Cache struct { + conn *ssdb.Client + conninfo []string +} + +//NewSsdbCache create new ssdb adapter. +func NewSsdbCache() cache.Cache { + return &Cache{} +} + +// Get get value from memcache. +func (rc *Cache) Get(key string) interface{} { + if rc.conn == nil { + if err := rc.connectInit(); err != nil { + return nil + } + } + value, err := rc.conn.Get(key) + if err == nil { + return value + } + return nil +} + +// GetMulti get value from memcache. +func (rc *Cache) GetMulti(keys []string) []interface{} { + size := len(keys) + var values []interface{} + if rc.conn == nil { + if err := rc.connectInit(); err != nil { + for i := 0; i < size; i++ { + values = append(values, err) + } + return values + } + } + res, err := rc.conn.Do("multi_get", keys) + resSize := len(res) + if err == nil { + for i := 1; i < resSize; i += 2 { + values = append(values, res[i+1]) + } + return values + } + for i := 0; i < size; i++ { + values = append(values, err) + } + return values +} + +// DelMulti get value from memcache. +func (rc *Cache) DelMulti(keys []string) error { + if rc.conn == nil { + if err := rc.connectInit(); err != nil { + return err + } + } + _, err := rc.conn.Do("multi_del", keys) + return err +} + +// Put put value to memcache. only support string. +func (rc *Cache) Put(key string, value interface{}, timeout time.Duration) error { + if rc.conn == nil { + if err := rc.connectInit(); err != nil { + return err + } + } + v, ok := value.(string) + if !ok { + return errors.New("value must string") + } + var resp []string + var err error + ttl := int(timeout / time.Second) + if ttl < 0 { + resp, err = rc.conn.Do("set", key, v) + } else { + resp, err = rc.conn.Do("setx", key, v, ttl) + } + if err != nil { + return err + } + if len(resp) == 2 && resp[0] == "ok" { + return nil + } + return errors.New("bad response") +} + +// Delete delete value in memcache. +func (rc *Cache) Delete(key string) error { + if rc.conn == nil { + if err := rc.connectInit(); err != nil { + return err + } + } + _, err := rc.conn.Del(key) + return err +} + +// Incr increase counter. +func (rc *Cache) Incr(key string) error { + if rc.conn == nil { + if err := rc.connectInit(); err != nil { + return err + } + } + _, err := rc.conn.Do("incr", key, 1) + return err +} + +// Decr decrease counter. +func (rc *Cache) Decr(key string) error { + if rc.conn == nil { + if err := rc.connectInit(); err != nil { + return err + } + } + _, err := rc.conn.Do("incr", key, -1) + return err +} + +// IsExist check value exists in memcache. +func (rc *Cache) IsExist(key string) bool { + if rc.conn == nil { + if err := rc.connectInit(); err != nil { + return false + } + } + resp, err := rc.conn.Do("exists", key) + if err != nil { + return false + } + if len(resp) == 2 && resp[1] == "1" { + return true + } + return false + +} + +// ClearAll clear all cached in memcache. +func (rc *Cache) ClearAll() error { + if rc.conn == nil { + if err := rc.connectInit(); err != nil { + return err + } + } + keyStart, keyEnd, limit := "", "", 50 + resp, err := rc.Scan(keyStart, keyEnd, limit) + for err == nil { + size := len(resp) + if size == 1 { + return nil + } + keys := []string{} + for i := 1; i < size; i += 2 { + keys = append(keys, resp[i]) + } + _, e := rc.conn.Do("multi_del", keys) + if e != nil { + return e + } + keyStart = resp[size-2] + resp, err = rc.Scan(keyStart, keyEnd, limit) + } + return err +} + +// Scan key all cached in ssdb. +func (rc *Cache) Scan(keyStart string, keyEnd string, limit int) ([]string, error) { + if rc.conn == nil { + if err := rc.connectInit(); err != nil { + return nil, err + } + } + resp, err := rc.conn.Do("scan", keyStart, keyEnd, limit) + if err != nil { + return nil, err + } + return resp, nil +} + +// StartAndGC start memcache adapter. +// config string is like {"conn":"connection info"}. +// if connecting error, return. +func (rc *Cache) StartAndGC(config string) error { + var cf map[string]string + json.Unmarshal([]byte(config), &cf) + if _, ok := cf["conn"]; !ok { + return errors.New("config has no conn key") + } + rc.conninfo = strings.Split(cf["conn"], ";") + if rc.conn == nil { + if err := rc.connectInit(); err != nil { + return err + } + } + return nil +} + +// connect to memcache and keep the connection. +func (rc *Cache) connectInit() error { + conninfoArray := strings.Split(rc.conninfo[0], ":") + host := conninfoArray[0] + port, e := strconv.Atoi(conninfoArray[1]) + if e != nil { + return e + } + var err error + rc.conn, err = ssdb.Connect(host, port) + return err +} + +func init() { + cache.Register("ssdb", NewSsdbCache) +} diff --git a/pkg/cache/ssdb/ssdb_test.go b/pkg/cache/ssdb/ssdb_test.go new file mode 100644 index 00000000..dd474960 --- /dev/null +++ b/pkg/cache/ssdb/ssdb_test.go @@ -0,0 +1,104 @@ +package ssdb + +import ( + "strconv" + "testing" + "time" + + "github.com/astaxie/beego/cache" +) + +func TestSsdbcacheCache(t *testing.T) { + ssdb, err := cache.NewCache("ssdb", `{"conn": "127.0.0.1:8888"}`) + if err != nil { + t.Error("init err") + } + + // test put and exist + if ssdb.IsExist("ssdb") { + t.Error("check err") + } + timeoutDuration := 10 * time.Second + //timeoutDuration := -10*time.Second if timeoutDuration is negtive,it means permanent + if err = ssdb.Put("ssdb", "ssdb", timeoutDuration); err != nil { + t.Error("set Error", err) + } + if !ssdb.IsExist("ssdb") { + t.Error("check err") + } + + // Get test done + if err = ssdb.Put("ssdb", "ssdb", timeoutDuration); err != nil { + t.Error("set Error", err) + } + + if v := ssdb.Get("ssdb"); v != "ssdb" { + t.Error("get Error") + } + + //inc/dec test done + if err = ssdb.Put("ssdb", "2", timeoutDuration); err != nil { + t.Error("set Error", err) + } + if err = ssdb.Incr("ssdb"); err != nil { + t.Error("incr Error", err) + } + + if v, err := strconv.Atoi(ssdb.Get("ssdb").(string)); err != nil || v != 3 { + t.Error("get err") + } + + if err = ssdb.Decr("ssdb"); err != nil { + t.Error("decr error") + } + + // test del + if err = ssdb.Put("ssdb", "3", timeoutDuration); err != nil { + t.Error("set Error", err) + } + if v, err := strconv.Atoi(ssdb.Get("ssdb").(string)); err != nil || v != 3 { + t.Error("get err") + } + if err := ssdb.Delete("ssdb"); err == nil { + if ssdb.IsExist("ssdb") { + t.Error("delete err") + } + } + + //test string + if err = ssdb.Put("ssdb", "ssdb", -10*time.Second); err != nil { + t.Error("set Error", err) + } + if !ssdb.IsExist("ssdb") { + t.Error("check err") + } + if v := ssdb.Get("ssdb").(string); v != "ssdb" { + t.Error("get err") + } + + //test GetMulti done + if err = ssdb.Put("ssdb1", "ssdb1", -10*time.Second); err != nil { + t.Error("set Error", err) + } + if !ssdb.IsExist("ssdb1") { + t.Error("check err") + } + vv := ssdb.GetMulti([]string{"ssdb", "ssdb1"}) + if len(vv) != 2 { + t.Error("getmulti error") + } + if vv[0].(string) != "ssdb" { + t.Error("getmulti error") + } + if vv[1].(string) != "ssdb1" { + t.Error("getmulti error") + } + + // test clear all done + if err = ssdb.ClearAll(); err != nil { + t.Error("clear all err") + } + if ssdb.IsExist("ssdb") || ssdb.IsExist("ssdb1") { + t.Error("check err") + } +} diff --git a/pkg/config.go b/pkg/config.go new file mode 100644 index 00000000..b6c9a99c --- /dev/null +++ b/pkg/config.go @@ -0,0 +1,524 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package beego + +import ( + "fmt" + "os" + "path/filepath" + "reflect" + "runtime" + "strings" + + "github.com/astaxie/beego/config" + "github.com/astaxie/beego/context" + "github.com/astaxie/beego/logs" + "github.com/astaxie/beego/session" + "github.com/astaxie/beego/utils" +) + +// Config is the main struct for BConfig +type Config struct { + AppName string //Application name + RunMode string //Running Mode: dev | prod + RouterCaseSensitive bool + ServerName string + RecoverPanic bool + RecoverFunc func(*context.Context) + CopyRequestBody bool + EnableGzip bool + MaxMemory int64 + EnableErrorsShow bool + EnableErrorsRender bool + Listen Listen + WebConfig WebConfig + Log LogConfig +} + +// Listen holds for http and https related config +type Listen struct { + Graceful bool // Graceful means use graceful module to start the server + ServerTimeOut int64 + ListenTCP4 bool + EnableHTTP bool + HTTPAddr string + HTTPPort int + AutoTLS bool + Domains []string + TLSCacheDir string + EnableHTTPS bool + EnableMutualHTTPS bool + HTTPSAddr string + HTTPSPort int + HTTPSCertFile string + HTTPSKeyFile string + TrustCaFile string + EnableAdmin bool + AdminAddr string + AdminPort int + EnableFcgi bool + EnableStdIo bool // EnableStdIo works with EnableFcgi Use FCGI via standard I/O +} + +// WebConfig holds web related config +type WebConfig struct { + AutoRender bool + EnableDocs bool + FlashName string + FlashSeparator string + DirectoryIndex bool + StaticDir map[string]string + StaticExtensionsToGzip []string + StaticCacheFileSize int + StaticCacheFileNum int + TemplateLeft string + TemplateRight string + ViewsPath string + EnableXSRF bool + XSRFKey string + XSRFExpire int + Session SessionConfig +} + +// SessionConfig holds session related config +type SessionConfig struct { + SessionOn bool + SessionProvider string + SessionName string + SessionGCMaxLifetime int64 + SessionProviderConfig string + SessionCookieLifeTime int + SessionAutoSetCookie bool + SessionDomain string + SessionDisableHTTPOnly bool // used to allow for cross domain cookies/javascript cookies. + SessionEnableSidInHTTPHeader bool // enable store/get the sessionId into/from http headers + SessionNameInHTTPHeader string + SessionEnableSidInURLQuery bool // enable get the sessionId from Url Query params +} + +// LogConfig holds Log related config +type LogConfig struct { + AccessLogs bool + EnableStaticLogs bool //log static files requests default: false + AccessLogsFormat string //access log format: JSON_FORMAT, APACHE_FORMAT or empty string + FileLineNum bool + Outputs map[string]string // Store Adaptor : config +} + +var ( + // BConfig is the default config for Application + BConfig *Config + // AppConfig is the instance of Config, store the config information from file + AppConfig *beegoAppConfig + // AppPath is the absolute path to the app + AppPath string + // GlobalSessions is the instance for the session manager + GlobalSessions *session.Manager + + // appConfigPath is the path to the config files + appConfigPath string + // appConfigProvider is the provider for the config, default is ini + appConfigProvider = "ini" + // WorkPath is the absolute path to project root directory + WorkPath string +) + +func init() { + BConfig = newBConfig() + var err error + if AppPath, err = filepath.Abs(filepath.Dir(os.Args[0])); err != nil { + panic(err) + } + WorkPath, err = os.Getwd() + if err != nil { + panic(err) + } + var filename = "app.conf" + if os.Getenv("BEEGO_RUNMODE") != "" { + filename = os.Getenv("BEEGO_RUNMODE") + ".app.conf" + } + appConfigPath = filepath.Join(WorkPath, "conf", filename) + if !utils.FileExists(appConfigPath) { + appConfigPath = filepath.Join(AppPath, "conf", filename) + if !utils.FileExists(appConfigPath) { + AppConfig = &beegoAppConfig{innerConfig: config.NewFakeConfig()} + return + } + } + if err = parseConfig(appConfigPath); err != nil { + panic(err) + } +} + +func recoverPanic(ctx *context.Context) { + if err := recover(); err != nil { + if err == ErrAbort { + return + } + if !BConfig.RecoverPanic { + panic(err) + } + if BConfig.EnableErrorsShow { + if _, ok := ErrorMaps[fmt.Sprint(err)]; ok { + exception(fmt.Sprint(err), ctx) + return + } + } + var stack string + logs.Critical("the request url is ", ctx.Input.URL()) + logs.Critical("Handler crashed with error", err) + for i := 1; ; i++ { + _, file, line, ok := runtime.Caller(i) + if !ok { + break + } + logs.Critical(fmt.Sprintf("%s:%d", file, line)) + stack = stack + fmt.Sprintln(fmt.Sprintf("%s:%d", file, line)) + } + if BConfig.RunMode == DEV && BConfig.EnableErrorsRender { + showErr(err, ctx, stack) + } + if ctx.Output.Status != 0 { + ctx.ResponseWriter.WriteHeader(ctx.Output.Status) + } else { + ctx.ResponseWriter.WriteHeader(500) + } + } +} + +func newBConfig() *Config { + return &Config{ + AppName: "beego", + RunMode: PROD, + RouterCaseSensitive: true, + ServerName: "beegoServer:" + VERSION, + RecoverPanic: true, + RecoverFunc: recoverPanic, + CopyRequestBody: false, + EnableGzip: false, + MaxMemory: 1 << 26, //64MB + EnableErrorsShow: true, + EnableErrorsRender: true, + Listen: Listen{ + Graceful: false, + ServerTimeOut: 0, + ListenTCP4: false, + EnableHTTP: true, + AutoTLS: false, + Domains: []string{}, + TLSCacheDir: ".", + HTTPAddr: "", + HTTPPort: 8080, + EnableHTTPS: false, + HTTPSAddr: "", + HTTPSPort: 10443, + HTTPSCertFile: "", + HTTPSKeyFile: "", + EnableAdmin: false, + AdminAddr: "", + AdminPort: 8088, + EnableFcgi: false, + EnableStdIo: false, + }, + WebConfig: WebConfig{ + AutoRender: true, + EnableDocs: false, + FlashName: "BEEGO_FLASH", + FlashSeparator: "BEEGOFLASH", + DirectoryIndex: false, + StaticDir: map[string]string{"/static": "static"}, + StaticExtensionsToGzip: []string{".css", ".js"}, + StaticCacheFileSize: 1024 * 100, + StaticCacheFileNum: 1000, + TemplateLeft: "{{", + TemplateRight: "}}", + ViewsPath: "views", + EnableXSRF: false, + XSRFKey: "beegoxsrf", + XSRFExpire: 0, + Session: SessionConfig{ + SessionOn: false, + SessionProvider: "memory", + SessionName: "beegosessionID", + SessionGCMaxLifetime: 3600, + SessionProviderConfig: "", + SessionDisableHTTPOnly: false, + SessionCookieLifeTime: 0, //set cookie default is the browser life + SessionAutoSetCookie: true, + SessionDomain: "", + SessionEnableSidInHTTPHeader: false, // enable store/get the sessionId into/from http headers + SessionNameInHTTPHeader: "Beegosessionid", + SessionEnableSidInURLQuery: false, // enable get the sessionId from Url Query params + }, + }, + Log: LogConfig{ + AccessLogs: false, + EnableStaticLogs: false, + AccessLogsFormat: "APACHE_FORMAT", + FileLineNum: true, + Outputs: map[string]string{"console": ""}, + }, + } +} + +// now only support ini, next will support json. +func parseConfig(appConfigPath string) (err error) { + AppConfig, err = newAppConfig(appConfigProvider, appConfigPath) + if err != nil { + return err + } + return assignConfig(AppConfig) +} + +func assignConfig(ac config.Configer) error { + for _, i := range []interface{}{BConfig, &BConfig.Listen, &BConfig.WebConfig, &BConfig.Log, &BConfig.WebConfig.Session} { + assignSingleConfig(i, ac) + } + // set the run mode first + if envRunMode := os.Getenv("BEEGO_RUNMODE"); envRunMode != "" { + BConfig.RunMode = envRunMode + } else if runMode := ac.String("RunMode"); runMode != "" { + BConfig.RunMode = runMode + } + + if sd := ac.String("StaticDir"); sd != "" { + BConfig.WebConfig.StaticDir = map[string]string{} + sds := strings.Fields(sd) + for _, v := range sds { + if url2fsmap := strings.SplitN(v, ":", 2); len(url2fsmap) == 2 { + BConfig.WebConfig.StaticDir["/"+strings.Trim(url2fsmap[0], "/")] = url2fsmap[1] + } else { + BConfig.WebConfig.StaticDir["/"+strings.Trim(url2fsmap[0], "/")] = url2fsmap[0] + } + } + } + + if sgz := ac.String("StaticExtensionsToGzip"); sgz != "" { + extensions := strings.Split(sgz, ",") + fileExts := []string{} + for _, ext := range extensions { + ext = strings.TrimSpace(ext) + if ext == "" { + continue + } + if !strings.HasPrefix(ext, ".") { + ext = "." + ext + } + fileExts = append(fileExts, ext) + } + if len(fileExts) > 0 { + BConfig.WebConfig.StaticExtensionsToGzip = fileExts + } + } + + if sfs, err := ac.Int("StaticCacheFileSize"); err == nil { + BConfig.WebConfig.StaticCacheFileSize = sfs + } + + if sfn, err := ac.Int("StaticCacheFileNum"); err == nil { + BConfig.WebConfig.StaticCacheFileNum = sfn + } + + if lo := ac.String("LogOutputs"); lo != "" { + // if lo is not nil or empty + // means user has set his own LogOutputs + // clear the default setting to BConfig.Log.Outputs + BConfig.Log.Outputs = make(map[string]string) + los := strings.Split(lo, ";") + for _, v := range los { + if logType2Config := strings.SplitN(v, ",", 2); len(logType2Config) == 2 { + BConfig.Log.Outputs[logType2Config[0]] = logType2Config[1] + } else { + continue + } + } + } + + //init log + logs.Reset() + for adaptor, config := range BConfig.Log.Outputs { + err := logs.SetLogger(adaptor, config) + if err != nil { + fmt.Fprintln(os.Stderr, fmt.Sprintf("%s with the config %q got err:%s", adaptor, config, err.Error())) + } + } + logs.SetLogFuncCall(BConfig.Log.FileLineNum) + + return nil +} + +func assignSingleConfig(p interface{}, ac config.Configer) { + pt := reflect.TypeOf(p) + if pt.Kind() != reflect.Ptr { + return + } + pt = pt.Elem() + if pt.Kind() != reflect.Struct { + return + } + pv := reflect.ValueOf(p).Elem() + + for i := 0; i < pt.NumField(); i++ { + pf := pv.Field(i) + if !pf.CanSet() { + continue + } + name := pt.Field(i).Name + switch pf.Kind() { + case reflect.String: + pf.SetString(ac.DefaultString(name, pf.String())) + case reflect.Int, reflect.Int64: + pf.SetInt(ac.DefaultInt64(name, pf.Int())) + case reflect.Bool: + pf.SetBool(ac.DefaultBool(name, pf.Bool())) + case reflect.Struct: + default: + //do nothing here + } + } + +} + +// LoadAppConfig allow developer to apply a config file +func LoadAppConfig(adapterName, configPath string) error { + absConfigPath, err := filepath.Abs(configPath) + if err != nil { + return err + } + + if !utils.FileExists(absConfigPath) { + return fmt.Errorf("the target config file: %s don't exist", configPath) + } + + appConfigPath = absConfigPath + appConfigProvider = adapterName + + return parseConfig(appConfigPath) +} + +type beegoAppConfig struct { + innerConfig config.Configer +} + +func newAppConfig(appConfigProvider, appConfigPath string) (*beegoAppConfig, error) { + ac, err := config.NewConfig(appConfigProvider, appConfigPath) + if err != nil { + return nil, err + } + return &beegoAppConfig{ac}, nil +} + +func (b *beegoAppConfig) Set(key, val string) error { + if err := b.innerConfig.Set(BConfig.RunMode+"::"+key, val); err != nil { + return b.innerConfig.Set(key, val) + } + return nil +} + +func (b *beegoAppConfig) String(key string) string { + if v := b.innerConfig.String(BConfig.RunMode + "::" + key); v != "" { + return v + } + return b.innerConfig.String(key) +} + +func (b *beegoAppConfig) Strings(key string) []string { + if v := b.innerConfig.Strings(BConfig.RunMode + "::" + key); len(v) > 0 { + return v + } + return b.innerConfig.Strings(key) +} + +func (b *beegoAppConfig) Int(key string) (int, error) { + if v, err := b.innerConfig.Int(BConfig.RunMode + "::" + key); err == nil { + return v, nil + } + return b.innerConfig.Int(key) +} + +func (b *beegoAppConfig) Int64(key string) (int64, error) { + if v, err := b.innerConfig.Int64(BConfig.RunMode + "::" + key); err == nil { + return v, nil + } + return b.innerConfig.Int64(key) +} + +func (b *beegoAppConfig) Bool(key string) (bool, error) { + if v, err := b.innerConfig.Bool(BConfig.RunMode + "::" + key); err == nil { + return v, nil + } + return b.innerConfig.Bool(key) +} + +func (b *beegoAppConfig) Float(key string) (float64, error) { + if v, err := b.innerConfig.Float(BConfig.RunMode + "::" + key); err == nil { + return v, nil + } + return b.innerConfig.Float(key) +} + +func (b *beegoAppConfig) DefaultString(key string, defaultVal string) string { + if v := b.String(key); v != "" { + return v + } + return defaultVal +} + +func (b *beegoAppConfig) DefaultStrings(key string, defaultVal []string) []string { + if v := b.Strings(key); len(v) != 0 { + return v + } + return defaultVal +} + +func (b *beegoAppConfig) DefaultInt(key string, defaultVal int) int { + if v, err := b.Int(key); err == nil { + return v + } + return defaultVal +} + +func (b *beegoAppConfig) DefaultInt64(key string, defaultVal int64) int64 { + if v, err := b.Int64(key); err == nil { + return v + } + return defaultVal +} + +func (b *beegoAppConfig) DefaultBool(key string, defaultVal bool) bool { + if v, err := b.Bool(key); err == nil { + return v + } + return defaultVal +} + +func (b *beegoAppConfig) DefaultFloat(key string, defaultVal float64) float64 { + if v, err := b.Float(key); err == nil { + return v + } + return defaultVal +} + +func (b *beegoAppConfig) DIY(key string) (interface{}, error) { + return b.innerConfig.DIY(key) +} + +func (b *beegoAppConfig) GetSection(section string) (map[string]string, error) { + return b.innerConfig.GetSection(section) +} + +func (b *beegoAppConfig) SaveConfigFile(filename string) error { + return b.innerConfig.SaveConfigFile(filename) +} diff --git a/pkg/config/config.go b/pkg/config/config.go new file mode 100644 index 00000000..bfd79e85 --- /dev/null +++ b/pkg/config/config.go @@ -0,0 +1,242 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package config is used to parse config. +// Usage: +// import "github.com/astaxie/beego/config" +//Examples. +// +// cnf, err := config.NewConfig("ini", "config.conf") +// +// cnf APIS: +// +// cnf.Set(key, val string) error +// cnf.String(key string) string +// cnf.Strings(key string) []string +// cnf.Int(key string) (int, error) +// cnf.Int64(key string) (int64, error) +// cnf.Bool(key string) (bool, error) +// cnf.Float(key string) (float64, error) +// cnf.DefaultString(key string, defaultVal string) string +// cnf.DefaultStrings(key string, defaultVal []string) []string +// cnf.DefaultInt(key string, defaultVal int) int +// cnf.DefaultInt64(key string, defaultVal int64) int64 +// cnf.DefaultBool(key string, defaultVal bool) bool +// cnf.DefaultFloat(key string, defaultVal float64) float64 +// cnf.DIY(key string) (interface{}, error) +// cnf.GetSection(section string) (map[string]string, error) +// cnf.SaveConfigFile(filename string) error +//More docs http://beego.me/docs/module/config.md +package config + +import ( + "fmt" + "os" + "reflect" + "time" +) + +// Configer defines how to get and set value from configuration raw data. +type Configer interface { + Set(key, val string) error //support section::key type in given key when using ini type. + String(key string) string //support section::key type in key string when using ini and json type; Int,Int64,Bool,Float,DIY are same. + Strings(key string) []string //get string slice + Int(key string) (int, error) + Int64(key string) (int64, error) + Bool(key string) (bool, error) + Float(key string) (float64, error) + DefaultString(key string, defaultVal string) string // support section::key type in key string when using ini and json type; Int,Int64,Bool,Float,DIY are same. + DefaultStrings(key string, defaultVal []string) []string //get string slice + DefaultInt(key string, defaultVal int) int + DefaultInt64(key string, defaultVal int64) int64 + DefaultBool(key string, defaultVal bool) bool + DefaultFloat(key string, defaultVal float64) float64 + DIY(key string) (interface{}, error) + GetSection(section string) (map[string]string, error) + SaveConfigFile(filename string) error +} + +// Config is the adapter interface for parsing config file to get raw data to Configer. +type Config interface { + Parse(key string) (Configer, error) + ParseData(data []byte) (Configer, error) +} + +var adapters = make(map[string]Config) + +// Register makes a config adapter available by the adapter name. +// If Register is called twice with the same name or if driver is nil, +// it panics. +func Register(name string, adapter Config) { + if adapter == nil { + panic("config: Register adapter is nil") + } + if _, ok := adapters[name]; ok { + panic("config: Register called twice for adapter " + name) + } + adapters[name] = adapter +} + +// NewConfig adapterName is ini/json/xml/yaml. +// filename is the config file path. +func NewConfig(adapterName, filename string) (Configer, error) { + adapter, ok := adapters[adapterName] + if !ok { + return nil, fmt.Errorf("config: unknown adaptername %q (forgotten import?)", adapterName) + } + return adapter.Parse(filename) +} + +// NewConfigData adapterName is ini/json/xml/yaml. +// data is the config data. +func NewConfigData(adapterName string, data []byte) (Configer, error) { + adapter, ok := adapters[adapterName] + if !ok { + return nil, fmt.Errorf("config: unknown adaptername %q (forgotten import?)", adapterName) + } + return adapter.ParseData(data) +} + +// ExpandValueEnvForMap convert all string value with environment variable. +func ExpandValueEnvForMap(m map[string]interface{}) map[string]interface{} { + for k, v := range m { + switch value := v.(type) { + case string: + m[k] = ExpandValueEnv(value) + case map[string]interface{}: + m[k] = ExpandValueEnvForMap(value) + case map[string]string: + for k2, v2 := range value { + value[k2] = ExpandValueEnv(v2) + } + m[k] = value + } + } + return m +} + +// ExpandValueEnv returns value of convert with environment variable. +// +// Return environment variable if value start with "${" and end with "}". +// Return default value if environment variable is empty or not exist. +// +// It accept value formats "${env}" , "${env||}}" , "${env||defaultValue}" , "defaultvalue". +// Examples: +// v1 := config.ExpandValueEnv("${GOPATH}") // return the GOPATH environment variable. +// v2 := config.ExpandValueEnv("${GOAsta||/usr/local/go}") // return the default value "/usr/local/go/". +// v3 := config.ExpandValueEnv("Astaxie") // return the value "Astaxie". +func ExpandValueEnv(value string) (realValue string) { + realValue = value + + vLen := len(value) + // 3 = ${} + if vLen < 3 { + return + } + // Need start with "${" and end with "}", then return. + if value[0] != '$' || value[1] != '{' || value[vLen-1] != '}' { + return + } + + key := "" + defaultV := "" + // value start with "${" + for i := 2; i < vLen; i++ { + if value[i] == '|' && (i+1 < vLen && value[i+1] == '|') { + key = value[2:i] + defaultV = value[i+2 : vLen-1] // other string is default value. + break + } else if value[i] == '}' { + key = value[2:i] + break + } + } + + realValue = os.Getenv(key) + if realValue == "" { + realValue = defaultV + } + + return +} + +// ParseBool returns the boolean value represented by the string. +// +// It accepts 1, 1.0, t, T, TRUE, true, True, YES, yes, Yes,Y, y, ON, on, On, +// 0, 0.0, f, F, FALSE, false, False, NO, no, No, N,n, OFF, off, Off. +// Any other value returns an error. +func ParseBool(val interface{}) (value bool, err error) { + if val != nil { + switch v := val.(type) { + case bool: + return v, nil + case string: + switch v { + case "1", "t", "T", "true", "TRUE", "True", "YES", "yes", "Yes", "Y", "y", "ON", "on", "On": + return true, nil + case "0", "f", "F", "false", "FALSE", "False", "NO", "no", "No", "N", "n", "OFF", "off", "Off": + return false, nil + } + case int8, int32, int64: + strV := fmt.Sprintf("%d", v) + if strV == "1" { + return true, nil + } else if strV == "0" { + return false, nil + } + case float64: + if v == 1.0 { + return true, nil + } else if v == 0.0 { + return false, nil + } + } + return false, fmt.Errorf("parsing %q: invalid syntax", val) + } + return false, fmt.Errorf("parsing : invalid syntax") +} + +// ToString converts values of any type to string. +func ToString(x interface{}) string { + switch y := x.(type) { + + // Handle dates with special logic + // This needs to come above the fmt.Stringer + // test since time.Time's have a .String() + // method + case time.Time: + return y.Format("A Monday") + + // Handle type string + case string: + return y + + // Handle type with .String() method + case fmt.Stringer: + return y.String() + + // Handle type with .Error() method + case error: + return y.Error() + + } + + // Handle named string type + if v := reflect.ValueOf(x); v.Kind() == reflect.String { + return v.String() + } + + // Fallback to fmt package for anything else like numeric types + return fmt.Sprint(x) +} diff --git a/pkg/config/config_test.go b/pkg/config/config_test.go new file mode 100644 index 00000000..15d6ffa6 --- /dev/null +++ b/pkg/config/config_test.go @@ -0,0 +1,55 @@ +// Copyright 2016 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package config + +import ( + "os" + "testing" +) + +func TestExpandValueEnv(t *testing.T) { + + testCases := []struct { + item string + want string + }{ + {"", ""}, + {"$", "$"}, + {"{", "{"}, + {"{}", "{}"}, + {"${}", ""}, + {"${|}", ""}, + {"${}", ""}, + {"${{}}", ""}, + {"${{||}}", "}"}, + {"${pwd||}", ""}, + {"${pwd||}", ""}, + {"${pwd||}", ""}, + {"${pwd||}}", "}"}, + {"${pwd||{{||}}}", "{{||}}"}, + {"${GOPATH}", os.Getenv("GOPATH")}, + {"${GOPATH||}", os.Getenv("GOPATH")}, + {"${GOPATH||root}", os.Getenv("GOPATH")}, + {"${GOPATH_NOT||root}", "root"}, + {"${GOPATH_NOT||||root}", "||root"}, + } + + for _, c := range testCases { + if got := ExpandValueEnv(c.item); got != c.want { + t.Errorf("expand value error, item %q want %q, got %q", c.item, c.want, got) + } + } + +} diff --git a/pkg/config/env/env.go b/pkg/config/env/env.go new file mode 100644 index 00000000..34f094fe --- /dev/null +++ b/pkg/config/env/env.go @@ -0,0 +1,87 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// Copyright 2017 Faissal Elamraoui. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package env is used to parse environment. +package env + +import ( + "fmt" + "os" + "strings" + + "github.com/astaxie/beego/utils" +) + +var env *utils.BeeMap + +func init() { + env = utils.NewBeeMap() + for _, e := range os.Environ() { + splits := strings.Split(e, "=") + env.Set(splits[0], os.Getenv(splits[0])) + } +} + +// Get returns a value by key. +// If the key does not exist, the default value will be returned. +func Get(key string, defVal string) string { + if val := env.Get(key); val != nil { + return val.(string) + } + return defVal +} + +// MustGet returns a value by key. +// If the key does not exist, it will return an error. +func MustGet(key string) (string, error) { + if val := env.Get(key); val != nil { + return val.(string), nil + } + return "", fmt.Errorf("no env variable with %s", key) +} + +// Set sets a value in the ENV copy. +// This does not affect the child process environment. +func Set(key string, value string) { + env.Set(key, value) +} + +// MustSet sets a value in the ENV copy and the child process environment. +// It returns an error in case the set operation failed. +func MustSet(key string, value string) error { + err := os.Setenv(key, value) + if err != nil { + return err + } + env.Set(key, value) + return nil +} + +// GetAll returns all keys/values in the current child process environment. +func GetAll() map[string]string { + items := env.Items() + envs := make(map[string]string, env.Count()) + + for key, val := range items { + switch key := key.(type) { + case string: + switch val := val.(type) { + case string: + envs[key] = val + } + } + } + return envs +} diff --git a/pkg/config/env/env_test.go b/pkg/config/env/env_test.go new file mode 100644 index 00000000..3f1d4dba --- /dev/null +++ b/pkg/config/env/env_test.go @@ -0,0 +1,75 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// Copyright 2017 Faissal Elamraoui. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +package env + +import ( + "os" + "testing" +) + +func TestEnvGet(t *testing.T) { + gopath := Get("GOPATH", "") + if gopath != os.Getenv("GOPATH") { + t.Error("expected GOPATH not empty.") + } + + noExistVar := Get("NOEXISTVAR", "foo") + if noExistVar != "foo" { + t.Errorf("expected NOEXISTVAR to equal foo, got %s.", noExistVar) + } +} + +func TestEnvMustGet(t *testing.T) { + gopath, err := MustGet("GOPATH") + if err != nil { + t.Error(err) + } + + if gopath != os.Getenv("GOPATH") { + t.Errorf("expected GOPATH to be the same, got %s.", gopath) + } + + _, err = MustGet("NOEXISTVAR") + if err == nil { + t.Error("expected error to be non-nil") + } +} + +func TestEnvSet(t *testing.T) { + Set("MYVAR", "foo") + myVar := Get("MYVAR", "bar") + if myVar != "foo" { + t.Errorf("expected MYVAR to equal foo, got %s.", myVar) + } +} + +func TestEnvMustSet(t *testing.T) { + err := MustSet("FOO", "bar") + if err != nil { + t.Error(err) + } + + fooVar := os.Getenv("FOO") + if fooVar != "bar" { + t.Errorf("expected FOO variable to equal bar, got %s.", fooVar) + } +} + +func TestEnvGetAll(t *testing.T) { + envMap := GetAll() + if len(envMap) == 0 { + t.Error("expected environment not empty.") + } +} diff --git a/pkg/config/fake.go b/pkg/config/fake.go new file mode 100644 index 00000000..d21ab820 --- /dev/null +++ b/pkg/config/fake.go @@ -0,0 +1,134 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package config + +import ( + "errors" + "strconv" + "strings" +) + +type fakeConfigContainer struct { + data map[string]string +} + +func (c *fakeConfigContainer) getData(key string) string { + return c.data[strings.ToLower(key)] +} + +func (c *fakeConfigContainer) Set(key, val string) error { + c.data[strings.ToLower(key)] = val + return nil +} + +func (c *fakeConfigContainer) String(key string) string { + return c.getData(key) +} + +func (c *fakeConfigContainer) DefaultString(key string, defaultval string) string { + v := c.String(key) + if v == "" { + return defaultval + } + return v +} + +func (c *fakeConfigContainer) Strings(key string) []string { + v := c.String(key) + if v == "" { + return nil + } + return strings.Split(v, ";") +} + +func (c *fakeConfigContainer) DefaultStrings(key string, defaultval []string) []string { + v := c.Strings(key) + if v == nil { + return defaultval + } + return v +} + +func (c *fakeConfigContainer) Int(key string) (int, error) { + return strconv.Atoi(c.getData(key)) +} + +func (c *fakeConfigContainer) DefaultInt(key string, defaultval int) int { + v, err := c.Int(key) + if err != nil { + return defaultval + } + return v +} + +func (c *fakeConfigContainer) Int64(key string) (int64, error) { + return strconv.ParseInt(c.getData(key), 10, 64) +} + +func (c *fakeConfigContainer) DefaultInt64(key string, defaultval int64) int64 { + v, err := c.Int64(key) + if err != nil { + return defaultval + } + return v +} + +func (c *fakeConfigContainer) Bool(key string) (bool, error) { + return ParseBool(c.getData(key)) +} + +func (c *fakeConfigContainer) DefaultBool(key string, defaultval bool) bool { + v, err := c.Bool(key) + if err != nil { + return defaultval + } + return v +} + +func (c *fakeConfigContainer) Float(key string) (float64, error) { + return strconv.ParseFloat(c.getData(key), 64) +} + +func (c *fakeConfigContainer) DefaultFloat(key string, defaultval float64) float64 { + v, err := c.Float(key) + if err != nil { + return defaultval + } + return v +} + +func (c *fakeConfigContainer) DIY(key string) (interface{}, error) { + if v, ok := c.data[strings.ToLower(key)]; ok { + return v, nil + } + return nil, errors.New("key not find") +} + +func (c *fakeConfigContainer) GetSection(section string) (map[string]string, error) { + return nil, errors.New("not implement in the fakeConfigContainer") +} + +func (c *fakeConfigContainer) SaveConfigFile(filename string) error { + return errors.New("not implement in the fakeConfigContainer") +} + +var _ Configer = new(fakeConfigContainer) + +// NewFakeConfig return a fake Configer +func NewFakeConfig() Configer { + return &fakeConfigContainer{ + data: make(map[string]string), + } +} diff --git a/pkg/config/ini.go b/pkg/config/ini.go new file mode 100644 index 00000000..002e5e05 --- /dev/null +++ b/pkg/config/ini.go @@ -0,0 +1,504 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package config + +import ( + "bufio" + "bytes" + "errors" + "io" + "io/ioutil" + "os" + "os/user" + "path/filepath" + "strconv" + "strings" + "sync" +) + +var ( + defaultSection = "default" // default section means if some ini items not in a section, make them in default section, + bNumComment = []byte{'#'} // number signal + bSemComment = []byte{';'} // semicolon signal + bEmpty = []byte{} + bEqual = []byte{'='} // equal signal + bDQuote = []byte{'"'} // quote signal + sectionStart = []byte{'['} // section start signal + sectionEnd = []byte{']'} // section end signal + lineBreak = "\n" +) + +// IniConfig implements Config to parse ini file. +type IniConfig struct { +} + +// Parse creates a new Config and parses the file configuration from the named file. +func (ini *IniConfig) Parse(name string) (Configer, error) { + return ini.parseFile(name) +} + +func (ini *IniConfig) parseFile(name string) (*IniConfigContainer, error) { + data, err := ioutil.ReadFile(name) + if err != nil { + return nil, err + } + + return ini.parseData(filepath.Dir(name), data) +} + +func (ini *IniConfig) parseData(dir string, data []byte) (*IniConfigContainer, error) { + cfg := &IniConfigContainer{ + data: make(map[string]map[string]string), + sectionComment: make(map[string]string), + keyComment: make(map[string]string), + RWMutex: sync.RWMutex{}, + } + cfg.Lock() + defer cfg.Unlock() + + var comment bytes.Buffer + buf := bufio.NewReader(bytes.NewBuffer(data)) + // check the BOM + head, err := buf.Peek(3) + if err == nil && head[0] == 239 && head[1] == 187 && head[2] == 191 { + for i := 1; i <= 3; i++ { + buf.ReadByte() + } + } + section := defaultSection + tmpBuf := bytes.NewBuffer(nil) + for { + tmpBuf.Reset() + + shouldBreak := false + for { + tmp, isPrefix, err := buf.ReadLine() + if err == io.EOF { + shouldBreak = true + break + } + + //It might be a good idea to throw a error on all unknonw errors? + if _, ok := err.(*os.PathError); ok { + return nil, err + } + + tmpBuf.Write(tmp) + if isPrefix { + continue + } + + if !isPrefix { + break + } + } + if shouldBreak { + break + } + + line := tmpBuf.Bytes() + line = bytes.TrimSpace(line) + if bytes.Equal(line, bEmpty) { + continue + } + var bComment []byte + switch { + case bytes.HasPrefix(line, bNumComment): + bComment = bNumComment + case bytes.HasPrefix(line, bSemComment): + bComment = bSemComment + } + if bComment != nil { + line = bytes.TrimLeft(line, string(bComment)) + // Need append to a new line if multi-line comments. + if comment.Len() > 0 { + comment.WriteByte('\n') + } + comment.Write(line) + continue + } + + if bytes.HasPrefix(line, sectionStart) && bytes.HasSuffix(line, sectionEnd) { + section = strings.ToLower(string(line[1 : len(line)-1])) // section name case insensitive + if comment.Len() > 0 { + cfg.sectionComment[section] = comment.String() + comment.Reset() + } + if _, ok := cfg.data[section]; !ok { + cfg.data[section] = make(map[string]string) + } + continue + } + + if _, ok := cfg.data[section]; !ok { + cfg.data[section] = make(map[string]string) + } + keyValue := bytes.SplitN(line, bEqual, 2) + + key := string(bytes.TrimSpace(keyValue[0])) // key name case insensitive + key = strings.ToLower(key) + + // handle include "other.conf" + if len(keyValue) == 1 && strings.HasPrefix(key, "include") { + + includefiles := strings.Fields(key) + if includefiles[0] == "include" && len(includefiles) == 2 { + + otherfile := strings.Trim(includefiles[1], "\"") + if !filepath.IsAbs(otherfile) { + otherfile = filepath.Join(dir, otherfile) + } + + i, err := ini.parseFile(otherfile) + if err != nil { + return nil, err + } + + for sec, dt := range i.data { + if _, ok := cfg.data[sec]; !ok { + cfg.data[sec] = make(map[string]string) + } + for k, v := range dt { + cfg.data[sec][k] = v + } + } + + for sec, comm := range i.sectionComment { + cfg.sectionComment[sec] = comm + } + + for k, comm := range i.keyComment { + cfg.keyComment[k] = comm + } + + continue + } + } + + if len(keyValue) != 2 { + return nil, errors.New("read the content error: \"" + string(line) + "\", should key = val") + } + val := bytes.TrimSpace(keyValue[1]) + if bytes.HasPrefix(val, bDQuote) { + val = bytes.Trim(val, `"`) + } + + cfg.data[section][key] = ExpandValueEnv(string(val)) + if comment.Len() > 0 { + cfg.keyComment[section+"."+key] = comment.String() + comment.Reset() + } + + } + return cfg, nil +} + +// ParseData parse ini the data +// When include other.conf,other.conf is either absolute directory +// or under beego in default temporary directory(/tmp/beego[-username]). +func (ini *IniConfig) ParseData(data []byte) (Configer, error) { + dir := "beego" + currentUser, err := user.Current() + if err == nil { + dir = "beego-" + currentUser.Username + } + dir = filepath.Join(os.TempDir(), dir) + if err = os.MkdirAll(dir, os.ModePerm); err != nil { + return nil, err + } + + return ini.parseData(dir, data) +} + +// IniConfigContainer A Config represents the ini configuration. +// When set and get value, support key as section:name type. +type IniConfigContainer struct { + data map[string]map[string]string // section=> key:val + sectionComment map[string]string // section : comment + keyComment map[string]string // id: []{comment, key...}; id 1 is for main comment. + sync.RWMutex +} + +// Bool returns the boolean value for a given key. +func (c *IniConfigContainer) Bool(key string) (bool, error) { + return ParseBool(c.getdata(key)) +} + +// DefaultBool returns the boolean value for a given key. +// if err != nil return defaultval +func (c *IniConfigContainer) DefaultBool(key string, defaultval bool) bool { + v, err := c.Bool(key) + if err != nil { + return defaultval + } + return v +} + +// Int returns the integer value for a given key. +func (c *IniConfigContainer) Int(key string) (int, error) { + return strconv.Atoi(c.getdata(key)) +} + +// DefaultInt returns the integer value for a given key. +// if err != nil return defaultval +func (c *IniConfigContainer) DefaultInt(key string, defaultval int) int { + v, err := c.Int(key) + if err != nil { + return defaultval + } + return v +} + +// Int64 returns the int64 value for a given key. +func (c *IniConfigContainer) Int64(key string) (int64, error) { + return strconv.ParseInt(c.getdata(key), 10, 64) +} + +// DefaultInt64 returns the int64 value for a given key. +// if err != nil return defaultval +func (c *IniConfigContainer) DefaultInt64(key string, defaultval int64) int64 { + v, err := c.Int64(key) + if err != nil { + return defaultval + } + return v +} + +// Float returns the float value for a given key. +func (c *IniConfigContainer) Float(key string) (float64, error) { + return strconv.ParseFloat(c.getdata(key), 64) +} + +// DefaultFloat returns the float64 value for a given key. +// if err != nil return defaultval +func (c *IniConfigContainer) DefaultFloat(key string, defaultval float64) float64 { + v, err := c.Float(key) + if err != nil { + return defaultval + } + return v +} + +// String returns the string value for a given key. +func (c *IniConfigContainer) String(key string) string { + return c.getdata(key) +} + +// DefaultString returns the string value for a given key. +// if err != nil return defaultval +func (c *IniConfigContainer) DefaultString(key string, defaultval string) string { + v := c.String(key) + if v == "" { + return defaultval + } + return v +} + +// Strings returns the []string value for a given key. +// Return nil if config value does not exist or is empty. +func (c *IniConfigContainer) Strings(key string) []string { + v := c.String(key) + if v == "" { + return nil + } + return strings.Split(v, ";") +} + +// DefaultStrings returns the []string value for a given key. +// if err != nil return defaultval +func (c *IniConfigContainer) DefaultStrings(key string, defaultval []string) []string { + v := c.Strings(key) + if v == nil { + return defaultval + } + return v +} + +// GetSection returns map for the given section +func (c *IniConfigContainer) GetSection(section string) (map[string]string, error) { + if v, ok := c.data[section]; ok { + return v, nil + } + return nil, errors.New("not exist section") +} + +// SaveConfigFile save the config into file. +// +// BUG(env): The environment variable config item will be saved with real value in SaveConfigFile Function. +func (c *IniConfigContainer) SaveConfigFile(filename string) (err error) { + // Write configuration file by filename. + f, err := os.Create(filename) + if err != nil { + return err + } + defer f.Close() + + // Get section or key comments. Fixed #1607 + getCommentStr := func(section, key string) string { + var ( + comment string + ok bool + ) + if len(key) == 0 { + comment, ok = c.sectionComment[section] + } else { + comment, ok = c.keyComment[section+"."+key] + } + + if ok { + // Empty comment + if len(comment) == 0 || len(strings.TrimSpace(comment)) == 0 { + return string(bNumComment) + } + prefix := string(bNumComment) + // Add the line head character "#" + return prefix + strings.Replace(comment, lineBreak, lineBreak+prefix, -1) + } + return "" + } + + buf := bytes.NewBuffer(nil) + // Save default section at first place + if dt, ok := c.data[defaultSection]; ok { + for key, val := range dt { + if key != " " { + // Write key comments. + if v := getCommentStr(defaultSection, key); len(v) > 0 { + if _, err = buf.WriteString(v + lineBreak); err != nil { + return err + } + } + + // Write key and value. + if _, err = buf.WriteString(key + string(bEqual) + val + lineBreak); err != nil { + return err + } + } + } + + // Put a line between sections. + if _, err = buf.WriteString(lineBreak); err != nil { + return err + } + } + // Save named sections + for section, dt := range c.data { + if section != defaultSection { + // Write section comments. + if v := getCommentStr(section, ""); len(v) > 0 { + if _, err = buf.WriteString(v + lineBreak); err != nil { + return err + } + } + + // Write section name. + if _, err = buf.WriteString(string(sectionStart) + section + string(sectionEnd) + lineBreak); err != nil { + return err + } + + for key, val := range dt { + if key != " " { + // Write key comments. + if v := getCommentStr(section, key); len(v) > 0 { + if _, err = buf.WriteString(v + lineBreak); err != nil { + return err + } + } + + // Write key and value. + if _, err = buf.WriteString(key + string(bEqual) + val + lineBreak); err != nil { + return err + } + } + } + + // Put a line between sections. + if _, err = buf.WriteString(lineBreak); err != nil { + return err + } + } + } + _, err = buf.WriteTo(f) + return err +} + +// Set writes a new value for key. +// if write to one section, the key need be "section::key". +// if the section is not existed, it panics. +func (c *IniConfigContainer) Set(key, value string) error { + c.Lock() + defer c.Unlock() + if len(key) == 0 { + return errors.New("key is empty") + } + + var ( + section, k string + sectionKey = strings.Split(strings.ToLower(key), "::") + ) + + if len(sectionKey) >= 2 { + section = sectionKey[0] + k = sectionKey[1] + } else { + section = defaultSection + k = sectionKey[0] + } + + if _, ok := c.data[section]; !ok { + c.data[section] = make(map[string]string) + } + c.data[section][k] = value + return nil +} + +// DIY returns the raw value by a given key. +func (c *IniConfigContainer) DIY(key string) (v interface{}, err error) { + if v, ok := c.data[strings.ToLower(key)]; ok { + return v, nil + } + return v, errors.New("key not find") +} + +// section.key or key +func (c *IniConfigContainer) getdata(key string) string { + if len(key) == 0 { + return "" + } + c.RLock() + defer c.RUnlock() + + var ( + section, k string + sectionKey = strings.Split(strings.ToLower(key), "::") + ) + if len(sectionKey) >= 2 { + section = sectionKey[0] + k = sectionKey[1] + } else { + section = defaultSection + k = sectionKey[0] + } + if v, ok := c.data[section]; ok { + if vv, ok := v[k]; ok { + return vv + } + } + return "" +} + +func init() { + Register("ini", &IniConfig{}) +} diff --git a/pkg/config/ini_test.go b/pkg/config/ini_test.go new file mode 100644 index 00000000..ffcdb294 --- /dev/null +++ b/pkg/config/ini_test.go @@ -0,0 +1,190 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package config + +import ( + "fmt" + "io/ioutil" + "os" + "strings" + "testing" +) + +func TestIni(t *testing.T) { + + var ( + inicontext = ` +;comment one +#comment two +appname = beeapi +httpport = 8080 +mysqlport = 3600 +PI = 3.1415976 +runmode = "dev" +autorender = false +copyrequestbody = true +session= on +cookieon= off +newreg = OFF +needlogin = ON +enableSession = Y +enableCookie = N +flag = 1 +path1 = ${GOPATH} +path2 = ${GOPATH||/home/go} +[demo] +key1="asta" +key2 = "xie" +CaseInsensitive = true +peers = one;two;three +password = ${GOPATH} +` + + keyValue = map[string]interface{}{ + "appname": "beeapi", + "httpport": 8080, + "mysqlport": int64(3600), + "pi": 3.1415976, + "runmode": "dev", + "autorender": false, + "copyrequestbody": true, + "session": true, + "cookieon": false, + "newreg": false, + "needlogin": true, + "enableSession": true, + "enableCookie": false, + "flag": true, + "path1": os.Getenv("GOPATH"), + "path2": os.Getenv("GOPATH"), + "demo::key1": "asta", + "demo::key2": "xie", + "demo::CaseInsensitive": true, + "demo::peers": []string{"one", "two", "three"}, + "demo::password": os.Getenv("GOPATH"), + "null": "", + "demo2::key1": "", + "error": "", + "emptystrings": []string{}, + } + ) + + f, err := os.Create("testini.conf") + if err != nil { + t.Fatal(err) + } + _, err = f.WriteString(inicontext) + if err != nil { + f.Close() + t.Fatal(err) + } + f.Close() + defer os.Remove("testini.conf") + iniconf, err := NewConfig("ini", "testini.conf") + if err != nil { + t.Fatal(err) + } + for k, v := range keyValue { + var err error + var value interface{} + switch v.(type) { + case int: + value, err = iniconf.Int(k) + case int64: + value, err = iniconf.Int64(k) + case float64: + value, err = iniconf.Float(k) + case bool: + value, err = iniconf.Bool(k) + case []string: + value = iniconf.Strings(k) + case string: + value = iniconf.String(k) + default: + value, err = iniconf.DIY(k) + } + if err != nil { + t.Fatalf("get key %q value fail,err %s", k, err) + } else if fmt.Sprintf("%v", v) != fmt.Sprintf("%v", value) { + t.Fatalf("get key %q value, want %v got %v .", k, v, value) + } + + } + if err = iniconf.Set("name", "astaxie"); err != nil { + t.Fatal(err) + } + if iniconf.String("name") != "astaxie" { + t.Fatal("get name error") + } + +} + +func TestIniSave(t *testing.T) { + + const ( + inicontext = ` +app = app +;comment one +#comment two +# comment three +appname = beeapi +httpport = 8080 +# DB Info +# enable db +[dbinfo] +# db type name +# suport mysql,sqlserver +name = mysql +` + + saveResult = ` +app=app +#comment one +#comment two +# comment three +appname=beeapi +httpport=8080 + +# DB Info +# enable db +[dbinfo] +# db type name +# suport mysql,sqlserver +name=mysql +` + ) + cfg, err := NewConfigData("ini", []byte(inicontext)) + if err != nil { + t.Fatal(err) + } + name := "newIniConfig.ini" + if err := cfg.SaveConfigFile(name); err != nil { + t.Fatal(err) + } + defer os.Remove(name) + + if data, err := ioutil.ReadFile(name); err != nil { + t.Fatal(err) + } else { + cfgData := string(data) + datas := strings.Split(saveResult, "\n") + for _, line := range datas { + if !strings.Contains(cfgData, line+"\n") { + t.Fatalf("different after save ini config file. need contains %q", line) + } + } + + } +} diff --git a/pkg/config/json.go b/pkg/config/json.go new file mode 100644 index 00000000..c4ef25cd --- /dev/null +++ b/pkg/config/json.go @@ -0,0 +1,269 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package config + +import ( + "encoding/json" + "errors" + "fmt" + "io/ioutil" + "os" + "strconv" + "strings" + "sync" +) + +// JSONConfig is a json config parser and implements Config interface. +type JSONConfig struct { +} + +// Parse returns a ConfigContainer with parsed json config map. +func (js *JSONConfig) Parse(filename string) (Configer, error) { + file, err := os.Open(filename) + if err != nil { + return nil, err + } + defer file.Close() + content, err := ioutil.ReadAll(file) + if err != nil { + return nil, err + } + + return js.ParseData(content) +} + +// ParseData returns a ConfigContainer with json string +func (js *JSONConfig) ParseData(data []byte) (Configer, error) { + x := &JSONConfigContainer{ + data: make(map[string]interface{}), + } + err := json.Unmarshal(data, &x.data) + if err != nil { + var wrappingArray []interface{} + err2 := json.Unmarshal(data, &wrappingArray) + if err2 != nil { + return nil, err + } + x.data["rootArray"] = wrappingArray + } + + x.data = ExpandValueEnvForMap(x.data) + + return x, nil +} + +// JSONConfigContainer A Config represents the json configuration. +// Only when get value, support key as section:name type. +type JSONConfigContainer struct { + data map[string]interface{} + sync.RWMutex +} + +// Bool returns the boolean value for a given key. +func (c *JSONConfigContainer) Bool(key string) (bool, error) { + val := c.getData(key) + if val != nil { + return ParseBool(val) + } + return false, fmt.Errorf("not exist key: %q", key) +} + +// DefaultBool return the bool value if has no error +// otherwise return the defaultval +func (c *JSONConfigContainer) DefaultBool(key string, defaultval bool) bool { + if v, err := c.Bool(key); err == nil { + return v + } + return defaultval +} + +// Int returns the integer value for a given key. +func (c *JSONConfigContainer) Int(key string) (int, error) { + val := c.getData(key) + if val != nil { + if v, ok := val.(float64); ok { + return int(v), nil + } else if v, ok := val.(string); ok { + return strconv.Atoi(v) + } + return 0, errors.New("not valid value") + } + return 0, errors.New("not exist key:" + key) +} + +// DefaultInt returns the integer value for a given key. +// if err != nil return defaultval +func (c *JSONConfigContainer) DefaultInt(key string, defaultval int) int { + if v, err := c.Int(key); err == nil { + return v + } + return defaultval +} + +// Int64 returns the int64 value for a given key. +func (c *JSONConfigContainer) Int64(key string) (int64, error) { + val := c.getData(key) + if val != nil { + if v, ok := val.(float64); ok { + return int64(v), nil + } + return 0, errors.New("not int64 value") + } + return 0, errors.New("not exist key:" + key) +} + +// DefaultInt64 returns the int64 value for a given key. +// if err != nil return defaultval +func (c *JSONConfigContainer) DefaultInt64(key string, defaultval int64) int64 { + if v, err := c.Int64(key); err == nil { + return v + } + return defaultval +} + +// Float returns the float value for a given key. +func (c *JSONConfigContainer) Float(key string) (float64, error) { + val := c.getData(key) + if val != nil { + if v, ok := val.(float64); ok { + return v, nil + } + return 0.0, errors.New("not float64 value") + } + return 0.0, errors.New("not exist key:" + key) +} + +// DefaultFloat returns the float64 value for a given key. +// if err != nil return defaultval +func (c *JSONConfigContainer) DefaultFloat(key string, defaultval float64) float64 { + if v, err := c.Float(key); err == nil { + return v + } + return defaultval +} + +// String returns the string value for a given key. +func (c *JSONConfigContainer) String(key string) string { + val := c.getData(key) + if val != nil { + if v, ok := val.(string); ok { + return v + } + } + return "" +} + +// DefaultString returns the string value for a given key. +// if err != nil return defaultval +func (c *JSONConfigContainer) DefaultString(key string, defaultval string) string { + // TODO FIXME should not use "" to replace non existence + if v := c.String(key); v != "" { + return v + } + return defaultval +} + +// Strings returns the []string value for a given key. +func (c *JSONConfigContainer) Strings(key string) []string { + stringVal := c.String(key) + if stringVal == "" { + return nil + } + return strings.Split(c.String(key), ";") +} + +// DefaultStrings returns the []string value for a given key. +// if err != nil return defaultval +func (c *JSONConfigContainer) DefaultStrings(key string, defaultval []string) []string { + if v := c.Strings(key); v != nil { + return v + } + return defaultval +} + +// GetSection returns map for the given section +func (c *JSONConfigContainer) GetSection(section string) (map[string]string, error) { + if v, ok := c.data[section]; ok { + return v.(map[string]string), nil + } + return nil, errors.New("nonexist section " + section) +} + +// SaveConfigFile save the config into file +func (c *JSONConfigContainer) SaveConfigFile(filename string) (err error) { + // Write configuration file by filename. + f, err := os.Create(filename) + if err != nil { + return err + } + defer f.Close() + b, err := json.MarshalIndent(c.data, "", " ") + if err != nil { + return err + } + _, err = f.Write(b) + return err +} + +// Set writes a new value for key. +func (c *JSONConfigContainer) Set(key, val string) error { + c.Lock() + defer c.Unlock() + c.data[key] = val + return nil +} + +// DIY returns the raw value by a given key. +func (c *JSONConfigContainer) DIY(key string) (v interface{}, err error) { + val := c.getData(key) + if val != nil { + return val, nil + } + return nil, errors.New("not exist key") +} + +// section.key or key +func (c *JSONConfigContainer) getData(key string) interface{} { + if len(key) == 0 { + return nil + } + + c.RLock() + defer c.RUnlock() + + sectionKeys := strings.Split(key, "::") + if len(sectionKeys) >= 2 { + curValue, ok := c.data[sectionKeys[0]] + if !ok { + return nil + } + for _, key := range sectionKeys[1:] { + if v, ok := curValue.(map[string]interface{}); ok { + if curValue, ok = v[key]; !ok { + return nil + } + } + } + return curValue + } + if v, ok := c.data[key]; ok { + return v + } + return nil +} + +func init() { + Register("json", &JSONConfig{}) +} diff --git a/pkg/config/json_test.go b/pkg/config/json_test.go new file mode 100644 index 00000000..16f42409 --- /dev/null +++ b/pkg/config/json_test.go @@ -0,0 +1,222 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package config + +import ( + "fmt" + "os" + "testing" +) + +func TestJsonStartsWithArray(t *testing.T) { + + const jsoncontextwitharray = `[ + { + "url": "user", + "serviceAPI": "http://www.test.com/user" + }, + { + "url": "employee", + "serviceAPI": "http://www.test.com/employee" + } +]` + f, err := os.Create("testjsonWithArray.conf") + if err != nil { + t.Fatal(err) + } + _, err = f.WriteString(jsoncontextwitharray) + if err != nil { + f.Close() + t.Fatal(err) + } + f.Close() + defer os.Remove("testjsonWithArray.conf") + jsonconf, err := NewConfig("json", "testjsonWithArray.conf") + if err != nil { + t.Fatal(err) + } + rootArray, err := jsonconf.DIY("rootArray") + if err != nil { + t.Error("array does not exist as element") + } + rootArrayCasted := rootArray.([]interface{}) + if rootArrayCasted == nil { + t.Error("array from root is nil") + } else { + elem := rootArrayCasted[0].(map[string]interface{}) + if elem["url"] != "user" || elem["serviceAPI"] != "http://www.test.com/user" { + t.Error("array[0] values are not valid") + } + + elem2 := rootArrayCasted[1].(map[string]interface{}) + if elem2["url"] != "employee" || elem2["serviceAPI"] != "http://www.test.com/employee" { + t.Error("array[1] values are not valid") + } + } +} + +func TestJson(t *testing.T) { + + var ( + jsoncontext = `{ +"appname": "beeapi", +"testnames": "foo;bar", +"httpport": 8080, +"mysqlport": 3600, +"PI": 3.1415976, +"runmode": "dev", +"autorender": false, +"copyrequestbody": true, +"session": "on", +"cookieon": "off", +"newreg": "OFF", +"needlogin": "ON", +"enableSession": "Y", +"enableCookie": "N", +"flag": 1, +"path1": "${GOPATH}", +"path2": "${GOPATH||/home/go}", +"database": { + "host": "host", + "port": "port", + "database": "database", + "username": "username", + "password": "${GOPATH}", + "conns":{ + "maxconnection":12, + "autoconnect":true, + "connectioninfo":"info", + "root": "${GOPATH}" + } + } +}` + keyValue = map[string]interface{}{ + "appname": "beeapi", + "testnames": []string{"foo", "bar"}, + "httpport": 8080, + "mysqlport": int64(3600), + "PI": 3.1415976, + "runmode": "dev", + "autorender": false, + "copyrequestbody": true, + "session": true, + "cookieon": false, + "newreg": false, + "needlogin": true, + "enableSession": true, + "enableCookie": false, + "flag": true, + "path1": os.Getenv("GOPATH"), + "path2": os.Getenv("GOPATH"), + "database::host": "host", + "database::port": "port", + "database::database": "database", + "database::password": os.Getenv("GOPATH"), + "database::conns::maxconnection": 12, + "database::conns::autoconnect": true, + "database::conns::connectioninfo": "info", + "database::conns::root": os.Getenv("GOPATH"), + "unknown": "", + } + ) + + f, err := os.Create("testjson.conf") + if err != nil { + t.Fatal(err) + } + _, err = f.WriteString(jsoncontext) + if err != nil { + f.Close() + t.Fatal(err) + } + f.Close() + defer os.Remove("testjson.conf") + jsonconf, err := NewConfig("json", "testjson.conf") + if err != nil { + t.Fatal(err) + } + + for k, v := range keyValue { + var err error + var value interface{} + switch v.(type) { + case int: + value, err = jsonconf.Int(k) + case int64: + value, err = jsonconf.Int64(k) + case float64: + value, err = jsonconf.Float(k) + case bool: + value, err = jsonconf.Bool(k) + case []string: + value = jsonconf.Strings(k) + case string: + value = jsonconf.String(k) + default: + value, err = jsonconf.DIY(k) + } + if err != nil { + t.Fatalf("get key %q value fatal,%v err %s", k, v, err) + } else if fmt.Sprintf("%v", v) != fmt.Sprintf("%v", value) { + t.Fatalf("get key %q value, want %v got %v .", k, v, value) + } + + } + if err = jsonconf.Set("name", "astaxie"); err != nil { + t.Fatal(err) + } + if jsonconf.String("name") != "astaxie" { + t.Fatal("get name error") + } + + if db, err := jsonconf.DIY("database"); err != nil { + t.Fatal(err) + } else if m, ok := db.(map[string]interface{}); !ok { + t.Log(db) + t.Fatal("db not map[string]interface{}") + } else { + if m["host"].(string) != "host" { + t.Fatal("get host err") + } + } + + if _, err := jsonconf.Int("unknown"); err == nil { + t.Error("unknown keys should return an error when expecting an Int") + } + + if _, err := jsonconf.Int64("unknown"); err == nil { + t.Error("unknown keys should return an error when expecting an Int64") + } + + if _, err := jsonconf.Float("unknown"); err == nil { + t.Error("unknown keys should return an error when expecting a Float") + } + + if _, err := jsonconf.DIY("unknown"); err == nil { + t.Error("unknown keys should return an error when expecting an interface{}") + } + + if val := jsonconf.String("unknown"); val != "" { + t.Error("unknown keys should return an empty string when expecting a String") + } + + if _, err := jsonconf.Bool("unknown"); err == nil { + t.Error("unknown keys should return an error when expecting a Bool") + } + + if !jsonconf.DefaultBool("unknown", true) { + t.Error("unknown keys with default value wrong") + } +} diff --git a/pkg/config/xml/xml.go b/pkg/config/xml/xml.go new file mode 100644 index 00000000..494242d3 --- /dev/null +++ b/pkg/config/xml/xml.go @@ -0,0 +1,228 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package xml for config provider. +// +// depend on github.com/beego/x2j. +// +// go install github.com/beego/x2j. +// +// Usage: +// import( +// _ "github.com/astaxie/beego/config/xml" +// "github.com/astaxie/beego/config" +// ) +// +// cnf, err := config.NewConfig("xml", "config.xml") +// +//More docs http://beego.me/docs/module/config.md +package xml + +import ( + "encoding/xml" + "errors" + "fmt" + "io/ioutil" + "os" + "strconv" + "strings" + "sync" + + "github.com/astaxie/beego/config" + "github.com/beego/x2j" +) + +// Config is a xml config parser and implements Config interface. +// xml configurations should be included in tag. +// only support key/value pair as value as each item. +type Config struct{} + +// Parse returns a ConfigContainer with parsed xml config map. +func (xc *Config) Parse(filename string) (config.Configer, error) { + context, err := ioutil.ReadFile(filename) + if err != nil { + return nil, err + } + + return xc.ParseData(context) +} + +// ParseData xml data +func (xc *Config) ParseData(data []byte) (config.Configer, error) { + x := &ConfigContainer{data: make(map[string]interface{})} + + d, err := x2j.DocToMap(string(data)) + if err != nil { + return nil, err + } + + x.data = config.ExpandValueEnvForMap(d["config"].(map[string]interface{})) + + return x, nil +} + +// ConfigContainer A Config represents the xml configuration. +type ConfigContainer struct { + data map[string]interface{} + sync.Mutex +} + +// Bool returns the boolean value for a given key. +func (c *ConfigContainer) Bool(key string) (bool, error) { + if v := c.data[key]; v != nil { + return config.ParseBool(v) + } + return false, fmt.Errorf("not exist key: %q", key) +} + +// DefaultBool return the bool value if has no error +// otherwise return the defaultval +func (c *ConfigContainer) DefaultBool(key string, defaultval bool) bool { + v, err := c.Bool(key) + if err != nil { + return defaultval + } + return v +} + +// Int returns the integer value for a given key. +func (c *ConfigContainer) Int(key string) (int, error) { + return strconv.Atoi(c.data[key].(string)) +} + +// DefaultInt returns the integer value for a given key. +// if err != nil return defaultval +func (c *ConfigContainer) DefaultInt(key string, defaultval int) int { + v, err := c.Int(key) + if err != nil { + return defaultval + } + return v +} + +// Int64 returns the int64 value for a given key. +func (c *ConfigContainer) Int64(key string) (int64, error) { + return strconv.ParseInt(c.data[key].(string), 10, 64) +} + +// DefaultInt64 returns the int64 value for a given key. +// if err != nil return defaultval +func (c *ConfigContainer) DefaultInt64(key string, defaultval int64) int64 { + v, err := c.Int64(key) + if err != nil { + return defaultval + } + return v + +} + +// Float returns the float value for a given key. +func (c *ConfigContainer) Float(key string) (float64, error) { + return strconv.ParseFloat(c.data[key].(string), 64) +} + +// DefaultFloat returns the float64 value for a given key. +// if err != nil return defaultval +func (c *ConfigContainer) DefaultFloat(key string, defaultval float64) float64 { + v, err := c.Float(key) + if err != nil { + return defaultval + } + return v +} + +// String returns the string value for a given key. +func (c *ConfigContainer) String(key string) string { + if v, ok := c.data[key].(string); ok { + return v + } + return "" +} + +// DefaultString returns the string value for a given key. +// if err != nil return defaultval +func (c *ConfigContainer) DefaultString(key string, defaultval string) string { + v := c.String(key) + if v == "" { + return defaultval + } + return v +} + +// Strings returns the []string value for a given key. +func (c *ConfigContainer) Strings(key string) []string { + v := c.String(key) + if v == "" { + return nil + } + return strings.Split(v, ";") +} + +// DefaultStrings returns the []string value for a given key. +// if err != nil return defaultval +func (c *ConfigContainer) DefaultStrings(key string, defaultval []string) []string { + v := c.Strings(key) + if v == nil { + return defaultval + } + return v +} + +// GetSection returns map for the given section +func (c *ConfigContainer) GetSection(section string) (map[string]string, error) { + if v, ok := c.data[section].(map[string]interface{}); ok { + mapstr := make(map[string]string) + for k, val := range v { + mapstr[k] = config.ToString(val) + } + return mapstr, nil + } + return nil, fmt.Errorf("section '%s' not found", section) +} + +// SaveConfigFile save the config into file +func (c *ConfigContainer) SaveConfigFile(filename string) (err error) { + // Write configuration file by filename. + f, err := os.Create(filename) + if err != nil { + return err + } + defer f.Close() + b, err := xml.MarshalIndent(c.data, " ", " ") + if err != nil { + return err + } + _, err = f.Write(b) + return err +} + +// Set writes a new value for key. +func (c *ConfigContainer) Set(key, val string) error { + c.Lock() + defer c.Unlock() + c.data[key] = val + return nil +} + +// DIY returns the raw value by a given key. +func (c *ConfigContainer) DIY(key string) (v interface{}, err error) { + if v, ok := c.data[key]; ok { + return v, nil + } + return nil, errors.New("not exist key") +} + +func init() { + config.Register("xml", &Config{}) +} diff --git a/pkg/config/xml/xml_test.go b/pkg/config/xml/xml_test.go new file mode 100644 index 00000000..346c866e --- /dev/null +++ b/pkg/config/xml/xml_test.go @@ -0,0 +1,125 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package xml + +import ( + "fmt" + "os" + "testing" + + "github.com/astaxie/beego/config" +) + +func TestXML(t *testing.T) { + + var ( + //xml parse should incluce in tags + xmlcontext = ` + +beeapi +8080 +3600 +3.1415976 +dev +false +true +${GOPATH} +${GOPATH||/home/go} + +1 +MySection + + +` + keyValue = map[string]interface{}{ + "appname": "beeapi", + "httpport": 8080, + "mysqlport": int64(3600), + "PI": 3.1415976, + "runmode": "dev", + "autorender": false, + "copyrequestbody": true, + "path1": os.Getenv("GOPATH"), + "path2": os.Getenv("GOPATH"), + "error": "", + "emptystrings": []string{}, + } + ) + + f, err := os.Create("testxml.conf") + if err != nil { + t.Fatal(err) + } + _, err = f.WriteString(xmlcontext) + if err != nil { + f.Close() + t.Fatal(err) + } + f.Close() + defer os.Remove("testxml.conf") + + xmlconf, err := config.NewConfig("xml", "testxml.conf") + if err != nil { + t.Fatal(err) + } + + var xmlsection map[string]string + xmlsection, err = xmlconf.GetSection("mysection") + if err != nil { + t.Fatal(err) + } + + if len(xmlsection) == 0 { + t.Error("section should not be empty") + } + + for k, v := range keyValue { + + var ( + value interface{} + err error + ) + + switch v.(type) { + case int: + value, err = xmlconf.Int(k) + case int64: + value, err = xmlconf.Int64(k) + case float64: + value, err = xmlconf.Float(k) + case bool: + value, err = xmlconf.Bool(k) + case []string: + value = xmlconf.Strings(k) + case string: + value = xmlconf.String(k) + default: + value, err = xmlconf.DIY(k) + } + if err != nil { + t.Errorf("get key %q value fatal,%v err %s", k, v, err) + } else if fmt.Sprintf("%v", v) != fmt.Sprintf("%v", value) { + t.Errorf("get key %q value, want %v got %v .", k, v, value) + } + + } + + if err = xmlconf.Set("name", "astaxie"); err != nil { + t.Fatal(err) + } + if xmlconf.String("name") != "astaxie" { + t.Fatal("get name error") + } +} diff --git a/pkg/config/yaml/yaml.go b/pkg/config/yaml/yaml.go new file mode 100644 index 00000000..5def2da3 --- /dev/null +++ b/pkg/config/yaml/yaml.go @@ -0,0 +1,316 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package yaml for config provider +// +// depend on github.com/beego/goyaml2 +// +// go install github.com/beego/goyaml2 +// +// Usage: +// import( +// _ "github.com/astaxie/beego/config/yaml" +// "github.com/astaxie/beego/config" +// ) +// +// cnf, err := config.NewConfig("yaml", "config.yaml") +// +//More docs http://beego.me/docs/module/config.md +package yaml + +import ( + "bytes" + "encoding/json" + "errors" + "fmt" + "io/ioutil" + "log" + "os" + "strings" + "sync" + + "github.com/astaxie/beego/config" + "github.com/beego/goyaml2" +) + +// Config is a yaml config parser and implements Config interface. +type Config struct{} + +// Parse returns a ConfigContainer with parsed yaml config map. +func (yaml *Config) Parse(filename string) (y config.Configer, err error) { + cnf, err := ReadYmlReader(filename) + if err != nil { + return + } + y = &ConfigContainer{ + data: cnf, + } + return +} + +// ParseData parse yaml data +func (yaml *Config) ParseData(data []byte) (config.Configer, error) { + cnf, err := parseYML(data) + if err != nil { + return nil, err + } + + return &ConfigContainer{ + data: cnf, + }, nil +} + +// ReadYmlReader Read yaml file to map. +// if json like, use json package, unless goyaml2 package. +func ReadYmlReader(path string) (cnf map[string]interface{}, err error) { + buf, err := ioutil.ReadFile(path) + if err != nil { + return + } + + return parseYML(buf) +} + +// parseYML parse yaml formatted []byte to map. +func parseYML(buf []byte) (cnf map[string]interface{}, err error) { + if len(buf) < 3 { + return + } + + if string(buf[0:1]) == "{" { + log.Println("Look like a Json, try json umarshal") + err = json.Unmarshal(buf, &cnf) + if err == nil { + log.Println("It is Json Map") + return + } + } + + data, err := goyaml2.Read(bytes.NewReader(buf)) + if err != nil { + log.Println("Goyaml2 ERR>", string(buf), err) + return + } + + if data == nil { + log.Println("Goyaml2 output nil? Pls report bug\n" + string(buf)) + return + } + cnf, ok := data.(map[string]interface{}) + if !ok { + log.Println("Not a Map? >> ", string(buf), data) + cnf = nil + } + cnf = config.ExpandValueEnvForMap(cnf) + return +} + +// ConfigContainer A Config represents the yaml configuration. +type ConfigContainer struct { + data map[string]interface{} + sync.RWMutex +} + +// Bool returns the boolean value for a given key. +func (c *ConfigContainer) Bool(key string) (bool, error) { + v, err := c.getData(key) + if err != nil { + return false, err + } + return config.ParseBool(v) +} + +// DefaultBool return the bool value if has no error +// otherwise return the defaultval +func (c *ConfigContainer) DefaultBool(key string, defaultval bool) bool { + v, err := c.Bool(key) + if err != nil { + return defaultval + } + return v +} + +// Int returns the integer value for a given key. +func (c *ConfigContainer) Int(key string) (int, error) { + if v, err := c.getData(key); err != nil { + return 0, err + } else if vv, ok := v.(int); ok { + return vv, nil + } else if vv, ok := v.(int64); ok { + return int(vv), nil + } + return 0, errors.New("not int value") +} + +// DefaultInt returns the integer value for a given key. +// if err != nil return defaultval +func (c *ConfigContainer) DefaultInt(key string, defaultval int) int { + v, err := c.Int(key) + if err != nil { + return defaultval + } + return v +} + +// Int64 returns the int64 value for a given key. +func (c *ConfigContainer) Int64(key string) (int64, error) { + if v, err := c.getData(key); err != nil { + return 0, err + } else if vv, ok := v.(int64); ok { + return vv, nil + } + return 0, errors.New("not bool value") +} + +// DefaultInt64 returns the int64 value for a given key. +// if err != nil return defaultval +func (c *ConfigContainer) DefaultInt64(key string, defaultval int64) int64 { + v, err := c.Int64(key) + if err != nil { + return defaultval + } + return v +} + +// Float returns the float value for a given key. +func (c *ConfigContainer) Float(key string) (float64, error) { + if v, err := c.getData(key); err != nil { + return 0.0, err + } else if vv, ok := v.(float64); ok { + return vv, nil + } else if vv, ok := v.(int); ok { + return float64(vv), nil + } else if vv, ok := v.(int64); ok { + return float64(vv), nil + } + return 0.0, errors.New("not float64 value") +} + +// DefaultFloat returns the float64 value for a given key. +// if err != nil return defaultval +func (c *ConfigContainer) DefaultFloat(key string, defaultval float64) float64 { + v, err := c.Float(key) + if err != nil { + return defaultval + } + return v +} + +// String returns the string value for a given key. +func (c *ConfigContainer) String(key string) string { + if v, err := c.getData(key); err == nil { + if vv, ok := v.(string); ok { + return vv + } + } + return "" +} + +// DefaultString returns the string value for a given key. +// if err != nil return defaultval +func (c *ConfigContainer) DefaultString(key string, defaultval string) string { + v := c.String(key) + if v == "" { + return defaultval + } + return v +} + +// Strings returns the []string value for a given key. +func (c *ConfigContainer) Strings(key string) []string { + v := c.String(key) + if v == "" { + return nil + } + return strings.Split(v, ";") +} + +// DefaultStrings returns the []string value for a given key. +// if err != nil return defaultval +func (c *ConfigContainer) DefaultStrings(key string, defaultval []string) []string { + v := c.Strings(key) + if v == nil { + return defaultval + } + return v +} + +// GetSection returns map for the given section +func (c *ConfigContainer) GetSection(section string) (map[string]string, error) { + + if v, ok := c.data[section]; ok { + return v.(map[string]string), nil + } + return nil, errors.New("not exist section") +} + +// SaveConfigFile save the config into file +func (c *ConfigContainer) SaveConfigFile(filename string) (err error) { + // Write configuration file by filename. + f, err := os.Create(filename) + if err != nil { + return err + } + defer f.Close() + err = goyaml2.Write(f, c.data) + return err +} + +// Set writes a new value for key. +func (c *ConfigContainer) Set(key, val string) error { + c.Lock() + defer c.Unlock() + c.data[key] = val + return nil +} + +// DIY returns the raw value by a given key. +func (c *ConfigContainer) DIY(key string) (v interface{}, err error) { + return c.getData(key) +} + +func (c *ConfigContainer) getData(key string) (interface{}, error) { + + if len(key) == 0 { + return nil, errors.New("key is empty") + } + c.RLock() + defer c.RUnlock() + + keys := strings.Split(key, ".") + tmpData := c.data + for idx, k := range keys { + if v, ok := tmpData[k]; ok { + switch v.(type) { + case map[string]interface{}: + { + tmpData = v.(map[string]interface{}) + if idx == len(keys) - 1 { + return tmpData, nil + } + } + default: + { + return v, nil + } + + } + } + } + return nil, fmt.Errorf("not exist key %q", key) +} + +func init() { + config.Register("yaml", &Config{}) +} diff --git a/pkg/config/yaml/yaml_test.go b/pkg/config/yaml/yaml_test.go new file mode 100644 index 00000000..49cc1d1e --- /dev/null +++ b/pkg/config/yaml/yaml_test.go @@ -0,0 +1,115 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package yaml + +import ( + "fmt" + "os" + "testing" + + "github.com/astaxie/beego/config" +) + +func TestYaml(t *testing.T) { + + var ( + yamlcontext = ` +"appname": beeapi +"httpport": 8080 +"mysqlport": 3600 +"PI": 3.1415976 +"runmode": dev +"autorender": false +"copyrequestbody": true +"PATH": GOPATH +"path1": ${GOPATH} +"path2": ${GOPATH||/home/go} +"empty": "" +` + + keyValue = map[string]interface{}{ + "appname": "beeapi", + "httpport": 8080, + "mysqlport": int64(3600), + "PI": 3.1415976, + "runmode": "dev", + "autorender": false, + "copyrequestbody": true, + "PATH": "GOPATH", + "path1": os.Getenv("GOPATH"), + "path2": os.Getenv("GOPATH"), + "error": "", + "emptystrings": []string{}, + } + ) + f, err := os.Create("testyaml.conf") + if err != nil { + t.Fatal(err) + } + _, err = f.WriteString(yamlcontext) + if err != nil { + f.Close() + t.Fatal(err) + } + f.Close() + defer os.Remove("testyaml.conf") + yamlconf, err := config.NewConfig("yaml", "testyaml.conf") + if err != nil { + t.Fatal(err) + } + + if yamlconf.String("appname") != "beeapi" { + t.Fatal("appname not equal to beeapi") + } + + for k, v := range keyValue { + + var ( + value interface{} + err error + ) + + switch v.(type) { + case int: + value, err = yamlconf.Int(k) + case int64: + value, err = yamlconf.Int64(k) + case float64: + value, err = yamlconf.Float(k) + case bool: + value, err = yamlconf.Bool(k) + case []string: + value = yamlconf.Strings(k) + case string: + value = yamlconf.String(k) + default: + value, err = yamlconf.DIY(k) + } + if err != nil { + t.Errorf("get key %q value fatal,%v err %s", k, v, err) + } else if fmt.Sprintf("%v", v) != fmt.Sprintf("%v", value) { + t.Errorf("get key %q value, want %v got %v .", k, v, value) + } + + } + + if err = yamlconf.Set("name", "astaxie"); err != nil { + t.Fatal(err) + } + if yamlconf.String("name") != "astaxie" { + t.Fatal("get name error") + } + +} diff --git a/pkg/config_test.go b/pkg/config_test.go new file mode 100644 index 00000000..5f71f1c3 --- /dev/null +++ b/pkg/config_test.go @@ -0,0 +1,146 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package beego + +import ( + "encoding/json" + "reflect" + "testing" + + "github.com/astaxie/beego/config" +) + +func TestDefaults(t *testing.T) { + if BConfig.WebConfig.FlashName != "BEEGO_FLASH" { + t.Errorf("FlashName was not set to default.") + } + + if BConfig.WebConfig.FlashSeparator != "BEEGOFLASH" { + t.Errorf("FlashName was not set to default.") + } +} + +func TestAssignConfig_01(t *testing.T) { + _BConfig := &Config{} + _BConfig.AppName = "beego_test" + jcf := &config.JSONConfig{} + ac, _ := jcf.ParseData([]byte(`{"AppName":"beego_json"}`)) + assignSingleConfig(_BConfig, ac) + if _BConfig.AppName != "beego_json" { + t.Log(_BConfig) + t.FailNow() + } +} + +func TestAssignConfig_02(t *testing.T) { + _BConfig := &Config{} + bs, _ := json.Marshal(newBConfig()) + + jsonMap := M{} + json.Unmarshal(bs, &jsonMap) + + configMap := M{} + for k, v := range jsonMap { + if reflect.TypeOf(v).Kind() == reflect.Map { + for k1, v1 := range v.(M) { + if reflect.TypeOf(v1).Kind() == reflect.Map { + for k2, v2 := range v1.(M) { + configMap[k2] = v2 + } + } else { + configMap[k1] = v1 + } + } + } else { + configMap[k] = v + } + } + configMap["MaxMemory"] = 1024 + configMap["Graceful"] = true + configMap["XSRFExpire"] = 32 + configMap["SessionProviderConfig"] = "file" + configMap["FileLineNum"] = true + + jcf := &config.JSONConfig{} + bs, _ = json.Marshal(configMap) + ac, _ := jcf.ParseData(bs) + + for _, i := range []interface{}{_BConfig, &_BConfig.Listen, &_BConfig.WebConfig, &_BConfig.Log, &_BConfig.WebConfig.Session} { + assignSingleConfig(i, ac) + } + + if _BConfig.MaxMemory != 1024 { + t.Log(_BConfig.MaxMemory) + t.FailNow() + } + + if !_BConfig.Listen.Graceful { + t.Log(_BConfig.Listen.Graceful) + t.FailNow() + } + + if _BConfig.WebConfig.XSRFExpire != 32 { + t.Log(_BConfig.WebConfig.XSRFExpire) + t.FailNow() + } + + if _BConfig.WebConfig.Session.SessionProviderConfig != "file" { + t.Log(_BConfig.WebConfig.Session.SessionProviderConfig) + t.FailNow() + } + + if !_BConfig.Log.FileLineNum { + t.Log(_BConfig.Log.FileLineNum) + t.FailNow() + } + +} + +func TestAssignConfig_03(t *testing.T) { + jcf := &config.JSONConfig{} + ac, _ := jcf.ParseData([]byte(`{"AppName":"beego"}`)) + ac.Set("AppName", "test_app") + ac.Set("RunMode", "online") + ac.Set("StaticDir", "download:down download2:down2") + ac.Set("StaticExtensionsToGzip", ".css,.js,.html,.jpg,.png") + ac.Set("StaticCacheFileSize", "87456") + ac.Set("StaticCacheFileNum", "1254") + assignConfig(ac) + + t.Logf("%#v", BConfig) + + if BConfig.AppName != "test_app" { + t.FailNow() + } + + if BConfig.RunMode != "online" { + t.FailNow() + } + if BConfig.WebConfig.StaticDir["/download"] != "down" { + t.FailNow() + } + if BConfig.WebConfig.StaticDir["/download2"] != "down2" { + t.FailNow() + } + if BConfig.WebConfig.StaticCacheFileSize != 87456 { + t.FailNow() + } + if BConfig.WebConfig.StaticCacheFileNum != 1254 { + t.FailNow() + } + if len(BConfig.WebConfig.StaticExtensionsToGzip) != 5 { + t.FailNow() + } +} diff --git a/pkg/context/acceptencoder.go b/pkg/context/acceptencoder.go new file mode 100644 index 00000000..b4e2492c --- /dev/null +++ b/pkg/context/acceptencoder.go @@ -0,0 +1,232 @@ +// Copyright 2015 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package context + +import ( + "bytes" + "compress/flate" + "compress/gzip" + "compress/zlib" + "io" + "net/http" + "os" + "strconv" + "strings" + "sync" +) + +var ( + //Default size==20B same as nginx + defaultGzipMinLength = 20 + //Content will only be compressed if content length is either unknown or greater than gzipMinLength. + gzipMinLength = defaultGzipMinLength + //The compression level used for deflate compression. (0-9). + gzipCompressLevel int + //List of HTTP methods to compress. If not set, only GET requests are compressed. + includedMethods map[string]bool + getMethodOnly bool +) + +// InitGzip init the gzipcompress +func InitGzip(minLength, compressLevel int, methods []string) { + if minLength >= 0 { + gzipMinLength = minLength + } + gzipCompressLevel = compressLevel + if gzipCompressLevel < flate.NoCompression || gzipCompressLevel > flate.BestCompression { + gzipCompressLevel = flate.BestSpeed + } + getMethodOnly = (len(methods) == 0) || (len(methods) == 1 && strings.ToUpper(methods[0]) == "GET") + includedMethods = make(map[string]bool, len(methods)) + for _, v := range methods { + includedMethods[strings.ToUpper(v)] = true + } +} + +type resetWriter interface { + io.Writer + Reset(w io.Writer) +} + +type nopResetWriter struct { + io.Writer +} + +func (n nopResetWriter) Reset(w io.Writer) { + //do nothing +} + +type acceptEncoder struct { + name string + levelEncode func(int) resetWriter + customCompressLevelPool *sync.Pool + bestCompressionPool *sync.Pool +} + +func (ac acceptEncoder) encode(wr io.Writer, level int) resetWriter { + if ac.customCompressLevelPool == nil || ac.bestCompressionPool == nil { + return nopResetWriter{wr} + } + var rwr resetWriter + switch level { + case flate.BestSpeed: + rwr = ac.customCompressLevelPool.Get().(resetWriter) + case flate.BestCompression: + rwr = ac.bestCompressionPool.Get().(resetWriter) + default: + rwr = ac.levelEncode(level) + } + rwr.Reset(wr) + return rwr +} + +func (ac acceptEncoder) put(wr resetWriter, level int) { + if ac.customCompressLevelPool == nil || ac.bestCompressionPool == nil { + return + } + wr.Reset(nil) + + //notice + //compressionLevel==BestCompression DOES NOT MATTER + //sync.Pool will not memory leak + + switch level { + case gzipCompressLevel: + ac.customCompressLevelPool.Put(wr) + case flate.BestCompression: + ac.bestCompressionPool.Put(wr) + } +} + +var ( + noneCompressEncoder = acceptEncoder{"", nil, nil, nil} + gzipCompressEncoder = acceptEncoder{ + name: "gzip", + levelEncode: func(level int) resetWriter { wr, _ := gzip.NewWriterLevel(nil, level); return wr }, + customCompressLevelPool: &sync.Pool{New: func() interface{} { wr, _ := gzip.NewWriterLevel(nil, gzipCompressLevel); return wr }}, + bestCompressionPool: &sync.Pool{New: func() interface{} { wr, _ := gzip.NewWriterLevel(nil, flate.BestCompression); return wr }}, + } + + //according to the sec :http://tools.ietf.org/html/rfc2616#section-3.5 ,the deflate compress in http is zlib indeed + //deflate + //The "zlib" format defined in RFC 1950 [31] in combination with + //the "deflate" compression mechanism described in RFC 1951 [29]. + deflateCompressEncoder = acceptEncoder{ + name: "deflate", + levelEncode: func(level int) resetWriter { wr, _ := zlib.NewWriterLevel(nil, level); return wr }, + customCompressLevelPool: &sync.Pool{New: func() interface{} { wr, _ := zlib.NewWriterLevel(nil, gzipCompressLevel); return wr }}, + bestCompressionPool: &sync.Pool{New: func() interface{} { wr, _ := zlib.NewWriterLevel(nil, flate.BestCompression); return wr }}, + } +) + +var ( + encoderMap = map[string]acceptEncoder{ // all the other compress methods will ignore + "gzip": gzipCompressEncoder, + "deflate": deflateCompressEncoder, + "*": gzipCompressEncoder, // * means any compress will accept,we prefer gzip + "identity": noneCompressEncoder, // identity means none-compress + } +) + +// WriteFile reads from file and writes to writer by the specific encoding(gzip/deflate) +func WriteFile(encoding string, writer io.Writer, file *os.File) (bool, string, error) { + return writeLevel(encoding, writer, file, flate.BestCompression) +} + +// WriteBody reads writes content to writer by the specific encoding(gzip/deflate) +func WriteBody(encoding string, writer io.Writer, content []byte) (bool, string, error) { + if encoding == "" || len(content) < gzipMinLength { + _, err := writer.Write(content) + return false, "", err + } + return writeLevel(encoding, writer, bytes.NewReader(content), gzipCompressLevel) +} + +// writeLevel reads from reader,writes to writer by specific encoding and compress level +// the compress level is defined by deflate package +func writeLevel(encoding string, writer io.Writer, reader io.Reader, level int) (bool, string, error) { + var outputWriter resetWriter + var err error + var ce = noneCompressEncoder + + if cf, ok := encoderMap[encoding]; ok { + ce = cf + } + encoding = ce.name + outputWriter = ce.encode(writer, level) + defer ce.put(outputWriter, level) + + _, err = io.Copy(outputWriter, reader) + if err != nil { + return false, "", err + } + + switch outputWriter.(type) { + case io.WriteCloser: + outputWriter.(io.WriteCloser).Close() + } + return encoding != "", encoding, nil +} + +// ParseEncoding will extract the right encoding for response +// the Accept-Encoding's sec is here: +// http://www.w3.org/Protocols/rfc2616/rfc2616-sec14.html#sec14.3 +func ParseEncoding(r *http.Request) string { + if r == nil { + return "" + } + if (getMethodOnly && r.Method == "GET") || includedMethods[r.Method] { + return parseEncoding(r) + } + return "" +} + +type q struct { + name string + value float64 +} + +func parseEncoding(r *http.Request) string { + acceptEncoding := r.Header.Get("Accept-Encoding") + if acceptEncoding == "" { + return "" + } + var lastQ q + for _, v := range strings.Split(acceptEncoding, ",") { + v = strings.TrimSpace(v) + if v == "" { + continue + } + vs := strings.Split(v, ";") + var cf acceptEncoder + var ok bool + if cf, ok = encoderMap[vs[0]]; !ok { + continue + } + if len(vs) == 1 { + return cf.name + } + if len(vs) == 2 { + f, _ := strconv.ParseFloat(strings.Replace(vs[1], "q=", "", -1), 64) + if f == 0 { + continue + } + if f > lastQ.value { + lastQ = q{cf.name, f} + } + } + } + return lastQ.name +} diff --git a/pkg/context/acceptencoder_test.go b/pkg/context/acceptencoder_test.go new file mode 100644 index 00000000..e3d61e27 --- /dev/null +++ b/pkg/context/acceptencoder_test.go @@ -0,0 +1,59 @@ +// Copyright 2015 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package context + +import ( + "net/http" + "testing" +) + +func Test_ExtractEncoding(t *testing.T) { + if parseEncoding(&http.Request{Header: map[string][]string{"Accept-Encoding": {"gzip,deflate"}}}) != "gzip" { + t.Fail() + } + if parseEncoding(&http.Request{Header: map[string][]string{"Accept-Encoding": {"deflate,gzip"}}}) != "deflate" { + t.Fail() + } + if parseEncoding(&http.Request{Header: map[string][]string{"Accept-Encoding": {"gzip;q=.5,deflate"}}}) != "deflate" { + t.Fail() + } + if parseEncoding(&http.Request{Header: map[string][]string{"Accept-Encoding": {"gzip;q=.5,deflate;q=0.3"}}}) != "gzip" { + t.Fail() + } + if parseEncoding(&http.Request{Header: map[string][]string{"Accept-Encoding": {"gzip;q=0,deflate"}}}) != "deflate" { + t.Fail() + } + if parseEncoding(&http.Request{Header: map[string][]string{"Accept-Encoding": {"deflate;q=0.5,gzip;q=0.5,identity"}}}) != "" { + t.Fail() + } + if parseEncoding(&http.Request{Header: map[string][]string{"Accept-Encoding": {"*"}}}) != "gzip" { + t.Fail() + } + if parseEncoding(&http.Request{Header: map[string][]string{"Accept-Encoding": {"x,gzip,deflate"}}}) != "gzip" { + t.Fail() + } + if parseEncoding(&http.Request{Header: map[string][]string{"Accept-Encoding": {"gzip,x,deflate"}}}) != "gzip" { + t.Fail() + } + if parseEncoding(&http.Request{Header: map[string][]string{"Accept-Encoding": {"gzip;q=0.5,x,deflate"}}}) != "deflate" { + t.Fail() + } + if parseEncoding(&http.Request{Header: map[string][]string{"Accept-Encoding": {"x"}}}) != "" { + t.Fail() + } + if parseEncoding(&http.Request{Header: map[string][]string{"Accept-Encoding": {"gzip;q=0.5,x;q=0.8"}}}) != "gzip" { + t.Fail() + } +} diff --git a/pkg/context/context.go b/pkg/context/context.go new file mode 100644 index 00000000..de248ed2 --- /dev/null +++ b/pkg/context/context.go @@ -0,0 +1,263 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package context provide the context utils +// Usage: +// +// import "github.com/astaxie/beego/context" +// +// ctx := context.Context{Request:req,ResponseWriter:rw} +// +// more docs http://beego.me/docs/module/context.md +package context + +import ( + "bufio" + "crypto/hmac" + "crypto/sha256" + "encoding/base64" + "errors" + "fmt" + "net" + "net/http" + "strconv" + "strings" + "time" + + "github.com/astaxie/beego/utils" +) + +//commonly used mime-types +const ( + ApplicationJSON = "application/json" + ApplicationXML = "application/xml" + ApplicationYAML = "application/x-yaml" + TextXML = "text/xml" +) + +// NewContext return the Context with Input and Output +func NewContext() *Context { + return &Context{ + Input: NewInput(), + Output: NewOutput(), + } +} + +// Context Http request context struct including BeegoInput, BeegoOutput, http.Request and http.ResponseWriter. +// BeegoInput and BeegoOutput provides some api to operate request and response more easily. +type Context struct { + Input *BeegoInput + Output *BeegoOutput + Request *http.Request + ResponseWriter *Response + _xsrfToken string +} + +// Reset init Context, BeegoInput and BeegoOutput +func (ctx *Context) Reset(rw http.ResponseWriter, r *http.Request) { + ctx.Request = r + if ctx.ResponseWriter == nil { + ctx.ResponseWriter = &Response{} + } + ctx.ResponseWriter.reset(rw) + ctx.Input.Reset(ctx) + ctx.Output.Reset(ctx) + ctx._xsrfToken = "" +} + +// Redirect does redirection to localurl with http header status code. +func (ctx *Context) Redirect(status int, localurl string) { + http.Redirect(ctx.ResponseWriter, ctx.Request, localurl, status) +} + +// Abort stops this request. +// if beego.ErrorMaps exists, panic body. +func (ctx *Context) Abort(status int, body string) { + ctx.Output.SetStatus(status) + panic(body) +} + +// WriteString Write string to response body. +// it sends response body. +func (ctx *Context) WriteString(content string) { + ctx.ResponseWriter.Write([]byte(content)) +} + +// GetCookie Get cookie from request by a given key. +// It's alias of BeegoInput.Cookie. +func (ctx *Context) GetCookie(key string) string { + return ctx.Input.Cookie(key) +} + +// SetCookie Set cookie for response. +// It's alias of BeegoOutput.Cookie. +func (ctx *Context) SetCookie(name string, value string, others ...interface{}) { + ctx.Output.Cookie(name, value, others...) +} + +// GetSecureCookie Get secure cookie from request by a given key. +func (ctx *Context) GetSecureCookie(Secret, key string) (string, bool) { + val := ctx.Input.Cookie(key) + if val == "" { + return "", false + } + + parts := strings.SplitN(val, "|", 3) + + if len(parts) != 3 { + return "", false + } + + vs := parts[0] + timestamp := parts[1] + sig := parts[2] + + h := hmac.New(sha256.New, []byte(Secret)) + fmt.Fprintf(h, "%s%s", vs, timestamp) + + if fmt.Sprintf("%02x", h.Sum(nil)) != sig { + return "", false + } + res, _ := base64.URLEncoding.DecodeString(vs) + return string(res), true +} + +// SetSecureCookie Set Secure cookie for response. +func (ctx *Context) SetSecureCookie(Secret, name, value string, others ...interface{}) { + vs := base64.URLEncoding.EncodeToString([]byte(value)) + timestamp := strconv.FormatInt(time.Now().UnixNano(), 10) + h := hmac.New(sha256.New, []byte(Secret)) + fmt.Fprintf(h, "%s%s", vs, timestamp) + sig := fmt.Sprintf("%02x", h.Sum(nil)) + cookie := strings.Join([]string{vs, timestamp, sig}, "|") + ctx.Output.Cookie(name, cookie, others...) +} + +// XSRFToken creates a xsrf token string and returns. +func (ctx *Context) XSRFToken(key string, expire int64) string { + if ctx._xsrfToken == "" { + token, ok := ctx.GetSecureCookie(key, "_xsrf") + if !ok { + token = string(utils.RandomCreateBytes(32)) + ctx.SetSecureCookie(key, "_xsrf", token, expire) + } + ctx._xsrfToken = token + } + return ctx._xsrfToken +} + +// CheckXSRFCookie checks xsrf token in this request is valid or not. +// the token can provided in request header "X-Xsrftoken" and "X-CsrfToken" +// or in form field value named as "_xsrf". +func (ctx *Context) CheckXSRFCookie() bool { + token := ctx.Input.Query("_xsrf") + if token == "" { + token = ctx.Request.Header.Get("X-Xsrftoken") + } + if token == "" { + token = ctx.Request.Header.Get("X-Csrftoken") + } + if token == "" { + ctx.Abort(422, "422") + return false + } + if ctx._xsrfToken != token { + ctx.Abort(417, "417") + return false + } + return true +} + +// RenderMethodResult renders the return value of a controller method to the output +func (ctx *Context) RenderMethodResult(result interface{}) { + if result != nil { + renderer, ok := result.(Renderer) + if !ok { + err, ok := result.(error) + if ok { + renderer = errorRenderer(err) + } else { + renderer = jsonRenderer(result) + } + } + renderer.Render(ctx) + } +} + +//Response is a wrapper for the http.ResponseWriter +//started set to true if response was written to then don't execute other handler +type Response struct { + http.ResponseWriter + Started bool + Status int + Elapsed time.Duration +} + +func (r *Response) reset(rw http.ResponseWriter) { + r.ResponseWriter = rw + r.Status = 0 + r.Started = false +} + +// Write writes the data to the connection as part of an HTTP reply, +// and sets `started` to true. +// started means the response has sent out. +func (r *Response) Write(p []byte) (int, error) { + r.Started = true + return r.ResponseWriter.Write(p) +} + +// WriteHeader sends an HTTP response header with status code, +// and sets `started` to true. +func (r *Response) WriteHeader(code int) { + if r.Status > 0 { + //prevent multiple response.WriteHeader calls + return + } + r.Status = code + r.Started = true + r.ResponseWriter.WriteHeader(code) +} + +// Hijack hijacker for http +func (r *Response) Hijack() (net.Conn, *bufio.ReadWriter, error) { + hj, ok := r.ResponseWriter.(http.Hijacker) + if !ok { + return nil, nil, errors.New("webserver doesn't support hijacking") + } + return hj.Hijack() +} + +// Flush http.Flusher +func (r *Response) Flush() { + if f, ok := r.ResponseWriter.(http.Flusher); ok { + f.Flush() + } +} + +// CloseNotify http.CloseNotifier +func (r *Response) CloseNotify() <-chan bool { + if cn, ok := r.ResponseWriter.(http.CloseNotifier); ok { + return cn.CloseNotify() + } + return nil +} + +// Pusher http.Pusher +func (r *Response) Pusher() (pusher http.Pusher) { + if pusher, ok := r.ResponseWriter.(http.Pusher); ok { + return pusher + } + return nil +} diff --git a/pkg/context/context_test.go b/pkg/context/context_test.go new file mode 100644 index 00000000..7c0535e0 --- /dev/null +++ b/pkg/context/context_test.go @@ -0,0 +1,47 @@ +// Copyright 2016 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package context + +import ( + "net/http" + "net/http/httptest" + "testing" +) + +func TestXsrfReset_01(t *testing.T) { + r := &http.Request{} + c := NewContext() + c.Request = r + c.ResponseWriter = &Response{} + c.ResponseWriter.reset(httptest.NewRecorder()) + c.Output.Reset(c) + c.Input.Reset(c) + c.XSRFToken("key", 16) + if c._xsrfToken == "" { + t.FailNow() + } + token := c._xsrfToken + c.Reset(&Response{ResponseWriter: httptest.NewRecorder()}, r) + if c._xsrfToken != "" { + t.FailNow() + } + c.XSRFToken("key", 16) + if c._xsrfToken == "" { + t.FailNow() + } + if token == c._xsrfToken { + t.FailNow() + } +} diff --git a/pkg/context/input.go b/pkg/context/input.go new file mode 100644 index 00000000..385549c1 --- /dev/null +++ b/pkg/context/input.go @@ -0,0 +1,689 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package context + +import ( + "bytes" + "compress/gzip" + "errors" + "io" + "io/ioutil" + "net" + "net/http" + "net/url" + "reflect" + "regexp" + "strconv" + "strings" + "sync" + + "github.com/astaxie/beego/session" +) + +// Regexes for checking the accept headers +// TODO make sure these are correct +var ( + acceptsHTMLRegex = regexp.MustCompile(`(text/html|application/xhtml\+xml)(?:,|$)`) + acceptsXMLRegex = regexp.MustCompile(`(application/xml|text/xml)(?:,|$)`) + acceptsJSONRegex = regexp.MustCompile(`(application/json)(?:,|$)`) + acceptsYAMLRegex = regexp.MustCompile(`(application/x-yaml)(?:,|$)`) + maxParam = 50 +) + +// BeegoInput operates the http request header, data, cookie and body. +// it also contains router params and current session. +type BeegoInput struct { + Context *Context + CruSession session.Store + pnames []string + pvalues []string + data map[interface{}]interface{} // store some values in this context when calling context in filter or controller. + dataLock sync.RWMutex + RequestBody []byte + RunMethod string + RunController reflect.Type +} + +// NewInput return BeegoInput generated by Context. +func NewInput() *BeegoInput { + return &BeegoInput{ + pnames: make([]string, 0, maxParam), + pvalues: make([]string, 0, maxParam), + data: make(map[interface{}]interface{}), + } +} + +// Reset init the BeegoInput +func (input *BeegoInput) Reset(ctx *Context) { + input.Context = ctx + input.CruSession = nil + input.pnames = input.pnames[:0] + input.pvalues = input.pvalues[:0] + input.dataLock.Lock() + input.data = nil + input.dataLock.Unlock() + input.RequestBody = []byte{} +} + +// Protocol returns request protocol name, such as HTTP/1.1 . +func (input *BeegoInput) Protocol() string { + return input.Context.Request.Proto +} + +// URI returns full request url with query string, fragment. +func (input *BeegoInput) URI() string { + return input.Context.Request.RequestURI +} + +// URL returns request url path (without query string, fragment). +func (input *BeegoInput) URL() string { + return input.Context.Request.URL.EscapedPath() +} + +// Site returns base site url as scheme://domain type. +func (input *BeegoInput) Site() string { + return input.Scheme() + "://" + input.Domain() +} + +// Scheme returns request scheme as "http" or "https". +func (input *BeegoInput) Scheme() string { + if scheme := input.Header("X-Forwarded-Proto"); scheme != "" { + return scheme + } + if input.Context.Request.URL.Scheme != "" { + return input.Context.Request.URL.Scheme + } + if input.Context.Request.TLS == nil { + return "http" + } + return "https" +} + +// Domain returns host name. +// Alias of Host method. +func (input *BeegoInput) Domain() string { + return input.Host() +} + +// Host returns host name. +// if no host info in request, return localhost. +func (input *BeegoInput) Host() string { + if input.Context.Request.Host != "" { + if hostPart, _, err := net.SplitHostPort(input.Context.Request.Host); err == nil { + return hostPart + } + return input.Context.Request.Host + } + return "localhost" +} + +// Method returns http request method. +func (input *BeegoInput) Method() string { + return input.Context.Request.Method +} + +// Is returns boolean of this request is on given method, such as Is("POST"). +func (input *BeegoInput) Is(method string) bool { + return input.Method() == method +} + +// IsGet Is this a GET method request? +func (input *BeegoInput) IsGet() bool { + return input.Is("GET") +} + +// IsPost Is this a POST method request? +func (input *BeegoInput) IsPost() bool { + return input.Is("POST") +} + +// IsHead Is this a Head method request? +func (input *BeegoInput) IsHead() bool { + return input.Is("HEAD") +} + +// IsOptions Is this a OPTIONS method request? +func (input *BeegoInput) IsOptions() bool { + return input.Is("OPTIONS") +} + +// IsPut Is this a PUT method request? +func (input *BeegoInput) IsPut() bool { + return input.Is("PUT") +} + +// IsDelete Is this a DELETE method request? +func (input *BeegoInput) IsDelete() bool { + return input.Is("DELETE") +} + +// IsPatch Is this a PATCH method request? +func (input *BeegoInput) IsPatch() bool { + return input.Is("PATCH") +} + +// IsAjax returns boolean of this request is generated by ajax. +func (input *BeegoInput) IsAjax() bool { + return input.Header("X-Requested-With") == "XMLHttpRequest" +} + +// IsSecure returns boolean of this request is in https. +func (input *BeegoInput) IsSecure() bool { + return input.Scheme() == "https" +} + +// IsWebsocket returns boolean of this request is in webSocket. +func (input *BeegoInput) IsWebsocket() bool { + return input.Header("Upgrade") == "websocket" +} + +// IsUpload returns boolean of whether file uploads in this request or not.. +func (input *BeegoInput) IsUpload() bool { + return strings.Contains(input.Header("Content-Type"), "multipart/form-data") +} + +// AcceptsHTML Checks if request accepts html response +func (input *BeegoInput) AcceptsHTML() bool { + return acceptsHTMLRegex.MatchString(input.Header("Accept")) +} + +// AcceptsXML Checks if request accepts xml response +func (input *BeegoInput) AcceptsXML() bool { + return acceptsXMLRegex.MatchString(input.Header("Accept")) +} + +// AcceptsJSON Checks if request accepts json response +func (input *BeegoInput) AcceptsJSON() bool { + return acceptsJSONRegex.MatchString(input.Header("Accept")) +} + +// AcceptsYAML Checks if request accepts json response +func (input *BeegoInput) AcceptsYAML() bool { + return acceptsYAMLRegex.MatchString(input.Header("Accept")) +} + +// IP returns request client ip. +// if in proxy, return first proxy id. +// if error, return RemoteAddr. +func (input *BeegoInput) IP() string { + ips := input.Proxy() + if len(ips) > 0 && ips[0] != "" { + rip, _, err := net.SplitHostPort(ips[0]) + if err != nil { + rip = ips[0] + } + return rip + } + if ip, _, err := net.SplitHostPort(input.Context.Request.RemoteAddr); err == nil { + return ip + } + return input.Context.Request.RemoteAddr +} + +// Proxy returns proxy client ips slice. +func (input *BeegoInput) Proxy() []string { + if ips := input.Header("X-Forwarded-For"); ips != "" { + return strings.Split(ips, ",") + } + return []string{} +} + +// Referer returns http referer header. +func (input *BeegoInput) Referer() string { + return input.Header("Referer") +} + +// Refer returns http referer header. +func (input *BeegoInput) Refer() string { + return input.Referer() +} + +// SubDomains returns sub domain string. +// if aa.bb.domain.com, returns aa.bb . +func (input *BeegoInput) SubDomains() string { + parts := strings.Split(input.Host(), ".") + if len(parts) >= 3 { + return strings.Join(parts[:len(parts)-2], ".") + } + return "" +} + +// Port returns request client port. +// when error or empty, return 80. +func (input *BeegoInput) Port() int { + if _, portPart, err := net.SplitHostPort(input.Context.Request.Host); err == nil { + port, _ := strconv.Atoi(portPart) + return port + } + return 80 +} + +// UserAgent returns request client user agent string. +func (input *BeegoInput) UserAgent() string { + return input.Header("User-Agent") +} + +// ParamsLen return the length of the params +func (input *BeegoInput) ParamsLen() int { + return len(input.pnames) +} + +// Param returns router param by a given key. +func (input *BeegoInput) Param(key string) string { + for i, v := range input.pnames { + if v == key && i <= len(input.pvalues) { + // we cannot use url.PathEscape(input.pvalues[i]) + // for example, if the value is /a/b + // after url.PathEscape(input.pvalues[i]), the value is %2Fa%2Fb + // However, the value is used in ControllerRegister.ServeHTTP + // and split by "/", so function crash... + return input.pvalues[i] + } + } + return "" +} + +// Params returns the map[key]value. +func (input *BeegoInput) Params() map[string]string { + m := make(map[string]string) + for i, v := range input.pnames { + if i <= len(input.pvalues) { + m[v] = input.pvalues[i] + } + } + return m +} + +// SetParam will set the param with key and value +func (input *BeegoInput) SetParam(key, val string) { + // check if already exists + for i, v := range input.pnames { + if v == key && i <= len(input.pvalues) { + input.pvalues[i] = val + return + } + } + input.pvalues = append(input.pvalues, val) + input.pnames = append(input.pnames, key) +} + +// ResetParams clears any of the input's Params +// This function is used to clear parameters so they may be reset between filter +// passes. +func (input *BeegoInput) ResetParams() { + input.pnames = input.pnames[:0] + input.pvalues = input.pvalues[:0] +} + +// Query returns input data item string by a given string. +func (input *BeegoInput) Query(key string) string { + if val := input.Param(key); val != "" { + return val + } + if input.Context.Request.Form == nil { + input.dataLock.Lock() + if input.Context.Request.Form == nil { + input.Context.Request.ParseForm() + } + input.dataLock.Unlock() + } + input.dataLock.RLock() + defer input.dataLock.RUnlock() + return input.Context.Request.Form.Get(key) +} + +// Header returns request header item string by a given string. +// if non-existed, return empty string. +func (input *BeegoInput) Header(key string) string { + return input.Context.Request.Header.Get(key) +} + +// Cookie returns request cookie item string by a given key. +// if non-existed, return empty string. +func (input *BeegoInput) Cookie(key string) string { + ck, err := input.Context.Request.Cookie(key) + if err != nil { + return "" + } + return ck.Value +} + +// Session returns current session item value by a given key. +// if non-existed, return nil. +func (input *BeegoInput) Session(key interface{}) interface{} { + return input.CruSession.Get(key) +} + +// CopyBody returns the raw request body data as bytes. +func (input *BeegoInput) CopyBody(MaxMemory int64) []byte { + if input.Context.Request.Body == nil { + return []byte{} + } + + var requestbody []byte + safe := &io.LimitedReader{R: input.Context.Request.Body, N: MaxMemory} + if input.Header("Content-Encoding") == "gzip" { + reader, err := gzip.NewReader(safe) + if err != nil { + return nil + } + requestbody, _ = ioutil.ReadAll(reader) + } else { + requestbody, _ = ioutil.ReadAll(safe) + } + + input.Context.Request.Body.Close() + bf := bytes.NewBuffer(requestbody) + input.Context.Request.Body = http.MaxBytesReader(input.Context.ResponseWriter, ioutil.NopCloser(bf), MaxMemory) + input.RequestBody = requestbody + return requestbody +} + +// Data return the implicit data in the input +func (input *BeegoInput) Data() map[interface{}]interface{} { + input.dataLock.Lock() + defer input.dataLock.Unlock() + if input.data == nil { + input.data = make(map[interface{}]interface{}) + } + return input.data +} + +// GetData returns the stored data in this context. +func (input *BeegoInput) GetData(key interface{}) interface{} { + input.dataLock.Lock() + defer input.dataLock.Unlock() + if v, ok := input.data[key]; ok { + return v + } + return nil +} + +// SetData stores data with given key in this context. +// This data are only available in this context. +func (input *BeegoInput) SetData(key, val interface{}) { + input.dataLock.Lock() + defer input.dataLock.Unlock() + if input.data == nil { + input.data = make(map[interface{}]interface{}) + } + input.data[key] = val +} + +// ParseFormOrMulitForm parseForm or parseMultiForm based on Content-type +func (input *BeegoInput) ParseFormOrMulitForm(maxMemory int64) error { + // Parse the body depending on the content type. + if strings.Contains(input.Header("Content-Type"), "multipart/form-data") { + if err := input.Context.Request.ParseMultipartForm(maxMemory); err != nil { + return errors.New("Error parsing request body:" + err.Error()) + } + } else if err := input.Context.Request.ParseForm(); err != nil { + return errors.New("Error parsing request body:" + err.Error()) + } + return nil +} + +// Bind data from request.Form[key] to dest +// like /?id=123&isok=true&ft=1.2&ol[0]=1&ol[1]=2&ul[]=str&ul[]=array&user.Name=astaxie +// var id int beegoInput.Bind(&id, "id") id ==123 +// var isok bool beegoInput.Bind(&isok, "isok") isok ==true +// var ft float64 beegoInput.Bind(&ft, "ft") ft ==1.2 +// ol := make([]int, 0, 2) beegoInput.Bind(&ol, "ol") ol ==[1 2] +// ul := make([]string, 0, 2) beegoInput.Bind(&ul, "ul") ul ==[str array] +// user struct{Name} beegoInput.Bind(&user, "user") user == {Name:"astaxie"} +func (input *BeegoInput) Bind(dest interface{}, key string) error { + value := reflect.ValueOf(dest) + if value.Kind() != reflect.Ptr { + return errors.New("beego: non-pointer passed to Bind: " + key) + } + value = value.Elem() + if !value.CanSet() { + return errors.New("beego: non-settable variable passed to Bind: " + key) + } + typ := value.Type() + // Get real type if dest define with interface{}. + // e.g var dest interface{} dest=1.0 + if value.Kind() == reflect.Interface { + typ = value.Elem().Type() + } + rv := input.bind(key, typ) + if !rv.IsValid() { + return errors.New("beego: reflect value is empty") + } + value.Set(rv) + return nil +} + +func (input *BeegoInput) bind(key string, typ reflect.Type) reflect.Value { + if input.Context.Request.Form == nil { + input.Context.Request.ParseForm() + } + rv := reflect.Zero(typ) + switch typ.Kind() { + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + val := input.Query(key) + if len(val) == 0 { + return rv + } + rv = input.bindInt(val, typ) + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + val := input.Query(key) + if len(val) == 0 { + return rv + } + rv = input.bindUint(val, typ) + case reflect.Float32, reflect.Float64: + val := input.Query(key) + if len(val) == 0 { + return rv + } + rv = input.bindFloat(val, typ) + case reflect.String: + val := input.Query(key) + if len(val) == 0 { + return rv + } + rv = input.bindString(val, typ) + case reflect.Bool: + val := input.Query(key) + if len(val) == 0 { + return rv + } + rv = input.bindBool(val, typ) + case reflect.Slice: + rv = input.bindSlice(&input.Context.Request.Form, key, typ) + case reflect.Struct: + rv = input.bindStruct(&input.Context.Request.Form, key, typ) + case reflect.Ptr: + rv = input.bindPoint(key, typ) + case reflect.Map: + rv = input.bindMap(&input.Context.Request.Form, key, typ) + } + return rv +} + +func (input *BeegoInput) bindValue(val string, typ reflect.Type) reflect.Value { + rv := reflect.Zero(typ) + switch typ.Kind() { + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + rv = input.bindInt(val, typ) + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + rv = input.bindUint(val, typ) + case reflect.Float32, reflect.Float64: + rv = input.bindFloat(val, typ) + case reflect.String: + rv = input.bindString(val, typ) + case reflect.Bool: + rv = input.bindBool(val, typ) + case reflect.Slice: + rv = input.bindSlice(&url.Values{"": {val}}, "", typ) + case reflect.Struct: + rv = input.bindStruct(&url.Values{"": {val}}, "", typ) + case reflect.Ptr: + rv = input.bindPoint(val, typ) + case reflect.Map: + rv = input.bindMap(&url.Values{"": {val}}, "", typ) + } + return rv +} + +func (input *BeegoInput) bindInt(val string, typ reflect.Type) reflect.Value { + intValue, err := strconv.ParseInt(val, 10, 64) + if err != nil { + return reflect.Zero(typ) + } + pValue := reflect.New(typ) + pValue.Elem().SetInt(intValue) + return pValue.Elem() +} + +func (input *BeegoInput) bindUint(val string, typ reflect.Type) reflect.Value { + uintValue, err := strconv.ParseUint(val, 10, 64) + if err != nil { + return reflect.Zero(typ) + } + pValue := reflect.New(typ) + pValue.Elem().SetUint(uintValue) + return pValue.Elem() +} + +func (input *BeegoInput) bindFloat(val string, typ reflect.Type) reflect.Value { + floatValue, err := strconv.ParseFloat(val, 64) + if err != nil { + return reflect.Zero(typ) + } + pValue := reflect.New(typ) + pValue.Elem().SetFloat(floatValue) + return pValue.Elem() +} + +func (input *BeegoInput) bindString(val string, typ reflect.Type) reflect.Value { + return reflect.ValueOf(val) +} + +func (input *BeegoInput) bindBool(val string, typ reflect.Type) reflect.Value { + val = strings.TrimSpace(strings.ToLower(val)) + switch val { + case "true", "on", "1": + return reflect.ValueOf(true) + } + return reflect.ValueOf(false) +} + +type sliceValue struct { + index int // Index extracted from brackets. If -1, no index was provided. + value reflect.Value // the bound value for this slice element. +} + +func (input *BeegoInput) bindSlice(params *url.Values, key string, typ reflect.Type) reflect.Value { + maxIndex := -1 + numNoIndex := 0 + sliceValues := []sliceValue{} + for reqKey, vals := range *params { + if !strings.HasPrefix(reqKey, key+"[") { + continue + } + // Extract the index, and the index where a sub-key starts. (e.g. field[0].subkey) + index := -1 + leftBracket, rightBracket := len(key), strings.Index(reqKey[len(key):], "]")+len(key) + if rightBracket > leftBracket+1 { + index, _ = strconv.Atoi(reqKey[leftBracket+1 : rightBracket]) + } + subKeyIndex := rightBracket + 1 + + // Handle the indexed case. + if index > -1 { + if index > maxIndex { + maxIndex = index + } + sliceValues = append(sliceValues, sliceValue{ + index: index, + value: input.bind(reqKey[:subKeyIndex], typ.Elem()), + }) + continue + } + + // It's an un-indexed element. (e.g. element[]) + numNoIndex += len(vals) + for _, val := range vals { + // Unindexed values can only be direct-bound. + sliceValues = append(sliceValues, sliceValue{ + index: -1, + value: input.bindValue(val, typ.Elem()), + }) + } + } + resultArray := reflect.MakeSlice(typ, maxIndex+1, maxIndex+1+numNoIndex) + for _, sv := range sliceValues { + if sv.index != -1 { + resultArray.Index(sv.index).Set(sv.value) + } else { + resultArray = reflect.Append(resultArray, sv.value) + } + } + return resultArray +} + +func (input *BeegoInput) bindStruct(params *url.Values, key string, typ reflect.Type) reflect.Value { + result := reflect.New(typ).Elem() + fieldValues := make(map[string]reflect.Value) + for reqKey, val := range *params { + var fieldName string + if strings.HasPrefix(reqKey, key+".") { + fieldName = reqKey[len(key)+1:] + } else if strings.HasPrefix(reqKey, key+"[") && reqKey[len(reqKey)-1] == ']' { + fieldName = reqKey[len(key)+1 : len(reqKey)-1] + } else { + continue + } + + if _, ok := fieldValues[fieldName]; !ok { + // Time to bind this field. Get it and make sure we can set it. + fieldValue := result.FieldByName(fieldName) + if !fieldValue.IsValid() { + continue + } + if !fieldValue.CanSet() { + continue + } + boundVal := input.bindValue(val[0], fieldValue.Type()) + fieldValue.Set(boundVal) + fieldValues[fieldName] = boundVal + } + } + + return result +} + +func (input *BeegoInput) bindPoint(key string, typ reflect.Type) reflect.Value { + return input.bind(key, typ.Elem()).Addr() +} + +func (input *BeegoInput) bindMap(params *url.Values, key string, typ reflect.Type) reflect.Value { + var ( + result = reflect.MakeMap(typ) + keyType = typ.Key() + valueType = typ.Elem() + ) + for paramName, values := range *params { + if !strings.HasPrefix(paramName, key+"[") || paramName[len(paramName)-1] != ']' { + continue + } + + key := paramName[len(key)+1 : len(paramName)-1] + result.SetMapIndex(input.bindValue(key, keyType), input.bindValue(values[0], valueType)) + } + return result +} diff --git a/pkg/context/input_test.go b/pkg/context/input_test.go new file mode 100644 index 00000000..3a6c2e7b --- /dev/null +++ b/pkg/context/input_test.go @@ -0,0 +1,217 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package context + +import ( + "net/http" + "net/http/httptest" + "reflect" + "testing" +) + +func TestBind(t *testing.T) { + type testItem struct { + field string + empty interface{} + want interface{} + } + type Human struct { + ID int + Nick string + Pwd string + Ms bool + } + + cases := []struct { + request string + valueGp []testItem + }{ + {"/?p=str", []testItem{{"p", interface{}(""), interface{}("str")}}}, + + {"/?p=", []testItem{{"p", "", ""}}}, + {"/?p=str", []testItem{{"p", "", "str"}}}, + + {"/?p=123", []testItem{{"p", 0, 123}}}, + {"/?p=123", []testItem{{"p", uint(0), uint(123)}}}, + + {"/?p=1.0", []testItem{{"p", 0.0, 1.0}}}, + {"/?p=1", []testItem{{"p", false, true}}}, + + {"/?p=true", []testItem{{"p", false, true}}}, + {"/?p=ON", []testItem{{"p", false, true}}}, + {"/?p=on", []testItem{{"p", false, true}}}, + {"/?p=1", []testItem{{"p", false, true}}}, + {"/?p=2", []testItem{{"p", false, false}}}, + {"/?p=false", []testItem{{"p", false, false}}}, + + {"/?p[a]=1&p[b]=2&p[c]=3", []testItem{{"p", map[string]int{}, map[string]int{"a": 1, "b": 2, "c": 3}}}}, + {"/?p[a]=v1&p[b]=v2&p[c]=v3", []testItem{{"p", map[string]string{}, map[string]string{"a": "v1", "b": "v2", "c": "v3"}}}}, + + {"/?p[]=8&p[]=9&p[]=10", []testItem{{"p", []int{}, []int{8, 9, 10}}}}, + {"/?p[0]=8&p[1]=9&p[2]=10", []testItem{{"p", []int{}, []int{8, 9, 10}}}}, + {"/?p[0]=8&p[1]=9&p[2]=10&p[5]=14", []testItem{{"p", []int{}, []int{8, 9, 10, 0, 0, 14}}}}, + {"/?p[0]=8.0&p[1]=9.0&p[2]=10.0", []testItem{{"p", []float64{}, []float64{8.0, 9.0, 10.0}}}}, + + {"/?p[]=10&p[]=9&p[]=8", []testItem{{"p", []string{}, []string{"10", "9", "8"}}}}, + {"/?p[0]=8&p[1]=9&p[2]=10", []testItem{{"p", []string{}, []string{"8", "9", "10"}}}}, + + {"/?p[0]=true&p[1]=false&p[2]=true&p[5]=1&p[6]=ON&p[7]=other", []testItem{{"p", []bool{}, []bool{true, false, true, false, false, true, true, false}}}}, + + {"/?human.Nick=astaxie", []testItem{{"human", Human{}, Human{Nick: "astaxie"}}}}, + {"/?human.ID=888&human.Nick=astaxie&human.Ms=true&human[Pwd]=pass", []testItem{{"human", Human{}, Human{ID: 888, Nick: "astaxie", Ms: true, Pwd: "pass"}}}}, + {"/?human[0].ID=888&human[0].Nick=astaxie&human[0].Ms=true&human[0][Pwd]=pass01&human[1].ID=999&human[1].Nick=ysqi&human[1].Ms=On&human[1].Pwd=pass02", + []testItem{{"human", []Human{}, []Human{ + {ID: 888, Nick: "astaxie", Ms: true, Pwd: "pass01"}, + {ID: 999, Nick: "ysqi", Ms: true, Pwd: "pass02"}, + }}}}, + + { + "/?id=123&isok=true&ft=1.2&ol[0]=1&ol[1]=2&ul[]=str&ul[]=array&human.Nick=astaxie", + []testItem{ + {"id", 0, 123}, + {"isok", false, true}, + {"ft", 0.0, 1.2}, + {"ol", []int{}, []int{1, 2}}, + {"ul", []string{}, []string{"str", "array"}}, + {"human", Human{}, Human{Nick: "astaxie"}}, + }, + }, + } + for _, c := range cases { + r, _ := http.NewRequest("GET", c.request, nil) + beegoInput := NewInput() + beegoInput.Context = NewContext() + beegoInput.Context.Reset(httptest.NewRecorder(), r) + + for _, item := range c.valueGp { + got := item.empty + err := beegoInput.Bind(&got, item.field) + if err != nil { + t.Fatal(err) + } + if !reflect.DeepEqual(got, item.want) { + t.Fatalf("Bind %q error,should be:\n%#v \ngot:\n%#v", item.field, item.want, got) + } + } + + } +} + +func TestSubDomain(t *testing.T) { + r, _ := http.NewRequest("GET", "http://www.example.com/?id=123&isok=true&ft=1.2&ol[0]=1&ol[1]=2&ul[]=str&ul[]=array&user.Name=astaxie", nil) + beegoInput := NewInput() + beegoInput.Context = NewContext() + beegoInput.Context.Reset(httptest.NewRecorder(), r) + + subdomain := beegoInput.SubDomains() + if subdomain != "www" { + t.Fatal("Subdomain parse error, got" + subdomain) + } + + r, _ = http.NewRequest("GET", "http://localhost/", nil) + beegoInput.Context.Request = r + if beegoInput.SubDomains() != "" { + t.Fatal("Subdomain parse error, should be empty, got " + beegoInput.SubDomains()) + } + + r, _ = http.NewRequest("GET", "http://aa.bb.example.com/", nil) + beegoInput.Context.Request = r + if beegoInput.SubDomains() != "aa.bb" { + t.Fatal("Subdomain parse error, got " + beegoInput.SubDomains()) + } + + /* TODO Fix this + r, _ = http.NewRequest("GET", "http://127.0.0.1/", nil) + beegoInput.Context.Request = r + if beegoInput.SubDomains() != "" { + t.Fatal("Subdomain parse error, got " + beegoInput.SubDomains()) + } + */ + + r, _ = http.NewRequest("GET", "http://example.com/", nil) + beegoInput.Context.Request = r + if beegoInput.SubDomains() != "" { + t.Fatal("Subdomain parse error, got " + beegoInput.SubDomains()) + } + + r, _ = http.NewRequest("GET", "http://aa.bb.cc.dd.example.com/", nil) + beegoInput.Context.Request = r + if beegoInput.SubDomains() != "aa.bb.cc.dd" { + t.Fatal("Subdomain parse error, got " + beegoInput.SubDomains()) + } +} + +func TestParams(t *testing.T) { + inp := NewInput() + + inp.SetParam("p1", "val1_ver1") + inp.SetParam("p2", "val2_ver1") + inp.SetParam("p3", "val3_ver1") + if l := inp.ParamsLen(); l != 3 { + t.Fatalf("Input.ParamsLen wrong value: %d, expected %d", l, 3) + } + + if val := inp.Param("p1"); val != "val1_ver1" { + t.Fatalf("Input.Param wrong value: %s, expected %s", val, "val1_ver1") + } + if val := inp.Param("p3"); val != "val3_ver1" { + t.Fatalf("Input.Param wrong value: %s, expected %s", val, "val3_ver1") + } + vals := inp.Params() + expected := map[string]string{ + "p1": "val1_ver1", + "p2": "val2_ver1", + "p3": "val3_ver1", + } + if !reflect.DeepEqual(vals, expected) { + t.Fatalf("Input.Params wrong value: %s, expected %s", vals, expected) + } + + // overwriting existing params + inp.SetParam("p1", "val1_ver2") + inp.SetParam("p2", "val2_ver2") + expected = map[string]string{ + "p1": "val1_ver2", + "p2": "val2_ver2", + "p3": "val3_ver1", + } + vals = inp.Params() + if !reflect.DeepEqual(vals, expected) { + t.Fatalf("Input.Params wrong value: %s, expected %s", vals, expected) + } + + if l := inp.ParamsLen(); l != 3 { + t.Fatalf("Input.ParamsLen wrong value: %d, expected %d", l, 3) + } + + if val := inp.Param("p1"); val != "val1_ver2" { + t.Fatalf("Input.Param wrong value: %s, expected %s", val, "val1_ver2") + } + + if val := inp.Param("p2"); val != "val2_ver2" { + t.Fatalf("Input.Param wrong value: %s, expected %s", val, "val1_ver2") + } + +} +func BenchmarkQuery(b *testing.B) { + beegoInput := NewInput() + beegoInput.Context = NewContext() + beegoInput.Context.Request, _ = http.NewRequest("POST", "http://www.example.com/?q=foo", nil) + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + beegoInput.Query("q") + } + }) +} diff --git a/pkg/context/output.go b/pkg/context/output.go new file mode 100644 index 00000000..238dcf45 --- /dev/null +++ b/pkg/context/output.go @@ -0,0 +1,408 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package context + +import ( + "bytes" + "encoding/json" + "encoding/xml" + "errors" + "fmt" + "html/template" + "io" + "mime" + "net/http" + "net/url" + "os" + "path/filepath" + "strconv" + "strings" + "time" + + yaml "gopkg.in/yaml.v2" +) + +// BeegoOutput does work for sending response header. +type BeegoOutput struct { + Context *Context + Status int + EnableGzip bool +} + +// NewOutput returns new BeegoOutput. +// it contains nothing now. +func NewOutput() *BeegoOutput { + return &BeegoOutput{} +} + +// Reset init BeegoOutput +func (output *BeegoOutput) Reset(ctx *Context) { + output.Context = ctx + output.Status = 0 +} + +// Header sets response header item string via given key. +func (output *BeegoOutput) Header(key, val string) { + output.Context.ResponseWriter.Header().Set(key, val) +} + +// Body sets response body content. +// if EnableGzip, compress content string. +// it sends out response body directly. +func (output *BeegoOutput) Body(content []byte) error { + var encoding string + var buf = &bytes.Buffer{} + if output.EnableGzip { + encoding = ParseEncoding(output.Context.Request) + } + if b, n, _ := WriteBody(encoding, buf, content); b { + output.Header("Content-Encoding", n) + output.Header("Content-Length", strconv.Itoa(buf.Len())) + } else { + output.Header("Content-Length", strconv.Itoa(len(content))) + } + // Write status code if it has been set manually + // Set it to 0 afterwards to prevent "multiple response.WriteHeader calls" + if output.Status != 0 { + output.Context.ResponseWriter.WriteHeader(output.Status) + output.Status = 0 + } else { + output.Context.ResponseWriter.Started = true + } + io.Copy(output.Context.ResponseWriter, buf) + return nil +} + +// Cookie sets cookie value via given key. +// others are ordered as cookie's max age time, path,domain, secure and httponly. +func (output *BeegoOutput) Cookie(name string, value string, others ...interface{}) { + var b bytes.Buffer + fmt.Fprintf(&b, "%s=%s", sanitizeName(name), sanitizeValue(value)) + + //fix cookie not work in IE + if len(others) > 0 { + var maxAge int64 + + switch v := others[0].(type) { + case int: + maxAge = int64(v) + case int32: + maxAge = int64(v) + case int64: + maxAge = v + } + + switch { + case maxAge > 0: + fmt.Fprintf(&b, "; Expires=%s; Max-Age=%d", time.Now().Add(time.Duration(maxAge)*time.Second).UTC().Format(time.RFC1123), maxAge) + case maxAge < 0: + fmt.Fprintf(&b, "; Max-Age=0") + } + } + + // the settings below + // Path, Domain, Secure, HttpOnly + // can use nil skip set + + // default "/" + if len(others) > 1 { + if v, ok := others[1].(string); ok && len(v) > 0 { + fmt.Fprintf(&b, "; Path=%s", sanitizeValue(v)) + } + } else { + fmt.Fprintf(&b, "; Path=%s", "/") + } + + // default empty + if len(others) > 2 { + if v, ok := others[2].(string); ok && len(v) > 0 { + fmt.Fprintf(&b, "; Domain=%s", sanitizeValue(v)) + } + } + + // default empty + if len(others) > 3 { + var secure bool + switch v := others[3].(type) { + case bool: + secure = v + default: + if others[3] != nil { + secure = true + } + } + if secure { + fmt.Fprintf(&b, "; Secure") + } + } + + // default false. for session cookie default true + if len(others) > 4 { + if v, ok := others[4].(bool); ok && v { + fmt.Fprintf(&b, "; HttpOnly") + } + } + + output.Context.ResponseWriter.Header().Add("Set-Cookie", b.String()) +} + +var cookieNameSanitizer = strings.NewReplacer("\n", "-", "\r", "-") + +func sanitizeName(n string) string { + return cookieNameSanitizer.Replace(n) +} + +var cookieValueSanitizer = strings.NewReplacer("\n", " ", "\r", " ", ";", " ") + +func sanitizeValue(v string) string { + return cookieValueSanitizer.Replace(v) +} + +func jsonRenderer(value interface{}) Renderer { + return rendererFunc(func(ctx *Context) { + ctx.Output.JSON(value, false, false) + }) +} + +func errorRenderer(err error) Renderer { + return rendererFunc(func(ctx *Context) { + ctx.Output.SetStatus(500) + ctx.Output.Body([]byte(err.Error())) + }) +} + +// JSON writes json to response body. +// if encoding is true, it converts utf-8 to \u0000 type. +func (output *BeegoOutput) JSON(data interface{}, hasIndent bool, encoding bool) error { + output.Header("Content-Type", "application/json; charset=utf-8") + var content []byte + var err error + if hasIndent { + content, err = json.MarshalIndent(data, "", " ") + } else { + content, err = json.Marshal(data) + } + if err != nil { + http.Error(output.Context.ResponseWriter, err.Error(), http.StatusInternalServerError) + return err + } + if encoding { + content = []byte(stringsToJSON(string(content))) + } + return output.Body(content) +} + +// YAML writes yaml to response body. +func (output *BeegoOutput) YAML(data interface{}) error { + output.Header("Content-Type", "application/x-yaml; charset=utf-8") + var content []byte + var err error + content, err = yaml.Marshal(data) + if err != nil { + http.Error(output.Context.ResponseWriter, err.Error(), http.StatusInternalServerError) + return err + } + return output.Body(content) +} + +// JSONP writes jsonp to response body. +func (output *BeegoOutput) JSONP(data interface{}, hasIndent bool) error { + output.Header("Content-Type", "application/javascript; charset=utf-8") + var content []byte + var err error + if hasIndent { + content, err = json.MarshalIndent(data, "", " ") + } else { + content, err = json.Marshal(data) + } + if err != nil { + http.Error(output.Context.ResponseWriter, err.Error(), http.StatusInternalServerError) + return err + } + callback := output.Context.Input.Query("callback") + if callback == "" { + return errors.New(`"callback" parameter required`) + } + callback = template.JSEscapeString(callback) + callbackContent := bytes.NewBufferString(" if(window." + callback + ")" + callback) + callbackContent.WriteString("(") + callbackContent.Write(content) + callbackContent.WriteString(");\r\n") + return output.Body(callbackContent.Bytes()) +} + +// XML writes xml string to response body. +func (output *BeegoOutput) XML(data interface{}, hasIndent bool) error { + output.Header("Content-Type", "application/xml; charset=utf-8") + var content []byte + var err error + if hasIndent { + content, err = xml.MarshalIndent(data, "", " ") + } else { + content, err = xml.Marshal(data) + } + if err != nil { + http.Error(output.Context.ResponseWriter, err.Error(), http.StatusInternalServerError) + return err + } + return output.Body(content) +} + +// ServeFormatted serve YAML, XML OR JSON, depending on the value of the Accept header +func (output *BeegoOutput) ServeFormatted(data interface{}, hasIndent bool, hasEncode ...bool) { + accept := output.Context.Input.Header("Accept") + switch accept { + case ApplicationYAML: + output.YAML(data) + case ApplicationXML, TextXML: + output.XML(data, hasIndent) + default: + output.JSON(data, hasIndent, len(hasEncode) > 0 && hasEncode[0]) + } +} + +// Download forces response for download file. +// it prepares the download response header automatically. +func (output *BeegoOutput) Download(file string, filename ...string) { + // check get file error, file not found or other error. + if _, err := os.Stat(file); err != nil { + http.ServeFile(output.Context.ResponseWriter, output.Context.Request, file) + return + } + + var fName string + if len(filename) > 0 && filename[0] != "" { + fName = filename[0] + } else { + fName = filepath.Base(file) + } + //https://tools.ietf.org/html/rfc6266#section-4.3 + fn := url.PathEscape(fName) + if fName == fn { + fn = "filename=" + fn + } else { + /** + The parameters "filename" and "filename*" differ only in that + "filename*" uses the encoding defined in [RFC5987], allowing the use + of characters not present in the ISO-8859-1 character set + ([ISO-8859-1]). + */ + fn = "filename=" + fName + "; filename*=utf-8''" + fn + } + output.Header("Content-Disposition", "attachment; "+fn) + output.Header("Content-Description", "File Transfer") + output.Header("Content-Type", "application/octet-stream") + output.Header("Content-Transfer-Encoding", "binary") + output.Header("Expires", "0") + output.Header("Cache-Control", "must-revalidate") + output.Header("Pragma", "public") + http.ServeFile(output.Context.ResponseWriter, output.Context.Request, file) +} + +// ContentType sets the content type from ext string. +// MIME type is given in mime package. +func (output *BeegoOutput) ContentType(ext string) { + if !strings.HasPrefix(ext, ".") { + ext = "." + ext + } + ctype := mime.TypeByExtension(ext) + if ctype != "" { + output.Header("Content-Type", ctype) + } +} + +// SetStatus sets response status code. +// It writes response header directly. +func (output *BeegoOutput) SetStatus(status int) { + output.Status = status +} + +// IsCachable returns boolean of this request is cached. +// HTTP 304 means cached. +func (output *BeegoOutput) IsCachable() bool { + return output.Status >= 200 && output.Status < 300 || output.Status == 304 +} + +// IsEmpty returns boolean of this request is empty. +// HTTP 201,204 and 304 means empty. +func (output *BeegoOutput) IsEmpty() bool { + return output.Status == 201 || output.Status == 204 || output.Status == 304 +} + +// IsOk returns boolean of this request runs well. +// HTTP 200 means ok. +func (output *BeegoOutput) IsOk() bool { + return output.Status == 200 +} + +// IsSuccessful returns boolean of this request runs successfully. +// HTTP 2xx means ok. +func (output *BeegoOutput) IsSuccessful() bool { + return output.Status >= 200 && output.Status < 300 +} + +// IsRedirect returns boolean of this request is redirection header. +// HTTP 301,302,307 means redirection. +func (output *BeegoOutput) IsRedirect() bool { + return output.Status == 301 || output.Status == 302 || output.Status == 303 || output.Status == 307 +} + +// IsForbidden returns boolean of this request is forbidden. +// HTTP 403 means forbidden. +func (output *BeegoOutput) IsForbidden() bool { + return output.Status == 403 +} + +// IsNotFound returns boolean of this request is not found. +// HTTP 404 means not found. +func (output *BeegoOutput) IsNotFound() bool { + return output.Status == 404 +} + +// IsClientError returns boolean of this request client sends error data. +// HTTP 4xx means client error. +func (output *BeegoOutput) IsClientError() bool { + return output.Status >= 400 && output.Status < 500 +} + +// IsServerError returns boolean of this server handler errors. +// HTTP 5xx means server internal error. +func (output *BeegoOutput) IsServerError() bool { + return output.Status >= 500 && output.Status < 600 +} + +func stringsToJSON(str string) string { + var jsons bytes.Buffer + for _, r := range str { + rint := int(r) + if rint < 128 { + jsons.WriteRune(r) + } else { + jsons.WriteString("\\u") + if rint < 0x100 { + jsons.WriteString("00") + } else if rint < 0x1000 { + jsons.WriteString("0") + } + jsons.WriteString(strconv.FormatInt(int64(rint), 16)) + } + } + return jsons.String() +} + +// Session sets session item value with given key. +func (output *BeegoOutput) Session(name interface{}, value interface{}) { + output.Context.Input.CruSession.Set(name, value) +} diff --git a/pkg/context/param/conv.go b/pkg/context/param/conv.go new file mode 100644 index 00000000..c200e008 --- /dev/null +++ b/pkg/context/param/conv.go @@ -0,0 +1,78 @@ +package param + +import ( + "fmt" + "reflect" + + beecontext "github.com/astaxie/beego/context" + "github.com/astaxie/beego/logs" +) + +// ConvertParams converts http method params to values that will be passed to the method controller as arguments +func ConvertParams(methodParams []*MethodParam, methodType reflect.Type, ctx *beecontext.Context) (result []reflect.Value) { + result = make([]reflect.Value, 0, len(methodParams)) + for i := 0; i < len(methodParams); i++ { + reflectValue := convertParam(methodParams[i], methodType.In(i), ctx) + result = append(result, reflectValue) + } + return +} + +func convertParam(param *MethodParam, paramType reflect.Type, ctx *beecontext.Context) (result reflect.Value) { + paramValue := getParamValue(param, ctx) + if paramValue == "" { + if param.required { + ctx.Abort(400, fmt.Sprintf("Missing parameter %s", param.name)) + } else { + paramValue = param.defaultValue + } + } + + reflectValue, err := parseValue(param, paramValue, paramType) + if err != nil { + logs.Debug(fmt.Sprintf("Error converting param %s to type %s. Value: %v, Error: %s", param.name, paramType, paramValue, err)) + ctx.Abort(400, fmt.Sprintf("Invalid parameter %s. Can not convert %v to type %s", param.name, paramValue, paramType)) + } + + return reflectValue +} + +func getParamValue(param *MethodParam, ctx *beecontext.Context) string { + switch param.in { + case body: + return string(ctx.Input.RequestBody) + case header: + return ctx.Input.Header(param.name) + case path: + return ctx.Input.Query(":" + param.name) + default: + return ctx.Input.Query(param.name) + } +} + +func parseValue(param *MethodParam, paramValue string, paramType reflect.Type) (result reflect.Value, err error) { + if paramValue == "" { + return reflect.Zero(paramType), nil + } + parser := getParser(param, paramType) + value, err := parser.parse(paramValue, paramType) + if err != nil { + return result, err + } + + return safeConvert(reflect.ValueOf(value), paramType) +} + +func safeConvert(value reflect.Value, t reflect.Type) (result reflect.Value, err error) { + defer func() { + if r := recover(); r != nil { + var ok bool + err, ok = r.(error) + if !ok { + err = fmt.Errorf("%v", r) + } + } + }() + result = value.Convert(t) + return +} diff --git a/pkg/context/param/methodparams.go b/pkg/context/param/methodparams.go new file mode 100644 index 00000000..cd6708a2 --- /dev/null +++ b/pkg/context/param/methodparams.go @@ -0,0 +1,69 @@ +package param + +import ( + "fmt" + "strings" +) + +//MethodParam keeps param information to be auto passed to controller methods +type MethodParam struct { + name string + in paramType + required bool + defaultValue string +} + +type paramType byte + +const ( + param paramType = iota + path + body + header +) + +//New creates a new MethodParam with name and specific options +func New(name string, opts ...MethodParamOption) *MethodParam { + return newParam(name, nil, opts) +} + +func newParam(name string, parser paramParser, opts []MethodParamOption) (param *MethodParam) { + param = &MethodParam{name: name} + for _, option := range opts { + option(param) + } + return +} + +//Make creates an array of MethodParmas or an empty array +func Make(list ...*MethodParam) []*MethodParam { + if len(list) > 0 { + return list + } + return nil +} + +func (mp *MethodParam) String() string { + options := []string{} + result := "param.New(\"" + mp.name + "\"" + if mp.required { + options = append(options, "param.IsRequired") + } + switch mp.in { + case path: + options = append(options, "param.InPath") + case body: + options = append(options, "param.InBody") + case header: + options = append(options, "param.InHeader") + } + if mp.defaultValue != "" { + options = append(options, fmt.Sprintf(`param.Default("%s")`, mp.defaultValue)) + } + if len(options) > 0 { + result += ", " + } + result += strings.Join(options, ", ") + result += ")" + return result +} diff --git a/pkg/context/param/options.go b/pkg/context/param/options.go new file mode 100644 index 00000000..3d5ba013 --- /dev/null +++ b/pkg/context/param/options.go @@ -0,0 +1,37 @@ +package param + +import ( + "fmt" +) + +// MethodParamOption defines a func which apply options on a MethodParam +type MethodParamOption func(*MethodParam) + +// IsRequired indicates that this param is required and can not be omitted from the http request +var IsRequired MethodParamOption = func(p *MethodParam) { + p.required = true +} + +// InHeader indicates that this param is passed via an http header +var InHeader MethodParamOption = func(p *MethodParam) { + p.in = header +} + +// InPath indicates that this param is part of the URL path +var InPath MethodParamOption = func(p *MethodParam) { + p.in = path +} + +// InBody indicates that this param is passed as an http request body +var InBody MethodParamOption = func(p *MethodParam) { + p.in = body +} + +// Default provides a default value for the http param +func Default(defaultValue interface{}) MethodParamOption { + return func(p *MethodParam) { + if defaultValue != nil { + p.defaultValue = fmt.Sprint(defaultValue) + } + } +} diff --git a/pkg/context/param/parsers.go b/pkg/context/param/parsers.go new file mode 100644 index 00000000..421aecf0 --- /dev/null +++ b/pkg/context/param/parsers.go @@ -0,0 +1,149 @@ +package param + +import ( + "encoding/json" + "reflect" + "strconv" + "strings" + "time" +) + +type paramParser interface { + parse(value string, toType reflect.Type) (interface{}, error) +} + +func getParser(param *MethodParam, t reflect.Type) paramParser { + switch t.Kind() { + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, + reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + return intParser{} + case reflect.Slice: + if t.Elem().Kind() == reflect.Uint8 { //treat []byte as string + return stringParser{} + } + if param.in == body { + return jsonParser{} + } + elemParser := getParser(param, t.Elem()) + if elemParser == (jsonParser{}) { + return elemParser + } + return sliceParser(elemParser) + case reflect.Bool: + return boolParser{} + case reflect.String: + return stringParser{} + case reflect.Float32, reflect.Float64: + return floatParser{} + case reflect.Ptr: + elemParser := getParser(param, t.Elem()) + if elemParser == (jsonParser{}) { + return elemParser + } + return ptrParser(elemParser) + default: + if t.PkgPath() == "time" && t.Name() == "Time" { + return timeParser{} + } + return jsonParser{} + } +} + +type parserFunc func(value string, toType reflect.Type) (interface{}, error) + +func (f parserFunc) parse(value string, toType reflect.Type) (interface{}, error) { + return f(value, toType) +} + +type boolParser struct { +} + +func (p boolParser) parse(value string, toType reflect.Type) (interface{}, error) { + return strconv.ParseBool(value) +} + +type stringParser struct { +} + +func (p stringParser) parse(value string, toType reflect.Type) (interface{}, error) { + return value, nil +} + +type intParser struct { +} + +func (p intParser) parse(value string, toType reflect.Type) (interface{}, error) { + return strconv.Atoi(value) +} + +type floatParser struct { +} + +func (p floatParser) parse(value string, toType reflect.Type) (interface{}, error) { + if toType.Kind() == reflect.Float32 { + res, err := strconv.ParseFloat(value, 32) + if err != nil { + return nil, err + } + return float32(res), nil + } + return strconv.ParseFloat(value, 64) +} + +type timeParser struct { +} + +func (p timeParser) parse(value string, toType reflect.Type) (result interface{}, err error) { + result, err = time.Parse(time.RFC3339, value) + if err != nil { + result, err = time.Parse("2006-01-02", value) + } + return +} + +type jsonParser struct { +} + +func (p jsonParser) parse(value string, toType reflect.Type) (interface{}, error) { + pResult := reflect.New(toType) + v := pResult.Interface() + err := json.Unmarshal([]byte(value), v) + if err != nil { + return nil, err + } + return pResult.Elem().Interface(), nil +} + +func sliceParser(elemParser paramParser) paramParser { + return parserFunc(func(value string, toType reflect.Type) (interface{}, error) { + values := strings.Split(value, ",") + result := reflect.MakeSlice(toType, 0, len(values)) + elemType := toType.Elem() + for _, v := range values { + parsedValue, err := elemParser.parse(v, elemType) + if err != nil { + return nil, err + } + result = reflect.Append(result, reflect.ValueOf(parsedValue)) + } + return result.Interface(), nil + }) +} + +func ptrParser(elemParser paramParser) paramParser { + return parserFunc(func(value string, toType reflect.Type) (interface{}, error) { + parsedValue, err := elemParser.parse(value, toType.Elem()) + if err != nil { + return nil, err + } + newValPtr := reflect.New(toType.Elem()) + newVal := reflect.Indirect(newValPtr) + convertedVal, err := safeConvert(reflect.ValueOf(parsedValue), toType.Elem()) + if err != nil { + return nil, err + } + + newVal.Set(convertedVal) + return newValPtr.Interface(), nil + }) +} diff --git a/pkg/context/param/parsers_test.go b/pkg/context/param/parsers_test.go new file mode 100644 index 00000000..7065a28e --- /dev/null +++ b/pkg/context/param/parsers_test.go @@ -0,0 +1,84 @@ +package param + +import "testing" +import "reflect" +import "time" + +type testDefinition struct { + strValue string + expectedValue interface{} + expectedParser paramParser +} + +func Test_Parsers(t *testing.T) { + + //ints + checkParser(testDefinition{"1", 1, intParser{}}, t) + checkParser(testDefinition{"-1", int64(-1), intParser{}}, t) + checkParser(testDefinition{"1", uint64(1), intParser{}}, t) + + //floats + checkParser(testDefinition{"1.0", float32(1.0), floatParser{}}, t) + checkParser(testDefinition{"-1.0", float64(-1.0), floatParser{}}, t) + + //strings + checkParser(testDefinition{"AB", "AB", stringParser{}}, t) + checkParser(testDefinition{"AB", []byte{65, 66}, stringParser{}}, t) + + //bools + checkParser(testDefinition{"true", true, boolParser{}}, t) + checkParser(testDefinition{"0", false, boolParser{}}, t) + + //timeParser + checkParser(testDefinition{"2017-05-30T13:54:53Z", time.Date(2017, 5, 30, 13, 54, 53, 0, time.UTC), timeParser{}}, t) + checkParser(testDefinition{"2017-05-30", time.Date(2017, 5, 30, 0, 0, 0, 0, time.UTC), timeParser{}}, t) + + //json + checkParser(testDefinition{`{"X": 5, "Y":"Z"}`, struct { + X int + Y string + }{5, "Z"}, jsonParser{}}, t) + + //slice in query is parsed as comma delimited + checkParser(testDefinition{`1,2`, []int{1, 2}, sliceParser(intParser{})}, t) + + //slice in body is parsed as json + checkParser(testDefinition{`["a","b"]`, []string{"a", "b"}, jsonParser{}}, t, MethodParam{in: body}) + + //pointers + var someInt = 1 + checkParser(testDefinition{`1`, &someInt, ptrParser(intParser{})}, t) + + var someStruct = struct{ X int }{5} + checkParser(testDefinition{`{"X": 5}`, &someStruct, jsonParser{}}, t) + +} + +func checkParser(def testDefinition, t *testing.T, methodParam ...MethodParam) { + toType := reflect.TypeOf(def.expectedValue) + var mp MethodParam + if len(methodParam) == 0 { + mp = MethodParam{} + } else { + mp = methodParam[0] + } + parser := getParser(&mp, toType) + + if reflect.TypeOf(parser) != reflect.TypeOf(def.expectedParser) { + t.Errorf("Invalid parser for value %v. Expected: %v, actual: %v", def.strValue, reflect.TypeOf(def.expectedParser).Name(), reflect.TypeOf(parser).Name()) + return + } + result, err := parser.parse(def.strValue, toType) + if err != nil { + t.Errorf("Parsing error for value %v. Expected result: %v, error: %v", def.strValue, def.expectedValue, err) + return + } + convResult, err := safeConvert(reflect.ValueOf(result), toType) + if err != nil { + t.Errorf("Conversion error for %v. from value: %v, toType: %v, error: %v", def.strValue, result, toType, err) + return + } + if !reflect.DeepEqual(convResult.Interface(), def.expectedValue) { + t.Errorf("Parsing error for value %v. Expected result: %v, actual: %v", def.strValue, def.expectedValue, result) + } +} diff --git a/pkg/context/renderer.go b/pkg/context/renderer.go new file mode 100644 index 00000000..36a7cb53 --- /dev/null +++ b/pkg/context/renderer.go @@ -0,0 +1,12 @@ +package context + +// Renderer defines an http response renderer +type Renderer interface { + Render(ctx *Context) +} + +type rendererFunc func(ctx *Context) + +func (f rendererFunc) Render(ctx *Context) { + f(ctx) +} diff --git a/pkg/context/response.go b/pkg/context/response.go new file mode 100644 index 00000000..9c3c715a --- /dev/null +++ b/pkg/context/response.go @@ -0,0 +1,27 @@ +package context + +import ( + "strconv" + + "net/http" +) + +const ( + //BadRequest indicates http error 400 + BadRequest StatusCode = http.StatusBadRequest + + //NotFound indicates http error 404 + NotFound StatusCode = http.StatusNotFound +) + +// StatusCode sets the http response status code +type StatusCode int + +func (s StatusCode) Error() string { + return strconv.Itoa(int(s)) +} + +// Render sets the http status code +func (s StatusCode) Render(ctx *Context) { + ctx.Output.SetStatus(int(s)) +} diff --git a/pkg/controller.go b/pkg/controller.go new file mode 100644 index 00000000..0e8853b3 --- /dev/null +++ b/pkg/controller.go @@ -0,0 +1,706 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package beego + +import ( + "bytes" + "errors" + "fmt" + "html/template" + "io" + "mime/multipart" + "net/http" + "net/url" + "os" + "reflect" + "strconv" + "strings" + + "github.com/astaxie/beego/context" + "github.com/astaxie/beego/context/param" + "github.com/astaxie/beego/session" +) + +var ( + // ErrAbort custom error when user stop request handler manually. + ErrAbort = errors.New("user stop run") + // GlobalControllerRouter store comments with controller. pkgpath+controller:comments + GlobalControllerRouter = make(map[string][]ControllerComments) +) + +// ControllerFilter store the filter for controller +type ControllerFilter struct { + Pattern string + Pos int + Filter FilterFunc + ReturnOnOutput bool + ResetParams bool +} + +// ControllerFilterComments store the comment for controller level filter +type ControllerFilterComments struct { + Pattern string + Pos int + Filter string // NOQA + ReturnOnOutput bool + ResetParams bool +} + +// ControllerImportComments store the import comment for controller needed +type ControllerImportComments struct { + ImportPath string + ImportAlias string +} + +// ControllerComments store the comment for the controller method +type ControllerComments struct { + Method string + Router string + Filters []*ControllerFilter + ImportComments []*ControllerImportComments + FilterComments []*ControllerFilterComments + AllowHTTPMethods []string + Params []map[string]string + MethodParams []*param.MethodParam +} + +// ControllerCommentsSlice implements the sort interface +type ControllerCommentsSlice []ControllerComments + +func (p ControllerCommentsSlice) Len() int { return len(p) } +func (p ControllerCommentsSlice) Less(i, j int) bool { return p[i].Router < p[j].Router } +func (p ControllerCommentsSlice) Swap(i, j int) { p[i], p[j] = p[j], p[i] } + +// Controller defines some basic http request handler operations, such as +// http context, template and view, session and xsrf. +type Controller struct { + // context data + Ctx *context.Context + Data map[interface{}]interface{} + + // route controller info + controllerName string + actionName string + methodMapping map[string]func() //method:routertree + AppController interface{} + + // template data + TplName string + ViewPath string + Layout string + LayoutSections map[string]string // the key is the section name and the value is the template name + TplPrefix string + TplExt string + EnableRender bool + + // xsrf data + _xsrfToken string + XSRFExpire int + EnableXSRF bool + + // session + CruSession session.Store +} + +// ControllerInterface is an interface to uniform all controller handler. +type ControllerInterface interface { + Init(ct *context.Context, controllerName, actionName string, app interface{}) + Prepare() + Get() + Post() + Delete() + Put() + Head() + Patch() + Options() + Trace() + Finish() + Render() error + XSRFToken() string + CheckXSRFCookie() bool + HandlerFunc(fn string) bool + URLMapping() +} + +// Init generates default values of controller operations. +func (c *Controller) Init(ctx *context.Context, controllerName, actionName string, app interface{}) { + c.Layout = "" + c.TplName = "" + c.controllerName = controllerName + c.actionName = actionName + c.Ctx = ctx + c.TplExt = "tpl" + c.AppController = app + c.EnableRender = true + c.EnableXSRF = true + c.Data = ctx.Input.Data() + c.methodMapping = make(map[string]func()) +} + +// Prepare runs after Init before request function execution. +func (c *Controller) Prepare() {} + +// Finish runs after request function execution. +func (c *Controller) Finish() {} + +// Get adds a request function to handle GET request. +func (c *Controller) Get() { + http.Error(c.Ctx.ResponseWriter, "Method Not Allowed", http.StatusMethodNotAllowed) +} + +// Post adds a request function to handle POST request. +func (c *Controller) Post() { + http.Error(c.Ctx.ResponseWriter, "Method Not Allowed", http.StatusMethodNotAllowed) +} + +// Delete adds a request function to handle DELETE request. +func (c *Controller) Delete() { + http.Error(c.Ctx.ResponseWriter, "Method Not Allowed", http.StatusMethodNotAllowed) +} + +// Put adds a request function to handle PUT request. +func (c *Controller) Put() { + http.Error(c.Ctx.ResponseWriter, "Method Not Allowed", http.StatusMethodNotAllowed) +} + +// Head adds a request function to handle HEAD request. +func (c *Controller) Head() { + http.Error(c.Ctx.ResponseWriter, "Method Not Allowed", http.StatusMethodNotAllowed) +} + +// Patch adds a request function to handle PATCH request. +func (c *Controller) Patch() { + http.Error(c.Ctx.ResponseWriter, "Method Not Allowed", http.StatusMethodNotAllowed) +} + +// Options adds a request function to handle OPTIONS request. +func (c *Controller) Options() { + http.Error(c.Ctx.ResponseWriter, "Method Not Allowed", http.StatusMethodNotAllowed) +} + +// Trace adds a request function to handle Trace request. +// this method SHOULD NOT be overridden. +// https://tools.ietf.org/html/rfc7231#section-4.3.8 +// The TRACE method requests a remote, application-level loop-back of +// the request message. The final recipient of the request SHOULD +// reflect the message received, excluding some fields described below, +// back to the client as the message body of a 200 (OK) response with a +// Content-Type of "message/http" (Section 8.3.1 of [RFC7230]). +func (c *Controller) Trace() { + ts := func(h http.Header) (hs string) { + for k, v := range h { + hs += fmt.Sprintf("\r\n%s: %s", k, v) + } + return + } + hs := fmt.Sprintf("\r\nTRACE %s %s%s\r\n", c.Ctx.Request.RequestURI, c.Ctx.Request.Proto, ts(c.Ctx.Request.Header)) + c.Ctx.Output.Header("Content-Type", "message/http") + c.Ctx.Output.Header("Content-Length", fmt.Sprint(len(hs))) + c.Ctx.Output.Header("Cache-Control", "no-cache, no-store, must-revalidate") + c.Ctx.WriteString(hs) +} + +// HandlerFunc call function with the name +func (c *Controller) HandlerFunc(fnname string) bool { + if v, ok := c.methodMapping[fnname]; ok { + v() + return true + } + return false +} + +// URLMapping register the internal Controller router. +func (c *Controller) URLMapping() {} + +// Mapping the method to function +func (c *Controller) Mapping(method string, fn func()) { + c.methodMapping[method] = fn +} + +// Render sends the response with rendered template bytes as text/html type. +func (c *Controller) Render() error { + if !c.EnableRender { + return nil + } + rb, err := c.RenderBytes() + if err != nil { + return err + } + + if c.Ctx.ResponseWriter.Header().Get("Content-Type") == "" { + c.Ctx.Output.Header("Content-Type", "text/html; charset=utf-8") + } + + return c.Ctx.Output.Body(rb) +} + +// RenderString returns the rendered template string. Do not send out response. +func (c *Controller) RenderString() (string, error) { + b, e := c.RenderBytes() + return string(b), e +} + +// RenderBytes returns the bytes of rendered template string. Do not send out response. +func (c *Controller) RenderBytes() ([]byte, error) { + buf, err := c.renderTemplate() + //if the controller has set layout, then first get the tplName's content set the content to the layout + if err == nil && c.Layout != "" { + c.Data["LayoutContent"] = template.HTML(buf.String()) + + if c.LayoutSections != nil { + for sectionName, sectionTpl := range c.LayoutSections { + if sectionTpl == "" { + c.Data[sectionName] = "" + continue + } + buf.Reset() + err = ExecuteViewPathTemplate(&buf, sectionTpl, c.viewPath(), c.Data) + if err != nil { + return nil, err + } + c.Data[sectionName] = template.HTML(buf.String()) + } + } + + buf.Reset() + ExecuteViewPathTemplate(&buf, c.Layout, c.viewPath(), c.Data) + } + return buf.Bytes(), err +} + +func (c *Controller) renderTemplate() (bytes.Buffer, error) { + var buf bytes.Buffer + if c.TplName == "" { + c.TplName = strings.ToLower(c.controllerName) + "/" + strings.ToLower(c.actionName) + "." + c.TplExt + } + if c.TplPrefix != "" { + c.TplName = c.TplPrefix + c.TplName + } + if BConfig.RunMode == DEV { + buildFiles := []string{c.TplName} + if c.Layout != "" { + buildFiles = append(buildFiles, c.Layout) + if c.LayoutSections != nil { + for _, sectionTpl := range c.LayoutSections { + if sectionTpl == "" { + continue + } + buildFiles = append(buildFiles, sectionTpl) + } + } + } + BuildTemplate(c.viewPath(), buildFiles...) + } + return buf, ExecuteViewPathTemplate(&buf, c.TplName, c.viewPath(), c.Data) +} + +func (c *Controller) viewPath() string { + if c.ViewPath == "" { + return BConfig.WebConfig.ViewsPath + } + return c.ViewPath +} + +// Redirect sends the redirection response to url with status code. +func (c *Controller) Redirect(url string, code int) { + LogAccess(c.Ctx, nil, code) + c.Ctx.Redirect(code, url) +} + +// SetData set the data depending on the accepted +func (c *Controller) SetData(data interface{}) { + accept := c.Ctx.Input.Header("Accept") + switch accept { + case context.ApplicationYAML: + c.Data["yaml"] = data + case context.ApplicationXML, context.TextXML: + c.Data["xml"] = data + default: + c.Data["json"] = data + } +} + +// Abort stops controller handler and show the error data if code is defined in ErrorMap or code string. +func (c *Controller) Abort(code string) { + status, err := strconv.Atoi(code) + if err != nil { + status = 200 + } + c.CustomAbort(status, code) +} + +// CustomAbort stops controller handler and show the error data, it's similar Aborts, but support status code and body. +func (c *Controller) CustomAbort(status int, body string) { + // first panic from ErrorMaps, it is user defined error functions. + if _, ok := ErrorMaps[body]; ok { + c.Ctx.Output.Status = status + panic(body) + } + // last panic user string + c.Ctx.ResponseWriter.WriteHeader(status) + c.Ctx.ResponseWriter.Write([]byte(body)) + panic(ErrAbort) +} + +// StopRun makes panic of USERSTOPRUN error and go to recover function if defined. +func (c *Controller) StopRun() { + panic(ErrAbort) +} + +// URLFor does another controller handler in this request function. +// it goes to this controller method if endpoint is not clear. +func (c *Controller) URLFor(endpoint string, values ...interface{}) string { + if len(endpoint) == 0 { + return "" + } + if endpoint[0] == '.' { + return URLFor(reflect.Indirect(reflect.ValueOf(c.AppController)).Type().Name()+endpoint, values...) + } + return URLFor(endpoint, values...) +} + +// ServeJSON sends a json response with encoding charset. +func (c *Controller) ServeJSON(encoding ...bool) { + var ( + hasIndent = BConfig.RunMode != PROD + hasEncoding = len(encoding) > 0 && encoding[0] + ) + + c.Ctx.Output.JSON(c.Data["json"], hasIndent, hasEncoding) +} + +// ServeJSONP sends a jsonp response. +func (c *Controller) ServeJSONP() { + hasIndent := BConfig.RunMode != PROD + c.Ctx.Output.JSONP(c.Data["jsonp"], hasIndent) +} + +// ServeXML sends xml response. +func (c *Controller) ServeXML() { + hasIndent := BConfig.RunMode != PROD + c.Ctx.Output.XML(c.Data["xml"], hasIndent) +} + +// ServeYAML sends yaml response. +func (c *Controller) ServeYAML() { + c.Ctx.Output.YAML(c.Data["yaml"]) +} + +// ServeFormatted serve YAML, XML OR JSON, depending on the value of the Accept header +func (c *Controller) ServeFormatted(encoding ...bool) { + hasIndent := BConfig.RunMode != PROD + hasEncoding := len(encoding) > 0 && encoding[0] + c.Ctx.Output.ServeFormatted(c.Data, hasIndent, hasEncoding) +} + +// Input returns the input data map from POST or PUT request body and query string. +func (c *Controller) Input() url.Values { + if c.Ctx.Request.Form == nil { + c.Ctx.Request.ParseForm() + } + return c.Ctx.Request.Form +} + +// ParseForm maps input data map to obj struct. +func (c *Controller) ParseForm(obj interface{}) error { + return ParseForm(c.Input(), obj) +} + +// GetString returns the input value by key string or the default value while it's present and input is blank +func (c *Controller) GetString(key string, def ...string) string { + if v := c.Ctx.Input.Query(key); v != "" { + return v + } + if len(def) > 0 { + return def[0] + } + return "" +} + +// GetStrings returns the input string slice by key string or the default value while it's present and input is blank +// it's designed for multi-value input field such as checkbox(input[type=checkbox]), multi-selection. +func (c *Controller) GetStrings(key string, def ...[]string) []string { + var defv []string + if len(def) > 0 { + defv = def[0] + } + + if f := c.Input(); f == nil { + return defv + } else if vs := f[key]; len(vs) > 0 { + return vs + } + + return defv +} + +// GetInt returns input as an int or the default value while it's present and input is blank +func (c *Controller) GetInt(key string, def ...int) (int, error) { + strv := c.Ctx.Input.Query(key) + if len(strv) == 0 && len(def) > 0 { + return def[0], nil + } + return strconv.Atoi(strv) +} + +// GetInt8 return input as an int8 or the default value while it's present and input is blank +func (c *Controller) GetInt8(key string, def ...int8) (int8, error) { + strv := c.Ctx.Input.Query(key) + if len(strv) == 0 && len(def) > 0 { + return def[0], nil + } + i64, err := strconv.ParseInt(strv, 10, 8) + return int8(i64), err +} + +// GetUint8 return input as an uint8 or the default value while it's present and input is blank +func (c *Controller) GetUint8(key string, def ...uint8) (uint8, error) { + strv := c.Ctx.Input.Query(key) + if len(strv) == 0 && len(def) > 0 { + return def[0], nil + } + u64, err := strconv.ParseUint(strv, 10, 8) + return uint8(u64), err +} + +// GetInt16 returns input as an int16 or the default value while it's present and input is blank +func (c *Controller) GetInt16(key string, def ...int16) (int16, error) { + strv := c.Ctx.Input.Query(key) + if len(strv) == 0 && len(def) > 0 { + return def[0], nil + } + i64, err := strconv.ParseInt(strv, 10, 16) + return int16(i64), err +} + +// GetUint16 returns input as an uint16 or the default value while it's present and input is blank +func (c *Controller) GetUint16(key string, def ...uint16) (uint16, error) { + strv := c.Ctx.Input.Query(key) + if len(strv) == 0 && len(def) > 0 { + return def[0], nil + } + u64, err := strconv.ParseUint(strv, 10, 16) + return uint16(u64), err +} + +// GetInt32 returns input as an int32 or the default value while it's present and input is blank +func (c *Controller) GetInt32(key string, def ...int32) (int32, error) { + strv := c.Ctx.Input.Query(key) + if len(strv) == 0 && len(def) > 0 { + return def[0], nil + } + i64, err := strconv.ParseInt(strv, 10, 32) + return int32(i64), err +} + +// GetUint32 returns input as an uint32 or the default value while it's present and input is blank +func (c *Controller) GetUint32(key string, def ...uint32) (uint32, error) { + strv := c.Ctx.Input.Query(key) + if len(strv) == 0 && len(def) > 0 { + return def[0], nil + } + u64, err := strconv.ParseUint(strv, 10, 32) + return uint32(u64), err +} + +// GetInt64 returns input value as int64 or the default value while it's present and input is blank. +func (c *Controller) GetInt64(key string, def ...int64) (int64, error) { + strv := c.Ctx.Input.Query(key) + if len(strv) == 0 && len(def) > 0 { + return def[0], nil + } + return strconv.ParseInt(strv, 10, 64) +} + +// GetUint64 returns input value as uint64 or the default value while it's present and input is blank. +func (c *Controller) GetUint64(key string, def ...uint64) (uint64, error) { + strv := c.Ctx.Input.Query(key) + if len(strv) == 0 && len(def) > 0 { + return def[0], nil + } + return strconv.ParseUint(strv, 10, 64) +} + +// GetBool returns input value as bool or the default value while it's present and input is blank. +func (c *Controller) GetBool(key string, def ...bool) (bool, error) { + strv := c.Ctx.Input.Query(key) + if len(strv) == 0 && len(def) > 0 { + return def[0], nil + } + return strconv.ParseBool(strv) +} + +// GetFloat returns input value as float64 or the default value while it's present and input is blank. +func (c *Controller) GetFloat(key string, def ...float64) (float64, error) { + strv := c.Ctx.Input.Query(key) + if len(strv) == 0 && len(def) > 0 { + return def[0], nil + } + return strconv.ParseFloat(strv, 64) +} + +// GetFile returns the file data in file upload field named as key. +// it returns the first one of multi-uploaded files. +func (c *Controller) GetFile(key string) (multipart.File, *multipart.FileHeader, error) { + return c.Ctx.Request.FormFile(key) +} + +// GetFiles return multi-upload files +// files, err:=c.GetFiles("myfiles") +// if err != nil { +// http.Error(w, err.Error(), http.StatusNoContent) +// return +// } +// for i, _ := range files { +// //for each fileheader, get a handle to the actual file +// file, err := files[i].Open() +// defer file.Close() +// if err != nil { +// http.Error(w, err.Error(), http.StatusInternalServerError) +// return +// } +// //create destination file making sure the path is writeable. +// dst, err := os.Create("upload/" + files[i].Filename) +// defer dst.Close() +// if err != nil { +// http.Error(w, err.Error(), http.StatusInternalServerError) +// return +// } +// //copy the uploaded file to the destination file +// if _, err := io.Copy(dst, file); err != nil { +// http.Error(w, err.Error(), http.StatusInternalServerError) +// return +// } +// } +func (c *Controller) GetFiles(key string) ([]*multipart.FileHeader, error) { + if files, ok := c.Ctx.Request.MultipartForm.File[key]; ok { + return files, nil + } + return nil, http.ErrMissingFile +} + +// SaveToFile saves uploaded file to new path. +// it only operates the first one of mutil-upload form file field. +func (c *Controller) SaveToFile(fromfile, tofile string) error { + file, _, err := c.Ctx.Request.FormFile(fromfile) + if err != nil { + return err + } + defer file.Close() + f, err := os.OpenFile(tofile, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0666) + if err != nil { + return err + } + defer f.Close() + io.Copy(f, file) + return nil +} + +// StartSession starts session and load old session data info this controller. +func (c *Controller) StartSession() session.Store { + if c.CruSession == nil { + c.CruSession = c.Ctx.Input.CruSession + } + return c.CruSession +} + +// SetSession puts value into session. +func (c *Controller) SetSession(name interface{}, value interface{}) { + if c.CruSession == nil { + c.StartSession() + } + c.CruSession.Set(name, value) +} + +// GetSession gets value from session. +func (c *Controller) GetSession(name interface{}) interface{} { + if c.CruSession == nil { + c.StartSession() + } + return c.CruSession.Get(name) +} + +// DelSession removes value from session. +func (c *Controller) DelSession(name interface{}) { + if c.CruSession == nil { + c.StartSession() + } + c.CruSession.Delete(name) +} + +// SessionRegenerateID regenerates session id for this session. +// the session data have no changes. +func (c *Controller) SessionRegenerateID() { + if c.CruSession != nil { + c.CruSession.SessionRelease(c.Ctx.ResponseWriter) + } + c.CruSession = GlobalSessions.SessionRegenerateID(c.Ctx.ResponseWriter, c.Ctx.Request) + c.Ctx.Input.CruSession = c.CruSession +} + +// DestroySession cleans session data and session cookie. +func (c *Controller) DestroySession() { + c.Ctx.Input.CruSession.Flush() + c.Ctx.Input.CruSession = nil + GlobalSessions.SessionDestroy(c.Ctx.ResponseWriter, c.Ctx.Request) +} + +// IsAjax returns this request is ajax or not. +func (c *Controller) IsAjax() bool { + return c.Ctx.Input.IsAjax() +} + +// GetSecureCookie returns decoded cookie value from encoded browser cookie values. +func (c *Controller) GetSecureCookie(Secret, key string) (string, bool) { + return c.Ctx.GetSecureCookie(Secret, key) +} + +// SetSecureCookie puts value into cookie after encoded the value. +func (c *Controller) SetSecureCookie(Secret, name, value string, others ...interface{}) { + c.Ctx.SetSecureCookie(Secret, name, value, others...) +} + +// XSRFToken creates a CSRF token string and returns. +func (c *Controller) XSRFToken() string { + if c._xsrfToken == "" { + expire := int64(BConfig.WebConfig.XSRFExpire) + if c.XSRFExpire > 0 { + expire = int64(c.XSRFExpire) + } + c._xsrfToken = c.Ctx.XSRFToken(BConfig.WebConfig.XSRFKey, expire) + } + return c._xsrfToken +} + +// CheckXSRFCookie checks xsrf token in this request is valid or not. +// the token can provided in request header "X-Xsrftoken" and "X-CsrfToken" +// or in form field value named as "_xsrf". +func (c *Controller) CheckXSRFCookie() bool { + if !c.EnableXSRF { + return true + } + return c.Ctx.CheckXSRFCookie() +} + +// XSRFFormHTML writes an input field contains xsrf token value. +func (c *Controller) XSRFFormHTML() string { + return `` +} + +// GetControllerAndAction gets the executing controller name and action name. +func (c *Controller) GetControllerAndAction() (string, string) { + return c.controllerName, c.actionName +} diff --git a/pkg/controller_test.go b/pkg/controller_test.go new file mode 100644 index 00000000..1e53416d --- /dev/null +++ b/pkg/controller_test.go @@ -0,0 +1,181 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package beego + +import ( + "math" + "strconv" + "testing" + + "github.com/astaxie/beego/context" + "os" + "path/filepath" +) + +func TestGetInt(t *testing.T) { + i := context.NewInput() + i.SetParam("age", "40") + ctx := &context.Context{Input: i} + ctrlr := Controller{Ctx: ctx} + val, _ := ctrlr.GetInt("age") + if val != 40 { + t.Errorf("TestGetInt expect 40,get %T,%v", val, val) + } +} + +func TestGetInt8(t *testing.T) { + i := context.NewInput() + i.SetParam("age", "40") + ctx := &context.Context{Input: i} + ctrlr := Controller{Ctx: ctx} + val, _ := ctrlr.GetInt8("age") + if val != 40 { + t.Errorf("TestGetInt8 expect 40,get %T,%v", val, val) + } + //Output: int8 +} + +func TestGetInt16(t *testing.T) { + i := context.NewInput() + i.SetParam("age", "40") + ctx := &context.Context{Input: i} + ctrlr := Controller{Ctx: ctx} + val, _ := ctrlr.GetInt16("age") + if val != 40 { + t.Errorf("TestGetInt16 expect 40,get %T,%v", val, val) + } +} + +func TestGetInt32(t *testing.T) { + i := context.NewInput() + i.SetParam("age", "40") + ctx := &context.Context{Input: i} + ctrlr := Controller{Ctx: ctx} + val, _ := ctrlr.GetInt32("age") + if val != 40 { + t.Errorf("TestGetInt32 expect 40,get %T,%v", val, val) + } +} + +func TestGetInt64(t *testing.T) { + i := context.NewInput() + i.SetParam("age", "40") + ctx := &context.Context{Input: i} + ctrlr := Controller{Ctx: ctx} + val, _ := ctrlr.GetInt64("age") + if val != 40 { + t.Errorf("TestGeetInt64 expect 40,get %T,%v", val, val) + } +} + +func TestGetUint8(t *testing.T) { + i := context.NewInput() + i.SetParam("age", strconv.FormatUint(math.MaxUint8, 10)) + ctx := &context.Context{Input: i} + ctrlr := Controller{Ctx: ctx} + val, _ := ctrlr.GetUint8("age") + if val != math.MaxUint8 { + t.Errorf("TestGetUint8 expect %v,get %T,%v", math.MaxUint8, val, val) + } +} + +func TestGetUint16(t *testing.T) { + i := context.NewInput() + i.SetParam("age", strconv.FormatUint(math.MaxUint16, 10)) + ctx := &context.Context{Input: i} + ctrlr := Controller{Ctx: ctx} + val, _ := ctrlr.GetUint16("age") + if val != math.MaxUint16 { + t.Errorf("TestGetUint16 expect %v,get %T,%v", math.MaxUint16, val, val) + } +} + +func TestGetUint32(t *testing.T) { + i := context.NewInput() + i.SetParam("age", strconv.FormatUint(math.MaxUint32, 10)) + ctx := &context.Context{Input: i} + ctrlr := Controller{Ctx: ctx} + val, _ := ctrlr.GetUint32("age") + if val != math.MaxUint32 { + t.Errorf("TestGetUint32 expect %v,get %T,%v", math.MaxUint32, val, val) + } +} + +func TestGetUint64(t *testing.T) { + i := context.NewInput() + i.SetParam("age", strconv.FormatUint(math.MaxUint64, 10)) + ctx := &context.Context{Input: i} + ctrlr := Controller{Ctx: ctx} + val, _ := ctrlr.GetUint64("age") + if val != math.MaxUint64 { + t.Errorf("TestGetUint64 expect %v,get %T,%v", uint64(math.MaxUint64), val, val) + } +} + +func TestAdditionalViewPaths(t *testing.T) { + dir1 := "_beeTmp" + dir2 := "_beeTmp2" + defer os.RemoveAll(dir1) + defer os.RemoveAll(dir2) + + dir1file := "file1.tpl" + dir2file := "file2.tpl" + + genFile := func(dir string, name string, content string) { + os.MkdirAll(filepath.Dir(filepath.Join(dir, name)), 0777) + if f, err := os.Create(filepath.Join(dir, name)); err != nil { + t.Fatal(err) + } else { + defer f.Close() + f.WriteString(content) + f.Close() + } + + } + genFile(dir1, dir1file, `
{{.Content}}
`) + genFile(dir2, dir2file, `{{.Content}}`) + + AddViewPath(dir1) + AddViewPath(dir2) + + ctrl := Controller{ + TplName: "file1.tpl", + ViewPath: dir1, + } + ctrl.Data = map[interface{}]interface{}{ + "Content": "value2", + } + if result, err := ctrl.RenderString(); err != nil { + t.Fatal(err) + } else { + if result != "
value2
" { + t.Fatalf("TestAdditionalViewPaths expect %s got %s", "
value2
", result) + } + } + + func() { + ctrl.TplName = "file2.tpl" + defer func() { + if r := recover(); r == nil { + t.Fatal("TestAdditionalViewPaths expected error") + } + }() + ctrl.RenderString() + }() + + ctrl.TplName = "file2.tpl" + ctrl.ViewPath = dir2 + ctrl.RenderString() +} diff --git a/pkg/doc.go b/pkg/doc.go new file mode 100644 index 00000000..8825bd29 --- /dev/null +++ b/pkg/doc.go @@ -0,0 +1,17 @@ +/* +Package beego provide a MVC framework +beego: an open-source, high-performance, modular, full-stack web framework + +It is used for rapid development of RESTful APIs, web apps and backend services in Go. +beego is inspired by Tornado, Sinatra and Flask with the added benefit of some Go-specific features such as interfaces and struct embedding. + + package main + import "github.com/astaxie/beego" + + func main() { + beego.Run() + } + +more information: http://beego.me +*/ +package beego diff --git a/pkg/error.go b/pkg/error.go new file mode 100644 index 00000000..f268f723 --- /dev/null +++ b/pkg/error.go @@ -0,0 +1,488 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package beego + +import ( + "fmt" + "html/template" + "net/http" + "reflect" + "runtime" + "strconv" + "strings" + + "github.com/astaxie/beego/context" + "github.com/astaxie/beego/utils" +) + +const ( + errorTypeHandler = iota + errorTypeController +) + +var tpl = ` + + + + + beego application error + + + + + +
+ + + + + + + + + + +
Request Method: {{.RequestMethod}}
Request URL: {{.RequestURL}}
RemoteAddr: {{.RemoteAddr }}
+
+ Stack +
{{.Stack}}
+
+
+ + + +` + +// render default application error page with error and stack string. +func showErr(err interface{}, ctx *context.Context, stack string) { + t, _ := template.New("beegoerrortemp").Parse(tpl) + data := map[string]string{ + "AppError": fmt.Sprintf("%s:%v", BConfig.AppName, err), + "RequestMethod": ctx.Input.Method(), + "RequestURL": ctx.Input.URI(), + "RemoteAddr": ctx.Input.IP(), + "Stack": stack, + "BeegoVersion": VERSION, + "GoVersion": runtime.Version(), + } + t.Execute(ctx.ResponseWriter, data) +} + +var errtpl = ` + + + + + {{.Title}} + + + +
+
+ +
+ {{.Content}} + Go Home
+ +
Powered by beego {{.BeegoVersion}} +
+
+
+ + +` + +type errorInfo struct { + controllerType reflect.Type + handler http.HandlerFunc + method string + errorType int +} + +// ErrorMaps holds map of http handlers for each error string. +// there is 10 kinds default error(40x and 50x) +var ErrorMaps = make(map[string]*errorInfo, 10) + +// show 401 unauthorized error. +func unauthorized(rw http.ResponseWriter, r *http.Request) { + responseError(rw, r, + 401, + "
The page you have requested can't be authorized."+ + "
Perhaps you are here because:"+ + "

    "+ + "
    The credentials you supplied are incorrect"+ + "
    There are errors in the website address"+ + "
", + ) +} + +// show 402 Payment Required +func paymentRequired(rw http.ResponseWriter, r *http.Request) { + responseError(rw, r, + 402, + "
The page you have requested Payment Required."+ + "
Perhaps you are here because:"+ + "

    "+ + "
    The credentials you supplied are incorrect"+ + "
    There are errors in the website address"+ + "
", + ) +} + +// show 403 forbidden error. +func forbidden(rw http.ResponseWriter, r *http.Request) { + responseError(rw, r, + 403, + "
The page you have requested is forbidden."+ + "
Perhaps you are here because:"+ + "

    "+ + "
    Your address may be blocked"+ + "
    The site may be disabled"+ + "
    You need to log in"+ + "
", + ) +} + +// show 422 missing xsrf token +func missingxsrf(rw http.ResponseWriter, r *http.Request) { + responseError(rw, r, + 422, + "
The page you have requested is forbidden."+ + "
Perhaps you are here because:"+ + "

    "+ + "
    '_xsrf' argument missing from POST"+ + "
", + ) +} + +// show 417 invalid xsrf token +func invalidxsrf(rw http.ResponseWriter, r *http.Request) { + responseError(rw, r, + 417, + "
The page you have requested is forbidden."+ + "
Perhaps you are here because:"+ + "

    "+ + "
    expected XSRF not found"+ + "
", + ) +} + +// show 404 not found error. +func notFound(rw http.ResponseWriter, r *http.Request) { + responseError(rw, r, + 404, + "
The page you have requested has flown the coop."+ + "
Perhaps you are here because:"+ + "

    "+ + "
    The page has moved"+ + "
    The page no longer exists"+ + "
    You were looking for your puppy and got lost"+ + "
    You like 404 pages"+ + "
", + ) +} + +// show 405 Method Not Allowed +func methodNotAllowed(rw http.ResponseWriter, r *http.Request) { + responseError(rw, r, + 405, + "
The method you have requested Not Allowed."+ + "
Perhaps you are here because:"+ + "

    "+ + "
    The method specified in the Request-Line is not allowed for the resource identified by the Request-URI"+ + "
    The response MUST include an Allow header containing a list of valid methods for the requested resource."+ + "
", + ) +} + +// show 500 internal server error. +func internalServerError(rw http.ResponseWriter, r *http.Request) { + responseError(rw, r, + 500, + "
The page you have requested is down right now."+ + "

    "+ + "
    Please try again later and report the error to the website administrator"+ + "
", + ) +} + +// show 501 Not Implemented. +func notImplemented(rw http.ResponseWriter, r *http.Request) { + responseError(rw, r, + 501, + "
The page you have requested is Not Implemented."+ + "

    "+ + "
    Please try again later and report the error to the website administrator"+ + "
", + ) +} + +// show 502 Bad Gateway. +func badGateway(rw http.ResponseWriter, r *http.Request) { + responseError(rw, r, + 502, + "
The page you have requested is down right now."+ + "

    "+ + "
    The server, while acting as a gateway or proxy, received an invalid response from the upstream server it accessed in attempting to fulfill the request."+ + "
    Please try again later and report the error to the website administrator"+ + "
", + ) +} + +// show 503 service unavailable error. +func serviceUnavailable(rw http.ResponseWriter, r *http.Request) { + responseError(rw, r, + 503, + "
The page you have requested is unavailable."+ + "
Perhaps you are here because:"+ + "

    "+ + "

    The page is overloaded"+ + "
    Please try again later."+ + "
", + ) +} + +// show 504 Gateway Timeout. +func gatewayTimeout(rw http.ResponseWriter, r *http.Request) { + responseError(rw, r, + 504, + "
The page you have requested is unavailable"+ + "
Perhaps you are here because:"+ + "

    "+ + "

    The server, while acting as a gateway or proxy, did not receive a timely response from the upstream server specified by the URI."+ + "
    Please try again later."+ + "
", + ) +} + +// show 413 Payload Too Large +func payloadTooLarge(rw http.ResponseWriter, r *http.Request) { + responseError(rw, r, + 413, + `
The page you have requested is unavailable. +
Perhaps you are here because:

+
    +
    The request entity is larger than limits defined by server. +
    Please change the request entity and try again. +
+ `, + ) +} + +func responseError(rw http.ResponseWriter, r *http.Request, errCode int, errContent string) { + t, _ := template.New("beegoerrortemp").Parse(errtpl) + data := M{ + "Title": http.StatusText(errCode), + "BeegoVersion": VERSION, + "Content": template.HTML(errContent), + } + t.Execute(rw, data) +} + +// ErrorHandler registers http.HandlerFunc to each http err code string. +// usage: +// beego.ErrorHandler("404",NotFound) +// beego.ErrorHandler("500",InternalServerError) +func ErrorHandler(code string, h http.HandlerFunc) *App { + ErrorMaps[code] = &errorInfo{ + errorType: errorTypeHandler, + handler: h, + method: code, + } + return BeeApp +} + +// ErrorController registers ControllerInterface to each http err code string. +// usage: +// beego.ErrorController(&controllers.ErrorController{}) +func ErrorController(c ControllerInterface) *App { + reflectVal := reflect.ValueOf(c) + rt := reflectVal.Type() + ct := reflect.Indirect(reflectVal).Type() + for i := 0; i < rt.NumMethod(); i++ { + methodName := rt.Method(i).Name + if !utils.InSlice(methodName, exceptMethod) && strings.HasPrefix(methodName, "Error") { + errName := strings.TrimPrefix(methodName, "Error") + ErrorMaps[errName] = &errorInfo{ + errorType: errorTypeController, + controllerType: ct, + method: methodName, + } + } + } + return BeeApp +} + +// Exception Write HttpStatus with errCode and Exec error handler if exist. +func Exception(errCode uint64, ctx *context.Context) { + exception(strconv.FormatUint(errCode, 10), ctx) +} + +// show error string as simple text message. +// if error string is empty, show 503 or 500 error as default. +func exception(errCode string, ctx *context.Context) { + atoi := func(code string) int { + v, err := strconv.Atoi(code) + if err == nil { + return v + } + if ctx.Output.Status == 0 { + return 503 + } + return ctx.Output.Status + } + + for _, ec := range []string{errCode, "503", "500"} { + if h, ok := ErrorMaps[ec]; ok { + executeError(h, ctx, atoi(ec)) + return + } + } + //if 50x error has been removed from errorMap + ctx.ResponseWriter.WriteHeader(atoi(errCode)) + ctx.WriteString(errCode) +} + +func executeError(err *errorInfo, ctx *context.Context, code int) { + //make sure to log the error in the access log + LogAccess(ctx, nil, code) + + if err.errorType == errorTypeHandler { + ctx.ResponseWriter.WriteHeader(code) + err.handler(ctx.ResponseWriter, ctx.Request) + return + } + if err.errorType == errorTypeController { + ctx.Output.SetStatus(code) + //Invoke the request handler + vc := reflect.New(err.controllerType) + execController, ok := vc.Interface().(ControllerInterface) + if !ok { + panic("controller is not ControllerInterface") + } + //call the controller init function + execController.Init(ctx, err.controllerType.Name(), err.method, vc.Interface()) + + //call prepare function + execController.Prepare() + + execController.URLMapping() + + method := vc.MethodByName(err.method) + method.Call([]reflect.Value{}) + + //render template + if BConfig.WebConfig.AutoRender { + if err := execController.Render(); err != nil { + panic(err) + } + } + + // finish all runrouter. release resource + execController.Finish() + } +} diff --git a/pkg/error_test.go b/pkg/error_test.go new file mode 100644 index 00000000..378aa953 --- /dev/null +++ b/pkg/error_test.go @@ -0,0 +1,88 @@ +// Copyright 2016 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package beego + +import ( + "net/http" + "net/http/httptest" + "strconv" + "strings" + "testing" +) + +type errorTestController struct { + Controller +} + +const parseCodeError = "parse code error" + +func (ec *errorTestController) Get() { + errorCode, err := ec.GetInt("code") + if err != nil { + ec.Abort(parseCodeError) + } + if errorCode != 0 { + ec.CustomAbort(errorCode, ec.GetString("code")) + } + ec.Abort("404") +} + +func TestErrorCode_01(t *testing.T) { + registerDefaultErrorHandler() + for k := range ErrorMaps { + r, _ := http.NewRequest("GET", "/error?code="+k, nil) + w := httptest.NewRecorder() + + handler := NewControllerRegister() + handler.Add("/error", &errorTestController{}) + handler.ServeHTTP(w, r) + code, _ := strconv.Atoi(k) + if w.Code != code { + t.Fail() + } + if !strings.Contains(w.Body.String(), http.StatusText(code)) { + t.Fail() + } + } +} + +func TestErrorCode_02(t *testing.T) { + registerDefaultErrorHandler() + r, _ := http.NewRequest("GET", "/error?code=0", nil) + w := httptest.NewRecorder() + + handler := NewControllerRegister() + handler.Add("/error", &errorTestController{}) + handler.ServeHTTP(w, r) + if w.Code != 404 { + t.Fail() + } +} + +func TestErrorCode_03(t *testing.T) { + registerDefaultErrorHandler() + r, _ := http.NewRequest("GET", "/error?code=panic", nil) + w := httptest.NewRecorder() + + handler := NewControllerRegister() + handler.Add("/error", &errorTestController{}) + handler.ServeHTTP(w, r) + if w.Code != 200 { + t.Fail() + } + if w.Body.String() != parseCodeError { + t.Fail() + } +} diff --git a/pkg/filter.go b/pkg/filter.go new file mode 100644 index 00000000..9cc6e913 --- /dev/null +++ b/pkg/filter.go @@ -0,0 +1,44 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package beego + +import "github.com/astaxie/beego/context" + +// FilterFunc defines a filter function which is invoked before the controller handler is executed. +type FilterFunc func(*context.Context) + +// FilterRouter defines a filter operation which is invoked before the controller handler is executed. +// It can match the URL against a pattern, and execute a filter function +// when a request with a matching URL arrives. +type FilterRouter struct { + filterFunc FilterFunc + tree *Tree + pattern string + returnOnOutput bool + resetParams bool +} + +// ValidRouter checks if the current request is matched by this filter. +// If the request is matched, the values of the URL parameters defined +// by the filter pattern are also returned. +func (f *FilterRouter) ValidRouter(url string, ctx *context.Context) bool { + isOk := f.tree.Match(url, ctx) + if isOk != nil { + if b, ok := isOk.(bool); ok { + return b + } + } + return false +} diff --git a/pkg/filter_test.go b/pkg/filter_test.go new file mode 100644 index 00000000..4ca4d2b8 --- /dev/null +++ b/pkg/filter_test.go @@ -0,0 +1,68 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package beego + +import ( + "net/http" + "net/http/httptest" + "testing" + + "github.com/astaxie/beego/context" +) + +var FilterUser = func(ctx *context.Context) { + ctx.Output.Body([]byte("i am " + ctx.Input.Param(":last") + ctx.Input.Param(":first"))) +} + +func TestFilter(t *testing.T) { + r, _ := http.NewRequest("GET", "/person/asta/Xie", nil) + w := httptest.NewRecorder() + handler := NewControllerRegister() + handler.InsertFilter("/person/:last/:first", BeforeRouter, FilterUser) + handler.Add("/person/:last/:first", &TestController{}) + handler.ServeHTTP(w, r) + if w.Body.String() != "i am astaXie" { + t.Errorf("user define func can't run") + } +} + +var FilterAdminUser = func(ctx *context.Context) { + ctx.Output.Body([]byte("i am admin")) +} + +// Filter pattern /admin/:all +// all url like /admin/ /admin/xie will all get filter + +func TestPatternTwo(t *testing.T) { + r, _ := http.NewRequest("GET", "/admin/", nil) + w := httptest.NewRecorder() + handler := NewControllerRegister() + handler.InsertFilter("/admin/?:all", BeforeRouter, FilterAdminUser) + handler.ServeHTTP(w, r) + if w.Body.String() != "i am admin" { + t.Errorf("filter /admin/ can't run") + } +} + +func TestPatternThree(t *testing.T) { + r, _ := http.NewRequest("GET", "/admin/astaxie", nil) + w := httptest.NewRecorder() + handler := NewControllerRegister() + handler.InsertFilter("/admin/:all", BeforeRouter, FilterAdminUser) + handler.ServeHTTP(w, r) + if w.Body.String() != "i am admin" { + t.Errorf("filter /admin/astaxie can't run") + } +} diff --git a/pkg/flash.go b/pkg/flash.go new file mode 100644 index 00000000..a6485a17 --- /dev/null +++ b/pkg/flash.go @@ -0,0 +1,110 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package beego + +import ( + "fmt" + "net/url" + "strings" +) + +// FlashData is a tools to maintain data when using across request. +type FlashData struct { + Data map[string]string +} + +// NewFlash return a new empty FlashData struct. +func NewFlash() *FlashData { + return &FlashData{ + Data: make(map[string]string), + } +} + +// Set message to flash +func (fd *FlashData) Set(key string, msg string, args ...interface{}) { + if len(args) == 0 { + fd.Data[key] = msg + } else { + fd.Data[key] = fmt.Sprintf(msg, args...) + } +} + +// Success writes success message to flash. +func (fd *FlashData) Success(msg string, args ...interface{}) { + if len(args) == 0 { + fd.Data["success"] = msg + } else { + fd.Data["success"] = fmt.Sprintf(msg, args...) + } +} + +// Notice writes notice message to flash. +func (fd *FlashData) Notice(msg string, args ...interface{}) { + if len(args) == 0 { + fd.Data["notice"] = msg + } else { + fd.Data["notice"] = fmt.Sprintf(msg, args...) + } +} + +// Warning writes warning message to flash. +func (fd *FlashData) Warning(msg string, args ...interface{}) { + if len(args) == 0 { + fd.Data["warning"] = msg + } else { + fd.Data["warning"] = fmt.Sprintf(msg, args...) + } +} + +// Error writes error message to flash. +func (fd *FlashData) Error(msg string, args ...interface{}) { + if len(args) == 0 { + fd.Data["error"] = msg + } else { + fd.Data["error"] = fmt.Sprintf(msg, args...) + } +} + +// Store does the saving operation of flash data. +// the data are encoded and saved in cookie. +func (fd *FlashData) Store(c *Controller) { + c.Data["flash"] = fd.Data + var flashValue string + for key, value := range fd.Data { + flashValue += "\x00" + key + "\x23" + BConfig.WebConfig.FlashSeparator + "\x23" + value + "\x00" + } + c.Ctx.SetCookie(BConfig.WebConfig.FlashName, url.QueryEscape(flashValue), 0, "/") +} + +// ReadFromRequest parsed flash data from encoded values in cookie. +func ReadFromRequest(c *Controller) *FlashData { + flash := NewFlash() + if cookie, err := c.Ctx.Request.Cookie(BConfig.WebConfig.FlashName); err == nil { + v, _ := url.QueryUnescape(cookie.Value) + vals := strings.Split(v, "\x00") + for _, v := range vals { + if len(v) > 0 { + kv := strings.Split(v, "\x23"+BConfig.WebConfig.FlashSeparator+"\x23") + if len(kv) == 2 { + flash.Data[kv[0]] = kv[1] + } + } + } + //read one time then delete it + c.Ctx.SetCookie(BConfig.WebConfig.FlashName, "", -1, "/") + } + c.Data["flash"] = flash.Data + return flash +} diff --git a/pkg/flash_test.go b/pkg/flash_test.go new file mode 100644 index 00000000..d5e9608d --- /dev/null +++ b/pkg/flash_test.go @@ -0,0 +1,54 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package beego + +import ( + "net/http" + "net/http/httptest" + "strings" + "testing" +) + +type TestFlashController struct { + Controller +} + +func (t *TestFlashController) TestWriteFlash() { + flash := NewFlash() + flash.Notice("TestFlashString") + flash.Store(&t.Controller) + // we choose to serve json because we don't want to load a template html file + t.ServeJSON(true) +} + +func TestFlashHeader(t *testing.T) { + // create fake GET request + r, _ := http.NewRequest("GET", "/", nil) + w := httptest.NewRecorder() + + // setup the handler + handler := NewControllerRegister() + handler.Add("/", &TestFlashController{}, "get:TestWriteFlash") + handler.ServeHTTP(w, r) + + // get the Set-Cookie value + sc := w.Header().Get("Set-Cookie") + // match for the expected header + res := strings.Contains(sc, "BEEGO_FLASH=%00notice%23BEEGOFLASH%23TestFlashString%00") + // validate the assertion + if !res { + t.Errorf("TestFlashHeader() unable to validate flash message") + } +} diff --git a/pkg/fs.go b/pkg/fs.go new file mode 100644 index 00000000..41cc6f6e --- /dev/null +++ b/pkg/fs.go @@ -0,0 +1,74 @@ +package beego + +import ( + "net/http" + "os" + "path/filepath" +) + +type FileSystem struct { +} + +func (d FileSystem) Open(name string) (http.File, error) { + return os.Open(name) +} + +// Walk walks the file tree rooted at root in filesystem, calling walkFn for each file or +// directory in the tree, including root. All errors that arise visiting files +// and directories are filtered by walkFn. +func Walk(fs http.FileSystem, root string, walkFn filepath.WalkFunc) error { + + f, err := fs.Open(root) + if err != nil { + return err + } + info, err := f.Stat() + if err != nil { + err = walkFn(root, nil, err) + } else { + err = walk(fs, root, info, walkFn) + } + if err == filepath.SkipDir { + return nil + } + return err +} + +// walk recursively descends path, calling walkFn. +func walk(fs http.FileSystem, path string, info os.FileInfo, walkFn filepath.WalkFunc) error { + var err error + if !info.IsDir() { + return walkFn(path, info, nil) + } + + dir, err := fs.Open(path) + if err != nil { + if err1 := walkFn(path, info, err); err1 != nil { + return err1 + } + return err + } + defer dir.Close() + dirs, err := dir.Readdir(-1) + err1 := walkFn(path, info, err) + // If err != nil, walk can't walk into this directory. + // err1 != nil means walkFn want walk to skip this directory or stop walking. + // Therefore, if one of err and err1 isn't nil, walk will return. + if err != nil || err1 != nil { + // The caller's behavior is controlled by the return value, which is decided + // by walkFn. walkFn may ignore err and return nil. + // If walkFn returns SkipDir, it will be handled by the caller. + // So walk should return whatever walkFn returns. + return err1 + } + + for _, fileInfo := range dirs { + filename := filepath.Join(path, fileInfo.Name()) + if err = walk(fs, filename, fileInfo, walkFn); err != nil { + if !fileInfo.IsDir() || err != filepath.SkipDir { + return err + } + } + } + return nil +} diff --git a/pkg/grace/grace.go b/pkg/grace/grace.go new file mode 100644 index 00000000..fb0cb7bb --- /dev/null +++ b/pkg/grace/grace.go @@ -0,0 +1,166 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package grace use to hot reload +// Description: http://grisha.org/blog/2014/06/03/graceful-restart-in-golang/ +// +// Usage: +// +// import( +// "log" +// "net/http" +// "os" +// +// "github.com/astaxie/beego/grace" +// ) +// +// func handler(w http.ResponseWriter, r *http.Request) { +// w.Write([]byte("WORLD!")) +// } +// +// func main() { +// mux := http.NewServeMux() +// mux.HandleFunc("/hello", handler) +// +// err := grace.ListenAndServe("localhost:8080", mux) +// if err != nil { +// log.Println(err) +// } +// log.Println("Server on 8080 stopped") +// os.Exit(0) +// } +package grace + +import ( + "flag" + "net/http" + "os" + "strings" + "sync" + "syscall" + "time" +) + +const ( + // PreSignal is the position to add filter before signal + PreSignal = iota + // PostSignal is the position to add filter after signal + PostSignal + // StateInit represent the application inited + StateInit + // StateRunning represent the application is running + StateRunning + // StateShuttingDown represent the application is shutting down + StateShuttingDown + // StateTerminate represent the application is killed + StateTerminate +) + +var ( + regLock *sync.Mutex + runningServers map[string]*Server + runningServersOrder []string + socketPtrOffsetMap map[string]uint + runningServersForked bool + + // DefaultReadTimeOut is the HTTP read timeout + DefaultReadTimeOut time.Duration + // DefaultWriteTimeOut is the HTTP Write timeout + DefaultWriteTimeOut time.Duration + // DefaultMaxHeaderBytes is the Max HTTP Header size, default is 0, no limit + DefaultMaxHeaderBytes int + // DefaultTimeout is the shutdown server's timeout. default is 60s + DefaultTimeout = 60 * time.Second + + isChild bool + socketOrder string + + hookableSignals []os.Signal +) + +func init() { + flag.BoolVar(&isChild, "graceful", false, "listen on open fd (after forking)") + flag.StringVar(&socketOrder, "socketorder", "", "previous initialization order - used when more than one listener was started") + + regLock = &sync.Mutex{} + runningServers = make(map[string]*Server) + runningServersOrder = []string{} + socketPtrOffsetMap = make(map[string]uint) + + hookableSignals = []os.Signal{ + syscall.SIGHUP, + syscall.SIGINT, + syscall.SIGTERM, + } +} + +// NewServer returns a new graceServer. +func NewServer(addr string, handler http.Handler) (srv *Server) { + regLock.Lock() + defer regLock.Unlock() + + if !flag.Parsed() { + flag.Parse() + } + if len(socketOrder) > 0 { + for i, addr := range strings.Split(socketOrder, ",") { + socketPtrOffsetMap[addr] = uint(i) + } + } else { + socketPtrOffsetMap[addr] = uint(len(runningServersOrder)) + } + + srv = &Server{ + sigChan: make(chan os.Signal), + isChild: isChild, + SignalHooks: map[int]map[os.Signal][]func(){ + PreSignal: { + syscall.SIGHUP: {}, + syscall.SIGINT: {}, + syscall.SIGTERM: {}, + }, + PostSignal: { + syscall.SIGHUP: {}, + syscall.SIGINT: {}, + syscall.SIGTERM: {}, + }, + }, + state: StateInit, + Network: "tcp", + terminalChan: make(chan error), //no cache channel + } + srv.Server = &http.Server{ + Addr: addr, + ReadTimeout: DefaultReadTimeOut, + WriteTimeout: DefaultWriteTimeOut, + MaxHeaderBytes: DefaultMaxHeaderBytes, + Handler: handler, + } + + runningServersOrder = append(runningServersOrder, addr) + runningServers[addr] = srv + return srv +} + +// ListenAndServe refer http.ListenAndServe +func ListenAndServe(addr string, handler http.Handler) error { + server := NewServer(addr, handler) + return server.ListenAndServe() +} + +// ListenAndServeTLS refer http.ListenAndServeTLS +func ListenAndServeTLS(addr string, certFile string, keyFile string, handler http.Handler) error { + server := NewServer(addr, handler) + return server.ListenAndServeTLS(certFile, keyFile) +} diff --git a/pkg/grace/server.go b/pkg/grace/server.go new file mode 100644 index 00000000..008a6171 --- /dev/null +++ b/pkg/grace/server.go @@ -0,0 +1,356 @@ +package grace + +import ( + "context" + "crypto/tls" + "crypto/x509" + "fmt" + "io/ioutil" + "log" + "net" + "net/http" + "os" + "os/exec" + "os/signal" + "strings" + "syscall" + "time" +) + +// Server embedded http.Server +type Server struct { + *http.Server + ln net.Listener + SignalHooks map[int]map[os.Signal][]func() + sigChan chan os.Signal + isChild bool + state uint8 + Network string + terminalChan chan error +} + +// Serve accepts incoming connections on the Listener l, +// creating a new service goroutine for each. +// The service goroutines read requests and then call srv.Handler to reply to them. +func (srv *Server) Serve() (err error) { + srv.state = StateRunning + defer func() { srv.state = StateTerminate }() + + // When Shutdown is called, Serve, ListenAndServe, and ListenAndServeTLS + // immediately return ErrServerClosed. Make sure the program doesn't exit + // and waits instead for Shutdown to return. + if err = srv.Server.Serve(srv.ln); err != nil && err != http.ErrServerClosed { + log.Println(syscall.Getpid(), "Server.Serve() error:", err) + return err + } + + log.Println(syscall.Getpid(), srv.ln.Addr(), "Listener closed.") + // wait for Shutdown to return + if shutdownErr := <-srv.terminalChan; shutdownErr != nil { + return shutdownErr + } + return +} + +// ListenAndServe listens on the TCP network address srv.Addr and then calls Serve +// to handle requests on incoming connections. If srv.Addr is blank, ":http" is +// used. +func (srv *Server) ListenAndServe() (err error) { + addr := srv.Addr + if addr == "" { + addr = ":http" + } + + go srv.handleSignals() + + srv.ln, err = srv.getListener(addr) + if err != nil { + log.Println(err) + return err + } + + if srv.isChild { + process, err := os.FindProcess(os.Getppid()) + if err != nil { + log.Println(err) + return err + } + err = process.Signal(syscall.SIGTERM) + if err != nil { + return err + } + } + + log.Println(os.Getpid(), srv.Addr) + return srv.Serve() +} + +// ListenAndServeTLS listens on the TCP network address srv.Addr and then calls +// Serve to handle requests on incoming TLS connections. +// +// Filenames containing a certificate and matching private key for the server must +// be provided. If the certificate is signed by a certificate authority, the +// certFile should be the concatenation of the server's certificate followed by the +// CA's certificate. +// +// If srv.Addr is blank, ":https" is used. +func (srv *Server) ListenAndServeTLS(certFile, keyFile string) (err error) { + addr := srv.Addr + if addr == "" { + addr = ":https" + } + + if srv.TLSConfig == nil { + srv.TLSConfig = &tls.Config{} + } + if srv.TLSConfig.NextProtos == nil { + srv.TLSConfig.NextProtos = []string{"http/1.1"} + } + + srv.TLSConfig.Certificates = make([]tls.Certificate, 1) + srv.TLSConfig.Certificates[0], err = tls.LoadX509KeyPair(certFile, keyFile) + if err != nil { + return + } + + go srv.handleSignals() + + ln, err := srv.getListener(addr) + if err != nil { + log.Println(err) + return err + } + srv.ln = tls.NewListener(tcpKeepAliveListener{ln.(*net.TCPListener)}, srv.TLSConfig) + + if srv.isChild { + process, err := os.FindProcess(os.Getppid()) + if err != nil { + log.Println(err) + return err + } + err = process.Signal(syscall.SIGTERM) + if err != nil { + return err + } + } + + log.Println(os.Getpid(), srv.Addr) + return srv.Serve() +} + +// ListenAndServeMutualTLS listens on the TCP network address srv.Addr and then calls +// Serve to handle requests on incoming mutual TLS connections. +func (srv *Server) ListenAndServeMutualTLS(certFile, keyFile, trustFile string) (err error) { + addr := srv.Addr + if addr == "" { + addr = ":https" + } + + if srv.TLSConfig == nil { + srv.TLSConfig = &tls.Config{} + } + if srv.TLSConfig.NextProtos == nil { + srv.TLSConfig.NextProtos = []string{"http/1.1"} + } + + srv.TLSConfig.Certificates = make([]tls.Certificate, 1) + srv.TLSConfig.Certificates[0], err = tls.LoadX509KeyPair(certFile, keyFile) + if err != nil { + return + } + srv.TLSConfig.ClientAuth = tls.RequireAndVerifyClientCert + pool := x509.NewCertPool() + data, err := ioutil.ReadFile(trustFile) + if err != nil { + log.Println(err) + return err + } + pool.AppendCertsFromPEM(data) + srv.TLSConfig.ClientCAs = pool + log.Println("Mutual HTTPS") + go srv.handleSignals() + + ln, err := srv.getListener(addr) + if err != nil { + log.Println(err) + return err + } + srv.ln = tls.NewListener(tcpKeepAliveListener{ln.(*net.TCPListener)}, srv.TLSConfig) + + if srv.isChild { + process, err := os.FindProcess(os.Getppid()) + if err != nil { + log.Println(err) + return err + } + err = process.Signal(syscall.SIGTERM) + if err != nil { + return err + } + } + + log.Println(os.Getpid(), srv.Addr) + return srv.Serve() +} + +// getListener either opens a new socket to listen on, or takes the acceptor socket +// it got passed when restarted. +func (srv *Server) getListener(laddr string) (l net.Listener, err error) { + if srv.isChild { + var ptrOffset uint + if len(socketPtrOffsetMap) > 0 { + ptrOffset = socketPtrOffsetMap[laddr] + log.Println("laddr", laddr, "ptr offset", socketPtrOffsetMap[laddr]) + } + + f := os.NewFile(uintptr(3+ptrOffset), "") + l, err = net.FileListener(f) + if err != nil { + err = fmt.Errorf("net.FileListener error: %v", err) + return + } + } else { + l, err = net.Listen(srv.Network, laddr) + if err != nil { + err = fmt.Errorf("net.Listen error: %v", err) + return + } + } + return +} + +type tcpKeepAliveListener struct { + *net.TCPListener +} + +func (ln tcpKeepAliveListener) Accept() (c net.Conn, err error) { + tc, err := ln.AcceptTCP() + if err != nil { + return + } + tc.SetKeepAlive(true) + tc.SetKeepAlivePeriod(3 * time.Minute) + return tc, nil +} + +// handleSignals listens for os Signals and calls any hooked in function that the +// user had registered with the signal. +func (srv *Server) handleSignals() { + var sig os.Signal + + signal.Notify( + srv.sigChan, + hookableSignals..., + ) + + pid := syscall.Getpid() + for { + sig = <-srv.sigChan + srv.signalHooks(PreSignal, sig) + switch sig { + case syscall.SIGHUP: + log.Println(pid, "Received SIGHUP. forking.") + err := srv.fork() + if err != nil { + log.Println("Fork err:", err) + } + case syscall.SIGINT: + log.Println(pid, "Received SIGINT.") + srv.shutdown() + case syscall.SIGTERM: + log.Println(pid, "Received SIGTERM.") + srv.shutdown() + default: + log.Printf("Received %v: nothing i care about...\n", sig) + } + srv.signalHooks(PostSignal, sig) + } +} + +func (srv *Server) signalHooks(ppFlag int, sig os.Signal) { + if _, notSet := srv.SignalHooks[ppFlag][sig]; !notSet { + return + } + for _, f := range srv.SignalHooks[ppFlag][sig] { + f() + } +} + +// shutdown closes the listener so that no new connections are accepted. it also +// starts a goroutine that will serverTimeout (stop all running requests) the server +// after DefaultTimeout. +func (srv *Server) shutdown() { + if srv.state != StateRunning { + return + } + + srv.state = StateShuttingDown + log.Println(syscall.Getpid(), "Waiting for connections to finish...") + ctx := context.Background() + if DefaultTimeout >= 0 { + var cancel context.CancelFunc + ctx, cancel = context.WithTimeout(context.Background(), DefaultTimeout) + defer cancel() + } + srv.terminalChan <- srv.Server.Shutdown(ctx) +} + +func (srv *Server) fork() (err error) { + regLock.Lock() + defer regLock.Unlock() + if runningServersForked { + return + } + runningServersForked = true + + var files = make([]*os.File, len(runningServers)) + var orderArgs = make([]string, len(runningServers)) + for _, srvPtr := range runningServers { + f, _ := srvPtr.ln.(*net.TCPListener).File() + files[socketPtrOffsetMap[srvPtr.Server.Addr]] = f + orderArgs[socketPtrOffsetMap[srvPtr.Server.Addr]] = srvPtr.Server.Addr + } + + log.Println(files) + path := os.Args[0] + var args []string + if len(os.Args) > 1 { + for _, arg := range os.Args[1:] { + if arg == "-graceful" { + break + } + args = append(args, arg) + } + } + args = append(args, "-graceful") + if len(runningServers) > 1 { + args = append(args, fmt.Sprintf(`-socketorder=%s`, strings.Join(orderArgs, ","))) + log.Println(args) + } + cmd := exec.Command(path, args...) + cmd.Stdout = os.Stdout + cmd.Stderr = os.Stderr + cmd.ExtraFiles = files + err = cmd.Start() + if err != nil { + log.Fatalf("Restart: Failed to launch, error: %v", err) + } + + return +} + +// RegisterSignalHook registers a function to be run PreSignal or PostSignal for a given signal. +func (srv *Server) RegisterSignalHook(ppFlag int, sig os.Signal, f func()) (err error) { + if ppFlag != PreSignal && ppFlag != PostSignal { + err = fmt.Errorf("Invalid ppFlag argument. Must be either grace.PreSignal or grace.PostSignal") + return + } + for _, s := range hookableSignals { + if s == sig { + srv.SignalHooks[ppFlag][sig] = append(srv.SignalHooks[ppFlag][sig], f) + return + } + } + err = fmt.Errorf("Signal '%v' is not supported", sig) + return +} diff --git a/pkg/hooks.go b/pkg/hooks.go new file mode 100644 index 00000000..49c42d5a --- /dev/null +++ b/pkg/hooks.go @@ -0,0 +1,104 @@ +package beego + +import ( + "encoding/json" + "mime" + "net/http" + "path/filepath" + + "github.com/astaxie/beego/context" + "github.com/astaxie/beego/logs" + "github.com/astaxie/beego/session" +) + +// register MIME type with content type +func registerMime() error { + for k, v := range mimemaps { + mime.AddExtensionType(k, v) + } + return nil +} + +// register default error http handlers, 404,401,403,500 and 503. +func registerDefaultErrorHandler() error { + m := map[string]func(http.ResponseWriter, *http.Request){ + "401": unauthorized, + "402": paymentRequired, + "403": forbidden, + "404": notFound, + "405": methodNotAllowed, + "500": internalServerError, + "501": notImplemented, + "502": badGateway, + "503": serviceUnavailable, + "504": gatewayTimeout, + "417": invalidxsrf, + "422": missingxsrf, + "413": payloadTooLarge, + } + for e, h := range m { + if _, ok := ErrorMaps[e]; !ok { + ErrorHandler(e, h) + } + } + return nil +} + +func registerSession() error { + if BConfig.WebConfig.Session.SessionOn { + var err error + sessionConfig := AppConfig.String("sessionConfig") + conf := new(session.ManagerConfig) + if sessionConfig == "" { + conf.CookieName = BConfig.WebConfig.Session.SessionName + conf.EnableSetCookie = BConfig.WebConfig.Session.SessionAutoSetCookie + conf.Gclifetime = BConfig.WebConfig.Session.SessionGCMaxLifetime + conf.Secure = BConfig.Listen.EnableHTTPS + conf.CookieLifeTime = BConfig.WebConfig.Session.SessionCookieLifeTime + conf.ProviderConfig = filepath.ToSlash(BConfig.WebConfig.Session.SessionProviderConfig) + conf.DisableHTTPOnly = BConfig.WebConfig.Session.SessionDisableHTTPOnly + conf.Domain = BConfig.WebConfig.Session.SessionDomain + conf.EnableSidInHTTPHeader = BConfig.WebConfig.Session.SessionEnableSidInHTTPHeader + conf.SessionNameInHTTPHeader = BConfig.WebConfig.Session.SessionNameInHTTPHeader + conf.EnableSidInURLQuery = BConfig.WebConfig.Session.SessionEnableSidInURLQuery + } else { + if err = json.Unmarshal([]byte(sessionConfig), conf); err != nil { + return err + } + } + if GlobalSessions, err = session.NewManager(BConfig.WebConfig.Session.SessionProvider, conf); err != nil { + return err + } + go GlobalSessions.GC() + } + return nil +} + +func registerTemplate() error { + defer lockViewPaths() + if err := AddViewPath(BConfig.WebConfig.ViewsPath); err != nil { + if BConfig.RunMode == DEV { + logs.Warn(err) + } + return err + } + return nil +} + +func registerAdmin() error { + if BConfig.Listen.EnableAdmin { + go beeAdminApp.Run() + } + return nil +} + +func registerGzip() error { + if BConfig.EnableGzip { + context.InitGzip( + AppConfig.DefaultInt("gzipMinLength", -1), + AppConfig.DefaultInt("gzipCompressLevel", -1), + AppConfig.DefaultStrings("includedMethods", []string{"GET"}), + ) + } + return nil +} diff --git a/pkg/httplib/README.md b/pkg/httplib/README.md new file mode 100644 index 00000000..97df8e6b --- /dev/null +++ b/pkg/httplib/README.md @@ -0,0 +1,97 @@ +# httplib +httplib is an libs help you to curl remote url. + +# How to use? + +## GET +you can use Get to crawl data. + + import "github.com/astaxie/beego/httplib" + + str, err := httplib.Get("http://beego.me/").String() + if err != nil { + // error + } + fmt.Println(str) + +## POST +POST data to remote url + + req := httplib.Post("http://beego.me/") + req.Param("username","astaxie") + req.Param("password","123456") + str, err := req.String() + if err != nil { + // error + } + fmt.Println(str) + +## Set timeout + +The default timeout is `60` seconds, function prototype: + + SetTimeout(connectTimeout, readWriteTimeout time.Duration) + +Example: + + // GET + httplib.Get("http://beego.me/").SetTimeout(100 * time.Second, 30 * time.Second) + + // POST + httplib.Post("http://beego.me/").SetTimeout(100 * time.Second, 30 * time.Second) + + +## Debug + +If you want to debug the request info, set the debug on + + httplib.Get("http://beego.me/").Debug(true) + +## Set HTTP Basic Auth + + str, err := Get("http://beego.me/").SetBasicAuth("user", "passwd").String() + if err != nil { + // error + } + fmt.Println(str) + +## Set HTTPS + +If request url is https, You can set the client support TSL: + + httplib.SetTLSClientConfig(&tls.Config{InsecureSkipVerify: true}) + +More info about the `tls.Config` please visit http://golang.org/pkg/crypto/tls/#Config + +## Set HTTP Version + +some servers need to specify the protocol version of HTTP + + httplib.Get("http://beego.me/").SetProtocolVersion("HTTP/1.1") + +## Set Cookie + +some http request need setcookie. So set it like this: + + cookie := &http.Cookie{} + cookie.Name = "username" + cookie.Value = "astaxie" + httplib.Get("http://beego.me/").SetCookie(cookie) + +## Upload file + +httplib support mutil file upload, use `req.PostFile()` + + req := httplib.Post("http://beego.me/") + req.Param("username","astaxie") + req.PostFile("uploadfile1", "httplib.pdf") + str, err := req.String() + if err != nil { + // error + } + fmt.Println(str) + + +See godoc for further documentation and examples. + +* [godoc.org/github.com/astaxie/beego/httplib](https://godoc.org/github.com/astaxie/beego/httplib) diff --git a/pkg/httplib/httplib.go b/pkg/httplib/httplib.go new file mode 100644 index 00000000..60aa4e8b --- /dev/null +++ b/pkg/httplib/httplib.go @@ -0,0 +1,654 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package httplib is used as http.Client +// Usage: +// +// import "github.com/astaxie/beego/httplib" +// +// b := httplib.Post("http://beego.me/") +// b.Param("username","astaxie") +// b.Param("password","123456") +// b.PostFile("uploadfile1", "httplib.pdf") +// b.PostFile("uploadfile2", "httplib.txt") +// str, err := b.String() +// if err != nil { +// t.Fatal(err) +// } +// fmt.Println(str) +// +// more docs http://beego.me/docs/module/httplib.md +package httplib + +import ( + "bytes" + "compress/gzip" + "crypto/tls" + "encoding/json" + "encoding/xml" + "io" + "io/ioutil" + "log" + "mime/multipart" + "net" + "net/http" + "net/http/cookiejar" + "net/http/httputil" + "net/url" + "os" + "path" + "strings" + "sync" + "time" + + "gopkg.in/yaml.v2" +) + +var defaultSetting = BeegoHTTPSettings{ + UserAgent: "beegoServer", + ConnectTimeout: 60 * time.Second, + ReadWriteTimeout: 60 * time.Second, + Gzip: true, + DumpBody: true, +} + +var defaultCookieJar http.CookieJar +var settingMutex sync.Mutex + +// createDefaultCookie creates a global cookiejar to store cookies. +func createDefaultCookie() { + settingMutex.Lock() + defer settingMutex.Unlock() + defaultCookieJar, _ = cookiejar.New(nil) +} + +// SetDefaultSetting Overwrite default settings +func SetDefaultSetting(setting BeegoHTTPSettings) { + settingMutex.Lock() + defer settingMutex.Unlock() + defaultSetting = setting +} + +// NewBeegoRequest return *BeegoHttpRequest with specific method +func NewBeegoRequest(rawurl, method string) *BeegoHTTPRequest { + var resp http.Response + u, err := url.Parse(rawurl) + if err != nil { + log.Println("Httplib:", err) + } + req := http.Request{ + URL: u, + Method: method, + Header: make(http.Header), + Proto: "HTTP/1.1", + ProtoMajor: 1, + ProtoMinor: 1, + } + return &BeegoHTTPRequest{ + url: rawurl, + req: &req, + params: map[string][]string{}, + files: map[string]string{}, + setting: defaultSetting, + resp: &resp, + } +} + +// Get returns *BeegoHttpRequest with GET method. +func Get(url string) *BeegoHTTPRequest { + return NewBeegoRequest(url, "GET") +} + +// Post returns *BeegoHttpRequest with POST method. +func Post(url string) *BeegoHTTPRequest { + return NewBeegoRequest(url, "POST") +} + +// Put returns *BeegoHttpRequest with PUT method. +func Put(url string) *BeegoHTTPRequest { + return NewBeegoRequest(url, "PUT") +} + +// Delete returns *BeegoHttpRequest DELETE method. +func Delete(url string) *BeegoHTTPRequest { + return NewBeegoRequest(url, "DELETE") +} + +// Head returns *BeegoHttpRequest with HEAD method. +func Head(url string) *BeegoHTTPRequest { + return NewBeegoRequest(url, "HEAD") +} + +// BeegoHTTPSettings is the http.Client setting +type BeegoHTTPSettings struct { + ShowDebug bool + UserAgent string + ConnectTimeout time.Duration + ReadWriteTimeout time.Duration + TLSClientConfig *tls.Config + Proxy func(*http.Request) (*url.URL, error) + Transport http.RoundTripper + CheckRedirect func(req *http.Request, via []*http.Request) error + EnableCookie bool + Gzip bool + DumpBody bool + Retries int // if set to -1 means will retry forever + RetryDelay time.Duration +} + +// BeegoHTTPRequest provides more useful methods for requesting one url than http.Request. +type BeegoHTTPRequest struct { + url string + req *http.Request + params map[string][]string + files map[string]string + setting BeegoHTTPSettings + resp *http.Response + body []byte + dump []byte +} + +// GetRequest return the request object +func (b *BeegoHTTPRequest) GetRequest() *http.Request { + return b.req +} + +// Setting Change request settings +func (b *BeegoHTTPRequest) Setting(setting BeegoHTTPSettings) *BeegoHTTPRequest { + b.setting = setting + return b +} + +// SetBasicAuth sets the request's Authorization header to use HTTP Basic Authentication with the provided username and password. +func (b *BeegoHTTPRequest) SetBasicAuth(username, password string) *BeegoHTTPRequest { + b.req.SetBasicAuth(username, password) + return b +} + +// SetEnableCookie sets enable/disable cookiejar +func (b *BeegoHTTPRequest) SetEnableCookie(enable bool) *BeegoHTTPRequest { + b.setting.EnableCookie = enable + return b +} + +// SetUserAgent sets User-Agent header field +func (b *BeegoHTTPRequest) SetUserAgent(useragent string) *BeegoHTTPRequest { + b.setting.UserAgent = useragent + return b +} + +// Debug sets show debug or not when executing request. +func (b *BeegoHTTPRequest) Debug(isdebug bool) *BeegoHTTPRequest { + b.setting.ShowDebug = isdebug + return b +} + +// Retries sets Retries times. +// default is 0 means no retried. +// -1 means retried forever. +// others means retried times. +func (b *BeegoHTTPRequest) Retries(times int) *BeegoHTTPRequest { + b.setting.Retries = times + return b +} + +func (b *BeegoHTTPRequest) RetryDelay(delay time.Duration) *BeegoHTTPRequest { + b.setting.RetryDelay = delay + return b +} + +// DumpBody setting whether need to Dump the Body. +func (b *BeegoHTTPRequest) DumpBody(isdump bool) *BeegoHTTPRequest { + b.setting.DumpBody = isdump + return b +} + +// DumpRequest return the DumpRequest +func (b *BeegoHTTPRequest) DumpRequest() []byte { + return b.dump +} + +// SetTimeout sets connect time out and read-write time out for BeegoRequest. +func (b *BeegoHTTPRequest) SetTimeout(connectTimeout, readWriteTimeout time.Duration) *BeegoHTTPRequest { + b.setting.ConnectTimeout = connectTimeout + b.setting.ReadWriteTimeout = readWriteTimeout + return b +} + +// SetTLSClientConfig sets tls connection configurations if visiting https url. +func (b *BeegoHTTPRequest) SetTLSClientConfig(config *tls.Config) *BeegoHTTPRequest { + b.setting.TLSClientConfig = config + return b +} + +// Header add header item string in request. +func (b *BeegoHTTPRequest) Header(key, value string) *BeegoHTTPRequest { + b.req.Header.Set(key, value) + return b +} + +// SetHost set the request host +func (b *BeegoHTTPRequest) SetHost(host string) *BeegoHTTPRequest { + b.req.Host = host + return b +} + +// SetProtocolVersion Set the protocol version for incoming requests. +// Client requests always use HTTP/1.1. +func (b *BeegoHTTPRequest) SetProtocolVersion(vers string) *BeegoHTTPRequest { + if len(vers) == 0 { + vers = "HTTP/1.1" + } + + major, minor, ok := http.ParseHTTPVersion(vers) + if ok { + b.req.Proto = vers + b.req.ProtoMajor = major + b.req.ProtoMinor = minor + } + + return b +} + +// SetCookie add cookie into request. +func (b *BeegoHTTPRequest) SetCookie(cookie *http.Cookie) *BeegoHTTPRequest { + b.req.Header.Add("Cookie", cookie.String()) + return b +} + +// SetTransport set the setting transport +func (b *BeegoHTTPRequest) SetTransport(transport http.RoundTripper) *BeegoHTTPRequest { + b.setting.Transport = transport + return b +} + +// SetProxy set the http proxy +// example: +// +// func(req *http.Request) (*url.URL, error) { +// u, _ := url.ParseRequestURI("http://127.0.0.1:8118") +// return u, nil +// } +func (b *BeegoHTTPRequest) SetProxy(proxy func(*http.Request) (*url.URL, error)) *BeegoHTTPRequest { + b.setting.Proxy = proxy + return b +} + +// SetCheckRedirect specifies the policy for handling redirects. +// +// If CheckRedirect is nil, the Client uses its default policy, +// which is to stop after 10 consecutive requests. +func (b *BeegoHTTPRequest) SetCheckRedirect(redirect func(req *http.Request, via []*http.Request) error) *BeegoHTTPRequest { + b.setting.CheckRedirect = redirect + return b +} + +// Param adds query param in to request. +// params build query string as ?key1=value1&key2=value2... +func (b *BeegoHTTPRequest) Param(key, value string) *BeegoHTTPRequest { + if param, ok := b.params[key]; ok { + b.params[key] = append(param, value) + } else { + b.params[key] = []string{value} + } + return b +} + +// PostFile add a post file to the request +func (b *BeegoHTTPRequest) PostFile(formname, filename string) *BeegoHTTPRequest { + b.files[formname] = filename + return b +} + +// Body adds request raw body. +// it supports string and []byte. +func (b *BeegoHTTPRequest) Body(data interface{}) *BeegoHTTPRequest { + switch t := data.(type) { + case string: + bf := bytes.NewBufferString(t) + b.req.Body = ioutil.NopCloser(bf) + b.req.ContentLength = int64(len(t)) + case []byte: + bf := bytes.NewBuffer(t) + b.req.Body = ioutil.NopCloser(bf) + b.req.ContentLength = int64(len(t)) + } + return b +} + +// XMLBody adds request raw body encoding by XML. +func (b *BeegoHTTPRequest) XMLBody(obj interface{}) (*BeegoHTTPRequest, error) { + if b.req.Body == nil && obj != nil { + byts, err := xml.Marshal(obj) + if err != nil { + return b, err + } + b.req.Body = ioutil.NopCloser(bytes.NewReader(byts)) + b.req.ContentLength = int64(len(byts)) + b.req.Header.Set("Content-Type", "application/xml") + } + return b, nil +} + +// YAMLBody adds request raw body encoding by YAML. +func (b *BeegoHTTPRequest) YAMLBody(obj interface{}) (*BeegoHTTPRequest, error) { + if b.req.Body == nil && obj != nil { + byts, err := yaml.Marshal(obj) + if err != nil { + return b, err + } + b.req.Body = ioutil.NopCloser(bytes.NewReader(byts)) + b.req.ContentLength = int64(len(byts)) + b.req.Header.Set("Content-Type", "application/x+yaml") + } + return b, nil +} + +// JSONBody adds request raw body encoding by JSON. +func (b *BeegoHTTPRequest) JSONBody(obj interface{}) (*BeegoHTTPRequest, error) { + if b.req.Body == nil && obj != nil { + byts, err := json.Marshal(obj) + if err != nil { + return b, err + } + b.req.Body = ioutil.NopCloser(bytes.NewReader(byts)) + b.req.ContentLength = int64(len(byts)) + b.req.Header.Set("Content-Type", "application/json") + } + return b, nil +} + +func (b *BeegoHTTPRequest) buildURL(paramBody string) { + // build GET url with query string + if b.req.Method == "GET" && len(paramBody) > 0 { + if strings.Contains(b.url, "?") { + b.url += "&" + paramBody + } else { + b.url = b.url + "?" + paramBody + } + return + } + + // build POST/PUT/PATCH url and body + if (b.req.Method == "POST" || b.req.Method == "PUT" || b.req.Method == "PATCH" || b.req.Method == "DELETE") && b.req.Body == nil { + // with files + if len(b.files) > 0 { + pr, pw := io.Pipe() + bodyWriter := multipart.NewWriter(pw) + go func() { + for formname, filename := range b.files { + fileWriter, err := bodyWriter.CreateFormFile(formname, filename) + if err != nil { + log.Println("Httplib:", err) + } + fh, err := os.Open(filename) + if err != nil { + log.Println("Httplib:", err) + } + //iocopy + _, err = io.Copy(fileWriter, fh) + fh.Close() + if err != nil { + log.Println("Httplib:", err) + } + } + for k, v := range b.params { + for _, vv := range v { + bodyWriter.WriteField(k, vv) + } + } + bodyWriter.Close() + pw.Close() + }() + b.Header("Content-Type", bodyWriter.FormDataContentType()) + b.req.Body = ioutil.NopCloser(pr) + b.Header("Transfer-Encoding", "chunked") + return + } + + // with params + if len(paramBody) > 0 { + b.Header("Content-Type", "application/x-www-form-urlencoded") + b.Body(paramBody) + } + } +} + +func (b *BeegoHTTPRequest) getResponse() (*http.Response, error) { + if b.resp.StatusCode != 0 { + return b.resp, nil + } + resp, err := b.DoRequest() + if err != nil { + return nil, err + } + b.resp = resp + return resp, nil +} + +// DoRequest will do the client.Do +func (b *BeegoHTTPRequest) DoRequest() (resp *http.Response, err error) { + var paramBody string + if len(b.params) > 0 { + var buf bytes.Buffer + for k, v := range b.params { + for _, vv := range v { + buf.WriteString(url.QueryEscape(k)) + buf.WriteByte('=') + buf.WriteString(url.QueryEscape(vv)) + buf.WriteByte('&') + } + } + paramBody = buf.String() + paramBody = paramBody[0 : len(paramBody)-1] + } + + b.buildURL(paramBody) + urlParsed, err := url.Parse(b.url) + if err != nil { + return nil, err + } + + b.req.URL = urlParsed + + trans := b.setting.Transport + + if trans == nil { + // create default transport + trans = &http.Transport{ + TLSClientConfig: b.setting.TLSClientConfig, + Proxy: b.setting.Proxy, + Dial: TimeoutDialer(b.setting.ConnectTimeout, b.setting.ReadWriteTimeout), + MaxIdleConnsPerHost: 100, + } + } else { + // if b.transport is *http.Transport then set the settings. + if t, ok := trans.(*http.Transport); ok { + if t.TLSClientConfig == nil { + t.TLSClientConfig = b.setting.TLSClientConfig + } + if t.Proxy == nil { + t.Proxy = b.setting.Proxy + } + if t.Dial == nil { + t.Dial = TimeoutDialer(b.setting.ConnectTimeout, b.setting.ReadWriteTimeout) + } + } + } + + var jar http.CookieJar + if b.setting.EnableCookie { + if defaultCookieJar == nil { + createDefaultCookie() + } + jar = defaultCookieJar + } + + client := &http.Client{ + Transport: trans, + Jar: jar, + } + + if b.setting.UserAgent != "" && b.req.Header.Get("User-Agent") == "" { + b.req.Header.Set("User-Agent", b.setting.UserAgent) + } + + if b.setting.CheckRedirect != nil { + client.CheckRedirect = b.setting.CheckRedirect + } + + if b.setting.ShowDebug { + dump, err := httputil.DumpRequest(b.req, b.setting.DumpBody) + if err != nil { + log.Println(err.Error()) + } + b.dump = dump + } + // retries default value is 0, it will run once. + // retries equal to -1, it will run forever until success + // retries is setted, it will retries fixed times. + // Sleeps for a 400ms inbetween calls to reduce spam + for i := 0; b.setting.Retries == -1 || i <= b.setting.Retries; i++ { + resp, err = client.Do(b.req) + if err == nil { + break + } + time.Sleep(b.setting.RetryDelay) + } + return resp, err +} + +// String returns the body string in response. +// it calls Response inner. +func (b *BeegoHTTPRequest) String() (string, error) { + data, err := b.Bytes() + if err != nil { + return "", err + } + + return string(data), nil +} + +// Bytes returns the body []byte in response. +// it calls Response inner. +func (b *BeegoHTTPRequest) Bytes() ([]byte, error) { + if b.body != nil { + return b.body, nil + } + resp, err := b.getResponse() + if err != nil { + return nil, err + } + if resp.Body == nil { + return nil, nil + } + defer resp.Body.Close() + if b.setting.Gzip && resp.Header.Get("Content-Encoding") == "gzip" { + reader, err := gzip.NewReader(resp.Body) + if err != nil { + return nil, err + } + b.body, err = ioutil.ReadAll(reader) + return b.body, err + } + b.body, err = ioutil.ReadAll(resp.Body) + return b.body, err +} + +// ToFile saves the body data in response to one file. +// it calls Response inner. +func (b *BeegoHTTPRequest) ToFile(filename string) error { + resp, err := b.getResponse() + if err != nil { + return err + } + if resp.Body == nil { + return nil + } + defer resp.Body.Close() + err = pathExistAndMkdir(filename) + if err != nil { + return err + } + f, err := os.Create(filename) + if err != nil { + return err + } + defer f.Close() + _, err = io.Copy(f, resp.Body) + return err +} + +//Check that the file directory exists, there is no automatically created +func pathExistAndMkdir(filename string) (err error) { + filename = path.Dir(filename) + _, err = os.Stat(filename) + if err == nil { + return nil + } + if os.IsNotExist(err) { + err = os.MkdirAll(filename, os.ModePerm) + if err == nil { + return nil + } + } + return err +} + +// ToJSON returns the map that marshals from the body bytes as json in response . +// it calls Response inner. +func (b *BeegoHTTPRequest) ToJSON(v interface{}) error { + data, err := b.Bytes() + if err != nil { + return err + } + return json.Unmarshal(data, v) +} + +// ToXML returns the map that marshals from the body bytes as xml in response . +// it calls Response inner. +func (b *BeegoHTTPRequest) ToXML(v interface{}) error { + data, err := b.Bytes() + if err != nil { + return err + } + return xml.Unmarshal(data, v) +} + +// ToYAML returns the map that marshals from the body bytes as yaml in response . +// it calls Response inner. +func (b *BeegoHTTPRequest) ToYAML(v interface{}) error { + data, err := b.Bytes() + if err != nil { + return err + } + return yaml.Unmarshal(data, v) +} + +// Response executes request client gets response mannually. +func (b *BeegoHTTPRequest) Response() (*http.Response, error) { + return b.getResponse() +} + +// TimeoutDialer returns functions of connection dialer with timeout settings for http.Transport Dial field. +func TimeoutDialer(cTimeout time.Duration, rwTimeout time.Duration) func(net, addr string) (c net.Conn, err error) { + return func(netw, addr string) (net.Conn, error) { + conn, err := net.DialTimeout(netw, addr, cTimeout) + if err != nil { + return nil, err + } + err = conn.SetDeadline(time.Now().Add(rwTimeout)) + return conn, err + } +} diff --git a/pkg/httplib/httplib_test.go b/pkg/httplib/httplib_test.go new file mode 100644 index 00000000..f6be8571 --- /dev/null +++ b/pkg/httplib/httplib_test.go @@ -0,0 +1,286 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package httplib + +import ( + "errors" + "io/ioutil" + "net" + "net/http" + "os" + "strings" + "testing" + "time" +) + +func TestResponse(t *testing.T) { + req := Get("http://httpbin.org/get") + resp, err := req.Response() + if err != nil { + t.Fatal(err) + } + t.Log(resp) +} + +func TestDoRequest(t *testing.T) { + req := Get("https://goolnk.com/33BD2j") + retryAmount := 1 + req.Retries(1) + req.RetryDelay(1400 * time.Millisecond) + retryDelay := 1400 * time.Millisecond + + req.setting.CheckRedirect = func(redirectReq *http.Request, redirectVia []*http.Request) error { + return errors.New("Redirect triggered") + } + + startTime := time.Now().UnixNano() / int64(time.Millisecond) + + _, err := req.Response() + if err == nil { + t.Fatal("Response should have yielded an error") + } + + endTime := time.Now().UnixNano() / int64(time.Millisecond) + elapsedTime := endTime - startTime + delayedTime := int64(retryAmount) * retryDelay.Milliseconds() + + if elapsedTime < delayedTime { + t.Errorf("Not enough retries. Took %dms. Delay was meant to take %dms", elapsedTime, delayedTime) + } + +} + +func TestGet(t *testing.T) { + req := Get("http://httpbin.org/get") + b, err := req.Bytes() + if err != nil { + t.Fatal(err) + } + t.Log(b) + + s, err := req.String() + if err != nil { + t.Fatal(err) + } + t.Log(s) + + if string(b) != s { + t.Fatal("request data not match") + } +} + +func TestSimplePost(t *testing.T) { + v := "smallfish" + req := Post("http://httpbin.org/post") + req.Param("username", v) + + str, err := req.String() + if err != nil { + t.Fatal(err) + } + t.Log(str) + + n := strings.Index(str, v) + if n == -1 { + t.Fatal(v + " not found in post") + } +} + +//func TestPostFile(t *testing.T) { +// v := "smallfish" +// req := Post("http://httpbin.org/post") +// req.Debug(true) +// req.Param("username", v) +// req.PostFile("uploadfile", "httplib_test.go") + +// str, err := req.String() +// if err != nil { +// t.Fatal(err) +// } +// t.Log(str) + +// n := strings.Index(str, v) +// if n == -1 { +// t.Fatal(v + " not found in post") +// } +//} + +func TestSimplePut(t *testing.T) { + str, err := Put("http://httpbin.org/put").String() + if err != nil { + t.Fatal(err) + } + t.Log(str) +} + +func TestSimpleDelete(t *testing.T) { + str, err := Delete("http://httpbin.org/delete").String() + if err != nil { + t.Fatal(err) + } + t.Log(str) +} + +func TestSimpleDeleteParam(t *testing.T) { + str, err := Delete("http://httpbin.org/delete").Param("key", "val").String() + if err != nil { + t.Fatal(err) + } + t.Log(str) +} + +func TestWithCookie(t *testing.T) { + v := "smallfish" + str, err := Get("http://httpbin.org/cookies/set?k1=" + v).SetEnableCookie(true).String() + if err != nil { + t.Fatal(err) + } + t.Log(str) + + str, err = Get("http://httpbin.org/cookies").SetEnableCookie(true).String() + if err != nil { + t.Fatal(err) + } + t.Log(str) + + n := strings.Index(str, v) + if n == -1 { + t.Fatal(v + " not found in cookie") + } +} + +func TestWithBasicAuth(t *testing.T) { + str, err := Get("http://httpbin.org/basic-auth/user/passwd").SetBasicAuth("user", "passwd").String() + if err != nil { + t.Fatal(err) + } + t.Log(str) + n := strings.Index(str, "authenticated") + if n == -1 { + t.Fatal("authenticated not found in response") + } +} + +func TestWithUserAgent(t *testing.T) { + v := "beego" + str, err := Get("http://httpbin.org/headers").SetUserAgent(v).String() + if err != nil { + t.Fatal(err) + } + t.Log(str) + + n := strings.Index(str, v) + if n == -1 { + t.Fatal(v + " not found in user-agent") + } +} + +func TestWithSetting(t *testing.T) { + v := "beego" + var setting BeegoHTTPSettings + setting.EnableCookie = true + setting.UserAgent = v + setting.Transport = &http.Transport{ + DialContext: (&net.Dialer{ + Timeout: 30 * time.Second, + KeepAlive: 30 * time.Second, + DualStack: true, + }).DialContext, + MaxIdleConns: 50, + IdleConnTimeout: 90 * time.Second, + ExpectContinueTimeout: 1 * time.Second, + } + setting.ReadWriteTimeout = 5 * time.Second + SetDefaultSetting(setting) + + str, err := Get("http://httpbin.org/get").String() + if err != nil { + t.Fatal(err) + } + t.Log(str) + + n := strings.Index(str, v) + if n == -1 { + t.Fatal(v + " not found in user-agent") + } +} + +func TestToJson(t *testing.T) { + req := Get("http://httpbin.org/ip") + resp, err := req.Response() + if err != nil { + t.Fatal(err) + } + t.Log(resp) + + // httpbin will return http remote addr + type IP struct { + Origin string `json:"origin"` + } + var ip IP + err = req.ToJSON(&ip) + if err != nil { + t.Fatal(err) + } + t.Log(ip.Origin) + ips := strings.Split(ip.Origin, ",") + if len(ips) == 0 { + t.Fatal("response is not valid ip") + } + for i := range ips { + if net.ParseIP(strings.TrimSpace(ips[i])).To4() == nil { + t.Fatal("response is not valid ip") + } + } + +} + +func TestToFile(t *testing.T) { + f := "beego_testfile" + req := Get("http://httpbin.org/ip") + err := req.ToFile(f) + if err != nil { + t.Fatal(err) + } + defer os.Remove(f) + b, err := ioutil.ReadFile(f) + if n := strings.Index(string(b), "origin"); n == -1 { + t.Fatal(err) + } +} + +func TestToFileDir(t *testing.T) { + f := "./files/beego_testfile" + req := Get("http://httpbin.org/ip") + err := req.ToFile(f) + if err != nil { + t.Fatal(err) + } + defer os.RemoveAll("./files") + b, err := ioutil.ReadFile(f) + if n := strings.Index(string(b), "origin"); n == -1 { + t.Fatal(err) + } +} + +func TestHeader(t *testing.T) { + req := Get("http://httpbin.org/headers") + req.Header("User-Agent", "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_9_0) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/31.0.1650.57 Safari/537.36") + str, err := req.String() + if err != nil { + t.Fatal(err) + } + t.Log(str) +} diff --git a/pkg/log.go b/pkg/log.go new file mode 100644 index 00000000..cc4c0f81 --- /dev/null +++ b/pkg/log.go @@ -0,0 +1,127 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package beego + +import ( + "strings" + + "github.com/astaxie/beego/logs" +) + +// Log levels to control the logging output. +// Deprecated: use github.com/astaxie/beego/logs instead. +const ( + LevelEmergency = iota + LevelAlert + LevelCritical + LevelError + LevelWarning + LevelNotice + LevelInformational + LevelDebug +) + +// BeeLogger references the used application logger. +// Deprecated: use github.com/astaxie/beego/logs instead. +var BeeLogger = logs.GetBeeLogger() + +// SetLevel sets the global log level used by the simple logger. +// Deprecated: use github.com/astaxie/beego/logs instead. +func SetLevel(l int) { + logs.SetLevel(l) +} + +// SetLogFuncCall set the CallDepth, default is 3 +// Deprecated: use github.com/astaxie/beego/logs instead. +func SetLogFuncCall(b bool) { + logs.SetLogFuncCall(b) +} + +// SetLogger sets a new logger. +// Deprecated: use github.com/astaxie/beego/logs instead. +func SetLogger(adaptername string, config string) error { + return logs.SetLogger(adaptername, config) +} + +// Emergency logs a message at emergency level. +// Deprecated: use github.com/astaxie/beego/logs instead. +func Emergency(v ...interface{}) { + logs.Emergency(generateFmtStr(len(v)), v...) +} + +// Alert logs a message at alert level. +// Deprecated: use github.com/astaxie/beego/logs instead. +func Alert(v ...interface{}) { + logs.Alert(generateFmtStr(len(v)), v...) +} + +// Critical logs a message at critical level. +// Deprecated: use github.com/astaxie/beego/logs instead. +func Critical(v ...interface{}) { + logs.Critical(generateFmtStr(len(v)), v...) +} + +// Error logs a message at error level. +// Deprecated: use github.com/astaxie/beego/logs instead. +func Error(v ...interface{}) { + logs.Error(generateFmtStr(len(v)), v...) +} + +// Warning logs a message at warning level. +// Deprecated: use github.com/astaxie/beego/logs instead. +func Warning(v ...interface{}) { + logs.Warning(generateFmtStr(len(v)), v...) +} + +// Warn compatibility alias for Warning() +// Deprecated: use github.com/astaxie/beego/logs instead. +func Warn(v ...interface{}) { + logs.Warn(generateFmtStr(len(v)), v...) +} + +// Notice logs a message at notice level. +// Deprecated: use github.com/astaxie/beego/logs instead. +func Notice(v ...interface{}) { + logs.Notice(generateFmtStr(len(v)), v...) +} + +// Informational logs a message at info level. +// Deprecated: use github.com/astaxie/beego/logs instead. +func Informational(v ...interface{}) { + logs.Informational(generateFmtStr(len(v)), v...) +} + +// Info compatibility alias for Warning() +// Deprecated: use github.com/astaxie/beego/logs instead. +func Info(v ...interface{}) { + logs.Info(generateFmtStr(len(v)), v...) +} + +// Debug logs a message at debug level. +// Deprecated: use github.com/astaxie/beego/logs instead. +func Debug(v ...interface{}) { + logs.Debug(generateFmtStr(len(v)), v...) +} + +// Trace logs a message at trace level. +// compatibility alias for Warning() +// Deprecated: use github.com/astaxie/beego/logs instead. +func Trace(v ...interface{}) { + logs.Trace(generateFmtStr(len(v)), v...) +} + +func generateFmtStr(n int) string { + return strings.Repeat("%v ", n) +} diff --git a/pkg/logs/README.md b/pkg/logs/README.md new file mode 100644 index 00000000..c05bcc04 --- /dev/null +++ b/pkg/logs/README.md @@ -0,0 +1,72 @@ +## logs +logs is a Go logs manager. It can use many logs adapters. The repo is inspired by `database/sql` . + + +## How to install? + + go get github.com/astaxie/beego/logs + + +## What adapters are supported? + +As of now this logs support console, file,smtp and conn. + + +## How to use it? + +First you must import it + +```golang +import ( + "github.com/astaxie/beego/logs" +) +``` + +Then init a Log (example with console adapter) + +```golang +log := logs.NewLogger(10000) +log.SetLogger("console", "") +``` + +> the first params stand for how many channel + +Use it like this: + +```golang +log.Trace("trace") +log.Info("info") +log.Warn("warning") +log.Debug("debug") +log.Critical("critical") +``` + +## File adapter + +Configure file adapter like this: + +```golang +log := NewLogger(10000) +log.SetLogger("file", `{"filename":"test.log"}`) +``` + +## Conn adapter + +Configure like this: + +```golang +log := NewLogger(1000) +log.SetLogger("conn", `{"net":"tcp","addr":":7020"}`) +log.Info("info") +``` + +## Smtp adapter + +Configure like this: + +```golang +log := NewLogger(10000) +log.SetLogger("smtp", `{"username":"beegotest@gmail.com","password":"xxxxxxxx","host":"smtp.gmail.com:587","sendTos":["xiemengjun@gmail.com"]}`) +log.Critical("sendmail critical") +time.Sleep(time.Second * 30) +``` diff --git a/pkg/logs/accesslog.go b/pkg/logs/accesslog.go new file mode 100644 index 00000000..3ff9e20f --- /dev/null +++ b/pkg/logs/accesslog.go @@ -0,0 +1,83 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package logs + +import ( + "bytes" + "strings" + "encoding/json" + "fmt" + "time" +) + +const ( + apacheFormatPattern = "%s - - [%s] \"%s %d %d\" %f %s %s" + apacheFormat = "APACHE_FORMAT" + jsonFormat = "JSON_FORMAT" +) + +// AccessLogRecord struct for holding access log data. +type AccessLogRecord struct { + RemoteAddr string `json:"remote_addr"` + RequestTime time.Time `json:"request_time"` + RequestMethod string `json:"request_method"` + Request string `json:"request"` + ServerProtocol string `json:"server_protocol"` + Host string `json:"host"` + Status int `json:"status"` + BodyBytesSent int64 `json:"body_bytes_sent"` + ElapsedTime time.Duration `json:"elapsed_time"` + HTTPReferrer string `json:"http_referrer"` + HTTPUserAgent string `json:"http_user_agent"` + RemoteUser string `json:"remote_user"` +} + +func (r *AccessLogRecord) json() ([]byte, error) { + buffer := &bytes.Buffer{} + encoder := json.NewEncoder(buffer) + disableEscapeHTML(encoder) + + err := encoder.Encode(r) + return buffer.Bytes(), err +} + +func disableEscapeHTML(i interface{}) { + if e, ok := i.(interface { + SetEscapeHTML(bool) + }); ok { + e.SetEscapeHTML(false) + } +} + +// AccessLog - Format and print access log. +func AccessLog(r *AccessLogRecord, format string) { + var msg string + switch format { + case apacheFormat: + timeFormatted := r.RequestTime.Format("02/Jan/2006 03:04:05") + msg = fmt.Sprintf(apacheFormatPattern, r.RemoteAddr, timeFormatted, r.Request, r.Status, r.BodyBytesSent, + r.ElapsedTime.Seconds(), r.HTTPReferrer, r.HTTPUserAgent) + case jsonFormat: + fallthrough + default: + jsonData, err := r.json() + if err != nil { + msg = fmt.Sprintf(`{"Error": "%s"}`, err) + } else { + msg = string(jsonData) + } + } + beeLogger.writeMsg(levelLoggerImpl, strings.TrimSpace(msg)) +} diff --git a/pkg/logs/alils/alils.go b/pkg/logs/alils/alils.go new file mode 100644 index 00000000..867ff4cb --- /dev/null +++ b/pkg/logs/alils/alils.go @@ -0,0 +1,186 @@ +package alils + +import ( + "encoding/json" + "strings" + "sync" + "time" + + "github.com/astaxie/beego/logs" + "github.com/gogo/protobuf/proto" +) + +const ( + // CacheSize set the flush size + CacheSize int = 64 + // Delimiter define the topic delimiter + Delimiter string = "##" +) + +// Config is the Config for Ali Log +type Config struct { + Project string `json:"project"` + Endpoint string `json:"endpoint"` + KeyID string `json:"key_id"` + KeySecret string `json:"key_secret"` + LogStore string `json:"log_store"` + Topics []string `json:"topics"` + Source string `json:"source"` + Level int `json:"level"` + FlushWhen int `json:"flush_when"` +} + +// aliLSWriter implements LoggerInterface. +// it writes messages in keep-live tcp connection. +type aliLSWriter struct { + store *LogStore + group []*LogGroup + withMap bool + groupMap map[string]*LogGroup + lock *sync.Mutex + Config +} + +// NewAliLS create a new Logger +func NewAliLS() logs.Logger { + alils := new(aliLSWriter) + alils.Level = logs.LevelTrace + return alils +} + +// Init parse config and init struct +func (c *aliLSWriter) Init(jsonConfig string) (err error) { + + json.Unmarshal([]byte(jsonConfig), c) + + if c.FlushWhen > CacheSize { + c.FlushWhen = CacheSize + } + + prj := &LogProject{ + Name: c.Project, + Endpoint: c.Endpoint, + AccessKeyID: c.KeyID, + AccessKeySecret: c.KeySecret, + } + + c.store, err = prj.GetLogStore(c.LogStore) + if err != nil { + return err + } + + // Create default Log Group + c.group = append(c.group, &LogGroup{ + Topic: proto.String(""), + Source: proto.String(c.Source), + Logs: make([]*Log, 0, c.FlushWhen), + }) + + // Create other Log Group + c.groupMap = make(map[string]*LogGroup) + for _, topic := range c.Topics { + + lg := &LogGroup{ + Topic: proto.String(topic), + Source: proto.String(c.Source), + Logs: make([]*Log, 0, c.FlushWhen), + } + + c.group = append(c.group, lg) + c.groupMap[topic] = lg + } + + if len(c.group) == 1 { + c.withMap = false + } else { + c.withMap = true + } + + c.lock = &sync.Mutex{} + + return nil +} + +// WriteMsg write message in connection. +// if connection is down, try to re-connect. +func (c *aliLSWriter) WriteMsg(when time.Time, msg string, level int) (err error) { + + if level > c.Level { + return nil + } + + var topic string + var content string + var lg *LogGroup + if c.withMap { + + // Topic,LogGroup + strs := strings.SplitN(msg, Delimiter, 2) + if len(strs) == 2 { + pos := strings.LastIndex(strs[0], " ") + topic = strs[0][pos+1 : len(strs[0])] + content = strs[0][0:pos] + strs[1] + lg = c.groupMap[topic] + } + + // send to empty Topic + if lg == nil { + content = msg + lg = c.group[0] + } + } else { + content = msg + lg = c.group[0] + } + + c1 := &LogContent{ + Key: proto.String("msg"), + Value: proto.String(content), + } + + l := &Log{ + Time: proto.Uint32(uint32(when.Unix())), + Contents: []*LogContent{ + c1, + }, + } + + c.lock.Lock() + lg.Logs = append(lg.Logs, l) + c.lock.Unlock() + + if len(lg.Logs) >= c.FlushWhen { + c.flush(lg) + } + + return nil +} + +// Flush implementing method. empty. +func (c *aliLSWriter) Flush() { + + // flush all group + for _, lg := range c.group { + c.flush(lg) + } +} + +// Destroy destroy connection writer and close tcp listener. +func (c *aliLSWriter) Destroy() { +} + +func (c *aliLSWriter) flush(lg *LogGroup) { + + c.lock.Lock() + defer c.lock.Unlock() + err := c.store.PutLogs(lg) + if err != nil { + return + } + + lg.Logs = make([]*Log, 0, c.FlushWhen) +} + +func init() { + logs.Register(logs.AdapterAliLS, NewAliLS) +} diff --git a/pkg/logs/alils/config.go b/pkg/logs/alils/config.go new file mode 100755 index 00000000..e8c24448 --- /dev/null +++ b/pkg/logs/alils/config.go @@ -0,0 +1,13 @@ +package alils + +const ( + version = "0.5.0" // SDK version + signatureMethod = "hmac-sha1" // Signature method + + // OffsetNewest stands for the log head offset, i.e. the offset that will be + // assigned to the next message that will be produced to the shard. + OffsetNewest = "end" + // OffsetOldest stands for the oldest offset available on the logstore for a + // shard. + OffsetOldest = "begin" +) diff --git a/pkg/logs/alils/log.pb.go b/pkg/logs/alils/log.pb.go new file mode 100755 index 00000000..601b0d78 --- /dev/null +++ b/pkg/logs/alils/log.pb.go @@ -0,0 +1,1038 @@ +package alils + +import ( + "fmt" + "io" + "math" + + "github.com/gogo/protobuf/proto" + github_com_gogo_protobuf_proto "github.com/gogo/protobuf/proto" +) + +// Reference imports to suppress errors if they are not otherwise used. +var _ = proto.Marshal +var _ = fmt.Errorf +var _ = math.Inf + +var ( + // ErrInvalidLengthLog invalid proto + ErrInvalidLengthLog = fmt.Errorf("proto: negative length found during unmarshaling") + // ErrIntOverflowLog overflow + ErrIntOverflowLog = fmt.Errorf("proto: integer overflow") +) + +// Log define the proto Log +type Log struct { + Time *uint32 `protobuf:"varint,1,req,name=Time" json:"Time,omitempty"` + Contents []*LogContent `protobuf:"bytes,2,rep,name=Contents" json:"Contents,omitempty"` + XXXUnrecognized []byte `json:"-"` +} + +// Reset the Log +func (m *Log) Reset() { *m = Log{} } + +// String return the Compact Log +func (m *Log) String() string { return proto.CompactTextString(m) } + +// ProtoMessage not implemented +func (*Log) ProtoMessage() {} + +// GetTime return the Log's Time +func (m *Log) GetTime() uint32 { + if m != nil && m.Time != nil { + return *m.Time + } + return 0 +} + +// GetContents return the Log's Contents +func (m *Log) GetContents() []*LogContent { + if m != nil { + return m.Contents + } + return nil +} + +// LogContent define the Log content struct +type LogContent struct { + Key *string `protobuf:"bytes,1,req,name=Key" json:"Key,omitempty"` + Value *string `protobuf:"bytes,2,req,name=Value" json:"Value,omitempty"` + XXXUnrecognized []byte `json:"-"` +} + +// Reset LogContent +func (m *LogContent) Reset() { *m = LogContent{} } + +// String return the compact text +func (m *LogContent) String() string { return proto.CompactTextString(m) } + +// ProtoMessage not implemented +func (*LogContent) ProtoMessage() {} + +// GetKey return the Key +func (m *LogContent) GetKey() string { + if m != nil && m.Key != nil { + return *m.Key + } + return "" +} + +// GetValue return the Value +func (m *LogContent) GetValue() string { + if m != nil && m.Value != nil { + return *m.Value + } + return "" +} + +// LogGroup define the logs struct +type LogGroup struct { + Logs []*Log `protobuf:"bytes,1,rep,name=Logs" json:"Logs,omitempty"` + Reserved *string `protobuf:"bytes,2,opt,name=Reserved" json:"Reserved,omitempty"` + Topic *string `protobuf:"bytes,3,opt,name=Topic" json:"Topic,omitempty"` + Source *string `protobuf:"bytes,4,opt,name=Source" json:"Source,omitempty"` + XXXUnrecognized []byte `json:"-"` +} + +// Reset LogGroup +func (m *LogGroup) Reset() { *m = LogGroup{} } + +// String return the compact text +func (m *LogGroup) String() string { return proto.CompactTextString(m) } + +// ProtoMessage not implemented +func (*LogGroup) ProtoMessage() {} + +// GetLogs return the loggroup logs +func (m *LogGroup) GetLogs() []*Log { + if m != nil { + return m.Logs + } + return nil +} + +// GetReserved return Reserved +func (m *LogGroup) GetReserved() string { + if m != nil && m.Reserved != nil { + return *m.Reserved + } + return "" +} + +// GetTopic return Topic +func (m *LogGroup) GetTopic() string { + if m != nil && m.Topic != nil { + return *m.Topic + } + return "" +} + +// GetSource return Source +func (m *LogGroup) GetSource() string { + if m != nil && m.Source != nil { + return *m.Source + } + return "" +} + +// LogGroupList define the LogGroups +type LogGroupList struct { + LogGroups []*LogGroup `protobuf:"bytes,1,rep,name=logGroups" json:"logGroups,omitempty"` + XXXUnrecognized []byte `json:"-"` +} + +// Reset LogGroupList +func (m *LogGroupList) Reset() { *m = LogGroupList{} } + +// String return compact text +func (m *LogGroupList) String() string { return proto.CompactTextString(m) } + +// ProtoMessage not implemented +func (*LogGroupList) ProtoMessage() {} + +// GetLogGroups return the LogGroups +func (m *LogGroupList) GetLogGroups() []*LogGroup { + if m != nil { + return m.LogGroups + } + return nil +} + +// Marshal the logs to byte slice +func (m *Log) Marshal() (data []byte, err error) { + size := m.Size() + data = make([]byte, size) + n, err := m.MarshalTo(data) + if err != nil { + return nil, err + } + return data[:n], nil +} + +// MarshalTo data +func (m *Log) MarshalTo(data []byte) (int, error) { + var i int + _ = i + var l int + _ = l + if m.Time == nil { + return 0, github_com_gogo_protobuf_proto.NewRequiredNotSetError("Time") + } + data[i] = 0x8 + i++ + i = encodeVarintLog(data, i, uint64(*m.Time)) + if len(m.Contents) > 0 { + for _, msg := range m.Contents { + data[i] = 0x12 + i++ + i = encodeVarintLog(data, i, uint64(msg.Size())) + n, err := msg.MarshalTo(data[i:]) + if err != nil { + return 0, err + } + i += n + } + } + if m.XXXUnrecognized != nil { + i += copy(data[i:], m.XXXUnrecognized) + } + return i, nil +} + +// Marshal LogContent +func (m *LogContent) Marshal() (data []byte, err error) { + size := m.Size() + data = make([]byte, size) + n, err := m.MarshalTo(data) + if err != nil { + return nil, err + } + return data[:n], nil +} + +// MarshalTo logcontent to data +func (m *LogContent) MarshalTo(data []byte) (int, error) { + var i int + _ = i + var l int + _ = l + if m.Key == nil { + return 0, github_com_gogo_protobuf_proto.NewRequiredNotSetError("Key") + } + data[i] = 0xa + i++ + i = encodeVarintLog(data, i, uint64(len(*m.Key))) + i += copy(data[i:], *m.Key) + + if m.Value == nil { + return 0, github_com_gogo_protobuf_proto.NewRequiredNotSetError("Value") + } + data[i] = 0x12 + i++ + i = encodeVarintLog(data, i, uint64(len(*m.Value))) + i += copy(data[i:], *m.Value) + if m.XXXUnrecognized != nil { + i += copy(data[i:], m.XXXUnrecognized) + } + return i, nil +} + +// Marshal LogGroup +func (m *LogGroup) Marshal() (data []byte, err error) { + size := m.Size() + data = make([]byte, size) + n, err := m.MarshalTo(data) + if err != nil { + return nil, err + } + return data[:n], nil +} + +// MarshalTo LogGroup to data +func (m *LogGroup) MarshalTo(data []byte) (int, error) { + var i int + _ = i + var l int + _ = l + if len(m.Logs) > 0 { + for _, msg := range m.Logs { + data[i] = 0xa + i++ + i = encodeVarintLog(data, i, uint64(msg.Size())) + n, err := msg.MarshalTo(data[i:]) + if err != nil { + return 0, err + } + i += n + } + } + if m.Reserved != nil { + data[i] = 0x12 + i++ + i = encodeVarintLog(data, i, uint64(len(*m.Reserved))) + i += copy(data[i:], *m.Reserved) + } + if m.Topic != nil { + data[i] = 0x1a + i++ + i = encodeVarintLog(data, i, uint64(len(*m.Topic))) + i += copy(data[i:], *m.Topic) + } + if m.Source != nil { + data[i] = 0x22 + i++ + i = encodeVarintLog(data, i, uint64(len(*m.Source))) + i += copy(data[i:], *m.Source) + } + if m.XXXUnrecognized != nil { + i += copy(data[i:], m.XXXUnrecognized) + } + return i, nil +} + +// Marshal LogGroupList +func (m *LogGroupList) Marshal() (data []byte, err error) { + size := m.Size() + data = make([]byte, size) + n, err := m.MarshalTo(data) + if err != nil { + return nil, err + } + return data[:n], nil +} + +// MarshalTo LogGroupList to data +func (m *LogGroupList) MarshalTo(data []byte) (int, error) { + var i int + _ = i + var l int + _ = l + if len(m.LogGroups) > 0 { + for _, msg := range m.LogGroups { + data[i] = 0xa + i++ + i = encodeVarintLog(data, i, uint64(msg.Size())) + n, err := msg.MarshalTo(data[i:]) + if err != nil { + return 0, err + } + i += n + } + } + if m.XXXUnrecognized != nil { + i += copy(data[i:], m.XXXUnrecognized) + } + return i, nil +} + +func encodeFixed64Log(data []byte, offset int, v uint64) int { + data[offset] = uint8(v) + data[offset+1] = uint8(v >> 8) + data[offset+2] = uint8(v >> 16) + data[offset+3] = uint8(v >> 24) + data[offset+4] = uint8(v >> 32) + data[offset+5] = uint8(v >> 40) + data[offset+6] = uint8(v >> 48) + data[offset+7] = uint8(v >> 56) + return offset + 8 +} +func encodeFixed32Log(data []byte, offset int, v uint32) int { + data[offset] = uint8(v) + data[offset+1] = uint8(v >> 8) + data[offset+2] = uint8(v >> 16) + data[offset+3] = uint8(v >> 24) + return offset + 4 +} +func encodeVarintLog(data []byte, offset int, v uint64) int { + for v >= 1<<7 { + data[offset] = uint8(v&0x7f | 0x80) + v >>= 7 + offset++ + } + data[offset] = uint8(v) + return offset + 1 +} + +// Size return the log's size +func (m *Log) Size() (n int) { + var l int + _ = l + if m.Time != nil { + n += 1 + sovLog(uint64(*m.Time)) + } + if len(m.Contents) > 0 { + for _, e := range m.Contents { + l = e.Size() + n += 1 + l + sovLog(uint64(l)) + } + } + if m.XXXUnrecognized != nil { + n += len(m.XXXUnrecognized) + } + return n +} + +// Size return LogContent size based on Key and Value +func (m *LogContent) Size() (n int) { + var l int + _ = l + if m.Key != nil { + l = len(*m.Key) + n += 1 + l + sovLog(uint64(l)) + } + if m.Value != nil { + l = len(*m.Value) + n += 1 + l + sovLog(uint64(l)) + } + if m.XXXUnrecognized != nil { + n += len(m.XXXUnrecognized) + } + return n +} + +// Size return LogGroup size based on Logs +func (m *LogGroup) Size() (n int) { + var l int + _ = l + if len(m.Logs) > 0 { + for _, e := range m.Logs { + l = e.Size() + n += 1 + l + sovLog(uint64(l)) + } + } + if m.Reserved != nil { + l = len(*m.Reserved) + n += 1 + l + sovLog(uint64(l)) + } + if m.Topic != nil { + l = len(*m.Topic) + n += 1 + l + sovLog(uint64(l)) + } + if m.Source != nil { + l = len(*m.Source) + n += 1 + l + sovLog(uint64(l)) + } + if m.XXXUnrecognized != nil { + n += len(m.XXXUnrecognized) + } + return n +} + +// Size return LogGroupList size +func (m *LogGroupList) Size() (n int) { + var l int + _ = l + if len(m.LogGroups) > 0 { + for _, e := range m.LogGroups { + l = e.Size() + n += 1 + l + sovLog(uint64(l)) + } + } + if m.XXXUnrecognized != nil { + n += len(m.XXXUnrecognized) + } + return n +} + +func sovLog(x uint64) (n int) { + for { + n++ + x >>= 7 + if x == 0 { + break + } + } + return n +} +func sozLog(x uint64) (n int) { + return sovLog((x << 1) ^ (x >> 63)) +} + +// Unmarshal data to log +func (m *Log) Unmarshal(data []byte) error { + var hasFields [1]uint64 + l := len(data) + iNdEx := 0 + for iNdEx < l { + preIndex := iNdEx + var wire uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowLog + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := data[iNdEx] + iNdEx++ + wire |= (uint64(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + fieldNum := int32(wire >> 3) + wireType := int(wire & 0x7) + if wireType == 4 { + return fmt.Errorf("proto: Log: wiretype end group for non-group") + } + if fieldNum <= 0 { + return fmt.Errorf("proto: Log: illegal tag %d (wire type %d)", fieldNum, wire) + } + switch fieldNum { + case 1: + if wireType != 0 { + return fmt.Errorf("proto: wrong wireType = %d for field Time", wireType) + } + var v uint32 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowLog + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := data[iNdEx] + iNdEx++ + v |= (uint32(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + m.Time = &v + hasFields[0] |= uint64(0x00000001) + case 2: + if wireType != 2 { + return fmt.Errorf("proto: wrong wireType = %d for field Contents", wireType) + } + var msglen int + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowLog + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := data[iNdEx] + iNdEx++ + msglen |= (int(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + if msglen < 0 { + return ErrInvalidLengthLog + } + postIndex := iNdEx + msglen + if postIndex > l { + return io.ErrUnexpectedEOF + } + m.Contents = append(m.Contents, &LogContent{}) + if err := m.Contents[len(m.Contents)-1].Unmarshal(data[iNdEx:postIndex]); err != nil { + return err + } + iNdEx = postIndex + default: + iNdEx = preIndex + skippy, err := skipLog(data[iNdEx:]) + if err != nil { + return err + } + if skippy < 0 { + return ErrInvalidLengthLog + } + if (iNdEx + skippy) > l { + return io.ErrUnexpectedEOF + } + m.XXXUnrecognized = append(m.XXXUnrecognized, data[iNdEx:iNdEx+skippy]...) + iNdEx += skippy + } + } + if hasFields[0]&uint64(0x00000001) == 0 { + return github_com_gogo_protobuf_proto.NewRequiredNotSetError("Time") + } + + if iNdEx > l { + return io.ErrUnexpectedEOF + } + return nil +} + +// Unmarshal data to LogContent +func (m *LogContent) Unmarshal(data []byte) error { + var hasFields [1]uint64 + l := len(data) + iNdEx := 0 + for iNdEx < l { + preIndex := iNdEx + var wire uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowLog + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := data[iNdEx] + iNdEx++ + wire |= (uint64(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + fieldNum := int32(wire >> 3) + wireType := int(wire & 0x7) + if wireType == 4 { + return fmt.Errorf("proto: Content: wiretype end group for non-group") + } + if fieldNum <= 0 { + return fmt.Errorf("proto: Content: illegal tag %d (wire type %d)", fieldNum, wire) + } + switch fieldNum { + case 1: + if wireType != 2 { + return fmt.Errorf("proto: wrong wireType = %d for field Key", wireType) + } + var stringLen uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowLog + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := data[iNdEx] + iNdEx++ + stringLen |= (uint64(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + intStringLen := int(stringLen) + if intStringLen < 0 { + return ErrInvalidLengthLog + } + postIndex := iNdEx + intStringLen + if postIndex > l { + return io.ErrUnexpectedEOF + } + s := string(data[iNdEx:postIndex]) + m.Key = &s + iNdEx = postIndex + hasFields[0] |= uint64(0x00000001) + case 2: + if wireType != 2 { + return fmt.Errorf("proto: wrong wireType = %d for field Value", wireType) + } + var stringLen uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowLog + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := data[iNdEx] + iNdEx++ + stringLen |= (uint64(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + intStringLen := int(stringLen) + if intStringLen < 0 { + return ErrInvalidLengthLog + } + postIndex := iNdEx + intStringLen + if postIndex > l { + return io.ErrUnexpectedEOF + } + s := string(data[iNdEx:postIndex]) + m.Value = &s + iNdEx = postIndex + hasFields[0] |= uint64(0x00000002) + default: + iNdEx = preIndex + skippy, err := skipLog(data[iNdEx:]) + if err != nil { + return err + } + if skippy < 0 { + return ErrInvalidLengthLog + } + if (iNdEx + skippy) > l { + return io.ErrUnexpectedEOF + } + m.XXXUnrecognized = append(m.XXXUnrecognized, data[iNdEx:iNdEx+skippy]...) + iNdEx += skippy + } + } + if hasFields[0]&uint64(0x00000001) == 0 { + return github_com_gogo_protobuf_proto.NewRequiredNotSetError("Key") + } + if hasFields[0]&uint64(0x00000002) == 0 { + return github_com_gogo_protobuf_proto.NewRequiredNotSetError("Value") + } + + if iNdEx > l { + return io.ErrUnexpectedEOF + } + return nil +} + +// Unmarshal data to LogGroup +func (m *LogGroup) Unmarshal(data []byte) error { + l := len(data) + iNdEx := 0 + for iNdEx < l { + preIndex := iNdEx + var wire uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowLog + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := data[iNdEx] + iNdEx++ + wire |= (uint64(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + fieldNum := int32(wire >> 3) + wireType := int(wire & 0x7) + if wireType == 4 { + return fmt.Errorf("proto: LogGroup: wiretype end group for non-group") + } + if fieldNum <= 0 { + return fmt.Errorf("proto: LogGroup: illegal tag %d (wire type %d)", fieldNum, wire) + } + switch fieldNum { + case 1: + if wireType != 2 { + return fmt.Errorf("proto: wrong wireType = %d for field Logs", wireType) + } + var msglen int + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowLog + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := data[iNdEx] + iNdEx++ + msglen |= (int(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + if msglen < 0 { + return ErrInvalidLengthLog + } + postIndex := iNdEx + msglen + if postIndex > l { + return io.ErrUnexpectedEOF + } + m.Logs = append(m.Logs, &Log{}) + if err := m.Logs[len(m.Logs)-1].Unmarshal(data[iNdEx:postIndex]); err != nil { + return err + } + iNdEx = postIndex + case 2: + if wireType != 2 { + return fmt.Errorf("proto: wrong wireType = %d for field Reserved", wireType) + } + var stringLen uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowLog + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := data[iNdEx] + iNdEx++ + stringLen |= (uint64(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + intStringLen := int(stringLen) + if intStringLen < 0 { + return ErrInvalidLengthLog + } + postIndex := iNdEx + intStringLen + if postIndex > l { + return io.ErrUnexpectedEOF + } + s := string(data[iNdEx:postIndex]) + m.Reserved = &s + iNdEx = postIndex + case 3: + if wireType != 2 { + return fmt.Errorf("proto: wrong wireType = %d for field Topic", wireType) + } + var stringLen uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowLog + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := data[iNdEx] + iNdEx++ + stringLen |= (uint64(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + intStringLen := int(stringLen) + if intStringLen < 0 { + return ErrInvalidLengthLog + } + postIndex := iNdEx + intStringLen + if postIndex > l { + return io.ErrUnexpectedEOF + } + s := string(data[iNdEx:postIndex]) + m.Topic = &s + iNdEx = postIndex + case 4: + if wireType != 2 { + return fmt.Errorf("proto: wrong wireType = %d for field Source", wireType) + } + var stringLen uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowLog + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := data[iNdEx] + iNdEx++ + stringLen |= (uint64(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + intStringLen := int(stringLen) + if intStringLen < 0 { + return ErrInvalidLengthLog + } + postIndex := iNdEx + intStringLen + if postIndex > l { + return io.ErrUnexpectedEOF + } + s := string(data[iNdEx:postIndex]) + m.Source = &s + iNdEx = postIndex + default: + iNdEx = preIndex + skippy, err := skipLog(data[iNdEx:]) + if err != nil { + return err + } + if skippy < 0 { + return ErrInvalidLengthLog + } + if (iNdEx + skippy) > l { + return io.ErrUnexpectedEOF + } + m.XXXUnrecognized = append(m.XXXUnrecognized, data[iNdEx:iNdEx+skippy]...) + iNdEx += skippy + } + } + + if iNdEx > l { + return io.ErrUnexpectedEOF + } + return nil +} + +// Unmarshal data to LogGroupList +func (m *LogGroupList) Unmarshal(data []byte) error { + l := len(data) + iNdEx := 0 + for iNdEx < l { + preIndex := iNdEx + var wire uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowLog + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := data[iNdEx] + iNdEx++ + wire |= (uint64(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + fieldNum := int32(wire >> 3) + wireType := int(wire & 0x7) + if wireType == 4 { + return fmt.Errorf("proto: LogGroupList: wiretype end group for non-group") + } + if fieldNum <= 0 { + return fmt.Errorf("proto: LogGroupList: illegal tag %d (wire type %d)", fieldNum, wire) + } + switch fieldNum { + case 1: + if wireType != 2 { + return fmt.Errorf("proto: wrong wireType = %d for field LogGroups", wireType) + } + var msglen int + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowLog + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := data[iNdEx] + iNdEx++ + msglen |= (int(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + if msglen < 0 { + return ErrInvalidLengthLog + } + postIndex := iNdEx + msglen + if postIndex > l { + return io.ErrUnexpectedEOF + } + m.LogGroups = append(m.LogGroups, &LogGroup{}) + if err := m.LogGroups[len(m.LogGroups)-1].Unmarshal(data[iNdEx:postIndex]); err != nil { + return err + } + iNdEx = postIndex + default: + iNdEx = preIndex + skippy, err := skipLog(data[iNdEx:]) + if err != nil { + return err + } + if skippy < 0 { + return ErrInvalidLengthLog + } + if (iNdEx + skippy) > l { + return io.ErrUnexpectedEOF + } + m.XXXUnrecognized = append(m.XXXUnrecognized, data[iNdEx:iNdEx+skippy]...) + iNdEx += skippy + } + } + + if iNdEx > l { + return io.ErrUnexpectedEOF + } + return nil +} + +func skipLog(data []byte) (n int, err error) { + l := len(data) + iNdEx := 0 + for iNdEx < l { + var wire uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return 0, ErrIntOverflowLog + } + if iNdEx >= l { + return 0, io.ErrUnexpectedEOF + } + b := data[iNdEx] + iNdEx++ + wire |= (uint64(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + wireType := int(wire & 0x7) + switch wireType { + case 0: + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return 0, ErrIntOverflowLog + } + if iNdEx >= l { + return 0, io.ErrUnexpectedEOF + } + iNdEx++ + if data[iNdEx-1] < 0x80 { + break + } + } + return iNdEx, nil + case 1: + iNdEx += 8 + return iNdEx, nil + case 2: + var length int + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return 0, ErrIntOverflowLog + } + if iNdEx >= l { + return 0, io.ErrUnexpectedEOF + } + b := data[iNdEx] + iNdEx++ + length |= (int(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + iNdEx += length + if length < 0 { + return 0, ErrInvalidLengthLog + } + return iNdEx, nil + case 3: + for { + var innerWire uint64 + var start = iNdEx + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return 0, ErrIntOverflowLog + } + if iNdEx >= l { + return 0, io.ErrUnexpectedEOF + } + b := data[iNdEx] + iNdEx++ + innerWire |= (uint64(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + innerWireType := int(innerWire & 0x7) + if innerWireType == 4 { + break + } + next, err := skipLog(data[start:]) + if err != nil { + return 0, err + } + iNdEx = start + next + } + return iNdEx, nil + case 4: + return iNdEx, nil + case 5: + iNdEx += 4 + return iNdEx, nil + default: + return 0, fmt.Errorf("proto: illegal wireType %d", wireType) + } + } + panic("unreachable") +} diff --git a/pkg/logs/alils/log_config.go b/pkg/logs/alils/log_config.go new file mode 100755 index 00000000..e8564efb --- /dev/null +++ b/pkg/logs/alils/log_config.go @@ -0,0 +1,42 @@ +package alils + +// InputDetail define log detail +type InputDetail struct { + LogType string `json:"logType"` + LogPath string `json:"logPath"` + FilePattern string `json:"filePattern"` + LocalStorage bool `json:"localStorage"` + TimeFormat string `json:"timeFormat"` + LogBeginRegex string `json:"logBeginRegex"` + Regex string `json:"regex"` + Keys []string `json:"key"` + FilterKeys []string `json:"filterKey"` + FilterRegex []string `json:"filterRegex"` + TopicFormat string `json:"topicFormat"` +} + +// OutputDetail define the output detail +type OutputDetail struct { + Endpoint string `json:"endpoint"` + LogStoreName string `json:"logstoreName"` +} + +// LogConfig define Log Config +type LogConfig struct { + Name string `json:"configName"` + InputType string `json:"inputType"` + InputDetail InputDetail `json:"inputDetail"` + OutputType string `json:"outputType"` + OutputDetail OutputDetail `json:"outputDetail"` + + CreateTime uint32 + LastModifyTime uint32 + + project *LogProject +} + +// GetAppliedMachineGroup returns applied machine group of this config. +func (c *LogConfig) GetAppliedMachineGroup(confName string) (groupNames []string, err error) { + groupNames, err = c.project.GetAppliedMachineGroups(c.Name) + return +} diff --git a/pkg/logs/alils/log_project.go b/pkg/logs/alils/log_project.go new file mode 100755 index 00000000..59db8cbf --- /dev/null +++ b/pkg/logs/alils/log_project.go @@ -0,0 +1,819 @@ +/* +Package alils implements the SDK(v0.5.0) of Simple Log Service(abbr. SLS). + +For more description about SLS, please read this article: +http://gitlab.alibaba-inc.com/sls/doc. +*/ +package alils + +import ( + "encoding/json" + "fmt" + "io/ioutil" + "net/http" + "net/http/httputil" +) + +// Error message in SLS HTTP response. +type errorMessage struct { + Code string `json:"errorCode"` + Message string `json:"errorMessage"` +} + +// LogProject Define the Ali Project detail +type LogProject struct { + Name string // Project name + Endpoint string // IP or hostname of SLS endpoint + AccessKeyID string + AccessKeySecret string +} + +// NewLogProject creates a new SLS project. +func NewLogProject(name, endpoint, AccessKeyID, accessKeySecret string) (p *LogProject, err error) { + p = &LogProject{ + Name: name, + Endpoint: endpoint, + AccessKeyID: AccessKeyID, + AccessKeySecret: accessKeySecret, + } + return p, nil +} + +// ListLogStore returns all logstore names of project p. +func (p *LogProject) ListLogStore() (storeNames []string, err error) { + h := map[string]string{ + "x-sls-bodyrawsize": "0", + } + + uri := fmt.Sprintf("/logstores") + r, err := request(p, "GET", uri, h, nil) + if err != nil { + return + } + + buf, err := ioutil.ReadAll(r.Body) + if err != nil { + return + } + + if r.StatusCode != http.StatusOK { + errMsg := &errorMessage{} + err = json.Unmarshal(buf, errMsg) + if err != nil { + err = fmt.Errorf("failed to list logstore") + dump, _ := httputil.DumpResponse(r, true) + fmt.Printf("%s\n", dump) + return + } + err = fmt.Errorf("%v:%v", errMsg.Code, errMsg.Message) + return + } + + type Body struct { + Count int + LogStores []string + } + body := &Body{} + + err = json.Unmarshal(buf, body) + if err != nil { + return + } + + storeNames = body.LogStores + + return +} + +// GetLogStore returns logstore according by logstore name. +func (p *LogProject) GetLogStore(name string) (s *LogStore, err error) { + h := map[string]string{ + "x-sls-bodyrawsize": "0", + } + + r, err := request(p, "GET", "/logstores/"+name, h, nil) + if err != nil { + return + } + + buf, err := ioutil.ReadAll(r.Body) + if err != nil { + return + } + + if r.StatusCode != http.StatusOK { + errMsg := &errorMessage{} + err = json.Unmarshal(buf, errMsg) + if err != nil { + err = fmt.Errorf("failed to get logstore") + dump, _ := httputil.DumpResponse(r, true) + fmt.Printf("%s\n", dump) + return + } + err = fmt.Errorf("%v:%v", errMsg.Code, errMsg.Message) + return + } + + s = &LogStore{} + err = json.Unmarshal(buf, s) + if err != nil { + return + } + s.project = p + return +} + +// CreateLogStore creates a new logstore in SLS, +// where name is logstore name, +// and ttl is time-to-live(in day) of logs, +// and shardCnt is the number of shards. +func (p *LogProject) CreateLogStore(name string, ttl, shardCnt int) (err error) { + + type Body struct { + Name string `json:"logstoreName"` + TTL int `json:"ttl"` + ShardCount int `json:"shardCount"` + } + + store := &Body{ + Name: name, + TTL: ttl, + ShardCount: shardCnt, + } + + body, err := json.Marshal(store) + if err != nil { + return + } + + h := map[string]string{ + "x-sls-bodyrawsize": fmt.Sprintf("%v", len(body)), + "Content-Type": "application/json", + "Accept-Encoding": "deflate", // TODO: support lz4 + } + + r, err := request(p, "POST", "/logstores", h, body) + if err != nil { + return + } + + body, err = ioutil.ReadAll(r.Body) + if err != nil { + return + } + + if r.StatusCode != http.StatusOK { + errMsg := &errorMessage{} + err = json.Unmarshal(body, errMsg) + if err != nil { + err = fmt.Errorf("failed to create logstore") + dump, _ := httputil.DumpResponse(r, true) + fmt.Printf("%s\n", dump) + return + } + err = fmt.Errorf("%v:%v", errMsg.Code, errMsg.Message) + return + } + + return +} + +// DeleteLogStore deletes a logstore according by logstore name. +func (p *LogProject) DeleteLogStore(name string) (err error) { + h := map[string]string{ + "x-sls-bodyrawsize": "0", + } + + r, err := request(p, "DELETE", "/logstores/"+name, h, nil) + if err != nil { + return + } + + body, err := ioutil.ReadAll(r.Body) + if err != nil { + return + } + + if r.StatusCode != http.StatusOK { + errMsg := &errorMessage{} + err = json.Unmarshal(body, errMsg) + if err != nil { + err = fmt.Errorf("failed to delete logstore") + dump, _ := httputil.DumpResponse(r, true) + fmt.Printf("%s\n", dump) + return + } + err = fmt.Errorf("%v:%v", errMsg.Code, errMsg.Message) + return + } + return +} + +// UpdateLogStore updates a logstore according by logstore name, +// obviously we can't modify the logstore name itself. +func (p *LogProject) UpdateLogStore(name string, ttl, shardCnt int) (err error) { + + type Body struct { + Name string `json:"logstoreName"` + TTL int `json:"ttl"` + ShardCount int `json:"shardCount"` + } + + store := &Body{ + Name: name, + TTL: ttl, + ShardCount: shardCnt, + } + + body, err := json.Marshal(store) + if err != nil { + return + } + + h := map[string]string{ + "x-sls-bodyrawsize": fmt.Sprintf("%v", len(body)), + "Content-Type": "application/json", + "Accept-Encoding": "deflate", // TODO: support lz4 + } + + r, err := request(p, "PUT", "/logstores", h, body) + if err != nil { + return + } + + body, err = ioutil.ReadAll(r.Body) + if err != nil { + return + } + + if r.StatusCode != http.StatusOK { + errMsg := &errorMessage{} + err = json.Unmarshal(body, errMsg) + if err != nil { + err = fmt.Errorf("failed to update logstore") + dump, _ := httputil.DumpResponse(r, true) + fmt.Printf("%s\n", dump) + return + } + err = fmt.Errorf("%v:%v", errMsg.Code, errMsg.Message) + return + } + + return +} + +// ListMachineGroup returns machine group name list and the total number of machine groups. +// The offset starts from 0 and the size is the max number of machine groups could be returned. +func (p *LogProject) ListMachineGroup(offset, size int) (m []string, total int, err error) { + h := map[string]string{ + "x-sls-bodyrawsize": "0", + } + + if size <= 0 { + size = 500 + } + + uri := fmt.Sprintf("/machinegroups?offset=%v&size=%v", offset, size) + r, err := request(p, "GET", uri, h, nil) + if err != nil { + return + } + + buf, err := ioutil.ReadAll(r.Body) + if err != nil { + return + } + + if r.StatusCode != http.StatusOK { + errMsg := &errorMessage{} + err = json.Unmarshal(buf, errMsg) + if err != nil { + err = fmt.Errorf("failed to list machine group") + dump, _ := httputil.DumpResponse(r, true) + fmt.Printf("%s\n", dump) + return + } + err = fmt.Errorf("%v:%v", errMsg.Code, errMsg.Message) + return + } + + type Body struct { + MachineGroups []string + Count int + Total int + } + body := &Body{} + + err = json.Unmarshal(buf, body) + if err != nil { + return + } + + m = body.MachineGroups + total = body.Total + + return +} + +// GetMachineGroup retruns machine group according by machine group name. +func (p *LogProject) GetMachineGroup(name string) (m *MachineGroup, err error) { + h := map[string]string{ + "x-sls-bodyrawsize": "0", + } + + r, err := request(p, "GET", "/machinegroups/"+name, h, nil) + if err != nil { + return + } + + buf, err := ioutil.ReadAll(r.Body) + if err != nil { + return + } + + if r.StatusCode != http.StatusOK { + errMsg := &errorMessage{} + err = json.Unmarshal(buf, errMsg) + if err != nil { + err = fmt.Errorf("failed to get machine group:%v", name) + dump, _ := httputil.DumpResponse(r, true) + fmt.Printf("%s\n", dump) + return + } + err = fmt.Errorf("%v:%v", errMsg.Code, errMsg.Message) + return + } + + m = &MachineGroup{} + err = json.Unmarshal(buf, m) + if err != nil { + return + } + m.project = p + return +} + +// CreateMachineGroup creates a new machine group in SLS. +func (p *LogProject) CreateMachineGroup(m *MachineGroup) (err error) { + + body, err := json.Marshal(m) + if err != nil { + return + } + + h := map[string]string{ + "x-sls-bodyrawsize": fmt.Sprintf("%v", len(body)), + "Content-Type": "application/json", + "Accept-Encoding": "deflate", // TODO: support lz4 + } + + r, err := request(p, "POST", "/machinegroups", h, body) + if err != nil { + return + } + + body, err = ioutil.ReadAll(r.Body) + if err != nil { + return + } + + if r.StatusCode != http.StatusOK { + errMsg := &errorMessage{} + err = json.Unmarshal(body, errMsg) + if err != nil { + err = fmt.Errorf("failed to create machine group") + dump, _ := httputil.DumpResponse(r, true) + fmt.Printf("%s\n", dump) + return + } + err = fmt.Errorf("%v:%v", errMsg.Code, errMsg.Message) + return + } + + return +} + +// UpdateMachineGroup updates a machine group. +func (p *LogProject) UpdateMachineGroup(m *MachineGroup) (err error) { + + body, err := json.Marshal(m) + if err != nil { + return + } + + h := map[string]string{ + "x-sls-bodyrawsize": fmt.Sprintf("%v", len(body)), + "Content-Type": "application/json", + "Accept-Encoding": "deflate", // TODO: support lz4 + } + + r, err := request(p, "PUT", "/machinegroups/"+m.Name, h, body) + if err != nil { + return + } + + body, err = ioutil.ReadAll(r.Body) + if err != nil { + return + } + + if r.StatusCode != http.StatusOK { + errMsg := &errorMessage{} + err = json.Unmarshal(body, errMsg) + if err != nil { + err = fmt.Errorf("failed to update machine group") + dump, _ := httputil.DumpResponse(r, true) + fmt.Printf("%s\n", dump) + return + } + err = fmt.Errorf("%v:%v", errMsg.Code, errMsg.Message) + return + } + + return +} + +// DeleteMachineGroup deletes machine group according machine group name. +func (p *LogProject) DeleteMachineGroup(name string) (err error) { + h := map[string]string{ + "x-sls-bodyrawsize": "0", + } + + r, err := request(p, "DELETE", "/machinegroups/"+name, h, nil) + if err != nil { + return + } + + body, err := ioutil.ReadAll(r.Body) + if err != nil { + return + } + + if r.StatusCode != http.StatusOK { + errMsg := &errorMessage{} + err = json.Unmarshal(body, errMsg) + if err != nil { + err = fmt.Errorf("failed to delete machine group") + dump, _ := httputil.DumpResponse(r, true) + fmt.Printf("%s\n", dump) + return + } + err = fmt.Errorf("%v:%v", errMsg.Code, errMsg.Message) + return + } + return +} + +// ListConfig returns config names list and the total number of configs. +// The offset starts from 0 and the size is the max number of configs could be returned. +func (p *LogProject) ListConfig(offset, size int) (cfgNames []string, total int, err error) { + h := map[string]string{ + "x-sls-bodyrawsize": "0", + } + + if size <= 0 { + size = 100 + } + + uri := fmt.Sprintf("/configs?offset=%v&size=%v", offset, size) + r, err := request(p, "GET", uri, h, nil) + if err != nil { + return + } + + buf, err := ioutil.ReadAll(r.Body) + if err != nil { + return + } + + if r.StatusCode != http.StatusOK { + errMsg := &errorMessage{} + err = json.Unmarshal(buf, errMsg) + if err != nil { + err = fmt.Errorf("failed to delete machine group") + dump, _ := httputil.DumpResponse(r, true) + fmt.Printf("%s\n", dump) + return + } + err = fmt.Errorf("%v:%v", errMsg.Code, errMsg.Message) + return + } + + type Body struct { + Total int + Configs []string + } + body := &Body{} + + err = json.Unmarshal(buf, body) + if err != nil { + return + } + + cfgNames = body.Configs + total = body.Total + return +} + +// GetConfig returns config according by config name. +func (p *LogProject) GetConfig(name string) (c *LogConfig, err error) { + h := map[string]string{ + "x-sls-bodyrawsize": "0", + } + + r, err := request(p, "GET", "/configs/"+name, h, nil) + if err != nil { + return + } + + buf, err := ioutil.ReadAll(r.Body) + if err != nil { + return + } + + if r.StatusCode != http.StatusOK { + errMsg := &errorMessage{} + err = json.Unmarshal(buf, errMsg) + if err != nil { + err = fmt.Errorf("failed to delete config") + dump, _ := httputil.DumpResponse(r, true) + fmt.Printf("%s\n", dump) + return + } + err = fmt.Errorf("%v:%v", errMsg.Code, errMsg.Message) + return + } + + c = &LogConfig{} + err = json.Unmarshal(buf, c) + if err != nil { + return + } + c.project = p + return +} + +// UpdateConfig updates a config. +func (p *LogProject) UpdateConfig(c *LogConfig) (err error) { + + body, err := json.Marshal(c) + if err != nil { + return + } + + h := map[string]string{ + "x-sls-bodyrawsize": fmt.Sprintf("%v", len(body)), + "Content-Type": "application/json", + "Accept-Encoding": "deflate", // TODO: support lz4 + } + + r, err := request(p, "PUT", "/configs/"+c.Name, h, body) + if err != nil { + return + } + + body, err = ioutil.ReadAll(r.Body) + if err != nil { + return + } + + if r.StatusCode != http.StatusOK { + errMsg := &errorMessage{} + err = json.Unmarshal(body, errMsg) + if err != nil { + err = fmt.Errorf("failed to update config") + dump, _ := httputil.DumpResponse(r, true) + fmt.Printf("%s\n", dump) + return + } + err = fmt.Errorf("%v:%v", errMsg.Code, errMsg.Message) + return + } + + return +} + +// CreateConfig creates a new config in SLS. +func (p *LogProject) CreateConfig(c *LogConfig) (err error) { + + body, err := json.Marshal(c) + if err != nil { + return + } + + h := map[string]string{ + "x-sls-bodyrawsize": fmt.Sprintf("%v", len(body)), + "Content-Type": "application/json", + "Accept-Encoding": "deflate", // TODO: support lz4 + } + + r, err := request(p, "POST", "/configs", h, body) + if err != nil { + return + } + + body, err = ioutil.ReadAll(r.Body) + if err != nil { + return + } + + if r.StatusCode != http.StatusOK { + errMsg := &errorMessage{} + err = json.Unmarshal(body, errMsg) + if err != nil { + err = fmt.Errorf("failed to update config") + dump, _ := httputil.DumpResponse(r, true) + fmt.Printf("%s\n", dump) + return + } + err = fmt.Errorf("%v:%v", errMsg.Code, errMsg.Message) + return + } + + return +} + +// DeleteConfig deletes a config according by config name. +func (p *LogProject) DeleteConfig(name string) (err error) { + h := map[string]string{ + "x-sls-bodyrawsize": "0", + } + + r, err := request(p, "DELETE", "/configs/"+name, h, nil) + if err != nil { + return + } + + body, err := ioutil.ReadAll(r.Body) + if err != nil { + return + } + + if r.StatusCode != http.StatusOK { + errMsg := &errorMessage{} + err = json.Unmarshal(body, errMsg) + if err != nil { + err = fmt.Errorf("failed to delete config") + dump, _ := httputil.DumpResponse(r, true) + fmt.Printf("%s\n", dump) + return + } + err = fmt.Errorf("%v:%v", errMsg.Code, errMsg.Message) + return + } + return +} + +// GetAppliedMachineGroups returns applied machine group names list according config name. +func (p *LogProject) GetAppliedMachineGroups(confName string) (groupNames []string, err error) { + h := map[string]string{ + "x-sls-bodyrawsize": "0", + } + + uri := fmt.Sprintf("/configs/%v/machinegroups", confName) + r, err := request(p, "GET", uri, h, nil) + if err != nil { + return + } + + buf, err := ioutil.ReadAll(r.Body) + if err != nil { + return + } + + if r.StatusCode != http.StatusOK { + errMsg := &errorMessage{} + err = json.Unmarshal(buf, errMsg) + if err != nil { + err = fmt.Errorf("failed to get applied machine groups") + dump, _ := httputil.DumpResponse(r, true) + fmt.Printf("%s\n", dump) + return + } + err = fmt.Errorf("%v:%v", errMsg.Code, errMsg.Message) + return + } + + type Body struct { + Count int + Machinegroups []string + } + + body := &Body{} + err = json.Unmarshal(buf, body) + if err != nil { + return + } + + groupNames = body.Machinegroups + return +} + +// GetAppliedConfigs returns applied config names list according machine group name groupName. +func (p *LogProject) GetAppliedConfigs(groupName string) (confNames []string, err error) { + h := map[string]string{ + "x-sls-bodyrawsize": "0", + } + + uri := fmt.Sprintf("/machinegroups/%v/configs", groupName) + r, err := request(p, "GET", uri, h, nil) + if err != nil { + return + } + + buf, err := ioutil.ReadAll(r.Body) + if err != nil { + return + } + + if r.StatusCode != http.StatusOK { + errMsg := &errorMessage{} + err = json.Unmarshal(buf, errMsg) + if err != nil { + err = fmt.Errorf("failed to applied configs") + dump, _ := httputil.DumpResponse(r, true) + fmt.Printf("%s\n", dump) + return + } + err = fmt.Errorf("%v:%v", errMsg.Code, errMsg.Message) + return + } + + type Cfg struct { + Count int `json:"count"` + Configs []string `json:"configs"` + } + + body := &Cfg{} + err = json.Unmarshal(buf, body) + if err != nil { + return + } + + confNames = body.Configs + return +} + +// ApplyConfigToMachineGroup applies config to machine group. +func (p *LogProject) ApplyConfigToMachineGroup(confName, groupName string) (err error) { + h := map[string]string{ + "x-sls-bodyrawsize": "0", + } + + uri := fmt.Sprintf("/machinegroups/%v/configs/%v", groupName, confName) + r, err := request(p, "PUT", uri, h, nil) + if err != nil { + return + } + + buf, err := ioutil.ReadAll(r.Body) + if err != nil { + return + } + + if r.StatusCode != http.StatusOK { + errMsg := &errorMessage{} + err = json.Unmarshal(buf, errMsg) + if err != nil { + err = fmt.Errorf("failed to apply config to machine group") + dump, _ := httputil.DumpResponse(r, true) + fmt.Printf("%s\n", dump) + return + } + err = fmt.Errorf("%v:%v", errMsg.Code, errMsg.Message) + return + } + return +} + +// RemoveConfigFromMachineGroup removes config from machine group. +func (p *LogProject) RemoveConfigFromMachineGroup(confName, groupName string) (err error) { + h := map[string]string{ + "x-sls-bodyrawsize": "0", + } + + uri := fmt.Sprintf("/machinegroups/%v/configs/%v", groupName, confName) + r, err := request(p, "DELETE", uri, h, nil) + if err != nil { + return + } + + buf, err := ioutil.ReadAll(r.Body) + if err != nil { + return + } + + if r.StatusCode != http.StatusOK { + errMsg := &errorMessage{} + err = json.Unmarshal(buf, errMsg) + if err != nil { + err = fmt.Errorf("failed to remove config from machine group") + dump, _ := httputil.DumpResponse(r, true) + fmt.Printf("%s\n", dump) + return + } + err = fmt.Errorf("%v:%v", errMsg.Code, errMsg.Message) + return + } + return +} diff --git a/pkg/logs/alils/log_store.go b/pkg/logs/alils/log_store.go new file mode 100755 index 00000000..fa502736 --- /dev/null +++ b/pkg/logs/alils/log_store.go @@ -0,0 +1,271 @@ +package alils + +import ( + "encoding/json" + "fmt" + "io/ioutil" + "net/http" + "net/http/httputil" + "strconv" + + lz4 "github.com/cloudflare/golz4" + "github.com/gogo/protobuf/proto" +) + +// LogStore Store the logs +type LogStore struct { + Name string `json:"logstoreName"` + TTL int + ShardCount int + + CreateTime uint32 + LastModifyTime uint32 + + project *LogProject +} + +// Shard define the Log Shard +type Shard struct { + ShardID int `json:"shardID"` +} + +// ListShards returns shard id list of this logstore. +func (s *LogStore) ListShards() (shardIDs []int, err error) { + h := map[string]string{ + "x-sls-bodyrawsize": "0", + } + + uri := fmt.Sprintf("/logstores/%v/shards", s.Name) + r, err := request(s.project, "GET", uri, h, nil) + if err != nil { + return + } + + buf, err := ioutil.ReadAll(r.Body) + if err != nil { + return + } + + if r.StatusCode != http.StatusOK { + errMsg := &errorMessage{} + err = json.Unmarshal(buf, errMsg) + if err != nil { + err = fmt.Errorf("failed to list logstore") + dump, _ := httputil.DumpResponse(r, true) + fmt.Println(dump) + return + } + err = fmt.Errorf("%v:%v", errMsg.Code, errMsg.Message) + return + } + + var shards []*Shard + err = json.Unmarshal(buf, &shards) + if err != nil { + return + } + + for _, v := range shards { + shardIDs = append(shardIDs, v.ShardID) + } + return +} + +// PutLogs put logs into logstore. +// The callers should transform user logs into LogGroup. +func (s *LogStore) PutLogs(lg *LogGroup) (err error) { + body, err := proto.Marshal(lg) + if err != nil { + return + } + + // Compresse body with lz4 + out := make([]byte, lz4.CompressBound(body)) + n, err := lz4.Compress(body, out) + if err != nil { + return + } + + h := map[string]string{ + "x-sls-compresstype": "lz4", + "x-sls-bodyrawsize": fmt.Sprintf("%v", len(body)), + "Content-Type": "application/x-protobuf", + } + + uri := fmt.Sprintf("/logstores/%v", s.Name) + r, err := request(s.project, "POST", uri, h, out[:n]) + if err != nil { + return + } + + buf, err := ioutil.ReadAll(r.Body) + if err != nil { + return + } + + if r.StatusCode != http.StatusOK { + errMsg := &errorMessage{} + err = json.Unmarshal(buf, errMsg) + if err != nil { + err = fmt.Errorf("failed to put logs") + dump, _ := httputil.DumpResponse(r, true) + fmt.Println(dump) + return + } + err = fmt.Errorf("%v:%v", errMsg.Code, errMsg.Message) + return + } + return +} + +// GetCursor gets log cursor of one shard specified by shardID. +// The from can be in three form: a) unix timestamp in seccond, b) "begin", c) "end". +// For more detail please read: http://gitlab.alibaba-inc.com/sls/doc/blob/master/api/shard.md#logstore +func (s *LogStore) GetCursor(shardID int, from string) (cursor string, err error) { + h := map[string]string{ + "x-sls-bodyrawsize": "0", + } + + uri := fmt.Sprintf("/logstores/%v/shards/%v?type=cursor&from=%v", + s.Name, shardID, from) + + r, err := request(s.project, "GET", uri, h, nil) + if err != nil { + return + } + + buf, err := ioutil.ReadAll(r.Body) + if err != nil { + return + } + + if r.StatusCode != http.StatusOK { + errMsg := &errorMessage{} + err = json.Unmarshal(buf, errMsg) + if err != nil { + err = fmt.Errorf("failed to get cursor") + dump, _ := httputil.DumpResponse(r, true) + fmt.Println(dump) + return + } + err = fmt.Errorf("%v:%v", errMsg.Code, errMsg.Message) + return + } + + type Body struct { + Cursor string + } + body := &Body{} + + err = json.Unmarshal(buf, body) + if err != nil { + return + } + cursor = body.Cursor + return +} + +// GetLogsBytes gets logs binary data from shard specified by shardID according cursor. +// The logGroupMaxCount is the max number of logGroup could be returned. +// The nextCursor is the next curosr can be used to read logs at next time. +func (s *LogStore) GetLogsBytes(shardID int, cursor string, + logGroupMaxCount int) (out []byte, nextCursor string, err error) { + + h := map[string]string{ + "x-sls-bodyrawsize": "0", + "Accept": "application/x-protobuf", + "Accept-Encoding": "lz4", + } + + uri := fmt.Sprintf("/logstores/%v/shards/%v?type=logs&cursor=%v&count=%v", + s.Name, shardID, cursor, logGroupMaxCount) + + r, err := request(s.project, "GET", uri, h, nil) + if err != nil { + return + } + + buf, err := ioutil.ReadAll(r.Body) + if err != nil { + return + } + + if r.StatusCode != http.StatusOK { + errMsg := &errorMessage{} + err = json.Unmarshal(buf, errMsg) + if err != nil { + err = fmt.Errorf("failed to get cursor") + dump, _ := httputil.DumpResponse(r, true) + fmt.Println(dump) + return + } + err = fmt.Errorf("%v:%v", errMsg.Code, errMsg.Message) + return + } + + v, ok := r.Header["X-Sls-Compresstype"] + if !ok || len(v) == 0 { + err = fmt.Errorf("can't find 'x-sls-compresstype' header") + return + } + if v[0] != "lz4" { + err = fmt.Errorf("unexpected compress type:%v", v[0]) + return + } + + v, ok = r.Header["X-Sls-Cursor"] + if !ok || len(v) == 0 { + err = fmt.Errorf("can't find 'x-sls-cursor' header") + return + } + nextCursor = v[0] + + v, ok = r.Header["X-Sls-Bodyrawsize"] + if !ok || len(v) == 0 { + err = fmt.Errorf("can't find 'x-sls-bodyrawsize' header") + return + } + bodyRawSize, err := strconv.Atoi(v[0]) + if err != nil { + return + } + + out = make([]byte, bodyRawSize) + err = lz4.Uncompress(buf, out) + if err != nil { + return + } + + return +} + +// LogsBytesDecode decodes logs binary data retruned by GetLogsBytes API +func LogsBytesDecode(data []byte) (gl *LogGroupList, err error) { + + gl = &LogGroupList{} + err = proto.Unmarshal(data, gl) + if err != nil { + return + } + + return +} + +// GetLogs gets logs from shard specified by shardID according cursor. +// The logGroupMaxCount is the max number of logGroup could be returned. +// The nextCursor is the next curosr can be used to read logs at next time. +func (s *LogStore) GetLogs(shardID int, cursor string, + logGroupMaxCount int) (gl *LogGroupList, nextCursor string, err error) { + + out, nextCursor, err := s.GetLogsBytes(shardID, cursor, logGroupMaxCount) + if err != nil { + return + } + + gl, err = LogsBytesDecode(out) + if err != nil { + return + } + + return +} diff --git a/pkg/logs/alils/machine_group.go b/pkg/logs/alils/machine_group.go new file mode 100755 index 00000000..b6c69a14 --- /dev/null +++ b/pkg/logs/alils/machine_group.go @@ -0,0 +1,91 @@ +package alils + +import ( + "encoding/json" + "fmt" + "io/ioutil" + "net/http" + "net/http/httputil" +) + +// MachineGroupAttribute define the Attribute +type MachineGroupAttribute struct { + ExternalName string `json:"externalName"` + TopicName string `json:"groupTopic"` +} + +// MachineGroup define the machine Group +type MachineGroup struct { + Name string `json:"groupName"` + Type string `json:"groupType"` + MachineIDType string `json:"machineIdentifyType"` + MachineIDList []string `json:"machineList"` + + Attribute MachineGroupAttribute `json:"groupAttribute"` + + CreateTime uint32 + LastModifyTime uint32 + + project *LogProject +} + +// Machine define the Machine +type Machine struct { + IP string + UniqueID string `json:"machine-uniqueid"` + UserdefinedID string `json:"userdefined-id"` +} + +// MachineList define the Machine List +type MachineList struct { + Total int + Machines []*Machine +} + +// ListMachines returns machine list of this machine group. +func (m *MachineGroup) ListMachines() (ms []*Machine, total int, err error) { + h := map[string]string{ + "x-sls-bodyrawsize": "0", + } + + uri := fmt.Sprintf("/machinegroups/%v/machines", m.Name) + r, err := request(m.project, "GET", uri, h, nil) + if err != nil { + return + } + + buf, err := ioutil.ReadAll(r.Body) + if err != nil { + return + } + + if r.StatusCode != http.StatusOK { + errMsg := &errorMessage{} + err = json.Unmarshal(buf, errMsg) + if err != nil { + err = fmt.Errorf("failed to remove config from machine group") + dump, _ := httputil.DumpResponse(r, true) + fmt.Println(dump) + return + } + err = fmt.Errorf("%v:%v", errMsg.Code, errMsg.Message) + return + } + + body := &MachineList{} + err = json.Unmarshal(buf, body) + if err != nil { + return + } + + ms = body.Machines + total = body.Total + + return +} + +// GetAppliedConfigs returns applied configs of this machine group. +func (m *MachineGroup) GetAppliedConfigs() (confNames []string, err error) { + confNames, err = m.project.GetAppliedConfigs(m.Name) + return +} diff --git a/pkg/logs/alils/request.go b/pkg/logs/alils/request.go new file mode 100755 index 00000000..50d9c43c --- /dev/null +++ b/pkg/logs/alils/request.go @@ -0,0 +1,62 @@ +package alils + +import ( + "bytes" + "crypto/md5" + "fmt" + "net/http" +) + +// request sends a request to SLS. +func request(project *LogProject, method, uri string, headers map[string]string, + body []byte) (resp *http.Response, err error) { + + // The caller should provide 'x-sls-bodyrawsize' header + if _, ok := headers["x-sls-bodyrawsize"]; !ok { + err = fmt.Errorf("Can't find 'x-sls-bodyrawsize' header") + return + } + + // SLS public request headers + headers["Host"] = project.Name + "." + project.Endpoint + headers["Date"] = nowRFC1123() + headers["x-sls-apiversion"] = version + headers["x-sls-signaturemethod"] = signatureMethod + if body != nil { + bodyMD5 := fmt.Sprintf("%X", md5.Sum(body)) + headers["Content-MD5"] = bodyMD5 + + if _, ok := headers["Content-Type"]; !ok { + err = fmt.Errorf("Can't find 'Content-Type' header") + return + } + } + + // Calc Authorization + // Authorization = "SLS :" + digest, err := signature(project, method, uri, headers) + if err != nil { + return + } + auth := fmt.Sprintf("SLS %v:%v", project.AccessKeyID, digest) + headers["Authorization"] = auth + + // Initialize http request + reader := bytes.NewReader(body) + urlStr := fmt.Sprintf("http://%v.%v%v", project.Name, project.Endpoint, uri) + req, err := http.NewRequest(method, urlStr, reader) + if err != nil { + return + } + for k, v := range headers { + req.Header.Add(k, v) + } + + // Get ready to do request + resp, err = http.DefaultClient.Do(req) + if err != nil { + return + } + + return +} diff --git a/pkg/logs/alils/signature.go b/pkg/logs/alils/signature.go new file mode 100755 index 00000000..2d611307 --- /dev/null +++ b/pkg/logs/alils/signature.go @@ -0,0 +1,111 @@ +package alils + +import ( + "crypto/hmac" + "crypto/sha1" + "encoding/base64" + "fmt" + "net/url" + "sort" + "strings" + "time" +) + +// GMT location +var gmtLoc = time.FixedZone("GMT", 0) + +// NowRFC1123 returns now time in RFC1123 format with GMT timezone, +// eg. "Mon, 02 Jan 2006 15:04:05 GMT". +func nowRFC1123() string { + return time.Now().In(gmtLoc).Format(time.RFC1123) +} + +// signature calculates a request's signature digest. +func signature(project *LogProject, method, uri string, + headers map[string]string) (digest string, err error) { + var contentMD5, contentType, date, canoHeaders, canoResource string + var slsHeaderKeys sort.StringSlice + + // SignString = VERB + "\n" + // + CONTENT-MD5 + "\n" + // + CONTENT-TYPE + "\n" + // + DATE + "\n" + // + CanonicalizedSLSHeaders + "\n" + // + CanonicalizedResource + + if val, ok := headers["Content-MD5"]; ok { + contentMD5 = val + } + + if val, ok := headers["Content-Type"]; ok { + contentType = val + } + + date, ok := headers["Date"] + if !ok { + err = fmt.Errorf("Can't find 'Date' header") + return + } + + // Calc CanonicalizedSLSHeaders + slsHeaders := make(map[string]string, len(headers)) + for k, v := range headers { + l := strings.TrimSpace(strings.ToLower(k)) + if strings.HasPrefix(l, "x-sls-") { + slsHeaders[l] = strings.TrimSpace(v) + slsHeaderKeys = append(slsHeaderKeys, l) + } + } + + sort.Sort(slsHeaderKeys) + for i, k := range slsHeaderKeys { + canoHeaders += k + ":" + slsHeaders[k] + if i+1 < len(slsHeaderKeys) { + canoHeaders += "\n" + } + } + + // Calc CanonicalizedResource + u, err := url.Parse(uri) + if err != nil { + return + } + + canoResource += url.QueryEscape(u.Path) + if u.RawQuery != "" { + var keys sort.StringSlice + + vals := u.Query() + for k := range vals { + keys = append(keys, k) + } + + sort.Sort(keys) + canoResource += "?" + for i, k := range keys { + if i > 0 { + canoResource += "&" + } + + for _, v := range vals[k] { + canoResource += k + "=" + v + } + } + } + + signStr := method + "\n" + + contentMD5 + "\n" + + contentType + "\n" + + date + "\n" + + canoHeaders + "\n" + + canoResource + + // Signature = base64(hmac-sha1(UTF8-Encoding-Of(SignString),AccessKeySecret)) + mac := hmac.New(sha1.New, []byte(project.AccessKeySecret)) + _, err = mac.Write([]byte(signStr)) + if err != nil { + return + } + digest = base64.StdEncoding.EncodeToString(mac.Sum(nil)) + return +} diff --git a/pkg/logs/conn.go b/pkg/logs/conn.go new file mode 100644 index 00000000..74c458ab --- /dev/null +++ b/pkg/logs/conn.go @@ -0,0 +1,119 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package logs + +import ( + "encoding/json" + "io" + "net" + "time" +) + +// connWriter implements LoggerInterface. +// it writes messages in keep-live tcp connection. +type connWriter struct { + lg *logWriter + innerWriter io.WriteCloser + ReconnectOnMsg bool `json:"reconnectOnMsg"` + Reconnect bool `json:"reconnect"` + Net string `json:"net"` + Addr string `json:"addr"` + Level int `json:"level"` +} + +// NewConn create new ConnWrite returning as LoggerInterface. +func NewConn() Logger { + conn := new(connWriter) + conn.Level = LevelTrace + return conn +} + +// Init init connection writer with json config. +// json config only need key "level". +func (c *connWriter) Init(jsonConfig string) error { + return json.Unmarshal([]byte(jsonConfig), c) +} + +// WriteMsg write message in connection. +// if connection is down, try to re-connect. +func (c *connWriter) WriteMsg(when time.Time, msg string, level int) error { + if level > c.Level { + return nil + } + if c.needToConnectOnMsg() { + err := c.connect() + if err != nil { + return err + } + } + + if c.ReconnectOnMsg { + defer c.innerWriter.Close() + } + + _, err := c.lg.writeln(when, msg) + if err != nil { + return err + } + return nil +} + +// Flush implementing method. empty. +func (c *connWriter) Flush() { + +} + +// Destroy destroy connection writer and close tcp listener. +func (c *connWriter) Destroy() { + if c.innerWriter != nil { + c.innerWriter.Close() + } +} + +func (c *connWriter) connect() error { + if c.innerWriter != nil { + c.innerWriter.Close() + c.innerWriter = nil + } + + conn, err := net.Dial(c.Net, c.Addr) + if err != nil { + return err + } + + if tcpConn, ok := conn.(*net.TCPConn); ok { + tcpConn.SetKeepAlive(true) + } + + c.innerWriter = conn + c.lg = newLogWriter(conn) + return nil +} + +func (c *connWriter) needToConnectOnMsg() bool { + if c.Reconnect { + return true + } + + if c.innerWriter == nil { + return true + } + + return c.ReconnectOnMsg +} + +func init() { + Register(AdapterConn, NewConn) +} diff --git a/pkg/logs/conn_test.go b/pkg/logs/conn_test.go new file mode 100644 index 00000000..bb377d41 --- /dev/null +++ b/pkg/logs/conn_test.go @@ -0,0 +1,79 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package logs + +import ( + "net" + "os" + "testing" +) + +// ConnTCPListener takes a TCP listener and accepts n TCP connections +// Returns connections using connChan +func connTCPListener(t *testing.T, n int, ln net.Listener, connChan chan<- net.Conn) { + + // Listen and accept n incoming connections + for i := 0; i < n; i++ { + conn, err := ln.Accept() + if err != nil { + t.Log("Error accepting connection: ", err.Error()) + os.Exit(1) + } + + // Send accepted connection to channel + connChan <- conn + } + ln.Close() + close(connChan) +} + +func TestConn(t *testing.T) { + log := NewLogger(1000) + log.SetLogger("conn", `{"net":"tcp","addr":":7020"}`) + log.Informational("informational") +} + +func TestReconnect(t *testing.T) { + // Setup connection listener + newConns := make(chan net.Conn) + connNum := 2 + ln, err := net.Listen("tcp", ":6002") + if err != nil { + t.Log("Error listening:", err.Error()) + os.Exit(1) + } + go connTCPListener(t, connNum, ln, newConns) + + // Setup logger + log := NewLogger(1000) + log.SetPrefix("test") + log.SetLogger(AdapterConn, `{"net":"tcp","reconnect":true,"level":6,"addr":":6002"}`) + log.Informational("informational 1") + + // Refuse first connection + first := <-newConns + first.Close() + + // Send another log after conn closed + log.Informational("informational 2") + + // Check if there was a second connection attempt + select { + case second := <-newConns: + second.Close() + default: + t.Error("Did not reconnect") + } +} diff --git a/pkg/logs/console.go b/pkg/logs/console.go new file mode 100644 index 00000000..3dcaee1d --- /dev/null +++ b/pkg/logs/console.go @@ -0,0 +1,99 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package logs + +import ( + "encoding/json" + "os" + "strings" + "time" + + "github.com/shiena/ansicolor" +) + +// brush is a color join function +type brush func(string) string + +// newBrush return a fix color Brush +func newBrush(color string) brush { + pre := "\033[" + reset := "\033[0m" + return func(text string) string { + return pre + color + "m" + text + reset + } +} + +var colors = []brush{ + newBrush("1;37"), // Emergency white + newBrush("1;36"), // Alert cyan + newBrush("1;35"), // Critical magenta + newBrush("1;31"), // Error red + newBrush("1;33"), // Warning yellow + newBrush("1;32"), // Notice green + newBrush("1;34"), // Informational blue + newBrush("1;44"), // Debug Background blue +} + +// consoleWriter implements LoggerInterface and writes messages to terminal. +type consoleWriter struct { + lg *logWriter + Level int `json:"level"` + Colorful bool `json:"color"` //this filed is useful only when system's terminal supports color +} + +// NewConsole create ConsoleWriter returning as LoggerInterface. +func NewConsole() Logger { + cw := &consoleWriter{ + lg: newLogWriter(ansicolor.NewAnsiColorWriter(os.Stdout)), + Level: LevelDebug, + Colorful: true, + } + return cw +} + +// Init init console logger. +// jsonConfig like '{"level":LevelTrace}'. +func (c *consoleWriter) Init(jsonConfig string) error { + if len(jsonConfig) == 0 { + return nil + } + return json.Unmarshal([]byte(jsonConfig), c) +} + +// WriteMsg write message in console. +func (c *consoleWriter) WriteMsg(when time.Time, msg string, level int) error { + if level > c.Level { + return nil + } + if c.Colorful { + msg = strings.Replace(msg, levelPrefix[level], colors[level](levelPrefix[level]), 1) + } + c.lg.writeln(when, msg) + return nil +} + +// Destroy implementing method. empty. +func (c *consoleWriter) Destroy() { + +} + +// Flush implementing method. empty. +func (c *consoleWriter) Flush() { + +} + +func init() { + Register(AdapterConsole, NewConsole) +} diff --git a/pkg/logs/console_test.go b/pkg/logs/console_test.go new file mode 100644 index 00000000..4bc45f57 --- /dev/null +++ b/pkg/logs/console_test.go @@ -0,0 +1,64 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package logs + +import ( + "testing" + "time" +) + +// Try each log level in decreasing order of priority. +func testConsoleCalls(bl *BeeLogger) { + bl.Emergency("emergency") + bl.Alert("alert") + bl.Critical("critical") + bl.Error("error") + bl.Warning("warning") + bl.Notice("notice") + bl.Informational("informational") + bl.Debug("debug") +} + +// Test console logging by visually comparing the lines being output with and +// without a log level specification. +func TestConsole(t *testing.T) { + log1 := NewLogger(10000) + log1.EnableFuncCallDepth(true) + log1.SetLogger("console", "") + testConsoleCalls(log1) + + log2 := NewLogger(100) + log2.SetLogger("console", `{"level":3}`) + testConsoleCalls(log2) +} + +// Test console without color +func TestConsoleNoColor(t *testing.T) { + log := NewLogger(100) + log.SetLogger("console", `{"color":false}`) + testConsoleCalls(log) +} + +// Test console async +func TestConsoleAsync(t *testing.T) { + log := NewLogger(100) + log.SetLogger("console") + log.Async() + //log.Close() + testConsoleCalls(log) + for len(log.msgChan) != 0 { + time.Sleep(1 * time.Millisecond) + } +} diff --git a/pkg/logs/es/es.go b/pkg/logs/es/es.go new file mode 100644 index 00000000..2b7b1710 --- /dev/null +++ b/pkg/logs/es/es.go @@ -0,0 +1,102 @@ +package es + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "net/url" + "strings" + "time" + + "github.com/elastic/go-elasticsearch/v6" + "github.com/elastic/go-elasticsearch/v6/esapi" + + "github.com/astaxie/beego/logs" +) + +// NewES return a LoggerInterface +func NewES() logs.Logger { + cw := &esLogger{ + Level: logs.LevelDebug, + } + return cw +} + +// esLogger will log msg into ES +// before you using this implementation, +// please import this package +// usually means that you can import this package in your main package +// for example, anonymous: +// import _ "github.com/astaxie/beego/logs/es" +type esLogger struct { + *elasticsearch.Client + DSN string `json:"dsn"` + Level int `json:"level"` +} + +// {"dsn":"http://localhost:9200/","level":1} +func (el *esLogger) Init(jsonconfig string) error { + err := json.Unmarshal([]byte(jsonconfig), el) + if err != nil { + return err + } + if el.DSN == "" { + return errors.New("empty dsn") + } else if u, err := url.Parse(el.DSN); err != nil { + return err + } else if u.Path == "" { + return errors.New("missing prefix") + } else { + conn, err := elasticsearch.NewClient(elasticsearch.Config{ + Addresses: []string{el.DSN}, + }) + if err != nil { + return err + } + el.Client = conn + } + return nil +} + +// WriteMsg will write the msg and level into es +func (el *esLogger) WriteMsg(when time.Time, msg string, level int) error { + if level > el.Level { + return nil + } + + idx := LogDocument{ + Timestamp: when.Format(time.RFC3339), + Msg: msg, + } + + body, err := json.Marshal(idx) + if err != nil { + return err + } + req := esapi.IndexRequest{ + Index: fmt.Sprintf("%04d.%02d.%02d", when.Year(), when.Month(), when.Day()), + DocumentType: "logs", + Body: strings.NewReader(string(body)), + } + _, err = req.Do(context.Background(), el.Client) + return err +} + +// Destroy is a empty method +func (el *esLogger) Destroy() { +} + +// Flush is a empty method +func (el *esLogger) Flush() { + +} + +type LogDocument struct { + Timestamp string `json:"timestamp"` + Msg string `json:"msg"` +} + +func init() { + logs.Register(logs.AdapterEs, NewES) +} diff --git a/pkg/logs/file.go b/pkg/logs/file.go new file mode 100644 index 00000000..222db989 --- /dev/null +++ b/pkg/logs/file.go @@ -0,0 +1,409 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package logs + +import ( + "bytes" + "encoding/json" + "errors" + "fmt" + "io" + "os" + "path" + "path/filepath" + "strconv" + "strings" + "sync" + "time" +) + +// fileLogWriter implements LoggerInterface. +// It writes messages by lines limit, file size limit, or time frequency. +type fileLogWriter struct { + sync.RWMutex // write log order by order and atomic incr maxLinesCurLines and maxSizeCurSize + // The opened file + Filename string `json:"filename"` + fileWriter *os.File + + // Rotate at line + MaxLines int `json:"maxlines"` + maxLinesCurLines int + + MaxFiles int `json:"maxfiles"` + MaxFilesCurFiles int + + // Rotate at size + MaxSize int `json:"maxsize"` + maxSizeCurSize int + + // Rotate daily + Daily bool `json:"daily"` + MaxDays int64 `json:"maxdays"` + dailyOpenDate int + dailyOpenTime time.Time + + // Rotate hourly + Hourly bool `json:"hourly"` + MaxHours int64 `json:"maxhours"` + hourlyOpenDate int + hourlyOpenTime time.Time + + Rotate bool `json:"rotate"` + + Level int `json:"level"` + + Perm string `json:"perm"` + + RotatePerm string `json:"rotateperm"` + + fileNameOnly, suffix string // like "project.log", project is fileNameOnly and .log is suffix +} + +// newFileWriter create a FileLogWriter returning as LoggerInterface. +func newFileWriter() Logger { + w := &fileLogWriter{ + Daily: true, + MaxDays: 7, + Hourly: false, + MaxHours: 168, + Rotate: true, + RotatePerm: "0440", + Level: LevelTrace, + Perm: "0660", + MaxLines: 10000000, + MaxFiles: 999, + MaxSize: 1 << 28, + } + return w +} + +// Init file logger with json config. +// jsonConfig like: +// { +// "filename":"logs/beego.log", +// "maxLines":10000, +// "maxsize":1024, +// "daily":true, +// "maxDays":15, +// "rotate":true, +// "perm":"0600" +// } +func (w *fileLogWriter) Init(jsonConfig string) error { + err := json.Unmarshal([]byte(jsonConfig), w) + if err != nil { + return err + } + if len(w.Filename) == 0 { + return errors.New("jsonconfig must have filename") + } + w.suffix = filepath.Ext(w.Filename) + w.fileNameOnly = strings.TrimSuffix(w.Filename, w.suffix) + if w.suffix == "" { + w.suffix = ".log" + } + err = w.startLogger() + return err +} + +// start file logger. create log file and set to locker-inside file writer. +func (w *fileLogWriter) startLogger() error { + file, err := w.createLogFile() + if err != nil { + return err + } + if w.fileWriter != nil { + w.fileWriter.Close() + } + w.fileWriter = file + return w.initFd() +} + +func (w *fileLogWriter) needRotateDaily(size int, day int) bool { + return (w.MaxLines > 0 && w.maxLinesCurLines >= w.MaxLines) || + (w.MaxSize > 0 && w.maxSizeCurSize >= w.MaxSize) || + (w.Daily && day != w.dailyOpenDate) +} + +func (w *fileLogWriter) needRotateHourly(size int, hour int) bool { + return (w.MaxLines > 0 && w.maxLinesCurLines >= w.MaxLines) || + (w.MaxSize > 0 && w.maxSizeCurSize >= w.MaxSize) || + (w.Hourly && hour != w.hourlyOpenDate) + +} + +// WriteMsg write logger message into file. +func (w *fileLogWriter) WriteMsg(when time.Time, msg string, level int) error { + if level > w.Level { + return nil + } + hd, d, h := formatTimeHeader(when) + msg = string(hd) + msg + "\n" + if w.Rotate { + w.RLock() + if w.needRotateHourly(len(msg), h) { + w.RUnlock() + w.Lock() + if w.needRotateHourly(len(msg), h) { + if err := w.doRotate(when); err != nil { + fmt.Fprintf(os.Stderr, "FileLogWriter(%q): %s\n", w.Filename, err) + } + } + w.Unlock() + } else if w.needRotateDaily(len(msg), d) { + w.RUnlock() + w.Lock() + if w.needRotateDaily(len(msg), d) { + if err := w.doRotate(when); err != nil { + fmt.Fprintf(os.Stderr, "FileLogWriter(%q): %s\n", w.Filename, err) + } + } + w.Unlock() + } else { + w.RUnlock() + } + } + + w.Lock() + _, err := w.fileWriter.Write([]byte(msg)) + if err == nil { + w.maxLinesCurLines++ + w.maxSizeCurSize += len(msg) + } + w.Unlock() + return err +} + +func (w *fileLogWriter) createLogFile() (*os.File, error) { + // Open the log file + perm, err := strconv.ParseInt(w.Perm, 8, 64) + if err != nil { + return nil, err + } + + filepath := path.Dir(w.Filename) + os.MkdirAll(filepath, os.FileMode(perm)) + + fd, err := os.OpenFile(w.Filename, os.O_WRONLY|os.O_APPEND|os.O_CREATE, os.FileMode(perm)) + if err == nil { + // Make sure file perm is user set perm cause of `os.OpenFile` will obey umask + os.Chmod(w.Filename, os.FileMode(perm)) + } + return fd, err +} + +func (w *fileLogWriter) initFd() error { + fd := w.fileWriter + fInfo, err := fd.Stat() + if err != nil { + return fmt.Errorf("get stat err: %s", err) + } + w.maxSizeCurSize = int(fInfo.Size()) + w.dailyOpenTime = time.Now() + w.dailyOpenDate = w.dailyOpenTime.Day() + w.hourlyOpenTime = time.Now() + w.hourlyOpenDate = w.hourlyOpenTime.Hour() + w.maxLinesCurLines = 0 + if w.Hourly { + go w.hourlyRotate(w.hourlyOpenTime) + } else if w.Daily { + go w.dailyRotate(w.dailyOpenTime) + } + if fInfo.Size() > 0 && w.MaxLines > 0 { + count, err := w.lines() + if err != nil { + return err + } + w.maxLinesCurLines = count + } + return nil +} + +func (w *fileLogWriter) dailyRotate(openTime time.Time) { + y, m, d := openTime.Add(24 * time.Hour).Date() + nextDay := time.Date(y, m, d, 0, 0, 0, 0, openTime.Location()) + tm := time.NewTimer(time.Duration(nextDay.UnixNano() - openTime.UnixNano() + 100)) + <-tm.C + w.Lock() + if w.needRotateDaily(0, time.Now().Day()) { + if err := w.doRotate(time.Now()); err != nil { + fmt.Fprintf(os.Stderr, "FileLogWriter(%q): %s\n", w.Filename, err) + } + } + w.Unlock() +} + +func (w *fileLogWriter) hourlyRotate(openTime time.Time) { + y, m, d := openTime.Add(1 * time.Hour).Date() + h, _, _ := openTime.Add(1 * time.Hour).Clock() + nextHour := time.Date(y, m, d, h, 0, 0, 0, openTime.Location()) + tm := time.NewTimer(time.Duration(nextHour.UnixNano() - openTime.UnixNano() + 100)) + <-tm.C + w.Lock() + if w.needRotateHourly(0, time.Now().Hour()) { + if err := w.doRotate(time.Now()); err != nil { + fmt.Fprintf(os.Stderr, "FileLogWriter(%q): %s\n", w.Filename, err) + } + } + w.Unlock() +} + +func (w *fileLogWriter) lines() (int, error) { + fd, err := os.Open(w.Filename) + if err != nil { + return 0, err + } + defer fd.Close() + + buf := make([]byte, 32768) // 32k + count := 0 + lineSep := []byte{'\n'} + + for { + c, err := fd.Read(buf) + if err != nil && err != io.EOF { + return count, err + } + + count += bytes.Count(buf[:c], lineSep) + + if err == io.EOF { + break + } + } + + return count, nil +} + +// DoRotate means it need to write file in new file. +// new file name like xx.2013-01-01.log (daily) or xx.001.log (by line or size) +func (w *fileLogWriter) doRotate(logTime time.Time) error { + // file exists + // Find the next available number + num := w.MaxFilesCurFiles + 1 + fName := "" + format := "" + var openTime time.Time + rotatePerm, err := strconv.ParseInt(w.RotatePerm, 8, 64) + if err != nil { + return err + } + + _, err = os.Lstat(w.Filename) + if err != nil { + //even if the file is not exist or other ,we should RESTART the logger + goto RESTART_LOGGER + } + + if w.Hourly { + format = "2006010215" + openTime = w.hourlyOpenTime + } else if w.Daily { + format = "2006-01-02" + openTime = w.dailyOpenTime + } + + // only when one of them be setted, then the file would be splited + if w.MaxLines > 0 || w.MaxSize > 0 { + for ; err == nil && num <= w.MaxFiles; num++ { + fName = w.fileNameOnly + fmt.Sprintf(".%s.%03d%s", logTime.Format(format), num, w.suffix) + _, err = os.Lstat(fName) + } + } else { + fName = w.fileNameOnly + fmt.Sprintf(".%s.%03d%s", openTime.Format(format), num, w.suffix) + _, err = os.Lstat(fName) + w.MaxFilesCurFiles = num + } + + // return error if the last file checked still existed + if err == nil { + return fmt.Errorf("Rotate: Cannot find free log number to rename %s", w.Filename) + } + + // close fileWriter before rename + w.fileWriter.Close() + + // Rename the file to its new found name + // even if occurs error,we MUST guarantee to restart new logger + err = os.Rename(w.Filename, fName) + if err != nil { + goto RESTART_LOGGER + } + + err = os.Chmod(fName, os.FileMode(rotatePerm)) + +RESTART_LOGGER: + + startLoggerErr := w.startLogger() + go w.deleteOldLog() + + if startLoggerErr != nil { + return fmt.Errorf("Rotate StartLogger: %s", startLoggerErr) + } + if err != nil { + return fmt.Errorf("Rotate: %s", err) + } + return nil +} + +func (w *fileLogWriter) deleteOldLog() { + dir := filepath.Dir(w.Filename) + absolutePath, err := filepath.EvalSymlinks(w.Filename) + if err == nil { + dir = filepath.Dir(absolutePath) + } + filepath.Walk(dir, func(path string, info os.FileInfo, err error) (returnErr error) { + defer func() { + if r := recover(); r != nil { + fmt.Fprintf(os.Stderr, "Unable to delete old log '%s', error: %v\n", path, r) + } + }() + + if info == nil { + return + } + if w.Hourly { + if !info.IsDir() && info.ModTime().Add(1 * time.Hour * time.Duration(w.MaxHours)).Before(time.Now()) { + if strings.HasPrefix(filepath.Base(path), filepath.Base(w.fileNameOnly)) && + strings.HasSuffix(filepath.Base(path), w.suffix) { + os.Remove(path) + } + } + } else if w.Daily { + if !info.IsDir() && info.ModTime().Add(24 * time.Hour * time.Duration(w.MaxDays)).Before(time.Now()) { + if strings.HasPrefix(filepath.Base(path), filepath.Base(w.fileNameOnly)) && + strings.HasSuffix(filepath.Base(path), w.suffix) { + os.Remove(path) + } + } + } + return + }) +} + +// Destroy close the file description, close file writer. +func (w *fileLogWriter) Destroy() { + w.fileWriter.Close() +} + +// Flush flush file logger. +// there are no buffering messages in file logger in memory. +// flush file means sync file from disk. +func (w *fileLogWriter) Flush() { + w.fileWriter.Sync() +} + +func init() { + Register(AdapterFile, newFileWriter) +} diff --git a/pkg/logs/file_test.go b/pkg/logs/file_test.go new file mode 100644 index 00000000..e7c2ca9a --- /dev/null +++ b/pkg/logs/file_test.go @@ -0,0 +1,420 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package logs + +import ( + "bufio" + "fmt" + "io/ioutil" + "os" + "strconv" + "testing" + "time" +) + +func TestFilePerm(t *testing.T) { + log := NewLogger(10000) + // use 0666 as test perm cause the default umask is 022 + log.SetLogger("file", `{"filename":"test.log", "perm": "0666"}`) + log.Debug("debug") + log.Informational("info") + log.Notice("notice") + log.Warning("warning") + log.Error("error") + log.Alert("alert") + log.Critical("critical") + log.Emergency("emergency") + file, err := os.Stat("test.log") + if err != nil { + t.Fatal(err) + } + if file.Mode() != 0666 { + t.Fatal("unexpected log file permission") + } + os.Remove("test.log") +} + +func TestFile1(t *testing.T) { + log := NewLogger(10000) + log.SetLogger("file", `{"filename":"test.log"}`) + log.Debug("debug") + log.Informational("info") + log.Notice("notice") + log.Warning("warning") + log.Error("error") + log.Alert("alert") + log.Critical("critical") + log.Emergency("emergency") + f, err := os.Open("test.log") + if err != nil { + t.Fatal(err) + } + b := bufio.NewReader(f) + lineNum := 0 + for { + line, _, err := b.ReadLine() + if err != nil { + break + } + if len(line) > 0 { + lineNum++ + } + } + var expected = LevelDebug + 1 + if lineNum != expected { + t.Fatal(lineNum, "not "+strconv.Itoa(expected)+" lines") + } + os.Remove("test.log") +} + +func TestFile2(t *testing.T) { + log := NewLogger(10000) + log.SetLogger("file", fmt.Sprintf(`{"filename":"test2.log","level":%d}`, LevelError)) + log.Debug("debug") + log.Info("info") + log.Notice("notice") + log.Warning("warning") + log.Error("error") + log.Alert("alert") + log.Critical("critical") + log.Emergency("emergency") + f, err := os.Open("test2.log") + if err != nil { + t.Fatal(err) + } + b := bufio.NewReader(f) + lineNum := 0 + for { + line, _, err := b.ReadLine() + if err != nil { + break + } + if len(line) > 0 { + lineNum++ + } + } + var expected = LevelError + 1 + if lineNum != expected { + t.Fatal(lineNum, "not "+strconv.Itoa(expected)+" lines") + } + os.Remove("test2.log") +} + +func TestFileDailyRotate_01(t *testing.T) { + log := NewLogger(10000) + log.SetLogger("file", `{"filename":"test3.log","maxlines":4}`) + log.Debug("debug") + log.Info("info") + log.Notice("notice") + log.Warning("warning") + log.Error("error") + log.Alert("alert") + log.Critical("critical") + log.Emergency("emergency") + rotateName := "test3" + fmt.Sprintf(".%s.%03d", time.Now().Format("2006-01-02"), 1) + ".log" + b, err := exists(rotateName) + if !b || err != nil { + os.Remove("test3.log") + t.Fatal("rotate not generated") + } + os.Remove(rotateName) + os.Remove("test3.log") +} + +func TestFileDailyRotate_02(t *testing.T) { + fn1 := "rotate_day.log" + fn2 := "rotate_day." + time.Now().Add(-24*time.Hour).Format("2006-01-02") + ".001.log" + testFileRotate(t, fn1, fn2, true, false) +} + +func TestFileDailyRotate_03(t *testing.T) { + fn1 := "rotate_day.log" + fn := "rotate_day." + time.Now().Add(-24*time.Hour).Format("2006-01-02") + ".log" + os.Create(fn) + fn2 := "rotate_day." + time.Now().Add(-24*time.Hour).Format("2006-01-02") + ".001.log" + testFileRotate(t, fn1, fn2, true, false) + os.Remove(fn) +} + +func TestFileDailyRotate_04(t *testing.T) { + fn1 := "rotate_day.log" + fn2 := "rotate_day." + time.Now().Add(-24*time.Hour).Format("2006-01-02") + ".001.log" + testFileDailyRotate(t, fn1, fn2) +} + +func TestFileDailyRotate_05(t *testing.T) { + fn1 := "rotate_day.log" + fn := "rotate_day." + time.Now().Add(-24*time.Hour).Format("2006-01-02") + ".log" + os.Create(fn) + fn2 := "rotate_day." + time.Now().Add(-24*time.Hour).Format("2006-01-02") + ".001.log" + testFileDailyRotate(t, fn1, fn2) + os.Remove(fn) +} +func TestFileDailyRotate_06(t *testing.T) { //test file mode + log := NewLogger(10000) + log.SetLogger("file", `{"filename":"test3.log","maxlines":4}`) + log.Debug("debug") + log.Info("info") + log.Notice("notice") + log.Warning("warning") + log.Error("error") + log.Alert("alert") + log.Critical("critical") + log.Emergency("emergency") + rotateName := "test3" + fmt.Sprintf(".%s.%03d", time.Now().Format("2006-01-02"), 1) + ".log" + s, _ := os.Lstat(rotateName) + if s.Mode() != 0440 { + os.Remove(rotateName) + os.Remove("test3.log") + t.Fatal("rotate file mode error") + } + os.Remove(rotateName) + os.Remove("test3.log") +} + +func TestFileHourlyRotate_01(t *testing.T) { + log := NewLogger(10000) + log.SetLogger("file", `{"filename":"test3.log","hourly":true,"maxlines":4}`) + log.Debug("debug") + log.Info("info") + log.Notice("notice") + log.Warning("warning") + log.Error("error") + log.Alert("alert") + log.Critical("critical") + log.Emergency("emergency") + rotateName := "test3" + fmt.Sprintf(".%s.%03d", time.Now().Format("2006010215"), 1) + ".log" + b, err := exists(rotateName) + if !b || err != nil { + os.Remove("test3.log") + t.Fatal("rotate not generated") + } + os.Remove(rotateName) + os.Remove("test3.log") +} + +func TestFileHourlyRotate_02(t *testing.T) { + fn1 := "rotate_hour.log" + fn2 := "rotate_hour." + time.Now().Add(-1*time.Hour).Format("2006010215") + ".001.log" + testFileRotate(t, fn1, fn2, false, true) +} + +func TestFileHourlyRotate_03(t *testing.T) { + fn1 := "rotate_hour.log" + fn := "rotate_hour." + time.Now().Add(-1*time.Hour).Format("2006010215") + ".log" + os.Create(fn) + fn2 := "rotate_hour." + time.Now().Add(-1*time.Hour).Format("2006010215") + ".001.log" + testFileRotate(t, fn1, fn2, false, true) + os.Remove(fn) +} + +func TestFileHourlyRotate_04(t *testing.T) { + fn1 := "rotate_hour.log" + fn2 := "rotate_hour." + time.Now().Add(-1*time.Hour).Format("2006010215") + ".001.log" + testFileHourlyRotate(t, fn1, fn2) +} + +func TestFileHourlyRotate_05(t *testing.T) { + fn1 := "rotate_hour.log" + fn := "rotate_hour." + time.Now().Add(-1*time.Hour).Format("2006010215") + ".log" + os.Create(fn) + fn2 := "rotate_hour." + time.Now().Add(-1*time.Hour).Format("2006010215") + ".001.log" + testFileHourlyRotate(t, fn1, fn2) + os.Remove(fn) +} + +func TestFileHourlyRotate_06(t *testing.T) { //test file mode + log := NewLogger(10000) + log.SetLogger("file", `{"filename":"test3.log", "hourly":true, "maxlines":4}`) + log.Debug("debug") + log.Info("info") + log.Notice("notice") + log.Warning("warning") + log.Error("error") + log.Alert("alert") + log.Critical("critical") + log.Emergency("emergency") + rotateName := "test3" + fmt.Sprintf(".%s.%03d", time.Now().Format("2006010215"), 1) + ".log" + s, _ := os.Lstat(rotateName) + if s.Mode() != 0440 { + os.Remove(rotateName) + os.Remove("test3.log") + t.Fatal("rotate file mode error") + } + os.Remove(rotateName) + os.Remove("test3.log") +} + +func testFileRotate(t *testing.T, fn1, fn2 string, daily, hourly bool) { + fw := &fileLogWriter{ + Daily: daily, + MaxDays: 7, + Hourly: hourly, + MaxHours: 168, + Rotate: true, + Level: LevelTrace, + Perm: "0660", + RotatePerm: "0440", + } + + if daily { + fw.Init(fmt.Sprintf(`{"filename":"%v","maxdays":1}`, fn1)) + fw.dailyOpenTime = time.Now().Add(-24 * time.Hour) + fw.dailyOpenDate = fw.dailyOpenTime.Day() + } + + if hourly { + fw.Init(fmt.Sprintf(`{"filename":"%v","maxhours":1}`, fn1)) + fw.hourlyOpenTime = time.Now().Add(-1 * time.Hour) + fw.hourlyOpenDate = fw.hourlyOpenTime.Day() + } + + fw.WriteMsg(time.Now(), "this is a msg for test", LevelDebug) + + for _, file := range []string{fn1, fn2} { + _, err := os.Stat(file) + if err != nil { + t.Log(err) + t.FailNow() + } + os.Remove(file) + } + fw.Destroy() +} + +func testFileDailyRotate(t *testing.T, fn1, fn2 string) { + fw := &fileLogWriter{ + Daily: true, + MaxDays: 7, + Rotate: true, + Level: LevelTrace, + Perm: "0660", + RotatePerm: "0440", + } + fw.Init(fmt.Sprintf(`{"filename":"%v","maxdays":1}`, fn1)) + fw.dailyOpenTime = time.Now().Add(-24 * time.Hour) + fw.dailyOpenDate = fw.dailyOpenTime.Day() + today, _ := time.ParseInLocation("2006-01-02", time.Now().Format("2006-01-02"), fw.dailyOpenTime.Location()) + today = today.Add(-1 * time.Second) + fw.dailyRotate(today) + for _, file := range []string{fn1, fn2} { + _, err := os.Stat(file) + if err != nil { + t.FailNow() + } + content, err := ioutil.ReadFile(file) + if err != nil { + t.FailNow() + } + if len(content) > 0 { + t.FailNow() + } + os.Remove(file) + } + fw.Destroy() +} + +func testFileHourlyRotate(t *testing.T, fn1, fn2 string) { + fw := &fileLogWriter{ + Hourly: true, + MaxHours: 168, + Rotate: true, + Level: LevelTrace, + Perm: "0660", + RotatePerm: "0440", + } + fw.Init(fmt.Sprintf(`{"filename":"%v","maxhours":1}`, fn1)) + fw.hourlyOpenTime = time.Now().Add(-1 * time.Hour) + fw.hourlyOpenDate = fw.hourlyOpenTime.Hour() + hour, _ := time.ParseInLocation("2006010215", time.Now().Format("2006010215"), fw.hourlyOpenTime.Location()) + hour = hour.Add(-1 * time.Second) + fw.hourlyRotate(hour) + for _, file := range []string{fn1, fn2} { + _, err := os.Stat(file) + if err != nil { + t.FailNow() + } + content, err := ioutil.ReadFile(file) + if err != nil { + t.FailNow() + } + if len(content) > 0 { + t.FailNow() + } + os.Remove(file) + } + fw.Destroy() +} +func exists(path string) (bool, error) { + _, err := os.Stat(path) + if err == nil { + return true, nil + } + if os.IsNotExist(err) { + return false, nil + } + return false, err +} + +func BenchmarkFile(b *testing.B) { + log := NewLogger(100000) + log.SetLogger("file", `{"filename":"test4.log"}`) + for i := 0; i < b.N; i++ { + log.Debug("debug") + } + os.Remove("test4.log") +} + +func BenchmarkFileAsynchronous(b *testing.B) { + log := NewLogger(100000) + log.SetLogger("file", `{"filename":"test4.log"}`) + log.Async() + for i := 0; i < b.N; i++ { + log.Debug("debug") + } + os.Remove("test4.log") +} + +func BenchmarkFileCallDepth(b *testing.B) { + log := NewLogger(100000) + log.SetLogger("file", `{"filename":"test4.log"}`) + log.EnableFuncCallDepth(true) + log.SetLogFuncCallDepth(2) + for i := 0; i < b.N; i++ { + log.Debug("debug") + } + os.Remove("test4.log") +} + +func BenchmarkFileAsynchronousCallDepth(b *testing.B) { + log := NewLogger(100000) + log.SetLogger("file", `{"filename":"test4.log"}`) + log.EnableFuncCallDepth(true) + log.SetLogFuncCallDepth(2) + log.Async() + for i := 0; i < b.N; i++ { + log.Debug("debug") + } + os.Remove("test4.log") +} + +func BenchmarkFileOnGoroutine(b *testing.B) { + log := NewLogger(100000) + log.SetLogger("file", `{"filename":"test4.log"}`) + for i := 0; i < b.N; i++ { + go log.Debug("debug") + } + os.Remove("test4.log") +} diff --git a/pkg/logs/jianliao.go b/pkg/logs/jianliao.go new file mode 100644 index 00000000..88ba0f9a --- /dev/null +++ b/pkg/logs/jianliao.go @@ -0,0 +1,72 @@ +package logs + +import ( + "encoding/json" + "fmt" + "net/http" + "net/url" + "time" +) + +// JLWriter implements beego LoggerInterface and is used to send jiaoliao webhook +type JLWriter struct { + AuthorName string `json:"authorname"` + Title string `json:"title"` + WebhookURL string `json:"webhookurl"` + RedirectURL string `json:"redirecturl,omitempty"` + ImageURL string `json:"imageurl,omitempty"` + Level int `json:"level"` +} + +// newJLWriter create jiaoliao writer. +func newJLWriter() Logger { + return &JLWriter{Level: LevelTrace} +} + +// Init JLWriter with json config string +func (s *JLWriter) Init(jsonconfig string) error { + return json.Unmarshal([]byte(jsonconfig), s) +} + +// WriteMsg write message in smtp writer. +// it will send an email with subject and only this message. +func (s *JLWriter) WriteMsg(when time.Time, msg string, level int) error { + if level > s.Level { + return nil + } + + text := fmt.Sprintf("%s %s", when.Format("2006-01-02 15:04:05"), msg) + + form := url.Values{} + form.Add("authorName", s.AuthorName) + form.Add("title", s.Title) + form.Add("text", text) + if s.RedirectURL != "" { + form.Add("redirectUrl", s.RedirectURL) + } + if s.ImageURL != "" { + form.Add("imageUrl", s.ImageURL) + } + + resp, err := http.PostForm(s.WebhookURL, form) + if err != nil { + return err + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + return fmt.Errorf("Post webhook failed %s %d", resp.Status, resp.StatusCode) + } + return nil +} + +// Flush implementing method. empty. +func (s *JLWriter) Flush() { +} + +// Destroy implementing method. empty. +func (s *JLWriter) Destroy() { +} + +func init() { + Register(AdapterJianLiao, newJLWriter) +} diff --git a/pkg/logs/log.go b/pkg/logs/log.go new file mode 100644 index 00000000..39c006d2 --- /dev/null +++ b/pkg/logs/log.go @@ -0,0 +1,669 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package logs provide a general log interface +// Usage: +// +// import "github.com/astaxie/beego/logs" +// +// log := NewLogger(10000) +// log.SetLogger("console", "") +// +// > the first params stand for how many channel +// +// Use it like this: +// +// log.Trace("trace") +// log.Info("info") +// log.Warn("warning") +// log.Debug("debug") +// log.Critical("critical") +// +// more docs http://beego.me/docs/module/logs.md +package logs + +import ( + "fmt" + "log" + "os" + "path" + "runtime" + "strconv" + "strings" + "sync" + "time" +) + +// RFC5424 log message levels. +const ( + LevelEmergency = iota + LevelAlert + LevelCritical + LevelError + LevelWarning + LevelNotice + LevelInformational + LevelDebug +) + +// levelLogLogger is defined to implement log.Logger +// the real log level will be LevelEmergency +const levelLoggerImpl = -1 + +// Name for adapter with beego official support +const ( + AdapterConsole = "console" + AdapterFile = "file" + AdapterMultiFile = "multifile" + AdapterMail = "smtp" + AdapterConn = "conn" + AdapterEs = "es" + AdapterJianLiao = "jianliao" + AdapterSlack = "slack" + AdapterAliLS = "alils" +) + +// Legacy log level constants to ensure backwards compatibility. +const ( + LevelInfo = LevelInformational + LevelTrace = LevelDebug + LevelWarn = LevelWarning +) + +type newLoggerFunc func() Logger + +// Logger defines the behavior of a log provider. +type Logger interface { + Init(config string) error + WriteMsg(when time.Time, msg string, level int) error + Destroy() + Flush() +} + +var adapters = make(map[string]newLoggerFunc) +var levelPrefix = [LevelDebug + 1]string{"[M]", "[A]", "[C]", "[E]", "[W]", "[N]", "[I]", "[D]"} + +// Register makes a log provide available by the provided name. +// If Register is called twice with the same name or if driver is nil, +// it panics. +func Register(name string, log newLoggerFunc) { + if log == nil { + panic("logs: Register provide is nil") + } + if _, dup := adapters[name]; dup { + panic("logs: Register called twice for provider " + name) + } + adapters[name] = log +} + +// BeeLogger is default logger in beego application. +// it can contain several providers and log message into all providers. +type BeeLogger struct { + lock sync.Mutex + level int + init bool + enableFuncCallDepth bool + loggerFuncCallDepth int + asynchronous bool + prefix string + msgChanLen int64 + msgChan chan *logMsg + signalChan chan string + wg sync.WaitGroup + outputs []*nameLogger +} + +const defaultAsyncMsgLen = 1e3 + +type nameLogger struct { + Logger + name string +} + +type logMsg struct { + level int + msg string + when time.Time +} + +var logMsgPool *sync.Pool + +// NewLogger returns a new BeeLogger. +// channelLen means the number of messages in chan(used where asynchronous is true). +// if the buffering chan is full, logger adapters write to file or other way. +func NewLogger(channelLens ...int64) *BeeLogger { + bl := new(BeeLogger) + bl.level = LevelDebug + bl.loggerFuncCallDepth = 2 + bl.msgChanLen = append(channelLens, 0)[0] + if bl.msgChanLen <= 0 { + bl.msgChanLen = defaultAsyncMsgLen + } + bl.signalChan = make(chan string, 1) + bl.setLogger(AdapterConsole) + return bl +} + +// Async set the log to asynchronous and start the goroutine +func (bl *BeeLogger) Async(msgLen ...int64) *BeeLogger { + bl.lock.Lock() + defer bl.lock.Unlock() + if bl.asynchronous { + return bl + } + bl.asynchronous = true + if len(msgLen) > 0 && msgLen[0] > 0 { + bl.msgChanLen = msgLen[0] + } + bl.msgChan = make(chan *logMsg, bl.msgChanLen) + logMsgPool = &sync.Pool{ + New: func() interface{} { + return &logMsg{} + }, + } + bl.wg.Add(1) + go bl.startLogger() + return bl +} + +// SetLogger provides a given logger adapter into BeeLogger with config string. +// config need to be correct JSON as string: {"interval":360}. +func (bl *BeeLogger) setLogger(adapterName string, configs ...string) error { + config := append(configs, "{}")[0] + for _, l := range bl.outputs { + if l.name == adapterName { + return fmt.Errorf("logs: duplicate adaptername %q (you have set this logger before)", adapterName) + } + } + + logAdapter, ok := adapters[adapterName] + if !ok { + return fmt.Errorf("logs: unknown adaptername %q (forgotten Register?)", adapterName) + } + + lg := logAdapter() + err := lg.Init(config) + if err != nil { + fmt.Fprintln(os.Stderr, "logs.BeeLogger.SetLogger: "+err.Error()) + return err + } + bl.outputs = append(bl.outputs, &nameLogger{name: adapterName, Logger: lg}) + return nil +} + +// SetLogger provides a given logger adapter into BeeLogger with config string. +// config need to be correct JSON as string: {"interval":360}. +func (bl *BeeLogger) SetLogger(adapterName string, configs ...string) error { + bl.lock.Lock() + defer bl.lock.Unlock() + if !bl.init { + bl.outputs = []*nameLogger{} + bl.init = true + } + return bl.setLogger(adapterName, configs...) +} + +// DelLogger remove a logger adapter in BeeLogger. +func (bl *BeeLogger) DelLogger(adapterName string) error { + bl.lock.Lock() + defer bl.lock.Unlock() + outputs := []*nameLogger{} + for _, lg := range bl.outputs { + if lg.name == adapterName { + lg.Destroy() + } else { + outputs = append(outputs, lg) + } + } + if len(outputs) == len(bl.outputs) { + return fmt.Errorf("logs: unknown adaptername %q (forgotten Register?)", adapterName) + } + bl.outputs = outputs + return nil +} + +func (bl *BeeLogger) writeToLoggers(when time.Time, msg string, level int) { + for _, l := range bl.outputs { + err := l.WriteMsg(when, msg, level) + if err != nil { + fmt.Fprintf(os.Stderr, "unable to WriteMsg to adapter:%v,error:%v\n", l.name, err) + } + } +} + +func (bl *BeeLogger) Write(p []byte) (n int, err error) { + if len(p) == 0 { + return 0, nil + } + // writeMsg will always add a '\n' character + if p[len(p)-1] == '\n' { + p = p[0 : len(p)-1] + } + // set levelLoggerImpl to ensure all log message will be write out + err = bl.writeMsg(levelLoggerImpl, string(p)) + if err == nil { + return len(p), err + } + return 0, err +} + +func (bl *BeeLogger) writeMsg(logLevel int, msg string, v ...interface{}) error { + if !bl.init { + bl.lock.Lock() + bl.setLogger(AdapterConsole) + bl.lock.Unlock() + } + + if len(v) > 0 { + msg = fmt.Sprintf(msg, v...) + } + + msg = bl.prefix + " " + msg + + when := time.Now() + if bl.enableFuncCallDepth { + _, file, line, ok := runtime.Caller(bl.loggerFuncCallDepth) + if !ok { + file = "???" + line = 0 + } + _, filename := path.Split(file) + msg = "[" + filename + ":" + strconv.Itoa(line) + "] " + msg + } + + //set level info in front of filename info + if logLevel == levelLoggerImpl { + // set to emergency to ensure all log will be print out correctly + logLevel = LevelEmergency + } else { + msg = levelPrefix[logLevel] + " " + msg + } + + if bl.asynchronous { + lm := logMsgPool.Get().(*logMsg) + lm.level = logLevel + lm.msg = msg + lm.when = when + if bl.outputs != nil { + bl.msgChan <- lm + } else { + logMsgPool.Put(lm) + } + } else { + bl.writeToLoggers(when, msg, logLevel) + } + return nil +} + +// SetLevel Set log message level. +// If message level (such as LevelDebug) is higher than logger level (such as LevelWarning), +// log providers will not even be sent the message. +func (bl *BeeLogger) SetLevel(l int) { + bl.level = l +} + +// GetLevel Get Current log message level. +func (bl *BeeLogger) GetLevel() int { + return bl.level +} + +// SetLogFuncCallDepth set log funcCallDepth +func (bl *BeeLogger) SetLogFuncCallDepth(d int) { + bl.loggerFuncCallDepth = d +} + +// GetLogFuncCallDepth return log funcCallDepth for wrapper +func (bl *BeeLogger) GetLogFuncCallDepth() int { + return bl.loggerFuncCallDepth +} + +// EnableFuncCallDepth enable log funcCallDepth +func (bl *BeeLogger) EnableFuncCallDepth(b bool) { + bl.enableFuncCallDepth = b +} + +// set prefix +func (bl *BeeLogger) SetPrefix(s string) { + bl.prefix = s +} + +// start logger chan reading. +// when chan is not empty, write logs. +func (bl *BeeLogger) startLogger() { + gameOver := false + for { + select { + case bm := <-bl.msgChan: + bl.writeToLoggers(bm.when, bm.msg, bm.level) + logMsgPool.Put(bm) + case sg := <-bl.signalChan: + // Now should only send "flush" or "close" to bl.signalChan + bl.flush() + if sg == "close" { + for _, l := range bl.outputs { + l.Destroy() + } + bl.outputs = nil + gameOver = true + } + bl.wg.Done() + } + if gameOver { + break + } + } +} + +// Emergency Log EMERGENCY level message. +func (bl *BeeLogger) Emergency(format string, v ...interface{}) { + if LevelEmergency > bl.level { + return + } + bl.writeMsg(LevelEmergency, format, v...) +} + +// Alert Log ALERT level message. +func (bl *BeeLogger) Alert(format string, v ...interface{}) { + if LevelAlert > bl.level { + return + } + bl.writeMsg(LevelAlert, format, v...) +} + +// Critical Log CRITICAL level message. +func (bl *BeeLogger) Critical(format string, v ...interface{}) { + if LevelCritical > bl.level { + return + } + bl.writeMsg(LevelCritical, format, v...) +} + +// Error Log ERROR level message. +func (bl *BeeLogger) Error(format string, v ...interface{}) { + if LevelError > bl.level { + return + } + bl.writeMsg(LevelError, format, v...) +} + +// Warning Log WARNING level message. +func (bl *BeeLogger) Warning(format string, v ...interface{}) { + if LevelWarn > bl.level { + return + } + bl.writeMsg(LevelWarn, format, v...) +} + +// Notice Log NOTICE level message. +func (bl *BeeLogger) Notice(format string, v ...interface{}) { + if LevelNotice > bl.level { + return + } + bl.writeMsg(LevelNotice, format, v...) +} + +// Informational Log INFORMATIONAL level message. +func (bl *BeeLogger) Informational(format string, v ...interface{}) { + if LevelInfo > bl.level { + return + } + bl.writeMsg(LevelInfo, format, v...) +} + +// Debug Log DEBUG level message. +func (bl *BeeLogger) Debug(format string, v ...interface{}) { + if LevelDebug > bl.level { + return + } + bl.writeMsg(LevelDebug, format, v...) +} + +// Warn Log WARN level message. +// compatibility alias for Warning() +func (bl *BeeLogger) Warn(format string, v ...interface{}) { + if LevelWarn > bl.level { + return + } + bl.writeMsg(LevelWarn, format, v...) +} + +// Info Log INFO level message. +// compatibility alias for Informational() +func (bl *BeeLogger) Info(format string, v ...interface{}) { + if LevelInfo > bl.level { + return + } + bl.writeMsg(LevelInfo, format, v...) +} + +// Trace Log TRACE level message. +// compatibility alias for Debug() +func (bl *BeeLogger) Trace(format string, v ...interface{}) { + if LevelDebug > bl.level { + return + } + bl.writeMsg(LevelDebug, format, v...) +} + +// Flush flush all chan data. +func (bl *BeeLogger) Flush() { + if bl.asynchronous { + bl.signalChan <- "flush" + bl.wg.Wait() + bl.wg.Add(1) + return + } + bl.flush() +} + +// Close close logger, flush all chan data and destroy all adapters in BeeLogger. +func (bl *BeeLogger) Close() { + if bl.asynchronous { + bl.signalChan <- "close" + bl.wg.Wait() + close(bl.msgChan) + } else { + bl.flush() + for _, l := range bl.outputs { + l.Destroy() + } + bl.outputs = nil + } + close(bl.signalChan) +} + +// Reset close all outputs, and set bl.outputs to nil +func (bl *BeeLogger) Reset() { + bl.Flush() + for _, l := range bl.outputs { + l.Destroy() + } + bl.outputs = nil +} + +func (bl *BeeLogger) flush() { + if bl.asynchronous { + for { + if len(bl.msgChan) > 0 { + bm := <-bl.msgChan + bl.writeToLoggers(bm.when, bm.msg, bm.level) + logMsgPool.Put(bm) + continue + } + break + } + } + for _, l := range bl.outputs { + l.Flush() + } +} + +// beeLogger references the used application logger. +var beeLogger = NewLogger() + +// GetBeeLogger returns the default BeeLogger +func GetBeeLogger() *BeeLogger { + return beeLogger +} + +var beeLoggerMap = struct { + sync.RWMutex + logs map[string]*log.Logger +}{ + logs: map[string]*log.Logger{}, +} + +// GetLogger returns the default BeeLogger +func GetLogger(prefixes ...string) *log.Logger { + prefix := append(prefixes, "")[0] + if prefix != "" { + prefix = fmt.Sprintf(`[%s] `, strings.ToUpper(prefix)) + } + beeLoggerMap.RLock() + l, ok := beeLoggerMap.logs[prefix] + if ok { + beeLoggerMap.RUnlock() + return l + } + beeLoggerMap.RUnlock() + beeLoggerMap.Lock() + defer beeLoggerMap.Unlock() + l, ok = beeLoggerMap.logs[prefix] + if !ok { + l = log.New(beeLogger, prefix, 0) + beeLoggerMap.logs[prefix] = l + } + return l +} + +// Reset will remove all the adapter +func Reset() { + beeLogger.Reset() +} + +// Async set the beelogger with Async mode and hold msglen messages +func Async(msgLen ...int64) *BeeLogger { + return beeLogger.Async(msgLen...) +} + +// SetLevel sets the global log level used by the simple logger. +func SetLevel(l int) { + beeLogger.SetLevel(l) +} + +// SetPrefix sets the prefix +func SetPrefix(s string) { + beeLogger.SetPrefix(s) +} + +// EnableFuncCallDepth enable log funcCallDepth +func EnableFuncCallDepth(b bool) { + beeLogger.enableFuncCallDepth = b +} + +// SetLogFuncCall set the CallDepth, default is 4 +func SetLogFuncCall(b bool) { + beeLogger.EnableFuncCallDepth(b) + beeLogger.SetLogFuncCallDepth(4) +} + +// SetLogFuncCallDepth set log funcCallDepth +func SetLogFuncCallDepth(d int) { + beeLogger.loggerFuncCallDepth = d +} + +// SetLogger sets a new logger. +func SetLogger(adapter string, config ...string) error { + return beeLogger.SetLogger(adapter, config...) +} + +// Emergency logs a message at emergency level. +func Emergency(f interface{}, v ...interface{}) { + beeLogger.Emergency(formatLog(f, v...)) +} + +// Alert logs a message at alert level. +func Alert(f interface{}, v ...interface{}) { + beeLogger.Alert(formatLog(f, v...)) +} + +// Critical logs a message at critical level. +func Critical(f interface{}, v ...interface{}) { + beeLogger.Critical(formatLog(f, v...)) +} + +// Error logs a message at error level. +func Error(f interface{}, v ...interface{}) { + beeLogger.Error(formatLog(f, v...)) +} + +// Warning logs a message at warning level. +func Warning(f interface{}, v ...interface{}) { + beeLogger.Warn(formatLog(f, v...)) +} + +// Warn compatibility alias for Warning() +func Warn(f interface{}, v ...interface{}) { + beeLogger.Warn(formatLog(f, v...)) +} + +// Notice logs a message at notice level. +func Notice(f interface{}, v ...interface{}) { + beeLogger.Notice(formatLog(f, v...)) +} + +// Informational logs a message at info level. +func Informational(f interface{}, v ...interface{}) { + beeLogger.Info(formatLog(f, v...)) +} + +// Info compatibility alias for Warning() +func Info(f interface{}, v ...interface{}) { + beeLogger.Info(formatLog(f, v...)) +} + +// Debug logs a message at debug level. +func Debug(f interface{}, v ...interface{}) { + beeLogger.Debug(formatLog(f, v...)) +} + +// Trace logs a message at trace level. +// compatibility alias for Warning() +func Trace(f interface{}, v ...interface{}) { + beeLogger.Trace(formatLog(f, v...)) +} + +func formatLog(f interface{}, v ...interface{}) string { + var msg string + switch f.(type) { + case string: + msg = f.(string) + if len(v) == 0 { + return msg + } + if strings.Contains(msg, "%") && !strings.Contains(msg, "%%") { + //format string + } else { + //do not contain format char + msg += strings.Repeat(" %v", len(v)) + } + default: + msg = fmt.Sprint(f) + if len(v) == 0 { + return msg + } + msg += strings.Repeat(" %v", len(v)) + } + return fmt.Sprintf(msg, v...) +} diff --git a/pkg/logs/logger.go b/pkg/logs/logger.go new file mode 100644 index 00000000..a28bff6f --- /dev/null +++ b/pkg/logs/logger.go @@ -0,0 +1,176 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package logs + +import ( + "io" + "runtime" + "sync" + "time" +) + +type logWriter struct { + sync.Mutex + writer io.Writer +} + +func newLogWriter(wr io.Writer) *logWriter { + return &logWriter{writer: wr} +} + +func (lg *logWriter) writeln(when time.Time, msg string) (int, error) { + lg.Lock() + h, _, _ := formatTimeHeader(when) + n, err := lg.writer.Write(append(append(h, msg...), '\n')) + lg.Unlock() + return n, err +} + +const ( + y1 = `0123456789` + y2 = `0123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789` + y3 = `0000000000111111111122222222223333333333444444444455555555556666666666777777777788888888889999999999` + y4 = `0123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789` + mo1 = `000000000111` + mo2 = `123456789012` + d1 = `0000000001111111111222222222233` + d2 = `1234567890123456789012345678901` + h1 = `000000000011111111112222` + h2 = `012345678901234567890123` + mi1 = `000000000011111111112222222222333333333344444444445555555555` + mi2 = `012345678901234567890123456789012345678901234567890123456789` + s1 = `000000000011111111112222222222333333333344444444445555555555` + s2 = `012345678901234567890123456789012345678901234567890123456789` + ns1 = `0123456789` +) + +func formatTimeHeader(when time.Time) ([]byte, int, int) { + y, mo, d := when.Date() + h, mi, s := when.Clock() + ns := when.Nanosecond() / 1000000 + //len("2006/01/02 15:04:05.123 ")==24 + var buf [24]byte + + buf[0] = y1[y/1000%10] + buf[1] = y2[y/100] + buf[2] = y3[y-y/100*100] + buf[3] = y4[y-y/100*100] + buf[4] = '/' + buf[5] = mo1[mo-1] + buf[6] = mo2[mo-1] + buf[7] = '/' + buf[8] = d1[d-1] + buf[9] = d2[d-1] + buf[10] = ' ' + buf[11] = h1[h] + buf[12] = h2[h] + buf[13] = ':' + buf[14] = mi1[mi] + buf[15] = mi2[mi] + buf[16] = ':' + buf[17] = s1[s] + buf[18] = s2[s] + buf[19] = '.' + buf[20] = ns1[ns/100] + buf[21] = ns1[ns%100/10] + buf[22] = ns1[ns%10] + + buf[23] = ' ' + + return buf[0:], d, h +} + +var ( + green = string([]byte{27, 91, 57, 55, 59, 52, 50, 109}) + white = string([]byte{27, 91, 57, 48, 59, 52, 55, 109}) + yellow = string([]byte{27, 91, 57, 55, 59, 52, 51, 109}) + red = string([]byte{27, 91, 57, 55, 59, 52, 49, 109}) + blue = string([]byte{27, 91, 57, 55, 59, 52, 52, 109}) + magenta = string([]byte{27, 91, 57, 55, 59, 52, 53, 109}) + cyan = string([]byte{27, 91, 57, 55, 59, 52, 54, 109}) + + w32Green = string([]byte{27, 91, 52, 50, 109}) + w32White = string([]byte{27, 91, 52, 55, 109}) + w32Yellow = string([]byte{27, 91, 52, 51, 109}) + w32Red = string([]byte{27, 91, 52, 49, 109}) + w32Blue = string([]byte{27, 91, 52, 52, 109}) + w32Magenta = string([]byte{27, 91, 52, 53, 109}) + w32Cyan = string([]byte{27, 91, 52, 54, 109}) + + reset = string([]byte{27, 91, 48, 109}) +) + +var once sync.Once +var colorMap map[string]string + +func initColor() { + if runtime.GOOS == "windows" { + green = w32Green + white = w32White + yellow = w32Yellow + red = w32Red + blue = w32Blue + magenta = w32Magenta + cyan = w32Cyan + } + colorMap = map[string]string{ + //by color + "green": green, + "white": white, + "yellow": yellow, + "red": red, + //by method + "GET": blue, + "POST": cyan, + "PUT": yellow, + "DELETE": red, + "PATCH": green, + "HEAD": magenta, + "OPTIONS": white, + } +} + +// ColorByStatus return color by http code +// 2xx return Green +// 3xx return White +// 4xx return Yellow +// 5xx return Red +func ColorByStatus(code int) string { + once.Do(initColor) + switch { + case code >= 200 && code < 300: + return colorMap["green"] + case code >= 300 && code < 400: + return colorMap["white"] + case code >= 400 && code < 500: + return colorMap["yellow"] + default: + return colorMap["red"] + } +} + +// ColorByMethod return color by http code +func ColorByMethod(method string) string { + once.Do(initColor) + if c := colorMap[method]; c != "" { + return c + } + return reset +} + +// ResetColor return reset color +func ResetColor() string { + return reset +} diff --git a/pkg/logs/logger_test.go b/pkg/logs/logger_test.go new file mode 100644 index 00000000..15be500d --- /dev/null +++ b/pkg/logs/logger_test.go @@ -0,0 +1,57 @@ +// Copyright 2016 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package logs + +import ( + "testing" + "time" +) + +func TestFormatHeader_0(t *testing.T) { + tm := time.Now() + if tm.Year() >= 2100 { + t.FailNow() + } + dur := time.Second + for { + if tm.Year() >= 2100 { + break + } + h, _, _ := formatTimeHeader(tm) + if tm.Format("2006/01/02 15:04:05.000 ") != string(h) { + t.Log(tm) + t.FailNow() + } + tm = tm.Add(dur) + dur *= 2 + } +} + +func TestFormatHeader_1(t *testing.T) { + tm := time.Now() + year := tm.Year() + dur := time.Second + for { + if tm.Year() >= year+1 { + break + } + h, _, _ := formatTimeHeader(tm) + if tm.Format("2006/01/02 15:04:05.000 ") != string(h) { + t.Log(tm) + t.FailNow() + } + tm = tm.Add(dur) + } +} diff --git a/pkg/logs/multifile.go b/pkg/logs/multifile.go new file mode 100644 index 00000000..90168274 --- /dev/null +++ b/pkg/logs/multifile.go @@ -0,0 +1,119 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package logs + +import ( + "encoding/json" + "time" +) + +// A filesLogWriter manages several fileLogWriter +// filesLogWriter will write logs to the file in json configuration and write the same level log to correspond file +// means if the file name in configuration is project.log filesLogWriter will create project.error.log/project.debug.log +// and write the error-level logs to project.error.log and write the debug-level logs to project.debug.log +// the rotate attribute also acts like fileLogWriter +type multiFileLogWriter struct { + writers [LevelDebug + 1 + 1]*fileLogWriter // the last one for fullLogWriter + fullLogWriter *fileLogWriter + Separate []string `json:"separate"` +} + +var levelNames = [...]string{"emergency", "alert", "critical", "error", "warning", "notice", "info", "debug"} + +// Init file logger with json config. +// jsonConfig like: +// { +// "filename":"logs/beego.log", +// "maxLines":0, +// "maxsize":0, +// "daily":true, +// "maxDays":15, +// "rotate":true, +// "perm":0600, +// "separate":["emergency", "alert", "critical", "error", "warning", "notice", "info", "debug"], +// } + +func (f *multiFileLogWriter) Init(config string) error { + writer := newFileWriter().(*fileLogWriter) + err := writer.Init(config) + if err != nil { + return err + } + f.fullLogWriter = writer + f.writers[LevelDebug+1] = writer + + //unmarshal "separate" field to f.Separate + json.Unmarshal([]byte(config), f) + + jsonMap := map[string]interface{}{} + json.Unmarshal([]byte(config), &jsonMap) + + for i := LevelEmergency; i < LevelDebug+1; i++ { + for _, v := range f.Separate { + if v == levelNames[i] { + jsonMap["filename"] = f.fullLogWriter.fileNameOnly + "." + levelNames[i] + f.fullLogWriter.suffix + jsonMap["level"] = i + bs, _ := json.Marshal(jsonMap) + writer = newFileWriter().(*fileLogWriter) + err := writer.Init(string(bs)) + if err != nil { + return err + } + f.writers[i] = writer + } + } + } + + return nil +} + +func (f *multiFileLogWriter) Destroy() { + for i := 0; i < len(f.writers); i++ { + if f.writers[i] != nil { + f.writers[i].Destroy() + } + } +} + +func (f *multiFileLogWriter) WriteMsg(when time.Time, msg string, level int) error { + if f.fullLogWriter != nil { + f.fullLogWriter.WriteMsg(when, msg, level) + } + for i := 0; i < len(f.writers)-1; i++ { + if f.writers[i] != nil { + if level == f.writers[i].Level { + f.writers[i].WriteMsg(when, msg, level) + } + } + } + return nil +} + +func (f *multiFileLogWriter) Flush() { + for i := 0; i < len(f.writers); i++ { + if f.writers[i] != nil { + f.writers[i].Flush() + } + } +} + +// newFilesWriter create a FileLogWriter returning as LoggerInterface. +func newFilesWriter() Logger { + return &multiFileLogWriter{} +} + +func init() { + Register(AdapterMultiFile, newFilesWriter) +} diff --git a/pkg/logs/multifile_test.go b/pkg/logs/multifile_test.go new file mode 100644 index 00000000..57b96094 --- /dev/null +++ b/pkg/logs/multifile_test.go @@ -0,0 +1,78 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package logs + +import ( + "bufio" + "os" + "strconv" + "strings" + "testing" +) + +func TestFiles_1(t *testing.T) { + log := NewLogger(10000) + log.SetLogger("multifile", `{"filename":"test.log","separate":["emergency", "alert", "critical", "error", "warning", "notice", "info", "debug"]}`) + log.Debug("debug") + log.Informational("info") + log.Notice("notice") + log.Warning("warning") + log.Error("error") + log.Alert("alert") + log.Critical("critical") + log.Emergency("emergency") + fns := []string{""} + fns = append(fns, levelNames[0:]...) + name := "test" + suffix := ".log" + for _, fn := range fns { + + file := name + suffix + if fn != "" { + file = name + "." + fn + suffix + } + f, err := os.Open(file) + if err != nil { + t.Fatal(err) + } + b := bufio.NewReader(f) + lineNum := 0 + lastLine := "" + for { + line, _, err := b.ReadLine() + if err != nil { + break + } + if len(line) > 0 { + lastLine = string(line) + lineNum++ + } + } + var expected = 1 + if fn == "" { + expected = LevelDebug + 1 + } + if lineNum != expected { + t.Fatal(file, "has", lineNum, "lines not "+strconv.Itoa(expected)+" lines") + } + if lineNum == 1 { + if !strings.Contains(lastLine, fn) { + t.Fatal(file + " " + lastLine + " not contains the log msg " + fn) + } + } + os.Remove(file) + } + +} diff --git a/pkg/logs/slack.go b/pkg/logs/slack.go new file mode 100644 index 00000000..1cd2e5ae --- /dev/null +++ b/pkg/logs/slack.go @@ -0,0 +1,60 @@ +package logs + +import ( + "encoding/json" + "fmt" + "net/http" + "net/url" + "time" +) + +// SLACKWriter implements beego LoggerInterface and is used to send jiaoliao webhook +type SLACKWriter struct { + WebhookURL string `json:"webhookurl"` + Level int `json:"level"` +} + +// newSLACKWriter create jiaoliao writer. +func newSLACKWriter() Logger { + return &SLACKWriter{Level: LevelTrace} +} + +// Init SLACKWriter with json config string +func (s *SLACKWriter) Init(jsonconfig string) error { + return json.Unmarshal([]byte(jsonconfig), s) +} + +// WriteMsg write message in smtp writer. +// it will send an email with subject and only this message. +func (s *SLACKWriter) WriteMsg(when time.Time, msg string, level int) error { + if level > s.Level { + return nil + } + + text := fmt.Sprintf("{\"text\": \"%s %s\"}", when.Format("2006-01-02 15:04:05"), msg) + + form := url.Values{} + form.Add("payload", text) + + resp, err := http.PostForm(s.WebhookURL, form) + if err != nil { + return err + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + return fmt.Errorf("Post webhook failed %s %d", resp.Status, resp.StatusCode) + } + return nil +} + +// Flush implementing method. empty. +func (s *SLACKWriter) Flush() { +} + +// Destroy implementing method. empty. +func (s *SLACKWriter) Destroy() { +} + +func init() { + Register(AdapterSlack, newSLACKWriter) +} diff --git a/pkg/logs/smtp.go b/pkg/logs/smtp.go new file mode 100644 index 00000000..6208d7b8 --- /dev/null +++ b/pkg/logs/smtp.go @@ -0,0 +1,149 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package logs + +import ( + "crypto/tls" + "encoding/json" + "fmt" + "net" + "net/smtp" + "strings" + "time" +) + +// SMTPWriter implements LoggerInterface and is used to send emails via given SMTP-server. +type SMTPWriter struct { + Username string `json:"username"` + Password string `json:"password"` + Host string `json:"host"` + Subject string `json:"subject"` + FromAddress string `json:"fromAddress"` + RecipientAddresses []string `json:"sendTos"` + Level int `json:"level"` +} + +// NewSMTPWriter create smtp writer. +func newSMTPWriter() Logger { + return &SMTPWriter{Level: LevelTrace} +} + +// Init smtp writer with json config. +// config like: +// { +// "username":"example@gmail.com", +// "password:"password", +// "host":"smtp.gmail.com:465", +// "subject":"email title", +// "fromAddress":"from@example.com", +// "sendTos":["email1","email2"], +// "level":LevelError +// } +func (s *SMTPWriter) Init(jsonconfig string) error { + return json.Unmarshal([]byte(jsonconfig), s) +} + +func (s *SMTPWriter) getSMTPAuth(host string) smtp.Auth { + if len(strings.Trim(s.Username, " ")) == 0 && len(strings.Trim(s.Password, " ")) == 0 { + return nil + } + return smtp.PlainAuth( + "", + s.Username, + s.Password, + host, + ) +} + +func (s *SMTPWriter) sendMail(hostAddressWithPort string, auth smtp.Auth, fromAddress string, recipients []string, msgContent []byte) error { + client, err := smtp.Dial(hostAddressWithPort) + if err != nil { + return err + } + + host, _, _ := net.SplitHostPort(hostAddressWithPort) + tlsConn := &tls.Config{ + InsecureSkipVerify: true, + ServerName: host, + } + if err = client.StartTLS(tlsConn); err != nil { + return err + } + + if auth != nil { + if err = client.Auth(auth); err != nil { + return err + } + } + + if err = client.Mail(fromAddress); err != nil { + return err + } + + for _, rec := range recipients { + if err = client.Rcpt(rec); err != nil { + return err + } + } + + w, err := client.Data() + if err != nil { + return err + } + _, err = w.Write(msgContent) + if err != nil { + return err + } + + err = w.Close() + if err != nil { + return err + } + + return client.Quit() +} + +// WriteMsg write message in smtp writer. +// it will send an email with subject and only this message. +func (s *SMTPWriter) WriteMsg(when time.Time, msg string, level int) error { + if level > s.Level { + return nil + } + + hp := strings.Split(s.Host, ":") + + // Set up authentication information. + auth := s.getSMTPAuth(hp[0]) + + // Connect to the server, authenticate, set the sender and recipient, + // and send the email all in one step. + contentType := "Content-Type: text/plain" + "; charset=UTF-8" + mailmsg := []byte("To: " + strings.Join(s.RecipientAddresses, ";") + "\r\nFrom: " + s.FromAddress + "<" + s.FromAddress + + ">\r\nSubject: " + s.Subject + "\r\n" + contentType + "\r\n\r\n" + fmt.Sprintf(".%s", when.Format("2006-01-02 15:04:05")) + msg) + + return s.sendMail(s.Host, auth, s.FromAddress, s.RecipientAddresses, mailmsg) +} + +// Flush implementing method. empty. +func (s *SMTPWriter) Flush() { +} + +// Destroy implementing method. empty. +func (s *SMTPWriter) Destroy() { +} + +func init() { + Register(AdapterMail, newSMTPWriter) +} diff --git a/pkg/logs/smtp_test.go b/pkg/logs/smtp_test.go new file mode 100644 index 00000000..28e762d2 --- /dev/null +++ b/pkg/logs/smtp_test.go @@ -0,0 +1,27 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package logs + +import ( + "testing" + "time" +) + +func TestSmtp(t *testing.T) { + log := NewLogger(10000) + log.SetLogger("smtp", `{"username":"beegotest@gmail.com","password":"xxxxxxxx","host":"smtp.gmail.com:587","sendTos":["xiemengjun@gmail.com"]}`) + log.Critical("sendmail critical") + time.Sleep(time.Second * 30) +} diff --git a/pkg/metric/prometheus.go b/pkg/metric/prometheus.go new file mode 100644 index 00000000..7722240b --- /dev/null +++ b/pkg/metric/prometheus.go @@ -0,0 +1,99 @@ +// Copyright 2020 astaxie +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package metric + +import ( + "net/http" + "reflect" + "strconv" + "strings" + "time" + + "github.com/prometheus/client_golang/prometheus" + + "github.com/astaxie/beego" + "github.com/astaxie/beego/logs" +) + +func PrometheusMiddleWare(next http.Handler) http.Handler { + summaryVec := prometheus.NewSummaryVec(prometheus.SummaryOpts{ + Name: "beego", + Subsystem: "http_request", + ConstLabels: map[string]string{ + "server": beego.BConfig.ServerName, + "env": beego.BConfig.RunMode, + "appname": beego.BConfig.AppName, + }, + Help: "The statics info for http request", + }, []string{"pattern", "method", "status", "duration"}) + + prometheus.MustRegister(summaryVec) + + registerBuildInfo() + + return http.HandlerFunc(func(writer http.ResponseWriter, q *http.Request) { + start := time.Now() + next.ServeHTTP(writer, q) + end := time.Now() + go report(end.Sub(start), writer, q, summaryVec) + }) +} + +func registerBuildInfo() { + buildInfo := prometheus.NewGaugeVec(prometheus.GaugeOpts{ + Name: "beego", + Subsystem: "build_info", + Help: "The building information", + ConstLabels: map[string]string{ + "appname": beego.BConfig.AppName, + "build_version": beego.BuildVersion, + "build_revision": beego.BuildGitRevision, + "build_status": beego.BuildStatus, + "build_tag": beego.BuildTag, + "build_time": strings.Replace(beego.BuildTime, "--", " ", 1), + "go_version": beego.GoVersion, + "git_branch": beego.GitBranch, + "start_time": time.Now().Format("2006-01-02 15:04:05"), + }, + }, []string{}) + + prometheus.MustRegister(buildInfo) + buildInfo.WithLabelValues().Set(1) +} + +func report(dur time.Duration, writer http.ResponseWriter, q *http.Request, vec *prometheus.SummaryVec) { + ctrl := beego.BeeApp.Handlers + ctx := ctrl.GetContext() + ctx.Reset(writer, q) + defer ctrl.GiveBackContext(ctx) + + // We cannot read the status code from q.Response.StatusCode + // since the http server does not set q.Response. So q.Response is nil + // Thus, we use reflection to read the status from writer whose concrete type is http.response + responseVal := reflect.ValueOf(writer).Elem() + field := responseVal.FieldByName("status") + status := -1 + if field.IsValid() && field.Kind() == reflect.Int { + status = int(field.Int()) + } + ptn := "UNKNOWN" + if rt, found := ctrl.FindRouter(ctx); found { + ptn = rt.GetPattern() + } else { + logs.Warn("we can not find the router info for this request, so request will be recorded as UNKNOWN: " + q.URL.String()) + } + ms := dur / time.Millisecond + vec.WithLabelValues(ptn, q.Method, strconv.Itoa(status), strconv.Itoa(int(ms))).Observe(float64(ms)) +} diff --git a/pkg/metric/prometheus_test.go b/pkg/metric/prometheus_test.go new file mode 100644 index 00000000..d82a6dec --- /dev/null +++ b/pkg/metric/prometheus_test.go @@ -0,0 +1,42 @@ +// Copyright 2020 astaxie +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package metric + +import ( + "net/http" + "net/url" + "testing" + "time" + + "github.com/prometheus/client_golang/prometheus" + + "github.com/astaxie/beego/context" +) + +func TestPrometheusMiddleWare(t *testing.T) { + middleware := PrometheusMiddleWare(http.HandlerFunc(func(http.ResponseWriter, *http.Request) {})) + writer := &context.Response{} + request := &http.Request{ + URL: &url.URL{ + Host: "localhost", + RawPath: "/a/b/c", + }, + Method: "POST", + } + vec := prometheus.NewSummaryVec(prometheus.SummaryOpts{}, []string{"pattern", "method", "status", "duration"}) + + report(time.Second, writer, request, vec) + middleware.ServeHTTP(writer, request) +} diff --git a/pkg/migration/ddl.go b/pkg/migration/ddl.go new file mode 100644 index 00000000..cd2c1c49 --- /dev/null +++ b/pkg/migration/ddl.go @@ -0,0 +1,395 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package migration + +import ( + "fmt" + + "github.com/astaxie/beego/logs" +) + +// Index struct defines the structure of Index Columns +type Index struct { + Name string +} + +// Unique struct defines a single unique key combination +type Unique struct { + Definition string + Columns []*Column +} + +//Column struct defines a single column of a table +type Column struct { + Name string + Inc string + Null string + Default string + Unsign string + DataType string + remove bool + Modify bool +} + +// Foreign struct defines a single foreign relationship +type Foreign struct { + ForeignTable string + ForeignColumn string + OnDelete string + OnUpdate string + Column +} + +// RenameColumn struct allows renaming of columns +type RenameColumn struct { + OldName string + OldNull string + OldDefault string + OldUnsign string + OldDataType string + NewName string + Column +} + +// CreateTable creates the table on system +func (m *Migration) CreateTable(tablename, engine, charset string, p ...func()) { + m.TableName = tablename + m.Engine = engine + m.Charset = charset + m.ModifyType = "create" +} + +// AlterTable set the ModifyType to alter +func (m *Migration) AlterTable(tablename string) { + m.TableName = tablename + m.ModifyType = "alter" +} + +// NewCol creates a new standard column and attaches it to m struct +func (m *Migration) NewCol(name string) *Column { + col := &Column{Name: name} + m.AddColumns(col) + return col +} + +//PriCol creates a new primary column and attaches it to m struct +func (m *Migration) PriCol(name string) *Column { + col := &Column{Name: name} + m.AddColumns(col) + m.AddPrimary(col) + return col +} + +//UniCol creates / appends columns to specified unique key and attaches it to m struct +func (m *Migration) UniCol(uni, name string) *Column { + col := &Column{Name: name} + m.AddColumns(col) + + uniqueOriginal := &Unique{} + + for _, unique := range m.Uniques { + if unique.Definition == uni { + unique.AddColumnsToUnique(col) + uniqueOriginal = unique + } + } + if uniqueOriginal.Definition == "" { + unique := &Unique{Definition: uni} + unique.AddColumnsToUnique(col) + m.AddUnique(unique) + } + + return col +} + +//ForeignCol creates a new foreign column and returns the instance of column +func (m *Migration) ForeignCol(colname, foreigncol, foreigntable string) (foreign *Foreign) { + + foreign = &Foreign{ForeignColumn: foreigncol, ForeignTable: foreigntable} + foreign.Name = colname + m.AddForeign(foreign) + return foreign +} + +//SetOnDelete sets the on delete of foreign +func (foreign *Foreign) SetOnDelete(del string) *Foreign { + foreign.OnDelete = "ON DELETE" + del + return foreign +} + +//SetOnUpdate sets the on update of foreign +func (foreign *Foreign) SetOnUpdate(update string) *Foreign { + foreign.OnUpdate = "ON UPDATE" + update + return foreign +} + +//Remove marks the columns to be removed. +//it allows reverse m to create the column. +func (c *Column) Remove() { + c.remove = true +} + +//SetAuto enables auto_increment of column (can be used once) +func (c *Column) SetAuto(inc bool) *Column { + if inc { + c.Inc = "auto_increment" + } + return c +} + +//SetNullable sets the column to be null +func (c *Column) SetNullable(null bool) *Column { + if null { + c.Null = "" + + } else { + c.Null = "NOT NULL" + } + return c +} + +//SetDefault sets the default value, prepend with "DEFAULT " +func (c *Column) SetDefault(def string) *Column { + c.Default = "DEFAULT " + def + return c +} + +//SetUnsigned sets the column to be unsigned int +func (c *Column) SetUnsigned(unsign bool) *Column { + if unsign { + c.Unsign = "UNSIGNED" + } + return c +} + +//SetDataType sets the dataType of the column +func (c *Column) SetDataType(dataType string) *Column { + c.DataType = dataType + return c +} + +//SetOldNullable allows reverting to previous nullable on reverse ms +func (c *RenameColumn) SetOldNullable(null bool) *RenameColumn { + if null { + c.OldNull = "" + + } else { + c.OldNull = "NOT NULL" + } + return c +} + +//SetOldDefault allows reverting to previous default on reverse ms +func (c *RenameColumn) SetOldDefault(def string) *RenameColumn { + c.OldDefault = def + return c +} + +//SetOldUnsigned allows reverting to previous unsgined on reverse ms +func (c *RenameColumn) SetOldUnsigned(unsign bool) *RenameColumn { + if unsign { + c.OldUnsign = "UNSIGNED" + } + return c +} + +//SetOldDataType allows reverting to previous datatype on reverse ms +func (c *RenameColumn) SetOldDataType(dataType string) *RenameColumn { + c.OldDataType = dataType + return c +} + +//SetPrimary adds the columns to the primary key (can only be used any number of times in only one m) +func (c *Column) SetPrimary(m *Migration) *Column { + m.Primary = append(m.Primary, c) + return c +} + +//AddColumnsToUnique adds the columns to Unique Struct +func (unique *Unique) AddColumnsToUnique(columns ...*Column) *Unique { + + unique.Columns = append(unique.Columns, columns...) + + return unique +} + +//AddColumns adds columns to m struct +func (m *Migration) AddColumns(columns ...*Column) *Migration { + + m.Columns = append(m.Columns, columns...) + + return m +} + +//AddPrimary adds the column to primary in m struct +func (m *Migration) AddPrimary(primary *Column) *Migration { + m.Primary = append(m.Primary, primary) + return m +} + +//AddUnique adds the column to unique in m struct +func (m *Migration) AddUnique(unique *Unique) *Migration { + m.Uniques = append(m.Uniques, unique) + return m +} + +//AddForeign adds the column to foreign in m struct +func (m *Migration) AddForeign(foreign *Foreign) *Migration { + m.Foreigns = append(m.Foreigns, foreign) + return m +} + +//AddIndex adds the column to index in m struct +func (m *Migration) AddIndex(index *Index) *Migration { + m.Indexes = append(m.Indexes, index) + return m +} + +//RenameColumn allows renaming of columns +func (m *Migration) RenameColumn(from, to string) *RenameColumn { + rename := &RenameColumn{OldName: from, NewName: to} + m.Renames = append(m.Renames, rename) + return rename +} + +//GetSQL returns the generated sql depending on ModifyType +func (m *Migration) GetSQL() (sql string) { + sql = "" + switch m.ModifyType { + case "create": + { + sql += fmt.Sprintf("CREATE TABLE `%s` (", m.TableName) + for index, column := range m.Columns { + sql += fmt.Sprintf("\n `%s` %s %s %s %s %s", column.Name, column.DataType, column.Unsign, column.Null, column.Inc, column.Default) + if len(m.Columns) > index+1 { + sql += "," + } + } + + if len(m.Primary) > 0 { + sql += fmt.Sprintf(",\n PRIMARY KEY( ") + } + for index, column := range m.Primary { + sql += fmt.Sprintf(" `%s`", column.Name) + if len(m.Primary) > index+1 { + sql += "," + } + + } + if len(m.Primary) > 0 { + sql += fmt.Sprintf(")") + } + + for _, unique := range m.Uniques { + sql += fmt.Sprintf(",\n UNIQUE KEY `%s`( ", unique.Definition) + for index, column := range unique.Columns { + sql += fmt.Sprintf(" `%s`", column.Name) + if len(unique.Columns) > index+1 { + sql += "," + } + } + sql += fmt.Sprintf(")") + } + for _, foreign := range m.Foreigns { + sql += fmt.Sprintf(",\n `%s` %s %s %s %s %s", foreign.Name, foreign.DataType, foreign.Unsign, foreign.Null, foreign.Inc, foreign.Default) + sql += fmt.Sprintf(",\n KEY `%s_%s_foreign`(`%s`),", m.TableName, foreign.Column.Name, foreign.Column.Name) + sql += fmt.Sprintf("\n CONSTRAINT `%s_%s_foreign` FOREIGN KEY (`%s`) REFERENCES `%s` (`%s`) %s %s", m.TableName, foreign.Column.Name, foreign.Column.Name, foreign.ForeignTable, foreign.ForeignColumn, foreign.OnDelete, foreign.OnUpdate) + + } + sql += fmt.Sprintf(")ENGINE=%s DEFAULT CHARSET=%s;", m.Engine, m.Charset) + break + } + case "alter": + { + sql += fmt.Sprintf("ALTER TABLE `%s` ", m.TableName) + for index, column := range m.Columns { + if !column.remove { + logs.Info("col") + sql += fmt.Sprintf("\n ADD `%s` %s %s %s %s %s", column.Name, column.DataType, column.Unsign, column.Null, column.Inc, column.Default) + } else { + sql += fmt.Sprintf("\n DROP COLUMN `%s`", column.Name) + } + + if len(m.Columns) > index+1 { + sql += "," + } + } + for index, column := range m.Renames { + sql += fmt.Sprintf("CHANGE COLUMN `%s` `%s` %s %s %s %s %s", column.OldName, column.NewName, column.DataType, column.Unsign, column.Null, column.Inc, column.Default) + if len(m.Renames) > index+1 { + sql += "," + } + } + + for index, foreign := range m.Foreigns { + sql += fmt.Sprintf("ADD `%s` %s %s %s %s %s", foreign.Name, foreign.DataType, foreign.Unsign, foreign.Null, foreign.Inc, foreign.Default) + sql += fmt.Sprintf(",\n ADD KEY `%s_%s_foreign`(`%s`)", m.TableName, foreign.Column.Name, foreign.Column.Name) + sql += fmt.Sprintf(",\n ADD CONSTRAINT `%s_%s_foreign` FOREIGN KEY (`%s`) REFERENCES `%s` (`%s`) %s %s", m.TableName, foreign.Column.Name, foreign.Column.Name, foreign.ForeignTable, foreign.ForeignColumn, foreign.OnDelete, foreign.OnUpdate) + if len(m.Foreigns) > index+1 { + sql += "," + } + } + sql += ";" + + break + } + case "reverse": + { + + sql += fmt.Sprintf("ALTER TABLE `%s`", m.TableName) + for index, column := range m.Columns { + if column.remove { + sql += fmt.Sprintf("\n ADD `%s` %s %s %s %s %s", column.Name, column.DataType, column.Unsign, column.Null, column.Inc, column.Default) + } else { + sql += fmt.Sprintf("\n DROP COLUMN `%s`", column.Name) + } + if len(m.Columns) > index+1 { + sql += "," + } + } + + if len(m.Primary) > 0 { + sql += fmt.Sprintf("\n DROP PRIMARY KEY,") + } + + for index, unique := range m.Uniques { + sql += fmt.Sprintf("\n DROP KEY `%s`", unique.Definition) + if len(m.Uniques) > index+1 { + sql += "," + } + + } + for index, column := range m.Renames { + sql += fmt.Sprintf("\n CHANGE COLUMN `%s` `%s` %s %s %s %s", column.NewName, column.OldName, column.OldDataType, column.OldUnsign, column.OldNull, column.OldDefault) + if len(m.Renames) > index+1 { + sql += "," + } + } + + for _, foreign := range m.Foreigns { + sql += fmt.Sprintf("\n DROP KEY `%s_%s_foreign`", m.TableName, foreign.Column.Name) + sql += fmt.Sprintf(",\n DROP FOREIGN KEY `%s_%s_foreign`", m.TableName, foreign.Column.Name) + sql += fmt.Sprintf(",\n DROP COLUMN `%s`", foreign.Name) + } + sql += ";" + } + case "delete": + { + sql += fmt.Sprintf("DROP TABLE IF EXISTS `%s`;", m.TableName) + } + } + + return +} diff --git a/pkg/migration/doc.go b/pkg/migration/doc.go new file mode 100644 index 00000000..0c6564d4 --- /dev/null +++ b/pkg/migration/doc.go @@ -0,0 +1,32 @@ +// Package migration enables you to generate migrations back and forth. It generates both migrations. +// +// //Creates a table +// m.CreateTable("tablename","InnoDB","utf8"); +// +// //Alter a table +// m.AlterTable("tablename") +// +// Standard Column Methods +// * SetDataType +// * SetNullable +// * SetDefault +// * SetUnsigned (use only on integer types unless produces error) +// +// //Sets a primary column, multiple calls allowed, standard column methods available +// m.PriCol("id").SetAuto(true).SetNullable(false).SetDataType("INT(10)").SetUnsigned(true) +// +// //UniCol Can be used multiple times, allows standard Column methods. Use same "index" string to add to same index +// m.UniCol("index","column") +// +// //Standard Column Initialisation, can call .Remove() after NewCol("") on alter to remove +// m.NewCol("name").SetDataType("VARCHAR(255) COLLATE utf8_unicode_ci").SetNullable(false) +// m.NewCol("value").SetDataType("DOUBLE(8,2)").SetNullable(false) +// +// //Rename Columns , only use with Alter table, doesn't works with Create, prefix standard column methods with "Old" to +// //create a true reversible migration eg: SetOldDataType("DOUBLE(12,3)") +// m.RenameColumn("from","to")... +// +// //Foreign Columns, single columns are only supported, SetOnDelete & SetOnUpdate are available, call appropriately. +// //Supports standard column methods, automatic reverse. +// m.ForeignCol("local_col","foreign_col","foreign_table") +package migration diff --git a/pkg/migration/migration.go b/pkg/migration/migration.go new file mode 100644 index 00000000..5ddfd972 --- /dev/null +++ b/pkg/migration/migration.go @@ -0,0 +1,330 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package migration is used for migration +// +// The table structure is as follow: +// +// CREATE TABLE `migrations` ( +// `id_migration` int(10) unsigned NOT NULL AUTO_INCREMENT COMMENT 'surrogate key', +// `name` varchar(255) DEFAULT NULL COMMENT 'migration name, unique', +// `created_at` timestamp NOT NULL DEFAULT CURRENT_TIMESTAMP COMMENT 'date migrated or rolled back', +// `statements` longtext COMMENT 'SQL statements for this migration', +// `rollback_statements` longtext, +// `status` enum('update','rollback') DEFAULT NULL COMMENT 'update indicates it is a normal migration while rollback means this migration is rolled back', +// PRIMARY KEY (`id_migration`) +// ) ENGINE=InnoDB DEFAULT CHARSET=utf8; +package migration + +import ( + "errors" + "sort" + "strings" + "time" + + "github.com/astaxie/beego/logs" + "github.com/astaxie/beego/orm" +) + +// const the data format for the bee generate migration datatype +const ( + DateFormat = "20060102_150405" + DBDateFormat = "2006-01-02 15:04:05" +) + +// Migrationer is an interface for all Migration struct +type Migrationer interface { + Up() + Down() + Reset() + Exec(name, status string) error + GetCreated() int64 +} + +//Migration defines the migrations by either SQL or DDL +type Migration struct { + sqls []string + Created string + TableName string + Engine string + Charset string + ModifyType string + Columns []*Column + Indexes []*Index + Primary []*Column + Uniques []*Unique + Foreigns []*Foreign + Renames []*RenameColumn + RemoveColumns []*Column + RemoveIndexes []*Index + RemoveUniques []*Unique + RemoveForeigns []*Foreign +} + +var ( + migrationMap map[string]Migrationer +) + +func init() { + migrationMap = make(map[string]Migrationer) +} + +// Up implement in the Inheritance struct for upgrade +func (m *Migration) Up() { + + switch m.ModifyType { + case "reverse": + m.ModifyType = "alter" + case "delete": + m.ModifyType = "create" + } + m.sqls = append(m.sqls, m.GetSQL()) +} + +// Down implement in the Inheritance struct for down +func (m *Migration) Down() { + + switch m.ModifyType { + case "alter": + m.ModifyType = "reverse" + case "create": + m.ModifyType = "delete" + } + m.sqls = append(m.sqls, m.GetSQL()) +} + +//Migrate adds the SQL to the execution list +func (m *Migration) Migrate(migrationType string) { + m.ModifyType = migrationType + m.sqls = append(m.sqls, m.GetSQL()) +} + +// SQL add sql want to execute +func (m *Migration) SQL(sql string) { + m.sqls = append(m.sqls, sql) +} + +// Reset the sqls +func (m *Migration) Reset() { + m.sqls = make([]string, 0) +} + +// Exec execute the sql already add in the sql +func (m *Migration) Exec(name, status string) error { + o := orm.NewOrm() + for _, s := range m.sqls { + logs.Info("exec sql:", s) + r := o.Raw(s) + _, err := r.Exec() + if err != nil { + return err + } + } + return m.addOrUpdateRecord(name, status) +} + +func (m *Migration) addOrUpdateRecord(name, status string) error { + o := orm.NewOrm() + if status == "down" { + status = "rollback" + p, err := o.Raw("update migrations set status = ?, rollback_statements = ?, created_at = ? where name = ?").Prepare() + if err != nil { + return nil + } + _, err = p.Exec(status, strings.Join(m.sqls, "; "), time.Now().Format(DBDateFormat), name) + return err + } + status = "update" + p, err := o.Raw("insert into migrations(name, created_at, statements, status) values(?,?,?,?)").Prepare() + if err != nil { + return err + } + _, err = p.Exec(name, time.Now().Format(DBDateFormat), strings.Join(m.sqls, "; "), status) + return err +} + +// GetCreated get the unixtime from the Created +func (m *Migration) GetCreated() int64 { + t, err := time.Parse(DateFormat, m.Created) + if err != nil { + return 0 + } + return t.Unix() +} + +// Register register the Migration in the map +func Register(name string, m Migrationer) error { + if _, ok := migrationMap[name]; ok { + return errors.New("already exist name:" + name) + } + migrationMap[name] = m + return nil +} + +// Upgrade upgrade the migration from lasttime +func Upgrade(lasttime int64) error { + sm := sortMap(migrationMap) + i := 0 + migs, _ := getAllMigrations() + for _, v := range sm { + if _, ok := migs[v.name]; !ok { + logs.Info("start upgrade", v.name) + v.m.Reset() + v.m.Up() + err := v.m.Exec(v.name, "up") + if err != nil { + logs.Error("execute error:", err) + time.Sleep(2 * time.Second) + return err + } + logs.Info("end upgrade:", v.name) + i++ + } + } + logs.Info("total success upgrade:", i, " migration") + time.Sleep(2 * time.Second) + return nil +} + +// Rollback rollback the migration by the name +func Rollback(name string) error { + if v, ok := migrationMap[name]; ok { + logs.Info("start rollback") + v.Reset() + v.Down() + err := v.Exec(name, "down") + if err != nil { + logs.Error("execute error:", err) + time.Sleep(2 * time.Second) + return err + } + logs.Info("end rollback") + time.Sleep(2 * time.Second) + return nil + } + logs.Error("not exist the migrationMap name:" + name) + time.Sleep(2 * time.Second) + return errors.New("not exist the migrationMap name:" + name) +} + +// Reset reset all migration +// run all migration's down function +func Reset() error { + sm := sortMap(migrationMap) + i := 0 + for j := len(sm) - 1; j >= 0; j-- { + v := sm[j] + if isRollBack(v.name) { + logs.Info("skip the", v.name) + time.Sleep(1 * time.Second) + continue + } + logs.Info("start reset:", v.name) + v.m.Reset() + v.m.Down() + err := v.m.Exec(v.name, "down") + if err != nil { + logs.Error("execute error:", err) + time.Sleep(2 * time.Second) + return err + } + i++ + logs.Info("end reset:", v.name) + } + logs.Info("total success reset:", i, " migration") + time.Sleep(2 * time.Second) + return nil +} + +// Refresh first Reset, then Upgrade +func Refresh() error { + err := Reset() + if err != nil { + logs.Error("execute error:", err) + time.Sleep(2 * time.Second) + return err + } + err = Upgrade(0) + return err +} + +type dataSlice []data + +type data struct { + created int64 + name string + m Migrationer +} + +// Len is part of sort.Interface. +func (d dataSlice) Len() int { + return len(d) +} + +// Swap is part of sort.Interface. +func (d dataSlice) Swap(i, j int) { + d[i], d[j] = d[j], d[i] +} + +// Less is part of sort.Interface. We use count as the value to sort by +func (d dataSlice) Less(i, j int) bool { + return d[i].created < d[j].created +} + +func sortMap(m map[string]Migrationer) dataSlice { + s := make(dataSlice, 0, len(m)) + for k, v := range m { + d := data{} + d.created = v.GetCreated() + d.name = k + d.m = v + s = append(s, d) + } + sort.Sort(s) + return s +} + +func isRollBack(name string) bool { + o := orm.NewOrm() + var maps []orm.Params + num, err := o.Raw("select * from migrations where `name` = ? order by id_migration desc", name).Values(&maps) + if err != nil { + logs.Info("get name has error", err) + return false + } + if num <= 0 { + return false + } + if maps[0]["status"] == "rollback" { + return true + } + return false +} +func getAllMigrations() (map[string]string, error) { + o := orm.NewOrm() + var maps []orm.Params + migs := make(map[string]string) + num, err := o.Raw("select * from migrations order by id_migration desc").Values(&maps) + if err != nil { + logs.Info("get name has error", err) + return migs, err + } + if num > 0 { + for _, v := range maps { + name := v["name"].(string) + migs[name] = v["status"].(string) + } + } + return migs, nil +} diff --git a/pkg/mime.go b/pkg/mime.go new file mode 100644 index 00000000..ca2878ab --- /dev/null +++ b/pkg/mime.go @@ -0,0 +1,556 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package beego + +var mimemaps = map[string]string{ + ".3dm": "x-world/x-3dmf", + ".3dmf": "x-world/x-3dmf", + ".7z": "application/x-7z-compressed", + ".a": "application/octet-stream", + ".aab": "application/x-authorware-bin", + ".aam": "application/x-authorware-map", + ".aas": "application/x-authorware-seg", + ".abc": "text/vndabc", + ".ace": "application/x-ace-compressed", + ".acgi": "text/html", + ".afl": "video/animaflex", + ".ai": "application/postscript", + ".aif": "audio/aiff", + ".aifc": "audio/aiff", + ".aiff": "audio/aiff", + ".aim": "application/x-aim", + ".aip": "text/x-audiosoft-intra", + ".alz": "application/x-alz-compressed", + ".ani": "application/x-navi-animation", + ".aos": "application/x-nokia-9000-communicator-add-on-software", + ".aps": "application/mime", + ".apk": "application/vnd.android.package-archive", + ".arc": "application/x-arc-compressed", + ".arj": "application/arj", + ".art": "image/x-jg", + ".asf": "video/x-ms-asf", + ".asm": "text/x-asm", + ".asp": "text/asp", + ".asx": "application/x-mplayer2", + ".au": "audio/basic", + ".avi": "video/x-msvideo", + ".avs": "video/avs-video", + ".bcpio": "application/x-bcpio", + ".bin": "application/mac-binary", + ".bmp": "image/bmp", + ".boo": "application/book", + ".book": "application/book", + ".boz": "application/x-bzip2", + ".bsh": "application/x-bsh", + ".bz2": "application/x-bzip2", + ".bz": "application/x-bzip", + ".c++": "text/plain", + ".c": "text/x-c", + ".cab": "application/vnd.ms-cab-compressed", + ".cat": "application/vndms-pkiseccat", + ".cc": "text/x-c", + ".ccad": "application/clariscad", + ".cco": "application/x-cocoa", + ".cdf": "application/cdf", + ".cer": "application/pkix-cert", + ".cha": "application/x-chat", + ".chat": "application/x-chat", + ".chrt": "application/vnd.kde.kchart", + ".class": "application/java", + ".com": "text/plain", + ".conf": "text/plain", + ".cpio": "application/x-cpio", + ".cpp": "text/x-c", + ".cpt": "application/mac-compactpro", + ".crl": "application/pkcs-crl", + ".crt": "application/pkix-cert", + ".crx": "application/x-chrome-extension", + ".csh": "text/x-scriptcsh", + ".css": "text/css", + ".csv": "text/csv", + ".cxx": "text/plain", + ".dar": "application/x-dar", + ".dcr": "application/x-director", + ".deb": "application/x-debian-package", + ".deepv": "application/x-deepv", + ".def": "text/plain", + ".der": "application/x-x509-ca-cert", + ".dif": "video/x-dv", + ".dir": "application/x-director", + ".divx": "video/divx", + ".dl": "video/dl", + ".dmg": "application/x-apple-diskimage", + ".doc": "application/msword", + ".dot": "application/msword", + ".dp": "application/commonground", + ".drw": "application/drafting", + ".dump": "application/octet-stream", + ".dv": "video/x-dv", + ".dvi": "application/x-dvi", + ".dwf": "drawing/x-dwf=(old)", + ".dwg": "application/acad", + ".dxf": "application/dxf", + ".dxr": "application/x-director", + ".el": "text/x-scriptelisp", + ".elc": "application/x-bytecodeelisp=(compiled=elisp)", + ".eml": "message/rfc822", + ".env": "application/x-envoy", + ".eps": "application/postscript", + ".es": "application/x-esrehber", + ".etx": "text/x-setext", + ".evy": "application/envoy", + ".exe": "application/octet-stream", + ".f77": "text/x-fortran", + ".f90": "text/x-fortran", + ".f": "text/x-fortran", + ".fdf": "application/vndfdf", + ".fif": "application/fractals", + ".fli": "video/fli", + ".flo": "image/florian", + ".flv": "video/x-flv", + ".flx": "text/vndfmiflexstor", + ".fmf": "video/x-atomic3d-feature", + ".for": "text/x-fortran", + ".fpx": "image/vndfpx", + ".frl": "application/freeloader", + ".funk": "audio/make", + ".g3": "image/g3fax", + ".g": "text/plain", + ".gif": "image/gif", + ".gl": "video/gl", + ".gsd": "audio/x-gsm", + ".gsm": "audio/x-gsm", + ".gsp": "application/x-gsp", + ".gss": "application/x-gss", + ".gtar": "application/x-gtar", + ".gz": "application/x-compressed", + ".gzip": "application/x-gzip", + ".h": "text/x-h", + ".hdf": "application/x-hdf", + ".help": "application/x-helpfile", + ".hgl": "application/vndhp-hpgl", + ".hh": "text/x-h", + ".hlb": "text/x-script", + ".hlp": "application/hlp", + ".hpg": "application/vndhp-hpgl", + ".hpgl": "application/vndhp-hpgl", + ".hqx": "application/binhex", + ".hta": "application/hta", + ".htc": "text/x-component", + ".htm": "text/html", + ".html": "text/html", + ".htmls": "text/html", + ".htt": "text/webviewhtml", + ".htx": "text/html", + ".ice": "x-conference/x-cooltalk", + ".ico": "image/x-icon", + ".ics": "text/calendar", + ".icz": "text/calendar", + ".idc": "text/plain", + ".ief": "image/ief", + ".iefs": "image/ief", + ".iges": "application/iges", + ".igs": "application/iges", + ".ima": "application/x-ima", + ".imap": "application/x-httpd-imap", + ".inf": "application/inf", + ".ins": "application/x-internett-signup", + ".ip": "application/x-ip2", + ".isu": "video/x-isvideo", + ".it": "audio/it", + ".iv": "application/x-inventor", + ".ivr": "i-world/i-vrml", + ".ivy": "application/x-livescreen", + ".jam": "audio/x-jam", + ".jav": "text/x-java-source", + ".java": "text/x-java-source", + ".jcm": "application/x-java-commerce", + ".jfif-tbnl": "image/jpeg", + ".jfif": "image/jpeg", + ".jnlp": "application/x-java-jnlp-file", + ".jpe": "image/jpeg", + ".jpeg": "image/jpeg", + ".jpg": "image/jpeg", + ".jps": "image/x-jps", + ".js": "application/javascript", + ".json": "application/json", + ".jut": "image/jutvision", + ".kar": "audio/midi", + ".karbon": "application/vnd.kde.karbon", + ".kfo": "application/vnd.kde.kformula", + ".flw": "application/vnd.kde.kivio", + ".kml": "application/vnd.google-earth.kml+xml", + ".kmz": "application/vnd.google-earth.kmz", + ".kon": "application/vnd.kde.kontour", + ".kpr": "application/vnd.kde.kpresenter", + ".kpt": "application/vnd.kde.kpresenter", + ".ksp": "application/vnd.kde.kspread", + ".kwd": "application/vnd.kde.kword", + ".kwt": "application/vnd.kde.kword", + ".ksh": "text/x-scriptksh", + ".la": "audio/nspaudio", + ".lam": "audio/x-liveaudio", + ".latex": "application/x-latex", + ".lha": "application/lha", + ".lhx": "application/octet-stream", + ".list": "text/plain", + ".lma": "audio/nspaudio", + ".log": "text/plain", + ".lsp": "text/x-scriptlisp", + ".lst": "text/plain", + ".lsx": "text/x-la-asf", + ".ltx": "application/x-latex", + ".lzh": "application/octet-stream", + ".lzx": "application/lzx", + ".m1v": "video/mpeg", + ".m2a": "audio/mpeg", + ".m2v": "video/mpeg", + ".m3u": "audio/x-mpegurl", + ".m": "text/x-m", + ".man": "application/x-troff-man", + ".manifest": "text/cache-manifest", + ".map": "application/x-navimap", + ".mar": "text/plain", + ".mbd": "application/mbedlet", + ".mc$": "application/x-magic-cap-package-10", + ".mcd": "application/mcad", + ".mcf": "text/mcf", + ".mcp": "application/netmc", + ".me": "application/x-troff-me", + ".mht": "message/rfc822", + ".mhtml": "message/rfc822", + ".mid": "application/x-midi", + ".midi": "application/x-midi", + ".mif": "application/x-frame", + ".mime": "message/rfc822", + ".mjf": "audio/x-vndaudioexplosionmjuicemediafile", + ".mjpg": "video/x-motion-jpeg", + ".mm": "application/base64", + ".mme": "application/base64", + ".mod": "audio/mod", + ".moov": "video/quicktime", + ".mov": "video/quicktime", + ".movie": "video/x-sgi-movie", + ".mp2": "audio/mpeg", + ".mp3": "audio/mpeg3", + ".mp4": "video/mp4", + ".mpa": "audio/mpeg", + ".mpc": "application/x-project", + ".mpe": "video/mpeg", + ".mpeg": "video/mpeg", + ".mpg": "video/mpeg", + ".mpga": "audio/mpeg", + ".mpp": "application/vndms-project", + ".mpt": "application/x-project", + ".mpv": "application/x-project", + ".mpx": "application/x-project", + ".mrc": "application/marc", + ".ms": "application/x-troff-ms", + ".mv": "video/x-sgi-movie", + ".my": "audio/make", + ".mzz": "application/x-vndaudioexplosionmzz", + ".nap": "image/naplps", + ".naplps": "image/naplps", + ".nc": "application/x-netcdf", + ".ncm": "application/vndnokiaconfiguration-message", + ".nif": "image/x-niff", + ".niff": "image/x-niff", + ".nix": "application/x-mix-transfer", + ".nsc": "application/x-conference", + ".nvd": "application/x-navidoc", + ".o": "application/octet-stream", + ".oda": "application/oda", + ".odb": "application/vnd.oasis.opendocument.database", + ".odc": "application/vnd.oasis.opendocument.chart", + ".odf": "application/vnd.oasis.opendocument.formula", + ".odg": "application/vnd.oasis.opendocument.graphics", + ".odi": "application/vnd.oasis.opendocument.image", + ".odm": "application/vnd.oasis.opendocument.text-master", + ".odp": "application/vnd.oasis.opendocument.presentation", + ".ods": "application/vnd.oasis.opendocument.spreadsheet", + ".odt": "application/vnd.oasis.opendocument.text", + ".oga": "audio/ogg", + ".ogg": "audio/ogg", + ".ogv": "video/ogg", + ".omc": "application/x-omc", + ".omcd": "application/x-omcdatamaker", + ".omcr": "application/x-omcregerator", + ".otc": "application/vnd.oasis.opendocument.chart-template", + ".otf": "application/vnd.oasis.opendocument.formula-template", + ".otg": "application/vnd.oasis.opendocument.graphics-template", + ".oth": "application/vnd.oasis.opendocument.text-web", + ".oti": "application/vnd.oasis.opendocument.image-template", + ".otm": "application/vnd.oasis.opendocument.text-master", + ".otp": "application/vnd.oasis.opendocument.presentation-template", + ".ots": "application/vnd.oasis.opendocument.spreadsheet-template", + ".ott": "application/vnd.oasis.opendocument.text-template", + ".p10": "application/pkcs10", + ".p12": "application/pkcs-12", + ".p7a": "application/x-pkcs7-signature", + ".p7c": "application/pkcs7-mime", + ".p7m": "application/pkcs7-mime", + ".p7r": "application/x-pkcs7-certreqresp", + ".p7s": "application/pkcs7-signature", + ".p": "text/x-pascal", + ".part": "application/pro_eng", + ".pas": "text/pascal", + ".pbm": "image/x-portable-bitmap", + ".pcl": "application/vndhp-pcl", + ".pct": "image/x-pict", + ".pcx": "image/x-pcx", + ".pdb": "chemical/x-pdb", + ".pdf": "application/pdf", + ".pfunk": "audio/make", + ".pgm": "image/x-portable-graymap", + ".pic": "image/pict", + ".pict": "image/pict", + ".pkg": "application/x-newton-compatible-pkg", + ".pko": "application/vndms-pkipko", + ".pl": "text/x-scriptperl", + ".plx": "application/x-pixclscript", + ".pm4": "application/x-pagemaker", + ".pm5": "application/x-pagemaker", + ".pm": "text/x-scriptperl-module", + ".png": "image/png", + ".pnm": "application/x-portable-anymap", + ".pot": "application/mspowerpoint", + ".pov": "model/x-pov", + ".ppa": "application/vndms-powerpoint", + ".ppm": "image/x-portable-pixmap", + ".pps": "application/mspowerpoint", + ".ppt": "application/mspowerpoint", + ".ppz": "application/mspowerpoint", + ".pre": "application/x-freelance", + ".prt": "application/pro_eng", + ".ps": "application/postscript", + ".psd": "application/octet-stream", + ".pvu": "paleovu/x-pv", + ".pwz": "application/vndms-powerpoint", + ".py": "text/x-scriptphyton", + ".pyc": "application/x-bytecodepython", + ".qcp": "audio/vndqcelp", + ".qd3": "x-world/x-3dmf", + ".qd3d": "x-world/x-3dmf", + ".qif": "image/x-quicktime", + ".qt": "video/quicktime", + ".qtc": "video/x-qtc", + ".qti": "image/x-quicktime", + ".qtif": "image/x-quicktime", + ".ra": "audio/x-pn-realaudio", + ".ram": "audio/x-pn-realaudio", + ".rar": "application/x-rar-compressed", + ".ras": "application/x-cmu-raster", + ".rast": "image/cmu-raster", + ".rexx": "text/x-scriptrexx", + ".rf": "image/vndrn-realflash", + ".rgb": "image/x-rgb", + ".rm": "application/vndrn-realmedia", + ".rmi": "audio/mid", + ".rmm": "audio/x-pn-realaudio", + ".rmp": "audio/x-pn-realaudio", + ".rng": "application/ringing-tones", + ".rnx": "application/vndrn-realplayer", + ".roff": "application/x-troff", + ".rp": "image/vndrn-realpix", + ".rpm": "audio/x-pn-realaudio-plugin", + ".rt": "text/vndrn-realtext", + ".rtf": "text/richtext", + ".rtx": "text/richtext", + ".rv": "video/vndrn-realvideo", + ".s": "text/x-asm", + ".s3m": "audio/s3m", + ".s7z": "application/x-7z-compressed", + ".saveme": "application/octet-stream", + ".sbk": "application/x-tbook", + ".scm": "text/x-scriptscheme", + ".sdml": "text/plain", + ".sdp": "application/sdp", + ".sdr": "application/sounder", + ".sea": "application/sea", + ".set": "application/set", + ".sgm": "text/x-sgml", + ".sgml": "text/x-sgml", + ".sh": "text/x-scriptsh", + ".shar": "application/x-bsh", + ".shtml": "text/x-server-parsed-html", + ".sid": "audio/x-psid", + ".skd": "application/x-koan", + ".skm": "application/x-koan", + ".skp": "application/x-koan", + ".skt": "application/x-koan", + ".sit": "application/x-stuffit", + ".sitx": "application/x-stuffitx", + ".sl": "application/x-seelogo", + ".smi": "application/smil", + ".smil": "application/smil", + ".snd": "audio/basic", + ".sol": "application/solids", + ".spc": "text/x-speech", + ".spl": "application/futuresplash", + ".spr": "application/x-sprite", + ".sprite": "application/x-sprite", + ".spx": "audio/ogg", + ".src": "application/x-wais-source", + ".ssi": "text/x-server-parsed-html", + ".ssm": "application/streamingmedia", + ".sst": "application/vndms-pkicertstore", + ".step": "application/step", + ".stl": "application/sla", + ".stp": "application/step", + ".sv4cpio": "application/x-sv4cpio", + ".sv4crc": "application/x-sv4crc", + ".svf": "image/vnddwg", + ".svg": "image/svg+xml", + ".svr": "application/x-world", + ".swf": "application/x-shockwave-flash", + ".t": "application/x-troff", + ".talk": "text/x-speech", + ".tar": "application/x-tar", + ".tbk": "application/toolbook", + ".tcl": "text/x-scripttcl", + ".tcsh": "text/x-scripttcsh", + ".tex": "application/x-tex", + ".texi": "application/x-texinfo", + ".texinfo": "application/x-texinfo", + ".text": "text/plain", + ".tgz": "application/gnutar", + ".tif": "image/tiff", + ".tiff": "image/tiff", + ".tr": "application/x-troff", + ".tsi": "audio/tsp-audio", + ".tsp": "application/dsptype", + ".tsv": "text/tab-separated-values", + ".turbot": "image/florian", + ".txt": "text/plain", + ".uil": "text/x-uil", + ".uni": "text/uri-list", + ".unis": "text/uri-list", + ".unv": "application/i-deas", + ".uri": "text/uri-list", + ".uris": "text/uri-list", + ".ustar": "application/x-ustar", + ".uu": "text/x-uuencode", + ".uue": "text/x-uuencode", + ".vcd": "application/x-cdlink", + ".vcf": "text/x-vcard", + ".vcard": "text/x-vcard", + ".vcs": "text/x-vcalendar", + ".vda": "application/vda", + ".vdo": "video/vdo", + ".vew": "application/groupwise", + ".viv": "video/vivo", + ".vivo": "video/vivo", + ".vmd": "application/vocaltec-media-desc", + ".vmf": "application/vocaltec-media-file", + ".voc": "audio/voc", + ".vos": "video/vosaic", + ".vox": "audio/voxware", + ".vqe": "audio/x-twinvq-plugin", + ".vqf": "audio/x-twinvq", + ".vql": "audio/x-twinvq-plugin", + ".vrml": "application/x-vrml", + ".vrt": "x-world/x-vrt", + ".vsd": "application/x-visio", + ".vst": "application/x-visio", + ".vsw": "application/x-visio", + ".w60": "application/wordperfect60", + ".w61": "application/wordperfect61", + ".w6w": "application/msword", + ".wav": "audio/wav", + ".wb1": "application/x-qpro", + ".wbmp": "image/vnd.wap.wbmp", + ".web": "application/vndxara", + ".wiz": "application/msword", + ".wk1": "application/x-123", + ".wmf": "windows/metafile", + ".wml": "text/vnd.wap.wml", + ".wmlc": "application/vnd.wap.wmlc", + ".wmls": "text/vnd.wap.wmlscript", + ".wmlsc": "application/vnd.wap.wmlscriptc", + ".word": "application/msword", + ".wp5": "application/wordperfect", + ".wp6": "application/wordperfect", + ".wp": "application/wordperfect", + ".wpd": "application/wordperfect", + ".wq1": "application/x-lotus", + ".wri": "application/mswrite", + ".wrl": "application/x-world", + ".wrz": "model/vrml", + ".wsc": "text/scriplet", + ".wsrc": "application/x-wais-source", + ".wtk": "application/x-wintalk", + ".x-png": "image/png", + ".xbm": "image/x-xbitmap", + ".xdr": "video/x-amt-demorun", + ".xgz": "xgl/drawing", + ".xif": "image/vndxiff", + ".xl": "application/excel", + ".xla": "application/excel", + ".xlb": "application/excel", + ".xlc": "application/excel", + ".xld": "application/excel", + ".xlk": "application/excel", + ".xll": "application/excel", + ".xlm": "application/excel", + ".xls": "application/excel", + ".xlt": "application/excel", + ".xlv": "application/excel", + ".xlw": "application/excel", + ".xm": "audio/xm", + ".xml": "text/xml", + ".xmz": "xgl/movie", + ".xpix": "application/x-vndls-xpix", + ".xpm": "image/x-xpixmap", + ".xsr": "video/x-amt-showrun", + ".xwd": "image/x-xwd", + ".xyz": "chemical/x-pdb", + ".z": "application/x-compress", + ".zip": "application/zip", + ".zoo": "application/octet-stream", + ".zsh": "text/x-scriptzsh", + ".docx": "application/vnd.openxmlformats-officedocument.wordprocessingml.document", + ".docm": "application/vnd.ms-word.document.macroEnabled.12", + ".dotx": "application/vnd.openxmlformats-officedocument.wordprocessingml.template", + ".dotm": "application/vnd.ms-word.template.macroEnabled.12", + ".xlsx": "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet", + ".xlsm": "application/vnd.ms-excel.sheet.macroEnabled.12", + ".xltx": "application/vnd.openxmlformats-officedocument.spreadsheetml.template", + ".xltm": "application/vnd.ms-excel.template.macroEnabled.12", + ".xlsb": "application/vnd.ms-excel.sheet.binary.macroEnabled.12", + ".xlam": "application/vnd.ms-excel.addin.macroEnabled.12", + ".pptx": "application/vnd.openxmlformats-officedocument.presentationml.presentation", + ".pptm": "application/vnd.ms-powerpoint.presentation.macroEnabled.12", + ".ppsx": "application/vnd.openxmlformats-officedocument.presentationml.slideshow", + ".ppsm": "application/vnd.ms-powerpoint.slideshow.macroEnabled.12", + ".potx": "application/vnd.openxmlformats-officedocument.presentationml.template", + ".potm": "application/vnd.ms-powerpoint.template.macroEnabled.12", + ".ppam": "application/vnd.ms-powerpoint.addin.macroEnabled.12", + ".sldx": "application/vnd.openxmlformats-officedocument.presentationml.slide", + ".sldm": "application/vnd.ms-powerpoint.slide.macroEnabled.12", + ".thmx": "application/vnd.ms-officetheme", + ".onetoc": "application/onenote", + ".onetoc2": "application/onenote", + ".onetmp": "application/onenote", + ".onepkg": "application/onenote", + ".key": "application/x-iwork-keynote-sffkey", + ".kth": "application/x-iwork-keynote-sffkth", + ".nmbtemplate": "application/x-iwork-numbers-sfftemplate", + ".numbers": "application/x-iwork-numbers-sffnumbers", + ".pages": "application/x-iwork-pages-sffpages", + ".template": "application/x-iwork-pages-sfftemplate", + ".xpi": "application/x-xpinstall", + ".oex": "application/x-opera-extension", + ".mustache": "text/html", +} diff --git a/pkg/namespace.go b/pkg/namespace.go new file mode 100644 index 00000000..4952c9d5 --- /dev/null +++ b/pkg/namespace.go @@ -0,0 +1,396 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package beego + +import ( + "net/http" + "strings" + + beecontext "github.com/astaxie/beego/context" +) + +type namespaceCond func(*beecontext.Context) bool + +// LinkNamespace used as link action +type LinkNamespace func(*Namespace) + +// Namespace is store all the info +type Namespace struct { + prefix string + handlers *ControllerRegister +} + +// NewNamespace get new Namespace +func NewNamespace(prefix string, params ...LinkNamespace) *Namespace { + ns := &Namespace{ + prefix: prefix, + handlers: NewControllerRegister(), + } + for _, p := range params { + p(ns) + } + return ns +} + +// Cond set condition function +// if cond return true can run this namespace, else can't +// usage: +// ns.Cond(func (ctx *context.Context) bool{ +// if ctx.Input.Domain() == "api.beego.me" { +// return true +// } +// return false +// }) +// Cond as the first filter +func (n *Namespace) Cond(cond namespaceCond) *Namespace { + fn := func(ctx *beecontext.Context) { + if !cond(ctx) { + exception("405", ctx) + } + } + if v := n.handlers.filters[BeforeRouter]; len(v) > 0 { + mr := new(FilterRouter) + mr.tree = NewTree() + mr.pattern = "*" + mr.filterFunc = fn + mr.tree.AddRouter("*", true) + n.handlers.filters[BeforeRouter] = append([]*FilterRouter{mr}, v...) + } else { + n.handlers.InsertFilter("*", BeforeRouter, fn) + } + return n +} + +// Filter add filter in the Namespace +// action has before & after +// FilterFunc +// usage: +// Filter("before", func (ctx *context.Context){ +// _, ok := ctx.Input.Session("uid").(int) +// if !ok && ctx.Request.RequestURI != "/login" { +// ctx.Redirect(302, "/login") +// } +// }) +func (n *Namespace) Filter(action string, filter ...FilterFunc) *Namespace { + var a int + if action == "before" { + a = BeforeRouter + } else if action == "after" { + a = FinishRouter + } + for _, f := range filter { + n.handlers.InsertFilter("*", a, f) + } + return n +} + +// Router same as beego.Rourer +// refer: https://godoc.org/github.com/astaxie/beego#Router +func (n *Namespace) Router(rootpath string, c ControllerInterface, mappingMethods ...string) *Namespace { + n.handlers.Add(rootpath, c, mappingMethods...) + return n +} + +// AutoRouter same as beego.AutoRouter +// refer: https://godoc.org/github.com/astaxie/beego#AutoRouter +func (n *Namespace) AutoRouter(c ControllerInterface) *Namespace { + n.handlers.AddAuto(c) + return n +} + +// AutoPrefix same as beego.AutoPrefix +// refer: https://godoc.org/github.com/astaxie/beego#AutoPrefix +func (n *Namespace) AutoPrefix(prefix string, c ControllerInterface) *Namespace { + n.handlers.AddAutoPrefix(prefix, c) + return n +} + +// Get same as beego.Get +// refer: https://godoc.org/github.com/astaxie/beego#Get +func (n *Namespace) Get(rootpath string, f FilterFunc) *Namespace { + n.handlers.Get(rootpath, f) + return n +} + +// Post same as beego.Post +// refer: https://godoc.org/github.com/astaxie/beego#Post +func (n *Namespace) Post(rootpath string, f FilterFunc) *Namespace { + n.handlers.Post(rootpath, f) + return n +} + +// Delete same as beego.Delete +// refer: https://godoc.org/github.com/astaxie/beego#Delete +func (n *Namespace) Delete(rootpath string, f FilterFunc) *Namespace { + n.handlers.Delete(rootpath, f) + return n +} + +// Put same as beego.Put +// refer: https://godoc.org/github.com/astaxie/beego#Put +func (n *Namespace) Put(rootpath string, f FilterFunc) *Namespace { + n.handlers.Put(rootpath, f) + return n +} + +// Head same as beego.Head +// refer: https://godoc.org/github.com/astaxie/beego#Head +func (n *Namespace) Head(rootpath string, f FilterFunc) *Namespace { + n.handlers.Head(rootpath, f) + return n +} + +// Options same as beego.Options +// refer: https://godoc.org/github.com/astaxie/beego#Options +func (n *Namespace) Options(rootpath string, f FilterFunc) *Namespace { + n.handlers.Options(rootpath, f) + return n +} + +// Patch same as beego.Patch +// refer: https://godoc.org/github.com/astaxie/beego#Patch +func (n *Namespace) Patch(rootpath string, f FilterFunc) *Namespace { + n.handlers.Patch(rootpath, f) + return n +} + +// Any same as beego.Any +// refer: https://godoc.org/github.com/astaxie/beego#Any +func (n *Namespace) Any(rootpath string, f FilterFunc) *Namespace { + n.handlers.Any(rootpath, f) + return n +} + +// Handler same as beego.Handler +// refer: https://godoc.org/github.com/astaxie/beego#Handler +func (n *Namespace) Handler(rootpath string, h http.Handler) *Namespace { + n.handlers.Handler(rootpath, h) + return n +} + +// Include add include class +// refer: https://godoc.org/github.com/astaxie/beego#Include +func (n *Namespace) Include(cList ...ControllerInterface) *Namespace { + n.handlers.Include(cList...) + return n +} + +// Namespace add nest Namespace +// usage: +//ns := beego.NewNamespace(“/v1”). +//Namespace( +// beego.NewNamespace("/shop"). +// Get("/:id", func(ctx *context.Context) { +// ctx.Output.Body([]byte("shopinfo")) +// }), +// beego.NewNamespace("/order"). +// Get("/:id", func(ctx *context.Context) { +// ctx.Output.Body([]byte("orderinfo")) +// }), +// beego.NewNamespace("/crm"). +// Get("/:id", func(ctx *context.Context) { +// ctx.Output.Body([]byte("crminfo")) +// }), +//) +func (n *Namespace) Namespace(ns ...*Namespace) *Namespace { + for _, ni := range ns { + for k, v := range ni.handlers.routers { + if _, ok := n.handlers.routers[k]; ok { + addPrefix(v, ni.prefix) + n.handlers.routers[k].AddTree(ni.prefix, v) + } else { + t := NewTree() + t.AddTree(ni.prefix, v) + addPrefix(t, ni.prefix) + n.handlers.routers[k] = t + } + } + if ni.handlers.enableFilter { + for pos, filterList := range ni.handlers.filters { + for _, mr := range filterList { + t := NewTree() + t.AddTree(ni.prefix, mr.tree) + mr.tree = t + n.handlers.insertFilterRouter(pos, mr) + } + } + } + } + return n +} + +// AddNamespace register Namespace into beego.Handler +// support multi Namespace +func AddNamespace(nl ...*Namespace) { + for _, n := range nl { + for k, v := range n.handlers.routers { + if _, ok := BeeApp.Handlers.routers[k]; ok { + addPrefix(v, n.prefix) + BeeApp.Handlers.routers[k].AddTree(n.prefix, v) + } else { + t := NewTree() + t.AddTree(n.prefix, v) + addPrefix(t, n.prefix) + BeeApp.Handlers.routers[k] = t + } + } + if n.handlers.enableFilter { + for pos, filterList := range n.handlers.filters { + for _, mr := range filterList { + t := NewTree() + t.AddTree(n.prefix, mr.tree) + mr.tree = t + BeeApp.Handlers.insertFilterRouter(pos, mr) + } + } + } + } +} + +func addPrefix(t *Tree, prefix string) { + for _, v := range t.fixrouters { + addPrefix(v, prefix) + } + if t.wildcard != nil { + addPrefix(t.wildcard, prefix) + } + for _, l := range t.leaves { + if c, ok := l.runObject.(*ControllerInfo); ok { + if !strings.HasPrefix(c.pattern, prefix) { + c.pattern = prefix + c.pattern + } + } + } +} + +// NSCond is Namespace Condition +func NSCond(cond namespaceCond) LinkNamespace { + return func(ns *Namespace) { + ns.Cond(cond) + } +} + +// NSBefore Namespace BeforeRouter filter +func NSBefore(filterList ...FilterFunc) LinkNamespace { + return func(ns *Namespace) { + ns.Filter("before", filterList...) + } +} + +// NSAfter add Namespace FinishRouter filter +func NSAfter(filterList ...FilterFunc) LinkNamespace { + return func(ns *Namespace) { + ns.Filter("after", filterList...) + } +} + +// NSInclude Namespace Include ControllerInterface +func NSInclude(cList ...ControllerInterface) LinkNamespace { + return func(ns *Namespace) { + ns.Include(cList...) + } +} + +// NSRouter call Namespace Router +func NSRouter(rootpath string, c ControllerInterface, mappingMethods ...string) LinkNamespace { + return func(ns *Namespace) { + ns.Router(rootpath, c, mappingMethods...) + } +} + +// NSGet call Namespace Get +func NSGet(rootpath string, f FilterFunc) LinkNamespace { + return func(ns *Namespace) { + ns.Get(rootpath, f) + } +} + +// NSPost call Namespace Post +func NSPost(rootpath string, f FilterFunc) LinkNamespace { + return func(ns *Namespace) { + ns.Post(rootpath, f) + } +} + +// NSHead call Namespace Head +func NSHead(rootpath string, f FilterFunc) LinkNamespace { + return func(ns *Namespace) { + ns.Head(rootpath, f) + } +} + +// NSPut call Namespace Put +func NSPut(rootpath string, f FilterFunc) LinkNamespace { + return func(ns *Namespace) { + ns.Put(rootpath, f) + } +} + +// NSDelete call Namespace Delete +func NSDelete(rootpath string, f FilterFunc) LinkNamespace { + return func(ns *Namespace) { + ns.Delete(rootpath, f) + } +} + +// NSAny call Namespace Any +func NSAny(rootpath string, f FilterFunc) LinkNamespace { + return func(ns *Namespace) { + ns.Any(rootpath, f) + } +} + +// NSOptions call Namespace Options +func NSOptions(rootpath string, f FilterFunc) LinkNamespace { + return func(ns *Namespace) { + ns.Options(rootpath, f) + } +} + +// NSPatch call Namespace Patch +func NSPatch(rootpath string, f FilterFunc) LinkNamespace { + return func(ns *Namespace) { + ns.Patch(rootpath, f) + } +} + +// NSAutoRouter call Namespace AutoRouter +func NSAutoRouter(c ControllerInterface) LinkNamespace { + return func(ns *Namespace) { + ns.AutoRouter(c) + } +} + +// NSAutoPrefix call Namespace AutoPrefix +func NSAutoPrefix(prefix string, c ControllerInterface) LinkNamespace { + return func(ns *Namespace) { + ns.AutoPrefix(prefix, c) + } +} + +// NSNamespace add sub Namespace +func NSNamespace(prefix string, params ...LinkNamespace) LinkNamespace { + return func(ns *Namespace) { + n := NewNamespace(prefix, params...) + ns.Namespace(n) + } +} + +// NSHandler add handler +func NSHandler(rootpath string, h http.Handler) LinkNamespace { + return func(ns *Namespace) { + ns.Handler(rootpath, h) + } +} diff --git a/pkg/namespace_test.go b/pkg/namespace_test.go new file mode 100644 index 00000000..b3f20dff --- /dev/null +++ b/pkg/namespace_test.go @@ -0,0 +1,168 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package beego + +import ( + "net/http" + "net/http/httptest" + "strconv" + "testing" + + "github.com/astaxie/beego/context" +) + +func TestNamespaceGet(t *testing.T) { + r, _ := http.NewRequest("GET", "/v1/user", nil) + w := httptest.NewRecorder() + + ns := NewNamespace("/v1") + ns.Get("/user", func(ctx *context.Context) { + ctx.Output.Body([]byte("v1_user")) + }) + AddNamespace(ns) + BeeApp.Handlers.ServeHTTP(w, r) + if w.Body.String() != "v1_user" { + t.Errorf("TestNamespaceGet can't run, get the response is " + w.Body.String()) + } +} + +func TestNamespacePost(t *testing.T) { + r, _ := http.NewRequest("POST", "/v1/user/123", nil) + w := httptest.NewRecorder() + + ns := NewNamespace("/v1") + ns.Post("/user/:id", func(ctx *context.Context) { + ctx.Output.Body([]byte(ctx.Input.Param(":id"))) + }) + AddNamespace(ns) + BeeApp.Handlers.ServeHTTP(w, r) + if w.Body.String() != "123" { + t.Errorf("TestNamespacePost can't run, get the response is " + w.Body.String()) + } +} + +func TestNamespaceNest(t *testing.T) { + r, _ := http.NewRequest("GET", "/v1/admin/order", nil) + w := httptest.NewRecorder() + + ns := NewNamespace("/v1") + ns.Namespace( + NewNamespace("/admin"). + Get("/order", func(ctx *context.Context) { + ctx.Output.Body([]byte("order")) + }), + ) + AddNamespace(ns) + BeeApp.Handlers.ServeHTTP(w, r) + if w.Body.String() != "order" { + t.Errorf("TestNamespaceNest can't run, get the response is " + w.Body.String()) + } +} + +func TestNamespaceNestParam(t *testing.T) { + r, _ := http.NewRequest("GET", "/v1/admin/order/123", nil) + w := httptest.NewRecorder() + + ns := NewNamespace("/v1") + ns.Namespace( + NewNamespace("/admin"). + Get("/order/:id", func(ctx *context.Context) { + ctx.Output.Body([]byte(ctx.Input.Param(":id"))) + }), + ) + AddNamespace(ns) + BeeApp.Handlers.ServeHTTP(w, r) + if w.Body.String() != "123" { + t.Errorf("TestNamespaceNestParam can't run, get the response is " + w.Body.String()) + } +} + +func TestNamespaceRouter(t *testing.T) { + r, _ := http.NewRequest("GET", "/v1/api/list", nil) + w := httptest.NewRecorder() + + ns := NewNamespace("/v1") + ns.Router("/api/list", &TestController{}, "*:List") + AddNamespace(ns) + BeeApp.Handlers.ServeHTTP(w, r) + if w.Body.String() != "i am list" { + t.Errorf("TestNamespaceRouter can't run, get the response is " + w.Body.String()) + } +} + +func TestNamespaceAutoFunc(t *testing.T) { + r, _ := http.NewRequest("GET", "/v1/test/list", nil) + w := httptest.NewRecorder() + + ns := NewNamespace("/v1") + ns.AutoRouter(&TestController{}) + AddNamespace(ns) + BeeApp.Handlers.ServeHTTP(w, r) + if w.Body.String() != "i am list" { + t.Errorf("user define func can't run") + } +} + +func TestNamespaceFilter(t *testing.T) { + r, _ := http.NewRequest("GET", "/v1/user/123", nil) + w := httptest.NewRecorder() + + ns := NewNamespace("/v1") + ns.Filter("before", func(ctx *context.Context) { + ctx.Output.Body([]byte("this is Filter")) + }). + Get("/user/:id", func(ctx *context.Context) { + ctx.Output.Body([]byte(ctx.Input.Param(":id"))) + }) + AddNamespace(ns) + BeeApp.Handlers.ServeHTTP(w, r) + if w.Body.String() != "this is Filter" { + t.Errorf("TestNamespaceFilter can't run, get the response is " + w.Body.String()) + } +} + +func TestNamespaceCond(t *testing.T) { + r, _ := http.NewRequest("GET", "/v2/test/list", nil) + w := httptest.NewRecorder() + + ns := NewNamespace("/v2") + ns.Cond(func(ctx *context.Context) bool { + return ctx.Input.Domain() == "beego.me" + }). + AutoRouter(&TestController{}) + AddNamespace(ns) + BeeApp.Handlers.ServeHTTP(w, r) + if w.Code != 405 { + t.Errorf("TestNamespaceCond can't run get the result " + strconv.Itoa(w.Code)) + } +} + +func TestNamespaceInside(t *testing.T) { + r, _ := http.NewRequest("GET", "/v3/shop/order/123", nil) + w := httptest.NewRecorder() + ns := NewNamespace("/v3", + NSAutoRouter(&TestController{}), + NSNamespace("/shop", + NSGet("/order/:id", func(ctx *context.Context) { + ctx.Output.Body([]byte(ctx.Input.Param(":id"))) + }), + ), + ) + AddNamespace(ns) + BeeApp.Handlers.ServeHTTP(w, r) + if w.Body.String() != "123" { + t.Errorf("TestNamespaceInside can't run, get the response is " + w.Body.String()) + } +} diff --git a/pkg/parser.go b/pkg/parser.go new file mode 100644 index 00000000..3a311894 --- /dev/null +++ b/pkg/parser.go @@ -0,0 +1,591 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package beego + +import ( + "encoding/json" + "errors" + "fmt" + "go/ast" + "go/parser" + "go/token" + "io/ioutil" + "os" + "path/filepath" + "regexp" + "sort" + "strconv" + "strings" + "unicode" + + "github.com/astaxie/beego/context/param" + "github.com/astaxie/beego/logs" + "github.com/astaxie/beego/utils" +) + +var globalRouterTemplate = `package {{.routersDir}} + +import ( + "github.com/astaxie/beego" + "github.com/astaxie/beego/context/param"{{.globalimport}} +) + +func init() { +{{.globalinfo}} +} +` + +var ( + lastupdateFilename = "lastupdate.tmp" + commentFilename string + pkgLastupdate map[string]int64 + genInfoList map[string][]ControllerComments + + routerHooks = map[string]int{ + "beego.BeforeStatic": BeforeStatic, + "beego.BeforeRouter": BeforeRouter, + "beego.BeforeExec": BeforeExec, + "beego.AfterExec": AfterExec, + "beego.FinishRouter": FinishRouter, + } + + routerHooksMapping = map[int]string{ + BeforeStatic: "beego.BeforeStatic", + BeforeRouter: "beego.BeforeRouter", + BeforeExec: "beego.BeforeExec", + AfterExec: "beego.AfterExec", + FinishRouter: "beego.FinishRouter", + } +) + +const commentPrefix = "commentsRouter_" + +func init() { + pkgLastupdate = make(map[string]int64) +} + +func parserPkg(pkgRealpath, pkgpath string) error { + rep := strings.NewReplacer("\\", "_", "/", "_", ".", "_") + commentFilename, _ = filepath.Rel(AppPath, pkgRealpath) + commentFilename = commentPrefix + rep.Replace(commentFilename) + ".go" + if !compareFile(pkgRealpath) { + logs.Info(pkgRealpath + " no changed") + return nil + } + genInfoList = make(map[string][]ControllerComments) + fileSet := token.NewFileSet() + astPkgs, err := parser.ParseDir(fileSet, pkgRealpath, func(info os.FileInfo) bool { + name := info.Name() + return !info.IsDir() && !strings.HasPrefix(name, ".") && strings.HasSuffix(name, ".go") + }, parser.ParseComments) + + if err != nil { + return err + } + for _, pkg := range astPkgs { + for _, fl := range pkg.Files { + for _, d := range fl.Decls { + switch specDecl := d.(type) { + case *ast.FuncDecl: + if specDecl.Recv != nil { + exp, ok := specDecl.Recv.List[0].Type.(*ast.StarExpr) // Check that the type is correct first beforing throwing to parser + if ok { + parserComments(specDecl, fmt.Sprint(exp.X), pkgpath) + } + } + } + } + } + } + genRouterCode(pkgRealpath) + savetoFile(pkgRealpath) + return nil +} + +type parsedComment struct { + routerPath string + methods []string + params map[string]parsedParam + filters []parsedFilter + imports []parsedImport +} + +type parsedImport struct { + importPath string + importAlias string +} + +type parsedFilter struct { + pattern string + pos int + filter string + params []bool +} + +type parsedParam struct { + name string + datatype string + location string + defValue string + required bool +} + +func parserComments(f *ast.FuncDecl, controllerName, pkgpath string) error { + if f.Doc != nil { + parsedComments, err := parseComment(f.Doc.List) + if err != nil { + return err + } + for _, parsedComment := range parsedComments { + if parsedComment.routerPath != "" { + key := pkgpath + ":" + controllerName + cc := ControllerComments{} + cc.Method = f.Name.String() + cc.Router = parsedComment.routerPath + cc.AllowHTTPMethods = parsedComment.methods + cc.MethodParams = buildMethodParams(f.Type.Params.List, parsedComment) + cc.FilterComments = buildFilters(parsedComment.filters) + cc.ImportComments = buildImports(parsedComment.imports) + genInfoList[key] = append(genInfoList[key], cc) + } + } + } + return nil +} + +func buildImports(pis []parsedImport) []*ControllerImportComments { + var importComments []*ControllerImportComments + + for _, pi := range pis { + importComments = append(importComments, &ControllerImportComments{ + ImportPath: pi.importPath, + ImportAlias: pi.importAlias, + }) + } + + return importComments +} + +func buildFilters(pfs []parsedFilter) []*ControllerFilterComments { + var filterComments []*ControllerFilterComments + + for _, pf := range pfs { + var ( + returnOnOutput bool + resetParams bool + ) + + if len(pf.params) >= 1 { + returnOnOutput = pf.params[0] + } + + if len(pf.params) >= 2 { + resetParams = pf.params[1] + } + + filterComments = append(filterComments, &ControllerFilterComments{ + Filter: pf.filter, + Pattern: pf.pattern, + Pos: pf.pos, + ReturnOnOutput: returnOnOutput, + ResetParams: resetParams, + }) + } + + return filterComments +} + +func buildMethodParams(funcParams []*ast.Field, pc *parsedComment) []*param.MethodParam { + result := make([]*param.MethodParam, 0, len(funcParams)) + for _, fparam := range funcParams { + for _, pName := range fparam.Names { + methodParam := buildMethodParam(fparam, pName.Name, pc) + result = append(result, methodParam) + } + } + return result +} + +func buildMethodParam(fparam *ast.Field, name string, pc *parsedComment) *param.MethodParam { + options := []param.MethodParamOption{} + if cparam, ok := pc.params[name]; ok { + //Build param from comment info + name = cparam.name + if cparam.required { + options = append(options, param.IsRequired) + } + switch cparam.location { + case "body": + options = append(options, param.InBody) + case "header": + options = append(options, param.InHeader) + case "path": + options = append(options, param.InPath) + } + if cparam.defValue != "" { + options = append(options, param.Default(cparam.defValue)) + } + } else { + if paramInPath(name, pc.routerPath) { + options = append(options, param.InPath) + } + } + return param.New(name, options...) +} + +func paramInPath(name, route string) bool { + return strings.HasSuffix(route, ":"+name) || + strings.Contains(route, ":"+name+"/") +} + +var routeRegex = regexp.MustCompile(`@router\s+(\S+)(?:\s+\[(\S+)\])?`) + +func parseComment(lines []*ast.Comment) (pcs []*parsedComment, err error) { + pcs = []*parsedComment{} + params := map[string]parsedParam{} + filters := []parsedFilter{} + imports := []parsedImport{} + + for _, c := range lines { + t := strings.TrimSpace(strings.TrimLeft(c.Text, "//")) + if strings.HasPrefix(t, "@Param") { + pv := getparams(strings.TrimSpace(strings.TrimLeft(t, "@Param"))) + if len(pv) < 4 { + logs.Error("Invalid @Param format. Needs at least 4 parameters") + } + p := parsedParam{} + names := strings.SplitN(pv[0], "=>", 2) + p.name = names[0] + funcParamName := p.name + if len(names) > 1 { + funcParamName = names[1] + } + p.location = pv[1] + p.datatype = pv[2] + switch len(pv) { + case 5: + p.required, _ = strconv.ParseBool(pv[3]) + case 6: + p.defValue = pv[3] + p.required, _ = strconv.ParseBool(pv[4]) + } + params[funcParamName] = p + } + } + + for _, c := range lines { + t := strings.TrimSpace(strings.TrimLeft(c.Text, "//")) + if strings.HasPrefix(t, "@Import") { + iv := getparams(strings.TrimSpace(strings.TrimLeft(t, "@Import"))) + if len(iv) == 0 || len(iv) > 2 { + logs.Error("Invalid @Import format. Only accepts 1 or 2 parameters") + continue + } + + p := parsedImport{} + p.importPath = iv[0] + + if len(iv) == 2 { + p.importAlias = iv[1] + } + + imports = append(imports, p) + } + } + +filterLoop: + for _, c := range lines { + t := strings.TrimSpace(strings.TrimLeft(c.Text, "//")) + if strings.HasPrefix(t, "@Filter") { + fv := getparams(strings.TrimSpace(strings.TrimLeft(t, "@Filter"))) + if len(fv) < 3 { + logs.Error("Invalid @Filter format. Needs at least 3 parameters") + continue filterLoop + } + + p := parsedFilter{} + p.pattern = fv[0] + posName := fv[1] + if pos, exists := routerHooks[posName]; exists { + p.pos = pos + } else { + logs.Error("Invalid @Filter pos: ", posName) + continue filterLoop + } + + p.filter = fv[2] + fvParams := fv[3:] + for _, fvParam := range fvParams { + switch fvParam { + case "true": + p.params = append(p.params, true) + case "false": + p.params = append(p.params, false) + default: + logs.Error("Invalid @Filter param: ", fvParam) + continue filterLoop + } + } + + filters = append(filters, p) + } + } + + for _, c := range lines { + var pc = &parsedComment{} + pc.params = params + pc.filters = filters + pc.imports = imports + + t := strings.TrimSpace(strings.TrimLeft(c.Text, "//")) + if strings.HasPrefix(t, "@router") { + t := strings.TrimSpace(strings.TrimLeft(c.Text, "//")) + matches := routeRegex.FindStringSubmatch(t) + if len(matches) == 3 { + pc.routerPath = matches[1] + methods := matches[2] + if methods == "" { + pc.methods = []string{"get"} + //pc.hasGet = true + } else { + pc.methods = strings.Split(methods, ",") + //pc.hasGet = strings.Contains(methods, "get") + } + pcs = append(pcs, pc) + } else { + return nil, errors.New("Router information is missing") + } + } + } + return +} + +// direct copy from bee\g_docs.go +// analysis params return []string +// @Param query form string true "The email for login" +// [query form string true "The email for login"] +func getparams(str string) []string { + var s []rune + var j int + var start bool + var r []string + var quoted int8 + for _, c := range str { + if unicode.IsSpace(c) && quoted == 0 { + if !start { + continue + } else { + start = false + j++ + r = append(r, string(s)) + s = make([]rune, 0) + continue + } + } + + start = true + if c == '"' { + quoted ^= 1 + continue + } + s = append(s, c) + } + if len(s) > 0 { + r = append(r, string(s)) + } + return r +} + +func genRouterCode(pkgRealpath string) { + os.Mkdir(getRouterDir(pkgRealpath), 0755) + logs.Info("generate router from comments") + var ( + globalinfo string + globalimport string + sortKey []string + ) + for k := range genInfoList { + sortKey = append(sortKey, k) + } + sort.Strings(sortKey) + for _, k := range sortKey { + cList := genInfoList[k] + sort.Sort(ControllerCommentsSlice(cList)) + for _, c := range cList { + allmethod := "nil" + if len(c.AllowHTTPMethods) > 0 { + allmethod = "[]string{" + for _, m := range c.AllowHTTPMethods { + allmethod += `"` + m + `",` + } + allmethod = strings.TrimRight(allmethod, ",") + "}" + } + + params := "nil" + if len(c.Params) > 0 { + params = "[]map[string]string{" + for _, p := range c.Params { + for k, v := range p { + params = params + `map[string]string{` + k + `:"` + v + `"},` + } + } + params = strings.TrimRight(params, ",") + "}" + } + + methodParams := "param.Make(" + if len(c.MethodParams) > 0 { + lines := make([]string, 0, len(c.MethodParams)) + for _, m := range c.MethodParams { + lines = append(lines, fmt.Sprint(m)) + } + methodParams += "\n " + + strings.Join(lines, ",\n ") + + ",\n " + } + methodParams += ")" + + imports := "" + if len(c.ImportComments) > 0 { + for _, i := range c.ImportComments { + var s string + if i.ImportAlias != "" { + s = fmt.Sprintf(` + %s "%s"`, i.ImportAlias, i.ImportPath) + } else { + s = fmt.Sprintf(` + "%s"`, i.ImportPath) + } + if !strings.Contains(globalimport, s) { + imports += s + } + } + } + + filters := "" + if len(c.FilterComments) > 0 { + for _, f := range c.FilterComments { + filters += fmt.Sprintf(` &beego.ControllerFilter{ + Pattern: "%s", + Pos: %s, + Filter: %s, + ReturnOnOutput: %v, + ResetParams: %v, + },`, f.Pattern, routerHooksMapping[f.Pos], f.Filter, f.ReturnOnOutput, f.ResetParams) + } + } + + if filters == "" { + filters = "nil" + } else { + filters = fmt.Sprintf(`[]*beego.ControllerFilter{ +%s + }`, filters) + } + + globalimport += imports + + globalinfo = globalinfo + ` + beego.GlobalControllerRouter["` + k + `"] = append(beego.GlobalControllerRouter["` + k + `"], + beego.ControllerComments{ + Method: "` + strings.TrimSpace(c.Method) + `", + ` + `Router: "` + c.Router + `"` + `, + AllowHTTPMethods: ` + allmethod + `, + MethodParams: ` + methodParams + `, + Filters: ` + filters + `, + Params: ` + params + `}) +` + } + } + + if globalinfo != "" { + f, err := os.Create(filepath.Join(getRouterDir(pkgRealpath), commentFilename)) + if err != nil { + panic(err) + } + defer f.Close() + + routersDir := AppConfig.DefaultString("routersdir", "routers") + content := strings.Replace(globalRouterTemplate, "{{.globalinfo}}", globalinfo, -1) + content = strings.Replace(content, "{{.routersDir}}", routersDir, -1) + content = strings.Replace(content, "{{.globalimport}}", globalimport, -1) + f.WriteString(content) + } +} + +func compareFile(pkgRealpath string) bool { + if !utils.FileExists(filepath.Join(getRouterDir(pkgRealpath), commentFilename)) { + return true + } + if utils.FileExists(lastupdateFilename) { + content, err := ioutil.ReadFile(lastupdateFilename) + if err != nil { + return true + } + json.Unmarshal(content, &pkgLastupdate) + lastupdate, err := getpathTime(pkgRealpath) + if err != nil { + return true + } + if v, ok := pkgLastupdate[pkgRealpath]; ok { + if lastupdate <= v { + return false + } + } + } + return true +} + +func savetoFile(pkgRealpath string) { + lastupdate, err := getpathTime(pkgRealpath) + if err != nil { + return + } + pkgLastupdate[pkgRealpath] = lastupdate + d, err := json.Marshal(pkgLastupdate) + if err != nil { + return + } + ioutil.WriteFile(lastupdateFilename, d, os.ModePerm) +} + +func getpathTime(pkgRealpath string) (lastupdate int64, err error) { + fl, err := ioutil.ReadDir(pkgRealpath) + if err != nil { + return lastupdate, err + } + for _, f := range fl { + if lastupdate < f.ModTime().UnixNano() { + lastupdate = f.ModTime().UnixNano() + } + } + return lastupdate, nil +} + +func getRouterDir(pkgRealpath string) string { + dir := filepath.Dir(pkgRealpath) + for { + routersDir := AppConfig.DefaultString("routersdir", "routers") + d := filepath.Join(dir, routersDir) + if utils.FileExists(d) { + return d + } + + if r, _ := filepath.Rel(dir, AppPath); r == "." { + return d + } + // Parent dir. + dir = filepath.Dir(dir) + } +} diff --git a/pkg/plugins/apiauth/apiauth.go b/pkg/plugins/apiauth/apiauth.go new file mode 100644 index 00000000..10e25f3f --- /dev/null +++ b/pkg/plugins/apiauth/apiauth.go @@ -0,0 +1,165 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package apiauth provides handlers to enable apiauth support. +// +// Simple Usage: +// import( +// "github.com/astaxie/beego" +// "github.com/astaxie/beego/plugins/apiauth" +// ) +// +// func main(){ +// // apiauth every request +// beego.InsertFilter("*", beego.BeforeRouter,apiauth.APIBaiscAuth("appid","appkey")) +// beego.Run() +// } +// +// Advanced Usage: +// +// func getAppSecret(appid string) string { +// // get appsecret by appid +// // maybe store in configure, maybe in database +// } +// +// beego.InsertFilter("*", beego.BeforeRouter,apiauth.APISecretAuth(getAppSecret, 360)) +// +// Information: +// +// In the request user should include these params in the query +// +// 1. appid +// +// appid is assigned to the application +// +// 2. signature +// +// get the signature use apiauth.Signature() +// +// when you send to server remember use url.QueryEscape() +// +// 3. timestamp: +// +// send the request time, the format is yyyy-mm-dd HH:ii:ss +// +package apiauth + +import ( + "bytes" + "crypto/hmac" + "crypto/sha256" + "encoding/base64" + "fmt" + "net/url" + "sort" + "time" + + "github.com/astaxie/beego" + "github.com/astaxie/beego/context" +) + +// AppIDToAppSecret is used to get appsecret throw appid +type AppIDToAppSecret func(string) string + +// APIBasicAuth use the basic appid/appkey as the AppIdToAppSecret +func APIBasicAuth(appid, appkey string) beego.FilterFunc { + ft := func(aid string) string { + if aid == appid { + return appkey + } + return "" + } + return APISecretAuth(ft, 300) +} + +// APIBaiscAuth calls APIBasicAuth for previous callers +func APIBaiscAuth(appid, appkey string) beego.FilterFunc { + return APIBasicAuth(appid, appkey) +} + +// APISecretAuth use AppIdToAppSecret verify and +func APISecretAuth(f AppIDToAppSecret, timeout int) beego.FilterFunc { + return func(ctx *context.Context) { + if ctx.Input.Query("appid") == "" { + ctx.ResponseWriter.WriteHeader(403) + ctx.WriteString("miss query param: appid") + return + } + appsecret := f(ctx.Input.Query("appid")) + if appsecret == "" { + ctx.ResponseWriter.WriteHeader(403) + ctx.WriteString("not exist this appid") + return + } + if ctx.Input.Query("signature") == "" { + ctx.ResponseWriter.WriteHeader(403) + ctx.WriteString("miss query param: signature") + return + } + if ctx.Input.Query("timestamp") == "" { + ctx.ResponseWriter.WriteHeader(403) + ctx.WriteString("miss query param: timestamp") + return + } + u, err := time.Parse("2006-01-02 15:04:05", ctx.Input.Query("timestamp")) + if err != nil { + ctx.ResponseWriter.WriteHeader(403) + ctx.WriteString("timestamp format is error, should 2006-01-02 15:04:05") + return + } + t := time.Now() + if t.Sub(u).Seconds() > float64(timeout) { + ctx.ResponseWriter.WriteHeader(403) + ctx.WriteString("timeout! the request time is long ago, please try again") + return + } + if ctx.Input.Query("signature") != + Signature(appsecret, ctx.Input.Method(), ctx.Request.Form, ctx.Input.URL()) { + ctx.ResponseWriter.WriteHeader(403) + ctx.WriteString("auth failed") + } + } +} + +// Signature used to generate signature with the appsecret/method/params/RequestURI +func Signature(appsecret, method string, params url.Values, RequestURL string) (result string) { + var b bytes.Buffer + keys := make([]string, len(params)) + pa := make(map[string]string) + for k, v := range params { + pa[k] = v[0] + keys = append(keys, k) + } + + sort.Strings(keys) + + for _, key := range keys { + if key == "signature" { + continue + } + + val := pa[key] + if key != "" && val != "" { + b.WriteString(key) + b.WriteString(val) + } + } + + stringToSign := fmt.Sprintf("%v\n%v\n%v\n", method, b.String(), RequestURL) + + sha256 := sha256.New + hash := hmac.New(sha256, []byte(appsecret)) + hash.Write([]byte(stringToSign)) + return base64.StdEncoding.EncodeToString(hash.Sum(nil)) +} diff --git a/pkg/plugins/apiauth/apiauth_test.go b/pkg/plugins/apiauth/apiauth_test.go new file mode 100644 index 00000000..1f56cb0f --- /dev/null +++ b/pkg/plugins/apiauth/apiauth_test.go @@ -0,0 +1,20 @@ +package apiauth + +import ( + "net/url" + "testing" +) + +func TestSignature(t *testing.T) { + appsecret := "beego secret" + method := "GET" + RequestURL := "http://localhost/test/url" + params := make(url.Values) + params.Add("arg1", "hello") + params.Add("arg2", "beego") + + signature := "mFdpvLh48ca4mDVEItE9++AKKQ/IVca7O/ZyyB8hR58=" + if Signature(appsecret, method, params, RequestURL) != signature { + t.Error("Signature error") + } +} diff --git a/pkg/plugins/auth/basic.go b/pkg/plugins/auth/basic.go new file mode 100644 index 00000000..c478044a --- /dev/null +++ b/pkg/plugins/auth/basic.go @@ -0,0 +1,107 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package auth provides handlers to enable basic auth support. +// Simple Usage: +// import( +// "github.com/astaxie/beego" +// "github.com/astaxie/beego/plugins/auth" +// ) +// +// func main(){ +// // authenticate every request +// beego.InsertFilter("*", beego.BeforeRouter,auth.Basic("username","secretpassword")) +// beego.Run() +// } +// +// +// Advanced Usage: +// +// func SecretAuth(username, password string) bool { +// return username == "astaxie" && password == "helloBeego" +// } +// authPlugin := auth.NewBasicAuthenticator(SecretAuth, "Authorization Required") +// beego.InsertFilter("*", beego.BeforeRouter,authPlugin) +package auth + +import ( + "encoding/base64" + "net/http" + "strings" + + "github.com/astaxie/beego" + "github.com/astaxie/beego/context" +) + +var defaultRealm = "Authorization Required" + +// Basic is the http basic auth +func Basic(username string, password string) beego.FilterFunc { + secrets := func(user, pass string) bool { + return user == username && pass == password + } + return NewBasicAuthenticator(secrets, defaultRealm) +} + +// NewBasicAuthenticator return the BasicAuth +func NewBasicAuthenticator(secrets SecretProvider, Realm string) beego.FilterFunc { + return func(ctx *context.Context) { + a := &BasicAuth{Secrets: secrets, Realm: Realm} + if username := a.CheckAuth(ctx.Request); username == "" { + a.RequireAuth(ctx.ResponseWriter, ctx.Request) + } + } +} + +// SecretProvider is the SecretProvider function +type SecretProvider func(user, pass string) bool + +// BasicAuth store the SecretProvider and Realm +type BasicAuth struct { + Secrets SecretProvider + Realm string +} + +// CheckAuth Checks the username/password combination from the request. Returns +// either an empty string (authentication failed) or the name of the +// authenticated user. +// Supports MD5 and SHA1 password entries +func (a *BasicAuth) CheckAuth(r *http.Request) string { + s := strings.SplitN(r.Header.Get("Authorization"), " ", 2) + if len(s) != 2 || s[0] != "Basic" { + return "" + } + + b, err := base64.StdEncoding.DecodeString(s[1]) + if err != nil { + return "" + } + pair := strings.SplitN(string(b), ":", 2) + if len(pair) != 2 { + return "" + } + + if a.Secrets(pair[0], pair[1]) { + return pair[0] + } + return "" +} + +// RequireAuth http.Handler for BasicAuth which initiates the authentication process +// (or requires reauthentication). +func (a *BasicAuth) RequireAuth(w http.ResponseWriter, r *http.Request) { + w.Header().Set("WWW-Authenticate", `Basic realm="`+a.Realm+`"`) + w.WriteHeader(401) + w.Write([]byte("401 Unauthorized\n")) +} diff --git a/pkg/plugins/authz/authz.go b/pkg/plugins/authz/authz.go new file mode 100644 index 00000000..9dc0db76 --- /dev/null +++ b/pkg/plugins/authz/authz.go @@ -0,0 +1,86 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package authz provides handlers to enable ACL, RBAC, ABAC authorization support. +// Simple Usage: +// import( +// "github.com/astaxie/beego" +// "github.com/astaxie/beego/plugins/authz" +// "github.com/casbin/casbin" +// ) +// +// func main(){ +// // mediate the access for every request +// beego.InsertFilter("*", beego.BeforeRouter, authz.NewAuthorizer(casbin.NewEnforcer("authz_model.conf", "authz_policy.csv"))) +// beego.Run() +// } +// +// +// Advanced Usage: +// +// func main(){ +// e := casbin.NewEnforcer("authz_model.conf", "") +// e.AddRoleForUser("alice", "admin") +// e.AddPolicy(...) +// +// beego.InsertFilter("*", beego.BeforeRouter, authz.NewAuthorizer(e)) +// beego.Run() +// } +package authz + +import ( + "github.com/astaxie/beego" + "github.com/astaxie/beego/context" + "github.com/casbin/casbin" + "net/http" +) + +// NewAuthorizer returns the authorizer. +// Use a casbin enforcer as input +func NewAuthorizer(e *casbin.Enforcer) beego.FilterFunc { + return func(ctx *context.Context) { + a := &BasicAuthorizer{enforcer: e} + + if !a.CheckPermission(ctx.Request) { + a.RequirePermission(ctx.ResponseWriter) + } + } +} + +// BasicAuthorizer stores the casbin handler +type BasicAuthorizer struct { + enforcer *casbin.Enforcer +} + +// GetUserName gets the user name from the request. +// Currently, only HTTP basic authentication is supported +func (a *BasicAuthorizer) GetUserName(r *http.Request) string { + username, _, _ := r.BasicAuth() + return username +} + +// CheckPermission checks the user/method/path combination from the request. +// Returns true (permission granted) or false (permission forbidden) +func (a *BasicAuthorizer) CheckPermission(r *http.Request) bool { + user := a.GetUserName(r) + method := r.Method + path := r.URL.Path + return a.enforcer.Enforce(user, path, method) +} + +// RequirePermission returns the 403 Forbidden to the client +func (a *BasicAuthorizer) RequirePermission(w http.ResponseWriter) { + w.WriteHeader(403) + w.Write([]byte("403 Forbidden\n")) +} diff --git a/pkg/plugins/authz/authz_model.conf b/pkg/plugins/authz/authz_model.conf new file mode 100644 index 00000000..d1b3dbd7 --- /dev/null +++ b/pkg/plugins/authz/authz_model.conf @@ -0,0 +1,14 @@ +[request_definition] +r = sub, obj, act + +[policy_definition] +p = sub, obj, act + +[role_definition] +g = _, _ + +[policy_effect] +e = some(where (p.eft == allow)) + +[matchers] +m = g(r.sub, p.sub) && keyMatch(r.obj, p.obj) && (r.act == p.act || p.act == "*") \ No newline at end of file diff --git a/pkg/plugins/authz/authz_policy.csv b/pkg/plugins/authz/authz_policy.csv new file mode 100644 index 00000000..c062dd3e --- /dev/null +++ b/pkg/plugins/authz/authz_policy.csv @@ -0,0 +1,7 @@ +p, alice, /dataset1/*, GET +p, alice, /dataset1/resource1, POST +p, bob, /dataset2/resource1, * +p, bob, /dataset2/resource2, GET +p, bob, /dataset2/folder1/*, POST +p, dataset1_admin, /dataset1/*, * +g, cathy, dataset1_admin \ No newline at end of file diff --git a/pkg/plugins/authz/authz_test.go b/pkg/plugins/authz/authz_test.go new file mode 100644 index 00000000..49aed84c --- /dev/null +++ b/pkg/plugins/authz/authz_test.go @@ -0,0 +1,107 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package authz + +import ( + "github.com/astaxie/beego" + "github.com/astaxie/beego/context" + "github.com/astaxie/beego/plugins/auth" + "github.com/casbin/casbin" + "net/http" + "net/http/httptest" + "testing" +) + +func testRequest(t *testing.T, handler *beego.ControllerRegister, user string, path string, method string, code int) { + r, _ := http.NewRequest(method, path, nil) + r.SetBasicAuth(user, "123") + w := httptest.NewRecorder() + handler.ServeHTTP(w, r) + + if w.Code != code { + t.Errorf("%s, %s, %s: %d, supposed to be %d", user, path, method, w.Code, code) + } +} + +func TestBasic(t *testing.T) { + handler := beego.NewControllerRegister() + + handler.InsertFilter("*", beego.BeforeRouter, auth.Basic("alice", "123")) + handler.InsertFilter("*", beego.BeforeRouter, NewAuthorizer(casbin.NewEnforcer("authz_model.conf", "authz_policy.csv"))) + + handler.Any("*", func(ctx *context.Context) { + ctx.Output.SetStatus(200) + }) + + testRequest(t, handler, "alice", "/dataset1/resource1", "GET", 200) + testRequest(t, handler, "alice", "/dataset1/resource1", "POST", 200) + testRequest(t, handler, "alice", "/dataset1/resource2", "GET", 200) + testRequest(t, handler, "alice", "/dataset1/resource2", "POST", 403) +} + +func TestPathWildcard(t *testing.T) { + handler := beego.NewControllerRegister() + + handler.InsertFilter("*", beego.BeforeRouter, auth.Basic("bob", "123")) + handler.InsertFilter("*", beego.BeforeRouter, NewAuthorizer(casbin.NewEnforcer("authz_model.conf", "authz_policy.csv"))) + + handler.Any("*", func(ctx *context.Context) { + ctx.Output.SetStatus(200) + }) + + testRequest(t, handler, "bob", "/dataset2/resource1", "GET", 200) + testRequest(t, handler, "bob", "/dataset2/resource1", "POST", 200) + testRequest(t, handler, "bob", "/dataset2/resource1", "DELETE", 200) + testRequest(t, handler, "bob", "/dataset2/resource2", "GET", 200) + testRequest(t, handler, "bob", "/dataset2/resource2", "POST", 403) + testRequest(t, handler, "bob", "/dataset2/resource2", "DELETE", 403) + + testRequest(t, handler, "bob", "/dataset2/folder1/item1", "GET", 403) + testRequest(t, handler, "bob", "/dataset2/folder1/item1", "POST", 200) + testRequest(t, handler, "bob", "/dataset2/folder1/item1", "DELETE", 403) + testRequest(t, handler, "bob", "/dataset2/folder1/item2", "GET", 403) + testRequest(t, handler, "bob", "/dataset2/folder1/item2", "POST", 200) + testRequest(t, handler, "bob", "/dataset2/folder1/item2", "DELETE", 403) +} + +func TestRBAC(t *testing.T) { + handler := beego.NewControllerRegister() + + handler.InsertFilter("*", beego.BeforeRouter, auth.Basic("cathy", "123")) + e := casbin.NewEnforcer("authz_model.conf", "authz_policy.csv") + handler.InsertFilter("*", beego.BeforeRouter, NewAuthorizer(e)) + + handler.Any("*", func(ctx *context.Context) { + ctx.Output.SetStatus(200) + }) + + // cathy can access all /dataset1/* resources via all methods because it has the dataset1_admin role. + testRequest(t, handler, "cathy", "/dataset1/item", "GET", 200) + testRequest(t, handler, "cathy", "/dataset1/item", "POST", 200) + testRequest(t, handler, "cathy", "/dataset1/item", "DELETE", 200) + testRequest(t, handler, "cathy", "/dataset2/item", "GET", 403) + testRequest(t, handler, "cathy", "/dataset2/item", "POST", 403) + testRequest(t, handler, "cathy", "/dataset2/item", "DELETE", 403) + + // delete all roles on user cathy, so cathy cannot access any resources now. + e.DeleteRolesForUser("cathy") + + testRequest(t, handler, "cathy", "/dataset1/item", "GET", 403) + testRequest(t, handler, "cathy", "/dataset1/item", "POST", 403) + testRequest(t, handler, "cathy", "/dataset1/item", "DELETE", 403) + testRequest(t, handler, "cathy", "/dataset2/item", "GET", 403) + testRequest(t, handler, "cathy", "/dataset2/item", "POST", 403) + testRequest(t, handler, "cathy", "/dataset2/item", "DELETE", 403) +} diff --git a/pkg/plugins/cors/cors.go b/pkg/plugins/cors/cors.go new file mode 100644 index 00000000..45c327ab --- /dev/null +++ b/pkg/plugins/cors/cors.go @@ -0,0 +1,228 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package cors provides handlers to enable CORS support. +// Usage +// import ( +// "github.com/astaxie/beego" +// "github.com/astaxie/beego/plugins/cors" +// ) +// +// func main() { +// // CORS for https://foo.* origins, allowing: +// // - PUT and PATCH methods +// // - Origin header +// // - Credentials share +// beego.InsertFilter("*", beego.BeforeRouter, cors.Allow(&cors.Options{ +// AllowOrigins: []string{"https://*.foo.com"}, +// AllowMethods: []string{"PUT", "PATCH"}, +// AllowHeaders: []string{"Origin"}, +// ExposeHeaders: []string{"Content-Length"}, +// AllowCredentials: true, +// })) +// beego.Run() +// } +package cors + +import ( + "net/http" + "regexp" + "strconv" + "strings" + "time" + + "github.com/astaxie/beego" + "github.com/astaxie/beego/context" +) + +const ( + headerAllowOrigin = "Access-Control-Allow-Origin" + headerAllowCredentials = "Access-Control-Allow-Credentials" + headerAllowHeaders = "Access-Control-Allow-Headers" + headerAllowMethods = "Access-Control-Allow-Methods" + headerExposeHeaders = "Access-Control-Expose-Headers" + headerMaxAge = "Access-Control-Max-Age" + + headerOrigin = "Origin" + headerRequestMethod = "Access-Control-Request-Method" + headerRequestHeaders = "Access-Control-Request-Headers" +) + +var ( + defaultAllowHeaders = []string{"Origin", "Accept", "Content-Type", "Authorization"} + // Regex patterns are generated from AllowOrigins. These are used and generated internally. + allowOriginPatterns = []string{} +) + +// Options represents Access Control options. +type Options struct { + // If set, all origins are allowed. + AllowAllOrigins bool + // A list of allowed origins. Wild cards and FQDNs are supported. + AllowOrigins []string + // If set, allows to share auth credentials such as cookies. + AllowCredentials bool + // A list of allowed HTTP methods. + AllowMethods []string + // A list of allowed HTTP headers. + AllowHeaders []string + // A list of exposed HTTP headers. + ExposeHeaders []string + // Max age of the CORS headers. + MaxAge time.Duration +} + +// Header converts options into CORS headers. +func (o *Options) Header(origin string) (headers map[string]string) { + headers = make(map[string]string) + // if origin is not allowed, don't extend the headers + // with CORS headers. + if !o.AllowAllOrigins && !o.IsOriginAllowed(origin) { + return + } + + // add allow origin + if o.AllowAllOrigins { + headers[headerAllowOrigin] = "*" + } else { + headers[headerAllowOrigin] = origin + } + + // add allow credentials + headers[headerAllowCredentials] = strconv.FormatBool(o.AllowCredentials) + + // add allow methods + if len(o.AllowMethods) > 0 { + headers[headerAllowMethods] = strings.Join(o.AllowMethods, ",") + } + + // add allow headers + if len(o.AllowHeaders) > 0 { + headers[headerAllowHeaders] = strings.Join(o.AllowHeaders, ",") + } + + // add exposed header + if len(o.ExposeHeaders) > 0 { + headers[headerExposeHeaders] = strings.Join(o.ExposeHeaders, ",") + } + // add a max age header + if o.MaxAge > time.Duration(0) { + headers[headerMaxAge] = strconv.FormatInt(int64(o.MaxAge/time.Second), 10) + } + return +} + +// PreflightHeader converts options into CORS headers for a preflight response. +func (o *Options) PreflightHeader(origin, rMethod, rHeaders string) (headers map[string]string) { + headers = make(map[string]string) + if !o.AllowAllOrigins && !o.IsOriginAllowed(origin) { + return + } + // verify if requested method is allowed + for _, method := range o.AllowMethods { + if method == rMethod { + headers[headerAllowMethods] = strings.Join(o.AllowMethods, ",") + break + } + } + + // verify if requested headers are allowed + var allowed []string + for _, rHeader := range strings.Split(rHeaders, ",") { + rHeader = strings.TrimSpace(rHeader) + lookupLoop: + for _, allowedHeader := range o.AllowHeaders { + if strings.ToLower(rHeader) == strings.ToLower(allowedHeader) { + allowed = append(allowed, rHeader) + break lookupLoop + } + } + } + + headers[headerAllowCredentials] = strconv.FormatBool(o.AllowCredentials) + // add allow origin + if o.AllowAllOrigins { + headers[headerAllowOrigin] = "*" + } else { + headers[headerAllowOrigin] = origin + } + + // add allowed headers + if len(allowed) > 0 { + headers[headerAllowHeaders] = strings.Join(allowed, ",") + } + + // add exposed headers + if len(o.ExposeHeaders) > 0 { + headers[headerExposeHeaders] = strings.Join(o.ExposeHeaders, ",") + } + // add a max age header + if o.MaxAge > time.Duration(0) { + headers[headerMaxAge] = strconv.FormatInt(int64(o.MaxAge/time.Second), 10) + } + return +} + +// IsOriginAllowed looks up if the origin matches one of the patterns +// generated from Options.AllowOrigins patterns. +func (o *Options) IsOriginAllowed(origin string) (allowed bool) { + for _, pattern := range allowOriginPatterns { + allowed, _ = regexp.MatchString(pattern, origin) + if allowed { + return + } + } + return +} + +// Allow enables CORS for requests those match the provided options. +func Allow(opts *Options) beego.FilterFunc { + // Allow default headers if nothing is specified. + if len(opts.AllowHeaders) == 0 { + opts.AllowHeaders = defaultAllowHeaders + } + + for _, origin := range opts.AllowOrigins { + pattern := regexp.QuoteMeta(origin) + pattern = strings.Replace(pattern, "\\*", ".*", -1) + pattern = strings.Replace(pattern, "\\?", ".", -1) + allowOriginPatterns = append(allowOriginPatterns, "^"+pattern+"$") + } + + return func(ctx *context.Context) { + var ( + origin = ctx.Input.Header(headerOrigin) + requestedMethod = ctx.Input.Header(headerRequestMethod) + requestedHeaders = ctx.Input.Header(headerRequestHeaders) + // additional headers to be added + // to the response. + headers map[string]string + ) + + if ctx.Input.Method() == "OPTIONS" && + (requestedMethod != "" || requestedHeaders != "") { + headers = opts.PreflightHeader(origin, requestedMethod, requestedHeaders) + for key, value := range headers { + ctx.Output.Header(key, value) + } + ctx.ResponseWriter.WriteHeader(http.StatusOK) + return + } + headers = opts.Header(origin) + + for key, value := range headers { + ctx.Output.Header(key, value) + } + } +} diff --git a/pkg/plugins/cors/cors_test.go b/pkg/plugins/cors/cors_test.go new file mode 100644 index 00000000..34039143 --- /dev/null +++ b/pkg/plugins/cors/cors_test.go @@ -0,0 +1,253 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package cors + +import ( + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" + + "github.com/astaxie/beego" + "github.com/astaxie/beego/context" +) + +// HTTPHeaderGuardRecorder is httptest.ResponseRecorder with own http.Header +type HTTPHeaderGuardRecorder struct { + *httptest.ResponseRecorder + savedHeaderMap http.Header +} + +// NewRecorder return HttpHeaderGuardRecorder +func NewRecorder() *HTTPHeaderGuardRecorder { + return &HTTPHeaderGuardRecorder{httptest.NewRecorder(), nil} +} + +func (gr *HTTPHeaderGuardRecorder) WriteHeader(code int) { + gr.ResponseRecorder.WriteHeader(code) + gr.savedHeaderMap = gr.ResponseRecorder.Header() +} + +func (gr *HTTPHeaderGuardRecorder) Header() http.Header { + if gr.savedHeaderMap != nil { + // headers were written. clone so we don't get updates + clone := make(http.Header) + for k, v := range gr.savedHeaderMap { + clone[k] = v + } + return clone + } + return gr.ResponseRecorder.Header() +} + +func Test_AllowAll(t *testing.T) { + recorder := httptest.NewRecorder() + handler := beego.NewControllerRegister() + handler.InsertFilter("*", beego.BeforeRouter, Allow(&Options{ + AllowAllOrigins: true, + })) + handler.Any("/foo", func(ctx *context.Context) { + ctx.Output.SetStatus(500) + }) + r, _ := http.NewRequest("PUT", "/foo", nil) + handler.ServeHTTP(recorder, r) + + if recorder.HeaderMap.Get(headerAllowOrigin) != "*" { + t.Errorf("Allow-Origin header should be *") + } +} + +func Test_AllowRegexMatch(t *testing.T) { + recorder := httptest.NewRecorder() + handler := beego.NewControllerRegister() + handler.InsertFilter("*", beego.BeforeRouter, Allow(&Options{ + AllowOrigins: []string{"https://aaa.com", "https://*.foo.com"}, + })) + handler.Any("/foo", func(ctx *context.Context) { + ctx.Output.SetStatus(500) + }) + origin := "https://bar.foo.com" + r, _ := http.NewRequest("PUT", "/foo", nil) + r.Header.Add("Origin", origin) + handler.ServeHTTP(recorder, r) + + headerValue := recorder.HeaderMap.Get(headerAllowOrigin) + if headerValue != origin { + t.Errorf("Allow-Origin header should be %v, found %v", origin, headerValue) + } +} + +func Test_AllowRegexNoMatch(t *testing.T) { + recorder := httptest.NewRecorder() + handler := beego.NewControllerRegister() + handler.InsertFilter("*", beego.BeforeRouter, Allow(&Options{ + AllowOrigins: []string{"https://*.foo.com"}, + })) + handler.Any("/foo", func(ctx *context.Context) { + ctx.Output.SetStatus(500) + }) + origin := "https://ww.foo.com.evil.com" + r, _ := http.NewRequest("PUT", "/foo", nil) + r.Header.Add("Origin", origin) + handler.ServeHTTP(recorder, r) + + headerValue := recorder.HeaderMap.Get(headerAllowOrigin) + if headerValue != "" { + t.Errorf("Allow-Origin header should not exist, found %v", headerValue) + } +} + +func Test_OtherHeaders(t *testing.T) { + recorder := httptest.NewRecorder() + handler := beego.NewControllerRegister() + handler.InsertFilter("*", beego.BeforeRouter, Allow(&Options{ + AllowAllOrigins: true, + AllowCredentials: true, + AllowMethods: []string{"PATCH", "GET"}, + AllowHeaders: []string{"Origin", "X-whatever"}, + ExposeHeaders: []string{"Content-Length", "Hello"}, + MaxAge: 5 * time.Minute, + })) + handler.Any("/foo", func(ctx *context.Context) { + ctx.Output.SetStatus(500) + }) + r, _ := http.NewRequest("PUT", "/foo", nil) + handler.ServeHTTP(recorder, r) + + credentialsVal := recorder.HeaderMap.Get(headerAllowCredentials) + methodsVal := recorder.HeaderMap.Get(headerAllowMethods) + headersVal := recorder.HeaderMap.Get(headerAllowHeaders) + exposedHeadersVal := recorder.HeaderMap.Get(headerExposeHeaders) + maxAgeVal := recorder.HeaderMap.Get(headerMaxAge) + + if credentialsVal != "true" { + t.Errorf("Allow-Credentials is expected to be true, found %v", credentialsVal) + } + + if methodsVal != "PATCH,GET" { + t.Errorf("Allow-Methods is expected to be PATCH,GET; found %v", methodsVal) + } + + if headersVal != "Origin,X-whatever" { + t.Errorf("Allow-Headers is expected to be Origin,X-whatever; found %v", headersVal) + } + + if exposedHeadersVal != "Content-Length,Hello" { + t.Errorf("Expose-Headers are expected to be Content-Length,Hello. Found %v", exposedHeadersVal) + } + + if maxAgeVal != "300" { + t.Errorf("Max-Age is expected to be 300, found %v", maxAgeVal) + } +} + +func Test_DefaultAllowHeaders(t *testing.T) { + recorder := httptest.NewRecorder() + handler := beego.NewControllerRegister() + handler.InsertFilter("*", beego.BeforeRouter, Allow(&Options{ + AllowAllOrigins: true, + })) + handler.Any("/foo", func(ctx *context.Context) { + ctx.Output.SetStatus(500) + }) + + r, _ := http.NewRequest("PUT", "/foo", nil) + handler.ServeHTTP(recorder, r) + + headersVal := recorder.HeaderMap.Get(headerAllowHeaders) + if headersVal != "Origin,Accept,Content-Type,Authorization" { + t.Errorf("Allow-Headers is expected to be Origin,Accept,Content-Type,Authorization; found %v", headersVal) + } +} + +func Test_Preflight(t *testing.T) { + recorder := NewRecorder() + handler := beego.NewControllerRegister() + handler.InsertFilter("*", beego.BeforeRouter, Allow(&Options{ + AllowAllOrigins: true, + AllowMethods: []string{"PUT", "PATCH"}, + AllowHeaders: []string{"Origin", "X-whatever", "X-CaseSensitive"}, + })) + + handler.Any("/foo", func(ctx *context.Context) { + ctx.Output.SetStatus(200) + }) + + r, _ := http.NewRequest("OPTIONS", "/foo", nil) + r.Header.Add(headerRequestMethod, "PUT") + r.Header.Add(headerRequestHeaders, "X-whatever, x-casesensitive") + handler.ServeHTTP(recorder, r) + + headers := recorder.Header() + methodsVal := headers.Get(headerAllowMethods) + headersVal := headers.Get(headerAllowHeaders) + originVal := headers.Get(headerAllowOrigin) + + if methodsVal != "PUT,PATCH" { + t.Errorf("Allow-Methods is expected to be PUT,PATCH, found %v", methodsVal) + } + + if !strings.Contains(headersVal, "X-whatever") { + t.Errorf("Allow-Headers is expected to contain X-whatever, found %v", headersVal) + } + + if !strings.Contains(headersVal, "x-casesensitive") { + t.Errorf("Allow-Headers is expected to contain x-casesensitive, found %v", headersVal) + } + + if originVal != "*" { + t.Errorf("Allow-Origin is expected to be *, found %v", originVal) + } + + if recorder.Code != http.StatusOK { + t.Errorf("Status code is expected to be 200, found %d", recorder.Code) + } +} + +func Benchmark_WithoutCORS(b *testing.B) { + recorder := httptest.NewRecorder() + handler := beego.NewControllerRegister() + beego.BConfig.RunMode = beego.PROD + handler.Any("/foo", func(ctx *context.Context) { + ctx.Output.SetStatus(500) + }) + b.ResetTimer() + r, _ := http.NewRequest("PUT", "/foo", nil) + for i := 0; i < b.N; i++ { + handler.ServeHTTP(recorder, r) + } +} + +func Benchmark_WithCORS(b *testing.B) { + recorder := httptest.NewRecorder() + handler := beego.NewControllerRegister() + beego.BConfig.RunMode = beego.PROD + handler.InsertFilter("*", beego.BeforeRouter, Allow(&Options{ + AllowAllOrigins: true, + AllowCredentials: true, + AllowMethods: []string{"PATCH", "GET"}, + AllowHeaders: []string{"Origin", "X-whatever"}, + MaxAge: 5 * time.Minute, + })) + handler.Any("/foo", func(ctx *context.Context) { + ctx.Output.SetStatus(500) + }) + b.ResetTimer() + r, _ := http.NewRequest("PUT", "/foo", nil) + for i := 0; i < b.N; i++ { + handler.ServeHTTP(recorder, r) + } +} diff --git a/pkg/policy.go b/pkg/policy.go new file mode 100644 index 00000000..ab23f927 --- /dev/null +++ b/pkg/policy.go @@ -0,0 +1,97 @@ +// Copyright 2016 beego authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package beego + +import ( + "strings" + + "github.com/astaxie/beego/context" +) + +// PolicyFunc defines a policy function which is invoked before the controller handler is executed. +type PolicyFunc func(*context.Context) + +// FindPolicy Find Router info for URL +func (p *ControllerRegister) FindPolicy(cont *context.Context) []PolicyFunc { + var urlPath = cont.Input.URL() + if !BConfig.RouterCaseSensitive { + urlPath = strings.ToLower(urlPath) + } + httpMethod := cont.Input.Method() + isWildcard := false + // Find policy for current method + t, ok := p.policies[httpMethod] + // If not found - find policy for whole controller + if !ok { + t, ok = p.policies["*"] + isWildcard = true + } + if ok { + runObjects := t.Match(urlPath, cont) + if r, ok := runObjects.([]PolicyFunc); ok { + return r + } else if !isWildcard { + // If no policies found and we checked not for "*" method - try to find it + t, ok = p.policies["*"] + if ok { + runObjects = t.Match(urlPath, cont) + if r, ok = runObjects.([]PolicyFunc); ok { + return r + } + } + } + } + return nil +} + +func (p *ControllerRegister) addToPolicy(method, pattern string, r ...PolicyFunc) { + method = strings.ToUpper(method) + p.enablePolicy = true + if !BConfig.RouterCaseSensitive { + pattern = strings.ToLower(pattern) + } + if t, ok := p.policies[method]; ok { + t.AddRouter(pattern, r) + } else { + t := NewTree() + t.AddRouter(pattern, r) + p.policies[method] = t + } +} + +// Policy Register new policy in beego +func Policy(pattern, method string, policy ...PolicyFunc) { + BeeApp.Handlers.addToPolicy(method, pattern, policy...) +} + +// Find policies and execute if were found +func (p *ControllerRegister) execPolicy(cont *context.Context, urlPath string) (started bool) { + if !p.enablePolicy { + return false + } + // Find Policy for method + policyList := p.FindPolicy(cont) + if len(policyList) > 0 { + // Run policies + for _, runPolicy := range policyList { + runPolicy(cont) + if cont.ResponseWriter.Started { + return true + } + } + return false + } + return false +} diff --git a/pkg/router.go b/pkg/router.go new file mode 100644 index 00000000..6a8ac6f7 --- /dev/null +++ b/pkg/router.go @@ -0,0 +1,1052 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package beego + +import ( + "errors" + "fmt" + "net/http" + "os" + "path" + "path/filepath" + "reflect" + "strconv" + "strings" + "sync" + "time" + + beecontext "github.com/astaxie/beego/context" + "github.com/astaxie/beego/context/param" + "github.com/astaxie/beego/logs" + "github.com/astaxie/beego/toolbox" + "github.com/astaxie/beego/utils" +) + +// default filter execution points +const ( + BeforeStatic = iota + BeforeRouter + BeforeExec + AfterExec + FinishRouter +) + +const ( + routerTypeBeego = iota + routerTypeRESTFul + routerTypeHandler +) + +var ( + // HTTPMETHOD list the supported http methods. + HTTPMETHOD = map[string]bool{ + "GET": true, + "POST": true, + "PUT": true, + "DELETE": true, + "PATCH": true, + "OPTIONS": true, + "HEAD": true, + "TRACE": true, + "CONNECT": true, + "MKCOL": true, + "COPY": true, + "MOVE": true, + "PROPFIND": true, + "PROPPATCH": true, + "LOCK": true, + "UNLOCK": true, + } + // these beego.Controller's methods shouldn't reflect to AutoRouter + exceptMethod = []string{"Init", "Prepare", "Finish", "Render", "RenderString", + "RenderBytes", "Redirect", "Abort", "StopRun", "UrlFor", "ServeJSON", "ServeJSONP", + "ServeYAML", "ServeXML", "Input", "ParseForm", "GetString", "GetStrings", "GetInt", "GetBool", + "GetFloat", "GetFile", "SaveToFile", "StartSession", "SetSession", "GetSession", + "DelSession", "SessionRegenerateID", "DestroySession", "IsAjax", "GetSecureCookie", + "SetSecureCookie", "XsrfToken", "CheckXsrfCookie", "XsrfFormHtml", + "GetControllerAndAction", "ServeFormatted"} + + urlPlaceholder = "{{placeholder}}" + // DefaultAccessLogFilter will skip the accesslog if return true + DefaultAccessLogFilter FilterHandler = &logFilter{} +) + +// FilterHandler is an interface for +type FilterHandler interface { + Filter(*beecontext.Context) bool +} + +// default log filter static file will not show +type logFilter struct { +} + +func (l *logFilter) Filter(ctx *beecontext.Context) bool { + requestPath := path.Clean(ctx.Request.URL.Path) + if requestPath == "/favicon.ico" || requestPath == "/robots.txt" { + return true + } + for prefix := range BConfig.WebConfig.StaticDir { + if strings.HasPrefix(requestPath, prefix) { + return true + } + } + return false +} + +// ExceptMethodAppend to append a slice's value into "exceptMethod", for controller's methods shouldn't reflect to AutoRouter +func ExceptMethodAppend(action string) { + exceptMethod = append(exceptMethod, action) +} + +// ControllerInfo holds information about the controller. +type ControllerInfo struct { + pattern string + controllerType reflect.Type + methods map[string]string + handler http.Handler + runFunction FilterFunc + routerType int + initialize func() ControllerInterface + methodParams []*param.MethodParam +} + +func (c *ControllerInfo) GetPattern() string { + return c.pattern +} + +// ControllerRegister containers registered router rules, controller handlers and filters. +type ControllerRegister struct { + routers map[string]*Tree + enablePolicy bool + policies map[string]*Tree + enableFilter bool + filters [FinishRouter + 1][]*FilterRouter + pool sync.Pool +} + +// NewControllerRegister returns a new ControllerRegister. +func NewControllerRegister() *ControllerRegister { + return &ControllerRegister{ + routers: make(map[string]*Tree), + policies: make(map[string]*Tree), + pool: sync.Pool{ + New: func() interface{} { + return beecontext.NewContext() + }, + }, + } +} + +// Add controller handler and pattern rules to ControllerRegister. +// usage: +// default methods is the same name as method +// Add("/user",&UserController{}) +// Add("/api/list",&RestController{},"*:ListFood") +// Add("/api/create",&RestController{},"post:CreateFood") +// Add("/api/update",&RestController{},"put:UpdateFood") +// Add("/api/delete",&RestController{},"delete:DeleteFood") +// Add("/api",&RestController{},"get,post:ApiFunc" +// Add("/simple",&SimpleController{},"get:GetFunc;post:PostFunc") +func (p *ControllerRegister) Add(pattern string, c ControllerInterface, mappingMethods ...string) { + p.addWithMethodParams(pattern, c, nil, mappingMethods...) +} + +func (p *ControllerRegister) addWithMethodParams(pattern string, c ControllerInterface, methodParams []*param.MethodParam, mappingMethods ...string) { + reflectVal := reflect.ValueOf(c) + t := reflect.Indirect(reflectVal).Type() + methods := make(map[string]string) + if len(mappingMethods) > 0 { + semi := strings.Split(mappingMethods[0], ";") + for _, v := range semi { + colon := strings.Split(v, ":") + if len(colon) != 2 { + panic("method mapping format is invalid") + } + comma := strings.Split(colon[0], ",") + for _, m := range comma { + if m == "*" || HTTPMETHOD[strings.ToUpper(m)] { + if val := reflectVal.MethodByName(colon[1]); val.IsValid() { + methods[strings.ToUpper(m)] = colon[1] + } else { + panic("'" + colon[1] + "' method doesn't exist in the controller " + t.Name()) + } + } else { + panic(v + " is an invalid method mapping. Method doesn't exist " + m) + } + } + } + } + + route := &ControllerInfo{} + route.pattern = pattern + route.methods = methods + route.routerType = routerTypeBeego + route.controllerType = t + route.initialize = func() ControllerInterface { + vc := reflect.New(route.controllerType) + execController, ok := vc.Interface().(ControllerInterface) + if !ok { + panic("controller is not ControllerInterface") + } + + elemVal := reflect.ValueOf(c).Elem() + elemType := reflect.TypeOf(c).Elem() + execElem := reflect.ValueOf(execController).Elem() + + numOfFields := elemVal.NumField() + for i := 0; i < numOfFields; i++ { + fieldType := elemType.Field(i) + elemField := execElem.FieldByName(fieldType.Name) + if elemField.CanSet() { + fieldVal := elemVal.Field(i) + elemField.Set(fieldVal) + } + } + + return execController + } + + route.methodParams = methodParams + if len(methods) == 0 { + for m := range HTTPMETHOD { + p.addToRouter(m, pattern, route) + } + } else { + for k := range methods { + if k == "*" { + for m := range HTTPMETHOD { + p.addToRouter(m, pattern, route) + } + } else { + p.addToRouter(k, pattern, route) + } + } + } +} + +func (p *ControllerRegister) addToRouter(method, pattern string, r *ControllerInfo) { + if !BConfig.RouterCaseSensitive { + pattern = strings.ToLower(pattern) + } + if t, ok := p.routers[method]; ok { + t.AddRouter(pattern, r) + } else { + t := NewTree() + t.AddRouter(pattern, r) + p.routers[method] = t + } +} + +// Include only when the Runmode is dev will generate router file in the router/auto.go from the controller +// Include(&BankAccount{}, &OrderController{},&RefundController{},&ReceiptController{}) +func (p *ControllerRegister) Include(cList ...ControllerInterface) { + if BConfig.RunMode == DEV { + skip := make(map[string]bool, 10) + wgopath := utils.GetGOPATHs() + go111module := os.Getenv(`GO111MODULE`) + for _, c := range cList { + reflectVal := reflect.ValueOf(c) + t := reflect.Indirect(reflectVal).Type() + // for go modules + if go111module == `on` { + pkgpath := filepath.Join(WorkPath, "..", t.PkgPath()) + if utils.FileExists(pkgpath) { + if pkgpath != "" { + if _, ok := skip[pkgpath]; !ok { + skip[pkgpath] = true + parserPkg(pkgpath, t.PkgPath()) + } + } + } + } else { + if len(wgopath) == 0 { + panic("you are in dev mode. So please set gopath") + } + pkgpath := "" + for _, wg := range wgopath { + wg, _ = filepath.EvalSymlinks(filepath.Join(wg, "src", t.PkgPath())) + if utils.FileExists(wg) { + pkgpath = wg + break + } + } + if pkgpath != "" { + if _, ok := skip[pkgpath]; !ok { + skip[pkgpath] = true + parserPkg(pkgpath, t.PkgPath()) + } + } + } + } + } + for _, c := range cList { + reflectVal := reflect.ValueOf(c) + t := reflect.Indirect(reflectVal).Type() + key := t.PkgPath() + ":" + t.Name() + if comm, ok := GlobalControllerRouter[key]; ok { + for _, a := range comm { + for _, f := range a.Filters { + p.InsertFilter(f.Pattern, f.Pos, f.Filter, f.ReturnOnOutput, f.ResetParams) + } + + p.addWithMethodParams(a.Router, c, a.MethodParams, strings.Join(a.AllowHTTPMethods, ",")+":"+a.Method) + } + } + } +} + +// GetContext returns a context from pool, so usually you should remember to call Reset function to clean the context +// And don't forget to give back context to pool +// example: +// ctx := p.GetContext() +// ctx.Reset(w, q) +// defer p.GiveBackContext(ctx) +func (p *ControllerRegister) GetContext() *beecontext.Context { + return p.pool.Get().(*beecontext.Context) +} + +// GiveBackContext put the ctx into pool so that it could be reuse +func (p *ControllerRegister) GiveBackContext(ctx *beecontext.Context) { + p.pool.Put(ctx) +} + +// Get add get method +// usage: +// Get("/", func(ctx *context.Context){ +// ctx.Output.Body("hello world") +// }) +func (p *ControllerRegister) Get(pattern string, f FilterFunc) { + p.AddMethod("get", pattern, f) +} + +// Post add post method +// usage: +// Post("/api", func(ctx *context.Context){ +// ctx.Output.Body("hello world") +// }) +func (p *ControllerRegister) Post(pattern string, f FilterFunc) { + p.AddMethod("post", pattern, f) +} + +// Put add put method +// usage: +// Put("/api/:id", func(ctx *context.Context){ +// ctx.Output.Body("hello world") +// }) +func (p *ControllerRegister) Put(pattern string, f FilterFunc) { + p.AddMethod("put", pattern, f) +} + +// Delete add delete method +// usage: +// Delete("/api/:id", func(ctx *context.Context){ +// ctx.Output.Body("hello world") +// }) +func (p *ControllerRegister) Delete(pattern string, f FilterFunc) { + p.AddMethod("delete", pattern, f) +} + +// Head add head method +// usage: +// Head("/api/:id", func(ctx *context.Context){ +// ctx.Output.Body("hello world") +// }) +func (p *ControllerRegister) Head(pattern string, f FilterFunc) { + p.AddMethod("head", pattern, f) +} + +// Patch add patch method +// usage: +// Patch("/api/:id", func(ctx *context.Context){ +// ctx.Output.Body("hello world") +// }) +func (p *ControllerRegister) Patch(pattern string, f FilterFunc) { + p.AddMethod("patch", pattern, f) +} + +// Options add options method +// usage: +// Options("/api/:id", func(ctx *context.Context){ +// ctx.Output.Body("hello world") +// }) +func (p *ControllerRegister) Options(pattern string, f FilterFunc) { + p.AddMethod("options", pattern, f) +} + +// Any add all method +// usage: +// Any("/api/:id", func(ctx *context.Context){ +// ctx.Output.Body("hello world") +// }) +func (p *ControllerRegister) Any(pattern string, f FilterFunc) { + p.AddMethod("*", pattern, f) +} + +// AddMethod add http method router +// usage: +// AddMethod("get","/api/:id", func(ctx *context.Context){ +// ctx.Output.Body("hello world") +// }) +func (p *ControllerRegister) AddMethod(method, pattern string, f FilterFunc) { + method = strings.ToUpper(method) + if method != "*" && !HTTPMETHOD[method] { + panic("not support http method: " + method) + } + route := &ControllerInfo{} + route.pattern = pattern + route.routerType = routerTypeRESTFul + route.runFunction = f + methods := make(map[string]string) + if method == "*" { + for val := range HTTPMETHOD { + methods[val] = val + } + } else { + methods[method] = method + } + route.methods = methods + for k := range methods { + if k == "*" { + for m := range HTTPMETHOD { + p.addToRouter(m, pattern, route) + } + } else { + p.addToRouter(k, pattern, route) + } + } +} + +// Handler add user defined Handler +func (p *ControllerRegister) Handler(pattern string, h http.Handler, options ...interface{}) { + route := &ControllerInfo{} + route.pattern = pattern + route.routerType = routerTypeHandler + route.handler = h + if len(options) > 0 { + if _, ok := options[0].(bool); ok { + pattern = path.Join(pattern, "?:all(.*)") + } + } + for m := range HTTPMETHOD { + p.addToRouter(m, pattern, route) + } +} + +// AddAuto router to ControllerRegister. +// example beego.AddAuto(&MainContorlller{}), +// MainController has method List and Page. +// visit the url /main/list to execute List function +// /main/page to execute Page function. +func (p *ControllerRegister) AddAuto(c ControllerInterface) { + p.AddAutoPrefix("/", c) +} + +// AddAutoPrefix Add auto router to ControllerRegister with prefix. +// example beego.AddAutoPrefix("/admin",&MainContorlller{}), +// MainController has method List and Page. +// visit the url /admin/main/list to execute List function +// /admin/main/page to execute Page function. +func (p *ControllerRegister) AddAutoPrefix(prefix string, c ControllerInterface) { + reflectVal := reflect.ValueOf(c) + rt := reflectVal.Type() + ct := reflect.Indirect(reflectVal).Type() + controllerName := strings.TrimSuffix(ct.Name(), "Controller") + for i := 0; i < rt.NumMethod(); i++ { + if !utils.InSlice(rt.Method(i).Name, exceptMethod) { + route := &ControllerInfo{} + route.routerType = routerTypeBeego + route.methods = map[string]string{"*": rt.Method(i).Name} + route.controllerType = ct + pattern := path.Join(prefix, strings.ToLower(controllerName), strings.ToLower(rt.Method(i).Name), "*") + patternInit := path.Join(prefix, controllerName, rt.Method(i).Name, "*") + patternFix := path.Join(prefix, strings.ToLower(controllerName), strings.ToLower(rt.Method(i).Name)) + patternFixInit := path.Join(prefix, controllerName, rt.Method(i).Name) + route.pattern = pattern + for m := range HTTPMETHOD { + p.addToRouter(m, pattern, route) + p.addToRouter(m, patternInit, route) + p.addToRouter(m, patternFix, route) + p.addToRouter(m, patternFixInit, route) + } + } + } +} + +// InsertFilter Add a FilterFunc with pattern rule and action constant. +// params is for: +// 1. setting the returnOnOutput value (false allows multiple filters to execute) +// 2. determining whether or not params need to be reset. +func (p *ControllerRegister) InsertFilter(pattern string, pos int, filter FilterFunc, params ...bool) error { + mr := &FilterRouter{ + tree: NewTree(), + pattern: pattern, + filterFunc: filter, + returnOnOutput: true, + } + if !BConfig.RouterCaseSensitive { + mr.pattern = strings.ToLower(pattern) + } + + paramsLen := len(params) + if paramsLen > 0 { + mr.returnOnOutput = params[0] + } + if paramsLen > 1 { + mr.resetParams = params[1] + } + mr.tree.AddRouter(pattern, true) + return p.insertFilterRouter(pos, mr) +} + +// add Filter into +func (p *ControllerRegister) insertFilterRouter(pos int, mr *FilterRouter) (err error) { + if pos < BeforeStatic || pos > FinishRouter { + return errors.New("can not find your filter position") + } + p.enableFilter = true + p.filters[pos] = append(p.filters[pos], mr) + return nil +} + +// URLFor does another controller handler in this request function. +// it can access any controller method. +func (p *ControllerRegister) URLFor(endpoint string, values ...interface{}) string { + paths := strings.Split(endpoint, ".") + if len(paths) <= 1 { + logs.Warn("urlfor endpoint must like path.controller.method") + return "" + } + if len(values)%2 != 0 { + logs.Warn("urlfor params must key-value pair") + return "" + } + params := make(map[string]string) + if len(values) > 0 { + key := "" + for k, v := range values { + if k%2 == 0 { + key = fmt.Sprint(v) + } else { + params[key] = fmt.Sprint(v) + } + } + } + controllerName := strings.Join(paths[:len(paths)-1], "/") + methodName := paths[len(paths)-1] + for m, t := range p.routers { + ok, url := p.getURL(t, "/", controllerName, methodName, params, m) + if ok { + return url + } + } + return "" +} + +func (p *ControllerRegister) getURL(t *Tree, url, controllerName, methodName string, params map[string]string, httpMethod string) (bool, string) { + for _, subtree := range t.fixrouters { + u := path.Join(url, subtree.prefix) + ok, u := p.getURL(subtree, u, controllerName, methodName, params, httpMethod) + if ok { + return ok, u + } + } + if t.wildcard != nil { + u := path.Join(url, urlPlaceholder) + ok, u := p.getURL(t.wildcard, u, controllerName, methodName, params, httpMethod) + if ok { + return ok, u + } + } + for _, l := range t.leaves { + if c, ok := l.runObject.(*ControllerInfo); ok { + if c.routerType == routerTypeBeego && + strings.HasSuffix(path.Join(c.controllerType.PkgPath(), c.controllerType.Name()), controllerName) { + find := false + if HTTPMETHOD[strings.ToUpper(methodName)] { + if len(c.methods) == 0 { + find = true + } else if m, ok := c.methods[strings.ToUpper(methodName)]; ok && m == strings.ToUpper(methodName) { + find = true + } else if m, ok = c.methods["*"]; ok && m == methodName { + find = true + } + } + if !find { + for m, md := range c.methods { + if (m == "*" || m == httpMethod) && md == methodName { + find = true + } + } + } + if find { + if l.regexps == nil { + if len(l.wildcards) == 0 { + return true, strings.Replace(url, "/"+urlPlaceholder, "", 1) + toURL(params) + } + if len(l.wildcards) == 1 { + if v, ok := params[l.wildcards[0]]; ok { + delete(params, l.wildcards[0]) + return true, strings.Replace(url, urlPlaceholder, v, 1) + toURL(params) + } + return false, "" + } + if len(l.wildcards) == 3 && l.wildcards[0] == "." { + if p, ok := params[":path"]; ok { + if e, isok := params[":ext"]; isok { + delete(params, ":path") + delete(params, ":ext") + return true, strings.Replace(url, urlPlaceholder, p+"."+e, -1) + toURL(params) + } + } + } + canSkip := false + for _, v := range l.wildcards { + if v == ":" { + canSkip = true + continue + } + if u, ok := params[v]; ok { + delete(params, v) + url = strings.Replace(url, urlPlaceholder, u, 1) + } else { + if canSkip { + canSkip = false + continue + } + return false, "" + } + } + return true, url + toURL(params) + } + var i int + var startReg bool + regURL := "" + for _, v := range strings.Trim(l.regexps.String(), "^$") { + if v == '(' { + startReg = true + continue + } else if v == ')' { + startReg = false + if v, ok := params[l.wildcards[i]]; ok { + delete(params, l.wildcards[i]) + regURL = regURL + v + i++ + } else { + break + } + } else if !startReg { + regURL = string(append([]rune(regURL), v)) + } + } + if l.regexps.MatchString(regURL) { + ps := strings.Split(regURL, "/") + for _, p := range ps { + url = strings.Replace(url, urlPlaceholder, p, 1) + } + return true, url + toURL(params) + } + } + } + } + } + + return false, "" +} + +func (p *ControllerRegister) execFilter(context *beecontext.Context, urlPath string, pos int) (started bool) { + var preFilterParams map[string]string + for _, filterR := range p.filters[pos] { + if filterR.returnOnOutput && context.ResponseWriter.Started { + return true + } + if filterR.resetParams { + preFilterParams = context.Input.Params() + } + if ok := filterR.ValidRouter(urlPath, context); ok { + filterR.filterFunc(context) + if filterR.resetParams { + context.Input.ResetParams() + for k, v := range preFilterParams { + context.Input.SetParam(k, v) + } + } + } + if filterR.returnOnOutput && context.ResponseWriter.Started { + return true + } + } + return false +} + +// Implement http.Handler interface. +func (p *ControllerRegister) ServeHTTP(rw http.ResponseWriter, r *http.Request) { + startTime := time.Now() + var ( + runRouter reflect.Type + findRouter bool + runMethod string + methodParams []*param.MethodParam + routerInfo *ControllerInfo + isRunnable bool + ) + context := p.GetContext() + + context.Reset(rw, r) + + defer p.GiveBackContext(context) + if BConfig.RecoverFunc != nil { + defer BConfig.RecoverFunc(context) + } + + context.Output.EnableGzip = BConfig.EnableGzip + + if BConfig.RunMode == DEV { + context.Output.Header("Server", BConfig.ServerName) + } + + var urlPath = r.URL.Path + + if !BConfig.RouterCaseSensitive { + urlPath = strings.ToLower(urlPath) + } + + // filter wrong http method + if !HTTPMETHOD[r.Method] { + exception("405", context) + goto Admin + } + + // filter for static file + if len(p.filters[BeforeStatic]) > 0 && p.execFilter(context, urlPath, BeforeStatic) { + goto Admin + } + + serverStaticRouter(context) + + if context.ResponseWriter.Started { + findRouter = true + goto Admin + } + + if r.Method != http.MethodGet && r.Method != http.MethodHead { + if BConfig.CopyRequestBody && !context.Input.IsUpload() { + // connection will close if the incoming data are larger (RFC 7231, 6.5.11) + if r.ContentLength > BConfig.MaxMemory { + logs.Error(errors.New("payload too large")) + exception("413", context) + goto Admin + } + context.Input.CopyBody(BConfig.MaxMemory) + } + context.Input.ParseFormOrMulitForm(BConfig.MaxMemory) + } + + // session init + if BConfig.WebConfig.Session.SessionOn { + var err error + context.Input.CruSession, err = GlobalSessions.SessionStart(rw, r) + if err != nil { + logs.Error(err) + exception("503", context) + goto Admin + } + defer func() { + if context.Input.CruSession != nil { + context.Input.CruSession.SessionRelease(rw) + } + }() + } + if len(p.filters[BeforeRouter]) > 0 && p.execFilter(context, urlPath, BeforeRouter) { + goto Admin + } + // User can define RunController and RunMethod in filter + if context.Input.RunController != nil && context.Input.RunMethod != "" { + findRouter = true + runMethod = context.Input.RunMethod + runRouter = context.Input.RunController + } else { + routerInfo, findRouter = p.FindRouter(context) + } + + // if no matches to url, throw a not found exception + if !findRouter { + exception("404", context) + goto Admin + } + if splat := context.Input.Param(":splat"); splat != "" { + for k, v := range strings.Split(splat, "/") { + context.Input.SetParam(strconv.Itoa(k), v) + } + } + + if routerInfo != nil { + // store router pattern into context + context.Input.SetData("RouterPattern", routerInfo.pattern) + } + + // execute middleware filters + if len(p.filters[BeforeExec]) > 0 && p.execFilter(context, urlPath, BeforeExec) { + goto Admin + } + + // check policies + if p.execPolicy(context, urlPath) { + goto Admin + } + + if routerInfo != nil { + if routerInfo.routerType == routerTypeRESTFul { + if _, ok := routerInfo.methods[r.Method]; ok { + isRunnable = true + routerInfo.runFunction(context) + } else { + exception("405", context) + goto Admin + } + } else if routerInfo.routerType == routerTypeHandler { + isRunnable = true + routerInfo.handler.ServeHTTP(context.ResponseWriter, context.Request) + } else { + runRouter = routerInfo.controllerType + methodParams = routerInfo.methodParams + method := r.Method + if r.Method == http.MethodPost && context.Input.Query("_method") == http.MethodPut { + method = http.MethodPut + } + if r.Method == http.MethodPost && context.Input.Query("_method") == http.MethodDelete { + method = http.MethodDelete + } + if m, ok := routerInfo.methods[method]; ok { + runMethod = m + } else if m, ok = routerInfo.methods["*"]; ok { + runMethod = m + } else { + runMethod = method + } + } + } + + // also defined runRouter & runMethod from filter + if !isRunnable { + // Invoke the request handler + var execController ControllerInterface + if routerInfo != nil && routerInfo.initialize != nil { + execController = routerInfo.initialize() + } else { + vc := reflect.New(runRouter) + var ok bool + execController, ok = vc.Interface().(ControllerInterface) + if !ok { + panic("controller is not ControllerInterface") + } + } + + // call the controller init function + execController.Init(context, runRouter.Name(), runMethod, execController) + + // call prepare function + execController.Prepare() + + // if XSRF is Enable then check cookie where there has any cookie in the request's cookie _csrf + if BConfig.WebConfig.EnableXSRF { + execController.XSRFToken() + if r.Method == http.MethodPost || r.Method == http.MethodDelete || r.Method == http.MethodPut || + (r.Method == http.MethodPost && (context.Input.Query("_method") == http.MethodDelete || context.Input.Query("_method") == http.MethodPut)) { + execController.CheckXSRFCookie() + } + } + + execController.URLMapping() + + if !context.ResponseWriter.Started { + // exec main logic + switch runMethod { + case http.MethodGet: + execController.Get() + case http.MethodPost: + execController.Post() + case http.MethodDelete: + execController.Delete() + case http.MethodPut: + execController.Put() + case http.MethodHead: + execController.Head() + case http.MethodPatch: + execController.Patch() + case http.MethodOptions: + execController.Options() + case http.MethodTrace: + execController.Trace() + default: + if !execController.HandlerFunc(runMethod) { + vc := reflect.ValueOf(execController) + method := vc.MethodByName(runMethod) + in := param.ConvertParams(methodParams, method.Type(), context) + out := method.Call(in) + + // For backward compatibility we only handle response if we had incoming methodParams + if methodParams != nil { + p.handleParamResponse(context, execController, out) + } + } + } + + // render template + if !context.ResponseWriter.Started && context.Output.Status == 0 { + if BConfig.WebConfig.AutoRender { + if err := execController.Render(); err != nil { + logs.Error(err) + } + } + } + } + + // finish all runRouter. release resource + execController.Finish() + } + + // execute middleware filters + if len(p.filters[AfterExec]) > 0 && p.execFilter(context, urlPath, AfterExec) { + goto Admin + } + + if len(p.filters[FinishRouter]) > 0 && p.execFilter(context, urlPath, FinishRouter) { + goto Admin + } + +Admin: + // admin module record QPS + + statusCode := context.ResponseWriter.Status + if statusCode == 0 { + statusCode = 200 + } + + LogAccess(context, &startTime, statusCode) + + timeDur := time.Since(startTime) + context.ResponseWriter.Elapsed = timeDur + if BConfig.Listen.EnableAdmin { + pattern := "" + if routerInfo != nil { + pattern = routerInfo.pattern + } + + if FilterMonitorFunc(r.Method, r.URL.Path, timeDur, pattern, statusCode) { + routerName := "" + if runRouter != nil { + routerName = runRouter.Name() + } + go toolbox.StatisticsMap.AddStatistics(r.Method, r.URL.Path, routerName, timeDur) + } + } + + if BConfig.RunMode == DEV && !BConfig.Log.AccessLogs { + match := map[bool]string{true: "match", false: "nomatch"} + devInfo := fmt.Sprintf("|%15s|%s %3d %s|%13s|%8s|%s %-7s %s %-3s", + context.Input.IP(), + logs.ColorByStatus(statusCode), statusCode, logs.ResetColor(), + timeDur.String(), + match[findRouter], + logs.ColorByMethod(r.Method), r.Method, logs.ResetColor(), + r.URL.Path) + if routerInfo != nil { + devInfo += fmt.Sprintf(" r:%s", routerInfo.pattern) + } + + logs.Debug(devInfo) + } + // Call WriteHeader if status code has been set changed + if context.Output.Status != 0 { + context.ResponseWriter.WriteHeader(context.Output.Status) + } +} + +func (p *ControllerRegister) handleParamResponse(context *beecontext.Context, execController ControllerInterface, results []reflect.Value) { + // looping in reverse order for the case when both error and value are returned and error sets the response status code + for i := len(results) - 1; i >= 0; i-- { + result := results[i] + if result.Kind() != reflect.Interface || !result.IsNil() { + resultValue := result.Interface() + context.RenderMethodResult(resultValue) + } + } + if !context.ResponseWriter.Started && len(results) > 0 && context.Output.Status == 0 { + context.Output.SetStatus(200) + } +} + +// FindRouter Find Router info for URL +func (p *ControllerRegister) FindRouter(context *beecontext.Context) (routerInfo *ControllerInfo, isFind bool) { + var urlPath = context.Input.URL() + if !BConfig.RouterCaseSensitive { + urlPath = strings.ToLower(urlPath) + } + httpMethod := context.Input.Method() + if t, ok := p.routers[httpMethod]; ok { + runObject := t.Match(urlPath, context) + if r, ok := runObject.(*ControllerInfo); ok { + return r, true + } + } + return +} + +func toURL(params map[string]string) string { + if len(params) == 0 { + return "" + } + u := "?" + for k, v := range params { + u += k + "=" + v + "&" + } + return strings.TrimRight(u, "&") +} + +// LogAccess logging info HTTP Access +func LogAccess(ctx *beecontext.Context, startTime *time.Time, statusCode int) { + // Skip logging if AccessLogs config is false + if !BConfig.Log.AccessLogs { + return + } + // Skip logging static requests unless EnableStaticLogs config is true + if !BConfig.Log.EnableStaticLogs && DefaultAccessLogFilter.Filter(ctx) { + return + } + var ( + requestTime time.Time + elapsedTime time.Duration + r = ctx.Request + ) + if startTime != nil { + requestTime = *startTime + elapsedTime = time.Since(*startTime) + } + record := &logs.AccessLogRecord{ + RemoteAddr: ctx.Input.IP(), + RequestTime: requestTime, + RequestMethod: r.Method, + Request: fmt.Sprintf("%s %s %s", r.Method, r.RequestURI, r.Proto), + ServerProtocol: r.Proto, + Host: r.Host, + Status: statusCode, + ElapsedTime: elapsedTime, + HTTPReferrer: r.Header.Get("Referer"), + HTTPUserAgent: r.Header.Get("User-Agent"), + RemoteUser: r.Header.Get("Remote-User"), + BodyBytesSent: r.ContentLength, + } + logs.AccessLog(record, BConfig.Log.AccessLogsFormat) +} diff --git a/pkg/router_test.go b/pkg/router_test.go new file mode 100644 index 00000000..8ec7927a --- /dev/null +++ b/pkg/router_test.go @@ -0,0 +1,732 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package beego + +import ( + "bytes" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/astaxie/beego/context" + "github.com/astaxie/beego/logs" +) + +type TestController struct { + Controller +} + +func (tc *TestController) Get() { + tc.Data["Username"] = "astaxie" + tc.Ctx.Output.Body([]byte("ok")) +} + +func (tc *TestController) Post() { + tc.Ctx.Output.Body([]byte(tc.Ctx.Input.Query(":name"))) +} + +func (tc *TestController) Param() { + tc.Ctx.Output.Body([]byte(tc.Ctx.Input.Query(":name"))) +} + +func (tc *TestController) List() { + tc.Ctx.Output.Body([]byte("i am list")) +} + +func (tc *TestController) Params() { + tc.Ctx.Output.Body([]byte(tc.Ctx.Input.Param("0") + tc.Ctx.Input.Param("1") + tc.Ctx.Input.Param("2"))) +} + +func (tc *TestController) Myext() { + tc.Ctx.Output.Body([]byte(tc.Ctx.Input.Param(":ext"))) +} + +func (tc *TestController) GetURL() { + tc.Ctx.Output.Body([]byte(tc.URLFor(".Myext"))) +} + +func (tc *TestController) GetParams() { + tc.Ctx.WriteString(tc.Ctx.Input.Query(":last") + "+" + + tc.Ctx.Input.Query(":first") + "+" + tc.Ctx.Input.Query("learn")) +} + +func (tc *TestController) GetManyRouter() { + tc.Ctx.WriteString(tc.Ctx.Input.Query(":id") + tc.Ctx.Input.Query(":page")) +} + +func (tc *TestController) GetEmptyBody() { + var res []byte + tc.Ctx.Output.Body(res) +} + +type JSONController struct { + Controller +} + +func (jc *JSONController) Prepare() { + jc.Data["json"] = "prepare" + jc.ServeJSON(true) +} + +func (jc *JSONController) Get() { + jc.Data["Username"] = "astaxie" + jc.Ctx.Output.Body([]byte("ok")) +} + +func TestUrlFor(t *testing.T) { + handler := NewControllerRegister() + handler.Add("/api/list", &TestController{}, "*:List") + handler.Add("/person/:last/:first", &TestController{}, "*:Param") + if a := handler.URLFor("TestController.List"); a != "/api/list" { + logs.Info(a) + t.Errorf("TestController.List must equal to /api/list") + } + if a := handler.URLFor("TestController.Param", ":last", "xie", ":first", "asta"); a != "/person/xie/asta" { + t.Errorf("TestController.Param must equal to /person/xie/asta, but get " + a) + } +} + +func TestUrlFor3(t *testing.T) { + handler := NewControllerRegister() + handler.AddAuto(&TestController{}) + if a := handler.URLFor("TestController.Myext"); a != "/test/myext" && a != "/Test/Myext" { + t.Errorf("TestController.Myext must equal to /test/myext, but get " + a) + } + if a := handler.URLFor("TestController.GetURL"); a != "/test/geturl" && a != "/Test/GetURL" { + t.Errorf("TestController.GetURL must equal to /test/geturl, but get " + a) + } +} + +func TestUrlFor2(t *testing.T) { + handler := NewControllerRegister() + handler.Add("/v1/:v/cms_:id(.+)_:page(.+).html", &TestController{}, "*:List") + handler.Add("/v1/:username/edit", &TestController{}, "get:GetURL") + handler.Add("/v1/:v(.+)_cms/ttt_:id(.+)_:page(.+).html", &TestController{}, "*:Param") + handler.Add("/:year:int/:month:int/:title/:entid", &TestController{}) + if handler.URLFor("TestController.GetURL", ":username", "astaxie") != "/v1/astaxie/edit" { + logs.Info(handler.URLFor("TestController.GetURL")) + t.Errorf("TestController.List must equal to /v1/astaxie/edit") + } + + if handler.URLFor("TestController.List", ":v", "za", ":id", "12", ":page", "123") != + "/v1/za/cms_12_123.html" { + logs.Info(handler.URLFor("TestController.List")) + t.Errorf("TestController.List must equal to /v1/za/cms_12_123.html") + } + if handler.URLFor("TestController.Param", ":v", "za", ":id", "12", ":page", "123") != + "/v1/za_cms/ttt_12_123.html" { + logs.Info(handler.URLFor("TestController.Param")) + t.Errorf("TestController.List must equal to /v1/za_cms/ttt_12_123.html") + } + if handler.URLFor("TestController.Get", ":year", "1111", ":month", "11", + ":title", "aaaa", ":entid", "aaaa") != + "/1111/11/aaaa/aaaa" { + logs.Info(handler.URLFor("TestController.Get")) + t.Errorf("TestController.Get must equal to /1111/11/aaaa/aaaa") + } +} + +func TestUserFunc(t *testing.T) { + r, _ := http.NewRequest("GET", "/api/list", nil) + w := httptest.NewRecorder() + + handler := NewControllerRegister() + handler.Add("/api/list", &TestController{}, "*:List") + handler.ServeHTTP(w, r) + if w.Body.String() != "i am list" { + t.Errorf("user define func can't run") + } +} + +func TestPostFunc(t *testing.T) { + r, _ := http.NewRequest("POST", "/astaxie", nil) + w := httptest.NewRecorder() + + handler := NewControllerRegister() + handler.Add("/:name", &TestController{}) + handler.ServeHTTP(w, r) + if w.Body.String() != "astaxie" { + t.Errorf("post func should astaxie") + } +} + +func TestAutoFunc(t *testing.T) { + r, _ := http.NewRequest("GET", "/test/list", nil) + w := httptest.NewRecorder() + + handler := NewControllerRegister() + handler.AddAuto(&TestController{}) + handler.ServeHTTP(w, r) + if w.Body.String() != "i am list" { + t.Errorf("user define func can't run") + } +} + +func TestAutoFunc2(t *testing.T) { + r, _ := http.NewRequest("GET", "/Test/List", nil) + w := httptest.NewRecorder() + + handler := NewControllerRegister() + handler.AddAuto(&TestController{}) + handler.ServeHTTP(w, r) + if w.Body.String() != "i am list" { + t.Errorf("user define func can't run") + } +} + +func TestAutoFuncParams(t *testing.T) { + r, _ := http.NewRequest("GET", "/test/params/2009/11/12", nil) + w := httptest.NewRecorder() + + handler := NewControllerRegister() + handler.AddAuto(&TestController{}) + handler.ServeHTTP(w, r) + if w.Body.String() != "20091112" { + t.Errorf("user define func can't run") + } +} + +func TestAutoExtFunc(t *testing.T) { + r, _ := http.NewRequest("GET", "/test/myext.json", nil) + w := httptest.NewRecorder() + + handler := NewControllerRegister() + handler.AddAuto(&TestController{}) + handler.ServeHTTP(w, r) + if w.Body.String() != "json" { + t.Errorf("user define func can't run") + } +} + +func TestRouteOk(t *testing.T) { + + r, _ := http.NewRequest("GET", "/person/anderson/thomas?learn=kungfu", nil) + w := httptest.NewRecorder() + + handler := NewControllerRegister() + handler.Add("/person/:last/:first", &TestController{}, "get:GetParams") + handler.ServeHTTP(w, r) + body := w.Body.String() + if body != "anderson+thomas+kungfu" { + t.Errorf("url param set to [%s];", body) + } +} + +func TestManyRoute(t *testing.T) { + + r, _ := http.NewRequest("GET", "/beego32-12.html", nil) + w := httptest.NewRecorder() + + handler := NewControllerRegister() + handler.Add("/beego:id([0-9]+)-:page([0-9]+).html", &TestController{}, "get:GetManyRouter") + handler.ServeHTTP(w, r) + + body := w.Body.String() + + if body != "3212" { + t.Errorf("url param set to [%s];", body) + } +} + +// Test for issue #1669 +func TestEmptyResponse(t *testing.T) { + + r, _ := http.NewRequest("GET", "/beego-empty.html", nil) + w := httptest.NewRecorder() + + handler := NewControllerRegister() + handler.Add("/beego-empty.html", &TestController{}, "get:GetEmptyBody") + handler.ServeHTTP(w, r) + + if body := w.Body.String(); body != "" { + t.Error("want empty body") + } +} + +func TestNotFound(t *testing.T) { + r, _ := http.NewRequest("GET", "/", nil) + w := httptest.NewRecorder() + + handler := NewControllerRegister() + handler.ServeHTTP(w, r) + + if w.Code != http.StatusNotFound { + t.Errorf("Code set to [%v]; want [%v]", w.Code, http.StatusNotFound) + } +} + +// TestStatic tests the ability to serve static +// content from the filesystem +func TestStatic(t *testing.T) { + r, _ := http.NewRequest("GET", "/static/js/jquery.js", nil) + w := httptest.NewRecorder() + + handler := NewControllerRegister() + handler.ServeHTTP(w, r) + + if w.Code != 404 { + t.Errorf("handler.Static failed to serve file") + } +} + +func TestPrepare(t *testing.T) { + r, _ := http.NewRequest("GET", "/json/list", nil) + w := httptest.NewRecorder() + + handler := NewControllerRegister() + handler.Add("/json/list", &JSONController{}) + handler.ServeHTTP(w, r) + if w.Body.String() != `"prepare"` { + t.Errorf(w.Body.String() + "user define func can't run") + } +} + +func TestAutoPrefix(t *testing.T) { + r, _ := http.NewRequest("GET", "/admin/test/list", nil) + w := httptest.NewRecorder() + + handler := NewControllerRegister() + handler.AddAutoPrefix("/admin", &TestController{}) + handler.ServeHTTP(w, r) + if w.Body.String() != "i am list" { + t.Errorf("TestAutoPrefix can't run") + } +} + +func TestRouterGet(t *testing.T) { + r, _ := http.NewRequest("GET", "/user", nil) + w := httptest.NewRecorder() + + handler := NewControllerRegister() + handler.Get("/user", func(ctx *context.Context) { + ctx.Output.Body([]byte("Get userlist")) + }) + handler.ServeHTTP(w, r) + if w.Body.String() != "Get userlist" { + t.Errorf("TestRouterGet can't run") + } +} + +func TestRouterPost(t *testing.T) { + r, _ := http.NewRequest("POST", "/user/123", nil) + w := httptest.NewRecorder() + + handler := NewControllerRegister() + handler.Post("/user/:id", func(ctx *context.Context) { + ctx.Output.Body([]byte(ctx.Input.Param(":id"))) + }) + handler.ServeHTTP(w, r) + if w.Body.String() != "123" { + t.Errorf("TestRouterPost can't run") + } +} + +func sayhello(w http.ResponseWriter, r *http.Request) { + w.Write([]byte("sayhello")) +} + +func TestRouterHandler(t *testing.T) { + r, _ := http.NewRequest("POST", "/sayhi", nil) + w := httptest.NewRecorder() + + handler := NewControllerRegister() + handler.Handler("/sayhi", http.HandlerFunc(sayhello)) + handler.ServeHTTP(w, r) + if w.Body.String() != "sayhello" { + t.Errorf("TestRouterHandler can't run") + } +} + +func TestRouterHandlerAll(t *testing.T) { + r, _ := http.NewRequest("POST", "/sayhi/a/b/c", nil) + w := httptest.NewRecorder() + + handler := NewControllerRegister() + handler.Handler("/sayhi", http.HandlerFunc(sayhello), true) + handler.ServeHTTP(w, r) + if w.Body.String() != "sayhello" { + t.Errorf("TestRouterHandler can't run") + } +} + +// +// Benchmarks NewApp: +// + +func beegoFilterFunc(ctx *context.Context) { + ctx.WriteString("hello") +} + +type AdminController struct { + Controller +} + +func (a *AdminController) Get() { + a.Ctx.WriteString("hello") +} + +func TestRouterFunc(t *testing.T) { + mux := NewControllerRegister() + mux.Get("/action", beegoFilterFunc) + mux.Post("/action", beegoFilterFunc) + rw, r := testRequest("GET", "/action") + mux.ServeHTTP(rw, r) + if rw.Body.String() != "hello" { + t.Errorf("TestRouterFunc can't run") + } +} + +func BenchmarkFunc(b *testing.B) { + mux := NewControllerRegister() + mux.Get("/action", beegoFilterFunc) + rw, r := testRequest("GET", "/action") + b.ResetTimer() + for i := 0; i < b.N; i++ { + mux.ServeHTTP(rw, r) + } +} + +func BenchmarkController(b *testing.B) { + mux := NewControllerRegister() + mux.Add("/action", &AdminController{}) + rw, r := testRequest("GET", "/action") + b.ResetTimer() + for i := 0; i < b.N; i++ { + mux.ServeHTTP(rw, r) + } +} + +func testRequest(method, path string) (*httptest.ResponseRecorder, *http.Request) { + request, _ := http.NewRequest(method, path, nil) + recorder := httptest.NewRecorder() + + return recorder, request +} + +// Expectation: A Filter with the correct configuration should be created given +// specific parameters. +func TestInsertFilter(t *testing.T) { + testName := "TestInsertFilter" + + mux := NewControllerRegister() + mux.InsertFilter("*", BeforeRouter, func(*context.Context) {}) + if !mux.filters[BeforeRouter][0].returnOnOutput { + t.Errorf( + "%s: passing no variadic params should set returnOnOutput to true", + testName) + } + if mux.filters[BeforeRouter][0].resetParams { + t.Errorf( + "%s: passing no variadic params should set resetParams to false", + testName) + } + + mux = NewControllerRegister() + mux.InsertFilter("*", BeforeRouter, func(*context.Context) {}, false) + if mux.filters[BeforeRouter][0].returnOnOutput { + t.Errorf( + "%s: passing false as 1st variadic param should set returnOnOutput to false", + testName) + } + + mux = NewControllerRegister() + mux.InsertFilter("*", BeforeRouter, func(*context.Context) {}, true, true) + if !mux.filters[BeforeRouter][0].resetParams { + t.Errorf( + "%s: passing true as 2nd variadic param should set resetParams to true", + testName) + } +} + +// Expectation: the second variadic arg should cause the execution of the filter +// to preserve the parameters from before its execution. +func TestParamResetFilter(t *testing.T) { + testName := "TestParamResetFilter" + route := "/beego/*" // splat + path := "/beego/routes/routes" + + mux := NewControllerRegister() + + mux.InsertFilter("*", BeforeExec, beegoResetParams, true, true) + + mux.Get(route, beegoHandleResetParams) + + rw, r := testRequest("GET", path) + mux.ServeHTTP(rw, r) + + // The two functions, `beegoResetParams` and `beegoHandleResetParams` add + // a response header of `Splat`. The expectation here is that that Header + // value should match what the _request's_ router set, not the filter's. + + headers := rw.Result().Header + if len(headers["Splat"]) != 1 { + t.Errorf( + "%s: There was an error in the test. Splat param not set in Header", + testName) + } + if headers["Splat"][0] != "routes/routes" { + t.Errorf( + "%s: expected `:splat` param to be [routes/routes] but it was [%s]", + testName, headers["Splat"][0]) + } +} + +// Execution point: BeforeRouter +// expectation: only BeforeRouter function is executed, notmatch output as router doesn't handle +func TestFilterBeforeRouter(t *testing.T) { + testName := "TestFilterBeforeRouter" + url := "/beforeRouter" + + mux := NewControllerRegister() + mux.InsertFilter(url, BeforeRouter, beegoBeforeRouter1) + + mux.Get(url, beegoFilterFunc) + + rw, r := testRequest("GET", url) + mux.ServeHTTP(rw, r) + + if !strings.Contains(rw.Body.String(), "BeforeRouter1") { + t.Errorf(testName + " BeforeRouter did not run") + } + if strings.Contains(rw.Body.String(), "hello") { + t.Errorf(testName + " BeforeRouter did not return properly") + } +} + +// Execution point: BeforeExec +// expectation: only BeforeExec function is executed, match as router determines route only +func TestFilterBeforeExec(t *testing.T) { + testName := "TestFilterBeforeExec" + url := "/beforeExec" + + mux := NewControllerRegister() + mux.InsertFilter(url, BeforeRouter, beegoFilterNoOutput) + mux.InsertFilter(url, BeforeExec, beegoBeforeExec1) + + mux.Get(url, beegoFilterFunc) + + rw, r := testRequest("GET", url) + mux.ServeHTTP(rw, r) + + if !strings.Contains(rw.Body.String(), "BeforeExec1") { + t.Errorf(testName + " BeforeExec did not run") + } + if strings.Contains(rw.Body.String(), "hello") { + t.Errorf(testName + " BeforeExec did not return properly") + } + if strings.Contains(rw.Body.String(), "BeforeRouter") { + t.Errorf(testName + " BeforeRouter ran in error") + } +} + +// Execution point: AfterExec +// expectation: only AfterExec function is executed, match as router handles +func TestFilterAfterExec(t *testing.T) { + testName := "TestFilterAfterExec" + url := "/afterExec" + + mux := NewControllerRegister() + mux.InsertFilter(url, BeforeRouter, beegoFilterNoOutput) + mux.InsertFilter(url, BeforeExec, beegoFilterNoOutput) + mux.InsertFilter(url, AfterExec, beegoAfterExec1, false) + + mux.Get(url, beegoFilterFunc) + + rw, r := testRequest("GET", url) + mux.ServeHTTP(rw, r) + + if !strings.Contains(rw.Body.String(), "AfterExec1") { + t.Errorf(testName + " AfterExec did not run") + } + if !strings.Contains(rw.Body.String(), "hello") { + t.Errorf(testName + " handler did not run properly") + } + if strings.Contains(rw.Body.String(), "BeforeRouter") { + t.Errorf(testName + " BeforeRouter ran in error") + } + if strings.Contains(rw.Body.String(), "BeforeExec") { + t.Errorf(testName + " BeforeExec ran in error") + } +} + +// Execution point: FinishRouter +// expectation: only FinishRouter function is executed, match as router handles +func TestFilterFinishRouter(t *testing.T) { + testName := "TestFilterFinishRouter" + url := "/finishRouter" + + mux := NewControllerRegister() + mux.InsertFilter(url, BeforeRouter, beegoFilterNoOutput) + mux.InsertFilter(url, BeforeExec, beegoFilterNoOutput) + mux.InsertFilter(url, AfterExec, beegoFilterNoOutput) + mux.InsertFilter(url, FinishRouter, beegoFinishRouter1) + + mux.Get(url, beegoFilterFunc) + + rw, r := testRequest("GET", url) + mux.ServeHTTP(rw, r) + + if strings.Contains(rw.Body.String(), "FinishRouter1") { + t.Errorf(testName + " FinishRouter did not run") + } + if !strings.Contains(rw.Body.String(), "hello") { + t.Errorf(testName + " handler did not run properly") + } + if strings.Contains(rw.Body.String(), "AfterExec1") { + t.Errorf(testName + " AfterExec ran in error") + } + if strings.Contains(rw.Body.String(), "BeforeRouter") { + t.Errorf(testName + " BeforeRouter ran in error") + } + if strings.Contains(rw.Body.String(), "BeforeExec") { + t.Errorf(testName + " BeforeExec ran in error") + } +} + +// Execution point: FinishRouter +// expectation: only first FinishRouter function is executed, match as router handles +func TestFilterFinishRouterMultiFirstOnly(t *testing.T) { + testName := "TestFilterFinishRouterMultiFirstOnly" + url := "/finishRouterMultiFirstOnly" + + mux := NewControllerRegister() + mux.InsertFilter(url, FinishRouter, beegoFinishRouter1, false) + mux.InsertFilter(url, FinishRouter, beegoFinishRouter2) + + mux.Get(url, beegoFilterFunc) + + rw, r := testRequest("GET", url) + mux.ServeHTTP(rw, r) + + if !strings.Contains(rw.Body.String(), "FinishRouter1") { + t.Errorf(testName + " FinishRouter1 did not run") + } + if !strings.Contains(rw.Body.String(), "hello") { + t.Errorf(testName + " handler did not run properly") + } + // not expected in body + if strings.Contains(rw.Body.String(), "FinishRouter2") { + t.Errorf(testName + " FinishRouter2 did run") + } +} + +// Execution point: FinishRouter +// expectation: both FinishRouter functions execute, match as router handles +func TestFilterFinishRouterMulti(t *testing.T) { + testName := "TestFilterFinishRouterMulti" + url := "/finishRouterMulti" + + mux := NewControllerRegister() + mux.InsertFilter(url, FinishRouter, beegoFinishRouter1, false) + mux.InsertFilter(url, FinishRouter, beegoFinishRouter2, false) + + mux.Get(url, beegoFilterFunc) + + rw, r := testRequest("GET", url) + mux.ServeHTTP(rw, r) + + if !strings.Contains(rw.Body.String(), "FinishRouter1") { + t.Errorf(testName + " FinishRouter1 did not run") + } + if !strings.Contains(rw.Body.String(), "hello") { + t.Errorf(testName + " handler did not run properly") + } + if !strings.Contains(rw.Body.String(), "FinishRouter2") { + t.Errorf(testName + " FinishRouter2 did not run properly") + } +} + +func beegoFilterNoOutput(ctx *context.Context) { +} + +func beegoBeforeRouter1(ctx *context.Context) { + ctx.WriteString("|BeforeRouter1") +} + +func beegoBeforeExec1(ctx *context.Context) { + ctx.WriteString("|BeforeExec1") +} + +func beegoAfterExec1(ctx *context.Context) { + ctx.WriteString("|AfterExec1") +} + +func beegoFinishRouter1(ctx *context.Context) { + ctx.WriteString("|FinishRouter1") +} + +func beegoFinishRouter2(ctx *context.Context) { + ctx.WriteString("|FinishRouter2") +} + +func beegoResetParams(ctx *context.Context) { + ctx.ResponseWriter.Header().Set("splat", ctx.Input.Param(":splat")) +} + +func beegoHandleResetParams(ctx *context.Context) { + ctx.ResponseWriter.Header().Set("splat", ctx.Input.Param(":splat")) +} + +// YAML +type YAMLController struct { + Controller +} + +func (jc *YAMLController) Prepare() { + jc.Data["yaml"] = "prepare" + jc.ServeYAML() +} + +func (jc *YAMLController) Get() { + jc.Data["Username"] = "astaxie" + jc.Ctx.Output.Body([]byte("ok")) +} + +func TestYAMLPrepare(t *testing.T) { + r, _ := http.NewRequest("GET", "/yaml/list", nil) + w := httptest.NewRecorder() + + handler := NewControllerRegister() + handler.Add("/yaml/list", &YAMLController{}) + handler.ServeHTTP(w, r) + if strings.TrimSpace(w.Body.String()) != "prepare" { + t.Errorf(w.Body.String()) + } +} + +func TestRouterEntityTooLargeCopyBody(t *testing.T) { + _MaxMemory := BConfig.MaxMemory + _CopyRequestBody := BConfig.CopyRequestBody + BConfig.CopyRequestBody = true + BConfig.MaxMemory = 20 + + b := bytes.NewBuffer([]byte("barbarbarbarbarbarbarbarbarbar")) + r, _ := http.NewRequest("POST", "/user/123", b) + w := httptest.NewRecorder() + + handler := NewControllerRegister() + handler.Post("/user/:id", func(ctx *context.Context) { + ctx.Output.Body([]byte(ctx.Input.Param(":id"))) + }) + handler.ServeHTTP(w, r) + + BConfig.CopyRequestBody = _CopyRequestBody + BConfig.MaxMemory = _MaxMemory + + if w.Code != http.StatusRequestEntityTooLarge { + t.Errorf("TestRouterRequestEntityTooLarge can't run") + } +} diff --git a/pkg/session/README.md b/pkg/session/README.md new file mode 100644 index 00000000..6d0a297e --- /dev/null +++ b/pkg/session/README.md @@ -0,0 +1,114 @@ +session +============== + +session is a Go session manager. It can use many session providers. Just like the `database/sql` and `database/sql/driver`. + +## How to install? + + go get github.com/astaxie/beego/session + + +## What providers are supported? + +As of now this session manager support memory, file, Redis and MySQL. + + +## How to use it? + +First you must import it + + import ( + "github.com/astaxie/beego/session" + ) + +Then in you web app init the global session manager + + var globalSessions *session.Manager + +* Use **memory** as provider: + + func init() { + globalSessions, _ = session.NewManager("memory", `{"cookieName":"gosessionid","gclifetime":3600}`) + go globalSessions.GC() + } + +* Use **file** as provider, the last param is the path where you want file to be stored: + + func init() { + globalSessions, _ = session.NewManager("file",`{"cookieName":"gosessionid","gclifetime":3600,"ProviderConfig":"./tmp"}`) + go globalSessions.GC() + } + +* Use **Redis** as provider, the last param is the Redis conn address,poolsize,password: + + func init() { + globalSessions, _ = session.NewManager("redis", `{"cookieName":"gosessionid","gclifetime":3600,"ProviderConfig":"127.0.0.1:6379,100,astaxie"}`) + go globalSessions.GC() + } + +* Use **MySQL** as provider, the last param is the DSN, learn more from [mysql](https://github.com/go-sql-driver/mysql#dsn-data-source-name): + + func init() { + globalSessions, _ = session.NewManager( + "mysql", `{"cookieName":"gosessionid","gclifetime":3600,"ProviderConfig":"username:password@protocol(address)/dbname?param=value"}`) + go globalSessions.GC() + } + +* Use **Cookie** as provider: + + func init() { + globalSessions, _ = session.NewManager( + "cookie", `{"cookieName":"gosessionid","enableSetCookie":false,"gclifetime":3600,"ProviderConfig":"{\"cookieName\":\"gosessionid\",\"securityKey\":\"beegocookiehashkey\"}"}`) + go globalSessions.GC() + } + + +Finally in the handlerfunc you can use it like this + + func login(w http.ResponseWriter, r *http.Request) { + sess := globalSessions.SessionStart(w, r) + defer sess.SessionRelease(w) + username := sess.Get("username") + fmt.Println(username) + if r.Method == "GET" { + t, _ := template.ParseFiles("login.gtpl") + t.Execute(w, nil) + } else { + fmt.Println("username:", r.Form["username"]) + sess.Set("username", r.Form["username"]) + fmt.Println("password:", r.Form["password"]) + } + } + + +## How to write own provider? + +When you develop a web app, maybe you want to write own provider because you must meet the requirements. + +Writing a provider is easy. You only need to define two struct types +(Session and Provider), which satisfy the interface definition. +Maybe you will find the **memory** provider is a good example. + + type SessionStore interface { + Set(key, value interface{}) error //set session value + Get(key interface{}) interface{} //get session value + Delete(key interface{}) error //delete session value + SessionID() string //back current sessionID + SessionRelease(w http.ResponseWriter) // release the resource & save data to provider & return the data + Flush() error //delete all data + } + + type Provider interface { + SessionInit(gclifetime int64, config string) error + SessionRead(sid string) (SessionStore, error) + SessionExist(sid string) bool + SessionRegenerate(oldsid, sid string) (SessionStore, error) + SessionDestroy(sid string) error + SessionAll() int //get all active session + SessionGC() + } + + +## LICENSE + +BSD License http://creativecommons.org/licenses/BSD/ diff --git a/pkg/session/couchbase/sess_couchbase.go b/pkg/session/couchbase/sess_couchbase.go new file mode 100644 index 00000000..707d042c --- /dev/null +++ b/pkg/session/couchbase/sess_couchbase.go @@ -0,0 +1,247 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package couchbase for session provider +// +// depend on github.com/couchbaselabs/go-couchbasee +// +// go install github.com/couchbaselabs/go-couchbase +// +// Usage: +// import( +// _ "github.com/astaxie/beego/session/couchbase" +// "github.com/astaxie/beego/session" +// ) +// +// func init() { +// globalSessions, _ = session.NewManager("couchbase", ``{"cookieName":"gosessionid","gclifetime":3600,"ProviderConfig":"http://host:port/, Pool, Bucket"}``) +// go globalSessions.GC() +// } +// +// more docs: http://beego.me/docs/module/session.md +package couchbase + +import ( + "net/http" + "strings" + "sync" + + couchbase "github.com/couchbase/go-couchbase" + + "github.com/astaxie/beego/session" +) + +var couchbpder = &Provider{} + +// SessionStore store each session +type SessionStore struct { + b *couchbase.Bucket + sid string + lock sync.RWMutex + values map[interface{}]interface{} + maxlifetime int64 +} + +// Provider couchabse provided +type Provider struct { + maxlifetime int64 + savePath string + pool string + bucket string + b *couchbase.Bucket +} + +// Set value to couchabse session +func (cs *SessionStore) Set(key, value interface{}) error { + cs.lock.Lock() + defer cs.lock.Unlock() + cs.values[key] = value + return nil +} + +// Get value from couchabse session +func (cs *SessionStore) Get(key interface{}) interface{} { + cs.lock.RLock() + defer cs.lock.RUnlock() + if v, ok := cs.values[key]; ok { + return v + } + return nil +} + +// Delete value in couchbase session by given key +func (cs *SessionStore) Delete(key interface{}) error { + cs.lock.Lock() + defer cs.lock.Unlock() + delete(cs.values, key) + return nil +} + +// Flush Clean all values in couchbase session +func (cs *SessionStore) Flush() error { + cs.lock.Lock() + defer cs.lock.Unlock() + cs.values = make(map[interface{}]interface{}) + return nil +} + +// SessionID Get couchbase session store id +func (cs *SessionStore) SessionID() string { + return cs.sid +} + +// SessionRelease Write couchbase session with Gob string +func (cs *SessionStore) SessionRelease(w http.ResponseWriter) { + defer cs.b.Close() + + bo, err := session.EncodeGob(cs.values) + if err != nil { + return + } + + cs.b.Set(cs.sid, int(cs.maxlifetime), bo) +} + +func (cp *Provider) getBucket() *couchbase.Bucket { + c, err := couchbase.Connect(cp.savePath) + if err != nil { + return nil + } + + pool, err := c.GetPool(cp.pool) + if err != nil { + return nil + } + + bucket, err := pool.GetBucket(cp.bucket) + if err != nil { + return nil + } + + return bucket +} + +// SessionInit init couchbase session +// savepath like couchbase server REST/JSON URL +// e.g. http://host:port/, Pool, Bucket +func (cp *Provider) SessionInit(maxlifetime int64, savePath string) error { + cp.maxlifetime = maxlifetime + configs := strings.Split(savePath, ",") + if len(configs) > 0 { + cp.savePath = configs[0] + } + if len(configs) > 1 { + cp.pool = configs[1] + } + if len(configs) > 2 { + cp.bucket = configs[2] + } + + return nil +} + +// SessionRead read couchbase session by sid +func (cp *Provider) SessionRead(sid string) (session.Store, error) { + cp.b = cp.getBucket() + + var ( + kv map[interface{}]interface{} + err error + doc []byte + ) + + err = cp.b.Get(sid, &doc) + if err != nil { + return nil, err + } else if doc == nil { + kv = make(map[interface{}]interface{}) + } else { + kv, err = session.DecodeGob(doc) + if err != nil { + return nil, err + } + } + + cs := &SessionStore{b: cp.b, sid: sid, values: kv, maxlifetime: cp.maxlifetime} + return cs, nil +} + +// SessionExist Check couchbase session exist. +// it checkes sid exist or not. +func (cp *Provider) SessionExist(sid string) bool { + cp.b = cp.getBucket() + defer cp.b.Close() + + var doc []byte + + if err := cp.b.Get(sid, &doc); err != nil || doc == nil { + return false + } + return true +} + +// SessionRegenerate remove oldsid and use sid to generate new session +func (cp *Provider) SessionRegenerate(oldsid, sid string) (session.Store, error) { + cp.b = cp.getBucket() + + var doc []byte + if err := cp.b.Get(oldsid, &doc); err != nil || doc == nil { + cp.b.Set(sid, int(cp.maxlifetime), "") + } else { + err := cp.b.Delete(oldsid) + if err != nil { + return nil, err + } + _, _ = cp.b.Add(sid, int(cp.maxlifetime), doc) + } + + err := cp.b.Get(sid, &doc) + if err != nil { + return nil, err + } + var kv map[interface{}]interface{} + if doc == nil { + kv = make(map[interface{}]interface{}) + } else { + kv, err = session.DecodeGob(doc) + if err != nil { + return nil, err + } + } + + cs := &SessionStore{b: cp.b, sid: sid, values: kv, maxlifetime: cp.maxlifetime} + return cs, nil +} + +// SessionDestroy Remove bucket in this couchbase +func (cp *Provider) SessionDestroy(sid string) error { + cp.b = cp.getBucket() + defer cp.b.Close() + + cp.b.Delete(sid) + return nil +} + +// SessionGC Recycle +func (cp *Provider) SessionGC() { +} + +// SessionAll return all active session +func (cp *Provider) SessionAll() int { + return 0 +} + +func init() { + session.Register("couchbase", couchbpder) +} diff --git a/pkg/session/ledis/ledis_session.go b/pkg/session/ledis/ledis_session.go new file mode 100644 index 00000000..ee81df67 --- /dev/null +++ b/pkg/session/ledis/ledis_session.go @@ -0,0 +1,173 @@ +// Package ledis provide session Provider +package ledis + +import ( + "net/http" + "strconv" + "strings" + "sync" + + "github.com/ledisdb/ledisdb/config" + "github.com/ledisdb/ledisdb/ledis" + + "github.com/astaxie/beego/session" +) + +var ( + ledispder = &Provider{} + c *ledis.DB +) + +// SessionStore ledis session store +type SessionStore struct { + sid string + lock sync.RWMutex + values map[interface{}]interface{} + maxlifetime int64 +} + +// Set value in ledis session +func (ls *SessionStore) Set(key, value interface{}) error { + ls.lock.Lock() + defer ls.lock.Unlock() + ls.values[key] = value + return nil +} + +// Get value in ledis session +func (ls *SessionStore) Get(key interface{}) interface{} { + ls.lock.RLock() + defer ls.lock.RUnlock() + if v, ok := ls.values[key]; ok { + return v + } + return nil +} + +// Delete value in ledis session +func (ls *SessionStore) Delete(key interface{}) error { + ls.lock.Lock() + defer ls.lock.Unlock() + delete(ls.values, key) + return nil +} + +// Flush clear all values in ledis session +func (ls *SessionStore) Flush() error { + ls.lock.Lock() + defer ls.lock.Unlock() + ls.values = make(map[interface{}]interface{}) + return nil +} + +// SessionID get ledis session id +func (ls *SessionStore) SessionID() string { + return ls.sid +} + +// SessionRelease save session values to ledis +func (ls *SessionStore) SessionRelease(w http.ResponseWriter) { + b, err := session.EncodeGob(ls.values) + if err != nil { + return + } + c.Set([]byte(ls.sid), b) + c.Expire([]byte(ls.sid), ls.maxlifetime) +} + +// Provider ledis session provider +type Provider struct { + maxlifetime int64 + savePath string + db int +} + +// SessionInit init ledis session +// savepath like ledis server saveDataPath,pool size +// e.g. 127.0.0.1:6379,100,astaxie +func (lp *Provider) SessionInit(maxlifetime int64, savePath string) error { + var err error + lp.maxlifetime = maxlifetime + configs := strings.Split(savePath, ",") + if len(configs) == 1 { + lp.savePath = configs[0] + } else if len(configs) == 2 { + lp.savePath = configs[0] + lp.db, err = strconv.Atoi(configs[1]) + if err != nil { + return err + } + } + cfg := new(config.Config) + cfg.DataDir = lp.savePath + + var ledisInstance *ledis.Ledis + ledisInstance, err = ledis.Open(cfg) + if err != nil { + return err + } + c, err = ledisInstance.Select(lp.db) + return err +} + +// SessionRead read ledis session by sid +func (lp *Provider) SessionRead(sid string) (session.Store, error) { + var ( + kv map[interface{}]interface{} + err error + ) + + kvs, _ := c.Get([]byte(sid)) + + if len(kvs) == 0 { + kv = make(map[interface{}]interface{}) + } else { + if kv, err = session.DecodeGob(kvs); err != nil { + return nil, err + } + } + + ls := &SessionStore{sid: sid, values: kv, maxlifetime: lp.maxlifetime} + return ls, nil +} + +// SessionExist check ledis session exist by sid +func (lp *Provider) SessionExist(sid string) bool { + count, _ := c.Exists([]byte(sid)) + return count != 0 +} + +// SessionRegenerate generate new sid for ledis session +func (lp *Provider) SessionRegenerate(oldsid, sid string) (session.Store, error) { + count, _ := c.Exists([]byte(sid)) + if count == 0 { + // oldsid doesn't exists, set the new sid directly + // ignore error here, since if it return error + // the existed value will be 0 + c.Set([]byte(sid), []byte("")) + c.Expire([]byte(sid), lp.maxlifetime) + } else { + data, _ := c.Get([]byte(oldsid)) + c.Set([]byte(sid), data) + c.Expire([]byte(sid), lp.maxlifetime) + } + return lp.SessionRead(sid) +} + +// SessionDestroy delete ledis session by id +func (lp *Provider) SessionDestroy(sid string) error { + c.Del([]byte(sid)) + return nil +} + +// SessionGC Impelment method, no used. +func (lp *Provider) SessionGC() { +} + +// SessionAll return all active session +func (lp *Provider) SessionAll() int { + return 0 +} +func init() { + session.Register("ledis", ledispder) +} diff --git a/pkg/session/memcache/sess_memcache.go b/pkg/session/memcache/sess_memcache.go new file mode 100644 index 00000000..85a2d815 --- /dev/null +++ b/pkg/session/memcache/sess_memcache.go @@ -0,0 +1,230 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package memcache for session provider +// +// depend on github.com/bradfitz/gomemcache/memcache +// +// go install github.com/bradfitz/gomemcache/memcache +// +// Usage: +// import( +// _ "github.com/astaxie/beego/session/memcache" +// "github.com/astaxie/beego/session" +// ) +// +// func init() { +// globalSessions, _ = session.NewManager("memcache", ``{"cookieName":"gosessionid","gclifetime":3600,"ProviderConfig":"127.0.0.1:11211"}``) +// go globalSessions.GC() +// } +// +// more docs: http://beego.me/docs/module/session.md +package memcache + +import ( + "net/http" + "strings" + "sync" + + "github.com/astaxie/beego/session" + + "github.com/bradfitz/gomemcache/memcache" +) + +var mempder = &MemProvider{} +var client *memcache.Client + +// SessionStore memcache session store +type SessionStore struct { + sid string + lock sync.RWMutex + values map[interface{}]interface{} + maxlifetime int64 +} + +// Set value in memcache session +func (rs *SessionStore) Set(key, value interface{}) error { + rs.lock.Lock() + defer rs.lock.Unlock() + rs.values[key] = value + return nil +} + +// Get value in memcache session +func (rs *SessionStore) Get(key interface{}) interface{} { + rs.lock.RLock() + defer rs.lock.RUnlock() + if v, ok := rs.values[key]; ok { + return v + } + return nil +} + +// Delete value in memcache session +func (rs *SessionStore) Delete(key interface{}) error { + rs.lock.Lock() + defer rs.lock.Unlock() + delete(rs.values, key) + return nil +} + +// Flush clear all values in memcache session +func (rs *SessionStore) Flush() error { + rs.lock.Lock() + defer rs.lock.Unlock() + rs.values = make(map[interface{}]interface{}) + return nil +} + +// SessionID get memcache session id +func (rs *SessionStore) SessionID() string { + return rs.sid +} + +// SessionRelease save session values to memcache +func (rs *SessionStore) SessionRelease(w http.ResponseWriter) { + b, err := session.EncodeGob(rs.values) + if err != nil { + return + } + item := memcache.Item{Key: rs.sid, Value: b, Expiration: int32(rs.maxlifetime)} + client.Set(&item) +} + +// MemProvider memcache session provider +type MemProvider struct { + maxlifetime int64 + conninfo []string + poolsize int + password string +} + +// SessionInit init memcache session +// savepath like +// e.g. 127.0.0.1:9090 +func (rp *MemProvider) SessionInit(maxlifetime int64, savePath string) error { + rp.maxlifetime = maxlifetime + rp.conninfo = strings.Split(savePath, ";") + client = memcache.New(rp.conninfo...) + return nil +} + +// SessionRead read memcache session by sid +func (rp *MemProvider) SessionRead(sid string) (session.Store, error) { + if client == nil { + if err := rp.connectInit(); err != nil { + return nil, err + } + } + item, err := client.Get(sid) + if err != nil { + if err == memcache.ErrCacheMiss { + rs := &SessionStore{sid: sid, values: make(map[interface{}]interface{}), maxlifetime: rp.maxlifetime} + return rs, nil + } + return nil, err + } + var kv map[interface{}]interface{} + if len(item.Value) == 0 { + kv = make(map[interface{}]interface{}) + } else { + kv, err = session.DecodeGob(item.Value) + if err != nil { + return nil, err + } + } + rs := &SessionStore{sid: sid, values: kv, maxlifetime: rp.maxlifetime} + return rs, nil +} + +// SessionExist check memcache session exist by sid +func (rp *MemProvider) SessionExist(sid string) bool { + if client == nil { + if err := rp.connectInit(); err != nil { + return false + } + } + if item, err := client.Get(sid); err != nil || len(item.Value) == 0 { + return false + } + return true +} + +// SessionRegenerate generate new sid for memcache session +func (rp *MemProvider) SessionRegenerate(oldsid, sid string) (session.Store, error) { + if client == nil { + if err := rp.connectInit(); err != nil { + return nil, err + } + } + var contain []byte + if item, err := client.Get(sid); err != nil || len(item.Value) == 0 { + // oldsid doesn't exists, set the new sid directly + // ignore error here, since if it return error + // the existed value will be 0 + item.Key = sid + item.Value = []byte("") + item.Expiration = int32(rp.maxlifetime) + client.Set(item) + } else { + client.Delete(oldsid) + item.Key = sid + item.Expiration = int32(rp.maxlifetime) + client.Set(item) + contain = item.Value + } + + var kv map[interface{}]interface{} + if len(contain) == 0 { + kv = make(map[interface{}]interface{}) + } else { + var err error + kv, err = session.DecodeGob(contain) + if err != nil { + return nil, err + } + } + + rs := &SessionStore{sid: sid, values: kv, maxlifetime: rp.maxlifetime} + return rs, nil +} + +// SessionDestroy delete memcache session by id +func (rp *MemProvider) SessionDestroy(sid string) error { + if client == nil { + if err := rp.connectInit(); err != nil { + return err + } + } + + return client.Delete(sid) +} + +func (rp *MemProvider) connectInit() error { + client = memcache.New(rp.conninfo...) + return nil +} + +// SessionGC Impelment method, no used. +func (rp *MemProvider) SessionGC() { +} + +// SessionAll return all activeSession +func (rp *MemProvider) SessionAll() int { + return 0 +} + +func init() { + session.Register("memcache", mempder) +} diff --git a/pkg/session/mysql/sess_mysql.go b/pkg/session/mysql/sess_mysql.go new file mode 100644 index 00000000..301353ab --- /dev/null +++ b/pkg/session/mysql/sess_mysql.go @@ -0,0 +1,228 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package mysql for session provider +// +// depends on github.com/go-sql-driver/mysql: +// +// go install github.com/go-sql-driver/mysql +// +// mysql session support need create table as sql: +// CREATE TABLE `session` ( +// `session_key` char(64) NOT NULL, +// `session_data` blob, +// `session_expiry` int(11) unsigned NOT NULL, +// PRIMARY KEY (`session_key`) +// ) ENGINE=MyISAM DEFAULT CHARSET=utf8; +// +// Usage: +// import( +// _ "github.com/astaxie/beego/session/mysql" +// "github.com/astaxie/beego/session" +// ) +// +// func init() { +// globalSessions, _ = session.NewManager("mysql", ``{"cookieName":"gosessionid","gclifetime":3600,"ProviderConfig":"[username[:password]@][protocol[(address)]]/dbname[?param1=value1&...¶mN=valueN]"}``) +// go globalSessions.GC() +// } +// +// more docs: http://beego.me/docs/module/session.md +package mysql + +import ( + "database/sql" + "net/http" + "sync" + "time" + + "github.com/astaxie/beego/session" + // import mysql driver + _ "github.com/go-sql-driver/mysql" +) + +var ( + // TableName store the session in MySQL + TableName = "session" + mysqlpder = &Provider{} +) + +// SessionStore mysql session store +type SessionStore struct { + c *sql.DB + sid string + lock sync.RWMutex + values map[interface{}]interface{} +} + +// Set value in mysql session. +// it is temp value in map. +func (st *SessionStore) Set(key, value interface{}) error { + st.lock.Lock() + defer st.lock.Unlock() + st.values[key] = value + return nil +} + +// Get value from mysql session +func (st *SessionStore) Get(key interface{}) interface{} { + st.lock.RLock() + defer st.lock.RUnlock() + if v, ok := st.values[key]; ok { + return v + } + return nil +} + +// Delete value in mysql session +func (st *SessionStore) Delete(key interface{}) error { + st.lock.Lock() + defer st.lock.Unlock() + delete(st.values, key) + return nil +} + +// Flush clear all values in mysql session +func (st *SessionStore) Flush() error { + st.lock.Lock() + defer st.lock.Unlock() + st.values = make(map[interface{}]interface{}) + return nil +} + +// SessionID get session id of this mysql session store +func (st *SessionStore) SessionID() string { + return st.sid +} + +// SessionRelease save mysql session values to database. +// must call this method to save values to database. +func (st *SessionStore) SessionRelease(w http.ResponseWriter) { + defer st.c.Close() + b, err := session.EncodeGob(st.values) + if err != nil { + return + } + st.c.Exec("UPDATE "+TableName+" set `session_data`=?, `session_expiry`=? where session_key=?", + b, time.Now().Unix(), st.sid) +} + +// Provider mysql session provider +type Provider struct { + maxlifetime int64 + savePath string +} + +// connect to mysql +func (mp *Provider) connectInit() *sql.DB { + db, e := sql.Open("mysql", mp.savePath) + if e != nil { + return nil + } + return db +} + +// SessionInit init mysql session. +// savepath is the connection string of mysql. +func (mp *Provider) SessionInit(maxlifetime int64, savePath string) error { + mp.maxlifetime = maxlifetime + mp.savePath = savePath + return nil +} + +// SessionRead get mysql session by sid +func (mp *Provider) SessionRead(sid string) (session.Store, error) { + c := mp.connectInit() + row := c.QueryRow("select session_data from "+TableName+" where session_key=?", sid) + var sessiondata []byte + err := row.Scan(&sessiondata) + if err == sql.ErrNoRows { + c.Exec("insert into "+TableName+"(`session_key`,`session_data`,`session_expiry`) values(?,?,?)", + sid, "", time.Now().Unix()) + } + var kv map[interface{}]interface{} + if len(sessiondata) == 0 { + kv = make(map[interface{}]interface{}) + } else { + kv, err = session.DecodeGob(sessiondata) + if err != nil { + return nil, err + } + } + rs := &SessionStore{c: c, sid: sid, values: kv} + return rs, nil +} + +// SessionExist check mysql session exist +func (mp *Provider) SessionExist(sid string) bool { + c := mp.connectInit() + defer c.Close() + row := c.QueryRow("select session_data from "+TableName+" where session_key=?", sid) + var sessiondata []byte + err := row.Scan(&sessiondata) + return err != sql.ErrNoRows +} + +// SessionRegenerate generate new sid for mysql session +func (mp *Provider) SessionRegenerate(oldsid, sid string) (session.Store, error) { + c := mp.connectInit() + row := c.QueryRow("select session_data from "+TableName+" where session_key=?", oldsid) + var sessiondata []byte + err := row.Scan(&sessiondata) + if err == sql.ErrNoRows { + c.Exec("insert into "+TableName+"(`session_key`,`session_data`,`session_expiry`) values(?,?,?)", oldsid, "", time.Now().Unix()) + } + c.Exec("update "+TableName+" set `session_key`=? where session_key=?", sid, oldsid) + var kv map[interface{}]interface{} + if len(sessiondata) == 0 { + kv = make(map[interface{}]interface{}) + } else { + kv, err = session.DecodeGob(sessiondata) + if err != nil { + return nil, err + } + } + rs := &SessionStore{c: c, sid: sid, values: kv} + return rs, nil +} + +// SessionDestroy delete mysql session by sid +func (mp *Provider) SessionDestroy(sid string) error { + c := mp.connectInit() + c.Exec("DELETE FROM "+TableName+" where session_key=?", sid) + c.Close() + return nil +} + +// SessionGC delete expired values in mysql session +func (mp *Provider) SessionGC() { + c := mp.connectInit() + c.Exec("DELETE from "+TableName+" where session_expiry < ?", time.Now().Unix()-mp.maxlifetime) + c.Close() +} + +// SessionAll count values in mysql session +func (mp *Provider) SessionAll() int { + c := mp.connectInit() + defer c.Close() + var total int + err := c.QueryRow("SELECT count(*) as num from " + TableName).Scan(&total) + if err != nil { + return 0 + } + return total +} + +func init() { + session.Register("mysql", mysqlpder) +} diff --git a/pkg/session/postgres/sess_postgresql.go b/pkg/session/postgres/sess_postgresql.go new file mode 100644 index 00000000..0b8b9645 --- /dev/null +++ b/pkg/session/postgres/sess_postgresql.go @@ -0,0 +1,243 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package postgres for session provider +// +// depends on github.com/lib/pq: +// +// go install github.com/lib/pq +// +// +// needs this table in your database: +// +// CREATE TABLE session ( +// session_key char(64) NOT NULL, +// session_data bytea, +// session_expiry timestamp NOT NULL, +// CONSTRAINT session_key PRIMARY KEY(session_key) +// ); +// +// will be activated with these settings in app.conf: +// +// SessionOn = true +// SessionProvider = postgresql +// SessionSavePath = "user=a password=b dbname=c sslmode=disable" +// SessionName = session +// +// +// Usage: +// import( +// _ "github.com/astaxie/beego/session/postgresql" +// "github.com/astaxie/beego/session" +// ) +// +// func init() { +// globalSessions, _ = session.NewManager("postgresql", ``{"cookieName":"gosessionid","gclifetime":3600,"ProviderConfig":"user=pqgotest dbname=pqgotest sslmode=verify-full"}``) +// go globalSessions.GC() +// } +// +// more docs: http://beego.me/docs/module/session.md +package postgres + +import ( + "database/sql" + "net/http" + "sync" + "time" + + "github.com/astaxie/beego/session" + // import postgresql Driver + _ "github.com/lib/pq" +) + +var postgresqlpder = &Provider{} + +// SessionStore postgresql session store +type SessionStore struct { + c *sql.DB + sid string + lock sync.RWMutex + values map[interface{}]interface{} +} + +// Set value in postgresql session. +// it is temp value in map. +func (st *SessionStore) Set(key, value interface{}) error { + st.lock.Lock() + defer st.lock.Unlock() + st.values[key] = value + return nil +} + +// Get value from postgresql session +func (st *SessionStore) Get(key interface{}) interface{} { + st.lock.RLock() + defer st.lock.RUnlock() + if v, ok := st.values[key]; ok { + return v + } + return nil +} + +// Delete value in postgresql session +func (st *SessionStore) Delete(key interface{}) error { + st.lock.Lock() + defer st.lock.Unlock() + delete(st.values, key) + return nil +} + +// Flush clear all values in postgresql session +func (st *SessionStore) Flush() error { + st.lock.Lock() + defer st.lock.Unlock() + st.values = make(map[interface{}]interface{}) + return nil +} + +// SessionID get session id of this postgresql session store +func (st *SessionStore) SessionID() string { + return st.sid +} + +// SessionRelease save postgresql session values to database. +// must call this method to save values to database. +func (st *SessionStore) SessionRelease(w http.ResponseWriter) { + defer st.c.Close() + b, err := session.EncodeGob(st.values) + if err != nil { + return + } + st.c.Exec("UPDATE session set session_data=$1, session_expiry=$2 where session_key=$3", + b, time.Now().Format(time.RFC3339), st.sid) + +} + +// Provider postgresql session provider +type Provider struct { + maxlifetime int64 + savePath string +} + +// connect to postgresql +func (mp *Provider) connectInit() *sql.DB { + db, e := sql.Open("postgres", mp.savePath) + if e != nil { + return nil + } + return db +} + +// SessionInit init postgresql session. +// savepath is the connection string of postgresql. +func (mp *Provider) SessionInit(maxlifetime int64, savePath string) error { + mp.maxlifetime = maxlifetime + mp.savePath = savePath + return nil +} + +// SessionRead get postgresql session by sid +func (mp *Provider) SessionRead(sid string) (session.Store, error) { + c := mp.connectInit() + row := c.QueryRow("select session_data from session where session_key=$1", sid) + var sessiondata []byte + err := row.Scan(&sessiondata) + if err == sql.ErrNoRows { + _, err = c.Exec("insert into session(session_key,session_data,session_expiry) values($1,$2,$3)", + sid, "", time.Now().Format(time.RFC3339)) + + if err != nil { + return nil, err + } + } else if err != nil { + return nil, err + } + + var kv map[interface{}]interface{} + if len(sessiondata) == 0 { + kv = make(map[interface{}]interface{}) + } else { + kv, err = session.DecodeGob(sessiondata) + if err != nil { + return nil, err + } + } + rs := &SessionStore{c: c, sid: sid, values: kv} + return rs, nil +} + +// SessionExist check postgresql session exist +func (mp *Provider) SessionExist(sid string) bool { + c := mp.connectInit() + defer c.Close() + row := c.QueryRow("select session_data from session where session_key=$1", sid) + var sessiondata []byte + err := row.Scan(&sessiondata) + return err != sql.ErrNoRows +} + +// SessionRegenerate generate new sid for postgresql session +func (mp *Provider) SessionRegenerate(oldsid, sid string) (session.Store, error) { + c := mp.connectInit() + row := c.QueryRow("select session_data from session where session_key=$1", oldsid) + var sessiondata []byte + err := row.Scan(&sessiondata) + if err == sql.ErrNoRows { + c.Exec("insert into session(session_key,session_data,session_expiry) values($1,$2,$3)", + oldsid, "", time.Now().Format(time.RFC3339)) + } + c.Exec("update session set session_key=$1 where session_key=$2", sid, oldsid) + var kv map[interface{}]interface{} + if len(sessiondata) == 0 { + kv = make(map[interface{}]interface{}) + } else { + kv, err = session.DecodeGob(sessiondata) + if err != nil { + return nil, err + } + } + rs := &SessionStore{c: c, sid: sid, values: kv} + return rs, nil +} + +// SessionDestroy delete postgresql session by sid +func (mp *Provider) SessionDestroy(sid string) error { + c := mp.connectInit() + c.Exec("DELETE FROM session where session_key=$1", sid) + c.Close() + return nil +} + +// SessionGC delete expired values in postgresql session +func (mp *Provider) SessionGC() { + c := mp.connectInit() + c.Exec("DELETE from session where EXTRACT(EPOCH FROM (current_timestamp - session_expiry)) > $1", mp.maxlifetime) + c.Close() +} + +// SessionAll count values in postgresql session +func (mp *Provider) SessionAll() int { + c := mp.connectInit() + defer c.Close() + var total int + err := c.QueryRow("SELECT count(*) as num from session").Scan(&total) + if err != nil { + return 0 + } + return total +} + +func init() { + session.Register("postgresql", postgresqlpder) +} diff --git a/pkg/session/redis/sess_redis.go b/pkg/session/redis/sess_redis.go new file mode 100644 index 00000000..5c382d61 --- /dev/null +++ b/pkg/session/redis/sess_redis.go @@ -0,0 +1,261 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package redis for session provider +// +// depend on github.com/gomodule/redigo/redis +// +// go install github.com/gomodule/redigo/redis +// +// Usage: +// import( +// _ "github.com/astaxie/beego/session/redis" +// "github.com/astaxie/beego/session" +// ) +// +// func init() { +// globalSessions, _ = session.NewManager("redis", ``{"cookieName":"gosessionid","gclifetime":3600,"ProviderConfig":"127.0.0.1:7070"}``) +// go globalSessions.GC() +// } +// +// more docs: http://beego.me/docs/module/session.md +package redis + +import ( + "net/http" + "strconv" + "strings" + "sync" + "time" + + "github.com/astaxie/beego/session" + + "github.com/gomodule/redigo/redis" +) + +var redispder = &Provider{} + +// MaxPoolSize redis max pool size +var MaxPoolSize = 100 + +// SessionStore redis session store +type SessionStore struct { + p *redis.Pool + sid string + lock sync.RWMutex + values map[interface{}]interface{} + maxlifetime int64 +} + +// Set value in redis session +func (rs *SessionStore) Set(key, value interface{}) error { + rs.lock.Lock() + defer rs.lock.Unlock() + rs.values[key] = value + return nil +} + +// Get value in redis session +func (rs *SessionStore) Get(key interface{}) interface{} { + rs.lock.RLock() + defer rs.lock.RUnlock() + if v, ok := rs.values[key]; ok { + return v + } + return nil +} + +// Delete value in redis session +func (rs *SessionStore) Delete(key interface{}) error { + rs.lock.Lock() + defer rs.lock.Unlock() + delete(rs.values, key) + return nil +} + +// Flush clear all values in redis session +func (rs *SessionStore) Flush() error { + rs.lock.Lock() + defer rs.lock.Unlock() + rs.values = make(map[interface{}]interface{}) + return nil +} + +// SessionID get redis session id +func (rs *SessionStore) SessionID() string { + return rs.sid +} + +// SessionRelease save session values to redis +func (rs *SessionStore) SessionRelease(w http.ResponseWriter) { + b, err := session.EncodeGob(rs.values) + if err != nil { + return + } + c := rs.p.Get() + defer c.Close() + c.Do("SETEX", rs.sid, rs.maxlifetime, string(b)) +} + +// Provider redis session provider +type Provider struct { + maxlifetime int64 + savePath string + poolsize int + password string + dbNum int + poollist *redis.Pool +} + +// SessionInit init redis session +// savepath like redis server addr,pool size,password,dbnum,IdleTimeout second +// e.g. 127.0.0.1:6379,100,astaxie,0,30 +func (rp *Provider) SessionInit(maxlifetime int64, savePath string) error { + rp.maxlifetime = maxlifetime + configs := strings.Split(savePath, ",") + if len(configs) > 0 { + rp.savePath = configs[0] + } + if len(configs) > 1 { + poolsize, err := strconv.Atoi(configs[1]) + if err != nil || poolsize < 0 { + rp.poolsize = MaxPoolSize + } else { + rp.poolsize = poolsize + } + } else { + rp.poolsize = MaxPoolSize + } + if len(configs) > 2 { + rp.password = configs[2] + } + if len(configs) > 3 { + dbnum, err := strconv.Atoi(configs[3]) + if err != nil || dbnum < 0 { + rp.dbNum = 0 + } else { + rp.dbNum = dbnum + } + } else { + rp.dbNum = 0 + } + var idleTimeout time.Duration = 0 + if len(configs) > 4 { + timeout, err := strconv.Atoi(configs[4]) + if err == nil && timeout > 0 { + idleTimeout = time.Duration(timeout) * time.Second + } + } + rp.poollist = &redis.Pool{ + Dial: func() (redis.Conn, error) { + c, err := redis.Dial("tcp", rp.savePath) + if err != nil { + return nil, err + } + if rp.password != "" { + if _, err = c.Do("AUTH", rp.password); err != nil { + c.Close() + return nil, err + } + } + // some redis proxy such as twemproxy is not support select command + if rp.dbNum > 0 { + _, err = c.Do("SELECT", rp.dbNum) + if err != nil { + c.Close() + return nil, err + } + } + return c, err + }, + MaxIdle: rp.poolsize, + } + + rp.poollist.IdleTimeout = idleTimeout + + return rp.poollist.Get().Err() +} + +// SessionRead read redis session by sid +func (rp *Provider) SessionRead(sid string) (session.Store, error) { + c := rp.poollist.Get() + defer c.Close() + + var kv map[interface{}]interface{} + + kvs, err := redis.String(c.Do("GET", sid)) + if err != nil && err != redis.ErrNil { + return nil, err + } + if len(kvs) == 0 { + kv = make(map[interface{}]interface{}) + } else { + if kv, err = session.DecodeGob([]byte(kvs)); err != nil { + return nil, err + } + } + + rs := &SessionStore{p: rp.poollist, sid: sid, values: kv, maxlifetime: rp.maxlifetime} + return rs, nil +} + +// SessionExist check redis session exist by sid +func (rp *Provider) SessionExist(sid string) bool { + c := rp.poollist.Get() + defer c.Close() + + if existed, err := redis.Int(c.Do("EXISTS", sid)); err != nil || existed == 0 { + return false + } + return true +} + +// SessionRegenerate generate new sid for redis session +func (rp *Provider) SessionRegenerate(oldsid, sid string) (session.Store, error) { + c := rp.poollist.Get() + defer c.Close() + + if existed, _ := redis.Int(c.Do("EXISTS", oldsid)); existed == 0 { + // oldsid doesn't exists, set the new sid directly + // ignore error here, since if it return error + // the existed value will be 0 + c.Do("SET", sid, "", "EX", rp.maxlifetime) + } else { + c.Do("RENAME", oldsid, sid) + c.Do("EXPIRE", sid, rp.maxlifetime) + } + return rp.SessionRead(sid) +} + +// SessionDestroy delete redis session by id +func (rp *Provider) SessionDestroy(sid string) error { + c := rp.poollist.Get() + defer c.Close() + + c.Do("DEL", sid) + return nil +} + +// SessionGC Impelment method, no used. +func (rp *Provider) SessionGC() { +} + +// SessionAll return all activeSession +func (rp *Provider) SessionAll() int { + return 0 +} + +func init() { + session.Register("redis", redispder) +} diff --git a/pkg/session/redis_cluster/redis_cluster.go b/pkg/session/redis_cluster/redis_cluster.go new file mode 100644 index 00000000..2fe300df --- /dev/null +++ b/pkg/session/redis_cluster/redis_cluster.go @@ -0,0 +1,220 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package redis for session provider +// +// depend on github.com/go-redis/redis +// +// go install github.com/go-redis/redis +// +// Usage: +// import( +// _ "github.com/astaxie/beego/session/redis_cluster" +// "github.com/astaxie/beego/session" +// ) +// +// func init() { +// globalSessions, _ = session.NewManager("redis_cluster", ``{"cookieName":"gosessionid","gclifetime":3600,"ProviderConfig":"127.0.0.1:7070;127.0.0.1:7071"}``) +// go globalSessions.GC() +// } +// +// more docs: http://beego.me/docs/module/session.md +package redis_cluster +import ( + "net/http" + "strconv" + "strings" + "sync" + "github.com/astaxie/beego/session" + rediss "github.com/go-redis/redis" + "time" +) + +var redispder = &Provider{} + +// MaxPoolSize redis_cluster max pool size +var MaxPoolSize = 1000 + +// SessionStore redis_cluster session store +type SessionStore struct { + p *rediss.ClusterClient + sid string + lock sync.RWMutex + values map[interface{}]interface{} + maxlifetime int64 +} + +// Set value in redis_cluster session +func (rs *SessionStore) Set(key, value interface{}) error { + rs.lock.Lock() + defer rs.lock.Unlock() + rs.values[key] = value + return nil +} + +// Get value in redis_cluster session +func (rs *SessionStore) Get(key interface{}) interface{} { + rs.lock.RLock() + defer rs.lock.RUnlock() + if v, ok := rs.values[key]; ok { + return v + } + return nil +} + +// Delete value in redis_cluster session +func (rs *SessionStore) Delete(key interface{}) error { + rs.lock.Lock() + defer rs.lock.Unlock() + delete(rs.values, key) + return nil +} + +// Flush clear all values in redis_cluster session +func (rs *SessionStore) Flush() error { + rs.lock.Lock() + defer rs.lock.Unlock() + rs.values = make(map[interface{}]interface{}) + return nil +} + +// SessionID get redis_cluster session id +func (rs *SessionStore) SessionID() string { + return rs.sid +} + +// SessionRelease save session values to redis_cluster +func (rs *SessionStore) SessionRelease(w http.ResponseWriter) { + b, err := session.EncodeGob(rs.values) + if err != nil { + return + } + c := rs.p + c.Set(rs.sid, string(b), time.Duration(rs.maxlifetime) * time.Second) +} + +// Provider redis_cluster session provider +type Provider struct { + maxlifetime int64 + savePath string + poolsize int + password string + dbNum int + poollist *rediss.ClusterClient +} + +// SessionInit init redis_cluster session +// savepath like redis server addr,pool size,password,dbnum +// e.g. 127.0.0.1:6379;127.0.0.1:6380,100,test,0 +func (rp *Provider) SessionInit(maxlifetime int64, savePath string) error { + rp.maxlifetime = maxlifetime + configs := strings.Split(savePath, ",") + if len(configs) > 0 { + rp.savePath = configs[0] + } + if len(configs) > 1 { + poolsize, err := strconv.Atoi(configs[1]) + if err != nil || poolsize < 0 { + rp.poolsize = MaxPoolSize + } else { + rp.poolsize = poolsize + } + } else { + rp.poolsize = MaxPoolSize + } + if len(configs) > 2 { + rp.password = configs[2] + } + if len(configs) > 3 { + dbnum, err := strconv.Atoi(configs[3]) + if err != nil || dbnum < 0 { + rp.dbNum = 0 + } else { + rp.dbNum = dbnum + } + } else { + rp.dbNum = 0 + } + + rp.poollist = rediss.NewClusterClient(&rediss.ClusterOptions{ + Addrs: strings.Split(rp.savePath, ";"), + Password: rp.password, + PoolSize: rp.poolsize, + }) + return rp.poollist.Ping().Err() +} + +// SessionRead read redis_cluster session by sid +func (rp *Provider) SessionRead(sid string) (session.Store, error) { + var kv map[interface{}]interface{} + kvs, err := rp.poollist.Get(sid).Result() + if err != nil && err != rediss.Nil { + return nil, err + } + if len(kvs) == 0 { + kv = make(map[interface{}]interface{}) + } else { + if kv, err = session.DecodeGob([]byte(kvs)); err != nil { + return nil, err + } + } + + rs := &SessionStore{p: rp.poollist, sid: sid, values: kv, maxlifetime: rp.maxlifetime} + return rs, nil +} + +// SessionExist check redis_cluster session exist by sid +func (rp *Provider) SessionExist(sid string) bool { + c := rp.poollist + if existed, err := c.Exists(sid).Result(); err != nil || existed == 0 { + return false + } + return true +} + +// SessionRegenerate generate new sid for redis_cluster session +func (rp *Provider) SessionRegenerate(oldsid, sid string) (session.Store, error) { + c := rp.poollist + + if existed, err := c.Exists(oldsid).Result(); err != nil || existed == 0 { + // oldsid doesn't exists, set the new sid directly + // ignore error here, since if it return error + // the existed value will be 0 + c.Set(sid, "", time.Duration(rp.maxlifetime) * time.Second) + } else { + c.Rename(oldsid, sid) + c.Expire(sid, time.Duration(rp.maxlifetime) * time.Second) + } + return rp.SessionRead(sid) +} + +// SessionDestroy delete redis session by id +func (rp *Provider) SessionDestroy(sid string) error { + c := rp.poollist + c.Del(sid) + return nil +} + +// SessionGC Impelment method, no used. +func (rp *Provider) SessionGC() { +} + +// SessionAll return all activeSession +func (rp *Provider) SessionAll() int { + return 0 +} + +func init() { + session.Register("redis_cluster", redispder) +} diff --git a/pkg/session/redis_sentinel/sess_redis_sentinel.go b/pkg/session/redis_sentinel/sess_redis_sentinel.go new file mode 100644 index 00000000..6ecb2977 --- /dev/null +++ b/pkg/session/redis_sentinel/sess_redis_sentinel.go @@ -0,0 +1,234 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package redis for session provider +// +// depend on github.com/go-redis/redis +// +// go install github.com/go-redis/redis +// +// Usage: +// import( +// _ "github.com/astaxie/beego/session/redis_sentinel" +// "github.com/astaxie/beego/session" +// ) +// +// func init() { +// globalSessions, _ = session.NewManager("redis_sentinel", ``{"cookieName":"gosessionid","gclifetime":3600,"ProviderConfig":"127.0.0.1:26379;127.0.0.2:26379"}``) +// go globalSessions.GC() +// } +// +// more detail about params: please check the notes on the function SessionInit in this package +package redis_sentinel + +import ( + "github.com/astaxie/beego/session" + "github.com/go-redis/redis" + "net/http" + "strconv" + "strings" + "sync" + "time" +) + +var redispder = &Provider{} + +// DefaultPoolSize redis_sentinel default pool size +var DefaultPoolSize = 100 + +// SessionStore redis_sentinel session store +type SessionStore struct { + p *redis.Client + sid string + lock sync.RWMutex + values map[interface{}]interface{} + maxlifetime int64 +} + +// Set value in redis_sentinel session +func (rs *SessionStore) Set(key, value interface{}) error { + rs.lock.Lock() + defer rs.lock.Unlock() + rs.values[key] = value + return nil +} + +// Get value in redis_sentinel session +func (rs *SessionStore) Get(key interface{}) interface{} { + rs.lock.RLock() + defer rs.lock.RUnlock() + if v, ok := rs.values[key]; ok { + return v + } + return nil +} + +// Delete value in redis_sentinel session +func (rs *SessionStore) Delete(key interface{}) error { + rs.lock.Lock() + defer rs.lock.Unlock() + delete(rs.values, key) + return nil +} + +// Flush clear all values in redis_sentinel session +func (rs *SessionStore) Flush() error { + rs.lock.Lock() + defer rs.lock.Unlock() + rs.values = make(map[interface{}]interface{}) + return nil +} + +// SessionID get redis_sentinel session id +func (rs *SessionStore) SessionID() string { + return rs.sid +} + +// SessionRelease save session values to redis_sentinel +func (rs *SessionStore) SessionRelease(w http.ResponseWriter) { + b, err := session.EncodeGob(rs.values) + if err != nil { + return + } + c := rs.p + c.Set(rs.sid, string(b), time.Duration(rs.maxlifetime)*time.Second) +} + +// Provider redis_sentinel session provider +type Provider struct { + maxlifetime int64 + savePath string + poolsize int + password string + dbNum int + poollist *redis.Client + masterName string +} + +// SessionInit init redis_sentinel session +// savepath like redis sentinel addr,pool size,password,dbnum,masterName +// e.g. 127.0.0.1:26379;127.0.0.2:26379,100,1qaz2wsx,0,mymaster +func (rp *Provider) SessionInit(maxlifetime int64, savePath string) error { + rp.maxlifetime = maxlifetime + configs := strings.Split(savePath, ",") + if len(configs) > 0 { + rp.savePath = configs[0] + } + if len(configs) > 1 { + poolsize, err := strconv.Atoi(configs[1]) + if err != nil || poolsize < 0 { + rp.poolsize = DefaultPoolSize + } else { + rp.poolsize = poolsize + } + } else { + rp.poolsize = DefaultPoolSize + } + if len(configs) > 2 { + rp.password = configs[2] + } + if len(configs) > 3 { + dbnum, err := strconv.Atoi(configs[3]) + if err != nil || dbnum < 0 { + rp.dbNum = 0 + } else { + rp.dbNum = dbnum + } + } else { + rp.dbNum = 0 + } + if len(configs) > 4 { + if configs[4] != "" { + rp.masterName = configs[4] + } else { + rp.masterName = "mymaster" + } + } else { + rp.masterName = "mymaster" + } + + rp.poollist = redis.NewFailoverClient(&redis.FailoverOptions{ + SentinelAddrs: strings.Split(rp.savePath, ";"), + Password: rp.password, + PoolSize: rp.poolsize, + DB: rp.dbNum, + MasterName: rp.masterName, + }) + + return rp.poollist.Ping().Err() +} + +// SessionRead read redis_sentinel session by sid +func (rp *Provider) SessionRead(sid string) (session.Store, error) { + var kv map[interface{}]interface{} + kvs, err := rp.poollist.Get(sid).Result() + if err != nil && err != redis.Nil { + return nil, err + } + if len(kvs) == 0 { + kv = make(map[interface{}]interface{}) + } else { + if kv, err = session.DecodeGob([]byte(kvs)); err != nil { + return nil, err + } + } + + rs := &SessionStore{p: rp.poollist, sid: sid, values: kv, maxlifetime: rp.maxlifetime} + return rs, nil +} + +// SessionExist check redis_sentinel session exist by sid +func (rp *Provider) SessionExist(sid string) bool { + c := rp.poollist + if existed, err := c.Exists(sid).Result(); err != nil || existed == 0 { + return false + } + return true +} + +// SessionRegenerate generate new sid for redis_sentinel session +func (rp *Provider) SessionRegenerate(oldsid, sid string) (session.Store, error) { + c := rp.poollist + + if existed, err := c.Exists(oldsid).Result(); err != nil || existed == 0 { + // oldsid doesn't exists, set the new sid directly + // ignore error here, since if it return error + // the existed value will be 0 + c.Set(sid, "", time.Duration(rp.maxlifetime)*time.Second) + } else { + c.Rename(oldsid, sid) + c.Expire(sid, time.Duration(rp.maxlifetime)*time.Second) + } + return rp.SessionRead(sid) +} + +// SessionDestroy delete redis session by id +func (rp *Provider) SessionDestroy(sid string) error { + c := rp.poollist + c.Del(sid) + return nil +} + +// SessionGC Impelment method, no used. +func (rp *Provider) SessionGC() { +} + +// SessionAll return all activeSession +func (rp *Provider) SessionAll() int { + return 0 +} + +func init() { + session.Register("redis_sentinel", redispder) +} diff --git a/pkg/session/redis_sentinel/sess_redis_sentinel_test.go b/pkg/session/redis_sentinel/sess_redis_sentinel_test.go new file mode 100644 index 00000000..fd4155c6 --- /dev/null +++ b/pkg/session/redis_sentinel/sess_redis_sentinel_test.go @@ -0,0 +1,90 @@ +package redis_sentinel + +import ( + "net/http" + "net/http/httptest" + "testing" + + "github.com/astaxie/beego/session" +) + +func TestRedisSentinel(t *testing.T) { + sessionConfig := &session.ManagerConfig{ + CookieName: "gosessionid", + EnableSetCookie: true, + Gclifetime: 3600, + Maxlifetime: 3600, + Secure: false, + CookieLifeTime: 3600, + ProviderConfig: "127.0.0.1:6379,100,,0,master", + } + globalSessions, e := session.NewManager("redis_sentinel", sessionConfig) + if e != nil { + t.Log(e) + return + } + //todo test if e==nil + go globalSessions.GC() + + r, _ := http.NewRequest("GET", "/", nil) + w := httptest.NewRecorder() + + sess, err := globalSessions.SessionStart(w, r) + if err != nil { + t.Fatal("session start failed:", err) + } + defer sess.SessionRelease(w) + + // SET AND GET + err = sess.Set("username", "astaxie") + if err != nil { + t.Fatal("set username failed:", err) + } + username := sess.Get("username") + if username != "astaxie" { + t.Fatal("get username failed") + } + + // DELETE + err = sess.Delete("username") + if err != nil { + t.Fatal("delete username failed:", err) + } + username = sess.Get("username") + if username != nil { + t.Fatal("delete username failed") + } + + // FLUSH + err = sess.Set("username", "astaxie") + if err != nil { + t.Fatal("set failed:", err) + } + err = sess.Set("password", "1qaz2wsx") + if err != nil { + t.Fatal("set failed:", err) + } + username = sess.Get("username") + if username != "astaxie" { + t.Fatal("get username failed") + } + password := sess.Get("password") + if password != "1qaz2wsx" { + t.Fatal("get password failed") + } + err = sess.Flush() + if err != nil { + t.Fatal("flush failed:", err) + } + username = sess.Get("username") + if username != nil { + t.Fatal("flush failed") + } + password = sess.Get("password") + if password != nil { + t.Fatal("flush failed") + } + + sess.SessionRelease(w) + +} diff --git a/pkg/session/sess_cookie.go b/pkg/session/sess_cookie.go new file mode 100644 index 00000000..6ad5debc --- /dev/null +++ b/pkg/session/sess_cookie.go @@ -0,0 +1,180 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package session + +import ( + "crypto/aes" + "crypto/cipher" + "encoding/json" + "net/http" + "net/url" + "sync" +) + +var cookiepder = &CookieProvider{} + +// CookieSessionStore Cookie SessionStore +type CookieSessionStore struct { + sid string + values map[interface{}]interface{} // session data + lock sync.RWMutex +} + +// Set value to cookie session. +// the value are encoded as gob with hash block string. +func (st *CookieSessionStore) Set(key, value interface{}) error { + st.lock.Lock() + defer st.lock.Unlock() + st.values[key] = value + return nil +} + +// Get value from cookie session +func (st *CookieSessionStore) Get(key interface{}) interface{} { + st.lock.RLock() + defer st.lock.RUnlock() + if v, ok := st.values[key]; ok { + return v + } + return nil +} + +// Delete value in cookie session +func (st *CookieSessionStore) Delete(key interface{}) error { + st.lock.Lock() + defer st.lock.Unlock() + delete(st.values, key) + return nil +} + +// Flush Clean all values in cookie session +func (st *CookieSessionStore) Flush() error { + st.lock.Lock() + defer st.lock.Unlock() + st.values = make(map[interface{}]interface{}) + return nil +} + +// SessionID Return id of this cookie session +func (st *CookieSessionStore) SessionID() string { + return st.sid +} + +// SessionRelease Write cookie session to http response cookie +func (st *CookieSessionStore) SessionRelease(w http.ResponseWriter) { + st.lock.Lock() + encodedCookie, err := encodeCookie(cookiepder.block, cookiepder.config.SecurityKey, cookiepder.config.SecurityName, st.values) + st.lock.Unlock() + if err == nil { + cookie := &http.Cookie{Name: cookiepder.config.CookieName, + Value: url.QueryEscape(encodedCookie), + Path: "/", + HttpOnly: true, + Secure: cookiepder.config.Secure, + MaxAge: cookiepder.config.Maxage} + http.SetCookie(w, cookie) + } +} + +type cookieConfig struct { + SecurityKey string `json:"securityKey"` + BlockKey string `json:"blockKey"` + SecurityName string `json:"securityName"` + CookieName string `json:"cookieName"` + Secure bool `json:"secure"` + Maxage int `json:"maxage"` +} + +// CookieProvider Cookie session provider +type CookieProvider struct { + maxlifetime int64 + config *cookieConfig + block cipher.Block +} + +// SessionInit Init cookie session provider with max lifetime and config json. +// maxlifetime is ignored. +// json config: +// securityKey - hash string +// blockKey - gob encode hash string. it's saved as aes crypto. +// securityName - recognized name in encoded cookie string +// cookieName - cookie name +// maxage - cookie max life time. +func (pder *CookieProvider) SessionInit(maxlifetime int64, config string) error { + pder.config = &cookieConfig{} + err := json.Unmarshal([]byte(config), pder.config) + if err != nil { + return err + } + if pder.config.BlockKey == "" { + pder.config.BlockKey = string(generateRandomKey(16)) + } + if pder.config.SecurityName == "" { + pder.config.SecurityName = string(generateRandomKey(20)) + } + pder.block, err = aes.NewCipher([]byte(pder.config.BlockKey)) + if err != nil { + return err + } + pder.maxlifetime = maxlifetime + return nil +} + +// SessionRead Get SessionStore in cooke. +// decode cooke string to map and put into SessionStore with sid. +func (pder *CookieProvider) SessionRead(sid string) (Store, error) { + maps, _ := decodeCookie(pder.block, + pder.config.SecurityKey, + pder.config.SecurityName, + sid, pder.maxlifetime) + if maps == nil { + maps = make(map[interface{}]interface{}) + } + rs := &CookieSessionStore{sid: sid, values: maps} + return rs, nil +} + +// SessionExist Cookie session is always existed +func (pder *CookieProvider) SessionExist(sid string) bool { + return true +} + +// SessionRegenerate Implement method, no used. +func (pder *CookieProvider) SessionRegenerate(oldsid, sid string) (Store, error) { + return nil, nil +} + +// SessionDestroy Implement method, no used. +func (pder *CookieProvider) SessionDestroy(sid string) error { + return nil +} + +// SessionGC Implement method, no used. +func (pder *CookieProvider) SessionGC() { +} + +// SessionAll Implement method, return 0. +func (pder *CookieProvider) SessionAll() int { + return 0 +} + +// SessionUpdate Implement method, no used. +func (pder *CookieProvider) SessionUpdate(sid string) error { + return nil +} + +func init() { + Register("cookie", cookiepder) +} diff --git a/pkg/session/sess_cookie_test.go b/pkg/session/sess_cookie_test.go new file mode 100644 index 00000000..b6726005 --- /dev/null +++ b/pkg/session/sess_cookie_test.go @@ -0,0 +1,105 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package session + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "strings" + "testing" +) + +func TestCookie(t *testing.T) { + config := `{"cookieName":"gosessionid","enableSetCookie":false,"gclifetime":3600,"ProviderConfig":"{\"cookieName\":\"gosessionid\",\"securityKey\":\"beegocookiehashkey\"}"}` + conf := new(ManagerConfig) + if err := json.Unmarshal([]byte(config), conf); err != nil { + t.Fatal("json decode error", err) + } + globalSessions, err := NewManager("cookie", conf) + if err != nil { + t.Fatal("init cookie session err", err) + } + r, _ := http.NewRequest("GET", "/", nil) + w := httptest.NewRecorder() + sess, err := globalSessions.SessionStart(w, r) + if err != nil { + t.Fatal("set error,", err) + } + err = sess.Set("username", "astaxie") + if err != nil { + t.Fatal("set error,", err) + } + if username := sess.Get("username"); username != "astaxie" { + t.Fatal("get username error") + } + sess.SessionRelease(w) + if cookiestr := w.Header().Get("Set-Cookie"); cookiestr == "" { + t.Fatal("setcookie error") + } else { + parts := strings.Split(strings.TrimSpace(cookiestr), ";") + for k, v := range parts { + nameval := strings.Split(v, "=") + if k == 0 && nameval[0] != "gosessionid" { + t.Fatal("error") + } + } + } +} + +func TestDestorySessionCookie(t *testing.T) { + config := `{"cookieName":"gosessionid","enableSetCookie":true,"gclifetime":3600,"ProviderConfig":"{\"cookieName\":\"gosessionid\",\"securityKey\":\"beegocookiehashkey\"}"}` + conf := new(ManagerConfig) + if err := json.Unmarshal([]byte(config), conf); err != nil { + t.Fatal("json decode error", err) + } + globalSessions, err := NewManager("cookie", conf) + if err != nil { + t.Fatal("init cookie session err", err) + } + + r, _ := http.NewRequest("GET", "/", nil) + w := httptest.NewRecorder() + session, err := globalSessions.SessionStart(w, r) + if err != nil { + t.Fatal("session start err,", err) + } + + // request again ,will get same sesssion id . + r1, _ := http.NewRequest("GET", "/", nil) + r1.Header.Set("Cookie", w.Header().Get("Set-Cookie")) + w = httptest.NewRecorder() + newSession, err := globalSessions.SessionStart(w, r1) + if err != nil { + t.Fatal("session start err,", err) + } + if newSession.SessionID() != session.SessionID() { + t.Fatal("get cookie session id is not the same again.") + } + + // After destroy session , will get a new session id . + globalSessions.SessionDestroy(w, r1) + r2, _ := http.NewRequest("GET", "/", nil) + r2.Header.Set("Cookie", w.Header().Get("Set-Cookie")) + + w = httptest.NewRecorder() + newSession, err = globalSessions.SessionStart(w, r2) + if err != nil { + t.Fatal("session start error") + } + if newSession.SessionID() == session.SessionID() { + t.Fatal("after destroy session and reqeust again ,get cookie session id is same.") + } +} diff --git a/pkg/session/sess_file.go b/pkg/session/sess_file.go new file mode 100644 index 00000000..47ad54a7 --- /dev/null +++ b/pkg/session/sess_file.go @@ -0,0 +1,315 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package session + +import ( + "errors" + "fmt" + "io/ioutil" + "net/http" + "os" + "path" + "path/filepath" + "strings" + "sync" + "time" +) + +var ( + filepder = &FileProvider{} + gcmaxlifetime int64 +) + +// FileSessionStore File session store +type FileSessionStore struct { + sid string + lock sync.RWMutex + values map[interface{}]interface{} +} + +// Set value to file session +func (fs *FileSessionStore) Set(key, value interface{}) error { + fs.lock.Lock() + defer fs.lock.Unlock() + fs.values[key] = value + return nil +} + +// Get value from file session +func (fs *FileSessionStore) Get(key interface{}) interface{} { + fs.lock.RLock() + defer fs.lock.RUnlock() + if v, ok := fs.values[key]; ok { + return v + } + return nil +} + +// Delete value in file session by given key +func (fs *FileSessionStore) Delete(key interface{}) error { + fs.lock.Lock() + defer fs.lock.Unlock() + delete(fs.values, key) + return nil +} + +// Flush Clean all values in file session +func (fs *FileSessionStore) Flush() error { + fs.lock.Lock() + defer fs.lock.Unlock() + fs.values = make(map[interface{}]interface{}) + return nil +} + +// SessionID Get file session store id +func (fs *FileSessionStore) SessionID() string { + return fs.sid +} + +// SessionRelease Write file session to local file with Gob string +func (fs *FileSessionStore) SessionRelease(w http.ResponseWriter) { + filepder.lock.Lock() + defer filepder.lock.Unlock() + b, err := EncodeGob(fs.values) + if err != nil { + SLogger.Println(err) + return + } + _, err = os.Stat(path.Join(filepder.savePath, string(fs.sid[0]), string(fs.sid[1]), fs.sid)) + var f *os.File + if err == nil { + f, err = os.OpenFile(path.Join(filepder.savePath, string(fs.sid[0]), string(fs.sid[1]), fs.sid), os.O_RDWR, 0777) + if err != nil { + SLogger.Println(err) + return + } + } else if os.IsNotExist(err) { + f, err = os.Create(path.Join(filepder.savePath, string(fs.sid[0]), string(fs.sid[1]), fs.sid)) + if err != nil { + SLogger.Println(err) + return + } + } else { + return + } + f.Truncate(0) + f.Seek(0, 0) + f.Write(b) + f.Close() +} + +// FileProvider File session provider +type FileProvider struct { + lock sync.RWMutex + maxlifetime int64 + savePath string +} + +// SessionInit Init file session provider. +// savePath sets the session files path. +func (fp *FileProvider) SessionInit(maxlifetime int64, savePath string) error { + fp.maxlifetime = maxlifetime + fp.savePath = savePath + return nil +} + +// SessionRead Read file session by sid. +// if file is not exist, create it. +// the file path is generated from sid string. +func (fp *FileProvider) SessionRead(sid string) (Store, error) { + invalidChars := "./" + if strings.ContainsAny(sid, invalidChars) { + return nil, errors.New("the sid shouldn't have following characters: " + invalidChars) + } + if len(sid) < 2 { + return nil, errors.New("length of the sid is less than 2") + } + filepder.lock.Lock() + defer filepder.lock.Unlock() + + err := os.MkdirAll(path.Join(fp.savePath, string(sid[0]), string(sid[1])), 0755) + if err != nil { + SLogger.Println(err.Error()) + } + _, err = os.Stat(path.Join(fp.savePath, string(sid[0]), string(sid[1]), sid)) + var f *os.File + if err == nil { + f, err = os.OpenFile(path.Join(fp.savePath, string(sid[0]), string(sid[1]), sid), os.O_RDWR, 0777) + } else if os.IsNotExist(err) { + f, err = os.Create(path.Join(fp.savePath, string(sid[0]), string(sid[1]), sid)) + } else { + return nil, err + } + + defer f.Close() + + os.Chtimes(path.Join(fp.savePath, string(sid[0]), string(sid[1]), sid), time.Now(), time.Now()) + var kv map[interface{}]interface{} + b, err := ioutil.ReadAll(f) + if err != nil { + return nil, err + } + if len(b) == 0 { + kv = make(map[interface{}]interface{}) + } else { + kv, err = DecodeGob(b) + if err != nil { + return nil, err + } + } + + ss := &FileSessionStore{sid: sid, values: kv} + return ss, nil +} + +// SessionExist Check file session exist. +// it checks the file named from sid exist or not. +func (fp *FileProvider) SessionExist(sid string) bool { + filepder.lock.Lock() + defer filepder.lock.Unlock() + + if len(sid) < 2 { + SLogger.Println("min length of session id is 2", sid) + return false + } + + _, err := os.Stat(path.Join(fp.savePath, string(sid[0]), string(sid[1]), sid)) + return err == nil +} + +// SessionDestroy Remove all files in this save path +func (fp *FileProvider) SessionDestroy(sid string) error { + filepder.lock.Lock() + defer filepder.lock.Unlock() + os.Remove(path.Join(fp.savePath, string(sid[0]), string(sid[1]), sid)) + return nil +} + +// SessionGC Recycle files in save path +func (fp *FileProvider) SessionGC() { + filepder.lock.Lock() + defer filepder.lock.Unlock() + + gcmaxlifetime = fp.maxlifetime + filepath.Walk(fp.savePath, gcpath) +} + +// SessionAll Get active file session number. +// it walks save path to count files. +func (fp *FileProvider) SessionAll() int { + a := &activeSession{} + err := filepath.Walk(fp.savePath, func(path string, f os.FileInfo, err error) error { + return a.visit(path, f, err) + }) + if err != nil { + SLogger.Printf("filepath.Walk() returned %v\n", err) + return 0 + } + return a.total +} + +// SessionRegenerate Generate new sid for file session. +// it delete old file and create new file named from new sid. +func (fp *FileProvider) SessionRegenerate(oldsid, sid string) (Store, error) { + filepder.lock.Lock() + defer filepder.lock.Unlock() + + oldPath := path.Join(fp.savePath, string(oldsid[0]), string(oldsid[1])) + oldSidFile := path.Join(oldPath, oldsid) + newPath := path.Join(fp.savePath, string(sid[0]), string(sid[1])) + newSidFile := path.Join(newPath, sid) + + // new sid file is exist + _, err := os.Stat(newSidFile) + if err == nil { + return nil, fmt.Errorf("newsid %s exist", newSidFile) + } + + err = os.MkdirAll(newPath, 0755) + if err != nil { + SLogger.Println(err.Error()) + } + + // if old sid file exist + // 1.read and parse file content + // 2.write content to new sid file + // 3.remove old sid file, change new sid file atime and ctime + // 4.return FileSessionStore + _, err = os.Stat(oldSidFile) + if err == nil { + b, err := ioutil.ReadFile(oldSidFile) + if err != nil { + return nil, err + } + + var kv map[interface{}]interface{} + if len(b) == 0 { + kv = make(map[interface{}]interface{}) + } else { + kv, err = DecodeGob(b) + if err != nil { + return nil, err + } + } + + ioutil.WriteFile(newSidFile, b, 0777) + os.Remove(oldSidFile) + os.Chtimes(newSidFile, time.Now(), time.Now()) + ss := &FileSessionStore{sid: sid, values: kv} + return ss, nil + } + + // if old sid file not exist, just create new sid file and return + newf, err := os.Create(newSidFile) + if err != nil { + return nil, err + } + newf.Close() + ss := &FileSessionStore{sid: sid, values: make(map[interface{}]interface{})} + return ss, nil +} + +// remove file in save path if expired +func gcpath(path string, info os.FileInfo, err error) error { + if err != nil { + return err + } + if info.IsDir() { + return nil + } + if (info.ModTime().Unix() + gcmaxlifetime) < time.Now().Unix() { + os.Remove(path) + } + return nil +} + +type activeSession struct { + total int +} + +func (as *activeSession) visit(paths string, f os.FileInfo, err error) error { + if err != nil { + return err + } + if f.IsDir() { + return nil + } + as.total = as.total + 1 + return nil +} + +func init() { + Register("file", filepder) +} diff --git a/pkg/session/sess_file_test.go b/pkg/session/sess_file_test.go new file mode 100644 index 00000000..0cf021db --- /dev/null +++ b/pkg/session/sess_file_test.go @@ -0,0 +1,387 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package session + +import ( + "fmt" + "os" + "sync" + "testing" + "time" +) + +const sid = "Session_id" +const sidNew = "Session_id_new" +const sessionPath = "./_session_runtime" + +var ( + mutex sync.Mutex +) + +func TestFileProvider_SessionInit(t *testing.T) { + mutex.Lock() + defer mutex.Unlock() + os.RemoveAll(sessionPath) + defer os.RemoveAll(sessionPath) + fp := &FileProvider{} + + _ = fp.SessionInit(180, sessionPath) + if fp.maxlifetime != 180 { + t.Error() + } + + if fp.savePath != sessionPath { + t.Error() + } +} + +func TestFileProvider_SessionExist(t *testing.T) { + mutex.Lock() + defer mutex.Unlock() + os.RemoveAll(sessionPath) + defer os.RemoveAll(sessionPath) + fp := &FileProvider{} + + _ = fp.SessionInit(180, sessionPath) + + if fp.SessionExist(sid) { + t.Error() + } + + _, err := fp.SessionRead(sid) + if err != nil { + t.Error(err) + } + + if !fp.SessionExist(sid) { + t.Error() + } +} + +func TestFileProvider_SessionExist2(t *testing.T) { + mutex.Lock() + defer mutex.Unlock() + os.RemoveAll(sessionPath) + defer os.RemoveAll(sessionPath) + fp := &FileProvider{} + + _ = fp.SessionInit(180, sessionPath) + + if fp.SessionExist(sid) { + t.Error() + } + + if fp.SessionExist("") { + t.Error() + } + + if fp.SessionExist("1") { + t.Error() + } +} + +func TestFileProvider_SessionRead(t *testing.T) { + mutex.Lock() + defer mutex.Unlock() + os.RemoveAll(sessionPath) + defer os.RemoveAll(sessionPath) + fp := &FileProvider{} + + _ = fp.SessionInit(180, sessionPath) + + s, err := fp.SessionRead(sid) + if err != nil { + t.Error(err) + } + + _ = s.Set("sessionValue", 18975) + v := s.Get("sessionValue") + + if v.(int) != 18975 { + t.Error() + } +} + +func TestFileProvider_SessionRead1(t *testing.T) { + mutex.Lock() + defer mutex.Unlock() + os.RemoveAll(sessionPath) + defer os.RemoveAll(sessionPath) + fp := &FileProvider{} + + _ = fp.SessionInit(180, sessionPath) + + _, err := fp.SessionRead("") + if err == nil { + t.Error(err) + } + + _, err = fp.SessionRead("1") + if err == nil { + t.Error(err) + } +} + +func TestFileProvider_SessionAll(t *testing.T) { + mutex.Lock() + defer mutex.Unlock() + os.RemoveAll(sessionPath) + defer os.RemoveAll(sessionPath) + fp := &FileProvider{} + + _ = fp.SessionInit(180, sessionPath) + + sessionCount := 546 + + for i := 1; i <= sessionCount; i++ { + _, err := fp.SessionRead(fmt.Sprintf("%s_%d", sid, i)) + if err != nil { + t.Error(err) + } + } + + if fp.SessionAll() != sessionCount { + t.Error() + } +} + +func TestFileProvider_SessionRegenerate(t *testing.T) { + mutex.Lock() + defer mutex.Unlock() + os.RemoveAll(sessionPath) + defer os.RemoveAll(sessionPath) + fp := &FileProvider{} + + _ = fp.SessionInit(180, sessionPath) + + _, err := fp.SessionRead(sid) + if err != nil { + t.Error(err) + } + + if !fp.SessionExist(sid) { + t.Error() + } + + _, err = fp.SessionRegenerate(sid, sidNew) + if err != nil { + t.Error(err) + } + + if fp.SessionExist(sid) { + t.Error() + } + + if !fp.SessionExist(sidNew) { + t.Error() + } +} + +func TestFileProvider_SessionDestroy(t *testing.T) { + mutex.Lock() + defer mutex.Unlock() + os.RemoveAll(sessionPath) + defer os.RemoveAll(sessionPath) + fp := &FileProvider{} + + _ = fp.SessionInit(180, sessionPath) + + _, err := fp.SessionRead(sid) + if err != nil { + t.Error(err) + } + + if !fp.SessionExist(sid) { + t.Error() + } + + err = fp.SessionDestroy(sid) + if err != nil { + t.Error(err) + } + + if fp.SessionExist(sid) { + t.Error() + } +} + +func TestFileProvider_SessionGC(t *testing.T) { + mutex.Lock() + defer mutex.Unlock() + os.RemoveAll(sessionPath) + defer os.RemoveAll(sessionPath) + fp := &FileProvider{} + + _ = fp.SessionInit(1, sessionPath) + + sessionCount := 412 + + for i := 1; i <= sessionCount; i++ { + _, err := fp.SessionRead(fmt.Sprintf("%s_%d", sid, i)) + if err != nil { + t.Error(err) + } + } + + time.Sleep(2 * time.Second) + + fp.SessionGC() + if fp.SessionAll() != 0 { + t.Error() + } +} + +func TestFileSessionStore_Set(t *testing.T) { + mutex.Lock() + defer mutex.Unlock() + os.RemoveAll(sessionPath) + defer os.RemoveAll(sessionPath) + fp := &FileProvider{} + + _ = fp.SessionInit(180, sessionPath) + + sessionCount := 100 + s, _ := fp.SessionRead(sid) + for i := 1; i <= sessionCount; i++ { + err := s.Set(i, i) + if err != nil { + t.Error(err) + } + } +} + +func TestFileSessionStore_Get(t *testing.T) { + mutex.Lock() + defer mutex.Unlock() + os.RemoveAll(sessionPath) + defer os.RemoveAll(sessionPath) + fp := &FileProvider{} + + _ = fp.SessionInit(180, sessionPath) + + sessionCount := 100 + s, _ := fp.SessionRead(sid) + for i := 1; i <= sessionCount; i++ { + _ = s.Set(i, i) + + v := s.Get(i) + if v.(int) != i { + t.Error() + } + } +} + +func TestFileSessionStore_Delete(t *testing.T) { + mutex.Lock() + defer mutex.Unlock() + os.RemoveAll(sessionPath) + defer os.RemoveAll(sessionPath) + fp := &FileProvider{} + + _ = fp.SessionInit(180, sessionPath) + + s, _ := fp.SessionRead(sid) + s.Set("1", 1) + + if s.Get("1") == nil { + t.Error() + } + + s.Delete("1") + + if s.Get("1") != nil { + t.Error() + } +} + +func TestFileSessionStore_Flush(t *testing.T) { + mutex.Lock() + defer mutex.Unlock() + os.RemoveAll(sessionPath) + defer os.RemoveAll(sessionPath) + fp := &FileProvider{} + + _ = fp.SessionInit(180, sessionPath) + + sessionCount := 100 + s, _ := fp.SessionRead(sid) + for i := 1; i <= sessionCount; i++ { + _ = s.Set(i, i) + } + + _ = s.Flush() + + for i := 1; i <= sessionCount; i++ { + if s.Get(i) != nil { + t.Error() + } + } +} + +func TestFileSessionStore_SessionID(t *testing.T) { + mutex.Lock() + defer mutex.Unlock() + os.RemoveAll(sessionPath) + defer os.RemoveAll(sessionPath) + fp := &FileProvider{} + + _ = fp.SessionInit(180, sessionPath) + + sessionCount := 85 + + for i := 1; i <= sessionCount; i++ { + s, err := fp.SessionRead(fmt.Sprintf("%s_%d", sid, i)) + if err != nil { + t.Error(err) + } + if s.SessionID() != fmt.Sprintf("%s_%d", sid, i) { + t.Error(err) + } + } +} + +func TestFileSessionStore_SessionRelease(t *testing.T) { + mutex.Lock() + defer mutex.Unlock() + os.RemoveAll(sessionPath) + defer os.RemoveAll(sessionPath) + fp := &FileProvider{} + + _ = fp.SessionInit(180, sessionPath) + filepder.savePath = sessionPath + sessionCount := 85 + + for i := 1; i <= sessionCount; i++ { + s, err := fp.SessionRead(fmt.Sprintf("%s_%d", sid, i)) + if err != nil { + t.Error(err) + } + + + s.Set(i,i) + s.SessionRelease(nil) + } + + for i := 1; i <= sessionCount; i++ { + s, err := fp.SessionRead(fmt.Sprintf("%s_%d", sid, i)) + if err != nil { + t.Error(err) + } + + if s.Get(i).(int) != i { + t.Error() + } + } +} \ No newline at end of file diff --git a/pkg/session/sess_mem.go b/pkg/session/sess_mem.go new file mode 100644 index 00000000..64d8b056 --- /dev/null +++ b/pkg/session/sess_mem.go @@ -0,0 +1,196 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package session + +import ( + "container/list" + "net/http" + "sync" + "time" +) + +var mempder = &MemProvider{list: list.New(), sessions: make(map[string]*list.Element)} + +// MemSessionStore memory session store. +// it saved sessions in a map in memory. +type MemSessionStore struct { + sid string //session id + timeAccessed time.Time //last access time + value map[interface{}]interface{} //session store + lock sync.RWMutex +} + +// Set value to memory session +func (st *MemSessionStore) Set(key, value interface{}) error { + st.lock.Lock() + defer st.lock.Unlock() + st.value[key] = value + return nil +} + +// Get value from memory session by key +func (st *MemSessionStore) Get(key interface{}) interface{} { + st.lock.RLock() + defer st.lock.RUnlock() + if v, ok := st.value[key]; ok { + return v + } + return nil +} + +// Delete in memory session by key +func (st *MemSessionStore) Delete(key interface{}) error { + st.lock.Lock() + defer st.lock.Unlock() + delete(st.value, key) + return nil +} + +// Flush clear all values in memory session +func (st *MemSessionStore) Flush() error { + st.lock.Lock() + defer st.lock.Unlock() + st.value = make(map[interface{}]interface{}) + return nil +} + +// SessionID get this id of memory session store +func (st *MemSessionStore) SessionID() string { + return st.sid +} + +// SessionRelease Implement method, no used. +func (st *MemSessionStore) SessionRelease(w http.ResponseWriter) { +} + +// MemProvider Implement the provider interface +type MemProvider struct { + lock sync.RWMutex // locker + sessions map[string]*list.Element // map in memory + list *list.List // for gc + maxlifetime int64 + savePath string +} + +// SessionInit init memory session +func (pder *MemProvider) SessionInit(maxlifetime int64, savePath string) error { + pder.maxlifetime = maxlifetime + pder.savePath = savePath + return nil +} + +// SessionRead get memory session store by sid +func (pder *MemProvider) SessionRead(sid string) (Store, error) { + pder.lock.RLock() + if element, ok := pder.sessions[sid]; ok { + go pder.SessionUpdate(sid) + pder.lock.RUnlock() + return element.Value.(*MemSessionStore), nil + } + pder.lock.RUnlock() + pder.lock.Lock() + newsess := &MemSessionStore{sid: sid, timeAccessed: time.Now(), value: make(map[interface{}]interface{})} + element := pder.list.PushFront(newsess) + pder.sessions[sid] = element + pder.lock.Unlock() + return newsess, nil +} + +// SessionExist check session store exist in memory session by sid +func (pder *MemProvider) SessionExist(sid string) bool { + pder.lock.RLock() + defer pder.lock.RUnlock() + if _, ok := pder.sessions[sid]; ok { + return true + } + return false +} + +// SessionRegenerate generate new sid for session store in memory session +func (pder *MemProvider) SessionRegenerate(oldsid, sid string) (Store, error) { + pder.lock.RLock() + if element, ok := pder.sessions[oldsid]; ok { + go pder.SessionUpdate(oldsid) + pder.lock.RUnlock() + pder.lock.Lock() + element.Value.(*MemSessionStore).sid = sid + pder.sessions[sid] = element + delete(pder.sessions, oldsid) + pder.lock.Unlock() + return element.Value.(*MemSessionStore), nil + } + pder.lock.RUnlock() + pder.lock.Lock() + newsess := &MemSessionStore{sid: sid, timeAccessed: time.Now(), value: make(map[interface{}]interface{})} + element := pder.list.PushFront(newsess) + pder.sessions[sid] = element + pder.lock.Unlock() + return newsess, nil +} + +// SessionDestroy delete session store in memory session by id +func (pder *MemProvider) SessionDestroy(sid string) error { + pder.lock.Lock() + defer pder.lock.Unlock() + if element, ok := pder.sessions[sid]; ok { + delete(pder.sessions, sid) + pder.list.Remove(element) + return nil + } + return nil +} + +// SessionGC clean expired session stores in memory session +func (pder *MemProvider) SessionGC() { + pder.lock.RLock() + for { + element := pder.list.Back() + if element == nil { + break + } + if (element.Value.(*MemSessionStore).timeAccessed.Unix() + pder.maxlifetime) < time.Now().Unix() { + pder.lock.RUnlock() + pder.lock.Lock() + pder.list.Remove(element) + delete(pder.sessions, element.Value.(*MemSessionStore).sid) + pder.lock.Unlock() + pder.lock.RLock() + } else { + break + } + } + pder.lock.RUnlock() +} + +// SessionAll get count number of memory session +func (pder *MemProvider) SessionAll() int { + return pder.list.Len() +} + +// SessionUpdate expand time of session store by id in memory session +func (pder *MemProvider) SessionUpdate(sid string) error { + pder.lock.Lock() + defer pder.lock.Unlock() + if element, ok := pder.sessions[sid]; ok { + element.Value.(*MemSessionStore).timeAccessed = time.Now() + pder.list.MoveToFront(element) + return nil + } + return nil +} + +func init() { + Register("memory", mempder) +} diff --git a/pkg/session/sess_mem_test.go b/pkg/session/sess_mem_test.go new file mode 100644 index 00000000..2e8934b8 --- /dev/null +++ b/pkg/session/sess_mem_test.go @@ -0,0 +1,58 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package session + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "strings" + "testing" +) + +func TestMem(t *testing.T) { + config := `{"cookieName":"gosessionid","gclifetime":10, "enableSetCookie":true}` + conf := new(ManagerConfig) + if err := json.Unmarshal([]byte(config), conf); err != nil { + t.Fatal("json decode error", err) + } + globalSessions, _ := NewManager("memory", conf) + go globalSessions.GC() + r, _ := http.NewRequest("GET", "/", nil) + w := httptest.NewRecorder() + sess, err := globalSessions.SessionStart(w, r) + if err != nil { + t.Fatal("set error,", err) + } + defer sess.SessionRelease(w) + err = sess.Set("username", "astaxie") + if err != nil { + t.Fatal("set error,", err) + } + if username := sess.Get("username"); username != "astaxie" { + t.Fatal("get username error") + } + if cookiestr := w.Header().Get("Set-Cookie"); cookiestr == "" { + t.Fatal("setcookie error") + } else { + parts := strings.Split(strings.TrimSpace(cookiestr), ";") + for k, v := range parts { + nameval := strings.Split(v, "=") + if k == 0 && nameval[0] != "gosessionid" { + t.Fatal("error") + } + } + } +} diff --git a/pkg/session/sess_test.go b/pkg/session/sess_test.go new file mode 100644 index 00000000..906abec2 --- /dev/null +++ b/pkg/session/sess_test.go @@ -0,0 +1,131 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package session + +import ( + "crypto/aes" + "encoding/json" + "testing" +) + +func Test_gob(t *testing.T) { + a := make(map[interface{}]interface{}) + a["username"] = "astaxie" + a[12] = 234 + a["user"] = User{"asta", "xie"} + b, err := EncodeGob(a) + if err != nil { + t.Error(err) + } + c, err := DecodeGob(b) + if err != nil { + t.Error(err) + } + if len(c) == 0 { + t.Error("decodeGob empty") + } + if c["username"] != "astaxie" { + t.Error("decode string error") + } + if c[12] != 234 { + t.Error("decode int error") + } + if c["user"].(User).Username != "asta" { + t.Error("decode struct error") + } +} + +type User struct { + Username string + NickName string +} + +func TestGenerate(t *testing.T) { + str := generateRandomKey(20) + if len(str) != 20 { + t.Fatal("generate length is not equal to 20") + } +} + +func TestCookieEncodeDecode(t *testing.T) { + hashKey := "testhashKey" + blockkey := generateRandomKey(16) + block, err := aes.NewCipher(blockkey) + if err != nil { + t.Fatal("NewCipher:", err) + } + securityName := string(generateRandomKey(20)) + val := make(map[interface{}]interface{}) + val["name"] = "astaxie" + val["gender"] = "male" + str, err := encodeCookie(block, hashKey, securityName, val) + if err != nil { + t.Fatal("encodeCookie:", err) + } + dst, err := decodeCookie(block, hashKey, securityName, str, 3600) + if err != nil { + t.Fatal("decodeCookie", err) + } + if dst["name"] != "astaxie" { + t.Fatal("dst get map error") + } + if dst["gender"] != "male" { + t.Fatal("dst get map error") + } +} + +func TestParseConfig(t *testing.T) { + s := `{"cookieName":"gosessionid","gclifetime":3600}` + cf := new(ManagerConfig) + cf.EnableSetCookie = true + err := json.Unmarshal([]byte(s), cf) + if err != nil { + t.Fatal("parse json error,", err) + } + if cf.CookieName != "gosessionid" { + t.Fatal("parseconfig get cookiename error") + } + if cf.Gclifetime != 3600 { + t.Fatal("parseconfig get gclifetime error") + } + + cc := `{"cookieName":"gosessionid","enableSetCookie":false,"gclifetime":3600,"ProviderConfig":"{\"cookieName\":\"gosessionid\",\"securityKey\":\"beegocookiehashkey\"}"}` + cf2 := new(ManagerConfig) + cf2.EnableSetCookie = true + err = json.Unmarshal([]byte(cc), cf2) + if err != nil { + t.Fatal("parse json error,", err) + } + if cf2.CookieName != "gosessionid" { + t.Fatal("parseconfig get cookiename error") + } + if cf2.Gclifetime != 3600 { + t.Fatal("parseconfig get gclifetime error") + } + if cf2.EnableSetCookie { + t.Fatal("parseconfig get enableSetCookie error") + } + cconfig := new(cookieConfig) + err = json.Unmarshal([]byte(cf2.ProviderConfig), cconfig) + if err != nil { + t.Fatal("parse ProviderConfig err,", err) + } + if cconfig.CookieName != "gosessionid" { + t.Fatal("ProviderConfig get cookieName error") + } + if cconfig.SecurityKey != "beegocookiehashkey" { + t.Fatal("ProviderConfig get securityKey error") + } +} diff --git a/pkg/session/sess_utils.go b/pkg/session/sess_utils.go new file mode 100644 index 00000000..20915bb6 --- /dev/null +++ b/pkg/session/sess_utils.go @@ -0,0 +1,207 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package session + +import ( + "bytes" + "crypto/cipher" + "crypto/hmac" + "crypto/rand" + "crypto/sha256" + "crypto/subtle" + "encoding/base64" + "encoding/gob" + "errors" + "fmt" + "io" + "strconv" + "time" + + "github.com/astaxie/beego/utils" +) + +func init() { + gob.Register([]interface{}{}) + gob.Register(map[int]interface{}{}) + gob.Register(map[string]interface{}{}) + gob.Register(map[interface{}]interface{}{}) + gob.Register(map[string]string{}) + gob.Register(map[int]string{}) + gob.Register(map[int]int{}) + gob.Register(map[int]int64{}) +} + +// EncodeGob encode the obj to gob +func EncodeGob(obj map[interface{}]interface{}) ([]byte, error) { + for _, v := range obj { + gob.Register(v) + } + buf := bytes.NewBuffer(nil) + enc := gob.NewEncoder(buf) + err := enc.Encode(obj) + if err != nil { + return []byte(""), err + } + return buf.Bytes(), nil +} + +// DecodeGob decode data to map +func DecodeGob(encoded []byte) (map[interface{}]interface{}, error) { + buf := bytes.NewBuffer(encoded) + dec := gob.NewDecoder(buf) + var out map[interface{}]interface{} + err := dec.Decode(&out) + if err != nil { + return nil, err + } + return out, nil +} + +// generateRandomKey creates a random key with the given strength. +func generateRandomKey(strength int) []byte { + k := make([]byte, strength) + if n, err := io.ReadFull(rand.Reader, k); n != strength || err != nil { + return utils.RandomCreateBytes(strength) + } + return k +} + +// Encryption ----------------------------------------------------------------- + +// encrypt encrypts a value using the given block in counter mode. +// +// A random initialization vector (http://goo.gl/zF67k) with the length of the +// block size is prepended to the resulting ciphertext. +func encrypt(block cipher.Block, value []byte) ([]byte, error) { + iv := generateRandomKey(block.BlockSize()) + if iv == nil { + return nil, errors.New("encrypt: failed to generate random iv") + } + // Encrypt it. + stream := cipher.NewCTR(block, iv) + stream.XORKeyStream(value, value) + // Return iv + ciphertext. + return append(iv, value...), nil +} + +// decrypt decrypts a value using the given block in counter mode. +// +// The value to be decrypted must be prepended by a initialization vector +// (http://goo.gl/zF67k) with the length of the block size. +func decrypt(block cipher.Block, value []byte) ([]byte, error) { + size := block.BlockSize() + if len(value) > size { + // Extract iv. + iv := value[:size] + // Extract ciphertext. + value = value[size:] + // Decrypt it. + stream := cipher.NewCTR(block, iv) + stream.XORKeyStream(value, value) + return value, nil + } + return nil, errors.New("decrypt: the value could not be decrypted") +} + +func encodeCookie(block cipher.Block, hashKey, name string, value map[interface{}]interface{}) (string, error) { + var err error + var b []byte + // 1. EncodeGob. + if b, err = EncodeGob(value); err != nil { + return "", err + } + // 2. Encrypt (optional). + if b, err = encrypt(block, b); err != nil { + return "", err + } + b = encode(b) + // 3. Create MAC for "name|date|value". Extra pipe to be used later. + b = []byte(fmt.Sprintf("%s|%d|%s|", name, time.Now().UTC().Unix(), b)) + h := hmac.New(sha256.New, []byte(hashKey)) + h.Write(b) + sig := h.Sum(nil) + // Append mac, remove name. + b = append(b, sig...)[len(name)+1:] + // 4. Encode to base64. + b = encode(b) + // Done. + return string(b), nil +} + +func decodeCookie(block cipher.Block, hashKey, name, value string, gcmaxlifetime int64) (map[interface{}]interface{}, error) { + // 1. Decode from base64. + b, err := decode([]byte(value)) + if err != nil { + return nil, err + } + // 2. Verify MAC. Value is "date|value|mac". + parts := bytes.SplitN(b, []byte("|"), 3) + if len(parts) != 3 { + return nil, errors.New("Decode: invalid value format") + } + + b = append([]byte(name+"|"), b[:len(b)-len(parts[2])]...) + h := hmac.New(sha256.New, []byte(hashKey)) + h.Write(b) + sig := h.Sum(nil) + if len(sig) != len(parts[2]) || subtle.ConstantTimeCompare(sig, parts[2]) != 1 { + return nil, errors.New("Decode: the value is not valid") + } + // 3. Verify date ranges. + var t1 int64 + if t1, err = strconv.ParseInt(string(parts[0]), 10, 64); err != nil { + return nil, errors.New("Decode: invalid timestamp") + } + t2 := time.Now().UTC().Unix() + if t1 > t2 { + return nil, errors.New("Decode: timestamp is too new") + } + if t1 < t2-gcmaxlifetime { + return nil, errors.New("Decode: expired timestamp") + } + // 4. Decrypt (optional). + b, err = decode(parts[1]) + if err != nil { + return nil, err + } + if b, err = decrypt(block, b); err != nil { + return nil, err + } + // 5. DecodeGob. + dst, err := DecodeGob(b) + if err != nil { + return nil, err + } + return dst, nil +} + +// Encoding ------------------------------------------------------------------- + +// encode encodes a value using base64. +func encode(value []byte) []byte { + encoded := make([]byte, base64.URLEncoding.EncodedLen(len(value))) + base64.URLEncoding.Encode(encoded, value) + return encoded +} + +// decode decodes a cookie using base64. +func decode(value []byte) ([]byte, error) { + decoded := make([]byte, base64.URLEncoding.DecodedLen(len(value))) + b, err := base64.URLEncoding.Decode(decoded, value) + if err != nil { + return nil, err + } + return decoded[:b], nil +} diff --git a/pkg/session/session.go b/pkg/session/session.go new file mode 100644 index 00000000..eb85360a --- /dev/null +++ b/pkg/session/session.go @@ -0,0 +1,377 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package session provider +// +// Usage: +// import( +// "github.com/astaxie/beego/session" +// ) +// +// func init() { +// globalSessions, _ = session.NewManager("memory", `{"cookieName":"gosessionid", "enableSetCookie,omitempty": true, "gclifetime":3600, "maxLifetime": 3600, "secure": false, "cookieLifeTime": 3600, "providerConfig": ""}`) +// go globalSessions.GC() +// } +// +// more docs: http://beego.me/docs/module/session.md +package session + +import ( + "crypto/rand" + "encoding/hex" + "errors" + "fmt" + "io" + "log" + "net/http" + "net/textproto" + "net/url" + "os" + "time" +) + +// Store contains all data for one session process with specific id. +type Store interface { + Set(key, value interface{}) error //set session value + Get(key interface{}) interface{} //get session value + Delete(key interface{}) error //delete session value + SessionID() string //back current sessionID + SessionRelease(w http.ResponseWriter) // release the resource & save data to provider & return the data + Flush() error //delete all data +} + +// Provider contains global session methods and saved SessionStores. +// it can operate a SessionStore by its id. +type Provider interface { + SessionInit(gclifetime int64, config string) error + SessionRead(sid string) (Store, error) + SessionExist(sid string) bool + SessionRegenerate(oldsid, sid string) (Store, error) + SessionDestroy(sid string) error + SessionAll() int //get all active session + SessionGC() +} + +var provides = make(map[string]Provider) + +// SLogger a helpful variable to log information about session +var SLogger = NewSessionLog(os.Stderr) + +// Register makes a session provide available by the provided name. +// If Register is called twice with the same name or if driver is nil, +// it panics. +func Register(name string, provide Provider) { + if provide == nil { + panic("session: Register provide is nil") + } + if _, dup := provides[name]; dup { + panic("session: Register called twice for provider " + name) + } + provides[name] = provide +} + +//GetProvider +func GetProvider(name string) (Provider, error) { + provider, ok := provides[name] + if !ok { + return nil, fmt.Errorf("session: unknown provide %q (forgotten import?)", name) + } + return provider, nil +} + +// ManagerConfig define the session config +type ManagerConfig struct { + CookieName string `json:"cookieName"` + EnableSetCookie bool `json:"enableSetCookie,omitempty"` + Gclifetime int64 `json:"gclifetime"` + Maxlifetime int64 `json:"maxLifetime"` + DisableHTTPOnly bool `json:"disableHTTPOnly"` + Secure bool `json:"secure"` + CookieLifeTime int `json:"cookieLifeTime"` + ProviderConfig string `json:"providerConfig"` + Domain string `json:"domain"` + SessionIDLength int64 `json:"sessionIDLength"` + EnableSidInHTTPHeader bool `json:"EnableSidInHTTPHeader"` + SessionNameInHTTPHeader string `json:"SessionNameInHTTPHeader"` + EnableSidInURLQuery bool `json:"EnableSidInURLQuery"` + SessionIDPrefix string `json:"sessionIDPrefix"` +} + +// Manager contains Provider and its configuration. +type Manager struct { + provider Provider + config *ManagerConfig +} + +// NewManager Create new Manager with provider name and json config string. +// provider name: +// 1. cookie +// 2. file +// 3. memory +// 4. redis +// 5. mysql +// json config: +// 1. is https default false +// 2. hashfunc default sha1 +// 3. hashkey default beegosessionkey +// 4. maxage default is none +func NewManager(provideName string, cf *ManagerConfig) (*Manager, error) { + provider, ok := provides[provideName] + if !ok { + return nil, fmt.Errorf("session: unknown provide %q (forgotten import?)", provideName) + } + + if cf.Maxlifetime == 0 { + cf.Maxlifetime = cf.Gclifetime + } + + if cf.EnableSidInHTTPHeader { + if cf.SessionNameInHTTPHeader == "" { + panic(errors.New("SessionNameInHTTPHeader is empty")) + } + + strMimeHeader := textproto.CanonicalMIMEHeaderKey(cf.SessionNameInHTTPHeader) + if cf.SessionNameInHTTPHeader != strMimeHeader { + strErrMsg := "SessionNameInHTTPHeader (" + cf.SessionNameInHTTPHeader + ") has the wrong format, it should be like this : " + strMimeHeader + panic(errors.New(strErrMsg)) + } + } + + err := provider.SessionInit(cf.Maxlifetime, cf.ProviderConfig) + if err != nil { + return nil, err + } + + if cf.SessionIDLength == 0 { + cf.SessionIDLength = 16 + } + + return &Manager{ + provider, + cf, + }, nil +} + +// GetProvider return current manager's provider +func (manager *Manager) GetProvider() Provider { + return manager.provider +} + +// getSid retrieves session identifier from HTTP Request. +// First try to retrieve id by reading from cookie, session cookie name is configurable, +// if not exist, then retrieve id from querying parameters. +// +// error is not nil when there is anything wrong. +// sid is empty when need to generate a new session id +// otherwise return an valid session id. +func (manager *Manager) getSid(r *http.Request) (string, error) { + cookie, errs := r.Cookie(manager.config.CookieName) + if errs != nil || cookie.Value == "" { + var sid string + if manager.config.EnableSidInURLQuery { + errs := r.ParseForm() + if errs != nil { + return "", errs + } + + sid = r.FormValue(manager.config.CookieName) + } + + // if not found in Cookie / param, then read it from request headers + if manager.config.EnableSidInHTTPHeader && sid == "" { + sids, isFound := r.Header[manager.config.SessionNameInHTTPHeader] + if isFound && len(sids) != 0 { + return sids[0], nil + } + } + + return sid, nil + } + + // HTTP Request contains cookie for sessionid info. + return url.QueryUnescape(cookie.Value) +} + +// SessionStart generate or read the session id from http request. +// if session id exists, return SessionStore with this id. +func (manager *Manager) SessionStart(w http.ResponseWriter, r *http.Request) (session Store, err error) { + sid, errs := manager.getSid(r) + if errs != nil { + return nil, errs + } + + if sid != "" && manager.provider.SessionExist(sid) { + return manager.provider.SessionRead(sid) + } + + // Generate a new session + sid, errs = manager.sessionID() + if errs != nil { + return nil, errs + } + + session, err = manager.provider.SessionRead(sid) + if err != nil { + return nil, err + } + cookie := &http.Cookie{ + Name: manager.config.CookieName, + Value: url.QueryEscape(sid), + Path: "/", + HttpOnly: !manager.config.DisableHTTPOnly, + Secure: manager.isSecure(r), + Domain: manager.config.Domain, + } + if manager.config.CookieLifeTime > 0 { + cookie.MaxAge = manager.config.CookieLifeTime + cookie.Expires = time.Now().Add(time.Duration(manager.config.CookieLifeTime) * time.Second) + } + if manager.config.EnableSetCookie { + http.SetCookie(w, cookie) + } + r.AddCookie(cookie) + + if manager.config.EnableSidInHTTPHeader { + r.Header.Set(manager.config.SessionNameInHTTPHeader, sid) + w.Header().Set(manager.config.SessionNameInHTTPHeader, sid) + } + + return +} + +// SessionDestroy Destroy session by its id in http request cookie. +func (manager *Manager) SessionDestroy(w http.ResponseWriter, r *http.Request) { + if manager.config.EnableSidInHTTPHeader { + r.Header.Del(manager.config.SessionNameInHTTPHeader) + w.Header().Del(manager.config.SessionNameInHTTPHeader) + } + + cookie, err := r.Cookie(manager.config.CookieName) + if err != nil || cookie.Value == "" { + return + } + + sid, _ := url.QueryUnescape(cookie.Value) + manager.provider.SessionDestroy(sid) + if manager.config.EnableSetCookie { + expiration := time.Now() + cookie = &http.Cookie{Name: manager.config.CookieName, + Path: "/", + HttpOnly: !manager.config.DisableHTTPOnly, + Expires: expiration, + MaxAge: -1, + Domain: manager.config.Domain} + + http.SetCookie(w, cookie) + } +} + +// GetSessionStore Get SessionStore by its id. +func (manager *Manager) GetSessionStore(sid string) (sessions Store, err error) { + sessions, err = manager.provider.SessionRead(sid) + return +} + +// GC Start session gc process. +// it can do gc in times after gc lifetime. +func (manager *Manager) GC() { + manager.provider.SessionGC() + time.AfterFunc(time.Duration(manager.config.Gclifetime)*time.Second, func() { manager.GC() }) +} + +// SessionRegenerateID Regenerate a session id for this SessionStore who's id is saving in http request. +func (manager *Manager) SessionRegenerateID(w http.ResponseWriter, r *http.Request) (session Store) { + sid, err := manager.sessionID() + if err != nil { + return + } + cookie, err := r.Cookie(manager.config.CookieName) + if err != nil || cookie.Value == "" { + //delete old cookie + session, _ = manager.provider.SessionRead(sid) + cookie = &http.Cookie{Name: manager.config.CookieName, + Value: url.QueryEscape(sid), + Path: "/", + HttpOnly: !manager.config.DisableHTTPOnly, + Secure: manager.isSecure(r), + Domain: manager.config.Domain, + } + } else { + oldsid, _ := url.QueryUnescape(cookie.Value) + session, _ = manager.provider.SessionRegenerate(oldsid, sid) + cookie.Value = url.QueryEscape(sid) + cookie.HttpOnly = true + cookie.Path = "/" + } + if manager.config.CookieLifeTime > 0 { + cookie.MaxAge = manager.config.CookieLifeTime + cookie.Expires = time.Now().Add(time.Duration(manager.config.CookieLifeTime) * time.Second) + } + if manager.config.EnableSetCookie { + http.SetCookie(w, cookie) + } + r.AddCookie(cookie) + + if manager.config.EnableSidInHTTPHeader { + r.Header.Set(manager.config.SessionNameInHTTPHeader, sid) + w.Header().Set(manager.config.SessionNameInHTTPHeader, sid) + } + + return +} + +// GetActiveSession Get all active sessions count number. +func (manager *Manager) GetActiveSession() int { + return manager.provider.SessionAll() +} + +// SetSecure Set cookie with https. +func (manager *Manager) SetSecure(secure bool) { + manager.config.Secure = secure +} + +func (manager *Manager) sessionID() (string, error) { + b := make([]byte, manager.config.SessionIDLength) + n, err := rand.Read(b) + if n != len(b) || err != nil { + return "", fmt.Errorf("Could not successfully read from the system CSPRNG") + } + return manager.config.SessionIDPrefix + hex.EncodeToString(b), nil +} + +// Set cookie with https. +func (manager *Manager) isSecure(req *http.Request) bool { + if !manager.config.Secure { + return false + } + if req.URL.Scheme != "" { + return req.URL.Scheme == "https" + } + if req.TLS == nil { + return false + } + return true +} + +// Log implement the log.Logger +type Log struct { + *log.Logger +} + +// NewSessionLog set io.Writer to create a Logger for session. +func NewSessionLog(out io.Writer) *Log { + sl := new(Log) + sl.Logger = log.New(out, "[SESSION]", 1e9) + return sl +} diff --git a/pkg/session/ssdb/sess_ssdb.go b/pkg/session/ssdb/sess_ssdb.go new file mode 100644 index 00000000..de0c6360 --- /dev/null +++ b/pkg/session/ssdb/sess_ssdb.go @@ -0,0 +1,199 @@ +package ssdb + +import ( + "errors" + "net/http" + "strconv" + "strings" + "sync" + + "github.com/astaxie/beego/session" + "github.com/ssdb/gossdb/ssdb" +) + +var ssdbProvider = &Provider{} + +// Provider holds ssdb client and configs +type Provider struct { + client *ssdb.Client + host string + port int + maxLifetime int64 +} + +func (p *Provider) connectInit() error { + var err error + if p.host == "" || p.port == 0 { + return errors.New("SessionInit First") + } + p.client, err = ssdb.Connect(p.host, p.port) + return err +} + +// SessionInit init the ssdb with the config +func (p *Provider) SessionInit(maxLifetime int64, savePath string) error { + p.maxLifetime = maxLifetime + address := strings.Split(savePath, ":") + p.host = address[0] + + var err error + if p.port, err = strconv.Atoi(address[1]); err != nil { + return err + } + return p.connectInit() +} + +// SessionRead return a ssdb client session Store +func (p *Provider) SessionRead(sid string) (session.Store, error) { + if p.client == nil { + if err := p.connectInit(); err != nil { + return nil, err + } + } + var kv map[interface{}]interface{} + value, err := p.client.Get(sid) + if err != nil { + return nil, err + } + if value == nil || len(value.(string)) == 0 { + kv = make(map[interface{}]interface{}) + } else { + kv, err = session.DecodeGob([]byte(value.(string))) + if err != nil { + return nil, err + } + } + rs := &SessionStore{sid: sid, values: kv, maxLifetime: p.maxLifetime, client: p.client} + return rs, nil +} + +// SessionExist judged whether sid is exist in session +func (p *Provider) SessionExist(sid string) bool { + if p.client == nil { + if err := p.connectInit(); err != nil { + panic(err) + } + } + value, err := p.client.Get(sid) + if err != nil { + panic(err) + } + if value == nil || len(value.(string)) == 0 { + return false + } + return true +} + +// SessionRegenerate regenerate session with new sid and delete oldsid +func (p *Provider) SessionRegenerate(oldsid, sid string) (session.Store, error) { + //conn.Do("setx", key, v, ttl) + if p.client == nil { + if err := p.connectInit(); err != nil { + return nil, err + } + } + value, err := p.client.Get(oldsid) + if err != nil { + return nil, err + } + var kv map[interface{}]interface{} + if value == nil || len(value.(string)) == 0 { + kv = make(map[interface{}]interface{}) + } else { + kv, err = session.DecodeGob([]byte(value.(string))) + if err != nil { + return nil, err + } + _, err = p.client.Del(oldsid) + if err != nil { + return nil, err + } + } + _, e := p.client.Do("setx", sid, value, p.maxLifetime) + if e != nil { + return nil, e + } + rs := &SessionStore{sid: sid, values: kv, maxLifetime: p.maxLifetime, client: p.client} + return rs, nil +} + +// SessionDestroy destroy the sid +func (p *Provider) SessionDestroy(sid string) error { + if p.client == nil { + if err := p.connectInit(); err != nil { + return err + } + } + _, err := p.client.Del(sid) + return err +} + +// SessionGC not implemented +func (p *Provider) SessionGC() { +} + +// SessionAll not implemented +func (p *Provider) SessionAll() int { + return 0 +} + +// SessionStore holds the session information which stored in ssdb +type SessionStore struct { + sid string + lock sync.RWMutex + values map[interface{}]interface{} + maxLifetime int64 + client *ssdb.Client +} + +// Set the key and value +func (s *SessionStore) Set(key, value interface{}) error { + s.lock.Lock() + defer s.lock.Unlock() + s.values[key] = value + return nil +} + +// Get return the value by the key +func (s *SessionStore) Get(key interface{}) interface{} { + s.lock.Lock() + defer s.lock.Unlock() + if value, ok := s.values[key]; ok { + return value + } + return nil +} + +// Delete the key in session store +func (s *SessionStore) Delete(key interface{}) error { + s.lock.Lock() + defer s.lock.Unlock() + delete(s.values, key) + return nil +} + +// Flush delete all keys and values +func (s *SessionStore) Flush() error { + s.lock.Lock() + defer s.lock.Unlock() + s.values = make(map[interface{}]interface{}) + return nil +} + +// SessionID return the sessionID +func (s *SessionStore) SessionID() string { + return s.sid +} + +// SessionRelease Store the keyvalues into ssdb +func (s *SessionStore) SessionRelease(w http.ResponseWriter) { + b, err := session.EncodeGob(s.values) + if err != nil { + return + } + s.client.Do("setx", s.sid, string(b), s.maxLifetime) +} + +func init() { + session.Register("ssdb", ssdbProvider) +} diff --git a/pkg/staticfile.go b/pkg/staticfile.go new file mode 100644 index 00000000..84e9aa7b --- /dev/null +++ b/pkg/staticfile.go @@ -0,0 +1,234 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package beego + +import ( + "bytes" + "errors" + "net/http" + "os" + "path" + "path/filepath" + "strconv" + "strings" + "sync" + "time" + + "github.com/astaxie/beego/context" + "github.com/astaxie/beego/logs" + "github.com/hashicorp/golang-lru" +) + +var errNotStaticRequest = errors.New("request not a static file request") + +func serverStaticRouter(ctx *context.Context) { + if ctx.Input.Method() != "GET" && ctx.Input.Method() != "HEAD" { + return + } + + forbidden, filePath, fileInfo, err := lookupFile(ctx) + if err == errNotStaticRequest { + return + } + + if forbidden { + exception("403", ctx) + return + } + + if filePath == "" || fileInfo == nil { + if BConfig.RunMode == DEV { + logs.Warn("Can't find/open the file:", filePath, err) + } + http.NotFound(ctx.ResponseWriter, ctx.Request) + return + } + if fileInfo.IsDir() { + requestURL := ctx.Input.URL() + if requestURL[len(requestURL)-1] != '/' { + redirectURL := requestURL + "/" + if ctx.Request.URL.RawQuery != "" { + redirectURL = redirectURL + "?" + ctx.Request.URL.RawQuery + } + ctx.Redirect(302, redirectURL) + } else { + //serveFile will list dir + http.ServeFile(ctx.ResponseWriter, ctx.Request, filePath) + } + return + } else if fileInfo.Size() > int64(BConfig.WebConfig.StaticCacheFileSize) { + //over size file serve with http module + http.ServeFile(ctx.ResponseWriter, ctx.Request, filePath) + return + } + + var enableCompress = BConfig.EnableGzip && isStaticCompress(filePath) + var acceptEncoding string + if enableCompress { + acceptEncoding = context.ParseEncoding(ctx.Request) + } + b, n, sch, reader, err := openFile(filePath, fileInfo, acceptEncoding) + if err != nil { + if BConfig.RunMode == DEV { + logs.Warn("Can't compress the file:", filePath, err) + } + http.NotFound(ctx.ResponseWriter, ctx.Request) + return + } + + if b { + ctx.Output.Header("Content-Encoding", n) + } else { + ctx.Output.Header("Content-Length", strconv.FormatInt(sch.size, 10)) + } + + http.ServeContent(ctx.ResponseWriter, ctx.Request, filePath, sch.modTime, reader) +} + +type serveContentHolder struct { + data []byte + modTime time.Time + size int64 + originSize int64 //original file size:to judge file changed + encoding string +} + +type serveContentReader struct { + *bytes.Reader +} + +var ( + staticFileLruCache *lru.Cache + lruLock sync.RWMutex +) + +func openFile(filePath string, fi os.FileInfo, acceptEncoding string) (bool, string, *serveContentHolder, *serveContentReader, error) { + if staticFileLruCache == nil { + //avoid lru cache error + if BConfig.WebConfig.StaticCacheFileNum >= 1 { + staticFileLruCache, _ = lru.New(BConfig.WebConfig.StaticCacheFileNum) + } else { + staticFileLruCache, _ = lru.New(1) + } + } + mapKey := acceptEncoding + ":" + filePath + lruLock.RLock() + var mapFile *serveContentHolder + if cacheItem, ok := staticFileLruCache.Get(mapKey); ok { + mapFile = cacheItem.(*serveContentHolder) + } + lruLock.RUnlock() + if isOk(mapFile, fi) { + reader := &serveContentReader{Reader: bytes.NewReader(mapFile.data)} + return mapFile.encoding != "", mapFile.encoding, mapFile, reader, nil + } + lruLock.Lock() + defer lruLock.Unlock() + if cacheItem, ok := staticFileLruCache.Get(mapKey); ok { + mapFile = cacheItem.(*serveContentHolder) + } + if !isOk(mapFile, fi) { + file, err := os.Open(filePath) + if err != nil { + return false, "", nil, nil, err + } + defer file.Close() + var bufferWriter bytes.Buffer + _, n, err := context.WriteFile(acceptEncoding, &bufferWriter, file) + if err != nil { + return false, "", nil, nil, err + } + mapFile = &serveContentHolder{data: bufferWriter.Bytes(), modTime: fi.ModTime(), size: int64(bufferWriter.Len()), originSize: fi.Size(), encoding: n} + if isOk(mapFile, fi) { + staticFileLruCache.Add(mapKey, mapFile) + } + } + + reader := &serveContentReader{Reader: bytes.NewReader(mapFile.data)} + return mapFile.encoding != "", mapFile.encoding, mapFile, reader, nil +} + +func isOk(s *serveContentHolder, fi os.FileInfo) bool { + if s == nil { + return false + } else if s.size > int64(BConfig.WebConfig.StaticCacheFileSize) { + return false + } + return s.modTime == fi.ModTime() && s.originSize == fi.Size() +} + +// isStaticCompress detect static files +func isStaticCompress(filePath string) bool { + for _, statExtension := range BConfig.WebConfig.StaticExtensionsToGzip { + if strings.HasSuffix(strings.ToLower(filePath), strings.ToLower(statExtension)) { + return true + } + } + return false +} + +// searchFile search the file by url path +// if none the static file prefix matches ,return notStaticRequestErr +func searchFile(ctx *context.Context) (string, os.FileInfo, error) { + requestPath := filepath.ToSlash(filepath.Clean(ctx.Request.URL.Path)) + // special processing : favicon.ico/robots.txt can be in any static dir + if requestPath == "/favicon.ico" || requestPath == "/robots.txt" { + file := path.Join(".", requestPath) + if fi, _ := os.Stat(file); fi != nil { + return file, fi, nil + } + for _, staticDir := range BConfig.WebConfig.StaticDir { + filePath := path.Join(staticDir, requestPath) + if fi, _ := os.Stat(filePath); fi != nil { + return filePath, fi, nil + } + } + return "", nil, errNotStaticRequest + } + + for prefix, staticDir := range BConfig.WebConfig.StaticDir { + if !strings.Contains(requestPath, prefix) { + continue + } + if prefix != "/" && len(requestPath) > len(prefix) && requestPath[len(prefix)] != '/' { + continue + } + filePath := path.Join(staticDir, requestPath[len(prefix):]) + if fi, err := os.Stat(filePath); fi != nil { + return filePath, fi, err + } + } + return "", nil, errNotStaticRequest +} + +// lookupFile find the file to serve +// if the file is dir ,search the index.html as default file( MUST NOT A DIR also) +// if the index.html not exist or is a dir, give a forbidden response depending on DirectoryIndex +func lookupFile(ctx *context.Context) (bool, string, os.FileInfo, error) { + fp, fi, err := searchFile(ctx) + if fp == "" || fi == nil { + return false, "", nil, err + } + if !fi.IsDir() { + return false, fp, fi, err + } + if requestURL := ctx.Input.URL(); requestURL[len(requestURL)-1] == '/' { + ifp := filepath.Join(fp, "index.html") + if ifi, _ := os.Stat(ifp); ifi != nil && ifi.Mode().IsRegular() { + return false, ifp, ifi, err + } + } + return !BConfig.WebConfig.DirectoryIndex, fp, fi, err +} diff --git a/pkg/staticfile_test.go b/pkg/staticfile_test.go new file mode 100644 index 00000000..e46c13ec --- /dev/null +++ b/pkg/staticfile_test.go @@ -0,0 +1,99 @@ +package beego + +import ( + "bytes" + "compress/gzip" + "compress/zlib" + "fmt" + "io" + "io/ioutil" + "os" + "path/filepath" + "testing" +) + +var currentWorkDir, _ = os.Getwd() +var licenseFile = filepath.Join(currentWorkDir, "LICENSE") + +func testOpenFile(encoding string, content []byte, t *testing.T) { + fi, _ := os.Stat(licenseFile) + b, n, sch, reader, err := openFile(licenseFile, fi, encoding) + if err != nil { + t.Log(err) + t.Fail() + } + + t.Log("open static file encoding "+n, b) + + assetOpenFileAndContent(sch, reader, content, t) +} +func TestOpenStaticFile_1(t *testing.T) { + file, _ := os.Open(licenseFile) + content, _ := ioutil.ReadAll(file) + testOpenFile("", content, t) +} + +func TestOpenStaticFileGzip_1(t *testing.T) { + file, _ := os.Open(licenseFile) + var zipBuf bytes.Buffer + fileWriter, _ := gzip.NewWriterLevel(&zipBuf, gzip.BestCompression) + io.Copy(fileWriter, file) + fileWriter.Close() + content, _ := ioutil.ReadAll(&zipBuf) + + testOpenFile("gzip", content, t) +} +func TestOpenStaticFileDeflate_1(t *testing.T) { + file, _ := os.Open(licenseFile) + var zipBuf bytes.Buffer + fileWriter, _ := zlib.NewWriterLevel(&zipBuf, zlib.BestCompression) + io.Copy(fileWriter, file) + fileWriter.Close() + content, _ := ioutil.ReadAll(&zipBuf) + + testOpenFile("deflate", content, t) +} + +func TestStaticCacheWork(t *testing.T) { + encodings := []string{"", "gzip", "deflate"} + + fi, _ := os.Stat(licenseFile) + for _, encoding := range encodings { + _, _, first, _, err := openFile(licenseFile, fi, encoding) + if err != nil { + t.Error(err) + continue + } + + _, _, second, _, err := openFile(licenseFile, fi, encoding) + if err != nil { + t.Error(err) + continue + } + + address1 := fmt.Sprintf("%p", first) + address2 := fmt.Sprintf("%p", second) + if address1 != address2 { + t.Errorf("encoding '%v' can not hit cache", encoding) + } + } +} + +func assetOpenFileAndContent(sch *serveContentHolder, reader *serveContentReader, content []byte, t *testing.T) { + t.Log(sch.size, len(content)) + if sch.size != int64(len(content)) { + t.Log("static content file size not same") + t.Fail() + } + bs, _ := ioutil.ReadAll(reader) + for i, v := range content { + if v != bs[i] { + t.Log("content not same") + t.Fail() + } + } + if staticFileLruCache.Len() == 0 { + t.Log("men map is empty") + t.Fail() + } +} diff --git a/pkg/swagger/swagger.go b/pkg/swagger/swagger.go new file mode 100644 index 00000000..a55676cd --- /dev/null +++ b/pkg/swagger/swagger.go @@ -0,0 +1,174 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Swagger™ is a project used to describe and document RESTful APIs. +// +// The Swagger specification defines a set of files required to describe such an API. These files can then be used by the Swagger-UI project to display the API and Swagger-Codegen to generate clients in various languages. Additional utilities can also take advantage of the resulting files, such as testing tools. +// Now in version 2.0, Swagger is more enabling than ever. And it's 100% open source software. + +// Package swagger struct definition +package swagger + +// Swagger list the resource +type Swagger struct { + SwaggerVersion string `json:"swagger,omitempty" yaml:"swagger,omitempty"` + Infos Information `json:"info" yaml:"info"` + Host string `json:"host,omitempty" yaml:"host,omitempty"` + BasePath string `json:"basePath,omitempty" yaml:"basePath,omitempty"` + Schemes []string `json:"schemes,omitempty" yaml:"schemes,omitempty"` + Consumes []string `json:"consumes,omitempty" yaml:"consumes,omitempty"` + Produces []string `json:"produces,omitempty" yaml:"produces,omitempty"` + Paths map[string]*Item `json:"paths" yaml:"paths"` + Definitions map[string]Schema `json:"definitions,omitempty" yaml:"definitions,omitempty"` + SecurityDefinitions map[string]Security `json:"securityDefinitions,omitempty" yaml:"securityDefinitions,omitempty"` + Security []map[string][]string `json:"security,omitempty" yaml:"security,omitempty"` + Tags []Tag `json:"tags,omitempty" yaml:"tags,omitempty"` + ExternalDocs *ExternalDocs `json:"externalDocs,omitempty" yaml:"externalDocs,omitempty"` +} + +// Information Provides metadata about the API. The metadata can be used by the clients if needed. +type Information struct { + Title string `json:"title,omitempty" yaml:"title,omitempty"` + Description string `json:"description,omitempty" yaml:"description,omitempty"` + Version string `json:"version,omitempty" yaml:"version,omitempty"` + TermsOfService string `json:"termsOfService,omitempty" yaml:"termsOfService,omitempty"` + + Contact Contact `json:"contact,omitempty" yaml:"contact,omitempty"` + License *License `json:"license,omitempty" yaml:"license,omitempty"` +} + +// Contact information for the exposed API. +type Contact struct { + Name string `json:"name,omitempty" yaml:"name,omitempty"` + URL string `json:"url,omitempty" yaml:"url,omitempty"` + EMail string `json:"email,omitempty" yaml:"email,omitempty"` +} + +// License information for the exposed API. +type License struct { + Name string `json:"name,omitempty" yaml:"name,omitempty"` + URL string `json:"url,omitempty" yaml:"url,omitempty"` +} + +// Item Describes the operations available on a single path. +type Item struct { + Ref string `json:"$ref,omitempty" yaml:"$ref,omitempty"` + Get *Operation `json:"get,omitempty" yaml:"get,omitempty"` + Put *Operation `json:"put,omitempty" yaml:"put,omitempty"` + Post *Operation `json:"post,omitempty" yaml:"post,omitempty"` + Delete *Operation `json:"delete,omitempty" yaml:"delete,omitempty"` + Options *Operation `json:"options,omitempty" yaml:"options,omitempty"` + Head *Operation `json:"head,omitempty" yaml:"head,omitempty"` + Patch *Operation `json:"patch,omitempty" yaml:"patch,omitempty"` +} + +// Operation Describes a single API operation on a path. +type Operation struct { + Tags []string `json:"tags,omitempty" yaml:"tags,omitempty"` + Summary string `json:"summary,omitempty" yaml:"summary,omitempty"` + Description string `json:"description,omitempty" yaml:"description,omitempty"` + OperationID string `json:"operationId,omitempty" yaml:"operationId,omitempty"` + Consumes []string `json:"consumes,omitempty" yaml:"consumes,omitempty"` + Produces []string `json:"produces,omitempty" yaml:"produces,omitempty"` + Schemes []string `json:"schemes,omitempty" yaml:"schemes,omitempty"` + Parameters []Parameter `json:"parameters,omitempty" yaml:"parameters,omitempty"` + Responses map[string]Response `json:"responses,omitempty" yaml:"responses,omitempty"` + Security []map[string][]string `json:"security,omitempty" yaml:"security,omitempty"` + Deprecated bool `json:"deprecated,omitempty" yaml:"deprecated,omitempty"` +} + +// Parameter Describes a single operation parameter. +type Parameter struct { + In string `json:"in,omitempty" yaml:"in,omitempty"` + Name string `json:"name,omitempty" yaml:"name,omitempty"` + Description string `json:"description,omitempty" yaml:"description,omitempty"` + Required bool `json:"required,omitempty" yaml:"required,omitempty"` + Schema *Schema `json:"schema,omitempty" yaml:"schema,omitempty"` + Type string `json:"type,omitempty" yaml:"type,omitempty"` + Format string `json:"format,omitempty" yaml:"format,omitempty"` + Items *ParameterItems `json:"items,omitempty" yaml:"items,omitempty"` + Default interface{} `json:"default,omitempty" yaml:"default,omitempty"` +} + +// ParameterItems A limited subset of JSON-Schema's items object. It is used by parameter definitions that are not located in "body". +// http://swagger.io/specification/#itemsObject +type ParameterItems struct { + Type string `json:"type,omitempty" yaml:"type,omitempty"` + Format string `json:"format,omitempty" yaml:"format,omitempty"` + Items []*ParameterItems `json:"items,omitempty" yaml:"items,omitempty"` //Required if type is "array". Describes the type of items in the array. + CollectionFormat string `json:"collectionFormat,omitempty" yaml:"collectionFormat,omitempty"` + Default string `json:"default,omitempty" yaml:"default,omitempty"` +} + +// Schema Object allows the definition of input and output data types. +type Schema struct { + Ref string `json:"$ref,omitempty" yaml:"$ref,omitempty"` + Title string `json:"title,omitempty" yaml:"title,omitempty"` + Format string `json:"format,omitempty" yaml:"format,omitempty"` + Description string `json:"description,omitempty" yaml:"description,omitempty"` + Required []string `json:"required,omitempty" yaml:"required,omitempty"` + Type string `json:"type,omitempty" yaml:"type,omitempty"` + Items *Schema `json:"items,omitempty" yaml:"items,omitempty"` + Properties map[string]Propertie `json:"properties,omitempty" yaml:"properties,omitempty"` + Enum []interface{} `json:"enum,omitempty" yaml:"enum,omitempty"` + Example interface{} `json:"example,omitempty" yaml:"example,omitempty"` +} + +// Propertie are taken from the JSON Schema definition but their definitions were adjusted to the Swagger Specification +type Propertie struct { + Ref string `json:"$ref,omitempty" yaml:"$ref,omitempty"` + Title string `json:"title,omitempty" yaml:"title,omitempty"` + Description string `json:"description,omitempty" yaml:"description,omitempty"` + Default interface{} `json:"default,omitempty" yaml:"default,omitempty"` + Type string `json:"type,omitempty" yaml:"type,omitempty"` + Example interface{} `json:"example,omitempty" yaml:"example,omitempty"` + Required []string `json:"required,omitempty" yaml:"required,omitempty"` + Format string `json:"format,omitempty" yaml:"format,omitempty"` + ReadOnly bool `json:"readOnly,omitempty" yaml:"readOnly,omitempty"` + Properties map[string]Propertie `json:"properties,omitempty" yaml:"properties,omitempty"` + Items *Propertie `json:"items,omitempty" yaml:"items,omitempty"` + AdditionalProperties *Propertie `json:"additionalProperties,omitempty" yaml:"additionalProperties,omitempty"` +} + +// Response as they are returned from executing this operation. +type Response struct { + Description string `json:"description" yaml:"description"` + Schema *Schema `json:"schema,omitempty" yaml:"schema,omitempty"` + Ref string `json:"$ref,omitempty" yaml:"$ref,omitempty"` +} + +// Security Allows the definition of a security scheme that can be used by the operations +type Security struct { + Type string `json:"type,omitempty" yaml:"type,omitempty"` // Valid values are "basic", "apiKey" or "oauth2". + Description string `json:"description,omitempty" yaml:"description,omitempty"` + Name string `json:"name,omitempty" yaml:"name,omitempty"` + In string `json:"in,omitempty" yaml:"in,omitempty"` // Valid values are "query" or "header". + Flow string `json:"flow,omitempty" yaml:"flow,omitempty"` // Valid values are "implicit", "password", "application" or "accessCode". + AuthorizationURL string `json:"authorizationUrl,omitempty" yaml:"authorizationUrl,omitempty"` + TokenURL string `json:"tokenUrl,omitempty" yaml:"tokenUrl,omitempty"` + Scopes map[string]string `json:"scopes,omitempty" yaml:"scopes,omitempty"` // The available scopes for the OAuth2 security scheme. +} + +// Tag Allows adding meta data to a single tag that is used by the Operation Object +type Tag struct { + Name string `json:"name,omitempty" yaml:"name,omitempty"` + Description string `json:"description,omitempty" yaml:"description,omitempty"` + ExternalDocs *ExternalDocs `json:"externalDocs,omitempty" yaml:"externalDocs,omitempty"` +} + +// ExternalDocs include Additional external documentation +type ExternalDocs struct { + Description string `json:"description,omitempty" yaml:"description,omitempty"` + URL string `json:"url,omitempty" yaml:"url,omitempty"` +} diff --git a/pkg/template.go b/pkg/template.go new file mode 100644 index 00000000..59875be7 --- /dev/null +++ b/pkg/template.go @@ -0,0 +1,406 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package beego + +import ( + "errors" + "fmt" + "html/template" + "io" + "io/ioutil" + "net/http" + "os" + "path/filepath" + "regexp" + "strings" + "sync" + + "github.com/astaxie/beego/logs" + "github.com/astaxie/beego/utils" +) + +var ( + beegoTplFuncMap = make(template.FuncMap) + beeViewPathTemplateLocked = false + // beeViewPathTemplates caching map and supported template file extensions per view + beeViewPathTemplates = make(map[string]map[string]*template.Template) + templatesLock sync.RWMutex + // beeTemplateExt stores the template extension which will build + beeTemplateExt = []string{"tpl", "html", "gohtml"} + // beeTemplatePreprocessors stores associations of extension -> preprocessor handler + beeTemplateEngines = map[string]templatePreProcessor{} + beeTemplateFS = defaultFSFunc +) + +// ExecuteTemplate applies the template with name to the specified data object, +// writing the output to wr. +// A template will be executed safely in parallel. +func ExecuteTemplate(wr io.Writer, name string, data interface{}) error { + return ExecuteViewPathTemplate(wr, name, BConfig.WebConfig.ViewsPath, data) +} + +// ExecuteViewPathTemplate applies the template with name and from specific viewPath to the specified data object, +// writing the output to wr. +// A template will be executed safely in parallel. +func ExecuteViewPathTemplate(wr io.Writer, name string, viewPath string, data interface{}) error { + if BConfig.RunMode == DEV { + templatesLock.RLock() + defer templatesLock.RUnlock() + } + if beeTemplates, ok := beeViewPathTemplates[viewPath]; ok { + if t, ok := beeTemplates[name]; ok { + var err error + if t.Lookup(name) != nil { + err = t.ExecuteTemplate(wr, name, data) + } else { + err = t.Execute(wr, data) + } + if err != nil { + logs.Trace("template Execute err:", err) + } + return err + } + panic("can't find templatefile in the path:" + viewPath + "/" + name) + } + panic("Unknown view path:" + viewPath) +} + +func init() { + beegoTplFuncMap["dateformat"] = DateFormat + beegoTplFuncMap["date"] = Date + beegoTplFuncMap["compare"] = Compare + beegoTplFuncMap["compare_not"] = CompareNot + beegoTplFuncMap["not_nil"] = NotNil + beegoTplFuncMap["not_null"] = NotNil + beegoTplFuncMap["substr"] = Substr + beegoTplFuncMap["html2str"] = HTML2str + beegoTplFuncMap["str2html"] = Str2html + beegoTplFuncMap["htmlquote"] = Htmlquote + beegoTplFuncMap["htmlunquote"] = Htmlunquote + beegoTplFuncMap["renderform"] = RenderForm + beegoTplFuncMap["assets_js"] = AssetsJs + beegoTplFuncMap["assets_css"] = AssetsCSS + beegoTplFuncMap["config"] = GetConfig + beegoTplFuncMap["map_get"] = MapGet + + // Comparisons + beegoTplFuncMap["eq"] = eq // == + beegoTplFuncMap["ge"] = ge // >= + beegoTplFuncMap["gt"] = gt // > + beegoTplFuncMap["le"] = le // <= + beegoTplFuncMap["lt"] = lt // < + beegoTplFuncMap["ne"] = ne // != + + beegoTplFuncMap["urlfor"] = URLFor // build a URL to match a Controller and it's method +} + +// AddFuncMap let user to register a func in the template. +func AddFuncMap(key string, fn interface{}) error { + beegoTplFuncMap[key] = fn + return nil +} + +type templatePreProcessor func(root, path string, funcs template.FuncMap) (*template.Template, error) + +type templateFile struct { + root string + files map[string][]string +} + +// visit will make the paths into two part,the first is subDir (without tf.root),the second is full path(without tf.root). +// if tf.root="views" and +// paths is "views/errors/404.html",the subDir will be "errors",the file will be "errors/404.html" +// paths is "views/admin/errors/404.html",the subDir will be "admin/errors",the file will be "admin/errors/404.html" +func (tf *templateFile) visit(paths string, f os.FileInfo, err error) error { + if f == nil { + return err + } + if f.IsDir() || (f.Mode()&os.ModeSymlink) > 0 { + return nil + } + if !HasTemplateExt(paths) { + return nil + } + + replace := strings.NewReplacer("\\", "/") + file := strings.TrimLeft(replace.Replace(paths[len(tf.root):]), "/") + subDir := filepath.Dir(file) + + tf.files[subDir] = append(tf.files[subDir], file) + return nil +} + +// HasTemplateExt return this path contains supported template extension of beego or not. +func HasTemplateExt(paths string) bool { + for _, v := range beeTemplateExt { + if strings.HasSuffix(paths, "."+v) { + return true + } + } + return false +} + +// AddTemplateExt add new extension for template. +func AddTemplateExt(ext string) { + for _, v := range beeTemplateExt { + if v == ext { + return + } + } + beeTemplateExt = append(beeTemplateExt, ext) +} + +// AddViewPath adds a new path to the supported view paths. +//Can later be used by setting a controller ViewPath to this folder +//will panic if called after beego.Run() +func AddViewPath(viewPath string) error { + if beeViewPathTemplateLocked { + if _, exist := beeViewPathTemplates[viewPath]; exist { + return nil //Ignore if viewpath already exists + } + panic("Can not add new view paths after beego.Run()") + } + beeViewPathTemplates[viewPath] = make(map[string]*template.Template) + return BuildTemplate(viewPath) +} + +func lockViewPaths() { + beeViewPathTemplateLocked = true +} + +// BuildTemplate will build all template files in a directory. +// it makes beego can render any template file in view directory. +func BuildTemplate(dir string, files ...string) error { + var err error + fs := beeTemplateFS() + f, err := fs.Open(dir) + if err != nil { + if os.IsNotExist(err) { + return nil + } + return errors.New("dir open err") + } + defer f.Close() + + beeTemplates, ok := beeViewPathTemplates[dir] + if !ok { + panic("Unknown view path: " + dir) + } + self := &templateFile{ + root: dir, + files: make(map[string][]string), + } + err = Walk(fs, dir, func(path string, f os.FileInfo, err error) error { + return self.visit(path, f, err) + }) + if err != nil { + fmt.Printf("Walk() returned %v\n", err) + return err + } + buildAllFiles := len(files) == 0 + for _, v := range self.files { + for _, file := range v { + if buildAllFiles || utils.InSlice(file, files) { + templatesLock.Lock() + ext := filepath.Ext(file) + var t *template.Template + if len(ext) == 0 { + t, err = getTemplate(self.root, fs, file, v...) + } else if fn, ok := beeTemplateEngines[ext[1:]]; ok { + t, err = fn(self.root, file, beegoTplFuncMap) + } else { + t, err = getTemplate(self.root, fs, file, v...) + } + if err != nil { + logs.Error("parse template err:", file, err) + templatesLock.Unlock() + return err + } + beeTemplates[file] = t + templatesLock.Unlock() + } + } + } + return nil +} + +func getTplDeep(root string, fs http.FileSystem, file string, parent string, t *template.Template) (*template.Template, [][]string, error) { + var fileAbsPath string + var rParent string + var err error + if strings.HasPrefix(file, "../") { + rParent = filepath.Join(filepath.Dir(parent), file) + fileAbsPath = filepath.Join(root, filepath.Dir(parent), file) + } else { + rParent = file + fileAbsPath = filepath.Join(root, file) + } + f, err := fs.Open(fileAbsPath) + if err != nil { + panic("can't find template file:" + file) + } + defer f.Close() + data, err := ioutil.ReadAll(f) + if err != nil { + return nil, [][]string{}, err + } + t, err = t.New(file).Parse(string(data)) + if err != nil { + return nil, [][]string{}, err + } + reg := regexp.MustCompile(BConfig.WebConfig.TemplateLeft + "[ ]*template[ ]+\"([^\"]+)\"") + allSub := reg.FindAllStringSubmatch(string(data), -1) + for _, m := range allSub { + if len(m) == 2 { + tl := t.Lookup(m[1]) + if tl != nil { + continue + } + if !HasTemplateExt(m[1]) { + continue + } + _, _, err = getTplDeep(root, fs, m[1], rParent, t) + if err != nil { + return nil, [][]string{}, err + } + } + } + return t, allSub, nil +} + +func getTemplate(root string, fs http.FileSystem, file string, others ...string) (t *template.Template, err error) { + t = template.New(file).Delims(BConfig.WebConfig.TemplateLeft, BConfig.WebConfig.TemplateRight).Funcs(beegoTplFuncMap) + var subMods [][]string + t, subMods, err = getTplDeep(root, fs, file, "", t) + if err != nil { + return nil, err + } + t, err = _getTemplate(t, root, fs, subMods, others...) + + if err != nil { + return nil, err + } + return +} + +func _getTemplate(t0 *template.Template, root string, fs http.FileSystem, subMods [][]string, others ...string) (t *template.Template, err error) { + t = t0 + for _, m := range subMods { + if len(m) == 2 { + tpl := t.Lookup(m[1]) + if tpl != nil { + continue + } + //first check filename + for _, otherFile := range others { + if otherFile == m[1] { + var subMods1 [][]string + t, subMods1, err = getTplDeep(root, fs, otherFile, "", t) + if err != nil { + logs.Trace("template parse file err:", err) + } else if len(subMods1) > 0 { + t, err = _getTemplate(t, root, fs, subMods1, others...) + } + break + } + } + //second check define + for _, otherFile := range others { + var data []byte + fileAbsPath := filepath.Join(root, otherFile) + f, err := fs.Open(fileAbsPath) + if err != nil { + f.Close() + logs.Trace("template file parse error, not success open file:", err) + continue + } + data, err = ioutil.ReadAll(f) + f.Close() + if err != nil { + logs.Trace("template file parse error, not success read file:", err) + continue + } + reg := regexp.MustCompile(BConfig.WebConfig.TemplateLeft + "[ ]*define[ ]+\"([^\"]+)\"") + allSub := reg.FindAllStringSubmatch(string(data), -1) + for _, sub := range allSub { + if len(sub) == 2 && sub[1] == m[1] { + var subMods1 [][]string + t, subMods1, err = getTplDeep(root, fs, otherFile, "", t) + if err != nil { + logs.Trace("template parse file err:", err) + } else if len(subMods1) > 0 { + t, err = _getTemplate(t, root, fs, subMods1, others...) + if err != nil { + logs.Trace("template parse file err:", err) + } + } + break + } + } + } + } + + } + return +} + +type templateFSFunc func() http.FileSystem + +func defaultFSFunc() http.FileSystem { + return FileSystem{} +} + +// SetTemplateFSFunc set default filesystem function +func SetTemplateFSFunc(fnt templateFSFunc) { + beeTemplateFS = fnt +} + +// SetViewsPath sets view directory path in beego application. +func SetViewsPath(path string) *App { + BConfig.WebConfig.ViewsPath = path + return BeeApp +} + +// SetStaticPath sets static directory path and proper url pattern in beego application. +// if beego.SetStaticPath("static","public"), visit /static/* to load static file in folder "public". +func SetStaticPath(url string, path string) *App { + if !strings.HasPrefix(url, "/") { + url = "/" + url + } + if url != "/" { + url = strings.TrimRight(url, "/") + } + BConfig.WebConfig.StaticDir[url] = path + return BeeApp +} + +// DelStaticPath removes the static folder setting in this url pattern in beego application. +func DelStaticPath(url string) *App { + if !strings.HasPrefix(url, "/") { + url = "/" + url + } + if url != "/" { + url = strings.TrimRight(url, "/") + } + delete(BConfig.WebConfig.StaticDir, url) + return BeeApp +} + +// AddTemplateEngine add a new templatePreProcessor which support extension +func AddTemplateEngine(extension string, fn templatePreProcessor) *App { + AddTemplateExt(extension) + beeTemplateEngines[extension] = fn + return BeeApp +} diff --git a/pkg/template_test.go b/pkg/template_test.go new file mode 100644 index 00000000..287faadc --- /dev/null +++ b/pkg/template_test.go @@ -0,0 +1,316 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package beego + +import ( + "bytes" + "github.com/astaxie/beego/testdata" + "github.com/elazarl/go-bindata-assetfs" + "net/http" + "os" + "path/filepath" + "testing" +) + +var header = `{{define "header"}} +

Hello, astaxie!

+{{end}}` + +var index = ` + + + beego welcome template + + +{{template "block"}} +{{template "header"}} +{{template "blocks/block.tpl"}} + + +` + +var block = `{{define "block"}} +

Hello, blocks!

+{{end}}` + +func TestTemplate(t *testing.T) { + dir := "_beeTmp" + files := []string{ + "header.tpl", + "index.tpl", + "blocks/block.tpl", + } + if err := os.MkdirAll(dir, 0777); err != nil { + t.Fatal(err) + } + for k, name := range files { + os.MkdirAll(filepath.Dir(filepath.Join(dir, name)), 0777) + if f, err := os.Create(filepath.Join(dir, name)); err != nil { + t.Fatal(err) + } else { + if k == 0 { + f.WriteString(header) + } else if k == 1 { + f.WriteString(index) + } else if k == 2 { + f.WriteString(block) + } + + f.Close() + } + } + if err := AddViewPath(dir); err != nil { + t.Fatal(err) + } + beeTemplates := beeViewPathTemplates[dir] + if len(beeTemplates) != 3 { + t.Fatalf("should be 3 but got %v", len(beeTemplates)) + } + if err := beeTemplates["index.tpl"].ExecuteTemplate(os.Stdout, "index.tpl", nil); err != nil { + t.Fatal(err) + } + for _, name := range files { + os.RemoveAll(filepath.Join(dir, name)) + } + os.RemoveAll(dir) +} + +var menu = ` +` +var user = ` + + + beego welcome template + + +{{template "../public/menu.tpl"}} + + +` + +func TestRelativeTemplate(t *testing.T) { + dir := "_beeTmp" + + //Just add dir to known viewPaths + if err := AddViewPath(dir); err != nil { + t.Fatal(err) + } + + files := []string{ + "easyui/public/menu.tpl", + "easyui/rbac/user.tpl", + } + if err := os.MkdirAll(dir, 0777); err != nil { + t.Fatal(err) + } + for k, name := range files { + os.MkdirAll(filepath.Dir(filepath.Join(dir, name)), 0777) + if f, err := os.Create(filepath.Join(dir, name)); err != nil { + t.Fatal(err) + } else { + if k == 0 { + f.WriteString(menu) + } else if k == 1 { + f.WriteString(user) + } + f.Close() + } + } + if err := BuildTemplate(dir, files[1]); err != nil { + t.Fatal(err) + } + beeTemplates := beeViewPathTemplates[dir] + if err := beeTemplates["easyui/rbac/user.tpl"].ExecuteTemplate(os.Stdout, "easyui/rbac/user.tpl", nil); err != nil { + t.Fatal(err) + } + for _, name := range files { + os.RemoveAll(filepath.Join(dir, name)) + } + os.RemoveAll(dir) +} + +var add = `{{ template "layout_blog.tpl" . }} +{{ define "css" }} + +{{ end}} + + +{{ define "content" }} +

{{ .Title }}

+

This is SomeVar: {{ .SomeVar }}

+{{ end }} + +{{ define "js" }} + +{{ end}}` + +var layoutBlog = ` + + + Lin Li + + + + + {{ block "css" . }}{{ end }} + + + +
+ {{ block "content" . }}{{ end }} +
+ + + {{ block "js" . }}{{ end }} + +` + +var output = ` + + + Lin Li + + + + + + + + + + +
+ +

Hello

+

This is SomeVar: val

+ +
+ + + + + + + + + + + + +` + +func TestTemplateLayout(t *testing.T) { + dir := "_beeTmp" + files := []string{ + "add.tpl", + "layout_blog.tpl", + } + if err := os.MkdirAll(dir, 0777); err != nil { + t.Fatal(err) + } + for k, name := range files { + os.MkdirAll(filepath.Dir(filepath.Join(dir, name)), 0777) + if f, err := os.Create(filepath.Join(dir, name)); err != nil { + t.Fatal(err) + } else { + if k == 0 { + f.WriteString(add) + } else if k == 1 { + f.WriteString(layoutBlog) + } + f.Close() + } + } + if err := AddViewPath(dir); err != nil { + t.Fatal(err) + } + beeTemplates := beeViewPathTemplates[dir] + if len(beeTemplates) != 2 { + t.Fatalf("should be 2 but got %v", len(beeTemplates)) + } + out := bytes.NewBufferString("") + if err := beeTemplates["add.tpl"].ExecuteTemplate(out, "add.tpl", map[string]string{"Title": "Hello", "SomeVar": "val"}); err != nil { + t.Fatal(err) + } + if out.String() != output { + t.Log(out.String()) + t.Fatal("Compare failed") + } + for _, name := range files { + os.RemoveAll(filepath.Join(dir, name)) + } + os.RemoveAll(dir) +} + +type TestingFileSystem struct { + assetfs *assetfs.AssetFS +} + +func (d TestingFileSystem) Open(name string) (http.File, error) { + return d.assetfs.Open(name) +} + +var outputBinData = ` + + + beego welcome template + + + + +

Hello, blocks!

+ + +

Hello, astaxie!

+ + + +

Hello

+

This is SomeVar: val

+ + +` + +func TestFsBinData(t *testing.T) { + SetTemplateFSFunc(func() http.FileSystem { + return TestingFileSystem{&assetfs.AssetFS{Asset: testdata.Asset, AssetDir: testdata.AssetDir, AssetInfo: testdata.AssetInfo}} + }) + dir := "views" + if err := AddViewPath("views"); err != nil { + t.Fatal(err) + } + beeTemplates := beeViewPathTemplates[dir] + if len(beeTemplates) != 3 { + t.Fatalf("should be 3 but got %v", len(beeTemplates)) + } + if err := beeTemplates["index.tpl"].ExecuteTemplate(os.Stdout, "index.tpl", map[string]string{"Title": "Hello", "SomeVar": "val"}); err != nil { + t.Fatal(err) + } + out := bytes.NewBufferString("") + if err := beeTemplates["index.tpl"].ExecuteTemplate(out, "index.tpl", map[string]string{"Title": "Hello", "SomeVar": "val"}); err != nil { + t.Fatal(err) + } + + if out.String() != outputBinData { + t.Log(out.String()) + t.Fatal("Compare failed") + } +} diff --git a/pkg/templatefunc.go b/pkg/templatefunc.go new file mode 100644 index 00000000..ba1ec5eb --- /dev/null +++ b/pkg/templatefunc.go @@ -0,0 +1,780 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package beego + +import ( + "errors" + "fmt" + "html" + "html/template" + "net/url" + "reflect" + "regexp" + "strconv" + "strings" + "time" +) + +const ( + formatTime = "15:04:05" + formatDate = "2006-01-02" + formatDateTime = "2006-01-02 15:04:05" + formatDateTimeT = "2006-01-02T15:04:05" +) + +// Substr returns the substr from start to length. +func Substr(s string, start, length int) string { + bt := []rune(s) + if start < 0 { + start = 0 + } + if start > len(bt) { + start = start % len(bt) + } + var end int + if (start + length) > (len(bt) - 1) { + end = len(bt) + } else { + end = start + length + } + return string(bt[start:end]) +} + +// HTML2str returns escaping text convert from html. +func HTML2str(html string) string { + + re := regexp.MustCompile(`\<[\S\s]+?\>`) + html = re.ReplaceAllStringFunc(html, strings.ToLower) + + //remove STYLE + re = regexp.MustCompile(`\`) + html = re.ReplaceAllString(html, "") + + //remove SCRIPT + re = regexp.MustCompile(`\`) + html = re.ReplaceAllString(html, "") + + re = regexp.MustCompile(`\<[\S\s]+?\>`) + html = re.ReplaceAllString(html, "\n") + + re = regexp.MustCompile(`\s{2,}`) + html = re.ReplaceAllString(html, "\n") + + return strings.TrimSpace(html) +} + +// DateFormat takes a time and a layout string and returns a string with the formatted date. Used by the template parser as "dateformat" +func DateFormat(t time.Time, layout string) (datestring string) { + datestring = t.Format(layout) + return +} + +// DateFormat pattern rules. +var datePatterns = []string{ + // year + "Y", "2006", // A full numeric representation of a year, 4 digits Examples: 1999 or 2003 + "y", "06", //A two digit representation of a year Examples: 99 or 03 + + // month + "m", "01", // Numeric representation of a month, with leading zeros 01 through 12 + "n", "1", // Numeric representation of a month, without leading zeros 1 through 12 + "M", "Jan", // A short textual representation of a month, three letters Jan through Dec + "F", "January", // A full textual representation of a month, such as January or March January through December + + // day + "d", "02", // Day of the month, 2 digits with leading zeros 01 to 31 + "j", "2", // Day of the month without leading zeros 1 to 31 + + // week + "D", "Mon", // A textual representation of a day, three letters Mon through Sun + "l", "Monday", // A full textual representation of the day of the week Sunday through Saturday + + // time + "g", "3", // 12-hour format of an hour without leading zeros 1 through 12 + "G", "15", // 24-hour format of an hour without leading zeros 0 through 23 + "h", "03", // 12-hour format of an hour with leading zeros 01 through 12 + "H", "15", // 24-hour format of an hour with leading zeros 00 through 23 + + "a", "pm", // Lowercase Ante meridiem and Post meridiem am or pm + "A", "PM", // Uppercase Ante meridiem and Post meridiem AM or PM + + "i", "04", // Minutes with leading zeros 00 to 59 + "s", "05", // Seconds, with leading zeros 00 through 59 + + // time zone + "T", "MST", + "P", "-07:00", + "O", "-0700", + + // RFC 2822 + "r", time.RFC1123Z, +} + +// DateParse Parse Date use PHP time format. +func DateParse(dateString, format string) (time.Time, error) { + replacer := strings.NewReplacer(datePatterns...) + format = replacer.Replace(format) + return time.ParseInLocation(format, dateString, time.Local) +} + +// Date takes a PHP like date func to Go's time format. +func Date(t time.Time, format string) string { + replacer := strings.NewReplacer(datePatterns...) + format = replacer.Replace(format) + return t.Format(format) +} + +// Compare is a quick and dirty comparison function. It will convert whatever you give it to strings and see if the two values are equal. +// Whitespace is trimmed. Used by the template parser as "eq". +func Compare(a, b interface{}) (equal bool) { + equal = false + if strings.TrimSpace(fmt.Sprintf("%v", a)) == strings.TrimSpace(fmt.Sprintf("%v", b)) { + equal = true + } + return +} + +// CompareNot !Compare +func CompareNot(a, b interface{}) (equal bool) { + return !Compare(a, b) +} + +// NotNil the same as CompareNot +func NotNil(a interface{}) (isNil bool) { + return CompareNot(a, nil) +} + +// GetConfig get the Appconfig +func GetConfig(returnType, key string, defaultVal interface{}) (value interface{}, err error) { + switch returnType { + case "String": + value = AppConfig.String(key) + case "Bool": + value, err = AppConfig.Bool(key) + case "Int": + value, err = AppConfig.Int(key) + case "Int64": + value, err = AppConfig.Int64(key) + case "Float": + value, err = AppConfig.Float(key) + case "DIY": + value, err = AppConfig.DIY(key) + default: + err = errors.New("config keys must be of type String, Bool, Int, Int64, Float, or DIY") + } + + if err != nil { + if reflect.TypeOf(returnType) != reflect.TypeOf(defaultVal) { + err = errors.New("defaultVal type does not match returnType") + } else { + value, err = defaultVal, nil + } + } else if reflect.TypeOf(value).Kind() == reflect.String { + if value == "" { + if reflect.TypeOf(defaultVal).Kind() != reflect.String { + err = errors.New("defaultVal type must be a String if the returnType is a String") + } else { + value = defaultVal.(string) + } + } + } + + return +} + +// Str2html Convert string to template.HTML type. +func Str2html(raw string) template.HTML { + return template.HTML(raw) +} + +// Htmlquote returns quoted html string. +func Htmlquote(text string) string { + //HTML编码为实体符号 + /* + Encodes `text` for raw use in HTML. + >>> htmlquote("<'&\\">") + '<'&">' + */ + + text = html.EscapeString(text) + text = strings.NewReplacer( + `“`, "“", + `”`, "”", + ` `, " ", + ).Replace(text) + + return strings.TrimSpace(text) +} + +// Htmlunquote returns unquoted html string. +func Htmlunquote(text string) string { + //实体符号解释为HTML + /* + Decodes `text` that's HTML quoted. + >>> htmlunquote('<'&">') + '<\\'&">' + */ + + text = html.UnescapeString(text) + + return strings.TrimSpace(text) +} + +// URLFor returns url string with another registered controller handler with params. +// usage: +// +// URLFor(".index") +// print URLFor("index") +// router /login +// print URLFor("login") +// print URLFor("login", "next","/"") +// router /profile/:username +// print UrlFor("profile", ":username","John Doe") +// result: +// / +// /login +// /login?next=/ +// /user/John%20Doe +// +// more detail http://beego.me/docs/mvc/controller/urlbuilding.md +func URLFor(endpoint string, values ...interface{}) string { + return BeeApp.Handlers.URLFor(endpoint, values...) +} + +// AssetsJs returns script tag with src string. +func AssetsJs(text string) template.HTML { + + text = "" + + return template.HTML(text) +} + +// AssetsCSS returns stylesheet link tag with src string. +func AssetsCSS(text string) template.HTML { + + text = "" + + return template.HTML(text) +} + +// ParseForm will parse form values to struct via tag. +// Support for anonymous struct. +func parseFormToStruct(form url.Values, objT reflect.Type, objV reflect.Value) error { + for i := 0; i < objT.NumField(); i++ { + fieldV := objV.Field(i) + if !fieldV.CanSet() { + continue + } + + fieldT := objT.Field(i) + if fieldT.Anonymous && fieldT.Type.Kind() == reflect.Struct { + err := parseFormToStruct(form, fieldT.Type, fieldV) + if err != nil { + return err + } + continue + } + + tags := strings.Split(fieldT.Tag.Get("form"), ",") + var tag string + if len(tags) == 0 || len(tags[0]) == 0 { + tag = fieldT.Name + } else if tags[0] == "-" { + continue + } else { + tag = tags[0] + } + + formValues := form[tag] + var value string + if len(formValues) == 0 { + defaultValue := fieldT.Tag.Get("default") + if defaultValue != "" { + value = defaultValue + } else { + continue + } + } + if len(formValues) == 1 { + value = formValues[0] + if value == "" { + continue + } + } + + switch fieldT.Type.Kind() { + case reflect.Bool: + if strings.ToLower(value) == "on" || strings.ToLower(value) == "1" || strings.ToLower(value) == "yes" { + fieldV.SetBool(true) + continue + } + if strings.ToLower(value) == "off" || strings.ToLower(value) == "0" || strings.ToLower(value) == "no" { + fieldV.SetBool(false) + continue + } + b, err := strconv.ParseBool(value) + if err != nil { + return err + } + fieldV.SetBool(b) + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + x, err := strconv.ParseInt(value, 10, 64) + if err != nil { + return err + } + fieldV.SetInt(x) + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + x, err := strconv.ParseUint(value, 10, 64) + if err != nil { + return err + } + fieldV.SetUint(x) + case reflect.Float32, reflect.Float64: + x, err := strconv.ParseFloat(value, 64) + if err != nil { + return err + } + fieldV.SetFloat(x) + case reflect.Interface: + fieldV.Set(reflect.ValueOf(value)) + case reflect.String: + fieldV.SetString(value) + case reflect.Struct: + switch fieldT.Type.String() { + case "time.Time": + var ( + t time.Time + err error + ) + if len(value) >= 25 { + value = value[:25] + t, err = time.ParseInLocation(time.RFC3339, value, time.Local) + } else if strings.HasSuffix(strings.ToUpper(value), "Z") { + t, err = time.ParseInLocation(time.RFC3339, value, time.Local) + } else if len(value) >= 19 { + if strings.Contains(value, "T") { + value = value[:19] + t, err = time.ParseInLocation(formatDateTimeT, value, time.Local) + } else { + value = value[:19] + t, err = time.ParseInLocation(formatDateTime, value, time.Local) + } + } else if len(value) >= 10 { + if len(value) > 10 { + value = value[:10] + } + t, err = time.ParseInLocation(formatDate, value, time.Local) + } else if len(value) >= 8 { + if len(value) > 8 { + value = value[:8] + } + t, err = time.ParseInLocation(formatTime, value, time.Local) + } + if err != nil { + return err + } + fieldV.Set(reflect.ValueOf(t)) + } + case reflect.Slice: + if fieldT.Type == sliceOfInts { + formVals := form[tag] + fieldV.Set(reflect.MakeSlice(reflect.SliceOf(reflect.TypeOf(int(1))), len(formVals), len(formVals))) + for i := 0; i < len(formVals); i++ { + val, err := strconv.Atoi(formVals[i]) + if err != nil { + return err + } + fieldV.Index(i).SetInt(int64(val)) + } + } else if fieldT.Type == sliceOfStrings { + formVals := form[tag] + fieldV.Set(reflect.MakeSlice(reflect.SliceOf(reflect.TypeOf("")), len(formVals), len(formVals))) + for i := 0; i < len(formVals); i++ { + fieldV.Index(i).SetString(formVals[i]) + } + } + } + } + return nil +} + +// ParseForm will parse form values to struct via tag. +func ParseForm(form url.Values, obj interface{}) error { + objT := reflect.TypeOf(obj) + objV := reflect.ValueOf(obj) + if !isStructPtr(objT) { + return fmt.Errorf("%v must be a struct pointer", obj) + } + objT = objT.Elem() + objV = objV.Elem() + + return parseFormToStruct(form, objT, objV) +} + +var sliceOfInts = reflect.TypeOf([]int(nil)) +var sliceOfStrings = reflect.TypeOf([]string(nil)) + +var unKind = map[reflect.Kind]bool{ + reflect.Uintptr: true, + reflect.Complex64: true, + reflect.Complex128: true, + reflect.Array: true, + reflect.Chan: true, + reflect.Func: true, + reflect.Map: true, + reflect.Ptr: true, + reflect.Slice: true, + reflect.Struct: true, + reflect.UnsafePointer: true, +} + +// RenderForm will render object to form html. +// obj must be a struct pointer. +func RenderForm(obj interface{}) template.HTML { + objT := reflect.TypeOf(obj) + objV := reflect.ValueOf(obj) + if !isStructPtr(objT) { + return template.HTML("") + } + objT = objT.Elem() + objV = objV.Elem() + + var raw []string + for i := 0; i < objT.NumField(); i++ { + fieldV := objV.Field(i) + if !fieldV.CanSet() || unKind[fieldV.Kind()] { + continue + } + + fieldT := objT.Field(i) + + label, name, fType, id, class, ignored, required := parseFormTag(fieldT) + if ignored { + continue + } + + raw = append(raw, renderFormField(label, name, fType, fieldV.Interface(), id, class, required)) + } + return template.HTML(strings.Join(raw, "
")) +} + +// renderFormField returns a string containing HTML of a single form field. +func renderFormField(label, name, fType string, value interface{}, id string, class string, required bool) string { + if id != "" { + id = " id=\"" + id + "\"" + } + + if class != "" { + class = " class=\"" + class + "\"" + } + + requiredString := "" + if required { + requiredString = " required" + } + + if isValidForInput(fType) { + return fmt.Sprintf(`%v`, label, id, class, name, fType, value, requiredString) + } + + return fmt.Sprintf(`%v<%v%v%v name="%v"%v>%v`, label, fType, id, class, name, requiredString, value, fType) +} + +// isValidForInput checks if fType is a valid value for the `type` property of an HTML input element. +func isValidForInput(fType string) bool { + validInputTypes := strings.Fields("text password checkbox radio submit reset hidden image file button search email url tel number range date month week time datetime datetime-local color") + for _, validType := range validInputTypes { + if fType == validType { + return true + } + } + return false +} + +// parseFormTag takes the stuct-tag of a StructField and parses the `form` value. +// returned are the form label, name-property, type and wether the field should be ignored. +func parseFormTag(fieldT reflect.StructField) (label, name, fType string, id string, class string, ignored bool, required bool) { + tags := strings.Split(fieldT.Tag.Get("form"), ",") + label = fieldT.Name + ": " + name = fieldT.Name + fType = "text" + ignored = false + id = fieldT.Tag.Get("id") + class = fieldT.Tag.Get("class") + + required = false + requiredField := fieldT.Tag.Get("required") + if requiredField != "-" && requiredField != "" { + required, _ = strconv.ParseBool(requiredField) + } + + switch len(tags) { + case 1: + if tags[0] == "-" { + ignored = true + } + if len(tags[0]) > 0 { + name = tags[0] + } + case 2: + if len(tags[0]) > 0 { + name = tags[0] + } + if len(tags[1]) > 0 { + fType = tags[1] + } + case 3: + if len(tags[0]) > 0 { + name = tags[0] + } + if len(tags[1]) > 0 { + fType = tags[1] + } + if len(tags[2]) > 0 { + label = tags[2] + } + } + + return +} + +func isStructPtr(t reflect.Type) bool { + return t.Kind() == reflect.Ptr && t.Elem().Kind() == reflect.Struct +} + +// go1.2 added template funcs. begin +var ( + errBadComparisonType = errors.New("invalid type for comparison") + errBadComparison = errors.New("incompatible types for comparison") + errNoComparison = errors.New("missing argument for comparison") +) + +type kind int + +const ( + invalidKind kind = iota + boolKind + complexKind + intKind + floatKind + stringKind + uintKind +) + +func basicKind(v reflect.Value) (kind, error) { + switch v.Kind() { + case reflect.Bool: + return boolKind, nil + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + return intKind, nil + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: + return uintKind, nil + case reflect.Float32, reflect.Float64: + return floatKind, nil + case reflect.Complex64, reflect.Complex128: + return complexKind, nil + case reflect.String: + return stringKind, nil + } + return invalidKind, errBadComparisonType +} + +// eq evaluates the comparison a == b || a == c || ... +func eq(arg1 interface{}, arg2 ...interface{}) (bool, error) { + v1 := reflect.ValueOf(arg1) + k1, err := basicKind(v1) + if err != nil { + return false, err + } + if len(arg2) == 0 { + return false, errNoComparison + } + for _, arg := range arg2 { + v2 := reflect.ValueOf(arg) + k2, err := basicKind(v2) + if err != nil { + return false, err + } + if k1 != k2 { + return false, errBadComparison + } + truth := false + switch k1 { + case boolKind: + truth = v1.Bool() == v2.Bool() + case complexKind: + truth = v1.Complex() == v2.Complex() + case floatKind: + truth = v1.Float() == v2.Float() + case intKind: + truth = v1.Int() == v2.Int() + case stringKind: + truth = v1.String() == v2.String() + case uintKind: + truth = v1.Uint() == v2.Uint() + default: + panic("invalid kind") + } + if truth { + return true, nil + } + } + return false, nil +} + +// ne evaluates the comparison a != b. +func ne(arg1, arg2 interface{}) (bool, error) { + // != is the inverse of ==. + equal, err := eq(arg1, arg2) + return !equal, err +} + +// lt evaluates the comparison a < b. +func lt(arg1, arg2 interface{}) (bool, error) { + v1 := reflect.ValueOf(arg1) + k1, err := basicKind(v1) + if err != nil { + return false, err + } + v2 := reflect.ValueOf(arg2) + k2, err := basicKind(v2) + if err != nil { + return false, err + } + if k1 != k2 { + return false, errBadComparison + } + truth := false + switch k1 { + case boolKind, complexKind: + return false, errBadComparisonType + case floatKind: + truth = v1.Float() < v2.Float() + case intKind: + truth = v1.Int() < v2.Int() + case stringKind: + truth = v1.String() < v2.String() + case uintKind: + truth = v1.Uint() < v2.Uint() + default: + panic("invalid kind") + } + return truth, nil +} + +// le evaluates the comparison <= b. +func le(arg1, arg2 interface{}) (bool, error) { + // <= is < or ==. + lessThan, err := lt(arg1, arg2) + if lessThan || err != nil { + return lessThan, err + } + return eq(arg1, arg2) +} + +// gt evaluates the comparison a > b. +func gt(arg1, arg2 interface{}) (bool, error) { + // > is the inverse of <=. + lessOrEqual, err := le(arg1, arg2) + if err != nil { + return false, err + } + return !lessOrEqual, nil +} + +// ge evaluates the comparison a >= b. +func ge(arg1, arg2 interface{}) (bool, error) { + // >= is the inverse of <. + lessThan, err := lt(arg1, arg2) + if err != nil { + return false, err + } + return !lessThan, nil +} + +// MapGet getting value from map by keys +// usage: +// Data["m"] = M{ +// "a": 1, +// "1": map[string]float64{ +// "c": 4, +// }, +// } +// +// {{ map_get m "a" }} // return 1 +// {{ map_get m 1 "c" }} // return 4 +func MapGet(arg1 interface{}, arg2 ...interface{}) (interface{}, error) { + arg1Type := reflect.TypeOf(arg1) + arg1Val := reflect.ValueOf(arg1) + + if arg1Type.Kind() == reflect.Map && len(arg2) > 0 { + // check whether arg2[0] type equals to arg1 key type + // if they are different, make conversion + arg2Val := reflect.ValueOf(arg2[0]) + arg2Type := reflect.TypeOf(arg2[0]) + if arg2Type.Kind() != arg1Type.Key().Kind() { + // convert arg2Value to string + var arg2ConvertedVal interface{} + arg2String := fmt.Sprintf("%v", arg2[0]) + + // convert string representation to any other type + switch arg1Type.Key().Kind() { + case reflect.Bool: + arg2ConvertedVal, _ = strconv.ParseBool(arg2String) + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + arg2ConvertedVal, _ = strconv.ParseInt(arg2String, 0, 64) + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: + arg2ConvertedVal, _ = strconv.ParseUint(arg2String, 0, 64) + case reflect.Float32, reflect.Float64: + arg2ConvertedVal, _ = strconv.ParseFloat(arg2String, 64) + case reflect.String: + arg2ConvertedVal = arg2String + default: + arg2ConvertedVal = arg2Val.Interface() + } + arg2Val = reflect.ValueOf(arg2ConvertedVal) + } + + storedVal := arg1Val.MapIndex(arg2Val) + + if storedVal.IsValid() { + var result interface{} + + switch arg1Type.Elem().Kind() { + case reflect.Bool: + result = storedVal.Bool() + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + result = storedVal.Int() + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: + result = storedVal.Uint() + case reflect.Float32, reflect.Float64: + result = storedVal.Float() + case reflect.String: + result = storedVal.String() + default: + result = storedVal.Interface() + } + + // if there is more keys, handle this recursively + if len(arg2) > 1 { + return MapGet(result, arg2[1:]...) + } + return result, nil + } + return nil, nil + + } + return nil, nil +} diff --git a/pkg/templatefunc_test.go b/pkg/templatefunc_test.go new file mode 100644 index 00000000..b4c19c2e --- /dev/null +++ b/pkg/templatefunc_test.go @@ -0,0 +1,380 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package beego + +import ( + "html/template" + "net/url" + "reflect" + "testing" + "time" +) + +func TestSubstr(t *testing.T) { + s := `012345` + if Substr(s, 0, 2) != "01" { + t.Error("should be equal") + } + if Substr(s, 0, 100) != "012345" { + t.Error("should be equal") + } + if Substr(s, 12, 100) != "012345" { + t.Error("should be equal") + } +} + +func TestHtml2str(t *testing.T) { + h := `<123> 123\n + + + \n` + if HTML2str(h) != "123\\n\n\\n" { + t.Error("should be equal") + } +} + +func TestDateFormat(t *testing.T) { + ts := "Mon, 01 Jul 2013 13:27:42 CST" + tt, _ := time.Parse(time.RFC1123, ts) + + if ss := DateFormat(tt, "2006-01-02 15:04:05"); ss != "2013-07-01 13:27:42" { + t.Errorf("2013-07-01 13:27:42 does not equal %v", ss) + } +} + +func TestDate(t *testing.T) { + ts := "Mon, 01 Jul 2013 13:27:42 CST" + tt, _ := time.Parse(time.RFC1123, ts) + + if ss := Date(tt, "Y-m-d H:i:s"); ss != "2013-07-01 13:27:42" { + t.Errorf("2013-07-01 13:27:42 does not equal %v", ss) + } + if ss := Date(tt, "y-n-j h:i:s A"); ss != "13-7-1 01:27:42 PM" { + t.Errorf("13-7-1 01:27:42 PM does not equal %v", ss) + } + if ss := Date(tt, "D, d M Y g:i:s a"); ss != "Mon, 01 Jul 2013 1:27:42 pm" { + t.Errorf("Mon, 01 Jul 2013 1:27:42 pm does not equal %v", ss) + } + if ss := Date(tt, "l, d F Y G:i:s"); ss != "Monday, 01 July 2013 13:27:42" { + t.Errorf("Monday, 01 July 2013 13:27:42 does not equal %v", ss) + } +} + +func TestCompareRelated(t *testing.T) { + if !Compare("abc", "abc") { + t.Error("should be equal") + } + if Compare("abc", "aBc") { + t.Error("should be not equal") + } + if !Compare("1", 1) { + t.Error("should be equal") + } + if CompareNot("abc", "abc") { + t.Error("should be equal") + } + if !CompareNot("abc", "aBc") { + t.Error("should be not equal") + } + if !NotNil("a string") { + t.Error("should not be nil") + } +} + +func TestHtmlquote(t *testing.T) { + h := `<' ”“&">` + s := `<' ”“&">` + if Htmlquote(s) != h { + t.Error("should be equal") + } +} + +func TestHtmlunquote(t *testing.T) { + h := `<' ”“&">` + s := `<' ”“&">` + if Htmlunquote(h) != s { + t.Error("should be equal") + } +} + +func TestParseForm(t *testing.T) { + type ExtendInfo struct { + Hobby []string `form:"hobby"` + Memo string + } + + type OtherInfo struct { + Organization string `form:"organization"` + Title string `form:"title"` + ExtendInfo + } + + type user struct { + ID int `form:"-"` + tag string `form:"tag"` + Name interface{} `form:"username"` + Age int `form:"age,text"` + Email string + Intro string `form:",textarea"` + StrBool bool `form:"strbool"` + Date time.Time `form:"date,2006-01-02"` + OtherInfo + } + + u := user{} + form := url.Values{ + "ID": []string{"1"}, + "-": []string{"1"}, + "tag": []string{"no"}, + "username": []string{"test"}, + "age": []string{"40"}, + "Email": []string{"test@gmail.com"}, + "Intro": []string{"I am an engineer!"}, + "strbool": []string{"yes"}, + "date": []string{"2014-11-12"}, + "organization": []string{"beego"}, + "title": []string{"CXO"}, + "hobby": []string{"", "Basketball", "Football"}, + "memo": []string{"nothing"}, + } + if err := ParseForm(form, u); err == nil { + t.Fatal("nothing will be changed") + } + if err := ParseForm(form, &u); err != nil { + t.Fatal(err) + } + if u.ID != 0 { + t.Errorf("ID should equal 0 but got %v", u.ID) + } + if len(u.tag) != 0 { + t.Errorf("tag's length should equal 0 but got %v", len(u.tag)) + } + if u.Name.(string) != "test" { + t.Errorf("Name should equal `test` but got `%v`", u.Name.(string)) + } + if u.Age != 40 { + t.Errorf("Age should equal 40 but got %v", u.Age) + } + if u.Email != "test@gmail.com" { + t.Errorf("Email should equal `test@gmail.com` but got `%v`", u.Email) + } + if u.Intro != "I am an engineer!" { + t.Errorf("Intro should equal `I am an engineer!` but got `%v`", u.Intro) + } + if !u.StrBool { + t.Errorf("strboll should equal `true`, but got `%v`", u.StrBool) + } + y, m, d := u.Date.Date() + if y != 2014 || m.String() != "November" || d != 12 { + t.Errorf("Date should equal `2014-11-12`, but got `%v`", u.Date.String()) + } + if u.Organization != "beego" { + t.Errorf("Organization should equal `beego`, but got `%v`", u.Organization) + } + if u.Title != "CXO" { + t.Errorf("Title should equal `CXO`, but got `%v`", u.Title) + } + if u.Hobby[0] != "" { + t.Errorf("Hobby should equal ``, but got `%v`", u.Hobby[0]) + } + if u.Hobby[1] != "Basketball" { + t.Errorf("Hobby should equal `Basketball`, but got `%v`", u.Hobby[1]) + } + if u.Hobby[2] != "Football" { + t.Errorf("Hobby should equal `Football`, but got `%v`", u.Hobby[2]) + } + if len(u.Memo) != 0 { + t.Errorf("Memo's length should equal 0 but got %v", len(u.Memo)) + } +} + +func TestRenderForm(t *testing.T) { + type user struct { + ID int `form:"-"` + Name interface{} `form:"username"` + Age int `form:"age,text,年龄:"` + Sex string + Email []string + Intro string `form:",textarea"` + Ignored string `form:"-"` + } + + u := user{Name: "test", Intro: "Some Text"} + output := RenderForm(u) + if output != template.HTML("") { + t.Errorf("output should be empty but got %v", output) + } + output = RenderForm(&u) + result := template.HTML( + `Name:
` + + `年龄:
` + + `Sex:
` + + `Intro: `) + if output != result { + t.Errorf("output should equal `%v` but got `%v`", result, output) + } +} + +func TestRenderFormField(t *testing.T) { + html := renderFormField("Label: ", "Name", "text", "Value", "", "", false) + if html != `Label: ` { + t.Errorf("Wrong html output for input[type=text]: %v ", html) + } + + html = renderFormField("Label: ", "Name", "textarea", "Value", "", "", false) + if html != `Label: ` { + t.Errorf("Wrong html output for textarea: %v ", html) + } + + html = renderFormField("Label: ", "Name", "textarea", "Value", "", "", true) + if html != `Label: ` { + t.Errorf("Wrong html output for textarea: %v ", html) + } +} + +func TestParseFormTag(t *testing.T) { + // create struct to contain field with different types of struct-tag `form` + type user struct { + All int `form:"name,text,年龄:"` + NoName int `form:",hidden,年龄:"` + OnlyLabel int `form:",,年龄:"` + OnlyName int `form:"name" id:"name" class:"form-name"` + Ignored int `form:"-"` + Required int `form:"name" required:"true"` + IgnoreRequired int `form:"name"` + NotRequired int `form:"name" required:"false"` + } + + objT := reflect.TypeOf(&user{}).Elem() + + label, name, fType, _, _, ignored, _ := parseFormTag(objT.Field(0)) + if !(name == "name" && label == "年龄:" && fType == "text" && !ignored) { + t.Errorf("Form Tag with name, label and type was not correctly parsed.") + } + + label, name, fType, _, _, ignored, _ = parseFormTag(objT.Field(1)) + if !(name == "NoName" && label == "年龄:" && fType == "hidden" && !ignored) { + t.Errorf("Form Tag with label and type but without name was not correctly parsed.") + } + + label, name, fType, _, _, ignored, _ = parseFormTag(objT.Field(2)) + if !(name == "OnlyLabel" && label == "年龄:" && fType == "text" && !ignored) { + t.Errorf("Form Tag containing only label was not correctly parsed.") + } + + label, name, fType, id, class, ignored, _ := parseFormTag(objT.Field(3)) + if !(name == "name" && label == "OnlyName: " && fType == "text" && !ignored && + id == "name" && class == "form-name") { + t.Errorf("Form Tag containing only name was not correctly parsed.") + } + + _, _, _, _, _, ignored, _ = parseFormTag(objT.Field(4)) + if !ignored { + t.Errorf("Form Tag that should be ignored was not correctly parsed.") + } + + _, name, _, _, _, _, required := parseFormTag(objT.Field(5)) + if !(name == "name" && required) { + t.Errorf("Form Tag containing only name and required was not correctly parsed.") + } + + _, name, _, _, _, _, required = parseFormTag(objT.Field(6)) + if !(name == "name" && !required) { + t.Errorf("Form Tag containing only name and ignore required was not correctly parsed.") + } + + _, name, _, _, _, _, required = parseFormTag(objT.Field(7)) + if !(name == "name" && !required) { + t.Errorf("Form Tag containing only name and not required was not correctly parsed.") + } + +} + +func TestMapGet(t *testing.T) { + // test one level map + m1 := map[string]int64{ + "a": 1, + "1": 2, + } + + if res, err := MapGet(m1, "a"); err == nil { + if res.(int64) != 1 { + t.Errorf("Should return 1, but return %v", res) + } + } else { + t.Errorf("Error happens %v", err) + } + + if res, err := MapGet(m1, "1"); err == nil { + if res.(int64) != 2 { + t.Errorf("Should return 2, but return %v", res) + } + } else { + t.Errorf("Error happens %v", err) + } + + if res, err := MapGet(m1, 1); err == nil { + if res.(int64) != 2 { + t.Errorf("Should return 2, but return %v", res) + } + } else { + t.Errorf("Error happens %v", err) + } + + // test 2 level map + m2 := M{ + "1": map[string]float64{ + "2": 3.5, + }, + } + + if res, err := MapGet(m2, 1, 2); err == nil { + if res.(float64) != 3.5 { + t.Errorf("Should return 3.5, but return %v", res) + } + } else { + t.Errorf("Error happens %v", err) + } + + // test 5 level map + m5 := M{ + "1": M{ + "2": M{ + "3": M{ + "4": M{ + "5": 1.2, + }, + }, + }, + }, + } + + if res, err := MapGet(m5, 1, 2, 3, 4, 5); err == nil { + if res.(float64) != 1.2 { + t.Errorf("Should return 1.2, but return %v", res) + } + } else { + t.Errorf("Error happens %v", err) + } + + // check whether element not exists in map + if res, err := MapGet(m5, 5, 4, 3, 2, 1); err == nil { + if res != nil { + t.Errorf("Should return nil, but return %v", res) + } + } else { + t.Errorf("Error happens %v", err) + } +} diff --git a/pkg/testdata/Makefile b/pkg/testdata/Makefile new file mode 100644 index 00000000..e80e8238 --- /dev/null +++ b/pkg/testdata/Makefile @@ -0,0 +1,2 @@ +build_view: + $(GOPATH)/bin/go-bindata-assetfs -pkg testdata views/... \ No newline at end of file diff --git a/pkg/testdata/bindata.go b/pkg/testdata/bindata.go new file mode 100644 index 00000000..beade103 --- /dev/null +++ b/pkg/testdata/bindata.go @@ -0,0 +1,296 @@ +// Code generated by go-bindata. +// sources: +// views/blocks/block.tpl +// views/header.tpl +// views/index.tpl +// DO NOT EDIT! + +package testdata + +import ( + "bytes" + "compress/gzip" + "fmt" + "github.com/elazarl/go-bindata-assetfs" + "io" + "io/ioutil" + "os" + "path/filepath" + "strings" + "time" +) + +func bindataRead(data []byte, name string) ([]byte, error) { + gz, err := gzip.NewReader(bytes.NewBuffer(data)) + if err != nil { + return nil, fmt.Errorf("Read %q: %v", name, err) + } + + var buf bytes.Buffer + _, err = io.Copy(&buf, gz) + clErr := gz.Close() + + if err != nil { + return nil, fmt.Errorf("Read %q: %v", name, err) + } + if clErr != nil { + return nil, err + } + + return buf.Bytes(), nil +} + +type asset struct { + bytes []byte + info os.FileInfo +} + +type bindataFileInfo struct { + name string + size int64 + mode os.FileMode + modTime time.Time +} + +func (fi bindataFileInfo) Name() string { + return fi.name +} +func (fi bindataFileInfo) Size() int64 { + return fi.size +} +func (fi bindataFileInfo) Mode() os.FileMode { + return fi.mode +} +func (fi bindataFileInfo) ModTime() time.Time { + return fi.modTime +} +func (fi bindataFileInfo) IsDir() bool { + return false +} +func (fi bindataFileInfo) Sys() interface{} { + return nil +} + +var _viewsBlocksBlockTpl = []byte("\x1f\x8b\x08\x00\x00\x00\x00\x00\x00\xff\xaa\xae\x4e\x49\x4d\xcb\xcc\x4b\x55\x50\x4a\xca\xc9\x4f\xce\x56\xaa\xad\xe5\xb2\xc9\x30\xb4\xf3\x48\xcd\xc9\xc9\xd7\x51\x00\x8b\x15\x2b\xda\xe8\x67\x18\xda\x71\x55\x57\xa7\xe6\xa5\xd4\xd6\x02\x02\x00\x00\xff\xff\xfd\xa1\x7a\xf6\x32\x00\x00\x00") + +func viewsBlocksBlockTplBytes() ([]byte, error) { + return bindataRead( + _viewsBlocksBlockTpl, + "views/blocks/block.tpl", + ) +} + +func viewsBlocksBlockTpl() (*asset, error) { + bytes, err := viewsBlocksBlockTplBytes() + if err != nil { + return nil, err + } + + info := bindataFileInfo{name: "views/blocks/block.tpl", size: 50, mode: os.FileMode(436), modTime: time.Unix(1541431067, 0)} + a := &asset{bytes: bytes, info: info} + return a, nil +} + +var _viewsHeaderTpl = []byte("\x1f\x8b\x08\x00\x00\x00\x00\x00\x00\xff\xaa\xae\x4e\x49\x4d\xcb\xcc\x4b\x55\x50\xca\x48\x4d\x4c\x49\x2d\x52\xaa\xad\xe5\xb2\xc9\x30\xb4\xf3\x48\xcd\xc9\xc9\xd7\x51\x48\x2c\x2e\x49\xac\xc8\x4c\x55\xb4\xd1\xcf\x30\xb4\xe3\xaa\xae\x4e\xcd\x4b\xa9\xad\x05\x04\x00\x00\xff\xff\xe4\x12\x47\x01\x34\x00\x00\x00") + +func viewsHeaderTplBytes() ([]byte, error) { + return bindataRead( + _viewsHeaderTpl, + "views/header.tpl", + ) +} + +func viewsHeaderTpl() (*asset, error) { + bytes, err := viewsHeaderTplBytes() + if err != nil { + return nil, err + } + + info := bindataFileInfo{name: "views/header.tpl", size: 52, mode: os.FileMode(436), modTime: time.Unix(1541431067, 0)} + a := &asset{bytes: bytes, info: info} + return a, nil +} + +var _viewsIndexTpl = []byte("\x1f\x8b\x08\x00\x00\x00\x00\x00\x00\xff\x64\x8f\xbd\x8a\xc3\x30\x10\x84\x6b\xeb\x29\xe6\xfc\x00\x16\xb8\x3c\x16\x35\x77\xa9\x13\x88\x09\xa4\xf4\xcf\x12\x99\x48\x48\xd8\x82\x10\x84\xde\x3d\xc8\x8a\x8b\x90\x6a\xa4\xd9\x6f\xd8\x59\xfa\xf9\x3f\xfe\x75\xd7\xd3\x01\x3a\x58\xa3\x04\x15\x01\x48\x73\x3f\xe5\x07\x40\x61\x0e\x86\xd5\xc0\x7c\x73\x78\xb0\x19\x9d\x65\x04\xb6\xde\xf4\x81\x49\x96\x69\x8e\xc8\x3d\x43\x83\x9b\x9e\x4a\x88\x2a\xc6\x9d\x43\x3d\x18\x37\xde\xeb\x94\x3e\xdd\x1c\xe1\xe5\xcb\xde\xe0\x55\x6e\xd2\x04\x6f\x32\x20\x2a\xd2\xad\x8a\x11\x4d\x97\x57\x22\x25\x92\xba\x55\xa2\x22\xaf\xd0\xe9\x79\xc5\xbc\xe2\xec\x2c\x5f\xfa\xe5\x17\x99\x7b\x7f\x36\xd2\x97\x8a\xa5\x19\xc9\x72\xe7\x2b\x00\x00\xff\xff\xb2\x39\xca\x9f\xff\x00\x00\x00") + +func viewsIndexTplBytes() ([]byte, error) { + return bindataRead( + _viewsIndexTpl, + "views/index.tpl", + ) +} + +func viewsIndexTpl() (*asset, error) { + bytes, err := viewsIndexTplBytes() + if err != nil { + return nil, err + } + + info := bindataFileInfo{name: "views/index.tpl", size: 255, mode: os.FileMode(436), modTime: time.Unix(1541434906, 0)} + a := &asset{bytes: bytes, info: info} + return a, nil +} + +// Asset loads and returns the asset for the given name. +// It returns an error if the asset could not be found or +// could not be loaded. +func Asset(name string) ([]byte, error) { + cannonicalName := strings.Replace(name, "\\", "/", -1) + if f, ok := _bindata[cannonicalName]; ok { + a, err := f() + if err != nil { + return nil, fmt.Errorf("Asset %s can't read by error: %v", name, err) + } + return a.bytes, nil + } + return nil, fmt.Errorf("Asset %s not found", name) +} + +// MustAsset is like Asset but panics when Asset would return an error. +// It simplifies safe initialization of global variables. +func MustAsset(name string) []byte { + a, err := Asset(name) + if err != nil { + panic("asset: Asset(" + name + "): " + err.Error()) + } + + return a +} + +// AssetInfo loads and returns the asset info for the given name. +// It returns an error if the asset could not be found or +// could not be loaded. +func AssetInfo(name string) (os.FileInfo, error) { + cannonicalName := strings.Replace(name, "\\", "/", -1) + if f, ok := _bindata[cannonicalName]; ok { + a, err := f() + if err != nil { + return nil, fmt.Errorf("AssetInfo %s can't read by error: %v", name, err) + } + return a.info, nil + } + return nil, fmt.Errorf("AssetInfo %s not found", name) +} + +// AssetNames returns the names of the assets. +func AssetNames() []string { + names := make([]string, 0, len(_bindata)) + for name := range _bindata { + names = append(names, name) + } + return names +} + +// _bindata is a table, holding each asset generator, mapped to its name. +var _bindata = map[string]func() (*asset, error){ + "views/blocks/block.tpl": viewsBlocksBlockTpl, + "views/header.tpl": viewsHeaderTpl, + "views/index.tpl": viewsIndexTpl, +} + +// AssetDir returns the file names below a certain +// directory embedded in the file by go-bindata. +// For example if you run go-bindata on data/... and data contains the +// following hierarchy: +// data/ +// foo.txt +// img/ +// a.png +// b.png +// then AssetDir("data") would return []string{"foo.txt", "img"} +// AssetDir("data/img") would return []string{"a.png", "b.png"} +// AssetDir("foo.txt") and AssetDir("notexist") would return an error +// AssetDir("") will return []string{"data"}. +func AssetDir(name string) ([]string, error) { + node := _bintree + if len(name) != 0 { + cannonicalName := strings.Replace(name, "\\", "/", -1) + pathList := strings.Split(cannonicalName, "/") + for _, p := range pathList { + node = node.Children[p] + if node == nil { + return nil, fmt.Errorf("Asset %s not found", name) + } + } + } + if node.Func != nil { + return nil, fmt.Errorf("Asset %s not found", name) + } + rv := make([]string, 0, len(node.Children)) + for childName := range node.Children { + rv = append(rv, childName) + } + return rv, nil +} + +type bintree struct { + Func func() (*asset, error) + Children map[string]*bintree +} + +var _bintree = &bintree{nil, map[string]*bintree{ + "views": &bintree{nil, map[string]*bintree{ + "blocks": &bintree{nil, map[string]*bintree{ + "block.tpl": &bintree{viewsBlocksBlockTpl, map[string]*bintree{}}, + }}, + "header.tpl": &bintree{viewsHeaderTpl, map[string]*bintree{}}, + "index.tpl": &bintree{viewsIndexTpl, map[string]*bintree{}}, + }}, +}} + +// RestoreAsset restores an asset under the given directory +func RestoreAsset(dir, name string) error { + data, err := Asset(name) + if err != nil { + return err + } + info, err := AssetInfo(name) + if err != nil { + return err + } + err = os.MkdirAll(_filePath(dir, filepath.Dir(name)), os.FileMode(0755)) + if err != nil { + return err + } + err = ioutil.WriteFile(_filePath(dir, name), data, info.Mode()) + if err != nil { + return err + } + err = os.Chtimes(_filePath(dir, name), info.ModTime(), info.ModTime()) + if err != nil { + return err + } + return nil +} + +// RestoreAssets restores an asset under the given directory recursively +func RestoreAssets(dir, name string) error { + children, err := AssetDir(name) + // File + if err != nil { + return RestoreAsset(dir, name) + } + // Dir + for _, child := range children { + err = RestoreAssets(dir, filepath.Join(name, child)) + if err != nil { + return err + } + } + return nil +} + +func _filePath(dir, name string) string { + cannonicalName := strings.Replace(name, "\\", "/", -1) + return filepath.Join(append([]string{dir}, strings.Split(cannonicalName, "/")...)...) +} + +func assetFS() *assetfs.AssetFS { + assetInfo := func(path string) (os.FileInfo, error) { + return os.Stat(path) + } + for k := range _bintree.Children { + return &assetfs.AssetFS{Asset: Asset, AssetDir: AssetDir, AssetInfo: assetInfo, Prefix: k} + } + panic("unreachable") +} diff --git a/pkg/testdata/views/blocks/block.tpl b/pkg/testdata/views/blocks/block.tpl new file mode 100644 index 00000000..2a9c57fc --- /dev/null +++ b/pkg/testdata/views/blocks/block.tpl @@ -0,0 +1,3 @@ +{{define "block"}} +

Hello, blocks!

+{{end}} \ No newline at end of file diff --git a/pkg/testdata/views/header.tpl b/pkg/testdata/views/header.tpl new file mode 100644 index 00000000..041fa403 --- /dev/null +++ b/pkg/testdata/views/header.tpl @@ -0,0 +1,3 @@ +{{define "header"}} +

Hello, astaxie!

+{{end}} \ No newline at end of file diff --git a/pkg/testdata/views/index.tpl b/pkg/testdata/views/index.tpl new file mode 100644 index 00000000..21b7fc06 --- /dev/null +++ b/pkg/testdata/views/index.tpl @@ -0,0 +1,15 @@ + + + + beego welcome template + + + + {{template "block"}} + {{template "header"}} + {{template "blocks/block.tpl"}} + +

{{ .Title }}

+

This is SomeVar: {{ .SomeVar }}

+ + diff --git a/pkg/testing/assertions.go b/pkg/testing/assertions.go new file mode 100644 index 00000000..96c5d4dd --- /dev/null +++ b/pkg/testing/assertions.go @@ -0,0 +1,15 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package testing diff --git a/pkg/testing/client.go b/pkg/testing/client.go new file mode 100644 index 00000000..c3737e9c --- /dev/null +++ b/pkg/testing/client.go @@ -0,0 +1,65 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package testing + +import ( + "github.com/astaxie/beego/config" + "github.com/astaxie/beego/httplib" +) + +var port = "" +var baseURL = "http://localhost:" + +// TestHTTPRequest beego test request client +type TestHTTPRequest struct { + httplib.BeegoHTTPRequest +} + +func getPort() string { + if port == "" { + config, err := config.NewConfig("ini", "../conf/app.conf") + if err != nil { + return "8080" + } + port = config.String("httpport") + return port + } + return port +} + +// Get returns test client in GET method +func Get(path string) *TestHTTPRequest { + return &TestHTTPRequest{*httplib.Get(baseURL + getPort() + path)} +} + +// Post returns test client in POST method +func Post(path string) *TestHTTPRequest { + return &TestHTTPRequest{*httplib.Post(baseURL + getPort() + path)} +} + +// Put returns test client in PUT method +func Put(path string) *TestHTTPRequest { + return &TestHTTPRequest{*httplib.Put(baseURL + getPort() + path)} +} + +// Delete returns test client in DELETE method +func Delete(path string) *TestHTTPRequest { + return &TestHTTPRequest{*httplib.Delete(baseURL + getPort() + path)} +} + +// Head returns test client in HEAD method +func Head(path string) *TestHTTPRequest { + return &TestHTTPRequest{*httplib.Head(baseURL + getPort() + path)} +} diff --git a/pkg/toolbox/healthcheck.go b/pkg/toolbox/healthcheck.go new file mode 100644 index 00000000..e3544b3a --- /dev/null +++ b/pkg/toolbox/healthcheck.go @@ -0,0 +1,48 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package toolbox healthcheck +// +// type DatabaseCheck struct { +// } +// +// func (dc *DatabaseCheck) Check() error { +// if dc.isConnected() { +// return nil +// } else { +// return errors.New("can't connect database") +// } +// } +// +// AddHealthCheck("database",&DatabaseCheck{}) +// +// more docs: http://beego.me/docs/module/toolbox.md +package toolbox + +// AdminCheckList holds health checker map +var AdminCheckList map[string]HealthChecker + +// HealthChecker health checker interface +type HealthChecker interface { + Check() error +} + +// AddHealthCheck add health checker with name string +func AddHealthCheck(name string, hc HealthChecker) { + AdminCheckList[name] = hc +} + +func init() { + AdminCheckList = make(map[string]HealthChecker) +} diff --git a/pkg/toolbox/profile.go b/pkg/toolbox/profile.go new file mode 100644 index 00000000..06e40ede --- /dev/null +++ b/pkg/toolbox/profile.go @@ -0,0 +1,184 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package toolbox + +import ( + "fmt" + "io" + "log" + "os" + "path" + "runtime" + "runtime/debug" + "runtime/pprof" + "strconv" + "time" +) + +var startTime = time.Now() +var pid int + +func init() { + pid = os.Getpid() +} + +// ProcessInput parse input command string +func ProcessInput(input string, w io.Writer) { + switch input { + case "lookup goroutine": + p := pprof.Lookup("goroutine") + p.WriteTo(w, 2) + case "lookup heap": + p := pprof.Lookup("heap") + p.WriteTo(w, 2) + case "lookup threadcreate": + p := pprof.Lookup("threadcreate") + p.WriteTo(w, 2) + case "lookup block": + p := pprof.Lookup("block") + p.WriteTo(w, 2) + case "get cpuprof": + GetCPUProfile(w) + case "get memprof": + MemProf(w) + case "gc summary": + PrintGCSummary(w) + } +} + +// MemProf record memory profile in pprof +func MemProf(w io.Writer) { + filename := "mem-" + strconv.Itoa(pid) + ".memprof" + if f, err := os.Create(filename); err != nil { + fmt.Fprintf(w, "create file %s error %s\n", filename, err.Error()) + log.Fatal("record heap profile failed: ", err) + } else { + runtime.GC() + pprof.WriteHeapProfile(f) + f.Close() + fmt.Fprintf(w, "create heap profile %s \n", filename) + _, fl := path.Split(os.Args[0]) + fmt.Fprintf(w, "Now you can use this to check it: go tool pprof %s %s\n", fl, filename) + } +} + +// GetCPUProfile start cpu profile monitor +func GetCPUProfile(w io.Writer) { + sec := 30 + filename := "cpu-" + strconv.Itoa(pid) + ".pprof" + f, err := os.Create(filename) + if err != nil { + fmt.Fprintf(w, "Could not enable CPU profiling: %s\n", err) + log.Fatal("record cpu profile failed: ", err) + } + pprof.StartCPUProfile(f) + time.Sleep(time.Duration(sec) * time.Second) + pprof.StopCPUProfile() + + fmt.Fprintf(w, "create cpu profile %s \n", filename) + _, fl := path.Split(os.Args[0]) + fmt.Fprintf(w, "Now you can use this to check it: go tool pprof %s %s\n", fl, filename) +} + +// PrintGCSummary print gc information to io.Writer +func PrintGCSummary(w io.Writer) { + memStats := &runtime.MemStats{} + runtime.ReadMemStats(memStats) + gcstats := &debug.GCStats{PauseQuantiles: make([]time.Duration, 100)} + debug.ReadGCStats(gcstats) + + printGC(memStats, gcstats, w) +} + +func printGC(memStats *runtime.MemStats, gcstats *debug.GCStats, w io.Writer) { + + if gcstats.NumGC > 0 { + lastPause := gcstats.Pause[0] + elapsed := time.Now().Sub(startTime) + overhead := float64(gcstats.PauseTotal) / float64(elapsed) * 100 + allocatedRate := float64(memStats.TotalAlloc) / elapsed.Seconds() + + fmt.Fprintf(w, "NumGC:%d Pause:%s Pause(Avg):%s Overhead:%3.2f%% Alloc:%s Sys:%s Alloc(Rate):%s/s Histogram:%s %s %s \n", + gcstats.NumGC, + toS(lastPause), + toS(avg(gcstats.Pause)), + overhead, + toH(memStats.Alloc), + toH(memStats.Sys), + toH(uint64(allocatedRate)), + toS(gcstats.PauseQuantiles[94]), + toS(gcstats.PauseQuantiles[98]), + toS(gcstats.PauseQuantiles[99])) + } else { + // while GC has disabled + elapsed := time.Now().Sub(startTime) + allocatedRate := float64(memStats.TotalAlloc) / elapsed.Seconds() + + fmt.Fprintf(w, "Alloc:%s Sys:%s Alloc(Rate):%s/s\n", + toH(memStats.Alloc), + toH(memStats.Sys), + toH(uint64(allocatedRate))) + } +} + +func avg(items []time.Duration) time.Duration { + var sum time.Duration + for _, item := range items { + sum += item + } + return time.Duration(int64(sum) / int64(len(items))) +} + +// format bytes number friendly +func toH(bytes uint64) string { + switch { + case bytes < 1024: + return fmt.Sprintf("%dB", bytes) + case bytes < 1024*1024: + return fmt.Sprintf("%.2fK", float64(bytes)/1024) + case bytes < 1024*1024*1024: + return fmt.Sprintf("%.2fM", float64(bytes)/1024/1024) + default: + return fmt.Sprintf("%.2fG", float64(bytes)/1024/1024/1024) + } +} + +// short string format +func toS(d time.Duration) string { + + u := uint64(d) + if u < uint64(time.Second) { + switch { + case u == 0: + return "0" + case u < uint64(time.Microsecond): + return fmt.Sprintf("%.2fns", float64(u)) + case u < uint64(time.Millisecond): + return fmt.Sprintf("%.2fus", float64(u)/1000) + default: + return fmt.Sprintf("%.2fms", float64(u)/1000/1000) + } + } else { + switch { + case u < uint64(time.Minute): + return fmt.Sprintf("%.2fs", float64(u)/1000/1000/1000) + case u < uint64(time.Hour): + return fmt.Sprintf("%.2fm", float64(u)/1000/1000/1000/60) + default: + return fmt.Sprintf("%.2fh", float64(u)/1000/1000/1000/60/60) + } + } + +} diff --git a/pkg/toolbox/profile_test.go b/pkg/toolbox/profile_test.go new file mode 100644 index 00000000..07a20c4e --- /dev/null +++ b/pkg/toolbox/profile_test.go @@ -0,0 +1,28 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package toolbox + +import ( + "os" + "testing" +) + +func TestProcessInput(t *testing.T) { + ProcessInput("lookup goroutine", os.Stdout) + ProcessInput("lookup heap", os.Stdout) + ProcessInput("lookup threadcreate", os.Stdout) + ProcessInput("lookup block", os.Stdout) + ProcessInput("gc summary", os.Stdout) +} diff --git a/pkg/toolbox/statistics.go b/pkg/toolbox/statistics.go new file mode 100644 index 00000000..fd73dfb3 --- /dev/null +++ b/pkg/toolbox/statistics.go @@ -0,0 +1,149 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package toolbox + +import ( + "fmt" + "sync" + "time" +) + +// Statistics struct +type Statistics struct { + RequestURL string + RequestController string + RequestNum int64 + MinTime time.Duration + MaxTime time.Duration + TotalTime time.Duration +} + +// URLMap contains several statistics struct to log different data +type URLMap struct { + lock sync.RWMutex + LengthLimit int //limit the urlmap's length if it's equal to 0 there's no limit + urlmap map[string]map[string]*Statistics +} + +// AddStatistics add statistics task. +// it needs request method, request url, request controller and statistics time duration +func (m *URLMap) AddStatistics(requestMethod, requestURL, requestController string, requesttime time.Duration) { + m.lock.Lock() + defer m.lock.Unlock() + if method, ok := m.urlmap[requestURL]; ok { + if s, ok := method[requestMethod]; ok { + s.RequestNum++ + if s.MaxTime < requesttime { + s.MaxTime = requesttime + } + if s.MinTime > requesttime { + s.MinTime = requesttime + } + s.TotalTime += requesttime + } else { + nb := &Statistics{ + RequestURL: requestURL, + RequestController: requestController, + RequestNum: 1, + MinTime: requesttime, + MaxTime: requesttime, + TotalTime: requesttime, + } + m.urlmap[requestURL][requestMethod] = nb + } + + } else { + if m.LengthLimit > 0 && m.LengthLimit <= len(m.urlmap) { + return + } + methodmap := make(map[string]*Statistics) + nb := &Statistics{ + RequestURL: requestURL, + RequestController: requestController, + RequestNum: 1, + MinTime: requesttime, + MaxTime: requesttime, + TotalTime: requesttime, + } + methodmap[requestMethod] = nb + m.urlmap[requestURL] = methodmap + } +} + +// GetMap put url statistics result in io.Writer +func (m *URLMap) GetMap() map[string]interface{} { + m.lock.RLock() + defer m.lock.RUnlock() + + var fields = []string{"requestUrl", "method", "times", "used", "max used", "min used", "avg used"} + + var resultLists [][]string + content := make(map[string]interface{}) + content["Fields"] = fields + + for k, v := range m.urlmap { + for kk, vv := range v { + result := []string{ + fmt.Sprintf("% -50s", k), + fmt.Sprintf("% -10s", kk), + fmt.Sprintf("% -16d", vv.RequestNum), + fmt.Sprintf("%d", vv.TotalTime), + fmt.Sprintf("% -16s", toS(vv.TotalTime)), + fmt.Sprintf("%d", vv.MaxTime), + fmt.Sprintf("% -16s", toS(vv.MaxTime)), + fmt.Sprintf("%d", vv.MinTime), + fmt.Sprintf("% -16s", toS(vv.MinTime)), + fmt.Sprintf("%d", time.Duration(int64(vv.TotalTime)/vv.RequestNum)), + fmt.Sprintf("% -16s", toS(time.Duration(int64(vv.TotalTime)/vv.RequestNum))), + } + resultLists = append(resultLists, result) + } + } + content["Data"] = resultLists + return content +} + +// GetMapData return all mapdata +func (m *URLMap) GetMapData() []map[string]interface{} { + m.lock.RLock() + defer m.lock.RUnlock() + + var resultLists []map[string]interface{} + + for k, v := range m.urlmap { + for kk, vv := range v { + result := map[string]interface{}{ + "request_url": k, + "method": kk, + "times": vv.RequestNum, + "total_time": toS(vv.TotalTime), + "max_time": toS(vv.MaxTime), + "min_time": toS(vv.MinTime), + "avg_time": toS(time.Duration(int64(vv.TotalTime) / vv.RequestNum)), + } + resultLists = append(resultLists, result) + } + } + return resultLists +} + +// StatisticsMap hosld global statistics data map +var StatisticsMap *URLMap + +func init() { + StatisticsMap = &URLMap{ + urlmap: make(map[string]map[string]*Statistics), + } +} diff --git a/pkg/toolbox/statistics_test.go b/pkg/toolbox/statistics_test.go new file mode 100644 index 00000000..ac29476c --- /dev/null +++ b/pkg/toolbox/statistics_test.go @@ -0,0 +1,40 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package toolbox + +import ( + "encoding/json" + "testing" + "time" +) + +func TestStatics(t *testing.T) { + StatisticsMap.AddStatistics("POST", "/api/user", "&admin.user", time.Duration(2000)) + StatisticsMap.AddStatistics("POST", "/api/user", "&admin.user", time.Duration(120000)) + StatisticsMap.AddStatistics("GET", "/api/user", "&admin.user", time.Duration(13000)) + StatisticsMap.AddStatistics("POST", "/api/admin", "&admin.user", time.Duration(14000)) + StatisticsMap.AddStatistics("POST", "/api/user/astaxie", "&admin.user", time.Duration(12000)) + StatisticsMap.AddStatistics("POST", "/api/user/xiemengjun", "&admin.user", time.Duration(13000)) + StatisticsMap.AddStatistics("DELETE", "/api/user", "&admin.user", time.Duration(1400)) + t.Log(StatisticsMap.GetMap()) + + data := StatisticsMap.GetMapData() + b, err := json.Marshal(data) + if err != nil { + t.Errorf(err.Error()) + } + + t.Log(string(b)) +} diff --git a/pkg/toolbox/task.go b/pkg/toolbox/task.go new file mode 100644 index 00000000..c902fdfc --- /dev/null +++ b/pkg/toolbox/task.go @@ -0,0 +1,640 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package toolbox + +import ( + "log" + "math" + "sort" + "strconv" + "strings" + "sync" + "time" +) + +// bounds provides a range of acceptable values (plus a map of name to value). +type bounds struct { + min, max uint + names map[string]uint +} + +// The bounds for each field. +var ( + AdminTaskList map[string]Tasker + taskLock sync.RWMutex + stop chan bool + changed chan bool + isstart bool + seconds = bounds{0, 59, nil} + minutes = bounds{0, 59, nil} + hours = bounds{0, 23, nil} + days = bounds{1, 31, nil} + months = bounds{1, 12, map[string]uint{ + "jan": 1, + "feb": 2, + "mar": 3, + "apr": 4, + "may": 5, + "jun": 6, + "jul": 7, + "aug": 8, + "sep": 9, + "oct": 10, + "nov": 11, + "dec": 12, + }} + weeks = bounds{0, 6, map[string]uint{ + "sun": 0, + "mon": 1, + "tue": 2, + "wed": 3, + "thu": 4, + "fri": 5, + "sat": 6, + }} +) + +const ( + // Set the top bit if a star was included in the expression. + starBit = 1 << 63 +) + +// Schedule time taks schedule +type Schedule struct { + Second uint64 + Minute uint64 + Hour uint64 + Day uint64 + Month uint64 + Week uint64 +} + +// TaskFunc task func type +type TaskFunc func() error + +// Tasker task interface +type Tasker interface { + GetSpec() string + GetStatus() string + Run() error + SetNext(time.Time) + GetNext() time.Time + SetPrev(time.Time) + GetPrev() time.Time +} + +// task error +type taskerr struct { + t time.Time + errinfo string +} + +// Task task struct +// It's not a thread-safe structure. +// Only nearest errors will be saved in ErrList +type Task struct { + Taskname string + Spec *Schedule + SpecStr string + DoFunc TaskFunc + Prev time.Time + Next time.Time + Errlist []*taskerr // like errtime:errinfo + ErrLimit int // max length for the errlist, 0 stand for no limit + errCnt int // records the error count during the execution +} + +// NewTask add new task with name, time and func +func NewTask(tname string, spec string, f TaskFunc) *Task { + + task := &Task{ + Taskname: tname, + DoFunc: f, + // Make configurable + ErrLimit: 100, + SpecStr: spec, + // we only store the pointer, so it won't use too many space + Errlist: make([]*taskerr, 100, 100), + } + task.SetCron(spec) + return task +} + +// GetSpec get spec string +func (t *Task) GetSpec() string { + return t.SpecStr +} + +// GetStatus get current task status +func (t *Task) GetStatus() string { + var str string + for _, v := range t.Errlist { + str += v.t.String() + ":" + v.errinfo + "
" + } + return str +} + +// Run run all tasks +func (t *Task) Run() error { + err := t.DoFunc() + if err != nil { + index := t.errCnt % t.ErrLimit + t.Errlist[index] = &taskerr{t: t.Next, errinfo: err.Error()} + t.errCnt++ + } + return err +} + +// SetNext set next time for this task +func (t *Task) SetNext(now time.Time) { + t.Next = t.Spec.Next(now) +} + +// GetNext get the next call time of this task +func (t *Task) GetNext() time.Time { + return t.Next +} + +// SetPrev set prev time of this task +func (t *Task) SetPrev(now time.Time) { + t.Prev = now +} + +// GetPrev get prev time of this task +func (t *Task) GetPrev() time.Time { + return t.Prev +} + +// six columns mean: +// second:0-59 +// minute:0-59 +// hour:1-23 +// day:1-31 +// month:1-12 +// week:0-6(0 means Sunday) + +// SetCron some signals: +// *: any time +// ,:  separate signal +//   -:duration +// /n : do as n times of time duration +///////////////////////////////////////////////////////// +// 0/30 * * * * * every 30s +// 0 43 21 * * * 21:43 +// 0 15 05 * * *    05:15 +// 0 0 17 * * * 17:00 +// 0 0 17 * * 1 17:00 in every Monday +// 0 0,10 17 * * 0,2,3 17:00 and 17:10 in every Sunday, Tuesday and Wednesday +// 0 0-10 17 1 * * 17:00 to 17:10 in 1 min duration each time on the first day of month +// 0 0 0 1,15 * 1 0:00 on the 1st day and 15th day of month +// 0 42 4 1 * *     4:42 on the 1st day of month +// 0 0 21 * * 1-6   21:00 from Monday to Saturday +// 0 0,10,20,30,40,50 * * * *  every 10 min duration +// 0 */10 * * * *        every 10 min duration +// 0 * 1 * * *         1:00 to 1:59 in 1 min duration each time +// 0 0 1 * * *         1:00 +// 0 0 */1 * * *        0 min of hour in 1 hour duration +// 0 0 * * * *         0 min of hour in 1 hour duration +// 0 2 8-20/3 * * *       8:02, 11:02, 14:02, 17:02, 20:02 +// 0 30 5 1,15 * *       5:30 on the 1st day and 15th day of month +func (t *Task) SetCron(spec string) { + t.Spec = t.parse(spec) +} + +func (t *Task) parse(spec string) *Schedule { + if len(spec) > 0 && spec[0] == '@' { + return t.parseSpec(spec) + } + // Split on whitespace. We require 5 or 6 fields. + // (second) (minute) (hour) (day of month) (month) (day of week, optional) + fields := strings.Fields(spec) + if len(fields) != 5 && len(fields) != 6 { + log.Panicf("Expected 5 or 6 fields, found %d: %s", len(fields), spec) + } + + // If a sixth field is not provided (DayOfWeek), then it is equivalent to star. + if len(fields) == 5 { + fields = append(fields, "*") + } + + schedule := &Schedule{ + Second: getField(fields[0], seconds), + Minute: getField(fields[1], minutes), + Hour: getField(fields[2], hours), + Day: getField(fields[3], days), + Month: getField(fields[4], months), + Week: getField(fields[5], weeks), + } + + return schedule +} + +func (t *Task) parseSpec(spec string) *Schedule { + switch spec { + case "@yearly", "@annually": + return &Schedule{ + Second: 1 << seconds.min, + Minute: 1 << minutes.min, + Hour: 1 << hours.min, + Day: 1 << days.min, + Month: 1 << months.min, + Week: all(weeks), + } + + case "@monthly": + return &Schedule{ + Second: 1 << seconds.min, + Minute: 1 << minutes.min, + Hour: 1 << hours.min, + Day: 1 << days.min, + Month: all(months), + Week: all(weeks), + } + + case "@weekly": + return &Schedule{ + Second: 1 << seconds.min, + Minute: 1 << minutes.min, + Hour: 1 << hours.min, + Day: all(days), + Month: all(months), + Week: 1 << weeks.min, + } + + case "@daily", "@midnight": + return &Schedule{ + Second: 1 << seconds.min, + Minute: 1 << minutes.min, + Hour: 1 << hours.min, + Day: all(days), + Month: all(months), + Week: all(weeks), + } + + case "@hourly": + return &Schedule{ + Second: 1 << seconds.min, + Minute: 1 << minutes.min, + Hour: all(hours), + Day: all(days), + Month: all(months), + Week: all(weeks), + } + } + log.Panicf("Unrecognized descriptor: %s", spec) + return nil +} + +// Next set schedule to next time +func (s *Schedule) Next(t time.Time) time.Time { + + // Start at the earliest possible time (the upcoming second). + t = t.Add(1*time.Second - time.Duration(t.Nanosecond())*time.Nanosecond) + + // This flag indicates whether a field has been incremented. + added := false + + // If no time is found within five years, return zero. + yearLimit := t.Year() + 5 + +WRAP: + if t.Year() > yearLimit { + return time.Time{} + } + + // Find the first applicable month. + // If it's this month, then do nothing. + for 1< 0 + dowMatch = 1< 0 + ) + + if s.Day&starBit > 0 || s.Week&starBit > 0 { + return domMatch && dowMatch + } + return domMatch || dowMatch +} + +// StartTask start all tasks +func StartTask() { + taskLock.Lock() + defer taskLock.Unlock() + if isstart { + //If already started, no need to start another goroutine. + return + } + isstart = true + go run() +} + +func run() { + now := time.Now().Local() + for _, t := range AdminTaskList { + t.SetNext(now) + } + + for { + // we only use RLock here because NewMapSorter copy the reference, do not change any thing + taskLock.RLock() + sortList := NewMapSorter(AdminTaskList) + taskLock.RUnlock() + sortList.Sort() + var effective time.Time + if len(AdminTaskList) == 0 || sortList.Vals[0].GetNext().IsZero() { + // If there are no entries yet, just sleep - it still handles new entries + // and stop requests. + effective = now.AddDate(10, 0, 0) + } else { + effective = sortList.Vals[0].GetNext() + } + select { + case now = <-time.After(effective.Sub(now)): + // Run every entry whose next time was this effective time. + for _, e := range sortList.Vals { + if e.GetNext() != effective { + break + } + go e.Run() + e.SetPrev(e.GetNext()) + e.SetNext(effective) + } + continue + case <-changed: + now = time.Now().Local() + taskLock.Lock() + for _, t := range AdminTaskList { + t.SetNext(now) + } + taskLock.Unlock() + continue + case <-stop: + return + } + } +} + +// StopTask stop all tasks +func StopTask() { + taskLock.Lock() + defer taskLock.Unlock() + if isstart { + isstart = false + stop <- true + } + +} + +// AddTask add task with name +func AddTask(taskname string, t Tasker) { + taskLock.Lock() + defer taskLock.Unlock() + t.SetNext(time.Now().Local()) + AdminTaskList[taskname] = t + if isstart { + changed <- true + } +} + +// DeleteTask delete task with name +func DeleteTask(taskname string) { + taskLock.Lock() + defer taskLock.Unlock() + delete(AdminTaskList, taskname) + if isstart { + changed <- true + } +} + +// MapSorter sort map for tasker +type MapSorter struct { + Keys []string + Vals []Tasker +} + +// NewMapSorter create new tasker map +func NewMapSorter(m map[string]Tasker) *MapSorter { + ms := &MapSorter{ + Keys: make([]string, 0, len(m)), + Vals: make([]Tasker, 0, len(m)), + } + for k, v := range m { + ms.Keys = append(ms.Keys, k) + ms.Vals = append(ms.Vals, v) + } + return ms +} + +// Sort sort tasker map +func (ms *MapSorter) Sort() { + sort.Sort(ms) +} + +func (ms *MapSorter) Len() int { return len(ms.Keys) } +func (ms *MapSorter) Less(i, j int) bool { + if ms.Vals[i].GetNext().IsZero() { + return false + } + if ms.Vals[j].GetNext().IsZero() { + return true + } + return ms.Vals[i].GetNext().Before(ms.Vals[j].GetNext()) +} +func (ms *MapSorter) Swap(i, j int) { + ms.Vals[i], ms.Vals[j] = ms.Vals[j], ms.Vals[i] + ms.Keys[i], ms.Keys[j] = ms.Keys[j], ms.Keys[i] +} + +func getField(field string, r bounds) uint64 { + // list = range {"," range} + var bits uint64 + ranges := strings.FieldsFunc(field, func(r rune) bool { return r == ',' }) + for _, expr := range ranges { + bits |= getRange(expr, r) + } + return bits +} + +// getRange returns the bits indicated by the given expression: +// number | number "-" number [ "/" number ] +func getRange(expr string, r bounds) uint64 { + + var ( + start, end, step uint + rangeAndStep = strings.Split(expr, "/") + lowAndHigh = strings.Split(rangeAndStep[0], "-") + singleDigit = len(lowAndHigh) == 1 + ) + + var extrastar uint64 + if lowAndHigh[0] == "*" || lowAndHigh[0] == "?" { + start = r.min + end = r.max + extrastar = starBit + } else { + start = parseIntOrName(lowAndHigh[0], r.names) + switch len(lowAndHigh) { + case 1: + end = start + case 2: + end = parseIntOrName(lowAndHigh[1], r.names) + default: + log.Panicf("Too many hyphens: %s", expr) + } + } + + switch len(rangeAndStep) { + case 1: + step = 1 + case 2: + step = mustParseInt(rangeAndStep[1]) + + // Special handling: "N/step" means "N-max/step". + if singleDigit { + end = r.max + } + default: + log.Panicf("Too many slashes: %s", expr) + } + + if start < r.min { + log.Panicf("Beginning of range (%d) below minimum (%d): %s", start, r.min, expr) + } + if end > r.max { + log.Panicf("End of range (%d) above maximum (%d): %s", end, r.max, expr) + } + if start > end { + log.Panicf("Beginning of range (%d) beyond end of range (%d): %s", start, end, expr) + } + + return getBits(start, end, step) | extrastar +} + +// parseIntOrName returns the (possibly-named) integer contained in expr. +func parseIntOrName(expr string, names map[string]uint) uint { + if names != nil { + if namedInt, ok := names[strings.ToLower(expr)]; ok { + return namedInt + } + } + return mustParseInt(expr) +} + +// mustParseInt parses the given expression as an int or panics. +func mustParseInt(expr string) uint { + num, err := strconv.Atoi(expr) + if err != nil { + log.Panicf("Failed to parse int from %s: %s", expr, err) + } + if num < 0 { + log.Panicf("Negative number (%d) not allowed: %s", num, expr) + } + + return uint(num) +} + +// getBits sets all bits in the range [min, max], modulo the given step size. +func getBits(min, max, step uint) uint64 { + var bits uint64 + + // If step is 1, use shifts. + if step == 1 { + return ^(math.MaxUint64 << (max + 1)) & (math.MaxUint64 << min) + } + + // Else, use a simple loop. + for i := min; i <= max; i += step { + bits |= 1 << i + } + return bits +} + +// all returns all bits within the given bounds. (plus the star bit) +func all(r bounds) uint64 { + return getBits(r.min, r.max, 1) | starBit +} + +func init() { + AdminTaskList = make(map[string]Tasker) + stop = make(chan bool) + changed = make(chan bool) +} diff --git a/pkg/toolbox/task_test.go b/pkg/toolbox/task_test.go new file mode 100644 index 00000000..3a4cce2f --- /dev/null +++ b/pkg/toolbox/task_test.go @@ -0,0 +1,85 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package toolbox + +import ( + "errors" + "fmt" + "sync" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestParse(t *testing.T) { + tk := NewTask("taska", "0/30 * * * * *", func() error { fmt.Println("hello world"); return nil }) + err := tk.Run() + if err != nil { + t.Fatal(err) + } + AddTask("taska", tk) + StartTask() + time.Sleep(6 * time.Second) + StopTask() +} + +func TestSpec(t *testing.T) { + wg := &sync.WaitGroup{} + wg.Add(2) + tk1 := NewTask("tk1", "0 12 * * * *", func() error { fmt.Println("tk1"); return nil }) + tk2 := NewTask("tk2", "0,10,20 * * * * *", func() error { fmt.Println("tk2"); wg.Done(); return nil }) + tk3 := NewTask("tk3", "0 10 * * * *", func() error { fmt.Println("tk3"); wg.Done(); return nil }) + + AddTask("tk1", tk1) + AddTask("tk2", tk2) + AddTask("tk3", tk3) + StartTask() + defer StopTask() + + select { + case <-time.After(200 * time.Second): + t.FailNow() + case <-wait(wg): + } +} + +func TestTask_Run(t *testing.T) { + cnt := -1 + task := func() error { + cnt ++ + fmt.Printf("Hello, world! %d \n", cnt) + return errors.New(fmt.Sprintf("Hello, world! %d", cnt)) + } + tk := NewTask("taska", "0/30 * * * * *", task) + for i := 0; i < 200 ; i ++ { + e := tk.Run() + assert.NotNil(t, e) + } + + l := tk.Errlist + assert.Equal(t, 100, len(l)) + assert.Equal(t, "Hello, world! 100", l[0].errinfo) + assert.Equal(t, "Hello, world! 101", l[1].errinfo) +} + +func wait(wg *sync.WaitGroup) chan bool { + ch := make(chan bool) + go func() { + wg.Wait() + ch <- true + }() + return ch +} diff --git a/pkg/tree.go b/pkg/tree.go new file mode 100644 index 00000000..9e53003b --- /dev/null +++ b/pkg/tree.go @@ -0,0 +1,585 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package beego + +import ( + "path" + "regexp" + "strings" + + "github.com/astaxie/beego/context" + "github.com/astaxie/beego/utils" +) + +var ( + allowSuffixExt = []string{".json", ".xml", ".html"} +) + +// Tree has three elements: FixRouter/wildcard/leaves +// fixRouter stores Fixed Router +// wildcard stores params +// leaves store the endpoint information +type Tree struct { + //prefix set for static router + prefix string + //search fix route first + fixrouters []*Tree + //if set, failure to match fixrouters search then search wildcard + wildcard *Tree + //if set, failure to match wildcard search + leaves []*leafInfo +} + +// NewTree return a new Tree +func NewTree() *Tree { + return &Tree{} +} + +// AddTree will add tree to the exist Tree +// prefix should has no params +func (t *Tree) AddTree(prefix string, tree *Tree) { + t.addtree(splitPath(prefix), tree, nil, "") +} + +func (t *Tree) addtree(segments []string, tree *Tree, wildcards []string, reg string) { + if len(segments) == 0 { + panic("prefix should has path") + } + seg := segments[0] + iswild, params, regexpStr := splitSegment(seg) + // if it's ? meaning can igone this, so add one more rule for it + if len(params) > 0 && params[0] == ":" { + params = params[1:] + if len(segments[1:]) > 0 { + t.addtree(segments[1:], tree, append(wildcards, params...), reg) + } else { + filterTreeWithPrefix(tree, wildcards, reg) + } + } + //Rule: /login/*/access match /login/2009/11/access + //if already has *, and when loop the access, should as a regexpStr + if !iswild && utils.InSlice(":splat", wildcards) { + iswild = true + regexpStr = seg + } + //Rule: /user/:id/* + if seg == "*" && len(wildcards) > 0 && reg == "" { + regexpStr = "(.+)" + } + if len(segments) == 1 { + if iswild { + if regexpStr != "" { + if reg == "" { + rr := "" + for _, w := range wildcards { + if w == ":splat" { + rr = rr + "(.+)/" + } else { + rr = rr + "([^/]+)/" + } + } + regexpStr = rr + regexpStr + } else { + regexpStr = "/" + regexpStr + } + } else if reg != "" { + if seg == "*.*" { + regexpStr = "([^.]+).(.+)" + } else { + for _, w := range params { + if w == "." || w == ":" { + continue + } + regexpStr = "([^/]+)/" + regexpStr + } + } + } + reg = strings.Trim(reg+"/"+regexpStr, "/") + filterTreeWithPrefix(tree, append(wildcards, params...), reg) + t.wildcard = tree + } else { + reg = strings.Trim(reg+"/"+regexpStr, "/") + filterTreeWithPrefix(tree, append(wildcards, params...), reg) + tree.prefix = seg + t.fixrouters = append(t.fixrouters, tree) + } + return + } + + if iswild { + if t.wildcard == nil { + t.wildcard = NewTree() + } + if regexpStr != "" { + if reg == "" { + rr := "" + for _, w := range wildcards { + if w == ":splat" { + rr = rr + "(.+)/" + } else { + rr = rr + "([^/]+)/" + } + } + regexpStr = rr + regexpStr + } else { + regexpStr = "/" + regexpStr + } + } else if reg != "" { + if seg == "*.*" { + regexpStr = "([^.]+).(.+)" + params = params[1:] + } else { + for range params { + regexpStr = "([^/]+)/" + regexpStr + } + } + } else { + if seg == "*.*" { + params = params[1:] + } + } + reg = strings.TrimRight(strings.TrimRight(reg, "/")+"/"+regexpStr, "/") + t.wildcard.addtree(segments[1:], tree, append(wildcards, params...), reg) + } else { + subTree := NewTree() + subTree.prefix = seg + t.fixrouters = append(t.fixrouters, subTree) + subTree.addtree(segments[1:], tree, append(wildcards, params...), reg) + } +} + +func filterTreeWithPrefix(t *Tree, wildcards []string, reg string) { + for _, v := range t.fixrouters { + filterTreeWithPrefix(v, wildcards, reg) + } + if t.wildcard != nil { + filterTreeWithPrefix(t.wildcard, wildcards, reg) + } + for _, l := range t.leaves { + if reg != "" { + if l.regexps != nil { + l.wildcards = append(wildcards, l.wildcards...) + l.regexps = regexp.MustCompile("^" + reg + "/" + strings.Trim(l.regexps.String(), "^$") + "$") + } else { + for _, v := range l.wildcards { + if v == ":splat" { + reg = reg + "/(.+)" + } else { + reg = reg + "/([^/]+)" + } + } + l.regexps = regexp.MustCompile("^" + reg + "$") + l.wildcards = append(wildcards, l.wildcards...) + } + } else { + l.wildcards = append(wildcards, l.wildcards...) + if l.regexps != nil { + for _, w := range wildcards { + if w == ":splat" { + reg = "(.+)/" + reg + } else { + reg = "([^/]+)/" + reg + } + } + l.regexps = regexp.MustCompile("^" + reg + strings.Trim(l.regexps.String(), "^$") + "$") + } + } + } +} + +// AddRouter call addseg function +func (t *Tree) AddRouter(pattern string, runObject interface{}) { + t.addseg(splitPath(pattern), runObject, nil, "") +} + +// "/" +// "admin" -> +func (t *Tree) addseg(segments []string, route interface{}, wildcards []string, reg string) { + if len(segments) == 0 { + if reg != "" { + t.leaves = append(t.leaves, &leafInfo{runObject: route, wildcards: wildcards, regexps: regexp.MustCompile("^" + reg + "$")}) + } else { + t.leaves = append(t.leaves, &leafInfo{runObject: route, wildcards: wildcards}) + } + } else { + seg := segments[0] + iswild, params, regexpStr := splitSegment(seg) + // if it's ? meaning can igone this, so add one more rule for it + if len(params) > 0 && params[0] == ":" { + t.addseg(segments[1:], route, wildcards, reg) + params = params[1:] + } + //Rule: /login/*/access match /login/2009/11/access + //if already has *, and when loop the access, should as a regexpStr + if !iswild && utils.InSlice(":splat", wildcards) { + iswild = true + regexpStr = seg + } + //Rule: /user/:id/* + if seg == "*" && len(wildcards) > 0 && reg == "" { + regexpStr = "(.+)" + } + if iswild { + if t.wildcard == nil { + t.wildcard = NewTree() + } + if regexpStr != "" { + if reg == "" { + rr := "" + for _, w := range wildcards { + if w == ":splat" { + rr = rr + "(.+)/" + } else { + rr = rr + "([^/]+)/" + } + } + regexpStr = rr + regexpStr + } else { + regexpStr = "/" + regexpStr + } + } else if reg != "" { + if seg == "*.*" { + regexpStr = "/([^.]+).(.+)" + params = params[1:] + } else { + for range params { + regexpStr = "/([^/]+)" + regexpStr + } + } + } else { + if seg == "*.*" { + params = params[1:] + } + } + t.wildcard.addseg(segments[1:], route, append(wildcards, params...), reg+regexpStr) + } else { + var subTree *Tree + for _, sub := range t.fixrouters { + if sub.prefix == seg { + subTree = sub + break + } + } + if subTree == nil { + subTree = NewTree() + subTree.prefix = seg + t.fixrouters = append(t.fixrouters, subTree) + } + subTree.addseg(segments[1:], route, wildcards, reg) + } + } +} + +// Match router to runObject & params +func (t *Tree) Match(pattern string, ctx *context.Context) (runObject interface{}) { + if len(pattern) == 0 || pattern[0] != '/' { + return nil + } + w := make([]string, 0, 20) + return t.match(pattern[1:], pattern, w, ctx) +} + +func (t *Tree) match(treePattern string, pattern string, wildcardValues []string, ctx *context.Context) (runObject interface{}) { + if len(pattern) > 0 { + i := 0 + for ; i < len(pattern) && pattern[i] == '/'; i++ { + } + pattern = pattern[i:] + } + // Handle leaf nodes: + if len(pattern) == 0 { + for _, l := range t.leaves { + if ok := l.match(treePattern, wildcardValues, ctx); ok { + return l.runObject + } + } + if t.wildcard != nil { + for _, l := range t.wildcard.leaves { + if ok := l.match(treePattern, wildcardValues, ctx); ok { + return l.runObject + } + } + } + return nil + } + var seg string + i, l := 0, len(pattern) + for ; i < l && pattern[i] != '/'; i++ { + } + if i == 0 { + seg = pattern + pattern = "" + } else { + seg = pattern[:i] + pattern = pattern[i:] + } + for _, subTree := range t.fixrouters { + if subTree.prefix == seg { + if len(pattern) != 0 && pattern[0] == '/' { + treePattern = pattern[1:] + } else { + treePattern = pattern + } + runObject = subTree.match(treePattern, pattern, wildcardValues, ctx) + if runObject != nil { + break + } + } + } + if runObject == nil && len(t.fixrouters) > 0 { + // Filter the .json .xml .html extension + for _, str := range allowSuffixExt { + if strings.HasSuffix(seg, str) { + for _, subTree := range t.fixrouters { + if subTree.prefix == seg[:len(seg)-len(str)] { + runObject = subTree.match(treePattern, pattern, wildcardValues, ctx) + if runObject != nil { + ctx.Input.SetParam(":ext", str[1:]) + } + } + } + } + } + } + if runObject == nil && t.wildcard != nil { + runObject = t.wildcard.match(treePattern, pattern, append(wildcardValues, seg), ctx) + } + + if runObject == nil && len(t.leaves) > 0 { + wildcardValues = append(wildcardValues, seg) + start, i := 0, 0 + for ; i < len(pattern); i++ { + if pattern[i] == '/' { + if i != 0 && start < len(pattern) { + wildcardValues = append(wildcardValues, pattern[start:i]) + } + start = i + 1 + continue + } + } + if start > 0 { + wildcardValues = append(wildcardValues, pattern[start:i]) + } + for _, l := range t.leaves { + if ok := l.match(treePattern, wildcardValues, ctx); ok { + return l.runObject + } + } + } + return runObject +} + +type leafInfo struct { + // names of wildcards that lead to this leaf. eg, ["id" "name"] for the wildcard ":id" and ":name" + wildcards []string + + // if the leaf is regexp + regexps *regexp.Regexp + + runObject interface{} +} + +func (leaf *leafInfo) match(treePattern string, wildcardValues []string, ctx *context.Context) (ok bool) { + //fmt.Println("Leaf:", wildcardValues, leaf.wildcards, leaf.regexps) + if leaf.regexps == nil { + if len(wildcardValues) == 0 && len(leaf.wildcards) == 0 { // static path + return true + } + // match * + if len(leaf.wildcards) == 1 && leaf.wildcards[0] == ":splat" { + ctx.Input.SetParam(":splat", treePattern) + return true + } + // match *.* or :id + if len(leaf.wildcards) >= 2 && leaf.wildcards[len(leaf.wildcards)-2] == ":path" && leaf.wildcards[len(leaf.wildcards)-1] == ":ext" { + if len(leaf.wildcards) == 2 { + lastone := wildcardValues[len(wildcardValues)-1] + strs := strings.SplitN(lastone, ".", 2) + if len(strs) == 2 { + ctx.Input.SetParam(":ext", strs[1]) + } + ctx.Input.SetParam(":path", path.Join(path.Join(wildcardValues[:len(wildcardValues)-1]...), strs[0])) + return true + } else if len(wildcardValues) < 2 { + return false + } + var index int + for index = 0; index < len(leaf.wildcards)-2; index++ { + ctx.Input.SetParam(leaf.wildcards[index], wildcardValues[index]) + } + lastone := wildcardValues[len(wildcardValues)-1] + strs := strings.SplitN(lastone, ".", 2) + if len(strs) == 2 { + ctx.Input.SetParam(":ext", strs[1]) + } + if index > (len(wildcardValues) - 1) { + ctx.Input.SetParam(":path", "") + } else { + ctx.Input.SetParam(":path", path.Join(path.Join(wildcardValues[index:len(wildcardValues)-1]...), strs[0])) + } + return true + } + // match :id + if len(leaf.wildcards) != len(wildcardValues) { + return false + } + for j, v := range leaf.wildcards { + ctx.Input.SetParam(v, wildcardValues[j]) + } + return true + } + + if !leaf.regexps.MatchString(path.Join(wildcardValues...)) { + return false + } + matches := leaf.regexps.FindStringSubmatch(path.Join(wildcardValues...)) + for i, match := range matches[1:] { + if i < len(leaf.wildcards) { + ctx.Input.SetParam(leaf.wildcards[i], match) + } + } + return true +} + +// "/" -> [] +// "/admin" -> ["admin"] +// "/admin/" -> ["admin"] +// "/admin/users" -> ["admin", "users"] +func splitPath(key string) []string { + key = strings.Trim(key, "/ ") + if key == "" { + return []string{} + } + return strings.Split(key, "/") +} + +// "admin" -> false, nil, "" +// ":id" -> true, [:id], "" +// "?:id" -> true, [: :id], "" : meaning can empty +// ":id:int" -> true, [:id], ([0-9]+) +// ":name:string" -> true, [:name], ([\w]+) +// ":id([0-9]+)" -> true, [:id], ([0-9]+) +// ":id([0-9]+)_:name" -> true, [:id :name], ([0-9]+)_(.+) +// "cms_:id_:page.html" -> true, [:id_ :page], cms_(.+)(.+).html +// "cms_:id(.+)_:page.html" -> true, [:id :page], cms_(.+)_(.+).html +// "*" -> true, [:splat], "" +// "*.*" -> true,[. :path :ext], "" . meaning separator +func splitSegment(key string) (bool, []string, string) { + if strings.HasPrefix(key, "*") { + if key == "*.*" { + return true, []string{".", ":path", ":ext"}, "" + } + return true, []string{":splat"}, "" + } + if strings.ContainsAny(key, ":") { + var paramsNum int + var out []rune + var start bool + var startexp bool + var param []rune + var expt []rune + var skipnum int + params := []string{} + reg := regexp.MustCompile(`[a-zA-Z0-9_]+`) + for i, v := range key { + if skipnum > 0 { + skipnum-- + continue + } + if start { + //:id:int and :name:string + if v == ':' { + if len(key) >= i+4 { + if key[i+1:i+4] == "int" { + out = append(out, []rune("([0-9]+)")...) + params = append(params, ":"+string(param)) + start = false + startexp = false + skipnum = 3 + param = make([]rune, 0) + paramsNum++ + continue + } + } + if len(key) >= i+7 { + if key[i+1:i+7] == "string" { + out = append(out, []rune(`([\w]+)`)...) + params = append(params, ":"+string(param)) + paramsNum++ + start = false + startexp = false + skipnum = 6 + param = make([]rune, 0) + continue + } + } + } + // params only support a-zA-Z0-9 + if reg.MatchString(string(v)) { + param = append(param, v) + continue + } + if v != '(' { + out = append(out, []rune(`(.+)`)...) + params = append(params, ":"+string(param)) + param = make([]rune, 0) + paramsNum++ + start = false + startexp = false + } + } + if startexp { + if v != ')' { + expt = append(expt, v) + continue + } + } + // Escape Sequence '\' + if i > 0 && key[i-1] == '\\' { + out = append(out, v) + } else if v == ':' { + param = make([]rune, 0) + start = true + } else if v == '(' { + startexp = true + start = false + if len(param) > 0 { + params = append(params, ":"+string(param)) + param = make([]rune, 0) + } + paramsNum++ + expt = make([]rune, 0) + expt = append(expt, '(') + } else if v == ')' { + startexp = false + expt = append(expt, ')') + out = append(out, expt...) + param = make([]rune, 0) + } else if v == '?' { + params = append(params, ":") + } else { + out = append(out, v) + } + } + if len(param) > 0 { + if paramsNum > 0 { + out = append(out, []rune(`(.+)`)...) + } + params = append(params, ":"+string(param)) + } + return true, params, string(out) + } + return false, nil, "" +} diff --git a/pkg/tree_test.go b/pkg/tree_test.go new file mode 100644 index 00000000..d412a348 --- /dev/null +++ b/pkg/tree_test.go @@ -0,0 +1,306 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package beego + +import ( + "strings" + "testing" + + "github.com/astaxie/beego/context" +) + +type testinfo struct { + url string + requesturl string + params map[string]string +} + +var routers []testinfo + +func init() { + routers = make([]testinfo, 0) + routers = append(routers, testinfo{"/topic/?:auth:int", "/topic", nil}) + routers = append(routers, testinfo{"/topic/?:auth:int", "/topic/123", map[string]string{":auth": "123"}}) + routers = append(routers, testinfo{"/topic/:id/?:auth", "/topic/1", map[string]string{":id": "1"}}) + routers = append(routers, testinfo{"/topic/:id/?:auth", "/topic/1/2", map[string]string{":id": "1", ":auth": "2"}}) + routers = append(routers, testinfo{"/topic/:id/?:auth:int", "/topic/1", map[string]string{":id": "1"}}) + routers = append(routers, testinfo{"/topic/:id/?:auth:int", "/topic/1/123", map[string]string{":id": "1", ":auth": "123"}}) + routers = append(routers, testinfo{"/:id", "/123", map[string]string{":id": "123"}}) + routers = append(routers, testinfo{"/hello/?:id", "/hello", map[string]string{":id": ""}}) + routers = append(routers, testinfo{"/", "/", nil}) + routers = append(routers, testinfo{"/customer/login", "/customer/login", nil}) + routers = append(routers, testinfo{"/customer/login", "/customer/login.json", map[string]string{":ext": "json"}}) + routers = append(routers, testinfo{"/*", "/http://customer/123/", map[string]string{":splat": "http://customer/123/"}}) + routers = append(routers, testinfo{"/*", "/customer/2009/12/11", map[string]string{":splat": "customer/2009/12/11"}}) + routers = append(routers, testinfo{"/aa/*/bb", "/aa/2009/bb", map[string]string{":splat": "2009"}}) + routers = append(routers, testinfo{"/cc/*/dd", "/cc/2009/11/dd", map[string]string{":splat": "2009/11"}}) + routers = append(routers, testinfo{"/cc/:id/*", "/cc/2009/11/dd", map[string]string{":id": "2009", ":splat": "11/dd"}}) + routers = append(routers, testinfo{"/ee/:year/*/ff", "/ee/2009/11/ff", map[string]string{":year": "2009", ":splat": "11"}}) + routers = append(routers, testinfo{"/thumbnail/:size/uploads/*", + "/thumbnail/100x100/uploads/items/2014/04/20/dPRCdChkUd651t1Hvs18.jpg", + map[string]string{":size": "100x100", ":splat": "items/2014/04/20/dPRCdChkUd651t1Hvs18.jpg"}}) + routers = append(routers, testinfo{"/*.*", "/nice/api.json", map[string]string{":path": "nice/api", ":ext": "json"}}) + routers = append(routers, testinfo{"/:name/*.*", "/nice/api.json", map[string]string{":name": "nice", ":path": "api", ":ext": "json"}}) + routers = append(routers, testinfo{"/:name/test/*.*", "/nice/test/api.json", map[string]string{":name": "nice", ":path": "api", ":ext": "json"}}) + routers = append(routers, testinfo{"/dl/:width:int/:height:int/*.*", + "/dl/48/48/05ac66d9bda00a3acf948c43e306fc9a.jpg", + map[string]string{":width": "48", ":height": "48", ":ext": "jpg", ":path": "05ac66d9bda00a3acf948c43e306fc9a"}}) + routers = append(routers, testinfo{"/v1/shop/:id:int", "/v1/shop/123", map[string]string{":id": "123"}}) + routers = append(routers, testinfo{"/v1/shop/:id\\((a|b|c)\\)", "/v1/shop/123(a)", map[string]string{":id": "123"}}) + routers = append(routers, testinfo{"/v1/shop/:id\\((a|b|c)\\)", "/v1/shop/123(b)", map[string]string{":id": "123"}}) + routers = append(routers, testinfo{"/v1/shop/:id\\((a|b|c)\\)", "/v1/shop/123(c)", map[string]string{":id": "123"}}) + routers = append(routers, testinfo{"/:year:int/:month:int/:id/:endid", "/1111/111/aaa/aaa", map[string]string{":year": "1111", ":month": "111", ":id": "aaa", ":endid": "aaa"}}) + routers = append(routers, testinfo{"/v1/shop/:id/:name", "/v1/shop/123/nike", map[string]string{":id": "123", ":name": "nike"}}) + routers = append(routers, testinfo{"/v1/shop/:id/account", "/v1/shop/123/account", map[string]string{":id": "123"}}) + routers = append(routers, testinfo{"/v1/shop/:name:string", "/v1/shop/nike", map[string]string{":name": "nike"}}) + routers = append(routers, testinfo{"/v1/shop/:id([0-9]+)", "/v1/shop//123", map[string]string{":id": "123"}}) + routers = append(routers, testinfo{"/v1/shop/:id([0-9]+)_:name", "/v1/shop/123_nike", map[string]string{":id": "123", ":name": "nike"}}) + routers = append(routers, testinfo{"/v1/shop/:id(.+)_cms.html", "/v1/shop/123_cms.html", map[string]string{":id": "123"}}) + routers = append(routers, testinfo{"/v1/shop/cms_:id(.+)_:page(.+).html", "/v1/shop/cms_123_1.html", map[string]string{":id": "123", ":page": "1"}}) + routers = append(routers, testinfo{"/v1/:v/cms/aaa_:id(.+)_:page(.+).html", "/v1/2/cms/aaa_123_1.html", map[string]string{":v": "2", ":id": "123", ":page": "1"}}) + routers = append(routers, testinfo{"/v1/:v/cms_:id(.+)_:page(.+).html", "/v1/2/cms_123_1.html", map[string]string{":v": "2", ":id": "123", ":page": "1"}}) + routers = append(routers, testinfo{"/v1/:v(.+)_cms/ttt_:id(.+)_:page(.+).html", "/v1/2_cms/ttt_123_1.html", map[string]string{":v": "2", ":id": "123", ":page": "1"}}) + routers = append(routers, testinfo{"/api/projects/:pid/members/?:mid", "/api/projects/1/members", map[string]string{":pid": "1"}}) + routers = append(routers, testinfo{"/api/projects/:pid/members/?:mid", "/api/projects/1/members/2", map[string]string{":pid": "1", ":mid": "2"}}) +} + +func TestTreeRouters(t *testing.T) { + for _, r := range routers { + tr := NewTree() + tr.AddRouter(r.url, "astaxie") + ctx := context.NewContext() + obj := tr.Match(r.requesturl, ctx) + if obj == nil || obj.(string) != "astaxie" { + t.Fatal(r.url+" can't get obj, Expect ", r.requesturl) + } + if r.params != nil { + for k, v := range r.params { + if vv := ctx.Input.Param(k); vv != v { + t.Fatal("The Rule: " + r.url + "\nThe RequestURL:" + r.requesturl + "\nThe Key is " + k + ", The Value should be: " + v + ", but get: " + vv) + } else if vv == "" && v != "" { + t.Fatal(r.url + " " + r.requesturl + " get param empty:" + k) + } + } + } + } +} + +func TestStaticPath(t *testing.T) { + tr := NewTree() + tr.AddRouter("/topic/:id", "wildcard") + tr.AddRouter("/topic", "static") + ctx := context.NewContext() + obj := tr.Match("/topic", ctx) + if obj == nil || obj.(string) != "static" { + t.Fatal("/topic is a static route") + } + obj = tr.Match("/topic/1", ctx) + if obj == nil || obj.(string) != "wildcard" { + t.Fatal("/topic/1 is a wildcard route") + } +} + +func TestAddTree(t *testing.T) { + tr := NewTree() + tr.AddRouter("/shop/:id/account", "astaxie") + tr.AddRouter("/shop/:sd/ttt_:id(.+)_:page(.+).html", "astaxie") + t1 := NewTree() + t1.AddTree("/v1/zl", tr) + ctx := context.NewContext() + obj := t1.Match("/v1/zl/shop/123/account", ctx) + if obj == nil || obj.(string) != "astaxie" { + t.Fatal("/v1/zl/shop/:id/account can't get obj ") + } + if ctx.Input.ParamsLen() == 0 { + t.Fatal("get param error") + } + if ctx.Input.Param(":id") != "123" { + t.Fatal("get :id param error") + } + ctx.Input.Reset(ctx) + obj = t1.Match("/v1/zl/shop/123/ttt_1_12.html", ctx) + if obj == nil || obj.(string) != "astaxie" { + t.Fatal("/v1/zl//shop/:sd/ttt_:id(.+)_:page(.+).html can't get obj ") + } + if ctx.Input.ParamsLen() == 0 { + t.Fatal("get param error") + } + if ctx.Input.Param(":sd") != "123" || ctx.Input.Param(":id") != "1" || ctx.Input.Param(":page") != "12" { + t.Fatal("get :sd :id :page param error") + } + + t2 := NewTree() + t2.AddTree("/v1/:shopid", tr) + ctx.Input.Reset(ctx) + obj = t2.Match("/v1/zl/shop/123/account", ctx) + if obj == nil || obj.(string) != "astaxie" { + t.Fatal("/v1/:shopid/shop/:id/account can't get obj ") + } + if ctx.Input.ParamsLen() == 0 { + t.Fatal("get param error") + } + if ctx.Input.Param(":id") != "123" || ctx.Input.Param(":shopid") != "zl" { + t.Fatal("get :id :shopid param error") + } + ctx.Input.Reset(ctx) + obj = t2.Match("/v1/zl/shop/123/ttt_1_12.html", ctx) + if obj == nil || obj.(string) != "astaxie" { + t.Fatal("/v1/:shopid/shop/:sd/ttt_:id(.+)_:page(.+).html can't get obj ") + } + if ctx.Input.ParamsLen() == 0 { + t.Fatal("get :shopid param error") + } + if ctx.Input.Param(":sd") != "123" || ctx.Input.Param(":id") != "1" || ctx.Input.Param(":page") != "12" || ctx.Input.Param(":shopid") != "zl" { + t.Fatal("get :sd :id :page :shopid param error") + } +} + +func TestAddTree2(t *testing.T) { + tr := NewTree() + tr.AddRouter("/shop/:id/account", "astaxie") + tr.AddRouter("/shop/:sd/ttt_:id(.+)_:page(.+).html", "astaxie") + t3 := NewTree() + t3.AddTree("/:version(v1|v2)/:prefix", tr) + ctx := context.NewContext() + obj := t3.Match("/v1/zl/shop/123/account", ctx) + if obj == nil || obj.(string) != "astaxie" { + t.Fatal("/:version(v1|v2)/:prefix/shop/:id/account can't get obj ") + } + if ctx.Input.ParamsLen() == 0 { + t.Fatal("get param error") + } + if ctx.Input.Param(":id") != "123" || ctx.Input.Param(":prefix") != "zl" || ctx.Input.Param(":version") != "v1" { + t.Fatal("get :id :prefix :version param error") + } +} + +func TestAddTree3(t *testing.T) { + tr := NewTree() + tr.AddRouter("/create", "astaxie") + tr.AddRouter("/shop/:sd/account", "astaxie") + t3 := NewTree() + t3.AddTree("/table/:num", tr) + ctx := context.NewContext() + obj := t3.Match("/table/123/shop/123/account", ctx) + if obj == nil || obj.(string) != "astaxie" { + t.Fatal("/table/:num/shop/:sd/account can't get obj ") + } + if ctx.Input.ParamsLen() == 0 { + t.Fatal("get param error") + } + if ctx.Input.Param(":num") != "123" || ctx.Input.Param(":sd") != "123" { + t.Fatal("get :num :sd param error") + } + ctx.Input.Reset(ctx) + obj = t3.Match("/table/123/create", ctx) + if obj == nil || obj.(string) != "astaxie" { + t.Fatal("/table/:num/create can't get obj ") + } +} + +func TestAddTree4(t *testing.T) { + tr := NewTree() + tr.AddRouter("/create", "astaxie") + tr.AddRouter("/shop/:sd/:account", "astaxie") + t4 := NewTree() + t4.AddTree("/:info:int/:num/:id", tr) + ctx := context.NewContext() + obj := t4.Match("/12/123/456/shop/123/account", ctx) + if obj == nil || obj.(string) != "astaxie" { + t.Fatal("/:info:int/:num/:id/shop/:sd/:account can't get obj ") + } + if ctx.Input.ParamsLen() == 0 { + t.Fatal("get param error") + } + if ctx.Input.Param(":info") != "12" || ctx.Input.Param(":num") != "123" || + ctx.Input.Param(":id") != "456" || ctx.Input.Param(":sd") != "123" || + ctx.Input.Param(":account") != "account" { + t.Fatal("get :info :num :id :sd :account param error") + } + ctx.Input.Reset(ctx) + obj = t4.Match("/12/123/456/create", ctx) + if obj == nil || obj.(string) != "astaxie" { + t.Fatal("/:info:int/:num/:id/create can't get obj ") + } +} + +// Test for issue #1595 +func TestAddTree5(t *testing.T) { + tr := NewTree() + tr.AddRouter("/v1/shop/:id", "shopdetail") + tr.AddRouter("/v1/shop/", "shophome") + ctx := context.NewContext() + obj := tr.Match("/v1/shop/", ctx) + if obj == nil || obj.(string) != "shophome" { + t.Fatal("url /v1/shop/ need match router /v1/shop/ ") + } +} + +func TestSplitPath(t *testing.T) { + a := splitPath("") + if len(a) != 0 { + t.Fatal("/ should retrun []") + } + a = splitPath("/") + if len(a) != 0 { + t.Fatal("/ should retrun []") + } + a = splitPath("/admin") + if len(a) != 1 || a[0] != "admin" { + t.Fatal("/admin should retrun [admin]") + } + a = splitPath("/admin/") + if len(a) != 1 || a[0] != "admin" { + t.Fatal("/admin/ should retrun [admin]") + } + a = splitPath("/admin/users") + if len(a) != 2 || a[0] != "admin" || a[1] != "users" { + t.Fatal("/admin should retrun [admin users]") + } + a = splitPath("/admin/:id:int") + if len(a) != 2 || a[0] != "admin" || a[1] != ":id:int" { + t.Fatal("/admin should retrun [admin :id:int]") + } +} + +func TestSplitSegment(t *testing.T) { + + items := map[string]struct { + isReg bool + params []string + regStr string + }{ + "admin": {false, nil, ""}, + "*": {true, []string{":splat"}, ""}, + "*.*": {true, []string{".", ":path", ":ext"}, ""}, + ":id": {true, []string{":id"}, ""}, + "?:id": {true, []string{":", ":id"}, ""}, + ":id:int": {true, []string{":id"}, "([0-9]+)"}, + ":name:string": {true, []string{":name"}, `([\w]+)`}, + ":id([0-9]+)": {true, []string{":id"}, `([0-9]+)`}, + ":id([0-9]+)_:name": {true, []string{":id", ":name"}, `([0-9]+)_(.+)`}, + ":id(.+)_cms.html": {true, []string{":id"}, `(.+)_cms.html`}, + "cms_:id(.+)_:page(.+).html": {true, []string{":id", ":page"}, `cms_(.+)_(.+).html`}, + `:app(a|b|c)`: {true, []string{":app"}, `(a|b|c)`}, + `:app\((a|b|c)\)`: {true, []string{":app"}, `(.+)\((a|b|c)\)`}, + } + + for pattern, v := range items { + b, w, r := splitSegment(pattern) + if b != v.isReg || r != v.regStr || strings.Join(w, ",") != strings.Join(v.params, ",") { + t.Fatalf("%s should return %t,%s,%q, got %t,%s,%q", pattern, v.isReg, v.params, v.regStr, b, w, r) + } + } +} diff --git a/pkg/unregroute_test.go b/pkg/unregroute_test.go new file mode 100644 index 00000000..08b1b77b --- /dev/null +++ b/pkg/unregroute_test.go @@ -0,0 +1,226 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package beego + +import ( + "net/http" + "net/http/httptest" + "strings" + "testing" +) + +// +// The unregroute_test.go contains tests for the unregister route +// functionality, that allows overriding route paths in children project +// that embed parent routers. +// + +const contentRootOriginal = "ok-original-root" +const contentLevel1Original = "ok-original-level1" +const contentLevel2Original = "ok-original-level2" + +const contentRootReplacement = "ok-replacement-root" +const contentLevel1Replacement = "ok-replacement-level1" +const contentLevel2Replacement = "ok-replacement-level2" + +// TestPreUnregController will supply content for the original routes, +// before unregistration +type TestPreUnregController struct { + Controller +} + +func (tc *TestPreUnregController) GetFixedRoot() { + tc.Ctx.Output.Body([]byte(contentRootOriginal)) +} +func (tc *TestPreUnregController) GetFixedLevel1() { + tc.Ctx.Output.Body([]byte(contentLevel1Original)) +} +func (tc *TestPreUnregController) GetFixedLevel2() { + tc.Ctx.Output.Body([]byte(contentLevel2Original)) +} + +// TestPostUnregController will supply content for the overriding routes, +// after the original ones are unregistered. +type TestPostUnregController struct { + Controller +} + +func (tc *TestPostUnregController) GetFixedRoot() { + tc.Ctx.Output.Body([]byte(contentRootReplacement)) +} +func (tc *TestPostUnregController) GetFixedLevel1() { + tc.Ctx.Output.Body([]byte(contentLevel1Replacement)) +} +func (tc *TestPostUnregController) GetFixedLevel2() { + tc.Ctx.Output.Body([]byte(contentLevel2Replacement)) +} + +// TestUnregisterFixedRouteRoot replaces just the root fixed route path. +// In this case, for a path like "/level1/level2" or "/level1", those actions +// should remain intact, and continue to serve the original content. +func TestUnregisterFixedRouteRoot(t *testing.T) { + + var method = "GET" + + handler := NewControllerRegister() + handler.Add("/", &TestPreUnregController{}, "get:GetFixedRoot") + handler.Add("/level1", &TestPreUnregController{}, "get:GetFixedLevel1") + handler.Add("/level1/level2", &TestPreUnregController{}, "get:GetFixedLevel2") + + // Test original root + testHelperFnContentCheck(t, handler, "Test original root", + method, "/", contentRootOriginal) + + // Test original level 1 + testHelperFnContentCheck(t, handler, "Test original level 1", + method, "/level1", contentLevel1Original) + + // Test original level 2 + testHelperFnContentCheck(t, handler, "Test original level 2", + method, "/level1/level2", contentLevel2Original) + + // Remove only the root path + findAndRemoveSingleTree(handler.routers[method]) + + // Replace the root path TestPreUnregController action with the action from + // TestPostUnregController + handler.Add("/", &TestPostUnregController{}, "get:GetFixedRoot") + + // Test replacement root (expect change) + testHelperFnContentCheck(t, handler, "Test replacement root (expect change)", method, "/", contentRootReplacement) + + // Test level 1 (expect no change from the original) + testHelperFnContentCheck(t, handler, "Test level 1 (expect no change from the original)", method, "/level1", contentLevel1Original) + + // Test level 2 (expect no change from the original) + testHelperFnContentCheck(t, handler, "Test level 2 (expect no change from the original)", method, "/level1/level2", contentLevel2Original) + +} + +// TestUnregisterFixedRouteLevel1 replaces just the "/level1" fixed route path. +// In this case, for a path like "/level1/level2" or "/", those actions +// should remain intact, and continue to serve the original content. +func TestUnregisterFixedRouteLevel1(t *testing.T) { + + var method = "GET" + + handler := NewControllerRegister() + handler.Add("/", &TestPreUnregController{}, "get:GetFixedRoot") + handler.Add("/level1", &TestPreUnregController{}, "get:GetFixedLevel1") + handler.Add("/level1/level2", &TestPreUnregController{}, "get:GetFixedLevel2") + + // Test original root + testHelperFnContentCheck(t, handler, + "TestUnregisterFixedRouteLevel1.Test original root", + method, "/", contentRootOriginal) + + // Test original level 1 + testHelperFnContentCheck(t, handler, + "TestUnregisterFixedRouteLevel1.Test original level 1", + method, "/level1", contentLevel1Original) + + // Test original level 2 + testHelperFnContentCheck(t, handler, + "TestUnregisterFixedRouteLevel1.Test original level 2", + method, "/level1/level2", contentLevel2Original) + + // Remove only the level1 path + subPaths := splitPath("/level1") + if handler.routers[method].prefix == strings.Trim("/level1", "/ ") { + findAndRemoveSingleTree(handler.routers[method]) + } else { + findAndRemoveTree(subPaths, handler.routers[method], method) + } + + // Replace the "level1" path TestPreUnregController action with the action from + // TestPostUnregController + handler.Add("/level1", &TestPostUnregController{}, "get:GetFixedLevel1") + + // Test replacement root (expect no change from the original) + testHelperFnContentCheck(t, handler, "Test replacement root (expect no change from the original)", method, "/", contentRootOriginal) + + // Test level 1 (expect change) + testHelperFnContentCheck(t, handler, "Test level 1 (expect change)", method, "/level1", contentLevel1Replacement) + + // Test level 2 (expect no change from the original) + testHelperFnContentCheck(t, handler, "Test level 2 (expect no change from the original)", method, "/level1/level2", contentLevel2Original) + +} + +// TestUnregisterFixedRouteLevel2 unregisters just the "/level1/level2" fixed +// route path. In this case, for a path like "/level1" or "/", those actions +// should remain intact, and continue to serve the original content. +func TestUnregisterFixedRouteLevel2(t *testing.T) { + + var method = "GET" + + handler := NewControllerRegister() + handler.Add("/", &TestPreUnregController{}, "get:GetFixedRoot") + handler.Add("/level1", &TestPreUnregController{}, "get:GetFixedLevel1") + handler.Add("/level1/level2", &TestPreUnregController{}, "get:GetFixedLevel2") + + // Test original root + testHelperFnContentCheck(t, handler, + "TestUnregisterFixedRouteLevel1.Test original root", + method, "/", contentRootOriginal) + + // Test original level 1 + testHelperFnContentCheck(t, handler, + "TestUnregisterFixedRouteLevel1.Test original level 1", + method, "/level1", contentLevel1Original) + + // Test original level 2 + testHelperFnContentCheck(t, handler, + "TestUnregisterFixedRouteLevel1.Test original level 2", + method, "/level1/level2", contentLevel2Original) + + // Remove only the level2 path + subPaths := splitPath("/level1/level2") + if handler.routers[method].prefix == strings.Trim("/level1/level2", "/ ") { + findAndRemoveSingleTree(handler.routers[method]) + } else { + findAndRemoveTree(subPaths, handler.routers[method], method) + } + + // Replace the "/level1/level2" path TestPreUnregController action with the action from + // TestPostUnregController + handler.Add("/level1/level2", &TestPostUnregController{}, "get:GetFixedLevel2") + + // Test replacement root (expect no change from the original) + testHelperFnContentCheck(t, handler, "Test replacement root (expect no change from the original)", method, "/", contentRootOriginal) + + // Test level 1 (expect no change from the original) + testHelperFnContentCheck(t, handler, "Test level 1 (expect no change from the original)", method, "/level1", contentLevel1Original) + + // Test level 2 (expect change) + testHelperFnContentCheck(t, handler, "Test level 2 (expect change)", method, "/level1/level2", contentLevel2Replacement) + +} + +func testHelperFnContentCheck(t *testing.T, handler *ControllerRegister, + testName, method, path, expectedBodyContent string) { + + r, err := http.NewRequest(method, path, nil) + if err != nil { + t.Errorf("httpRecorderBodyTest NewRequest error: %v", err) + return + } + w := httptest.NewRecorder() + handler.ServeHTTP(w, r) + body := w.Body.String() + if body != expectedBodyContent { + t.Errorf("%s: expected [%s], got [%s];", testName, expectedBodyContent, body) + } +} diff --git a/pkg/utils/caller.go b/pkg/utils/caller.go new file mode 100644 index 00000000..73c52a62 --- /dev/null +++ b/pkg/utils/caller.go @@ -0,0 +1,25 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package utils + +import ( + "reflect" + "runtime" +) + +// GetFuncName get function name +func GetFuncName(i interface{}) string { + return runtime.FuncForPC(reflect.ValueOf(i).Pointer()).Name() +} diff --git a/pkg/utils/caller_test.go b/pkg/utils/caller_test.go new file mode 100644 index 00000000..0675f0aa --- /dev/null +++ b/pkg/utils/caller_test.go @@ -0,0 +1,28 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package utils + +import ( + "strings" + "testing" +) + +func TestGetFuncName(t *testing.T) { + name := GetFuncName(TestGetFuncName) + t.Log(name) + if !strings.HasSuffix(name, ".TestGetFuncName") { + t.Error("get func name error") + } +} diff --git a/pkg/utils/captcha/LICENSE b/pkg/utils/captcha/LICENSE new file mode 100644 index 00000000..0ad73ae0 --- /dev/null +++ b/pkg/utils/captcha/LICENSE @@ -0,0 +1,19 @@ +Copyright (c) 2011-2014 Dmitry Chestnykh + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. diff --git a/pkg/utils/captcha/README.md b/pkg/utils/captcha/README.md new file mode 100644 index 00000000..dbc2026b --- /dev/null +++ b/pkg/utils/captcha/README.md @@ -0,0 +1,45 @@ +# Captcha + +an example for use captcha + +``` +package controllers + +import ( + "github.com/astaxie/beego" + "github.com/astaxie/beego/cache" + "github.com/astaxie/beego/utils/captcha" +) + +var cpt *captcha.Captcha + +func init() { + // use beego cache system store the captcha data + store := cache.NewMemoryCache() + cpt = captcha.NewWithFilter("/captcha/", store) +} + +type MainController struct { + beego.Controller +} + +func (this *MainController) Get() { + this.TplName = "index.tpl" +} + +func (this *MainController) Post() { + this.TplName = "index.tpl" + + this.Data["Success"] = cpt.VerifyReq(this.Ctx.Request) +} +``` + +template usage + +``` +{{.Success}} +
+ {{create_captcha}} + +
+``` diff --git a/pkg/utils/captcha/captcha.go b/pkg/utils/captcha/captcha.go new file mode 100644 index 00000000..42ac70d3 --- /dev/null +++ b/pkg/utils/captcha/captcha.go @@ -0,0 +1,270 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package captcha implements generation and verification of image CAPTCHAs. +// an example for use captcha +// +// ``` +// package controllers +// +// import ( +// "github.com/astaxie/beego" +// "github.com/astaxie/beego/cache" +// "github.com/astaxie/beego/utils/captcha" +// ) +// +// var cpt *captcha.Captcha +// +// func init() { +// // use beego cache system store the captcha data +// store := cache.NewMemoryCache() +// cpt = captcha.NewWithFilter("/captcha/", store) +// } +// +// type MainController struct { +// beego.Controller +// } +// +// func (this *MainController) Get() { +// this.TplName = "index.tpl" +// } +// +// func (this *MainController) Post() { +// this.TplName = "index.tpl" +// +// this.Data["Success"] = cpt.VerifyReq(this.Ctx.Request) +// } +// ``` +// +// template usage +// +// ``` +// {{.Success}} +//
+// {{create_captcha}} +// +//
+// ``` +package captcha + +import ( + "fmt" + "html/template" + "net/http" + "path" + "strings" + "time" + + "github.com/astaxie/beego" + "github.com/astaxie/beego/cache" + "github.com/astaxie/beego/context" + "github.com/astaxie/beego/logs" + "github.com/astaxie/beego/utils" +) + +var ( + defaultChars = []byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9} +) + +const ( + // default captcha attributes + challengeNums = 6 + expiration = 600 * time.Second + fieldIDName = "captcha_id" + fieldCaptchaName = "captcha" + cachePrefix = "captcha_" + defaultURLPrefix = "/captcha/" +) + +// Captcha struct +type Captcha struct { + // beego cache store + store cache.Cache + + // url prefix for captcha image + URLPrefix string + + // specify captcha id input field name + FieldIDName string + // specify captcha result input field name + FieldCaptchaName string + + // captcha image width and height + StdWidth int + StdHeight int + + // captcha chars nums + ChallengeNums int + + // captcha expiration seconds + Expiration time.Duration + + // cache key prefix + CachePrefix string +} + +// generate key string +func (c *Captcha) key(id string) string { + return c.CachePrefix + id +} + +// generate rand chars with default chars +func (c *Captcha) genRandChars() []byte { + return utils.RandomCreateBytes(c.ChallengeNums, defaultChars...) +} + +// Handler beego filter handler for serve captcha image +func (c *Captcha) Handler(ctx *context.Context) { + var chars []byte + + id := path.Base(ctx.Request.RequestURI) + if i := strings.Index(id, "."); i != -1 { + id = id[:i] + } + + key := c.key(id) + + if len(ctx.Input.Query("reload")) > 0 { + chars = c.genRandChars() + if err := c.store.Put(key, chars, c.Expiration); err != nil { + ctx.Output.SetStatus(500) + ctx.WriteString("captcha reload error") + logs.Error("Reload Create Captcha Error:", err) + return + } + } else { + if v, ok := c.store.Get(key).([]byte); ok { + chars = v + } else { + ctx.Output.SetStatus(404) + ctx.WriteString("captcha not found") + return + } + } + + img := NewImage(chars, c.StdWidth, c.StdHeight) + if _, err := img.WriteTo(ctx.ResponseWriter); err != nil { + logs.Error("Write Captcha Image Error:", err) + } +} + +// CreateCaptchaHTML template func for output html +func (c *Captcha) CreateCaptchaHTML() template.HTML { + value, err := c.CreateCaptcha() + if err != nil { + logs.Error("Create Captcha Error:", err) + return "" + } + + // create html + return template.HTML(fmt.Sprintf(``+ + ``+ + ``+ + ``, c.FieldIDName, value, c.URLPrefix, value, c.URLPrefix, value)) +} + +// CreateCaptcha create a new captcha id +func (c *Captcha) CreateCaptcha() (string, error) { + // generate captcha id + id := string(utils.RandomCreateBytes(15)) + + // get the captcha chars + chars := c.genRandChars() + + // save to store + if err := c.store.Put(c.key(id), chars, c.Expiration); err != nil { + return "", err + } + + return id, nil +} + +// VerifyReq verify from a request +func (c *Captcha) VerifyReq(req *http.Request) bool { + req.ParseForm() + return c.Verify(req.Form.Get(c.FieldIDName), req.Form.Get(c.FieldCaptchaName)) +} + +// Verify direct verify id and challenge string +func (c *Captcha) Verify(id string, challenge string) (success bool) { + if len(challenge) == 0 || len(id) == 0 { + return + } + + var chars []byte + + key := c.key(id) + + if v, ok := c.store.Get(key).([]byte); ok { + chars = v + } else { + return + } + + defer func() { + // finally remove it + c.store.Delete(key) + }() + + if len(chars) != len(challenge) { + return + } + // verify challenge + for i, c := range chars { + if c != challenge[i]-48 { + return + } + } + + return true +} + +// NewCaptcha create a new captcha.Captcha +func NewCaptcha(urlPrefix string, store cache.Cache) *Captcha { + cpt := &Captcha{} + cpt.store = store + cpt.FieldIDName = fieldIDName + cpt.FieldCaptchaName = fieldCaptchaName + cpt.ChallengeNums = challengeNums + cpt.Expiration = expiration + cpt.CachePrefix = cachePrefix + cpt.StdWidth = stdWidth + cpt.StdHeight = stdHeight + + if len(urlPrefix) == 0 { + urlPrefix = defaultURLPrefix + } + + if urlPrefix[len(urlPrefix)-1] != '/' { + urlPrefix += "/" + } + + cpt.URLPrefix = urlPrefix + + return cpt +} + +// NewWithFilter create a new captcha.Captcha and auto AddFilter for serve captacha image +// and add a template func for output html +func NewWithFilter(urlPrefix string, store cache.Cache) *Captcha { + cpt := NewCaptcha(urlPrefix, store) + + // create filter for serve captcha image + beego.InsertFilter(cpt.URLPrefix+"*", beego.BeforeRouter, cpt.Handler) + + // add to template func map + beego.AddFuncMap("create_captcha", cpt.CreateCaptchaHTML) + + return cpt +} diff --git a/pkg/utils/captcha/image.go b/pkg/utils/captcha/image.go new file mode 100644 index 00000000..c3c9a83a --- /dev/null +++ b/pkg/utils/captcha/image.go @@ -0,0 +1,501 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package captcha + +import ( + "bytes" + "image" + "image/color" + "image/png" + "io" + "math" +) + +const ( + fontWidth = 11 + fontHeight = 18 + blackChar = 1 + + // Standard width and height of a captcha image. + stdWidth = 240 + stdHeight = 80 + // Maximum absolute skew factor of a single digit. + maxSkew = 0.7 + // Number of background circles. + circleCount = 20 +) + +var font = [][]byte{ + { // 0 + 0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 0, + 0, 0, 1, 1, 1, 1, 1, 1, 1, 0, 0, + 0, 1, 1, 1, 0, 0, 0, 1, 1, 1, 0, + 0, 1, 1, 0, 0, 0, 0, 0, 1, 1, 0, + 1, 1, 1, 0, 0, 0, 0, 0, 1, 1, 0, + 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, + 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, + 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, + 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, + 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, + 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, + 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, + 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, + 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 1, + 0, 1, 1, 0, 0, 0, 0, 0, 1, 1, 0, + 0, 1, 1, 1, 0, 0, 0, 1, 1, 1, 0, + 0, 0, 1, 1, 1, 1, 1, 1, 1, 0, 0, + 0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 0, + }, + { // 1 + 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, + 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, + 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, + 0, 0, 1, 1, 1, 1, 1, 0, 0, 0, 0, + 0, 0, 1, 1, 0, 1, 1, 0, 0, 0, 0, + 0, 0, 1, 0, 0, 1, 1, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, + 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + }, + { // 2 + 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, + 0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, + 1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 0, + 0, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, + 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, + 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, + 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, + 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, + 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, + 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, + 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, + 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + }, + { // 3 + 0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 0, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, + 1, 1, 0, 0, 0, 0, 0, 1, 1, 1, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, + 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, + 0, 0, 1, 1, 1, 1, 1, 0, 0, 0, 0, + 0, 0, 1, 1, 1, 1, 1, 1, 1, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, + 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, + 1, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, + 0, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, + }, + { // 4 + 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, + 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, + 0, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0, + 0, 0, 0, 0, 0, 1, 0, 1, 1, 0, 0, + 0, 0, 0, 0, 1, 1, 0, 1, 1, 0, 0, + 0, 0, 0, 1, 1, 0, 0, 1, 1, 0, 0, + 0, 0, 0, 1, 1, 0, 0, 1, 1, 0, 0, + 0, 0, 1, 1, 0, 0, 0, 1, 1, 0, 0, + 0, 1, 1, 0, 0, 0, 0, 1, 1, 0, 0, + 0, 1, 1, 0, 0, 0, 0, 1, 1, 0, 0, + 1, 1, 0, 0, 0, 0, 0, 1, 1, 0, 0, + 1, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, + }, + { // 5 + 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, + 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, + 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, + 0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, + 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, + 1, 1, 0, 0, 0, 0, 0, 1, 1, 1, 0, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, + 0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 0, + }, + { // 6 + 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 0, + 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 0, + 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, + 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 1, 1, 0, 0, 1, 1, 1, 1, 0, 0, 0, + 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 0, + 1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 0, + 1, 1, 1, 0, 0, 0, 0, 0, 1, 1, 1, + 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, + 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, + 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, + 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, + 0, 1, 1, 0, 0, 0, 0, 0, 1, 1, 1, + 0, 1, 1, 1, 0, 0, 0, 1, 1, 1, 0, + 0, 0, 1, 1, 1, 1, 1, 1, 1, 0, 0, + 0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 0, + }, + { // 7 + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, + 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, + 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, + 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, + 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, + 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, + 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, + 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, + 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, + 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, + 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, + 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, + }, + { // 8 + 0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 0, + 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 0, + 0, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1, + 0, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, + 0, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, + 0, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, + 0, 0, 1, 1, 0, 0, 0, 0, 1, 1, 0, + 0, 0, 1, 1, 1, 1, 1, 1, 1, 0, 0, + 0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, + 0, 0, 1, 1, 1, 0, 1, 1, 1, 0, 0, + 0, 1, 1, 1, 0, 0, 0, 1, 1, 1, 0, + 1, 1, 1, 0, 0, 0, 0, 0, 1, 1, 1, + 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, + 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, + 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, + 1, 1, 1, 0, 0, 0, 0, 0, 1, 1, 0, + 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, + 0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 0, + }, + { // 9 + 0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 0, + 0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, + 0, 1, 1, 0, 0, 0, 0, 1, 1, 1, 0, + 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0, + 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, + 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, + 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, + 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 1, + 0, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1, + 0, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, + 0, 0, 0, 1, 1, 1, 1, 0, 0, 1, 1, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, + 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, + 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, + 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, + 0, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, + 0, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, + }, +} + +// Image struct +type Image struct { + *image.Paletted + numWidth int + numHeight int + dotSize int +} + +var prng = &siprng{} + +// randIntn returns a pseudorandom non-negative int in range [0, n). +func randIntn(n int) int { + return prng.Intn(n) +} + +// randInt returns a pseudorandom int in range [from, to]. +func randInt(from, to int) int { + return prng.Intn(to+1-from) + from +} + +// randFloat returns a pseudorandom float64 in range [from, to]. +func randFloat(from, to float64) float64 { + return (to-from)*prng.Float64() + from +} + +func randomPalette() color.Palette { + p := make([]color.Color, circleCount+1) + // Transparent color. + p[0] = color.RGBA{0xFF, 0xFF, 0xFF, 0x00} + // Primary color. + prim := color.RGBA{ + uint8(randIntn(129)), + uint8(randIntn(129)), + uint8(randIntn(129)), + 0xFF, + } + p[1] = prim + // Circle colors. + for i := 2; i <= circleCount; i++ { + p[i] = randomBrightness(prim, 255) + } + return p +} + +// NewImage returns a new captcha image of the given width and height with the +// given digits, where each digit must be in range 0-9. +func NewImage(digits []byte, width, height int) *Image { + m := new(Image) + m.Paletted = image.NewPaletted(image.Rect(0, 0, width, height), randomPalette()) + m.calculateSizes(width, height, len(digits)) + // Randomly position captcha inside the image. + maxx := width - (m.numWidth+m.dotSize)*len(digits) - m.dotSize + maxy := height - m.numHeight - m.dotSize*2 + var border int + if width > height { + border = height / 5 + } else { + border = width / 5 + } + x := randInt(border, maxx-border) + y := randInt(border, maxy-border) + // Draw digits. + for _, n := range digits { + m.drawDigit(font[n], x, y) + x += m.numWidth + m.dotSize + } + // Draw strike-through line. + m.strikeThrough() + // Apply wave distortion. + m.distort(randFloat(5, 10), randFloat(100, 200)) + // Fill image with random circles. + m.fillWithCircles(circleCount, m.dotSize) + return m +} + +// encodedPNG encodes an image to PNG and returns +// the result as a byte slice. +func (m *Image) encodedPNG() []byte { + var buf bytes.Buffer + if err := png.Encode(&buf, m.Paletted); err != nil { + panic(err.Error()) + } + return buf.Bytes() +} + +// WriteTo writes captcha image in PNG format into the given writer. +func (m *Image) WriteTo(w io.Writer) (int64, error) { + n, err := w.Write(m.encodedPNG()) + return int64(n), err +} + +func (m *Image) calculateSizes(width, height, ncount int) { + // Goal: fit all digits inside the image. + var border int + if width > height { + border = height / 4 + } else { + border = width / 4 + } + // Convert everything to floats for calculations. + w := float64(width - border*2) + h := float64(height - border*2) + // fw takes into account 1-dot spacing between digits. + fw := float64(fontWidth + 1) + fh := float64(fontHeight) + nc := float64(ncount) + // Calculate the width of a single digit taking into account only the + // width of the image. + nw := w / nc + // Calculate the height of a digit from this width. + nh := nw * fh / fw + // Digit too high? + if nh > h { + // Fit digits based on height. + nh = h + nw = fw / fh * nh + } + // Calculate dot size. + m.dotSize = int(nh / fh) + if m.dotSize < 1 { + m.dotSize = 1 + } + // Save everything, making the actual width smaller by 1 dot to account + // for spacing between digits. + m.numWidth = int(nw) - m.dotSize + m.numHeight = int(nh) +} + +func (m *Image) drawHorizLine(fromX, toX, y int, colorIdx uint8) { + for x := fromX; x <= toX; x++ { + m.SetColorIndex(x, y, colorIdx) + } +} + +func (m *Image) drawCircle(x, y, radius int, colorIdx uint8) { + f := 1 - radius + dfx := 1 + dfy := -2 * radius + xo := 0 + yo := radius + + m.SetColorIndex(x, y+radius, colorIdx) + m.SetColorIndex(x, y-radius, colorIdx) + m.drawHorizLine(x-radius, x+radius, y, colorIdx) + + for xo < yo { + if f >= 0 { + yo-- + dfy += 2 + f += dfy + } + xo++ + dfx += 2 + f += dfx + m.drawHorizLine(x-xo, x+xo, y+yo, colorIdx) + m.drawHorizLine(x-xo, x+xo, y-yo, colorIdx) + m.drawHorizLine(x-yo, x+yo, y+xo, colorIdx) + m.drawHorizLine(x-yo, x+yo, y-xo, colorIdx) + } +} + +func (m *Image) fillWithCircles(n, maxradius int) { + maxx := m.Bounds().Max.X + maxy := m.Bounds().Max.Y + for i := 0; i < n; i++ { + colorIdx := uint8(randInt(1, circleCount-1)) + r := randInt(1, maxradius) + m.drawCircle(randInt(r, maxx-r), randInt(r, maxy-r), r, colorIdx) + } +} + +func (m *Image) strikeThrough() { + maxx := m.Bounds().Max.X + maxy := m.Bounds().Max.Y + y := randInt(maxy/3, maxy-maxy/3) + amplitude := randFloat(5, 20) + period := randFloat(80, 180) + dx := 2.0 * math.Pi / period + for x := 0; x < maxx; x++ { + xo := amplitude * math.Cos(float64(y)*dx) + yo := amplitude * math.Sin(float64(x)*dx) + for yn := 0; yn < m.dotSize; yn++ { + r := randInt(0, m.dotSize) + m.drawCircle(x+int(xo), y+int(yo)+(yn*m.dotSize), r/2, 1) + } + } +} + +func (m *Image) drawDigit(digit []byte, x, y int) { + skf := randFloat(-maxSkew, maxSkew) + xs := float64(x) + r := m.dotSize / 2 + y += randInt(-r, r) + for yo := 0; yo < fontHeight; yo++ { + for xo := 0; xo < fontWidth; xo++ { + if digit[yo*fontWidth+xo] != blackChar { + continue + } + m.drawCircle(x+xo*m.dotSize, y+yo*m.dotSize, r, 1) + } + xs += skf + x = int(xs) + } +} + +func (m *Image) distort(amplude float64, period float64) { + w := m.Bounds().Max.X + h := m.Bounds().Max.Y + + oldm := m.Paletted + newm := image.NewPaletted(image.Rect(0, 0, w, h), oldm.Palette) + + dx := 2.0 * math.Pi / period + for x := 0; x < w; x++ { + for y := 0; y < h; y++ { + xo := amplude * math.Sin(float64(y)*dx) + yo := amplude * math.Cos(float64(x)*dx) + newm.SetColorIndex(x, y, oldm.ColorIndexAt(x+int(xo), y+int(yo))) + } + } + m.Paletted = newm +} + +func randomBrightness(c color.RGBA, max uint8) color.RGBA { + minc := min3(c.R, c.G, c.B) + maxc := max3(c.R, c.G, c.B) + if maxc > max { + return c + } + n := randIntn(int(max-maxc)) - int(minc) + return color.RGBA{ + uint8(int(c.R) + n), + uint8(int(c.G) + n), + uint8(int(c.B) + n), + c.A, + } +} + +func min3(x, y, z uint8) (m uint8) { + m = x + if y < m { + m = y + } + if z < m { + m = z + } + return +} + +func max3(x, y, z uint8) (m uint8) { + m = x + if y > m { + m = y + } + if z > m { + m = z + } + return +} diff --git a/pkg/utils/captcha/image_test.go b/pkg/utils/captcha/image_test.go new file mode 100644 index 00000000..5e35b7f7 --- /dev/null +++ b/pkg/utils/captcha/image_test.go @@ -0,0 +1,52 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package captcha + +import ( + "testing" + + "github.com/astaxie/beego/utils" +) + +type byteCounter struct { + n int64 +} + +func (bc *byteCounter) Write(b []byte) (int, error) { + bc.n += int64(len(b)) + return len(b), nil +} + +func BenchmarkNewImage(b *testing.B) { + b.StopTimer() + d := utils.RandomCreateBytes(challengeNums, defaultChars...) + b.StartTimer() + for i := 0; i < b.N; i++ { + NewImage(d, stdWidth, stdHeight) + } +} + +func BenchmarkImageWriteTo(b *testing.B) { + b.StopTimer() + d := utils.RandomCreateBytes(challengeNums, defaultChars...) + b.StartTimer() + counter := &byteCounter{} + for i := 0; i < b.N; i++ { + img := NewImage(d, stdWidth, stdHeight) + img.WriteTo(counter) + b.SetBytes(counter.n) + counter.n = 0 + } +} diff --git a/pkg/utils/captcha/siprng.go b/pkg/utils/captcha/siprng.go new file mode 100644 index 00000000..5e256cf9 --- /dev/null +++ b/pkg/utils/captcha/siprng.go @@ -0,0 +1,277 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package captcha + +import ( + "crypto/rand" + "encoding/binary" + "io" + "sync" +) + +// siprng is PRNG based on SipHash-2-4. +type siprng struct { + mu sync.Mutex + k0, k1, ctr uint64 +} + +// siphash implements SipHash-2-4, accepting a uint64 as a message. +func siphash(k0, k1, m uint64) uint64 { + // Initialization. + v0 := k0 ^ 0x736f6d6570736575 + v1 := k1 ^ 0x646f72616e646f6d + v2 := k0 ^ 0x6c7967656e657261 + v3 := k1 ^ 0x7465646279746573 + t := uint64(8) << 56 + + // Compression. + v3 ^= m + + // Round 1. + v0 += v1 + v1 = v1<<13 | v1>>(64-13) + v1 ^= v0 + v0 = v0<<32 | v0>>(64-32) + + v2 += v3 + v3 = v3<<16 | v3>>(64-16) + v3 ^= v2 + + v0 += v3 + v3 = v3<<21 | v3>>(64-21) + v3 ^= v0 + + v2 += v1 + v1 = v1<<17 | v1>>(64-17) + v1 ^= v2 + v2 = v2<<32 | v2>>(64-32) + + // Round 2. + v0 += v1 + v1 = v1<<13 | v1>>(64-13) + v1 ^= v0 + v0 = v0<<32 | v0>>(64-32) + + v2 += v3 + v3 = v3<<16 | v3>>(64-16) + v3 ^= v2 + + v0 += v3 + v3 = v3<<21 | v3>>(64-21) + v3 ^= v0 + + v2 += v1 + v1 = v1<<17 | v1>>(64-17) + v1 ^= v2 + v2 = v2<<32 | v2>>(64-32) + + v0 ^= m + + // Compress last block. + v3 ^= t + + // Round 1. + v0 += v1 + v1 = v1<<13 | v1>>(64-13) + v1 ^= v0 + v0 = v0<<32 | v0>>(64-32) + + v2 += v3 + v3 = v3<<16 | v3>>(64-16) + v3 ^= v2 + + v0 += v3 + v3 = v3<<21 | v3>>(64-21) + v3 ^= v0 + + v2 += v1 + v1 = v1<<17 | v1>>(64-17) + v1 ^= v2 + v2 = v2<<32 | v2>>(64-32) + + // Round 2. + v0 += v1 + v1 = v1<<13 | v1>>(64-13) + v1 ^= v0 + v0 = v0<<32 | v0>>(64-32) + + v2 += v3 + v3 = v3<<16 | v3>>(64-16) + v3 ^= v2 + + v0 += v3 + v3 = v3<<21 | v3>>(64-21) + v3 ^= v0 + + v2 += v1 + v1 = v1<<17 | v1>>(64-17) + v1 ^= v2 + v2 = v2<<32 | v2>>(64-32) + + v0 ^= t + + // Finalization. + v2 ^= 0xff + + // Round 1. + v0 += v1 + v1 = v1<<13 | v1>>(64-13) + v1 ^= v0 + v0 = v0<<32 | v0>>(64-32) + + v2 += v3 + v3 = v3<<16 | v3>>(64-16) + v3 ^= v2 + + v0 += v3 + v3 = v3<<21 | v3>>(64-21) + v3 ^= v0 + + v2 += v1 + v1 = v1<<17 | v1>>(64-17) + v1 ^= v2 + v2 = v2<<32 | v2>>(64-32) + + // Round 2. + v0 += v1 + v1 = v1<<13 | v1>>(64-13) + v1 ^= v0 + v0 = v0<<32 | v0>>(64-32) + + v2 += v3 + v3 = v3<<16 | v3>>(64-16) + v3 ^= v2 + + v0 += v3 + v3 = v3<<21 | v3>>(64-21) + v3 ^= v0 + + v2 += v1 + v1 = v1<<17 | v1>>(64-17) + v1 ^= v2 + v2 = v2<<32 | v2>>(64-32) + + // Round 3. + v0 += v1 + v1 = v1<<13 | v1>>(64-13) + v1 ^= v0 + v0 = v0<<32 | v0>>(64-32) + + v2 += v3 + v3 = v3<<16 | v3>>(64-16) + v3 ^= v2 + + v0 += v3 + v3 = v3<<21 | v3>>(64-21) + v3 ^= v0 + + v2 += v1 + v1 = v1<<17 | v1>>(64-17) + v1 ^= v2 + v2 = v2<<32 | v2>>(64-32) + + // Round 4. + v0 += v1 + v1 = v1<<13 | v1>>(64-13) + v1 ^= v0 + v0 = v0<<32 | v0>>(64-32) + + v2 += v3 + v3 = v3<<16 | v3>>(64-16) + v3 ^= v2 + + v0 += v3 + v3 = v3<<21 | v3>>(64-21) + v3 ^= v0 + + v2 += v1 + v1 = v1<<17 | v1>>(64-17) + v1 ^= v2 + v2 = v2<<32 | v2>>(64-32) + + return v0 ^ v1 ^ v2 ^ v3 +} + +// rekey sets a new PRNG key, which is read from crypto/rand. +func (p *siprng) rekey() { + var k [16]byte + if _, err := io.ReadFull(rand.Reader, k[:]); err != nil { + panic(err.Error()) + } + p.k0 = binary.LittleEndian.Uint64(k[0:8]) + p.k1 = binary.LittleEndian.Uint64(k[8:16]) + p.ctr = 1 +} + +// Uint64 returns a new pseudorandom uint64. +// It rekeys PRNG on the first call and every 64 MB of generated data. +func (p *siprng) Uint64() uint64 { + p.mu.Lock() + if p.ctr == 0 || p.ctr > 8*1024*1024 { + p.rekey() + } + v := siphash(p.k0, p.k1, p.ctr) + p.ctr++ + p.mu.Unlock() + return v +} + +func (p *siprng) Int63() int64 { + return int64(p.Uint64() & 0x7fffffffffffffff) +} + +func (p *siprng) Uint32() uint32 { + return uint32(p.Uint64()) +} + +func (p *siprng) Int31() int32 { + return int32(p.Uint32() & 0x7fffffff) +} + +func (p *siprng) Intn(n int) int { + if n <= 0 { + panic("invalid argument to Intn") + } + if n <= 1<<31-1 { + return int(p.Int31n(int32(n))) + } + return int(p.Int63n(int64(n))) +} + +func (p *siprng) Int63n(n int64) int64 { + if n <= 0 { + panic("invalid argument to Int63n") + } + max := int64((1 << 63) - 1 - (1<<63)%uint64(n)) + v := p.Int63() + for v > max { + v = p.Int63() + } + return v % n +} + +func (p *siprng) Int31n(n int32) int32 { + if n <= 0 { + panic("invalid argument to Int31n") + } + max := int32((1 << 31) - 1 - (1<<31)%uint32(n)) + v := p.Int31() + for v > max { + v = p.Int31() + } + return v % n +} + +func (p *siprng) Float64() float64 { return float64(p.Int63()) / (1 << 63) } diff --git a/pkg/utils/captcha/siprng_test.go b/pkg/utils/captcha/siprng_test.go new file mode 100644 index 00000000..189d3d3c --- /dev/null +++ b/pkg/utils/captcha/siprng_test.go @@ -0,0 +1,33 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package captcha + +import "testing" + +func TestSiphash(t *testing.T) { + good := uint64(0xe849e8bb6ffe2567) + cur := siphash(0, 0, 0) + if cur != good { + t.Fatalf("siphash: expected %x, got %x", good, cur) + } +} + +func BenchmarkSiprng(b *testing.B) { + b.SetBytes(8) + p := &siprng{} + for i := 0; i < b.N; i++ { + p.Uint64() + } +} diff --git a/pkg/utils/debug.go b/pkg/utils/debug.go new file mode 100644 index 00000000..93c27b70 --- /dev/null +++ b/pkg/utils/debug.go @@ -0,0 +1,478 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package utils + +import ( + "bytes" + "fmt" + "log" + "reflect" + "runtime" +) + +var ( + dunno = []byte("???") + centerDot = []byte("·") + dot = []byte(".") +) + +type pointerInfo struct { + prev *pointerInfo + n int + addr uintptr + pos int + used []int +} + +// Display print the data in console +func Display(data ...interface{}) { + display(true, data...) +} + +// GetDisplayString return data print string +func GetDisplayString(data ...interface{}) string { + return display(false, data...) +} + +func display(displayed bool, data ...interface{}) string { + var pc, file, line, ok = runtime.Caller(2) + + if !ok { + return "" + } + + var buf = new(bytes.Buffer) + + fmt.Fprintf(buf, "[Debug] at %s() [%s:%d]\n", function(pc), file, line) + + fmt.Fprintf(buf, "\n[Variables]\n") + + for i := 0; i < len(data); i += 2 { + var output = fomateinfo(len(data[i].(string))+3, data[i+1]) + fmt.Fprintf(buf, "%s = %s", data[i], output) + } + + if displayed { + log.Print(buf) + } + return buf.String() +} + +// return data dump and format bytes +func fomateinfo(headlen int, data ...interface{}) []byte { + var buf = new(bytes.Buffer) + + if len(data) > 1 { + fmt.Fprint(buf, " ") + + fmt.Fprint(buf, "[") + + fmt.Fprintln(buf) + } + + for k, v := range data { + var buf2 = new(bytes.Buffer) + var pointers *pointerInfo + var interfaces = make([]reflect.Value, 0, 10) + + printKeyValue(buf2, reflect.ValueOf(v), &pointers, &interfaces, nil, true, " ", 1) + + if k < len(data)-1 { + fmt.Fprint(buf2, ", ") + } + + fmt.Fprintln(buf2) + + buf.Write(buf2.Bytes()) + } + + if len(data) > 1 { + fmt.Fprintln(buf) + + fmt.Fprint(buf, " ") + + fmt.Fprint(buf, "]") + } + + return buf.Bytes() +} + +// check data is golang basic type +func isSimpleType(val reflect.Value, kind reflect.Kind, pointers **pointerInfo, interfaces *[]reflect.Value) bool { + switch kind { + case reflect.Bool: + return true + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + return true + case reflect.Uint8, reflect.Uint16, reflect.Uint, reflect.Uint32, reflect.Uint64: + return true + case reflect.Float32, reflect.Float64: + return true + case reflect.Complex64, reflect.Complex128: + return true + case reflect.String: + return true + case reflect.Chan: + return true + case reflect.Invalid: + return true + case reflect.Interface: + for _, in := range *interfaces { + if reflect.DeepEqual(in, val) { + return true + } + } + return false + case reflect.UnsafePointer: + if val.IsNil() { + return true + } + + var elem = val.Elem() + + if isSimpleType(elem, elem.Kind(), pointers, interfaces) { + return true + } + + var addr = val.Elem().UnsafeAddr() + + for p := *pointers; p != nil; p = p.prev { + if addr == p.addr { + return true + } + } + + return false + } + + return false +} + +// dump value +func printKeyValue(buf *bytes.Buffer, val reflect.Value, pointers **pointerInfo, interfaces *[]reflect.Value, structFilter func(string, string) bool, formatOutput bool, indent string, level int) { + var t = val.Kind() + + switch t { + case reflect.Bool: + fmt.Fprint(buf, val.Bool()) + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + fmt.Fprint(buf, val.Int()) + case reflect.Uint8, reflect.Uint16, reflect.Uint, reflect.Uint32, reflect.Uint64: + fmt.Fprint(buf, val.Uint()) + case reflect.Float32, reflect.Float64: + fmt.Fprint(buf, val.Float()) + case reflect.Complex64, reflect.Complex128: + fmt.Fprint(buf, val.Complex()) + case reflect.UnsafePointer: + fmt.Fprintf(buf, "unsafe.Pointer(0x%X)", val.Pointer()) + case reflect.Ptr: + if val.IsNil() { + fmt.Fprint(buf, "nil") + return + } + + var addr = val.Elem().UnsafeAddr() + + for p := *pointers; p != nil; p = p.prev { + if addr == p.addr { + p.used = append(p.used, buf.Len()) + fmt.Fprintf(buf, "0x%X", addr) + return + } + } + + *pointers = &pointerInfo{ + prev: *pointers, + addr: addr, + pos: buf.Len(), + used: make([]int, 0), + } + + fmt.Fprint(buf, "&") + + printKeyValue(buf, val.Elem(), pointers, interfaces, structFilter, formatOutput, indent, level) + case reflect.String: + fmt.Fprint(buf, "\"", val.String(), "\"") + case reflect.Interface: + var value = val.Elem() + + if !value.IsValid() { + fmt.Fprint(buf, "nil") + } else { + for _, in := range *interfaces { + if reflect.DeepEqual(in, val) { + fmt.Fprint(buf, "repeat") + return + } + } + + *interfaces = append(*interfaces, val) + + printKeyValue(buf, value, pointers, interfaces, structFilter, formatOutput, indent, level+1) + } + case reflect.Struct: + var t = val.Type() + + fmt.Fprint(buf, t) + fmt.Fprint(buf, "{") + + for i := 0; i < val.NumField(); i++ { + if formatOutput { + fmt.Fprintln(buf) + } else { + fmt.Fprint(buf, " ") + } + + var name = t.Field(i).Name + + if formatOutput { + for ind := 0; ind < level; ind++ { + fmt.Fprint(buf, indent) + } + } + + fmt.Fprint(buf, name) + fmt.Fprint(buf, ": ") + + if structFilter != nil && structFilter(t.String(), name) { + fmt.Fprint(buf, "ignore") + } else { + printKeyValue(buf, val.Field(i), pointers, interfaces, structFilter, formatOutput, indent, level+1) + } + + fmt.Fprint(buf, ",") + } + + if formatOutput { + fmt.Fprintln(buf) + + for ind := 0; ind < level-1; ind++ { + fmt.Fprint(buf, indent) + } + } else { + fmt.Fprint(buf, " ") + } + + fmt.Fprint(buf, "}") + case reflect.Array, reflect.Slice: + fmt.Fprint(buf, val.Type()) + fmt.Fprint(buf, "{") + + var allSimple = true + + for i := 0; i < val.Len(); i++ { + var elem = val.Index(i) + + var isSimple = isSimpleType(elem, elem.Kind(), pointers, interfaces) + + if !isSimple { + allSimple = false + } + + if formatOutput && !isSimple { + fmt.Fprintln(buf) + } else { + fmt.Fprint(buf, " ") + } + + if formatOutput && !isSimple { + for ind := 0; ind < level; ind++ { + fmt.Fprint(buf, indent) + } + } + + printKeyValue(buf, elem, pointers, interfaces, structFilter, formatOutput, indent, level+1) + + if i != val.Len()-1 || !allSimple { + fmt.Fprint(buf, ",") + } + } + + if formatOutput && !allSimple { + fmt.Fprintln(buf) + + for ind := 0; ind < level-1; ind++ { + fmt.Fprint(buf, indent) + } + } else { + fmt.Fprint(buf, " ") + } + + fmt.Fprint(buf, "}") + case reflect.Map: + var t = val.Type() + var keys = val.MapKeys() + + fmt.Fprint(buf, t) + fmt.Fprint(buf, "{") + + var allSimple = true + + for i := 0; i < len(keys); i++ { + var elem = val.MapIndex(keys[i]) + + var isSimple = isSimpleType(elem, elem.Kind(), pointers, interfaces) + + if !isSimple { + allSimple = false + } + + if formatOutput && !isSimple { + fmt.Fprintln(buf) + } else { + fmt.Fprint(buf, " ") + } + + if formatOutput && !isSimple { + for ind := 0; ind <= level; ind++ { + fmt.Fprint(buf, indent) + } + } + + printKeyValue(buf, keys[i], pointers, interfaces, structFilter, formatOutput, indent, level+1) + fmt.Fprint(buf, ": ") + printKeyValue(buf, elem, pointers, interfaces, structFilter, formatOutput, indent, level+1) + + if i != val.Len()-1 || !allSimple { + fmt.Fprint(buf, ",") + } + } + + if formatOutput && !allSimple { + fmt.Fprintln(buf) + + for ind := 0; ind < level-1; ind++ { + fmt.Fprint(buf, indent) + } + } else { + fmt.Fprint(buf, " ") + } + + fmt.Fprint(buf, "}") + case reflect.Chan: + fmt.Fprint(buf, val.Type()) + case reflect.Invalid: + fmt.Fprint(buf, "invalid") + default: + fmt.Fprint(buf, "unknow") + } +} + +// PrintPointerInfo dump pointer value +func PrintPointerInfo(buf *bytes.Buffer, headlen int, pointers *pointerInfo) { + var anyused = false + var pointerNum = 0 + + for p := pointers; p != nil; p = p.prev { + if len(p.used) > 0 { + anyused = true + } + pointerNum++ + p.n = pointerNum + } + + if anyused { + var pointerBufs = make([][]rune, pointerNum+1) + + for i := 0; i < len(pointerBufs); i++ { + var pointerBuf = make([]rune, buf.Len()+headlen) + + for j := 0; j < len(pointerBuf); j++ { + pointerBuf[j] = ' ' + } + + pointerBufs[i] = pointerBuf + } + + for pn := 0; pn <= pointerNum; pn++ { + for p := pointers; p != nil; p = p.prev { + if len(p.used) > 0 && p.n >= pn { + if pn == p.n { + pointerBufs[pn][p.pos+headlen] = '└' + + var maxpos = 0 + + for i, pos := range p.used { + if i < len(p.used)-1 { + pointerBufs[pn][pos+headlen] = '┴' + } else { + pointerBufs[pn][pos+headlen] = '┘' + } + + maxpos = pos + } + + for i := 0; i < maxpos-p.pos-1; i++ { + if pointerBufs[pn][i+p.pos+headlen+1] == ' ' { + pointerBufs[pn][i+p.pos+headlen+1] = '─' + } + } + } else { + pointerBufs[pn][p.pos+headlen] = '│' + + for _, pos := range p.used { + if pointerBufs[pn][pos+headlen] == ' ' { + pointerBufs[pn][pos+headlen] = '│' + } else { + pointerBufs[pn][pos+headlen] = '┼' + } + } + } + } + } + + buf.WriteString(string(pointerBufs[pn]) + "\n") + } + } +} + +// Stack get stack bytes +func Stack(skip int, indent string) []byte { + var buf = new(bytes.Buffer) + + for i := skip; ; i++ { + var pc, file, line, ok = runtime.Caller(i) + + if !ok { + break + } + + buf.WriteString(indent) + + fmt.Fprintf(buf, "at %s() [%s:%d]\n", function(pc), file, line) + } + + return buf.Bytes() +} + +// return the name of the function containing the PC if possible, +func function(pc uintptr) []byte { + fn := runtime.FuncForPC(pc) + if fn == nil { + return dunno + } + name := []byte(fn.Name()) + // The name includes the path name to the package, which is unnecessary + // since the file name is already included. Plus, it has center dots. + // That is, we see + // runtime/debug.*T·ptrmethod + // and want + // *T.ptrmethod + if period := bytes.Index(name, dot); period >= 0 { + name = name[period+1:] + } + name = bytes.Replace(name, centerDot, dot, -1) + return name +} diff --git a/pkg/utils/debug_test.go b/pkg/utils/debug_test.go new file mode 100644 index 00000000..efb8924e --- /dev/null +++ b/pkg/utils/debug_test.go @@ -0,0 +1,46 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package utils + +import ( + "testing" +) + +type mytype struct { + next *mytype + prev *mytype +} + +func TestPrint(t *testing.T) { + Display("v1", 1, "v2", 2, "v3", 3) +} + +func TestPrintPoint(t *testing.T) { + var v1 = new(mytype) + var v2 = new(mytype) + + v1.prev = nil + v1.next = v2 + + v2.prev = v1 + v2.next = nil + + Display("v1", v1, "v2", v2) +} + +func TestPrintString(t *testing.T) { + str := GetDisplayString("v1", 1, "v2", 2) + println(str) +} diff --git a/pkg/utils/file.go b/pkg/utils/file.go new file mode 100644 index 00000000..6090eb17 --- /dev/null +++ b/pkg/utils/file.go @@ -0,0 +1,101 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package utils + +import ( + "bufio" + "errors" + "io" + "os" + "path/filepath" + "regexp" +) + +// SelfPath gets compiled executable file absolute path +func SelfPath() string { + path, _ := filepath.Abs(os.Args[0]) + return path +} + +// SelfDir gets compiled executable file directory +func SelfDir() string { + return filepath.Dir(SelfPath()) +} + +// FileExists reports whether the named file or directory exists. +func FileExists(name string) bool { + if _, err := os.Stat(name); err != nil { + if os.IsNotExist(err) { + return false + } + } + return true +} + +// SearchFile Search a file in paths. +// this is often used in search config file in /etc ~/ +func SearchFile(filename string, paths ...string) (fullpath string, err error) { + for _, path := range paths { + if fullpath = filepath.Join(path, filename); FileExists(fullpath) { + return + } + } + err = errors.New(fullpath + " not found in paths") + return +} + +// GrepFile like command grep -E +// for example: GrepFile(`^hello`, "hello.txt") +// \n is striped while read +func GrepFile(patten string, filename string) (lines []string, err error) { + re, err := regexp.Compile(patten) + if err != nil { + return + } + + fd, err := os.Open(filename) + if err != nil { + return + } + lines = make([]string, 0) + reader := bufio.NewReader(fd) + prefix := "" + var isLongLine bool + for { + byteLine, isPrefix, er := reader.ReadLine() + if er != nil && er != io.EOF { + return nil, er + } + if er == io.EOF { + break + } + line := string(byteLine) + if isPrefix { + prefix += line + continue + } else { + isLongLine = true + } + + line = prefix + line + if isLongLine { + prefix = "" + } + if re.MatchString(line) { + lines = append(lines, line) + } + } + return lines, nil +} diff --git a/pkg/utils/file_test.go b/pkg/utils/file_test.go new file mode 100644 index 00000000..b2644157 --- /dev/null +++ b/pkg/utils/file_test.go @@ -0,0 +1,75 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package utils + +import ( + "path/filepath" + "reflect" + "testing" +) + +var noExistedFile = "/tmp/not_existed_file" + +func TestSelfPath(t *testing.T) { + path := SelfPath() + if path == "" { + t.Error("path cannot be empty") + } + t.Logf("SelfPath: %s", path) +} + +func TestSelfDir(t *testing.T) { + dir := SelfDir() + t.Logf("SelfDir: %s", dir) +} + +func TestFileExists(t *testing.T) { + if !FileExists("./file.go") { + t.Errorf("./file.go should exists, but it didn't") + } + + if FileExists(noExistedFile) { + t.Errorf("Weird, how could this file exists: %s", noExistedFile) + } +} + +func TestSearchFile(t *testing.T) { + path, err := SearchFile(filepath.Base(SelfPath()), SelfDir()) + if err != nil { + t.Error(err) + } + t.Log(path) + + _, err = SearchFile(noExistedFile, ".") + if err == nil { + t.Errorf("err shouldnt be nil, got path: %s", SelfDir()) + } +} + +func TestGrepFile(t *testing.T) { + _, err := GrepFile("", noExistedFile) + if err == nil { + t.Error("expect file-not-existed error, but got nothing") + } + + path := filepath.Join(".", "testdata", "grepe.test") + lines, err := GrepFile(`^\s*[^#]+`, path) + if err != nil { + t.Error(err) + } + if !reflect.DeepEqual(lines, []string{"hello", "world"}) { + t.Errorf("expect [hello world], but receive %v", lines) + } +} diff --git a/pkg/utils/mail.go b/pkg/utils/mail.go new file mode 100644 index 00000000..80a366ca --- /dev/null +++ b/pkg/utils/mail.go @@ -0,0 +1,424 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package utils + +import ( + "bytes" + "encoding/base64" + "encoding/json" + "errors" + "fmt" + "io" + "mime" + "mime/multipart" + "net/mail" + "net/smtp" + "net/textproto" + "os" + "path" + "path/filepath" + "strconv" + "strings" + "sync" +) + +const ( + maxLineLength = 76 + + upperhex = "0123456789ABCDEF" +) + +// Email is the type used for email messages +type Email struct { + Auth smtp.Auth + Identity string `json:"identity"` + Username string `json:"username"` + Password string `json:"password"` + Host string `json:"host"` + Port int `json:"port"` + From string `json:"from"` + To []string + Bcc []string + Cc []string + Subject string + Text string // Plaintext message (optional) + HTML string // Html message (optional) + Headers textproto.MIMEHeader + Attachments []*Attachment + ReadReceipt []string +} + +// Attachment is a struct representing an email attachment. +// Based on the mime/multipart.FileHeader struct, Attachment contains the name, MIMEHeader, and content of the attachment in question +type Attachment struct { + Filename string + Header textproto.MIMEHeader + Content []byte +} + +// NewEMail create new Email struct with config json. +// config json is followed from Email struct fields. +func NewEMail(config string) *Email { + e := new(Email) + e.Headers = textproto.MIMEHeader{} + err := json.Unmarshal([]byte(config), e) + if err != nil { + return nil + } + return e +} + +// Bytes Make all send information to byte +func (e *Email) Bytes() ([]byte, error) { + buff := &bytes.Buffer{} + w := multipart.NewWriter(buff) + // Set the appropriate headers (overwriting any conflicts) + // Leave out Bcc (only included in envelope headers) + e.Headers.Set("To", strings.Join(e.To, ",")) + if e.Cc != nil { + e.Headers.Set("Cc", strings.Join(e.Cc, ",")) + } + e.Headers.Set("From", e.From) + e.Headers.Set("Subject", e.Subject) + if len(e.ReadReceipt) != 0 { + e.Headers.Set("Disposition-Notification-To", strings.Join(e.ReadReceipt, ",")) + } + e.Headers.Set("MIME-Version", "1.0") + + // Write the envelope headers (including any custom headers) + if err := headerToBytes(buff, e.Headers); err != nil { + return nil, fmt.Errorf("Failed to render message headers: %s", err) + } + + e.Headers.Set("Content-Type", fmt.Sprintf("multipart/mixed;\r\n boundary=%s\r\n", w.Boundary())) + fmt.Fprintf(buff, "%s:", "Content-Type") + fmt.Fprintf(buff, " %s\r\n", fmt.Sprintf("multipart/mixed;\r\n boundary=%s\r\n", w.Boundary())) + + // Start the multipart/mixed part + fmt.Fprintf(buff, "--%s\r\n", w.Boundary()) + header := textproto.MIMEHeader{} + // Check to see if there is a Text or HTML field + if e.Text != "" || e.HTML != "" { + subWriter := multipart.NewWriter(buff) + // Create the multipart alternative part + header.Set("Content-Type", fmt.Sprintf("multipart/alternative;\r\n boundary=%s\r\n", subWriter.Boundary())) + // Write the header + if err := headerToBytes(buff, header); err != nil { + return nil, fmt.Errorf("Failed to render multipart message headers: %s", err) + } + // Create the body sections + if e.Text != "" { + header.Set("Content-Type", fmt.Sprintf("text/plain; charset=UTF-8")) + header.Set("Content-Transfer-Encoding", "quoted-printable") + if _, err := subWriter.CreatePart(header); err != nil { + return nil, err + } + // Write the text + if err := quotePrintEncode(buff, e.Text); err != nil { + return nil, err + } + } + if e.HTML != "" { + header.Set("Content-Type", fmt.Sprintf("text/html; charset=UTF-8")) + header.Set("Content-Transfer-Encoding", "quoted-printable") + if _, err := subWriter.CreatePart(header); err != nil { + return nil, err + } + // Write the text + if err := quotePrintEncode(buff, e.HTML); err != nil { + return nil, err + } + } + if err := subWriter.Close(); err != nil { + return nil, err + } + } + // Create attachment part, if necessary + for _, a := range e.Attachments { + ap, err := w.CreatePart(a.Header) + if err != nil { + return nil, err + } + // Write the base64Wrapped content to the part + base64Wrap(ap, a.Content) + } + if err := w.Close(); err != nil { + return nil, err + } + return buff.Bytes(), nil +} + +// AttachFile Add attach file to the send mail +func (e *Email) AttachFile(args ...string) (a *Attachment, err error) { + if len(args) < 1 || len(args) > 2 { // change && to || + err = errors.New("Must specify a file name and number of parameters can not exceed at least two") + return + } + filename := args[0] + id := "" + if len(args) > 1 { + id = args[1] + } + f, err := os.Open(filename) + if err != nil { + return + } + defer f.Close() + ct := mime.TypeByExtension(filepath.Ext(filename)) + basename := path.Base(filename) + return e.Attach(f, basename, ct, id) +} + +// Attach is used to attach content from an io.Reader to the email. +// Parameters include an io.Reader, the desired filename for the attachment, and the Content-Type. +func (e *Email) Attach(r io.Reader, filename string, args ...string) (a *Attachment, err error) { + if len(args) < 1 || len(args) > 2 { // change && to || + err = errors.New("Must specify the file type and number of parameters can not exceed at least two") + return + } + c := args[0] //Content-Type + id := "" + if len(args) > 1 { + id = args[1] //Content-ID + } + var buffer bytes.Buffer + if _, err = io.Copy(&buffer, r); err != nil { + return + } + at := &Attachment{ + Filename: filename, + Header: textproto.MIMEHeader{}, + Content: buffer.Bytes(), + } + // Get the Content-Type to be used in the MIMEHeader + if c != "" { + at.Header.Set("Content-Type", c) + } else { + // If the Content-Type is blank, set the Content-Type to "application/octet-stream" + at.Header.Set("Content-Type", "application/octet-stream") + } + if id != "" { + at.Header.Set("Content-Disposition", fmt.Sprintf("inline;\r\n filename=\"%s\"", filename)) + at.Header.Set("Content-ID", fmt.Sprintf("<%s>", id)) + } else { + at.Header.Set("Content-Disposition", fmt.Sprintf("attachment;\r\n filename=\"%s\"", filename)) + } + at.Header.Set("Content-Transfer-Encoding", "base64") + e.Attachments = append(e.Attachments, at) + return at, nil +} + +// Send will send out the mail +func (e *Email) Send() error { + if e.Auth == nil { + e.Auth = smtp.PlainAuth(e.Identity, e.Username, e.Password, e.Host) + } + // Merge the To, Cc, and Bcc fields + to := make([]string, 0, len(e.To)+len(e.Cc)+len(e.Bcc)) + to = append(append(append(to, e.To...), e.Cc...), e.Bcc...) + // Check to make sure there is at least one recipient and one "From" address + if len(to) == 0 { + return errors.New("Must specify at least one To address") + } + + // Use the username if no From is provided + if len(e.From) == 0 { + e.From = e.Username + } + + from, err := mail.ParseAddress(e.From) + if err != nil { + return err + } + + // use mail's RFC 2047 to encode any string + e.Subject = qEncode("utf-8", e.Subject) + + raw, err := e.Bytes() + if err != nil { + return err + } + return smtp.SendMail(e.Host+":"+strconv.Itoa(e.Port), e.Auth, from.Address, to, raw) +} + +// quotePrintEncode writes the quoted-printable text to the IO Writer (according to RFC 2045) +func quotePrintEncode(w io.Writer, s string) error { + var buf [3]byte + mc := 0 + for i := 0; i < len(s); i++ { + c := s[i] + // We're assuming Unix style text formats as input (LF line break), and + // quoted-printble uses CRLF line breaks. (Literal CRs will become + // "=0D", but probably shouldn't be there to begin with!) + if c == '\n' { + io.WriteString(w, "\r\n") + mc = 0 + continue + } + + var nextOut []byte + if isPrintable(c) { + nextOut = append(buf[:0], c) + } else { + nextOut = buf[:] + qpEscape(nextOut, c) + } + + // Add a soft line break if the next (encoded) byte would push this line + // to or past the limit. + if mc+len(nextOut) >= maxLineLength { + if _, err := io.WriteString(w, "=\r\n"); err != nil { + return err + } + mc = 0 + } + + if _, err := w.Write(nextOut); err != nil { + return err + } + mc += len(nextOut) + } + // No trailing end-of-line?? Soft line break, then. TODO: is this sane? + if mc > 0 { + io.WriteString(w, "=\r\n") + } + return nil +} + +// isPrintable returns true if the rune given is "printable" according to RFC 2045, false otherwise +func isPrintable(c byte) bool { + return (c >= '!' && c <= '<') || (c >= '>' && c <= '~') || (c == ' ' || c == '\n' || c == '\t') +} + +// qpEscape is a helper function for quotePrintEncode which escapes a +// non-printable byte. Expects len(dest) == 3. +func qpEscape(dest []byte, c byte) { + const nums = "0123456789ABCDEF" + dest[0] = '=' + dest[1] = nums[(c&0xf0)>>4] + dest[2] = nums[(c & 0xf)] +} + +// headerToBytes enumerates the key and values in the header, and writes the results to the IO Writer +func headerToBytes(w io.Writer, t textproto.MIMEHeader) error { + for k, v := range t { + // Write the header key + _, err := fmt.Fprintf(w, "%s:", k) + if err != nil { + return err + } + // Write each value in the header + for _, c := range v { + _, err := fmt.Fprintf(w, " %s\r\n", c) + if err != nil { + return err + } + } + } + return nil +} + +// base64Wrap encodes the attachment content, and wraps it according to RFC 2045 standards (every 76 chars) +// The output is then written to the specified io.Writer +func base64Wrap(w io.Writer, b []byte) { + // 57 raw bytes per 76-byte base64 line. + const maxRaw = 57 + // Buffer for each line, including trailing CRLF. + var buffer [maxLineLength + len("\r\n")]byte + copy(buffer[maxLineLength:], "\r\n") + // Process raw chunks until there's no longer enough to fill a line. + for len(b) >= maxRaw { + base64.StdEncoding.Encode(buffer[:], b[:maxRaw]) + w.Write(buffer[:]) + b = b[maxRaw:] + } + // Handle the last chunk of bytes. + if len(b) > 0 { + out := buffer[:base64.StdEncoding.EncodedLen(len(b))] + base64.StdEncoding.Encode(out, b) + out = append(out, "\r\n"...) + w.Write(out) + } +} + +// Encode returns the encoded-word form of s. If s is ASCII without special +// characters, it is returned unchanged. The provided charset is the IANA +// charset name of s. It is case insensitive. +// RFC 2047 encoded-word +func qEncode(charset, s string) string { + if !needsEncoding(s) { + return s + } + return encodeWord(charset, s) +} + +func needsEncoding(s string) bool { + for _, b := range s { + if (b < ' ' || b > '~') && b != '\t' { + return true + } + } + return false +} + +// encodeWord encodes a string into an encoded-word. +func encodeWord(charset, s string) string { + buf := getBuffer() + + buf.WriteString("=?") + buf.WriteString(charset) + buf.WriteByte('?') + buf.WriteByte('q') + buf.WriteByte('?') + + enc := make([]byte, 3) + for i := 0; i < len(s); i++ { + b := s[i] + switch { + case b == ' ': + buf.WriteByte('_') + case b <= '~' && b >= '!' && b != '=' && b != '?' && b != '_': + buf.WriteByte(b) + default: + enc[0] = '=' + enc[1] = upperhex[b>>4] + enc[2] = upperhex[b&0x0f] + buf.Write(enc) + } + } + buf.WriteString("?=") + + es := buf.String() + putBuffer(buf) + return es +} + +var bufPool = sync.Pool{ + New: func() interface{} { + return new(bytes.Buffer) + }, +} + +func getBuffer() *bytes.Buffer { + return bufPool.Get().(*bytes.Buffer) +} + +func putBuffer(buf *bytes.Buffer) { + if buf.Len() > 1024 { + return + } + buf.Reset() + bufPool.Put(buf) +} diff --git a/pkg/utils/mail_test.go b/pkg/utils/mail_test.go new file mode 100644 index 00000000..c38356a2 --- /dev/null +++ b/pkg/utils/mail_test.go @@ -0,0 +1,41 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package utils + +import "testing" + +func TestMail(t *testing.T) { + config := `{"username":"astaxie@gmail.com","password":"astaxie","host":"smtp.gmail.com","port":587}` + mail := NewEMail(config) + if mail.Username != "astaxie@gmail.com" { + t.Fatal("email parse get username error") + } + if mail.Password != "astaxie" { + t.Fatal("email parse get password error") + } + if mail.Host != "smtp.gmail.com" { + t.Fatal("email parse get host error") + } + if mail.Port != 587 { + t.Fatal("email parse get port error") + } + mail.To = []string{"xiemengjun@gmail.com"} + mail.From = "astaxie@gmail.com" + mail.Subject = "hi, just from beego!" + mail.Text = "Text Body is, of course, supported!" + mail.HTML = "

Fancy Html is supported, too!

" + mail.AttachFile("/Users/astaxie/github/beego/beego.go") + mail.Send() +} diff --git a/pkg/utils/pagination/controller.go b/pkg/utils/pagination/controller.go new file mode 100644 index 00000000..2f022d0c --- /dev/null +++ b/pkg/utils/pagination/controller.go @@ -0,0 +1,26 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package pagination + +import ( + "github.com/astaxie/beego/context" +) + +// SetPaginator Instantiates a Paginator and assigns it to context.Input.Data("paginator"). +func SetPaginator(context *context.Context, per int, nums int64) (paginator *Paginator) { + paginator = NewPaginator(context.Request, per, nums) + context.Input.SetData("paginator", &paginator) + return +} diff --git a/pkg/utils/pagination/doc.go b/pkg/utils/pagination/doc.go new file mode 100644 index 00000000..9abc6d78 --- /dev/null +++ b/pkg/utils/pagination/doc.go @@ -0,0 +1,58 @@ +/* +Package pagination provides utilities to setup a paginator within the +context of a http request. + +Usage + +In your beego.Controller: + + package controllers + + import "github.com/astaxie/beego/utils/pagination" + + type PostsController struct { + beego.Controller + } + + func (this *PostsController) ListAllPosts() { + // sets this.Data["paginator"] with the current offset (from the url query param) + postsPerPage := 20 + paginator := pagination.SetPaginator(this.Ctx, postsPerPage, CountPosts()) + + // fetch the next 20 posts + this.Data["posts"] = ListPostsByOffsetAndLimit(paginator.Offset(), postsPerPage) + } + + +In your view templates: + + {{if .paginator.HasPages}} + + {{end}} + +See also + +http://beego.me/docs/mvc/view/page.md + +*/ +package pagination diff --git a/pkg/utils/pagination/paginator.go b/pkg/utils/pagination/paginator.go new file mode 100644 index 00000000..c6db31e0 --- /dev/null +++ b/pkg/utils/pagination/paginator.go @@ -0,0 +1,189 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package pagination + +import ( + "math" + "net/http" + "net/url" + "strconv" +) + +// Paginator within the state of a http request. +type Paginator struct { + Request *http.Request + PerPageNums int + MaxPages int + + nums int64 + pageRange []int + pageNums int + page int +} + +// PageNums Returns the total number of pages. +func (p *Paginator) PageNums() int { + if p.pageNums != 0 { + return p.pageNums + } + pageNums := math.Ceil(float64(p.nums) / float64(p.PerPageNums)) + if p.MaxPages > 0 { + pageNums = math.Min(pageNums, float64(p.MaxPages)) + } + p.pageNums = int(pageNums) + return p.pageNums +} + +// Nums Returns the total number of items (e.g. from doing SQL count). +func (p *Paginator) Nums() int64 { + return p.nums +} + +// SetNums Sets the total number of items. +func (p *Paginator) SetNums(nums interface{}) { + p.nums, _ = toInt64(nums) +} + +// Page Returns the current page. +func (p *Paginator) Page() int { + if p.page != 0 { + return p.page + } + if p.Request.Form == nil { + p.Request.ParseForm() + } + p.page, _ = strconv.Atoi(p.Request.Form.Get("p")) + if p.page > p.PageNums() { + p.page = p.PageNums() + } + if p.page <= 0 { + p.page = 1 + } + return p.page +} + +// Pages Returns a list of all pages. +// +// Usage (in a view template): +// +// {{range $index, $page := .paginator.Pages}} +// +// {{$page}} +// +// {{end}} +func (p *Paginator) Pages() []int { + if p.pageRange == nil && p.nums > 0 { + var pages []int + pageNums := p.PageNums() + page := p.Page() + switch { + case page >= pageNums-4 && pageNums > 9: + start := pageNums - 9 + 1 + pages = make([]int, 9) + for i := range pages { + pages[i] = start + i + } + case page >= 5 && pageNums > 9: + start := page - 5 + 1 + pages = make([]int, int(math.Min(9, float64(page+4+1)))) + for i := range pages { + pages[i] = start + i + } + default: + pages = make([]int, int(math.Min(9, float64(pageNums)))) + for i := range pages { + pages[i] = i + 1 + } + } + p.pageRange = pages + } + return p.pageRange +} + +// PageLink Returns URL for a given page index. +func (p *Paginator) PageLink(page int) string { + link, _ := url.ParseRequestURI(p.Request.URL.String()) + values := link.Query() + if page == 1 { + values.Del("p") + } else { + values.Set("p", strconv.Itoa(page)) + } + link.RawQuery = values.Encode() + return link.String() +} + +// PageLinkPrev Returns URL to the previous page. +func (p *Paginator) PageLinkPrev() (link string) { + if p.HasPrev() { + link = p.PageLink(p.Page() - 1) + } + return +} + +// PageLinkNext Returns URL to the next page. +func (p *Paginator) PageLinkNext() (link string) { + if p.HasNext() { + link = p.PageLink(p.Page() + 1) + } + return +} + +// PageLinkFirst Returns URL to the first page. +func (p *Paginator) PageLinkFirst() (link string) { + return p.PageLink(1) +} + +// PageLinkLast Returns URL to the last page. +func (p *Paginator) PageLinkLast() (link string) { + return p.PageLink(p.PageNums()) +} + +// HasPrev Returns true if the current page has a predecessor. +func (p *Paginator) HasPrev() bool { + return p.Page() > 1 +} + +// HasNext Returns true if the current page has a successor. +func (p *Paginator) HasNext() bool { + return p.Page() < p.PageNums() +} + +// IsActive Returns true if the given page index points to the current page. +func (p *Paginator) IsActive(page int) bool { + return p.Page() == page +} + +// Offset Returns the current offset. +func (p *Paginator) Offset() int { + return (p.Page() - 1) * p.PerPageNums +} + +// HasPages Returns true if there is more than one page. +func (p *Paginator) HasPages() bool { + return p.PageNums() > 1 +} + +// NewPaginator Instantiates a paginator struct for the current http request. +func NewPaginator(req *http.Request, per int, nums interface{}) *Paginator { + p := Paginator{} + p.Request = req + if per <= 0 { + per = 10 + } + p.PerPageNums = per + p.SetNums(nums) + return &p +} diff --git a/pkg/utils/pagination/utils.go b/pkg/utils/pagination/utils.go new file mode 100644 index 00000000..686e68b0 --- /dev/null +++ b/pkg/utils/pagination/utils.go @@ -0,0 +1,34 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package pagination + +import ( + "fmt" + "reflect" +) + +// ToInt64 convert any numeric value to int64 +func toInt64(value interface{}) (d int64, err error) { + val := reflect.ValueOf(value) + switch value.(type) { + case int, int8, int16, int32, int64: + d = val.Int() + case uint, uint8, uint16, uint32, uint64: + d = int64(val.Uint()) + default: + err = fmt.Errorf("ToInt64 need numeric not `%T`", value) + } + return +} diff --git a/pkg/utils/rand.go b/pkg/utils/rand.go new file mode 100644 index 00000000..344d1cd5 --- /dev/null +++ b/pkg/utils/rand.go @@ -0,0 +1,44 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package utils + +import ( + "crypto/rand" + r "math/rand" + "time" +) + +var alphaNum = []byte(`0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz`) + +// RandomCreateBytes generate random []byte by specify chars. +func RandomCreateBytes(n int, alphabets ...byte) []byte { + if len(alphabets) == 0 { + alphabets = alphaNum + } + var bytes = make([]byte, n) + var randBy bool + if num, err := rand.Read(bytes); num != n || err != nil { + r.Seed(time.Now().UnixNano()) + randBy = true + } + for i, b := range bytes { + if randBy { + bytes[i] = alphabets[r.Intn(len(alphabets))] + } else { + bytes[i] = alphabets[b%byte(len(alphabets))] + } + } + return bytes +} diff --git a/pkg/utils/rand_test.go b/pkg/utils/rand_test.go new file mode 100644 index 00000000..6c238b5e --- /dev/null +++ b/pkg/utils/rand_test.go @@ -0,0 +1,33 @@ +// Copyright 2016 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package utils + +import "testing" + +func TestRand_01(t *testing.T) { + bs0 := RandomCreateBytes(16) + bs1 := RandomCreateBytes(16) + + t.Log(string(bs0), string(bs1)) + if string(bs0) == string(bs1) { + t.FailNow() + } + + bs0 = RandomCreateBytes(4, []byte(`a`)...) + + if string(bs0) != "aaaa" { + t.FailNow() + } +} diff --git a/pkg/utils/safemap.go b/pkg/utils/safemap.go new file mode 100644 index 00000000..1793030a --- /dev/null +++ b/pkg/utils/safemap.go @@ -0,0 +1,91 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package utils + +import ( + "sync" +) + +// BeeMap is a map with lock +type BeeMap struct { + lock *sync.RWMutex + bm map[interface{}]interface{} +} + +// NewBeeMap return new safemap +func NewBeeMap() *BeeMap { + return &BeeMap{ + lock: new(sync.RWMutex), + bm: make(map[interface{}]interface{}), + } +} + +// Get from maps return the k's value +func (m *BeeMap) Get(k interface{}) interface{} { + m.lock.RLock() + defer m.lock.RUnlock() + if val, ok := m.bm[k]; ok { + return val + } + return nil +} + +// Set Maps the given key and value. Returns false +// if the key is already in the map and changes nothing. +func (m *BeeMap) Set(k interface{}, v interface{}) bool { + m.lock.Lock() + defer m.lock.Unlock() + if val, ok := m.bm[k]; !ok { + m.bm[k] = v + } else if val != v { + m.bm[k] = v + } else { + return false + } + return true +} + +// Check Returns true if k is exist in the map. +func (m *BeeMap) Check(k interface{}) bool { + m.lock.RLock() + defer m.lock.RUnlock() + _, ok := m.bm[k] + return ok +} + +// Delete the given key and value. +func (m *BeeMap) Delete(k interface{}) { + m.lock.Lock() + defer m.lock.Unlock() + delete(m.bm, k) +} + +// Items returns all items in safemap. +func (m *BeeMap) Items() map[interface{}]interface{} { + m.lock.RLock() + defer m.lock.RUnlock() + r := make(map[interface{}]interface{}) + for k, v := range m.bm { + r[k] = v + } + return r +} + +// Count returns the number of items within the map. +func (m *BeeMap) Count() int { + m.lock.RLock() + defer m.lock.RUnlock() + return len(m.bm) +} diff --git a/pkg/utils/safemap_test.go b/pkg/utils/safemap_test.go new file mode 100644 index 00000000..65085195 --- /dev/null +++ b/pkg/utils/safemap_test.go @@ -0,0 +1,89 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package utils + +import "testing" + +var safeMap *BeeMap + +func TestNewBeeMap(t *testing.T) { + safeMap = NewBeeMap() + if safeMap == nil { + t.Fatal("expected to return non-nil BeeMap", "got", safeMap) + } +} + +func TestSet(t *testing.T) { + safeMap = NewBeeMap() + if ok := safeMap.Set("astaxie", 1); !ok { + t.Error("expected", true, "got", false) + } +} + +func TestReSet(t *testing.T) { + safeMap := NewBeeMap() + if ok := safeMap.Set("astaxie", 1); !ok { + t.Error("expected", true, "got", false) + } + // set diff value + if ok := safeMap.Set("astaxie", -1); !ok { + t.Error("expected", true, "got", false) + } + + // set same value + if ok := safeMap.Set("astaxie", -1); ok { + t.Error("expected", false, "got", true) + } +} + +func TestCheck(t *testing.T) { + if exists := safeMap.Check("astaxie"); !exists { + t.Error("expected", true, "got", false) + } +} + +func TestGet(t *testing.T) { + if val := safeMap.Get("astaxie"); val.(int) != 1 { + t.Error("expected value", 1, "got", val) + } +} + +func TestDelete(t *testing.T) { + safeMap.Delete("astaxie") + if exists := safeMap.Check("astaxie"); exists { + t.Error("expected element to be deleted") + } +} + +func TestItems(t *testing.T) { + safeMap := NewBeeMap() + safeMap.Set("astaxie", "hello") + for k, v := range safeMap.Items() { + key := k.(string) + value := v.(string) + if key != "astaxie" { + t.Error("expected the key should be astaxie") + } + if value != "hello" { + t.Error("expected the value should be hello") + } + } +} + +func TestCount(t *testing.T) { + if count := safeMap.Count(); count != 0 { + t.Error("expected count to be", 0, "got", count) + } +} diff --git a/pkg/utils/slice.go b/pkg/utils/slice.go new file mode 100644 index 00000000..8f2cef98 --- /dev/null +++ b/pkg/utils/slice.go @@ -0,0 +1,170 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package utils + +import ( + "math/rand" + "time" +) + +type reducetype func(interface{}) interface{} +type filtertype func(interface{}) bool + +// InSlice checks given string in string slice or not. +func InSlice(v string, sl []string) bool { + for _, vv := range sl { + if vv == v { + return true + } + } + return false +} + +// InSliceIface checks given interface in interface slice. +func InSliceIface(v interface{}, sl []interface{}) bool { + for _, vv := range sl { + if vv == v { + return true + } + } + return false +} + +// SliceRandList generate an int slice from min to max. +func SliceRandList(min, max int) []int { + if max < min { + min, max = max, min + } + length := max - min + 1 + t0 := time.Now() + rand.Seed(int64(t0.Nanosecond())) + list := rand.Perm(length) + for index := range list { + list[index] += min + } + return list +} + +// SliceMerge merges interface slices to one slice. +func SliceMerge(slice1, slice2 []interface{}) (c []interface{}) { + c = append(slice1, slice2...) + return +} + +// SliceReduce generates a new slice after parsing every value by reduce function +func SliceReduce(slice []interface{}, a reducetype) (dslice []interface{}) { + for _, v := range slice { + dslice = append(dslice, a(v)) + } + return +} + +// SliceRand returns random one from slice. +func SliceRand(a []interface{}) (b interface{}) { + randnum := rand.Intn(len(a)) + b = a[randnum] + return +} + +// SliceSum sums all values in int64 slice. +func SliceSum(intslice []int64) (sum int64) { + for _, v := range intslice { + sum += v + } + return +} + +// SliceFilter generates a new slice after filter function. +func SliceFilter(slice []interface{}, a filtertype) (ftslice []interface{}) { + for _, v := range slice { + if a(v) { + ftslice = append(ftslice, v) + } + } + return +} + +// SliceDiff returns diff slice of slice1 - slice2. +func SliceDiff(slice1, slice2 []interface{}) (diffslice []interface{}) { + for _, v := range slice1 { + if !InSliceIface(v, slice2) { + diffslice = append(diffslice, v) + } + } + return +} + +// SliceIntersect returns slice that are present in all the slice1 and slice2. +func SliceIntersect(slice1, slice2 []interface{}) (diffslice []interface{}) { + for _, v := range slice1 { + if InSliceIface(v, slice2) { + diffslice = append(diffslice, v) + } + } + return +} + +// SliceChunk separates one slice to some sized slice. +func SliceChunk(slice []interface{}, size int) (chunkslice [][]interface{}) { + if size >= len(slice) { + chunkslice = append(chunkslice, slice) + return + } + end := size + for i := 0; i <= (len(slice) - size); i += size { + chunkslice = append(chunkslice, slice[i:end]) + end += size + } + return +} + +// SliceRange generates a new slice from begin to end with step duration of int64 number. +func SliceRange(start, end, step int64) (intslice []int64) { + for i := start; i <= end; i += step { + intslice = append(intslice, i) + } + return +} + +// SlicePad prepends size number of val into slice. +func SlicePad(slice []interface{}, size int, val interface{}) []interface{} { + if size <= len(slice) { + return slice + } + for i := 0; i < (size - len(slice)); i++ { + slice = append(slice, val) + } + return slice +} + +// SliceUnique cleans repeated values in slice. +func SliceUnique(slice []interface{}) (uniqueslice []interface{}) { + for _, v := range slice { + if !InSliceIface(v, uniqueslice) { + uniqueslice = append(uniqueslice, v) + } + } + return +} + +// SliceShuffle shuffles a slice. +func SliceShuffle(slice []interface{}) []interface{} { + for i := 0; i < len(slice); i++ { + a := rand.Intn(len(slice)) + b := rand.Intn(len(slice)) + slice[a], slice[b] = slice[b], slice[a] + } + return slice +} diff --git a/pkg/utils/slice_test.go b/pkg/utils/slice_test.go new file mode 100644 index 00000000..142dec96 --- /dev/null +++ b/pkg/utils/slice_test.go @@ -0,0 +1,29 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package utils + +import ( + "testing" +) + +func TestInSlice(t *testing.T) { + sl := []string{"A", "b"} + if !InSlice("A", sl) { + t.Error("should be true") + } + if InSlice("B", sl) { + t.Error("should be false") + } +} diff --git a/pkg/utils/testdata/grepe.test b/pkg/utils/testdata/grepe.test new file mode 100644 index 00000000..6c014c40 --- /dev/null +++ b/pkg/utils/testdata/grepe.test @@ -0,0 +1,7 @@ +# empty lines + + + +hello +# comment +world diff --git a/pkg/utils/utils.go b/pkg/utils/utils.go new file mode 100644 index 00000000..3874b803 --- /dev/null +++ b/pkg/utils/utils.go @@ -0,0 +1,89 @@ +package utils + +import ( + "os" + "path/filepath" + "regexp" + "runtime" + "strconv" + "strings" +) + +// GetGOPATHs returns all paths in GOPATH variable. +func GetGOPATHs() []string { + gopath := os.Getenv("GOPATH") + if gopath == "" && compareGoVersion(runtime.Version(), "go1.8") >= 0 { + gopath = defaultGOPATH() + } + return filepath.SplitList(gopath) +} + +func compareGoVersion(a, b string) int { + reg := regexp.MustCompile("^\\d*") + + a = strings.TrimPrefix(a, "go") + b = strings.TrimPrefix(b, "go") + + versionsA := strings.Split(a, ".") + versionsB := strings.Split(b, ".") + + for i := 0; i < len(versionsA) && i < len(versionsB); i++ { + versionA := versionsA[i] + versionB := versionsB[i] + + vA, err := strconv.Atoi(versionA) + if err != nil { + str := reg.FindString(versionA) + if str != "" { + vA, _ = strconv.Atoi(str) + } else { + vA = -1 + } + } + + vB, err := strconv.Atoi(versionB) + if err != nil { + str := reg.FindString(versionB) + if str != "" { + vB, _ = strconv.Atoi(str) + } else { + vB = -1 + } + } + + if vA > vB { + // vA = 12, vB = 8 + return 1 + } else if vA < vB { + // vA = 6, vB = 8 + return -1 + } else if vA == -1 { + // vA = rc1, vB = rc3 + return strings.Compare(versionA, versionB) + } + + // vA = vB = 8 + continue + } + + if len(versionsA) > len(versionsB) { + return 1 + } else if len(versionsA) == len(versionsB) { + return 0 + } + + return -1 +} + +func defaultGOPATH() string { + env := "HOME" + if runtime.GOOS == "windows" { + env = "USERPROFILE" + } else if runtime.GOOS == "plan9" { + env = "home" + } + if home := os.Getenv(env); home != "" { + return filepath.Join(home, "go") + } + return "" +} diff --git a/pkg/utils/utils_test.go b/pkg/utils/utils_test.go new file mode 100644 index 00000000..ced6f63f --- /dev/null +++ b/pkg/utils/utils_test.go @@ -0,0 +1,36 @@ +package utils + +import ( + "testing" +) + +func TestCompareGoVersion(t *testing.T) { + targetVersion := "go1.8" + if compareGoVersion("go1.12.4", targetVersion) != 1 { + t.Error("should be 1") + } + + if compareGoVersion("go1.8.7", targetVersion) != 1 { + t.Error("should be 1") + } + + if compareGoVersion("go1.8", targetVersion) != 0 { + t.Error("should be 0") + } + + if compareGoVersion("go1.7.6", targetVersion) != -1 { + t.Error("should be -1") + } + + if compareGoVersion("go1.12.1rc1", targetVersion) != 1 { + t.Error("should be 1") + } + + if compareGoVersion("go1.8rc1", targetVersion) != 0 { + t.Error("should be 0") + } + + if compareGoVersion("go1.7rc1", targetVersion) != -1 { + t.Error("should be -1") + } +} diff --git a/pkg/validation/README.md b/pkg/validation/README.md new file mode 100644 index 00000000..43373e47 --- /dev/null +++ b/pkg/validation/README.md @@ -0,0 +1,147 @@ +validation +============== + +validation is a form validation for a data validation and error collecting using Go. + +## Installation and tests + +Install: + + go get github.com/astaxie/beego/validation + +Test: + + go test github.com/astaxie/beego/validation + +## Example + +Direct Use: + + import ( + "github.com/astaxie/beego/validation" + "log" + ) + + type User struct { + Name string + Age int + } + + func main() { + u := User{"man", 40} + valid := validation.Validation{} + valid.Required(u.Name, "name") + valid.MaxSize(u.Name, 15, "nameMax") + valid.Range(u.Age, 0, 140, "age") + if valid.HasErrors() { + // validation does not pass + // print invalid message + for _, err := range valid.Errors { + log.Println(err.Key, err.Message) + } + } + // or use like this + if v := valid.Max(u.Age, 140, "ageMax"); !v.Ok { + log.Println(v.Error.Key, v.Error.Message) + } + } + +Struct Tag Use: + + import ( + "github.com/astaxie/beego/validation" + ) + + // validation function follow with "valid" tag + // functions divide with ";" + // parameters in parentheses "()" and divide with "," + // Match function's pattern string must in "//" + type user struct { + Id int + Name string `valid:"Required;Match(/^(test)?\\w*@;com$/)"` + Age int `valid:"Required;Range(1, 140)"` + } + + func main() { + valid := validation.Validation{} + // ignore empty field valid + // see CanSkipFuncs + // valid := validation.Validation{RequiredFirst:true} + u := user{Name: "test", Age: 40} + b, err := valid.Valid(u) + if err != nil { + // handle error + } + if !b { + // validation does not pass + // blabla... + } + } + +Use custom function: + + import ( + "github.com/astaxie/beego/validation" + ) + + type user struct { + Id int + Name string `valid:"Required;IsMe"` + Age int `valid:"Required;Range(1, 140)"` + } + + func IsMe(v *validation.Validation, obj interface{}, key string) { + name, ok:= obj.(string) + if !ok { + // wrong use case? + return + } + + if name != "me" { + // valid false + v.SetError("Name", "is not me!") + } + } + + func main() { + valid := validation.Validation{} + if err := validation.AddCustomFunc("IsMe", IsMe); err != nil { + // hadle error + } + u := user{Name: "test", Age: 40} + b, err := valid.Valid(u) + if err != nil { + // handle error + } + if !b { + // validation does not pass + // blabla... + } + } + +Struct Tag Functions: + + Required + Min(min int) + Max(max int) + Range(min, max int) + MinSize(min int) + MaxSize(max int) + Length(length int) + Alpha + Numeric + AlphaNumeric + Match(pattern string) + AlphaDash + Email + IP + Base64 + Mobile + Tel + Phone + ZipCode + + +## LICENSE + +BSD License http://creativecommons.org/licenses/BSD/ diff --git a/pkg/validation/util.go b/pkg/validation/util.go new file mode 100644 index 00000000..82206f4f --- /dev/null +++ b/pkg/validation/util.go @@ -0,0 +1,298 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package validation + +import ( + "fmt" + "reflect" + "regexp" + "strconv" + "strings" +) + +const ( + // ValidTag struct tag + ValidTag = "valid" + + LabelTag = "label" + + wordsize = 32 << (^uint(0) >> 32 & 1) +) + +var ( + // key: function name + // value: the number of parameters + funcs = make(Funcs) + + // doesn't belong to validation functions + unFuncs = map[string]bool{ + "Clear": true, + "HasErrors": true, + "ErrorMap": true, + "Error": true, + "apply": true, + "Check": true, + "Valid": true, + "NoMatch": true, + } + // ErrInt64On32 show 32 bit platform not support int64 + ErrInt64On32 = fmt.Errorf("not support int64 on 32-bit platform") +) + +func init() { + v := &Validation{} + t := reflect.TypeOf(v) + for i := 0; i < t.NumMethod(); i++ { + m := t.Method(i) + if !unFuncs[m.Name] { + funcs[m.Name] = m.Func + } + } +} + +// CustomFunc is for custom validate function +type CustomFunc func(v *Validation, obj interface{}, key string) + +// AddCustomFunc Add a custom function to validation +// The name can not be: +// Clear +// HasErrors +// ErrorMap +// Error +// Check +// Valid +// NoMatch +// If the name is same with exists function, it will replace the origin valid function +func AddCustomFunc(name string, f CustomFunc) error { + if unFuncs[name] { + return fmt.Errorf("invalid function name: %s", name) + } + + funcs[name] = reflect.ValueOf(f) + return nil +} + +// ValidFunc Valid function type +type ValidFunc struct { + Name string + Params []interface{} +} + +// Funcs Validate function map +type Funcs map[string]reflect.Value + +// Call validate values with named type string +func (f Funcs) Call(name string, params ...interface{}) (result []reflect.Value, err error) { + defer func() { + if r := recover(); r != nil { + err = fmt.Errorf("%v", r) + } + }() + if _, ok := f[name]; !ok { + err = fmt.Errorf("%s does not exist", name) + return + } + if len(params) != f[name].Type().NumIn() { + err = fmt.Errorf("The number of params is not adapted") + return + } + in := make([]reflect.Value, len(params)) + for k, param := range params { + in[k] = reflect.ValueOf(param) + } + result = f[name].Call(in) + return +} + +func isStruct(t reflect.Type) bool { + return t.Kind() == reflect.Struct +} + +func isStructPtr(t reflect.Type) bool { + return t.Kind() == reflect.Ptr && t.Elem().Kind() == reflect.Struct +} + +func getValidFuncs(f reflect.StructField) (vfs []ValidFunc, err error) { + tag := f.Tag.Get(ValidTag) + label := f.Tag.Get(LabelTag) + if len(tag) == 0 { + return + } + if vfs, tag, err = getRegFuncs(tag, f.Name); err != nil { + return + } + fs := strings.Split(tag, ";") + for _, vfunc := range fs { + var vf ValidFunc + if len(vfunc) == 0 { + continue + } + vf, err = parseFunc(vfunc, f.Name, label) + if err != nil { + return + } + vfs = append(vfs, vf) + } + return +} + +// Get Match function +// May be get NoMatch function in the future +func getRegFuncs(tag, key string) (vfs []ValidFunc, str string, err error) { + tag = strings.TrimSpace(tag) + index := strings.Index(tag, "Match(/") + if index == -1 { + str = tag + return + } + end := strings.LastIndex(tag, "/)") + if end < index { + err = fmt.Errorf("invalid Match function") + return + } + reg, err := regexp.Compile(tag[index+len("Match(/") : end]) + if err != nil { + return + } + vfs = []ValidFunc{{"Match", []interface{}{reg, key + ".Match"}}} + str = strings.TrimSpace(tag[:index]) + strings.TrimSpace(tag[end+len("/)"):]) + return +} + +func parseFunc(vfunc, key string, label string) (v ValidFunc, err error) { + defer func() { + if r := recover(); r != nil { + err = fmt.Errorf("%v", r) + } + }() + + vfunc = strings.TrimSpace(vfunc) + start := strings.Index(vfunc, "(") + var num int + + // doesn't need parameter valid function + if start == -1 { + if num, err = numIn(vfunc); err != nil { + return + } + if num != 0 { + err = fmt.Errorf("%s require %d parameters", vfunc, num) + return + } + v = ValidFunc{vfunc, []interface{}{key + "." + vfunc + "." + label}} + return + } + + end := strings.Index(vfunc, ")") + if end == -1 { + err = fmt.Errorf("invalid valid function") + return + } + + name := strings.TrimSpace(vfunc[:start]) + if num, err = numIn(name); err != nil { + return + } + + params := strings.Split(vfunc[start+1:end], ",") + // the num of param must be equal + if num != len(params) { + err = fmt.Errorf("%s require %d parameters", name, num) + return + } + + tParams, err := trim(name, key+"."+ name + "." + label, params) + if err != nil { + return + } + v = ValidFunc{name, tParams} + return +} + +func numIn(name string) (num int, err error) { + fn, ok := funcs[name] + if !ok { + err = fmt.Errorf("doesn't exists %s valid function", name) + return + } + // sub *Validation obj and key + num = fn.Type().NumIn() - 3 + return +} + +func trim(name, key string, s []string) (ts []interface{}, err error) { + ts = make([]interface{}, len(s), len(s)+1) + fn, ok := funcs[name] + if !ok { + err = fmt.Errorf("doesn't exists %s valid function", name) + return + } + for i := 0; i < len(s); i++ { + var param interface{} + // skip *Validation and obj params + if param, err = parseParam(fn.Type().In(i+2), strings.TrimSpace(s[i])); err != nil { + return + } + ts[i] = param + } + ts = append(ts, key) + return +} + +// modify the parameters's type to adapt the function input parameters' type +func parseParam(t reflect.Type, s string) (i interface{}, err error) { + switch t.Kind() { + case reflect.Int: + i, err = strconv.Atoi(s) + case reflect.Int64: + if wordsize == 32 { + return nil, ErrInt64On32 + } + i, err = strconv.ParseInt(s, 10, 64) + case reflect.Int32: + var v int64 + v, err = strconv.ParseInt(s, 10, 32) + if err == nil { + i = int32(v) + } + case reflect.Int16: + var v int64 + v, err = strconv.ParseInt(s, 10, 16) + if err == nil { + i = int16(v) + } + case reflect.Int8: + var v int64 + v, err = strconv.ParseInt(s, 10, 8) + if err == nil { + i = int8(v) + } + case reflect.String: + i = s + case reflect.Ptr: + if t.Elem().String() != "regexp.Regexp" { + err = fmt.Errorf("not support %s", t.Elem().String()) + return + } + i, err = regexp.Compile(s) + default: + err = fmt.Errorf("not support %s", t.Kind().String()) + } + return +} + +func mergeParam(v *Validation, obj interface{}, params []interface{}) []interface{} { + return append([]interface{}{v, obj}, params...) +} diff --git a/pkg/validation/util_test.go b/pkg/validation/util_test.go new file mode 100644 index 00000000..58ca38db --- /dev/null +++ b/pkg/validation/util_test.go @@ -0,0 +1,128 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package validation + +import ( + "log" + "reflect" + "testing" +) + +type user struct { + ID int + Tag string `valid:"Maxx(aa)"` + Name string `valid:"Required;"` + Age int `valid:"Required; Range(1, 140)"` + match string `valid:"Required; Match(/^(test)?\\w*@(/test/);com$/);Max(2)"` +} + +func TestGetValidFuncs(t *testing.T) { + u := user{Name: "test", Age: 1} + tf := reflect.TypeOf(u) + var vfs []ValidFunc + var err error + + f, _ := tf.FieldByName("ID") + if vfs, err = getValidFuncs(f); err != nil { + t.Fatal(err) + } + if len(vfs) != 0 { + t.Fatal("should get none ValidFunc") + } + + f, _ = tf.FieldByName("Tag") + if _, err = getValidFuncs(f); err.Error() != "doesn't exists Maxx valid function" { + t.Fatal(err) + } + + f, _ = tf.FieldByName("Name") + if vfs, err = getValidFuncs(f); err != nil { + t.Fatal(err) + } + if len(vfs) != 1 { + t.Fatal("should get 1 ValidFunc") + } + if vfs[0].Name != "Required" && len(vfs[0].Params) != 0 { + t.Error("Required funcs should be got") + } + + f, _ = tf.FieldByName("Age") + if vfs, err = getValidFuncs(f); err != nil { + t.Fatal(err) + } + if len(vfs) != 2 { + t.Fatal("should get 2 ValidFunc") + } + if vfs[0].Name != "Required" && len(vfs[0].Params) != 0 { + t.Error("Required funcs should be got") + } + if vfs[1].Name != "Range" && len(vfs[1].Params) != 2 { + t.Error("Range funcs should be got") + } + + f, _ = tf.FieldByName("match") + if vfs, err = getValidFuncs(f); err != nil { + t.Fatal(err) + } + if len(vfs) != 3 { + t.Fatal("should get 3 ValidFunc but now is", len(vfs)) + } +} + +type User struct { + Name string `valid:"Required;MaxSize(5)" ` + Sex string `valid:"Required;" label:"sex_label"` + Age int `valid:"Required;Range(1, 140);" label:"age_label"` +} + +func TestValidation(t *testing.T) { + u := User{"man1238888456", "", 1140} + valid := Validation{} + b, err := valid.Valid(&u) + if err != nil { + // handle error + } + if !b { + // validation does not pass + // blabla... + for _, err := range valid.Errors { + log.Println(err.Key, err.Message) + } + if len(valid.Errors) != 3 { + t.Error("must be has 3 error") + } + } else { + t.Error("must be has 3 error") + } +} + +func TestCall(t *testing.T) { + u := user{Name: "test", Age: 180} + tf := reflect.TypeOf(u) + var vfs []ValidFunc + var err error + f, _ := tf.FieldByName("Age") + if vfs, err = getValidFuncs(f); err != nil { + t.Fatal(err) + } + valid := &Validation{} + vfs[1].Params = append([]interface{}{valid, u.Age}, vfs[1].Params...) + if _, err = funcs.Call(vfs[1].Name, vfs[1].Params...); err != nil { + t.Fatal(err) + } + if len(valid.Errors) != 1 { + t.Error("age out of range should be has an error") + } +} diff --git a/pkg/validation/validation.go b/pkg/validation/validation.go new file mode 100644 index 00000000..190e0f0e --- /dev/null +++ b/pkg/validation/validation.go @@ -0,0 +1,456 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package validation for validations +// +// import ( +// "github.com/astaxie/beego/validation" +// "log" +// ) +// +// type User struct { +// Name string +// Age int +// } +// +// func main() { +// u := User{"man", 40} +// valid := validation.Validation{} +// valid.Required(u.Name, "name") +// valid.MaxSize(u.Name, 15, "nameMax") +// valid.Range(u.Age, 0, 140, "age") +// if valid.HasErrors() { +// // validation does not pass +// // print invalid message +// for _, err := range valid.Errors { +// log.Println(err.Key, err.Message) +// } +// } +// // or use like this +// if v := valid.Max(u.Age, 140, "ageMax"); !v.Ok { +// log.Println(v.Error.Key, v.Error.Message) +// } +// } +// +// more info: http://beego.me/docs/mvc/controller/validation.md +package validation + +import ( + "fmt" + "reflect" + "regexp" + "strings" +) + +// ValidFormer valid interface +type ValidFormer interface { + Valid(*Validation) +} + +// Error show the error +type Error struct { + Message, Key, Name, Field, Tmpl string + Value interface{} + LimitValue interface{} +} + +// String Returns the Message. +func (e *Error) String() string { + if e == nil { + return "" + } + return e.Message +} + +// Implement Error interface. +// Return e.String() +func (e *Error) Error() string { return e.String() } + +// Result is returned from every validation method. +// It provides an indication of success, and a pointer to the Error (if any). +type Result struct { + Error *Error + Ok bool +} + +// Key Get Result by given key string. +func (r *Result) Key(key string) *Result { + if r.Error != nil { + r.Error.Key = key + } + return r +} + +// Message Set Result message by string or format string with args +func (r *Result) Message(message string, args ...interface{}) *Result { + if r.Error != nil { + if len(args) == 0 { + r.Error.Message = message + } else { + r.Error.Message = fmt.Sprintf(message, args...) + } + } + return r +} + +// A Validation context manages data validation and error messages. +type Validation struct { + // if this field set true, in struct tag valid + // if the struct field vale is empty + // it will skip those valid functions, see CanSkipFuncs + RequiredFirst bool + + Errors []*Error + ErrorsMap map[string][]*Error +} + +// Clear Clean all ValidationError. +func (v *Validation) Clear() { + v.Errors = []*Error{} + v.ErrorsMap = nil +} + +// HasErrors Has ValidationError nor not. +func (v *Validation) HasErrors() bool { + return len(v.Errors) > 0 +} + +// ErrorMap Return the errors mapped by key. +// If there are multiple validation errors associated with a single key, the +// first one "wins". (Typically the first validation will be the more basic). +func (v *Validation) ErrorMap() map[string][]*Error { + return v.ErrorsMap +} + +// Error Add an error to the validation context. +func (v *Validation) Error(message string, args ...interface{}) *Result { + result := (&Result{ + Ok: false, + Error: &Error{}, + }).Message(message, args...) + v.Errors = append(v.Errors, result.Error) + return result +} + +// Required Test that the argument is non-nil and non-empty (if string or list) +func (v *Validation) Required(obj interface{}, key string) *Result { + return v.apply(Required{key}, obj) +} + +// Min Test that the obj is greater than min if obj's type is int +func (v *Validation) Min(obj interface{}, min int, key string) *Result { + return v.apply(Min{min, key}, obj) +} + +// Max Test that the obj is less than max if obj's type is int +func (v *Validation) Max(obj interface{}, max int, key string) *Result { + return v.apply(Max{max, key}, obj) +} + +// Range Test that the obj is between mni and max if obj's type is int +func (v *Validation) Range(obj interface{}, min, max int, key string) *Result { + return v.apply(Range{Min{Min: min}, Max{Max: max}, key}, obj) +} + +// MinSize Test that the obj is longer than min size if type is string or slice +func (v *Validation) MinSize(obj interface{}, min int, key string) *Result { + return v.apply(MinSize{min, key}, obj) +} + +// MaxSize Test that the obj is shorter than max size if type is string or slice +func (v *Validation) MaxSize(obj interface{}, max int, key string) *Result { + return v.apply(MaxSize{max, key}, obj) +} + +// Length Test that the obj is same length to n if type is string or slice +func (v *Validation) Length(obj interface{}, n int, key string) *Result { + return v.apply(Length{n, key}, obj) +} + +// Alpha Test that the obj is [a-zA-Z] if type is string +func (v *Validation) Alpha(obj interface{}, key string) *Result { + return v.apply(Alpha{key}, obj) +} + +// Numeric Test that the obj is [0-9] if type is string +func (v *Validation) Numeric(obj interface{}, key string) *Result { + return v.apply(Numeric{key}, obj) +} + +// AlphaNumeric Test that the obj is [0-9a-zA-Z] if type is string +func (v *Validation) AlphaNumeric(obj interface{}, key string) *Result { + return v.apply(AlphaNumeric{key}, obj) +} + +// Match Test that the obj matches regexp if type is string +func (v *Validation) Match(obj interface{}, regex *regexp.Regexp, key string) *Result { + return v.apply(Match{regex, key}, obj) +} + +// NoMatch Test that the obj doesn't match regexp if type is string +func (v *Validation) NoMatch(obj interface{}, regex *regexp.Regexp, key string) *Result { + return v.apply(NoMatch{Match{Regexp: regex}, key}, obj) +} + +// AlphaDash Test that the obj is [0-9a-zA-Z_-] if type is string +func (v *Validation) AlphaDash(obj interface{}, key string) *Result { + return v.apply(AlphaDash{NoMatch{Match: Match{Regexp: alphaDashPattern}}, key}, obj) +} + +// Email Test that the obj is email address if type is string +func (v *Validation) Email(obj interface{}, key string) *Result { + return v.apply(Email{Match{Regexp: emailPattern}, key}, obj) +} + +// IP Test that the obj is IP address if type is string +func (v *Validation) IP(obj interface{}, key string) *Result { + return v.apply(IP{Match{Regexp: ipPattern}, key}, obj) +} + +// Base64 Test that the obj is base64 encoded if type is string +func (v *Validation) Base64(obj interface{}, key string) *Result { + return v.apply(Base64{Match{Regexp: base64Pattern}, key}, obj) +} + +// Mobile Test that the obj is chinese mobile number if type is string +func (v *Validation) Mobile(obj interface{}, key string) *Result { + return v.apply(Mobile{Match{Regexp: mobilePattern}, key}, obj) +} + +// Tel Test that the obj is chinese telephone number if type is string +func (v *Validation) Tel(obj interface{}, key string) *Result { + return v.apply(Tel{Match{Regexp: telPattern}, key}, obj) +} + +// Phone Test that the obj is chinese mobile or telephone number if type is string +func (v *Validation) Phone(obj interface{}, key string) *Result { + return v.apply(Phone{Mobile{Match: Match{Regexp: mobilePattern}}, + Tel{Match: Match{Regexp: telPattern}}, key}, obj) +} + +// ZipCode Test that the obj is chinese zip code if type is string +func (v *Validation) ZipCode(obj interface{}, key string) *Result { + return v.apply(ZipCode{Match{Regexp: zipCodePattern}, key}, obj) +} + +func (v *Validation) apply(chk Validator, obj interface{}) *Result { + if nil == obj { + if chk.IsSatisfied(obj) { + return &Result{Ok: true} + } + } else if reflect.TypeOf(obj).Kind() == reflect.Ptr { + if reflect.ValueOf(obj).IsNil() { + if chk.IsSatisfied(nil) { + return &Result{Ok: true} + } + } else { + if chk.IsSatisfied(reflect.ValueOf(obj).Elem().Interface()) { + return &Result{Ok: true} + } + } + } else if chk.IsSatisfied(obj) { + return &Result{Ok: true} + } + + // Add the error to the validation context. + key := chk.GetKey() + Name := key + Field := "" + Label := "" + parts := strings.Split(key, ".") + if len(parts) == 3 { + Field = parts[0] + Name = parts[1] + Label = parts[2] + if len(Label) == 0 { + Label = Field + } + } + + err := &Error{ + Message: Label + " " + chk.DefaultMessage(), + Key: key, + Name: Name, + Field: Field, + Value: obj, + Tmpl: MessageTmpls[Name], + LimitValue: chk.GetLimitValue(), + } + v.setError(err) + + // Also return it in the result. + return &Result{ + Ok: false, + Error: err, + } +} + +// key must like aa.bb.cc or aa.bb. +// AddError adds independent error message for the provided key +func (v *Validation) AddError(key, message string) { + Name := key + Field := "" + + Label := "" + parts := strings.Split(key, ".") + if len(parts) == 3 { + Field = parts[0] + Name = parts[1] + Label = parts[2] + if len(Label) == 0 { + Label = Field + } + } + + err := &Error{ + Message: Label + " " + message, + Key: key, + Name: Name, + Field: Field, + } + v.setError(err) +} + +func (v *Validation) setError(err *Error) { + v.Errors = append(v.Errors, err) + if v.ErrorsMap == nil { + v.ErrorsMap = make(map[string][]*Error) + } + if _, ok := v.ErrorsMap[err.Field]; !ok { + v.ErrorsMap[err.Field] = []*Error{} + } + v.ErrorsMap[err.Field] = append(v.ErrorsMap[err.Field], err) +} + +// SetError Set error message for one field in ValidationError +func (v *Validation) SetError(fieldName string, errMsg string) *Error { + err := &Error{Key: fieldName, Field: fieldName, Tmpl: errMsg, Message: errMsg} + v.setError(err) + return err +} + +// Check Apply a group of validators to a field, in order, and return the +// ValidationResult from the first one that fails, or the last one that +// succeeds. +func (v *Validation) Check(obj interface{}, checks ...Validator) *Result { + var result *Result + for _, check := range checks { + result = v.apply(check, obj) + if !result.Ok { + return result + } + } + return result +} + +// Valid Validate a struct. +// the obj parameter must be a struct or a struct pointer +func (v *Validation) Valid(obj interface{}) (b bool, err error) { + objT := reflect.TypeOf(obj) + objV := reflect.ValueOf(obj) + switch { + case isStruct(objT): + case isStructPtr(objT): + objT = objT.Elem() + objV = objV.Elem() + default: + err = fmt.Errorf("%v must be a struct or a struct pointer", obj) + return + } + + for i := 0; i < objT.NumField(); i++ { + var vfs []ValidFunc + if vfs, err = getValidFuncs(objT.Field(i)); err != nil { + return + } + + var hasRequired bool + for _, vf := range vfs { + if vf.Name == "Required" { + hasRequired = true + } + + currentField := objV.Field(i).Interface() + if objV.Field(i).Kind() == reflect.Ptr { + if objV.Field(i).IsNil() { + currentField = "" + } else { + currentField = objV.Field(i).Elem().Interface() + } + } + + chk := Required{""}.IsSatisfied(currentField) + if !hasRequired && v.RequiredFirst && !chk { + if _, ok := CanSkipFuncs[vf.Name]; ok { + continue + } + } + + if _, err = funcs.Call(vf.Name, + mergeParam(v, objV.Field(i).Interface(), vf.Params)...); err != nil { + return + } + } + } + + if !v.HasErrors() { + if form, ok := obj.(ValidFormer); ok { + form.Valid(v) + } + } + + return !v.HasErrors(), nil +} + +// RecursiveValid Recursively validate a struct. +// Step1: Validate by v.Valid +// Step2: If pass on step1, then reflect obj's fields +// Step3: Do the Recursively validation to all struct or struct pointer fields +func (v *Validation) RecursiveValid(objc interface{}) (bool, error) { + //Step 1: validate obj itself firstly + // fails if objc is not struct + pass, err := v.Valid(objc) + if err != nil || !pass { + return pass, err // Stop recursive validation + } + // Step 2: Validate struct's struct fields + objT := reflect.TypeOf(objc) + objV := reflect.ValueOf(objc) + + if isStructPtr(objT) { + objT = objT.Elem() + objV = objV.Elem() + } + + for i := 0; i < objT.NumField(); i++ { + + t := objT.Field(i).Type + + // Recursive applies to struct or pointer to structs fields + if isStruct(t) || isStructPtr(t) { + // Step 3: do the recursive validation + // Only valid the Public field recursively + if objV.Field(i).CanInterface() { + pass, err = v.RecursiveValid(objV.Field(i).Interface()) + } + } + } + return pass, err +} + +func (v *Validation) CanSkipAlso(skipFunc string) { + if _, ok := CanSkipFuncs[skipFunc]; !ok { + CanSkipFuncs[skipFunc] = struct{}{} + } +} diff --git a/pkg/validation/validation_test.go b/pkg/validation/validation_test.go new file mode 100644 index 00000000..b4b5b1b6 --- /dev/null +++ b/pkg/validation/validation_test.go @@ -0,0 +1,609 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package validation + +import ( + "regexp" + "testing" + "time" +) + +func TestRequired(t *testing.T) { + valid := Validation{} + + if valid.Required(nil, "nil").Ok { + t.Error("nil object should be false") + } + if !valid.Required(true, "bool").Ok { + t.Error("Bool value should always return true") + } + if !valid.Required(false, "bool").Ok { + t.Error("Bool value should always return true") + } + if valid.Required("", "string").Ok { + t.Error("\"'\" string should be false") + } + if valid.Required(" ", "string").Ok { + t.Error("\" \" string should be false") // For #2361 + } + if valid.Required("\n", "string").Ok { + t.Error("new line string should be false") // For #2361 + } + if !valid.Required("astaxie", "string").Ok { + t.Error("string should be true") + } + if valid.Required(0, "zero").Ok { + t.Error("Integer should not be equal 0") + } + if !valid.Required(1, "int").Ok { + t.Error("Integer except 0 should be true") + } + if !valid.Required(time.Now(), "time").Ok { + t.Error("time should be true") + } + if valid.Required([]string{}, "emptySlice").Ok { + t.Error("empty slice should be false") + } + if !valid.Required([]interface{}{"ok"}, "slice").Ok { + t.Error("slice should be true") + } +} + +func TestMin(t *testing.T) { + valid := Validation{} + + if valid.Min(-1, 0, "min0").Ok { + t.Error("-1 is less than the minimum value of 0 should be false") + } + if !valid.Min(1, 0, "min0").Ok { + t.Error("1 is greater or equal than the minimum value of 0 should be true") + } +} + +func TestMax(t *testing.T) { + valid := Validation{} + + if valid.Max(1, 0, "max0").Ok { + t.Error("1 is greater than the minimum value of 0 should be false") + } + if !valid.Max(-1, 0, "max0").Ok { + t.Error("-1 is less or equal than the maximum value of 0 should be true") + } +} + +func TestRange(t *testing.T) { + valid := Validation{} + + if valid.Range(-1, 0, 1, "range0_1").Ok { + t.Error("-1 is between 0 and 1 should be false") + } + if !valid.Range(1, 0, 1, "range0_1").Ok { + t.Error("1 is between 0 and 1 should be true") + } +} + +func TestMinSize(t *testing.T) { + valid := Validation{} + + if valid.MinSize("", 1, "minSize1").Ok { + t.Error("the length of \"\" is less than the minimum value of 1 should be false") + } + if !valid.MinSize("ok", 1, "minSize1").Ok { + t.Error("the length of \"ok\" is greater or equal than the minimum value of 1 should be true") + } + if valid.MinSize([]string{}, 1, "minSize1").Ok { + t.Error("the length of empty slice is less than the minimum value of 1 should be false") + } + if !valid.MinSize([]interface{}{"ok"}, 1, "minSize1").Ok { + t.Error("the length of [\"ok\"] is greater or equal than the minimum value of 1 should be true") + } +} + +func TestMaxSize(t *testing.T) { + valid := Validation{} + + if valid.MaxSize("ok", 1, "maxSize1").Ok { + t.Error("the length of \"ok\" is greater than the maximum value of 1 should be false") + } + if !valid.MaxSize("", 1, "maxSize1").Ok { + t.Error("the length of \"\" is less or equal than the maximum value of 1 should be true") + } + if valid.MaxSize([]interface{}{"ok", false}, 1, "maxSize1").Ok { + t.Error("the length of [\"ok\", false] is greater than the maximum value of 1 should be false") + } + if !valid.MaxSize([]string{}, 1, "maxSize1").Ok { + t.Error("the length of empty slice is less or equal than the maximum value of 1 should be true") + } +} + +func TestLength(t *testing.T) { + valid := Validation{} + + if valid.Length("", 1, "length1").Ok { + t.Error("the length of \"\" must equal 1 should be false") + } + if !valid.Length("1", 1, "length1").Ok { + t.Error("the length of \"1\" must equal 1 should be true") + } + if valid.Length([]string{}, 1, "length1").Ok { + t.Error("the length of empty slice must equal 1 should be false") + } + if !valid.Length([]interface{}{"ok"}, 1, "length1").Ok { + t.Error("the length of [\"ok\"] must equal 1 should be true") + } +} + +func TestAlpha(t *testing.T) { + valid := Validation{} + + if valid.Alpha("a,1-@ $", "alpha").Ok { + t.Error("\"a,1-@ $\" are valid alpha characters should be false") + } + if !valid.Alpha("abCD", "alpha").Ok { + t.Error("\"abCD\" are valid alpha characters should be true") + } +} + +func TestNumeric(t *testing.T) { + valid := Validation{} + + if valid.Numeric("a,1-@ $", "numeric").Ok { + t.Error("\"a,1-@ $\" are valid numeric characters should be false") + } + if !valid.Numeric("1234", "numeric").Ok { + t.Error("\"1234\" are valid numeric characters should be true") + } +} + +func TestAlphaNumeric(t *testing.T) { + valid := Validation{} + + if valid.AlphaNumeric("a,1-@ $", "alphaNumeric").Ok { + t.Error("\"a,1-@ $\" are valid alpha or numeric characters should be false") + } + if !valid.AlphaNumeric("1234aB", "alphaNumeric").Ok { + t.Error("\"1234aB\" are valid alpha or numeric characters should be true") + } +} + +func TestMatch(t *testing.T) { + valid := Validation{} + + if valid.Match("suchuangji@gmail", regexp.MustCompile(`^\w+@\w+\.\w+$`), "match").Ok { + t.Error("\"suchuangji@gmail\" match \"^\\w+@\\w+\\.\\w+$\" should be false") + } + if !valid.Match("suchuangji@gmail.com", regexp.MustCompile(`^\w+@\w+\.\w+$`), "match").Ok { + t.Error("\"suchuangji@gmail\" match \"^\\w+@\\w+\\.\\w+$\" should be true") + } +} + +func TestNoMatch(t *testing.T) { + valid := Validation{} + + if valid.NoMatch("123@gmail", regexp.MustCompile(`[^\w\d]`), "nomatch").Ok { + t.Error("\"123@gmail\" not match \"[^\\w\\d]\" should be false") + } + if !valid.NoMatch("123gmail", regexp.MustCompile(`[^\w\d]`), "match").Ok { + t.Error("\"123@gmail\" not match \"[^\\w\\d@]\" should be true") + } +} + +func TestAlphaDash(t *testing.T) { + valid := Validation{} + + if valid.AlphaDash("a,1-@ $", "alphaDash").Ok { + t.Error("\"a,1-@ $\" are valid alpha or numeric or dash(-_) characters should be false") + } + if !valid.AlphaDash("1234aB-_", "alphaDash").Ok { + t.Error("\"1234aB\" are valid alpha or numeric or dash(-_) characters should be true") + } +} + +func TestEmail(t *testing.T) { + valid := Validation{} + + if valid.Email("not@a email", "email").Ok { + t.Error("\"not@a email\" is a valid email address should be false") + } + if !valid.Email("suchuangji@gmail.com", "email").Ok { + t.Error("\"suchuangji@gmail.com\" is a valid email address should be true") + } + if valid.Email("@suchuangji@gmail.com", "email").Ok { + t.Error("\"@suchuangji@gmail.com\" is a valid email address should be false") + } + if valid.Email("suchuangji@gmail.com ok", "email").Ok { + t.Error("\"suchuangji@gmail.com ok\" is a valid email address should be false") + } +} + +func TestIP(t *testing.T) { + valid := Validation{} + + if valid.IP("11.255.255.256", "IP").Ok { + t.Error("\"11.255.255.256\" is a valid ip address should be false") + } + if !valid.IP("01.11.11.11", "IP").Ok { + t.Error("\"suchuangji@gmail.com\" is a valid ip address should be true") + } +} + +func TestBase64(t *testing.T) { + valid := Validation{} + + if valid.Base64("suchuangji@gmail.com", "base64").Ok { + t.Error("\"suchuangji@gmail.com\" are a valid base64 characters should be false") + } + if !valid.Base64("c3VjaHVhbmdqaUBnbWFpbC5jb20=", "base64").Ok { + t.Error("\"c3VjaHVhbmdqaUBnbWFpbC5jb20=\" are a valid base64 characters should be true") + } +} + +func TestMobile(t *testing.T) { + valid := Validation{} + + validMobiles := []string{ + "19800008888", + "18800008888", + "18000008888", + "8618300008888", + "+8614700008888", + "17300008888", + "+8617100008888", + "8617500008888", + "8617400008888", + "16200008888", + "16500008888", + "16600008888", + "16700008888", + "13300008888", + "14900008888", + "15300008888", + "17300008888", + "17700008888", + "18000008888", + "18900008888", + "19100008888", + "19900008888", + "19300008888", + "13000008888", + "13100008888", + "13200008888", + "14500008888", + "15500008888", + "15600008888", + "16600008888", + "17100008888", + "17500008888", + "17600008888", + "18500008888", + "18600008888", + "13400008888", + "13500008888", + "13600008888", + "13700008888", + "13800008888", + "13900008888", + "14700008888", + "15000008888", + "15100008888", + "15200008888", + "15800008888", + "15900008888", + "17200008888", + "17800008888", + "18200008888", + "18300008888", + "18400008888", + "18700008888", + "18800008888", + "19800008888", + } + + for _, m := range validMobiles { + if !valid.Mobile(m, "mobile").Ok { + t.Error(m + " is a valid mobile phone number should be true") + } + } +} + +func TestTel(t *testing.T) { + valid := Validation{} + + if valid.Tel("222-00008888", "telephone").Ok { + t.Error("\"222-00008888\" is a valid telephone number should be false") + } + if !valid.Tel("022-70008888", "telephone").Ok { + t.Error("\"022-70008888\" is a valid telephone number should be true") + } + if !valid.Tel("02270008888", "telephone").Ok { + t.Error("\"02270008888\" is a valid telephone number should be true") + } + if !valid.Tel("70008888", "telephone").Ok { + t.Error("\"70008888\" is a valid telephone number should be true") + } +} + +func TestPhone(t *testing.T) { + valid := Validation{} + + if valid.Phone("222-00008888", "phone").Ok { + t.Error("\"222-00008888\" is a valid phone number should be false") + } + if !valid.Mobile("+8614700008888", "phone").Ok { + t.Error("\"+8614700008888\" is a valid phone number should be true") + } + if !valid.Tel("02270008888", "phone").Ok { + t.Error("\"02270008888\" is a valid phone number should be true") + } +} + +func TestZipCode(t *testing.T) { + valid := Validation{} + + if valid.ZipCode("", "zipcode").Ok { + t.Error("\"00008888\" is a valid zipcode should be false") + } + if !valid.ZipCode("536000", "zipcode").Ok { + t.Error("\"536000\" is a valid zipcode should be true") + } +} + +func TestValid(t *testing.T) { + type user struct { + ID int + Name string `valid:"Required;Match(/^(test)?\\w*@(/test/);com$/)"` + Age int `valid:"Required;Range(1, 140)"` + } + valid := Validation{} + + u := user{Name: "test@/test/;com", Age: 40} + b, err := valid.Valid(u) + if err != nil { + t.Fatal(err) + } + if !b { + t.Error("validation should be passed") + } + + uptr := &user{Name: "test", Age: 40} + valid.Clear() + b, err = valid.Valid(uptr) + if err != nil { + t.Fatal(err) + } + if b { + t.Error("validation should not be passed") + } + if len(valid.Errors) != 1 { + t.Fatalf("valid errors len should be 1 but got %d", len(valid.Errors)) + } + if valid.Errors[0].Key != "Name.Match" { + t.Errorf("Message key should be `Name.Match` but got %s", valid.Errors[0].Key) + } + + u = user{Name: "test@/test/;com", Age: 180} + valid.Clear() + b, err = valid.Valid(u) + if err != nil { + t.Fatal(err) + } + if b { + t.Error("validation should not be passed") + } + if len(valid.Errors) != 1 { + t.Fatalf("valid errors len should be 1 but got %d", len(valid.Errors)) + } + if valid.Errors[0].Key != "Age.Range." { + t.Errorf("Message key should be `Age.Range` but got %s", valid.Errors[0].Key) + } +} + +func TestRecursiveValid(t *testing.T) { + type User struct { + ID int + Name string `valid:"Required;Match(/^(test)?\\w*@(/test/);com$/)"` + Age int `valid:"Required;Range(1, 140)"` + } + + type AnonymouseUser struct { + ID2 int + Name2 string `valid:"Required;Match(/^(test)?\\w*@(/test/);com$/)"` + Age2 int `valid:"Required;Range(1, 140)"` + } + + type Account struct { + Password string `valid:"Required"` + U User + AnonymouseUser + } + valid := Validation{} + + u := Account{Password: "abc123_", U: User{}} + b, err := valid.RecursiveValid(u) + if err != nil { + t.Fatal(err) + } + if b { + t.Error("validation should not be passed") + } +} + +func TestSkipValid(t *testing.T) { + type User struct { + ID int + + Email string `valid:"Email"` + ReqEmail string `valid:"Required;Email"` + + IP string `valid:"IP"` + ReqIP string `valid:"Required;IP"` + + Mobile string `valid:"Mobile"` + ReqMobile string `valid:"Required;Mobile"` + + Tel string `valid:"Tel"` + ReqTel string `valid:"Required;Tel"` + + Phone string `valid:"Phone"` + ReqPhone string `valid:"Required;Phone"` + + ZipCode string `valid:"ZipCode"` + ReqZipCode string `valid:"Required;ZipCode"` + } + + u := User{ + ReqEmail: "a@a.com", + ReqIP: "127.0.0.1", + ReqMobile: "18888888888", + ReqTel: "02088888888", + ReqPhone: "02088888888", + ReqZipCode: "510000", + } + + valid := Validation{} + b, err := valid.Valid(u) + if err != nil { + t.Fatal(err) + } + if b { + t.Fatal("validation should not be passed") + } + + valid = Validation{RequiredFirst: true} + b, err = valid.Valid(u) + if err != nil { + t.Fatal(err) + } + if !b { + t.Fatal("validation should be passed") + } +} + +func TestPointer(t *testing.T) { + type User struct { + ID int + + Email *string `valid:"Email"` + ReqEmail *string `valid:"Required;Email"` + } + + u := User{ + ReqEmail: nil, + Email: nil, + } + + valid := Validation{} + b, err := valid.Valid(u) + if err != nil { + t.Fatal(err) + } + if b { + t.Fatal("validation should not be passed") + } + + validEmail := "a@a.com" + u = User{ + ReqEmail: &validEmail, + Email: nil, + } + + valid = Validation{RequiredFirst: true} + b, err = valid.Valid(u) + if err != nil { + t.Fatal(err) + } + if !b { + t.Fatal("validation should be passed") + } + + u = User{ + ReqEmail: &validEmail, + Email: nil, + } + + valid = Validation{} + b, err = valid.Valid(u) + if err != nil { + t.Fatal(err) + } + if b { + t.Fatal("validation should not be passed") + } + + invalidEmail := "a@a" + u = User{ + ReqEmail: &validEmail, + Email: &invalidEmail, + } + + valid = Validation{RequiredFirst: true} + b, err = valid.Valid(u) + if err != nil { + t.Fatal(err) + } + if b { + t.Fatal("validation should not be passed") + } + + u = User{ + ReqEmail: &validEmail, + Email: &invalidEmail, + } + + valid = Validation{} + b, err = valid.Valid(u) + if err != nil { + t.Fatal(err) + } + if b { + t.Fatal("validation should not be passed") + } +} + +func TestCanSkipAlso(t *testing.T) { + type User struct { + ID int + + Email string `valid:"Email"` + ReqEmail string `valid:"Required;Email"` + MatchRange int `valid:"Range(10, 20)"` + } + + u := User{ + ReqEmail: "a@a.com", + Email: "", + MatchRange: 0, + } + + valid := Validation{RequiredFirst: true} + b, err := valid.Valid(u) + if err != nil { + t.Fatal(err) + } + if b { + t.Fatal("validation should not be passed") + } + + valid = Validation{RequiredFirst: true} + valid.CanSkipAlso("Range") + b, err = valid.Valid(u) + if err != nil { + t.Fatal(err) + } + if !b { + t.Fatal("validation should be passed") + } + +} diff --git a/pkg/validation/validators.go b/pkg/validation/validators.go new file mode 100644 index 00000000..38b6f1aa --- /dev/null +++ b/pkg/validation/validators.go @@ -0,0 +1,738 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package validation + +import ( + "fmt" + "github.com/astaxie/beego/logs" + "reflect" + "regexp" + "strings" + "sync" + "time" + "unicode/utf8" +) + +// CanSkipFuncs will skip valid if RequiredFirst is true and the struct field's value is empty +var CanSkipFuncs = map[string]struct{}{ + "Email": {}, + "IP": {}, + "Mobile": {}, + "Tel": {}, + "Phone": {}, + "ZipCode": {}, +} + +// MessageTmpls store commond validate template +var MessageTmpls = map[string]string{ + "Required": "Can not be empty", + "Min": "Minimum is %d", + "Max": "Maximum is %d", + "Range": "Range is %d to %d", + "MinSize": "Minimum size is %d", + "MaxSize": "Maximum size is %d", + "Length": "Required length is %d", + "Alpha": "Must be valid alpha characters", + "Numeric": "Must be valid numeric characters", + "AlphaNumeric": "Must be valid alpha or numeric characters", + "Match": "Must match %s", + "NoMatch": "Must not match %s", + "AlphaDash": "Must be valid alpha or numeric or dash(-_) characters", + "Email": "Must be a valid email address", + "IP": "Must be a valid ip address", + "Base64": "Must be valid base64 characters", + "Mobile": "Must be valid mobile number", + "Tel": "Must be valid telephone number", + "Phone": "Must be valid telephone or mobile phone number", + "ZipCode": "Must be valid zipcode", +} + +var once sync.Once + +// SetDefaultMessage set default messages +// if not set, the default messages are +// "Required": "Can not be empty", +// "Min": "Minimum is %d", +// "Max": "Maximum is %d", +// "Range": "Range is %d to %d", +// "MinSize": "Minimum size is %d", +// "MaxSize": "Maximum size is %d", +// "Length": "Required length is %d", +// "Alpha": "Must be valid alpha characters", +// "Numeric": "Must be valid numeric characters", +// "AlphaNumeric": "Must be valid alpha or numeric characters", +// "Match": "Must match %s", +// "NoMatch": "Must not match %s", +// "AlphaDash": "Must be valid alpha or numeric or dash(-_) characters", +// "Email": "Must be a valid email address", +// "IP": "Must be a valid ip address", +// "Base64": "Must be valid base64 characters", +// "Mobile": "Must be valid mobile number", +// "Tel": "Must be valid telephone number", +// "Phone": "Must be valid telephone or mobile phone number", +// "ZipCode": "Must be valid zipcode", +func SetDefaultMessage(msg map[string]string) { + if len(msg) == 0 { + return + } + + once.Do(func() { + for name := range msg { + MessageTmpls[name] = msg[name] + } + }) + logs.Warn(`you must SetDefaultMessage at once`) +} + +// Validator interface +type Validator interface { + IsSatisfied(interface{}) bool + DefaultMessage() string + GetKey() string + GetLimitValue() interface{} +} + +// Required struct +type Required struct { + Key string +} + +// IsSatisfied judge whether obj has value +func (r Required) IsSatisfied(obj interface{}) bool { + if obj == nil { + return false + } + + if str, ok := obj.(string); ok { + return len(strings.TrimSpace(str)) > 0 + } + if _, ok := obj.(bool); ok { + return true + } + if i, ok := obj.(int); ok { + return i != 0 + } + if i, ok := obj.(uint); ok { + return i != 0 + } + if i, ok := obj.(int8); ok { + return i != 0 + } + if i, ok := obj.(uint8); ok { + return i != 0 + } + if i, ok := obj.(int16); ok { + return i != 0 + } + if i, ok := obj.(uint16); ok { + return i != 0 + } + if i, ok := obj.(uint32); ok { + return i != 0 + } + if i, ok := obj.(int32); ok { + return i != 0 + } + if i, ok := obj.(int64); ok { + return i != 0 + } + if i, ok := obj.(uint64); ok { + return i != 0 + } + if t, ok := obj.(time.Time); ok { + return !t.IsZero() + } + v := reflect.ValueOf(obj) + if v.Kind() == reflect.Slice { + return v.Len() > 0 + } + return true +} + +// DefaultMessage return the default error message +func (r Required) DefaultMessage() string { + return MessageTmpls["Required"] +} + +// GetKey return the r.Key +func (r Required) GetKey() string { + return r.Key +} + +// GetLimitValue return nil now +func (r Required) GetLimitValue() interface{} { + return nil +} + +// Min check struct +type Min struct { + Min int + Key string +} + +// IsSatisfied judge whether obj is valid +// not support int64 on 32-bit platform +func (m Min) IsSatisfied(obj interface{}) bool { + var v int + switch obj.(type) { + case int64: + if wordsize == 32 { + return false + } + v = int(obj.(int64)) + case int: + v = obj.(int) + case int32: + v = int(obj.(int32)) + case int16: + v = int(obj.(int16)) + case int8: + v = int(obj.(int8)) + default: + return false + } + + return v >= m.Min +} + +// DefaultMessage return the default min error message +func (m Min) DefaultMessage() string { + return fmt.Sprintf(MessageTmpls["Min"], m.Min) +} + +// GetKey return the m.Key +func (m Min) GetKey() string { + return m.Key +} + +// GetLimitValue return the limit value, Min +func (m Min) GetLimitValue() interface{} { + return m.Min +} + +// Max validate struct +type Max struct { + Max int + Key string +} + +// IsSatisfied judge whether obj is valid +// not support int64 on 32-bit platform +func (m Max) IsSatisfied(obj interface{}) bool { + var v int + switch obj.(type) { + case int64: + if wordsize == 32 { + return false + } + v = int(obj.(int64)) + case int: + v = obj.(int) + case int32: + v = int(obj.(int32)) + case int16: + v = int(obj.(int16)) + case int8: + v = int(obj.(int8)) + default: + return false + } + + return v <= m.Max +} + +// DefaultMessage return the default max error message +func (m Max) DefaultMessage() string { + return fmt.Sprintf(MessageTmpls["Max"], m.Max) +} + +// GetKey return the m.Key +func (m Max) GetKey() string { + return m.Key +} + +// GetLimitValue return the limit value, Max +func (m Max) GetLimitValue() interface{} { + return m.Max +} + +// Range Requires an integer to be within Min, Max inclusive. +type Range struct { + Min + Max + Key string +} + +// IsSatisfied judge whether obj is valid +// not support int64 on 32-bit platform +func (r Range) IsSatisfied(obj interface{}) bool { + return r.Min.IsSatisfied(obj) && r.Max.IsSatisfied(obj) +} + +// DefaultMessage return the default Range error message +func (r Range) DefaultMessage() string { + return fmt.Sprintf(MessageTmpls["Range"], r.Min.Min, r.Max.Max) +} + +// GetKey return the m.Key +func (r Range) GetKey() string { + return r.Key +} + +// GetLimitValue return the limit value, Max +func (r Range) GetLimitValue() interface{} { + return []int{r.Min.Min, r.Max.Max} +} + +// MinSize Requires an array or string to be at least a given length. +type MinSize struct { + Min int + Key string +} + +// IsSatisfied judge whether obj is valid +func (m MinSize) IsSatisfied(obj interface{}) bool { + if str, ok := obj.(string); ok { + return utf8.RuneCountInString(str) >= m.Min + } + v := reflect.ValueOf(obj) + if v.Kind() == reflect.Slice { + return v.Len() >= m.Min + } + return false +} + +// DefaultMessage return the default MinSize error message +func (m MinSize) DefaultMessage() string { + return fmt.Sprintf(MessageTmpls["MinSize"], m.Min) +} + +// GetKey return the m.Key +func (m MinSize) GetKey() string { + return m.Key +} + +// GetLimitValue return the limit value +func (m MinSize) GetLimitValue() interface{} { + return m.Min +} + +// MaxSize Requires an array or string to be at most a given length. +type MaxSize struct { + Max int + Key string +} + +// IsSatisfied judge whether obj is valid +func (m MaxSize) IsSatisfied(obj interface{}) bool { + if str, ok := obj.(string); ok { + return utf8.RuneCountInString(str) <= m.Max + } + v := reflect.ValueOf(obj) + if v.Kind() == reflect.Slice { + return v.Len() <= m.Max + } + return false +} + +// DefaultMessage return the default MaxSize error message +func (m MaxSize) DefaultMessage() string { + return fmt.Sprintf(MessageTmpls["MaxSize"], m.Max) +} + +// GetKey return the m.Key +func (m MaxSize) GetKey() string { + return m.Key +} + +// GetLimitValue return the limit value +func (m MaxSize) GetLimitValue() interface{} { + return m.Max +} + +// Length Requires an array or string to be exactly a given length. +type Length struct { + N int + Key string +} + +// IsSatisfied judge whether obj is valid +func (l Length) IsSatisfied(obj interface{}) bool { + if str, ok := obj.(string); ok { + return utf8.RuneCountInString(str) == l.N + } + v := reflect.ValueOf(obj) + if v.Kind() == reflect.Slice { + return v.Len() == l.N + } + return false +} + +// DefaultMessage return the default Length error message +func (l Length) DefaultMessage() string { + return fmt.Sprintf(MessageTmpls["Length"], l.N) +} + +// GetKey return the m.Key +func (l Length) GetKey() string { + return l.Key +} + +// GetLimitValue return the limit value +func (l Length) GetLimitValue() interface{} { + return l.N +} + +// Alpha check the alpha +type Alpha struct { + Key string +} + +// IsSatisfied judge whether obj is valid +func (a Alpha) IsSatisfied(obj interface{}) bool { + if str, ok := obj.(string); ok { + for _, v := range str { + if ('Z' < v || v < 'A') && ('z' < v || v < 'a') { + return false + } + } + return true + } + return false +} + +// DefaultMessage return the default Length error message +func (a Alpha) DefaultMessage() string { + return MessageTmpls["Alpha"] +} + +// GetKey return the m.Key +func (a Alpha) GetKey() string { + return a.Key +} + +// GetLimitValue return the limit value +func (a Alpha) GetLimitValue() interface{} { + return nil +} + +// Numeric check number +type Numeric struct { + Key string +} + +// IsSatisfied judge whether obj is valid +func (n Numeric) IsSatisfied(obj interface{}) bool { + if str, ok := obj.(string); ok { + for _, v := range str { + if '9' < v || v < '0' { + return false + } + } + return true + } + return false +} + +// DefaultMessage return the default Length error message +func (n Numeric) DefaultMessage() string { + return MessageTmpls["Numeric"] +} + +// GetKey return the n.Key +func (n Numeric) GetKey() string { + return n.Key +} + +// GetLimitValue return the limit value +func (n Numeric) GetLimitValue() interface{} { + return nil +} + +// AlphaNumeric check alpha and number +type AlphaNumeric struct { + Key string +} + +// IsSatisfied judge whether obj is valid +func (a AlphaNumeric) IsSatisfied(obj interface{}) bool { + if str, ok := obj.(string); ok { + for _, v := range str { + if ('Z' < v || v < 'A') && ('z' < v || v < 'a') && ('9' < v || v < '0') { + return false + } + } + return true + } + return false +} + +// DefaultMessage return the default Length error message +func (a AlphaNumeric) DefaultMessage() string { + return MessageTmpls["AlphaNumeric"] +} + +// GetKey return the a.Key +func (a AlphaNumeric) GetKey() string { + return a.Key +} + +// GetLimitValue return the limit value +func (a AlphaNumeric) GetLimitValue() interface{} { + return nil +} + +// Match Requires a string to match a given regex. +type Match struct { + Regexp *regexp.Regexp + Key string +} + +// IsSatisfied judge whether obj is valid +func (m Match) IsSatisfied(obj interface{}) bool { + return m.Regexp.MatchString(fmt.Sprintf("%v", obj)) +} + +// DefaultMessage return the default Match error message +func (m Match) DefaultMessage() string { + return fmt.Sprintf(MessageTmpls["Match"], m.Regexp.String()) +} + +// GetKey return the m.Key +func (m Match) GetKey() string { + return m.Key +} + +// GetLimitValue return the limit value +func (m Match) GetLimitValue() interface{} { + return m.Regexp.String() +} + +// NoMatch Requires a string to not match a given regex. +type NoMatch struct { + Match + Key string +} + +// IsSatisfied judge whether obj is valid +func (n NoMatch) IsSatisfied(obj interface{}) bool { + return !n.Match.IsSatisfied(obj) +} + +// DefaultMessage return the default NoMatch error message +func (n NoMatch) DefaultMessage() string { + return fmt.Sprintf(MessageTmpls["NoMatch"], n.Regexp.String()) +} + +// GetKey return the n.Key +func (n NoMatch) GetKey() string { + return n.Key +} + +// GetLimitValue return the limit value +func (n NoMatch) GetLimitValue() interface{} { + return n.Regexp.String() +} + +var alphaDashPattern = regexp.MustCompile(`[^\d\w-_]`) + +// AlphaDash check not Alpha +type AlphaDash struct { + NoMatch + Key string +} + +// DefaultMessage return the default AlphaDash error message +func (a AlphaDash) DefaultMessage() string { + return MessageTmpls["AlphaDash"] +} + +// GetKey return the n.Key +func (a AlphaDash) GetKey() string { + return a.Key +} + +// GetLimitValue return the limit value +func (a AlphaDash) GetLimitValue() interface{} { + return nil +} + +var emailPattern = regexp.MustCompile(`^[\w!#$%&'*+/=?^_` + "`" + `{|}~-]+(?:\.[\w!#$%&'*+/=?^_` + "`" + `{|}~-]+)*@(?:[\w](?:[\w-]*[\w])?\.)+[a-zA-Z0-9](?:[\w-]*[\w])?$`) + +// Email check struct +type Email struct { + Match + Key string +} + +// DefaultMessage return the default Email error message +func (e Email) DefaultMessage() string { + return MessageTmpls["Email"] +} + +// GetKey return the n.Key +func (e Email) GetKey() string { + return e.Key +} + +// GetLimitValue return the limit value +func (e Email) GetLimitValue() interface{} { + return nil +} + +var ipPattern = regexp.MustCompile(`^((2[0-4]\d|25[0-5]|[01]?\d\d?)\.){3}(2[0-4]\d|25[0-5]|[01]?\d\d?)$`) + +// IP check struct +type IP struct { + Match + Key string +} + +// DefaultMessage return the default IP error message +func (i IP) DefaultMessage() string { + return MessageTmpls["IP"] +} + +// GetKey return the i.Key +func (i IP) GetKey() string { + return i.Key +} + +// GetLimitValue return the limit value +func (i IP) GetLimitValue() interface{} { + return nil +} + +var base64Pattern = regexp.MustCompile(`^(?:[A-Za-z0-99+/]{4})*(?:[A-Za-z0-9+/]{2}==|[A-Za-z0-9+/]{3}=)?$`) + +// Base64 check struct +type Base64 struct { + Match + Key string +} + +// DefaultMessage return the default Base64 error message +func (b Base64) DefaultMessage() string { + return MessageTmpls["Base64"] +} + +// GetKey return the b.Key +func (b Base64) GetKey() string { + return b.Key +} + +// GetLimitValue return the limit value +func (b Base64) GetLimitValue() interface{} { + return nil +} + +// just for chinese mobile phone number +var mobilePattern = regexp.MustCompile(`^((\+86)|(86))?1([356789][0-9]|4[579]|6[67]|7[0135678]|9[189])[0-9]{8}$`) + +// Mobile check struct +type Mobile struct { + Match + Key string +} + +// DefaultMessage return the default Mobile error message +func (m Mobile) DefaultMessage() string { + return MessageTmpls["Mobile"] +} + +// GetKey return the m.Key +func (m Mobile) GetKey() string { + return m.Key +} + +// GetLimitValue return the limit value +func (m Mobile) GetLimitValue() interface{} { + return nil +} + +// just for chinese telephone number +var telPattern = regexp.MustCompile(`^(0\d{2,3}(\-)?)?\d{7,8}$`) + +// Tel check telephone struct +type Tel struct { + Match + Key string +} + +// DefaultMessage return the default Tel error message +func (t Tel) DefaultMessage() string { + return MessageTmpls["Tel"] +} + +// GetKey return the t.Key +func (t Tel) GetKey() string { + return t.Key +} + +// GetLimitValue return the limit value +func (t Tel) GetLimitValue() interface{} { + return nil +} + +// Phone just for chinese telephone or mobile phone number +type Phone struct { + Mobile + Tel + Key string +} + +// IsSatisfied judge whether obj is valid +func (p Phone) IsSatisfied(obj interface{}) bool { + return p.Mobile.IsSatisfied(obj) || p.Tel.IsSatisfied(obj) +} + +// DefaultMessage return the default Phone error message +func (p Phone) DefaultMessage() string { + return MessageTmpls["Phone"] +} + +// GetKey return the p.Key +func (p Phone) GetKey() string { + return p.Key +} + +// GetLimitValue return the limit value +func (p Phone) GetLimitValue() interface{} { + return nil +} + +// just for chinese zipcode +var zipCodePattern = regexp.MustCompile(`^[1-9]\d{5}$`) + +// ZipCode check the zip struct +type ZipCode struct { + Match + Key string +} + +// DefaultMessage return the default Zip error message +func (z ZipCode) DefaultMessage() string { + return MessageTmpls["ZipCode"] +} + +// GetKey return the z.Key +func (z ZipCode) GetKey() string { + return z.Key +} + +// GetLimitValue return the limit value +func (z ZipCode) GetLimitValue() interface{} { + return nil +}