diff --git a/.travis.yml b/.travis.yml index 63b31c52..67efe057 100644 --- a/.travis.yml +++ b/.travis.yml @@ -7,24 +7,53 @@ services: - mysql - postgresql - memcached + - docker env: global: - GO_REPO_FULLNAME="github.com/astaxie/beego" matrix: - ORM_DRIVER=sqlite3 ORM_SOURCE=$TRAVIS_BUILD_DIR/orm_test.db - ORM_DRIVER=postgres ORM_SOURCE="user=postgres dbname=orm_test sslmode=disable" + - ORM_DRIVER=mysql export ORM_SOURCE="root:@/orm_test?charset=utf8" before_install: - # link the local repo with ${GOPATH}/src// - - GO_REPO_NAMESPACE=${GO_REPO_FULLNAME%/*} - # relies on GOPATH to contain only one directory... - - mkdir -p ${GOPATH}/src/${GO_REPO_NAMESPACE} - - ln -sv ${TRAVIS_BUILD_DIR} ${GOPATH}/src/${GO_REPO_FULLNAME} - - cd ${GOPATH}/src/${GO_REPO_FULLNAME} - # get and build ssdb - - git clone git://github.com/ideawu/ssdb.git - - cd ssdb - - make - - cd .. + # link the local repo with ${GOPATH}/src// + - GO_REPO_NAMESPACE=${GO_REPO_FULLNAME%/*} + # relies on GOPATH to contain only one directory... + - mkdir -p ${GOPATH}/src/${GO_REPO_NAMESPACE} + - ln -sv ${TRAVIS_BUILD_DIR} ${GOPATH}/src/${GO_REPO_FULLNAME} + - cd ${GOPATH}/src/${GO_REPO_FULLNAME} + # get and build ssdb + - git clone git://github.com/ideawu/ssdb.git + - cd ssdb + - make + - cd .. + # - prepare etcd + # - prepare for etcd unit tests + - rm -rf /tmp/etcd-data.tmp + - mkdir -p /tmp/etcd-data.tmp + - docker rmi gcr.io/etcd-development/etcd:v3.3.25 || true && + docker run -d + -p 2379:2379 + -p 2380:2380 + --mount type=bind,source=/tmp/etcd-data.tmp,destination=/etcd-data + --name etcd-gcr-v3.3.25 + gcr.io/etcd-development/etcd:v3.3.25 + /usr/local/bin/etcd + --name s1 + --data-dir /etcd-data + --listen-client-urls http://0.0.0.0:2379 + --advertise-client-urls http://0.0.0.0:2379 + --listen-peer-urls http://0.0.0.0:2380 + --initial-advertise-peer-urls http://0.0.0.0:2380 + --initial-cluster s1=http://0.0.0.0:2380 + --initial-cluster-token tkn + --initial-cluster-state new + - docker exec etcd-gcr-v3.3.25 /bin/sh -c "ETCDCTL_API=3 /usr/local/bin/etcdctl put current.float 1.23" + - docker exec etcd-gcr-v3.3.25 /bin/sh -c "ETCDCTL_API=3 /usr/local/bin/etcdctl put current.bool true" + - docker exec etcd-gcr-v3.3.25 /bin/sh -c "ETCDCTL_API=3 /usr/local/bin/etcdctl put current.int 11" + - docker exec etcd-gcr-v3.3.25 /bin/sh -c "ETCDCTL_API=3 /usr/local/bin/etcdctl put current.string hello" + - docker exec etcd-gcr-v3.3.25 /bin/sh -c "ETCDCTL_API=3 /usr/local/bin/etcdctl put current.serialize.name test" + - docker exec etcd-gcr-v3.3.25 /bin/sh -c "ETCDCTL_API=3 /usr/local/bin/etcdctl put sub.sub.key1 sub.sub.key" install: - go get github.com/lib/pq - go get github.com/go-sql-driver/mysql @@ -51,7 +80,10 @@ install: - go get -u golang.org/x/lint/golint - go get -u github.com/go-redis/redis before_script: + + # - - psql --version + # - prepare for orm unit tests - sh -c "if [ '$ORM_DRIVER' = 'postgres' ]; then psql -c 'create database orm_test;' -U postgres; fi" - sh -c "if [ '$ORM_DRIVER' = 'mysql' ]; then mysql -u root -e 'create database orm_test;'; fi" - sh -c "if [ '$ORM_DRIVER' = 'sqlite' ]; then touch $TRAVIS_BUILD_DIR/orm_test.db; fi" @@ -70,4 +102,4 @@ script: - find . ! \( -path './vendor' -prune \) -type f -name '*.go' -print0 | xargs -0 gofmt -l -s - golint ./... addons: - postgresql: "9.6" + postgresql: "9.6" \ No newline at end of file diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index ee7e0b5a..5035ae94 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -23,6 +23,26 @@ cp ./githook/pre-commit ./.git/hooks/pre-commit ``` This will add git hooks into .git/hooks. Or you can add it manually. +## Prepare middleware + +Beego uses many middlewares, including MySQL, Redis, SSDB and so on. + +We provide docker compose file to start all middlewares. + +You can run: +```shell script +docker-compose -f scripts/test_docker_compose.yml up -d +``` +Unit tests read addressed from environment, here is an example: +```shell script +export ORM_DRIVER=mysql +export ORM_SOURCE="beego:test@tcp(192.168.0.105:13306)/orm_test?charset=utf8" +export MEMCACHE_ADDR="192.168.0.105:11211" +export REDIS_ADDR="192.168.0.105:6379" +export SSDB_ADDR="192.168.0.105:8888" +``` + + ## Contribution guidelines ### Pull requests diff --git a/go.mod b/go.mod index 91bd9aef..ab7f5e39 100644 --- a/go.mod +++ b/go.mod @@ -7,6 +7,10 @@ require ( github.com/bradfitz/gomemcache v0.0.0-20180710155616-bc664df96737 github.com/casbin/casbin v1.7.0 github.com/cloudflare/golz4 v0.0.0-20150217214814-ef862a3cdc58 + github.com/coreos/etcd v3.3.25+incompatible + github.com/coreos/go-semver v0.3.0 // indirect + github.com/coreos/go-systemd v0.0.0-20191104093116-d3cd4ed1dbcf // indirect + github.com/coreos/pkg v0.0.0-20180928190104-399ea9e2e55f // indirect github.com/couchbase/go-couchbase v0.0.0-20200519150804-63f3cdb75e0d github.com/couchbase/gomemcached v0.0.0-20200526233749-ec430f949808 // indirect github.com/couchbase/goutils v0.0.0-20180530154633-e865a1461c8a // indirect @@ -16,13 +20,17 @@ require ( github.com/go-redis/redis v6.14.2+incompatible github.com/go-redis/redis/v7 v7.4.0 github.com/go-sql-driver/mysql v1.5.0 - github.com/gogo/protobuf v1.1.1 + github.com/gogo/protobuf v1.3.1 github.com/golang/snappy v0.0.0-20180518054509-2e65f85255db // indirect github.com/gomodule/redigo v2.0.0+incompatible + github.com/google/go-cmp v0.5.0 // indirect + github.com/google/uuid v1.1.1 // indirect + github.com/grpc-ecosystem/go-grpc-prometheus v1.2.0 github.com/hashicorp/golang-lru v0.5.4 github.com/ledisdb/ledisdb v0.0.0-20200510135210-d35789ec47e6 github.com/lib/pq v1.0.0 github.com/mattn/go-sqlite3 v2.0.3+incompatible + github.com/mitchellh/mapstructure v1.3.3 github.com/opentracing/opentracing-go v1.2.0 github.com/pelletier/go-toml v1.2.0 // indirect github.com/pkg/errors v0.9.1 @@ -32,9 +40,14 @@ require ( github.com/stretchr/testify v1.4.0 github.com/syndtr/goleveldb v0.0.0-20181127023241-353a9fca669c // indirect github.com/wendal/errors v0.0.0-20130201093226-f66c77a7882b // indirect - golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550 + go.etcd.io/etcd v3.3.25+incompatible // indirect + go.uber.org/zap v1.15.0 // indirect + golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9 + golang.org/x/net v0.0.0-20200822124328-c89045814202 // indirect + golang.org/x/sys v0.0.0-20200824131525-c12d262b63d8 // indirect + golang.org/x/text v0.3.3 // indirect golang.org/x/tools v0.0.0-20200117065230-39095c1d176c - google.golang.org/grpc v1.31.0 // indirect + google.golang.org/grpc v1.26.0 gopkg.in/yaml.v2 v2.2.8 honnef.co/go/tools v0.0.1-2020.1.5 // indirect ) diff --git a/go.sum b/go.sum index 95babc92..545dbae5 100644 --- a/go.sum +++ b/go.sum @@ -28,6 +28,15 @@ github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDk github.com/cloudflare/golz4 v0.0.0-20150217214814-ef862a3cdc58 h1:F1EaeKL/ta07PY/k9Os/UFtwERei2/XzGemhpGnBKNg= github.com/cloudflare/golz4 v0.0.0-20150217214814-ef862a3cdc58/go.mod h1:EOBUe0h4xcZ5GoxqC5SDxFQ8gwyZPKQoEzownBlhI80= github.com/cncf/udpa/go v0.0.0-20191209042840-269d4d468f6f/go.mod h1:M8M6+tZqaGXZJjfX53e64911xZQV5JYwmTeXPW+k8Sc= +github.com/coreos/etcd v0.5.0-alpha.5 h1:0Qi6Jzjk2CDuuGlIeecpu+em2nrjhOgz2wsIwCmQHmc= +github.com/coreos/etcd v3.3.25+incompatible h1:0GQEw6h3YnuOVdtwygkIfJ+Omx0tZ8/QkVyXI4LkbeY= +github.com/coreos/etcd v3.3.25+incompatible/go.mod h1:uF7uidLiAD3TWHmW31ZFd/JWoc32PjwdhPthX9715RE= +github.com/coreos/go-semver v0.3.0 h1:wkHLiw0WNATZnSG7epLsujiMCgPAc9xhjJ4tgnAxmfM= +github.com/coreos/go-semver v0.3.0/go.mod h1:nnelYz7RCh+5ahJtPPxZlU+153eP4D4r3EedlOD2RNk= +github.com/coreos/go-systemd v0.0.0-20191104093116-d3cd4ed1dbcf h1:iW4rZ826su+pqaw19uhpSCzhj44qo35pNgKFGqzDKkU= +github.com/coreos/go-systemd v0.0.0-20191104093116-d3cd4ed1dbcf/go.mod h1:F5haX7vjVVG0kc13fIWeqUViNPyEJxv/OmvnBo0Yme4= +github.com/coreos/pkg v0.0.0-20180928190104-399ea9e2e55f h1:lBNOc5arjvs8E5mO2tbpBpLoyyu8B6e44T7hJy6potg= +github.com/coreos/pkg v0.0.0-20180928190104-399ea9e2e55f/go.mod h1:E3G3o1h8I7cfcXa63jLwjI0eiQQMgzzUDFVpN/nH/eA= github.com/couchbase/go-couchbase v0.0.0-20200519150804-63f3cdb75e0d h1:OMrhQqj1QCyDT2sxHCDjE+k8aMdn2ngTCGG7g4wrdLo= github.com/couchbase/go-couchbase v0.0.0-20200519150804-63f3cdb75e0d/go.mod h1:TWI8EKQMs5u5jLKW/tsb9VwauIrMIxQG1r5fMsswK5U= github.com/couchbase/gomemcached v0.0.0-20200526233749-ec430f949808 h1:8s2l8TVUwMXl6tZMe3+hPCRJ25nQXiA3d1x622JtOqc= @@ -46,6 +55,7 @@ github.com/elastic/go-elasticsearch/v6 v6.8.5/go.mod h1:UwaDJsD3rWLM5rKNFzv9hgox github.com/elazarl/go-bindata-assetfs v1.0.0 h1:G/bYguwHIzWq9ZoyUQqrjTmJbbYn3j3CKKpKinvZLFk= github.com/elazarl/go-bindata-assetfs v1.0.0/go.mod h1:v+YaWX3bdea5J/mo8dSETolEo7R71Vk1u8bnjau5yw4= github.com/envoyproxy/go-control-plane v0.9.0/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4= +github.com/envoyproxy/go-control-plane v0.9.1-0.20191026205805-5f8ba28d4473/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4= github.com/envoyproxy/go-control-plane v0.9.4/go.mod h1:6rpuAdCZL397s3pYoYcLgu1mIlRU8Am5FuJP05cCM98= github.com/envoyproxy/protoc-gen-validate v0.1.0/go.mod h1:iSmxcyjqTsJpI2R4NaDN7+kN2VEUnK/pcBlmesArF7c= github.com/fsnotify/fsnotify v1.4.7 h1:IXs+QLmnXW2CcXuY+8Mzv/fWEsPGWxqefPtCP5CnV9I= @@ -69,6 +79,8 @@ github.com/go-yaml/yaml v0.0.0-20180328195020-5420a8b6744d h1:xy93KVe+KrIIwWDEAf github.com/go-yaml/yaml v0.0.0-20180328195020-5420a8b6744d/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= github.com/gogo/protobuf v1.1.1 h1:72R+M5VuhED/KujmZVcIquuo8mBgX4oVda//DQb3PXo= github.com/gogo/protobuf v1.1.1/go.mod h1:r8qH/GZQm5c6nD/R0oafs1akxWv10x8SbQlK7atdtwQ= +github.com/gogo/protobuf v1.3.1 h1:DqDEcV5aeaTmdFBePNpYsp3FlcVH/2ISVVM9Qf8PSls= +github.com/gogo/protobuf v1.3.1/go.mod h1:SlYgWuQ5SjCEi6WLHjHCa1yvBfUnHcTbrrZtXPKa29o= github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q= github.com/golang/mock v1.1.1/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A= github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= @@ -80,6 +92,7 @@ github.com/golang/protobuf v1.4.0-rc.1.0.20200221234624-67d41d38c208/go.mod h1:x github.com/golang/protobuf v1.4.0-rc.2/go.mod h1:LlEzMj4AhA7rCAGe4KMBDvJI+AwstrUpVNzEA03Pprs= github.com/golang/protobuf v1.4.0-rc.4.0.20200313231945-b860323f09d0/go.mod h1:WU3c8KckQ9AFe+yFwt9sWVRKCVIyN9cPHBJSNnbL67w= github.com/golang/protobuf v1.4.0/go.mod h1:jodUvKwWbYaEsadDk5Fwe5c77LiNKVO9IDvqG2KuDX0= +github.com/golang/protobuf v1.4.1/go.mod h1:U8fpvMrcmy5pZrNK1lt4xCsGvpyWQ/VVv6QDs8UjoX8= github.com/golang/protobuf v1.4.2 h1:+Z5KGCizgyZCbGh1KZqA0fcLLkwbsjIzS4aV2v7wJX0= github.com/golang/protobuf v1.4.2/go.mod h1:oDoupMAO8OvCJWAcko0GGGIgR6R6ocIYbsSw735rRwI= github.com/golang/snappy v0.0.0-20170215233205-553a64147049/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= @@ -92,8 +105,13 @@ github.com/google/go-cmp v0.3.0/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMyw github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= github.com/google/go-cmp v0.4.0 h1:xsAVV57WRhGj6kEIi8ReJzQlHHqcBYCElAvkovg3B/4= github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= +github.com/google/go-cmp v0.5.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= github.com/google/renameio v0.1.0/go.mod h1:KWCgfxg9yswjAJkECMjeO8J8rahYeXnNhOm40UhjYkI= +github.com/google/uuid v1.1.1 h1:Gkbcsh/GbpXz7lPftLA3P6TYMwjCLYm83jiFQZF/3gY= +github.com/google/uuid v1.1.1/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/grpc-ecosystem/go-grpc-prometheus v1.2.0 h1:Ovs26xHkKqVztRpIrF/92BcuyuQ/YW4NSIpoGtfXNho= +github.com/grpc-ecosystem/go-grpc-prometheus v1.2.0/go.mod h1:8NvIoxWQoOIhqOTXgfV/d3M/q6VIi02HzZEHgUlZvzk= github.com/hashicorp/golang-lru v0.5.4 h1:YDjusn29QI/Das2iO9M0BHnIbxPeyuCHsjMW+lJfyTc= github.com/hashicorp/golang-lru v0.5.4/go.mod h1:iADmTwqILo4mZ8BN3D2Q6+9jd8WM5uGBxy+E8yxSoD4= github.com/hpcloud/tail v1.0.0 h1:nfCOvKYfkgYP8hkirhJocXT2+zOD8yUNjXaWfTlyFKI= @@ -101,6 +119,7 @@ github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpO github.com/json-iterator/go v1.1.6/go.mod h1:+SdeFBvtyEkXs7REEP0seUULqWtbJapLOCVDaaPEHmU= github.com/json-iterator/go v1.1.10/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4= github.com/julienschmidt/httprouter v1.2.0/go.mod h1:SYymIcj16QtmaHHD7aYtjjsJG7VTCxuUUipMqKk8s4w= +github.com/kisielk/errcheck v1.2.0/go.mod h1:/BMXB+zMLi60iA8Vv6Ksmxu/1UDYcXs4uQLJ+jE2L00= github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck= github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= github.com/kr/logfmt v0.0.0-20140226030751-b84e30acd515 h1:T+h1c/A9Gawja4Y9mFVWj2vyii2bbUNDw3kt9VxK2EY= @@ -117,6 +136,8 @@ github.com/mattn/go-sqlite3 v2.0.3+incompatible h1:gXHsfypPkaMZrKbD5209QV9jbUTJK github.com/mattn/go-sqlite3 v2.0.3+incompatible/go.mod h1:FPy6KqzDD04eiIsT53CuJW3U88zkxoIYsOqkbpncsNc= github.com/matttproud/golang_protobuf_extensions v1.0.1 h1:4hp9jkHxhMHkqkrB3Ix0jegS5sx/RkqARlsWZ6pIwiU= github.com/matttproud/golang_protobuf_extensions v1.0.1/go.mod h1:D8He9yQNgCq6Z5Ld7szi9bcBfOoFv/3dc6xSMkL2PC0= +github.com/mitchellh/mapstructure v1.3.3 h1:SzB1nHZ2Xi+17FP0zVQBHIZqvwRN9408fJO8h+eeNA8= +github.com/mitchellh/mapstructure v1.3.3/go.mod h1:bFUtVrKA4DC2yAKiSyO/QUcy7e+RRV2QTWOzhPopBRo= github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= github.com/modern-go/reflect2 v0.0.0-20180701023420-4b7aa43c6742/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3RllmbCylyMrvgv0= @@ -187,6 +208,16 @@ github.com/wendal/errors v0.0.0-20130201093226-f66c77a7882b h1:0Ve0/CCjiAiyKddUM github.com/wendal/errors v0.0.0-20130201093226-f66c77a7882b/go.mod h1:Q12BUT7DqIlHRmgv3RskH+UCM/4eqVMgI0EMmlSpAXc= github.com/yuin/goldmark v1.1.32/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/gopher-lua v0.0.0-20171031051903-609c9cd26973/go.mod h1:aEV29XrmTYFr3CiRxZeGHpkvbwq+prZduBqMaascyCU= +go.etcd.io/etcd v0.5.0-alpha.5 h1:VOolFSo3XgsmnYDLozjvZ6JL6AAwIDu1Yx1y+4EYLDo= +go.etcd.io/etcd v3.3.25+incompatible h1:V1RzkZJj9LqsJRy+TUBgpWSbZXITLB819lstuTFoZOY= +go.etcd.io/etcd v3.3.25+incompatible/go.mod h1:yaeTdrJi5lOmYerz05bd8+V7KubZs8YSFZfzsF9A6aI= +go.uber.org/atomic v1.6.0 h1:Ezj3JGmsOnG1MoRWQkPBsKLe9DwWD9QeXzTRzzldNVk= +go.uber.org/atomic v1.6.0/go.mod h1:sABNBOSYdrvTF6hTgEIbc7YasKWGhgEQZyfxyTvoXHQ= +go.uber.org/multierr v1.5.0 h1:KCa4XfM8CWFCpxXRGok+Q0SS/0XBhMDbHHGABQLvD2A= +go.uber.org/multierr v1.5.0/go.mod h1:FeouvMocqHpRaaGuG9EjoKcStLC43Zu/fmqdUMPcKYU= +go.uber.org/tools v0.0.0-20190618225709-2cfd321de3ee/go.mod h1:vJERXedbb3MVM5f9Ejo0C68/HhF8uaILCdgjnY+goOA= +go.uber.org/zap v1.15.0 h1:ZZCA22JRF2gQE5FoNmhmrf7jeJJ2uhqDUNRYKm8dvmM= +go.uber.org/zap v1.15.0/go.mod h1:Mb2vm2krFEG5DV0W9qcHBYFtp/Wku1cvYaqPsS/WYfc= golang.org/x/crypto v0.0.0-20180904163835-0709b304e793/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20190510104115-cbcb75029529/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= @@ -198,6 +229,7 @@ golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= golang.org/x/lint v0.0.0-20190227174305-5b3e6a55c961/go.mod h1:wehouNa3lNwaWXcvxsM5YxQ5yQlVC4a0KAMCusXpPoU= golang.org/x/lint v0.0.0-20190313153728-d0100b6bd8b3/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc= +golang.org/x/lint v0.0.0-20190930215403-16217165b5de/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc= golang.org/x/mod v0.0.0-20190513183733-4bf6d317e70e/go.mod h1:mXi4GBBbnImb6dmsKGUJ2LatrhH/nqhxcFungHvyanc= golang.org/x/mod v0.1.1-0.20191105210325-c90efee705ee/go.mod h1:QqPTAvyqsEbceGzBzNggFXnrqF1CaUcvgkdR5Ot7KZg= golang.org/x/mod v0.3.0 h1:RM4zey1++hCTbCVQfnWeKs9/IEsaBLA8vTkd0WVtmH4= @@ -216,6 +248,8 @@ golang.org/x/net v0.0.0-20190923162816-aa69164e4478 h1:l5EDrHhldLYb3ZRHDUhXF7Om7 golang.org/x/net v0.0.0-20190923162816-aa69164e4478/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20200625001655-4c5254603344 h1:vGXIOMxbNfDTk/aXCmfdLgkrSV+Z2tcbze+pEc3v5W4= golang.org/x/net v0.0.0-20200625001655-4c5254603344/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA= +golang.org/x/net v0.0.0-20200822124328-c89045814202 h1:VvcQYSHwXgi7W+TpUR6A9g6Up98WAHf3f/ulnJ62IyA= +golang.org/x/net v0.0.0-20200822124328-c89045814202/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA= golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= @@ -236,15 +270,23 @@ golang.org/x/sys v0.0.0-20200106162015-b016eb3dc98e/go.mod h1:h1NjWce9XRLGQEsW7w golang.org/x/sys v0.0.0-20200323222414-85ca7c5b95cd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200615200032-f1bc736245b1 h1:ogLJMz+qpzav7lGMh10LMvAkM/fAoGlaiiHYiFYdm80= golang.org/x/sys v0.0.0-20200615200032-f1bc736245b1/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200824131525-c12d262b63d8 h1:AvbQYmiaaaza3cW3QXRyPo5kYgpFIzOAfeAAN7m3qQ4= +golang.org/x/sys v0.0.0-20200824131525-c12d262b63d8/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/text v0.3.0 h1:g61tztE5qeGQ89tm6NTjjM9VPIm088od1l6aSorWRWg= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.2 h1:tW2bmiBqwgJj/UpqtC8EpXEZVYOwU0yG4iWbprSVAcs= golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= +golang.org/x/text v0.3.3 h1:cokOdA+Jmi5PJGXLlLllQSgYigAEfHXJAERHVMaCc2k= +golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +golang.org/x/tools v0.0.0-20181030221726-6c7e314b6563/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20190114222345-bf090417da8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20190226205152-f727befe758c/go.mod h1:9Yl7xja0Znq3iFh3HoIrodX9oNMXvdceNzlUR8zjMvY= golang.org/x/tools v0.0.0-20190311212946-11955173bddd/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= golang.org/x/tools v0.0.0-20190524140312-2c0ae7006135/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q= +golang.org/x/tools v0.0.0-20190621195816-6e04913cbbac/go.mod h1:/rFqwRUd4F7ZHNgwSSTFct+R/Kf4OFW1sUzUTQQTgfc= +golang.org/x/tools v0.0.0-20191029041327-9cc4af7d6b2c/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= +golang.org/x/tools v0.0.0-20191029190741-b9c20aec41a5/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.0.0-20191130070609-6e064ea0cf2d/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.0.0-20200117065230-39095c1d176c h1:FodBYPZKH5tAN2O60HlglMwXGAeV/4k+NKbli79M/2c= @@ -260,19 +302,34 @@ golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8T google.golang.org/appengine v1.1.0/go.mod h1:EbEs0AVv82hx2wNQdGPgUI5lhzA/G0D9YwlJXL52JkM= google.golang.org/appengine v1.4.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4= google.golang.org/genproto v0.0.0-20180817151627-c66870c02cf8/go.mod h1:JiN7NxoALGmiZfu7CAH4rXhgtRTLTxftemlI0sWmxmc= +google.golang.org/genproto v0.0.0-20190819201941-24fa4b261c55 h1:gSJIx1SDwno+2ElGhA4+qG2zF97qiUzTM+rQ0klBOcE= google.golang.org/genproto v0.0.0-20190819201941-24fa4b261c55/go.mod h1:DMBHOl98Agz4BDEuKkezgsaosCRResVns1a3J2ZsMNc= +google.golang.org/genproto v0.0.0-20200526211855-cb27e3aa2013/go.mod h1:NbSheEEYHJ7i3ixzK3sjbqSGDJWnxyFXZblF3eUsNvo= +google.golang.org/genproto v0.0.0-20200825200019-8632dd797987 h1:PDIOdWxZ8eRizhKa1AAvY53xsvLB1cWorMjslvY3VA8= +google.golang.org/genproto v0.0.0-20200825200019-8632dd797987/go.mod h1:FWY/as6DDZQgahTzZj3fqbO1CbirC29ZNUFHwi0/+no= google.golang.org/grpc v1.19.0/go.mod h1:mqu4LbDTu4XGKhr4mRzUsmM4RtVoemTSY81AxZiDr8c= google.golang.org/grpc v1.23.0/go.mod h1:Y5yQAOtifL1yxbo5wqy6BxZv8vAUGQwXBOALyacEbxg= google.golang.org/grpc v1.25.1/go.mod h1:c3i+UQWmh7LiEpx4sFZnkU36qjEYZ0imhYfXVyQciAY= +google.golang.org/grpc v1.26.0 h1:2dTRdpdFEEhJYQD8EMLB61nnrzSCTbG38PhqdhvOltg= +google.golang.org/grpc v1.26.0/go.mod h1:qbnxyOmOxrQa7FizSgH+ReBfzJrCY1pSN7KXBS8abTk= +google.golang.org/grpc v1.27.0 h1:rRYRFMVgRv6E0D70Skyfsr28tDXIuuPZyWGMPdMcnXg= +google.golang.org/grpc v1.27.0/go.mod h1:qbnxyOmOxrQa7FizSgH+ReBfzJrCY1pSN7KXBS8abTk= google.golang.org/grpc v1.31.0 h1:T7P4R73V3SSDPhH7WW7ATbfViLtmamH0DKrP3f9AuDI= google.golang.org/grpc v1.31.0/go.mod h1:N36X2cJ7JwdamYAgDz+s+rVMFjt3numwzf/HckM8pak= +google.golang.org/grpc v1.31.1 h1:SfXqXS5hkufcdZ/mHtYCh53P2b+92WQq/DZcKLgsFRs= +google.golang.org/grpc v1.31.1/go.mod h1:N36X2cJ7JwdamYAgDz+s+rVMFjt3numwzf/HckM8pak= google.golang.org/protobuf v0.0.0-20200109180630-ec00e32a8dfd/go.mod h1:DFci5gLYBciE7Vtevhsrf46CRTquxDuWsQurQQe4oz8= google.golang.org/protobuf v0.0.0-20200221191635-4d8936d0db64/go.mod h1:kwYJMbMJ01Woi6D6+Kah6886xMZcty6N08ah7+eCXa0= google.golang.org/protobuf v0.0.0-20200228230310-ab0ca4ff8a60/go.mod h1:cfTl7dwQJ+fmap5saPgwCLgHXTUD7jkjRqWcaiX5VyM= google.golang.org/protobuf v1.20.1-0.20200309200217-e05f789c0967/go.mod h1:A+miEFZTKqfCUM6K7xSMQL9OKL/b6hQv+e19PK+JZNE= google.golang.org/protobuf v1.21.0/go.mod h1:47Nbq4nVaFHyn7ilMalzfO3qCViNmqZ2kzikPIcrTAo= +google.golang.org/protobuf v1.22.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU= google.golang.org/protobuf v1.23.0 h1:4MY060fB1DLGMB/7MBTLnwQUY6+F09GEiz6SsrNqyzM= google.golang.org/protobuf v1.23.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU= +google.golang.org/protobuf v1.23.1-0.20200526195155-81db48ad09cc/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU= +google.golang.org/protobuf v1.24.0/go.mod h1:r/3tXBNzIEhYS9I1OUVjXDlt8tc493IdKGjtUeSXeh4= +google.golang.org/protobuf v1.25.0 h1:Ejskq+SyPohKW+1uil0JJMtmHCgJPJ/qWTxr8qp+R4c= +google.golang.org/protobuf v1.25.0/go.mod h1:9JNX74DMeImyA3h4bdi1ymwjUzf21/xIlbajtzgsN7c= gopkg.in/alecthomas/kingpin.v2 v2.2.6/go.mod h1:FMv+mEhP44yOT+4EoQTLFTRgOQ1FBLkstjWtayDeSgw= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= @@ -293,5 +350,6 @@ gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= honnef.co/go/tools v0.0.0-20190523083050-ea95bdfd59fc h1:/hemPrYIhOhy8zYrNj+069zDB68us2sMGsfkFJO0iZs= honnef.co/go/tools v0.0.0-20190523083050-ea95bdfd59fc/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= +honnef.co/go/tools v0.0.1-2019.2.3/go.mod h1:a3bituU0lyd329TUQxRnasdCoJDkEUEAqEt0JzvZhAg= honnef.co/go/tools v0.0.1-2020.1.5 h1:nI5egYTGJakVyOryqLs1cQO5dO0ksin5XXs2pspk75k= honnef.co/go/tools v0.0.1-2020.1.5/go.mod h1:X/FiERA/W4tHapMX5mGpAtMSVEeEUOyHaw9vFzvIQ3k= diff --git a/pkg/adapter/admin.go b/pkg/adapter/admin.go new file mode 100644 index 00000000..87e7259b --- /dev/null +++ b/pkg/adapter/admin.go @@ -0,0 +1,48 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package adapter + +import ( + "time" + + "github.com/astaxie/beego/pkg/server/web" +) + +// FilterMonitorFunc is default monitor filter when admin module is enable. +// if this func returns, admin module records qps for this request by condition of this function logic. +// usage: +// func MyFilterMonitor(method, requestPath string, t time.Duration, pattern string, statusCode int) bool { +// if method == "POST" { +// return false +// } +// if t.Nanoseconds() < 100 { +// return false +// } +// if strings.HasPrefix(requestPath, "/astaxie") { +// return false +// } +// return true +// } +// beego.FilterMonitorFunc = MyFilterMonitor. +var FilterMonitorFunc func(string, string, time.Duration, string, int) bool + +func init() { + FilterMonitorFunc = web.FilterMonitorFunc +} + +// PrintTree prints all registered routers. +func PrintTree() M { + return (M)(web.PrintTree()) +} diff --git a/pkg/adapter/app.go b/pkg/adapter/app.go new file mode 100644 index 00000000..c1046c79 --- /dev/null +++ b/pkg/adapter/app.go @@ -0,0 +1,262 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package adapter + +import ( + "net/http" + + context2 "github.com/astaxie/beego/pkg/adapter/context" + "github.com/astaxie/beego/pkg/server/web" + "github.com/astaxie/beego/pkg/server/web/context" +) + +var ( + // BeeApp is an application instance + BeeApp *App +) + +func init() { + // create beego application + BeeApp = (*App)(web.BeeApp) +} + +// App defines beego application with a new PatternServeMux. +type App web.App + +// NewApp returns a new beego application. +func NewApp() *App { + return (*App)(web.NewApp()) +} + +// MiddleWare function for http.Handler +type MiddleWare web.MiddleWare + +// Run beego application. +func (app *App) Run(mws ...MiddleWare) { + newMws := oldMiddlewareToNew(mws) + (*web.App)(app).Run(newMws...) +} + +func oldMiddlewareToNew(mws []MiddleWare) []web.MiddleWare { + newMws := make([]web.MiddleWare, 0, len(mws)) + for _, old := range mws { + newMws = append(newMws, (web.MiddleWare)(old)) + } + return newMws +} + +// Router adds a patterned controller handler to BeeApp. +// it's an alias method of App.Router. +// usage: +// simple router +// beego.Router("/admin", &admin.UserController{}) +// beego.Router("/admin/index", &admin.ArticleController{}) +// +// regex router +// +// beego.Router("/api/:id([0-9]+)", &controllers.RController{}) +// +// custom rules +// beego.Router("/api/list",&RestController{},"*:ListFood") +// beego.Router("/api/create",&RestController{},"post:CreateFood") +// beego.Router("/api/update",&RestController{},"put:UpdateFood") +// beego.Router("/api/delete",&RestController{},"delete:DeleteFood") +func Router(rootpath string, c ControllerInterface, mappingMethods ...string) *App { + return (*App)(web.Router(rootpath, c, mappingMethods...)) +} + +// UnregisterFixedRoute unregisters the route with the specified fixedRoute. It is particularly useful +// in web applications that inherit most routes from a base webapp via the underscore +// import, and aim to overwrite only certain paths. +// The method parameter can be empty or "*" for all HTTP methods, or a particular +// method type (e.g. "GET" or "POST") for selective removal. +// +// Usage (replace "GET" with "*" for all methods): +// beego.UnregisterFixedRoute("/yourpreviouspath", "GET") +// beego.Router("/yourpreviouspath", yourControllerAddress, "get:GetNewPage") +func UnregisterFixedRoute(fixedRoute string, method string) *App { + return (*App)(web.UnregisterFixedRoute(fixedRoute, method)) +} + +// Include will generate router file in the router/xxx.go from the controller's comments +// usage: +// beego.Include(&BankAccount{}, &OrderController{},&RefundController{},&ReceiptController{}) +// type BankAccount struct{ +// beego.Controller +// } +// +// register the function +// func (b *BankAccount)Mapping(){ +// b.Mapping("ShowAccount" , b.ShowAccount) +// b.Mapping("ModifyAccount", b.ModifyAccount) +// } +// +// //@router /account/:id [get] +// func (b *BankAccount) ShowAccount(){ +// //logic +// } +// +// +// //@router /account/:id [post] +// func (b *BankAccount) ModifyAccount(){ +// //logic +// } +// +// the comments @router url methodlist +// url support all the function Router's pattern +// methodlist [get post head put delete options *] +func Include(cList ...ControllerInterface) *App { + newList := oldToNewCtrlIntfs(cList) + return (*App)(web.Include(newList...)) +} + +func oldToNewCtrlIntfs(cList []ControllerInterface) []web.ControllerInterface { + newList := make([]web.ControllerInterface, 0, len(cList)) + for _, c := range cList { + newList = append(newList, c) + } + return newList +} + +// RESTRouter adds a restful controller handler to BeeApp. +// its' controller implements beego.ControllerInterface and +// defines a param "pattern/:objectId" to visit each resource. +func RESTRouter(rootpath string, c ControllerInterface) *App { + return (*App)(web.RESTRouter(rootpath, c)) +} + +// AutoRouter adds defined controller handler to BeeApp. +// it's same to App.AutoRouter. +// if beego.AddAuto(&MainContorlller{}) and MainController has methods List and Page, +// visit the url /main/list to exec List function or /main/page to exec Page function. +func AutoRouter(c ControllerInterface) *App { + return (*App)(web.AutoRouter(c)) +} + +// AutoPrefix adds controller handler to BeeApp with prefix. +// it's same to App.AutoRouterWithPrefix. +// if beego.AutoPrefix("/admin",&MainContorlller{}) and MainController has methods List and Page, +// visit the url /admin/main/list to exec List function or /admin/main/page to exec Page function. +func AutoPrefix(prefix string, c ControllerInterface) *App { + return (*App)(web.AutoPrefix(prefix, c)) +} + +// Get used to register router for Get method +// usage: +// beego.Get("/", func(ctx *context.Context){ +// ctx.Output.Body("hello world") +// }) +func Get(rootpath string, f FilterFunc) *App { + return (*App)(web.Get(rootpath, func(ctx *context.Context) { + f((*context2.Context)(ctx)) + })) +} + +// Post used to register router for Post method +// usage: +// beego.Post("/api", func(ctx *context.Context){ +// ctx.Output.Body("hello world") +// }) +func Post(rootpath string, f FilterFunc) *App { + return (*App)(web.Post(rootpath, func(ctx *context.Context) { + f((*context2.Context)(ctx)) + })) +} + +// Delete used to register router for Delete method +// usage: +// beego.Delete("/api", func(ctx *context.Context){ +// ctx.Output.Body("hello world") +// }) +func Delete(rootpath string, f FilterFunc) *App { + return (*App)(web.Delete(rootpath, func(ctx *context.Context) { + f((*context2.Context)(ctx)) + })) +} + +// Put used to register router for Put method +// usage: +// beego.Put("/api", func(ctx *context.Context){ +// ctx.Output.Body("hello world") +// }) +func Put(rootpath string, f FilterFunc) *App { + return (*App)(web.Put(rootpath, func(ctx *context.Context) { + f((*context2.Context)(ctx)) + })) +} + +// Head used to register router for Head method +// usage: +// beego.Head("/api", func(ctx *context.Context){ +// ctx.Output.Body("hello world") +// }) +func Head(rootpath string, f FilterFunc) *App { + return (*App)(web.Head(rootpath, func(ctx *context.Context) { + f((*context2.Context)(ctx)) + })) +} + +// Options used to register router for Options method +// usage: +// beego.Options("/api", func(ctx *context.Context){ +// ctx.Output.Body("hello world") +// }) +func Options(rootpath string, f FilterFunc) *App { + return (*App)(web.Options(rootpath, func(ctx *context.Context) { + f((*context2.Context)(ctx)) + })) +} + +// Patch used to register router for Patch method +// usage: +// beego.Patch("/api", func(ctx *context.Context){ +// ctx.Output.Body("hello world") +// }) +func Patch(rootpath string, f FilterFunc) *App { + return (*App)(web.Patch(rootpath, func(ctx *context.Context) { + f((*context2.Context)(ctx)) + })) +} + +// Any used to register router for all methods +// usage: +// beego.Any("/api", func(ctx *context.Context){ +// ctx.Output.Body("hello world") +// }) +func Any(rootpath string, f FilterFunc) *App { + return (*App)(web.Any(rootpath, func(ctx *context.Context) { + f((*context2.Context)(ctx)) + })) +} + +// Handler used to register a Handler router +// usage: +// beego.Handler("/api", http.HandlerFunc(func (w http.ResponseWriter, r *http.Request) { +// fmt.Fprintf(w, "Hello, %q", html.EscapeString(r.URL.Path)) +// })) +func Handler(rootpath string, h http.Handler, options ...interface{}) *App { + return (*App)(web.Handler(rootpath, h, options)) +} + +// InsertFilter adds a FilterFunc with pattern condition and action constant. +// The pos means action constant including +// beego.BeforeStatic, beego.BeforeRouter, beego.BeforeExec, beego.AfterExec and beego.FinishRouter. +// The bool params is for setting the returnOnOutput value (false allows multiple filters to execute) +func InsertFilter(pattern string, pos int, filter FilterFunc, params ...bool) *App { + opts := oldToNewFilterOpts(params) + return (*App)(web.InsertFilter(pattern, pos, func(ctx *context.Context) { + filter((*context2.Context)(ctx)) + }, opts...)) +} diff --git a/pkg/adapter/beego.go b/pkg/adapter/beego.go new file mode 100644 index 00000000..efd2d4ea --- /dev/null +++ b/pkg/adapter/beego.go @@ -0,0 +1,75 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package adapter + +import ( + "github.com/astaxie/beego/pkg/server/web" +) + +const ( + // VERSION represent beego web framework version. + VERSION = web.VERSION + + // DEV is for develop + DEV = web.DEV + // PROD is for production + PROD = web.PROD +) + +// M is Map shortcut +type M web.M + +// Hook function to run +type hookfunc func() error + +var ( + hooks = make([]hookfunc, 0) // hook function slice to store the hookfunc +) + +// AddAPPStartHook is used to register the hookfunc +// The hookfuncs will run in beego.Run() +// such as initiating session , starting middleware , building template, starting admin control and so on. +func AddAPPStartHook(hf ...hookfunc) { + for _, f := range hf { + web.AddAPPStartHook(func() error { + return f() + }) + } +} + +// Run beego application. +// beego.Run() default run on HttpPort +// beego.Run("localhost") +// beego.Run(":8089") +// beego.Run("127.0.0.1:8089") +func Run(params ...string) { + web.Run(params...) +} + +// RunWithMiddleWares Run beego application with middlewares. +func RunWithMiddleWares(addr string, mws ...MiddleWare) { + newMws := oldMiddlewareToNew(mws) + web.RunWithMiddleWares(addr, newMws...) +} + +// TestBeegoInit is for test package init +func TestBeegoInit(ap string) { + web.TestBeegoInit(ap) +} + +// InitBeegoBeforeTest is for test package init +func InitBeegoBeforeTest(appConfigPath string) { + web.InitBeegoBeforeTest(appConfigPath) +} diff --git a/pkg/adapter/build_info.go b/pkg/adapter/build_info.go new file mode 100644 index 00000000..1e8dacf0 --- /dev/null +++ b/pkg/adapter/build_info.go @@ -0,0 +1,27 @@ +// Copyright 2020 astaxie +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package adapter + +var ( + BuildVersion string + BuildGitRevision string + BuildStatus string + BuildTag string + BuildTime string + + GoVersion string + + GitBranch string +) diff --git a/pkg/adapter/cache/cache.go b/pkg/adapter/cache/cache.go new file mode 100644 index 00000000..21bb9141 --- /dev/null +++ b/pkg/adapter/cache/cache.go @@ -0,0 +1,85 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package cache provide a Cache interface and some implement engine +// Usage: +// +// import( +// "github.com/astaxie/beego/cache" +// ) +// +// bm, err := cache.NewCache("memory", `{"interval":60}`) +// +// Use it like this: +// +// bm.Put("astaxie", 1, 10 * time.Second) +// bm.Get("astaxie") +// bm.IsExist("astaxie") +// bm.Delete("astaxie") +// +// more docs http://beego.me/docs/module/cache.md +package cache + +import ( + "fmt" + + "github.com/astaxie/beego/pkg/client/cache" +) + +// Cache interface contains all behaviors for cache adapter. +// usage: +// cache.Register("file",cache.NewFileCache) // this operation is run in init method of file.go. +// c,err := cache.NewCache("file","{....}") +// c.Put("key",value, 3600 * time.Second) +// v := c.Get("key") +// +// c.Incr("counter") // now is 1 +// c.Incr("counter") // now is 2 +// count := c.Get("counter").(int) +type Cache cache.Cache + +// Instance is a function create a new Cache Instance +type Instance func() Cache + +var adapters = make(map[string]Instance) + +// Register makes a cache adapter available by the adapter name. +// If Register is called twice with the same name or if driver is nil, +// it panics. +func Register(name string, adapter Instance) { + if adapter == nil { + panic("cache: Register adapter is nil") + } + if _, ok := adapters[name]; ok { + panic("cache: Register called twice for adapter " + name) + } + adapters[name] = adapter +} + +// NewCache Create a new cache driver by adapter name and config string. +// config need to be correct JSON as string: {"interval":360}. +// it will start gc automatically. +func NewCache(adapterName, config string) (adapter Cache, err error) { + instanceFunc, ok := adapters[adapterName] + if !ok { + err = fmt.Errorf("cache: unknown adapter name %q (forgot to import?)", adapterName) + return + } + adapter = instanceFunc() + err = adapter.StartAndGC(config) + if err != nil { + adapter = nil + } + return +} diff --git a/pkg/adapter/config.go b/pkg/adapter/config.go new file mode 100644 index 00000000..1491722c --- /dev/null +++ b/pkg/adapter/config.go @@ -0,0 +1,179 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package adapter + +import ( + context2 "context" + + "github.com/astaxie/beego/pkg/adapter/session" + newCfg "github.com/astaxie/beego/pkg/infrastructure/config" + "github.com/astaxie/beego/pkg/server/web" +) + +// Config is the main struct for BConfig +type Config web.Config + +// Listen holds for http and https related config +type Listen web.Listen + +// WebConfig holds web related config +type WebConfig web.WebConfig + +// SessionConfig holds session related config +type SessionConfig web.SessionConfig + +// LogConfig holds Log related config +type LogConfig web.LogConfig + +var ( + // BConfig is the default config for Application + BConfig *Config + // AppConfig is the instance of Config, store the config information from file + AppConfig *beegoAppConfig + // AppPath is the absolute path to the app + AppPath string + // GlobalSessions is the instance for the session manager + GlobalSessions *session.Manager + + // appConfigPath is the path to the config files + appConfigPath string + // appConfigProvider is the provider for the config, default is ini + appConfigProvider = "ini" + // WorkPath is the absolute path to project root directory + WorkPath string +) + +func init() { + BConfig = (*Config)(web.BConfig) + AppPath = web.AppPath + + WorkPath = web.WorkPath + + AppConfig = &beegoAppConfig{innerConfig: (newCfg.Configer)(web.AppConfig)} +} + +// LoadAppConfig allow developer to apply a config file +func LoadAppConfig(adapterName, configPath string) error { + return web.LoadAppConfig(adapterName, configPath) +} + +type beegoAppConfig struct { + innerConfig newCfg.Configer +} + +func (b *beegoAppConfig) Set(key, val string) error { + if err := b.innerConfig.Set(context2.Background(), BConfig.RunMode+"::"+key, val); err != nil { + return b.innerConfig.Set(context2.Background(), key, val) + } + return nil +} + +func (b *beegoAppConfig) String(key string) string { + if v, err := b.innerConfig.String(context2.Background(), BConfig.RunMode+"::"+key); v != "" && err != nil { + return v + } + res, _ := b.innerConfig.String(context2.Background(), key) + return res +} + +func (b *beegoAppConfig) Strings(key string) []string { + if v, err := b.innerConfig.Strings(context2.Background(), BConfig.RunMode+"::"+key); len(v) > 0 && err != nil { + return v + } + res, _ := b.innerConfig.Strings(context2.Background(), key) + return res +} + +func (b *beegoAppConfig) Int(key string) (int, error) { + if v, err := b.innerConfig.Int(context2.Background(), BConfig.RunMode+"::"+key); err == nil { + return v, nil + } + return b.innerConfig.Int(context2.Background(), key) +} + +func (b *beegoAppConfig) Int64(key string) (int64, error) { + if v, err := b.innerConfig.Int64(context2.Background(), BConfig.RunMode+"::"+key); err == nil { + return v, nil + } + return b.innerConfig.Int64(context2.Background(), key) +} + +func (b *beegoAppConfig) Bool(key string) (bool, error) { + if v, err := b.innerConfig.Bool(context2.Background(), BConfig.RunMode+"::"+key); err == nil { + return v, nil + } + return b.innerConfig.Bool(context2.Background(), key) +} + +func (b *beegoAppConfig) Float(key string) (float64, error) { + if v, err := b.innerConfig.Float(context2.Background(), BConfig.RunMode+"::"+key); err == nil { + return v, nil + } + return b.innerConfig.Float(context2.Background(), key) +} + +func (b *beegoAppConfig) DefaultString(key string, defaultVal string) string { + if v := b.String(key); v != "" { + return v + } + return defaultVal +} + +func (b *beegoAppConfig) DefaultStrings(key string, defaultVal []string) []string { + if v := b.Strings(key); len(v) != 0 { + return v + } + return defaultVal +} + +func (b *beegoAppConfig) DefaultInt(key string, defaultVal int) int { + if v, err := b.Int(key); err == nil { + return v + } + return defaultVal +} + +func (b *beegoAppConfig) DefaultInt64(key string, defaultVal int64) int64 { + if v, err := b.Int64(key); err == nil { + return v + } + return defaultVal +} + +func (b *beegoAppConfig) DefaultBool(key string, defaultVal bool) bool { + if v, err := b.Bool(key); err == nil { + return v + } + return defaultVal +} + +func (b *beegoAppConfig) DefaultFloat(key string, defaultVal float64) float64 { + if v, err := b.Float(key); err == nil { + return v + } + return defaultVal +} + +func (b *beegoAppConfig) DIY(key string) (interface{}, error) { + return b.innerConfig.DIY(context2.Background(), key) +} + +func (b *beegoAppConfig) GetSection(section string) (map[string]string, error) { + return b.innerConfig.GetSection(context2.Background(), section) +} + +func (b *beegoAppConfig) SaveConfigFile(filename string) error { + return b.innerConfig.SaveConfigFile(context2.Background(), filename) +} diff --git a/pkg/adapter/config/adapter.go b/pkg/adapter/config/adapter.go new file mode 100644 index 00000000..f74b3ff9 --- /dev/null +++ b/pkg/adapter/config/adapter.go @@ -0,0 +1,193 @@ +// Copyright 2020 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package config + +import ( + "context" + + "github.com/pkg/errors" + + "github.com/astaxie/beego/pkg/infrastructure/config" +) + +type newToOldConfigerAdapter struct { + delegate config.Configer +} + +func (c *newToOldConfigerAdapter) Set(key, val string) error { + return c.delegate.Set(context.Background(), key, val) +} + +func (c *newToOldConfigerAdapter) String(key string) string { + res, _ := c.delegate.String(context.Background(), key) + return res +} + +func (c *newToOldConfigerAdapter) Strings(key string) []string { + res, _ := c.delegate.Strings(context.Background(), key) + return res +} + +func (c *newToOldConfigerAdapter) Int(key string) (int, error) { + return c.delegate.Int(context.Background(), key) +} + +func (c *newToOldConfigerAdapter) Int64(key string) (int64, error) { + return c.delegate.Int64(context.Background(), key) +} + +func (c *newToOldConfigerAdapter) Bool(key string) (bool, error) { + return c.delegate.Bool(context.Background(), key) +} + +func (c *newToOldConfigerAdapter) Float(key string) (float64, error) { + return c.delegate.Float(context.Background(), key) +} + +func (c *newToOldConfigerAdapter) DefaultString(key string, defaultVal string) string { + return c.delegate.DefaultString(context.Background(), key, defaultVal) +} + +func (c *newToOldConfigerAdapter) DefaultStrings(key string, defaultVal []string) []string { + return c.delegate.DefaultStrings(context.Background(), key, defaultVal) +} + +func (c *newToOldConfigerAdapter) DefaultInt(key string, defaultVal int) int { + return c.delegate.DefaultInt(context.Background(), key, defaultVal) +} + +func (c *newToOldConfigerAdapter) DefaultInt64(key string, defaultVal int64) int64 { + return c.delegate.DefaultInt64(context.Background(), key, defaultVal) +} + +func (c *newToOldConfigerAdapter) DefaultBool(key string, defaultVal bool) bool { + return c.delegate.DefaultBool(context.Background(), key, defaultVal) +} + +func (c *newToOldConfigerAdapter) DefaultFloat(key string, defaultVal float64) float64 { + return c.delegate.DefaultFloat(context.Background(), key, defaultVal) +} + +func (c *newToOldConfigerAdapter) DIY(key string) (interface{}, error) { + return c.delegate.DIY(context.Background(), key) +} + +func (c *newToOldConfigerAdapter) GetSection(section string) (map[string]string, error) { + return c.delegate.GetSection(context.Background(), section) +} + +func (c *newToOldConfigerAdapter) SaveConfigFile(filename string) error { + return c.delegate.SaveConfigFile(context.Background(), filename) +} + +type oldToNewConfigerAdapter struct { + delegate Configer +} + +func (o *oldToNewConfigerAdapter) Set(ctx context.Context, key, val string) error { + return o.delegate.Set(key, val) +} + +func (o *oldToNewConfigerAdapter) String(ctx context.Context, key string) (string, error) { + return o.delegate.String(key), nil +} + +func (o *oldToNewConfigerAdapter) Strings(ctx context.Context, key string) ([]string, error) { + return o.delegate.Strings(key), nil +} + +func (o *oldToNewConfigerAdapter) Int(ctx context.Context, key string) (int, error) { + return o.delegate.Int(key) +} + +func (o *oldToNewConfigerAdapter) Int64(ctx context.Context, key string) (int64, error) { + return o.delegate.Int64(key) +} + +func (o *oldToNewConfigerAdapter) Bool(ctx context.Context, key string) (bool, error) { + return o.delegate.Bool(key) +} + +func (o *oldToNewConfigerAdapter) Float(ctx context.Context, key string) (float64, error) { + return o.delegate.Float(key) +} + +func (o *oldToNewConfigerAdapter) DefaultString(ctx context.Context, key string, defaultVal string) string { + return o.delegate.DefaultString(key, defaultVal) +} + +func (o *oldToNewConfigerAdapter) DefaultStrings(ctx context.Context, key string, defaultVal []string) []string { + return o.delegate.DefaultStrings(key, defaultVal) +} + +func (o *oldToNewConfigerAdapter) DefaultInt(ctx context.Context, key string, defaultVal int) int { + return o.delegate.DefaultInt(key, defaultVal) +} + +func (o *oldToNewConfigerAdapter) DefaultInt64(ctx context.Context, key string, defaultVal int64) int64 { + return o.delegate.DefaultInt64(key, defaultVal) +} + +func (o *oldToNewConfigerAdapter) DefaultBool(ctx context.Context, key string, defaultVal bool) bool { + return o.delegate.DefaultBool(key, defaultVal) +} + +func (o *oldToNewConfigerAdapter) DefaultFloat(ctx context.Context, key string, defaultVal float64) float64 { + return o.delegate.DefaultFloat(key, defaultVal) +} + +func (o *oldToNewConfigerAdapter) DIY(ctx context.Context, key string) (interface{}, error) { + return o.delegate.DIY(key) +} + +func (o *oldToNewConfigerAdapter) GetSection(ctx context.Context, section string) (map[string]string, error) { + return o.delegate.GetSection(section) +} + +func (o *oldToNewConfigerAdapter) Unmarshaler(ctx context.Context, prefix string, obj interface{}, opt ...config.DecodeOption) error { + return errors.New("unsupported operation, please use actual config.Configer") +} + +func (o *oldToNewConfigerAdapter) Sub(ctx context.Context, key string) (config.Configer, error) { + return nil, errors.New("unsupported operation, please use actual config.Configer") +} + +func (o *oldToNewConfigerAdapter) OnChange(ctx context.Context, key string, fn func(value string)) { + // do nothing +} + +func (o *oldToNewConfigerAdapter) SaveConfigFile(ctx context.Context, filename string) error { + return o.delegate.SaveConfigFile(filename) +} + +type oldToNewConfigAdapter struct { + delegate Config +} + +func (o *oldToNewConfigAdapter) Parse(key string) (config.Configer, error) { + old, err := o.delegate.Parse(key) + if err != nil { + return nil, err + } + return &oldToNewConfigerAdapter{delegate: old}, nil +} + +func (o *oldToNewConfigAdapter) ParseData(data []byte) (config.Configer, error) { + old, err := o.delegate.ParseData(data) + if err != nil { + return nil, err + } + return &oldToNewConfigerAdapter{delegate: old}, nil +} diff --git a/pkg/adapter/config/config.go b/pkg/adapter/config/config.go new file mode 100644 index 00000000..c870a15a --- /dev/null +++ b/pkg/adapter/config/config.go @@ -0,0 +1,151 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package config is used to parse config. +// Usage: +// import "github.com/astaxie/beego/config" +// Examples. +// +// cnf, err := config.NewConfig("ini", "config.conf") +// +// cnf APIS: +// +// cnf.Set(key, val string) error +// cnf.String(key string) string +// cnf.Strings(key string) []string +// cnf.Int(key string) (int, error) +// cnf.Int64(key string) (int64, error) +// cnf.Bool(key string) (bool, error) +// cnf.Float(key string) (float64, error) +// cnf.DefaultString(key string, defaultVal string) string +// cnf.DefaultStrings(key string, defaultVal []string) []string +// cnf.DefaultInt(key string, defaultVal int) int +// cnf.DefaultInt64(key string, defaultVal int64) int64 +// cnf.DefaultBool(key string, defaultVal bool) bool +// cnf.DefaultFloat(key string, defaultVal float64) float64 +// cnf.DIY(key string) (interface{}, error) +// cnf.GetSection(section string) (map[string]string, error) +// cnf.SaveConfigFile(filename string) error +// More docs http://beego.me/docs/module/config.md +package config + +import ( + "github.com/astaxie/beego/pkg/infrastructure/config" +) + +// Configer defines how to get and set value from configuration raw data. +type Configer interface { + Set(key, val string) error // support section::key type in given key when using ini type. + String(key string) string // support section::key type in key string when using ini and json type; Int,Int64,Bool,Float,DIY are same. + Strings(key string) []string // get string slice + Int(key string) (int, error) + Int64(key string) (int64, error) + Bool(key string) (bool, error) + Float(key string) (float64, error) + DefaultString(key string, defaultVal string) string // support section::key type in key string when using ini and json type; Int,Int64,Bool,Float,DIY are same. + DefaultStrings(key string, defaultVal []string) []string // get string slice + DefaultInt(key string, defaultVal int) int + DefaultInt64(key string, defaultVal int64) int64 + DefaultBool(key string, defaultVal bool) bool + DefaultFloat(key string, defaultVal float64) float64 + DIY(key string) (interface{}, error) + GetSection(section string) (map[string]string, error) + SaveConfigFile(filename string) error +} + +// Config is the adapter interface for parsing config file to get raw data to Configer. +type Config interface { + Parse(key string) (Configer, error) + ParseData(data []byte) (Configer, error) +} + +var adapters = make(map[string]Config) + +// Register makes a config adapter available by the adapter name. +// If Register is called twice with the same name or if driver is nil, +// it panics. +func Register(name string, adapter Config) { + config.Register(name, &oldToNewConfigAdapter{delegate: adapter}) +} + +// NewConfig adapterName is ini/json/xml/yaml. +// filename is the config file path. +func NewConfig(adapterName, filename string) (Configer, error) { + cfg, err := config.NewConfig(adapterName, filename) + if err != nil { + return nil, err + } + + // it was registered by using Register method + res, ok := cfg.(*oldToNewConfigerAdapter) + if ok { + return res.delegate, nil + } + + return &newToOldConfigerAdapter{ + delegate: cfg, + }, nil +} + +// NewConfigData adapterName is ini/json/xml/yaml. +// data is the config data. +func NewConfigData(adapterName string, data []byte) (Configer, error) { + cfg, err := config.NewConfigData(adapterName, data) + if err != nil { + return nil, err + } + + // it was registered by using Register method + res, ok := cfg.(*oldToNewConfigerAdapter) + if ok { + return res.delegate, nil + } + + return &newToOldConfigerAdapter{ + delegate: cfg, + }, nil +} + +// ExpandValueEnvForMap convert all string value with environment variable. +func ExpandValueEnvForMap(m map[string]interface{}) map[string]interface{} { + return config.ExpandValueEnvForMap(m) +} + +// ExpandValueEnv returns value of convert with environment variable. +// +// Return environment variable if value start with "${" and end with "}". +// Return default value if environment variable is empty or not exist. +// +// It accept value formats "${env}" , "${env||}}" , "${env||defaultValue}" , "defaultvalue". +// Examples: +// v1 := config.ExpandValueEnv("${GOPATH}") // return the GOPATH environment variable. +// v2 := config.ExpandValueEnv("${GOAsta||/usr/local/go}") // return the default value "/usr/local/go/". +// v3 := config.ExpandValueEnv("Astaxie") // return the value "Astaxie". +func ExpandValueEnv(value string) string { + return config.ExpandValueEnv(value) +} + +// ParseBool returns the boolean value represented by the string. +// +// It accepts 1, 1.0, t, T, TRUE, true, True, YES, yes, Yes,Y, y, ON, on, On, +// 0, 0.0, f, F, FALSE, false, False, NO, no, No, N,n, OFF, off, Off. +// Any other value returns an error. +func ParseBool(val interface{}) (value bool, err error) { + return config.ParseBool(val) +} + +// ToString converts values of any type to string. +func ToString(x interface{}) string { + return config.ToString(x) +} diff --git a/pkg/config/config_test.go b/pkg/adapter/config/config_test.go similarity index 100% rename from pkg/config/config_test.go rename to pkg/adapter/config/config_test.go diff --git a/pkg/adapter/config/env/env.go b/pkg/adapter/config/env/env.go new file mode 100644 index 00000000..77d7b53c --- /dev/null +++ b/pkg/adapter/config/env/env.go @@ -0,0 +1,50 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// Copyright 2017 Faissal Elamraoui. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package env is used to parse environment. +package env + +import ( + "github.com/astaxie/beego/pkg/infrastructure/config/env" +) + +// Get returns a value by key. +// If the key does not exist, the default value will be returned. +func Get(key string, defVal string) string { + return env.Get(key, defVal) +} + +// MustGet returns a value by key. +// If the key does not exist, it will return an error. +func MustGet(key string) (string, error) { + return env.MustGet(key) +} + +// Set sets a value in the ENV copy. +// This does not affect the child process environment. +func Set(key string, value string) { + env.Set(key, value) +} + +// MustSet sets a value in the ENV copy and the child process environment. +// It returns an error in case the set operation failed. +func MustSet(key string, value string) error { + return env.MustSet(key, value) +} + +// GetAll returns all keys/values in the current child process environment. +func GetAll() map[string]string { + return env.GetAll() +} diff --git a/pkg/config/env/env_test.go b/pkg/adapter/config/env/env_test.go similarity index 100% rename from pkg/config/env/env_test.go rename to pkg/adapter/config/env/env_test.go diff --git a/pkg/adapter/config/fake.go b/pkg/adapter/config/fake.go new file mode 100644 index 00000000..fac96b41 --- /dev/null +++ b/pkg/adapter/config/fake.go @@ -0,0 +1,25 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package config + +import ( + "github.com/astaxie/beego/pkg/infrastructure/config" +) + +// NewFakeConfig return a fake Configer +func NewFakeConfig() Configer { + new := config.NewFakeConfig() + return &newToOldConfigerAdapter{delegate: new} +} diff --git a/pkg/config/ini_test.go b/pkg/adapter/config/ini_test.go similarity index 100% rename from pkg/config/ini_test.go rename to pkg/adapter/config/ini_test.go diff --git a/pkg/adapter/config/json.go b/pkg/adapter/config/json.go new file mode 100644 index 00000000..d0fe4d09 --- /dev/null +++ b/pkg/adapter/config/json.go @@ -0,0 +1,19 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package config + +import ( + _ "github.com/astaxie/beego/pkg/infrastructure/config/json" +) diff --git a/pkg/config/json/json_test.go b/pkg/adapter/config/json_test.go similarity index 97% rename from pkg/config/json/json_test.go rename to pkg/adapter/config/json_test.go index da87986f..16f42409 100644 --- a/pkg/config/json/json_test.go +++ b/pkg/adapter/config/json_test.go @@ -12,14 +12,12 @@ // See the License for the specific language governing permissions and // limitations under the License. -package json +package config import ( "fmt" "os" "testing" - - "github.com/astaxie/beego/pkg/config" ) func TestJsonStartsWithArray(t *testing.T) { @@ -45,7 +43,7 @@ func TestJsonStartsWithArray(t *testing.T) { } f.Close() defer os.Remove("testjsonWithArray.conf") - jsonconf, err := config.NewConfig("json", "testjsonWithArray.conf") + jsonconf, err := NewConfig("json", "testjsonWithArray.conf") if err != nil { t.Fatal(err) } @@ -145,7 +143,7 @@ func TestJson(t *testing.T) { } f.Close() defer os.Remove("testjson.conf") - jsonconf, err := config.NewConfig("json", "testjson.conf") + jsonconf, err := NewConfig("json", "testjson.conf") if err != nil { t.Fatal(err) } diff --git a/pkg/adapter/config/xml/xml.go b/pkg/adapter/config/xml/xml.go new file mode 100644 index 00000000..f96cdcd6 --- /dev/null +++ b/pkg/adapter/config/xml/xml.go @@ -0,0 +1,34 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package xml for config provider. +// +// depend on github.com/beego/x2j. +// +// go install github.com/beego/x2j. +// +// Usage: +// import( +// _ "github.com/astaxie/beego/config/xml" +// "github.com/astaxie/beego/config" +// ) +// +// cnf, err := config.NewConfig("xml", "config.xml") +// +// More docs http://beego.me/docs/module/config.md +package xml + +import ( + _ "github.com/astaxie/beego/pkg/infrastructure/config/xml" +) diff --git a/pkg/config/xml/xml_test.go b/pkg/adapter/config/xml/xml_test.go similarity index 98% rename from pkg/config/xml/xml_test.go rename to pkg/adapter/config/xml/xml_test.go index b7828933..122c5027 100644 --- a/pkg/config/xml/xml_test.go +++ b/pkg/adapter/config/xml/xml_test.go @@ -19,7 +19,7 @@ import ( "os" "testing" - "github.com/astaxie/beego/pkg/config" + "github.com/astaxie/beego/pkg/adapter/config" ) func TestXML(t *testing.T) { diff --git a/pkg/adapter/config/yaml/yaml.go b/pkg/adapter/config/yaml/yaml.go new file mode 100644 index 00000000..bc2398e9 --- /dev/null +++ b/pkg/adapter/config/yaml/yaml.go @@ -0,0 +1,34 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package yaml for config provider +// +// depend on github.com/beego/goyaml2 +// +// go install github.com/beego/goyaml2 +// +// Usage: +// import( +// _ "github.com/astaxie/beego/config/yaml" +// "github.com/astaxie/beego/config" +// ) +// +// cnf, err := config.NewConfig("yaml", "config.yaml") +// +// More docs http://beego.me/docs/module/config.md +package yaml + +import ( + _ "github.com/astaxie/beego/pkg/infrastructure/config/yaml" +) diff --git a/pkg/config/yaml/yaml_test.go b/pkg/adapter/config/yaml/yaml_test.go similarity index 98% rename from pkg/config/yaml/yaml_test.go rename to pkg/adapter/config/yaml/yaml_test.go index 0e76457f..e4e309a2 100644 --- a/pkg/config/yaml/yaml_test.go +++ b/pkg/adapter/config/yaml/yaml_test.go @@ -19,7 +19,7 @@ import ( "os" "testing" - "github.com/astaxie/beego/pkg/config" + "github.com/astaxie/beego/pkg/adapter/config" ) func TestYaml(t *testing.T) { diff --git a/pkg/adapter/context/acceptencoder.go b/pkg/adapter/context/acceptencoder.go new file mode 100644 index 00000000..e578de45 --- /dev/null +++ b/pkg/adapter/context/acceptencoder.go @@ -0,0 +1,45 @@ +// Copyright 2015 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package context + +import ( + "io" + "net/http" + "os" + + "github.com/astaxie/beego/pkg/server/web/context" +) + +// InitGzip init the gzipcompress +func InitGzip(minLength, compressLevel int, methods []string) { + context.InitGzip(minLength, compressLevel, methods) +} + +// WriteFile reads from file and writes to writer by the specific encoding(gzip/deflate) +func WriteFile(encoding string, writer io.Writer, file *os.File) (bool, string, error) { + return context.WriteFile(encoding, writer, file) +} + +// WriteBody reads writes content to writer by the specific encoding(gzip/deflate) +func WriteBody(encoding string, writer io.Writer, content []byte) (bool, string, error) { + return context.WriteBody(encoding, writer, content) +} + +// ParseEncoding will extract the right encoding for response +// the Accept-Encoding's sec is here: +// http://www.w3.org/Protocols/rfc2616/rfc2616-sec14.html#sec14.3 +func ParseEncoding(r *http.Request) string { + return context.ParseEncoding(r) +} diff --git a/pkg/adapter/context/context.go b/pkg/adapter/context/context.go new file mode 100644 index 00000000..f9d8c624 --- /dev/null +++ b/pkg/adapter/context/context.go @@ -0,0 +1,146 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package context provide the context utils +// Usage: +// +// import "github.com/astaxie/beego/context" +// +// ctx := context.Context{Request:req,ResponseWriter:rw} +// +// more docs http://beego.me/docs/module/context.md +package context + +import ( + "bufio" + "net" + "net/http" + + "github.com/astaxie/beego/pkg/server/web/context" +) + +// commonly used mime-types +const ( + ApplicationJSON = context.ApplicationJSON + ApplicationXML = context.ApplicationXML + ApplicationYAML = context.ApplicationYAML + TextXML = context.TextXML +) + +// NewContext return the Context with Input and Output +func NewContext() *Context { + return (*Context)(context.NewContext()) +} + +// Context Http request context struct including BeegoInput, BeegoOutput, http.Request and http.ResponseWriter. +// BeegoInput and BeegoOutput provides some api to operate request and response more easily. +type Context context.Context + +// Reset init Context, BeegoInput and BeegoOutput +func (ctx *Context) Reset(rw http.ResponseWriter, r *http.Request) { + (*context.Context)(ctx).Reset(rw, r) +} + +// Redirect does redirection to localurl with http header status code. +func (ctx *Context) Redirect(status int, localurl string) { + (*context.Context)(ctx).Redirect(status, localurl) +} + +// Abort stops this request. +// if beego.ErrorMaps exists, panic body. +func (ctx *Context) Abort(status int, body string) { + (*context.Context)(ctx).Abort(status, body) +} + +// WriteString Write string to response body. +// it sends response body. +func (ctx *Context) WriteString(content string) { + (*context.Context)(ctx).WriteString(content) +} + +// GetCookie Get cookie from request by a given key. +// It's alias of BeegoInput.Cookie. +func (ctx *Context) GetCookie(key string) string { + return (*context.Context)(ctx).GetCookie(key) +} + +// SetCookie Set cookie for response. +// It's alias of BeegoOutput.Cookie. +func (ctx *Context) SetCookie(name string, value string, others ...interface{}) { + (*context.Context)(ctx).SetCookie(name, value, others) +} + +// GetSecureCookie Get secure cookie from request by a given key. +func (ctx *Context) GetSecureCookie(Secret, key string) (string, bool) { + return (*context.Context)(ctx).GetSecureCookie(Secret, key) +} + +// SetSecureCookie Set Secure cookie for response. +func (ctx *Context) SetSecureCookie(Secret, name, value string, others ...interface{}) { + (*context.Context)(ctx).SetSecureCookie(Secret, name, value, others) +} + +// XSRFToken creates a xsrf token string and returns. +func (ctx *Context) XSRFToken(key string, expire int64) string { + return (*context.Context)(ctx).XSRFToken(key, expire) +} + +// CheckXSRFCookie checks xsrf token in this request is valid or not. +// the token can provided in request header "X-Xsrftoken" and "X-CsrfToken" +// or in form field value named as "_xsrf". +func (ctx *Context) CheckXSRFCookie() bool { + return (*context.Context)(ctx).CheckXSRFCookie() +} + +// RenderMethodResult renders the return value of a controller method to the output +func (ctx *Context) RenderMethodResult(result interface{}) { + (*context.Context)(ctx).RenderMethodResult(result) +} + +// Response is a wrapper for the http.ResponseWriter +// started set to true if response was written to then don't execute other handler +type Response context.Response + +// Write writes the data to the connection as part of an HTTP reply, +// and sets `started` to true. +// started means the response has sent out. +func (r *Response) Write(p []byte) (int, error) { + return (*context.Response)(r).Write(p) +} + +// WriteHeader sends an HTTP response header with status code, +// and sets `started` to true. +func (r *Response) WriteHeader(code int) { + (*context.Response)(r).WriteHeader(code) +} + +// Hijack hijacker for http +func (r *Response) Hijack() (net.Conn, *bufio.ReadWriter, error) { + return (*context.Response)(r).Hijack() +} + +// Flush http.Flusher +func (r *Response) Flush() { + (*context.Response)(r).Flush() +} + +// CloseNotify http.CloseNotifier +func (r *Response) CloseNotify() <-chan bool { + return (*context.Response)(r).CloseNotify() +} + +// Pusher http.Pusher +func (r *Response) Pusher() (pusher http.Pusher) { + return (*context.Response)(r).Pusher() +} diff --git a/pkg/adapter/context/input.go b/pkg/adapter/context/input.go new file mode 100644 index 00000000..a1d08855 --- /dev/null +++ b/pkg/adapter/context/input.go @@ -0,0 +1,282 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package context + +import ( + "github.com/astaxie/beego/pkg/server/web/context" +) + +// BeegoInput operates the http request header, data, cookie and body. +// it also contains router params and current session. +type BeegoInput context.BeegoInput + +// NewInput return BeegoInput generated by Context. +func NewInput() *BeegoInput { + return (*BeegoInput)(context.NewInput()) +} + +// Reset init the BeegoInput +func (input *BeegoInput) Reset(ctx *Context) { + (*context.BeegoInput)(input).Reset((*context.Context)(ctx)) +} + +// Protocol returns request protocol name, such as HTTP/1.1 . +func (input *BeegoInput) Protocol() string { + return (*context.BeegoInput)(input).Protocol() +} + +// URI returns full request url with query string, fragment. +func (input *BeegoInput) URI() string { + return input.Context.Request.RequestURI +} + +// URL returns request url path (without query string, fragment). +func (input *BeegoInput) URL() string { + return (*context.BeegoInput)(input).URL() +} + +// Site returns base site url as scheme://domain type. +func (input *BeegoInput) Site() string { + return (*context.BeegoInput)(input).Site() +} + +// Scheme returns request scheme as "http" or "https". +func (input *BeegoInput) Scheme() string { + return (*context.BeegoInput)(input).Scheme() +} + +// Domain returns host name. +// Alias of Host method. +func (input *BeegoInput) Domain() string { + return (*context.BeegoInput)(input).Domain() +} + +// Host returns host name. +// if no host info in request, return localhost. +func (input *BeegoInput) Host() string { + return (*context.BeegoInput)(input).Host() +} + +// Method returns http request method. +func (input *BeegoInput) Method() string { + return (*context.BeegoInput)(input).Method() +} + +// Is returns boolean of this request is on given method, such as Is("POST"). +func (input *BeegoInput) Is(method string) bool { + return (*context.BeegoInput)(input).Is(method) +} + +// IsGet Is this a GET method request? +func (input *BeegoInput) IsGet() bool { + return (*context.BeegoInput)(input).IsGet() +} + +// IsPost Is this a POST method request? +func (input *BeegoInput) IsPost() bool { + return (*context.BeegoInput)(input).IsPost() +} + +// IsHead Is this a Head method request? +func (input *BeegoInput) IsHead() bool { + return (*context.BeegoInput)(input).IsHead() +} + +// IsOptions Is this a OPTIONS method request? +func (input *BeegoInput) IsOptions() bool { + return (*context.BeegoInput)(input).IsOptions() +} + +// IsPut Is this a PUT method request? +func (input *BeegoInput) IsPut() bool { + return (*context.BeegoInput)(input).IsPut() +} + +// IsDelete Is this a DELETE method request? +func (input *BeegoInput) IsDelete() bool { + return (*context.BeegoInput)(input).IsDelete() +} + +// IsPatch Is this a PATCH method request? +func (input *BeegoInput) IsPatch() bool { + return (*context.BeegoInput)(input).IsPatch() +} + +// IsAjax returns boolean of this request is generated by ajax. +func (input *BeegoInput) IsAjax() bool { + return (*context.BeegoInput)(input).IsAjax() +} + +// IsSecure returns boolean of this request is in https. +func (input *BeegoInput) IsSecure() bool { + return (*context.BeegoInput)(input).IsSecure() +} + +// IsWebsocket returns boolean of this request is in webSocket. +func (input *BeegoInput) IsWebsocket() bool { + return (*context.BeegoInput)(input).IsWebsocket() +} + +// IsUpload returns boolean of whether file uploads in this request or not.. +func (input *BeegoInput) IsUpload() bool { + return (*context.BeegoInput)(input).IsUpload() +} + +// AcceptsHTML Checks if request accepts html response +func (input *BeegoInput) AcceptsHTML() bool { + return (*context.BeegoInput)(input).AcceptsHTML() +} + +// AcceptsXML Checks if request accepts xml response +func (input *BeegoInput) AcceptsXML() bool { + return (*context.BeegoInput)(input).AcceptsXML() +} + +// AcceptsJSON Checks if request accepts json response +func (input *BeegoInput) AcceptsJSON() bool { + return (*context.BeegoInput)(input).AcceptsJSON() +} + +// AcceptsYAML Checks if request accepts json response +func (input *BeegoInput) AcceptsYAML() bool { + return (*context.BeegoInput)(input).AcceptsYAML() +} + +// IP returns request client ip. +// if in proxy, return first proxy id. +// if error, return RemoteAddr. +func (input *BeegoInput) IP() string { + return (*context.BeegoInput)(input).IP() +} + +// Proxy returns proxy client ips slice. +func (input *BeegoInput) Proxy() []string { + return (*context.BeegoInput)(input).Proxy() +} + +// Referer returns http referer header. +func (input *BeegoInput) Referer() string { + return (*context.BeegoInput)(input).Referer() +} + +// Refer returns http referer header. +func (input *BeegoInput) Refer() string { + return (*context.BeegoInput)(input).Refer() +} + +// SubDomains returns sub domain string. +// if aa.bb.domain.com, returns aa.bb . +func (input *BeegoInput) SubDomains() string { + return (*context.BeegoInput)(input).SubDomains() +} + +// Port returns request client port. +// when error or empty, return 80. +func (input *BeegoInput) Port() int { + return (*context.BeegoInput)(input).Port() +} + +// UserAgent returns request client user agent string. +func (input *BeegoInput) UserAgent() string { + return (*context.BeegoInput)(input).UserAgent() +} + +// ParamsLen return the length of the params +func (input *BeegoInput) ParamsLen() int { + return (*context.BeegoInput)(input).ParamsLen() +} + +// Param returns router param by a given key. +func (input *BeegoInput) Param(key string) string { + return (*context.BeegoInput)(input).Param(key) +} + +// Params returns the map[key]value. +func (input *BeegoInput) Params() map[string]string { + return (*context.BeegoInput)(input).Params() +} + +// SetParam will set the param with key and value +func (input *BeegoInput) SetParam(key, val string) { + (*context.BeegoInput)(input).SetParam(key, val) +} + +// ResetParams clears any of the input's Params +// This function is used to clear parameters so they may be reset between filter +// passes. +func (input *BeegoInput) ResetParams() { + (*context.BeegoInput)(input).ResetParams() +} + +// Query returns input data item string by a given string. +func (input *BeegoInput) Query(key string) string { + return (*context.BeegoInput)(input).Query(key) +} + +// Header returns request header item string by a given string. +// if non-existed, return empty string. +func (input *BeegoInput) Header(key string) string { + return (*context.BeegoInput)(input).Header(key) +} + +// Cookie returns request cookie item string by a given key. +// if non-existed, return empty string. +func (input *BeegoInput) Cookie(key string) string { + return (*context.BeegoInput)(input).Cookie(key) +} + +// Session returns current session item value by a given key. +// if non-existed, return nil. +func (input *BeegoInput) Session(key interface{}) interface{} { + return (*context.BeegoInput)(input).Session(key) +} + +// CopyBody returns the raw request body data as bytes. +func (input *BeegoInput) CopyBody(MaxMemory int64) []byte { + return (*context.BeegoInput)(input).CopyBody(MaxMemory) +} + +// Data return the implicit data in the input +func (input *BeegoInput) Data() map[interface{}]interface{} { + return (*context.BeegoInput)(input).Data() +} + +// GetData returns the stored data in this context. +func (input *BeegoInput) GetData(key interface{}) interface{} { + return (*context.BeegoInput)(input).GetData(key) +} + +// SetData stores data with given key in this context. +// This data are only available in this context. +func (input *BeegoInput) SetData(key, val interface{}) { + (*context.BeegoInput)(input).SetData(key, val) +} + +// ParseFormOrMulitForm parseForm or parseMultiForm based on Content-type +func (input *BeegoInput) ParseFormOrMulitForm(maxMemory int64) error { + return (*context.BeegoInput)(input).ParseFormOrMulitForm(maxMemory) +} + +// Bind data from request.Form[key] to dest +// like /?id=123&isok=true&ft=1.2&ol[0]=1&ol[1]=2&ul[]=str&ul[]=array&user.Name=astaxie +// var id int beegoInput.Bind(&id, "id") id ==123 +// var isok bool beegoInput.Bind(&isok, "isok") isok ==true +// var ft float64 beegoInput.Bind(&ft, "ft") ft ==1.2 +// ol := make([]int, 0, 2) beegoInput.Bind(&ol, "ol") ol ==[1 2] +// ul := make([]string, 0, 2) beegoInput.Bind(&ul, "ul") ul ==[str array] +// user struct{Name} beegoInput.Bind(&user, "user") user == {Name:"astaxie"} +func (input *BeegoInput) Bind(dest interface{}, key string) error { + return (*context.BeegoInput)(input).Bind(dest, key) +} diff --git a/pkg/adapter/context/output.go b/pkg/adapter/context/output.go new file mode 100644 index 00000000..8e2a7f7d --- /dev/null +++ b/pkg/adapter/context/output.go @@ -0,0 +1,154 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package context + +import ( + "github.com/astaxie/beego/pkg/server/web/context" +) + +// BeegoOutput does work for sending response header. +type BeegoOutput context.BeegoOutput + +// NewOutput returns new BeegoOutput. +// it contains nothing now. +func NewOutput() *BeegoOutput { + return (*BeegoOutput)(context.NewOutput()) +} + +// Reset init BeegoOutput +func (output *BeegoOutput) Reset(ctx *Context) { + (*context.BeegoOutput)(output).Reset((*context.Context)(ctx)) +} + +// Header sets response header item string via given key. +func (output *BeegoOutput) Header(key, val string) { + (*context.BeegoOutput)(output).Header(key, val) +} + +// Body sets response body content. +// if EnableGzip, compress content string. +// it sends out response body directly. +func (output *BeegoOutput) Body(content []byte) error { + return (*context.BeegoOutput)(output).Body(content) +} + +// Cookie sets cookie value via given key. +// others are ordered as cookie's max age time, path,domain, secure and httponly. +func (output *BeegoOutput) Cookie(name string, value string, others ...interface{}) { + (*context.BeegoOutput)(output).Cookie(name, value, others) +} + +// JSON writes json to response body. +// if encoding is true, it converts utf-8 to \u0000 type. +func (output *BeegoOutput) JSON(data interface{}, hasIndent bool, encoding bool) error { + return (*context.BeegoOutput)(output).JSON(data, hasIndent, encoding) +} + +// YAML writes yaml to response body. +func (output *BeegoOutput) YAML(data interface{}) error { + return (*context.BeegoOutput)(output).YAML(data) +} + +// JSONP writes jsonp to response body. +func (output *BeegoOutput) JSONP(data interface{}, hasIndent bool) error { + return (*context.BeegoOutput)(output).JSONP(data, hasIndent) +} + +// XML writes xml string to response body. +func (output *BeegoOutput) XML(data interface{}, hasIndent bool) error { + return (*context.BeegoOutput)(output).XML(data, hasIndent) +} + +// ServeFormatted serve YAML, XML OR JSON, depending on the value of the Accept header +func (output *BeegoOutput) ServeFormatted(data interface{}, hasIndent bool, hasEncode ...bool) { + (*context.BeegoOutput)(output).ServeFormatted(data, hasIndent, hasEncode...) +} + +// Download forces response for download file. +// it prepares the download response header automatically. +func (output *BeegoOutput) Download(file string, filename ...string) { + (*context.BeegoOutput)(output).Download(file, filename...) +} + +// ContentType sets the content type from ext string. +// MIME type is given in mime package. +func (output *BeegoOutput) ContentType(ext string) { + (*context.BeegoOutput)(output).ContentType(ext) +} + +// SetStatus sets response status code. +// It writes response header directly. +func (output *BeegoOutput) SetStatus(status int) { + (*context.BeegoOutput)(output).SetStatus(status) +} + +// IsCachable returns boolean of this request is cached. +// HTTP 304 means cached. +func (output *BeegoOutput) IsCachable() bool { + return (*context.BeegoOutput)(output).IsCachable() +} + +// IsEmpty returns boolean of this request is empty. +// HTTP 201,204 and 304 means empty. +func (output *BeegoOutput) IsEmpty() bool { + return (*context.BeegoOutput)(output).IsEmpty() +} + +// IsOk returns boolean of this request runs well. +// HTTP 200 means ok. +func (output *BeegoOutput) IsOk() bool { + return (*context.BeegoOutput)(output).IsOk() +} + +// IsSuccessful returns boolean of this request runs successfully. +// HTTP 2xx means ok. +func (output *BeegoOutput) IsSuccessful() bool { + return (*context.BeegoOutput)(output).IsSuccessful() +} + +// IsRedirect returns boolean of this request is redirection header. +// HTTP 301,302,307 means redirection. +func (output *BeegoOutput) IsRedirect() bool { + return (*context.BeegoOutput)(output).IsRedirect() +} + +// IsForbidden returns boolean of this request is forbidden. +// HTTP 403 means forbidden. +func (output *BeegoOutput) IsForbidden() bool { + return (*context.BeegoOutput)(output).IsForbidden() +} + +// IsNotFound returns boolean of this request is not found. +// HTTP 404 means not found. +func (output *BeegoOutput) IsNotFound() bool { + return (*context.BeegoOutput)(output).IsNotFound() +} + +// IsClientError returns boolean of this request client sends error data. +// HTTP 4xx means client error. +func (output *BeegoOutput) IsClientError() bool { + return (*context.BeegoOutput)(output).IsClientError() +} + +// IsServerError returns boolean of this server handler errors. +// HTTP 5xx means server internal error. +func (output *BeegoOutput) IsServerError() bool { + return (*context.BeegoOutput)(output).IsServerError() +} + +// Session sets session item value with given key. +func (output *BeegoOutput) Session(name interface{}, value interface{}) { + (*context.BeegoOutput)(output).Session(name, value) +} diff --git a/pkg/adapter/context/renderer.go b/pkg/adapter/context/renderer.go new file mode 100644 index 00000000..763fb9c4 --- /dev/null +++ b/pkg/adapter/context/renderer.go @@ -0,0 +1,8 @@ +package context + +import ( + "github.com/astaxie/beego/pkg/server/web/context" +) + +// Renderer defines an http response renderer +type Renderer context.Renderer diff --git a/pkg/adapter/context/response.go b/pkg/adapter/context/response.go new file mode 100644 index 00000000..24e196a4 --- /dev/null +++ b/pkg/adapter/context/response.go @@ -0,0 +1,26 @@ +package context + +import ( + "net/http" + "strconv" +) + +const ( + // BadRequest indicates http error 400 + BadRequest StatusCode = http.StatusBadRequest + + // NotFound indicates http error 404 + NotFound StatusCode = http.StatusNotFound +) + +// StatusCode sets the http response status code +type StatusCode int + +func (s StatusCode) Error() string { + return strconv.Itoa(int(s)) +} + +// Render sets the http status code +func (s StatusCode) Render(ctx *Context) { + ctx.Output.SetStatus(int(s)) +} diff --git a/pkg/adapter/controller.go b/pkg/adapter/controller.go new file mode 100644 index 00000000..010add64 --- /dev/null +++ b/pkg/adapter/controller.go @@ -0,0 +1,401 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package adapter + +import ( + "mime/multipart" + "net/url" + + "github.com/astaxie/beego/pkg/adapter/context" + "github.com/astaxie/beego/pkg/adapter/session" + webContext "github.com/astaxie/beego/pkg/server/web/context" + + "github.com/astaxie/beego/pkg/server/web" +) + +var ( + // ErrAbort custom error when user stop request handler manually. + ErrAbort = web.ErrAbort + // GlobalControllerRouter store comments with controller. pkgpath+controller:comments + GlobalControllerRouter = web.GlobalControllerRouter +) + +// ControllerFilter store the filter for controller +type ControllerFilter web.ControllerFilter + +// ControllerFilterComments store the comment for controller level filter +type ControllerFilterComments web.ControllerFilterComments + +// ControllerImportComments store the import comment for controller needed +type ControllerImportComments web.ControllerImportComments + +// ControllerComments store the comment for the controller method +type ControllerComments web.ControllerComments + +// ControllerCommentsSlice implements the sort interface +type ControllerCommentsSlice web.ControllerCommentsSlice + +func (p ControllerCommentsSlice) Len() int { + return (web.ControllerCommentsSlice)(p).Len() +} +func (p ControllerCommentsSlice) Less(i, j int) bool { + return (web.ControllerCommentsSlice)(p).Less(i, j) +} +func (p ControllerCommentsSlice) Swap(i, j int) { + (web.ControllerCommentsSlice)(p).Swap(i, j) +} + +// Controller defines some basic http request handler operations, such as +// http context, template and view, session and xsrf. +type Controller web.Controller + +// ControllerInterface is an interface to uniform all controller handler. +type ControllerInterface web.ControllerInterface + +// Init generates default values of controller operations. +func (c *Controller) Init(ctx *context.Context, controllerName, actionName string, app interface{}) { + (*web.Controller)(c).Init((*webContext.Context)(ctx), controllerName, actionName, app) +} + +// Prepare runs after Init before request function execution. +func (c *Controller) Prepare() { + (*web.Controller)(c).Prepare() +} + +// Finish runs after request function execution. +func (c *Controller) Finish() { + (*web.Controller)(c).Finish() +} + +// Get adds a request function to handle GET request. +func (c *Controller) Get() { + (*web.Controller)(c).Get() +} + +// Post adds a request function to handle POST request. +func (c *Controller) Post() { + (*web.Controller)(c).Post() +} + +// Delete adds a request function to handle DELETE request. +func (c *Controller) Delete() { + (*web.Controller)(c).Delete() +} + +// Put adds a request function to handle PUT request. +func (c *Controller) Put() { + (*web.Controller)(c).Put() +} + +// Head adds a request function to handle HEAD request. +func (c *Controller) Head() { + (*web.Controller)(c).Head() +} + +// Patch adds a request function to handle PATCH request. +func (c *Controller) Patch() { + (*web.Controller)(c).Patch() +} + +// Options adds a request function to handle OPTIONS request. +func (c *Controller) Options() { + (*web.Controller)(c).Options() +} + +// Trace adds a request function to handle Trace request. +// this method SHOULD NOT be overridden. +// https://tools.ietf.org/html/rfc7231#section-4.3.8 +// The TRACE method requests a remote, application-level loop-back of +// the request message. The final recipient of the request SHOULD +// reflect the message received, excluding some fields described below, +// back to the client as the message body of a 200 (OK) response with a +// Content-Type of "message/http" (Section 8.3.1 of [RFC7230]). +func (c *Controller) Trace() { + (*web.Controller)(c).Trace() +} + +// HandlerFunc call function with the name +func (c *Controller) HandlerFunc(fnname string) bool { + return (*web.Controller)(c).HandlerFunc(fnname) +} + +// URLMapping register the internal Controller router. +func (c *Controller) URLMapping() { + (*web.Controller)(c).URLMapping() +} + +// Mapping the method to function +func (c *Controller) Mapping(method string, fn func()) { + (*web.Controller)(c).Mapping(method, fn) +} + +// Render sends the response with rendered template bytes as text/html type. +func (c *Controller) Render() error { + return (*web.Controller)(c).Render() +} + +// RenderString returns the rendered template string. Do not send out response. +func (c *Controller) RenderString() (string, error) { + return (*web.Controller)(c).RenderString() +} + +// RenderBytes returns the bytes of rendered template string. Do not send out response. +func (c *Controller) RenderBytes() ([]byte, error) { + return (*web.Controller)(c).RenderBytes() +} + +// Redirect sends the redirection response to url with status code. +func (c *Controller) Redirect(url string, code int) { + (*web.Controller)(c).Redirect(url, code) +} + +// SetData set the data depending on the accepted +func (c *Controller) SetData(data interface{}) { + (*web.Controller)(c).SetData(data) +} + +// Abort stops controller handler and show the error data if code is defined in ErrorMap or code string. +func (c *Controller) Abort(code string) { + (*web.Controller)(c).Abort(code) +} + +// CustomAbort stops controller handler and show the error data, it's similar Aborts, but support status code and body. +func (c *Controller) CustomAbort(status int, body string) { + (*web.Controller)(c).CustomAbort(status, body) +} + +// StopRun makes panic of USERSTOPRUN error and go to recover function if defined. +func (c *Controller) StopRun() { + (*web.Controller)(c).StopRun() +} + +// URLFor does another controller handler in this request function. +// it goes to this controller method if endpoint is not clear. +func (c *Controller) URLFor(endpoint string, values ...interface{}) string { + return (*web.Controller)(c).URLFor(endpoint, values...) +} + +// ServeJSON sends a json response with encoding charset. +func (c *Controller) ServeJSON(encoding ...bool) { + (*web.Controller)(c).ServeJSON(encoding...) +} + +// ServeJSONP sends a jsonp response. +func (c *Controller) ServeJSONP() { + (*web.Controller)(c).ServeJSONP() +} + +// ServeXML sends xml response. +func (c *Controller) ServeXML() { + (*web.Controller)(c).ServeXML() +} + +// ServeYAML sends yaml response. +func (c *Controller) ServeYAML() { + (*web.Controller)(c).ServeYAML() +} + +// ServeFormatted serve YAML, XML OR JSON, depending on the value of the Accept header +func (c *Controller) ServeFormatted(encoding ...bool) { + (*web.Controller)(c).ServeFormatted(encoding...) +} + +// Input returns the input data map from POST or PUT request body and query string. +func (c *Controller) Input() url.Values { + return (*web.Controller)(c).Input() +} + +// ParseForm maps input data map to obj struct. +func (c *Controller) ParseForm(obj interface{}) error { + return (*web.Controller)(c).ParseForm(obj) +} + +// GetString returns the input value by key string or the default value while it's present and input is blank +func (c *Controller) GetString(key string, def ...string) string { + return (*web.Controller)(c).GetString(key, def...) +} + +// GetStrings returns the input string slice by key string or the default value while it's present and input is blank +// it's designed for multi-value input field such as checkbox(input[type=checkbox]), multi-selection. +func (c *Controller) GetStrings(key string, def ...[]string) []string { + return (*web.Controller)(c).GetStrings(key, def...) +} + +// GetInt returns input as an int or the default value while it's present and input is blank +func (c *Controller) GetInt(key string, def ...int) (int, error) { + return (*web.Controller)(c).GetInt(key, def...) +} + +// GetInt8 return input as an int8 or the default value while it's present and input is blank +func (c *Controller) GetInt8(key string, def ...int8) (int8, error) { + return (*web.Controller)(c).GetInt8(key, def...) +} + +// GetUint8 return input as an uint8 or the default value while it's present and input is blank +func (c *Controller) GetUint8(key string, def ...uint8) (uint8, error) { + return (*web.Controller)(c).GetUint8(key, def...) +} + +// GetInt16 returns input as an int16 or the default value while it's present and input is blank +func (c *Controller) GetInt16(key string, def ...int16) (int16, error) { + return (*web.Controller)(c).GetInt16(key, def...) +} + +// GetUint16 returns input as an uint16 or the default value while it's present and input is blank +func (c *Controller) GetUint16(key string, def ...uint16) (uint16, error) { + return (*web.Controller)(c).GetUint16(key, def...) +} + +// GetInt32 returns input as an int32 or the default value while it's present and input is blank +func (c *Controller) GetInt32(key string, def ...int32) (int32, error) { + return (*web.Controller)(c).GetInt32(key, def...) +} + +// GetUint32 returns input as an uint32 or the default value while it's present and input is blank +func (c *Controller) GetUint32(key string, def ...uint32) (uint32, error) { + return (*web.Controller)(c).GetUint32(key, def...) +} + +// GetInt64 returns input value as int64 or the default value while it's present and input is blank. +func (c *Controller) GetInt64(key string, def ...int64) (int64, error) { + return (*web.Controller)(c).GetInt64(key, def...) +} + +// GetUint64 returns input value as uint64 or the default value while it's present and input is blank. +func (c *Controller) GetUint64(key string, def ...uint64) (uint64, error) { + return (*web.Controller)(c).GetUint64(key, def...) +} + +// GetBool returns input value as bool or the default value while it's present and input is blank. +func (c *Controller) GetBool(key string, def ...bool) (bool, error) { + return (*web.Controller)(c).GetBool(key, def...) +} + +// GetFloat returns input value as float64 or the default value while it's present and input is blank. +func (c *Controller) GetFloat(key string, def ...float64) (float64, error) { + return (*web.Controller)(c).GetFloat(key, def...) +} + +// GetFile returns the file data in file upload field named as key. +// it returns the first one of multi-uploaded files. +func (c *Controller) GetFile(key string) (multipart.File, *multipart.FileHeader, error) { + return (*web.Controller)(c).GetFile(key) +} + +// GetFiles return multi-upload files +// files, err:=c.GetFiles("myfiles") +// if err != nil { +// http.Error(w, err.Error(), http.StatusNoContent) +// return +// } +// for i, _ := range files { +// //for each fileheader, get a handle to the actual file +// file, err := files[i].Open() +// defer file.Close() +// if err != nil { +// http.Error(w, err.Error(), http.StatusInternalServerError) +// return +// } +// //create destination file making sure the path is writeable. +// dst, err := os.Create("upload/" + files[i].Filename) +// defer dst.Close() +// if err != nil { +// http.Error(w, err.Error(), http.StatusInternalServerError) +// return +// } +// //copy the uploaded file to the destination file +// if _, err := io.Copy(dst, file); err != nil { +// http.Error(w, err.Error(), http.StatusInternalServerError) +// return +// } +// } +func (c *Controller) GetFiles(key string) ([]*multipart.FileHeader, error) { + return (*web.Controller)(c).GetFiles(key) +} + +// SaveToFile saves uploaded file to new path. +// it only operates the first one of mutil-upload form file field. +func (c *Controller) SaveToFile(fromfile, tofile string) error { + return (*web.Controller)(c).SaveToFile(fromfile, tofile) +} + +// StartSession starts session and load old session data info this controller. +func (c *Controller) StartSession() session.Store { + s := (*web.Controller)(c).StartSession() + return session.CreateNewToOldStoreAdapter(s) +} + +// SetSession puts value into session. +func (c *Controller) SetSession(name interface{}, value interface{}) { + (*web.Controller)(c).SetSession(name, value) +} + +// GetSession gets value from session. +func (c *Controller) GetSession(name interface{}) interface{} { + return (*web.Controller)(c).GetSession(name) +} + +// DelSession removes value from session. +func (c *Controller) DelSession(name interface{}) { + (*web.Controller)(c).DelSession(name) +} + +// SessionRegenerateID regenerates session id for this session. +// the session data have no changes. +func (c *Controller) SessionRegenerateID() { + (*web.Controller)(c).SessionRegenerateID() +} + +// DestroySession cleans session data and session cookie. +func (c *Controller) DestroySession() { + (*web.Controller)(c).DestroySession() +} + +// IsAjax returns this request is ajax or not. +func (c *Controller) IsAjax() bool { + return (*web.Controller)(c).IsAjax() +} + +// GetSecureCookie returns decoded cookie value from encoded browser cookie values. +func (c *Controller) GetSecureCookie(Secret, key string) (string, bool) { + return (*web.Controller)(c).GetSecureCookie(Secret, key) +} + +// SetSecureCookie puts value into cookie after encoded the value. +func (c *Controller) SetSecureCookie(Secret, name, value string, others ...interface{}) { + (*web.Controller)(c).SetSecureCookie(Secret, name, value, others...) +} + +// XSRFToken creates a CSRF token string and returns. +func (c *Controller) XSRFToken() string { + return (*web.Controller)(c).XSRFToken() +} + +// CheckXSRFCookie checks xsrf token in this request is valid or not. +// the token can provided in request header "X-Xsrftoken" and "X-CsrfToken" +// or in form field value named as "_xsrf". +func (c *Controller) CheckXSRFCookie() bool { + return (*web.Controller)(c).CheckXSRFCookie() +} + +// XSRFFormHTML writes an input field contains xsrf token value. +func (c *Controller) XSRFFormHTML() string { + return (*web.Controller)(c).XSRFFormHTML() +} + +// GetControllerAndAction gets the executing controller name and action name. +func (c *Controller) GetControllerAndAction() (string, string) { + return (*web.Controller)(c).GetControllerAndAction() +} diff --git a/pkg/web/doc.go b/pkg/adapter/doc.go similarity index 87% rename from pkg/web/doc.go rename to pkg/adapter/doc.go index 1425a729..c8f2174c 100644 --- a/pkg/web/doc.go +++ b/pkg/adapter/doc.go @@ -1,4 +1,4 @@ -// Copyright 2020 beego +// Copyright 2020 // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -12,5 +12,5 @@ // See the License for the specific language governing permissions and // limitations under the License. -// we will move all web related codes here -package web +// used to keep compatible with v1.x +package adapter diff --git a/pkg/adapter/error.go b/pkg/adapter/error.go new file mode 100644 index 00000000..4f08aa8c --- /dev/null +++ b/pkg/adapter/error.go @@ -0,0 +1,202 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package adapter + +import ( + "net/http" + + "github.com/astaxie/beego/pkg/adapter/context" + beecontext "github.com/astaxie/beego/pkg/server/web/context" + + "github.com/astaxie/beego/pkg/server/web" +) + +const ( + errorTypeHandler = iota + errorTypeController +) + +var tpl = ` + + + + + beego application error + + + + + +
+ + + + + + + + + + +
Request Method: {{.RequestMethod}}
Request URL: {{.RequestURL}}
RemoteAddr: {{.RemoteAddr }}
+
+ Stack +
{{.Stack}}
+
+
+ + + +` + +var errtpl = ` + + + + + {{.Title}} + + + +
+
+ +
+ {{.Content}} + Go Home
+ +
Powered by beego {{.BeegoVersion}} +
+
+
+ + +` + +// ErrorMaps holds map of http handlers for each error string. +// there is 10 kinds default error(40x and 50x) +var ErrorMaps = web.ErrorMaps + +// ErrorHandler registers http.HandlerFunc to each http err code string. +// usage: +// beego.ErrorHandler("404",NotFound) +// beego.ErrorHandler("500",InternalServerError) +func ErrorHandler(code string, h http.HandlerFunc) *App { + return (*App)(web.ErrorHandler(code, h)) +} + +// ErrorController registers ControllerInterface to each http err code string. +// usage: +// beego.ErrorController(&controllers.ErrorController{}) +func ErrorController(c ControllerInterface) *App { + return (*App)(web.ErrorController(c)) +} + +// Exception Write HttpStatus with errCode and Exec error handler if exist. +func Exception(errCode uint64, ctx *context.Context) { + web.Exception(errCode, (*beecontext.Context)(ctx)) +} diff --git a/pkg/adapter/filter.go b/pkg/adapter/filter.go new file mode 100644 index 00000000..cafed773 --- /dev/null +++ b/pkg/adapter/filter.go @@ -0,0 +1,36 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package adapter + +import ( + "github.com/astaxie/beego/pkg/adapter/context" + "github.com/astaxie/beego/pkg/server/web" + beecontext "github.com/astaxie/beego/pkg/server/web/context" +) + +// FilterFunc defines a filter function which is invoked before the controller handler is executed. +type FilterFunc func(*context.Context) + +// FilterRouter defines a filter operation which is invoked before the controller handler is executed. +// It can match the URL against a pattern, and execute a filter function +// when a request with a matching URL arrives. +type FilterRouter web.FilterRouter + +// ValidRouter checks if the current request is matched by this filter. +// If the request is matched, the values of the URL parameters defined +// by the filter pattern are also returned. +func (f *FilterRouter) ValidRouter(url string, ctx *context.Context) bool { + return (*web.FilterRouter)(f).ValidRouter(url, (*beecontext.Context)(ctx)) +} diff --git a/pkg/adapter/flash.go b/pkg/adapter/flash.go new file mode 100644 index 00000000..02e75ed6 --- /dev/null +++ b/pkg/adapter/flash.go @@ -0,0 +1,63 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package adapter + +import ( + "github.com/astaxie/beego/pkg/server/web" +) + +// FlashData is a tools to maintain data when using across request. +type FlashData web.FlashData + +// NewFlash return a new empty FlashData struct. +func NewFlash() *FlashData { + return (*FlashData)(web.NewFlash()) +} + +// Set message to flash +func (fd *FlashData) Set(key string, msg string, args ...interface{}) { + (*web.FlashData)(fd).Set(key, msg, args...) +} + +// Success writes success message to flash. +func (fd *FlashData) Success(msg string, args ...interface{}) { + (*web.FlashData)(fd).Success(msg, args...) +} + +// Notice writes notice message to flash. +func (fd *FlashData) Notice(msg string, args ...interface{}) { + (*web.FlashData)(fd).Notice(msg, args...) +} + +// Warning writes warning message to flash. +func (fd *FlashData) Warning(msg string, args ...interface{}) { + (*web.FlashData)(fd).Warning(msg, args...) +} + +// Error writes error message to flash. +func (fd *FlashData) Error(msg string, args ...interface{}) { + (*web.FlashData)(fd).Error(msg, args...) +} + +// Store does the saving operation of flash data. +// the data are encoded and saved in cookie. +func (fd *FlashData) Store(c *Controller) { + (*web.FlashData)(fd).Store((*web.Controller)(c)) +} + +// ReadFromRequest parsed flash data from encoded values in cookie. +func ReadFromRequest(c *Controller) *FlashData { + return (*FlashData)(web.ReadFromRequest((*web.Controller)(c))) +} diff --git a/pkg/adapter/fs.go b/pkg/adapter/fs.go new file mode 100644 index 00000000..07054ca3 --- /dev/null +++ b/pkg/adapter/fs.go @@ -0,0 +1,35 @@ +// Copyright 2020 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package adapter + +import ( + "net/http" + "path/filepath" + + "github.com/astaxie/beego/pkg/server/web" +) + +type FileSystem web.FileSystem + +func (d FileSystem) Open(name string) (http.File, error) { + return (web.FileSystem)(d).Open(name) +} + +// Walk walks the file tree rooted at root in filesystem, calling walkFn for each file or +// directory in the tree, including root. All errors that arise visiting files +// and directories are filtered by walkFn. +func Walk(fs http.FileSystem, root string, walkFn filepath.WalkFunc) error { + return web.Walk(fs, root, walkFn) +} diff --git a/pkg/adapter/grace/grace.go b/pkg/adapter/grace/grace.go new file mode 100644 index 00000000..3775e395 --- /dev/null +++ b/pkg/adapter/grace/grace.go @@ -0,0 +1,94 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package grace use to hot reload +// Description: http://grisha.org/blog/2014/06/03/graceful-restart-in-golang/ +// +// Usage: +// +// import( +// "log" +// "net/http" +// "os" +// +// "github.com/astaxie/beego/grace" +// ) +// +// func handler(w http.ResponseWriter, r *http.Request) { +// w.Write([]byte("WORLD!")) +// } +// +// func main() { +// mux := http.NewServeMux() +// mux.HandleFunc("/hello", handler) +// +// err := grace.ListenAndServe("localhost:8080", mux) +// if err != nil { +// log.Println(err) +// } +// log.Println("Server on 8080 stopped") +// os.Exit(0) +// } +package grace + +import ( + "net/http" + "time" + + "github.com/astaxie/beego/pkg/server/web/grace" +) + +const ( + // PreSignal is the position to add filter before signal + PreSignal = iota + // PostSignal is the position to add filter after signal + PostSignal + // StateInit represent the application inited + StateInit + // StateRunning represent the application is running + StateRunning + // StateShuttingDown represent the application is shutting down + StateShuttingDown + // StateTerminate represent the application is killed + StateTerminate +) + +var ( + + // DefaultReadTimeOut is the HTTP read timeout + DefaultReadTimeOut time.Duration + // DefaultWriteTimeOut is the HTTP Write timeout + DefaultWriteTimeOut time.Duration + // DefaultMaxHeaderBytes is the Max HTTP Header size, default is 0, no limit + DefaultMaxHeaderBytes int + // DefaultTimeout is the shutdown server's timeout. default is 60s + DefaultTimeout = grace.DefaultTimeout +) + +// NewServer returns a new graceServer. +func NewServer(addr string, handler http.Handler) (srv *Server) { + return (*Server)(grace.NewServer(addr, handler)) +} + +// ListenAndServe refer http.ListenAndServe +func ListenAndServe(addr string, handler http.Handler) error { + server := NewServer(addr, handler) + return server.ListenAndServe() +} + +// ListenAndServeTLS refer http.ListenAndServeTLS +func ListenAndServeTLS(addr string, certFile string, keyFile string, handler http.Handler) error { + server := NewServer(addr, handler) + return server.ListenAndServeTLS(certFile, keyFile) +} diff --git a/pkg/adapter/grace/server.go b/pkg/adapter/grace/server.go new file mode 100644 index 00000000..31c13f18 --- /dev/null +++ b/pkg/adapter/grace/server.go @@ -0,0 +1,48 @@ +package grace + +import ( + "os" + + "github.com/astaxie/beego/pkg/server/web/grace" +) + +// Server embedded http.Server +type Server grace.Server + +// Serve accepts incoming connections on the Listener l, +// creating a new service goroutine for each. +// The service goroutines read requests and then call srv.Handler to reply to them. +func (srv *Server) Serve() (err error) { + return (*grace.Server)(srv).Serve() +} + +// ListenAndServe listens on the TCP network address srv.Addr and then calls Serve +// to handle requests on incoming connections. If srv.Addr is blank, ":http" is +// used. +func (srv *Server) ListenAndServe() (err error) { + return (*grace.Server)(srv).ListenAndServe() +} + +// ListenAndServeTLS listens on the TCP network address srv.Addr and then calls +// Serve to handle requests on incoming TLS connections. +// +// Filenames containing a certificate and matching private key for the server must +// be provided. If the certificate is signed by a certificate authority, the +// certFile should be the concatenation of the server's certificate followed by the +// CA's certificate. +// +// If srv.Addr is blank, ":https" is used. +func (srv *Server) ListenAndServeTLS(certFile, keyFile string) error { + return (*grace.Server)(srv).ListenAndServeTLS(certFile, keyFile) +} + +// ListenAndServeMutualTLS listens on the TCP network address srv.Addr and then calls +// Serve to handle requests on incoming mutual TLS connections. +func (srv *Server) ListenAndServeMutualTLS(certFile, keyFile, trustFile string) error { + return (*grace.Server)(srv).ListenAndServeMutualTLS(certFile, keyFile, trustFile) +} + +// RegisterSignalHook registers a function to be run PreSignal or PostSignal for a given signal. +func (srv *Server) RegisterSignalHook(ppFlag int, sig os.Signal, f func()) error { + return (*grace.Server)(srv).RegisterSignalHook(ppFlag, sig, f) +} diff --git a/pkg/adapter/httplib/httplib.go b/pkg/adapter/httplib/httplib.go new file mode 100644 index 00000000..d2ef36c1 --- /dev/null +++ b/pkg/adapter/httplib/httplib.go @@ -0,0 +1,300 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package httplib is used as http.Client +// Usage: +// +// import "github.com/astaxie/beego/httplib" +// +// b := httplib.Post("http://beego.me/") +// b.Param("username","astaxie") +// b.Param("password","123456") +// b.PostFile("uploadfile1", "httplib.pdf") +// b.PostFile("uploadfile2", "httplib.txt") +// str, err := b.String() +// if err != nil { +// t.Fatal(err) +// } +// fmt.Println(str) +// +// more docs http://beego.me/docs/module/httplib.md +package httplib + +import ( + "crypto/tls" + "net" + "net/http" + "net/url" + "time" + + "github.com/astaxie/beego/pkg/client/httplib" +) + +// SetDefaultSetting Overwrite default settings +func SetDefaultSetting(setting BeegoHTTPSettings) { + httplib.SetDefaultSetting(httplib.BeegoHTTPSettings(setting)) +} + +// NewBeegoRequest return *BeegoHttpRequest with specific method +func NewBeegoRequest(rawurl, method string) *BeegoHTTPRequest { + return &BeegoHTTPRequest{ + delegate: httplib.NewBeegoRequest(rawurl, method), + } +} + +// Get returns *BeegoHttpRequest with GET method. +func Get(url string) *BeegoHTTPRequest { + return NewBeegoRequest(url, "GET") +} + +// Post returns *BeegoHttpRequest with POST method. +func Post(url string) *BeegoHTTPRequest { + return NewBeegoRequest(url, "POST") +} + +// Put returns *BeegoHttpRequest with PUT method. +func Put(url string) *BeegoHTTPRequest { + return NewBeegoRequest(url, "PUT") +} + +// Delete returns *BeegoHttpRequest DELETE method. +func Delete(url string) *BeegoHTTPRequest { + return NewBeegoRequest(url, "DELETE") +} + +// Head returns *BeegoHttpRequest with HEAD method. +func Head(url string) *BeegoHTTPRequest { + return NewBeegoRequest(url, "HEAD") +} + +// BeegoHTTPSettings is the http.Client setting +type BeegoHTTPSettings httplib.BeegoHTTPSettings + +// BeegoHTTPRequest provides more useful methods for requesting one url than http.Request. +type BeegoHTTPRequest struct { + delegate *httplib.BeegoHTTPRequest +} + +// GetRequest return the request object +func (b *BeegoHTTPRequest) GetRequest() *http.Request { + return b.delegate.GetRequest() +} + +// Setting Change request settings +func (b *BeegoHTTPRequest) Setting(setting BeegoHTTPSettings) *BeegoHTTPRequest { + b.delegate.Setting(httplib.BeegoHTTPSettings(setting)) + return b +} + +// SetBasicAuth sets the request's Authorization header to use HTTP Basic Authentication with the provided username and password. +func (b *BeegoHTTPRequest) SetBasicAuth(username, password string) *BeegoHTTPRequest { + b.delegate.SetBasicAuth(username, password) + return b +} + +// SetEnableCookie sets enable/disable cookiejar +func (b *BeegoHTTPRequest) SetEnableCookie(enable bool) *BeegoHTTPRequest { + b.delegate.SetEnableCookie(enable) + return b +} + +// SetUserAgent sets User-Agent header field +func (b *BeegoHTTPRequest) SetUserAgent(useragent string) *BeegoHTTPRequest { + b.delegate.SetUserAgent(useragent) + return b +} + +// Debug sets show debug or not when executing request. +func (b *BeegoHTTPRequest) Debug(isdebug bool) *BeegoHTTPRequest { + b.delegate.Debug(isdebug) + return b +} + +// Retries sets Retries times. +// default is 0 means no retried. +// -1 means retried forever. +// others means retried times. +func (b *BeegoHTTPRequest) Retries(times int) *BeegoHTTPRequest { + b.delegate.Retries(times) + return b +} + +func (b *BeegoHTTPRequest) RetryDelay(delay time.Duration) *BeegoHTTPRequest { + b.delegate.RetryDelay(delay) + return b +} + +// DumpBody setting whether need to Dump the Body. +func (b *BeegoHTTPRequest) DumpBody(isdump bool) *BeegoHTTPRequest { + b.delegate.DumpBody(isdump) + return b +} + +// DumpRequest return the DumpRequest +func (b *BeegoHTTPRequest) DumpRequest() []byte { + return b.delegate.DumpRequest() +} + +// SetTimeout sets connect time out and read-write time out for BeegoRequest. +func (b *BeegoHTTPRequest) SetTimeout(connectTimeout, readWriteTimeout time.Duration) *BeegoHTTPRequest { + b.delegate.SetTimeout(connectTimeout, readWriteTimeout) + return b +} + +// SetTLSClientConfig sets tls connection configurations if visiting https url. +func (b *BeegoHTTPRequest) SetTLSClientConfig(config *tls.Config) *BeegoHTTPRequest { + b.delegate.SetTLSClientConfig(config) + return b +} + +// Header add header item string in request. +func (b *BeegoHTTPRequest) Header(key, value string) *BeegoHTTPRequest { + b.delegate.Header(key, value) + return b +} + +// SetHost set the request host +func (b *BeegoHTTPRequest) SetHost(host string) *BeegoHTTPRequest { + b.delegate.SetHost(host) + return b +} + +// SetProtocolVersion Set the protocol version for incoming requests. +// Client requests always use HTTP/1.1. +func (b *BeegoHTTPRequest) SetProtocolVersion(vers string) *BeegoHTTPRequest { + b.delegate.SetProtocolVersion(vers) + return b +} + +// SetCookie add cookie into request. +func (b *BeegoHTTPRequest) SetCookie(cookie *http.Cookie) *BeegoHTTPRequest { + b.delegate.SetCookie(cookie) + return b +} + +// SetTransport set the setting transport +func (b *BeegoHTTPRequest) SetTransport(transport http.RoundTripper) *BeegoHTTPRequest { + b.delegate.SetTransport(transport) + return b +} + +// SetProxy set the http proxy +// example: +// +// func(req *http.Request) (*url.URL, error) { +// u, _ := url.ParseRequestURI("http://127.0.0.1:8118") +// return u, nil +// } +func (b *BeegoHTTPRequest) SetProxy(proxy func(*http.Request) (*url.URL, error)) *BeegoHTTPRequest { + b.delegate.SetProxy(proxy) + return b +} + +// SetCheckRedirect specifies the policy for handling redirects. +// +// If CheckRedirect is nil, the Client uses its default policy, +// which is to stop after 10 consecutive requests. +func (b *BeegoHTTPRequest) SetCheckRedirect(redirect func(req *http.Request, via []*http.Request) error) *BeegoHTTPRequest { + b.delegate.SetCheckRedirect(redirect) + return b +} + +// Param adds query param in to request. +// params build query string as ?key1=value1&key2=value2... +func (b *BeegoHTTPRequest) Param(key, value string) *BeegoHTTPRequest { + b.delegate.Param(key, value) + return b +} + +// PostFile add a post file to the request +func (b *BeegoHTTPRequest) PostFile(formname, filename string) *BeegoHTTPRequest { + b.delegate.PostFile(formname, filename) + return b +} + +// Body adds request raw body. +// it supports string and []byte. +func (b *BeegoHTTPRequest) Body(data interface{}) *BeegoHTTPRequest { + b.delegate.Body(data) + return b +} + +// XMLBody adds request raw body encoding by XML. +func (b *BeegoHTTPRequest) XMLBody(obj interface{}) (*BeegoHTTPRequest, error) { + _, err := b.delegate.XMLBody(obj) + return b, err +} + +// YAMLBody adds request raw body encoding by YAML. +func (b *BeegoHTTPRequest) YAMLBody(obj interface{}) (*BeegoHTTPRequest, error) { + _, err := b.delegate.YAMLBody(obj) + return b, err +} + +// JSONBody adds request raw body encoding by JSON. +func (b *BeegoHTTPRequest) JSONBody(obj interface{}) (*BeegoHTTPRequest, error) { + _, err := b.delegate.JSONBody(obj) + return b, err +} + +// DoRequest will do the client.Do +func (b *BeegoHTTPRequest) DoRequest() (resp *http.Response, err error) { + return b.delegate.DoRequest() +} + +// String returns the body string in response. +// it calls Response inner. +func (b *BeegoHTTPRequest) String() (string, error) { + return b.delegate.String() +} + +// Bytes returns the body []byte in response. +// it calls Response inner. +func (b *BeegoHTTPRequest) Bytes() ([]byte, error) { + return b.delegate.Bytes() +} + +// ToFile saves the body data in response to one file. +// it calls Response inner. +func (b *BeegoHTTPRequest) ToFile(filename string) error { + return b.delegate.ToFile(filename) +} + +// ToJSON returns the map that marshals from the body bytes as json in response . +// it calls Response inner. +func (b *BeegoHTTPRequest) ToJSON(v interface{}) error { + return b.delegate.ToJSON(v) +} + +// ToXML returns the map that marshals from the body bytes as xml in response . +// it calls Response inner. +func (b *BeegoHTTPRequest) ToXML(v interface{}) error { + return b.delegate.ToXML(v) +} + +// ToYAML returns the map that marshals from the body bytes as yaml in response . +// it calls Response inner. +func (b *BeegoHTTPRequest) ToYAML(v interface{}) error { + return b.delegate.ToYAML(v) +} + +// Response executes request client gets response mannually. +func (b *BeegoHTTPRequest) Response() (*http.Response, error) { + return b.delegate.Response() +} + +// TimeoutDialer returns functions of connection dialer with timeout settings for http.Transport Dial field. +func TimeoutDialer(cTimeout time.Duration, rwTimeout time.Duration) func(net, addr string) (c net.Conn, err error) { + return httplib.TimeoutDialer(cTimeout, rwTimeout) +} diff --git a/pkg/adapter/httplib/httplib_test.go b/pkg/adapter/httplib/httplib_test.go new file mode 100644 index 00000000..e7605c87 --- /dev/null +++ b/pkg/adapter/httplib/httplib_test.go @@ -0,0 +1,286 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package httplib + +import ( + "errors" + "io/ioutil" + "net" + "net/http" + "os" + "strings" + "testing" + "time" +) + +func TestResponse(t *testing.T) { + req := Get("http://httpbin.org/get") + resp, err := req.Response() + if err != nil { + t.Fatal(err) + } + t.Log(resp) +} + +func TestDoRequest(t *testing.T) { + req := Get("https://goolnk.com/33BD2j") + retryAmount := 1 + req.Retries(1) + req.RetryDelay(1400 * time.Millisecond) + retryDelay := 1400 * time.Millisecond + + req.SetCheckRedirect(func(redirectReq *http.Request, redirectVia []*http.Request) error { + return errors.New("Redirect triggered") + }) + + startTime := time.Now().UnixNano() / int64(time.Millisecond) + + _, err := req.Response() + if err == nil { + t.Fatal("Response should have yielded an error") + } + + endTime := time.Now().UnixNano() / int64(time.Millisecond) + elapsedTime := endTime - startTime + delayedTime := int64(retryAmount) * retryDelay.Milliseconds() + + if elapsedTime < delayedTime { + t.Errorf("Not enough retries. Took %dms. Delay was meant to take %dms", elapsedTime, delayedTime) + } + +} + +func TestGet(t *testing.T) { + req := Get("http://httpbin.org/get") + b, err := req.Bytes() + if err != nil { + t.Fatal(err) + } + t.Log(b) + + s, err := req.String() + if err != nil { + t.Fatal(err) + } + t.Log(s) + + if string(b) != s { + t.Fatal("request data not match") + } +} + +func TestSimplePost(t *testing.T) { + v := "smallfish" + req := Post("http://httpbin.org/post") + req.Param("username", v) + + str, err := req.String() + if err != nil { + t.Fatal(err) + } + t.Log(str) + + n := strings.Index(str, v) + if n == -1 { + t.Fatal(v + " not found in post") + } +} + +// func TestPostFile(t *testing.T) { +// v := "smallfish" +// req := Post("http://httpbin.org/post") +// req.Debug(true) +// req.Param("username", v) +// req.PostFile("uploadfile", "httplib_test.go") + +// str, err := req.String() +// if err != nil { +// t.Fatal(err) +// } +// t.Log(str) + +// n := strings.Index(str, v) +// if n == -1 { +// t.Fatal(v + " not found in post") +// } +// } + +func TestSimplePut(t *testing.T) { + str, err := Put("http://httpbin.org/put").String() + if err != nil { + t.Fatal(err) + } + t.Log(str) +} + +func TestSimpleDelete(t *testing.T) { + str, err := Delete("http://httpbin.org/delete").String() + if err != nil { + t.Fatal(err) + } + t.Log(str) +} + +func TestSimpleDeleteParam(t *testing.T) { + str, err := Delete("http://httpbin.org/delete").Param("key", "val").String() + if err != nil { + t.Fatal(err) + } + t.Log(str) +} + +func TestWithCookie(t *testing.T) { + v := "smallfish" + str, err := Get("http://httpbin.org/cookies/set?k1=" + v).SetEnableCookie(true).String() + if err != nil { + t.Fatal(err) + } + t.Log(str) + + str, err = Get("http://httpbin.org/cookies").SetEnableCookie(true).String() + if err != nil { + t.Fatal(err) + } + t.Log(str) + + n := strings.Index(str, v) + if n == -1 { + t.Fatal(v + " not found in cookie") + } +} + +func TestWithBasicAuth(t *testing.T) { + str, err := Get("http://httpbin.org/basic-auth/user/passwd").SetBasicAuth("user", "passwd").String() + if err != nil { + t.Fatal(err) + } + t.Log(str) + n := strings.Index(str, "authenticated") + if n == -1 { + t.Fatal("authenticated not found in response") + } +} + +func TestWithUserAgent(t *testing.T) { + v := "beego" + str, err := Get("http://httpbin.org/headers").SetUserAgent(v).String() + if err != nil { + t.Fatal(err) + } + t.Log(str) + + n := strings.Index(str, v) + if n == -1 { + t.Fatal(v + " not found in user-agent") + } +} + +func TestWithSetting(t *testing.T) { + v := "beego" + var setting BeegoHTTPSettings + setting.EnableCookie = true + setting.UserAgent = v + setting.Transport = &http.Transport{ + DialContext: (&net.Dialer{ + Timeout: 30 * time.Second, + KeepAlive: 30 * time.Second, + DualStack: true, + }).DialContext, + MaxIdleConns: 50, + IdleConnTimeout: 90 * time.Second, + ExpectContinueTimeout: 1 * time.Second, + } + setting.ReadWriteTimeout = 5 * time.Second + SetDefaultSetting(setting) + + str, err := Get("http://httpbin.org/get").String() + if err != nil { + t.Fatal(err) + } + t.Log(str) + + n := strings.Index(str, v) + if n == -1 { + t.Fatal(v + " not found in user-agent") + } +} + +func TestToJson(t *testing.T) { + req := Get("http://httpbin.org/ip") + resp, err := req.Response() + if err != nil { + t.Fatal(err) + } + t.Log(resp) + + // httpbin will return http remote addr + type IP struct { + Origin string `json:"origin"` + } + var ip IP + err = req.ToJSON(&ip) + if err != nil { + t.Fatal(err) + } + t.Log(ip.Origin) + ips := strings.Split(ip.Origin, ",") + if len(ips) == 0 { + t.Fatal("response is not valid ip") + } + for i := range ips { + if net.ParseIP(strings.TrimSpace(ips[i])).To4() == nil { + t.Fatal("response is not valid ip") + } + } + +} + +func TestToFile(t *testing.T) { + f := "beego_testfile" + req := Get("http://httpbin.org/ip") + err := req.ToFile(f) + if err != nil { + t.Fatal(err) + } + defer os.Remove(f) + b, err := ioutil.ReadFile(f) + if n := strings.Index(string(b), "origin"); n == -1 { + t.Fatal(err) + } +} + +func TestToFileDir(t *testing.T) { + f := "./files/beego_testfile" + req := Get("http://httpbin.org/ip") + err := req.ToFile(f) + if err != nil { + t.Fatal(err) + } + defer os.RemoveAll("./files") + b, err := ioutil.ReadFile(f) + if n := strings.Index(string(b), "origin"); n == -1 { + t.Fatal(err) + } +} + +func TestHeader(t *testing.T) { + req := Get("http://httpbin.org/headers") + req.Header("User-Agent", "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_9_0) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/31.0.1650.57 Safari/537.36") + str, err := req.String() + if err != nil { + t.Fatal(err) + } + t.Log(str) +} diff --git a/pkg/adapter/log.go b/pkg/adapter/log.go new file mode 100644 index 00000000..d9ff6e0c --- /dev/null +++ b/pkg/adapter/log.go @@ -0,0 +1,129 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package adapter + +import ( + "strings" + + "github.com/astaxie/beego/pkg/infrastructure/logs" + + webLog "github.com/astaxie/beego/pkg/infrastructure/logs" +) + +// Log levels to control the logging output. +// Deprecated: use github.com/astaxie/beego/logs instead. +const ( + LevelEmergency = webLog.LevelEmergency + LevelAlert = webLog.LevelAlert + LevelCritical = webLog.LevelCritical + LevelError = webLog.LevelError + LevelWarning = webLog.LevelWarning + LevelNotice = webLog.LevelNotice + LevelInformational = webLog.LevelInformational + LevelDebug = webLog.LevelDebug +) + +// BeeLogger references the used application logger. +// Deprecated: use github.com/astaxie/beego/logs instead. +var BeeLogger = logs.GetBeeLogger() + +// SetLevel sets the global log level used by the simple logger. +// Deprecated: use github.com/astaxie/beego/logs instead. +func SetLevel(l int) { + logs.SetLevel(l) +} + +// SetLogFuncCall set the CallDepth, default is 3 +// Deprecated: use github.com/astaxie/beego/logs instead. +func SetLogFuncCall(b bool) { + logs.SetLogFuncCall(b) +} + +// SetLogger sets a new logger. +// Deprecated: use github.com/astaxie/beego/logs instead. +func SetLogger(adaptername string, config string) error { + return logs.SetLogger(adaptername, config) +} + +// Emergency logs a message at emergency level. +// Deprecated: use github.com/astaxie/beego/logs instead. +func Emergency(v ...interface{}) { + logs.Emergency(generateFmtStr(len(v)), v...) +} + +// Alert logs a message at alert level. +// Deprecated: use github.com/astaxie/beego/logs instead. +func Alert(v ...interface{}) { + logs.Alert(generateFmtStr(len(v)), v...) +} + +// Critical logs a message at critical level. +// Deprecated: use github.com/astaxie/beego/logs instead. +func Critical(v ...interface{}) { + logs.Critical(generateFmtStr(len(v)), v...) +} + +// Error logs a message at error level. +// Deprecated: use github.com/astaxie/beego/logs instead. +func Error(v ...interface{}) { + logs.Error(generateFmtStr(len(v)), v...) +} + +// Warning logs a message at warning level. +// Deprecated: use github.com/astaxie/beego/logs instead. +func Warning(v ...interface{}) { + logs.Warning(generateFmtStr(len(v)), v...) +} + +// Warn compatibility alias for Warning() +// Deprecated: use github.com/astaxie/beego/logs instead. +func Warn(v ...interface{}) { + logs.Warn(generateFmtStr(len(v)), v...) +} + +// Notice logs a message at notice level. +// Deprecated: use github.com/astaxie/beego/logs instead. +func Notice(v ...interface{}) { + logs.Notice(generateFmtStr(len(v)), v...) +} + +// Informational logs a message at info level. +// Deprecated: use github.com/astaxie/beego/logs instead. +func Informational(v ...interface{}) { + logs.Informational(generateFmtStr(len(v)), v...) +} + +// Info compatibility alias for Warning() +// Deprecated: use github.com/astaxie/beego/logs instead. +func Info(v ...interface{}) { + logs.Info(generateFmtStr(len(v)), v...) +} + +// Debug logs a message at debug level. +// Deprecated: use github.com/astaxie/beego/logs instead. +func Debug(v ...interface{}) { + logs.Debug(generateFmtStr(len(v)), v...) +} + +// Trace logs a message at trace level. +// compatibility alias for Warning() +// Deprecated: use github.com/astaxie/beego/logs instead. +func Trace(v ...interface{}) { + logs.Trace(generateFmtStr(len(v)), v...) +} + +func generateFmtStr(n int) string { + return strings.Repeat("%v ", n) +} diff --git a/pkg/adapter/metric/prometheus.go b/pkg/adapter/metric/prometheus.go new file mode 100644 index 00000000..1d3488c6 --- /dev/null +++ b/pkg/adapter/metric/prometheus.go @@ -0,0 +1,99 @@ +// Copyright 2020 astaxie +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package metric + +import ( + "net/http" + "reflect" + "strconv" + "strings" + "time" + + "github.com/prometheus/client_golang/prometheus" + + "github.com/astaxie/beego/pkg/infrastructure/logs" + "github.com/astaxie/beego/pkg/server/web" +) + +func PrometheusMiddleWare(next http.Handler) http.Handler { + summaryVec := prometheus.NewSummaryVec(prometheus.SummaryOpts{ + Name: "beego", + Subsystem: "http_request", + ConstLabels: map[string]string{ + "server": web.BConfig.ServerName, + "env": web.BConfig.RunMode, + "appname": web.BConfig.AppName, + }, + Help: "The statics info for http request", + }, []string{"pattern", "method", "status", "duration"}) + + prometheus.MustRegister(summaryVec) + + registerBuildInfo() + + return http.HandlerFunc(func(writer http.ResponseWriter, q *http.Request) { + start := time.Now() + next.ServeHTTP(writer, q) + end := time.Now() + go report(end.Sub(start), writer, q, summaryVec) + }) +} + +func registerBuildInfo() { + buildInfo := prometheus.NewGaugeVec(prometheus.GaugeOpts{ + Name: "beego", + Subsystem: "build_info", + Help: "The building information", + ConstLabels: map[string]string{ + "appname": web.BConfig.AppName, + "build_version": web.BuildVersion, + "build_revision": web.BuildGitRevision, + "build_status": web.BuildStatus, + "build_tag": web.BuildTag, + "build_time": strings.Replace(web.BuildTime, "--", " ", 1), + "go_version": web.GoVersion, + "git_branch": web.GitBranch, + "start_time": time.Now().Format("2006-01-02 15:04:05"), + }, + }, []string{}) + + prometheus.MustRegister(buildInfo) + buildInfo.WithLabelValues().Set(1) +} + +func report(dur time.Duration, writer http.ResponseWriter, q *http.Request, vec *prometheus.SummaryVec) { + ctrl := web.BeeApp.Handlers + ctx := ctrl.GetContext() + ctx.Reset(writer, q) + defer ctrl.GiveBackContext(ctx) + + // We cannot read the status code from q.Response.StatusCode + // since the http server does not set q.Response. So q.Response is nil + // Thus, we use reflection to read the status from writer whose concrete type is http.response + responseVal := reflect.ValueOf(writer).Elem() + field := responseVal.FieldByName("status") + status := -1 + if field.IsValid() && field.Kind() == reflect.Int { + status = int(field.Int()) + } + ptn := "UNKNOWN" + if rt, found := ctrl.FindRouter(ctx); found { + ptn = rt.GetPattern() + } else { + logs.Warn("we can not find the router info for this request, so request will be recorded as UNKNOWN: " + q.URL.String()) + } + ms := dur / time.Millisecond + vec.WithLabelValues(ptn, q.Method, strconv.Itoa(status), strconv.Itoa(int(ms))).Observe(float64(ms)) +} diff --git a/pkg/adapter/metric/prometheus_test.go b/pkg/adapter/metric/prometheus_test.go new file mode 100644 index 00000000..87286e02 --- /dev/null +++ b/pkg/adapter/metric/prometheus_test.go @@ -0,0 +1,42 @@ +// Copyright 2020 astaxie +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package metric + +import ( + "net/http" + "net/url" + "testing" + "time" + + "github.com/prometheus/client_golang/prometheus" + + "github.com/astaxie/beego/pkg/adapter/context" +) + +func TestPrometheusMiddleWare(t *testing.T) { + middleware := PrometheusMiddleWare(http.HandlerFunc(func(http.ResponseWriter, *http.Request) {})) + writer := &context.Response{} + request := &http.Request{ + URL: &url.URL{ + Host: "localhost", + RawPath: "/a/b/c", + }, + Method: "POST", + } + vec := prometheus.NewSummaryVec(prometheus.SummaryOpts{}, []string{"pattern", "method", "status", "duration"}) + + report(time.Second, writer, request, vec) + middleware.ServeHTTP(writer, request) +} diff --git a/pkg/adapter/migration/ddl.go b/pkg/adapter/migration/ddl.go new file mode 100644 index 00000000..97e45dec --- /dev/null +++ b/pkg/adapter/migration/ddl.go @@ -0,0 +1,198 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package migration + +import ( + "github.com/astaxie/beego/pkg/client/orm/migration" +) + +// Index struct defines the structure of Index Columns +type Index migration.Index + +// Unique struct defines a single unique key combination +type Unique migration.Unique + +// Column struct defines a single column of a table +type Column migration.Column + +// Foreign struct defines a single foreign relationship +type Foreign migration.Foreign + +// RenameColumn struct allows renaming of columns +type RenameColumn migration.RenameColumn + +// CreateTable creates the table on system +func (m *Migration) CreateTable(tablename, engine, charset string, p ...func()) { + (*migration.Migration)(m).CreateTable(tablename, engine, charset, p...) +} + +// AlterTable set the ModifyType to alter +func (m *Migration) AlterTable(tablename string) { + (*migration.Migration)(m).AlterTable(tablename) +} + +// NewCol creates a new standard column and attaches it to m struct +func (m *Migration) NewCol(name string) *Column { + return (*Column)((*migration.Migration)(m).NewCol(name)) +} + +// PriCol creates a new primary column and attaches it to m struct +func (m *Migration) PriCol(name string) *Column { + return (*Column)((*migration.Migration)(m).PriCol(name)) +} + +// UniCol creates / appends columns to specified unique key and attaches it to m struct +func (m *Migration) UniCol(uni, name string) *Column { + return (*Column)((*migration.Migration)(m).UniCol(uni, name)) +} + +// ForeignCol creates a new foreign column and returns the instance of column +func (m *Migration) ForeignCol(colname, foreigncol, foreigntable string) (foreign *Foreign) { + return (*Foreign)((*migration.Migration)(m).ForeignCol(colname, foreigncol, foreigntable)) +} + +// SetOnDelete sets the on delete of foreign +func (foreign *Foreign) SetOnDelete(del string) *Foreign { + (*migration.Foreign)(foreign).SetOnDelete(del) + return foreign +} + +// SetOnUpdate sets the on update of foreign +func (foreign *Foreign) SetOnUpdate(update string) *Foreign { + (*migration.Foreign)(foreign).SetOnUpdate(update) + return foreign +} + +// Remove marks the columns to be removed. +// it allows reverse m to create the column. +func (c *Column) Remove() { + (*migration.Column)(c).Remove() +} + +// SetAuto enables auto_increment of column (can be used once) +func (c *Column) SetAuto(inc bool) *Column { + (*migration.Column)(c).SetAuto(inc) + return c +} + +// SetNullable sets the column to be null +func (c *Column) SetNullable(null bool) *Column { + (*migration.Column)(c).SetNullable(null) + return c +} + +// SetDefault sets the default value, prepend with "DEFAULT " +func (c *Column) SetDefault(def string) *Column { + (*migration.Column)(c).SetDefault(def) + return c +} + +// SetUnsigned sets the column to be unsigned int +func (c *Column) SetUnsigned(unsign bool) *Column { + (*migration.Column)(c).SetUnsigned(unsign) + return c +} + +// SetDataType sets the dataType of the column +func (c *Column) SetDataType(dataType string) *Column { + (*migration.Column)(c).SetDataType(dataType) + return c +} + +// SetOldNullable allows reverting to previous nullable on reverse ms +func (c *RenameColumn) SetOldNullable(null bool) *RenameColumn { + (*migration.RenameColumn)(c).SetOldNullable(null) + return c +} + +// SetOldDefault allows reverting to previous default on reverse ms +func (c *RenameColumn) SetOldDefault(def string) *RenameColumn { + (*migration.RenameColumn)(c).SetOldDefault(def) + return c +} + +// SetOldUnsigned allows reverting to previous unsgined on reverse ms +func (c *RenameColumn) SetOldUnsigned(unsign bool) *RenameColumn { + (*migration.RenameColumn)(c).SetOldUnsigned(unsign) + return c +} + +// SetOldDataType allows reverting to previous datatype on reverse ms +func (c *RenameColumn) SetOldDataType(dataType string) *RenameColumn { + (*migration.RenameColumn)(c).SetOldDataType(dataType) + return c +} + +// SetPrimary adds the columns to the primary key (can only be used any number of times in only one m) +func (c *Column) SetPrimary(m *Migration) *Column { + (*migration.Column)(c).SetPrimary((*migration.Migration)(m)) + return c +} + +// AddColumnsToUnique adds the columns to Unique Struct +func (unique *Unique) AddColumnsToUnique(columns ...*Column) *Unique { + cls := toNewColumnsArray(columns) + (*migration.Unique)(unique).AddColumnsToUnique(cls...) + return unique +} + +// AddColumns adds columns to m struct +func (m *Migration) AddColumns(columns ...*Column) *Migration { + cls := toNewColumnsArray(columns) + (*migration.Migration)(m).AddColumns(cls...) + return m +} + +func toNewColumnsArray(columns []*Column) []*migration.Column { + cls := make([]*migration.Column, 0, len(columns)) + for _, c := range columns { + cls = append(cls, (*migration.Column)(c)) + } + return cls +} + +// AddPrimary adds the column to primary in m struct +func (m *Migration) AddPrimary(primary *Column) *Migration { + (*migration.Migration)(m).AddPrimary((*migration.Column)(primary)) + return m +} + +// AddUnique adds the column to unique in m struct +func (m *Migration) AddUnique(unique *Unique) *Migration { + (*migration.Migration)(m).AddUnique((*migration.Unique)(unique)) + return m +} + +// AddForeign adds the column to foreign in m struct +func (m *Migration) AddForeign(foreign *Foreign) *Migration { + (*migration.Migration)(m).AddForeign((*migration.Foreign)(foreign)) + return m +} + +// AddIndex adds the column to index in m struct +func (m *Migration) AddIndex(index *Index) *Migration { + (*migration.Migration)(m).AddIndex((*migration.Index)(index)) + return m +} + +// RenameColumn allows renaming of columns +func (m *Migration) RenameColumn(from, to string) *RenameColumn { + return (*RenameColumn)((*migration.Migration)(m).RenameColumn(from, to)) +} + +// GetSQL returns the generated sql depending on ModifyType +func (m *Migration) GetSQL() (sql string) { + return (*migration.Migration)(m).GetSQL() +} diff --git a/pkg/migration/doc.go b/pkg/adapter/migration/doc.go similarity index 100% rename from pkg/migration/doc.go rename to pkg/adapter/migration/doc.go diff --git a/pkg/adapter/migration/migration.go b/pkg/adapter/migration/migration.go new file mode 100644 index 00000000..4ee22e5a --- /dev/null +++ b/pkg/adapter/migration/migration.go @@ -0,0 +1,111 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package migration is used for migration +// +// The table structure is as follow: +// +// CREATE TABLE `migrations` ( +// `id_migration` int(10) unsigned NOT NULL AUTO_INCREMENT COMMENT 'surrogate key', +// `name` varchar(255) DEFAULT NULL COMMENT 'migration name, unique', +// `created_at` timestamp NOT NULL DEFAULT CURRENT_TIMESTAMP COMMENT 'date migrated or rolled back', +// `statements` longtext COMMENT 'SQL statements for this migration', +// `rollback_statements` longtext, +// `status` enum('update','rollback') DEFAULT NULL COMMENT 'update indicates it is a normal migration while rollback means this migration is rolled back', +// PRIMARY KEY (`id_migration`) +// ) ENGINE=InnoDB DEFAULT CHARSET=utf8; +package migration + +import ( + "github.com/astaxie/beego/pkg/client/orm/migration" +) + +// const the data format for the bee generate migration datatype +const ( + DateFormat = "20060102_150405" + DBDateFormat = "2006-01-02 15:04:05" +) + +// Migrationer is an interface for all Migration struct +type Migrationer interface { + Up() + Down() + Reset() + Exec(name, status string) error + GetCreated() int64 +} + +// Migration defines the migrations by either SQL or DDL +type Migration migration.Migration + +// Up implement in the Inheritance struct for upgrade +func (m *Migration) Up() { + (*migration.Migration)(m).Up() +} + +// Down implement in the Inheritance struct for down +func (m *Migration) Down() { + (*migration.Migration)(m).Down() +} + +// Migrate adds the SQL to the execution list +func (m *Migration) Migrate(migrationType string) { + (*migration.Migration)(m).Migrate(migrationType) +} + +// SQL add sql want to execute +func (m *Migration) SQL(sql string) { + (*migration.Migration)(m).SQL(sql) +} + +// Reset the sqls +func (m *Migration) Reset() { + (*migration.Migration)(m).Reset() +} + +// Exec execute the sql already add in the sql +func (m *Migration) Exec(name, status string) error { + return (*migration.Migration)(m).Exec(name, status) +} + +// GetCreated get the unixtime from the Created +func (m *Migration) GetCreated() int64 { + return (*migration.Migration)(m).GetCreated() +} + +// Register register the Migration in the map +func Register(name string, m Migrationer) error { + return migration.Register(name, m) +} + +// Upgrade upgrade the migration from lasttime +func Upgrade(lasttime int64) error { + return migration.Upgrade(lasttime) +} + +// Rollback rollback the migration by the name +func Rollback(name string) error { + return migration.Rollback(name) +} + +// Reset reset all migration +// run all migration's down function +func Reset() error { + return migration.Reset() +} + +// Refresh first Reset, then Upgrade +func Refresh() error { + return migration.Refresh() +} diff --git a/pkg/adapter/namespace.go b/pkg/adapter/namespace.go new file mode 100644 index 00000000..609402cf --- /dev/null +++ b/pkg/adapter/namespace.go @@ -0,0 +1,378 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package adapter + +import ( + "net/http" + + adtContext "github.com/astaxie/beego/pkg/adapter/context" + "github.com/astaxie/beego/pkg/server/web/context" + + "github.com/astaxie/beego/pkg/server/web" +) + +type namespaceCond func(*adtContext.Context) bool + +// LinkNamespace used as link action +type LinkNamespace func(*Namespace) + +// Namespace is store all the info +type Namespace web.Namespace + +// NewNamespace get new Namespace +func NewNamespace(prefix string, params ...LinkNamespace) *Namespace { + nps := oldToNewLinkNs(params) + return (*Namespace)(web.NewNamespace(prefix, nps...)) +} + +func oldToNewLinkNs(params []LinkNamespace) []web.LinkNamespace { + nps := make([]web.LinkNamespace, 0, len(params)) + for _, p := range params { + nps = append(nps, func(namespace *web.Namespace) { + p((*Namespace)(namespace)) + }) + } + return nps +} + +// Cond set condition function +// if cond return true can run this namespace, else can't +// usage: +// ns.Cond(func (ctx *context.Context) bool{ +// if ctx.Input.Domain() == "api.beego.me" { +// return true +// } +// return false +// }) +// Cond as the first filter +func (n *Namespace) Cond(cond namespaceCond) *Namespace { + (*web.Namespace)(n).Cond(func(context *context.Context) bool { + return cond((*adtContext.Context)(context)) + }) + return n +} + +// Filter add filter in the Namespace +// action has before & after +// FilterFunc +// usage: +// Filter("before", func (ctx *context.Context){ +// _, ok := ctx.Input.Session("uid").(int) +// if !ok && ctx.Request.RequestURI != "/login" { +// ctx.Redirect(302, "/login") +// } +// }) +func (n *Namespace) Filter(action string, filter ...FilterFunc) *Namespace { + nfs := oldToNewFilter(filter) + (*web.Namespace)(n).Filter(action, nfs...) + return n +} + +func oldToNewFilter(filter []FilterFunc) []web.FilterFunc { + nfs := make([]web.FilterFunc, 0, len(filter)) + for _, f := range filter { + nfs = append(nfs, func(ctx *context.Context) { + f((*adtContext.Context)(ctx)) + }) + } + return nfs +} + +// Router same as beego.Rourer +// refer: https://godoc.org/github.com/astaxie/beego#Router +func (n *Namespace) Router(rootpath string, c ControllerInterface, mappingMethods ...string) *Namespace { + (*web.Namespace)(n).Router(rootpath, c, mappingMethods...) + return n +} + +// AutoRouter same as beego.AutoRouter +// refer: https://godoc.org/github.com/astaxie/beego#AutoRouter +func (n *Namespace) AutoRouter(c ControllerInterface) *Namespace { + (*web.Namespace)(n).AutoRouter(c) + return n +} + +// AutoPrefix same as beego.AutoPrefix +// refer: https://godoc.org/github.com/astaxie/beego#AutoPrefix +func (n *Namespace) AutoPrefix(prefix string, c ControllerInterface) *Namespace { + (*web.Namespace)(n).AutoPrefix(prefix, c) + return n +} + +// Get same as beego.Get +// refer: https://godoc.org/github.com/astaxie/beego#Get +func (n *Namespace) Get(rootpath string, f FilterFunc) *Namespace { + (*web.Namespace)(n).Get(rootpath, func(ctx *context.Context) { + f((*adtContext.Context)(ctx)) + }) + return n +} + +// Post same as beego.Post +// refer: https://godoc.org/github.com/astaxie/beego#Post +func (n *Namespace) Post(rootpath string, f FilterFunc) *Namespace { + (*web.Namespace)(n).Post(rootpath, func(ctx *context.Context) { + f((*adtContext.Context)(ctx)) + }) + return n +} + +// Delete same as beego.Delete +// refer: https://godoc.org/github.com/astaxie/beego#Delete +func (n *Namespace) Delete(rootpath string, f FilterFunc) *Namespace { + (*web.Namespace)(n).Delete(rootpath, func(ctx *context.Context) { + f((*adtContext.Context)(ctx)) + }) + return n +} + +// Put same as beego.Put +// refer: https://godoc.org/github.com/astaxie/beego#Put +func (n *Namespace) Put(rootpath string, f FilterFunc) *Namespace { + (*web.Namespace)(n).Put(rootpath, func(ctx *context.Context) { + f((*adtContext.Context)(ctx)) + }) + return n +} + +// Head same as beego.Head +// refer: https://godoc.org/github.com/astaxie/beego#Head +func (n *Namespace) Head(rootpath string, f FilterFunc) *Namespace { + (*web.Namespace)(n).Head(rootpath, func(ctx *context.Context) { + f((*adtContext.Context)(ctx)) + }) + return n +} + +// Options same as beego.Options +// refer: https://godoc.org/github.com/astaxie/beego#Options +func (n *Namespace) Options(rootpath string, f FilterFunc) *Namespace { + (*web.Namespace)(n).Options(rootpath, func(ctx *context.Context) { + f((*adtContext.Context)(ctx)) + }) + return n +} + +// Patch same as beego.Patch +// refer: https://godoc.org/github.com/astaxie/beego#Patch +func (n *Namespace) Patch(rootpath string, f FilterFunc) *Namespace { + (*web.Namespace)(n).Patch(rootpath, func(ctx *context.Context) { + f((*adtContext.Context)(ctx)) + }) + return n +} + +// Any same as beego.Any +// refer: https://godoc.org/github.com/astaxie/beego#Any +func (n *Namespace) Any(rootpath string, f FilterFunc) *Namespace { + (*web.Namespace)(n).Any(rootpath, func(ctx *context.Context) { + f((*adtContext.Context)(ctx)) + }) + return n +} + +// Handler same as beego.Handler +// refer: https://godoc.org/github.com/astaxie/beego#Handler +func (n *Namespace) Handler(rootpath string, h http.Handler) *Namespace { + (*web.Namespace)(n).Handler(rootpath, h) + return n +} + +// Include add include class +// refer: https://godoc.org/github.com/astaxie/beego#Include +func (n *Namespace) Include(cList ...ControllerInterface) *Namespace { + nL := oldToNewCtrlIntfs(cList) + (*web.Namespace)(n).Include(nL...) + return n +} + +// Namespace add nest Namespace +// usage: +// ns := beego.NewNamespace(“/v1”). +// Namespace( +// beego.NewNamespace("/shop"). +// Get("/:id", func(ctx *context.Context) { +// ctx.Output.Body([]byte("shopinfo")) +// }), +// beego.NewNamespace("/order"). +// Get("/:id", func(ctx *context.Context) { +// ctx.Output.Body([]byte("orderinfo")) +// }), +// beego.NewNamespace("/crm"). +// Get("/:id", func(ctx *context.Context) { +// ctx.Output.Body([]byte("crminfo")) +// }), +// ) +func (n *Namespace) Namespace(ns ...*Namespace) *Namespace { + nns := oldToNewNs(ns) + (*web.Namespace)(n).Namespace(nns...) + return n +} + +func oldToNewNs(ns []*Namespace) []*web.Namespace { + nns := make([]*web.Namespace, 0, len(ns)) + for _, n := range ns { + nns = append(nns, (*web.Namespace)(n)) + } + return nns +} + +// AddNamespace register Namespace into beego.Handler +// support multi Namespace +func AddNamespace(nl ...*Namespace) { + nnl := oldToNewNs(nl) + web.AddNamespace(nnl...) +} + +// NSCond is Namespace Condition +func NSCond(cond namespaceCond) LinkNamespace { + return func(namespace *Namespace) { + web.NSCond(func(b *context.Context) bool { + return cond((*adtContext.Context)(b)) + }) + } +} + +// NSBefore Namespace BeforeRouter filter +func NSBefore(filterList ...FilterFunc) LinkNamespace { + return func(namespace *Namespace) { + nfs := oldToNewFilter(filterList) + web.NSBefore(nfs...) + } +} + +// NSAfter add Namespace FinishRouter filter +func NSAfter(filterList ...FilterFunc) LinkNamespace { + return func(namespace *Namespace) { + nfs := oldToNewFilter(filterList) + web.NSAfter(nfs...) + } +} + +// NSInclude Namespace Include ControllerInterface +func NSInclude(cList ...ControllerInterface) LinkNamespace { + return func(namespace *Namespace) { + nfs := oldToNewCtrlIntfs(cList) + web.NSInclude(nfs...) + } +} + +// NSRouter call Namespace Router +func NSRouter(rootpath string, c ControllerInterface, mappingMethods ...string) LinkNamespace { + return func(namespace *Namespace) { + web.Router(rootpath, c, mappingMethods...) + } +} + +// NSGet call Namespace Get +func NSGet(rootpath string, f FilterFunc) LinkNamespace { + return func(ns *Namespace) { + web.NSGet(rootpath, func(ctx *context.Context) { + f((*adtContext.Context)(ctx)) + }) + } +} + +// NSPost call Namespace Post +func NSPost(rootpath string, f FilterFunc) LinkNamespace { + return func(ns *Namespace) { + web.Post(rootpath, func(ctx *context.Context) { + f((*adtContext.Context)(ctx)) + }) + } +} + +// NSHead call Namespace Head +func NSHead(rootpath string, f FilterFunc) LinkNamespace { + return func(ns *Namespace) { + web.NSHead(rootpath, func(ctx *context.Context) { + f((*adtContext.Context)(ctx)) + }) + } +} + +// NSPut call Namespace Put +func NSPut(rootpath string, f FilterFunc) LinkNamespace { + return func(ns *Namespace) { + web.NSPut(rootpath, func(ctx *context.Context) { + f((*adtContext.Context)(ctx)) + }) + } +} + +// NSDelete call Namespace Delete +func NSDelete(rootpath string, f FilterFunc) LinkNamespace { + return func(ns *Namespace) { + web.NSDelete(rootpath, func(ctx *context.Context) { + f((*adtContext.Context)(ctx)) + }) + } +} + +// NSAny call Namespace Any +func NSAny(rootpath string, f FilterFunc) LinkNamespace { + return func(ns *Namespace) { + web.NSAny(rootpath, func(ctx *context.Context) { + f((*adtContext.Context)(ctx)) + }) + } +} + +// NSOptions call Namespace Options +func NSOptions(rootpath string, f FilterFunc) LinkNamespace { + return func(ns *Namespace) { + web.NSOptions(rootpath, func(ctx *context.Context) { + f((*adtContext.Context)(ctx)) + }) + } +} + +// NSPatch call Namespace Patch +func NSPatch(rootpath string, f FilterFunc) LinkNamespace { + return func(ns *Namespace) { + web.NSPatch(rootpath, func(ctx *context.Context) { + f((*adtContext.Context)(ctx)) + }) + } +} + +// NSAutoRouter call Namespace AutoRouter +func NSAutoRouter(c ControllerInterface) LinkNamespace { + return func(ns *Namespace) { + web.NSAutoRouter(c) + } +} + +// NSAutoPrefix call Namespace AutoPrefix +func NSAutoPrefix(prefix string, c ControllerInterface) LinkNamespace { + return func(ns *Namespace) { + web.NSAutoPrefix(prefix, c) + } +} + +// NSNamespace add sub Namespace +func NSNamespace(prefix string, params ...LinkNamespace) LinkNamespace { + return func(ns *Namespace) { + nps := oldToNewLinkNs(params) + web.NSNamespace(prefix, nps...) + } +} + +// NSHandler add handler +func NSHandler(rootpath string, h http.Handler) LinkNamespace { + return func(ns *Namespace) { + web.NSHandler(rootpath, h) + } +} diff --git a/pkg/adapter/orm/cmd.go b/pkg/adapter/orm/cmd.go new file mode 100644 index 00000000..6fee237c --- /dev/null +++ b/pkg/adapter/orm/cmd.go @@ -0,0 +1,28 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package orm + +import ( + "github.com/astaxie/beego/pkg/client/orm" +) + +// RunCommand listen for orm command and then run it if command arguments passed. +func RunCommand() { + orm.RunCommand() +} + +func RunSyncdb(name string, force bool, verbose bool) error { + return orm.RunSyncdb(name, force, verbose) +} diff --git a/pkg/adapter/orm/db.go b/pkg/adapter/orm/db.go new file mode 100644 index 00000000..74bca8c0 --- /dev/null +++ b/pkg/adapter/orm/db.go @@ -0,0 +1,24 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package orm + +import ( + "github.com/astaxie/beego/pkg/client/orm" +) + +var ( + // ErrMissPK missing pk error + ErrMissPK = orm.ErrMissPK +) diff --git a/pkg/adapter/orm/db_alias.go b/pkg/adapter/orm/db_alias.go new file mode 100644 index 00000000..b1f1a724 --- /dev/null +++ b/pkg/adapter/orm/db_alias.go @@ -0,0 +1,122 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package orm + +import ( + "context" + "database/sql" + "time" + + "github.com/astaxie/beego/pkg/client/orm" +) + +// DriverType database driver constant int. +type DriverType orm.DriverType + +// Enum the Database driver +const ( + _ DriverType = iota // int enum type + DRMySQL = orm.DRMySQL + DRSqlite = orm.DRSqlite // sqlite + DROracle = orm.DROracle // oracle + DRPostgres = orm.DRPostgres // pgsql + DRTiDB = orm.DRTiDB // TiDB +) + +type DB orm.DB + +func (d *DB) Begin() (*sql.Tx, error) { + return (*orm.DB)(d).Begin() +} + +func (d *DB) BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error) { + return (*orm.DB)(d).BeginTx(ctx, opts) +} + +func (d *DB) Prepare(query string) (*sql.Stmt, error) { + return (*orm.DB)(d).Prepare(query) +} + +func (d *DB) PrepareContext(ctx context.Context, query string) (*sql.Stmt, error) { + return (*orm.DB)(d).PrepareContext(ctx, query) +} + +func (d *DB) Exec(query string, args ...interface{}) (sql.Result, error) { + return (*orm.DB)(d).Exec(query, args...) +} + +func (d *DB) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) { + return (*orm.DB)(d).ExecContext(ctx, query, args...) +} + +func (d *DB) Query(query string, args ...interface{}) (*sql.Rows, error) { + return (*orm.DB)(d).Query(query, args...) +} + +func (d *DB) QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) { + return (*orm.DB)(d).QueryContext(ctx, query, args...) +} + +func (d *DB) QueryRow(query string, args ...interface{}) *sql.Row { + return (*orm.DB)(d).QueryRow(query, args) +} + +func (d *DB) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row { + return (*orm.DB)(d).QueryRowContext(ctx, query, args...) +} + +// AddAliasWthDB add a aliasName for the drivename +func AddAliasWthDB(aliasName, driverName string, db *sql.DB) error { + return orm.AddAliasWthDB(aliasName, driverName, db) +} + +// RegisterDataBase Setting the database connect params. Use the database driver self dataSource args. +func RegisterDataBase(aliasName, driverName, dataSource string, params ...int) error { + opts := make([]orm.DBOption, 0, 2) + if len(params) > 0 { + opts = append(opts, orm.MaxIdleConnections(params[0])) + } + + if len(params) > 1 { + opts = append(opts, orm.MaxOpenConnections(params[1])) + } + return orm.RegisterDataBase(aliasName, driverName, dataSource, opts...) +} + +// RegisterDriver Register a database driver use specify driver name, this can be definition the driver is which database type. +func RegisterDriver(driverName string, typ DriverType) error { + return orm.RegisterDriver(driverName, orm.DriverType(typ)) +} + +// SetDataBaseTZ Change the database default used timezone +func SetDataBaseTZ(aliasName string, tz *time.Location) error { + return orm.SetDataBaseTZ(aliasName, tz) +} + +// SetMaxIdleConns Change the max idle conns for *sql.DB, use specify database alias name +func SetMaxIdleConns(aliasName string, maxIdleConns int) { + orm.SetMaxIdleConns(aliasName, maxIdleConns) +} + +// SetMaxOpenConns Change the max open conns for *sql.DB, use specify database alias name +func SetMaxOpenConns(aliasName string, maxOpenConns int) { + orm.SetMaxOpenConns(aliasName, maxOpenConns) +} + +// GetDB Get *sql.DB from registered database by db alias name. +// Use "default" as alias name if you not set. +func GetDB(aliasNames ...string) (*sql.DB, error) { + return orm.GetDB(aliasNames...) +} diff --git a/pkg/adapter/orm/models.go b/pkg/adapter/orm/models.go new file mode 100644 index 00000000..3215f5b5 --- /dev/null +++ b/pkg/adapter/orm/models.go @@ -0,0 +1,25 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package orm + +import ( + "github.com/astaxie/beego/pkg/client/orm" +) + +// ResetModelCache Clean model cache. Then you can re-RegisterModel. +// Common use this api for test case. +func ResetModelCache() { + orm.ResetModelCache() +} diff --git a/pkg/adapter/orm/models_boot.go b/pkg/adapter/orm/models_boot.go new file mode 100644 index 00000000..8888ef65 --- /dev/null +++ b/pkg/adapter/orm/models_boot.go @@ -0,0 +1,40 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package orm + +import ( + "github.com/astaxie/beego/pkg/client/orm" +) + +// RegisterModel register models +func RegisterModel(models ...interface{}) { + orm.RegisterModel(models...) +} + +// RegisterModelWithPrefix register models with a prefix +func RegisterModelWithPrefix(prefix string, models ...interface{}) { + orm.RegisterModelWithPrefix(prefix, models) +} + +// RegisterModelWithSuffix register models with a suffix +func RegisterModelWithSuffix(suffix string, models ...interface{}) { + orm.RegisterModelWithSuffix(suffix, models...) +} + +// BootStrap bootstrap models. +// make all model parsed and can not add more models +func BootStrap() { + orm.BootStrap() +} diff --git a/pkg/adapter/orm/models_fields.go b/pkg/adapter/orm/models_fields.go new file mode 100644 index 00000000..666a97dc --- /dev/null +++ b/pkg/adapter/orm/models_fields.go @@ -0,0 +1,625 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package orm + +import ( + "time" + + "github.com/astaxie/beego/pkg/client/orm" +) + +// Define the Type enum +const ( + TypeBooleanField = orm.TypeBooleanField + TypeVarCharField = orm.TypeVarCharField + TypeCharField = orm.TypeCharField + TypeTextField = orm.TypeTextField + TypeTimeField = orm.TypeTimeField + TypeDateField = orm.TypeDateField + TypeDateTimeField = orm.TypeDateTimeField + TypeBitField = orm.TypeBitField + TypeSmallIntegerField = orm.TypeSmallIntegerField + TypeIntegerField = orm.TypeIntegerField + TypeBigIntegerField = orm.TypeBigIntegerField + TypePositiveBitField = orm.TypePositiveBitField + TypePositiveSmallIntegerField = orm.TypePositiveSmallIntegerField + TypePositiveIntegerField = orm.TypePositiveIntegerField + TypePositiveBigIntegerField = orm.TypePositiveBigIntegerField + TypeFloatField = orm.TypeFloatField + TypeDecimalField = orm.TypeDecimalField + TypeJSONField = orm.TypeJSONField + TypeJsonbField = orm.TypeJsonbField + RelForeignKey = orm.RelForeignKey + RelOneToOne = orm.RelOneToOne + RelManyToMany = orm.RelManyToMany + RelReverseOne = orm.RelReverseOne + RelReverseMany = orm.RelReverseMany +) + +// Define some logic enum +const ( + IsIntegerField = orm.IsIntegerField + IsPositiveIntegerField = orm.IsPositiveIntegerField + IsRelField = orm.IsRelField + IsFieldType = orm.IsFieldType +) + +// BooleanField A true/false field. +type BooleanField orm.BooleanField + +// Value return the BooleanField +func (e BooleanField) Value() bool { + return orm.BooleanField(e).Value() +} + +// Set will set the BooleanField +func (e *BooleanField) Set(d bool) { + (*orm.BooleanField)(e).Set(d) +} + +// String format the Bool to string +func (e *BooleanField) String() string { + return (*orm.BooleanField)(e).String() +} + +// FieldType return BooleanField the type +func (e *BooleanField) FieldType() int { + return (*orm.BooleanField)(e).FieldType() +} + +// SetRaw set the interface to bool +func (e *BooleanField) SetRaw(value interface{}) error { + return (*orm.BooleanField)(e).SetRaw(value) +} + +// RawValue return the current value +func (e *BooleanField) RawValue() interface{} { + return (*orm.BooleanField)(e).RawValue() +} + +// verify the BooleanField implement the Fielder interface +var _ Fielder = new(BooleanField) + +// CharField A string field +// required values tag: size +// The size is enforced at the database level and in models’s validation. +// eg: `orm:"size(120)"` +type CharField orm.CharField + +// Value return the CharField's Value +func (e CharField) Value() string { + return orm.CharField(e).Value() +} + +// Set CharField value +func (e *CharField) Set(d string) { + (*orm.CharField)(e).Set(d) +} + +// String return the CharField +func (e *CharField) String() string { + return (*orm.CharField)(e).String() +} + +// FieldType return the enum type +func (e *CharField) FieldType() int { + return (*orm.CharField)(e).FieldType() +} + +// SetRaw set the interface to string +func (e *CharField) SetRaw(value interface{}) error { + return (*orm.CharField)(e).SetRaw(value) +} + +// RawValue return the CharField value +func (e *CharField) RawValue() interface{} { + return (*orm.CharField)(e).RawValue() +} + +// verify CharField implement Fielder +var _ Fielder = new(CharField) + +// TimeField A time, represented in go by a time.Time instance. +// only time values like 10:00:00 +// Has a few extra, optional attr tag: +// +// auto_now: +// Automatically set the field to now every time the object is saved. Useful for “last-modified” timestamps. +// Note that the current date is always used; it’s not just a default value that you can override. +// +// auto_now_add: +// Automatically set the field to now when the object is first created. Useful for creation of timestamps. +// Note that the current date is always used; it’s not just a default value that you can override. +// +// eg: `orm:"auto_now"` or `orm:"auto_now_add"` +type TimeField orm.TimeField + +// Value return the time.Time +func (e TimeField) Value() time.Time { + return orm.TimeField(e).Value() +} + +// Set set the TimeField's value +func (e *TimeField) Set(d time.Time) { + (*orm.TimeField)(e).Set(d) +} + +// String convert time to string +func (e *TimeField) String() string { + return (*orm.TimeField)(e).String() +} + +// FieldType return enum type Date +func (e *TimeField) FieldType() int { + return (*orm.TimeField)(e).FieldType() +} + +// SetRaw convert the interface to time.Time. Allow string and time.Time +func (e *TimeField) SetRaw(value interface{}) error { + return (*orm.TimeField)(e).SetRaw(value) +} + +// RawValue return time value +func (e *TimeField) RawValue() interface{} { + return (*orm.TimeField)(e).RawValue() +} + +var _ Fielder = new(TimeField) + +// DateField A date, represented in go by a time.Time instance. +// only date values like 2006-01-02 +// Has a few extra, optional attr tag: +// +// auto_now: +// Automatically set the field to now every time the object is saved. Useful for “last-modified” timestamps. +// Note that the current date is always used; it’s not just a default value that you can override. +// +// auto_now_add: +// Automatically set the field to now when the object is first created. Useful for creation of timestamps. +// Note that the current date is always used; it’s not just a default value that you can override. +// +// eg: `orm:"auto_now"` or `orm:"auto_now_add"` +type DateField orm.DateField + +// Value return the time.Time +func (e DateField) Value() time.Time { + return orm.DateField(e).Value() +} + +// Set set the DateField's value +func (e *DateField) Set(d time.Time) { + (*orm.DateField)(e).Set(d) +} + +// String convert datetime to string +func (e *DateField) String() string { + return (*orm.DateField)(e).String() +} + +// FieldType return enum type Date +func (e *DateField) FieldType() int { + return (*orm.DateField)(e).FieldType() +} + +// SetRaw convert the interface to time.Time. Allow string and time.Time +func (e *DateField) SetRaw(value interface{}) error { + return (*orm.DateField)(e).SetRaw(value) +} + +// RawValue return Date value +func (e *DateField) RawValue() interface{} { + return (*orm.DateField)(e).RawValue() +} + +// verify DateField implement fielder interface +var _ Fielder = new(DateField) + +// DateTimeField A date, represented in go by a time.Time instance. +// datetime values like 2006-01-02 15:04:05 +// Takes the same extra arguments as DateField. +type DateTimeField orm.DateTimeField + +// Value return the datetime value +func (e DateTimeField) Value() time.Time { + return orm.DateTimeField(e).Value() +} + +// Set set the time.Time to datetime +func (e *DateTimeField) Set(d time.Time) { + (*orm.DateTimeField)(e).Set(d) +} + +// String return the time's String +func (e *DateTimeField) String() string { + return (*orm.DateTimeField)(e).String() +} + +// FieldType return the enum TypeDateTimeField +func (e *DateTimeField) FieldType() int { + return (*orm.DateTimeField)(e).FieldType() +} + +// SetRaw convert the string or time.Time to DateTimeField +func (e *DateTimeField) SetRaw(value interface{}) error { + return (*orm.DateTimeField)(e).SetRaw(value) +} + +// RawValue return the datetime value +func (e *DateTimeField) RawValue() interface{} { + return (*orm.DateTimeField)(e).RawValue() +} + +// verify datetime implement fielder +var _ Fielder = new(DateTimeField) + +// FloatField A floating-point number represented in go by a float32 value. +type FloatField orm.FloatField + +// Value return the FloatField value +func (e FloatField) Value() float64 { + return orm.FloatField(e).Value() +} + +// Set the Float64 +func (e *FloatField) Set(d float64) { + (*orm.FloatField)(e).Set(d) +} + +// String return the string +func (e *FloatField) String() string { + return (*orm.FloatField)(e).String() +} + +// FieldType return the enum type +func (e *FloatField) FieldType() int { + return (*orm.FloatField)(e).FieldType() +} + +// SetRaw converter interface Float64 float32 or string to FloatField +func (e *FloatField) SetRaw(value interface{}) error { + return (*orm.FloatField)(e).SetRaw(value) +} + +// RawValue return the FloatField value +func (e *FloatField) RawValue() interface{} { + return (*orm.FloatField)(e).RawValue() +} + +// verify FloatField implement Fielder +var _ Fielder = new(FloatField) + +// SmallIntegerField -32768 to 32767 +type SmallIntegerField orm.SmallIntegerField + +// Value return int16 value +func (e SmallIntegerField) Value() int16 { + return orm.SmallIntegerField(e).Value() +} + +// Set the SmallIntegerField value +func (e *SmallIntegerField) Set(d int16) { + (*orm.SmallIntegerField)(e).Set(d) +} + +// String convert smallint to string +func (e *SmallIntegerField) String() string { + return (*orm.SmallIntegerField)(e).String() +} + +// FieldType return enum type SmallIntegerField +func (e *SmallIntegerField) FieldType() int { + return (*orm.SmallIntegerField)(e).FieldType() +} + +// SetRaw convert interface int16/string to int16 +func (e *SmallIntegerField) SetRaw(value interface{}) error { + return (*orm.SmallIntegerField)(e).SetRaw(value) +} + +// RawValue return smallint value +func (e *SmallIntegerField) RawValue() interface{} { + return (*orm.SmallIntegerField)(e).RawValue() +} + +// verify SmallIntegerField implement Fielder +var _ Fielder = new(SmallIntegerField) + +// IntegerField -2147483648 to 2147483647 +type IntegerField orm.IntegerField + +// Value return the int32 +func (e IntegerField) Value() int32 { + return orm.IntegerField(e).Value() +} + +// Set IntegerField value +func (e *IntegerField) Set(d int32) { + (*orm.IntegerField)(e).Set(d) +} + +// String convert Int32 to string +func (e *IntegerField) String() string { + return (*orm.IntegerField)(e).String() +} + +// FieldType return the enum type +func (e *IntegerField) FieldType() int { + return (*orm.IntegerField)(e).FieldType() +} + +// SetRaw convert interface int32/string to int32 +func (e *IntegerField) SetRaw(value interface{}) error { + return (*orm.IntegerField)(e).SetRaw(value) +} + +// RawValue return IntegerField value +func (e *IntegerField) RawValue() interface{} { + return (*orm.IntegerField)(e).RawValue() +} + +// verify IntegerField implement Fielder +var _ Fielder = new(IntegerField) + +// BigIntegerField -9223372036854775808 to 9223372036854775807. +type BigIntegerField orm.BigIntegerField + +// Value return int64 +func (e BigIntegerField) Value() int64 { + return orm.BigIntegerField(e).Value() +} + +// Set the BigIntegerField value +func (e *BigIntegerField) Set(d int64) { + (*orm.BigIntegerField)(e).Set(d) +} + +// String convert BigIntegerField to string +func (e *BigIntegerField) String() string { + return (*orm.BigIntegerField)(e).String() +} + +// FieldType return enum type +func (e *BigIntegerField) FieldType() int { + return (*orm.BigIntegerField)(e).FieldType() +} + +// SetRaw convert interface int64/string to int64 +func (e *BigIntegerField) SetRaw(value interface{}) error { + return (*orm.BigIntegerField)(e).SetRaw(value) +} + +// RawValue return BigIntegerField value +func (e *BigIntegerField) RawValue() interface{} { + return (*orm.BigIntegerField)(e).RawValue() +} + +// verify BigIntegerField implement Fielder +var _ Fielder = new(BigIntegerField) + +// PositiveSmallIntegerField 0 to 65535 +type PositiveSmallIntegerField orm.PositiveSmallIntegerField + +// Value return uint16 +func (e PositiveSmallIntegerField) Value() uint16 { + return orm.PositiveSmallIntegerField(e).Value() +} + +// Set PositiveSmallIntegerField value +func (e *PositiveSmallIntegerField) Set(d uint16) { + (*orm.PositiveSmallIntegerField)(e).Set(d) +} + +// String convert uint16 to string +func (e *PositiveSmallIntegerField) String() string { + return (*orm.PositiveSmallIntegerField)(e).String() +} + +// FieldType return enum type +func (e *PositiveSmallIntegerField) FieldType() int { + return (*orm.PositiveSmallIntegerField)(e).FieldType() +} + +// SetRaw convert Interface uint16/string to uint16 +func (e *PositiveSmallIntegerField) SetRaw(value interface{}) error { + return (*orm.PositiveSmallIntegerField)(e).SetRaw(value) +} + +// RawValue returns PositiveSmallIntegerField value +func (e *PositiveSmallIntegerField) RawValue() interface{} { + return (*orm.PositiveSmallIntegerField)(e).RawValue() +} + +// verify PositiveSmallIntegerField implement Fielder +var _ Fielder = new(PositiveSmallIntegerField) + +// PositiveIntegerField 0 to 4294967295 +type PositiveIntegerField orm.PositiveIntegerField + +// Value return PositiveIntegerField value. Uint32 +func (e PositiveIntegerField) Value() uint32 { + return orm.PositiveIntegerField(e).Value() +} + +// Set the PositiveIntegerField value +func (e *PositiveIntegerField) Set(d uint32) { + (*orm.PositiveIntegerField)(e).Set(d) +} + +// String convert PositiveIntegerField to string +func (e *PositiveIntegerField) String() string { + return (*orm.PositiveIntegerField)(e).String() +} + +// FieldType return enum type +func (e *PositiveIntegerField) FieldType() int { + return (*orm.PositiveIntegerField)(e).FieldType() +} + +// SetRaw convert interface uint32/string to Uint32 +func (e *PositiveIntegerField) SetRaw(value interface{}) error { + return (*orm.PositiveIntegerField)(e).SetRaw(value) +} + +// RawValue return the PositiveIntegerField Value +func (e *PositiveIntegerField) RawValue() interface{} { + return (*orm.PositiveIntegerField)(e).RawValue() +} + +// verify PositiveIntegerField implement Fielder +var _ Fielder = new(PositiveIntegerField) + +// PositiveBigIntegerField 0 to 18446744073709551615 +type PositiveBigIntegerField orm.PositiveBigIntegerField + +// Value return uint64 +func (e PositiveBigIntegerField) Value() uint64 { + return orm.PositiveBigIntegerField(e).Value() +} + +// Set PositiveBigIntegerField value +func (e *PositiveBigIntegerField) Set(d uint64) { + (*orm.PositiveBigIntegerField)(e).Set(d) +} + +// String convert PositiveBigIntegerField to string +func (e *PositiveBigIntegerField) String() string { + return (*orm.PositiveBigIntegerField)(e).String() +} + +// FieldType return enum type +func (e *PositiveBigIntegerField) FieldType() int { + return (*orm.PositiveBigIntegerField)(e).FieldType() +} + +// SetRaw convert interface uint64/string to Uint64 +func (e *PositiveBigIntegerField) SetRaw(value interface{}) error { + return (*orm.PositiveBigIntegerField)(e).SetRaw(value) +} + +// RawValue return PositiveBigIntegerField value +func (e *PositiveBigIntegerField) RawValue() interface{} { + return (*orm.PositiveBigIntegerField)(e).RawValue() +} + +// verify PositiveBigIntegerField implement Fielder +var _ Fielder = new(PositiveBigIntegerField) + +// TextField A large text field. +type TextField orm.TextField + +// Value return TextField value +func (e TextField) Value() string { + return orm.TextField(e).Value() +} + +// Set the TextField value +func (e *TextField) Set(d string) { + (*orm.TextField)(e).Set(d) +} + +// String convert TextField to string +func (e *TextField) String() string { + return (*orm.TextField)(e).String() +} + +// FieldType return enum type +func (e *TextField) FieldType() int { + return (*orm.TextField)(e).FieldType() +} + +// SetRaw convert interface string to string +func (e *TextField) SetRaw(value interface{}) error { + return (*orm.TextField)(e).SetRaw(value) +} + +// RawValue return TextField value +func (e *TextField) RawValue() interface{} { + return (*orm.TextField)(e).RawValue() +} + +// verify TextField implement Fielder +var _ Fielder = new(TextField) + +// JSONField postgres json field. +type JSONField orm.JSONField + +// Value return JSONField value +func (j JSONField) Value() string { + return orm.JSONField(j).Value() +} + +// Set the JSONField value +func (j *JSONField) Set(d string) { + (*orm.JSONField)(j).Set(d) +} + +// String convert JSONField to string +func (j *JSONField) String() string { + return (*orm.JSONField)(j).String() +} + +// FieldType return enum type +func (j *JSONField) FieldType() int { + return (*orm.JSONField)(j).FieldType() +} + +// SetRaw convert interface string to string +func (j *JSONField) SetRaw(value interface{}) error { + return (*orm.JSONField)(j).SetRaw(value) +} + +// RawValue return JSONField value +func (j *JSONField) RawValue() interface{} { + return (*orm.JSONField)(j).RawValue() +} + +// verify JSONField implement Fielder +var _ Fielder = new(JSONField) + +// JsonbField postgres json field. +type JsonbField orm.JsonbField + +// Value return JsonbField value +func (j JsonbField) Value() string { + return orm.JsonbField(j).Value() +} + +// Set the JsonbField value +func (j *JsonbField) Set(d string) { + (*orm.JsonbField)(j).Set(d) +} + +// String convert JsonbField to string +func (j *JsonbField) String() string { + return (*orm.JsonbField)(j).String() +} + +// FieldType return enum type +func (j *JsonbField) FieldType() int { + return (*orm.JsonbField)(j).FieldType() +} + +// SetRaw convert interface string to string +func (j *JsonbField) SetRaw(value interface{}) error { + return (*orm.JsonbField)(j).SetRaw(value) +} + +// RawValue return JsonbField value +func (j *JsonbField) RawValue() interface{} { + return (*orm.JsonbField)(j).RawValue() +} + +// verify JsonbField implement Fielder +var _ Fielder = new(JsonbField) diff --git a/pkg/adapter/orm/orm.go b/pkg/adapter/orm/orm.go new file mode 100644 index 00000000..f8463ea2 --- /dev/null +++ b/pkg/adapter/orm/orm.go @@ -0,0 +1,314 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// +build go1.8 + +// Package orm provide ORM for MySQL/PostgreSQL/sqlite +// Simple Usage +// +// package main +// +// import ( +// "fmt" +// "github.com/astaxie/beego/orm" +// _ "github.com/go-sql-driver/mysql" // import your used driver +// ) +// +// // Model Struct +// type User struct { +// Id int `orm:"auto"` +// Name string `orm:"size(100)"` +// } +// +// func init() { +// orm.RegisterDataBase("default", "mysql", "root:root@/my_db?charset=utf8", 30) +// } +// +// func main() { +// o := orm.NewOrm() +// user := User{Name: "slene"} +// // insert +// id, err := o.Insert(&user) +// // update +// user.Name = "astaxie" +// num, err := o.Update(&user) +// // read one +// u := User{Id: user.Id} +// err = o.Read(&u) +// // delete +// num, err = o.Delete(&u) +// } +// +// more docs: http://beego.me/docs/mvc/model/overview.md +package orm + +import ( + "context" + "database/sql" + "errors" + + "github.com/astaxie/beego/pkg/client/orm" + "github.com/astaxie/beego/pkg/client/orm/hints" + "github.com/astaxie/beego/pkg/infrastructure/utils" +) + +// DebugQueries define the debug +const ( + DebugQueries = iota +) + +// Define common vars +var ( + Debug = orm.Debug + DebugLog = orm.DebugLog + DefaultRowsLimit = orm.DefaultRowsLimit + DefaultRelsDepth = orm.DefaultRelsDepth + DefaultTimeLoc = orm.DefaultTimeLoc + ErrTxHasBegan = errors.New(" transaction already begin") + ErrTxDone = errors.New(" transaction not begin") + ErrMultiRows = errors.New(" return multi rows") + ErrNoRows = errors.New(" no row found") + ErrStmtClosed = errors.New(" stmt already closed") + ErrArgs = errors.New(" args error may be empty") + ErrNotImplement = errors.New("have not implement") +) + +type ormer struct { + delegate orm.Ormer + txDelegate orm.TxOrmer + isTx bool +} + +var _ Ormer = new(ormer) + +// read data to model +func (o *ormer) Read(md interface{}, cols ...string) error { + if o.isTx { + return o.txDelegate.Read(md, cols...) + } + return o.delegate.Read(md, cols...) +} + +// read data to model, like Read(), but use "SELECT FOR UPDATE" form +func (o *ormer) ReadForUpdate(md interface{}, cols ...string) error { + if o.isTx { + return o.txDelegate.ReadForUpdate(md, cols...) + } + return o.delegate.ReadForUpdate(md, cols...) +} + +// Try to read a row from the database, or insert one if it doesn't exist +func (o *ormer) ReadOrCreate(md interface{}, col1 string, cols ...string) (bool, int64, error) { + if o.isTx { + return o.txDelegate.ReadOrCreate(md, col1, cols...) + } + return o.delegate.ReadOrCreate(md, col1, cols...) +} + +// insert model data to database +func (o *ormer) Insert(md interface{}) (int64, error) { + if o.isTx { + return o.txDelegate.Insert(md) + } + return o.delegate.Insert(md) +} + +// insert some models to database +func (o *ormer) InsertMulti(bulk int, mds interface{}) (int64, error) { + if o.isTx { + return o.txDelegate.InsertMulti(bulk, mds) + } + return o.delegate.InsertMulti(bulk, mds) +} + +// InsertOrUpdate data to database +func (o *ormer) InsertOrUpdate(md interface{}, colConflitAndArgs ...string) (int64, error) { + if o.isTx { + return o.txDelegate.InsertOrUpdate(md, colConflitAndArgs...) + } + return o.delegate.InsertOrUpdate(md, colConflitAndArgs...) +} + +// update model to database. +// cols set the columns those want to update. +func (o *ormer) Update(md interface{}, cols ...string) (int64, error) { + if o.isTx { + return o.txDelegate.Update(md, cols...) + } + return o.delegate.Update(md, cols...) +} + +// delete model in database +// cols shows the delete conditions values read from. default is pk +func (o *ormer) Delete(md interface{}, cols ...string) (int64, error) { + if o.isTx { + return o.txDelegate.Delete(md, cols...) + } + return o.delegate.Delete(md, cols...) +} + +// create a models to models queryer +func (o *ormer) QueryM2M(md interface{}, name string) QueryM2Mer { + if o.isTx { + return o.txDelegate.QueryM2M(md, name) + } + return o.delegate.QueryM2M(md, name) +} + +// load related models to md model. +// args are limit, offset int and order string. +// +// example: +// orm.LoadRelated(post,"Tags") +// for _,tag := range post.Tags{...} +// +// make sure the relation is defined in model struct tags. +func (o *ormer) LoadRelated(md interface{}, name string, args ...interface{}) (int64, error) { + kvs := make([]utils.KV, 0, 4) + for i, arg := range args { + switch i { + case 0: + if v, ok := arg.(bool); ok { + if v { + kvs = append(kvs, hints.DefaultRelDepth()) + } + } else if v, ok := arg.(int); ok { + kvs = append(kvs, hints.RelDepth(v)) + } + case 1: + kvs = append(kvs, hints.Limit(orm.ToInt64(arg))) + case 2: + kvs = append(kvs, hints.Offset(orm.ToInt64(arg))) + case 3: + kvs = append(kvs, hints.Offset(orm.ToInt64(arg))) + } + } + if o.isTx { + return o.txDelegate.LoadRelated(md, name, kvs...) + } + return o.delegate.LoadRelated(md, name, kvs...) +} + +// return a QuerySeter for table operations. +// table name can be string or struct. +// e.g. QueryTable("user"), QueryTable(&user{}) or QueryTable((*User)(nil)), +func (o *ormer) QueryTable(ptrStructOrTableName interface{}) (qs QuerySeter) { + if o.isTx { + return o.txDelegate.QueryTable(ptrStructOrTableName) + } + return o.delegate.QueryTable(ptrStructOrTableName) +} + +// switch to another registered database driver by given name. +func (o *ormer) Using(name string) error { + if o.isTx { + return ErrTxHasBegan + } + o.delegate = orm.NewOrmUsingDB(name) + return nil +} + +// begin transaction +func (o *ormer) Begin() error { + if o.isTx { + return ErrTxHasBegan + } + return o.BeginTx(context.Background(), nil) +} + +func (o *ormer) BeginTx(ctx context.Context, opts *sql.TxOptions) error { + if o.isTx { + return ErrTxHasBegan + } + txOrmer, err := o.delegate.BeginWithCtxAndOpts(ctx, opts) + if err != nil { + return err + } + o.txDelegate = txOrmer + o.isTx = true + return nil +} + +// commit transaction +func (o *ormer) Commit() error { + if !o.isTx { + return ErrTxDone + } + err := o.txDelegate.Commit() + if err == nil { + o.isTx = false + o.txDelegate = nil + } else if err == sql.ErrTxDone { + return ErrTxDone + } + return err +} + +// rollback transaction +func (o *ormer) Rollback() error { + if !o.isTx { + return ErrTxDone + } + err := o.txDelegate.Rollback() + if err == nil { + o.isTx = false + o.txDelegate = nil + } else if err == sql.ErrTxDone { + return ErrTxDone + } + return err +} + +// return a raw query seter for raw sql string. +func (o *ormer) Raw(query string, args ...interface{}) RawSeter { + if o.isTx { + return o.txDelegate.Raw(query, args...) + } + return o.delegate.Raw(query, args...) +} + +// return current using database Driver +func (o *ormer) Driver() Driver { + if o.isTx { + return o.txDelegate.Driver() + } + return o.delegate.Driver() +} + +// return sql.DBStats for current database +func (o *ormer) DBStats() *sql.DBStats { + if o.isTx { + return o.txDelegate.DBStats() + } + return o.delegate.DBStats() +} + +// NewOrm create new orm +func NewOrm() Ormer { + o := orm.NewOrm() + return &ormer{ + delegate: o, + } +} + +// NewOrmWithDB create a new ormer object with specify *sql.DB for query +func NewOrmWithDB(driverName, aliasName string, db *sql.DB) (Ormer, error) { + o, err := orm.NewOrmWithDB(driverName, aliasName, db) + if err != nil { + return nil, err + } + return &ormer{ + delegate: o, + }, nil +} diff --git a/pkg/adapter/orm/orm_conds.go b/pkg/adapter/orm/orm_conds.go new file mode 100644 index 00000000..986b4858 --- /dev/null +++ b/pkg/adapter/orm/orm_conds.go @@ -0,0 +1,83 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package orm + +import ( + "github.com/astaxie/beego/pkg/client/orm" +) + +// ExprSep define the expression separation +const ( + ExprSep = "__" +) + +// Condition struct. +// work for WHERE conditions. +type Condition orm.Condition + +// NewCondition return new condition struct +func NewCondition() *Condition { + return (*Condition)(orm.NewCondition()) +} + +// Raw add raw sql to condition +func (c Condition) Raw(expr string, sql string) *Condition { + return (*Condition)((orm.Condition)(c).Raw(expr, sql)) +} + +// And add expression to condition +func (c Condition) And(expr string, args ...interface{}) *Condition { + return (*Condition)((orm.Condition)(c).And(expr, args...)) +} + +// AndNot add NOT expression to condition +func (c Condition) AndNot(expr string, args ...interface{}) *Condition { + return (*Condition)((orm.Condition)(c).AndNot(expr, args...)) +} + +// AndCond combine a condition to current condition +func (c *Condition) AndCond(cond *Condition) *Condition { + return (*Condition)((*orm.Condition)(c).AndCond((*orm.Condition)(cond))) +} + +// AndNotCond combine a AND NOT condition to current condition +func (c *Condition) AndNotCond(cond *Condition) *Condition { + return (*Condition)((*orm.Condition)(c).AndNotCond((*orm.Condition)(cond))) +} + +// Or add OR expression to condition +func (c Condition) Or(expr string, args ...interface{}) *Condition { + return (*Condition)((orm.Condition)(c).Or(expr, args...)) +} + +// OrNot add OR NOT expression to condition +func (c Condition) OrNot(expr string, args ...interface{}) *Condition { + return (*Condition)((orm.Condition)(c).OrNot(expr, args...)) +} + +// OrCond combine a OR condition to current condition +func (c *Condition) OrCond(cond *Condition) *Condition { + return (*Condition)((*orm.Condition)(c).OrCond((*orm.Condition)(cond))) +} + +// OrNotCond combine a OR NOT condition to current condition +func (c *Condition) OrNotCond(cond *Condition) *Condition { + return (*Condition)((*orm.Condition)(c).OrNotCond((*orm.Condition)(cond))) +} + +// IsEmpty check the condition arguments are empty or not. +func (c *Condition) IsEmpty() bool { + return (*orm.Condition)(c).IsEmpty() +} diff --git a/pkg/adapter/orm/orm_log.go b/pkg/adapter/orm/orm_log.go new file mode 100644 index 00000000..6b2b4a9b --- /dev/null +++ b/pkg/adapter/orm/orm_log.go @@ -0,0 +1,32 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package orm + +import ( + "io" + + "github.com/astaxie/beego/pkg/client/orm" +) + +// Log implement the log.Logger +type Log orm.Log + +// costomer log func +var LogFunc = orm.LogFunc + +// NewLog set io.Writer to create a Logger. +func NewLog(out io.Writer) *Log { + return (*Log)(orm.NewLog(out)) +} diff --git a/pkg/adapter/orm/orm_queryset.go b/pkg/adapter/orm/orm_queryset.go new file mode 100644 index 00000000..5f211644 --- /dev/null +++ b/pkg/adapter/orm/orm_queryset.go @@ -0,0 +1,32 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package orm + +import ( + "github.com/astaxie/beego/pkg/client/orm" +) + +// define Col operations +const ( + ColAdd = orm.ColAdd + ColMinus = orm.ColMinus + ColMultiply = orm.ColMultiply + ColExcept = orm.ColExcept + ColBitAnd = orm.ColBitAnd + ColBitRShift = orm.ColBitRShift + ColBitLShift = orm.ColBitLShift + ColBitXOR = orm.ColBitXOR + ColBitOr = orm.ColBitOr +) diff --git a/pkg/adapter/orm/qb.go b/pkg/adapter/orm/qb.go new file mode 100644 index 00000000..90b97797 --- /dev/null +++ b/pkg/adapter/orm/qb.go @@ -0,0 +1,27 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package orm + +import ( + "github.com/astaxie/beego/pkg/client/orm" +) + +// QueryBuilder is the Query builder interface +type QueryBuilder orm.QueryBuilder + +// NewQueryBuilder return the QueryBuilder +func NewQueryBuilder(driver string) (qb QueryBuilder, err error) { + return orm.NewQueryBuilder(driver) +} diff --git a/pkg/adapter/orm/qb_mysql.go b/pkg/adapter/orm/qb_mysql.go new file mode 100644 index 00000000..9566068f --- /dev/null +++ b/pkg/adapter/orm/qb_mysql.go @@ -0,0 +1,150 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package orm + +import ( + "github.com/astaxie/beego/pkg/client/orm" +) + +// CommaSpace is the separation +const CommaSpace = orm.CommaSpace + +// MySQLQueryBuilder is the SQL build +type MySQLQueryBuilder orm.MySQLQueryBuilder + +// Select will join the fields +func (qb *MySQLQueryBuilder) Select(fields ...string) QueryBuilder { + return (*orm.MySQLQueryBuilder)(qb).Select(fields...) +} + +// ForUpdate add the FOR UPDATE clause +func (qb *MySQLQueryBuilder) ForUpdate() QueryBuilder { + return (*orm.MySQLQueryBuilder)(qb).ForUpdate() +} + +// From join the tables +func (qb *MySQLQueryBuilder) From(tables ...string) QueryBuilder { + return (*orm.MySQLQueryBuilder)(qb).From(tables...) +} + +// InnerJoin INNER JOIN the table +func (qb *MySQLQueryBuilder) InnerJoin(table string) QueryBuilder { + return (*orm.MySQLQueryBuilder)(qb).InnerJoin(table) +} + +// LeftJoin LEFT JOIN the table +func (qb *MySQLQueryBuilder) LeftJoin(table string) QueryBuilder { + return (*orm.MySQLQueryBuilder)(qb).LeftJoin(table) +} + +// RightJoin RIGHT JOIN the table +func (qb *MySQLQueryBuilder) RightJoin(table string) QueryBuilder { + return (*orm.MySQLQueryBuilder)(qb).RightJoin(table) +} + +// On join with on cond +func (qb *MySQLQueryBuilder) On(cond string) QueryBuilder { + return (*orm.MySQLQueryBuilder)(qb).On(cond) +} + +// Where join the Where cond +func (qb *MySQLQueryBuilder) Where(cond string) QueryBuilder { + return (*orm.MySQLQueryBuilder)(qb).Where(cond) +} + +// And join the and cond +func (qb *MySQLQueryBuilder) And(cond string) QueryBuilder { + return (*orm.MySQLQueryBuilder)(qb).And(cond) +} + +// Or join the or cond +func (qb *MySQLQueryBuilder) Or(cond string) QueryBuilder { + return (*orm.MySQLQueryBuilder)(qb).Or(cond) +} + +// In join the IN (vals) +func (qb *MySQLQueryBuilder) In(vals ...string) QueryBuilder { + return (*orm.MySQLQueryBuilder)(qb).In(vals...) +} + +// OrderBy join the Order by fields +func (qb *MySQLQueryBuilder) OrderBy(fields ...string) QueryBuilder { + return (*orm.MySQLQueryBuilder)(qb).OrderBy(fields...) +} + +// Asc join the asc +func (qb *MySQLQueryBuilder) Asc() QueryBuilder { + return (*orm.MySQLQueryBuilder)(qb).Asc() +} + +// Desc join the desc +func (qb *MySQLQueryBuilder) Desc() QueryBuilder { + return (*orm.MySQLQueryBuilder)(qb).Desc() +} + +// Limit join the limit num +func (qb *MySQLQueryBuilder) Limit(limit int) QueryBuilder { + return (*orm.MySQLQueryBuilder)(qb).Limit(limit) +} + +// Offset join the offset num +func (qb *MySQLQueryBuilder) Offset(offset int) QueryBuilder { + return (*orm.MySQLQueryBuilder)(qb).Offset(offset) +} + +// GroupBy join the Group by fields +func (qb *MySQLQueryBuilder) GroupBy(fields ...string) QueryBuilder { + return (*orm.MySQLQueryBuilder)(qb).GroupBy(fields...) +} + +// Having join the Having cond +func (qb *MySQLQueryBuilder) Having(cond string) QueryBuilder { + return (*orm.MySQLQueryBuilder)(qb).Having(cond) +} + +// Update join the update table +func (qb *MySQLQueryBuilder) Update(tables ...string) QueryBuilder { + return (*orm.MySQLQueryBuilder)(qb).Update(tables...) +} + +// Set join the set kv +func (qb *MySQLQueryBuilder) Set(kv ...string) QueryBuilder { + return (*orm.MySQLQueryBuilder)(qb).Set(kv...) +} + +// Delete join the Delete tables +func (qb *MySQLQueryBuilder) Delete(tables ...string) QueryBuilder { + return (*orm.MySQLQueryBuilder)(qb).Delete(tables...) +} + +// InsertInto join the insert SQL +func (qb *MySQLQueryBuilder) InsertInto(table string, fields ...string) QueryBuilder { + return (*orm.MySQLQueryBuilder)(qb).InsertInto(table, fields...) +} + +// Values join the Values(vals) +func (qb *MySQLQueryBuilder) Values(vals ...string) QueryBuilder { + return (*orm.MySQLQueryBuilder)(qb).Values(vals...) +} + +// Subquery join the sub as alias +func (qb *MySQLQueryBuilder) Subquery(sub string, alias string) string { + return (*orm.MySQLQueryBuilder)(qb).Subquery(sub, alias) +} + +// String join all Tokens +func (qb *MySQLQueryBuilder) String() string { + return (*orm.MySQLQueryBuilder)(qb).String() +} diff --git a/pkg/adapter/orm/qb_tidb.go b/pkg/adapter/orm/qb_tidb.go new file mode 100644 index 00000000..05c91a26 --- /dev/null +++ b/pkg/adapter/orm/qb_tidb.go @@ -0,0 +1,147 @@ +// Copyright 2015 TiDB Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package orm + +import ( + "github.com/astaxie/beego/pkg/client/orm" +) + +// TiDBQueryBuilder is the SQL build +type TiDBQueryBuilder orm.TiDBQueryBuilder + +// Select will join the fields +func (qb *TiDBQueryBuilder) Select(fields ...string) QueryBuilder { + return (*orm.TiDBQueryBuilder)(qb).Select(fields...) +} + +// ForUpdate add the FOR UPDATE clause +func (qb *TiDBQueryBuilder) ForUpdate() QueryBuilder { + return (*orm.TiDBQueryBuilder)(qb).ForUpdate() +} + +// From join the tables +func (qb *TiDBQueryBuilder) From(tables ...string) QueryBuilder { + return (*orm.TiDBQueryBuilder)(qb).From(tables...) +} + +// InnerJoin INNER JOIN the table +func (qb *TiDBQueryBuilder) InnerJoin(table string) QueryBuilder { + return (*orm.TiDBQueryBuilder)(qb).InnerJoin(table) +} + +// LeftJoin LEFT JOIN the table +func (qb *TiDBQueryBuilder) LeftJoin(table string) QueryBuilder { + return (*orm.TiDBQueryBuilder)(qb).LeftJoin(table) +} + +// RightJoin RIGHT JOIN the table +func (qb *TiDBQueryBuilder) RightJoin(table string) QueryBuilder { + return (*orm.TiDBQueryBuilder)(qb).RightJoin(table) +} + +// On join with on cond +func (qb *TiDBQueryBuilder) On(cond string) QueryBuilder { + return (*orm.TiDBQueryBuilder)(qb).On(cond) +} + +// Where join the Where cond +func (qb *TiDBQueryBuilder) Where(cond string) QueryBuilder { + return (*orm.TiDBQueryBuilder)(qb).Where(cond) +} + +// And join the and cond +func (qb *TiDBQueryBuilder) And(cond string) QueryBuilder { + return (*orm.TiDBQueryBuilder)(qb).And(cond) +} + +// Or join the or cond +func (qb *TiDBQueryBuilder) Or(cond string) QueryBuilder { + return (*orm.TiDBQueryBuilder)(qb).Or(cond) +} + +// In join the IN (vals) +func (qb *TiDBQueryBuilder) In(vals ...string) QueryBuilder { + return (*orm.TiDBQueryBuilder)(qb).In(vals...) +} + +// OrderBy join the Order by fields +func (qb *TiDBQueryBuilder) OrderBy(fields ...string) QueryBuilder { + return (*orm.TiDBQueryBuilder)(qb).OrderBy(fields...) +} + +// Asc join the asc +func (qb *TiDBQueryBuilder) Asc() QueryBuilder { + return (*orm.TiDBQueryBuilder)(qb).Asc() +} + +// Desc join the desc +func (qb *TiDBQueryBuilder) Desc() QueryBuilder { + return (*orm.TiDBQueryBuilder)(qb).Desc() +} + +// Limit join the limit num +func (qb *TiDBQueryBuilder) Limit(limit int) QueryBuilder { + return (*orm.TiDBQueryBuilder)(qb).Limit(limit) +} + +// Offset join the offset num +func (qb *TiDBQueryBuilder) Offset(offset int) QueryBuilder { + return (*orm.TiDBQueryBuilder)(qb).Offset(offset) +} + +// GroupBy join the Group by fields +func (qb *TiDBQueryBuilder) GroupBy(fields ...string) QueryBuilder { + return (*orm.TiDBQueryBuilder)(qb).GroupBy(fields...) +} + +// Having join the Having cond +func (qb *TiDBQueryBuilder) Having(cond string) QueryBuilder { + return (*orm.TiDBQueryBuilder)(qb).Having(cond) +} + +// Update join the update table +func (qb *TiDBQueryBuilder) Update(tables ...string) QueryBuilder { + return (*orm.TiDBQueryBuilder)(qb).Update(tables...) +} + +// Set join the set kv +func (qb *TiDBQueryBuilder) Set(kv ...string) QueryBuilder { + return (*orm.TiDBQueryBuilder)(qb).Set(kv...) +} + +// Delete join the Delete tables +func (qb *TiDBQueryBuilder) Delete(tables ...string) QueryBuilder { + return (*orm.TiDBQueryBuilder)(qb).Delete(tables...) +} + +// InsertInto join the insert SQL +func (qb *TiDBQueryBuilder) InsertInto(table string, fields ...string) QueryBuilder { + return (*orm.TiDBQueryBuilder)(qb).InsertInto(table, fields...) +} + +// Values join the Values(vals) +func (qb *TiDBQueryBuilder) Values(vals ...string) QueryBuilder { + return (*orm.TiDBQueryBuilder)(qb).Values(vals...) +} + +// Subquery join the sub as alias +func (qb *TiDBQueryBuilder) Subquery(sub string, alias string) string { + return (*orm.TiDBQueryBuilder)(qb).Subquery(sub, alias) +} + +// String join all Tokens +func (qb *TiDBQueryBuilder) String() string { + return (*orm.TiDBQueryBuilder)(qb).String() +} diff --git a/pkg/adapter/orm/query_setter_adapter.go b/pkg/adapter/orm/query_setter_adapter.go new file mode 100644 index 00000000..cc24ef6b --- /dev/null +++ b/pkg/adapter/orm/query_setter_adapter.go @@ -0,0 +1,34 @@ +// Copyright 2020 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package orm + +import ( + "github.com/astaxie/beego/pkg/client/orm" +) + +type baseQuerySetter struct { +} + +func (b *baseQuerySetter) ForceIndex(indexes ...string) orm.QuerySeter { + panic("you should not invoke this method.") +} + +func (b *baseQuerySetter) UseIndex(indexes ...string) orm.QuerySeter { + panic("you should not invoke this method.") +} + +func (b *baseQuerySetter) IgnoreIndex(indexes ...string) orm.QuerySeter { + panic("you should not invoke this method.") +} diff --git a/pkg/adapter/orm/types.go b/pkg/adapter/orm/types.go new file mode 100644 index 00000000..3372e301 --- /dev/null +++ b/pkg/adapter/orm/types.go @@ -0,0 +1,150 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package orm + +import ( + "context" + "database/sql" + + "github.com/astaxie/beego/pkg/client/orm" +) + +// Params stores the Params +type Params orm.Params + +// ParamsList stores paramslist +type ParamsList orm.ParamsList + +// Driver define database driver +type Driver orm.Driver + +// Fielder define field info +type Fielder orm.Fielder + +// Ormer define the orm interface +type Ormer interface { + // read data to model + // for example: + // this will find User by Id field + // u = &User{Id: user.Id} + // err = Ormer.Read(u) + // this will find User by UserName field + // u = &User{UserName: "astaxie", Password: "pass"} + // err = Ormer.Read(u, "UserName") + Read(md interface{}, cols ...string) error + // Like Read(), but with "FOR UPDATE" clause, useful in transaction. + // Some databases are not support this feature. + ReadForUpdate(md interface{}, cols ...string) error + // Try to read a row from the database, or insert one if it doesn't exist + ReadOrCreate(md interface{}, col1 string, cols ...string) (bool, int64, error) + // insert model data to database + // for example: + // user := new(User) + // id, err = Ormer.Insert(user) + // user must be a pointer and Insert will set user's pk field + Insert(interface{}) (int64, error) + // mysql:InsertOrUpdate(model) or InsertOrUpdate(model,"colu=colu+value") + // if colu type is integer : can use(+-*/), string : convert(colu,"value") + // postgres: InsertOrUpdate(model,"conflictColumnName") or InsertOrUpdate(model,"conflictColumnName","colu=colu+value") + // if colu type is integer : can use(+-*/), string : colu || "value" + InsertOrUpdate(md interface{}, colConflitAndArgs ...string) (int64, error) + // insert some models to database + InsertMulti(bulk int, mds interface{}) (int64, error) + // update model to database. + // cols set the columns those want to update. + // find model by Id(pk) field and update columns specified by fields, if cols is null then update all columns + // for example: + // user := User{Id: 2} + // user.Langs = append(user.Langs, "zh-CN", "en-US") + // user.Extra.Name = "beego" + // user.Extra.Data = "orm" + // num, err = Ormer.Update(&user, "Langs", "Extra") + Update(md interface{}, cols ...string) (int64, error) + // delete model in database + Delete(md interface{}, cols ...string) (int64, error) + // load related models to md model. + // args are limit, offset int and order string. + // + // example: + // Ormer.LoadRelated(post,"Tags") + // for _,tag := range post.Tags{...} + // args[0] bool true useDefaultRelsDepth ; false depth 0 + // args[0] int loadRelationDepth + // args[1] int limit default limit 1000 + // args[2] int offset default offset 0 + // args[3] string order for example : "-Id" + // make sure the relation is defined in model struct tags. + LoadRelated(md interface{}, name string, args ...interface{}) (int64, error) + // create a models to models queryer + // for example: + // post := Post{Id: 4} + // m2m := Ormer.QueryM2M(&post, "Tags") + QueryM2M(md interface{}, name string) QueryM2Mer + // return a QuerySeter for table operations. + // table name can be string or struct. + // e.g. QueryTable("user"), QueryTable(&user{}) or QueryTable((*User)(nil)), + QueryTable(ptrStructOrTableName interface{}) QuerySeter + // switch to another registered database driver by given name. + Using(name string) error + // begin transaction + // for example: + // o := NewOrm() + // err := o.Begin() + // ... + // err = o.Rollback() + Begin() error + // begin transaction with provided context and option + // the provided context is used until the transaction is committed or rolled back. + // if the context is canceled, the transaction will be rolled back. + // the provided TxOptions is optional and may be nil if defaults should be used. + // if a non-default isolation level is used that the driver doesn't support, an error will be returned. + // for example: + // o := NewOrm() + // err := o.BeginTx(context.Background(), &sql.TxOptions{Isolation: sql.LevelRepeatableRead}) + // ... + // err = o.Rollback() + BeginTx(ctx context.Context, opts *sql.TxOptions) error + // commit transaction + Commit() error + // rollback transaction + Rollback() error + // return a raw query seter for raw sql string. + // for example: + // ormer.Raw("UPDATE `user` SET `user_name` = ? WHERE `user_name` = ?", "slene", "testing").Exec() + // // update user testing's name to slene + Raw(query string, args ...interface{}) RawSeter + Driver() Driver + DBStats() *sql.DBStats +} + +// Inserter insert prepared statement +type Inserter orm.Inserter + +// QuerySeter query seter +type QuerySeter orm.QuerySeter + +// QueryM2Mer model to model query struct +// all operations are on the m2m table only, will not affect the origin model table +type QueryM2Mer orm.QueryM2Mer + +// RawPreparer raw query statement +type RawPreparer orm.RawPreparer + +// RawSeter raw query seter +// create From Ormer.Raw +// for example: +// sql := fmt.Sprintf("SELECT %sid%s,%sname%s FROM %suser%s WHERE id = ?",Q,Q,Q,Q,Q,Q) +// rs := Ormer.Raw(sql, 1) +type RawSeter orm.RawSeter diff --git a/pkg/adapter/orm/utils.go b/pkg/adapter/orm/utils.go new file mode 100644 index 00000000..16d0e4e5 --- /dev/null +++ b/pkg/adapter/orm/utils.go @@ -0,0 +1,286 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package orm + +import ( + "fmt" + "reflect" + "strconv" + "strings" + "time" + + "github.com/astaxie/beego/pkg/client/orm" +) + +type fn func(string) string + +var ( + nameStrategyMap = map[string]fn{ + defaultNameStrategy: snakeString, + SnakeAcronymNameStrategy: snakeStringWithAcronym, + } + defaultNameStrategy = "snakeString" + SnakeAcronymNameStrategy = "snakeStringWithAcronym" + nameStrategy = defaultNameStrategy +) + +// StrTo is the target string +type StrTo orm.StrTo + +// Set string +func (f *StrTo) Set(v string) { + (*orm.StrTo)(f).Set(v) +} + +// Clear string +func (f *StrTo) Clear() { + (*orm.StrTo)(f).Clear() +} + +// Exist check string exist +func (f StrTo) Exist() bool { + return orm.StrTo(f).Exist() +} + +// Bool string to bool +func (f StrTo) Bool() (bool, error) { + return orm.StrTo(f).Bool() +} + +// Float32 string to float32 +func (f StrTo) Float32() (float32, error) { + return orm.StrTo(f).Float32() +} + +// Float64 string to float64 +func (f StrTo) Float64() (float64, error) { + return orm.StrTo(f).Float64() +} + +// Int string to int +func (f StrTo) Int() (int, error) { + return orm.StrTo(f).Int() +} + +// Int8 string to int8 +func (f StrTo) Int8() (int8, error) { + return orm.StrTo(f).Int8() +} + +// Int16 string to int16 +func (f StrTo) Int16() (int16, error) { + return orm.StrTo(f).Int16() +} + +// Int32 string to int32 +func (f StrTo) Int32() (int32, error) { + return orm.StrTo(f).Int32() +} + +// Int64 string to int64 +func (f StrTo) Int64() (int64, error) { + return orm.StrTo(f).Int64() +} + +// Uint string to uint +func (f StrTo) Uint() (uint, error) { + return orm.StrTo(f).Uint() +} + +// Uint8 string to uint8 +func (f StrTo) Uint8() (uint8, error) { + return orm.StrTo(f).Uint8() +} + +// Uint16 string to uint16 +func (f StrTo) Uint16() (uint16, error) { + return orm.StrTo(f).Uint16() +} + +// Uint32 string to uint32 +func (f StrTo) Uint32() (uint32, error) { + return orm.StrTo(f).Uint32() +} + +// Uint64 string to uint64 +func (f StrTo) Uint64() (uint64, error) { + return orm.StrTo(f).Uint64() +} + +// String string to string +func (f StrTo) String() string { + return orm.StrTo(f).String() +} + +// ToStr interface to string +func ToStr(value interface{}, args ...int) (s string) { + switch v := value.(type) { + case bool: + s = strconv.FormatBool(v) + case float32: + s = strconv.FormatFloat(float64(v), 'f', argInt(args).Get(0, -1), argInt(args).Get(1, 32)) + case float64: + s = strconv.FormatFloat(v, 'f', argInt(args).Get(0, -1), argInt(args).Get(1, 64)) + case int: + s = strconv.FormatInt(int64(v), argInt(args).Get(0, 10)) + case int8: + s = strconv.FormatInt(int64(v), argInt(args).Get(0, 10)) + case int16: + s = strconv.FormatInt(int64(v), argInt(args).Get(0, 10)) + case int32: + s = strconv.FormatInt(int64(v), argInt(args).Get(0, 10)) + case int64: + s = strconv.FormatInt(v, argInt(args).Get(0, 10)) + case uint: + s = strconv.FormatUint(uint64(v), argInt(args).Get(0, 10)) + case uint8: + s = strconv.FormatUint(uint64(v), argInt(args).Get(0, 10)) + case uint16: + s = strconv.FormatUint(uint64(v), argInt(args).Get(0, 10)) + case uint32: + s = strconv.FormatUint(uint64(v), argInt(args).Get(0, 10)) + case uint64: + s = strconv.FormatUint(v, argInt(args).Get(0, 10)) + case string: + s = v + case []byte: + s = string(v) + default: + s = fmt.Sprintf("%v", v) + } + return s +} + +// ToInt64 interface to int64 +func ToInt64(value interface{}) (d int64) { + val := reflect.ValueOf(value) + switch value.(type) { + case int, int8, int16, int32, int64: + d = val.Int() + case uint, uint8, uint16, uint32, uint64: + d = int64(val.Uint()) + default: + panic(fmt.Errorf("ToInt64 need numeric not `%T`", value)) + } + return +} + +func snakeStringWithAcronym(s string) string { + data := make([]byte, 0, len(s)*2) + num := len(s) + for i := 0; i < num; i++ { + d := s[i] + before := false + after := false + if i > 0 { + before = s[i-1] >= 'a' && s[i-1] <= 'z' + } + if i+1 < num { + after = s[i+1] >= 'a' && s[i+1] <= 'z' + } + if i > 0 && d >= 'A' && d <= 'Z' && (before || after) { + data = append(data, '_') + } + data = append(data, d) + } + return strings.ToLower(string(data[:])) +} + +// snake string, XxYy to xx_yy , XxYY to xx_y_y +func snakeString(s string) string { + data := make([]byte, 0, len(s)*2) + j := false + num := len(s) + for i := 0; i < num; i++ { + d := s[i] + if i > 0 && d >= 'A' && d <= 'Z' && j { + data = append(data, '_') + } + if d != '_' { + j = true + } + data = append(data, d) + } + return strings.ToLower(string(data[:])) +} + +// SetNameStrategy set different name strategy +func SetNameStrategy(s string) { + if SnakeAcronymNameStrategy != s { + nameStrategy = defaultNameStrategy + } + nameStrategy = s +} + +// camel string, xx_yy to XxYy +func camelString(s string) string { + data := make([]byte, 0, len(s)) + flag, num := true, len(s)-1 + for i := 0; i <= num; i++ { + d := s[i] + if d == '_' { + flag = true + continue + } else if flag { + if d >= 'a' && d <= 'z' { + d = d - 32 + } + flag = false + } + data = append(data, d) + } + return string(data[:]) +} + +type argString []string + +// get string by index from string slice +func (a argString) Get(i int, args ...string) (r string) { + if i >= 0 && i < len(a) { + r = a[i] + } else if len(args) > 0 { + r = args[0] + } + return +} + +type argInt []int + +// get int by index from int slice +func (a argInt) Get(i int, args ...int) (r int) { + if i >= 0 && i < len(a) { + r = a[i] + } + if len(args) > 0 { + r = args[0] + } + return +} + +// parse time to string with location +func timeParse(dateString, format string) (time.Time, error) { + tp, err := time.ParseInLocation(format, dateString, DefaultTimeLoc) + return tp, err +} + +// get pointer indirect type +func indirectType(v reflect.Type) reflect.Type { + switch v.Kind() { + case reflect.Ptr: + return indirectType(v.Elem()) + default: + return v + } +} diff --git a/pkg/orm/utils_test.go b/pkg/adapter/orm/utils_test.go similarity index 100% rename from pkg/orm/utils_test.go rename to pkg/adapter/orm/utils_test.go diff --git a/pkg/adapter/plugins/apiauth/apiauth.go b/pkg/adapter/plugins/apiauth/apiauth.go new file mode 100644 index 00000000..ed43f8a0 --- /dev/null +++ b/pkg/adapter/plugins/apiauth/apiauth.go @@ -0,0 +1,94 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package apiauth provides handlers to enable apiauth support. +// +// Simple Usage: +// import( +// "github.com/astaxie/beego" +// "github.com/astaxie/beego/plugins/apiauth" +// ) +// +// func main(){ +// // apiauth every request +// beego.InsertFilter("*", beego.BeforeRouter,apiauth.APIBaiscAuth("appid","appkey")) +// beego.Run() +// } +// +// Advanced Usage: +// +// func getAppSecret(appid string) string { +// // get appsecret by appid +// // maybe store in configure, maybe in database +// } +// +// beego.InsertFilter("*", beego.BeforeRouter,apiauth.APISecretAuth(getAppSecret, 360)) +// +// Information: +// +// In the request user should include these params in the query +// +// 1. appid +// +// appid is assigned to the application +// +// 2. signature +// +// get the signature use apiauth.Signature() +// +// when you send to server remember use url.QueryEscape() +// +// 3. timestamp: +// +// send the request time, the format is yyyy-mm-dd HH:ii:ss +// +package apiauth + +import ( + "net/url" + + beego "github.com/astaxie/beego/pkg/adapter" + "github.com/astaxie/beego/pkg/adapter/context" + beecontext "github.com/astaxie/beego/pkg/server/web/context" + "github.com/astaxie/beego/pkg/server/web/filter/apiauth" +) + +// AppIDToAppSecret is used to get appsecret throw appid +type AppIDToAppSecret apiauth.AppIDToAppSecret + +// APIBasicAuth use the basic appid/appkey as the AppIdToAppSecret +func APIBasicAuth(appid, appkey string) beego.FilterFunc { + f := apiauth.APIBasicAuth(appid, appkey) + return func(c *context.Context) { + f((*beecontext.Context)(c)) + } +} + +// APIBaiscAuth calls APIBasicAuth for previous callers +func APIBaiscAuth(appid, appkey string) beego.FilterFunc { + return APIBasicAuth(appid, appkey) +} + +// APISecretAuth use AppIdToAppSecret verify and +func APISecretAuth(f AppIDToAppSecret, timeout int) beego.FilterFunc { + ft := apiauth.APISecretAuth(apiauth.AppIDToAppSecret(f), timeout) + return func(ctx *context.Context) { + ft((*beecontext.Context)(ctx)) + } +} + +// Signature used to generate signature with the appsecret/method/params/RequestURI +func Signature(appsecret, method string, params url.Values, requestURL string) string { + return apiauth.Signature(appsecret, method, params, requestURL) +} diff --git a/pkg/plugins/apiauth/apiauth_test.go b/pkg/adapter/plugins/apiauth/apiauth_test.go similarity index 100% rename from pkg/plugins/apiauth/apiauth_test.go rename to pkg/adapter/plugins/apiauth/apiauth_test.go diff --git a/pkg/adapter/plugins/auth/basic.go b/pkg/adapter/plugins/auth/basic.go new file mode 100644 index 00000000..7a9cd326 --- /dev/null +++ b/pkg/adapter/plugins/auth/basic.go @@ -0,0 +1,81 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package auth provides handlers to enable basic auth support. +// Simple Usage: +// import( +// "github.com/astaxie/beego" +// "github.com/astaxie/beego/plugins/auth" +// ) +// +// func main(){ +// // authenticate every request +// beego.InsertFilter("*", beego.BeforeRouter,auth.Basic("username","secretpassword")) +// beego.Run() +// } +// +// +// Advanced Usage: +// +// func SecretAuth(username, password string) bool { +// return username == "astaxie" && password == "helloBeego" +// } +// authPlugin := auth.NewBasicAuthenticator(SecretAuth, "Authorization Required") +// beego.InsertFilter("*", beego.BeforeRouter,authPlugin) +package auth + +import ( + "net/http" + + beego "github.com/astaxie/beego/pkg/adapter" + "github.com/astaxie/beego/pkg/adapter/context" + beecontext "github.com/astaxie/beego/pkg/server/web/context" + "github.com/astaxie/beego/pkg/server/web/filter/auth" +) + +// Basic is the http basic auth +func Basic(username string, password string) beego.FilterFunc { + return func(c *context.Context) { + f := auth.Basic(username, password) + f((*beecontext.Context)(c)) + } +} + +// NewBasicAuthenticator return the BasicAuth +func NewBasicAuthenticator(secrets SecretProvider, realm string) beego.FilterFunc { + f := auth.NewBasicAuthenticator(auth.SecretProvider(secrets), realm) + return func(c *context.Context) { + f((*beecontext.Context)(c)) + } +} + +// SecretProvider is the SecretProvider function +type SecretProvider auth.SecretProvider + +// BasicAuth store the SecretProvider and Realm +type BasicAuth auth.BasicAuth + +// CheckAuth Checks the username/password combination from the request. Returns +// either an empty string (authentication failed) or the name of the +// authenticated user. +// Supports MD5 and SHA1 password entries +func (a *BasicAuth) CheckAuth(r *http.Request) string { + return (*auth.BasicAuth)(a).CheckAuth(r) +} + +// RequireAuth http.Handler for BasicAuth which initiates the authentication process +// (or requires reauthentication). +func (a *BasicAuth) RequireAuth(w http.ResponseWriter, r *http.Request) { + (*auth.BasicAuth)(a).RequireAuth(w, r) +} diff --git a/pkg/adapter/plugins/authz/authz.go b/pkg/adapter/plugins/authz/authz.go new file mode 100644 index 00000000..c38be9cb --- /dev/null +++ b/pkg/adapter/plugins/authz/authz.go @@ -0,0 +1,80 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package authz provides handlers to enable ACL, RBAC, ABAC authorization support. +// Simple Usage: +// import( +// "github.com/astaxie/beego" +// "github.com/astaxie/beego/plugins/authz" +// "github.com/casbin/casbin" +// ) +// +// func main(){ +// // mediate the access for every request +// beego.InsertFilter("*", beego.BeforeRouter, authz.NewAuthorizer(casbin.NewEnforcer("authz_model.conf", "authz_policy.csv"))) +// beego.Run() +// } +// +// +// Advanced Usage: +// +// func main(){ +// e := casbin.NewEnforcer("authz_model.conf", "") +// e.AddRoleForUser("alice", "admin") +// e.AddPolicy(...) +// +// beego.InsertFilter("*", beego.BeforeRouter, authz.NewAuthorizer(e)) +// beego.Run() +// } +package authz + +import ( + "net/http" + + "github.com/casbin/casbin" + + beego "github.com/astaxie/beego/pkg/adapter" + "github.com/astaxie/beego/pkg/adapter/context" + beecontext "github.com/astaxie/beego/pkg/server/web/context" + "github.com/astaxie/beego/pkg/server/web/filter/authz" +) + +// NewAuthorizer returns the authorizer. +// Use a casbin enforcer as input +func NewAuthorizer(e *casbin.Enforcer) beego.FilterFunc { + f := authz.NewAuthorizer(e) + return func(context *context.Context) { + f((*beecontext.Context)(context)) + } +} + +// BasicAuthorizer stores the casbin handler +type BasicAuthorizer authz.BasicAuthorizer + +// GetUserName gets the user name from the request. +// Currently, only HTTP basic authentication is supported +func (a *BasicAuthorizer) GetUserName(r *http.Request) string { + return (*authz.BasicAuthorizer)(a).GetUserName(r) +} + +// CheckPermission checks the user/method/path combination from the request. +// Returns true (permission granted) or false (permission forbidden) +func (a *BasicAuthorizer) CheckPermission(r *http.Request) bool { + return (*authz.BasicAuthorizer)(a).CheckPermission(r) +} + +// RequirePermission returns the 403 Forbidden to the client +func (a *BasicAuthorizer) RequirePermission(w http.ResponseWriter) { + (*authz.BasicAuthorizer)(a).RequirePermission(w) +} diff --git a/pkg/plugins/authz/authz_model.conf b/pkg/adapter/plugins/authz/authz_model.conf similarity index 100% rename from pkg/plugins/authz/authz_model.conf rename to pkg/adapter/plugins/authz/authz_model.conf diff --git a/pkg/plugins/authz/authz_policy.csv b/pkg/adapter/plugins/authz/authz_policy.csv similarity index 100% rename from pkg/plugins/authz/authz_policy.csv rename to pkg/adapter/plugins/authz/authz_policy.csv diff --git a/pkg/plugins/authz/authz_test.go b/pkg/adapter/plugins/authz/authz_test.go similarity index 96% rename from pkg/plugins/authz/authz_test.go rename to pkg/adapter/plugins/authz/authz_test.go index 6cc081f3..ddbda5f4 100644 --- a/pkg/plugins/authz/authz_test.go +++ b/pkg/adapter/plugins/authz/authz_test.go @@ -19,9 +19,9 @@ import ( "net/http/httptest" "testing" - beego "github.com/astaxie/beego/pkg" - "github.com/astaxie/beego/pkg/context" - "github.com/astaxie/beego/pkg/plugins/auth" + beego "github.com/astaxie/beego/pkg/adapter" + "github.com/astaxie/beego/pkg/adapter/context" + "github.com/astaxie/beego/pkg/adapter/plugins/auth" "github.com/casbin/casbin" ) diff --git a/pkg/adapter/plugins/cors/cors.go b/pkg/adapter/plugins/cors/cors.go new file mode 100644 index 00000000..65af8b8f --- /dev/null +++ b/pkg/adapter/plugins/cors/cors.go @@ -0,0 +1,71 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package cors provides handlers to enable CORS support. +// Usage +// import ( +// "github.com/astaxie/beego" +// "github.com/astaxie/beego/plugins/cors" +// ) +// +// func main() { +// // CORS for https://foo.* origins, allowing: +// // - PUT and PATCH methods +// // - Origin header +// // - Credentials share +// beego.InsertFilter("*", beego.BeforeRouter, cors.Allow(&cors.Options{ +// AllowOrigins: []string{"https://*.foo.com"}, +// AllowMethods: []string{"PUT", "PATCH"}, +// AllowHeaders: []string{"Origin"}, +// ExposeHeaders: []string{"Content-Length"}, +// AllowCredentials: true, +// })) +// beego.Run() +// } +package cors + +import ( + beego "github.com/astaxie/beego/pkg/adapter" + beecontext "github.com/astaxie/beego/pkg/server/web/context" + "github.com/astaxie/beego/pkg/server/web/filter/cors" + + "github.com/astaxie/beego/pkg/adapter/context" +) + +// Options represents Access Control options. +type Options cors.Options + +// Header converts options into CORS headers. +func (o *Options) Header(origin string) (headers map[string]string) { + return (*cors.Options)(o).Header(origin) +} + +// PreflightHeader converts options into CORS headers for a preflight response. +func (o *Options) PreflightHeader(origin, rMethod, rHeaders string) (headers map[string]string) { + return (*cors.Options)(o).PreflightHeader(origin, rMethod, rHeaders) +} + +// IsOriginAllowed looks up if the origin matches one of the patterns +// generated from Options.AllowOrigins patterns. +func (o *Options) IsOriginAllowed(origin string) bool { + return (*cors.Options)(o).IsOriginAllowed(origin) +} + +// Allow enables CORS for requests those match the provided options. +func Allow(opts *Options) beego.FilterFunc { + f := cors.Allow((*cors.Options)(opts)) + return func(c *context.Context) { + f((*beecontext.Context)(c)) + } +} diff --git a/pkg/adapter/policy.go b/pkg/adapter/policy.go new file mode 100644 index 00000000..f3759c76 --- /dev/null +++ b/pkg/adapter/policy.go @@ -0,0 +1,57 @@ +// Copyright 2016 beego authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package adapter + +import ( + "github.com/astaxie/beego/pkg/adapter/context" + "github.com/astaxie/beego/pkg/server/web" + beecontext "github.com/astaxie/beego/pkg/server/web/context" +) + +// PolicyFunc defines a policy function which is invoked before the controller handler is executed. +type PolicyFunc func(*context.Context) + +// FindPolicy Find Router info for URL +func (p *ControllerRegister) FindPolicy(cont *context.Context) []PolicyFunc { + pf := (*web.ControllerRegister)(p).FindPolicy((*beecontext.Context)(cont)) + npf := newToOldPolicyFunc(pf) + return npf +} + +func newToOldPolicyFunc(pf []web.PolicyFunc) []PolicyFunc { + npf := make([]PolicyFunc, 0, len(pf)) + for _, f := range pf { + npf = append(npf, func(c *context.Context) { + f((*beecontext.Context)(c)) + }) + } + return npf +} + +func oldToNewPolicyFunc(pf []PolicyFunc) []web.PolicyFunc { + npf := make([]web.PolicyFunc, 0, len(pf)) + for _, f := range pf { + npf = append(npf, func(c *beecontext.Context) { + f((*context.Context)(c)) + }) + } + return npf +} + +// Policy Register new policy in beego +func Policy(pattern, method string, policy ...PolicyFunc) { + pf := oldToNewPolicyFunc(policy) + web.Policy(pattern, method, pf...) +} diff --git a/pkg/adapter/router.go b/pkg/adapter/router.go new file mode 100644 index 00000000..8e8d9fdb --- /dev/null +++ b/pkg/adapter/router.go @@ -0,0 +1,282 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package adapter + +import ( + "net/http" + "time" + + beecontext "github.com/astaxie/beego/pkg/adapter/context" + "github.com/astaxie/beego/pkg/server/web/context" + + "github.com/astaxie/beego/pkg/server/web" +) + +// default filter execution points +const ( + BeforeStatic = web.BeforeStatic + BeforeRouter = web.BeforeRouter + BeforeExec = web.BeforeExec + AfterExec = web.AfterExec + FinishRouter = web.FinishRouter +) + +var ( + // HTTPMETHOD list the supported http methods. + HTTPMETHOD = web.HTTPMETHOD + + // DefaultAccessLogFilter will skip the accesslog if return true + DefaultAccessLogFilter FilterHandler = &newToOldFtHdlAdapter{ + delegate: web.DefaultAccessLogFilter, + } +) + +// FilterHandler is an interface for +type FilterHandler interface { + Filter(*beecontext.Context) bool +} + +type newToOldFtHdlAdapter struct { + delegate web.FilterHandler +} + +func (n *newToOldFtHdlAdapter) Filter(ctx *beecontext.Context) bool { + return n.delegate.Filter((*context.Context)(ctx)) +} + +// ExceptMethodAppend to append a slice's value into "exceptMethod", for controller's methods shouldn't reflect to AutoRouter +func ExceptMethodAppend(action string) { + web.ExceptMethodAppend(action) +} + +// ControllerInfo holds information about the controller. +type ControllerInfo web.ControllerInfo + +func (c *ControllerInfo) GetPattern() string { + return (*web.ControllerInfo)(c).GetPattern() +} + +// ControllerRegister containers registered router rules, controller handlers and filters. +type ControllerRegister web.ControllerRegister + +// NewControllerRegister returns a new ControllerRegister. +func NewControllerRegister() *ControllerRegister { + return (*ControllerRegister)(web.NewControllerRegister()) +} + +// Add controller handler and pattern rules to ControllerRegister. +// usage: +// default methods is the same name as method +// Add("/user",&UserController{}) +// Add("/api/list",&RestController{},"*:ListFood") +// Add("/api/create",&RestController{},"post:CreateFood") +// Add("/api/update",&RestController{},"put:UpdateFood") +// Add("/api/delete",&RestController{},"delete:DeleteFood") +// Add("/api",&RestController{},"get,post:ApiFunc" +// Add("/simple",&SimpleController{},"get:GetFunc;post:PostFunc") +func (p *ControllerRegister) Add(pattern string, c ControllerInterface, mappingMethods ...string) { + (*web.ControllerRegister)(p).Add(pattern, c, mappingMethods...) +} + +// Include only when the Runmode is dev will generate router file in the router/auto.go from the controller +// Include(&BankAccount{}, &OrderController{},&RefundController{},&ReceiptController{}) +func (p *ControllerRegister) Include(cList ...ControllerInterface) { + nls := oldToNewCtrlIntfs(cList) + (*web.ControllerRegister)(p).Include(nls...) +} + +// GetContext returns a context from pool, so usually you should remember to call Reset function to clean the context +// And don't forget to give back context to pool +// example: +// ctx := p.GetContext() +// ctx.Reset(w, q) +// defer p.GiveBackContext(ctx) +func (p *ControllerRegister) GetContext() *beecontext.Context { + return (*beecontext.Context)((*web.ControllerRegister)(p).GetContext()) +} + +// GiveBackContext put the ctx into pool so that it could be reuse +func (p *ControllerRegister) GiveBackContext(ctx *beecontext.Context) { + (*web.ControllerRegister)(p).GiveBackContext((*context.Context)(ctx)) +} + +// Get add get method +// usage: +// Get("/", func(ctx *context.Context){ +// ctx.Output.Body("hello world") +// }) +func (p *ControllerRegister) Get(pattern string, f FilterFunc) { + (*web.ControllerRegister)(p).Get(pattern, func(ctx *context.Context) { + f((*beecontext.Context)(ctx)) + }) +} + +// Post add post method +// usage: +// Post("/api", func(ctx *context.Context){ +// ctx.Output.Body("hello world") +// }) +func (p *ControllerRegister) Post(pattern string, f FilterFunc) { + (*web.ControllerRegister)(p).Post(pattern, func(ctx *context.Context) { + f((*beecontext.Context)(ctx)) + }) +} + +// Put add put method +// usage: +// Put("/api/:id", func(ctx *context.Context){ +// ctx.Output.Body("hello world") +// }) +func (p *ControllerRegister) Put(pattern string, f FilterFunc) { + (*web.ControllerRegister)(p).Put(pattern, func(ctx *context.Context) { + f((*beecontext.Context)(ctx)) + }) +} + +// Delete add delete method +// usage: +// Delete("/api/:id", func(ctx *context.Context){ +// ctx.Output.Body("hello world") +// }) +func (p *ControllerRegister) Delete(pattern string, f FilterFunc) { + (*web.ControllerRegister)(p).Delete(pattern, func(ctx *context.Context) { + f((*beecontext.Context)(ctx)) + }) +} + +// Head add head method +// usage: +// Head("/api/:id", func(ctx *context.Context){ +// ctx.Output.Body("hello world") +// }) +func (p *ControllerRegister) Head(pattern string, f FilterFunc) { + (*web.ControllerRegister)(p).Head(pattern, func(ctx *context.Context) { + f((*beecontext.Context)(ctx)) + }) +} + +// Patch add patch method +// usage: +// Patch("/api/:id", func(ctx *context.Context){ +// ctx.Output.Body("hello world") +// }) +func (p *ControllerRegister) Patch(pattern string, f FilterFunc) { + (*web.ControllerRegister)(p).Patch(pattern, func(ctx *context.Context) { + f((*beecontext.Context)(ctx)) + }) +} + +// Options add options method +// usage: +// Options("/api/:id", func(ctx *context.Context){ +// ctx.Output.Body("hello world") +// }) +func (p *ControllerRegister) Options(pattern string, f FilterFunc) { + (*web.ControllerRegister)(p).Options(pattern, func(ctx *context.Context) { + f((*beecontext.Context)(ctx)) + }) +} + +// Any add all method +// usage: +// Any("/api/:id", func(ctx *context.Context){ +// ctx.Output.Body("hello world") +// }) +func (p *ControllerRegister) Any(pattern string, f FilterFunc) { + (*web.ControllerRegister)(p).Any(pattern, func(ctx *context.Context) { + f((*beecontext.Context)(ctx)) + }) +} + +// AddMethod add http method router +// usage: +// AddMethod("get","/api/:id", func(ctx *context.Context){ +// ctx.Output.Body("hello world") +// }) +func (p *ControllerRegister) AddMethod(method, pattern string, f FilterFunc) { + (*web.ControllerRegister)(p).AddMethod(method, pattern, func(ctx *context.Context) { + f((*beecontext.Context)(ctx)) + }) +} + +// Handler add user defined Handler +func (p *ControllerRegister) Handler(pattern string, h http.Handler, options ...interface{}) { + (*web.ControllerRegister)(p).Handler(pattern, h, options) +} + +// AddAuto router to ControllerRegister. +// example beego.AddAuto(&MainContorlller{}), +// MainController has method List and Page. +// visit the url /main/list to execute List function +// /main/page to execute Page function. +func (p *ControllerRegister) AddAuto(c ControllerInterface) { + (*web.ControllerRegister)(p).AddAuto(c) +} + +// AddAutoPrefix Add auto router to ControllerRegister with prefix. +// example beego.AddAutoPrefix("/admin",&MainContorlller{}), +// MainController has method List and Page. +// visit the url /admin/main/list to execute List function +// /admin/main/page to execute Page function. +func (p *ControllerRegister) AddAutoPrefix(prefix string, c ControllerInterface) { + (*web.ControllerRegister)(p).AddAutoPrefix(prefix, c) +} + +// InsertFilter Add a FilterFunc with pattern rule and action constant. +// params is for: +// 1. setting the returnOnOutput value (false allows multiple filters to execute) +// 2. determining whether or not params need to be reset. +func (p *ControllerRegister) InsertFilter(pattern string, pos int, filter FilterFunc, params ...bool) error { + opts := oldToNewFilterOpts(params) + return (*web.ControllerRegister)(p).InsertFilter(pattern, pos, func(ctx *context.Context) { + filter((*beecontext.Context)(ctx)) + }, opts...) +} + +func oldToNewFilterOpts(params []bool) []web.FilterOpt { + opts := make([]web.FilterOpt, 0, 4) + if len(params) > 0 { + opts = append(opts, web.WithReturnOnOutput(params[0])) + } else { + // the default value should be true + opts = append(opts, web.WithReturnOnOutput(true)) + } + if len(params) > 1 { + opts = append(opts, web.WithResetParams(params[1])) + } + return opts +} + +// URLFor does another controller handler in this request function. +// it can access any controller method. +func (p *ControllerRegister) URLFor(endpoint string, values ...interface{}) string { + return (*web.ControllerRegister)(p).URLFor(endpoint, values...) +} + +// Implement http.Handler interface. +func (p *ControllerRegister) ServeHTTP(rw http.ResponseWriter, r *http.Request) { + (*web.ControllerRegister)(p).ServeHTTP(rw, r) +} + +// FindRouter Find Router info for URL +func (p *ControllerRegister) FindRouter(ctx *beecontext.Context) (routerInfo *ControllerInfo, isFind bool) { + r, ok := (*web.ControllerRegister)(p).FindRouter((*context.Context)(ctx)) + return (*ControllerInfo)(r), ok +} + +// LogAccess logging info HTTP Access +func LogAccess(ctx *beecontext.Context, startTime *time.Time, statusCode int) { + web.LogAccess((*context.Context)(ctx), startTime, statusCode) +} diff --git a/pkg/adapter/session/couchbase/sess_couchbase.go b/pkg/adapter/session/couchbase/sess_couchbase.go new file mode 100644 index 00000000..bce09641 --- /dev/null +++ b/pkg/adapter/session/couchbase/sess_couchbase.go @@ -0,0 +1,118 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package couchbase for session provider +// +// depend on github.com/couchbaselabs/go-couchbasee +// +// go install github.com/couchbaselabs/go-couchbase +// +// Usage: +// import( +// _ "github.com/astaxie/beego/session/couchbase" +// "github.com/astaxie/beego/session" +// ) +// +// func init() { +// globalSessions, _ = session.NewManager("couchbase", ``{"cookieName":"gosessionid","gclifetime":3600,"ProviderConfig":"http://host:port/, Pool, Bucket"}``) +// go globalSessions.GC() +// } +// +// more docs: http://beego.me/docs/module/session.md +package couchbase + +import ( + "context" + "net/http" + + "github.com/astaxie/beego/pkg/adapter/session" + beecb "github.com/astaxie/beego/pkg/infrastructure/session/couchbase" +) + +// SessionStore store each session +type SessionStore beecb.SessionStore + +// Provider couchabse provided +type Provider beecb.Provider + +// Set value to couchabse session +func (cs *SessionStore) Set(key, value interface{}) error { + return (*beecb.SessionStore)(cs).Set(context.Background(), key, value) +} + +// Get value from couchabse session +func (cs *SessionStore) Get(key interface{}) interface{} { + return (*beecb.SessionStore)(cs).Get(context.Background(), key) +} + +// Delete value in couchbase session by given key +func (cs *SessionStore) Delete(key interface{}) error { + return (*beecb.SessionStore)(cs).Delete(context.Background(), key) +} + +// Flush Clean all values in couchbase session +func (cs *SessionStore) Flush() error { + return (*beecb.SessionStore)(cs).Flush(context.Background()) +} + +// SessionID Get couchbase session store id +func (cs *SessionStore) SessionID() string { + return (*beecb.SessionStore)(cs).SessionID(context.Background()) +} + +// SessionRelease Write couchbase session with Gob string +func (cs *SessionStore) SessionRelease(w http.ResponseWriter) { + (*beecb.SessionStore)(cs).SessionRelease(context.Background(), w) +} + +// SessionInit init couchbase session +// savepath like couchbase server REST/JSON URL +// e.g. http://host:port/, Pool, Bucket +func (cp *Provider) SessionInit(maxlifetime int64, savePath string) error { + return (*beecb.Provider)(cp).SessionInit(context.Background(), maxlifetime, savePath) +} + +// SessionRead read couchbase session by sid +func (cp *Provider) SessionRead(sid string) (session.Store, error) { + s, err := (*beecb.Provider)(cp).SessionRead(context.Background(), sid) + return session.CreateNewToOldStoreAdapter(s), err +} + +// SessionExist Check couchbase session exist. +// it checkes sid exist or not. +func (cp *Provider) SessionExist(sid string) bool { + res, _ := (*beecb.Provider)(cp).SessionExist(context.Background(), sid) + return res +} + +// SessionRegenerate remove oldsid and use sid to generate new session +func (cp *Provider) SessionRegenerate(oldsid, sid string) (session.Store, error) { + s, err := (*beecb.Provider)(cp).SessionRegenerate(context.Background(), oldsid, sid) + return session.CreateNewToOldStoreAdapter(s), err +} + +// SessionDestroy Remove bucket in this couchbase +func (cp *Provider) SessionDestroy(sid string) error { + return (*beecb.Provider)(cp).SessionDestroy(context.Background(), sid) +} + +// SessionGC Recycle +func (cp *Provider) SessionGC() { + (*beecb.Provider)(cp).SessionGC(context.Background()) +} + +// SessionAll return all active session +func (cp *Provider) SessionAll() int { + return (*beecb.Provider)(cp).SessionAll(context.Background()) +} diff --git a/pkg/adapter/session/ledis/ledis_session.go b/pkg/adapter/session/ledis/ledis_session.go new file mode 100644 index 00000000..96198837 --- /dev/null +++ b/pkg/adapter/session/ledis/ledis_session.go @@ -0,0 +1,86 @@ +// Package ledis provide session Provider +package ledis + +import ( + "context" + "net/http" + + "github.com/astaxie/beego/pkg/adapter/session" + beeLedis "github.com/astaxie/beego/pkg/infrastructure/session/ledis" +) + +// SessionStore ledis session store +type SessionStore beeLedis.SessionStore + +// Set value in ledis session +func (ls *SessionStore) Set(key, value interface{}) error { + return (*beeLedis.SessionStore)(ls).Set(context.Background(), key, value) +} + +// Get value in ledis session +func (ls *SessionStore) Get(key interface{}) interface{} { + return (*beeLedis.SessionStore)(ls).Get(context.Background(), key) +} + +// Delete value in ledis session +func (ls *SessionStore) Delete(key interface{}) error { + return (*beeLedis.SessionStore)(ls).Delete(context.Background(), key) +} + +// Flush clear all values in ledis session +func (ls *SessionStore) Flush() error { + return (*beeLedis.SessionStore)(ls).Flush(context.Background()) +} + +// SessionID get ledis session id +func (ls *SessionStore) SessionID() string { + return (*beeLedis.SessionStore)(ls).SessionID(context.Background()) +} + +// SessionRelease save session values to ledis +func (ls *SessionStore) SessionRelease(w http.ResponseWriter) { + (*beeLedis.SessionStore)(ls).SessionRelease(context.Background(), w) +} + +// Provider ledis session provider +type Provider beeLedis.Provider + +// SessionInit init ledis session +// savepath like ledis server saveDataPath,pool size +// e.g. 127.0.0.1:6379,100,astaxie +func (lp *Provider) SessionInit(maxlifetime int64, savePath string) error { + return (*beeLedis.Provider)(lp).SessionInit(context.Background(), maxlifetime, savePath) +} + +// SessionRead read ledis session by sid +func (lp *Provider) SessionRead(sid string) (session.Store, error) { + s, err := (*beeLedis.Provider)(lp).SessionRead(context.Background(), sid) + return session.CreateNewToOldStoreAdapter(s), err +} + +// SessionExist check ledis session exist by sid +func (lp *Provider) SessionExist(sid string) bool { + res, _ := (*beeLedis.Provider)(lp).SessionExist(context.Background(), sid) + return res +} + +// SessionRegenerate generate new sid for ledis session +func (lp *Provider) SessionRegenerate(oldsid, sid string) (session.Store, error) { + s, err := (*beeLedis.Provider)(lp).SessionRegenerate(context.Background(), oldsid, sid) + return session.CreateNewToOldStoreAdapter(s), err +} + +// SessionDestroy delete ledis session by id +func (lp *Provider) SessionDestroy(sid string) error { + return (*beeLedis.Provider)(lp).SessionDestroy(context.Background(), sid) +} + +// SessionGC Impelment method, no used. +func (lp *Provider) SessionGC() { + (*beeLedis.Provider)(lp).SessionGC(context.Background()) +} + +// SessionAll return all active session +func (lp *Provider) SessionAll() int { + return (*beeLedis.Provider)(lp).SessionAll(context.Background()) +} diff --git a/pkg/adapter/session/memcache/sess_memcache.go b/pkg/adapter/session/memcache/sess_memcache.go new file mode 100644 index 00000000..8afa79aa --- /dev/null +++ b/pkg/adapter/session/memcache/sess_memcache.go @@ -0,0 +1,118 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package memcache for session provider +// +// depend on github.com/bradfitz/gomemcache/memcache +// +// go install github.com/bradfitz/gomemcache/memcache +// +// Usage: +// import( +// _ "github.com/astaxie/beego/session/memcache" +// "github.com/astaxie/beego/session" +// ) +// +// func init() { +// globalSessions, _ = session.NewManager("memcache", ``{"cookieName":"gosessionid","gclifetime":3600,"ProviderConfig":"127.0.0.1:11211"}``) +// go globalSessions.GC() +// } +// +// more docs: http://beego.me/docs/module/session.md +package memcache + +import ( + "context" + "net/http" + + "github.com/astaxie/beego/pkg/adapter/session" + + beemem "github.com/astaxie/beego/pkg/infrastructure/session/memcache" +) + +// SessionStore memcache session store +type SessionStore beemem.SessionStore + +// Set value in memcache session +func (rs *SessionStore) Set(key, value interface{}) error { + return (*beemem.SessionStore)(rs).Set(context.Background(), key, value) +} + +// Get value in memcache session +func (rs *SessionStore) Get(key interface{}) interface{} { + return (*beemem.SessionStore)(rs).Get(context.Background(), key) +} + +// Delete value in memcache session +func (rs *SessionStore) Delete(key interface{}) error { + return (*beemem.SessionStore)(rs).Delete(context.Background(), key) +} + +// Flush clear all values in memcache session +func (rs *SessionStore) Flush() error { + return (*beemem.SessionStore)(rs).Flush(context.Background()) +} + +// SessionID get memcache session id +func (rs *SessionStore) SessionID() string { + return (*beemem.SessionStore)(rs).SessionID(context.Background()) +} + +// SessionRelease save session values to memcache +func (rs *SessionStore) SessionRelease(w http.ResponseWriter) { + (*beemem.SessionStore)(rs).SessionRelease(context.Background(), w) +} + +// MemProvider memcache session provider +type MemProvider beemem.MemProvider + +// SessionInit init memcache session +// savepath like +// e.g. 127.0.0.1:9090 +func (rp *MemProvider) SessionInit(maxlifetime int64, savePath string) error { + return (*beemem.MemProvider)(rp).SessionInit(context.Background(), maxlifetime, savePath) +} + +// SessionRead read memcache session by sid +func (rp *MemProvider) SessionRead(sid string) (session.Store, error) { + s, err := (*beemem.MemProvider)(rp).SessionRead(context.Background(), sid) + return session.CreateNewToOldStoreAdapter(s), err +} + +// SessionExist check memcache session exist by sid +func (rp *MemProvider) SessionExist(sid string) bool { + res, _ := (*beemem.MemProvider)(rp).SessionExist(context.Background(), sid) + return res +} + +// SessionRegenerate generate new sid for memcache session +func (rp *MemProvider) SessionRegenerate(oldsid, sid string) (session.Store, error) { + s, err := (*beemem.MemProvider)(rp).SessionRegenerate(context.Background(), oldsid, sid) + return session.CreateNewToOldStoreAdapter(s), err +} + +// SessionDestroy delete memcache session by id +func (rp *MemProvider) SessionDestroy(sid string) error { + return (*beemem.MemProvider)(rp).SessionDestroy(context.Background(), sid) +} + +// SessionGC Impelment method, no used. +func (rp *MemProvider) SessionGC() { + (*beemem.MemProvider)(rp).SessionGC(context.Background()) +} + +// SessionAll return all activeSession +func (rp *MemProvider) SessionAll() int { + return (*beemem.MemProvider)(rp).SessionAll(context.Background()) +} diff --git a/pkg/adapter/session/mysql/sess_mysql.go b/pkg/adapter/session/mysql/sess_mysql.go new file mode 100644 index 00000000..1850a380 --- /dev/null +++ b/pkg/adapter/session/mysql/sess_mysql.go @@ -0,0 +1,135 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package mysql for session provider +// +// depends on github.com/go-sql-driver/mysql: +// +// go install github.com/go-sql-driver/mysql +// +// mysql session support need create table as sql: +// CREATE TABLE `session` ( +// `session_key` char(64) NOT NULL, +// `session_data` blob, +// `session_expiry` int(11) unsigned NOT NULL, +// PRIMARY KEY (`session_key`) +// ) ENGINE=MyISAM DEFAULT CHARSET=utf8; +// +// Usage: +// import( +// _ "github.com/astaxie/beego/session/mysql" +// "github.com/astaxie/beego/session" +// ) +// +// func init() { +// globalSessions, _ = session.NewManager("mysql", ``{"cookieName":"gosessionid","gclifetime":3600,"ProviderConfig":"[username[:password]@][protocol[(address)]]/dbname[?param1=value1&...¶mN=valueN]"}``) +// go globalSessions.GC() +// } +// +// more docs: http://beego.me/docs/module/session.md +package mysql + +import ( + "context" + "net/http" + + "github.com/astaxie/beego/pkg/adapter/session" + "github.com/astaxie/beego/pkg/infrastructure/session/mysql" + + // import mysql driver + _ "github.com/go-sql-driver/mysql" +) + +var ( + // TableName store the session in MySQL + TableName = mysql.TableName + mysqlpder = &Provider{} +) + +// SessionStore mysql session store +type SessionStore mysql.SessionStore + +// Set value in mysql session. +// it is temp value in map. +func (st *SessionStore) Set(key, value interface{}) error { + return (*mysql.SessionStore)(st).Set(context.Background(), key, value) +} + +// Get value from mysql session +func (st *SessionStore) Get(key interface{}) interface{} { + return (*mysql.SessionStore)(st).Get(context.Background(), key) +} + +// Delete value in mysql session +func (st *SessionStore) Delete(key interface{}) error { + return (*mysql.SessionStore)(st).Delete(context.Background(), key) +} + +// Flush clear all values in mysql session +func (st *SessionStore) Flush() error { + return (*mysql.SessionStore)(st).Flush(context.Background()) +} + +// SessionID get session id of this mysql session store +func (st *SessionStore) SessionID() string { + return (*mysql.SessionStore)(st).SessionID(context.Background()) +} + +// SessionRelease save mysql session values to database. +// must call this method to save values to database. +func (st *SessionStore) SessionRelease(w http.ResponseWriter) { + (*mysql.SessionStore)(st).SessionRelease(context.Background(), w) +} + +// Provider mysql session provider +type Provider mysql.Provider + +// SessionInit init mysql session. +// savepath is the connection string of mysql. +func (mp *Provider) SessionInit(maxlifetime int64, savePath string) error { + return (*mysql.Provider)(mp).SessionInit(context.Background(), maxlifetime, savePath) +} + +// SessionRead get mysql session by sid +func (mp *Provider) SessionRead(sid string) (session.Store, error) { + s, err := (*mysql.Provider)(mp).SessionRead(context.Background(), sid) + return session.CreateNewToOldStoreAdapter(s), err +} + +// SessionExist check mysql session exist +func (mp *Provider) SessionExist(sid string) bool { + res, _ := (*mysql.Provider)(mp).SessionExist(context.Background(), sid) + return res +} + +// SessionRegenerate generate new sid for mysql session +func (mp *Provider) SessionRegenerate(oldsid, sid string) (session.Store, error) { + s, err := (*mysql.Provider)(mp).SessionRegenerate(context.Background(), oldsid, sid) + return session.CreateNewToOldStoreAdapter(s), err +} + +// SessionDestroy delete mysql session by sid +func (mp *Provider) SessionDestroy(sid string) error { + return (*mysql.Provider)(mp).SessionDestroy(context.Background(), sid) +} + +// SessionGC delete expired values in mysql session +func (mp *Provider) SessionGC() { + (*mysql.Provider)(mp).SessionGC(context.Background()) +} + +// SessionAll count values in mysql session +func (mp *Provider) SessionAll() int { + return (*mysql.Provider)(mp).SessionAll(context.Background()) +} diff --git a/pkg/adapter/session/postgres/sess_postgresql.go b/pkg/adapter/session/postgres/sess_postgresql.go new file mode 100644 index 00000000..de1adbc4 --- /dev/null +++ b/pkg/adapter/session/postgres/sess_postgresql.go @@ -0,0 +1,139 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package postgres for session provider +// +// depends on github.com/lib/pq: +// +// go install github.com/lib/pq +// +// +// needs this table in your database: +// +// CREATE TABLE session ( +// session_key char(64) NOT NULL, +// session_data bytea, +// session_expiry timestamp NOT NULL, +// CONSTRAINT session_key PRIMARY KEY(session_key) +// ); +// +// will be activated with these settings in app.conf: +// +// SessionOn = true +// SessionProvider = postgresql +// SessionSavePath = "user=a password=b dbname=c sslmode=disable" +// SessionName = session +// +// +// Usage: +// import( +// _ "github.com/astaxie/beego/session/postgresql" +// "github.com/astaxie/beego/session" +// ) +// +// func init() { +// globalSessions, _ = session.NewManager("postgresql", ``{"cookieName":"gosessionid","gclifetime":3600,"ProviderConfig":"user=pqgotest dbname=pqgotest sslmode=verify-full"}``) +// go globalSessions.GC() +// } +// +// more docs: http://beego.me/docs/module/session.md +package postgres + +import ( + "context" + "net/http" + + "github.com/astaxie/beego/pkg/adapter/session" + // import postgresql Driver + _ "github.com/lib/pq" + + "github.com/astaxie/beego/pkg/infrastructure/session/postgres" +) + +// SessionStore postgresql session store +type SessionStore postgres.SessionStore + +// Set value in postgresql session. +// it is temp value in map. +func (st *SessionStore) Set(key, value interface{}) error { + return (*postgres.SessionStore)(st).Set(context.Background(), key, value) +} + +// Get value from postgresql session +func (st *SessionStore) Get(key interface{}) interface{} { + return (*postgres.SessionStore)(st).Get(context.Background(), key) +} + +// Delete value in postgresql session +func (st *SessionStore) Delete(key interface{}) error { + return (*postgres.SessionStore)(st).Delete(context.Background(), key) +} + +// Flush clear all values in postgresql session +func (st *SessionStore) Flush() error { + return (*postgres.SessionStore)(st).Flush(context.Background()) +} + +// SessionID get session id of this postgresql session store +func (st *SessionStore) SessionID() string { + return (*postgres.SessionStore)(st).SessionID(context.Background()) +} + +// SessionRelease save postgresql session values to database. +// must call this method to save values to database. +func (st *SessionStore) SessionRelease(w http.ResponseWriter) { + (*postgres.SessionStore)(st).SessionRelease(context.Background(), w) +} + +// Provider postgresql session provider +type Provider postgres.Provider + +// SessionInit init postgresql session. +// savepath is the connection string of postgresql. +func (mp *Provider) SessionInit(maxlifetime int64, savePath string) error { + return (*postgres.Provider)(mp).SessionInit(context.Background(), maxlifetime, savePath) +} + +// SessionRead get postgresql session by sid +func (mp *Provider) SessionRead(sid string) (session.Store, error) { + s, err := (*postgres.Provider)(mp).SessionRead(context.Background(), sid) + return session.CreateNewToOldStoreAdapter(s), err +} + +// SessionExist check postgresql session exist +func (mp *Provider) SessionExist(sid string) bool { + res, _ := (*postgres.Provider)(mp).SessionExist(context.Background(), sid) + return res +} + +// SessionRegenerate generate new sid for postgresql session +func (mp *Provider) SessionRegenerate(oldsid, sid string) (session.Store, error) { + s, err := (*postgres.Provider)(mp).SessionRegenerate(context.Background(), oldsid, sid) + return session.CreateNewToOldStoreAdapter(s), err +} + +// SessionDestroy delete postgresql session by sid +func (mp *Provider) SessionDestroy(sid string) error { + return (*postgres.Provider)(mp).SessionDestroy(context.Background(), sid) +} + +// SessionGC delete expired values in postgresql session +func (mp *Provider) SessionGC() { + (*postgres.Provider)(mp).SessionGC(context.Background()) +} + +// SessionAll count values in postgresql session +func (mp *Provider) SessionAll() int { + return (*postgres.Provider)(mp).SessionAll(context.Background()) +} diff --git a/pkg/adapter/session/provider_adapter.go b/pkg/adapter/session/provider_adapter.go new file mode 100644 index 00000000..11177a4d --- /dev/null +++ b/pkg/adapter/session/provider_adapter.go @@ -0,0 +1,104 @@ +// Copyright 2020 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package session + +import ( + "context" + + "github.com/astaxie/beego/pkg/infrastructure/session" +) + +type oldToNewProviderAdapter struct { + delegate Provider +} + +func (o *oldToNewProviderAdapter) SessionInit(ctx context.Context, gclifetime int64, config string) error { + return o.delegate.SessionInit(gclifetime, config) +} + +func (o *oldToNewProviderAdapter) SessionRead(ctx context.Context, sid string) (session.Store, error) { + store, err := o.delegate.SessionRead(sid) + return &oldToNewStoreAdapter{ + delegate: store, + }, err +} + +func (o *oldToNewProviderAdapter) SessionExist(ctx context.Context, sid string) (bool, error) { + return o.delegate.SessionExist(sid), nil +} + +func (o *oldToNewProviderAdapter) SessionRegenerate(ctx context.Context, oldsid, sid string) (session.Store, error) { + s, err := o.delegate.SessionRegenerate(oldsid, sid) + return &oldToNewStoreAdapter{ + delegate: s, + }, err +} + +func (o *oldToNewProviderAdapter) SessionDestroy(ctx context.Context, sid string) error { + return o.delegate.SessionDestroy(sid) +} + +func (o *oldToNewProviderAdapter) SessionAll(ctx context.Context) int { + return o.delegate.SessionAll() +} + +func (o *oldToNewProviderAdapter) SessionGC(ctx context.Context) { + o.delegate.SessionGC() +} + +type newToOldProviderAdapter struct { + delegate session.Provider +} + +func (n *newToOldProviderAdapter) SessionInit(gclifetime int64, config string) error { + return n.delegate.SessionInit(context.Background(), gclifetime, config) +} + +func (n *newToOldProviderAdapter) SessionRead(sid string) (Store, error) { + s, err := n.delegate.SessionRead(context.Background(), sid) + if adt, ok := s.(*oldToNewStoreAdapter); err == nil && ok { + return adt.delegate, err + } + return &NewToOldStoreAdapter{ + delegate: s, + }, err +} + +func (n *newToOldProviderAdapter) SessionExist(sid string) bool { + res, _ := n.delegate.SessionExist(context.Background(), sid) + return res +} + +func (n *newToOldProviderAdapter) SessionRegenerate(oldsid, sid string) (Store, error) { + s, err := n.delegate.SessionRegenerate(context.Background(), oldsid, sid) + if adt, ok := s.(*oldToNewStoreAdapter); err == nil && ok { + return adt.delegate, err + } + return &NewToOldStoreAdapter{ + delegate: s, + }, err +} + +func (n *newToOldProviderAdapter) SessionDestroy(sid string) error { + return n.delegate.SessionDestroy(context.Background(), sid) +} + +func (n *newToOldProviderAdapter) SessionAll() int { + return n.delegate.SessionAll(context.Background()) +} + +func (n *newToOldProviderAdapter) SessionGC() { + n.delegate.SessionGC(context.Background()) +} diff --git a/pkg/adapter/session/redis/sess_redis.go b/pkg/adapter/session/redis/sess_redis.go new file mode 100644 index 00000000..6c521e50 --- /dev/null +++ b/pkg/adapter/session/redis/sess_redis.go @@ -0,0 +1,121 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package redis for session provider +// +// depend on github.com/gomodule/redigo/redis +// +// go install github.com/gomodule/redigo/redis +// +// Usage: +// import( +// _ "github.com/astaxie/beego/session/redis" +// "github.com/astaxie/beego/session" +// ) +// +// func init() { +// globalSessions, _ = session.NewManager("redis", ``{"cookieName":"gosessionid","gclifetime":3600,"ProviderConfig":"127.0.0.1:7070"}``) +// go globalSessions.GC() +// } +// +// more docs: http://beego.me/docs/module/session.md +package redis + +import ( + "context" + "net/http" + + "github.com/astaxie/beego/pkg/adapter/session" + + beeRedis "github.com/astaxie/beego/pkg/infrastructure/session/redis" +) + +// MaxPoolSize redis max pool size +var MaxPoolSize = beeRedis.MaxPoolSize + +// SessionStore redis session store +type SessionStore beeRedis.SessionStore + +// Set value in redis session +func (rs *SessionStore) Set(key, value interface{}) error { + return (*beeRedis.SessionStore)(rs).Set(context.Background(), key, value) +} + +// Get value in redis session +func (rs *SessionStore) Get(key interface{}) interface{} { + return (*beeRedis.SessionStore)(rs).Get(context.Background(), key) +} + +// Delete value in redis session +func (rs *SessionStore) Delete(key interface{}) error { + return (*beeRedis.SessionStore)(rs).Delete(context.Background(), key) +} + +// Flush clear all values in redis session +func (rs *SessionStore) Flush() error { + return (*beeRedis.SessionStore)(rs).Flush(context.Background()) +} + +// SessionID get redis session id +func (rs *SessionStore) SessionID() string { + return (*beeRedis.SessionStore)(rs).SessionID(context.Background()) +} + +// SessionRelease save session values to redis +func (rs *SessionStore) SessionRelease(w http.ResponseWriter) { + (*beeRedis.SessionStore)(rs).SessionRelease(context.Background(), w) +} + +// Provider redis session provider +type Provider beeRedis.Provider + +// SessionInit init redis session +// savepath like redis server addr,pool size,password,dbnum,IdleTimeout second +// e.g. 127.0.0.1:6379,100,astaxie,0,30 +func (rp *Provider) SessionInit(maxlifetime int64, savePath string) error { + return (*beeRedis.Provider)(rp).SessionInit(context.Background(), maxlifetime, savePath) +} + +// SessionRead read redis session by sid +func (rp *Provider) SessionRead(sid string) (session.Store, error) { + s, err := (*beeRedis.Provider)(rp).SessionRead(context.Background(), sid) + return session.CreateNewToOldStoreAdapter(s), err +} + +// SessionExist check redis session exist by sid +func (rp *Provider) SessionExist(sid string) bool { + res, _ := (*beeRedis.Provider)(rp).SessionExist(context.Background(), sid) + return res +} + +// SessionRegenerate generate new sid for redis session +func (rp *Provider) SessionRegenerate(oldsid, sid string) (session.Store, error) { + s, err := (*beeRedis.Provider)(rp).SessionRegenerate(context.Background(), oldsid, sid) + return session.CreateNewToOldStoreAdapter(s), err +} + +// SessionDestroy delete redis session by id +func (rp *Provider) SessionDestroy(sid string) error { + return (*beeRedis.Provider)(rp).SessionDestroy(context.Background(), sid) +} + +// SessionGC Impelment method, no used. +func (rp *Provider) SessionGC() { + (*beeRedis.Provider)(rp).SessionGC(context.Background()) +} + +// SessionAll return all activeSession +func (rp *Provider) SessionAll() int { + return (*beeRedis.Provider)(rp).SessionAll(context.Background()) +} diff --git a/pkg/adapter/session/redis_cluster/redis_cluster.go b/pkg/adapter/session/redis_cluster/redis_cluster.go new file mode 100644 index 00000000..03a805e4 --- /dev/null +++ b/pkg/adapter/session/redis_cluster/redis_cluster.go @@ -0,0 +1,120 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package redis for session provider +// +// depend on github.com/go-redis/redis +// +// go install github.com/go-redis/redis +// +// Usage: +// import( +// _ "github.com/astaxie/beego/session/redis_cluster" +// "github.com/astaxie/beego/session" +// ) +// +// func init() { +// globalSessions, _ = session.NewManager("redis_cluster", ``{"cookieName":"gosessionid","gclifetime":3600,"ProviderConfig":"127.0.0.1:7070;127.0.0.1:7071"}``) +// go globalSessions.GC() +// } +// +// more docs: http://beego.me/docs/module/session.md +package redis_cluster + +import ( + "context" + "net/http" + + "github.com/astaxie/beego/pkg/adapter/session" + cluster "github.com/astaxie/beego/pkg/infrastructure/session/redis_cluster" +) + +// MaxPoolSize redis_cluster max pool size +var MaxPoolSize = cluster.MaxPoolSize + +// SessionStore redis_cluster session store +type SessionStore cluster.SessionStore + +// Set value in redis_cluster session +func (rs *SessionStore) Set(key, value interface{}) error { + return (*cluster.SessionStore)(rs).Set(context.Background(), key, value) +} + +// Get value in redis_cluster session +func (rs *SessionStore) Get(key interface{}) interface{} { + return (*cluster.SessionStore)(rs).Get(context.Background(), key) +} + +// Delete value in redis_cluster session +func (rs *SessionStore) Delete(key interface{}) error { + return (*cluster.SessionStore)(rs).Delete(context.Background(), key) +} + +// Flush clear all values in redis_cluster session +func (rs *SessionStore) Flush() error { + return (*cluster.SessionStore)(rs).Flush(context.Background()) +} + +// SessionID get redis_cluster session id +func (rs *SessionStore) SessionID() string { + return (*cluster.SessionStore)(rs).SessionID(context.Background()) +} + +// SessionRelease save session values to redis_cluster +func (rs *SessionStore) SessionRelease(w http.ResponseWriter) { + (*cluster.SessionStore)(rs).SessionRelease(context.Background(), w) +} + +// Provider redis_cluster session provider +type Provider cluster.Provider + +// SessionInit init redis_cluster session +// savepath like redis server addr,pool size,password,dbnum +// e.g. 127.0.0.1:6379;127.0.0.1:6380,100,test,0 +func (rp *Provider) SessionInit(maxlifetime int64, savePath string) error { + return (*cluster.Provider)(rp).SessionInit(context.Background(), maxlifetime, savePath) +} + +// SessionRead read redis_cluster session by sid +func (rp *Provider) SessionRead(sid string) (session.Store, error) { + s, err := (*cluster.Provider)(rp).SessionRead(context.Background(), sid) + return session.CreateNewToOldStoreAdapter(s), err +} + +// SessionExist check redis_cluster session exist by sid +func (rp *Provider) SessionExist(sid string) bool { + res, _ := (*cluster.Provider)(rp).SessionExist(context.Background(), sid) + return res +} + +// SessionRegenerate generate new sid for redis_cluster session +func (rp *Provider) SessionRegenerate(oldsid, sid string) (session.Store, error) { + s, err := (*cluster.Provider)(rp).SessionRegenerate(context.Background(), oldsid, sid) + return session.CreateNewToOldStoreAdapter(s), err +} + +// SessionDestroy delete redis session by id +func (rp *Provider) SessionDestroy(sid string) error { + return (*cluster.Provider)(rp).SessionDestroy(context.Background(), sid) +} + +// SessionGC Impelment method, no used. +func (rp *Provider) SessionGC() { + (*cluster.Provider)(rp).SessionGC(context.Background()) +} + +// SessionAll return all activeSession +func (rp *Provider) SessionAll() int { + return (*cluster.Provider)(rp).SessionAll(context.Background()) +} diff --git a/pkg/adapter/session/redis_sentinel/sess_redis_sentinel.go b/pkg/adapter/session/redis_sentinel/sess_redis_sentinel.go new file mode 100644 index 00000000..f5eb8a4f --- /dev/null +++ b/pkg/adapter/session/redis_sentinel/sess_redis_sentinel.go @@ -0,0 +1,121 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package redis for session provider +// +// depend on github.com/go-redis/redis +// +// go install github.com/go-redis/redis +// +// Usage: +// import( +// _ "github.com/astaxie/beego/session/redis_sentinel" +// "github.com/astaxie/beego/session" +// ) +// +// func init() { +// globalSessions, _ = session.NewManager("redis_sentinel", ``{"cookieName":"gosessionid","gclifetime":3600,"ProviderConfig":"127.0.0.1:26379;127.0.0.2:26379"}``) +// go globalSessions.GC() +// } +// +// more detail about params: please check the notes on the function SessionInit in this package +package redis_sentinel + +import ( + "context" + "net/http" + + "github.com/astaxie/beego/pkg/adapter/session" + + sentinel "github.com/astaxie/beego/pkg/infrastructure/session/redis_sentinel" +) + +// DefaultPoolSize redis_sentinel default pool size +var DefaultPoolSize = sentinel.DefaultPoolSize + +// SessionStore redis_sentinel session store +type SessionStore sentinel.SessionStore + +// Set value in redis_sentinel session +func (rs *SessionStore) Set(key, value interface{}) error { + return (*sentinel.SessionStore)(rs).Set(context.Background(), key, value) +} + +// Get value in redis_sentinel session +func (rs *SessionStore) Get(key interface{}) interface{} { + return (*sentinel.SessionStore)(rs).Get(context.Background(), key) +} + +// Delete value in redis_sentinel session +func (rs *SessionStore) Delete(key interface{}) error { + return (*sentinel.SessionStore)(rs).Delete(context.Background(), key) +} + +// Flush clear all values in redis_sentinel session +func (rs *SessionStore) Flush() error { + return (*sentinel.SessionStore)(rs).Flush(context.Background()) +} + +// SessionID get redis_sentinel session id +func (rs *SessionStore) SessionID() string { + return (*sentinel.SessionStore)(rs).SessionID(context.Background()) +} + +// SessionRelease save session values to redis_sentinel +func (rs *SessionStore) SessionRelease(w http.ResponseWriter) { + (*sentinel.SessionStore)(rs).SessionRelease(context.Background(), w) +} + +// Provider redis_sentinel session provider +type Provider sentinel.Provider + +// SessionInit init redis_sentinel session +// savepath like redis sentinel addr,pool size,password,dbnum,masterName +// e.g. 127.0.0.1:26379;127.0.0.2:26379,100,1qaz2wsx,0,mymaster +func (rp *Provider) SessionInit(maxlifetime int64, savePath string) error { + return (*sentinel.Provider)(rp).SessionInit(context.Background(), maxlifetime, savePath) +} + +// SessionRead read redis_sentinel session by sid +func (rp *Provider) SessionRead(sid string) (session.Store, error) { + s, err := (*sentinel.Provider)(rp).SessionRead(context.Background(), sid) + return session.CreateNewToOldStoreAdapter(s), err +} + +// SessionExist check redis_sentinel session exist by sid +func (rp *Provider) SessionExist(sid string) bool { + res, _ := (*sentinel.Provider)(rp).SessionExist(context.Background(), sid) + return res +} + +// SessionRegenerate generate new sid for redis_sentinel session +func (rp *Provider) SessionRegenerate(oldsid, sid string) (session.Store, error) { + s, err := (*sentinel.Provider)(rp).SessionRegenerate(context.Background(), oldsid, sid) + return session.CreateNewToOldStoreAdapter(s), err +} + +// SessionDestroy delete redis session by id +func (rp *Provider) SessionDestroy(sid string) error { + return (*sentinel.Provider)(rp).SessionDestroy(context.Background(), sid) +} + +// SessionGC Impelment method, no used. +func (rp *Provider) SessionGC() { + (*sentinel.Provider)(rp).SessionGC(context.Background()) +} + +// SessionAll return all activeSession +func (rp *Provider) SessionAll() int { + return (*sentinel.Provider)(rp).SessionAll(context.Background()) +} diff --git a/pkg/session/redis_sentinel/sess_redis_sentinel_test.go b/pkg/adapter/session/redis_sentinel/sess_redis_sentinel_test.go similarity index 96% rename from pkg/session/redis_sentinel/sess_redis_sentinel_test.go rename to pkg/adapter/session/redis_sentinel/sess_redis_sentinel_test.go index bd31741f..7c33985f 100644 --- a/pkg/session/redis_sentinel/sess_redis_sentinel_test.go +++ b/pkg/adapter/session/redis_sentinel/sess_redis_sentinel_test.go @@ -5,7 +5,7 @@ import ( "net/http/httptest" "testing" - "github.com/astaxie/beego/pkg/session" + "github.com/astaxie/beego/pkg/adapter/session" ) func TestRedisSentinel(t *testing.T) { @@ -23,7 +23,7 @@ func TestRedisSentinel(t *testing.T) { t.Log(e) return } - //todo test if e==nil + // todo test if e==nil go globalSessions.GC() r, _ := http.NewRequest("GET", "/", nil) diff --git a/pkg/adapter/session/sess_cookie.go b/pkg/adapter/session/sess_cookie.go new file mode 100644 index 00000000..32216040 --- /dev/null +++ b/pkg/adapter/session/sess_cookie.go @@ -0,0 +1,114 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package session + +import ( + "context" + "net/http" + + "github.com/astaxie/beego/pkg/infrastructure/session" +) + +// CookieSessionStore Cookie SessionStore +type CookieSessionStore session.CookieSessionStore + +// Set value to cookie session. +// the value are encoded as gob with hash block string. +func (st *CookieSessionStore) Set(key, value interface{}) error { + return (*session.CookieSessionStore)(st).Set(context.Background(), key, value) +} + +// Get value from cookie session +func (st *CookieSessionStore) Get(key interface{}) interface{} { + return (*session.CookieSessionStore)(st).Get(context.Background(), key) +} + +// Delete value in cookie session +func (st *CookieSessionStore) Delete(key interface{}) error { + return (*session.CookieSessionStore)(st).Delete(context.Background(), key) +} + +// Flush Clean all values in cookie session +func (st *CookieSessionStore) Flush() error { + return (*session.CookieSessionStore)(st).Flush(context.Background()) +} + +// SessionID Return id of this cookie session +func (st *CookieSessionStore) SessionID() string { + return (*session.CookieSessionStore)(st).SessionID(context.Background()) +} + +// SessionRelease Write cookie session to http response cookie +func (st *CookieSessionStore) SessionRelease(w http.ResponseWriter) { + (*session.CookieSessionStore)(st).SessionRelease(context.Background(), w) +} + +// CookieProvider Cookie session provider +type CookieProvider session.CookieProvider + +// SessionInit Init cookie session provider with max lifetime and config json. +// maxlifetime is ignored. +// json config: +// securityKey - hash string +// blockKey - gob encode hash string. it's saved as aes crypto. +// securityName - recognized name in encoded cookie string +// cookieName - cookie name +// maxage - cookie max life time. +func (pder *CookieProvider) SessionInit(maxlifetime int64, config string) error { + return (*session.CookieProvider)(pder).SessionInit(context.Background(), maxlifetime, config) +} + +// SessionRead Get SessionStore in cooke. +// decode cooke string to map and put into SessionStore with sid. +func (pder *CookieProvider) SessionRead(sid string) (Store, error) { + s, err := (*session.CookieProvider)(pder).SessionRead(context.Background(), sid) + return &NewToOldStoreAdapter{ + delegate: s, + }, err +} + +// SessionExist Cookie session is always existed +func (pder *CookieProvider) SessionExist(sid string) bool { + res, _ := (*session.CookieProvider)(pder).SessionExist(context.Background(), sid) + return res +} + +// SessionRegenerate Implement method, no used. +func (pder *CookieProvider) SessionRegenerate(oldsid, sid string) (Store, error) { + s, err := (*session.CookieProvider)(pder).SessionRegenerate(context.Background(), oldsid, sid) + return &NewToOldStoreAdapter{ + delegate: s, + }, err +} + +// SessionDestroy Implement method, no used. +func (pder *CookieProvider) SessionDestroy(sid string) error { + return (*session.CookieProvider)(pder).SessionDestroy(context.Background(), sid) +} + +// SessionGC Implement method, no used. +func (pder *CookieProvider) SessionGC() { + (*session.CookieProvider)(pder).SessionGC(context.Background()) +} + +// SessionAll Implement method, return 0. +func (pder *CookieProvider) SessionAll() int { + return (*session.CookieProvider)(pder).SessionAll(context.Background()) +} + +// SessionUpdate Implement method, no used. +func (pder *CookieProvider) SessionUpdate(sid string) error { + return (*session.CookieProvider)(pder).SessionUpdate(context.Background(), sid) +} diff --git a/pkg/session/sess_cookie_test.go b/pkg/adapter/session/sess_cookie_test.go similarity index 100% rename from pkg/session/sess_cookie_test.go rename to pkg/adapter/session/sess_cookie_test.go diff --git a/pkg/adapter/session/sess_file.go b/pkg/adapter/session/sess_file.go new file mode 100644 index 00000000..b9648998 --- /dev/null +++ b/pkg/adapter/session/sess_file.go @@ -0,0 +1,106 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package session + +import ( + "context" + "net/http" + + "github.com/astaxie/beego/pkg/infrastructure/session" +) + +// FileSessionStore File session store +type FileSessionStore session.FileSessionStore + +// Set value to file session +func (fs *FileSessionStore) Set(key, value interface{}) error { + return (*session.FileSessionStore)(fs).Set(context.Background(), key, value) +} + +// Get value from file session +func (fs *FileSessionStore) Get(key interface{}) interface{} { + return (*session.FileSessionStore)(fs).Get(context.Background(), key) +} + +// Delete value in file session by given key +func (fs *FileSessionStore) Delete(key interface{}) error { + return (*session.FileSessionStore)(fs).Delete(context.Background(), key) +} + +// Flush Clean all values in file session +func (fs *FileSessionStore) Flush() error { + return (*session.FileSessionStore)(fs).Flush(context.Background()) +} + +// SessionID Get file session store id +func (fs *FileSessionStore) SessionID() string { + return (*session.FileSessionStore)(fs).SessionID(context.Background()) +} + +// SessionRelease Write file session to local file with Gob string +func (fs *FileSessionStore) SessionRelease(w http.ResponseWriter) { + (*session.FileSessionStore)(fs).SessionRelease(context.Background(), w) +} + +// FileProvider File session provider +type FileProvider session.FileProvider + +// SessionInit Init file session provider. +// savePath sets the session files path. +func (fp *FileProvider) SessionInit(maxlifetime int64, savePath string) error { + return (*session.FileProvider)(fp).SessionInit(context.Background(), maxlifetime, savePath) +} + +// SessionRead Read file session by sid. +// if file is not exist, create it. +// the file path is generated from sid string. +func (fp *FileProvider) SessionRead(sid string) (Store, error) { + s, err := (*session.FileProvider)(fp).SessionRead(context.Background(), sid) + return &NewToOldStoreAdapter{ + delegate: s, + }, err +} + +// SessionExist Check file session exist. +// it checks the file named from sid exist or not. +func (fp *FileProvider) SessionExist(sid string) bool { + res, _ := (*session.FileProvider)(fp).SessionExist(context.Background(), sid) + return res +} + +// SessionDestroy Remove all files in this save path +func (fp *FileProvider) SessionDestroy(sid string) error { + return (*session.FileProvider)(fp).SessionDestroy(context.Background(), sid) +} + +// SessionGC Recycle files in save path +func (fp *FileProvider) SessionGC() { + (*session.FileProvider)(fp).SessionGC(context.Background()) +} + +// SessionAll Get active file session number. +// it walks save path to count files. +func (fp *FileProvider) SessionAll() int { + return (*session.FileProvider)(fp).SessionAll(context.Background()) +} + +// SessionRegenerate Generate new sid for file session. +// it delete old file and create new file named from new sid. +func (fp *FileProvider) SessionRegenerate(oldsid, sid string) (Store, error) { + s, err := (*session.FileProvider)(fp).SessionRegenerate(context.Background(), oldsid, sid) + return &NewToOldStoreAdapter{ + delegate: s, + }, err +} diff --git a/pkg/session/sess_file_test.go b/pkg/adapter/session/sess_file_test.go similarity index 75% rename from pkg/session/sess_file_test.go rename to pkg/adapter/session/sess_file_test.go index a27d30a6..4c90a3ac 100644 --- a/pkg/session/sess_file_test.go +++ b/pkg/adapter/session/sess_file_test.go @@ -30,23 +30,6 @@ var ( mutex sync.Mutex ) -func TestFileProvider_SessionInit(t *testing.T) { - mutex.Lock() - defer mutex.Unlock() - os.RemoveAll(sessionPath) - defer os.RemoveAll(sessionPath) - fp := &FileProvider{} - - _ = fp.SessionInit(180, sessionPath) - if fp.maxlifetime != 180 { - t.Error() - } - - if fp.savePath != sessionPath { - t.Error() - } -} - func TestFileProvider_SessionExist(t *testing.T) { mutex.Lock() defer mutex.Unlock() @@ -56,24 +39,16 @@ func TestFileProvider_SessionExist(t *testing.T) { _ = fp.SessionInit(180, sessionPath) - exists, err := fp.SessionExist(sid) - if err != nil { - t.Error(err) - } - if exists { + if fp.SessionExist(sid) { t.Error() } - _, err = fp.SessionRead(sid) + _, err := fp.SessionRead(sid) if err != nil { t.Error(err) } - exists, err = fp.SessionExist(sid) - if err != nil { - t.Error(err) - } - if !exists { + if !fp.SessionExist(sid) { t.Error() } } @@ -87,27 +62,15 @@ func TestFileProvider_SessionExist2(t *testing.T) { _ = fp.SessionInit(180, sessionPath) - exists, err := fp.SessionExist(sid) - if err != nil { - t.Error(err) - } - if exists { + if fp.SessionExist(sid) { t.Error() } - exists, err = fp.SessionExist("") - if err == nil { - t.Error() - } - if exists { + if fp.SessionExist("") { t.Error() } - exists, err = fp.SessionExist("1") - if err == nil { - t.Error() - } - if exists { + if fp.SessionExist("1") { t.Error() } } @@ -191,11 +154,7 @@ func TestFileProvider_SessionRegenerate(t *testing.T) { t.Error(err) } - exists, err := fp.SessionExist(sid) - if err != nil { - t.Error(err) - } - if !exists { + if !fp.SessionExist(sid) { t.Error() } @@ -204,19 +163,11 @@ func TestFileProvider_SessionRegenerate(t *testing.T) { t.Error(err) } - exists, err = fp.SessionExist(sid) - if err != nil { - t.Error(err) - } - if exists { + if fp.SessionExist(sid) { t.Error() } - exists, err = fp.SessionExist(sidNew) - if err != nil { - t.Error(err) - } - if !exists { + if !fp.SessionExist(sidNew) { t.Error() } } @@ -235,11 +186,7 @@ func TestFileProvider_SessionDestroy(t *testing.T) { t.Error(err) } - exists, err := fp.SessionExist(sid) - if err != nil { - t.Error(err) - } - if !exists { + if !fp.SessionExist(sid) { t.Error() } @@ -248,11 +195,7 @@ func TestFileProvider_SessionDestroy(t *testing.T) { t.Error(err) } - exists, err = fp.SessionExist(sid) - if err != nil { - t.Error(err) - } - if exists { + if fp.SessionExist(sid) { t.Error() } } @@ -391,36 +334,3 @@ func TestFileSessionStore_SessionID(t *testing.T) { } } } - -func TestFileSessionStore_SessionRelease(t *testing.T) { - mutex.Lock() - defer mutex.Unlock() - os.RemoveAll(sessionPath) - defer os.RemoveAll(sessionPath) - fp := &FileProvider{} - - _ = fp.SessionInit(180, sessionPath) - filepder.savePath = sessionPath - sessionCount := 85 - - for i := 1; i <= sessionCount; i++ { - s, err := fp.SessionRead(fmt.Sprintf("%s_%d", sid, i)) - if err != nil { - t.Error(err) - } - - s.Set(i, i) - s.SessionRelease(nil) - } - - for i := 1; i <= sessionCount; i++ { - s, err := fp.SessionRead(fmt.Sprintf("%s_%d", sid, i)) - if err != nil { - t.Error(err) - } - - if s.Get(i).(int) != i { - t.Error() - } - } -} diff --git a/pkg/adapter/session/sess_mem.go b/pkg/adapter/session/sess_mem.go new file mode 100644 index 00000000..818c8329 --- /dev/null +++ b/pkg/adapter/session/sess_mem.go @@ -0,0 +1,106 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package session + +import ( + "context" + "net/http" + + "github.com/astaxie/beego/pkg/infrastructure/session" +) + +// MemSessionStore memory session store. +// it saved sessions in a map in memory. +type MemSessionStore session.MemSessionStore + +// Set value to memory session +func (st *MemSessionStore) Set(key, value interface{}) error { + return (*session.MemSessionStore)(st).Set(context.Background(), key, value) +} + +// Get value from memory session by key +func (st *MemSessionStore) Get(key interface{}) interface{} { + return (*session.MemSessionStore)(st).Get(context.Background(), key) +} + +// Delete in memory session by key +func (st *MemSessionStore) Delete(key interface{}) error { + return (*session.MemSessionStore)(st).Delete(context.Background(), key) +} + +// Flush clear all values in memory session +func (st *MemSessionStore) Flush() error { + return (*session.MemSessionStore)(st).Flush(context.Background()) +} + +// SessionID get this id of memory session store +func (st *MemSessionStore) SessionID() string { + return (*session.MemSessionStore)(st).SessionID(context.Background()) +} + +// SessionRelease Implement method, no used. +func (st *MemSessionStore) SessionRelease(w http.ResponseWriter) { + (*session.MemSessionStore)(st).SessionRelease(context.Background(), w) +} + +// MemProvider Implement the provider interface +type MemProvider session.MemProvider + +// SessionInit init memory session +func (pder *MemProvider) SessionInit(maxlifetime int64, savePath string) error { + return (*session.MemProvider)(pder).SessionInit(context.Background(), maxlifetime, savePath) +} + +// SessionRead get memory session store by sid +func (pder *MemProvider) SessionRead(sid string) (Store, error) { + s, err := (*session.MemProvider)(pder).SessionRead(context.Background(), sid) + return &NewToOldStoreAdapter{ + delegate: s, + }, err +} + +// SessionExist check session store exist in memory session by sid +func (pder *MemProvider) SessionExist(sid string) bool { + res, _ := (*session.MemProvider)(pder).SessionExist(context.Background(), sid) + return res +} + +// SessionRegenerate generate new sid for session store in memory session +func (pder *MemProvider) SessionRegenerate(oldsid, sid string) (Store, error) { + s, err := (*session.MemProvider)(pder).SessionRegenerate(context.Background(), oldsid, sid) + return &NewToOldStoreAdapter{ + delegate: s, + }, err +} + +// SessionDestroy delete session store in memory session by id +func (pder *MemProvider) SessionDestroy(sid string) error { + return (*session.MemProvider)(pder).SessionDestroy(context.Background(), sid) +} + +// SessionGC clean expired session stores in memory session +func (pder *MemProvider) SessionGC() { + (*session.MemProvider)(pder).SessionGC(context.Background()) +} + +// SessionAll get count number of memory session +func (pder *MemProvider) SessionAll() int { + return (*session.MemProvider)(pder).SessionAll(context.Background()) +} + +// SessionUpdate expand time of session store by id in memory session +func (pder *MemProvider) SessionUpdate(sid string) error { + return (*session.MemProvider)(pder).SessionUpdate(context.Background(), sid) +} diff --git a/pkg/session/sess_mem_test.go b/pkg/adapter/session/sess_mem_test.go similarity index 100% rename from pkg/session/sess_mem_test.go rename to pkg/adapter/session/sess_mem_test.go diff --git a/pkg/adapter/session/sess_test.go b/pkg/adapter/session/sess_test.go new file mode 100644 index 00000000..aba702ca --- /dev/null +++ b/pkg/adapter/session/sess_test.go @@ -0,0 +1,51 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package session + +import ( + "testing" +) + +func Test_gob(t *testing.T) { + a := make(map[interface{}]interface{}) + a["username"] = "astaxie" + a[12] = 234 + a["user"] = User{"asta", "xie"} + b, err := EncodeGob(a) + if err != nil { + t.Error(err) + } + c, err := DecodeGob(b) + if err != nil { + t.Error(err) + } + if len(c) == 0 { + t.Error("decodeGob empty") + } + if c["username"] != "astaxie" { + t.Error("decode string error") + } + if c[12] != 234 { + t.Error("decode int error") + } + if c["user"].(User).Username != "asta" { + t.Error("decode struct error") + } +} + +type User struct { + Username string + NickName string +} diff --git a/pkg/adapter/session/sess_utils.go b/pkg/adapter/session/sess_utils.go new file mode 100644 index 00000000..3d107198 --- /dev/null +++ b/pkg/adapter/session/sess_utils.go @@ -0,0 +1,29 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package session + +import ( + "github.com/astaxie/beego/pkg/infrastructure/session" +) + +// EncodeGob encode the obj to gob +func EncodeGob(obj map[interface{}]interface{}) ([]byte, error) { + return session.EncodeGob(obj) +} + +// DecodeGob decode data to map +func DecodeGob(encoded []byte) (map[interface{}]interface{}, error) { + return session.DecodeGob(encoded) +} diff --git a/pkg/adapter/session/session.go b/pkg/adapter/session/session.go new file mode 100644 index 00000000..eea2f90e --- /dev/null +++ b/pkg/adapter/session/session.go @@ -0,0 +1,166 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package session provider +// +// Usage: +// import( +// "github.com/astaxie/beego/session" +// ) +// +// func init() { +// globalSessions, _ = session.NewManager("memory", `{"cookieName":"gosessionid", "enableSetCookie,omitempty": true, "gclifetime":3600, "maxLifetime": 3600, "secure": false, "cookieLifeTime": 3600, "providerConfig": ""}`) +// go globalSessions.GC() +// } +// +// more docs: http://beego.me/docs/module/session.md +package session + +import ( + "io" + "net/http" + "os" + + "github.com/astaxie/beego/pkg/infrastructure/session" +) + +// Store contains all data for one session process with specific id. +type Store interface { + Set(key, value interface{}) error // set session value + Get(key interface{}) interface{} // get session value + Delete(key interface{}) error // delete session value + SessionID() string // back current sessionID + SessionRelease(w http.ResponseWriter) // release the resource & save data to provider & return the data + Flush() error // delete all data +} + +// Provider contains global session methods and saved SessionStores. +// it can operate a SessionStore by its id. +type Provider interface { + SessionInit(gclifetime int64, config string) error + SessionRead(sid string) (Store, error) + SessionExist(sid string) bool + SessionRegenerate(oldsid, sid string) (Store, error) + SessionDestroy(sid string) error + SessionAll() int // get all active session + SessionGC() +} + +// SLogger a helpful variable to log information about session +var SLogger = NewSessionLog(os.Stderr) + +// Register makes a session provide available by the provided name. +// If Register is called twice with the same name or if driver is nil, +// it panics. +func Register(name string, provide Provider) { + session.Register(name, &oldToNewProviderAdapter{ + delegate: provide, + }) +} + +// GetProvider +func GetProvider(name string) (Provider, error) { + res, err := session.GetProvider(name) + if adt, ok := res.(*oldToNewProviderAdapter); err == nil && ok { + return adt.delegate, err + } + + return &newToOldProviderAdapter{ + delegate: res, + }, err +} + +// ManagerConfig define the session config +type ManagerConfig session.ManagerConfig + +// Manager contains Provider and its configuration. +type Manager session.Manager + +// NewManager Create new Manager with provider name and json config string. +// provider name: +// 1. cookie +// 2. file +// 3. memory +// 4. redis +// 5. mysql +// json config: +// 1. is https default false +// 2. hashfunc default sha1 +// 3. hashkey default beegosessionkey +// 4. maxage default is none +func NewManager(provideName string, cf *ManagerConfig) (*Manager, error) { + m, err := session.NewManager(provideName, (*session.ManagerConfig)(cf)) + return (*Manager)(m), err +} + +// GetProvider return current manager's provider +func (manager *Manager) GetProvider() Provider { + return &newToOldProviderAdapter{ + delegate: (*session.Manager)(manager).GetProvider(), + } +} + +// SessionStart generate or read the session id from http request. +// if session id exists, return SessionStore with this id. +func (manager *Manager) SessionStart(w http.ResponseWriter, r *http.Request) (Store, error) { + s, err := (*session.Manager)(manager).SessionStart(w, r) + return &NewToOldStoreAdapter{ + delegate: s, + }, err +} + +// SessionDestroy Destroy session by its id in http request cookie. +func (manager *Manager) SessionDestroy(w http.ResponseWriter, r *http.Request) { + (*session.Manager)(manager).SessionDestroy(w, r) +} + +// GetSessionStore Get SessionStore by its id. +func (manager *Manager) GetSessionStore(sid string) (Store, error) { + s, err := (*session.Manager)(manager).GetSessionStore(sid) + return &NewToOldStoreAdapter{ + delegate: s, + }, err +} + +// GC Start session gc process. +// it can do gc in times after gc lifetime. +func (manager *Manager) GC() { + (*session.Manager)(manager).GC() +} + +// SessionRegenerateID Regenerate a session id for this SessionStore who's id is saving in http request. +func (manager *Manager) SessionRegenerateID(w http.ResponseWriter, r *http.Request) Store { + s := (*session.Manager)(manager).SessionRegenerateID(w, r) + return &NewToOldStoreAdapter{ + delegate: s, + } +} + +// GetActiveSession Get all active sessions count number. +func (manager *Manager) GetActiveSession() int { + return (*session.Manager)(manager).GetActiveSession() +} + +// SetSecure Set cookie with https. +func (manager *Manager) SetSecure(secure bool) { + (*session.Manager)(manager).SetSecure(secure) +} + +// Log implement the log.Logger +type Log session.Log + +// NewSessionLog set io.Writer to create a Logger for session. +func NewSessionLog(out io.Writer) *Log { + return (*Log)(session.NewSessionLog(out)) +} diff --git a/pkg/adapter/session/ssdb/sess_ssdb.go b/pkg/adapter/session/ssdb/sess_ssdb.go new file mode 100644 index 00000000..aee3a364 --- /dev/null +++ b/pkg/adapter/session/ssdb/sess_ssdb.go @@ -0,0 +1,84 @@ +package ssdb + +import ( + "context" + "net/http" + + "github.com/astaxie/beego/pkg/adapter/session" + + beeSsdb "github.com/astaxie/beego/pkg/infrastructure/session/ssdb" +) + +// Provider holds ssdb client and configs +type Provider beeSsdb.Provider + +// SessionInit init the ssdb with the config +func (p *Provider) SessionInit(maxLifetime int64, savePath string) error { + return (*beeSsdb.Provider)(p).SessionInit(context.Background(), maxLifetime, savePath) +} + +// SessionRead return a ssdb client session Store +func (p *Provider) SessionRead(sid string) (session.Store, error) { + s, err := (*beeSsdb.Provider)(p).SessionRead(context.Background(), sid) + return session.CreateNewToOldStoreAdapter(s), err +} + +// SessionExist judged whether sid is exist in session +func (p *Provider) SessionExist(sid string) bool { + res, _ := (*beeSsdb.Provider)(p).SessionExist(context.Background(), sid) + return res +} + +// SessionRegenerate regenerate session with new sid and delete oldsid +func (p *Provider) SessionRegenerate(oldsid, sid string) (session.Store, error) { + s, err := (*beeSsdb.Provider)(p).SessionRegenerate(context.Background(), oldsid, sid) + return session.CreateNewToOldStoreAdapter(s), err +} + +// SessionDestroy destroy the sid +func (p *Provider) SessionDestroy(sid string) error { + return (*beeSsdb.Provider)(p).SessionDestroy(context.Background(), sid) +} + +// SessionGC not implemented +func (p *Provider) SessionGC() { + (*beeSsdb.Provider)(p).SessionGC(context.Background()) +} + +// SessionAll not implemented +func (p *Provider) SessionAll() int { + return (*beeSsdb.Provider)(p).SessionAll(context.Background()) +} + +// SessionStore holds the session information which stored in ssdb +type SessionStore beeSsdb.SessionStore + +// Set the key and value +func (s *SessionStore) Set(key, value interface{}) error { + return (*beeSsdb.SessionStore)(s).Set(context.Background(), key, value) +} + +// Get return the value by the key +func (s *SessionStore) Get(key interface{}) interface{} { + return (*beeSsdb.SessionStore)(s).Get(context.Background(), key) +} + +// Delete the key in session store +func (s *SessionStore) Delete(key interface{}) error { + return (*beeSsdb.SessionStore)(s).Delete(context.Background(), key) +} + +// Flush delete all keys and values +func (s *SessionStore) Flush() error { + return (*beeSsdb.SessionStore)(s).Flush(context.Background()) +} + +// SessionID return the sessionID +func (s *SessionStore) SessionID() string { + return (*beeSsdb.SessionStore)(s).SessionID(context.Background()) +} + +// SessionRelease Store the keyvalues into ssdb +func (s *SessionStore) SessionRelease(w http.ResponseWriter) { + (*beeSsdb.SessionStore)(s).SessionRelease(context.Background(), w) +} diff --git a/pkg/adapter/session/store_adapter.go b/pkg/adapter/session/store_adapter.go new file mode 100644 index 00000000..c1a03c38 --- /dev/null +++ b/pkg/adapter/session/store_adapter.go @@ -0,0 +1,84 @@ +// Copyright 2020 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package session + +import ( + "context" + "net/http" + + "github.com/astaxie/beego/pkg/infrastructure/session" +) + +type NewToOldStoreAdapter struct { + delegate session.Store +} + +func CreateNewToOldStoreAdapter(s session.Store) Store { + return &NewToOldStoreAdapter{ + delegate: s, + } +} + +func (n *NewToOldStoreAdapter) Set(key, value interface{}) error { + return n.delegate.Set(context.Background(), key, value) +} + +func (n *NewToOldStoreAdapter) Get(key interface{}) interface{} { + return n.delegate.Get(context.Background(), key) +} + +func (n *NewToOldStoreAdapter) Delete(key interface{}) error { + return n.delegate.Delete(context.Background(), key) +} + +func (n *NewToOldStoreAdapter) SessionID() string { + return n.delegate.SessionID(context.Background()) +} + +func (n *NewToOldStoreAdapter) SessionRelease(w http.ResponseWriter) { + n.delegate.SessionRelease(context.Background(), w) +} + +func (n *NewToOldStoreAdapter) Flush() error { + return n.delegate.Flush(context.Background()) +} + +type oldToNewStoreAdapter struct { + delegate Store +} + +func (o *oldToNewStoreAdapter) Set(ctx context.Context, key, value interface{}) error { + return o.delegate.Set(key, value) +} + +func (o *oldToNewStoreAdapter) Get(ctx context.Context, key interface{}) interface{} { + return o.delegate.Get(key) +} + +func (o *oldToNewStoreAdapter) Delete(ctx context.Context, key interface{}) error { + return o.delegate.Delete(key) +} + +func (o *oldToNewStoreAdapter) SessionID(ctx context.Context) string { + return o.delegate.SessionID() +} + +func (o *oldToNewStoreAdapter) SessionRelease(ctx context.Context, w http.ResponseWriter) { + o.delegate.SessionRelease(w) +} + +func (o *oldToNewStoreAdapter) Flush(ctx context.Context) error { + return o.delegate.Flush() +} diff --git a/pkg/adapter/swagger/swagger.go b/pkg/adapter/swagger/swagger.go new file mode 100644 index 00000000..214959d9 --- /dev/null +++ b/pkg/adapter/swagger/swagger.go @@ -0,0 +1,68 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Swagger™ is a project used to describe and document RESTful APIs. +// +// The Swagger specification defines a set of files required to describe such an API. These files can then be used by the Swagger-UI project to display the API and Swagger-Codegen to generate clients in various languages. Additional utilities can also take advantage of the resulting files, such as testing tools. +// Now in version 2.0, Swagger is more enabling than ever. And it's 100% open source software. + +// Package swagger struct definition +package swagger + +import ( + "github.com/astaxie/beego/pkg/server/web/swagger" +) + +// Swagger list the resource +type Swagger swagger.Swagger + +// Information Provides metadata about the API. The metadata can be used by the clients if needed. +type Information swagger.Information + +// Contact information for the exposed API. +type Contact swagger.Contact + +// License information for the exposed API. +type License swagger.License + +// Item Describes the operations available on a single path. +type Item swagger.Item + +// Operation Describes a single API operation on a path. +type Operation swagger.Operation + +// Parameter Describes a single operation parameter. +type Parameter swagger.Parameter + +// ParameterItems A limited subset of JSON-Schema's items object. It is used by parameter definitions that are not located in "body". +// http://swagger.io/specification/#itemsObject +type ParameterItems swagger.ParameterItems + +// Schema Object allows the definition of input and output data types. +type Schema swagger.Schema + +// Propertie are taken from the JSON Schema definition but their definitions were adjusted to the Swagger Specification +type Propertie swagger.Propertie + +// Response as they are returned from executing this operation. +type Response swagger.Response + +// Security Allows the definition of a security scheme that can be used by the operations +type Security swagger.Security + +// Tag Allows adding meta data to a single tag that is used by the Operation Object +type Tag swagger.Tag + +// ExternalDocs include Additional external documentation +type ExternalDocs swagger.ExternalDocs diff --git a/pkg/adapter/template.go b/pkg/adapter/template.go new file mode 100644 index 00000000..1f943caf --- /dev/null +++ b/pkg/adapter/template.go @@ -0,0 +1,108 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package adapter + +import ( + "html/template" + "io" + "net/http" + + "github.com/astaxie/beego/pkg/server/web" +) + +// ExecuteTemplate applies the template with name to the specified data object, +// writing the output to wr. +// A template will be executed safely in parallel. +func ExecuteTemplate(wr io.Writer, name string, data interface{}) error { + return web.ExecuteTemplate(wr, name, data) +} + +// ExecuteViewPathTemplate applies the template with name and from specific viewPath to the specified data object, +// writing the output to wr. +// A template will be executed safely in parallel. +func ExecuteViewPathTemplate(wr io.Writer, name string, viewPath string, data interface{}) error { + return web.ExecuteViewPathTemplate(wr, name, viewPath, data) +} + +// AddFuncMap let user to register a func in the template. +func AddFuncMap(key string, fn interface{}) error { + return web.AddFuncMap(key, fn) +} + +type templatePreProcessor func(root, path string, funcs template.FuncMap) (*template.Template, error) + +type templateFile struct { + root string + files map[string][]string +} + +// HasTemplateExt return this path contains supported template extension of beego or not. +func HasTemplateExt(paths string) bool { + return web.HasTemplateExt(paths) +} + +// AddTemplateExt add new extension for template. +func AddTemplateExt(ext string) { + web.AddTemplateExt(ext) +} + +// AddViewPath adds a new path to the supported view paths. +// Can later be used by setting a controller ViewPath to this folder +// will panic if called after beego.Run() +func AddViewPath(viewPath string) error { + return web.AddViewPath(viewPath) +} + +// BuildTemplate will build all template files in a directory. +// it makes beego can render any template file in view directory. +func BuildTemplate(dir string, files ...string) error { + return web.BuildTemplate(dir, files...) +} + +type templateFSFunc func() http.FileSystem + +func defaultFSFunc() http.FileSystem { + return FileSystem{} +} + +// SetTemplateFSFunc set default filesystem function +func SetTemplateFSFunc(fnt templateFSFunc) { + web.SetTemplateFSFunc(func() http.FileSystem { + return fnt() + }) +} + +// SetViewsPath sets view directory path in beego application. +func SetViewsPath(path string) *App { + return (*App)(web.SetViewsPath(path)) +} + +// SetStaticPath sets static directory path and proper url pattern in beego application. +// if beego.SetStaticPath("static","public"), visit /static/* to load static file in folder "public". +func SetStaticPath(url string, path string) *App { + return (*App)(web.SetStaticPath(url, path)) +} + +// DelStaticPath removes the static folder setting in this url pattern in beego application. +func DelStaticPath(url string) *App { + return (*App)(web.DelStaticPath(url)) +} + +// AddTemplateEngine add a new templatePreProcessor which support extension +func AddTemplateEngine(extension string, fn templatePreProcessor) *App { + return (*App)(web.AddTemplateEngine(extension, func(root, path string, funcs template.FuncMap) (*template.Template, error) { + return fn(root, path, funcs) + })) +} diff --git a/pkg/adapter/templatefunc.go b/pkg/adapter/templatefunc.go new file mode 100644 index 00000000..5130d590 --- /dev/null +++ b/pkg/adapter/templatefunc.go @@ -0,0 +1,151 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package adapter + +import ( + "html/template" + "net/url" + "time" + + "github.com/astaxie/beego/pkg/server/web" +) + +const ( + formatTime = "15:04:05" + formatDate = "2006-01-02" + formatDateTime = "2006-01-02 15:04:05" + formatDateTimeT = "2006-01-02T15:04:05" +) + +// Substr returns the substr from start to length. +func Substr(s string, start, length int) string { + return web.Substr(s, start, length) +} + +// HTML2str returns escaping text convert from html. +func HTML2str(html string) string { + return web.HTML2str(html) +} + +// DateFormat takes a time and a layout string and returns a string with the formatted date. Used by the template parser as "dateformat" +func DateFormat(t time.Time, layout string) (datestring string) { + return web.DateFormat(t, layout) +} + +// DateParse Parse Date use PHP time format. +func DateParse(dateString, format string) (time.Time, error) { + return web.DateParse(dateString, format) +} + +// Date takes a PHP like date func to Go's time format. +func Date(t time.Time, format string) string { + return web.Date(t, format) +} + +// Compare is a quick and dirty comparison function. It will convert whatever you give it to strings and see if the two values are equal. +// Whitespace is trimmed. Used by the template parser as "eq". +func Compare(a, b interface{}) (equal bool) { + return web.Compare(a, b) +} + +// CompareNot !Compare +func CompareNot(a, b interface{}) (equal bool) { + return web.CompareNot(a, b) +} + +// NotNil the same as CompareNot +func NotNil(a interface{}) (isNil bool) { + return web.NotNil(a) +} + +// GetConfig get the Appconfig +func GetConfig(returnType, key string, defaultVal interface{}) (interface{}, error) { + return web.GetConfig(returnType, key, defaultVal) +} + +// Str2html Convert string to template.HTML type. +func Str2html(raw string) template.HTML { + return web.Str2html(raw) +} + +// Htmlquote returns quoted html string. +func Htmlquote(text string) string { + return web.Htmlquote(text) +} + +// Htmlunquote returns unquoted html string. +func Htmlunquote(text string) string { + return web.Htmlunquote(text) +} + +// URLFor returns url string with another registered controller handler with params. +// usage: +// +// URLFor(".index") +// print URLFor("index") +// router /login +// print URLFor("login") +// print URLFor("login", "next","/"") +// router /profile/:username +// print UrlFor("profile", ":username","John Doe") +// result: +// / +// /login +// /login?next=/ +// /user/John%20Doe +// +// more detail http://beego.me/docs/mvc/controller/urlbuilding.md +func URLFor(endpoint string, values ...interface{}) string { + return web.URLFor(endpoint, values...) +} + +// AssetsJs returns script tag with src string. +func AssetsJs(text string) template.HTML { + return web.AssetsJs(text) +} + +// AssetsCSS returns stylesheet link tag with src string. +func AssetsCSS(text string) template.HTML { + + text = "" + + return template.HTML(text) +} + +// ParseForm will parse form values to struct via tag. +func ParseForm(form url.Values, obj interface{}) error { + return web.ParseForm(form, obj) +} + +// RenderForm will render object to form html. +// obj must be a struct pointer. +func RenderForm(obj interface{}) template.HTML { + return web.RenderForm(obj) +} + +// MapGet getting value from map by keys +// usage: +// Data["m"] = M{ +// "a": 1, +// "1": map[string]float64{ +// "c": 4, +// }, +// } +// +// {{ map_get m "a" }} // return 1 +// {{ map_get m 1 "c" }} // return 4 +func MapGet(arg1 interface{}, arg2 ...interface{}) (interface{}, error) { + return web.MapGet(arg1, arg2...) +} diff --git a/pkg/adapter/templatefunc_test.go b/pkg/adapter/templatefunc_test.go new file mode 100644 index 00000000..f5113606 --- /dev/null +++ b/pkg/adapter/templatefunc_test.go @@ -0,0 +1,304 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package adapter + +import ( + "html/template" + "net/url" + "testing" + "time" +) + +func TestSubstr(t *testing.T) { + s := `012345` + if Substr(s, 0, 2) != "01" { + t.Error("should be equal") + } + if Substr(s, 0, 100) != "012345" { + t.Error("should be equal") + } + if Substr(s, 12, 100) != "012345" { + t.Error("should be equal") + } +} + +func TestHtml2str(t *testing.T) { + h := `<123> 123\n + + + \n` + if HTML2str(h) != "123\\n\n\\n" { + t.Error("should be equal") + } +} + +func TestDateFormat(t *testing.T) { + ts := "Mon, 01 Jul 2013 13:27:42 CST" + tt, _ := time.Parse(time.RFC1123, ts) + + if ss := DateFormat(tt, "2006-01-02 15:04:05"); ss != "2013-07-01 13:27:42" { + t.Errorf("2013-07-01 13:27:42 does not equal %v", ss) + } +} + +func TestDate(t *testing.T) { + ts := "Mon, 01 Jul 2013 13:27:42 CST" + tt, _ := time.Parse(time.RFC1123, ts) + + if ss := Date(tt, "Y-m-d H:i:s"); ss != "2013-07-01 13:27:42" { + t.Errorf("2013-07-01 13:27:42 does not equal %v", ss) + } + if ss := Date(tt, "y-n-j h:i:s A"); ss != "13-7-1 01:27:42 PM" { + t.Errorf("13-7-1 01:27:42 PM does not equal %v", ss) + } + if ss := Date(tt, "D, d M Y g:i:s a"); ss != "Mon, 01 Jul 2013 1:27:42 pm" { + t.Errorf("Mon, 01 Jul 2013 1:27:42 pm does not equal %v", ss) + } + if ss := Date(tt, "l, d F Y G:i:s"); ss != "Monday, 01 July 2013 13:27:42" { + t.Errorf("Monday, 01 July 2013 13:27:42 does not equal %v", ss) + } +} + +func TestCompareRelated(t *testing.T) { + if !Compare("abc", "abc") { + t.Error("should be equal") + } + if Compare("abc", "aBc") { + t.Error("should be not equal") + } + if !Compare("1", 1) { + t.Error("should be equal") + } + if CompareNot("abc", "abc") { + t.Error("should be equal") + } + if !CompareNot("abc", "aBc") { + t.Error("should be not equal") + } + if !NotNil("a string") { + t.Error("should not be nil") + } +} + +func TestHtmlquote(t *testing.T) { + h := `<' ”“&">` + s := `<' ”“&">` + if Htmlquote(s) != h { + t.Error("should be equal") + } +} + +func TestHtmlunquote(t *testing.T) { + h := `<' ”“&">` + s := `<' ”“&">` + if Htmlunquote(h) != s { + t.Error("should be equal") + } +} + +func TestParseForm(t *testing.T) { + type ExtendInfo struct { + Hobby []string `form:"hobby"` + Memo string + } + + type OtherInfo struct { + Organization string `form:"organization"` + Title string `form:"title"` + ExtendInfo + } + + type user struct { + ID int `form:"-"` + tag string `form:"tag"` + Name interface{} `form:"username"` + Age int `form:"age,text"` + Email string + Intro string `form:",textarea"` + StrBool bool `form:"strbool"` + Date time.Time `form:"date,2006-01-02"` + OtherInfo + } + + u := user{} + form := url.Values{ + "ID": []string{"1"}, + "-": []string{"1"}, + "tag": []string{"no"}, + "username": []string{"test"}, + "age": []string{"40"}, + "Email": []string{"test@gmail.com"}, + "Intro": []string{"I am an engineer!"}, + "strbool": []string{"yes"}, + "date": []string{"2014-11-12"}, + "organization": []string{"beego"}, + "title": []string{"CXO"}, + "hobby": []string{"", "Basketball", "Football"}, + "memo": []string{"nothing"}, + } + if err := ParseForm(form, u); err == nil { + t.Fatal("nothing will be changed") + } + if err := ParseForm(form, &u); err != nil { + t.Fatal(err) + } + if u.ID != 0 { + t.Errorf("ID should equal 0 but got %v", u.ID) + } + if len(u.tag) != 0 { + t.Errorf("tag's length should equal 0 but got %v", len(u.tag)) + } + if u.Name.(string) != "test" { + t.Errorf("Name should equal `test` but got `%v`", u.Name.(string)) + } + if u.Age != 40 { + t.Errorf("Age should equal 40 but got %v", u.Age) + } + if u.Email != "test@gmail.com" { + t.Errorf("Email should equal `test@gmail.com` but got `%v`", u.Email) + } + if u.Intro != "I am an engineer!" { + t.Errorf("Intro should equal `I am an engineer!` but got `%v`", u.Intro) + } + if !u.StrBool { + t.Errorf("strboll should equal `true`, but got `%v`", u.StrBool) + } + y, m, d := u.Date.Date() + if y != 2014 || m.String() != "November" || d != 12 { + t.Errorf("Date should equal `2014-11-12`, but got `%v`", u.Date.String()) + } + if u.Organization != "beego" { + t.Errorf("Organization should equal `beego`, but got `%v`", u.Organization) + } + if u.Title != "CXO" { + t.Errorf("Title should equal `CXO`, but got `%v`", u.Title) + } + if u.Hobby[0] != "" { + t.Errorf("Hobby should equal ``, but got `%v`", u.Hobby[0]) + } + if u.Hobby[1] != "Basketball" { + t.Errorf("Hobby should equal `Basketball`, but got `%v`", u.Hobby[1]) + } + if u.Hobby[2] != "Football" { + t.Errorf("Hobby should equal `Football`, but got `%v`", u.Hobby[2]) + } + if len(u.Memo) != 0 { + t.Errorf("Memo's length should equal 0 but got %v", len(u.Memo)) + } +} + +func TestRenderForm(t *testing.T) { + type user struct { + ID int `form:"-"` + Name interface{} `form:"username"` + Age int `form:"age,text,年龄:"` + Sex string + Email []string + Intro string `form:",textarea"` + Ignored string `form:"-"` + } + + u := user{Name: "test", Intro: "Some Text"} + output := RenderForm(u) + if output != template.HTML("") { + t.Errorf("output should be empty but got %v", output) + } + output = RenderForm(&u) + result := template.HTML( + `Name:
` + + `年龄:
` + + `Sex:
` + + `Intro: `) + if output != result { + t.Errorf("output should equal `%v` but got `%v`", result, output) + } +} + +func TestMapGet(t *testing.T) { + // test one level map + m1 := map[string]int64{ + "a": 1, + "1": 2, + } + + if res, err := MapGet(m1, "a"); err == nil { + if res.(int64) != 1 { + t.Errorf("Should return 1, but return %v", res) + } + } else { + t.Errorf("Error happens %v", err) + } + + if res, err := MapGet(m1, "1"); err == nil { + if res.(int64) != 2 { + t.Errorf("Should return 2, but return %v", res) + } + } else { + t.Errorf("Error happens %v", err) + } + + if res, err := MapGet(m1, 1); err == nil { + if res.(int64) != 2 { + t.Errorf("Should return 2, but return %v", res) + } + } else { + t.Errorf("Error happens %v", err) + } + + // test 2 level map + m2 := M{ + "1": map[string]float64{ + "2": 3.5, + }, + } + + if res, err := MapGet(m2, 1, 2); err == nil { + if res.(float64) != 3.5 { + t.Errorf("Should return 3.5, but return %v", res) + } + } else { + t.Errorf("Error happens %v", err) + } + + // test 5 level map + m5 := M{ + "1": M{ + "2": M{ + "3": M{ + "4": M{ + "5": 1.2, + }, + }, + }, + }, + } + + if res, err := MapGet(m5, 1, 2, 3, 4, 5); err == nil { + if res.(float64) != 1.2 { + t.Errorf("Should return 1.2, but return %v", res) + } + } else { + t.Errorf("Error happens %v", err) + } + + // check whether element not exists in map + if res, err := MapGet(m5, 5, 4, 3, 2, 1); err == nil { + if res != nil { + t.Errorf("Should return nil, but return %v", res) + } + } else { + t.Errorf("Error happens %v", err) + } +} diff --git a/pkg/adapter/testing/client.go b/pkg/adapter/testing/client.go new file mode 100644 index 00000000..688aa6f3 --- /dev/null +++ b/pkg/adapter/testing/client.go @@ -0,0 +1,50 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package testing + +import ( + "github.com/astaxie/beego/pkg/client/httplib/testing" +) + +var port = "" +var baseURL = "http://localhost:" + +// TestHTTPRequest beego test request client +type TestHTTPRequest testing.TestHTTPRequest + +// Get returns test client in GET method +func Get(path string) *TestHTTPRequest { + return (*TestHTTPRequest)(testing.Get(path)) +} + +// Post returns test client in POST method +func Post(path string) *TestHTTPRequest { + return (*TestHTTPRequest)(testing.Post(path)) +} + +// Put returns test client in PUT method +func Put(path string) *TestHTTPRequest { + return (*TestHTTPRequest)(testing.Put(path)) +} + +// Delete returns test client in DELETE method +func Delete(path string) *TestHTTPRequest { + return (*TestHTTPRequest)(testing.Delete(path)) +} + +// Head returns test client in HEAD method +func Head(path string) *TestHTTPRequest { + return (*TestHTTPRequest)(testing.Head(path)) +} diff --git a/pkg/adapter/toolbox/healthcheck.go b/pkg/adapter/toolbox/healthcheck.go new file mode 100644 index 00000000..56be8089 --- /dev/null +++ b/pkg/adapter/toolbox/healthcheck.go @@ -0,0 +1,52 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package toolbox healthcheck +// +// type DatabaseCheck struct { +// } +// +// func (dc *DatabaseCheck) Check() error { +// if dc.isConnected() { +// return nil +// } else { +// return errors.New("can't connect database") +// } +// } +// +// AddHealthCheck("database",&DatabaseCheck{}) +// +// more docs: http://beego.me/docs/module/toolbox.md +package toolbox + +import ( + "github.com/astaxie/beego/pkg/infrastructure/governor" +) + +// AdminCheckList holds health checker map +// Deprecated using governor.AdminCheckList +var AdminCheckList map[string]HealthChecker + +// HealthChecker health checker interface +type HealthChecker governor.HealthChecker + +// AddHealthCheck add health checker with name string +func AddHealthCheck(name string, hc HealthChecker) { + governor.AddHealthCheck(name, hc) + AdminCheckList[name] = hc +} + +func init() { + AdminCheckList = make(map[string]HealthChecker) +} diff --git a/pkg/adapter/toolbox/profile.go b/pkg/adapter/toolbox/profile.go new file mode 100644 index 00000000..16cf80b1 --- /dev/null +++ b/pkg/adapter/toolbox/profile.go @@ -0,0 +1,50 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package toolbox + +import ( + "io" + "os" + "time" + + "github.com/astaxie/beego/pkg/infrastructure/governor" +) + +var startTime = time.Now() +var pid int + +func init() { + pid = os.Getpid() +} + +// ProcessInput parse input command string +func ProcessInput(input string, w io.Writer) { + governor.ProcessInput(input, w) +} + +// MemProf record memory profile in pprof +func MemProf(w io.Writer) { + governor.MemProf(w) +} + +// GetCPUProfile start cpu profile monitor +func GetCPUProfile(w io.Writer) { + governor.GetCPUProfile(w) +} + +// PrintGCSummary print gc information to io.Writer +func PrintGCSummary(w io.Writer) { + governor.PrintGCSummary(w) +} diff --git a/pkg/toolbox/profile_test.go b/pkg/adapter/toolbox/profile_test.go similarity index 100% rename from pkg/toolbox/profile_test.go rename to pkg/adapter/toolbox/profile_test.go diff --git a/pkg/adapter/toolbox/statistics.go b/pkg/adapter/toolbox/statistics.go new file mode 100644 index 00000000..b7d3bda9 --- /dev/null +++ b/pkg/adapter/toolbox/statistics.go @@ -0,0 +1,50 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package toolbox + +import ( + "time" + + "github.com/astaxie/beego/pkg/server/web" +) + +// Statistics struct +type Statistics web.Statistics + +// URLMap contains several statistics struct to log different data +type URLMap web.URLMap + +// AddStatistics add statistics task. +// it needs request method, request url, request controller and statistics time duration +func (m *URLMap) AddStatistics(requestMethod, requestURL, requestController string, requesttime time.Duration) { + (*web.URLMap)(m).AddStatistics(requestMethod, requestURL, requestController, requesttime) +} + +// GetMap put url statistics result in io.Writer +func (m *URLMap) GetMap() map[string]interface{} { + return (*web.URLMap)(m).GetMap() +} + +// GetMapData return all mapdata +func (m *URLMap) GetMapData() []map[string]interface{} { + return (*web.URLMap)(m).GetMapData() +} + +// StatisticsMap hosld global statistics data map +var StatisticsMap *URLMap + +func init() { + StatisticsMap = (*URLMap)(web.StatisticsMap) +} diff --git a/pkg/toolbox/statistics_test.go b/pkg/adapter/toolbox/statistics_test.go similarity index 100% rename from pkg/toolbox/statistics_test.go rename to pkg/adapter/toolbox/statistics_test.go diff --git a/pkg/adapter/toolbox/task.go b/pkg/adapter/toolbox/task.go new file mode 100644 index 00000000..2a6d9aa6 --- /dev/null +++ b/pkg/adapter/toolbox/task.go @@ -0,0 +1,286 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package toolbox + +import ( + "context" + "sort" + "time" + + "github.com/astaxie/beego/pkg/task" +) + +// The bounds for each field. +var ( + AdminTaskList map[string]Tasker +) + +const ( + // Set the top bit if a star was included in the expression. + starBit = 1 << 63 +) + +// Schedule time taks schedule +type Schedule task.Schedule + +// TaskFunc task func type +type TaskFunc func() error + +// Tasker task interface +type Tasker interface { + GetSpec() string + GetStatus() string + Run() error + SetNext(time.Time) + GetNext() time.Time + SetPrev(time.Time) + GetPrev() time.Time +} + +// task error +type taskerr struct { + t time.Time + errinfo string +} + +// Task task struct +// Deprecated +type Task struct { + // Deprecated + Taskname string + // Deprecated + Spec *Schedule + // Deprecated + SpecStr string + // Deprecated + DoFunc TaskFunc + // Deprecated + Prev time.Time + // Deprecated + Next time.Time + // Deprecated + Errlist []*taskerr // like errtime:errinfo + // Deprecated + ErrLimit int // max length for the errlist, 0 stand for no limit + + delegate *task.Task +} + +// NewTask add new task with name, time and func +func NewTask(tname string, spec string, f TaskFunc) *Task { + + task := task.NewTask(tname, spec, func(ctx context.Context) error { + return f() + }) + return &Task{ + delegate: task, + } +} + +// GetSpec get spec string +func (t *Task) GetSpec() string { + t.initDelegate() + + return t.delegate.GetSpec(context.Background()) +} + +// GetStatus get current task status +func (t *Task) GetStatus() string { + + t.initDelegate() + + return t.delegate.GetStatus(context.Background()) +} + +// Run run all tasks +func (t *Task) Run() error { + t.initDelegate() + return t.delegate.Run(context.Background()) +} + +// SetNext set next time for this task +func (t *Task) SetNext(now time.Time) { + t.initDelegate() + t.delegate.SetNext(context.Background(), now) +} + +// GetNext get the next call time of this task +func (t *Task) GetNext() time.Time { + t.initDelegate() + return t.delegate.GetNext(context.Background()) +} + +// SetPrev set prev time of this task +func (t *Task) SetPrev(now time.Time) { + t.initDelegate() + t.delegate.SetPrev(context.Background(), now) +} + +// GetPrev get prev time of this task +func (t *Task) GetPrev() time.Time { + t.initDelegate() + return t.delegate.GetPrev(context.Background()) +} + +// six columns mean: +// second:0-59 +// minute:0-59 +// hour:1-23 +// day:1-31 +// month:1-12 +// week:0-6(0 means Sunday) + +// SetCron some signals: +// *: any time +// ,:  separate signal +//    -:duration +// /n : do as n times of time duration +// /////////////////////////////////////////////////////// +// 0/30 * * * * * every 30s +// 0 43 21 * * * 21:43 +// 0 15 05 * * *    05:15 +// 0 0 17 * * * 17:00 +// 0 0 17 * * 1 17:00 in every Monday +// 0 0,10 17 * * 0,2,3 17:00 and 17:10 in every Sunday, Tuesday and Wednesday +// 0 0-10 17 1 * * 17:00 to 17:10 in 1 min duration each time on the first day of month +// 0 0 0 1,15 * 1 0:00 on the 1st day and 15th day of month +// 0 42 4 1 * *     4:42 on the 1st day of month +// 0 0 21 * * 1-6   21:00 from Monday to Saturday +// 0 0,10,20,30,40,50 * * * *  every 10 min duration +// 0 */10 * * * *        every 10 min duration +// 0 * 1 * * *         1:00 to 1:59 in 1 min duration each time +// 0 0 1 * * *         1:00 +// 0 0 */1 * * *        0 min of hour in 1 hour duration +// 0 0 * * * *         0 min of hour in 1 hour duration +// 0 2 8-20/3 * * *       8:02, 11:02, 14:02, 17:02, 20:02 +// 0 30 5 1,15 * *       5:30 on the 1st day and 15th day of month +func (t *Task) SetCron(spec string) { + t.initDelegate() + t.delegate.SetCron(spec) +} + +func (t *Task) initDelegate() { + if t.delegate == nil { + t.delegate = &task.Task{ + Taskname: t.Taskname, + Spec: (*task.Schedule)(t.Spec), + SpecStr: t.SpecStr, + DoFunc: func(ctx context.Context) error { + return t.DoFunc() + }, + Prev: t.Prev, + Next: t.Next, + ErrLimit: t.ErrLimit, + } + } +} + +// Next set schedule to next time +func (s *Schedule) Next(t time.Time) time.Time { + return (*task.Schedule)(s).Next(t) +} + +// StartTask start all tasks +func StartTask() { + task.StartTask() +} + +// StopTask stop all tasks +func StopTask() { + task.StopTask() +} + +// AddTask add task with name +func AddTask(taskname string, t Tasker) { + task.AddTask(taskname, &oldToNewAdapter{delegate: t}) +} + +// DeleteTask delete task with name +func DeleteTask(taskname string) { + task.DeleteTask(taskname) +} + +// MapSorter sort map for tasker +type MapSorter task.MapSorter + +// NewMapSorter create new tasker map +func NewMapSorter(m map[string]Tasker) *MapSorter { + + newTaskerMap := make(map[string]task.Tasker, len(m)) + + for key, value := range m { + newTaskerMap[key] = &oldToNewAdapter{ + delegate: value, + } + } + + return (*MapSorter)(task.NewMapSorter(newTaskerMap)) +} + +// Sort sort tasker map +func (ms *MapSorter) Sort() { + sort.Sort(ms) +} + +func (ms *MapSorter) Len() int { return len(ms.Keys) } +func (ms *MapSorter) Less(i, j int) bool { + if ms.Vals[i].GetNext(context.Background()).IsZero() { + return false + } + if ms.Vals[j].GetNext(context.Background()).IsZero() { + return true + } + return ms.Vals[i].GetNext(context.Background()).Before(ms.Vals[j].GetNext(context.Background())) +} +func (ms *MapSorter) Swap(i, j int) { + ms.Vals[i], ms.Vals[j] = ms.Vals[j], ms.Vals[i] + ms.Keys[i], ms.Keys[j] = ms.Keys[j], ms.Keys[i] +} + +func init() { + AdminTaskList = make(map[string]Tasker) +} + +type oldToNewAdapter struct { + delegate Tasker +} + +func (o *oldToNewAdapter) GetSpec(ctx context.Context) string { + return o.delegate.GetSpec() +} + +func (o *oldToNewAdapter) GetStatus(ctx context.Context) string { + return o.delegate.GetStatus() +} + +func (o *oldToNewAdapter) Run(ctx context.Context) error { + return o.delegate.Run() +} + +func (o *oldToNewAdapter) SetNext(ctx context.Context, t time.Time) { + o.delegate.SetNext(t) +} + +func (o *oldToNewAdapter) GetNext(ctx context.Context) time.Time { + return o.delegate.GetNext() +} + +func (o *oldToNewAdapter) SetPrev(ctx context.Context, t time.Time) { + o.delegate.SetPrev(t) +} + +func (o *oldToNewAdapter) GetPrev(ctx context.Context) time.Time { + return o.delegate.GetPrev() +} diff --git a/pkg/toolbox/task_test.go b/pkg/adapter/toolbox/task_test.go similarity index 76% rename from pkg/toolbox/task_test.go rename to pkg/adapter/toolbox/task_test.go index b63f4391..596bc9c5 100644 --- a/pkg/toolbox/task_test.go +++ b/pkg/adapter/toolbox/task_test.go @@ -15,13 +15,10 @@ package toolbox import ( - "errors" "fmt" "sync" "testing" "time" - - "github.com/stretchr/testify/assert" ) func TestParse(t *testing.T) { @@ -56,25 +53,6 @@ func TestSpec(t *testing.T) { } } -func TestTask_Run(t *testing.T) { - cnt := -1 - task := func() error { - cnt++ - fmt.Printf("Hello, world! %d \n", cnt) - return errors.New(fmt.Sprintf("Hello, world! %d", cnt)) - } - tk := NewTask("taska", "0/30 * * * * *", task) - for i := 0; i < 200; i++ { - e := tk.Run() - assert.NotNil(t, e) - } - - l := tk.Errlist - assert.Equal(t, 100, len(l)) - assert.Equal(t, "Hello, world! 100", l[0].errinfo) - assert.Equal(t, "Hello, world! 101", l[1].errinfo) -} - func wait(wg *sync.WaitGroup) chan bool { ch := make(chan bool) go func() { diff --git a/pkg/adapter/tree.go b/pkg/adapter/tree.go new file mode 100644 index 00000000..2e3cd0d0 --- /dev/null +++ b/pkg/adapter/tree.go @@ -0,0 +1,49 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package adapter + +import ( + "github.com/astaxie/beego/pkg/adapter/context" + beecontext "github.com/astaxie/beego/pkg/server/web/context" + + "github.com/astaxie/beego/pkg/server/web" +) + +// Tree has three elements: FixRouter/wildcard/leaves +// fixRouter stores Fixed Router +// wildcard stores params +// leaves store the endpoint information +type Tree web.Tree + +// NewTree return a new Tree +func NewTree() *Tree { + return (*Tree)(web.NewTree()) +} + +// AddTree will add tree to the exist Tree +// prefix should has no params +func (t *Tree) AddTree(prefix string, tree *Tree) { + (*web.Tree)(t).AddTree(prefix, (*web.Tree)(tree)) +} + +// AddRouter call addseg function +func (t *Tree) AddRouter(pattern string, runObject interface{}) { + (*web.Tree)(t).AddRouter(pattern, runObject) +} + +// Match router to runObject & params +func (t *Tree) Match(pattern string, ctx *context.Context) (runObject interface{}) { + return (*web.Tree)(t).Match(pattern, (*beecontext.Context)(ctx)) +} diff --git a/pkg/adapter/tree_test.go b/pkg/adapter/tree_test.go new file mode 100644 index 00000000..309ed072 --- /dev/null +++ b/pkg/adapter/tree_test.go @@ -0,0 +1,249 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package adapter + +import ( + "testing" + + "github.com/astaxie/beego/pkg/adapter/context" + beecontext "github.com/astaxie/beego/pkg/server/web/context" +) + +type testinfo struct { + url string + requesturl string + params map[string]string +} + +var routers []testinfo + +func init() { + routers = make([]testinfo, 0) + routers = append(routers, testinfo{"/topic/?:auth:int", "/topic", nil}) + routers = append(routers, testinfo{"/topic/?:auth:int", "/topic/123", map[string]string{":auth": "123"}}) + routers = append(routers, testinfo{"/topic/:id/?:auth", "/topic/1", map[string]string{":id": "1"}}) + routers = append(routers, testinfo{"/topic/:id/?:auth", "/topic/1/2", map[string]string{":id": "1", ":auth": "2"}}) + routers = append(routers, testinfo{"/topic/:id/?:auth:int", "/topic/1", map[string]string{":id": "1"}}) + routers = append(routers, testinfo{"/topic/:id/?:auth:int", "/topic/1/123", map[string]string{":id": "1", ":auth": "123"}}) + routers = append(routers, testinfo{"/:id", "/123", map[string]string{":id": "123"}}) + routers = append(routers, testinfo{"/hello/?:id", "/hello", map[string]string{":id": ""}}) + routers = append(routers, testinfo{"/", "/", nil}) + routers = append(routers, testinfo{"/customer/login", "/customer/login", nil}) + routers = append(routers, testinfo{"/customer/login", "/customer/login.json", map[string]string{":ext": "json"}}) + routers = append(routers, testinfo{"/*", "/http://customer/123/", map[string]string{":splat": "http://customer/123/"}}) + routers = append(routers, testinfo{"/*", "/customer/2009/12/11", map[string]string{":splat": "customer/2009/12/11"}}) + routers = append(routers, testinfo{"/aa/*/bb", "/aa/2009/bb", map[string]string{":splat": "2009"}}) + routers = append(routers, testinfo{"/cc/*/dd", "/cc/2009/11/dd", map[string]string{":splat": "2009/11"}}) + routers = append(routers, testinfo{"/cc/:id/*", "/cc/2009/11/dd", map[string]string{":id": "2009", ":splat": "11/dd"}}) + routers = append(routers, testinfo{"/ee/:year/*/ff", "/ee/2009/11/ff", map[string]string{":year": "2009", ":splat": "11"}}) + routers = append(routers, testinfo{"/thumbnail/:size/uploads/*", + "/thumbnail/100x100/uploads/items/2014/04/20/dPRCdChkUd651t1Hvs18.jpg", + map[string]string{":size": "100x100", ":splat": "items/2014/04/20/dPRCdChkUd651t1Hvs18.jpg"}}) + routers = append(routers, testinfo{"/*.*", "/nice/api.json", map[string]string{":path": "nice/api", ":ext": "json"}}) + routers = append(routers, testinfo{"/:name/*.*", "/nice/api.json", map[string]string{":name": "nice", ":path": "api", ":ext": "json"}}) + routers = append(routers, testinfo{"/:name/test/*.*", "/nice/test/api.json", map[string]string{":name": "nice", ":path": "api", ":ext": "json"}}) + routers = append(routers, testinfo{"/dl/:width:int/:height:int/*.*", + "/dl/48/48/05ac66d9bda00a3acf948c43e306fc9a.jpg", + map[string]string{":width": "48", ":height": "48", ":ext": "jpg", ":path": "05ac66d9bda00a3acf948c43e306fc9a"}}) + routers = append(routers, testinfo{"/v1/shop/:id:int", "/v1/shop/123", map[string]string{":id": "123"}}) + routers = append(routers, testinfo{"/v1/shop/:id\\((a|b|c)\\)", "/v1/shop/123(a)", map[string]string{":id": "123"}}) + routers = append(routers, testinfo{"/v1/shop/:id\\((a|b|c)\\)", "/v1/shop/123(b)", map[string]string{":id": "123"}}) + routers = append(routers, testinfo{"/v1/shop/:id\\((a|b|c)\\)", "/v1/shop/123(c)", map[string]string{":id": "123"}}) + routers = append(routers, testinfo{"/:year:int/:month:int/:id/:endid", "/1111/111/aaa/aaa", map[string]string{":year": "1111", ":month": "111", ":id": "aaa", ":endid": "aaa"}}) + routers = append(routers, testinfo{"/v1/shop/:id/:name", "/v1/shop/123/nike", map[string]string{":id": "123", ":name": "nike"}}) + routers = append(routers, testinfo{"/v1/shop/:id/account", "/v1/shop/123/account", map[string]string{":id": "123"}}) + routers = append(routers, testinfo{"/v1/shop/:name:string", "/v1/shop/nike", map[string]string{":name": "nike"}}) + routers = append(routers, testinfo{"/v1/shop/:id([0-9]+)", "/v1/shop//123", map[string]string{":id": "123"}}) + routers = append(routers, testinfo{"/v1/shop/:id([0-9]+)_:name", "/v1/shop/123_nike", map[string]string{":id": "123", ":name": "nike"}}) + routers = append(routers, testinfo{"/v1/shop/:id(.+)_cms.html", "/v1/shop/123_cms.html", map[string]string{":id": "123"}}) + routers = append(routers, testinfo{"/v1/shop/cms_:id(.+)_:page(.+).html", "/v1/shop/cms_123_1.html", map[string]string{":id": "123", ":page": "1"}}) + routers = append(routers, testinfo{"/v1/:v/cms/aaa_:id(.+)_:page(.+).html", "/v1/2/cms/aaa_123_1.html", map[string]string{":v": "2", ":id": "123", ":page": "1"}}) + routers = append(routers, testinfo{"/v1/:v/cms_:id(.+)_:page(.+).html", "/v1/2/cms_123_1.html", map[string]string{":v": "2", ":id": "123", ":page": "1"}}) + routers = append(routers, testinfo{"/v1/:v(.+)_cms/ttt_:id(.+)_:page(.+).html", "/v1/2_cms/ttt_123_1.html", map[string]string{":v": "2", ":id": "123", ":page": "1"}}) + routers = append(routers, testinfo{"/api/projects/:pid/members/?:mid", "/api/projects/1/members", map[string]string{":pid": "1"}}) + routers = append(routers, testinfo{"/api/projects/:pid/members/?:mid", "/api/projects/1/members/2", map[string]string{":pid": "1", ":mid": "2"}}) +} + +func TestTreeRouters(t *testing.T) { + for _, r := range routers { + tr := NewTree() + tr.AddRouter(r.url, "astaxie") + ctx := context.NewContext() + obj := tr.Match(r.requesturl, ctx) + if obj == nil || obj.(string) != "astaxie" { + t.Fatal(r.url+" can't get obj, Expect ", r.requesturl) + } + if r.params != nil { + for k, v := range r.params { + if vv := ctx.Input.Param(k); vv != v { + t.Fatal("The Rule: " + r.url + "\nThe RequestURL:" + r.requesturl + "\nThe Key is " + k + ", The Value should be: " + v + ", but get: " + vv) + } else if vv == "" && v != "" { + t.Fatal(r.url + " " + r.requesturl + " get param empty:" + k) + } + } + } + } +} + +func TestStaticPath(t *testing.T) { + tr := NewTree() + tr.AddRouter("/topic/:id", "wildcard") + tr.AddRouter("/topic", "static") + ctx := context.NewContext() + obj := tr.Match("/topic", ctx) + if obj == nil || obj.(string) != "static" { + t.Fatal("/topic is a static route") + } + obj = tr.Match("/topic/1", ctx) + if obj == nil || obj.(string) != "wildcard" { + t.Fatal("/topic/1 is a wildcard route") + } +} + +func TestAddTree(t *testing.T) { + tr := NewTree() + tr.AddRouter("/shop/:id/account", "astaxie") + tr.AddRouter("/shop/:sd/ttt_:id(.+)_:page(.+).html", "astaxie") + t1 := NewTree() + t1.AddTree("/v1/zl", tr) + ctx := context.NewContext() + obj := t1.Match("/v1/zl/shop/123/account", ctx) + if obj == nil || obj.(string) != "astaxie" { + t.Fatal("/v1/zl/shop/:id/account can't get obj ") + } + if ctx.Input.ParamsLen() == 0 { + t.Fatal("get param error") + } + if ctx.Input.Param(":id") != "123" { + t.Fatal("get :id param error") + } + ctx.Input.Reset((*beecontext.Context)(ctx)) + obj = t1.Match("/v1/zl/shop/123/ttt_1_12.html", ctx) + if obj == nil || obj.(string) != "astaxie" { + t.Fatal("/v1/zl//shop/:sd/ttt_:id(.+)_:page(.+).html can't get obj ") + } + if ctx.Input.ParamsLen() == 0 { + t.Fatal("get param error") + } + if ctx.Input.Param(":sd") != "123" || ctx.Input.Param(":id") != "1" || ctx.Input.Param(":page") != "12" { + t.Fatal("get :sd :id :page param error") + } + + t2 := NewTree() + t2.AddTree("/v1/:shopid", tr) + ctx.Input.Reset((*beecontext.Context)(ctx)) + obj = t2.Match("/v1/zl/shop/123/account", ctx) + if obj == nil || obj.(string) != "astaxie" { + t.Fatal("/v1/:shopid/shop/:id/account can't get obj ") + } + if ctx.Input.ParamsLen() == 0 { + t.Fatal("get param error") + } + if ctx.Input.Param(":id") != "123" || ctx.Input.Param(":shopid") != "zl" { + t.Fatal("get :id :shopid param error") + } + ctx.Input.Reset((*beecontext.Context)(ctx)) + obj = t2.Match("/v1/zl/shop/123/ttt_1_12.html", ctx) + if obj == nil || obj.(string) != "astaxie" { + t.Fatal("/v1/:shopid/shop/:sd/ttt_:id(.+)_:page(.+).html can't get obj ") + } + if ctx.Input.ParamsLen() == 0 { + t.Fatal("get :shopid param error") + } + if ctx.Input.Param(":sd") != "123" || ctx.Input.Param(":id") != "1" || ctx.Input.Param(":page") != "12" || ctx.Input.Param(":shopid") != "zl" { + t.Fatal("get :sd :id :page :shopid param error") + } +} + +func TestAddTree2(t *testing.T) { + tr := NewTree() + tr.AddRouter("/shop/:id/account", "astaxie") + tr.AddRouter("/shop/:sd/ttt_:id(.+)_:page(.+).html", "astaxie") + t3 := NewTree() + t3.AddTree("/:version(v1|v2)/:prefix", tr) + ctx := context.NewContext() + obj := t3.Match("/v1/zl/shop/123/account", ctx) + if obj == nil || obj.(string) != "astaxie" { + t.Fatal("/:version(v1|v2)/:prefix/shop/:id/account can't get obj ") + } + if ctx.Input.ParamsLen() == 0 { + t.Fatal("get param error") + } + if ctx.Input.Param(":id") != "123" || ctx.Input.Param(":prefix") != "zl" || ctx.Input.Param(":version") != "v1" { + t.Fatal("get :id :prefix :version param error") + } +} + +func TestAddTree3(t *testing.T) { + tr := NewTree() + tr.AddRouter("/create", "astaxie") + tr.AddRouter("/shop/:sd/account", "astaxie") + t3 := NewTree() + t3.AddTree("/table/:num", tr) + ctx := context.NewContext() + obj := t3.Match("/table/123/shop/123/account", ctx) + if obj == nil || obj.(string) != "astaxie" { + t.Fatal("/table/:num/shop/:sd/account can't get obj ") + } + if ctx.Input.ParamsLen() == 0 { + t.Fatal("get param error") + } + if ctx.Input.Param(":num") != "123" || ctx.Input.Param(":sd") != "123" { + t.Fatal("get :num :sd param error") + } + ctx.Input.Reset((*beecontext.Context)(ctx)) + obj = t3.Match("/table/123/create", ctx) + if obj == nil || obj.(string) != "astaxie" { + t.Fatal("/table/:num/create can't get obj ") + } +} + +func TestAddTree4(t *testing.T) { + tr := NewTree() + tr.AddRouter("/create", "astaxie") + tr.AddRouter("/shop/:sd/:account", "astaxie") + t4 := NewTree() + t4.AddTree("/:info:int/:num/:id", tr) + ctx := context.NewContext() + obj := t4.Match("/12/123/456/shop/123/account", ctx) + if obj == nil || obj.(string) != "astaxie" { + t.Fatal("/:info:int/:num/:id/shop/:sd/:account can't get obj ") + } + if ctx.Input.ParamsLen() == 0 { + t.Fatal("get param error") + } + if ctx.Input.Param(":info") != "12" || ctx.Input.Param(":num") != "123" || + ctx.Input.Param(":id") != "456" || ctx.Input.Param(":sd") != "123" || + ctx.Input.Param(":account") != "account" { + t.Fatal("get :info :num :id :sd :account param error") + } + ctx.Input.Reset((*beecontext.Context)(ctx)) + obj = t4.Match("/12/123/456/create", ctx) + if obj == nil || obj.(string) != "astaxie" { + t.Fatal("/:info:int/:num/:id/create can't get obj ") + } +} + +// Test for issue #1595 +func TestAddTree5(t *testing.T) { + tr := NewTree() + tr.AddRouter("/v1/shop/:id", "shopdetail") + tr.AddRouter("/v1/shop/", "shophome") + ctx := context.NewContext() + obj := tr.Match("/v1/shop/", ctx) + if obj == nil || obj.(string) != "shophome" { + t.Fatal("url /v1/shop/ need match router /v1/shop/ ") + } +} diff --git a/pkg/adapter/utils/caller.go b/pkg/adapter/utils/caller.go new file mode 100644 index 00000000..d4fcc456 --- /dev/null +++ b/pkg/adapter/utils/caller.go @@ -0,0 +1,24 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package utils + +import ( + "github.com/astaxie/beego/pkg/infrastructure/utils" +) + +// GetFuncName get function name +func GetFuncName(i interface{}) string { + return utils.GetFuncName(i) +} diff --git a/pkg/utils/caller_test.go b/pkg/adapter/utils/caller_test.go similarity index 100% rename from pkg/utils/caller_test.go rename to pkg/adapter/utils/caller_test.go diff --git a/pkg/utils/captcha/LICENSE b/pkg/adapter/utils/captcha/LICENSE similarity index 100% rename from pkg/utils/captcha/LICENSE rename to pkg/adapter/utils/captcha/LICENSE diff --git a/pkg/utils/captcha/README.md b/pkg/adapter/utils/captcha/README.md similarity index 100% rename from pkg/utils/captcha/README.md rename to pkg/adapter/utils/captcha/README.md diff --git a/pkg/adapter/utils/captcha/captcha.go b/pkg/adapter/utils/captcha/captcha.go new file mode 100644 index 00000000..faadc8bf --- /dev/null +++ b/pkg/adapter/utils/captcha/captcha.go @@ -0,0 +1,124 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package captcha implements generation and verification of image CAPTCHAs. +// an example for use captcha +// +// ``` +// package controllers +// +// import ( +// "github.com/astaxie/beego" +// "github.com/astaxie/beego/cache" +// "github.com/astaxie/beego/utils/captcha" +// ) +// +// var cpt *captcha.Captcha +// +// func init() { +// // use beego cache system store the captcha data +// store := cache.NewMemoryCache() +// cpt = captcha.NewWithFilter("/captcha/", store) +// } +// +// type MainController struct { +// beego.Controller +// } +// +// func (this *MainController) Get() { +// this.TplName = "index.tpl" +// } +// +// func (this *MainController) Post() { +// this.TplName = "index.tpl" +// +// this.Data["Success"] = cpt.VerifyReq(this.Ctx.Request) +// } +// ``` +// +// template usage +// +// ``` +// {{.Success}} +//
+// {{create_captcha}} +// +//
+// ``` +package captcha + +import ( + "html/template" + "net/http" + "time" + + "github.com/astaxie/beego/pkg/server/web/captcha" + beecontext "github.com/astaxie/beego/pkg/server/web/context" + + "github.com/astaxie/beego/pkg/adapter/cache" + "github.com/astaxie/beego/pkg/adapter/context" +) + +var ( + defaultChars = []byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9} +) + +const ( + // default captcha attributes + challengeNums = 6 + expiration = 600 * time.Second + fieldIDName = "captcha_id" + fieldCaptchaName = "captcha" + cachePrefix = "captcha_" + defaultURLPrefix = "/captcha/" +) + +// Captcha struct +type Captcha captcha.Captcha + +// Handler beego filter handler for serve captcha image +func (c *Captcha) Handler(ctx *context.Context) { + (*captcha.Captcha)(c).Handler((*beecontext.Context)(ctx)) +} + +// CreateCaptchaHTML template func for output html +func (c *Captcha) CreateCaptchaHTML() template.HTML { + return (*captcha.Captcha)(c).CreateCaptchaHTML() +} + +// CreateCaptcha create a new captcha id +func (c *Captcha) CreateCaptcha() (string, error) { + return (*captcha.Captcha)(c).CreateCaptcha() +} + +// VerifyReq verify from a request +func (c *Captcha) VerifyReq(req *http.Request) bool { + return (*captcha.Captcha)(c).VerifyReq(req) +} + +// Verify direct verify id and challenge string +func (c *Captcha) Verify(id string, challenge string) (success bool) { + return (*captcha.Captcha)(c).Verify(id, challenge) +} + +// NewCaptcha create a new captcha.Captcha +func NewCaptcha(urlPrefix string, store cache.Cache) *Captcha { + return (*Captcha)(captcha.NewCaptcha(urlPrefix, store)) +} + +// NewWithFilter create a new captcha.Captcha and auto AddFilter for serve captacha image +// and add a template func for output html +func NewWithFilter(urlPrefix string, store cache.Cache) *Captcha { + return (*Captcha)(captcha.NewWithFilter(urlPrefix, store)) +} diff --git a/pkg/adapter/utils/captcha/image.go b/pkg/adapter/utils/captcha/image.go new file mode 100644 index 00000000..9979db84 --- /dev/null +++ b/pkg/adapter/utils/captcha/image.go @@ -0,0 +1,35 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package captcha + +import ( + "io" + + "github.com/astaxie/beego/pkg/server/web/captcha" +) + +// Image struct +type Image captcha.Image + +// NewImage returns a new captcha image of the given width and height with the +// given digits, where each digit must be in range 0-9. +func NewImage(digits []byte, width, height int) *Image { + return (*Image)(captcha.NewImage(digits, width, height)) +} + +// WriteTo writes captcha image in PNG format into the given writer. +func (m *Image) WriteTo(w io.Writer) (int64, error) { + return (*captcha.Image)(m).WriteTo(w) +} diff --git a/pkg/adapter/utils/captcha/image_test.go b/pkg/adapter/utils/captcha/image_test.go new file mode 100644 index 00000000..bce2134a --- /dev/null +++ b/pkg/adapter/utils/captcha/image_test.go @@ -0,0 +1,58 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package captcha + +import ( + "testing" + + "github.com/astaxie/beego/pkg/adapter/utils" +) + +const ( + // Standard width and height of a captcha image. + stdWidth = 240 + stdHeight = 80 +) + +type byteCounter struct { + n int64 +} + +func (bc *byteCounter) Write(b []byte) (int, error) { + bc.n += int64(len(b)) + return len(b), nil +} + +func BenchmarkNewImage(b *testing.B) { + b.StopTimer() + d := utils.RandomCreateBytes(challengeNums, defaultChars...) + b.StartTimer() + for i := 0; i < b.N; i++ { + NewImage(d, stdWidth, stdHeight) + } +} + +func BenchmarkImageWriteTo(b *testing.B) { + b.StopTimer() + d := utils.RandomCreateBytes(challengeNums, defaultChars...) + b.StartTimer() + counter := &byteCounter{} + for i := 0; i < b.N; i++ { + img := NewImage(d, stdWidth, stdHeight) + img.WriteTo(counter) + b.SetBytes(counter.n) + counter.n = 0 + } +} diff --git a/pkg/adapter/utils/debug.go b/pkg/adapter/utils/debug.go new file mode 100644 index 00000000..d39f3d3e --- /dev/null +++ b/pkg/adapter/utils/debug.go @@ -0,0 +1,34 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package utils + +import ( + "github.com/astaxie/beego/pkg/infrastructure/utils" +) + +// Display print the data in console +func Display(data ...interface{}) { + utils.Display(data...) +} + +// GetDisplayString return data print string +func GetDisplayString(data ...interface{}) string { + return utils.GetDisplayString(data...) +} + +// Stack get stack bytes +func Stack(skip int, indent string) []byte { + return utils.Stack(skip, indent) +} diff --git a/pkg/utils/debug_test.go b/pkg/adapter/utils/debug_test.go similarity index 100% rename from pkg/utils/debug_test.go rename to pkg/adapter/utils/debug_test.go diff --git a/pkg/adapter/utils/file.go b/pkg/adapter/utils/file.go new file mode 100644 index 00000000..8979389e --- /dev/null +++ b/pkg/adapter/utils/file.go @@ -0,0 +1,47 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package utils + +import ( + "github.com/astaxie/beego/pkg/infrastructure/utils" +) + +// SelfPath gets compiled executable file absolute path +func SelfPath() string { + return utils.SelfPath() +} + +// SelfDir gets compiled executable file directory +func SelfDir() string { + return utils.SelfDir() +} + +// FileExists reports whether the named file or directory exists. +func FileExists(name string) bool { + return utils.FileExists(name) +} + +// SearchFile Search a file in paths. +// this is often used in search config file in /etc ~/ +func SearchFile(filename string, paths ...string) (fullpath string, err error) { + return utils.SearchFile(filename, paths...) +} + +// GrepFile like command grep -E +// for example: GrepFile(`^hello`, "hello.txt") +// \n is striped while read +func GrepFile(patten string, filename string) (lines []string, err error) { + return utils.GrepFile(patten, filename) +} diff --git a/pkg/adapter/utils/mail.go b/pkg/adapter/utils/mail.go new file mode 100644 index 00000000..35a58756 --- /dev/null +++ b/pkg/adapter/utils/mail.go @@ -0,0 +1,63 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package utils + +import ( + "io" + + "github.com/astaxie/beego/pkg/infrastructure/utils" +) + +// Email is the type used for email messages +type Email utils.Email + +// Attachment is a struct representing an email attachment. +// Based on the mime/multipart.FileHeader struct, Attachment contains the name, MIMEHeader, and content of the attachment in question +type Attachment utils.Attachment + +// NewEMail create new Email struct with config json. +// config json is followed from Email struct fields. +func NewEMail(config string) *Email { + return (*Email)(utils.NewEMail(config)) +} + +// Bytes Make all send information to byte +func (e *Email) Bytes() ([]byte, error) { + return (*utils.Email)(e).Bytes() +} + +// AttachFile Add attach file to the send mail +func (e *Email) AttachFile(args ...string) (*Attachment, error) { + a, err := (*utils.Email)(e).AttachFile(args...) + if err != nil { + return nil, err + } + return (*Attachment)(a), err +} + +// Attach is used to attach content from an io.Reader to the email. +// Parameters include an io.Reader, the desired filename for the attachment, and the Content-Type. +func (e *Email) Attach(r io.Reader, filename string, args ...string) (*Attachment, error) { + a, err := (*utils.Email)(e).Attach(r, filename, args...) + if err != nil { + return nil, err + } + return (*Attachment)(a), err +} + +// Send will send out the mail +func (e *Email) Send() error { + return (*utils.Email)(e).Send() +} diff --git a/pkg/utils/mail_test.go b/pkg/adapter/utils/mail_test.go similarity index 100% rename from pkg/utils/mail_test.go rename to pkg/adapter/utils/mail_test.go diff --git a/pkg/adapter/utils/pagination/controller.go b/pkg/adapter/utils/pagination/controller.go new file mode 100644 index 00000000..a908d8b0 --- /dev/null +++ b/pkg/adapter/utils/pagination/controller.go @@ -0,0 +1,26 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package pagination + +import ( + "github.com/astaxie/beego/pkg/adapter/context" + beecontext "github.com/astaxie/beego/pkg/server/web/context" + "github.com/astaxie/beego/pkg/server/web/pagination" +) + +// SetPaginator Instantiates a Paginator and assigns it to context.Input.Data("paginator"). +func SetPaginator(ctx *context.Context, per int, nums int64) (paginator *Paginator) { + return (*Paginator)(pagination.SetPaginator((*beecontext.Context)(ctx), per, nums)) +} diff --git a/pkg/utils/pagination/doc.go b/pkg/adapter/utils/pagination/doc.go similarity index 96% rename from pkg/utils/pagination/doc.go rename to pkg/adapter/utils/pagination/doc.go index 718f5e7a..9abc6d78 100644 --- a/pkg/utils/pagination/doc.go +++ b/pkg/adapter/utils/pagination/doc.go @@ -8,7 +8,7 @@ In your beego.Controller: package controllers - import "github.com/astaxie/beego/pkg/utils/pagination" + import "github.com/astaxie/beego/utils/pagination" type PostsController struct { beego.Controller diff --git a/pkg/adapter/utils/pagination/paginator.go b/pkg/adapter/utils/pagination/paginator.go new file mode 100644 index 00000000..4bd4a1b0 --- /dev/null +++ b/pkg/adapter/utils/pagination/paginator.go @@ -0,0 +1,112 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package pagination + +import ( + "net/http" + + "github.com/astaxie/beego/pkg/infrastructure/utils/pagination" +) + +// Paginator within the state of a http request. +type Paginator pagination.Paginator + +// PageNums Returns the total number of pages. +func (p *Paginator) PageNums() int { + return (*pagination.Paginator)(p).PageNums() +} + +// Nums Returns the total number of items (e.g. from doing SQL count). +func (p *Paginator) Nums() int64 { + return (*pagination.Paginator)(p).Nums() +} + +// SetNums Sets the total number of items. +func (p *Paginator) SetNums(nums interface{}) { + (*pagination.Paginator)(p).SetNums(nums) +} + +// Page Returns the current page. +func (p *Paginator) Page() int { + return (*pagination.Paginator)(p).Page() +} + +// Pages Returns a list of all pages. +// +// Usage (in a view template): +// +// {{range $index, $page := .paginator.Pages}} +// +// {{$page}} +// +// {{end}} +func (p *Paginator) Pages() []int { + return (*pagination.Paginator)(p).Pages() +} + +// PageLink Returns URL for a given page index. +func (p *Paginator) PageLink(page int) string { + return (*pagination.Paginator)(p).PageLink(page) +} + +// PageLinkPrev Returns URL to the previous page. +func (p *Paginator) PageLinkPrev() (link string) { + return (*pagination.Paginator)(p).PageLinkPrev() +} + +// PageLinkNext Returns URL to the next page. +func (p *Paginator) PageLinkNext() (link string) { + return (*pagination.Paginator)(p).PageLinkNext() +} + +// PageLinkFirst Returns URL to the first page. +func (p *Paginator) PageLinkFirst() (link string) { + return (*pagination.Paginator)(p).PageLinkFirst() +} + +// PageLinkLast Returns URL to the last page. +func (p *Paginator) PageLinkLast() (link string) { + return (*pagination.Paginator)(p).PageLinkLast() +} + +// HasPrev Returns true if the current page has a predecessor. +func (p *Paginator) HasPrev() bool { + return (*pagination.Paginator)(p).HasPrev() +} + +// HasNext Returns true if the current page has a successor. +func (p *Paginator) HasNext() bool { + return (*pagination.Paginator)(p).HasNext() +} + +// IsActive Returns true if the given page index points to the current page. +func (p *Paginator) IsActive(page int) bool { + return (*pagination.Paginator)(p).IsActive(page) +} + +// Offset Returns the current offset. +func (p *Paginator) Offset() int { + return (*pagination.Paginator)(p).Offset() +} + +// HasPages Returns true if there is more than one page. +func (p *Paginator) HasPages() bool { + return (*pagination.Paginator)(p).HasPages() +} + +// NewPaginator Instantiates a paginator struct for the current http request. +func NewPaginator(req *http.Request, per int, nums interface{}) *Paginator { + return (*Paginator)(pagination.NewPaginator(req, per, nums)) +} diff --git a/pkg/adapter/utils/rand.go b/pkg/adapter/utils/rand.go new file mode 100644 index 00000000..ae415cf3 --- /dev/null +++ b/pkg/adapter/utils/rand.go @@ -0,0 +1,24 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package utils + +import ( + "github.com/astaxie/beego/pkg/infrastructure/utils" +) + +// RandomCreateBytes generate random []byte by specify chars. +func RandomCreateBytes(n int, alphabets ...byte) []byte { + return utils.RandomCreateBytes(n, alphabets...) +} diff --git a/pkg/utils/rand_test.go b/pkg/adapter/utils/rand_test.go similarity index 100% rename from pkg/utils/rand_test.go rename to pkg/adapter/utils/rand_test.go diff --git a/pkg/adapter/utils/safemap.go b/pkg/adapter/utils/safemap.go new file mode 100644 index 00000000..13e7bb46 --- /dev/null +++ b/pkg/adapter/utils/safemap.go @@ -0,0 +1,58 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package utils + +import ( + "github.com/astaxie/beego/pkg/infrastructure/utils" +) + +// BeeMap is a map with lock +type BeeMap utils.BeeMap + +// NewBeeMap return new safemap +func NewBeeMap() *BeeMap { + return (*BeeMap)(utils.NewBeeMap()) +} + +// Get from maps return the k's value +func (m *BeeMap) Get(k interface{}) interface{} { + return (*utils.BeeMap)(m).Get(k) +} + +// Set Maps the given key and value. Returns false +// if the key is already in the map and changes nothing. +func (m *BeeMap) Set(k interface{}, v interface{}) bool { + return (*utils.BeeMap)(m).Set(k, v) +} + +// Check Returns true if k is exist in the map. +func (m *BeeMap) Check(k interface{}) bool { + return (*utils.BeeMap)(m).Check(k) +} + +// Delete the given key and value. +func (m *BeeMap) Delete(k interface{}) { + (*utils.BeeMap)(m).Delete(k) +} + +// Items returns all items in safemap. +func (m *BeeMap) Items() map[interface{}]interface{} { + return (*utils.BeeMap)(m).Items() +} + +// Count returns the number of items within the map. +func (m *BeeMap) Count() int { + return (*utils.BeeMap)(m).Count() +} diff --git a/pkg/utils/safemap_test.go b/pkg/adapter/utils/safemap_test.go similarity index 100% rename from pkg/utils/safemap_test.go rename to pkg/adapter/utils/safemap_test.go diff --git a/pkg/adapter/utils/slice.go b/pkg/adapter/utils/slice.go new file mode 100644 index 00000000..24d19ad2 --- /dev/null +++ b/pkg/adapter/utils/slice.go @@ -0,0 +1,101 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package utils + +import ( + "github.com/astaxie/beego/pkg/infrastructure/utils" +) + +type reducetype func(interface{}) interface{} +type filtertype func(interface{}) bool + +// InSlice checks given string in string slice or not. +func InSlice(v string, sl []string) bool { + return utils.InSlice(v, sl) +} + +// InSliceIface checks given interface in interface slice. +func InSliceIface(v interface{}, sl []interface{}) bool { + return utils.InSliceIface(v, sl) +} + +// SliceRandList generate an int slice from min to max. +func SliceRandList(min, max int) []int { + return utils.SliceRandList(min, max) +} + +// SliceMerge merges interface slices to one slice. +func SliceMerge(slice1, slice2 []interface{}) (c []interface{}) { + return utils.SliceMerge(slice1, slice2) +} + +// SliceReduce generates a new slice after parsing every value by reduce function +func SliceReduce(slice []interface{}, a reducetype) (dslice []interface{}) { + return utils.SliceReduce(slice, func(i interface{}) interface{} { + return a(i) + }) +} + +// SliceRand returns random one from slice. +func SliceRand(a []interface{}) (b interface{}) { + return utils.SliceRand(a) +} + +// SliceSum sums all values in int64 slice. +func SliceSum(intslice []int64) (sum int64) { + return utils.SliceSum(intslice) +} + +// SliceFilter generates a new slice after filter function. +func SliceFilter(slice []interface{}, a filtertype) (ftslice []interface{}) { + return utils.SliceFilter(slice, func(i interface{}) bool { + return a(i) + }) +} + +// SliceDiff returns diff slice of slice1 - slice2. +func SliceDiff(slice1, slice2 []interface{}) (diffslice []interface{}) { + return utils.SliceDiff(slice1, slice2) +} + +// SliceIntersect returns slice that are present in all the slice1 and slice2. +func SliceIntersect(slice1, slice2 []interface{}) (diffslice []interface{}) { + return utils.SliceIntersect(slice1, slice2) +} + +// SliceChunk separates one slice to some sized slice. +func SliceChunk(slice []interface{}, size int) (chunkslice [][]interface{}) { + return utils.SliceChunk(slice, size) +} + +// SliceRange generates a new slice from begin to end with step duration of int64 number. +func SliceRange(start, end, step int64) (intslice []int64) { + return utils.SliceRange(start, end, step) +} + +// SlicePad prepends size number of val into slice. +func SlicePad(slice []interface{}, size int, val interface{}) []interface{} { + return utils.SlicePad(slice, size, val) +} + +// SliceUnique cleans repeated values in slice. +func SliceUnique(slice []interface{}) (uniqueslice []interface{}) { + return utils.SliceUnique(slice) +} + +// SliceShuffle shuffles a slice. +func SliceShuffle(slice []interface{}) []interface{} { + return utils.SliceShuffle(slice) +} diff --git a/pkg/utils/slice_test.go b/pkg/adapter/utils/slice_test.go similarity index 100% rename from pkg/utils/slice_test.go rename to pkg/adapter/utils/slice_test.go diff --git a/pkg/adapter/utils/utils.go b/pkg/adapter/utils/utils.go new file mode 100644 index 00000000..1f3bcd31 --- /dev/null +++ b/pkg/adapter/utils/utils.go @@ -0,0 +1,10 @@ +package utils + +import ( + "github.com/astaxie/beego/pkg/infrastructure/utils" +) + +// GetGOPATHs returns all paths in GOPATH variable. +func GetGOPATHs() []string { + return utils.GetGOPATHs() +} diff --git a/pkg/adapter/validation/util.go b/pkg/adapter/validation/util.go new file mode 100644 index 00000000..729712e0 --- /dev/null +++ b/pkg/adapter/validation/util.go @@ -0,0 +1,62 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package validation + +import ( + "reflect" + + "github.com/astaxie/beego/pkg/infrastructure/validation" +) + +const ( + // ValidTag struct tag + ValidTag = validation.ValidTag + + LabelTag = validation.LabelTag +) + +var ( + ErrInt64On32 = validation.ErrInt64On32 +) + +// CustomFunc is for custom validate function +type CustomFunc func(v *Validation, obj interface{}, key string) + +// AddCustomFunc Add a custom function to validation +// The name can not be: +// Clear +// HasErrors +// ErrorMap +// Error +// Check +// Valid +// NoMatch +// If the name is same with exists function, it will replace the origin valid function +func AddCustomFunc(name string, f CustomFunc) error { + return validation.AddCustomFunc(name, func(v *validation.Validation, obj interface{}, key string) { + f((*Validation)(v), obj, key) + }) +} + +// ValidFunc Valid function type +type ValidFunc validation.ValidFunc + +// Funcs Validate function map +type Funcs validation.Funcs + +// Call validate values with named type string +func (f Funcs) Call(name string, params ...interface{}) (result []reflect.Value, err error) { + return (validation.Funcs(f)).Call(name, params...) +} diff --git a/pkg/adapter/validation/validation.go b/pkg/adapter/validation/validation.go new file mode 100644 index 00000000..1cdb8dda --- /dev/null +++ b/pkg/adapter/validation/validation.go @@ -0,0 +1,274 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package validation for validations +// +// import ( +// "github.com/astaxie/beego/validation" +// "log" +// ) +// +// type User struct { +// Name string +// Age int +// } +// +// func main() { +// u := User{"man", 40} +// valid := validation.Validation{} +// valid.Required(u.Name, "name") +// valid.MaxSize(u.Name, 15, "nameMax") +// valid.Range(u.Age, 0, 140, "age") +// if valid.HasErrors() { +// // validation does not pass +// // print invalid message +// for _, err := range valid.Errors { +// log.Println(err.Key, err.Message) +// } +// } +// // or use like this +// if v := valid.Max(u.Age, 140, "ageMax"); !v.Ok { +// log.Println(v.Error.Key, v.Error.Message) +// } +// } +// +// more info: http://beego.me/docs/mvc/controller/validation.md +package validation + +import ( + "fmt" + "regexp" + + "github.com/astaxie/beego/pkg/infrastructure/validation" +) + +// ValidFormer valid interface +type ValidFormer interface { + Valid(*Validation) +} + +// Error show the error +type Error validation.Error + +// String Returns the Message. +func (e *Error) String() string { + if e == nil { + return "" + } + return e.Message +} + +// Implement Error interface. +// Return e.String() +func (e *Error) Error() string { return e.String() } + +// Result is returned from every validation method. +// It provides an indication of success, and a pointer to the Error (if any). +type Result validation.Result + +// Key Get Result by given key string. +func (r *Result) Key(key string) *Result { + if r.Error != nil { + r.Error.Key = key + } + return r +} + +// Message Set Result message by string or format string with args +func (r *Result) Message(message string, args ...interface{}) *Result { + if r.Error != nil { + if len(args) == 0 { + r.Error.Message = message + } else { + r.Error.Message = fmt.Sprintf(message, args...) + } + } + return r +} + +// A Validation context manages data validation and error messages. +type Validation validation.Validation + +// Clear Clean all ValidationError. +func (v *Validation) Clear() { + (*validation.Validation)(v).Clear() +} + +// HasErrors Has ValidationError nor not. +func (v *Validation) HasErrors() bool { + return (*validation.Validation)(v).HasErrors() +} + +// ErrorMap Return the errors mapped by key. +// If there are multiple validation errors associated with a single key, the +// first one "wins". (Typically the first validation will be the more basic). +func (v *Validation) ErrorMap() map[string][]*Error { + newErrors := (*validation.Validation)(v).ErrorMap() + res := make(map[string][]*Error, len(newErrors)) + for n, es := range newErrors { + errs := make([]*Error, 0, len(es)) + + for _, e := range es { + errs = append(errs, (*Error)(e)) + } + + res[n] = errs + } + return res +} + +// Error Add an error to the validation context. +func (v *Validation) Error(message string, args ...interface{}) *Result { + return (*Result)((*validation.Validation)(v).Error(message, args...)) +} + +// Required Test that the argument is non-nil and non-empty (if string or list) +func (v *Validation) Required(obj interface{}, key string) *Result { + return (*Result)((*validation.Validation)(v).Required(obj, key)) +} + +// Min Test that the obj is greater than min if obj's type is int +func (v *Validation) Min(obj interface{}, min int, key string) *Result { + return (*Result)((*validation.Validation)(v).Min(obj, min, key)) +} + +// Max Test that the obj is less than max if obj's type is int +func (v *Validation) Max(obj interface{}, max int, key string) *Result { + return (*Result)((*validation.Validation)(v).Max(obj, max, key)) +} + +// Range Test that the obj is between mni and max if obj's type is int +func (v *Validation) Range(obj interface{}, min, max int, key string) *Result { + return (*Result)((*validation.Validation)(v).Range(obj, min, max, key)) +} + +// MinSize Test that the obj is longer than min size if type is string or slice +func (v *Validation) MinSize(obj interface{}, min int, key string) *Result { + return (*Result)((*validation.Validation)(v).MinSize(obj, min, key)) +} + +// MaxSize Test that the obj is shorter than max size if type is string or slice +func (v *Validation) MaxSize(obj interface{}, max int, key string) *Result { + return (*Result)((*validation.Validation)(v).MaxSize(obj, max, key)) +} + +// Length Test that the obj is same length to n if type is string or slice +func (v *Validation) Length(obj interface{}, n int, key string) *Result { + return (*Result)((*validation.Validation)(v).Length(obj, n, key)) +} + +// Alpha Test that the obj is [a-zA-Z] if type is string +func (v *Validation) Alpha(obj interface{}, key string) *Result { + return (*Result)((*validation.Validation)(v).Alpha(obj, key)) +} + +// Numeric Test that the obj is [0-9] if type is string +func (v *Validation) Numeric(obj interface{}, key string) *Result { + return (*Result)((*validation.Validation)(v).Numeric(obj, key)) +} + +// AlphaNumeric Test that the obj is [0-9a-zA-Z] if type is string +func (v *Validation) AlphaNumeric(obj interface{}, key string) *Result { + return (*Result)((*validation.Validation)(v).AlphaNumeric(obj, key)) +} + +// Match Test that the obj matches regexp if type is string +func (v *Validation) Match(obj interface{}, regex *regexp.Regexp, key string) *Result { + return (*Result)((*validation.Validation)(v).Match(obj, regex, key)) +} + +// NoMatch Test that the obj doesn't match regexp if type is string +func (v *Validation) NoMatch(obj interface{}, regex *regexp.Regexp, key string) *Result { + return (*Result)((*validation.Validation)(v).NoMatch(obj, regex, key)) +} + +// AlphaDash Test that the obj is [0-9a-zA-Z_-] if type is string +func (v *Validation) AlphaDash(obj interface{}, key string) *Result { + return (*Result)((*validation.Validation)(v).AlphaDash(obj, key)) +} + +// Email Test that the obj is email address if type is string +func (v *Validation) Email(obj interface{}, key string) *Result { + return (*Result)((*validation.Validation)(v).Email(obj, key)) +} + +// IP Test that the obj is IP address if type is string +func (v *Validation) IP(obj interface{}, key string) *Result { + return (*Result)((*validation.Validation)(v).IP(obj, key)) +} + +// Base64 Test that the obj is base64 encoded if type is string +func (v *Validation) Base64(obj interface{}, key string) *Result { + return (*Result)((*validation.Validation)(v).Base64(obj, key)) +} + +// Mobile Test that the obj is chinese mobile number if type is string +func (v *Validation) Mobile(obj interface{}, key string) *Result { + return (*Result)((*validation.Validation)(v).Mobile(obj, key)) +} + +// Tel Test that the obj is chinese telephone number if type is string +func (v *Validation) Tel(obj interface{}, key string) *Result { + return (*Result)((*validation.Validation)(v).Tel(obj, key)) +} + +// Phone Test that the obj is chinese mobile or telephone number if type is string +func (v *Validation) Phone(obj interface{}, key string) *Result { + return (*Result)((*validation.Validation)(v).Phone(obj, key)) +} + +// ZipCode Test that the obj is chinese zip code if type is string +func (v *Validation) ZipCode(obj interface{}, key string) *Result { + return (*Result)((*validation.Validation)(v).ZipCode(obj, key)) +} + +// key must like aa.bb.cc or aa.bb. +// AddError adds independent error message for the provided key +func (v *Validation) AddError(key, message string) { + (*validation.Validation)(v).AddError(key, message) +} + +// SetError Set error message for one field in ValidationError +func (v *Validation) SetError(fieldName string, errMsg string) *Error { + return (*Error)((*validation.Validation)(v).SetError(fieldName, errMsg)) +} + +// Check Apply a group of validators to a field, in order, and return the +// ValidationResult from the first one that fails, or the last one that +// succeeds. +func (v *Validation) Check(obj interface{}, checks ...Validator) *Result { + vldts := make([]validation.Validator, 0, len(checks)) + for _, v := range checks { + vldts = append(vldts, validation.Validator(v)) + } + return (*Result)((*validation.Validation)(v).Check(obj, vldts...)) +} + +// Valid Validate a struct. +// the obj parameter must be a struct or a struct pointer +func (v *Validation) Valid(obj interface{}) (b bool, err error) { + return (*validation.Validation)(v).Valid(obj) +} + +// RecursiveValid Recursively validate a struct. +// Step1: Validate by v.Valid +// Step2: If pass on step1, then reflect obj's fields +// Step3: Do the Recursively validation to all struct or struct pointer fields +func (v *Validation) RecursiveValid(objc interface{}) (bool, error) { + return (*validation.Validation)(v).RecursiveValid(objc) +} + +func (v *Validation) CanSkipAlso(skipFunc string) { + (*validation.Validation)(v).CanSkipAlso(skipFunc) +} diff --git a/pkg/validation/validation_test.go b/pkg/adapter/validation/validation_test.go similarity index 100% rename from pkg/validation/validation_test.go rename to pkg/adapter/validation/validation_test.go diff --git a/pkg/adapter/validation/validators.go b/pkg/adapter/validation/validators.go new file mode 100644 index 00000000..1a063749 --- /dev/null +++ b/pkg/adapter/validation/validators.go @@ -0,0 +1,512 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package validation + +import ( + "sync" + + "github.com/astaxie/beego/pkg/infrastructure/validation" +) + +// CanSkipFuncs will skip valid if RequiredFirst is true and the struct field's value is empty +var CanSkipFuncs = validation.CanSkipFuncs + +// MessageTmpls store commond validate template +var MessageTmpls = map[string]string{ + "Required": "Can not be empty", + "Min": "Minimum is %d", + "Max": "Maximum is %d", + "Range": "Range is %d to %d", + "MinSize": "Minimum size is %d", + "MaxSize": "Maximum size is %d", + "Length": "Required length is %d", + "Alpha": "Must be valid alpha characters", + "Numeric": "Must be valid numeric characters", + "AlphaNumeric": "Must be valid alpha or numeric characters", + "Match": "Must match %s", + "NoMatch": "Must not match %s", + "AlphaDash": "Must be valid alpha or numeric or dash(-_) characters", + "Email": "Must be a valid email address", + "IP": "Must be a valid ip address", + "Base64": "Must be valid base64 characters", + "Mobile": "Must be valid mobile number", + "Tel": "Must be valid telephone number", + "Phone": "Must be valid telephone or mobile phone number", + "ZipCode": "Must be valid zipcode", +} + +var once sync.Once + +// SetDefaultMessage set default messages +// if not set, the default messages are +// "Required": "Can not be empty", +// "Min": "Minimum is %d", +// "Max": "Maximum is %d", +// "Range": "Range is %d to %d", +// "MinSize": "Minimum size is %d", +// "MaxSize": "Maximum size is %d", +// "Length": "Required length is %d", +// "Alpha": "Must be valid alpha characters", +// "Numeric": "Must be valid numeric characters", +// "AlphaNumeric": "Must be valid alpha or numeric characters", +// "Match": "Must match %s", +// "NoMatch": "Must not match %s", +// "AlphaDash": "Must be valid alpha or numeric or dash(-_) characters", +// "Email": "Must be a valid email address", +// "IP": "Must be a valid ip address", +// "Base64": "Must be valid base64 characters", +// "Mobile": "Must be valid mobile number", +// "Tel": "Must be valid telephone number", +// "Phone": "Must be valid telephone or mobile phone number", +// "ZipCode": "Must be valid zipcode", +func SetDefaultMessage(msg map[string]string) { + validation.SetDefaultMessage(msg) +} + +// Validator interface +type Validator interface { + IsSatisfied(interface{}) bool + DefaultMessage() string + GetKey() string + GetLimitValue() interface{} +} + +// Required struct +type Required validation.Required + +// IsSatisfied judge whether obj has value +func (r Required) IsSatisfied(obj interface{}) bool { + return validation.Required(r).IsSatisfied(obj) +} + +// DefaultMessage return the default error message +func (r Required) DefaultMessage() string { + return validation.Required(r).DefaultMessage() +} + +// GetKey return the r.Key +func (r Required) GetKey() string { + return validation.Required(r).GetKey() +} + +// GetLimitValue return nil now +func (r Required) GetLimitValue() interface{} { + return validation.Required(r).GetLimitValue() +} + +// Min check struct +type Min validation.Min + +// IsSatisfied judge whether obj is valid +// not support int64 on 32-bit platform +func (m Min) IsSatisfied(obj interface{}) bool { + return validation.Min(m).IsSatisfied(obj) +} + +// DefaultMessage return the default min error message +func (m Min) DefaultMessage() string { + return validation.Min(m).DefaultMessage() +} + +// GetKey return the m.Key +func (m Min) GetKey() string { + return validation.Min(m).GetKey() +} + +// GetLimitValue return the limit value, Min +func (m Min) GetLimitValue() interface{} { + return validation.Min(m).GetLimitValue() +} + +// Max validate struct +type Max validation.Max + +// IsSatisfied judge whether obj is valid +// not support int64 on 32-bit platform +func (m Max) IsSatisfied(obj interface{}) bool { + return validation.Max(m).IsSatisfied(obj) +} + +// DefaultMessage return the default max error message +func (m Max) DefaultMessage() string { + return validation.Max(m).DefaultMessage() +} + +// GetKey return the m.Key +func (m Max) GetKey() string { + return validation.Max(m).GetKey() +} + +// GetLimitValue return the limit value, Max +func (m Max) GetLimitValue() interface{} { + return validation.Max(m).GetLimitValue() +} + +// Range Requires an integer to be within Min, Max inclusive. +type Range validation.Range + +// IsSatisfied judge whether obj is valid +// not support int64 on 32-bit platform +func (r Range) IsSatisfied(obj interface{}) bool { + return validation.Range(r).IsSatisfied(obj) +} + +// DefaultMessage return the default Range error message +func (r Range) DefaultMessage() string { + return validation.Range(r).DefaultMessage() +} + +// GetKey return the m.Key +func (r Range) GetKey() string { + return validation.Range(r).GetKey() +} + +// GetLimitValue return the limit value, Max +func (r Range) GetLimitValue() interface{} { + return validation.Range(r).GetLimitValue() +} + +// MinSize Requires an array or string to be at least a given length. +type MinSize validation.MinSize + +// IsSatisfied judge whether obj is valid +func (m MinSize) IsSatisfied(obj interface{}) bool { + return validation.MinSize(m).IsSatisfied(obj) +} + +// DefaultMessage return the default MinSize error message +func (m MinSize) DefaultMessage() string { + return validation.MinSize(m).DefaultMessage() +} + +// GetKey return the m.Key +func (m MinSize) GetKey() string { + return validation.MinSize(m).GetKey() +} + +// GetLimitValue return the limit value +func (m MinSize) GetLimitValue() interface{} { + return validation.MinSize(m).GetLimitValue() +} + +// MaxSize Requires an array or string to be at most a given length. +type MaxSize validation.MaxSize + +// IsSatisfied judge whether obj is valid +func (m MaxSize) IsSatisfied(obj interface{}) bool { + return validation.MaxSize(m).IsSatisfied(obj) +} + +// DefaultMessage return the default MaxSize error message +func (m MaxSize) DefaultMessage() string { + return validation.MaxSize(m).DefaultMessage() +} + +// GetKey return the m.Key +func (m MaxSize) GetKey() string { + return validation.MaxSize(m).GetKey() +} + +// GetLimitValue return the limit value +func (m MaxSize) GetLimitValue() interface{} { + return validation.MaxSize(m).GetLimitValue() +} + +// Length Requires an array or string to be exactly a given length. +type Length validation.Length + +// IsSatisfied judge whether obj is valid +func (l Length) IsSatisfied(obj interface{}) bool { + return validation.Length(l).IsSatisfied(obj) +} + +// DefaultMessage return the default Length error message +func (l Length) DefaultMessage() string { + return validation.Length(l).DefaultMessage() +} + +// GetKey return the m.Key +func (l Length) GetKey() string { + return validation.Length(l).GetKey() +} + +// GetLimitValue return the limit value +func (l Length) GetLimitValue() interface{} { + return validation.Length(l).GetLimitValue() +} + +// Alpha check the alpha +type Alpha validation.Alpha + +// IsSatisfied judge whether obj is valid +func (a Alpha) IsSatisfied(obj interface{}) bool { + return validation.Alpha(a).IsSatisfied(obj) +} + +// DefaultMessage return the default Length error message +func (a Alpha) DefaultMessage() string { + return validation.Alpha(a).DefaultMessage() +} + +// GetKey return the m.Key +func (a Alpha) GetKey() string { + return validation.Alpha(a).GetKey() +} + +// GetLimitValue return the limit value +func (a Alpha) GetLimitValue() interface{} { + return validation.Alpha(a).GetLimitValue() +} + +// Numeric check number +type Numeric validation.Numeric + +// IsSatisfied judge whether obj is valid +func (n Numeric) IsSatisfied(obj interface{}) bool { + return validation.Numeric(n).IsSatisfied(obj) +} + +// DefaultMessage return the default Length error message +func (n Numeric) DefaultMessage() string { + return validation.Numeric(n).DefaultMessage() +} + +// GetKey return the n.Key +func (n Numeric) GetKey() string { + return validation.Numeric(n).GetKey() +} + +// GetLimitValue return the limit value +func (n Numeric) GetLimitValue() interface{} { + return validation.Numeric(n).GetLimitValue() +} + +// AlphaNumeric check alpha and number +type AlphaNumeric validation.AlphaNumeric + +// IsSatisfied judge whether obj is valid +func (a AlphaNumeric) IsSatisfied(obj interface{}) bool { + return validation.AlphaNumeric(a).IsSatisfied(obj) +} + +// DefaultMessage return the default Length error message +func (a AlphaNumeric) DefaultMessage() string { + return validation.AlphaNumeric(a).DefaultMessage() +} + +// GetKey return the a.Key +func (a AlphaNumeric) GetKey() string { + return validation.AlphaNumeric(a).GetKey() +} + +// GetLimitValue return the limit value +func (a AlphaNumeric) GetLimitValue() interface{} { + return validation.AlphaNumeric(a).GetLimitValue() +} + +// Match Requires a string to match a given regex. +type Match validation.Match + +// IsSatisfied judge whether obj is valid +func (m Match) IsSatisfied(obj interface{}) bool { + return validation.Match(m).IsSatisfied(obj) +} + +// DefaultMessage return the default Match error message +func (m Match) DefaultMessage() string { + return validation.Match(m).DefaultMessage() +} + +// GetKey return the m.Key +func (m Match) GetKey() string { + return validation.Match(m).GetKey() +} + +// GetLimitValue return the limit value +func (m Match) GetLimitValue() interface{} { + return validation.Match(m).GetLimitValue() +} + +// NoMatch Requires a string to not match a given regex. +type NoMatch validation.NoMatch + +// IsSatisfied judge whether obj is valid +func (n NoMatch) IsSatisfied(obj interface{}) bool { + return validation.NoMatch(n).IsSatisfied(obj) +} + +// DefaultMessage return the default NoMatch error message +func (n NoMatch) DefaultMessage() string { + return validation.NoMatch(n).DefaultMessage() +} + +// GetKey return the n.Key +func (n NoMatch) GetKey() string { + return validation.NoMatch(n).GetKey() +} + +// GetLimitValue return the limit value +func (n NoMatch) GetLimitValue() interface{} { + return validation.NoMatch(n).GetLimitValue() +} + +// AlphaDash check not Alpha +type AlphaDash validation.AlphaDash + +// DefaultMessage return the default AlphaDash error message +func (a AlphaDash) DefaultMessage() string { + return validation.AlphaDash(a).DefaultMessage() +} + +// GetKey return the n.Key +func (a AlphaDash) GetKey() string { + return validation.AlphaDash(a).GetKey() +} + +// GetLimitValue return the limit value +func (a AlphaDash) GetLimitValue() interface{} { + return validation.AlphaDash(a).GetLimitValue() +} + +// Email check struct +type Email validation.Email + +// DefaultMessage return the default Email error message +func (e Email) DefaultMessage() string { + return validation.Email(e).DefaultMessage() +} + +// GetKey return the n.Key +func (e Email) GetKey() string { + return validation.Email(e).GetKey() +} + +// GetLimitValue return the limit value +func (e Email) GetLimitValue() interface{} { + return validation.Email(e).GetLimitValue() +} + +// IP check struct +type IP validation.IP + +// DefaultMessage return the default IP error message +func (i IP) DefaultMessage() string { + return validation.IP(i).DefaultMessage() +} + +// GetKey return the i.Key +func (i IP) GetKey() string { + return validation.IP(i).GetKey() +} + +// GetLimitValue return the limit value +func (i IP) GetLimitValue() interface{} { + return validation.IP(i).GetLimitValue() +} + +// Base64 check struct +type Base64 validation.Base64 + +// DefaultMessage return the default Base64 error message +func (b Base64) DefaultMessage() string { + return validation.Base64(b).DefaultMessage() +} + +// GetKey return the b.Key +func (b Base64) GetKey() string { + return validation.Base64(b).GetKey() +} + +// GetLimitValue return the limit value +func (b Base64) GetLimitValue() interface{} { + return validation.Base64(b).GetLimitValue() +} + +// Mobile check struct +type Mobile validation.Mobile + +// DefaultMessage return the default Mobile error message +func (m Mobile) DefaultMessage() string { + return validation.Mobile(m).DefaultMessage() +} + +// GetKey return the m.Key +func (m Mobile) GetKey() string { + return validation.Mobile(m).GetKey() +} + +// GetLimitValue return the limit value +func (m Mobile) GetLimitValue() interface{} { + return validation.Mobile(m).GetLimitValue() +} + +// Tel check telephone struct +type Tel validation.Tel + +// DefaultMessage return the default Tel error message +func (t Tel) DefaultMessage() string { + return validation.Tel(t).DefaultMessage() +} + +// GetKey return the t.Key +func (t Tel) GetKey() string { + return validation.Tel(t).GetKey() +} + +// GetLimitValue return the limit value +func (t Tel) GetLimitValue() interface{} { + return validation.Tel(t).GetLimitValue() +} + +// Phone just for chinese telephone or mobile phone number +type Phone validation.Phone + +// IsSatisfied judge whether obj is valid +func (p Phone) IsSatisfied(obj interface{}) bool { + return validation.Phone(p).IsSatisfied(obj) +} + +// DefaultMessage return the default Phone error message +func (p Phone) DefaultMessage() string { + return validation.Phone(p).DefaultMessage() +} + +// GetKey return the p.Key +func (p Phone) GetKey() string { + return validation.Phone(p).GetKey() +} + +// GetLimitValue return the limit value +func (p Phone) GetLimitValue() interface{} { + return validation.Phone(p).GetLimitValue() +} + +// ZipCode check the zip struct +type ZipCode validation.ZipCode + +// DefaultMessage return the default Zip error message +func (z ZipCode) DefaultMessage() string { + return validation.ZipCode(z).DefaultMessage() +} + +// GetKey return the z.Key +func (z ZipCode) GetKey() string { + return validation.ZipCode(z).GetKey() +} + +// GetLimitValue return the limit value +func (z ZipCode) GetLimitValue() interface{} { + return validation.ZipCode(z).GetLimitValue() +} diff --git a/pkg/cache/README.md b/pkg/client/cache/README.md similarity index 100% rename from pkg/cache/README.md rename to pkg/client/cache/README.md diff --git a/pkg/cache/cache.go b/pkg/client/cache/cache.go similarity index 100% rename from pkg/cache/cache.go rename to pkg/client/cache/cache.go diff --git a/pkg/cache/cache_test.go b/pkg/client/cache/cache_test.go similarity index 100% rename from pkg/cache/cache_test.go rename to pkg/client/cache/cache_test.go diff --git a/pkg/cache/conv.go b/pkg/client/cache/conv.go similarity index 100% rename from pkg/cache/conv.go rename to pkg/client/cache/conv.go diff --git a/pkg/cache/conv_test.go b/pkg/client/cache/conv_test.go similarity index 100% rename from pkg/cache/conv_test.go rename to pkg/client/cache/conv_test.go diff --git a/pkg/cache/file.go b/pkg/client/cache/file.go similarity index 100% rename from pkg/cache/file.go rename to pkg/client/cache/file.go diff --git a/pkg/cache/memcache/memcache.go b/pkg/client/cache/memcache/memcache.go similarity index 98% rename from pkg/cache/memcache/memcache.go rename to pkg/client/cache/memcache/memcache.go index 94fc61dc..60712c2f 100644 --- a/pkg/cache/memcache/memcache.go +++ b/pkg/client/cache/memcache/memcache.go @@ -37,7 +37,7 @@ import ( "github.com/bradfitz/gomemcache/memcache" - "github.com/astaxie/beego/pkg/cache" + "github.com/astaxie/beego/pkg/client/cache" ) // Cache Memcache adapter. diff --git a/pkg/cache/memcache/memcache_test.go b/pkg/client/cache/memcache/memcache_test.go similarity index 92% rename from pkg/cache/memcache/memcache_test.go rename to pkg/client/cache/memcache/memcache_test.go index b7dad8fc..df2ba37f 100644 --- a/pkg/cache/memcache/memcache_test.go +++ b/pkg/client/cache/memcache/memcache_test.go @@ -15,17 +15,26 @@ package memcache import ( + "fmt" + "os" + _ "github.com/bradfitz/gomemcache/memcache" "strconv" "testing" "time" - "github.com/astaxie/beego/pkg/cache" + "github.com/astaxie/beego/pkg/client/cache" ) func TestMemcacheCache(t *testing.T) { - bm, err := cache.NewCache("memcache", `{"conn": "127.0.0.1:11211"}`) + + addr := os.Getenv("MEMCACHE_ADDR") + if addr == "" { + addr = "127.0.0.1:11211" + } + + bm, err := cache.NewCache("memcache", fmt.Sprintf(`{"conn": "%s"}`, addr)) if err != nil { t.Error("init err") } diff --git a/pkg/cache/memory.go b/pkg/client/cache/memory.go similarity index 100% rename from pkg/cache/memory.go rename to pkg/client/cache/memory.go diff --git a/pkg/cache/redis/redis.go b/pkg/client/cache/redis/redis.go similarity index 99% rename from pkg/cache/redis/redis.go rename to pkg/client/cache/redis/redis.go index 68a934bf..2cd20503 100644 --- a/pkg/cache/redis/redis.go +++ b/pkg/client/cache/redis/redis.go @@ -39,7 +39,7 @@ import ( "github.com/gomodule/redigo/redis" - "github.com/astaxie/beego/pkg/cache" + "github.com/astaxie/beego/pkg/client/cache" ) var ( diff --git a/pkg/cache/redis/redis_test.go b/pkg/client/cache/redis/redis_test.go similarity index 86% rename from pkg/cache/redis/redis_test.go rename to pkg/client/cache/redis/redis_test.go index 8de331ab..dc0ca40f 100644 --- a/pkg/cache/redis/redis_test.go +++ b/pkg/client/cache/redis/redis_test.go @@ -16,17 +16,24 @@ package redis import ( "fmt" + "os" "testing" "time" "github.com/gomodule/redigo/redis" "github.com/stretchr/testify/assert" - "github.com/astaxie/beego/pkg/cache" + "github.com/astaxie/beego/pkg/client/cache" ) func TestRedisCache(t *testing.T) { - bm, err := cache.NewCache("redis", `{"conn": "127.0.0.1:6379"}`) + + redisAddr := os.Getenv("REDIS_ADDR") + if redisAddr == "" { + redisAddr = "127.0.0.1:6379" + } + + bm, err := cache.NewCache("redis", fmt.Sprintf(`{"conn": "%s"}`, redisAddr)) if err != nil { t.Error("init err") } @@ -110,24 +117,31 @@ func TestRedisCache(t *testing.T) { func TestCache_Scan(t *testing.T) { timeoutDuration := 10 * time.Second + + addr := os.Getenv("REDIS_ADDR") + if addr == "" { + addr = "127.0.0.1:6379" + } + // init - bm, err := cache.NewCache("redis", `{"conn": "127.0.0.1:6379"}`) + bm, err := cache.NewCache("redis", fmt.Sprintf(`{"conn": "%s"}`, addr)) if err != nil { t.Error("init err") } // insert all - for i := 0; i < 10000; i++ { + for i := 0; i < 100; i++ { if err = bm.Put(fmt.Sprintf("astaxie%d", i), fmt.Sprintf("author%d", i), timeoutDuration); err != nil { t.Error("set Error", err) } } + time.Sleep(time.Second) // scan all for the first time keys, err := bm.(*Cache).Scan(DefaultKey + ":*") if err != nil { t.Error("scan Error", err) } - assert.Equal(t, 10000, len(keys), "scan all error") + assert.Equal(t, 100, len(keys), "scan all error") // clear all if err = bm.ClearAll(); err != nil { diff --git a/pkg/cache/ssdb/ssdb.go b/pkg/client/cache/ssdb/ssdb.go similarity index 99% rename from pkg/cache/ssdb/ssdb.go rename to pkg/client/cache/ssdb/ssdb.go index 038b2ebe..10ff72b0 100644 --- a/pkg/cache/ssdb/ssdb.go +++ b/pkg/client/cache/ssdb/ssdb.go @@ -9,7 +9,7 @@ import ( "github.com/ssdb/gossdb/ssdb" - "github.com/astaxie/beego/pkg/cache" + "github.com/astaxie/beego/pkg/client/cache" ) // Cache SSDB adapter diff --git a/pkg/cache/ssdb/ssdb_test.go b/pkg/client/cache/ssdb/ssdb_test.go similarity index 90% rename from pkg/cache/ssdb/ssdb_test.go rename to pkg/client/cache/ssdb/ssdb_test.go index 7390ea94..bd6ede4e 100644 --- a/pkg/cache/ssdb/ssdb_test.go +++ b/pkg/client/cache/ssdb/ssdb_test.go @@ -1,15 +1,23 @@ package ssdb import ( + "fmt" + "os" "strconv" "testing" "time" - "github.com/astaxie/beego/pkg/cache" + "github.com/astaxie/beego/pkg/client/cache" ) func TestSsdbcacheCache(t *testing.T) { - ssdb, err := cache.NewCache("ssdb", `{"conn": "127.0.0.1:8888"}`) + + ssdbAddr := os.Getenv("SSDB_ADDR") + if ssdbAddr == "" { + ssdbAddr = "127.0.0.1:8888" + } + + ssdb, err := cache.NewCache("ssdb", fmt.Sprintf(`{"conn": "%s"}`, ssdbAddr)) if err != nil { t.Error("init err") } diff --git a/pkg/httplib/README.md b/pkg/client/httplib/README.md similarity index 100% rename from pkg/httplib/README.md rename to pkg/client/httplib/README.md diff --git a/pkg/httplib/filter.go b/pkg/client/httplib/filter.go similarity index 100% rename from pkg/httplib/filter.go rename to pkg/client/httplib/filter.go diff --git a/pkg/httplib/filter/opentracing/filter.go b/pkg/client/httplib/filter/opentracing/filter.go similarity index 98% rename from pkg/httplib/filter/opentracing/filter.go rename to pkg/client/httplib/filter/opentracing/filter.go index 6cc4d6b0..93376843 100644 --- a/pkg/httplib/filter/opentracing/filter.go +++ b/pkg/client/httplib/filter/opentracing/filter.go @@ -18,7 +18,7 @@ import ( "context" "net/http" - "github.com/astaxie/beego/pkg/httplib" + "github.com/astaxie/beego/pkg/client/httplib" logKit "github.com/go-kit/kit/log" opentracingKit "github.com/go-kit/kit/tracing/opentracing" "github.com/opentracing/opentracing-go" diff --git a/pkg/httplib/filter/opentracing/filter_test.go b/pkg/client/httplib/filter/opentracing/filter_test.go similarity index 96% rename from pkg/httplib/filter/opentracing/filter_test.go rename to pkg/client/httplib/filter/opentracing/filter_test.go index 8849a9ad..46937803 100644 --- a/pkg/httplib/filter/opentracing/filter_test.go +++ b/pkg/client/httplib/filter/opentracing/filter_test.go @@ -23,7 +23,7 @@ import ( "github.com/stretchr/testify/assert" - "github.com/astaxie/beego/pkg/httplib" + "github.com/astaxie/beego/pkg/client/httplib" ) func TestFilterChainBuilder_FilterChain(t *testing.T) { diff --git a/pkg/httplib/filter/prometheus/filter.go b/pkg/client/httplib/filter/prometheus/filter.go similarity index 91% rename from pkg/httplib/filter/prometheus/filter.go rename to pkg/client/httplib/filter/prometheus/filter.go index e7f7316f..917d4720 100644 --- a/pkg/httplib/filter/prometheus/filter.go +++ b/pkg/client/httplib/filter/prometheus/filter.go @@ -22,8 +22,8 @@ import ( "github.com/prometheus/client_golang/prometheus" - beego "github.com/astaxie/beego/pkg" - "github.com/astaxie/beego/pkg/httplib" + "github.com/astaxie/beego/pkg/client/httplib" + "github.com/astaxie/beego/pkg/server/web" ) type FilterChainBuilder struct { @@ -36,9 +36,9 @@ func (builder *FilterChainBuilder) FilterChain(next httplib.Filter) httplib.Filt Name: "beego", Subsystem: "remote_http_request", ConstLabels: map[string]string{ - "server": beego.BConfig.ServerName, - "env": beego.BConfig.RunMode, - "appname": beego.BConfig.AppName, + "server": web.BConfig.ServerName, + "env": web.BConfig.RunMode, + "appname": web.BConfig.AppName, }, Help: "The statics info for remote http requests", }, []string{"proto", "scheme", "method", "host", "path", "status", "duration", "isError"}) diff --git a/pkg/httplib/filter/prometheus/filter_test.go b/pkg/client/httplib/filter/prometheus/filter_test.go similarity index 96% rename from pkg/httplib/filter/prometheus/filter_test.go rename to pkg/client/httplib/filter/prometheus/filter_test.go index 2964e6c5..2964a4a2 100644 --- a/pkg/httplib/filter/prometheus/filter_test.go +++ b/pkg/client/httplib/filter/prometheus/filter_test.go @@ -22,7 +22,7 @@ import ( "github.com/stretchr/testify/assert" - "github.com/astaxie/beego/pkg/httplib" + "github.com/astaxie/beego/pkg/client/httplib" ) func TestFilterChainBuilder_FilterChain(t *testing.T) { diff --git a/pkg/httplib/httplib.go b/pkg/client/httplib/httplib.go similarity index 100% rename from pkg/httplib/httplib.go rename to pkg/client/httplib/httplib.go diff --git a/pkg/httplib/httplib_test.go b/pkg/client/httplib/httplib_test.go similarity index 100% rename from pkg/httplib/httplib_test.go rename to pkg/client/httplib/httplib_test.go diff --git a/pkg/testing/client.go b/pkg/client/httplib/testing/client.go similarity index 90% rename from pkg/testing/client.go rename to pkg/client/httplib/testing/client.go index 0062b857..00fa3059 100644 --- a/pkg/testing/client.go +++ b/pkg/client/httplib/testing/client.go @@ -15,8 +15,9 @@ package testing import ( - "github.com/astaxie/beego/pkg/config" - "github.com/astaxie/beego/pkg/httplib" + "github.com/astaxie/beego/pkg/client/httplib" + + "github.com/astaxie/beego/pkg/infrastructure/config" ) var port = "" @@ -33,7 +34,10 @@ func getPort() string { if err != nil { return "8080" } - port = config.String("httpport") + port, err = config.String(nil, "httpport") + if err != nil { + return "8080" + } return port } return port diff --git a/pkg/orm/README.md b/pkg/client/orm/README.md similarity index 100% rename from pkg/orm/README.md rename to pkg/client/orm/README.md diff --git a/pkg/orm/cmd.go b/pkg/client/orm/cmd.go similarity index 91% rename from pkg/orm/cmd.go rename to pkg/client/orm/cmd.go index f03382e9..b0661971 100644 --- a/pkg/orm/cmd.go +++ b/pkg/client/orm/cmd.go @@ -99,13 +99,17 @@ func (d *commandSyncDb) Parse(args []string) { // Run orm line command. func (d *commandSyncDb) Run() error { var drops []string + var err error if d.force { - drops = getDbDropSQL(d.al) + drops, err = modelCache.getDbDropSQL(d.al) + if err != nil { + return err + } } db := d.al.DB - if d.force { + if d.force && len(drops) > 0 { for i, mi := range modelCache.allOrdered() { query := drops[i] if !d.noInfo { @@ -124,7 +128,10 @@ func (d *commandSyncDb) Run() error { } } - sqls, indexes := getDbCreateSQL(d.al) + createQueries, indexes, err := modelCache.getDbCreateSQL(d.al) + if err != nil { + return err + } tables, err := d.al.DbBaser.GetTables(db) if err != nil { @@ -135,6 +142,12 @@ func (d *commandSyncDb) Run() error { } for i, mi := range modelCache.allOrdered() { + + if !isApplicableTableForDB(mi.addrField, d.al.Name) { + fmt.Printf("table `%s` is not applicable to database '%s'\n", mi.table, d.al.Name) + continue + } + if tables[mi.table] { if !d.noInfo { fmt.Printf("table `%s` already exists, skip\n", mi.table) @@ -201,7 +214,7 @@ func (d *commandSyncDb) Run() error { fmt.Printf("create table `%s` \n", mi.table) } - queries := []string{sqls[i]} + queries := []string{createQueries[i]} for _, idx := range indexes[mi.table] { queries = append(queries, idx.SQL) } @@ -245,10 +258,13 @@ func (d *commandSQLAll) Parse(args []string) { // Run orm line command. func (d *commandSQLAll) Run() error { - sqls, indexes := getDbCreateSQL(d.al) + createQueries, indexes, err := modelCache.getDbCreateSQL(d.al) + if err != nil { + return err + } var all []string for i, mi := range modelCache.allOrdered() { - queries := []string{sqls[i]} + queries := []string{createQueries[i]} for _, idx := range indexes[mi.table] { queries = append(queries, idx.SQL) } diff --git a/pkg/client/orm/cmd_utils.go b/pkg/client/orm/cmd_utils.go new file mode 100644 index 00000000..8d6c0c33 --- /dev/null +++ b/pkg/client/orm/cmd_utils.go @@ -0,0 +1,171 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package orm + +import ( + "fmt" + "strings" +) + +type dbIndex struct { + Table string + Name string + SQL string +} + +// get database column type string. +func getColumnTyp(al *alias, fi *fieldInfo) (col string) { + T := al.DbBaser.DbTypes() + fieldType := fi.fieldType + fieldSize := fi.size + +checkColumn: + switch fieldType { + case TypeBooleanField: + col = T["bool"] + case TypeVarCharField: + if al.Driver == DRPostgres && fi.toText { + col = T["string-text"] + } else { + col = fmt.Sprintf(T["string"], fieldSize) + } + case TypeCharField: + col = fmt.Sprintf(T["string-char"], fieldSize) + case TypeTextField: + col = T["string-text"] + case TypeTimeField: + col = T["time.Time-clock"] + case TypeDateField: + col = T["time.Time-date"] + case TypeDateTimeField: + // the precision of sqlite is not implemented + if al.Driver == 2 || fi.timePrecision == nil { + col = T["time.Time"] + } else { + s := T["time.Time-precision"] + col = fmt.Sprintf(s, *fi.timePrecision) + } + + case TypeBitField: + col = T["int8"] + case TypeSmallIntegerField: + col = T["int16"] + case TypeIntegerField: + col = T["int32"] + case TypeBigIntegerField: + if al.Driver == DRSqlite { + fieldType = TypeIntegerField + goto checkColumn + } + col = T["int64"] + case TypePositiveBitField: + col = T["uint8"] + case TypePositiveSmallIntegerField: + col = T["uint16"] + case TypePositiveIntegerField: + col = T["uint32"] + case TypePositiveBigIntegerField: + col = T["uint64"] + case TypeFloatField: + col = T["float64"] + case TypeDecimalField: + s := T["float64-decimal"] + if !strings.Contains(s, "%d") { + col = s + } else { + col = fmt.Sprintf(s, fi.digits, fi.decimals) + } + case TypeJSONField: + if al.Driver != DRPostgres { + fieldType = TypeVarCharField + goto checkColumn + } + col = T["json"] + case TypeJsonbField: + if al.Driver != DRPostgres { + fieldType = TypeVarCharField + goto checkColumn + } + col = T["jsonb"] + case RelForeignKey, RelOneToOne: + fieldType = fi.relModelInfo.fields.pk.fieldType + fieldSize = fi.relModelInfo.fields.pk.size + goto checkColumn + } + + return +} + +// create alter sql string. +func getColumnAddQuery(al *alias, fi *fieldInfo) string { + Q := al.DbBaser.TableQuote() + typ := getColumnTyp(al, fi) + + if !fi.null { + typ += " " + "NOT NULL" + } + + return fmt.Sprintf("ALTER TABLE %s%s%s ADD COLUMN %s%s%s %s %s", + Q, fi.mi.table, Q, + Q, fi.column, Q, + typ, getColumnDefault(fi), + ) +} + +// Get string value for the attribute "DEFAULT" for the CREATE, ALTER commands +func getColumnDefault(fi *fieldInfo) string { + var ( + v, t, d string + ) + + // Skip default attribute if field is in relations + if fi.rel || fi.reverse { + return v + } + + t = " DEFAULT '%s' " + + // These defaults will be useful if there no config value orm:"default" and NOT NULL is on + switch fi.fieldType { + case TypeTimeField, TypeDateField, TypeDateTimeField, TypeTextField: + return v + + case TypeBitField, TypeSmallIntegerField, TypeIntegerField, + TypeBigIntegerField, TypePositiveBitField, TypePositiveSmallIntegerField, + TypePositiveIntegerField, TypePositiveBigIntegerField, TypeFloatField, + TypeDecimalField: + t = " DEFAULT %s " + d = "0" + case TypeBooleanField: + t = " DEFAULT %s " + d = "FALSE" + case TypeJSONField, TypeJsonbField: + d = "{}" + } + + if fi.colDefault { + if !fi.initial.Exist() { + v = fmt.Sprintf(t, "") + } else { + v = fmt.Sprintf(t, fi.initial.String()) + } + } else { + if !fi.null { + v = fmt.Sprintf(t, d) + } + } + + return v +} diff --git a/pkg/orm/db.go b/pkg/client/orm/db.go similarity index 99% rename from pkg/orm/db.go rename to pkg/client/orm/db.go index 905c8189..820435ca 100644 --- a/pkg/orm/db.go +++ b/pkg/client/orm/db.go @@ -22,7 +22,7 @@ import ( "strings" "time" - "github.com/astaxie/beego/pkg/orm/hints" + "github.com/astaxie/beego/pkg/client/orm/hints" ) const ( @@ -1355,7 +1355,14 @@ setValue: t time.Time err error ) - if len(s) >= 19 { + + if fi.timePrecision != nil && len(s) >= (20+*fi.timePrecision) { + layout := formatDateTime + "." + for i := 0; i < *fi.timePrecision; i++ { + layout += "0" + } + t, err = time.ParseInLocation(layout, s[:20+*fi.timePrecision], tz) + } else if len(s) >= 19 { s = s[:19] t, err = time.ParseInLocation(formatDateTime, s, tz) } else if len(s) >= 10 { diff --git a/pkg/orm/db_alias.go b/pkg/client/orm/db_alias.go similarity index 86% rename from pkg/orm/db_alias.go rename to pkg/client/orm/db_alias.go index 0a53ad31..29e0904c 100644 --- a/pkg/orm/db_alias.go +++ b/pkg/client/orm/db_alias.go @@ -21,11 +21,7 @@ import ( "sync" "time" - "github.com/astaxie/beego/pkg/orm/hints" - lru "github.com/hashicorp/golang-lru" - - "github.com/astaxie/beego/pkg/common" ) // DriverType database driver constant int. @@ -279,6 +275,7 @@ type alias struct { MaxIdleConns int MaxOpenConns int ConnMaxLifetime time.Duration + StmtCacheSize int DB *DB DbBaser dbBaser TZ *time.Location @@ -341,7 +338,7 @@ func detectTZ(al *alias) { } } -func addAliasWthDB(aliasName, driverName string, db *sql.DB, params ...common.KV) (*alias, error) { +func addAliasWthDB(aliasName, driverName string, db *sql.DB, params ...DBOption) (*alias, error) { existErr := fmt.Errorf("DataBase alias name `%s` already registered, cannot reuse", aliasName) if _, ok := dataBaseCache.get(aliasName); ok { return nil, existErr @@ -359,32 +356,35 @@ func addAliasWthDB(aliasName, driverName string, db *sql.DB, params ...common.KV return al, nil } -func newAliasWithDb(aliasName, driverName string, db *sql.DB, params ...common.KV) (*alias, error) { - kvs := common.NewKVs(params...) +func newAliasWithDb(aliasName, driverName string, db *sql.DB, params ...DBOption) (*alias, error) { + + al := &alias{} + al.DB = &DB{ + RWMutex: new(sync.RWMutex), + DB: db, + } + + for _, p := range params { + p(al) + } var stmtCache *lru.Cache var stmtCacheSize int - maxStmtCacheSize := kvs.GetValueOr(hints.KeyMaxStmtCacheSize, 0).(int) - if maxStmtCacheSize > 0 { - _stmtCache, errC := newStmtDecoratorLruWithEvict(maxStmtCacheSize) + if al.StmtCacheSize > 0 { + _stmtCache, errC := newStmtDecoratorLruWithEvict(al.StmtCacheSize) if errC != nil { return nil, errC } else { stmtCache = _stmtCache - stmtCacheSize = maxStmtCacheSize + stmtCacheSize = al.StmtCacheSize } } - al := new(alias) al.Name = aliasName al.DriverName = driverName - al.DB = &DB{ - RWMutex: new(sync.RWMutex), - DB: db, - stmtDecorators: stmtCache, - stmtDecoratorsLimit: stmtCacheSize, - } + al.DB.stmtDecorators = stmtCache + al.DB.stmtDecoratorsLimit = stmtCacheSize if dr, ok := drivers[driverName]; ok { al.DbBaser = dbBasers[dr] @@ -400,31 +400,48 @@ func newAliasWithDb(aliasName, driverName string, db *sql.DB, params ...common.K detectTZ(al) - kvs.IfContains(hints.KeyMaxIdleConnections, func(value interface{}) { - if m, ok := value.(int); ok { - SetMaxIdleConns(al, m) - } - }).IfContains(hints.KeyMaxOpenConnections, func(value interface{}) { - if m, ok := value.(int); ok { - SetMaxOpenConns(al, m) - } - }).IfContains(hints.KeyConnMaxLifetime, func(value interface{}) { - if m, ok := value.(time.Duration); ok { - SetConnMaxLifetime(al, m) - } - }) - return al, nil } +// SetMaxIdleConns Change the max idle conns for *sql.DB, use specify database alias name +// Deprecated you should not use this, we will remove it in the future +func SetMaxIdleConns(aliasName string, maxIdleConns int) { + al := getDbAlias(aliasName) + al.SetMaxIdleConns(maxIdleConns) +} + +// SetMaxOpenConns Change the max open conns for *sql.DB, use specify database alias name +// Deprecated you should not use this, we will remove it in the future +func SetMaxOpenConns(aliasName string, maxOpenConns int) { + al := getDbAlias(aliasName) + al.SetMaxIdleConns(maxOpenConns) +} + +// SetMaxIdleConns Change the max idle conns for *sql.DB, use specify database alias name +func (al *alias) SetMaxIdleConns(maxIdleConns int) { + al.MaxIdleConns = maxIdleConns + al.DB.DB.SetMaxIdleConns(maxIdleConns) +} + +// SetMaxOpenConns Change the max open conns for *sql.DB, use specify database alias name +func (al *alias) SetMaxOpenConns(maxOpenConns int) { + al.MaxOpenConns = maxOpenConns + al.DB.DB.SetMaxOpenConns(maxOpenConns) +} + +func (al *alias) SetConnMaxLifetime(lifeTime time.Duration) { + al.ConnMaxLifetime = lifeTime + al.DB.DB.SetConnMaxLifetime(lifeTime) +} + // AddAliasWthDB add a aliasName for the drivename -func AddAliasWthDB(aliasName, driverName string, db *sql.DB, params ...common.KV) error { +func AddAliasWthDB(aliasName, driverName string, db *sql.DB, params ...DBOption) error { _, err := addAliasWthDB(aliasName, driverName, db, params...) return err } // RegisterDataBase Setting the database connect params. Use the database driver self dataSource args. -func RegisterDataBase(aliasName, driverName, dataSource string, params ...common.KV) error { +func RegisterDataBase(aliasName, driverName, dataSource string, params ...DBOption) error { var ( err error db *sql.DB @@ -477,23 +494,6 @@ func SetDataBaseTZ(aliasName string, tz *time.Location) error { return nil } -// SetMaxIdleConns Change the max idle conns for *sql.DB, use specify database alias name -func SetMaxIdleConns(al *alias, maxIdleConns int) { - al.MaxIdleConns = maxIdleConns - al.DB.DB.SetMaxIdleConns(maxIdleConns) -} - -// SetMaxOpenConns Change the max open conns for *sql.DB, use specify database alias name -func SetMaxOpenConns(al *alias, maxOpenConns int) { - al.MaxOpenConns = maxOpenConns - al.DB.DB.SetMaxOpenConns(maxOpenConns) -} - -func SetConnMaxLifetime(al *alias, lifeTime time.Duration) { - al.ConnMaxLifetime = lifeTime - al.DB.DB.SetConnMaxLifetime(lifeTime) -} - // GetDB Get *sql.DB from registered database by db alias name. // Use "default" as alias name if you not set. func GetDB(aliasNames ...string) (*sql.DB, error) { @@ -554,3 +554,33 @@ func newStmtDecoratorLruWithEvict(cacheSize int) (*lru.Cache, error) { } return cache, nil } + +type DBOption func(al *alias) + +// MaxIdleConnections return a hint about MaxIdleConnections +func MaxIdleConnections(maxIdleConn int) DBOption { + return func(al *alias) { + al.SetMaxIdleConns(maxIdleConn) + } +} + +// MaxOpenConnections return a hint about MaxOpenConnections +func MaxOpenConnections(maxOpenConn int) DBOption { + return func(al *alias) { + al.SetMaxOpenConns(maxOpenConn) + } +} + +// ConnMaxLifetime return a hint about ConnMaxLifetime +func ConnMaxLifetime(v time.Duration) DBOption { + return func(al *alias) { + al.SetConnMaxLifetime(v) + } +} + +// MaxStmtCacheSize return a hint about MaxStmtCacheSize +func MaxStmtCacheSize(v int) DBOption { + return func(al *alias) { + al.StmtCacheSize = v + } +} diff --git a/pkg/orm/db_alias_test.go b/pkg/client/orm/db_alias_test.go similarity index 90% rename from pkg/orm/db_alias_test.go rename to pkg/client/orm/db_alias_test.go index 4a561a27..6275cb2a 100644 --- a/pkg/orm/db_alias_test.go +++ b/pkg/client/orm/db_alias_test.go @@ -18,16 +18,14 @@ import ( "testing" "time" - "github.com/astaxie/beego/pkg/orm/hints" - "github.com/stretchr/testify/assert" ) func TestRegisterDataBase(t *testing.T) { err := RegisterDataBase("test-params", DBARGS.Driver, DBARGS.Source, - hints.MaxIdleConnections(20), - hints.MaxOpenConnections(300), - hints.ConnMaxLifetime(time.Minute)) + MaxIdleConnections(20), + MaxOpenConnections(300), + ConnMaxLifetime(time.Minute)) assert.Nil(t, err) al := getDbAlias("test-params") @@ -39,7 +37,7 @@ func TestRegisterDataBase(t *testing.T) { func TestRegisterDataBase_MaxStmtCacheSizeNegative1(t *testing.T) { aliasName := "TestRegisterDataBase_MaxStmtCacheSizeNegative1" - err := RegisterDataBase(aliasName, DBARGS.Driver, DBARGS.Source, hints.MaxStmtCacheSize(-1)) + err := RegisterDataBase(aliasName, DBARGS.Driver, DBARGS.Source, MaxStmtCacheSize(-1)) assert.Nil(t, err) al := getDbAlias(aliasName) @@ -49,7 +47,7 @@ func TestRegisterDataBase_MaxStmtCacheSizeNegative1(t *testing.T) { func TestRegisterDataBase_MaxStmtCacheSize0(t *testing.T) { aliasName := "TestRegisterDataBase_MaxStmtCacheSize0" - err := RegisterDataBase(aliasName, DBARGS.Driver, DBARGS.Source, hints.MaxStmtCacheSize(0)) + err := RegisterDataBase(aliasName, DBARGS.Driver, DBARGS.Source, MaxStmtCacheSize(0)) assert.Nil(t, err) al := getDbAlias(aliasName) @@ -59,7 +57,7 @@ func TestRegisterDataBase_MaxStmtCacheSize0(t *testing.T) { func TestRegisterDataBase_MaxStmtCacheSize1(t *testing.T) { aliasName := "TestRegisterDataBase_MaxStmtCacheSize1" - err := RegisterDataBase(aliasName, DBARGS.Driver, DBARGS.Source, hints.MaxStmtCacheSize(1)) + err := RegisterDataBase(aliasName, DBARGS.Driver, DBARGS.Source, MaxStmtCacheSize(1)) assert.Nil(t, err) al := getDbAlias(aliasName) @@ -69,7 +67,7 @@ func TestRegisterDataBase_MaxStmtCacheSize1(t *testing.T) { func TestRegisterDataBase_MaxStmtCacheSize841(t *testing.T) { aliasName := "TestRegisterDataBase_MaxStmtCacheSize841" - err := RegisterDataBase(aliasName, DBARGS.Driver, DBARGS.Source, hints.MaxStmtCacheSize(841)) + err := RegisterDataBase(aliasName, DBARGS.Driver, DBARGS.Source, MaxStmtCacheSize(841)) assert.Nil(t, err) al := getDbAlias(aliasName) diff --git a/pkg/orm/db_mysql.go b/pkg/client/orm/db_mysql.go similarity index 85% rename from pkg/orm/db_mysql.go rename to pkg/client/orm/db_mysql.go index d934d842..f602fd0a 100644 --- a/pkg/orm/db_mysql.go +++ b/pkg/client/orm/db_mysql.go @@ -42,24 +42,25 @@ var mysqlOperators = map[string]string{ // mysql column field types. var mysqlTypes = map[string]string{ - "auto": "AUTO_INCREMENT NOT NULL PRIMARY KEY", - "pk": "NOT NULL PRIMARY KEY", - "bool": "bool", - "string": "varchar(%d)", - "string-char": "char(%d)", - "string-text": "longtext", - "time.Time-date": "date", - "time.Time": "datetime", - "int8": "tinyint", - "int16": "smallint", - "int32": "integer", - "int64": "bigint", - "uint8": "tinyint unsigned", - "uint16": "smallint unsigned", - "uint32": "integer unsigned", - "uint64": "bigint unsigned", - "float64": "double precision", - "float64-decimal": "numeric(%d, %d)", + "auto": "AUTO_INCREMENT NOT NULL PRIMARY KEY", + "pk": "NOT NULL PRIMARY KEY", + "bool": "bool", + "string": "varchar(%d)", + "string-char": "char(%d)", + "string-text": "longtext", + "time.Time-date": "date", + "time.Time": "datetime", + "int8": "tinyint", + "int16": "smallint", + "int32": "integer", + "int64": "bigint", + "uint8": "tinyint unsigned", + "uint16": "smallint unsigned", + "uint32": "integer unsigned", + "uint64": "bigint unsigned", + "float64": "double precision", + "float64-decimal": "numeric(%d, %d)", + "time.Time-precision": "datetime(%d)", } // mysql dbBaser implementation. diff --git a/pkg/orm/db_oracle.go b/pkg/client/orm/db_oracle.go similarity index 84% rename from pkg/orm/db_oracle.go rename to pkg/client/orm/db_oracle.go index 66246ec4..7c1bf1b3 100644 --- a/pkg/orm/db_oracle.go +++ b/pkg/client/orm/db_oracle.go @@ -18,7 +18,7 @@ import ( "fmt" "strings" - "github.com/astaxie/beego/pkg/orm/hints" + "github.com/astaxie/beego/pkg/client/orm/hints" ) // oracle operators. @@ -33,23 +33,24 @@ var oracleOperators = map[string]string{ // oracle column field types. var oracleTypes = map[string]string{ - "pk": "NOT NULL PRIMARY KEY", - "bool": "bool", - "string": "VARCHAR2(%d)", - "string-char": "CHAR(%d)", - "string-text": "VARCHAR2(%d)", - "time.Time-date": "DATE", - "time.Time": "TIMESTAMP", - "int8": "INTEGER", - "int16": "INTEGER", - "int32": "INTEGER", - "int64": "INTEGER", - "uint8": "INTEGER", - "uint16": "INTEGER", - "uint32": "INTEGER", - "uint64": "INTEGER", - "float64": "NUMBER", - "float64-decimal": "NUMBER(%d, %d)", + "pk": "NOT NULL PRIMARY KEY", + "bool": "bool", + "string": "VARCHAR2(%d)", + "string-char": "CHAR(%d)", + "string-text": "VARCHAR2(%d)", + "time.Time-date": "DATE", + "time.Time": "TIMESTAMP", + "int8": "INTEGER", + "int16": "INTEGER", + "int32": "INTEGER", + "int64": "INTEGER", + "uint8": "INTEGER", + "uint16": "INTEGER", + "uint32": "INTEGER", + "uint64": "INTEGER", + "float64": "NUMBER", + "float64-decimal": "NUMBER(%d, %d)", + "time.Time-precision": "TIMESTAMP(%d)", } // oracle dbBaser diff --git a/pkg/orm/db_postgres.go b/pkg/client/orm/db_postgres.go similarity index 83% rename from pkg/orm/db_postgres.go rename to pkg/client/orm/db_postgres.go index 35471ddc..12431d6e 100644 --- a/pkg/orm/db_postgres.go +++ b/pkg/client/orm/db_postgres.go @@ -39,26 +39,27 @@ var postgresOperators = map[string]string{ // postgresql column field types. var postgresTypes = map[string]string{ - "auto": "serial NOT NULL PRIMARY KEY", - "pk": "NOT NULL PRIMARY KEY", - "bool": "bool", - "string": "varchar(%d)", - "string-char": "char(%d)", - "string-text": "text", - "time.Time-date": "date", - "time.Time": "timestamp with time zone", - "int8": `smallint CHECK("%COL%" >= -127 AND "%COL%" <= 128)`, - "int16": "smallint", - "int32": "integer", - "int64": "bigint", - "uint8": `smallint CHECK("%COL%" >= 0 AND "%COL%" <= 255)`, - "uint16": `integer CHECK("%COL%" >= 0)`, - "uint32": `bigint CHECK("%COL%" >= 0)`, - "uint64": `bigint CHECK("%COL%" >= 0)`, - "float64": "double precision", - "float64-decimal": "numeric(%d, %d)", - "json": "json", - "jsonb": "jsonb", + "auto": "serial NOT NULL PRIMARY KEY", + "pk": "NOT NULL PRIMARY KEY", + "bool": "bool", + "string": "varchar(%d)", + "string-char": "char(%d)", + "string-text": "text", + "time.Time-date": "date", + "time.Time": "timestamp with time zone", + "int8": `smallint CHECK("%COL%" >= -127 AND "%COL%" <= 128)`, + "int16": "smallint", + "int32": "integer", + "int64": "bigint", + "uint8": `smallint CHECK("%COL%" >= 0 AND "%COL%" <= 255)`, + "uint16": `integer CHECK("%COL%" >= 0)`, + "uint32": `bigint CHECK("%COL%" >= 0)`, + "uint64": `bigint CHECK("%COL%" >= 0)`, + "float64": "double precision", + "float64-decimal": "numeric(%d, %d)", + "json": "json", + "jsonb": "jsonb", + "time.Time-precision": "timestamp(%d) with time zone", } // postgresql dbBaser. diff --git a/pkg/orm/db_sqlite.go b/pkg/client/orm/db_sqlite.go similarity index 84% rename from pkg/orm/db_sqlite.go rename to pkg/client/orm/db_sqlite.go index f9d379ce..6d7a5617 100644 --- a/pkg/orm/db_sqlite.go +++ b/pkg/client/orm/db_sqlite.go @@ -21,7 +21,7 @@ import ( "strings" "time" - "github.com/astaxie/beego/pkg/orm/hints" + "github.com/astaxie/beego/pkg/client/orm/hints" ) // sqlite operators. @@ -44,24 +44,25 @@ var sqliteOperators = map[string]string{ // sqlite column types. var sqliteTypes = map[string]string{ - "auto": "integer NOT NULL PRIMARY KEY AUTOINCREMENT", - "pk": "NOT NULL PRIMARY KEY", - "bool": "bool", - "string": "varchar(%d)", - "string-char": "character(%d)", - "string-text": "text", - "time.Time-date": "date", - "time.Time": "datetime", - "int8": "tinyint", - "int16": "smallint", - "int32": "integer", - "int64": "bigint", - "uint8": "tinyint unsigned", - "uint16": "smallint unsigned", - "uint32": "integer unsigned", - "uint64": "bigint unsigned", - "float64": "real", - "float64-decimal": "decimal", + "auto": "integer NOT NULL PRIMARY KEY AUTOINCREMENT", + "pk": "NOT NULL PRIMARY KEY", + "bool": "bool", + "string": "varchar(%d)", + "string-char": "character(%d)", + "string-text": "text", + "time.Time-date": "date", + "time.Time": "datetime", + "time.Time-precision": "datetime(%d)", + "int8": "tinyint", + "int16": "smallint", + "int32": "integer", + "int64": "bigint", + "uint8": "tinyint unsigned", + "uint16": "smallint unsigned", + "uint32": "integer unsigned", + "uint64": "bigint unsigned", + "float64": "real", + "float64-decimal": "decimal", } // sqlite dbBaser. diff --git a/pkg/orm/db_tables.go b/pkg/client/orm/db_tables.go similarity index 100% rename from pkg/orm/db_tables.go rename to pkg/client/orm/db_tables.go diff --git a/pkg/orm/db_tidb.go b/pkg/client/orm/db_tidb.go similarity index 100% rename from pkg/orm/db_tidb.go rename to pkg/client/orm/db_tidb.go diff --git a/pkg/orm/db_utils.go b/pkg/client/orm/db_utils.go similarity index 100% rename from pkg/orm/db_utils.go rename to pkg/client/orm/db_utils.go diff --git a/pkg/orm/do_nothing_orm.go b/pkg/client/orm/do_nothing_orm.go similarity index 96% rename from pkg/orm/do_nothing_orm.go rename to pkg/client/orm/do_nothing_orm.go index afc428f2..e27e7f3a 100644 --- a/pkg/orm/do_nothing_orm.go +++ b/pkg/client/orm/do_nothing_orm.go @@ -18,7 +18,7 @@ import ( "context" "database/sql" - "github.com/astaxie/beego/pkg/common" + "github.com/astaxie/beego/pkg/infrastructure/utils" ) // DoNothingOrm won't do anything, usually you use this to custom your mock Ormer implementation @@ -54,11 +54,11 @@ func (d *DoNothingOrm) ReadOrCreateWithCtx(ctx context.Context, md interface{}, return false, 0, nil } -func (d *DoNothingOrm) LoadRelated(md interface{}, name string, args ...common.KV) (int64, error) { +func (d *DoNothingOrm) LoadRelated(md interface{}, name string, args ...utils.KV) (int64, error) { return 0, nil } -func (d *DoNothingOrm) LoadRelatedWithCtx(ctx context.Context, md interface{}, name string, args ...common.KV) (int64, error) { +func (d *DoNothingOrm) LoadRelatedWithCtx(ctx context.Context, md interface{}, name string, args ...utils.KV) (int64, error) { return 0, nil } diff --git a/pkg/orm/do_nothing_orm_test.go b/pkg/client/orm/do_nothing_orm_test.go similarity index 100% rename from pkg/orm/do_nothing_orm_test.go rename to pkg/client/orm/do_nothing_orm_test.go diff --git a/pkg/orm/filter.go b/pkg/client/orm/filter.go similarity index 100% rename from pkg/orm/filter.go rename to pkg/client/orm/filter.go diff --git a/pkg/orm/filter/bean/default_value_filter.go b/pkg/client/orm/filter/bean/default_value_filter.go similarity index 96% rename from pkg/orm/filter/bean/default_value_filter.go rename to pkg/client/orm/filter/bean/default_value_filter.go index b3ef7415..a367a6e2 100644 --- a/pkg/orm/filter/bean/default_value_filter.go +++ b/pkg/client/orm/filter/bean/default_value_filter.go @@ -19,9 +19,10 @@ import ( "reflect" "strings" - "github.com/astaxie/beego/pkg/bean" - "github.com/astaxie/beego/pkg/logs" - "github.com/astaxie/beego/pkg/orm" + "github.com/astaxie/beego/pkg/infrastructure/logs" + + "github.com/astaxie/beego/pkg/client/orm" + "github.com/astaxie/beego/pkg/infrastructure/bean" ) // DefaultValueFilterChainBuilder only works for InsertXXX method, diff --git a/pkg/orm/filter/bean/default_value_filter_test.go b/pkg/client/orm/filter/bean/default_value_filter_test.go similarity index 97% rename from pkg/orm/filter/bean/default_value_filter_test.go rename to pkg/client/orm/filter/bean/default_value_filter_test.go index 2c754a3e..fde7abf8 100644 --- a/pkg/orm/filter/bean/default_value_filter_test.go +++ b/pkg/client/orm/filter/bean/default_value_filter_test.go @@ -19,7 +19,7 @@ import ( "github.com/stretchr/testify/assert" - "github.com/astaxie/beego/pkg/orm" + "github.com/astaxie/beego/pkg/client/orm" ) func TestDefaultValueFilterChainBuilder_FilterChain(t *testing.T) { diff --git a/pkg/orm/filter/opentracing/filter.go b/pkg/client/orm/filter/opentracing/filter.go similarity index 98% rename from pkg/orm/filter/opentracing/filter.go rename to pkg/client/orm/filter/opentracing/filter.go index 0b6968b7..1a2ee541 100644 --- a/pkg/orm/filter/opentracing/filter.go +++ b/pkg/client/orm/filter/opentracing/filter.go @@ -20,7 +20,7 @@ import ( "github.com/opentracing/opentracing-go" - "github.com/astaxie/beego/pkg/orm" + "github.com/astaxie/beego/pkg/client/orm" ) // FilterChainBuilder provides an extension point diff --git a/pkg/orm/filter/opentracing/filter_test.go b/pkg/client/orm/filter/opentracing/filter_test.go similarity index 96% rename from pkg/orm/filter/opentracing/filter_test.go rename to pkg/client/orm/filter/opentracing/filter_test.go index 8f6d4807..9e89e5da 100644 --- a/pkg/orm/filter/opentracing/filter_test.go +++ b/pkg/client/orm/filter/opentracing/filter_test.go @@ -21,7 +21,7 @@ import ( "github.com/opentracing/opentracing-go" - "github.com/astaxie/beego/pkg/orm" + "github.com/astaxie/beego/pkg/client/orm" ) func TestFilterChainBuilder_FilterChain(t *testing.T) { diff --git a/pkg/orm/filter/prometheus/filter.go b/pkg/client/orm/filter/prometheus/filter.go similarity index 93% rename from pkg/orm/filter/prometheus/filter.go rename to pkg/client/orm/filter/prometheus/filter.go index fb2b473d..175b26be 100644 --- a/pkg/orm/filter/prometheus/filter.go +++ b/pkg/client/orm/filter/prometheus/filter.go @@ -22,8 +22,8 @@ import ( "github.com/prometheus/client_golang/prometheus" - beego "github.com/astaxie/beego/pkg" - "github.com/astaxie/beego/pkg/orm" + "github.com/astaxie/beego/pkg/client/orm" + "github.com/astaxie/beego/pkg/server/web" ) // FilterChainBuilder is an extension point, @@ -42,9 +42,9 @@ func NewFilterChainBuilder() *FilterChainBuilder { Name: "beego", Subsystem: "orm_operation", ConstLabels: map[string]string{ - "server": beego.BConfig.ServerName, - "env": beego.BConfig.RunMode, - "appname": beego.BConfig.AppName, + "server": web.BConfig.ServerName, + "env": web.BConfig.RunMode, + "appname": web.BConfig.AppName, }, Help: "The statics info for orm operation", }, []string{"method", "name", "duration", "insideTx", "txName"}) diff --git a/pkg/orm/filter/prometheus/filter_test.go b/pkg/client/orm/filter/prometheus/filter_test.go similarity index 97% rename from pkg/orm/filter/prometheus/filter_test.go rename to pkg/client/orm/filter/prometheus/filter_test.go index aad62c18..1b55b989 100644 --- a/pkg/orm/filter/prometheus/filter_test.go +++ b/pkg/client/orm/filter/prometheus/filter_test.go @@ -21,7 +21,7 @@ import ( "github.com/stretchr/testify/assert" - "github.com/astaxie/beego/pkg/orm" + "github.com/astaxie/beego/pkg/client/orm" ) func TestFilterChainBuilder_FilterChain(t *testing.T) { diff --git a/pkg/orm/filter_orm_decorator.go b/pkg/client/orm/filter_orm_decorator.go similarity index 98% rename from pkg/orm/filter_orm_decorator.go rename to pkg/client/orm/filter_orm_decorator.go index 97221a07..d0c5c537 100644 --- a/pkg/orm/filter_orm_decorator.go +++ b/pkg/client/orm/filter_orm_decorator.go @@ -20,7 +20,7 @@ import ( "reflect" "time" - "github.com/astaxie/beego/pkg/common" + "github.com/astaxie/beego/pkg/infrastructure/utils" ) const ( @@ -137,11 +137,11 @@ func (f *filterOrmDecorator) ReadOrCreateWithCtx(ctx context.Context, md interfa return res[0].(bool), res[1].(int64), f.convertError(res[2]) } -func (f *filterOrmDecorator) LoadRelated(md interface{}, name string, args ...common.KV) (int64, error) { +func (f *filterOrmDecorator) LoadRelated(md interface{}, name string, args ...utils.KV) (int64, error) { return f.LoadRelatedWithCtx(context.Background(), md, name, args...) } -func (f *filterOrmDecorator) LoadRelatedWithCtx(ctx context.Context, md interface{}, name string, args ...common.KV) (int64, error) { +func (f *filterOrmDecorator) LoadRelatedWithCtx(ctx context.Context, md interface{}, name string, args ...utils.KV) (int64, error) { mi, _ := modelCache.getByMd(md) inv := &Invocation{ diff --git a/pkg/orm/filter_orm_decorator_test.go b/pkg/client/orm/filter_orm_decorator_test.go similarity index 99% rename from pkg/orm/filter_orm_decorator_test.go rename to pkg/client/orm/filter_orm_decorator_test.go index 47f20854..7acb7d46 100644 --- a/pkg/orm/filter_orm_decorator_test.go +++ b/pkg/client/orm/filter_orm_decorator_test.go @@ -21,7 +21,7 @@ import ( "sync" "testing" - "github.com/astaxie/beego/pkg/common" + "github.com/astaxie/beego/pkg/infrastructure/utils" "github.com/stretchr/testify/assert" ) @@ -362,7 +362,7 @@ func (f *filterMockOrm) ReadForUpdateWithCtx(ctx context.Context, md interface{} return errors.New("read for update error") } -func (f *filterMockOrm) LoadRelatedWithCtx(ctx context.Context, md interface{}, name string, args ...common.KV) (int64, error) { +func (f *filterMockOrm) LoadRelatedWithCtx(ctx context.Context, md interface{}, name string, args ...utils.KV) (int64, error) { return 99, errors.New("load related error") } diff --git a/pkg/orm/filter_test.go b/pkg/client/orm/filter_test.go similarity index 100% rename from pkg/orm/filter_test.go rename to pkg/client/orm/filter_test.go diff --git a/pkg/orm/hints/db_hints.go b/pkg/client/orm/hints/db_hints.go similarity index 72% rename from pkg/orm/hints/db_hints.go rename to pkg/client/orm/hints/db_hints.go index 0649ab9f..7340bd07 100644 --- a/pkg/orm/hints/db_hints.go +++ b/pkg/client/orm/hints/db_hints.go @@ -15,20 +15,12 @@ package hints import ( - "time" - - "github.com/astaxie/beego/pkg/common" + "github.com/astaxie/beego/pkg/infrastructure/utils" ) const ( - //db level - KeyMaxIdleConnections = iota - KeyMaxOpenConnections - KeyConnMaxLifetime - KeyMaxStmtCacheSize - //query level - KeyForceIndex + KeyForceIndex = iota KeyUseIndex KeyIgnoreIndex KeyForUpdate @@ -43,7 +35,7 @@ type Hint struct { value interface{} } -var _ common.KV = new(Hint) +var _ utils.KV = new(Hint) // GetKey return key func (s *Hint) GetKey() interface{} { @@ -55,27 +47,7 @@ func (s *Hint) GetValue() interface{} { return s.value } -var _ common.KV = new(Hint) - -// MaxIdleConnections return a hint about MaxIdleConnections -func MaxIdleConnections(v int) *Hint { - return NewHint(KeyMaxIdleConnections, v) -} - -// MaxOpenConnections return a hint about MaxOpenConnections -func MaxOpenConnections(v int) *Hint { - return NewHint(KeyMaxOpenConnections, v) -} - -// ConnMaxLifetime return a hint about ConnMaxLifetime -func ConnMaxLifetime(v time.Duration) *Hint { - return NewHint(KeyConnMaxLifetime, v) -} - -// MaxStmtCacheSize return a hint about MaxStmtCacheSize -func MaxStmtCacheSize(v int) *Hint { - return NewHint(KeyMaxStmtCacheSize, v) -} +var _ utils.KV = new(Hint) // ForceIndex return a hint about ForceIndex func ForceIndex(indexes ...string) *Hint { diff --git a/pkg/orm/hints/db_hints_test.go b/pkg/client/orm/hints/db_hints_test.go similarity index 81% rename from pkg/orm/hints/db_hints_test.go rename to pkg/client/orm/hints/db_hints_test.go index 4e962a8f..510f9f16 100644 --- a/pkg/orm/hints/db_hints_test.go +++ b/pkg/client/orm/hints/db_hints_test.go @@ -48,34 +48,6 @@ func TestNewHint_float(t *testing.T) { assert.Equal(t, hint.GetValue(), value) } -func TestMaxOpenConnections(t *testing.T) { - i := 887423 - hint := MaxOpenConnections(i) - assert.Equal(t, hint.GetValue(), i) - assert.Equal(t, hint.GetKey(), KeyMaxOpenConnections) -} - -func TestConnMaxLifetime(t *testing.T) { - i := time.Hour - hint := ConnMaxLifetime(i) - assert.Equal(t, hint.GetValue(), i) - assert.Equal(t, hint.GetKey(), KeyConnMaxLifetime) -} - -func TestMaxIdleConnections(t *testing.T) { - i := 42316 - hint := MaxIdleConnections(i) - assert.Equal(t, hint.GetValue(), i) - assert.Equal(t, hint.GetKey(), KeyMaxIdleConnections) -} - -func TestMaxStmtCacheSize(t *testing.T) { - i := 94157 - hint := MaxStmtCacheSize(i) - assert.Equal(t, hint.GetValue(), i) - assert.Equal(t, hint.GetKey(), KeyMaxStmtCacheSize) -} - func TestForceIndex(t *testing.T) { s := []string{`f_index1`, `f_index2`, `f_index3`} hint := ForceIndex(s...) diff --git a/pkg/orm/invocation.go b/pkg/client/orm/invocation.go similarity index 100% rename from pkg/orm/invocation.go rename to pkg/client/orm/invocation.go diff --git a/pkg/migration/ddl.go b/pkg/client/orm/migration/ddl.go similarity index 85% rename from pkg/migration/ddl.go rename to pkg/client/orm/migration/ddl.go index b6d3a2d9..e8b13212 100644 --- a/pkg/migration/ddl.go +++ b/pkg/client/orm/migration/ddl.go @@ -17,7 +17,7 @@ package migration import ( "fmt" - "github.com/astaxie/beego/pkg/logs" + "github.com/astaxie/beego/pkg/infrastructure/logs" ) // Index struct defines the structure of Index Columns @@ -31,7 +31,7 @@ type Unique struct { Columns []*Column } -//Column struct defines a single column of a table +// Column struct defines a single column of a table type Column struct { Name string Inc string @@ -84,7 +84,7 @@ func (m *Migration) NewCol(name string) *Column { return col } -//PriCol creates a new primary column and attaches it to m struct +// PriCol creates a new primary column and attaches it to m struct func (m *Migration) PriCol(name string) *Column { col := &Column{Name: name} m.AddColumns(col) @@ -92,7 +92,7 @@ func (m *Migration) PriCol(name string) *Column { return col } -//UniCol creates / appends columns to specified unique key and attaches it to m struct +// UniCol creates / appends columns to specified unique key and attaches it to m struct func (m *Migration) UniCol(uni, name string) *Column { col := &Column{Name: name} m.AddColumns(col) @@ -114,7 +114,7 @@ func (m *Migration) UniCol(uni, name string) *Column { return col } -//ForeignCol creates a new foreign column and returns the instance of column +// ForeignCol creates a new foreign column and returns the instance of column func (m *Migration) ForeignCol(colname, foreigncol, foreigntable string) (foreign *Foreign) { foreign = &Foreign{ForeignColumn: foreigncol, ForeignTable: foreigntable} @@ -123,25 +123,25 @@ func (m *Migration) ForeignCol(colname, foreigncol, foreigntable string) (foreig return foreign } -//SetOnDelete sets the on delete of foreign +// SetOnDelete sets the on delete of foreign func (foreign *Foreign) SetOnDelete(del string) *Foreign { foreign.OnDelete = "ON DELETE" + del return foreign } -//SetOnUpdate sets the on update of foreign +// SetOnUpdate sets the on update of foreign func (foreign *Foreign) SetOnUpdate(update string) *Foreign { foreign.OnUpdate = "ON UPDATE" + update return foreign } -//Remove marks the columns to be removed. -//it allows reverse m to create the column. +// Remove marks the columns to be removed. +// it allows reverse m to create the column. func (c *Column) Remove() { c.remove = true } -//SetAuto enables auto_increment of column (can be used once) +// SetAuto enables auto_increment of column (can be used once) func (c *Column) SetAuto(inc bool) *Column { if inc { c.Inc = "auto_increment" @@ -149,7 +149,7 @@ func (c *Column) SetAuto(inc bool) *Column { return c } -//SetNullable sets the column to be null +// SetNullable sets the column to be null func (c *Column) SetNullable(null bool) *Column { if null { c.Null = "" @@ -160,13 +160,13 @@ func (c *Column) SetNullable(null bool) *Column { return c } -//SetDefault sets the default value, prepend with "DEFAULT " +// SetDefault sets the default value, prepend with "DEFAULT " func (c *Column) SetDefault(def string) *Column { c.Default = "DEFAULT " + def return c } -//SetUnsigned sets the column to be unsigned int +// SetUnsigned sets the column to be unsigned int func (c *Column) SetUnsigned(unsign bool) *Column { if unsign { c.Unsign = "UNSIGNED" @@ -174,13 +174,13 @@ func (c *Column) SetUnsigned(unsign bool) *Column { return c } -//SetDataType sets the dataType of the column +// SetDataType sets the dataType of the column func (c *Column) SetDataType(dataType string) *Column { c.DataType = dataType return c } -//SetOldNullable allows reverting to previous nullable on reverse ms +// SetOldNullable allows reverting to previous nullable on reverse ms func (c *RenameColumn) SetOldNullable(null bool) *RenameColumn { if null { c.OldNull = "" @@ -191,13 +191,13 @@ func (c *RenameColumn) SetOldNullable(null bool) *RenameColumn { return c } -//SetOldDefault allows reverting to previous default on reverse ms +// SetOldDefault allows reverting to previous default on reverse ms func (c *RenameColumn) SetOldDefault(def string) *RenameColumn { c.OldDefault = def return c } -//SetOldUnsigned allows reverting to previous unsgined on reverse ms +// SetOldUnsigned allows reverting to previous unsgined on reverse ms func (c *RenameColumn) SetOldUnsigned(unsign bool) *RenameColumn { if unsign { c.OldUnsign = "UNSIGNED" @@ -205,19 +205,19 @@ func (c *RenameColumn) SetOldUnsigned(unsign bool) *RenameColumn { return c } -//SetOldDataType allows reverting to previous datatype on reverse ms +// SetOldDataType allows reverting to previous datatype on reverse ms func (c *RenameColumn) SetOldDataType(dataType string) *RenameColumn { c.OldDataType = dataType return c } -//SetPrimary adds the columns to the primary key (can only be used any number of times in only one m) +// SetPrimary adds the columns to the primary key (can only be used any number of times in only one m) func (c *Column) SetPrimary(m *Migration) *Column { m.Primary = append(m.Primary, c) return c } -//AddColumnsToUnique adds the columns to Unique Struct +// AddColumnsToUnique adds the columns to Unique Struct func (unique *Unique) AddColumnsToUnique(columns ...*Column) *Unique { unique.Columns = append(unique.Columns, columns...) @@ -225,7 +225,7 @@ func (unique *Unique) AddColumnsToUnique(columns ...*Column) *Unique { return unique } -//AddColumns adds columns to m struct +// AddColumns adds columns to m struct func (m *Migration) AddColumns(columns ...*Column) *Migration { m.Columns = append(m.Columns, columns...) @@ -233,38 +233,38 @@ func (m *Migration) AddColumns(columns ...*Column) *Migration { return m } -//AddPrimary adds the column to primary in m struct +// AddPrimary adds the column to primary in m struct func (m *Migration) AddPrimary(primary *Column) *Migration { m.Primary = append(m.Primary, primary) return m } -//AddUnique adds the column to unique in m struct +// AddUnique adds the column to unique in m struct func (m *Migration) AddUnique(unique *Unique) *Migration { m.Uniques = append(m.Uniques, unique) return m } -//AddForeign adds the column to foreign in m struct +// AddForeign adds the column to foreign in m struct func (m *Migration) AddForeign(foreign *Foreign) *Migration { m.Foreigns = append(m.Foreigns, foreign) return m } -//AddIndex adds the column to index in m struct +// AddIndex adds the column to index in m struct func (m *Migration) AddIndex(index *Index) *Migration { m.Indexes = append(m.Indexes, index) return m } -//RenameColumn allows renaming of columns +// RenameColumn allows renaming of columns func (m *Migration) RenameColumn(from, to string) *RenameColumn { rename := &RenameColumn{OldName: from, NewName: to} m.Renames = append(m.Renames, rename) return rename } -//GetSQL returns the generated sql depending on ModifyType +// GetSQL returns the generated sql depending on ModifyType func (m *Migration) GetSQL() (sql string) { sql = "" switch m.ModifyType { diff --git a/pkg/client/orm/migration/doc.go b/pkg/client/orm/migration/doc.go new file mode 100644 index 00000000..0c6564d4 --- /dev/null +++ b/pkg/client/orm/migration/doc.go @@ -0,0 +1,32 @@ +// Package migration enables you to generate migrations back and forth. It generates both migrations. +// +// //Creates a table +// m.CreateTable("tablename","InnoDB","utf8"); +// +// //Alter a table +// m.AlterTable("tablename") +// +// Standard Column Methods +// * SetDataType +// * SetNullable +// * SetDefault +// * SetUnsigned (use only on integer types unless produces error) +// +// //Sets a primary column, multiple calls allowed, standard column methods available +// m.PriCol("id").SetAuto(true).SetNullable(false).SetDataType("INT(10)").SetUnsigned(true) +// +// //UniCol Can be used multiple times, allows standard Column methods. Use same "index" string to add to same index +// m.UniCol("index","column") +// +// //Standard Column Initialisation, can call .Remove() after NewCol("") on alter to remove +// m.NewCol("name").SetDataType("VARCHAR(255) COLLATE utf8_unicode_ci").SetNullable(false) +// m.NewCol("value").SetDataType("DOUBLE(8,2)").SetNullable(false) +// +// //Rename Columns , only use with Alter table, doesn't works with Create, prefix standard column methods with "Old" to +// //create a true reversible migration eg: SetOldDataType("DOUBLE(12,3)") +// m.RenameColumn("from","to")... +// +// //Foreign Columns, single columns are only supported, SetOnDelete & SetOnUpdate are available, call appropriately. +// //Supports standard column methods, automatic reverse. +// m.ForeignCol("local_col","foreign_col","foreign_table") +package migration diff --git a/pkg/migration/migration.go b/pkg/client/orm/migration/migration.go similarity index 98% rename from pkg/migration/migration.go rename to pkg/client/orm/migration/migration.go index c62fd901..3f740594 100644 --- a/pkg/migration/migration.go +++ b/pkg/client/orm/migration/migration.go @@ -33,8 +33,8 @@ import ( "strings" "time" - "github.com/astaxie/beego/pkg/logs" - "github.com/astaxie/beego/pkg/orm" + "github.com/astaxie/beego/pkg/client/orm" + "github.com/astaxie/beego/pkg/infrastructure/logs" ) // const the data format for the bee generate migration datatype diff --git a/pkg/orm/model_utils_test.go b/pkg/client/orm/model_utils_test.go similarity index 100% rename from pkg/orm/model_utils_test.go rename to pkg/client/orm/model_utils_test.go diff --git a/pkg/client/orm/models.go b/pkg/client/orm/models.go new file mode 100644 index 00000000..24f564ab --- /dev/null +++ b/pkg/client/orm/models.go @@ -0,0 +1,564 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package orm + +import ( + "errors" + "fmt" + "reflect" + "runtime/debug" + "strings" + "sync" +) + +const ( + odCascade = "cascade" + odSetNULL = "set_null" + odSetDefault = "set_default" + odDoNothing = "do_nothing" + defaultStructTagName = "orm" + defaultStructTagDelim = ";" +) + +var ( + modelCache = &_modelCache{ + cache: make(map[string]*modelInfo), + cacheByFullName: make(map[string]*modelInfo), + } +) + +// model info collection +type _modelCache struct { + sync.RWMutex // only used outsite for bootStrap + orders []string + cache map[string]*modelInfo + cacheByFullName map[string]*modelInfo + done bool +} + +// get all model info +func (mc *_modelCache) all() map[string]*modelInfo { + m := make(map[string]*modelInfo, len(mc.cache)) + for k, v := range mc.cache { + m[k] = v + } + return m +} + +// get ordered model info +func (mc *_modelCache) allOrdered() []*modelInfo { + m := make([]*modelInfo, 0, len(mc.orders)) + for _, table := range mc.orders { + m = append(m, mc.cache[table]) + } + return m +} + +// get model info by table name +func (mc *_modelCache) get(table string) (mi *modelInfo, ok bool) { + mi, ok = mc.cache[table] + return +} + +// get model info by full name +func (mc *_modelCache) getByFullName(name string) (mi *modelInfo, ok bool) { + mi, ok = mc.cacheByFullName[name] + return +} + +func (mc *_modelCache) getByMd(md interface{}) (*modelInfo, bool) { + val := reflect.ValueOf(md) + ind := reflect.Indirect(val) + typ := ind.Type() + name := getFullName(typ) + return mc.getByFullName(name) +} + +// set model info to collection +func (mc *_modelCache) set(table string, mi *modelInfo) *modelInfo { + mii := mc.cache[table] + mc.cache[table] = mi + mc.cacheByFullName[mi.fullName] = mi + if mii == nil { + mc.orders = append(mc.orders, table) + } + return mii +} + +// clean all model info. +func (mc *_modelCache) clean() { + mc.Lock() + defer mc.Unlock() + + mc.orders = make([]string, 0) + mc.cache = make(map[string]*modelInfo) + mc.cacheByFullName = make(map[string]*modelInfo) + mc.done = false +} + +//bootstrap bootstrap for models +func (mc *_modelCache) bootstrap() { + mc.Lock() + defer mc.Unlock() + if mc.done { + return + } + var ( + err error + models map[string]*modelInfo + ) + if dataBaseCache.getDefault() == nil { + err = fmt.Errorf("must have one register DataBase alias named `default`") + goto end + } + + // set rel and reverse model + // RelManyToMany set the relTable + models = mc.all() + for _, mi := range models { + for _, fi := range mi.fields.columns { + if fi.rel || fi.reverse { + elm := fi.addrValue.Type().Elem() + if fi.fieldType == RelReverseMany || fi.fieldType == RelManyToMany { + elm = elm.Elem() + } + // check the rel or reverse model already register + name := getFullName(elm) + mii, ok := mc.getByFullName(name) + if !ok || mii.pkg != elm.PkgPath() { + err = fmt.Errorf("can not find rel in field `%s`, `%s` may be miss register", fi.fullName, elm.String()) + goto end + } + fi.relModelInfo = mii + + switch fi.fieldType { + case RelManyToMany: + if fi.relThrough != "" { + if i := strings.LastIndex(fi.relThrough, "."); i != -1 && len(fi.relThrough) > (i+1) { + pn := fi.relThrough[:i] + rmi, ok := mc.getByFullName(fi.relThrough) + if !ok || pn != rmi.pkg { + err = fmt.Errorf("field `%s` wrong rel_through value `%s` cannot find table", fi.fullName, fi.relThrough) + goto end + } + fi.relThroughModelInfo = rmi + fi.relTable = rmi.table + } else { + err = fmt.Errorf("field `%s` wrong rel_through value `%s`", fi.fullName, fi.relThrough) + goto end + } + } else { + i := newM2MModelInfo(mi, mii) + if fi.relTable != "" { + i.table = fi.relTable + } + if v := mc.set(i.table, i); v != nil { + err = fmt.Errorf("the rel table name `%s` already registered, cannot be use, please change one", fi.relTable) + goto end + } + fi.relTable = i.table + fi.relThroughModelInfo = i + } + + fi.relThroughModelInfo.isThrough = true + } + } + } + } + + // check the rel filed while the relModelInfo also has filed point to current model + // if not exist, add a new field to the relModelInfo + models = mc.all() + for _, mi := range models { + for _, fi := range mi.fields.fieldsRel { + switch fi.fieldType { + case RelForeignKey, RelOneToOne, RelManyToMany: + inModel := false + for _, ffi := range fi.relModelInfo.fields.fieldsReverse { + if ffi.relModelInfo == mi { + inModel = true + break + } + } + if !inModel { + rmi := fi.relModelInfo + ffi := new(fieldInfo) + ffi.name = mi.name + ffi.column = ffi.name + ffi.fullName = rmi.fullName + "." + ffi.name + ffi.reverse = true + ffi.relModelInfo = mi + ffi.mi = rmi + if fi.fieldType == RelOneToOne { + ffi.fieldType = RelReverseOne + } else { + ffi.fieldType = RelReverseMany + } + if !rmi.fields.Add(ffi) { + added := false + for cnt := 0; cnt < 5; cnt++ { + ffi.name = fmt.Sprintf("%s%d", mi.name, cnt) + ffi.column = ffi.name + ffi.fullName = rmi.fullName + "." + ffi.name + if added = rmi.fields.Add(ffi); added { + break + } + } + if !added { + panic(fmt.Errorf("cannot generate auto reverse field info `%s` to `%s`", fi.fullName, ffi.fullName)) + } + } + } + } + } + } + + models = mc.all() + for _, mi := range models { + for _, fi := range mi.fields.fieldsRel { + switch fi.fieldType { + case RelManyToMany: + for _, ffi := range fi.relThroughModelInfo.fields.fieldsRel { + switch ffi.fieldType { + case RelOneToOne, RelForeignKey: + if ffi.relModelInfo == fi.relModelInfo { + fi.reverseFieldInfoTwo = ffi + } + if ffi.relModelInfo == mi { + fi.reverseField = ffi.name + fi.reverseFieldInfo = ffi + } + } + } + if fi.reverseFieldInfoTwo == nil { + err = fmt.Errorf("can not find m2m field for m2m model `%s`, ensure your m2m model defined correct", + fi.relThroughModelInfo.fullName) + goto end + } + } + } + } + + models = mc.all() + for _, mi := range models { + for _, fi := range mi.fields.fieldsReverse { + switch fi.fieldType { + case RelReverseOne: + found := false + mForA: + for _, ffi := range fi.relModelInfo.fields.fieldsByType[RelOneToOne] { + if ffi.relModelInfo == mi { + found = true + fi.reverseField = ffi.name + fi.reverseFieldInfo = ffi + + ffi.reverseField = fi.name + ffi.reverseFieldInfo = fi + break mForA + } + } + if !found { + err = fmt.Errorf("reverse field `%s` not found in model `%s`", fi.fullName, fi.relModelInfo.fullName) + goto end + } + case RelReverseMany: + found := false + mForB: + for _, ffi := range fi.relModelInfo.fields.fieldsByType[RelForeignKey] { + if ffi.relModelInfo == mi { + found = true + fi.reverseField = ffi.name + fi.reverseFieldInfo = ffi + + ffi.reverseField = fi.name + ffi.reverseFieldInfo = fi + + break mForB + } + } + if !found { + mForC: + for _, ffi := range fi.relModelInfo.fields.fieldsByType[RelManyToMany] { + conditions := fi.relThrough != "" && fi.relThrough == ffi.relThrough || + fi.relTable != "" && fi.relTable == ffi.relTable || + fi.relThrough == "" && fi.relTable == "" + if ffi.relModelInfo == mi && conditions { + found = true + + fi.reverseField = ffi.reverseFieldInfoTwo.name + fi.reverseFieldInfo = ffi.reverseFieldInfoTwo + fi.relThroughModelInfo = ffi.relThroughModelInfo + fi.reverseFieldInfoTwo = ffi.reverseFieldInfo + fi.reverseFieldInfoM2M = ffi + ffi.reverseFieldInfoM2M = fi + + break mForC + } + } + } + if !found { + err = fmt.Errorf("reverse field for `%s` not found in model `%s`", fi.fullName, fi.relModelInfo.fullName) + goto end + } + } + } + } + +end: + if err != nil { + fmt.Println(err) + debug.PrintStack() + } + modelCache.done = true + return +} + +// register register models to model cache +func (mc *_modelCache) register(prefixOrSuffixStr string, prefixOrSuffix bool, models ...interface{}) (err error) { + if mc.done { + err = fmt.Errorf("register must be run before BootStrap") + return + } + + for _, model := range models { + val := reflect.ValueOf(model) + typ := reflect.Indirect(val).Type() + + if val.Kind() != reflect.Ptr { + err = fmt.Errorf(" cannot use non-ptr model struct `%s`", getFullName(typ)) + return + } + // For this case: + // u := &User{} + // registerModel(&u) + if typ.Kind() == reflect.Ptr { + err = fmt.Errorf(" only allow ptr model struct, it looks you use two reference to the struct `%s`", typ) + return + } + + table := getTableName(val) + + if prefixOrSuffixStr != "" { + if prefixOrSuffix { + table = prefixOrSuffixStr + table + } else { + table = table + prefixOrSuffixStr + } + } + + // models's fullname is pkgpath + struct name + name := getFullName(typ) + if _, ok := mc.getByFullName(name); ok { + err = fmt.Errorf(" model `%s` repeat register, must be unique\n", name) + return + } + + if _, ok := mc.get(table); ok { + err = fmt.Errorf(" table name `%s` repeat register, must be unique\n", table) + return + } + + mi := newModelInfo(val) + if mi.fields.pk == nil { + outFor: + for _, fi := range mi.fields.fieldsDB { + if strings.ToLower(fi.name) == "id" { + switch fi.addrValue.Elem().Kind() { + case reflect.Int, reflect.Int32, reflect.Int64, reflect.Uint, reflect.Uint32, reflect.Uint64: + fi.auto = true + fi.pk = true + mi.fields.pk = fi + break outFor + } + } + } + + if mi.fields.pk == nil { + err = fmt.Errorf(" `%s` needs a primary key field, default is to use 'id' if not set\n", name) + return + } + + } + + mi.table = table + mi.pkg = typ.PkgPath() + mi.model = model + mi.manual = true + + mc.set(table, mi) + } + return +} + +//getDbDropSQL get database scheme drop sql queries +func (mc *_modelCache) getDbDropSQL(al *alias) (queries []string, err error) { + if len(modelCache.cache) == 0 { + err = errors.New("no Model found, need register your model") + return + } + + Q := al.DbBaser.TableQuote() + + for _, mi := range modelCache.allOrdered() { + queries = append(queries, fmt.Sprintf(`DROP TABLE IF EXISTS %s%s%s`, Q, mi.table, Q)) + } + return queries, nil +} + +//getDbCreateSQL get database scheme creation sql queries +func (mc *_modelCache) getDbCreateSQL(al *alias) (queries []string, tableIndexes map[string][]dbIndex, err error) { + if len(modelCache.cache) == 0 { + err = errors.New("no Model found, need register your model") + return + } + + Q := al.DbBaser.TableQuote() + T := al.DbBaser.DbTypes() + sep := fmt.Sprintf("%s, %s", Q, Q) + + tableIndexes = make(map[string][]dbIndex) + + for _, mi := range modelCache.allOrdered() { + sql := fmt.Sprintf("-- %s\n", strings.Repeat("-", 50)) + sql += fmt.Sprintf("-- Table Structure for `%s`\n", mi.fullName) + sql += fmt.Sprintf("-- %s\n", strings.Repeat("-", 50)) + + sql += fmt.Sprintf("CREATE TABLE IF NOT EXISTS %s%s%s (\n", Q, mi.table, Q) + + columns := make([]string, 0, len(mi.fields.fieldsDB)) + + sqlIndexes := [][]string{} + + for _, fi := range mi.fields.fieldsDB { + + column := fmt.Sprintf(" %s%s%s ", Q, fi.column, Q) + col := getColumnTyp(al, fi) + + if fi.auto { + switch al.Driver { + case DRSqlite, DRPostgres: + column += T["auto"] + default: + column += col + " " + T["auto"] + } + } else if fi.pk { + column += col + " " + T["pk"] + } else { + column += col + + if !fi.null { + column += " " + "NOT NULL" + } + + //if fi.initial.String() != "" { + // column += " DEFAULT " + fi.initial.String() + //} + + // Append attribute DEFAULT + column += getColumnDefault(fi) + + if fi.unique { + column += " " + "UNIQUE" + } + + if fi.index { + sqlIndexes = append(sqlIndexes, []string{fi.column}) + } + } + + if strings.Contains(column, "%COL%") { + column = strings.Replace(column, "%COL%", fi.column, -1) + } + + if fi.description != "" && al.Driver != DRSqlite { + column += " " + fmt.Sprintf("COMMENT '%s'", fi.description) + } + + columns = append(columns, column) + } + + if mi.model != nil { + allnames := getTableUnique(mi.addrField) + if !mi.manual && len(mi.uniques) > 0 { + allnames = append(allnames, mi.uniques) + } + for _, names := range allnames { + cols := make([]string, 0, len(names)) + for _, name := range names { + if fi, ok := mi.fields.GetByAny(name); ok && fi.dbcol { + cols = append(cols, fi.column) + } else { + panic(fmt.Errorf("cannot found column `%s` when parse UNIQUE in `%s.TableUnique`", name, mi.fullName)) + } + } + column := fmt.Sprintf(" UNIQUE (%s%s%s)", Q, strings.Join(cols, sep), Q) + columns = append(columns, column) + } + } + + sql += strings.Join(columns, ",\n") + sql += "\n)" + + if al.Driver == DRMySQL { + var engine string + if mi.model != nil { + engine = getTableEngine(mi.addrField) + } + if engine == "" { + engine = al.Engine + } + sql += " ENGINE=" + engine + } + + sql += ";" + queries = append(queries, sql) + + if mi.model != nil { + for _, names := range getTableIndex(mi.addrField) { + cols := make([]string, 0, len(names)) + for _, name := range names { + if fi, ok := mi.fields.GetByAny(name); ok && fi.dbcol { + cols = append(cols, fi.column) + } else { + panic(fmt.Errorf("cannot found column `%s` when parse INDEX in `%s.TableIndex`", name, mi.fullName)) + } + } + sqlIndexes = append(sqlIndexes, cols) + } + } + + for _, names := range sqlIndexes { + name := mi.table + "_" + strings.Join(names, "_") + cols := strings.Join(names, sep) + sql := fmt.Sprintf("CREATE INDEX %s%s%s ON %s%s%s (%s%s%s);", Q, name, Q, Q, mi.table, Q, Q, cols, Q) + + index := dbIndex{} + index.Table = mi.table + index.Name = name + index.SQL = sql + + tableIndexes[mi.table] = append(tableIndexes[mi.table], index) + } + + } + + return +} + +// ResetModelCache Clean model cache. Then you can re-RegisterModel. +// Common use this api for test case. +func ResetModelCache() { + modelCache.clean() +} diff --git a/pkg/client/orm/models_boot.go b/pkg/client/orm/models_boot.go new file mode 100644 index 00000000..407cf536 --- /dev/null +++ b/pkg/client/orm/models_boot.go @@ -0,0 +1,47 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package orm + +import ( + "fmt" +) + +// RegisterModel register models +func RegisterModel(models ...interface{}) { + if modelCache.done { + panic(fmt.Errorf("RegisterModel must be run before BootStrap")) + } + RegisterModelWithPrefix("", models...) +} + +// RegisterModelWithPrefix register models with a prefix +func RegisterModelWithPrefix(prefix string, models ...interface{}) { + if err := modelCache.register(prefix, true, models...); err != nil { + panic(err) + } +} + +// RegisterModelWithSuffix register models with a suffix +func RegisterModelWithSuffix(suffix string, models ...interface{}) { + if err := modelCache.register(suffix, false, models...); err != nil { + panic(err) + } +} + +// BootStrap bootstrap models. +// make all model parsed and can not add more models +func BootStrap() { + modelCache.bootstrap() +} diff --git a/pkg/orm/models_fields.go b/pkg/client/orm/models_fields.go similarity index 100% rename from pkg/orm/models_fields.go rename to pkg/client/orm/models_fields.go diff --git a/pkg/orm/models_info_f.go b/pkg/client/orm/models_info_f.go similarity index 97% rename from pkg/orm/models_info_f.go rename to pkg/client/orm/models_info_f.go index 7044b0bd..7152fada 100644 --- a/pkg/orm/models_info_f.go +++ b/pkg/client/orm/models_info_f.go @@ -137,6 +137,7 @@ type fieldInfo struct { isFielder bool // implement Fielder interface onDelete string description string + timePrecision *int } // new field info @@ -177,7 +178,7 @@ func newFieldInfo(mi *modelInfo, field reflect.Value, sf reflect.StructField, mN decimals := tags["decimals"] size := tags["size"] onDelete := tags["on_delete"] - + precision := tags["precision"] initial.Clear() if v, ok := tags["default"]; ok { initial.Set(v) @@ -377,6 +378,18 @@ checkType: fi.index = false fi.unique = false case TypeTimeField, TypeDateField, TypeDateTimeField: + if fieldType == TypeDateTimeField { + if precision != "" { + v, e := StrTo(precision).Int() + if e != nil { + err = fmt.Errorf("convert %s to int error:%v", precision, e) + } else { + fi.timePrecision = &v + } + } + + } + if attrs["auto_now"] { fi.autoNow = true } else if attrs["auto_now_add"] { diff --git a/pkg/orm/models_info_m.go b/pkg/client/orm/models_info_m.go similarity index 100% rename from pkg/orm/models_info_m.go rename to pkg/client/orm/models_info_m.go diff --git a/pkg/orm/models_test.go b/pkg/client/orm/models_test.go similarity index 71% rename from pkg/orm/models_test.go rename to pkg/client/orm/models_test.go index 8fcd2e06..f0044f6d 100644 --- a/pkg/orm/models_test.go +++ b/pkg/client/orm/models_test.go @@ -22,8 +22,6 @@ import ( "strings" "time" - "github.com/astaxie/beego/pkg/orm/hints" - _ "github.com/go-sql-driver/mysql" _ "github.com/lib/pq" _ "github.com/mattn/go-sqlite3" @@ -145,55 +143,56 @@ type Data struct { } type DataNull struct { - ID int `orm:"column(id)"` - Boolean bool `orm:"null"` - Char string `orm:"null;size(50)"` - Text string `orm:"null;type(text)"` - JSON string `orm:"type(json);null"` - Jsonb string `orm:"type(jsonb);null"` - Time time.Time `orm:"null;type(time)"` - Date time.Time `orm:"null;type(date)"` - DateTime time.Time `orm:"null;column(datetime)"` - Byte byte `orm:"null"` - Rune rune `orm:"null"` - Int int `orm:"null"` - Int8 int8 `orm:"null"` - Int16 int16 `orm:"null"` - Int32 int32 `orm:"null"` - Int64 int64 `orm:"null"` - Uint uint `orm:"null"` - Uint8 uint8 `orm:"null"` - Uint16 uint16 `orm:"null"` - Uint32 uint32 `orm:"null"` - Uint64 uint64 `orm:"null"` - Float32 float32 `orm:"null"` - Float64 float64 `orm:"null"` - Decimal float64 `orm:"digits(8);decimals(4);null"` - NullString sql.NullString `orm:"null"` - NullBool sql.NullBool `orm:"null"` - NullFloat64 sql.NullFloat64 `orm:"null"` - NullInt64 sql.NullInt64 `orm:"null"` - BooleanPtr *bool `orm:"null"` - CharPtr *string `orm:"null;size(50)"` - TextPtr *string `orm:"null;type(text)"` - BytePtr *byte `orm:"null"` - RunePtr *rune `orm:"null"` - IntPtr *int `orm:"null"` - Int8Ptr *int8 `orm:"null"` - Int16Ptr *int16 `orm:"null"` - Int32Ptr *int32 `orm:"null"` - Int64Ptr *int64 `orm:"null"` - UintPtr *uint `orm:"null"` - Uint8Ptr *uint8 `orm:"null"` - Uint16Ptr *uint16 `orm:"null"` - Uint32Ptr *uint32 `orm:"null"` - Uint64Ptr *uint64 `orm:"null"` - Float32Ptr *float32 `orm:"null"` - Float64Ptr *float64 `orm:"null"` - DecimalPtr *float64 `orm:"digits(8);decimals(4);null"` - TimePtr *time.Time `orm:"null;type(time)"` - DatePtr *time.Time `orm:"null;type(date)"` - DateTimePtr *time.Time `orm:"null"` + ID int `orm:"column(id)"` + Boolean bool `orm:"null"` + Char string `orm:"null;size(50)"` + Text string `orm:"null;type(text)"` + JSON string `orm:"type(json);null"` + Jsonb string `orm:"type(jsonb);null"` + Time time.Time `orm:"null;type(time)"` + Date time.Time `orm:"null;type(date)"` + DateTime time.Time `orm:"null;column(datetime)"` + DateTimePrecision time.Time `orm:"null;type(datetime);precision(4)"` + Byte byte `orm:"null"` + Rune rune `orm:"null"` + Int int `orm:"null"` + Int8 int8 `orm:"null"` + Int16 int16 `orm:"null"` + Int32 int32 `orm:"null"` + Int64 int64 `orm:"null"` + Uint uint `orm:"null"` + Uint8 uint8 `orm:"null"` + Uint16 uint16 `orm:"null"` + Uint32 uint32 `orm:"null"` + Uint64 uint64 `orm:"null"` + Float32 float32 `orm:"null"` + Float64 float64 `orm:"null"` + Decimal float64 `orm:"digits(8);decimals(4);null"` + NullString sql.NullString `orm:"null"` + NullBool sql.NullBool `orm:"null"` + NullFloat64 sql.NullFloat64 `orm:"null"` + NullInt64 sql.NullInt64 `orm:"null"` + BooleanPtr *bool `orm:"null"` + CharPtr *string `orm:"null;size(50)"` + TextPtr *string `orm:"null;type(text)"` + BytePtr *byte `orm:"null"` + RunePtr *rune `orm:"null"` + IntPtr *int `orm:"null"` + Int8Ptr *int8 `orm:"null"` + Int16Ptr *int16 `orm:"null"` + Int32Ptr *int32 `orm:"null"` + Int64Ptr *int64 `orm:"null"` + UintPtr *uint `orm:"null"` + Uint8Ptr *uint8 `orm:"null"` + Uint16Ptr *uint16 `orm:"null"` + Uint32Ptr *uint32 `orm:"null"` + Uint64Ptr *uint64 `orm:"null"` + Float32Ptr *float32 `orm:"null"` + Float64Ptr *float64 `orm:"null"` + DecimalPtr *float64 `orm:"digits(8);decimals(4);null"` + TimePtr *time.Time `orm:"null;type(time)"` + DatePtr *time.Time `orm:"null;type(date)"` + DateTimePtr *time.Time `orm:"null"` } type String string @@ -241,6 +240,21 @@ type UserBig struct { Name string } +type TM struct { + ID int `orm:"column(id)"` + TMPrecision1 time.Time `orm:"type(datetime);precision(3)"` + TMPrecision2 time.Time `orm:"auto_now_add;type(datetime);precision(4)"` +} + +func (t *TM) TableName() string { + return "tm" +} + +func NewTM() *TM { + obj := new(TM) + return obj +} + type User struct { ID int `orm:"column(id)"` UserName string `orm:"size(30);unique"` @@ -297,13 +311,14 @@ func NewProfile() *Profile { } type Post struct { - ID int `orm:"column(id)"` - User *User `orm:"rel(fk)"` - Title string `orm:"size(60)"` - Content string `orm:"type(text)"` - Created time.Time `orm:"auto_now_add"` - Updated time.Time `orm:"auto_now"` - Tags []*Tag `orm:"rel(m2m);rel_through(github.com/astaxie/beego/pkg/orm.PostTags)"` + ID int `orm:"column(id)"` + User *User `orm:"rel(fk)"` + Title string `orm:"size(60)"` + Content string `orm:"type(text)"` + Created time.Time `orm:"auto_now_add"` + Updated time.Time `orm:"auto_now"` + UpdatedPrecision time.Time `orm:"auto_now;type(datetime);precision(4)"` + Tags []*Tag `orm:"rel(m2m);rel_through(github.com/astaxie/beego/pkg/client/orm.PostTags)"` } func (u *Post) TableIndex() [][]string { @@ -361,7 +376,7 @@ type Group struct { type Permission struct { ID int `orm:"column(id)"` Name string - Groups []*Group `orm:"rel(m2m);rel_through(github.com/astaxie/beego/pkg/orm.GroupPermissions)"` + Groups []*Group `orm:"rel(m2m);rel_through(github.com/astaxie/beego/pkg/client/orm.GroupPermissions)"` } type GroupPermissions struct { @@ -470,7 +485,7 @@ var ( usage: - go get -u github.com/astaxie/beego/pkg/orm + go get -u github.com/astaxie/beego/pkg/client/orm go get -u github.com/go-sql-driver/mysql go get -u github.com/mattn/go-sqlite3 go get -u github.com/lib/pq @@ -480,20 +495,20 @@ var ( mysql -u root -e 'create database orm_test;' export ORM_DRIVER=mysql export ORM_SOURCE="root:@/orm_test?charset=utf8" - go test -v github.com/astaxie/beego/pkg/orm + go test -v github.com/astaxie/beego/pkg/client/orm #### Sqlite3 export ORM_DRIVER=sqlite3 export ORM_SOURCE='file:memory_test?mode=memory' - go test -v github.com/astaxie/beego/pkg/orm + go test -v github.com/astaxie/beego/pkg/client/orm #### PostgreSQL psql -c 'create database orm_test;' -U postgres export ORM_DRIVER=postgres export ORM_SOURCE="user=postgres dbname=orm_test sslmode=disable" - go test -v github.com/astaxie/beego/pkg/orm + go test -v github.com/astaxie/beego/pkg/client/orm #### TiDB export ORM_DRIVER=tidb @@ -504,14 +519,15 @@ var ( ) func init() { - Debug, _ = StrTo(DBARGS.Debug).Bool() + // Debug, _ = StrTo(DBARGS.Debug).Bool() + Debug = true if DBARGS.Driver == "" || DBARGS.Source == "" { fmt.Println(helpinfo) os.Exit(2) } - err := RegisterDataBase("default", DBARGS.Driver, DBARGS.Source, hints.MaxIdleConnections(20)) + err := RegisterDataBase("default", DBARGS.Driver, DBARGS.Source, MaxIdleConnections(20)) if err != nil { panic(fmt.Sprintf("can not register database: %v", err)) diff --git a/pkg/orm/models_utils.go b/pkg/client/orm/models_utils.go similarity index 93% rename from pkg/orm/models_utils.go rename to pkg/client/orm/models_utils.go index 71127a6b..950ca243 100644 --- a/pkg/orm/models_utils.go +++ b/pkg/client/orm/models_utils.go @@ -45,6 +45,7 @@ var supportTag = map[string]int{ "on_delete": 2, "type": 2, "description": 2, + "precision": 2, } // get reflect.Type name with package path. @@ -106,6 +107,18 @@ func getTableUnique(val reflect.Value) [][]string { return nil } +// get whether the table needs to be created for the database alias +func isApplicableTableForDB(val reflect.Value, db string) bool { + fun := val.MethodByName("IsApplicableTableForDB") + if fun.IsValid() { + vals := fun.Call([]reflect.Value{reflect.ValueOf(db)}) + if len(vals) > 0 && vals[0].Kind() == reflect.Bool { + return vals[0].Bool() + } + } + return true +} + // get snaked column name func getColumnName(ft int, addrField reflect.Value, sf reflect.StructField, col string) string { column := col diff --git a/pkg/client/orm/models_utils_test.go b/pkg/client/orm/models_utils_test.go new file mode 100644 index 00000000..0a6995b3 --- /dev/null +++ b/pkg/client/orm/models_utils_test.go @@ -0,0 +1,35 @@ +// Copyright 2020 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package orm + +import ( + "reflect" + "testing" + + "github.com/stretchr/testify/assert" +) + +type NotApplicableModel struct { + Id int +} + +func (n *NotApplicableModel) IsApplicableTableForDB(db string) bool { + return db == "default" +} + +func Test_IsApplicableTableForDB(t *testing.T) { + assert.False(t, isApplicableTableForDB(reflect.ValueOf(&NotApplicableModel{}), "defa")) + assert.True(t, isApplicableTableForDB(reflect.ValueOf(&NotApplicableModel{}), "default")) +} diff --git a/pkg/orm/orm.go b/pkg/client/orm/orm.go similarity index 97% rename from pkg/orm/orm.go rename to pkg/client/orm/orm.go index d7dc3915..bfb710d1 100644 --- a/pkg/orm/orm.go +++ b/pkg/client/orm/orm.go @@ -21,7 +21,7 @@ // // import ( // "fmt" -// "github.com/astaxie/beego/pkg/orm" +// "github.com/astaxie/beego/pkg/client/orm" // _ "github.com/go-sql-driver/mysql" // import your used driver // ) // @@ -62,10 +62,10 @@ import ( "reflect" "time" - "github.com/astaxie/beego/pkg/common" - "github.com/astaxie/beego/pkg/orm/hints" + "github.com/astaxie/beego/pkg/client/orm/hints" + "github.com/astaxie/beego/pkg/infrastructure/utils" - "github.com/astaxie/beego/pkg/logs" + "github.com/astaxie/beego/pkg/infrastructure/logs" ) // DebugQueries define the debug @@ -307,19 +307,17 @@ func (o *ormBase) QueryM2MWithCtx(ctx context.Context, md interface{}, name stri // for _,tag := range post.Tags{...} // // make sure the relation is defined in model struct tags. -func (o *ormBase) LoadRelated(md interface{}, name string, args ...common.KV) (int64, error) { +func (o *ormBase) LoadRelated(md interface{}, name string, args ...utils.KV) (int64, error) { return o.LoadRelatedWithCtx(context.Background(), md, name, args...) } -func (o *ormBase) LoadRelatedWithCtx(ctx context.Context, md interface{}, name string, args ...common.KV) (int64, error) { - _, fi, ind, qseter := o.queryRelated(md, name) - - qs := qseter.(*querySet) +func (o *ormBase) LoadRelatedWithCtx(ctx context.Context, md interface{}, name string, args ...utils.KV) (int64, error) { + _, fi, ind, qs := o.queryRelated(md, name) var relDepth int var limit, offset int64 var order string - kvs := common.NewKVs(args...) + kvs := utils.NewKVs(args...) kvs.IfContains(hints.KeyRelDepth, func(value interface{}) { if v, ok := value.(bool); ok { if v { @@ -377,7 +375,7 @@ func (o *ormBase) LoadRelatedWithCtx(ctx context.Context, md interface{}, name s } // get QuerySeter for related models to md model -func (o *ormBase) queryRelated(md interface{}, name string) (*modelInfo, *fieldInfo, reflect.Value, QuerySeter) { +func (o *ormBase) queryRelated(md interface{}, name string) (*modelInfo, *fieldInfo, reflect.Value, *querySet) { mi, ind := o.getMiInd(md, true) fi := o.getFieldInfo(mi, name) @@ -603,7 +601,7 @@ func NewOrmUsingDB(aliasName string) Ormer { } // NewOrmWithDB create a new ormer object with specify *sql.DB for query -func NewOrmWithDB(driverName, aliasName string, db *sql.DB, params ...common.KV) (Ormer, error) { +func NewOrmWithDB(driverName, aliasName string, db *sql.DB, params ...DBOption) (Ormer, error) { al, err := newAliasWithDb(aliasName, driverName, db, params...) if err != nil { return nil, err diff --git a/pkg/orm/orm_conds.go b/pkg/client/orm/orm_conds.go similarity index 100% rename from pkg/orm/orm_conds.go rename to pkg/client/orm/orm_conds.go diff --git a/pkg/orm/orm_log.go b/pkg/client/orm/orm_log.go similarity index 100% rename from pkg/orm/orm_log.go rename to pkg/client/orm/orm_log.go diff --git a/pkg/orm/orm_object.go b/pkg/client/orm/orm_object.go similarity index 100% rename from pkg/orm/orm_object.go rename to pkg/client/orm/orm_object.go diff --git a/pkg/orm/orm_querym2m.go b/pkg/client/orm/orm_querym2m.go similarity index 100% rename from pkg/orm/orm_querym2m.go rename to pkg/client/orm/orm_querym2m.go diff --git a/pkg/orm/orm_queryset.go b/pkg/client/orm/orm_queryset.go similarity index 99% rename from pkg/orm/orm_queryset.go rename to pkg/client/orm/orm_queryset.go index 3a50fcae..906505de 100644 --- a/pkg/orm/orm_queryset.go +++ b/pkg/client/orm/orm_queryset.go @@ -18,7 +18,7 @@ import ( "context" "fmt" - "github.com/astaxie/beego/pkg/orm/hints" + "github.com/astaxie/beego/pkg/client/orm/hints" ) type colValue struct { diff --git a/pkg/orm/orm_raw.go b/pkg/client/orm/orm_raw.go similarity index 98% rename from pkg/orm/orm_raw.go rename to pkg/client/orm/orm_raw.go index c2539147..e11e97fa 100644 --- a/pkg/orm/orm_raw.go +++ b/pkg/client/orm/orm_raw.go @@ -330,6 +330,8 @@ func (o *rawSet) QueryRow(containers ...interface{}) error { return err } + structTagMap := make(map[reflect.StructTag]map[string]string) + defer rows.Close() if rows.Next() { @@ -396,7 +398,12 @@ func (o *rawSet) QueryRow(containers ...interface{}) error { recursiveSetField(f) } - _, tags := parseStructTag(fe.Tag.Get(defaultStructTagName)) + // thanks @Gazeboxu. + tags := structTagMap[fe.Tag] + if tags == nil { + _, tags = parseStructTag(fe.Tag.Get(defaultStructTagName)) + structTagMap[fe.Tag] = tags + } var col string if col = tags["column"]; col == "" { col = nameStrategyMap[nameStrategy](fe.Name) diff --git a/pkg/orm/orm_test.go b/pkg/client/orm/orm_test.go similarity index 99% rename from pkg/orm/orm_test.go rename to pkg/client/orm/orm_test.go index 40314ab4..8c4bf55d 100644 --- a/pkg/orm/orm_test.go +++ b/pkg/client/orm/orm_test.go @@ -31,7 +31,7 @@ import ( "testing" "time" - "github.com/astaxie/beego/pkg/orm/hints" + "github.com/astaxie/beego/pkg/client/orm/hints" "github.com/stretchr/testify/assert" ) @@ -204,6 +204,7 @@ func TestSyncDb(t *testing.T) { RegisterModel(new(PtrPk)) RegisterModel(new(Index)) RegisterModel(new(StrPk)) + RegisterModel(new(TM)) err := RunSyncdb("default", true, Debug) throwFail(t, err) @@ -230,6 +231,7 @@ func TestRegisterModels(t *testing.T) { RegisterModel(new(PtrPk)) RegisterModel(new(Index)) RegisterModel(new(StrPk)) + RegisterModel(new(TM)) BootStrap() @@ -313,6 +315,24 @@ func TestDataTypes(t *testing.T) { } } +func TestTM(t *testing.T) { + // The precision of sqlite is not implemented + if dORM.Driver().Type() == 2 { + return + } + var recTM TM + tm := NewTM() + tm.TMPrecision1 = time.Unix(1596766024, 123456789) + tm.TMPrecision2 = time.Unix(1596766024, 123456789) + _, err := dORM.Insert(tm) + throwFail(t, err) + + err = dORM.QueryTable("tm").One(&recTM) + throwFail(t, err) + throwFail(t, AssertIs(recTM.TMPrecision1.String(), "2020-08-07 02:07:04.123 +0000 UTC")) + throwFail(t, AssertIs(recTM.TMPrecision2.String(), "2020-08-07 02:07:04.1235 +0000 UTC")) +} + func TestNullDataTypes(t *testing.T) { d := DataNull{} diff --git a/pkg/orm/qb.go b/pkg/client/orm/qb.go similarity index 100% rename from pkg/orm/qb.go rename to pkg/client/orm/qb.go diff --git a/pkg/orm/qb_mysql.go b/pkg/client/orm/qb_mysql.go similarity index 100% rename from pkg/orm/qb_mysql.go rename to pkg/client/orm/qb_mysql.go diff --git a/pkg/orm/qb_tidb.go b/pkg/client/orm/qb_tidb.go similarity index 100% rename from pkg/orm/qb_tidb.go rename to pkg/client/orm/qb_tidb.go diff --git a/pkg/orm/types.go b/pkg/client/orm/types.go similarity index 98% rename from pkg/orm/types.go rename to pkg/client/orm/types.go index 06ba12f2..b0c793b7 100644 --- a/pkg/orm/types.go +++ b/pkg/client/orm/types.go @@ -20,7 +20,7 @@ import ( "reflect" "time" - "github.com/astaxie/beego/pkg/common" + "github.com/astaxie/beego/pkg/infrastructure/utils" ) // TableNaming is usually used by model @@ -75,6 +75,11 @@ type TableUniqueI interface { TableUnique() [][]string } +// IsApplicableTableForDB if return false, we won't create table to this db +type IsApplicableTableForDB interface { + IsApplicableTableForDB(db string) bool +} + // Driver define database driver type Driver interface { Name() string @@ -183,8 +188,8 @@ type DQL interface { // hints.Offset int offset default offset 0 // hints.OrderBy string order for example : "-Id" // make sure the relation is defined in model struct tags. - LoadRelated(md interface{}, name string, args ...common.KV) (int64, error) - LoadRelatedWithCtx(ctx context.Context, md interface{}, name string, args ...common.KV) (int64, error) + LoadRelated(md interface{}, name string, args ...utils.KV) (int64, error) + LoadRelatedWithCtx(ctx context.Context, md interface{}, name string, args ...utils.KV) (int64, error) // create a models to models queryer // for example: diff --git a/pkg/orm/utils.go b/pkg/client/orm/utils.go similarity index 100% rename from pkg/orm/utils.go rename to pkg/client/orm/utils.go diff --git a/pkg/client/orm/utils_test.go b/pkg/client/orm/utils_test.go new file mode 100644 index 00000000..7d94cada --- /dev/null +++ b/pkg/client/orm/utils_test.go @@ -0,0 +1,70 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package orm + +import ( + "testing" +) + +func TestCamelString(t *testing.T) { + snake := []string{"pic_url", "hello_world_", "hello__World", "_HelLO_Word", "pic_url_1", "pic_url__1"} + camel := []string{"PicUrl", "HelloWorld", "HelloWorld", "HelLOWord", "PicUrl1", "PicUrl1"} + + answer := make(map[string]string) + for i, v := range snake { + answer[v] = camel[i] + } + + for _, v := range snake { + res := camelString(v) + if res != answer[v] { + t.Error("Unit Test Fail:", v, res, answer[v]) + } + } +} + +func TestSnakeString(t *testing.T) { + camel := []string{"PicUrl", "HelloWorld", "HelloWorld", "HelLOWord", "PicUrl1", "XyXX"} + snake := []string{"pic_url", "hello_world", "hello_world", "hel_l_o_word", "pic_url1", "xy_x_x"} + + answer := make(map[string]string) + for i, v := range camel { + answer[v] = snake[i] + } + + for _, v := range camel { + res := snakeString(v) + if res != answer[v] { + t.Error("Unit Test Fail:", v, res, answer[v]) + } + } +} + +func TestSnakeStringWithAcronym(t *testing.T) { + camel := []string{"ID", "PicURL", "HelloWorld", "HelloWorld", "HelLOWord", "PicUrl1", "XyXX"} + snake := []string{"id", "pic_url", "hello_world", "hello_world", "hel_lo_word", "pic_url1", "xy_xx"} + + answer := make(map[string]string) + for i, v := range camel { + answer[v] = snake[i] + } + + for _, v := range camel { + res := snakeStringWithAcronym(v) + if res != answer[v] { + t.Error("Unit Test Fail:", v, res, answer[v]) + } + } +} diff --git a/pkg/doc.go b/pkg/doc.go index a1cdbbb0..2d9c2bfe 100644 --- a/pkg/doc.go +++ b/pkg/doc.go @@ -1,17 +1,15 @@ -/* -Package beego provide a MVC framework -beego: an open-source, high-performance, modular, full-stack web framework +// Copyright 2020 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. -It is used for rapid development of RESTful APIs, web apps and backend services in Go. -beego is inspired by Tornado, Sinatra and Flask with the added benefit of some Go-specific features such as interfaces and struct embedding. - - package main - import "github.com/astaxie/beego/pkg" - - func main() { - beego.Run() - } - -more information: http://beego.me -*/ -package beego +package pkg diff --git a/pkg/bean/context.go b/pkg/infrastructure/bean/context.go similarity index 100% rename from pkg/bean/context.go rename to pkg/infrastructure/bean/context.go diff --git a/pkg/bean/doc.go b/pkg/infrastructure/bean/doc.go similarity index 100% rename from pkg/bean/doc.go rename to pkg/infrastructure/bean/doc.go diff --git a/pkg/bean/factory.go b/pkg/infrastructure/bean/factory.go similarity index 100% rename from pkg/bean/factory.go rename to pkg/infrastructure/bean/factory.go diff --git a/pkg/bean/metadata.go b/pkg/infrastructure/bean/metadata.go similarity index 100% rename from pkg/bean/metadata.go rename to pkg/infrastructure/bean/metadata.go diff --git a/pkg/bean/tag_auto_wire_bean_factory.go b/pkg/infrastructure/bean/tag_auto_wire_bean_factory.go similarity index 99% rename from pkg/bean/tag_auto_wire_bean_factory.go rename to pkg/infrastructure/bean/tag_auto_wire_bean_factory.go index 569ffb0d..f80388f9 100644 --- a/pkg/bean/tag_auto_wire_bean_factory.go +++ b/pkg/infrastructure/bean/tag_auto_wire_bean_factory.go @@ -22,7 +22,7 @@ import ( "github.com/pkg/errors" - "github.com/astaxie/beego/pkg/logs" + "github.com/astaxie/beego/pkg/infrastructure/logs" ) const DefaultValueTagKey = "default" diff --git a/pkg/bean/tag_auto_wire_bean_factory_test.go b/pkg/infrastructure/bean/tag_auto_wire_bean_factory_test.go similarity index 100% rename from pkg/bean/tag_auto_wire_bean_factory_test.go rename to pkg/infrastructure/bean/tag_auto_wire_bean_factory_test.go diff --git a/pkg/bean/time_type_adapter.go b/pkg/infrastructure/bean/time_type_adapter.go similarity index 100% rename from pkg/bean/time_type_adapter.go rename to pkg/infrastructure/bean/time_type_adapter.go diff --git a/pkg/bean/time_type_adapter_test.go b/pkg/infrastructure/bean/time_type_adapter_test.go similarity index 100% rename from pkg/bean/time_type_adapter_test.go rename to pkg/infrastructure/bean/time_type_adapter_test.go diff --git a/pkg/bean/type_adapter.go b/pkg/infrastructure/bean/type_adapter.go similarity index 100% rename from pkg/bean/type_adapter.go rename to pkg/infrastructure/bean/type_adapter.go diff --git a/pkg/config/base_config_test.go b/pkg/infrastructure/config/base_config_test.go similarity index 56% rename from pkg/config/base_config_test.go rename to pkg/infrastructure/config/base_config_test.go index 3d37bc91..74cef184 100644 --- a/pkg/config/base_config_test.go +++ b/pkg/infrastructure/config/base_config_test.go @@ -15,6 +15,7 @@ package config import ( + "context" "errors" "testing" @@ -23,43 +24,43 @@ import ( func TestBaseConfiger_DefaultBool(t *testing.T) { bc := newBaseConfier("true") - assert.True(t, bc.DefaultBool("key1", false)) - assert.True(t, bc.DefaultBool("key2", true)) + assert.True(t, bc.DefaultBool(context.Background(), "key1", false)) + assert.True(t, bc.DefaultBool(context.Background(), "key2", true)) } func TestBaseConfiger_DefaultFloat(t *testing.T) { bc := newBaseConfier("12.3") - assert.Equal(t, 12.3, bc.DefaultFloat("key1", 0.1)) - assert.Equal(t, 0.1, bc.DefaultFloat("key2", 0.1)) + assert.Equal(t, 12.3, bc.DefaultFloat(context.Background(), "key1", 0.1)) + assert.Equal(t, 0.1, bc.DefaultFloat(context.Background(), "key2", 0.1)) } func TestBaseConfiger_DefaultInt(t *testing.T) { bc := newBaseConfier("10") - assert.Equal(t, 10, bc.DefaultInt("key1", 8)) - assert.Equal(t, 8, bc.DefaultInt("key2", 8)) + assert.Equal(t, 10, bc.DefaultInt(context.Background(), "key1", 8)) + assert.Equal(t, 8, bc.DefaultInt(context.Background(), "key2", 8)) } func TestBaseConfiger_DefaultInt64(t *testing.T) { bc := newBaseConfier("64") - assert.Equal(t, int64(64), bc.DefaultInt64("key1", int64(8))) - assert.Equal(t, int64(8), bc.DefaultInt64("key2", int64(8))) + assert.Equal(t, int64(64), bc.DefaultInt64(context.Background(), "key1", int64(8))) + assert.Equal(t, int64(8), bc.DefaultInt64(context.Background(), "key2", int64(8))) } func TestBaseConfiger_DefaultString(t *testing.T) { bc := newBaseConfier("Hello") - assert.Equal(t, "Hello", bc.DefaultString("key1", "world")) - assert.Equal(t, "world", bc.DefaultString("key2", "world")) + assert.Equal(t, "Hello", bc.DefaultString(context.Background(), "key1", "world")) + assert.Equal(t, "world", bc.DefaultString(context.Background(), "key2", "world")) } func TestBaseConfiger_DefaultStrings(t *testing.T) { bc := newBaseConfier("Hello;world") - assert.Equal(t, []string{"Hello", "world"}, bc.DefaultStrings("key1", []string{"world"})) - assert.Equal(t, []string{"world"}, bc.DefaultStrings("key2", []string{"world"})) + assert.Equal(t, []string{"Hello", "world"}, bc.DefaultStrings(context.Background(), "key1", []string{"world"})) + assert.Equal(t, []string{"world"}, bc.DefaultStrings(context.Background(), "key2", []string{"world"})) } func newBaseConfier(str1 string) *BaseConfiger { return &BaseConfiger{ - reader: func(key string) (string, error) { + reader: func(ctx context.Context, key string) (string, error) { if key == "key1" { return str1, nil } else { diff --git a/pkg/config/config.go b/pkg/infrastructure/config/config.go similarity index 66% rename from pkg/config/config.go rename to pkg/infrastructure/config/config.go index b17f6208..0891e571 100644 --- a/pkg/config/config.go +++ b/pkg/infrastructure/config/config.go @@ -41,6 +41,7 @@ package config import ( + "context" "errors" "fmt" "os" @@ -53,138 +54,148 @@ import ( // Configer defines how to get and set value from configuration raw data. type Configer interface { // support section::key type in given key when using ini type. - Set(key, val string) error + Set(ctx context.Context, key, val string) error // support section::key type in key string when using ini and json type; Int,Int64,Bool,Float,DIY are same. - String(key string) string + String(ctx context.Context, key string) (string, error) // get string slice - Strings(key string) []string - Int(key string) (int, error) - Int64(key string) (int64, error) - Bool(key string) (bool, error) - Float(key string) (float64, error) + Strings(ctx context.Context, key string) ([]string, error) + Int(ctx context.Context, key string) (int, error) + Int64(ctx context.Context, key string) (int64, error) + Bool(ctx context.Context, key string) (bool, error) + Float(ctx context.Context, key string) (float64, error) // support section::key type in key string when using ini and json type; Int,Int64,Bool,Float,DIY are same. - DefaultString(key string, defaultVal string) string + DefaultString(ctx context.Context, key string, defaultVal string) string // get string slice - DefaultStrings(key string, defaultVal []string) []string - DefaultInt(key string, defaultVal int) int - DefaultInt64(key string, defaultVal int64) int64 - DefaultBool(key string, defaultVal bool) bool - DefaultFloat(key string, defaultVal float64) float64 - DIY(key string) (interface{}, error) - GetSection(section string) (map[string]string, error) + DefaultStrings(ctx context.Context, key string, defaultVal []string) []string + DefaultInt(ctx context.Context, key string, defaultVal int) int + DefaultInt64(ctx context.Context, key string, defaultVal int64) int64 + DefaultBool(ctx context.Context, key string, defaultVal bool) bool + DefaultFloat(ctx context.Context, key string, defaultVal float64) float64 + DIY(ctx context.Context, key string) (interface{}, error) - Unmarshaler(obj interface{}) error - Sub(key string) (Configer, error) - OnChange(fn func(cfg Configer)) - // GetByPrefix(prefix string) ([]byte, error) - // GetSerializer() Serializer - SaveConfigFile(filename string) error + GetSection(ctx context.Context, section string) (map[string]string, error) + + Unmarshaler(ctx context.Context, prefix string, obj interface{}, opt ...DecodeOption) error + Sub(ctx context.Context, key string) (Configer, error) + OnChange(ctx context.Context, key string, fn func(value string)) + SaveConfigFile(ctx context.Context, filename string) error } type BaseConfiger struct { // The reader should support key like "a.b.c" - reader func(key string) (string, error) + reader func(ctx context.Context, key string) (string, error) } -func (c *BaseConfiger) Int(key string) (int, error) { - res, err := c.reader(key) +func NewBaseConfiger(reader func(ctx context.Context, key string) (string, error)) BaseConfiger { + return BaseConfiger{ + reader: reader, + } +} + +func (c *BaseConfiger) Int(ctx context.Context, key string) (int, error) { + res, err := c.reader(context.TODO(), key) if err != nil { return 0, err } return strconv.Atoi(res) } -func (c *BaseConfiger) Int64(key string) (int64, error) { - res, err := c.reader(key) +func (c *BaseConfiger) Int64(ctx context.Context, key string) (int64, error) { + res, err := c.reader(context.TODO(), key) if err != nil { return 0, err } return strconv.ParseInt(res, 10, 64) } -func (c *BaseConfiger) Bool(key string) (bool, error) { - res, err := c.reader(key) +func (c *BaseConfiger) Bool(ctx context.Context, key string) (bool, error) { + res, err := c.reader(context.TODO(), key) if err != nil { return false, err } - return strconv.ParseBool(res) + return ParseBool(res) } -func (c *BaseConfiger) Float(key string) (float64, error) { - res, err := c.reader(key) +func (c *BaseConfiger) Float(ctx context.Context, key string) (float64, error) { + res, err := c.reader(context.TODO(), key) if err != nil { return 0, err } return strconv.ParseFloat(res, 64) } -func (c *BaseConfiger) DefaultString(key string, defaultVal string) string { - if res := c.String(key); res != "" { +// DefaultString returns the string value for a given key. +// if err != nil or value is empty return defaultval +func (c *BaseConfiger) DefaultString(ctx context.Context, key string, defaultVal string) string { + if res, err := c.String(ctx, key); res != "" && err == nil { return res } return defaultVal } -func (c *BaseConfiger) DefaultStrings(key string, defaultVal []string) []string { - if res := c.Strings(key); len(res) > 0 { +// DefaultStrings returns the []string value for a given key. +// if err != nil return defaultval +func (c *BaseConfiger) DefaultStrings(ctx context.Context, key string, defaultVal []string) []string { + if res, err := c.Strings(ctx, key); len(res) > 0 && err == nil { return res } return defaultVal } -func (c *BaseConfiger) DefaultInt(key string, defaultVal int) int { - if res, err := c.Int(key); err == nil { +func (c *BaseConfiger) DefaultInt(ctx context.Context, key string, defaultVal int) int { + if res, err := c.Int(ctx, key); err == nil { return res } return defaultVal } -func (c *BaseConfiger) DefaultInt64(key string, defaultVal int64) int64 { - if res, err := c.Int64(key); err == nil { +func (c *BaseConfiger) DefaultInt64(ctx context.Context, key string, defaultVal int64) int64 { + if res, err := c.Int64(ctx, key); err == nil { return res } return defaultVal } -func (c *BaseConfiger) DefaultBool(key string, defaultVal bool) bool { - if res, err := c.Bool(key); err == nil { +func (c *BaseConfiger) DefaultBool(ctx context.Context, key string, defaultVal bool) bool { + if res, err := c.Bool(ctx, key); err == nil { return res } return defaultVal } -func (c *BaseConfiger) DefaultFloat(key string, defaultVal float64) float64 { - if res, err := c.Float(key); err == nil { +func (c *BaseConfiger) DefaultFloat(ctx context.Context, key string, defaultVal float64) float64 { + if res, err := c.Float(ctx, key); err == nil { return res } return defaultVal } -func (c *BaseConfiger) String(key string) string { - res, _ := c.reader(key) - return res +func (c *BaseConfiger) String(ctx context.Context, key string) (string, error) { + return c.reader(context.TODO(), key) } -func (c *BaseConfiger) Strings(key string) []string { - res, err := c.reader(key) +// Strings returns the []string value for a given key. +// Return nil if config value does not exist or is empty. +func (c *BaseConfiger) Strings(ctx context.Context, key string) ([]string, error) { + res, err := c.String(nil, key) if err != nil || res == "" { - return nil + return nil, err } - return strings.Split(res, ";") + return strings.Split(res, ";"), nil } // TODO remove this before release v2.0.0 -func (c *BaseConfiger) Unmarshaler(obj interface{}) error { +func (c *BaseConfiger) Unmarshaler(ctx context.Context, prefix string, obj interface{}, opt ...DecodeOption) error { return errors.New("unsupported operation") } // TODO remove this before release v2.0.0 -func (c *BaseConfiger) Sub(key string) (Configer, error) { +func (c *BaseConfiger) Sub(ctx context.Context, key string) (Configer, error) { return nil, errors.New("unsupported operation") } // TODO remove this before release v2.0.0 -func (c *BaseConfiger) OnChange(fn func(cfg Configer)) { +func (c *BaseConfiger) OnChange(ctx context.Context, key string, fn func(value string)) { // do nothing } @@ -361,3 +372,8 @@ func ToString(x interface{}) string { // Fallback to fmt package for anything else like numeric types return fmt.Sprint(x) } + +type DecodeOption func(options decodeOptions) + +type decodeOptions struct { +} diff --git a/pkg/infrastructure/config/config_test.go b/pkg/infrastructure/config/config_test.go new file mode 100644 index 00000000..15d6ffa6 --- /dev/null +++ b/pkg/infrastructure/config/config_test.go @@ -0,0 +1,55 @@ +// Copyright 2016 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package config + +import ( + "os" + "testing" +) + +func TestExpandValueEnv(t *testing.T) { + + testCases := []struct { + item string + want string + }{ + {"", ""}, + {"$", "$"}, + {"{", "{"}, + {"{}", "{}"}, + {"${}", ""}, + {"${|}", ""}, + {"${}", ""}, + {"${{}}", ""}, + {"${{||}}", "}"}, + {"${pwd||}", ""}, + {"${pwd||}", ""}, + {"${pwd||}", ""}, + {"${pwd||}}", "}"}, + {"${pwd||{{||}}}", "{{||}}"}, + {"${GOPATH}", os.Getenv("GOPATH")}, + {"${GOPATH||}", os.Getenv("GOPATH")}, + {"${GOPATH||root}", os.Getenv("GOPATH")}, + {"${GOPATH_NOT||root}", "root"}, + {"${GOPATH_NOT||||root}", "||root"}, + } + + for _, c := range testCases { + if got := ExpandValueEnv(c.item); got != c.want { + t.Errorf("expand value error, item %q want %q, got %q", c.item, c.want, got) + } + } + +} diff --git a/pkg/config/env/env.go b/pkg/infrastructure/config/env/env.go similarity index 97% rename from pkg/config/env/env.go rename to pkg/infrastructure/config/env/env.go index 5d8e47de..83155b34 100644 --- a/pkg/config/env/env.go +++ b/pkg/infrastructure/config/env/env.go @@ -21,7 +21,7 @@ import ( "os" "strings" - "github.com/astaxie/beego/pkg/utils" + "github.com/astaxie/beego/pkg/infrastructure/utils" ) var env *utils.BeeMap diff --git a/pkg/infrastructure/config/env/env_test.go b/pkg/infrastructure/config/env/env_test.go new file mode 100644 index 00000000..3f1d4dba --- /dev/null +++ b/pkg/infrastructure/config/env/env_test.go @@ -0,0 +1,75 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// Copyright 2017 Faissal Elamraoui. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +package env + +import ( + "os" + "testing" +) + +func TestEnvGet(t *testing.T) { + gopath := Get("GOPATH", "") + if gopath != os.Getenv("GOPATH") { + t.Error("expected GOPATH not empty.") + } + + noExistVar := Get("NOEXISTVAR", "foo") + if noExistVar != "foo" { + t.Errorf("expected NOEXISTVAR to equal foo, got %s.", noExistVar) + } +} + +func TestEnvMustGet(t *testing.T) { + gopath, err := MustGet("GOPATH") + if err != nil { + t.Error(err) + } + + if gopath != os.Getenv("GOPATH") { + t.Errorf("expected GOPATH to be the same, got %s.", gopath) + } + + _, err = MustGet("NOEXISTVAR") + if err == nil { + t.Error("expected error to be non-nil") + } +} + +func TestEnvSet(t *testing.T) { + Set("MYVAR", "foo") + myVar := Get("MYVAR", "bar") + if myVar != "foo" { + t.Errorf("expected MYVAR to equal foo, got %s.", myVar) + } +} + +func TestEnvMustSet(t *testing.T) { + err := MustSet("FOO", "bar") + if err != nil { + t.Error(err) + } + + fooVar := os.Getenv("FOO") + if fooVar != "bar" { + t.Errorf("expected FOO variable to equal bar, got %s.", fooVar) + } +} + +func TestEnvGetAll(t *testing.T) { + envMap := GetAll() + if len(envMap) == 0 { + t.Error("expected environment not empty.") + } +} diff --git a/pkg/infrastructure/config/etcd/config.go b/pkg/infrastructure/config/etcd/config.go new file mode 100644 index 00000000..94057d73 --- /dev/null +++ b/pkg/infrastructure/config/etcd/config.go @@ -0,0 +1,214 @@ +// Copyright 2020 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package etcd + +import ( + "context" + "encoding/json" + "fmt" + "time" + + "github.com/coreos/etcd/clientv3" + grpc_prometheus "github.com/grpc-ecosystem/go-grpc-prometheus" + "github.com/mitchellh/mapstructure" + "github.com/pkg/errors" + "google.golang.org/grpc" + + "github.com/astaxie/beego/pkg/infrastructure/config" + "github.com/astaxie/beego/pkg/infrastructure/logs" +) + +const etcdOpts = "etcdOpts" + +type EtcdConfiger struct { + prefix string + client *clientv3.Client + config.BaseConfiger +} + +func newEtcdConfiger(client *clientv3.Client, prefix string) *EtcdConfiger { + res := &EtcdConfiger{ + client: client, + prefix: prefix, + } + + res.BaseConfiger = config.NewBaseConfiger(res.reader) + return res +} + +// reader is an general implementation that read config from etcd. +func (e *EtcdConfiger) reader(ctx context.Context, key string) (string, error) { + resp, err := get(e.client, ctx, e.prefix+key) + if err != nil { + return "", err + } + + if resp.Count > 0 { + return string(resp.Kvs[0].Value), nil + } + + return "", nil +} + +// Set do nothing and return an error +// I think write data to remote config center is not a good practice +func (e *EtcdConfiger) Set(ctx context.Context, key, val string) error { + return errors.New("Unsupported operation") +} + +// DIY return the original response from etcd +// be careful when you decide to use this +func (e *EtcdConfiger) DIY(ctx context.Context, key string) (interface{}, error) { + return get(e.client, context.TODO(), key) +} + +// GetSection in this implementation, we use section as prefix +func (e *EtcdConfiger) GetSection(ctx context.Context, section string) (map[string]string, error) { + var ( + resp *clientv3.GetResponse + err error + ) + + if opts, ok := ctx.Value(etcdOpts).([]clientv3.OpOption); ok { + opts = append(opts, clientv3.WithPrefix()) + resp, err = e.client.Get(context.TODO(), e.prefix+section, opts...) + } else { + resp, err = e.client.Get(context.TODO(), e.prefix+section, clientv3.WithPrefix()) + } + + if err != nil { + return nil, errors.WithMessage(err, "GetSection failed") + } + res := make(map[string]string, len(resp.Kvs)) + for _, kv := range resp.Kvs { + res[string(kv.Key)] = string(kv.Value) + } + return res, nil +} + +func (e *EtcdConfiger) SaveConfigFile(ctx context.Context, filename string) error { + return errors.New("Unsupported operation") +} + +// Unmarshaler is not very powerful because we lost the type information when we get configuration from etcd +// for example, when we got "5", we are not sure whether it's int 5, or it's string "5" +// TODO(support more complicated decoder) +func (e *EtcdConfiger) Unmarshaler(ctx context.Context, prefix string, obj interface{}, opt ...config.DecodeOption) error { + res, err := e.GetSection(ctx, prefix) + if err != nil { + return errors.WithMessage(err, fmt.Sprintf("could not read config with prefix: %s", prefix)) + } + + prefixLen := len(e.prefix + prefix) + m := make(map[string]string, len(res)) + for k, v := range res { + m[k[prefixLen:]] = v + } + return mapstructure.Decode(m, obj) +} + +// Sub return an sub configer. +func (e *EtcdConfiger) Sub(ctx context.Context, key string) (config.Configer, error) { + return newEtcdConfiger(e.client, e.prefix+key), nil +} + +// TODO remove this before release v2.0.0 +func (e *EtcdConfiger) OnChange(ctx context.Context, key string, fn func(value string)) { + + buildOptsFunc := func() []clientv3.OpOption { + if opts, ok := ctx.Value(etcdOpts).([]clientv3.OpOption); ok { + opts = append(opts, clientv3.WithCreatedNotify()) + return opts + } + return []clientv3.OpOption{} + } + + rch := e.client.Watch(ctx, e.prefix+key, buildOptsFunc()...) + go func() { + for { + for resp := range rch { + if err := resp.Err(); err != nil { + logs.Error("listen to key but got error callback", err) + break + } + + for _, e := range resp.Events { + if e.Kv == nil { + continue + } + fn(string(e.Kv.Value)) + } + } + time.Sleep(time.Second) + rch = e.client.Watch(ctx, e.prefix+key, buildOptsFunc()...) + } + }() + +} + +type EtcdConfigerProvider struct { +} + +// Parse = ParseData([]byte(key)) +// key must be json +func (provider *EtcdConfigerProvider) Parse(key string) (config.Configer, error) { + return provider.ParseData([]byte(key)) +} + +// ParseData try to parse key as clientv3.Config, using this to build etcdClient +func (provider *EtcdConfigerProvider) ParseData(data []byte) (config.Configer, error) { + cfg := &clientv3.Config{} + err := json.Unmarshal(data, cfg) + if err != nil { + return nil, errors.WithMessage(err, "parse data to etcd config failed, please check your input") + } + + cfg.DialOptions = []grpc.DialOption{ + grpc.WithBlock(), + grpc.WithUnaryInterceptor(grpc_prometheus.UnaryClientInterceptor), + grpc.WithStreamInterceptor(grpc_prometheus.StreamClientInterceptor), + } + client, err := clientv3.New(*cfg) + if err != nil { + return nil, errors.WithMessage(err, "create etcd client failed") + } + + return newEtcdConfiger(client, ""), nil +} + +func get(client *clientv3.Client, ctx context.Context, key string) (*clientv3.GetResponse, error) { + var ( + resp *clientv3.GetResponse + err error + ) + if opts, ok := ctx.Value(etcdOpts).([]clientv3.OpOption); ok { + resp, err = client.Get(ctx, key, opts...) + } else { + resp, err = client.Get(ctx, key) + } + + if err != nil { + return nil, errors.WithMessage(err, fmt.Sprintf("read config from etcd with key %s failed", key)) + } + return resp, err +} + +func WithEtcdOption(ctx context.Context, opts ...clientv3.OpOption) context.Context { + return context.WithValue(ctx, etcdOpts, opts) +} + +func init() { + config.Register("json", &EtcdConfigerProvider{}) +} diff --git a/pkg/infrastructure/config/etcd/config_test.go b/pkg/infrastructure/config/etcd/config_test.go new file mode 100644 index 00000000..7ccf6b96 --- /dev/null +++ b/pkg/infrastructure/config/etcd/config_test.go @@ -0,0 +1,123 @@ +// Copyright 2020 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package etcd + +import ( + "context" + "encoding/json" + "os" + "testing" + "time" + + "github.com/coreos/etcd/clientv3" + "github.com/stretchr/testify/assert" +) + +func TestWithEtcdOption(t *testing.T) { + ctx := WithEtcdOption(context.Background(), clientv3.WithPrefix()) + assert.NotNil(t, ctx.Value(etcdOpts)) +} + +func TestEtcdConfigerProvider_Parse(t *testing.T) { + provider := &EtcdConfigerProvider{} + cfger, err := provider.Parse(readEtcdConfig()) + assert.Nil(t, err) + assert.NotNil(t, cfger) +} + +func TestEtcdConfiger(t *testing.T) { + + provider := &EtcdConfigerProvider{} + cfger, _ := provider.Parse(readEtcdConfig()) + + subCfger, err := cfger.Sub(nil, "sub.") + assert.Nil(t, err) + assert.NotNil(t, subCfger) + + subSubCfger, err := subCfger.Sub(nil, "sub.") + assert.NotNil(t, subSubCfger) + assert.Nil(t, err) + + str, err := subSubCfger.String(nil, "key1") + assert.Nil(t, err) + assert.Equal(t, "sub.sub.key", str) + + // we cannot test it + subSubCfger.OnChange(context.Background(), "watch", func(value string) { + // do nothing + }) + + defStr := cfger.DefaultString(nil, "not_exit", "default value") + assert.Equal(t, "default value", defStr) + + defInt64 := cfger.DefaultInt64(nil, "not_exit", -1) + assert.Equal(t, int64(-1), defInt64) + + defInt := cfger.DefaultInt(nil, "not_exit", -2) + assert.Equal(t, -2, defInt) + + defFlt := cfger.DefaultFloat(nil, "not_exit", 12.3) + assert.Equal(t, 12.3, defFlt) + + defBl := cfger.DefaultBool(nil, "not_exit", true) + assert.True(t, defBl) + + defStrs := cfger.DefaultStrings(nil, "not_exit", []string{"hello"}) + assert.Equal(t, []string{"hello"}, defStrs) + + fl, err := cfger.Float(nil, "current.float") + assert.Nil(t, err) + assert.Equal(t, 1.23, fl) + + bl, err := cfger.Bool(nil, "current.bool") + assert.Nil(t, err) + assert.True(t, bl) + + it, err := cfger.Int(nil, "current.int") + assert.Nil(t, err) + assert.Equal(t, 11, it) + + str, err = cfger.String(nil, "current.string") + assert.Nil(t, err) + assert.Equal(t, "hello", str) + + tn := &TestEntity{} + err = cfger.Unmarshaler(context.Background(), "current.serialize.", tn) + assert.Nil(t, err) + assert.Equal(t, "test", tn.Name) +} + +type TestEntity struct { + Name string `yaml:"name"` + Sub SubEntity `yaml:"sub"` +} + +type SubEntity struct { + SubName string `yaml:"subName"` +} + +func readEtcdConfig() string { + addr := os.Getenv("ETCD_ADDR") + if addr == "" { + addr = "localhost:2379" + } + + obj := clientv3.Config{ + Endpoints: []string{addr}, + DialTimeout: 3 * time.Second, + } + cfg, _ := json.Marshal(obj) + return string(cfg) +} diff --git a/pkg/config/fake.go b/pkg/infrastructure/config/fake.go similarity index 50% rename from pkg/config/fake.go rename to pkg/infrastructure/config/fake.go index ddbc99b8..b606be01 100644 --- a/pkg/config/fake.go +++ b/pkg/infrastructure/config/fake.go @@ -15,6 +15,7 @@ package config import ( + "context" "errors" "strconv" "strings" @@ -29,99 +30,71 @@ func (c *fakeConfigContainer) getData(key string) string { return c.data[strings.ToLower(key)] } -func (c *fakeConfigContainer) Set(key, val string) error { +func (c *fakeConfigContainer) Set(ctx context.Context, key, val string) error { c.data[strings.ToLower(key)] = val return nil } -func (c *fakeConfigContainer) String(key string) string { - return c.getData(key) -} - -func (c *fakeConfigContainer) DefaultString(key string, defaultval string) string { - v := c.String(key) - if v == "" { - return defaultval - } - return v -} - -func (c *fakeConfigContainer) Strings(key string) []string { - v := c.String(key) - if v == "" { - return nil - } - return strings.Split(v, ";") -} - -func (c *fakeConfigContainer) DefaultStrings(key string, defaultval []string) []string { - v := c.Strings(key) - if v == nil { - return defaultval - } - return v -} - -func (c *fakeConfigContainer) Int(key string) (int, error) { +func (c *fakeConfigContainer) Int(ctx context.Context, key string) (int, error) { return strconv.Atoi(c.getData(key)) } -func (c *fakeConfigContainer) DefaultInt(key string, defaultval int) int { - v, err := c.Int(key) +func (c *fakeConfigContainer) DefaultInt(ctx context.Context, key string, defaultVal int) int { + v, err := c.Int(ctx, key) if err != nil { - return defaultval + return defaultVal } return v } -func (c *fakeConfigContainer) Int64(key string) (int64, error) { +func (c *fakeConfigContainer) Int64(ctx context.Context, key string) (int64, error) { return strconv.ParseInt(c.getData(key), 10, 64) } -func (c *fakeConfigContainer) DefaultInt64(key string, defaultval int64) int64 { - v, err := c.Int64(key) +func (c *fakeConfigContainer) DefaultInt64(ctx context.Context, key string, defaultVal int64) int64 { + v, err := c.Int64(ctx, key) if err != nil { - return defaultval + return defaultVal } return v } -func (c *fakeConfigContainer) Bool(key string) (bool, error) { +func (c *fakeConfigContainer) Bool(ctx context.Context, key string) (bool, error) { return ParseBool(c.getData(key)) } -func (c *fakeConfigContainer) DefaultBool(key string, defaultval bool) bool { - v, err := c.Bool(key) +func (c *fakeConfigContainer) DefaultBool(ctx context.Context, key string, defaultVal bool) bool { + v, err := c.Bool(ctx, key) if err != nil { - return defaultval + return defaultVal } return v } -func (c *fakeConfigContainer) Float(key string) (float64, error) { +func (c *fakeConfigContainer) Float(ctx context.Context, key string) (float64, error) { return strconv.ParseFloat(c.getData(key), 64) } -func (c *fakeConfigContainer) DefaultFloat(key string, defaultval float64) float64 { - v, err := c.Float(key) +func (c *fakeConfigContainer) DefaultFloat(ctx context.Context, key string, defaultVal float64) float64 { + v, err := c.Float(ctx, key) if err != nil { - return defaultval + return defaultVal } return v } -func (c *fakeConfigContainer) DIY(key string) (interface{}, error) { +func (c *fakeConfigContainer) DIY(ctx context.Context, key string) (interface{}, error) { if v, ok := c.data[strings.ToLower(key)]; ok { return v, nil } return nil, errors.New("key not find") } -func (c *fakeConfigContainer) GetSection(section string) (map[string]string, error) { +func (c *fakeConfigContainer) GetSection(ctx context.Context, section string) (map[string]string, error) { return nil, errors.New("not implement in the fakeConfigContainer") } -func (c *fakeConfigContainer) SaveConfigFile(filename string) error { +func (c *fakeConfigContainer) SaveConfigFile(ctx context.Context, filename string) error { return errors.New("not implement in the fakeConfigContainer") } @@ -129,7 +102,11 @@ var _ Configer = new(fakeConfigContainer) // NewFakeConfig return a fake Configer func NewFakeConfig() Configer { - return &fakeConfigContainer{ + res := &fakeConfigContainer{ data: make(map[string]string), } + res.BaseConfiger = NewBaseConfiger(func(ctx context.Context, key string) (string, error) { + return res.getData(key), nil + }) + return res } diff --git a/pkg/config/ini.go b/pkg/infrastructure/config/ini.go similarity index 82% rename from pkg/config/ini.go rename to pkg/infrastructure/config/ini.go index 0bef67d4..cc67e4cd 100644 --- a/pkg/config/ini.go +++ b/pkg/infrastructure/config/ini.go @@ -17,6 +17,7 @@ package config import ( "bufio" "bytes" + "context" "errors" "io" "io/ioutil" @@ -65,6 +66,10 @@ func (ini *IniConfig) parseData(dir string, data []byte) (*IniConfigContainer, e keyComment: make(map[string]string), RWMutex: sync.RWMutex{}, } + + cfg.BaseConfiger = NewBaseConfiger(func(ctx context.Context, key string) (string, error) { + return cfg.getdata(key), nil + }) cfg.Lock() defer cfg.Unlock() @@ -233,102 +238,102 @@ type IniConfigContainer struct { } // Bool returns the boolean value for a given key. -func (c *IniConfigContainer) Bool(key string) (bool, error) { +func (c *IniConfigContainer) Bool(ctx context.Context, key string) (bool, error) { return ParseBool(c.getdata(key)) } // DefaultBool returns the boolean value for a given key. -// if err != nil return defaultval -func (c *IniConfigContainer) DefaultBool(key string, defaultval bool) bool { - v, err := c.Bool(key) +// if err != nil return defaultVal +func (c *IniConfigContainer) DefaultBool(ctx context.Context, key string, defaultVal bool) bool { + v, err := c.Bool(ctx, key) if err != nil { - return defaultval + return defaultVal } return v } // Int returns the integer value for a given key. -func (c *IniConfigContainer) Int(key string) (int, error) { +func (c *IniConfigContainer) Int(ctx context.Context, key string) (int, error) { return strconv.Atoi(c.getdata(key)) } // DefaultInt returns the integer value for a given key. -// if err != nil return defaultval -func (c *IniConfigContainer) DefaultInt(key string, defaultval int) int { - v, err := c.Int(key) +// if err != nil return defaultVal +func (c *IniConfigContainer) DefaultInt(ctx context.Context, key string, defaultVal int) int { + v, err := c.Int(ctx, key) if err != nil { - return defaultval + return defaultVal } return v } // Int64 returns the int64 value for a given key. -func (c *IniConfigContainer) Int64(key string) (int64, error) { +func (c *IniConfigContainer) Int64(ctx context.Context, key string) (int64, error) { return strconv.ParseInt(c.getdata(key), 10, 64) } // DefaultInt64 returns the int64 value for a given key. -// if err != nil return defaultval -func (c *IniConfigContainer) DefaultInt64(key string, defaultval int64) int64 { - v, err := c.Int64(key) +// if err != nil return defaultVal +func (c *IniConfigContainer) DefaultInt64(ctx context.Context, key string, defaultVal int64) int64 { + v, err := c.Int64(ctx, key) if err != nil { - return defaultval + return defaultVal } return v } // Float returns the float value for a given key. -func (c *IniConfigContainer) Float(key string) (float64, error) { +func (c *IniConfigContainer) Float(ctx context.Context, key string) (float64, error) { return strconv.ParseFloat(c.getdata(key), 64) } // DefaultFloat returns the float64 value for a given key. -// if err != nil return defaultval -func (c *IniConfigContainer) DefaultFloat(key string, defaultval float64) float64 { - v, err := c.Float(key) +// if err != nil return defaultVal +func (c *IniConfigContainer) DefaultFloat(ctx context.Context, key string, defaultVal float64) float64 { + v, err := c.Float(ctx, key) if err != nil { - return defaultval + return defaultVal } return v } // String returns the string value for a given key. -func (c *IniConfigContainer) String(key string) string { - return c.getdata(key) +func (c *IniConfigContainer) String(ctx context.Context, key string) (string, error) { + return c.getdata(key), nil } // DefaultString returns the string value for a given key. -// if err != nil return defaultval -func (c *IniConfigContainer) DefaultString(key string, defaultval string) string { - v := c.String(key) - if v == "" { - return defaultval +// if err != nil return defaultVal +func (c *IniConfigContainer) DefaultString(ctx context.Context, key string, defaultVal string) string { + v, err := c.String(nil, key) + if v == "" || err != nil { + return defaultVal } return v } // Strings returns the []string value for a given key. // Return nil if config value does not exist or is empty. -func (c *IniConfigContainer) Strings(key string) []string { - v := c.String(key) - if v == "" { - return nil +func (c *IniConfigContainer) Strings(ctx context.Context, key string) ([]string, error) { + v, err := c.String(nil, key) + if v == "" || err != nil { + return nil, err } - return strings.Split(v, ";") + return strings.Split(v, ";"), nil } // DefaultStrings returns the []string value for a given key. -// if err != nil return defaultval -func (c *IniConfigContainer) DefaultStrings(key string, defaultval []string) []string { - v := c.Strings(key) - if v == nil { - return defaultval +// if err != nil return defaultVal +func (c *IniConfigContainer) DefaultStrings(ctx context.Context, key string, defaultVal []string) []string { + v, err := c.Strings(ctx, key) + if v == nil || err != nil { + return defaultVal } return v } // GetSection returns map for the given section -func (c *IniConfigContainer) GetSection(section string) (map[string]string, error) { +func (c *IniConfigContainer) GetSection(ctx context.Context, section string) (map[string]string, error) { if v, ok := c.data[section]; ok { return v, nil } @@ -338,7 +343,7 @@ func (c *IniConfigContainer) GetSection(section string) (map[string]string, erro // SaveConfigFile save the config into file. // // BUG(env): The environment variable config item will be saved with real value in SaveConfigFile Function. -func (c *IniConfigContainer) SaveConfigFile(filename string) (err error) { +func (c *IniConfigContainer) SaveConfigFile(ctx context.Context, filename string) (err error) { // Write configuration file by filename. f, err := os.Create(filename) if err != nil { @@ -438,7 +443,7 @@ func (c *IniConfigContainer) SaveConfigFile(filename string) (err error) { // Set writes a new value for key. // if write to one section, the key need be "section::key". // if the section is not existed, it panics. -func (c *IniConfigContainer) Set(key, value string) error { +func (c *IniConfigContainer) Set(ctx context.Context, key, val string) error { c.Lock() defer c.Unlock() if len(key) == 0 { @@ -461,12 +466,12 @@ func (c *IniConfigContainer) Set(key, value string) error { if _, ok := c.data[section]; !ok { c.data[section] = make(map[string]string) } - c.data[section][k] = value + c.data[section][k] = val return nil } // DIY returns the raw value by a given key. -func (c *IniConfigContainer) DIY(key string) (v interface{}, err error) { +func (c *IniConfigContainer) DIY(ctx context.Context, key string) (v interface{}, err error) { if v, ok := c.data[strings.ToLower(key)]; ok { return v, nil } diff --git a/pkg/infrastructure/config/ini_test.go b/pkg/infrastructure/config/ini_test.go new file mode 100644 index 00000000..d4972ddd --- /dev/null +++ b/pkg/infrastructure/config/ini_test.go @@ -0,0 +1,191 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package config + +import ( + "fmt" + "io/ioutil" + "os" + "strings" + "testing" +) + +func TestIni(t *testing.T) { + + var ( + inicontext = ` +;comment one +#comment two +appname = beeapi +httpport = 8080 +mysqlport = 3600 +PI = 3.1415976 +runmode = "dev" +autorender = false +copyrequestbody = true +session= on +cookieon= off +newreg = OFF +needlogin = ON +enableSession = Y +enableCookie = N +flag = 1 +path1 = ${GOPATH} +path2 = ${GOPATH||/home/go} +[demo] +key1="asta" +key2 = "xie" +CaseInsensitive = true +peers = one;two;three +password = ${GOPATH} +` + + keyValue = map[string]interface{}{ + "appname": "beeapi", + "httpport": 8080, + "mysqlport": int64(3600), + "pi": 3.1415976, + "runmode": "dev", + "autorender": false, + "copyrequestbody": true, + "session": true, + "cookieon": false, + "newreg": false, + "needlogin": true, + "enableSession": true, + "enableCookie": false, + "flag": true, + "path1": os.Getenv("GOPATH"), + "path2": os.Getenv("GOPATH"), + "demo::key1": "asta", + "demo::key2": "xie", + "demo::CaseInsensitive": true, + "demo::peers": []string{"one", "two", "three"}, + "demo::password": os.Getenv("GOPATH"), + "null": "", + "demo2::key1": "", + "error": "", + "emptystrings": []string{}, + } + ) + + f, err := os.Create("testini.conf") + if err != nil { + t.Fatal(err) + } + _, err = f.WriteString(inicontext) + if err != nil { + f.Close() + t.Fatal(err) + } + f.Close() + defer os.Remove("testini.conf") + iniconf, err := NewConfig("ini", "testini.conf") + if err != nil { + t.Fatal(err) + } + for k, v := range keyValue { + var err error + var value interface{} + switch v.(type) { + case int: + value, err = iniconf.Int(nil, k) + case int64: + value, err = iniconf.Int64(nil, k) + case float64: + value, err = iniconf.Float(nil, k) + case bool: + value, err = iniconf.Bool(nil, k) + case []string: + value, err = iniconf.Strings(nil, k) + case string: + value, err = iniconf.String(nil, k) + default: + value, err = iniconf.DIY(nil, k) + } + if err != nil { + t.Fatalf("get key %q value fail,err %s", k, err) + } else if fmt.Sprintf("%v", v) != fmt.Sprintf("%v", value) { + t.Fatalf("get key %q value, want %v got %v .", k, v, value) + } + + } + if err = iniconf.Set(nil, "name", "astaxie"); err != nil { + t.Fatal(err) + } + res, _ := iniconf.String(nil, "name") + if res != "astaxie" { + t.Fatal("get name error") + } + +} + +func TestIniSave(t *testing.T) { + + const ( + inicontext = ` +app = app +;comment one +#comment two +# comment three +appname = beeapi +httpport = 8080 +# DB Info +# enable db +[dbinfo] +# db type name +# suport mysql,sqlserver +name = mysql +` + + saveResult = ` +app=app +#comment one +#comment two +# comment three +appname=beeapi +httpport=8080 + +# DB Info +# enable db +[dbinfo] +# db type name +# suport mysql,sqlserver +name=mysql +` + ) + cfg, err := NewConfigData("ini", []byte(inicontext)) + if err != nil { + t.Fatal(err) + } + name := "newIniConfig.ini" + if err := cfg.SaveConfigFile(nil, name); err != nil { + t.Fatal(err) + } + defer os.Remove(name) + + if data, err := ioutil.ReadFile(name); err != nil { + t.Fatal(err) + } else { + cfgData := string(data) + datas := strings.Split(saveResult, "\n") + for _, line := range datas { + if !strings.Contains(cfgData, line+"\n") { + t.Fatalf("different after save ini config file. need contains %q", line) + } + } + + } +} diff --git a/pkg/config/json/json.go b/pkg/infrastructure/config/json/json.go similarity index 60% rename from pkg/config/json/json.go rename to pkg/infrastructure/config/json/json.go index 876077e1..c65eff4d 100644 --- a/pkg/config/json/json.go +++ b/pkg/infrastructure/config/json/json.go @@ -15,6 +15,7 @@ package json import ( + "context" "encoding/json" "errors" "fmt" @@ -24,7 +25,10 @@ import ( "strings" "sync" - "github.com/astaxie/beego/pkg/config" + "github.com/mitchellh/mapstructure" + + "github.com/astaxie/beego/pkg/infrastructure/config" + "github.com/astaxie/beego/pkg/infrastructure/logs" ) // JSONConfig is a json config parser and implements Config interface. @@ -69,13 +73,50 @@ func (js *JSONConfig) ParseData(data []byte) (config.Configer, error) { // JSONConfigContainer is a config which represents the json configuration. // Only when get value, support key as section:name type. type JSONConfigContainer struct { - config.BaseConfiger data map[string]interface{} sync.RWMutex } +func (c *JSONConfigContainer) Unmarshaler(ctx context.Context, prefix string, obj interface{}, opt ...config.DecodeOption) error { + sub, err := c.sub(ctx, prefix) + if err != nil { + return err + } + return mapstructure.Decode(sub, obj) +} + +func (c *JSONConfigContainer) Sub(ctx context.Context, key string) (config.Configer, error) { + sub, err := c.sub(ctx, key) + if err != nil { + return nil, err + } + return &JSONConfigContainer{ + data: sub, + }, nil +} + +func (c *JSONConfigContainer) sub(ctx context.Context, key string) (map[string]interface{}, error) { + if key == "" { + return c.data, nil + } + value, ok := c.data[key] + if !ok { + return nil, errors.New(fmt.Sprintf("key is not found: %s", key)) + } + + res, ok := value.(map[string]interface{}) + if !ok { + return nil, errors.New(fmt.Sprintf("the type of value is invalid, key: %s", key)) + } + return res, nil +} + +func (c *JSONConfigContainer) OnChange(ctx context.Context, key string, fn func(value string)) { + logs.Warn("unsupported operation") +} + // Bool returns the boolean value for a given key. -func (c *JSONConfigContainer) Bool(key string) (bool, error) { +func (c *JSONConfigContainer) Bool(ctx context.Context, key string) (bool, error) { val := c.getData(key) if val != nil { return config.ParseBool(val) @@ -85,15 +126,15 @@ func (c *JSONConfigContainer) Bool(key string) (bool, error) { // DefaultBool return the bool value if has no error // otherwise return the defaultval -func (c *JSONConfigContainer) DefaultBool(key string, defaultval bool) bool { - if v, err := c.Bool(key); err == nil { +func (c *JSONConfigContainer) DefaultBool(ctx context.Context, key string, defaultVal bool) bool { + if v, err := c.Bool(ctx, key); err == nil { return v } - return defaultval + return defaultVal } // Int returns the integer value for a given key. -func (c *JSONConfigContainer) Int(key string) (int, error) { +func (c *JSONConfigContainer) Int(ctx context.Context, key string) (int, error) { val := c.getData(key) if val != nil { if v, ok := val.(float64); ok { @@ -108,15 +149,15 @@ func (c *JSONConfigContainer) Int(key string) (int, error) { // DefaultInt returns the integer value for a given key. // if err != nil return defaultval -func (c *JSONConfigContainer) DefaultInt(key string, defaultval int) int { - if v, err := c.Int(key); err == nil { +func (c *JSONConfigContainer) DefaultInt(ctx context.Context, key string, defaultVal int) int { + if v, err := c.Int(ctx, key); err == nil { return v } - return defaultval + return defaultVal } // Int64 returns the int64 value for a given key. -func (c *JSONConfigContainer) Int64(key string) (int64, error) { +func (c *JSONConfigContainer) Int64(ctx context.Context, key string) (int64, error) { val := c.getData(key) if val != nil { if v, ok := val.(float64); ok { @@ -129,15 +170,15 @@ func (c *JSONConfigContainer) Int64(key string) (int64, error) { // DefaultInt64 returns the int64 value for a given key. // if err != nil return defaultval -func (c *JSONConfigContainer) DefaultInt64(key string, defaultval int64) int64 { - if v, err := c.Int64(key); err == nil { +func (c *JSONConfigContainer) DefaultInt64(ctx context.Context, key string, defaultVal int64) int64 { + if v, err := c.Int64(ctx, key); err == nil { return v } - return defaultval + return defaultVal } // Float returns the float value for a given key. -func (c *JSONConfigContainer) Float(key string) (float64, error) { +func (c *JSONConfigContainer) Float(ctx context.Context, key string) (float64, error) { val := c.getData(key) if val != nil { if v, ok := val.(float64); ok { @@ -150,54 +191,54 @@ func (c *JSONConfigContainer) Float(key string) (float64, error) { // DefaultFloat returns the float64 value for a given key. // if err != nil return defaultval -func (c *JSONConfigContainer) DefaultFloat(key string, defaultval float64) float64 { - if v, err := c.Float(key); err == nil { +func (c *JSONConfigContainer) DefaultFloat(ctx context.Context, key string, defaultVal float64) float64 { + if v, err := c.Float(ctx, key); err == nil { return v } - return defaultval + return defaultVal } // String returns the string value for a given key. -func (c *JSONConfigContainer) String(key string) string { +func (c *JSONConfigContainer) String(ctx context.Context, key string) (string, error) { val := c.getData(key) if val != nil { if v, ok := val.(string); ok { - return v + return v, nil } } - return "" + return "", nil } // DefaultString returns the string value for a given key. // if err != nil return defaultval -func (c *JSONConfigContainer) DefaultString(key string, defaultval string) string { +func (c *JSONConfigContainer) DefaultString(ctx context.Context, key string, defaultVal string) string { // TODO FIXME should not use "" to replace non existence - if v := c.String(key); v != "" { + if v, err := c.String(ctx, key); v != "" && err == nil { return v } - return defaultval + return defaultVal } // Strings returns the []string value for a given key. -func (c *JSONConfigContainer) Strings(key string) []string { - stringVal := c.String(key) - if stringVal == "" { - return nil +func (c *JSONConfigContainer) Strings(ctx context.Context, key string) ([]string, error) { + stringVal, err := c.String(nil, key) + if stringVal == "" || err != nil { + return nil, err } - return strings.Split(c.String(key), ";") + return strings.Split(stringVal, ";"), nil } // DefaultStrings returns the []string value for a given key. // if err != nil return defaultval -func (c *JSONConfigContainer) DefaultStrings(key string, defaultval []string) []string { - if v := c.Strings(key); v != nil { +func (c *JSONConfigContainer) DefaultStrings(ctx context.Context, key string, defaultVal []string) []string { + if v, err := c.Strings(ctx, key); v != nil && err == nil { return v } - return defaultval + return defaultVal } // GetSection returns map for the given section -func (c *JSONConfigContainer) GetSection(section string) (map[string]string, error) { +func (c *JSONConfigContainer) GetSection(ctx context.Context, section string) (map[string]string, error) { if v, ok := c.data[section]; ok { return v.(map[string]string), nil } @@ -205,7 +246,7 @@ func (c *JSONConfigContainer) GetSection(section string) (map[string]string, err } // SaveConfigFile save the config into file -func (c *JSONConfigContainer) SaveConfigFile(filename string) (err error) { +func (c *JSONConfigContainer) SaveConfigFile(ctx context.Context, filename string) (err error) { // Write configuration file by filename. f, err := os.Create(filename) if err != nil { @@ -221,7 +262,7 @@ func (c *JSONConfigContainer) SaveConfigFile(filename string) (err error) { } // Set writes a new value for key. -func (c *JSONConfigContainer) Set(key, val string) error { +func (c *JSONConfigContainer) Set(ctx context.Context, key, val string) error { c.Lock() defer c.Unlock() c.data[key] = val @@ -229,7 +270,7 @@ func (c *JSONConfigContainer) Set(key, val string) error { } // DIY returns the raw value by a given key. -func (c *JSONConfigContainer) DIY(key string) (v interface{}, err error) { +func (c *JSONConfigContainer) DIY(ctx context.Context, key string) (v interface{}, err error) { val := c.getData(key) if val != nil { return val, nil diff --git a/pkg/infrastructure/config/json/json_test.go b/pkg/infrastructure/config/json/json_test.go new file mode 100644 index 00000000..5275ee57 --- /dev/null +++ b/pkg/infrastructure/config/json/json_test.go @@ -0,0 +1,252 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package json + +import ( + "context" + "fmt" + "os" + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/astaxie/beego/pkg/infrastructure/config" +) + +func TestJsonStartsWithArray(t *testing.T) { + + const jsoncontextwitharray = `[ + { + "url": "user", + "serviceAPI": "http://www.test.com/user" + }, + { + "url": "employee", + "serviceAPI": "http://www.test.com/employee" + } +]` + f, err := os.Create("testjsonWithArray.conf") + if err != nil { + t.Fatal(err) + } + _, err = f.WriteString(jsoncontextwitharray) + if err != nil { + f.Close() + t.Fatal(err) + } + f.Close() + defer os.Remove("testjsonWithArray.conf") + jsonconf, err := config.NewConfig("json", "testjsonWithArray.conf") + if err != nil { + t.Fatal(err) + } + rootArray, err := jsonconf.DIY(nil, "rootArray") + if err != nil { + t.Error("array does not exist as element") + } + rootArrayCasted := rootArray.([]interface{}) + if rootArrayCasted == nil { + t.Error("array from root is nil") + } else { + elem := rootArrayCasted[0].(map[string]interface{}) + if elem["url"] != "user" || elem["serviceAPI"] != "http://www.test.com/user" { + t.Error("array[0] values are not valid") + } + + elem2 := rootArrayCasted[1].(map[string]interface{}) + if elem2["url"] != "employee" || elem2["serviceAPI"] != "http://www.test.com/employee" { + t.Error("array[1] values are not valid") + } + } +} + +func TestJson(t *testing.T) { + + var ( + jsoncontext = `{ +"appname": "beeapi", +"testnames": "foo;bar", +"httpport": 8080, +"mysqlport": 3600, +"PI": 3.1415976, +"runmode": "dev", +"autorender": false, +"copyrequestbody": true, +"session": "on", +"cookieon": "off", +"newreg": "OFF", +"needlogin": "ON", +"enableSession": "Y", +"enableCookie": "N", +"flag": 1, +"path1": "${GOPATH}", +"path2": "${GOPATH||/home/go}", +"database": { + "host": "host", + "port": "port", + "database": "database", + "username": "username", + "password": "${GOPATH}", + "conns":{ + "maxconnection":12, + "autoconnect":true, + "connectioninfo":"info", + "root": "${GOPATH}" + } + } +}` + keyValue = map[string]interface{}{ + "appname": "beeapi", + "testnames": []string{"foo", "bar"}, + "httpport": 8080, + "mysqlport": int64(3600), + "PI": 3.1415976, + "runmode": "dev", + "autorender": false, + "copyrequestbody": true, + "session": true, + "cookieon": false, + "newreg": false, + "needlogin": true, + "enableSession": true, + "enableCookie": false, + "flag": true, + "path1": os.Getenv("GOPATH"), + "path2": os.Getenv("GOPATH"), + "database::host": "host", + "database::port": "port", + "database::database": "database", + "database::password": os.Getenv("GOPATH"), + "database::conns::maxconnection": 12, + "database::conns::autoconnect": true, + "database::conns::connectioninfo": "info", + "database::conns::root": os.Getenv("GOPATH"), + "unknown": "", + } + ) + + f, err := os.Create("testjson.conf") + if err != nil { + t.Fatal(err) + } + _, err = f.WriteString(jsoncontext) + if err != nil { + f.Close() + t.Fatal(err) + } + f.Close() + defer os.Remove("testjson.conf") + jsonconf, err := config.NewConfig("json", "testjson.conf") + if err != nil { + t.Fatal(err) + } + + for k, v := range keyValue { + var err error + var value interface{} + switch v.(type) { + case int: + value, err = jsonconf.Int(nil, k) + case int64: + value, err = jsonconf.Int64(nil, k) + case float64: + value, err = jsonconf.Float(nil, k) + case bool: + value, err = jsonconf.Bool(nil, k) + case []string: + value, err = jsonconf.Strings(nil, k) + case string: + value, err = jsonconf.String(nil, k) + default: + value, err = jsonconf.DIY(nil, k) + } + if err != nil { + t.Fatalf("get key %q value fatal,%v err %s", k, v, err) + } else if fmt.Sprintf("%v", v) != fmt.Sprintf("%v", value) { + t.Fatalf("get key %q value, want %v got %v .", k, v, value) + } + + } + if err = jsonconf.Set(nil, "name", "astaxie"); err != nil { + t.Fatal(err) + } + + res, _ := jsonconf.String(nil, "name") + if res != "astaxie" { + t.Fatal("get name error") + } + + if db, err := jsonconf.DIY(nil, "database"); err != nil { + t.Fatal(err) + } else if m, ok := db.(map[string]interface{}); !ok { + t.Log(db) + t.Fatal("db not map[string]interface{}") + } else { + if m["host"].(string) != "host" { + t.Fatal("get host err") + } + } + + if _, err := jsonconf.Int(nil, "unknown"); err == nil { + t.Error("unknown keys should return an error when expecting an Int") + } + + if _, err := jsonconf.Int64(nil, "unknown"); err == nil { + t.Error("unknown keys should return an error when expecting an Int64") + } + + if _, err := jsonconf.Float(nil, "unknown"); err == nil { + t.Error("unknown keys should return an error when expecting a Float") + } + + if _, err := jsonconf.DIY(nil, "unknown"); err == nil { + t.Error("unknown keys should return an error when expecting an interface{}") + } + + if val, _ := jsonconf.String(nil, "unknown"); val != "" { + t.Error("unknown keys should return an empty string when expecting a String") + } + + if _, err := jsonconf.Bool(nil, "unknown"); err == nil { + t.Error("unknown keys should return an error when expecting a Bool") + } + + if !jsonconf.DefaultBool(nil, "unknown", true) { + t.Error("unknown keys with default value wrong") + } + + sub, err := jsonconf.Sub(context.Background(), "database") + assert.Nil(t, err) + assert.NotNil(t, sub) + + sub, err = sub.Sub(context.Background(), "conns") + assert.Nil(t, err) + + maxCon, _ := sub.Int(context.Background(), "maxconnection") + assert.Equal(t, 12, maxCon) + + dbCfg := &DatabaseConfig{} + err = sub.Unmarshaler(context.Background(), "", dbCfg) + assert.Nil(t, err) + assert.Equal(t, 12, dbCfg.MaxConnection) + assert.True(t, dbCfg.Autoconnect) + assert.Equal(t, "info", dbCfg.Connectioninfo) +} + +type DatabaseConfig struct { + MaxConnection int `json:"maxconnection"` + Autoconnect bool `json:"autoconnect"` + Connectioninfo string `json:"connectioninfo"` +} diff --git a/pkg/config/xml/xml.go b/pkg/infrastructure/config/xml/xml.go similarity index 53% rename from pkg/config/xml/xml.go rename to pkg/infrastructure/config/xml/xml.go index 9b5ec791..e5096b9b 100644 --- a/pkg/config/xml/xml.go +++ b/pkg/infrastructure/config/xml/xml.go @@ -26,10 +26,11 @@ // // cnf, err := config.NewConfig("xml", "config.xml") // -//More docs http://beego.me/docs/module/config.md +// More docs http://beego.me/docs/module/config.md package xml import ( + "context" "encoding/xml" "errors" "fmt" @@ -39,7 +40,11 @@ import ( "strings" "sync" - "github.com/astaxie/beego/pkg/config" + "github.com/mitchellh/mapstructure" + + "github.com/astaxie/beego/pkg/infrastructure/config" + "github.com/astaxie/beego/pkg/infrastructure/logs" + "github.com/beego/x2j" ) @@ -74,13 +79,55 @@ func (xc *Config) ParseData(data []byte) (config.Configer, error) { // ConfigContainer is a Config which represents the xml configuration. type ConfigContainer struct { - config.BaseConfiger data map[string]interface{} sync.Mutex } +// Unmarshaler is a little be inconvenient since the xml library doesn't know type. +// So when you use +// 1 +// The "1" is a string, not int +func (c *ConfigContainer) Unmarshaler(ctx context.Context, prefix string, obj interface{}, opt ...config.DecodeOption) error { + sub, err := c.sub(ctx, prefix) + if err != nil { + return err + } + return mapstructure.Decode(sub, obj) +} + +func (c *ConfigContainer) Sub(ctx context.Context, key string) (config.Configer, error) { + sub, err := c.sub(ctx, key) + if err != nil { + return nil, err + } + + return &ConfigContainer{ + data: sub, + }, nil + +} + +func (c *ConfigContainer) sub(ctx context.Context, key string) (map[string]interface{}, error) { + if key == "" { + return c.data, nil + } + value, ok := c.data[key] + if !ok { + return nil, errors.New(fmt.Sprintf("the key is not found: %s", key)) + } + res, ok := value.(map[string]interface{}) + if !ok { + return nil, errors.New(fmt.Sprintf("the value of this key is not a structure: %s", key)) + } + return res, nil +} + +func (c *ConfigContainer) OnChange(ctx context.Context, key string, fn func(value string)) { + logs.Warn("Unsupported operation") +} + // Bool returns the boolean value for a given key. -func (c *ConfigContainer) Bool(key string) (bool, error) { +func (c *ConfigContainer) Bool(ctx context.Context, key string) (bool, error) { if v := c.data[key]; v != nil { return config.ParseBool(v) } @@ -88,100 +135,100 @@ func (c *ConfigContainer) Bool(key string) (bool, error) { } // DefaultBool return the bool value if has no error -// otherwise return the defaultval -func (c *ConfigContainer) DefaultBool(key string, defaultval bool) bool { - v, err := c.Bool(key) +// otherwise return the defaultVal +func (c *ConfigContainer) DefaultBool(ctx context.Context, key string, defaultVal bool) bool { + v, err := c.Bool(ctx, key) if err != nil { - return defaultval + return defaultVal } return v } // Int returns the integer value for a given key. -func (c *ConfigContainer) Int(key string) (int, error) { +func (c *ConfigContainer) Int(ctx context.Context, key string) (int, error) { return strconv.Atoi(c.data[key].(string)) } // DefaultInt returns the integer value for a given key. -// if err != nil return defaultval -func (c *ConfigContainer) DefaultInt(key string, defaultval int) int { - v, err := c.Int(key) +// if err != nil return defaultVal +func (c *ConfigContainer) DefaultInt(ctx context.Context, key string, defaultVal int) int { + v, err := c.Int(ctx, key) if err != nil { - return defaultval + return defaultVal } return v } // Int64 returns the int64 value for a given key. -func (c *ConfigContainer) Int64(key string) (int64, error) { +func (c *ConfigContainer) Int64(ctx context.Context, key string) (int64, error) { return strconv.ParseInt(c.data[key].(string), 10, 64) } // DefaultInt64 returns the int64 value for a given key. -// if err != nil return defaultval -func (c *ConfigContainer) DefaultInt64(key string, defaultval int64) int64 { - v, err := c.Int64(key) +// if err != nil return defaultVal +func (c *ConfigContainer) DefaultInt64(ctx context.Context, key string, defaultVal int64) int64 { + v, err := c.Int64(ctx, key) if err != nil { - return defaultval + return defaultVal } return v } // Float returns the float value for a given key. -func (c *ConfigContainer) Float(key string) (float64, error) { +func (c *ConfigContainer) Float(ctx context.Context, key string) (float64, error) { return strconv.ParseFloat(c.data[key].(string), 64) } // DefaultFloat returns the float64 value for a given key. -// if err != nil return defaultval -func (c *ConfigContainer) DefaultFloat(key string, defaultval float64) float64 { - v, err := c.Float(key) +// if err != nil return defaultVal +func (c *ConfigContainer) DefaultFloat(ctx context.Context, key string, defaultVal float64) float64 { + v, err := c.Float(ctx, key) if err != nil { - return defaultval + return defaultVal } return v } // String returns the string value for a given key. -func (c *ConfigContainer) String(key string) string { +func (c *ConfigContainer) String(ctx context.Context, key string) (string, error) { if v, ok := c.data[key].(string); ok { - return v + return v, nil } - return "" + return "", nil } // DefaultString returns the string value for a given key. -// if err != nil return defaultval -func (c *ConfigContainer) DefaultString(key string, defaultval string) string { - v := c.String(key) - if v == "" { - return defaultval +// if err != nil return defaultVal +func (c *ConfigContainer) DefaultString(ctx context.Context, key string, defaultVal string) string { + v, err := c.String(ctx, key) + if v == "" || err != nil { + return defaultVal } return v } // Strings returns the []string value for a given key. -func (c *ConfigContainer) Strings(key string) []string { - v := c.String(key) - if v == "" { - return nil +func (c *ConfigContainer) Strings(ctx context.Context, key string) ([]string, error) { + v, err := c.String(ctx, key) + if v == "" || err != nil { + return nil, err } - return strings.Split(v, ";") + return strings.Split(v, ";"), nil } // DefaultStrings returns the []string value for a given key. -// if err != nil return defaultval -func (c *ConfigContainer) DefaultStrings(key string, defaultval []string) []string { - v := c.Strings(key) - if v == nil { - return defaultval +// if err != nil return defaultVal +func (c *ConfigContainer) DefaultStrings(ctx context.Context, key string, defaultVal []string) []string { + v, err := c.Strings(ctx, key) + if v == nil || err != nil { + return defaultVal } return v } // GetSection returns map for the given section -func (c *ConfigContainer) GetSection(section string) (map[string]string, error) { +func (c *ConfigContainer) GetSection(ctx context.Context, section string) (map[string]string, error) { if v, ok := c.data[section].(map[string]interface{}); ok { mapstr := make(map[string]string) for k, val := range v { @@ -193,7 +240,7 @@ func (c *ConfigContainer) GetSection(section string) (map[string]string, error) } // SaveConfigFile save the config into file -func (c *ConfigContainer) SaveConfigFile(filename string) (err error) { +func (c *ConfigContainer) SaveConfigFile(ctx context.Context, filename string) (err error) { // Write configuration file by filename. f, err := os.Create(filename) if err != nil { @@ -209,7 +256,7 @@ func (c *ConfigContainer) SaveConfigFile(filename string) (err error) { } // Set writes a new value for key. -func (c *ConfigContainer) Set(key, val string) error { +func (c *ConfigContainer) Set(ctx context.Context, key, val string) error { c.Lock() defer c.Unlock() c.data[key] = val @@ -217,7 +264,7 @@ func (c *ConfigContainer) Set(key, val string) error { } // DIY returns the raw value by a given key. -func (c *ConfigContainer) DIY(key string) (v interface{}, err error) { +func (c *ConfigContainer) DIY(ctx context.Context, key string) (v interface{}, err error) { if v, ok := c.data[key]; ok { return v, nil } diff --git a/pkg/infrastructure/config/xml/xml_test.go b/pkg/infrastructure/config/xml/xml_test.go new file mode 100644 index 00000000..0a3eb313 --- /dev/null +++ b/pkg/infrastructure/config/xml/xml_test.go @@ -0,0 +1,158 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package xml + +import ( + "context" + "fmt" + "os" + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/astaxie/beego/pkg/infrastructure/config" +) + +func TestXML(t *testing.T) { + + var ( + // xml parse should incluce in tags + xmlcontext = ` + +beeapi +8080 +3600 +3.1415976 +dev +false +true +${GOPATH} +${GOPATH||/home/go} + +1 +MySection + + +` + keyValue = map[string]interface{}{ + "appname": "beeapi", + "httpport": 8080, + "mysqlport": int64(3600), + "PI": 3.1415976, + "runmode": "dev", + "autorender": false, + "copyrequestbody": true, + "path1": os.Getenv("GOPATH"), + "path2": os.Getenv("GOPATH"), + "error": "", + "emptystrings": []string{}, + } + ) + + f, err := os.Create("testxml.conf") + if err != nil { + t.Fatal(err) + } + _, err = f.WriteString(xmlcontext) + if err != nil { + f.Close() + t.Fatal(err) + } + f.Close() + defer os.Remove("testxml.conf") + + xmlconf, err := config.NewConfig("xml", "testxml.conf") + if err != nil { + t.Fatal(err) + } + + var xmlsection map[string]string + xmlsection, err = xmlconf.GetSection(nil, "mysection") + if err != nil { + t.Fatal(err) + } + + if len(xmlsection) == 0 { + t.Error("section should not be empty") + } + + for k, v := range keyValue { + + var ( + value interface{} + err error + ) + + switch v.(type) { + case int: + value, err = xmlconf.Int(nil, k) + case int64: + value, err = xmlconf.Int64(nil, k) + case float64: + value, err = xmlconf.Float(nil, k) + case bool: + value, err = xmlconf.Bool(nil, k) + case []string: + value, err = xmlconf.Strings(nil, k) + case string: + value, err = xmlconf.String(nil, k) + default: + value, err = xmlconf.DIY(nil, k) + } + if err != nil { + t.Errorf("get key %q value fatal,%v err %s", k, v, err) + } else if fmt.Sprintf("%v", v) != fmt.Sprintf("%v", value) { + t.Errorf("get key %q value, want %v got %v .", k, v, value) + } + + } + + if err = xmlconf.Set(nil, "name", "astaxie"); err != nil { + t.Fatal(err) + } + + res, _ := xmlconf.String(context.Background(), "name") + if res != "astaxie" { + t.Fatal("get name error") + } + + sub, err := xmlconf.Sub(context.Background(), "mysection") + assert.Nil(t, err) + assert.NotNil(t, sub) + name, err := sub.String(context.Background(), "name") + assert.Nil(t, err) + assert.Equal(t, "MySection", name) + + id, err := sub.Int(context.Background(), "id") + assert.Nil(t, err) + assert.Equal(t, 1, id) + + sec := &Section{} + + err = sub.Unmarshaler(context.Background(), "", sec) + assert.Nil(t, err) + assert.Equal(t, "MySection", sec.Name) + + sec = &Section{} + + err = xmlconf.Unmarshaler(context.Background(), "mysection", sec) + assert.Nil(t, err) + assert.Equal(t, "MySection", sec.Name) + +} + +type Section struct { + Name string `xml:"name"` +} diff --git a/pkg/config/yaml/yaml.go b/pkg/infrastructure/config/yaml/yaml.go similarity index 60% rename from pkg/config/yaml/yaml.go rename to pkg/infrastructure/config/yaml/yaml.go index 5c77e88f..61ea45b9 100644 --- a/pkg/config/yaml/yaml.go +++ b/pkg/infrastructure/config/yaml/yaml.go @@ -26,11 +26,12 @@ // // cnf, err := config.NewConfig("yaml", "config.yaml") // -//More docs http://beego.me/docs/module/config.md +// More docs http://beego.me/docs/module/config.md package yaml import ( "bytes" + "context" "encoding/json" "errors" "fmt" @@ -40,8 +41,11 @@ import ( "strings" "sync" - "github.com/astaxie/beego/pkg/config" "github.com/beego/goyaml2" + "gopkg.in/yaml.v2" + + "github.com/astaxie/beego/pkg/infrastructure/config" + "github.com/astaxie/beego/pkg/infrastructure/logs" ) // Config is a yaml config parser and implements Config interface. @@ -118,13 +122,63 @@ func parseYML(buf []byte) (cnf map[string]interface{}, err error) { // ConfigContainer is a config which represents the yaml configuration. type ConfigContainer struct { - config.BaseConfiger data map[string]interface{} sync.RWMutex } +// Unmarshaler is similar to Sub +func (c *ConfigContainer) Unmarshaler(ctx context.Context, prefix string, obj interface{}, opt ...config.DecodeOption) error { + sub, err := c.sub(ctx, prefix) + if err != nil { + return err + } + + bytes, err := yaml.Marshal(sub) + if err != nil { + return err + } + return yaml.Unmarshal(bytes, obj) +} + +func (c *ConfigContainer) Sub(ctx context.Context, key string) (config.Configer, error) { + sub, err := c.sub(ctx, key) + if err != nil { + return nil, err + } + return &ConfigContainer{ + data: sub, + }, nil +} + +func (c *ConfigContainer) sub(ctx context.Context, key string) (map[string]interface{}, error) { + tmpData := c.data + keys := strings.Split(key, ".") + for idx, k := range keys { + if v, ok := tmpData[k]; ok { + switch v.(type) { + case map[string]interface{}: + { + tmpData = v.(map[string]interface{}) + if idx == len(keys)-1 { + return tmpData, nil + } + } + default: + return nil, errors.New(fmt.Sprintf("the key is invalid: %s", key)) + } + } + } + + return tmpData, nil +} + +func (c *ConfigContainer) OnChange(ctx context.Context, key string, fn func(value string)) { + // do nothing + logs.Warn("Unsupported operation: OnChange") +} + // Bool returns the boolean value for a given key. -func (c *ConfigContainer) Bool(key string) (bool, error) { +func (c *ConfigContainer) Bool(ctx context.Context, key string) (bool, error) { v, err := c.getData(key) if err != nil { return false, err @@ -133,17 +187,17 @@ func (c *ConfigContainer) Bool(key string) (bool, error) { } // DefaultBool return the bool value if has no error -// otherwise return the defaultval -func (c *ConfigContainer) DefaultBool(key string, defaultval bool) bool { - v, err := c.Bool(key) +// otherwise return the defaultVal +func (c *ConfigContainer) DefaultBool(ctx context.Context, key string, defaultVal bool) bool { + v, err := c.Bool(ctx, key) if err != nil { - return defaultval + return defaultVal } return v } // Int returns the integer value for a given key. -func (c *ConfigContainer) Int(key string) (int, error) { +func (c *ConfigContainer) Int(ctx context.Context, key string) (int, error) { if v, err := c.getData(key); err != nil { return 0, err } else if vv, ok := v.(int); ok { @@ -155,17 +209,17 @@ func (c *ConfigContainer) Int(key string) (int, error) { } // DefaultInt returns the integer value for a given key. -// if err != nil return defaultval -func (c *ConfigContainer) DefaultInt(key string, defaultval int) int { - v, err := c.Int(key) +// if err != nil return defaultVal +func (c *ConfigContainer) DefaultInt(ctx context.Context, key string, defaultVal int) int { + v, err := c.Int(ctx, key) if err != nil { - return defaultval + return defaultVal } return v } // Int64 returns the int64 value for a given key. -func (c *ConfigContainer) Int64(key string) (int64, error) { +func (c *ConfigContainer) Int64(ctx context.Context, key string) (int64, error) { if v, err := c.getData(key); err != nil { return 0, err } else if vv, ok := v.(int64); ok { @@ -175,17 +229,17 @@ func (c *ConfigContainer) Int64(key string) (int64, error) { } // DefaultInt64 returns the int64 value for a given key. -// if err != nil return defaultval -func (c *ConfigContainer) DefaultInt64(key string, defaultval int64) int64 { - v, err := c.Int64(key) +// if err != nil return defaultVal +func (c *ConfigContainer) DefaultInt64(ctx context.Context, key string, defaultVal int64) int64 { + v, err := c.Int64(ctx, key) if err != nil { - return defaultval + return defaultVal } return v } // Float returns the float value for a given key. -func (c *ConfigContainer) Float(key string) (float64, error) { +func (c *ConfigContainer) Float(ctx context.Context, key string) (float64, error) { if v, err := c.getData(key); err != nil { return 0.0, err } else if vv, ok := v.(float64); ok { @@ -199,56 +253,56 @@ func (c *ConfigContainer) Float(key string) (float64, error) { } // DefaultFloat returns the float64 value for a given key. -// if err != nil return defaultval -func (c *ConfigContainer) DefaultFloat(key string, defaultval float64) float64 { - v, err := c.Float(key) +// if err != nil return defaultVal +func (c *ConfigContainer) DefaultFloat(ctx context.Context, key string, defaultVal float64) float64 { + v, err := c.Float(ctx, key) if err != nil { - return defaultval + return defaultVal } return v } // String returns the string value for a given key. -func (c *ConfigContainer) String(key string) string { +func (c *ConfigContainer) String(ctx context.Context, key string) (string, error) { if v, err := c.getData(key); err == nil { if vv, ok := v.(string); ok { - return vv + return vv, nil } } - return "" + return "", nil } // DefaultString returns the string value for a given key. -// if err != nil return defaultval -func (c *ConfigContainer) DefaultString(key string, defaultval string) string { - v := c.String(key) - if v == "" { - return defaultval +// if err != nil return defaultVal +func (c *ConfigContainer) DefaultString(ctx context.Context, key string, defaultVal string) string { + v, err := c.String(nil, key) + if v == "" || err != nil { + return defaultVal } return v } // Strings returns the []string value for a given key. -func (c *ConfigContainer) Strings(key string) []string { - v := c.String(key) - if v == "" { - return nil +func (c *ConfigContainer) Strings(ctx context.Context, key string) ([]string, error) { + v, err := c.String(nil, key) + if v == "" || err != nil { + return nil, err } - return strings.Split(v, ";") + return strings.Split(v, ";"), nil } // DefaultStrings returns the []string value for a given key. -// if err != nil return defaultval -func (c *ConfigContainer) DefaultStrings(key string, defaultval []string) []string { - v := c.Strings(key) - if v == nil { - return defaultval +// if err != nil return defaultVal +func (c *ConfigContainer) DefaultStrings(ctx context.Context, key string, defaultVal []string) []string { + v, err := c.Strings(ctx, key) + if v == nil || err != nil { + return defaultVal } return v } // GetSection returns map for the given section -func (c *ConfigContainer) GetSection(section string) (map[string]string, error) { +func (c *ConfigContainer) GetSection(ctx context.Context, section string) (map[string]string, error) { if v, ok := c.data[section]; ok { return v.(map[string]string), nil @@ -257,7 +311,7 @@ func (c *ConfigContainer) GetSection(section string) (map[string]string, error) } // SaveConfigFile save the config into file -func (c *ConfigContainer) SaveConfigFile(filename string) (err error) { +func (c *ConfigContainer) SaveConfigFile(ctx context.Context, filename string) (err error) { // Write configuration file by filename. f, err := os.Create(filename) if err != nil { @@ -269,7 +323,7 @@ func (c *ConfigContainer) SaveConfigFile(filename string) (err error) { } // Set writes a new value for key. -func (c *ConfigContainer) Set(key, val string) error { +func (c *ConfigContainer) Set(ctx context.Context, key, val string) error { c.Lock() defer c.Unlock() c.data[key] = val @@ -277,7 +331,7 @@ func (c *ConfigContainer) Set(key, val string) error { } // DIY returns the raw value by a given key. -func (c *ConfigContainer) DIY(key string) (v interface{}, err error) { +func (c *ConfigContainer) DIY(ctx context.Context, key string) (v interface{}, err error) { return c.getData(key) } @@ -289,7 +343,7 @@ func (c *ConfigContainer) getData(key string) (interface{}, error) { c.RLock() defer c.RUnlock() - keys := strings.Split(key, ".") + keys := strings.Split(c.key(key), ".") tmpData := c.data for idx, k := range keys { if v, ok := tmpData[k]; ok { @@ -312,6 +366,10 @@ func (c *ConfigContainer) getData(key string) (interface{}, error) { return nil, fmt.Errorf("not exist key %q", key) } +func (c *ConfigContainer) key(key string) string { + return key +} + func init() { config.Register("yaml", &Config{}) } diff --git a/pkg/infrastructure/config/yaml/yaml_test.go b/pkg/infrastructure/config/yaml/yaml_test.go new file mode 100644 index 00000000..1fd4e894 --- /dev/null +++ b/pkg/infrastructure/config/yaml/yaml_test.go @@ -0,0 +1,152 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package yaml + +import ( + "context" + "fmt" + "os" + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/astaxie/beego/pkg/infrastructure/config" +) + +func TestYaml(t *testing.T) { + + var ( + yamlcontext = ` +"appname": beeapi +"httpport": 8080 +"mysqlport": 3600 +"PI": 3.1415976 +"runmode": dev +"autorender": false +"copyrequestbody": true +"PATH": GOPATH +"path1": ${GOPATH} +"path2": ${GOPATH||/home/go} +"empty": "" +"user": + "name": "tom" + "age": 13 +` + + keyValue = map[string]interface{}{ + "appname": "beeapi", + "httpport": 8080, + "mysqlport": int64(3600), + "PI": 3.1415976, + "runmode": "dev", + "autorender": false, + "copyrequestbody": true, + "PATH": "GOPATH", + "path1": os.Getenv("GOPATH"), + "path2": os.Getenv("GOPATH"), + "error": "", + "emptystrings": []string{}, + } + ) + f, err := os.Create("testyaml.conf") + if err != nil { + t.Fatal(err) + } + _, err = f.WriteString(yamlcontext) + if err != nil { + f.Close() + t.Fatal(err) + } + f.Close() + defer os.Remove("testyaml.conf") + yamlconf, err := config.NewConfig("yaml", "testyaml.conf") + if err != nil { + t.Fatal(err) + } + + res, _ := yamlconf.String(nil, "appname") + if res != "beeapi" { + t.Fatal("appname not equal to beeapi") + } + + for k, v := range keyValue { + + var ( + value interface{} + err error + ) + + switch v.(type) { + case int: + value, err = yamlconf.Int(nil, k) + case int64: + value, err = yamlconf.Int64(nil, k) + case float64: + value, err = yamlconf.Float(nil, k) + case bool: + value, err = yamlconf.Bool(nil, k) + case []string: + value, err = yamlconf.Strings(nil, k) + case string: + value, err = yamlconf.String(nil, k) + default: + value, err = yamlconf.DIY(nil, k) + } + if err != nil { + t.Errorf("get key %q value fatal,%v err %s", k, v, err) + } else if fmt.Sprintf("%v", v) != fmt.Sprintf("%v", value) { + t.Errorf("get key %q value, want %v got %v .", k, v, value) + } + + } + + if err = yamlconf.Set(nil, "name", "astaxie"); err != nil { + t.Fatal(err) + } + res, _ = yamlconf.String(nil, "name") + if res != "astaxie" { + t.Fatal("get name error") + } + + sub, err := yamlconf.Sub(context.Background(), "user") + assert.Nil(t, err) + assert.NotNil(t, sub) + name, err := sub.String(context.Background(), "name") + assert.Nil(t, err) + assert.Equal(t, "tom", name) + + age, err := sub.Int(context.Background(), "age") + assert.Nil(t, err) + assert.Equal(t, 13, age) + + user := &User{} + + err = sub.Unmarshaler(context.Background(), "", user) + assert.Nil(t, err) + assert.Equal(t, "tom", user.Name) + assert.Equal(t, 13, user.Age) + + user = &User{} + + err = yamlconf.Unmarshaler(context.Background(), "user", user) + assert.Nil(t, err) + assert.Equal(t, "tom", user.Name) + assert.Equal(t, 13, user.Age) +} + +type User struct { + Name string `yaml:"name"` + Age int `yaml:"age"` +} diff --git a/pkg/toolbox/healthcheck.go b/pkg/infrastructure/governor/healthcheck.go similarity index 96% rename from pkg/toolbox/healthcheck.go rename to pkg/infrastructure/governor/healthcheck.go index e3544b3a..a91f09fa 100644 --- a/pkg/toolbox/healthcheck.go +++ b/pkg/infrastructure/governor/healthcheck.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -// Package toolbox healthcheck +// Package governor healthcheck // // type DatabaseCheck struct { // } @@ -28,7 +28,7 @@ // AddHealthCheck("database",&DatabaseCheck{}) // // more docs: http://beego.me/docs/module/toolbox.md -package toolbox +package governor // AdminCheckList holds health checker map var AdminCheckList map[string]HealthChecker diff --git a/pkg/toolbox/profile.go b/pkg/infrastructure/governor/profile.go similarity index 82% rename from pkg/toolbox/profile.go rename to pkg/infrastructure/governor/profile.go index 06e40ede..c40cf6ba 100644 --- a/pkg/toolbox/profile.go +++ b/pkg/infrastructure/governor/profile.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package toolbox +package governor import ( "fmt" @@ -25,6 +25,8 @@ import ( "runtime/pprof" "strconv" "time" + + "github.com/astaxie/beego/pkg/infrastructure/utils" ) var startTime = time.Now() @@ -112,15 +114,15 @@ func printGC(memStats *runtime.MemStats, gcstats *debug.GCStats, w io.Writer) { fmt.Fprintf(w, "NumGC:%d Pause:%s Pause(Avg):%s Overhead:%3.2f%% Alloc:%s Sys:%s Alloc(Rate):%s/s Histogram:%s %s %s \n", gcstats.NumGC, - toS(lastPause), - toS(avg(gcstats.Pause)), + utils.ToShortTimeFormat(lastPause), + utils.ToShortTimeFormat(avg(gcstats.Pause)), overhead, toH(memStats.Alloc), toH(memStats.Sys), toH(uint64(allocatedRate)), - toS(gcstats.PauseQuantiles[94]), - toS(gcstats.PauseQuantiles[98]), - toS(gcstats.PauseQuantiles[99])) + utils.ToShortTimeFormat(gcstats.PauseQuantiles[94]), + utils.ToShortTimeFormat(gcstats.PauseQuantiles[98]), + utils.ToShortTimeFormat(gcstats.PauseQuantiles[99])) } else { // while GC has disabled elapsed := time.Now().Sub(startTime) @@ -154,31 +156,3 @@ func toH(bytes uint64) string { return fmt.Sprintf("%.2fG", float64(bytes)/1024/1024/1024) } } - -// short string format -func toS(d time.Duration) string { - - u := uint64(d) - if u < uint64(time.Second) { - switch { - case u == 0: - return "0" - case u < uint64(time.Microsecond): - return fmt.Sprintf("%.2fns", float64(u)) - case u < uint64(time.Millisecond): - return fmt.Sprintf("%.2fus", float64(u)/1000) - default: - return fmt.Sprintf("%.2fms", float64(u)/1000/1000) - } - } else { - switch { - case u < uint64(time.Minute): - return fmt.Sprintf("%.2fs", float64(u)/1000/1000/1000) - case u < uint64(time.Hour): - return fmt.Sprintf("%.2fm", float64(u)/1000/1000/1000/60) - default: - return fmt.Sprintf("%.2fh", float64(u)/1000/1000/1000/60/60) - } - } - -} diff --git a/pkg/infrastructure/governor/profile_test.go b/pkg/infrastructure/governor/profile_test.go new file mode 100644 index 00000000..530b0637 --- /dev/null +++ b/pkg/infrastructure/governor/profile_test.go @@ -0,0 +1,28 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package governor + +import ( + "os" + "testing" +) + +func TestProcessInput(t *testing.T) { + ProcessInput("lookup goroutine", os.Stdout) + ProcessInput("lookup heap", os.Stdout) + ProcessInput("lookup threadcreate", os.Stdout) + ProcessInput("lookup block", os.Stdout) + ProcessInput("gc summary", os.Stdout) +} diff --git a/pkg/logs/README.md b/pkg/infrastructure/logs/README.md similarity index 100% rename from pkg/logs/README.md rename to pkg/infrastructure/logs/README.md diff --git a/pkg/logs/accesslog.go b/pkg/infrastructure/logs/accesslog.go similarity index 100% rename from pkg/logs/accesslog.go rename to pkg/infrastructure/logs/accesslog.go diff --git a/pkg/logs/alils/alils.go b/pkg/infrastructure/logs/alils/alils.go similarity index 94% rename from pkg/logs/alils/alils.go rename to pkg/infrastructure/logs/alils/alils.go index 425071f8..03e97045 100644 --- a/pkg/logs/alils/alils.go +++ b/pkg/infrastructure/logs/alils/alils.go @@ -5,8 +5,8 @@ import ( "strings" "sync" - "github.com/astaxie/beego/pkg/common" - "github.com/astaxie/beego/pkg/logs" + "github.com/astaxie/beego/pkg/infrastructure/logs" + "github.com/astaxie/beego/pkg/infrastructure/utils" "github.com/gogo/protobuf/proto" ) @@ -50,10 +50,10 @@ func NewAliLS() logs.Logger { } // Init parses config and initializes struct -func (c *aliLSWriter) Init(jsonConfig string, opts ...common.SimpleKV) error { +func (c *aliLSWriter) Init(jsonConfig string, opts ...utils.KV) error { for _, elem := range opts { - if elem.Key == "formatter" { + if elem.GetKey() == "formatter" { formatter, err := logs.GetFormatter(elem) if err != nil { return err diff --git a/pkg/logs/alils/config.go b/pkg/infrastructure/logs/alils/config.go similarity index 100% rename from pkg/logs/alils/config.go rename to pkg/infrastructure/logs/alils/config.go diff --git a/pkg/logs/alils/log.pb.go b/pkg/infrastructure/logs/alils/log.pb.go similarity index 100% rename from pkg/logs/alils/log.pb.go rename to pkg/infrastructure/logs/alils/log.pb.go diff --git a/pkg/logs/alils/log_config.go b/pkg/infrastructure/logs/alils/log_config.go similarity index 100% rename from pkg/logs/alils/log_config.go rename to pkg/infrastructure/logs/alils/log_config.go diff --git a/pkg/logs/alils/log_project.go b/pkg/infrastructure/logs/alils/log_project.go similarity index 100% rename from pkg/logs/alils/log_project.go rename to pkg/infrastructure/logs/alils/log_project.go diff --git a/pkg/logs/alils/log_store.go b/pkg/infrastructure/logs/alils/log_store.go similarity index 100% rename from pkg/logs/alils/log_store.go rename to pkg/infrastructure/logs/alils/log_store.go diff --git a/pkg/logs/alils/machine_group.go b/pkg/infrastructure/logs/alils/machine_group.go similarity index 100% rename from pkg/logs/alils/machine_group.go rename to pkg/infrastructure/logs/alils/machine_group.go diff --git a/pkg/logs/alils/request.go b/pkg/infrastructure/logs/alils/request.go similarity index 100% rename from pkg/logs/alils/request.go rename to pkg/infrastructure/logs/alils/request.go diff --git a/pkg/logs/alils/signature.go b/pkg/infrastructure/logs/alils/signature.go similarity index 100% rename from pkg/logs/alils/signature.go rename to pkg/infrastructure/logs/alils/signature.go diff --git a/pkg/logs/conn.go b/pkg/infrastructure/logs/conn.go similarity index 94% rename from pkg/logs/conn.go rename to pkg/infrastructure/logs/conn.go index 9a520bda..f7d44d7f 100644 --- a/pkg/logs/conn.go +++ b/pkg/infrastructure/logs/conn.go @@ -19,7 +19,7 @@ import ( "io" "net" - "github.com/astaxie/beego/pkg/common" + "github.com/astaxie/beego/pkg/infrastructure/utils" ) // connWriter implements LoggerInterface. @@ -48,10 +48,10 @@ func (c *connWriter) Format(lm *LogMsg) string { // Init initializes a connection writer with json config. // json config only needs they "level" key -func (c *connWriter) Init(jsonConfig string, opts ...common.SimpleKV) error { +func (c *connWriter) Init(jsonConfig string, opts ...utils.KV) error { for _, elem := range opts { - if elem.Key == "formatter" { + if elem.GetKey() == "formatter" { formatter, err := GetFormatter(elem) if err != nil { return err diff --git a/pkg/logs/conn_test.go b/pkg/infrastructure/logs/conn_test.go similarity index 100% rename from pkg/logs/conn_test.go rename to pkg/infrastructure/logs/conn_test.go diff --git a/pkg/logs/console.go b/pkg/infrastructure/logs/console.go similarity index 94% rename from pkg/logs/console.go rename to pkg/infrastructure/logs/console.go index a3e5fb5a..802d79f5 100644 --- a/pkg/logs/console.go +++ b/pkg/infrastructure/logs/console.go @@ -19,7 +19,7 @@ import ( "os" "strings" - "github.com/astaxie/beego/pkg/common" + "github.com/astaxie/beego/pkg/infrastructure/utils" "github.com/shiena/ansicolor" ) @@ -77,10 +77,10 @@ func NewConsole() Logger { // Init initianlizes the console logger. // jsonConfig must be in the format '{"level":LevelTrace}' -func (c *consoleWriter) Init(jsonConfig string, opts ...common.SimpleKV) error { +func (c *consoleWriter) Init(jsonConfig string, opts ...utils.KV) error { for _, elem := range opts { - if elem.Key == "formatter" { + if elem.GetKey() == "formatter" { formatter, err := GetFormatter(elem) if err != nil { return err diff --git a/pkg/logs/console_test.go b/pkg/infrastructure/logs/console_test.go similarity index 100% rename from pkg/logs/console_test.go rename to pkg/infrastructure/logs/console_test.go diff --git a/pkg/logs/es/es.go b/pkg/infrastructure/logs/es/es.go similarity index 91% rename from pkg/logs/es/es.go rename to pkg/infrastructure/logs/es/es.go index dc9304c8..857a1a34 100644 --- a/pkg/logs/es/es.go +++ b/pkg/infrastructure/logs/es/es.go @@ -12,8 +12,8 @@ import ( "github.com/elastic/go-elasticsearch/v6" "github.com/elastic/go-elasticsearch/v6/esapi" - "github.com/astaxie/beego/pkg/common" - "github.com/astaxie/beego/pkg/logs" + "github.com/astaxie/beego/pkg/infrastructure/logs" + "github.com/astaxie/beego/pkg/infrastructure/utils" ) // NewES returns a LoggerInterface @@ -42,10 +42,10 @@ func (el *esLogger) Format(lm *logs.LogMsg) string { } // {"dsn":"http://localhost:9200/","level":1} -func (el *esLogger) Init(jsonConfig string, opts ...common.SimpleKV) error { +func (el *esLogger) Init(jsonConfig string, opts ...utils.KV) error { for _, elem := range opts { - if elem.Key == "formatter" { + if elem.GetKey() == "formatter" { formatter, err := logs.GetFormatter(elem) if err != nil { return err diff --git a/pkg/logs/file.go b/pkg/infrastructure/logs/file.go similarity index 97% rename from pkg/logs/file.go rename to pkg/infrastructure/logs/file.go index 42148c3a..0c96918c 100644 --- a/pkg/logs/file.go +++ b/pkg/infrastructure/logs/file.go @@ -28,7 +28,7 @@ import ( "sync" "time" - "github.com/astaxie/beego/pkg/common" + "github.com/astaxie/beego/pkg/infrastructure/utils" ) // fileLogWriter implements LoggerInterface. @@ -108,10 +108,10 @@ func (w *fileLogWriter) Format(lm *LogMsg) string { // "rotate":true, // "perm":"0600" // } -func (w *fileLogWriter) Init(jsonConfig string, opts ...common.SimpleKV) error { +func (w *fileLogWriter) Init(jsonConfig string, opts ...utils.KV) error { for _, elem := range opts { - if elem.Key == "formatter" { + if elem.GetKey() == "formatter" { formatter, err := GetFormatter(elem) if err != nil { return err @@ -329,7 +329,7 @@ func (w *fileLogWriter) doRotate(logTime time.Time) error { _, err = os.Lstat(w.Filename) if err != nil { - //even if the file is not exist or other ,we should RESTART the logger + // even if the file is not exist or other ,we should RESTART the logger goto RESTART_LOGGER } diff --git a/pkg/logs/file_test.go b/pkg/infrastructure/logs/file_test.go similarity index 100% rename from pkg/logs/file_test.go rename to pkg/infrastructure/logs/file_test.go diff --git a/pkg/logs/jianliao.go b/pkg/infrastructure/logs/jianliao.go similarity index 92% rename from pkg/logs/jianliao.go rename to pkg/infrastructure/logs/jianliao.go index 81d0195b..88750125 100644 --- a/pkg/logs/jianliao.go +++ b/pkg/infrastructure/logs/jianliao.go @@ -6,7 +6,7 @@ import ( "net/http" "net/url" - "github.com/astaxie/beego/pkg/common" + "github.com/astaxie/beego/pkg/infrastructure/utils" ) // JLWriter implements beego LoggerInterface and is used to send jiaoliao webhook @@ -26,9 +26,9 @@ func newJLWriter() Logger { } // Init JLWriter with json config string -func (s *JLWriter) Init(jsonConfig string, opts ...common.SimpleKV) error { +func (s *JLWriter) Init(jsonConfig string, opts ...utils.KV) error { for _, elem := range opts { - if elem.Key == "formatter" { + if elem.GetKey() == "formatter" { formatter, err := GetFormatter(elem) if err != nil { return err diff --git a/pkg/logs/log.go b/pkg/infrastructure/logs/log.go similarity index 96% rename from pkg/logs/log.go rename to pkg/infrastructure/logs/log.go index e18ea95b..2d400eba 100644 --- a/pkg/logs/log.go +++ b/pkg/infrastructure/logs/log.go @@ -44,7 +44,7 @@ import ( "sync" "time" - "github.com/astaxie/beego/pkg/common" + "github.com/astaxie/beego/pkg/infrastructure/utils" ) // RFC5424 log message levels. @@ -87,7 +87,7 @@ type newLoggerFunc func() Logger // Logger defines the behavior of a log provider. type Logger interface { - Init(config string, opts ...common.SimpleKV) error + Init(config string, opts ...utils.KV) error WriteMsg(lm *LogMsg) error Format(lm *LogMsg) string Destroy() @@ -213,7 +213,7 @@ func (bl *BeeLogger) setLogger(adapterName string, configs ...string) error { // Global formatter overrides the default set formatter // but not adapter specific formatters set with logs.SetLoggerWithOpts() if bl.globalFormatter != nil { - err = lg.Init(config, common.SimpleKV{Key: "formatter", Value: bl.globalFormatter}) + err = lg.Init(config, &utils.SimpleKV{Key: "formatter", Value: bl.globalFormatter}) } else { err = lg.Init(config) } @@ -321,7 +321,7 @@ func (bl *BeeLogger) writeMsg(lm *LogMsg, v ...interface{}) error { lm.Msg = fmt.Sprintf("[%s:%d] %s", lm.FilePath, lm.LineNumber, lm.Msg) } - //set level info in front of filename info + // set level info in front of filename info if lm.Level == levelLoggerImpl { // set to emergency to ensure all log will be print out correctly lm.Level = LevelEmergency @@ -406,9 +406,9 @@ func (bl *BeeLogger) startLogger() { // Get the formatter from the opts common.SimpleKV structure // Looks for a key: "formatter" with value: func(*LogMsg) string -func GetFormatter(opts common.SimpleKV) (func(*LogMsg) string, error) { - if strings.ToLower(opts.Key.(string)) == "formatter" { - formatterInterface := reflect.ValueOf(opts.Value).Interface() +func GetFormatter(opts utils.KV) (func(*LogMsg) string, error) { + if strings.ToLower(opts.GetKey().(string)) == "formatter" { + formatterInterface := reflect.ValueOf(opts.GetValue()).Interface() formatterFunc := formatterInterface.(func(*LogMsg) string) return formatterFunc, nil } @@ -418,7 +418,7 @@ func GetFormatter(opts common.SimpleKV) (func(*LogMsg) string, error) { // SetLoggerWithOpts sets a log adapter with a user defined logging format. Config must be valid JSON // such as: {"interval":360} -func (bl *BeeLogger) setLoggerWithOpts(adapterName string, opts common.SimpleKV, configs ...string) error { +func (bl *BeeLogger) setLoggerWithOpts(adapterName string, opts utils.KV, configs ...string) error { config := append(configs, "{}")[0] for _, l := range bl.outputs { if l.name == adapterName { @@ -431,7 +431,7 @@ func (bl *BeeLogger) setLoggerWithOpts(adapterName string, opts common.SimpleKV, return fmt.Errorf("logs: unknown adaptername %q (forgotten Register?)", adapterName) } - if opts.Key == nil { + if opts.GetKey() == nil { return fmt.Errorf("No SimpleKV struct set for %s log adapter", adapterName) } @@ -451,7 +451,7 @@ func (bl *BeeLogger) setLoggerWithOpts(adapterName string, opts common.SimpleKV, } // SetLogger provides a given logger adapter into BeeLogger with config string. -func (bl *BeeLogger) SetLoggerWithOpts(adapterName string, opts common.SimpleKV, configs ...string) error { +func (bl *BeeLogger) SetLoggerWithOpts(adapterName string, opts utils.KV, configs ...string) error { bl.lock.Lock() defer bl.lock.Unlock() if !bl.init { @@ -465,7 +465,7 @@ func (bl *BeeLogger) SetLoggerWithOpts(adapterName string, opts common.SimpleKV, // Log Adapter must be given in the form common.SimpleKV{Key: "formatter": Value: struct.FormatFunc} // where FormatFunc has the signature func(*LogMsg) string // func SetLoggerWithOpts(adapter string, config []string, formatterFunc func(*LogMsg) string) error { -func SetLoggerWithOpts(adapter string, config []string, opts common.SimpleKV) error { +func SetLoggerWithOpts(adapter string, config []string, opts utils.KV) error { err := beeLogger.SetLoggerWithOpts(adapter, opts, config...) if err != nil { log.Fatal(err) @@ -880,9 +880,9 @@ func formatLog(f interface{}, v ...interface{}) string { return msg } if strings.Contains(msg, "%") && !strings.Contains(msg, "%%") { - //format string + // format string } else { - //do not contain format char + // do not contain format char msg += strings.Repeat(" %v", len(v)) } default: diff --git a/pkg/infrastructure/logs/log_formatter_test.go b/pkg/infrastructure/logs/log_formatter_test.go new file mode 100644 index 00000000..73281cf6 --- /dev/null +++ b/pkg/infrastructure/logs/log_formatter_test.go @@ -0,0 +1,35 @@ +package logs + +import ( + "fmt" + "testing" + + "github.com/astaxie/beego/pkg/infrastructure/utils" +) + +func customFormatter(lm *LogMsg) string { + return fmt.Sprintf("[CUSTOM CONSOLE LOGGING] %s", lm.Msg) +} + +func globalFormatter(lm *LogMsg) string { + return fmt.Sprintf("[GLOBAL] %s", lm.Msg) +} + +func TestCustomLoggingFormatter(t *testing.T) { + // beego.BConfig.Log.AccessLogs = true + + SetLoggerWithOpts("console", []string{`{"color":true}`}, &utils.SimpleKV{Key: "formatter", Value: customFormatter}) + + // Message will be formatted by the customFormatter with colorful text set to true + Informational("Test message") +} + +func TestGlobalLoggingFormatter(t *testing.T) { + SetGlobalFormatter(globalFormatter) + + SetLogger("console", `{"color":true}`) + + // Message will be formatted by globalFormatter + Informational("Test message") + +} diff --git a/pkg/logs/logger.go b/pkg/infrastructure/logs/logger.go similarity index 100% rename from pkg/logs/logger.go rename to pkg/infrastructure/logs/logger.go diff --git a/pkg/logs/logger_test.go b/pkg/infrastructure/logs/logger_test.go similarity index 100% rename from pkg/logs/logger_test.go rename to pkg/infrastructure/logs/logger_test.go diff --git a/pkg/logs/multifile.go b/pkg/infrastructure/logs/multifile.go similarity index 95% rename from pkg/logs/multifile.go rename to pkg/infrastructure/logs/multifile.go index c1b7cfdd..bf589b91 100644 --- a/pkg/logs/multifile.go +++ b/pkg/infrastructure/logs/multifile.go @@ -17,7 +17,7 @@ package logs import ( "encoding/json" - "github.com/astaxie/beego/pkg/common" + "github.com/astaxie/beego/pkg/infrastructure/utils" ) // A filesLogWriter manages several fileLogWriter @@ -47,9 +47,9 @@ var levelNames = [...]string{"emergency", "alert", "critical", "error", "warning // "separate":["emergency", "alert", "critical", "error", "warning", "notice", "info", "debug"], // } -func (f *multiFileLogWriter) Init(jsonConfig string, opts ...common.SimpleKV) error { +func (f *multiFileLogWriter) Init(jsonConfig string, opts ...utils.KV) error { for _, elem := range opts { - if elem.Key == "formatter" { + if elem.GetKey() == "formatter" { formatter, err := GetFormatter(elem) if err != nil { return err diff --git a/pkg/logs/multifile_test.go b/pkg/infrastructure/logs/multifile_test.go similarity index 100% rename from pkg/logs/multifile_test.go rename to pkg/infrastructure/logs/multifile_test.go diff --git a/pkg/logs/slack.go b/pkg/infrastructure/logs/slack.go similarity index 92% rename from pkg/logs/slack.go rename to pkg/infrastructure/logs/slack.go index 0fc75149..d56b9acd 100644 --- a/pkg/logs/slack.go +++ b/pkg/infrastructure/logs/slack.go @@ -6,7 +6,7 @@ import ( "net/http" "net/url" - "github.com/astaxie/beego/pkg/common" + "github.com/astaxie/beego/pkg/infrastructure/utils" ) // SLACKWriter implements beego LoggerInterface and is used to send jiaoliao webhook @@ -27,7 +27,7 @@ func (s *SLACKWriter) Format(lm *LogMsg) string { } // Init SLACKWriter with json config string -func (s *SLACKWriter) Init(jsonConfig string, opts ...common.SimpleKV) error { +func (s *SLACKWriter) Init(jsonConfig string, opts ...utils.KV) error { // if elem != nil { // s.UseCustomFormatter = true // s.CustomFormatter = elem diff --git a/pkg/logs/smtp.go b/pkg/infrastructure/logs/smtp.go similarity index 96% rename from pkg/logs/smtp.go rename to pkg/infrastructure/logs/smtp.go index 9b67e343..904a89df 100644 --- a/pkg/logs/smtp.go +++ b/pkg/infrastructure/logs/smtp.go @@ -22,7 +22,7 @@ import ( "net/smtp" "strings" - "github.com/astaxie/beego/pkg/common" + "github.com/astaxie/beego/pkg/infrastructure/utils" ) // SMTPWriter implements LoggerInterface and is used to send emails via given SMTP-server. @@ -53,10 +53,10 @@ func newSMTPWriter() Logger { // "sendTos":["email1","email2"], // "level":LevelError // } -func (s *SMTPWriter) Init(jsonConfig string, opts ...common.SimpleKV) error { +func (s *SMTPWriter) Init(jsonConfig string, opts ...utils.KV) error { for _, elem := range opts { - if elem.Key == "formatter" { + if elem.GetKey() == "formatter" { formatter, err := GetFormatter(elem) if err != nil { return err diff --git a/pkg/logs/smtp_test.go b/pkg/infrastructure/logs/smtp_test.go similarity index 100% rename from pkg/logs/smtp_test.go rename to pkg/infrastructure/logs/smtp_test.go diff --git a/pkg/session/README.md b/pkg/infrastructure/session/README.md similarity index 100% rename from pkg/session/README.md rename to pkg/infrastructure/session/README.md diff --git a/pkg/session/couchbase/sess_couchbase.go b/pkg/infrastructure/session/couchbase/sess_couchbase.go similarity index 82% rename from pkg/session/couchbase/sess_couchbase.go rename to pkg/infrastructure/session/couchbase/sess_couchbase.go index b824a938..ddb4be58 100644 --- a/pkg/session/couchbase/sess_couchbase.go +++ b/pkg/infrastructure/session/couchbase/sess_couchbase.go @@ -33,13 +33,14 @@ package couchbase import ( + "context" "net/http" "strings" "sync" couchbase "github.com/couchbase/go-couchbase" - "github.com/astaxie/beego/pkg/session" + "github.com/astaxie/beego/pkg/infrastructure/session" ) var couchbpder = &Provider{} @@ -63,7 +64,7 @@ type Provider struct { } // Set value to couchabse session -func (cs *SessionStore) Set(key, value interface{}) error { +func (cs *SessionStore) Set(ctx context.Context, key, value interface{}) error { cs.lock.Lock() defer cs.lock.Unlock() cs.values[key] = value @@ -71,7 +72,7 @@ func (cs *SessionStore) Set(key, value interface{}) error { } // Get value from couchabse session -func (cs *SessionStore) Get(key interface{}) interface{} { +func (cs *SessionStore) Get(ctx context.Context, key interface{}) interface{} { cs.lock.RLock() defer cs.lock.RUnlock() if v, ok := cs.values[key]; ok { @@ -81,7 +82,7 @@ func (cs *SessionStore) Get(key interface{}) interface{} { } // Delete value in couchbase session by given key -func (cs *SessionStore) Delete(key interface{}) error { +func (cs *SessionStore) Delete(ctx context.Context, key interface{}) error { cs.lock.Lock() defer cs.lock.Unlock() delete(cs.values, key) @@ -89,7 +90,7 @@ func (cs *SessionStore) Delete(key interface{}) error { } // Flush Clean all values in couchbase session -func (cs *SessionStore) Flush() error { +func (cs *SessionStore) Flush(context.Context) error { cs.lock.Lock() defer cs.lock.Unlock() cs.values = make(map[interface{}]interface{}) @@ -97,12 +98,12 @@ func (cs *SessionStore) Flush() error { } // SessionID Get couchbase session store id -func (cs *SessionStore) SessionID() string { +func (cs *SessionStore) SessionID(context.Context) string { return cs.sid } // SessionRelease Write couchbase session with Gob string -func (cs *SessionStore) SessionRelease(w http.ResponseWriter) { +func (cs *SessionStore) SessionRelease(ctx context.Context, w http.ResponseWriter) { defer cs.b.Close() bo, err := session.EncodeGob(cs.values) @@ -135,7 +136,7 @@ func (cp *Provider) getBucket() *couchbase.Bucket { // SessionInit init couchbase session // savepath like couchbase server REST/JSON URL // e.g. http://host:port/, Pool, Bucket -func (cp *Provider) SessionInit(maxlifetime int64, savePath string) error { +func (cp *Provider) SessionInit(ctx context.Context, maxlifetime int64, savePath string) error { cp.maxlifetime = maxlifetime configs := strings.Split(savePath, ",") if len(configs) > 0 { @@ -152,7 +153,7 @@ func (cp *Provider) SessionInit(maxlifetime int64, savePath string) error { } // SessionRead read couchbase session by sid -func (cp *Provider) SessionRead(sid string) (session.Store, error) { +func (cp *Provider) SessionRead(ctx context.Context, sid string) (session.Store, error) { cp.b = cp.getBucket() var ( @@ -179,7 +180,7 @@ func (cp *Provider) SessionRead(sid string) (session.Store, error) { // SessionExist Check couchbase session exist. // it checkes sid exist or not. -func (cp *Provider) SessionExist(sid string) (bool, error) { +func (cp *Provider) SessionExist(ctx context.Context, sid string) (bool, error) { cp.b = cp.getBucket() defer cp.b.Close() @@ -192,7 +193,7 @@ func (cp *Provider) SessionExist(sid string) (bool, error) { } // SessionRegenerate remove oldsid and use sid to generate new session -func (cp *Provider) SessionRegenerate(oldsid, sid string) (session.Store, error) { +func (cp *Provider) SessionRegenerate(ctx context.Context, oldsid, sid string) (session.Store, error) { cp.b = cp.getBucket() var doc []byte @@ -225,7 +226,7 @@ func (cp *Provider) SessionRegenerate(oldsid, sid string) (session.Store, error) } // SessionDestroy Remove bucket in this couchbase -func (cp *Provider) SessionDestroy(sid string) error { +func (cp *Provider) SessionDestroy(ctx context.Context, sid string) error { cp.b = cp.getBucket() defer cp.b.Close() @@ -234,11 +235,11 @@ func (cp *Provider) SessionDestroy(sid string) error { } // SessionGC Recycle -func (cp *Provider) SessionGC() { +func (cp *Provider) SessionGC(context.Context) { } // SessionAll return all active session -func (cp *Provider) SessionAll() int { +func (cp *Provider) SessionAll(context.Context) int { return 0 } diff --git a/pkg/session/ledis/ledis_session.go b/pkg/infrastructure/session/ledis/ledis_session.go similarity index 73% rename from pkg/session/ledis/ledis_session.go rename to pkg/infrastructure/session/ledis/ledis_session.go index e43d70a0..74bf9b65 100644 --- a/pkg/session/ledis/ledis_session.go +++ b/pkg/infrastructure/session/ledis/ledis_session.go @@ -2,6 +2,7 @@ package ledis import ( + "context" "net/http" "strconv" "strings" @@ -10,7 +11,7 @@ import ( "github.com/ledisdb/ledisdb/config" "github.com/ledisdb/ledisdb/ledis" - "github.com/astaxie/beego/pkg/session" + "github.com/astaxie/beego/pkg/infrastructure/session" ) var ( @@ -27,7 +28,7 @@ type SessionStore struct { } // Set value in ledis session -func (ls *SessionStore) Set(key, value interface{}) error { +func (ls *SessionStore) Set(ctx context.Context, key, value interface{}) error { ls.lock.Lock() defer ls.lock.Unlock() ls.values[key] = value @@ -35,7 +36,7 @@ func (ls *SessionStore) Set(key, value interface{}) error { } // Get value in ledis session -func (ls *SessionStore) Get(key interface{}) interface{} { +func (ls *SessionStore) Get(ctx context.Context, key interface{}) interface{} { ls.lock.RLock() defer ls.lock.RUnlock() if v, ok := ls.values[key]; ok { @@ -45,7 +46,7 @@ func (ls *SessionStore) Get(key interface{}) interface{} { } // Delete value in ledis session -func (ls *SessionStore) Delete(key interface{}) error { +func (ls *SessionStore) Delete(ctx context.Context, key interface{}) error { ls.lock.Lock() defer ls.lock.Unlock() delete(ls.values, key) @@ -53,7 +54,7 @@ func (ls *SessionStore) Delete(key interface{}) error { } // Flush clear all values in ledis session -func (ls *SessionStore) Flush() error { +func (ls *SessionStore) Flush(context.Context) error { ls.lock.Lock() defer ls.lock.Unlock() ls.values = make(map[interface{}]interface{}) @@ -61,12 +62,12 @@ func (ls *SessionStore) Flush() error { } // SessionID get ledis session id -func (ls *SessionStore) SessionID() string { +func (ls *SessionStore) SessionID(context.Context) string { return ls.sid } // SessionRelease save session values to ledis -func (ls *SessionStore) SessionRelease(w http.ResponseWriter) { +func (ls *SessionStore) SessionRelease(ctx context.Context, w http.ResponseWriter) { b, err := session.EncodeGob(ls.values) if err != nil { return @@ -85,7 +86,7 @@ type Provider struct { // SessionInit init ledis session // savepath like ledis server saveDataPath,pool size // e.g. 127.0.0.1:6379,100,astaxie -func (lp *Provider) SessionInit(maxlifetime int64, savePath string) error { +func (lp *Provider) SessionInit(ctx context.Context, maxlifetime int64, savePath string) error { var err error lp.maxlifetime = maxlifetime configs := strings.Split(savePath, ",") @@ -111,7 +112,7 @@ func (lp *Provider) SessionInit(maxlifetime int64, savePath string) error { } // SessionRead read ledis session by sid -func (lp *Provider) SessionRead(sid string) (session.Store, error) { +func (lp *Provider) SessionRead(ctx context.Context, sid string) (session.Store, error) { var ( kv map[interface{}]interface{} err error @@ -132,13 +133,13 @@ func (lp *Provider) SessionRead(sid string) (session.Store, error) { } // SessionExist check ledis session exist by sid -func (lp *Provider) SessionExist(sid string) (bool, error) { +func (lp *Provider) SessionExist(ctx context.Context, sid string) (bool, error) { count, _ := c.Exists([]byte(sid)) return count != 0, nil } // SessionRegenerate generate new sid for ledis session -func (lp *Provider) SessionRegenerate(oldsid, sid string) (session.Store, error) { +func (lp *Provider) SessionRegenerate(ctx context.Context, oldsid, sid string) (session.Store, error) { count, _ := c.Exists([]byte(sid)) if count == 0 { // oldsid doesn't exists, set the new sid directly @@ -151,21 +152,21 @@ func (lp *Provider) SessionRegenerate(oldsid, sid string) (session.Store, error) c.Set([]byte(sid), data) c.Expire([]byte(sid), lp.maxlifetime) } - return lp.SessionRead(sid) + return lp.SessionRead(context.Background(), sid) } // SessionDestroy delete ledis session by id -func (lp *Provider) SessionDestroy(sid string) error { +func (lp *Provider) SessionDestroy(ctx context.Context, sid string) error { c.Del([]byte(sid)) return nil } // SessionGC Impelment method, no used. -func (lp *Provider) SessionGC() { +func (lp *Provider) SessionGC(context.Context) { } // SessionAll return all active session -func (lp *Provider) SessionAll() int { +func (lp *Provider) SessionAll(context.Context) int { return 0 } func init() { diff --git a/pkg/session/memcache/sess_memcache.go b/pkg/infrastructure/session/memcache/sess_memcache.go similarity index 82% rename from pkg/session/memcache/sess_memcache.go rename to pkg/infrastructure/session/memcache/sess_memcache.go index 7fab842a..57df2844 100644 --- a/pkg/session/memcache/sess_memcache.go +++ b/pkg/infrastructure/session/memcache/sess_memcache.go @@ -33,11 +33,12 @@ package memcache import ( + "context" "net/http" "strings" "sync" - "github.com/astaxie/beego/pkg/session" + "github.com/astaxie/beego/pkg/infrastructure/session" "github.com/bradfitz/gomemcache/memcache" ) @@ -54,7 +55,7 @@ type SessionStore struct { } // Set value in memcache session -func (rs *SessionStore) Set(key, value interface{}) error { +func (rs *SessionStore) Set(ctx context.Context, key, value interface{}) error { rs.lock.Lock() defer rs.lock.Unlock() rs.values[key] = value @@ -62,7 +63,7 @@ func (rs *SessionStore) Set(key, value interface{}) error { } // Get value in memcache session -func (rs *SessionStore) Get(key interface{}) interface{} { +func (rs *SessionStore) Get(ctx context.Context, key interface{}) interface{} { rs.lock.RLock() defer rs.lock.RUnlock() if v, ok := rs.values[key]; ok { @@ -72,7 +73,7 @@ func (rs *SessionStore) Get(key interface{}) interface{} { } // Delete value in memcache session -func (rs *SessionStore) Delete(key interface{}) error { +func (rs *SessionStore) Delete(ctx context.Context, key interface{}) error { rs.lock.Lock() defer rs.lock.Unlock() delete(rs.values, key) @@ -80,7 +81,7 @@ func (rs *SessionStore) Delete(key interface{}) error { } // Flush clear all values in memcache session -func (rs *SessionStore) Flush() error { +func (rs *SessionStore) Flush(context.Context) error { rs.lock.Lock() defer rs.lock.Unlock() rs.values = make(map[interface{}]interface{}) @@ -88,12 +89,12 @@ func (rs *SessionStore) Flush() error { } // SessionID get memcache session id -func (rs *SessionStore) SessionID() string { +func (rs *SessionStore) SessionID(context.Context) string { return rs.sid } // SessionRelease save session values to memcache -func (rs *SessionStore) SessionRelease(w http.ResponseWriter) { +func (rs *SessionStore) SessionRelease(ctx context.Context, w http.ResponseWriter) { b, err := session.EncodeGob(rs.values) if err != nil { return @@ -113,7 +114,7 @@ type MemProvider struct { // SessionInit init memcache session // savepath like // e.g. 127.0.0.1:9090 -func (rp *MemProvider) SessionInit(maxlifetime int64, savePath string) error { +func (rp *MemProvider) SessionInit(ctx context.Context, maxlifetime int64, savePath string) error { rp.maxlifetime = maxlifetime rp.conninfo = strings.Split(savePath, ";") client = memcache.New(rp.conninfo...) @@ -121,7 +122,7 @@ func (rp *MemProvider) SessionInit(maxlifetime int64, savePath string) error { } // SessionRead read memcache session by sid -func (rp *MemProvider) SessionRead(sid string) (session.Store, error) { +func (rp *MemProvider) SessionRead(ctx context.Context, sid string) (session.Store, error) { if client == nil { if err := rp.connectInit(); err != nil { return nil, err @@ -149,7 +150,7 @@ func (rp *MemProvider) SessionRead(sid string) (session.Store, error) { } // SessionExist check memcache session exist by sid -func (rp *MemProvider) SessionExist(sid string) (bool, error) { +func (rp *MemProvider) SessionExist(ctx context.Context, sid string) (bool, error) { if client == nil { if err := rp.connectInit(); err != nil { return false, err @@ -162,7 +163,7 @@ func (rp *MemProvider) SessionExist(sid string) (bool, error) { } // SessionRegenerate generate new sid for memcache session -func (rp *MemProvider) SessionRegenerate(oldsid, sid string) (session.Store, error) { +func (rp *MemProvider) SessionRegenerate(ctx context.Context, oldsid, sid string) (session.Store, error) { if client == nil { if err := rp.connectInit(); err != nil { return nil, err @@ -201,7 +202,7 @@ func (rp *MemProvider) SessionRegenerate(oldsid, sid string) (session.Store, err } // SessionDestroy delete memcache session by id -func (rp *MemProvider) SessionDestroy(sid string) error { +func (rp *MemProvider) SessionDestroy(ctx context.Context, sid string) error { if client == nil { if err := rp.connectInit(); err != nil { return err @@ -217,11 +218,11 @@ func (rp *MemProvider) connectInit() error { } // SessionGC Impelment method, no used. -func (rp *MemProvider) SessionGC() { +func (rp *MemProvider) SessionGC(context.Context) { } // SessionAll return all activeSession -func (rp *MemProvider) SessionAll() int { +func (rp *MemProvider) SessionAll(context.Context) int { return 0 } diff --git a/pkg/session/mysql/sess_mysql.go b/pkg/infrastructure/session/mysql/sess_mysql.go similarity index 84% rename from pkg/session/mysql/sess_mysql.go rename to pkg/infrastructure/session/mysql/sess_mysql.go index c641a4bf..fe1d69dc 100644 --- a/pkg/session/mysql/sess_mysql.go +++ b/pkg/infrastructure/session/mysql/sess_mysql.go @@ -41,12 +41,13 @@ package mysql import ( + "context" "database/sql" "net/http" "sync" "time" - "github.com/astaxie/beego/pkg/session" + "github.com/astaxie/beego/pkg/infrastructure/session" // import mysql driver _ "github.com/go-sql-driver/mysql" ) @@ -67,7 +68,7 @@ type SessionStore struct { // Set value in mysql session. // it is temp value in map. -func (st *SessionStore) Set(key, value interface{}) error { +func (st *SessionStore) Set(ctx context.Context, key, value interface{}) error { st.lock.Lock() defer st.lock.Unlock() st.values[key] = value @@ -75,7 +76,7 @@ func (st *SessionStore) Set(key, value interface{}) error { } // Get value from mysql session -func (st *SessionStore) Get(key interface{}) interface{} { +func (st *SessionStore) Get(ctx context.Context, key interface{}) interface{} { st.lock.RLock() defer st.lock.RUnlock() if v, ok := st.values[key]; ok { @@ -85,7 +86,7 @@ func (st *SessionStore) Get(key interface{}) interface{} { } // Delete value in mysql session -func (st *SessionStore) Delete(key interface{}) error { +func (st *SessionStore) Delete(ctx context.Context, key interface{}) error { st.lock.Lock() defer st.lock.Unlock() delete(st.values, key) @@ -93,7 +94,7 @@ func (st *SessionStore) Delete(key interface{}) error { } // Flush clear all values in mysql session -func (st *SessionStore) Flush() error { +func (st *SessionStore) Flush(context.Context) error { st.lock.Lock() defer st.lock.Unlock() st.values = make(map[interface{}]interface{}) @@ -101,13 +102,13 @@ func (st *SessionStore) Flush() error { } // SessionID get session id of this mysql session store -func (st *SessionStore) SessionID() string { +func (st *SessionStore) SessionID(context.Context) string { return st.sid } // SessionRelease save mysql session values to database. // must call this method to save values to database. -func (st *SessionStore) SessionRelease(w http.ResponseWriter) { +func (st *SessionStore) SessionRelease(ctx context.Context, w http.ResponseWriter) { defer st.c.Close() b, err := session.EncodeGob(st.values) if err != nil { @@ -134,14 +135,14 @@ func (mp *Provider) connectInit() *sql.DB { // SessionInit init mysql session. // savepath is the connection string of mysql. -func (mp *Provider) SessionInit(maxlifetime int64, savePath string) error { +func (mp *Provider) SessionInit(ctx context.Context, maxlifetime int64, savePath string) error { mp.maxlifetime = maxlifetime mp.savePath = savePath return nil } // SessionRead get mysql session by sid -func (mp *Provider) SessionRead(sid string) (session.Store, error) { +func (mp *Provider) SessionRead(ctx context.Context, sid string) (session.Store, error) { c := mp.connectInit() row := c.QueryRow("select session_data from "+TableName+" where session_key=?", sid) var sessiondata []byte @@ -164,7 +165,7 @@ func (mp *Provider) SessionRead(sid string) (session.Store, error) { } // SessionExist check mysql session exist -func (mp *Provider) SessionExist(sid string) (bool, error) { +func (mp *Provider) SessionExist(ctx context.Context, sid string) (bool, error) { c := mp.connectInit() defer c.Close() row := c.QueryRow("select session_data from "+TableName+" where session_key=?", sid) @@ -180,7 +181,7 @@ func (mp *Provider) SessionExist(sid string) (bool, error) { } // SessionRegenerate generate new sid for mysql session -func (mp *Provider) SessionRegenerate(oldsid, sid string) (session.Store, error) { +func (mp *Provider) SessionRegenerate(ctx context.Context, oldsid, sid string) (session.Store, error) { c := mp.connectInit() row := c.QueryRow("select session_data from "+TableName+" where session_key=?", oldsid) var sessiondata []byte @@ -203,7 +204,7 @@ func (mp *Provider) SessionRegenerate(oldsid, sid string) (session.Store, error) } // SessionDestroy delete mysql session by sid -func (mp *Provider) SessionDestroy(sid string) error { +func (mp *Provider) SessionDestroy(ctx context.Context, sid string) error { c := mp.connectInit() c.Exec("DELETE FROM "+TableName+" where session_key=?", sid) c.Close() @@ -211,14 +212,14 @@ func (mp *Provider) SessionDestroy(sid string) error { } // SessionGC delete expired values in mysql session -func (mp *Provider) SessionGC() { +func (mp *Provider) SessionGC(context.Context) { c := mp.connectInit() c.Exec("DELETE from "+TableName+" where session_expiry < ?", time.Now().Unix()-mp.maxlifetime) c.Close() } // SessionAll count values in mysql session -func (mp *Provider) SessionAll() int { +func (mp *Provider) SessionAll(context.Context) int { c := mp.connectInit() defer c.Close() var total int diff --git a/pkg/session/postgres/sess_postgresql.go b/pkg/infrastructure/session/postgres/sess_postgresql.go similarity index 84% rename from pkg/session/postgres/sess_postgresql.go rename to pkg/infrastructure/session/postgres/sess_postgresql.go index 688c0e36..2fadbed0 100644 --- a/pkg/session/postgres/sess_postgresql.go +++ b/pkg/infrastructure/session/postgres/sess_postgresql.go @@ -51,12 +51,13 @@ package postgres import ( + "context" "database/sql" "net/http" "sync" "time" - "github.com/astaxie/beego/pkg/session" + "github.com/astaxie/beego/pkg/infrastructure/session" // import postgresql Driver _ "github.com/lib/pq" ) @@ -73,7 +74,7 @@ type SessionStore struct { // Set value in postgresql session. // it is temp value in map. -func (st *SessionStore) Set(key, value interface{}) error { +func (st *SessionStore) Set(ctx context.Context, key, value interface{}) error { st.lock.Lock() defer st.lock.Unlock() st.values[key] = value @@ -81,7 +82,7 @@ func (st *SessionStore) Set(key, value interface{}) error { } // Get value from postgresql session -func (st *SessionStore) Get(key interface{}) interface{} { +func (st *SessionStore) Get(ctx context.Context, key interface{}) interface{} { st.lock.RLock() defer st.lock.RUnlock() if v, ok := st.values[key]; ok { @@ -91,7 +92,7 @@ func (st *SessionStore) Get(key interface{}) interface{} { } // Delete value in postgresql session -func (st *SessionStore) Delete(key interface{}) error { +func (st *SessionStore) Delete(ctx context.Context, key interface{}) error { st.lock.Lock() defer st.lock.Unlock() delete(st.values, key) @@ -99,7 +100,7 @@ func (st *SessionStore) Delete(key interface{}) error { } // Flush clear all values in postgresql session -func (st *SessionStore) Flush() error { +func (st *SessionStore) Flush(context.Context) error { st.lock.Lock() defer st.lock.Unlock() st.values = make(map[interface{}]interface{}) @@ -107,13 +108,13 @@ func (st *SessionStore) Flush() error { } // SessionID get session id of this postgresql session store -func (st *SessionStore) SessionID() string { +func (st *SessionStore) SessionID(context.Context) string { return st.sid } // SessionRelease save postgresql session values to database. // must call this method to save values to database. -func (st *SessionStore) SessionRelease(w http.ResponseWriter) { +func (st *SessionStore) SessionRelease(ctx context.Context, w http.ResponseWriter) { defer st.c.Close() b, err := session.EncodeGob(st.values) if err != nil { @@ -141,14 +142,14 @@ func (mp *Provider) connectInit() *sql.DB { // SessionInit init postgresql session. // savepath is the connection string of postgresql. -func (mp *Provider) SessionInit(maxlifetime int64, savePath string) error { +func (mp *Provider) SessionInit(ctx context.Context, maxlifetime int64, savePath string) error { mp.maxlifetime = maxlifetime mp.savePath = savePath return nil } // SessionRead get postgresql session by sid -func (mp *Provider) SessionRead(sid string) (session.Store, error) { +func (mp *Provider) SessionRead(ctx context.Context, sid string) (session.Store, error) { c := mp.connectInit() row := c.QueryRow("select session_data from session where session_key=$1", sid) var sessiondata []byte @@ -178,7 +179,7 @@ func (mp *Provider) SessionRead(sid string) (session.Store, error) { } // SessionExist check postgresql session exist -func (mp *Provider) SessionExist(sid string) (bool, error) { +func (mp *Provider) SessionExist(ctx context.Context, sid string) (bool, error) { c := mp.connectInit() defer c.Close() row := c.QueryRow("select session_data from session where session_key=$1", sid) @@ -194,7 +195,7 @@ func (mp *Provider) SessionExist(sid string) (bool, error) { } // SessionRegenerate generate new sid for postgresql session -func (mp *Provider) SessionRegenerate(oldsid, sid string) (session.Store, error) { +func (mp *Provider) SessionRegenerate(ctx context.Context, oldsid, sid string) (session.Store, error) { c := mp.connectInit() row := c.QueryRow("select session_data from session where session_key=$1", oldsid) var sessiondata []byte @@ -218,7 +219,7 @@ func (mp *Provider) SessionRegenerate(oldsid, sid string) (session.Store, error) } // SessionDestroy delete postgresql session by sid -func (mp *Provider) SessionDestroy(sid string) error { +func (mp *Provider) SessionDestroy(ctx context.Context, sid string) error { c := mp.connectInit() c.Exec("DELETE FROM session where session_key=$1", sid) c.Close() @@ -226,14 +227,14 @@ func (mp *Provider) SessionDestroy(sid string) error { } // SessionGC delete expired values in postgresql session -func (mp *Provider) SessionGC() { +func (mp *Provider) SessionGC(context.Context) { c := mp.connectInit() c.Exec("DELETE from session where EXTRACT(EPOCH FROM (current_timestamp - session_expiry)) > $1", mp.maxlifetime) c.Close() } // SessionAll count values in postgresql session -func (mp *Provider) SessionAll() int { +func (mp *Provider) SessionAll(context.Context) int { c := mp.connectInit() defer c.Close() var total int diff --git a/pkg/session/redis/sess_redis.go b/pkg/infrastructure/session/redis/sess_redis.go similarity index 83% rename from pkg/session/redis/sess_redis.go rename to pkg/infrastructure/session/redis/sess_redis.go index b68ee012..c7bfbcbf 100644 --- a/pkg/session/redis/sess_redis.go +++ b/pkg/infrastructure/session/redis/sess_redis.go @@ -33,13 +33,14 @@ package redis import ( + "context" "net/http" "strconv" "strings" "sync" "time" - "github.com/astaxie/beego/pkg/session" + "github.com/astaxie/beego/pkg/infrastructure/session" "github.com/go-redis/redis/v7" ) @@ -59,7 +60,7 @@ type SessionStore struct { } // Set value in redis session -func (rs *SessionStore) Set(key, value interface{}) error { +func (rs *SessionStore) Set(ctx context.Context, key, value interface{}) error { rs.lock.Lock() defer rs.lock.Unlock() rs.values[key] = value @@ -67,7 +68,7 @@ func (rs *SessionStore) Set(key, value interface{}) error { } // Get value in redis session -func (rs *SessionStore) Get(key interface{}) interface{} { +func (rs *SessionStore) Get(ctx context.Context, key interface{}) interface{} { rs.lock.RLock() defer rs.lock.RUnlock() if v, ok := rs.values[key]; ok { @@ -77,7 +78,7 @@ func (rs *SessionStore) Get(key interface{}) interface{} { } // Delete value in redis session -func (rs *SessionStore) Delete(key interface{}) error { +func (rs *SessionStore) Delete(ctx context.Context, key interface{}) error { rs.lock.Lock() defer rs.lock.Unlock() delete(rs.values, key) @@ -85,7 +86,7 @@ func (rs *SessionStore) Delete(key interface{}) error { } // Flush clear all values in redis session -func (rs *SessionStore) Flush() error { +func (rs *SessionStore) Flush(context.Context) error { rs.lock.Lock() defer rs.lock.Unlock() rs.values = make(map[interface{}]interface{}) @@ -93,12 +94,12 @@ func (rs *SessionStore) Flush() error { } // SessionID get redis session id -func (rs *SessionStore) SessionID() string { +func (rs *SessionStore) SessionID(context.Context) string { return rs.sid } // SessionRelease save session values to redis -func (rs *SessionStore) SessionRelease(w http.ResponseWriter) { +func (rs *SessionStore) SessionRelease(ctx context.Context, w http.ResponseWriter) { b, err := session.EncodeGob(rs.values) if err != nil { return @@ -123,7 +124,7 @@ type Provider struct { // SessionInit init redis session // savepath like redis server addr,pool size,password,dbnum,IdleTimeout second // e.g. 127.0.0.1:6379,100,astaxie,0,30 -func (rp *Provider) SessionInit(maxlifetime int64, savePath string) error { +func (rp *Provider) SessionInit(ctx context.Context, maxlifetime int64, savePath string) error { rp.maxlifetime = maxlifetime configs := strings.Split(savePath, ",") if len(configs) > 0 { @@ -185,7 +186,7 @@ func (rp *Provider) SessionInit(maxlifetime int64, savePath string) error { } // SessionRead read redis session by sid -func (rp *Provider) SessionRead(sid string) (session.Store, error) { +func (rp *Provider) SessionRead(ctx context.Context, sid string) (session.Store, error) { var kv map[interface{}]interface{} kvs, err := rp.poollist.Get(sid).Result() @@ -205,7 +206,7 @@ func (rp *Provider) SessionRead(sid string) (session.Store, error) { } // SessionExist check redis session exist by sid -func (rp *Provider) SessionExist(sid string) (bool, error) { +func (rp *Provider) SessionExist(ctx context.Context, sid string) (bool, error) { c := rp.poollist if existed, err := c.Exists(sid).Result(); err != nil || existed == 0 { @@ -215,7 +216,7 @@ func (rp *Provider) SessionExist(sid string) (bool, error) { } // SessionRegenerate generate new sid for redis session -func (rp *Provider) SessionRegenerate(oldsid, sid string) (session.Store, error) { +func (rp *Provider) SessionRegenerate(ctx context.Context, oldsid, sid string) (session.Store, error) { c := rp.poollist if existed, _ := c.Exists(oldsid).Result(); existed == 0 { // oldsid doesn't exists, set the new sid directly @@ -226,11 +227,11 @@ func (rp *Provider) SessionRegenerate(oldsid, sid string) (session.Store, error) c.Rename(oldsid, sid) c.Expire(sid, time.Duration(rp.maxlifetime)*time.Second) } - return rp.SessionRead(sid) + return rp.SessionRead(context.Background(), sid) } // SessionDestroy delete redis session by id -func (rp *Provider) SessionDestroy(sid string) error { +func (rp *Provider) SessionDestroy(ctx context.Context, sid string) error { c := rp.poollist c.Del(sid) @@ -238,11 +239,11 @@ func (rp *Provider) SessionDestroy(sid string) error { } // SessionGC Impelment method, no used. -func (rp *Provider) SessionGC() { +func (rp *Provider) SessionGC(context.Context) { } // SessionAll return all activeSession -func (rp *Provider) SessionAll() int { +func (rp *Provider) SessionAll(context.Context) int { return 0 } diff --git a/pkg/session/redis/sess_redis_test.go b/pkg/infrastructure/session/redis/sess_redis_test.go similarity index 64% rename from pkg/session/redis/sess_redis_test.go rename to pkg/infrastructure/session/redis/sess_redis_test.go index db5bb2c7..df77204d 100644 --- a/pkg/session/redis/sess_redis_test.go +++ b/pkg/infrastructure/session/redis/sess_redis_test.go @@ -1,11 +1,13 @@ package redis import ( + "fmt" "net/http" "net/http/httptest" + "os" "testing" - "github.com/astaxie/beego/pkg/session" + "github.com/astaxie/beego/pkg/infrastructure/session" ) func TestRedis(t *testing.T) { @@ -16,8 +18,14 @@ func TestRedis(t *testing.T) { Maxlifetime: 3600, Secure: false, CookieLifeTime: 3600, - ProviderConfig: "127.0.0.1:6379,100,,0,30", } + + redisAddr := os.Getenv("REDIS_ADDR") + if redisAddr == "" { + redisAddr = "127.0.0.1:6379" + } + + sessionConfig.ProviderConfig = fmt.Sprintf("%s,100,,0,30", redisAddr) globalSession, err := session.NewManager("redis", sessionConfig) if err != nil { t.Fatal("could not create manager:", err) @@ -32,57 +40,57 @@ func TestRedis(t *testing.T) { if err != nil { t.Fatal("session start failed:", err) } - defer sess.SessionRelease(w) + defer sess.SessionRelease(nil, w) // SET AND GET - err = sess.Set("username", "astaxie") + err = sess.Set(nil, "username", "astaxie") if err != nil { t.Fatal("set username failed:", err) } - username := sess.Get("username") + username := sess.Get(nil, "username") if username != "astaxie" { t.Fatal("get username failed") } // DELETE - err = sess.Delete("username") + err = sess.Delete(nil, "username") if err != nil { t.Fatal("delete username failed:", err) } - username = sess.Get("username") + username = sess.Get(nil, "username") if username != nil { t.Fatal("delete username failed") } // FLUSH - err = sess.Set("username", "astaxie") + err = sess.Set(nil, "username", "astaxie") if err != nil { t.Fatal("set failed:", err) } - err = sess.Set("password", "1qaz2wsx") + err = sess.Set(nil, "password", "1qaz2wsx") if err != nil { t.Fatal("set failed:", err) } - username = sess.Get("username") + username = sess.Get(nil, "username") if username != "astaxie" { t.Fatal("get username failed") } - password := sess.Get("password") + password := sess.Get(nil, "password") if password != "1qaz2wsx" { t.Fatal("get password failed") } - err = sess.Flush() + err = sess.Flush(nil) if err != nil { t.Fatal("flush failed:", err) } - username = sess.Get("username") + username = sess.Get(nil, "username") if username != nil { t.Fatal("flush failed") } - password = sess.Get("password") + password = sess.Get(nil, "password") if password != nil { t.Fatal("flush failed") } - sess.SessionRelease(w) + sess.SessionRelease(nil, w) } diff --git a/pkg/session/redis_cluster/redis_cluster.go b/pkg/infrastructure/session/redis_cluster/redis_cluster.go similarity index 83% rename from pkg/session/redis_cluster/redis_cluster.go rename to pkg/infrastructure/session/redis_cluster/redis_cluster.go index dcdfae85..95907a5f 100644 --- a/pkg/session/redis_cluster/redis_cluster.go +++ b/pkg/infrastructure/session/redis_cluster/redis_cluster.go @@ -33,13 +33,14 @@ package redis_cluster import ( + "context" "net/http" "strconv" "strings" "sync" "time" - "github.com/astaxie/beego/pkg/session" + "github.com/astaxie/beego/pkg/infrastructure/session" rediss "github.com/go-redis/redis/v7" ) @@ -58,7 +59,7 @@ type SessionStore struct { } // Set value in redis_cluster session -func (rs *SessionStore) Set(key, value interface{}) error { +func (rs *SessionStore) Set(ctx context.Context, key, value interface{}) error { rs.lock.Lock() defer rs.lock.Unlock() rs.values[key] = value @@ -66,7 +67,7 @@ func (rs *SessionStore) Set(key, value interface{}) error { } // Get value in redis_cluster session -func (rs *SessionStore) Get(key interface{}) interface{} { +func (rs *SessionStore) Get(ctx context.Context, key interface{}) interface{} { rs.lock.RLock() defer rs.lock.RUnlock() if v, ok := rs.values[key]; ok { @@ -76,7 +77,7 @@ func (rs *SessionStore) Get(key interface{}) interface{} { } // Delete value in redis_cluster session -func (rs *SessionStore) Delete(key interface{}) error { +func (rs *SessionStore) Delete(ctx context.Context, key interface{}) error { rs.lock.Lock() defer rs.lock.Unlock() delete(rs.values, key) @@ -84,7 +85,7 @@ func (rs *SessionStore) Delete(key interface{}) error { } // Flush clear all values in redis_cluster session -func (rs *SessionStore) Flush() error { +func (rs *SessionStore) Flush(context.Context) error { rs.lock.Lock() defer rs.lock.Unlock() rs.values = make(map[interface{}]interface{}) @@ -92,12 +93,12 @@ func (rs *SessionStore) Flush() error { } // SessionID get redis_cluster session id -func (rs *SessionStore) SessionID() string { +func (rs *SessionStore) SessionID(context.Context) string { return rs.sid } // SessionRelease save session values to redis_cluster -func (rs *SessionStore) SessionRelease(w http.ResponseWriter) { +func (rs *SessionStore) SessionRelease(ctx context.Context, w http.ResponseWriter) { b, err := session.EncodeGob(rs.values) if err != nil { return @@ -122,7 +123,7 @@ type Provider struct { // SessionInit init redis_cluster session // savepath like redis server addr,pool size,password,dbnum // e.g. 127.0.0.1:6379;127.0.0.1:6380,100,test,0 -func (rp *Provider) SessionInit(maxlifetime int64, savePath string) error { +func (rp *Provider) SessionInit(ctx context.Context, maxlifetime int64, savePath string) error { rp.maxlifetime = maxlifetime configs := strings.Split(savePath, ",") if len(configs) > 0 { @@ -182,7 +183,7 @@ func (rp *Provider) SessionInit(maxlifetime int64, savePath string) error { } // SessionRead read redis_cluster session by sid -func (rp *Provider) SessionRead(sid string) (session.Store, error) { +func (rp *Provider) SessionRead(ctx context.Context, sid string) (session.Store, error) { var kv map[interface{}]interface{} kvs, err := rp.poollist.Get(sid).Result() if err != nil && err != rediss.Nil { @@ -201,7 +202,7 @@ func (rp *Provider) SessionRead(sid string) (session.Store, error) { } // SessionExist check redis_cluster session exist by sid -func (rp *Provider) SessionExist(sid string) (bool, error) { +func (rp *Provider) SessionExist(ctx context.Context, sid string) (bool, error) { c := rp.poollist if existed, err := c.Exists(sid).Result(); err != nil || existed == 0 { return false, err @@ -210,7 +211,7 @@ func (rp *Provider) SessionExist(sid string) (bool, error) { } // SessionRegenerate generate new sid for redis_cluster session -func (rp *Provider) SessionRegenerate(oldsid, sid string) (session.Store, error) { +func (rp *Provider) SessionRegenerate(ctx context.Context, oldsid, sid string) (session.Store, error) { c := rp.poollist if existed, err := c.Exists(oldsid).Result(); err != nil || existed == 0 { @@ -222,22 +223,22 @@ func (rp *Provider) SessionRegenerate(oldsid, sid string) (session.Store, error) c.Rename(oldsid, sid) c.Expire(sid, time.Duration(rp.maxlifetime)*time.Second) } - return rp.SessionRead(sid) + return rp.SessionRead(context.Background(), sid) } // SessionDestroy delete redis session by id -func (rp *Provider) SessionDestroy(sid string) error { +func (rp *Provider) SessionDestroy(ctx context.Context, sid string) error { c := rp.poollist c.Del(sid) return nil } // SessionGC Impelment method, no used. -func (rp *Provider) SessionGC() { +func (rp *Provider) SessionGC(context.Context) { } // SessionAll return all activeSession -func (rp *Provider) SessionAll() int { +func (rp *Provider) SessionAll(context.Context) int { return 0 } diff --git a/pkg/session/redis_sentinel/sess_redis_sentinel.go b/pkg/infrastructure/session/redis_sentinel/sess_redis_sentinel.go similarity index 84% rename from pkg/session/redis_sentinel/sess_redis_sentinel.go rename to pkg/infrastructure/session/redis_sentinel/sess_redis_sentinel.go index 6721539a..1b9c841b 100644 --- a/pkg/session/redis_sentinel/sess_redis_sentinel.go +++ b/pkg/infrastructure/session/redis_sentinel/sess_redis_sentinel.go @@ -33,13 +33,14 @@ package redis_sentinel import ( + "context" "net/http" "strconv" "strings" "sync" "time" - "github.com/astaxie/beego/pkg/session" + "github.com/astaxie/beego/pkg/infrastructure/session" "github.com/go-redis/redis/v7" ) @@ -58,7 +59,7 @@ type SessionStore struct { } // Set value in redis_sentinel session -func (rs *SessionStore) Set(key, value interface{}) error { +func (rs *SessionStore) Set(ctx context.Context, key, value interface{}) error { rs.lock.Lock() defer rs.lock.Unlock() rs.values[key] = value @@ -66,7 +67,7 @@ func (rs *SessionStore) Set(key, value interface{}) error { } // Get value in redis_sentinel session -func (rs *SessionStore) Get(key interface{}) interface{} { +func (rs *SessionStore) Get(ctx context.Context, key interface{}) interface{} { rs.lock.RLock() defer rs.lock.RUnlock() if v, ok := rs.values[key]; ok { @@ -76,7 +77,7 @@ func (rs *SessionStore) Get(key interface{}) interface{} { } // Delete value in redis_sentinel session -func (rs *SessionStore) Delete(key interface{}) error { +func (rs *SessionStore) Delete(ctx context.Context, key interface{}) error { rs.lock.Lock() defer rs.lock.Unlock() delete(rs.values, key) @@ -84,7 +85,7 @@ func (rs *SessionStore) Delete(key interface{}) error { } // Flush clear all values in redis_sentinel session -func (rs *SessionStore) Flush() error { +func (rs *SessionStore) Flush(context.Context) error { rs.lock.Lock() defer rs.lock.Unlock() rs.values = make(map[interface{}]interface{}) @@ -92,12 +93,12 @@ func (rs *SessionStore) Flush() error { } // SessionID get redis_sentinel session id -func (rs *SessionStore) SessionID() string { +func (rs *SessionStore) SessionID(context.Context) string { return rs.sid } // SessionRelease save session values to redis_sentinel -func (rs *SessionStore) SessionRelease(w http.ResponseWriter) { +func (rs *SessionStore) SessionRelease(ctx context.Context, w http.ResponseWriter) { b, err := session.EncodeGob(rs.values) if err != nil { return @@ -123,7 +124,7 @@ type Provider struct { // SessionInit init redis_sentinel session // savepath like redis sentinel addr,pool size,password,dbnum,masterName // e.g. 127.0.0.1:26379;127.0.0.2:26379,100,1qaz2wsx,0,mymaster -func (rp *Provider) SessionInit(maxlifetime int64, savePath string) error { +func (rp *Provider) SessionInit(ctx context.Context, maxlifetime int64, savePath string) error { rp.maxlifetime = maxlifetime configs := strings.Split(savePath, ",") if len(configs) > 0 { @@ -195,7 +196,7 @@ func (rp *Provider) SessionInit(maxlifetime int64, savePath string) error { } // SessionRead read redis_sentinel session by sid -func (rp *Provider) SessionRead(sid string) (session.Store, error) { +func (rp *Provider) SessionRead(ctx context.Context, sid string) (session.Store, error) { var kv map[interface{}]interface{} kvs, err := rp.poollist.Get(sid).Result() if err != nil && err != redis.Nil { @@ -214,7 +215,7 @@ func (rp *Provider) SessionRead(sid string) (session.Store, error) { } // SessionExist check redis_sentinel session exist by sid -func (rp *Provider) SessionExist(sid string) (bool, error) { +func (rp *Provider) SessionExist(ctx context.Context, sid string) (bool, error) { c := rp.poollist if existed, err := c.Exists(sid).Result(); err != nil || existed == 0 { return false, err @@ -223,7 +224,7 @@ func (rp *Provider) SessionExist(sid string) (bool, error) { } // SessionRegenerate generate new sid for redis_sentinel session -func (rp *Provider) SessionRegenerate(oldsid, sid string) (session.Store, error) { +func (rp *Provider) SessionRegenerate(ctx context.Context, oldsid, sid string) (session.Store, error) { c := rp.poollist if existed, err := c.Exists(oldsid).Result(); err != nil || existed == 0 { @@ -235,22 +236,22 @@ func (rp *Provider) SessionRegenerate(oldsid, sid string) (session.Store, error) c.Rename(oldsid, sid) c.Expire(sid, time.Duration(rp.maxlifetime)*time.Second) } - return rp.SessionRead(sid) + return rp.SessionRead(context.Background(), sid) } // SessionDestroy delete redis session by id -func (rp *Provider) SessionDestroy(sid string) error { +func (rp *Provider) SessionDestroy(ctx context.Context, sid string) error { c := rp.poollist c.Del(sid) return nil } // SessionGC Impelment method, no used. -func (rp *Provider) SessionGC() { +func (rp *Provider) SessionGC(context.Context) { } // SessionAll return all activeSession -func (rp *Provider) SessionAll() int { +func (rp *Provider) SessionAll(context.Context) int { return 0 } diff --git a/pkg/infrastructure/session/redis_sentinel/sess_redis_sentinel_test.go b/pkg/infrastructure/session/redis_sentinel/sess_redis_sentinel_test.go new file mode 100644 index 00000000..fcec9806 --- /dev/null +++ b/pkg/infrastructure/session/redis_sentinel/sess_redis_sentinel_test.go @@ -0,0 +1,90 @@ +package redis_sentinel + +import ( + "net/http" + "net/http/httptest" + "testing" + + "github.com/astaxie/beego/pkg/infrastructure/session" +) + +func TestRedisSentinel(t *testing.T) { + sessionConfig := &session.ManagerConfig{ + CookieName: "gosessionid", + EnableSetCookie: true, + Gclifetime: 3600, + Maxlifetime: 3600, + Secure: false, + CookieLifeTime: 3600, + ProviderConfig: "127.0.0.1:6379,100,,0,master", + } + globalSessions, e := session.NewManager("redis_sentinel", sessionConfig) + if e != nil { + t.Log(e) + return + } + //todo test if e==nil + go globalSessions.GC() + + r, _ := http.NewRequest("GET", "/", nil) + w := httptest.NewRecorder() + + sess, err := globalSessions.SessionStart(w, r) + if err != nil { + t.Fatal("session start failed:", err) + } + defer sess.SessionRelease(nil, w) + + // SET AND GET + err = sess.Set(nil, "username", "astaxie") + if err != nil { + t.Fatal("set username failed:", err) + } + username := sess.Get(nil, "username") + if username != "astaxie" { + t.Fatal("get username failed") + } + + // DELETE + err = sess.Delete(nil, "username") + if err != nil { + t.Fatal("delete username failed:", err) + } + username = sess.Get(nil, "username") + if username != nil { + t.Fatal("delete username failed") + } + + // FLUSH + err = sess.Set(nil, "username", "astaxie") + if err != nil { + t.Fatal("set failed:", err) + } + err = sess.Set(nil, "password", "1qaz2wsx") + if err != nil { + t.Fatal("set failed:", err) + } + username = sess.Get(nil, "username") + if username != "astaxie" { + t.Fatal("get username failed") + } + password := sess.Get(nil, "password") + if password != "1qaz2wsx" { + t.Fatal("get password failed") + } + err = sess.Flush(nil) + if err != nil { + t.Fatal("flush failed:", err) + } + username = sess.Get(nil, "username") + if username != nil { + t.Fatal("flush failed") + } + password = sess.Get(nil, "password") + if password != nil { + t.Fatal("flush failed") + } + + sess.SessionRelease(nil, w) + +} diff --git a/pkg/session/sess_cookie.go b/pkg/infrastructure/session/sess_cookie.go similarity index 77% rename from pkg/session/sess_cookie.go rename to pkg/infrastructure/session/sess_cookie.go index 30a7032e..649f6510 100644 --- a/pkg/session/sess_cookie.go +++ b/pkg/infrastructure/session/sess_cookie.go @@ -15,6 +15,7 @@ package session import ( + "context" "crypto/aes" "crypto/cipher" "encoding/json" @@ -34,7 +35,7 @@ type CookieSessionStore struct { // Set value to cookie session. // the value are encoded as gob with hash block string. -func (st *CookieSessionStore) Set(key, value interface{}) error { +func (st *CookieSessionStore) Set(ctx context.Context, key, value interface{}) error { st.lock.Lock() defer st.lock.Unlock() st.values[key] = value @@ -42,7 +43,7 @@ func (st *CookieSessionStore) Set(key, value interface{}) error { } // Get value from cookie session -func (st *CookieSessionStore) Get(key interface{}) interface{} { +func (st *CookieSessionStore) Get(ctx context.Context, key interface{}) interface{} { st.lock.RLock() defer st.lock.RUnlock() if v, ok := st.values[key]; ok { @@ -52,7 +53,7 @@ func (st *CookieSessionStore) Get(key interface{}) interface{} { } // Delete value in cookie session -func (st *CookieSessionStore) Delete(key interface{}) error { +func (st *CookieSessionStore) Delete(ctx context.Context, key interface{}) error { st.lock.Lock() defer st.lock.Unlock() delete(st.values, key) @@ -60,7 +61,7 @@ func (st *CookieSessionStore) Delete(key interface{}) error { } // Flush Clean all values in cookie session -func (st *CookieSessionStore) Flush() error { +func (st *CookieSessionStore) Flush(context.Context) error { st.lock.Lock() defer st.lock.Unlock() st.values = make(map[interface{}]interface{}) @@ -68,12 +69,12 @@ func (st *CookieSessionStore) Flush() error { } // SessionID Return id of this cookie session -func (st *CookieSessionStore) SessionID() string { +func (st *CookieSessionStore) SessionID(context.Context) string { return st.sid } // SessionRelease Write cookie session to http response cookie -func (st *CookieSessionStore) SessionRelease(w http.ResponseWriter) { +func (st *CookieSessionStore) SessionRelease(ctx context.Context, w http.ResponseWriter) { st.lock.Lock() encodedCookie, err := encodeCookie(cookiepder.block, cookiepder.config.SecurityKey, cookiepder.config.SecurityName, st.values) st.lock.Unlock() @@ -112,7 +113,7 @@ type CookieProvider struct { // securityName - recognized name in encoded cookie string // cookieName - cookie name // maxage - cookie max life time. -func (pder *CookieProvider) SessionInit(maxlifetime int64, config string) error { +func (pder *CookieProvider) SessionInit(ctx context.Context, maxlifetime int64, config string) error { pder.config = &cookieConfig{} err := json.Unmarshal([]byte(config), pder.config) if err != nil { @@ -134,7 +135,7 @@ func (pder *CookieProvider) SessionInit(maxlifetime int64, config string) error // SessionRead Get SessionStore in cooke. // decode cooke string to map and put into SessionStore with sid. -func (pder *CookieProvider) SessionRead(sid string) (Store, error) { +func (pder *CookieProvider) SessionRead(ctx context.Context, sid string) (Store, error) { maps, _ := decodeCookie(pder.block, pder.config.SecurityKey, pder.config.SecurityName, @@ -147,31 +148,31 @@ func (pder *CookieProvider) SessionRead(sid string) (Store, error) { } // SessionExist Cookie session is always existed -func (pder *CookieProvider) SessionExist(sid string) (bool, error) { +func (pder *CookieProvider) SessionExist(ctx context.Context, sid string) (bool, error) { return true, nil } // SessionRegenerate Implement method, no used. -func (pder *CookieProvider) SessionRegenerate(oldsid, sid string) (Store, error) { +func (pder *CookieProvider) SessionRegenerate(ctx context.Context, oldsid, sid string) (Store, error) { return nil, nil } // SessionDestroy Implement method, no used. -func (pder *CookieProvider) SessionDestroy(sid string) error { +func (pder *CookieProvider) SessionDestroy(ctx context.Context, sid string) error { return nil } // SessionGC Implement method, no used. -func (pder *CookieProvider) SessionGC() { +func (pder *CookieProvider) SessionGC(context.Context) { } // SessionAll Implement method, return 0. -func (pder *CookieProvider) SessionAll() int { +func (pder *CookieProvider) SessionAll(context.Context) int { return 0 } // SessionUpdate Implement method, no used. -func (pder *CookieProvider) SessionUpdate(sid string) error { +func (pder *CookieProvider) SessionUpdate(ctx context.Context, sid string) error { return nil } diff --git a/pkg/infrastructure/session/sess_cookie_test.go b/pkg/infrastructure/session/sess_cookie_test.go new file mode 100644 index 00000000..a9fc876d --- /dev/null +++ b/pkg/infrastructure/session/sess_cookie_test.go @@ -0,0 +1,105 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package session + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "strings" + "testing" +) + +func TestCookie(t *testing.T) { + config := `{"cookieName":"gosessionid","enableSetCookie":false,"gclifetime":3600,"ProviderConfig":"{\"cookieName\":\"gosessionid\",\"securityKey\":\"beegocookiehashkey\"}"}` + conf := new(ManagerConfig) + if err := json.Unmarshal([]byte(config), conf); err != nil { + t.Fatal("json decode error", err) + } + globalSessions, err := NewManager("cookie", conf) + if err != nil { + t.Fatal("init cookie session err", err) + } + r, _ := http.NewRequest("GET", "/", nil) + w := httptest.NewRecorder() + sess, err := globalSessions.SessionStart(w, r) + if err != nil { + t.Fatal("set error,", err) + } + err = sess.Set(nil, "username", "astaxie") + if err != nil { + t.Fatal("set error,", err) + } + if username := sess.Get(nil, "username"); username != "astaxie" { + t.Fatal("get username error") + } + sess.SessionRelease(nil, w) + if cookiestr := w.Header().Get("Set-Cookie"); cookiestr == "" { + t.Fatal("setcookie error") + } else { + parts := strings.Split(strings.TrimSpace(cookiestr), ";") + for k, v := range parts { + nameval := strings.Split(v, "=") + if k == 0 && nameval[0] != "gosessionid" { + t.Fatal("error") + } + } + } +} + +func TestDestorySessionCookie(t *testing.T) { + config := `{"cookieName":"gosessionid","enableSetCookie":true,"gclifetime":3600,"ProviderConfig":"{\"cookieName\":\"gosessionid\",\"securityKey\":\"beegocookiehashkey\"}"}` + conf := new(ManagerConfig) + if err := json.Unmarshal([]byte(config), conf); err != nil { + t.Fatal("json decode error", err) + } + globalSessions, err := NewManager("cookie", conf) + if err != nil { + t.Fatal("init cookie session err", err) + } + + r, _ := http.NewRequest("GET", "/", nil) + w := httptest.NewRecorder() + session, err := globalSessions.SessionStart(w, r) + if err != nil { + t.Fatal("session start err,", err) + } + + // request again ,will get same sesssion id . + r1, _ := http.NewRequest("GET", "/", nil) + r1.Header.Set("Cookie", w.Header().Get("Set-Cookie")) + w = httptest.NewRecorder() + newSession, err := globalSessions.SessionStart(w, r1) + if err != nil { + t.Fatal("session start err,", err) + } + if newSession.SessionID(nil) != session.SessionID(nil) { + t.Fatal("get cookie session id is not the same again.") + } + + // After destroy session , will get a new session id . + globalSessions.SessionDestroy(w, r1) + r2, _ := http.NewRequest("GET", "/", nil) + r2.Header.Set("Cookie", w.Header().Get("Set-Cookie")) + + w = httptest.NewRecorder() + newSession, err = globalSessions.SessionStart(w, r2) + if err != nil { + t.Fatal("session start error") + } + if newSession.SessionID(nil) == session.SessionID(nil) { + t.Fatal("after destroy session and reqeust again ,get cookie session id is same.") + } +} diff --git a/pkg/session/sess_file.go b/pkg/infrastructure/session/sess_file.go similarity index 87% rename from pkg/session/sess_file.go rename to pkg/infrastructure/session/sess_file.go index 37d5bd68..90de9a79 100644 --- a/pkg/session/sess_file.go +++ b/pkg/infrastructure/session/sess_file.go @@ -15,6 +15,7 @@ package session import ( + "context" "errors" "fmt" "io/ioutil" @@ -40,7 +41,7 @@ type FileSessionStore struct { } // Set value to file session -func (fs *FileSessionStore) Set(key, value interface{}) error { +func (fs *FileSessionStore) Set(ctx context.Context, key, value interface{}) error { fs.lock.Lock() defer fs.lock.Unlock() fs.values[key] = value @@ -48,7 +49,7 @@ func (fs *FileSessionStore) Set(key, value interface{}) error { } // Get value from file session -func (fs *FileSessionStore) Get(key interface{}) interface{} { +func (fs *FileSessionStore) Get(ctx context.Context, key interface{}) interface{} { fs.lock.RLock() defer fs.lock.RUnlock() if v, ok := fs.values[key]; ok { @@ -58,7 +59,7 @@ func (fs *FileSessionStore) Get(key interface{}) interface{} { } // Delete value in file session by given key -func (fs *FileSessionStore) Delete(key interface{}) error { +func (fs *FileSessionStore) Delete(ctx context.Context, key interface{}) error { fs.lock.Lock() defer fs.lock.Unlock() delete(fs.values, key) @@ -66,7 +67,7 @@ func (fs *FileSessionStore) Delete(key interface{}) error { } // Flush Clean all values in file session -func (fs *FileSessionStore) Flush() error { +func (fs *FileSessionStore) Flush(context.Context) error { fs.lock.Lock() defer fs.lock.Unlock() fs.values = make(map[interface{}]interface{}) @@ -74,12 +75,12 @@ func (fs *FileSessionStore) Flush() error { } // SessionID Get file session store id -func (fs *FileSessionStore) SessionID() string { +func (fs *FileSessionStore) SessionID(context.Context) string { return fs.sid } // SessionRelease Write file session to local file with Gob string -func (fs *FileSessionStore) SessionRelease(w http.ResponseWriter) { +func (fs *FileSessionStore) SessionRelease(ctx context.Context, w http.ResponseWriter) { filepder.lock.Lock() defer filepder.lock.Unlock() b, err := EncodeGob(fs.values) @@ -119,7 +120,7 @@ type FileProvider struct { // SessionInit Init file session provider. // savePath sets the session files path. -func (fp *FileProvider) SessionInit(maxlifetime int64, savePath string) error { +func (fp *FileProvider) SessionInit(ctx context.Context, maxlifetime int64, savePath string) error { fp.maxlifetime = maxlifetime fp.savePath = savePath return nil @@ -128,7 +129,7 @@ func (fp *FileProvider) SessionInit(maxlifetime int64, savePath string) error { // SessionRead Read file session by sid. // if file is not exist, create it. // the file path is generated from sid string. -func (fp *FileProvider) SessionRead(sid string) (Store, error) { +func (fp *FileProvider) SessionRead(ctx context.Context, sid string) (Store, error) { invalidChars := "./" if strings.ContainsAny(sid, invalidChars) { return nil, errors.New("the sid shouldn't have following characters: " + invalidChars) @@ -176,7 +177,7 @@ func (fp *FileProvider) SessionRead(sid string) (Store, error) { // SessionExist Check file session exist. // it checks the file named from sid exist or not. -func (fp *FileProvider) SessionExist(sid string) (bool, error) { +func (fp *FileProvider) SessionExist(ctx context.Context, sid string) (bool, error) { filepder.lock.Lock() defer filepder.lock.Unlock() @@ -190,7 +191,7 @@ func (fp *FileProvider) SessionExist(sid string) (bool, error) { } // SessionDestroy Remove all files in this save path -func (fp *FileProvider) SessionDestroy(sid string) error { +func (fp *FileProvider) SessionDestroy(ctx context.Context, sid string) error { filepder.lock.Lock() defer filepder.lock.Unlock() os.Remove(path.Join(fp.savePath, string(sid[0]), string(sid[1]), sid)) @@ -198,7 +199,7 @@ func (fp *FileProvider) SessionDestroy(sid string) error { } // SessionGC Recycle files in save path -func (fp *FileProvider) SessionGC() { +func (fp *FileProvider) SessionGC(context.Context) { filepder.lock.Lock() defer filepder.lock.Unlock() @@ -208,7 +209,7 @@ func (fp *FileProvider) SessionGC() { // SessionAll Get active file session number. // it walks save path to count files. -func (fp *FileProvider) SessionAll() int { +func (fp *FileProvider) SessionAll(context.Context) int { a := &activeSession{} err := filepath.Walk(fp.savePath, func(path string, f os.FileInfo, err error) error { return a.visit(path, f, err) @@ -222,7 +223,7 @@ func (fp *FileProvider) SessionAll() int { // SessionRegenerate Generate new sid for file session. // it delete old file and create new file named from new sid. -func (fp *FileProvider) SessionRegenerate(oldsid, sid string) (Store, error) { +func (fp *FileProvider) SessionRegenerate(ctx context.Context, oldsid, sid string) (Store, error) { filepder.lock.Lock() defer filepder.lock.Unlock() diff --git a/pkg/infrastructure/session/sess_file_test.go b/pkg/infrastructure/session/sess_file_test.go new file mode 100644 index 00000000..f40de69f --- /dev/null +++ b/pkg/infrastructure/session/sess_file_test.go @@ -0,0 +1,427 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package session + +import ( + "context" + "fmt" + "os" + "sync" + "testing" + "time" +) + +const sid = "Session_id" +const sidNew = "Session_id_new" +const sessionPath = "./_session_runtime" + +var ( + mutex sync.Mutex +) + +func TestFileProvider_SessionInit(t *testing.T) { + mutex.Lock() + defer mutex.Unlock() + os.RemoveAll(sessionPath) + defer os.RemoveAll(sessionPath) + fp := &FileProvider{} + + _ = fp.SessionInit(context.Background(), 180, sessionPath) + if fp.maxlifetime != 180 { + t.Error() + } + + if fp.savePath != sessionPath { + t.Error() + } +} + +func TestFileProvider_SessionExist(t *testing.T) { + mutex.Lock() + defer mutex.Unlock() + os.RemoveAll(sessionPath) + defer os.RemoveAll(sessionPath) + fp := &FileProvider{} + + _ = fp.SessionInit(context.Background(), 180, sessionPath) + + exists, err := fp.SessionExist(context.Background(), sid) + if err != nil { + t.Error(err) + } + if exists { + t.Error() + } + + _, err = fp.SessionRead(context.Background(), sid) + if err != nil { + t.Error(err) + } + + exists, err = fp.SessionExist(context.Background(), sid) + if err != nil { + t.Error(err) + } + if !exists { + t.Error() + } +} + +func TestFileProvider_SessionExist2(t *testing.T) { + mutex.Lock() + defer mutex.Unlock() + os.RemoveAll(sessionPath) + defer os.RemoveAll(sessionPath) + fp := &FileProvider{} + + _ = fp.SessionInit(context.Background(), 180, sessionPath) + + exists, err := fp.SessionExist(context.Background(), sid) + if err != nil { + t.Error(err) + } + if exists { + t.Error() + } + + exists, err = fp.SessionExist(context.Background(), "") + if err == nil { + t.Error() + } + if exists { + t.Error() + } + + exists, err = fp.SessionExist(context.Background(), "1") + if err == nil { + t.Error() + } + if exists { + t.Error() + } +} + +func TestFileProvider_SessionRead(t *testing.T) { + mutex.Lock() + defer mutex.Unlock() + os.RemoveAll(sessionPath) + defer os.RemoveAll(sessionPath) + fp := &FileProvider{} + + _ = fp.SessionInit(context.Background(), 180, sessionPath) + + s, err := fp.SessionRead(context.Background(), sid) + if err != nil { + t.Error(err) + } + + _ = s.Set(nil, "sessionValue", 18975) + v := s.Get(nil, "sessionValue") + + if v.(int) != 18975 { + t.Error() + } +} + +func TestFileProvider_SessionRead1(t *testing.T) { + mutex.Lock() + defer mutex.Unlock() + os.RemoveAll(sessionPath) + defer os.RemoveAll(sessionPath) + fp := &FileProvider{} + + _ = fp.SessionInit(context.Background(), 180, sessionPath) + + _, err := fp.SessionRead(context.Background(), "") + if err == nil { + t.Error(err) + } + + _, err = fp.SessionRead(context.Background(), "1") + if err == nil { + t.Error(err) + } +} + +func TestFileProvider_SessionAll(t *testing.T) { + mutex.Lock() + defer mutex.Unlock() + os.RemoveAll(sessionPath) + defer os.RemoveAll(sessionPath) + fp := &FileProvider{} + + _ = fp.SessionInit(context.Background(), 180, sessionPath) + + sessionCount := 546 + + for i := 1; i <= sessionCount; i++ { + _, err := fp.SessionRead(context.Background(), fmt.Sprintf("%s_%d", sid, i)) + if err != nil { + t.Error(err) + } + } + + if fp.SessionAll(nil) != sessionCount { + t.Error() + } +} + +func TestFileProvider_SessionRegenerate(t *testing.T) { + mutex.Lock() + defer mutex.Unlock() + os.RemoveAll(sessionPath) + defer os.RemoveAll(sessionPath) + fp := &FileProvider{} + + _ = fp.SessionInit(context.Background(), 180, sessionPath) + + _, err := fp.SessionRead(context.Background(), sid) + if err != nil { + t.Error(err) + } + + exists, err := fp.SessionExist(context.Background(), sid) + if err != nil { + t.Error(err) + } + if !exists { + t.Error() + } + + _, err = fp.SessionRegenerate(context.Background(), sid, sidNew) + if err != nil { + t.Error(err) + } + + exists, err = fp.SessionExist(context.Background(), sid) + if err != nil { + t.Error(err) + } + if exists { + t.Error() + } + + exists, err = fp.SessionExist(context.Background(), sidNew) + if err != nil { + t.Error(err) + } + if !exists { + t.Error() + } +} + +func TestFileProvider_SessionDestroy(t *testing.T) { + mutex.Lock() + defer mutex.Unlock() + os.RemoveAll(sessionPath) + defer os.RemoveAll(sessionPath) + fp := &FileProvider{} + + _ = fp.SessionInit(context.Background(), 180, sessionPath) + + _, err := fp.SessionRead(context.Background(), sid) + if err != nil { + t.Error(err) + } + + exists, err := fp.SessionExist(context.Background(), sid) + if err != nil { + t.Error(err) + } + if !exists { + t.Error() + } + + err = fp.SessionDestroy(context.Background(), sid) + if err != nil { + t.Error(err) + } + + exists, err = fp.SessionExist(context.Background(), sid) + if err != nil { + t.Error(err) + } + if exists { + t.Error() + } +} + +func TestFileProvider_SessionGC(t *testing.T) { + mutex.Lock() + defer mutex.Unlock() + os.RemoveAll(sessionPath) + defer os.RemoveAll(sessionPath) + fp := &FileProvider{} + + _ = fp.SessionInit(context.Background(), 1, sessionPath) + + sessionCount := 412 + + for i := 1; i <= sessionCount; i++ { + _, err := fp.SessionRead(context.Background(), fmt.Sprintf("%s_%d", sid, i)) + if err != nil { + t.Error(err) + } + } + + time.Sleep(2 * time.Second) + + fp.SessionGC(nil) + if fp.SessionAll(nil) != 0 { + t.Error() + } +} + +func TestFileSessionStore_Set(t *testing.T) { + mutex.Lock() + defer mutex.Unlock() + os.RemoveAll(sessionPath) + defer os.RemoveAll(sessionPath) + fp := &FileProvider{} + + _ = fp.SessionInit(context.Background(), 180, sessionPath) + + sessionCount := 100 + s, _ := fp.SessionRead(context.Background(), sid) + for i := 1; i <= sessionCount; i++ { + err := s.Set(nil, i, i) + if err != nil { + t.Error(err) + } + } +} + +func TestFileSessionStore_Get(t *testing.T) { + mutex.Lock() + defer mutex.Unlock() + os.RemoveAll(sessionPath) + defer os.RemoveAll(sessionPath) + fp := &FileProvider{} + + _ = fp.SessionInit(context.Background(), 180, sessionPath) + + sessionCount := 100 + s, _ := fp.SessionRead(context.Background(), sid) + for i := 1; i <= sessionCount; i++ { + _ = s.Set(nil, i, i) + + v := s.Get(nil, i) + if v.(int) != i { + t.Error() + } + } +} + +func TestFileSessionStore_Delete(t *testing.T) { + mutex.Lock() + defer mutex.Unlock() + os.RemoveAll(sessionPath) + defer os.RemoveAll(sessionPath) + fp := &FileProvider{} + + _ = fp.SessionInit(context.Background(), 180, sessionPath) + + s, _ := fp.SessionRead(context.Background(), sid) + s.Set(nil, "1", 1) + + if s.Get(nil, "1") == nil { + t.Error() + } + + s.Delete(nil, "1") + + if s.Get(nil, "1") != nil { + t.Error() + } +} + +func TestFileSessionStore_Flush(t *testing.T) { + mutex.Lock() + defer mutex.Unlock() + os.RemoveAll(sessionPath) + defer os.RemoveAll(sessionPath) + fp := &FileProvider{} + + _ = fp.SessionInit(context.Background(), 180, sessionPath) + + sessionCount := 100 + s, _ := fp.SessionRead(context.Background(), sid) + for i := 1; i <= sessionCount; i++ { + _ = s.Set(nil, i, i) + } + + _ = s.Flush(nil) + + for i := 1; i <= sessionCount; i++ { + if s.Get(nil, i) != nil { + t.Error() + } + } +} + +func TestFileSessionStore_SessionID(t *testing.T) { + mutex.Lock() + defer mutex.Unlock() + os.RemoveAll(sessionPath) + defer os.RemoveAll(sessionPath) + fp := &FileProvider{} + + _ = fp.SessionInit(context.Background(), 180, sessionPath) + + sessionCount := 85 + + for i := 1; i <= sessionCount; i++ { + s, err := fp.SessionRead(context.Background(), fmt.Sprintf("%s_%d", sid, i)) + if err != nil { + t.Error(err) + } + if s.SessionID(nil) != fmt.Sprintf("%s_%d", sid, i) { + t.Error(err) + } + } +} + +func TestFileSessionStore_SessionRelease(t *testing.T) { + mutex.Lock() + defer mutex.Unlock() + os.RemoveAll(sessionPath) + defer os.RemoveAll(sessionPath) + fp := &FileProvider{} + + _ = fp.SessionInit(context.Background(), 180, sessionPath) + filepder.savePath = sessionPath + sessionCount := 85 + + for i := 1; i <= sessionCount; i++ { + s, err := fp.SessionRead(context.Background(), fmt.Sprintf("%s_%d", sid, i)) + if err != nil { + t.Error(err) + } + + s.Set(nil, i, i) + s.SessionRelease(nil, nil) + } + + for i := 1; i <= sessionCount; i++ { + s, err := fp.SessionRead(context.Background(), fmt.Sprintf("%s_%d", sid, i)) + if err != nil { + t.Error(err) + } + + if s.Get(nil, i).(int) != i { + t.Error() + } + } +} diff --git a/pkg/session/sess_mem.go b/pkg/infrastructure/session/sess_mem.go similarity index 79% rename from pkg/session/sess_mem.go rename to pkg/infrastructure/session/sess_mem.go index bd69ff80..27e24c73 100644 --- a/pkg/session/sess_mem.go +++ b/pkg/infrastructure/session/sess_mem.go @@ -16,6 +16,7 @@ package session import ( "container/list" + "context" "net/http" "sync" "time" @@ -33,7 +34,7 @@ type MemSessionStore struct { } // Set value to memory session -func (st *MemSessionStore) Set(key, value interface{}) error { +func (st *MemSessionStore) Set(ctx context.Context, key, value interface{}) error { st.lock.Lock() defer st.lock.Unlock() st.value[key] = value @@ -41,7 +42,7 @@ func (st *MemSessionStore) Set(key, value interface{}) error { } // Get value from memory session by key -func (st *MemSessionStore) Get(key interface{}) interface{} { +func (st *MemSessionStore) Get(ctx context.Context, key interface{}) interface{} { st.lock.RLock() defer st.lock.RUnlock() if v, ok := st.value[key]; ok { @@ -51,7 +52,7 @@ func (st *MemSessionStore) Get(key interface{}) interface{} { } // Delete in memory session by key -func (st *MemSessionStore) Delete(key interface{}) error { +func (st *MemSessionStore) Delete(ctx context.Context, key interface{}) error { st.lock.Lock() defer st.lock.Unlock() delete(st.value, key) @@ -59,7 +60,7 @@ func (st *MemSessionStore) Delete(key interface{}) error { } // Flush clear all values in memory session -func (st *MemSessionStore) Flush() error { +func (st *MemSessionStore) Flush(context.Context) error { st.lock.Lock() defer st.lock.Unlock() st.value = make(map[interface{}]interface{}) @@ -67,12 +68,12 @@ func (st *MemSessionStore) Flush() error { } // SessionID get this id of memory session store -func (st *MemSessionStore) SessionID() string { +func (st *MemSessionStore) SessionID(context.Context) string { return st.sid } // SessionRelease Implement method, no used. -func (st *MemSessionStore) SessionRelease(w http.ResponseWriter) { +func (st *MemSessionStore) SessionRelease(ctx context.Context, w http.ResponseWriter) { } // MemProvider Implement the provider interface @@ -85,17 +86,17 @@ type MemProvider struct { } // SessionInit init memory session -func (pder *MemProvider) SessionInit(maxlifetime int64, savePath string) error { +func (pder *MemProvider) SessionInit(ctx context.Context, maxlifetime int64, savePath string) error { pder.maxlifetime = maxlifetime pder.savePath = savePath return nil } // SessionRead get memory session store by sid -func (pder *MemProvider) SessionRead(sid string) (Store, error) { +func (pder *MemProvider) SessionRead(ctx context.Context, sid string) (Store, error) { pder.lock.RLock() if element, ok := pder.sessions[sid]; ok { - go pder.SessionUpdate(sid) + go pder.SessionUpdate(nil, sid) pder.lock.RUnlock() return element.Value.(*MemSessionStore), nil } @@ -109,7 +110,7 @@ func (pder *MemProvider) SessionRead(sid string) (Store, error) { } // SessionExist check session store exist in memory session by sid -func (pder *MemProvider) SessionExist(sid string) (bool, error) { +func (pder *MemProvider) SessionExist(ctx context.Context, sid string) (bool, error) { pder.lock.RLock() defer pder.lock.RUnlock() if _, ok := pder.sessions[sid]; ok { @@ -119,10 +120,10 @@ func (pder *MemProvider) SessionExist(sid string) (bool, error) { } // SessionRegenerate generate new sid for session store in memory session -func (pder *MemProvider) SessionRegenerate(oldsid, sid string) (Store, error) { +func (pder *MemProvider) SessionRegenerate(ctx context.Context, oldsid, sid string) (Store, error) { pder.lock.RLock() if element, ok := pder.sessions[oldsid]; ok { - go pder.SessionUpdate(oldsid) + go pder.SessionUpdate(nil, oldsid) pder.lock.RUnlock() pder.lock.Lock() element.Value.(*MemSessionStore).sid = sid @@ -141,7 +142,7 @@ func (pder *MemProvider) SessionRegenerate(oldsid, sid string) (Store, error) { } // SessionDestroy delete session store in memory session by id -func (pder *MemProvider) SessionDestroy(sid string) error { +func (pder *MemProvider) SessionDestroy(ctx context.Context, sid string) error { pder.lock.Lock() defer pder.lock.Unlock() if element, ok := pder.sessions[sid]; ok { @@ -153,7 +154,7 @@ func (pder *MemProvider) SessionDestroy(sid string) error { } // SessionGC clean expired session stores in memory session -func (pder *MemProvider) SessionGC() { +func (pder *MemProvider) SessionGC(context.Context) { pder.lock.RLock() for { element := pder.list.Back() @@ -175,12 +176,12 @@ func (pder *MemProvider) SessionGC() { } // SessionAll get count number of memory session -func (pder *MemProvider) SessionAll() int { +func (pder *MemProvider) SessionAll(context.Context) int { return pder.list.Len() } // SessionUpdate expand time of session store by id in memory session -func (pder *MemProvider) SessionUpdate(sid string) error { +func (pder *MemProvider) SessionUpdate(ctx context.Context, sid string) error { pder.lock.Lock() defer pder.lock.Unlock() if element, ok := pder.sessions[sid]; ok { diff --git a/pkg/infrastructure/session/sess_mem_test.go b/pkg/infrastructure/session/sess_mem_test.go new file mode 100644 index 00000000..e6d35476 --- /dev/null +++ b/pkg/infrastructure/session/sess_mem_test.go @@ -0,0 +1,58 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package session + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "strings" + "testing" +) + +func TestMem(t *testing.T) { + config := `{"cookieName":"gosessionid","gclifetime":10, "enableSetCookie":true}` + conf := new(ManagerConfig) + if err := json.Unmarshal([]byte(config), conf); err != nil { + t.Fatal("json decode error", err) + } + globalSessions, _ := NewManager("memory", conf) + go globalSessions.GC() + r, _ := http.NewRequest("GET", "/", nil) + w := httptest.NewRecorder() + sess, err := globalSessions.SessionStart(w, r) + if err != nil { + t.Fatal("set error,", err) + } + defer sess.SessionRelease(nil, w) + err = sess.Set(nil, "username", "astaxie") + if err != nil { + t.Fatal("set error,", err) + } + if username := sess.Get(nil, "username"); username != "astaxie" { + t.Fatal("get username error") + } + if cookiestr := w.Header().Get("Set-Cookie"); cookiestr == "" { + t.Fatal("setcookie error") + } else { + parts := strings.Split(strings.TrimSpace(cookiestr), ";") + for k, v := range parts { + nameval := strings.Split(v, "=") + if k == 0 && nameval[0] != "gosessionid" { + t.Fatal("error") + } + } + } +} diff --git a/pkg/session/sess_test.go b/pkg/infrastructure/session/sess_test.go similarity index 100% rename from pkg/session/sess_test.go rename to pkg/infrastructure/session/sess_test.go diff --git a/pkg/session/sess_utils.go b/pkg/infrastructure/session/sess_utils.go similarity index 99% rename from pkg/session/sess_utils.go rename to pkg/infrastructure/session/sess_utils.go index 5b70afa8..906e1c4b 100644 --- a/pkg/session/sess_utils.go +++ b/pkg/infrastructure/session/sess_utils.go @@ -29,7 +29,7 @@ import ( "strconv" "time" - "github.com/astaxie/beego/pkg/utils" + "github.com/astaxie/beego/pkg/infrastructure/utils" ) func init() { diff --git a/pkg/session/session.go b/pkg/infrastructure/session/session.go similarity index 87% rename from pkg/session/session.go rename to pkg/infrastructure/session/session.go index 92e35de4..bb7e5bd6 100644 --- a/pkg/session/session.go +++ b/pkg/infrastructure/session/session.go @@ -28,6 +28,7 @@ package session import ( + "context" "crypto/rand" "encoding/hex" "errors" @@ -43,24 +44,24 @@ import ( // Store contains all data for one session process with specific id. type Store interface { - Set(key, value interface{}) error //set session value - Get(key interface{}) interface{} //get session value - Delete(key interface{}) error //delete session value - SessionID() string //back current sessionID - SessionRelease(w http.ResponseWriter) // release the resource & save data to provider & return the data - Flush() error //delete all data + Set(ctx context.Context, key, value interface{}) error //set session value + Get(ctx context.Context, key interface{}) interface{} //get session value + Delete(ctx context.Context, key interface{}) error //delete session value + SessionID(ctx context.Context) string //back current sessionID + SessionRelease(ctx context.Context, w http.ResponseWriter) // release the resource & save data to provider & return the data + Flush(ctx context.Context) error //delete all data } // Provider contains global session methods and saved SessionStores. // it can operate a SessionStore by its id. type Provider interface { - SessionInit(gclifetime int64, config string) error - SessionRead(sid string) (Store, error) - SessionExist(sid string) (bool, error) - SessionRegenerate(oldsid, sid string) (Store, error) - SessionDestroy(sid string) error - SessionAll() int //get all active session - SessionGC() + SessionInit(ctx context.Context, gclifetime int64, config string) error + SessionRead(ctx context.Context, sid string) (Store, error) + SessionExist(ctx context.Context, sid string) (bool, error) + SessionRegenerate(ctx context.Context, oldsid, sid string) (Store, error) + SessionDestroy(ctx context.Context, sid string) error + SessionAll(ctx context.Context) int //get all active session + SessionGC(ctx context.Context) } var provides = make(map[string]Provider) @@ -148,7 +149,7 @@ func NewManager(provideName string, cf *ManagerConfig) (*Manager, error) { } } - err := provider.SessionInit(cf.Maxlifetime, cf.ProviderConfig) + err := provider.SessionInit(nil, cf.Maxlifetime, cf.ProviderConfig) if err != nil { return nil, err } @@ -212,12 +213,12 @@ func (manager *Manager) SessionStart(w http.ResponseWriter, r *http.Request) (se } if sid != "" { - exists, err := manager.provider.SessionExist(sid) + exists, err := manager.provider.SessionExist(nil, sid) if err != nil { return nil, err } if exists { - return manager.provider.SessionRead(sid) + return manager.provider.SessionRead(nil, sid) } } @@ -227,7 +228,7 @@ func (manager *Manager) SessionStart(w http.ResponseWriter, r *http.Request) (se return nil, errs } - session, err = manager.provider.SessionRead(sid) + session, err = manager.provider.SessionRead(nil, sid) if err != nil { return nil, err } @@ -269,7 +270,7 @@ func (manager *Manager) SessionDestroy(w http.ResponseWriter, r *http.Request) { } sid, _ := url.QueryUnescape(cookie.Value) - manager.provider.SessionDestroy(sid) + manager.provider.SessionDestroy(nil, sid) if manager.config.EnableSetCookie { expiration := time.Now() cookie = &http.Cookie{Name: manager.config.CookieName, @@ -285,14 +286,14 @@ func (manager *Manager) SessionDestroy(w http.ResponseWriter, r *http.Request) { // GetSessionStore Get SessionStore by its id. func (manager *Manager) GetSessionStore(sid string) (sessions Store, err error) { - sessions, err = manager.provider.SessionRead(sid) + sessions, err = manager.provider.SessionRead(nil, sid) return } // GC Start session gc process. // it can do gc in times after gc lifetime. func (manager *Manager) GC() { - manager.provider.SessionGC() + manager.provider.SessionGC(nil) time.AfterFunc(time.Duration(manager.config.Gclifetime)*time.Second, func() { manager.GC() }) } @@ -305,7 +306,7 @@ func (manager *Manager) SessionRegenerateID(w http.ResponseWriter, r *http.Reque cookie, err := r.Cookie(manager.config.CookieName) if err != nil || cookie.Value == "" { //delete old cookie - session, _ = manager.provider.SessionRead(sid) + session, _ = manager.provider.SessionRead(nil, sid) cookie = &http.Cookie{Name: manager.config.CookieName, Value: url.QueryEscape(sid), Path: "/", @@ -315,7 +316,7 @@ func (manager *Manager) SessionRegenerateID(w http.ResponseWriter, r *http.Reque } } else { oldsid, _ := url.QueryUnescape(cookie.Value) - session, _ = manager.provider.SessionRegenerate(oldsid, sid) + session, _ = manager.provider.SessionRegenerate(nil, oldsid, sid) cookie.Value = url.QueryEscape(sid) cookie.HttpOnly = true cookie.Path = "/" @@ -339,7 +340,7 @@ func (manager *Manager) SessionRegenerateID(w http.ResponseWriter, r *http.Reque // GetActiveSession Get all active sessions count number. func (manager *Manager) GetActiveSession() int { - return manager.provider.SessionAll() + return manager.provider.SessionAll(nil) } // SetSecure Set cookie with https. diff --git a/pkg/session/ssdb/sess_ssdb.go b/pkg/infrastructure/session/ssdb/sess_ssdb.go similarity index 77% rename from pkg/session/ssdb/sess_ssdb.go rename to pkg/infrastructure/session/ssdb/sess_ssdb.go index f950b835..6e4f341e 100644 --- a/pkg/session/ssdb/sess_ssdb.go +++ b/pkg/infrastructure/session/ssdb/sess_ssdb.go @@ -1,13 +1,14 @@ package ssdb import ( + "context" "errors" "net/http" "strconv" "strings" "sync" - "github.com/astaxie/beego/pkg/session" + "github.com/astaxie/beego/pkg/infrastructure/session" "github.com/ssdb/gossdb/ssdb" ) @@ -31,7 +32,7 @@ func (p *Provider) connectInit() error { } // SessionInit init the ssdb with the config -func (p *Provider) SessionInit(maxLifetime int64, savePath string) error { +func (p *Provider) SessionInit(ctx context.Context, maxLifetime int64, savePath string) error { p.maxLifetime = maxLifetime address := strings.Split(savePath, ":") p.host = address[0] @@ -44,7 +45,7 @@ func (p *Provider) SessionInit(maxLifetime int64, savePath string) error { } // SessionRead return a ssdb client session Store -func (p *Provider) SessionRead(sid string) (session.Store, error) { +func (p *Provider) SessionRead(ctx context.Context, sid string) (session.Store, error) { if p.client == nil { if err := p.connectInit(); err != nil { return nil, err @@ -68,7 +69,7 @@ func (p *Provider) SessionRead(sid string) (session.Store, error) { } // SessionExist judged whether sid is exist in session -func (p *Provider) SessionExist(sid string) (bool, error) { +func (p *Provider) SessionExist(ctx context.Context, sid string) (bool, error) { if p.client == nil { if err := p.connectInit(); err != nil { return false, err @@ -85,7 +86,7 @@ func (p *Provider) SessionExist(sid string) (bool, error) { } // SessionRegenerate regenerate session with new sid and delete oldsid -func (p *Provider) SessionRegenerate(oldsid, sid string) (session.Store, error) { +func (p *Provider) SessionRegenerate(ctx context.Context, oldsid, sid string) (session.Store, error) { //conn.Do("setx", key, v, ttl) if p.client == nil { if err := p.connectInit(); err != nil { @@ -118,7 +119,7 @@ func (p *Provider) SessionRegenerate(oldsid, sid string) (session.Store, error) } // SessionDestroy destroy the sid -func (p *Provider) SessionDestroy(sid string) error { +func (p *Provider) SessionDestroy(ctx context.Context, sid string) error { if p.client == nil { if err := p.connectInit(); err != nil { return err @@ -129,11 +130,11 @@ func (p *Provider) SessionDestroy(sid string) error { } // SessionGC not implemented -func (p *Provider) SessionGC() { +func (p *Provider) SessionGC(context.Context) { } // SessionAll not implemented -func (p *Provider) SessionAll() int { +func (p *Provider) SessionAll(context.Context) int { return 0 } @@ -147,7 +148,7 @@ type SessionStore struct { } // Set the key and value -func (s *SessionStore) Set(key, value interface{}) error { +func (s *SessionStore) Set(ctx context.Context, key, value interface{}) error { s.lock.Lock() defer s.lock.Unlock() s.values[key] = value @@ -155,7 +156,7 @@ func (s *SessionStore) Set(key, value interface{}) error { } // Get return the value by the key -func (s *SessionStore) Get(key interface{}) interface{} { +func (s *SessionStore) Get(ctx context.Context, key interface{}) interface{} { s.lock.Lock() defer s.lock.Unlock() if value, ok := s.values[key]; ok { @@ -165,7 +166,7 @@ func (s *SessionStore) Get(key interface{}) interface{} { } // Delete the key in session store -func (s *SessionStore) Delete(key interface{}) error { +func (s *SessionStore) Delete(ctx context.Context, key interface{}) error { s.lock.Lock() defer s.lock.Unlock() delete(s.values, key) @@ -173,7 +174,7 @@ func (s *SessionStore) Delete(key interface{}) error { } // Flush delete all keys and values -func (s *SessionStore) Flush() error { +func (s *SessionStore) Flush(context.Context) error { s.lock.Lock() defer s.lock.Unlock() s.values = make(map[interface{}]interface{}) @@ -181,12 +182,12 @@ func (s *SessionStore) Flush() error { } // SessionID return the sessionID -func (s *SessionStore) SessionID() string { +func (s *SessionStore) SessionID(context.Context) string { return s.sid } // SessionRelease Store the keyvalues into ssdb -func (s *SessionStore) SessionRelease(w http.ResponseWriter) { +func (s *SessionStore) SessionRelease(ctx context.Context, w http.ResponseWriter) { b, err := session.EncodeGob(s.values) if err != nil { return diff --git a/pkg/utils/caller.go b/pkg/infrastructure/utils/caller.go similarity index 100% rename from pkg/utils/caller.go rename to pkg/infrastructure/utils/caller.go diff --git a/pkg/infrastructure/utils/caller_test.go b/pkg/infrastructure/utils/caller_test.go new file mode 100644 index 00000000..0675f0aa --- /dev/null +++ b/pkg/infrastructure/utils/caller_test.go @@ -0,0 +1,28 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package utils + +import ( + "strings" + "testing" +) + +func TestGetFuncName(t *testing.T) { + name := GetFuncName(TestGetFuncName) + t.Log(name) + if !strings.HasSuffix(name, ".TestGetFuncName") { + t.Error("get func name error") + } +} diff --git a/pkg/utils/debug.go b/pkg/infrastructure/utils/debug.go similarity index 100% rename from pkg/utils/debug.go rename to pkg/infrastructure/utils/debug.go diff --git a/pkg/infrastructure/utils/debug_test.go b/pkg/infrastructure/utils/debug_test.go new file mode 100644 index 00000000..efb8924e --- /dev/null +++ b/pkg/infrastructure/utils/debug_test.go @@ -0,0 +1,46 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package utils + +import ( + "testing" +) + +type mytype struct { + next *mytype + prev *mytype +} + +func TestPrint(t *testing.T) { + Display("v1", 1, "v2", 2, "v3", 3) +} + +func TestPrintPoint(t *testing.T) { + var v1 = new(mytype) + var v2 = new(mytype) + + v1.prev = nil + v1.next = v2 + + v2.prev = v1 + v2.next = nil + + Display("v1", v1, "v2", v2) +} + +func TestPrintString(t *testing.T) { + str := GetDisplayString("v1", 1, "v2", 2) + println(str) +} diff --git a/pkg/utils/file.go b/pkg/infrastructure/utils/file.go similarity index 100% rename from pkg/utils/file.go rename to pkg/infrastructure/utils/file.go diff --git a/pkg/utils/file_test.go b/pkg/infrastructure/utils/file_test.go similarity index 100% rename from pkg/utils/file_test.go rename to pkg/infrastructure/utils/file_test.go diff --git a/pkg/common/kv.go b/pkg/infrastructure/utils/kv.go similarity index 99% rename from pkg/common/kv.go rename to pkg/infrastructure/utils/kv.go index 80797aa9..f4e6c4d4 100644 --- a/pkg/common/kv.go +++ b/pkg/infrastructure/utils/kv.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package common +package utils type KV interface { GetKey() interface{} diff --git a/pkg/common/kv_test.go b/pkg/infrastructure/utils/kv_test.go similarity index 98% rename from pkg/common/kv_test.go rename to pkg/infrastructure/utils/kv_test.go index 7b52a300..4c9643dc 100644 --- a/pkg/common/kv_test.go +++ b/pkg/infrastructure/utils/kv_test.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package common +package utils import ( "testing" diff --git a/pkg/utils/mail.go b/pkg/infrastructure/utils/mail.go similarity index 100% rename from pkg/utils/mail.go rename to pkg/infrastructure/utils/mail.go diff --git a/pkg/infrastructure/utils/mail_test.go b/pkg/infrastructure/utils/mail_test.go new file mode 100644 index 00000000..c38356a2 --- /dev/null +++ b/pkg/infrastructure/utils/mail_test.go @@ -0,0 +1,41 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package utils + +import "testing" + +func TestMail(t *testing.T) { + config := `{"username":"astaxie@gmail.com","password":"astaxie","host":"smtp.gmail.com","port":587}` + mail := NewEMail(config) + if mail.Username != "astaxie@gmail.com" { + t.Fatal("email parse get username error") + } + if mail.Password != "astaxie" { + t.Fatal("email parse get password error") + } + if mail.Host != "smtp.gmail.com" { + t.Fatal("email parse get host error") + } + if mail.Port != 587 { + t.Fatal("email parse get port error") + } + mail.To = []string{"xiemengjun@gmail.com"} + mail.From = "astaxie@gmail.com" + mail.Subject = "hi, just from beego!" + mail.Text = "Text Body is, of course, supported!" + mail.HTML = "

Fancy Html is supported, too!

" + mail.AttachFile("/Users/astaxie/github/beego/beego.go") + mail.Send() +} diff --git a/pkg/infrastructure/utils/pagination/doc.go b/pkg/infrastructure/utils/pagination/doc.go new file mode 100644 index 00000000..86a2ba5d --- /dev/null +++ b/pkg/infrastructure/utils/pagination/doc.go @@ -0,0 +1,58 @@ +/* +Package pagination provides utilities to setup a paginator within the +context of a http request. + +Usage + +In your beego.Controller: + + package controllers + + import "github.com/astaxie/beego/pkg/infrastructure/utils/pagination" + + type PostsController struct { + beego.Controller + } + + func (this *PostsController) ListAllPosts() { + // sets this.Data["paginator"] with the current offset (from the url query param) + postsPerPage := 20 + paginator := pagination.SetPaginator(this.Ctx, postsPerPage, CountPosts()) + + // fetch the next 20 posts + this.Data["posts"] = ListPostsByOffsetAndLimit(paginator.Offset(), postsPerPage) + } + + +In your view templates: + + {{if .paginator.HasPages}} + + {{end}} + +See also + +http://beego.me/docs/mvc/view/page.md + +*/ +package pagination diff --git a/pkg/utils/pagination/paginator.go b/pkg/infrastructure/utils/pagination/paginator.go similarity index 100% rename from pkg/utils/pagination/paginator.go rename to pkg/infrastructure/utils/pagination/paginator.go diff --git a/pkg/utils/pagination/utils.go b/pkg/infrastructure/utils/pagination/utils.go similarity index 100% rename from pkg/utils/pagination/utils.go rename to pkg/infrastructure/utils/pagination/utils.go diff --git a/pkg/utils/rand.go b/pkg/infrastructure/utils/rand.go similarity index 100% rename from pkg/utils/rand.go rename to pkg/infrastructure/utils/rand.go diff --git a/pkg/infrastructure/utils/rand_test.go b/pkg/infrastructure/utils/rand_test.go new file mode 100644 index 00000000..6c238b5e --- /dev/null +++ b/pkg/infrastructure/utils/rand_test.go @@ -0,0 +1,33 @@ +// Copyright 2016 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package utils + +import "testing" + +func TestRand_01(t *testing.T) { + bs0 := RandomCreateBytes(16) + bs1 := RandomCreateBytes(16) + + t.Log(string(bs0), string(bs1)) + if string(bs0) == string(bs1) { + t.FailNow() + } + + bs0 = RandomCreateBytes(4, []byte(`a`)...) + + if string(bs0) != "aaaa" { + t.FailNow() + } +} diff --git a/pkg/utils/safemap.go b/pkg/infrastructure/utils/safemap.go similarity index 100% rename from pkg/utils/safemap.go rename to pkg/infrastructure/utils/safemap.go diff --git a/pkg/infrastructure/utils/safemap_test.go b/pkg/infrastructure/utils/safemap_test.go new file mode 100644 index 00000000..65085195 --- /dev/null +++ b/pkg/infrastructure/utils/safemap_test.go @@ -0,0 +1,89 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package utils + +import "testing" + +var safeMap *BeeMap + +func TestNewBeeMap(t *testing.T) { + safeMap = NewBeeMap() + if safeMap == nil { + t.Fatal("expected to return non-nil BeeMap", "got", safeMap) + } +} + +func TestSet(t *testing.T) { + safeMap = NewBeeMap() + if ok := safeMap.Set("astaxie", 1); !ok { + t.Error("expected", true, "got", false) + } +} + +func TestReSet(t *testing.T) { + safeMap := NewBeeMap() + if ok := safeMap.Set("astaxie", 1); !ok { + t.Error("expected", true, "got", false) + } + // set diff value + if ok := safeMap.Set("astaxie", -1); !ok { + t.Error("expected", true, "got", false) + } + + // set same value + if ok := safeMap.Set("astaxie", -1); ok { + t.Error("expected", false, "got", true) + } +} + +func TestCheck(t *testing.T) { + if exists := safeMap.Check("astaxie"); !exists { + t.Error("expected", true, "got", false) + } +} + +func TestGet(t *testing.T) { + if val := safeMap.Get("astaxie"); val.(int) != 1 { + t.Error("expected value", 1, "got", val) + } +} + +func TestDelete(t *testing.T) { + safeMap.Delete("astaxie") + if exists := safeMap.Check("astaxie"); exists { + t.Error("expected element to be deleted") + } +} + +func TestItems(t *testing.T) { + safeMap := NewBeeMap() + safeMap.Set("astaxie", "hello") + for k, v := range safeMap.Items() { + key := k.(string) + value := v.(string) + if key != "astaxie" { + t.Error("expected the key should be astaxie") + } + if value != "hello" { + t.Error("expected the value should be hello") + } + } +} + +func TestCount(t *testing.T) { + if count := safeMap.Count(); count != 0 { + t.Error("expected count to be", 0, "got", count) + } +} diff --git a/pkg/utils/slice.go b/pkg/infrastructure/utils/slice.go similarity index 100% rename from pkg/utils/slice.go rename to pkg/infrastructure/utils/slice.go diff --git a/pkg/infrastructure/utils/slice_test.go b/pkg/infrastructure/utils/slice_test.go new file mode 100644 index 00000000..142dec96 --- /dev/null +++ b/pkg/infrastructure/utils/slice_test.go @@ -0,0 +1,29 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package utils + +import ( + "testing" +) + +func TestInSlice(t *testing.T) { + sl := []string{"A", "b"} + if !InSlice("A", sl) { + t.Error("should be true") + } + if InSlice("B", sl) { + t.Error("should be false") + } +} diff --git a/pkg/utils/testdata/grepe.test b/pkg/infrastructure/utils/testdata/grepe.test similarity index 100% rename from pkg/utils/testdata/grepe.test rename to pkg/infrastructure/utils/testdata/grepe.test diff --git a/pkg/infrastructure/utils/time.go b/pkg/infrastructure/utils/time.go new file mode 100644 index 00000000..579b292a --- /dev/null +++ b/pkg/infrastructure/utils/time.go @@ -0,0 +1,48 @@ +// Copyright 2020 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package utils + +import ( + "fmt" + "time" +) + +// short string format +func ToShortTimeFormat(d time.Duration) string { + + u := uint64(d) + if u < uint64(time.Second) { + switch { + case u == 0: + return "0" + case u < uint64(time.Microsecond): + return fmt.Sprintf("%.2fns", float64(u)) + case u < uint64(time.Millisecond): + return fmt.Sprintf("%.2fus", float64(u)/1000) + default: + return fmt.Sprintf("%.2fms", float64(u)/1000/1000) + } + } else { + switch { + case u < uint64(time.Minute): + return fmt.Sprintf("%.2fs", float64(u)/1000/1000/1000) + case u < uint64(time.Hour): + return fmt.Sprintf("%.2fm", float64(u)/1000/1000/1000/60) + default: + return fmt.Sprintf("%.2fh", float64(u)/1000/1000/1000/60/60) + } + } + +} diff --git a/pkg/utils/utils.go b/pkg/infrastructure/utils/utils.go similarity index 100% rename from pkg/utils/utils.go rename to pkg/infrastructure/utils/utils.go diff --git a/pkg/utils/utils_test.go b/pkg/infrastructure/utils/utils_test.go similarity index 100% rename from pkg/utils/utils_test.go rename to pkg/infrastructure/utils/utils_test.go diff --git a/pkg/validation/README.md b/pkg/infrastructure/validation/README.md similarity index 100% rename from pkg/validation/README.md rename to pkg/infrastructure/validation/README.md diff --git a/pkg/validation/util.go b/pkg/infrastructure/validation/util.go similarity index 100% rename from pkg/validation/util.go rename to pkg/infrastructure/validation/util.go diff --git a/pkg/validation/util_test.go b/pkg/infrastructure/validation/util_test.go similarity index 100% rename from pkg/validation/util_test.go rename to pkg/infrastructure/validation/util_test.go diff --git a/pkg/validation/validation.go b/pkg/infrastructure/validation/validation.go similarity index 100% rename from pkg/validation/validation.go rename to pkg/infrastructure/validation/validation.go diff --git a/pkg/infrastructure/validation/validation_test.go b/pkg/infrastructure/validation/validation_test.go new file mode 100644 index 00000000..b4b5b1b6 --- /dev/null +++ b/pkg/infrastructure/validation/validation_test.go @@ -0,0 +1,609 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package validation + +import ( + "regexp" + "testing" + "time" +) + +func TestRequired(t *testing.T) { + valid := Validation{} + + if valid.Required(nil, "nil").Ok { + t.Error("nil object should be false") + } + if !valid.Required(true, "bool").Ok { + t.Error("Bool value should always return true") + } + if !valid.Required(false, "bool").Ok { + t.Error("Bool value should always return true") + } + if valid.Required("", "string").Ok { + t.Error("\"'\" string should be false") + } + if valid.Required(" ", "string").Ok { + t.Error("\" \" string should be false") // For #2361 + } + if valid.Required("\n", "string").Ok { + t.Error("new line string should be false") // For #2361 + } + if !valid.Required("astaxie", "string").Ok { + t.Error("string should be true") + } + if valid.Required(0, "zero").Ok { + t.Error("Integer should not be equal 0") + } + if !valid.Required(1, "int").Ok { + t.Error("Integer except 0 should be true") + } + if !valid.Required(time.Now(), "time").Ok { + t.Error("time should be true") + } + if valid.Required([]string{}, "emptySlice").Ok { + t.Error("empty slice should be false") + } + if !valid.Required([]interface{}{"ok"}, "slice").Ok { + t.Error("slice should be true") + } +} + +func TestMin(t *testing.T) { + valid := Validation{} + + if valid.Min(-1, 0, "min0").Ok { + t.Error("-1 is less than the minimum value of 0 should be false") + } + if !valid.Min(1, 0, "min0").Ok { + t.Error("1 is greater or equal than the minimum value of 0 should be true") + } +} + +func TestMax(t *testing.T) { + valid := Validation{} + + if valid.Max(1, 0, "max0").Ok { + t.Error("1 is greater than the minimum value of 0 should be false") + } + if !valid.Max(-1, 0, "max0").Ok { + t.Error("-1 is less or equal than the maximum value of 0 should be true") + } +} + +func TestRange(t *testing.T) { + valid := Validation{} + + if valid.Range(-1, 0, 1, "range0_1").Ok { + t.Error("-1 is between 0 and 1 should be false") + } + if !valid.Range(1, 0, 1, "range0_1").Ok { + t.Error("1 is between 0 and 1 should be true") + } +} + +func TestMinSize(t *testing.T) { + valid := Validation{} + + if valid.MinSize("", 1, "minSize1").Ok { + t.Error("the length of \"\" is less than the minimum value of 1 should be false") + } + if !valid.MinSize("ok", 1, "minSize1").Ok { + t.Error("the length of \"ok\" is greater or equal than the minimum value of 1 should be true") + } + if valid.MinSize([]string{}, 1, "minSize1").Ok { + t.Error("the length of empty slice is less than the minimum value of 1 should be false") + } + if !valid.MinSize([]interface{}{"ok"}, 1, "minSize1").Ok { + t.Error("the length of [\"ok\"] is greater or equal than the minimum value of 1 should be true") + } +} + +func TestMaxSize(t *testing.T) { + valid := Validation{} + + if valid.MaxSize("ok", 1, "maxSize1").Ok { + t.Error("the length of \"ok\" is greater than the maximum value of 1 should be false") + } + if !valid.MaxSize("", 1, "maxSize1").Ok { + t.Error("the length of \"\" is less or equal than the maximum value of 1 should be true") + } + if valid.MaxSize([]interface{}{"ok", false}, 1, "maxSize1").Ok { + t.Error("the length of [\"ok\", false] is greater than the maximum value of 1 should be false") + } + if !valid.MaxSize([]string{}, 1, "maxSize1").Ok { + t.Error("the length of empty slice is less or equal than the maximum value of 1 should be true") + } +} + +func TestLength(t *testing.T) { + valid := Validation{} + + if valid.Length("", 1, "length1").Ok { + t.Error("the length of \"\" must equal 1 should be false") + } + if !valid.Length("1", 1, "length1").Ok { + t.Error("the length of \"1\" must equal 1 should be true") + } + if valid.Length([]string{}, 1, "length1").Ok { + t.Error("the length of empty slice must equal 1 should be false") + } + if !valid.Length([]interface{}{"ok"}, 1, "length1").Ok { + t.Error("the length of [\"ok\"] must equal 1 should be true") + } +} + +func TestAlpha(t *testing.T) { + valid := Validation{} + + if valid.Alpha("a,1-@ $", "alpha").Ok { + t.Error("\"a,1-@ $\" are valid alpha characters should be false") + } + if !valid.Alpha("abCD", "alpha").Ok { + t.Error("\"abCD\" are valid alpha characters should be true") + } +} + +func TestNumeric(t *testing.T) { + valid := Validation{} + + if valid.Numeric("a,1-@ $", "numeric").Ok { + t.Error("\"a,1-@ $\" are valid numeric characters should be false") + } + if !valid.Numeric("1234", "numeric").Ok { + t.Error("\"1234\" are valid numeric characters should be true") + } +} + +func TestAlphaNumeric(t *testing.T) { + valid := Validation{} + + if valid.AlphaNumeric("a,1-@ $", "alphaNumeric").Ok { + t.Error("\"a,1-@ $\" are valid alpha or numeric characters should be false") + } + if !valid.AlphaNumeric("1234aB", "alphaNumeric").Ok { + t.Error("\"1234aB\" are valid alpha or numeric characters should be true") + } +} + +func TestMatch(t *testing.T) { + valid := Validation{} + + if valid.Match("suchuangji@gmail", regexp.MustCompile(`^\w+@\w+\.\w+$`), "match").Ok { + t.Error("\"suchuangji@gmail\" match \"^\\w+@\\w+\\.\\w+$\" should be false") + } + if !valid.Match("suchuangji@gmail.com", regexp.MustCompile(`^\w+@\w+\.\w+$`), "match").Ok { + t.Error("\"suchuangji@gmail\" match \"^\\w+@\\w+\\.\\w+$\" should be true") + } +} + +func TestNoMatch(t *testing.T) { + valid := Validation{} + + if valid.NoMatch("123@gmail", regexp.MustCompile(`[^\w\d]`), "nomatch").Ok { + t.Error("\"123@gmail\" not match \"[^\\w\\d]\" should be false") + } + if !valid.NoMatch("123gmail", regexp.MustCompile(`[^\w\d]`), "match").Ok { + t.Error("\"123@gmail\" not match \"[^\\w\\d@]\" should be true") + } +} + +func TestAlphaDash(t *testing.T) { + valid := Validation{} + + if valid.AlphaDash("a,1-@ $", "alphaDash").Ok { + t.Error("\"a,1-@ $\" are valid alpha or numeric or dash(-_) characters should be false") + } + if !valid.AlphaDash("1234aB-_", "alphaDash").Ok { + t.Error("\"1234aB\" are valid alpha or numeric or dash(-_) characters should be true") + } +} + +func TestEmail(t *testing.T) { + valid := Validation{} + + if valid.Email("not@a email", "email").Ok { + t.Error("\"not@a email\" is a valid email address should be false") + } + if !valid.Email("suchuangji@gmail.com", "email").Ok { + t.Error("\"suchuangji@gmail.com\" is a valid email address should be true") + } + if valid.Email("@suchuangji@gmail.com", "email").Ok { + t.Error("\"@suchuangji@gmail.com\" is a valid email address should be false") + } + if valid.Email("suchuangji@gmail.com ok", "email").Ok { + t.Error("\"suchuangji@gmail.com ok\" is a valid email address should be false") + } +} + +func TestIP(t *testing.T) { + valid := Validation{} + + if valid.IP("11.255.255.256", "IP").Ok { + t.Error("\"11.255.255.256\" is a valid ip address should be false") + } + if !valid.IP("01.11.11.11", "IP").Ok { + t.Error("\"suchuangji@gmail.com\" is a valid ip address should be true") + } +} + +func TestBase64(t *testing.T) { + valid := Validation{} + + if valid.Base64("suchuangji@gmail.com", "base64").Ok { + t.Error("\"suchuangji@gmail.com\" are a valid base64 characters should be false") + } + if !valid.Base64("c3VjaHVhbmdqaUBnbWFpbC5jb20=", "base64").Ok { + t.Error("\"c3VjaHVhbmdqaUBnbWFpbC5jb20=\" are a valid base64 characters should be true") + } +} + +func TestMobile(t *testing.T) { + valid := Validation{} + + validMobiles := []string{ + "19800008888", + "18800008888", + "18000008888", + "8618300008888", + "+8614700008888", + "17300008888", + "+8617100008888", + "8617500008888", + "8617400008888", + "16200008888", + "16500008888", + "16600008888", + "16700008888", + "13300008888", + "14900008888", + "15300008888", + "17300008888", + "17700008888", + "18000008888", + "18900008888", + "19100008888", + "19900008888", + "19300008888", + "13000008888", + "13100008888", + "13200008888", + "14500008888", + "15500008888", + "15600008888", + "16600008888", + "17100008888", + "17500008888", + "17600008888", + "18500008888", + "18600008888", + "13400008888", + "13500008888", + "13600008888", + "13700008888", + "13800008888", + "13900008888", + "14700008888", + "15000008888", + "15100008888", + "15200008888", + "15800008888", + "15900008888", + "17200008888", + "17800008888", + "18200008888", + "18300008888", + "18400008888", + "18700008888", + "18800008888", + "19800008888", + } + + for _, m := range validMobiles { + if !valid.Mobile(m, "mobile").Ok { + t.Error(m + " is a valid mobile phone number should be true") + } + } +} + +func TestTel(t *testing.T) { + valid := Validation{} + + if valid.Tel("222-00008888", "telephone").Ok { + t.Error("\"222-00008888\" is a valid telephone number should be false") + } + if !valid.Tel("022-70008888", "telephone").Ok { + t.Error("\"022-70008888\" is a valid telephone number should be true") + } + if !valid.Tel("02270008888", "telephone").Ok { + t.Error("\"02270008888\" is a valid telephone number should be true") + } + if !valid.Tel("70008888", "telephone").Ok { + t.Error("\"70008888\" is a valid telephone number should be true") + } +} + +func TestPhone(t *testing.T) { + valid := Validation{} + + if valid.Phone("222-00008888", "phone").Ok { + t.Error("\"222-00008888\" is a valid phone number should be false") + } + if !valid.Mobile("+8614700008888", "phone").Ok { + t.Error("\"+8614700008888\" is a valid phone number should be true") + } + if !valid.Tel("02270008888", "phone").Ok { + t.Error("\"02270008888\" is a valid phone number should be true") + } +} + +func TestZipCode(t *testing.T) { + valid := Validation{} + + if valid.ZipCode("", "zipcode").Ok { + t.Error("\"00008888\" is a valid zipcode should be false") + } + if !valid.ZipCode("536000", "zipcode").Ok { + t.Error("\"536000\" is a valid zipcode should be true") + } +} + +func TestValid(t *testing.T) { + type user struct { + ID int + Name string `valid:"Required;Match(/^(test)?\\w*@(/test/);com$/)"` + Age int `valid:"Required;Range(1, 140)"` + } + valid := Validation{} + + u := user{Name: "test@/test/;com", Age: 40} + b, err := valid.Valid(u) + if err != nil { + t.Fatal(err) + } + if !b { + t.Error("validation should be passed") + } + + uptr := &user{Name: "test", Age: 40} + valid.Clear() + b, err = valid.Valid(uptr) + if err != nil { + t.Fatal(err) + } + if b { + t.Error("validation should not be passed") + } + if len(valid.Errors) != 1 { + t.Fatalf("valid errors len should be 1 but got %d", len(valid.Errors)) + } + if valid.Errors[0].Key != "Name.Match" { + t.Errorf("Message key should be `Name.Match` but got %s", valid.Errors[0].Key) + } + + u = user{Name: "test@/test/;com", Age: 180} + valid.Clear() + b, err = valid.Valid(u) + if err != nil { + t.Fatal(err) + } + if b { + t.Error("validation should not be passed") + } + if len(valid.Errors) != 1 { + t.Fatalf("valid errors len should be 1 but got %d", len(valid.Errors)) + } + if valid.Errors[0].Key != "Age.Range." { + t.Errorf("Message key should be `Age.Range` but got %s", valid.Errors[0].Key) + } +} + +func TestRecursiveValid(t *testing.T) { + type User struct { + ID int + Name string `valid:"Required;Match(/^(test)?\\w*@(/test/);com$/)"` + Age int `valid:"Required;Range(1, 140)"` + } + + type AnonymouseUser struct { + ID2 int + Name2 string `valid:"Required;Match(/^(test)?\\w*@(/test/);com$/)"` + Age2 int `valid:"Required;Range(1, 140)"` + } + + type Account struct { + Password string `valid:"Required"` + U User + AnonymouseUser + } + valid := Validation{} + + u := Account{Password: "abc123_", U: User{}} + b, err := valid.RecursiveValid(u) + if err != nil { + t.Fatal(err) + } + if b { + t.Error("validation should not be passed") + } +} + +func TestSkipValid(t *testing.T) { + type User struct { + ID int + + Email string `valid:"Email"` + ReqEmail string `valid:"Required;Email"` + + IP string `valid:"IP"` + ReqIP string `valid:"Required;IP"` + + Mobile string `valid:"Mobile"` + ReqMobile string `valid:"Required;Mobile"` + + Tel string `valid:"Tel"` + ReqTel string `valid:"Required;Tel"` + + Phone string `valid:"Phone"` + ReqPhone string `valid:"Required;Phone"` + + ZipCode string `valid:"ZipCode"` + ReqZipCode string `valid:"Required;ZipCode"` + } + + u := User{ + ReqEmail: "a@a.com", + ReqIP: "127.0.0.1", + ReqMobile: "18888888888", + ReqTel: "02088888888", + ReqPhone: "02088888888", + ReqZipCode: "510000", + } + + valid := Validation{} + b, err := valid.Valid(u) + if err != nil { + t.Fatal(err) + } + if b { + t.Fatal("validation should not be passed") + } + + valid = Validation{RequiredFirst: true} + b, err = valid.Valid(u) + if err != nil { + t.Fatal(err) + } + if !b { + t.Fatal("validation should be passed") + } +} + +func TestPointer(t *testing.T) { + type User struct { + ID int + + Email *string `valid:"Email"` + ReqEmail *string `valid:"Required;Email"` + } + + u := User{ + ReqEmail: nil, + Email: nil, + } + + valid := Validation{} + b, err := valid.Valid(u) + if err != nil { + t.Fatal(err) + } + if b { + t.Fatal("validation should not be passed") + } + + validEmail := "a@a.com" + u = User{ + ReqEmail: &validEmail, + Email: nil, + } + + valid = Validation{RequiredFirst: true} + b, err = valid.Valid(u) + if err != nil { + t.Fatal(err) + } + if !b { + t.Fatal("validation should be passed") + } + + u = User{ + ReqEmail: &validEmail, + Email: nil, + } + + valid = Validation{} + b, err = valid.Valid(u) + if err != nil { + t.Fatal(err) + } + if b { + t.Fatal("validation should not be passed") + } + + invalidEmail := "a@a" + u = User{ + ReqEmail: &validEmail, + Email: &invalidEmail, + } + + valid = Validation{RequiredFirst: true} + b, err = valid.Valid(u) + if err != nil { + t.Fatal(err) + } + if b { + t.Fatal("validation should not be passed") + } + + u = User{ + ReqEmail: &validEmail, + Email: &invalidEmail, + } + + valid = Validation{} + b, err = valid.Valid(u) + if err != nil { + t.Fatal(err) + } + if b { + t.Fatal("validation should not be passed") + } +} + +func TestCanSkipAlso(t *testing.T) { + type User struct { + ID int + + Email string `valid:"Email"` + ReqEmail string `valid:"Required;Email"` + MatchRange int `valid:"Range(10, 20)"` + } + + u := User{ + ReqEmail: "a@a.com", + Email: "", + MatchRange: 0, + } + + valid := Validation{RequiredFirst: true} + b, err := valid.Valid(u) + if err != nil { + t.Fatal(err) + } + if b { + t.Fatal("validation should not be passed") + } + + valid = Validation{RequiredFirst: true} + valid.CanSkipAlso("Range") + b, err = valid.Valid(u) + if err != nil { + t.Fatal(err) + } + if !b { + t.Fatal("validation should be passed") + } + +} diff --git a/pkg/validation/validators.go b/pkg/infrastructure/validation/validators.go similarity index 99% rename from pkg/validation/validators.go rename to pkg/infrastructure/validation/validators.go index 534a371e..94152b89 100644 --- a/pkg/validation/validators.go +++ b/pkg/infrastructure/validation/validators.go @@ -23,7 +23,7 @@ import ( "time" "unicode/utf8" - "github.com/astaxie/beego/pkg/logs" + "github.com/astaxie/beego/pkg/infrastructure/logs" ) // CanSkipFuncs will skip valid if RequiredFirst is true and the struct field's value is empty diff --git a/pkg/logs/logformattertest/log_formatter_test.go b/pkg/logs/logformattertest/log_formatter_test.go deleted file mode 100644 index 2d99a8e6..00000000 --- a/pkg/logs/logformattertest/log_formatter_test.go +++ /dev/null @@ -1,36 +0,0 @@ -package logformattertest - -import ( - "fmt" - "testing" - - "github.com/astaxie/beego/pkg/common" - "github.com/astaxie/beego/pkg/logs" -) - -func customFormatter(lm *logs.LogMsg) string { - return fmt.Sprintf("[CUSTOM CONSOLE LOGGING] %s", lm.Msg) -} - -func globalFormatter(lm *logs.LogMsg) string { - return fmt.Sprintf("[GLOBAL] %s", lm.Msg) -} - -func TestCustomLoggingFormatter(t *testing.T) { - // beego.BConfig.Log.AccessLogs = true - - logs.SetLoggerWithOpts("console", []string{`{"color":true}`}, common.SimpleKV{Key: "formatter", Value: customFormatter}) - - // Message will be formatted by the customFormatter with colorful text set to true - logs.Informational("Test message") -} - -func TestGlobalLoggingFormatter(t *testing.T) { - logs.SetGlobalFormatter(globalFormatter) - - logs.SetLogger("console", `{"color":true}`) - - // Message will be formatted by globalFormatter - logs.Informational("Test message") - -} diff --git a/pkg/orm/cmd_utils.go b/pkg/orm/cmd_utils.go deleted file mode 100644 index 692a079f..00000000 --- a/pkg/orm/cmd_utils.go +++ /dev/null @@ -1,320 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package orm - -import ( - "fmt" - "os" - "strings" -) - -type dbIndex struct { - Table string - Name string - SQL string -} - -// create database drop sql. -func getDbDropSQL(al *alias) (sqls []string) { - if len(modelCache.cache) == 0 { - fmt.Println("no Model found, need register your model") - os.Exit(2) - } - - Q := al.DbBaser.TableQuote() - - for _, mi := range modelCache.allOrdered() { - sqls = append(sqls, fmt.Sprintf(`DROP TABLE IF EXISTS %s%s%s`, Q, mi.table, Q)) - } - return sqls -} - -// get database column type string. -func getColumnTyp(al *alias, fi *fieldInfo) (col string) { - T := al.DbBaser.DbTypes() - fieldType := fi.fieldType - fieldSize := fi.size - -checkColumn: - switch fieldType { - case TypeBooleanField: - col = T["bool"] - case TypeVarCharField: - if al.Driver == DRPostgres && fi.toText { - col = T["string-text"] - } else { - col = fmt.Sprintf(T["string"], fieldSize) - } - case TypeCharField: - col = fmt.Sprintf(T["string-char"], fieldSize) - case TypeTextField: - col = T["string-text"] - case TypeTimeField: - col = T["time.Time-clock"] - case TypeDateField: - col = T["time.Time-date"] - case TypeDateTimeField: - col = T["time.Time"] - case TypeBitField: - col = T["int8"] - case TypeSmallIntegerField: - col = T["int16"] - case TypeIntegerField: - col = T["int32"] - case TypeBigIntegerField: - if al.Driver == DRSqlite { - fieldType = TypeIntegerField - goto checkColumn - } - col = T["int64"] - case TypePositiveBitField: - col = T["uint8"] - case TypePositiveSmallIntegerField: - col = T["uint16"] - case TypePositiveIntegerField: - col = T["uint32"] - case TypePositiveBigIntegerField: - col = T["uint64"] - case TypeFloatField: - col = T["float64"] - case TypeDecimalField: - s := T["float64-decimal"] - if !strings.Contains(s, "%d") { - col = s - } else { - col = fmt.Sprintf(s, fi.digits, fi.decimals) - } - case TypeJSONField: - if al.Driver != DRPostgres { - fieldType = TypeVarCharField - goto checkColumn - } - col = T["json"] - case TypeJsonbField: - if al.Driver != DRPostgres { - fieldType = TypeVarCharField - goto checkColumn - } - col = T["jsonb"] - case RelForeignKey, RelOneToOne: - fieldType = fi.relModelInfo.fields.pk.fieldType - fieldSize = fi.relModelInfo.fields.pk.size - goto checkColumn - } - - return -} - -// create alter sql string. -func getColumnAddQuery(al *alias, fi *fieldInfo) string { - Q := al.DbBaser.TableQuote() - typ := getColumnTyp(al, fi) - - if !fi.null { - typ += " " + "NOT NULL" - } - - return fmt.Sprintf("ALTER TABLE %s%s%s ADD COLUMN %s%s%s %s %s", - Q, fi.mi.table, Q, - Q, fi.column, Q, - typ, getColumnDefault(fi), - ) -} - -// create database creation string. -func getDbCreateSQL(al *alias) (sqls []string, tableIndexes map[string][]dbIndex) { - if len(modelCache.cache) == 0 { - fmt.Println("no Model found, need register your model") - os.Exit(2) - } - - Q := al.DbBaser.TableQuote() - T := al.DbBaser.DbTypes() - sep := fmt.Sprintf("%s, %s", Q, Q) - - tableIndexes = make(map[string][]dbIndex) - - for _, mi := range modelCache.allOrdered() { - sql := fmt.Sprintf("-- %s\n", strings.Repeat("-", 50)) - sql += fmt.Sprintf("-- Table Structure for `%s`\n", mi.fullName) - sql += fmt.Sprintf("-- %s\n", strings.Repeat("-", 50)) - - sql += fmt.Sprintf("CREATE TABLE IF NOT EXISTS %s%s%s (\n", Q, mi.table, Q) - - columns := make([]string, 0, len(mi.fields.fieldsDB)) - - sqlIndexes := [][]string{} - - for _, fi := range mi.fields.fieldsDB { - - column := fmt.Sprintf(" %s%s%s ", Q, fi.column, Q) - col := getColumnTyp(al, fi) - - if fi.auto { - switch al.Driver { - case DRSqlite, DRPostgres: - column += T["auto"] - default: - column += col + " " + T["auto"] - } - } else if fi.pk { - column += col + " " + T["pk"] - } else { - column += col - - if !fi.null { - column += " " + "NOT NULL" - } - - //if fi.initial.String() != "" { - // column += " DEFAULT " + fi.initial.String() - //} - - // Append attribute DEFAULT - column += getColumnDefault(fi) - - if fi.unique { - column += " " + "UNIQUE" - } - - if fi.index { - sqlIndexes = append(sqlIndexes, []string{fi.column}) - } - } - - if strings.Contains(column, "%COL%") { - column = strings.Replace(column, "%COL%", fi.column, -1) - } - - if fi.description != "" && al.Driver != DRSqlite { - column += " " + fmt.Sprintf("COMMENT '%s'", fi.description) - } - - columns = append(columns, column) - } - - if mi.model != nil { - allnames := getTableUnique(mi.addrField) - if !mi.manual && len(mi.uniques) > 0 { - allnames = append(allnames, mi.uniques) - } - for _, names := range allnames { - cols := make([]string, 0, len(names)) - for _, name := range names { - if fi, ok := mi.fields.GetByAny(name); ok && fi.dbcol { - cols = append(cols, fi.column) - } else { - panic(fmt.Errorf("cannot found column `%s` when parse UNIQUE in `%s.TableUnique`", name, mi.fullName)) - } - } - column := fmt.Sprintf(" UNIQUE (%s%s%s)", Q, strings.Join(cols, sep), Q) - columns = append(columns, column) - } - } - - sql += strings.Join(columns, ",\n") - sql += "\n)" - - if al.Driver == DRMySQL { - var engine string - if mi.model != nil { - engine = getTableEngine(mi.addrField) - } - if engine == "" { - engine = al.Engine - } - sql += " ENGINE=" + engine - } - - sql += ";" - sqls = append(sqls, sql) - - if mi.model != nil { - for _, names := range getTableIndex(mi.addrField) { - cols := make([]string, 0, len(names)) - for _, name := range names { - if fi, ok := mi.fields.GetByAny(name); ok && fi.dbcol { - cols = append(cols, fi.column) - } else { - panic(fmt.Errorf("cannot found column `%s` when parse INDEX in `%s.TableIndex`", name, mi.fullName)) - } - } - sqlIndexes = append(sqlIndexes, cols) - } - } - - for _, names := range sqlIndexes { - name := mi.table + "_" + strings.Join(names, "_") - cols := strings.Join(names, sep) - sql := fmt.Sprintf("CREATE INDEX %s%s%s ON %s%s%s (%s%s%s);", Q, name, Q, Q, mi.table, Q, Q, cols, Q) - - index := dbIndex{} - index.Table = mi.table - index.Name = name - index.SQL = sql - - tableIndexes[mi.table] = append(tableIndexes[mi.table], index) - } - - } - - return -} - -// Get string value for the attribute "DEFAULT" for the CREATE, ALTER commands -func getColumnDefault(fi *fieldInfo) string { - var ( - v, t, d string - ) - - // Skip default attribute if field is in relations - if fi.rel || fi.reverse { - return v - } - - t = " DEFAULT '%s' " - - // These defaults will be useful if there no config value orm:"default" and NOT NULL is on - switch fi.fieldType { - case TypeTimeField, TypeDateField, TypeDateTimeField, TypeTextField: - return v - - case TypeBitField, TypeSmallIntegerField, TypeIntegerField, - TypeBigIntegerField, TypePositiveBitField, TypePositiveSmallIntegerField, - TypePositiveIntegerField, TypePositiveBigIntegerField, TypeFloatField, - TypeDecimalField: - t = " DEFAULT %s " - d = "0" - case TypeBooleanField: - t = " DEFAULT %s " - d = "FALSE" - case TypeJSONField, TypeJsonbField: - d = "{}" - } - - if fi.colDefault { - if !fi.initial.Exist() { - v = fmt.Sprintf(t, "") - } else { - v = fmt.Sprintf(t, fi.initial.String()) - } - } else { - if !fi.null { - v = fmt.Sprintf(t, d) - } - } - - return v -} diff --git a/pkg/orm/models.go b/pkg/orm/models.go deleted file mode 100644 index c8fbcced..00000000 --- a/pkg/orm/models.go +++ /dev/null @@ -1,108 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package orm - -import ( - "reflect" - "sync" -) - -const ( - odCascade = "cascade" - odSetNULL = "set_null" - odSetDefault = "set_default" - odDoNothing = "do_nothing" - defaultStructTagName = "orm" - defaultStructTagDelim = ";" -) - -var ( - modelCache = &_modelCache{ - cache: make(map[string]*modelInfo), - cacheByFullName: make(map[string]*modelInfo), - } -) - -// model info collection -type _modelCache struct { - sync.RWMutex // only used outsite for bootStrap - orders []string - cache map[string]*modelInfo - cacheByFullName map[string]*modelInfo - done bool -} - -// get all model info -func (mc *_modelCache) all() map[string]*modelInfo { - m := make(map[string]*modelInfo, len(mc.cache)) - for k, v := range mc.cache { - m[k] = v - } - return m -} - -// get ordered model info -func (mc *_modelCache) allOrdered() []*modelInfo { - m := make([]*modelInfo, 0, len(mc.orders)) - for _, table := range mc.orders { - m = append(m, mc.cache[table]) - } - return m -} - -// get model info by table name -func (mc *_modelCache) get(table string) (mi *modelInfo, ok bool) { - mi, ok = mc.cache[table] - return -} - -// get model info by full name -func (mc *_modelCache) getByFullName(name string) (mi *modelInfo, ok bool) { - mi, ok = mc.cacheByFullName[name] - return -} - -func (mc *_modelCache) getByMd(md interface{}) (*modelInfo, bool) { - val := reflect.ValueOf(md) - ind := reflect.Indirect(val) - typ := ind.Type() - name := getFullName(typ) - return mc.getByFullName(name) -} - -// set model info to collection -func (mc *_modelCache) set(table string, mi *modelInfo) *modelInfo { - mii := mc.cache[table] - mc.cache[table] = mi - mc.cacheByFullName[mi.fullName] = mi - if mii == nil { - mc.orders = append(mc.orders, table) - } - return mii -} - -// clean all model info. -func (mc *_modelCache) clean() { - mc.orders = make([]string, 0) - mc.cache = make(map[string]*modelInfo) - mc.cacheByFullName = make(map[string]*modelInfo) - mc.done = false -} - -// ResetModelCache Clean model cache. Then you can re-RegisterModel. -// Common use this api for test case. -func ResetModelCache() { - modelCache.clean() -} diff --git a/pkg/orm/models_boot.go b/pkg/orm/models_boot.go deleted file mode 100644 index 8c56b3c4..00000000 --- a/pkg/orm/models_boot.go +++ /dev/null @@ -1,347 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package orm - -import ( - "fmt" - "os" - "reflect" - "runtime/debug" - "strings" -) - -// register models. -// PrefixOrSuffix means table name prefix or suffix. -// isPrefix whether the prefix is prefix or suffix -func registerModel(PrefixOrSuffix string, model interface{}, isPrefix bool) { - val := reflect.ValueOf(model) - typ := reflect.Indirect(val).Type() - - if val.Kind() != reflect.Ptr { - panic(fmt.Errorf(" cannot use non-ptr model struct `%s`", getFullName(typ))) - } - // For this case: - // u := &User{} - // registerModel(&u) - if typ.Kind() == reflect.Ptr { - panic(fmt.Errorf(" only allow ptr model struct, it looks you use two reference to the struct `%s`", typ)) - } - - table := getTableName(val) - - if PrefixOrSuffix != "" { - if isPrefix { - table = PrefixOrSuffix + table - } else { - table = table + PrefixOrSuffix - } - } - // models's fullname is pkgpath + struct name - name := getFullName(typ) - if _, ok := modelCache.getByFullName(name); ok { - fmt.Printf(" model `%s` repeat register, must be unique\n", name) - os.Exit(2) - } - - if _, ok := modelCache.get(table); ok { - fmt.Printf(" table name `%s` repeat register, must be unique\n", table) - os.Exit(2) - } - - mi := newModelInfo(val) - if mi.fields.pk == nil { - outFor: - for _, fi := range mi.fields.fieldsDB { - if strings.ToLower(fi.name) == "id" { - switch fi.addrValue.Elem().Kind() { - case reflect.Int, reflect.Int32, reflect.Int64, reflect.Uint, reflect.Uint32, reflect.Uint64: - fi.auto = true - fi.pk = true - mi.fields.pk = fi - break outFor - } - } - } - - if mi.fields.pk == nil { - fmt.Printf(" `%s` needs a primary key field, default is to use 'id' if not set\n", name) - os.Exit(2) - } - - } - - mi.table = table - mi.pkg = typ.PkgPath() - mi.model = model - mi.manual = true - - modelCache.set(table, mi) -} - -// bootstrap models -func bootStrap() { - if modelCache.done { - return - } - var ( - err error - models map[string]*modelInfo - ) - if dataBaseCache.getDefault() == nil { - err = fmt.Errorf("must have one register DataBase alias named `default`") - goto end - } - - // set rel and reverse model - // RelManyToMany set the relTable - models = modelCache.all() - for _, mi := range models { - for _, fi := range mi.fields.columns { - if fi.rel || fi.reverse { - elm := fi.addrValue.Type().Elem() - if fi.fieldType == RelReverseMany || fi.fieldType == RelManyToMany { - elm = elm.Elem() - } - // check the rel or reverse model already register - name := getFullName(elm) - mii, ok := modelCache.getByFullName(name) - if !ok || mii.pkg != elm.PkgPath() { - err = fmt.Errorf("can not find rel in field `%s`, `%s` may be miss register", fi.fullName, elm.String()) - goto end - } - fi.relModelInfo = mii - - switch fi.fieldType { - case RelManyToMany: - if fi.relThrough != "" { - if i := strings.LastIndex(fi.relThrough, "."); i != -1 && len(fi.relThrough) > (i+1) { - pn := fi.relThrough[:i] - rmi, ok := modelCache.getByFullName(fi.relThrough) - if !ok || pn != rmi.pkg { - err = fmt.Errorf("field `%s` wrong rel_through value `%s` cannot find table", fi.fullName, fi.relThrough) - goto end - } - fi.relThroughModelInfo = rmi - fi.relTable = rmi.table - } else { - err = fmt.Errorf("field `%s` wrong rel_through value `%s`", fi.fullName, fi.relThrough) - goto end - } - } else { - i := newM2MModelInfo(mi, mii) - if fi.relTable != "" { - i.table = fi.relTable - } - if v := modelCache.set(i.table, i); v != nil { - err = fmt.Errorf("the rel table name `%s` already registered, cannot be use, please change one", fi.relTable) - goto end - } - fi.relTable = i.table - fi.relThroughModelInfo = i - } - - fi.relThroughModelInfo.isThrough = true - } - } - } - } - - // check the rel filed while the relModelInfo also has filed point to current model - // if not exist, add a new field to the relModelInfo - models = modelCache.all() - for _, mi := range models { - for _, fi := range mi.fields.fieldsRel { - switch fi.fieldType { - case RelForeignKey, RelOneToOne, RelManyToMany: - inModel := false - for _, ffi := range fi.relModelInfo.fields.fieldsReverse { - if ffi.relModelInfo == mi { - inModel = true - break - } - } - if !inModel { - rmi := fi.relModelInfo - ffi := new(fieldInfo) - ffi.name = mi.name - ffi.column = ffi.name - ffi.fullName = rmi.fullName + "." + ffi.name - ffi.reverse = true - ffi.relModelInfo = mi - ffi.mi = rmi - if fi.fieldType == RelOneToOne { - ffi.fieldType = RelReverseOne - } else { - ffi.fieldType = RelReverseMany - } - if !rmi.fields.Add(ffi) { - added := false - for cnt := 0; cnt < 5; cnt++ { - ffi.name = fmt.Sprintf("%s%d", mi.name, cnt) - ffi.column = ffi.name - ffi.fullName = rmi.fullName + "." + ffi.name - if added = rmi.fields.Add(ffi); added { - break - } - } - if !added { - panic(fmt.Errorf("cannot generate auto reverse field info `%s` to `%s`", fi.fullName, ffi.fullName)) - } - } - } - } - } - } - - models = modelCache.all() - for _, mi := range models { - for _, fi := range mi.fields.fieldsRel { - switch fi.fieldType { - case RelManyToMany: - for _, ffi := range fi.relThroughModelInfo.fields.fieldsRel { - switch ffi.fieldType { - case RelOneToOne, RelForeignKey: - if ffi.relModelInfo == fi.relModelInfo { - fi.reverseFieldInfoTwo = ffi - } - if ffi.relModelInfo == mi { - fi.reverseField = ffi.name - fi.reverseFieldInfo = ffi - } - } - } - if fi.reverseFieldInfoTwo == nil { - err = fmt.Errorf("can not find m2m field for m2m model `%s`, ensure your m2m model defined correct", - fi.relThroughModelInfo.fullName) - goto end - } - } - } - } - - models = modelCache.all() - for _, mi := range models { - for _, fi := range mi.fields.fieldsReverse { - switch fi.fieldType { - case RelReverseOne: - found := false - mForA: - for _, ffi := range fi.relModelInfo.fields.fieldsByType[RelOneToOne] { - if ffi.relModelInfo == mi { - found = true - fi.reverseField = ffi.name - fi.reverseFieldInfo = ffi - - ffi.reverseField = fi.name - ffi.reverseFieldInfo = fi - break mForA - } - } - if !found { - err = fmt.Errorf("reverse field `%s` not found in model `%s`", fi.fullName, fi.relModelInfo.fullName) - goto end - } - case RelReverseMany: - found := false - mForB: - for _, ffi := range fi.relModelInfo.fields.fieldsByType[RelForeignKey] { - if ffi.relModelInfo == mi { - found = true - fi.reverseField = ffi.name - fi.reverseFieldInfo = ffi - - ffi.reverseField = fi.name - ffi.reverseFieldInfo = fi - - break mForB - } - } - if !found { - mForC: - for _, ffi := range fi.relModelInfo.fields.fieldsByType[RelManyToMany] { - conditions := fi.relThrough != "" && fi.relThrough == ffi.relThrough || - fi.relTable != "" && fi.relTable == ffi.relTable || - fi.relThrough == "" && fi.relTable == "" - if ffi.relModelInfo == mi && conditions { - found = true - - fi.reverseField = ffi.reverseFieldInfoTwo.name - fi.reverseFieldInfo = ffi.reverseFieldInfoTwo - fi.relThroughModelInfo = ffi.relThroughModelInfo - fi.reverseFieldInfoTwo = ffi.reverseFieldInfo - fi.reverseFieldInfoM2M = ffi - ffi.reverseFieldInfoM2M = fi - - break mForC - } - } - } - if !found { - err = fmt.Errorf("reverse field for `%s` not found in model `%s`", fi.fullName, fi.relModelInfo.fullName) - goto end - } - } - } - } - -end: - if err != nil { - fmt.Println(err) - debug.PrintStack() - os.Exit(2) - } -} - -// RegisterModel register models -func RegisterModel(models ...interface{}) { - if modelCache.done { - panic(fmt.Errorf("RegisterModel must be run before BootStrap")) - } - RegisterModelWithPrefix("", models...) -} - -// RegisterModelWithPrefix register models with a prefix -func RegisterModelWithPrefix(prefix string, models ...interface{}) { - if modelCache.done { - panic(fmt.Errorf("RegisterModelWithPrefix must be run before BootStrap")) - } - - for _, model := range models { - registerModel(prefix, model, true) - } -} - -// RegisterModelWithSuffix register models with a suffix -func RegisterModelWithSuffix(suffix string, models ...interface{}) { - if modelCache.done { - panic(fmt.Errorf("RegisterModelWithSuffix must be run before BootStrap")) - } - - for _, model := range models { - registerModel(suffix, model, false) - } -} - -// BootStrap bootstrap models. -// make all model parsed and can not add more models -func BootStrap() { - modelCache.Lock() - defer modelCache.Unlock() - if modelCache.done { - return - } - bootStrap() - modelCache.done = true -} diff --git a/pkg/LICENSE b/pkg/server/web/LICENSE similarity index 100% rename from pkg/LICENSE rename to pkg/server/web/LICENSE diff --git a/pkg/admin.go b/pkg/server/web/admin.go similarity index 93% rename from pkg/admin.go rename to pkg/server/web/admin.go index 4d8b256f..f54ac9e5 100644 --- a/pkg/admin.go +++ b/pkg/server/web/admin.go @@ -12,10 +12,11 @@ // See the License for the specific language governing permissions and // limitations under the License. -package beego +package web import ( "bytes" + context2 "context" "encoding/json" "fmt" "net/http" @@ -27,10 +28,12 @@ import ( "github.com/prometheus/client_golang/prometheus/promhttp" - "github.com/astaxie/beego/pkg/grace" - "github.com/astaxie/beego/pkg/logs" - "github.com/astaxie/beego/pkg/toolbox" - "github.com/astaxie/beego/pkg/utils" + "github.com/astaxie/beego/pkg/infrastructure/logs" + + "github.com/astaxie/beego/pkg/infrastructure/governor" + "github.com/astaxie/beego/pkg/infrastructure/utils" + "github.com/astaxie/beego/pkg/server/web/grace" + "github.com/astaxie/beego/pkg/task" ) // BeeAdminApp is the default adminApp used by admin module. @@ -79,7 +82,7 @@ func adminIndex(rw http.ResponseWriter, _ *http.Request) { // it's registered with url pattern "/qps" in admin module. func qpsIndex(rw http.ResponseWriter, _ *http.Request) { data := make(map[interface{}]interface{}) - data["Content"] = toolbox.StatisticsMap.GetMap() + data["Content"] = StatisticsMap.GetMap() // do html escape before display path, avoid xss if content, ok := (data["Content"]).(M); ok { @@ -271,7 +274,7 @@ func profIndex(rw http.ResponseWriter, r *http.Request) { data = make(map[interface{}]interface{}) result bytes.Buffer ) - toolbox.ProcessInput(command, &result) + governor.ProcessInput(command, &result) data["Content"] = template.HTMLEscapeString(result.String()) if format == "json" && command == "gc summary" { @@ -304,7 +307,7 @@ func healthcheck(rw http.ResponseWriter, r *http.Request) { } ) - for name, h := range toolbox.AdminCheckList { + for name, h := range governor.AdminCheckList { if err := h.Check(); err != nil { result = []string{ "error", @@ -375,11 +378,11 @@ func taskStatus(rw http.ResponseWriter, req *http.Request) { req.ParseForm() taskname := req.Form.Get("taskname") if taskname != "" { - if t, ok := toolbox.AdminTaskList[taskname]; ok { - if err := t.Run(); err != nil { + if t, ok := task.AdminTaskList[taskname]; ok { + if err := t.Run(nil); err != nil { data["Message"] = []string{"error", template.HTMLEscapeString(fmt.Sprintf("%s", err))} } - data["Message"] = []string{"success", template.HTMLEscapeString(fmt.Sprintf("%s run success,Now the Status is
%s", taskname, t.GetStatus()))} + data["Message"] = []string{"success", template.HTMLEscapeString(fmt.Sprintf("%s run success,Now the Status is
%s", taskname, t.GetStatus(nil)))} } else { data["Message"] = []string{"warning", template.HTMLEscapeString(fmt.Sprintf("there's no task which named: %s", taskname))} } @@ -395,12 +398,12 @@ func taskStatus(rw http.ResponseWriter, req *http.Request) { "Last Time", "", } - for tname, tk := range toolbox.AdminTaskList { + for tname, tk := range task.AdminTaskList { result := []string{ template.HTMLEscapeString(tname), - template.HTMLEscapeString(tk.GetSpec()), - template.HTMLEscapeString(tk.GetStatus()), - template.HTMLEscapeString(tk.GetPrev().String()), + template.HTMLEscapeString(tk.GetSpec(nil)), + template.HTMLEscapeString(tk.GetStatus(nil)), + template.HTMLEscapeString(tk.GetPrev(context2.Background()).String()), } *resultList = append(*resultList, result) } @@ -433,8 +436,8 @@ func (admin *adminApp) Route(pattern string, f http.HandlerFunc) { // Run adminApp http server. // Its addr is defined in configuration file as adminhttpaddr and adminhttpport. func (admin *adminApp) Run() { - if len(toolbox.AdminTaskList) > 0 { - toolbox.StartTask() + if len(task.AdminTaskList) > 0 { + task.StartTask() } addr := BConfig.Listen.AdminAddr diff --git a/pkg/admin_test.go b/pkg/server/web/admin_test.go similarity index 96% rename from pkg/admin_test.go rename to pkg/server/web/admin_test.go index 5094aeed..acc67aeb 100644 --- a/pkg/admin_test.go +++ b/pkg/server/web/admin_test.go @@ -1,4 +1,4 @@ -package beego +package web import ( "encoding/json" @@ -11,7 +11,7 @@ import ( "github.com/stretchr/testify/assert" - "github.com/astaxie/beego/pkg/toolbox" + "github.com/astaxie/beego/pkg/infrastructure/governor" ) type SampleDatabaseCheck struct { @@ -126,8 +126,8 @@ func TestWriteJSON(t *testing.T) { func TestHealthCheckHandlerDefault(t *testing.T) { endpointPath := "/healthcheck" - toolbox.AddHealthCheck("database", &SampleDatabaseCheck{}) - toolbox.AddHealthCheck("cache", &SampleCacheCheck{}) + governor.AddHealthCheck("database", &SampleDatabaseCheck{}) + governor.AddHealthCheck("cache", &SampleCacheCheck{}) req, err := http.NewRequest("GET", endpointPath, nil) if err != nil { @@ -187,8 +187,8 @@ func TestBuildHealthCheckResponseList(t *testing.T) { func TestHealthCheckHandlerReturnsJSON(t *testing.T) { - toolbox.AddHealthCheck("database", &SampleDatabaseCheck{}) - toolbox.AddHealthCheck("cache", &SampleCacheCheck{}) + governor.AddHealthCheck("database", &SampleDatabaseCheck{}) + governor.AddHealthCheck("cache", &SampleCacheCheck{}) req, err := http.NewRequest("GET", "/healthcheck?json=true", nil) if err != nil { diff --git a/pkg/adminui.go b/pkg/server/web/adminui.go similarity index 99% rename from pkg/adminui.go rename to pkg/server/web/adminui.go index cdcdef33..de8c9455 100644 --- a/pkg/adminui.go +++ b/pkg/server/web/adminui.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package beego +package web var indexTpl = ` {{define "content"}} @@ -21,7 +21,7 @@ var indexTpl = ` For detail usage please check our document:

-Toolbox +Toolbox

Live Monitor diff --git a/pkg/app.go b/pkg/server/web/app.go similarity index 97% rename from pkg/app.go rename to pkg/server/web/app.go index ea71ce4e..7511c7fe 100644 --- a/pkg/app.go +++ b/pkg/server/web/app.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package beego +package web import ( "crypto/tls" @@ -29,9 +29,10 @@ import ( "golang.org/x/crypto/acme/autocert" - "github.com/astaxie/beego/pkg/grace" - "github.com/astaxie/beego/pkg/logs" - "github.com/astaxie/beego/pkg/utils" + "github.com/astaxie/beego/pkg/infrastructure/logs" + + "github.com/astaxie/beego/pkg/infrastructure/utils" + "github.com/astaxie/beego/pkg/server/web/grace" ) var ( @@ -198,7 +199,7 @@ func (app *App) Run(mws ...MiddleWare) { pool.AppendCertsFromPEM(data) app.Server.TLSConfig = &tls.Config{ ClientCAs: pool, - ClientAuth: tls.RequireAndVerifyClientCert, + ClientAuth: tls.ClientAuthType(BConfig.Listen.ClientAuth), } } if err := app.Server.ListenAndServeTLS(BConfig.Listen.HTTPSCertFile, BConfig.Listen.HTTPSKeyFile); err != nil { @@ -491,15 +492,15 @@ func Handler(rootpath string, h http.Handler, options ...interface{}) *App { // The pos means action constant including // beego.BeforeStatic, beego.BeforeRouter, beego.BeforeExec, beego.AfterExec and beego.FinishRouter. // The bool params is for setting the returnOnOutput value (false allows multiple filters to execute) -func InsertFilter(pattern string, pos int, filter FilterFunc, params ...bool) *App { - BeeApp.Handlers.InsertFilter(pattern, pos, filter, params...) +func InsertFilter(pattern string, pos int, filter FilterFunc, opts ...FilterOpt) *App { + BeeApp.Handlers.InsertFilter(pattern, pos, filter, opts...) return BeeApp } // InsertFilterChain adds a FilterFunc built by filterChain. // This filter will be executed before all filters. // the filter's behavior is like stack -func InsertFilterChain(pattern string, filterChain FilterChain, params ...bool) *App { - BeeApp.Handlers.InsertFilterChain(pattern, filterChain, params...) +func InsertFilterChain(pattern string, filterChain FilterChain, opts ...FilterOpt) *App { + BeeApp.Handlers.InsertFilterChain(pattern, filterChain, opts...) return BeeApp } diff --git a/pkg/beego.go b/pkg/server/web/beego.go similarity index 99% rename from pkg/beego.go rename to pkg/server/web/beego.go index c08ae528..76e7b85e 100644 --- a/pkg/beego.go +++ b/pkg/server/web/beego.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package beego +package web import ( "os" diff --git a/pkg/build_info.go b/pkg/server/web/build_info.go similarity index 98% rename from pkg/build_info.go rename to pkg/server/web/build_info.go index c31152ea..53351c11 100644 --- a/pkg/build_info.go +++ b/pkg/server/web/build_info.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package beego +package web var ( BuildVersion string diff --git a/pkg/server/web/captcha/LICENSE b/pkg/server/web/captcha/LICENSE new file mode 100644 index 00000000..0ad73ae0 --- /dev/null +++ b/pkg/server/web/captcha/LICENSE @@ -0,0 +1,19 @@ +Copyright (c) 2011-2014 Dmitry Chestnykh + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. diff --git a/pkg/server/web/captcha/README.md b/pkg/server/web/captcha/README.md new file mode 100644 index 00000000..dbc2026b --- /dev/null +++ b/pkg/server/web/captcha/README.md @@ -0,0 +1,45 @@ +# Captcha + +an example for use captcha + +``` +package controllers + +import ( + "github.com/astaxie/beego" + "github.com/astaxie/beego/cache" + "github.com/astaxie/beego/utils/captcha" +) + +var cpt *captcha.Captcha + +func init() { + // use beego cache system store the captcha data + store := cache.NewMemoryCache() + cpt = captcha.NewWithFilter("/captcha/", store) +} + +type MainController struct { + beego.Controller +} + +func (this *MainController) Get() { + this.TplName = "index.tpl" +} + +func (this *MainController) Post() { + this.TplName = "index.tpl" + + this.Data["Success"] = cpt.VerifyReq(this.Ctx.Request) +} +``` + +template usage + +``` +{{.Success}} +

+ {{create_captcha}} + +
+``` diff --git a/pkg/utils/captcha/captcha.go b/pkg/server/web/captcha/captcha.go similarity index 94% rename from pkg/utils/captcha/captcha.go rename to pkg/server/web/captcha/captcha.go index 62fc26cf..2ae1fb8f 100644 --- a/pkg/utils/captcha/captcha.go +++ b/pkg/server/web/captcha/captcha.go @@ -66,11 +66,12 @@ import ( "strings" "time" - beego "github.com/astaxie/beego/pkg" - "github.com/astaxie/beego/pkg/cache" - "github.com/astaxie/beego/pkg/context" - "github.com/astaxie/beego/pkg/logs" - "github.com/astaxie/beego/pkg/utils" + "github.com/astaxie/beego/pkg/infrastructure/logs" + + "github.com/astaxie/beego/pkg/client/cache" + "github.com/astaxie/beego/pkg/infrastructure/utils" + "github.com/astaxie/beego/pkg/server/web" + "github.com/astaxie/beego/pkg/server/web/context" ) var ( @@ -261,10 +262,10 @@ func NewWithFilter(urlPrefix string, store cache.Cache) *Captcha { cpt := NewCaptcha(urlPrefix, store) // create filter for serve captcha image - beego.InsertFilter(cpt.URLPrefix+"*", beego.BeforeRouter, cpt.Handler) + web.InsertFilter(cpt.URLPrefix+"*", web.BeforeRouter, cpt.Handler) // add to template func map - beego.AddFuncMap("create_captcha", cpt.CreateCaptchaHTML) + web.AddFuncMap("create_captcha", cpt.CreateCaptchaHTML) return cpt } diff --git a/pkg/utils/captcha/image.go b/pkg/server/web/captcha/image.go similarity index 100% rename from pkg/utils/captcha/image.go rename to pkg/server/web/captcha/image.go diff --git a/pkg/utils/captcha/image_test.go b/pkg/server/web/captcha/image_test.go similarity index 96% rename from pkg/utils/captcha/image_test.go rename to pkg/server/web/captcha/image_test.go index 73d3361b..36cba386 100644 --- a/pkg/utils/captcha/image_test.go +++ b/pkg/server/web/captcha/image_test.go @@ -17,7 +17,7 @@ package captcha import ( "testing" - "github.com/astaxie/beego/pkg/utils" + "github.com/astaxie/beego/pkg/infrastructure/utils" ) type byteCounter struct { diff --git a/pkg/utils/captcha/siprng.go b/pkg/server/web/captcha/siprng.go similarity index 100% rename from pkg/utils/captcha/siprng.go rename to pkg/server/web/captcha/siprng.go diff --git a/pkg/utils/captcha/siprng_test.go b/pkg/server/web/captcha/siprng_test.go similarity index 100% rename from pkg/utils/captcha/siprng_test.go rename to pkg/server/web/captcha/siprng_test.go diff --git a/pkg/config.go b/pkg/server/web/config.go similarity index 73% rename from pkg/config.go rename to pkg/server/web/config.go index e8bde705..6e69a2fb 100644 --- a/pkg/config.go +++ b/pkg/server/web/config.go @@ -12,9 +12,11 @@ // See the License for the specific language governing permissions and // limitations under the License. -package beego +package web import ( + context2 "context" + "crypto/tls" "fmt" "os" "path/filepath" @@ -22,17 +24,18 @@ import ( "runtime" "strings" - "github.com/astaxie/beego/pkg/config" - "github.com/astaxie/beego/pkg/context" - "github.com/astaxie/beego/pkg/logs" - "github.com/astaxie/beego/pkg/session" - "github.com/astaxie/beego/pkg/utils" + "github.com/astaxie/beego/pkg/infrastructure/config" + "github.com/astaxie/beego/pkg/infrastructure/logs" + "github.com/astaxie/beego/pkg/infrastructure/session" + + "github.com/astaxie/beego/pkg/infrastructure/utils" + "github.com/astaxie/beego/pkg/server/web/context" ) // Config is the main struct for BConfig type Config struct { - AppName string //Application name - RunMode string //Running Mode: dev | prod + AppName string // Application name + RunMode string // Running Mode: dev | prod RouterCaseSensitive bool ServerName string RecoverPanic bool @@ -70,6 +73,7 @@ type Listen struct { AdminPort int EnableFcgi bool EnableStdIo bool // EnableStdIo works with EnableFcgi Use FCGI via standard I/O + ClientAuth int } // WebConfig holds web related config @@ -112,8 +116,8 @@ type SessionConfig struct { // LogConfig holds Log related config type LogConfig struct { AccessLogs bool - EnableStaticLogs bool //log static files requests default: false - AccessLogsFormat string //access log format: JSON_FORMAT, APACHE_FORMAT or empty string + EnableStaticLogs bool // log static files requests default: false + AccessLogsFormat string // access log format: JSON_FORMAT, APACHE_FORMAT or empty string FileLineNum bool Outputs map[string]string // Store Adaptor : config } @@ -209,7 +213,7 @@ func newBConfig() *Config { RecoverFunc: recoverPanic, CopyRequestBody: false, EnableGzip: false, - MaxMemory: 1 << 26, //64MB + MaxMemory: 1 << 26, // 64MB EnableErrorsShow: true, EnableErrorsRender: true, Listen: Listen{ @@ -232,6 +236,7 @@ func newBConfig() *Config { AdminPort: 8088, EnableFcgi: false, EnableStdIo: false, + ClientAuth: int(tls.RequireAndVerifyClientCert), }, WebConfig: WebConfig{ AutoRender: true, @@ -257,7 +262,7 @@ func newBConfig() *Config { SessionGCMaxLifetime: 3600, SessionProviderConfig: "", SessionDisableHTTPOnly: false, - SessionCookieLifeTime: 0, //set cookie default is the browser life + SessionCookieLifeTime: 0, // set cookie default is the browser life SessionAutoSetCookie: true, SessionDomain: "", SessionEnableSidInHTTPHeader: false, // enable store/get the sessionId into/from http headers @@ -291,11 +296,11 @@ func assignConfig(ac config.Configer) error { // set the run mode first if envRunMode := os.Getenv("BEEGO_RUNMODE"); envRunMode != "" { BConfig.RunMode = envRunMode - } else if runMode := ac.String("RunMode"); runMode != "" { + } else if runMode, err := ac.String(nil, "RunMode"); runMode != "" && err == nil { BConfig.RunMode = runMode } - if sd := ac.String("StaticDir"); sd != "" { + if sd, err := ac.String(nil, "StaticDir"); sd != "" && err == nil { BConfig.WebConfig.StaticDir = map[string]string{} sds := strings.Fields(sd) for _, v := range sds { @@ -307,7 +312,7 @@ func assignConfig(ac config.Configer) error { } } - if sgz := ac.String("StaticExtensionsToGzip"); sgz != "" { + if sgz, err := ac.String(nil, "StaticExtensionsToGzip"); sgz != "" && err == nil { extensions := strings.Split(sgz, ",") fileExts := []string{} for _, ext := range extensions { @@ -325,15 +330,15 @@ func assignConfig(ac config.Configer) error { } } - if sfs, err := ac.Int("StaticCacheFileSize"); err == nil { + if sfs, err := ac.Int(nil, "StaticCacheFileSize"); err == nil { BConfig.WebConfig.StaticCacheFileSize = sfs } - if sfn, err := ac.Int("StaticCacheFileNum"); err == nil { + if sfn, err := ac.Int(nil, "StaticCacheFileNum"); err == nil { BConfig.WebConfig.StaticCacheFileNum = sfn } - if lo := ac.String("LogOutputs"); lo != "" { + if lo, err := ac.String(nil, "LogOutputs"); lo != "" && err == nil { // if lo is not nil or empty // means user has set his own LogOutputs // clear the default setting to BConfig.Log.Outputs @@ -348,7 +353,7 @@ func assignConfig(ac config.Configer) error { } } - //init log + // init log logs.Reset() for adaptor, config := range BConfig.Log.Outputs { err := logs.SetLogger(adaptor, config) @@ -380,14 +385,14 @@ func assignSingleConfig(p interface{}, ac config.Configer) { name := pt.Field(i).Name switch pf.Kind() { case reflect.String: - pf.SetString(ac.DefaultString(name, pf.String())) + pf.SetString(ac.DefaultString(nil, name, pf.String())) case reflect.Int, reflect.Int64: - pf.SetInt(ac.DefaultInt64(name, pf.Int())) + pf.SetInt(ac.DefaultInt64(nil, name, pf.Int())) case reflect.Bool: - pf.SetBool(ac.DefaultBool(name, pf.Bool())) + pf.SetBool(ac.DefaultBool(nil, name, pf.Bool())) case reflect.Struct: default: - //do nothing here + // do nothing here } } @@ -423,105 +428,105 @@ func newAppConfig(appConfigProvider, appConfigPath string) (*beegoAppConfig, err return &beegoAppConfig{innerConfig: ac}, nil } -func (b *beegoAppConfig) Set(key, val string) error { - if err := b.innerConfig.Set(BConfig.RunMode+"::"+key, val); err != nil { - return b.innerConfig.Set(key, val) +func (b *beegoAppConfig) Set(ctx context2.Context, key, val string) error { + if err := b.innerConfig.Set(nil, BConfig.RunMode+"::"+key, val); err != nil { + return b.innerConfig.Set(nil, key, val) } return nil } -func (b *beegoAppConfig) String(key string) string { - if v := b.innerConfig.String(BConfig.RunMode + "::" + key); v != "" { - return v - } - return b.innerConfig.String(key) -} - -func (b *beegoAppConfig) Strings(key string) []string { - if v := b.innerConfig.Strings(BConfig.RunMode + "::" + key); len(v) > 0 { - return v - } - return b.innerConfig.Strings(key) -} - -func (b *beegoAppConfig) Int(key string) (int, error) { - if v, err := b.innerConfig.Int(BConfig.RunMode + "::" + key); err == nil { +func (b *beegoAppConfig) String(ctx context2.Context, key string) (string, error) { + if v, err := b.innerConfig.String(nil, BConfig.RunMode+"::"+key); v != "" && err == nil { return v, nil } - return b.innerConfig.Int(key) + return b.innerConfig.String(nil, key) } -func (b *beegoAppConfig) Int64(key string) (int64, error) { - if v, err := b.innerConfig.Int64(BConfig.RunMode + "::" + key); err == nil { +func (b *beegoAppConfig) Strings(ctx context2.Context, key string) ([]string, error) { + if v, err := b.innerConfig.Strings(nil, BConfig.RunMode+"::"+key); len(v) > 0 && err == nil { return v, nil } - return b.innerConfig.Int64(key) + return b.innerConfig.Strings(nil, key) } -func (b *beegoAppConfig) Bool(key string) (bool, error) { - if v, err := b.innerConfig.Bool(BConfig.RunMode + "::" + key); err == nil { +func (b *beegoAppConfig) Int(ctx context2.Context, key string) (int, error) { + if v, err := b.innerConfig.Int(nil, BConfig.RunMode+"::"+key); err == nil { return v, nil } - return b.innerConfig.Bool(key) + return b.innerConfig.Int(nil, key) } -func (b *beegoAppConfig) Float(key string) (float64, error) { - if v, err := b.innerConfig.Float(BConfig.RunMode + "::" + key); err == nil { +func (b *beegoAppConfig) Int64(ctx context2.Context, key string) (int64, error) { + if v, err := b.innerConfig.Int64(nil, BConfig.RunMode+"::"+key); err == nil { return v, nil } - return b.innerConfig.Float(key) + return b.innerConfig.Int64(nil, key) } -func (b *beegoAppConfig) DefaultString(key string, defaultVal string) string { - if v := b.String(key); v != "" { +func (b *beegoAppConfig) Bool(ctx context2.Context, key string) (bool, error) { + if v, err := b.innerConfig.Bool(nil, BConfig.RunMode+"::"+key); err == nil { + return v, nil + } + return b.innerConfig.Bool(nil, key) +} + +func (b *beegoAppConfig) Float(ctx context2.Context, key string) (float64, error) { + if v, err := b.innerConfig.Float(nil, BConfig.RunMode+"::"+key); err == nil { + return v, nil + } + return b.innerConfig.Float(nil, key) +} + +func (b *beegoAppConfig) DefaultString(ctx context2.Context, key string, defaultVal string) string { + if v, err := b.String(nil, key); v != "" && err == nil { return v } return defaultVal } -func (b *beegoAppConfig) DefaultStrings(key string, defaultVal []string) []string { - if v := b.Strings(key); len(v) != 0 { +func (b *beegoAppConfig) DefaultStrings(ctx context2.Context, key string, defaultVal []string) []string { + if v, err := b.Strings(ctx, key); len(v) != 0 && err == nil { return v } return defaultVal } -func (b *beegoAppConfig) DefaultInt(key string, defaultVal int) int { - if v, err := b.Int(key); err == nil { +func (b *beegoAppConfig) DefaultInt(ctx context2.Context, key string, defaultVal int) int { + if v, err := b.Int(ctx, key); err == nil { return v } return defaultVal } -func (b *beegoAppConfig) DefaultInt64(key string, defaultVal int64) int64 { - if v, err := b.Int64(key); err == nil { +func (b *beegoAppConfig) DefaultInt64(ctx context2.Context, key string, defaultVal int64) int64 { + if v, err := b.Int64(ctx, key); err == nil { return v } return defaultVal } -func (b *beegoAppConfig) DefaultBool(key string, defaultVal bool) bool { - if v, err := b.Bool(key); err == nil { +func (b *beegoAppConfig) DefaultBool(ctx context2.Context, key string, defaultVal bool) bool { + if v, err := b.Bool(ctx, key); err == nil { return v } return defaultVal } -func (b *beegoAppConfig) DefaultFloat(key string, defaultVal float64) float64 { - if v, err := b.Float(key); err == nil { +func (b *beegoAppConfig) DefaultFloat(ctx context2.Context, key string, defaultVal float64) float64 { + if v, err := b.Float(ctx, key); err == nil { return v } return defaultVal } -func (b *beegoAppConfig) DIY(key string) (interface{}, error) { - return b.innerConfig.DIY(key) +func (b *beegoAppConfig) DIY(ctx context2.Context, key string) (interface{}, error) { + return b.innerConfig.DIY(nil, key) } -func (b *beegoAppConfig) GetSection(section string) (map[string]string, error) { - return b.innerConfig.GetSection(section) +func (b *beegoAppConfig) GetSection(ctx context2.Context, section string) (map[string]string, error) { + return b.innerConfig.GetSection(nil, section) } -func (b *beegoAppConfig) SaveConfigFile(filename string) error { - return b.innerConfig.SaveConfigFile(filename) +func (b *beegoAppConfig) SaveConfigFile(ctx context2.Context, filename string) error { + return b.innerConfig.SaveConfigFile(nil, filename) } diff --git a/pkg/config_test.go b/pkg/server/web/config_test.go similarity index 89% rename from pkg/config_test.go rename to pkg/server/web/config_test.go index c810a9e3..4961d3a9 100644 --- a/pkg/config_test.go +++ b/pkg/server/web/config_test.go @@ -12,14 +12,14 @@ // See the License for the specific language governing permissions and // limitations under the License. -package beego +package web import ( "encoding/json" "reflect" "testing" - beeJson "github.com/astaxie/beego/pkg/config/json" + beeJson "github.com/astaxie/beego/pkg/infrastructure/config/json" ) func TestDefaults(t *testing.T) { @@ -111,12 +111,12 @@ func TestAssignConfig_02(t *testing.T) { func TestAssignConfig_03(t *testing.T) { jcf := &beeJson.JSONConfig{} ac, _ := jcf.ParseData([]byte(`{"AppName":"beego"}`)) - ac.Set("AppName", "test_app") - ac.Set("RunMode", "online") - ac.Set("StaticDir", "download:down download2:down2") - ac.Set("StaticExtensionsToGzip", ".css,.js,.html,.jpg,.png") - ac.Set("StaticCacheFileSize", "87456") - ac.Set("StaticCacheFileNum", "1254") + ac.Set(nil, "AppName", "test_app") + ac.Set(nil, "RunMode", "online") + ac.Set(nil, "StaticDir", "download:down download2:down2") + ac.Set(nil, "StaticExtensionsToGzip", ".css,.js,.html,.jpg,.png") + ac.Set(nil, "StaticCacheFileSize", "87456") + ac.Set(nil, "StaticCacheFileNum", "1254") assignConfig(ac) t.Logf("%#v", BConfig) diff --git a/pkg/context/acceptencoder.go b/pkg/server/web/context/acceptencoder.go similarity index 100% rename from pkg/context/acceptencoder.go rename to pkg/server/web/context/acceptencoder.go diff --git a/pkg/context/acceptencoder_test.go b/pkg/server/web/context/acceptencoder_test.go similarity index 100% rename from pkg/context/acceptencoder_test.go rename to pkg/server/web/context/acceptencoder_test.go diff --git a/pkg/context/context.go b/pkg/server/web/context/context.go similarity index 99% rename from pkg/context/context.go rename to pkg/server/web/context/context.go index f7b325a9..78e0a6d6 100644 --- a/pkg/context/context.go +++ b/pkg/server/web/context/context.go @@ -35,7 +35,7 @@ import ( "strings" "time" - "github.com/astaxie/beego/pkg/utils" + "github.com/astaxie/beego/pkg/infrastructure/utils" ) // Commonly used mime-types diff --git a/pkg/context/context_test.go b/pkg/server/web/context/context_test.go similarity index 100% rename from pkg/context/context_test.go rename to pkg/server/web/context/context_test.go diff --git a/pkg/context/input.go b/pkg/server/web/context/input.go similarity index 99% rename from pkg/context/input.go rename to pkg/server/web/context/input.go index 5ff85f43..a6fec774 100644 --- a/pkg/context/input.go +++ b/pkg/server/web/context/input.go @@ -29,7 +29,7 @@ import ( "strings" "sync" - "github.com/astaxie/beego/pkg/session" + "github.com/astaxie/beego/pkg/infrastructure/session" ) // Regexes for checking the accept headers @@ -361,7 +361,7 @@ func (input *BeegoInput) Cookie(key string) string { // Session returns current session item value by a given key. // if non-existed, return nil. func (input *BeegoInput) Session(key interface{}) interface{} { - return input.CruSession.Get(key) + return input.CruSession.Get(nil, key) } // CopyBody returns the raw request body data as bytes. diff --git a/pkg/context/input_test.go b/pkg/server/web/context/input_test.go similarity index 100% rename from pkg/context/input_test.go rename to pkg/server/web/context/input_test.go diff --git a/pkg/context/output.go b/pkg/server/web/context/output.go similarity index 99% rename from pkg/context/output.go rename to pkg/server/web/context/output.go index 0a530244..a6e83681 100644 --- a/pkg/context/output.go +++ b/pkg/server/web/context/output.go @@ -404,5 +404,5 @@ func stringsToJSON(str string) string { // Session sets session item value with given key. func (output *BeegoOutput) Session(name interface{}, value interface{}) { - output.Context.Input.CruSession.Set(name, value) + output.Context.Input.CruSession.Set(nil, name, value) } diff --git a/pkg/context/param/conv.go b/pkg/server/web/context/param/conv.go similarity index 94% rename from pkg/context/param/conv.go rename to pkg/server/web/context/param/conv.go index d96f964c..a96dacdd 100644 --- a/pkg/context/param/conv.go +++ b/pkg/server/web/context/param/conv.go @@ -4,8 +4,8 @@ import ( "fmt" "reflect" - beecontext "github.com/astaxie/beego/pkg/context" - "github.com/astaxie/beego/pkg/logs" + "github.com/astaxie/beego/pkg/infrastructure/logs" + beecontext "github.com/astaxie/beego/pkg/server/web/context" ) // ConvertParams converts http method params to values that will be passed to the method controller as arguments diff --git a/pkg/context/param/methodparams.go b/pkg/server/web/context/param/methodparams.go similarity index 100% rename from pkg/context/param/methodparams.go rename to pkg/server/web/context/param/methodparams.go diff --git a/pkg/context/param/options.go b/pkg/server/web/context/param/options.go similarity index 100% rename from pkg/context/param/options.go rename to pkg/server/web/context/param/options.go diff --git a/pkg/context/param/parsers.go b/pkg/server/web/context/param/parsers.go similarity index 100% rename from pkg/context/param/parsers.go rename to pkg/server/web/context/param/parsers.go diff --git a/pkg/context/param/parsers_test.go b/pkg/server/web/context/param/parsers_test.go similarity index 100% rename from pkg/context/param/parsers_test.go rename to pkg/server/web/context/param/parsers_test.go diff --git a/pkg/context/renderer.go b/pkg/server/web/context/renderer.go similarity index 100% rename from pkg/context/renderer.go rename to pkg/server/web/context/renderer.go diff --git a/pkg/context/response.go b/pkg/server/web/context/response.go similarity index 100% rename from pkg/context/response.go rename to pkg/server/web/context/response.go diff --git a/pkg/controller.go b/pkg/server/web/controller.go similarity index 98% rename from pkg/controller.go rename to pkg/server/web/controller.go index f3989a76..2081e647 100644 --- a/pkg/controller.go +++ b/pkg/server/web/controller.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package beego +package web import ( "bytes" @@ -28,9 +28,10 @@ import ( "strconv" "strings" - "github.com/astaxie/beego/pkg/context" - "github.com/astaxie/beego/pkg/context/param" - "github.com/astaxie/beego/pkg/session" + "github.com/astaxie/beego/pkg/infrastructure/session" + + "github.com/astaxie/beego/pkg/server/web/context" + "github.com/astaxie/beego/pkg/server/web/context/param" ) var ( @@ -621,7 +622,7 @@ func (c *Controller) SetSession(name interface{}, value interface{}) { if c.CruSession == nil { c.StartSession() } - c.CruSession.Set(name, value) + c.CruSession.Set(nil, name, value) } // GetSession gets value from session. @@ -629,7 +630,7 @@ func (c *Controller) GetSession(name interface{}) interface{} { if c.CruSession == nil { c.StartSession() } - return c.CruSession.Get(name) + return c.CruSession.Get(nil, name) } // DelSession removes value from session. @@ -637,14 +638,14 @@ func (c *Controller) DelSession(name interface{}) { if c.CruSession == nil { c.StartSession() } - c.CruSession.Delete(name) + c.CruSession.Delete(nil, name) } // SessionRegenerateID regenerates session id for this session. // the session data have no changes. func (c *Controller) SessionRegenerateID() { if c.CruSession != nil { - c.CruSession.SessionRelease(c.Ctx.ResponseWriter) + c.CruSession.SessionRelease(nil, c.Ctx.ResponseWriter) } c.CruSession = GlobalSessions.SessionRegenerateID(c.Ctx.ResponseWriter, c.Ctx.Request) c.Ctx.Input.CruSession = c.CruSession @@ -652,7 +653,7 @@ func (c *Controller) SessionRegenerateID() { // DestroySession cleans session data and session cookie. func (c *Controller) DestroySession() { - c.Ctx.Input.CruSession.Flush() + c.Ctx.Input.CruSession.Flush(nil) c.Ctx.Input.CruSession = nil GlobalSessions.SessionDestroy(c.Ctx.ResponseWriter, c.Ctx.Request) } diff --git a/pkg/controller_test.go b/pkg/server/web/controller_test.go similarity index 98% rename from pkg/controller_test.go rename to pkg/server/web/controller_test.go index 97f1e964..46da3629 100644 --- a/pkg/controller_test.go +++ b/pkg/server/web/controller_test.go @@ -12,19 +12,18 @@ // See the License for the specific language governing permissions and // limitations under the License. -package beego +package web import ( "math" + "os" + "path/filepath" "strconv" "testing" "github.com/stretchr/testify/assert" - "os" - "path/filepath" - - "github.com/astaxie/beego/pkg/context" + "github.com/astaxie/beego/pkg/server/web/context" ) func TestGetInt(t *testing.T) { diff --git a/pkg/server/web/doc.go b/pkg/server/web/doc.go new file mode 100644 index 00000000..0ab10bfd --- /dev/null +++ b/pkg/server/web/doc.go @@ -0,0 +1,17 @@ +/* +Package beego provide a MVC framework +beego: an open-source, high-performance, modular, full-stack web framework + +It is used for rapid development of RESTful APIs, web apps and backend services in Go. +beego is inspired by Tornado, Sinatra and Flask with the added benefit of some Go-specific features such as interfaces and struct embedding. + + package main + import "github.com/astaxie/beego/pkg" + + func main() { + beego.Run() + } + +more information: http://beego.me +*/ +package web diff --git a/pkg/error.go b/pkg/server/web/error.go similarity index 99% rename from pkg/error.go rename to pkg/server/web/error.go index aff984c0..b62fb70d 100644 --- a/pkg/error.go +++ b/pkg/server/web/error.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package beego +package web import ( "fmt" @@ -23,8 +23,9 @@ import ( "strconv" "strings" - "github.com/astaxie/beego/pkg/context" - "github.com/astaxie/beego/pkg/utils" + "github.com/astaxie/beego/pkg/infrastructure/utils" + + "github.com/astaxie/beego/pkg/server/web/context" ) const ( diff --git a/pkg/error_test.go b/pkg/server/web/error_test.go similarity index 99% rename from pkg/error_test.go rename to pkg/server/web/error_test.go index 378aa953..2685a985 100644 --- a/pkg/error_test.go +++ b/pkg/server/web/error_test.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package beego +package web import ( "net/http" diff --git a/pkg/filter.go b/pkg/server/web/filter.go similarity index 76% rename from pkg/filter.go rename to pkg/server/web/filter.go index 911cb848..9aab48d6 100644 --- a/pkg/filter.go +++ b/pkg/server/web/filter.go @@ -12,12 +12,12 @@ // See the License for the specific language governing permissions and // limitations under the License. -package beego +package web import ( "strings" - "github.com/astaxie/beego/pkg/context" + "github.com/astaxie/beego/pkg/server/web/context" ) // FilterChain is different from pure FilterFunc @@ -43,24 +43,27 @@ type FilterRouter struct { // params is for: // 1. setting the returnOnOutput value (false allows multiple filters to execute) // 2. determining whether or not params need to be reset. -func newFilterRouter(pattern string, routerCaseSensitive bool, filter FilterFunc, params ...bool) *FilterRouter { +func newFilterRouter(pattern string, filter FilterFunc, opts ...FilterOpt) *FilterRouter { mr := &FilterRouter{ - tree: NewTree(), - pattern: pattern, - filterFunc: filter, + tree: NewTree(), + pattern: pattern, + filterFunc: filter, + } + + fos := &filterOpts{ returnOnOutput: true, } - if !routerCaseSensitive { + + for _, o := range opts { + o(fos) + } + + if !fos.routerCaseSensitive { mr.pattern = strings.ToLower(pattern) } - paramsLen := len(params) - if paramsLen > 0 { - mr.returnOnOutput = params[0] - } - if paramsLen > 1 { - mr.resetParams = params[1] - } + mr.returnOnOutput = fos.returnOnOutput + mr.resetParams = fos.resetParams mr.tree.AddRouter(pattern, true) return mr } @@ -103,3 +106,29 @@ func (f *FilterRouter) ValidRouter(url string, ctx *context.Context) bool { } return false } + +type filterOpts struct { + returnOnOutput bool + resetParams bool + routerCaseSensitive bool +} + +type FilterOpt func(opts *filterOpts) + +func WithReturnOnOutput(ret bool) FilterOpt { + return func(opts *filterOpts) { + opts.returnOnOutput = ret + } +} + +func WithResetParams(reset bool) FilterOpt { + return func(opts *filterOpts) { + opts.resetParams = reset + } +} + +func WithCaseSensitive(sensitive bool) FilterOpt { + return func(opts *filterOpts) { + opts.routerCaseSensitive = sensitive + } +} diff --git a/pkg/plugins/apiauth/apiauth.go b/pkg/server/web/filter/apiauth/apiauth.go similarity index 91% rename from pkg/plugins/apiauth/apiauth.go rename to pkg/server/web/filter/apiauth/apiauth.go index 7b1d4405..8944db63 100644 --- a/pkg/plugins/apiauth/apiauth.go +++ b/pkg/server/web/filter/apiauth/apiauth.go @@ -65,15 +65,15 @@ import ( "sort" "time" - beego "github.com/astaxie/beego/pkg" - "github.com/astaxie/beego/pkg/context" + "github.com/astaxie/beego/pkg/server/web" + "github.com/astaxie/beego/pkg/server/web/context" ) // AppIDToAppSecret gets appsecret through appid type AppIDToAppSecret func(string) string // APIBasicAuth uses the basic appid/appkey as the AppIdToAppSecret -func APIBasicAuth(appid, appkey string) beego.FilterFunc { +func APIBasicAuth(appid, appkey string) web.FilterFunc { ft := func(aid string) string { if aid == appid { return appkey @@ -83,13 +83,8 @@ func APIBasicAuth(appid, appkey string) beego.FilterFunc { return APISecretAuth(ft, 300) } -// APIBasicAuth calls APIBasicAuth for previous callers -func APIBaiscAuth(appid, appkey string) beego.FilterFunc { - return APIBasicAuth(appid, appkey) -} - // APISecretAuth uses AppIdToAppSecret verify and -func APISecretAuth(f AppIDToAppSecret, timeout int) beego.FilterFunc { +func APISecretAuth(f AppIDToAppSecret, timeout int) web.FilterFunc { return func(ctx *context.Context) { if ctx.Input.Query("appid") == "" { ctx.ResponseWriter.WriteHeader(403) diff --git a/pkg/server/web/filter/apiauth/apiauth_test.go b/pkg/server/web/filter/apiauth/apiauth_test.go new file mode 100644 index 00000000..1f56cb0f --- /dev/null +++ b/pkg/server/web/filter/apiauth/apiauth_test.go @@ -0,0 +1,20 @@ +package apiauth + +import ( + "net/url" + "testing" +) + +func TestSignature(t *testing.T) { + appsecret := "beego secret" + method := "GET" + RequestURL := "http://localhost/test/url" + params := make(url.Values) + params.Add("arg1", "hello") + params.Add("arg2", "beego") + + signature := "mFdpvLh48ca4mDVEItE9++AKKQ/IVca7O/ZyyB8hR58=" + if Signature(appsecret, method, params, RequestURL) != signature { + t.Error("Signature error") + } +} diff --git a/pkg/plugins/auth/basic.go b/pkg/server/web/filter/auth/basic.go similarity index 94% rename from pkg/plugins/auth/basic.go rename to pkg/server/web/filter/auth/basic.go index d84b8df2..209cd97d 100644 --- a/pkg/plugins/auth/basic.go +++ b/pkg/server/web/filter/auth/basic.go @@ -40,14 +40,14 @@ import ( "net/http" "strings" - beego "github.com/astaxie/beego/pkg" - "github.com/astaxie/beego/pkg/context" + "github.com/astaxie/beego/pkg/server/web" + "github.com/astaxie/beego/pkg/server/web/context" ) var defaultRealm = "Authorization Required" // Basic is the http basic auth -func Basic(username string, password string) beego.FilterFunc { +func Basic(username string, password string) web.FilterFunc { secrets := func(user, pass string) bool { return user == username && pass == password } @@ -55,7 +55,7 @@ func Basic(username string, password string) beego.FilterFunc { } // NewBasicAuthenticator return the BasicAuth -func NewBasicAuthenticator(secrets SecretProvider, Realm string) beego.FilterFunc { +func NewBasicAuthenticator(secrets SecretProvider, Realm string) web.FilterFunc { return func(ctx *context.Context) { a := &BasicAuth{Secrets: secrets, Realm: Realm} if username := a.CheckAuth(ctx.Request); username == "" { diff --git a/pkg/plugins/authz/authz.go b/pkg/server/web/filter/authz/authz.go similarity index 94% rename from pkg/plugins/authz/authz.go rename to pkg/server/web/filter/authz/authz.go index 47a20c8a..a3a8dca6 100644 --- a/pkg/plugins/authz/authz.go +++ b/pkg/server/web/filter/authz/authz.go @@ -42,14 +42,15 @@ package authz import ( "net/http" - beego "github.com/astaxie/beego/pkg" - "github.com/astaxie/beego/pkg/context" "github.com/casbin/casbin" + + "github.com/astaxie/beego/pkg/server/web" + "github.com/astaxie/beego/pkg/server/web/context" ) // NewAuthorizer returns the authorizer. // Use a casbin enforcer as input -func NewAuthorizer(e *casbin.Enforcer) beego.FilterFunc { +func NewAuthorizer(e *casbin.Enforcer) web.FilterFunc { return func(ctx *context.Context) { a := &BasicAuthorizer{enforcer: e} diff --git a/pkg/server/web/filter/authz/authz_model.conf b/pkg/server/web/filter/authz/authz_model.conf new file mode 100644 index 00000000..d1b3dbd7 --- /dev/null +++ b/pkg/server/web/filter/authz/authz_model.conf @@ -0,0 +1,14 @@ +[request_definition] +r = sub, obj, act + +[policy_definition] +p = sub, obj, act + +[role_definition] +g = _, _ + +[policy_effect] +e = some(where (p.eft == allow)) + +[matchers] +m = g(r.sub, p.sub) && keyMatch(r.obj, p.obj) && (r.act == p.act || p.act == "*") \ No newline at end of file diff --git a/pkg/server/web/filter/authz/authz_policy.csv b/pkg/server/web/filter/authz/authz_policy.csv new file mode 100644 index 00000000..c062dd3e --- /dev/null +++ b/pkg/server/web/filter/authz/authz_policy.csv @@ -0,0 +1,7 @@ +p, alice, /dataset1/*, GET +p, alice, /dataset1/resource1, POST +p, bob, /dataset2/resource1, * +p, bob, /dataset2/resource2, GET +p, bob, /dataset2/folder1/*, POST +p, dataset1_admin, /dataset1/*, * +g, cathy, dataset1_admin \ No newline at end of file diff --git a/pkg/server/web/filter/authz/authz_test.go b/pkg/server/web/filter/authz/authz_test.go new file mode 100644 index 00000000..e50596b4 --- /dev/null +++ b/pkg/server/web/filter/authz/authz_test.go @@ -0,0 +1,109 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package authz + +import ( + "net/http" + "net/http/httptest" + "testing" + + "github.com/casbin/casbin" + + "github.com/astaxie/beego/pkg/server/web" + "github.com/astaxie/beego/pkg/server/web/context" + "github.com/astaxie/beego/pkg/server/web/filter/auth" +) + +func testRequest(t *testing.T, handler *web.ControllerRegister, user string, path string, method string, code int) { + r, _ := http.NewRequest(method, path, nil) + r.SetBasicAuth(user, "123") + w := httptest.NewRecorder() + handler.ServeHTTP(w, r) + + if w.Code != code { + t.Errorf("%s, %s, %s: %d, supposed to be %d", user, path, method, w.Code, code) + } +} + +func TestBasic(t *testing.T) { + handler := web.NewControllerRegister() + + handler.InsertFilter("*", web.BeforeRouter, auth.Basic("alice", "123")) + handler.InsertFilter("*", web.BeforeRouter, NewAuthorizer(casbin.NewEnforcer("authz_model.conf", "authz_policy.csv"))) + + handler.Any("*", func(ctx *context.Context) { + ctx.Output.SetStatus(200) + }) + + testRequest(t, handler, "alice", "/dataset1/resource1", "GET", 200) + testRequest(t, handler, "alice", "/dataset1/resource1", "POST", 200) + testRequest(t, handler, "alice", "/dataset1/resource2", "GET", 200) + testRequest(t, handler, "alice", "/dataset1/resource2", "POST", 403) +} + +func TestPathWildcard(t *testing.T) { + handler := web.NewControllerRegister() + + handler.InsertFilter("*", web.BeforeRouter, auth.Basic("bob", "123")) + handler.InsertFilter("*", web.BeforeRouter, NewAuthorizer(casbin.NewEnforcer("authz_model.conf", "authz_policy.csv"))) + + handler.Any("*", func(ctx *context.Context) { + ctx.Output.SetStatus(200) + }) + + testRequest(t, handler, "bob", "/dataset2/resource1", "GET", 200) + testRequest(t, handler, "bob", "/dataset2/resource1", "POST", 200) + testRequest(t, handler, "bob", "/dataset2/resource1", "DELETE", 200) + testRequest(t, handler, "bob", "/dataset2/resource2", "GET", 200) + testRequest(t, handler, "bob", "/dataset2/resource2", "POST", 403) + testRequest(t, handler, "bob", "/dataset2/resource2", "DELETE", 403) + + testRequest(t, handler, "bob", "/dataset2/folder1/item1", "GET", 403) + testRequest(t, handler, "bob", "/dataset2/folder1/item1", "POST", 200) + testRequest(t, handler, "bob", "/dataset2/folder1/item1", "DELETE", 403) + testRequest(t, handler, "bob", "/dataset2/folder1/item2", "GET", 403) + testRequest(t, handler, "bob", "/dataset2/folder1/item2", "POST", 200) + testRequest(t, handler, "bob", "/dataset2/folder1/item2", "DELETE", 403) +} + +func TestRBAC(t *testing.T) { + handler := web.NewControllerRegister() + + handler.InsertFilter("*", web.BeforeRouter, auth.Basic("cathy", "123")) + e := casbin.NewEnforcer("authz_model.conf", "authz_policy.csv") + handler.InsertFilter("*", web.BeforeRouter, NewAuthorizer(e)) + + handler.Any("*", func(ctx *context.Context) { + ctx.Output.SetStatus(200) + }) + + // cathy can access all /dataset1/* resources via all methods because it has the dataset1_admin role. + testRequest(t, handler, "cathy", "/dataset1/item", "GET", 200) + testRequest(t, handler, "cathy", "/dataset1/item", "POST", 200) + testRequest(t, handler, "cathy", "/dataset1/item", "DELETE", 200) + testRequest(t, handler, "cathy", "/dataset2/item", "GET", 403) + testRequest(t, handler, "cathy", "/dataset2/item", "POST", 403) + testRequest(t, handler, "cathy", "/dataset2/item", "DELETE", 403) + + // delete all roles on user cathy, so cathy cannot access any resources now. + e.DeleteRolesForUser("cathy") + + testRequest(t, handler, "cathy", "/dataset1/item", "GET", 403) + testRequest(t, handler, "cathy", "/dataset1/item", "POST", 403) + testRequest(t, handler, "cathy", "/dataset1/item", "DELETE", 403) + testRequest(t, handler, "cathy", "/dataset2/item", "GET", 403) + testRequest(t, handler, "cathy", "/dataset2/item", "POST", 403) + testRequest(t, handler, "cathy", "/dataset2/item", "DELETE", 403) +} diff --git a/pkg/plugins/cors/cors.go b/pkg/server/web/filter/cors/cors.go similarity index 97% rename from pkg/plugins/cors/cors.go rename to pkg/server/web/filter/cors/cors.go index 18bf2bec..800eeded 100644 --- a/pkg/plugins/cors/cors.go +++ b/pkg/server/web/filter/cors/cors.go @@ -42,8 +42,8 @@ import ( "strings" "time" - beego "github.com/astaxie/beego/pkg" - "github.com/astaxie/beego/pkg/context" + "github.com/astaxie/beego/pkg/server/web" + "github.com/astaxie/beego/pkg/server/web/context" ) const ( @@ -187,7 +187,7 @@ func (o *Options) IsOriginAllowed(origin string) (allowed bool) { } // Allow enables CORS for requests those match the provided options. -func Allow(opts *Options) beego.FilterFunc { +func Allow(opts *Options) web.FilterFunc { // Allow default headers if nothing is specified. if len(opts.AllowHeaders) == 0 { opts.AllowHeaders = defaultAllowHeaders diff --git a/pkg/plugins/cors/cors_test.go b/pkg/server/web/filter/cors/cors_test.go similarity index 88% rename from pkg/plugins/cors/cors_test.go rename to pkg/server/web/filter/cors/cors_test.go index 664d35a7..60659fdd 100644 --- a/pkg/plugins/cors/cors_test.go +++ b/pkg/server/web/filter/cors/cors_test.go @@ -21,8 +21,8 @@ import ( "testing" "time" - beego "github.com/astaxie/beego/pkg" - "github.com/astaxie/beego/pkg/context" + "github.com/astaxie/beego/pkg/server/web" + "github.com/astaxie/beego/pkg/server/web/context" ) // HTTPHeaderGuardRecorder is httptest.ResponseRecorder with own http.Header @@ -55,8 +55,8 @@ func (gr *HTTPHeaderGuardRecorder) Header() http.Header { func Test_AllowAll(t *testing.T) { recorder := httptest.NewRecorder() - handler := beego.NewControllerRegister() - handler.InsertFilter("*", beego.BeforeRouter, Allow(&Options{ + handler := web.NewControllerRegister() + handler.InsertFilter("*", web.BeforeRouter, Allow(&Options{ AllowAllOrigins: true, })) handler.Any("/foo", func(ctx *context.Context) { @@ -72,8 +72,8 @@ func Test_AllowAll(t *testing.T) { func Test_AllowRegexMatch(t *testing.T) { recorder := httptest.NewRecorder() - handler := beego.NewControllerRegister() - handler.InsertFilter("*", beego.BeforeRouter, Allow(&Options{ + handler := web.NewControllerRegister() + handler.InsertFilter("*", web.BeforeRouter, Allow(&Options{ AllowOrigins: []string{"https://aaa.com", "https://*.foo.com"}, })) handler.Any("/foo", func(ctx *context.Context) { @@ -92,8 +92,8 @@ func Test_AllowRegexMatch(t *testing.T) { func Test_AllowRegexNoMatch(t *testing.T) { recorder := httptest.NewRecorder() - handler := beego.NewControllerRegister() - handler.InsertFilter("*", beego.BeforeRouter, Allow(&Options{ + handler := web.NewControllerRegister() + handler.InsertFilter("*", web.BeforeRouter, Allow(&Options{ AllowOrigins: []string{"https://*.foo.com"}, })) handler.Any("/foo", func(ctx *context.Context) { @@ -112,8 +112,8 @@ func Test_AllowRegexNoMatch(t *testing.T) { func Test_OtherHeaders(t *testing.T) { recorder := httptest.NewRecorder() - handler := beego.NewControllerRegister() - handler.InsertFilter("*", beego.BeforeRouter, Allow(&Options{ + handler := web.NewControllerRegister() + handler.InsertFilter("*", web.BeforeRouter, Allow(&Options{ AllowAllOrigins: true, AllowCredentials: true, AllowMethods: []string{"PATCH", "GET"}, @@ -156,8 +156,8 @@ func Test_OtherHeaders(t *testing.T) { func Test_DefaultAllowHeaders(t *testing.T) { recorder := httptest.NewRecorder() - handler := beego.NewControllerRegister() - handler.InsertFilter("*", beego.BeforeRouter, Allow(&Options{ + handler := web.NewControllerRegister() + handler.InsertFilter("*", web.BeforeRouter, Allow(&Options{ AllowAllOrigins: true, })) handler.Any("/foo", func(ctx *context.Context) { @@ -175,8 +175,8 @@ func Test_DefaultAllowHeaders(t *testing.T) { func Test_Preflight(t *testing.T) { recorder := NewRecorder() - handler := beego.NewControllerRegister() - handler.InsertFilter("*", beego.BeforeRouter, Allow(&Options{ + handler := web.NewControllerRegister() + handler.InsertFilter("*", web.BeforeRouter, Allow(&Options{ AllowAllOrigins: true, AllowMethods: []string{"PUT", "PATCH"}, AllowHeaders: []string{"Origin", "X-whatever", "X-CaseSensitive"}, @@ -219,8 +219,8 @@ func Test_Preflight(t *testing.T) { func Benchmark_WithoutCORS(b *testing.B) { recorder := httptest.NewRecorder() - handler := beego.NewControllerRegister() - beego.BConfig.RunMode = beego.PROD + handler := web.NewControllerRegister() + web.BConfig.RunMode = web.PROD handler.Any("/foo", func(ctx *context.Context) { ctx.Output.SetStatus(500) }) @@ -233,9 +233,9 @@ func Benchmark_WithoutCORS(b *testing.B) { func Benchmark_WithCORS(b *testing.B) { recorder := httptest.NewRecorder() - handler := beego.NewControllerRegister() - beego.BConfig.RunMode = beego.PROD - handler.InsertFilter("*", beego.BeforeRouter, Allow(&Options{ + handler := web.NewControllerRegister() + web.BConfig.RunMode = web.PROD + handler.InsertFilter("*", web.BeforeRouter, Allow(&Options{ AllowAllOrigins: true, AllowCredentials: true, AllowMethods: []string{"PATCH", "GET"}, diff --git a/pkg/web/filter/opentracing/filter.go b/pkg/server/web/filter/opentracing/filter.go similarity index 92% rename from pkg/web/filter/opentracing/filter.go rename to pkg/server/web/filter/opentracing/filter.go index e6ee9150..dd5663f9 100644 --- a/pkg/web/filter/opentracing/filter.go +++ b/pkg/server/web/filter/opentracing/filter.go @@ -21,8 +21,8 @@ import ( opentracingKit "github.com/go-kit/kit/tracing/opentracing" "github.com/opentracing/opentracing-go" - beego "github.com/astaxie/beego/pkg" - beegoCtx "github.com/astaxie/beego/pkg/context" + "github.com/astaxie/beego/pkg/server/web" + beegoCtx "github.com/astaxie/beego/pkg/server/web/context" ) // FilterChainBuilder provides an extension point that we can support more configurations if necessary @@ -31,7 +31,7 @@ type FilterChainBuilder struct { CustomSpanFunc func(span opentracing.Span, ctx *beegoCtx.Context) } -func (builder *FilterChainBuilder) FilterChain(next beego.FilterFunc) beego.FilterFunc { +func (builder *FilterChainBuilder) FilterChain(next web.FilterFunc) web.FilterFunc { return func(ctx *beegoCtx.Context) { var ( spanCtx context.Context @@ -79,7 +79,7 @@ func (builder *FilterChainBuilder) operationName(ctx *beegoCtx.Context) string { operationName := ctx.Input.URL() // it means that there is not any span, so we create a span as the root span. // TODO, if we support multiple servers, this need to be changed - route, found := beego.BeeApp.Handlers.FindRouter(ctx) + route, found := web.BeeApp.Handlers.FindRouter(ctx) if found { operationName = ctx.Input.Method() + "#" + route.GetPattern() } diff --git a/pkg/web/filter/opentracing/filter_test.go b/pkg/server/web/filter/opentracing/filter_test.go similarity index 96% rename from pkg/web/filter/opentracing/filter_test.go rename to pkg/server/web/filter/opentracing/filter_test.go index 750ea7a9..04f44324 100644 --- a/pkg/web/filter/opentracing/filter_test.go +++ b/pkg/server/web/filter/opentracing/filter_test.go @@ -22,7 +22,7 @@ import ( "github.com/opentracing/opentracing-go" "github.com/stretchr/testify/assert" - "github.com/astaxie/beego/pkg/context" + "github.com/astaxie/beego/pkg/server/web/context" ) func TestFilterChainBuilder_FilterChain(t *testing.T) { diff --git a/pkg/web/filter/prometheus/filter.go b/pkg/server/web/filter/prometheus/filter.go similarity index 76% rename from pkg/web/filter/prometheus/filter.go rename to pkg/server/web/filter/prometheus/filter.go index 8f4b46e3..f4231c73 100644 --- a/pkg/web/filter/prometheus/filter.go +++ b/pkg/server/web/filter/prometheus/filter.go @@ -21,8 +21,8 @@ import ( "github.com/prometheus/client_golang/prometheus" - beego "github.com/astaxie/beego/pkg" - "github.com/astaxie/beego/pkg/context" + "github.com/astaxie/beego/pkg/server/web" + "github.com/astaxie/beego/pkg/server/web/context" ) // FilterChainBuilder is an extension point, @@ -32,14 +32,14 @@ type FilterChainBuilder struct { } // FilterChain returns a FilterFunc. The filter will records some metrics -func (builder *FilterChainBuilder) FilterChain(next beego.FilterFunc) beego.FilterFunc { +func (builder *FilterChainBuilder) FilterChain(next web.FilterFunc) web.FilterFunc { summaryVec := prometheus.NewSummaryVec(prometheus.SummaryOpts{ Name: "beego", Subsystem: "http_request", ConstLabels: map[string]string{ - "server": beego.BConfig.ServerName, - "env": beego.BConfig.RunMode, - "appname": beego.BConfig.AppName, + "server": web.BConfig.ServerName, + "env": web.BConfig.RunMode, + "appname": web.BConfig.AppName, }, Help: "The statics info for http request", }, []string{"pattern", "method", "status", "duration"}) @@ -62,14 +62,14 @@ func registerBuildInfo() { Subsystem: "build_info", Help: "The building information", ConstLabels: map[string]string{ - "appname": beego.BConfig.AppName, - "build_version": beego.BuildVersion, - "build_revision": beego.BuildGitRevision, - "build_status": beego.BuildStatus, - "build_tag": beego.BuildTag, - "build_time": strings.Replace(beego.BuildTime, "--", " ", 1), - "go_version": beego.GoVersion, - "git_branch": beego.GitBranch, + "appname": web.BConfig.AppName, + "build_version": web.BuildVersion, + "build_revision": web.BuildGitRevision, + "build_status": web.BuildStatus, + "build_tag": web.BuildTag, + "build_time": strings.Replace(web.BuildTime, "--", " ", 1), + "go_version": web.GoVersion, + "git_branch": web.GitBranch, "start_time": time.Now().Format("2006-01-02 15:04:05"), }, }, []string{}) diff --git a/pkg/web/filter/prometheus/filter_test.go b/pkg/server/web/filter/prometheus/filter_test.go similarity index 95% rename from pkg/web/filter/prometheus/filter_test.go rename to pkg/server/web/filter/prometheus/filter_test.go index 822892bc..08887839 100644 --- a/pkg/web/filter/prometheus/filter_test.go +++ b/pkg/server/web/filter/prometheus/filter_test.go @@ -21,7 +21,7 @@ import ( "github.com/stretchr/testify/assert" - "github.com/astaxie/beego/pkg/context" + "github.com/astaxie/beego/pkg/server/web/context" ) func TestFilterChain(t *testing.T) { diff --git a/pkg/filter_chain_test.go b/pkg/server/web/filter_chain_test.go similarity index 94% rename from pkg/filter_chain_test.go rename to pkg/server/web/filter_chain_test.go index f1f86088..44d5f71e 100644 --- a/pkg/filter_chain_test.go +++ b/pkg/server/web/filter_chain_test.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package beego +package web import ( "net/http" @@ -21,7 +21,7 @@ import ( "github.com/stretchr/testify/assert" - "github.com/astaxie/beego/pkg/context" + "github.com/astaxie/beego/pkg/server/web/context" ) func TestControllerRegister_InsertFilterChain(t *testing.T) { diff --git a/pkg/filter_test.go b/pkg/server/web/filter_test.go similarity index 97% rename from pkg/filter_test.go rename to pkg/server/web/filter_test.go index 3a1bcb07..eea50534 100644 --- a/pkg/filter_test.go +++ b/pkg/server/web/filter_test.go @@ -12,14 +12,14 @@ // See the License for the specific language governing permissions and // limitations under the License. -package beego +package web import ( "net/http" "net/http/httptest" "testing" - "github.com/astaxie/beego/pkg/context" + "github.com/astaxie/beego/pkg/server/web/context" ) var FilterUser = func(ctx *context.Context) { diff --git a/pkg/flash.go b/pkg/server/web/flash.go similarity index 99% rename from pkg/flash.go rename to pkg/server/web/flash.go index a6485a17..55f6435d 100644 --- a/pkg/flash.go +++ b/pkg/server/web/flash.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package beego +package web import ( "fmt" diff --git a/pkg/flash_test.go b/pkg/server/web/flash_test.go similarity index 99% rename from pkg/flash_test.go rename to pkg/server/web/flash_test.go index d5e9608d..2deef54e 100644 --- a/pkg/flash_test.go +++ b/pkg/server/web/flash_test.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package beego +package web import ( "net/http" diff --git a/pkg/fs.go b/pkg/server/web/fs.go similarity index 99% rename from pkg/fs.go rename to pkg/server/web/fs.go index 41cc6f6e..5457457a 100644 --- a/pkg/fs.go +++ b/pkg/server/web/fs.go @@ -1,4 +1,4 @@ -package beego +package web import ( "net/http" diff --git a/pkg/grace/grace.go b/pkg/server/web/grace/grace.go similarity index 100% rename from pkg/grace/grace.go rename to pkg/server/web/grace/grace.go diff --git a/pkg/grace/server.go b/pkg/server/web/grace/server.go similarity index 100% rename from pkg/grace/server.go rename to pkg/server/web/grace/server.go diff --git a/pkg/hooks.go b/pkg/server/web/hooks.go similarity index 83% rename from pkg/hooks.go rename to pkg/server/web/hooks.go index 3f778cdc..080b2006 100644 --- a/pkg/hooks.go +++ b/pkg/server/web/hooks.go @@ -1,14 +1,16 @@ -package beego +package web import ( + context2 "context" "encoding/json" "mime" "net/http" "path/filepath" - "github.com/astaxie/beego/pkg/context" - "github.com/astaxie/beego/pkg/logs" - "github.com/astaxie/beego/pkg/session" + "github.com/astaxie/beego/pkg/infrastructure/logs" + "github.com/astaxie/beego/pkg/infrastructure/session" + + "github.com/astaxie/beego/pkg/server/web/context" ) // register MIME type with content type @@ -47,9 +49,9 @@ func registerDefaultErrorHandler() error { func registerSession() error { if BConfig.WebConfig.Session.SessionOn { var err error - sessionConfig := AppConfig.String("sessionConfig") + sessionConfig, err := AppConfig.String(nil, "sessionConfig") conf := new(session.ManagerConfig) - if sessionConfig == "" { + if sessionConfig == "" || err != nil { conf.CookieName = BConfig.WebConfig.Session.SessionName conf.EnableSetCookie = BConfig.WebConfig.Session.SessionAutoSetCookie conf.Gclifetime = BConfig.WebConfig.Session.SessionGCMaxLifetime @@ -95,9 +97,9 @@ func registerAdmin() error { func registerGzip() error { if BConfig.EnableGzip { context.InitGzip( - AppConfig.DefaultInt("gzipMinLength", -1), - AppConfig.DefaultInt("gzipCompressLevel", -1), - AppConfig.DefaultStrings("includedMethods", []string{"GET"}), + AppConfig.DefaultInt(context2.Background(), "gzipMinLength", -1), + AppConfig.DefaultInt(context2.Background(), "gzipCompressLevel", -1), + AppConfig.DefaultStrings(context2.Background(), "includedMethods", []string{"GET"}), ) } return nil diff --git a/pkg/mime.go b/pkg/server/web/mime.go similarity index 99% rename from pkg/mime.go rename to pkg/server/web/mime.go index ca2878ab..9393e9c7 100644 --- a/pkg/mime.go +++ b/pkg/server/web/mime.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package beego +package web var mimemaps = map[string]string{ ".3dm": "x-world/x-3dmf", diff --git a/pkg/namespace.go b/pkg/server/web/namespace.go similarity index 98% rename from pkg/namespace.go rename to pkg/server/web/namespace.go index bda18f4b..a792aa60 100644 --- a/pkg/namespace.go +++ b/pkg/server/web/namespace.go @@ -12,13 +12,13 @@ // See the License for the specific language governing permissions and // limitations under the License. -package beego +package web import ( "net/http" "strings" - beecontext "github.com/astaxie/beego/pkg/context" + beecontext "github.com/astaxie/beego/pkg/server/web/context" ) type namespaceCond func(*beecontext.Context) bool @@ -91,7 +91,7 @@ func (n *Namespace) Filter(action string, filter ...FilterFunc) *Namespace { a = FinishRouter } for _, f := range filter { - n.handlers.InsertFilter("*", a, f) + n.handlers.InsertFilter("*", a, f, WithReturnOnOutput(true)) } return n } diff --git a/pkg/namespace_test.go b/pkg/server/web/namespace_test.go similarity index 98% rename from pkg/namespace_test.go rename to pkg/server/web/namespace_test.go index bdf33b4f..39d60041 100644 --- a/pkg/namespace_test.go +++ b/pkg/server/web/namespace_test.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package beego +package web import ( "net/http" @@ -20,7 +20,7 @@ import ( "strconv" "testing" - "github.com/astaxie/beego/pkg/context" + "github.com/astaxie/beego/pkg/server/web/context" ) func TestNamespaceGet(t *testing.T) { diff --git a/pkg/utils/pagination/controller.go b/pkg/server/web/pagination/controller.go similarity index 79% rename from pkg/utils/pagination/controller.go rename to pkg/server/web/pagination/controller.go index b5b09a2f..530a72ff 100644 --- a/pkg/utils/pagination/controller.go +++ b/pkg/server/web/pagination/controller.go @@ -15,12 +15,13 @@ package pagination import ( - "github.com/astaxie/beego/pkg/context" + "github.com/astaxie/beego/pkg/infrastructure/utils/pagination" + "github.com/astaxie/beego/pkg/server/web/context" ) // SetPaginator Instantiates a Paginator and assigns it to context.Input.Data("paginator"). -func SetPaginator(context *context.Context, per int, nums int64) (paginator *Paginator) { - paginator = NewPaginator(context.Request, per, nums) +func SetPaginator(context *context.Context, per int, nums int64) (paginator *pagination.Paginator) { + paginator = pagination.NewPaginator(context.Request, per, nums) context.Input.SetData("paginator", &paginator) return } diff --git a/pkg/parser.go b/pkg/server/web/parser.go similarity index 97% rename from pkg/parser.go rename to pkg/server/web/parser.go index bee45d7b..a4507010 100644 --- a/pkg/parser.go +++ b/pkg/server/web/parser.go @@ -12,9 +12,10 @@ // See the License for the specific language governing permissions and // limitations under the License. -package beego +package web import ( + "context" "encoding/json" "errors" "fmt" @@ -30,16 +31,17 @@ import ( "golang.org/x/tools/go/packages" - "github.com/astaxie/beego/pkg/context/param" - "github.com/astaxie/beego/pkg/logs" - "github.com/astaxie/beego/pkg/utils" + "github.com/astaxie/beego/pkg/infrastructure/logs" + + "github.com/astaxie/beego/pkg/infrastructure/utils" + "github.com/astaxie/beego/pkg/server/web/context/param" ) var globalRouterTemplate = `package {{.routersDir}} import ( "github.com/astaxie/beego/pkg" - "github.com/astaxie/beego/pkg/context/param"{{.globalimport}} + "github.com/astaxie/beego/pkg/server/web/context/param"{{.globalimport}} ) func init() { @@ -515,7 +517,7 @@ func genRouterCode(pkgRealpath string) { } defer f.Close() - routersDir := AppConfig.DefaultString("routersdir", "routers") + routersDir := AppConfig.DefaultString(context.Background(), "routersdir", "routers") content := strings.Replace(globalRouterTemplate, "{{.globalinfo}}", globalinfo, -1) content = strings.Replace(content, "{{.routersDir}}", routersDir, -1) content = strings.Replace(content, "{{.globalimport}}", globalimport, -1) @@ -584,7 +586,7 @@ func getpathTime(pkgRealpath string) (lastupdate int64, err error) { func getRouterDir(pkgRealpath string) string { dir := filepath.Dir(pkgRealpath) for { - routersDir := AppConfig.DefaultString("routersdir", "routers") + routersDir := AppConfig.DefaultString(context.Background(), "routersdir", "routers") d := filepath.Join(dir, routersDir) if utils.FileExists(d) { return d diff --git a/pkg/policy.go b/pkg/server/web/policy.go similarity index 97% rename from pkg/policy.go rename to pkg/server/web/policy.go index 4af240f1..2099f99d 100644 --- a/pkg/policy.go +++ b/pkg/server/web/policy.go @@ -12,12 +12,12 @@ // See the License for the specific language governing permissions and // limitations under the License. -package beego +package web import ( "strings" - "github.com/astaxie/beego/pkg/context" + "github.com/astaxie/beego/pkg/server/web/context" ) // PolicyFunc defines a policy function which is invoked before the controller handler is executed. diff --git a/pkg/router.go b/pkg/server/web/router.go similarity index 96% rename from pkg/router.go rename to pkg/server/web/router.go index 6b25d7e3..3dd19a6f 100644 --- a/pkg/router.go +++ b/pkg/server/web/router.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package beego +package web import ( "errors" @@ -25,11 +25,11 @@ import ( "sync" "time" - beecontext "github.com/astaxie/beego/pkg/context" - "github.com/astaxie/beego/pkg/context/param" - "github.com/astaxie/beego/pkg/logs" - "github.com/astaxie/beego/pkg/toolbox" - "github.com/astaxie/beego/pkg/utils" + "github.com/astaxie/beego/pkg/infrastructure/logs" + + "github.com/astaxie/beego/pkg/infrastructure/utils" + beecontext "github.com/astaxie/beego/pkg/server/web/context" + "github.com/astaxie/beego/pkg/server/web/context/param" ) // default filter execution points @@ -148,7 +148,7 @@ func NewControllerRegister() *ControllerRegister { }, }, } - res.chainRoot = newFilterRouter("/*", false, res.serveHttp) + res.chainRoot = newFilterRouter("/*", res.serveHttp, WithCaseSensitive(false)) return res } @@ -262,7 +262,7 @@ func (p *ControllerRegister) Include(cList ...ControllerInterface) { if comm, ok := GlobalControllerRouter[key]; ok { for _, a := range comm { for _, f := range a.Filters { - p.InsertFilter(f.Pattern, f.Pos, f.Filter, f.ReturnOnOutput, f.ResetParams) + p.InsertFilter(f.Pattern, f.Pos, f.Filter, WithReturnOnOutput(f.ReturnOnOutput), WithResetParams(f.ResetParams)) } p.addWithMethodParams(a.Router, c, a.MethodParams, strings.Join(a.AllowHTTPMethods, ",")+":"+a.Method) @@ -452,8 +452,9 @@ func (p *ControllerRegister) AddAutoPrefix(prefix string, c ControllerInterface) // params is for: // 1. setting the returnOnOutput value (false allows multiple filters to execute) // 2. determining whether or not params need to be reset. -func (p *ControllerRegister) InsertFilter(pattern string, pos int, filter FilterFunc, params ...bool) error { - mr := newFilterRouter(pattern, BConfig.RouterCaseSensitive, filter, params...) +func (p *ControllerRegister) InsertFilter(pattern string, pos int, filter FilterFunc, opts ...FilterOpt) error { + opts = append(opts, WithCaseSensitive(BConfig.RouterCaseSensitive)) + mr := newFilterRouter(pattern, filter, opts...) return p.insertFilterRouter(pos, mr) } @@ -468,10 +469,11 @@ func (p *ControllerRegister) InsertFilter(pattern string, pos int, filter Filter // // do something // } // } -func (p *ControllerRegister) InsertFilterChain(pattern string, chain FilterChain, params ...bool) { +func (p *ControllerRegister) InsertFilterChain(pattern string, chain FilterChain, opts ...FilterOpt) { root := p.chainRoot filterFunc := chain(root.filterFunc) - p.chainRoot = newFilterRouter(pattern, BConfig.RouterCaseSensitive, filterFunc, params...) + opts = append(opts, WithCaseSensitive(BConfig.RouterCaseSensitive)) + p.chainRoot = newFilterRouter(pattern, filterFunc, opts...) p.chainRoot.next = root } @@ -721,7 +723,7 @@ func (p *ControllerRegister) serveHttp(ctx *beecontext.Context) { } defer func() { if ctx.Input.CruSession != nil { - ctx.Input.CruSession.SessionRelease(rw) + ctx.Input.CruSession.SessionRelease(nil, rw) } }() } @@ -906,7 +908,7 @@ Admin: if runRouter != nil { routerName = runRouter.Name() } - go toolbox.StatisticsMap.AddStatistics(r.Method, r.URL.Path, routerName, timeDur) + go StatisticsMap.AddStatistics(r.Method, r.URL.Path, routerName, timeDur) } } diff --git a/pkg/router_test.go b/pkg/server/web/router_test.go similarity index 94% rename from pkg/router_test.go rename to pkg/server/web/router_test.go index 8a7862f6..33b75703 100644 --- a/pkg/router_test.go +++ b/pkg/server/web/router_test.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package beego +package web import ( "bytes" @@ -21,8 +21,9 @@ import ( "strings" "testing" - "github.com/astaxie/beego/pkg/context" - "github.com/astaxie/beego/pkg/logs" + "github.com/astaxie/beego/pkg/infrastructure/logs" + + "github.com/astaxie/beego/pkg/server/web/context" ) type TestController struct { @@ -422,7 +423,7 @@ func TestInsertFilter(t *testing.T) { testName := "TestInsertFilter" mux := NewControllerRegister() - mux.InsertFilter("*", BeforeRouter, func(*context.Context) {}) + mux.InsertFilter("*", BeforeRouter, func(*context.Context) {}, WithReturnOnOutput(true)) if !mux.filters[BeforeRouter][0].returnOnOutput { t.Errorf( "%s: passing no variadic params should set returnOnOutput to true", @@ -435,7 +436,7 @@ func TestInsertFilter(t *testing.T) { } mux = NewControllerRegister() - mux.InsertFilter("*", BeforeRouter, func(*context.Context) {}, false) + mux.InsertFilter("*", BeforeRouter, func(*context.Context) {}, WithReturnOnOutput(false)) if mux.filters[BeforeRouter][0].returnOnOutput { t.Errorf( "%s: passing false as 1st variadic param should set returnOnOutput to false", @@ -443,7 +444,7 @@ func TestInsertFilter(t *testing.T) { } mux = NewControllerRegister() - mux.InsertFilter("*", BeforeRouter, func(*context.Context) {}, true, true) + mux.InsertFilter("*", BeforeRouter, func(*context.Context) {}, WithReturnOnOutput(true), WithResetParams(true)) if !mux.filters[BeforeRouter][0].resetParams { t.Errorf( "%s: passing true as 2nd variadic param should set resetParams to true", @@ -460,7 +461,7 @@ func TestParamResetFilter(t *testing.T) { mux := NewControllerRegister() - mux.InsertFilter("*", BeforeExec, beegoResetParams, true, true) + mux.InsertFilter("*", BeforeExec, beegoResetParams, WithReturnOnOutput(true), WithResetParams(true)) mux.Get(route, beegoHandleResetParams) @@ -513,8 +514,8 @@ func TestFilterBeforeExec(t *testing.T) { url := "/beforeExec" mux := NewControllerRegister() - mux.InsertFilter(url, BeforeRouter, beegoFilterNoOutput) - mux.InsertFilter(url, BeforeExec, beegoBeforeExec1) + mux.InsertFilter(url, BeforeRouter, beegoFilterNoOutput, WithReturnOnOutput(true)) + mux.InsertFilter(url, BeforeExec, beegoBeforeExec1, WithReturnOnOutput(true)) mux.Get(url, beegoFilterFunc) @@ -541,7 +542,7 @@ func TestFilterAfterExec(t *testing.T) { mux := NewControllerRegister() mux.InsertFilter(url, BeforeRouter, beegoFilterNoOutput) mux.InsertFilter(url, BeforeExec, beegoFilterNoOutput) - mux.InsertFilter(url, AfterExec, beegoAfterExec1, false) + mux.InsertFilter(url, AfterExec, beegoAfterExec1, WithReturnOnOutput(false)) mux.Get(url, beegoFilterFunc) @@ -569,10 +570,10 @@ func TestFilterFinishRouter(t *testing.T) { url := "/finishRouter" mux := NewControllerRegister() - mux.InsertFilter(url, BeforeRouter, beegoFilterNoOutput) - mux.InsertFilter(url, BeforeExec, beegoFilterNoOutput) - mux.InsertFilter(url, AfterExec, beegoFilterNoOutput) - mux.InsertFilter(url, FinishRouter, beegoFinishRouter1) + mux.InsertFilter(url, BeforeRouter, beegoFilterNoOutput, WithReturnOnOutput(true)) + mux.InsertFilter(url, BeforeExec, beegoFilterNoOutput, WithReturnOnOutput(true)) + mux.InsertFilter(url, AfterExec, beegoFilterNoOutput, WithReturnOnOutput(true)) + mux.InsertFilter(url, FinishRouter, beegoFinishRouter1, WithReturnOnOutput(true)) mux.Get(url, beegoFilterFunc) @@ -603,7 +604,7 @@ func TestFilterFinishRouterMultiFirstOnly(t *testing.T) { url := "/finishRouterMultiFirstOnly" mux := NewControllerRegister() - mux.InsertFilter(url, FinishRouter, beegoFinishRouter1, false) + mux.InsertFilter(url, FinishRouter, beegoFinishRouter1, WithReturnOnOutput(false)) mux.InsertFilter(url, FinishRouter, beegoFinishRouter2) mux.Get(url, beegoFilterFunc) @@ -630,8 +631,8 @@ func TestFilterFinishRouterMulti(t *testing.T) { url := "/finishRouterMulti" mux := NewControllerRegister() - mux.InsertFilter(url, FinishRouter, beegoFinishRouter1, false) - mux.InsertFilter(url, FinishRouter, beegoFinishRouter2, false) + mux.InsertFilter(url, FinishRouter, beegoFinishRouter1, WithReturnOnOutput(false)) + mux.InsertFilter(url, FinishRouter, beegoFinishRouter2, WithReturnOnOutput(false)) mux.Get(url, beegoFilterFunc) diff --git a/pkg/staticfile.go b/pkg/server/web/staticfile.go similarity index 98% rename from pkg/staticfile.go rename to pkg/server/web/staticfile.go index f8b17fc5..7b9942f4 100644 --- a/pkg/staticfile.go +++ b/pkg/server/web/staticfile.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package beego +package web import ( "bytes" @@ -26,9 +26,10 @@ import ( "sync" "time" - "github.com/astaxie/beego/pkg/context" - "github.com/astaxie/beego/pkg/logs" + "github.com/astaxie/beego/pkg/infrastructure/logs" lru "github.com/hashicorp/golang-lru" + + "github.com/astaxie/beego/pkg/server/web/context" ) var errNotStaticRequest = errors.New("request not a static file request") diff --git a/pkg/staticfile_test.go b/pkg/server/web/staticfile_test.go similarity index 99% rename from pkg/staticfile_test.go rename to pkg/server/web/staticfile_test.go index e46c13ec..0725a2f8 100644 --- a/pkg/staticfile_test.go +++ b/pkg/server/web/staticfile_test.go @@ -1,4 +1,4 @@ -package beego +package web import ( "bytes" diff --git a/pkg/toolbox/statistics.go b/pkg/server/web/statistics.go similarity index 85% rename from pkg/toolbox/statistics.go rename to pkg/server/web/statistics.go index fd73dfb3..ccc3a1fc 100644 --- a/pkg/toolbox/statistics.go +++ b/pkg/server/web/statistics.go @@ -12,12 +12,14 @@ // See the License for the specific language governing permissions and // limitations under the License. -package toolbox +package web import ( "fmt" "sync" "time" + + "github.com/astaxie/beego/pkg/infrastructure/utils" ) // Statistics struct @@ -100,13 +102,13 @@ func (m *URLMap) GetMap() map[string]interface{} { fmt.Sprintf("% -10s", kk), fmt.Sprintf("% -16d", vv.RequestNum), fmt.Sprintf("%d", vv.TotalTime), - fmt.Sprintf("% -16s", toS(vv.TotalTime)), + fmt.Sprintf("% -16s", utils.ToShortTimeFormat(vv.TotalTime)), fmt.Sprintf("%d", vv.MaxTime), - fmt.Sprintf("% -16s", toS(vv.MaxTime)), + fmt.Sprintf("% -16s", utils.ToShortTimeFormat(vv.MaxTime)), fmt.Sprintf("%d", vv.MinTime), - fmt.Sprintf("% -16s", toS(vv.MinTime)), + fmt.Sprintf("% -16s", utils.ToShortTimeFormat(vv.MinTime)), fmt.Sprintf("%d", time.Duration(int64(vv.TotalTime)/vv.RequestNum)), - fmt.Sprintf("% -16s", toS(time.Duration(int64(vv.TotalTime)/vv.RequestNum))), + fmt.Sprintf("% -16s", utils.ToShortTimeFormat(time.Duration(int64(vv.TotalTime)/vv.RequestNum))), } resultLists = append(resultLists, result) } @@ -128,10 +130,10 @@ func (m *URLMap) GetMapData() []map[string]interface{} { "request_url": k, "method": kk, "times": vv.RequestNum, - "total_time": toS(vv.TotalTime), - "max_time": toS(vv.MaxTime), - "min_time": toS(vv.MinTime), - "avg_time": toS(time.Duration(int64(vv.TotalTime) / vv.RequestNum)), + "total_time": utils.ToShortTimeFormat(vv.TotalTime), + "max_time": utils.ToShortTimeFormat(vv.MaxTime), + "min_time": utils.ToShortTimeFormat(vv.MinTime), + "avg_time": utils.ToShortTimeFormat(time.Duration(int64(vv.TotalTime) / vv.RequestNum)), } resultLists = append(resultLists, result) } diff --git a/pkg/server/web/statistics_test.go b/pkg/server/web/statistics_test.go new file mode 100644 index 00000000..7c83e15a --- /dev/null +++ b/pkg/server/web/statistics_test.go @@ -0,0 +1,40 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package web + +import ( + "encoding/json" + "testing" + "time" +) + +func TestStatics(t *testing.T) { + StatisticsMap.AddStatistics("POST", "/api/user", "&admin.user", time.Duration(2000)) + StatisticsMap.AddStatistics("POST", "/api/user", "&admin.user", time.Duration(120000)) + StatisticsMap.AddStatistics("GET", "/api/user", "&admin.user", time.Duration(13000)) + StatisticsMap.AddStatistics("POST", "/api/admin", "&admin.user", time.Duration(14000)) + StatisticsMap.AddStatistics("POST", "/api/user/astaxie", "&admin.user", time.Duration(12000)) + StatisticsMap.AddStatistics("POST", "/api/user/xiemengjun", "&admin.user", time.Duration(13000)) + StatisticsMap.AddStatistics("DELETE", "/api/user", "&admin.user", time.Duration(1400)) + t.Log(StatisticsMap.GetMap()) + + data := StatisticsMap.GetMapData() + b, err := json.Marshal(data) + if err != nil { + t.Errorf(err.Error()) + } + + t.Log(string(b)) +} diff --git a/pkg/swagger/swagger.go b/pkg/server/web/swagger/swagger.go similarity index 100% rename from pkg/swagger/swagger.go rename to pkg/server/web/swagger/swagger.go diff --git a/pkg/template.go b/pkg/server/web/template.go similarity index 99% rename from pkg/template.go rename to pkg/server/web/template.go index 8edd9dc1..a4b8db99 100644 --- a/pkg/template.go +++ b/pkg/server/web/template.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package beego +package web import ( "errors" @@ -27,8 +27,8 @@ import ( "strings" "sync" - "github.com/astaxie/beego/pkg/logs" - "github.com/astaxie/beego/pkg/utils" + "github.com/astaxie/beego/pkg/infrastructure/logs" + "github.com/astaxie/beego/pkg/infrastructure/utils" ) var ( diff --git a/pkg/template_test.go b/pkg/server/web/template_test.go similarity index 99% rename from pkg/template_test.go rename to pkg/server/web/template_test.go index 134c2cb2..b542494d 100644 --- a/pkg/template_test.go +++ b/pkg/server/web/template_test.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package beego +package web import ( "bytes" diff --git a/pkg/templatefunc.go b/pkg/server/web/templatefunc.go similarity index 97% rename from pkg/templatefunc.go rename to pkg/server/web/templatefunc.go index 6f02b8d6..f3301e50 100644 --- a/pkg/templatefunc.go +++ b/pkg/server/web/templatefunc.go @@ -12,9 +12,10 @@ // See the License for the specific language governing permissions and // limitations under the License. -package beego +package web import ( + "context" "errors" "fmt" "html" @@ -58,11 +59,11 @@ func HTML2str(html string) string { re := regexp.MustCompile(`\<[\S\s]+?\>`) html = re.ReplaceAllStringFunc(html, strings.ToLower) - //remove STYLE + // remove STYLE re = regexp.MustCompile(`\`) html = re.ReplaceAllString(html, "") - //remove SCRIPT + // remove SCRIPT re = regexp.MustCompile(`\`) html = re.ReplaceAllString(html, "") @@ -85,7 +86,7 @@ func DateFormat(t time.Time, layout string) (datestring string) { var datePatterns = []string{ // year "Y", "2006", // A full numeric representation of a year, 4 digits Examples: 1999 or 2003 - "y", "06", //A two digit representation of a year Examples: 99 or 03 + "y", "06", // A two digit representation of a year Examples: 99 or 03 // month "m", "01", // Numeric representation of a month, with leading zeros 01 through 12 @@ -160,17 +161,17 @@ func NotNil(a interface{}) (isNil bool) { func GetConfig(returnType, key string, defaultVal interface{}) (value interface{}, err error) { switch returnType { case "String": - value = AppConfig.String(key) + value, err = AppConfig.String(context.Background(), key) case "Bool": - value, err = AppConfig.Bool(key) + value, err = AppConfig.Bool(context.Background(), key) case "Int": - value, err = AppConfig.Int(key) + value, err = AppConfig.Int(context.Background(), key) case "Int64": - value, err = AppConfig.Int64(key) + value, err = AppConfig.Int64(context.Background(), key) case "Float": - value, err = AppConfig.Float(key) + value, err = AppConfig.Float(context.Background(), key) case "DIY": - value, err = AppConfig.DIY(key) + value, err = AppConfig.DIY(context.Background(), key) default: err = errors.New("config keys must be of type String, Bool, Int, Int64, Float, or DIY") } @@ -201,7 +202,7 @@ func Str2html(raw string) template.HTML { // Htmlquote returns quoted html string. func Htmlquote(text string) string { - //HTML编码为实体符号 + // HTML编码为实体符号 /* Encodes `text` for raw use in HTML. >>> htmlquote("<'&\\">") @@ -220,7 +221,7 @@ func Htmlquote(text string) string { // Htmlunquote returns unquoted html string. func Htmlunquote(text string) string { - //实体符号解释为HTML + // 实体符号解释为HTML /* Decodes `text` that's HTML quoted. >>> htmlunquote('<'&">') diff --git a/pkg/templatefunc_test.go b/pkg/server/web/templatefunc_test.go similarity index 99% rename from pkg/templatefunc_test.go rename to pkg/server/web/templatefunc_test.go index b4c19c2e..df5cfa40 100644 --- a/pkg/templatefunc_test.go +++ b/pkg/server/web/templatefunc_test.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package beego +package web import ( "html/template" diff --git a/pkg/tree.go b/pkg/server/web/tree.go similarity index 99% rename from pkg/tree.go rename to pkg/server/web/tree.go index 785ba6a6..7213a0c6 100644 --- a/pkg/tree.go +++ b/pkg/server/web/tree.go @@ -12,15 +12,16 @@ // See the License for the specific language governing permissions and // limitations under the License. -package beego +package web import ( "path" "regexp" "strings" - "github.com/astaxie/beego/pkg/context" - "github.com/astaxie/beego/pkg/utils" + "github.com/astaxie/beego/pkg/infrastructure/utils" + + "github.com/astaxie/beego/pkg/server/web/context" ) var ( diff --git a/pkg/tree_test.go b/pkg/server/web/tree_test.go similarity index 99% rename from pkg/tree_test.go rename to pkg/server/web/tree_test.go index 8758e0c0..d3091de0 100644 --- a/pkg/tree_test.go +++ b/pkg/server/web/tree_test.go @@ -12,13 +12,13 @@ // See the License for the specific language governing permissions and // limitations under the License. -package beego +package web import ( "strings" "testing" - "github.com/astaxie/beego/pkg/context" + "github.com/astaxie/beego/pkg/server/web/context" ) type testinfo struct { diff --git a/pkg/unregroute_test.go b/pkg/server/web/unregroute_test.go similarity index 99% rename from pkg/unregroute_test.go rename to pkg/server/web/unregroute_test.go index 08b1b77b..c675ae7d 100644 --- a/pkg/unregroute_test.go +++ b/pkg/server/web/unregroute_test.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package beego +package web import ( "net/http" diff --git a/pkg/toolbox/task.go b/pkg/task/task.go similarity index 91% rename from pkg/toolbox/task.go rename to pkg/task/task.go index fb2c5f16..bcadb956 100644 --- a/pkg/toolbox/task.go +++ b/pkg/task/task.go @@ -12,9 +12,10 @@ // See the License for the specific language governing permissions and // limitations under the License. -package toolbox +package task import ( + "context" "log" "math" "sort" @@ -82,17 +83,17 @@ type Schedule struct { } // TaskFunc task func type -type TaskFunc func() error +type TaskFunc func(ctx context.Context) error // Tasker task interface type Tasker interface { - GetSpec() string - GetStatus() string - Run() error - SetNext(time.Time) - GetNext() time.Time - SetPrev(time.Time) - GetPrev() time.Time + GetSpec(ctx context.Context) string + GetStatus(ctx context.Context) string + Run(ctx context.Context) error + SetNext(context.Context, time.Time) + GetNext(ctx context.Context) time.Time + SetPrev(context.Context, time.Time) + GetPrev(ctx context.Context) time.Time } // task error @@ -133,12 +134,12 @@ func NewTask(tname string, spec string, f TaskFunc) *Task { } // GetSpec get spec string -func (t *Task) GetSpec() string { +func (t *Task) GetSpec(context.Context) string { return t.SpecStr } // GetStatus get current task status -func (t *Task) GetStatus() string { +func (t *Task) GetStatus(context.Context) string { var str string for _, v := range t.Errlist { str += v.t.String() + ":" + v.errinfo + "
" @@ -147,8 +148,8 @@ func (t *Task) GetStatus() string { } // Run run all tasks -func (t *Task) Run() error { - err := t.DoFunc() +func (t *Task) Run(ctx context.Context) error { + err := t.DoFunc(ctx) if err != nil { index := t.errCnt % t.ErrLimit t.Errlist[index] = &taskerr{t: t.Next, errinfo: err.Error()} @@ -158,22 +159,22 @@ func (t *Task) Run() error { } // SetNext set next time for this task -func (t *Task) SetNext(now time.Time) { +func (t *Task) SetNext(ctx context.Context, now time.Time) { t.Next = t.Spec.Next(now) } // GetNext get the next call time of this task -func (t *Task) GetNext() time.Time { +func (t *Task) GetNext(context.Context) time.Time { return t.Next } // SetPrev set prev time of this task -func (t *Task) SetPrev(now time.Time) { +func (t *Task) SetPrev(ctx context.Context, now time.Time) { t.Prev = now } // GetPrev get prev time of this task -func (t *Task) GetPrev() time.Time { +func (t *Task) GetPrev(context.Context) time.Time { return t.Prev } @@ -410,7 +411,7 @@ func StartTask() { func run() { now := time.Now().Local() for _, t := range AdminTaskList { - t.SetNext(now) + t.SetNext(nil, now) } for { @@ -420,30 +421,30 @@ func run() { taskLock.RUnlock() sortList.Sort() var effective time.Time - if len(AdminTaskList) == 0 || sortList.Vals[0].GetNext().IsZero() { + if len(AdminTaskList) == 0 || sortList.Vals[0].GetNext(context.Background()).IsZero() { // If there are no entries yet, just sleep - it still handles new entries // and stop requests. effective = now.AddDate(10, 0, 0) } else { - effective = sortList.Vals[0].GetNext() + effective = sortList.Vals[0].GetNext(context.Background()) } select { case now = <-time.After(effective.Sub(now)): // Run every entry whose next time was this effective time. for _, e := range sortList.Vals { - if e.GetNext() != effective { + if e.GetNext(context.Background()) != effective { break } - go e.Run() - e.SetPrev(e.GetNext()) - e.SetNext(effective) + go e.Run(nil) + e.SetPrev(context.Background(), e.GetNext(context.Background())) + e.SetNext(nil, effective) } continue case <-changed: now = time.Now().Local() taskLock.Lock() for _, t := range AdminTaskList { - t.SetNext(now) + t.SetNext(nil, now) } taskLock.Unlock() continue @@ -468,7 +469,7 @@ func StopTask() { func AddTask(taskname string, t Tasker) { taskLock.Lock() defer taskLock.Unlock() - t.SetNext(time.Now().Local()) + t.SetNext(nil, time.Now().Local()) AdminTaskList[taskname] = t if isstart { changed <- true @@ -511,13 +512,13 @@ func (ms *MapSorter) Sort() { func (ms *MapSorter) Len() int { return len(ms.Keys) } func (ms *MapSorter) Less(i, j int) bool { - if ms.Vals[i].GetNext().IsZero() { + if ms.Vals[i].GetNext(context.Background()).IsZero() { return false } - if ms.Vals[j].GetNext().IsZero() { + if ms.Vals[j].GetNext(context.Background()).IsZero() { return true } - return ms.Vals[i].GetNext().Before(ms.Vals[j].GetNext()) + return ms.Vals[i].GetNext(context.Background()).Before(ms.Vals[j].GetNext(context.Background())) } func (ms *MapSorter) Swap(i, j int) { ms.Vals[i], ms.Vals[j] = ms.Vals[j], ms.Vals[i] diff --git a/pkg/task/task_test.go b/pkg/task/task_test.go new file mode 100644 index 00000000..488729dc --- /dev/null +++ b/pkg/task/task_test.go @@ -0,0 +1,89 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package task + +import ( + "context" + "errors" + "fmt" + "sync" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestParse(t *testing.T) { + tk := NewTask("taska", "0/30 * * * * *", func(ctx context.Context) error { + fmt.Println("hello world") + return nil + }) + err := tk.Run(nil) + if err != nil { + t.Fatal(err) + } + AddTask("taska", tk) + StartTask() + time.Sleep(6 * time.Second) + StopTask() +} + +func TestSpec(t *testing.T) { + wg := &sync.WaitGroup{} + wg.Add(2) + tk1 := NewTask("tk1", "0 12 * * * *", func(ctx context.Context) error { fmt.Println("tk1"); return nil }) + tk2 := NewTask("tk2", "0,10,20 * * * * *", func(ctx context.Context) error { fmt.Println("tk2"); wg.Done(); return nil }) + tk3 := NewTask("tk3", "0 10 * * * *", func(ctx context.Context) error { fmt.Println("tk3"); wg.Done(); return nil }) + + AddTask("tk1", tk1) + AddTask("tk2", tk2) + AddTask("tk3", tk3) + StartTask() + defer StopTask() + + select { + case <-time.After(200 * time.Second): + t.FailNow() + case <-wait(wg): + } +} + +func TestTask_Run(t *testing.T) { + cnt := -1 + task := func(ctx context.Context) error { + cnt++ + fmt.Printf("Hello, world! %d \n", cnt) + return errors.New(fmt.Sprintf("Hello, world! %d", cnt)) + } + tk := NewTask("taska", "0/30 * * * * *", task) + for i := 0; i < 200; i++ { + e := tk.Run(nil) + assert.NotNil(t, e) + } + + l := tk.Errlist + assert.Equal(t, 100, len(l)) + assert.Equal(t, "Hello, world! 100", l[0].errinfo) + assert.Equal(t, "Hello, world! 101", l[1].errinfo) +} + +func wait(wg *sync.WaitGroup) chan bool { + ch := make(chan bool) + go func() { + wg.Wait() + ch <- true + }() + return ch +} diff --git a/scripts/prepare_etcd.sh b/scripts/prepare_etcd.sh new file mode 100644 index 00000000..d34c05a3 --- /dev/null +++ b/scripts/prepare_etcd.sh @@ -0,0 +1,8 @@ +#!/bin/bash + +etcdctl put current.float 1.23 +etcdctl put current.bool true +etcdctl put current.int 11 +etcdctl put current.string hello +etcdctl put current.serialize.name test +etcdctl put sub.sub.key1 sub.sub.key \ No newline at end of file diff --git a/scripts/test_docker_compose.yaml b/scripts/test_docker_compose.yaml index 54ca4097..f22b6deb 100644 --- a/scripts/test_docker_compose.yaml +++ b/scripts/test_docker_compose.yaml @@ -36,4 +36,20 @@ services: image: memcached ports: - "11211:11211" - + etcd: + command: > + sh -c " + etcdctl put current.float 1.23 + && etcdctl put current.bool true + && etcdctl put current.int 11 + && etcdctl put current.string hello + && etcdctl put current.serialize.name test + " + container_name: "beego-etcd" + environment: + - ALLOW_NONE_AUTHENTICATION=yes +# - ETCD_ADVERTISE_CLIENT_URLS=http://etcd:2379 + image: bitnami/etcd + ports: + - "2379:2379" + - "2380:2380" \ No newline at end of file