1
0
mirror of https://github.com/astaxie/beego.git synced 2025-07-11 20:21:01 +00:00

166 Commits

Author SHA1 Message Date
42f1d1aeef change version 0.8 to 0.9 2013-08-14 10:53:22 +08:00
f4b3e7e4d2 orm small update 2013-08-13 19:33:43 +08:00
c6a436ed5d orm docs update 2013-08-13 19:28:37 +08:00
27b84841a7 orm add full regular go type support, such as int8, uint8, byte, rune. add date/datetime timezone support very well. 2013-08-13 17:17:19 +08:00
deb00809a5 orm add missed SetMaxIdleConns 2013-08-12 13:24:45 +08:00
eb06435f23 change position GOMAXPROCS to init that user can set own GOMAXPROCS 2013-08-12 12:46:59 +08:00
328f4566e4 Merge pull request #150 from miraclesu/form
Support custom label for renderform
2013-08-11 20:04:25 -07:00
a37b2bdfb0 Support custom label for renderform 2013-08-12 06:54:36 +08:00
50f3bd5835 add filter after 2013-08-12 00:14:42 +08:00
1f3ae3d682 Improve performance 2013-08-11 23:27:53 +08:00
ca1354e77f Merge pull request #147 from miraclesu/form
ignore struct field if form tag value is '-'
2013-08-11 07:44:28 -07:00
459b97858c renderform ignore struct field if form tag value is '-' 2013-08-11 22:42:35 +08:00
18c09bb2ed orm update docs 2013-08-11 22:27:54 +08:00
45345fa782 orm add postgres support 2013-08-11 22:27:45 +08:00
5c859466ef ignore struct field if form tag value is '-' 2013-08-11 22:21:31 +08:00
449fbe82f6 Update README.md 2013-08-11 13:26:28 +08:00
6c41e6dd78 orm add sqlite3 support, may be support postgres in next commit 2013-08-11 00:15:26 +08:00
9631c663d5 fix #145
this.DestroySession()
2013-08-10 23:58:25 +08:00
bbef213155 fix #144 2013-08-10 21:44:27 +08:00
bc060c95f8 Merge pull request #143 from miraclesu/form
Add renderform template function
2013-08-10 06:28:06 -07:00
9e1d5036f7 Add renderform template function 2013-08-10 16:33:46 +08:00
e47b2b677d Update ParserForm for new form tag style 2013-08-10 11:42:25 +08:00
38f6f8eef7 update TestParseForm 2013-08-10 10:29:29 +08:00
115b1d03db fix #138 2013-08-09 23:41:03 +08:00
0833d4baf8 fix #132 2013-08-09 22:19:05 +08:00
f2b359d8e8 orm full remove orm.Manager for simple use, add struct tag - for skip struct field 2013-08-09 20:14:18 +08:00
402932aa6e ... haha 2013-08-09 14:04:33 +08:00
f1e2372a56 orm update docs about debug log queries 2013-08-09 13:53:04 +08:00
45aa071261 orm add queries debug logger 2013-08-09 13:20:19 +08:00
8563000235 orm operator args now support multi types eg: []int []*int *int, Model *Model 2013-08-08 22:34:18 +08:00
9047d21ec5 orm fix, def a string in model but use int in db may cause nil pointer error 2013-08-08 22:15:27 +08:00
74a95f6cbf docs update 2013-08-08 11:07:08 +08:00
a17dcf4991 fix docs link 2013-08-07 23:39:18 +08:00
fc528c51a3 update docs 2013-08-07 23:35:45 +08:00
ad2965bbf9 update docs 2013-08-07 23:28:14 +08:00
37f8c6a04a zh docs update 2013-08-07 19:11:57 +08:00
46668b811f some fix / add test 2013-08-07 19:11:44 +08:00
10f4e822c3 add XSRFExpire 2013-08-07 11:22:23 +08:00
b191e96f51 Merge pull request #125 from miraclesu/valid
Change tag valid func default key
2013-08-06 08:22:56 -07:00
f9a31ea00a EnableXSRF 2013-08-06 23:21:52 +08:00
97d99fcef2 Change tag valid func default key 2013-08-06 23:15:20 +08:00
2fa534ff26 delete model move to orm 2013-08-06 16:40:23 +08:00
e47a147c3b serverJson Supoort 中文编码 2013-08-06 16:37:41 +08:00
4ecb9cc30b move httplib from beego to beego/httplib safemap support get all items 2013-08-06 16:13:45 +08:00
64ef8ad62b fix new version for memcahe client 2013-08-06 16:04:35 +08:00
6f2cd326bf Merge pull request #122 from pricees/master
Added Items() to return items from BeeCache
2013-08-04 23:15:37 -07:00
339346e307 Added Items() to return items from BeeCache 2013-08-05 00:47:37 -05:00
a611480b94 fix #121 2013-08-05 00:03:47 +08:00
fd3c8834da autorouter when /admin 301 to /admin/ 2013-08-04 23:13:29 +08:00
3d481178d7 improve router 2013-08-04 23:06:48 +08:00
d0cb112f4b fix test func name 2013-08-03 22:22:37 +08:00
c58445c772 add httplib support like http.client 2013-08-03 22:20:09 +08:00
f7dd376596 fix it , 2013-08-03 18:18:09 +08:00
5b3b6f7f48 add NewFlash func 2013-08-03 18:17:00 +08:00
0c2af58b8d fix fomat 2013-08-03 18:02:28 +08:00
0e0040e78d fix # 2013-08-03 18:00:57 +08:00
8ba5ea0ecf flash support 2013-08-03 17:55:53 +08:00
dbfd844ff2 beego support flash data 2013-08-03 17:55:53 +08:00
452478e779 Merge branch 'master' of github.com:astaxie/beego 2013-08-01 15:52:33 +08:00
6e06720e84 zh docs update 2013-08-01 15:52:05 +08:00
51baa35df1 now object crud is simple 2013-08-01 15:51:53 +08:00
6e2972673e Merge pull request #116 from lqixv/master
更新了部分部分文字
2013-08-01 00:25:19 -07:00
0ac7e342f0 Merge pull request #118 from miraclesu/test
Refactor template visit & Add template test
2013-07-31 22:29:04 -07:00
2a9852fa94 Add template test 2013-08-01 12:10:56 +08:00
250cbf593b fix values name 2013-08-01 12:09:17 +08:00
6fbdbaae80 Refactor template 2013-08-01 11:57:29 +08:00
5ccdaeb09e zh docs update 2013-08-01 09:23:44 +08:00
b0b64eb404 some change 2013-08-01 09:23:32 +08:00
831eeca7c8 zh docs update 2013-07-31 22:11:35 +08:00
2c5e062c2b some fix 2013-07-31 22:11:22 +08:00
c83d03c298 fix #117 2013-07-31 21:36:10 +08:00
485d89d5c8 tpl tolower 2013-07-30 22:45:50 +08:00
a997ca746f fix router's /path/ 2013-07-30 22:38:01 +08:00
572e281566 fix router's bug 2013-07-30 22:33:36 +08:00
8674b81b3a fix router & tpl tolower 2013-07-30 22:17:16 +08:00
bce35c708a init orm project, beta, unstable 2013-07-30 20:32:38 +08:00
ccbf116fd6 Merge pull request #115 from Xuyuanp/master
笔误
2013-07-30 05:17:44 -07:00
f26d81200b 笔误 2013-07-30 18:52:54 +08:00
71173aa010 Merge pull request #114 from Xuyuanp/master
bee command
2013-07-30 02:42:09 -07:00
4ee3d6aad4 bee command 2013-07-30 17:29:22 +08:00
4ffe988c30 update docs 2013-07-28 20:17:29 +08:00
df354acf97 Merge pull request #113 from miraclesu/valid
Support Match validate function for tag
2013-07-28 04:47:44 -07:00
6662eef2fd Support Match validate function for tag 2013-07-28 19:22:09 +08:00
dcdfaf36f1 Accept parameters more types 2013-07-28 16:59:35 +08:00
ae7e31717a Merge pull request #112 from miraclesu/form
Refactor ParseForm
2013-07-28 01:31:27 -07:00
fe7ecc377a Refactor ParseForm 2013-07-28 13:51:01 +08:00
29b1c8e1cb Merge pull request #111 from marswj/master
统一文档和代码中RunMode
2013-07-27 20:33:00 -07:00
ae906eed8f Update Quickstart.md 2013-07-28 11:14:54 +08:00
07ce3fb8ea Update Quickstart.md 2013-07-28 11:14:04 +08:00
1d7d6c6f99 Merge pull request #109 from miraclesu/valid
Add some validate functions
2013-07-27 06:12:23 -07:00
f490141217 Add some validate functions 2013-07-27 20:40:15 +08:00
0e748c6871 parse url to params
/object/login/2009/07/11
parse to ObjectController  Login function
params:map[0:2009 1:07 2:11]
2013-07-27 10:55:10 +08:00
f7e7fab6f2 support autorouter 2013-07-27 10:25:14 +08:00
fbde7df487 Merge pull request #106 from miraclesu/form
Fix utils test fail
2013-07-26 01:59:44 -07:00
914b6fa966 Fix utils test fail 2013-07-26 16:56:25 +08:00
69096b09f3 update zh docs 2013-07-26 16:41:11 +08:00
da5dd4d173 Merge pull request #105 from miraclesu/form
Add ParseForm Function & utils test
2013-07-26 01:36:13 -07:00
6b5dc3b7d5 Remove MarkDown test 2013-07-26 16:31:01 +08:00
d31ac49ead Merge branch 'master' of https://github.com/astaxie/beego into form 2013-07-26 16:29:22 +08:00
60afcd069a Add utils test 2013-07-26 16:29:00 +08:00
a39139c610 Update Quickstart.md
纠偏:从上面的例子中,并不能知道 DelSession(name string) 方法。
2013-07-26 09:09:09 +08:00
42370b5eb8 Update Quickstart.md
更正某些错别字,或语义不清
2013-07-26 08:31:46 +08:00
b99a09d73b Add two test cases for ParseForm 2013-07-25 22:50:29 +08:00
b8e06f6365 Add ParseForm function for *Controller 2013-07-25 22:27:25 +08:00
f552338822 Add ParseForm function 2013-07-25 22:25:12 +08:00
6373379da6 Merge pull request #103 from eXthen/master
I guess write would be better
2013-07-25 06:28:10 -07:00
ff11bcdb7c I guess write would be better 2013-07-25 19:04:26 +08:00
c2bb6b3068 Update log.go 2013-07-25 18:57:40 +08:00
259617f68d remove markdown 2013-07-25 16:40:37 +08:00
ab08aa9c9e MethodByName 2013-07-25 16:08:18 +08:00
7c610ee7c9 fix reflect find methodByName 2013-07-25 16:00:42 +08:00
23deaedd39 add to beego 2013-07-25 15:50:16 +08:00
538d39a704 add some tips 2013-07-25 15:46:47 +08:00
b961abb52b update docs 2013-07-25 15:41:25 +08:00
b32b12b208 add some comments 2013-07-25 15:40:33 +08:00
e88c2be013 update docs & update beego's version 2013-07-25 15:37:38 +08:00
d5ddd0a9dd support user define function
+//Add("/user",&UserController{})
+//Add("/api/list",&RestController{},"*:ListFood")
+//Add("/api/create",&RestController{},"post:CreateFood")
+//Add("/api/update",&RestController{},"put:UpdateFood")
+//Add("/api/delete",&RestController{},"delete:DeleteFood")
+//Add("/api",&RestController{},"get,post:ApiFunc")
+//Add("/simple",&SimpleController{},"get:GetFunc;post:PostFunc")
2013-07-25 15:17:09 +08:00
dff36a18a2 Merge pull request #101 from miraclesu/valid
Valid
2013-07-23 21:44:35 -07:00
fb78d83ec3 Merge branch 'master' of https://github.com/astaxie/beego into valid 2013-07-24 12:37:07 +08:00
d23700b919 update README 2013-07-24 12:36:46 +08:00
92db56c0cb add struct tag support 2013-07-24 12:20:42 +08:00
4c6163baa0 add funcmap 2013-07-24 01:20:24 +08:00
f46388fa63 setcookie set to unique. fix multi setcookie 2013-07-23 21:54:45 +08:00
aba1728bc3 add some util funcs 2013-07-23 14:42:14 +08:00
ddb9ed39a5 add validation README 2013-07-22 17:40:32 +08:00
a242f61b8e Merge pull request #100 from miraclesu/valid
Valid
2013-07-21 19:26:43 -07:00
d19de30d9c add test 2013-07-21 23:46:18 +08:00
6d05163c9f add validation funcs 2013-07-21 01:37:24 +08:00
a41cd17092 add validators 2013-07-19 16:49:28 +08:00
ec7324e972 Merge pull request #98 from Unknwon/master
Fixed bug: error page cannot show correct corresponding status code
2013-07-18 01:20:16 -07:00
7f5dd13422 Fixed bug: error page cannot show correct corresponding status code 2013-07-18 14:42:45 +08:00
7f4ad7ff46 fix #91 2013-07-16 19:05:44 +08:00
60200689f4 fix setcookie time type 2013-07-11 10:57:34 +08:00
af3797e16c Merge pull request #93 from Unknwon/master
Sync documentation of English with Chinese version.
2013-07-10 07:12:07 -07:00
d4743fb10d Sync documentation of English with Chinese version 2013-07-10 22:06:28 +08:00
38b083e117 add docs about how to write api application 2013-07-09 16:43:03 +08:00
fece5adc2a add example for api application 2013-07-09 13:59:47 +08:00
7bfb4126d7 support copy requestbody 2013-07-08 23:12:31 +08:00
2abe584bc5 support restful router 2013-07-08 18:35:10 +08:00
ee9223b1b9 fix #18
func (this *MainController) Get() {
this.GoToFunc("Test")
}

func (this *MainController) Test() {
this.Ctx.WriteString("testtest")
}
2013-07-08 17:35:09 +08:00
d2a16ff8f6 fix #26 add xsrf function 2013-07-08 16:17:08 +08:00
f1e5059682 fix #69 refer to http://www.php.net/manual/zh/function.setcookie.php 2013-07-08 15:13:51 +08:00
11977f4f77 fix #90 2013-07-07 17:58:50 +08:00
461eac46b9 fix #89 2013-07-07 17:45:39 +08:00
75af664511 change r.ParseMultipartForm position 2013-07-04 23:41:35 +08:00
174298b497 fix cache's bug expird is not changed by get method 2013-07-04 13:02:11 +08:00
9b392a0601 update to hotupdate's connet timeout 2013-07-03 16:58:15 +08:00
bf9de3bcf6 add HttpServerTimeOut setting 2013-07-03 15:29:54 +08:00
189df1280c fix #87 2013-07-02 09:45:12 +08:00
8807c327d1 fix log delete 2013-06-29 14:39:02 +08:00
d627ec013e add config to countol if enable hotupdate 2013-06-28 22:09:08 +08:00
d0bbc67b27 Merge pull request #86 from slene/master
fix logrotate
2013-06-28 06:18:51 -07:00
453557948e fix logrotate close fd before rename file, add a MuxWriter for Logger 2013-06-27 22:04:01 +08:00
4033692dcb Merge pull request #85 from Unknwon/master
sync quickstart.
2013-06-26 18:45:40 -07:00
236f28c53c sync quickstart. 2013-06-27 00:24:06 +08:00
b2bfed8937 fix close err 2013-06-26 23:34:32 +08:00
71adbdd7d7 add mutex 2013-06-26 22:16:03 +08:00
573df2e747 add isclose call close many times 2013-06-26 22:13:25 +08:00
aa9cb6d052 delete session's cookie Expires 2013-06-25 23:08:47 +08:00
8358e0ff48 Merge pull request #83 from shxsun/master
In go version 1.0.3 will call build error
2013-06-25 07:08:35 -07:00
1687ec85de fix build problem 2013-06-25 21:40:42 +08:00
d5fc0a4bda Merge pull request #82 from matrixik/master
Change: SetRotateMaxDay => SetRotateMaxDays
2013-06-25 04:08:51 -07:00
3bcff77947 Change: SetRotateMaxDay => SetRotateMaxDays 2013-06-25 12:54:51 +02:00
1521842d7a add static global function & add seocms example 2013-06-25 18:49:08 +08:00
b04813e472 fix #80
add a function to delete default StaticPath
2013-06-25 17:21:19 +08:00
9e41d93184 delete strcut map
I think if user should set field in controller, there's no need to have
thie feature
2013-06-25 16:44:53 +08:00
97 changed files with 11600 additions and 540 deletions

View File

@ -6,6 +6,7 @@ beego is a Go Framework which is inspired from tornado and sinatra.
It is a simply & powerful web framework.
more info [beego.me](http://beego.me)
## Features
@ -38,4 +39,5 @@ beego is licensed under the Apache Licence, Version 2.0
## Use case
- Displaying API documentation: [gowalker](https://github.com/Unknwon/gowalker)
- seocms: [seocms](https://github.com/chinakr/seocms)
- CMS: [toropress](https://github.com/insionng/toropress)

117
beego.go
View File

@ -10,9 +10,10 @@ import (
"os"
"path"
"runtime"
"time"
)
const VERSION = "0.7.0"
const VERSION = "0.9.0"
var (
BeeApp *App
@ -38,8 +39,15 @@ var (
SessionSavePath string // session savepath if use mysql/redis/file this set to the connectinfo
UseFcgi bool
MaxMemory int64
EnableGzip bool // enable gzip
DirectoryIndex bool //ebable DirectoryIndex default is false
EnableGzip bool // enable gzip
DirectoryIndex bool //ebable DirectoryIndex default is false
EnbaleHotUpdate bool //enable HotUpdate default is false
HttpServerTimeOut int64 //set httpserver timeout
ErrorsShow bool //set weather show errors
XSRFKEY string //set XSRF
EnableXSRF bool
XSRFExpire int
CopyRequestBody bool //When in raw application, You want to the reqeustbody
)
func init() {
@ -66,7 +74,12 @@ func init() {
EnableGzip = false
StaticDir["/static"] = "static"
AppConfigPath = path.Join(AppPath, "conf", "app.conf")
HttpServerTimeOut = 0
ErrorsShow = true
XSRFKEY = "beegoxsrf"
XSRFExpire = 60
ParseConfig()
runtime.GOMAXPROCS(runtime.NumCPU())
}
type App struct {
@ -93,24 +106,44 @@ func (app *App) Run() {
}
err = fcgi.Serve(l, app.Handlers)
} else {
server := &http.Server{Handler: app.Handlers}
laddr, err := net.ResolveTCPAddr("tcp", addr)
if nil != err {
BeeLogger.Fatal("ResolveTCPAddr:", err)
if EnbaleHotUpdate {
server := &http.Server{
Handler: app.Handlers,
ReadTimeout: time.Duration(HttpServerTimeOut) * time.Second,
WriteTimeout: time.Duration(HttpServerTimeOut) * time.Second,
}
laddr, err := net.ResolveTCPAddr("tcp", addr)
if nil != err {
BeeLogger.Fatal("ResolveTCPAddr:", err)
}
l, err = GetInitListner(laddr)
theStoppable = newStoppable(l)
err = server.Serve(theStoppable)
theStoppable.wg.Wait()
CloseSelf()
} else {
s := &http.Server{
Addr: addr,
Handler: app.Handlers,
ReadTimeout: time.Duration(HttpServerTimeOut) * time.Second,
WriteTimeout: time.Duration(HttpServerTimeOut) * time.Second,
}
err = s.ListenAndServe()
}
l, err = GetInitListner(laddr)
theStoppable = newStoppable(l)
err = server.Serve(theStoppable)
theStoppable.wg.Wait()
CloseSelf()
}
if err != nil {
BeeLogger.Fatal("ListenAndServe: ", err)
}
}
func (app *App) Router(path string, c ControllerInterface) *App {
app.Handlers.Add(path, c)
func (app *App) Router(path string, c ControllerInterface, mappingMethods ...string) *App {
app.Handlers.Add(path, c, mappingMethods...)
return app
}
func (app *App) AutoRouter(c ControllerInterface) *App {
app.Handlers.AddAuto(c)
return app
}
@ -129,6 +162,21 @@ func (app *App) FilterPrefixPath(path string, filter http.HandlerFunc) *App {
return app
}
func (app *App) FilterAfter(filter http.HandlerFunc) *App {
app.Handlers.FilterAfter(filter)
return app
}
func (app *App) FilterParamAfter(param string, filter http.HandlerFunc) *App {
app.Handlers.FilterParamAfter(param, filter)
return app
}
func (app *App) FilterPrefixPathAfter(path string, filter http.HandlerFunc) *App {
app.Handlers.FilterPrefixPathAfter(path, filter)
return app
}
func (app *App) SetViewsPath(path string) *App {
ViewsPath = path
return app
@ -139,6 +187,11 @@ func (app *App) SetStaticPath(url string, path string) *App {
return app
}
func (app *App) DelStaticPath(url string) *App {
delete(StaticDir, url)
return app
}
func (app *App) ErrorLog(ctx *Context) {
BeeLogger.Printf("[ERR] host: '%s', request: '%s %s', proto: '%s', ua: '%s', remote: '%s'\n", ctx.Request.Host, ctx.Request.Method, ctx.Request.URL.Path, ctx.Request.Proto, ctx.Request.UserAgent(), ctx.Request.RemoteAddr)
}
@ -152,8 +205,19 @@ func RegisterController(path string, c ControllerInterface) *App {
return BeeApp
}
func Router(path string, c ControllerInterface) *App {
BeeApp.Router(path, c)
func Router(rootpath string, c ControllerInterface, mappingMethods ...string) *App {
BeeApp.Router(rootpath, c, mappingMethods...)
return BeeApp
}
func RESTRouter(rootpath string, c ControllerInterface) *App {
Router(rootpath, c)
Router(path.Join(rootpath, ":objectId"), c)
return BeeApp
}
func AutoRouter(c ControllerInterface) *App {
BeeApp.AutoRouter(c)
return BeeApp
}
@ -177,6 +241,11 @@ func SetStaticPath(url string, path string) *App {
return BeeApp
}
func DelStaticPath(url string) *App {
delete(StaticDir, url)
return BeeApp
}
func Filter(filter http.HandlerFunc) *App {
BeeApp.Filter(filter)
return BeeApp
@ -192,6 +261,21 @@ func FilterPrefixPath(path string, filter http.HandlerFunc) *App {
return BeeApp
}
func FilterAfter(filter http.HandlerFunc) *App {
BeeApp.FilterAfter(filter)
return BeeApp
}
func FilterParamAfter(param string, filter http.HandlerFunc) *App {
BeeApp.FilterParamAfter(param, filter)
return BeeApp
}
func FilterPrefixPathAfter(path string, filter http.HandlerFunc) *App {
BeeApp.FilterPrefixPathAfter(path, filter)
return BeeApp
}
func Run() {
if AppConfigPath != path.Join(AppPath, "conf", "app.conf") {
err := ParseConfig()
@ -215,7 +299,6 @@ func Run() {
Warn(err)
}
}
runtime.GOMAXPROCS(runtime.NumCPU())
registerErrorHander()
BeeApp.Run()
}

View File

@ -73,6 +73,11 @@ func (bc *BeeCache) Delete(name string) (ok bool, err error) {
return
}
// Return all of the item in a BeeCache
func (bc *BeeCache) Items() map[string]*BeeItem {
return bc.items
}
func (bc *BeeCache) IsExist(name string) bool {
bc.lock.RLock()
defer bc.lock.RUnlock()

4
cache/cache.go vendored
View File

@ -6,8 +6,10 @@ import (
type Cache interface {
Get(key string) interface{}
Put(key string, val interface{}, timeout int) error
Put(key string, val interface{}, timeout int64) error
Delete(key string) error
Incr(key string) error
Decr(key string) error
IsExist(key string) bool
ClearAll() error
StartAndGC(config string) error

19
cache/cache_test.go vendored
View File

@ -6,7 +6,7 @@ import (
)
func Test_cache(t *testing.T) {
bm, err := NewCache("memory", `{"interval":60}`)
bm, err := NewCache("memory", `{"interval":20}`)
if err != nil {
t.Error("init err")
}
@ -21,7 +21,7 @@ func Test_cache(t *testing.T) {
t.Error("get err")
}
time.Sleep(70 * time.Second)
time.Sleep(30 * time.Second)
if bm.IsExist("astaxie") {
t.Error("check err")
@ -31,6 +31,21 @@ func Test_cache(t *testing.T) {
t.Error("set Error", err)
}
if err = bm.Incr("astaxie"); err != nil {
t.Error("Incr Error", err)
}
if v := bm.Get("astaxie"); v.(int) != 2 {
t.Error("get err")
}
if err = bm.Decr("astaxie"); err != nil {
t.Error("Incr Error", err)
}
if v := bm.Get("astaxie"); v.(int) != 1 {
t.Error("get err")
}
bm.Delete("astaxie")
if bm.IsExist("astaxie") {
t.Error("delete err")

20
cache/memcache.go vendored
View File

@ -19,16 +19,20 @@ func (rc *MemcacheCache) Get(key string) interface{} {
if rc.c == nil {
rc.c = rc.connectInit()
}
v, _, err := rc.c.Get(key)
v, err := rc.c.Get(key)
if err != nil {
return nil
}
var contain interface{}
contain = v
if len(v) > 0 {
contain = string(v[0].Value)
} else {
contain = nil
}
return contain
}
func (rc *MemcacheCache) Put(key string, val interface{}, timeout int) error {
func (rc *MemcacheCache) Put(key string, val interface{}, timeout int64) error {
if rc.c == nil {
rc.c = rc.connectInit()
}
@ -51,11 +55,19 @@ func (rc *MemcacheCache) Delete(key string) error {
return err
}
func (rc *MemcacheCache) Incr(key string) error {
return errors.New("not support in memcache")
}
func (rc *MemcacheCache) Decr(key string) error {
return errors.New("not support in memcache")
}
func (rc *MemcacheCache) IsExist(key string) bool {
if rc.c == nil {
rc.c = rc.connectInit()
}
v, _, err := rc.c.Get(key)
v, err := rc.c.Get(key)
if err != nil {
return false
}

98
cache/memory.go vendored
View File

@ -4,7 +4,6 @@ import (
"encoding/json"
"errors"
"fmt"
"strconv"
"sync"
"time"
)
@ -16,12 +15,7 @@ var (
type MemoryItem struct {
val interface{}
Lastaccess time.Time
expired int
}
func (itm *MemoryItem) Access() interface{} {
itm.Lastaccess = time.Now()
return itm.val
expired int64
}
type MemoryCache struct {
@ -44,13 +38,21 @@ func (bc *MemoryCache) Get(name string) interface{} {
if !ok {
return nil
}
return itm.Access()
if (time.Now().Unix() - itm.Lastaccess.Unix()) > itm.expired {
go bc.Delete(name)
return nil
}
return itm.val
}
func (bc *MemoryCache) Put(name string, value interface{}, expired int) error {
func (bc *MemoryCache) Put(name string, value interface{}, expired int64) error {
bc.lock.Lock()
defer bc.lock.Unlock()
t := MemoryItem{val: value, Lastaccess: time.Now(), expired: expired}
t := MemoryItem{
val: value,
Lastaccess: time.Now(),
expired: expired,
}
if _, ok := bc.items[name]; ok {
return errors.New("the key is exist")
} else {
@ -73,6 +75,70 @@ func (bc *MemoryCache) Delete(name string) error {
return nil
}
func (bc *MemoryCache) Incr(key string) error {
bc.lock.RLock()
defer bc.lock.RUnlock()
itm, ok := bc.items[key]
if !ok {
return errors.New("key not exist")
}
switch itm.val.(type) {
case int:
itm.val = itm.val.(int) + 1
case int64:
itm.val = itm.val.(int64) + 1
case int32:
itm.val = itm.val.(int32) + 1
case uint:
itm.val = itm.val.(uint) + 1
case uint32:
itm.val = itm.val.(uint32) + 1
case uint64:
itm.val = itm.val.(uint64) + 1
default:
return errors.New("item val is not int int64 int32")
}
return nil
}
func (bc *MemoryCache) Decr(key string) error {
bc.lock.RLock()
defer bc.lock.RUnlock()
itm, ok := bc.items[key]
if !ok {
return errors.New("key not exist")
}
switch itm.val.(type) {
case int:
itm.val = itm.val.(int) - 1
case int64:
itm.val = itm.val.(int64) - 1
case int32:
itm.val = itm.val.(int32) - 1
case uint:
if itm.val.(uint) > 0 {
itm.val = itm.val.(uint) - 1
} else {
return errors.New("item val is less than 0")
}
case uint32:
if itm.val.(uint32) > 0 {
itm.val = itm.val.(uint32) - 1
} else {
return errors.New("item val is less than 0")
}
case uint64:
if itm.val.(uint64) > 0 {
itm.val = itm.val.(uint64) - 1
} else {
return errors.New("item val is less than 0")
}
default:
return errors.New("item val is not int int64 int32")
}
return nil
}
func (bc *MemoryCache) IsExist(name string) bool {
bc.lock.RLock()
defer bc.lock.RUnlock()
@ -91,7 +157,7 @@ func (bc *MemoryCache) ClearAll() error {
func (bc *MemoryCache) StartAndGC(config string) error {
var cf map[string]int
json.Unmarshal([]byte(config), &cf)
if _, ok := cf["every"]; !ok {
if _, ok := cf["interval"]; !ok {
cf = make(map[string]int)
cf["interval"] = DefaultEvery
}
@ -110,7 +176,7 @@ func (bc *MemoryCache) vaccuum() {
return
}
for {
<-time.After(time.Duration(bc.dur) * time.Second)
<-time.After(bc.dur)
if bc.items == nil {
return
}
@ -128,12 +194,8 @@ func (bc *MemoryCache) item_expired(name string) bool {
if !ok {
return true
}
dur := time.Now().Sub(itm.Lastaccess)
sec, err := strconv.Atoi(fmt.Sprintf("%0.0f", dur.Seconds()))
if err != nil {
delete(bc.items, name)
return true
} else if sec >= itm.expired {
sec := time.Now().Unix() - itm.Lastaccess.Unix()
if sec >= itm.expired {
delete(bc.items, name)
return true
}

24
cache/redis.go vendored
View File

@ -31,7 +31,7 @@ func (rc *RedisCache) Get(key string) interface{} {
return v
}
func (rc *RedisCache) Put(key string, val interface{}, timeout int) error {
func (rc *RedisCache) Put(key string, val interface{}, timeout int64) error {
if rc.c == nil {
rc.c = rc.connectInit()
}
@ -58,6 +58,28 @@ func (rc *RedisCache) IsExist(key string) bool {
return v
}
func (rc *RedisCache) Incr(key string) error {
if rc.c == nil {
rc.c = rc.connectInit()
}
_, err := redis.Bool(rc.c.Do("HINCRBY", rc.key, key, 1))
if err != nil {
return err
}
return nil
}
func (rc *RedisCache) Decr(key string) error {
if rc.c == nil {
rc.c = rc.connectInit()
}
_, err := redis.Bool(rc.c.Do("HINCRBY", rc.key, key, -1))
if err != nil {
return err
}
return nil
}
func (rc *RedisCache) ClearAll() error {
if rc.c == nil {
rc.c = rc.connectInit()

View File

@ -133,49 +133,70 @@ func ParseConfig() (err error) {
if v, err := AppConfig.Int("httpport"); err == nil {
HttpPort = v
}
if v, err := AppConfig.Int64("maxmemory"); err == nil {
MaxMemory = v
if maxmemory, err := AppConfig.Int64("maxmemory"); err == nil {
MaxMemory = maxmemory
}
AppName = AppConfig.String("appname")
if runmode := AppConfig.String("runmode"); runmode != "" {
RunMode = runmode
}
if ar, err := AppConfig.Bool("autorender"); err == nil {
AutoRender = ar
if autorender, err := AppConfig.Bool("autorender"); err == nil {
AutoRender = autorender
}
if ar, err := AppConfig.Bool("autorecover"); err == nil {
RecoverPanic = ar
if autorecover, err := AppConfig.Bool("autorecover"); err == nil {
RecoverPanic = autorecover
}
if ar, err := AppConfig.Bool("pprofon"); err == nil {
PprofOn = ar
if pprofon, err := AppConfig.Bool("pprofon"); err == nil {
PprofOn = pprofon
}
if views := AppConfig.String("viewspath"); views != "" {
ViewsPath = views
}
if ar, err := AppConfig.Bool("sessionon"); err == nil {
SessionOn = ar
if sessionon, err := AppConfig.Bool("sessionon"); err == nil {
SessionOn = sessionon
}
if ar := AppConfig.String("sessionprovider"); ar != "" {
SessionProvider = ar
if sessProvider := AppConfig.String("sessionprovider"); sessProvider != "" {
SessionProvider = sessProvider
}
if ar := AppConfig.String("sessionname"); ar != "" {
SessionName = ar
if sessName := AppConfig.String("sessionname"); sessName != "" {
SessionName = sessName
}
if ar := AppConfig.String("sessionsavepath"); ar != "" {
SessionSavePath = ar
if sesssavepath := AppConfig.String("sessionsavepath"); sesssavepath != "" {
SessionSavePath = sesssavepath
}
if ar, err := AppConfig.Int("sessiongcmaxlifetime"); err == nil && ar != 0 {
int64val, _ := strconv.ParseInt(strconv.Itoa(ar), 10, 64)
if sessMaxLifeTime, err := AppConfig.Int("sessiongcmaxlifetime"); err == nil && sessMaxLifeTime != 0 {
int64val, _ := strconv.ParseInt(strconv.Itoa(sessMaxLifeTime), 10, 64)
SessionGCMaxLifetime = int64val
}
if ar, err := AppConfig.Bool("usefcgi"); err == nil {
UseFcgi = ar
if usefcgi, err := AppConfig.Bool("usefcgi"); err == nil {
UseFcgi = usefcgi
}
if ar, err := AppConfig.Bool("enablegzip"); err == nil {
EnableGzip = ar
if enablegzip, err := AppConfig.Bool("enablegzip"); err == nil {
EnableGzip = enablegzip
}
if ar, err := AppConfig.Bool("directoryindex"); err == nil {
DirectoryIndex = ar
if directoryindex, err := AppConfig.Bool("directoryindex"); err == nil {
DirectoryIndex = directoryindex
}
if hotupdate, err := AppConfig.Bool("hotupdate"); err == nil {
EnbaleHotUpdate = hotupdate
}
if timeout, err := AppConfig.Int64("httpservertimeout"); err == nil {
HttpServerTimeOut = timeout
}
if errorsshow, err := AppConfig.Bool("errorsshow"); err == nil {
ErrorsShow = errorsshow
}
if copyrequestbody, err := AppConfig.Bool("copyrequestbody"); err == nil {
CopyRequestBody = copyrequestbody
}
if xsrfkey := AppConfig.String("xsrfkey"); xsrfkey != "" {
XSRFKEY = xsrfkey
}
if enablexsrf, err := AppConfig.Bool("enablexsrf"); err == nil {
EnableXSRF = enablexsrf
}
if expire, err := AppConfig.Int("xsrfexpire"); err == nil {
XSRFExpire = expire
}
}
return nil

View File

@ -1,16 +1,17 @@
package beego
import (
"bytes"
"fmt"
"mime"
"net/http"
"strings"
"time"
)
type Context struct {
ResponseWriter http.ResponseWriter
Request *http.Request
RequestBody []byte
Params map[string]string
}
@ -58,14 +59,61 @@ func (ctx *Context) SetHeader(hdr string, val string, unique bool) {
}
//Sets a cookie -- duration is the amount of time in seconds. 0 = forever
func (ctx *Context) SetCookie(name string, value string, age int64) {
var utctime time.Time
if age == 0 {
// 2^31 - 1 seconds (roughly 2038)
utctime = time.Unix(2147483647, 0)
//params:
//string name
//string value
//int64 expire = 0
//string $path
//string $domain
//bool $secure = false
//bool $httponly = false
func (ctx *Context) SetCookie(name string, value string, others ...interface{}) {
var b bytes.Buffer
fmt.Fprintf(&b, "%s=%s", sanitizeName(name), sanitizeValue(value))
if len(others) > 0 {
switch others[0].(type) {
case int:
fmt.Fprintf(&b, "; Max-Age=%d", others[0].(int))
case int64:
fmt.Fprintf(&b, "; Max-Age=%d", others[0].(int64))
case int32:
fmt.Fprintf(&b, "; Max-Age=%d", others[0].(int32))
}
} else {
utctime = time.Unix(time.Now().Unix()+age, 0)
fmt.Fprintf(&b, "; Max-Age=0")
}
cookie := fmt.Sprintf("%s=%s; Expires=%s; Path=/", name, value, webTime(utctime))
ctx.SetHeader("Set-Cookie", cookie, true)
if len(others) > 1 {
fmt.Fprintf(&b, "; Path=%s", sanitizeValue(others[1].(string)))
}
if len(others) > 2 {
fmt.Fprintf(&b, "; Domain=%s", sanitizeValue(others[2].(string)))
}
if len(others) > 3 {
fmt.Fprintf(&b, "; Secure")
}
if len(others) > 4 {
fmt.Fprintf(&b, "; HttpOnly")
}
ctx.SetHeader("Set-Cookie", b.String(), false)
}
var cookieNameSanitizer = strings.NewReplacer("\n", "-", "\r", "-")
func sanitizeName(n string) string {
return cookieNameSanitizer.Replace(n)
}
var cookieValueSanitizer = strings.NewReplacer("\n", " ", "\r", " ", ";", " ")
func sanitizeValue(v string) string {
return cookieValueSanitizer.Replace(v)
}
func (ctx *Context) GetCookie(key string) string {
keycookie, err := ctx.Request.Cookie(key)
if err != nil {
return ""
}
return keycookie.Value
}

View File

@ -2,11 +2,15 @@ package beego
import (
"bytes"
"compress/flate"
"compress/gzip"
"compress/zlib"
"crypto/hmac"
"crypto/sha1"
"encoding/base64"
"encoding/json"
"encoding/xml"
"errors"
"fmt"
"github.com/astaxie/beego/session"
"html/template"
"io"
@ -18,16 +22,20 @@ import (
"path"
"strconv"
"strings"
"time"
)
type Controller struct {
Ctx *Context
Data map[interface{}]interface{}
ChildName string
TplNames string
Layout string
TplExt string
CruSession session.SessionStore
Ctx *Context
Data map[interface{}]interface{}
ChildName string
TplNames string
Layout string
TplExt string
_xsrf_token string
gotofunc string
CruSession session.SessionStore
XSRFExpire int
}
type ControllerInterface interface {
@ -51,7 +59,6 @@ func (c *Controller) Init(ctx *Context, cn string) {
c.ChildName = cn
c.Ctx = ctx
c.TplExt = "tpl"
}
func (c *Controller) Prepare() {
@ -103,39 +110,7 @@ func (c *Controller) Render() error {
return err
} else {
c.Ctx.ResponseWriter.Header().Set("Content-Type", "text/html; charset=utf-8")
output_writer := c.Ctx.ResponseWriter.(io.Writer)
if EnableGzip == true && c.Ctx.Request.Header.Get("Accept-Encoding") != "" {
splitted := strings.SplitN(c.Ctx.Request.Header.Get("Accept-Encoding"), ",", -1)
encodings := make([]string, len(splitted))
for i, val := range splitted {
encodings[i] = strings.TrimSpace(val)
}
for _, val := range encodings {
if val == "gzip" {
c.Ctx.ResponseWriter.Header().Set("Content-Encoding", "gzip")
output_writer, _ = gzip.NewWriterLevel(c.Ctx.ResponseWriter, gzip.BestSpeed)
break
} else if val == "deflate" {
c.Ctx.ResponseWriter.Header().Set("Content-Encoding", "deflate")
output_writer, _ = zlib.NewWriterLevel(c.Ctx.ResponseWriter, zlib.BestSpeed)
break
}
}
} else {
c.Ctx.SetHeader("Content-Length", strconv.Itoa(len(rb)), true)
}
output_writer.Write(rb)
switch output_writer.(type) {
case *gzip.Writer:
output_writer.(*gzip.Writer).Close()
case *zlib.Writer:
output_writer.(*zlib.Writer).Close()
case io.WriteCloser:
output_writer.(io.WriteCloser).Close()
}
return nil
c.writeToWriter(rb)
}
return nil
}
@ -149,7 +124,7 @@ func (c *Controller) RenderBytes() ([]byte, error) {
//if the controller has set layout, then first get the tplname's content set the content to the layout
if c.Layout != "" {
if c.TplNames == "" {
c.TplNames = c.ChildName + "/" + c.Ctx.Request.Method + "." + c.TplExt
c.TplNames = c.ChildName + "/" + strings.ToLower(c.Ctx.Request.Method) + "." + c.TplExt
}
if RunMode == "dev" {
BuildTemplate(ViewsPath)
@ -175,7 +150,7 @@ func (c *Controller) RenderBytes() ([]byte, error) {
return icontent, nil
} else {
if c.TplNames == "" {
c.TplNames = c.ChildName + "/" + c.Ctx.Request.Method + "." + c.TplExt
c.TplNames = c.ChildName + "/" + strings.ToLower(c.Ctx.Request.Method) + "." + c.TplExt
}
if RunMode == "dev" {
BuildTemplate(ViewsPath)
@ -197,6 +172,41 @@ func (c *Controller) RenderBytes() ([]byte, error) {
return []byte{}, nil
}
func (c *Controller) writeToWriter(rb []byte) {
output_writer := c.Ctx.ResponseWriter.(io.Writer)
if EnableGzip == true && c.Ctx.Request.Header.Get("Accept-Encoding") != "" {
splitted := strings.SplitN(c.Ctx.Request.Header.Get("Accept-Encoding"), ",", -1)
encodings := make([]string, len(splitted))
for i, val := range splitted {
encodings[i] = strings.TrimSpace(val)
}
for _, val := range encodings {
if val == "gzip" {
c.Ctx.ResponseWriter.Header().Set("Content-Encoding", "gzip")
output_writer, _ = gzip.NewWriterLevel(c.Ctx.ResponseWriter, gzip.BestSpeed)
break
} else if val == "deflate" {
c.Ctx.ResponseWriter.Header().Set("Content-Encoding", "deflate")
output_writer, _ = flate.NewWriter(c.Ctx.ResponseWriter, flate.BestSpeed)
break
}
}
} else {
c.Ctx.SetHeader("Content-Length", strconv.Itoa(len(rb)), true)
}
output_writer.Write(rb)
switch output_writer.(type) {
case *gzip.Writer:
output_writer.(*gzip.Writer).Close()
case *flate.Writer:
output_writer.(*flate.Writer).Close()
case io.WriteCloser:
output_writer.(io.WriteCloser).Close()
}
}
func (c *Controller) Redirect(url string, code int) {
c.Ctx.Redirect(code, url)
}
@ -205,15 +215,17 @@ func (c *Controller) Abort(code string) {
panic(code)
}
func (c *Controller) ServeJson() {
func (c *Controller) ServeJson(encoding ...bool) {
content, err := json.MarshalIndent(c.Data["json"], "", " ")
if err != nil {
http.Error(c.Ctx.ResponseWriter, err.Error(), http.StatusInternalServerError)
return
}
c.Ctx.SetHeader("Content-Length", strconv.Itoa(len(content)), true)
c.Ctx.ResponseWriter.Header().Set("Content-Type", "application/json;charset=UTF-8")
c.Ctx.ResponseWriter.Write(content)
if len(encoding) > 0 && encoding[0] == true {
content = []byte(stringsToJson(string(content)))
}
c.writeToWriter(content)
}
func (c *Controller) ServeJsonp() {
@ -231,9 +243,8 @@ func (c *Controller) ServeJsonp() {
callback_content.WriteString("(")
callback_content.Write(content)
callback_content.WriteString(");\r\n")
c.Ctx.SetHeader("Content-Length", strconv.Itoa(callback_content.Len()), true)
c.Ctx.ResponseWriter.Header().Set("Content-Type", "application/json;charset=UTF-8")
c.Ctx.ResponseWriter.Write(callback_content.Bytes())
c.writeToWriter(callback_content.Bytes())
}
func (c *Controller) ServeXml() {
@ -242,9 +253,8 @@ func (c *Controller) ServeXml() {
http.Error(c.Ctx.ResponseWriter, err.Error(), http.StatusInternalServerError)
return
}
c.Ctx.SetHeader("Content-Length", strconv.Itoa(len(content)), true)
c.Ctx.ResponseWriter.Header().Set("Content-Type", "application/xml;charset=UTF-8")
c.Ctx.ResponseWriter.Write(content)
c.writeToWriter(content)
}
func (c *Controller) Input() url.Values {
@ -257,20 +267,24 @@ func (c *Controller) Input() url.Values {
return c.Ctx.Request.Form
}
func (c *Controller) ParseForm(obj interface{}) error {
return ParseForm(c.Input(), obj)
}
func (c *Controller) GetString(key string) string {
return c.Input().Get(key)
}
func (c *Controller) GetStrings(key string) []string {
r := c.Ctx.Request;
if r.Form == nil {
r := c.Ctx.Request
if r.Form == nil {
return []string{}
}
vs := r.Form[key]
if len(vs) > 0 {
return vs
}
return []string{}
vs := r.Form[key]
if len(vs) > 0 {
return vs
}
return []string{}
}
func (c *Controller) GetInt(key string) (int64, error) {
@ -327,3 +341,62 @@ func (c *Controller) DelSession(name interface{}) {
}
c.CruSession.Delete(name)
}
func (c *Controller) DestroySession() {
GlobalSessions.SessionDestroy(c.Ctx.ResponseWriter, c.Ctx.Request)
}
func (c *Controller) IsAjax() bool {
return (c.Ctx.Request.Header.Get("HTTP_X_REQUESTED_WITH") == "XMLHttpRequest")
}
func (c *Controller) XsrfToken() string {
if c._xsrf_token == "" {
token := c.Ctx.GetCookie("_xsrf")
if token == "" {
h := hmac.New(sha1.New, []byte(XSRFKEY))
fmt.Fprintf(h, "%s:%d", c.Ctx.Request.RemoteAddr, time.Now().UnixNano())
tok := fmt.Sprintf("%s:%d", h.Sum(nil), time.Now().UnixNano())
token = base64.URLEncoding.EncodeToString([]byte(tok))
expire := 0
if c.XSRFExpire > 0 {
expire = c.XSRFExpire
} else {
expire = XSRFExpire
}
c.Ctx.SetCookie("_xsrf", token, expire)
}
c._xsrf_token = token
}
return c._xsrf_token
}
func (c *Controller) CheckXsrfCookie() bool {
token := c.GetString("_xsrf")
if token == "" {
token = c.Ctx.Request.Header.Get("X-Xsrftoken")
}
if token == "" {
token = c.Ctx.Request.Header.Get("X-Csrftoken")
}
if token == "" {
c.Ctx.Abort(403, "'_xsrf' argument missing from POST")
}
if c._xsrf_token != token {
c.Ctx.Abort(403, "XSRF cookie does not match POST argument")
}
return true
}
func (c *Controller) XsrfFormHtml() string {
return "<input type=\"hidden\" name=\"_xsrf\" value=\"" +
c._xsrf_token + "\"/>"
}
func (c *Controller) GoToFunc(funcname string) {
if funcname[0] < 65 || funcname[0] > 90 {
panic("GoToFunc should exported function")
}
c.gotofunc = funcname
}

102
docs/en/API.md Normal file
View File

@ -0,0 +1,102 @@
# Getting start with API application development
Go is very good for developing API applications which I think is the biggest strength compare to other dynamic languages. Beego provides powerful and quick setup tool for developing API applications, which gives you more focus on business logic.
## Quick setup
bee can setup a API application very quick by executing commands under any `$GOPATH/src`.
`bee api beeapi`
## Application directory structure
```
├── conf
│ └── app.conf
├── controllers
│ └── default.go
├── models
│ └── object.go
└── main.go
```
## Source code explanation
- app.conf has following configuration options for your API applications:
- autorender = false // Disable auto-render since API applications don't need.
- copyrequestbody = true // RESTFul applications sends raw body instead of form, so we need to read body specifically.
- main.go is for registering routers of RESTFul.
beego.RESTRouter("/object", &controllers.ObejctController{})
Match rules as follows:
<table>
<tr>
<th>URL</th> <th>HTTP Verb</th> <th>Functionality</th>
</tr>
<tr>
<td>/object</td> <td>POST</td> <td>Creating Objects</td>
</tr>
<tr>
<td>/object/objectId</td> <td>GET</td> <td>Retrieving Objects</td>
</tr>
<tr>
<td>/object/objectId</td> <td>PUT</td> <td>Updating Objects</td>
</tr>
<tr>
<td>/object</td> <td>GET</td> <td>Queries</td>
</tr>
<tr>
<td>/object/objectId</td> <td>DELETE</td> <td>Deleting Objects</td>
</tr>
</table>
- ObejctController implemented corresponding methods:
type ObejctController struct {
beego.Controller
}
func (this *ObejctController) Post(){
}
func (this *ObejctController) Get(){
}
func (this *ObejctController) Put(){
}
func (this *ObejctController) Delete(){
}
- models implemented corresponding object operation for adding, deleting, updating and getting.
## Test
- Add a new object:
curl -X POST -d '{"Score":1337,"PlayerName":"Sean Plott"}' http://127.0.0.1:8080/object
Returns a corresponding objectID:astaxie1373349756660423900
- Query a object:
`curl -X GET http://127.0.0.1:8080/object/astaxie1373349756660423900`
- Query all objects:
`curl -X GET http://127.0.0.1:8080/object`
- Update a object:
`curl -X PUT -d '{"Score":10000}'http://127.0.0.1:8080/object/astaxie1373349756660423900`
- Delete a object:
`curl -X DELETE http://127.0.0.1:8080/object/astaxie1373349756660423900`

View File

@ -74,7 +74,7 @@ What happened in behind above example?
Get into your $GOPATH, then use following command to setup Beego project:
bee create hello
bee new hello
It generates folders and files for your project, directory structure as follows:
@ -97,11 +97,11 @@ It generates folders and files for your project, directory structure as follows:
Beego uses development mode as default, you can use following code to change mode in your application:
beego.RunMode = "pro"
beego.RunMode = "prod"
Or use configuration file in `conf/app.conf`, and input following content:
runmode = pro
runmode = prod
No differences between two ways.
@ -338,7 +338,7 @@ To disable auto-render in `main.go`(before you call `beego.Run()` to run the app
You can use `this.Data` in controller methods to access the data in templates. Suppose you want to get content of `{{.Content}}`, you can use following code to do this:
this.Data["Context"] = "value"
this.Data["Content"] = "value"
### Template name
@ -762,16 +762,29 @@ Beego has a default BeeLogger object that outputs log into stdout, and you can u
beego.SetLogger(*log.Logger)
You can output everything that implemented `*log.Logger`, for example, write to file:
Now Beego supports new way to record your log with automatically log rotate. Use following code in your main function:
fd,err := os.OpenFile("/var/log/beeapp/beeapp.log", os.O_RDWR|os.O_APPEND, 0644)
filew := beego.NewFileWriter("tmp/log.log", true)
err := filew.StartLogger()
if err != nil {
beego.Critical("openfile beeapp.log:", err)
return
beego.Critical("NewFileWriter err", err)
}
lg := log.New(fd, "", log.Ldate|log.Ltime)
beego.SetLogger(lg)
So Beego records your log into file `tmp/log.log`, the second argument indicates whether enable log rotate or not. The rules of rotate as follows:
1. segment log every 1,000,000 lines.
2. segment log every 256 MB file size.
3. segment log daily.
4. save log file up to 7 days as default.
You cannot segment log over 999 times everyday, the segmented file name with format `<defined file name>.<date>.<three digits>`.
You are able to modify rotate rules with following methods, be sure that you call them before `StartLogger()`.
- func (w *FileLogWriter) SetRotateDaily(daily bool) *FileLogWriter
- func (w *FileLogWriter) SetRotateLines(maxlines int) *FileLogWriter
- func (w *FileLogWriter) SetRotateMaxDays(maxdays int64) *FileLogWriter
- func (w *FileLogWriter) SetRotateSize(maxsize int) *FileLogWriter
### Different levels of log

106
docs/zh/API.md Normal file
View File

@ -0,0 +1,106 @@
# API应用开发入门
Go是非常适合用来开发API应用的而且我认为也是Go相对于其他动态语言的最大优势应用。beego在开发API应用方面提供了非常强大和快速的工具方便用户快速的建立API应用原型专心业务逻辑就行了。
## 快速建立原型
bee快速开发工具提供了一个API应用建立的工具在gopath/src下的任意目录执行如下命令就可以快速的建立一个API应用
`bee api beeapi`
## 应用的目录结构
应用的目录结构如下所示:
```
├── conf
│ └── app.conf
├── controllers
│ └── default.go
├── models
│ └── object.go
└── main.go
```
## 源码解析
- app.conf里面主要针对API的配置如下
autorender = false //API应用不需要模板渲染所以关闭自动渲染
copyrequestbody = true //RESTFul应用发送信息的时候是raw body而不是普通的form表单所以需要额外的读取body信息
- main.go文件主要针对RESTFul的路由注册
`beego.RESTRouter("/object", &controllers.ObejctController{})`
这个路由可以匹配如下的规则
<table>
<tr>
<th>URL</th> <th>HTTP Verb</th> <th>Functionality</th>
</tr>
<tr>
<td>/object</td> <td>POST</td> <td>Creating Objects</td>
</tr>
<tr>
<td>/object/objectId</td> <td>GET</td> <td>Retrieving Objects</td>
</tr>
<tr>
<td>/object/objectId</td> <td>PUT</td> <td>Updating Objects</td>
</tr>
<tr>
<td>/object</td> <td>GET</td> <td>Queries</td>
</tr>
<tr>
<td>/object/objectId</td> <td>DELETE</td> <td>Deleting Objects</td>
</tr>
</table>
- ObejctController实现了对应的方法
```
type ObejctController struct {
beego.Controller
}
func (this *ObejctController) Post(){
}
func (this *ObejctController) Get(){
}
func (this *ObejctController) Put(){
}
func (this *ObejctController) Delete(){
}
```
- models里面实现了对应操作对象的增删改取等操作
## 测试
- 添加一个对象:
`curl -X POST -d '{"Score":1337,"PlayerName":"Sean Plott"}' http://127.0.0.1:8080/object`
返回一个相应的objectID:astaxie1373349756660423900
- 查询一个对象
`curl -X GET http://127.0.0.1:8080/object/astaxie1373349756660423900`
- 查询全部的对象
`curl -X GET http://127.0.0.1:8080/object`
- 更新一个对象
`curl -X PUT -d '{"Score":10000}'http://127.0.0.1:8080/object/astaxie1373349756660423900`
- 删除一个对象
`curl -X DELETE http://127.0.0.1:8080/object/astaxie1373349756660423900`

78
docs/zh/HttpLib.md Normal file
View File

@ -0,0 +1,78 @@
## 方便的http客户端
我们经常会使用Go来请求其他API应用例如你使用beego开发了一个RESTFul的API应用那么如果来请求呢当然可以使用`http.Client`来实现但是需要自己来操作很多步骤自己需要考虑很多东西所以我就基于net下的一些包实现了这个简便的http客户端工具。
该工具的主要特点:
- 链式操作
- 超时控制
- 方便的解析
- 可控的debug
## 例子
我们上次开发的RESTful应用最后我写过如何通过curl来进行测试那么下面一一对每个操作如何用httplib来操作进行展示
- 添加一个对象:
`curl -X POST -d '{"Score":1337,"PlayerName":"Sean Plott"}' http://127.0.0.1:8080/object`
返回一个相应的objectID:astaxie1373349756660423900
str,err:=beego.Post("http://127.0.0.1:8080/object").Body(`{"Score":1337,"PlayerName":"Sean Plott"}`).String()
if err != nil{
println(err)
}
- 查询一个对象
`curl -X GET http://127.0.0.1:8080/object/astaxie1373349756660423900`
var object Obeject
err:=beego.Get("http://127.0.0.1:8080/object/astaxie1373349756660423900").ToJson(&object)
if err != nil{
println(err)
}
- 查询全部的对象
`curl -X GET http://127.0.0.1:8080/object`
var objects []Object
err:=beego.Get("http://127.0.0.1:8080/object").ToJson(&objects)
if err != nil{
println(err)
}
- 更新一个对象
`curl -X PUT -d '{"Score":10000}'http://127.0.0.1:8080/object/astaxie1373349756660423900`
str,err:=beego.Put("http://127.0.0.1:8080/object/astaxie1373349756660423900").Body(`{"Score":10000}`).String()
if err != nil{
println(err)
}
- 删除一个对象
`curl -X DELETE http://127.0.0.1:8080/object/astaxie1373349756660423900`
str,er:=beego.Delete("http://127.0.0.1:8080/object/astaxie1373349756660423900").String()
if err != nil{
println(err)
}
## 开启调试模式
用户可以开启调试打印request信息默认是关闭模式
beego.Post(url).Debug(true)
## ToFile、ToXML、ToJson
上面我演示了Json的解析其实还有直接保存为文件的ToFile操作解析XML的ToXML操作
## 设置链接超时和读写超时
默认都设置为60秒用户可以通过函数来设置相应的超时时间
beego.Get(url).SetTimeout(100*time.Second,100*time.Second)
更加详细的请参考[API接口](http://gowalker.org/github.com/astaxie/beego)

View File

@ -25,7 +25,6 @@ beego虽然是一个简单的框架但是其中用到了很多第三方的包
> - session模块中支持mysql引擎github.com/go-sql-driver/mysql
> - 模板函数中支持markdown转化github.com/russross/blackfriday
- [beego介绍](README.md)

View File

@ -4,25 +4,25 @@
**导航**
- [最小应用](#-1)
- [新建项目](#-2)
- [开发模式](#-3)
- [路由设置](#-4)
- [静态文件](#-5)
- [过滤和中间件](#-6)
- [Controller设计](#-7)
- [模板处理](#-8)
- [request处理](#request)
- [跳转和错误](#-15)
- [response处理](#response)
- [Sessions](#sessions)
- [Cache设置](#cache)
- [安全的Map](#map)
- [日志处理](#-16)
- [配置管理](#-17)
- [beego参数](#-18)
- [第三方应用集成](#-19)
- [部署编译应用](#-20)
- [最小应用](#%E6%9C%80%E5%B0%8F%E5%BA%94%E7%94%A8)
- [新建项目](#%E6%96%B0%E5%BB%BA%E9%A1%B9%E7%9B%AE)
- [开发模式](#%E5%BC%80%E5%8F%91%E6%A8%A1%E5%BC%8F)
- [路由设置](#%E8%B7%AF%E7%94%B1%E8%AE%BE%E7%BD%AE)
- [静态文件](#%E9%9D%99%E6%80%81%E6%96%87%E4%BB%B6)
- [过滤和中间件](#%E8%BF%87%E6%BB%A4%E5%92%8C%E4%B8%AD%E9%97%B4%E4%BB%B6)
- [Controller设计](#%E6%8E%A7%E5%88%B6%E5%99%A8%E8%AE%BE%E8%AE%A1)
- [模板处理](#%E6%A8%A1%E6%9D%BF%E5%A4%84%E7%90%86)
- [request处理](#request%E5%A4%84%E7%90%86)
- [跳转和错误](#%E8%B7%B3%E8%BD%AC%E5%92%8C%E9%94%99%E8%AF%AF)
- [response处理](#response%E5%A4%84%E7%90%86)
- [Sessions/Flash](#sessionsflash)
- [Cache设置](#cache%E8%AE%BE%E7%BD%AE)
- [安全的Map](#%E5%AE%89%E5%85%A8%E7%9A%84map)
- [日志处理](#%E6%97%A5%E5%BF%97%E5%A4%84%E7%90%86)
- [配置管理](#%E9%85%8D%E7%BD%AE%E7%AE%A1%E7%90%86)
- [beego参数](#%E7%B3%BB%E7%BB%9F%E9%BB%98%E8%AE%A4%E5%8F%82%E6%95%B0)
- [第三方应用集成](#%E7%AC%AC%E4%B8%89%E6%96%B9%E5%BA%94%E7%94%A8%E9%9B%86%E6%88%90)
- [部署编译应用](#%E9%83%A8%E7%BD%B2%E7%BC%96%E8%AF%91%E5%BA%94%E7%94%A8)
## 最小应用
@ -76,7 +76,7 @@
通过如下命令创建beego项目首先进入gopath目录
bee create hello
bee new hello
这样就建立了一个项目hello目录结构如下所示
@ -101,11 +101,11 @@
我们可以通过如下的方式改变我们的模式:
beego.RunMode = "pro"
beego.RunMode = "prod"
或者我们在conf/app.conf下面设置如下
runmode = pro
runmode = prod
以上两种效果一样。
@ -115,7 +115,7 @@
2013/04/13 19:36:17 [W] [stat views: no such file or directory]
- 模板会自动重新加载不缓存。
- 模板每次使用都会重新加载,不进行缓存。
- 如果服务端出错,那么就会在浏览器端显示如下类似的截图:
![](images/dev.png)
@ -123,6 +123,8 @@
## 路由设置
### 默认路由RESTFul规则
路由的主要功能是实现从请求地址到实现方法beego中封装了`Controller`,所以路由是从路径到`ControllerInterface`的过程,`ControllerInterface`的方法有如下:
type ControllerInterface interface {
@ -179,6 +181,66 @@
this.Ctx.Params[":path"]
this.Ctx.Params[":ext"]
### 自定义方法及RESTFul规则
上面列举的是默认的请求方法名(请求的method和函数名一致例如GET请求执行Get函数POST请求执行Post函数),如果用户期望自定义函数名,那么可以使用如下方式:
beego.Router("/",&IndexController{},"*:Index")
使用第三个参数第三个参数就是用来设置对应method到函数名定义如下
- *表示任意的method都执行该函数
- 使用`httpmethod:funcname`格式来展示
- 多个不同的格式使用`;`分割
- 多个method对应同一个funcnamemethod之间通过`,`来分割
以下是一个RESTful的设计如下
- beego.Router("/api/list",&RestController{},"*:ListFood")
- beego.Router("/api/create",&RestController{},"post:CreateFood")
- beego.Router("/api/update",&RestController{},"put:UpdateFood")
- beego.Router("/api/delete",&RestController{},"delete:DeleteFood")
以下是多个http method指向同一个函数
beego.Router("/api",&RestController{},"get,post:ApiFunc")
一下是不同的method对应不同的函数通过`;`进行分割
beego.Router("/simple",&SimpleController{},"get:GetFunc;post:PostFunc")
可用的http method
- * :包含一下所有的函数
- get GET请求
- post POST请求
- put PUT请求
- delete DELETE请求
- patch PATCH请求
- options OPTIONS请求
- head HEAD请求
>>>如果同时存在*和对应的http method那么优先执行http method的方法例如同时注册了如下所示的路由
>>> beego.Router("/simple",&SimpleController{},"*:AllFunc;post:PostFunc")
>>>那么执行POST请求的时候执行PostFunc而不执行AllFunc
### 自动化路由
用户首先需要把需要路由的控制器注册到自动路由中:
beego.AutoRouter(&controllers.ObjectController{})
那么beego就会通过反射获取该结构体中所有的实现方法你就可以通过如下的方式访问到对应的方法中
/object/login 调用ObjectController中的Login方法
/object/logout 调用ObjectController中的Logout方法
除了前缀两个/:controller/:method的匹配之外剩下的urlbeego会帮你自动化解析为参数保存在`this.Ctx.Params`当中:
/object/blog/2013/09/12 调用ObjectController中的Blog方法参数如下map[0:2013 1:09 2:12]
>>> 方法名在内部是保存了用户设置的例如Loginurl匹配的时候都会转化为小写所以/object/LOGIN这样的url也一样可以路由到用户定义的Login方法中
## 静态文件
@ -343,7 +405,7 @@ main.go文件中设置如下
模板中的数据是通过在Controller中`this.Data`获取的,所以如果你想在模板中获取内容`{{.Content}}`,那么你需要在Controller中如下设置
this.Data["Context"] = "value"
this.Data["Content"] = "value"
### 模板名称
@ -581,11 +643,11 @@ beego更加人性化的还有一个设计就是支持用户自定义字符串错
## response处理
response可能会有集中情况
response可能会有几种情况
1. 模板输出
模板输出上面模板介绍里面已经介绍beego会在执行完相应的Controller里面的对应的Method之后输出到模板
上面模板介绍里面已经介绍beego会在执行完相应的Controller里面的对应的Method之后输出到模板
2. 跳转
@ -598,7 +660,7 @@ response可能会有集中情况
this.Ctx.WriteString("ok")
## Sessions
## Sessions/Flash
beego内置了session模块目前session模块支持的后端引擎包括memoryfilemysqlredis四中用户也可以根据相应的interface实现自己的引擎
@ -624,7 +686,7 @@ beego中使用session相当方便只要在main入口函数中设置如下
this.TplNames = "index.tpl"
}
上面的例子中我们知道session有几个方便的方法
session有几个方便的方法
- SetSession(name string, value interface{})
- GetSession(name string) interface{}
@ -684,6 +746,64 @@ sess对象具有如下方法
beego.SessionProvider = "redis"
beego.SessionSavePath = "127.0.0.1:6379"
这个flash与Adobe/Macromedia Flash没有任何关系它主要用于在两个逻辑间传递临时数据flash中存放的所有数据会在紧接着的下一个逻辑中调用后清除一般用于传递提示和错误消息它适合[Post/Redirect/Get](http://en.wikipedia.org/wiki/Post/Redirect/Get)模式下面看使用的例子
// 显示设置信息
func (c *MainController) Get() {
flash:=beego.ReadFromRequest(c)
if n,ok:=flash.Data["notice"];ok{
//显示设置成功
c.TplNames = "set_success.html"
}else if n,ok=flash.Data["error"];ok{
//显示错误
c.TplNames = "set_error.html"
}else{
// 不然默认显示设置页面
this.Data["list"]=GetInfo()
c.TplNames = "setting_list.html"
}
}
// 处理设置信息
func (c *MainController) Post() {
flash:=beego.NewFlash()
setting:=Settings{}
valid := Validation{}
c.ParseForm(&setting)
if b, err := valid.Valid(setting);err!=nil {
flash.Error("Settings invalid!")
flash.Store(c)
c.Redirect("/setting",302)
return
}else if b!=nil{
flash.Error("validation err!")
flash.Store(c)
c.Redirect("/setting",302)
return
}
saveSetting(setting)
flash.Notice("Settings saved!")
flash.Store(c)
c.Redirect("/setting",302)
}
上面的代码执行的大概逻辑是这样的
1. Get方法执行因为没有flash数据所以显示设置页面
2. 用户设置信息之后点击递交执行Post然后初始化一个flash通过验证验证出错或者验证不通过设置flash的错误如果通过了就保存设置然后设置flash成功设置的信息
3. 设置完成后跳转到Get请求
4. Get请求获取到了Flash信息然后执行相应的逻辑如果出错显示出错的页面如果成功显示成功的页面
默认情况下`ReadFromRequest`函数已经实现了读取的数据赋值给flash所以在你的模板里面你可以这样读取数据
{{.flash.error}}
{{.flash.warning}}
{{.flash.notice}}
flash对象有三个级别的设置
* Notice提示信息
* Warning警告信息
* Error错误信息
## Cache设置
@ -787,7 +907,7 @@ beego默认有一个初始化的BeeLogger对象输出内容到stdout中你可
- func (w *FileLogWriter) SetRotateDaily(daily bool) *FileLogWriter
- func (w *FileLogWriter) SetRotateLines(maxlines int) *FileLogWriter
- func (w *FileLogWriter) SetRotateMaxDay(maxday int64) *FileLogWriter
- func (w *FileLogWriter) SetRotateMaxDays(maxdays int64) *FileLogWriter
- func (w *FileLogWriter) SetRotateSize(maxsize int) *FileLogWriter
但是这些函数调用必须在调用`StartLogger`之前

View File

@ -43,6 +43,8 @@ beego是一个类似tornado的Go应用框架采用了RESTFul的方式来实
* [一步一步开发应用](Tutorial.md)
* [beego案例](Application.md)
* [热升级](HotUpdate.md)
* [API应用开发入门](API.md)
* [HTTPLIB客户端](HttpLib.md)
# API接口

View File

View File

View File

View File

View File

View File

View File

View File

View File

View File

View File

View File

View File

View File

View File

View File

View File

View File

View File

@ -189,6 +189,7 @@ func NotFound(rw http.ResponseWriter, r *http.Request) {
"<br>You like 404 pages" +
"</ul>")
data["BeegoVersion"] = VERSION
rw.WriteHeader(http.StatusNotFound)
t.Execute(rw, data)
}
@ -204,6 +205,7 @@ func Unauthorized(rw http.ResponseWriter, r *http.Request) {
"<br>Check the address for errors" +
"</ul>")
data["BeegoVersion"] = VERSION
rw.WriteHeader(http.StatusUnauthorized)
t.Execute(rw, data)
}
@ -220,6 +222,7 @@ func Forbidden(rw http.ResponseWriter, r *http.Request) {
"<br>You need to log in" +
"</ul>")
data["BeegoVersion"] = VERSION
rw.WriteHeader(http.StatusForbidden)
t.Execute(rw, data)
}
@ -235,6 +238,7 @@ func ServiceUnavailable(rw http.ResponseWriter, r *http.Request) {
"<br>Please try again later." +
"</ul>")
data["BeegoVersion"] = VERSION
rw.WriteHeader(http.StatusServiceUnavailable)
t.Execute(rw, data)
}
@ -249,6 +253,7 @@ func InternalServerError(rw http.ResponseWriter, r *http.Request) {
"<br>you should report the fault to the website administrator" +
"</ul>")
data["BeegoVersion"] = VERSION
rw.WriteHeader(http.StatusInternalServerError)
t.Execute(rw, data)
}

View File

@ -0,0 +1,5 @@
appname = beeapi
httpport = 8080
runmode = dev
autorender = false
copyrequestbody = true

View File

@ -0,0 +1,56 @@
package controllers
import (
"encoding/json"
"github.com/astaxie/beego"
"github.com/astaxie/beego/example/beeapi/models"
)
type ObejctController struct {
beego.Controller
}
func (this *ObejctController) Post() {
var ob models.Object
json.Unmarshal(this.Ctx.RequestBody, &ob)
objectid := models.AddOne(ob)
this.Data["json"] = "{\"ObjectId\":\"" + objectid + "\"}"
this.ServeJson()
}
func (this *ObejctController) Get() {
objectId := this.Ctx.Params[":objectId"]
if objectId != "" {
ob, err := models.GetOne(objectId)
if err != nil {
this.Data["json"] = err
} else {
this.Data["json"] = ob
}
} else {
obs := models.GetAll()
this.Data["json"] = obs
}
this.ServeJson()
}
func (this *ObejctController) Put() {
objectId := this.Ctx.Params[":objectId"]
var ob models.Object
json.Unmarshal(this.Ctx.RequestBody, &ob)
err := models.Update(objectId, ob.Score)
if err != nil {
this.Data["json"] = err
} else {
this.Data["json"] = "update success!"
}
this.ServeJson()
}
func (this *ObejctController) Delete() {
objectId := this.Ctx.Params[":objectId"]
models.Delete(objectId)
this.Data["json"] = "delete success!"
this.ServeJson()
}

20
example/beeapi/main.go Normal file
View File

@ -0,0 +1,20 @@
package main
import (
"github.com/astaxie/beego"
"github.com/astaxie/beego/example/beeapi/controllers"
)
// Objects
// URL HTTP Verb Functionality
// /object POST Creating Objects
// /object/<objectId> GET Retrieving Objects
// /object/<objectId> PUT Updating Objects
// /object GET Queries
// /object/<objectId> DELETE Deleting Objects
func main() {
beego.RESTRouter("/object", &controllers.ObejctController{})
beego.Run()
}

View File

@ -0,0 +1,52 @@
package models
import (
"errors"
"strconv"
"time"
)
var (
Objects map[string]*Object
)
type Object struct {
ObjectId string
Score int64
PlayerName string
}
func init() {
Objects = make(map[string]*Object)
Objects["hjkhsbnmn123"] = &Object{"hjkhsbnmn123", 100, "astaxie"}
Objects["mjjkxsxsaa23"] = &Object{"mjjkxsxsaa23", 101, "someone"}
}
func AddOne(object Object) (ObjectId string) {
object.ObjectId = "astaxie" + strconv.FormatInt(time.Now().UnixNano(), 10)
Objects[object.ObjectId] = &object
return object.ObjectId
}
func GetOne(ObjectId string) (object *Object, err error) {
if v, ok := Objects[ObjectId]; ok {
return v, nil
}
return nil, errors.New("ObjectId Not Exist")
}
func GetAll() map[string]*Object {
return Objects
}
func Update(ObjectId string, Score int64) (err error) {
if v, ok := Objects[ObjectId]; ok {
v.Score = Score
return nil
}
return errors.New("ObjectId Not Exist")
}
func Delete(ObjectId string) {
delete(Objects, ObjectId)
}

72
flash.go Normal file
View File

@ -0,0 +1,72 @@
package beego
import (
"fmt"
"net/url"
"strings"
)
type FlashData struct {
Data map[string]string
}
func NewFlash() *FlashData {
return &FlashData{
Data: make(map[string]string),
}
}
func (fd *FlashData) Notice(msg string, args ...interface{}) {
if len(args) == 0 {
fd.Data["notice"] = msg
} else {
fd.Data["notice"] = fmt.Sprintf(msg, args...)
}
}
func (fd *FlashData) Warning(msg string, args ...interface{}) {
if len(args) == 0 {
fd.Data["warning"] = msg
} else {
fd.Data["warning"] = fmt.Sprintf(msg, args...)
}
}
func (fd *FlashData) Error(msg string, args ...interface{}) {
if len(args) == 0 {
fd.Data["error"] = msg
} else {
fd.Data["error"] = fmt.Sprintf(msg, args...)
}
}
func (fd *FlashData) Store(c *Controller) {
c.Data["flash"] = fd.Data
var flashValue string
for key, value := range fd.Data {
flashValue += "\x00" + key + ":" + value + "\x00"
}
c.Ctx.SetCookie("BEEGO_FLASH", url.QueryEscape(flashValue), 0, "/")
}
func ReadFromRequest(c *Controller) *FlashData {
flash := &FlashData{
Data: make(map[string]string),
}
if cookie, err := c.Ctx.Request.Cookie("BEEGO_FLASH"); err == nil {
vals := strings.Split(cookie.Value, "\x00")
for _, v := range vals {
if len(v) > 0 {
kv := strings.Split(v, ":")
if len(kv) == 2 {
flash.Data[kv[0]] = kv[1]
}
}
}
//read one time then delete it
cookie.MaxAge = -1
c.Ctx.Request.AddCookie(cookie)
}
c.Data["flash"] = flash.Data
return flash
}

242
httplib/httplib.go Normal file
View File

@ -0,0 +1,242 @@
package httplib
import (
"bytes"
"encoding/json"
"encoding/xml"
"io"
"io/ioutil"
"net"
"net/http"
"net/http/httputil"
"net/url"
"os"
"strings"
"time"
)
var defaultUserAgent = "beegoServer"
func Get(url string) *BeegoHttpRequest {
var req http.Request
req.Method = "GET"
req.Header = http.Header{}
req.Header.Set("User-Agent", defaultUserAgent)
return &BeegoHttpRequest{url, &req, map[string]string{}, false, 60 * time.Second, 60 * time.Second}
}
func Post(url string) *BeegoHttpRequest {
var req http.Request
req.Method = "POST"
req.Header = http.Header{}
req.Header.Set("User-Agent", defaultUserAgent)
return &BeegoHttpRequest{url, &req, map[string]string{}, false, 60 * time.Second, 60 * time.Second}
}
func Put(url string) *BeegoHttpRequest {
var req http.Request
req.Method = "PUT"
req.Header = http.Header{}
req.Header.Set("User-Agent", defaultUserAgent)
return &BeegoHttpRequest{url, &req, map[string]string{}, false, 60 * time.Second, 60 * time.Second}
}
func Delete(url string) *BeegoHttpRequest {
var req http.Request
req.Method = "DELETE"
req.Header = http.Header{}
req.Header.Set("User-Agent", defaultUserAgent)
return &BeegoHttpRequest{url, &req, map[string]string{}, false, 60 * time.Second, 60 * time.Second}
}
func Head(url string) *BeegoHttpRequest {
var req http.Request
req.Method = "HEAD"
req.Header = http.Header{}
req.Header.Set("User-Agent", defaultUserAgent)
return &BeegoHttpRequest{url, &req, map[string]string{}, false, 60 * time.Second, 60 * time.Second}
}
type BeegoHttpRequest struct {
url string
req *http.Request
params map[string]string
showdebug bool
connectTimeout time.Duration
readWriteTimeout time.Duration
}
func (b *BeegoHttpRequest) Debug(isdebug bool) *BeegoHttpRequest {
b.showdebug = isdebug
return b
}
func (b *BeegoHttpRequest) SetTimeout(connectTimeout, readWriteTimeout time.Duration) *BeegoHttpRequest {
b.connectTimeout = connectTimeout
b.readWriteTimeout = readWriteTimeout
return b
}
func (b *BeegoHttpRequest) Header(key, value string) *BeegoHttpRequest {
b.req.Header.Set(key, value)
return b
}
func (b *BeegoHttpRequest) Param(key, value string) *BeegoHttpRequest {
b.params[key] = value
return b
}
func (b *BeegoHttpRequest) Body(data interface{}) *BeegoHttpRequest {
switch t := data.(type) {
case string:
bf := bytes.NewBufferString(t)
b.req.Body = ioutil.NopCloser(bf)
b.req.ContentLength = int64(len(t))
case []byte:
bf := bytes.NewBuffer(t)
b.req.Body = ioutil.NopCloser(bf)
b.req.ContentLength = int64(len(t))
}
return b
}
func (b *BeegoHttpRequest) getResponse() (*http.Response, error) {
var paramBody string
if b.params != nil && len(b.params) > 0 {
var buf bytes.Buffer
for k, v := range b.params {
buf.WriteString(url.QueryEscape(k))
buf.WriteByte('=')
buf.WriteString(url.QueryEscape(v))
buf.WriteByte('&')
}
paramBody = buf.String()
paramBody = paramBody[0 : len(paramBody)-1]
}
if b.req.Method == "GET" && len(paramBody) > 0 {
if strings.Index(b.url, "?") != -1 {
b.url += "&" + paramBody
} else {
b.url = b.url + "?" + paramBody
}
} else if b.req.Method == "POST" && b.req.Body == nil && len(paramBody) > 0 {
b.Header("Content-Type", "application/x-www-form-urlencoded")
b.Body(paramBody)
}
url, err := url.Parse(b.url)
if url.Scheme == "" {
b.url = "http://" + b.url
url, err = url.Parse(b.url)
}
if err != nil {
return nil, err
}
b.req.URL = url
if b.showdebug {
dump, err := httputil.DumpRequest(b.req, true)
if err != nil {
println(err.Error())
}
println(string(dump))
}
client := &http.Client{
Transport: &http.Transport{
Dial: TimeoutDialer(b.connectTimeout, b.readWriteTimeout),
},
}
resp, err := client.Do(b.req)
if err != nil {
return nil, err
}
return resp, nil
}
func (b *BeegoHttpRequest) String() (string, error) {
data, err := b.Bytes()
if err != nil {
return "", err
}
return string(data), nil
}
func (b *BeegoHttpRequest) Bytes() ([]byte, error) {
resp, err := b.getResponse()
if err != nil {
return nil, err
}
if resp.Body == nil {
return nil, nil
}
defer resp.Body.Close()
data, err := ioutil.ReadAll(resp.Body)
if err != nil {
return nil, err
}
return data, nil
}
func (b *BeegoHttpRequest) ToFile(filename string) error {
f, err := os.Create(filename)
if err != nil {
return err
}
defer f.Close()
resp, err := b.getResponse()
if err != nil {
return err
}
if resp.Body == nil {
return nil
}
defer resp.Body.Close()
_, err = io.Copy(f, resp.Body)
if err != nil {
return err
}
return nil
}
func (b *BeegoHttpRequest) ToJson(v interface{}) error {
data, err := b.Bytes()
if err != nil {
return err
}
err = json.Unmarshal(data, v)
if err != nil {
return err
}
return nil
}
func (b *BeegoHttpRequest) ToXML(v interface{}) error {
data, err := b.Bytes()
if err != nil {
return err
}
err = xml.Unmarshal(data, v)
if err != nil {
return err
}
return nil
}
func (b *BeegoHttpRequest) Response() (*http.Response, error) {
return b.getResponse()
}
func TimeoutDialer(cTimeout time.Duration, rwTimeout time.Duration) func(net, addr string) (c net.Conn, err error) {
return func(netw, addr string) (net.Conn, error) {
conn, err := net.DialTimeout(netw, addr, cTimeout)
if err != nil {
return nil, err
}
conn.SetDeadline(time.Now().Add(rwTimeout))
return conn, nil
}
}

32
httplib/httplib_test.go Normal file
View File

@ -0,0 +1,32 @@
package httplib
import (
"io/ioutil"
"testing"
)
func TestGetUrl(t *testing.T) {
resp, err := Get("http://beego.me/").Response()
if err != nil {
t.Fatal(err)
}
if resp.Body == nil {
t.Fatal("body is nil")
}
data, err := ioutil.ReadAll(resp.Body)
defer resp.Body.Close()
if err != nil {
t.Fatal(err)
}
if len(data) == 0 {
t.Fatal("data is no")
}
str, err := Get("http://beego.me/").String()
if err != nil {
t.Fatal(err)
}
if len(str) == 0 {
t.Fatal("has no info")
}
}

130
log.go
View File

@ -14,6 +14,7 @@ import (
type FileLogWriter struct {
*log.Logger
mw *MuxWriter
// The opened file
filename string
@ -26,12 +27,30 @@ type FileLogWriter struct {
// Rotate daily
daily bool
maxday int64
maxdays int64
daily_opendate int
rotate bool
startLock sync.Mutex //only one log can writer to the file
startLock sync.Mutex // Only one log can write to the file
}
type MuxWriter struct {
sync.Mutex
fd *os.File
}
func (l *MuxWriter) Write(b []byte) (int, error) {
l.Lock()
defer l.Unlock()
return l.fd.Write(b)
}
func (l *MuxWriter) SetFd(fd *os.File) {
if l.fd != nil {
l.fd.Close()
}
l.fd = fd
}
func NewFileWriter(fname string, rotate bool) *FileLogWriter {
@ -40,9 +59,13 @@ func NewFileWriter(fname string, rotate bool) *FileLogWriter {
maxlines: 1000000,
maxsize: 1 << 28, //256 MB
daily: true,
maxday: 7,
maxdays: 7,
rotate: rotate,
}
// use MuxWriter instead direct use os.File for lock write when rotate
w.mw = new(MuxWriter)
// set MuxWriter as Logger's io.Writer
w.Logger = log.New(w.mw, "", log.Ldate|log.Ltime)
return w
}
@ -64,16 +87,23 @@ func (w *FileLogWriter) SetRotateDaily(daily bool) *FileLogWriter {
return w
}
// Set rotate daily's log keep for maxday,other will delete
func (w *FileLogWriter) SetRotateMaxDay(maxday int64) *FileLogWriter {
w.maxday = maxday
// Set rotate daily's log keep for maxdays, other will delete
func (w *FileLogWriter) SetRotateMaxDays(maxdays int64) *FileLogWriter {
w.maxdays = maxdays
return w
}
func (w *FileLogWriter) StartLogger() error {
if err := w.DoRotate(false); err != nil {
fd, err := w.createLogFile()
if err != nil {
return err
}
w.mw.SetFd(fd)
err = w.initFd()
if err != nil {
return err
}
BeeLogger = w
return nil
}
@ -83,7 +113,7 @@ func (w *FileLogWriter) docheck(size int) {
if (w.maxlines > 0 && w.maxlines_curlines >= w.maxlines) ||
(w.maxsize > 0 && w.maxsize_cursize >= w.maxsize) ||
(w.daily && time.Now().Day() != w.daily_opendate) {
if err := w.DoRotate(true); err != nil {
if err := w.DoRotate(); err != nil {
fmt.Fprintf(os.Stderr, "FileLogWriter(%q): %s\n", w.filename, err)
return
}
@ -101,37 +131,14 @@ func (w *FileLogWriter) Printf(format string, v ...interface{}) {
w.Logger.Printf(format, v...)
}
func (w *FileLogWriter) DoRotate(rotate bool) error {
if rotate {
_, err := os.Lstat(w.filename)
if err == nil { // file exists
// Find the next available number
num := 1
fname := ""
for ; err == nil && num <= 999; num++ {
fname = w.filename + fmt.Sprintf(".%s.%03d", time.Now().Format("2006-01-02"), num)
_, err = os.Lstat(fname)
}
// return error if the last file checked still existed
if err == nil {
return fmt.Errorf("Rotate: Cannot find free log number to rename %s\n", w.filename)
}
// Rename the file to its newfound home
err = os.Rename(w.filename, fname)
if err != nil {
return fmt.Errorf("Rotate: %s\n", err)
}
go w.deleteOldLog()
}
}
func (w *FileLogWriter) createLogFile() (*os.File, error) {
// Open the log file
fd, err := os.OpenFile(w.filename, os.O_WRONLY|os.O_APPEND|os.O_CREATE, 0660)
if err != nil {
return err
}
w.Logger = log.New(fd, "", log.Ldate|log.Ltime)
return fd, err
}
func (w *FileLogWriter) initFd() error {
fd := w.mw.fd
finfo, err := fd.Stat()
if err != nil {
return fmt.Errorf("get stat err: %s\n", err)
@ -144,19 +151,60 @@ func (w *FileLogWriter) DoRotate(rotate bool) error {
fmt.Println(err)
}
w.maxlines_curlines = len(strings.Split(string(content), "\n"))
} else {
w.maxlines_curlines = 0
}
BeeLogger = w
return nil
}
func (w *FileLogWriter) DoRotate() error {
_, err := os.Lstat(w.filename)
if err == nil { // file exists
// Find the next available number
num := 1
fname := ""
for ; err == nil && num <= 999; num++ {
fname = w.filename + fmt.Sprintf(".%s.%03d", time.Now().Format("2006-01-02"), num)
_, err = os.Lstat(fname)
}
// return error if the last file checked still existed
if err == nil {
return fmt.Errorf("Rotate: Cannot find free log number to rename %s\n", w.filename)
}
// block Logger's io.Writer
w.mw.Lock()
defer w.mw.Unlock()
fd := w.mw.fd
fd.Close()
// close fd before rename
// Rename the file to its newfound home
err = os.Rename(w.filename, fname)
if err != nil {
return fmt.Errorf("Rotate: %s\n", err)
}
// re-start logger
err = w.StartLogger()
if err != nil {
return fmt.Errorf("Rotate StartLogger: %s\n", err)
}
go w.deleteOldLog()
}
return nil
}
func (w *FileLogWriter) deleteOldLog() {
dir := path.Dir(w.filename)
filepath.Walk(dir, func(path string, info os.FileInfo, err error) error {
if !info.IsDir() && info.ModTime().Unix() < (time.Now().Unix()-60*60*24*w.maxday) {
os.Remove(path)
if !info.IsDir() && info.ModTime().Unix() < (time.Now().Unix()-60*60*24*w.maxdays) {
if strings.HasPrefix(filepath.Base(path), filepath.Base(w.filename)) {
os.Remove(path)
}
}
return nil
})

View File

@ -1,36 +0,0 @@
package beego
type BeeModel struct {
driver string
}
func (this *BeeModel) Insert() {
}
func (this *BeeModel) MultipleInsert() {
}
func (this *BeeModel) Update() {
}
func (this *BeeModel) Query() {
}
//Deletes from table with clauses where and using.
func (this *BeeModel) Delete() {
}
//Start a transaction
func (this *BeeModel) Transaction() {
}
//commit transaction
func (this *BeeModel) Commit() {
}

162
orm/README.md Normal file
View File

@ -0,0 +1,162 @@
# beego orm
[![Build Status](https://drone.io/github.com/astaxie/beego/status.png)](https://drone.io/github.com/astaxie/beego/latest)
A powerful orm framework for go.
It is heavily influenced by Django ORM, SQLAlchemy.
now, beta, unstable, may be changing some api make your app build failed.
**Support Database:**
* MySQL: [github.com/go-sql-driver/mysql](https://github.com/go-sql-driver/mysql)
* PostgreSQL: [github.com/lib/pq](https://github.com/lib/pq)
* Sqlite3: [github.com/mattn/go-sqlite3](https://github.com/mattn/go-sqlite3)
Passed all test, but need more feedback.
**Features:**
* full go type support
* easy for usage, simple CRUD operation
* auto join with relation table
* cross DataBase compatible query
* Raw SQL query / mapper without orm model
* full test keep stable and strong
more features please read the docs
**Install:**
go get github.com/astaxie/beego/orm
## Changelog
* 2013-08-13: update test for database types
* 2013-08-13: go type support, such as int8, uint8, byte, rune
* 2013-08-13: date / datetime timezone support very well
## Quick Start
#### Simple Usage
```go
package main
import (
"fmt"
"github.com/astaxie/beego/orm"
_ "github.com/go-sql-driver/mysql" // import your used driver
)
// Model Struct
type User struct {
Id int `orm:"auto"`
Name string `orm:"size(100)"`
}
func init() {
// register model
orm.RegisterModel(new(User))
// set default database
orm.RegisterDataBase("default", "mysql", "root:root@/my_db?charset=utf8", 30)
}
func main() {
o := orm.NewOrm()
user := User{Name: "slene"}
// insert
id, err := o.Insert(&user)
// update
user.Name = "astaxie"
num, err := o.Update(&user)
// read one
u := User{Id: user.Id}
err = o.Read(&u)
// delete
num, err = o.Delete(&u)
}
```
#### Next with relation
```go
type Post struct {
Id int `orm:"auto"`
Title string `orm:"size(100)"`
User *User `orm:"rel(fk)"`
}
var posts []*Post
qs := o.QueryTable("post")
num, err := qs.Filter("User__Name", "slene").All(&posts)
```
#### Use Raw sql
If you don't like ORMuse Raw SQL to query / mapping without ORM setting
```go
var maps []Params
num, err := o.Raw("SELECT id FROM user WHERE name = ?", "slene").Values(&maps)
if num > 0 {
fmt.Println(maps[0]["id"])
}
```
#### Transaction
```go
o.Begin()
...
user := User{Name: "slene"}
id, err := o.Insert(&user)
if err != nil {
o.Commit()
} else {
o.Rollback()
}
```
#### Debug Log Queries
In development env, you can simple use
```go
func main() {
orm.Debug = true
...
```
enable log queries.
output include all queries, such as exec / prepare / transaction.
like this:
```go
[ORM] - 2013-08-09 13:18:16 - [Queries/default] - [ db.Exec / 0.4ms] - [INSERT INTO `user` (`name`) VALUES (?)] - `slene`
...
```
note: not recommend use this in product env.
## Docs
more details and examples in docs and test
* [中文](docs/zh)
* English
## TODO
- some unrealized api
- examples
- docs

44
orm/command.go Normal file
View File

@ -0,0 +1,44 @@
package orm
import (
"flag"
"fmt"
"os"
)
func printHelp() {
}
func getSqlAll() (sql string) {
for _, mi := range modelCache.allOrdered() {
_ = mi
}
return
}
func runCommand() {
if len(os.Args) < 2 || os.Args[1] != "orm" {
return
}
_ = flag.NewFlagSet("orm command", flag.ExitOnError)
args := argString(os.Args[2:])
cmd := args.Get(0)
switch cmd {
case "syncdb":
case "sqlall":
sql := getSqlAll()
fmt.Println(sql)
default:
if cmd != "" {
fmt.Printf("unknown command %s", cmd)
} else {
printHelp()
}
os.Exit(2)
}
}

1114
orm/db.go Normal file

File diff suppressed because it is too large Load Diff

183
orm/db_alias.go Normal file
View File

@ -0,0 +1,183 @@
package orm
import (
"database/sql"
"fmt"
"os"
"sync"
"time"
)
const defaultMaxIdle = 30
type DriverType int
const (
_ DriverType = iota
DR_MySQL
DR_Sqlite
DR_Oracle
DR_Postgres
)
type driver string
func (d driver) Type() DriverType {
a, _ := dataBaseCache.get(string(d))
return a.Driver
}
func (d driver) Name() string {
return string(d)
}
var _ Driver = new(driver)
var (
dataBaseCache = &_dbCache{cache: make(map[string]*alias)}
drivers = map[string]DriverType{
"mysql": DR_MySQL,
"postgres": DR_Postgres,
"sqlite3": DR_Sqlite,
}
dbBasers = map[DriverType]dbBaser{
DR_MySQL: newdbBaseMysql(),
DR_Sqlite: newdbBaseSqlite(),
DR_Oracle: newdbBaseMysql(),
DR_Postgres: newdbBasePostgres(),
}
)
type _dbCache struct {
mux sync.RWMutex
cache map[string]*alias
}
func (ac *_dbCache) add(name string, al *alias) (added bool) {
ac.mux.Lock()
defer ac.mux.Unlock()
if _, ok := ac.cache[name]; ok == false {
ac.cache[name] = al
added = true
}
return
}
func (ac *_dbCache) get(name string) (al *alias, ok bool) {
ac.mux.RLock()
defer ac.mux.RUnlock()
al, ok = ac.cache[name]
return
}
func (ac *_dbCache) getDefault() (al *alias) {
al, _ = ac.get("default")
return
}
type alias struct {
Name string
Driver DriverType
DriverName string
DataSource string
MaxIdle int
DB *sql.DB
DbBaser dbBaser
TZ *time.Location
}
func RegisterDataBase(name, driverName, dataSource string, maxIdle int) {
if maxIdle <= 0 {
maxIdle = defaultMaxIdle
}
al := new(alias)
al.Name = name
al.DriverName = driverName
al.DataSource = dataSource
al.MaxIdle = maxIdle
var (
err error
)
if dr, ok := drivers[driverName]; ok {
al.DbBaser = dbBasers[dr]
al.Driver = dr
} else {
err = fmt.Errorf("driver name `%s` have not registered", driverName)
goto end
}
if dataBaseCache.add(name, al) == false {
err = fmt.Errorf("db name `%s` already registered, cannot reuse", name)
goto end
}
al.DB, err = sql.Open(driverName, dataSource)
if err != nil {
err = fmt.Errorf("register db `%s`, %s", name, err.Error())
goto end
}
al.DB.SetMaxIdleConns(al.MaxIdle)
// orm timezone system match database
// default use Local
al.TZ = time.Local
switch al.Driver {
case DR_MySQL:
row := al.DB.QueryRow("SELECT @@session.time_zone")
var tz string
row.Scan(&tz)
if tz != "SYSTEM" {
t, err := time.Parse("-07:00", tz)
if err == nil {
al.TZ = t.Location()
}
}
case DR_Sqlite:
al.TZ = time.UTC
case DR_Postgres:
row := al.DB.QueryRow("SELECT current_setting('TIMEZONE')")
var tz string
row.Scan(&tz)
loc, err := time.LoadLocation(tz)
if err == nil {
al.TZ = loc
}
}
err = al.DB.Ping()
if err != nil {
err = fmt.Errorf("register db `%s`, %s", name, err.Error())
goto end
}
end:
if err != nil {
fmt.Println(err.Error())
os.Exit(2)
}
}
func RegisterDriver(driverName string, typ DriverType) {
if t, ok := drivers[driverName]; ok == false {
drivers[driverName] = typ
} else {
if t != typ {
fmt.Sprintf("driverName `%s` db driver already registered and is other type\n", driverName)
os.Exit(2)
}
}
}
func SetDataBaseTZ(name string, tz *time.Location) {
if al, ok := dataBaseCache.get(name); ok {
al.TZ = tz
} else {
fmt.Sprintf("DataBase name `%s` not registered\n", name)
os.Exit(2)
}
}

34
orm/db_mysql.go Normal file
View File

@ -0,0 +1,34 @@
package orm
var mysqlOperators = map[string]string{
"exact": "= ?",
"iexact": "LIKE ?",
"contains": "LIKE BINARY ?",
"icontains": "LIKE ?",
// "regex": "REGEXP BINARY ?",
// "iregex": "REGEXP ?",
"gt": "> ?",
"gte": ">= ?",
"lt": "< ?",
"lte": "<= ?",
"startswith": "LIKE BINARY ?",
"endswith": "LIKE BINARY ?",
"istartswith": "LIKE ?",
"iendswith": "LIKE ?",
}
type dbBaseMysql struct {
dbBase
}
var _ dbBaser = new(dbBaseMysql)
func (d *dbBaseMysql) OperatorSql(operator string) string {
return mysqlOperators[operator]
}
func newdbBaseMysql() dbBaser {
b := new(dbBaseMysql)
b.ins = b
return b
}

17
orm/db_oracle.go Normal file
View File

@ -0,0 +1,17 @@
package orm
type dbBaseOracle struct {
dbBase
}
var _ dbBaser = new(dbBaseOracle)
func (d *dbBase) OperatorSql(operator string) string {
return ""
}
func newdbBaseOracle() dbBaser {
b := new(dbBaseOracle)
b.ins = b
return b
}

94
orm/db_postgres.go Normal file
View File

@ -0,0 +1,94 @@
package orm
import (
"fmt"
"strconv"
)
var postgresOperators = map[string]string{
"exact": "= ?",
"iexact": "= UPPER(?)",
"contains": "LIKE ?",
"icontains": "LIKE UPPER(?)",
"gt": "> ?",
"gte": ">= ?",
"lt": "< ?",
"lte": "<= ?",
"startswith": "LIKE ?",
"endswith": "LIKE ?",
"istartswith": "LIKE UPPER(?)",
"iendswith": "LIKE UPPER(?)",
}
type dbBasePostgres struct {
dbBase
}
var _ dbBaser = new(dbBasePostgres)
func (d *dbBasePostgres) OperatorSql(operator string) string {
return postgresOperators[operator]
}
func (d *dbBasePostgres) GenerateOperatorLeftCol(fi *fieldInfo, operator string, leftCol *string) {
switch operator {
case "contains", "startswith", "endswith":
*leftCol = fmt.Sprintf("%s::text", *leftCol)
case "iexact", "icontains", "istartswith", "iendswith":
*leftCol = fmt.Sprintf("UPPER(%s::text)", *leftCol)
}
}
func (d *dbBasePostgres) SupportUpdateJoin() bool {
return false
}
func (d *dbBasePostgres) MaxLimit() uint64 {
return 0
}
func (d *dbBasePostgres) TableQuote() string {
return `"`
}
func (d *dbBasePostgres) ReplaceMarks(query *string) {
q := *query
num := 0
for _, c := range q {
if c == '?' {
num += 1
}
}
if num == 0 {
return
}
data := make([]byte, 0, len(q)+num)
num = 1
for i := 0; i < len(q); i++ {
c := q[i]
if c == '?' {
data = append(data, '$')
data = append(data, []byte(strconv.Itoa(num))...)
num += 1
} else {
data = append(data, c)
}
}
*query = string(data)
}
func (d *dbBasePostgres) HasReturningID(mi *modelInfo, query *string) (has bool) {
if mi.fields.pk.auto {
if query != nil {
*query = fmt.Sprintf(`%s RETURNING "%s"`, *query, mi.fields.pk.column)
}
has = true
}
return
}
func newdbBasePostgres() dbBaser {
b := new(dbBasePostgres)
b.ins = b
return b
}

50
orm/db_sqlite.go Normal file
View File

@ -0,0 +1,50 @@
package orm
import (
"fmt"
)
var sqliteOperators = map[string]string{
"exact": "= ?",
"iexact": "LIKE ? ESCAPE '\\'",
"contains": "LIKE ? ESCAPE '\\'",
"icontains": "LIKE ? ESCAPE '\\'",
"gt": "> ?",
"gte": ">= ?",
"lt": "< ?",
"lte": "<= ?",
"startswith": "LIKE ? ESCAPE '\\'",
"endswith": "LIKE ? ESCAPE '\\'",
"istartswith": "LIKE ? ESCAPE '\\'",
"iendswith": "LIKE ? ESCAPE '\\'",
}
type dbBaseSqlite struct {
dbBase
}
var _ dbBaser = new(dbBaseSqlite)
func (d *dbBaseSqlite) OperatorSql(operator string) string {
return sqliteOperators[operator]
}
func (d *dbBaseSqlite) GenerateOperatorLeftCol(fi *fieldInfo, operator string, leftCol *string) {
if fi.fieldType == TypeDateField {
*leftCol = fmt.Sprintf("DATE(%s)", *leftCol)
}
}
func (d *dbBaseSqlite) SupportUpdateJoin() bool {
return false
}
func (d *dbBaseSqlite) MaxLimit() uint64 {
return 9223372036854775807
}
func newdbBaseSqlite() dbBaser {
b := new(dbBaseSqlite)
b.ins = b
return b
}

391
orm/db_tables.go Normal file
View File

@ -0,0 +1,391 @@
package orm
import (
"fmt"
"strings"
"time"
)
type dbTable struct {
id int
index string
name string
names []string
sel bool
inner bool
mi *modelInfo
fi *fieldInfo
jtl *dbTable
}
type dbTables struct {
tablesM map[string]*dbTable
tables []*dbTable
mi *modelInfo
base dbBaser
}
func (t *dbTables) set(names []string, mi *modelInfo, fi *fieldInfo, inner bool) *dbTable {
name := strings.Join(names, ExprSep)
if j, ok := t.tablesM[name]; ok {
j.name = name
j.mi = mi
j.fi = fi
j.inner = inner
} else {
i := len(t.tables) + 1
jt := &dbTable{i, fmt.Sprintf("T%d", i), name, names, false, inner, mi, fi, nil}
t.tablesM[name] = jt
t.tables = append(t.tables, jt)
}
return t.tablesM[name]
}
func (t *dbTables) add(names []string, mi *modelInfo, fi *fieldInfo, inner bool) (*dbTable, bool) {
name := strings.Join(names, ExprSep)
if _, ok := t.tablesM[name]; ok == false {
i := len(t.tables) + 1
jt := &dbTable{i, fmt.Sprintf("T%d", i), name, names, false, inner, mi, fi, nil}
t.tablesM[name] = jt
t.tables = append(t.tables, jt)
return jt, true
}
return t.tablesM[name], false
}
func (t *dbTables) get(name string) (*dbTable, bool) {
j, ok := t.tablesM[name]
return j, ok
}
func (t *dbTables) loopDepth(depth int, prefix string, fi *fieldInfo, related []string) []string {
if depth < 0 || fi.fieldType == RelManyToMany {
return related
}
if prefix == "" {
prefix = fi.name
} else {
prefix = prefix + ExprSep + fi.name
}
related = append(related, prefix)
depth--
for _, fi := range fi.relModelInfo.fields.fieldsRel {
related = t.loopDepth(depth, prefix, fi, related)
}
return related
}
func (t *dbTables) parseRelated(rels []string, depth int) {
relsNum := len(rels)
related := make([]string, relsNum)
copy(related, rels)
relDepth := depth
if relsNum != 0 {
relDepth = 0
}
relDepth--
for _, fi := range t.mi.fields.fieldsRel {
related = t.loopDepth(relDepth, "", fi, related)
}
for i, s := range related {
var (
exs = strings.Split(s, ExprSep)
names = make([]string, 0, len(exs))
mmi = t.mi
cansel = true
jtl *dbTable
)
for _, ex := range exs {
if fi, ok := mmi.fields.GetByAny(ex); ok && fi.rel && fi.fieldType != RelManyToMany {
names = append(names, fi.name)
mmi = fi.relModelInfo
jt := t.set(names, mmi, fi, fi.null == false)
jt.jtl = jtl
if fi.reverse {
cansel = false
}
if cansel {
jt.sel = depth > 0
if i < relsNum {
jt.sel = true
}
}
jtl = jt
} else {
panic(fmt.Sprintf("unknown model/table name `%s`", ex))
}
}
}
}
func (t *dbTables) getJoinSql() (join string) {
Q := t.base.TableQuote()
for _, jt := range t.tables {
if jt.inner {
join += "INNER JOIN "
} else {
join += "LEFT OUTER JOIN "
}
var (
table string
t1, t2 string
c1, c2 string
)
t1 = "T0"
if jt.jtl != nil {
t1 = jt.jtl.index
}
t2 = jt.index
table = jt.mi.table
switch {
case jt.fi.fieldType == RelManyToMany || jt.fi.reverse && jt.fi.reverseFieldInfo.fieldType == RelManyToMany:
c1 = jt.fi.mi.fields.pk.column
for _, ffi := range jt.mi.fields.fieldsRel {
if jt.fi.mi == ffi.relModelInfo {
c2 = ffi.column
break
}
}
default:
c1 = jt.fi.column
c2 = jt.fi.relModelInfo.fields.pk.column
if jt.fi.reverse {
c1 = jt.mi.fields.pk.column
c2 = jt.fi.reverseFieldInfo.column
}
}
join += fmt.Sprintf("%s%s%s %s ON %s.%s%s%s = %s.%s%s%s ", Q, table, Q, t2,
t2, Q, c2, Q, t1, Q, c1, Q)
}
return
}
func (d *dbTables) parseExprs(mi *modelInfo, exprs []string) (index, name string, info *fieldInfo, success bool) {
var (
ffi *fieldInfo
jtl *dbTable
mmi = mi
)
num := len(exprs) - 1
names := make([]string, 0)
for i, ex := range exprs {
exist := false
check:
fi, ok := mmi.fields.GetByAny(ex)
if ok {
if num != i {
names = append(names, fi.name)
switch {
case fi.rel:
mmi = fi.relModelInfo
if fi.fieldType == RelManyToMany {
mmi = fi.relThroughModelInfo
}
case fi.reverse:
mmi = fi.reverseFieldInfo.mi
if fi.reverseFieldInfo.fieldType == RelManyToMany {
mmi = fi.reverseFieldInfo.relThroughModelInfo
}
default:
return
}
jt, _ := d.add(names, mmi, fi, fi.null == false)
jt.jtl = jtl
jtl = jt
if fi.rel && fi.fieldType == RelManyToMany {
ex = fi.relModelInfo.name
goto check
}
if fi.reverse && fi.reverseFieldInfo.fieldType == RelManyToMany {
ex = fi.reverseFieldInfo.mi.name
goto check
}
exist = true
} else {
if ffi == nil {
index = "T0"
} else {
index = jtl.index
}
info = fi
if jtl != nil {
name = jtl.name + ExprSep + fi.name
} else {
name = fi.name
}
switch fi.fieldType {
case RelManyToMany, RelReverseMany:
default:
exist = true
}
}
ffi = fi
}
if exist == false {
index = ""
name = ""
info = nil
success = false
return
}
}
success = index != "" && info != nil
return
}
func (d *dbTables) getCondSql(cond *Condition, sub bool, tz *time.Location) (where string, params []interface{}) {
if cond == nil || cond.IsEmpty() {
return
}
Q := d.base.TableQuote()
mi := d.mi
// outFor:
for i, p := range cond.params {
if i > 0 {
if p.isOr {
where += "OR "
} else {
where += "AND "
}
}
if p.isNot {
where += "NOT "
}
if p.isCond {
w, ps := d.getCondSql(p.cond, true, tz)
if w != "" {
w = fmt.Sprintf("( %s) ", w)
}
where += w
params = append(params, ps...)
} else {
exprs := p.exprs
num := len(exprs) - 1
operator := ""
if operators[exprs[num]] {
operator = exprs[num]
exprs = exprs[:num]
}
index, _, fi, suc := d.parseExprs(mi, exprs)
if suc == false {
panic(fmt.Errorf("unknown field/column name `%s`", strings.Join(p.exprs, ExprSep)))
}
if operator == "" {
operator = "exact"
}
operSql, args := d.base.GenerateOperatorSql(mi, fi, operator, p.args, tz)
leftCol := fmt.Sprintf("%s.%s%s%s", index, Q, fi.column, Q)
d.base.GenerateOperatorLeftCol(fi, operator, &leftCol)
where += fmt.Sprintf("%s %s ", leftCol, operSql)
params = append(params, args...)
}
}
if sub == false && where != "" {
where = "WHERE " + where
}
return
}
func (d *dbTables) getOrderSql(orders []string) (orderSql string) {
if len(orders) == 0 {
return
}
Q := d.base.TableQuote()
orderSqls := make([]string, 0, len(orders))
for _, order := range orders {
asc := "ASC"
if order[0] == '-' {
asc = "DESC"
order = order[1:]
}
exprs := strings.Split(order, ExprSep)
index, _, fi, suc := d.parseExprs(d.mi, exprs)
if suc == false {
panic(fmt.Errorf("unknown field/column name `%s`", strings.Join(exprs, ExprSep)))
}
orderSqls = append(orderSqls, fmt.Sprintf("%s.%s%s%s %s", index, Q, fi.column, Q, asc))
}
orderSql = fmt.Sprintf("ORDER BY %s ", strings.Join(orderSqls, ", "))
return
}
func (d *dbTables) getLimitSql(mi *modelInfo, offset int64, limit int) (limits string) {
if limit == 0 {
limit = DefaultRowsLimit
}
if limit < 0 {
// no limit
if offset > 0 {
maxLimit := d.base.MaxLimit()
if maxLimit == 0 {
limits = fmt.Sprintf("OFFSET %d", offset)
} else {
limits = fmt.Sprintf("LIMIT %d OFFSET %d", maxLimit, offset)
}
}
} else if offset <= 0 {
limits = fmt.Sprintf("LIMIT %d", limit)
} else {
limits = fmt.Sprintf("LIMIT %d OFFSET %d", limit, offset)
}
return
}
func newDbTables(mi *modelInfo, base dbBaser) *dbTables {
tables := &dbTables{}
tables.tablesM = make(map[string]*dbTable)
tables.mi = mi
tables.base = base
return tables
}

98
orm/db_utils.go Normal file
View File

@ -0,0 +1,98 @@
package orm
import (
"fmt"
"reflect"
"time"
)
func getExistPk(mi *modelInfo, ind reflect.Value) (column string, value interface{}, exist bool) {
fi := mi.fields.pk
v := ind.Field(fi.fieldIndex)
if fi.fieldType&IsIntegerField > 0 {
vu := v.Int()
exist = vu > 0
value = vu
} else {
vu := v.String()
exist = vu != ""
value = vu
}
column = fi.column
return
}
func getFlatParams(fi *fieldInfo, args []interface{}, tz *time.Location) (params []interface{}) {
outFor:
for _, arg := range args {
val := reflect.ValueOf(arg)
if arg == nil {
params = append(params, arg)
continue
}
switch v := arg.(type) {
case []byte:
case time.Time:
if fi != nil && fi.fieldType == TypeDateField {
arg = v.In(DefaultTimeLoc).Format(format_Date)
} else {
arg = v.In(tz).Format(format_DateTime)
}
default:
kind := val.Kind()
switch kind {
case reflect.Slice, reflect.Array:
var args []interface{}
for i := 0; i < val.Len(); i++ {
v := val.Index(i)
var vu interface{}
if v.CanInterface() {
vu = v.Interface()
}
if vu == nil {
continue
}
args = append(args, vu)
}
if len(args) > 0 {
p := getFlatParams(fi, args, tz)
params = append(params, p...)
}
continue outFor
case reflect.Ptr, reflect.Struct:
ind := reflect.Indirect(val)
if ind.Kind() == reflect.Struct {
typ := ind.Type()
name := getFullName(typ)
var value interface{}
if mmi, ok := modelCache.getByFN(name); ok {
if _, vu, exist := getExistPk(mmi, ind); exist {
value = vu
}
}
arg = value
if arg == nil {
panic(fmt.Errorf("need a valid args value, unknown table or value `%s`", name))
}
} else {
arg = ind.Interface()
}
}
}
params = append(params, arg)
}
return
}

View File

@ -0,0 +1,38 @@
## Custom Fields
TypeBooleanField = 1 << iota
// string
TypeCharField
// string
TypeTextField
// time.Time
TypeDateField
// time.Time
TypeDateTimeField
// int16
TypeSmallIntegerField
// int32
TypeIntegerField
// int64
TypeBigIntegerField
// uint16
TypePositiveSmallIntegerField
// uint32
TypePositiveIntegerField
// uint64
TypePositiveBigIntegerField
// float64
TypeFloatField
// float64
TypeDecimalField
RelForeignKey
RelOneToOne
RelManyToMany
RelReverseOne
RelReverseMany

257
orm/docs/zh/Models.md Normal file
View File

@ -0,0 +1,257 @@
## 模型定义
复杂的模型定义不是必须的,此功能用作数据库数据转换和自动建表
## Struct Tag 设置参数
```go
orm:"null;rel(fk)"
```
通常每个 Field 的 StructTag 里包含两种类型的设置,类似 null 的 bool 型设置,还有 类似 rel(fk) 的指定值设置bool 型默认为 false指定以后即表示为 true
多个设置间使用 `;` 分隔,设置的值如果是多个,使用 `,` 分隔。
#### 忽略字段
设置 `-` 即可忽略 struct 中的字段
```go
type User struct {
...
AnyField string `orm:"-"`
...
```
#### auto
设置为 Autoincrement Primary Key
#### pk
设置为 Primary Key
#### null
数据库表默认为 `NOT NULL`,设置 null 代表 `ALLOW NULL`
#### blank
设置 string 类型的字段允许为空,否则 clean 会返回错误
#### index
为字段增加索引
#### unique
为字段增加 unique 键
#### column
为字段设置 db 字段的名称
```go
Name `orm:"column(user_name)"`
```
#### default
为字段设置默认值,类型必须符合
```go
type User struct {
...
Status int `orm:"default(1)"`
```
#### size (string)
string 类型字段设置 size 以后db type 将使用 varchar
```go
Title string `orm:"size(60)"`
```
#### digits / decimals
设置 float32, float64 类型的浮点精度
```go
Money float64 `orm:"digits(12);decimals(4)"`
```
总长度 12 小数点后 4 位 eg: `99999999.9999`
#### auto_now / auto_now_add
```go
Created time.Time `auto_now_add`
Updated time.Time `auto_now`
```
* auto_now 每次 model 保存时都会对时间自动更新
* auto_now_add 第一次保存时才设置时间
对于批量的 update 此设置是不生效的
#### type
设置为 date, time.Time 字段的对应 db 类型使用 date
```go
Created time.Time `orm:"auto_now_add;type(date)"`
```
## 表关系设置
#### rel / reverse
**RelOneToOne**:
```go
type User struct {
...
Profile *Profile `orm:"null;rel(one);on_delete(set_null)"`
```
对应的反向关系 **RelReverseOne**:
```go
type Profile struct {
...
User *User `orm:"reverse(one)" json:"-"`
```
**RelForeignKey**:
```go
type Post struct {
...
User*User `orm:"rel(fk)"` // RelForeignKey relation
```
对应的反向关系 **RelReverseMany**:
```go
type User struct {
...
Posts []*Post `orm:"reverse(many)" json:"-"` // fk 的反向关系
```
**RelManyToMany**:
```go
type Post struct {
...
Tags []*Tag `orm:"rel(m2m)"` // ManyToMany relation
```
对应的反向关系 **RelReverseMany**:
```go
type Tag struct {
...
Posts []*Post `orm:"reverse(many)" json:"-"`
```
#### rel_table / rel_through
此设置针对 `orm:"rel(m2m)"` 的关系字段
rel_table 设置自动生成的 m2m 关系表的名称
rel_through 如果要在 m2m 关系中使用自定义的 m2m 关系表
通过这个设置其名称,格式为 pkg.path.ModelName
eg: app.models.PostTagRel
PostTagRel 表需要有到 Post 和 Tag 的关系
当设置 rel_table 时会忽略 rel_through
#### on_delete
设置对应的 rel 关系删除时,如何处理关系字段。
cascade 级联删除(默认值)
set_null 设置为 NULL需要设置 null = true
set_default 设置为默认值,需要设置 default 值
do_nothing 什么也不做,忽略
```go
type User struct {
...
Profile *Profile `orm:"null;rel(one);on_delete(set_null)"`
...
type Profile struct {
...
User *User `orm:"reverse(one)" json:"-"`
// 删除 Profile 时将设置 User.Profile 的数据库字段为 NULL
```
## 模型字段与数据库类型的对应
在此列出 orm 推荐的对应数据库类型,自动建表功能也会以此为标准。
默认所有的字段都是 **NOT NULL**
#### MySQL
| go |mysql
| :--- | :---
| bool | bool
| string - 设置 size 时 | varchar(size)
| string | longtext
| time.Time - 设置 type 为 date 时 | date
| time.TIme | datetime
| byte | tinyint unsigned
| rune | integer
| int | integer
| int8 | tinyint
| int16 | smallint
| int32 | integer
| int64 | bigint
| uint | integer unsigned
| uint8 | tinyint unsigned
| uint16 | smallint unsigned
| uint32 | integer unsigned
| uint64 | bigint unsigned
| float32 | double precision
| float64 | double precision
| float64 - 设置 digits, decimals 时 | numeric(digits, decimals)
#### Sqlite3
| go | sqlite3
| :--- | :---
| bool | bool
| string - 设置 size 时 | varchar(size)
| string | text
| time.Time - 设置 type 为 date 时 | date
| time.TIme | datetime
| byte | tinyint unsigned
| rune | integer
| int | integer
| int8 | tinyint
| int16 | smallint
| int32 | integer
| int64 | bigint
| uint | integer unsigned
| uint8 | tinyint unsigned
| uint16 | smallint unsigned
| uint32 | integer unsigned
| uint64 | bigint unsigned
| float32 | real
| float64 | real
| float64 - 设置 digits, decimals 时 | decimal
#### PostgreSQL
| go | postgres
| :--- | :---
| bool | bool
| string - 设置 size 时 | varchar(size)
| string | text
| time.Time - 设置 type 为 date 时 | date
| time.TIme | timestamp with time zone
| byte | smallint CHECK("column" >= 0 AND "column" <= 255)
| rune | integer
| int | integer
| int8 | smallint CHECK("column" >= -127 AND "column" <= 128)
| int16 | smallint
| int32 | integer
| int64 | bigint
| uint | bigint CHECK("column" >= 0)
| uint8 | smallint CHECK("column" >= 0 AND "column" <= 255)
| uint16 | integer CHECK("column" >= 0)
| uint32 | bigint CHECK("column" >= 0)
| uint64 | bigint CHECK("column" >= 0)
| float32 | double precision
| float64 | double precision
| float64 - 设置 digits, decimals 时 | numeric(digits, decimals)
## 关系型字段
其字段类型取决于对应的主键。
* RelForeignKey
* RelOneToOne
* RelManyToMany
* RelReverseOne
* RelReverseMany

83
orm/docs/zh/Models.sql Normal file
View File

@ -0,0 +1,83 @@
SET NAMES utf8;
SET FOREIGN_KEY_CHECKS = 0;
-- ----------------------------
-- Table structure for `comment`
-- ----------------------------
DROP TABLE IF EXISTS `comment`;
CREATE TABLE `comment` (
`id` int(11) NOT NULL,
`post_id` bigint(200) NOT NULL,
`content` longtext NOT NULL,
`parent_id` int(11) DEFAULT NULL,
`status` smallint(4) NOT NULL,
`created` datetime NOT NULL,
PRIMARY KEY (`id`)
) ENGINE=InnoDB DEFAULT CHARSET=utf8;
-- ----------------------------
-- Table structure for `post`
-- ----------------------------
DROP TABLE IF EXISTS `post`;
CREATE TABLE `post` (
`id` int(11) NOT NULL,
`user_id` int(11) NOT NULL,
`title` varchar(60) NOT NULL,
`content` longtext NOT NULL,
`created` datetime NOT NULL,
`updated` datetime NOT NULL,
PRIMARY KEY (`id`)
) ENGINE=InnoDB DEFAULT CHARSET=utf8;
-- ----------------------------
-- Table structure for `post_tag_rel`
-- ----------------------------
DROP TABLE IF EXISTS `post_tag_rel`;
CREATE TABLE `post_tag_rel` (
`id` int(11) NOT NULL,
`post_id` int(11) NOT NULL,
`tag_id` int(11) NOT NULL,
PRIMARY KEY (`id`)
) ENGINE=InnoDB DEFAULT CHARSET=utf8;
-- ----------------------------
-- Table structure for `tag`
-- ----------------------------
DROP TABLE IF EXISTS `tag`;
CREATE TABLE `tag` (
`id` int(11) NOT NULL,
`name` varchar(30) NOT NULL,
`status` smallint(4) NOT NULL,
PRIMARY KEY (`id`)
) ENGINE=InnoDB DEFAULT CHARSET=utf8;
-- ----------------------------
-- Table structure for `user`
-- ----------------------------
DROP TABLE IF EXISTS `user`;
CREATE TABLE `user` (
`id` int(11) NOT NULL AUTO_INCREMENT,
`user_name` varchar(30) NOT NULL,
`email` varchar(100) NOT NULL,
`password` varchar(30) NOT NULL,
`status` smallint(4) NOT NULL,
`is_staff` tinyint(1) NOT NULL,
`is_active` tinyint(1) NOT NULL,
`created` date NOT NULL,
`updated` datetime NOT NULL,
`profile_id` int(11) DEFAULT NULL,
PRIMARY KEY (`id`)
) ENGINE=InnoDB DEFAULT CHARSET=utf8;
-- ----------------------------
-- Table structure for `profile`
-- ----------------------------
DROP TABLE IF EXISTS `profile`;
CREATE TABLE `profile` (
`id` int(11) NOT NULL AUTO_INCREMENT,
`age` smallint(4) NOT NULL,
`money` double NOT NULL,
PRIMARY KEY (`id`)
) ENGINE=InnoDB DEFAULT CHARSET=utf8;
SET FOREIGN_KEY_CHECKS = 1;

59
orm/docs/zh/Object.md Normal file
View File

@ -0,0 +1,59 @@
## 对象的CRUD操作
对 object 操作简单的三个方法 Read / Insert / Update / Delete
```go
o := orm.NewOrm()
user := NewUser()
user.Name = "slene"
fmt.Println(o.Insert(user))
user.Name = "Your"
fmt.Println(o.Update(user))
fmt.Println(o.Read(user))
fmt.Println(o.Delete(user))
```
### Read
```go
o := orm.NewOrm()
user := User{Id: 1}
err = o.Read(&user)
if err == sql.ErrNoRows {
fmt.Println("查询不到")
} else if err == orm.ErrMissPK {
fmt.Println("找不到主键")
} else {
fmt.Println(user.Id, user.Name)
}
```
### Insert
```go
o := orm.NewOrm()
var user User
user.Name = "slene"
user.IsActive = true
fmt.Println(o.Insert(&user))
fmt.Println(user.Id)
```
创建后会自动对 auto 的 field 赋值
### Update
```go
o := orm.NewOrm()
user := User{Id: 1}
if o.Read(&user) == nil {
user.Name = "MyName"
o.Update(&user)
}
```
### Delete
```go
o := orm.NewOrm()
o.Delete(&User{Id: 1})
```
Delete 操作会对反向关系进行操作,此例中 Post 拥有一个到 User 的外键。删除 User 的时候。如果 on_delete 设置为默认的级联操作,将删除对应的 Post
删除以后会清除 auto field 的值

303
orm/docs/zh/Orm.md Normal file
View File

@ -0,0 +1,303 @@
## Orm 使用方法
beego/orm 的使用例子
后文例子如无特殊说明都以这个为基础。
##### models.go:
```go
package main
import (
"github.com/astaxie/beego/orm"
)
type User struct {
Id int `orm:"auto"` // 设置为auto主键
Name string
Profile *Profile `orm:"rel(one)"` // OneToOne relation
}
type Profile struct {
Id int `orm:"auto"`
Age int16
User *User `orm:"reverse(one)"` // 设置反向关系(可选)
}
func init() {
// 需要在init中注册定义的model
orm.RegisterModel(new(User), new(Profile))
}
```
##### main.go
```go
package main
import (
"fmt"
"github.com/astaxie/beego/orm"
_ "github.com/go-sql-driver/mysql"
)
func init() {
orm.RegisterDriver("mysql", orm.DR_MySQL)
orm.RegisterDataBase("default", "mysql", "root:root@/orm_test?charset=utf8", 30)
}
func main() {
o := orm.NewOrm()
o.Using("default") // 默认使用 default你可以指定为其他数据库
profile := NewProfile()
profile.Age = 30
user := NewUser()
user.Profile = profile
user.Name = "slene"
fmt.Println(o.Insert(profile))
fmt.Println(o.Insert(user))
}
```
## 数据库的设置
目前 orm 支持三种数据库,以下为测试过的 driver
将你需要使用的 driver 加入 import 中
```go
import (
_ "github.com/go-sql-driver/mysql"
_ "github.com/lib/pq"
_ "github.com/mattn/go-sqlite3"
)
```
#### RegisterDriver
三种默认数据库类型
```go
orm.DR_MySQL
orm.DR_Sqlite
orm.DR_Postgres
```
```go
// 参数1 driverName
// 参数2 数据库类型
// 这个用来设置 driverName 对应的数据库类型
// mysql / sqlite3 / postgres 这三种是默认已经注册过的,所以可以无需设置
orm.RegisterDriver("mymysql", orm.DR_MySQL)
```
#### RegisterDataBase
orm 必须注册一个名称为 `default` 的数据库,用以作为默认使用。
```go
// 参数1 自定义数据库名称用来在orm中切换数据库使用
// 参数2 driverName
// 参数3 对应的链接字符串
// 参数4 设置最大的空闲连接数,使用 golang 自己的连接池
orm.RegisterDataBase("default", "mysql", "root:root@/orm_test?charset=utf8", 30)
```
#### 时区设置
orm 默认使用 time.Local 本地时区
* 作用于 orm 自动创建的时间
* 从数据库中取回的时间转换成 orm 本地时间
如果需要的话,你也可以进行更改
```go
// 设置为 UTC 时间
orm.DefaultTimeLoc = time.UTC
```
orm 在进行 RegisterDataBase 的同时,会获取数据库使用的时区,然后在 time.Time 类型存取的时做相应转换,以匹配时间系统,从而保证时间不会出错。
**注意:** 鉴于 Sqlite3 的设计,存取默认都为 UTC 时间
## RegisterModel
如果使用 orm.QuerySeter 进行高级查询的话,这个是必须的。
反之,如果只使用 Raw 查询和 map struct是无需这一步的。您可以去查看 [Raw SQL 查询](Raw.md)
将你定义的 Model 进行注册,最佳设计是有单独的 models.go 文件,在他的 init 函数中进行注册。
迷你版 models.go
```go
package main
import "github.com/astaxie/beego/orm"
type User struct {
Id int `orm:"auto"`
name string
}
func init(){
orm.RegisterModel(new(User))
}
```
RegisterModel 也可以同时注册多个 model
```go
orm.RegisterModel(new(User), new(Profile), new(Post))
```
## ORM 接口使用
使用 orm 必然接触的 Ormer 接口,我们来熟悉一下
```go
var o Ormer
o = orm.NewOrm() // 创建一个 Ormer
// NewOrm 的同时会执行 orm.BootStrap (整个 app 只执行一次),用以验证模型之间的定义并缓存。
```
* type Ormer interface {
* [Read(Modeler) error](Object.md#read)
* [Insert(Modeler) (int64, error)](Object.md#insert)
* [Update(Modeler) (int64, error)](Object.md#update)
* [Delete(Modeler) (int64, error)](Object.md#delete)
* [M2mAdd(Modeler, string, ...interface{}) (int64, error)](Object.md#m2madd)
* [M2mDel(Modeler, string, ...interface{}) (int64, error)](Object.md#m2mdel)
* [LoadRel(Modeler, string) (int64, error)](Object.md#loadRel)
* [QueryTable(interface{}) QuerySeter](#querytable)
* [Using(string) error](#using)
* [Begin() error](Transaction.md)
* [Commit() error](Transaction.md)
* [Rollback() error](Transaction.md)
* [Raw(string, ...interface{}) RawSeter](#raw)
* [Driver() Driver](#driver)
* }
#### QueryTable
传入表名,或者 Modeler 对象,返回一个 [QuerySeter](Query.md#queryseter)
```go
o := orm.NewOrm()
var qs QuerySeter
qs = o.QueryTable("user")
// 如果表没有定义过,会立刻 panic
```
#### Using
切换为其他数据库
```go
orm.RegisterDataBase("db1", "mysql", "root:root@/orm_db2?charset=utf8", 30)
orm.RegisterDataBase("db2", "sqlite3", "data.db", 30)
o1 := orm.NewOrm()
o1.Using("db1")
o2 := orm.NewOrm()
o2.Using("db2")
// 切换为其他数据库以后
// 这个 Ormer 对象的其下的 api 调用都将使用这个数据库
```
默认使用 `default` 数据库,无需调用 Using
#### Raw
使用 sql 语句直接进行操作
Raw 函数,返回一个 [RawSeter](Raw.md) 用以对设置的 sql 语句和参数进行操作
```go
o := NewOrm()
var r RawSeter
r = o.Raw("UPDATE user SET name = ? WHERE name = ?", "testing", "slene")
```
#### Driver
返回当前 orm 使用的 db 信息
```go
type Driver interface {
Name() string
Type() DriverType
}
```
```go
orm.RegisterDataBase("db1", "mysql", "root:root@/orm_db2?charset=utf8", 30)
orm.RegisterDataBase("db2", "sqlite3", "data.db", 30)
o1 := orm.NewOrm()
o1.Using("db1")
dr := o1.Driver()
fmt.Println(dr.Name() == "db1") // true
fmt.Println(dr.Type() == orm.DR_MySQL) // true
o2 := orm.NewOrm()
o2.Using("db2")
dr = o2.Driver()
fmt.Println(dr.Name() == "db2") // true
fmt.Println(dr.Type() == orm.DR_Sqlite) // true
```
## 调试模式打印查询语句
简单的设置 Debug 为 true 打印查询的语句
可能存在性能问题,不建议使用在产品模式
```go
func main() {
orm.Debug = true
...
```
默认使用 os.Stderr 输出日志信息
改变输出到你自己的 io.Writer
```go
var w io.Writer
...
// 设置为你的 io.Writer
...
orm.DebugLog = orm.NewLog(w)
```
日志格式
```go
[ORM] - 时间 - [Queries/数据库名] - [执行操作/执行时间] - [SQL语句] - 使用标点 `,` 分隔的参数列表 - 打印遇到的错误
```
```go
[ORM] - 2013-08-09 13:18:16 - [Queries/default] - [ db.Exec / 0.4ms] - [INSERT INTO `user` (`name`) VALUES (?)] - `slene`
[ORM] - 2013-08-09 13:18:16 - [Queries/default] - [ db.Exec / 0.5ms] - [UPDATE `user` SET `name` = ? WHERE `id` = ?] - `astaxie`, `14`
[ORM] - 2013-08-09 13:18:16 - [Queries/default] - [db.QueryRow / 0.4ms] - [SELECT `id`, `name` FROM `user` WHERE `id` = ?] - `14`
[ORM] - 2013-08-09 13:18:16 - [Queries/default] - [ db.Exec / 0.4ms] - [INSERT INTO `post` (`user_id`,`title`,`content`) VALUES (?, ?, ?)] - `14`, `beego orm`, `powerful amazing`
[ORM] - 2013-08-09 13:18:16 - [Queries/default] - [ db.Query / 0.4ms] - [SELECT T1.`name` `User__Name`, T0.`user_id` `User`, T1.`id` `User__Id` FROM `post` T0 INNER JOIN `user` T1 ON T1.`id` = T0.`user_id` WHERE T0.`id` = ? LIMIT 1000] - `68`
[ORM] - 2013-08-09 13:18:16 - [Queries/default] - [ db.Exec / 0.4ms] - [DELETE FROM `user` WHERE `id` = ?] - `14`
[ORM] - 2013-08-09 13:18:16 - [Queries/default] - [ db.Query / 0.3ms] - [SELECT T0.`id` FROM `post` T0 WHERE T0.`user_id` IN (?) ] - `14`
[ORM] - 2013-08-09 13:18:16 - [Queries/default] - [ db.Exec / 0.4ms] - [DELETE FROM `post` WHERE `id` IN (?)] - `68`
```
日志内容包括 **所有的数据库操作**事务Prepare

411
orm/docs/zh/Query.md Normal file
View File

@ -0,0 +1,411 @@
## 高级查询
orm 以 **QuerySeter** 来组织查询,每个返回 **QuerySeter** 的方法都会获得一个新的 **QuerySeter** 对象。
基本使用方法:
```go
o := orm.NewOrm()
// 获取 QuerySeter 对象user 为表名
qs := o.QueryTable("user")
// 也可以直接使用对象作为表名
user := NewUser()
qs = o.QueryTable(user) // 返回 QuerySeter
```
## expr
QuerySeter 中用于描述字段和 sql 操作符,使用简单的 expr 查询方法
字段组合的前后顺序依照表的关系,比如 User 表拥有 Profile 的外键,那么对 User 表查询对应的 Profile.Age 为条件,则使用 `Profile__Age` 注意,字段的分隔符号使用双下划线 `__`,除了描述字段, expr 的尾部可以增加操作符以执行对应的 sql 操作。比如 `Profile__Age__gt` 代表 Profile.Age > 18 的条件查询。
注释后面将描述对应的 sql 语句,仅仅是描述 expr 的类似结果,并不代表实际生成的语句。
```go
qs.Filter("id", 1) // WHERE id = 1
qs.Filter("profile__age", 18) // WHERE profile.age = 18
qs.Filter("Profile__Age", 18) // 使用字段名和Field名都是允许的
qs.Filter("profile__age", 18) // WHERE profile.age = 18
qs.Filter("profile__age__gt", 18) // WHERE profile.age > 18
qs.Filter("profile__age__gte", 18) // WHERE profile.age >= 18
qs.Filter("profile__age__in", 18, 20) // WHERE profile.age IN (18, 20)
qs.Filter("profile__age__in", 18, 20).Exclude("profile__lt", 1000)
// WHERE profile.age IN (18, 20) AND NOT profile_id < 1000
```
## Operators
当前支持的操作符号:
* [exact](#exact) / [iexact](#iexact) 等于
* [contains](#contains) / [icontains](#icontains) 包含
* [gt / gte](#gt / gte) 大于 / 大于等于
* [lt / lte](#lt / lte) 小于 / 小于等于
* [startswith](#startswith) / [istartswith](#istartswith) 以...起始
* [endswith](#endswith) / [iendswith](#iendswith) 以...结束
* [in](#in)
* [isnull](#isnull)
后面以 `i` 开头的表示:大小写不敏感
#### exact
Filter / Exclude / Condition expr 的默认值
```go
qs.Filter("name", "slene") // WHERE name = 'slene'
qs.Filter("name__exact", "slene") // WHERE name = 'slene'
// 使用 = 匹配,大小写是否敏感取决于数据表使用的 collation
qs.Filter("profile", nil) // WHERE profile_id IS NULL
```
#### iexact
```go
qs.Filter("name__iexact", "slene")
// WHERE name LIKE 'slene'
// 大小写不敏感,匹配任意 'Slene' 'sLENE'
```
#### contains
```go
qs.Filter("name__contains", "slene")
// WHERE name LIKE BINARY '%slene%'
// 大小写敏感, 匹配包含 slene 的字符
```
#### icontains
```go
qs.Filter("name__icontains", "slene")
// WHERE name LIKE '%slene%'
// 大小写不敏感, 匹配任意 'im Slene', 'im sLENE'
```
#### in
```go
qs.Filter("profile__age__in", 17, 18, 19, 20)
// WHERE profile.age IN (17, 18, 19, 20)
```
#### gt / gte
```go
qs.Filter("profile__age__gt", 17)
// WHERE profile.age > 17
qs.Filter("profile__age__gte", 18)
// WHERE profile.age >= 18
```
#### lt / lte
```go
qs.Filter("profile__age__lt", 17)
// WHERE profile.age < 17
qs.Filter("profile__age__lte", 18)
// WHERE profile.age <= 18
```
#### startswith
```go
qs.Filter("name__startswith", "slene")
// WHERE name LIKE BINARY 'slene%'
// 大小写敏感, 匹配以 'slene' 起始的字符串
```
#### istartswith
```go
qs.Filter("name__istartswith", "slene")
// WHERE name LIKE 'slene%'
// 大小写不敏感, 匹配任意以 'slene', 'Slene' 起始的字符串
```
#### endswith
```go
qs.Filter("name__endswith", "slene")
// WHERE name LIKE BINARY '%slene'
// 大小写敏感, 匹配以 'slene' 结束的字符串
```
#### iendswith
```go
qs.Filter("name__startswith", "slene")
// WHERE name LIKE '%slene'
// 大小写不敏感, 匹配任意以 'slene', 'Slene' 结束的字符串
```
#### isnull
```go
qs.Filter("profile__isnull", true)
qs.Filter("profile_id__isnull", true)
// WHERE profile_id IS NULL
qs.Filter("profile__isnull", false)
// WHERE profile_id IS NOT NULL
```
## 高级查询接口使用
QuerySeter 是高级查询使用的接口,我们来熟悉下他的接口方法
* type QuerySeter interface {
* [Filter(string, ...interface{}) QuerySeter](#filter)
* [Exclude(string, ...interface{}) QuerySeter](#exclude)
* [SetCond(*Condition) QuerySeter](#setcond)
* [Limit(int, ...int64) QuerySeter](#limit)
* [Offset(int64) QuerySeter](#offset)
* [OrderBy(...string) QuerySeter](#orderby)
* [RelatedSel(...interface{}) QuerySeter](#relatedsel)
* [Count() (int64, error)](#count)
* [Update(Params) (int64, error)](#update)
* [Delete() (int64, error)](#delete)
* [PrepareInsert() (Inserter, error)](#prepareinsert)
* [All(interface{}) (int64, error)](#all)
* [One(Modeler) error](#one)
* [Values(*[]Params, ...string) (int64, error)](#values)
* [ValuesList(*[]ParamsList, ...string) (int64, error)](#valueslist)
* [ValuesFlat(*ParamsList, string) (int64, error)](#valuesflat)
* }
* 每个返回 QuerySeter 的 api 调用时都会新建一个 QuerySeter不影响之前创建的。
* 高级查询使用 Filter 和 Exclude 来做常用的条件查询。囊括两种清晰的过滤规则:包含, 排除
#### Filter
用来过滤查询结果,起到 **包含条件** 的作用
多个 Filter 之间使用 `AND` 连接
```go
qs.Filter("profile__isnull", true).Filter("name", "slene")
// WHERE profile_id IS NULL AND name = 'slene'
```
#### Exclude
用来过滤查询结果,起到 **排除条件** 的作用
使用 `NOT` 排除条件
多个 Exclude 之间使用 `AND` 连接
```go
qs.Exclude("profile__isnull", true).Filter("name", "slene")
// WHERE NOT profile_id IS NULL AND name = 'slene'
```
#### SetCond
自定义条件表达式
```go
cond := NewCondition()
cond1 := cond.And("profile__isnull", false).AndNot("status__in", 1).Or("profile__age__gt", 2000)
qs := orm.QueryTable("user")
qs = qs.SetCond(cond1)
// WHERE ... AND ... AND NOT ... OR ...
cond2 := cond.AndCond(cond1).OrCond(cond.And("name", "slene"))
qs = qs.SetCond(cond2).Count()
// WHERE (... AND ... AND NOT ... OR ...) OR ( ... )
```
#### Limit
限制最大返回数据行数,第二个参数可以设置 `Offset`
```go
var DefaultRowsLimit = 1000 // orm 默认的 limit 值为 1000
// 默认情况下 select 查询的最大行数为 1000
// LIMIT 1000
qs.Limit(10)
// LIMIT 10
qs.Limit(10, 20)
// LIMIT 10 OFFSET 20
qs.Limit(-1)
// no limit
qs.Limit(-1, 100)
// LIMIT 18446744073709551615 OFFSET 100
// 18446744073709551615 是 1<<64 - 1 用来指定无 limit 限制 但有 offset 偏移的情况
```
#### Offset
设置 偏移行数
```go
qs.Offset(20)
// LIMIT 1000 OFFSET 20
```
#### OrderBy
参数使用 **expr**
在 expr 前使用减号 `-` 表示 `DESC` 的排列
```go
qs.OrderBy("id", "-profile__age")
// ORDER BY id ASC, profile.age DESC
qs.OrderBy("-profile__age", "profile")
// ORDER BY profile.age DESC, profile_id ASC
```
#### RelatedSel
关系查询,参数使用 **expr**
```go
var DefaultRelsDepth = 5 // 默认情况下直接调用 RelatedSel 将进行最大 5 层的关系查询
qs := o.QueryTable("post")
qs.RelateSel()
// INNER JOIN user ... LEFT OUTER JOIN profile ...
qs.RelateSel("user")
// INNER JOIN user ...
// 设置 expr 只对设置的字段进行关系查询
// 对设置 null 属性的 Field 将使用 LEFT OUTER JOIN
```
#### Count
依据当前的查询条件,返回结果行数
```go
cnt, err := o.QueryTable("user").Count() // SELECT COUNT(*) FROM USER
fmt.Printf("Count Num: %s, %s", cnt, err)
```
#### Update
依据当前查询条件,进行批量更新操作
```go
num, err := o.QueryTable("user").Filter("name", "slene").Update(orm.Params{
"name": "astaxie",
})
fmt.Printf("Affected Num: %s, %s", num, err)
// SET name = "astaixe" WHERE name = "slene"
```
#### Delete
依据当前查询条件,进行批量删除操作
```go
num, err := o.QueryTable("user").Filter("name", "slene").Delete()
fmt.Printf("Affected Num: %s, %s", num, err)
// DELETE FROM user WHERE name = "slene"
```
#### PrepareInsert
用于一次 prepare 多次 insert 插入,以提高批量插入的速度。
```go
var users []*User
...
qs := o.QueryTable("user")
i, _ := qs.PrepareInsert()
for _, user := range users {
id, err := i.Insert(user)
if err != nil {
...
}
}
// PREPARE INSERT INTO user (`name`, ...) VALUES (?, ...)
// EXECUTE INSERT INTO user (`name`, ...) VALUES ("slene", ...)
// EXECUTE ...
// ...
i.Close() // 别忘记关闭 statement
```
#### All
返回对应的结果集对象
```go
var users []*User
num, err := o.QueryTable("user").Filter("name", "slene").All(&users)
fmt.Printf("Returned Rows Num: %s, %s", num, err)
```
#### One
尝试返回单条记录
```go
var user *User
err := o.QueryTable("user").Filter("name", "slene").One(&user)
if err == orm.ErrMultiRows {
// 多条的时候报错
fmt.Printf("Returned Multi Rows Not One")
}
if err == orm.ErrNoRows {
// 没有找到记录
fmt.Printf("Not row found")
}
```
#### Values
返回结果集的 key => value 值
key 为 Model 里的 Field namevalue 的值 以 string 保存
```go
var maps []orm.Params
num, err := o.QueryTable("user").Values(&maps)
if err != nil {
fmt.Printf("Result Nums: %d\n", num)
for _, m := range maps {
fmt.Println(m["Id"], m["Name"])
}
}
```
返回指定的 Field 数据
**TODO**: 暂不支持级联查询 **RelatedSel** 直接返回 Values
但可以直接指定 expr 级联返回需要的数据
```go
var maps []orm.Params
num, err := o.QueryTable("user").Values(&maps, "id", "name", "profile", "profile__age")
if err != nil {
fmt.Printf("Result Nums: %d\n", num)
for _, m := range maps {
fmt.Println(m["Id"], m["Name"], m["Profile"], m["Profile__Age"])
// map 中的数据都是展开的,没有复杂的嵌套
}
}
```
#### ValuesList
顾名思义返回的结果集以slice存储
结果的排列与 Model 中定义的 Field 顺序一致
返回的每个元素值以 string 保存
```go
var lists []orm.ParamsList
num, err := o.QueryTable("user").ValuesList(&lists)
if err != nil {
fmt.Printf("Result Nums: %d\n", num)
for _, row := range lists {
fmt.Println(row)
}
}
```
当然也可以指定 expr 返回指定的 Field
```go
var lists []orm.ParamsList
num, err := o.QueryTable("user").ValuesList(&lists, "name", "profile__age")
if err != nil {
fmt.Printf("Result Nums: %d\n", num)
for _, row := range lists {
fmt.Printf("Name: %s, Age: %s\m", row[0], row[1])
}
}
```
#### ValuesFlat
只返回特定的 Field 值,讲结果集展开到单个 slice 里
```go
var list orm.ParamsList
num, err := o.QueryTable("user").ValuesFlat(&list, "name")
if err != nil {
fmt.Printf("Result Nums: %d\n", num)
fmt.Printf("All User Names: %s", strings.Join(list, ", ")
}
```

29
orm/docs/zh/README.md Normal file
View File

@ -0,0 +1,29 @@
## 文档目录
1. [Orm 使用方法](Orm.md)
- [数据库的设置](Orm.md#数据库的设置)
* [驱动类型设置](Orm.md#registerdriver)
* [参数设置](Orm.md#registerdataBase)
* [时区设置](Orm.md#时区设置)
- [注册 ORM 使用的模型](Orm.md#registermodel)
- [ORM 接口使用](Orm.md#orm-接口使用)
- [调试模式打印查询语句](Orm.md#调试模式打印查询语句)
2. [对象的CRUD操作](Object.md)
3. [高级查询](Query.md)
- [使用的表达式语法](Query.md#expr)
- [支持的操作符号](Query.md#operators)
- [高级查询接口使用](Query.md#高级查询接口使用)
4. [使用SQL语句进行查询](Raw.md)
5. [事务处理](Transaction.md)
6. [模型定义](Models.md)
- [Struct Tag 设置参数](Models.md#struct-tag-设置参数)
- [表关系设置](Models.md#表关系设置)
- [模型字段与数据库类型的对应](Models.md#模型字段与数据库类型的对应)
7. Custom Fields
8. Faq
### 文档更新记录
* 2013-08-13: ORM 的 [时区设置](Orm.md#时区设置)
* 2013-08-13: [模型字段与数据库类型的对应](Models.md#模型字段与数据库类型的对应) 推荐的数据库对应使用的类型

116
orm/docs/zh/Raw.md Normal file
View File

@ -0,0 +1,116 @@
## 使用SQL语句进行查询
* 使用 Raw SQL 查询,无需使用 ORM 表定义
* 多数据库,都可直接使用占位符号 `?`,自动转换
* 查询时的参数,支持使用 Model Struct 和 Slice, Array
```go
ids := []int{1, 2, 3}
p.Raw("SELECT name FROM user WHERE id IN (?, ?, ?)", ids)
```
创建一个 **RawSeter**
```go
o := NewOrm()
var r RawSeter
r = o.Raw("UPDATE user SET name = ? WHERE name = ?", "testing", "slene")
```
* type RawSeter interface {
* [Exec() (int64, error)](#exec)
* [QueryRow(...interface{}) error](#queryrow)
* [QueryRows(...interface{}) (int64, error)](#queryrows)
* [SetArgs(...interface{}) RawSeter](#setargs)
* [Values(*[]Params) (int64, error)](#values)
* [ValuesList(*[]ParamsList) (int64, error)](#valueslist)
* [ValuesFlat(*ParamsList) (int64, error)](#valuesflat)
* [Prepare() (RawPreparer, error)](#prepare)
* }
#### Exec
执行sql语句
```go
num, err := r.Exec()
```
#### QueryRow
TODO
#### QueryRows
TODO
#### SetArgs
改变 Raw(sql, args...) 中的 args 参数,返回一个新的 RawSeter
用于单条 sql 语句,重复利用,替换参数然后执行。
```go
num, err := r.SetArgs("arg1", "arg2").Exec()
num, err := r.SetArgs("arg1", "arg2").Exec()
...
```
#### Values / ValuesList / ValuesFlat
Raw SQL 查询获得的结果集 Value 为 `string` 类型NULL 字段的值为空 ``
#### Values
返回结果集的 key => value 值
```go
var maps []orm.Params
num, err = o.Raw("SELECT user_name FROM user WHERE status = ?", 1).Values(&maps)
if err == nil && num > 0 {
fmt.Println(maps[0]["user_name"]) // slene
}
```
#### ValuesList
返回结果集 slice
```go
var lists []orm.ParamsList
num, err = o.Raw("SELECT user_name FROM user WHERE status = ?", 1).ValuesList(&lists)
if err == nil && num > 0 {
fmt.Println(lists[0][0]) // slene
}
```
#### ValuesFlat
返回单一字段的平铺 slice 数据
```go
var list orm.ParamsList
num, err = o.Raw("SELECT id FROM user WHERE id < ?", 10).ValuesList(&list)
if err == nil && num > 0 {
fmt.Println(list) // []{"1","2","3",...}
}
```
#### Prepare
用于一次 prepare 多次 exec以提高批量执行的速度。
```go
p, err := o.Raw("UPDATE user SET name = ? WHERE name = ?").Prepare()
num, err := p.Exec("testing", "slene")
num, err = p.Exec("testing", "astaxie")
...
...
p.Close() // 别忘记关闭 statement
```
## FAQ
1. 我的 app 需要支持多类型数据库,如何在使用 Raw SQL 的时候判断当前使用的数据库类型。
使用 Ormer 的 [Driver方法](Orm.md#driver) 可以进行判断

View File

@ -0,0 +1,17 @@
## 事务处理
orm 可以简单的进行事务操作
```go
o := NewOrm()
err := o.Begin()
// 事务处理过程
...
...
// 此过程中的所有使用 o Ormer 对象的查询都在事务处理范围内
if SomeError {
err = o.Rollback()
} else {
err = o.Commit()
}
```

86
orm/models.go Normal file
View File

@ -0,0 +1,86 @@
package orm
import (
"sync"
)
const (
od_CASCADE = "cascade"
od_SET_NULL = "set_null"
od_SET_DEFAULT = "set_default"
od_DO_NOTHING = "do_nothing"
defaultStructTagName = "orm"
)
var (
modelCache = &_modelCache{
cache: make(map[string]*modelInfo),
cacheByFN: make(map[string]*modelInfo),
}
supportTag = map[string]int{
"-": 1,
"null": 1,
"blank": 1,
"index": 1,
"unique": 1,
"pk": 1,
"auto": 1,
"auto_now": 1,
"auto_now_add": 1,
"size": 2,
"column": 2,
"default": 2,
"rel": 2,
"reverse": 2,
"rel_table": 2,
"rel_through": 2,
"digits": 2,
"decimals": 2,
"on_delete": 2,
"type": 2,
}
)
type _modelCache struct {
sync.RWMutex
orders []string
cache map[string]*modelInfo
cacheByFN map[string]*modelInfo
done bool
}
func (mc *_modelCache) all() map[string]*modelInfo {
m := make(map[string]*modelInfo, len(mc.cache))
for k, v := range mc.cache {
m[k] = v
}
return m
}
func (mc *_modelCache) allOrdered() []*modelInfo {
m := make([]*modelInfo, 0, len(mc.orders))
for _, v := range mc.cache {
m = append(m, v)
}
return m
}
func (mc *_modelCache) get(table string) (mi *modelInfo, ok bool) {
mi, ok = mc.cache[table]
return
}
func (mc *_modelCache) getByFN(name string) (mi *modelInfo, ok bool) {
mi, ok = mc.cacheByFN[name]
return
}
func (mc *_modelCache) set(table string, mi *modelInfo) *modelInfo {
mii := mc.cache[table]
mc.cache[table] = mi
mc.cacheByFN[mi.fullName] = mi
if mii == nil {
mc.orders = append(mc.orders, table)
}
return mii
}

240
orm/models_boot.go Normal file
View File

@ -0,0 +1,240 @@
package orm
import (
"errors"
"fmt"
"os"
"reflect"
"strings"
)
func registerModel(model interface{}) {
val := reflect.ValueOf(model)
ind := reflect.Indirect(val)
typ := ind.Type()
if val.Kind() != reflect.Ptr {
panic(fmt.Sprintf("<orm.RegisterModel> cannot use non-ptr model struct `%s`", getFullName(typ)))
}
info := newModelInfo(val)
name := getFullName(typ)
if _, ok := modelCache.getByFN(name); ok {
fmt.Printf("<orm.RegisterModel> model `%s` redeclared, must be unique\n", name)
os.Exit(2)
}
table := getTableName(val)
if _, ok := modelCache.get(table); ok {
fmt.Printf("<orm.RegisterModel> table name `%s` redeclared, must be unique\n", table)
os.Exit(2)
}
if info.fields.pk == nil {
fmt.Printf("<orm.RegisterModel> `%s` need a primary key field\n", name)
os.Exit(2)
}
info.table = table
info.pkg = typ.PkgPath()
info.model = model
info.manual = true
modelCache.set(table, info)
}
func bootStrap() {
if modelCache.done {
return
}
var (
err error
models map[string]*modelInfo
)
if dataBaseCache.getDefault() == nil {
err = fmt.Errorf("must have one register alias named `default`")
goto end
}
models = modelCache.all()
for _, mi := range models {
for _, fi := range mi.fields.columns {
if fi.rel || fi.reverse {
elm := fi.addrValue.Type().Elem()
switch fi.fieldType {
case RelReverseMany, RelManyToMany:
elm = elm.Elem()
}
name := getFullName(elm)
mii, ok := modelCache.getByFN(name)
if ok == false || mii.pkg != elm.PkgPath() {
err = fmt.Errorf("can not found rel in field `%s`, `%s` may be miss register", fi.fullName, elm.String())
goto end
}
fi.relModelInfo = mii
switch fi.fieldType {
case RelManyToMany:
if fi.relThrough != "" {
msg := fmt.Sprintf("filed `%s` wrong rel_through value `%s`", fi.fullName, fi.relThrough)
if i := strings.LastIndex(fi.relThrough, "."); i != -1 && len(fi.relThrough) > (i+1) {
pn := fi.relThrough[:i]
mn := fi.relThrough[i+1:]
tn := snakeString(mn)
rmi, ok := modelCache.get(tn)
if ok == false || pn != rmi.pkg {
err = errors.New(msg + " cannot find table")
goto end
}
fi.relThroughModelInfo = rmi
fi.relTable = rmi.table
} else {
err = errors.New(msg)
goto end
}
err = nil
} else {
i := newM2MModelInfo(mi, mii)
if fi.relTable != "" {
i.table = fi.relTable
}
if v := modelCache.set(i.table, i); v != nil {
err = fmt.Errorf("the rel table name `%s` already registered, cannot be use, please change one", fi.relTable)
goto end
}
fi.relTable = i.table
fi.relThroughModelInfo = i
}
}
}
}
}
models = modelCache.all()
for _, mi := range models {
for _, fi := range mi.fields.fieldsRel {
switch fi.fieldType {
case RelForeignKey, RelOneToOne, RelManyToMany:
inModel := false
for _, ffi := range fi.relModelInfo.fields.fieldsReverse {
if ffi.relModelInfo == mi {
inModel = true
break
}
}
if inModel == false {
rmi := fi.relModelInfo
ffi := new(fieldInfo)
ffi.name = mi.name
ffi.column = ffi.name
ffi.fullName = rmi.fullName + "." + ffi.name
ffi.reverse = true
ffi.relModelInfo = mi
ffi.mi = rmi
if fi.fieldType == RelOneToOne {
ffi.fieldType = RelReverseOne
} else {
ffi.fieldType = RelReverseMany
}
if rmi.fields.Add(ffi) == false {
added := false
for cnt := 0; cnt < 5; cnt++ {
ffi.name = fmt.Sprintf("%s%d", mi.name, cnt)
ffi.column = ffi.name
ffi.fullName = rmi.fullName + "." + ffi.name
if added = rmi.fields.Add(ffi); added {
break
}
}
if added == false {
panic(fmt.Sprintf("cannot generate auto reverse field info `%s` to `%s`", fi.fullName, ffi.fullName))
}
}
}
}
}
}
for _, mi := range models {
if fields, ok := mi.fields.fieldsByType[RelReverseOne]; ok {
for _, fi := range fields {
found := false
mForA:
for _, ffi := range fi.relModelInfo.fields.fieldsByType[RelOneToOne] {
if ffi.relModelInfo == mi {
found = true
fi.reverseField = ffi.name
fi.reverseFieldInfo = ffi
break mForA
}
}
if found == false {
err = fmt.Errorf("reverse field `%s` not found in model `%s`", fi.fullName, fi.relModelInfo.fullName)
goto end
}
}
}
if fields, ok := mi.fields.fieldsByType[RelReverseMany]; ok {
for _, fi := range fields {
found := false
mForB:
for _, ffi := range fi.relModelInfo.fields.fieldsByType[RelForeignKey] {
if ffi.relModelInfo == mi {
found = true
fi.reverseField = ffi.name
fi.reverseFieldInfo = ffi
break mForB
}
}
if found == false {
mForC:
for _, ffi := range fi.relModelInfo.fields.fieldsByType[RelManyToMany] {
if ffi.relModelInfo == mi {
found = true
fi.reverseField = ffi.name
fi.reverseFieldInfo = ffi
break mForC
}
}
}
if found == false {
err = fmt.Errorf("reverse field `%s` not found in model `%s`", fi.fullName, fi.relModelInfo.fullName)
goto end
}
}
}
}
end:
if err != nil {
fmt.Println(err)
os.Exit(2)
}
}
func RegisterModel(models ...interface{}) {
if modelCache.done {
panic(fmt.Errorf("RegisterModel must be run begore BootStrap"))
}
for _, model := range models {
registerModel(model)
}
}
func BootStrap() {
if modelCache.done {
return
}
modelCache.Lock()
defer modelCache.Unlock()
bootStrap()
modelCache.done = true
}

599
orm/models_fields.go Normal file
View File

@ -0,0 +1,599 @@
package orm
import (
"errors"
"fmt"
"strconv"
"time"
)
const (
// bool
TypeBooleanField = 1 << iota
// string
TypeCharField
// string
TypeTextField
// time.Time
TypeDateField
// time.Time
TypeDateTimeField
// int8
TypeBitField
// int16
TypeSmallIntegerField
// int32
TypeIntegerField
// int64
TypeBigIntegerField
// uint8
TypePostiveBitField
// uint16
TypePositiveSmallIntegerField
// uint32
TypePositiveIntegerField
// uint64
TypePositiveBigIntegerField
// float64
TypeFloatField
// float64
TypeDecimalField
RelForeignKey
RelOneToOne
RelManyToMany
RelReverseOne
RelReverseMany
)
const (
IsIntegerField = ^-TypePositiveBigIntegerField >> 4 << 5
IsPostiveIntegerField = ^-TypePositiveBigIntegerField >> 8 << 9
IsRelField = ^-RelReverseMany >> 14 << 15
IsFieldType = ^-RelReverseMany<<1 + 1
)
// A true/false field.
type BooleanField bool
func (e BooleanField) Value() bool {
return bool(e)
}
func (e *BooleanField) Set(d bool) {
*e = BooleanField(d)
}
func (e *BooleanField) String() string {
return strconv.FormatBool(e.Value())
}
func (e *BooleanField) FieldType() int {
return TypeBooleanField
}
func (e *BooleanField) SetRaw(value interface{}) error {
switch d := value.(type) {
case bool:
e.Set(d)
case string:
v, err := StrTo(d).Bool()
if err != nil {
e.Set(v)
}
return err
default:
return errors.New(fmt.Sprintf("<BooleanField.SetRaw> unknown value `%s`", value))
}
return nil
}
func (e *BooleanField) RawValue() interface{} {
return e.Value()
}
func (e *BooleanField) Clean() error {
return nil
}
var _ Fielder = new(BooleanField)
// A string field
// required values tag: size
// The size is enforced at the database level and in modelss validation.
// eg: `orm:"size(120)"`
type CharField string
func (e CharField) Value() string {
return string(e)
}
func (e *CharField) Set(d string) {
*e = CharField(d)
}
func (e *CharField) String() string {
return e.Value()
}
func (e *CharField) FieldType() int {
return TypeCharField
}
func (e *CharField) SetRaw(value interface{}) error {
switch d := value.(type) {
case string:
e.Set(d)
default:
return errors.New(fmt.Sprintf("<CharField.SetRaw> unknown value `%s`", value))
}
return nil
}
func (e *CharField) RawValue() interface{} {
return e.Value()
}
func (e *CharField) Clean() error {
return nil
}
var _ Fielder = new(CharField)
// A date, represented in go by a time.Time instance.
// only date values like 2006-01-02
// Has a few extra, optional attr tag:
//
// auto_now:
// Automatically set the field to now every time the object is saved. Useful for “last-modified” timestamps.
// Note that the current date is always used; its not just a default value that you can override.
//
// auto_now_add:
// Automatically set the field to now when the object is first created. Useful for creation of timestamps.
// Note that the current date is always used; its not just a default value that you can override.
//
// eg: `orm:"auto_now"` or `orm:"auto_now_add"`
type DateField time.Time
func (e DateField) Value() time.Time {
return time.Time(e)
}
func (e *DateField) Set(d time.Time) {
*e = DateField(d)
}
func (e *DateField) String() string {
return e.Value().String()
}
func (e *DateField) FieldType() int {
return TypeDateField
}
func (e *DateField) SetRaw(value interface{}) error {
switch d := value.(type) {
case time.Time:
e.Set(d)
case string:
v, err := timeParse(d, format_Date)
if err != nil {
e.Set(v)
}
return err
default:
return errors.New(fmt.Sprintf("<DateField.SetRaw> unknown value `%s`", value))
}
return nil
}
func (e *DateField) RawValue() interface{} {
return e.Value()
}
func (e *DateField) Clean() error {
return nil
}
var _ Fielder = new(DateField)
// A date, represented in go by a time.Time instance.
// datetime values like 2006-01-02 15:04:05
// Takes the same extra arguments as DateField.
type DateTimeField time.Time
func (e DateTimeField) Value() time.Time {
return time.Time(e)
}
func (e *DateTimeField) Set(d time.Time) {
*e = DateTimeField(d)
}
func (e *DateTimeField) String() string {
return e.Value().String()
}
func (e *DateTimeField) FieldType() int {
return TypeDateTimeField
}
func (e *DateTimeField) SetRaw(value interface{}) error {
switch d := value.(type) {
case time.Time:
e.Set(d)
case string:
v, err := timeParse(d, format_DateTime)
if err != nil {
e.Set(v)
}
return err
default:
return errors.New(fmt.Sprintf("<DateTimeField.SetRaw> unknown value `%s`", value))
}
return nil
}
func (e *DateTimeField) RawValue() interface{} {
return e.Value()
}
func (e *DateTimeField) Clean() error {
return nil
}
var _ Fielder = new(DateTimeField)
// A floating-point number represented in go by a float32 value.
type FloatField float64
func (e FloatField) Value() float64 {
return float64(e)
}
func (e *FloatField) Set(d float64) {
*e = FloatField(d)
}
func (e *FloatField) String() string {
return ToStr(e.Value(), -1, 32)
}
func (e *FloatField) FieldType() int {
return TypeFloatField
}
func (e *FloatField) SetRaw(value interface{}) error {
switch d := value.(type) {
case float32:
e.Set(float64(d))
case float64:
e.Set(d)
case string:
v, err := StrTo(d).Float64()
if err != nil {
e.Set(v)
}
default:
return errors.New(fmt.Sprintf("<FloatField.SetRaw> unknown value `%s`", value))
}
return nil
}
func (e *FloatField) RawValue() interface{} {
return e.Value()
}
func (e *FloatField) Clean() error {
return nil
}
var _ Fielder = new(FloatField)
// -32768 to 32767
type SmallIntegerField int16
func (e SmallIntegerField) Value() int16 {
return int16(e)
}
func (e *SmallIntegerField) Set(d int16) {
*e = SmallIntegerField(d)
}
func (e *SmallIntegerField) String() string {
return ToStr(e.Value())
}
func (e *SmallIntegerField) FieldType() int {
return TypeSmallIntegerField
}
func (e *SmallIntegerField) SetRaw(value interface{}) error {
switch d := value.(type) {
case int16:
e.Set(d)
case string:
v, err := StrTo(d).Int16()
if err != nil {
e.Set(v)
}
default:
return errors.New(fmt.Sprintf("<SmallIntegerField.SetRaw> unknown value `%s`", value))
}
return nil
}
func (e *SmallIntegerField) RawValue() interface{} {
return e.Value()
}
func (e *SmallIntegerField) Clean() error {
return nil
}
var _ Fielder = new(SmallIntegerField)
// -2147483648 to 2147483647
type IntegerField int32
func (e IntegerField) Value() int32 {
return int32(e)
}
func (e *IntegerField) Set(d int32) {
*e = IntegerField(d)
}
func (e *IntegerField) String() string {
return ToStr(e.Value())
}
func (e *IntegerField) FieldType() int {
return TypeIntegerField
}
func (e *IntegerField) SetRaw(value interface{}) error {
switch d := value.(type) {
case int32:
e.Set(d)
case string:
v, err := StrTo(d).Int32()
if err != nil {
e.Set(v)
}
default:
return errors.New(fmt.Sprintf("<IntegerField.SetRaw> unknown value `%s`", value))
}
return nil
}
func (e *IntegerField) RawValue() interface{} {
return e.Value()
}
func (e *IntegerField) Clean() error {
return nil
}
var _ Fielder = new(IntegerField)
// -9223372036854775808 to 9223372036854775807.
type BigIntegerField int64
func (e BigIntegerField) Value() int64 {
return int64(e)
}
func (e *BigIntegerField) Set(d int64) {
*e = BigIntegerField(d)
}
func (e *BigIntegerField) String() string {
return ToStr(e.Value())
}
func (e *BigIntegerField) FieldType() int {
return TypeBigIntegerField
}
func (e *BigIntegerField) SetRaw(value interface{}) error {
switch d := value.(type) {
case int64:
e.Set(d)
case string:
v, err := StrTo(d).Int64()
if err != nil {
e.Set(v)
}
default:
return errors.New(fmt.Sprintf("<BigIntegerField.SetRaw> unknown value `%s`", value))
}
return nil
}
func (e *BigIntegerField) RawValue() interface{} {
return e.Value()
}
func (e *BigIntegerField) Clean() error {
return nil
}
var _ Fielder = new(BigIntegerField)
// 0 to 65535
type PositiveSmallIntegerField uint16
func (e PositiveSmallIntegerField) Value() uint16 {
return uint16(e)
}
func (e *PositiveSmallIntegerField) Set(d uint16) {
*e = PositiveSmallIntegerField(d)
}
func (e *PositiveSmallIntegerField) String() string {
return ToStr(e.Value())
}
func (e *PositiveSmallIntegerField) FieldType() int {
return TypePositiveSmallIntegerField
}
func (e *PositiveSmallIntegerField) SetRaw(value interface{}) error {
switch d := value.(type) {
case uint16:
e.Set(d)
case string:
v, err := StrTo(d).Uint16()
if err != nil {
e.Set(v)
}
default:
return errors.New(fmt.Sprintf("<PositiveSmallIntegerField.SetRaw> unknown value `%s`", value))
}
return nil
}
func (e *PositiveSmallIntegerField) RawValue() interface{} {
return e.Value()
}
func (e *PositiveSmallIntegerField) Clean() error {
return nil
}
var _ Fielder = new(PositiveSmallIntegerField)
// 0 to 4294967295
type PositiveIntegerField uint32
func (e PositiveIntegerField) Value() uint32 {
return uint32(e)
}
func (e *PositiveIntegerField) Set(d uint32) {
*e = PositiveIntegerField(d)
}
func (e *PositiveIntegerField) String() string {
return ToStr(e.Value())
}
func (e *PositiveIntegerField) FieldType() int {
return TypePositiveIntegerField
}
func (e *PositiveIntegerField) SetRaw(value interface{}) error {
switch d := value.(type) {
case uint32:
e.Set(d)
case string:
v, err := StrTo(d).Uint32()
if err != nil {
e.Set(v)
}
default:
return errors.New(fmt.Sprintf("<PositiveIntegerField.SetRaw> unknown value `%s`", value))
}
return nil
}
func (e *PositiveIntegerField) RawValue() interface{} {
return e.Value()
}
func (e *PositiveIntegerField) Clean() error {
return nil
}
var _ Fielder = new(PositiveIntegerField)
// 0 to 18446744073709551615
type PositiveBigIntegerField uint64
func (e PositiveBigIntegerField) Value() uint64 {
return uint64(e)
}
func (e *PositiveBigIntegerField) Set(d uint64) {
*e = PositiveBigIntegerField(d)
}
func (e *PositiveBigIntegerField) String() string {
return ToStr(e.Value())
}
func (e *PositiveBigIntegerField) FieldType() int {
return TypePositiveIntegerField
}
func (e *PositiveBigIntegerField) SetRaw(value interface{}) error {
switch d := value.(type) {
case uint64:
e.Set(d)
case string:
v, err := StrTo(d).Uint64()
if err != nil {
e.Set(v)
}
default:
return errors.New(fmt.Sprintf("<PositiveBigIntegerField.SetRaw> unknown value `%s`", value))
}
return nil
}
func (e *PositiveBigIntegerField) RawValue() interface{} {
return e.Value()
}
func (e *PositiveBigIntegerField) Clean() error {
return nil
}
var _ Fielder = new(PositiveBigIntegerField)
// A large text field.
type TextField string
func (e TextField) Value() string {
return string(e)
}
func (e *TextField) Set(d string) {
*e = TextField(d)
}
func (e *TextField) String() string {
return e.Value()
}
func (e *TextField) FieldType() int {
return TypeTextField
}
func (e *TextField) SetRaw(value interface{}) error {
switch d := value.(type) {
case string:
e.Set(d)
default:
return errors.New(fmt.Sprintf("<TextField.SetRaw> unknown value `%s`", value))
}
return nil
}
func (e *TextField) RawValue() interface{} {
return e.Value()
}
func (e *TextField) Clean() error {
return nil
}
var _ Fielder = new(TextField)

417
orm/models_info_f.go Normal file
View File

@ -0,0 +1,417 @@
package orm
import (
"errors"
"fmt"
"reflect"
"strings"
)
var errSkipField = errors.New("skip field")
type fields struct {
pk *fieldInfo
columns map[string]*fieldInfo
fields map[string]*fieldInfo
fieldsLow map[string]*fieldInfo
fieldsByType map[int][]*fieldInfo
fieldsRel []*fieldInfo
fieldsReverse []*fieldInfo
fieldsDB []*fieldInfo
rels []*fieldInfo
orders []string
dbcols []string
}
func (f *fields) Add(fi *fieldInfo) (added bool) {
if f.fields[fi.name] == nil && f.columns[fi.column] == nil {
f.columns[fi.column] = fi
f.fields[fi.name] = fi
f.fieldsLow[strings.ToLower(fi.name)] = fi
} else {
return
}
if _, ok := f.fieldsByType[fi.fieldType]; ok == false {
f.fieldsByType[fi.fieldType] = make([]*fieldInfo, 0)
}
f.fieldsByType[fi.fieldType] = append(f.fieldsByType[fi.fieldType], fi)
f.orders = append(f.orders, fi.column)
if fi.dbcol {
f.dbcols = append(f.dbcols, fi.column)
f.fieldsDB = append(f.fieldsDB, fi)
}
if fi.rel {
f.fieldsRel = append(f.fieldsRel, fi)
}
if fi.reverse {
f.fieldsReverse = append(f.fieldsReverse, fi)
}
return true
}
func (f *fields) GetByName(name string) *fieldInfo {
return f.fields[name]
}
func (f *fields) GetByColumn(column string) *fieldInfo {
return f.columns[column]
}
func (f *fields) GetByAny(name string) (*fieldInfo, bool) {
if fi, ok := f.fields[name]; ok {
return fi, ok
}
if fi, ok := f.fieldsLow[strings.ToLower(name)]; ok {
return fi, ok
}
if fi, ok := f.columns[name]; ok {
return fi, ok
}
return nil, false
}
func newFields() *fields {
f := new(fields)
f.fields = make(map[string]*fieldInfo)
f.fieldsLow = make(map[string]*fieldInfo)
f.columns = make(map[string]*fieldInfo)
f.fieldsByType = make(map[int][]*fieldInfo)
return f
}
type fieldInfo struct {
mi *modelInfo
fieldIndex int
fieldType int
dbcol bool
inModel bool
name string
fullName string
column string
addrValue reflect.Value
sf *reflect.StructField
auto bool
pk bool
null bool
blank bool
index bool
unique bool
initial StrTo
size int
auto_now bool
auto_now_add bool
rel bool
reverse bool
reverseField string
reverseFieldInfo *fieldInfo
relTable string
relThrough string
relThroughModelInfo *modelInfo
relModelInfo *modelInfo
digits int
decimals int
isFielder bool
onDelete string
}
func newFieldInfo(mi *modelInfo, field reflect.Value, sf reflect.StructField) (fi *fieldInfo, err error) {
var (
tag string
tagValue string
initial StrTo
fieldType int
attrs map[string]bool
tags map[string]string
addrField reflect.Value
)
fi = new(fieldInfo)
if field.Kind() != reflect.Ptr && field.Kind() != reflect.Slice && field.CanAddr() {
addrField = field.Addr()
} else {
addrField = field
}
parseStructTag(sf.Tag.Get(defaultStructTagName), &attrs, &tags)
if _, ok := attrs["-"]; ok {
return nil, errSkipField
}
digits := tags["digits"]
decimals := tags["decimals"]
size := tags["size"]
onDelete := tags["on_delete"]
initial.Clear()
if v, ok := tags["default"]; ok {
initial.Set(v)
}
checkType:
switch f := addrField.Interface().(type) {
case Fielder:
fi.isFielder = true
if field.Kind() == reflect.Ptr {
err = fmt.Errorf("the model Fielder can not be use ptr")
goto end
}
fieldType = f.FieldType()
if fieldType&IsRelField > 0 {
err = fmt.Errorf("unsupport rel type custom field")
goto end
}
default:
tag = "rel"
tagValue = tags[tag]
if tagValue != "" {
switch tagValue {
case "fk":
fieldType = RelForeignKey
break checkType
case "one":
fieldType = RelOneToOne
break checkType
case "m2m":
fieldType = RelManyToMany
if tv := tags["rel_table"]; tv != "" {
fi.relTable = tv
} else if tv := tags["rel_through"]; tv != "" {
fi.relThrough = tv
}
break checkType
default:
err = fmt.Errorf("error")
goto wrongTag
}
}
tag = "reverse"
tagValue = tags[tag]
if tagValue != "" {
switch tagValue {
case "one":
fieldType = RelReverseOne
break checkType
case "many":
fieldType = RelReverseMany
break checkType
default:
err = fmt.Errorf("error")
goto wrongTag
}
}
fieldType, err = getFieldType(addrField)
if err != nil {
goto end
}
if fieldType == TypeTextField && size != "" {
fieldType = TypeCharField
}
if fieldType == TypeFloatField && (digits != "" || decimals != "") {
fieldType = TypeDecimalField
}
if fieldType == TypeDateTimeField && tags["type"] == "date" {
fieldType = TypeDateField
}
}
switch fieldType {
case RelForeignKey, RelOneToOne, RelReverseOne:
if field.Kind() != reflect.Ptr {
err = fmt.Errorf("rel/reverse:one field must be *%s", field.Type().Name())
goto end
}
case RelManyToMany, RelReverseMany:
if field.Kind() != reflect.Slice {
err = fmt.Errorf("rel/reverse:many field must be slice")
goto end
} else {
if field.Type().Elem().Kind() != reflect.Ptr {
err = fmt.Errorf("rel/reverse:many slice must be []*%s", field.Type().Elem().Name())
goto end
}
}
}
if fieldType&IsFieldType == 0 {
err = fmt.Errorf("wrong field type")
goto end
}
fi.fieldType = fieldType
fi.name = sf.Name
fi.column = getColumnName(fieldType, addrField, sf, tags["column"])
fi.addrValue = addrField
fi.sf = &sf
fi.fullName = mi.fullName + "." + sf.Name
fi.null = attrs["null"]
fi.blank = attrs["blank"]
fi.index = attrs["index"]
fi.auto = attrs["auto"]
fi.pk = attrs["pk"]
fi.unique = attrs["unique"]
switch fieldType {
case RelManyToMany, RelReverseMany, RelReverseOne:
fi.null = false
fi.blank = false
fi.index = false
fi.auto = false
fi.pk = false
fi.unique = false
default:
fi.dbcol = true
}
switch fieldType {
case RelForeignKey, RelOneToOne, RelManyToMany:
fi.rel = true
if fieldType == RelOneToOne {
fi.unique = true
}
case RelReverseMany, RelReverseOne:
fi.reverse = true
}
if fi.rel && fi.dbcol {
switch onDelete {
case od_CASCADE, od_DO_NOTHING:
case od_SET_DEFAULT:
if initial.Exist() == false {
err = errors.New("on_delete: set_default need set field a default value")
goto end
}
case od_SET_NULL:
if fi.null == false {
err = errors.New("on_delete: set_null need set field null")
goto end
}
default:
if onDelete == "" {
onDelete = od_CASCADE
} else {
err = fmt.Errorf("on_delete value expected choice in `cascade,set_null,set_default,do_nothing`, unknown `%s`", onDelete)
goto end
}
}
fi.onDelete = onDelete
}
switch fieldType {
case TypeBooleanField:
case TypeCharField:
if size != "" {
v, e := StrTo(size).Int32()
if e != nil {
err = fmt.Errorf("wrong size value `%s`", size)
} else {
fi.size = int(v)
}
} else {
err = fmt.Errorf("size must be specify")
}
case TypeTextField:
fi.index = false
fi.unique = false
case TypeDateField, TypeDateTimeField:
if attrs["auto_now"] {
fi.auto_now = true
} else if attrs["auto_now_add"] {
fi.auto_now_add = true
}
case TypeFloatField:
case TypeDecimalField:
d1 := digits
d2 := decimals
v1, er1 := StrTo(d1).Int8()
v2, er2 := StrTo(d2).Int8()
if er1 != nil || er2 != nil {
err = fmt.Errorf("wrong digits/decimals value %s/%s", d2, d1)
goto end
}
fi.digits = int(v1)
fi.decimals = int(v2)
default:
switch {
case fieldType&IsIntegerField > 0:
case fieldType&IsRelField > 0:
}
}
if fieldType&IsIntegerField == 0 {
if fi.auto {
err = fmt.Errorf("non-integer type cannot set auto")
goto end
}
if fi.pk || fi.index || fi.unique {
if fieldType != TypeCharField && fieldType != RelOneToOne {
err = fmt.Errorf("cannot set pk/index/unique")
goto end
}
}
}
if fi.auto || fi.pk {
if fi.auto {
fi.pk = true
}
fi.null = false
fi.blank = false
fi.index = false
fi.unique = false
}
if fi.unique {
fi.blank = false
fi.index = false
}
if fi.auto || fi.pk || fi.unique || fieldType == TypeDateField || fieldType == TypeDateTimeField {
// can not set default
initial.Clear()
}
if initial.Exist() {
v := initial
switch fieldType {
case TypeBooleanField:
_, err = v.Bool()
case TypeFloatField, TypeDecimalField:
_, err = v.Float64()
case TypeBitField:
_, err = v.Int8()
case TypeSmallIntegerField:
_, err = v.Int16()
case TypeIntegerField:
_, err = v.Int32()
case TypeBigIntegerField:
_, err = v.Int64()
case TypePostiveBitField:
_, err = v.Uint8()
case TypePositiveSmallIntegerField:
_, err = v.Uint16()
case TypePositiveIntegerField:
_, err = v.Uint32()
case TypePositiveBigIntegerField:
_, err = v.Uint64()
}
if err != nil {
tag, tagValue = "default", tags["default"]
goto wrongTag
}
}
fi.initial = initial
end:
if err != nil {
return nil, err
}
return
wrongTag:
return nil, fmt.Errorf("wrong tag format: `%s:\"%s\"`, %s", tag, tagValue, err)
}

118
orm/models_info_m.go Normal file
View File

@ -0,0 +1,118 @@
package orm
import (
"errors"
"fmt"
"os"
"reflect"
)
type modelInfo struct {
pkg string
name string
fullName string
table string
model interface{}
fields *fields
manual bool
addrField reflect.Value
}
func newModelInfo(val reflect.Value) (info *modelInfo) {
var (
err error
fi *fieldInfo
sf reflect.StructField
)
info = &modelInfo{}
info.fields = newFields()
ind := reflect.Indirect(val)
typ := ind.Type()
info.addrField = ind.Addr()
info.name = typ.Name()
info.fullName = getFullName(typ)
for i := 0; i < ind.NumField(); i++ {
field := ind.Field(i)
sf = ind.Type().Field(i)
fi, err = newFieldInfo(info, field, sf)
if err != nil {
if err == errSkipField {
err = nil
continue
}
break
}
added := info.fields.Add(fi)
if added == false {
err = errors.New(fmt.Sprintf("duplicate column name: %s", fi.column))
break
}
if fi.pk {
if info.fields.pk != nil {
err = errors.New(fmt.Sprintf("one model must have one pk field only"))
break
} else {
info.fields.pk = fi
}
}
fi.fieldIndex = i
fi.mi = info
}
if err != nil {
fmt.Println(fmt.Errorf("field: %s.%s, %s", ind.Type(), sf.Name, err))
os.Exit(2)
}
return
}
func newM2MModelInfo(m1, m2 *modelInfo) (info *modelInfo) {
info = new(modelInfo)
info.fields = newFields()
info.table = m1.table + "_" + m2.table + "s"
info.name = camelString(info.table)
info.fullName = m1.pkg + "." + info.name
fa := new(fieldInfo)
f1 := new(fieldInfo)
f2 := new(fieldInfo)
fa.fieldType = TypeBigIntegerField
fa.auto = true
fa.pk = true
fa.dbcol = true
f1.dbcol = true
f2.dbcol = true
f1.fieldType = RelForeignKey
f2.fieldType = RelForeignKey
f1.name = camelString(m1.table)
f2.name = camelString(m2.table)
f1.fullName = info.fullName + "." + f1.name
f2.fullName = info.fullName + "." + f2.name
f1.column = m1.table + "_id"
f2.column = m2.table + "_id"
f1.rel = true
f2.rel = true
f1.relTable = m1.table
f2.relTable = m2.table
f1.relModelInfo = m1
f2.relModelInfo = m2
f1.mi = info
f2.mi = info
info.fields.Add(fa)
info.fields.Add(f1)
info.fields.Add(f2)
info.fields.pk = fa
return
}

512
orm/models_test.go Normal file
View File

@ -0,0 +1,512 @@
package orm
import (
"fmt"
"os"
"strings"
"time"
// _ "github.com/bylevel/pq"
_ "github.com/go-sql-driver/mysql"
_ "github.com/lib/pq"
_ "github.com/mattn/go-sqlite3"
)
type Data struct {
Id int `orm:"auto"`
Boolean bool
Char string `orm:"size(50)"`
Text string
Date time.Time `orm:"type(date)"`
DateTime time.Time
Byte byte
Rune rune
Int int
Int8 int8
Int16 int16
Int32 int32
Int64 int64
Uint uint
Uint8 uint8
Uint16 uint16
Uint32 uint32
Uint64 uint64
Float32 float32
Float64 float64
Decimal float64 `orm:"digits(8);decimals(4)"`
}
type DataNull struct {
Id int `orm:"auto"`
Boolean bool `orm:"null"`
Char string `orm:"size(50);null"`
Text string `orm:"null"`
Date time.Time `orm:"type(date);null"`
DateTime time.Time `orm:"null"`
Byte byte `orm:"null"`
Rune rune `orm:"null"`
Int int `orm:"null"`
Int8 int8 `orm:"null"`
Int16 int16 `orm:"null"`
Int32 int32 `orm:"null"`
Int64 int64 `orm:"null"`
Uint uint `orm:"null"`
Uint8 uint8 `orm:"null"`
Uint16 uint16 `orm:"null"`
Uint32 uint32 `orm:"null"`
Uint64 uint64 `orm:"null"`
Float32 float32 `orm:"null"`
Float64 float64 `orm:"null"`
Decimal float64 `orm:"digits(8);decimals(4);null"`
}
type User struct {
Id int `orm:"auto"`
UserName string `orm:"size(30);unique"`
Email string `orm:"size(100)"`
Password string `orm:"size(100)"`
Status int16
IsStaff bool
IsActive bool `orm:"default(1)"`
Created time.Time `orm:"auto_now_add;type(date)"`
Updated time.Time `orm:"auto_now"`
Profile *Profile `orm:"null;rel(one);on_delete(set_null)"`
Posts []*Post `orm:"reverse(many)" json:"-"`
ShouldSkip string `orm:"-"`
}
func NewUser() *User {
obj := new(User)
return obj
}
type Profile struct {
Id int `orm:"auto"`
Age int16 ``
Money float64 ``
User *User `orm:"reverse(one)" json:"-"`
}
func (u *Profile) TableName() string {
return "user_profile"
}
func NewProfile() *Profile {
obj := new(Profile)
return obj
}
type Post struct {
Id int `orm:"auto"`
User *User `orm:"rel(fk)"` //
Title string `orm:"size(60)"`
Content string ``
Created time.Time `orm:"auto_now_add"`
Updated time.Time `orm:"auto_now"`
Tags []*Tag `orm:"rel(m2m)"`
}
func NewPost() *Post {
obj := new(Post)
return obj
}
type Tag struct {
Id int `orm:"auto"`
Name string `orm:"size(30)"`
Posts []*Post `orm:"reverse(many)" json:"-"`
}
func NewTag() *Tag {
obj := new(Tag)
return obj
}
type Comment struct {
Id int `orm:"auto"`
Post *Post `orm:"rel(fk)"`
Content string ``
Parent *Comment `orm:"null;rel(fk)"`
Created time.Time `orm:"auto_now_add"`
}
func NewComment() *Comment {
obj := new(Comment)
return obj
}
var DBARGS = struct {
Driver string
Source string
Debug string
}{
os.Getenv("ORM_DRIVER"),
os.Getenv("ORM_SOURCE"),
os.Getenv("ORM_DEBUG"),
}
var (
IsMysql = DBARGS.Driver == "mysql"
IsSqlite = DBARGS.Driver == "sqlite3"
IsPostgres = DBARGS.Driver == "postgres"
)
var dORM Ormer
var initSQLs = map[string]string{
"mysql": "DROP TABLE IF EXISTS `user_profile`;\n" +
"DROP TABLE IF EXISTS `user`;\n" +
"DROP TABLE IF EXISTS `post`;\n" +
"DROP TABLE IF EXISTS `tag`;\n" +
"DROP TABLE IF EXISTS `post_tags`;\n" +
"DROP TABLE IF EXISTS `comment`;\n" +
"DROP TABLE IF EXISTS `data`;\n" +
"DROP TABLE IF EXISTS `data_null`;\n" +
"CREATE TABLE `user_profile` (\n" +
" `id` integer AUTO_INCREMENT NOT NULL PRIMARY KEY,\n" +
" `age` smallint NOT NULL,\n" +
" `money` double precision NOT NULL\n" +
") ENGINE=INNODB;\n" +
"CREATE TABLE `user` (\n" +
" `id` integer AUTO_INCREMENT NOT NULL PRIMARY KEY,\n" +
" `user_name` varchar(30) NOT NULL UNIQUE,\n" +
" `email` varchar(100) NOT NULL,\n" +
" `password` varchar(100) NOT NULL,\n" +
" `status` smallint NOT NULL,\n" +
" `is_staff` bool NOT NULL,\n" +
" `is_active` bool NOT NULL,\n" +
" `created` date NOT NULL,\n" +
" `updated` datetime NOT NULL,\n" +
" `profile_id` integer\n" +
") ENGINE=INNODB;\n" +
"CREATE TABLE `post` (\n" +
" `id` integer AUTO_INCREMENT NOT NULL PRIMARY KEY,\n" +
" `user_id` integer NOT NULL,\n" +
" `title` varchar(60) NOT NULL,\n" +
" `content` longtext NOT NULL,\n" +
" `created` datetime NOT NULL,\n" +
" `updated` datetime NOT NULL\n" +
") ENGINE=INNODB;\n" +
"CREATE TABLE `tag` (\n" +
" `id` integer AUTO_INCREMENT NOT NULL PRIMARY KEY,\n" +
" `name` varchar(30) NOT NULL\n" +
") ENGINE=INNODB;\n" +
"CREATE TABLE `post_tags` (\n" +
" `id` integer AUTO_INCREMENT NOT NULL PRIMARY KEY,\n" +
" `post_id` integer NOT NULL,\n" +
" `tag_id` integer NOT NULL,\n" +
" UNIQUE (`post_id`, `tag_id`)\n" +
") ENGINE=INNODB;\n" +
"CREATE TABLE `comment` (\n" +
" `id` integer AUTO_INCREMENT NOT NULL PRIMARY KEY,\n" +
" `post_id` integer NOT NULL,\n" +
" `content` longtext NOT NULL,\n" +
" `parent_id` integer,\n" +
" `created` datetime NOT NULL\n" +
") ENGINE=INNODB;\n" +
"CREATE TABLE `data` (\n" +
" `id` integer AUTO_INCREMENT NOT NULL PRIMARY KEY,\n" +
" `boolean` bool NOT NULL,\n" +
" `char` varchar(50) NOT NULL,\n" +
" `text` longtext NOT NULL,\n" +
" `date` date NOT NULL,\n" +
" `date_time` datetime NOT NULL,\n" +
" `byte` tinyint unsigned NOT NULL,\n" +
" `rune` integer NOT NULL,\n" +
" `int` integer NOT NULL,\n" +
" `int8` tinyint NOT NULL,\n" +
" `int16` smallint NOT NULL,\n" +
" `int32` integer NOT NULL,\n" +
" `int64` bigint NOT NULL,\n" +
" `uint` integer unsigned NOT NULL,\n" +
" `uint8` tinyint unsigned NULL,\n" +
" `uint16` smallint unsigned NOT NULL,\n" +
" `uint32` integer unsigned NOT NULL,\n" +
" `uint64` bigint unsigned NOT NULL,\n" +
" `float32` double precision NOT NULL,\n" +
" `float64` double precision NOT NULL,\n" +
" `decimal` numeric(8,4) NOT NULL\n" +
") ENGINE=INNODB;\n" +
"CREATE TABLE `data_null` (\n" +
" `id` integer AUTO_INCREMENT NOT NULL PRIMARY KEY,\n" +
" `boolean` bool,\n" +
" `char` varchar(50),\n" +
" `text` longtext,\n" +
" `date` date,\n" +
" `date_time` datetime,\n" +
" `byte` tinyint unsigned,\n" +
" `rune` integer,\n" +
" `int` integer,\n" +
" `int8` tinyint,\n" +
" `int16` smallint,\n" +
" `int32` integer,\n" +
" `int64` bigint,\n" +
" `uint` integer unsigned,\n" +
" `uint8` tinyint unsigned,\n" +
" `uint16` smallint unsigned,\n" +
" `uint32` integer unsigned,\n" +
" `uint64` bigint unsigned,\n" +
" `float32` double precision,\n" +
" `float64` double precision,\n" +
" `decimal` numeric(8,4)\n" +
") ENGINE=INNODB;\n" +
"CREATE INDEX `user_141c6eec` ON `user` (`profile_id`);\n" +
"CREATE INDEX `post_fbfc09f1` ON `post` (`user_id`);\n" +
"CREATE INDEX `comment_699ae8ca` ON `comment` (`post_id`);\n" +
"CREATE INDEX `comment_63f17a16` ON `comment` (`parent_id`);",
"sqlite3": `
DROP TABLE IF EXISTS "user_profile";
DROP TABLE IF EXISTS "user";
DROP TABLE IF EXISTS "post";
DROP TABLE IF EXISTS "tag";
DROP TABLE IF EXISTS "post_tags";
DROP TABLE IF EXISTS "comment";
DROP TABLE IF EXISTS "data";
DROP TABLE IF EXISTS "data_null";
CREATE TABLE "user_profile" (
"id" integer NOT NULL PRIMARY KEY AUTOINCREMENT,
"age" smallint NOT NULL,
"money" real NOT NULL
);
CREATE TABLE "user" (
"id" integer NOT NULL PRIMARY KEY AUTOINCREMENT,
"user_name" varchar(30) NOT NULL UNIQUE,
"email" varchar(100) NOT NULL,
"password" varchar(100) NOT NULL,
"status" smallint NOT NULL,
"is_staff" bool NOT NULL,
"is_active" bool NOT NULL,
"created" date NOT NULL,
"updated" datetime NOT NULL,
"profile_id" integer
);
CREATE TABLE "post" (
"id" integer NOT NULL PRIMARY KEY AUTOINCREMENT,
"user_id" integer NOT NULL,
"title" varchar(60) NOT NULL,
"content" text NOT NULL,
"created" datetime NOT NULL,
"updated" datetime NOT NULL
);
CREATE TABLE "tag" (
"id" integer NOT NULL PRIMARY KEY AUTOINCREMENT,
"name" varchar(30) NOT NULL
);
CREATE TABLE "post_tags" (
"id" integer NOT NULL PRIMARY KEY AUTOINCREMENT,
"post_id" integer NOT NULL,
"tag_id" integer NOT NULL,
UNIQUE ("post_id", "tag_id")
);
CREATE TABLE "comment" (
"id" integer NOT NULL PRIMARY KEY AUTOINCREMENT,
"post_id" integer NOT NULL,
"content" text NOT NULL,
"parent_id" integer,
"created" datetime NOT NULL
);
CREATE TABLE "data" (
"id" integer NOT NULL PRIMARY KEY AUTOINCREMENT,
"boolean" bool NOT NULL,
"char" varchar(50) NOT NULL,
"text" text NOT NULL,
"date" date NOT NULL,
"date_time" datetime NOT NULL,
"byte" tinyint unsigned NOT NULL,
"rune" integer NOT NULL,
"int" integer NOT NULL,
"int8" tinyint NOT NULL,
"int16" smallint NOT NULL,
"int32" integer NOT NULL,
"int64" bigint NOT NULL,
"uint" integer unsigned NOT NULL,
"uint8" tinyint unsigned NOT NULL,
"uint16" smallint unsigned NOT NULL,
"uint32" integer unsigned NOT NULL,
"uint64" bigint unsigned NOT NULL,
"float32" real NOT NULL,
"float64" real NOT NULL,
"decimal" decimal
);
CREATE TABLE "data_null" (
"id" integer NOT NULL PRIMARY KEY AUTOINCREMENT,
"boolean" bool,
"char" varchar(50),
"text" text,
"date" date,
"date_time" datetime,
"byte" tinyint unsigned,
"rune" integer,
"int" integer,
"int8" tinyint,
"int16" smallint,
"int32" integer,
"int64" bigint,
"uint" integer unsigned,
"uint8" tinyint unsigned,
"uint16" smallint unsigned,
"uint32" integer unsigned,
"uint64" bigint unsigned,
"float32" real,
"float64" real,
"decimal" decimal
);
CREATE INDEX "user_141c6eec" ON "user" ("profile_id");
CREATE INDEX "post_fbfc09f1" ON "post" ("user_id");
CREATE INDEX "comment_699ae8ca" ON "comment" ("post_id");
CREATE INDEX "comment_63f17a16" ON "comment" ("parent_id");
`,
"postgres": `
DROP TABLE IF EXISTS "user_profile";
DROP TABLE IF EXISTS "user";
DROP TABLE IF EXISTS "post";
DROP TABLE IF EXISTS "tag";
DROP TABLE IF EXISTS "post_tags";
DROP TABLE IF EXISTS "comment";
DROP TABLE IF EXISTS "data";
DROP TABLE IF EXISTS "data_null";
CREATE TABLE "user_profile" (
"id" serial NOT NULL PRIMARY KEY,
"age" smallint NOT NULL,
"money" double precision NOT NULL
);
CREATE TABLE "user" (
"id" serial NOT NULL PRIMARY KEY,
"user_name" varchar(30) NOT NULL UNIQUE,
"email" varchar(100) NOT NULL,
"password" varchar(100) NOT NULL,
"status" smallint NOT NULL,
"is_staff" boolean NOT NULL,
"is_active" boolean NOT NULL,
"created" date NOT NULL,
"updated" timestamp with time zone NOT NULL,
"profile_id" integer
);
CREATE TABLE "post" (
"id" serial NOT NULL PRIMARY KEY,
"user_id" integer NOT NULL,
"title" varchar(60) NOT NULL,
"content" text NOT NULL,
"created" timestamp with time zone NOT NULL,
"updated" timestamp with time zone NOT NULL
);
CREATE TABLE "tag" (
"id" serial NOT NULL PRIMARY KEY,
"name" varchar(30) NOT NULL
);
CREATE TABLE "post_tags" (
"id" serial NOT NULL PRIMARY KEY,
"post_id" integer NOT NULL,
"tag_id" integer NOT NULL,
UNIQUE ("post_id", "tag_id")
);
CREATE TABLE "comment" (
"id" serial NOT NULL PRIMARY KEY,
"post_id" integer NOT NULL,
"content" text NOT NULL,
"parent_id" integer,
"created" timestamp with time zone NOT NULL
);
CREATE TABLE "data" (
"id" serial NOT NULL PRIMARY KEY,
"boolean" bool NOT NULL,
"char" varchar(50) NOT NULL,
"text" text NOT NULL,
"date" date NOT NULL,
"date_time" timestamp with time zone NOT NULL,
"byte" smallint CHECK("byte" >= 0 AND "byte" <= 255) NOT NULL,
"rune" integer NOT NULL,
"int" integer NOT NULL,
"int8" smallint CHECK("int8" >= -127 AND "int8" <= 128) NOT NULL,
"int16" smallint NOT NULL,
"int32" integer NOT NULL,
"int64" bigint NOT NULL,
"uint" bigint CHECK("uint" >= 0) NOT NULL,
"uint8" smallint CHECK("uint8" >= 0 AND "uint8" <= 255) NOT NULL,
"uint16" integer CHECK("uint16" >= 0) NOT NULL,
"uint32" bigint CHECK("uint32" >= 0) NOT NULL,
"uint64" bigint CHECK("uint64" >= 0) NOT NULL,
"float32" double precision NOT NULL,
"float64" double precision NOT NULL,
"decimal" numeric(8, 4)
);
CREATE TABLE "data_null" (
"id" serial NOT NULL PRIMARY KEY,
"boolean" bool,
"char" varchar(50),
"text" text,
"date" date,
"date_time" timestamp with time zone,
"byte" smallint CHECK("byte" >= 0 AND "byte" <= 255),
"rune" integer,
"int" integer,
"int8" smallint CHECK("int8" >= -127 AND "int8" <= 128),
"int16" smallint,
"int32" integer,
"int64" bigint,
"uint" bigint CHECK("uint" >= 0),
"uint8" smallint CHECK("uint8" >= 0 AND "uint8" <= 255),
"uint16" integer CHECK("uint16" >= 0),
"uint32" bigint CHECK("uint32" >= 0),
"uint64" bigint CHECK("uint64" >= 0),
"float32" double precision,
"float64" double precision,
"decimal" numeric(8, 4)
);
CREATE INDEX "user_profile_id" ON "user" ("profile_id");
CREATE INDEX "post_user_id" ON "post" ("user_id");
CREATE INDEX "comment_post_id" ON "comment" ("post_id");
CREATE INDEX "comment_parent_id" ON "comment" ("parent_id");
`}
func init() {
// err := os.Setenv("TZ", "+00:00")
// fmt.Println(err)
RegisterModel(new(Data), new(DataNull))
RegisterModel(new(User))
RegisterModel(new(Profile))
RegisterModel(new(Post))
RegisterModel(new(Tag))
RegisterModel(new(Comment))
Debug, _ = StrTo(DBARGS.Debug).Bool()
if DBARGS.Driver == "" || DBARGS.Source == "" {
fmt.Println(`need driver and source!
Default DB Drivers.
driver: url
mysql: https://github.com/go-sql-driver/mysql
sqlite3: https://github.com/mattn/go-sqlite3
postgres: https://github.com/lib/pq
eg: mysql
ORM_DRIVER=mysql ORM_SOURCE="root:root@/my_db?charset=utf8" go test github.com/astaxie/beego/orm
`)
os.Exit(2)
}
RegisterDataBase("default", DBARGS.Driver, DBARGS.Source, 20)
BootStrap()
dORM = NewOrm()
queries := strings.Split(initSQLs[DBARGS.Driver], ";")
for _, query := range queries {
query = strings.TrimSpace(query)
if len(query) == 0 {
continue
}
_, err := dORM.Raw(query).Exec()
if err != nil {
fmt.Println(err)
os.Exit(2)
}
}
}

99
orm/models_utils.go Normal file
View File

@ -0,0 +1,99 @@
package orm
import (
"fmt"
"reflect"
"strings"
"time"
)
func getFullName(typ reflect.Type) string {
return typ.PkgPath() + "." + typ.Name()
}
func getTableName(val reflect.Value) string {
ind := reflect.Indirect(val)
fun := val.MethodByName("TableName")
if fun.IsValid() {
vals := fun.Call([]reflect.Value{})
if len(vals) > 0 {
val := vals[0]
if val.Kind() == reflect.String {
return val.String()
}
}
}
return snakeString(ind.Type().Name())
}
func getColumnName(ft int, addrField reflect.Value, sf reflect.StructField, col string) string {
column := strings.ToLower(col)
if column == "" {
column = snakeString(sf.Name)
}
switch ft {
case RelForeignKey, RelOneToOne:
column = column + "_id"
case RelManyToMany, RelReverseMany, RelReverseOne:
column = sf.Name
}
return column
}
func getFieldType(val reflect.Value) (ft int, err error) {
elm := reflect.Indirect(val)
switch elm.Kind() {
case reflect.Int8:
ft = TypeBitField
case reflect.Int16:
ft = TypeSmallIntegerField
case reflect.Int32, reflect.Int:
ft = TypeIntegerField
case reflect.Int64:
ft = TypeBigIntegerField
case reflect.Uint8:
ft = TypePostiveBitField
case reflect.Uint16:
ft = TypePositiveSmallIntegerField
case reflect.Uint32, reflect.Uint:
ft = TypePositiveIntegerField
case reflect.Uint64:
ft = TypePositiveBigIntegerField
case reflect.Float32, reflect.Float64:
ft = TypeFloatField
case reflect.Bool:
ft = TypeBooleanField
case reflect.String:
ft = TypeTextField
case reflect.Invalid:
default:
if elm.CanInterface() {
if _, ok := elm.Interface().(time.Time); ok {
ft = TypeDateTimeField
}
}
}
if ft&IsFieldType == 0 {
err = fmt.Errorf("unsupport field type %s, may be miss setting tag", val)
}
return
}
func parseStructTag(data string, attrs *map[string]bool, tags *map[string]string) {
attr := make(map[string]bool)
tag := make(map[string]string)
for _, v := range strings.Split(data, ";") {
v = strings.TrimSpace(v)
if supportTag[v] == 1 {
attr[v] = true
} else if i := strings.Index(v, "("); i > 0 && strings.Index(v, ")") == len(v)-1 {
name := v[:i]
if supportTag[name] == 2 {
v = v[i+1 : len(v)-1]
tag[name] = v
}
}
}
*attrs = attr
*tags = tag
}

221
orm/orm.go Normal file
View File

@ -0,0 +1,221 @@
package orm
import (
"database/sql"
"errors"
"fmt"
"os"
"reflect"
"time"
)
const (
Debug_Queries = iota
)
var (
// DebugLevel = Debug_Queries
Debug = false
DebugLog = NewLog(os.Stderr)
DefaultRowsLimit = 1000
DefaultRelsDepth = 5
DefaultTimeLoc = time.Local
ErrTxHasBegan = errors.New("<Ormer.Begin> transaction already begin")
ErrTxDone = errors.New("<Ormer.Commit/Rollback> transaction not begin")
ErrMultiRows = errors.New("<QuerySeter> return multi rows")
ErrNoRows = errors.New("<QuerySeter> no row found")
ErrStmtClosed = errors.New("<QuerySeter> stmt already closed")
ErrNotImplement = errors.New("have not implement")
)
type Params map[string]interface{}
type ParamsList []interface{}
type orm struct {
alias *alias
db dbQuerier
isTx bool
}
var _ Ormer = new(orm)
func (o *orm) getMiInd(md interface{}) (mi *modelInfo, ind reflect.Value) {
val := reflect.ValueOf(md)
ind = reflect.Indirect(val)
typ := ind.Type()
if val.Kind() != reflect.Ptr {
panic(fmt.Sprintf("<Ormer> cannot use non-ptr model struct `%s`", getFullName(typ)))
}
name := getFullName(typ)
if mi, ok := modelCache.getByFN(name); ok {
return mi, ind
}
panic(fmt.Sprintf("<Ormer> table: `%s` not found, maybe not RegisterModel", name))
}
func (o *orm) Read(md interface{}) error {
mi, ind := o.getMiInd(md)
err := o.alias.DbBaser.Read(o.db, mi, ind, o.alias.TZ)
if err != nil {
return err
}
return nil
}
func (o *orm) Insert(md interface{}) (int64, error) {
mi, ind := o.getMiInd(md)
id, err := o.alias.DbBaser.Insert(o.db, mi, ind, o.alias.TZ)
if err != nil {
return id, err
}
if id > 0 {
if mi.fields.pk.auto {
ind.Field(mi.fields.pk.fieldIndex).SetInt(id)
}
}
return id, nil
}
func (o *orm) Update(md interface{}) (int64, error) {
mi, ind := o.getMiInd(md)
num, err := o.alias.DbBaser.Update(o.db, mi, ind, o.alias.TZ)
if err != nil {
return num, err
}
return num, nil
}
func (o *orm) Delete(md interface{}) (int64, error) {
mi, ind := o.getMiInd(md)
num, err := o.alias.DbBaser.Delete(o.db, mi, ind, o.alias.TZ)
if err != nil {
return num, err
}
if num > 0 {
if mi.fields.pk.auto {
ind.Field(mi.fields.pk.fieldIndex).SetInt(0)
}
}
return num, nil
}
func (o *orm) M2mAdd(md interface{}, name string, mds ...interface{}) (int64, error) {
// TODO
panic(ErrNotImplement)
return 0, nil
}
func (o *orm) M2mDel(md interface{}, name string, mds ...interface{}) (int64, error) {
// TODO
panic(ErrNotImplement)
return 0, nil
}
func (o *orm) LoadRel(md interface{}, name string) (int64, error) {
// TODO
panic(ErrNotImplement)
return 0, nil
}
func (o *orm) QueryTable(ptrStructOrTableName interface{}) (qs QuerySeter) {
name := ""
if table, ok := ptrStructOrTableName.(string); ok {
name = snakeString(table)
if mi, ok := modelCache.get(name); ok {
qs = newQuerySet(o, mi)
}
} else {
val := reflect.ValueOf(ptrStructOrTableName)
ind := reflect.Indirect(val)
name = getFullName(ind.Type())
if mi, ok := modelCache.getByFN(name); ok {
qs = newQuerySet(o, mi)
}
}
if qs == nil {
panic(fmt.Sprintf("<Ormer.QueryTable> table name: `%s` not exists", name))
}
return
}
func (o *orm) Using(name string) error {
if o.isTx {
panic("<Ormer.Using> transaction has been start, cannot change db")
}
if al, ok := dataBaseCache.get(name); ok {
o.alias = al
if Debug {
o.db = newDbQueryLog(al, al.DB)
} else {
o.db = al.DB
}
} else {
return errors.New(fmt.Sprintf("<Ormer.Using> unknown db alias name `%s`", name))
}
return nil
}
func (o *orm) Begin() error {
if o.isTx {
return ErrTxHasBegan
}
var tx *sql.Tx
tx, err := o.db.(txer).Begin()
if err != nil {
return err
}
o.isTx = true
if Debug {
o.db.(*dbQueryLog).SetDB(tx)
} else {
o.db = tx
}
return nil
}
func (o *orm) Commit() error {
if o.isTx == false {
return ErrTxDone
}
err := o.db.(txEnder).Commit()
if err == nil {
o.isTx = false
o.Using(o.alias.Name)
} else if err == sql.ErrTxDone {
return ErrTxDone
}
return err
}
func (o *orm) Rollback() error {
if o.isTx == false {
return ErrTxDone
}
err := o.db.(txEnder).Rollback()
if err == nil {
o.isTx = false
o.Using(o.alias.Name)
} else if err == sql.ErrTxDone {
return ErrTxDone
}
return err
}
func (o *orm) Raw(query string, args ...interface{}) RawSeter {
return newRawSet(o, query, args)
}
func (o *orm) Driver() Driver {
return driver(o.alias.Name)
}
func NewOrm() Ormer {
BootStrap() // execute only once
o := new(orm)
err := o.Using("default")
if err != nil {
panic(err)
}
return o
}

89
orm/orm_conds.go Normal file
View File

@ -0,0 +1,89 @@
package orm
import (
"strings"
)
const (
ExprSep = "__"
)
type condValue struct {
exprs []string
args []interface{}
cond *Condition
isOr bool
isNot bool
isCond bool
}
type Condition struct {
params []condValue
}
func NewCondition() *Condition {
c := &Condition{}
return c
}
func (c Condition) And(expr string, args ...interface{}) *Condition {
if expr == "" || len(args) == 0 {
panic("<Condition.And> args cannot empty")
}
c.params = append(c.params, condValue{exprs: strings.Split(expr, ExprSep), args: args})
return &c
}
func (c Condition) AndNot(expr string, args ...interface{}) *Condition {
if expr == "" || len(args) == 0 {
panic("<Condition.AndNot> args cannot empty")
}
c.params = append(c.params, condValue{exprs: strings.Split(expr, ExprSep), args: args, isNot: true})
return &c
}
func (c *Condition) AndCond(cond *Condition) *Condition {
c = c.clone()
if c == cond {
panic("cannot use self as sub cond")
}
if cond != nil {
c.params = append(c.params, condValue{cond: cond, isCond: true})
}
return c
}
func (c Condition) Or(expr string, args ...interface{}) *Condition {
if expr == "" || len(args) == 0 {
panic("<Condition.Or> args cannot empty")
}
c.params = append(c.params, condValue{exprs: strings.Split(expr, ExprSep), args: args, isOr: true})
return &c
}
func (c Condition) OrNot(expr string, args ...interface{}) *Condition {
if expr == "" || len(args) == 0 {
panic("<Condition.OrNot> args cannot empty")
}
c.params = append(c.params, condValue{exprs: strings.Split(expr, ExprSep), args: args, isNot: true, isOr: true})
return &c
}
func (c *Condition) OrCond(cond *Condition) *Condition {
c = c.clone()
if c == cond {
panic("cannot use self as sub cond")
}
if cond != nil {
c.params = append(c.params, condValue{cond: cond, isCond: true, isOr: true})
}
return c
}
func (c *Condition) IsEmpty() bool {
return len(c.params) == 0
}
func (c Condition) clone() *Condition {
return &c
}

156
orm/orm_log.go Normal file
View File

@ -0,0 +1,156 @@
package orm
import (
"database/sql"
"fmt"
"io"
"log"
"strings"
"time"
)
type Log struct {
*log.Logger
}
func NewLog(out io.Writer) *Log {
d := new(Log)
d.Logger = log.New(out, "[ORM]", 1e9)
return d
}
func debugLogQueies(alias *alias, operaton, query string, t time.Time, err error, args ...interface{}) {
sub := time.Now().Sub(t) / 1e5
elsp := float64(int(sub)) / 10.0
flag := " OK"
if err != nil {
flag = "FAIL"
}
con := fmt.Sprintf(" - %s - [Queries/%s] - [%s / %11s / %7.1fms] - [%s]", t.Format(format_DateTime), alias.Name, flag, operaton, elsp, query)
cons := make([]string, 0, len(args))
for _, arg := range args {
cons = append(cons, fmt.Sprintf("%v", arg))
}
if len(cons) > 0 {
con += fmt.Sprintf(" - `%s`", strings.Join(cons, "`, `"))
}
if err != nil {
con += " - " + err.Error()
}
DebugLog.Println(con)
}
type stmtQueryLog struct {
alias *alias
query string
stmt stmtQuerier
}
var _ stmtQuerier = new(stmtQueryLog)
func (d *stmtQueryLog) Close() error {
a := time.Now()
err := d.stmt.Close()
debugLogQueies(d.alias, "st.Close", d.query, a, err)
return err
}
func (d *stmtQueryLog) Exec(args ...interface{}) (sql.Result, error) {
a := time.Now()
res, err := d.stmt.Exec(args...)
debugLogQueies(d.alias, "st.Exec", d.query, a, err, args...)
return res, err
}
func (d *stmtQueryLog) Query(args ...interface{}) (*sql.Rows, error) {
a := time.Now()
res, err := d.stmt.Query(args...)
debugLogQueies(d.alias, "st.Query", d.query, a, err, args...)
return res, err
}
func (d *stmtQueryLog) QueryRow(args ...interface{}) *sql.Row {
a := time.Now()
res := d.stmt.QueryRow(args...)
debugLogQueies(d.alias, "st.QueryRow", d.query, a, nil, args...)
return res
}
func newStmtQueryLog(alias *alias, stmt stmtQuerier, query string) stmtQuerier {
d := new(stmtQueryLog)
d.stmt = stmt
d.alias = alias
d.query = query
return d
}
type dbQueryLog struct {
alias *alias
db dbQuerier
tx txer
txe txEnder
}
var _ dbQuerier = new(dbQueryLog)
var _ txer = new(dbQueryLog)
var _ txEnder = new(dbQueryLog)
func (d *dbQueryLog) Prepare(query string) (*sql.Stmt, error) {
a := time.Now()
stmt, err := d.db.Prepare(query)
debugLogQueies(d.alias, "db.Prepare", query, a, err)
return stmt, err
}
func (d *dbQueryLog) Exec(query string, args ...interface{}) (sql.Result, error) {
a := time.Now()
res, err := d.db.Exec(query, args...)
debugLogQueies(d.alias, "db.Exec", query, a, err, args...)
return res, err
}
func (d *dbQueryLog) Query(query string, args ...interface{}) (*sql.Rows, error) {
a := time.Now()
res, err := d.db.Query(query, args...)
debugLogQueies(d.alias, "db.Query", query, a, err, args...)
return res, err
}
func (d *dbQueryLog) QueryRow(query string, args ...interface{}) *sql.Row {
a := time.Now()
res := d.db.QueryRow(query, args...)
debugLogQueies(d.alias, "db.QueryRow", query, a, nil, args...)
return res
}
func (d *dbQueryLog) Begin() (*sql.Tx, error) {
a := time.Now()
tx, err := d.db.(txer).Begin()
debugLogQueies(d.alias, "db.Begin", "START TRANSACTION", a, err)
return tx, err
}
func (d *dbQueryLog) Commit() error {
a := time.Now()
err := d.db.(txEnder).Commit()
debugLogQueies(d.alias, "tx.Commit", "COMMIT", a, err)
return err
}
func (d *dbQueryLog) Rollback() error {
a := time.Now()
err := d.db.(txEnder).Rollback()
debugLogQueies(d.alias, "tx.Rollback", "ROLLBACK", a, err)
return err
}
func (d *dbQueryLog) SetDB(db dbQuerier) {
d.db = db
}
func newDbQueryLog(alias *alias, db dbQuerier) dbQuerier {
d := new(dbQueryLog)
d.alias = alias
d.db = db
return d
}

65
orm/orm_object.go Normal file
View File

@ -0,0 +1,65 @@
package orm
import (
"fmt"
"reflect"
)
type insertSet struct {
mi *modelInfo
orm *orm
stmt stmtQuerier
closed bool
}
var _ Inserter = new(insertSet)
func (o *insertSet) Insert(md interface{}) (int64, error) {
if o.closed {
return 0, ErrStmtClosed
}
val := reflect.ValueOf(md)
ind := reflect.Indirect(val)
typ := ind.Type()
name := getFullName(typ)
if val.Kind() != reflect.Ptr {
panic(fmt.Sprintf("<Inserter.Insert> cannot use non-ptr model struct `%s`", name))
}
if name != o.mi.fullName {
panic(fmt.Sprintf("<Inserter.Insert> need model `%s` but found `%s`", o.mi.fullName, name))
}
id, err := o.orm.alias.DbBaser.InsertStmt(o.stmt, o.mi, ind, o.orm.alias.TZ)
if err != nil {
return id, err
}
if id > 0 {
if o.mi.fields.pk.auto {
ind.Field(o.mi.fields.pk.fieldIndex).SetInt(id)
}
}
return id, nil
}
func (o *insertSet) Close() error {
if o.closed {
return ErrStmtClosed
}
o.closed = true
return o.stmt.Close()
}
func newInsertSet(orm *orm, mi *modelInfo) (Inserter, error) {
bi := new(insertSet)
bi.orm = orm
bi.mi = mi
st, query, err := orm.alias.DbBaser.PrepareInsert(orm.db, mi)
if err != nil {
return nil, err
}
if Debug {
bi.stmt = newStmtQueryLog(orm.alias, st, query)
} else {
bi.stmt = st
}
return bi, nil
}

130
orm/orm_queryset.go Normal file
View File

@ -0,0 +1,130 @@
package orm
import (
"fmt"
)
type querySet struct {
mi *modelInfo
cond *Condition
related []string
relDepth int
limit int
offset int64
orders []string
orm *orm
}
var _ QuerySeter = new(querySet)
func (o querySet) Filter(expr string, args ...interface{}) QuerySeter {
if o.cond == nil {
o.cond = NewCondition()
}
o.cond = o.cond.And(expr, args...)
return &o
}
func (o querySet) Exclude(expr string, args ...interface{}) QuerySeter {
if o.cond == nil {
o.cond = NewCondition()
}
o.cond = o.cond.AndNot(expr, args...)
return &o
}
func (o querySet) Limit(limit int, args ...int64) QuerySeter {
o.limit = limit
if len(args) > 0 {
o.offset = args[0]
}
return &o
}
func (o querySet) Offset(offset int64) QuerySeter {
o.offset = offset
return &o
}
func (o querySet) OrderBy(exprs ...string) QuerySeter {
o.orders = exprs
return &o
}
func (o querySet) RelatedSel(params ...interface{}) QuerySeter {
var related []string
if len(params) == 0 {
o.relDepth = DefaultRelsDepth
} else {
for _, p := range params {
switch val := p.(type) {
case string:
related = append(o.related, val)
case int:
o.relDepth = val
default:
panic(fmt.Sprintf("<QuerySeter.RelatedSel> wrong param kind: %v", val))
}
}
}
o.related = related
return &o
}
func (o querySet) SetCond(cond *Condition) QuerySeter {
o.cond = cond
return &o
}
func (o *querySet) Count() (int64, error) {
return o.orm.alias.DbBaser.Count(o.orm.db, o, o.mi, o.cond, o.orm.alias.TZ)
}
func (o *querySet) Update(values Params) (int64, error) {
return o.orm.alias.DbBaser.UpdateBatch(o.orm.db, o, o.mi, o.cond, values, o.orm.alias.TZ)
}
func (o *querySet) Delete() (int64, error) {
return o.orm.alias.DbBaser.DeleteBatch(o.orm.db, o, o.mi, o.cond, o.orm.alias.TZ)
}
func (o *querySet) PrepareInsert() (Inserter, error) {
return newInsertSet(o.orm, o.mi)
}
func (o *querySet) All(container interface{}) (int64, error) {
return o.orm.alias.DbBaser.ReadBatch(o.orm.db, o, o.mi, o.cond, container, o.orm.alias.TZ)
}
func (o *querySet) One(container interface{}) error {
num, err := o.orm.alias.DbBaser.ReadBatch(o.orm.db, o, o.mi, o.cond, container, o.orm.alias.TZ)
if err != nil {
return err
}
if num > 1 {
return ErrMultiRows
}
if num == 0 {
return ErrNoRows
}
return nil
}
func (o *querySet) Values(results *[]Params, exprs ...string) (int64, error) {
return o.orm.alias.DbBaser.ReadValues(o.orm.db, o, o.mi, o.cond, exprs, results, o.orm.alias.TZ)
}
func (o *querySet) ValuesList(results *[]ParamsList, exprs ...string) (int64, error) {
return o.orm.alias.DbBaser.ReadValues(o.orm.db, o, o.mi, o.cond, exprs, results, o.orm.alias.TZ)
}
func (o *querySet) ValuesFlat(result *ParamsList, expr string) (int64, error) {
return o.orm.alias.DbBaser.ReadValues(o.orm.db, o, o.mi, o.cond, []string{expr}, result, o.orm.alias.TZ)
}
func newQuerySet(orm *orm, mi *modelInfo) QuerySeter {
o := new(querySet)
o.mi = mi
o.orm = orm
return o
}

190
orm/orm_raw.go Normal file
View File

@ -0,0 +1,190 @@
package orm
import (
"database/sql"
"fmt"
"reflect"
)
type rawPrepare struct {
rs *rawSet
stmt stmtQuerier
closed bool
}
func (o *rawPrepare) Exec(args ...interface{}) (sql.Result, error) {
if o.closed {
return nil, ErrStmtClosed
}
return o.stmt.Exec(args...)
}
func (o *rawPrepare) Close() error {
o.closed = true
return o.stmt.Close()
}
func newRawPreparer(rs *rawSet) (RawPreparer, error) {
o := new(rawPrepare)
o.rs = rs
query := rs.query
rs.orm.alias.DbBaser.ReplaceMarks(&query)
st, err := rs.orm.db.Prepare(query)
if err != nil {
return nil, err
}
if Debug {
o.stmt = newStmtQueryLog(rs.orm.alias, st, query)
} else {
o.stmt = st
}
return o, nil
}
type rawSet struct {
query string
args []interface{}
orm *orm
}
var _ RawSeter = new(rawSet)
func (o rawSet) SetArgs(args ...interface{}) RawSeter {
o.args = args
return &o
}
func (o *rawSet) Exec() (sql.Result, error) {
query := o.query
o.orm.alias.DbBaser.ReplaceMarks(&query)
args := getFlatParams(nil, o.args, o.orm.alias.TZ)
return o.orm.db.Exec(query, args...)
}
func (o *rawSet) QueryRow(...interface{}) error {
//TODO
return nil
}
func (o *rawSet) QueryRows(...interface{}) (int64, error) {
//TODO
return 0, nil
}
func (o *rawSet) readValues(container interface{}) (int64, error) {
var (
maps []Params
lists []ParamsList
list ParamsList
)
typ := 0
switch container.(type) {
case *[]Params:
typ = 1
case *[]ParamsList:
typ = 2
case *ParamsList:
typ = 3
default:
panic(fmt.Sprintf("unsupport read values type `%T`", container))
}
query := o.query
o.orm.alias.DbBaser.ReplaceMarks(&query)
args := getFlatParams(nil, o.args, o.orm.alias.TZ)
var rs *sql.Rows
if r, err := o.orm.db.Query(query, args...); err != nil {
return 0, err
} else {
rs = r
}
var (
refs []interface{}
cnt int64
cols []string
)
for rs.Next() {
if cnt == 0 {
if columns, err := rs.Columns(); err != nil {
return 0, err
} else {
cols = columns
refs = make([]interface{}, len(cols))
for i, _ := range refs {
var ref sql.NullString
refs[i] = &ref
}
}
}
if err := rs.Scan(refs...); err != nil {
return 0, err
}
switch typ {
case 1:
params := make(Params, len(cols))
for i, ref := range refs {
value := reflect.Indirect(reflect.ValueOf(ref)).Interface().(sql.NullString)
params[cols[i]] = value.String
}
maps = append(maps, params)
case 2:
params := make(ParamsList, 0, len(cols))
for _, ref := range refs {
value := reflect.Indirect(reflect.ValueOf(ref)).Interface().(sql.NullString)
params = append(params, value.String)
}
lists = append(lists, params)
case 3:
for _, ref := range refs {
value := reflect.Indirect(reflect.ValueOf(ref)).Interface().(sql.NullString)
list = append(list, value.String)
}
}
cnt++
}
switch v := container.(type) {
case *[]Params:
*v = maps
case *[]ParamsList:
*v = lists
case *ParamsList:
*v = list
}
return cnt, nil
}
func (o *rawSet) Values(container *[]Params) (int64, error) {
return o.readValues(container)
}
func (o *rawSet) ValuesList(container *[]ParamsList) (int64, error) {
return o.readValues(container)
}
func (o *rawSet) ValuesFlat(container *ParamsList) (int64, error) {
return o.readValues(container)
}
func (o *rawSet) Prepare() (RawPreparer, error) {
return newRawPreparer(o)
}
func newRawSet(orm *orm, query string, args []interface{}) RawSeter {
o := new(rawSet)
o.query = query
o.args = args
o.orm = orm
return o
}

1018
orm/orm_test.go Normal file

File diff suppressed because it is too large Load Diff

135
orm/types.go Normal file
View File

@ -0,0 +1,135 @@
package orm
import (
"database/sql"
"reflect"
"time"
)
type Driver interface {
Name() string
Type() DriverType
}
type Fielder interface {
String() string
FieldType() int
SetRaw(interface{}) error
RawValue() interface{}
Clean() error
}
type Ormer interface {
Read(interface{}) error
Insert(interface{}) (int64, error)
Update(interface{}) (int64, error)
Delete(interface{}) (int64, error)
M2mAdd(interface{}, string, ...interface{}) (int64, error)
M2mDel(interface{}, string, ...interface{}) (int64, error)
LoadRel(interface{}, string) (int64, error)
QueryTable(interface{}) QuerySeter
Using(string) error
Begin() error
Commit() error
Rollback() error
Raw(string, ...interface{}) RawSeter
Driver() Driver
}
type Inserter interface {
Insert(interface{}) (int64, error)
Close() error
}
type QuerySeter interface {
Filter(string, ...interface{}) QuerySeter
Exclude(string, ...interface{}) QuerySeter
SetCond(*Condition) QuerySeter
Limit(int, ...int64) QuerySeter
Offset(int64) QuerySeter
OrderBy(...string) QuerySeter
RelatedSel(...interface{}) QuerySeter
Count() (int64, error)
Update(Params) (int64, error)
Delete() (int64, error)
PrepareInsert() (Inserter, error)
All(interface{}) (int64, error)
One(interface{}) error
Values(*[]Params, ...string) (int64, error)
ValuesList(*[]ParamsList, ...string) (int64, error)
ValuesFlat(*ParamsList, string) (int64, error)
}
type RawPreparer interface {
Exec(...interface{}) (sql.Result, error)
Close() error
}
type RawSeter interface {
Exec() (sql.Result, error)
QueryRow(...interface{}) error
QueryRows(...interface{}) (int64, error)
SetArgs(...interface{}) RawSeter
Values(*[]Params) (int64, error)
ValuesList(*[]ParamsList) (int64, error)
ValuesFlat(*ParamsList) (int64, error)
Prepare() (RawPreparer, error)
}
type IFieldError interface {
Name() string
Error() error
}
type IFieldErrors interface {
Get(string) IFieldError
Set(string, IFieldError)
List() []IFieldError
}
type stmtQuerier interface {
Close() error
Exec(args ...interface{}) (sql.Result, error)
Query(args ...interface{}) (*sql.Rows, error)
QueryRow(args ...interface{}) *sql.Row
}
type dbQuerier interface {
Prepare(query string) (*sql.Stmt, error)
Exec(query string, args ...interface{}) (sql.Result, error)
Query(query string, args ...interface{}) (*sql.Rows, error)
QueryRow(query string, args ...interface{}) *sql.Row
}
type txer interface {
Begin() (*sql.Tx, error)
}
type txEnder interface {
Commit() error
Rollback() error
}
type dbBaser interface {
Read(dbQuerier, *modelInfo, reflect.Value, *time.Location) error
Insert(dbQuerier, *modelInfo, reflect.Value, *time.Location) (int64, error)
InsertStmt(stmtQuerier, *modelInfo, reflect.Value, *time.Location) (int64, error)
Update(dbQuerier, *modelInfo, reflect.Value, *time.Location) (int64, error)
Delete(dbQuerier, *modelInfo, reflect.Value, *time.Location) (int64, error)
ReadBatch(dbQuerier, *querySet, *modelInfo, *Condition, interface{}, *time.Location) (int64, error)
SupportUpdateJoin() bool
UpdateBatch(dbQuerier, *querySet, *modelInfo, *Condition, Params, *time.Location) (int64, error)
DeleteBatch(dbQuerier, *querySet, *modelInfo, *Condition, *time.Location) (int64, error)
Count(dbQuerier, *querySet, *modelInfo, *Condition, *time.Location) (int64, error)
OperatorSql(string) string
GenerateOperatorSql(*modelInfo, *fieldInfo, string, []interface{}, *time.Location) (string, []interface{})
GenerateOperatorLeftCol(*fieldInfo, string, *string)
PrepareInsert(dbQuerier, *modelInfo) (stmtQuerier, string, error)
ReadValues(dbQuerier, *querySet, *modelInfo, *Condition, []string, interface{}, *time.Location) (int64, error)
MaxLimit() uint64
TableQuote() string
ReplaceMarks(*string)
HasReturningID(*modelInfo, *string) bool
TimeFromDB(*time.Time, *time.Location)
TimeToDB(*time.Time, *time.Location)
}

207
orm/utils.go Normal file
View File

@ -0,0 +1,207 @@
package orm
import (
"fmt"
"strconv"
"strings"
"time"
)
type StrTo string
func (f *StrTo) Set(v string) {
if v != "" {
*f = StrTo(v)
} else {
f.Clear()
}
}
func (f *StrTo) Clear() {
*f = StrTo(0x1E)
}
func (f StrTo) Exist() bool {
return string(f) != string(0x1E)
}
func (f StrTo) Bool() (bool, error) {
return strconv.ParseBool(f.String())
}
func (f StrTo) Float32() (float32, error) {
v, err := strconv.ParseFloat(f.String(), 32)
return float32(v), err
}
func (f StrTo) Float64() (float64, error) {
return strconv.ParseFloat(f.String(), 64)
}
func (f StrTo) Int8() (int8, error) {
v, err := strconv.ParseInt(f.String(), 10, 8)
return int8(v), err
}
func (f StrTo) Int16() (int16, error) {
v, err := strconv.ParseInt(f.String(), 10, 16)
return int16(v), err
}
func (f StrTo) Int32() (int32, error) {
v, err := strconv.ParseInt(f.String(), 10, 32)
return int32(v), err
}
func (f StrTo) Int64() (int64, error) {
v, err := strconv.ParseInt(f.String(), 10, 64)
return int64(v), err
}
func (f StrTo) Uint8() (uint8, error) {
v, err := strconv.ParseUint(f.String(), 10, 8)
return uint8(v), err
}
func (f StrTo) Uint16() (uint16, error) {
v, err := strconv.ParseUint(f.String(), 10, 16)
return uint16(v), err
}
func (f StrTo) Uint32() (uint32, error) {
v, err := strconv.ParseUint(f.String(), 10, 32)
return uint32(v), err
}
func (f StrTo) Uint64() (uint64, error) {
v, err := strconv.ParseUint(f.String(), 10, 64)
return uint64(v), err
}
func (f StrTo) String() string {
if f.Exist() {
return string(f)
}
return ""
}
func ToStr(value interface{}, args ...int) (s string) {
switch v := value.(type) {
case bool:
s = strconv.FormatBool(v)
case float32:
s = strconv.FormatFloat(float64(v), 'f', argInt(args).Get(0, -1), argInt(args).Get(1, 32))
case float64:
s = strconv.FormatFloat(v, 'f', argInt(args).Get(0, -1), argInt(args).Get(1, 64))
case int:
s = strconv.FormatInt(int64(v), argInt(args).Get(0, 10))
case int8:
s = strconv.FormatInt(int64(v), argInt(args).Get(0, 10))
case int16:
s = strconv.FormatInt(int64(v), argInt(args).Get(0, 10))
case int32:
s = strconv.FormatInt(int64(v), argInt(args).Get(0, 10))
case int64:
s = strconv.FormatInt(v, argInt(args).Get(0, 10))
case uint:
s = strconv.FormatUint(uint64(v), argInt(args).Get(0, 10))
case uint8:
s = strconv.FormatUint(uint64(v), argInt(args).Get(0, 10))
case uint16:
s = strconv.FormatUint(uint64(v), argInt(args).Get(0, 10))
case uint32:
s = strconv.FormatUint(uint64(v), argInt(args).Get(0, 10))
case uint64:
s = strconv.FormatUint(v, argInt(args).Get(0, 10))
case string:
s = v
default:
s = fmt.Sprintf("%v", v)
}
return s
}
func snakeString(s string) string {
data := make([]byte, 0, len(s)*2)
j := false
num := len(s)
for i := 0; i < num; i++ {
d := s[i]
if i > 0 && d >= 'A' && d <= 'Z' && j {
data = append(data, '_')
}
if d != '_' {
j = true
}
data = append(data, d)
}
return strings.ToLower(string(data[:len(data)]))
}
func camelString(s string) string {
data := make([]byte, 0, len(s))
j := false
k := false
num := len(s) - 1
for i := 0; i <= num; i++ {
d := s[i]
if k == false && d >= 'A' && d <= 'Z' {
k = true
}
if d >= 'a' && d <= 'z' && (j || k == false) {
d = d - 32
j = false
k = true
}
if k && d == '_' && num > i && s[i+1] >= 'a' && s[i+1] <= 'z' {
j = true
continue
}
data = append(data, d)
}
return string(data[:len(data)])
}
type argString []string
func (a argString) Get(i int, args ...string) (r string) {
if i >= 0 && i < len(a) {
r = a[i]
} else if len(args) > 0 {
r = args[0]
}
return
}
type argInt []int
func (a argInt) Get(i int, args ...int) (r int) {
if i >= 0 && i < len(a) {
r = a[i]
}
if len(args) > 0 {
r = args[0]
}
return
}
type argAny []interface{}
func (a argAny) Get(i int, args ...interface{}) (r interface{}) {
if i >= 0 && i < len(a) {
r = a[i]
}
if len(args) > 0 {
r = args[0]
}
return
}
func timeParse(dateString, format string) (time.Time, error) {
tp, err := time.ParseInLocation(format, dateString, DefaultTimeLoc)
return tp, err
}
func timeFormat(t time.Time, format string) string {
return t.Format(format)
}

View File

@ -27,12 +27,19 @@ var ErrInitStart = errors.New("init from")
// Allows for us to notice when the connection is closed.
type conn struct {
net.Conn
wg *sync.WaitGroup
wg *sync.WaitGroup
isclose bool
lock sync.Mutex
}
func (c conn) Close() error {
c.lock.Lock()
defer c.lock.Unlock()
err := c.Conn.Close()
c.wg.Done()
if !c.isclose && err == nil {
c.wg.Done()
c.isclose = true
}
return err
}
@ -137,16 +144,15 @@ func GetInitListner(tcpaddr *net.TCPAddr) (l net.Listener, err error) {
countStr := os.Getenv(FDKey)
if countStr == "" {
return net.ListenTCP("tcp", tcpaddr)
} else {
count, err := strconv.Atoi(countStr)
if err != nil {
return nil, err
}
f := os.NewFile(uintptr(count), "listen socket")
l, err = net.FileListener(f)
if err != nil {
return nil, err
}
return l, nil
}
count, err := strconv.Atoi(countStr)
if err != nil {
return nil, err
}
f := os.NewFile(uintptr(count), "listen socket")
l, err = net.FileListener(f)
if err != nil {
return nil, err
}
return l, nil
}

512
router.go
View File

@ -1,7 +1,9 @@
package beego
import (
"bytes"
"fmt"
"io/ioutil"
"net/http"
"net/url"
"os"
@ -12,15 +14,15 @@ import (
"strings"
)
var (
sc *Controller = &Controller{}
)
var HTTPMETHOD = []string{"get", "post", "put", "delete", "patch", "options", "head"}
type controllerInfo struct {
pattern string
regex *regexp.Regexp
params map[int]string
controllerType reflect.Type
methods map[string]string
hasMethod bool
}
type userHandler struct {
@ -33,15 +35,34 @@ type userHandler struct {
type ControllerRegistor struct {
routers []*controllerInfo
fixrouters []*controllerInfo
enableFilter bool
filters []http.HandlerFunc
enableAfter bool
afterFilters []http.HandlerFunc
enableUser bool
userHandlers map[string]*userHandler
enableAuto bool
autoRouter map[string]map[string]reflect.Type //key:controller key:method value:reflect.type
}
func NewControllerRegistor() *ControllerRegistor {
return &ControllerRegistor{routers: make([]*controllerInfo, 0), userHandlers: make(map[string]*userHandler)}
return &ControllerRegistor{
routers: make([]*controllerInfo, 0),
userHandlers: make(map[string]*userHandler),
autoRouter: make(map[string]map[string]reflect.Type),
}
}
func (p *ControllerRegistor) Add(pattern string, c ControllerInterface) {
//methods support like this:
//default methods is the same name as method
//Add("/user",&UserController{})
//Add("/api/list",&RestController{},"*:ListFood")
//Add("/api/create",&RestController{},"post:CreateFood")
//Add("/api/update",&RestController{},"put:UpdateFood")
//Add("/api/delete",&RestController{},"delete:DeleteFood")
//Add("/api",&RestController{},"get,post:ApiFunc")
//Add("/simple",&SimpleController{},"get:GetFunc;post:PostFunc")
func (p *ControllerRegistor) Add(pattern string, c ControllerInterface, mappingMethods ...string) {
parts := strings.Split(pattern, "/")
j := 0
@ -85,13 +106,39 @@ func (p *ControllerRegistor) Add(pattern string, c ControllerInterface) {
}
}
}
reflectVal := reflect.ValueOf(c)
t := reflect.Indirect(reflectVal).Type()
methods := make(map[string]string)
if len(mappingMethods) > 0 {
semi := strings.Split(mappingMethods[0], ";")
for _, v := range semi {
colon := strings.Split(v, ":")
if len(colon) != 2 {
panic("method mapping fomate is error")
}
comma := strings.Split(colon[0], ",")
for _, m := range comma {
if m == "*" || inSlice(strings.ToLower(m), HTTPMETHOD) {
if val := reflectVal.MethodByName(colon[1]); val.IsValid() {
methods[strings.ToLower(m)] = colon[1]
} else {
panic(colon[1] + " method don't exist in the controller " + t.Name())
}
} else {
panic(v + " is an error method mapping,Don't exist method named " + m)
}
}
}
}
if j == 0 {
//now create the Route
t := reflect.Indirect(reflect.ValueOf(c)).Type()
route := &controllerInfo{}
route.pattern = pattern
route.controllerType = t
route.methods = methods
if len(methods) > 0 {
route.hasMethod = true
}
p.fixrouters = append(p.fixrouters, route)
} else { // add regexp routers
//recreate the url pattern, with parameters replaced
@ -105,17 +152,38 @@ func (p *ControllerRegistor) Add(pattern string, c ControllerInterface) {
}
//now create the Route
t := reflect.Indirect(reflect.ValueOf(c)).Type()
route := &controllerInfo{}
route.regex = regex
route.params = params
route.pattern = pattern
route.methods = methods
if len(methods) > 0 {
route.hasMethod = true
}
route.controllerType = t
p.routers = append(p.routers, route)
}
}
func (p *ControllerRegistor) AddAuto(c ControllerInterface) {
p.enableAuto = true
reflectVal := reflect.ValueOf(c)
rt := reflectVal.Type()
ct := reflect.Indirect(reflectVal).Type()
firstParam := strings.ToLower(strings.TrimSuffix(ct.Name(), "Controller"))
if _, ok := p.autoRouter[firstParam]; ok {
return
} else {
p.autoRouter[firstParam] = make(map[string]reflect.Type)
}
for i := 0; i < rt.NumMethod(); i++ {
p.autoRouter[firstParam][rt.Method(i).Name] = ct
}
}
func (p *ControllerRegistor) AddHandler(pattern string, c http.Handler) {
p.enableUser = true
parts := strings.Split(pattern, "/")
j := 0
@ -163,6 +231,7 @@ func (p *ControllerRegistor) AddHandler(pattern string, c http.Handler) {
// Filter adds the middleware filter.
func (p *ControllerRegistor) Filter(filter http.HandlerFunc) {
p.enableFilter = true
p.filters = append(p.filters, filter)
}
@ -189,83 +258,33 @@ func (p *ControllerRegistor) FilterPrefixPath(path string, filter http.HandlerFu
})
}
func StructMap(vc reflect.Value, r *http.Request) error {
for k, t := range r.Form {
v := t[0]
names := strings.Split(k, ".")
var value reflect.Value = vc
for i, name := range names {
name = strings.Title(name)
if i == 0 {
if reflect.ValueOf(sc).Elem().FieldByName(name).IsValid() {
Trace("Controller's property should not be changed by mapper.")
break
}
}
if value.Kind() != reflect.Struct {
Trace(fmt.Sprintf("arg error, value kind is %v", value.Kind()))
break
}
// Filter adds the middleware after filter.
func (p *ControllerRegistor) FilterAfter(filter http.HandlerFunc) {
p.enableAfter = true
p.afterFilters = append(p.afterFilters, filter)
}
if i != len(names)-1 {
value = value.FieldByName(name)
if !value.IsValid() {
Trace(fmt.Sprintf("(%v value is not valid %v)", name, value))
break
}
} else {
tv := value.FieldByName(name)
if !tv.IsValid() {
Trace(fmt.Sprintf("struct %v has no field named %v", value, name))
break
}
if !tv.CanSet() {
Trace("can not set " + k)
break
}
var l interface{}
switch k := tv.Kind(); k {
case reflect.String:
l = v
case reflect.Bool:
l = (v == "true")
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32:
x, err := strconv.Atoi(v)
if err != nil {
Trace("arg " + v + " as int: " + err.Error())
break
}
l = x
case reflect.Int64:
x, err := strconv.ParseInt(v, 10, 64)
if err != nil {
Trace("arg " + v + " as int: " + err.Error())
break
}
l = x
case reflect.Float32, reflect.Float64:
x, err := strconv.ParseFloat(v, 64)
if err != nil {
Trace("arg " + v + " as float64: " + err.Error())
break
}
l = x
case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
x, err := strconv.ParseUint(v, 10, 64)
if err != nil {
Trace("arg " + v + " as int: " + err.Error())
break
}
l = x
case reflect.Struct:
Trace("can not set an struct")
}
tv.Set(reflect.ValueOf(l))
}
}
// FilterParam adds the middleware filter if the REST URL parameter exists.
func (p *ControllerRegistor) FilterParamAfter(param string, filter http.HandlerFunc) {
if !strings.HasPrefix(param, ":") {
param = ":" + param
}
return nil
p.FilterAfter(func(w http.ResponseWriter, r *http.Request) {
p := r.URL.Query().Get(param)
if len(p) > 0 {
filter(w, r)
}
})
}
// FilterPrefixPath adds the middleware filter if the prefix path exists.
func (p *ControllerRegistor) FilterPrefixPathAfter(path string, filter http.HandlerFunc) {
p.FilterAfter(func(w http.ResponseWriter, r *http.Request) {
if strings.HasPrefix(r.URL.Path, path) {
filter(w, r)
}
})
}
// AutoRoute
@ -273,7 +292,7 @@ func (p *ControllerRegistor) ServeHTTP(rw http.ResponseWriter, r *http.Request)
defer func() {
if err := recover(); err != nil {
errstr := fmt.Sprint(err)
if handler, ok := ErrorMaps[errstr]; ok {
if handler, ok := ErrorMaps[errstr]; ok && ErrorsShow {
handler(rw, r)
} else {
if !RecoverPanic {
@ -341,68 +360,79 @@ func (p *ControllerRegistor) ServeHTTP(rw http.ResponseWriter, r *http.Request)
}
requestPath := r.URL.Path
var requestbody []byte
if CopyRequestBody {
requestbody, _ = ioutil.ReadAll(r.Body)
r.Body.Close()
bf := bytes.NewBuffer(requestbody)
r.Body = ioutil.NopCloser(bf)
}
r.ParseMultipartForm(MaxMemory)
//user defined Handler
for pattern, c := range p.userHandlers {
if c.regex == nil && pattern == requestPath {
if p.enableUser {
for pattern, c := range p.userHandlers {
if c.regex == nil && pattern == requestPath {
c.h.ServeHTTP(rw, r)
return
} else if c.regex == nil {
continue
}
//check if Route pattern matches url
if !c.regex.MatchString(requestPath) {
continue
}
//get submatches (params)
matches := c.regex.FindStringSubmatch(requestPath)
//double check that the Route matches the URL pattern.
if len(matches[0]) != len(requestPath) {
continue
}
if len(c.params) > 0 {
//add url parameters to the query param map
values := r.URL.Query()
for i, match := range matches[1:] {
values.Add(c.params[i], match)
r.Form.Add(c.params[i], match)
params[c.params[i]] = match
}
//reassemble query params and add to RawQuery
r.URL.RawQuery = url.Values(values).Encode() + "&" + r.URL.RawQuery
//r.URL.RawQuery = url.Values(values).Encode()
}
c.h.ServeHTTP(rw, r)
return
} else if c.regex == nil {
continue
}
//check if Route pattern matches url
if !c.regex.MatchString(requestPath) {
continue
}
//get submatches (params)
matches := c.regex.FindStringSubmatch(requestPath)
//double check that the Route matches the URL pattern.
if len(matches[0]) != len(requestPath) {
continue
}
if len(c.params) > 0 {
//add url parameters to the query param map
values := r.URL.Query()
for i, match := range matches[1:] {
values.Add(c.params[i], match)
r.Form.Add(c.params[i], match)
params[c.params[i]] = match
}
//reassemble query params and add to RawQuery
r.URL.RawQuery = url.Values(values).Encode() + "&" + r.URL.RawQuery
//r.URL.RawQuery = url.Values(values).Encode()
}
c.h.ServeHTTP(rw, r)
return
}
//first find path from the fixrouters to Improve Performance
for _, route := range p.fixrouters {
n := len(requestPath)
//route like "/"
if n == 1 {
if requestPath == route.pattern {
runrouter = route
findrouter = true
break
} else {
continue
}
}
if (requestPath[n-1] != '/' && route.pattern == requestPath) ||
(requestPath[n-1] == '/' && len(route.pattern) >= n-1 && requestPath[0:n-1] == route.pattern) {
if requestPath == route.pattern {
runrouter = route
findrouter = true
break
}
// pattern /admin url /admin 200 /admin/ 404
// pattern /admin/ url /admin 301 /admin/ 200
if requestPath[n-1] != '/' && len(route.pattern) == n+1 &&
route.pattern[n] == '/' && route.pattern[:n-1] == requestPath {
http.Redirect(w, r, requestPath+"/", 301)
return
}
}
//find regex's router
if !findrouter {
//find a matching Route
for _, route := range p.routers {
@ -429,7 +459,7 @@ func (p *ControllerRegistor) ServeHTTP(rw http.ResponseWriter, r *http.Request)
params[route.params[i]] = match
}
//reassemble query params and add to RawQuery
r.URL.RawQuery = url.Values(values).Encode() + "&" + r.URL.RawQuery
r.URL.RawQuery = url.Values(values).Encode()
//r.URL.RawQuery = url.Values(values).Encode()
}
runrouter = route
@ -440,22 +470,22 @@ func (p *ControllerRegistor) ServeHTTP(rw http.ResponseWriter, r *http.Request)
if runrouter != nil {
//execute middleware filters
for _, filter := range p.filters {
filter(w, r)
if w.started {
return
if p.enableFilter {
for _, filter := range p.filters {
filter(w, r)
if w.started {
return
}
}
}
//Invoke the request handler
vc := reflect.New(runrouter.controllerType)
StructMap(vc.Elem(), r)
//call the controller init function
init := vc.MethodByName("Init")
in := make([]reflect.Value, 2)
ct := &Context{ResponseWriter: w, Request: r, Params: params}
ct := &Context{ResponseWriter: w, Request: r, Params: params, RequestBody: requestbody}
in[0] = reflect.ValueOf(ct)
in[1] = reflect.ValueOf(runrouter.controllerType.Name())
@ -465,38 +495,136 @@ func (p *ControllerRegistor) ServeHTTP(rw http.ResponseWriter, r *http.Request)
method := vc.MethodByName("Prepare")
method.Call(in)
//if XSRF is Enable then check cookie where there has any cookie in the request's cookie _csrf
if EnableXSRF {
method = vc.MethodByName("XsrfToken")
method.Call(in)
if r.Method == "POST" || r.Method == "DELETE" || r.Method == "PUT" ||
(r.Method == "POST" && (r.Form.Get("_method") == "delete" || r.Form.Get("_method") == "put")) {
method = vc.MethodByName("CheckXsrfCookie")
method.Call(in)
}
}
//if response has written,yes don't run next
if !w.started {
if r.Method == "GET" {
method = vc.MethodByName("Get")
if runrouter.hasMethod {
if m, ok := runrouter.methods["get"]; ok {
method = vc.MethodByName(m)
} else if m, ok = runrouter.methods["*"]; ok {
method = vc.MethodByName(m)
} else {
method = vc.MethodByName("Get")
}
} else {
method = vc.MethodByName("Get")
}
method.Call(in)
} else if r.Method == "HEAD" {
method = vc.MethodByName("Head")
if runrouter.hasMethod {
if m, ok := runrouter.methods["head"]; ok {
method = vc.MethodByName(m)
} else if m, ok = runrouter.methods["*"]; ok {
method = vc.MethodByName(m)
} else {
method = vc.MethodByName("Head")
}
} else {
method = vc.MethodByName("Head")
}
method.Call(in)
} else if r.Method == "DELETE" || (r.Method == "POST" && r.Form.Get("_method") == "delete") {
method = vc.MethodByName("Delete")
if runrouter.hasMethod {
if m, ok := runrouter.methods["delete"]; ok {
method = vc.MethodByName(m)
} else if m, ok = runrouter.methods["*"]; ok {
method = vc.MethodByName(m)
} else {
method = vc.MethodByName("Delete")
}
} else {
method = vc.MethodByName("Delete")
}
method.Call(in)
} else if r.Method == "PUT" || (r.Method == "POST" && r.Form.Get("_method") == "put") {
method = vc.MethodByName("Put")
if runrouter.hasMethod {
if m, ok := runrouter.methods["put"]; ok {
method = vc.MethodByName(m)
} else if m, ok = runrouter.methods["*"]; ok {
method = vc.MethodByName(m)
} else {
method = vc.MethodByName("Put")
}
} else {
method = vc.MethodByName("Put")
}
method.Call(in)
} else if r.Method == "POST" {
method = vc.MethodByName("Post")
if runrouter.hasMethod {
if m, ok := runrouter.methods["post"]; ok {
method = vc.MethodByName(m)
} else if m, ok = runrouter.methods["*"]; ok {
method = vc.MethodByName(m)
} else {
method = vc.MethodByName("Post")
}
} else {
method = vc.MethodByName("Post")
}
method.Call(in)
} else if r.Method == "PATCH" {
method = vc.MethodByName("Patch")
if runrouter.hasMethod {
if m, ok := runrouter.methods["patch"]; ok {
method = vc.MethodByName(m)
} else if m, ok = runrouter.methods["*"]; ok {
method = vc.MethodByName(m)
} else {
method = vc.MethodByName("Patch")
}
} else {
method = vc.MethodByName("Patch")
}
method.Call(in)
} else if r.Method == "OPTIONS" {
method = vc.MethodByName("Options")
if runrouter.hasMethod {
if m, ok := runrouter.methods["options"]; ok {
method = vc.MethodByName(m)
} else if m, ok = runrouter.methods["*"]; ok {
method = vc.MethodByName(m)
} else {
method = vc.MethodByName("Options")
}
} else {
method = vc.MethodByName("Options")
}
method.Call(in)
}
gotofunc := vc.Elem().FieldByName("gotofunc").String()
if gotofunc != "" {
method = vc.MethodByName(gotofunc)
if method.IsValid() {
method.Call(in)
} else {
panic("gotofunc is exists:" + gotofunc)
}
}
if !w.started {
if AutoRender {
method = vc.MethodByName("Render")
method.Call(in)
}
if !w.started {
method = vc.MethodByName("Finish")
method.Call(in)
}
}
method = vc.MethodByName("Finish")
method.Call(in)
//execute middleware filters
if p.enableAfter {
for _, filter := range p.afterFilters {
filter(w, r)
if w.started {
return
}
}
}
@ -504,9 +632,99 @@ func (p *ControllerRegistor) ServeHTTP(rw http.ResponseWriter, r *http.Request)
method.Call(in)
}
//start autorouter
if p.enableAuto {
if !findrouter {
for cName, methodmap := range p.autoRouter {
if strings.ToLower(requestPath) == "/"+cName {
http.Redirect(w, r, requestPath+"/", 301)
return
}
if strings.ToLower(requestPath) == "/"+cName+"/" {
requestPath = requestPath + "index"
}
if strings.HasPrefix(strings.ToLower(requestPath), "/"+cName+"/") {
for mName, controllerType := range methodmap {
if strings.HasPrefix(strings.ToLower(requestPath), "/"+cName+"/"+strings.ToLower(mName)) {
//execute middleware filters
if p.enableFilter {
for _, filter := range p.filters {
filter(w, r)
if w.started {
return
}
}
}
//parse params
otherurl := requestPath[len("/"+cName+"/"+strings.ToLower(mName)):]
if len(otherurl) > 1 {
plist := strings.Split(otherurl, "/")
for k, v := range plist[1:] {
params[strconv.Itoa(k)] = v
}
}
//Invoke the request handler
vc := reflect.New(controllerType)
//call the controller init function
init := vc.MethodByName("Init")
in := make([]reflect.Value, 2)
ct := &Context{ResponseWriter: w, Request: r, Params: params, RequestBody: requestbody}
in[0] = reflect.ValueOf(ct)
in[1] = reflect.ValueOf(controllerType.Name())
init.Call(in)
//call prepare function
in = make([]reflect.Value, 0)
method := vc.MethodByName("Prepare")
method.Call(in)
method = vc.MethodByName(mName)
method.Call(in)
//if XSRF is Enable then check cookie where there has any cookie in the request's cookie _csrf
if EnableXSRF {
method = vc.MethodByName("XsrfToken")
method.Call(in)
if r.Method == "POST" || r.Method == "DELETE" || r.Method == "PUT" ||
(r.Method == "POST" && (r.Form.Get("_method") == "delete" || r.Form.Get("_method") == "put")) {
method = vc.MethodByName("CheckXsrfCookie")
method.Call(in)
}
}
if !w.started {
if AutoRender {
method = vc.MethodByName("Render")
method.Call(in)
}
}
method = vc.MethodByName("Finish")
method.Call(in)
//execute middleware filters
if p.enableAfter {
for _, filter := range p.afterFilters {
filter(w, r)
if w.started {
return
}
}
}
method = vc.MethodByName("Destructor")
method.Call(in)
// set find
findrouter = true
}
}
}
}
}
}
//if no matches to url, throw a not found exception
if !findrouter {
if h, ok := ErrorMaps["404"]; ok {
w.status = 404
h(w, r)
} else {
http.NotFound(w, r)

View File

@ -56,3 +56,9 @@ func (m *BeeMap) Delete(k interface{}) {
defer m.lock.Unlock()
delete(m.bm, k)
}
func (m *BeeMap) Items() map[interface{}]interface{} {
m.lock.RLock()
defer m.lock.RUnlock()
return m.bm
}

View File

@ -113,8 +113,8 @@ func (pder *MemProvider) SessionGC() {
}
func (pder *MemProvider) SessionUpdate(sid string) error {
pder.lock.RLock()
defer pder.lock.RUnlock()
pder.lock.Lock()
defer pder.lock.Unlock()
if element, ok := pder.sessions[sid]; ok {
element.Value.(*MemSessionStore).timeAccessed = time.Now()
pder.list.MoveToFront(element)

View File

@ -22,12 +22,11 @@ func (rs *RedisSessionStore) Set(key, value interface{}) error {
}
func (rs *RedisSessionStore) Get(key interface{}) interface{} {
//v, err := rs.c.Do("GET", rs.sid, key)
v, err := redis.String(rs.c.Do("HGET", rs.sid, key))
reply, err := rs.c.Do("HGET", rs.sid, key)
if err != nil {
return nil
}
return v
return reply
}
func (rs *RedisSessionStore) Delete(key interface{}) error {
@ -56,7 +55,7 @@ func (rp *RedisProvider) connectInit() redis.Conn {
}
return c*/
//if redisPool == nil {
redisPool = make(chan redis.Conn, MAX_POOL_SIZE)
redisPool = make(chan redis.Conn, MAX_POOL_SIZE)
//}
if len(redisPool) == 0 {
go func() {

View File

@ -66,11 +66,11 @@ func (manager *Manager) SessionStart(w http.ResponseWriter, r *http.Request) (se
Path: "/",
HttpOnly: true,
Secure: false}
cookie.Expires = time.Now().Add(time.Duration(manager.maxlifetime) * time.Second)
//cookie.Expires = time.Now().Add(time.Duration(manager.maxlifetime) * time.Second)
http.SetCookie(w, &cookie)
r.AddCookie(&cookie)
} else {
cookie.Expires = time.Now().Add(time.Duration(manager.maxlifetime) * time.Second)
//cookie.Expires = time.Now().Add(time.Duration(manager.maxlifetime) * time.Second)
cookie.HttpOnly = true
cookie.Path = "/"
http.SetCookie(w, cookie)

View File

@ -23,7 +23,6 @@ func init() {
beegoTplFuncMap = make(template.FuncMap)
BeeTemplateExt = make([]string, 0)
BeeTemplateExt = append(BeeTemplateExt, "tpl", "html")
beegoTplFuncMap["markdown"] = MarkDown
beegoTplFuncMap["dateformat"] = DateFormat
beegoTplFuncMap["date"] = Date
beegoTplFuncMap["compare"] = Compare
@ -32,6 +31,7 @@ func init() {
beegoTplFuncMap["str2html"] = Str2html
beegoTplFuncMap["htmlquote"] = Htmlquote
beegoTplFuncMap["htmlunquote"] = Htmlunquote
beegoTplFuncMap["renderform"] = RenderForm
}
// AddFuncMap let user to register a func in the template
@ -52,34 +52,35 @@ func (self *templatefile) visit(paths string, f os.FileInfo, err error) error {
if f == nil {
return err
}
if f.IsDir() {
if f.IsDir() || (f.Mode()&os.ModeSymlink) > 0 {
return nil
} else if (f.Mode() & os.ModeSymlink) > 0 {
}
if !HasTemplateEXt(paths) {
return nil
} else {
hasExt := false
for _, v := range BeeTemplateExt {
if strings.HasSuffix(paths, v) {
hasExt = true
break
}
}
if hasExt {
replace := strings.NewReplacer("\\", "/")
a := []byte(paths)
a = a[len([]byte(self.root)):]
subdir := path.Dir(strings.TrimLeft(replace.Replace(string(a)), "/"))
if _, ok := self.files[subdir]; ok {
self.files[subdir] = append(self.files[subdir], paths)
} else {
m := make([]string, 1)
m[0] = paths
self.files[subdir] = m
}
}
replace := strings.NewReplacer("\\", "/")
a := []byte(paths)
a = a[len([]byte(self.root)):]
subdir := path.Dir(strings.TrimLeft(replace.Replace(string(a)), "/"))
if _, ok := self.files[subdir]; ok {
self.files[subdir] = append(self.files[subdir], paths)
} else {
m := make([]string, 1)
m[0] = paths
self.files[subdir] = m
}
return nil
}
func HasTemplateEXt(paths string) bool {
for _, v := range BeeTemplateExt {
if strings.HasSuffix(paths, "."+v) {
return true
}
}
return nil
return false
}
func AddTemplateExt(ext string) {

49
template_test.go Normal file
View File

@ -0,0 +1,49 @@
package beego
import (
"os"
"path/filepath"
"testing"
)
func TestBuildTemplate(t *testing.T) {
dir := "_beeTmp"
files := []string{
"1.tpl",
"2.html",
"3.htmltpl",
"4.mystyle",
}
if err := os.MkdirAll(dir, 0777); err != nil {
t.Fatal(err)
}
for _, name := range files {
if _, err := os.Create(filepath.Join(dir, name)); err != nil {
t.Fatal(err)
}
}
if err := BuildTemplate(dir); err != nil {
t.Fatal(err)
}
if len(BeeTemplates) != 1 {
t.Fatalf("should be 1 but got %v", len(BeeTemplates))
}
for _, v := range BeeTemplates {
if len(v.Templates()) != 3 {
t.Errorf("should be 3 but got %v", len(v.Templates()))
}
}
AddTemplateExt("mystyle")
if err := BuildTemplate(dir); err != nil {
t.Fatal(err)
}
if len(BeeTemplates) != 1 {
t.Fatalf("should be 1 but got %v", len(BeeTemplates))
}
for _, v := range BeeTemplates {
if len(v.Templates()) != 4 {
t.Errorf("should be 4 but got %v", len(v.Templates()))
}
}
}

183
utils.go
View File

@ -2,9 +2,11 @@ package beego
import (
"fmt"
"github.com/russross/blackfriday"
"html/template"
"net/url"
"reflect"
"regexp"
"strconv"
"strings"
"time"
)
@ -17,14 +19,6 @@ func webTime(t time.Time) string {
return ftime
}
// MarkDown parses a string in MarkDown format and returns HTML. Used by the template parser as "markdown"
func MarkDown(raw string) (output template.HTML) {
input := []byte(raw)
bOutput := blackfriday.MarkdownBasic(input)
output = template.HTML(string(bOutput))
return
}
func Substr(s string, start, length int) string {
bt := []rune(s)
if start < 0 {
@ -170,3 +164,174 @@ func Htmlunquote(src string) string {
return strings.TrimSpace(text)
}
func inSlice(v string, sl []string) bool {
for _, vv := range sl {
if vv == v {
return true
}
}
return false
}
// parse form values to struct via tag
func ParseForm(form url.Values, obj interface{}) error {
objT := reflect.TypeOf(obj)
objV := reflect.ValueOf(obj)
if !isStructPtr(objT) {
return fmt.Errorf("%v must be a struct pointer", obj)
}
objT = objT.Elem()
objV = objV.Elem()
for i := 0; i < objT.NumField(); i++ {
fieldV := objV.Field(i)
if !fieldV.CanSet() {
continue
}
fieldT := objT.Field(i)
tags := strings.Split(fieldT.Tag.Get("form"), ",")
var tag string
if len(tags) == 0 || len(tags[0]) == 0 {
tag = fieldT.Name
} else if tags[0] == "-" {
continue
} else {
tag = tags[0]
}
value := form.Get(tag)
if len(value) == 0 {
continue
}
switch fieldT.Type.Kind() {
case reflect.Bool:
b, err := strconv.ParseBool(value)
if err != nil {
return err
}
fieldV.SetBool(b)
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
x, err := strconv.ParseInt(value, 10, 64)
if err != nil {
return err
}
fieldV.SetInt(x)
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
x, err := strconv.ParseUint(value, 10, 64)
if err != nil {
return err
}
fieldV.SetUint(x)
case reflect.Float32, reflect.Float64:
x, err := strconv.ParseFloat(value, 64)
if err != nil {
return err
}
fieldV.SetFloat(x)
case reflect.Interface:
fieldV.Set(reflect.ValueOf(value))
case reflect.String:
fieldV.SetString(value)
}
}
return nil
}
// form types for RenderForm function
var FormType = map[string]bool{
"text": true,
"textarea": true,
"hidden": true,
"password": true,
}
var unKind = map[reflect.Kind]bool{
reflect.Uintptr: true,
reflect.Complex64: true,
reflect.Complex128: true,
reflect.Array: true,
reflect.Chan: true,
reflect.Func: true,
reflect.Map: true,
reflect.Ptr: true,
reflect.Slice: true,
reflect.Struct: true,
reflect.UnsafePointer: true,
}
// obj must be a struct pointer
func RenderForm(obj interface{}) template.HTML {
objT := reflect.TypeOf(obj)
objV := reflect.ValueOf(obj)
if !isStructPtr(objT) {
return template.HTML("")
}
objT = objT.Elem()
objV = objV.Elem()
var raw []string
for i := 0; i < objT.NumField(); i++ {
fieldV := objV.Field(i)
if !fieldV.CanSet() || unKind[fieldV.Kind()] {
continue
}
fieldT := objT.Field(i)
tags := strings.Split(fieldT.Tag.Get("form"), ",")
label := fieldT.Name + ": "
name := fieldT.Name
fType := "text"
switch len(tags) {
case 1:
if tags[0] == "-" {
continue
}
if len(tags[0]) > 0 {
name = tags[0]
}
case 2:
if len(tags[0]) > 0 {
name = tags[0]
}
if len(tags[1]) > 0 {
fType = tags[1]
}
case 3:
if len(tags[0]) > 0 {
name = tags[0]
}
if len(tags[1]) > 0 {
fType = tags[1]
}
if len(tags[2]) > 0 {
label = tags[2]
}
}
raw = append(raw, fmt.Sprintf(`%v<input name="%v" type="%v" value="%v">`,
label, name, fType, fieldV.Interface()))
}
return template.HTML(strings.Join(raw, "</br>"))
}
func isStructPtr(t reflect.Type) bool {
return t.Kind() == reflect.Ptr && t.Elem().Kind() == reflect.Struct
}
func stringsToJson(str string) string {
rs := []rune(str)
jsons := ""
for _, r := range rs {
rint := int(r)
if rint < 128 {
jsons += string(r)
} else {
jsons += "\\u" + strconv.FormatInt(int64(rint), 16) // json
}
}
return jsons
}

176
utils_test.go Normal file
View File

@ -0,0 +1,176 @@
package beego
import (
"html/template"
"net/url"
"testing"
"time"
)
func TestWebTime(t *testing.T) {
ts := "Fri, 26 Jul 2013 12:27:42 CST"
l, _ := time.LoadLocation("GST")
tt, _ := time.ParseInLocation(time.RFC1123, ts, l)
if ts != webTime(tt) {
t.Error("should be equal")
}
if "Fri, 26 Jul 2013 12:27:42 GMT" != webTime(tt.UTC()) {
t.Error("should be equal")
}
}
func TestSubstr(t *testing.T) {
s := `012345`
if Substr(s, 0, 2) != "01" {
t.Error("should be equal")
}
if Substr(s, 0, 100) != "012345" {
t.Error("should be equal")
}
}
func TestHtml2str(t *testing.T) {
h := `<HTML><style></style><script>x<x</script></HTML><123> 123\n
\n`
if Html2str(h) != "123\\n\n\\n" {
t.Error("should be equal")
}
}
func TestDateFormat(t *testing.T) {
ts := "Mon, 01 Jul 2013 13:27:42 CST"
tt, _ := time.Parse(time.RFC1123, ts)
if DateFormat(tt, "2006-01-02 15:04:05") != "2013-07-01 13:27:42" {
t.Error("should be equal")
}
}
func TestDate(t *testing.T) {
ts := "Mon, 01 Jul 2013 13:27:42 CST"
tt, _ := time.Parse(time.RFC1123, ts)
if Date(tt, "Y-m-d H:i:s") != "2013-07-01 13:27:42" {
t.Error("should be equal")
}
if Date(tt, "y-n-j h:i:s A") != "13-7-1 01:27:42 PM" {
t.Error("should be equal")
}
if Date(tt, "D, d M Y g:i:s a") != "Mon, 01 Jul 2013 1:27:42 pm" {
t.Error("should be equal")
}
if Date(tt, "l, d F Y G:i:s") != "Monday, 01 July 2013 13:27:42" {
t.Error("should be equal")
}
}
func TestCompare(t *testing.T) {
if !Compare("abc", "abc") {
t.Error("should be equal")
}
if Compare("abc", "aBc") {
t.Error("should be not equal")
}
if !Compare("1", 1) {
t.Error("should be equal")
}
}
func TestHtmlquote(t *testing.T) {
h := `&lt;&#39;&nbsp;&rdquo;&ldquo;&amp;&quot;&gt;`
s := `<' ”“&">`
if Htmlquote(s) != h {
t.Error("should be equal")
}
}
func TestHtmlunquote(t *testing.T) {
h := `&lt;&#39;&nbsp;&rdquo;&ldquo;&amp;&quot;&gt;`
s := `<' ”“&">`
if Htmlunquote(h) != s {
t.Error("should be equal")
}
}
func TestInSlice(t *testing.T) {
sl := []string{"A", "b"}
if !inSlice("A", sl) {
t.Error("should be true")
}
if inSlice("B", sl) {
t.Error("should be false")
}
}
func TestParseForm(t *testing.T) {
type user struct {
Id int `form:"-"`
tag string `form:"tag"`
Name interface{} `form:"username"`
Age int `form:"age,text"`
Email string
Intro string `form:",textarea"`
}
u := user{}
form := url.Values{
"Id": []string{"1"},
"-": []string{"1"},
"tag": []string{"no"},
"username": []string{"test"},
"age": []string{"40"},
"Email": []string{"test@gmail.com"},
"Intro": []string{"I am an engineer!"},
}
if err := ParseForm(form, u); err == nil {
t.Fatal("nothing will be changed")
}
if err := ParseForm(form, &u); err != nil {
t.Fatal(err)
}
if u.Id != 0 {
t.Errorf("Id should equal 0 but got %v", u.Id)
}
if len(u.tag) != 0 {
t.Errorf("tag's length should equal 0 but got %v", len(u.tag))
}
if u.Name.(string) != "test" {
t.Errorf("Name should equal `test` but got `%v`", u.Name.(string))
}
if u.Age != 40 {
t.Errorf("Age should equal 40 but got %v", u.Age)
}
if u.Email != "test@gmail.com" {
t.Errorf("Email should equal `test@gmail.com` but got `%v`", u.Email)
}
if u.Intro != "I am an engineer!" {
t.Errorf("Intro should equal `I am an engineer!` but got `%v`", u.Intro)
}
}
func TestRenderForm(t *testing.T) {
type user struct {
Id int `form:"-"`
tag string `form:"tag"`
Name interface{} `form:"username"`
Age int `form:"age,text,年龄:"`
Sex string
Email []string
Intro string `form:",textarea"`
}
u := user{Name: "test"}
output := RenderForm(u)
if output != template.HTML("") {
t.Errorf("output should be empty but got %v", output)
}
output = RenderForm(&u)
result := template.HTML(
`Name: <input name="username" type="text" value="test"></br>` +
`年龄:<input name="age" type="text" value="0"></br>` +
`Sex: <input name="Sex" type="text" value=""></br>` +
`Intro: <input name="Intro" type="textarea" value="">`)
if output != result {
t.Errorf("output should equal `%v` but got `%v`", result, output)
}
}

103
validation/README.md Normal file
View File

@ -0,0 +1,103 @@
validation
==============
validation is a form validation for a data validation and error collecting using Go.
## Installation and tests
Install:
go get github.com/astaxie/beego/validation
Test:
go test github.com/astaxie/beego/validation
## Example
Direct Use:
import (
"github.com/astaxie/beego/validation"
"log"
)
type User struct {
Name string
Age int
}
func main() {
u := User{"man", 40}
valid := validation.Validation{}
valid.Required(u.Name, "name")
valid.MaxSize(u.Name, 15, "nameMax")
valid.Range(u.Age, 0, 140, "age")
if valid.HasErrors {
// validation does not pass
// print invalid message
for _, err := range valid.Errors {
log.Println(err.Key, err.Message)
}
}
// or use like this
if v := valid.Max(u.Age, 140); !v.Ok {
log.Println(v.Error.Key, v.Error.Message)
}
}
Struct Tag Use:
import (
"github.com/astaxie/beego/validation"
)
// validation function follow with "valid" tag
// functions divide with ";"
// parameters in parentheses "()" and divide with ","
// Match function's pattern string must in "//"
type user struct {
Id int
Name string `valid:"Required;Match(/^(test)?\\w*@;com$/)"`
Age int `valid:"Required;Range(1, 140)"`
}
func main() {
valid := Validation{}
u := user{Name: "test", Age: 40}
b, err := valid.Valid(u)
if err != nil {
// handle error
}
if !b {
// validation does not pass
// blabla...
}
}
Struct Tag Functions:
Required
Min(min int)
Max(max int)
Range(min, max int)
MinSize(min int)
MaxSize(max int)
Length(length int)
Alpha
Numeric
AlphaNumeric
Match(pattern string)
AlphaDash
Email
IP
Base64
Mobile
Tel
Phone
ZipCode
## LICENSE
BSD License http://creativecommons.org/licenses/BSD/

229
validation/util.go Normal file
View File

@ -0,0 +1,229 @@
package validation
import (
"fmt"
"reflect"
"regexp"
"strconv"
"strings"
)
const (
VALIDTAG = "valid"
)
var (
// key: function name
// value: the number of parameters
funcs = make(Funcs)
// doesn't belong to validation functions
unFuncs = map[string]bool{
"Clear": true,
"HasErrors": true,
"ErrorMap": true,
"Error": true,
"apply": true,
"Check": true,
"Valid": true,
"NoMatch": true,
}
)
func init() {
v := &Validation{}
t := reflect.TypeOf(v)
for i := 0; i < t.NumMethod(); i++ {
m := t.Method(i)
if !unFuncs[m.Name] {
funcs[m.Name] = m.Func
}
}
}
type ValidFunc struct {
Name string
Params []interface{}
}
type Funcs map[string]reflect.Value
func (f Funcs) Call(name string, params ...interface{}) (result []reflect.Value, err error) {
defer func() {
if r := recover(); r != nil {
err = fmt.Errorf("%v", r)
}
}()
if _, ok := f[name]; !ok {
err = fmt.Errorf("%s does not exist", name)
return
}
if len(params) != f[name].Type().NumIn() {
err = fmt.Errorf("The number of params is not adapted")
return
}
in := make([]reflect.Value, len(params))
for k, param := range params {
in[k] = reflect.ValueOf(param)
}
result = f[name].Call(in)
return
}
func isStruct(t reflect.Type) bool {
return t.Kind() == reflect.Struct
}
func isStructPtr(t reflect.Type) bool {
return t.Kind() == reflect.Ptr && t.Elem().Kind() == reflect.Struct
}
func getValidFuncs(f reflect.StructField) (vfs []ValidFunc, err error) {
tag := f.Tag.Get(VALIDTAG)
if len(tag) == 0 {
return
}
if vfs, tag, err = getRegFuncs(tag, f.Name); err != nil {
fmt.Printf("%+v\n", err)
return
}
fs := strings.Split(tag, ";")
for _, vfunc := range fs {
var vf ValidFunc
if len(vfunc) == 0 {
continue
}
vf, err = parseFunc(vfunc, f.Name)
if err != nil {
return
}
vfs = append(vfs, vf)
}
return
}
// Get Match function
// May be get NoMatch function in the future
func getRegFuncs(tag, key string) (vfs []ValidFunc, str string, err error) {
tag = strings.TrimSpace(tag)
index := strings.Index(tag, "Match(/")
if index == -1 {
str = tag
return
}
end := strings.LastIndex(tag, "/)")
if end < index {
err = fmt.Errorf("invalid Match function")
return
}
reg, err := regexp.Compile(tag[index+len("Match(/") : end])
if err != nil {
return
}
vfs = []ValidFunc{ValidFunc{"Match", []interface{}{reg, key + ".Match"}}}
str = strings.TrimSpace(tag[:index]) + strings.TrimSpace(tag[end+len("/)"):])
return
}
func parseFunc(vfunc, key string) (v ValidFunc, err error) {
defer func() {
if r := recover(); r != nil {
err = fmt.Errorf("%v", r)
}
}()
vfunc = strings.TrimSpace(vfunc)
start := strings.Index(vfunc, "(")
var num int
// doesn't need parameter valid function
if start == -1 {
if num, err = numIn(vfunc); err != nil {
return
}
if num != 0 {
err = fmt.Errorf("%s require %d parameters", vfunc, num)
return
}
v = ValidFunc{vfunc, []interface{}{key + "." + vfunc}}
return
}
end := strings.Index(vfunc, ")")
if end == -1 {
err = fmt.Errorf("invalid valid function")
return
}
name := strings.TrimSpace(vfunc[:start])
if num, err = numIn(name); err != nil {
return
}
params := strings.Split(vfunc[start+1:end], ",")
// the num of param must be equal
if num != len(params) {
err = fmt.Errorf("%s require %d parameters", name, num)
return
}
tParams, err := trim(name, key+"."+name, params)
if err != nil {
return
}
v = ValidFunc{name, tParams}
return
}
func numIn(name string) (num int, err error) {
fn, ok := funcs[name]
if !ok {
err = fmt.Errorf("doesn't exsits %s valid function", name)
return
}
// sub *Validation obj and key
num = fn.Type().NumIn() - 3
return
}
func trim(name, key string, s []string) (ts []interface{}, err error) {
ts = make([]interface{}, len(s), len(s)+1)
fn, ok := funcs[name]
if !ok {
err = fmt.Errorf("doesn't exsits %s valid function", name)
return
}
for i := 0; i < len(s); i++ {
var param interface{}
// skip *Validation and obj params
if param, err = magic(fn.Type().In(i+2), strings.TrimSpace(s[i])); err != nil {
return
}
ts[i] = param
}
ts = append(ts, key)
return
}
// modify the parameters's type to adapt the function input parameters' type
func magic(t reflect.Type, s string) (i interface{}, err error) {
switch t.Kind() {
case reflect.Int:
i, err = strconv.Atoi(s)
case reflect.String:
i = s
case reflect.Ptr:
if t.Elem().String() != "regexp.Regexp" {
err = fmt.Errorf("does not support %s", t.Elem().String())
return
}
i, err = regexp.Compile(s)
default:
err = fmt.Errorf("does not support %s", t.Kind().String())
}
return
}
func mergeParam(v *Validation, obj interface{}, params []interface{}) []interface{} {
return append([]interface{}{v, obj}, params...)
}

86
validation/util_test.go Normal file
View File

@ -0,0 +1,86 @@
package validation
import (
"reflect"
"testing"
)
type user struct {
Id int
Tag string `valid:"Maxx(aa)"`
Name string `valid:"Required;"`
Age int `valid:"Required;Range(1, 140)"`
match string `valid:"Required; Match(/^(test)?\\w*@(/test/);com$/);Max(2)"`
}
func TestGetValidFuncs(t *testing.T) {
u := user{Name: "test", Age: 1}
tf := reflect.TypeOf(u)
var vfs []ValidFunc
var err error
f, _ := tf.FieldByName("Id")
if vfs, err = getValidFuncs(f); err != nil {
t.Fatal(err)
}
if len(vfs) != 0 {
t.Fatal("should get none ValidFunc")
}
f, _ = tf.FieldByName("Tag")
if vfs, err = getValidFuncs(f); err.Error() != "doesn't exsits Maxx valid function" {
t.Fatal(err)
}
f, _ = tf.FieldByName("Name")
if vfs, err = getValidFuncs(f); err != nil {
t.Fatal(err)
}
if len(vfs) != 1 {
t.Fatal("should get 1 ValidFunc")
}
if vfs[0].Name != "Required" && len(vfs[0].Params) != 0 {
t.Error("Required funcs should be got")
}
f, _ = tf.FieldByName("Age")
if vfs, err = getValidFuncs(f); err != nil {
t.Fatal(err)
}
if len(vfs) != 2 {
t.Fatal("should get 2 ValidFunc")
}
if vfs[0].Name != "Required" && len(vfs[0].Params) != 0 {
t.Error("Required funcs should be got")
}
if vfs[1].Name != "Range" && len(vfs[1].Params) != 2 {
t.Error("Range funcs should be got")
}
f, _ = tf.FieldByName("match")
if vfs, err = getValidFuncs(f); err != nil {
t.Fatal(err)
}
if len(vfs) != 3 {
t.Fatal("should get 3 ValidFunc but now is", len(vfs))
}
}
func TestCall(t *testing.T) {
u := user{Name: "test", Age: 180}
tf := reflect.TypeOf(u)
var vfs []ValidFunc
var err error
f, _ := tf.FieldByName("Age")
if vfs, err = getValidFuncs(f); err != nil {
t.Fatal(err)
}
valid := &Validation{}
vfs[1].Params = append([]interface{}{valid, u.Age}, vfs[1].Params...)
if _, err = funcs.Call(vfs[1].Name, vfs[1].Params...); err != nil {
t.Fatal(err)
}
if len(valid.Errors) != 1 {
t.Error("age out of range should be has an error")
}
}

227
validation/validation.go Normal file
View File

@ -0,0 +1,227 @@
package validation
import (
"fmt"
"reflect"
"regexp"
)
type ValidationError struct {
Message, Key string
}
// Returns the Message.
func (e *ValidationError) String() string {
if e == nil {
return ""
}
return e.Message
}
// A Validation context manages data validation and error messages.
type Validation struct {
Errors []*ValidationError
}
func (v *Validation) Clear() {
v.Errors = []*ValidationError{}
}
func (v *Validation) HasErrors() bool {
return len(v.Errors) > 0
}
// Return the errors mapped by key.
// If there are multiple validation errors associated with a single key, the
// first one "wins". (Typically the first validation will be the more basic).
func (v *Validation) ErrorMap() map[string]*ValidationError {
m := map[string]*ValidationError{}
for _, e := range v.Errors {
if _, ok := m[e.Key]; !ok {
m[e.Key] = e
}
}
return m
}
// Add an error to the validation context.
func (v *Validation) Error(message string, args ...interface{}) *ValidationResult {
result := (&ValidationResult{
Ok: false,
Error: &ValidationError{},
}).Message(message, args...)
v.Errors = append(v.Errors, result.Error)
return result
}
// A ValidationResult is returned from every validation method.
// It provides an indication of success, and a pointer to the Error (if any).
type ValidationResult struct {
Error *ValidationError
Ok bool
}
func (r *ValidationResult) Key(key string) *ValidationResult {
if r.Error != nil {
r.Error.Key = key
}
return r
}
func (r *ValidationResult) Message(message string, args ...interface{}) *ValidationResult {
if r.Error != nil {
if len(args) == 0 {
r.Error.Message = message
} else {
r.Error.Message = fmt.Sprintf(message, args...)
}
}
return r
}
// Test that the argument is non-nil and non-empty (if string or list)
func (v *Validation) Required(obj interface{}, key string) *ValidationResult {
return v.apply(Required{key}, obj)
}
// Test that the obj is greater than min if obj's type is int
func (v *Validation) Min(obj interface{}, min int, key string) *ValidationResult {
return v.apply(Min{min, key}, obj)
}
// Test that the obj is less than max if obj's type is int
func (v *Validation) Max(obj interface{}, max int, key string) *ValidationResult {
return v.apply(Max{max, key}, obj)
}
// Test that the obj is between mni and max if obj's type is int
func (v *Validation) Range(obj interface{}, min, max int, key string) *ValidationResult {
return v.apply(Range{Min{Min: min}, Max{Max: max}, key}, obj)
}
func (v *Validation) MinSize(obj interface{}, min int, key string) *ValidationResult {
return v.apply(MinSize{min, key}, obj)
}
func (v *Validation) MaxSize(obj interface{}, max int, key string) *ValidationResult {
return v.apply(MaxSize{max, key}, obj)
}
func (v *Validation) Length(obj interface{}, n int, key string) *ValidationResult {
return v.apply(Length{n, key}, obj)
}
func (v *Validation) Alpha(obj interface{}, key string) *ValidationResult {
return v.apply(Alpha{key}, obj)
}
func (v *Validation) Numeric(obj interface{}, key string) *ValidationResult {
return v.apply(Numeric{key}, obj)
}
func (v *Validation) AlphaNumeric(obj interface{}, key string) *ValidationResult {
return v.apply(AlphaNumeric{key}, obj)
}
func (v *Validation) Match(obj interface{}, regex *regexp.Regexp, key string) *ValidationResult {
return v.apply(Match{regex, key}, obj)
}
func (v *Validation) NoMatch(obj interface{}, regex *regexp.Regexp, key string) *ValidationResult {
return v.apply(NoMatch{Match{Regexp: regex}, key}, obj)
}
func (v *Validation) AlphaDash(obj interface{}, key string) *ValidationResult {
return v.apply(AlphaDash{NoMatch{Match: Match{Regexp: alphaDashPattern}}, key}, obj)
}
func (v *Validation) Email(obj interface{}, key string) *ValidationResult {
return v.apply(Email{Match{Regexp: emailPattern}, key}, obj)
}
func (v *Validation) IP(obj interface{}, key string) *ValidationResult {
return v.apply(IP{Match{Regexp: ipPattern}, key}, obj)
}
func (v *Validation) Base64(obj interface{}, key string) *ValidationResult {
return v.apply(Base64{Match{Regexp: base64Pattern}, key}, obj)
}
func (v *Validation) Mobile(obj interface{}, key string) *ValidationResult {
return v.apply(Mobile{Match{Regexp: mobilePattern}, key}, obj)
}
func (v *Validation) Tel(obj interface{}, key string) *ValidationResult {
return v.apply(Tel{Match{Regexp: telPattern}, key}, obj)
}
func (v *Validation) Phone(obj interface{}, key string) *ValidationResult {
return v.apply(Phone{Mobile{Match: Match{Regexp: mobilePattern}},
Tel{Match: Match{Regexp: telPattern}}, key}, obj)
}
func (v *Validation) ZipCode(obj interface{}, key string) *ValidationResult {
return v.apply(ZipCode{Match{Regexp: zipCodePattern}, key}, obj)
}
func (v *Validation) apply(chk Validator, obj interface{}) *ValidationResult {
if chk.IsSatisfied(obj) {
return &ValidationResult{Ok: true}
}
// Add the error to the validation context.
err := &ValidationError{
Message: chk.DefaultMessage(),
Key: chk.GetKey(),
}
v.Errors = append(v.Errors, err)
// Also return it in the result.
return &ValidationResult{
Ok: false,
Error: err,
}
}
// Apply a group of validators to a field, in order, and return the
// ValidationResult from the first one that fails, or the last one that
// succeeds.
func (v *Validation) Check(obj interface{}, checks ...Validator) *ValidationResult {
var result *ValidationResult
for _, check := range checks {
result = v.apply(check, obj)
if !result.Ok {
return result
}
}
return result
}
// the obj parameter must be a struct or a struct pointer
func (v *Validation) Valid(obj interface{}) (b bool, err error) {
objT := reflect.TypeOf(obj)
objV := reflect.ValueOf(obj)
switch {
case isStruct(objT):
case isStructPtr(objT):
objT = objT.Elem()
objV = objV.Elem()
default:
err = fmt.Errorf("%v must be a struct or a struct pointer", obj)
return
}
for i := 0; i < objT.NumField(); i++ {
var vfs []ValidFunc
if vfs, err = getValidFuncs(objT.Field(i)); err != nil {
return
}
for _, vf := range vfs {
if _, err = funcs.Call(vf.Name,
mergeParam(v, objV.Field(i).Interface(), vf.Params)...); err != nil {
return
}
}
}
return !v.HasErrors(), nil
}

View File

@ -0,0 +1,331 @@
package validation
import (
"regexp"
"testing"
"time"
)
func TestRequired(t *testing.T) {
valid := Validation{}
if valid.Required(nil, "nil").Ok {
t.Error("nil object should be false")
}
if valid.Required("", "string").Ok {
t.Error("\"'\" string should be false")
}
if !valid.Required("astaxie", "string").Ok {
t.Error("string should be true")
}
if valid.Required(0, "zero").Ok {
t.Error("Integer should not be equal 0")
}
if !valid.Required(1, "int").Ok {
t.Error("Integer except 0 should be true")
}
if !valid.Required(time.Now(), "time").Ok {
t.Error("time should be true")
}
if valid.Required([]string{}, "emptySlice").Ok {
t.Error("empty slice should be false")
}
if !valid.Required([]interface{}{"ok"}, "slice").Ok {
t.Error("slice should be true")
}
}
func TestMin(t *testing.T) {
valid := Validation{}
if valid.Min(-1, 0, "min0").Ok {
t.Error("-1 is less than the minimum value of 0 should be false")
}
if !valid.Min(1, 0, "min0").Ok {
t.Error("1 is greater or equal than the minimum value of 0 should be true")
}
}
func TestMax(t *testing.T) {
valid := Validation{}
if valid.Max(1, 0, "max0").Ok {
t.Error("1 is greater than the minimum value of 0 should be false")
}
if !valid.Max(-1, 0, "max0").Ok {
t.Error("-1 is less or equal than the maximum value of 0 should be true")
}
}
func TestRange(t *testing.T) {
valid := Validation{}
if valid.Range(-1, 0, 1, "range0_1").Ok {
t.Error("-1 is between 0 and 1 should be false")
}
if !valid.Range(1, 0, 1, "range0_1").Ok {
t.Error("1 is between 0 and 1 should be true")
}
}
func TestMinSize(t *testing.T) {
valid := Validation{}
if valid.MinSize("", 1, "minSize1").Ok {
t.Error("the length of \"\" is less than the minimum value of 1 should be false")
}
if !valid.MinSize("ok", 1, "minSize1").Ok {
t.Error("the length of \"ok\" is greater or equal than the minimum value of 1 should be true")
}
if valid.MinSize([]string{}, 1, "minSize1").Ok {
t.Error("the length of empty slice is less than the minimum value of 1 should be false")
}
if !valid.MinSize([]interface{}{"ok"}, 1, "minSize1").Ok {
t.Error("the length of [\"ok\"] is greater or equal than the minimum value of 1 should be true")
}
}
func TestMaxSize(t *testing.T) {
valid := Validation{}
if valid.MaxSize("ok", 1, "maxSize1").Ok {
t.Error("the length of \"ok\" is greater than the maximum value of 1 should be false")
}
if !valid.MaxSize("", 1, "maxSize1").Ok {
t.Error("the length of \"\" is less or equal than the maximum value of 1 should be true")
}
if valid.MaxSize([]interface{}{"ok", false}, 1, "maxSize1").Ok {
t.Error("the length of [\"ok\", false] is greater than the maximum value of 1 should be false")
}
if !valid.MaxSize([]string{}, 1, "maxSize1").Ok {
t.Error("the length of empty slice is less or equal than the maximum value of 1 should be true")
}
}
func TestLength(t *testing.T) {
valid := Validation{}
if valid.Length("", 1, "length1").Ok {
t.Error("the length of \"\" must equal 1 should be false")
}
if !valid.Length("1", 1, "length1").Ok {
t.Error("the length of \"1\" must equal 1 should be true")
}
if valid.Length([]string{}, 1, "length1").Ok {
t.Error("the length of empty slice must equal 1 should be false")
}
if !valid.Length([]interface{}{"ok"}, 1, "length1").Ok {
t.Error("the length of [\"ok\"] must equal 1 should be true")
}
}
func TestAlpha(t *testing.T) {
valid := Validation{}
if valid.Alpha("a,1-@ $", "alpha").Ok {
t.Error("\"a,1-@ $\" are valid alpha characters should be false")
}
if !valid.Alpha("abCD", "alpha").Ok {
t.Error("\"abCD\" are valid alpha characters should be true")
}
}
func TestNumeric(t *testing.T) {
valid := Validation{}
if valid.Numeric("a,1-@ $", "numeric").Ok {
t.Error("\"a,1-@ $\" are valid numeric characters should be false")
}
if !valid.Numeric("1234", "numeric").Ok {
t.Error("\"1234\" are valid numeric characters should be true")
}
}
func TestAlphaNumeric(t *testing.T) {
valid := Validation{}
if valid.AlphaNumeric("a,1-@ $", "alphaNumeric").Ok {
t.Error("\"a,1-@ $\" are valid alpha or numeric characters should be false")
}
if !valid.AlphaNumeric("1234aB", "alphaNumeric").Ok {
t.Error("\"1234aB\" are valid alpha or numeric characters should be true")
}
}
func TestMatch(t *testing.T) {
valid := Validation{}
if valid.Match("suchuangji@gmail", regexp.MustCompile("^\\w+@\\w+\\.\\w+$"), "match").Ok {
t.Error("\"suchuangji@gmail\" match \"^\\w+@\\w+\\.\\w+$\" should be false")
}
if !valid.Match("suchuangji@gmail.com", regexp.MustCompile("^\\w+@\\w+\\.\\w+$"), "match").Ok {
t.Error("\"suchuangji@gmail\" match \"^\\w+@\\w+\\.\\w+$\" should be true")
}
}
func TestNoMatch(t *testing.T) {
valid := Validation{}
if valid.NoMatch("123@gmail", regexp.MustCompile("[^\\w\\d]"), "nomatch").Ok {
t.Error("\"123@gmail\" not match \"[^\\w\\d]\" should be false")
}
if !valid.NoMatch("123gmail", regexp.MustCompile("[^\\w\\d]"), "match").Ok {
t.Error("\"123@gmail\" not match \"[^\\w\\d@]\" should be true")
}
}
func TestAlphaDash(t *testing.T) {
valid := Validation{}
if valid.AlphaDash("a,1-@ $", "alphaDash").Ok {
t.Error("\"a,1-@ $\" are valid alpha or numeric or dash(-_) characters should be false")
}
if !valid.AlphaDash("1234aB-_", "alphaDash").Ok {
t.Error("\"1234aB\" are valid alpha or numeric or dash(-_) characters should be true")
}
}
func TestEmail(t *testing.T) {
valid := Validation{}
if valid.Email("not@a email", "email").Ok {
t.Error("\"not@a email\" is a valid email address should be false")
}
if !valid.Email("suchuangji@gmail.com", "email").Ok {
t.Error("\"suchuangji@gmail.com\" is a valid email address should be true")
}
}
func TestIP(t *testing.T) {
valid := Validation{}
if valid.IP("11.255.255.256", "IP").Ok {
t.Error("\"11.255.255.256\" is a valid ip address should be false")
}
if !valid.IP("01.11.11.11", "IP").Ok {
t.Error("\"suchuangji@gmail.com\" is a valid ip address should be true")
}
}
func TestBase64(t *testing.T) {
valid := Validation{}
if valid.Base64("suchuangji@gmail.com", "base64").Ok {
t.Error("\"suchuangji@gmail.com\" are a valid base64 characters should be false")
}
if !valid.Base64("c3VjaHVhbmdqaUBnbWFpbC5jb20=", "base64").Ok {
t.Error("\"c3VjaHVhbmdqaUBnbWFpbC5jb20=\" are a valid base64 characters should be true")
}
}
func TestMobile(t *testing.T) {
valid := Validation{}
if valid.Mobile("19800008888", "mobile").Ok {
t.Error("\"19800008888\" is a valid mobile phone number should be false")
}
if !valid.Mobile("18800008888", "mobile").Ok {
t.Error("\"18800008888\" is a valid mobile phone number should be true")
}
if !valid.Mobile("18000008888", "mobile").Ok {
t.Error("\"18000008888\" is a valid mobile phone number should be true")
}
if !valid.Mobile("8618300008888", "mobile").Ok {
t.Error("\"8618300008888\" is a valid mobile phone number should be true")
}
if !valid.Mobile("+8614700008888", "mobile").Ok {
t.Error("\"+8614700008888\" is a valid mobile phone number should be true")
}
}
func TestTel(t *testing.T) {
valid := Validation{}
if valid.Tel("222-00008888", "telephone").Ok {
t.Error("\"222-00008888\" is a valid telephone number should be false")
}
if !valid.Tel("022-70008888", "telephone").Ok {
t.Error("\"022-70008888\" is a valid telephone number should be true")
}
if !valid.Tel("02270008888", "telephone").Ok {
t.Error("\"02270008888\" is a valid telephone number should be true")
}
if !valid.Tel("70008888", "telephone").Ok {
t.Error("\"70008888\" is a valid telephone number should be true")
}
}
func TestPhone(t *testing.T) {
valid := Validation{}
if valid.Phone("222-00008888", "phone").Ok {
t.Error("\"222-00008888\" is a valid phone number should be false")
}
if !valid.Mobile("+8614700008888", "phone").Ok {
t.Error("\"+8614700008888\" is a valid phone number should be true")
}
if !valid.Tel("02270008888", "phone").Ok {
t.Error("\"02270008888\" is a valid phone number should be true")
}
}
func TestZipCode(t *testing.T) {
valid := Validation{}
if valid.ZipCode("", "zipcode").Ok {
t.Error("\"00008888\" is a valid zipcode should be false")
}
if !valid.ZipCode("536000", "zipcode").Ok {
t.Error("\"536000\" is a valid zipcode should be true")
}
}
func TestValid(t *testing.T) {
type user struct {
Id int
Name string `valid:"Required;Match(/^(test)?\\w*@(/test/);com$/)"`
Age int `valid:"Required;Range(1, 140)"`
}
valid := Validation{}
u := user{Name: "test@/test/;com", Age: 40}
b, err := valid.Valid(u)
if err != nil {
t.Fatal(err)
}
if !b {
t.Error("validation should be passed")
}
uptr := &user{Name: "test", Age: 40}
valid.Clear()
b, err = valid.Valid(uptr)
if err != nil {
t.Fatal(err)
}
if b {
t.Error("validation should not be passed")
}
if len(valid.Errors) != 1 {
t.Fatalf("valid errors len should be 1 but got %d", len(valid.Errors))
}
if valid.Errors[0].Key != "Name.Match" {
t.Errorf("Message key should be `Name.Match` but got %s", valid.Errors[0].Key)
}
u = user{Name: "test@/test/;com", Age: 180}
valid.Clear()
b, err = valid.Valid(u)
if err != nil {
t.Fatal(err)
}
if b {
t.Error("validation should not be passed")
}
if len(valid.Errors) != 1 {
t.Fatalf("valid errors len should be 1 but got %d", len(valid.Errors))
}
if valid.Errors[0].Key != "Age.Range" {
t.Errorf("Message key should be `Name.Match` but got %s", valid.Errors[0].Key)
}
}

421
validation/validators.go Normal file
View File

@ -0,0 +1,421 @@
package validation
import (
"fmt"
"reflect"
"regexp"
"time"
)
type Validator interface {
IsSatisfied(interface{}) bool
DefaultMessage() string
GetKey() string
}
type Required struct {
Key string
}
func (r Required) IsSatisfied(obj interface{}) bool {
if obj == nil {
return false
}
if str, ok := obj.(string); ok {
return len(str) > 0
}
if b, ok := obj.(bool); ok {
return b
}
if i, ok := obj.(int); ok {
return i != 0
}
if t, ok := obj.(time.Time); ok {
return !t.IsZero()
}
v := reflect.ValueOf(obj)
if v.Kind() == reflect.Slice {
return v.Len() > 0
}
return true
}
func (r Required) DefaultMessage() string {
return "Required"
}
func (r Required) GetKey() string {
return r.Key
}
type Min struct {
Min int
Key string
}
func (m Min) IsSatisfied(obj interface{}) bool {
num, ok := obj.(int)
if ok {
return num >= m.Min
}
return false
}
func (m Min) DefaultMessage() string {
return fmt.Sprint("Minimum is ", m.Min)
}
func (m Min) GetKey() string {
return m.Key
}
type Max struct {
Max int
Key string
}
func (m Max) IsSatisfied(obj interface{}) bool {
num, ok := obj.(int)
if ok {
return num <= m.Max
}
return false
}
func (m Max) DefaultMessage() string {
return fmt.Sprint("Maximum is ", m.Max)
}
func (m Max) GetKey() string {
return m.Key
}
// Requires an integer to be within Min, Max inclusive.
type Range struct {
Min
Max
Key string
}
func (r Range) IsSatisfied(obj interface{}) bool {
return r.Min.IsSatisfied(obj) && r.Max.IsSatisfied(obj)
}
func (r Range) DefaultMessage() string {
return fmt.Sprint("Range is ", r.Min.Min, " to ", r.Max.Max)
}
func (r Range) GetKey() string {
return r.Key
}
// Requires an array or string to be at least a given length.
type MinSize struct {
Min int
Key string
}
func (m MinSize) IsSatisfied(obj interface{}) bool {
if str, ok := obj.(string); ok {
return len(str) >= m.Min
}
v := reflect.ValueOf(obj)
if v.Kind() == reflect.Slice {
return v.Len() >= m.Min
}
return false
}
func (m MinSize) DefaultMessage() string {
return fmt.Sprint("Minimum size is ", m.Min)
}
func (m MinSize) GetKey() string {
return m.Key
}
// Requires an array or string to be at most a given length.
type MaxSize struct {
Max int
Key string
}
func (m MaxSize) IsSatisfied(obj interface{}) bool {
if str, ok := obj.(string); ok {
return len(str) <= m.Max
}
v := reflect.ValueOf(obj)
if v.Kind() == reflect.Slice {
return v.Len() <= m.Max
}
return false
}
func (m MaxSize) DefaultMessage() string {
return fmt.Sprint("Maximum size is ", m.Max)
}
func (m MaxSize) GetKey() string {
return m.Key
}
// Requires an array or string to be exactly a given length.
type Length struct {
N int
Key string
}
func (l Length) IsSatisfied(obj interface{}) bool {
if str, ok := obj.(string); ok {
return len(str) == l.N
}
v := reflect.ValueOf(obj)
if v.Kind() == reflect.Slice {
return v.Len() == l.N
}
return false
}
func (l Length) DefaultMessage() string {
return fmt.Sprint("Required length is ", l.N)
}
func (l Length) GetKey() string {
return l.Key
}
type Alpha struct {
Key string
}
func (a Alpha) IsSatisfied(obj interface{}) bool {
if str, ok := obj.(string); ok {
for _, v := range str {
if ('Z' < v || v < 'A') && ('z' < v || v < 'a') {
return false
}
}
return true
}
return false
}
func (a Alpha) DefaultMessage() string {
return fmt.Sprint("Must be valid alpha characters")
}
func (a Alpha) GetKey() string {
return a.Key
}
type Numeric struct {
Key string
}
func (n Numeric) IsSatisfied(obj interface{}) bool {
if str, ok := obj.(string); ok {
for _, v := range str {
if '9' < v || v < '0' {
return false
}
}
return true
}
return false
}
func (n Numeric) DefaultMessage() string {
return fmt.Sprint("Must be valid numeric characters")
}
func (n Numeric) GetKey() string {
return n.Key
}
type AlphaNumeric struct {
Key string
}
func (a AlphaNumeric) IsSatisfied(obj interface{}) bool {
if str, ok := obj.(string); ok {
for _, v := range str {
if ('Z' < v || v < 'A') && ('z' < v || v < 'a') && ('9' < v || v < '0') {
return false
}
}
return true
}
return false
}
func (a AlphaNumeric) DefaultMessage() string {
return fmt.Sprint("Must be valid alpha or numeric characters")
}
func (a AlphaNumeric) GetKey() string {
return a.Key
}
// Requires a string to match a given regex.
type Match struct {
Regexp *regexp.Regexp
Key string
}
func (m Match) IsSatisfied(obj interface{}) bool {
return m.Regexp.MatchString(fmt.Sprintf("%v", obj))
}
func (m Match) DefaultMessage() string {
return fmt.Sprint("Must match ", m.Regexp)
}
func (m Match) GetKey() string {
return m.Key
}
// Requires a string to not match a given regex.
type NoMatch struct {
Match
Key string
}
func (n NoMatch) IsSatisfied(obj interface{}) bool {
return !n.Match.IsSatisfied(obj)
}
func (n NoMatch) DefaultMessage() string {
return fmt.Sprint("Must not match ", n.Regexp)
}
func (n NoMatch) GetKey() string {
return n.Key
}
var alphaDashPattern = regexp.MustCompile("[^\\d\\w-_]")
type AlphaDash struct {
NoMatch
Key string
}
func (a AlphaDash) DefaultMessage() string {
return fmt.Sprint("Must be valid alpha or numeric or dash(-_) characters")
}
func (a AlphaDash) GetKey() string {
return a.Key
}
var emailPattern = regexp.MustCompile("[\\w!#$%&'*+/=?^_`{|}~-]+(?:\\.[\\w!#$%&'*+/=?^_`{|}~-]+)*@(?:[\\w](?:[\\w-]*[\\w])?\\.)+[a-zA-Z0-9](?:[\\w-]*[\\w])?")
type Email struct {
Match
Key string
}
func (e Email) DefaultMessage() string {
return fmt.Sprint("Must be a valid email address")
}
func (e Email) GetKey() string {
return e.Key
}
var ipPattern = regexp.MustCompile("^((2[0-4]\\d|25[0-5]|[01]?\\d\\d?)\\.){3}(2[0-4]\\d|25[0-5]|[01]?\\d\\d?)$")
type IP struct {
Match
Key string
}
func (i IP) DefaultMessage() string {
return fmt.Sprint("Must be a valid ip address")
}
func (i IP) GetKey() string {
return i.Key
}
var base64Pattern = regexp.MustCompile("^(?:[A-Za-z0-99+/]{4})*(?:[A-Za-z0-9+/]{2}==|[A-Za-z0-9+/]{3}=)?$")
type Base64 struct {
Match
Key string
}
func (b Base64) DefaultMessage() string {
return fmt.Sprint("Must be valid base64 characters")
}
func (b Base64) GetKey() string {
return b.Key
}
// just for chinese mobile phone number
var mobilePattern = regexp.MustCompile("^((\\+86)|(86))?(1(([35][0-9])|(47)|[8][01236789]))\\d{8}$")
type Mobile struct {
Match
Key string
}
func (m Mobile) DefaultMessage() string {
return fmt.Sprint("Must be valid mobile number")
}
func (m Mobile) GetKey() string {
return m.Key
}
// just for chinese telephone number
var telPattern = regexp.MustCompile("^(0\\d{2,3}(\\-)?)?\\d{7,8}$")
type Tel struct {
Match
Key string
}
func (t Tel) DefaultMessage() string {
return fmt.Sprint("Must be valid telephone number")
}
func (t Tel) GetKey() string {
return t.Key
}
// just for chinese telephone or mobile phone number
type Phone struct {
Mobile
Tel
Key string
}
func (p Phone) IsSatisfied(obj interface{}) bool {
return p.Mobile.IsSatisfied(obj) || p.Tel.IsSatisfied(obj)
}
func (p Phone) DefaultMessage() string {
return fmt.Sprint("Must be valid telephone or mobile phone number")
}
func (p Phone) GetKey() string {
return p.Key
}
// just for chinese zipcode
var zipCodePattern = regexp.MustCompile("^[1-9]\\d{5}$")
type ZipCode struct {
Match
Key string
}
func (z ZipCode) DefaultMessage() string {
return fmt.Sprint("Must be valid zipcode")
}
func (z ZipCode) GetKey() string {
return z.Key
}