1
0
mirror of https://github.com/astaxie/beego.git synced 2025-07-11 17:31:02 +00:00

353 Commits

Author SHA1 Message Date
469f283b68 beego:fix captcha filter router 2014-06-19 20:29:36 +08:00
085c362ffb beego:fix router expge 2014-06-18 23:32:47 +08:00
c3a07555c4 beego:change version to 1.3.0 2014-06-18 15:15:45 +08:00
2b8e411174 merger master httplib 2014-06-18 15:04:08 +08:00
720a77c1f9 beego:rever docs 2014-06-18 11:15:43 +08:00
b943b74fc5 beego:init GlobalDocApi 2014-06-18 11:09:47 +08:00
67be7b532d beego:doc move to swagger 2014-06-18 10:36:20 +08:00
8be83fe488 Merge pull request #642 from JessonChan/develop
ignore nil time
2014-06-17 15:25:03 +08:00
cff632f553 beego: swagger EnableDocs 2014-06-16 16:05:19 +08:00
4990d88861 Merge pull request #648 from redaready/develop
update chat example
2014-06-14 12:13:50 +08:00
7075ad8a28 update chat example 2014-06-14 00:21:26 +02:00
0e278ae358 beego:format the admin print route 2014-06-13 00:14:30 +08:00
e38a23b30e beego:admin add print method 2014-06-13 00:08:43 +08:00
117904be73 beego:fix the some regexp routes to different func 2014-06-12 23:08:05 +08:00
3b807845f2 beego:addtree support regexp 2014-06-12 20:50:29 +08:00
00b710e168 beego:namespace sub router add url to pattern 2014-06-11 23:51:19 +08:00
e25fcffbc0 config:json add support array 2014-06-11 22:47:11 +08:00
c13141b8bf beego:fix when user defined function equal to HTTP 2014-06-11 22:45:54 +08:00
7a7ff735e3 Merge pull request #644 from chrisport/develop
config: fix error when json config starts with an array
2014-06-11 22:02:28 +08:00
3b934bb910 config: fix error when json config starts with an array 2014-06-11 11:33:32 +03:00
aa275fb5ce beego:fix #639 2014-06-11 13:26:45 +08:00
deb553be7f beego:confgi support difference run mode section
runmode = dev
appname = doraemon
[dev]
httpport = 8880
sessionon = true

[prod]
httpport = 8888
sessionon = true

[test]
httpport = 8080
sessionon = false
2014-06-11 12:00:50 +08:00
3db9633ebd remove websocket logic because not support handler 2014-06-11 11:12:17 +08:00
2f8a70d548 beego: router support param has _ 2014-06-11 09:33:35 +08:00
7c0d0900ac beego:fix static file router 2014-06-11 01:19:39 +08:00
6809c97611 beego: improve performance 2014-06-11 01:11:32 +08:00
675643c68d beego: run mode support test 2014-06-10 22:47:48 +08:00
06f4bf493d ignore nil time 2014-06-10 22:10:58 +08:00
4786fb0948 beego:fix typo NewControllerRegister 2014-06-10 20:12:57 +08:00
fdb5672b7a beego:delete debug information 2014-06-10 18:10:32 +08:00
107a7a21c0 beego: dev mode print request router & pattern 2014-06-10 18:09:07 +08:00
dbebf8df4b beego:namespace support nest
ns := NewNamespace("/v3",
		NSAutoRouter(&TestController{}),
		NSNamespace("/shop",
			NSGet("/order/:id", func(ctx *context.Context) {
				ctx.Output.Body([]byte(ctx.Input.Param(":id")))
			}),
		),
	)
2014-06-10 17:11:02 +08:00
f7b01aab13 beego: modify the filter sequence 2014-06-10 11:02:41 +08:00
2570f075d9 beego:change ControllerComments exported 2014-06-09 17:46:13 +08:00
21cb8ea4a3 beego:AST code 2014-06-09 17:33:04 +08:00
6c8a7f1382 beego: router change to method Tree 2014-06-09 10:11:37 +08:00
e00eab7f49 beego: change to tree 2014-06-08 20:24:07 +08:00
bfabcfcb6b beego:router tree 2014-06-08 20:24:07 +08:00
f06ba52ede Merge pull request #633 from dlt/develop
fixed typo on constant applicationXml
2014-06-07 01:10:42 +08:00
fcae000a79 fixed typo on constant applicationXml 2014-06-06 13:56:34 -03:00
3e4c015982 Merge pull request #631 from curvesoft/master
cookiejar support
2014-06-04 23:02:52 +08:00
d689be30e8 remove httplib_test.php 2014-06-04 22:12:37 +08:00
7b110a0b73 remove httplib_test.php 2014-06-04 22:09:43 +08:00
e3033b57a6 1.gofmt httplib.go httplib_test.go
2.replace test url to http://httpbin.org functions
2014-06-04 21:15:24 +08:00
bd537554ea 1.gofmt httplib.go httplib_test.go
2.replace test url to http://httpbin.org functions
2014-06-04 21:04:50 +08:00
ebb3b91df9 1、增加cookiejar支持
2、增加Setting结构,便于统一设置请求参数
3、增加服务端测试php脚本
2014-06-03 21:20:10 +08:00
a65ad1a4bc fix the time test case 2014-05-31 14:28:44 +08:00
bdc01f52a0 Merge pull request #626 from mvpmvh/michael
Michael
2014-05-31 14:25:40 +08:00
a673a85d4a added tests config/json_test that test missing key usecases. created a template function to fetch AppConfig values 2014-05-30 23:48:23 -05:00
61008fe75c udpated timezone in templatefunc_test. changed error message to be more descriptive when tests fail 2014-05-30 14:12:21 -05:00
5dee6b7d19 beego: fix the namespace cond 2014-05-28 10:23:31 +08:00
f6c7a6bd32 beego: improve the admin router print 2014-05-27 17:27:22 +08:00
d2eece9a39 session: #620 make the session never read empty 2014-05-27 15:45:35 +08:00
c3a23b28ee beego: improve the RandomCreateBytes #620
when rand.Read is failed. will use the math/rand to generate the rand
bytes
2014-05-27 15:29:43 +08:00
9083927c6a beego: enhance the XSRFKEY from 15 to 32 #620 2014-05-27 15:00:10 +08:00
3f7e91e6a4 beego:fix *.* router bug 2014-05-26 10:15:56 +08:00
a2a6f47afa beego: support other config provider 2014-05-25 22:37:38 +08:00
23229ef9ef beego: BeegoServerName & beego.Run
BeegoServerName change to beegoServer+Version
beego.Run(“:8089”)
2014-05-25 22:35:20 +08:00
0d17d974cd beego: update namespace 2014-05-23 15:56:25 +08:00
17104c25a2 beego: Refactoring Filter & add comments 2014-05-20 18:47:41 +08:00
8b374d7f90 beego: add benchmark 2014-05-20 18:20:44 +08:00
fa3234147a httplib:drone can't upload file 2014-05-20 17:34:52 +08:00
33ad6c1370 beego: remove app funciont & fix #590
config := tls.Config{
    ClientAuth: tls.RequireAndVerifyClientCert,
    Certificates: []tls.Certificate{cert},
    ClientCAs: pool,
}
config.Rand = rand.Reader

beego.BeeApp.Server. TLSConfig = &config
2014-05-20 17:28:06 +08:00
04290dfc68 beego: delete hotupdate 2014-05-20 16:41:39 +08:00
03080b3ef2 beego:1.2.0 2014-05-20 15:53:41 +08:00
3f2a712ba8 beego:change default port 2014-05-20 15:40:05 +08:00
f215aa4810 beego: change the error tips 2014-05-20 15:34:27 +08:00
18a02d7d60 beego:support https & http listen 2014-05-20 15:30:17 +08:00
3f4d750dc4 utils: improve the file grep 2014-05-20 14:32:06 +08:00
9f01aeed31 beego:remove unused code 2014-05-19 18:52:48 +08:00
b45f0b9bf6 beego: fix #478 2014-05-17 02:56:50 +08:00
cf04ade603 merger master 2014-05-17 02:29:41 +08:00
92f6181616 beego: change the version to 1.2.0 2014-05-17 02:26:52 +08:00
9270a0504a beego: admin support link 2014-05-17 02:26:52 +08:00
1da37f6ce1 beego: controller add ServeFormatted
ServeFormatted serve Xml OR Json, depending on the value of the Accept
header
2014-05-17 02:26:52 +08:00
ef6d9b9a94 session: support memcache interface 2014-05-17 02:26:52 +08:00
c265786251 session:support struct.
gob.Register(v)
2014-05-17 02:26:51 +08:00
c5c806b58e beego: XSRF support Controller level fix #610
default value is true when you Enable Global XSRF, also can control in
the prepare function to change the value.
2014-05-17 02:26:51 +08:00
e657dcfd5f beego: support namespace
ns := beego.NewNamespace("/v1/api/")
ns.Cond(func(ctx *context.Context)bool{
	    if ctx.Input.Domain() == "www.beego.me" {
	    	return true
	    }
	    return false
	})
.Filter("before", Authenticate)
.Router("/order",	&admin.OrderController{})
.Get("/version",func (ctx *context.Context) {
	ctx.Output.Body([]byte("1.0.0"))
})
.Post("/login",func (ctx *context.Context) {
	if ctx.Query("username") == "admin" && ctx.Query("username") ==
"password" {

	}
})
.Namespace(
	NewNamespace("/shop").
		Get("/order/:id", func(ctx *context.Context) {
		ctx.Output.Body([]byte(ctx.Input.Param(":id")))
	}),
)
2014-05-17 02:26:51 +08:00
2ed9b2bffd orm: add test for unexported struct field 2014-05-17 02:26:51 +08:00
55ad951bce beego: support more router
//design model
	beego.Get(router, beego.FilterFunc)
	beego.Post(router, beego.FilterFunc)
	beego.Put(router, beego.FilterFunc)
	beego.Head(router, beego.FilterFunc)
	beego.Options(router, beego.FilterFunc)
	beego.Delete(router, beego.FilterFunc)
	beego.Handler(router, http.Handler)

//example

beego.Get("/user", func(ctx *context.Context) {
	ctx.Output.Body([]byte("Get userlist"))
})

beego.Post("/user", func(ctx *context.Context) {
	ctx.Output.Body([]byte("add userlist"))
})

beego.Delete("/user/:id", func(ctx *context.Context) {
	ctx.Output.Body([]byte([]byte(ctx.Input.Param(":id")))
})

import (
    "http"
    "github.com/gorilla/rpc"
    "github.com/gorilla/rpc/json"
)

func init() {
    s := rpc.NewServer()
    s.RegisterCodec(json.NewCodec(), "application/json")
    s.RegisterService(new(HelloService), "")
    beego.Handler("/rpc", s)
}
2014-05-17 02:26:51 +08:00
ef815bf5fc config: fix the import issue 2014-05-17 02:26:51 +08:00
6082a0af3e bug fixed 2014-05-17 02:26:51 +08:00
be30fb7937 refator func 2014-05-17 02:26:51 +08:00
f4e7d63e65 httplib support to set the protocol version for incoming requests 2014-05-17 02:26:51 +08:00
14688f240f httplib:support file upload 2014-05-17 02:26:51 +08:00
dce09837b9 fix the typo 2014-05-17 02:26:50 +08:00
3b9a404138 beego: support other analisys & fix typo 2014-05-17 02:26:50 +08:00
a6f55b59cf beego: add link in the admin console 2014-05-17 02:26:50 +08:00
c188cbbcb4 update all files License 2014-05-17 02:26:50 +08:00
4245521660 fix #576 2014-05-17 02:26:50 +08:00
05e5baaa9f beego:add post test case 2014-05-17 02:26:50 +08:00
54b92e9599 context:add Bind function
// Bind data from request.Form[key] to dest
// like
/?id=123&isok=true&ft=1.2&ol[0]=1&ol[1]=2&ul[]=str&ul[]=array&user.Name=
astaxie
// var id int  beegoInput.Bind(&id, "id")  id ==123
// var isok bool  beegoInput.Bind(&isok, "isok")  id ==true
// var ft float64  beegoInput.Bind(&ft, "ft")  ft ==1.2
// ol := make([]int, 0, 2)  beegoInput.Bind(&ol, "ol")  ol ==[1 2]
// ul := make([]string, 0, 2)  beegoInput.Bind(&ul, "ul")  ul ==[str
array]
// user struct{Name}  beegoInput.Bind(&user, "user")  user ==
{Name:"astaxie"}
2014-05-17 02:26:50 +08:00
aa68ffecec beego: support not-empty value in router fix #555 2014-05-17 02:26:50 +08:00
78991c81ab make Maxage work 2014-05-17 02:26:49 +08:00
348ff13857 update the error message 2014-05-17 02:26:49 +08:00
52817fb668 allow unexported fields on model structs 2014-05-17 02:26:49 +08:00
2c59ff1cc6 beego: admin support link 2014-05-17 02:20:48 +08:00
6bdf0838ce beego: controller add ServeFormatted
ServeFormatted serve Xml OR Json, depending on the value of the Accept
header
2014-05-17 01:26:59 +08:00
31a63c5d50 session: support memcache interface 2014-05-17 01:19:47 +08:00
237aaadd65 session:support struct.
gob.Register(v)
2014-05-17 00:43:51 +08:00
34ddcef1dc beego: XSRF support Controller level fix #610
default value is true when you Enable Global XSRF, also can control in
the prepare function to change the value.
2014-05-17 00:12:25 +08:00
f6ce2656db beego: support namespace
ns := beego.NewNamespace("/v1/api/")
ns.Cond(func(ctx *context.Context)bool{
	    if ctx.Input.Domain() == "www.beego.me" {
	    	return true
	    }
	    return false
	})
.Filter("before", Authenticate)
.Router("/order",	&admin.OrderController{})
.Get("/version",func (ctx *context.Context) {
	ctx.Output.Body([]byte("1.0.0"))
})
.Post("/login",func (ctx *context.Context) {
	if ctx.Query("username") == "admin" && ctx.Query("username") ==
"password" {

	}
})
.Namespace(
	NewNamespace("/shop").
		Get("/order/:id", func(ctx *context.Context) {
		ctx.Output.Body([]byte(ctx.Input.Param(":id")))
	}),
)
2014-05-16 23:47:29 +08:00
b647026dff orm: add test for unexported struct field 2014-05-16 13:14:15 +08:00
568c0c47f0 Merge pull request #542 from kylemcc/develop
orm: allow unexported fields on model structs
2014-05-16 13:11:55 +08:00
2629de28f2 beego: support more router
//design model
	beego.Get(router, beego.FilterFunc)
	beego.Post(router, beego.FilterFunc)
	beego.Put(router, beego.FilterFunc)
	beego.Head(router, beego.FilterFunc)
	beego.Options(router, beego.FilterFunc)
	beego.Delete(router, beego.FilterFunc)
	beego.Handler(router, http.Handler)

//example

beego.Get("/user", func(ctx *context.Context) {
	ctx.Output.Body([]byte("Get userlist"))
})

beego.Post("/user", func(ctx *context.Context) {
	ctx.Output.Body([]byte("add userlist"))
})

beego.Delete("/user/:id", func(ctx *context.Context) {
	ctx.Output.Body([]byte([]byte(ctx.Input.Param(":id")))
})

import (
    "http"
    "github.com/gorilla/rpc"
    "github.com/gorilla/rpc/json"
)

func init() {
    s := rpc.NewServer()
    s.RegisterCodec(json.NewCodec(), "application/json")
    s.RegisterService(new(HelloService), "")
    beego.Handler("/rpc", s)
}
2014-05-16 10:18:19 +08:00
10d2c7c328 config: fix the import issue 2014-05-16 10:18:19 +08:00
af7ac98bd6 Merge pull request #609 from JessonChan/develop
[important] bug fixed
2014-05-15 11:47:07 +08:00
6f78f1d4b2 bug fixed 2014-05-15 11:34:44 +08:00
9f95fd3309 Merge pull request #608 from JessonChan/develop
refactor func
2014-05-14 20:47:01 +08:00
74c309cefd refator func 2014-05-14 20:08:51 +08:00
b6d63c84ae Merge pull request #605 from francoishill/patch-6
Update app.go
2014-05-14 13:12:15 +08:00
bc2f1fb79d Update app.go 2014-05-13 17:19:50 +02:00
29e113a48a Merge pull request #597 from tobyzxj/develop
httplib support to set the protocol version for incoming requests
2014-05-09 16:25:33 +08:00
3caf1896d6 httplib support to set the protocol version for incoming requests 2014-05-09 15:48:50 +08:00
d5d5f23756 httplib:support file upload 2014-05-08 16:58:08 +08:00
46641ef3b6 fix the typo 2014-05-08 10:51:29 +08:00
8ed459512f Merge pull request #591 from luosangnanka/master
Update Session package README.md
2014-05-05 16:43:53 +08:00
25768f0109 Update README.md
Update the json format of session for file, redis, mysql, cookie, there are errors in these json string, such as after the param `ProviderConfig` and there is a lost of `"` in the line 61 of the `gclifetime`.
2014-05-05 16:21:50 +08:00
a56f67e073 Update README.md
Update the json format of session for file, redis, mysql, there are errors in these json string, after `ProviderConfig` params.
2014-05-05 16:15:49 +08:00
8164f9821d Update README.md
Update the json format of session for redis, there is an error in that json string
2014-05-05 16:13:03 +08:00
b2a69f505c beego: support other analisys & fix typo 2014-04-28 18:07:30 +08:00
e307bd7ba9 beego:hotfix for multipart/form-data 2014-04-15 05:05:53 +08:00
b2bd829d39 beego: add link in the admin console 2014-04-15 05:03:20 +08:00
f9b8617fa3 context: fix multipart/form-data 2014-04-15 05:02:50 +08:00
6c6e4ecfbc update all files License 2014-04-12 13:18:18 +08:00
8bcf03c652 fix #576 2014-04-11 16:08:43 +08:00
1ea449aa3a beego:add post test case 2014-04-10 22:33:32 +08:00
a99802b7d1 beego:query data from Form & params 2014-04-10 22:21:08 +08:00
b212ec8dab beego:query data from Form & params 2014-04-10 22:20:46 +08:00
3e16feb1e2 beego: fix flash errors 2014-04-10 18:16:08 +08:00
e50cbecf80 beego: fix flash errors 2014-04-10 18:14:18 +08:00
127b85bcaa context:add Bind function
// Bind data from request.Form[key] to dest
// like
/?id=123&isok=true&ft=1.2&ol[0]=1&ol[1]=2&ul[]=str&ul[]=array&user.Name=
astaxie
// var id int  beegoInput.Bind(&id, "id")  id ==123
// var isok bool  beegoInput.Bind(&isok, "isok")  id ==true
// var ft float64  beegoInput.Bind(&ft, "ft")  ft ==1.2
// ol := make([]int, 0, 2)  beegoInput.Bind(&ol, "ol")  ol ==[1 2]
// ul := make([]string, 0, 2)  beegoInput.Bind(&ul, "ul")  ul ==[str
array]
// user struct{Name}  beegoInput.Bind(&user, "user")  user ==
{Name:"astaxie"}
2014-04-10 00:31:16 +08:00
89dde6cd9d Merge branch 'develop' of https://github.com/astaxie/beego into develop 2014-04-09 21:43:32 +08:00
5a52949761 beego: support not-empty value in router fix #555 2014-04-09 21:42:57 +08:00
a78162e9e4 Merge pull request #573 from linluxiang/master
Make the Maxage Config of cookie session work
2014-04-09 14:42:29 +08:00
931e6162ac Merge pull request #574 from dz1984/fix/cache_test
Fix/cache test
2014-04-09 14:42:01 +08:00
1bb876f2df make Maxage work 2014-04-09 13:18:44 +08:00
18f70e6ee4 update the error message 2014-04-09 12:39:12 +08:00
82ca85dc65 release new version 1.1.4 2014-04-08 17:49:12 +08:00
f9cc9e9eb3 beego: release new version 1.1.4 2014-04-08 17:45:57 +08:00
18fee2ad9a beego: fixed serious Directory Traversal 2014-04-08 17:43:25 +08:00
4124760706 beego: filter the static file's url 2014-04-07 14:20:30 +08:00
3fe4f8c362 toolbox: modify the godocs 2014-04-06 01:05:20 +08:00
5a863b45f4 beego: BeeAdminApp private 2014-04-06 01:02:10 +08:00
3ad30d48b5 beego: fix the godoc 2014-04-06 00:53:18 +08:00
3255a43568 beego: move staticServer to New file 2014-04-06 00:18:21 +08:00
73d757e3f4 context: improve the formParse 2014-04-06 00:08:03 +08:00
deb28dd873 fix session test case 2014-04-04 10:16:34 +08:00
7f394feab5 beego: hot fix for TestBeegoInit can't parsefile 2014-04-04 10:04:36 +08:00
8cbea70e07 beego: hot fix for TestBeegoInit can't parsefile 2014-04-04 10:04:22 +08:00
f222f5b238 beego: hot fix for console logs & go run can't find file. 2014-04-04 09:57:57 +08:00
3f1de576e4 fix go run hello.go & console log 2014-04-04 09:57:51 +08:00
f48ca96a7e beego: fix log output when SetLogger has error 2014-04-04 09:57:45 +08:00
9421a21037 beego: fix log output when SetLogger has error 2014-04-04 09:57:37 +08:00
5c06cd090c beego: hot fix for console logs & go run can't find file. 2014-04-04 09:56:25 +08:00
fc982feeb9 fix go run hello.go & console log 2014-04-04 09:49:57 +08:00
31de651053 beego: fix log output when SetLogger has error 2014-04-04 09:49:57 +08:00
f4d62d3193 beego: fix dependency of cache / session sub package 2014-04-04 08:31:22 +08:00
acbdeb62e8 beego: fix log output when SetLogger has error 2014-04-04 08:22:26 +08:00
d58e9e6e12 beego: move dependency module to sub package 2014-04-03 23:41:48 +08:00
6497f29ed7 version 1.1.2 release 2014-04-03 15:56:31 +08:00
1705b42546 beego: change version from 1.1.1 to 1.1.2 2014-04-03 15:54:37 +08:00
5505cc09ed beego: move init to a fund & add a new fund TestBeegoInit
support Test with everything init
2014-04-03 15:07:20 +08:00
12e1ab0f80 beego: setLogger return error 2014-04-02 23:45:44 +08:00
9c5ceb70cc Logs: modify StartLogger to private 2014-04-02 23:43:37 +08:00
bf0b1af64f add workPath don't chdir when go run or go test 2014-04-01 18:08:00 +08:00
9c959fba4d fix string 2014-03-29 14:59:55 +08:00
5588bfc35e support filter to get router. get runController & runMethod 2014-03-29 14:55:34 +08:00
2f4acf46c6 modify the template file 2014-03-27 08:49:57 +08:00
c7437d7590 fix Cookie for session 2014-03-26 13:51:35 +08:00
4f819dbd9a Add a function SetLogFuncCall to enable 2014-03-26 00:06:25 +08:00
3f5fee2dc6 Logs support file & filenum 2014-03-25 23:48:58 +08:00
c7f16b5d5a Merge pull request #551 from steamonimo/develop
session provider for postgresql
2014-03-25 21:28:42 +08:00
8d1268c0a9 session provider for postgresql
This provider is based on the mysql provider: sess_mysql.go
2014-03-25 12:45:23 +01:00
c921b0aa5d fix #533 change the function name 2014-03-21 14:33:11 +08:00
589f97130c add w.Rotate 2014-03-21 14:33:11 +08:00
443aaadcce fix #533 change the function name 2014-03-21 14:24:00 +08:00
ff1938054a add w.Rotate 2014-03-21 14:07:03 +08:00
d79c297880 rollback: set httponly default is false. 2014-03-20 09:51:25 +08:00
65631e0522 fix orm test 2014-03-19 10:00:26 +08:00
a879e412a1 #514 2014-03-19 09:46:09 +08:00
4785ac14d7 allow unexported fields on model structs 2014-03-18 18:00:07 -05:00
50bc1ef757 rollback: set httponly default is false. 2014-03-17 12:27:04 +08:00
7bacb25725 Merge pull request #538 from admpub/patch-2
validators bug fixed
2014-03-14 15:44:05 +08:00
ad8418720f bug fixed 2014-03-14 14:47:52 +08:00
b59dae6fb8 Merge pull request #537 from jfolkins/develop
(REF#519) Enhancement: Allow developer to set FlashName and FlashSeperator values
2014-03-14 14:22:56 +08:00
4951314837 added FlashName,FlashSeperator, & Tests 2014-03-13 22:34:22 -07:00
8188873216 omit the data init 2014-03-14 12:00:53 +08:00
5d392b76c7 Merge pull request #531 from unphp/develop
Update router.go
2014-03-14 10:08:47 +08:00
c6a34b8efd Merge branch 'develop' of github.com:astaxie/beego into develop 2014-03-13 23:32:03 +08:00
95e67ba2c2 orm now support custom builtin types as model struct field or query args fix #489 2014-03-13 23:31:47 +08:00
439b1afb85 Merge branch 'release/1.1.1' 2014-03-12 21:25:41 +08:00
745e9fb0fb change version to 1.1.1 2014-03-12 21:24:23 +08:00
769f7c751b fix static file route 2014-03-12 21:06:20 +08:00
a8c2deb014 Merge pull request #530 from cnphpbb/develop
Update beego.go - Slice types exist trap
2014-03-12 18:56:12 +08:00
624f6258ee fix read / 2014-03-12 18:29:45 +08:00
43c977ab62 Update router.go
To append a slice's value into "exceptMethod", for controller's methods shouldn't reflect to AutoRouter
2014-03-12 17:20:53 +08:00
6c92ca2a16 fix bug for static file like /static /static_js /static_css 2014-03-12 17:03:34 +08:00
0f015d75d2 Update beego.go - Modify GroupRouters all function
Modify the file beego.go. 
Type GroupRouters all function 
Slice types exist trap bug.
2014-03-12 16:50:13 +08:00
217c3a2e87 Merge pull request #508 from voidd/develop
Couchbase session provider
2014-03-12 16:23:41 +08:00
ff6120cb93 Merge branch 'develop' of https://github.com/astaxie/beego into develop
Conflicts:
	session/sess_file.go
2014-03-12 15:57:57 +08:00
53aaf3b4a9 orm add ResetModelCache api for test case 2014-03-12 15:56:05 +08:00
d5b5c18cf9 orm add GetDB api #433 2014-03-12 15:56:05 +08:00
cacdb3228d orm add operator between #518 2014-03-12 15:56:05 +08:00
d0949b64c6 fix issue#521, return error when init redis session 2014-03-12 15:56:05 +08:00
d49984d47d Fix basic auth plugin example.
NewBasicAuthenticator requires passing a second argument for Realm.
2014-03-12 15:56:05 +08:00
9f3af59250 add support for sql.Null* types
Change instructions for sqlite3 tests to use in memory db for much faster
2014-03-12 15:56:05 +08:00
57afd3d979 GetStrings action as GetString 2014-03-12 15:56:05 +08:00
f7430a2ce1 enhance the static file path. If user foget / path.Join will auto fix it. 2014-03-12 15:56:05 +08:00
7389f0507e fix #505 2014-03-12 15:56:04 +08:00
ee889e9975 skip cookie args when value is nil 2014-03-12 15:56:04 +08:00
d8b9db8d3e move SetSecureCookie / GetSecureCookie to *context.Context and alias in Controller 2014-03-12 15:56:04 +08:00
9b498feac7 update output.Cookie 2014-03-12 15:56:04 +08:00
69982c62c8 path default is /
httponly default true
seuce default not set
2014-03-12 15:56:04 +08:00
b405e19f56 delete MaxAge cookielifeTime replace 2014-03-12 15:56:04 +08:00
828235b4c1 httplib support set transport and proxy 2014-03-12 15:56:04 +08:00
430a0a971f orm insert skip auto_now_add when user custom a value 2014-03-12 15:56:04 +08:00
5be22a99a8 fix captcha urlPrefix 2014-03-12 15:56:04 +08:00
5d02b18db4 register interface to gob automatically 2014-03-12 15:56:04 +08:00
97b68bdd66 fix bug, can not remove session file 2014-03-12 15:56:04 +08:00
5583fa2054 orm add ResetModelCache api for test case 2014-03-10 20:52:04 +08:00
00a410ad1a orm add GetDB api #433 2014-03-10 20:50:54 +08:00
6ca30386b8 orm add operator between #518 2014-03-10 20:19:29 +08:00
03b17a2ca9 Merge pull request #522 from pengfei-xue/develop
fix issue#521, return error when init redis session
2014-03-07 16:45:34 +08:00
9957a867cd fix issue#521, return error when init redis session 2014-03-07 16:35:34 +08:00
f3ba41a991 Merge pull request #517 from zkirill/develop
Fix basic auth plugin example.
2014-03-03 09:58:40 +08:00
4befa1bc1b Fix basic auth plugin example.
NewBasicAuthenticator requires passing a second argument for Realm.
2014-03-02 17:44:23 -08:00
9e3ebc88c4 Merge pull request #513 from hobeone/develop
add support for sql.Null* types, thx hobeone
2014-02-28 12:17:58 +08:00
6e00cfb464 add support for sql.Null* types
Change instructions for sqlite3 tests to use in memory db for much faster
2014-02-27 19:53:35 -08:00
c358c18018 Merge pull request #511 from francoishill/patch-2
Update sess_file.go
2014-02-28 09:40:28 +08:00
adf2a590fc Update sess_file.go
Lock required to ensure the File sessions work correct.
2014-02-27 15:34:38 +02:00
edb8bac5bc Merge pull request #510 from jfolkins/fix_sessregenid
fix: added nil check on c.CruSession to prevent crash
2014-02-27 10:34:39 +08:00
47d7ac06b7 fix: added nil check on c.CruSession to prevent crash 2014-02-26 16:44:31 -08:00
d05270d2ec GetStrings action as GetString 2014-02-26 15:02:58 +08:00
04a19685ed enhance the static file path. If user foget / path.Join will auto fix it. 2014-02-26 14:44:41 +08:00
62555771d0 fix #505 2014-02-24 15:55:38 +08:00
9dc93cbab0 skip cookie args when value is nil 2014-02-22 14:40:18 +08:00
7f5fb871de move SetSecureCookie / GetSecureCookie to *context.Context and alias in Controller 2014-02-22 11:58:53 +08:00
03037170e1 update output.Cookie 2014-02-22 11:12:57 +08:00
002e0854ab path default is /
httponly default true
seuce default not set
2014-02-22 10:37:14 +08:00
2bc70f62ce delete MaxAge cookielifeTime replace 2014-02-22 01:04:47 +08:00
8bf0e67b79 httplib support set transport and proxy 2014-02-20 13:53:13 +08:00
b310be1fcf orm insert skip auto_now_add when user custom a value 2014-02-20 13:45:31 +08:00
a54353b51c fix captcha urlPrefix 2014-02-20 13:44:34 +08:00
04c2ba01bc Merge branch 'develop' of https://github.com/voidd/beego 2014-02-19 22:44:06 +04:00
296bcab425 couchbase session provider 2014-02-19 15:54:16 +04:00
060b321182 Merge pull request #497 from pengfei-xue/develop
register interface to gob automatically
2014-02-18 16:47:59 +08:00
05a0a4b046 register interface to gob automatically 2014-02-14 17:52:57 +08:00
8906d3e77c Merge pull request #494 from pengfei-xue/develop
fix bug, can not remove session file
2014-02-13 20:19:42 +08:00
e822642cb0 fix bug, can not remove session file 2014-02-13 18:24:05 +08:00
a38a4f0343 Merge pull request #492 from TimothyYe/master
Fix spelling mistake
2014-02-10 14:51:26 +08:00
e8a22660e4 Update error.go
Fix the spelling mistake of error page.
2014-02-10 12:55:53 +08:00
92196c602b Merge branch 'release/release1.1.0' 2014-02-10 11:39:03 +08:00
76222ac8d0 change 1.0.1 to 1.1.0 2014-02-10 11:33:53 +08:00
a184c23603 basic auth for plugin 2014-02-10 11:31:54 +08:00
1b778509c9 should copy the data direct. don't need range 2014-02-08 10:42:34 +08:00
c4250872ca controller data inherit the context's data 2014-02-07 17:25:56 +08:00
17dd72241b Merge pull request #491 from fuxiaohei/develop
add comments for testing, utils and validation packages
2014-02-07 16:25:05 +08:00
ce2984f09a add comments for testing, utils and validation packages 2014-02-07 16:07:31 +08:00
846d766499 Merge branch 'develop' of git://github.com/astaxie/beego 2014-02-07 15:34:01 +08:00
bbc71142d7 controller can controller whether render the template.
EnableReander default is true.
2014-02-07 00:38:58 +08:00
74804bc586 Merge pull request #490 from fuxiaohei/develop
add comments for session and toolbox package
2014-02-04 06:25:18 -08:00
1d08a54f44 add comments for toolbox packages 2014-01-29 19:12:00 +08:00
682544165f add comments for session packages, part 2 2014-01-29 18:15:09 +08:00
3f0ec5c0ca Merge branch 'develop' of git://github.com/astaxie/beego into develop 2014-01-29 01:06:49 +08:00
0e2872324f add comments for session packages, part 1 2014-01-29 01:05:56 +08:00
2fb575838d Merge pull request #474 from pengfei-xue/develop
fix bug, redis session doesnt work
2014-01-28 01:50:20 -08:00
ab8f8d532a Merge pull request #487 from cloudaice/log-feature
fixed bug: in logs package check if platform is windows
2014-01-27 20:50:25 -08:00
d93f112083 fixed bug: in logs package check if platform is windows 2014-01-28 11:26:43 +08:00
9384e87083 orm 1. add api: NewOrmWithDB, AddAliasWthDB; 2. RawSeter -> add api: RowsToMap, RowsToStruct; 3. RawSeter -> change api: Values, ValuesList, ValuesFlat add optional params comumns. 2014-01-27 01:48:00 +08:00
34eff4cc1f bugfix, delete the sid if it's values is empty
* regenerate sid, if the old key doesn't exists, set the new one directly
2014-01-25 10:55:49 +08:00
8296713ba4 Merge pull request #477 from kylemcc/read_or_create
Add a ReadOrCreate method:
2014-01-24 18:01:24 -08:00
d014ccfb8e bug fix, session stored in redis cannot be deleted 2014-01-23 19:28:58 +08:00
190039b6f8 Add a ReadOrCreate method:
m := &User{Name: "Kyle"}
// Returns a boolean indicating whether the object was created,
// the primary key of the object, or an error.
created, id, err := orm.ReadOrCreate(m, "Name")
2014-01-22 09:15:21 -06:00
edf7982567 Merge pull request #473 from cloudaice/logs-feature
diffrent level logs display diffrent color
2014-01-21 19:03:32 -08:00
1509a6b681 fix bug, redis session doesnt work 2014-01-21 18:48:16 +08:00
11e6c2829b diffrent level logs display diffrent color 2014-01-21 18:00:17 +08:00
38f93a7ba9 Merge pull request #470 from fuxiaohei/develop
add comments for orm and middleware packages.
2014-01-17 17:44:52 -08:00
6b5108ef92 Merge pull request #1 from fuxiaohei/develop
merge develop
2014-01-17 07:48:39 -08:00
828a306069 add comments for orm package, done 2014-01-17 23:28:54 +08:00
4c527dde65 add comments for orm packages, part 2 2014-01-17 17:25:17 +08:00
f5a5ebe16b add comments for orm packages, part 1 2014-01-17 17:04:15 +08:00
32799bc259 add comments for middleware packages, fix typo error 2014-01-17 16:03:01 +08:00
91d75e8925 add readme for captcha, and enhanced performance 2014-01-17 12:07:30 +08:00
3e40041219 Merge pull request #468 from cloudaice/patch-1
Update README.md
2014-01-16 17:28:13 -08:00
7d5ee0d692 Update README.md 2014-01-17 00:17:43 +08:00
91cbe1f29b add some comments for captcha 2014-01-16 21:34:59 +08:00
f419c12427 add captcha util 2014-01-16 20:53:35 +08:00
fee3c2b8f9 add Strings interface can return []string sep by ;
Example:
peers = one;Two;Three
2014-01-15 17:19:03 +08:00
b016102d34 add coding 2014-01-15 09:40:33 +08:00
c20e1ab1e2 Merge pull request #463 from NormanZhang/develop
Update SessionExist to close the db connection
2014-01-14 06:13:34 -08:00
dc767b65df Update SessionExist to close the db connection
close the mysql connection
2014-01-14 19:54:32 +08:00
63f19974cd Merge pull request #460 from pengfei-xue/develop
use connection pool for redis, support auto connection
2014-01-11 06:39:20 -08:00
6e9ba0ea7f fix SessionRegenerateID should release old SessionStore and release new SessionStore in router.go 2014-01-11 17:01:33 +08:00
3b99f37aa1 add a empty fake config Initialize AppConfig to avoid nil pointer runtime error. 2014-01-11 14:28:11 +08:00
e8f5c10488 Merge pull request #457 from luxuchu/patch-1
fix #453
2014-01-10 22:21:27 -08:00
cb55009c8b remove mutex 2014-01-10 20:31:43 +08:00
b64e70e7df use connection pool for redis cache 2014-01-10 18:31:15 +08:00
8d79f8387b #441 fix detect timezone in mysql 2014-01-10 16:50:03 +08:00
afadb3f6df Update beego.go 2014-01-10 13:31:08 +08:00
844412c302 fix #453 2014-01-09 21:37:50 +08:00
299cb9130b Merge pull request #454 from pengfei-xue/develop
support redis cache auto connection
2014-01-09 04:49:31 -08:00
0b42e5573b align memcache operations with redis 2014-01-09 18:50:30 +08:00
a369b15ef2 reset cache connection to nil, if err isio.EOF
* this will support auto-connection
2014-01-09 18:49:18 +08:00
e34f8c4634 add cookie test 2014-01-08 23:24:31 +08:00
d7f2c738c8 add attach file 2014-01-08 22:35:42 +08:00
d06c04277f support send mail 2014-01-08 22:31:26 +08:00
aa2fef0d36 update sessionRelease
1. mysql fix last access time not update
2. mysql & redid Release when data is empty
3. add maxlifetime distinct Gclifetime
2014-01-08 20:54:20 +08:00
b766f65c26 #436 support insert multi 2014-01-06 11:31:35 +08:00
6f3a759ba5 gmfim add lock. fix #445 2014-01-05 23:16:47 +08:00
338124e3fb fix #443 2014-01-05 15:43:48 +08:00
31bdb793cf make fix 2014-01-05 15:21:50 +08:00
9cbd475701 beego support new version session 2014-01-05 14:59:39 +08:00
481448fa90 modify session module
change a log
2014-01-05 14:48:36 +08:00
95c65de97c fix #440 2014-01-04 22:30:17 +08:00
ef79a2b484 fix #440 2014-01-04 00:04:15 +08:00
20cfece1ab Merge pull request #438 from Codonaut/error_page_improvements
Error page improvements
2014-01-02 07:17:49 -08:00
c433b7029f added back a <br> 2014-01-02 09:54:15 -05:00
f5cf2876dd Improved the language on the error pages 2014-01-02 09:53:09 -05:00
480aa521e5 fix #430 2014-01-01 20:50:06 +08:00
d57557dc55 add AutoRouterWithPrefix 2014-01-01 17:57:57 +08:00
803d91c077 support modules design!
// the follow code is write in modules:
// GR:=NewGroupRouters()
// GR.AddRouter("/login",&UserController,"get:Login")
// GR.AddRouter("/logout",&UserController,"get:Logout")
// GR.AddRouter("/register",&UserController,"get:Reg")
// the follow code is write in app:
// import "github.com/beego/modules/auth"
// AddRouterGroup("/admin", auth.GR)
2013-12-31 23:43:15 +08:00
62ee48dcbf Merge branch 'develop' of https://github.com/astaxie/beego into develop 2013-12-31 20:48:46 +08:00
1e57587fe9 support Hijacker #428 2013-12-31 20:47:48 +08:00
61c0b3e286 fix db locked 2013-12-31 09:55:29 +08:00
383a04f4c2 move initmime from beego.Run to hookfunc 2013-12-31 00:34:47 +08:00
eea272482b Merge pull request #425 from fuxiaohei/master
add comments in logs package.
2013-12-30 07:38:51 -08:00
94ad13c846 add comments in logs package 2013-12-30 23:32:57 +08:00
412a4a04de #384 2013-12-30 23:04:13 +08:00
e0e8fa6e2a fix #413 2013-12-30 22:51:54 +08:00
a1e29b0b75 Merge pull request #422 from pengfei-xue/devel
simplify condition test for trailing /
2013-12-30 04:58:50 -08:00
984b0cbf31 1. :all param default expr change from (.+) to (.*)
2. add hookfunc to support appstart hook
2013-12-30 15:06:51 +08:00
3118c6c23f Merge commit '7a3d05ebf3fd36ea7e534de64ad38c23367ac97f' 2013-12-30 11:37:20 +08:00
3a08eec1f9 simplify condition test for trailing / 2013-12-30 11:29:35 +08:00
ecfd11adb4 fix typo healthcheck url 2013-12-29 11:01:19 +08:00
95dc670eb4 fix #416 2013-12-28 23:06:20 +08:00
7a3d05ebf3 when pattern is /admin while the url is /admin/ should return 200. fix #416 2013-12-28 23:04:45 +08:00
62f54cbbee fix typo error 2013-12-28 20:14:36 +08:00
4d7f7ffa37 Merge pull request #418 from fuxiaohei/master
add comments for httplib package.
2013-12-27 16:52:29 -08:00
cb876268b5 add comments for httplib package. 2013-12-27 17:11:39 +08:00
094f2fbab8 Merge pull request #415 from fuxiaohei/master
add comments for context package.
2013-12-26 07:40:19 -08:00
2d77c4dc49 fix code with no need line 2013-12-26 00:44:49 +08:00
f535916fae add comments for context package. 2013-12-25 20:13:38 +08:00
673993fa2b Merge pull request #412 from fuxiaohei/master
add comment in config package.
2013-12-24 07:06:02 -08:00
6f3803ce8c Merge remote-tracking branch 'astaxie/master' 2013-12-24 21:59:37 +08:00
a1f6039d82 gofmt code 2013-12-24 21:59:00 +08:00
0183608a59 add comments for config package. 2013-12-24 21:57:33 +08:00
5b1afcdb5a add timeout description for file and memory cache. 2013-12-24 21:56:48 +08:00
053e7a6aa6 Merge remote-tracking branch 'astaxie/master' 2013-12-24 21:09:17 +08:00
ba94479efd Merge remote-tracking branch 'astaxie/master' 2013-12-24 13:05:09 +08:00
ba3a9bee4c Merge remote-tracking branch 'astaxie/master' 2013-12-22 16:25:08 +08:00
f96eec6dea fix a code broken when documenting 2013-12-22 15:31:49 +08:00
151 changed files with 10316 additions and 2461 deletions

2
.gitignore vendored
View File

@ -1 +1,3 @@
.DS_Store .DS_Store
*.swp
*.swo

13
LICENSE Normal file
View File

@ -0,0 +1,13 @@
Copyright 2014 astaxie
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

View File

@ -35,9 +35,5 @@ More info [beego.me](http://beego.me)
beego is licensed under the Apache Licence, Version 2.0 beego is licensed under the Apache Licence, Version 2.0
(http://www.apache.org/licenses/LICENSE-2.0.html). (http://www.apache.org/licenses/LICENSE-2.0.html).
[![Clone in Koding](http://learn.koding.com/btn/clone_d.png)][koding]
## Use case [koding]: https://koding.com/Teamwork?import=https://github.com/astaxie/beego/archive/master.zip&c=git1
- Displaying API documentation: [gowalker](https://github.com/Unknwon/gowalker)
- seocms: [seocms](https://github.com/chinakr/seocms)
- CMS: [toropress](https://github.com/insionng/toropress)

166
admin.go
View File

@ -1,16 +1,23 @@
// Beego (http://beego.me/)
// @description beego is an open-source, high-performance web framework for the Go programming language.
// @link http://github.com/astaxie/beego for the canonical source repository
// @license http://github.com/astaxie/beego/blob/master/LICENSE
// @authors astaxie
package beego package beego
import ( import (
"fmt" "fmt"
"net/http" "net/http"
"strconv"
"time" "time"
"github.com/astaxie/beego/toolbox" "github.com/astaxie/beego/toolbox"
"github.com/astaxie/beego/utils" "github.com/astaxie/beego/utils"
) )
// BeeAdminApp is the default AdminApp used by admin module. // BeeAdminApp is the default adminApp used by admin module.
var BeeAdminApp *AdminApp var beeAdminApp *adminApp
// FilterMonitorFunc is default monitor filter when admin module is enable. // FilterMonitorFunc is default monitor filter when admin module is enable.
// if this func returns, admin module records qbs for this request by condition of this function logic. // if this func returns, admin module records qbs for this request by condition of this function logic.
@ -31,42 +38,43 @@ var BeeAdminApp *AdminApp
var FilterMonitorFunc func(string, string, time.Duration) bool var FilterMonitorFunc func(string, string, time.Duration) bool
func init() { func init() {
BeeAdminApp = &AdminApp{ beeAdminApp = &adminApp{
routers: make(map[string]http.HandlerFunc), routers: make(map[string]http.HandlerFunc),
} }
BeeAdminApp.Route("/", AdminIndex) beeAdminApp.Route("/", adminIndex)
BeeAdminApp.Route("/qps", QpsIndex) beeAdminApp.Route("/qps", qpsIndex)
BeeAdminApp.Route("/prof", ProfIndex) beeAdminApp.Route("/prof", profIndex)
BeeAdminApp.Route("/healthcheck", Healthcheck) beeAdminApp.Route("/healthcheck", healthcheck)
BeeAdminApp.Route("/task", TaskStatus) beeAdminApp.Route("/task", taskStatus)
BeeAdminApp.Route("/runtask", RunTask) beeAdminApp.Route("/runtask", runTask)
BeeAdminApp.Route("/listconf", ListConf) beeAdminApp.Route("/listconf", listConf)
FilterMonitorFunc = func(string, string, time.Duration) bool { return true } FilterMonitorFunc = func(string, string, time.Duration) bool { return true }
} }
// AdminIndex is the default http.Handler for admin module. // AdminIndex is the default http.Handler for admin module.
// it matches url pattern "/". // it matches url pattern "/".
func AdminIndex(rw http.ResponseWriter, r *http.Request) { func adminIndex(rw http.ResponseWriter, r *http.Request) {
rw.Write([]byte("Welcome to Admin Dashboard\n")) rw.Write([]byte("<html><head><title>beego admin dashboard</title></head><body>"))
rw.Write([]byte("There are servral functions:\n")) rw.Write([]byte("Welcome to Admin Dashboard<br>\n"))
rw.Write([]byte("1. Record all request and request time, http://localhost:8088/qps\n")) rw.Write([]byte("There are servral functions:<br>\n"))
rw.Write([]byte("2. Get runtime profiling data by the pprof, http://localhost:8088/prof\n")) rw.Write([]byte("1. Record all request and request time, <a href='/qps'>http://localhost:" + strconv.Itoa(AdminHttpPort) + "/qps</a><br>\n"))
rw.Write([]byte("3. Get healthcheck result from http://localhost:8088/prof\n")) rw.Write([]byte("2. Get runtime profiling data by the pprof, <a href='/prof'>http://localhost:" + strconv.Itoa(AdminHttpPort) + "/prof</a><br>\n"))
rw.Write([]byte("4. Get current task infomation from taskhttp://localhost:8088/task \n")) rw.Write([]byte("3. Get healthcheck result from <a href='/healthcheck'>http://localhost:" + strconv.Itoa(AdminHttpPort) + "/healthcheck</a><br>\n"))
rw.Write([]byte("5. To run a task passed a param http://localhost:8088/runtask\n")) rw.Write([]byte("4. Get current task infomation from task <a href='/task'>http://localhost:" + strconv.Itoa(AdminHttpPort) + "/task</a><br> \n"))
rw.Write([]byte("6. Get all confige & router infomation http://localhost:8088/listconf\n")) rw.Write([]byte("5. To run a task passed a param <a href='/runtask'>http://localhost:" + strconv.Itoa(AdminHttpPort) + "/runtask</a><br>\n"))
rw.Write([]byte("6. Get all confige & router infomation <a href='/listconf'>http://localhost:" + strconv.Itoa(AdminHttpPort) + "/listconf</a><br>\n"))
rw.Write([]byte("</body></html>"))
} }
// QpsIndex is the http.Handler for writing qbs statistics map result info in http.ResponseWriter. // QpsIndex is the http.Handler for writing qbs statistics map result info in http.ResponseWriter.
// it's registered with url pattern "/qbs" in admin module. // it's registered with url pattern "/qbs" in admin module.
func QpsIndex(rw http.ResponseWriter, r *http.Request) { func qpsIndex(rw http.ResponseWriter, r *http.Request) {
toolbox.StatisticsMap.GetMap(rw) toolbox.StatisticsMap.GetMap(rw)
} }
// ListConf is the http.Handler of displaying all beego configuration values as key/value pair. // ListConf is the http.Handler of displaying all beego configuration values as key/value pair.
// it's registered with url pattern "/listconf" in admin module. // it's registered with url pattern "/listconf" in admin module.
func ListConf(rw http.ResponseWriter, r *http.Request) { func listConf(rw http.ResponseWriter, r *http.Request) {
r.ParseForm() r.ParseForm()
command := r.Form.Get("command") command := r.Form.Get("command")
if command != "" { if command != "" {
@ -80,7 +88,7 @@ func ListConf(rw http.ResponseWriter, r *http.Request) {
fmt.Fprintln(rw, "StaticExtensionsToGzip:", StaticExtensionsToGzip) fmt.Fprintln(rw, "StaticExtensionsToGzip:", StaticExtensionsToGzip)
fmt.Fprintln(rw, "HttpAddr:", HttpAddr) fmt.Fprintln(rw, "HttpAddr:", HttpAddr)
fmt.Fprintln(rw, "HttpPort:", HttpPort) fmt.Fprintln(rw, "HttpPort:", HttpPort)
fmt.Fprintln(rw, "HttpTLS:", HttpTLS) fmt.Fprintln(rw, "HttpTLS:", EnableHttpTLS)
fmt.Fprintln(rw, "HttpCertFile:", HttpCertFile) fmt.Fprintln(rw, "HttpCertFile:", HttpCertFile)
fmt.Fprintln(rw, "HttpKeyFile:", HttpKeyFile) fmt.Fprintln(rw, "HttpKeyFile:", HttpKeyFile)
fmt.Fprintln(rw, "RecoverPanic:", RecoverPanic) fmt.Fprintln(rw, "RecoverPanic:", RecoverPanic)
@ -99,7 +107,6 @@ func ListConf(rw http.ResponseWriter, r *http.Request) {
fmt.Fprintln(rw, "MaxMemory:", MaxMemory) fmt.Fprintln(rw, "MaxMemory:", MaxMemory)
fmt.Fprintln(rw, "EnableGzip:", EnableGzip) fmt.Fprintln(rw, "EnableGzip:", EnableGzip)
fmt.Fprintln(rw, "DirectoryIndex:", DirectoryIndex) fmt.Fprintln(rw, "DirectoryIndex:", DirectoryIndex)
fmt.Fprintln(rw, "EnableHotUpdate:", EnableHotUpdate)
fmt.Fprintln(rw, "HttpServerTimeOut:", HttpServerTimeOut) fmt.Fprintln(rw, "HttpServerTimeOut:", HttpServerTimeOut)
fmt.Fprintln(rw, "ErrorsShow:", ErrorsShow) fmt.Fprintln(rw, "ErrorsShow:", ErrorsShow)
fmt.Fprintln(rw, "XSRFKEY:", XSRFKEY) fmt.Fprintln(rw, "XSRFKEY:", XSRFKEY)
@ -114,28 +121,13 @@ func ListConf(rw http.ResponseWriter, r *http.Request) {
fmt.Fprintln(rw, "AdminHttpPort:", AdminHttpPort) fmt.Fprintln(rw, "AdminHttpPort:", AdminHttpPort)
case "router": case "router":
fmt.Fprintln(rw, "Print all router infomation:") fmt.Fprintln(rw, "Print all router infomation:")
for _, router := range BeeApp.Handlers.fixrouters { for method, t := range BeeApp.Handlers.routers {
if router.hasMethod { fmt.Fprintln(rw)
fmt.Fprintln(rw, router.pattern, "----", router.methods, "----", router.controllerType.Name()) fmt.Fprintln(rw)
} else { fmt.Fprintln(rw, " Method:", method)
fmt.Fprintln(rw, router.pattern, "----", router.controllerType.Name()) printTree(rw, t)
}
}
for _, router := range BeeApp.Handlers.routers {
if router.hasMethod {
fmt.Fprintln(rw, router.pattern, "----", router.methods, "----", router.controllerType.Name())
} else {
fmt.Fprintln(rw, router.pattern, "----", router.controllerType.Name())
}
}
if BeeApp.Handlers.enableAuto {
for controllerName, methodObj := range BeeApp.Handlers.autoRouter {
fmt.Fprintln(rw, controllerName, "----")
for methodName, obj := range methodObj {
fmt.Fprintln(rw, " ", methodName, "-----", obj.Name())
}
}
} }
// @todo print routers
case "filter": case "filter":
fmt.Fprintln(rw, "Print all filter infomation:") fmt.Fprintln(rw, "Print all filter infomation:")
if BeeApp.Handlers.enableFilter { if BeeApp.Handlers.enableFilter {
@ -145,12 +137,6 @@ func ListConf(rw http.ResponseWriter, r *http.Request) {
fmt.Fprintln(rw, f.pattern, utils.GetFuncName(f.filterFunc)) fmt.Fprintln(rw, f.pattern, utils.GetFuncName(f.filterFunc))
} }
} }
fmt.Fprintln(rw, "AfterStatic:")
if bf, ok := BeeApp.Handlers.filters[AfterStatic]; ok {
for _, f := range bf {
fmt.Fprintln(rw, f.pattern, utils.GetFuncName(f.filterFunc))
}
}
fmt.Fprintln(rw, "BeforeExec:") fmt.Fprintln(rw, "BeforeExec:")
if bf, ok := BeeApp.Handlers.filters[BeforeExec]; ok { if bf, ok := BeeApp.Handlers.filters[BeforeExec]; ok {
for _, f := range bf { for _, f := range bf {
@ -174,49 +160,73 @@ func ListConf(rw http.ResponseWriter, r *http.Request) {
rw.Write([]byte("command not support")) rw.Write([]byte("command not support"))
} }
} else { } else {
rw.Write([]byte("ListConf support this command:\n")) rw.Write([]byte("<html><head><title>beego admin dashboard</title></head><body>"))
rw.Write([]byte("1. command=conf\n")) rw.Write([]byte("ListConf support this command:<br>\n"))
rw.Write([]byte("2. command=router\n")) rw.Write([]byte("1. <a href='?command=conf'>command=conf</a><br>\n"))
rw.Write([]byte("3. command=filter\n")) rw.Write([]byte("2. <a href='?command=router'>command=router</a><br>\n"))
rw.Write([]byte("3. <a href='?command=filter'>command=filter</a><br>\n"))
rw.Write([]byte("</body></html>"))
}
}
func printTree(rw http.ResponseWriter, t *Tree) {
for _, tr := range t.fixrouters {
printTree(rw, tr)
}
if t.wildcard != nil {
printTree(rw, t.wildcard)
}
for _, l := range t.leaves {
if v, ok := l.runObject.(*controllerInfo); ok {
if v.routerType == routerTypeBeego {
fmt.Fprintln(rw, v.pattern, v.methods, v.controllerType.Name())
} else if v.routerType == routerTypeRESTFul {
fmt.Fprintln(rw, v.pattern, v.methods)
} else if v.routerType == routerTypeHandler {
fmt.Fprintln(rw, v.pattern, "handler")
}
}
} }
} }
// ProfIndex is a http.Handler for showing profile command. // ProfIndex is a http.Handler for showing profile command.
// it's in url pattern "/prof" in admin module. // it's in url pattern "/prof" in admin module.
func ProfIndex(rw http.ResponseWriter, r *http.Request) { func profIndex(rw http.ResponseWriter, r *http.Request) {
r.ParseForm() r.ParseForm()
command := r.Form.Get("command") command := r.Form.Get("command")
if command != "" { if command != "" {
toolbox.ProcessInput(command, rw) toolbox.ProcessInput(command, rw)
} else { } else {
rw.Write([]byte("request url like '/prof?command=lookup goroutine'\n")) rw.Write([]byte("<html><head><title>beego admin dashboard</title></head><body>"))
rw.Write([]byte("the command have below types:\n")) rw.Write([]byte("request url like '/prof?command=lookup goroutine'<br>\n"))
rw.Write([]byte("1. lookup goroutine\n")) rw.Write([]byte("the command have below types:<br>\n"))
rw.Write([]byte("2. lookup heap\n")) rw.Write([]byte("1. <a href='?command=lookup goroutine'>lookup goroutine</a><br>\n"))
rw.Write([]byte("3. lookup threadcreate\n")) rw.Write([]byte("2. <a href='?command=lookup heap'>lookup heap</a><br>\n"))
rw.Write([]byte("4. lookup block\n")) rw.Write([]byte("3. <a href='?command=lookup threadcreate'>lookup threadcreate</a><br>\n"))
rw.Write([]byte("5. start cpuprof\n")) rw.Write([]byte("4. <a href='?command=lookup block'>lookup block</a><br>\n"))
rw.Write([]byte("6. stop cpuprof\n")) rw.Write([]byte("5. <a href='?command=start cpuprof'>start cpuprof</a><br>\n"))
rw.Write([]byte("7. get memprof\n")) rw.Write([]byte("6. <a href='?command=stop cpuprof'>stop cpuprof</a><br>\n"))
rw.Write([]byte("8. gc summary\n")) rw.Write([]byte("7. <a href='?command=get memprof'>get memprof</a><br>\n"))
rw.Write([]byte("8. <a href='?command=gc summary'>gc summary</a><br>\n"))
rw.Write([]byte("</body></html>"))
} }
} }
// Healthcheck is a http.Handler calling health checking and showing the result. // Healthcheck is a http.Handler calling health checking and showing the result.
// it's in "/healthcheck" pattern in admin module. // it's in "/healthcheck" pattern in admin module.
func Healthcheck(rw http.ResponseWriter, req *http.Request) { func healthcheck(rw http.ResponseWriter, req *http.Request) {
for name, h := range toolbox.AdminCheckList { for name, h := range toolbox.AdminCheckList {
if err := h.Check(); err != nil { if err := h.Check(); err != nil {
fmt.Fprintf(rw, "%s : ok\n", name)
} else {
fmt.Fprintf(rw, "%s : %s\n", name, err.Error()) fmt.Fprintf(rw, "%s : %s\n", name, err.Error())
} else {
fmt.Fprintf(rw, "%s : ok\n", name)
} }
} }
} }
// TaskStatus is a http.Handler with running task status (task name, status and the last execution). // TaskStatus is a http.Handler with running task status (task name, status and the last execution).
// it's in "/task" pattern in admin module. // it's in "/task" pattern in admin module.
func TaskStatus(rw http.ResponseWriter, req *http.Request) { func taskStatus(rw http.ResponseWriter, req *http.Request) {
for tname, tk := range toolbox.AdminTaskList { for tname, tk := range toolbox.AdminTaskList {
fmt.Fprintf(rw, "%s:%s:%s", tname, tk.GetStatus(), tk.GetPrev().String()) fmt.Fprintf(rw, "%s:%s:%s", tname, tk.GetStatus(), tk.GetPrev().String())
} }
@ -224,7 +234,7 @@ func TaskStatus(rw http.ResponseWriter, req *http.Request) {
// RunTask is a http.Handler to run a Task from the "query string. // RunTask is a http.Handler to run a Task from the "query string.
// the request url likes /runtask?taskname=sendmail. // the request url likes /runtask?taskname=sendmail.
func RunTask(rw http.ResponseWriter, req *http.Request) { func runTask(rw http.ResponseWriter, req *http.Request) {
req.ParseForm() req.ParseForm()
taskname := req.Form.Get("taskname") taskname := req.Form.Get("taskname")
if t, ok := toolbox.AdminTaskList[taskname]; ok { if t, ok := toolbox.AdminTaskList[taskname]; ok {
@ -232,25 +242,25 @@ func RunTask(rw http.ResponseWriter, req *http.Request) {
if err != nil { if err != nil {
fmt.Fprintf(rw, "%v", err) fmt.Fprintf(rw, "%v", err)
} }
fmt.Fprintf(rw, "%s run success,Now the Status is %s", t.GetStatus()) fmt.Fprintf(rw, "%s run success,Now the Status is %s", taskname, t.GetStatus())
} else { } else {
fmt.Fprintf(rw, "there's no task which named:%s", taskname) fmt.Fprintf(rw, "there's no task which named:%s", taskname)
} }
} }
// AdminApp is an http.HandlerFunc map used as BeeAdminApp. // adminApp is an http.HandlerFunc map used as beeAdminApp.
type AdminApp struct { type adminApp struct {
routers map[string]http.HandlerFunc routers map[string]http.HandlerFunc
} }
// Route adds http.HandlerFunc to AdminApp with url pattern. // Route adds http.HandlerFunc to adminApp with url pattern.
func (admin *AdminApp) Route(pattern string, f http.HandlerFunc) { func (admin *adminApp) Route(pattern string, f http.HandlerFunc) {
admin.routers[pattern] = f admin.routers[pattern] = f
} }
// Run AdminApp http server. // Run adminApp http server.
// Its addr is defined in configuration file as adminhttpaddr and adminhttpport. // Its addr is defined in configuration file as adminhttpaddr and adminhttpport.
func (admin *AdminApp) Run() { func (admin *adminApp) Run() {
if len(toolbox.AdminTaskList) > 0 { if len(toolbox.AdminTaskList) > 0 {
toolbox.StartTask() toolbox.StartTask()
} }

136
app.go
View File

@ -1,3 +1,9 @@
// Beego (http://beego.me/)
// @description beego is an open-source, high-performance web framework for the Go programming language.
// @link http://github.com/astaxie/beego for the canonical source repository
// @license http://github.com/astaxie/beego/blob/master/LICENSE
// @authors astaxie
package beego package beego
import ( import (
@ -16,12 +22,13 @@ type FilterFunc func(*context.Context)
// App defines beego application with a new PatternServeMux. // App defines beego application with a new PatternServeMux.
type App struct { type App struct {
Handlers *ControllerRegistor Handlers *ControllerRegistor
Server *http.Server
} }
// NewApp returns a new beego application. // NewApp returns a new beego application.
func NewApp() *App { func NewApp() *App {
cr := NewControllerRegistor() cr := NewControllerRegister()
app := &App{Handlers: cr} app := &App{Handlers: cr, Server: &http.Server{}}
return app return app
} }
@ -39,6 +46,7 @@ func (app *App) Run() {
err error err error
l net.Listener l net.Listener
) )
endRunning := make(chan bool, 1)
if UseFcgi { if UseFcgi {
if HttpPort == 0 { if HttpPort == 0 {
@ -51,114 +59,36 @@ func (app *App) Run() {
} }
err = fcgi.Serve(l, app.Handlers) err = fcgi.Serve(l, app.Handlers)
} else { } else {
if EnableHotUpdate { app.Server.Addr = addr
server := &http.Server{ app.Server.Handler = app.Handlers
Handler: app.Handlers, app.Server.ReadTimeout = time.Duration(HttpServerTimeOut) * time.Second
ReadTimeout: time.Duration(HttpServerTimeOut) * time.Second, app.Server.WriteTimeout = time.Duration(HttpServerTimeOut) * time.Second
WriteTimeout: time.Duration(HttpServerTimeOut) * time.Second,
} if EnableHttpTLS {
laddr, err := net.ResolveTCPAddr("tcp", addr) go func() {
if nil != err { if HttpsPort != 0 {
BeeLogger.Critical("ResolveTCPAddr:", err) app.Server.Addr = fmt.Sprintf("%s:%d", HttpAddr, HttpsPort)
}
l, err = GetInitListener(laddr)
theStoppable = newStoppable(l)
err = server.Serve(theStoppable)
theStoppable.wg.Wait()
CloseSelf()
} else {
s := &http.Server{
Addr: addr,
Handler: app.Handlers,
ReadTimeout: time.Duration(HttpServerTimeOut) * time.Second,
WriteTimeout: time.Duration(HttpServerTimeOut) * time.Second,
}
if HttpTLS {
err = s.ListenAndServeTLS(HttpCertFile, HttpKeyFile)
} else {
err = s.ListenAndServe()
} }
err := app.Server.ListenAndServeTLS(HttpCertFile, HttpKeyFile)
if err != nil {
BeeLogger.Critical("ListenAndServeTLS: ", err)
time.Sleep(100 * time.Microsecond)
endRunning <- true
} }
}()
} }
if EnableHttpListen {
go func() {
err := app.Server.ListenAndServe()
if err != nil { if err != nil {
BeeLogger.Critical("ListenAndServe: ", err) BeeLogger.Critical("ListenAndServe: ", err)
time.Sleep(100 * time.Microsecond) time.Sleep(100 * time.Microsecond)
endRunning <- true
}
}()
}
} }
}
// Router adds a url-patterned controller handler. <-endRunning
// The path argument supports regex rules and specific placeholders.
// The c argument needs a controller handler implemented beego.ControllerInterface.
// The mapping methods argument only need one string to define custom router rules.
// usage:
// simple router
// beego.Router("/admin", &admin.UserController{})
// beego.Router("/admin/index", &admin.ArticleController{})
//
// regex router
//
// beego.Router(“/api/:id([0-9]+)“, &controllers.RController{})
//
// custom rules
// beego.Router("/api/list",&RestController{},"*:ListFood")
// beego.Router("/api/create",&RestController{},"post:CreateFood")
// beego.Router("/api/update",&RestController{},"put:UpdateFood")
// beego.Router("/api/delete",&RestController{},"delete:DeleteFood")
func (app *App) Router(path string, c ControllerInterface, mappingMethods ...string) *App {
app.Handlers.Add(path, c, mappingMethods...)
return app
}
// AutoRouter adds beego-defined controller handler.
// if beego.AddAuto(&MainContorlller{}) and MainController has methods List and Page,
// visit the url /main/list to exec List function or /main/page to exec Page function.
func (app *App) AutoRouter(c ControllerInterface) *App {
app.Handlers.AddAuto(c)
return app
}
// UrlFor creates a url with another registered controller handler with params.
// The endpoint is formed as path.controller.name to defined the controller method which will run.
// The values need key-pair data to assign into controller method.
func (app *App) UrlFor(endpoint string, values ...string) string {
return app.Handlers.UrlFor(endpoint, values...)
}
// [Deprecated] use InsertFilter.
// Filter adds a FilterFunc under pattern condition and named action.
// The actions contains BeforeRouter,AfterStatic,BeforeExec,AfterExec and FinishRouter.
func (app *App) Filter(pattern, action string, filter FilterFunc) *App {
app.Handlers.AddFilter(pattern, action, filter)
return app
}
// InsertFilter adds a FilterFunc with pattern condition and action constant.
// The pos means action constant including
// beego.BeforeRouter, beego.AfterStatic, beego.BeforeExec, beego.AfterExec and beego.FinishRouter.
func (app *App) InsertFilter(pattern string, pos int, filter FilterFunc) *App {
app.Handlers.InsertFilter(pattern, pos, filter)
return app
}
// SetViewsPath sets view directory path in beego application.
// it returns beego application self.
func (app *App) SetViewsPath(path string) *App {
ViewsPath = path
return app
}
// SetStaticPath sets static directory path and proper url pattern in beego application.
// if beego.SetStaticPath("static","public"), visit /static/* to load static file in folder "public".
// it returns beego application self.
func (app *App) SetStaticPath(url string, path string) *App {
StaticDir[url] = path
return app
}
// DelStaticPath removes the static folder setting in this url pattern in beego application.
// it returns beego application self.
func (app *App) DelStaticPath(url string) *App {
delete(StaticDir, url)
return app
} }

352
beego.go
View File

@ -1,9 +1,17 @@
// Beego (http://beego.me/)
// @description beego is an open-source, high-performance web framework for the Go programming language.
// @link http://github.com/astaxie/beego for the canonical source repository
// @license http://github.com/astaxie/beego/blob/master/LICENSE
// @authors astaxie
package beego package beego
import ( import (
"net/http" "net/http"
"os"
"path" "path"
"path/filepath" "path/filepath"
"strconv"
"strings" "strings"
"github.com/astaxie/beego/middleware" "github.com/astaxie/beego/middleware"
@ -11,12 +19,128 @@ import (
) )
// beego web framework version. // beego web framework version.
const VERSION = "1.0.1" const VERSION = "1.3.0"
type hookfunc func() error //hook function to run
var hooks []hookfunc //hook function slice to store the hookfunc
type groupRouter struct {
pattern string
controller ControllerInterface
mappingMethods string
}
// RouterGroups which will store routers
type GroupRouters []groupRouter
// Get a new GroupRouters
func NewGroupRouters() GroupRouters {
return make(GroupRouters, 0)
}
// Add Router in the GroupRouters
// it is for plugin or module to register router
func (gr *GroupRouters) AddRouter(pattern string, c ControllerInterface, mappingMethod ...string) {
var newRG groupRouter
if len(mappingMethod) > 0 {
newRG = groupRouter{
pattern,
c,
mappingMethod[0],
}
} else {
newRG = groupRouter{
pattern,
c,
"",
}
}
*gr = append(*gr, newRG)
}
func (gr *GroupRouters) AddAuto(c ControllerInterface) {
newRG := groupRouter{
"",
c,
"",
}
*gr = append(*gr, newRG)
}
// AddGroupRouter with the prefix
// it will register the router in BeeApp
// the follow code is write in modules:
// GR:=NewGroupRouters()
// GR.AddRouter("/login",&UserController,"get:Login")
// GR.AddRouter("/logout",&UserController,"get:Logout")
// GR.AddRouter("/register",&UserController,"get:Reg")
// the follow code is write in app:
// import "github.com/beego/modules/auth"
// AddRouterGroup("/admin", auth.GR)
func AddGroupRouter(prefix string, groups GroupRouters) *App {
for _, v := range groups {
if v.pattern == "" {
BeeApp.Handlers.AddAutoPrefix(prefix, v.controller)
} else if v.mappingMethods != "" {
BeeApp.Handlers.Add(prefix+v.pattern, v.controller, v.mappingMethods)
} else {
BeeApp.Handlers.Add(prefix+v.pattern, v.controller)
}
}
return BeeApp
}
// Router adds a patterned controller handler to BeeApp. // Router adds a patterned controller handler to BeeApp.
// it's an alias method of App.Router. // it's an alias method of App.Router.
// usage:
// simple router
// beego.Router("/admin", &admin.UserController{})
// beego.Router("/admin/index", &admin.ArticleController{})
//
// regex router
//
// beego.Router("/api/:id([0-9]+)", &controllers.RController{})
//
// custom rules
// beego.Router("/api/list",&RestController{},"*:ListFood")
// beego.Router("/api/create",&RestController{},"post:CreateFood")
// beego.Router("/api/update",&RestController{},"put:UpdateFood")
// beego.Router("/api/delete",&RestController{},"delete:DeleteFood")
func Router(rootpath string, c ControllerInterface, mappingMethods ...string) *App { func Router(rootpath string, c ControllerInterface, mappingMethods ...string) *App {
BeeApp.Router(rootpath, c, mappingMethods...) BeeApp.Handlers.Add(rootpath, c, mappingMethods...)
return BeeApp
}
// Router add list from
// usage:
// beego.Include(&BankAccount{}, &OrderController{},&RefundController{},&ReceiptController{})
// type BankAccount struct{
// beego.Controller
// }
//
// register the function
// func (b *BankAccount)Mapping(){
// b.Mapping("ShowAccount" , b.ShowAccount)
// b.Mapping("ModifyAccount", b.ModifyAccount)
//}
//
// //@router /account/:id [get]
// func (b *BankAccount) ShowAccount(){
// //logic
// }
//
//
// //@router /account/:id [post]
// func (b *BankAccount) ModifyAccount(){
// //logic
// }
//
// the comments @router url methodlist
// url support all the function Router's pattern
// methodlist [get post head put delete options *]
func Include(cList ...ControllerInterface) *App {
BeeApp.Handlers.Include(cList...)
return BeeApp return BeeApp
} }
@ -31,8 +155,109 @@ func RESTRouter(rootpath string, c ControllerInterface) *App {
// AutoRouter adds defined controller handler to BeeApp. // AutoRouter adds defined controller handler to BeeApp.
// it's same to App.AutoRouter. // it's same to App.AutoRouter.
// if beego.AddAuto(&MainContorlller{}) and MainController has methods List and Page,
// visit the url /main/list to exec List function or /main/page to exec Page function.
func AutoRouter(c ControllerInterface) *App { func AutoRouter(c ControllerInterface) *App {
BeeApp.AutoRouter(c) BeeApp.Handlers.AddAuto(c)
return BeeApp
}
// AutoPrefix adds controller handler to BeeApp with prefix.
// it's same to App.AutoRouterWithPrefix.
// if beego.AutoPrefix("/admin",&MainContorlller{}) and MainController has methods List and Page,
// visit the url /admin/main/list to exec List function or /admin/main/page to exec Page function.
func AutoPrefix(prefix string, c ControllerInterface) *App {
BeeApp.Handlers.AddAutoPrefix(prefix, c)
return BeeApp
}
// register router for Get method
// usage:
// beego.Get("/", func(ctx *context.Context){
// ctx.Output.Body("hello world")
// })
func Get(rootpath string, f FilterFunc) *App {
BeeApp.Handlers.Get(rootpath, f)
return BeeApp
}
// register router for Post method
// usage:
// beego.Post("/api", func(ctx *context.Context){
// ctx.Output.Body("hello world")
// })
func Post(rootpath string, f FilterFunc) *App {
BeeApp.Handlers.Post(rootpath, f)
return BeeApp
}
// register router for Delete method
// usage:
// beego.Delete("/api", func(ctx *context.Context){
// ctx.Output.Body("hello world")
// })
func Delete(rootpath string, f FilterFunc) *App {
BeeApp.Handlers.Delete(rootpath, f)
return BeeApp
}
// register router for Put method
// usage:
// beego.Put("/api", func(ctx *context.Context){
// ctx.Output.Body("hello world")
// })
func Put(rootpath string, f FilterFunc) *App {
BeeApp.Handlers.Put(rootpath, f)
return BeeApp
}
// register router for Head method
// usage:
// beego.Head("/api", func(ctx *context.Context){
// ctx.Output.Body("hello world")
// })
func Head(rootpath string, f FilterFunc) *App {
BeeApp.Handlers.Head(rootpath, f)
return BeeApp
}
// register router for Options method
// usage:
// beego.Options("/api", func(ctx *context.Context){
// ctx.Output.Body("hello world")
// })
func Options(rootpath string, f FilterFunc) *App {
BeeApp.Handlers.Options(rootpath, f)
return BeeApp
}
// register router for Patch method
// usage:
// beego.Patch("/api", func(ctx *context.Context){
// ctx.Output.Body("hello world")
// })
func Patch(rootpath string, f FilterFunc) *App {
BeeApp.Handlers.Patch(rootpath, f)
return BeeApp
}
// register router for all method
// usage:
// beego.Any("/api", func(ctx *context.Context){
// ctx.Output.Body("hello world")
// })
func Any(rootpath string, f FilterFunc) *App {
BeeApp.Handlers.Any(rootpath, f)
return BeeApp
}
// register router for own Handler
// usage:
// beego.Handler("/api", func(ctx *context.Context){
// ctx.Output.Body("hello world")
// })
func Handler(rootpath string, h http.Handler, options ...interface{}) *App {
BeeApp.Handlers.Handler(rootpath, h, options...)
return BeeApp return BeeApp
} }
@ -45,72 +270,102 @@ func Errorhandler(err string, h http.HandlerFunc) *App {
return BeeApp return BeeApp
} }
// SetViewsPath sets view directory to BeeApp. // SetViewsPath sets view directory path in beego application.
// it's alias of App.SetViewsPath.
func SetViewsPath(path string) *App { func SetViewsPath(path string) *App {
BeeApp.SetViewsPath(path) ViewsPath = path
return BeeApp return BeeApp
} }
// SetStaticPath sets static directory and url prefix to BeeApp. // SetStaticPath sets static directory path and proper url pattern in beego application.
// it's alias of App.SetStaticPath. // if beego.SetStaticPath("static","public"), visit /static/* to load static file in folder "public".
func SetStaticPath(url string, path string) *App { func SetStaticPath(url string, path string) *App {
if !strings.HasPrefix(url, "/") { if !strings.HasPrefix(url, "/") {
url = "/" + url url = "/" + url
} }
url = strings.TrimRight(url, "/")
StaticDir[url] = path StaticDir[url] = path
return BeeApp return BeeApp
} }
// DelStaticPath removes the static folder setting in this url pattern in beego application. // DelStaticPath removes the static folder setting in this url pattern in beego application.
// it's alias of App.DelStaticPath.
func DelStaticPath(url string) *App { func DelStaticPath(url string) *App {
delete(StaticDir, url) delete(StaticDir, url)
return BeeApp return BeeApp
} }
// [Deprecated] use InsertFilter.
// Filter adds a FilterFunc under pattern condition and named action.
// The actions contains BeforeRouter,AfterStatic,BeforeExec,AfterExec and FinishRouter.
// it's alias of App.Filter.
func AddFilter(pattern, action string, filter FilterFunc) *App {
BeeApp.Filter(pattern, action, filter)
return BeeApp
}
// InsertFilter adds a FilterFunc with pattern condition and action constant. // InsertFilter adds a FilterFunc with pattern condition and action constant.
// The pos means action constant including // The pos means action constant including
// beego.BeforeRouter, beego.AfterStatic, beego.BeforeExec, beego.AfterExec and beego.FinishRouter. // beego.BeforeRouter, beego.AfterStatic, beego.BeforeExec, beego.AfterExec and beego.FinishRouter.
// it's alias of App.InsertFilter.
func InsertFilter(pattern string, pos int, filter FilterFunc) *App { func InsertFilter(pattern string, pos int, filter FilterFunc) *App {
BeeApp.InsertFilter(pattern, pos, filter) BeeApp.Handlers.InsertFilter(pattern, pos, filter)
return BeeApp return BeeApp
} }
// The hookfunc will run in beego.Run()
// such as sessionInit, middlerware start, buildtemplate, admin start
func AddAPPStartHook(hf hookfunc) {
hooks = append(hooks, hf)
}
// Run beego application. // Run beego application.
// it's alias of App.Run. // beego.Run() default run on HttpPort
func Run() { // beego.Run(":8089")
// beego.Run("127.0.0.1:8089")
func Run(params ...string) {
if len(params) > 0 && params[0] != "" {
strs := strings.Split(params[0], ":")
if len(strs) > 0 && strs[0] != "" {
HttpAddr = strs[0]
}
if len(strs) > 1 && strs[1] != "" {
HttpPort, _ = strconv.Atoi(strs[1])
}
}
initBeforeHttpRun()
if EnableAdmin {
go beeAdminApp.Run()
}
BeeApp.Run()
}
func initBeforeHttpRun() {
// if AppConfigPath not In the conf/app.conf reParse config // if AppConfigPath not In the conf/app.conf reParse config
if AppConfigPath != filepath.Join(AppPath, "conf", "app.conf") { if AppConfigPath != filepath.Join(AppPath, "conf", "app.conf") {
err := ParseConfig() err := ParseConfig()
if err != nil { if err != nil && AppConfigPath != filepath.Join(workPath, "conf", "app.conf") {
// configuration is critical to app, panic here if parse failed // configuration is critical to app, panic here if parse failed
panic(err) panic(err)
} }
} }
//init mime // do hooks function
initMime() for _, hk := range hooks {
err := hk()
if err != nil {
panic(err)
}
}
if SessionOn { if SessionOn {
GlobalSessions, _ = session.NewManager(SessionProvider, var err error
SessionName, sessionConfig := AppConfig.String("sessionConfig")
SessionGCMaxLifetime, if sessionConfig == "" {
SessionSavePath, sessionConfig = `{"cookieName":"` + SessionName + `",` +
HttpTLS, `"gclifetime":` + strconv.FormatInt(SessionGCMaxLifetime, 10) + `,` +
SessionHashFunc, `"providerConfig":"` + SessionSavePath + `",` +
SessionHashKey, `"secure":` + strconv.FormatBool(EnableHttpTLS) + `,` +
SessionCookieLifeTime) `"sessionIDHashFunc":"` + SessionHashFunc + `",` +
`"sessionIDHashKey":"` + SessionHashKey + `",` +
`"enableSetCookie":` + strconv.FormatBool(SessionAutoSetCookie) + `,` +
`"cookieLifeTime":` + strconv.Itoa(SessionCookieLifeTime) + `}`
}
GlobalSessions, err = session.NewManager(SessionProvider,
sessionConfig)
if err != nil {
panic(err)
}
go GlobalSessions.GC() go GlobalSessions.GC()
} }
@ -123,11 +378,32 @@ func Run() {
middleware.VERSION = VERSION middleware.VERSION = VERSION
middleware.AppName = AppName middleware.AppName = AppName
middleware.RegisterErrorHander() middleware.RegisterErrorHandler()
if EnableAdmin { for u, _ := range StaticDir {
go BeeAdminApp.Run() Get(u+"/*", serverStaticRouter)
}
if EnableDocs {
Get("/docs/*", serverDocs)
} }
}
BeeApp.Run()
// this function is for test package init
func TestBeegoInit(apppath string) {
AppPath = apppath
RunMode = "test"
AppConfigPath = filepath.Join(AppPath, "conf", "app.conf")
err := ParseConfig()
if err != nil && !os.IsNotExist(err) {
// for init if doesn't have app.conf will not panic
Info(err)
}
os.Chdir(AppPath)
initBeforeHttpRun()
}
func init() {
hooks = make([]hookfunc, 0)
//init mime
AddAPPStartHook(initMime)
} }

2
cache/README.md vendored
View File

@ -43,7 +43,7 @@ interval means the gc time. The cache will check at each time interval, whether
## Memcache adapter ## Memcache adapter
memory adapter use the vitess's [Memcache](http://code.google.com/p/vitess/go/memcache) client. Memcache adapter use the vitess's [Memcache](http://code.google.com/p/vitess/go/memcache) client.
Configure like this: Configure like this:

6
cache/cache.go vendored
View File

@ -1,3 +1,9 @@
// Beego (http://beego.me/)
// @description beego is an open-source, high-performance web framework for the Go programming language.
// @link http://github.com/astaxie/beego for the canonical source repository
// @license http://github.com/astaxie/beego/blob/master/LICENSE
// @authors astaxie
package cache package cache
import ( import (

58
cache/cache_test.go vendored
View File

@ -1,3 +1,9 @@
// Beego (http://beego.me/)
// @description beego is an open-source, high-performance web framework for the Go programming language.
// @link http://github.com/astaxie/beego for the canonical source repository
// @license http://github.com/astaxie/beego/blob/master/LICENSE
// @authors astaxie
package cache package cache
import ( import (
@ -5,7 +11,7 @@ import (
"time" "time"
) )
func Test_cache(t *testing.T) { func TestCache(t *testing.T) {
bm, err := NewCache("memory", `{"interval":20}`) bm, err := NewCache("memory", `{"interval":20}`)
if err != nil { if err != nil {
t.Error("init err") t.Error("init err")
@ -40,7 +46,7 @@ func Test_cache(t *testing.T) {
} }
if err = bm.Decr("astaxie"); err != nil { if err = bm.Decr("astaxie"); err != nil {
t.Error("Incr Error", err) t.Error("Decr Error", err)
} }
if v := bm.Get("astaxie"); v.(int) != 1 { if v := bm.Get("astaxie"); v.(int) != 1 {
@ -51,3 +57,51 @@ func Test_cache(t *testing.T) {
t.Error("delete err") t.Error("delete err")
} }
} }
func TestFileCache(t *testing.T) {
bm, err := NewCache("file", `{"CachePath":"/cache","FileSuffix":".bin","DirectoryLevel":2,"EmbedExpiry":0}`)
if err != nil {
t.Error("init err")
}
if err = bm.Put("astaxie", 1, 10); err != nil {
t.Error("set Error", err)
}
if !bm.IsExist("astaxie") {
t.Error("check err")
}
if v := bm.Get("astaxie"); v.(int) != 1 {
t.Error("get err")
}
if err = bm.Incr("astaxie"); err != nil {
t.Error("Incr Error", err)
}
if v := bm.Get("astaxie"); v.(int) != 2 {
t.Error("get err")
}
if err = bm.Decr("astaxie"); err != nil {
t.Error("Decr Error", err)
}
if v := bm.Get("astaxie"); v.(int) != 1 {
t.Error("get err")
}
bm.Delete("astaxie")
if bm.IsExist("astaxie") {
t.Error("delete err")
}
//test string
if err = bm.Put("astaxie", "author", 10); err != nil {
t.Error("set Error", err)
}
if !bm.IsExist("astaxie") {
t.Error("check err")
}
if v := bm.Get("astaxie"); v.(string) != "author" {
t.Error("get err")
}
}

6
cache/conv.go vendored
View File

@ -1,3 +1,9 @@
// Beego (http://beego.me/)
// @description beego is an open-source, high-performance web framework for the Go programming language.
// @link http://github.com/astaxie/beego for the canonical source repository
// @license http://github.com/astaxie/beego/blob/master/LICENSE
// @authors astaxie
package cache package cache
import ( import (

6
cache/conv_test.go vendored
View File

@ -1,3 +1,9 @@
// Beego (http://beego.me/)
// @description beego is an open-source, high-performance web framework for the Go programming language.
// @link http://github.com/astaxie/beego for the canonical source repository
// @license http://github.com/astaxie/beego/blob/master/LICENSE
// @authors astaxie
package cache package cache
import ( import (

33
cache/file.go vendored
View File

@ -1,8 +1,9 @@
/** // Beego (http://beego.me/)
* package: file // @description beego is an open-source, high-performance web framework for the Go programming language.
* User: gouki // @link http://github.com/astaxie/beego for the canonical source repository
* Date: 2013-10-22 - 14:22 // @license http://github.com/astaxie/beego/blob/master/LICENSE
*/ // @authors astaxie
package cache package cache
import ( import (
@ -47,10 +48,11 @@ type FileCache struct {
EmbedExpiry int EmbedExpiry int
} }
// Create new file cache with default directory and suffix. // Create new file cache with no config.
// the level and expiry need set in method StartAndGC as config string. // the level and expiry need set in method StartAndGC as config string.
func NewFileCache() *FileCache { func NewFileCache() *FileCache {
return &FileCache{CachePath:FileCachePath, FileSuffix:FileCacheFileSuffix} // return &FileCache{CachePath:FileCachePath, FileSuffix:FileCacheFileSuffix}
return &FileCache{}
} }
// Start and begin gc for file cache. // Start and begin gc for file cache.
@ -60,6 +62,7 @@ func (this *FileCache) StartAndGC(config string) error {
var cfg map[string]string var cfg map[string]string
json.Unmarshal([]byte(config), &cfg) json.Unmarshal([]byte(config), &cfg)
//fmt.Println(cfg) //fmt.Println(cfg)
//fmt.Println(config)
if _, ok := cfg["CachePath"]; !ok { if _, ok := cfg["CachePath"]; !ok {
cfg["CachePath"] = FileCachePath cfg["CachePath"] = FileCachePath
} }
@ -134,7 +137,7 @@ func (this *FileCache) Get(key string) interface{} {
return "" return ""
} }
var to FileCacheItem var to FileCacheItem
Gob_decode([]byte(filedata), &to) Gob_decode(filedata, &to)
if to.Expired < time.Now().Unix() { if to.Expired < time.Now().Unix() {
return "" return ""
} }
@ -142,13 +145,16 @@ func (this *FileCache) Get(key string) interface{} {
} }
// Put value into file cache. // Put value into file cache.
// timeout means how long to keep this file, unit of second. // timeout means how long to keep this file, unit of ms.
// if timeout equals FileCacheEmbedExpiry(default is 0), cache this item forever.
func (this *FileCache) Put(key string, val interface{}, timeout int64) error { func (this *FileCache) Put(key string, val interface{}, timeout int64) error {
gob.Register(val)
filename := this.getCacheFileName(key) filename := this.getCacheFileName(key)
var item FileCacheItem var item FileCacheItem
item.Data = val item.Data = val
if timeout == FileCacheEmbedExpiry { if timeout == FileCacheEmbedExpiry {
item.Expired = time.Now().Unix() + (86400 * 365 * 10) //10年 item.Expired = time.Now().Unix() + (86400 * 365 * 10) // ten years
} else { } else {
item.Expired = time.Now().Unix() + timeout item.Expired = time.Now().Unix() + timeout
} }
@ -175,7 +181,7 @@ func (this *FileCache) Delete(key string) error {
func (this *FileCache) Incr(key string) error { func (this *FileCache) Incr(key string) error {
data := this.Get(key) data := this.Get(key)
var incr int var incr int
fmt.Println(reflect.TypeOf(data).Name()) //fmt.Println(reflect.TypeOf(data).Name())
if reflect.TypeOf(data).Name() != "int" { if reflect.TypeOf(data).Name() != "int" {
incr = 0 incr = 0
} else { } else {
@ -208,8 +214,7 @@ func (this *FileCache) IsExist(key string) bool {
// Clean cached files. // Clean cached files.
// not implemented. // not implemented.
func (this *FileCache) ClearAll() error { func (this *FileCache) ClearAll() error {
//this.CachePath .递归删除 //this.CachePath
return nil return nil
} }
@ -269,7 +274,7 @@ func Gob_encode(data interface{}) ([]byte, error) {
} }
// Gob decodes file cache item. // Gob decodes file cache item.
func Gob_decode(data []byte, to interface{}) error { func Gob_decode(data []byte, to *FileCacheItem) error {
buf := bytes.NewBuffer(data) buf := bytes.NewBuffer(data)
dec := gob.NewDecoder(buf) dec := gob.NewDecoder(buf)
return dec.Decode(&to) return dec.Decode(&to)

View File

@ -1,3 +1,9 @@
// Beego (http://beego.me/)
// @description beego is an open-source, high-performance web framework for the Go programming language.
// @link http://github.com/astaxie/beego for the canonical source repository
// @license http://github.com/astaxie/beego/blob/master/LICENSE
// @authors astaxie
package cache package cache
import ( import (
@ -5,6 +11,8 @@ import (
"errors" "errors"
"github.com/beego/memcache" "github.com/beego/memcache"
"github.com/astaxie/beego/cache"
) )
// Memcache adapter. // Memcache adapter.
@ -21,7 +29,11 @@ func NewMemCache() *MemcacheCache {
// get value from memcache. // get value from memcache.
func (rc *MemcacheCache) Get(key string) interface{} { func (rc *MemcacheCache) Get(key string) interface{} {
if rc.c == nil { if rc.c == nil {
rc.c = rc.connectInit() var err error
rc.c, err = rc.connectInit()
if err != nil {
return err
}
} }
v, err := rc.c.Get(key) v, err := rc.c.Get(key)
if err != nil { if err != nil {
@ -39,7 +51,11 @@ func (rc *MemcacheCache) Get(key string) interface{} {
// put value to memcache. only support string. // put value to memcache. only support string.
func (rc *MemcacheCache) Put(key string, val interface{}, timeout int64) error { func (rc *MemcacheCache) Put(key string, val interface{}, timeout int64) error {
if rc.c == nil { if rc.c == nil {
rc.c = rc.connectInit() var err error
rc.c, err = rc.connectInit()
if err != nil {
return err
}
} }
v, ok := val.(string) v, ok := val.(string)
if !ok { if !ok {
@ -55,7 +71,11 @@ func (rc *MemcacheCache) Put(key string, val interface{}, timeout int64) error {
// delete value in memcache. // delete value in memcache.
func (rc *MemcacheCache) Delete(key string) error { func (rc *MemcacheCache) Delete(key string) error {
if rc.c == nil { if rc.c == nil {
rc.c = rc.connectInit() var err error
rc.c, err = rc.connectInit()
if err != nil {
return err
}
} }
_, err := rc.c.Delete(key) _, err := rc.c.Delete(key)
return err return err
@ -76,7 +96,11 @@ func (rc *MemcacheCache) Decr(key string) error {
// check value exists in memcache. // check value exists in memcache.
func (rc *MemcacheCache) IsExist(key string) bool { func (rc *MemcacheCache) IsExist(key string) bool {
if rc.c == nil { if rc.c == nil {
rc.c = rc.connectInit() var err error
rc.c, err = rc.connectInit()
if err != nil {
return false
}
} }
v, err := rc.c.Get(key) v, err := rc.c.Get(key)
if err != nil { if err != nil {
@ -87,13 +111,16 @@ func (rc *MemcacheCache) IsExist(key string) bool {
} else { } else {
return true return true
} }
return true
} }
// clear all cached in memcache. // clear all cached in memcache.
func (rc *MemcacheCache) ClearAll() error { func (rc *MemcacheCache) ClearAll() error {
if rc.c == nil { if rc.c == nil {
rc.c = rc.connectInit() var err error
rc.c, err = rc.connectInit()
if err != nil {
return err
}
} }
err := rc.c.FlushAll() err := rc.c.FlushAll()
return err return err
@ -109,22 +136,25 @@ func (rc *MemcacheCache) StartAndGC(config string) error {
return errors.New("config has no conn key") return errors.New("config has no conn key")
} }
rc.conninfo = cf["conn"] rc.conninfo = cf["conn"]
rc.c = rc.connectInit() var err error
if rc.c == nil { if rc.c != nil {
rc.c, err = rc.connectInit()
if err != nil {
return errors.New("dial tcp conn error") return errors.New("dial tcp conn error")
} }
}
return nil return nil
} }
// connect to memcache and keep the connection. // connect to memcache and keep the connection.
func (rc *MemcacheCache) connectInit() *memcache.Connection { func (rc *MemcacheCache) connectInit() (*memcache.Connection, error) {
c, err := memcache.Connect(rc.conninfo) c, err := memcache.Connect(rc.conninfo)
if err != nil { if err != nil {
return nil return nil, err
} }
return c return c, nil
} }
func init() { func init() {
Register("memcache", NewMemCache()) cache.Register("memcache", NewMemCache())
} }

9
cache/memory.go vendored
View File

@ -1,3 +1,9 @@
// Beego (http://beego.me/)
// @description beego is an open-source, high-performance web framework for the Go programming language.
// @link http://github.com/astaxie/beego for the canonical source repository
// @license http://github.com/astaxie/beego/blob/master/LICENSE
// @authors astaxie
package cache package cache
import ( import (
@ -26,7 +32,7 @@ type MemoryCache struct {
lock sync.RWMutex lock sync.RWMutex
dur time.Duration dur time.Duration
items map[string]*MemoryItem items map[string]*MemoryItem
Every int // run an expiration check Every cloc; time Every int // run an expiration check Every clock time
} }
// NewMemoryCache returns a new MemoryCache. // NewMemoryCache returns a new MemoryCache.
@ -52,6 +58,7 @@ func (bc *MemoryCache) Get(name string) interface{} {
} }
// Put cache to memory. // Put cache to memory.
// if expired is 0, it will be cleaned by next gc operation ( default gc clock is 1 minute).
func (bc *MemoryCache) Put(name string, value interface{}, expired int64) error { func (bc *MemoryCache) Put(name string, value interface{}, expired int64) error {
bc.lock.Lock() bc.lock.Lock()
defer bc.lock.Unlock() defer bc.lock.Unlock()

View File

@ -1,10 +1,19 @@
// Beego (http://beego.me/)
// @description beego is an open-source, high-performance web framework for the Go programming language.
// @link http://github.com/astaxie/beego for the canonical source repository
// @license http://github.com/astaxie/beego/blob/master/LICENSE
// @authors astaxie
package cache package cache
import ( import (
"encoding/json" "encoding/json"
"errors" "errors"
"time"
"github.com/beego/redigo/redis" "github.com/beego/redigo/redis"
"github.com/astaxie/beego/cache"
) )
var ( var (
@ -14,7 +23,7 @@ var (
// Redis cache adapter. // Redis cache adapter.
type RedisCache struct { type RedisCache struct {
c redis.Conn p *redis.Pool // redis connection pool
conninfo string conninfo string
key string key string
} }
@ -24,107 +33,62 @@ func NewRedisCache() *RedisCache {
return &RedisCache{key: DefaultKey} return &RedisCache{key: DefaultKey}
} }
// actually do the redis cmds
func (rc *RedisCache) do(commandName string, args ...interface{}) (reply interface{}, err error) {
c := rc.p.Get()
defer c.Close()
return c.Do(commandName, args...)
}
// Get cache from redis. // Get cache from redis.
func (rc *RedisCache) Get(key string) interface{} { func (rc *RedisCache) Get(key string) interface{} {
if rc.c == nil { v, err := rc.do("HGET", rc.key, key)
var err error
rc.c, err = rc.connectInit()
if err != nil {
return nil
}
}
v, err := rc.c.Do("HGET", rc.key, key)
if err != nil { if err != nil {
return nil return nil
} }
return v return v
} }
// put cache to redis. // put cache to redis.
// timeout is ignored. // timeout is ignored.
func (rc *RedisCache) Put(key string, val interface{}, timeout int64) error { func (rc *RedisCache) Put(key string, val interface{}, timeout int64) error {
if rc.c == nil { _, err := rc.do("HSET", rc.key, key, val)
var err error
rc.c, err = rc.connectInit()
if err != nil {
return err
}
}
_, err := rc.c.Do("HSET", rc.key, key, val)
return err return err
} }
// delete cache in redis. // delete cache in redis.
func (rc *RedisCache) Delete(key string) error { func (rc *RedisCache) Delete(key string) error {
if rc.c == nil { _, err := rc.do("HDEL", rc.key, key)
var err error
rc.c, err = rc.connectInit()
if err != nil {
return err
}
}
_, err := rc.c.Do("HDEL", rc.key, key)
return err return err
} }
// check cache exist in redis. // check cache exist in redis.
func (rc *RedisCache) IsExist(key string) bool { func (rc *RedisCache) IsExist(key string) bool {
if rc.c == nil { v, err := redis.Bool(rc.do("HEXISTS", rc.key, key))
var err error
rc.c, err = rc.connectInit()
if err != nil {
return false
}
}
v, err := redis.Bool(rc.c.Do("HEXISTS", rc.key, key))
if err != nil { if err != nil {
return false return false
} }
return v return v
} }
// increase counter in redis. // increase counter in redis.
func (rc *RedisCache) Incr(key string) error { func (rc *RedisCache) Incr(key string) error {
if rc.c == nil { _, err := redis.Bool(rc.do("HINCRBY", rc.key, key, 1))
var err error
rc.c, err = rc.connectInit()
if err != nil {
return err return err
}
}
_, err := redis.Bool(rc.c.Do("HINCRBY", rc.key, key, 1))
if err != nil {
return err
}
return nil
} }
// decrease counter in redis. // decrease counter in redis.
func (rc *RedisCache) Decr(key string) error { func (rc *RedisCache) Decr(key string) error {
if rc.c == nil { _, err := redis.Bool(rc.do("HINCRBY", rc.key, key, -1))
var err error
rc.c, err = rc.connectInit()
if err != nil {
return err return err
}
}
_, err := redis.Bool(rc.c.Do("HINCRBY", rc.key, key, -1))
if err != nil {
return err
}
return nil
} }
// clean all cache in redis. delete this redis collection. // clean all cache in redis. delete this redis collection.
func (rc *RedisCache) ClearAll() error { func (rc *RedisCache) ClearAll() error {
if rc.c == nil { _, err := rc.do("DEL", rc.key)
var err error
rc.c, err = rc.connectInit()
if err != nil {
return err
}
}
_, err := rc.c.Do("DEL", rc.key)
return err return err
} }
@ -135,34 +99,44 @@ func (rc *RedisCache) ClearAll() error {
func (rc *RedisCache) StartAndGC(config string) error { func (rc *RedisCache) StartAndGC(config string) error {
var cf map[string]string var cf map[string]string
json.Unmarshal([]byte(config), &cf) json.Unmarshal([]byte(config), &cf)
if _, ok := cf["key"]; !ok { if _, ok := cf["key"]; !ok {
cf["key"] = DefaultKey cf["key"] = DefaultKey
} }
if _, ok := cf["conn"]; !ok { if _, ok := cf["conn"]; !ok {
return errors.New("config has no conn key") return errors.New("config has no conn key")
} }
rc.key = cf["key"] rc.key = cf["key"]
rc.conninfo = cf["conn"] rc.conninfo = cf["conn"]
var err error rc.connectInit()
rc.c, err = rc.connectInit()
if err != nil { c := rc.p.Get()
defer c.Close()
if err := c.Err(); err != nil {
return err return err
} }
if rc.c == nil {
return errors.New("dial tcp conn error")
}
return nil return nil
} }
// connect to redis. // connect to redis.
func (rc *RedisCache) connectInit() (redis.Conn, error) { func (rc *RedisCache) connectInit() {
// initialize a new pool
rc.p = &redis.Pool{
MaxIdle: 3,
IdleTimeout: 180 * time.Second,
Dial: func() (redis.Conn, error) {
c, err := redis.Dial("tcp", rc.conninfo) c, err := redis.Dial("tcp", rc.conninfo)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return c, nil return c, nil
},
}
} }
func init() { func init() {
Register("redis", NewRedisCache()) cache.Register("redis", NewRedisCache())
} }

268
config.go
View File

@ -1,29 +1,40 @@
// Beego (http://beego.me/)
// @description beego is an open-source, high-performance web framework for the Go programming language.
// @link http://github.com/astaxie/beego for the canonical source repository
// @license http://github.com/astaxie/beego/blob/master/LICENSE
// @authors astaxie
package beego package beego
import ( import (
"errors"
"fmt"
"html/template" "html/template"
"os" "os"
"path/filepath" "path/filepath"
"runtime" "runtime"
"strconv"
"strings" "strings"
"github.com/astaxie/beego/config" "github.com/astaxie/beego/config"
"github.com/astaxie/beego/logs" "github.com/astaxie/beego/logs"
"github.com/astaxie/beego/session" "github.com/astaxie/beego/session"
"github.com/astaxie/beego/utils"
) )
var ( var (
BeeApp *App // beego application BeeApp *App // beego application
AppName string AppName string
AppPath string AppPath string
workPath string
AppConfigPath string AppConfigPath string
StaticDir map[string]string StaticDir map[string]string
TemplateCache map[string]*template.Template // template caching map TemplateCache map[string]*template.Template // template caching map
StaticExtensionsToGzip []string // files with should be compressed with gzip (.js,.css,etc) StaticExtensionsToGzip []string // files with should be compressed with gzip (.js,.css,etc)
EnableHttpListen bool
HttpAddr string HttpAddr string
HttpPort int HttpPort int
HttpTLS bool EnableHttpTLS bool
HttpsPort int
HttpCertFile string HttpCertFile string
HttpKeyFile string HttpKeyFile string
RecoverPanic bool // flag of auto recover panic RecoverPanic bool // flag of auto recover panic
@ -40,11 +51,11 @@ var (
SessionHashFunc string // session hash generation func. SessionHashFunc string // session hash generation func.
SessionHashKey string // session hash salt string. SessionHashKey string // session hash salt string.
SessionCookieLifeTime int // the life time of session id in cookie. SessionCookieLifeTime int // the life time of session id in cookie.
SessionAutoSetCookie bool // auto setcookie
UseFcgi bool UseFcgi bool
MaxMemory int64 MaxMemory int64
EnableGzip bool // flag of enable gzip EnableGzip bool // flag of enable gzip
DirectoryIndex bool // flag of display directory index. default is false. DirectoryIndex bool // flag of display directory index. default is false.
EnableHotUpdate bool // flag of hot update checking by app self. default is false.
HttpServerTimeOut int64 HttpServerTimeOut int64
ErrorsShow bool // flag of show errors in page. if true, show error and trace info in page rendered with error template. ErrorsShow bool // flag of show errors in page. if true, show error and trace info in page rendered with error template.
XSRFKEY string // xsrf hash salt string. XSRFKEY string // xsrf hash salt string.
@ -57,15 +68,32 @@ var (
EnableAdmin bool // flag of enable admin module to log every request info. EnableAdmin bool // flag of enable admin module to log every request info.
AdminHttpAddr string // http server configurations for admin module. AdminHttpAddr string // http server configurations for admin module.
AdminHttpPort int AdminHttpPort int
FlashName string // name of the flash variable found in response header and cookie
FlashSeperator string // used to seperate flash key:value
AppConfigProvider string // config provider
EnableDocs bool // enable generate docs & server docs API Swagger
) )
func init() { func init() {
// create beego application // create beego application
BeeApp = NewApp() BeeApp = NewApp()
workPath, _ = os.Getwd()
workPath, _ = filepath.Abs(workPath)
// initialize default configurations // initialize default configurations
AppPath, _ = filepath.Abs(filepath.Dir(os.Args[0])) AppPath, _ = filepath.Abs(filepath.Dir(os.Args[0]))
AppConfigPath = filepath.Join(AppPath, "conf", "app.conf")
if workPath != AppPath {
if utils.FileExists(AppConfigPath) {
os.Chdir(AppPath) os.Chdir(AppPath)
} else {
AppConfigPath = filepath.Join(workPath, "conf", "app.conf")
}
}
AppConfigProvider = "ini"
StaticDir = make(map[string]string) StaticDir = make(map[string]string)
StaticDir["/static"] = "static" StaticDir["/static"] = "static"
@ -75,9 +103,13 @@ func init() {
TemplateCache = make(map[string]*template.Template) TemplateCache = make(map[string]*template.Template)
// set this to 0.0.0.0 to make this app available to externally // set this to 0.0.0.0 to make this app available to externally
EnableHttpListen = true //default enable http Listen
HttpAddr = "" HttpAddr = ""
HttpPort = 8080 HttpPort = 8080
HttpsPort = 10443
AppName = "beego" AppName = "beego"
RunMode = "dev" //default runmod RunMode = "dev" //default runmod
@ -96,6 +128,7 @@ func init() {
SessionHashFunc = "sha1" SessionHashFunc = "sha1"
SessionHashKey = "beegoserversessionkey" SessionHashKey = "beegoserversessionkey"
SessionCookieLifeTime = 0 //set cookie default is the brower life SessionCookieLifeTime = 0 //set cookie default is the brower life
SessionAutoSetCookie = true
UseFcgi = false UseFcgi = false
@ -103,8 +136,6 @@ func init() {
EnableGzip = false EnableGzip = false
AppConfigPath = filepath.Join(AppPath, "conf", "app.conf")
HttpServerTimeOut = 0 HttpServerTimeOut = 0
ErrorsShow = true ErrorsShow = true
@ -115,19 +146,25 @@ func init() {
TemplateLeft = "{{" TemplateLeft = "{{"
TemplateRight = "}}" TemplateRight = "}}"
BeegoServerName = "beegoServer" BeegoServerName = "beegoServer:" + VERSION
EnableAdmin = false EnableAdmin = false
AdminHttpAddr = "127.0.0.1" AdminHttpAddr = "127.0.0.1"
AdminHttpPort = 8088 AdminHttpPort = 8088
FlashName = "BEEGO_FLASH"
FlashSeperator = "BEEGOFLASH"
runtime.GOMAXPROCS(runtime.NumCPU()) runtime.GOMAXPROCS(runtime.NumCPU())
// init BeeLogger // init BeeLogger
BeeLogger = logs.NewLogger(10000) BeeLogger = logs.NewLogger(10000)
BeeLogger.SetLogger("console", "") err := BeeLogger.SetLogger("console", "")
if err != nil {
fmt.Println("init console log error:", err)
}
err := ParseConfig() err = ParseConfig()
if err != nil && !os.IsNotExist(err) { if err != nil && !os.IsNotExist(err) {
// for init if doesn't have app.conf will not panic // for init if doesn't have app.conf will not panic
Info(err) Info(err)
@ -137,153 +174,168 @@ func init() {
// ParseConfig parsed default config file. // ParseConfig parsed default config file.
// now only support ini, next will support json. // now only support ini, next will support json.
func ParseConfig() (err error) { func ParseConfig() (err error) {
AppConfig, err = config.NewConfig("ini", AppConfigPath) AppConfig, err = config.NewConfig(AppConfigProvider, AppConfigPath)
if err != nil { if err != nil {
AppConfig = config.NewFakeConfig()
return err return err
} else { } else {
HttpAddr = AppConfig.String("HttpAddr")
if v, err := AppConfig.Int("HttpPort"); err == nil { if v, err := getConfig("string", "HttpAddr"); err == nil {
HttpPort = v HttpAddr = v.(string)
} }
if maxmemory, err := AppConfig.Int64("MaxMemory"); err == nil { if v, err := getConfig("int", "HttpPort"); err == nil {
MaxMemory = maxmemory HttpPort = v.(int)
} }
if appname := AppConfig.String("AppName"); appname != "" { if v, err := getConfig("bool", "EnableHttpListen"); err == nil {
AppName = appname EnableHttpListen = v.(bool)
} }
if runmode := AppConfig.String("RunMode"); runmode != "" { if maxmemory, err := getConfig("int64", "MaxMemory"); err == nil {
RunMode = runmode MaxMemory = maxmemory.(int64)
} }
if autorender, err := AppConfig.Bool("AutoRender"); err == nil { if appname, _ := getConfig("string", "AppName"); appname != "" {
AutoRender = autorender AppName = appname.(string)
} }
if autorecover, err := AppConfig.Bool("RecoverPanic"); err == nil { if runmode, _ := getConfig("string", "RunMode"); runmode != "" {
RecoverPanic = autorecover RunMode = runmode.(string)
} }
if views := AppConfig.String("ViewsPath"); views != "" { if autorender, err := getConfig("bool", "AutoRender"); err == nil {
ViewsPath = views AutoRender = autorender.(bool)
} }
if sessionon, err := AppConfig.Bool("SessionOn"); err == nil { if autorecover, err := getConfig("bool", "RecoverPanic"); err == nil {
SessionOn = sessionon RecoverPanic = autorecover.(bool)
} }
if sessProvider := AppConfig.String("SessionProvider"); sessProvider != "" { if views, _ := getConfig("string", "ViewsPath"); views != "" {
SessionProvider = sessProvider ViewsPath = views.(string)
} }
if sessName := AppConfig.String("SessionName"); sessName != "" { if sessionon, err := getConfig("bool", "SessionOn"); err == nil {
SessionName = sessName SessionOn = sessionon.(bool)
} }
if sesssavepath := AppConfig.String("SessionSavePath"); sesssavepath != "" { if sessProvider, _ := getConfig("string", "SessionProvider"); sessProvider != "" {
SessionSavePath = sesssavepath SessionProvider = sessProvider.(string)
} }
if sesshashfunc := AppConfig.String("SessionHashFunc"); sesshashfunc != "" { if sessName, _ := getConfig("string", "SessionName"); sessName != "" {
SessionHashFunc = sesshashfunc SessionName = sessName.(string)
} }
if sesshashkey := AppConfig.String("SessionHashKey"); sesshashkey != "" { if sesssavepath, _ := getConfig("string", "SessionSavePath"); sesssavepath != "" {
SessionHashKey = sesshashkey SessionSavePath = sesssavepath.(string)
} }
if sessMaxLifeTime, err := AppConfig.Int("SessionGCMaxLifetime"); err == nil && sessMaxLifeTime != 0 { if sesshashfunc, _ := getConfig("string", "SessionHashFunc"); sesshashfunc != "" {
int64val, _ := strconv.ParseInt(strconv.Itoa(sessMaxLifeTime), 10, 64) SessionHashFunc = sesshashfunc.(string)
SessionGCMaxLifetime = int64val
} }
if sesscookielifetime, err := AppConfig.Int("SessionCookieLifeTime"); err == nil && sesscookielifetime != 0 { if sesshashkey, _ := getConfig("string", "SessionHashKey"); sesshashkey != "" {
SessionCookieLifeTime = sesscookielifetime SessionHashKey = sesshashkey.(string)
} }
if usefcgi, err := AppConfig.Bool("UseFcgi"); err == nil { if sessMaxLifeTime, err := getConfig("int64", "SessionGCMaxLifetime"); err == nil && sessMaxLifeTime != 0 {
UseFcgi = usefcgi SessionGCMaxLifetime = sessMaxLifeTime.(int64)
} }
if enablegzip, err := AppConfig.Bool("EnableGzip"); err == nil { if sesscookielifetime, err := getConfig("int", "SessionCookieLifeTime"); err == nil && sesscookielifetime != 0 {
EnableGzip = enablegzip SessionCookieLifeTime = sesscookielifetime.(int)
} }
if directoryindex, err := AppConfig.Bool("DirectoryIndex"); err == nil { if usefcgi, err := getConfig("bool", "UseFcgi"); err == nil {
DirectoryIndex = directoryindex UseFcgi = usefcgi.(bool)
} }
if hotupdate, err := AppConfig.Bool("HotUpdate"); err == nil { if enablegzip, err := getConfig("bool", "EnableGzip"); err == nil {
EnableHotUpdate = hotupdate EnableGzip = enablegzip.(bool)
} }
if timeout, err := AppConfig.Int64("HttpServerTimeOut"); err == nil { if directoryindex, err := getConfig("bool", "DirectoryIndex"); err == nil {
HttpServerTimeOut = timeout DirectoryIndex = directoryindex.(bool)
} }
if errorsshow, err := AppConfig.Bool("ErrorsShow"); err == nil { if timeout, err := getConfig("int64", "HttpServerTimeOut"); err == nil {
ErrorsShow = errorsshow HttpServerTimeOut = timeout.(int64)
} }
if copyrequestbody, err := AppConfig.Bool("CopyRequestBody"); err == nil { if errorsshow, err := getConfig("bool", "ErrorsShow"); err == nil {
CopyRequestBody = copyrequestbody ErrorsShow = errorsshow.(bool)
} }
if xsrfkey := AppConfig.String("XSRFKEY"); xsrfkey != "" { if copyrequestbody, err := getConfig("bool", "CopyRequestBody"); err == nil {
XSRFKEY = xsrfkey CopyRequestBody = copyrequestbody.(bool)
} }
if enablexsrf, err := AppConfig.Bool("EnableXSRF"); err == nil { if xsrfkey, _ := getConfig("string", "XSRFKEY"); xsrfkey != "" {
EnableXSRF = enablexsrf XSRFKEY = xsrfkey.(string)
} }
if expire, err := AppConfig.Int("XSRFExpire"); err == nil { if enablexsrf, err := getConfig("bool", "EnableXSRF"); err == nil {
XSRFExpire = expire EnableXSRF = enablexsrf.(bool)
} }
if tplleft := AppConfig.String("TemplateLeft"); tplleft != "" { if expire, err := getConfig("int", "XSRFExpire"); err == nil {
TemplateLeft = tplleft XSRFExpire = expire.(int)
} }
if tplright := AppConfig.String("TemplateRight"); tplright != "" { if tplleft, _ := getConfig("string", "TemplateLeft"); tplleft != "" {
TemplateRight = tplright TemplateLeft = tplleft.(string)
} }
if httptls, err := AppConfig.Bool("HttpTLS"); err == nil { if tplright, _ := getConfig("string", "TemplateRight"); tplright != "" {
HttpTLS = httptls TemplateRight = tplright.(string)
} }
if certfile := AppConfig.String("HttpCertFile"); certfile != "" { if httptls, err := getConfig("bool", "EnableHttpTLS"); err == nil {
HttpCertFile = certfile EnableHttpTLS = httptls.(bool)
} }
if keyfile := AppConfig.String("HttpKeyFile"); keyfile != "" { if httpsport, err := getConfig("int", "HttpsPort"); err == nil {
HttpKeyFile = keyfile HttpsPort = httpsport.(int)
} }
if serverName := AppConfig.String("BeegoServerName"); serverName != "" { if certfile, _ := getConfig("string", "HttpCertFile"); certfile != "" {
BeegoServerName = serverName HttpCertFile = certfile.(string)
} }
if sd := AppConfig.String("StaticDir"); sd != "" { if keyfile, _ := getConfig("string", "HttpKeyFile"); keyfile != "" {
HttpKeyFile = keyfile.(string)
}
if serverName, _ := getConfig("string", "BeegoServerName"); serverName != "" {
BeegoServerName = serverName.(string)
}
if flashname, _ := getConfig("string", "FlashName"); flashname != "" {
FlashName = flashname.(string)
}
if flashseperator, _ := getConfig("string", "FlashSeperator"); flashseperator != "" {
FlashSeperator = flashseperator.(string)
}
if sd, _ := getConfig("string", "StaticDir"); sd != "" {
for k := range StaticDir { for k := range StaticDir {
delete(StaticDir, k) delete(StaticDir, k)
} }
sds := strings.Fields(sd) sds := strings.Fields(sd.(string))
for _, v := range sds { for _, v := range sds {
if url2fsmap := strings.SplitN(v, ":", 2); len(url2fsmap) == 2 { if url2fsmap := strings.SplitN(v, ":", 2); len(url2fsmap) == 2 {
StaticDir["/"+url2fsmap[0]] = url2fsmap[1] StaticDir["/"+strings.TrimRight(url2fsmap[0], "/")] = url2fsmap[1]
} else { } else {
StaticDir["/"+url2fsmap[0]] = url2fsmap[0] StaticDir["/"+strings.TrimRight(url2fsmap[0], "/")] = url2fsmap[0]
} }
} }
} }
if sgz := AppConfig.String("StaticExtensionsToGzip"); sgz != "" { if sgz, _ := getConfig("string", "StaticExtensionsToGzip"); sgz != "" {
extensions := strings.Split(sgz, ",") extensions := strings.Split(sgz.(string), ",")
if len(extensions) > 0 { if len(extensions) > 0 {
StaticExtensionsToGzip = []string{} StaticExtensionsToGzip = []string{}
for _, ext := range extensions { for _, ext := range extensions {
@ -299,17 +351,63 @@ func ParseConfig() (err error) {
} }
} }
if enableadmin, err := AppConfig.Bool("EnableAdmin"); err == nil { if enableadmin, err := getConfig("bool", "EnableAdmin"); err == nil {
EnableAdmin = enableadmin EnableAdmin = enableadmin.(bool)
} }
if adminhttpaddr := AppConfig.String("AdminHttpAddr"); adminhttpaddr != "" { if adminhttpaddr, _ := getConfig("string", "AdminHttpAddr"); adminhttpaddr != "" {
AdminHttpAddr = adminhttpaddr AdminHttpAddr = adminhttpaddr.(string)
} }
if adminhttpport, err := AppConfig.Int("AdminHttpPort"); err == nil { if adminhttpport, err := getConfig("int", "AdminHttpPort"); err == nil {
AdminHttpPort = adminhttpport AdminHttpPort = adminhttpport.(int)
}
if enabledocs, err := getConfig("bool", "EnableDocs"); err == nil {
EnableDocs = enabledocs.(bool)
} }
} }
return nil return nil
} }
func getConfig(typ, key string) (interface{}, error) {
switch typ {
case "string":
v := AppConfig.String(RunMode + "::" + key)
if v == "" {
v = AppConfig.String(key)
}
return v, nil
case "strings":
v := AppConfig.Strings(RunMode + "::" + key)
if len(v) == 0 {
v = AppConfig.Strings(key)
}
return v, nil
case "int":
v, err := AppConfig.Int(RunMode + "::" + key)
if err != nil || v == 0 {
return AppConfig.Int(key)
}
return v, nil
case "bool":
v, err := AppConfig.Bool(RunMode + "::" + key)
if err != nil {
return AppConfig.Bool(key)
}
return v, nil
case "int64":
v, err := AppConfig.Int64(RunMode + "::" + key)
if err != nil || v == 0 {
return AppConfig.Int64(key)
}
return v, nil
case "float":
v, err := AppConfig.Float(RunMode + "::" + key)
if err != nil || v == 0 {
return AppConfig.Float(key)
}
return v, nil
}
return "", errors.New("not support type")
}

View File

@ -1,12 +1,20 @@
// Beego (http://beego.me/)
// @description beego is an open-source, high-performance web framework for the Go programming language.
// @link http://github.com/astaxie/beego for the canonical source repository
// @license http://github.com/astaxie/beego/blob/master/LICENSE
// @authors astaxie
package config package config
import ( import (
"fmt" "fmt"
) )
// ConfigContainer defines how to get and set value from configuration raw data.
type ConfigContainer interface { type ConfigContainer interface {
Set(key, val string) error Set(key, val string) error // support section::key type in given key when using ini type.
String(key string) string String(key string) string // support section::key type in key string when using ini and json type; Int,Int64,Bool,Float,DIY are same.
Strings(key string) []string //get string slice
Int(key string) (int, error) Int(key string) (int, error)
Int64(key string) (int64, error) Int64(key string) (int64, error)
Bool(key string) (bool, error) Bool(key string) (bool, error)
@ -14,6 +22,7 @@ type ConfigContainer interface {
DIY(key string) (interface{}, error) DIY(key string) (interface{}, error)
} }
// Config is the adapter interface for parsing config file to get raw data to ConfigContainer.
type Config interface { type Config interface {
Parse(key string) (ConfigContainer, error) Parse(key string) (ConfigContainer, error)
} }
@ -33,8 +42,8 @@ func Register(name string, adapter Config) {
adapters[name] = adapter adapters[name] = adapter
} }
// adapterNamer is ini/json/xml/yaml // adapterName is ini/json/xml/yaml.
// filename is the config file path // filename is the config file path.
func NewConfig(adapterName, fileaname string) (ConfigContainer, error) { func NewConfig(adapterName, fileaname string) (ConfigContainer, error) {
adapter, ok := adapters[adapterName] adapter, ok := adapters[adapterName]
if !ok { if !ok {

68
config/fake.go Normal file
View File

@ -0,0 +1,68 @@
// Beego (http://beego.me/)
// @description beego is an open-source, high-performance web framework for the Go programming language.
// @link http://github.com/astaxie/beego for the canonical source repository
// @license http://github.com/astaxie/beego/blob/master/LICENSE
// @authors astaxie
package config
import (
"errors"
"strconv"
"strings"
)
type fakeConfigContainer struct {
data map[string]string
}
func (c *fakeConfigContainer) getData(key string) string {
key = strings.ToLower(key)
return c.data[key]
}
func (c *fakeConfigContainer) Set(key, val string) error {
key = strings.ToLower(key)
c.data[key] = val
return nil
}
func (c *fakeConfigContainer) String(key string) string {
return c.getData(key)
}
func (c *fakeConfigContainer) Strings(key string) []string {
return strings.Split(c.getData(key), ";")
}
func (c *fakeConfigContainer) Int(key string) (int, error) {
return strconv.Atoi(c.getData(key))
}
func (c *fakeConfigContainer) Int64(key string) (int64, error) {
return strconv.ParseInt(c.getData(key), 10, 64)
}
func (c *fakeConfigContainer) Bool(key string) (bool, error) {
return strconv.ParseBool(c.getData(key))
}
func (c *fakeConfigContainer) Float(key string) (float64, error) {
return strconv.ParseFloat(c.getData(key), 64)
}
func (c *fakeConfigContainer) DIY(key string) (interface{}, error) {
key = strings.ToLower(key)
if v, ok := c.data[key]; ok {
return v, nil
}
return nil, errors.New("key not find")
}
var _ ConfigContainer = new(fakeConfigContainer)
func NewFakeConfig() ConfigContainer {
return &fakeConfigContainer{
data: make(map[string]string),
}
}

View File

@ -1,3 +1,9 @@
// Beego (http://beego.me/)
// @description beego is an open-source, high-performance web framework for the Go programming language.
// @link http://github.com/astaxie/beego for the canonical source repository
// @license http://github.com/astaxie/beego/blob/master/LICENSE
// @authors astaxie
package config package config
import ( import (
@ -13,21 +19,21 @@ import (
) )
var ( var (
DEFAULT_SECTION = "default" DEFAULT_SECTION = "default" // default section means if some ini items not in a section, make them in default section,
bNumComment = []byte{'#'} // number sign bNumComment = []byte{'#'} // number signal
bSemComment = []byte{';'} // semicolon bSemComment = []byte{';'} // semicolon signal
bEmpty = []byte{} bEmpty = []byte{}
bEqual = []byte{'='} bEqual = []byte{'='} // equal signal
bDQuote = []byte{'"'} bDQuote = []byte{'"'} // quote signal
sectionStart = []byte{'['} sectionStart = []byte{'['} // section start signal
sectionEnd = []byte{']'} sectionEnd = []byte{']'} // section end signal
) )
// IniConfig implements Config to parse ini file.
type IniConfig struct { type IniConfig struct {
} }
// ParseFile creates a new Config and parses the file configuration from the // ParseFile creates a new Config and parses the file configuration from the named file.
// named file.
func (ini *IniConfig) Parse(name string) (ConfigContainer, error) { func (ini *IniConfig) Parse(name string) (ConfigContainer, error) {
file, err := os.Open(name) file, err := os.Open(name)
if err != nil { if err != nil {
@ -106,11 +112,12 @@ func (ini *IniConfig) Parse(name string) (ConfigContainer, error) {
return cfg, nil return cfg, nil
} }
// A Config represents the configuration. // A Config represents the ini configuration.
// When set and get value, support key as section:name type.
type IniConfigContainer struct { type IniConfigContainer struct {
filename string filename string
data map[string]map[string]string //section=> key:val data map[string]map[string]string // section=> key:val
sectionComment map[string]string //sction : comment sectionComment map[string]string // section : comment
keycomment map[string]string // id: []{comment, key...}; id 1 is for main comment. keycomment map[string]string // id: []{comment, key...}; id 1 is for main comment.
sync.RWMutex sync.RWMutex
} }
@ -127,6 +134,7 @@ func (c *IniConfigContainer) Int(key string) (int, error) {
return strconv.Atoi(c.getdata(key)) return strconv.Atoi(c.getdata(key))
} }
// Int64 returns the int64 value for a given key.
func (c *IniConfigContainer) Int64(key string) (int64, error) { func (c *IniConfigContainer) Int64(key string) (int64, error) {
key = strings.ToLower(key) key = strings.ToLower(key)
return strconv.ParseInt(c.getdata(key), 10, 64) return strconv.ParseInt(c.getdata(key), 10, 64)
@ -144,7 +152,14 @@ func (c *IniConfigContainer) String(key string) string {
return c.getdata(key) return c.getdata(key)
} }
// Strings returns the []string value for a given key.
func (c *IniConfigContainer) Strings(key string) []string {
return strings.Split(c.String(key), ";")
}
// WriteValue writes a new value for key. // WriteValue writes a new value for key.
// if write to one section, the key need be "section::key".
// if the section is not existed, it panics.
func (c *IniConfigContainer) Set(key, value string) error { func (c *IniConfigContainer) Set(key, value string) error {
c.Lock() c.Lock()
defer c.Unlock() defer c.Unlock()
@ -169,6 +184,7 @@ func (c *IniConfigContainer) Set(key, value string) error {
return nil return nil
} }
// DIY returns the raw value by a given key.
func (c *IniConfigContainer) DIY(key string) (v interface{}, err error) { func (c *IniConfigContainer) DIY(key string) (v interface{}, err error) {
key = strings.ToLower(key) key = strings.ToLower(key)
if v, ok := c.data[key]; ok { if v, ok := c.data[key]; ok {
@ -177,7 +193,7 @@ func (c *IniConfigContainer) DIY(key string) (v interface{}, err error) {
return v, errors.New("key not find") return v, errors.New("key not find")
} }
//section.key or key // section.key or key
func (c *IniConfigContainer) getdata(key string) string { func (c *IniConfigContainer) getdata(key string) string {
c.RLock() c.RLock()
defer c.RUnlock() defer c.RUnlock()

View File

@ -1,3 +1,9 @@
// Beego (http://beego.me/)
// @description beego is an open-source, high-performance web framework for the Go programming language.
// @link http://github.com/astaxie/beego for the canonical source repository
// @license http://github.com/astaxie/beego/blob/master/LICENSE
// @authors astaxie
package config package config
import ( import (
@ -19,6 +25,7 @@ copyrequestbody = true
key1="asta" key1="asta"
key2 = "xie" key2 = "xie"
CaseInsensitive = true CaseInsensitive = true
peers = one;two;three
` `
func TestIni(t *testing.T) { func TestIni(t *testing.T) {
@ -78,4 +85,11 @@ func TestIni(t *testing.T) {
if v, err := iniconf.Bool("demo::caseinsensitive"); err != nil || v != true { if v, err := iniconf.Bool("demo::caseinsensitive"); err != nil || v != true {
t.Fatal("get demo.caseinsensitive error") t.Fatal("get demo.caseinsensitive error")
} }
if data := iniconf.Strings("demo::peers"); len(data) != 3 {
t.Fatal("get strings error", data)
} else if data[0] != "one" {
t.Fatal("get first params error not equat to one")
}
} }

View File

@ -1,3 +1,9 @@
// Beego (http://beego.me/)
// @description beego is an open-source, high-performance web framework for the Go programming language.
// @link http://github.com/astaxie/beego for the canonical source repository
// @license http://github.com/astaxie/beego/blob/master/LICENSE
// @authors astaxie
package config package config
import ( import (
@ -9,9 +15,11 @@ import (
"sync" "sync"
) )
// JsonConfig is a json config parser and implements Config interface.
type JsonConfig struct { type JsonConfig struct {
} }
// Parse returns a ConfigContainer with parsed json config map.
func (js *JsonConfig) Parse(filename string) (ConfigContainer, error) { func (js *JsonConfig) Parse(filename string) (ConfigContainer, error) {
file, err := os.Open(filename) file, err := os.Open(filename)
if err != nil { if err != nil {
@ -27,16 +35,24 @@ func (js *JsonConfig) Parse(filename string) (ConfigContainer, error) {
} }
err = json.Unmarshal(content, &x.data) err = json.Unmarshal(content, &x.data)
if err != nil { if err != nil {
var wrappingArray []interface{}
err2 := json.Unmarshal(content, &wrappingArray)
if err2 != nil {
return nil, err return nil, err
} }
x.data["rootArray"] = wrappingArray
}
return x, nil return x, nil
} }
// A Config represents the json configuration.
// Only when get value, support key as section:name type.
type JsonConfigContainer struct { type JsonConfigContainer struct {
data map[string]interface{} data map[string]interface{}
sync.RWMutex sync.RWMutex
} }
// Bool returns the boolean value for a given key.
func (c *JsonConfigContainer) Bool(key string) (bool, error) { func (c *JsonConfigContainer) Bool(key string) (bool, error) {
val := c.getdata(key) val := c.getdata(key)
if val != nil { if val != nil {
@ -48,9 +64,9 @@ func (c *JsonConfigContainer) Bool(key string) (bool, error) {
} else { } else {
return false, errors.New("not exist key:" + key) return false, errors.New("not exist key:" + key)
} }
} }
// Int returns the integer value for a given key.
func (c *JsonConfigContainer) Int(key string) (int, error) { func (c *JsonConfigContainer) Int(key string) (int, error) {
val := c.getdata(key) val := c.getdata(key)
if val != nil { if val != nil {
@ -64,6 +80,7 @@ func (c *JsonConfigContainer) Int(key string) (int, error) {
} }
} }
// Int64 returns the int64 value for a given key.
func (c *JsonConfigContainer) Int64(key string) (int64, error) { func (c *JsonConfigContainer) Int64(key string) (int64, error) {
val := c.getdata(key) val := c.getdata(key)
if val != nil { if val != nil {
@ -77,6 +94,7 @@ func (c *JsonConfigContainer) Int64(key string) (int64, error) {
} }
} }
// Float returns the float value for a given key.
func (c *JsonConfigContainer) Float(key string) (float64, error) { func (c *JsonConfigContainer) Float(key string) (float64, error) {
val := c.getdata(key) val := c.getdata(key)
if val != nil { if val != nil {
@ -90,6 +108,7 @@ func (c *JsonConfigContainer) Float(key string) (float64, error) {
} }
} }
// String returns the string value for a given key.
func (c *JsonConfigContainer) String(key string) string { func (c *JsonConfigContainer) String(key string) string {
val := c.getdata(key) val := c.getdata(key)
if val != nil { if val != nil {
@ -103,6 +122,12 @@ func (c *JsonConfigContainer) String(key string) string {
} }
} }
// Strings returns the []string value for a given key.
func (c *JsonConfigContainer) Strings(key string) []string {
return strings.Split(c.String(key), ";")
}
// WriteValue writes a new value for key.
func (c *JsonConfigContainer) Set(key, val string) error { func (c *JsonConfigContainer) Set(key, val string) error {
c.Lock() c.Lock()
defer c.Unlock() defer c.Unlock()
@ -110,6 +135,7 @@ func (c *JsonConfigContainer) Set(key, val string) error {
return nil return nil
} }
// DIY returns the raw value by a given key.
func (c *JsonConfigContainer) DIY(key string) (v interface{}, err error) { func (c *JsonConfigContainer) DIY(key string) (v interface{}, err error) {
val := c.getdata(key) val := c.getdata(key)
if val != nil { if val != nil {
@ -119,7 +145,7 @@ func (c *JsonConfigContainer) DIY(key string) (v interface{}, err error) {
} }
} }
//section.key or key // section.key or key
func (c *JsonConfigContainer) getdata(key string) interface{} { func (c *JsonConfigContainer) getdata(key string) interface{} {
c.RLock() c.RLock()
defer c.RUnlock() defer c.RUnlock()

View File

@ -1,3 +1,9 @@
// Beego (http://beego.me/)
// @description beego is an open-source, high-performance web framework for the Go programming language.
// @link http://github.com/astaxie/beego for the canonical source repository
// @license http://github.com/astaxie/beego/blob/master/LICENSE
// @authors astaxie
package config package config
import ( import (
@ -27,6 +33,53 @@ var jsoncontext = `{
} }
}` }`
var jsoncontextwitharray = `[
{
"url": "user",
"serviceAPI": "http://www.test.com/user"
},
{
"url": "employee",
"serviceAPI": "http://www.test.com/employee"
}
]`
func TestJsonStartsWithArray(t *testing.T) {
f, err := os.Create("testjsonWithArray.conf")
if err != nil {
t.Fatal(err)
}
_, err = f.WriteString(jsoncontextwitharray)
if err != nil {
f.Close()
t.Fatal(err)
}
f.Close()
defer os.Remove("testjsonWithArray.conf")
jsonconf, err := NewConfig("json", "testjsonWithArray.conf")
if err != nil {
t.Fatal(err)
}
rootArray, err := jsonconf.DIY("rootArray")
if (err != nil) {
t.Error("array does not exist as element")
}
rootArrayCasted := rootArray.([]interface{})
if (rootArrayCasted == nil) {
t.Error("array from root is nil")
}else {
elem := rootArrayCasted[0].(map[string]interface{})
if elem["url"] != "user" || elem["serviceAPI"] != "http://www.test.com/user" {
t.Error("array[0] values are not valid")
}
elem2 := rootArrayCasted[1].(map[string]interface{})
if elem2["url"] != "employee" || elem2["serviceAPI"] != "http://www.test.com/employee" {
t.Error("array[1] values are not valid")
}
}
}
func TestJson(t *testing.T) { func TestJson(t *testing.T) {
f, err := os.Create("testjson.conf") f, err := os.Create("testjson.conf")
if err != nil { if err != nil {
@ -94,4 +147,28 @@ func TestJson(t *testing.T) {
t.Fatal("get host err") t.Fatal("get host err")
} }
} }
if _, err := jsonconf.Int("unknown"); err == nil {
t.Error("unknown keys should return an error when expecting an Int")
}
if _, err := jsonconf.Int64("unknown"); err == nil {
t.Error("unknown keys should return an error when expecting an Int64")
}
if _, err := jsonconf.Float("unknown"); err == nil {
t.Error("unknown keys should return an error when expecting a Float")
}
if _, err := jsonconf.DIY("unknown"); err == nil {
t.Error("unknown keys should return an error when expecting an interface{}")
}
if val := jsonconf.String("unknown"); val != "" {
t.Error("unknown keys should return an empty string when expecting a String")
}
if _, err := jsonconf.Bool("unknown"); err == nil {
t.Error("unknown keys should return an error when expecting a Bool")
}
} }

View File

@ -1,4 +1,8 @@
//xml parse should incluce in <config></config> tags // Beego (http://beego.me/)
// @description beego is an open-source, high-performance web framework for the Go programming language.
// @link http://github.com/astaxie/beego for the canonical source repository
// @license http://github.com/astaxie/beego/blob/master/LICENSE
// @authors astaxie
package config package config
@ -7,15 +11,21 @@ import (
"io/ioutil" "io/ioutil"
"os" "os"
"strconv" "strconv"
"strings"
"sync" "sync"
"github.com/astaxie/beego/config"
"github.com/beego/x2j" "github.com/beego/x2j"
) )
// XmlConfig is a xml config parser and implements Config interface.
// xml configurations should be included in <config></config> tag.
// only support key/value pair as <key>value</key> as each item.
type XMLConfig struct { type XMLConfig struct {
} }
func (xmls *XMLConfig) Parse(filename string) (ConfigContainer, error) { // Parse returns a ConfigContainer with parsed xml config map.
func (xmls *XMLConfig) Parse(filename string) (config.ConfigContainer, error) {
file, err := os.Open(filename) file, err := os.Open(filename)
if err != nil { if err != nil {
return nil, err return nil, err
@ -36,27 +46,33 @@ func (xmls *XMLConfig) Parse(filename string) (ConfigContainer, error) {
return x, nil return x, nil
} }
// A Config represents the xml configuration.
type XMLConfigContainer struct { type XMLConfigContainer struct {
data map[string]interface{} data map[string]interface{}
sync.Mutex sync.Mutex
} }
// Bool returns the boolean value for a given key.
func (c *XMLConfigContainer) Bool(key string) (bool, error) { func (c *XMLConfigContainer) Bool(key string) (bool, error) {
return strconv.ParseBool(c.data[key].(string)) return strconv.ParseBool(c.data[key].(string))
} }
// Int returns the integer value for a given key.
func (c *XMLConfigContainer) Int(key string) (int, error) { func (c *XMLConfigContainer) Int(key string) (int, error) {
return strconv.Atoi(c.data[key].(string)) return strconv.Atoi(c.data[key].(string))
} }
// Int64 returns the int64 value for a given key.
func (c *XMLConfigContainer) Int64(key string) (int64, error) { func (c *XMLConfigContainer) Int64(key string) (int64, error) {
return strconv.ParseInt(c.data[key].(string), 10, 64) return strconv.ParseInt(c.data[key].(string), 10, 64)
} }
// Float returns the float value for a given key.
func (c *XMLConfigContainer) Float(key string) (float64, error) { func (c *XMLConfigContainer) Float(key string) (float64, error) {
return strconv.ParseFloat(c.data[key].(string), 64) return strconv.ParseFloat(c.data[key].(string), 64)
} }
// String returns the string value for a given key.
func (c *XMLConfigContainer) String(key string) string { func (c *XMLConfigContainer) String(key string) string {
if v, ok := c.data[key].(string); ok { if v, ok := c.data[key].(string); ok {
return v return v
@ -64,6 +80,12 @@ func (c *XMLConfigContainer) String(key string) string {
return "" return ""
} }
// Strings returns the []string value for a given key.
func (c *XMLConfigContainer) Strings(key string) []string {
return strings.Split(c.String(key), ";")
}
// WriteValue writes a new value for key.
func (c *XMLConfigContainer) Set(key, val string) error { func (c *XMLConfigContainer) Set(key, val string) error {
c.Lock() c.Lock()
defer c.Unlock() defer c.Unlock()
@ -71,6 +93,7 @@ func (c *XMLConfigContainer) Set(key, val string) error {
return nil return nil
} }
// DIY returns the raw value by a given key.
func (c *XMLConfigContainer) DIY(key string) (v interface{}, err error) { func (c *XMLConfigContainer) DIY(key string) (v interface{}, err error) {
if v, ok := c.data[key]; ok { if v, ok := c.data[key]; ok {
return v, nil return v, nil
@ -79,5 +102,5 @@ func (c *XMLConfigContainer) DIY(key string) (v interface{}, err error) {
} }
func init() { func init() {
Register("xml", &XMLConfig{}) config.Register("xml", &XMLConfig{})
} }

View File

@ -1,8 +1,16 @@
// Beego (http://beego.me/)
// @description beego is an open-source, high-performance web framework for the Go programming language.
// @link http://github.com/astaxie/beego for the canonical source repository
// @license http://github.com/astaxie/beego/blob/master/LICENSE
// @authors astaxie
package config package config
import ( import (
"os" "os"
"testing" "testing"
"github.com/astaxie/beego/config"
) )
//xml parse should incluce in <config></config> tags //xml parse should incluce in <config></config> tags
@ -30,7 +38,7 @@ func TestXML(t *testing.T) {
} }
f.Close() f.Close()
defer os.Remove("testxml.conf") defer os.Remove("testxml.conf")
xmlconf, err := NewConfig("xml", "testxml.conf") xmlconf, err := config.NewConfig("xml", "testxml.conf")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }

View File

@ -1,3 +1,9 @@
// Beego (http://beego.me/)
// @description beego is an open-source, high-performance web framework for the Go programming language.
// @link http://github.com/astaxie/beego for the canonical source repository
// @license http://github.com/astaxie/beego/blob/master/LICENSE
// @authors astaxie
package config package config
import ( import (
@ -7,15 +13,19 @@ import (
"io/ioutil" "io/ioutil"
"log" "log"
"os" "os"
"strings"
"sync" "sync"
"github.com/astaxie/beego/config"
"github.com/beego/goyaml2" "github.com/beego/goyaml2"
) )
// YAMLConfig is a yaml config parser and implements Config interface.
type YAMLConfig struct { type YAMLConfig struct {
} }
func (yaml *YAMLConfig) Parse(filename string) (ConfigContainer, error) { // Parse returns a ConfigContainer with parsed yaml config map.
func (yaml *YAMLConfig) Parse(filename string) (config.ConfigContainer, error) {
y := &YAMLConfigContainer{ y := &YAMLConfigContainer{
data: make(map[string]interface{}), data: make(map[string]interface{}),
} }
@ -27,7 +37,8 @@ func (yaml *YAMLConfig) Parse(filename string) (ConfigContainer, error) {
return y, nil return y, nil
} }
// Reader读取YAML // Read yaml file to map.
// if json like, use json package, unless goyaml2 package.
func ReadYmlReader(path string) (cnf map[string]interface{}, err error) { func ReadYmlReader(path string) (cnf map[string]interface{}, err error) {
err = nil err = nil
f, err := os.Open(path) f, err := os.Open(path)
@ -68,11 +79,13 @@ func ReadYmlReader(path string) (cnf map[string]interface{}, err error) {
return return
} }
// A Config represents the yaml configuration.
type YAMLConfigContainer struct { type YAMLConfigContainer struct {
data map[string]interface{} data map[string]interface{}
sync.Mutex sync.Mutex
} }
// Bool returns the boolean value for a given key.
func (c *YAMLConfigContainer) Bool(key string) (bool, error) { func (c *YAMLConfigContainer) Bool(key string) (bool, error) {
if v, ok := c.data[key].(bool); ok { if v, ok := c.data[key].(bool); ok {
return v, nil return v, nil
@ -80,6 +93,7 @@ func (c *YAMLConfigContainer) Bool(key string) (bool, error) {
return false, errors.New("not bool value") return false, errors.New("not bool value")
} }
// Int returns the integer value for a given key.
func (c *YAMLConfigContainer) Int(key string) (int, error) { func (c *YAMLConfigContainer) Int(key string) (int, error) {
if v, ok := c.data[key].(int64); ok { if v, ok := c.data[key].(int64); ok {
return int(v), nil return int(v), nil
@ -87,6 +101,7 @@ func (c *YAMLConfigContainer) Int(key string) (int, error) {
return 0, errors.New("not int value") return 0, errors.New("not int value")
} }
// Int64 returns the int64 value for a given key.
func (c *YAMLConfigContainer) Int64(key string) (int64, error) { func (c *YAMLConfigContainer) Int64(key string) (int64, error) {
if v, ok := c.data[key].(int64); ok { if v, ok := c.data[key].(int64); ok {
return v, nil return v, nil
@ -94,6 +109,7 @@ func (c *YAMLConfigContainer) Int64(key string) (int64, error) {
return 0, errors.New("not bool value") return 0, errors.New("not bool value")
} }
// Float returns the float value for a given key.
func (c *YAMLConfigContainer) Float(key string) (float64, error) { func (c *YAMLConfigContainer) Float(key string) (float64, error) {
if v, ok := c.data[key].(float64); ok { if v, ok := c.data[key].(float64); ok {
return v, nil return v, nil
@ -101,6 +117,7 @@ func (c *YAMLConfigContainer) Float(key string) (float64, error) {
return 0.0, errors.New("not float64 value") return 0.0, errors.New("not float64 value")
} }
// String returns the string value for a given key.
func (c *YAMLConfigContainer) String(key string) string { func (c *YAMLConfigContainer) String(key string) string {
if v, ok := c.data[key].(string); ok { if v, ok := c.data[key].(string); ok {
return v return v
@ -108,6 +125,12 @@ func (c *YAMLConfigContainer) String(key string) string {
return "" return ""
} }
// Strings returns the []string value for a given key.
func (c *YAMLConfigContainer) Strings(key string) []string {
return strings.Split(c.String(key), ";")
}
// WriteValue writes a new value for key.
func (c *YAMLConfigContainer) Set(key, val string) error { func (c *YAMLConfigContainer) Set(key, val string) error {
c.Lock() c.Lock()
defer c.Unlock() defer c.Unlock()
@ -115,6 +138,7 @@ func (c *YAMLConfigContainer) Set(key, val string) error {
return nil return nil
} }
// DIY returns the raw value by a given key.
func (c *YAMLConfigContainer) DIY(key string) (v interface{}, err error) { func (c *YAMLConfigContainer) DIY(key string) (v interface{}, err error) {
if v, ok := c.data[key]; ok { if v, ok := c.data[key]; ok {
return v, nil return v, nil
@ -123,5 +147,5 @@ func (c *YAMLConfigContainer) DIY(key string) (v interface{}, err error) {
} }
func init() { func init() {
Register("yaml", &YAMLConfig{}) config.Register("yaml", &YAMLConfig{})
} }

View File

@ -1,8 +1,16 @@
// Beego (http://beego.me/)
// @description beego is an open-source, high-performance web framework for the Go programming language.
// @link http://github.com/astaxie/beego for the canonical source repository
// @license http://github.com/astaxie/beego/blob/master/LICENSE
// @authors astaxie
package config package config
import ( import (
"os" "os"
"testing" "testing"
"github.com/astaxie/beego/config"
) )
var yamlcontext = ` var yamlcontext = `
@ -27,7 +35,7 @@ func TestYaml(t *testing.T) {
} }
f.Close() f.Close()
defer os.Remove("testyaml.conf") defer os.Remove("testyaml.conf")
yamlconf, err := NewConfig("yaml", "testyaml.conf") yamlconf, err := config.NewConfig("yaml", "testyaml.conf")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }

21
config_test.go Normal file
View File

@ -0,0 +1,21 @@
// Beego (http://beego.me/)
// @description beego is an open-source, high-performance web framework for the Go programming language.
// @link http://github.com/astaxie/beego for the canonical source repository
// @license http://github.com/astaxie/beego/blob/master/LICENSE
// @authors astaxie
package beego
import (
"testing"
)
func TestDefaults(t *testing.T) {
if FlashName != "BEEGO_FLASH" {
t.Errorf("FlashName was not set to default.")
}
if FlashSeperator != "BEEGOFLASH" {
t.Errorf("FlashName was not set to default.")
}
}

View File

@ -1,11 +1,26 @@
// Beego (http://beego.me/)
// @description beego is an open-source, high-performance web framework for the Go programming language.
// @link http://github.com/astaxie/beego for the canonical source repository
// @license http://github.com/astaxie/beego/blob/master/LICENSE
// @authors astaxie
package context package context
import ( import (
"crypto/hmac"
"crypto/sha1"
"encoding/base64"
"fmt"
"net/http" "net/http"
"strconv"
"strings"
"time"
"github.com/astaxie/beego/middleware" "github.com/astaxie/beego/middleware"
) )
// Http request context struct including BeegoInput, BeegoOutput, http.Request and http.ResponseWriter.
// BeegoInput and BeegoOutput provides some api to operate request and response more easily.
type Context struct { type Context struct {
Input *BeegoInput Input *BeegoInput
Output *BeegoOutput Output *BeegoOutput
@ -13,11 +28,16 @@ type Context struct {
ResponseWriter http.ResponseWriter ResponseWriter http.ResponseWriter
} }
// Redirect does redirection to localurl with http header status code.
// It sends http response header directly.
func (ctx *Context) Redirect(status int, localurl string) { func (ctx *Context) Redirect(status int, localurl string) {
ctx.Output.Header("Location", localurl) ctx.Output.Header("Location", localurl)
ctx.Output.SetStatus(status) ctx.Output.SetStatus(status)
} }
// Abort stops this request.
// if middleware.ErrorMaps exists, panic body.
// if middleware.HTTPExceptionMaps exists, panic HTTPException struct with status and body string.
func (ctx *Context) Abort(status int, body string) { func (ctx *Context) Abort(status int, body string) {
ctx.Output.SetStatus(status) ctx.Output.SetStatus(status)
// first panic from ErrorMaps, is is user defined error functions. // first panic from ErrorMaps, is is user defined error functions.
@ -35,14 +55,58 @@ func (ctx *Context) Abort(status int, body string) {
panic(body) panic(body)
} }
// Write string to response body.
// it sends response body.
func (ctx *Context) WriteString(content string) { func (ctx *Context) WriteString(content string) {
ctx.Output.Body([]byte(content)) ctx.Output.Body([]byte(content))
} }
// Get cookie from request by a given key.
// It's alias of BeegoInput.Cookie.
func (ctx *Context) GetCookie(key string) string { func (ctx *Context) GetCookie(key string) string {
return ctx.Input.Cookie(key) return ctx.Input.Cookie(key)
} }
// Set cookie for response.
// It's alias of BeegoOutput.Cookie.
func (ctx *Context) SetCookie(name string, value string, others ...interface{}) { func (ctx *Context) SetCookie(name string, value string, others ...interface{}) {
ctx.Output.Cookie(name, value, others...) ctx.Output.Cookie(name, value, others...)
} }
// Get secure cookie from request by a given key.
func (ctx *Context) GetSecureCookie(Secret, key string) (string, bool) {
val := ctx.Input.Cookie(key)
if val == "" {
return "", false
}
parts := strings.SplitN(val, "|", 3)
if len(parts) != 3 {
return "", false
}
vs := parts[0]
timestamp := parts[1]
sig := parts[2]
h := hmac.New(sha1.New, []byte(Secret))
fmt.Fprintf(h, "%s%s", vs, timestamp)
if fmt.Sprintf("%02x", h.Sum(nil)) != sig {
return "", false
}
res, _ := base64.URLEncoding.DecodeString(vs)
return string(res), true
}
// Set Secure cookie for response.
func (ctx *Context) SetSecureCookie(Secret, name, value string, others ...interface{}) {
vs := base64.URLEncoding.EncodeToString([]byte(value))
timestamp := strconv.FormatInt(time.Now().UnixNano(), 10)
h := hmac.New(sha1.New, []byte(Secret))
fmt.Fprintf(h, "%s%s", vs, timestamp)
sig := fmt.Sprintf("%02x", h.Sum(nil))
cookie := strings.Join([]string{vs, timestamp, sig}, "|")
ctx.Output.Cookie(name, cookie, others...)
}

View File

@ -1,23 +1,37 @@
// Beego (http://beego.me/)
// @description beego is an open-source, high-performance web framework for the Go programming language.
// @link http://github.com/astaxie/beego for the canonical source repository
// @license http://github.com/astaxie/beego/blob/master/LICENSE
// @authors astaxie
package context package context
import ( import (
"bytes" "bytes"
"errors"
"io/ioutil" "io/ioutil"
"net/http" "net/http"
"net/url"
"reflect"
"strconv" "strconv"
"strings" "strings"
"github.com/astaxie/beego/session" "github.com/astaxie/beego/session"
) )
// BeegoInput operates the http request header ,data ,cookie and body.
// it also contains router params and current session.
type BeegoInput struct { type BeegoInput struct {
CruSession session.SessionStore CruSession session.SessionStore
Params map[string]string Params map[string]string
Data map[interface{}]interface{} Data map[interface{}]interface{} // store some values in this context when calling context in filter or controller.
Request *http.Request Request *http.Request
RequestBody []byte RequestBody []byte
RunController reflect.Type
RunMethod string
} }
// NewInput return BeegoInput generated by http.Request.
func NewInput(req *http.Request) *BeegoInput { func NewInput(req *http.Request) *BeegoInput {
return &BeegoInput{ return &BeegoInput{
Params: make(map[string]string), Params: make(map[string]string),
@ -26,22 +40,27 @@ func NewInput(req *http.Request) *BeegoInput {
} }
} }
// Protocol returns request protocol name, such as HTTP/1.1 .
func (input *BeegoInput) Protocol() string { func (input *BeegoInput) Protocol() string {
return input.Request.Proto return input.Request.Proto
} }
// Uri returns full request url with query string, fragment.
func (input *BeegoInput) Uri() string { func (input *BeegoInput) Uri() string {
return input.Request.RequestURI return input.Request.RequestURI
} }
// Url returns request url path (without query string, fragment).
func (input *BeegoInput) Url() string { func (input *BeegoInput) Url() string {
return input.Request.URL.String() return input.Request.URL.String()
} }
// Site returns base site url as scheme://domain type.
func (input *BeegoInput) Site() string { func (input *BeegoInput) Site() string {
return input.Scheme() + "://" + input.Domain() return input.Scheme() + "://" + input.Domain()
} }
// Scheme returns request scheme as "http" or "https".
func (input *BeegoInput) Scheme() string { func (input *BeegoInput) Scheme() string {
if input.Request.URL.Scheme != "" { if input.Request.URL.Scheme != "" {
return input.Request.URL.Scheme return input.Request.URL.Scheme
@ -52,10 +71,14 @@ func (input *BeegoInput) Scheme() string {
} }
} }
// Domain returns host name.
// Alias of Host method.
func (input *BeegoInput) Domain() string { func (input *BeegoInput) Domain() string {
return input.Host() return input.Host()
} }
// Host returns host name.
// if no host info in request, return localhost.
func (input *BeegoInput) Host() string { func (input *BeegoInput) Host() string {
if input.Request.Host != "" { if input.Request.Host != "" {
hostParts := strings.Split(input.Request.Host, ":") hostParts := strings.Split(input.Request.Host, ":")
@ -67,44 +90,90 @@ func (input *BeegoInput) Host() string {
return "localhost" return "localhost"
} }
// Method returns http request method.
func (input *BeegoInput) Method() string { func (input *BeegoInput) Method() string {
return input.Request.Method return input.Request.Method
} }
// Is returns boolean of this request is on given method, such as Is("POST").
func (input *BeegoInput) Is(method string) bool { func (input *BeegoInput) Is(method string) bool {
return input.Method() == method return input.Method() == method
} }
// Is this a GET method request?
func (input *BeegoInput) IsGet() bool {
return input.Is("GET")
}
// Is this a POST method request?
func (input *BeegoInput) IsPost() bool {
return input.Is("POST")
}
// Is this a Head method request?
func (input *BeegoInput) IsHead() bool {
return input.Is("HEAD")
}
// Is this a OPTIONS method request?
func (input *BeegoInput) IsOptions() bool {
return input.Is("OPTIONS")
}
// Is this a PUT method request?
func (input *BeegoInput) IsPut() bool {
return input.Is("PUT")
}
// Is this a DELETE method request?
func (input *BeegoInput) IsDelete() bool {
return input.Is("DELETE")
}
// Is this a PATCH method request?
func (input *BeegoInput) IsPatch() bool {
return input.Is("PATCH")
}
// IsAjax returns boolean of this request is generated by ajax.
func (input *BeegoInput) IsAjax() bool { func (input *BeegoInput) IsAjax() bool {
return input.Header("X-Requested-With") == "XMLHttpRequest" return input.Header("X-Requested-With") == "XMLHttpRequest"
} }
// IsSecure returns boolean of this request is in https.
func (input *BeegoInput) IsSecure() bool { func (input *BeegoInput) IsSecure() bool {
return input.Scheme() == "https" return input.Scheme() == "https"
} }
// IsSecure returns boolean of this request is in webSocket.
func (input *BeegoInput) IsWebsocket() bool { func (input *BeegoInput) IsWebsocket() bool {
return input.Header("Upgrade") == "websocket" return input.Header("Upgrade") == "websocket"
} }
// IsSecure returns boolean of whether file uploads in this request or not..
func (input *BeegoInput) IsUpload() bool { func (input *BeegoInput) IsUpload() bool {
return input.Request.MultipartForm != nil return strings.Contains(input.Header("Content-Type"), "multipart/form-data")
} }
// IP returns request client ip.
// if in proxy, return first proxy id.
// if error, return 127.0.0.1.
func (input *BeegoInput) IP() string { func (input *BeegoInput) IP() string {
ips := input.Proxy() ips := input.Proxy()
if len(ips) > 0 && ips[0] != "" { if len(ips) > 0 && ips[0] != "" {
return ips[0] rip := strings.Split(ips[0], ":")
return rip[0]
} }
ip := strings.Split(input.Request.RemoteAddr, ":") ip := strings.Split(input.Request.RemoteAddr, ":")
if len(ip) > 0 { if len(ip) > 0 {
if ip[0] != "["{ if ip[0] != "[" {
return ip[0] return ip[0]
} }
} }
return "127.0.0.1" return "127.0.0.1"
} }
// Proxy returns proxy client ips slice.
func (input *BeegoInput) Proxy() []string { func (input *BeegoInput) Proxy() []string {
if ips := input.Header("X-Forwarded-For"); ips != "" { if ips := input.Header("X-Forwarded-For"); ips != "" {
return strings.Split(ips, ",") return strings.Split(ips, ",")
@ -112,15 +181,20 @@ func (input *BeegoInput) Proxy() []string {
return []string{} return []string{}
} }
// Refer returns http referer header.
func (input *BeegoInput) Refer() string { func (input *BeegoInput) Refer() string {
return input.Header("Referer") return input.Header("Referer")
} }
// SubDomains returns sub domain string.
// if aa.bb.domain.com, returns aa.bb .
func (input *BeegoInput) SubDomains() string { func (input *BeegoInput) SubDomains() string {
parts := strings.Split(input.Host(), ".") parts := strings.Split(input.Host(), ".")
return strings.Join(parts[len(parts)-2:], ".") return strings.Join(parts[len(parts)-2:], ".")
} }
// Port returns request client port.
// when error or empty, return 80.
func (input *BeegoInput) Port() int { func (input *BeegoInput) Port() int {
parts := strings.Split(input.Request.Host, ":") parts := strings.Split(input.Request.Host, ":")
if len(parts) == 2 { if len(parts) == 2 {
@ -130,10 +204,12 @@ func (input *BeegoInput) Port() int {
return 80 return 80
} }
// UserAgent returns request client user agent string.
func (input *BeegoInput) UserAgent() string { func (input *BeegoInput) UserAgent() string {
return input.Header("User-Agent") return input.Header("User-Agent")
} }
// Param returns router param by a given key.
func (input *BeegoInput) Param(key string) string { func (input *BeegoInput) Param(key string) string {
if v, ok := input.Params[key]; ok { if v, ok := input.Params[key]; ok {
return v return v
@ -141,15 +217,24 @@ func (input *BeegoInput) Param(key string) string {
return "" return ""
} }
// Query returns input data item string by a given string.
func (input *BeegoInput) Query(key string) string { func (input *BeegoInput) Query(key string) string {
if val := input.Param(key); val != "" {
return val
}
if input.Request.Form == nil {
input.Request.ParseForm() input.Request.ParseForm()
}
return input.Request.Form.Get(key) return input.Request.Form.Get(key)
} }
// Header returns request header item string by a given string.
func (input *BeegoInput) Header(key string) string { func (input *BeegoInput) Header(key string) string {
return input.Request.Header.Get(key) return input.Request.Header.Get(key)
} }
// Cookie returns request cookie item string by a given key.
// if non-existed, return empty string.
func (input *BeegoInput) Cookie(key string) string { func (input *BeegoInput) Cookie(key string) string {
ck, err := input.Request.Cookie(key) ck, err := input.Request.Cookie(key)
if err != nil { if err != nil {
@ -158,11 +243,13 @@ func (input *BeegoInput) Cookie(key string) string {
return ck.Value return ck.Value
} }
// Session returns current session item value by a given key.
func (input *BeegoInput) Session(key interface{}) interface{} { func (input *BeegoInput) Session(key interface{}) interface{} {
return input.CruSession.Get(key) return input.CruSession.Get(key)
} }
func (input *BeegoInput) Body() []byte { // Body returns the raw request body data as bytes.
func (input *BeegoInput) CopyBody() []byte {
requestbody, _ := ioutil.ReadAll(input.Request.Body) requestbody, _ := ioutil.ReadAll(input.Request.Body)
input.Request.Body.Close() input.Request.Body.Close()
bf := bytes.NewBuffer(requestbody) bf := bytes.NewBuffer(requestbody)
@ -171,6 +258,7 @@ func (input *BeegoInput) Body() []byte {
return requestbody return requestbody
} }
// GetData returns the stored data in this context.
func (input *BeegoInput) GetData(key interface{}) interface{} { func (input *BeegoInput) GetData(key interface{}) interface{} {
if v, ok := input.Data[key]; ok { if v, ok := input.Data[key]; ok {
return v return v
@ -178,6 +266,262 @@ func (input *BeegoInput) GetData(key interface{}) interface{} {
return nil return nil
} }
// SetData stores data with given key in this context.
// This data are only available in this context.
func (input *BeegoInput) SetData(key, val interface{}) { func (input *BeegoInput) SetData(key, val interface{}) {
input.Data[key] = val input.Data[key] = val
} }
// parseForm or parseMultiForm based on Content-type
func (input *BeegoInput) ParseFormOrMulitForm(maxMemory int64) error {
// Parse the body depending on the content type.
if strings.Contains(input.Header("Content-Type"), "multipart/form-data") {
if err := input.Request.ParseMultipartForm(maxMemory); err != nil {
return errors.New("Error parsing request body:" + err.Error())
}
} else if err := input.Request.ParseForm(); err != nil {
return errors.New("Error parsing request body:" + err.Error())
}
return nil
}
// Bind data from request.Form[key] to dest
// like /?id=123&isok=true&ft=1.2&ol[0]=1&ol[1]=2&ul[]=str&ul[]=array&user.Name=astaxie
// var id int beegoInput.Bind(&id, "id") id ==123
// var isok bool beegoInput.Bind(&isok, "isok") id ==true
// var ft float64 beegoInput.Bind(&ft, "ft") ft ==1.2
// ol := make([]int, 0, 2) beegoInput.Bind(&ol, "ol") ol ==[1 2]
// ul := make([]string, 0, 2) beegoInput.Bind(&ul, "ul") ul ==[str array]
// user struct{Name} beegoInput.Bind(&user, "user") user == {Name:"astaxie"}
func (input *BeegoInput) Bind(dest interface{}, key string) error {
value := reflect.ValueOf(dest)
if value.Kind() != reflect.Ptr {
return errors.New("beego: non-pointer passed to Bind: " + key)
}
value = value.Elem()
if !value.CanSet() {
return errors.New("beego: non-settable variable passed to Bind: " + key)
}
rv := input.bind(key, value.Type())
if !rv.IsValid() {
return errors.New("beego: reflect value is empty")
}
value.Set(rv)
return nil
}
func (input *BeegoInput) bind(key string, typ reflect.Type) reflect.Value {
rv := reflect.Zero(reflect.TypeOf(0))
switch typ.Kind() {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
val := input.Query(key)
if len(val) == 0 {
return rv
}
rv = input.bindInt(val, typ)
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
val := input.Query(key)
if len(val) == 0 {
return rv
}
rv = input.bindUint(val, typ)
case reflect.Float32, reflect.Float64:
val := input.Query(key)
if len(val) == 0 {
return rv
}
rv = input.bindFloat(val, typ)
case reflect.String:
val := input.Query(key)
if len(val) == 0 {
return rv
}
rv = input.bindString(val, typ)
case reflect.Bool:
val := input.Query(key)
if len(val) == 0 {
return rv
}
rv = input.bindBool(val, typ)
case reflect.Slice:
rv = input.bindSlice(&input.Request.Form, key, typ)
case reflect.Struct:
rv = input.bindStruct(&input.Request.Form, key, typ)
case reflect.Ptr:
rv = input.bindPoint(key, typ)
case reflect.Map:
rv = input.bindMap(&input.Request.Form, key, typ)
}
return rv
}
func (input *BeegoInput) bindValue(val string, typ reflect.Type) reflect.Value {
rv := reflect.Zero(reflect.TypeOf(0))
switch typ.Kind() {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
rv = input.bindInt(val, typ)
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
rv = input.bindUint(val, typ)
case reflect.Float32, reflect.Float64:
rv = input.bindFloat(val, typ)
case reflect.String:
rv = input.bindString(val, typ)
case reflect.Bool:
rv = input.bindBool(val, typ)
case reflect.Slice:
rv = input.bindSlice(&url.Values{"": {val}}, "", typ)
case reflect.Struct:
rv = input.bindStruct(&url.Values{"": {val}}, "", typ)
case reflect.Ptr:
rv = input.bindPoint(val, typ)
case reflect.Map:
rv = input.bindMap(&url.Values{"": {val}}, "", typ)
}
return rv
}
func (input *BeegoInput) bindInt(val string, typ reflect.Type) reflect.Value {
intValue, err := strconv.ParseInt(val, 10, 64)
if err != nil {
return reflect.Zero(typ)
}
pValue := reflect.New(typ)
pValue.Elem().SetInt(intValue)
return pValue.Elem()
}
func (input *BeegoInput) bindUint(val string, typ reflect.Type) reflect.Value {
uintValue, err := strconv.ParseUint(val, 10, 64)
if err != nil {
return reflect.Zero(typ)
}
pValue := reflect.New(typ)
pValue.Elem().SetUint(uintValue)
return pValue.Elem()
}
func (input *BeegoInput) bindFloat(val string, typ reflect.Type) reflect.Value {
floatValue, err := strconv.ParseFloat(val, 64)
if err != nil {
return reflect.Zero(typ)
}
pValue := reflect.New(typ)
pValue.Elem().SetFloat(floatValue)
return pValue.Elem()
}
func (input *BeegoInput) bindString(val string, typ reflect.Type) reflect.Value {
return reflect.ValueOf(val)
}
func (input *BeegoInput) bindBool(val string, typ reflect.Type) reflect.Value {
val = strings.TrimSpace(strings.ToLower(val))
switch val {
case "true", "on", "1":
return reflect.ValueOf(true)
}
return reflect.ValueOf(false)
}
type sliceValue struct {
index int // Index extracted from brackets. If -1, no index was provided.
value reflect.Value // the bound value for this slice element.
}
func (input *BeegoInput) bindSlice(params *url.Values, key string, typ reflect.Type) reflect.Value {
maxIndex := -1
numNoIndex := 0
sliceValues := []sliceValue{}
for reqKey, vals := range *params {
if !strings.HasPrefix(reqKey, key+"[") {
continue
}
// Extract the index, and the index where a sub-key starts. (e.g. field[0].subkey)
index := -1
leftBracket, rightBracket := len(key), strings.Index(reqKey[len(key):], "]")+len(key)
if rightBracket > leftBracket+1 {
index, _ = strconv.Atoi(reqKey[leftBracket+1 : rightBracket])
}
subKeyIndex := rightBracket + 1
// Handle the indexed case.
if index > -1 {
if index > maxIndex {
maxIndex = index
}
sliceValues = append(sliceValues, sliceValue{
index: index,
value: input.bind(reqKey[:subKeyIndex], typ.Elem()),
})
continue
}
// It's an un-indexed element. (e.g. element[])
numNoIndex += len(vals)
for _, val := range vals {
// Unindexed values can only be direct-bound.
sliceValues = append(sliceValues, sliceValue{
index: -1,
value: input.bindValue(val, typ.Elem()),
})
}
}
resultArray := reflect.MakeSlice(typ, maxIndex+1, maxIndex+1+numNoIndex)
for _, sv := range sliceValues {
if sv.index != -1 {
resultArray.Index(sv.index).Set(sv.value)
} else {
resultArray = reflect.Append(resultArray, sv.value)
}
}
return resultArray
}
func (input *BeegoInput) bindStruct(params *url.Values, key string, typ reflect.Type) reflect.Value {
result := reflect.New(typ).Elem()
fieldValues := make(map[string]reflect.Value)
for reqKey, val := range *params {
if !strings.HasPrefix(reqKey, key+".") {
continue
}
fieldName := reqKey[len(key)+1:]
if _, ok := fieldValues[fieldName]; !ok {
// Time to bind this field. Get it and make sure we can set it.
fieldValue := result.FieldByName(fieldName)
if !fieldValue.IsValid() {
continue
}
if !fieldValue.CanSet() {
continue
}
boundVal := input.bindValue(val[0], fieldValue.Type())
fieldValue.Set(boundVal)
fieldValues[fieldName] = boundVal
}
}
return result
}
func (input *BeegoInput) bindPoint(key string, typ reflect.Type) reflect.Value {
return input.bind(key, typ.Elem()).Addr()
}
func (input *BeegoInput) bindMap(params *url.Values, key string, typ reflect.Type) reflect.Value {
var (
result = reflect.MakeMap(typ)
keyType = typ.Key()
valueType = typ.Elem()
)
for paramName, values := range *params {
if !strings.HasPrefix(paramName, key+"[") || paramName[len(paramName)-1] != ']' {
continue
}
key := paramName[len(key)+1 : len(paramName)-1]
result.SetMapIndex(input.bindValue(key, keyType), input.bindValue(values[0], valueType))
}
return result
}

64
context/input_test.go Normal file
View File

@ -0,0 +1,64 @@
// Beego (http://beego.me/)
// @description beego is an open-source, high-performance web framework for the Go programming language.
// @link http://github.com/astaxie/beego for the canonical source repository
// @license http://github.com/astaxie/beego/blob/master/LICENSE
// @authors astaxie
package context
import (
"fmt"
"net/http"
"testing"
)
func TestParse(t *testing.T) {
r, _ := http.NewRequest("GET", "/?id=123&isok=true&ft=1.2&ol[0]=1&ol[1]=2&ul[]=str&ul[]=array&user.Name=astaxie", nil)
beegoInput := NewInput(r)
beegoInput.ParseFormOrMulitForm(1 << 20)
var id int
err := beegoInput.Bind(&id, "id")
if id != 123 || err != nil {
t.Fatal("id should has int value")
}
fmt.Println(id)
var isok bool
err = beegoInput.Bind(&isok, "isok")
if !isok || err != nil {
t.Fatal("isok should be true")
}
fmt.Println(isok)
var float float64
err = beegoInput.Bind(&float, "ft")
if float != 1.2 || err != nil {
t.Fatal("float should be equal to 1.2")
}
fmt.Println(float)
ol := make([]int, 0, 2)
err = beegoInput.Bind(&ol, "ol")
if len(ol) != 2 || err != nil || ol[0] != 1 || ol[1] != 2 {
t.Fatal("ol should has two elements")
}
fmt.Println(ol)
ul := make([]string, 0, 2)
err = beegoInput.Bind(&ul, "ul")
if len(ul) != 2 || err != nil || ul[0] != "str" || ul[1] != "array" {
t.Fatal("ul should has two elements")
}
fmt.Println(ul)
type User struct {
Name string
}
user := User{}
err = beegoInput.Bind(&user, "user")
if err != nil || user.Name != "astaxie" {
t.Fatal("user should has name")
}
fmt.Println(user)
}

View File

@ -1,3 +1,9 @@
// Beego (http://beego.me/)
// @description beego is an open-source, high-performance web framework for the Go programming language.
// @link http://github.com/astaxie/beego for the canonical source repository
// @license http://github.com/astaxie/beego/blob/master/LICENSE
// @authors astaxie
package context package context
import ( import (
@ -17,20 +23,27 @@ import (
"strings" "strings"
) )
// BeegoOutput does work for sending response header.
type BeegoOutput struct { type BeegoOutput struct {
Context *Context Context *Context
Status int Status int
EnableGzip bool EnableGzip bool
} }
// NewOutput returns new BeegoOutput.
// it contains nothing now.
func NewOutput() *BeegoOutput { func NewOutput() *BeegoOutput {
return &BeegoOutput{} return &BeegoOutput{}
} }
// Header sets response header item string via given key.
func (output *BeegoOutput) Header(key, val string) { func (output *BeegoOutput) Header(key, val string) {
output.Context.ResponseWriter.Header().Set(key, val) output.Context.ResponseWriter.Header().Set(key, val)
} }
// Body sets response body content.
// if EnableGzip, compress content string.
// it sends out response body directly.
func (output *BeegoOutput) Body(content []byte) { func (output *BeegoOutput) Body(content []byte) {
output_writer := output.Context.ResponseWriter.(io.Writer) output_writer := output.Context.ResponseWriter.(io.Writer)
if output.EnableGzip == true && output.Context.Input.Header("Accept-Encoding") != "" { if output.EnableGzip == true && output.Context.Input.Header("Accept-Encoding") != "" {
@ -64,43 +77,83 @@ func (output *BeegoOutput) Body(content []byte) {
} }
} }
// Cookie sets cookie value via given key.
// others are ordered as cookie's max age time, path,domain, secure and httponly.
func (output *BeegoOutput) Cookie(name string, value string, others ...interface{}) { func (output *BeegoOutput) Cookie(name string, value string, others ...interface{}) {
var b bytes.Buffer var b bytes.Buffer
fmt.Fprintf(&b, "%s=%s", sanitizeName(name), sanitizeValue(value)) fmt.Fprintf(&b, "%s=%s", sanitizeName(name), sanitizeValue(value))
if len(others) > 0 { if len(others) > 0 {
switch others[0].(type) { switch v := others[0].(type) {
case int: case int:
if others[0].(int) > 0 { if v > 0 {
fmt.Fprintf(&b, "; Max-Age=%d", others[0].(int)) fmt.Fprintf(&b, "; Max-Age=%d", v)
} else if others[0].(int) < 0 { } else if v < 0 {
fmt.Fprintf(&b, "; Max-Age=0") fmt.Fprintf(&b, "; Max-Age=0")
} }
case int64: case int64:
if others[0].(int64) > 0 { if v > 0 {
fmt.Fprintf(&b, "; Max-Age=%d", others[0].(int64)) fmt.Fprintf(&b, "; Max-Age=%d", v)
} else if others[0].(int64) < 0 { } else if v < 0 {
fmt.Fprintf(&b, "; Max-Age=0") fmt.Fprintf(&b, "; Max-Age=0")
} }
case int32: case int32:
if others[0].(int32) > 0 { if v > 0 {
fmt.Fprintf(&b, "; Max-Age=%d", others[0].(int32)) fmt.Fprintf(&b, "; Max-Age=%d", v)
} else if others[0].(int32) < 0 { } else if v < 0 {
fmt.Fprintf(&b, "; Max-Age=0") fmt.Fprintf(&b, "; Max-Age=0")
} }
} }
} }
// the settings below
// Path, Domain, Secure, HttpOnly
// can use nil skip set
// default "/"
if len(others) > 1 { if len(others) > 1 {
fmt.Fprintf(&b, "; Path=%s", sanitizeValue(others[1].(string))) if v, ok := others[1].(string); ok && len(v) > 0 {
fmt.Fprintf(&b, "; Path=%s", sanitizeValue(v))
} }
} else {
fmt.Fprintf(&b, "; Path=%s", "/")
}
// default empty
if len(others) > 2 { if len(others) > 2 {
fmt.Fprintf(&b, "; Domain=%s", sanitizeValue(others[2].(string))) if v, ok := others[2].(string); ok && len(v) > 0 {
fmt.Fprintf(&b, "; Domain=%s", sanitizeValue(v))
} }
}
// default empty
if len(others) > 3 { if len(others) > 3 {
var secure bool
switch v := others[3].(type) {
case bool:
secure = v
default:
if others[3] != nil {
secure = true
}
}
if secure {
fmt.Fprintf(&b, "; Secure") fmt.Fprintf(&b, "; Secure")
} }
}
// default false. for session cookie default true
httponly := false
if len(others) > 4 { if len(others) > 4 {
if v, ok := others[4].(bool); ok && v {
// HttpOnly = true
httponly = true
}
}
if httponly {
fmt.Fprintf(&b, "; HttpOnly") fmt.Fprintf(&b, "; HttpOnly")
} }
output.Context.ResponseWriter.Header().Add("Set-Cookie", b.String()) output.Context.ResponseWriter.Header().Add("Set-Cookie", b.String())
} }
@ -116,6 +169,8 @@ func sanitizeValue(v string) string {
return cookieValueSanitizer.Replace(v) return cookieValueSanitizer.Replace(v)
} }
// Json writes json to response body.
// if coding is true, it converts utf-8 to \u0000 type.
func (output *BeegoOutput) Json(data interface{}, hasIndent bool, coding bool) error { func (output *BeegoOutput) Json(data interface{}, hasIndent bool, coding bool) error {
output.Header("Content-Type", "application/json;charset=UTF-8") output.Header("Content-Type", "application/json;charset=UTF-8")
var content []byte var content []byte
@ -136,6 +191,7 @@ func (output *BeegoOutput) Json(data interface{}, hasIndent bool, coding bool) e
return nil return nil
} }
// Jsonp writes jsonp to response body.
func (output *BeegoOutput) Jsonp(data interface{}, hasIndent bool) error { func (output *BeegoOutput) Jsonp(data interface{}, hasIndent bool) error {
output.Header("Content-Type", "application/javascript;charset=UTF-8") output.Header("Content-Type", "application/javascript;charset=UTF-8")
var content []byte var content []byte
@ -161,6 +217,7 @@ func (output *BeegoOutput) Jsonp(data interface{}, hasIndent bool) error {
return nil return nil
} }
// Xml writes xml string to response body.
func (output *BeegoOutput) Xml(data interface{}, hasIndent bool) error { func (output *BeegoOutput) Xml(data interface{}, hasIndent bool) error {
output.Header("Content-Type", "application/xml;charset=UTF-8") output.Header("Content-Type", "application/xml;charset=UTF-8")
var content []byte var content []byte
@ -178,10 +235,16 @@ func (output *BeegoOutput) Xml(data interface{}, hasIndent bool) error {
return nil return nil
} }
func (output *BeegoOutput) Download(file string) { // Download forces response for download file.
// it prepares the download response header automatically.
func (output *BeegoOutput) Download(file string, filename ...string) {
output.Header("Content-Description", "File Transfer") output.Header("Content-Description", "File Transfer")
output.Header("Content-Type", "application/octet-stream") output.Header("Content-Type", "application/octet-stream")
if len(filename) > 0 && filename[0] != "" {
output.Header("Content-Disposition", "attachment; filename="+filename[0])
} else {
output.Header("Content-Disposition", "attachment; filename="+filepath.Base(file)) output.Header("Content-Disposition", "attachment; filename="+filepath.Base(file))
}
output.Header("Content-Transfer-Encoding", "binary") output.Header("Content-Transfer-Encoding", "binary")
output.Header("Expires", "0") output.Header("Expires", "0")
output.Header("Cache-Control", "must-revalidate") output.Header("Cache-Control", "must-revalidate")
@ -189,6 +252,8 @@ func (output *BeegoOutput) Download(file string) {
http.ServeFile(output.Context.ResponseWriter, output.Context.Request, file) http.ServeFile(output.Context.ResponseWriter, output.Context.Request, file)
} }
// ContentType sets the content type from ext string.
// MIME type is given in mime package.
func (output *BeegoOutput) ContentType(ext string) { func (output *BeegoOutput) ContentType(ext string) {
if !strings.HasPrefix(ext, ".") { if !strings.HasPrefix(ext, ".") {
ext = "." + ext ext = "." + ext
@ -199,43 +264,63 @@ func (output *BeegoOutput) ContentType(ext string) {
} }
} }
// SetStatus sets response status code.
// It writes response header directly.
func (output *BeegoOutput) SetStatus(status int) { func (output *BeegoOutput) SetStatus(status int) {
output.Context.ResponseWriter.WriteHeader(status) output.Context.ResponseWriter.WriteHeader(status)
output.Status = status output.Status = status
} }
// IsCachable returns boolean of this request is cached.
// HTTP 304 means cached.
func (output *BeegoOutput) IsCachable(status int) bool { func (output *BeegoOutput) IsCachable(status int) bool {
return output.Status >= 200 && output.Status < 300 || output.Status == 304 return output.Status >= 200 && output.Status < 300 || output.Status == 304
} }
// IsEmpty returns boolean of this request is empty.
// HTTP 201204 and 304 means empty.
func (output *BeegoOutput) IsEmpty(status int) bool { func (output *BeegoOutput) IsEmpty(status int) bool {
return output.Status == 201 || output.Status == 204 || output.Status == 304 return output.Status == 201 || output.Status == 204 || output.Status == 304
} }
// IsOk returns boolean of this request runs well.
// HTTP 200 means ok.
func (output *BeegoOutput) IsOk(status int) bool { func (output *BeegoOutput) IsOk(status int) bool {
return output.Status == 200 return output.Status == 200
} }
// IsSuccessful returns boolean of this request runs successfully.
// HTTP 2xx means ok.
func (output *BeegoOutput) IsSuccessful(status int) bool { func (output *BeegoOutput) IsSuccessful(status int) bool {
return output.Status >= 200 && output.Status < 300 return output.Status >= 200 && output.Status < 300
} }
// IsRedirect returns boolean of this request is redirection header.
// HTTP 301,302,307 means redirection.
func (output *BeegoOutput) IsRedirect(status int) bool { func (output *BeegoOutput) IsRedirect(status int) bool {
return output.Status == 301 || output.Status == 302 || output.Status == 303 || output.Status == 307 return output.Status == 301 || output.Status == 302 || output.Status == 303 || output.Status == 307
} }
// IsForbidden returns boolean of this request is forbidden.
// HTTP 403 means forbidden.
func (output *BeegoOutput) IsForbidden(status int) bool { func (output *BeegoOutput) IsForbidden(status int) bool {
return output.Status == 403 return output.Status == 403
} }
// IsNotFound returns boolean of this request is not found.
// HTTP 404 means forbidden.
func (output *BeegoOutput) IsNotFound(status int) bool { func (output *BeegoOutput) IsNotFound(status int) bool {
return output.Status == 404 return output.Status == 404
} }
// IsClient returns boolean of this request client sends error data.
// HTTP 4xx means forbidden.
func (output *BeegoOutput) IsClientError(status int) bool { func (output *BeegoOutput) IsClientError(status int) bool {
return output.Status >= 400 && output.Status < 500 return output.Status >= 400 && output.Status < 500
} }
// IsServerError returns boolean of this server handler errors.
// HTTP 5xx means server internal error.
func (output *BeegoOutput) IsServerError(status int) bool { func (output *BeegoOutput) IsServerError(status int) bool {
return output.Status >= 500 && output.Status < 600 return output.Status >= 500 && output.Status < 600
} }
@ -254,6 +339,7 @@ func stringsToJson(str string) string {
return jsons return jsons
} }
// Sessions sets session item value with given key.
func (output *BeegoOutput) Session(name interface{}, value interface{}) { func (output *BeegoOutput) Session(name interface{}, value interface{}) {
output.Context.Input.CruSession.Set(name, value) output.Context.Input.CruSession.Set(name, value)
} }

View File

@ -1,13 +1,14 @@
// Beego (http://beego.me/)
// @description beego is an open-source, high-performance web framework for the Go programming language.
// @link http://github.com/astaxie/beego for the canonical source repository
// @license http://github.com/astaxie/beego/blob/master/LICENSE
// @authors astaxie
package beego package beego
import ( import (
"bytes" "bytes"
"crypto/hmac"
"crypto/rand"
"crypto/sha1"
"encoding/base64"
"errors" "errors"
"fmt"
"html/template" "html/template"
"io" "io"
"io/ioutil" "io/ioutil"
@ -18,17 +19,33 @@ import (
"reflect" "reflect"
"strconv" "strconv"
"strings" "strings"
"time"
"github.com/astaxie/beego/context" "github.com/astaxie/beego/context"
"github.com/astaxie/beego/session" "github.com/astaxie/beego/session"
"github.com/astaxie/beego/utils"
)
//commonly used mime-types
const (
applicationJson = "application/json"
applicationXml = "application/xml"
textXml = "text/xml"
) )
var ( var (
// custom error when user stop request handler manually. // custom error when user stop request handler manually.
USERSTOPRUN = errors.New("User stop run") USERSTOPRUN = errors.New("User stop run")
GlobalControllerRouter map[string][]ControllerComments = make(map[string][]ControllerComments) //pkgpath+controller:comments
) )
// store the comment for the controller method
type ControllerComments struct {
Method string
Router string
AllowHTTPMethods []string
Params []map[string]string
}
// Controller defines some basic http request handler operations, such as // Controller defines some basic http request handler operations, such as
// http context, template and view, session and xsrf. // http context, template and view, session and xsrf.
type Controller struct { type Controller struct {
@ -45,6 +62,9 @@ type Controller struct {
CruSession session.SessionStore CruSession session.SessionStore
XSRFExpire int XSRFExpire int
AppController interface{} AppController interface{}
EnableRender bool
EnableXSRF bool
methodMapping map[string]func() //method:routertree
} }
// ControllerInterface is an interface to uniform all controller handler. // ControllerInterface is an interface to uniform all controller handler.
@ -62,11 +82,12 @@ type ControllerInterface interface {
Render() error Render() error
XsrfToken() string XsrfToken() string
CheckXsrfCookie() bool CheckXsrfCookie() bool
HandlerFunc(fn string) bool
URLMapping()
} }
// Init generates default values of controller operations. // Init generates default values of controller operations.
func (c *Controller) Init(ctx *context.Context, controllerName, actionName string, app interface{}) { func (c *Controller) Init(ctx *context.Context, controllerName, actionName string, app interface{}) {
c.Data = make(map[interface{}]interface{})
c.Layout = "" c.Layout = ""
c.TplNames = "" c.TplNames = ""
c.controllerName = controllerName c.controllerName = controllerName
@ -74,6 +95,10 @@ func (c *Controller) Init(ctx *context.Context, controllerName, actionName strin
c.Ctx = ctx c.Ctx = ctx
c.TplExt = "tpl" c.TplExt = "tpl"
c.AppController = app c.AppController = app
c.EnableRender = true
c.EnableXSRF = true
c.Data = ctx.Input.Data
c.methodMapping = make(map[string]func())
} }
// Prepare runs after Init before request function execution. // Prepare runs after Init before request function execution.
@ -121,8 +146,29 @@ func (c *Controller) Options() {
http.Error(c.Ctx.ResponseWriter, "Method Not Allowed", 405) http.Error(c.Ctx.ResponseWriter, "Method Not Allowed", 405)
} }
// call function fn
func (c *Controller) HandlerFunc(fnname string) bool {
if v, ok := c.methodMapping[fnname]; ok {
v()
return true
} else {
return false
}
}
// URLMapping register the internal Controller router.
func (c *Controller) URLMapping() {
}
func (c *Controller) Mapping(method string, fn func()) {
c.methodMapping[method] = fn
}
// Render sends the response with rendered template bytes as text/html type. // Render sends the response with rendered template bytes as text/html type.
func (c *Controller) Render() error { func (c *Controller) Render() error {
if !c.EnableRender {
return nil
}
rb, err := c.RenderBytes() rb, err := c.RenderBytes()
if err != nil { if err != nil {
@ -140,7 +186,7 @@ func (c *Controller) RenderString() (string, error) {
return string(b), e return string(b), e
} }
// RenderBytes returns the bytes of renderd tempate string. Do not send out response. // RenderBytes returns the bytes of rendered template string. Do not send out response.
func (c *Controller) RenderBytes() ([]byte, error) { func (c *Controller) RenderBytes() ([]byte, error) {
//if the controller has set layout, then first get the tplname's content set the content to the layout //if the controller has set layout, then first get the tplname's content set the content to the layout
if c.Layout != "" { if c.Layout != "" {
@ -153,7 +199,6 @@ func (c *Controller) RenderBytes() ([]byte, error) {
newbytes := bytes.NewBufferString("") newbytes := bytes.NewBufferString("")
if _, ok := BeeTemplates[c.TplNames]; !ok { if _, ok := BeeTemplates[c.TplNames]; !ok {
panic("can't find templatefile in the path:" + c.TplNames) panic("can't find templatefile in the path:" + c.TplNames)
return []byte{}, errors.New("can't find templatefile in the path:" + c.TplNames)
} }
err := BeeTemplates[c.TplNames].ExecuteTemplate(newbytes, c.TplNames, c.Data) err := BeeTemplates[c.TplNames].ExecuteTemplate(newbytes, c.TplNames, c.Data)
if err != nil { if err != nil {
@ -165,7 +210,7 @@ func (c *Controller) RenderBytes() ([]byte, error) {
if c.LayoutSections != nil { if c.LayoutSections != nil {
for sectionName, sectionTpl := range c.LayoutSections { for sectionName, sectionTpl := range c.LayoutSections {
if (sectionTpl == "") { if sectionTpl == "" {
c.Data[sectionName] = "" c.Data[sectionName] = ""
continue continue
} }
@ -199,7 +244,6 @@ func (c *Controller) RenderBytes() ([]byte, error) {
ibytes := bytes.NewBufferString("") ibytes := bytes.NewBufferString("")
if _, ok := BeeTemplates[c.TplNames]; !ok { if _, ok := BeeTemplates[c.TplNames]; !ok {
panic("can't find templatefile in the path:" + c.TplNames) panic("can't find templatefile in the path:" + c.TplNames)
return []byte{}, errors.New("can't find templatefile in the path:" + c.TplNames)
} }
err := BeeTemplates[c.TplNames].ExecuteTemplate(ibytes, c.TplNames, c.Data) err := BeeTemplates[c.TplNames].ExecuteTemplate(ibytes, c.TplNames, c.Data)
if err != nil { if err != nil {
@ -209,7 +253,6 @@ func (c *Controller) RenderBytes() ([]byte, error) {
icontent, _ := ioutil.ReadAll(ibytes) icontent, _ := ioutil.ReadAll(ibytes)
return icontent, nil return icontent, nil
} }
return []byte{}, nil
} }
// Redirect sends the redirection response to url with status code. // Redirect sends the redirection response to url with status code.
@ -243,7 +286,6 @@ func (c *Controller) UrlFor(endpoint string, values ...string) string {
} else { } else {
return UrlFor(endpoint, values...) return UrlFor(endpoint, values...)
} }
return ""
} }
// ServeJson sends a json response with encoding charset. // ServeJson sends a json response with encoding charset.
@ -283,12 +325,22 @@ func (c *Controller) ServeXml() {
c.Ctx.Output.Xml(c.Data["xml"], hasIndent) c.Ctx.Output.Xml(c.Data["xml"], hasIndent)
} }
// ServeFormatted serve Xml OR Json, depending on the value of the Accept header
func (c *Controller) ServeFormatted() {
accept := c.Ctx.Input.Header("Accept")
switch accept {
case applicationJson:
c.ServeJson()
case applicationXml, textXml:
c.ServeXml()
default:
c.ServeJson()
}
}
// Input returns the input data map from POST or PUT request body and query string. // Input returns the input data map from POST or PUT request body and query string.
func (c *Controller) Input() url.Values { func (c *Controller) Input() url.Values {
ct := c.Ctx.Request.Header.Get("Content-Type") if c.Ctx.Request.Form == nil {
if strings.Contains(ct, "multipart/form-data") {
c.Ctx.Request.ParseMultipartForm(MaxMemory) //64MB
} else {
c.Ctx.Request.ParseForm() c.Ctx.Request.ParseForm()
} }
return c.Ctx.Request.Form return c.Ctx.Request.Form
@ -301,17 +353,17 @@ func (c *Controller) ParseForm(obj interface{}) error {
// GetString returns the input value by key string. // GetString returns the input value by key string.
func (c *Controller) GetString(key string) string { func (c *Controller) GetString(key string) string {
return c.Input().Get(key) return c.Ctx.Input.Query(key)
} }
// GetStrings returns the input string slice by key string. // GetStrings returns the input string slice by key string.
// it's designed for multi-value input field such as checkbox(input[type=checkbox]), multi-selection. // it's designed for multi-value input field such as checkbox(input[type=checkbox]), multi-selection.
func (c *Controller) GetStrings(key string) []string { func (c *Controller) GetStrings(key string) []string {
r := c.Ctx.Request f := c.Input()
if r.Form == nil { if f == nil {
return []string{} return []string{}
} }
vs := r.Form[key] vs := f[key]
if len(vs) > 0 { if len(vs) > 0 {
return vs return vs
} }
@ -320,17 +372,17 @@ func (c *Controller) GetStrings(key string) []string {
// GetInt returns input value as int64. // GetInt returns input value as int64.
func (c *Controller) GetInt(key string) (int64, error) { func (c *Controller) GetInt(key string) (int64, error) {
return strconv.ParseInt(c.Input().Get(key), 10, 64) return strconv.ParseInt(c.Ctx.Input.Query(key), 10, 64)
} }
// GetBool returns input value as bool. // GetBool returns input value as bool.
func (c *Controller) GetBool(key string) (bool, error) { func (c *Controller) GetBool(key string) (bool, error) {
return strconv.ParseBool(c.Input().Get(key)) return strconv.ParseBool(c.Ctx.Input.Query(key))
} }
// GetFloat returns input value as float64. // GetFloat returns input value as float64.
func (c *Controller) GetFloat(key string) (float64, error) { func (c *Controller) GetFloat(key string) (float64, error) {
return strconv.ParseFloat(c.Input().Get(key), 64) return strconv.ParseFloat(c.Ctx.Input.Query(key), 64)
} }
// GetFile returns the file data in file upload field named as key. // GetFile returns the file data in file upload field named as key.
@ -391,12 +443,16 @@ func (c *Controller) DelSession(name interface{}) {
// SessionRegenerateID regenerates session id for this session. // SessionRegenerateID regenerates session id for this session.
// the session data have no changes. // the session data have no changes.
func (c *Controller) SessionRegenerateID() { func (c *Controller) SessionRegenerateID() {
if c.CruSession != nil {
c.CruSession.SessionRelease(c.Ctx.ResponseWriter)
}
c.CruSession = GlobalSessions.SessionRegenerateId(c.Ctx.ResponseWriter, c.Ctx.Request) c.CruSession = GlobalSessions.SessionRegenerateId(c.Ctx.ResponseWriter, c.Ctx.Request)
c.Ctx.Input.CruSession = c.CruSession c.Ctx.Input.CruSession = c.CruSession
} }
// DestroySession cleans session data and session cookie. // DestroySession cleans session data and session cookie.
func (c *Controller) DestroySession() { func (c *Controller) DestroySession() {
c.Ctx.Input.CruSession.Flush()
GlobalSessions.SessionDestroy(c.Ctx.ResponseWriter, c.Ctx.Request) GlobalSessions.SessionDestroy(c.Ctx.ResponseWriter, c.Ctx.Request)
} }
@ -407,40 +463,12 @@ func (c *Controller) IsAjax() bool {
// GetSecureCookie returns decoded cookie value from encoded browser cookie values. // GetSecureCookie returns decoded cookie value from encoded browser cookie values.
func (c *Controller) GetSecureCookie(Secret, key string) (string, bool) { func (c *Controller) GetSecureCookie(Secret, key string) (string, bool) {
val := c.Ctx.GetCookie(key) return c.Ctx.GetSecureCookie(Secret, key)
if val == "" {
return "", false
}
parts := strings.SplitN(val, "|", 3)
if len(parts) != 3 {
return "", false
}
vs := parts[0]
timestamp := parts[1]
sig := parts[2]
h := hmac.New(sha1.New, []byte(Secret))
fmt.Fprintf(h, "%s%s", vs, timestamp)
if fmt.Sprintf("%02x", h.Sum(nil)) != sig {
return "", false
}
res, _ := base64.URLEncoding.DecodeString(vs)
return string(res), true
} }
// SetSecureCookie puts value into cookie after encoded the value. // SetSecureCookie puts value into cookie after encoded the value.
func (c *Controller) SetSecureCookie(Secret, name, val string, age int64) { func (c *Controller) SetSecureCookie(Secret, name, value string, others ...interface{}) {
vs := base64.URLEncoding.EncodeToString([]byte(val)) c.Ctx.SetSecureCookie(Secret, name, value, others...)
timestamp := strconv.FormatInt(time.Now().UnixNano(), 10)
h := hmac.New(sha1.New, []byte(Secret))
fmt.Fprintf(h, "%s%s", vs, timestamp)
sig := fmt.Sprintf("%02x", h.Sum(nil))
cookie := strings.Join([]string{vs, timestamp, sig}, "|")
c.Ctx.SetCookie(name, cookie, age, "/")
} }
// XsrfToken creates a xsrf token string and returns. // XsrfToken creates a xsrf token string and returns.
@ -454,7 +482,7 @@ func (c *Controller) XsrfToken() string {
} else { } else {
expire = int64(XSRFExpire) expire = int64(XSRFExpire)
} }
token = getRandomString(15) token = string(utils.RandomCreateBytes(32))
c.SetSecureCookie(XSRFKEY, "_xsrf", token, expire) c.SetSecureCookie(XSRFKEY, "_xsrf", token, expire)
} }
c._xsrf_token = token c._xsrf_token = token
@ -466,6 +494,9 @@ func (c *Controller) XsrfToken() string {
// the token can provided in request header "X-Xsrftoken" and "X-CsrfToken" // the token can provided in request header "X-Xsrftoken" and "X-CsrfToken"
// or in form field value named as "_xsrf". // or in form field value named as "_xsrf".
func (c *Controller) CheckXsrfCookie() bool { func (c *Controller) CheckXsrfCookie() bool {
if !c.EnableXSRF {
return true
}
token := c.GetString("_xsrf") token := c.GetString("_xsrf")
if token == "" { if token == "" {
token = c.Ctx.Request.Header.Get("X-Xsrftoken") token = c.Ctx.Request.Header.Get("X-Xsrftoken")
@ -491,14 +522,3 @@ func (c *Controller) XsrfFormHtml() string {
func (c *Controller) GetControllerAndAction() (controllerName, actionName string) { func (c *Controller) GetControllerAndAction() (controllerName, actionName string) {
return c.controllerName, c.actionName return c.controllerName, c.actionName
} }
// getRandomString returns random string.
func getRandomString(n int) string {
const alphanum = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"
var bytes = make([]byte, n)
rand.Read(bytes)
for i, b := range bytes {
bytes[i] = alphanum[b%byte(len(alphanum))]
}
return string(bytes)
}

38
docs.go Normal file
View File

@ -0,0 +1,38 @@
package beego
import (
"encoding/json"
"github.com/astaxie/beego/context"
)
var GlobalDocApi map[string]interface{}
func init() {
if EnableDocs {
GlobalDocApi = make(map[string]interface{})
}
}
func serverDocs(ctx *context.Context) {
var obj interface{}
if splat := ctx.Input.Param(":splat"); splat == "" {
obj = GlobalDocApi["Root"]
} else {
if v, ok := GlobalDocApi[splat]; ok {
obj = v
}
}
if obj != nil {
bt, err := json.Marshal(obj)
if err != nil {
ctx.Output.SetStatus(504)
return
}
ctx.Output.Header("Content-Type", "application/json;charset=UTF-8")
ctx.Output.Header("Access-Control-Allow-Origin", "*")
ctx.Output.Body(bt)
return
}
ctx.Output.SetStatus(404)
}

View File

@ -1,7 +1,14 @@
// Beego (http://beego.me/)
// @description beego is an open-source, high-performance web framework for the Go programming language.
// @link http://github.com/astaxie/beego for the canonical source repository
// @license http://github.com/astaxie/beego/blob/master/LICENSE
// @authors astaxie
package controllers package controllers
import ( import (
"encoding/json" "encoding/json"
"github.com/astaxie/beego" "github.com/astaxie/beego"
"github.com/astaxie/beego/example/beeapi/models" "github.com/astaxie/beego/example/beeapi/models"
) )

View File

@ -1,3 +1,9 @@
// Beego (http://beego.me/)
// @description beego is an open-source, high-performance web framework for the Go programming language.
// @link http://github.com/astaxie/beego for the canonical source repository
// @license http://github.com/astaxie/beego/blob/master/LICENSE
// @authors astaxie
package main package main
import ( import (

View File

@ -1,3 +1,9 @@
// Beego (http://beego.me/)
// @description beego is an open-source, high-performance web framework for the Go programming language.
// @link http://github.com/astaxie/beego for the canonical source repository
// @license http://github.com/astaxie/beego/blob/master/LICENSE
// @authors astaxie
package models package models
import ( import (

View File

@ -1,3 +1,9 @@
// Beego (http://beego.me/)
// @description beego is an open-source, high-performance web framework for the Go programming language.
// @link http://github.com/astaxie/beego for the canonical source repository
// @license http://github.com/astaxie/beego/blob/master/LICENSE
// @authors Unknwon
package controllers package controllers
import ( import (

View File

@ -1,12 +1,19 @@
// Beego (http://beego.me/)
// @description beego is an open-source, high-performance web framework for the Go programming language.
// @link http://github.com/astaxie/beego for the canonical source repository
// @license http://github.com/astaxie/beego/blob/master/LICENSE
// @authors Unknwon
package controllers package controllers
import ( import (
"github.com/astaxie/beego"
"github.com/garyburd/go-websocket/websocket"
"io/ioutil" "io/ioutil"
"math/rand" "math/rand"
"net/http" "net/http"
"time" "time"
"github.com/astaxie/beego"
"github.com/gorilla/websocket"
) )
const ( const (
@ -53,9 +60,9 @@ func (c *connection) readPump() {
break break
} }
switch op { switch op {
case websocket.OpPong: case websocket.PongMessage:
c.ws.SetReadDeadline(time.Now().Add(readWait)) c.ws.SetReadDeadline(time.Now().Add(readWait))
case websocket.OpText: case websocket.TextMessage:
message, err := ioutil.ReadAll(r) message, err := ioutil.ReadAll(r)
if err != nil { if err != nil {
break break
@ -82,14 +89,14 @@ func (c *connection) writePump() {
select { select {
case message, ok := <-c.send: case message, ok := <-c.send:
if !ok { if !ok {
c.write(websocket.OpClose, []byte{}) c.write(websocket.CloseMessage, []byte{})
return return
} }
if err := c.write(websocket.OpText, message); err != nil { if err := c.write(websocket.TextMessage, message); err != nil {
return return
} }
case <-ticker.C: case <-ticker.C:
if err := c.write(websocket.OpPing, []byte{}); err != nil { if err := c.write(websocket.PingMessage, []byte{}); err != nil {
return return
} }
} }
@ -142,8 +149,13 @@ type WSController struct {
beego.Controller beego.Controller
} }
var upgrader = websocket.Upgrader{
ReadBufferSize: 1024,
WriteBufferSize: 1024,
}
func (this *WSController) Get() { func (this *WSController) Get() {
ws, err := websocket.Upgrade(this.Ctx.ResponseWriter, this.Ctx.Request.Header, nil, 1024, 1024) ws, err := upgrader.Upgrade(this.Ctx.ResponseWriter, this.Ctx.Request,nil)
if _, ok := err.(websocket.HandshakeError); ok { if _, ok := err.(websocket.HandshakeError); ok {
http.Error(this.Ctx.ResponseWriter, "Not a websocket handshake", 400) http.Error(this.Ctx.ResponseWriter, "Not a websocket handshake", 400)
return return

View File

@ -1,3 +1,8 @@
// Beego (http://beego.me/)
// @description beego is an open-source, high-performance web framework for the Go programming language.
// @link http://github.com/astaxie/beego for the canonical source repository
// @license http://github.com/astaxie/beego/blob/master/LICENSE
// @authors Unknwon
package main package main
import ( import (

149
filter.go
View File

@ -1,148 +1,29 @@
package beego // Beego (http://beego.me/)
// @description beego is an open-source, high-performance web framework for the Go programming language.
// @link http://github.com/astaxie/beego for the canonical source repository
// @license http://github.com/astaxie/beego/blob/master/LICENSE
// @authors astaxie
import ( package beego
"regexp"
"strings"
)
// FilterRouter defines filter operation before controller handler execution. // FilterRouter defines filter operation before controller handler execution.
// it can match patterned url and do filter function when action arrives. // it can match patterned url and do filter function when action arrives.
type FilterRouter struct { type FilterRouter struct {
pattern string
regex *regexp.Regexp
filterFunc FilterFunc filterFunc FilterFunc
hasregex bool tree *Tree
params map[int]string pattern string
parseParams map[string]string
} }
// ValidRouter check current request is valid for this filter. // ValidRouter check current request is valid for this filter.
// if matched, returns parsed params in this request by defined filter router pattern. // if matched, returns parsed params in this request by defined filter router pattern.
func (mr *FilterRouter) ValidRouter(router string) (bool, map[string]string) { func (f *FilterRouter) ValidRouter(router string) (bool, map[string]string) {
if mr.pattern == "" { isok, params := f.tree.Match(router)
return true, nil if isok == nil {
}
if mr.pattern == "*" {
return true, nil
}
if router == mr.pattern {
return true, nil
}
if mr.hasregex {
if !mr.regex.MatchString(router) {
return false, nil return false, nil
} }
matches := mr.regex.FindStringSubmatch(router) if isok, ok := isok.(bool); ok {
if len(matches) > 0 { return isok, params
if len(matches[0]) == len(router) { } else {
params := make(map[string]string)
for i, match := range matches[1:] {
params[mr.params[i]] = match
}
return true, params
}
}
}
return false, nil return false, nil
} }
func buildFilter(pattern string, filter FilterFunc) *FilterRouter {
mr := new(FilterRouter)
mr.params = make(map[int]string)
mr.filterFunc = filter
parts := strings.Split(pattern, "/")
j := 0
for i, part := range parts {
if strings.HasPrefix(part, ":") {
expr := "(.+)"
//a user may choose to override the default expression
// similar to expressjs: /user/:id([0-9]+)
if index := strings.Index(part, "("); index != -1 {
expr = part[index:]
part = part[:index]
//match /user/:id:int ([0-9]+)
//match /post/:username:string ([\w]+)
} else if lindex := strings.LastIndex(part, ":"); lindex != 0 {
switch part[lindex:] {
case ":int":
expr = "([0-9]+)"
part = part[:lindex]
case ":string":
expr = `([\w]+)`
part = part[:lindex]
}
}
mr.params[j] = part
parts[i] = expr
j++
}
if strings.HasPrefix(part, "*") {
expr := "(.+)"
if part == "*.*" {
mr.params[j] = ":path"
parts[i] = "([^.]+).([^.]+)"
j++
mr.params[j] = ":ext"
j++
} else {
mr.params[j] = ":splat"
parts[i] = expr
j++
}
}
//url like someprefix:id(xxx).html
if strings.Contains(part, ":") && strings.Contains(part, "(") && strings.Contains(part, ")") {
var out []rune
var start bool
var startexp bool
var param []rune
var expt []rune
for _, v := range part {
if start {
if v != '(' {
param = append(param, v)
continue
}
}
if startexp {
if v != ')' {
expt = append(expt, v)
continue
}
}
if v == ':' {
param = make([]rune, 0)
param = append(param, ':')
start = true
} else if v == '(' {
startexp = true
start = false
mr.params[j] = string(param)
j++
expt = make([]rune, 0)
expt = append(expt, '(')
} else if v == ')' {
startexp = false
expt = append(expt, ')')
out = append(out, expt...)
} else {
out = append(out, v)
}
}
parts[i] = string(out)
}
}
if j != 0 {
pattern = strings.Join(parts, "/")
regex, regexErr := regexp.Compile(pattern)
if regexErr != nil {
//TODO add error handling here to avoid panic
panic(regexErr)
}
mr.regex = regex
mr.hasregex = true
}
mr.pattern = pattern
return mr
} }

60
filter_test.go Normal file
View File

@ -0,0 +1,60 @@
// Beego (http://beego.me/)
// @description beego is an open-source, high-performance web framework for the Go programming language.
// @link http://github.com/astaxie/beego for the canonical source repository
// @license http://github.com/astaxie/beego/blob/master/LICENSE
// @authors astaxie
package beego
import (
"net/http"
"net/http/httptest"
"testing"
"github.com/astaxie/beego/context"
)
var FilterUser = func(ctx *context.Context) {
ctx.Output.Body([]byte("i am " + ctx.Input.Params[":last"] + ctx.Input.Params[":first"]))
}
func TestFilter(t *testing.T) {
r, _ := http.NewRequest("GET", "/person/asta/Xie", nil)
w := httptest.NewRecorder()
handler := NewControllerRegister()
handler.InsertFilter("/person/:last/:first", BeforeRouter, FilterUser)
handler.Add("/person/:last/:first", &TestController{})
handler.ServeHTTP(w, r)
if w.Body.String() != "i am astaXie" {
t.Errorf("user define func can't run")
}
}
var FilterAdminUser = func(ctx *context.Context) {
ctx.Output.Body([]byte("i am admin"))
}
// Filter pattern /admin/:all
// all url like /admin/ /admin/xie will all get filter
func TestPatternTwo(t *testing.T) {
r, _ := http.NewRequest("GET", "/admin/", nil)
w := httptest.NewRecorder()
handler := NewControllerRegister()
handler.InsertFilter("/admin/?:all", BeforeRouter, FilterAdminUser)
handler.ServeHTTP(w, r)
if w.Body.String() != "i am admin" {
t.Errorf("filter /admin/ can't run")
}
}
func TestPatternThree(t *testing.T) {
r, _ := http.NewRequest("GET", "/admin/astaxie", nil)
w := httptest.NewRecorder()
handler := NewControllerRegister()
handler.InsertFilter("/admin/:all", BeforeRouter, FilterAdminUser)
handler.ServeHTTP(w, r)
if w.Body.String() != "i am admin" {
t.Errorf("filter /admin/astaxie can't run")
}
}

View File

@ -1,25 +0,0 @@
package beego
import (
"net/http"
"net/http/httptest"
"testing"
"github.com/astaxie/beego/context"
)
var FilterUser = func(ctx *context.Context) {
ctx.Output.Body([]byte("i am " + ctx.Input.Params[":last"] + ctx.Input.Params[":first"]))
}
func TestFilter(t *testing.T) {
r, _ := http.NewRequest("GET", "/person/asta/Xie", nil)
w := httptest.NewRecorder()
handler := NewControllerRegistor()
handler.AddFilter("/person/:last/:first", "AfterStatic", FilterUser)
handler.Add("/person/:last/:first", &TestController{})
handler.ServeHTTP(w, r)
if w.Body.String() != "i am astaXie" {
t.Errorf("user define func can't run")
}
}

View File

@ -1,3 +1,9 @@
// Beego (http://beego.me/)
// @description beego is an open-source, high-performance web framework for the Go programming language.
// @link http://github.com/astaxie/beego for the canonical source repository
// @license http://github.com/astaxie/beego/blob/master/LICENSE
// @authors astaxie
package beego package beego
import ( import (
@ -6,9 +12,6 @@ import (
"strings" "strings"
) )
// the separation string when encoding flash data.
const BEEGO_FLASH_SEP = "#BEEGOFLASH#"
// FlashData is a tools to maintain data when using across request. // FlashData is a tools to maintain data when using across request.
type FlashData struct { type FlashData struct {
Data map[string]string Data map[string]string
@ -54,29 +57,27 @@ func (fd *FlashData) Store(c *Controller) {
c.Data["flash"] = fd.Data c.Data["flash"] = fd.Data
var flashValue string var flashValue string
for key, value := range fd.Data { for key, value := range fd.Data {
flashValue += "\x00" + key + BEEGO_FLASH_SEP + value + "\x00" flashValue += "\x00" + key + "\x23" + FlashSeperator + "\x23" + value + "\x00"
} }
c.Ctx.SetCookie("BEEGO_FLASH", url.QueryEscape(flashValue), 0, "/") c.Ctx.SetCookie(FlashName, url.QueryEscape(flashValue), 0, "/")
} }
// ReadFromRequest parsed flash data from encoded values in cookie. // ReadFromRequest parsed flash data from encoded values in cookie.
func ReadFromRequest(c *Controller) *FlashData { func ReadFromRequest(c *Controller) *FlashData {
flash := &FlashData{ flash := NewFlash()
Data: make(map[string]string), if cookie, err := c.Ctx.Request.Cookie(FlashName); err == nil {
}
if cookie, err := c.Ctx.Request.Cookie("BEEGO_FLASH"); err == nil {
v, _ := url.QueryUnescape(cookie.Value) v, _ := url.QueryUnescape(cookie.Value)
vals := strings.Split(v, "\x00") vals := strings.Split(v, "\x00")
for _, v := range vals { for _, v := range vals {
if len(v) > 0 { if len(v) > 0 {
kv := strings.Split(v, BEEGO_FLASH_SEP) kv := strings.Split(v, "\x23"+FlashSeperator+"\x23")
if len(kv) == 2 { if len(kv) == 2 {
flash.Data[kv[0]] = kv[1] flash.Data[kv[0]] = kv[1]
} }
} }
} }
//read one time then delete it //read one time then delete it
c.Ctx.SetCookie("BEEGO_FLASH", "", -1, "/") c.Ctx.SetCookie(FlashName, "", -1, "/")
} }
c.Data["flash"] = flash.Data c.Data["flash"] = flash.Data
return flash return flash

46
flash_test.go Normal file
View File

@ -0,0 +1,46 @@
// Beego (http://beego.me/)
// @description beego is an open-source, high-performance web framework for the Go programming language.
// @link http://github.com/astaxie/beego for the canonical source repository
// @license http://github.com/astaxie/beego/blob/master/LICENSE
// @authors astaxie
package beego
import (
"net/http"
"net/http/httptest"
"strings"
"testing"
)
type TestFlashController struct {
Controller
}
func (this *TestFlashController) TestWriteFlash() {
flash := NewFlash()
flash.Notice("TestFlashString")
flash.Store(&this.Controller)
// we choose to serve json because we don't want to load a template html file
this.ServeJson(true)
}
func TestFlashHeader(t *testing.T) {
// create fake GET request
r, _ := http.NewRequest("GET", "/", nil)
w := httptest.NewRecorder()
// setup the handler
handler := NewControllerRegister()
handler.Add("/", &TestFlashController{}, "get:TestWriteFlash")
handler.ServeHTTP(w, r)
// get the Set-Cookie value
sc := w.Header().Get("Set-Cookie")
// match for the expected header
res := strings.Contains(sc, "BEEGO_FLASH=%00notice%23BEEGOFLASH%23TestFlashString%00")
// validate the assertion
if res != true {
t.Errorf("TestFlashHeader() unable to validate flash message")
}
}

View File

@ -60,3 +60,21 @@ some http request need setcookie. So set it like this:
cookie.Value = "astaxie" cookie.Value = "astaxie"
httplib.Get("http://beego.me/").SetCookie(cookie) httplib.Get("http://beego.me/").SetCookie(cookie)
## upload file
httplib support mutil file upload, use `b.PostFile()`
b:=httplib.Post("http://beego.me/")
b.Param("username","astaxie")
b.Param("password","123456")
b.PostFile("uploadfile1", "httplib.pdf")
b.PostFile("uploadfile2", "httplib.txt")
str, err := b.String()
if err != nil {
t.Fatal(err)
}
fmt.Println(str)
## set HTTP version
some servers need to specify the protocol version of HTTP
httplib.Get("http://beego.me/").SetProtocolVersion("HTTP/1.1")

View File

@ -1,3 +1,9 @@
// Beego (http://beego.me/)
// @description beego is an open-source, high-performance web framework for the Go programming language.
// @link http://github.com/astaxie/beego for the canonical source repository
// @license http://github.com/astaxie/beego/blob/master/LICENSE
// @authors astaxie
package httplib package httplib
import ( import (
@ -7,98 +13,201 @@ import (
"encoding/xml" "encoding/xml"
"io" "io"
"io/ioutil" "io/ioutil"
"mime/multipart"
"net" "net"
"net/http" "net/http"
"net/http/cookiejar"
"net/http/httputil" "net/http/httputil"
"net/url" "net/url"
"os" "os"
"strings" "strings"
"sync"
"time" "time"
) )
var defaultUserAgent = "beegoServer" var defaultSetting = BeegoHttpSettings{false, "beegoServer", 60 * time.Second, 60 * time.Second, nil, nil, nil, false}
var defaultCookieJar http.CookieJar
var settingMutex sync.Mutex
// createDefaultCookieJar creates a global cookiejar to store cookies.
func createDefaultCookie() {
settingMutex.Lock()
defer settingMutex.Unlock()
defaultCookieJar, _ = cookiejar.New(nil)
}
// Overwrite default settings
func SetDefaultSetting(setting BeegoHttpSettings) {
settingMutex.Lock()
defer settingMutex.Unlock()
defaultSetting = setting
if defaultSetting.ConnectTimeout == 0 {
defaultSetting.ConnectTimeout = 60 * time.Second
}
if defaultSetting.ReadWriteTimeout == 0 {
defaultSetting.ReadWriteTimeout = 60 * time.Second
}
}
// Get returns *BeegoHttpRequest with GET method.
func Get(url string) *BeegoHttpRequest { func Get(url string) *BeegoHttpRequest {
var req http.Request var req http.Request
req.Method = "GET" req.Method = "GET"
req.Header = http.Header{} req.Header = http.Header{}
req.Header.Set("User-Agent", defaultUserAgent) return &BeegoHttpRequest{url, &req, map[string]string{}, map[string]string{}, defaultSetting}
return &BeegoHttpRequest{url, &req, map[string]string{}, false, 60 * time.Second, 60 * time.Second, nil}
} }
// Post returns *BeegoHttpRequest with POST method.
func Post(url string) *BeegoHttpRequest { func Post(url string) *BeegoHttpRequest {
var req http.Request var req http.Request
req.Method = "POST" req.Method = "POST"
req.Header = http.Header{} req.Header = http.Header{}
req.Header.Set("User-Agent", defaultUserAgent) return &BeegoHttpRequest{url, &req, map[string]string{}, map[string]string{}, defaultSetting}
return &BeegoHttpRequest{url, &req, map[string]string{}, false, 60 * time.Second, 60 * time.Second, nil}
} }
// Put returns *BeegoHttpRequest with PUT method.
func Put(url string) *BeegoHttpRequest { func Put(url string) *BeegoHttpRequest {
var req http.Request var req http.Request
req.Method = "PUT" req.Method = "PUT"
req.Header = http.Header{} req.Header = http.Header{}
req.Header.Set("User-Agent", defaultUserAgent) return &BeegoHttpRequest{url, &req, map[string]string{}, map[string]string{}, defaultSetting}
return &BeegoHttpRequest{url, &req, map[string]string{}, false, 60 * time.Second, 60 * time.Second, nil}
} }
// Delete returns *BeegoHttpRequest DELETE GET method.
func Delete(url string) *BeegoHttpRequest { func Delete(url string) *BeegoHttpRequest {
var req http.Request var req http.Request
req.Method = "DELETE" req.Method = "DELETE"
req.Header = http.Header{} req.Header = http.Header{}
req.Header.Set("User-Agent", defaultUserAgent) return &BeegoHttpRequest{url, &req, map[string]string{}, map[string]string{}, defaultSetting}
return &BeegoHttpRequest{url, &req, map[string]string{}, false, 60 * time.Second, 60 * time.Second, nil}
} }
// Head returns *BeegoHttpRequest with HEAD method.
func Head(url string) *BeegoHttpRequest { func Head(url string) *BeegoHttpRequest {
var req http.Request var req http.Request
req.Method = "HEAD" req.Method = "HEAD"
req.Header = http.Header{} req.Header = http.Header{}
req.Header.Set("User-Agent", defaultUserAgent) return &BeegoHttpRequest{url, &req, map[string]string{}, map[string]string{}, defaultSetting}
return &BeegoHttpRequest{url, &req, map[string]string{}, false, 60 * time.Second, 60 * time.Second, nil}
} }
// BeegoHttpSettings
type BeegoHttpSettings struct {
ShowDebug bool
UserAgent string
ConnectTimeout time.Duration
ReadWriteTimeout time.Duration
TlsClientConfig *tls.Config
Proxy func(*http.Request) (*url.URL, error)
Transport http.RoundTripper
EnableCookie bool
}
// BeegoHttpRequest provides more useful methods for requesting one url than http.Request.
type BeegoHttpRequest struct { type BeegoHttpRequest struct {
url string url string
req *http.Request req *http.Request
params map[string]string params map[string]string
showdebug bool files map[string]string
connectTimeout time.Duration setting BeegoHttpSettings
readWriteTimeout time.Duration
tlsClientConfig *tls.Config
} }
// Change request settings
func (b *BeegoHttpRequest) Setting(setting BeegoHttpSettings) *BeegoHttpRequest {
b.setting = setting
return b
}
// SetEnableCookie sets enable/disable cookiejar
func (b *BeegoHttpRequest) SetEnableCookie(enable bool) *BeegoHttpRequest {
b.setting.EnableCookie = enable
return b
}
// SetUserAgent sets User-Agent header field
func (b *BeegoHttpRequest) SetAgent(useragent string) *BeegoHttpRequest {
b.setting.UserAgent = useragent
return b
}
// Debug sets show debug or not when executing request.
func (b *BeegoHttpRequest) Debug(isdebug bool) *BeegoHttpRequest { func (b *BeegoHttpRequest) Debug(isdebug bool) *BeegoHttpRequest {
b.showdebug = isdebug b.setting.ShowDebug = isdebug
return b return b
} }
// SetTimeout sets connect time out and read-write time out for BeegoRequest.
func (b *BeegoHttpRequest) SetTimeout(connectTimeout, readWriteTimeout time.Duration) *BeegoHttpRequest { func (b *BeegoHttpRequest) SetTimeout(connectTimeout, readWriteTimeout time.Duration) *BeegoHttpRequest {
b.connectTimeout = connectTimeout b.setting.ConnectTimeout = connectTimeout
b.readWriteTimeout = readWriteTimeout b.setting.ReadWriteTimeout = readWriteTimeout
return b return b
} }
// SetTLSClientConfig sets tls connection configurations if visiting https url.
func (b *BeegoHttpRequest) SetTLSClientConfig(config *tls.Config) *BeegoHttpRequest { func (b *BeegoHttpRequest) SetTLSClientConfig(config *tls.Config) *BeegoHttpRequest {
b.tlsClientConfig = config b.setting.TlsClientConfig = config
return b return b
} }
// Header add header item string in request.
func (b *BeegoHttpRequest) Header(key, value string) *BeegoHttpRequest { func (b *BeegoHttpRequest) Header(key, value string) *BeegoHttpRequest {
b.req.Header.Set(key, value) b.req.Header.Set(key, value)
return b return b
} }
// Set the protocol version for incoming requests.
// Client requests always use HTTP/1.1.
func (b *BeegoHttpRequest) SetProtocolVersion(vers string) *BeegoHttpRequest {
if len(vers) == 0 {
vers = "HTTP/1.1"
}
major, minor, ok := http.ParseHTTPVersion(vers)
if ok {
b.req.Proto = vers
b.req.ProtoMajor = major
b.req.ProtoMinor = minor
}
return b
}
// SetCookie add cookie into request.
func (b *BeegoHttpRequest) SetCookie(cookie *http.Cookie) *BeegoHttpRequest { func (b *BeegoHttpRequest) SetCookie(cookie *http.Cookie) *BeegoHttpRequest {
b.req.Header.Add("Cookie", cookie.String()) b.req.Header.Add("Cookie", cookie.String())
return b return b
} }
// Set transport to
func (b *BeegoHttpRequest) SetTransport(transport http.RoundTripper) *BeegoHttpRequest {
b.setting.Transport = transport
return b
}
// Set http proxy
// example:
//
// func(req *http.Request) (*url.URL, error) {
// u, _ := url.ParseRequestURI("http://127.0.0.1:8118")
// return u, nil
// }
func (b *BeegoHttpRequest) SetProxy(proxy func(*http.Request) (*url.URL, error)) *BeegoHttpRequest {
b.setting.Proxy = proxy
return b
}
// Param adds query param in to request.
// params build query string as ?key1=value1&key2=value2...
func (b *BeegoHttpRequest) Param(key, value string) *BeegoHttpRequest { func (b *BeegoHttpRequest) Param(key, value string) *BeegoHttpRequest {
b.params[key] = value b.params[key] = value
return b return b
} }
func (b *BeegoHttpRequest) PostFile(formname, filename string) *BeegoHttpRequest {
b.files[formname] = filename
return b
}
// Body adds request raw body.
// it supports string and []byte.
func (b *BeegoHttpRequest) Body(data interface{}) *BeegoHttpRequest { func (b *BeegoHttpRequest) Body(data interface{}) *BeegoHttpRequest {
switch t := data.(type) { switch t := data.(type) {
case string: case string:
@ -134,9 +243,38 @@ func (b *BeegoHttpRequest) getResponse() (*http.Response, error) {
b.url = b.url + "?" + paramBody b.url = b.url + "?" + paramBody
} }
} else if b.req.Method == "POST" && b.req.Body == nil && len(paramBody) > 0 { } else if b.req.Method == "POST" && b.req.Body == nil && len(paramBody) > 0 {
if len(b.files) > 0 {
bodyBuf := &bytes.Buffer{}
bodyWriter := multipart.NewWriter(bodyBuf)
for formname, filename := range b.files {
fileWriter, err := bodyWriter.CreateFormFile(formname, filename)
if err != nil {
return nil, err
}
fh, err := os.Open(filename)
if err != nil {
return nil, err
}
//iocopy
_, err = io.Copy(fileWriter, fh)
fh.Close()
if err != nil {
return nil, err
}
}
for k, v := range b.params {
bodyWriter.WriteField(k, v)
}
contentType := bodyWriter.FormDataContentType()
bodyWriter.Close()
b.Header("Content-Type", contentType)
b.req.Body = ioutil.NopCloser(bodyBuf)
b.req.ContentLength = int64(bodyBuf.Len())
} else {
b.Header("Content-Type", "application/x-www-form-urlencoded") b.Header("Content-Type", "application/x-www-form-urlencoded")
b.Body(paramBody) b.Body(paramBody)
} }
}
url, err := url.Parse(b.url) url, err := url.Parse(b.url)
if url.Scheme == "" { if url.Scheme == "" {
@ -148,7 +286,7 @@ func (b *BeegoHttpRequest) getResponse() (*http.Response, error) {
} }
b.req.URL = url b.req.URL = url
if b.showdebug { if b.setting.ShowDebug {
dump, err := httputil.DumpRequest(b.req, true) dump, err := httputil.DumpRequest(b.req, true)
if err != nil { if err != nil {
println(err.Error()) println(err.Error())
@ -156,12 +294,49 @@ func (b *BeegoHttpRequest) getResponse() (*http.Response, error) {
println(string(dump)) println(string(dump))
} }
client := &http.Client{ trans := b.setting.Transport
Transport: &http.Transport{
TLSClientConfig: b.tlsClientConfig, if trans == nil {
Dial: TimeoutDialer(b.connectTimeout, b.readWriteTimeout), // create default transport
}, trans = &http.Transport{
TLSClientConfig: b.setting.TlsClientConfig,
Proxy: b.setting.Proxy,
Dial: TimeoutDialer(b.setting.ConnectTimeout, b.setting.ReadWriteTimeout),
} }
} else {
// if b.transport is *http.Transport then set the settings.
if t, ok := trans.(*http.Transport); ok {
if t.TLSClientConfig == nil {
t.TLSClientConfig = b.setting.TlsClientConfig
}
if t.Proxy == nil {
t.Proxy = b.setting.Proxy
}
if t.Dial == nil {
t.Dial = TimeoutDialer(b.setting.ConnectTimeout, b.setting.ReadWriteTimeout)
}
}
}
var jar http.CookieJar
if b.setting.EnableCookie {
if defaultCookieJar == nil {
createDefaultCookie()
}
jar = defaultCookieJar
} else {
jar = nil
}
client := &http.Client{
Transport: trans,
Jar: jar,
}
if b.setting.UserAgent != "" {
b.req.Header.Set("User-Agent", b.setting.UserAgent)
}
resp, err := client.Do(b.req) resp, err := client.Do(b.req)
if err != nil { if err != nil {
return nil, err return nil, err
@ -169,6 +344,8 @@ func (b *BeegoHttpRequest) getResponse() (*http.Response, error) {
return resp, nil return resp, nil
} }
// String returns the body string in response.
// it calls Response inner.
func (b *BeegoHttpRequest) String() (string, error) { func (b *BeegoHttpRequest) String() (string, error) {
data, err := b.Bytes() data, err := b.Bytes()
if err != nil { if err != nil {
@ -178,6 +355,8 @@ func (b *BeegoHttpRequest) String() (string, error) {
return string(data), nil return string(data), nil
} }
// Bytes returns the body []byte in response.
// it calls Response inner.
func (b *BeegoHttpRequest) Bytes() ([]byte, error) { func (b *BeegoHttpRequest) Bytes() ([]byte, error) {
resp, err := b.getResponse() resp, err := b.getResponse()
if err != nil { if err != nil {
@ -194,6 +373,8 @@ func (b *BeegoHttpRequest) Bytes() ([]byte, error) {
return data, nil return data, nil
} }
// ToFile saves the body data in response to one file.
// it calls Response inner.
func (b *BeegoHttpRequest) ToFile(filename string) error { func (b *BeegoHttpRequest) ToFile(filename string) error {
f, err := os.Create(filename) f, err := os.Create(filename)
if err != nil { if err != nil {
@ -216,6 +397,8 @@ func (b *BeegoHttpRequest) ToFile(filename string) error {
return nil return nil
} }
// ToJson returns the map that marshals from the body bytes as json in response .
// it calls Response inner.
func (b *BeegoHttpRequest) ToJson(v interface{}) error { func (b *BeegoHttpRequest) ToJson(v interface{}) error {
data, err := b.Bytes() data, err := b.Bytes()
if err != nil { if err != nil {
@ -228,6 +411,8 @@ func (b *BeegoHttpRequest) ToJson(v interface{}) error {
return nil return nil
} }
// ToXml returns the map that marshals from the body bytes as xml in response .
// it calls Response inner.
func (b *BeegoHttpRequest) ToXML(v interface{}) error { func (b *BeegoHttpRequest) ToXML(v interface{}) error {
data, err := b.Bytes() data, err := b.Bytes()
if err != nil { if err != nil {
@ -240,10 +425,12 @@ func (b *BeegoHttpRequest) ToXML(v interface{}) error {
return nil return nil
} }
// Response executes request client gets response mannually.
func (b *BeegoHttpRequest) Response() (*http.Response, error) { func (b *BeegoHttpRequest) Response() (*http.Response, error) {
return b.getResponse() return b.getResponse()
} }
// TimeoutDialer returns functions of connection dialer with timeout settings for http.Transport Dial field.
func TimeoutDialer(cTimeout time.Duration, rwTimeout time.Duration) func(net, addr string) (c net.Conn, err error) { func TimeoutDialer(cTimeout time.Duration, rwTimeout time.Duration) func(net, addr string) (c net.Conn, err error) {
return func(netw, addr string) (net.Conn, error) { return func(netw, addr string) (net.Conn, error) {
conn, err := net.DialTimeout(netw, addr, cTimeout) conn, err := net.DialTimeout(netw, addr, cTimeout)

View File

@ -1,12 +1,19 @@
// Beego (http://beego.me)
// @description beego is an open-source, high-performance web framework for the Go programming language.
// @link http://github.com/astaxie/beego for the canonical source repository
// @license http://github.com/astaxie/beego/blob/master/LICENSE
// @authors astaxie
package httplib package httplib
import ( import (
"fmt"
"io/ioutil" "io/ioutil"
"testing" "testing"
) )
func TestGetUrl(t *testing.T) { func TestGetUrl(t *testing.T) {
resp, err := Get("http://beego.me/").Debug(true).Response() resp, err := Get("http://beego.me").Debug(true).Response()
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -22,7 +29,7 @@ func TestGetUrl(t *testing.T) {
t.Fatal("data is no") t.Fatal("data is no")
} }
str, err := Get("http://beego.me/").String() str, err := Get("http://beego.me").String()
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -30,3 +37,64 @@ func TestGetUrl(t *testing.T) {
t.Fatal("has no info") t.Fatal("has no info")
} }
} }
func ExamplePost(t *testing.T) {
b := Post("http://beego.me/").Debug(true)
b.Param("username", "astaxie")
b.Param("password", "hello")
b.PostFile("uploadfile", "httplib_test.go")
str, err := b.String()
if err != nil {
t.Fatal(err)
}
fmt.Println(str)
}
func TestSimpleGetString(t *testing.T) {
fmt.Println("TestSimpleGetString==========================================")
html, err := Get("http://httpbin.org/headers").SetAgent("beegoooooo").String()
if err != nil {
t.Fatal(err)
}
fmt.Println(html)
fmt.Println("TestSimpleGetString==========================================")
}
func TestSimpleGetStringWithDefaultCookie(t *testing.T) {
fmt.Println("TestSimpleGetStringWithDefaultCookie==========================================")
html, err := Get("http://httpbin.org/cookies/set?k1=v1").SetEnableCookie(true).String()
if err != nil {
t.Fatal(err)
}
fmt.Println(html)
html, err = Get("http://httpbin.org/cookies").SetEnableCookie(true).String()
if err != nil {
t.Fatal(err)
}
fmt.Println(html)
fmt.Println("TestSimpleGetStringWithDefaultCookie==========================================")
}
func TestDefaultSetting(t *testing.T) {
fmt.Println("TestDefaultSetting==========================================")
var def BeegoHttpSettings
def.EnableCookie = true
//def.ShowDebug = true
def.UserAgent = "UserAgent"
//def.ConnectTimeout = 60*time.Second
//def.ReadWriteTimeout = 60*time.Second
def.Transport = nil //http.DefaultTransport
SetDefaultSetting(def)
html, err := Get("http://httpbin.org/headers").String()
if err != nil {
t.Fatal(err)
}
fmt.Println(html)
html, err = Get("http://httpbin.org/headers").String()
if err != nil {
t.Fatal(err)
}
fmt.Println(html)
fmt.Println("TestDefaultSetting==========================================")
}

19
log.go
View File

@ -1,3 +1,9 @@
// Beego (http://beego.me/)
// @description beego is an open-source, high-performance web framework for the Go programming language.
// @link http://github.com/astaxie/beego for the canonical source repository
// @license http://github.com/astaxie/beego/blob/master/LICENSE
// @authors astaxie
package beego package beego
import ( import (
@ -22,12 +28,21 @@ func SetLevel(l int) {
BeeLogger.SetLevel(l) BeeLogger.SetLevel(l)
} }
func SetLogFuncCall(b bool) {
BeeLogger.EnableFuncCallDepth(b)
BeeLogger.SetLogFuncCallDepth(3)
}
// logger references the used application logger. // logger references the used application logger.
var BeeLogger *logs.BeeLogger var BeeLogger *logs.BeeLogger
// SetLogger sets a new logger. // SetLogger sets a new logger.
func SetLogger(adaptername string, config string) { func SetLogger(adaptername string, config string) error {
BeeLogger.SetLogger(adaptername, config) err := BeeLogger.SetLogger(adaptername, config)
if err != nil {
return err
}
return nil
} }
// Trace logs a message at trace level. // Trace logs a message at trace level.

View File

@ -1,3 +1,9 @@
// Beego (http://beego.me/)
// @description beego is an open-source, high-performance web framework for the Go programming language.
// @link http://github.com/astaxie/beego for the canonical source repository
// @license http://github.com/astaxie/beego/blob/master/LICENSE
// @authors astaxie
package logs package logs
import ( import (
@ -7,6 +13,8 @@ import (
"net" "net"
) )
// ConnWriter implements LoggerInterface.
// it writes messages in keep-live tcp connection.
type ConnWriter struct { type ConnWriter struct {
lg *log.Logger lg *log.Logger
innerWriter io.WriteCloser innerWriter io.WriteCloser
@ -17,12 +25,15 @@ type ConnWriter struct {
Level int `json:"level"` Level int `json:"level"`
} }
// create new ConnWrite returning as LoggerInterface.
func NewConn() LoggerInterface { func NewConn() LoggerInterface {
conn := new(ConnWriter) conn := new(ConnWriter)
conn.Level = LevelTrace conn.Level = LevelTrace
return conn return conn
} }
// init connection writer with json config.
// json config only need key "level".
func (c *ConnWriter) Init(jsonconfig string) error { func (c *ConnWriter) Init(jsonconfig string) error {
err := json.Unmarshal([]byte(jsonconfig), c) err := json.Unmarshal([]byte(jsonconfig), c)
if err != nil { if err != nil {
@ -31,6 +42,8 @@ func (c *ConnWriter) Init(jsonconfig string) error {
return nil return nil
} }
// write message in connection.
// if connection is down, try to re-connect.
func (c *ConnWriter) WriteMsg(msg string, level int) error { func (c *ConnWriter) WriteMsg(msg string, level int) error {
if level < c.Level { if level < c.Level {
return nil return nil
@ -49,10 +62,12 @@ func (c *ConnWriter) WriteMsg(msg string, level int) error {
return nil return nil
} }
// implementing method. empty.
func (c *ConnWriter) Flush() { func (c *ConnWriter) Flush() {
} }
// destroy connection writer and close tcp listener.
func (c *ConnWriter) Destroy() { func (c *ConnWriter) Destroy() {
if c.innerWriter == nil { if c.innerWriter == nil {
return return

View File

@ -1,3 +1,9 @@
// Beego (http://beego.me/)
// @description beego is an open-source, high-performance web framework for the Go programming language.
// @link http://github.com/astaxie/beego for the canonical source repository
// @license http://github.com/astaxie/beego/blob/master/LICENSE
// @authors astaxie
package logs package logs
import ( import (

View File

@ -1,16 +1,44 @@
// Beego (http://beego.me/)
// @description beego is an open-source, high-performance web framework for the Go programming language.
// @link http://github.com/astaxie/beego for the canonical source repository
// @license http://github.com/astaxie/beego/blob/master/LICENSE
// @authors astaxie
package logs package logs
import ( import (
"encoding/json" "encoding/json"
"log" "log"
"os" "os"
"runtime"
) )
type Brush func(string) string
func NewBrush(color string) Brush {
pre := "\033["
reset := "\033[0m"
return func(text string) string {
return pre + color + "m" + text + reset
}
}
var colors = []Brush{
NewBrush("1;36"), // Trace cyan
NewBrush("1;34"), // Debug blue
NewBrush("1;32"), // Info green
NewBrush("1;33"), // Warn yellow
NewBrush("1;31"), // Error red
NewBrush("1;35"), // Critical purple
}
// ConsoleWriter implements LoggerInterface and writes messages to terminal.
type ConsoleWriter struct { type ConsoleWriter struct {
lg *log.Logger lg *log.Logger
Level int `json:"level"` Level int `json:"level"`
} }
// create ConsoleWriter returning as LoggerInterface.
func NewConsole() LoggerInterface { func NewConsole() LoggerInterface {
cw := new(ConsoleWriter) cw := new(ConsoleWriter)
cw.lg = log.New(os.Stdout, "", log.Ldate|log.Ltime) cw.lg = log.New(os.Stdout, "", log.Ldate|log.Ltime)
@ -18,7 +46,12 @@ func NewConsole() LoggerInterface {
return cw return cw
} }
// init console logger.
// jsonconfig like '{"level":LevelTrace}'.
func (c *ConsoleWriter) Init(jsonconfig string) error { func (c *ConsoleWriter) Init(jsonconfig string) error {
if len(jsonconfig) == 0 {
return nil
}
err := json.Unmarshal([]byte(jsonconfig), c) err := json.Unmarshal([]byte(jsonconfig), c)
if err != nil { if err != nil {
return err return err
@ -26,18 +59,25 @@ func (c *ConsoleWriter) Init(jsonconfig string) error {
return nil return nil
} }
// write message in console.
func (c *ConsoleWriter) WriteMsg(msg string, level int) error { func (c *ConsoleWriter) WriteMsg(msg string, level int) error {
if level < c.Level { if level < c.Level {
return nil return nil
} }
if goos := runtime.GOOS; goos == "windows" {
c.lg.Println(msg) c.lg.Println(msg)
} else {
c.lg.Println(colors[level](msg))
}
return nil return nil
} }
// implementing method. empty.
func (c *ConsoleWriter) Destroy() { func (c *ConsoleWriter) Destroy() {
} }
// implementing method. empty.
func (c *ConsoleWriter) Flush() { func (c *ConsoleWriter) Flush() {
} }

View File

@ -1,3 +1,9 @@
// Beego (http://beego.me/)
// @description beego is an open-source, high-performance web framework for the Go programming language.
// @link http://github.com/astaxie/beego for the canonical source repository
// @license http://github.com/astaxie/beego/blob/master/LICENSE
// @authors astaxie
package logs package logs
import ( import (
@ -6,6 +12,7 @@ import (
func TestConsole(t *testing.T) { func TestConsole(t *testing.T) {
log := NewLogger(10000) log := NewLogger(10000)
log.EnableFuncCallDepth(true)
log.SetLogger("console", "") log.SetLogger("console", "")
log.Trace("trace") log.Trace("trace")
log.Info("info") log.Info("info")
@ -23,6 +30,7 @@ func TestConsole(t *testing.T) {
func BenchmarkConsole(b *testing.B) { func BenchmarkConsole(b *testing.B) {
log := NewLogger(10000) log := NewLogger(10000)
log.EnableFuncCallDepth(true)
log.SetLogger("console", "") log.SetLogger("console", "")
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
log.Trace("trace") log.Trace("trace")

View File

@ -1,3 +1,9 @@
// Beego (http://beego.me/)
// @description beego is an open-source, high-performance web framework for the Go programming language.
// @link http://github.com/astaxie/beego for the canonical source repository
// @license http://github.com/astaxie/beego/blob/master/LICENSE
// @authors astaxie
package logs package logs
import ( import (
@ -13,6 +19,8 @@ import (
"time" "time"
) )
// FileLogWriter implements LoggerInterface.
// It writes messages by lines limit, file size limit, or time frequency.
type FileLogWriter struct { type FileLogWriter struct {
*log.Logger *log.Logger
mw *MuxWriter mw *MuxWriter
@ -28,7 +36,7 @@ type FileLogWriter struct {
// Rotate daily // Rotate daily
Daily bool `json:"daily"` Daily bool `json:"daily"`
Maxdays int64 `json:"maxdays` Maxdays int64 `json:"maxdays"`
daily_opendate int daily_opendate int
Rotate bool `json:"rotate"` Rotate bool `json:"rotate"`
@ -38,17 +46,20 @@ type FileLogWriter struct {
Level int `json:"level"` Level int `json:"level"`
} }
// an *os.File writer with locker.
type MuxWriter struct { type MuxWriter struct {
sync.Mutex sync.Mutex
fd *os.File fd *os.File
} }
// write to os.File.
func (l *MuxWriter) Write(b []byte) (int, error) { func (l *MuxWriter) Write(b []byte) (int, error) {
l.Lock() l.Lock()
defer l.Unlock() defer l.Unlock()
return l.fd.Write(b) return l.fd.Write(b)
} }
// set os.File in writer.
func (l *MuxWriter) SetFd(fd *os.File) { func (l *MuxWriter) SetFd(fd *os.File) {
if l.fd != nil { if l.fd != nil {
l.fd.Close() l.fd.Close()
@ -56,6 +67,7 @@ func (l *MuxWriter) SetFd(fd *os.File) {
l.fd = fd l.fd = fd
} }
// create a FileLogWriter returning as LoggerInterface.
func NewFileWriter() LoggerInterface { func NewFileWriter() LoggerInterface {
w := &FileLogWriter{ w := &FileLogWriter{
Filename: "", Filename: "",
@ -73,15 +85,16 @@ func NewFileWriter() LoggerInterface {
return w return w
} }
// jsonconfig like this // Init file logger with json config.
//{ // jsonconfig like:
// {
// "filename":"logs/beego.log", // "filename":"logs/beego.log",
// "maxlines":10000, // "maxlines":10000,
// "maxsize":1<<30, // "maxsize":1<<30,
// "daily":true, // "daily":true,
// "maxdays":15, // "maxdays":15,
// "rotate":true // "rotate":true
//} // }
func (w *FileLogWriter) Init(jsonconfig string) error { func (w *FileLogWriter) Init(jsonconfig string) error {
err := json.Unmarshal([]byte(jsonconfig), w) err := json.Unmarshal([]byte(jsonconfig), w)
if err != nil { if err != nil {
@ -90,11 +103,12 @@ func (w *FileLogWriter) Init(jsonconfig string) error {
if len(w.Filename) == 0 { if len(w.Filename) == 0 {
return errors.New("jsonconfig must have filename") return errors.New("jsonconfig must have filename")
} }
err = w.StartLogger() err = w.startLogger()
return err return err
} }
func (w *FileLogWriter) StartLogger() error { // start file logger. create log file and set to locker-inside file writer.
func (w *FileLogWriter) startLogger() error {
fd, err := w.createLogFile() fd, err := w.createLogFile()
if err != nil { if err != nil {
return err return err
@ -110,9 +124,9 @@ func (w *FileLogWriter) StartLogger() error {
func (w *FileLogWriter) docheck(size int) { func (w *FileLogWriter) docheck(size int) {
w.startLock.Lock() w.startLock.Lock()
defer w.startLock.Unlock() defer w.startLock.Unlock()
if (w.Maxlines > 0 && w.maxlines_curlines >= w.Maxlines) || if w.Rotate && ((w.Maxlines > 0 && w.maxlines_curlines >= w.Maxlines) ||
(w.Maxsize > 0 && w.maxsize_cursize >= w.Maxsize) || (w.Maxsize > 0 && w.maxsize_cursize >= w.Maxsize) ||
(w.Daily && time.Now().Day() != w.daily_opendate) { (w.Daily && time.Now().Day() != w.daily_opendate)) {
if err := w.DoRotate(); err != nil { if err := w.DoRotate(); err != nil {
fmt.Fprintf(os.Stderr, "FileLogWriter(%q): %s\n", w.Filename, err) fmt.Fprintf(os.Stderr, "FileLogWriter(%q): %s\n", w.Filename, err)
return return
@ -122,6 +136,7 @@ func (w *FileLogWriter) docheck(size int) {
w.maxsize_cursize += size w.maxsize_cursize += size
} }
// write logger message into file.
func (w *FileLogWriter) WriteMsg(msg string, level int) error { func (w *FileLogWriter) WriteMsg(msg string, level int) error {
if level < w.Level { if level < w.Level {
return nil return nil
@ -158,6 +173,8 @@ func (w *FileLogWriter) initFd() error {
return nil return nil
} }
// DoRotate means it need to write file in new file.
// new file name like xx.log.2013-01-01.2
func (w *FileLogWriter) DoRotate() error { func (w *FileLogWriter) DoRotate() error {
_, err := os.Lstat(w.Filename) _, err := os.Lstat(w.Filename)
if err == nil { // file exists if err == nil { // file exists
@ -188,7 +205,7 @@ func (w *FileLogWriter) DoRotate() error {
} }
// re-start logger // re-start logger
err = w.StartLogger() err = w.startLogger()
if err != nil { if err != nil {
return fmt.Errorf("Rotate StartLogger: %s\n", err) return fmt.Errorf("Rotate StartLogger: %s\n", err)
} }
@ -211,10 +228,14 @@ func (w *FileLogWriter) deleteOldLog() {
}) })
} }
// destroy file logger, close file writer.
func (w *FileLogWriter) Destroy() { func (w *FileLogWriter) Destroy() {
w.mw.fd.Close() w.mw.fd.Close()
} }
// flush file logger.
// there are no buffering messages in file logger in memory.
// flush file means sync file from disk.
func (w *FileLogWriter) Flush() { func (w *FileLogWriter) Flush() {
w.mw.fd.Sync() w.mw.fd.Sync()
} }

View File

@ -1,3 +1,9 @@
// Beego (http://beego.me/)
// @description beego is an open-source, high-performance web framework for the Go programming language.
// @link http://github.com/astaxie/beego for the canonical source repository
// @license http://github.com/astaxie/beego/blob/master/LICENSE
// @authors astaxie
package logs package logs
import ( import (

View File

@ -1,11 +1,20 @@
// Beego (http://beego.me/)
// @description beego is an open-source, high-performance web framework for the Go programming language.
// @link http://github.com/astaxie/beego for the canonical source repository
// @license http://github.com/astaxie/beego/blob/master/LICENSE
// @authors astaxie
package logs package logs
import ( import (
"fmt" "fmt"
"path"
"runtime"
"sync" "sync"
) )
const ( const (
// log message levels
LevelTrace = iota LevelTrace = iota
LevelDebug LevelDebug
LevelInfo LevelInfo
@ -16,6 +25,7 @@ const (
type loggerType func() LoggerInterface type loggerType func() LoggerInterface
// LoggerInterface defines the behavior of a log provider.
type LoggerInterface interface { type LoggerInterface interface {
Init(config string) error Init(config string) error
WriteMsg(msg string, level int) error WriteMsg(msg string, level int) error
@ -38,9 +48,13 @@ func Register(name string, log loggerType) {
adapters[name] = log adapters[name] = log
} }
// BeeLogger is default logger in beego application.
// it can contain several providers and log message into all providers.
type BeeLogger struct { type BeeLogger struct {
lock sync.Mutex lock sync.Mutex
level int level int
enableFuncCallDepth bool
loggerFuncCallDepth int
msg chan *logMsg msg chan *logMsg
outputs map[string]LoggerInterface outputs map[string]LoggerInterface
} }
@ -50,29 +64,39 @@ type logMsg struct {
msg string msg string
} }
// config need to be correct JSON as string: {"interval":360} // NewLogger returns a new BeeLogger.
// channellen means the number of messages in chan.
// if the buffering chan is full, logger adapters write to file or other way.
func NewLogger(channellen int64) *BeeLogger { func NewLogger(channellen int64) *BeeLogger {
bl := new(BeeLogger) bl := new(BeeLogger)
bl.loggerFuncCallDepth = 2
bl.msg = make(chan *logMsg, channellen) bl.msg = make(chan *logMsg, channellen)
bl.outputs = make(map[string]LoggerInterface) bl.outputs = make(map[string]LoggerInterface)
//bl.SetLogger("console", "") // default output to console //bl.SetLogger("console", "") // default output to console
go bl.StartLogger() go bl.startLogger()
return bl return bl
} }
// SetLogger provides a given logger adapter into BeeLogger with config string.
// config need to be correct JSON as string: {"interval":360}.
func (bl *BeeLogger) SetLogger(adaptername string, config string) error { func (bl *BeeLogger) SetLogger(adaptername string, config string) error {
bl.lock.Lock() bl.lock.Lock()
defer bl.lock.Unlock() defer bl.lock.Unlock()
if log, ok := adapters[adaptername]; ok { if log, ok := adapters[adaptername]; ok {
lg := log() lg := log()
lg.Init(config) err := lg.Init(config)
bl.outputs[adaptername] = lg bl.outputs[adaptername] = lg
return nil if err != nil {
fmt.Println("logs.BeeLogger.SetLogger: " + err.Error())
return err
}
} else { } else {
return fmt.Errorf("logs: unknown adaptername %q (forgotten Register?)", adaptername) return fmt.Errorf("logs: unknown adaptername %q (forgotten Register?)", adaptername)
} }
return nil
} }
// remove a logger adapter in BeeLogger.
func (bl *BeeLogger) DelLogger(adaptername string) error { func (bl *BeeLogger) DelLogger(adaptername string) error {
bl.lock.Lock() bl.lock.Lock()
defer bl.lock.Unlock() defer bl.lock.Unlock()
@ -91,16 +115,40 @@ func (bl *BeeLogger) writerMsg(loglevel int, msg string) error {
} }
lm := new(logMsg) lm := new(logMsg)
lm.level = loglevel lm.level = loglevel
if bl.enableFuncCallDepth {
_, file, line, ok := runtime.Caller(bl.loggerFuncCallDepth)
if ok {
_, filename := path.Split(file)
lm.msg = fmt.Sprintf("[%s:%d] %s", filename, line, msg)
} else {
lm.msg = msg lm.msg = msg
}
} else {
lm.msg = msg
}
bl.msg <- lm bl.msg <- lm
return nil return nil
} }
// set log message level.
// if message level (such as LevelTrace) is less than logger level (such as LevelWarn), ignore message.
func (bl *BeeLogger) SetLevel(l int) { func (bl *BeeLogger) SetLevel(l int) {
bl.level = l bl.level = l
} }
func (bl *BeeLogger) StartLogger() { // set log funcCallDepth
func (bl *BeeLogger) SetLogFuncCallDepth(d int) {
bl.loggerFuncCallDepth = d
}
// enable log funcCallDepth
func (bl *BeeLogger) EnableFuncCallDepth(b bool) {
bl.enableFuncCallDepth = b
}
// start logger chan reading.
// when chan is full, write logs.
func (bl *BeeLogger) startLogger() {
for { for {
select { select {
case bm := <-bl.msg: case bm := <-bl.msg:
@ -111,43 +159,50 @@ func (bl *BeeLogger) StartLogger() {
} }
} }
// log trace level message.
func (bl *BeeLogger) Trace(format string, v ...interface{}) { func (bl *BeeLogger) Trace(format string, v ...interface{}) {
msg := fmt.Sprintf("[T] "+format, v...) msg := fmt.Sprintf("[T] "+format, v...)
bl.writerMsg(LevelTrace, msg) bl.writerMsg(LevelTrace, msg)
} }
// log debug level message.
func (bl *BeeLogger) Debug(format string, v ...interface{}) { func (bl *BeeLogger) Debug(format string, v ...interface{}) {
msg := fmt.Sprintf("[D] "+format, v...) msg := fmt.Sprintf("[D] "+format, v...)
bl.writerMsg(LevelDebug, msg) bl.writerMsg(LevelDebug, msg)
} }
// log info level message.
func (bl *BeeLogger) Info(format string, v ...interface{}) { func (bl *BeeLogger) Info(format string, v ...interface{}) {
msg := fmt.Sprintf("[I] "+format, v...) msg := fmt.Sprintf("[I] "+format, v...)
bl.writerMsg(LevelInfo, msg) bl.writerMsg(LevelInfo, msg)
} }
// log warn level message.
func (bl *BeeLogger) Warn(format string, v ...interface{}) { func (bl *BeeLogger) Warn(format string, v ...interface{}) {
msg := fmt.Sprintf("[W] "+format, v...) msg := fmt.Sprintf("[W] "+format, v...)
bl.writerMsg(LevelWarn, msg) bl.writerMsg(LevelWarn, msg)
} }
// log error level message.
func (bl *BeeLogger) Error(format string, v ...interface{}) { func (bl *BeeLogger) Error(format string, v ...interface{}) {
msg := fmt.Sprintf("[E] "+format, v...) msg := fmt.Sprintf("[E] "+format, v...)
bl.writerMsg(LevelError, msg) bl.writerMsg(LevelError, msg)
} }
// log critical level message.
func (bl *BeeLogger) Critical(format string, v ...interface{}) { func (bl *BeeLogger) Critical(format string, v ...interface{}) {
msg := fmt.Sprintf("[C] "+format, v...) msg := fmt.Sprintf("[C] "+format, v...)
bl.writerMsg(LevelCritical, msg) bl.writerMsg(LevelCritical, msg)
} }
//flush all chan data // flush all chan data.
func (bl *BeeLogger) Flush() { func (bl *BeeLogger) Flush() {
for _, l := range bl.outputs { for _, l := range bl.outputs {
l.Flush() l.Flush()
} }
} }
// close logger, flush all chan data and destroy all adapters in BeeLogger.
func (bl *BeeLogger) Close() { func (bl *BeeLogger) Close() {
for { for {
if len(bl.msg) > 0 { if len(bl.msg) > 0 {

View File

@ -1,3 +1,9 @@
// Beego (http://beego.me/)
// @description beego is an open-source, high-performance web framework for the Go programming language.
// @link http://github.com/astaxie/beego for the canonical source repository
// @license http://github.com/astaxie/beego/blob/master/LICENSE
// @authors astaxie
package logs package logs
import ( import (
@ -12,7 +18,7 @@ const (
subjectPhrase = "Diagnostic message from server" subjectPhrase = "Diagnostic message from server"
) )
// smtpWriter is used to send emails via given SMTP-server. // smtpWriter implements LoggerInterface and is used to send emails via given SMTP-server.
type SmtpWriter struct { type SmtpWriter struct {
Username string `json:"Username"` Username string `json:"Username"`
Password string `json:"password"` Password string `json:"password"`
@ -22,10 +28,21 @@ type SmtpWriter struct {
Level int `json:"level"` Level int `json:"level"`
} }
// create smtp writer.
func NewSmtpWriter() LoggerInterface { func NewSmtpWriter() LoggerInterface {
return &SmtpWriter{Level: LevelTrace} return &SmtpWriter{Level: LevelTrace}
} }
// init smtp writer with json config.
// config like:
// {
// "Username":"example@gmail.com",
// "password:"password",
// "host":"smtp.gmail.com:465",
// "subject":"email title",
// "sendTos":["email1","email2"],
// "level":LevelError
// }
func (s *SmtpWriter) Init(jsonconfig string) error { func (s *SmtpWriter) Init(jsonconfig string) error {
err := json.Unmarshal([]byte(jsonconfig), s) err := json.Unmarshal([]byte(jsonconfig), s)
if err != nil { if err != nil {
@ -34,6 +51,8 @@ func (s *SmtpWriter) Init(jsonconfig string) error {
return nil return nil
} }
// write message in smtp writer.
// it will send an email with subject and only this message.
func (s *SmtpWriter) WriteMsg(msg string, level int) error { func (s *SmtpWriter) WriteMsg(msg string, level int) error {
if level < s.Level { if level < s.Level {
return nil return nil
@ -65,9 +84,12 @@ func (s *SmtpWriter) WriteMsg(msg string, level int) error {
return err return err
} }
// implementing method. empty.
func (s *SmtpWriter) Flush() { func (s *SmtpWriter) Flush() {
return return
} }
// implementing method. empty.
func (s *SmtpWriter) Destroy() { func (s *SmtpWriter) Destroy() {
return return
} }

View File

@ -1,3 +1,9 @@
// Beego (http://beego.me/)
// @description beego is an open-source, high-performance web framework for the Go programming language.
// @link http://github.com/astaxie/beego for the canonical source repository
// @license http://github.com/astaxie/beego/blob/master/LICENSE
// @authors astaxie
package logs package logs
import ( import (

View File

@ -1,3 +1,9 @@
// Beego (http://beego.me/)
// @description beego is an open-source, high-performance web framework for the Go programming language.
// @link http://github.com/astaxie/beego for the canonical source repository
// @license http://github.com/astaxie/beego/blob/master/LICENSE
// @authors astaxie
package beego package beego
import ( import (
@ -5,20 +11,21 @@ import (
"compress/flate" "compress/flate"
"compress/gzip" "compress/gzip"
"errors" "errors"
//"fmt"
"io" "io"
"io/ioutil" "io/ioutil"
"net/http" "net/http"
"os" "os"
"strings" "strings"
"sync"
"time" "time"
) )
var gmfim map[string]*MemFileInfo = make(map[string]*MemFileInfo) var gmfim map[string]*memFileInfo = make(map[string]*memFileInfo)
var lock sync.RWMutex
// OpenMemZipFile returns MemFile object with a compressed static file. // OpenMemZipFile returns MemFile object with a compressed static file.
// it's used for serve static file if gzip enable. // it's used for serve static file if gzip enable.
func OpenMemZipFile(path string, zip string) (*MemFile, error) { func openMemZipFile(path string, zip string) (*memFile, error) {
osfile, e := os.Open(path) osfile, e := os.Open(path)
if e != nil { if e != nil {
return nil, e return nil, e
@ -32,15 +39,12 @@ func OpenMemZipFile(path string, zip string) (*MemFile, error) {
modtime := osfileinfo.ModTime() modtime := osfileinfo.ModTime()
fileSize := osfileinfo.Size() fileSize := osfileinfo.Size()
lock.RLock()
cfi, ok := gmfim[zip+":"+path] cfi, ok := gmfim[zip+":"+path]
if ok && cfi.ModTime() == modtime && cfi.fileSize == fileSize { lock.RUnlock()
//fmt.Printf("read %s file %s from cache\n", zip, path) if !(ok && cfi.ModTime() == modtime && cfi.fileSize == fileSize) {
} else {
//fmt.Printf("NOT read %s file %s from cache\n", zip, path)
var content []byte var content []byte
if zip == "gzip" { if zip == "gzip" {
//将文件内容压缩到zipbuf中
var zipbuf bytes.Buffer var zipbuf bytes.Buffer
gzipwriter, e := gzip.NewWriterLevel(&zipbuf, gzip.BestCompression) gzipwriter, e := gzip.NewWriterLevel(&zipbuf, gzip.BestCompression)
if e != nil { if e != nil {
@ -51,13 +55,11 @@ func OpenMemZipFile(path string, zip string) (*MemFile, error) {
if e != nil { if e != nil {
return nil, e return nil, e
} }
//读zipbuf到content
content, e = ioutil.ReadAll(&zipbuf) content, e = ioutil.ReadAll(&zipbuf)
if e != nil { if e != nil {
return nil, e return nil, e
} }
} else if zip == "deflate" { } else if zip == "deflate" {
//将文件内容压缩到zipbuf中
var zipbuf bytes.Buffer var zipbuf bytes.Buffer
deflatewriter, e := flate.NewWriter(&zipbuf, flate.BestCompression) deflatewriter, e := flate.NewWriter(&zipbuf, flate.BestCompression)
if e != nil { if e != nil {
@ -68,7 +70,6 @@ func OpenMemZipFile(path string, zip string) (*MemFile, error) {
if e != nil { if e != nil {
return nil, e return nil, e
} }
//将zipbuf读入到content
content, e = ioutil.ReadAll(&zipbuf) content, e = ioutil.ReadAll(&zipbuf)
if e != nil { if e != nil {
return nil, e return nil, e
@ -80,16 +81,17 @@ func OpenMemZipFile(path string, zip string) (*MemFile, error) {
} }
} }
cfi = &MemFileInfo{osfileinfo, modtime, content, int64(len(content)), fileSize} cfi = &memFileInfo{osfileinfo, modtime, content, int64(len(content)), fileSize}
lock.Lock()
defer lock.Unlock()
gmfim[zip+":"+path] = cfi gmfim[zip+":"+path] = cfi
//fmt.Printf("%s file %s to %d, cache it\n", zip, path, len(content))
} }
return &MemFile{fi: cfi, offset: 0}, nil return &memFile{fi: cfi, offset: 0}, nil
} }
// MemFileInfo contains a compressed file bytes and file information. // MemFileInfo contains a compressed file bytes and file information.
// it implements os.FileInfo interface. // it implements os.FileInfo interface.
type MemFileInfo struct { type memFileInfo struct {
os.FileInfo os.FileInfo
modTime time.Time modTime time.Time
content []byte content []byte
@ -98,62 +100,62 @@ type MemFileInfo struct {
} }
// Name returns the compressed filename. // Name returns the compressed filename.
func (fi *MemFileInfo) Name() string { func (fi *memFileInfo) Name() string {
return fi.Name() return fi.Name()
} }
// Size returns the raw file content size, not compressed size. // Size returns the raw file content size, not compressed size.
func (fi *MemFileInfo) Size() int64 { func (fi *memFileInfo) Size() int64 {
return fi.contentSize return fi.contentSize
} }
// Mode returns file mode. // Mode returns file mode.
func (fi *MemFileInfo) Mode() os.FileMode { func (fi *memFileInfo) Mode() os.FileMode {
return fi.Mode() return fi.Mode()
} }
// ModTime returns the last modified time of raw file. // ModTime returns the last modified time of raw file.
func (fi *MemFileInfo) ModTime() time.Time { func (fi *memFileInfo) ModTime() time.Time {
return fi.modTime return fi.modTime
} }
// IsDir returns the compressing file is a directory or not. // IsDir returns the compressing file is a directory or not.
func (fi *MemFileInfo) IsDir() bool { func (fi *memFileInfo) IsDir() bool {
return fi.IsDir() return fi.IsDir()
} }
// return nil. implement the os.FileInfo interface method. // return nil. implement the os.FileInfo interface method.
func (fi *MemFileInfo) Sys() interface{} { func (fi *memFileInfo) Sys() interface{} {
return nil return nil
} }
// MemFile contains MemFileInfo and bytes offset when reading. // MemFile contains MemFileInfo and bytes offset when reading.
// it implements io.Reader,io.ReadCloser and io.Seeker. // it implements io.Reader,io.ReadCloser and io.Seeker.
type MemFile struct { type memFile struct {
fi *MemFileInfo fi *memFileInfo
offset int64 offset int64
} }
// Close memfile. // Close memfile.
func (f *MemFile) Close() error { func (f *memFile) Close() error {
return nil return nil
} }
// Get os.FileInfo of memfile. // Get os.FileInfo of memfile.
func (f *MemFile) Stat() (os.FileInfo, error) { func (f *memFile) Stat() (os.FileInfo, error) {
return f.fi, nil return f.fi, nil
} }
// read os.FileInfo of files in directory of memfile. // read os.FileInfo of files in directory of memfile.
// it returns empty slice. // it returns empty slice.
func (f *MemFile) Readdir(count int) ([]os.FileInfo, error) { func (f *memFile) Readdir(count int) ([]os.FileInfo, error) {
infos := []os.FileInfo{} infos := []os.FileInfo{}
return infos, nil return infos, nil
} }
// Read bytes from the compressed file bytes. // Read bytes from the compressed file bytes.
func (f *MemFile) Read(p []byte) (n int, err error) { func (f *memFile) Read(p []byte) (n int, err error) {
if len(f.fi.content)-int(f.offset) >= len(p) { if len(f.fi.content)-int(f.offset) >= len(p) {
n = len(p) n = len(p)
} else { } else {
@ -169,7 +171,7 @@ var errWhence = errors.New("Seek: invalid whence")
var errOffset = errors.New("Seek: invalid offset") var errOffset = errors.New("Seek: invalid offset")
// Read bytes from the compressed file bytes by seeker. // Read bytes from the compressed file bytes by seeker.
func (f *MemFile) Seek(offset int64, whence int) (ret int64, err error) { func (f *memFile) Seek(offset int64, whence int) (ret int64, err error) {
switch whence { switch whence {
default: default:
return 0, errWhence return 0, errWhence
@ -189,7 +191,7 @@ func (f *MemFile) Seek(offset int64, whence int) (ret int64, err error) {
// GetAcceptEncodingZip returns accept encoding format in http header. // GetAcceptEncodingZip returns accept encoding format in http header.
// zip is first, then deflate if both accepted. // zip is first, then deflate if both accepted.
// If no accepted, return empty string. // If no accepted, return empty string.
func GetAcceptEncodingZip(r *http.Request) string { func getAcceptEncodingZip(r *http.Request) string {
ss := r.Header.Get("Accept-Encoding") ss := r.Header.Get("Accept-Encoding")
ss = strings.ToLower(ss) ss = strings.ToLower(ss)
if strings.Contains(ss, "gzip") { if strings.Contains(ss, "gzip") {
@ -199,24 +201,4 @@ func GetAcceptEncodingZip(r *http.Request) string {
} else { } else {
return "" return ""
} }
return ""
}
// CloseZWriter closes the io.Writer after compressing static file.
func CloseZWriter(zwriter io.Writer) {
if zwriter == nil {
return
}
switch zwriter.(type) {
case *gzip.Writer:
zwriter.(*gzip.Writer).Close()
case *flate.Writer:
zwriter.(*flate.Writer).Close()
//其他情况不close, 保持和默认(非压缩)行为一致
/*
case io.WriteCloser:
zwriter.(io.WriteCloser).Close()
*/
}
} }

View File

@ -1,3 +1,9 @@
// Beego (http://beego.me/)
// @description beego is an open-source, high-performance web framework for the Go programming language.
// @link http://github.com/astaxie/beego for the canonical source repository
// @license http://github.com/astaxie/beego/blob/master/LICENSE
// @authors astaxie
package middleware package middleware
import ( import (
@ -61,6 +67,7 @@ var tpl = `
</html> </html>
` `
// render default application error page with error and stack string.
func ShowErr(err interface{}, rw http.ResponseWriter, r *http.Request, Stack string) { func ShowErr(err interface{}, rw http.ResponseWriter, r *http.Request, Stack string) {
t, _ := template.New("beegoerrortemp").Parse(tpl) t, _ := template.New("beegoerrortemp").Parse(tpl)
data := make(map[string]string) data := make(map[string]string)
@ -71,6 +78,7 @@ func ShowErr(err interface{}, rw http.ResponseWriter, r *http.Request, Stack str
data["Stack"] = Stack data["Stack"] = Stack
data["BeegoVersion"] = VERSION data["BeegoVersion"] = VERSION
data["GoVersion"] = runtime.Version() data["GoVersion"] = runtime.Version()
rw.WriteHeader(500)
t.Execute(rw, data) t.Execute(rw, data)
} }
@ -166,7 +174,7 @@ var errtpl = `
{{.Content}} {{.Content}}
<a href="/" title="Home" class="button">Go Home</a><br /> <a href="/" title="Home" class="button">Go Home</a><br />
<br>power by beego {{.BeegoVersion}} <br>Powered by beego {{.BeegoVersion}}
</div> </div>
</div> </div>
</div> </div>
@ -174,18 +182,19 @@ var errtpl = `
</html> </html>
` `
// map of http handlers for each error string.
var ErrorMaps map[string]http.HandlerFunc var ErrorMaps map[string]http.HandlerFunc
func init() { func init() {
ErrorMaps = make(map[string]http.HandlerFunc) ErrorMaps = make(map[string]http.HandlerFunc)
} }
//404 // show 404 notfound error.
func NotFound(rw http.ResponseWriter, r *http.Request) { func NotFound(rw http.ResponseWriter, r *http.Request) {
t, _ := template.New("beegoerrortemp").Parse(errtpl) t, _ := template.New("beegoerrortemp").Parse(errtpl)
data := make(map[string]interface{}) data := make(map[string]interface{})
data["Title"] = "Page Not Found" data["Title"] = "Page Not Found"
data["Content"] = template.HTML("<br>The Page You have requested flown the coop." + data["Content"] = template.HTML("<br>The page you have requested has flown the coop." +
"<br>Perhaps you are here because:" + "<br>Perhaps you are here because:" +
"<br><br><ul>" + "<br><br><ul>" +
"<br>The page has moved" + "<br>The page has moved" +
@ -198,28 +207,28 @@ func NotFound(rw http.ResponseWriter, r *http.Request) {
t.Execute(rw, data) t.Execute(rw, data)
} }
//401 // show 401 unauthorized error.
func Unauthorized(rw http.ResponseWriter, r *http.Request) { func Unauthorized(rw http.ResponseWriter, r *http.Request) {
t, _ := template.New("beegoerrortemp").Parse(errtpl) t, _ := template.New("beegoerrortemp").Parse(errtpl)
data := make(map[string]interface{}) data := make(map[string]interface{})
data["Title"] = "Unauthorized" data["Title"] = "Unauthorized"
data["Content"] = template.HTML("<br>The Page You have requested can't authorized." + data["Content"] = template.HTML("<br>The page you have requested can't be authorized." +
"<br>Perhaps you are here because:" + "<br>Perhaps you are here because:" +
"<br><br><ul>" + "<br><br><ul>" +
"<br>Check the credentials that you supplied" + "<br>The credentials you supplied are incorrect" +
"<br>Check the address for errors" + "<br>There are errors in the website address" +
"</ul>") "</ul>")
data["BeegoVersion"] = VERSION data["BeegoVersion"] = VERSION
//rw.WriteHeader(http.StatusUnauthorized) //rw.WriteHeader(http.StatusUnauthorized)
t.Execute(rw, data) t.Execute(rw, data)
} }
//403 // show 403 forbidden error.
func Forbidden(rw http.ResponseWriter, r *http.Request) { func Forbidden(rw http.ResponseWriter, r *http.Request) {
t, _ := template.New("beegoerrortemp").Parse(errtpl) t, _ := template.New("beegoerrortemp").Parse(errtpl)
data := make(map[string]interface{}) data := make(map[string]interface{})
data["Title"] = "Forbidden" data["Title"] = "Forbidden"
data["Content"] = template.HTML("<br>The Page You have requested forbidden." + data["Content"] = template.HTML("<br>The page you have requested is forbidden." +
"<br>Perhaps you are here because:" + "<br>Perhaps you are here because:" +
"<br><br><ul>" + "<br><br><ul>" +
"<br>Your address may be blocked" + "<br>Your address may be blocked" +
@ -231,12 +240,12 @@ func Forbidden(rw http.ResponseWriter, r *http.Request) {
t.Execute(rw, data) t.Execute(rw, data)
} }
//503 // show 503 service unavailable error.
func ServiceUnavailable(rw http.ResponseWriter, r *http.Request) { func ServiceUnavailable(rw http.ResponseWriter, r *http.Request) {
t, _ := template.New("beegoerrortemp").Parse(errtpl) t, _ := template.New("beegoerrortemp").Parse(errtpl)
data := make(map[string]interface{}) data := make(map[string]interface{})
data["Title"] = "Service Unavailable" data["Title"] = "Service Unavailable"
data["Content"] = template.HTML("<br>The Page You have requested unavailable." + data["Content"] = template.HTML("<br>The page you have requested is unavailable." +
"<br>Perhaps you are here because:" + "<br>Perhaps you are here because:" +
"<br><br><ul>" + "<br><br><ul>" +
"<br><br>The page is overloaded" + "<br><br>The page is overloaded" +
@ -247,30 +256,32 @@ func ServiceUnavailable(rw http.ResponseWriter, r *http.Request) {
t.Execute(rw, data) t.Execute(rw, data)
} }
//500 // show 500 internal server error.
func InternalServerError(rw http.ResponseWriter, r *http.Request) { func InternalServerError(rw http.ResponseWriter, r *http.Request) {
t, _ := template.New("beegoerrortemp").Parse(errtpl) t, _ := template.New("beegoerrortemp").Parse(errtpl)
data := make(map[string]interface{}) data := make(map[string]interface{})
data["Title"] = "Internal Server Error" data["Title"] = "Internal Server Error"
data["Content"] = template.HTML("<br>The Page You have requested has down now." + data["Content"] = template.HTML("<br>The page you have requested is down right now." +
"<br><br><ul>" + "<br><br><ul>" +
"<br>simply try again later" + "<br>Please try again later and report the error to the website administrator" +
"<br>you should report the fault to the website administrator" + "<br></ul>")
"</ul>")
data["BeegoVersion"] = VERSION data["BeegoVersion"] = VERSION
//rw.WriteHeader(http.StatusInternalServerError) //rw.WriteHeader(http.StatusInternalServerError)
t.Execute(rw, data) t.Execute(rw, data)
} }
// show 500 internal error with simple text string.
func SimpleServerError(rw http.ResponseWriter, r *http.Request) { func SimpleServerError(rw http.ResponseWriter, r *http.Request) {
http.Error(rw, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) http.Error(rw, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
} }
// add http handler for given error string.
func Errorhandler(err string, h http.HandlerFunc) { func Errorhandler(err string, h http.HandlerFunc) {
ErrorMaps[err] = h ErrorMaps[err] = h
} }
func RegisterErrorHander() { // register default error http handlers, 404,401,403,500 and 503.
func RegisterErrorHandler() {
if _, ok := ErrorMaps["404"]; !ok { if _, ok := ErrorMaps["404"]; !ok {
ErrorMaps["404"] = NotFound ErrorMaps["404"] = NotFound
} }
@ -292,6 +303,8 @@ func RegisterErrorHander() {
} }
} }
// show error string as simple text message.
// if error string is empty, show 500 error as default.
func Exception(errcode string, w http.ResponseWriter, r *http.Request, msg string) { func Exception(errcode string, w http.ResponseWriter, r *http.Request, msg string) {
if h, ok := ErrorMaps[errcode]; ok { if h, ok := ErrorMaps[errcode]; ok {
isint, err := strconv.Atoi(errcode) isint, err := strconv.Atoi(errcode)

View File

@ -1,17 +1,26 @@
// Beego (http://beego.me/)
// @description beego is an open-source, high-performance web framework for the Go programming language.
// @link http://github.com/astaxie/beego for the canonical source repository
// @license http://github.com/astaxie/beego/blob/master/LICENSE
// @authors astaxie
package middleware package middleware
import "fmt" import "fmt"
// http exceptions
type HTTPException struct { type HTTPException struct {
StatusCode int // http status code 4xx, 5xx StatusCode int // http status code 4xx, 5xx
Description string Description string
} }
// return http exception error string, e.g. "400 Bad Request".
func (e *HTTPException) Error() string { func (e *HTTPException) Error() string {
// return `status description`, e.g. `400 Bad Request`
return fmt.Sprintf("%d %s", e.StatusCode, e.Description) return fmt.Sprintf("%d %s", e.StatusCode, e.Description)
} }
// map of http exceptions for each http status code int.
// defined 400,401,403,404,405,500,502,503 and 504 default.
var HTTPExceptionMaps map[int]HTTPException var HTTPExceptionMaps map[int]HTTPException
func init() { func init() {

View File

@ -1,3 +1,9 @@
// Beego (http://beego.me/)
// @description beego is an open-source, high-performance web framework for the Go programming language.
// @link http://github.com/astaxie/beego for the canonical source repository
// @license http://github.com/astaxie/beego/blob/master/LICENSE
// @authors astaxie
package middleware package middleware
//import ( //import (

View File

@ -1,3 +1,9 @@
// Beego (http://beego.me/)
// @description beego is an open-source, high-performance web framework for the Go programming language.
// @link http://github.com/astaxie/beego for the canonical source repository
// @license http://github.com/astaxie/beego/blob/master/LICENSE
// @authors astaxie
package beego package beego
import ( import (
@ -544,8 +550,9 @@ var mimemaps map[string]string = map[string]string{
".mustache": "text/html", ".mustache": "text/html",
} }
func initMime() { func initMime() error {
for k, v := range mimemaps { for k, v := range mimemaps {
mime.AddExtensionType(k, v) mime.AddExtensionType(k, v)
} }
return nil
} }

378
namespace.go Normal file
View File

@ -0,0 +1,378 @@
// Beego (http://beego.me/)
// @description beego is an open-source, high-performance web framework for the Go programming language.
// @link http://github.com/astaxie/beego for the canonical source repository
// @license http://github.com/astaxie/beego/blob/master/LICENSE
// @authors astaxie
package beego
import (
"net/http"
beecontext "github.com/astaxie/beego/context"
"github.com/astaxie/beego/middleware"
)
type namespaceCond func(*beecontext.Context) bool
type innnerNamespace func(*Namespace)
// Namespace is store all the info
type Namespace struct {
prefix string
handlers *ControllerRegistor
}
// get new Namespace
func NewNamespace(prefix string, params ...innnerNamespace) *Namespace {
ns := &Namespace{
prefix: prefix,
handlers: NewControllerRegister(),
}
for _, p := range params {
p(ns)
}
return ns
}
// set condtion function
// if cond return true can run this namespace, else can't
// usage:
// ns.Cond(func (ctx *context.Context) bool{
// if ctx.Input.Domain() == "api.beego.me" {
// return true
// }
// return false
// })
// Cond as the first filter
func (n *Namespace) Cond(cond namespaceCond) *Namespace {
fn := func(ctx *beecontext.Context) {
if !cond(ctx) {
middleware.Exception("405", ctx.ResponseWriter, ctx.Request, "Method not allowed")
}
}
if v, ok := n.handlers.filters[BeforeRouter]; ok {
mr := new(FilterRouter)
mr.tree = NewTree()
mr.pattern = "*"
mr.filterFunc = fn
mr.tree.AddRouter("*", true)
n.handlers.filters[BeforeRouter] = append([]*FilterRouter{mr}, v...)
} else {
n.handlers.InsertFilter("*", BeforeRouter, fn)
}
return n
}
// add filter in the Namespace
// action has before & after
// FilterFunc
// usage:
// Filter("before", func (ctx *context.Context){
// _, ok := ctx.Input.Session("uid").(int)
// if !ok && ctx.Request.RequestURI != "/login" {
// ctx.Redirect(302, "/login")
// }
// })
func (n *Namespace) Filter(action string, filter ...FilterFunc) *Namespace {
var a int
if action == "before" {
a = BeforeRouter
} else if action == "after" {
a = FinishRouter
}
for _, f := range filter {
n.handlers.InsertFilter("*", a, f)
}
return n
}
// same as beego.Rourer
// refer: https://godoc.org/github.com/astaxie/beego#Router
func (n *Namespace) Router(rootpath string, c ControllerInterface, mappingMethods ...string) *Namespace {
n.handlers.Add(rootpath, c, mappingMethods...)
return n
}
// same as beego.AutoRouter
// refer: https://godoc.org/github.com/astaxie/beego#AutoRouter
func (n *Namespace) AutoRouter(c ControllerInterface) *Namespace {
n.handlers.AddAuto(c)
return n
}
// same as beego.AutoPrefix
// refer: https://godoc.org/github.com/astaxie/beego#AutoPrefix
func (n *Namespace) AutoPrefix(prefix string, c ControllerInterface) *Namespace {
n.handlers.AddAutoPrefix(prefix, c)
return n
}
// same as beego.Get
// refer: https://godoc.org/github.com/astaxie/beego#Get
func (n *Namespace) Get(rootpath string, f FilterFunc) *Namespace {
n.handlers.Get(rootpath, f)
return n
}
// same as beego.Post
// refer: https://godoc.org/github.com/astaxie/beego#Post
func (n *Namespace) Post(rootpath string, f FilterFunc) *Namespace {
n.handlers.Post(rootpath, f)
return n
}
// same as beego.Delete
// refer: https://godoc.org/github.com/astaxie/beego#Delete
func (n *Namespace) Delete(rootpath string, f FilterFunc) *Namespace {
n.handlers.Delete(rootpath, f)
return n
}
// same as beego.Put
// refer: https://godoc.org/github.com/astaxie/beego#Put
func (n *Namespace) Put(rootpath string, f FilterFunc) *Namespace {
n.handlers.Put(rootpath, f)
return n
}
// same as beego.Head
// refer: https://godoc.org/github.com/astaxie/beego#Head
func (n *Namespace) Head(rootpath string, f FilterFunc) *Namespace {
n.handlers.Head(rootpath, f)
return n
}
// same as beego.Options
// refer: https://godoc.org/github.com/astaxie/beego#Options
func (n *Namespace) Options(rootpath string, f FilterFunc) *Namespace {
n.handlers.Options(rootpath, f)
return n
}
// same as beego.Patch
// refer: https://godoc.org/github.com/astaxie/beego#Patch
func (n *Namespace) Patch(rootpath string, f FilterFunc) *Namespace {
n.handlers.Patch(rootpath, f)
return n
}
// same as beego.Any
// refer: https://godoc.org/github.com/astaxie/beego#Any
func (n *Namespace) Any(rootpath string, f FilterFunc) *Namespace {
n.handlers.Any(rootpath, f)
return n
}
// same as beego.Handler
// refer: https://godoc.org/github.com/astaxie/beego#Handler
func (n *Namespace) Handler(rootpath string, h http.Handler) *Namespace {
n.handlers.Handler(rootpath, h)
return n
}
// add include class
// refer: https://godoc.org/github.com/astaxie/beego#Include
func (n *Namespace) Include(cList ...ControllerInterface) *Namespace {
n.handlers.Include(cList...)
return n
}
// nest Namespace
// usage:
//ns := beego.NewNamespace(“/v1”).
//Namespace(
// beego.NewNamespace("/shop").
// Get("/:id", func(ctx *context.Context) {
// ctx.Output.Body([]byte("shopinfo"))
// }),
// beego.NewNamespace("/order").
// Get("/:id", func(ctx *context.Context) {
// ctx.Output.Body([]byte("orderinfo"))
// }),
// beego.NewNamespace("/crm").
// Get("/:id", func(ctx *context.Context) {
// ctx.Output.Body([]byte("crminfo"))
// }),
//)
func (n *Namespace) Namespace(ns ...*Namespace) *Namespace {
for _, ni := range ns {
for k, v := range ni.handlers.routers {
if t, ok := n.handlers.routers[k]; ok {
addPrefix(v, ni.prefix)
n.handlers.routers[k].AddTree(ni.prefix, v)
} else {
t = NewTree()
t.AddTree(ni.prefix, v)
addPrefix(t, ni.prefix)
n.handlers.routers[k] = t
}
}
if n.handlers.enableFilter {
for pos, filterList := range ni.handlers.filters {
for _, mr := range filterList {
t := NewTree()
t.AddTree(ni.prefix, mr.tree)
mr.tree = t
n.handlers.insertFilterRouter(pos, mr)
}
}
}
}
return n
}
// register Namespace into beego.Handler
// support multi Namespace
func AddNamespace(nl ...*Namespace) {
for _, n := range nl {
for k, v := range n.handlers.routers {
if t, ok := BeeApp.Handlers.routers[k]; ok {
addPrefix(v, n.prefix)
BeeApp.Handlers.routers[k].AddTree(n.prefix, v)
} else {
t = NewTree()
t.AddTree(n.prefix, v)
addPrefix(t, n.prefix)
BeeApp.Handlers.routers[k] = t
}
}
if n.handlers.enableFilter {
for pos, filterList := range n.handlers.filters {
for _, mr := range filterList {
t := NewTree()
t.AddTree(n.prefix, mr.tree)
mr.tree = t
BeeApp.Handlers.insertFilterRouter(pos, mr)
}
}
}
}
}
func addPrefix(t *Tree, prefix string) {
for _, v := range t.fixrouters {
addPrefix(v, prefix)
}
if t.wildcard != nil {
addPrefix(t.wildcard, prefix)
}
for _, l := range t.leaves {
if c, ok := l.runObject.(*controllerInfo); ok {
c.pattern = prefix + c.pattern
}
}
}
// Namespace Condition
func NSCond(cond namespaceCond) innnerNamespace {
return func(ns *Namespace) {
ns.Cond(cond)
}
}
// Namespace BeforeRouter filter
func NSBefore(filiterList ...FilterFunc) innnerNamespace {
return func(ns *Namespace) {
ns.Filter("before", filiterList...)
}
}
// Namespace FinishRouter filter
func NSAfter(filiterList ...FilterFunc) innnerNamespace {
return func(ns *Namespace) {
ns.Filter("after", filiterList...)
}
}
// Namespace Include ControllerInterface
func NSInclude(cList ...ControllerInterface) innnerNamespace {
return func(ns *Namespace) {
ns.Include(cList...)
}
}
// Namespace Router
func NSRouter(rootpath string, c ControllerInterface, mappingMethods ...string) innnerNamespace {
return func(ns *Namespace) {
ns.Router(rootpath, c, mappingMethods...)
}
}
// Namespace Get
func NSGet(rootpath string, f FilterFunc) innnerNamespace {
return func(ns *Namespace) {
ns.Get(rootpath, f)
}
}
// Namespace Post
func NSPost(rootpath string, f FilterFunc) innnerNamespace {
return func(ns *Namespace) {
ns.Post(rootpath, f)
}
}
// Namespace Head
func NSHead(rootpath string, f FilterFunc) innnerNamespace {
return func(ns *Namespace) {
ns.Head(rootpath, f)
}
}
// Namespace Put
func NSPut(rootpath string, f FilterFunc) innnerNamespace {
return func(ns *Namespace) {
ns.Put(rootpath, f)
}
}
// Namespace Delete
func NSDelete(rootpath string, f FilterFunc) innnerNamespace {
return func(ns *Namespace) {
ns.Delete(rootpath, f)
}
}
// Namespace Any
func NSAny(rootpath string, f FilterFunc) innnerNamespace {
return func(ns *Namespace) {
ns.Any(rootpath, f)
}
}
// Namespace Options
func NSOptions(rootpath string, f FilterFunc) innnerNamespace {
return func(ns *Namespace) {
ns.Options(rootpath, f)
}
}
// Namespace Patch
func NSPatch(rootpath string, f FilterFunc) innnerNamespace {
return func(ns *Namespace) {
ns.Patch(rootpath, f)
}
}
//Namespace AutoRouter
func NSAutoRouter(c ControllerInterface) innnerNamespace {
return func(ns *Namespace) {
ns.AutoRouter(c)
}
}
// Namespace AutoPrefix
func NSAutoPrefix(prefix string, c ControllerInterface) innnerNamespace {
return func(ns *Namespace) {
ns.AutoPrefix(prefix, c)
}
}
// Namespace add sub Namespace
func NSNamespace(prefix string, params ...innnerNamespace) innnerNamespace {
return func(ns *Namespace) {
n := NewNamespace(prefix, params...)
ns.Namespace(n)
}
}

163
namespace_test.go Normal file
View File

@ -0,0 +1,163 @@
// Beego (http://beego.me/)
// @description beego is an open-source, high-performance web framework for the Go programming language.
// @link http://github.com/astaxie/beego for the canonical source repository
// @license http://github.com/astaxie/beego/blob/master/LICENSE
// @authors astaxie
package beego
import (
"net/http"
"net/http/httptest"
"strconv"
"testing"
"github.com/astaxie/beego/context"
)
func TestNamespaceGet(t *testing.T) {
r, _ := http.NewRequest("GET", "/v1/user", nil)
w := httptest.NewRecorder()
ns := NewNamespace("/v1")
ns.Get("/user", func(ctx *context.Context) {
ctx.Output.Body([]byte("v1_user"))
})
AddNamespace(ns)
BeeApp.Handlers.ServeHTTP(w, r)
if w.Body.String() != "v1_user" {
t.Errorf("TestNamespaceGet can't run, get the response is " + w.Body.String())
}
}
func TestNamespacePost(t *testing.T) {
r, _ := http.NewRequest("POST", "/v1/user/123", nil)
w := httptest.NewRecorder()
ns := NewNamespace("/v1")
ns.Post("/user/:id", func(ctx *context.Context) {
ctx.Output.Body([]byte(ctx.Input.Param(":id")))
})
AddNamespace(ns)
BeeApp.Handlers.ServeHTTP(w, r)
if w.Body.String() != "123" {
t.Errorf("TestNamespacePost can't run, get the response is " + w.Body.String())
}
}
func TestNamespaceNest(t *testing.T) {
r, _ := http.NewRequest("GET", "/v1/admin/order", nil)
w := httptest.NewRecorder()
ns := NewNamespace("/v1")
ns.Namespace(
NewNamespace("/admin").
Get("/order", func(ctx *context.Context) {
ctx.Output.Body([]byte("order"))
}),
)
AddNamespace(ns)
BeeApp.Handlers.ServeHTTP(w, r)
if w.Body.String() != "order" {
t.Errorf("TestNamespaceNest can't run, get the response is " + w.Body.String())
}
}
func TestNamespaceNestParam(t *testing.T) {
r, _ := http.NewRequest("GET", "/v1/admin/order/123", nil)
w := httptest.NewRecorder()
ns := NewNamespace("/v1")
ns.Namespace(
NewNamespace("/admin").
Get("/order/:id", func(ctx *context.Context) {
ctx.Output.Body([]byte(ctx.Input.Param(":id")))
}),
)
AddNamespace(ns)
BeeApp.Handlers.ServeHTTP(w, r)
if w.Body.String() != "123" {
t.Errorf("TestNamespaceNestParam can't run, get the response is " + w.Body.String())
}
}
func TestNamespaceRouter(t *testing.T) {
r, _ := http.NewRequest("GET", "/v1/api/list", nil)
w := httptest.NewRecorder()
ns := NewNamespace("/v1")
ns.Router("/api/list", &TestController{}, "*:List")
AddNamespace(ns)
BeeApp.Handlers.ServeHTTP(w, r)
if w.Body.String() != "i am list" {
t.Errorf("TestNamespaceRouter can't run, get the response is " + w.Body.String())
}
}
func TestNamespaceAutoFunc(t *testing.T) {
r, _ := http.NewRequest("GET", "/v1/test/list", nil)
w := httptest.NewRecorder()
ns := NewNamespace("/v1")
ns.AutoRouter(&TestController{})
AddNamespace(ns)
BeeApp.Handlers.ServeHTTP(w, r)
if w.Body.String() != "i am list" {
t.Errorf("user define func can't run")
}
}
func TestNamespaceFilter(t *testing.T) {
r, _ := http.NewRequest("GET", "/v1/user/123", nil)
w := httptest.NewRecorder()
ns := NewNamespace("/v1")
ns.Filter("before", func(ctx *context.Context) {
ctx.Output.Body([]byte("this is Filter"))
}).
Get("/user/:id", func(ctx *context.Context) {
ctx.Output.Body([]byte(ctx.Input.Param(":id")))
})
AddNamespace(ns)
BeeApp.Handlers.ServeHTTP(w, r)
if w.Body.String() != "this is Filter" {
t.Errorf("TestNamespaceFilter can't run, get the response is " + w.Body.String())
}
}
func TestNamespaceCond(t *testing.T) {
r, _ := http.NewRequest("GET", "/v2/test/list", nil)
w := httptest.NewRecorder()
ns := NewNamespace("/v2")
ns.Cond(func(ctx *context.Context) bool {
if ctx.Input.Domain() == "beego.me" {
return true
}
return false
}).
AutoRouter(&TestController{})
AddNamespace(ns)
BeeApp.Handlers.ServeHTTP(w, r)
if w.Code != 405 {
t.Errorf("TestNamespaceCond can't run get the result " + strconv.Itoa(w.Code))
}
}
func TestNamespaceInside(t *testing.T) {
r, _ := http.NewRequest("GET", "/v3/shop/order/123", nil)
w := httptest.NewRecorder()
ns := NewNamespace("/v3",
NSAutoRouter(&TestController{}),
NSNamespace("/shop",
NSGet("/order/:id", func(ctx *context.Context) {
ctx.Output.Body([]byte(ctx.Input.Param(":id")))
}),
),
)
AddNamespace(ns)
BeeApp.Handlers.ServeHTTP(w, r)
if w.Body.String() != "123" {
t.Errorf("TestNamespaceInside can't run, get the response is " + w.Body.String())
}
}

View File

@ -1,3 +1,9 @@
// Beego (http://beego.me/)
// @description beego is an open-source, high-performance web framework for the Go programming language.
// @link http://github.com/astaxie/beego for the canonical source repository
// @license http://github.com/astaxie/beego/blob/master/LICENSE
// @authors slene
package orm package orm
import ( import (
@ -16,6 +22,7 @@ var (
commands = make(map[string]commander) commands = make(map[string]commander)
) )
// print help.
func printHelp(errs ...string) { func printHelp(errs ...string) {
content := `orm command usage: content := `orm command usage:
@ -31,6 +38,7 @@ func printHelp(errs ...string) {
os.Exit(2) os.Exit(2)
} }
// listen for orm command and then run it if command arguments passed.
func RunCommand() { func RunCommand() {
if len(os.Args) < 2 || os.Args[1] != "orm" { if len(os.Args) < 2 || os.Args[1] != "orm" {
return return
@ -58,6 +66,7 @@ func RunCommand() {
} }
} }
// sync database struct command interface.
type commandSyncDb struct { type commandSyncDb struct {
al *alias al *alias
force bool force bool
@ -66,6 +75,7 @@ type commandSyncDb struct {
rtOnError bool rtOnError bool
} }
// parse orm command line arguments.
func (d *commandSyncDb) Parse(args []string) { func (d *commandSyncDb) Parse(args []string) {
var name string var name string
@ -78,6 +88,7 @@ func (d *commandSyncDb) Parse(args []string) {
d.al = getDbAlias(name) d.al = getDbAlias(name)
} }
// run orm line command.
func (d *commandSyncDb) Run() error { func (d *commandSyncDb) Run() error {
var drops []string var drops []string
if d.force { if d.force {
@ -208,10 +219,12 @@ func (d *commandSyncDb) Run() error {
return nil return nil
} }
// database creation commander interface implement.
type commandSqlAll struct { type commandSqlAll struct {
al *alias al *alias
} }
// parse orm command line arguments.
func (d *commandSqlAll) Parse(args []string) { func (d *commandSqlAll) Parse(args []string) {
var name string var name string
@ -222,6 +235,7 @@ func (d *commandSqlAll) Parse(args []string) {
d.al = getDbAlias(name) d.al = getDbAlias(name)
} }
// run orm line command.
func (d *commandSqlAll) Run() error { func (d *commandSqlAll) Run() error {
sqls, indexes := getDbCreateSql(d.al) sqls, indexes := getDbCreateSql(d.al)
var all []string var all []string
@ -243,6 +257,10 @@ func init() {
commands["sqlall"] = new(commandSqlAll) commands["sqlall"] = new(commandSqlAll)
} }
// run syncdb command line.
// name means table's alias name. default is "default".
// force means run next sql if the current is error.
// verbose means show all info when running command or not.
func RunSyncdb(name string, force bool, verbose bool) error { func RunSyncdb(name string, force bool, verbose bool) error {
BootStrap() BootStrap()

View File

@ -1,3 +1,9 @@
// Beego (http://beego.me/)
// @description beego is an open-source, high-performance web framework for the Go programming language.
// @link http://github.com/astaxie/beego for the canonical source repository
// @license http://github.com/astaxie/beego/blob/master/LICENSE
// @authors slene
package orm package orm
import ( import (
@ -12,6 +18,7 @@ type dbIndex struct {
Sql string Sql string
} }
// create database drop sql.
func getDbDropSql(al *alias) (sqls []string) { func getDbDropSql(al *alias) (sqls []string) {
if len(modelCache.cache) == 0 { if len(modelCache.cache) == 0 {
fmt.Println("no Model found, need register your model") fmt.Println("no Model found, need register your model")
@ -26,6 +33,7 @@ func getDbDropSql(al *alias) (sqls []string) {
return sqls return sqls
} }
// get database column type string.
func getColumnTyp(al *alias, fi *fieldInfo) (col string) { func getColumnTyp(al *alias, fi *fieldInfo) (col string) {
T := al.DbBaser.DbTypes() T := al.DbBaser.DbTypes()
fieldType := fi.fieldType fieldType := fi.fieldType
@ -79,6 +87,7 @@ checkColumn:
return return
} }
// create alter sql string.
func getColumnAddQuery(al *alias, fi *fieldInfo) string { func getColumnAddQuery(al *alias, fi *fieldInfo) string {
Q := al.DbBaser.TableQuote() Q := al.DbBaser.TableQuote()
typ := getColumnTyp(al, fi) typ := getColumnTyp(al, fi)
@ -90,6 +99,7 @@ func getColumnAddQuery(al *alias, fi *fieldInfo) string {
return fmt.Sprintf("ALTER TABLE %s%s%s ADD COLUMN %s%s%s %s", Q, fi.mi.table, Q, Q, fi.column, Q, typ) return fmt.Sprintf("ALTER TABLE %s%s%s ADD COLUMN %s%s%s %s", Q, fi.mi.table, Q, Q, fi.column, Q, typ)
} }
// create database creation string.
func getDbCreateSql(al *alias) (sqls []string, tableIndexes map[string][]dbIndex) { func getDbCreateSql(al *alias) (sqls []string, tableIndexes map[string][]dbIndex) {
if len(modelCache.cache) == 0 { if len(modelCache.cache) == 0 {
fmt.Println("no Model found, need register your model") fmt.Println("no Model found, need register your model")

276
orm/db.go
View File

@ -1,3 +1,9 @@
// Beego (http://beego.me/)
// @description beego is an open-source, high-performance web framework for the Go programming language.
// @link http://github.com/astaxie/beego for the canonical source repository
// @license http://github.com/astaxie/beego/blob/master/LICENSE
// @authors slene
package orm package orm
import ( import (
@ -15,7 +21,7 @@ const (
) )
var ( var (
ErrMissPK = errors.New("missed pk value") ErrMissPK = errors.New("missed pk value") // missing pk error
) )
var ( var (
@ -35,7 +41,7 @@ var (
"istartswith": true, "istartswith": true,
"iendswith": true, "iendswith": true,
"in": true, "in": true,
// "range": true, "between": true,
// "year": true, // "year": true,
// "month": true, // "month": true,
// "day": true, // "day": true,
@ -45,13 +51,22 @@ var (
} }
) )
// an instance of dbBaser interface/
type dbBase struct { type dbBase struct {
ins dbBaser ins dbBaser
} }
// check dbBase implements dbBaser interface.
var _ dbBaser = new(dbBase) var _ dbBaser = new(dbBase)
func (d *dbBase) collectValues(mi *modelInfo, ind reflect.Value, cols []string, skipAuto bool, insert bool, tz *time.Location) (columns []string, values []interface{}, err error) { // get struct columns values as interface slice.
func (d *dbBase) collectValues(mi *modelInfo, ind reflect.Value, cols []string, skipAuto bool, insert bool, names *[]string, tz *time.Location) (values []interface{}, err error) {
var columns []string
if names != nil {
columns = *names
}
for _, column := range cols { for _, column := range cols {
var fi *fieldInfo var fi *fieldInfo
if fi, _ = mi.fields.GetByAny(column); fi != nil { if fi, _ = mi.fields.GetByAny(column); fi != nil {
@ -64,14 +79,24 @@ func (d *dbBase) collectValues(mi *modelInfo, ind reflect.Value, cols []string,
} }
value, err := d.collectFieldValue(mi, fi, ind, insert, tz) value, err := d.collectFieldValue(mi, fi, ind, insert, tz)
if err != nil { if err != nil {
return nil, nil, err return nil, err
} }
if names != nil {
columns = append(columns, column) columns = append(columns, column)
}
values = append(values, value) values = append(values, value)
} }
if names != nil {
*names = columns
}
return return
} }
// get one field value in struct column as interface.
func (d *dbBase) collectFieldValue(mi *modelInfo, fi *fieldInfo, ind reflect.Value, insert bool, tz *time.Location) (interface{}, error) { func (d *dbBase) collectFieldValue(mi *modelInfo, fi *fieldInfo, ind reflect.Value, insert bool, tz *time.Location) (interface{}, error) {
var value interface{} var value interface{}
if fi.pk { if fi.pk {
@ -84,28 +109,60 @@ func (d *dbBase) collectFieldValue(mi *modelInfo, fi *fieldInfo, ind reflect.Val
} else { } else {
switch fi.fieldType { switch fi.fieldType {
case TypeBooleanField: case TypeBooleanField:
if nb, ok := field.Interface().(sql.NullBool); ok {
value = nil
if nb.Valid {
value = nb.Bool
}
} else {
value = field.Bool() value = field.Bool()
}
case TypeCharField, TypeTextField: case TypeCharField, TypeTextField:
if ns, ok := field.Interface().(sql.NullString); ok {
value = nil
if ns.Valid {
value = ns.String
}
} else {
value = field.String() value = field.String()
}
case TypeFloatField, TypeDecimalField: case TypeFloatField, TypeDecimalField:
if nf, ok := field.Interface().(sql.NullFloat64); ok {
value = nil
if nf.Valid {
value = nf.Float64
}
} else {
vu := field.Interface() vu := field.Interface()
if _, ok := vu.(float32); ok { if _, ok := vu.(float32); ok {
value, _ = StrTo(ToStr(vu)).Float64() value, _ = StrTo(ToStr(vu)).Float64()
} else { } else {
value = field.Float() value = field.Float()
} }
}
case TypeDateField, TypeDateTimeField: case TypeDateField, TypeDateTimeField:
value = field.Interface() value = field.Interface()
if t, ok := value.(time.Time); ok { if t, ok := value.(time.Time); ok {
d.ins.TimeToDB(&t, tz) d.ins.TimeToDB(&t, tz)
if t.IsZero() {
value = nil
} else {
value = t value = t
} }
}
default: default:
switch { switch {
case fi.fieldType&IsPostiveIntegerField > 0: case fi.fieldType&IsPostiveIntegerField > 0:
value = field.Uint() value = field.Uint()
case fi.fieldType&IsIntegerField > 0: case fi.fieldType&IsIntegerField > 0:
if ni, ok := field.Interface().(sql.NullInt64); ok {
value = nil
if ni.Valid {
value = ni.Int64
}
} else {
value = field.Int() value = field.Int()
}
case fi.fieldType&IsRelField > 0: case fi.fieldType&IsRelField > 0:
if field.IsNil() { if field.IsNil() {
value = nil value = nil
@ -125,6 +182,11 @@ func (d *dbBase) collectFieldValue(mi *modelInfo, fi *fieldInfo, ind reflect.Val
switch fi.fieldType { switch fi.fieldType {
case TypeDateField, TypeDateTimeField: case TypeDateField, TypeDateTimeField:
if fi.auto_now || fi.auto_now_add && insert { if fi.auto_now || fi.auto_now_add && insert {
if insert {
if t, ok := value.(time.Time); ok && !t.IsZero() {
break
}
}
tnow := time.Now() tnow := time.Now()
d.ins.TimeToDB(&tnow, tz) d.ins.TimeToDB(&tnow, tz)
value = tnow value = tnow
@ -140,6 +202,7 @@ func (d *dbBase) collectFieldValue(mi *modelInfo, fi *fieldInfo, ind reflect.Val
return value, nil return value, nil
} }
// create insert sql preparation statement object.
func (d *dbBase) PrepareInsert(q dbQuerier, mi *modelInfo) (stmtQuerier, string, error) { func (d *dbBase) PrepareInsert(q dbQuerier, mi *modelInfo) (stmtQuerier, string, error) {
Q := d.ins.TableQuote() Q := d.ins.TableQuote()
@ -165,8 +228,9 @@ func (d *dbBase) PrepareInsert(q dbQuerier, mi *modelInfo) (stmtQuerier, string,
return stmt, query, err return stmt, query, err
} }
// insert struct with prepared statement and given struct reflect value.
func (d *dbBase) InsertStmt(stmt stmtQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location) (int64, error) { func (d *dbBase) InsertStmt(stmt stmtQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location) (int64, error) {
_, values, err := d.collectValues(mi, ind, mi.fields.dbcols, true, true, tz) values, err := d.collectValues(mi, ind, mi.fields.dbcols, true, true, nil, tz)
if err != nil { if err != nil {
return 0, err return 0, err
} }
@ -185,6 +249,7 @@ func (d *dbBase) InsertStmt(stmt stmtQuerier, mi *modelInfo, ind reflect.Value,
} }
} }
// query sql ,read records and persist in dbBaser.
func (d *dbBase) Read(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location, cols []string) error { func (d *dbBase) Read(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location, cols []string) error {
var whereCols []string var whereCols []string
var args []interface{} var args []interface{}
@ -192,7 +257,8 @@ func (d *dbBase) Read(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Lo
// if specify cols length > 0, then use it for where condition. // if specify cols length > 0, then use it for where condition.
if len(cols) > 0 { if len(cols) > 0 {
var err error var err error
whereCols, args, err = d.collectValues(mi, ind, cols, false, false, tz) whereCols = make([]string, 0, len(cols))
args, err = d.collectValues(mi, ind, cols, false, false, &whereCols, tz)
if err != nil { if err != nil {
return err return err
} }
@ -202,7 +268,7 @@ func (d *dbBase) Read(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Lo
if ok == false { if ok == false {
return ErrMissPK return ErrMissPK
} }
whereCols = append(whereCols, pkColumn) whereCols = []string{pkColumn}
args = append(args, pkValue) args = append(args, pkValue)
} }
@ -243,16 +309,77 @@ func (d *dbBase) Read(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Lo
return nil return nil
} }
// execute insert sql dbQuerier with given struct reflect.Value.
func (d *dbBase) Insert(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location) (int64, error) { func (d *dbBase) Insert(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location) (int64, error) {
names, values, err := d.collectValues(mi, ind, mi.fields.dbcols, true, true, tz) names := make([]string, 0, len(mi.fields.dbcols)-1)
values, err := d.collectValues(mi, ind, mi.fields.dbcols, true, true, &names, tz)
if err != nil { if err != nil {
return 0, err return 0, err
} }
return d.InsertValue(q, mi, names, values) return d.InsertValue(q, mi, false, names, values)
} }
func (d *dbBase) InsertValue(q dbQuerier, mi *modelInfo, names []string, values []interface{}) (int64, error) { // multi-insert sql with given slice struct reflect.Value.
func (d *dbBase) InsertMulti(q dbQuerier, mi *modelInfo, sind reflect.Value, bulk int, tz *time.Location) (int64, error) {
var (
cnt int64
nums int
values []interface{}
names []string
)
// typ := reflect.Indirect(mi.addrField).Type()
length := sind.Len()
for i := 1; i <= length; i++ {
ind := reflect.Indirect(sind.Index(i - 1))
// Is this needed ?
// if !ind.Type().AssignableTo(typ) {
// return cnt, ErrArgs
// }
if i == 1 {
vus, err := d.collectValues(mi, ind, mi.fields.dbcols, true, true, &names, tz)
if err != nil {
return cnt, err
}
values = make([]interface{}, bulk*len(vus))
nums += copy(values, vus)
} else {
vus, err := d.collectValues(mi, ind, mi.fields.dbcols, true, true, nil, tz)
if err != nil {
return cnt, err
}
if len(vus) != len(names) {
return cnt, ErrArgs
}
nums += copy(values[nums:], vus)
}
if i > 1 && i%bulk == 0 || length == i {
num, err := d.InsertValue(q, mi, true, names, values[:nums])
if err != nil {
return cnt, err
}
cnt += num
nums = 0
}
}
return cnt, nil
}
// execute insert sql with given struct and given values.
// insert the given values, not the field values in struct.
func (d *dbBase) InsertValue(q dbQuerier, mi *modelInfo, isMulti bool, names []string, values []interface{}) (int64, error) {
Q := d.ins.TableQuote() Q := d.ins.TableQuote()
marks := make([]string, len(names)) marks := make([]string, len(names))
@ -264,36 +391,51 @@ func (d *dbBase) InsertValue(q dbQuerier, mi *modelInfo, names []string, values
qmarks := strings.Join(marks, ", ") qmarks := strings.Join(marks, ", ")
columns := strings.Join(names, sep) columns := strings.Join(names, sep)
multi := len(values) / len(names)
if isMulti {
qmarks = strings.Repeat(qmarks+"), (", multi-1) + qmarks
}
query := fmt.Sprintf("INSERT INTO %s%s%s (%s%s%s) VALUES (%s)", Q, mi.table, Q, Q, columns, Q, qmarks) query := fmt.Sprintf("INSERT INTO %s%s%s (%s%s%s) VALUES (%s)", Q, mi.table, Q, Q, columns, Q, qmarks)
d.ins.ReplaceMarks(&query) d.ins.ReplaceMarks(&query)
if d.ins.HasReturningID(mi, &query) { if isMulti || !d.ins.HasReturningID(mi, &query) {
row := q.QueryRow(query, values...)
var id int64
err := row.Scan(&id)
return id, err
} else {
if res, err := q.Exec(query, values...); err == nil { if res, err := q.Exec(query, values...); err == nil {
if isMulti {
return res.RowsAffected()
}
return res.LastInsertId() return res.LastInsertId()
} else { } else {
return 0, err return 0, err
} }
} else {
row := q.QueryRow(query, values...)
var id int64
err := row.Scan(&id)
return id, err
} }
} }
// execute update sql dbQuerier with given struct reflect.Value.
func (d *dbBase) Update(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location, cols []string) (int64, error) { func (d *dbBase) Update(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location, cols []string) (int64, error) {
pkName, pkValue, ok := getExistPk(mi, ind) pkName, pkValue, ok := getExistPk(mi, ind)
if ok == false { if ok == false {
return 0, ErrMissPK return 0, ErrMissPK
} }
var setNames []string
// if specify cols length is zero, then commit all columns. // if specify cols length is zero, then commit all columns.
if len(cols) == 0 { if len(cols) == 0 {
cols = mi.fields.dbcols cols = mi.fields.dbcols
setNames = make([]string, 0, len(mi.fields.dbcols)-1)
} else {
setNames = make([]string, 0, len(cols))
} }
setNames, setValues, err := d.collectValues(mi, ind, cols, true, false, tz) setValues, err := d.collectValues(mi, ind, cols, true, false, &setNames, tz)
if err != nil { if err != nil {
return 0, err return 0, err
} }
@ -314,9 +456,10 @@ func (d *dbBase) Update(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.
} else { } else {
return 0, err return 0, err
} }
return 0, nil
} }
// execute delete sql dbQuerier with given struct reflect.Value.
// delete index is pk.
func (d *dbBase) Delete(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location) (int64, error) { func (d *dbBase) Delete(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location) (int64, error) {
pkName, pkValue, ok := getExistPk(mi, ind) pkName, pkValue, ok := getExistPk(mi, ind)
if ok == false { if ok == false {
@ -355,9 +498,10 @@ func (d *dbBase) Delete(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.
} else { } else {
return 0, err return 0, err
} }
return 0, nil
} }
// update table-related record by querySet.
// need querySet not struct reflect.Value to update related records.
func (d *dbBase) UpdateBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition, params Params, tz *time.Location) (int64, error) { func (d *dbBase) UpdateBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition, params Params, tz *time.Location) (int64, error) {
columns := make([]string, 0, len(params)) columns := make([]string, 0, len(params))
values := make([]interface{}, 0, len(params)) values := make([]interface{}, 0, len(params))
@ -430,9 +574,10 @@ func (d *dbBase) UpdateBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Con
} else { } else {
return 0, err return 0, err
} }
return 0, nil
} }
// delete related records.
// do UpdateBanch or DeleteBanch by condition of tables' relationship.
func (d *dbBase) deleteRels(q dbQuerier, mi *modelInfo, args []interface{}, tz *time.Location) error { func (d *dbBase) deleteRels(q dbQuerier, mi *modelInfo, args []interface{}, tz *time.Location) error {
for _, fi := range mi.fields.fieldsReverse { for _, fi := range mi.fields.fieldsReverse {
fi = fi.reverseFieldInfo fi = fi.reverseFieldInfo
@ -459,8 +604,11 @@ func (d *dbBase) deleteRels(q dbQuerier, mi *modelInfo, args []interface{}, tz *
return nil return nil
} }
// delete table-related records.
func (d *dbBase) DeleteBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition, tz *time.Location) (int64, error) { func (d *dbBase) DeleteBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition, tz *time.Location) (int64, error) {
tables := newDbTables(mi, d.ins) tables := newDbTables(mi, d.ins)
tables.skipEnd = true
if qs != nil { if qs != nil {
tables.parseRelated(qs.related, qs.relDepth) tables.parseRelated(qs.related, qs.relDepth)
} }
@ -486,6 +634,8 @@ func (d *dbBase) DeleteBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Con
rs = r rs = r
} }
defer rs.Close()
var ref interface{} var ref interface{}
args = make([]interface{}, 0) args = make([]interface{}, 0)
@ -528,10 +678,9 @@ func (d *dbBase) DeleteBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Con
} else { } else {
return 0, err return 0, err
} }
return 0, nil
} }
// read related records.
func (d *dbBase) ReadBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition, container interface{}, tz *time.Location, cols []string) (int64, error) { func (d *dbBase) ReadBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition, container interface{}, tz *time.Location, cols []string) (int64, error) {
val := reflect.ValueOf(container) val := reflect.ValueOf(container)
@ -640,6 +789,8 @@ func (d *dbBase) ReadBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condi
refs[i] = &ref refs[i] = &ref
} }
defer rs.Close()
slice := ind slice := ind
var cnt int64 var cnt int64
@ -739,6 +890,7 @@ func (d *dbBase) ReadBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condi
return cnt, nil return cnt, nil
} }
// excute count sql and return count result int64.
func (d *dbBase) Count(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition, tz *time.Location) (cnt int64, err error) { func (d *dbBase) Count(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition, tz *time.Location) (cnt int64, err error) {
tables := newDbTables(mi, d.ins) tables := newDbTables(mi, d.ins)
tables.parseRelated(qs.related, qs.relDepth) tables.parseRelated(qs.related, qs.relDepth)
@ -759,6 +911,7 @@ func (d *dbBase) Count(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition
return return
} }
// generate sql with replacing operator string placeholders and replaced values.
func (d *dbBase) GenerateOperatorSql(mi *modelInfo, fi *fieldInfo, operator string, args []interface{}, tz *time.Location) (string, []interface{}) { func (d *dbBase) GenerateOperatorSql(mi *modelInfo, fi *fieldInfo, operator string, args []interface{}, tz *time.Location) (string, []interface{}) {
sql := "" sql := ""
params := getFlatParams(fi, args, tz) params := getFlatParams(fi, args, tz)
@ -768,13 +921,19 @@ func (d *dbBase) GenerateOperatorSql(mi *modelInfo, fi *fieldInfo, operator stri
} }
arg := params[0] arg := params[0]
if operator == "in" { switch operator {
case "in":
marks := make([]string, len(params)) marks := make([]string, len(params))
for i, _ := range marks { for i, _ := range marks {
marks[i] = "?" marks[i] = "?"
} }
sql = fmt.Sprintf("IN (%s)", strings.Join(marks, ", ")) sql = fmt.Sprintf("IN (%s)", strings.Join(marks, ", "))
} else { case "between":
if len(params) != 2 {
panic(fmt.Errorf("operator `%s` need 2 args not %d", operator, len(params)))
}
sql = "BETWEEN ? AND ?"
default:
if len(params) > 1 { if len(params) > 1 {
panic(fmt.Errorf("operator `%s` need 1 args not %d", operator, len(params))) panic(fmt.Errorf("operator `%s` need 1 args not %d", operator, len(params)))
} }
@ -812,10 +971,12 @@ func (d *dbBase) GenerateOperatorSql(mi *modelInfo, fi *fieldInfo, operator stri
return sql, params return sql, params
} }
// gernerate sql string with inner function, such as UPPER(text).
func (d *dbBase) GenerateOperatorLeftCol(*fieldInfo, string, *string) { func (d *dbBase) GenerateOperatorLeftCol(*fieldInfo, string, *string) {
// default not use // default not use
} }
// set values to struct column.
func (d *dbBase) setColsValues(mi *modelInfo, ind *reflect.Value, cols []string, values []interface{}, tz *time.Location) { func (d *dbBase) setColsValues(mi *modelInfo, ind *reflect.Value, cols []string, values []interface{}, tz *time.Location) {
for i, column := range cols { for i, column := range cols {
val := reflect.Indirect(reflect.ValueOf(values[i])).Interface() val := reflect.Indirect(reflect.ValueOf(values[i])).Interface()
@ -837,6 +998,7 @@ func (d *dbBase) setColsValues(mi *modelInfo, ind *reflect.Value, cols []string,
} }
} }
// convert value from database result to value following in field type.
func (d *dbBase) convertValueFromDB(fi *fieldInfo, val interface{}, tz *time.Location) (interface{}, error) { func (d *dbBase) convertValueFromDB(fi *fieldInfo, val interface{}, tz *time.Location) (interface{}, error) {
if val == nil { if val == nil {
return nil, nil return nil, nil
@ -989,6 +1151,7 @@ end:
} }
// set one value to struct column field.
func (d *dbBase) setFieldValue(fi *fieldInfo, value interface{}, field reflect.Value) (interface{}, error) { func (d *dbBase) setFieldValue(fi *fieldInfo, value interface{}, field reflect.Value) (interface{}, error) {
fieldType := fi.fieldType fieldType := fi.fieldType
@ -998,18 +1161,38 @@ setValue:
switch { switch {
case fieldType == TypeBooleanField: case fieldType == TypeBooleanField:
if isNative { if isNative {
if nb, ok := field.Interface().(sql.NullBool); ok {
if value == nil {
nb.Valid = false
} else {
nb.Bool = value.(bool)
nb.Valid = true
}
field.Set(reflect.ValueOf(nb))
} else {
if value == nil { if value == nil {
value = false value = false
} }
field.SetBool(value.(bool)) field.SetBool(value.(bool))
} }
}
case fieldType == TypeCharField || fieldType == TypeTextField: case fieldType == TypeCharField || fieldType == TypeTextField:
if isNative { if isNative {
if ns, ok := field.Interface().(sql.NullString); ok {
if value == nil {
ns.Valid = false
} else {
ns.String = value.(string)
ns.Valid = true
}
field.Set(reflect.ValueOf(ns))
} else {
if value == nil { if value == nil {
value = "" value = ""
} }
field.SetString(value.(string)) field.SetString(value.(string))
} }
}
case fieldType == TypeDateField || fieldType == TypeDateTimeField: case fieldType == TypeDateField || fieldType == TypeDateTimeField:
if isNative { if isNative {
if value == nil { if value == nil {
@ -1027,19 +1210,40 @@ setValue:
} }
} else { } else {
if isNative { if isNative {
if ni, ok := field.Interface().(sql.NullInt64); ok {
if value == nil {
ni.Valid = false
} else {
ni.Int64 = value.(int64)
ni.Valid = true
}
field.Set(reflect.ValueOf(ni))
} else {
if value == nil { if value == nil {
value = int64(0) value = int64(0)
} }
field.SetInt(value.(int64)) field.SetInt(value.(int64))
} }
} }
}
case fieldType == TypeFloatField || fieldType == TypeDecimalField: case fieldType == TypeFloatField || fieldType == TypeDecimalField:
if isNative { if isNative {
if nf, ok := field.Interface().(sql.NullFloat64); ok {
if value == nil {
nf.Valid = false
} else {
nf.Float64 = value.(float64)
nf.Valid = true
}
field.Set(reflect.ValueOf(nf))
} else {
if value == nil { if value == nil {
value = float64(0) value = float64(0)
} }
field.SetFloat(value.(float64)) field.SetFloat(value.(float64))
} }
}
case fieldType&IsRelField > 0: case fieldType&IsRelField > 0:
if value != nil { if value != nil {
fieldType = fi.relModelInfo.fields.pk.fieldType fieldType = fi.relModelInfo.fields.pk.fieldType
@ -1063,6 +1267,7 @@ setValue:
return value, nil return value, nil
} }
// query sql, read values , save to *[]ParamList.
func (d *dbBase) ReadValues(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition, exprs []string, container interface{}, tz *time.Location) (int64, error) { func (d *dbBase) ReadValues(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition, exprs []string, container interface{}, tz *time.Location) (int64, error) {
var ( var (
@ -1150,6 +1355,8 @@ func (d *dbBase) ReadValues(q dbQuerier, qs *querySet, mi *modelInfo, cond *Cond
refs[i] = &ref refs[i] = &ref
} }
defer rs.Close()
var ( var (
cnt int64 cnt int64
columns []string columns []string
@ -1228,6 +1435,11 @@ func (d *dbBase) ReadValues(q dbQuerier, qs *querySet, mi *modelInfo, cond *Cond
return cnt, nil return cnt, nil
} }
func (d *dbBase) RowsTo(dbQuerier, *querySet, *modelInfo, *Condition, interface{}, string, string, *time.Location) (int64, error) {
return 0, nil
}
// flag of update joined record.
func (d *dbBase) SupportUpdateJoin() bool { func (d *dbBase) SupportUpdateJoin() bool {
return true return true
} }
@ -1236,30 +1448,37 @@ func (d *dbBase) MaxLimit() uint64 {
return 18446744073709551615 return 18446744073709551615
} }
// return quote.
func (d *dbBase) TableQuote() string { func (d *dbBase) TableQuote() string {
return "`" return "`"
} }
// replace value placeholer in parametered sql string.
func (d *dbBase) ReplaceMarks(query *string) { func (d *dbBase) ReplaceMarks(query *string) {
// default use `?` as mark, do nothing // default use `?` as mark, do nothing
} }
// flag of RETURNING sql.
func (d *dbBase) HasReturningID(*modelInfo, *string) bool { func (d *dbBase) HasReturningID(*modelInfo, *string) bool {
return false return false
} }
// convert time from db.
func (d *dbBase) TimeFromDB(t *time.Time, tz *time.Location) { func (d *dbBase) TimeFromDB(t *time.Time, tz *time.Location) {
*t = t.In(tz) *t = t.In(tz)
} }
// convert time to db.
func (d *dbBase) TimeToDB(t *time.Time, tz *time.Location) { func (d *dbBase) TimeToDB(t *time.Time, tz *time.Location) {
*t = t.In(tz) *t = t.In(tz)
} }
// get database types.
func (d *dbBase) DbTypes() map[string]string { func (d *dbBase) DbTypes() map[string]string {
return nil return nil
} }
// gt all tables.
func (d *dbBase) GetTables(db dbQuerier) (map[string]bool, error) { func (d *dbBase) GetTables(db dbQuerier) (map[string]bool, error) {
tables := make(map[string]bool) tables := make(map[string]bool)
query := d.ins.ShowTablesQuery() query := d.ins.ShowTablesQuery()
@ -1268,6 +1487,8 @@ func (d *dbBase) GetTables(db dbQuerier) (map[string]bool, error) {
return tables, err return tables, err
} }
defer rows.Close()
for rows.Next() { for rows.Next() {
var table string var table string
err := rows.Scan(&table) err := rows.Scan(&table)
@ -1282,6 +1503,7 @@ func (d *dbBase) GetTables(db dbQuerier) (map[string]bool, error) {
return tables, nil return tables, nil
} }
// get all cloumns in table.
func (d *dbBase) GetColumns(db dbQuerier, table string) (map[string][3]string, error) { func (d *dbBase) GetColumns(db dbQuerier, table string) (map[string][3]string, error) {
columns := make(map[string][3]string) columns := make(map[string][3]string)
query := d.ins.ShowColumnsQuery(table) query := d.ins.ShowColumnsQuery(table)
@ -1290,6 +1512,8 @@ func (d *dbBase) GetColumns(db dbQuerier, table string) (map[string][3]string, e
return columns, err return columns, err
} }
defer rows.Close()
for rows.Next() { for rows.Next() {
var ( var (
name string name string
@ -1306,18 +1530,22 @@ func (d *dbBase) GetColumns(db dbQuerier, table string) (map[string][3]string, e
return columns, nil return columns, nil
} }
// not implement.
func (d *dbBase) OperatorSql(operator string) string { func (d *dbBase) OperatorSql(operator string) string {
panic(ErrNotImplement) panic(ErrNotImplement)
} }
// not implement.
func (d *dbBase) ShowTablesQuery() string { func (d *dbBase) ShowTablesQuery() string {
panic(ErrNotImplement) panic(ErrNotImplement)
} }
// not implement.
func (d *dbBase) ShowColumnsQuery(table string) string { func (d *dbBase) ShowColumnsQuery(table string) string {
panic(ErrNotImplement) panic(ErrNotImplement)
} }
// not implement.
func (d *dbBase) IndexExists(dbQuerier, string, string) bool { func (d *dbBase) IndexExists(dbQuerier, string, string) bool {
panic(ErrNotImplement) panic(ErrNotImplement)
} }

View File

@ -1,35 +1,45 @@
// Beego (http://beego.me/)
// @description beego is an open-source, high-performance web framework for the Go programming language.
// @link http://github.com/astaxie/beego for the canonical source repository
// @license http://github.com/astaxie/beego/blob/master/LICENSE
// @authors slene
package orm package orm
import ( import (
"database/sql" "database/sql"
"fmt" "fmt"
"os"
"reflect" "reflect"
"sync" "sync"
"time" "time"
) )
// database driver constant int.
type DriverType int type DriverType int
const ( const (
_ DriverType = iota _ DriverType = iota // int enum type
DR_MySQL DR_MySQL // mysql
DR_Sqlite DR_Sqlite // sqlite
DR_Oracle DR_Oracle // oracle
DR_Postgres DR_Postgres // pgsql
) )
// database driver string.
type driver string type driver string
// get type constant int of current driver..
func (d driver) Type() DriverType { func (d driver) Type() DriverType {
a, _ := dataBaseCache.get(string(d)) a, _ := dataBaseCache.get(string(d))
return a.Driver return a.Driver
} }
// get name of current driver
func (d driver) Name() string { func (d driver) Name() string {
return string(d) return string(d)
} }
// check driver iis implemented Driver interface or not.
var _ Driver = new(driver) var _ Driver = new(driver)
var ( var (
@ -47,11 +57,13 @@ var (
} }
) )
// database alias cacher.
type _dbCache struct { type _dbCache struct {
mux sync.RWMutex mux sync.RWMutex
cache map[string]*alias cache map[string]*alias
} }
// add database alias with original name.
func (ac *_dbCache) add(name string, al *alias) (added bool) { func (ac *_dbCache) add(name string, al *alias) (added bool) {
ac.mux.Lock() ac.mux.Lock()
defer ac.mux.Unlock() defer ac.mux.Unlock()
@ -62,6 +74,7 @@ func (ac *_dbCache) add(name string, al *alias) (added bool) {
return return
} }
// get database alias if cached.
func (ac *_dbCache) get(name string) (al *alias, ok bool) { func (ac *_dbCache) get(name string) (al *alias, ok bool) {
ac.mux.RLock() ac.mux.RLock()
defer ac.mux.RUnlock() defer ac.mux.RUnlock()
@ -69,6 +82,7 @@ func (ac *_dbCache) get(name string) (al *alias, ok bool) {
return return
} }
// get default alias.
func (ac *_dbCache) getDefault() (al *alias) { func (ac *_dbCache) getDefault() (al *alias) {
al, _ = ac.get("default") al, _ = ac.get("default")
return return
@ -87,57 +101,29 @@ type alias struct {
Engine string Engine string
} }
// Setting the database connect params. Use the database driver self dataSource args. func detectTZ(al *alias) {
func RegisterDataBase(aliasName, driverName, dataSource string, params ...int) {
al := new(alias)
al.Name = aliasName
al.DriverName = driverName
al.DataSource = dataSource
var (
err error
)
if dr, ok := drivers[driverName]; ok {
al.DbBaser = dbBasers[dr]
al.Driver = dr
} else {
err = fmt.Errorf("driver name `%s` have not registered", driverName)
goto end
}
if dataBaseCache.add(aliasName, al) == false {
err = fmt.Errorf("db name `%s` already registered, cannot reuse", aliasName)
goto end
}
al.DB, err = sql.Open(driverName, dataSource)
if err != nil {
err = fmt.Errorf("register db `%s`, %s", aliasName, err.Error())
goto end
}
// orm timezone system match database // orm timezone system match database
// default use Local // default use Local
al.TZ = time.Local al.TZ = time.Local
if al.DriverName == "sphinx" {
return
}
switch al.Driver { switch al.Driver {
case DR_MySQL: case DR_MySQL:
row := al.DB.QueryRow("SELECT @@session.time_zone") row := al.DB.QueryRow("SELECT TIMEDIFF(NOW(), UTC_TIMESTAMP)")
var tz string var tz string
row.Scan(&tz) row.Scan(&tz)
if tz == "SYSTEM" { if len(tz) >= 8 {
tz = "" if tz[0] != '-' {
row = al.DB.QueryRow("SELECT @@system_time_zone") tz = "+" + tz
row.Scan(&tz)
t, err := time.Parse("MST", tz)
if err == nil {
al.TZ = t.Location()
} }
} else { t, err := time.Parse("-07:00:00", tz)
t, err := time.Parse("-07:00", tz)
if err == nil { if err == nil {
al.TZ = t.Location() al.TZ = t.Location()
} else {
DebugLog.Printf("Detect DB timezone: %s %s\n", tz, err.Error())
} }
} }
@ -163,8 +149,64 @@ func RegisterDataBase(aliasName, driverName, dataSource string, params ...int) {
loc, err := time.LoadLocation(tz) loc, err := time.LoadLocation(tz)
if err == nil { if err == nil {
al.TZ = loc al.TZ = loc
} else {
DebugLog.Printf("Detect DB timezone: %s %s\n", tz, err.Error())
} }
} }
}
func addAliasWthDB(aliasName, driverName string, db *sql.DB) (*alias, error) {
al := new(alias)
al.Name = aliasName
al.DriverName = driverName
al.DB = db
if dr, ok := drivers[driverName]; ok {
al.DbBaser = dbBasers[dr]
al.Driver = dr
} else {
return nil, fmt.Errorf("driver name `%s` have not registered", driverName)
}
err := db.Ping()
if err != nil {
return nil, fmt.Errorf("register db Ping `%s`, %s", aliasName, err.Error())
}
if dataBaseCache.add(aliasName, al) == false {
return nil, fmt.Errorf("DataBase alias name `%s` already registered, cannot reuse", aliasName)
}
return al, nil
}
func AddAliasWthDB(aliasName, driverName string, db *sql.DB) error {
_, err := addAliasWthDB(aliasName, driverName, db)
return err
}
// Setting the database connect params. Use the database driver self dataSource args.
func RegisterDataBase(aliasName, driverName, dataSource string, params ...int) error {
var (
err error
db *sql.DB
al *alias
)
db, err = sql.Open(driverName, dataSource)
if err != nil {
err = fmt.Errorf("register db `%s`, %s", aliasName, err.Error())
goto end
}
al, err = addAliasWthDB(aliasName, driverName, db)
if err != nil {
goto end
}
al.DataSource = dataSource
detectTZ(al)
for i, v := range params { for i, v := range params {
switch i { switch i {
@ -175,39 +217,37 @@ func RegisterDataBase(aliasName, driverName, dataSource string, params ...int) {
} }
} }
err = al.DB.Ping()
if err != nil {
err = fmt.Errorf("register db `%s`, %s", aliasName, err.Error())
goto end
}
end: end:
if err != nil { if err != nil {
fmt.Println(err.Error()) if db != nil {
os.Exit(2) db.Close()
} }
DebugLog.Println(err.Error())
}
return err
} }
// Register a database driver use specify driver name, this can be definition the driver is which database type. // Register a database driver use specify driver name, this can be definition the driver is which database type.
func RegisterDriver(driverName string, typ DriverType) { func RegisterDriver(driverName string, typ DriverType) error {
if t, ok := drivers[driverName]; ok == false { if t, ok := drivers[driverName]; ok == false {
drivers[driverName] = typ drivers[driverName] = typ
} else { } else {
if t != typ { if t != typ {
fmt.Sprintf("driverName `%s` db driver already registered and is other type\n", driverName) return fmt.Errorf("driverName `%s` db driver already registered and is other type\n", driverName)
os.Exit(2)
} }
} }
return nil
} }
// Change the database default used timezone // Change the database default used timezone
func SetDataBaseTZ(aliasName string, tz *time.Location) { func SetDataBaseTZ(aliasName string, tz *time.Location) error {
if al, ok := dataBaseCache.get(aliasName); ok { if al, ok := dataBaseCache.get(aliasName); ok {
al.TZ = tz al.TZ = tz
} else { } else {
fmt.Sprintf("DataBase name `%s` not registered\n", aliasName) return fmt.Errorf("DataBase alias name `%s` not registered\n", aliasName)
os.Exit(2)
} }
return nil
} }
// Change the max idle conns for *sql.DB, use specify database alias name // Change the max idle conns for *sql.DB, use specify database alias name
@ -226,3 +266,19 @@ func SetMaxOpenConns(aliasName string, maxOpenConns int) {
fun.Call([]reflect.Value{reflect.ValueOf(maxOpenConns)}) fun.Call([]reflect.Value{reflect.ValueOf(maxOpenConns)})
} }
} }
// Get *sql.DB from registered database by db alias name.
// Use "default" as alias name if you not set.
func GetDB(aliasNames ...string) (*sql.DB, error) {
var name string
if len(aliasNames) > 0 {
name = aliasNames[0]
} else {
name = "default"
}
if al, ok := dataBaseCache.get(name); ok {
return al.DB, nil
} else {
return nil, fmt.Errorf("DataBase of alias name `%s` not found\n", name)
}
}

View File

@ -1,9 +1,16 @@
// Beego (http://beego.me/)
// @description beego is an open-source, high-performance web framework for the Go programming language.
// @link http://github.com/astaxie/beego for the canonical source repository
// @license http://github.com/astaxie/beego/blob/master/LICENSE
// @authors slene
package orm package orm
import ( import (
"fmt" "fmt"
) )
// mysql operators.
var mysqlOperators = map[string]string{ var mysqlOperators = map[string]string{
"exact": "= ?", "exact": "= ?",
"iexact": "LIKE ?", "iexact": "LIKE ?",
@ -21,6 +28,7 @@ var mysqlOperators = map[string]string{
"iendswith": "LIKE ?", "iendswith": "LIKE ?",
} }
// mysql column field types.
var mysqlTypes = map[string]string{ var mysqlTypes = map[string]string{
"auto": "AUTO_INCREMENT NOT NULL PRIMARY KEY", "auto": "AUTO_INCREMENT NOT NULL PRIMARY KEY",
"pk": "NOT NULL PRIMARY KEY", "pk": "NOT NULL PRIMARY KEY",
@ -41,29 +49,35 @@ var mysqlTypes = map[string]string{
"float64-decimal": "numeric(%d, %d)", "float64-decimal": "numeric(%d, %d)",
} }
// mysql dbBaser implementation.
type dbBaseMysql struct { type dbBaseMysql struct {
dbBase dbBase
} }
var _ dbBaser = new(dbBaseMysql) var _ dbBaser = new(dbBaseMysql)
// get mysql operator.
func (d *dbBaseMysql) OperatorSql(operator string) string { func (d *dbBaseMysql) OperatorSql(operator string) string {
return mysqlOperators[operator] return mysqlOperators[operator]
} }
// get mysql table field types.
func (d *dbBaseMysql) DbTypes() map[string]string { func (d *dbBaseMysql) DbTypes() map[string]string {
return mysqlTypes return mysqlTypes
} }
// show table sql for mysql.
func (d *dbBaseMysql) ShowTablesQuery() string { func (d *dbBaseMysql) ShowTablesQuery() string {
return "SELECT table_name FROM information_schema.tables WHERE table_type = 'BASE TABLE' AND table_schema = DATABASE()" return "SELECT table_name FROM information_schema.tables WHERE table_type = 'BASE TABLE' AND table_schema = DATABASE()"
} }
// show columns sql of table for mysql.
func (d *dbBaseMysql) ShowColumnsQuery(table string) string { func (d *dbBaseMysql) ShowColumnsQuery(table string) string {
return fmt.Sprintf("SELECT COLUMN_NAME, COLUMN_TYPE, IS_NULLABLE FROM information_schema.columns "+ return fmt.Sprintf("SELECT COLUMN_NAME, COLUMN_TYPE, IS_NULLABLE FROM information_schema.columns "+
"WHERE table_schema = DATABASE() AND table_name = '%s'", table) "WHERE table_schema = DATABASE() AND table_name = '%s'", table)
} }
// execute sql to check index exist.
func (d *dbBaseMysql) IndexExists(db dbQuerier, table string, name string) bool { func (d *dbBaseMysql) IndexExists(db dbQuerier, table string, name string) bool {
row := db.QueryRow("SELECT count(*) FROM information_schema.statistics "+ row := db.QueryRow("SELECT count(*) FROM information_schema.statistics "+
"WHERE table_schema = DATABASE() AND table_name = ? AND index_name = ?", table, name) "WHERE table_schema = DATABASE() AND table_name = ? AND index_name = ?", table, name)
@ -72,6 +86,7 @@ func (d *dbBaseMysql) IndexExists(db dbQuerier, table string, name string) bool
return cnt > 0 return cnt > 0
} }
// create new mysql dbBaser.
func newdbBaseMysql() dbBaser { func newdbBaseMysql() dbBaser {
b := new(dbBaseMysql) b := new(dbBaseMysql)
b.ins = b b.ins = b

View File

@ -1,11 +1,19 @@
// Beego (http://beego.me/)
// @description beego is an open-source, high-performance web framework for the Go programming language.
// @link http://github.com/astaxie/beego for the canonical source repository
// @license http://github.com/astaxie/beego/blob/master/LICENSE
// @authors slene
package orm package orm
// oracle dbBaser
type dbBaseOracle struct { type dbBaseOracle struct {
dbBase dbBase
} }
var _ dbBaser = new(dbBaseOracle) var _ dbBaser = new(dbBaseOracle)
// create oracle dbBaser.
func newdbBaseOracle() dbBaser { func newdbBaseOracle() dbBaser {
b := new(dbBaseOracle) b := new(dbBaseOracle)
b.ins = b b.ins = b

View File

@ -1,3 +1,9 @@
// Beego (http://beego.me/)
// @description beego is an open-source, high-performance web framework for the Go programming language.
// @link http://github.com/astaxie/beego for the canonical source repository
// @license http://github.com/astaxie/beego/blob/master/LICENSE
// @authors slene
package orm package orm
import ( import (
@ -5,6 +11,7 @@ import (
"strconv" "strconv"
) )
// postgresql operators.
var postgresOperators = map[string]string{ var postgresOperators = map[string]string{
"exact": "= ?", "exact": "= ?",
"iexact": "= UPPER(?)", "iexact": "= UPPER(?)",
@ -20,6 +27,7 @@ var postgresOperators = map[string]string{
"iendswith": "LIKE UPPER(?)", "iendswith": "LIKE UPPER(?)",
} }
// postgresql column field types.
var postgresTypes = map[string]string{ var postgresTypes = map[string]string{
"auto": "serial NOT NULL PRIMARY KEY", "auto": "serial NOT NULL PRIMARY KEY",
"pk": "NOT NULL PRIMARY KEY", "pk": "NOT NULL PRIMARY KEY",
@ -40,16 +48,19 @@ var postgresTypes = map[string]string{
"float64-decimal": "numeric(%d, %d)", "float64-decimal": "numeric(%d, %d)",
} }
// postgresql dbBaser.
type dbBasePostgres struct { type dbBasePostgres struct {
dbBase dbBase
} }
var _ dbBaser = new(dbBasePostgres) var _ dbBaser = new(dbBasePostgres)
// get postgresql operator.
func (d *dbBasePostgres) OperatorSql(operator string) string { func (d *dbBasePostgres) OperatorSql(operator string) string {
return postgresOperators[operator] return postgresOperators[operator]
} }
// generate functioned sql string, such as contains(text).
func (d *dbBasePostgres) GenerateOperatorLeftCol(fi *fieldInfo, operator string, leftCol *string) { func (d *dbBasePostgres) GenerateOperatorLeftCol(fi *fieldInfo, operator string, leftCol *string) {
switch operator { switch operator {
case "contains", "startswith", "endswith": case "contains", "startswith", "endswith":
@ -59,6 +70,7 @@ func (d *dbBasePostgres) GenerateOperatorLeftCol(fi *fieldInfo, operator string,
} }
} }
// postgresql unsupports updating joined record.
func (d *dbBasePostgres) SupportUpdateJoin() bool { func (d *dbBasePostgres) SupportUpdateJoin() bool {
return false return false
} }
@ -67,10 +79,13 @@ func (d *dbBasePostgres) MaxLimit() uint64 {
return 0 return 0
} }
// postgresql quote is ".
func (d *dbBasePostgres) TableQuote() string { func (d *dbBasePostgres) TableQuote() string {
return `"` return `"`
} }
// postgresql value placeholder is $n.
// replace default ? to $n.
func (d *dbBasePostgres) ReplaceMarks(query *string) { func (d *dbBasePostgres) ReplaceMarks(query *string) {
q := *query q := *query
num := 0 num := 0
@ -97,6 +112,7 @@ func (d *dbBasePostgres) ReplaceMarks(query *string) {
*query = string(data) *query = string(data)
} }
// make returning sql support for postgresql.
func (d *dbBasePostgres) HasReturningID(mi *modelInfo, query *string) (has bool) { func (d *dbBasePostgres) HasReturningID(mi *modelInfo, query *string) (has bool) {
if mi.fields.pk.auto { if mi.fields.pk.auto {
if query != nil { if query != nil {
@ -107,18 +123,22 @@ func (d *dbBasePostgres) HasReturningID(mi *modelInfo, query *string) (has bool)
return return
} }
// show table sql for postgresql.
func (d *dbBasePostgres) ShowTablesQuery() string { func (d *dbBasePostgres) ShowTablesQuery() string {
return "SELECT table_name FROM information_schema.tables WHERE table_type = 'BASE TABLE' AND table_schema NOT IN ('pg_catalog', 'information_schema')" return "SELECT table_name FROM information_schema.tables WHERE table_type = 'BASE TABLE' AND table_schema NOT IN ('pg_catalog', 'information_schema')"
} }
// show table columns sql for postgresql.
func (d *dbBasePostgres) ShowColumnsQuery(table string) string { func (d *dbBasePostgres) ShowColumnsQuery(table string) string {
return fmt.Sprintf("SELECT column_name, data_type, is_nullable FROM information_schema.columns where table_schema NOT IN ('pg_catalog', 'information_schema') and table_name = '%s'", table) return fmt.Sprintf("SELECT column_name, data_type, is_nullable FROM information_schema.columns where table_schema NOT IN ('pg_catalog', 'information_schema') and table_name = '%s'", table)
} }
// get column types of postgresql.
func (d *dbBasePostgres) DbTypes() map[string]string { func (d *dbBasePostgres) DbTypes() map[string]string {
return postgresTypes return postgresTypes
} }
// check index exist in postgresql.
func (d *dbBasePostgres) IndexExists(db dbQuerier, table string, name string) bool { func (d *dbBasePostgres) IndexExists(db dbQuerier, table string, name string) bool {
query := fmt.Sprintf("SELECT COUNT(*) FROM pg_indexes WHERE tablename = '%s' AND indexname = '%s'", table, name) query := fmt.Sprintf("SELECT COUNT(*) FROM pg_indexes WHERE tablename = '%s' AND indexname = '%s'", table, name)
row := db.QueryRow(query) row := db.QueryRow(query)
@ -127,6 +147,7 @@ func (d *dbBasePostgres) IndexExists(db dbQuerier, table string, name string) bo
return cnt > 0 return cnt > 0
} }
// create new postgresql dbBaser.
func newdbBasePostgres() dbBaser { func newdbBasePostgres() dbBaser {
b := new(dbBasePostgres) b := new(dbBasePostgres)
b.ins = b b.ins = b

View File

@ -1,3 +1,9 @@
// Beego (http://beego.me/)
// @description beego is an open-source, high-performance web framework for the Go programming language.
// @link http://github.com/astaxie/beego for the canonical source repository
// @license http://github.com/astaxie/beego/blob/master/LICENSE
// @authors slene
package orm package orm
import ( import (
@ -5,6 +11,7 @@ import (
"fmt" "fmt"
) )
// sqlite operators.
var sqliteOperators = map[string]string{ var sqliteOperators = map[string]string{
"exact": "= ?", "exact": "= ?",
"iexact": "LIKE ? ESCAPE '\\'", "iexact": "LIKE ? ESCAPE '\\'",
@ -20,6 +27,7 @@ var sqliteOperators = map[string]string{
"iendswith": "LIKE ? ESCAPE '\\'", "iendswith": "LIKE ? ESCAPE '\\'",
} }
// sqlite column types.
var sqliteTypes = map[string]string{ var sqliteTypes = map[string]string{
"auto": "integer NOT NULL PRIMARY KEY AUTOINCREMENT", "auto": "integer NOT NULL PRIMARY KEY AUTOINCREMENT",
"pk": "NOT NULL PRIMARY KEY", "pk": "NOT NULL PRIMARY KEY",
@ -40,38 +48,47 @@ var sqliteTypes = map[string]string{
"float64-decimal": "decimal", "float64-decimal": "decimal",
} }
// sqlite dbBaser.
type dbBaseSqlite struct { type dbBaseSqlite struct {
dbBase dbBase
} }
var _ dbBaser = new(dbBaseSqlite) var _ dbBaser = new(dbBaseSqlite)
// get sqlite operator.
func (d *dbBaseSqlite) OperatorSql(operator string) string { func (d *dbBaseSqlite) OperatorSql(operator string) string {
return sqliteOperators[operator] return sqliteOperators[operator]
} }
// generate functioned sql for sqlite.
// only support DATE(text).
func (d *dbBaseSqlite) GenerateOperatorLeftCol(fi *fieldInfo, operator string, leftCol *string) { func (d *dbBaseSqlite) GenerateOperatorLeftCol(fi *fieldInfo, operator string, leftCol *string) {
if fi.fieldType == TypeDateField { if fi.fieldType == TypeDateField {
*leftCol = fmt.Sprintf("DATE(%s)", *leftCol) *leftCol = fmt.Sprintf("DATE(%s)", *leftCol)
} }
} }
// unable updating joined record in sqlite.
func (d *dbBaseSqlite) SupportUpdateJoin() bool { func (d *dbBaseSqlite) SupportUpdateJoin() bool {
return false return false
} }
// max int in sqlite.
func (d *dbBaseSqlite) MaxLimit() uint64 { func (d *dbBaseSqlite) MaxLimit() uint64 {
return 9223372036854775807 return 9223372036854775807
} }
// get column types in sqlite.
func (d *dbBaseSqlite) DbTypes() map[string]string { func (d *dbBaseSqlite) DbTypes() map[string]string {
return sqliteTypes return sqliteTypes
} }
// get show tables sql in sqlite.
func (d *dbBaseSqlite) ShowTablesQuery() string { func (d *dbBaseSqlite) ShowTablesQuery() string {
return "SELECT name FROM sqlite_master WHERE type = 'table'" return "SELECT name FROM sqlite_master WHERE type = 'table'"
} }
// get columns in sqlite.
func (d *dbBaseSqlite) GetColumns(db dbQuerier, table string) (map[string][3]string, error) { func (d *dbBaseSqlite) GetColumns(db dbQuerier, table string) (map[string][3]string, error) {
query := d.ins.ShowColumnsQuery(table) query := d.ins.ShowColumnsQuery(table)
rows, err := db.Query(query) rows, err := db.Query(query)
@ -92,10 +109,12 @@ func (d *dbBaseSqlite) GetColumns(db dbQuerier, table string) (map[string][3]str
return columns, nil return columns, nil
} }
// get show columns sql in sqlite.
func (d *dbBaseSqlite) ShowColumnsQuery(table string) string { func (d *dbBaseSqlite) ShowColumnsQuery(table string) string {
return fmt.Sprintf("pragma table_info('%s')", table) return fmt.Sprintf("pragma table_info('%s')", table)
} }
// check index exist in sqlite.
func (d *dbBaseSqlite) IndexExists(db dbQuerier, table string, name string) bool { func (d *dbBaseSqlite) IndexExists(db dbQuerier, table string, name string) bool {
query := fmt.Sprintf("PRAGMA index_list('%s')", table) query := fmt.Sprintf("PRAGMA index_list('%s')", table)
rows, err := db.Query(query) rows, err := db.Query(query)
@ -113,6 +132,7 @@ func (d *dbBaseSqlite) IndexExists(db dbQuerier, table string, name string) bool
return false return false
} }
// create new sqlite dbBaser.
func newdbBaseSqlite() dbBaser { func newdbBaseSqlite() dbBaser {
b := new(dbBaseSqlite) b := new(dbBaseSqlite)
b.ins = b b.ins = b

View File

@ -1,3 +1,9 @@
// Beego (http://beego.me/)
// @description beego is an open-source, high-performance web framework for the Go programming language.
// @link http://github.com/astaxie/beego for the canonical source repository
// @license http://github.com/astaxie/beego/blob/master/LICENSE
// @authors slene
package orm package orm
import ( import (
@ -6,6 +12,7 @@ import (
"time" "time"
) )
// table info struct.
type dbTable struct { type dbTable struct {
id int id int
index string index string
@ -18,13 +25,17 @@ type dbTable struct {
jtl *dbTable jtl *dbTable
} }
// tables collection struct, contains some tables.
type dbTables struct { type dbTables struct {
tablesM map[string]*dbTable tablesM map[string]*dbTable
tables []*dbTable tables []*dbTable
mi *modelInfo mi *modelInfo
base dbBaser base dbBaser
skipEnd bool
} }
// set table info to collection.
// if not exist, create new.
func (t *dbTables) set(names []string, mi *modelInfo, fi *fieldInfo, inner bool) *dbTable { func (t *dbTables) set(names []string, mi *modelInfo, fi *fieldInfo, inner bool) *dbTable {
name := strings.Join(names, ExprSep) name := strings.Join(names, ExprSep)
if j, ok := t.tablesM[name]; ok { if j, ok := t.tablesM[name]; ok {
@ -41,6 +52,7 @@ func (t *dbTables) set(names []string, mi *modelInfo, fi *fieldInfo, inner bool)
return t.tablesM[name] return t.tablesM[name]
} }
// add table info to collection.
func (t *dbTables) add(names []string, mi *modelInfo, fi *fieldInfo, inner bool) (*dbTable, bool) { func (t *dbTables) add(names []string, mi *modelInfo, fi *fieldInfo, inner bool) (*dbTable, bool) {
name := strings.Join(names, ExprSep) name := strings.Join(names, ExprSep)
if _, ok := t.tablesM[name]; ok == false { if _, ok := t.tablesM[name]; ok == false {
@ -53,11 +65,14 @@ func (t *dbTables) add(names []string, mi *modelInfo, fi *fieldInfo, inner bool)
return t.tablesM[name], false return t.tablesM[name], false
} }
// get table info in collection.
func (t *dbTables) get(name string) (*dbTable, bool) { func (t *dbTables) get(name string) (*dbTable, bool) {
j, ok := t.tablesM[name] j, ok := t.tablesM[name]
return j, ok return j, ok
} }
// get related fields info in recursive depth loop.
// loop once, depth decreases one.
func (t *dbTables) loopDepth(depth int, prefix string, fi *fieldInfo, related []string) []string { func (t *dbTables) loopDepth(depth int, prefix string, fi *fieldInfo, related []string) []string {
if depth < 0 || fi.fieldType == RelManyToMany { if depth < 0 || fi.fieldType == RelManyToMany {
return related return related
@ -78,6 +93,7 @@ func (t *dbTables) loopDepth(depth int, prefix string, fi *fieldInfo, related []
return related return related
} }
// parse related fields.
func (t *dbTables) parseRelated(rels []string, depth int) { func (t *dbTables) parseRelated(rels []string, depth int) {
relsNum := len(rels) relsNum := len(rels)
@ -111,7 +127,7 @@ func (t *dbTables) parseRelated(rels []string, depth int) {
names = append(names, fi.name) names = append(names, fi.name)
mmi = fi.relModelInfo mmi = fi.relModelInfo
if fi.null { if fi.null || t.skipEnd {
inner = false inner = false
} }
@ -139,6 +155,7 @@ func (t *dbTables) parseRelated(rels []string, depth int) {
} }
} }
// generate join string.
func (t *dbTables) getJoinSql() (join string) { func (t *dbTables) getJoinSql() (join string) {
Q := t.base.TableQuote() Q := t.base.TableQuote()
@ -185,9 +202,12 @@ func (t *dbTables) getJoinSql() (join string) {
return return
} }
// parse orm model struct field tag expression.
func (t *dbTables) parseExprs(mi *modelInfo, exprs []string) (index, name string, info *fieldInfo, success bool) { func (t *dbTables) parseExprs(mi *modelInfo, exprs []string) (index, name string, info *fieldInfo, success bool) {
var ( var (
jtl *dbTable jtl *dbTable
fi *fieldInfo
fiN *fieldInfo
mmi = mi mmi = mi
) )
@ -196,9 +216,22 @@ func (t *dbTables) parseExprs(mi *modelInfo, exprs []string) (index, name string
inner := true inner := true
loopFor:
for i, ex := range exprs { for i, ex := range exprs {
fi, ok := mmi.fields.GetByAny(ex) var ok, okN bool
if fiN != nil {
fi = fiN
ok = true
fiN = nil
}
if i == 0 {
fi, ok = mmi.fields.GetByAny(ex)
}
_ = okN
if ok { if ok {
@ -216,17 +249,33 @@ func (t *dbTables) parseExprs(mi *modelInfo, exprs []string) (index, name string
mmi = fi.reverseFieldInfo.mi mmi = fi.reverseFieldInfo.mi
} }
if i < num {
fiN, okN = mmi.fields.GetByAny(exprs[i+1])
}
if isRel && (fi.mi.isThrough == false || num != i) { if isRel && (fi.mi.isThrough == false || num != i) {
if fi.null { if fi.null || t.skipEnd {
inner = false inner = false
} }
if t.skipEnd && okN || !t.skipEnd {
if t.skipEnd && okN && fiN.pk {
goto loopEnd
}
jt, _ := t.add(names, mmi, fi, inner) jt, _ := t.add(names, mmi, fi, inner)
jt.jtl = jtl jt.jtl = jtl
jtl = jt jtl = jt
} }
if num == i { }
if num != i {
continue
}
loopEnd:
if i == 0 || jtl == nil { if i == 0 || jtl == nil {
index = "T0" index = "T0"
} else { } else {
@ -252,7 +301,8 @@ func (t *dbTables) parseExprs(mi *modelInfo, exprs []string) (index, name string
name = info.name name = info.name
} }
} }
}
break loopFor
} else { } else {
index = "" index = ""
@ -267,6 +317,7 @@ func (t *dbTables) parseExprs(mi *modelInfo, exprs []string) (index, name string
return return
} }
// generate condition sql.
func (t *dbTables) getCondSql(cond *Condition, sub bool, tz *time.Location) (where string, params []interface{}) { func (t *dbTables) getCondSql(cond *Condition, sub bool, tz *time.Location) (where string, params []interface{}) {
if cond == nil || cond.IsEmpty() { if cond == nil || cond.IsEmpty() {
return return
@ -331,6 +382,7 @@ func (t *dbTables) getCondSql(cond *Condition, sub bool, tz *time.Location) (whe
return return
} }
// generate order sql.
func (t *dbTables) getOrderSql(orders []string) (orderSql string) { func (t *dbTables) getOrderSql(orders []string) (orderSql string) {
if len(orders) == 0 { if len(orders) == 0 {
return return
@ -359,6 +411,7 @@ func (t *dbTables) getOrderSql(orders []string) (orderSql string) {
return return
} }
// generate limit sql.
func (t *dbTables) getLimitSql(mi *modelInfo, offset int64, limit int64) (limits string) { func (t *dbTables) getLimitSql(mi *modelInfo, offset int64, limit int64) (limits string) {
if limit == 0 { if limit == 0 {
limit = int64(DefaultRowsLimit) limit = int64(DefaultRowsLimit)
@ -381,6 +434,7 @@ func (t *dbTables) getLimitSql(mi *modelInfo, offset int64, limit int64) (limits
return return
} }
// crete new tables collection.
func newDbTables(mi *modelInfo, base dbBaser) *dbTables { func newDbTables(mi *modelInfo, base dbBaser) *dbTables {
tables := &dbTables{} tables := &dbTables{}
tables.tablesM = make(map[string]*dbTable) tables.tablesM = make(map[string]*dbTable)

View File

@ -1,3 +1,9 @@
// Beego (http://beego.me/)
// @description beego is an open-source, high-performance web framework for the Go programming language.
// @link http://github.com/astaxie/beego for the canonical source repository
// @license http://github.com/astaxie/beego/blob/master/LICENSE
// @authors slene
package orm package orm
import ( import (
@ -6,15 +12,16 @@ import (
"time" "time"
) )
// get table alias.
func getDbAlias(name string) *alias { func getDbAlias(name string) *alias {
if al, ok := dataBaseCache.get(name); ok { if al, ok := dataBaseCache.get(name); ok {
return al return al
} else { } else {
panic(fmt.Errorf("unknown DataBase alias name %s", name)) panic(fmt.Errorf("unknown DataBase alias name %s", name))
} }
return nil
} }
// get pk column info.
func getExistPk(mi *modelInfo, ind reflect.Value) (column string, value interface{}, exist bool) { func getExistPk(mi *modelInfo, ind reflect.Value) (column string, value interface{}, exist bool) {
fi := mi.fields.pk fi := mi.fields.pk
@ -37,6 +44,7 @@ func getExistPk(mi *modelInfo, ind reflect.Value) (column string, value interfac
return return
} }
// get fields description as flatted string.
func getFlatParams(fi *fieldInfo, args []interface{}, tz *time.Location) (params []interface{}) { func getFlatParams(fi *fieldInfo, args []interface{}, tz *time.Location) (params []interface{}) {
outFor: outFor:
@ -48,9 +56,16 @@ outFor:
continue continue
} }
switch v := arg.(type) { kind := val.Kind()
case []byte: if kind == reflect.Ptr {
case string: val = val.Elem()
kind = val.Kind()
arg = val.Interface()
}
switch kind {
case reflect.String:
v := val.String()
if fi != nil { if fi != nil {
if fi.fieldType == TypeDateField || fi.fieldType == TypeDateTimeField { if fi.fieldType == TypeDateField || fi.fieldType == TypeDateTimeField {
var t time.Time var t time.Time
@ -75,16 +90,20 @@ outFor:
} }
} }
arg = v arg = v
case time.Time: case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
if fi != nil && fi.fieldType == TypeDateField { arg = val.Int()
arg = v.In(tz).Format(format_Date) case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
} else { arg = val.Uint()
arg = v.In(tz).Format(format_DateTime) case reflect.Float32:
} arg, _ = StrTo(ToStr(arg)).Float64()
default: case reflect.Float64:
kind := val.Kind() arg = val.Float()
switch kind { case reflect.Bool:
arg = val.Bool()
case reflect.Slice, reflect.Array: case reflect.Slice, reflect.Array:
if _, ok := arg.([]byte); ok {
continue outFor
}
var args []interface{} var args []interface{}
for i := 0; i < val.Len(); i++ { for i := 0; i < val.Len(); i++ {
@ -107,16 +126,19 @@ outFor:
params = append(params, p...) params = append(params, p...)
} }
continue outFor continue outFor
case reflect.Struct:
case reflect.Ptr, reflect.Struct: if v, ok := arg.(time.Time); ok {
ind := reflect.Indirect(val) if fi != nil && fi.fieldType == TypeDateField {
arg = v.In(tz).Format(format_Date)
if ind.Kind() == reflect.Struct { } else {
typ := ind.Type() arg = v.In(tz).Format(format_DateTime)
}
} else {
typ := val.Type()
name := getFullName(typ) name := getFullName(typ)
var value interface{} var value interface{}
if mmi, ok := modelCache.getByFN(name); ok { if mmi, ok := modelCache.getByFN(name); ok {
if _, vu, exist := getExistPk(mmi, ind); exist { if _, vu, exist := getExistPk(mmi, val); exist {
value = vu value = vu
} }
} }
@ -125,11 +147,9 @@ outFor:
if arg == nil { if arg == nil {
panic(fmt.Errorf("need a valid args value, unknown table or value `%s`", name)) panic(fmt.Errorf("need a valid args value, unknown table or value `%s`", name))
} }
} else {
arg = ind.Interface()
}
} }
} }
params = append(params, arg) params = append(params, arg)
} }
return return

View File

@ -1,3 +1,9 @@
// Beego (http://beego.me/)
// @description beego is an open-source, high-performance web framework for the Go programming language.
// @link http://github.com/astaxie/beego for the canonical source repository
// @license http://github.com/astaxie/beego/blob/master/LICENSE
// @authors slene
package orm package orm
import ( import (
@ -41,6 +47,7 @@ var (
} }
) )
// model info collection
type _modelCache struct { type _modelCache struct {
sync.RWMutex sync.RWMutex
orders []string orders []string
@ -49,6 +56,7 @@ type _modelCache struct {
done bool done bool
} }
// get all model info
func (mc *_modelCache) all() map[string]*modelInfo { func (mc *_modelCache) all() map[string]*modelInfo {
m := make(map[string]*modelInfo, len(mc.cache)) m := make(map[string]*modelInfo, len(mc.cache))
for k, v := range mc.cache { for k, v := range mc.cache {
@ -57,6 +65,7 @@ func (mc *_modelCache) all() map[string]*modelInfo {
return m return m
} }
// get orderd model info
func (mc *_modelCache) allOrdered() []*modelInfo { func (mc *_modelCache) allOrdered() []*modelInfo {
m := make([]*modelInfo, 0, len(mc.orders)) m := make([]*modelInfo, 0, len(mc.orders))
for _, table := range mc.orders { for _, table := range mc.orders {
@ -65,16 +74,19 @@ func (mc *_modelCache) allOrdered() []*modelInfo {
return m return m
} }
// get model info by table name
func (mc *_modelCache) get(table string) (mi *modelInfo, ok bool) { func (mc *_modelCache) get(table string) (mi *modelInfo, ok bool) {
mi, ok = mc.cache[table] mi, ok = mc.cache[table]
return return
} }
// get model info by field name
func (mc *_modelCache) getByFN(name string) (mi *modelInfo, ok bool) { func (mc *_modelCache) getByFN(name string) (mi *modelInfo, ok bool) {
mi, ok = mc.cacheByFN[name] mi, ok = mc.cacheByFN[name]
return return
} }
// set model info to collection
func (mc *_modelCache) set(table string, mi *modelInfo) *modelInfo { func (mc *_modelCache) set(table string, mi *modelInfo) *modelInfo {
mii := mc.cache[table] mii := mc.cache[table]
mc.cache[table] = mi mc.cache[table] = mi
@ -85,9 +97,16 @@ func (mc *_modelCache) set(table string, mi *modelInfo) *modelInfo {
return mii return mii
} }
// clean all model info.
func (mc *_modelCache) clean() { func (mc *_modelCache) clean() {
mc.orders = make([]string, 0) mc.orders = make([]string, 0)
mc.cache = make(map[string]*modelInfo) mc.cache = make(map[string]*modelInfo)
mc.cacheByFN = make(map[string]*modelInfo) mc.cacheByFN = make(map[string]*modelInfo)
mc.done = false mc.done = false
} }
// Clean model cache. Then you can re-RegisterModel.
// Common use this api for test case.
func ResetModelCache() {
modelCache.clean()
}

View File

@ -1,3 +1,9 @@
// Beego (http://beego.me/)
// @description beego is an open-source, high-performance web framework for the Go programming language.
// @link http://github.com/astaxie/beego for the canonical source repository
// @license http://github.com/astaxie/beego/blob/master/LICENSE
// @authors slene
package orm package orm
import ( import (
@ -8,7 +14,9 @@ import (
"strings" "strings"
) )
func registerModel(model interface{}, prefix string) { // register models.
// prefix means table name prefix.
func registerModel(prefix string, model interface{}) {
val := reflect.ValueOf(model) val := reflect.ValueOf(model)
ind := reflect.Indirect(val) ind := reflect.Indirect(val)
typ := ind.Type() typ := ind.Type()
@ -67,6 +75,7 @@ func registerModel(model interface{}, prefix string) {
modelCache.set(table, info) modelCache.set(table, info)
} }
// boostrap models
func bootStrap() { func bootStrap() {
if modelCache.done { if modelCache.done {
return return
@ -281,27 +290,24 @@ end:
} }
} }
// register models
func RegisterModel(models ...interface{}) { func RegisterModel(models ...interface{}) {
if modelCache.done { RegisterModelWithPrefix("", models...)
panic(fmt.Errorf("RegisterModel must be run before BootStrap"))
}
for _, model := range models {
registerModel(model, "")
}
} }
// register model with a prefix // register models with a prefix
func RegisterModelWithPrefix(prefix string, models ...interface{}) { func RegisterModelWithPrefix(prefix string, models ...interface{}) {
if modelCache.done { if modelCache.done {
panic(fmt.Errorf("RegisterModel must be run before BootStrap")) panic(fmt.Errorf("RegisterModel must be run before BootStrap"))
} }
for _, model := range models { for _, model := range models {
registerModel(model, prefix) registerModel(prefix, model)
} }
} }
// bootrap models.
// make all model parsed and can not add more models
func BootStrap() { func BootStrap() {
if modelCache.done { if modelCache.done {
return return

View File

@ -1,3 +1,9 @@
// Beego (http://beego.me/)
// @description beego is an open-source, high-performance web framework for the Go programming language.
// @link http://github.com/astaxie/beego for the canonical source repository
// @license http://github.com/astaxie/beego/blob/master/LICENSE
// @authors slene
package orm package orm
import ( import (

View File

@ -1,3 +1,9 @@
// Beego (http://beego.me/)
// @description beego is an open-source, high-performance web framework for the Go programming language.
// @link http://github.com/astaxie/beego for the canonical source repository
// @license http://github.com/astaxie/beego/blob/master/LICENSE
// @authors slene
package orm package orm
import ( import (
@ -9,6 +15,7 @@ import (
var errSkipField = errors.New("skip field") var errSkipField = errors.New("skip field")
// field info collection
type fields struct { type fields struct {
pk *fieldInfo pk *fieldInfo
columns map[string]*fieldInfo columns map[string]*fieldInfo
@ -23,6 +30,7 @@ type fields struct {
dbcols []string dbcols []string
} }
// add field info
func (f *fields) Add(fi *fieldInfo) (added bool) { func (f *fields) Add(fi *fieldInfo) (added bool) {
if f.fields[fi.name] == nil && f.columns[fi.column] == nil { if f.fields[fi.name] == nil && f.columns[fi.column] == nil {
f.columns[fi.column] = fi f.columns[fi.column] = fi
@ -49,14 +57,17 @@ func (f *fields) Add(fi *fieldInfo) (added bool) {
return true return true
} }
// get field info by name
func (f *fields) GetByName(name string) *fieldInfo { func (f *fields) GetByName(name string) *fieldInfo {
return f.fields[name] return f.fields[name]
} }
// get field info by column name
func (f *fields) GetByColumn(column string) *fieldInfo { func (f *fields) GetByColumn(column string) *fieldInfo {
return f.columns[column] return f.columns[column]
} }
// get field info by string, name is prior
func (f *fields) GetByAny(name string) (*fieldInfo, bool) { func (f *fields) GetByAny(name string) (*fieldInfo, bool) {
if fi, ok := f.fields[name]; ok { if fi, ok := f.fields[name]; ok {
return fi, ok return fi, ok
@ -70,6 +81,7 @@ func (f *fields) GetByAny(name string) (*fieldInfo, bool) {
return nil, false return nil, false
} }
// create new field info collection
func newFields() *fields { func newFields() *fields {
f := new(fields) f := new(fields)
f.fields = make(map[string]*fieldInfo) f.fields = make(map[string]*fieldInfo)
@ -79,6 +91,7 @@ func newFields() *fields {
return f return f
} }
// single field info
type fieldInfo struct { type fieldInfo struct {
mi *modelInfo mi *modelInfo
fieldIndex int fieldIndex int
@ -115,6 +128,7 @@ type fieldInfo struct {
onDelete string onDelete string
} }
// new field info
func newFieldInfo(mi *modelInfo, field reflect.Value, sf reflect.StructField) (fi *fieldInfo, err error) { func newFieldInfo(mi *modelInfo, field reflect.Value, sf reflect.StructField) (fi *fieldInfo, err error) {
var ( var (
tag string tag string

View File

@ -1,3 +1,9 @@
// Beego (http://beego.me/)
// @description beego is an open-source, high-performance web framework for the Go programming language.
// @link http://github.com/astaxie/beego for the canonical source repository
// @license http://github.com/astaxie/beego/blob/master/LICENSE
// @authors slene
package orm package orm
import ( import (
@ -7,6 +13,7 @@ import (
"reflect" "reflect"
) )
// single model info
type modelInfo struct { type modelInfo struct {
pkg string pkg string
name string name string
@ -20,6 +27,7 @@ type modelInfo struct {
isThrough bool isThrough bool
} }
// new model info
func newModelInfo(val reflect.Value) (info *modelInfo) { func newModelInfo(val reflect.Value) (info *modelInfo) {
var ( var (
err error err error
@ -41,6 +49,9 @@ func newModelInfo(val reflect.Value) (info *modelInfo) {
for i := 0; i < ind.NumField(); i++ { for i := 0; i < ind.NumField(); i++ {
field := ind.Field(i) field := ind.Field(i)
sf = ind.Type().Field(i) sf = ind.Type().Field(i)
if sf.PkgPath != "" {
continue
}
fi, err = newFieldInfo(info, field, sf) fi, err = newFieldInfo(info, field, sf)
if err != nil { if err != nil {
@ -79,6 +90,8 @@ func newModelInfo(val reflect.Value) (info *modelInfo) {
return return
} }
// combine related model info to new model info.
// prepare for relation models query.
func newM2MModelInfo(m1, m2 *modelInfo) (info *modelInfo) { func newM2MModelInfo(m1, m2 *modelInfo) (info *modelInfo) {
info = new(modelInfo) info = new(modelInfo)
info.fields = newFields() info.fields = newFields()

View File

@ -1,6 +1,13 @@
// Beego (http://beego.me/)
// @description beego is an open-source, high-performance web framework for the Go programming language.
// @link http://github.com/astaxie/beego for the canonical source repository
// @license http://github.com/astaxie/beego/blob/master/LICENSE
// @authors slene
package orm package orm
import ( import (
"database/sql"
"encoding/json" "encoding/json"
"fmt" "fmt"
"os" "os"
@ -82,7 +89,6 @@ func (e *JsonField) SetRaw(value interface{}) error {
default: default:
return fmt.Errorf("<JsonField.SetRaw> unknown value `%v`", value) return fmt.Errorf("<JsonField.SetRaw> unknown value `%v`", value)
} }
return nil
} }
func (e *JsonField) RawValue() interface{} { func (e *JsonField) RawValue() interface{} {
@ -121,7 +127,7 @@ type DataNull struct {
Char string `orm:"null;size(50)"` Char string `orm:"null;size(50)"`
Text string `orm:"null;type(text)"` Text string `orm:"null;type(text)"`
Date time.Time `orm:"null;type(date)"` Date time.Time `orm:"null;type(date)"`
DateTime time.Time `orm:"null;column(datetime)""` DateTime time.Time `orm:"null;column(datetime)"`
Byte byte `orm:"null"` Byte byte `orm:"null"`
Rune rune `orm:"null"` Rune rune `orm:"null"`
Int int `orm:"null"` Int int `orm:"null"`
@ -137,6 +143,49 @@ type DataNull struct {
Float32 float32 `orm:"null"` Float32 float32 `orm:"null"`
Float64 float64 `orm:"null"` Float64 float64 `orm:"null"`
Decimal float64 `orm:"digits(8);decimals(4);null"` Decimal float64 `orm:"digits(8);decimals(4);null"`
NullString sql.NullString `orm:"null"`
NullBool sql.NullBool `orm:"null"`
NullFloat64 sql.NullFloat64 `orm:"null"`
NullInt64 sql.NullInt64 `orm:"null"`
}
type String string
type Boolean bool
type Byte byte
type Rune rune
type Int int
type Int8 int8
type Int16 int16
type Int32 int32
type Int64 int64
type Uint uint
type Uint8 uint8
type Uint16 uint16
type Uint32 uint32
type Uint64 uint64
type Float32 float64
type Float64 float64
type DataCustom struct {
Id int
Boolean Boolean
Char string `orm:"size(50)"`
Text string `orm:"type(text)"`
Byte Byte
Rune Rune
Int Int
Int8 Int8
Int16 Int16
Int32 Int32
Int64 Int64
Uint Uint
Uint8 Uint8
Uint16 Uint16
Uint32 Uint32
Uint64 Uint64
Float32 Float32
Float64 Float64
Decimal Float64 `orm:"digits(8);decimals(4)"`
} }
// only for mysql // only for mysql
@ -150,7 +199,7 @@ type User struct {
UserName string `orm:"size(30);unique"` UserName string `orm:"size(30);unique"`
Email string `orm:"size(100)"` Email string `orm:"size(100)"`
Password string `orm:"size(100)"` Password string `orm:"size(100)"`
Status int16 Status int16 `orm:"column(Status)"`
IsStaff bool IsStaff bool
IsActive bool `orm:"default(1)"` IsActive bool `orm:"default(1)"`
Created time.Time `orm:"auto_now_add;type(date)"` Created time.Time `orm:"auto_now_add;type(date)"`
@ -161,6 +210,8 @@ type User struct {
Nums int Nums int
Langs SliceStringField `orm:"size(100)"` Langs SliceStringField `orm:"size(100)"`
Extra JsonField `orm:"type(text)"` Extra JsonField `orm:"type(text)"`
unexport bool `orm:"-"`
unexport_ bool
} }
func (u *User) TableIndex() [][]string { func (u *User) TableIndex() [][]string {
@ -303,9 +354,8 @@ go test -v github.com/astaxie/beego/orm
#### Sqlite3 #### Sqlite3
touch /path/to/orm_test.db
export ORM_DRIVER=sqlite3 export ORM_DRIVER=sqlite3
export ORM_SOURCE=/path/to/orm_test.db export ORM_SOURCE='file:memory_test?mode=memory'
go test -v github.com/astaxie/beego/orm go test -v github.com/astaxie/beego/orm

View File

@ -1,16 +1,25 @@
// Beego (http://beego.me/)
// @description beego is an open-source, high-performance web framework for the Go programming language.
// @link http://github.com/astaxie/beego for the canonical source repository
// @license http://github.com/astaxie/beego/blob/master/LICENSE
// @authors slene
package orm package orm
import ( import (
"database/sql"
"fmt" "fmt"
"reflect" "reflect"
"strings" "strings"
"time" "time"
) )
// get reflect.Type name with package path.
func getFullName(typ reflect.Type) string { func getFullName(typ reflect.Type) string {
return typ.PkgPath() + "." + typ.Name() return typ.PkgPath() + "." + typ.Name()
} }
// get table name. method, or field name. auto snaked.
func getTableName(val reflect.Value) string { func getTableName(val reflect.Value) string {
ind := reflect.Indirect(val) ind := reflect.Indirect(val)
fun := val.MethodByName("TableName") fun := val.MethodByName("TableName")
@ -26,6 +35,7 @@ func getTableName(val reflect.Value) string {
return snakeString(ind.Type().Name()) return snakeString(ind.Type().Name())
} }
// get table engine, mysiam or innodb.
func getTableEngine(val reflect.Value) string { func getTableEngine(val reflect.Value) string {
fun := val.MethodByName("TableEngine") fun := val.MethodByName("TableEngine")
if fun.IsValid() { if fun.IsValid() {
@ -40,6 +50,7 @@ func getTableEngine(val reflect.Value) string {
return "" return ""
} }
// get table index from method.
func getTableIndex(val reflect.Value) [][]string { func getTableIndex(val reflect.Value) [][]string {
fun := val.MethodByName("TableIndex") fun := val.MethodByName("TableIndex")
if fun.IsValid() { if fun.IsValid() {
@ -56,6 +67,7 @@ func getTableIndex(val reflect.Value) [][]string {
return nil return nil
} }
// get table unique from method
func getTableUnique(val reflect.Value) [][]string { func getTableUnique(val reflect.Value) [][]string {
fun := val.MethodByName("TableUnique") fun := val.MethodByName("TableUnique")
if fun.IsValid() { if fun.IsValid() {
@ -72,8 +84,8 @@ func getTableUnique(val reflect.Value) [][]string {
return nil return nil
} }
// get snaked column name
func getColumnName(ft int, addrField reflect.Value, sf reflect.StructField, col string) string { func getColumnName(ft int, addrField reflect.Value, sf reflect.StructField, col string) string {
col = strings.ToLower(col)
column := col column := col
if col == "" { if col == "" {
column = snakeString(sf.Name) column = snakeString(sf.Name)
@ -89,6 +101,7 @@ func getColumnName(ft int, addrField reflect.Value, sf reflect.StructField, col
return column return column
} }
// return field type as type constant from reflect.Value
func getFieldType(val reflect.Value) (ft int, err error) { func getFieldType(val reflect.Value) (ft int, err error) {
elm := reflect.Indirect(val) elm := reflect.Indirect(val)
switch elm.Kind() { switch elm.Kind() {
@ -114,20 +127,27 @@ func getFieldType(val reflect.Value) (ft int, err error) {
ft = TypeBooleanField ft = TypeBooleanField
case reflect.String: case reflect.String:
ft = TypeCharField ft = TypeCharField
case reflect.Invalid:
default: default:
if elm.CanInterface() { switch elm.Interface().(type) {
if _, ok := elm.Interface().(time.Time); ok { case sql.NullInt64:
ft = TypeBigIntegerField
case sql.NullFloat64:
ft = TypeFloatField
case sql.NullBool:
ft = TypeBooleanField
case sql.NullString:
ft = TypeCharField
case time.Time:
ft = TypeDateTimeField ft = TypeDateTimeField
} }
} }
}
if ft&IsFieldType == 0 { if ft&IsFieldType == 0 {
err = fmt.Errorf("unsupport field type %s, may be miss setting tag", val) err = fmt.Errorf("unsupport field type %s, may be miss setting tag", val)
} }
return return
} }
// parse struct tag string
func parseStructTag(data string, attrs *map[string]bool, tags *map[string]string) { func parseStructTag(data string, attrs *map[string]bool, tags *map[string]string) {
attr := make(map[string]bool) attr := make(map[string]bool)
tag := make(map[string]string) tag := make(map[string]string)

View File

@ -1,3 +1,9 @@
// Beego (http://beego.me/)
// @description beego is an open-source, high-performance web framework for the Go programming language.
// @link http://github.com/astaxie/beego for the canonical source repository
// @license http://github.com/astaxie/beego/blob/master/LICENSE
// @authors slene
package orm package orm
import ( import (
@ -25,6 +31,7 @@ var (
ErrMultiRows = errors.New("<QuerySeter> return multi rows") ErrMultiRows = errors.New("<QuerySeter> return multi rows")
ErrNoRows = errors.New("<QuerySeter> no row found") ErrNoRows = errors.New("<QuerySeter> no row found")
ErrStmtClosed = errors.New("<QuerySeter> stmt already closed") ErrStmtClosed = errors.New("<QuerySeter> stmt already closed")
ErrArgs = errors.New("<Ormer> args error may be empty")
ErrNotImplement = errors.New("have not implement") ErrNotImplement = errors.New("have not implement")
) )
@ -39,11 +46,12 @@ type orm struct {
var _ Ormer = new(orm) var _ Ormer = new(orm)
func (o *orm) getMiInd(md interface{}) (mi *modelInfo, ind reflect.Value) { // get model info and model reflect value
func (o *orm) getMiInd(md interface{}, needPtr bool) (mi *modelInfo, ind reflect.Value) {
val := reflect.ValueOf(md) val := reflect.ValueOf(md)
ind = reflect.Indirect(val) ind = reflect.Indirect(val)
typ := ind.Type() typ := ind.Type()
if val.Kind() != reflect.Ptr { if needPtr && val.Kind() != reflect.Ptr {
panic(fmt.Errorf("<Ormer> cannot use non-ptr model struct `%s`", getFullName(typ))) panic(fmt.Errorf("<Ormer> cannot use non-ptr model struct `%s`", getFullName(typ)))
} }
name := getFullName(typ) name := getFullName(typ)
@ -53,6 +61,7 @@ func (o *orm) getMiInd(md interface{}) (mi *modelInfo, ind reflect.Value) {
panic(fmt.Errorf("<Ormer> table: `%s` not found, maybe not RegisterModel", name)) panic(fmt.Errorf("<Ormer> table: `%s` not found, maybe not RegisterModel", name))
} }
// get field info from model info by given field name
func (o *orm) getFieldInfo(mi *modelInfo, name string) *fieldInfo { func (o *orm) getFieldInfo(mi *modelInfo, name string) *fieldInfo {
fi, ok := mi.fields.GetByAny(name) fi, ok := mi.fields.GetByAny(name)
if !ok { if !ok {
@ -61,8 +70,9 @@ func (o *orm) getFieldInfo(mi *modelInfo, name string) *fieldInfo {
return fi return fi
} }
// read data to model
func (o *orm) Read(md interface{}, cols ...string) error { func (o *orm) Read(md interface{}, cols ...string) error {
mi, ind := o.getMiInd(md) mi, ind := o.getMiInd(md, true)
err := o.alias.DbBaser.Read(o.db, mi, ind, o.alias.TZ, cols) err := o.alias.DbBaser.Read(o.db, mi, ind, o.alias.TZ, cols)
if err != nil { if err != nil {
return err return err
@ -70,13 +80,35 @@ func (o *orm) Read(md interface{}, cols ...string) error {
return nil return nil
} }
// Try to read a row from the database, or insert one if it doesn't exist
func (o *orm) ReadOrCreate(md interface{}, col1 string, cols ...string) (bool, int64, error) {
cols = append([]string{col1}, cols...)
mi, ind := o.getMiInd(md, true)
err := o.alias.DbBaser.Read(o.db, mi, ind, o.alias.TZ, cols)
if err == ErrNoRows {
// Create
id, err := o.Insert(md)
return (err == nil), id, err
}
return false, ind.Field(mi.fields.pk.fieldIndex).Int(), err
}
// insert model data to database
func (o *orm) Insert(md interface{}) (int64, error) { func (o *orm) Insert(md interface{}) (int64, error) {
mi, ind := o.getMiInd(md) mi, ind := o.getMiInd(md, true)
id, err := o.alias.DbBaser.Insert(o.db, mi, ind, o.alias.TZ) id, err := o.alias.DbBaser.Insert(o.db, mi, ind, o.alias.TZ)
if err != nil { if err != nil {
return id, err return id, err
} }
if id > 0 {
o.setPk(mi, ind, id)
return id, nil
}
// set auto pk field
func (o *orm) setPk(mi *modelInfo, ind reflect.Value, id int64) {
if mi.fields.pk.auto { if mi.fields.pk.auto {
if mi.fields.pk.fieldType&IsPostiveIntegerField > 0 { if mi.fields.pk.fieldType&IsPostiveIntegerField > 0 {
ind.Field(mi.fields.pk.fieldIndex).SetUint(uint64(id)) ind.Field(mi.fields.pk.fieldIndex).SetUint(uint64(id))
@ -84,12 +116,47 @@ func (o *orm) Insert(md interface{}) (int64, error) {
ind.Field(mi.fields.pk.fieldIndex).SetInt(id) ind.Field(mi.fields.pk.fieldIndex).SetInt(id)
} }
} }
}
return id, nil
} }
// insert some models to database
func (o *orm) InsertMulti(bulk int, mds interface{}) (int64, error) {
var cnt int64
sind := reflect.Indirect(reflect.ValueOf(mds))
switch sind.Kind() {
case reflect.Array, reflect.Slice:
if sind.Len() == 0 {
return cnt, ErrArgs
}
default:
return cnt, ErrArgs
}
if bulk <= 1 {
for i := 0; i < sind.Len(); i++ {
ind := sind.Index(i)
mi, _ := o.getMiInd(ind.Interface(), false)
id, err := o.alias.DbBaser.Insert(o.db, mi, ind, o.alias.TZ)
if err != nil {
return cnt, err
}
o.setPk(mi, ind, id)
cnt += 1
}
} else {
mi, _ := o.getMiInd(sind.Index(0).Interface(), false)
return o.alias.DbBaser.InsertMulti(o.db, mi, sind, bulk, o.alias.TZ)
}
return cnt, nil
}
// update model to database.
// cols set the columns those want to update.
func (o *orm) Update(md interface{}, cols ...string) (int64, error) { func (o *orm) Update(md interface{}, cols ...string) (int64, error) {
mi, ind := o.getMiInd(md) mi, ind := o.getMiInd(md, true)
num, err := o.alias.DbBaser.Update(o.db, mi, ind, o.alias.TZ, cols) num, err := o.alias.DbBaser.Update(o.db, mi, ind, o.alias.TZ, cols)
if err != nil { if err != nil {
return num, err return num, err
@ -97,26 +164,22 @@ func (o *orm) Update(md interface{}, cols ...string) (int64, error) {
return num, nil return num, nil
} }
// delete model in database
func (o *orm) Delete(md interface{}) (int64, error) { func (o *orm) Delete(md interface{}) (int64, error) {
mi, ind := o.getMiInd(md) mi, ind := o.getMiInd(md, true)
num, err := o.alias.DbBaser.Delete(o.db, mi, ind, o.alias.TZ) num, err := o.alias.DbBaser.Delete(o.db, mi, ind, o.alias.TZ)
if err != nil { if err != nil {
return num, err return num, err
} }
if num > 0 { if num > 0 {
if mi.fields.pk.auto { o.setPk(mi, ind, 0)
if mi.fields.pk.fieldType&IsPostiveIntegerField > 0 {
ind.Field(mi.fields.pk.fieldIndex).SetUint(0)
} else {
ind.Field(mi.fields.pk.fieldIndex).SetInt(0)
}
}
} }
return num, nil return num, nil
} }
// create a models to models queryer
func (o *orm) QueryM2M(md interface{}, name string) QueryM2Mer { func (o *orm) QueryM2M(md interface{}, name string) QueryM2Mer {
mi, ind := o.getMiInd(md) mi, ind := o.getMiInd(md, true)
fi := o.getFieldInfo(mi, name) fi := o.getFieldInfo(mi, name)
switch { switch {
@ -129,6 +192,14 @@ func (o *orm) QueryM2M(md interface{}, name string) QueryM2Mer {
return newQueryM2M(md, o, mi, fi, ind) return newQueryM2M(md, o, mi, fi, ind)
} }
// load related models to md model.
// args are limit, offset int and order string.
//
// example:
// orm.LoadRelated(post,"Tags")
// for _,tag := range post.Tags{...}
//
// make sure the relation is defined in model struct tags.
func (o *orm) LoadRelated(md interface{}, name string, args ...interface{}) (int64, error) { func (o *orm) LoadRelated(md interface{}, name string, args ...interface{}) (int64, error) {
_, fi, ind, qseter := o.queryRelated(md, name) _, fi, ind, qseter := o.queryRelated(md, name)
@ -190,14 +261,21 @@ func (o *orm) LoadRelated(md interface{}, name string, args ...interface{}) (int
return nums, err return nums, err
} }
// return a QuerySeter for related models to md model.
// it can do all, update, delete in QuerySeter.
// example:
// qs := orm.QueryRelated(post,"Tag")
// qs.All(&[]*Tag{})
//
func (o *orm) QueryRelated(md interface{}, name string) QuerySeter { func (o *orm) QueryRelated(md interface{}, name string) QuerySeter {
// is this api needed ? // is this api needed ?
_, _, _, qs := o.queryRelated(md, name) _, _, _, qs := o.queryRelated(md, name)
return qs return qs
} }
// get QuerySeter for related models to md model
func (o *orm) queryRelated(md interface{}, name string) (*modelInfo, *fieldInfo, reflect.Value, QuerySeter) { func (o *orm) queryRelated(md interface{}, name string) (*modelInfo, *fieldInfo, reflect.Value, QuerySeter) {
mi, ind := o.getMiInd(md) mi, ind := o.getMiInd(md, true)
fi := o.getFieldInfo(mi, name) fi := o.getFieldInfo(mi, name)
_, _, exist := getExistPk(mi, ind) _, _, exist := getExistPk(mi, ind)
@ -221,12 +299,13 @@ func (o *orm) queryRelated(md interface{}, name string) (*modelInfo, *fieldInfo,
} }
if qs == nil { if qs == nil {
panic(fmt.Errorf("<Ormer> name `%s` for model `%s` is not an available rel/reverse field")) panic(fmt.Errorf("<Ormer> name `%s` for model `%s` is not an available rel/reverse field", md, name))
} }
return mi, fi, ind, qs return mi, fi, ind, qs
} }
// get reverse relation QuerySeter
func (o *orm) getReverseQs(md interface{}, mi *modelInfo, fi *fieldInfo) *querySet { func (o *orm) getReverseQs(md interface{}, mi *modelInfo, fi *fieldInfo) *querySet {
switch fi.fieldType { switch fi.fieldType {
case RelReverseOne, RelReverseMany: case RelReverseOne, RelReverseMany:
@ -247,6 +326,7 @@ func (o *orm) getReverseQs(md interface{}, mi *modelInfo, fi *fieldInfo) *queryS
return q return q
} }
// get relation QuerySeter
func (o *orm) getRelQs(md interface{}, mi *modelInfo, fi *fieldInfo) *querySet { func (o *orm) getRelQs(md interface{}, mi *modelInfo, fi *fieldInfo) *querySet {
switch fi.fieldType { switch fi.fieldType {
case RelOneToOne, RelForeignKey, RelManyToMany: case RelOneToOne, RelForeignKey, RelManyToMany:
@ -266,6 +346,9 @@ func (o *orm) getRelQs(md interface{}, mi *modelInfo, fi *fieldInfo) *querySet {
return q return q
} }
// return a QuerySeter for table operations.
// table name can be string or struct.
// e.g. QueryTable("user"), QueryTable(&user{}) or QueryTable((*User)(nil)),
func (o *orm) QueryTable(ptrStructOrTableName interface{}) (qs QuerySeter) { func (o *orm) QueryTable(ptrStructOrTableName interface{}) (qs QuerySeter) {
name := "" name := ""
if table, ok := ptrStructOrTableName.(string); ok { if table, ok := ptrStructOrTableName.(string); ok {
@ -285,6 +368,7 @@ func (o *orm) QueryTable(ptrStructOrTableName interface{}) (qs QuerySeter) {
return return
} }
// switch to another registered database driver by given name.
func (o *orm) Using(name string) error { func (o *orm) Using(name string) error {
if o.isTx { if o.isTx {
panic(fmt.Errorf("<Ormer.Using> transaction has been start, cannot change db")) panic(fmt.Errorf("<Ormer.Using> transaction has been start, cannot change db"))
@ -302,6 +386,7 @@ func (o *orm) Using(name string) error {
return nil return nil
} }
// begin transaction
func (o *orm) Begin() error { func (o *orm) Begin() error {
if o.isTx { if o.isTx {
return ErrTxHasBegan return ErrTxHasBegan
@ -320,6 +405,7 @@ func (o *orm) Begin() error {
return nil return nil
} }
// commit transaction
func (o *orm) Commit() error { func (o *orm) Commit() error {
if o.isTx == false { if o.isTx == false {
return ErrTxDone return ErrTxDone
@ -334,6 +420,7 @@ func (o *orm) Commit() error {
return err return err
} }
// rollback transaction
func (o *orm) Rollback() error { func (o *orm) Rollback() error {
if o.isTx == false { if o.isTx == false {
return ErrTxDone return ErrTxDone
@ -348,14 +435,21 @@ func (o *orm) Rollback() error {
return err return err
} }
// return a raw query seter for raw sql string.
func (o *orm) Raw(query string, args ...interface{}) RawSeter { func (o *orm) Raw(query string, args ...interface{}) RawSeter {
return newRawSet(o, query, args) return newRawSet(o, query, args)
} }
// return current using database Driver
func (o *orm) Driver() Driver { func (o *orm) Driver() Driver {
return driver(o.alias.Name) return driver(o.alias.Name)
} }
func (o *orm) GetDB() dbQuerier {
panic(ErrNotImplement)
}
// create new orm
func NewOrm() Ormer { func NewOrm() Ormer {
BootStrap() // execute only once BootStrap() // execute only once
@ -366,3 +460,30 @@ func NewOrm() Ormer {
} }
return o return o
} }
// create a new ormer object with specify *sql.DB for query
func NewOrmWithDB(driverName, aliasName string, db *sql.DB) (Ormer, error) {
var al *alias
if dr, ok := drivers[driverName]; ok {
al = new(alias)
al.DbBaser = dbBasers[dr]
al.Driver = dr
} else {
return nil, fmt.Errorf("driver name `%s` have not registered", driverName)
}
al.Name = aliasName
al.DriverName = driverName
o := new(orm)
o.alias = al
if Debug {
o.db = newDbQueryLog(o.alias, db)
} else {
o.db = db
}
return o, nil
}

View File

@ -1,3 +1,9 @@
// Beego (http://beego.me/)
// @description beego is an open-source, high-performance web framework for the Go programming language.
// @link http://github.com/astaxie/beego for the canonical source repository
// @license http://github.com/astaxie/beego/blob/master/LICENSE
// @authors slene
package orm package orm
import ( import (
@ -18,15 +24,19 @@ type condValue struct {
isCond bool isCond bool
} }
// condition struct.
// work for WHERE conditions.
type Condition struct { type Condition struct {
params []condValue params []condValue
} }
// return new condition struct
func NewCondition() *Condition { func NewCondition() *Condition {
c := &Condition{} c := &Condition{}
return c return c
} }
// add expression to condition
func (c Condition) And(expr string, args ...interface{}) *Condition { func (c Condition) And(expr string, args ...interface{}) *Condition {
if expr == "" || len(args) == 0 { if expr == "" || len(args) == 0 {
panic(fmt.Errorf("<Condition.And> args cannot empty")) panic(fmt.Errorf("<Condition.And> args cannot empty"))
@ -35,6 +45,7 @@ func (c Condition) And(expr string, args ...interface{}) *Condition {
return &c return &c
} }
// add NOT expression to condition
func (c Condition) AndNot(expr string, args ...interface{}) *Condition { func (c Condition) AndNot(expr string, args ...interface{}) *Condition {
if expr == "" || len(args) == 0 { if expr == "" || len(args) == 0 {
panic(fmt.Errorf("<Condition.AndNot> args cannot empty")) panic(fmt.Errorf("<Condition.AndNot> args cannot empty"))
@ -43,6 +54,7 @@ func (c Condition) AndNot(expr string, args ...interface{}) *Condition {
return &c return &c
} }
// combine a condition to current condition
func (c *Condition) AndCond(cond *Condition) *Condition { func (c *Condition) AndCond(cond *Condition) *Condition {
c = c.clone() c = c.clone()
if c == cond { if c == cond {
@ -54,6 +66,7 @@ func (c *Condition) AndCond(cond *Condition) *Condition {
return c return c
} }
// add OR expression to condition
func (c Condition) Or(expr string, args ...interface{}) *Condition { func (c Condition) Or(expr string, args ...interface{}) *Condition {
if expr == "" || len(args) == 0 { if expr == "" || len(args) == 0 {
panic(fmt.Errorf("<Condition.Or> args cannot empty")) panic(fmt.Errorf("<Condition.Or> args cannot empty"))
@ -62,6 +75,7 @@ func (c Condition) Or(expr string, args ...interface{}) *Condition {
return &c return &c
} }
// add OR NOT expression to condition
func (c Condition) OrNot(expr string, args ...interface{}) *Condition { func (c Condition) OrNot(expr string, args ...interface{}) *Condition {
if expr == "" || len(args) == 0 { if expr == "" || len(args) == 0 {
panic(fmt.Errorf("<Condition.OrNot> args cannot empty")) panic(fmt.Errorf("<Condition.OrNot> args cannot empty"))
@ -70,6 +84,7 @@ func (c Condition) OrNot(expr string, args ...interface{}) *Condition {
return &c return &c
} }
// combine a OR condition to current condition
func (c *Condition) OrCond(cond *Condition) *Condition { func (c *Condition) OrCond(cond *Condition) *Condition {
c = c.clone() c = c.clone()
if c == cond { if c == cond {
@ -81,10 +96,12 @@ func (c *Condition) OrCond(cond *Condition) *Condition {
return c return c
} }
// check the condition arguments are empty or not.
func (c *Condition) IsEmpty() bool { func (c *Condition) IsEmpty() bool {
return len(c.params) == 0 return len(c.params) == 0
} }
// clone a condition
func (c Condition) clone() *Condition { func (c Condition) clone() *Condition {
return &c return &c
} }

View File

@ -1,3 +1,9 @@
// Beego (http://beego.me/)
// @description beego is an open-source, high-performance web framework for the Go programming language.
// @link http://github.com/astaxie/beego for the canonical source repository
// @license http://github.com/astaxie/beego/blob/master/LICENSE
// @authors slene
package orm package orm
import ( import (
@ -13,6 +19,7 @@ type Log struct {
*log.Logger *log.Logger
} }
// set io.Writer to create a Logger.
func NewLog(out io.Writer) *Log { func NewLog(out io.Writer) *Log {
d := new(Log) d := new(Log)
d.Logger = log.New(out, "[ORM]", 1e9) d.Logger = log.New(out, "[ORM]", 1e9)
@ -40,6 +47,8 @@ func debugLogQueies(alias *alias, operaton, query string, t time.Time, err error
DebugLog.Println(con) DebugLog.Println(con)
} }
// statement query logger struct.
// if dev mode, use stmtQueryLog, or use stmtQuerier.
type stmtQueryLog struct { type stmtQueryLog struct {
alias *alias alias *alias
query string query string
@ -84,6 +93,8 @@ func newStmtQueryLog(alias *alias, stmt stmtQuerier, query string) stmtQuerier {
return d return d
} }
// database query logger struct.
// if dev mode, use dbQueryLog, or use dbQuerier.
type dbQueryLog struct { type dbQueryLog struct {
alias *alias alias *alias
db dbQuerier db dbQuerier

View File

@ -1,3 +1,9 @@
// Beego (http://beego.me/)
// @description beego is an open-source, high-performance web framework for the Go programming language.
// @link http://github.com/astaxie/beego for the canonical source repository
// @license http://github.com/astaxie/beego/blob/master/LICENSE
// @authors slene
package orm package orm
import ( import (
@ -5,6 +11,7 @@ import (
"reflect" "reflect"
) )
// an insert queryer struct
type insertSet struct { type insertSet struct {
mi *modelInfo mi *modelInfo
orm *orm orm *orm
@ -14,6 +21,7 @@ type insertSet struct {
var _ Inserter = new(insertSet) var _ Inserter = new(insertSet)
// insert model ignore it's registered or not.
func (o *insertSet) Insert(md interface{}) (int64, error) { func (o *insertSet) Insert(md interface{}) (int64, error) {
if o.closed { if o.closed {
return 0, ErrStmtClosed return 0, ErrStmtClosed
@ -44,6 +52,7 @@ func (o *insertSet) Insert(md interface{}) (int64, error) {
return id, nil return id, nil
} }
// close insert queryer statement
func (o *insertSet) Close() error { func (o *insertSet) Close() error {
if o.closed { if o.closed {
return ErrStmtClosed return ErrStmtClosed
@ -52,6 +61,7 @@ func (o *insertSet) Close() error {
return o.stmt.Close() return o.stmt.Close()
} }
// create new insert queryer.
func newInsertSet(orm *orm, mi *modelInfo) (Inserter, error) { func newInsertSet(orm *orm, mi *modelInfo) (Inserter, error) {
bi := new(insertSet) bi := new(insertSet)
bi.orm = orm bi.orm = orm

View File

@ -1,9 +1,16 @@
// Beego (http://beego.me/)
// @description beego is an open-source, high-performance web framework for the Go programming language.
// @link http://github.com/astaxie/beego for the canonical source repository
// @license http://github.com/astaxie/beego/blob/master/LICENSE
// @authors slene
package orm package orm
import ( import (
"reflect" "reflect"
) )
// model to model struct
type queryM2M struct { type queryM2M struct {
md interface{} md interface{}
mi *modelInfo mi *modelInfo
@ -12,6 +19,13 @@ type queryM2M struct {
ind reflect.Value ind reflect.Value
} }
// add models to origin models when creating queryM2M.
// example:
// m2m := orm.QueryM2M(post,"Tag")
// m2m.Add(&Tag1{},&Tag2{})
// for _,tag := range post.Tags{}
//
// make sure the relation is defined in post model struct tag.
func (o *queryM2M) Add(mds ...interface{}) (int64, error) { func (o *queryM2M) Add(mds ...interface{}) (int64, error) {
fi := o.fi fi := o.fi
mi := fi.relThroughModelInfo mi := fi.relThroughModelInfo
@ -44,7 +58,8 @@ func (o *queryM2M) Add(mds ...interface{}) (int64, error) {
names := []string{mfi.column, rfi.column} names := []string{mfi.column, rfi.column}
var nums int64 values := make([]interface{}, 0, len(models)*2)
for _, md := range models { for _, md := range models {
ind := reflect.Indirect(reflect.ValueOf(md)) ind := reflect.Indirect(reflect.ValueOf(md))
@ -59,18 +74,14 @@ func (o *queryM2M) Add(mds ...interface{}) (int64, error) {
} }
} }
values := []interface{}{v1, v2} values = append(values, v1, v2)
_, err := dbase.InsertValue(orm.db, mi, names, values)
if err != nil {
return nums, err
} }
nums += 1 return dbase.InsertValue(orm.db, mi, true, names, values)
}
return nums, nil
} }
// remove models following the origin model relationship
func (o *queryM2M) Remove(mds ...interface{}) (int64, error) { func (o *queryM2M) Remove(mds ...interface{}) (int64, error) {
fi := o.fi fi := o.fi
qs := o.qs.Filter(fi.reverseFieldInfo.name, o.md) qs := o.qs.Filter(fi.reverseFieldInfo.name, o.md)
@ -82,17 +93,20 @@ func (o *queryM2M) Remove(mds ...interface{}) (int64, error) {
return nums, nil return nums, nil
} }
// check model is existed in relationship of origin model
func (o *queryM2M) Exist(md interface{}) bool { func (o *queryM2M) Exist(md interface{}) bool {
fi := o.fi fi := o.fi
return o.qs.Filter(fi.reverseFieldInfo.name, o.md). return o.qs.Filter(fi.reverseFieldInfo.name, o.md).
Filter(fi.reverseFieldInfoTwo.name, md).Exist() Filter(fi.reverseFieldInfoTwo.name, md).Exist()
} }
// clean all models in related of origin model
func (o *queryM2M) Clear() (int64, error) { func (o *queryM2M) Clear() (int64, error) {
fi := o.fi fi := o.fi
return o.qs.Filter(fi.reverseFieldInfo.name, o.md).Delete() return o.qs.Filter(fi.reverseFieldInfo.name, o.md).Delete()
} }
// count all related models of origin model
func (o *queryM2M) Count() (int64, error) { func (o *queryM2M) Count() (int64, error) {
fi := o.fi fi := o.fi
return o.qs.Filter(fi.reverseFieldInfo.name, o.md).Count() return o.qs.Filter(fi.reverseFieldInfo.name, o.md).Count()
@ -100,6 +114,7 @@ func (o *queryM2M) Count() (int64, error) {
var _ QueryM2Mer = new(queryM2M) var _ QueryM2Mer = new(queryM2M)
// create new M2M queryer.
func newQueryM2M(md interface{}, o *orm, mi *modelInfo, fi *fieldInfo, ind reflect.Value) QueryM2Mer { func newQueryM2M(md interface{}, o *orm, mi *modelInfo, fi *fieldInfo, ind reflect.Value) QueryM2Mer {
qm2m := new(queryM2M) qm2m := new(queryM2M)
qm2m.md = md qm2m.md = md

View File

@ -1,3 +1,9 @@
// Beego (http://beego.me/)
// @description beego is an open-source, high-performance web framework for the Go programming language.
// @link http://github.com/astaxie/beego for the canonical source repository
// @license http://github.com/astaxie/beego/blob/master/LICENSE
// @authors slene
package orm package orm
import ( import (
@ -18,6 +24,10 @@ const (
Col_Except Col_Except
) )
// ColValue do the field raw changes. e.g Nums = Nums + 10. usage:
// Params{
// "Nums": ColValue(Col_Add, 10),
// }
func ColValue(opt operator, value interface{}) interface{} { func ColValue(opt operator, value interface{}) interface{} {
switch opt { switch opt {
case Col_Add, Col_Minus, Col_Multiply, Col_Except: case Col_Add, Col_Minus, Col_Multiply, Col_Except:
@ -34,6 +44,7 @@ func ColValue(opt operator, value interface{}) interface{} {
return val return val
} }
// real query struct
type querySet struct { type querySet struct {
mi *modelInfo mi *modelInfo
cond *Condition cond *Condition
@ -47,6 +58,7 @@ type querySet struct {
var _ QuerySeter = new(querySet) var _ QuerySeter = new(querySet)
// add condition expression to QuerySeter.
func (o querySet) Filter(expr string, args ...interface{}) QuerySeter { func (o querySet) Filter(expr string, args ...interface{}) QuerySeter {
if o.cond == nil { if o.cond == nil {
o.cond = NewCondition() o.cond = NewCondition()
@ -55,6 +67,7 @@ func (o querySet) Filter(expr string, args ...interface{}) QuerySeter {
return &o return &o
} }
// add NOT condition to querySeter.
func (o querySet) Exclude(expr string, args ...interface{}) QuerySeter { func (o querySet) Exclude(expr string, args ...interface{}) QuerySeter {
if o.cond == nil { if o.cond == nil {
o.cond = NewCondition() o.cond = NewCondition()
@ -63,10 +76,13 @@ func (o querySet) Exclude(expr string, args ...interface{}) QuerySeter {
return &o return &o
} }
// set offset number
func (o *querySet) setOffset(num interface{}) { func (o *querySet) setOffset(num interface{}) {
o.offset = ToInt64(num) o.offset = ToInt64(num)
} }
// add LIMIT value.
// args[0] means offset, e.g. LIMIT num,offset.
func (o querySet) Limit(limit interface{}, args ...interface{}) QuerySeter { func (o querySet) Limit(limit interface{}, args ...interface{}) QuerySeter {
o.limit = ToInt64(limit) o.limit = ToInt64(limit)
if len(args) > 0 { if len(args) > 0 {
@ -75,16 +91,21 @@ func (o querySet) Limit(limit interface{}, args ...interface{}) QuerySeter {
return &o return &o
} }
// add OFFSET value
func (o querySet) Offset(offset interface{}) QuerySeter { func (o querySet) Offset(offset interface{}) QuerySeter {
o.setOffset(offset) o.setOffset(offset)
return &o return &o
} }
// add ORDER expression.
// "column" means ASC, "-column" means DESC.
func (o querySet) OrderBy(exprs ...string) QuerySeter { func (o querySet) OrderBy(exprs ...string) QuerySeter {
o.orders = exprs o.orders = exprs
return &o return &o
} }
// set relation model to query together.
// it will query relation models and assign to parent model.
func (o querySet) RelatedSel(params ...interface{}) QuerySeter { func (o querySet) RelatedSel(params ...interface{}) QuerySeter {
var related []string var related []string
if len(params) == 0 { if len(params) == 0 {
@ -105,36 +126,50 @@ func (o querySet) RelatedSel(params ...interface{}) QuerySeter {
return &o return &o
} }
// set condition to QuerySeter.
func (o querySet) SetCond(cond *Condition) QuerySeter { func (o querySet) SetCond(cond *Condition) QuerySeter {
o.cond = cond o.cond = cond
return &o return &o
} }
// return QuerySeter execution result number
func (o *querySet) Count() (int64, error) { func (o *querySet) Count() (int64, error) {
return o.orm.alias.DbBaser.Count(o.orm.db, o, o.mi, o.cond, o.orm.alias.TZ) return o.orm.alias.DbBaser.Count(o.orm.db, o, o.mi, o.cond, o.orm.alias.TZ)
} }
// check result empty or not after QuerySeter executed
func (o *querySet) Exist() bool { func (o *querySet) Exist() bool {
cnt, _ := o.orm.alias.DbBaser.Count(o.orm.db, o, o.mi, o.cond, o.orm.alias.TZ) cnt, _ := o.orm.alias.DbBaser.Count(o.orm.db, o, o.mi, o.cond, o.orm.alias.TZ)
return cnt > 0 return cnt > 0
} }
// execute update with parameters
func (o *querySet) Update(values Params) (int64, error) { func (o *querySet) Update(values Params) (int64, error) {
return o.orm.alias.DbBaser.UpdateBatch(o.orm.db, o, o.mi, o.cond, values, o.orm.alias.TZ) return o.orm.alias.DbBaser.UpdateBatch(o.orm.db, o, o.mi, o.cond, values, o.orm.alias.TZ)
} }
// execute delete
func (o *querySet) Delete() (int64, error) { func (o *querySet) Delete() (int64, error) {
return o.orm.alias.DbBaser.DeleteBatch(o.orm.db, o, o.mi, o.cond, o.orm.alias.TZ) return o.orm.alias.DbBaser.DeleteBatch(o.orm.db, o, o.mi, o.cond, o.orm.alias.TZ)
} }
// return a insert queryer.
// it can be used in times.
// example:
// i,err := sq.PrepareInsert()
// i.Add(&user1{},&user2{})
func (o *querySet) PrepareInsert() (Inserter, error) { func (o *querySet) PrepareInsert() (Inserter, error) {
return newInsertSet(o.orm, o.mi) return newInsertSet(o.orm, o.mi)
} }
// query all data and map to containers.
// cols means the columns when querying.
func (o *querySet) All(container interface{}, cols ...string) (int64, error) { func (o *querySet) All(container interface{}, cols ...string) (int64, error) {
return o.orm.alias.DbBaser.ReadBatch(o.orm.db, o, o.mi, o.cond, container, o.orm.alias.TZ, cols) return o.orm.alias.DbBaser.ReadBatch(o.orm.db, o, o.mi, o.cond, container, o.orm.alias.TZ, cols)
} }
// query one row data and map to containers.
// cols means the columns when querying.
func (o *querySet) One(container interface{}, cols ...string) error { func (o *querySet) One(container interface{}, cols ...string) error {
num, err := o.orm.alias.DbBaser.ReadBatch(o.orm.db, o, o.mi, o.cond, container, o.orm.alias.TZ, cols) num, err := o.orm.alias.DbBaser.ReadBatch(o.orm.db, o, o.mi, o.cond, container, o.orm.alias.TZ, cols)
if err != nil { if err != nil {
@ -149,18 +184,54 @@ func (o *querySet) One(container interface{}, cols ...string) error {
return nil return nil
} }
// query all data and map to []map[string]interface.
// expres means condition expression.
// it converts data to []map[column]value.
func (o *querySet) Values(results *[]Params, exprs ...string) (int64, error) { func (o *querySet) Values(results *[]Params, exprs ...string) (int64, error) {
return o.orm.alias.DbBaser.ReadValues(o.orm.db, o, o.mi, o.cond, exprs, results, o.orm.alias.TZ) return o.orm.alias.DbBaser.ReadValues(o.orm.db, o, o.mi, o.cond, exprs, results, o.orm.alias.TZ)
} }
// query all data and map to [][]interface
// it converts data to [][column_index]value
func (o *querySet) ValuesList(results *[]ParamsList, exprs ...string) (int64, error) { func (o *querySet) ValuesList(results *[]ParamsList, exprs ...string) (int64, error) {
return o.orm.alias.DbBaser.ReadValues(o.orm.db, o, o.mi, o.cond, exprs, results, o.orm.alias.TZ) return o.orm.alias.DbBaser.ReadValues(o.orm.db, o, o.mi, o.cond, exprs, results, o.orm.alias.TZ)
} }
// query all data and map to []interface.
// it's designed for one row record set, auto change to []value, not [][column]value.
func (o *querySet) ValuesFlat(result *ParamsList, expr string) (int64, error) { func (o *querySet) ValuesFlat(result *ParamsList, expr string) (int64, error) {
return o.orm.alias.DbBaser.ReadValues(o.orm.db, o, o.mi, o.cond, []string{expr}, result, o.orm.alias.TZ) return o.orm.alias.DbBaser.ReadValues(o.orm.db, o, o.mi, o.cond, []string{expr}, result, o.orm.alias.TZ)
} }
// query all rows into map[string]interface with specify key and value column name.
// keyCol = "name", valueCol = "value"
// table data
// name | value
// total | 100
// found | 200
// to map[string]interface{}{
// "total": 100,
// "found": 200,
// }
func (o *querySet) RowsToMap(result *Params, keyCol, valueCol string) (int64, error) {
panic(ErrNotImplement)
}
// query all rows into struct with specify key and value column name.
// keyCol = "name", valueCol = "value"
// table data
// name | value
// total | 100
// found | 200
// to struct {
// Total int
// Found int
// }
func (o *querySet) RowsToStruct(ptrStruct interface{}, keyCol, valueCol string) (int64, error) {
panic(ErrNotImplement)
}
// create new QuerySeter.
func newQuerySet(orm *orm, mi *modelInfo) QuerySeter { func newQuerySet(orm *orm, mi *modelInfo) QuerySeter {
o := new(querySet) o := new(querySet)
o.mi = mi o.mi = mi

View File

@ -1,13 +1,19 @@
// Beego (http://beego.me/)
// @description beego is an open-source, high-performance web framework for the Go programming language.
// @link http://github.com/astaxie/beego for the canonical source repository
// @license http://github.com/astaxie/beego/blob/master/LICENSE
// @authors slene
package orm package orm
import ( import (
"database/sql" "database/sql"
"fmt" "fmt"
"reflect" "reflect"
"strings"
"time" "time"
) )
// raw sql string prepared statement
type rawPrepare struct { type rawPrepare struct {
rs *rawSet rs *rawSet
stmt stmtQuerier stmt stmtQuerier
@ -45,6 +51,7 @@ func newRawPreparer(rs *rawSet) (RawPreparer, error) {
return o, nil return o, nil
} }
// raw query seter
type rawSet struct { type rawSet struct {
query string query string
args []interface{} args []interface{}
@ -53,11 +60,13 @@ type rawSet struct {
var _ RawSeter = new(rawSet) var _ RawSeter = new(rawSet)
// set args for every query
func (o rawSet) SetArgs(args ...interface{}) RawSeter { func (o rawSet) SetArgs(args ...interface{}) RawSeter {
o.args = args o.args = args
return &o return &o
} }
// execute raw sql and return sql.Result
func (o *rawSet) Exec() (sql.Result, error) { func (o *rawSet) Exec() (sql.Result, error) {
query := o.query query := o.query
o.orm.alias.DbBaser.ReplaceMarks(&query) o.orm.alias.DbBaser.ReplaceMarks(&query)
@ -66,6 +75,7 @@ func (o *rawSet) Exec() (sql.Result, error) {
return o.orm.db.Exec(query, args...) return o.orm.db.Exec(query, args...)
} }
// set field value to row container
func (o *rawSet) setFieldValue(ind reflect.Value, value interface{}) { func (o *rawSet) setFieldValue(ind reflect.Value, value interface{}) {
switch ind.Kind() { switch ind.Kind() {
case reflect.Bool: case reflect.Bool:
@ -164,65 +174,12 @@ func (o *rawSet) setFieldValue(ind reflect.Value, value interface{}) {
} }
} }
func (o *rawSet) loopInitRefs(typ reflect.Type, refsPtr *[]interface{}, sIdxesPtr *[][]int) { // set field value in loop for slice container
sIdxes := *sIdxesPtr func (o *rawSet) loopSetRefs(refs []interface{}, sInds []reflect.Value, nIndsPtr *[]reflect.Value, eTyps []reflect.Type, init bool) {
refs := *refsPtr
if typ.Kind() == reflect.Struct {
if typ.String() == "time.Time" {
var ref interface{}
refs = append(refs, &ref)
sIdxes = append(sIdxes, []int{0})
} else {
idxs := []int{}
outFor:
for idx := 0; idx < typ.NumField(); idx++ {
ctyp := typ.Field(idx)
tag := ctyp.Tag.Get(defaultStructTagName)
for _, v := range strings.Split(tag, defaultStructTagDelim) {
if v == "-" {
continue outFor
}
}
tp := ctyp.Type
if tp.Kind() == reflect.Ptr {
tp = tp.Elem()
}
if tp.String() == "time.Time" {
var ref interface{}
refs = append(refs, &ref)
} else if tp.Kind() != reflect.Struct {
var ref interface{}
refs = append(refs, &ref)
} else {
// skip other type
continue
}
idxs = append(idxs, idx)
}
sIdxes = append(sIdxes, idxs)
}
} else {
var ref interface{}
refs = append(refs, &ref)
sIdxes = append(sIdxes, []int{0})
}
*sIdxesPtr = sIdxes
*refsPtr = refs
}
func (o *rawSet) loopSetRefs(refs []interface{}, sIdxes [][]int, sInds []reflect.Value, nIndsPtr *[]reflect.Value, eTyps []reflect.Type, init bool) {
nInds := *nIndsPtr nInds := *nIndsPtr
cur := 0 cur := 0
for i, idxs := range sIdxes { for i := 0; i < len(sInds); i++ {
sInd := sInds[i] sInd := sInds[i]
eTyp := eTyps[i] eTyp := eTyps[i]
@ -258,32 +215,8 @@ func (o *rawSet) loopSetRefs(refs []interface{}, sIdxes [][]int, sInds []reflect
o.setFieldValue(ind, value) o.setFieldValue(ind, value)
} }
cur++ cur++
} else {
hasValue := false
for _, idx := range idxs {
tind := ind.Field(idx)
value := reflect.ValueOf(refs[cur]).Elem().Interface()
if value != nil {
hasValue = true
}
if tind.Kind() == reflect.Ptr {
if value == nil {
tindV := reflect.New(tind.Type()).Elem()
tind.Set(tindV)
} else {
tindV := reflect.New(tind.Type().Elem())
o.setFieldValue(tindV.Elem(), value)
tind.Set(tindV)
}
} else {
o.setFieldValue(tind, value)
}
cur++
}
if hasValue == false && isPtr {
val = reflect.New(val.Type()).Elem()
}
} }
} else { } else {
value := reflect.ValueOf(refs[cur]).Elem().Interface() value := reflect.ValueOf(refs[cur]).Elem().Interface()
if isPtr && value == nil { if isPtr && value == nil {
@ -312,16 +245,14 @@ func (o *rawSet) loopSetRefs(refs []interface{}, sIdxes [][]int, sInds []reflect
} }
} }
// query data and map to container
func (o *rawSet) QueryRow(containers ...interface{}) error { func (o *rawSet) QueryRow(containers ...interface{}) error {
if len(containers) == 0 {
panic(fmt.Errorf("<RawSeter.QueryRow> need at least one arg"))
}
refs := make([]interface{}, 0, len(containers)) refs := make([]interface{}, 0, len(containers))
sIdxes := make([][]int, 0)
sInds := make([]reflect.Value, 0) sInds := make([]reflect.Value, 0)
eTyps := make([]reflect.Type, 0) eTyps := make([]reflect.Type, 0)
structMode := false
var sMi *modelInfo
for _, container := range containers { for _, container := range containers {
val := reflect.ValueOf(container) val := reflect.ValueOf(container)
ind := reflect.Indirect(val) ind := reflect.Indirect(val)
@ -335,44 +266,123 @@ func (o *rawSet) QueryRow(containers ...interface{}) error {
if typ.Kind() == reflect.Ptr { if typ.Kind() == reflect.Ptr {
typ = typ.Elem() typ = typ.Elem()
} }
if typ.Kind() == reflect.Ptr {
typ = typ.Elem()
}
sInds = append(sInds, ind) sInds = append(sInds, ind)
eTyps = append(eTyps, etyp) eTyps = append(eTyps, etyp)
o.loopInitRefs(typ, &refs, &sIdxes) if typ.Kind() == reflect.Struct && typ.String() != "time.Time" {
if len(containers) > 1 {
panic(fmt.Errorf("<RawSeter.QueryRow> now support one struct only. see #384"))
}
structMode = true
fn := getFullName(typ)
if mi, ok := modelCache.getByFN(fn); ok {
sMi = mi
}
} else {
var ref interface{}
refs = append(refs, &ref)
}
} }
query := o.query query := o.query
o.orm.alias.DbBaser.ReplaceMarks(&query) o.orm.alias.DbBaser.ReplaceMarks(&query)
args := getFlatParams(nil, o.args, o.orm.alias.TZ) args := getFlatParams(nil, o.args, o.orm.alias.TZ)
row := o.orm.db.QueryRow(query, args...) rows, err := o.orm.db.Query(query, args...)
if err != nil {
if err := row.Scan(refs...); err == sql.ErrNoRows { if err == sql.ErrNoRows {
return ErrNoRows return ErrNoRows
} else if err != nil { }
return err
}
defer rows.Close()
if rows.Next() {
if structMode {
columns, err := rows.Columns()
if err != nil {
return err
}
columnsMp := make(map[string]interface{}, len(columns))
refs = make([]interface{}, 0, len(columns))
for _, col := range columns {
var ref interface{}
columnsMp[col] = &ref
refs = append(refs, &ref)
}
if err := rows.Scan(refs...); err != nil {
return err
}
ind := sInds[0]
if ind.Kind() == reflect.Ptr {
if ind.IsNil() || !ind.IsValid() {
ind.Set(reflect.New(eTyps[0].Elem()))
}
ind = ind.Elem()
}
if sMi != nil {
for _, col := range columns {
if fi := sMi.fields.GetByColumn(col); fi != nil {
value := reflect.ValueOf(columnsMp[col]).Elem().Interface()
o.setFieldValue(ind.FieldByIndex([]int{fi.fieldIndex}), value)
}
}
} else {
for i := 0; i < ind.NumField(); i++ {
f := ind.Field(i)
fe := ind.Type().Field(i)
var attrs map[string]bool
var tags map[string]string
parseStructTag(fe.Tag.Get("orm"), &attrs, &tags)
var col string
if col = tags["column"]; len(col) == 0 {
col = snakeString(fe.Name)
}
if v, ok := columnsMp[col]; ok {
value := reflect.ValueOf(v).Elem().Interface()
o.setFieldValue(f, value)
}
}
}
} else {
if err := rows.Scan(refs...); err != nil {
return err return err
} }
nInds := make([]reflect.Value, len(sInds)) nInds := make([]reflect.Value, len(sInds))
o.loopSetRefs(refs, sIdxes, sInds, &nInds, eTyps, true) o.loopSetRefs(refs, sInds, &nInds, eTyps, true)
for i, sInd := range sInds { for i, sInd := range sInds {
nInd := nInds[i] nInd := nInds[i]
sInd.Set(nInd) sInd.Set(nInd)
} }
}
} else {
return ErrNoRows
}
return nil return nil
} }
// query data rows and map to container
func (o *rawSet) QueryRows(containers ...interface{}) (int64, error) { func (o *rawSet) QueryRows(containers ...interface{}) (int64, error) {
refs := make([]interface{}, 0) refs := make([]interface{}, 0, len(containers))
sIdxes := make([][]int, 0)
sInds := make([]reflect.Value, 0) sInds := make([]reflect.Value, 0)
eTyps := make([]reflect.Type, 0) eTyps := make([]reflect.Type, 0)
structMode := false
var sMi *modelInfo
for _, container := range containers { for _, container := range containers {
val := reflect.ValueOf(container) val := reflect.ValueOf(container)
sInd := reflect.Indirect(val) sInd := reflect.Indirect(val)
@ -389,7 +399,20 @@ func (o *rawSet) QueryRows(containers ...interface{}) (int64, error) {
sInds = append(sInds, sInd) sInds = append(sInds, sInd)
eTyps = append(eTyps, etyp) eTyps = append(eTyps, etyp)
o.loopInitRefs(typ, &refs, &sIdxes) if typ.Kind() == reflect.Struct && typ.String() != "time.Time" {
if len(containers) > 1 {
panic(fmt.Errorf("<RawSeter.QueryRow> now support one struct only. see #384"))
}
structMode = true
fn := getFullName(typ)
if mi, ok := modelCache.getByFN(fn); ok {
sMi = mi
}
} else {
var ref interface{}
refs = append(refs, &ref)
}
} }
query := o.query query := o.query
@ -401,30 +424,107 @@ func (o *rawSet) QueryRows(containers ...interface{}) (int64, error) {
return 0, err return 0, err
} }
nInds := make([]reflect.Value, len(sInds)) defer rows.Close()
var cnt int64 var cnt int64
nInds := make([]reflect.Value, len(sInds))
sInd := sInds[0]
for rows.Next() { for rows.Next() {
if structMode {
columns, err := rows.Columns()
if err != nil {
return 0, err
}
columnsMp := make(map[string]interface{}, len(columns))
refs = make([]interface{}, 0, len(columns))
for _, col := range columns {
var ref interface{}
columnsMp[col] = &ref
refs = append(refs, &ref)
}
if err := rows.Scan(refs...); err != nil { if err := rows.Scan(refs...); err != nil {
return 0, err return 0, err
} }
o.loopSetRefs(refs, sIdxes, sInds, &nInds, eTyps, cnt == 0) if cnt == 0 && !sInd.IsNil() {
sInd.Set(reflect.New(sInd.Type()).Elem())
}
var ind reflect.Value
if eTyps[0].Kind() == reflect.Ptr {
ind = reflect.New(eTyps[0].Elem())
} else {
ind = reflect.New(eTyps[0])
}
if ind.Kind() == reflect.Ptr {
ind = ind.Elem()
}
if sMi != nil {
for _, col := range columns {
if fi := sMi.fields.GetByColumn(col); fi != nil {
value := reflect.ValueOf(columnsMp[col]).Elem().Interface()
o.setFieldValue(ind.FieldByIndex([]int{fi.fieldIndex}), value)
}
}
} else {
for i := 0; i < ind.NumField(); i++ {
f := ind.Field(i)
fe := ind.Type().Field(i)
var attrs map[string]bool
var tags map[string]string
parseStructTag(fe.Tag.Get("orm"), &attrs, &tags)
var col string
if col = tags["column"]; len(col) == 0 {
col = snakeString(fe.Name)
}
if v, ok := columnsMp[col]; ok {
value := reflect.ValueOf(v).Elem().Interface()
o.setFieldValue(f, value)
}
}
}
if eTyps[0].Kind() == reflect.Ptr {
ind = ind.Addr()
}
sInd = reflect.Append(sInd, ind)
} else {
if err := rows.Scan(refs...); err != nil {
return 0, err
}
o.loopSetRefs(refs, sInds, &nInds, eTyps, cnt == 0)
}
cnt++ cnt++
} }
if cnt > 0 { if cnt > 0 {
if structMode {
sInds[0].Set(sInd)
} else {
for i, sInd := range sInds { for i, sInd := range sInds {
nInd := nInds[i] nInd := nInds[i]
sInd.Set(nInd) sInd.Set(nInd)
} }
} }
}
return cnt, nil return cnt, nil
} }
func (o *rawSet) readValues(container interface{}) (int64, error) { func (o *rawSet) readValues(container interface{}, needCols []string) (int64, error) {
var ( var (
maps []Params maps []Params
lists []ParamsList lists []ParamsList
@ -455,21 +555,41 @@ func (o *rawSet) readValues(container interface{}) (int64, error) {
rs = r rs = r
} }
defer rs.Close()
var ( var (
refs []interface{} refs []interface{}
cnt int64 cnt int64
cols []string cols []string
indexs []int
) )
for rs.Next() { for rs.Next() {
if cnt == 0 { if cnt == 0 {
if columns, err := rs.Columns(); err != nil { if columns, err := rs.Columns(); err != nil {
return 0, err return 0, err
} else { } else {
if len(needCols) > 0 {
indexs = make([]int, 0, len(needCols))
} else {
indexs = make([]int, 0, len(columns))
}
cols = columns cols = columns
refs = make([]interface{}, len(cols)) refs = make([]interface{}, len(cols))
for i, _ := range refs { for i, _ := range refs {
var ref sql.NullString var ref sql.NullString
refs[i] = &ref refs[i] = &ref
if len(needCols) > 0 {
for _, c := range needCols {
if c == cols[i] {
indexs = append(indexs, i)
}
}
} else {
indexs = append(indexs, i)
}
} }
} }
} }
@ -481,7 +601,8 @@ func (o *rawSet) readValues(container interface{}) (int64, error) {
switch typ { switch typ {
case 1: case 1:
params := make(Params, len(cols)) params := make(Params, len(cols))
for i, ref := range refs { for _, i := range indexs {
ref := refs[i]
value := reflect.Indirect(reflect.ValueOf(ref)).Interface().(sql.NullString) value := reflect.Indirect(reflect.ValueOf(ref)).Interface().(sql.NullString)
if value.Valid { if value.Valid {
params[cols[i]] = value.String params[cols[i]] = value.String
@ -492,7 +613,8 @@ func (o *rawSet) readValues(container interface{}) (int64, error) {
maps = append(maps, params) maps = append(maps, params)
case 2: case 2:
params := make(ParamsList, 0, len(cols)) params := make(ParamsList, 0, len(cols))
for _, ref := range refs { for _, i := range indexs {
ref := refs[i]
value := reflect.Indirect(reflect.ValueOf(ref)).Interface().(sql.NullString) value := reflect.Indirect(reflect.ValueOf(ref)).Interface().(sql.NullString)
if value.Valid { if value.Valid {
params = append(params, value.String) params = append(params, value.String)
@ -502,7 +624,8 @@ func (o *rawSet) readValues(container interface{}) (int64, error) {
} }
lists = append(lists, params) lists = append(lists, params)
case 3: case 3:
for _, ref := range refs { for _, i := range indexs {
ref := refs[i]
value := reflect.Indirect(reflect.ValueOf(ref)).Interface().(sql.NullString) value := reflect.Indirect(reflect.ValueOf(ref)).Interface().(sql.NullString)
if value.Valid { if value.Valid {
list = append(list, value.String) list = append(list, value.String)
@ -527,18 +650,166 @@ func (o *rawSet) readValues(container interface{}) (int64, error) {
return cnt, nil return cnt, nil
} }
func (o *rawSet) Values(container *[]Params) (int64, error) { func (o *rawSet) queryRowsTo(container interface{}, keyCol, valueCol string) (int64, error) {
return o.readValues(container) var (
maps Params
ind *reflect.Value
)
typ := 0
switch container.(type) {
case *Params:
typ = 1
default:
typ = 2
vl := reflect.ValueOf(container)
id := reflect.Indirect(vl)
if vl.Kind() != reflect.Ptr || id.Kind() != reflect.Struct {
panic(fmt.Errorf("<RawSeter> RowsTo unsupport type `%T` need ptr struct", container))
}
ind = &id
}
query := o.query
o.orm.alias.DbBaser.ReplaceMarks(&query)
args := getFlatParams(nil, o.args, o.orm.alias.TZ)
var rs *sql.Rows
if r, err := o.orm.db.Query(query, args...); err != nil {
return 0, err
} else {
rs = r
}
defer rs.Close()
var (
refs []interface{}
cnt int64
cols []string
)
var (
keyIndex = -1
valueIndex = -1
)
for rs.Next() {
if cnt == 0 {
if columns, err := rs.Columns(); err != nil {
return 0, err
} else {
cols = columns
refs = make([]interface{}, len(cols))
for i, _ := range refs {
if keyCol == cols[i] {
keyIndex = i
}
if typ == 1 || keyIndex == i {
var ref sql.NullString
refs[i] = &ref
} else {
var ref interface{}
refs[i] = &ref
}
if valueCol == cols[i] {
valueIndex = i
}
}
if keyIndex == -1 || valueIndex == -1 {
panic(fmt.Errorf("<RawSeter> RowsTo unknown key, value column name `%s: %s`", keyCol, valueCol))
}
}
}
if err := rs.Scan(refs...); err != nil {
return 0, err
}
if cnt == 0 {
switch typ {
case 1:
maps = make(Params)
}
}
key := reflect.Indirect(reflect.ValueOf(refs[keyIndex])).Interface().(sql.NullString).String
switch typ {
case 1:
value := reflect.Indirect(reflect.ValueOf(refs[valueIndex])).Interface().(sql.NullString)
if value.Valid {
maps[key] = value.String
} else {
maps[key] = nil
}
default:
if id := ind.FieldByName(camelString(key)); id.IsValid() {
o.setFieldValue(id, reflect.ValueOf(refs[valueIndex]).Elem().Interface())
}
}
cnt++
}
if typ == 1 {
v, _ := container.(*Params)
*v = maps
}
return cnt, nil
} }
func (o *rawSet) ValuesList(container *[]ParamsList) (int64, error) { // query data to []map[string]interface
return o.readValues(container) func (o *rawSet) Values(container *[]Params, cols ...string) (int64, error) {
return o.readValues(container, cols)
} }
func (o *rawSet) ValuesFlat(container *ParamsList) (int64, error) { // query data to [][]interface
return o.readValues(container) func (o *rawSet) ValuesList(container *[]ParamsList, cols ...string) (int64, error) {
return o.readValues(container, cols)
} }
// query data to []interface
func (o *rawSet) ValuesFlat(container *ParamsList, cols ...string) (int64, error) {
return o.readValues(container, cols)
}
// query all rows into map[string]interface with specify key and value column name.
// keyCol = "name", valueCol = "value"
// table data
// name | value
// total | 100
// found | 200
// to map[string]interface{}{
// "total": 100,
// "found": 200,
// }
func (o *rawSet) RowsToMap(result *Params, keyCol, valueCol string) (int64, error) {
return o.queryRowsTo(result, keyCol, valueCol)
}
// query all rows into struct with specify key and value column name.
// keyCol = "name", valueCol = "value"
// table data
// name | value
// total | 100
// found | 200
// to struct {
// Total int
// Found int
// }
func (o *rawSet) RowsToStruct(ptrStruct interface{}, keyCol, valueCol string) (int64, error) {
return o.queryRowsTo(ptrStruct, keyCol, valueCol)
}
// return prepared raw statement for used in times.
func (o *rawSet) Prepare() (RawPreparer, error) { func (o *rawSet) Prepare() (RawPreparer, error) {
return newRawPreparer(o) return newRawPreparer(o)
} }

View File

@ -1,7 +1,14 @@
// Beego (http://beego.me/)
// @description beego is an open-source, high-performance web framework for the Go programming language.
// @link http://github.com/astaxie/beego for the canonical source repository
// @license http://github.com/astaxie/beego/blob/master/LICENSE
// @authors slene
package orm package orm
import ( import (
"bytes" "bytes"
"database/sql"
"fmt" "fmt"
"io/ioutil" "io/ioutil"
"os" "os"
@ -138,8 +145,17 @@ func throwFailNow(t *testing.T, err error, args ...interface{}) {
} }
} }
func TestGetDB(t *testing.T) {
if db, err := GetDB(); err != nil {
throwFailNow(t, err)
} else {
err = db.Ping()
throwFailNow(t, err)
}
}
func TestSyncDb(t *testing.T) { func TestSyncDb(t *testing.T) {
RegisterModel(new(Data), new(DataNull)) RegisterModel(new(Data), new(DataNull), new(DataCustom))
RegisterModel(new(User)) RegisterModel(new(User))
RegisterModel(new(Profile)) RegisterModel(new(Profile))
RegisterModel(new(Post)) RegisterModel(new(Post))
@ -155,7 +171,7 @@ func TestSyncDb(t *testing.T) {
} }
func TestRegisterModels(t *testing.T) { func TestRegisterModels(t *testing.T) {
RegisterModel(new(Data), new(DataNull)) RegisterModel(new(Data), new(DataNull), new(DataCustom))
RegisterModel(new(User)) RegisterModel(new(User))
RegisterModel(new(Profile)) RegisterModel(new(Profile))
RegisterModel(new(Post)) RegisterModel(new(Post))
@ -258,12 +274,78 @@ func TestNullDataTypes(t *testing.T) {
err = dORM.Read(&d) err = dORM.Read(&d)
throwFail(t, err) throwFail(t, err)
throwFail(t, AssertIs(d.NullBool.Valid, false))
throwFail(t, AssertIs(d.NullString.Valid, false))
throwFail(t, AssertIs(d.NullInt64.Valid, false))
throwFail(t, AssertIs(d.NullFloat64.Valid, false))
_, err = dORM.Raw(`INSERT INTO data_null (boolean) VALUES (?)`, nil).Exec() _, err = dORM.Raw(`INSERT INTO data_null (boolean) VALUES (?)`, nil).Exec()
throwFail(t, err) throwFail(t, err)
d = DataNull{Id: 2} d = DataNull{Id: 2}
err = dORM.Read(&d) err = dORM.Read(&d)
throwFail(t, err) throwFail(t, err)
d = DataNull{
DateTime: time.Now(),
NullString: sql.NullString{String: "test", Valid: true},
NullBool: sql.NullBool{Bool: true, Valid: true},
NullInt64: sql.NullInt64{Int64: 42, Valid: true},
NullFloat64: sql.NullFloat64{Float64: 42.42, Valid: true},
}
id, err = dORM.Insert(&d)
throwFail(t, err)
throwFail(t, AssertIs(id, 3))
d = DataNull{Id: 3}
err = dORM.Read(&d)
throwFail(t, err)
throwFail(t, AssertIs(d.NullBool.Valid, true))
throwFail(t, AssertIs(d.NullBool.Bool, true))
throwFail(t, AssertIs(d.NullString.Valid, true))
throwFail(t, AssertIs(d.NullString.String, "test"))
throwFail(t, AssertIs(d.NullInt64.Valid, true))
throwFail(t, AssertIs(d.NullInt64.Int64, 42))
throwFail(t, AssertIs(d.NullFloat64.Valid, true))
throwFail(t, AssertIs(d.NullFloat64.Float64, 42.42))
}
func TestDataCustomTypes(t *testing.T) {
d := DataCustom{}
ind := reflect.Indirect(reflect.ValueOf(&d))
for name, value := range Data_Values {
e := ind.FieldByName(name)
if !e.IsValid() {
continue
}
e.Set(reflect.ValueOf(value).Convert(e.Type()))
}
id, err := dORM.Insert(&d)
throwFail(t, err)
throwFail(t, AssertIs(id, 1))
d = DataCustom{Id: 1}
err = dORM.Read(&d)
throwFail(t, err)
ind = reflect.Indirect(reflect.ValueOf(&d))
for name, value := range Data_Values {
e := ind.FieldByName(name)
if !e.IsValid() {
continue
}
vu := e.Interface()
value = reflect.ValueOf(value).Convert(e.Type()).Interface()
throwFail(t, AssertIs(vu == value, true), value, vu)
}
} }
func TestCRUD(t *testing.T) { func TestCRUD(t *testing.T) {
@ -519,6 +601,10 @@ func TestOperators(t *testing.T) {
throwFail(t, err) throwFail(t, err)
throwFail(t, AssertIs(num, 1)) throwFail(t, AssertIs(num, 1))
num, err = qs.Filter("user_name__exact", String("slene")).Count()
throwFail(t, err)
throwFail(t, AssertIs(num, 1))
num, err = qs.Filter("user_name__exact", "slene").Count() num, err = qs.Filter("user_name__exact", "slene").Count()
throwFail(t, err) throwFail(t, err)
throwFail(t, AssertIs(num, 1)) throwFail(t, AssertIs(num, 1))
@ -559,11 +645,11 @@ func TestOperators(t *testing.T) {
throwFail(t, err) throwFail(t, err)
throwFail(t, AssertIs(num, 3)) throwFail(t, AssertIs(num, 3))
num, err = qs.Filter("status__lt", 3).Count() num, err = qs.Filter("status__lt", Uint(3)).Count()
throwFail(t, err) throwFail(t, err)
throwFail(t, AssertIs(num, 2)) throwFail(t, AssertIs(num, 2))
num, err = qs.Filter("status__lte", 3).Count() num, err = qs.Filter("status__lte", Int(3)).Count()
throwFail(t, err) throwFail(t, err)
throwFail(t, AssertIs(num, 3)) throwFail(t, AssertIs(num, 3))
@ -619,6 +705,14 @@ func TestOperators(t *testing.T) {
num, err = qs.Filter("status__in", []*int{&n1}, &n2).Count() num, err = qs.Filter("status__in", []*int{&n1}, &n2).Count()
throwFail(t, err) throwFail(t, err)
throwFail(t, AssertIs(num, 2)) throwFail(t, AssertIs(num, 2))
num, err = qs.Filter("id__between", 2, 3).Count()
throwFail(t, err)
throwFail(t, AssertIs(num, 2))
num, err = qs.Filter("id__between", []int{2, 3}).Count()
throwFail(t, err)
throwFail(t, AssertIs(num, 2))
} }
func TestSetCond(t *testing.T) { func TestSetCond(t *testing.T) {
@ -1322,58 +1416,6 @@ func TestRawQueryRow(t *testing.T) {
} }
} }
type Tmp struct {
Skip0 string
Id int
Char *string
Skip1 int `orm:"-"`
Date time.Time
DateTime time.Time
}
Boolean = false
Text = ""
Int64 = 0
Uint = 0
tmp := new(Tmp)
cols = []string{
"int", "char", "date", "datetime", "boolean", "text", "int64", "uint",
}
query = fmt.Sprintf("SELECT NULL, %s%s%s FROM data WHERE id = ?", Q, strings.Join(cols, sep), Q)
values = []interface{}{
tmp, &Boolean, &Text, &Int64, &Uint,
}
err = dORM.Raw(query, 1).QueryRow(values...)
throwFailNow(t, err)
for _, col := range cols {
switch col {
case "id":
throwFail(t, AssertIs(tmp.Id, data_values[col]))
case "char":
c := tmp.Char
throwFail(t, AssertIs(*c, data_values[col]))
case "date":
v := tmp.Date.In(DefaultTimeLoc)
value := data_values[col].(time.Time).In(DefaultTimeLoc)
throwFail(t, AssertIs(v, value, test_Date))
case "datetime":
v := tmp.DateTime.In(DefaultTimeLoc)
value := data_values[col].(time.Time).In(DefaultTimeLoc)
throwFail(t, AssertIs(v, value, test_DateTime))
case "boolean":
throwFail(t, AssertIs(Boolean, data_values[col]))
case "text":
throwFail(t, AssertIs(Text, data_values[col]))
case "int64":
throwFail(t, AssertIs(Int64, data_values[col]))
case "uint":
throwFail(t, AssertIs(Uint, data_values[col]))
}
}
var ( var (
uid int uid int
status *int status *int
@ -1381,7 +1423,7 @@ func TestRawQueryRow(t *testing.T) {
) )
cols = []string{ cols = []string{
"id", "status", "profile_id", "id", "Status", "profile_id",
} }
query = fmt.Sprintf("SELECT %s%s%s FROM %suser%s WHERE id = ?", Q, strings.Join(cols, sep), Q, Q, Q) query = fmt.Sprintf("SELECT %s%s%s FROM %suser%s WHERE id = ?", Q, strings.Join(cols, sep), Q, Q, Q)
err = dORM.Raw(query, 4).QueryRow(&uid, &status, &pid) err = dORM.Raw(query, 4).QueryRow(&uid, &status, &pid)
@ -1394,22 +1436,13 @@ func TestRawQueryRow(t *testing.T) {
func TestQueryRows(t *testing.T) { func TestQueryRows(t *testing.T) {
Q := dDbBaser.TableQuote() Q := dDbBaser.TableQuote()
cols := []string{
"id", "boolean", "char", "text", "date", "datetime", "byte", "rune", "int", "int8", "int16", "int32",
"int64", "uint", "uint8", "uint16", "uint32", "uint64", "float32", "float64", "decimal",
}
var datas []*Data var datas []*Data
var dids []int
sep := fmt.Sprintf("%s, %s", Q, Q) query := fmt.Sprintf("SELECT * FROM %sdata%s", Q, Q)
query := fmt.Sprintf("SELECT %s%s%s, id FROM %sdata%s", Q, strings.Join(cols, sep), Q, Q, Q) num, err := dORM.Raw(query).QueryRows(&datas)
num, err := dORM.Raw(query).QueryRows(&datas, &dids)
throwFailNow(t, err) throwFailNow(t, err)
throwFailNow(t, AssertIs(num, 1)) throwFailNow(t, AssertIs(num, 1))
throwFailNow(t, AssertIs(len(datas), 1)) throwFailNow(t, AssertIs(len(datas), 1))
throwFailNow(t, AssertIs(len(dids), 1))
throwFailNow(t, AssertIs(dids[0], 1))
ind := reflect.Indirect(reflect.ValueOf(datas[0])) ind := reflect.Indirect(reflect.ValueOf(datas[0]))
@ -1427,97 +1460,50 @@ func TestQueryRows(t *testing.T) {
throwFail(t, AssertIs(vu == value, true), value, vu) throwFail(t, AssertIs(vu == value, true), value, vu)
} }
type Tmp struct { var datas2 []Data
Id int
Name string query = fmt.Sprintf("SELECT * FROM %sdata%s", Q, Q)
Skiped0 string `orm:"-"` num, err = dORM.Raw(query).QueryRows(&datas2)
Pid *int throwFailNow(t, err)
Skiped1 Data throwFailNow(t, AssertIs(num, 1))
Skiped2 *Data throwFailNow(t, AssertIs(len(datas2), 1))
ind = reflect.Indirect(reflect.ValueOf(datas2[0]))
for name, value := range Data_Values {
e := ind.FieldByName(name)
vu := e.Interface()
switch name {
case "Date":
vu = vu.(time.Time).In(DefaultTimeLoc).Format(test_Date)
value = value.(time.Time).In(DefaultTimeLoc).Format(test_Date)
case "DateTime":
vu = vu.(time.Time).In(DefaultTimeLoc).Format(test_DateTime)
value = value.(time.Time).In(DefaultTimeLoc).Format(test_DateTime)
}
throwFail(t, AssertIs(vu == value, true), value, vu)
} }
var ( var ids []int
ids []int var usernames []string
userNames []string query = fmt.Sprintf("SELECT %sid%s, %suser_name%s FROM %suser%s ORDER BY %sid%s ASC", Q, Q, Q, Q, Q, Q, Q, Q)
profileIds1 []int num, err = dORM.Raw(query).QueryRows(&ids, &usernames)
profileIds2 []*int
createds []time.Time
updateds []time.Time
tmps1 []*Tmp
tmps2 []Tmp
)
cols = []string{
"id", "user_name", "profile_id", "profile_id", "id", "user_name", "profile_id", "id", "user_name", "profile_id", "created", "updated",
}
query = fmt.Sprintf("SELECT %s%s%s FROM %suser%s ORDER BY id", Q, strings.Join(cols, sep), Q, Q, Q)
num, err = dORM.Raw(query).QueryRows(&ids, &userNames, &profileIds1, &profileIds2, &tmps1, &tmps2, &createds, &updateds)
throwFailNow(t, err) throwFailNow(t, err)
throwFailNow(t, AssertIs(num, 3)) throwFailNow(t, AssertIs(num, 3))
throwFailNow(t, AssertIs(len(ids), 3))
var users []User throwFailNow(t, AssertIs(ids[0], 2))
dORM.QueryTable("user").OrderBy("Id").All(&users) throwFailNow(t, AssertIs(usernames[0], "slene"))
throwFailNow(t, AssertIs(ids[1], 3))
for i := 0; i < 3; i++ { throwFailNow(t, AssertIs(usernames[1], "astaxie"))
id := ids[i] throwFailNow(t, AssertIs(ids[2], 4))
name := userNames[i] throwFailNow(t, AssertIs(usernames[2], "nobody"))
pid1 := profileIds1[i]
pid2 := profileIds2[i]
created := createds[i]
updated := updateds[i]
user := users[i]
throwFailNow(t, AssertIs(id, user.Id))
throwFailNow(t, AssertIs(name, user.UserName))
if user.Profile != nil {
throwFailNow(t, AssertIs(pid1, user.Profile.Id))
throwFailNow(t, AssertIs(*pid2, user.Profile.Id))
} else {
throwFailNow(t, AssertIs(pid1, 0))
throwFailNow(t, AssertIs(pid2, nil))
}
throwFailNow(t, AssertIs(created, user.Created, test_Date))
throwFailNow(t, AssertIs(updated, user.Updated, test_DateTime))
tmp := tmps1[i]
tmp1 := *tmp
throwFailNow(t, AssertIs(tmp1.Id, user.Id))
throwFailNow(t, AssertIs(tmp1.Name, user.UserName))
if user.Profile != nil {
pid := tmp1.Pid
throwFailNow(t, AssertIs(*pid, user.Profile.Id))
} else {
throwFailNow(t, AssertIs(tmp1.Pid, nil))
}
tmp2 := tmps2[i]
throwFailNow(t, AssertIs(tmp2.Id, user.Id))
throwFailNow(t, AssertIs(tmp2.Name, user.UserName))
if user.Profile != nil {
pid := tmp2.Pid
throwFailNow(t, AssertIs(*pid, user.Profile.Id))
} else {
throwFailNow(t, AssertIs(tmp2.Pid, nil))
}
}
type Sec struct {
Id int
Name string
}
var tmp []*Sec
query = fmt.Sprintf("SELECT NULL, NULL FROM %suser%s LIMIT 1", Q, Q)
num, err = dORM.Raw(query).QueryRows(&tmp)
throwFail(t, err)
throwFail(t, AssertIs(num, 1))
throwFail(t, AssertIs(tmp[0], nil))
} }
func TestRawValues(t *testing.T) { func TestRawValues(t *testing.T) {
Q := dDbBaser.TableQuote() Q := dDbBaser.TableQuote()
var maps []Params var maps []Params
query := fmt.Sprintf("SELECT %suser_name%s FROM %suser%s WHERE %sstatus%s = ?", Q, Q, Q, Q, Q, Q) query := fmt.Sprintf("SELECT %suser_name%s FROM %suser%s WHERE %sStatus%s = ?", Q, Q, Q, Q, Q, Q)
num, err := dORM.Raw(query, 1).Values(&maps) num, err := dORM.Raw(query, 1).Values(&maps)
throwFail(t, err) throwFail(t, err)
throwFail(t, AssertIs(num, 1)) throwFail(t, AssertIs(num, 1))
@ -1669,6 +1655,31 @@ func TestDelete(t *testing.T) {
num, err = qs.Filter("user_name", "slene").Filter("profile__isnull", true).Count() num, err = qs.Filter("user_name", "slene").Filter("profile__isnull", true).Count()
throwFail(t, err) throwFail(t, err)
throwFail(t, AssertIs(num, 1)) throwFail(t, AssertIs(num, 1))
qs = dORM.QueryTable("comment")
num, err = qs.Count()
throwFail(t, err)
throwFail(t, AssertIs(num, 6))
qs = dORM.QueryTable("post")
num, err = qs.Filter("Id", 3).Delete()
throwFail(t, err)
throwFail(t, AssertIs(num, 1))
qs = dORM.QueryTable("comment")
num, err = qs.Count()
throwFail(t, err)
throwFail(t, AssertIs(num, 4))
qs = dORM.QueryTable("comment")
num, err = qs.Filter("Post__User", 3).Delete()
throwFail(t, err)
throwFail(t, AssertIs(num, 3))
qs = dORM.QueryTable("comment")
num, err = qs.Count()
throwFail(t, err)
throwFail(t, AssertIs(num, 1))
} }
func TestTransaction(t *testing.T) { func TestTransaction(t *testing.T) {
@ -1724,3 +1735,41 @@ func TestTransaction(t *testing.T) {
throwFail(t, AssertIs(num, 1)) throwFail(t, AssertIs(num, 1))
} }
func TestReadOrCreate(t *testing.T) {
u := &User{
UserName: "Kyle",
Email: "kylemcc@gmail.com",
Password: "other_pass",
Status: 7,
IsStaff: false,
IsActive: true,
}
created, pk, err := dORM.ReadOrCreate(u, "UserName")
throwFail(t, err)
throwFail(t, AssertIs(created, true))
throwFail(t, AssertIs(u.UserName, "Kyle"))
throwFail(t, AssertIs(u.Email, "kylemcc@gmail.com"))
throwFail(t, AssertIs(u.Password, "other_pass"))
throwFail(t, AssertIs(u.Status, 7))
throwFail(t, AssertIs(u.IsStaff, false))
throwFail(t, AssertIs(u.IsActive, true))
throwFail(t, AssertIs(u.Created.In(DefaultTimeLoc), u.Created.In(DefaultTimeLoc), test_Date))
throwFail(t, AssertIs(u.Updated.In(DefaultTimeLoc), u.Updated.In(DefaultTimeLoc), test_DateTime))
nu := &User{UserName: u.UserName, Email: "someotheremail@gmail.com"}
created, pk, err = dORM.ReadOrCreate(nu, "UserName")
throwFail(t, err)
throwFail(t, AssertIs(created, false))
throwFail(t, AssertIs(nu.Id, u.Id))
throwFail(t, AssertIs(pk, u.Id))
throwFail(t, AssertIs(nu.UserName, u.UserName))
throwFail(t, AssertIs(nu.Email, u.Email)) // should contain the value in the table, not the one specified above
throwFail(t, AssertIs(nu.Password, u.Password))
throwFail(t, AssertIs(nu.Status, u.Status))
throwFail(t, AssertIs(nu.IsStaff, u.IsStaff))
throwFail(t, AssertIs(nu.IsActive, u.IsActive))
dORM.Delete(u)
}

View File

@ -1,3 +1,9 @@
// Beego (http://beego.me/)
// @description beego is an open-source, high-performance web framework for the Go programming language.
// @link http://github.com/astaxie/beego for the canonical source repository
// @license http://github.com/astaxie/beego/blob/master/LICENSE
// @authors slene
package orm package orm
import ( import (
@ -6,11 +12,13 @@ import (
"time" "time"
) )
// database driver
type Driver interface { type Driver interface {
Name() string Name() string
Type() DriverType Type() DriverType
} }
// field info
type Fielder interface { type Fielder interface {
String() string String() string
FieldType() int FieldType() int
@ -18,9 +26,12 @@ type Fielder interface {
RawValue() interface{} RawValue() interface{}
} }
// orm struct
type Ormer interface { type Ormer interface {
Read(interface{}, ...string) error Read(interface{}, ...string) error
ReadOrCreate(interface{}, string, ...string) (bool, int64, error)
Insert(interface{}) (int64, error) Insert(interface{}) (int64, error)
InsertMulti(int, interface{}) (int64, error)
Update(interface{}, ...string) (int64, error) Update(interface{}, ...string) (int64, error)
Delete(interface{}) (int64, error) Delete(interface{}) (int64, error)
LoadRelated(interface{}, string, ...interface{}) (int64, error) LoadRelated(interface{}, string, ...interface{}) (int64, error)
@ -32,13 +43,16 @@ type Ormer interface {
Rollback() error Rollback() error
Raw(string, ...interface{}) RawSeter Raw(string, ...interface{}) RawSeter
Driver() Driver Driver() Driver
GetDB() dbQuerier
} }
// insert prepared statement
type Inserter interface { type Inserter interface {
Insert(interface{}) (int64, error) Insert(interface{}) (int64, error)
Close() error Close() error
} }
// query seter
type QuerySeter interface { type QuerySeter interface {
Filter(string, ...interface{}) QuerySeter Filter(string, ...interface{}) QuerySeter
Exclude(string, ...interface{}) QuerySeter Exclude(string, ...interface{}) QuerySeter
@ -57,8 +71,11 @@ type QuerySeter interface {
Values(*[]Params, ...string) (int64, error) Values(*[]Params, ...string) (int64, error)
ValuesList(*[]ParamsList, ...string) (int64, error) ValuesList(*[]ParamsList, ...string) (int64, error)
ValuesFlat(*ParamsList, string) (int64, error) ValuesFlat(*ParamsList, string) (int64, error)
RowsToMap(*Params, string, string) (int64, error)
RowsToStruct(interface{}, string, string) (int64, error)
} }
// model to model query struct
type QueryM2Mer interface { type QueryM2Mer interface {
Add(...interface{}) (int64, error) Add(...interface{}) (int64, error)
Remove(...interface{}) (int64, error) Remove(...interface{}) (int64, error)
@ -67,22 +84,27 @@ type QueryM2Mer interface {
Count() (int64, error) Count() (int64, error)
} }
// raw query statement
type RawPreparer interface { type RawPreparer interface {
Exec(...interface{}) (sql.Result, error) Exec(...interface{}) (sql.Result, error)
Close() error Close() error
} }
// raw query seter
type RawSeter interface { type RawSeter interface {
Exec() (sql.Result, error) Exec() (sql.Result, error)
QueryRow(...interface{}) error QueryRow(...interface{}) error
QueryRows(...interface{}) (int64, error) QueryRows(...interface{}) (int64, error)
SetArgs(...interface{}) RawSeter SetArgs(...interface{}) RawSeter
Values(*[]Params) (int64, error) Values(*[]Params, ...string) (int64, error)
ValuesList(*[]ParamsList) (int64, error) ValuesList(*[]ParamsList, ...string) (int64, error)
ValuesFlat(*ParamsList) (int64, error) ValuesFlat(*ParamsList, ...string) (int64, error)
RowsToMap(*Params, string, string) (int64, error)
RowsToStruct(interface{}, string, string) (int64, error)
Prepare() (RawPreparer, error) Prepare() (RawPreparer, error)
} }
// statement querier
type stmtQuerier interface { type stmtQuerier interface {
Close() error Close() error
Exec(args ...interface{}) (sql.Result, error) Exec(args ...interface{}) (sql.Result, error)
@ -90,6 +112,7 @@ type stmtQuerier interface {
QueryRow(args ...interface{}) *sql.Row QueryRow(args ...interface{}) *sql.Row
} }
// db querier
type dbQuerier interface { type dbQuerier interface {
Prepare(query string) (*sql.Stmt, error) Prepare(query string) (*sql.Stmt, error)
Exec(query string, args ...interface{}) (sql.Result, error) Exec(query string, args ...interface{}) (sql.Result, error)
@ -97,19 +120,31 @@ type dbQuerier interface {
QueryRow(query string, args ...interface{}) *sql.Row QueryRow(query string, args ...interface{}) *sql.Row
} }
// type DB interface {
// Begin() (*sql.Tx, error)
// Prepare(query string) (stmtQuerier, error)
// Exec(query string, args ...interface{}) (sql.Result, error)
// Query(query string, args ...interface{}) (*sql.Rows, error)
// QueryRow(query string, args ...interface{}) *sql.Row
// }
// transaction beginner
type txer interface { type txer interface {
Begin() (*sql.Tx, error) Begin() (*sql.Tx, error)
} }
// transaction ending
type txEnder interface { type txEnder interface {
Commit() error Commit() error
Rollback() error Rollback() error
} }
// base database struct
type dbBaser interface { type dbBaser interface {
Read(dbQuerier, *modelInfo, reflect.Value, *time.Location, []string) error Read(dbQuerier, *modelInfo, reflect.Value, *time.Location, []string) error
Insert(dbQuerier, *modelInfo, reflect.Value, *time.Location) (int64, error) Insert(dbQuerier, *modelInfo, reflect.Value, *time.Location) (int64, error)
InsertValue(dbQuerier, *modelInfo, []string, []interface{}) (int64, error) InsertMulti(dbQuerier, *modelInfo, reflect.Value, int, *time.Location) (int64, error)
InsertValue(dbQuerier, *modelInfo, bool, []string, []interface{}) (int64, error)
InsertStmt(stmtQuerier, *modelInfo, reflect.Value, *time.Location) (int64, error) InsertStmt(stmtQuerier, *modelInfo, reflect.Value, *time.Location) (int64, error)
Update(dbQuerier, *modelInfo, reflect.Value, *time.Location, []string) (int64, error) Update(dbQuerier, *modelInfo, reflect.Value, *time.Location, []string) (int64, error)
Delete(dbQuerier, *modelInfo, reflect.Value, *time.Location) (int64, error) Delete(dbQuerier, *modelInfo, reflect.Value, *time.Location) (int64, error)
@ -123,6 +158,7 @@ type dbBaser interface {
GenerateOperatorLeftCol(*fieldInfo, string, *string) GenerateOperatorLeftCol(*fieldInfo, string, *string)
PrepareInsert(dbQuerier, *modelInfo) (stmtQuerier, string, error) PrepareInsert(dbQuerier, *modelInfo) (stmtQuerier, string, error)
ReadValues(dbQuerier, *querySet, *modelInfo, *Condition, []string, interface{}, *time.Location) (int64, error) ReadValues(dbQuerier, *querySet, *modelInfo, *Condition, []string, interface{}, *time.Location) (int64, error)
RowsTo(dbQuerier, *querySet, *modelInfo, *Condition, interface{}, string, string, *time.Location) (int64, error)
MaxLimit() uint64 MaxLimit() uint64
TableQuote() string TableQuote() string
ReplaceMarks(*string) ReplaceMarks(*string)

View File

@ -1,3 +1,9 @@
// Beego (http://beego.me/)
// @description beego is an open-source, high-performance web framework for the Go programming language.
// @link http://github.com/astaxie/beego for the canonical source repository
// @license http://github.com/astaxie/beego/blob/master/LICENSE
// @authors slene
package orm package orm
import ( import (
@ -10,6 +16,7 @@ import (
type StrTo string type StrTo string
// set string
func (f *StrTo) Set(v string) { func (f *StrTo) Set(v string) {
if v != "" { if v != "" {
*f = StrTo(v) *f = StrTo(v)
@ -18,77 +25,93 @@ func (f *StrTo) Set(v string) {
} }
} }
// clean string
func (f *StrTo) Clear() { func (f *StrTo) Clear() {
*f = StrTo(0x1E) *f = StrTo(0x1E)
} }
// check string exist
func (f StrTo) Exist() bool { func (f StrTo) Exist() bool {
return string(f) != string(0x1E) return string(f) != string(0x1E)
} }
// string to bool
func (f StrTo) Bool() (bool, error) { func (f StrTo) Bool() (bool, error) {
return strconv.ParseBool(f.String()) return strconv.ParseBool(f.String())
} }
// string to float32
func (f StrTo) Float32() (float32, error) { func (f StrTo) Float32() (float32, error) {
v, err := strconv.ParseFloat(f.String(), 32) v, err := strconv.ParseFloat(f.String(), 32)
return float32(v), err return float32(v), err
} }
// string to float64
func (f StrTo) Float64() (float64, error) { func (f StrTo) Float64() (float64, error) {
return strconv.ParseFloat(f.String(), 64) return strconv.ParseFloat(f.String(), 64)
} }
// string to int
func (f StrTo) Int() (int, error) { func (f StrTo) Int() (int, error) {
v, err := strconv.ParseInt(f.String(), 10, 32) v, err := strconv.ParseInt(f.String(), 10, 32)
return int(v), err return int(v), err
} }
// string to int8
func (f StrTo) Int8() (int8, error) { func (f StrTo) Int8() (int8, error) {
v, err := strconv.ParseInt(f.String(), 10, 8) v, err := strconv.ParseInt(f.String(), 10, 8)
return int8(v), err return int8(v), err
} }
// string to int16
func (f StrTo) Int16() (int16, error) { func (f StrTo) Int16() (int16, error) {
v, err := strconv.ParseInt(f.String(), 10, 16) v, err := strconv.ParseInt(f.String(), 10, 16)
return int16(v), err return int16(v), err
} }
// string to int32
func (f StrTo) Int32() (int32, error) { func (f StrTo) Int32() (int32, error) {
v, err := strconv.ParseInt(f.String(), 10, 32) v, err := strconv.ParseInt(f.String(), 10, 32)
return int32(v), err return int32(v), err
} }
// string to int64
func (f StrTo) Int64() (int64, error) { func (f StrTo) Int64() (int64, error) {
v, err := strconv.ParseInt(f.String(), 10, 64) v, err := strconv.ParseInt(f.String(), 10, 64)
return int64(v), err return int64(v), err
} }
// string to uint
func (f StrTo) Uint() (uint, error) { func (f StrTo) Uint() (uint, error) {
v, err := strconv.ParseUint(f.String(), 10, 32) v, err := strconv.ParseUint(f.String(), 10, 32)
return uint(v), err return uint(v), err
} }
// string to uint8
func (f StrTo) Uint8() (uint8, error) { func (f StrTo) Uint8() (uint8, error) {
v, err := strconv.ParseUint(f.String(), 10, 8) v, err := strconv.ParseUint(f.String(), 10, 8)
return uint8(v), err return uint8(v), err
} }
// string to uint16
func (f StrTo) Uint16() (uint16, error) { func (f StrTo) Uint16() (uint16, error) {
v, err := strconv.ParseUint(f.String(), 10, 16) v, err := strconv.ParseUint(f.String(), 10, 16)
return uint16(v), err return uint16(v), err
} }
// string to uint31
func (f StrTo) Uint32() (uint32, error) { func (f StrTo) Uint32() (uint32, error) {
v, err := strconv.ParseUint(f.String(), 10, 32) v, err := strconv.ParseUint(f.String(), 10, 32)
return uint32(v), err return uint32(v), err
} }
// string to uint64
func (f StrTo) Uint64() (uint64, error) { func (f StrTo) Uint64() (uint64, error) {
v, err := strconv.ParseUint(f.String(), 10, 64) v, err := strconv.ParseUint(f.String(), 10, 64)
return uint64(v), err return uint64(v), err
} }
// string to string
func (f StrTo) String() string { func (f StrTo) String() string {
if f.Exist() { if f.Exist() {
return string(f) return string(f)
@ -96,6 +119,7 @@ func (f StrTo) String() string {
return "" return ""
} }
// interface to string
func ToStr(value interface{}, args ...int) (s string) { func ToStr(value interface{}, args ...int) (s string) {
switch v := value.(type) { switch v := value.(type) {
case bool: case bool:
@ -134,6 +158,7 @@ func ToStr(value interface{}, args ...int) (s string) {
return s return s
} }
// interface to int64
func ToInt64(value interface{}) (d int64) { func ToInt64(value interface{}) (d int64) {
val := reflect.ValueOf(value) val := reflect.ValueOf(value)
switch value.(type) { switch value.(type) {
@ -147,6 +172,7 @@ func ToInt64(value interface{}) (d int64) {
return return
} }
// snake string, XxYy to xx_yy
func snakeString(s string) string { func snakeString(s string) string {
data := make([]byte, 0, len(s)*2) data := make([]byte, 0, len(s)*2)
j := false j := false
@ -164,6 +190,7 @@ func snakeString(s string) string {
return strings.ToLower(string(data[:len(data)])) return strings.ToLower(string(data[:len(data)]))
} }
// camel string, xx_yy to XxYy
func camelString(s string) string { func camelString(s string) string {
data := make([]byte, 0, len(s)) data := make([]byte, 0, len(s))
j := false j := false
@ -190,6 +217,7 @@ func camelString(s string) string {
type argString []string type argString []string
// get string by index from string slice
func (a argString) Get(i int, args ...string) (r string) { func (a argString) Get(i int, args ...string) (r string) {
if i >= 0 && i < len(a) { if i >= 0 && i < len(a) {
r = a[i] r = a[i]
@ -201,6 +229,7 @@ func (a argString) Get(i int, args ...string) (r string) {
type argInt []int type argInt []int
// get int by index from int slice
func (a argInt) Get(i int, args ...int) (r int) { func (a argInt) Get(i int, args ...int) (r int) {
if i >= 0 && i < len(a) { if i >= 0 && i < len(a) {
r = a[i] r = a[i]
@ -213,6 +242,7 @@ func (a argInt) Get(i int, args ...int) (r int) {
type argAny []interface{} type argAny []interface{}
// get interface by index from interface slice
func (a argAny) Get(i int, args ...interface{}) (r interface{}) { func (a argAny) Get(i int, args ...interface{}) (r interface{}) {
if i >= 0 && i < len(a) { if i >= 0 && i < len(a) {
r = a[i] r = a[i]
@ -223,15 +253,18 @@ func (a argAny) Get(i int, args ...interface{}) (r interface{}) {
return return
} }
// parse time to string with location
func timeParse(dateString, format string) (time.Time, error) { func timeParse(dateString, format string) (time.Time, error) {
tp, err := time.ParseInLocation(format, dateString, DefaultTimeLoc) tp, err := time.ParseInLocation(format, dateString, DefaultTimeLoc)
return tp, err return tp, err
} }
// format time string
func timeFormat(t time.Time, format string) string { func timeFormat(t time.Time, format string) string {
return t.Format(format) return t.Format(format)
} }
// get pointer indirect type
func indirectType(v reflect.Type) reflect.Type { func indirectType(v reflect.Type) reflect.Type {
switch v.Kind() { switch v.Kind() {
case reflect.Ptr: case reflect.Ptr:
@ -239,5 +272,4 @@ func indirectType(v reflect.Type) reflect.Type {
default: default:
return v return v
} }
return v
} }

187
parser.go Normal file
View File

@ -0,0 +1,187 @@
// Beego (http://beego.me/)
// @description beego is an open-source, high-performance web framework for the Go programming language.
// @link http://github.com/astaxie/beego for the canonical source repository
// @license http://github.com/astaxie/beego/blob/master/LICENSE
// @authors astaxie
package beego
import (
"encoding/json"
"errors"
"fmt"
"go/ast"
"go/parser"
"go/token"
"io/ioutil"
"os"
"path"
"strings"
"github.com/astaxie/beego/utils"
)
var globalRouterTemplate = `package routers
import (
"github.com/astaxie/beego"
)
func init() {
{{.globalinfo}}
}
`
var (
lastupdateFilename string = "lastupdate.tmp"
pkgLastupdate map[string]int64
genInfoList map[string][]ControllerComments
)
func init() {
pkgLastupdate = make(map[string]int64)
genInfoList = make(map[string][]ControllerComments)
}
func parserPkg(pkgRealpath, pkgpath string) error {
if !compareFile(pkgRealpath) {
Info(pkgRealpath + " don't has updated")
return nil
}
fileSet := token.NewFileSet()
astPkgs, err := parser.ParseDir(fileSet, pkgRealpath, func(info os.FileInfo) bool {
name := info.Name()
return !info.IsDir() && !strings.HasPrefix(name, ".") && strings.HasSuffix(name, ".go")
}, parser.ParseComments)
if err != nil {
return err
}
for _, pkg := range astPkgs {
for _, fl := range pkg.Files {
for _, d := range fl.Decls {
switch specDecl := d.(type) {
case *ast.FuncDecl:
parserComments(specDecl.Doc, specDecl.Name.String(), fmt.Sprint(specDecl.Recv.List[0].Type.(*ast.StarExpr).X), pkgpath)
}
}
}
}
genRouterCode()
savetoFile(pkgRealpath)
return nil
}
func parserComments(comments *ast.CommentGroup, funcName, controllerName, pkgpath string) error {
if comments != nil && comments.List != nil {
for _, c := range comments.List {
t := strings.TrimSpace(strings.TrimLeft(c.Text, "//"))
if strings.HasPrefix(t, "@router") {
elements := strings.TrimLeft(t, "@router ")
e1 := strings.SplitN(elements, " ", 2)
if len(e1) < 1 {
return errors.New("you should has router infomation")
}
key := pkgpath + ":" + controllerName
cc := ControllerComments{}
cc.Method = funcName
cc.Router = e1[0]
if len(e1) == 2 && e1[1] != "" {
e1 = strings.SplitN(e1[1], " ", 2)
if len(e1) >= 1 {
cc.AllowHTTPMethods = strings.Split(strings.Trim(e1[0], "[]"), ",")
} else {
cc.AllowHTTPMethods = append(cc.AllowHTTPMethods, "get")
}
} else {
cc.AllowHTTPMethods = append(cc.AllowHTTPMethods, "get")
}
if len(e1) == 2 && e1[1] != "" {
keyval := strings.Split(strings.Trim(e1[1], "[]"), " ")
for _, kv := range keyval {
kk := strings.Split(kv, ":")
cc.Params = append(cc.Params, map[string]string{strings.Join(kk[:len(kk)-1], ":"): kk[len(kk)-1]})
}
}
genInfoList[key] = append(genInfoList[key], cc)
}
}
}
return nil
}
func genRouterCode() {
os.Mkdir(path.Join(AppPath, "routers"), 0755)
Info("generate router from comments")
var globalinfo string
for k, cList := range genInfoList {
for _, c := range cList {
allmethod := "nil"
if len(c.AllowHTTPMethods) > 0 {
allmethod = "[]string{"
for _, m := range c.AllowHTTPMethods {
allmethod += `"` + m + `",`
}
allmethod = strings.TrimRight(allmethod, ",") + "}"
}
params := "nil"
if len(c.Params) > 0 {
params = "[]map[string]string{"
for _, p := range c.Params {
for k, v := range p {
params = params + `map[string]string{` + k + `:"` + v + `"},`
}
}
params = strings.TrimRight(params, ",") + "}"
}
globalinfo = globalinfo + `
beego.GlobalControllerRouter["` + k + `"] = append(beego.GlobalControllerRouter["` + k + `"],
beego.ControllerComments{
"` + strings.TrimSpace(c.Method) + `",
"` + c.Router + `",
` + allmethod + `,
` + params + `})
`
}
}
if globalinfo != "" {
f, err := os.Create(path.Join(AppPath, "routers", "commentsRouter.go"))
if err != nil {
panic(err)
}
defer f.Close()
f.WriteString(strings.Replace(globalRouterTemplate, "{{.globalinfo}}", globalinfo, -1))
}
}
func compareFile(pkgRealpath string) bool {
if utils.FileExists(path.Join(AppPath, lastupdateFilename)) {
content, err := ioutil.ReadFile(path.Join(AppPath, lastupdateFilename))
if err != nil {
return true
}
json.Unmarshal(content, &pkgLastupdate)
ft, err := os.Lstat(pkgRealpath)
if err != nil {
return true
}
if v, ok := pkgLastupdate[pkgRealpath]; ok {
if ft.ModTime().UnixNano() >= v {
return false
}
}
}
return true
}
func savetoFile(pkgRealpath string) {
ft, err := os.Lstat(pkgRealpath)
if err != nil {
return
}
pkgLastupdate[pkgRealpath] = ft.ModTime().UnixNano()
d, err := json.Marshal(pkgLastupdate)
if err != nil {
return
}
ioutil.WriteFile(path.Join(AppPath, lastupdateFilename), d, os.ModePerm)
}

80
plugins/auth/basic.go Normal file
View File

@ -0,0 +1,80 @@
// Beego (http://beego.me/)
// @description beego is an open-source, high-performance web framework for the Go programming language.
// @link http://github.com/astaxie/beego for the canonical source repository
// @license http://github.com/astaxie/beego/blob/master/LICENSE
// @authors astaxie
package auth
// Example:
// func SecretAuth(username, password string) bool {
// if username == "astaxie" && password == "helloBeego" {
// return true
// }
// return false
// }
// authPlugin := auth.NewBasicAuthenticator(SecretAuth, "My Realm")
// beego.AddFilter("*","AfterStatic",authPlugin)
import (
"encoding/base64"
"net/http"
"strings"
"github.com/astaxie/beego"
"github.com/astaxie/beego/context"
)
func NewBasicAuthenticator(secrets SecretProvider, Realm string) beego.FilterFunc {
return func(ctx *context.Context) {
a := &BasicAuth{Secrets: secrets, Realm: Realm}
if username := a.CheckAuth(ctx.Request); username == "" {
a.RequireAuth(ctx.ResponseWriter, ctx.Request)
}
}
}
type SecretProvider func(user, pass string) bool
type BasicAuth struct {
Secrets SecretProvider
Realm string
}
/*
Checks the username/password combination from the request. Returns
either an empty string (authentication failed) or the name of the
authenticated user.
Supports MD5 and SHA1 password entries
*/
func (a *BasicAuth) CheckAuth(r *http.Request) string {
s := strings.SplitN(r.Header.Get("Authorization"), " ", 2)
if len(s) != 2 || s[0] != "Basic" {
return ""
}
b, err := base64.StdEncoding.DecodeString(s[1])
if err != nil {
return ""
}
pair := strings.SplitN(string(b), ":", 2)
if len(pair) != 2 {
return ""
}
if a.Secrets(pair[0], pair[1]) {
return pair[0]
}
return ""
}
/*
http.Handler for BasicAuth which initiates the authentication process
(or requires reauthentication).
*/
func (a *BasicAuth) RequireAuth(w http.ResponseWriter, r *http.Request) {
w.Header().Set("WWW-Authenticate", `Basic realm="`+a.Realm+`"`)
w.WriteHeader(401)
w.Write([]byte("401 Unauthorized\n"))
}

165
reload.go
View File

@ -1,165 +0,0 @@
// Zero-downtime restarts in Go.
package beego
import (
"errors"
"fmt"
"log"
"net"
"os"
"os/exec"
"os/signal"
"reflect"
"strconv"
"sync"
//"syscall"
)
const (
// An environment variable when restarting application http listener.
FDKey = "BEEGO_HOT_FD"
)
// Export an error equivalent to net.errClosing for use with Accept during
// a graceful exit.
var ErrClosing = errors.New("use of closed network connection")
var ErrInitStart = errors.New("init from")
// Allows for us to notice when the connection is closed.
type conn struct {
net.Conn
wg *sync.WaitGroup
isclose bool
lock sync.Mutex
}
// Close current processing connection.
func (c conn) Close() error {
c.lock.Lock()
defer c.lock.Unlock()
err := c.Conn.Close()
if !c.isclose && err == nil {
c.wg.Done()
c.isclose = true
}
return err
}
type stoppableListener struct {
net.Listener
count int64
stopped bool
wg sync.WaitGroup
}
var theStoppable *stoppableListener
func newStoppable(l net.Listener) (sl *stoppableListener) {
sl = &stoppableListener{Listener: l}
// this goroutine monitors the channel. Can't do this in
// Accept (below) because once it enters sl.Listener.Accept()
// it blocks. We unblock it by closing the fd it is trying to
// accept(2) on.
go func() {
WaitSignal(l)
sl.stopped = true
sl.Listener.Close()
}()
return
}
// Set stopped Listener to accept requests again.
// it returns the accepted and closable connection or error.
func (sl *stoppableListener) Accept() (c net.Conn, err error) {
c, err = sl.Listener.Accept()
if err != nil {
return
}
sl.wg.Add(1)
// Wrap the returned connection, so that we can observe when
// it is closed.
c = conn{Conn: c, wg: &sl.wg}
return
}
// Listener waits signal to kill or interrupt then restart.
func WaitSignal(l net.Listener) error {
ch := make(chan os.Signal, 1)
signal.Notify(ch, os.Interrupt, os.Kill)
for {
sig := <-ch
log.Println(sig.String())
switch sig {
case os.Kill:
return nil
case os.Interrupt:
err := Restart(l)
if nil != err {
return err
}
return nil
}
}
return nil // It'll never get here.
}
// Kill current running os process.
func CloseSelf() error {
ppid := os.Getpid()
if ppid == 1 { // init provided sockets, for example systemd
return nil
}
p, err := os.FindProcess(ppid)
if err != nil {
return err
}
return p.Kill()
}
// Re-exec this image without dropping the listener passed to this function.
func Restart(l net.Listener) error {
argv0, err := exec.LookPath(os.Args[0])
if nil != err {
return err
}
wd, err := os.Getwd()
if nil != err {
return err
}
v := reflect.ValueOf(l).Elem().FieldByName("fd").Elem()
fd := uintptr(v.FieldByName("sysfd").Int())
allFiles := append([]*os.File{os.Stdin, os.Stdout, os.Stderr},
os.NewFile(fd, string(v.FieldByName("sysfile").String())))
p, err := os.StartProcess(argv0, os.Args, &os.ProcAttr{
Dir: wd,
Env: append(os.Environ(), fmt.Sprintf("%s=%d", FDKey, fd)),
Files: allFiles,
})
if nil != err {
return err
}
log.Printf("spawned child %d\n", p.Pid)
return nil
}
// Get current net.Listen in running process.
func GetInitListener(tcpaddr *net.TCPAddr) (l net.Listener, err error) {
countStr := os.Getenv(FDKey)
if countStr == "" {
return net.ListenTCP("tcp", tcpaddr)
}
count, err := strconv.Atoi(countStr)
if err != nil {
return nil, err
}
f := os.NewFile(uintptr(count), "listen socket")
l, err = net.FileListener(f)
if err != nil {
return nil, err
}
return l, nil
}

1251
router.go

File diff suppressed because it is too large Load Diff

View File

@ -1,9 +1,17 @@
// Beego (http://beego.me/)
// @description beego is an open-source, high-performance web framework for the Go programming language.
// @link http://github.com/astaxie/beego for the canonical source repository
// @license http://github.com/astaxie/beego/blob/master/LICENSE
// @authors astaxie
package beego package beego
import ( import (
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"testing" "testing"
"github.com/astaxie/beego/context"
) )
type TestController struct { type TestController struct {
@ -15,6 +23,14 @@ func (this *TestController) Get() {
this.Ctx.Output.Body([]byte("ok")) this.Ctx.Output.Body([]byte("ok"))
} }
func (this *TestController) Post() {
this.Ctx.Output.Body([]byte(this.Ctx.Input.Query(":name")))
}
func (this *TestController) Param() {
this.Ctx.Output.Body([]byte(this.Ctx.Input.Query(":name")))
}
func (this *TestController) List() { func (this *TestController) List() {
this.Ctx.Output.Body([]byte("i am list")) this.Ctx.Output.Body([]byte("i am list"))
} }
@ -31,6 +47,15 @@ func (this *TestController) GetUrl() {
this.Ctx.Output.Body([]byte(this.UrlFor(".Myext"))) this.Ctx.Output.Body([]byte(this.UrlFor(".Myext")))
} }
func (t *TestController) GetParams() {
t.Ctx.WriteString(t.Ctx.Input.Query(":last") + "+" +
t.Ctx.Input.Query(":first") + "+" + t.Ctx.Input.Query("learn"))
}
func (t *TestController) GetManyRouter() {
t.Ctx.WriteString(t.Ctx.Input.Query(":id") + t.Ctx.Input.Query(":page"))
}
type ResStatus struct { type ResStatus struct {
Code int Code int
Msg string Msg string
@ -51,21 +76,45 @@ func (this *JsonController) Get() {
} }
func TestUrlFor(t *testing.T) { func TestUrlFor(t *testing.T) {
handler := NewControllerRegistor() handler := NewControllerRegister()
handler.Add("/api/list", &TestController{}, "*:List") handler.Add("/api/list", &TestController{}, "*:List")
handler.Add("/person/:last/:first", &TestController{}) handler.Add("/person/:last/:first", &TestController{}, "*:Param")
handler.AddAuto(&TestController{}) handler.AddAuto(&TestController{})
if handler.UrlFor("TestController.List") != "/api/list" { if handler.UrlFor("TestController.List") != "/api/list" {
Info(handler.UrlFor("TestController.List"))
t.Errorf("TestController.List must equal to /api/list") t.Errorf("TestController.List must equal to /api/list")
} }
if handler.UrlFor("TestController.Get", ":last", "xie", ":first", "asta") != "/person/xie/asta" { if handler.UrlFor("TestController.Param", ":last", "xie", ":first", "asta") != "/person/xie/asta" {
t.Errorf("TestController.Get must equal to /person/xie/asta") t.Errorf("TestController.Param must equal to /person/xie/asta, but get " + handler.UrlFor("TestController.Param", ":last", "xie", ":first", "asta"))
} }
if handler.UrlFor("TestController.Myext") != "/Test/Myext" { if handler.UrlFor("TestController.Myext") != "/test/myext" {
t.Errorf("TestController.Myext must equal to /Test/Myext") t.Errorf("TestController.Myext must equal to /test/myext")
} }
if handler.UrlFor("TestController.GetUrl") != "/Test/GetUrl" { if handler.UrlFor("TestController.GetUrl") != "/test/geturl" {
t.Errorf("TestController.GetUrl must equal to /Test/GetUrl") t.Errorf("TestController.GetUrl must equal to /test/geturl")
}
}
func TestUrlFor2(t *testing.T) {
handler := NewControllerRegister()
handler.Add("/v1/:v/cms_:id(.+)_:page(.+).html", &TestController{}, "*:List")
handler.Add("/v1/:v(.+)_cms/ttt_:id(.+)_:page(.+).html", &TestController{}, "*:Param")
handler.Add("/:year:int/:month:int/:title/:entid", &TestController{})
if handler.UrlFor("TestController.List", ":v", "za", ":id", "12", ":page", "123") !=
"/v1/za/cms_12_123.html" {
Info(handler.UrlFor("TestController.List"))
t.Errorf("TestController.List must equal to /v1/za/cms_12_123.html")
}
if handler.UrlFor("TestController.Param", ":v", "za", ":id", "12", ":page", "123") !=
"/v1/za_cms/ttt_12_123.html" {
Info(handler.UrlFor("TestController.Param"))
t.Errorf("TestController.List must equal to /v1/za_cms/ttt_12_123.html")
}
if handler.UrlFor("TestController.Get", ":year", "1111", ":month", "11",
":title", "aaaa", ":entid", "aaaa") !=
"/1111/11/aaaa/aaaa" {
Info(handler.UrlFor("TestController.Get"))
t.Errorf("TestController.Get must equal to /1111/11/aaaa/aaaa")
} }
} }
@ -73,7 +122,7 @@ func TestUserFunc(t *testing.T) {
r, _ := http.NewRequest("GET", "/api/list", nil) r, _ := http.NewRequest("GET", "/api/list", nil)
w := httptest.NewRecorder() w := httptest.NewRecorder()
handler := NewControllerRegistor() handler := NewControllerRegister()
handler.Add("/api/list", &TestController{}, "*:List") handler.Add("/api/list", &TestController{}, "*:List")
handler.ServeHTTP(w, r) handler.ServeHTTP(w, r)
if w.Body.String() != "i am list" { if w.Body.String() != "i am list" {
@ -81,11 +130,23 @@ func TestUserFunc(t *testing.T) {
} }
} }
func TestPostFunc(t *testing.T) {
r, _ := http.NewRequest("POST", "/astaxie", nil)
w := httptest.NewRecorder()
handler := NewControllerRegister()
handler.Add("/:name", &TestController{})
handler.ServeHTTP(w, r)
if w.Body.String() != "astaxie" {
t.Errorf("post func should astaxie")
}
}
func TestAutoFunc(t *testing.T) { func TestAutoFunc(t *testing.T) {
r, _ := http.NewRequest("GET", "/test/list", nil) r, _ := http.NewRequest("GET", "/test/list", nil)
w := httptest.NewRecorder() w := httptest.NewRecorder()
handler := NewControllerRegistor() handler := NewControllerRegister()
handler.AddAuto(&TestController{}) handler.AddAuto(&TestController{})
handler.ServeHTTP(w, r) handler.ServeHTTP(w, r)
if w.Body.String() != "i am list" { if w.Body.String() != "i am list" {
@ -97,7 +158,7 @@ func TestAutoFuncParams(t *testing.T) {
r, _ := http.NewRequest("GET", "/test/params/2009/11/12", nil) r, _ := http.NewRequest("GET", "/test/params/2009/11/12", nil)
w := httptest.NewRecorder() w := httptest.NewRecorder()
handler := NewControllerRegistor() handler := NewControllerRegister()
handler.AddAuto(&TestController{}) handler.AddAuto(&TestController{})
handler.ServeHTTP(w, r) handler.ServeHTTP(w, r)
if w.Body.String() != "20091112" { if w.Body.String() != "20091112" {
@ -109,7 +170,7 @@ func TestAutoExtFunc(t *testing.T) {
r, _ := http.NewRequest("GET", "/test/myext.json", nil) r, _ := http.NewRequest("GET", "/test/myext.json", nil)
w := httptest.NewRecorder() w := httptest.NewRecorder()
handler := NewControllerRegistor() handler := NewControllerRegister()
handler.AddAuto(&TestController{}) handler.AddAuto(&TestController{})
handler.ServeHTTP(w, r) handler.ServeHTTP(w, r)
if w.Body.String() != "json" { if w.Body.String() != "json" {
@ -122,22 +183,12 @@ func TestRouteOk(t *testing.T) {
r, _ := http.NewRequest("GET", "/person/anderson/thomas?learn=kungfu", nil) r, _ := http.NewRequest("GET", "/person/anderson/thomas?learn=kungfu", nil)
w := httptest.NewRecorder() w := httptest.NewRecorder()
handler := NewControllerRegistor() handler := NewControllerRegister()
handler.Add("/person/:last/:first", &TestController{}) handler.Add("/person/:last/:first", &TestController{}, "get:GetParams")
handler.ServeHTTP(w, r) handler.ServeHTTP(w, r)
body := w.Body.String()
lastNameParam := r.URL.Query().Get(":last") if body != "anderson+thomas+kungfu" {
firstNameParam := r.URL.Query().Get(":first") t.Errorf("url param set to [%s];", body)
learnParam := r.URL.Query().Get("learn")
if lastNameParam != "anderson" {
t.Errorf("url param set to [%s]; want [%s]", lastNameParam, "anderson")
}
if firstNameParam != "thomas" {
t.Errorf("url param set to [%s]; want [%s]", firstNameParam, "thomas")
}
if learnParam != "kungfu" {
t.Errorf("url param set to [%s]; want [%s]", learnParam, "kungfu")
} }
} }
@ -146,18 +197,14 @@ func TestManyRoute(t *testing.T) {
r, _ := http.NewRequest("GET", "/beego32-12.html", nil) r, _ := http.NewRequest("GET", "/beego32-12.html", nil)
w := httptest.NewRecorder() w := httptest.NewRecorder()
handler := NewControllerRegistor() handler := NewControllerRegister()
handler.Add("/beego:id([0-9]+)-:page([0-9]+).html", &TestController{}) handler.Add("/beego:id([0-9]+)-:page([0-9]+).html", &TestController{}, "get:GetManyRouter")
handler.ServeHTTP(w, r) handler.ServeHTTP(w, r)
id := r.URL.Query().Get(":id") body := w.Body.String()
page := r.URL.Query().Get(":page")
if id != "32" { if body != "3212" {
t.Errorf("url param set to [%s]; want [%s]", id, "32") t.Errorf("url param set to [%s];", body)
}
if page != "12" {
t.Errorf("url param set to [%s]; want [%s]", page, "12")
} }
} }
@ -165,7 +212,7 @@ func TestNotFound(t *testing.T) {
r, _ := http.NewRequest("GET", "/", nil) r, _ := http.NewRequest("GET", "/", nil)
w := httptest.NewRecorder() w := httptest.NewRecorder()
handler := NewControllerRegistor() handler := NewControllerRegister()
handler.ServeHTTP(w, r) handler.ServeHTTP(w, r)
if w.Code != http.StatusNotFound { if w.Code != http.StatusNotFound {
@ -179,7 +226,7 @@ func TestStatic(t *testing.T) {
r, _ := http.NewRequest("GET", "/static/js/jquery.js", nil) r, _ := http.NewRequest("GET", "/static/js/jquery.js", nil)
w := httptest.NewRecorder() w := httptest.NewRecorder()
handler := NewControllerRegistor() handler := NewControllerRegister()
handler.ServeHTTP(w, r) handler.ServeHTTP(w, r)
if w.Code != 404 { if w.Code != 404 {
@ -191,10 +238,120 @@ func TestPrepare(t *testing.T) {
r, _ := http.NewRequest("GET", "/json/list", nil) r, _ := http.NewRequest("GET", "/json/list", nil)
w := httptest.NewRecorder() w := httptest.NewRecorder()
handler := NewControllerRegistor() handler := NewControllerRegister()
handler.Add("/json/list", &JsonController{}) handler.Add("/json/list", &JsonController{})
handler.ServeHTTP(w, r) handler.ServeHTTP(w, r)
if w.Body.String() != `"prepare"` { if w.Body.String() != `"prepare"` {
t.Errorf(w.Body.String() + "user define func can't run") t.Errorf(w.Body.String() + "user define func can't run")
} }
} }
func TestAutoPrefix(t *testing.T) {
r, _ := http.NewRequest("GET", "/admin/test/list", nil)
w := httptest.NewRecorder()
handler := NewControllerRegister()
handler.AddAutoPrefix("/admin", &TestController{})
handler.ServeHTTP(w, r)
if w.Body.String() != "i am list" {
t.Errorf("TestAutoPrefix can't run")
}
}
func TestRouterGet(t *testing.T) {
r, _ := http.NewRequest("GET", "/user", nil)
w := httptest.NewRecorder()
handler := NewControllerRegister()
handler.Get("/user", func(ctx *context.Context) {
ctx.Output.Body([]byte("Get userlist"))
})
handler.ServeHTTP(w, r)
if w.Body.String() != "Get userlist" {
t.Errorf("TestRouterGet can't run")
}
}
func TestRouterPost(t *testing.T) {
r, _ := http.NewRequest("POST", "/user/123", nil)
w := httptest.NewRecorder()
handler := NewControllerRegister()
handler.Post("/user/:id", func(ctx *context.Context) {
ctx.Output.Body([]byte(ctx.Input.Param(":id")))
})
handler.ServeHTTP(w, r)
if w.Body.String() != "123" {
t.Errorf("TestRouterPost can't run")
}
}
func sayhello(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("sayhello"))
}
func TestRouterHandler(t *testing.T) {
r, _ := http.NewRequest("POST", "/sayhi", nil)
w := httptest.NewRecorder()
handler := NewControllerRegister()
handler.Handler("/sayhi", http.HandlerFunc(sayhello))
handler.ServeHTTP(w, r)
if w.Body.String() != "sayhello" {
t.Errorf("TestRouterHandler can't run")
}
}
//
// Benchmarks NewApp:
//
func beegoFilterFunc(ctx *context.Context) {
ctx.WriteString("hello")
}
type AdminController struct {
Controller
}
func (a *AdminController) Get() {
a.Ctx.WriteString("hello")
}
func TestRouterFunc(t *testing.T) {
mux := NewControllerRegister()
mux.Get("/action", beegoFilterFunc)
mux.Post("/action", beegoFilterFunc)
rw, r := testRequest("GET", "/action")
mux.ServeHTTP(rw, r)
if rw.Body.String() != "hello" {
t.Errorf("TestRouterFunc can't run")
}
}
func BenchmarkFunc(b *testing.B) {
mux := NewControllerRegister()
mux.Get("/action", beegoFilterFunc)
rw, r := testRequest("GET", "/action")
b.ResetTimer()
for i := 0; i < b.N; i++ {
mux.ServeHTTP(rw, r)
}
}
func BenchmarkController(b *testing.B) {
mux := NewControllerRegister()
mux.Add("/action", &AdminController{})
rw, r := testRequest("GET", "/action")
b.ResetTimer()
for i := 0; i < b.N; i++ {
mux.ServeHTTP(rw, r)
}
}
func testRequest(method, path string) (*httptest.ResponseRecorder, *http.Request) {
request, _ := http.NewRequest(method, path, nil)
recorder := httptest.NewRecorder()
return recorder, request
}

View File

@ -28,21 +28,21 @@ Then in you web app init the global session manager
* Use **memory** as provider: * Use **memory** as provider:
func init() { func init() {
globalSessions, _ = session.NewManager("memory", "gosessionid", 3600,"") globalSessions, _ = session.NewManager("memory", `{"cookieName":"gosessionid","gclifetime":3600}`)
go globalSessions.GC() go globalSessions.GC()
} }
* Use **file** as provider, the last param is the path where you want file to be stored: * Use **file** as provider, the last param is the path where you want file to be stored:
func init() { func init() {
globalSessions, _ = session.NewManager("file", "gosessionid", 3600, "./tmp") globalSessions, _ = session.NewManager("file",`{"cookieName":"gosessionid","gclifetime":3600,"ProviderConfig":"./tmp"}`)
go globalSessions.GC() go globalSessions.GC()
} }
* Use **Redis** as provider, the last param is the Redis conn address,poolsize,password: * Use **Redis** as provider, the last param is the Redis conn address,poolsize,password:
func init() { func init() {
globalSessions, _ = session.NewManager("redis", "gosessionid", 3600, "127.0.0.1:6379,100,astaxie") globalSessions, _ = session.NewManager("redis", `{"cookieName":"gosessionid","gclifetime":3600,"ProviderConfig":"127.0.0.1:6379,100,astaxie"}`)
go globalSessions.GC() go globalSessions.GC()
} }
@ -50,15 +50,24 @@ Then in you web app init the global session manager
func init() { func init() {
globalSessions, _ = session.NewManager( globalSessions, _ = session.NewManager(
"mysql", "gosessionid", 3600, "username:password@protocol(address)/dbname?param=value") "mysql", `{"cookieName":"gosessionid","gclifetime":3600,"ProviderConfig":"username:password@protocol(address)/dbname?param=value"}`)
go globalSessions.GC() go globalSessions.GC()
} }
* Use **Cookie** as provider:
func init() {
globalSessions, _ = session.NewManager(
"cookie", `{"cookieName":"gosessionid","enableSetCookie":false,"gclifetime":3600,"ProviderConfig":"{\"cookieName\":\"gosessionid\",\"securityKey\":\"beegocookiehashkey\"}"}`)
go globalSessions.GC()
}
Finally in the handlerfunc you can use it like this Finally in the handlerfunc you can use it like this
func login(w http.ResponseWriter, r *http.Request) { func login(w http.ResponseWriter, r *http.Request) {
sess := globalSessions.SessionStart(w, r) sess := globalSessions.SessionStart(w, r)
defer sess.SessionRelease() defer sess.SessionRelease(w)
username := sess.Get("username") username := sess.Get("username")
fmt.Println(username) fmt.Println(username)
if r.Method == "GET" { if r.Method == "GET" {
@ -78,19 +87,19 @@ When you develop a web app, maybe you want to write own provider because you mus
Writing a provider is easy. You only need to define two struct types Writing a provider is easy. You only need to define two struct types
(Session and Provider), which satisfy the interface definition. (Session and Provider), which satisfy the interface definition.
Maybe you will find the **memory** provider as good example. Maybe you will find the **memory** provider is a good example.
type SessionStore interface { type SessionStore interface {
Set(key, value interface{}) error //set session value Set(key, value interface{}) error //set session value
Get(key interface{}) interface{} //get session value Get(key interface{}) interface{} //get session value
Delete(key interface{}) error //delete session value Delete(key interface{}) error //delete session value
SessionID() string //back current sessionID SessionID() string //back current sessionID
SessionRelease() // release the resource & save data to provider SessionRelease(w http.ResponseWriter) // release the resource & save data to provider & return the data
Flush() error //delete all data Flush() error //delete all data
} }
type Provider interface { type Provider interface {
SessionInit(maxlifetime int64, savePath string) error SessionInit(gclifetime int64, config string) error
SessionRead(sid string) (SessionStore, error) SessionRead(sid string) (SessionStore, error)
SessionExist(sid string) bool SessionExist(sid string) bool
SessionRegenerate(oldsid, sid string) (SessionStore, error) SessionRegenerate(oldsid, sid string) (SessionStore, error)

View File

@ -0,0 +1,211 @@
// Beego (http://beego.me/)
// @description beego is an open-source, high-performance web framework for the Go programming language.
// @link http://github.com/astaxie/beego for the canonical source repository
// @license http://github.com/astaxie/beego/blob/master/LICENSE
// @authors astaxie
package session
import (
"net/http"
"strings"
"sync"
"github.com/couchbaselabs/go-couchbase"
"github.com/astaxie/beego/session"
)
var couchbpder = &CouchbaseProvider{}
type CouchbaseSessionStore struct {
b *couchbase.Bucket
sid string
lock sync.RWMutex
values map[interface{}]interface{}
maxlifetime int64
}
type CouchbaseProvider struct {
maxlifetime int64
savePath string
pool string
bucket string
b *couchbase.Bucket
}
func (cs *CouchbaseSessionStore) Set(key, value interface{}) error {
cs.lock.Lock()
defer cs.lock.Unlock()
cs.values[key] = value
return nil
}
func (cs *CouchbaseSessionStore) Get(key interface{}) interface{} {
cs.lock.RLock()
defer cs.lock.RUnlock()
if v, ok := cs.values[key]; ok {
return v
} else {
return nil
}
}
func (cs *CouchbaseSessionStore) Delete(key interface{}) error {
cs.lock.Lock()
defer cs.lock.Unlock()
delete(cs.values, key)
return nil
}
func (cs *CouchbaseSessionStore) Flush() error {
cs.lock.Lock()
defer cs.lock.Unlock()
cs.values = make(map[interface{}]interface{})
return nil
}
func (cs *CouchbaseSessionStore) SessionID() string {
return cs.sid
}
func (cs *CouchbaseSessionStore) SessionRelease(w http.ResponseWriter) {
defer cs.b.Close()
// if rs.values is empty, return directly
if len(cs.values) < 1 {
cs.b.Delete(cs.sid)
return
}
bo, err := session.EncodeGob(cs.values)
if err != nil {
return
}
cs.b.Set(cs.sid, int(cs.maxlifetime), bo)
}
func (cp *CouchbaseProvider) getBucket() *couchbase.Bucket {
c, err := couchbase.Connect(cp.savePath)
if err != nil {
return nil
}
pool, err := c.GetPool(cp.pool)
if err != nil {
return nil
}
bucket, err := pool.GetBucket(cp.bucket)
if err != nil {
return nil
}
return bucket
}
// init couchbase session
// savepath like couchbase server REST/JSON URL
// e.g. http://host:port/, Pool, Bucket
func (cp *CouchbaseProvider) SessionInit(maxlifetime int64, savePath string) error {
cp.maxlifetime = maxlifetime
configs := strings.Split(savePath, ",")
if len(configs) > 0 {
cp.savePath = configs[0]
}
if len(configs) > 1 {
cp.pool = configs[1]
}
if len(configs) > 2 {
cp.bucket = configs[2]
}
return nil
}
// read couchbase session by sid
func (cp *CouchbaseProvider) SessionRead(sid string) (session.SessionStore, error) {
cp.b = cp.getBucket()
var doc []byte
err := cp.b.Get(sid, &doc)
var kv map[interface{}]interface{}
if doc == nil {
kv = make(map[interface{}]interface{})
} else {
kv, err = session.DecodeGob(doc)
if err != nil {
return nil, err
}
}
cs := &CouchbaseSessionStore{b: cp.b, sid: sid, values: kv, maxlifetime: cp.maxlifetime}
return cs, nil
}
func (cp *CouchbaseProvider) SessionExist(sid string) bool {
cp.b = cp.getBucket()
defer cp.b.Close()
var doc []byte
if err := cp.b.Get(sid, &doc); err != nil || doc == nil {
return false
} else {
return true
}
}
func (cp *CouchbaseProvider) SessionRegenerate(oldsid, sid string) (session.SessionStore, error) {
cp.b = cp.getBucket()
var doc []byte
if err := cp.b.Get(oldsid, &doc); err != nil || doc == nil {
cp.b.Set(sid, int(cp.maxlifetime), "")
} else {
err := cp.b.Delete(oldsid)
if err != nil {
return nil, err
}
_, _ = cp.b.Add(sid, int(cp.maxlifetime), doc)
}
err := cp.b.Get(sid, &doc)
if err != nil {
return nil, err
}
var kv map[interface{}]interface{}
if doc == nil {
kv = make(map[interface{}]interface{})
} else {
kv, err = session.DecodeGob(doc)
if err != nil {
return nil, err
}
}
cs := &CouchbaseSessionStore{b: cp.b, sid: sid, values: kv, maxlifetime: cp.maxlifetime}
return cs, nil
}
func (cp *CouchbaseProvider) SessionDestroy(sid string) error {
cp.b = cp.getBucket()
defer cp.b.Close()
cp.b.Delete(sid)
return nil
}
func (cp *CouchbaseProvider) SessionGC() {
return
}
func (cp *CouchbaseProvider) SessionAll() int {
return 0
}
func init() {
session.Register("couchbase", couchbpder)
}

View File

@ -0,0 +1,212 @@
// Beego (http://beego.me/)
// @description beego is an open-source, high-performance web framework for the Go programming language.
// @link http://github.com/astaxie/beego for the canonical source repository
// @license http://github.com/astaxie/beego/blob/master/LICENSE
// @authors astaxie
package session
import (
"net/http"
"sync"
"github.com/astaxie/beego/session"
"github.com/beego/memcache"
)
var mempder = &MemProvider{}
// memcache session store
type MemcacheSessionStore struct {
c *memcache.Connection
sid string
lock sync.RWMutex
values map[interface{}]interface{}
maxlifetime int64
}
// set value in memcache session
func (rs *MemcacheSessionStore) Set(key, value interface{}) error {
rs.lock.Lock()
defer rs.lock.Unlock()
rs.values[key] = value
return nil
}
// get value in memcache session
func (rs *MemcacheSessionStore) Get(key interface{}) interface{} {
rs.lock.RLock()
defer rs.lock.RUnlock()
if v, ok := rs.values[key]; ok {
return v
} else {
return nil
}
}
// delete value in memcache session
func (rs *MemcacheSessionStore) Delete(key interface{}) error {
rs.lock.Lock()
defer rs.lock.Unlock()
delete(rs.values, key)
return nil
}
// clear all values in memcache session
func (rs *MemcacheSessionStore) Flush() error {
rs.lock.Lock()
defer rs.lock.Unlock()
rs.values = make(map[interface{}]interface{})
return nil
}
// get redis session id
func (rs *MemcacheSessionStore) SessionID() string {
return rs.sid
}
// save session values to redis
func (rs *MemcacheSessionStore) SessionRelease(w http.ResponseWriter) {
defer rs.c.Close()
// if rs.values is empty, return directly
if len(rs.values) < 1 {
rs.c.Delete(rs.sid)
return
}
b, err := session.EncodeGob(rs.values)
if err != nil {
return
}
rs.c.Set(rs.sid, 0, uint64(rs.maxlifetime), b)
}
// redis session provider
type MemProvider struct {
maxlifetime int64
savePath string
poolsize int
password string
}
// init redis session
// savepath like
// e.g. 127.0.0.1:9090
func (rp *MemProvider) SessionInit(maxlifetime int64, savePath string) error {
rp.maxlifetime = maxlifetime
rp.savePath = savePath
return nil
}
// read redis session by sid
func (rp *MemProvider) SessionRead(sid string) (session.SessionStore, error) {
conn, err := rp.connectInit()
if err != nil {
return nil, err
}
kvs, err := conn.Get(sid)
if err != nil {
return nil, err
}
var contain []byte
if len(kvs) > 0 {
contain = kvs[0].Value
}
var kv map[interface{}]interface{}
if len(contain) == 0 {
kv = make(map[interface{}]interface{})
} else {
kv, err = session.DecodeGob(contain)
if err != nil {
return nil, err
}
}
rs := &MemcacheSessionStore{c: conn, sid: sid, values: kv, maxlifetime: rp.maxlifetime}
return rs, nil
}
// check redis session exist by sid
func (rp *MemProvider) SessionExist(sid string) bool {
conn, err := rp.connectInit()
if err != nil {
return false
}
defer conn.Close()
if kvs, err := conn.Get(sid); err != nil || len(kvs) == 0 {
return false
} else {
return true
}
}
// generate new sid for redis session
func (rp *MemProvider) SessionRegenerate(oldsid, sid string) (session.SessionStore, error) {
conn, err := rp.connectInit()
if err != nil {
return nil, err
}
var contain []byte
if kvs, err := conn.Get(sid); err != nil || len(kvs) == 0 {
// oldsid doesn't exists, set the new sid directly
// ignore error here, since if it return error
// the existed value will be 0
conn.Set(sid, 0, uint64(rp.maxlifetime), []byte(""))
} else {
conn.Delete(oldsid)
conn.Set(sid, 0, uint64(rp.maxlifetime), kvs[0].Value)
contain = kvs[0].Value
}
var kv map[interface{}]interface{}
if len(contain) == 0 {
kv = make(map[interface{}]interface{})
} else {
kv, err = session.DecodeGob(contain)
if err != nil {
return nil, err
}
}
rs := &MemcacheSessionStore{c: conn, sid: sid, values: kv, maxlifetime: rp.maxlifetime}
return rs, nil
}
// delete redis session by id
func (rp *MemProvider) SessionDestroy(sid string) error {
conn, err := rp.connectInit()
if err != nil {
return err
}
defer conn.Close()
_, err = conn.Delete(sid)
if err != nil {
return err
}
return nil
}
// Impelment method, no used.
func (rp *MemProvider) SessionGC() {
return
}
// @todo
func (rp *MemProvider) SessionAll() int {
return 0
}
// connect to memcache and keep the connection.
func (rp *MemProvider) connectInit() (*memcache.Connection, error) {
c, err := memcache.Connect(rp.savePath)
if err != nil {
return nil, err
}
return c, nil
}
func init() {
session.Register("memcache", mempder)
}

View File

@ -1,22 +1,33 @@
// Beego (http://beego.me/)
// @description beego is an open-source, high-performance web framework for the Go programming language.
// @link http://github.com/astaxie/beego for the canonical source repository
// @license http://github.com/astaxie/beego/blob/master/LICENSE
// @authors astaxie
package session package session
//CREATE TABLE `session` ( // mysql session support need create table as sql:
// CREATE TABLE `session` (
// `session_key` char(64) NOT NULL, // `session_key` char(64) NOT NULL,
// `session_data` blob, // session_data` blob,
// `session_expiry` int(11) unsigned NOT NULL, // `session_expiry` int(11) unsigned NOT NULL,
// PRIMARY KEY (`session_key`) // PRIMARY KEY (`session_key`)
//) ENGINE=MyISAM DEFAULT CHARSET=utf8; // ) ENGINE=MyISAM DEFAULT CHARSET=utf8;
import ( import (
"database/sql" "database/sql"
"net/http"
"sync" "sync"
"time" "time"
"github.com/astaxie/beego/session"
_ "github.com/go-sql-driver/mysql" _ "github.com/go-sql-driver/mysql"
) )
var mysqlpder = &MysqlProvider{} var mysqlpder = &MysqlProvider{}
// mysql session store
type MysqlSessionStore struct { type MysqlSessionStore struct {
c *sql.DB c *sql.DB
sid string sid string
@ -24,6 +35,8 @@ type MysqlSessionStore struct {
values map[interface{}]interface{} values map[interface{}]interface{}
} }
// set value in mysql session.
// it is temp value in map.
func (st *MysqlSessionStore) Set(key, value interface{}) error { func (st *MysqlSessionStore) Set(key, value interface{}) error {
st.lock.Lock() st.lock.Lock()
defer st.lock.Unlock() defer st.lock.Unlock()
@ -31,6 +44,7 @@ func (st *MysqlSessionStore) Set(key, value interface{}) error {
return nil return nil
} }
// get value from mysql session
func (st *MysqlSessionStore) Get(key interface{}) interface{} { func (st *MysqlSessionStore) Get(key interface{}) interface{} {
st.lock.RLock() st.lock.RLock()
defer st.lock.RUnlock() defer st.lock.RUnlock()
@ -39,9 +53,9 @@ func (st *MysqlSessionStore) Get(key interface{}) interface{} {
} else { } else {
return nil return nil
} }
return nil
} }
// delete value in mysql session
func (st *MysqlSessionStore) Delete(key interface{}) error { func (st *MysqlSessionStore) Delete(key interface{}) error {
st.lock.Lock() st.lock.Lock()
defer st.lock.Unlock() defer st.lock.Unlock()
@ -49,6 +63,7 @@ func (st *MysqlSessionStore) Delete(key interface{}) error {
return nil return nil
} }
// clear all values in mysql session
func (st *MysqlSessionStore) Flush() error { func (st *MysqlSessionStore) Flush() error {
st.lock.Lock() st.lock.Lock()
defer st.lock.Unlock() defer st.lock.Unlock()
@ -56,26 +71,31 @@ func (st *MysqlSessionStore) Flush() error {
return nil return nil
} }
// get session id of this mysql session store
func (st *MysqlSessionStore) SessionID() string { func (st *MysqlSessionStore) SessionID() string {
return st.sid return st.sid
} }
func (st *MysqlSessionStore) SessionRelease() { // save mysql session values to database.
// must call this method to save values to database.
func (st *MysqlSessionStore) SessionRelease(w http.ResponseWriter) {
defer st.c.Close() defer st.c.Close()
if len(st.values) > 0 { b, err := session.EncodeGob(st.values)
b, err := encodeGob(st.values)
if err != nil { if err != nil {
return return
} }
st.c.Exec("UPDATE session set `session_data`= ? where session_key=?", b, st.sid) st.c.Exec("UPDATE session set `session_data`=?, `session_expiry`=? where session_key=?",
} b, time.Now().Unix(), st.sid)
} }
// mysql session provider
type MysqlProvider struct { type MysqlProvider struct {
maxlifetime int64 maxlifetime int64
savePath string savePath string
} }
// connect to mysql
func (mp *MysqlProvider) connectInit() *sql.DB { func (mp *MysqlProvider) connectInit() *sql.DB {
db, e := sql.Open("mysql", mp.savePath) db, e := sql.Open("mysql", mp.savePath)
if e != nil { if e != nil {
@ -84,25 +104,29 @@ func (mp *MysqlProvider) connectInit() *sql.DB {
return db return db
} }
// init mysql session.
// savepath is the connection string of mysql.
func (mp *MysqlProvider) SessionInit(maxlifetime int64, savePath string) error { func (mp *MysqlProvider) SessionInit(maxlifetime int64, savePath string) error {
mp.maxlifetime = maxlifetime mp.maxlifetime = maxlifetime
mp.savePath = savePath mp.savePath = savePath
return nil return nil
} }
func (mp *MysqlProvider) SessionRead(sid string) (SessionStore, error) { // get mysql session by sid
func (mp *MysqlProvider) SessionRead(sid string) (session.SessionStore, error) {
c := mp.connectInit() c := mp.connectInit()
row := c.QueryRow("select session_data from session where session_key=?", sid) row := c.QueryRow("select session_data from session where session_key=?", sid)
var sessiondata []byte var sessiondata []byte
err := row.Scan(&sessiondata) err := row.Scan(&sessiondata)
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
c.Exec("insert into session(`session_key`,`session_data`,`session_expiry`) values(?,?,?)", sid, "", time.Now().Unix()) c.Exec("insert into session(`session_key`,`session_data`,`session_expiry`) values(?,?,?)",
sid, "", time.Now().Unix())
} }
var kv map[interface{}]interface{} var kv map[interface{}]interface{}
if len(sessiondata) == 0 { if len(sessiondata) == 0 {
kv = make(map[interface{}]interface{}) kv = make(map[interface{}]interface{})
} else { } else {
kv, err = decodeGob(sessiondata) kv, err = session.DecodeGob(sessiondata)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -111,8 +135,10 @@ func (mp *MysqlProvider) SessionRead(sid string) (SessionStore, error) {
return rs, nil return rs, nil
} }
// check mysql session exist
func (mp *MysqlProvider) SessionExist(sid string) bool { func (mp *MysqlProvider) SessionExist(sid string) bool {
c := mp.connectInit() c := mp.connectInit()
defer c.Close()
row := c.QueryRow("select session_data from session where session_key=?", sid) row := c.QueryRow("select session_data from session where session_key=?", sid)
var sessiondata []byte var sessiondata []byte
err := row.Scan(&sessiondata) err := row.Scan(&sessiondata)
@ -123,7 +149,8 @@ func (mp *MysqlProvider) SessionExist(sid string) bool {
} }
} }
func (mp *MysqlProvider) SessionRegenerate(oldsid, sid string) (SessionStore, error) { // generate new sid for mysql session
func (mp *MysqlProvider) SessionRegenerate(oldsid, sid string) (session.SessionStore, error) {
c := mp.connectInit() c := mp.connectInit()
row := c.QueryRow("select session_data from session where session_key=?", oldsid) row := c.QueryRow("select session_data from session where session_key=?", oldsid)
var sessiondata []byte var sessiondata []byte
@ -136,7 +163,7 @@ func (mp *MysqlProvider) SessionRegenerate(oldsid, sid string) (SessionStore, er
if len(sessiondata) == 0 { if len(sessiondata) == 0 {
kv = make(map[interface{}]interface{}) kv = make(map[interface{}]interface{})
} else { } else {
kv, err = decodeGob(sessiondata) kv, err = session.DecodeGob(sessiondata)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -145,6 +172,7 @@ func (mp *MysqlProvider) SessionRegenerate(oldsid, sid string) (SessionStore, er
return rs, nil return rs, nil
} }
// delete mysql session by sid
func (mp *MysqlProvider) SessionDestroy(sid string) error { func (mp *MysqlProvider) SessionDestroy(sid string) error {
c := mp.connectInit() c := mp.connectInit()
c.Exec("DELETE FROM session where session_key=?", sid) c.Exec("DELETE FROM session where session_key=?", sid)
@ -152,6 +180,7 @@ func (mp *MysqlProvider) SessionDestroy(sid string) error {
return nil return nil
} }
// delete expired values in mysql session
func (mp *MysqlProvider) SessionGC() { func (mp *MysqlProvider) SessionGC() {
c := mp.connectInit() c := mp.connectInit()
c.Exec("DELETE from session where session_expiry < ?", time.Now().Unix()-mp.maxlifetime) c.Exec("DELETE from session where session_expiry < ?", time.Now().Unix()-mp.maxlifetime)
@ -159,6 +188,7 @@ func (mp *MysqlProvider) SessionGC() {
return return
} }
// count values in mysql session
func (mp *MysqlProvider) SessionAll() int { func (mp *MysqlProvider) SessionAll() int {
c := mp.connectInit() c := mp.connectInit()
defer c.Close() defer c.Close()
@ -171,5 +201,5 @@ func (mp *MysqlProvider) SessionAll() int {
} }
func init() { func init() {
Register("mysql", mysqlpder) session.Register("mysql", mysqlpder)
} }

Some files were not shown because too many files have changed in this diff Show More