mirror of
https://github.com/astaxie/beego.git
synced 2025-07-11 11:01:02 +00:00
Compare commits
247 Commits
Author | SHA1 | Date | |
---|---|---|---|
92f6181616 | |||
9270a0504a | |||
1da37f6ce1 | |||
ef6d9b9a94 | |||
c265786251 | |||
c5c806b58e | |||
e657dcfd5f | |||
2ed9b2bffd | |||
55ad951bce | |||
ef815bf5fc | |||
6082a0af3e | |||
be30fb7937 | |||
f4e7d63e65 | |||
14688f240f | |||
dce09837b9 | |||
3b9a404138 | |||
a6f55b59cf | |||
c188cbbcb4 | |||
4245521660 | |||
05e5baaa9f | |||
54b92e9599 | |||
aa68ffecec | |||
78991c81ab | |||
348ff13857 | |||
52817fb668 | |||
b6d63c84ae | |||
bc2f1fb79d | |||
8ed459512f | |||
25768f0109 | |||
a56f67e073 | |||
8164f9821d | |||
e307bd7ba9 | |||
a99802b7d1 | |||
3e16feb1e2 | |||
82ca85dc65 | |||
f9cc9e9eb3 | |||
18fee2ad9a | |||
4124760706 | |||
3fe4f8c362 | |||
5a863b45f4 | |||
3ad30d48b5 | |||
3255a43568 | |||
73d757e3f4 | |||
deb28dd873 | |||
7f394feab5 | |||
8cbea70e07 | |||
f222f5b238 | |||
3f1de576e4 | |||
f48ca96a7e | |||
9421a21037 | |||
5c06cd090c | |||
fc982feeb9 | |||
31de651053 | |||
f4d62d3193 | |||
acbdeb62e8 | |||
d58e9e6e12 | |||
6497f29ed7 | |||
1705b42546 | |||
5505cc09ed | |||
12e1ab0f80 | |||
9c5ceb70cc | |||
bf0b1af64f | |||
9c959fba4d | |||
5588bfc35e | |||
2f4acf46c6 | |||
c7437d7590 | |||
4f819dbd9a | |||
3f5fee2dc6 | |||
c7f16b5d5a | |||
8d1268c0a9 | |||
c921b0aa5d | |||
589f97130c | |||
443aaadcce | |||
ff1938054a | |||
d79c297880 | |||
65631e0522 | |||
a879e412a1 | |||
50bc1ef757 | |||
7bacb25725 | |||
ad8418720f | |||
b59dae6fb8 | |||
4951314837 | |||
8188873216 | |||
5d392b76c7 | |||
c6a34b8efd | |||
95e67ba2c2 | |||
439b1afb85 | |||
745e9fb0fb | |||
769f7c751b | |||
a8c2deb014 | |||
624f6258ee | |||
43c977ab62 | |||
6c92ca2a16 | |||
0f015d75d2 | |||
217c3a2e87 | |||
ff6120cb93 | |||
53aaf3b4a9 | |||
d5b5c18cf9 | |||
cacdb3228d | |||
d0949b64c6 | |||
d49984d47d | |||
9f3af59250 | |||
57afd3d979 | |||
f7430a2ce1 | |||
7389f0507e | |||
ee889e9975 | |||
d8b9db8d3e | |||
9b498feac7 | |||
69982c62c8 | |||
b405e19f56 | |||
828235b4c1 | |||
430a0a971f | |||
5be22a99a8 | |||
5d02b18db4 | |||
97b68bdd66 | |||
5583fa2054 | |||
00a410ad1a | |||
6ca30386b8 | |||
03b17a2ca9 | |||
9957a867cd | |||
f3ba41a991 | |||
4befa1bc1b | |||
9e3ebc88c4 | |||
6e00cfb464 | |||
c358c18018 | |||
adf2a590fc | |||
edb8bac5bc | |||
47d7ac06b7 | |||
d05270d2ec | |||
04a19685ed | |||
62555771d0 | |||
9dc93cbab0 | |||
7f5fb871de | |||
03037170e1 | |||
002e0854ab | |||
2bc70f62ce | |||
8bf0e67b79 | |||
b310be1fcf | |||
a54353b51c | |||
04c2ba01bc | |||
296bcab425 | |||
060b321182 | |||
05a0a4b046 | |||
8906d3e77c | |||
e822642cb0 | |||
a38a4f0343 | |||
e8a22660e4 | |||
92196c602b | |||
76222ac8d0 | |||
a184c23603 | |||
1b778509c9 | |||
c4250872ca | |||
17dd72241b | |||
ce2984f09a | |||
846d766499 | |||
bbc71142d7 | |||
74804bc586 | |||
1d08a54f44 | |||
682544165f | |||
3f0ec5c0ca | |||
0e2872324f | |||
2fb575838d | |||
ab8f8d532a | |||
d93f112083 | |||
9384e87083 | |||
34eff4cc1f | |||
8296713ba4 | |||
d014ccfb8e | |||
190039b6f8 | |||
edf7982567 | |||
1509a6b681 | |||
11e6c2829b | |||
38f93a7ba9 | |||
6b5108ef92 | |||
828a306069 | |||
4c527dde65 | |||
f5a5ebe16b | |||
32799bc259 | |||
91d75e8925 | |||
3e40041219 | |||
7d5ee0d692 | |||
91cbe1f29b | |||
f419c12427 | |||
fee3c2b8f9 | |||
b016102d34 | |||
c20e1ab1e2 | |||
dc767b65df | |||
63f19974cd | |||
6e9ba0ea7f | |||
3b99f37aa1 | |||
e8f5c10488 | |||
cb55009c8b | |||
b64e70e7df | |||
8d79f8387b | |||
afadb3f6df | |||
844412c302 | |||
299cb9130b | |||
0b42e5573b | |||
a369b15ef2 | |||
e34f8c4634 | |||
d7f2c738c8 | |||
d06c04277f | |||
aa2fef0d36 | |||
b766f65c26 | |||
6f3a759ba5 | |||
338124e3fb | |||
31bdb793cf | |||
9cbd475701 | |||
481448fa90 | |||
95c65de97c | |||
ef79a2b484 | |||
20cfece1ab | |||
c433b7029f | |||
f5cf2876dd | |||
480aa521e5 | |||
d57557dc55 | |||
803d91c077 | |||
62ee48dcbf | |||
1e57587fe9 | |||
61c0b3e286 | |||
383a04f4c2 | |||
eea272482b | |||
94ad13c846 | |||
412a4a04de | |||
e0e8fa6e2a | |||
a1e29b0b75 | |||
984b0cbf31 | |||
3118c6c23f | |||
3a08eec1f9 | |||
ecfd11adb4 | |||
95dc670eb4 | |||
7a3d05ebf3 | |||
62f54cbbee | |||
4d7f7ffa37 | |||
cb876268b5 | |||
094f2fbab8 | |||
2d77c4dc49 | |||
f535916fae | |||
673993fa2b | |||
6f3803ce8c | |||
a1f6039d82 | |||
0183608a59 | |||
5b1afcdb5a | |||
053e7a6aa6 | |||
ba94479efd | |||
ba3a9bee4c | |||
f96eec6dea |
2
.gitignore
vendored
2
.gitignore
vendored
@ -1 +1,3 @@
|
||||
.DS_Store
|
||||
*.swp
|
||||
*.swo
|
||||
|
13
LICENSE
Normal file
13
LICENSE
Normal file
@ -0,0 +1,13 @@
|
||||
Copyright 2014 astaxie
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
@ -35,9 +35,5 @@ More info [beego.me](http://beego.me)
|
||||
beego is licensed under the Apache Licence, Version 2.0
|
||||
(http://www.apache.org/licenses/LICENSE-2.0.html).
|
||||
|
||||
|
||||
## Use case
|
||||
|
||||
- Displaying API documentation: [gowalker](https://github.com/Unknwon/gowalker)
|
||||
- seocms: [seocms](https://github.com/chinakr/seocms)
|
||||
- CMS: [toropress](https://github.com/insionng/toropress)
|
||||
[][koding]
|
||||
[koding]: https://koding.com/Teamwork?import=https://github.com/astaxie/beego/archive/master.zip&c=git1
|
106
admin.go
106
admin.go
@ -1,16 +1,23 @@
|
||||
// Beego (http://beego.me/)
|
||||
// @description beego is an open-source, high-performance web framework for the Go programming language.
|
||||
// @link http://github.com/astaxie/beego for the canonical source repository
|
||||
// @license http://github.com/astaxie/beego/blob/master/LICENSE
|
||||
// @authors astaxie
|
||||
|
||||
package beego
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/astaxie/beego/toolbox"
|
||||
"github.com/astaxie/beego/utils"
|
||||
)
|
||||
|
||||
// BeeAdminApp is the default AdminApp used by admin module.
|
||||
var BeeAdminApp *AdminApp
|
||||
// BeeAdminApp is the default adminApp used by admin module.
|
||||
var beeAdminApp *adminApp
|
||||
|
||||
// FilterMonitorFunc is default monitor filter when admin module is enable.
|
||||
// if this func returns, admin module records qbs for this request by condition of this function logic.
|
||||
@ -31,42 +38,43 @@ var BeeAdminApp *AdminApp
|
||||
var FilterMonitorFunc func(string, string, time.Duration) bool
|
||||
|
||||
func init() {
|
||||
BeeAdminApp = &AdminApp{
|
||||
beeAdminApp = &adminApp{
|
||||
routers: make(map[string]http.HandlerFunc),
|
||||
}
|
||||
BeeAdminApp.Route("/", AdminIndex)
|
||||
BeeAdminApp.Route("/qps", QpsIndex)
|
||||
BeeAdminApp.Route("/prof", ProfIndex)
|
||||
BeeAdminApp.Route("/healthcheck", Healthcheck)
|
||||
BeeAdminApp.Route("/task", TaskStatus)
|
||||
BeeAdminApp.Route("/runtask", RunTask)
|
||||
BeeAdminApp.Route("/listconf", ListConf)
|
||||
beeAdminApp.Route("/", adminIndex)
|
||||
beeAdminApp.Route("/qps", qpsIndex)
|
||||
beeAdminApp.Route("/prof", profIndex)
|
||||
beeAdminApp.Route("/healthcheck", healthcheck)
|
||||
beeAdminApp.Route("/task", taskStatus)
|
||||
beeAdminApp.Route("/runtask", runTask)
|
||||
beeAdminApp.Route("/listconf", listConf)
|
||||
FilterMonitorFunc = func(string, string, time.Duration) bool { return true }
|
||||
}
|
||||
|
||||
// AdminIndex is the default http.Handler for admin module.
|
||||
// it matches url pattern "/".
|
||||
func AdminIndex(rw http.ResponseWriter, r *http.Request) {
|
||||
rw.Write([]byte("Welcome to Admin Dashboard\n"))
|
||||
rw.Write([]byte("There are servral functions:\n"))
|
||||
rw.Write([]byte("1. Record all request and request time, http://localhost:8088/qps\n"))
|
||||
rw.Write([]byte("2. Get runtime profiling data by the pprof, http://localhost:8088/prof\n"))
|
||||
rw.Write([]byte("3. Get healthcheck result from http://localhost:8088/prof\n"))
|
||||
rw.Write([]byte("4. Get current task infomation from taskhttp://localhost:8088/task \n"))
|
||||
rw.Write([]byte("5. To run a task passed a param http://localhost:8088/runtask\n"))
|
||||
rw.Write([]byte("6. Get all confige & router infomation http://localhost:8088/listconf\n"))
|
||||
|
||||
func adminIndex(rw http.ResponseWriter, r *http.Request) {
|
||||
rw.Write([]byte("<html><head><title>beego admin dashboard</title></head><body>"))
|
||||
rw.Write([]byte("Welcome to Admin Dashboard<br>\n"))
|
||||
rw.Write([]byte("There are servral functions:<br>\n"))
|
||||
rw.Write([]byte("1. Record all request and request time, <a href='/qps'>http://localhost:" + strconv.Itoa(AdminHttpPort) + "/qps</a><br>\n"))
|
||||
rw.Write([]byte("2. Get runtime profiling data by the pprof, <a href='/prof'>http://localhost:" + strconv.Itoa(AdminHttpPort) + "/prof</a><br>\n"))
|
||||
rw.Write([]byte("3. Get healthcheck result from <a href='/healthcheck'>http://localhost:" + strconv.Itoa(AdminHttpPort) + "/healthcheck</a><br>\n"))
|
||||
rw.Write([]byte("4. Get current task infomation from task <a href='/task'>http://localhost:" + strconv.Itoa(AdminHttpPort) + "/task</a><br> \n"))
|
||||
rw.Write([]byte("5. To run a task passed a param <a href='/runtask'>http://localhost:" + strconv.Itoa(AdminHttpPort) + "/runtask</a><br>\n"))
|
||||
rw.Write([]byte("6. Get all confige & router infomation <a href='/listconf'>http://localhost:" + strconv.Itoa(AdminHttpPort) + "/listconf</a><br>\n"))
|
||||
rw.Write([]byte("</body></html>"))
|
||||
}
|
||||
|
||||
// QpsIndex is the http.Handler for writing qbs statistics map result info in http.ResponseWriter.
|
||||
// it's registered with url pattern "/qbs" in admin module.
|
||||
func QpsIndex(rw http.ResponseWriter, r *http.Request) {
|
||||
func qpsIndex(rw http.ResponseWriter, r *http.Request) {
|
||||
toolbox.StatisticsMap.GetMap(rw)
|
||||
}
|
||||
|
||||
// ListConf is the http.Handler of displaying all beego configuration values as key/value pair.
|
||||
// it's registered with url pattern "/listconf" in admin module.
|
||||
func ListConf(rw http.ResponseWriter, r *http.Request) {
|
||||
func listConf(rw http.ResponseWriter, r *http.Request) {
|
||||
r.ParseForm()
|
||||
command := r.Form.Get("command")
|
||||
if command != "" {
|
||||
@ -174,37 +182,41 @@ func ListConf(rw http.ResponseWriter, r *http.Request) {
|
||||
rw.Write([]byte("command not support"))
|
||||
}
|
||||
} else {
|
||||
rw.Write([]byte("ListConf support this command:\n"))
|
||||
rw.Write([]byte("1. command=conf\n"))
|
||||
rw.Write([]byte("2. command=router\n"))
|
||||
rw.Write([]byte("3. command=filter\n"))
|
||||
rw.Write([]byte("<html><head><title>beego admin dashboard</title></head><body>"))
|
||||
rw.Write([]byte("ListConf support this command:<br>\n"))
|
||||
rw.Write([]byte("1. <a href='?command=conf'>command=conf</a><br>\n"))
|
||||
rw.Write([]byte("2. <a href='?command=router'>command=router</a><br>\n"))
|
||||
rw.Write([]byte("3. <a href='?command=filter'>command=filter</a><br>\n"))
|
||||
rw.Write([]byte("</body></html>"))
|
||||
}
|
||||
}
|
||||
|
||||
// ProfIndex is a http.Handler for showing profile command.
|
||||
// it's in url pattern "/prof" in admin module.
|
||||
func ProfIndex(rw http.ResponseWriter, r *http.Request) {
|
||||
func profIndex(rw http.ResponseWriter, r *http.Request) {
|
||||
r.ParseForm()
|
||||
command := r.Form.Get("command")
|
||||
if command != "" {
|
||||
toolbox.ProcessInput(command, rw)
|
||||
} else {
|
||||
rw.Write([]byte("request url like '/prof?command=lookup goroutine'\n"))
|
||||
rw.Write([]byte("the command have below types:\n"))
|
||||
rw.Write([]byte("1. lookup goroutine\n"))
|
||||
rw.Write([]byte("2. lookup heap\n"))
|
||||
rw.Write([]byte("3. lookup threadcreate\n"))
|
||||
rw.Write([]byte("4. lookup block\n"))
|
||||
rw.Write([]byte("5. start cpuprof\n"))
|
||||
rw.Write([]byte("6. stop cpuprof\n"))
|
||||
rw.Write([]byte("7. get memprof\n"))
|
||||
rw.Write([]byte("8. gc summary\n"))
|
||||
rw.Write([]byte("<html><head><title>beego admin dashboard</title></head><body>"))
|
||||
rw.Write([]byte("request url like '/prof?command=lookup goroutine'<br>\n"))
|
||||
rw.Write([]byte("the command have below types:<br>\n"))
|
||||
rw.Write([]byte("1. <a href='?command=lookup goroutine'>lookup goroutine</a><br>\n"))
|
||||
rw.Write([]byte("2. <a href='?command=lookup heap'>lookup heap</a><br>\n"))
|
||||
rw.Write([]byte("3. <a href='?command=lookup threadcreate'>lookup threadcreate</a><br>\n"))
|
||||
rw.Write([]byte("4. <a href='?command=lookup block'>lookup block</a><br>\n"))
|
||||
rw.Write([]byte("5. <a href='?command=start cpuprof'>start cpuprof</a><br>\n"))
|
||||
rw.Write([]byte("6. <a href='?command=stop cpuprof'>stop cpuprof</a><br>\n"))
|
||||
rw.Write([]byte("7. <a href='?command=get memprof'>get memprof</a><br>\n"))
|
||||
rw.Write([]byte("8. <a href='?command=gc summary'>gc summary</a><br>\n"))
|
||||
rw.Write([]byte("</body></html>"))
|
||||
}
|
||||
}
|
||||
|
||||
// Healthcheck is a http.Handler calling health checking and showing the result.
|
||||
// it's in "/healthcheck" pattern in admin module.
|
||||
func Healthcheck(rw http.ResponseWriter, req *http.Request) {
|
||||
func healthcheck(rw http.ResponseWriter, req *http.Request) {
|
||||
for name, h := range toolbox.AdminCheckList {
|
||||
if err := h.Check(); err != nil {
|
||||
fmt.Fprintf(rw, "%s : ok\n", name)
|
||||
@ -216,7 +228,7 @@ func Healthcheck(rw http.ResponseWriter, req *http.Request) {
|
||||
|
||||
// TaskStatus is a http.Handler with running task status (task name, status and the last execution).
|
||||
// it's in "/task" pattern in admin module.
|
||||
func TaskStatus(rw http.ResponseWriter, req *http.Request) {
|
||||
func taskStatus(rw http.ResponseWriter, req *http.Request) {
|
||||
for tname, tk := range toolbox.AdminTaskList {
|
||||
fmt.Fprintf(rw, "%s:%s:%s", tname, tk.GetStatus(), tk.GetPrev().String())
|
||||
}
|
||||
@ -224,7 +236,7 @@ func TaskStatus(rw http.ResponseWriter, req *http.Request) {
|
||||
|
||||
// RunTask is a http.Handler to run a Task from the "query string.
|
||||
// the request url likes /runtask?taskname=sendmail.
|
||||
func RunTask(rw http.ResponseWriter, req *http.Request) {
|
||||
func runTask(rw http.ResponseWriter, req *http.Request) {
|
||||
req.ParseForm()
|
||||
taskname := req.Form.Get("taskname")
|
||||
if t, ok := toolbox.AdminTaskList[taskname]; ok {
|
||||
@ -232,25 +244,25 @@ func RunTask(rw http.ResponseWriter, req *http.Request) {
|
||||
if err != nil {
|
||||
fmt.Fprintf(rw, "%v", err)
|
||||
}
|
||||
fmt.Fprintf(rw, "%s run success,Now the Status is %s", t.GetStatus())
|
||||
fmt.Fprintf(rw, "%s run success,Now the Status is %s", taskname, t.GetStatus())
|
||||
} else {
|
||||
fmt.Fprintf(rw, "there's no task which named:%s", taskname)
|
||||
}
|
||||
}
|
||||
|
||||
// AdminApp is an http.HandlerFunc map used as BeeAdminApp.
|
||||
type AdminApp struct {
|
||||
// adminApp is an http.HandlerFunc map used as beeAdminApp.
|
||||
type adminApp struct {
|
||||
routers map[string]http.HandlerFunc
|
||||
}
|
||||
|
||||
// Route adds http.HandlerFunc to AdminApp with url pattern.
|
||||
func (admin *AdminApp) Route(pattern string, f http.HandlerFunc) {
|
||||
// Route adds http.HandlerFunc to adminApp with url pattern.
|
||||
func (admin *adminApp) Route(pattern string, f http.HandlerFunc) {
|
||||
admin.routers[pattern] = f
|
||||
}
|
||||
|
||||
// Run AdminApp http server.
|
||||
// Run adminApp http server.
|
||||
// Its addr is defined in configuration file as adminhttpaddr and adminhttpport.
|
||||
func (admin *AdminApp) Run() {
|
||||
func (admin *adminApp) Run() {
|
||||
if len(toolbox.AdminTaskList) > 0 {
|
||||
toolbox.StartTask()
|
||||
}
|
||||
|
74
app.go
74
app.go
@ -1,3 +1,9 @@
|
||||
// Beego (http://beego.me/)
|
||||
// @description beego is an open-source, high-performance web framework for the Go programming language.
|
||||
// @link http://github.com/astaxie/beego for the canonical source repository
|
||||
// @license http://github.com/astaxie/beego/blob/master/LICENSE
|
||||
// @authors astaxie
|
||||
|
||||
package beego
|
||||
|
||||
import (
|
||||
@ -62,10 +68,14 @@ func (app *App) Run() {
|
||||
BeeLogger.Critical("ResolveTCPAddr:", err)
|
||||
}
|
||||
l, err = GetInitListener(laddr)
|
||||
if err == nil {
|
||||
theStoppable = newStoppable(l)
|
||||
err = server.Serve(theStoppable)
|
||||
if err == nil {
|
||||
theStoppable.wg.Wait()
|
||||
CloseSelf()
|
||||
err = CloseSelf()
|
||||
}
|
||||
}
|
||||
} else {
|
||||
s := &http.Server{
|
||||
Addr: addr,
|
||||
@ -118,6 +128,68 @@ func (app *App) AutoRouter(c ControllerInterface) *App {
|
||||
return app
|
||||
}
|
||||
|
||||
// AutoRouterWithPrefix adds beego-defined controller handler with prefix.
|
||||
// if beego.AutoPrefix("/admin",&MainContorlller{}) and MainController has methods List and Page,
|
||||
// visit the url /admin/main/list to exec List function or /admin/main/page to exec Page function.
|
||||
func (app *App) AutoRouterWithPrefix(prefix string, c ControllerInterface) *App {
|
||||
app.Handlers.AddAutoPrefix(prefix, c)
|
||||
return app
|
||||
}
|
||||
|
||||
// add router for Get method
|
||||
func (app *App) Get(rootpath string, f FilterFunc) *App {
|
||||
app.Handlers.Get(rootpath, f)
|
||||
return app
|
||||
}
|
||||
|
||||
// add router for Post method
|
||||
func (app *App) Post(rootpath string, f FilterFunc) *App {
|
||||
app.Handlers.Post(rootpath, f)
|
||||
return app
|
||||
}
|
||||
|
||||
// add router for Put method
|
||||
func (app *App) Put(rootpath string, f FilterFunc) *App {
|
||||
app.Handlers.Put(rootpath, f)
|
||||
return app
|
||||
}
|
||||
|
||||
// add router for Delete method
|
||||
func (app *App) Delete(rootpath string, f FilterFunc) *App {
|
||||
app.Handlers.Delete(rootpath, f)
|
||||
return app
|
||||
}
|
||||
|
||||
// add router for Options method
|
||||
func (app *App) Options(rootpath string, f FilterFunc) *App {
|
||||
app.Handlers.Options(rootpath, f)
|
||||
return app
|
||||
}
|
||||
|
||||
// add router for Head method
|
||||
func (app *App) Head(rootpath string, f FilterFunc) *App {
|
||||
app.Handlers.Head(rootpath, f)
|
||||
return app
|
||||
}
|
||||
|
||||
// add router for Patch method
|
||||
func (app *App) Patch(rootpath string, f FilterFunc) *App {
|
||||
app.Handlers.Patch(rootpath, f)
|
||||
return app
|
||||
}
|
||||
|
||||
// add router for Patch method
|
||||
func (app *App) Any(rootpath string, f FilterFunc) *App {
|
||||
app.Handlers.Any(rootpath, f)
|
||||
return app
|
||||
}
|
||||
|
||||
// add router for http.Handler
|
||||
func (app *App) Handler(rootpath string, h http.Handler, options ...interface{}) *App {
|
||||
app.Handlers.Handler(rootpath, h, options...)
|
||||
return app
|
||||
}
|
||||
|
||||
// UrlFor creates a url with another registered controller handler with params.
|
||||
// The endpoint is formed as path.controller.name to defined the controller method which will run.
|
||||
// The values need key-pair data to assign into controller method.
|
||||
|
220
beego.go
220
beego.go
@ -1,9 +1,17 @@
|
||||
// Beego (http://beego.me/)
|
||||
// @description beego is an open-source, high-performance web framework for the Go programming language.
|
||||
// @link http://github.com/astaxie/beego for the canonical source repository
|
||||
// @license http://github.com/astaxie/beego/blob/master/LICENSE
|
||||
// @authors astaxie
|
||||
|
||||
package beego
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"os"
|
||||
"path"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/astaxie/beego/middleware"
|
||||
@ -11,7 +19,77 @@ import (
|
||||
)
|
||||
|
||||
// beego web framework version.
|
||||
const VERSION = "1.0.1"
|
||||
const VERSION = "1.2.0"
|
||||
|
||||
type hookfunc func() error //hook function to run
|
||||
var hooks []hookfunc //hook function slice to store the hookfunc
|
||||
|
||||
type groupRouter struct {
|
||||
pattern string
|
||||
controller ControllerInterface
|
||||
mappingMethods string
|
||||
}
|
||||
|
||||
// RouterGroups which will store routers
|
||||
type GroupRouters []groupRouter
|
||||
|
||||
// Get a new GroupRouters
|
||||
func NewGroupRouters() GroupRouters {
|
||||
return make(GroupRouters, 0)
|
||||
}
|
||||
|
||||
// Add Router in the GroupRouters
|
||||
// it is for plugin or module to register router
|
||||
func (gr *GroupRouters) AddRouter(pattern string, c ControllerInterface, mappingMethod ...string) {
|
||||
var newRG groupRouter
|
||||
if len(mappingMethod) > 0 {
|
||||
newRG = groupRouter{
|
||||
pattern,
|
||||
c,
|
||||
mappingMethod[0],
|
||||
}
|
||||
} else {
|
||||
newRG = groupRouter{
|
||||
pattern,
|
||||
c,
|
||||
"",
|
||||
}
|
||||
}
|
||||
*gr = append(*gr, newRG)
|
||||
}
|
||||
|
||||
func (gr *GroupRouters) AddAuto(c ControllerInterface) {
|
||||
newRG := groupRouter{
|
||||
"",
|
||||
c,
|
||||
"",
|
||||
}
|
||||
*gr = append(*gr, newRG)
|
||||
}
|
||||
|
||||
// AddGroupRouter with the prefix
|
||||
// it will register the router in BeeApp
|
||||
// the follow code is write in modules:
|
||||
// GR:=NewGroupRouters()
|
||||
// GR.AddRouter("/login",&UserController,"get:Login")
|
||||
// GR.AddRouter("/logout",&UserController,"get:Logout")
|
||||
// GR.AddRouter("/register",&UserController,"get:Reg")
|
||||
// the follow code is write in app:
|
||||
// import "github.com/beego/modules/auth"
|
||||
// AddRouterGroup("/admin", auth.GR)
|
||||
func AddGroupRouter(prefix string, groups GroupRouters) *App {
|
||||
for _, v := range groups {
|
||||
if v.pattern == "" {
|
||||
BeeApp.AutoRouterWithPrefix(prefix, v.controller)
|
||||
} else if v.mappingMethods != "" {
|
||||
BeeApp.Router(prefix+v.pattern, v.controller, v.mappingMethods)
|
||||
} else {
|
||||
BeeApp.Router(prefix+v.pattern, v.controller)
|
||||
}
|
||||
|
||||
}
|
||||
return BeeApp
|
||||
}
|
||||
|
||||
// Router adds a patterned controller handler to BeeApp.
|
||||
// it's an alias method of App.Router.
|
||||
@ -36,6 +114,67 @@ func AutoRouter(c ControllerInterface) *App {
|
||||
return BeeApp
|
||||
}
|
||||
|
||||
// AutoPrefix adds controller handler to BeeApp with prefix.
|
||||
// it's same to App.AutoRouterWithPrefix.
|
||||
func AutoPrefix(prefix string, c ControllerInterface) *App {
|
||||
BeeApp.AutoRouterWithPrefix(prefix, c)
|
||||
return BeeApp
|
||||
}
|
||||
|
||||
// register router for Get method
|
||||
func Get(rootpath string, f FilterFunc) *App {
|
||||
BeeApp.Get(rootpath, f)
|
||||
return BeeApp
|
||||
}
|
||||
|
||||
// register router for Post method
|
||||
func Post(rootpath string, f FilterFunc) *App {
|
||||
BeeApp.Post(rootpath, f)
|
||||
return BeeApp
|
||||
}
|
||||
|
||||
// register router for Delete method
|
||||
func Delete(rootpath string, f FilterFunc) *App {
|
||||
BeeApp.Delete(rootpath, f)
|
||||
return BeeApp
|
||||
}
|
||||
|
||||
// register router for Put method
|
||||
func Put(rootpath string, f FilterFunc) *App {
|
||||
BeeApp.Put(rootpath, f)
|
||||
return BeeApp
|
||||
}
|
||||
|
||||
// register router for Head method
|
||||
func Head(rootpath string, f FilterFunc) *App {
|
||||
BeeApp.Head(rootpath, f)
|
||||
return BeeApp
|
||||
}
|
||||
|
||||
// register router for Options method
|
||||
func Options(rootpath string, f FilterFunc) *App {
|
||||
BeeApp.Options(rootpath, f)
|
||||
return BeeApp
|
||||
}
|
||||
|
||||
// register router for Patch method
|
||||
func Patch(rootpath string, f FilterFunc) *App {
|
||||
BeeApp.Patch(rootpath, f)
|
||||
return BeeApp
|
||||
}
|
||||
|
||||
// register router for all method
|
||||
func Any(rootpath string, f FilterFunc) *App {
|
||||
BeeApp.Any(rootpath, f)
|
||||
return BeeApp
|
||||
}
|
||||
|
||||
// register router for own Handler
|
||||
func Handler(rootpath string, h http.Handler, options ...interface{}) *App {
|
||||
BeeApp.Handler(rootpath, h, options...)
|
||||
return BeeApp
|
||||
}
|
||||
|
||||
// ErrorHandler registers http.HandlerFunc to each http err code string.
|
||||
// usage:
|
||||
// beego.ErrorHandler("404",NotFound)
|
||||
@ -58,6 +197,7 @@ func SetStaticPath(url string, path string) *App {
|
||||
if !strings.HasPrefix(url, "/") {
|
||||
url = "/" + url
|
||||
}
|
||||
url = strings.TrimRight(url, "/")
|
||||
StaticDir[url] = path
|
||||
return BeeApp
|
||||
}
|
||||
@ -87,30 +227,60 @@ func InsertFilter(pattern string, pos int, filter FilterFunc) *App {
|
||||
return BeeApp
|
||||
}
|
||||
|
||||
// The hookfunc will run in beego.Run()
|
||||
// such as sessionInit, middlerware start, buildtemplate, admin start
|
||||
func AddAPPStartHook(hf hookfunc) {
|
||||
hooks = append(hooks, hf)
|
||||
}
|
||||
|
||||
// Run beego application.
|
||||
// it's alias of App.Run.
|
||||
func Run() {
|
||||
initBeforeHttpRun()
|
||||
|
||||
if EnableAdmin {
|
||||
go beeAdminApp.Run()
|
||||
}
|
||||
|
||||
BeeApp.Run()
|
||||
}
|
||||
|
||||
func initBeforeHttpRun() {
|
||||
// if AppConfigPath not In the conf/app.conf reParse config
|
||||
if AppConfigPath != filepath.Join(AppPath, "conf", "app.conf") {
|
||||
err := ParseConfig()
|
||||
if err != nil {
|
||||
if err != nil && AppConfigPath != filepath.Join(workPath, "conf", "app.conf") {
|
||||
// configuration is critical to app, panic here if parse failed
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
|
||||
//init mime
|
||||
initMime()
|
||||
// do hooks function
|
||||
for _, hk := range hooks {
|
||||
err := hk()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
|
||||
if SessionOn {
|
||||
GlobalSessions, _ = session.NewManager(SessionProvider,
|
||||
SessionName,
|
||||
SessionGCMaxLifetime,
|
||||
SessionSavePath,
|
||||
HttpTLS,
|
||||
SessionHashFunc,
|
||||
SessionHashKey,
|
||||
SessionCookieLifeTime)
|
||||
var err error
|
||||
sessionConfig := AppConfig.String("sessionConfig")
|
||||
if sessionConfig == "" {
|
||||
sessionConfig = `{"cookieName":"` + SessionName + `",` +
|
||||
`"gclifetime":` + strconv.FormatInt(SessionGCMaxLifetime, 10) + `,` +
|
||||
`"providerConfig":"` + SessionSavePath + `",` +
|
||||
`"secure":` + strconv.FormatBool(HttpTLS) + `,` +
|
||||
`"sessionIDHashFunc":"` + SessionHashFunc + `",` +
|
||||
`"sessionIDHashKey":"` + SessionHashKey + `",` +
|
||||
`"enableSetCookie":` + strconv.FormatBool(SessionAutoSetCookie) + `,` +
|
||||
`"cookieLifeTime":` + strconv.Itoa(SessionCookieLifeTime) + `}`
|
||||
}
|
||||
GlobalSessions, err = session.NewManager(SessionProvider,
|
||||
sessionConfig)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
go GlobalSessions.GC()
|
||||
}
|
||||
|
||||
@ -123,11 +293,23 @@ func Run() {
|
||||
|
||||
middleware.VERSION = VERSION
|
||||
middleware.AppName = AppName
|
||||
middleware.RegisterErrorHander()
|
||||
|
||||
if EnableAdmin {
|
||||
go BeeAdminApp.Run()
|
||||
}
|
||||
|
||||
BeeApp.Run()
|
||||
middleware.RegisterErrorHandler()
|
||||
}
|
||||
|
||||
func TestBeegoInit(apppath string) {
|
||||
AppPath = apppath
|
||||
AppConfigPath = filepath.Join(AppPath, "conf", "app.conf")
|
||||
err := ParseConfig()
|
||||
if err != nil && !os.IsNotExist(err) {
|
||||
// for init if doesn't have app.conf will not panic
|
||||
Info(err)
|
||||
}
|
||||
os.Chdir(AppPath)
|
||||
initBeforeHttpRun()
|
||||
}
|
||||
|
||||
func init() {
|
||||
hooks = make([]hookfunc, 0)
|
||||
//init mime
|
||||
AddAPPStartHook(initMime)
|
||||
}
|
||||
|
2
cache/README.md
vendored
2
cache/README.md
vendored
@ -43,7 +43,7 @@ interval means the gc time. The cache will check at each time interval, whether
|
||||
|
||||
## Memcache adapter
|
||||
|
||||
memory adapter use the vitess's [Memcache](http://code.google.com/p/vitess/go/memcache) client.
|
||||
Memcache adapter use the vitess's [Memcache](http://code.google.com/p/vitess/go/memcache) client.
|
||||
|
||||
Configure like this:
|
||||
|
||||
|
6
cache/cache.go
vendored
6
cache/cache.go
vendored
@ -1,3 +1,9 @@
|
||||
// Beego (http://beego.me/)
|
||||
// @description beego is an open-source, high-performance web framework for the Go programming language.
|
||||
// @link http://github.com/astaxie/beego for the canonical source repository
|
||||
// @license http://github.com/astaxie/beego/blob/master/LICENSE
|
||||
// @authors astaxie
|
||||
|
||||
package cache
|
||||
|
||||
import (
|
||||
|
58
cache/cache_test.go
vendored
58
cache/cache_test.go
vendored
@ -1,3 +1,9 @@
|
||||
// Beego (http://beego.me/)
|
||||
// @description beego is an open-source, high-performance web framework for the Go programming language.
|
||||
// @link http://github.com/astaxie/beego for the canonical source repository
|
||||
// @license http://github.com/astaxie/beego/blob/master/LICENSE
|
||||
// @authors astaxie
|
||||
|
||||
package cache
|
||||
|
||||
import (
|
||||
@ -5,7 +11,7 @@ import (
|
||||
"time"
|
||||
)
|
||||
|
||||
func Test_cache(t *testing.T) {
|
||||
func TestCache(t *testing.T) {
|
||||
bm, err := NewCache("memory", `{"interval":20}`)
|
||||
if err != nil {
|
||||
t.Error("init err")
|
||||
@ -40,7 +46,7 @@ func Test_cache(t *testing.T) {
|
||||
}
|
||||
|
||||
if err = bm.Decr("astaxie"); err != nil {
|
||||
t.Error("Incr Error", err)
|
||||
t.Error("Decr Error", err)
|
||||
}
|
||||
|
||||
if v := bm.Get("astaxie"); v.(int) != 1 {
|
||||
@ -51,3 +57,51 @@ func Test_cache(t *testing.T) {
|
||||
t.Error("delete err")
|
||||
}
|
||||
}
|
||||
|
||||
func TestFileCache(t *testing.T) {
|
||||
bm, err := NewCache("file", `{"CachePath":"/cache","FileSuffix":".bin","DirectoryLevel":2,"EmbedExpiry":0}`)
|
||||
if err != nil {
|
||||
t.Error("init err")
|
||||
}
|
||||
if err = bm.Put("astaxie", 1, 10); err != nil {
|
||||
t.Error("set Error", err)
|
||||
}
|
||||
if !bm.IsExist("astaxie") {
|
||||
t.Error("check err")
|
||||
}
|
||||
|
||||
if v := bm.Get("astaxie"); v.(int) != 1 {
|
||||
t.Error("get err")
|
||||
}
|
||||
|
||||
if err = bm.Incr("astaxie"); err != nil {
|
||||
t.Error("Incr Error", err)
|
||||
}
|
||||
|
||||
if v := bm.Get("astaxie"); v.(int) != 2 {
|
||||
t.Error("get err")
|
||||
}
|
||||
|
||||
if err = bm.Decr("astaxie"); err != nil {
|
||||
t.Error("Decr Error", err)
|
||||
}
|
||||
|
||||
if v := bm.Get("astaxie"); v.(int) != 1 {
|
||||
t.Error("get err")
|
||||
}
|
||||
bm.Delete("astaxie")
|
||||
if bm.IsExist("astaxie") {
|
||||
t.Error("delete err")
|
||||
}
|
||||
//test string
|
||||
if err = bm.Put("astaxie", "author", 10); err != nil {
|
||||
t.Error("set Error", err)
|
||||
}
|
||||
if !bm.IsExist("astaxie") {
|
||||
t.Error("check err")
|
||||
}
|
||||
|
||||
if v := bm.Get("astaxie"); v.(string) != "author" {
|
||||
t.Error("get err")
|
||||
}
|
||||
}
|
||||
|
6
cache/conv.go
vendored
6
cache/conv.go
vendored
@ -1,3 +1,9 @@
|
||||
// Beego (http://beego.me/)
|
||||
// @description beego is an open-source, high-performance web framework for the Go programming language.
|
||||
// @link http://github.com/astaxie/beego for the canonical source repository
|
||||
// @license http://github.com/astaxie/beego/blob/master/LICENSE
|
||||
// @authors astaxie
|
||||
|
||||
package cache
|
||||
|
||||
import (
|
||||
|
6
cache/conv_test.go
vendored
6
cache/conv_test.go
vendored
@ -1,3 +1,9 @@
|
||||
// Beego (http://beego.me/)
|
||||
// @description beego is an open-source, high-performance web framework for the Go programming language.
|
||||
// @link http://github.com/astaxie/beego for the canonical source repository
|
||||
// @license http://github.com/astaxie/beego/blob/master/LICENSE
|
||||
// @authors astaxie
|
||||
|
||||
package cache
|
||||
|
||||
import (
|
||||
|
33
cache/file.go
vendored
33
cache/file.go
vendored
@ -1,8 +1,9 @@
|
||||
/**
|
||||
* package: file
|
||||
* User: gouki
|
||||
* Date: 2013-10-22 - 14:22
|
||||
*/
|
||||
// Beego (http://beego.me/)
|
||||
// @description beego is an open-source, high-performance web framework for the Go programming language.
|
||||
// @link http://github.com/astaxie/beego for the canonical source repository
|
||||
// @license http://github.com/astaxie/beego/blob/master/LICENSE
|
||||
// @authors astaxie
|
||||
|
||||
package cache
|
||||
|
||||
import (
|
||||
@ -47,10 +48,11 @@ type FileCache struct {
|
||||
EmbedExpiry int
|
||||
}
|
||||
|
||||
// Create new file cache with default directory and suffix.
|
||||
// Create new file cache with no config.
|
||||
// the level and expiry need set in method StartAndGC as config string.
|
||||
func NewFileCache() *FileCache {
|
||||
return &FileCache{CachePath:FileCachePath, FileSuffix:FileCacheFileSuffix}
|
||||
// return &FileCache{CachePath:FileCachePath, FileSuffix:FileCacheFileSuffix}
|
||||
return &FileCache{}
|
||||
}
|
||||
|
||||
// Start and begin gc for file cache.
|
||||
@ -60,6 +62,7 @@ func (this *FileCache) StartAndGC(config string) error {
|
||||
var cfg map[string]string
|
||||
json.Unmarshal([]byte(config), &cfg)
|
||||
//fmt.Println(cfg)
|
||||
//fmt.Println(config)
|
||||
if _, ok := cfg["CachePath"]; !ok {
|
||||
cfg["CachePath"] = FileCachePath
|
||||
}
|
||||
@ -134,7 +137,7 @@ func (this *FileCache) Get(key string) interface{} {
|
||||
return ""
|
||||
}
|
||||
var to FileCacheItem
|
||||
Gob_decode([]byte(filedata), &to)
|
||||
Gob_decode(filedata, &to)
|
||||
if to.Expired < time.Now().Unix() {
|
||||
return ""
|
||||
}
|
||||
@ -142,13 +145,16 @@ func (this *FileCache) Get(key string) interface{} {
|
||||
}
|
||||
|
||||
// Put value into file cache.
|
||||
// timeout means how long to keep this file, unit of second.
|
||||
// timeout means how long to keep this file, unit of ms.
|
||||
// if timeout equals FileCacheEmbedExpiry(default is 0), cache this item forever.
|
||||
func (this *FileCache) Put(key string, val interface{}, timeout int64) error {
|
||||
gob.Register(val)
|
||||
|
||||
filename := this.getCacheFileName(key)
|
||||
var item FileCacheItem
|
||||
item.Data = val
|
||||
if timeout == FileCacheEmbedExpiry {
|
||||
item.Expired = time.Now().Unix() + (86400 * 365 * 10) //10年
|
||||
item.Expired = time.Now().Unix() + (86400 * 365 * 10) // ten years
|
||||
} else {
|
||||
item.Expired = time.Now().Unix() + timeout
|
||||
}
|
||||
@ -175,7 +181,7 @@ func (this *FileCache) Delete(key string) error {
|
||||
func (this *FileCache) Incr(key string) error {
|
||||
data := this.Get(key)
|
||||
var incr int
|
||||
fmt.Println(reflect.TypeOf(data).Name())
|
||||
//fmt.Println(reflect.TypeOf(data).Name())
|
||||
if reflect.TypeOf(data).Name() != "int" {
|
||||
incr = 0
|
||||
} else {
|
||||
@ -208,8 +214,7 @@ func (this *FileCache) IsExist(key string) bool {
|
||||
// Clean cached files.
|
||||
// not implemented.
|
||||
func (this *FileCache) ClearAll() error {
|
||||
//this.CachePath .递归删除
|
||||
|
||||
//this.CachePath
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -269,7 +274,7 @@ func Gob_encode(data interface{}) ([]byte, error) {
|
||||
}
|
||||
|
||||
// Gob decodes file cache item.
|
||||
func Gob_decode(data []byte, to interface{}) error {
|
||||
func Gob_decode(data []byte, to *FileCacheItem) error {
|
||||
buf := bytes.NewBuffer(data)
|
||||
dec := gob.NewDecoder(buf)
|
||||
return dec.Decode(&to)
|
||||
|
54
cache/memcache.go → cache/memcache/memcache.go
vendored
54
cache/memcache.go → cache/memcache/memcache.go
vendored
@ -1,3 +1,9 @@
|
||||
// Beego (http://beego.me/)
|
||||
// @description beego is an open-source, high-performance web framework for the Go programming language.
|
||||
// @link http://github.com/astaxie/beego for the canonical source repository
|
||||
// @license http://github.com/astaxie/beego/blob/master/LICENSE
|
||||
// @authors astaxie
|
||||
|
||||
package cache
|
||||
|
||||
import (
|
||||
@ -5,6 +11,8 @@ import (
|
||||
"errors"
|
||||
|
||||
"github.com/beego/memcache"
|
||||
|
||||
"github.com/astaxie/beego/cache"
|
||||
)
|
||||
|
||||
// Memcache adapter.
|
||||
@ -21,7 +29,11 @@ func NewMemCache() *MemcacheCache {
|
||||
// get value from memcache.
|
||||
func (rc *MemcacheCache) Get(key string) interface{} {
|
||||
if rc.c == nil {
|
||||
rc.c = rc.connectInit()
|
||||
var err error
|
||||
rc.c, err = rc.connectInit()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
v, err := rc.c.Get(key)
|
||||
if err != nil {
|
||||
@ -39,7 +51,11 @@ func (rc *MemcacheCache) Get(key string) interface{} {
|
||||
// put value to memcache. only support string.
|
||||
func (rc *MemcacheCache) Put(key string, val interface{}, timeout int64) error {
|
||||
if rc.c == nil {
|
||||
rc.c = rc.connectInit()
|
||||
var err error
|
||||
rc.c, err = rc.connectInit()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
v, ok := val.(string)
|
||||
if !ok {
|
||||
@ -55,7 +71,11 @@ func (rc *MemcacheCache) Put(key string, val interface{}, timeout int64) error {
|
||||
// delete value in memcache.
|
||||
func (rc *MemcacheCache) Delete(key string) error {
|
||||
if rc.c == nil {
|
||||
rc.c = rc.connectInit()
|
||||
var err error
|
||||
rc.c, err = rc.connectInit()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
_, err := rc.c.Delete(key)
|
||||
return err
|
||||
@ -76,7 +96,11 @@ func (rc *MemcacheCache) Decr(key string) error {
|
||||
// check value exists in memcache.
|
||||
func (rc *MemcacheCache) IsExist(key string) bool {
|
||||
if rc.c == nil {
|
||||
rc.c = rc.connectInit()
|
||||
var err error
|
||||
rc.c, err = rc.connectInit()
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
}
|
||||
v, err := rc.c.Get(key)
|
||||
if err != nil {
|
||||
@ -87,13 +111,16 @@ func (rc *MemcacheCache) IsExist(key string) bool {
|
||||
} else {
|
||||
return true
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// clear all cached in memcache.
|
||||
func (rc *MemcacheCache) ClearAll() error {
|
||||
if rc.c == nil {
|
||||
rc.c = rc.connectInit()
|
||||
var err error
|
||||
rc.c, err = rc.connectInit()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
err := rc.c.FlushAll()
|
||||
return err
|
||||
@ -109,22 +136,25 @@ func (rc *MemcacheCache) StartAndGC(config string) error {
|
||||
return errors.New("config has no conn key")
|
||||
}
|
||||
rc.conninfo = cf["conn"]
|
||||
rc.c = rc.connectInit()
|
||||
if rc.c == nil {
|
||||
var err error
|
||||
if rc.c != nil {
|
||||
rc.c, err = rc.connectInit()
|
||||
if err != nil {
|
||||
return errors.New("dial tcp conn error")
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// connect to memcache and keep the connection.
|
||||
func (rc *MemcacheCache) connectInit() *memcache.Connection {
|
||||
func (rc *MemcacheCache) connectInit() (*memcache.Connection, error) {
|
||||
c, err := memcache.Connect(rc.conninfo)
|
||||
if err != nil {
|
||||
return nil
|
||||
return nil, err
|
||||
}
|
||||
return c
|
||||
return c, nil
|
||||
}
|
||||
|
||||
func init() {
|
||||
Register("memcache", NewMemCache())
|
||||
cache.Register("memcache", NewMemCache())
|
||||
}
|
9
cache/memory.go
vendored
9
cache/memory.go
vendored
@ -1,3 +1,9 @@
|
||||
// Beego (http://beego.me/)
|
||||
// @description beego is an open-source, high-performance web framework for the Go programming language.
|
||||
// @link http://github.com/astaxie/beego for the canonical source repository
|
||||
// @license http://github.com/astaxie/beego/blob/master/LICENSE
|
||||
// @authors astaxie
|
||||
|
||||
package cache
|
||||
|
||||
import (
|
||||
@ -26,7 +32,7 @@ type MemoryCache struct {
|
||||
lock sync.RWMutex
|
||||
dur time.Duration
|
||||
items map[string]*MemoryItem
|
||||
Every int // run an expiration check Every cloc; time
|
||||
Every int // run an expiration check Every clock time
|
||||
}
|
||||
|
||||
// NewMemoryCache returns a new MemoryCache.
|
||||
@ -52,6 +58,7 @@ func (bc *MemoryCache) Get(name string) interface{} {
|
||||
}
|
||||
|
||||
// Put cache to memory.
|
||||
// if expired is 0, it will be cleaned by next gc operation ( default gc clock is 1 minute).
|
||||
func (bc *MemoryCache) Put(name string, value interface{}, expired int64) error {
|
||||
bc.lock.Lock()
|
||||
defer bc.lock.Unlock()
|
||||
|
116
cache/redis.go → cache/redis/redis.go
vendored
116
cache/redis.go → cache/redis/redis.go
vendored
@ -1,10 +1,19 @@
|
||||
// Beego (http://beego.me/)
|
||||
// @description beego is an open-source, high-performance web framework for the Go programming language.
|
||||
// @link http://github.com/astaxie/beego for the canonical source repository
|
||||
// @license http://github.com/astaxie/beego/blob/master/LICENSE
|
||||
// @authors astaxie
|
||||
|
||||
package cache
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"time"
|
||||
|
||||
"github.com/beego/redigo/redis"
|
||||
|
||||
"github.com/astaxie/beego/cache"
|
||||
)
|
||||
|
||||
var (
|
||||
@ -14,7 +23,7 @@ var (
|
||||
|
||||
// Redis cache adapter.
|
||||
type RedisCache struct {
|
||||
c redis.Conn
|
||||
p *redis.Pool // redis connection pool
|
||||
conninfo string
|
||||
key string
|
||||
}
|
||||
@ -24,107 +33,62 @@ func NewRedisCache() *RedisCache {
|
||||
return &RedisCache{key: DefaultKey}
|
||||
}
|
||||
|
||||
// actually do the redis cmds
|
||||
func (rc *RedisCache) do(commandName string, args ...interface{}) (reply interface{}, err error) {
|
||||
c := rc.p.Get()
|
||||
defer c.Close()
|
||||
|
||||
return c.Do(commandName, args...)
|
||||
}
|
||||
|
||||
// Get cache from redis.
|
||||
func (rc *RedisCache) Get(key string) interface{} {
|
||||
if rc.c == nil {
|
||||
var err error
|
||||
rc.c, err = rc.connectInit()
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
v, err := rc.c.Do("HGET", rc.key, key)
|
||||
v, err := rc.do("HGET", rc.key, key)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return v
|
||||
}
|
||||
|
||||
// put cache to redis.
|
||||
// timeout is ignored.
|
||||
func (rc *RedisCache) Put(key string, val interface{}, timeout int64) error {
|
||||
if rc.c == nil {
|
||||
var err error
|
||||
rc.c, err = rc.connectInit()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
_, err := rc.c.Do("HSET", rc.key, key, val)
|
||||
_, err := rc.do("HSET", rc.key, key, val)
|
||||
return err
|
||||
}
|
||||
|
||||
// delete cache in redis.
|
||||
func (rc *RedisCache) Delete(key string) error {
|
||||
if rc.c == nil {
|
||||
var err error
|
||||
rc.c, err = rc.connectInit()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
_, err := rc.c.Do("HDEL", rc.key, key)
|
||||
_, err := rc.do("HDEL", rc.key, key)
|
||||
return err
|
||||
}
|
||||
|
||||
// check cache exist in redis.
|
||||
func (rc *RedisCache) IsExist(key string) bool {
|
||||
if rc.c == nil {
|
||||
var err error
|
||||
rc.c, err = rc.connectInit()
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
}
|
||||
v, err := redis.Bool(rc.c.Do("HEXISTS", rc.key, key))
|
||||
v, err := redis.Bool(rc.do("HEXISTS", rc.key, key))
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
return v
|
||||
}
|
||||
|
||||
// increase counter in redis.
|
||||
func (rc *RedisCache) Incr(key string) error {
|
||||
if rc.c == nil {
|
||||
var err error
|
||||
rc.c, err = rc.connectInit()
|
||||
if err != nil {
|
||||
_, err := redis.Bool(rc.do("HINCRBY", rc.key, key, 1))
|
||||
return err
|
||||
}
|
||||
}
|
||||
_, err := redis.Bool(rc.c.Do("HINCRBY", rc.key, key, 1))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// decrease counter in redis.
|
||||
func (rc *RedisCache) Decr(key string) error {
|
||||
if rc.c == nil {
|
||||
var err error
|
||||
rc.c, err = rc.connectInit()
|
||||
if err != nil {
|
||||
_, err := redis.Bool(rc.do("HINCRBY", rc.key, key, -1))
|
||||
return err
|
||||
}
|
||||
}
|
||||
_, err := redis.Bool(rc.c.Do("HINCRBY", rc.key, key, -1))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// clean all cache in redis. delete this redis collection.
|
||||
func (rc *RedisCache) ClearAll() error {
|
||||
if rc.c == nil {
|
||||
var err error
|
||||
rc.c, err = rc.connectInit()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
_, err := rc.c.Do("DEL", rc.key)
|
||||
_, err := rc.do("DEL", rc.key)
|
||||
return err
|
||||
}
|
||||
|
||||
@ -135,34 +99,44 @@ func (rc *RedisCache) ClearAll() error {
|
||||
func (rc *RedisCache) StartAndGC(config string) error {
|
||||
var cf map[string]string
|
||||
json.Unmarshal([]byte(config), &cf)
|
||||
|
||||
if _, ok := cf["key"]; !ok {
|
||||
cf["key"] = DefaultKey
|
||||
}
|
||||
|
||||
if _, ok := cf["conn"]; !ok {
|
||||
return errors.New("config has no conn key")
|
||||
}
|
||||
|
||||
rc.key = cf["key"]
|
||||
rc.conninfo = cf["conn"]
|
||||
var err error
|
||||
rc.c, err = rc.connectInit()
|
||||
if err != nil {
|
||||
rc.connectInit()
|
||||
|
||||
c := rc.p.Get()
|
||||
defer c.Close()
|
||||
if err := c.Err(); err != nil {
|
||||
return err
|
||||
}
|
||||
if rc.c == nil {
|
||||
return errors.New("dial tcp conn error")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// connect to redis.
|
||||
func (rc *RedisCache) connectInit() (redis.Conn, error) {
|
||||
func (rc *RedisCache) connectInit() {
|
||||
// initialize a new pool
|
||||
rc.p = &redis.Pool{
|
||||
MaxIdle: 3,
|
||||
IdleTimeout: 180 * time.Second,
|
||||
Dial: func() (redis.Conn, error) {
|
||||
c, err := redis.Dial("tcp", rc.conninfo)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return c, nil
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func init() {
|
||||
Register("redis", NewRedisCache())
|
||||
cache.Register("redis", NewRedisCache())
|
||||
}
|
49
config.go
49
config.go
@ -1,6 +1,13 @@
|
||||
// Beego (http://beego.me/)
|
||||
// @description beego is an open-source, high-performance web framework for the Go programming language.
|
||||
// @link http://github.com/astaxie/beego for the canonical source repository
|
||||
// @license http://github.com/astaxie/beego/blob/master/LICENSE
|
||||
// @authors astaxie
|
||||
|
||||
package beego
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"html/template"
|
||||
"os"
|
||||
"path/filepath"
|
||||
@ -11,12 +18,14 @@ import (
|
||||
"github.com/astaxie/beego/config"
|
||||
"github.com/astaxie/beego/logs"
|
||||
"github.com/astaxie/beego/session"
|
||||
"github.com/astaxie/beego/utils"
|
||||
)
|
||||
|
||||
var (
|
||||
BeeApp *App // beego application
|
||||
AppName string
|
||||
AppPath string
|
||||
workPath string
|
||||
AppConfigPath string
|
||||
StaticDir map[string]string
|
||||
TemplateCache map[string]*template.Template // template caching map
|
||||
@ -40,6 +49,7 @@ var (
|
||||
SessionHashFunc string // session hash generation func.
|
||||
SessionHashKey string // session hash salt string.
|
||||
SessionCookieLifeTime int // the life time of session id in cookie.
|
||||
SessionAutoSetCookie bool // auto setcookie
|
||||
UseFcgi bool
|
||||
MaxMemory int64
|
||||
EnableGzip bool // flag of enable gzip
|
||||
@ -57,15 +67,28 @@ var (
|
||||
EnableAdmin bool // flag of enable admin module to log every request info.
|
||||
AdminHttpAddr string // http server configurations for admin module.
|
||||
AdminHttpPort int
|
||||
FlashName string // name of the flash variable found in response header and cookie
|
||||
FlashSeperator string // used to seperate flash key:value
|
||||
)
|
||||
|
||||
func init() {
|
||||
// create beego application
|
||||
BeeApp = NewApp()
|
||||
|
||||
workPath, _ = os.Getwd()
|
||||
workPath, _ = filepath.Abs(workPath)
|
||||
// initialize default configurations
|
||||
AppPath, _ = filepath.Abs(filepath.Dir(os.Args[0]))
|
||||
|
||||
AppConfigPath = filepath.Join(AppPath, "conf", "app.conf")
|
||||
|
||||
if workPath != AppPath {
|
||||
if utils.FileExists(AppConfigPath) {
|
||||
os.Chdir(AppPath)
|
||||
} else {
|
||||
AppConfigPath = filepath.Join(workPath, "conf", "app.conf")
|
||||
}
|
||||
}
|
||||
|
||||
StaticDir = make(map[string]string)
|
||||
StaticDir["/static"] = "static"
|
||||
@ -96,6 +119,7 @@ func init() {
|
||||
SessionHashFunc = "sha1"
|
||||
SessionHashKey = "beegoserversessionkey"
|
||||
SessionCookieLifeTime = 0 //set cookie default is the brower life
|
||||
SessionAutoSetCookie = true
|
||||
|
||||
UseFcgi = false
|
||||
|
||||
@ -103,8 +127,6 @@ func init() {
|
||||
|
||||
EnableGzip = false
|
||||
|
||||
AppConfigPath = filepath.Join(AppPath, "conf", "app.conf")
|
||||
|
||||
HttpServerTimeOut = 0
|
||||
|
||||
ErrorsShow = true
|
||||
@ -121,13 +143,19 @@ func init() {
|
||||
AdminHttpAddr = "127.0.0.1"
|
||||
AdminHttpPort = 8088
|
||||
|
||||
FlashName = "BEEGO_FLASH"
|
||||
FlashSeperator = "BEEGOFLASH"
|
||||
|
||||
runtime.GOMAXPROCS(runtime.NumCPU())
|
||||
|
||||
// init BeeLogger
|
||||
BeeLogger = logs.NewLogger(10000)
|
||||
BeeLogger.SetLogger("console", "")
|
||||
err := BeeLogger.SetLogger("console", "")
|
||||
if err != nil {
|
||||
fmt.Println("init console log error:", err)
|
||||
}
|
||||
|
||||
err := ParseConfig()
|
||||
err = ParseConfig()
|
||||
if err != nil && !os.IsNotExist(err) {
|
||||
// for init if doesn't have app.conf will not panic
|
||||
Info(err)
|
||||
@ -139,6 +167,7 @@ func init() {
|
||||
func ParseConfig() (err error) {
|
||||
AppConfig, err = config.NewConfig("ini", AppConfigPath)
|
||||
if err != nil {
|
||||
AppConfig = config.NewFakeConfig()
|
||||
return err
|
||||
} else {
|
||||
HttpAddr = AppConfig.String("HttpAddr")
|
||||
@ -268,6 +297,14 @@ func ParseConfig() (err error) {
|
||||
BeegoServerName = serverName
|
||||
}
|
||||
|
||||
if flashname := AppConfig.String("FlashName"); flashname != "" {
|
||||
FlashName = flashname
|
||||
}
|
||||
|
||||
if flashseperator := AppConfig.String("FlashSeperator"); flashseperator != "" {
|
||||
FlashSeperator = flashseperator
|
||||
}
|
||||
|
||||
if sd := AppConfig.String("StaticDir"); sd != "" {
|
||||
for k := range StaticDir {
|
||||
delete(StaticDir, k)
|
||||
@ -275,9 +312,9 @@ func ParseConfig() (err error) {
|
||||
sds := strings.Fields(sd)
|
||||
for _, v := range sds {
|
||||
if url2fsmap := strings.SplitN(v, ":", 2); len(url2fsmap) == 2 {
|
||||
StaticDir["/"+url2fsmap[0]] = url2fsmap[1]
|
||||
StaticDir["/"+strings.TrimRight(url2fsmap[0], "/")] = url2fsmap[1]
|
||||
} else {
|
||||
StaticDir["/"+url2fsmap[0]] = url2fsmap[0]
|
||||
StaticDir["/"+strings.TrimRight(url2fsmap[0], "/")] = url2fsmap[0]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -1,12 +1,20 @@
|
||||
// Beego (http://beego.me/)
|
||||
// @description beego is an open-source, high-performance web framework for the Go programming language.
|
||||
// @link http://github.com/astaxie/beego for the canonical source repository
|
||||
// @license http://github.com/astaxie/beego/blob/master/LICENSE
|
||||
// @authors astaxie
|
||||
|
||||
package config
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
)
|
||||
|
||||
// ConfigContainer defines how to get and set value from configuration raw data.
|
||||
type ConfigContainer interface {
|
||||
Set(key, val string) error
|
||||
String(key string) string
|
||||
Set(key, val string) error // support section::key type in given key when using ini type.
|
||||
String(key string) string // support section::key type in key string when using ini and json type; Int,Int64,Bool,Float,DIY are same.
|
||||
Strings(key string) []string //get string slice
|
||||
Int(key string) (int, error)
|
||||
Int64(key string) (int64, error)
|
||||
Bool(key string) (bool, error)
|
||||
@ -14,6 +22,7 @@ type ConfigContainer interface {
|
||||
DIY(key string) (interface{}, error)
|
||||
}
|
||||
|
||||
// Config is the adapter interface for parsing config file to get raw data to ConfigContainer.
|
||||
type Config interface {
|
||||
Parse(key string) (ConfigContainer, error)
|
||||
}
|
||||
@ -33,8 +42,8 @@ func Register(name string, adapter Config) {
|
||||
adapters[name] = adapter
|
||||
}
|
||||
|
||||
// adapterNamer is ini/json/xml/yaml
|
||||
// filename is the config file path
|
||||
// adapterName is ini/json/xml/yaml.
|
||||
// filename is the config file path.
|
||||
func NewConfig(adapterName, fileaname string) (ConfigContainer, error) {
|
||||
adapter, ok := adapters[adapterName]
|
||||
if !ok {
|
||||
|
68
config/fake.go
Normal file
68
config/fake.go
Normal file
@ -0,0 +1,68 @@
|
||||
// Beego (http://beego.me/)
|
||||
// @description beego is an open-source, high-performance web framework for the Go programming language.
|
||||
// @link http://github.com/astaxie/beego for the canonical source repository
|
||||
// @license http://github.com/astaxie/beego/blob/master/LICENSE
|
||||
// @authors astaxie
|
||||
|
||||
package config
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type fakeConfigContainer struct {
|
||||
data map[string]string
|
||||
}
|
||||
|
||||
func (c *fakeConfigContainer) getData(key string) string {
|
||||
key = strings.ToLower(key)
|
||||
return c.data[key]
|
||||
}
|
||||
|
||||
func (c *fakeConfigContainer) Set(key, val string) error {
|
||||
key = strings.ToLower(key)
|
||||
c.data[key] = val
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *fakeConfigContainer) String(key string) string {
|
||||
return c.getData(key)
|
||||
}
|
||||
|
||||
func (c *fakeConfigContainer) Strings(key string) []string {
|
||||
return strings.Split(c.getData(key), ";")
|
||||
}
|
||||
|
||||
func (c *fakeConfigContainer) Int(key string) (int, error) {
|
||||
return strconv.Atoi(c.getData(key))
|
||||
}
|
||||
|
||||
func (c *fakeConfigContainer) Int64(key string) (int64, error) {
|
||||
return strconv.ParseInt(c.getData(key), 10, 64)
|
||||
}
|
||||
|
||||
func (c *fakeConfigContainer) Bool(key string) (bool, error) {
|
||||
return strconv.ParseBool(c.getData(key))
|
||||
}
|
||||
|
||||
func (c *fakeConfigContainer) Float(key string) (float64, error) {
|
||||
return strconv.ParseFloat(c.getData(key), 64)
|
||||
}
|
||||
|
||||
func (c *fakeConfigContainer) DIY(key string) (interface{}, error) {
|
||||
key = strings.ToLower(key)
|
||||
if v, ok := c.data[key]; ok {
|
||||
return v, nil
|
||||
}
|
||||
return nil, errors.New("key not find")
|
||||
}
|
||||
|
||||
var _ ConfigContainer = new(fakeConfigContainer)
|
||||
|
||||
func NewFakeConfig() ConfigContainer {
|
||||
return &fakeConfigContainer{
|
||||
data: make(map[string]string),
|
||||
}
|
||||
}
|
@ -1,3 +1,9 @@
|
||||
// Beego (http://beego.me/)
|
||||
// @description beego is an open-source, high-performance web framework for the Go programming language.
|
||||
// @link http://github.com/astaxie/beego for the canonical source repository
|
||||
// @license http://github.com/astaxie/beego/blob/master/LICENSE
|
||||
// @authors astaxie
|
||||
|
||||
package config
|
||||
|
||||
import (
|
||||
@ -13,21 +19,21 @@ import (
|
||||
)
|
||||
|
||||
var (
|
||||
DEFAULT_SECTION = "default"
|
||||
bNumComment = []byte{'#'} // number sign
|
||||
bSemComment = []byte{';'} // semicolon
|
||||
DEFAULT_SECTION = "default" // default section means if some ini items not in a section, make them in default section,
|
||||
bNumComment = []byte{'#'} // number signal
|
||||
bSemComment = []byte{';'} // semicolon signal
|
||||
bEmpty = []byte{}
|
||||
bEqual = []byte{'='}
|
||||
bDQuote = []byte{'"'}
|
||||
sectionStart = []byte{'['}
|
||||
sectionEnd = []byte{']'}
|
||||
bEqual = []byte{'='} // equal signal
|
||||
bDQuote = []byte{'"'} // quote signal
|
||||
sectionStart = []byte{'['} // section start signal
|
||||
sectionEnd = []byte{']'} // section end signal
|
||||
)
|
||||
|
||||
// IniConfig implements Config to parse ini file.
|
||||
type IniConfig struct {
|
||||
}
|
||||
|
||||
// ParseFile creates a new Config and parses the file configuration from the
|
||||
// named file.
|
||||
// ParseFile creates a new Config and parses the file configuration from the named file.
|
||||
func (ini *IniConfig) Parse(name string) (ConfigContainer, error) {
|
||||
file, err := os.Open(name)
|
||||
if err != nil {
|
||||
@ -106,11 +112,12 @@ func (ini *IniConfig) Parse(name string) (ConfigContainer, error) {
|
||||
return cfg, nil
|
||||
}
|
||||
|
||||
// A Config represents the configuration.
|
||||
// A Config represents the ini configuration.
|
||||
// When set and get value, support key as section:name type.
|
||||
type IniConfigContainer struct {
|
||||
filename string
|
||||
data map[string]map[string]string //section=> key:val
|
||||
sectionComment map[string]string //sction : comment
|
||||
data map[string]map[string]string // section=> key:val
|
||||
sectionComment map[string]string // section : comment
|
||||
keycomment map[string]string // id: []{comment, key...}; id 1 is for main comment.
|
||||
sync.RWMutex
|
||||
}
|
||||
@ -127,6 +134,7 @@ func (c *IniConfigContainer) Int(key string) (int, error) {
|
||||
return strconv.Atoi(c.getdata(key))
|
||||
}
|
||||
|
||||
// Int64 returns the int64 value for a given key.
|
||||
func (c *IniConfigContainer) Int64(key string) (int64, error) {
|
||||
key = strings.ToLower(key)
|
||||
return strconv.ParseInt(c.getdata(key), 10, 64)
|
||||
@ -144,7 +152,14 @@ func (c *IniConfigContainer) String(key string) string {
|
||||
return c.getdata(key)
|
||||
}
|
||||
|
||||
// Strings returns the []string value for a given key.
|
||||
func (c *IniConfigContainer) Strings(key string) []string {
|
||||
return strings.Split(c.String(key), ";")
|
||||
}
|
||||
|
||||
// WriteValue writes a new value for key.
|
||||
// if write to one section, the key need be "section::key".
|
||||
// if the section is not existed, it panics.
|
||||
func (c *IniConfigContainer) Set(key, value string) error {
|
||||
c.Lock()
|
||||
defer c.Unlock()
|
||||
@ -169,6 +184,7 @@ func (c *IniConfigContainer) Set(key, value string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// DIY returns the raw value by a given key.
|
||||
func (c *IniConfigContainer) DIY(key string) (v interface{}, err error) {
|
||||
key = strings.ToLower(key)
|
||||
if v, ok := c.data[key]; ok {
|
||||
@ -177,7 +193,7 @@ func (c *IniConfigContainer) DIY(key string) (v interface{}, err error) {
|
||||
return v, errors.New("key not find")
|
||||
}
|
||||
|
||||
//section.key or key
|
||||
// section.key or key
|
||||
func (c *IniConfigContainer) getdata(key string) string {
|
||||
c.RLock()
|
||||
defer c.RUnlock()
|
||||
|
@ -1,3 +1,9 @@
|
||||
// Beego (http://beego.me/)
|
||||
// @description beego is an open-source, high-performance web framework for the Go programming language.
|
||||
// @link http://github.com/astaxie/beego for the canonical source repository
|
||||
// @license http://github.com/astaxie/beego/blob/master/LICENSE
|
||||
// @authors astaxie
|
||||
|
||||
package config
|
||||
|
||||
import (
|
||||
@ -19,6 +25,7 @@ copyrequestbody = true
|
||||
key1="asta"
|
||||
key2 = "xie"
|
||||
CaseInsensitive = true
|
||||
peers = one;two;three
|
||||
`
|
||||
|
||||
func TestIni(t *testing.T) {
|
||||
@ -78,4 +85,11 @@ func TestIni(t *testing.T) {
|
||||
if v, err := iniconf.Bool("demo::caseinsensitive"); err != nil || v != true {
|
||||
t.Fatal("get demo.caseinsensitive error")
|
||||
}
|
||||
|
||||
if data := iniconf.Strings("demo::peers"); len(data) != 3 {
|
||||
t.Fatal("get strings error", data)
|
||||
} else if data[0] != "one" {
|
||||
t.Fatal("get first params error not equat to one")
|
||||
}
|
||||
|
||||
}
|
||||
|
@ -1,3 +1,9 @@
|
||||
// Beego (http://beego.me/)
|
||||
// @description beego is an open-source, high-performance web framework for the Go programming language.
|
||||
// @link http://github.com/astaxie/beego for the canonical source repository
|
||||
// @license http://github.com/astaxie/beego/blob/master/LICENSE
|
||||
// @authors astaxie
|
||||
|
||||
package config
|
||||
|
||||
import (
|
||||
@ -9,9 +15,11 @@ import (
|
||||
"sync"
|
||||
)
|
||||
|
||||
// JsonConfig is a json config parser and implements Config interface.
|
||||
type JsonConfig struct {
|
||||
}
|
||||
|
||||
// Parse returns a ConfigContainer with parsed json config map.
|
||||
func (js *JsonConfig) Parse(filename string) (ConfigContainer, error) {
|
||||
file, err := os.Open(filename)
|
||||
if err != nil {
|
||||
@ -32,11 +40,14 @@ func (js *JsonConfig) Parse(filename string) (ConfigContainer, error) {
|
||||
return x, nil
|
||||
}
|
||||
|
||||
// A Config represents the json configuration.
|
||||
// Only when get value, support key as section:name type.
|
||||
type JsonConfigContainer struct {
|
||||
data map[string]interface{}
|
||||
sync.RWMutex
|
||||
}
|
||||
|
||||
// Bool returns the boolean value for a given key.
|
||||
func (c *JsonConfigContainer) Bool(key string) (bool, error) {
|
||||
val := c.getdata(key)
|
||||
if val != nil {
|
||||
@ -48,9 +59,9 @@ func (c *JsonConfigContainer) Bool(key string) (bool, error) {
|
||||
} else {
|
||||
return false, errors.New("not exist key:" + key)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
// Int returns the integer value for a given key.
|
||||
func (c *JsonConfigContainer) Int(key string) (int, error) {
|
||||
val := c.getdata(key)
|
||||
if val != nil {
|
||||
@ -64,6 +75,7 @@ func (c *JsonConfigContainer) Int(key string) (int, error) {
|
||||
}
|
||||
}
|
||||
|
||||
// Int64 returns the int64 value for a given key.
|
||||
func (c *JsonConfigContainer) Int64(key string) (int64, error) {
|
||||
val := c.getdata(key)
|
||||
if val != nil {
|
||||
@ -77,6 +89,7 @@ func (c *JsonConfigContainer) Int64(key string) (int64, error) {
|
||||
}
|
||||
}
|
||||
|
||||
// Float returns the float value for a given key.
|
||||
func (c *JsonConfigContainer) Float(key string) (float64, error) {
|
||||
val := c.getdata(key)
|
||||
if val != nil {
|
||||
@ -90,6 +103,7 @@ func (c *JsonConfigContainer) Float(key string) (float64, error) {
|
||||
}
|
||||
}
|
||||
|
||||
// String returns the string value for a given key.
|
||||
func (c *JsonConfigContainer) String(key string) string {
|
||||
val := c.getdata(key)
|
||||
if val != nil {
|
||||
@ -103,6 +117,12 @@ func (c *JsonConfigContainer) String(key string) string {
|
||||
}
|
||||
}
|
||||
|
||||
// Strings returns the []string value for a given key.
|
||||
func (c *JsonConfigContainer) Strings(key string) []string {
|
||||
return strings.Split(c.String(key), ";")
|
||||
}
|
||||
|
||||
// WriteValue writes a new value for key.
|
||||
func (c *JsonConfigContainer) Set(key, val string) error {
|
||||
c.Lock()
|
||||
defer c.Unlock()
|
||||
@ -110,6 +130,7 @@ func (c *JsonConfigContainer) Set(key, val string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// DIY returns the raw value by a given key.
|
||||
func (c *JsonConfigContainer) DIY(key string) (v interface{}, err error) {
|
||||
val := c.getdata(key)
|
||||
if val != nil {
|
||||
@ -119,7 +140,7 @@ func (c *JsonConfigContainer) DIY(key string) (v interface{}, err error) {
|
||||
}
|
||||
}
|
||||
|
||||
//section.key or key
|
||||
// section.key or key
|
||||
func (c *JsonConfigContainer) getdata(key string) interface{} {
|
||||
c.RLock()
|
||||
defer c.RUnlock()
|
||||
|
@ -1,3 +1,9 @@
|
||||
// Beego (http://beego.me/)
|
||||
// @description beego is an open-source, high-performance web framework for the Go programming language.
|
||||
// @link http://github.com/astaxie/beego for the canonical source repository
|
||||
// @license http://github.com/astaxie/beego/blob/master/LICENSE
|
||||
// @authors astaxie
|
||||
|
||||
package config
|
||||
|
||||
import (
|
||||
|
@ -1,4 +1,8 @@
|
||||
//xml parse should incluce in <config></config> tags
|
||||
// Beego (http://beego.me/)
|
||||
// @description beego is an open-source, high-performance web framework for the Go programming language.
|
||||
// @link http://github.com/astaxie/beego for the canonical source repository
|
||||
// @license http://github.com/astaxie/beego/blob/master/LICENSE
|
||||
// @authors astaxie
|
||||
|
||||
package config
|
||||
|
||||
@ -7,15 +11,21 @@ import (
|
||||
"io/ioutil"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/astaxie/beego/config"
|
||||
"github.com/beego/x2j"
|
||||
)
|
||||
|
||||
// XmlConfig is a xml config parser and implements Config interface.
|
||||
// xml configurations should be included in <config></config> tag.
|
||||
// only support key/value pair as <key>value</key> as each item.
|
||||
type XMLConfig struct {
|
||||
}
|
||||
|
||||
func (xmls *XMLConfig) Parse(filename string) (ConfigContainer, error) {
|
||||
// Parse returns a ConfigContainer with parsed xml config map.
|
||||
func (xmls *XMLConfig) Parse(filename string) (config.ConfigContainer, error) {
|
||||
file, err := os.Open(filename)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@ -36,27 +46,33 @@ func (xmls *XMLConfig) Parse(filename string) (ConfigContainer, error) {
|
||||
return x, nil
|
||||
}
|
||||
|
||||
// A Config represents the xml configuration.
|
||||
type XMLConfigContainer struct {
|
||||
data map[string]interface{}
|
||||
sync.Mutex
|
||||
}
|
||||
|
||||
// Bool returns the boolean value for a given key.
|
||||
func (c *XMLConfigContainer) Bool(key string) (bool, error) {
|
||||
return strconv.ParseBool(c.data[key].(string))
|
||||
}
|
||||
|
||||
// Int returns the integer value for a given key.
|
||||
func (c *XMLConfigContainer) Int(key string) (int, error) {
|
||||
return strconv.Atoi(c.data[key].(string))
|
||||
}
|
||||
|
||||
// Int64 returns the int64 value for a given key.
|
||||
func (c *XMLConfigContainer) Int64(key string) (int64, error) {
|
||||
return strconv.ParseInt(c.data[key].(string), 10, 64)
|
||||
}
|
||||
|
||||
// Float returns the float value for a given key.
|
||||
func (c *XMLConfigContainer) Float(key string) (float64, error) {
|
||||
return strconv.ParseFloat(c.data[key].(string), 64)
|
||||
}
|
||||
|
||||
// String returns the string value for a given key.
|
||||
func (c *XMLConfigContainer) String(key string) string {
|
||||
if v, ok := c.data[key].(string); ok {
|
||||
return v
|
||||
@ -64,6 +80,12 @@ func (c *XMLConfigContainer) String(key string) string {
|
||||
return ""
|
||||
}
|
||||
|
||||
// Strings returns the []string value for a given key.
|
||||
func (c *XMLConfigContainer) Strings(key string) []string {
|
||||
return strings.Split(c.String(key), ";")
|
||||
}
|
||||
|
||||
// WriteValue writes a new value for key.
|
||||
func (c *XMLConfigContainer) Set(key, val string) error {
|
||||
c.Lock()
|
||||
defer c.Unlock()
|
||||
@ -71,6 +93,7 @@ func (c *XMLConfigContainer) Set(key, val string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// DIY returns the raw value by a given key.
|
||||
func (c *XMLConfigContainer) DIY(key string) (v interface{}, err error) {
|
||||
if v, ok := c.data[key]; ok {
|
||||
return v, nil
|
||||
@ -79,5 +102,5 @@ func (c *XMLConfigContainer) DIY(key string) (v interface{}, err error) {
|
||||
}
|
||||
|
||||
func init() {
|
||||
Register("xml", &XMLConfig{})
|
||||
config.Register("xml", &XMLConfig{})
|
||||
}
|
@ -1,8 +1,16 @@
|
||||
// Beego (http://beego.me/)
|
||||
// @description beego is an open-source, high-performance web framework for the Go programming language.
|
||||
// @link http://github.com/astaxie/beego for the canonical source repository
|
||||
// @license http://github.com/astaxie/beego/blob/master/LICENSE
|
||||
// @authors astaxie
|
||||
|
||||
package config
|
||||
|
||||
import (
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/astaxie/beego/config"
|
||||
)
|
||||
|
||||
//xml parse should incluce in <config></config> tags
|
||||
@ -30,7 +38,7 @@ func TestXML(t *testing.T) {
|
||||
}
|
||||
f.Close()
|
||||
defer os.Remove("testxml.conf")
|
||||
xmlconf, err := NewConfig("xml", "testxml.conf")
|
||||
xmlconf, err := config.NewConfig("xml", "testxml.conf")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
@ -1,3 +1,9 @@
|
||||
// Beego (http://beego.me/)
|
||||
// @description beego is an open-source, high-performance web framework for the Go programming language.
|
||||
// @link http://github.com/astaxie/beego for the canonical source repository
|
||||
// @license http://github.com/astaxie/beego/blob/master/LICENSE
|
||||
// @authors astaxie
|
||||
|
||||
package config
|
||||
|
||||
import (
|
||||
@ -7,15 +13,19 @@ import (
|
||||
"io/ioutil"
|
||||
"log"
|
||||
"os"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/astaxie/beego/config"
|
||||
"github.com/beego/goyaml2"
|
||||
)
|
||||
|
||||
// YAMLConfig is a yaml config parser and implements Config interface.
|
||||
type YAMLConfig struct {
|
||||
}
|
||||
|
||||
func (yaml *YAMLConfig) Parse(filename string) (ConfigContainer, error) {
|
||||
// Parse returns a ConfigContainer with parsed yaml config map.
|
||||
func (yaml *YAMLConfig) Parse(filename string) (config.ConfigContainer, error) {
|
||||
y := &YAMLConfigContainer{
|
||||
data: make(map[string]interface{}),
|
||||
}
|
||||
@ -27,7 +37,8 @@ func (yaml *YAMLConfig) Parse(filename string) (ConfigContainer, error) {
|
||||
return y, nil
|
||||
}
|
||||
|
||||
// 从Reader读取YAML
|
||||
// Read yaml file to map.
|
||||
// if json like, use json package, unless goyaml2 package.
|
||||
func ReadYmlReader(path string) (cnf map[string]interface{}, err error) {
|
||||
err = nil
|
||||
f, err := os.Open(path)
|
||||
@ -68,11 +79,13 @@ func ReadYmlReader(path string) (cnf map[string]interface{}, err error) {
|
||||
return
|
||||
}
|
||||
|
||||
// A Config represents the yaml configuration.
|
||||
type YAMLConfigContainer struct {
|
||||
data map[string]interface{}
|
||||
sync.Mutex
|
||||
}
|
||||
|
||||
// Bool returns the boolean value for a given key.
|
||||
func (c *YAMLConfigContainer) Bool(key string) (bool, error) {
|
||||
if v, ok := c.data[key].(bool); ok {
|
||||
return v, nil
|
||||
@ -80,6 +93,7 @@ func (c *YAMLConfigContainer) Bool(key string) (bool, error) {
|
||||
return false, errors.New("not bool value")
|
||||
}
|
||||
|
||||
// Int returns the integer value for a given key.
|
||||
func (c *YAMLConfigContainer) Int(key string) (int, error) {
|
||||
if v, ok := c.data[key].(int64); ok {
|
||||
return int(v), nil
|
||||
@ -87,6 +101,7 @@ func (c *YAMLConfigContainer) Int(key string) (int, error) {
|
||||
return 0, errors.New("not int value")
|
||||
}
|
||||
|
||||
// Int64 returns the int64 value for a given key.
|
||||
func (c *YAMLConfigContainer) Int64(key string) (int64, error) {
|
||||
if v, ok := c.data[key].(int64); ok {
|
||||
return v, nil
|
||||
@ -94,6 +109,7 @@ func (c *YAMLConfigContainer) Int64(key string) (int64, error) {
|
||||
return 0, errors.New("not bool value")
|
||||
}
|
||||
|
||||
// Float returns the float value for a given key.
|
||||
func (c *YAMLConfigContainer) Float(key string) (float64, error) {
|
||||
if v, ok := c.data[key].(float64); ok {
|
||||
return v, nil
|
||||
@ -101,6 +117,7 @@ func (c *YAMLConfigContainer) Float(key string) (float64, error) {
|
||||
return 0.0, errors.New("not float64 value")
|
||||
}
|
||||
|
||||
// String returns the string value for a given key.
|
||||
func (c *YAMLConfigContainer) String(key string) string {
|
||||
if v, ok := c.data[key].(string); ok {
|
||||
return v
|
||||
@ -108,6 +125,12 @@ func (c *YAMLConfigContainer) String(key string) string {
|
||||
return ""
|
||||
}
|
||||
|
||||
// Strings returns the []string value for a given key.
|
||||
func (c *YAMLConfigContainer) Strings(key string) []string {
|
||||
return strings.Split(c.String(key), ";")
|
||||
}
|
||||
|
||||
// WriteValue writes a new value for key.
|
||||
func (c *YAMLConfigContainer) Set(key, val string) error {
|
||||
c.Lock()
|
||||
defer c.Unlock()
|
||||
@ -115,6 +138,7 @@ func (c *YAMLConfigContainer) Set(key, val string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// DIY returns the raw value by a given key.
|
||||
func (c *YAMLConfigContainer) DIY(key string) (v interface{}, err error) {
|
||||
if v, ok := c.data[key]; ok {
|
||||
return v, nil
|
||||
@ -123,5 +147,5 @@ func (c *YAMLConfigContainer) DIY(key string) (v interface{}, err error) {
|
||||
}
|
||||
|
||||
func init() {
|
||||
Register("yaml", &YAMLConfig{})
|
||||
config.Register("yaml", &YAMLConfig{})
|
||||
}
|
@ -1,8 +1,16 @@
|
||||
// Beego (http://beego.me/)
|
||||
// @description beego is an open-source, high-performance web framework for the Go programming language.
|
||||
// @link http://github.com/astaxie/beego for the canonical source repository
|
||||
// @license http://github.com/astaxie/beego/blob/master/LICENSE
|
||||
// @authors astaxie
|
||||
|
||||
package config
|
||||
|
||||
import (
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/astaxie/beego/config"
|
||||
)
|
||||
|
||||
var yamlcontext = `
|
||||
@ -27,7 +35,7 @@ func TestYaml(t *testing.T) {
|
||||
}
|
||||
f.Close()
|
||||
defer os.Remove("testyaml.conf")
|
||||
yamlconf, err := NewConfig("yaml", "testyaml.conf")
|
||||
yamlconf, err := config.NewConfig("yaml", "testyaml.conf")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
21
config_test.go
Normal file
21
config_test.go
Normal file
@ -0,0 +1,21 @@
|
||||
// Beego (http://beego.me/)
|
||||
// @description beego is an open-source, high-performance web framework for the Go programming language.
|
||||
// @link http://github.com/astaxie/beego for the canonical source repository
|
||||
// @license http://github.com/astaxie/beego/blob/master/LICENSE
|
||||
// @authors astaxie
|
||||
|
||||
package beego
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestDefaults(t *testing.T) {
|
||||
if FlashName != "BEEGO_FLASH" {
|
||||
t.Errorf("FlashName was not set to default.")
|
||||
}
|
||||
|
||||
if FlashSeperator != "BEEGOFLASH" {
|
||||
t.Errorf("FlashName was not set to default.")
|
||||
}
|
||||
}
|
@ -1,11 +1,26 @@
|
||||
// Beego (http://beego.me/)
|
||||
// @description beego is an open-source, high-performance web framework for the Go programming language.
|
||||
// @link http://github.com/astaxie/beego for the canonical source repository
|
||||
// @license http://github.com/astaxie/beego/blob/master/LICENSE
|
||||
// @authors astaxie
|
||||
|
||||
package context
|
||||
|
||||
import (
|
||||
"crypto/hmac"
|
||||
"crypto/sha1"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/astaxie/beego/middleware"
|
||||
)
|
||||
|
||||
// Http request context struct including BeegoInput, BeegoOutput, http.Request and http.ResponseWriter.
|
||||
// BeegoInput and BeegoOutput provides some api to operate request and response more easily.
|
||||
type Context struct {
|
||||
Input *BeegoInput
|
||||
Output *BeegoOutput
|
||||
@ -13,11 +28,16 @@ type Context struct {
|
||||
ResponseWriter http.ResponseWriter
|
||||
}
|
||||
|
||||
// Redirect does redirection to localurl with http header status code.
|
||||
// It sends http response header directly.
|
||||
func (ctx *Context) Redirect(status int, localurl string) {
|
||||
ctx.Output.Header("Location", localurl)
|
||||
ctx.Output.SetStatus(status)
|
||||
}
|
||||
|
||||
// Abort stops this request.
|
||||
// if middleware.ErrorMaps exists, panic body.
|
||||
// if middleware.HTTPExceptionMaps exists, panic HTTPException struct with status and body string.
|
||||
func (ctx *Context) Abort(status int, body string) {
|
||||
ctx.Output.SetStatus(status)
|
||||
// first panic from ErrorMaps, is is user defined error functions.
|
||||
@ -35,14 +55,58 @@ func (ctx *Context) Abort(status int, body string) {
|
||||
panic(body)
|
||||
}
|
||||
|
||||
// Write string to response body.
|
||||
// it sends response body.
|
||||
func (ctx *Context) WriteString(content string) {
|
||||
ctx.Output.Body([]byte(content))
|
||||
}
|
||||
|
||||
// Get cookie from request by a given key.
|
||||
// It's alias of BeegoInput.Cookie.
|
||||
func (ctx *Context) GetCookie(key string) string {
|
||||
return ctx.Input.Cookie(key)
|
||||
}
|
||||
|
||||
// Set cookie for response.
|
||||
// It's alias of BeegoOutput.Cookie.
|
||||
func (ctx *Context) SetCookie(name string, value string, others ...interface{}) {
|
||||
ctx.Output.Cookie(name, value, others...)
|
||||
}
|
||||
|
||||
// Get secure cookie from request by a given key.
|
||||
func (ctx *Context) GetSecureCookie(Secret, key string) (string, bool) {
|
||||
val := ctx.Input.Cookie(key)
|
||||
if val == "" {
|
||||
return "", false
|
||||
}
|
||||
|
||||
parts := strings.SplitN(val, "|", 3)
|
||||
|
||||
if len(parts) != 3 {
|
||||
return "", false
|
||||
}
|
||||
|
||||
vs := parts[0]
|
||||
timestamp := parts[1]
|
||||
sig := parts[2]
|
||||
|
||||
h := hmac.New(sha1.New, []byte(Secret))
|
||||
fmt.Fprintf(h, "%s%s", vs, timestamp)
|
||||
|
||||
if fmt.Sprintf("%02x", h.Sum(nil)) != sig {
|
||||
return "", false
|
||||
}
|
||||
res, _ := base64.URLEncoding.DecodeString(vs)
|
||||
return string(res), true
|
||||
}
|
||||
|
||||
// Set Secure cookie for response.
|
||||
func (ctx *Context) SetSecureCookie(Secret, name, value string, others ...interface{}) {
|
||||
vs := base64.URLEncoding.EncodeToString([]byte(value))
|
||||
timestamp := strconv.FormatInt(time.Now().UnixNano(), 10)
|
||||
h := hmac.New(sha1.New, []byte(Secret))
|
||||
fmt.Fprintf(h, "%s%s", vs, timestamp)
|
||||
sig := fmt.Sprintf("%02x", h.Sum(nil))
|
||||
cookie := strings.Join([]string{vs, timestamp, sig}, "|")
|
||||
ctx.Output.Cookie(name, cookie, others...)
|
||||
}
|
||||
|
351
context/input.go
351
context/input.go
@ -1,23 +1,37 @@
|
||||
// Beego (http://beego.me/)
|
||||
// @description beego is an open-source, high-performance web framework for the Go programming language.
|
||||
// @link http://github.com/astaxie/beego for the canonical source repository
|
||||
// @license http://github.com/astaxie/beego/blob/master/LICENSE
|
||||
// @authors astaxie
|
||||
|
||||
package context
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"reflect"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/astaxie/beego/session"
|
||||
)
|
||||
|
||||
// BeegoInput operates the http request header ,data ,cookie and body.
|
||||
// it also contains router params and current session.
|
||||
type BeegoInput struct {
|
||||
CruSession session.SessionStore
|
||||
Params map[string]string
|
||||
Data map[interface{}]interface{}
|
||||
Data map[interface{}]interface{} // store some values in this context when calling context in filter or controller.
|
||||
Request *http.Request
|
||||
RequestBody []byte
|
||||
RunController reflect.Type
|
||||
RunMethod string
|
||||
}
|
||||
|
||||
// NewInput return BeegoInput generated by http.Request.
|
||||
func NewInput(req *http.Request) *BeegoInput {
|
||||
return &BeegoInput{
|
||||
Params: make(map[string]string),
|
||||
@ -26,22 +40,27 @@ func NewInput(req *http.Request) *BeegoInput {
|
||||
}
|
||||
}
|
||||
|
||||
// Protocol returns request protocol name, such as HTTP/1.1 .
|
||||
func (input *BeegoInput) Protocol() string {
|
||||
return input.Request.Proto
|
||||
}
|
||||
|
||||
// Uri returns full request url with query string, fragment.
|
||||
func (input *BeegoInput) Uri() string {
|
||||
return input.Request.RequestURI
|
||||
}
|
||||
|
||||
// Url returns request url path (without query string, fragment).
|
||||
func (input *BeegoInput) Url() string {
|
||||
return input.Request.URL.String()
|
||||
}
|
||||
|
||||
// Site returns base site url as scheme://domain type.
|
||||
func (input *BeegoInput) Site() string {
|
||||
return input.Scheme() + "://" + input.Domain()
|
||||
}
|
||||
|
||||
// Scheme returns request scheme as "http" or "https".
|
||||
func (input *BeegoInput) Scheme() string {
|
||||
if input.Request.URL.Scheme != "" {
|
||||
return input.Request.URL.Scheme
|
||||
@ -52,10 +71,14 @@ func (input *BeegoInput) Scheme() string {
|
||||
}
|
||||
}
|
||||
|
||||
// Domain returns host name.
|
||||
// Alias of Host method.
|
||||
func (input *BeegoInput) Domain() string {
|
||||
return input.Host()
|
||||
}
|
||||
|
||||
// Host returns host name.
|
||||
// if no host info in request, return localhost.
|
||||
func (input *BeegoInput) Host() string {
|
||||
if input.Request.Host != "" {
|
||||
hostParts := strings.Split(input.Request.Host, ":")
|
||||
@ -67,30 +90,74 @@ func (input *BeegoInput) Host() string {
|
||||
return "localhost"
|
||||
}
|
||||
|
||||
// Method returns http request method.
|
||||
func (input *BeegoInput) Method() string {
|
||||
return input.Request.Method
|
||||
}
|
||||
|
||||
// Is returns boolean of this request is on given method, such as Is("POST").
|
||||
func (input *BeegoInput) Is(method string) bool {
|
||||
return input.Method() == method
|
||||
}
|
||||
|
||||
// Is this a GET method request?
|
||||
func (input *BeegoInput) IsGet() bool {
|
||||
return input.Is("GET")
|
||||
}
|
||||
|
||||
// Is this a POST method request?
|
||||
func (input *BeegoInput) IsPost() bool {
|
||||
return input.Is("POST")
|
||||
}
|
||||
|
||||
// Is this a Head method request?
|
||||
func (input *BeegoInput) IsHead() bool {
|
||||
return input.Is("HEAD")
|
||||
}
|
||||
|
||||
// Is this a OPTIONS method request?
|
||||
func (input *BeegoInput) IsOptions() bool {
|
||||
return input.Is("OPTIONS")
|
||||
}
|
||||
|
||||
// Is this a PUT method request?
|
||||
func (input *BeegoInput) IsPut() bool {
|
||||
return input.Is("PUT")
|
||||
}
|
||||
|
||||
// Is this a DELETE method request?
|
||||
func (input *BeegoInput) IsDelete() bool {
|
||||
return input.Is("DELETE")
|
||||
}
|
||||
|
||||
// Is this a PATCH method request?
|
||||
func (input *BeegoInput) IsPatch() bool {
|
||||
return input.Is("PATCH")
|
||||
}
|
||||
|
||||
// IsAjax returns boolean of this request is generated by ajax.
|
||||
func (input *BeegoInput) IsAjax() bool {
|
||||
return input.Header("X-Requested-With") == "XMLHttpRequest"
|
||||
}
|
||||
|
||||
// IsSecure returns boolean of this request is in https.
|
||||
func (input *BeegoInput) IsSecure() bool {
|
||||
return input.Scheme() == "https"
|
||||
}
|
||||
|
||||
// IsSecure returns boolean of this request is in webSocket.
|
||||
func (input *BeegoInput) IsWebsocket() bool {
|
||||
return input.Header("Upgrade") == "websocket"
|
||||
}
|
||||
|
||||
// IsSecure returns boolean of whether file uploads in this request or not..
|
||||
func (input *BeegoInput) IsUpload() bool {
|
||||
return input.Request.MultipartForm != nil
|
||||
return strings.Contains(input.Header("Content-Type"), "multipart/form-data")
|
||||
}
|
||||
|
||||
// IP returns request client ip.
|
||||
// if in proxy, return first proxy id.
|
||||
// if error, return 127.0.0.1.
|
||||
func (input *BeegoInput) IP() string {
|
||||
ips := input.Proxy()
|
||||
if len(ips) > 0 && ips[0] != "" {
|
||||
@ -98,13 +165,14 @@ func (input *BeegoInput) IP() string {
|
||||
}
|
||||
ip := strings.Split(input.Request.RemoteAddr, ":")
|
||||
if len(ip) > 0 {
|
||||
if ip[0] != "["{
|
||||
if ip[0] != "[" {
|
||||
return ip[0]
|
||||
}
|
||||
}
|
||||
return "127.0.0.1"
|
||||
}
|
||||
|
||||
// Proxy returns proxy client ips slice.
|
||||
func (input *BeegoInput) Proxy() []string {
|
||||
if ips := input.Header("X-Forwarded-For"); ips != "" {
|
||||
return strings.Split(ips, ",")
|
||||
@ -112,15 +180,20 @@ func (input *BeegoInput) Proxy() []string {
|
||||
return []string{}
|
||||
}
|
||||
|
||||
// Refer returns http referer header.
|
||||
func (input *BeegoInput) Refer() string {
|
||||
return input.Header("Referer")
|
||||
}
|
||||
|
||||
// SubDomains returns sub domain string.
|
||||
// if aa.bb.domain.com, returns aa.bb .
|
||||
func (input *BeegoInput) SubDomains() string {
|
||||
parts := strings.Split(input.Host(), ".")
|
||||
return strings.Join(parts[len(parts)-2:], ".")
|
||||
}
|
||||
|
||||
// Port returns request client port.
|
||||
// when error or empty, return 80.
|
||||
func (input *BeegoInput) Port() int {
|
||||
parts := strings.Split(input.Request.Host, ":")
|
||||
if len(parts) == 2 {
|
||||
@ -130,10 +203,12 @@ func (input *BeegoInput) Port() int {
|
||||
return 80
|
||||
}
|
||||
|
||||
// UserAgent returns request client user agent string.
|
||||
func (input *BeegoInput) UserAgent() string {
|
||||
return input.Header("User-Agent")
|
||||
}
|
||||
|
||||
// Param returns router param by a given key.
|
||||
func (input *BeegoInput) Param(key string) string {
|
||||
if v, ok := input.Params[key]; ok {
|
||||
return v
|
||||
@ -141,15 +216,24 @@ func (input *BeegoInput) Param(key string) string {
|
||||
return ""
|
||||
}
|
||||
|
||||
// Query returns input data item string by a given string.
|
||||
func (input *BeegoInput) Query(key string) string {
|
||||
if val := input.Param(key); val != "" {
|
||||
return val
|
||||
}
|
||||
if input.Request.Form == nil {
|
||||
input.Request.ParseForm()
|
||||
}
|
||||
return input.Request.Form.Get(key)
|
||||
}
|
||||
|
||||
// Header returns request header item string by a given string.
|
||||
func (input *BeegoInput) Header(key string) string {
|
||||
return input.Request.Header.Get(key)
|
||||
}
|
||||
|
||||
// Cookie returns request cookie item string by a given key.
|
||||
// if non-existed, return empty string.
|
||||
func (input *BeegoInput) Cookie(key string) string {
|
||||
ck, err := input.Request.Cookie(key)
|
||||
if err != nil {
|
||||
@ -158,11 +242,13 @@ func (input *BeegoInput) Cookie(key string) string {
|
||||
return ck.Value
|
||||
}
|
||||
|
||||
// Session returns current session item value by a given key.
|
||||
func (input *BeegoInput) Session(key interface{}) interface{} {
|
||||
return input.CruSession.Get(key)
|
||||
}
|
||||
|
||||
func (input *BeegoInput) Body() []byte {
|
||||
// Body returns the raw request body data as bytes.
|
||||
func (input *BeegoInput) CopyBody() []byte {
|
||||
requestbody, _ := ioutil.ReadAll(input.Request.Body)
|
||||
input.Request.Body.Close()
|
||||
bf := bytes.NewBuffer(requestbody)
|
||||
@ -171,6 +257,7 @@ func (input *BeegoInput) Body() []byte {
|
||||
return requestbody
|
||||
}
|
||||
|
||||
// GetData returns the stored data in this context.
|
||||
func (input *BeegoInput) GetData(key interface{}) interface{} {
|
||||
if v, ok := input.Data[key]; ok {
|
||||
return v
|
||||
@ -178,6 +265,262 @@ func (input *BeegoInput) GetData(key interface{}) interface{} {
|
||||
return nil
|
||||
}
|
||||
|
||||
// SetData stores data with given key in this context.
|
||||
// This data are only available in this context.
|
||||
func (input *BeegoInput) SetData(key, val interface{}) {
|
||||
input.Data[key] = val
|
||||
}
|
||||
|
||||
// parseForm or parseMultiForm based on Content-type
|
||||
func (input *BeegoInput) ParseFormOrMulitForm(maxMemory int64) error {
|
||||
// Parse the body depending on the content type.
|
||||
if strings.Contains(input.Header("Content-Type"), "multipart/form-data") {
|
||||
if err := input.Request.ParseMultipartForm(maxMemory); err != nil {
|
||||
return errors.New("Error parsing request body:" + err.Error())
|
||||
}
|
||||
} else if err := input.Request.ParseForm(); err != nil {
|
||||
return errors.New("Error parsing request body:" + err.Error())
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Bind data from request.Form[key] to dest
|
||||
// like /?id=123&isok=true&ft=1.2&ol[0]=1&ol[1]=2&ul[]=str&ul[]=array&user.Name=astaxie
|
||||
// var id int beegoInput.Bind(&id, "id") id ==123
|
||||
// var isok bool beegoInput.Bind(&isok, "isok") id ==true
|
||||
// var ft float64 beegoInput.Bind(&ft, "ft") ft ==1.2
|
||||
// ol := make([]int, 0, 2) beegoInput.Bind(&ol, "ol") ol ==[1 2]
|
||||
// ul := make([]string, 0, 2) beegoInput.Bind(&ul, "ul") ul ==[str array]
|
||||
// user struct{Name} beegoInput.Bind(&user, "user") user == {Name:"astaxie"}
|
||||
func (input *BeegoInput) Bind(dest interface{}, key string) error {
|
||||
value := reflect.ValueOf(dest)
|
||||
if value.Kind() != reflect.Ptr {
|
||||
return errors.New("beego: non-pointer passed to Bind: " + key)
|
||||
}
|
||||
value = value.Elem()
|
||||
if !value.CanSet() {
|
||||
return errors.New("beego: non-settable variable passed to Bind: " + key)
|
||||
}
|
||||
rv := input.bind(key, value.Type())
|
||||
if !rv.IsValid() {
|
||||
return errors.New("beego: reflect value is empty")
|
||||
}
|
||||
value.Set(rv)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (input *BeegoInput) bind(key string, typ reflect.Type) reflect.Value {
|
||||
rv := reflect.Zero(reflect.TypeOf(0))
|
||||
switch typ.Kind() {
|
||||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
|
||||
val := input.Query(key)
|
||||
if len(val) == 0 {
|
||||
return rv
|
||||
}
|
||||
rv = input.bindInt(val, typ)
|
||||
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
|
||||
val := input.Query(key)
|
||||
if len(val) == 0 {
|
||||
return rv
|
||||
}
|
||||
rv = input.bindUint(val, typ)
|
||||
case reflect.Float32, reflect.Float64:
|
||||
val := input.Query(key)
|
||||
if len(val) == 0 {
|
||||
return rv
|
||||
}
|
||||
rv = input.bindFloat(val, typ)
|
||||
case reflect.String:
|
||||
val := input.Query(key)
|
||||
if len(val) == 0 {
|
||||
return rv
|
||||
}
|
||||
rv = input.bindString(val, typ)
|
||||
case reflect.Bool:
|
||||
val := input.Query(key)
|
||||
if len(val) == 0 {
|
||||
return rv
|
||||
}
|
||||
rv = input.bindBool(val, typ)
|
||||
case reflect.Slice:
|
||||
rv = input.bindSlice(&input.Request.Form, key, typ)
|
||||
case reflect.Struct:
|
||||
rv = input.bindStruct(&input.Request.Form, key, typ)
|
||||
case reflect.Ptr:
|
||||
rv = input.bindPoint(key, typ)
|
||||
case reflect.Map:
|
||||
rv = input.bindMap(&input.Request.Form, key, typ)
|
||||
}
|
||||
return rv
|
||||
}
|
||||
|
||||
func (input *BeegoInput) bindValue(val string, typ reflect.Type) reflect.Value {
|
||||
rv := reflect.Zero(reflect.TypeOf(0))
|
||||
switch typ.Kind() {
|
||||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
|
||||
rv = input.bindInt(val, typ)
|
||||
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
|
||||
rv = input.bindUint(val, typ)
|
||||
case reflect.Float32, reflect.Float64:
|
||||
rv = input.bindFloat(val, typ)
|
||||
case reflect.String:
|
||||
rv = input.bindString(val, typ)
|
||||
case reflect.Bool:
|
||||
rv = input.bindBool(val, typ)
|
||||
case reflect.Slice:
|
||||
rv = input.bindSlice(&url.Values{"": {val}}, "", typ)
|
||||
case reflect.Struct:
|
||||
rv = input.bindStruct(&url.Values{"": {val}}, "", typ)
|
||||
case reflect.Ptr:
|
||||
rv = input.bindPoint(val, typ)
|
||||
case reflect.Map:
|
||||
rv = input.bindMap(&url.Values{"": {val}}, "", typ)
|
||||
}
|
||||
return rv
|
||||
}
|
||||
|
||||
func (input *BeegoInput) bindInt(val string, typ reflect.Type) reflect.Value {
|
||||
intValue, err := strconv.ParseInt(val, 10, 64)
|
||||
if err != nil {
|
||||
return reflect.Zero(typ)
|
||||
}
|
||||
pValue := reflect.New(typ)
|
||||
pValue.Elem().SetInt(intValue)
|
||||
return pValue.Elem()
|
||||
}
|
||||
|
||||
func (input *BeegoInput) bindUint(val string, typ reflect.Type) reflect.Value {
|
||||
uintValue, err := strconv.ParseUint(val, 10, 64)
|
||||
if err != nil {
|
||||
return reflect.Zero(typ)
|
||||
}
|
||||
pValue := reflect.New(typ)
|
||||
pValue.Elem().SetUint(uintValue)
|
||||
return pValue.Elem()
|
||||
}
|
||||
|
||||
func (input *BeegoInput) bindFloat(val string, typ reflect.Type) reflect.Value {
|
||||
floatValue, err := strconv.ParseFloat(val, 64)
|
||||
if err != nil {
|
||||
return reflect.Zero(typ)
|
||||
}
|
||||
pValue := reflect.New(typ)
|
||||
pValue.Elem().SetFloat(floatValue)
|
||||
return pValue.Elem()
|
||||
}
|
||||
|
||||
func (input *BeegoInput) bindString(val string, typ reflect.Type) reflect.Value {
|
||||
return reflect.ValueOf(val)
|
||||
}
|
||||
|
||||
func (input *BeegoInput) bindBool(val string, typ reflect.Type) reflect.Value {
|
||||
val = strings.TrimSpace(strings.ToLower(val))
|
||||
switch val {
|
||||
case "true", "on", "1":
|
||||
return reflect.ValueOf(true)
|
||||
}
|
||||
return reflect.ValueOf(false)
|
||||
}
|
||||
|
||||
type sliceValue struct {
|
||||
index int // Index extracted from brackets. If -1, no index was provided.
|
||||
value reflect.Value // the bound value for this slice element.
|
||||
}
|
||||
|
||||
func (input *BeegoInput) bindSlice(params *url.Values, key string, typ reflect.Type) reflect.Value {
|
||||
maxIndex := -1
|
||||
numNoIndex := 0
|
||||
sliceValues := []sliceValue{}
|
||||
for reqKey, vals := range *params {
|
||||
if !strings.HasPrefix(reqKey, key+"[") {
|
||||
continue
|
||||
}
|
||||
// Extract the index, and the index where a sub-key starts. (e.g. field[0].subkey)
|
||||
index := -1
|
||||
leftBracket, rightBracket := len(key), strings.Index(reqKey[len(key):], "]")+len(key)
|
||||
if rightBracket > leftBracket+1 {
|
||||
index, _ = strconv.Atoi(reqKey[leftBracket+1 : rightBracket])
|
||||
}
|
||||
subKeyIndex := rightBracket + 1
|
||||
|
||||
// Handle the indexed case.
|
||||
if index > -1 {
|
||||
if index > maxIndex {
|
||||
maxIndex = index
|
||||
}
|
||||
sliceValues = append(sliceValues, sliceValue{
|
||||
index: index,
|
||||
value: input.bind(reqKey[:subKeyIndex], typ.Elem()),
|
||||
})
|
||||
continue
|
||||
}
|
||||
|
||||
// It's an un-indexed element. (e.g. element[])
|
||||
numNoIndex += len(vals)
|
||||
for _, val := range vals {
|
||||
// Unindexed values can only be direct-bound.
|
||||
sliceValues = append(sliceValues, sliceValue{
|
||||
index: -1,
|
||||
value: input.bindValue(val, typ.Elem()),
|
||||
})
|
||||
}
|
||||
}
|
||||
resultArray := reflect.MakeSlice(typ, maxIndex+1, maxIndex+1+numNoIndex)
|
||||
for _, sv := range sliceValues {
|
||||
if sv.index != -1 {
|
||||
resultArray.Index(sv.index).Set(sv.value)
|
||||
} else {
|
||||
resultArray = reflect.Append(resultArray, sv.value)
|
||||
}
|
||||
}
|
||||
return resultArray
|
||||
}
|
||||
|
||||
func (input *BeegoInput) bindStruct(params *url.Values, key string, typ reflect.Type) reflect.Value {
|
||||
result := reflect.New(typ).Elem()
|
||||
fieldValues := make(map[string]reflect.Value)
|
||||
for reqKey, val := range *params {
|
||||
if !strings.HasPrefix(reqKey, key+".") {
|
||||
continue
|
||||
}
|
||||
|
||||
fieldName := reqKey[len(key)+1:]
|
||||
|
||||
if _, ok := fieldValues[fieldName]; !ok {
|
||||
// Time to bind this field. Get it and make sure we can set it.
|
||||
fieldValue := result.FieldByName(fieldName)
|
||||
if !fieldValue.IsValid() {
|
||||
continue
|
||||
}
|
||||
if !fieldValue.CanSet() {
|
||||
continue
|
||||
}
|
||||
boundVal := input.bindValue(val[0], fieldValue.Type())
|
||||
fieldValue.Set(boundVal)
|
||||
fieldValues[fieldName] = boundVal
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
func (input *BeegoInput) bindPoint(key string, typ reflect.Type) reflect.Value {
|
||||
return input.bind(key, typ.Elem()).Addr()
|
||||
}
|
||||
|
||||
func (input *BeegoInput) bindMap(params *url.Values, key string, typ reflect.Type) reflect.Value {
|
||||
var (
|
||||
result = reflect.MakeMap(typ)
|
||||
keyType = typ.Key()
|
||||
valueType = typ.Elem()
|
||||
)
|
||||
for paramName, values := range *params {
|
||||
if !strings.HasPrefix(paramName, key+"[") || paramName[len(paramName)-1] != ']' {
|
||||
continue
|
||||
}
|
||||
|
||||
key := paramName[len(key)+1 : len(paramName)-1]
|
||||
result.SetMapIndex(input.bindValue(key, keyType), input.bindValue(values[0], valueType))
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
64
context/input_test.go
Normal file
64
context/input_test.go
Normal file
@ -0,0 +1,64 @@
|
||||
// Beego (http://beego.me/)
|
||||
// @description beego is an open-source, high-performance web framework for the Go programming language.
|
||||
// @link http://github.com/astaxie/beego for the canonical source repository
|
||||
// @license http://github.com/astaxie/beego/blob/master/LICENSE
|
||||
// @authors astaxie
|
||||
|
||||
package context
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestParse(t *testing.T) {
|
||||
r, _ := http.NewRequest("GET", "/?id=123&isok=true&ft=1.2&ol[0]=1&ol[1]=2&ul[]=str&ul[]=array&user.Name=astaxie", nil)
|
||||
beegoInput := NewInput(r)
|
||||
beegoInput.ParseFormOrMulitForm(1 << 20)
|
||||
|
||||
var id int
|
||||
err := beegoInput.Bind(&id, "id")
|
||||
if id != 123 || err != nil {
|
||||
t.Fatal("id should has int value")
|
||||
}
|
||||
fmt.Println(id)
|
||||
|
||||
var isok bool
|
||||
err = beegoInput.Bind(&isok, "isok")
|
||||
if !isok || err != nil {
|
||||
t.Fatal("isok should be true")
|
||||
}
|
||||
fmt.Println(isok)
|
||||
|
||||
var float float64
|
||||
err = beegoInput.Bind(&float, "ft")
|
||||
if float != 1.2 || err != nil {
|
||||
t.Fatal("float should be equal to 1.2")
|
||||
}
|
||||
fmt.Println(float)
|
||||
|
||||
ol := make([]int, 0, 2)
|
||||
err = beegoInput.Bind(&ol, "ol")
|
||||
if len(ol) != 2 || err != nil || ol[0] != 1 || ol[1] != 2 {
|
||||
t.Fatal("ol should has two elements")
|
||||
}
|
||||
fmt.Println(ol)
|
||||
|
||||
ul := make([]string, 0, 2)
|
||||
err = beegoInput.Bind(&ul, "ul")
|
||||
if len(ul) != 2 || err != nil || ul[0] != "str" || ul[1] != "array" {
|
||||
t.Fatal("ul should has two elements")
|
||||
}
|
||||
fmt.Println(ul)
|
||||
|
||||
type User struct {
|
||||
Name string
|
||||
}
|
||||
user := User{}
|
||||
err = beegoInput.Bind(&user, "user")
|
||||
if err != nil || user.Name != "astaxie" {
|
||||
t.Fatal("user should has name")
|
||||
}
|
||||
fmt.Println(user)
|
||||
}
|
@ -1,3 +1,9 @@
|
||||
// Beego (http://beego.me/)
|
||||
// @description beego is an open-source, high-performance web framework for the Go programming language.
|
||||
// @link http://github.com/astaxie/beego for the canonical source repository
|
||||
// @license http://github.com/astaxie/beego/blob/master/LICENSE
|
||||
// @authors astaxie
|
||||
|
||||
package context
|
||||
|
||||
import (
|
||||
@ -17,20 +23,27 @@ import (
|
||||
"strings"
|
||||
)
|
||||
|
||||
// BeegoOutput does work for sending response header.
|
||||
type BeegoOutput struct {
|
||||
Context *Context
|
||||
Status int
|
||||
EnableGzip bool
|
||||
}
|
||||
|
||||
// NewOutput returns new BeegoOutput.
|
||||
// it contains nothing now.
|
||||
func NewOutput() *BeegoOutput {
|
||||
return &BeegoOutput{}
|
||||
}
|
||||
|
||||
// Header sets response header item string via given key.
|
||||
func (output *BeegoOutput) Header(key, val string) {
|
||||
output.Context.ResponseWriter.Header().Set(key, val)
|
||||
}
|
||||
|
||||
// Body sets response body content.
|
||||
// if EnableGzip, compress content string.
|
||||
// it sends out response body directly.
|
||||
func (output *BeegoOutput) Body(content []byte) {
|
||||
output_writer := output.Context.ResponseWriter.(io.Writer)
|
||||
if output.EnableGzip == true && output.Context.Input.Header("Accept-Encoding") != "" {
|
||||
@ -64,43 +77,83 @@ func (output *BeegoOutput) Body(content []byte) {
|
||||
}
|
||||
}
|
||||
|
||||
// Cookie sets cookie value via given key.
|
||||
// others are ordered as cookie's max age time, path,domain, secure and httponly.
|
||||
func (output *BeegoOutput) Cookie(name string, value string, others ...interface{}) {
|
||||
var b bytes.Buffer
|
||||
fmt.Fprintf(&b, "%s=%s", sanitizeName(name), sanitizeValue(value))
|
||||
if len(others) > 0 {
|
||||
switch others[0].(type) {
|
||||
switch v := others[0].(type) {
|
||||
case int:
|
||||
if others[0].(int) > 0 {
|
||||
fmt.Fprintf(&b, "; Max-Age=%d", others[0].(int))
|
||||
} else if others[0].(int) < 0 {
|
||||
if v > 0 {
|
||||
fmt.Fprintf(&b, "; Max-Age=%d", v)
|
||||
} else if v < 0 {
|
||||
fmt.Fprintf(&b, "; Max-Age=0")
|
||||
}
|
||||
case int64:
|
||||
if others[0].(int64) > 0 {
|
||||
fmt.Fprintf(&b, "; Max-Age=%d", others[0].(int64))
|
||||
} else if others[0].(int64) < 0 {
|
||||
if v > 0 {
|
||||
fmt.Fprintf(&b, "; Max-Age=%d", v)
|
||||
} else if v < 0 {
|
||||
fmt.Fprintf(&b, "; Max-Age=0")
|
||||
}
|
||||
case int32:
|
||||
if others[0].(int32) > 0 {
|
||||
fmt.Fprintf(&b, "; Max-Age=%d", others[0].(int32))
|
||||
} else if others[0].(int32) < 0 {
|
||||
if v > 0 {
|
||||
fmt.Fprintf(&b, "; Max-Age=%d", v)
|
||||
} else if v < 0 {
|
||||
fmt.Fprintf(&b, "; Max-Age=0")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// the settings below
|
||||
// Path, Domain, Secure, HttpOnly
|
||||
// can use nil skip set
|
||||
|
||||
// default "/"
|
||||
if len(others) > 1 {
|
||||
fmt.Fprintf(&b, "; Path=%s", sanitizeValue(others[1].(string)))
|
||||
if v, ok := others[1].(string); ok && len(v) > 0 {
|
||||
fmt.Fprintf(&b, "; Path=%s", sanitizeValue(v))
|
||||
}
|
||||
} else {
|
||||
fmt.Fprintf(&b, "; Path=%s", "/")
|
||||
}
|
||||
|
||||
// default empty
|
||||
if len(others) > 2 {
|
||||
fmt.Fprintf(&b, "; Domain=%s", sanitizeValue(others[2].(string)))
|
||||
if v, ok := others[2].(string); ok && len(v) > 0 {
|
||||
fmt.Fprintf(&b, "; Domain=%s", sanitizeValue(v))
|
||||
}
|
||||
}
|
||||
|
||||
// default empty
|
||||
if len(others) > 3 {
|
||||
var secure bool
|
||||
switch v := others[3].(type) {
|
||||
case bool:
|
||||
secure = v
|
||||
default:
|
||||
if others[3] != nil {
|
||||
secure = true
|
||||
}
|
||||
}
|
||||
if secure {
|
||||
fmt.Fprintf(&b, "; Secure")
|
||||
}
|
||||
}
|
||||
|
||||
// default false. for session cookie default true
|
||||
httponly := false
|
||||
if len(others) > 4 {
|
||||
if v, ok := others[4].(bool); ok && v {
|
||||
// HttpOnly = true
|
||||
httponly = true
|
||||
}
|
||||
}
|
||||
|
||||
if httponly {
|
||||
fmt.Fprintf(&b, "; HttpOnly")
|
||||
}
|
||||
|
||||
output.Context.ResponseWriter.Header().Add("Set-Cookie", b.String())
|
||||
}
|
||||
|
||||
@ -116,6 +169,8 @@ func sanitizeValue(v string) string {
|
||||
return cookieValueSanitizer.Replace(v)
|
||||
}
|
||||
|
||||
// Json writes json to response body.
|
||||
// if coding is true, it converts utf-8 to \u0000 type.
|
||||
func (output *BeegoOutput) Json(data interface{}, hasIndent bool, coding bool) error {
|
||||
output.Header("Content-Type", "application/json;charset=UTF-8")
|
||||
var content []byte
|
||||
@ -136,6 +191,7 @@ func (output *BeegoOutput) Json(data interface{}, hasIndent bool, coding bool) e
|
||||
return nil
|
||||
}
|
||||
|
||||
// Jsonp writes jsonp to response body.
|
||||
func (output *BeegoOutput) Jsonp(data interface{}, hasIndent bool) error {
|
||||
output.Header("Content-Type", "application/javascript;charset=UTF-8")
|
||||
var content []byte
|
||||
@ -161,6 +217,7 @@ func (output *BeegoOutput) Jsonp(data interface{}, hasIndent bool) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Xml writes xml string to response body.
|
||||
func (output *BeegoOutput) Xml(data interface{}, hasIndent bool) error {
|
||||
output.Header("Content-Type", "application/xml;charset=UTF-8")
|
||||
var content []byte
|
||||
@ -178,6 +235,8 @@ func (output *BeegoOutput) Xml(data interface{}, hasIndent bool) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Download forces response for download file.
|
||||
// it prepares the download response header automatically.
|
||||
func (output *BeegoOutput) Download(file string) {
|
||||
output.Header("Content-Description", "File Transfer")
|
||||
output.Header("Content-Type", "application/octet-stream")
|
||||
@ -189,6 +248,8 @@ func (output *BeegoOutput) Download(file string) {
|
||||
http.ServeFile(output.Context.ResponseWriter, output.Context.Request, file)
|
||||
}
|
||||
|
||||
// ContentType sets the content type from ext string.
|
||||
// MIME type is given in mime package.
|
||||
func (output *BeegoOutput) ContentType(ext string) {
|
||||
if !strings.HasPrefix(ext, ".") {
|
||||
ext = "." + ext
|
||||
@ -199,43 +260,63 @@ func (output *BeegoOutput) ContentType(ext string) {
|
||||
}
|
||||
}
|
||||
|
||||
// SetStatus sets response status code.
|
||||
// It writes response header directly.
|
||||
func (output *BeegoOutput) SetStatus(status int) {
|
||||
output.Context.ResponseWriter.WriteHeader(status)
|
||||
output.Status = status
|
||||
}
|
||||
|
||||
// IsCachable returns boolean of this request is cached.
|
||||
// HTTP 304 means cached.
|
||||
func (output *BeegoOutput) IsCachable(status int) bool {
|
||||
return output.Status >= 200 && output.Status < 300 || output.Status == 304
|
||||
}
|
||||
|
||||
// IsEmpty returns boolean of this request is empty.
|
||||
// HTTP 201,204 and 304 means empty.
|
||||
func (output *BeegoOutput) IsEmpty(status int) bool {
|
||||
return output.Status == 201 || output.Status == 204 || output.Status == 304
|
||||
}
|
||||
|
||||
// IsOk returns boolean of this request runs well.
|
||||
// HTTP 200 means ok.
|
||||
func (output *BeegoOutput) IsOk(status int) bool {
|
||||
return output.Status == 200
|
||||
}
|
||||
|
||||
// IsSuccessful returns boolean of this request runs successfully.
|
||||
// HTTP 2xx means ok.
|
||||
func (output *BeegoOutput) IsSuccessful(status int) bool {
|
||||
return output.Status >= 200 && output.Status < 300
|
||||
}
|
||||
|
||||
// IsRedirect returns boolean of this request is redirection header.
|
||||
// HTTP 301,302,307 means redirection.
|
||||
func (output *BeegoOutput) IsRedirect(status int) bool {
|
||||
return output.Status == 301 || output.Status == 302 || output.Status == 303 || output.Status == 307
|
||||
}
|
||||
|
||||
// IsForbidden returns boolean of this request is forbidden.
|
||||
// HTTP 403 means forbidden.
|
||||
func (output *BeegoOutput) IsForbidden(status int) bool {
|
||||
return output.Status == 403
|
||||
}
|
||||
|
||||
// IsNotFound returns boolean of this request is not found.
|
||||
// HTTP 404 means forbidden.
|
||||
func (output *BeegoOutput) IsNotFound(status int) bool {
|
||||
return output.Status == 404
|
||||
}
|
||||
|
||||
// IsClient returns boolean of this request client sends error data.
|
||||
// HTTP 4xx means forbidden.
|
||||
func (output *BeegoOutput) IsClientError(status int) bool {
|
||||
return output.Status >= 400 && output.Status < 500
|
||||
}
|
||||
|
||||
// IsServerError returns boolean of this server handler errors.
|
||||
// HTTP 5xx means server internal error.
|
||||
func (output *BeegoOutput) IsServerError(status int) bool {
|
||||
return output.Status >= 500 && output.Status < 600
|
||||
}
|
||||
@ -254,6 +335,7 @@ func stringsToJson(str string) string {
|
||||
return jsons
|
||||
}
|
||||
|
||||
// Sessions sets session item value with given key.
|
||||
func (output *BeegoOutput) Session(name interface{}, value interface{}) {
|
||||
output.Context.Input.CruSession.Set(name, value)
|
||||
}
|
||||
|
124
controller.go
124
controller.go
@ -1,13 +1,14 @@
|
||||
// Beego (http://beego.me/)
|
||||
// @description beego is an open-source, high-performance web framework for the Go programming language.
|
||||
// @link http://github.com/astaxie/beego for the canonical source repository
|
||||
// @license http://github.com/astaxie/beego/blob/master/LICENSE
|
||||
// @authors astaxie
|
||||
|
||||
package beego
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/hmac"
|
||||
"crypto/rand"
|
||||
"crypto/sha1"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"fmt"
|
||||
"html/template"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
@ -18,10 +19,17 @@ import (
|
||||
"reflect"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/astaxie/beego/context"
|
||||
"github.com/astaxie/beego/session"
|
||||
"github.com/astaxie/beego/utils"
|
||||
)
|
||||
|
||||
//commonly used mime-types
|
||||
const (
|
||||
applicationJson = "application/json"
|
||||
applicationXml = "applicatoin/xml"
|
||||
textXml = "text/xml"
|
||||
)
|
||||
|
||||
var (
|
||||
@ -45,6 +53,8 @@ type Controller struct {
|
||||
CruSession session.SessionStore
|
||||
XSRFExpire int
|
||||
AppController interface{}
|
||||
EnableRender bool
|
||||
EnableXSRF bool
|
||||
}
|
||||
|
||||
// ControllerInterface is an interface to uniform all controller handler.
|
||||
@ -66,7 +76,6 @@ type ControllerInterface interface {
|
||||
|
||||
// Init generates default values of controller operations.
|
||||
func (c *Controller) Init(ctx *context.Context, controllerName, actionName string, app interface{}) {
|
||||
c.Data = make(map[interface{}]interface{})
|
||||
c.Layout = ""
|
||||
c.TplNames = ""
|
||||
c.controllerName = controllerName
|
||||
@ -74,6 +83,9 @@ func (c *Controller) Init(ctx *context.Context, controllerName, actionName strin
|
||||
c.Ctx = ctx
|
||||
c.TplExt = "tpl"
|
||||
c.AppController = app
|
||||
c.EnableRender = true
|
||||
c.EnableXSRF = true
|
||||
c.Data = ctx.Input.Data
|
||||
}
|
||||
|
||||
// Prepare runs after Init before request function execution.
|
||||
@ -123,6 +135,9 @@ func (c *Controller) Options() {
|
||||
|
||||
// Render sends the response with rendered template bytes as text/html type.
|
||||
func (c *Controller) Render() error {
|
||||
if !c.EnableRender {
|
||||
return nil
|
||||
}
|
||||
rb, err := c.RenderBytes()
|
||||
|
||||
if err != nil {
|
||||
@ -140,7 +155,7 @@ func (c *Controller) RenderString() (string, error) {
|
||||
return string(b), e
|
||||
}
|
||||
|
||||
// RenderBytes returns the bytes of renderd tempate string. Do not send out response.
|
||||
// RenderBytes returns the bytes of rendered template string. Do not send out response.
|
||||
func (c *Controller) RenderBytes() ([]byte, error) {
|
||||
//if the controller has set layout, then first get the tplname's content set the content to the layout
|
||||
if c.Layout != "" {
|
||||
@ -153,7 +168,6 @@ func (c *Controller) RenderBytes() ([]byte, error) {
|
||||
newbytes := bytes.NewBufferString("")
|
||||
if _, ok := BeeTemplates[c.TplNames]; !ok {
|
||||
panic("can't find templatefile in the path:" + c.TplNames)
|
||||
return []byte{}, errors.New("can't find templatefile in the path:" + c.TplNames)
|
||||
}
|
||||
err := BeeTemplates[c.TplNames].ExecuteTemplate(newbytes, c.TplNames, c.Data)
|
||||
if err != nil {
|
||||
@ -165,7 +179,7 @@ func (c *Controller) RenderBytes() ([]byte, error) {
|
||||
|
||||
if c.LayoutSections != nil {
|
||||
for sectionName, sectionTpl := range c.LayoutSections {
|
||||
if (sectionTpl == "") {
|
||||
if sectionTpl == "" {
|
||||
c.Data[sectionName] = ""
|
||||
continue
|
||||
}
|
||||
@ -199,7 +213,6 @@ func (c *Controller) RenderBytes() ([]byte, error) {
|
||||
ibytes := bytes.NewBufferString("")
|
||||
if _, ok := BeeTemplates[c.TplNames]; !ok {
|
||||
panic("can't find templatefile in the path:" + c.TplNames)
|
||||
return []byte{}, errors.New("can't find templatefile in the path:" + c.TplNames)
|
||||
}
|
||||
err := BeeTemplates[c.TplNames].ExecuteTemplate(ibytes, c.TplNames, c.Data)
|
||||
if err != nil {
|
||||
@ -209,7 +222,6 @@ func (c *Controller) RenderBytes() ([]byte, error) {
|
||||
icontent, _ := ioutil.ReadAll(ibytes)
|
||||
return icontent, nil
|
||||
}
|
||||
return []byte{}, nil
|
||||
}
|
||||
|
||||
// Redirect sends the redirection response to url with status code.
|
||||
@ -243,7 +255,6 @@ func (c *Controller) UrlFor(endpoint string, values ...string) string {
|
||||
} else {
|
||||
return UrlFor(endpoint, values...)
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// ServeJson sends a json response with encoding charset.
|
||||
@ -283,12 +294,23 @@ func (c *Controller) ServeXml() {
|
||||
c.Ctx.Output.Xml(c.Data["xml"], hasIndent)
|
||||
}
|
||||
|
||||
// ServeFormatted serve Xml OR Json, depending on the value of the Accept header
|
||||
|
||||
func (c *Controller) ServeFormatted() {
|
||||
accept := c.Ctx.Input.Header("Accept")
|
||||
switch accept {
|
||||
case applicationJson:
|
||||
c.ServeJson()
|
||||
case applicationXml, textXml:
|
||||
c.ServeXml()
|
||||
default:
|
||||
c.ServeJson()
|
||||
}
|
||||
}
|
||||
|
||||
// Input returns the input data map from POST or PUT request body and query string.
|
||||
func (c *Controller) Input() url.Values {
|
||||
ct := c.Ctx.Request.Header.Get("Content-Type")
|
||||
if strings.Contains(ct, "multipart/form-data") {
|
||||
c.Ctx.Request.ParseMultipartForm(MaxMemory) //64MB
|
||||
} else {
|
||||
if c.Ctx.Request.Form == nil {
|
||||
c.Ctx.Request.ParseForm()
|
||||
}
|
||||
return c.Ctx.Request.Form
|
||||
@ -301,17 +323,17 @@ func (c *Controller) ParseForm(obj interface{}) error {
|
||||
|
||||
// GetString returns the input value by key string.
|
||||
func (c *Controller) GetString(key string) string {
|
||||
return c.Input().Get(key)
|
||||
return c.Ctx.Input.Query(key)
|
||||
}
|
||||
|
||||
// GetStrings returns the input string slice by key string.
|
||||
// it's designed for multi-value input field such as checkbox(input[type=checkbox]), multi-selection.
|
||||
func (c *Controller) GetStrings(key string) []string {
|
||||
r := c.Ctx.Request
|
||||
if r.Form == nil {
|
||||
f := c.Input()
|
||||
if f == nil {
|
||||
return []string{}
|
||||
}
|
||||
vs := r.Form[key]
|
||||
vs := f[key]
|
||||
if len(vs) > 0 {
|
||||
return vs
|
||||
}
|
||||
@ -320,17 +342,17 @@ func (c *Controller) GetStrings(key string) []string {
|
||||
|
||||
// GetInt returns input value as int64.
|
||||
func (c *Controller) GetInt(key string) (int64, error) {
|
||||
return strconv.ParseInt(c.Input().Get(key), 10, 64)
|
||||
return strconv.ParseInt(c.Ctx.Input.Query(key), 10, 64)
|
||||
}
|
||||
|
||||
// GetBool returns input value as bool.
|
||||
func (c *Controller) GetBool(key string) (bool, error) {
|
||||
return strconv.ParseBool(c.Input().Get(key))
|
||||
return strconv.ParseBool(c.Ctx.Input.Query(key))
|
||||
}
|
||||
|
||||
// GetFloat returns input value as float64.
|
||||
func (c *Controller) GetFloat(key string) (float64, error) {
|
||||
return strconv.ParseFloat(c.Input().Get(key), 64)
|
||||
return strconv.ParseFloat(c.Ctx.Input.Query(key), 64)
|
||||
}
|
||||
|
||||
// GetFile returns the file data in file upload field named as key.
|
||||
@ -391,12 +413,16 @@ func (c *Controller) DelSession(name interface{}) {
|
||||
// SessionRegenerateID regenerates session id for this session.
|
||||
// the session data have no changes.
|
||||
func (c *Controller) SessionRegenerateID() {
|
||||
if c.CruSession != nil {
|
||||
c.CruSession.SessionRelease(c.Ctx.ResponseWriter)
|
||||
}
|
||||
c.CruSession = GlobalSessions.SessionRegenerateId(c.Ctx.ResponseWriter, c.Ctx.Request)
|
||||
c.Ctx.Input.CruSession = c.CruSession
|
||||
}
|
||||
|
||||
// DestroySession cleans session data and session cookie.
|
||||
func (c *Controller) DestroySession() {
|
||||
c.Ctx.Input.CruSession.Flush()
|
||||
GlobalSessions.SessionDestroy(c.Ctx.ResponseWriter, c.Ctx.Request)
|
||||
}
|
||||
|
||||
@ -407,40 +433,12 @@ func (c *Controller) IsAjax() bool {
|
||||
|
||||
// GetSecureCookie returns decoded cookie value from encoded browser cookie values.
|
||||
func (c *Controller) GetSecureCookie(Secret, key string) (string, bool) {
|
||||
val := c.Ctx.GetCookie(key)
|
||||
if val == "" {
|
||||
return "", false
|
||||
}
|
||||
|
||||
parts := strings.SplitN(val, "|", 3)
|
||||
|
||||
if len(parts) != 3 {
|
||||
return "", false
|
||||
}
|
||||
|
||||
vs := parts[0]
|
||||
timestamp := parts[1]
|
||||
sig := parts[2]
|
||||
|
||||
h := hmac.New(sha1.New, []byte(Secret))
|
||||
fmt.Fprintf(h, "%s%s", vs, timestamp)
|
||||
|
||||
if fmt.Sprintf("%02x", h.Sum(nil)) != sig {
|
||||
return "", false
|
||||
}
|
||||
res, _ := base64.URLEncoding.DecodeString(vs)
|
||||
return string(res), true
|
||||
return c.Ctx.GetSecureCookie(Secret, key)
|
||||
}
|
||||
|
||||
// SetSecureCookie puts value into cookie after encoded the value.
|
||||
func (c *Controller) SetSecureCookie(Secret, name, val string, age int64) {
|
||||
vs := base64.URLEncoding.EncodeToString([]byte(val))
|
||||
timestamp := strconv.FormatInt(time.Now().UnixNano(), 10)
|
||||
h := hmac.New(sha1.New, []byte(Secret))
|
||||
fmt.Fprintf(h, "%s%s", vs, timestamp)
|
||||
sig := fmt.Sprintf("%02x", h.Sum(nil))
|
||||
cookie := strings.Join([]string{vs, timestamp, sig}, "|")
|
||||
c.Ctx.SetCookie(name, cookie, age, "/")
|
||||
func (c *Controller) SetSecureCookie(Secret, name, value string, others ...interface{}) {
|
||||
c.Ctx.SetSecureCookie(Secret, name, value, others...)
|
||||
}
|
||||
|
||||
// XsrfToken creates a xsrf token string and returns.
|
||||
@ -454,7 +452,7 @@ func (c *Controller) XsrfToken() string {
|
||||
} else {
|
||||
expire = int64(XSRFExpire)
|
||||
}
|
||||
token = getRandomString(15)
|
||||
token = string(utils.RandomCreateBytes(15))
|
||||
c.SetSecureCookie(XSRFKEY, "_xsrf", token, expire)
|
||||
}
|
||||
c._xsrf_token = token
|
||||
@ -466,6 +464,9 @@ func (c *Controller) XsrfToken() string {
|
||||
// the token can provided in request header "X-Xsrftoken" and "X-CsrfToken"
|
||||
// or in form field value named as "_xsrf".
|
||||
func (c *Controller) CheckXsrfCookie() bool {
|
||||
if !c.EnableXSRF {
|
||||
return true
|
||||
}
|
||||
token := c.GetString("_xsrf")
|
||||
if token == "" {
|
||||
token = c.Ctx.Request.Header.Get("X-Xsrftoken")
|
||||
@ -491,14 +492,3 @@ func (c *Controller) XsrfFormHtml() string {
|
||||
func (c *Controller) GetControllerAndAction() (controllerName, actionName string) {
|
||||
return c.controllerName, c.actionName
|
||||
}
|
||||
|
||||
// getRandomString returns random string.
|
||||
func getRandomString(n int) string {
|
||||
const alphanum = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"
|
||||
var bytes = make([]byte, n)
|
||||
rand.Read(bytes)
|
||||
for i, b := range bytes {
|
||||
bytes[i] = alphanum[b%byte(len(alphanum))]
|
||||
}
|
||||
return string(bytes)
|
||||
}
|
||||
|
@ -1,7 +1,14 @@
|
||||
// Beego (http://beego.me/)
|
||||
// @description beego is an open-source, high-performance web framework for the Go programming language.
|
||||
// @link http://github.com/astaxie/beego for the canonical source repository
|
||||
// @license http://github.com/astaxie/beego/blob/master/LICENSE
|
||||
// @authors astaxie
|
||||
|
||||
package controllers
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
|
||||
"github.com/astaxie/beego"
|
||||
"github.com/astaxie/beego/example/beeapi/models"
|
||||
)
|
||||
|
@ -1,3 +1,9 @@
|
||||
// Beego (http://beego.me/)
|
||||
// @description beego is an open-source, high-performance web framework for the Go programming language.
|
||||
// @link http://github.com/astaxie/beego for the canonical source repository
|
||||
// @license http://github.com/astaxie/beego/blob/master/LICENSE
|
||||
// @authors astaxie
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
|
@ -1,3 +1,9 @@
|
||||
// Beego (http://beego.me/)
|
||||
// @description beego is an open-source, high-performance web framework for the Go programming language.
|
||||
// @link http://github.com/astaxie/beego for the canonical source repository
|
||||
// @license http://github.com/astaxie/beego/blob/master/LICENSE
|
||||
// @authors astaxie
|
||||
|
||||
package models
|
||||
|
||||
import (
|
||||
|
@ -1,3 +1,9 @@
|
||||
// Beego (http://beego.me/)
|
||||
// @description beego is an open-source, high-performance web framework for the Go programming language.
|
||||
// @link http://github.com/astaxie/beego for the canonical source repository
|
||||
// @license http://github.com/astaxie/beego/blob/master/LICENSE
|
||||
// @authors Unknwon
|
||||
|
||||
package controllers
|
||||
|
||||
import (
|
||||
|
@ -1,12 +1,19 @@
|
||||
// Beego (http://beego.me/)
|
||||
// @description beego is an open-source, high-performance web framework for the Go programming language.
|
||||
// @link http://github.com/astaxie/beego for the canonical source repository
|
||||
// @license http://github.com/astaxie/beego/blob/master/LICENSE
|
||||
// @authors Unknwon
|
||||
|
||||
package controllers
|
||||
|
||||
import (
|
||||
"github.com/astaxie/beego"
|
||||
"github.com/garyburd/go-websocket/websocket"
|
||||
"io/ioutil"
|
||||
"math/rand"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/astaxie/beego"
|
||||
"github.com/gorilla/websocket"
|
||||
)
|
||||
|
||||
const (
|
||||
|
@ -1,3 +1,8 @@
|
||||
// Beego (http://beego.me/)
|
||||
// @description beego is an open-source, high-performance web framework for the Go programming language.
|
||||
// @link http://github.com/astaxie/beego for the canonical source repository
|
||||
// @license http://github.com/astaxie/beego/blob/master/LICENSE
|
||||
// @authors Unknwon
|
||||
package main
|
||||
|
||||
import (
|
||||
|
23
filter.go
23
filter.go
@ -1,3 +1,9 @@
|
||||
// Beego (http://beego.me/)
|
||||
// @description beego is an open-source, high-performance web framework for the Go programming language.
|
||||
// @link http://github.com/astaxie/beego for the canonical source repository
|
||||
// @license http://github.com/astaxie/beego/blob/master/LICENSE
|
||||
// @authors astaxie
|
||||
|
||||
package beego
|
||||
|
||||
import (
|
||||
@ -28,6 +34,12 @@ func (mr *FilterRouter) ValidRouter(router string) (bool, map[string]string) {
|
||||
if router == mr.pattern {
|
||||
return true, nil
|
||||
}
|
||||
//pattern /admin router /admin/ match
|
||||
//pattern /admin/ router /admin don't match, because url will 301 in router
|
||||
if n := len(router); n > 1 && router[n-1] == '/' && router[:n-2] == mr.pattern {
|
||||
return true, nil
|
||||
}
|
||||
|
||||
if mr.hasregex {
|
||||
if !mr.regex.MatchString(router) {
|
||||
return false, nil
|
||||
@ -46,7 +58,7 @@ func (mr *FilterRouter) ValidRouter(router string) (bool, map[string]string) {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func buildFilter(pattern string, filter FilterFunc) *FilterRouter {
|
||||
func buildFilter(pattern string, filter FilterFunc) (*FilterRouter, error) {
|
||||
mr := new(FilterRouter)
|
||||
mr.params = make(map[int]string)
|
||||
mr.filterFunc = filter
|
||||
@ -54,7 +66,7 @@ func buildFilter(pattern string, filter FilterFunc) *FilterRouter {
|
||||
j := 0
|
||||
for i, part := range parts {
|
||||
if strings.HasPrefix(part, ":") {
|
||||
expr := "(.+)"
|
||||
expr := "(.*)"
|
||||
//a user may choose to override the default expression
|
||||
// similar to expressjs: ‘/user/:id([0-9]+)’
|
||||
if index := strings.Index(part, "("); index != -1 {
|
||||
@ -77,7 +89,7 @@ func buildFilter(pattern string, filter FilterFunc) *FilterRouter {
|
||||
j++
|
||||
}
|
||||
if strings.HasPrefix(part, "*") {
|
||||
expr := "(.+)"
|
||||
expr := "(.*)"
|
||||
if part == "*.*" {
|
||||
mr.params[j] = ":path"
|
||||
parts[i] = "([^.]+).([^.]+)"
|
||||
@ -137,12 +149,11 @@ func buildFilter(pattern string, filter FilterFunc) *FilterRouter {
|
||||
pattern = strings.Join(parts, "/")
|
||||
regex, regexErr := regexp.Compile(pattern)
|
||||
if regexErr != nil {
|
||||
//TODO add error handling here to avoid panic
|
||||
panic(regexErr)
|
||||
return nil, regexErr
|
||||
}
|
||||
mr.regex = regex
|
||||
mr.hasregex = true
|
||||
}
|
||||
mr.pattern = pattern
|
||||
return mr
|
||||
return mr, nil
|
||||
}
|
||||
|
60
filter_test.go
Normal file
60
filter_test.go
Normal file
@ -0,0 +1,60 @@
|
||||
// Beego (http://beego.me/)
|
||||
// @description beego is an open-source, high-performance web framework for the Go programming language.
|
||||
// @link http://github.com/astaxie/beego for the canonical source repository
|
||||
// @license http://github.com/astaxie/beego/blob/master/LICENSE
|
||||
// @authors astaxie
|
||||
|
||||
package beego
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/astaxie/beego/context"
|
||||
)
|
||||
|
||||
var FilterUser = func(ctx *context.Context) {
|
||||
ctx.Output.Body([]byte("i am " + ctx.Input.Params[":last"] + ctx.Input.Params[":first"]))
|
||||
}
|
||||
|
||||
func TestFilter(t *testing.T) {
|
||||
r, _ := http.NewRequest("GET", "/person/asta/Xie", nil)
|
||||
w := httptest.NewRecorder()
|
||||
handler := NewControllerRegistor()
|
||||
handler.AddFilter("/person/:last/:first", "AfterStatic", FilterUser)
|
||||
handler.Add("/person/:last/:first", &TestController{})
|
||||
handler.ServeHTTP(w, r)
|
||||
if w.Body.String() != "i am astaXie" {
|
||||
t.Errorf("user define func can't run")
|
||||
}
|
||||
}
|
||||
|
||||
var FilterAdminUser = func(ctx *context.Context) {
|
||||
ctx.Output.Body([]byte("i am admin"))
|
||||
}
|
||||
|
||||
// Filter pattern /admin/:all
|
||||
// all url like /admin/ /admin/xie will all get filter
|
||||
|
||||
func TestPatternTwo(t *testing.T) {
|
||||
r, _ := http.NewRequest("GET", "/admin/", nil)
|
||||
w := httptest.NewRecorder()
|
||||
handler := NewControllerRegistor()
|
||||
handler.AddFilter("/admin/:all", "AfterStatic", FilterAdminUser)
|
||||
handler.ServeHTTP(w, r)
|
||||
if w.Body.String() != "i am admin" {
|
||||
t.Errorf("filter /admin/ can't run")
|
||||
}
|
||||
}
|
||||
|
||||
func TestPatternThree(t *testing.T) {
|
||||
r, _ := http.NewRequest("GET", "/admin/astaxie", nil)
|
||||
w := httptest.NewRecorder()
|
||||
handler := NewControllerRegistor()
|
||||
handler.AddFilter("/admin/:all", "AfterStatic", FilterAdminUser)
|
||||
handler.ServeHTTP(w, r)
|
||||
if w.Body.String() != "i am admin" {
|
||||
t.Errorf("filter /admin/astaxie can't run")
|
||||
}
|
||||
}
|
@ -1,25 +0,0 @@
|
||||
package beego
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/astaxie/beego/context"
|
||||
)
|
||||
|
||||
var FilterUser = func(ctx *context.Context) {
|
||||
ctx.Output.Body([]byte("i am " + ctx.Input.Params[":last"] + ctx.Input.Params[":first"]))
|
||||
}
|
||||
|
||||
func TestFilter(t *testing.T) {
|
||||
r, _ := http.NewRequest("GET", "/person/asta/Xie", nil)
|
||||
w := httptest.NewRecorder()
|
||||
handler := NewControllerRegistor()
|
||||
handler.AddFilter("/person/:last/:first", "AfterStatic", FilterUser)
|
||||
handler.Add("/person/:last/:first", &TestController{})
|
||||
handler.ServeHTTP(w, r)
|
||||
if w.Body.String() != "i am astaXie" {
|
||||
t.Errorf("user define func can't run")
|
||||
}
|
||||
}
|
23
flash.go
23
flash.go
@ -1,3 +1,9 @@
|
||||
// Beego (http://beego.me/)
|
||||
// @description beego is an open-source, high-performance web framework for the Go programming language.
|
||||
// @link http://github.com/astaxie/beego for the canonical source repository
|
||||
// @license http://github.com/astaxie/beego/blob/master/LICENSE
|
||||
// @authors astaxie
|
||||
|
||||
package beego
|
||||
|
||||
import (
|
||||
@ -6,9 +12,6 @@ import (
|
||||
"strings"
|
||||
)
|
||||
|
||||
// the separation string when encoding flash data.
|
||||
const BEEGO_FLASH_SEP = "#BEEGOFLASH#"
|
||||
|
||||
// FlashData is a tools to maintain data when using across request.
|
||||
type FlashData struct {
|
||||
Data map[string]string
|
||||
@ -54,29 +57,27 @@ func (fd *FlashData) Store(c *Controller) {
|
||||
c.Data["flash"] = fd.Data
|
||||
var flashValue string
|
||||
for key, value := range fd.Data {
|
||||
flashValue += "\x00" + key + BEEGO_FLASH_SEP + value + "\x00"
|
||||
flashValue += "\x00" + key + "\x23" + FlashSeperator + "\x23" + value + "\x00"
|
||||
}
|
||||
c.Ctx.SetCookie("BEEGO_FLASH", url.QueryEscape(flashValue), 0, "/")
|
||||
c.Ctx.SetCookie(FlashName, url.QueryEscape(flashValue), 0, "/")
|
||||
}
|
||||
|
||||
// ReadFromRequest parsed flash data from encoded values in cookie.
|
||||
func ReadFromRequest(c *Controller) *FlashData {
|
||||
flash := &FlashData{
|
||||
Data: make(map[string]string),
|
||||
}
|
||||
if cookie, err := c.Ctx.Request.Cookie("BEEGO_FLASH"); err == nil {
|
||||
flash := NewFlash()
|
||||
if cookie, err := c.Ctx.Request.Cookie(FlashName); err == nil {
|
||||
v, _ := url.QueryUnescape(cookie.Value)
|
||||
vals := strings.Split(v, "\x00")
|
||||
for _, v := range vals {
|
||||
if len(v) > 0 {
|
||||
kv := strings.Split(v, BEEGO_FLASH_SEP)
|
||||
kv := strings.Split(v, "\x23"+FlashSeperator+"\x23")
|
||||
if len(kv) == 2 {
|
||||
flash.Data[kv[0]] = kv[1]
|
||||
}
|
||||
}
|
||||
}
|
||||
//read one time then delete it
|
||||
c.Ctx.SetCookie("BEEGO_FLASH", "", -1, "/")
|
||||
c.Ctx.SetCookie(FlashName, "", -1, "/")
|
||||
}
|
||||
c.Data["flash"] = flash.Data
|
||||
return flash
|
||||
|
46
flash_test.go
Normal file
46
flash_test.go
Normal file
@ -0,0 +1,46 @@
|
||||
// Beego (http://beego.me/)
|
||||
// @description beego is an open-source, high-performance web framework for the Go programming language.
|
||||
// @link http://github.com/astaxie/beego for the canonical source repository
|
||||
// @license http://github.com/astaxie/beego/blob/master/LICENSE
|
||||
// @authors astaxie
|
||||
|
||||
package beego
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
type TestFlashController struct {
|
||||
Controller
|
||||
}
|
||||
|
||||
func (this *TestFlashController) TestWriteFlash() {
|
||||
flash := NewFlash()
|
||||
flash.Notice("TestFlashString")
|
||||
flash.Store(&this.Controller)
|
||||
// we choose to serve json because we don't want to load a template html file
|
||||
this.ServeJson(true)
|
||||
}
|
||||
|
||||
func TestFlashHeader(t *testing.T) {
|
||||
// create fake GET request
|
||||
r, _ := http.NewRequest("GET", "/", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
// setup the handler
|
||||
handler := NewControllerRegistor()
|
||||
handler.Add("/", &TestFlashController{}, "get:TestWriteFlash")
|
||||
handler.ServeHTTP(w, r)
|
||||
|
||||
// get the Set-Cookie value
|
||||
sc := w.Header().Get("Set-Cookie")
|
||||
// match for the expected header
|
||||
res := strings.Contains(sc, "BEEGO_FLASH=%00notice%23BEEGOFLASH%23TestFlashString%00")
|
||||
// validate the assertion
|
||||
if res != true {
|
||||
t.Errorf("TestFlashHeader() unable to validate flash message")
|
||||
}
|
||||
}
|
@ -60,3 +60,21 @@ some http request need setcookie. So set it like this:
|
||||
cookie.Value = "astaxie"
|
||||
httplib.Get("http://beego.me/").SetCookie(cookie)
|
||||
|
||||
## upload file
|
||||
httplib support mutil file upload, use `b.PostFile()`
|
||||
|
||||
b:=httplib.Post("http://beego.me/")
|
||||
b.Param("username","astaxie")
|
||||
b.Param("password","123456")
|
||||
b.PostFile("uploadfile1", "httplib.pdf")
|
||||
b.PostFile("uploadfile2", "httplib.txt")
|
||||
str, err := b.String()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
fmt.Println(str)
|
||||
|
||||
## set HTTP version
|
||||
some servers need to specify the protocol version of HTTP
|
||||
|
||||
httplib.Get("http://beego.me/").SetProtocolVersion("HTTP/1.1")
|
@ -1,3 +1,9 @@
|
||||
// Beego (http://beego.me/)
|
||||
// @description beego is an open-source, high-performance web framework for the Go programming language.
|
||||
// @link http://github.com/astaxie/beego for the canonical source repository
|
||||
// @license http://github.com/astaxie/beego/blob/master/LICENSE
|
||||
// @authors astaxie
|
||||
|
||||
package httplib
|
||||
|
||||
import (
|
||||
@ -7,6 +13,7 @@ import (
|
||||
"encoding/xml"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"mime/multipart"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httputil"
|
||||
@ -18,87 +25,145 @@ import (
|
||||
|
||||
var defaultUserAgent = "beegoServer"
|
||||
|
||||
// Get returns *BeegoHttpRequest with GET method.
|
||||
func Get(url string) *BeegoHttpRequest {
|
||||
var req http.Request
|
||||
req.Method = "GET"
|
||||
req.Header = http.Header{}
|
||||
req.Header.Set("User-Agent", defaultUserAgent)
|
||||
return &BeegoHttpRequest{url, &req, map[string]string{}, false, 60 * time.Second, 60 * time.Second, nil}
|
||||
return &BeegoHttpRequest{url, &req, map[string]string{}, map[string]string{}, false, 60 * time.Second, 60 * time.Second, nil, nil, nil}
|
||||
}
|
||||
|
||||
// Post returns *BeegoHttpRequest with POST method.
|
||||
func Post(url string) *BeegoHttpRequest {
|
||||
var req http.Request
|
||||
req.Method = "POST"
|
||||
req.Header = http.Header{}
|
||||
req.Header.Set("User-Agent", defaultUserAgent)
|
||||
return &BeegoHttpRequest{url, &req, map[string]string{}, false, 60 * time.Second, 60 * time.Second, nil}
|
||||
return &BeegoHttpRequest{url, &req, map[string]string{}, map[string]string{}, false, 60 * time.Second, 60 * time.Second, nil, nil, nil}
|
||||
}
|
||||
|
||||
// Put returns *BeegoHttpRequest with PUT method.
|
||||
func Put(url string) *BeegoHttpRequest {
|
||||
var req http.Request
|
||||
req.Method = "PUT"
|
||||
req.Header = http.Header{}
|
||||
req.Header.Set("User-Agent", defaultUserAgent)
|
||||
return &BeegoHttpRequest{url, &req, map[string]string{}, false, 60 * time.Second, 60 * time.Second, nil}
|
||||
return &BeegoHttpRequest{url, &req, map[string]string{}, map[string]string{}, false, 60 * time.Second, 60 * time.Second, nil, nil, nil}
|
||||
}
|
||||
|
||||
// Delete returns *BeegoHttpRequest DELETE GET method.
|
||||
func Delete(url string) *BeegoHttpRequest {
|
||||
var req http.Request
|
||||
req.Method = "DELETE"
|
||||
req.Header = http.Header{}
|
||||
req.Header.Set("User-Agent", defaultUserAgent)
|
||||
return &BeegoHttpRequest{url, &req, map[string]string{}, false, 60 * time.Second, 60 * time.Second, nil}
|
||||
return &BeegoHttpRequest{url, &req, map[string]string{}, map[string]string{}, false, 60 * time.Second, 60 * time.Second, nil, nil, nil}
|
||||
}
|
||||
|
||||
// Head returns *BeegoHttpRequest with HEAD method.
|
||||
func Head(url string) *BeegoHttpRequest {
|
||||
var req http.Request
|
||||
req.Method = "HEAD"
|
||||
req.Header = http.Header{}
|
||||
req.Header.Set("User-Agent", defaultUserAgent)
|
||||
return &BeegoHttpRequest{url, &req, map[string]string{}, false, 60 * time.Second, 60 * time.Second, nil}
|
||||
return &BeegoHttpRequest{url, &req, map[string]string{}, map[string]string{}, false, 60 * time.Second, 60 * time.Second, nil, nil, nil}
|
||||
}
|
||||
|
||||
// BeegoHttpRequest provides more useful methods for requesting one url than http.Request.
|
||||
type BeegoHttpRequest struct {
|
||||
url string
|
||||
req *http.Request
|
||||
params map[string]string
|
||||
files map[string]string
|
||||
showdebug bool
|
||||
connectTimeout time.Duration
|
||||
readWriteTimeout time.Duration
|
||||
tlsClientConfig *tls.Config
|
||||
proxy func(*http.Request) (*url.URL, error)
|
||||
transport http.RoundTripper
|
||||
}
|
||||
|
||||
// Debug sets show debug or not when executing request.
|
||||
func (b *BeegoHttpRequest) Debug(isdebug bool) *BeegoHttpRequest {
|
||||
b.showdebug = isdebug
|
||||
return b
|
||||
}
|
||||
|
||||
// SetTimeout sets connect time out and read-write time out for BeegoRequest.
|
||||
func (b *BeegoHttpRequest) SetTimeout(connectTimeout, readWriteTimeout time.Duration) *BeegoHttpRequest {
|
||||
b.connectTimeout = connectTimeout
|
||||
b.readWriteTimeout = readWriteTimeout
|
||||
return b
|
||||
}
|
||||
|
||||
// SetTLSClientConfig sets tls connection configurations if visiting https url.
|
||||
func (b *BeegoHttpRequest) SetTLSClientConfig(config *tls.Config) *BeegoHttpRequest {
|
||||
b.tlsClientConfig = config
|
||||
return b
|
||||
}
|
||||
|
||||
// Header add header item string in request.
|
||||
func (b *BeegoHttpRequest) Header(key, value string) *BeegoHttpRequest {
|
||||
b.req.Header.Set(key, value)
|
||||
return b
|
||||
}
|
||||
|
||||
// Set the protocol version for incoming requests.
|
||||
// Client requests always use HTTP/1.1.
|
||||
func (b *BeegoHttpRequest) SetProtocolVersion(vers string) *BeegoHttpRequest {
|
||||
if len(vers) == 0 {
|
||||
vers = "HTTP/1.1"
|
||||
}
|
||||
|
||||
major, minor, ok := http.ParseHTTPVersion(vers)
|
||||
if ok {
|
||||
b.req.Proto = vers
|
||||
b.req.ProtoMajor = major
|
||||
b.req.ProtoMinor = minor
|
||||
}
|
||||
|
||||
return b
|
||||
}
|
||||
|
||||
// SetCookie add cookie into request.
|
||||
func (b *BeegoHttpRequest) SetCookie(cookie *http.Cookie) *BeegoHttpRequest {
|
||||
b.req.Header.Add("Cookie", cookie.String())
|
||||
return b
|
||||
}
|
||||
|
||||
// Set transport to
|
||||
func (b *BeegoHttpRequest) SetTransport(transport http.RoundTripper) *BeegoHttpRequest {
|
||||
b.transport = transport
|
||||
return b
|
||||
}
|
||||
|
||||
// Set http proxy
|
||||
// example:
|
||||
//
|
||||
// func(req *http.Request) (*url.URL, error) {
|
||||
// u, _ := url.ParseRequestURI("http://127.0.0.1:8118")
|
||||
// return u, nil
|
||||
// }
|
||||
func (b *BeegoHttpRequest) SetProxy(proxy func(*http.Request) (*url.URL, error)) *BeegoHttpRequest {
|
||||
b.proxy = proxy
|
||||
return b
|
||||
}
|
||||
|
||||
// Param adds query param in to request.
|
||||
// params build query string as ?key1=value1&key2=value2...
|
||||
func (b *BeegoHttpRequest) Param(key, value string) *BeegoHttpRequest {
|
||||
b.params[key] = value
|
||||
return b
|
||||
}
|
||||
|
||||
func (b *BeegoHttpRequest) PostFile(formname, filename string) *BeegoHttpRequest {
|
||||
b.files[formname] = filename
|
||||
return b
|
||||
}
|
||||
|
||||
// Body adds request raw body.
|
||||
// it supports string and []byte.
|
||||
func (b *BeegoHttpRequest) Body(data interface{}) *BeegoHttpRequest {
|
||||
switch t := data.(type) {
|
||||
case string:
|
||||
@ -134,9 +199,38 @@ func (b *BeegoHttpRequest) getResponse() (*http.Response, error) {
|
||||
b.url = b.url + "?" + paramBody
|
||||
}
|
||||
} else if b.req.Method == "POST" && b.req.Body == nil && len(paramBody) > 0 {
|
||||
if len(b.files) > 0 {
|
||||
bodyBuf := &bytes.Buffer{}
|
||||
bodyWriter := multipart.NewWriter(bodyBuf)
|
||||
for formname, filename := range b.files {
|
||||
fileWriter, err := bodyWriter.CreateFormFile(formname, filename)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
fh, err := os.Open(filename)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
//iocopy
|
||||
_, err = io.Copy(fileWriter, fh)
|
||||
fh.Close()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
for k, v := range b.params {
|
||||
bodyWriter.WriteField(k, v)
|
||||
}
|
||||
contentType := bodyWriter.FormDataContentType()
|
||||
bodyWriter.Close()
|
||||
b.Header("Content-Type", contentType)
|
||||
b.req.Body = ioutil.NopCloser(bodyBuf)
|
||||
b.req.ContentLength = int64(bodyBuf.Len())
|
||||
} else {
|
||||
b.Header("Content-Type", "application/x-www-form-urlencoded")
|
||||
b.Body(paramBody)
|
||||
}
|
||||
}
|
||||
|
||||
url, err := url.Parse(b.url)
|
||||
if url.Scheme == "" {
|
||||
@ -156,12 +250,34 @@ func (b *BeegoHttpRequest) getResponse() (*http.Response, error) {
|
||||
println(string(dump))
|
||||
}
|
||||
|
||||
client := &http.Client{
|
||||
Transport: &http.Transport{
|
||||
trans := b.transport
|
||||
|
||||
if trans == nil {
|
||||
// create default transport
|
||||
trans = &http.Transport{
|
||||
TLSClientConfig: b.tlsClientConfig,
|
||||
Proxy: b.proxy,
|
||||
Dial: TimeoutDialer(b.connectTimeout, b.readWriteTimeout),
|
||||
},
|
||||
}
|
||||
} else {
|
||||
// if b.transport is *http.Transport then set the settings.
|
||||
if t, ok := trans.(*http.Transport); ok {
|
||||
if t.TLSClientConfig == nil {
|
||||
t.TLSClientConfig = b.tlsClientConfig
|
||||
}
|
||||
if t.Proxy == nil {
|
||||
t.Proxy = b.proxy
|
||||
}
|
||||
if t.Dial == nil {
|
||||
t.Dial = TimeoutDialer(b.connectTimeout, b.readWriteTimeout)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
client := &http.Client{
|
||||
Transport: trans,
|
||||
}
|
||||
|
||||
resp, err := client.Do(b.req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@ -169,6 +285,8 @@ func (b *BeegoHttpRequest) getResponse() (*http.Response, error) {
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
// String returns the body string in response.
|
||||
// it calls Response inner.
|
||||
func (b *BeegoHttpRequest) String() (string, error) {
|
||||
data, err := b.Bytes()
|
||||
if err != nil {
|
||||
@ -178,6 +296,8 @@ func (b *BeegoHttpRequest) String() (string, error) {
|
||||
return string(data), nil
|
||||
}
|
||||
|
||||
// Bytes returns the body []byte in response.
|
||||
// it calls Response inner.
|
||||
func (b *BeegoHttpRequest) Bytes() ([]byte, error) {
|
||||
resp, err := b.getResponse()
|
||||
if err != nil {
|
||||
@ -194,6 +314,8 @@ func (b *BeegoHttpRequest) Bytes() ([]byte, error) {
|
||||
return data, nil
|
||||
}
|
||||
|
||||
// ToFile saves the body data in response to one file.
|
||||
// it calls Response inner.
|
||||
func (b *BeegoHttpRequest) ToFile(filename string) error {
|
||||
f, err := os.Create(filename)
|
||||
if err != nil {
|
||||
@ -216,6 +338,8 @@ func (b *BeegoHttpRequest) ToFile(filename string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// ToJson returns the map that marshals from the body bytes as json in response .
|
||||
// it calls Response inner.
|
||||
func (b *BeegoHttpRequest) ToJson(v interface{}) error {
|
||||
data, err := b.Bytes()
|
||||
if err != nil {
|
||||
@ -228,6 +352,8 @@ func (b *BeegoHttpRequest) ToJson(v interface{}) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// ToXml returns the map that marshals from the body bytes as xml in response .
|
||||
// it calls Response inner.
|
||||
func (b *BeegoHttpRequest) ToXML(v interface{}) error {
|
||||
data, err := b.Bytes()
|
||||
if err != nil {
|
||||
@ -240,10 +366,12 @@ func (b *BeegoHttpRequest) ToXML(v interface{}) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Response executes request client gets response mannually.
|
||||
func (b *BeegoHttpRequest) Response() (*http.Response, error) {
|
||||
return b.getResponse()
|
||||
}
|
||||
|
||||
// TimeoutDialer returns functions of connection dialer with timeout settings for http.Transport Dial field.
|
||||
func TimeoutDialer(cTimeout time.Duration, rwTimeout time.Duration) func(net, addr string) (c net.Conn, err error) {
|
||||
return func(netw, addr string) (net.Conn, error) {
|
||||
conn, err := net.DialTimeout(netw, addr, cTimeout)
|
||||
|
@ -1,6 +1,13 @@
|
||||
// Beego (http://beego.me/)
|
||||
// @description beego is an open-source, high-performance web framework for the Go programming language.
|
||||
// @link http://github.com/astaxie/beego for the canonical source repository
|
||||
// @license http://github.com/astaxie/beego/blob/master/LICENSE
|
||||
// @authors astaxie
|
||||
|
||||
package httplib
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"testing"
|
||||
)
|
||||
@ -30,3 +37,15 @@ func TestGetUrl(t *testing.T) {
|
||||
t.Fatal("has no info")
|
||||
}
|
||||
}
|
||||
|
||||
func TestPost(t *testing.T) {
|
||||
b := Post("http://beego.me/").Debug(true)
|
||||
b.Param("username", "astaxie")
|
||||
b.Param("password", "hello")
|
||||
b.PostFile("uploadfile", "httplib.go")
|
||||
str, err := b.String()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
fmt.Println(str)
|
||||
}
|
||||
|
19
log.go
19
log.go
@ -1,3 +1,9 @@
|
||||
// Beego (http://beego.me/)
|
||||
// @description beego is an open-source, high-performance web framework for the Go programming language.
|
||||
// @link http://github.com/astaxie/beego for the canonical source repository
|
||||
// @license http://github.com/astaxie/beego/blob/master/LICENSE
|
||||
// @authors astaxie
|
||||
|
||||
package beego
|
||||
|
||||
import (
|
||||
@ -22,12 +28,21 @@ func SetLevel(l int) {
|
||||
BeeLogger.SetLevel(l)
|
||||
}
|
||||
|
||||
func SetLogFuncCall(b bool) {
|
||||
BeeLogger.EnableFuncCallDepth(b)
|
||||
BeeLogger.SetLogFuncCallDepth(3)
|
||||
}
|
||||
|
||||
// logger references the used application logger.
|
||||
var BeeLogger *logs.BeeLogger
|
||||
|
||||
// SetLogger sets a new logger.
|
||||
func SetLogger(adaptername string, config string) {
|
||||
BeeLogger.SetLogger(adaptername, config)
|
||||
func SetLogger(adaptername string, config string) error {
|
||||
err := BeeLogger.SetLogger(adaptername, config)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Trace logs a message at trace level.
|
||||
|
15
logs/conn.go
15
logs/conn.go
@ -1,3 +1,9 @@
|
||||
// Beego (http://beego.me/)
|
||||
// @description beego is an open-source, high-performance web framework for the Go programming language.
|
||||
// @link http://github.com/astaxie/beego for the canonical source repository
|
||||
// @license http://github.com/astaxie/beego/blob/master/LICENSE
|
||||
// @authors astaxie
|
||||
|
||||
package logs
|
||||
|
||||
import (
|
||||
@ -7,6 +13,8 @@ import (
|
||||
"net"
|
||||
)
|
||||
|
||||
// ConnWriter implements LoggerInterface.
|
||||
// it writes messages in keep-live tcp connection.
|
||||
type ConnWriter struct {
|
||||
lg *log.Logger
|
||||
innerWriter io.WriteCloser
|
||||
@ -17,12 +25,15 @@ type ConnWriter struct {
|
||||
Level int `json:"level"`
|
||||
}
|
||||
|
||||
// create new ConnWrite returning as LoggerInterface.
|
||||
func NewConn() LoggerInterface {
|
||||
conn := new(ConnWriter)
|
||||
conn.Level = LevelTrace
|
||||
return conn
|
||||
}
|
||||
|
||||
// init connection writer with json config.
|
||||
// json config only need key "level".
|
||||
func (c *ConnWriter) Init(jsonconfig string) error {
|
||||
err := json.Unmarshal([]byte(jsonconfig), c)
|
||||
if err != nil {
|
||||
@ -31,6 +42,8 @@ func (c *ConnWriter) Init(jsonconfig string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// write message in connection.
|
||||
// if connection is down, try to re-connect.
|
||||
func (c *ConnWriter) WriteMsg(msg string, level int) error {
|
||||
if level < c.Level {
|
||||
return nil
|
||||
@ -49,10 +62,12 @@ func (c *ConnWriter) WriteMsg(msg string, level int) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// implementing method. empty.
|
||||
func (c *ConnWriter) Flush() {
|
||||
|
||||
}
|
||||
|
||||
// destroy connection writer and close tcp listener.
|
||||
func (c *ConnWriter) Destroy() {
|
||||
if c.innerWriter == nil {
|
||||
return
|
||||
|
@ -1,3 +1,9 @@
|
||||
// Beego (http://beego.me/)
|
||||
// @description beego is an open-source, high-performance web framework for the Go programming language.
|
||||
// @link http://github.com/astaxie/beego for the canonical source repository
|
||||
// @license http://github.com/astaxie/beego/blob/master/LICENSE
|
||||
// @authors astaxie
|
||||
|
||||
package logs
|
||||
|
||||
import (
|
||||
|
@ -1,16 +1,44 @@
|
||||
// Beego (http://beego.me/)
|
||||
// @description beego is an open-source, high-performance web framework for the Go programming language.
|
||||
// @link http://github.com/astaxie/beego for the canonical source repository
|
||||
// @license http://github.com/astaxie/beego/blob/master/LICENSE
|
||||
// @authors astaxie
|
||||
|
||||
package logs
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"log"
|
||||
"os"
|
||||
"runtime"
|
||||
)
|
||||
|
||||
type Brush func(string) string
|
||||
|
||||
func NewBrush(color string) Brush {
|
||||
pre := "\033["
|
||||
reset := "\033[0m"
|
||||
return func(text string) string {
|
||||
return pre + color + "m" + text + reset
|
||||
}
|
||||
}
|
||||
|
||||
var colors = []Brush{
|
||||
NewBrush("1;36"), // Trace cyan
|
||||
NewBrush("1;34"), // Debug blue
|
||||
NewBrush("1;32"), // Info green
|
||||
NewBrush("1;33"), // Warn yellow
|
||||
NewBrush("1;31"), // Error red
|
||||
NewBrush("1;35"), // Critical purple
|
||||
}
|
||||
|
||||
// ConsoleWriter implements LoggerInterface and writes messages to terminal.
|
||||
type ConsoleWriter struct {
|
||||
lg *log.Logger
|
||||
Level int `json:"level"`
|
||||
}
|
||||
|
||||
// create ConsoleWriter returning as LoggerInterface.
|
||||
func NewConsole() LoggerInterface {
|
||||
cw := new(ConsoleWriter)
|
||||
cw.lg = log.New(os.Stdout, "", log.Ldate|log.Ltime)
|
||||
@ -18,7 +46,12 @@ func NewConsole() LoggerInterface {
|
||||
return cw
|
||||
}
|
||||
|
||||
// init console logger.
|
||||
// jsonconfig like '{"level":LevelTrace}'.
|
||||
func (c *ConsoleWriter) Init(jsonconfig string) error {
|
||||
if len(jsonconfig) == 0 {
|
||||
return nil
|
||||
}
|
||||
err := json.Unmarshal([]byte(jsonconfig), c)
|
||||
if err != nil {
|
||||
return err
|
||||
@ -26,18 +59,25 @@ func (c *ConsoleWriter) Init(jsonconfig string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// write message in console.
|
||||
func (c *ConsoleWriter) WriteMsg(msg string, level int) error {
|
||||
if level < c.Level {
|
||||
return nil
|
||||
}
|
||||
if goos := runtime.GOOS; goos == "windows" {
|
||||
c.lg.Println(msg)
|
||||
} else {
|
||||
c.lg.Println(colors[level](msg))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// implementing method. empty.
|
||||
func (c *ConsoleWriter) Destroy() {
|
||||
|
||||
}
|
||||
|
||||
// implementing method. empty.
|
||||
func (c *ConsoleWriter) Flush() {
|
||||
|
||||
}
|
||||
|
@ -1,3 +1,9 @@
|
||||
// Beego (http://beego.me/)
|
||||
// @description beego is an open-source, high-performance web framework for the Go programming language.
|
||||
// @link http://github.com/astaxie/beego for the canonical source repository
|
||||
// @license http://github.com/astaxie/beego/blob/master/LICENSE
|
||||
// @authors astaxie
|
||||
|
||||
package logs
|
||||
|
||||
import (
|
||||
@ -6,6 +12,7 @@ import (
|
||||
|
||||
func TestConsole(t *testing.T) {
|
||||
log := NewLogger(10000)
|
||||
log.EnableFuncCallDepth(true)
|
||||
log.SetLogger("console", "")
|
||||
log.Trace("trace")
|
||||
log.Info("info")
|
||||
@ -23,6 +30,7 @@ func TestConsole(t *testing.T) {
|
||||
|
||||
func BenchmarkConsole(b *testing.B) {
|
||||
log := NewLogger(10000)
|
||||
log.EnableFuncCallDepth(true)
|
||||
log.SetLogger("console", "")
|
||||
for i := 0; i < b.N; i++ {
|
||||
log.Trace("trace")
|
||||
|
39
logs/file.go
39
logs/file.go
@ -1,3 +1,9 @@
|
||||
// Beego (http://beego.me/)
|
||||
// @description beego is an open-source, high-performance web framework for the Go programming language.
|
||||
// @link http://github.com/astaxie/beego for the canonical source repository
|
||||
// @license http://github.com/astaxie/beego/blob/master/LICENSE
|
||||
// @authors astaxie
|
||||
|
||||
package logs
|
||||
|
||||
import (
|
||||
@ -13,6 +19,8 @@ import (
|
||||
"time"
|
||||
)
|
||||
|
||||
// FileLogWriter implements LoggerInterface.
|
||||
// It writes messages by lines limit, file size limit, or time frequency.
|
||||
type FileLogWriter struct {
|
||||
*log.Logger
|
||||
mw *MuxWriter
|
||||
@ -28,7 +36,7 @@ type FileLogWriter struct {
|
||||
|
||||
// Rotate daily
|
||||
Daily bool `json:"daily"`
|
||||
Maxdays int64 `json:"maxdays`
|
||||
Maxdays int64 `json:"maxdays"`
|
||||
daily_opendate int
|
||||
|
||||
Rotate bool `json:"rotate"`
|
||||
@ -38,17 +46,20 @@ type FileLogWriter struct {
|
||||
Level int `json:"level"`
|
||||
}
|
||||
|
||||
// an *os.File writer with locker.
|
||||
type MuxWriter struct {
|
||||
sync.Mutex
|
||||
fd *os.File
|
||||
}
|
||||
|
||||
// write to os.File.
|
||||
func (l *MuxWriter) Write(b []byte) (int, error) {
|
||||
l.Lock()
|
||||
defer l.Unlock()
|
||||
return l.fd.Write(b)
|
||||
}
|
||||
|
||||
// set os.File in writer.
|
||||
func (l *MuxWriter) SetFd(fd *os.File) {
|
||||
if l.fd != nil {
|
||||
l.fd.Close()
|
||||
@ -56,6 +67,7 @@ func (l *MuxWriter) SetFd(fd *os.File) {
|
||||
l.fd = fd
|
||||
}
|
||||
|
||||
// create a FileLogWriter returning as LoggerInterface.
|
||||
func NewFileWriter() LoggerInterface {
|
||||
w := &FileLogWriter{
|
||||
Filename: "",
|
||||
@ -73,15 +85,16 @@ func NewFileWriter() LoggerInterface {
|
||||
return w
|
||||
}
|
||||
|
||||
// jsonconfig like this
|
||||
//{
|
||||
// Init file logger with json config.
|
||||
// jsonconfig like:
|
||||
// {
|
||||
// "filename":"logs/beego.log",
|
||||
// "maxlines":10000,
|
||||
// "maxsize":1<<30,
|
||||
// "daily":true,
|
||||
// "maxdays":15,
|
||||
// "rotate":true
|
||||
//}
|
||||
// }
|
||||
func (w *FileLogWriter) Init(jsonconfig string) error {
|
||||
err := json.Unmarshal([]byte(jsonconfig), w)
|
||||
if err != nil {
|
||||
@ -90,11 +103,12 @@ func (w *FileLogWriter) Init(jsonconfig string) error {
|
||||
if len(w.Filename) == 0 {
|
||||
return errors.New("jsonconfig must have filename")
|
||||
}
|
||||
err = w.StartLogger()
|
||||
err = w.startLogger()
|
||||
return err
|
||||
}
|
||||
|
||||
func (w *FileLogWriter) StartLogger() error {
|
||||
// start file logger. create log file and set to locker-inside file writer.
|
||||
func (w *FileLogWriter) startLogger() error {
|
||||
fd, err := w.createLogFile()
|
||||
if err != nil {
|
||||
return err
|
||||
@ -110,9 +124,9 @@ func (w *FileLogWriter) StartLogger() error {
|
||||
func (w *FileLogWriter) docheck(size int) {
|
||||
w.startLock.Lock()
|
||||
defer w.startLock.Unlock()
|
||||
if (w.Maxlines > 0 && w.maxlines_curlines >= w.Maxlines) ||
|
||||
if w.Rotate && ((w.Maxlines > 0 && w.maxlines_curlines >= w.Maxlines) ||
|
||||
(w.Maxsize > 0 && w.maxsize_cursize >= w.Maxsize) ||
|
||||
(w.Daily && time.Now().Day() != w.daily_opendate) {
|
||||
(w.Daily && time.Now().Day() != w.daily_opendate)) {
|
||||
if err := w.DoRotate(); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "FileLogWriter(%q): %s\n", w.Filename, err)
|
||||
return
|
||||
@ -122,6 +136,7 @@ func (w *FileLogWriter) docheck(size int) {
|
||||
w.maxsize_cursize += size
|
||||
}
|
||||
|
||||
// write logger message into file.
|
||||
func (w *FileLogWriter) WriteMsg(msg string, level int) error {
|
||||
if level < w.Level {
|
||||
return nil
|
||||
@ -158,6 +173,8 @@ func (w *FileLogWriter) initFd() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// DoRotate means it need to write file in new file.
|
||||
// new file name like xx.log.2013-01-01.2
|
||||
func (w *FileLogWriter) DoRotate() error {
|
||||
_, err := os.Lstat(w.Filename)
|
||||
if err == nil { // file exists
|
||||
@ -188,7 +205,7 @@ func (w *FileLogWriter) DoRotate() error {
|
||||
}
|
||||
|
||||
// re-start logger
|
||||
err = w.StartLogger()
|
||||
err = w.startLogger()
|
||||
if err != nil {
|
||||
return fmt.Errorf("Rotate StartLogger: %s\n", err)
|
||||
}
|
||||
@ -211,10 +228,14 @@ func (w *FileLogWriter) deleteOldLog() {
|
||||
})
|
||||
}
|
||||
|
||||
// destroy file logger, close file writer.
|
||||
func (w *FileLogWriter) Destroy() {
|
||||
w.mw.fd.Close()
|
||||
}
|
||||
|
||||
// flush file logger.
|
||||
// there are no buffering messages in file logger in memory.
|
||||
// flush file means sync file from disk.
|
||||
func (w *FileLogWriter) Flush() {
|
||||
w.mw.fd.Sync()
|
||||
}
|
||||
|
@ -1,3 +1,9 @@
|
||||
// Beego (http://beego.me/)
|
||||
// @description beego is an open-source, high-performance web framework for the Go programming language.
|
||||
// @link http://github.com/astaxie/beego for the canonical source repository
|
||||
// @license http://github.com/astaxie/beego/blob/master/LICENSE
|
||||
// @authors astaxie
|
||||
|
||||
package logs
|
||||
|
||||
import (
|
||||
|
67
logs/log.go
67
logs/log.go
@ -1,11 +1,20 @@
|
||||
// Beego (http://beego.me/)
|
||||
// @description beego is an open-source, high-performance web framework for the Go programming language.
|
||||
// @link http://github.com/astaxie/beego for the canonical source repository
|
||||
// @license http://github.com/astaxie/beego/blob/master/LICENSE
|
||||
// @authors astaxie
|
||||
|
||||
package logs
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"path"
|
||||
"runtime"
|
||||
"sync"
|
||||
)
|
||||
|
||||
const (
|
||||
// log message levels
|
||||
LevelTrace = iota
|
||||
LevelDebug
|
||||
LevelInfo
|
||||
@ -16,6 +25,7 @@ const (
|
||||
|
||||
type loggerType func() LoggerInterface
|
||||
|
||||
// LoggerInterface defines the behavior of a log provider.
|
||||
type LoggerInterface interface {
|
||||
Init(config string) error
|
||||
WriteMsg(msg string, level int) error
|
||||
@ -38,9 +48,13 @@ func Register(name string, log loggerType) {
|
||||
adapters[name] = log
|
||||
}
|
||||
|
||||
// BeeLogger is default logger in beego application.
|
||||
// it can contain several providers and log message into all providers.
|
||||
type BeeLogger struct {
|
||||
lock sync.Mutex
|
||||
level int
|
||||
enableFuncCallDepth bool
|
||||
loggerFuncCallDepth int
|
||||
msg chan *logMsg
|
||||
outputs map[string]LoggerInterface
|
||||
}
|
||||
@ -50,29 +64,39 @@ type logMsg struct {
|
||||
msg string
|
||||
}
|
||||
|
||||
// config need to be correct JSON as string: {"interval":360}
|
||||
// NewLogger returns a new BeeLogger.
|
||||
// channellen means the number of messages in chan.
|
||||
// if the buffering chan is full, logger adapters write to file or other way.
|
||||
func NewLogger(channellen int64) *BeeLogger {
|
||||
bl := new(BeeLogger)
|
||||
bl.loggerFuncCallDepth = 2
|
||||
bl.msg = make(chan *logMsg, channellen)
|
||||
bl.outputs = make(map[string]LoggerInterface)
|
||||
//bl.SetLogger("console", "") // default output to console
|
||||
go bl.StartLogger()
|
||||
go bl.startLogger()
|
||||
return bl
|
||||
}
|
||||
|
||||
// SetLogger provides a given logger adapter into BeeLogger with config string.
|
||||
// config need to be correct JSON as string: {"interval":360}.
|
||||
func (bl *BeeLogger) SetLogger(adaptername string, config string) error {
|
||||
bl.lock.Lock()
|
||||
defer bl.lock.Unlock()
|
||||
if log, ok := adapters[adaptername]; ok {
|
||||
lg := log()
|
||||
lg.Init(config)
|
||||
err := lg.Init(config)
|
||||
bl.outputs[adaptername] = lg
|
||||
return nil
|
||||
if err != nil {
|
||||
fmt.Println("logs.BeeLogger.SetLogger: " + err.Error())
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
return fmt.Errorf("logs: unknown adaptername %q (forgotten Register?)", adaptername)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// remove a logger adapter in BeeLogger.
|
||||
func (bl *BeeLogger) DelLogger(adaptername string) error {
|
||||
bl.lock.Lock()
|
||||
defer bl.lock.Unlock()
|
||||
@ -91,16 +115,40 @@ func (bl *BeeLogger) writerMsg(loglevel int, msg string) error {
|
||||
}
|
||||
lm := new(logMsg)
|
||||
lm.level = loglevel
|
||||
if bl.enableFuncCallDepth {
|
||||
_, file, line, ok := runtime.Caller(bl.loggerFuncCallDepth)
|
||||
if ok {
|
||||
_, filename := path.Split(file)
|
||||
lm.msg = fmt.Sprintf("[%s:%d] %s", filename, line, msg)
|
||||
} else {
|
||||
lm.msg = msg
|
||||
}
|
||||
} else {
|
||||
lm.msg = msg
|
||||
}
|
||||
bl.msg <- lm
|
||||
return nil
|
||||
}
|
||||
|
||||
// set log message level.
|
||||
// if message level (such as LevelTrace) is less than logger level (such as LevelWarn), ignore message.
|
||||
func (bl *BeeLogger) SetLevel(l int) {
|
||||
bl.level = l
|
||||
}
|
||||
|
||||
func (bl *BeeLogger) StartLogger() {
|
||||
// set log funcCallDepth
|
||||
func (bl *BeeLogger) SetLogFuncCallDepth(d int) {
|
||||
bl.loggerFuncCallDepth = d
|
||||
}
|
||||
|
||||
// enable log funcCallDepth
|
||||
func (bl *BeeLogger) EnableFuncCallDepth(b bool) {
|
||||
bl.enableFuncCallDepth = b
|
||||
}
|
||||
|
||||
// start logger chan reading.
|
||||
// when chan is full, write logs.
|
||||
func (bl *BeeLogger) startLogger() {
|
||||
for {
|
||||
select {
|
||||
case bm := <-bl.msg:
|
||||
@ -111,43 +159,50 @@ func (bl *BeeLogger) StartLogger() {
|
||||
}
|
||||
}
|
||||
|
||||
// log trace level message.
|
||||
func (bl *BeeLogger) Trace(format string, v ...interface{}) {
|
||||
msg := fmt.Sprintf("[T] "+format, v...)
|
||||
bl.writerMsg(LevelTrace, msg)
|
||||
}
|
||||
|
||||
// log debug level message.
|
||||
func (bl *BeeLogger) Debug(format string, v ...interface{}) {
|
||||
msg := fmt.Sprintf("[D] "+format, v...)
|
||||
bl.writerMsg(LevelDebug, msg)
|
||||
}
|
||||
|
||||
// log info level message.
|
||||
func (bl *BeeLogger) Info(format string, v ...interface{}) {
|
||||
msg := fmt.Sprintf("[I] "+format, v...)
|
||||
bl.writerMsg(LevelInfo, msg)
|
||||
}
|
||||
|
||||
// log warn level message.
|
||||
func (bl *BeeLogger) Warn(format string, v ...interface{}) {
|
||||
msg := fmt.Sprintf("[W] "+format, v...)
|
||||
bl.writerMsg(LevelWarn, msg)
|
||||
}
|
||||
|
||||
// log error level message.
|
||||
func (bl *BeeLogger) Error(format string, v ...interface{}) {
|
||||
msg := fmt.Sprintf("[E] "+format, v...)
|
||||
bl.writerMsg(LevelError, msg)
|
||||
}
|
||||
|
||||
// log critical level message.
|
||||
func (bl *BeeLogger) Critical(format string, v ...interface{}) {
|
||||
msg := fmt.Sprintf("[C] "+format, v...)
|
||||
bl.writerMsg(LevelCritical, msg)
|
||||
}
|
||||
|
||||
//flush all chan data
|
||||
// flush all chan data.
|
||||
func (bl *BeeLogger) Flush() {
|
||||
for _, l := range bl.outputs {
|
||||
l.Flush()
|
||||
}
|
||||
}
|
||||
|
||||
// close logger, flush all chan data and destroy all adapters in BeeLogger.
|
||||
func (bl *BeeLogger) Close() {
|
||||
for {
|
||||
if len(bl.msg) > 0 {
|
||||
|
24
logs/smtp.go
24
logs/smtp.go
@ -1,3 +1,9 @@
|
||||
// Beego (http://beego.me/)
|
||||
// @description beego is an open-source, high-performance web framework for the Go programming language.
|
||||
// @link http://github.com/astaxie/beego for the canonical source repository
|
||||
// @license http://github.com/astaxie/beego/blob/master/LICENSE
|
||||
// @authors astaxie
|
||||
|
||||
package logs
|
||||
|
||||
import (
|
||||
@ -12,7 +18,7 @@ const (
|
||||
subjectPhrase = "Diagnostic message from server"
|
||||
)
|
||||
|
||||
// smtpWriter is used to send emails via given SMTP-server.
|
||||
// smtpWriter implements LoggerInterface and is used to send emails via given SMTP-server.
|
||||
type SmtpWriter struct {
|
||||
Username string `json:"Username"`
|
||||
Password string `json:"password"`
|
||||
@ -22,10 +28,21 @@ type SmtpWriter struct {
|
||||
Level int `json:"level"`
|
||||
}
|
||||
|
||||
// create smtp writer.
|
||||
func NewSmtpWriter() LoggerInterface {
|
||||
return &SmtpWriter{Level: LevelTrace}
|
||||
}
|
||||
|
||||
// init smtp writer with json config.
|
||||
// config like:
|
||||
// {
|
||||
// "Username":"example@gmail.com",
|
||||
// "password:"password",
|
||||
// "host":"smtp.gmail.com:465",
|
||||
// "subject":"email title",
|
||||
// "sendTos":["email1","email2"],
|
||||
// "level":LevelError
|
||||
// }
|
||||
func (s *SmtpWriter) Init(jsonconfig string) error {
|
||||
err := json.Unmarshal([]byte(jsonconfig), s)
|
||||
if err != nil {
|
||||
@ -34,6 +51,8 @@ func (s *SmtpWriter) Init(jsonconfig string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// write message in smtp writer.
|
||||
// it will send an email with subject and only this message.
|
||||
func (s *SmtpWriter) WriteMsg(msg string, level int) error {
|
||||
if level < s.Level {
|
||||
return nil
|
||||
@ -65,9 +84,12 @@ func (s *SmtpWriter) WriteMsg(msg string, level int) error {
|
||||
return err
|
||||
}
|
||||
|
||||
// implementing method. empty.
|
||||
func (s *SmtpWriter) Flush() {
|
||||
return
|
||||
}
|
||||
|
||||
// implementing method. empty.
|
||||
func (s *SmtpWriter) Destroy() {
|
||||
return
|
||||
}
|
||||
|
@ -1,3 +1,9 @@
|
||||
// Beego (http://beego.me/)
|
||||
// @description beego is an open-source, high-performance web framework for the Go programming language.
|
||||
// @link http://github.com/astaxie/beego for the canonical source repository
|
||||
// @license http://github.com/astaxie/beego/blob/master/LICENSE
|
||||
// @authors astaxie
|
||||
|
||||
package logs
|
||||
|
||||
import (
|
||||
|
@ -1,3 +1,9 @@
|
||||
// Beego (http://beego.me/)
|
||||
// @description beego is an open-source, high-performance web framework for the Go programming language.
|
||||
// @link http://github.com/astaxie/beego for the canonical source repository
|
||||
// @license http://github.com/astaxie/beego/blob/master/LICENSE
|
||||
// @authors astaxie
|
||||
|
||||
package beego
|
||||
|
||||
import (
|
||||
@ -5,20 +11,21 @@ import (
|
||||
"compress/flate"
|
||||
"compress/gzip"
|
||||
"errors"
|
||||
//"fmt"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"os"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
var gmfim map[string]*MemFileInfo = make(map[string]*MemFileInfo)
|
||||
var gmfim map[string]*memFileInfo = make(map[string]*memFileInfo)
|
||||
var lock sync.RWMutex
|
||||
|
||||
// OpenMemZipFile returns MemFile object with a compressed static file.
|
||||
// it's used for serve static file if gzip enable.
|
||||
func OpenMemZipFile(path string, zip string) (*MemFile, error) {
|
||||
func openMemZipFile(path string, zip string) (*memFile, error) {
|
||||
osfile, e := os.Open(path)
|
||||
if e != nil {
|
||||
return nil, e
|
||||
@ -32,15 +39,12 @@ func OpenMemZipFile(path string, zip string) (*MemFile, error) {
|
||||
|
||||
modtime := osfileinfo.ModTime()
|
||||
fileSize := osfileinfo.Size()
|
||||
|
||||
lock.RLock()
|
||||
cfi, ok := gmfim[zip+":"+path]
|
||||
if ok && cfi.ModTime() == modtime && cfi.fileSize == fileSize {
|
||||
//fmt.Printf("read %s file %s from cache\n", zip, path)
|
||||
} else {
|
||||
//fmt.Printf("NOT read %s file %s from cache\n", zip, path)
|
||||
lock.RUnlock()
|
||||
if !(ok && cfi.ModTime() == modtime && cfi.fileSize == fileSize) {
|
||||
var content []byte
|
||||
if zip == "gzip" {
|
||||
//将文件内容压缩到zipbuf中
|
||||
var zipbuf bytes.Buffer
|
||||
gzipwriter, e := gzip.NewWriterLevel(&zipbuf, gzip.BestCompression)
|
||||
if e != nil {
|
||||
@ -51,13 +55,11 @@ func OpenMemZipFile(path string, zip string) (*MemFile, error) {
|
||||
if e != nil {
|
||||
return nil, e
|
||||
}
|
||||
//读zipbuf到content
|
||||
content, e = ioutil.ReadAll(&zipbuf)
|
||||
if e != nil {
|
||||
return nil, e
|
||||
}
|
||||
} else if zip == "deflate" {
|
||||
//将文件内容压缩到zipbuf中
|
||||
var zipbuf bytes.Buffer
|
||||
deflatewriter, e := flate.NewWriter(&zipbuf, flate.BestCompression)
|
||||
if e != nil {
|
||||
@ -68,7 +70,6 @@ func OpenMemZipFile(path string, zip string) (*MemFile, error) {
|
||||
if e != nil {
|
||||
return nil, e
|
||||
}
|
||||
//将zipbuf读入到content
|
||||
content, e = ioutil.ReadAll(&zipbuf)
|
||||
if e != nil {
|
||||
return nil, e
|
||||
@ -80,16 +81,17 @@ func OpenMemZipFile(path string, zip string) (*MemFile, error) {
|
||||
}
|
||||
}
|
||||
|
||||
cfi = &MemFileInfo{osfileinfo, modtime, content, int64(len(content)), fileSize}
|
||||
cfi = &memFileInfo{osfileinfo, modtime, content, int64(len(content)), fileSize}
|
||||
lock.Lock()
|
||||
defer lock.Unlock()
|
||||
gmfim[zip+":"+path] = cfi
|
||||
//fmt.Printf("%s file %s to %d, cache it\n", zip, path, len(content))
|
||||
}
|
||||
return &MemFile{fi: cfi, offset: 0}, nil
|
||||
return &memFile{fi: cfi, offset: 0}, nil
|
||||
}
|
||||
|
||||
// MemFileInfo contains a compressed file bytes and file information.
|
||||
// it implements os.FileInfo interface.
|
||||
type MemFileInfo struct {
|
||||
type memFileInfo struct {
|
||||
os.FileInfo
|
||||
modTime time.Time
|
||||
content []byte
|
||||
@ -98,62 +100,62 @@ type MemFileInfo struct {
|
||||
}
|
||||
|
||||
// Name returns the compressed filename.
|
||||
func (fi *MemFileInfo) Name() string {
|
||||
func (fi *memFileInfo) Name() string {
|
||||
return fi.Name()
|
||||
}
|
||||
|
||||
// Size returns the raw file content size, not compressed size.
|
||||
func (fi *MemFileInfo) Size() int64 {
|
||||
func (fi *memFileInfo) Size() int64 {
|
||||
return fi.contentSize
|
||||
}
|
||||
|
||||
// Mode returns file mode.
|
||||
func (fi *MemFileInfo) Mode() os.FileMode {
|
||||
func (fi *memFileInfo) Mode() os.FileMode {
|
||||
return fi.Mode()
|
||||
}
|
||||
|
||||
// ModTime returns the last modified time of raw file.
|
||||
func (fi *MemFileInfo) ModTime() time.Time {
|
||||
func (fi *memFileInfo) ModTime() time.Time {
|
||||
return fi.modTime
|
||||
}
|
||||
|
||||
// IsDir returns the compressing file is a directory or not.
|
||||
func (fi *MemFileInfo) IsDir() bool {
|
||||
func (fi *memFileInfo) IsDir() bool {
|
||||
return fi.IsDir()
|
||||
}
|
||||
|
||||
// return nil. implement the os.FileInfo interface method.
|
||||
func (fi *MemFileInfo) Sys() interface{} {
|
||||
func (fi *memFileInfo) Sys() interface{} {
|
||||
return nil
|
||||
}
|
||||
|
||||
// MemFile contains MemFileInfo and bytes offset when reading.
|
||||
// it implements io.Reader,io.ReadCloser and io.Seeker.
|
||||
type MemFile struct {
|
||||
fi *MemFileInfo
|
||||
type memFile struct {
|
||||
fi *memFileInfo
|
||||
offset int64
|
||||
}
|
||||
|
||||
// Close memfile.
|
||||
func (f *MemFile) Close() error {
|
||||
func (f *memFile) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Get os.FileInfo of memfile.
|
||||
func (f *MemFile) Stat() (os.FileInfo, error) {
|
||||
func (f *memFile) Stat() (os.FileInfo, error) {
|
||||
return f.fi, nil
|
||||
}
|
||||
|
||||
// read os.FileInfo of files in directory of memfile.
|
||||
// it returns empty slice.
|
||||
func (f *MemFile) Readdir(count int) ([]os.FileInfo, error) {
|
||||
func (f *memFile) Readdir(count int) ([]os.FileInfo, error) {
|
||||
infos := []os.FileInfo{}
|
||||
|
||||
return infos, nil
|
||||
}
|
||||
|
||||
// Read bytes from the compressed file bytes.
|
||||
func (f *MemFile) Read(p []byte) (n int, err error) {
|
||||
func (f *memFile) Read(p []byte) (n int, err error) {
|
||||
if len(f.fi.content)-int(f.offset) >= len(p) {
|
||||
n = len(p)
|
||||
} else {
|
||||
@ -169,7 +171,7 @@ var errWhence = errors.New("Seek: invalid whence")
|
||||
var errOffset = errors.New("Seek: invalid offset")
|
||||
|
||||
// Read bytes from the compressed file bytes by seeker.
|
||||
func (f *MemFile) Seek(offset int64, whence int) (ret int64, err error) {
|
||||
func (f *memFile) Seek(offset int64, whence int) (ret int64, err error) {
|
||||
switch whence {
|
||||
default:
|
||||
return 0, errWhence
|
||||
@ -189,7 +191,7 @@ func (f *MemFile) Seek(offset int64, whence int) (ret int64, err error) {
|
||||
// GetAcceptEncodingZip returns accept encoding format in http header.
|
||||
// zip is first, then deflate if both accepted.
|
||||
// If no accepted, return empty string.
|
||||
func GetAcceptEncodingZip(r *http.Request) string {
|
||||
func getAcceptEncodingZip(r *http.Request) string {
|
||||
ss := r.Header.Get("Accept-Encoding")
|
||||
ss = strings.ToLower(ss)
|
||||
if strings.Contains(ss, "gzip") {
|
||||
@ -199,24 +201,4 @@ func GetAcceptEncodingZip(r *http.Request) string {
|
||||
} else {
|
||||
return ""
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// CloseZWriter closes the io.Writer after compressing static file.
|
||||
func CloseZWriter(zwriter io.Writer) {
|
||||
if zwriter == nil {
|
||||
return
|
||||
}
|
||||
|
||||
switch zwriter.(type) {
|
||||
case *gzip.Writer:
|
||||
zwriter.(*gzip.Writer).Close()
|
||||
case *flate.Writer:
|
||||
zwriter.(*flate.Writer).Close()
|
||||
//其他情况不close, 保持和默认(非压缩)行为一致
|
||||
/*
|
||||
case io.WriteCloser:
|
||||
zwriter.(io.WriteCloser).Close()
|
||||
*/
|
||||
}
|
||||
}
|
||||
|
@ -1,3 +1,9 @@
|
||||
// Beego (http://beego.me/)
|
||||
// @description beego is an open-source, high-performance web framework for the Go programming language.
|
||||
// @link http://github.com/astaxie/beego for the canonical source repository
|
||||
// @license http://github.com/astaxie/beego/blob/master/LICENSE
|
||||
// @authors astaxie
|
||||
|
||||
package middleware
|
||||
|
||||
import (
|
||||
@ -61,6 +67,7 @@ var tpl = `
|
||||
</html>
|
||||
`
|
||||
|
||||
// render default application error page with error and stack string.
|
||||
func ShowErr(err interface{}, rw http.ResponseWriter, r *http.Request, Stack string) {
|
||||
t, _ := template.New("beegoerrortemp").Parse(tpl)
|
||||
data := make(map[string]string)
|
||||
@ -71,6 +78,7 @@ func ShowErr(err interface{}, rw http.ResponseWriter, r *http.Request, Stack str
|
||||
data["Stack"] = Stack
|
||||
data["BeegoVersion"] = VERSION
|
||||
data["GoVersion"] = runtime.Version()
|
||||
rw.WriteHeader(500)
|
||||
t.Execute(rw, data)
|
||||
}
|
||||
|
||||
@ -166,7 +174,7 @@ var errtpl = `
|
||||
{{.Content}}
|
||||
<a href="/" title="Home" class="button">Go Home</a><br />
|
||||
|
||||
<br>power by beego {{.BeegoVersion}}
|
||||
<br>Powered by beego {{.BeegoVersion}}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
@ -174,18 +182,19 @@ var errtpl = `
|
||||
</html>
|
||||
`
|
||||
|
||||
// map of http handlers for each error string.
|
||||
var ErrorMaps map[string]http.HandlerFunc
|
||||
|
||||
func init() {
|
||||
ErrorMaps = make(map[string]http.HandlerFunc)
|
||||
}
|
||||
|
||||
//404
|
||||
// show 404 notfound error.
|
||||
func NotFound(rw http.ResponseWriter, r *http.Request) {
|
||||
t, _ := template.New("beegoerrortemp").Parse(errtpl)
|
||||
data := make(map[string]interface{})
|
||||
data["Title"] = "Page Not Found"
|
||||
data["Content"] = template.HTML("<br>The Page You have requested flown the coop." +
|
||||
data["Content"] = template.HTML("<br>The page you have requested has flown the coop." +
|
||||
"<br>Perhaps you are here because:" +
|
||||
"<br><br><ul>" +
|
||||
"<br>The page has moved" +
|
||||
@ -198,28 +207,28 @@ func NotFound(rw http.ResponseWriter, r *http.Request) {
|
||||
t.Execute(rw, data)
|
||||
}
|
||||
|
||||
//401
|
||||
// show 401 unauthorized error.
|
||||
func Unauthorized(rw http.ResponseWriter, r *http.Request) {
|
||||
t, _ := template.New("beegoerrortemp").Parse(errtpl)
|
||||
data := make(map[string]interface{})
|
||||
data["Title"] = "Unauthorized"
|
||||
data["Content"] = template.HTML("<br>The Page You have requested can't authorized." +
|
||||
data["Content"] = template.HTML("<br>The page you have requested can't be authorized." +
|
||||
"<br>Perhaps you are here because:" +
|
||||
"<br><br><ul>" +
|
||||
"<br>Check the credentials that you supplied" +
|
||||
"<br>Check the address for errors" +
|
||||
"<br>The credentials you supplied are incorrect" +
|
||||
"<br>There are errors in the website address" +
|
||||
"</ul>")
|
||||
data["BeegoVersion"] = VERSION
|
||||
//rw.WriteHeader(http.StatusUnauthorized)
|
||||
t.Execute(rw, data)
|
||||
}
|
||||
|
||||
//403
|
||||
// show 403 forbidden error.
|
||||
func Forbidden(rw http.ResponseWriter, r *http.Request) {
|
||||
t, _ := template.New("beegoerrortemp").Parse(errtpl)
|
||||
data := make(map[string]interface{})
|
||||
data["Title"] = "Forbidden"
|
||||
data["Content"] = template.HTML("<br>The Page You have requested forbidden." +
|
||||
data["Content"] = template.HTML("<br>The page you have requested is forbidden." +
|
||||
"<br>Perhaps you are here because:" +
|
||||
"<br><br><ul>" +
|
||||
"<br>Your address may be blocked" +
|
||||
@ -231,12 +240,12 @@ func Forbidden(rw http.ResponseWriter, r *http.Request) {
|
||||
t.Execute(rw, data)
|
||||
}
|
||||
|
||||
//503
|
||||
// show 503 service unavailable error.
|
||||
func ServiceUnavailable(rw http.ResponseWriter, r *http.Request) {
|
||||
t, _ := template.New("beegoerrortemp").Parse(errtpl)
|
||||
data := make(map[string]interface{})
|
||||
data["Title"] = "Service Unavailable"
|
||||
data["Content"] = template.HTML("<br>The Page You have requested unavailable." +
|
||||
data["Content"] = template.HTML("<br>The page you have requested is unavailable." +
|
||||
"<br>Perhaps you are here because:" +
|
||||
"<br><br><ul>" +
|
||||
"<br><br>The page is overloaded" +
|
||||
@ -247,30 +256,32 @@ func ServiceUnavailable(rw http.ResponseWriter, r *http.Request) {
|
||||
t.Execute(rw, data)
|
||||
}
|
||||
|
||||
//500
|
||||
// show 500 internal server error.
|
||||
func InternalServerError(rw http.ResponseWriter, r *http.Request) {
|
||||
t, _ := template.New("beegoerrortemp").Parse(errtpl)
|
||||
data := make(map[string]interface{})
|
||||
data["Title"] = "Internal Server Error"
|
||||
data["Content"] = template.HTML("<br>The Page You have requested has down now." +
|
||||
data["Content"] = template.HTML("<br>The page you have requested is down right now." +
|
||||
"<br><br><ul>" +
|
||||
"<br>simply try again later" +
|
||||
"<br>you should report the fault to the website administrator" +
|
||||
"</ul>")
|
||||
"<br>Please try again later and report the error to the website administrator" +
|
||||
"<br></ul>")
|
||||
data["BeegoVersion"] = VERSION
|
||||
//rw.WriteHeader(http.StatusInternalServerError)
|
||||
t.Execute(rw, data)
|
||||
}
|
||||
|
||||
// show 500 internal error with simple text string.
|
||||
func SimpleServerError(rw http.ResponseWriter, r *http.Request) {
|
||||
http.Error(rw, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
// add http handler for given error string.
|
||||
func Errorhandler(err string, h http.HandlerFunc) {
|
||||
ErrorMaps[err] = h
|
||||
}
|
||||
|
||||
func RegisterErrorHander() {
|
||||
// register default error http handlers, 404,401,403,500 and 503.
|
||||
func RegisterErrorHandler() {
|
||||
if _, ok := ErrorMaps["404"]; !ok {
|
||||
ErrorMaps["404"] = NotFound
|
||||
}
|
||||
@ -292,6 +303,8 @@ func RegisterErrorHander() {
|
||||
}
|
||||
}
|
||||
|
||||
// show error string as simple text message.
|
||||
// if error string is empty, show 500 error as default.
|
||||
func Exception(errcode string, w http.ResponseWriter, r *http.Request, msg string) {
|
||||
if h, ok := ErrorMaps[errcode]; ok {
|
||||
isint, err := strconv.Atoi(errcode)
|
||||
|
@ -1,17 +1,26 @@
|
||||
// Beego (http://beego.me/)
|
||||
// @description beego is an open-source, high-performance web framework for the Go programming language.
|
||||
// @link http://github.com/astaxie/beego for the canonical source repository
|
||||
// @license http://github.com/astaxie/beego/blob/master/LICENSE
|
||||
// @authors astaxie
|
||||
|
||||
package middleware
|
||||
|
||||
import "fmt"
|
||||
|
||||
// http exceptions
|
||||
type HTTPException struct {
|
||||
StatusCode int // http status code 4xx, 5xx
|
||||
Description string
|
||||
}
|
||||
|
||||
// return http exception error string, e.g. "400 Bad Request".
|
||||
func (e *HTTPException) Error() string {
|
||||
// return `status description`, e.g. `400 Bad Request`
|
||||
return fmt.Sprintf("%d %s", e.StatusCode, e.Description)
|
||||
}
|
||||
|
||||
// map of http exceptions for each http status code int.
|
||||
// defined 400,401,403,404,405,500,502,503 and 504 default.
|
||||
var HTTPExceptionMaps map[int]HTTPException
|
||||
|
||||
func init() {
|
||||
|
@ -1,3 +1,9 @@
|
||||
// Beego (http://beego.me/)
|
||||
// @description beego is an open-source, high-performance web framework for the Go programming language.
|
||||
// @link http://github.com/astaxie/beego for the canonical source repository
|
||||
// @license http://github.com/astaxie/beego/blob/master/LICENSE
|
||||
// @authors astaxie
|
||||
|
||||
package middleware
|
||||
|
||||
//import (
|
||||
|
9
mime.go
9
mime.go
@ -1,3 +1,9 @@
|
||||
// Beego (http://beego.me/)
|
||||
// @description beego is an open-source, high-performance web framework for the Go programming language.
|
||||
// @link http://github.com/astaxie/beego for the canonical source repository
|
||||
// @license http://github.com/astaxie/beego/blob/master/LICENSE
|
||||
// @authors astaxie
|
||||
|
||||
package beego
|
||||
|
||||
import (
|
||||
@ -544,8 +550,9 @@ var mimemaps map[string]string = map[string]string{
|
||||
".mustache": "text/html",
|
||||
}
|
||||
|
||||
func initMime() {
|
||||
func initMime() error {
|
||||
for k, v := range mimemaps {
|
||||
mime.AddExtensionType(k, v)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
137
namespace.go
Normal file
137
namespace.go
Normal file
@ -0,0 +1,137 @@
|
||||
// Beego (http://beego.me/)
|
||||
// @description beego is an open-source, high-performance web framework for the Go programming language.
|
||||
// @link http://github.com/astaxie/beego for the canonical source repository
|
||||
// @license http://github.com/astaxie/beego/blob/master/LICENSE
|
||||
// @authors astaxie
|
||||
package beego
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
beecontext "github.com/astaxie/beego/context"
|
||||
)
|
||||
|
||||
type namespaceCond func(*beecontext.Context) bool
|
||||
|
||||
type Namespace struct {
|
||||
prefix string
|
||||
condition namespaceCond
|
||||
handlers *ControllerRegistor
|
||||
}
|
||||
|
||||
func NewNamespace(prefix string) *Namespace {
|
||||
cr := NewControllerRegistor()
|
||||
return &Namespace{
|
||||
prefix: prefix,
|
||||
handlers: cr,
|
||||
}
|
||||
}
|
||||
|
||||
func (n *Namespace) Cond(cond namespaceCond) *Namespace {
|
||||
n.condition = cond
|
||||
return n
|
||||
}
|
||||
|
||||
func (n *Namespace) Filter(action string, filter FilterFunc) *Namespace {
|
||||
if action == "before" {
|
||||
action = "BeforeRouter"
|
||||
} else if action == "after" {
|
||||
action = "FinishRouter"
|
||||
}
|
||||
n.handlers.AddFilter("*", action, filter)
|
||||
return n
|
||||
}
|
||||
|
||||
func (n *Namespace) Router(rootpath string, c ControllerInterface, mappingMethods ...string) *Namespace {
|
||||
n.handlers.Add(rootpath, c, mappingMethods...)
|
||||
return n
|
||||
}
|
||||
|
||||
func (n *Namespace) AutoRouter(c ControllerInterface) *Namespace {
|
||||
n.handlers.AddAuto(c)
|
||||
return n
|
||||
}
|
||||
|
||||
func (n *Namespace) AutoPrefix(prefix string, c ControllerInterface) *Namespace {
|
||||
n.handlers.AddAutoPrefix(prefix, c)
|
||||
return n
|
||||
}
|
||||
|
||||
func (n *Namespace) Get(rootpath string, f FilterFunc) *Namespace {
|
||||
n.handlers.Get(rootpath, f)
|
||||
return n
|
||||
}
|
||||
|
||||
func (n *Namespace) Post(rootpath string, f FilterFunc) *Namespace {
|
||||
n.handlers.Post(rootpath, f)
|
||||
return n
|
||||
}
|
||||
|
||||
func (n *Namespace) Delete(rootpath string, f FilterFunc) *Namespace {
|
||||
n.handlers.Delete(rootpath, f)
|
||||
return n
|
||||
}
|
||||
|
||||
func (n *Namespace) Put(rootpath string, f FilterFunc) *Namespace {
|
||||
n.handlers.Put(rootpath, f)
|
||||
return n
|
||||
}
|
||||
|
||||
func (n *Namespace) Head(rootpath string, f FilterFunc) *Namespace {
|
||||
n.handlers.Head(rootpath, f)
|
||||
return n
|
||||
}
|
||||
|
||||
func (n *Namespace) Options(rootpath string, f FilterFunc) *Namespace {
|
||||
n.handlers.Options(rootpath, f)
|
||||
return n
|
||||
}
|
||||
|
||||
func (n *Namespace) Patch(rootpath string, f FilterFunc) *Namespace {
|
||||
n.handlers.Patch(rootpath, f)
|
||||
return n
|
||||
}
|
||||
|
||||
func (n *Namespace) Any(rootpath string, f FilterFunc) *Namespace {
|
||||
n.handlers.Any(rootpath, f)
|
||||
return n
|
||||
}
|
||||
|
||||
func (n *Namespace) Handler(rootpath string, h http.Handler) *Namespace {
|
||||
n.handlers.Handler(rootpath, h)
|
||||
return n
|
||||
}
|
||||
|
||||
func (n *Namespace) Namespace(ns *Namespace) *Namespace {
|
||||
n.handlers.Handler(ns.prefix, ns, true)
|
||||
return n
|
||||
}
|
||||
|
||||
func (n *Namespace) ServeHTTP(rw http.ResponseWriter, r *http.Request) {
|
||||
//trim the preifix from URL.Path
|
||||
r.URL.Path = strings.TrimPrefix(r.URL.Path, n.prefix)
|
||||
// init context
|
||||
context := &beecontext.Context{
|
||||
ResponseWriter: rw,
|
||||
Request: r,
|
||||
Input: beecontext.NewInput(r),
|
||||
Output: beecontext.NewOutput(),
|
||||
}
|
||||
context.Output.Context = context
|
||||
context.Output.EnableGzip = EnableGzip
|
||||
|
||||
if context.Input.IsWebsocket() {
|
||||
context.ResponseWriter = rw
|
||||
}
|
||||
if n.condition != nil && !n.condition(context) {
|
||||
http.Error(rw, "Method Not Allowed", 405)
|
||||
}
|
||||
n.handlers.ServeHTTP(rw, r)
|
||||
}
|
||||
|
||||
func AddNamespace(nl ...*Namespace) {
|
||||
for _, n := range nl {
|
||||
Handler(n.prefix, n, true)
|
||||
}
|
||||
}
|
137
namespace_test.go
Normal file
137
namespace_test.go
Normal file
@ -0,0 +1,137 @@
|
||||
// Beego (http://beego.me/)
|
||||
// @description beego is an open-source, high-performance web framework for the Go programming language.
|
||||
// @link http://github.com/astaxie/beego for the canonical source repository
|
||||
// @license http://github.com/astaxie/beego/blob/master/LICENSE
|
||||
// @authors astaxie
|
||||
|
||||
package beego
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strconv"
|
||||
"testing"
|
||||
|
||||
"github.com/astaxie/beego/context"
|
||||
)
|
||||
|
||||
func TestNamespaceGet(t *testing.T) {
|
||||
r, _ := http.NewRequest("GET", "/v1/user", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
ns := NewNamespace("/v1")
|
||||
ns.Get("/user", func(ctx *context.Context) {
|
||||
ctx.Output.Body([]byte("v1_user"))
|
||||
})
|
||||
ns.ServeHTTP(w, r)
|
||||
if w.Body.String() != "v1_user" {
|
||||
t.Errorf("TestNamespaceGet can't run, get the response is " + w.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestNamespacePost(t *testing.T) {
|
||||
r, _ := http.NewRequest("POST", "/v1/user/123", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
ns := NewNamespace("/v1")
|
||||
ns.Post("/user/:id", func(ctx *context.Context) {
|
||||
ctx.Output.Body([]byte(ctx.Input.Param(":id")))
|
||||
})
|
||||
ns.ServeHTTP(w, r)
|
||||
if w.Body.String() != "123" {
|
||||
t.Errorf("TestNamespacePost can't run, get the response is " + w.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestNamespaceNest(t *testing.T) {
|
||||
r, _ := http.NewRequest("GET", "/v1/admin/order", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
ns := NewNamespace("/v1")
|
||||
ns.Namespace(
|
||||
NewNamespace("/admin").
|
||||
Get("/order", func(ctx *context.Context) {
|
||||
ctx.Output.Body([]byte("order"))
|
||||
}),
|
||||
)
|
||||
ns.ServeHTTP(w, r)
|
||||
if w.Body.String() != "order" {
|
||||
t.Errorf("TestNamespaceNest can't run, get the response is " + w.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestNamespaceNestParam(t *testing.T) {
|
||||
r, _ := http.NewRequest("GET", "/v1/admin/order/123", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
ns := NewNamespace("/v1")
|
||||
ns.Namespace(
|
||||
NewNamespace("/admin").
|
||||
Get("/order/:id", func(ctx *context.Context) {
|
||||
ctx.Output.Body([]byte(ctx.Input.Param(":id")))
|
||||
}),
|
||||
)
|
||||
ns.ServeHTTP(w, r)
|
||||
if w.Body.String() != "123" {
|
||||
t.Errorf("TestNamespaceNestParam can't run, get the response is " + w.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestNamespaceFilter(t *testing.T) {
|
||||
r, _ := http.NewRequest("GET", "/v1/user/123", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
ns := NewNamespace("/v1")
|
||||
ns.Filter("before", func(ctx *context.Context) {
|
||||
ctx.Output.Body([]byte("this is Filter"))
|
||||
}).
|
||||
Get("/user/:id", func(ctx *context.Context) {
|
||||
ctx.Output.Body([]byte(ctx.Input.Param(":id")))
|
||||
})
|
||||
ns.ServeHTTP(w, r)
|
||||
if w.Body.String() != "this is Filter" {
|
||||
t.Errorf("TestNamespaceFilter can't run, get the response is " + w.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestNamespaceRouter(t *testing.T) {
|
||||
r, _ := http.NewRequest("GET", "/v1/api/list", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
ns := NewNamespace("/v1")
|
||||
ns.Router("/api/list", &TestController{}, "*:List")
|
||||
ns.ServeHTTP(w, r)
|
||||
if w.Body.String() != "i am list" {
|
||||
t.Errorf("TestNamespaceRouter can't run, get the response is " + w.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestNamespaceAutoFunc(t *testing.T) {
|
||||
r, _ := http.NewRequest("GET", "/v1/test/list", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
ns := NewNamespace("/v1")
|
||||
ns.AutoRouter(&TestController{})
|
||||
ns.ServeHTTP(w, r)
|
||||
if w.Body.String() != "i am list" {
|
||||
t.Errorf("user define func can't run")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNamespaceCond(t *testing.T) {
|
||||
r, _ := http.NewRequest("GET", "/v1/test/list", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
ns := NewNamespace("/v1")
|
||||
ns.Cond(func(ctx *context.Context) bool {
|
||||
if ctx.Input.Domain() == "beego.me" {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}).
|
||||
AutoRouter(&TestController{})
|
||||
ns.ServeHTTP(w, r)
|
||||
if w.Code != 405 {
|
||||
t.Errorf("TestNamespaceCond can't run get the result " + strconv.Itoa(w.Code))
|
||||
}
|
||||
}
|
18
orm/cmd.go
18
orm/cmd.go
@ -1,3 +1,9 @@
|
||||
// Beego (http://beego.me/)
|
||||
// @description beego is an open-source, high-performance web framework for the Go programming language.
|
||||
// @link http://github.com/astaxie/beego for the canonical source repository
|
||||
// @license http://github.com/astaxie/beego/blob/master/LICENSE
|
||||
// @authors slene
|
||||
|
||||
package orm
|
||||
|
||||
import (
|
||||
@ -16,6 +22,7 @@ var (
|
||||
commands = make(map[string]commander)
|
||||
)
|
||||
|
||||
// print help.
|
||||
func printHelp(errs ...string) {
|
||||
content := `orm command usage:
|
||||
|
||||
@ -31,6 +38,7 @@ func printHelp(errs ...string) {
|
||||
os.Exit(2)
|
||||
}
|
||||
|
||||
// listen for orm command and then run it if command arguments passed.
|
||||
func RunCommand() {
|
||||
if len(os.Args) < 2 || os.Args[1] != "orm" {
|
||||
return
|
||||
@ -58,6 +66,7 @@ func RunCommand() {
|
||||
}
|
||||
}
|
||||
|
||||
// sync database struct command interface.
|
||||
type commandSyncDb struct {
|
||||
al *alias
|
||||
force bool
|
||||
@ -66,6 +75,7 @@ type commandSyncDb struct {
|
||||
rtOnError bool
|
||||
}
|
||||
|
||||
// parse orm command line arguments.
|
||||
func (d *commandSyncDb) Parse(args []string) {
|
||||
var name string
|
||||
|
||||
@ -78,6 +88,7 @@ func (d *commandSyncDb) Parse(args []string) {
|
||||
d.al = getDbAlias(name)
|
||||
}
|
||||
|
||||
// run orm line command.
|
||||
func (d *commandSyncDb) Run() error {
|
||||
var drops []string
|
||||
if d.force {
|
||||
@ -208,10 +219,12 @@ func (d *commandSyncDb) Run() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// database creation commander interface implement.
|
||||
type commandSqlAll struct {
|
||||
al *alias
|
||||
}
|
||||
|
||||
// parse orm command line arguments.
|
||||
func (d *commandSqlAll) Parse(args []string) {
|
||||
var name string
|
||||
|
||||
@ -222,6 +235,7 @@ func (d *commandSqlAll) Parse(args []string) {
|
||||
d.al = getDbAlias(name)
|
||||
}
|
||||
|
||||
// run orm line command.
|
||||
func (d *commandSqlAll) Run() error {
|
||||
sqls, indexes := getDbCreateSql(d.al)
|
||||
var all []string
|
||||
@ -243,6 +257,10 @@ func init() {
|
||||
commands["sqlall"] = new(commandSqlAll)
|
||||
}
|
||||
|
||||
// run syncdb command line.
|
||||
// name means table's alias name. default is "default".
|
||||
// force means run next sql if the current is error.
|
||||
// verbose means show all info when running command or not.
|
||||
func RunSyncdb(name string, force bool, verbose bool) error {
|
||||
BootStrap()
|
||||
|
||||
|
@ -1,3 +1,9 @@
|
||||
// Beego (http://beego.me/)
|
||||
// @description beego is an open-source, high-performance web framework for the Go programming language.
|
||||
// @link http://github.com/astaxie/beego for the canonical source repository
|
||||
// @license http://github.com/astaxie/beego/blob/master/LICENSE
|
||||
// @authors slene
|
||||
|
||||
package orm
|
||||
|
||||
import (
|
||||
@ -12,6 +18,7 @@ type dbIndex struct {
|
||||
Sql string
|
||||
}
|
||||
|
||||
// create database drop sql.
|
||||
func getDbDropSql(al *alias) (sqls []string) {
|
||||
if len(modelCache.cache) == 0 {
|
||||
fmt.Println("no Model found, need register your model")
|
||||
@ -26,6 +33,7 @@ func getDbDropSql(al *alias) (sqls []string) {
|
||||
return sqls
|
||||
}
|
||||
|
||||
// get database column type string.
|
||||
func getColumnTyp(al *alias, fi *fieldInfo) (col string) {
|
||||
T := al.DbBaser.DbTypes()
|
||||
fieldType := fi.fieldType
|
||||
@ -79,6 +87,7 @@ checkColumn:
|
||||
return
|
||||
}
|
||||
|
||||
// create alter sql string.
|
||||
func getColumnAddQuery(al *alias, fi *fieldInfo) string {
|
||||
Q := al.DbBaser.TableQuote()
|
||||
typ := getColumnTyp(al, fi)
|
||||
@ -90,6 +99,7 @@ func getColumnAddQuery(al *alias, fi *fieldInfo) string {
|
||||
return fmt.Sprintf("ALTER TABLE %s%s%s ADD COLUMN %s%s%s %s", Q, fi.mi.table, Q, Q, fi.column, Q, typ)
|
||||
}
|
||||
|
||||
// create database creation string.
|
||||
func getDbCreateSql(al *alias) (sqls []string, tableIndexes map[string][]dbIndex) {
|
||||
if len(modelCache.cache) == 0 {
|
||||
fmt.Println("no Model found, need register your model")
|
||||
|
272
orm/db.go
272
orm/db.go
@ -1,3 +1,9 @@
|
||||
// Beego (http://beego.me/)
|
||||
// @description beego is an open-source, high-performance web framework for the Go programming language.
|
||||
// @link http://github.com/astaxie/beego for the canonical source repository
|
||||
// @license http://github.com/astaxie/beego/blob/master/LICENSE
|
||||
// @authors slene
|
||||
|
||||
package orm
|
||||
|
||||
import (
|
||||
@ -15,7 +21,7 @@ const (
|
||||
)
|
||||
|
||||
var (
|
||||
ErrMissPK = errors.New("missed pk value")
|
||||
ErrMissPK = errors.New("missed pk value") // missing pk error
|
||||
)
|
||||
|
||||
var (
|
||||
@ -35,7 +41,7 @@ var (
|
||||
"istartswith": true,
|
||||
"iendswith": true,
|
||||
"in": true,
|
||||
// "range": true,
|
||||
"between": true,
|
||||
// "year": true,
|
||||
// "month": true,
|
||||
// "day": true,
|
||||
@ -45,13 +51,22 @@ var (
|
||||
}
|
||||
)
|
||||
|
||||
// an instance of dbBaser interface/
|
||||
type dbBase struct {
|
||||
ins dbBaser
|
||||
}
|
||||
|
||||
// check dbBase implements dbBaser interface.
|
||||
var _ dbBaser = new(dbBase)
|
||||
|
||||
func (d *dbBase) collectValues(mi *modelInfo, ind reflect.Value, cols []string, skipAuto bool, insert bool, tz *time.Location) (columns []string, values []interface{}, err error) {
|
||||
// get struct columns values as interface slice.
|
||||
func (d *dbBase) collectValues(mi *modelInfo, ind reflect.Value, cols []string, skipAuto bool, insert bool, names *[]string, tz *time.Location) (values []interface{}, err error) {
|
||||
var columns []string
|
||||
|
||||
if names != nil {
|
||||
columns = *names
|
||||
}
|
||||
|
||||
for _, column := range cols {
|
||||
var fi *fieldInfo
|
||||
if fi, _ = mi.fields.GetByAny(column); fi != nil {
|
||||
@ -64,14 +79,24 @@ func (d *dbBase) collectValues(mi *modelInfo, ind reflect.Value, cols []string,
|
||||
}
|
||||
value, err := d.collectFieldValue(mi, fi, ind, insert, tz)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if names != nil {
|
||||
columns = append(columns, column)
|
||||
}
|
||||
|
||||
values = append(values, value)
|
||||
}
|
||||
|
||||
if names != nil {
|
||||
*names = columns
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// get one field value in struct column as interface.
|
||||
func (d *dbBase) collectFieldValue(mi *modelInfo, fi *fieldInfo, ind reflect.Value, insert bool, tz *time.Location) (interface{}, error) {
|
||||
var value interface{}
|
||||
if fi.pk {
|
||||
@ -84,16 +109,37 @@ func (d *dbBase) collectFieldValue(mi *modelInfo, fi *fieldInfo, ind reflect.Val
|
||||
} else {
|
||||
switch fi.fieldType {
|
||||
case TypeBooleanField:
|
||||
if nb, ok := field.Interface().(sql.NullBool); ok {
|
||||
value = nil
|
||||
if nb.Valid {
|
||||
value = nb.Bool
|
||||
}
|
||||
} else {
|
||||
value = field.Bool()
|
||||
}
|
||||
case TypeCharField, TypeTextField:
|
||||
if ns, ok := field.Interface().(sql.NullString); ok {
|
||||
value = nil
|
||||
if ns.Valid {
|
||||
value = ns.String
|
||||
}
|
||||
} else {
|
||||
value = field.String()
|
||||
}
|
||||
case TypeFloatField, TypeDecimalField:
|
||||
if nf, ok := field.Interface().(sql.NullFloat64); ok {
|
||||
value = nil
|
||||
if nf.Valid {
|
||||
value = nf.Float64
|
||||
}
|
||||
} else {
|
||||
vu := field.Interface()
|
||||
if _, ok := vu.(float32); ok {
|
||||
value, _ = StrTo(ToStr(vu)).Float64()
|
||||
} else {
|
||||
value = field.Float()
|
||||
}
|
||||
}
|
||||
case TypeDateField, TypeDateTimeField:
|
||||
value = field.Interface()
|
||||
if t, ok := value.(time.Time); ok {
|
||||
@ -105,7 +151,14 @@ func (d *dbBase) collectFieldValue(mi *modelInfo, fi *fieldInfo, ind reflect.Val
|
||||
case fi.fieldType&IsPostiveIntegerField > 0:
|
||||
value = field.Uint()
|
||||
case fi.fieldType&IsIntegerField > 0:
|
||||
if ni, ok := field.Interface().(sql.NullInt64); ok {
|
||||
value = nil
|
||||
if ni.Valid {
|
||||
value = ni.Int64
|
||||
}
|
||||
} else {
|
||||
value = field.Int()
|
||||
}
|
||||
case fi.fieldType&IsRelField > 0:
|
||||
if field.IsNil() {
|
||||
value = nil
|
||||
@ -125,6 +178,11 @@ func (d *dbBase) collectFieldValue(mi *modelInfo, fi *fieldInfo, ind reflect.Val
|
||||
switch fi.fieldType {
|
||||
case TypeDateField, TypeDateTimeField:
|
||||
if fi.auto_now || fi.auto_now_add && insert {
|
||||
if insert {
|
||||
if t, ok := value.(time.Time); ok && !t.IsZero() {
|
||||
break
|
||||
}
|
||||
}
|
||||
tnow := time.Now()
|
||||
d.ins.TimeToDB(&tnow, tz)
|
||||
value = tnow
|
||||
@ -140,6 +198,7 @@ func (d *dbBase) collectFieldValue(mi *modelInfo, fi *fieldInfo, ind reflect.Val
|
||||
return value, nil
|
||||
}
|
||||
|
||||
// create insert sql preparation statement object.
|
||||
func (d *dbBase) PrepareInsert(q dbQuerier, mi *modelInfo) (stmtQuerier, string, error) {
|
||||
Q := d.ins.TableQuote()
|
||||
|
||||
@ -165,8 +224,9 @@ func (d *dbBase) PrepareInsert(q dbQuerier, mi *modelInfo) (stmtQuerier, string,
|
||||
return stmt, query, err
|
||||
}
|
||||
|
||||
// insert struct with prepared statement and given struct reflect value.
|
||||
func (d *dbBase) InsertStmt(stmt stmtQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location) (int64, error) {
|
||||
_, values, err := d.collectValues(mi, ind, mi.fields.dbcols, true, true, tz)
|
||||
values, err := d.collectValues(mi, ind, mi.fields.dbcols, true, true, nil, tz)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
@ -185,6 +245,7 @@ func (d *dbBase) InsertStmt(stmt stmtQuerier, mi *modelInfo, ind reflect.Value,
|
||||
}
|
||||
}
|
||||
|
||||
// query sql ,read records and persist in dbBaser.
|
||||
func (d *dbBase) Read(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location, cols []string) error {
|
||||
var whereCols []string
|
||||
var args []interface{}
|
||||
@ -192,7 +253,8 @@ func (d *dbBase) Read(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Lo
|
||||
// if specify cols length > 0, then use it for where condition.
|
||||
if len(cols) > 0 {
|
||||
var err error
|
||||
whereCols, args, err = d.collectValues(mi, ind, cols, false, false, tz)
|
||||
whereCols = make([]string, 0, len(cols))
|
||||
args, err = d.collectValues(mi, ind, cols, false, false, &whereCols, tz)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -202,7 +264,7 @@ func (d *dbBase) Read(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Lo
|
||||
if ok == false {
|
||||
return ErrMissPK
|
||||
}
|
||||
whereCols = append(whereCols, pkColumn)
|
||||
whereCols = []string{pkColumn}
|
||||
args = append(args, pkValue)
|
||||
}
|
||||
|
||||
@ -243,16 +305,77 @@ func (d *dbBase) Read(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Lo
|
||||
return nil
|
||||
}
|
||||
|
||||
// execute insert sql dbQuerier with given struct reflect.Value.
|
||||
func (d *dbBase) Insert(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location) (int64, error) {
|
||||
names, values, err := d.collectValues(mi, ind, mi.fields.dbcols, true, true, tz)
|
||||
names := make([]string, 0, len(mi.fields.dbcols)-1)
|
||||
values, err := d.collectValues(mi, ind, mi.fields.dbcols, true, true, &names, tz)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
return d.InsertValue(q, mi, names, values)
|
||||
return d.InsertValue(q, mi, false, names, values)
|
||||
}
|
||||
|
||||
func (d *dbBase) InsertValue(q dbQuerier, mi *modelInfo, names []string, values []interface{}) (int64, error) {
|
||||
// multi-insert sql with given slice struct reflect.Value.
|
||||
func (d *dbBase) InsertMulti(q dbQuerier, mi *modelInfo, sind reflect.Value, bulk int, tz *time.Location) (int64, error) {
|
||||
var (
|
||||
cnt int64
|
||||
nums int
|
||||
values []interface{}
|
||||
names []string
|
||||
)
|
||||
|
||||
// typ := reflect.Indirect(mi.addrField).Type()
|
||||
|
||||
length := sind.Len()
|
||||
|
||||
for i := 1; i <= length; i++ {
|
||||
|
||||
ind := reflect.Indirect(sind.Index(i - 1))
|
||||
|
||||
// Is this needed ?
|
||||
// if !ind.Type().AssignableTo(typ) {
|
||||
// return cnt, ErrArgs
|
||||
// }
|
||||
|
||||
if i == 1 {
|
||||
vus, err := d.collectValues(mi, ind, mi.fields.dbcols, true, true, &names, tz)
|
||||
if err != nil {
|
||||
return cnt, err
|
||||
}
|
||||
values = make([]interface{}, bulk*len(vus))
|
||||
nums += copy(values, vus)
|
||||
|
||||
} else {
|
||||
|
||||
vus, err := d.collectValues(mi, ind, mi.fields.dbcols, true, true, nil, tz)
|
||||
if err != nil {
|
||||
return cnt, err
|
||||
}
|
||||
|
||||
if len(vus) != len(names) {
|
||||
return cnt, ErrArgs
|
||||
}
|
||||
|
||||
nums += copy(values[nums:], vus)
|
||||
}
|
||||
|
||||
if i > 1 && i%bulk == 0 || length == i {
|
||||
num, err := d.InsertValue(q, mi, true, names, values[:nums])
|
||||
if err != nil {
|
||||
return cnt, err
|
||||
}
|
||||
cnt += num
|
||||
nums = 0
|
||||
}
|
||||
}
|
||||
|
||||
return cnt, nil
|
||||
}
|
||||
|
||||
// execute insert sql with given struct and given values.
|
||||
// insert the given values, not the field values in struct.
|
||||
func (d *dbBase) InsertValue(q dbQuerier, mi *modelInfo, isMulti bool, names []string, values []interface{}) (int64, error) {
|
||||
Q := d.ins.TableQuote()
|
||||
|
||||
marks := make([]string, len(names))
|
||||
@ -264,36 +387,51 @@ func (d *dbBase) InsertValue(q dbQuerier, mi *modelInfo, names []string, values
|
||||
qmarks := strings.Join(marks, ", ")
|
||||
columns := strings.Join(names, sep)
|
||||
|
||||
multi := len(values) / len(names)
|
||||
|
||||
if isMulti {
|
||||
qmarks = strings.Repeat(qmarks+"), (", multi-1) + qmarks
|
||||
}
|
||||
|
||||
query := fmt.Sprintf("INSERT INTO %s%s%s (%s%s%s) VALUES (%s)", Q, mi.table, Q, Q, columns, Q, qmarks)
|
||||
|
||||
d.ins.ReplaceMarks(&query)
|
||||
|
||||
if d.ins.HasReturningID(mi, &query) {
|
||||
row := q.QueryRow(query, values...)
|
||||
var id int64
|
||||
err := row.Scan(&id)
|
||||
return id, err
|
||||
} else {
|
||||
if isMulti || !d.ins.HasReturningID(mi, &query) {
|
||||
if res, err := q.Exec(query, values...); err == nil {
|
||||
if isMulti {
|
||||
return res.RowsAffected()
|
||||
}
|
||||
return res.LastInsertId()
|
||||
} else {
|
||||
return 0, err
|
||||
}
|
||||
} else {
|
||||
row := q.QueryRow(query, values...)
|
||||
var id int64
|
||||
err := row.Scan(&id)
|
||||
return id, err
|
||||
}
|
||||
}
|
||||
|
||||
// execute update sql dbQuerier with given struct reflect.Value.
|
||||
func (d *dbBase) Update(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location, cols []string) (int64, error) {
|
||||
pkName, pkValue, ok := getExistPk(mi, ind)
|
||||
if ok == false {
|
||||
return 0, ErrMissPK
|
||||
}
|
||||
|
||||
var setNames []string
|
||||
|
||||
// if specify cols length is zero, then commit all columns.
|
||||
if len(cols) == 0 {
|
||||
cols = mi.fields.dbcols
|
||||
setNames = make([]string, 0, len(mi.fields.dbcols)-1)
|
||||
} else {
|
||||
setNames = make([]string, 0, len(cols))
|
||||
}
|
||||
|
||||
setNames, setValues, err := d.collectValues(mi, ind, cols, true, false, tz)
|
||||
setValues, err := d.collectValues(mi, ind, cols, true, false, &setNames, tz)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
@ -314,9 +452,10 @@ func (d *dbBase) Update(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.
|
||||
} else {
|
||||
return 0, err
|
||||
}
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
// execute delete sql dbQuerier with given struct reflect.Value.
|
||||
// delete index is pk.
|
||||
func (d *dbBase) Delete(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location) (int64, error) {
|
||||
pkName, pkValue, ok := getExistPk(mi, ind)
|
||||
if ok == false {
|
||||
@ -355,9 +494,10 @@ func (d *dbBase) Delete(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.
|
||||
} else {
|
||||
return 0, err
|
||||
}
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
// update table-related record by querySet.
|
||||
// need querySet not struct reflect.Value to update related records.
|
||||
func (d *dbBase) UpdateBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition, params Params, tz *time.Location) (int64, error) {
|
||||
columns := make([]string, 0, len(params))
|
||||
values := make([]interface{}, 0, len(params))
|
||||
@ -430,9 +570,10 @@ func (d *dbBase) UpdateBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Con
|
||||
} else {
|
||||
return 0, err
|
||||
}
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
// delete related records.
|
||||
// do UpdateBanch or DeleteBanch by condition of tables' relationship.
|
||||
func (d *dbBase) deleteRels(q dbQuerier, mi *modelInfo, args []interface{}, tz *time.Location) error {
|
||||
for _, fi := range mi.fields.fieldsReverse {
|
||||
fi = fi.reverseFieldInfo
|
||||
@ -459,8 +600,11 @@ func (d *dbBase) deleteRels(q dbQuerier, mi *modelInfo, args []interface{}, tz *
|
||||
return nil
|
||||
}
|
||||
|
||||
// delete table-related records.
|
||||
func (d *dbBase) DeleteBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition, tz *time.Location) (int64, error) {
|
||||
tables := newDbTables(mi, d.ins)
|
||||
tables.skipEnd = true
|
||||
|
||||
if qs != nil {
|
||||
tables.parseRelated(qs.related, qs.relDepth)
|
||||
}
|
||||
@ -486,6 +630,8 @@ func (d *dbBase) DeleteBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Con
|
||||
rs = r
|
||||
}
|
||||
|
||||
defer rs.Close()
|
||||
|
||||
var ref interface{}
|
||||
|
||||
args = make([]interface{}, 0)
|
||||
@ -528,10 +674,9 @@ func (d *dbBase) DeleteBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Con
|
||||
} else {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
// read related records.
|
||||
func (d *dbBase) ReadBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition, container interface{}, tz *time.Location, cols []string) (int64, error) {
|
||||
|
||||
val := reflect.ValueOf(container)
|
||||
@ -640,6 +785,8 @@ func (d *dbBase) ReadBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condi
|
||||
refs[i] = &ref
|
||||
}
|
||||
|
||||
defer rs.Close()
|
||||
|
||||
slice := ind
|
||||
|
||||
var cnt int64
|
||||
@ -739,6 +886,7 @@ func (d *dbBase) ReadBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condi
|
||||
return cnt, nil
|
||||
}
|
||||
|
||||
// excute count sql and return count result int64.
|
||||
func (d *dbBase) Count(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition, tz *time.Location) (cnt int64, err error) {
|
||||
tables := newDbTables(mi, d.ins)
|
||||
tables.parseRelated(qs.related, qs.relDepth)
|
||||
@ -759,6 +907,7 @@ func (d *dbBase) Count(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition
|
||||
return
|
||||
}
|
||||
|
||||
// generate sql with replacing operator string placeholders and replaced values.
|
||||
func (d *dbBase) GenerateOperatorSql(mi *modelInfo, fi *fieldInfo, operator string, args []interface{}, tz *time.Location) (string, []interface{}) {
|
||||
sql := ""
|
||||
params := getFlatParams(fi, args, tz)
|
||||
@ -768,13 +917,19 @@ func (d *dbBase) GenerateOperatorSql(mi *modelInfo, fi *fieldInfo, operator stri
|
||||
}
|
||||
arg := params[0]
|
||||
|
||||
if operator == "in" {
|
||||
switch operator {
|
||||
case "in":
|
||||
marks := make([]string, len(params))
|
||||
for i, _ := range marks {
|
||||
marks[i] = "?"
|
||||
}
|
||||
sql = fmt.Sprintf("IN (%s)", strings.Join(marks, ", "))
|
||||
} else {
|
||||
case "between":
|
||||
if len(params) != 2 {
|
||||
panic(fmt.Errorf("operator `%s` need 2 args not %d", operator, len(params)))
|
||||
}
|
||||
sql = "BETWEEN ? AND ?"
|
||||
default:
|
||||
if len(params) > 1 {
|
||||
panic(fmt.Errorf("operator `%s` need 1 args not %d", operator, len(params)))
|
||||
}
|
||||
@ -812,10 +967,12 @@ func (d *dbBase) GenerateOperatorSql(mi *modelInfo, fi *fieldInfo, operator stri
|
||||
return sql, params
|
||||
}
|
||||
|
||||
// gernerate sql string with inner function, such as UPPER(text).
|
||||
func (d *dbBase) GenerateOperatorLeftCol(*fieldInfo, string, *string) {
|
||||
// default not use
|
||||
}
|
||||
|
||||
// set values to struct column.
|
||||
func (d *dbBase) setColsValues(mi *modelInfo, ind *reflect.Value, cols []string, values []interface{}, tz *time.Location) {
|
||||
for i, column := range cols {
|
||||
val := reflect.Indirect(reflect.ValueOf(values[i])).Interface()
|
||||
@ -837,6 +994,7 @@ func (d *dbBase) setColsValues(mi *modelInfo, ind *reflect.Value, cols []string,
|
||||
}
|
||||
}
|
||||
|
||||
// convert value from database result to value following in field type.
|
||||
func (d *dbBase) convertValueFromDB(fi *fieldInfo, val interface{}, tz *time.Location) (interface{}, error) {
|
||||
if val == nil {
|
||||
return nil, nil
|
||||
@ -989,6 +1147,7 @@ end:
|
||||
|
||||
}
|
||||
|
||||
// set one value to struct column field.
|
||||
func (d *dbBase) setFieldValue(fi *fieldInfo, value interface{}, field reflect.Value) (interface{}, error) {
|
||||
|
||||
fieldType := fi.fieldType
|
||||
@ -998,18 +1157,38 @@ setValue:
|
||||
switch {
|
||||
case fieldType == TypeBooleanField:
|
||||
if isNative {
|
||||
if nb, ok := field.Interface().(sql.NullBool); ok {
|
||||
if value == nil {
|
||||
nb.Valid = false
|
||||
} else {
|
||||
nb.Bool = value.(bool)
|
||||
nb.Valid = true
|
||||
}
|
||||
field.Set(reflect.ValueOf(nb))
|
||||
} else {
|
||||
if value == nil {
|
||||
value = false
|
||||
}
|
||||
field.SetBool(value.(bool))
|
||||
}
|
||||
}
|
||||
case fieldType == TypeCharField || fieldType == TypeTextField:
|
||||
if isNative {
|
||||
if ns, ok := field.Interface().(sql.NullString); ok {
|
||||
if value == nil {
|
||||
ns.Valid = false
|
||||
} else {
|
||||
ns.String = value.(string)
|
||||
ns.Valid = true
|
||||
}
|
||||
field.Set(reflect.ValueOf(ns))
|
||||
} else {
|
||||
if value == nil {
|
||||
value = ""
|
||||
}
|
||||
field.SetString(value.(string))
|
||||
}
|
||||
}
|
||||
case fieldType == TypeDateField || fieldType == TypeDateTimeField:
|
||||
if isNative {
|
||||
if value == nil {
|
||||
@ -1027,19 +1206,40 @@ setValue:
|
||||
}
|
||||
} else {
|
||||
if isNative {
|
||||
if ni, ok := field.Interface().(sql.NullInt64); ok {
|
||||
if value == nil {
|
||||
ni.Valid = false
|
||||
} else {
|
||||
ni.Int64 = value.(int64)
|
||||
ni.Valid = true
|
||||
}
|
||||
field.Set(reflect.ValueOf(ni))
|
||||
} else {
|
||||
if value == nil {
|
||||
value = int64(0)
|
||||
}
|
||||
field.SetInt(value.(int64))
|
||||
}
|
||||
}
|
||||
}
|
||||
case fieldType == TypeFloatField || fieldType == TypeDecimalField:
|
||||
if isNative {
|
||||
if nf, ok := field.Interface().(sql.NullFloat64); ok {
|
||||
if value == nil {
|
||||
nf.Valid = false
|
||||
} else {
|
||||
nf.Float64 = value.(float64)
|
||||
nf.Valid = true
|
||||
}
|
||||
field.Set(reflect.ValueOf(nf))
|
||||
} else {
|
||||
|
||||
if value == nil {
|
||||
value = float64(0)
|
||||
}
|
||||
field.SetFloat(value.(float64))
|
||||
}
|
||||
}
|
||||
case fieldType&IsRelField > 0:
|
||||
if value != nil {
|
||||
fieldType = fi.relModelInfo.fields.pk.fieldType
|
||||
@ -1063,6 +1263,7 @@ setValue:
|
||||
return value, nil
|
||||
}
|
||||
|
||||
// query sql, read values , save to *[]ParamList.
|
||||
func (d *dbBase) ReadValues(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition, exprs []string, container interface{}, tz *time.Location) (int64, error) {
|
||||
|
||||
var (
|
||||
@ -1150,6 +1351,8 @@ func (d *dbBase) ReadValues(q dbQuerier, qs *querySet, mi *modelInfo, cond *Cond
|
||||
refs[i] = &ref
|
||||
}
|
||||
|
||||
defer rs.Close()
|
||||
|
||||
var (
|
||||
cnt int64
|
||||
columns []string
|
||||
@ -1228,6 +1431,11 @@ func (d *dbBase) ReadValues(q dbQuerier, qs *querySet, mi *modelInfo, cond *Cond
|
||||
return cnt, nil
|
||||
}
|
||||
|
||||
func (d *dbBase) RowsTo(dbQuerier, *querySet, *modelInfo, *Condition, interface{}, string, string, *time.Location) (int64, error) {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
// flag of update joined record.
|
||||
func (d *dbBase) SupportUpdateJoin() bool {
|
||||
return true
|
||||
}
|
||||
@ -1236,30 +1444,37 @@ func (d *dbBase) MaxLimit() uint64 {
|
||||
return 18446744073709551615
|
||||
}
|
||||
|
||||
// return quote.
|
||||
func (d *dbBase) TableQuote() string {
|
||||
return "`"
|
||||
}
|
||||
|
||||
// replace value placeholer in parametered sql string.
|
||||
func (d *dbBase) ReplaceMarks(query *string) {
|
||||
// default use `?` as mark, do nothing
|
||||
}
|
||||
|
||||
// flag of RETURNING sql.
|
||||
func (d *dbBase) HasReturningID(*modelInfo, *string) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// convert time from db.
|
||||
func (d *dbBase) TimeFromDB(t *time.Time, tz *time.Location) {
|
||||
*t = t.In(tz)
|
||||
}
|
||||
|
||||
// convert time to db.
|
||||
func (d *dbBase) TimeToDB(t *time.Time, tz *time.Location) {
|
||||
*t = t.In(tz)
|
||||
}
|
||||
|
||||
// get database types.
|
||||
func (d *dbBase) DbTypes() map[string]string {
|
||||
return nil
|
||||
}
|
||||
|
||||
// gt all tables.
|
||||
func (d *dbBase) GetTables(db dbQuerier) (map[string]bool, error) {
|
||||
tables := make(map[string]bool)
|
||||
query := d.ins.ShowTablesQuery()
|
||||
@ -1268,6 +1483,8 @@ func (d *dbBase) GetTables(db dbQuerier) (map[string]bool, error) {
|
||||
return tables, err
|
||||
}
|
||||
|
||||
defer rows.Close()
|
||||
|
||||
for rows.Next() {
|
||||
var table string
|
||||
err := rows.Scan(&table)
|
||||
@ -1282,6 +1499,7 @@ func (d *dbBase) GetTables(db dbQuerier) (map[string]bool, error) {
|
||||
return tables, nil
|
||||
}
|
||||
|
||||
// get all cloumns in table.
|
||||
func (d *dbBase) GetColumns(db dbQuerier, table string) (map[string][3]string, error) {
|
||||
columns := make(map[string][3]string)
|
||||
query := d.ins.ShowColumnsQuery(table)
|
||||
@ -1290,6 +1508,8 @@ func (d *dbBase) GetColumns(db dbQuerier, table string) (map[string][3]string, e
|
||||
return columns, err
|
||||
}
|
||||
|
||||
defer rows.Close()
|
||||
|
||||
for rows.Next() {
|
||||
var (
|
||||
name string
|
||||
@ -1306,18 +1526,22 @@ func (d *dbBase) GetColumns(db dbQuerier, table string) (map[string][3]string, e
|
||||
return columns, nil
|
||||
}
|
||||
|
||||
// not implement.
|
||||
func (d *dbBase) OperatorSql(operator string) string {
|
||||
panic(ErrNotImplement)
|
||||
}
|
||||
|
||||
// not implement.
|
||||
func (d *dbBase) ShowTablesQuery() string {
|
||||
panic(ErrNotImplement)
|
||||
}
|
||||
|
||||
// not implement.
|
||||
func (d *dbBase) ShowColumnsQuery(table string) string {
|
||||
panic(ErrNotImplement)
|
||||
}
|
||||
|
||||
// not implement.
|
||||
func (d *dbBase) IndexExists(dbQuerier, string, string) bool {
|
||||
panic(ErrNotImplement)
|
||||
}
|
||||
|
176
orm/db_alias.go
176
orm/db_alias.go
@ -1,35 +1,45 @@
|
||||
// Beego (http://beego.me/)
|
||||
// @description beego is an open-source, high-performance web framework for the Go programming language.
|
||||
// @link http://github.com/astaxie/beego for the canonical source repository
|
||||
// @license http://github.com/astaxie/beego/blob/master/LICENSE
|
||||
// @authors slene
|
||||
|
||||
package orm
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"os"
|
||||
"reflect"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// database driver constant int.
|
||||
type DriverType int
|
||||
|
||||
const (
|
||||
_ DriverType = iota
|
||||
DR_MySQL
|
||||
DR_Sqlite
|
||||
DR_Oracle
|
||||
DR_Postgres
|
||||
_ DriverType = iota // int enum type
|
||||
DR_MySQL // mysql
|
||||
DR_Sqlite // sqlite
|
||||
DR_Oracle // oracle
|
||||
DR_Postgres // pgsql
|
||||
)
|
||||
|
||||
// database driver string.
|
||||
type driver string
|
||||
|
||||
// get type constant int of current driver..
|
||||
func (d driver) Type() DriverType {
|
||||
a, _ := dataBaseCache.get(string(d))
|
||||
return a.Driver
|
||||
}
|
||||
|
||||
// get name of current driver
|
||||
func (d driver) Name() string {
|
||||
return string(d)
|
||||
}
|
||||
|
||||
// check driver iis implemented Driver interface or not.
|
||||
var _ Driver = new(driver)
|
||||
|
||||
var (
|
||||
@ -47,11 +57,13 @@ var (
|
||||
}
|
||||
)
|
||||
|
||||
// database alias cacher.
|
||||
type _dbCache struct {
|
||||
mux sync.RWMutex
|
||||
cache map[string]*alias
|
||||
}
|
||||
|
||||
// add database alias with original name.
|
||||
func (ac *_dbCache) add(name string, al *alias) (added bool) {
|
||||
ac.mux.Lock()
|
||||
defer ac.mux.Unlock()
|
||||
@ -62,6 +74,7 @@ func (ac *_dbCache) add(name string, al *alias) (added bool) {
|
||||
return
|
||||
}
|
||||
|
||||
// get database alias if cached.
|
||||
func (ac *_dbCache) get(name string) (al *alias, ok bool) {
|
||||
ac.mux.RLock()
|
||||
defer ac.mux.RUnlock()
|
||||
@ -69,6 +82,7 @@ func (ac *_dbCache) get(name string) (al *alias, ok bool) {
|
||||
return
|
||||
}
|
||||
|
||||
// get default alias.
|
||||
func (ac *_dbCache) getDefault() (al *alias) {
|
||||
al, _ = ac.get("default")
|
||||
return
|
||||
@ -87,57 +101,29 @@ type alias struct {
|
||||
Engine string
|
||||
}
|
||||
|
||||
// Setting the database connect params. Use the database driver self dataSource args.
|
||||
func RegisterDataBase(aliasName, driverName, dataSource string, params ...int) {
|
||||
al := new(alias)
|
||||
al.Name = aliasName
|
||||
al.DriverName = driverName
|
||||
al.DataSource = dataSource
|
||||
|
||||
var (
|
||||
err error
|
||||
)
|
||||
|
||||
if dr, ok := drivers[driverName]; ok {
|
||||
al.DbBaser = dbBasers[dr]
|
||||
al.Driver = dr
|
||||
} else {
|
||||
err = fmt.Errorf("driver name `%s` have not registered", driverName)
|
||||
goto end
|
||||
}
|
||||
|
||||
if dataBaseCache.add(aliasName, al) == false {
|
||||
err = fmt.Errorf("db name `%s` already registered, cannot reuse", aliasName)
|
||||
goto end
|
||||
}
|
||||
|
||||
al.DB, err = sql.Open(driverName, dataSource)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("register db `%s`, %s", aliasName, err.Error())
|
||||
goto end
|
||||
}
|
||||
|
||||
func detectTZ(al *alias) {
|
||||
// orm timezone system match database
|
||||
// default use Local
|
||||
al.TZ = time.Local
|
||||
|
||||
if al.DriverName == "sphinx" {
|
||||
return
|
||||
}
|
||||
|
||||
switch al.Driver {
|
||||
case DR_MySQL:
|
||||
row := al.DB.QueryRow("SELECT @@session.time_zone")
|
||||
row := al.DB.QueryRow("SELECT TIMEDIFF(NOW(), UTC_TIMESTAMP)")
|
||||
var tz string
|
||||
row.Scan(&tz)
|
||||
if tz == "SYSTEM" {
|
||||
tz = ""
|
||||
row = al.DB.QueryRow("SELECT @@system_time_zone")
|
||||
row.Scan(&tz)
|
||||
t, err := time.Parse("MST", tz)
|
||||
if err == nil {
|
||||
al.TZ = t.Location()
|
||||
if len(tz) >= 8 {
|
||||
if tz[0] != '-' {
|
||||
tz = "+" + tz
|
||||
}
|
||||
} else {
|
||||
t, err := time.Parse("-07:00", tz)
|
||||
t, err := time.Parse("-07:00:00", tz)
|
||||
if err == nil {
|
||||
al.TZ = t.Location()
|
||||
} else {
|
||||
DebugLog.Printf("Detect DB timezone: %s %s\n", tz, err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
@ -163,8 +149,64 @@ func RegisterDataBase(aliasName, driverName, dataSource string, params ...int) {
|
||||
loc, err := time.LoadLocation(tz)
|
||||
if err == nil {
|
||||
al.TZ = loc
|
||||
} else {
|
||||
DebugLog.Printf("Detect DB timezone: %s %s\n", tz, err.Error())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func addAliasWthDB(aliasName, driverName string, db *sql.DB) (*alias, error) {
|
||||
al := new(alias)
|
||||
al.Name = aliasName
|
||||
al.DriverName = driverName
|
||||
al.DB = db
|
||||
|
||||
if dr, ok := drivers[driverName]; ok {
|
||||
al.DbBaser = dbBasers[dr]
|
||||
al.Driver = dr
|
||||
} else {
|
||||
return nil, fmt.Errorf("driver name `%s` have not registered", driverName)
|
||||
}
|
||||
|
||||
err := db.Ping()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("register db Ping `%s`, %s", aliasName, err.Error())
|
||||
}
|
||||
|
||||
if dataBaseCache.add(aliasName, al) == false {
|
||||
return nil, fmt.Errorf("DataBase alias name `%s` already registered, cannot reuse", aliasName)
|
||||
}
|
||||
|
||||
return al, nil
|
||||
}
|
||||
|
||||
func AddAliasWthDB(aliasName, driverName string, db *sql.DB) error {
|
||||
_, err := addAliasWthDB(aliasName, driverName, db)
|
||||
return err
|
||||
}
|
||||
|
||||
// Setting the database connect params. Use the database driver self dataSource args.
|
||||
func RegisterDataBase(aliasName, driverName, dataSource string, params ...int) error {
|
||||
var (
|
||||
err error
|
||||
db *sql.DB
|
||||
al *alias
|
||||
)
|
||||
|
||||
db, err = sql.Open(driverName, dataSource)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("register db `%s`, %s", aliasName, err.Error())
|
||||
goto end
|
||||
}
|
||||
|
||||
al, err = addAliasWthDB(aliasName, driverName, db)
|
||||
if err != nil {
|
||||
goto end
|
||||
}
|
||||
|
||||
al.DataSource = dataSource
|
||||
|
||||
detectTZ(al)
|
||||
|
||||
for i, v := range params {
|
||||
switch i {
|
||||
@ -175,39 +217,37 @@ func RegisterDataBase(aliasName, driverName, dataSource string, params ...int) {
|
||||
}
|
||||
}
|
||||
|
||||
err = al.DB.Ping()
|
||||
if err != nil {
|
||||
err = fmt.Errorf("register db `%s`, %s", aliasName, err.Error())
|
||||
goto end
|
||||
}
|
||||
|
||||
end:
|
||||
if err != nil {
|
||||
fmt.Println(err.Error())
|
||||
os.Exit(2)
|
||||
if db != nil {
|
||||
db.Close()
|
||||
}
|
||||
DebugLog.Println(err.Error())
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// Register a database driver use specify driver name, this can be definition the driver is which database type.
|
||||
func RegisterDriver(driverName string, typ DriverType) {
|
||||
func RegisterDriver(driverName string, typ DriverType) error {
|
||||
if t, ok := drivers[driverName]; ok == false {
|
||||
drivers[driverName] = typ
|
||||
} else {
|
||||
if t != typ {
|
||||
fmt.Sprintf("driverName `%s` db driver already registered and is other type\n", driverName)
|
||||
os.Exit(2)
|
||||
return fmt.Errorf("driverName `%s` db driver already registered and is other type\n", driverName)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Change the database default used timezone
|
||||
func SetDataBaseTZ(aliasName string, tz *time.Location) {
|
||||
func SetDataBaseTZ(aliasName string, tz *time.Location) error {
|
||||
if al, ok := dataBaseCache.get(aliasName); ok {
|
||||
al.TZ = tz
|
||||
} else {
|
||||
fmt.Sprintf("DataBase name `%s` not registered\n", aliasName)
|
||||
os.Exit(2)
|
||||
return fmt.Errorf("DataBase alias name `%s` not registered\n", aliasName)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Change the max idle conns for *sql.DB, use specify database alias name
|
||||
@ -226,3 +266,19 @@ func SetMaxOpenConns(aliasName string, maxOpenConns int) {
|
||||
fun.Call([]reflect.Value{reflect.ValueOf(maxOpenConns)})
|
||||
}
|
||||
}
|
||||
|
||||
// Get *sql.DB from registered database by db alias name.
|
||||
// Use "default" as alias name if you not set.
|
||||
func GetDB(aliasNames ...string) (*sql.DB, error) {
|
||||
var name string
|
||||
if len(aliasNames) > 0 {
|
||||
name = aliasNames[0]
|
||||
} else {
|
||||
name = "default"
|
||||
}
|
||||
if al, ok := dataBaseCache.get(name); ok {
|
||||
return al.DB, nil
|
||||
} else {
|
||||
return nil, fmt.Errorf("DataBase of alias name `%s` not found\n", name)
|
||||
}
|
||||
}
|
||||
|
@ -1,9 +1,16 @@
|
||||
// Beego (http://beego.me/)
|
||||
// @description beego is an open-source, high-performance web framework for the Go programming language.
|
||||
// @link http://github.com/astaxie/beego for the canonical source repository
|
||||
// @license http://github.com/astaxie/beego/blob/master/LICENSE
|
||||
// @authors slene
|
||||
|
||||
package orm
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
)
|
||||
|
||||
// mysql operators.
|
||||
var mysqlOperators = map[string]string{
|
||||
"exact": "= ?",
|
||||
"iexact": "LIKE ?",
|
||||
@ -21,6 +28,7 @@ var mysqlOperators = map[string]string{
|
||||
"iendswith": "LIKE ?",
|
||||
}
|
||||
|
||||
// mysql column field types.
|
||||
var mysqlTypes = map[string]string{
|
||||
"auto": "AUTO_INCREMENT NOT NULL PRIMARY KEY",
|
||||
"pk": "NOT NULL PRIMARY KEY",
|
||||
@ -41,29 +49,35 @@ var mysqlTypes = map[string]string{
|
||||
"float64-decimal": "numeric(%d, %d)",
|
||||
}
|
||||
|
||||
// mysql dbBaser implementation.
|
||||
type dbBaseMysql struct {
|
||||
dbBase
|
||||
}
|
||||
|
||||
var _ dbBaser = new(dbBaseMysql)
|
||||
|
||||
// get mysql operator.
|
||||
func (d *dbBaseMysql) OperatorSql(operator string) string {
|
||||
return mysqlOperators[operator]
|
||||
}
|
||||
|
||||
// get mysql table field types.
|
||||
func (d *dbBaseMysql) DbTypes() map[string]string {
|
||||
return mysqlTypes
|
||||
}
|
||||
|
||||
// show table sql for mysql.
|
||||
func (d *dbBaseMysql) ShowTablesQuery() string {
|
||||
return "SELECT table_name FROM information_schema.tables WHERE table_type = 'BASE TABLE' AND table_schema = DATABASE()"
|
||||
}
|
||||
|
||||
// show columns sql of table for mysql.
|
||||
func (d *dbBaseMysql) ShowColumnsQuery(table string) string {
|
||||
return fmt.Sprintf("SELECT COLUMN_NAME, COLUMN_TYPE, IS_NULLABLE FROM information_schema.columns "+
|
||||
"WHERE table_schema = DATABASE() AND table_name = '%s'", table)
|
||||
}
|
||||
|
||||
// execute sql to check index exist.
|
||||
func (d *dbBaseMysql) IndexExists(db dbQuerier, table string, name string) bool {
|
||||
row := db.QueryRow("SELECT count(*) FROM information_schema.statistics "+
|
||||
"WHERE table_schema = DATABASE() AND table_name = ? AND index_name = ?", table, name)
|
||||
@ -72,6 +86,7 @@ func (d *dbBaseMysql) IndexExists(db dbQuerier, table string, name string) bool
|
||||
return cnt > 0
|
||||
}
|
||||
|
||||
// create new mysql dbBaser.
|
||||
func newdbBaseMysql() dbBaser {
|
||||
b := new(dbBaseMysql)
|
||||
b.ins = b
|
||||
|
@ -1,11 +1,19 @@
|
||||
// Beego (http://beego.me/)
|
||||
// @description beego is an open-source, high-performance web framework for the Go programming language.
|
||||
// @link http://github.com/astaxie/beego for the canonical source repository
|
||||
// @license http://github.com/astaxie/beego/blob/master/LICENSE
|
||||
// @authors slene
|
||||
|
||||
package orm
|
||||
|
||||
// oracle dbBaser
|
||||
type dbBaseOracle struct {
|
||||
dbBase
|
||||
}
|
||||
|
||||
var _ dbBaser = new(dbBaseOracle)
|
||||
|
||||
// create oracle dbBaser.
|
||||
func newdbBaseOracle() dbBaser {
|
||||
b := new(dbBaseOracle)
|
||||
b.ins = b
|
||||
|
@ -1,3 +1,9 @@
|
||||
// Beego (http://beego.me/)
|
||||
// @description beego is an open-source, high-performance web framework for the Go programming language.
|
||||
// @link http://github.com/astaxie/beego for the canonical source repository
|
||||
// @license http://github.com/astaxie/beego/blob/master/LICENSE
|
||||
// @authors slene
|
||||
|
||||
package orm
|
||||
|
||||
import (
|
||||
@ -5,6 +11,7 @@ import (
|
||||
"strconv"
|
||||
)
|
||||
|
||||
// postgresql operators.
|
||||
var postgresOperators = map[string]string{
|
||||
"exact": "= ?",
|
||||
"iexact": "= UPPER(?)",
|
||||
@ -20,6 +27,7 @@ var postgresOperators = map[string]string{
|
||||
"iendswith": "LIKE UPPER(?)",
|
||||
}
|
||||
|
||||
// postgresql column field types.
|
||||
var postgresTypes = map[string]string{
|
||||
"auto": "serial NOT NULL PRIMARY KEY",
|
||||
"pk": "NOT NULL PRIMARY KEY",
|
||||
@ -40,16 +48,19 @@ var postgresTypes = map[string]string{
|
||||
"float64-decimal": "numeric(%d, %d)",
|
||||
}
|
||||
|
||||
// postgresql dbBaser.
|
||||
type dbBasePostgres struct {
|
||||
dbBase
|
||||
}
|
||||
|
||||
var _ dbBaser = new(dbBasePostgres)
|
||||
|
||||
// get postgresql operator.
|
||||
func (d *dbBasePostgres) OperatorSql(operator string) string {
|
||||
return postgresOperators[operator]
|
||||
}
|
||||
|
||||
// generate functioned sql string, such as contains(text).
|
||||
func (d *dbBasePostgres) GenerateOperatorLeftCol(fi *fieldInfo, operator string, leftCol *string) {
|
||||
switch operator {
|
||||
case "contains", "startswith", "endswith":
|
||||
@ -59,6 +70,7 @@ func (d *dbBasePostgres) GenerateOperatorLeftCol(fi *fieldInfo, operator string,
|
||||
}
|
||||
}
|
||||
|
||||
// postgresql unsupports updating joined record.
|
||||
func (d *dbBasePostgres) SupportUpdateJoin() bool {
|
||||
return false
|
||||
}
|
||||
@ -67,10 +79,13 @@ func (d *dbBasePostgres) MaxLimit() uint64 {
|
||||
return 0
|
||||
}
|
||||
|
||||
// postgresql quote is ".
|
||||
func (d *dbBasePostgres) TableQuote() string {
|
||||
return `"`
|
||||
}
|
||||
|
||||
// postgresql value placeholder is $n.
|
||||
// replace default ? to $n.
|
||||
func (d *dbBasePostgres) ReplaceMarks(query *string) {
|
||||
q := *query
|
||||
num := 0
|
||||
@ -97,6 +112,7 @@ func (d *dbBasePostgres) ReplaceMarks(query *string) {
|
||||
*query = string(data)
|
||||
}
|
||||
|
||||
// make returning sql support for postgresql.
|
||||
func (d *dbBasePostgres) HasReturningID(mi *modelInfo, query *string) (has bool) {
|
||||
if mi.fields.pk.auto {
|
||||
if query != nil {
|
||||
@ -107,18 +123,22 @@ func (d *dbBasePostgres) HasReturningID(mi *modelInfo, query *string) (has bool)
|
||||
return
|
||||
}
|
||||
|
||||
// show table sql for postgresql.
|
||||
func (d *dbBasePostgres) ShowTablesQuery() string {
|
||||
return "SELECT table_name FROM information_schema.tables WHERE table_type = 'BASE TABLE' AND table_schema NOT IN ('pg_catalog', 'information_schema')"
|
||||
}
|
||||
|
||||
// show table columns sql for postgresql.
|
||||
func (d *dbBasePostgres) ShowColumnsQuery(table string) string {
|
||||
return fmt.Sprintf("SELECT column_name, data_type, is_nullable FROM information_schema.columns where table_schema NOT IN ('pg_catalog', 'information_schema') and table_name = '%s'", table)
|
||||
}
|
||||
|
||||
// get column types of postgresql.
|
||||
func (d *dbBasePostgres) DbTypes() map[string]string {
|
||||
return postgresTypes
|
||||
}
|
||||
|
||||
// check index exist in postgresql.
|
||||
func (d *dbBasePostgres) IndexExists(db dbQuerier, table string, name string) bool {
|
||||
query := fmt.Sprintf("SELECT COUNT(*) FROM pg_indexes WHERE tablename = '%s' AND indexname = '%s'", table, name)
|
||||
row := db.QueryRow(query)
|
||||
@ -127,6 +147,7 @@ func (d *dbBasePostgres) IndexExists(db dbQuerier, table string, name string) bo
|
||||
return cnt > 0
|
||||
}
|
||||
|
||||
// create new postgresql dbBaser.
|
||||
func newdbBasePostgres() dbBaser {
|
||||
b := new(dbBasePostgres)
|
||||
b.ins = b
|
||||
|
@ -1,3 +1,9 @@
|
||||
// Beego (http://beego.me/)
|
||||
// @description beego is an open-source, high-performance web framework for the Go programming language.
|
||||
// @link http://github.com/astaxie/beego for the canonical source repository
|
||||
// @license http://github.com/astaxie/beego/blob/master/LICENSE
|
||||
// @authors slene
|
||||
|
||||
package orm
|
||||
|
||||
import (
|
||||
@ -5,6 +11,7 @@ import (
|
||||
"fmt"
|
||||
)
|
||||
|
||||
// sqlite operators.
|
||||
var sqliteOperators = map[string]string{
|
||||
"exact": "= ?",
|
||||
"iexact": "LIKE ? ESCAPE '\\'",
|
||||
@ -20,6 +27,7 @@ var sqliteOperators = map[string]string{
|
||||
"iendswith": "LIKE ? ESCAPE '\\'",
|
||||
}
|
||||
|
||||
// sqlite column types.
|
||||
var sqliteTypes = map[string]string{
|
||||
"auto": "integer NOT NULL PRIMARY KEY AUTOINCREMENT",
|
||||
"pk": "NOT NULL PRIMARY KEY",
|
||||
@ -40,38 +48,47 @@ var sqliteTypes = map[string]string{
|
||||
"float64-decimal": "decimal",
|
||||
}
|
||||
|
||||
// sqlite dbBaser.
|
||||
type dbBaseSqlite struct {
|
||||
dbBase
|
||||
}
|
||||
|
||||
var _ dbBaser = new(dbBaseSqlite)
|
||||
|
||||
// get sqlite operator.
|
||||
func (d *dbBaseSqlite) OperatorSql(operator string) string {
|
||||
return sqliteOperators[operator]
|
||||
}
|
||||
|
||||
// generate functioned sql for sqlite.
|
||||
// only support DATE(text).
|
||||
func (d *dbBaseSqlite) GenerateOperatorLeftCol(fi *fieldInfo, operator string, leftCol *string) {
|
||||
if fi.fieldType == TypeDateField {
|
||||
*leftCol = fmt.Sprintf("DATE(%s)", *leftCol)
|
||||
}
|
||||
}
|
||||
|
||||
// unable updating joined record in sqlite.
|
||||
func (d *dbBaseSqlite) SupportUpdateJoin() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// max int in sqlite.
|
||||
func (d *dbBaseSqlite) MaxLimit() uint64 {
|
||||
return 9223372036854775807
|
||||
}
|
||||
|
||||
// get column types in sqlite.
|
||||
func (d *dbBaseSqlite) DbTypes() map[string]string {
|
||||
return sqliteTypes
|
||||
}
|
||||
|
||||
// get show tables sql in sqlite.
|
||||
func (d *dbBaseSqlite) ShowTablesQuery() string {
|
||||
return "SELECT name FROM sqlite_master WHERE type = 'table'"
|
||||
}
|
||||
|
||||
// get columns in sqlite.
|
||||
func (d *dbBaseSqlite) GetColumns(db dbQuerier, table string) (map[string][3]string, error) {
|
||||
query := d.ins.ShowColumnsQuery(table)
|
||||
rows, err := db.Query(query)
|
||||
@ -92,10 +109,12 @@ func (d *dbBaseSqlite) GetColumns(db dbQuerier, table string) (map[string][3]str
|
||||
return columns, nil
|
||||
}
|
||||
|
||||
// get show columns sql in sqlite.
|
||||
func (d *dbBaseSqlite) ShowColumnsQuery(table string) string {
|
||||
return fmt.Sprintf("pragma table_info('%s')", table)
|
||||
}
|
||||
|
||||
// check index exist in sqlite.
|
||||
func (d *dbBaseSqlite) IndexExists(db dbQuerier, table string, name string) bool {
|
||||
query := fmt.Sprintf("PRAGMA index_list('%s')", table)
|
||||
rows, err := db.Query(query)
|
||||
@ -113,6 +132,7 @@ func (d *dbBaseSqlite) IndexExists(db dbQuerier, table string, name string) bool
|
||||
return false
|
||||
}
|
||||
|
||||
// create new sqlite dbBaser.
|
||||
func newdbBaseSqlite() dbBaser {
|
||||
b := new(dbBaseSqlite)
|
||||
b.ins = b
|
||||
|
@ -1,3 +1,9 @@
|
||||
// Beego (http://beego.me/)
|
||||
// @description beego is an open-source, high-performance web framework for the Go programming language.
|
||||
// @link http://github.com/astaxie/beego for the canonical source repository
|
||||
// @license http://github.com/astaxie/beego/blob/master/LICENSE
|
||||
// @authors slene
|
||||
|
||||
package orm
|
||||
|
||||
import (
|
||||
@ -6,6 +12,7 @@ import (
|
||||
"time"
|
||||
)
|
||||
|
||||
// table info struct.
|
||||
type dbTable struct {
|
||||
id int
|
||||
index string
|
||||
@ -18,13 +25,17 @@ type dbTable struct {
|
||||
jtl *dbTable
|
||||
}
|
||||
|
||||
// tables collection struct, contains some tables.
|
||||
type dbTables struct {
|
||||
tablesM map[string]*dbTable
|
||||
tables []*dbTable
|
||||
mi *modelInfo
|
||||
base dbBaser
|
||||
skipEnd bool
|
||||
}
|
||||
|
||||
// set table info to collection.
|
||||
// if not exist, create new.
|
||||
func (t *dbTables) set(names []string, mi *modelInfo, fi *fieldInfo, inner bool) *dbTable {
|
||||
name := strings.Join(names, ExprSep)
|
||||
if j, ok := t.tablesM[name]; ok {
|
||||
@ -41,6 +52,7 @@ func (t *dbTables) set(names []string, mi *modelInfo, fi *fieldInfo, inner bool)
|
||||
return t.tablesM[name]
|
||||
}
|
||||
|
||||
// add table info to collection.
|
||||
func (t *dbTables) add(names []string, mi *modelInfo, fi *fieldInfo, inner bool) (*dbTable, bool) {
|
||||
name := strings.Join(names, ExprSep)
|
||||
if _, ok := t.tablesM[name]; ok == false {
|
||||
@ -53,11 +65,14 @@ func (t *dbTables) add(names []string, mi *modelInfo, fi *fieldInfo, inner bool)
|
||||
return t.tablesM[name], false
|
||||
}
|
||||
|
||||
// get table info in collection.
|
||||
func (t *dbTables) get(name string) (*dbTable, bool) {
|
||||
j, ok := t.tablesM[name]
|
||||
return j, ok
|
||||
}
|
||||
|
||||
// get related fields info in recursive depth loop.
|
||||
// loop once, depth decreases one.
|
||||
func (t *dbTables) loopDepth(depth int, prefix string, fi *fieldInfo, related []string) []string {
|
||||
if depth < 0 || fi.fieldType == RelManyToMany {
|
||||
return related
|
||||
@ -78,6 +93,7 @@ func (t *dbTables) loopDepth(depth int, prefix string, fi *fieldInfo, related []
|
||||
return related
|
||||
}
|
||||
|
||||
// parse related fields.
|
||||
func (t *dbTables) parseRelated(rels []string, depth int) {
|
||||
|
||||
relsNum := len(rels)
|
||||
@ -111,7 +127,7 @@ func (t *dbTables) parseRelated(rels []string, depth int) {
|
||||
names = append(names, fi.name)
|
||||
mmi = fi.relModelInfo
|
||||
|
||||
if fi.null {
|
||||
if fi.null || t.skipEnd {
|
||||
inner = false
|
||||
}
|
||||
|
||||
@ -139,6 +155,7 @@ func (t *dbTables) parseRelated(rels []string, depth int) {
|
||||
}
|
||||
}
|
||||
|
||||
// generate join string.
|
||||
func (t *dbTables) getJoinSql() (join string) {
|
||||
Q := t.base.TableQuote()
|
||||
|
||||
@ -185,9 +202,12 @@ func (t *dbTables) getJoinSql() (join string) {
|
||||
return
|
||||
}
|
||||
|
||||
// parse orm model struct field tag expression.
|
||||
func (t *dbTables) parseExprs(mi *modelInfo, exprs []string) (index, name string, info *fieldInfo, success bool) {
|
||||
var (
|
||||
jtl *dbTable
|
||||
fi *fieldInfo
|
||||
fiN *fieldInfo
|
||||
mmi = mi
|
||||
)
|
||||
|
||||
@ -196,9 +216,22 @@ func (t *dbTables) parseExprs(mi *modelInfo, exprs []string) (index, name string
|
||||
|
||||
inner := true
|
||||
|
||||
loopFor:
|
||||
for i, ex := range exprs {
|
||||
|
||||
fi, ok := mmi.fields.GetByAny(ex)
|
||||
var ok, okN bool
|
||||
|
||||
if fiN != nil {
|
||||
fi = fiN
|
||||
ok = true
|
||||
fiN = nil
|
||||
}
|
||||
|
||||
if i == 0 {
|
||||
fi, ok = mmi.fields.GetByAny(ex)
|
||||
}
|
||||
|
||||
_ = okN
|
||||
|
||||
if ok {
|
||||
|
||||
@ -216,17 +249,33 @@ func (t *dbTables) parseExprs(mi *modelInfo, exprs []string) (index, name string
|
||||
mmi = fi.reverseFieldInfo.mi
|
||||
}
|
||||
|
||||
if i < num {
|
||||
fiN, okN = mmi.fields.GetByAny(exprs[i+1])
|
||||
}
|
||||
|
||||
if isRel && (fi.mi.isThrough == false || num != i) {
|
||||
if fi.null {
|
||||
if fi.null || t.skipEnd {
|
||||
inner = false
|
||||
}
|
||||
|
||||
if t.skipEnd && okN || !t.skipEnd {
|
||||
if t.skipEnd && okN && fiN.pk {
|
||||
goto loopEnd
|
||||
}
|
||||
|
||||
jt, _ := t.add(names, mmi, fi, inner)
|
||||
jt.jtl = jtl
|
||||
jtl = jt
|
||||
}
|
||||
|
||||
if num == i {
|
||||
}
|
||||
|
||||
if num != i {
|
||||
continue
|
||||
}
|
||||
|
||||
loopEnd:
|
||||
|
||||
if i == 0 || jtl == nil {
|
||||
index = "T0"
|
||||
} else {
|
||||
@ -252,7 +301,8 @@ func (t *dbTables) parseExprs(mi *modelInfo, exprs []string) (index, name string
|
||||
name = info.name
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
break loopFor
|
||||
|
||||
} else {
|
||||
index = ""
|
||||
@ -267,6 +317,7 @@ func (t *dbTables) parseExprs(mi *modelInfo, exprs []string) (index, name string
|
||||
return
|
||||
}
|
||||
|
||||
// generate condition sql.
|
||||
func (t *dbTables) getCondSql(cond *Condition, sub bool, tz *time.Location) (where string, params []interface{}) {
|
||||
if cond == nil || cond.IsEmpty() {
|
||||
return
|
||||
@ -331,6 +382,7 @@ func (t *dbTables) getCondSql(cond *Condition, sub bool, tz *time.Location) (whe
|
||||
return
|
||||
}
|
||||
|
||||
// generate order sql.
|
||||
func (t *dbTables) getOrderSql(orders []string) (orderSql string) {
|
||||
if len(orders) == 0 {
|
||||
return
|
||||
@ -359,6 +411,7 @@ func (t *dbTables) getOrderSql(orders []string) (orderSql string) {
|
||||
return
|
||||
}
|
||||
|
||||
// generate limit sql.
|
||||
func (t *dbTables) getLimitSql(mi *modelInfo, offset int64, limit int64) (limits string) {
|
||||
if limit == 0 {
|
||||
limit = int64(DefaultRowsLimit)
|
||||
@ -381,6 +434,7 @@ func (t *dbTables) getLimitSql(mi *modelInfo, offset int64, limit int64) (limits
|
||||
return
|
||||
}
|
||||
|
||||
// crete new tables collection.
|
||||
func newDbTables(mi *modelInfo, base dbBaser) *dbTables {
|
||||
tables := &dbTables{}
|
||||
tables.tablesM = make(map[string]*dbTable)
|
||||
|
@ -1,3 +1,9 @@
|
||||
// Beego (http://beego.me/)
|
||||
// @description beego is an open-source, high-performance web framework for the Go programming language.
|
||||
// @link http://github.com/astaxie/beego for the canonical source repository
|
||||
// @license http://github.com/astaxie/beego/blob/master/LICENSE
|
||||
// @authors slene
|
||||
|
||||
package orm
|
||||
|
||||
import (
|
||||
@ -6,15 +12,16 @@ import (
|
||||
"time"
|
||||
)
|
||||
|
||||
// get table alias.
|
||||
func getDbAlias(name string) *alias {
|
||||
if al, ok := dataBaseCache.get(name); ok {
|
||||
return al
|
||||
} else {
|
||||
panic(fmt.Errorf("unknown DataBase alias name %s", name))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// get pk column info.
|
||||
func getExistPk(mi *modelInfo, ind reflect.Value) (column string, value interface{}, exist bool) {
|
||||
fi := mi.fields.pk
|
||||
|
||||
@ -37,6 +44,7 @@ func getExistPk(mi *modelInfo, ind reflect.Value) (column string, value interfac
|
||||
return
|
||||
}
|
||||
|
||||
// get fields description as flatted string.
|
||||
func getFlatParams(fi *fieldInfo, args []interface{}, tz *time.Location) (params []interface{}) {
|
||||
|
||||
outFor:
|
||||
@ -48,9 +56,16 @@ outFor:
|
||||
continue
|
||||
}
|
||||
|
||||
switch v := arg.(type) {
|
||||
case []byte:
|
||||
case string:
|
||||
kind := val.Kind()
|
||||
if kind == reflect.Ptr {
|
||||
val = val.Elem()
|
||||
kind = val.Kind()
|
||||
arg = val.Interface()
|
||||
}
|
||||
|
||||
switch kind {
|
||||
case reflect.String:
|
||||
v := val.String()
|
||||
if fi != nil {
|
||||
if fi.fieldType == TypeDateField || fi.fieldType == TypeDateTimeField {
|
||||
var t time.Time
|
||||
@ -75,16 +90,20 @@ outFor:
|
||||
}
|
||||
}
|
||||
arg = v
|
||||
case time.Time:
|
||||
if fi != nil && fi.fieldType == TypeDateField {
|
||||
arg = v.In(tz).Format(format_Date)
|
||||
} else {
|
||||
arg = v.In(tz).Format(format_DateTime)
|
||||
}
|
||||
default:
|
||||
kind := val.Kind()
|
||||
switch kind {
|
||||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
|
||||
arg = val.Int()
|
||||
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
|
||||
arg = val.Uint()
|
||||
case reflect.Float32:
|
||||
arg, _ = StrTo(ToStr(arg)).Float64()
|
||||
case reflect.Float64:
|
||||
arg = val.Float()
|
||||
case reflect.Bool:
|
||||
arg = val.Bool()
|
||||
case reflect.Slice, reflect.Array:
|
||||
if _, ok := arg.([]byte); ok {
|
||||
continue outFor
|
||||
}
|
||||
|
||||
var args []interface{}
|
||||
for i := 0; i < val.Len(); i++ {
|
||||
@ -107,16 +126,19 @@ outFor:
|
||||
params = append(params, p...)
|
||||
}
|
||||
continue outFor
|
||||
|
||||
case reflect.Ptr, reflect.Struct:
|
||||
ind := reflect.Indirect(val)
|
||||
|
||||
if ind.Kind() == reflect.Struct {
|
||||
typ := ind.Type()
|
||||
case reflect.Struct:
|
||||
if v, ok := arg.(time.Time); ok {
|
||||
if fi != nil && fi.fieldType == TypeDateField {
|
||||
arg = v.In(tz).Format(format_Date)
|
||||
} else {
|
||||
arg = v.In(tz).Format(format_DateTime)
|
||||
}
|
||||
} else {
|
||||
typ := val.Type()
|
||||
name := getFullName(typ)
|
||||
var value interface{}
|
||||
if mmi, ok := modelCache.getByFN(name); ok {
|
||||
if _, vu, exist := getExistPk(mmi, ind); exist {
|
||||
if _, vu, exist := getExistPk(mmi, val); exist {
|
||||
value = vu
|
||||
}
|
||||
}
|
||||
@ -125,11 +147,9 @@ outFor:
|
||||
if arg == nil {
|
||||
panic(fmt.Errorf("need a valid args value, unknown table or value `%s`", name))
|
||||
}
|
||||
} else {
|
||||
arg = ind.Interface()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
params = append(params, arg)
|
||||
}
|
||||
return
|
||||
|
@ -1,3 +1,9 @@
|
||||
// Beego (http://beego.me/)
|
||||
// @description beego is an open-source, high-performance web framework for the Go programming language.
|
||||
// @link http://github.com/astaxie/beego for the canonical source repository
|
||||
// @license http://github.com/astaxie/beego/blob/master/LICENSE
|
||||
// @authors slene
|
||||
|
||||
package orm
|
||||
|
||||
import (
|
||||
@ -41,6 +47,7 @@ var (
|
||||
}
|
||||
)
|
||||
|
||||
// model info collection
|
||||
type _modelCache struct {
|
||||
sync.RWMutex
|
||||
orders []string
|
||||
@ -49,6 +56,7 @@ type _modelCache struct {
|
||||
done bool
|
||||
}
|
||||
|
||||
// get all model info
|
||||
func (mc *_modelCache) all() map[string]*modelInfo {
|
||||
m := make(map[string]*modelInfo, len(mc.cache))
|
||||
for k, v := range mc.cache {
|
||||
@ -57,6 +65,7 @@ func (mc *_modelCache) all() map[string]*modelInfo {
|
||||
return m
|
||||
}
|
||||
|
||||
// get orderd model info
|
||||
func (mc *_modelCache) allOrdered() []*modelInfo {
|
||||
m := make([]*modelInfo, 0, len(mc.orders))
|
||||
for _, table := range mc.orders {
|
||||
@ -65,16 +74,19 @@ func (mc *_modelCache) allOrdered() []*modelInfo {
|
||||
return m
|
||||
}
|
||||
|
||||
// get model info by table name
|
||||
func (mc *_modelCache) get(table string) (mi *modelInfo, ok bool) {
|
||||
mi, ok = mc.cache[table]
|
||||
return
|
||||
}
|
||||
|
||||
// get model info by field name
|
||||
func (mc *_modelCache) getByFN(name string) (mi *modelInfo, ok bool) {
|
||||
mi, ok = mc.cacheByFN[name]
|
||||
return
|
||||
}
|
||||
|
||||
// set model info to collection
|
||||
func (mc *_modelCache) set(table string, mi *modelInfo) *modelInfo {
|
||||
mii := mc.cache[table]
|
||||
mc.cache[table] = mi
|
||||
@ -85,9 +97,16 @@ func (mc *_modelCache) set(table string, mi *modelInfo) *modelInfo {
|
||||
return mii
|
||||
}
|
||||
|
||||
// clean all model info.
|
||||
func (mc *_modelCache) clean() {
|
||||
mc.orders = make([]string, 0)
|
||||
mc.cache = make(map[string]*modelInfo)
|
||||
mc.cacheByFN = make(map[string]*modelInfo)
|
||||
mc.done = false
|
||||
}
|
||||
|
||||
// Clean model cache. Then you can re-RegisterModel.
|
||||
// Common use this api for test case.
|
||||
func ResetModelCache() {
|
||||
modelCache.clean()
|
||||
}
|
||||
|
@ -1,3 +1,9 @@
|
||||
// Beego (http://beego.me/)
|
||||
// @description beego is an open-source, high-performance web framework for the Go programming language.
|
||||
// @link http://github.com/astaxie/beego for the canonical source repository
|
||||
// @license http://github.com/astaxie/beego/blob/master/LICENSE
|
||||
// @authors slene
|
||||
|
||||
package orm
|
||||
|
||||
import (
|
||||
@ -8,7 +14,9 @@ import (
|
||||
"strings"
|
||||
)
|
||||
|
||||
func registerModel(model interface{}, prefix string) {
|
||||
// register models.
|
||||
// prefix means table name prefix.
|
||||
func registerModel(prefix string, model interface{}) {
|
||||
val := reflect.ValueOf(model)
|
||||
ind := reflect.Indirect(val)
|
||||
typ := ind.Type()
|
||||
@ -67,6 +75,7 @@ func registerModel(model interface{}, prefix string) {
|
||||
modelCache.set(table, info)
|
||||
}
|
||||
|
||||
// boostrap models
|
||||
func bootStrap() {
|
||||
if modelCache.done {
|
||||
return
|
||||
@ -281,27 +290,24 @@ end:
|
||||
}
|
||||
}
|
||||
|
||||
// register models
|
||||
func RegisterModel(models ...interface{}) {
|
||||
if modelCache.done {
|
||||
panic(fmt.Errorf("RegisterModel must be run before BootStrap"))
|
||||
}
|
||||
|
||||
for _, model := range models {
|
||||
registerModel(model, "")
|
||||
}
|
||||
RegisterModelWithPrefix("", models...)
|
||||
}
|
||||
|
||||
// register model with a prefix
|
||||
// register models with a prefix
|
||||
func RegisterModelWithPrefix(prefix string, models ...interface{}) {
|
||||
if modelCache.done {
|
||||
panic(fmt.Errorf("RegisterModel must be run before BootStrap"))
|
||||
}
|
||||
|
||||
for _, model := range models {
|
||||
registerModel(model, prefix)
|
||||
registerModel(prefix, model)
|
||||
}
|
||||
}
|
||||
|
||||
// bootrap models.
|
||||
// make all model parsed and can not add more models
|
||||
func BootStrap() {
|
||||
if modelCache.done {
|
||||
return
|
||||
|
@ -1,3 +1,9 @@
|
||||
// Beego (http://beego.me/)
|
||||
// @description beego is an open-source, high-performance web framework for the Go programming language.
|
||||
// @link http://github.com/astaxie/beego for the canonical source repository
|
||||
// @license http://github.com/astaxie/beego/blob/master/LICENSE
|
||||
// @authors slene
|
||||
|
||||
package orm
|
||||
|
||||
import (
|
||||
|
@ -1,3 +1,9 @@
|
||||
// Beego (http://beego.me/)
|
||||
// @description beego is an open-source, high-performance web framework for the Go programming language.
|
||||
// @link http://github.com/astaxie/beego for the canonical source repository
|
||||
// @license http://github.com/astaxie/beego/blob/master/LICENSE
|
||||
// @authors slene
|
||||
|
||||
package orm
|
||||
|
||||
import (
|
||||
@ -9,6 +15,7 @@ import (
|
||||
|
||||
var errSkipField = errors.New("skip field")
|
||||
|
||||
// field info collection
|
||||
type fields struct {
|
||||
pk *fieldInfo
|
||||
columns map[string]*fieldInfo
|
||||
@ -23,6 +30,7 @@ type fields struct {
|
||||
dbcols []string
|
||||
}
|
||||
|
||||
// add field info
|
||||
func (f *fields) Add(fi *fieldInfo) (added bool) {
|
||||
if f.fields[fi.name] == nil && f.columns[fi.column] == nil {
|
||||
f.columns[fi.column] = fi
|
||||
@ -49,14 +57,17 @@ func (f *fields) Add(fi *fieldInfo) (added bool) {
|
||||
return true
|
||||
}
|
||||
|
||||
// get field info by name
|
||||
func (f *fields) GetByName(name string) *fieldInfo {
|
||||
return f.fields[name]
|
||||
}
|
||||
|
||||
// get field info by column name
|
||||
func (f *fields) GetByColumn(column string) *fieldInfo {
|
||||
return f.columns[column]
|
||||
}
|
||||
|
||||
// get field info by string, name is prior
|
||||
func (f *fields) GetByAny(name string) (*fieldInfo, bool) {
|
||||
if fi, ok := f.fields[name]; ok {
|
||||
return fi, ok
|
||||
@ -70,6 +81,7 @@ func (f *fields) GetByAny(name string) (*fieldInfo, bool) {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
// create new field info collection
|
||||
func newFields() *fields {
|
||||
f := new(fields)
|
||||
f.fields = make(map[string]*fieldInfo)
|
||||
@ -79,6 +91,7 @@ func newFields() *fields {
|
||||
return f
|
||||
}
|
||||
|
||||
// single field info
|
||||
type fieldInfo struct {
|
||||
mi *modelInfo
|
||||
fieldIndex int
|
||||
@ -115,6 +128,7 @@ type fieldInfo struct {
|
||||
onDelete string
|
||||
}
|
||||
|
||||
// new field info
|
||||
func newFieldInfo(mi *modelInfo, field reflect.Value, sf reflect.StructField) (fi *fieldInfo, err error) {
|
||||
var (
|
||||
tag string
|
||||
|
@ -1,3 +1,9 @@
|
||||
// Beego (http://beego.me/)
|
||||
// @description beego is an open-source, high-performance web framework for the Go programming language.
|
||||
// @link http://github.com/astaxie/beego for the canonical source repository
|
||||
// @license http://github.com/astaxie/beego/blob/master/LICENSE
|
||||
// @authors slene
|
||||
|
||||
package orm
|
||||
|
||||
import (
|
||||
@ -7,6 +13,7 @@ import (
|
||||
"reflect"
|
||||
)
|
||||
|
||||
// single model info
|
||||
type modelInfo struct {
|
||||
pkg string
|
||||
name string
|
||||
@ -20,6 +27,7 @@ type modelInfo struct {
|
||||
isThrough bool
|
||||
}
|
||||
|
||||
// new model info
|
||||
func newModelInfo(val reflect.Value) (info *modelInfo) {
|
||||
var (
|
||||
err error
|
||||
@ -41,6 +49,9 @@ func newModelInfo(val reflect.Value) (info *modelInfo) {
|
||||
for i := 0; i < ind.NumField(); i++ {
|
||||
field := ind.Field(i)
|
||||
sf = ind.Type().Field(i)
|
||||
if sf.PkgPath != "" {
|
||||
continue
|
||||
}
|
||||
fi, err = newFieldInfo(info, field, sf)
|
||||
|
||||
if err != nil {
|
||||
@ -79,6 +90,8 @@ func newModelInfo(val reflect.Value) (info *modelInfo) {
|
||||
return
|
||||
}
|
||||
|
||||
// combine related model info to new model info.
|
||||
// prepare for relation models query.
|
||||
func newM2MModelInfo(m1, m2 *modelInfo) (info *modelInfo) {
|
||||
info = new(modelInfo)
|
||||
info.fields = newFields()
|
||||
|
@ -1,6 +1,13 @@
|
||||
// Beego (http://beego.me/)
|
||||
// @description beego is an open-source, high-performance web framework for the Go programming language.
|
||||
// @link http://github.com/astaxie/beego for the canonical source repository
|
||||
// @license http://github.com/astaxie/beego/blob/master/LICENSE
|
||||
// @authors slene
|
||||
|
||||
package orm
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
@ -82,7 +89,6 @@ func (e *JsonField) SetRaw(value interface{}) error {
|
||||
default:
|
||||
return fmt.Errorf("<JsonField.SetRaw> unknown value `%v`", value)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (e *JsonField) RawValue() interface{} {
|
||||
@ -121,7 +127,7 @@ type DataNull struct {
|
||||
Char string `orm:"null;size(50)"`
|
||||
Text string `orm:"null;type(text)"`
|
||||
Date time.Time `orm:"null;type(date)"`
|
||||
DateTime time.Time `orm:"null;column(datetime)""`
|
||||
DateTime time.Time `orm:"null;column(datetime)"`
|
||||
Byte byte `orm:"null"`
|
||||
Rune rune `orm:"null"`
|
||||
Int int `orm:"null"`
|
||||
@ -137,6 +143,49 @@ type DataNull struct {
|
||||
Float32 float32 `orm:"null"`
|
||||
Float64 float64 `orm:"null"`
|
||||
Decimal float64 `orm:"digits(8);decimals(4);null"`
|
||||
NullString sql.NullString `orm:"null"`
|
||||
NullBool sql.NullBool `orm:"null"`
|
||||
NullFloat64 sql.NullFloat64 `orm:"null"`
|
||||
NullInt64 sql.NullInt64 `orm:"null"`
|
||||
}
|
||||
|
||||
type String string
|
||||
type Boolean bool
|
||||
type Byte byte
|
||||
type Rune rune
|
||||
type Int int
|
||||
type Int8 int8
|
||||
type Int16 int16
|
||||
type Int32 int32
|
||||
type Int64 int64
|
||||
type Uint uint
|
||||
type Uint8 uint8
|
||||
type Uint16 uint16
|
||||
type Uint32 uint32
|
||||
type Uint64 uint64
|
||||
type Float32 float64
|
||||
type Float64 float64
|
||||
|
||||
type DataCustom struct {
|
||||
Id int
|
||||
Boolean Boolean
|
||||
Char string `orm:"size(50)"`
|
||||
Text string `orm:"type(text)"`
|
||||
Byte Byte
|
||||
Rune Rune
|
||||
Int Int
|
||||
Int8 Int8
|
||||
Int16 Int16
|
||||
Int32 Int32
|
||||
Int64 Int64
|
||||
Uint Uint
|
||||
Uint8 Uint8
|
||||
Uint16 Uint16
|
||||
Uint32 Uint32
|
||||
Uint64 Uint64
|
||||
Float32 Float32
|
||||
Float64 Float64
|
||||
Decimal Float64 `orm:"digits(8);decimals(4)"`
|
||||
}
|
||||
|
||||
// only for mysql
|
||||
@ -150,7 +199,7 @@ type User struct {
|
||||
UserName string `orm:"size(30);unique"`
|
||||
Email string `orm:"size(100)"`
|
||||
Password string `orm:"size(100)"`
|
||||
Status int16
|
||||
Status int16 `orm:"column(Status)"`
|
||||
IsStaff bool
|
||||
IsActive bool `orm:"default(1)"`
|
||||
Created time.Time `orm:"auto_now_add;type(date)"`
|
||||
@ -161,6 +210,8 @@ type User struct {
|
||||
Nums int
|
||||
Langs SliceStringField `orm:"size(100)"`
|
||||
Extra JsonField `orm:"type(text)"`
|
||||
unexport bool `orm:"-"`
|
||||
unexport_ bool
|
||||
}
|
||||
|
||||
func (u *User) TableIndex() [][]string {
|
||||
@ -303,9 +354,8 @@ go test -v github.com/astaxie/beego/orm
|
||||
|
||||
|
||||
#### Sqlite3
|
||||
touch /path/to/orm_test.db
|
||||
export ORM_DRIVER=sqlite3
|
||||
export ORM_SOURCE=/path/to/orm_test.db
|
||||
export ORM_SOURCE='file:memory_test?mode=memory'
|
||||
go test -v github.com/astaxie/beego/orm
|
||||
|
||||
|
||||
|
@ -1,16 +1,25 @@
|
||||
// Beego (http://beego.me/)
|
||||
// @description beego is an open-source, high-performance web framework for the Go programming language.
|
||||
// @link http://github.com/astaxie/beego for the canonical source repository
|
||||
// @license http://github.com/astaxie/beego/blob/master/LICENSE
|
||||
// @authors slene
|
||||
|
||||
package orm
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// get reflect.Type name with package path.
|
||||
func getFullName(typ reflect.Type) string {
|
||||
return typ.PkgPath() + "." + typ.Name()
|
||||
}
|
||||
|
||||
// get table name. method, or field name. auto snaked.
|
||||
func getTableName(val reflect.Value) string {
|
||||
ind := reflect.Indirect(val)
|
||||
fun := val.MethodByName("TableName")
|
||||
@ -26,6 +35,7 @@ func getTableName(val reflect.Value) string {
|
||||
return snakeString(ind.Type().Name())
|
||||
}
|
||||
|
||||
// get table engine, mysiam or innodb.
|
||||
func getTableEngine(val reflect.Value) string {
|
||||
fun := val.MethodByName("TableEngine")
|
||||
if fun.IsValid() {
|
||||
@ -40,6 +50,7 @@ func getTableEngine(val reflect.Value) string {
|
||||
return ""
|
||||
}
|
||||
|
||||
// get table index from method.
|
||||
func getTableIndex(val reflect.Value) [][]string {
|
||||
fun := val.MethodByName("TableIndex")
|
||||
if fun.IsValid() {
|
||||
@ -56,6 +67,7 @@ func getTableIndex(val reflect.Value) [][]string {
|
||||
return nil
|
||||
}
|
||||
|
||||
// get table unique from method
|
||||
func getTableUnique(val reflect.Value) [][]string {
|
||||
fun := val.MethodByName("TableUnique")
|
||||
if fun.IsValid() {
|
||||
@ -72,8 +84,8 @@ func getTableUnique(val reflect.Value) [][]string {
|
||||
return nil
|
||||
}
|
||||
|
||||
// get snaked column name
|
||||
func getColumnName(ft int, addrField reflect.Value, sf reflect.StructField, col string) string {
|
||||
col = strings.ToLower(col)
|
||||
column := col
|
||||
if col == "" {
|
||||
column = snakeString(sf.Name)
|
||||
@ -89,6 +101,7 @@ func getColumnName(ft int, addrField reflect.Value, sf reflect.StructField, col
|
||||
return column
|
||||
}
|
||||
|
||||
// return field type as type constant from reflect.Value
|
||||
func getFieldType(val reflect.Value) (ft int, err error) {
|
||||
elm := reflect.Indirect(val)
|
||||
switch elm.Kind() {
|
||||
@ -114,20 +127,27 @@ func getFieldType(val reflect.Value) (ft int, err error) {
|
||||
ft = TypeBooleanField
|
||||
case reflect.String:
|
||||
ft = TypeCharField
|
||||
case reflect.Invalid:
|
||||
default:
|
||||
if elm.CanInterface() {
|
||||
if _, ok := elm.Interface().(time.Time); ok {
|
||||
switch elm.Interface().(type) {
|
||||
case sql.NullInt64:
|
||||
ft = TypeBigIntegerField
|
||||
case sql.NullFloat64:
|
||||
ft = TypeFloatField
|
||||
case sql.NullBool:
|
||||
ft = TypeBooleanField
|
||||
case sql.NullString:
|
||||
ft = TypeCharField
|
||||
case time.Time:
|
||||
ft = TypeDateTimeField
|
||||
}
|
||||
}
|
||||
}
|
||||
if ft&IsFieldType == 0 {
|
||||
err = fmt.Errorf("unsupport field type %s, may be miss setting tag", val)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// parse struct tag string
|
||||
func parseStructTag(data string, attrs *map[string]bool, tags *map[string]string) {
|
||||
attr := make(map[string]bool)
|
||||
tag := make(map[string]string)
|
||||
|
159
orm/orm.go
159
orm/orm.go
@ -1,3 +1,9 @@
|
||||
// Beego (http://beego.me/)
|
||||
// @description beego is an open-source, high-performance web framework for the Go programming language.
|
||||
// @link http://github.com/astaxie/beego for the canonical source repository
|
||||
// @license http://github.com/astaxie/beego/blob/master/LICENSE
|
||||
// @authors slene
|
||||
|
||||
package orm
|
||||
|
||||
import (
|
||||
@ -25,6 +31,7 @@ var (
|
||||
ErrMultiRows = errors.New("<QuerySeter> return multi rows")
|
||||
ErrNoRows = errors.New("<QuerySeter> no row found")
|
||||
ErrStmtClosed = errors.New("<QuerySeter> stmt already closed")
|
||||
ErrArgs = errors.New("<Ormer> args error may be empty")
|
||||
ErrNotImplement = errors.New("have not implement")
|
||||
)
|
||||
|
||||
@ -39,11 +46,12 @@ type orm struct {
|
||||
|
||||
var _ Ormer = new(orm)
|
||||
|
||||
func (o *orm) getMiInd(md interface{}) (mi *modelInfo, ind reflect.Value) {
|
||||
// get model info and model reflect value
|
||||
func (o *orm) getMiInd(md interface{}, needPtr bool) (mi *modelInfo, ind reflect.Value) {
|
||||
val := reflect.ValueOf(md)
|
||||
ind = reflect.Indirect(val)
|
||||
typ := ind.Type()
|
||||
if val.Kind() != reflect.Ptr {
|
||||
if needPtr && val.Kind() != reflect.Ptr {
|
||||
panic(fmt.Errorf("<Ormer> cannot use non-ptr model struct `%s`", getFullName(typ)))
|
||||
}
|
||||
name := getFullName(typ)
|
||||
@ -53,6 +61,7 @@ func (o *orm) getMiInd(md interface{}) (mi *modelInfo, ind reflect.Value) {
|
||||
panic(fmt.Errorf("<Ormer> table: `%s` not found, maybe not RegisterModel", name))
|
||||
}
|
||||
|
||||
// get field info from model info by given field name
|
||||
func (o *orm) getFieldInfo(mi *modelInfo, name string) *fieldInfo {
|
||||
fi, ok := mi.fields.GetByAny(name)
|
||||
if !ok {
|
||||
@ -61,8 +70,9 @@ func (o *orm) getFieldInfo(mi *modelInfo, name string) *fieldInfo {
|
||||
return fi
|
||||
}
|
||||
|
||||
// read data to model
|
||||
func (o *orm) Read(md interface{}, cols ...string) error {
|
||||
mi, ind := o.getMiInd(md)
|
||||
mi, ind := o.getMiInd(md, true)
|
||||
err := o.alias.DbBaser.Read(o.db, mi, ind, o.alias.TZ, cols)
|
||||
if err != nil {
|
||||
return err
|
||||
@ -70,13 +80,35 @@ func (o *orm) Read(md interface{}, cols ...string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Try to read a row from the database, or insert one if it doesn't exist
|
||||
func (o *orm) ReadOrCreate(md interface{}, col1 string, cols ...string) (bool, int64, error) {
|
||||
cols = append([]string{col1}, cols...)
|
||||
mi, ind := o.getMiInd(md, true)
|
||||
err := o.alias.DbBaser.Read(o.db, mi, ind, o.alias.TZ, cols)
|
||||
if err == ErrNoRows {
|
||||
// Create
|
||||
id, err := o.Insert(md)
|
||||
return (err == nil), id, err
|
||||
}
|
||||
|
||||
return false, ind.Field(mi.fields.pk.fieldIndex).Int(), err
|
||||
}
|
||||
|
||||
// insert model data to database
|
||||
func (o *orm) Insert(md interface{}) (int64, error) {
|
||||
mi, ind := o.getMiInd(md)
|
||||
mi, ind := o.getMiInd(md, true)
|
||||
id, err := o.alias.DbBaser.Insert(o.db, mi, ind, o.alias.TZ)
|
||||
if err != nil {
|
||||
return id, err
|
||||
}
|
||||
if id > 0 {
|
||||
|
||||
o.setPk(mi, ind, id)
|
||||
|
||||
return id, nil
|
||||
}
|
||||
|
||||
// set auto pk field
|
||||
func (o *orm) setPk(mi *modelInfo, ind reflect.Value, id int64) {
|
||||
if mi.fields.pk.auto {
|
||||
if mi.fields.pk.fieldType&IsPostiveIntegerField > 0 {
|
||||
ind.Field(mi.fields.pk.fieldIndex).SetUint(uint64(id))
|
||||
@ -84,12 +116,47 @@ func (o *orm) Insert(md interface{}) (int64, error) {
|
||||
ind.Field(mi.fields.pk.fieldIndex).SetInt(id)
|
||||
}
|
||||
}
|
||||
}
|
||||
return id, nil
|
||||
}
|
||||
|
||||
// insert some models to database
|
||||
func (o *orm) InsertMulti(bulk int, mds interface{}) (int64, error) {
|
||||
var cnt int64
|
||||
|
||||
sind := reflect.Indirect(reflect.ValueOf(mds))
|
||||
|
||||
switch sind.Kind() {
|
||||
case reflect.Array, reflect.Slice:
|
||||
if sind.Len() == 0 {
|
||||
return cnt, ErrArgs
|
||||
}
|
||||
default:
|
||||
return cnt, ErrArgs
|
||||
}
|
||||
|
||||
if bulk <= 1 {
|
||||
for i := 0; i < sind.Len(); i++ {
|
||||
ind := sind.Index(i)
|
||||
mi, _ := o.getMiInd(ind.Interface(), false)
|
||||
id, err := o.alias.DbBaser.Insert(o.db, mi, ind, o.alias.TZ)
|
||||
if err != nil {
|
||||
return cnt, err
|
||||
}
|
||||
|
||||
o.setPk(mi, ind, id)
|
||||
|
||||
cnt += 1
|
||||
}
|
||||
} else {
|
||||
mi, _ := o.getMiInd(sind.Index(0).Interface(), false)
|
||||
return o.alias.DbBaser.InsertMulti(o.db, mi, sind, bulk, o.alias.TZ)
|
||||
}
|
||||
return cnt, nil
|
||||
}
|
||||
|
||||
// update model to database.
|
||||
// cols set the columns those want to update.
|
||||
func (o *orm) Update(md interface{}, cols ...string) (int64, error) {
|
||||
mi, ind := o.getMiInd(md)
|
||||
mi, ind := o.getMiInd(md, true)
|
||||
num, err := o.alias.DbBaser.Update(o.db, mi, ind, o.alias.TZ, cols)
|
||||
if err != nil {
|
||||
return num, err
|
||||
@ -97,26 +164,22 @@ func (o *orm) Update(md interface{}, cols ...string) (int64, error) {
|
||||
return num, nil
|
||||
}
|
||||
|
||||
// delete model in database
|
||||
func (o *orm) Delete(md interface{}) (int64, error) {
|
||||
mi, ind := o.getMiInd(md)
|
||||
mi, ind := o.getMiInd(md, true)
|
||||
num, err := o.alias.DbBaser.Delete(o.db, mi, ind, o.alias.TZ)
|
||||
if err != nil {
|
||||
return num, err
|
||||
}
|
||||
if num > 0 {
|
||||
if mi.fields.pk.auto {
|
||||
if mi.fields.pk.fieldType&IsPostiveIntegerField > 0 {
|
||||
ind.Field(mi.fields.pk.fieldIndex).SetUint(0)
|
||||
} else {
|
||||
ind.Field(mi.fields.pk.fieldIndex).SetInt(0)
|
||||
}
|
||||
}
|
||||
o.setPk(mi, ind, 0)
|
||||
}
|
||||
return num, nil
|
||||
}
|
||||
|
||||
// create a models to models queryer
|
||||
func (o *orm) QueryM2M(md interface{}, name string) QueryM2Mer {
|
||||
mi, ind := o.getMiInd(md)
|
||||
mi, ind := o.getMiInd(md, true)
|
||||
fi := o.getFieldInfo(mi, name)
|
||||
|
||||
switch {
|
||||
@ -129,6 +192,14 @@ func (o *orm) QueryM2M(md interface{}, name string) QueryM2Mer {
|
||||
return newQueryM2M(md, o, mi, fi, ind)
|
||||
}
|
||||
|
||||
// load related models to md model.
|
||||
// args are limit, offset int and order string.
|
||||
//
|
||||
// example:
|
||||
// orm.LoadRelated(post,"Tags")
|
||||
// for _,tag := range post.Tags{...}
|
||||
//
|
||||
// make sure the relation is defined in model struct tags.
|
||||
func (o *orm) LoadRelated(md interface{}, name string, args ...interface{}) (int64, error) {
|
||||
_, fi, ind, qseter := o.queryRelated(md, name)
|
||||
|
||||
@ -190,14 +261,21 @@ func (o *orm) LoadRelated(md interface{}, name string, args ...interface{}) (int
|
||||
return nums, err
|
||||
}
|
||||
|
||||
// return a QuerySeter for related models to md model.
|
||||
// it can do all, update, delete in QuerySeter.
|
||||
// example:
|
||||
// qs := orm.QueryRelated(post,"Tag")
|
||||
// qs.All(&[]*Tag{})
|
||||
//
|
||||
func (o *orm) QueryRelated(md interface{}, name string) QuerySeter {
|
||||
// is this api needed ?
|
||||
_, _, _, qs := o.queryRelated(md, name)
|
||||
return qs
|
||||
}
|
||||
|
||||
// get QuerySeter for related models to md model
|
||||
func (o *orm) queryRelated(md interface{}, name string) (*modelInfo, *fieldInfo, reflect.Value, QuerySeter) {
|
||||
mi, ind := o.getMiInd(md)
|
||||
mi, ind := o.getMiInd(md, true)
|
||||
fi := o.getFieldInfo(mi, name)
|
||||
|
||||
_, _, exist := getExistPk(mi, ind)
|
||||
@ -221,12 +299,13 @@ func (o *orm) queryRelated(md interface{}, name string) (*modelInfo, *fieldInfo,
|
||||
}
|
||||
|
||||
if qs == nil {
|
||||
panic(fmt.Errorf("<Ormer> name `%s` for model `%s` is not an available rel/reverse field"))
|
||||
panic(fmt.Errorf("<Ormer> name `%s` for model `%s` is not an available rel/reverse field", md, name))
|
||||
}
|
||||
|
||||
return mi, fi, ind, qs
|
||||
}
|
||||
|
||||
// get reverse relation QuerySeter
|
||||
func (o *orm) getReverseQs(md interface{}, mi *modelInfo, fi *fieldInfo) *querySet {
|
||||
switch fi.fieldType {
|
||||
case RelReverseOne, RelReverseMany:
|
||||
@ -247,6 +326,7 @@ func (o *orm) getReverseQs(md interface{}, mi *modelInfo, fi *fieldInfo) *queryS
|
||||
return q
|
||||
}
|
||||
|
||||
// get relation QuerySeter
|
||||
func (o *orm) getRelQs(md interface{}, mi *modelInfo, fi *fieldInfo) *querySet {
|
||||
switch fi.fieldType {
|
||||
case RelOneToOne, RelForeignKey, RelManyToMany:
|
||||
@ -266,6 +346,9 @@ func (o *orm) getRelQs(md interface{}, mi *modelInfo, fi *fieldInfo) *querySet {
|
||||
return q
|
||||
}
|
||||
|
||||
// return a QuerySeter for table operations.
|
||||
// table name can be string or struct.
|
||||
// e.g. QueryTable("user"), QueryTable(&user{}) or QueryTable((*User)(nil)),
|
||||
func (o *orm) QueryTable(ptrStructOrTableName interface{}) (qs QuerySeter) {
|
||||
name := ""
|
||||
if table, ok := ptrStructOrTableName.(string); ok {
|
||||
@ -285,6 +368,7 @@ func (o *orm) QueryTable(ptrStructOrTableName interface{}) (qs QuerySeter) {
|
||||
return
|
||||
}
|
||||
|
||||
// switch to another registered database driver by given name.
|
||||
func (o *orm) Using(name string) error {
|
||||
if o.isTx {
|
||||
panic(fmt.Errorf("<Ormer.Using> transaction has been start, cannot change db"))
|
||||
@ -302,6 +386,7 @@ func (o *orm) Using(name string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// begin transaction
|
||||
func (o *orm) Begin() error {
|
||||
if o.isTx {
|
||||
return ErrTxHasBegan
|
||||
@ -320,6 +405,7 @@ func (o *orm) Begin() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// commit transaction
|
||||
func (o *orm) Commit() error {
|
||||
if o.isTx == false {
|
||||
return ErrTxDone
|
||||
@ -334,6 +420,7 @@ func (o *orm) Commit() error {
|
||||
return err
|
||||
}
|
||||
|
||||
// rollback transaction
|
||||
func (o *orm) Rollback() error {
|
||||
if o.isTx == false {
|
||||
return ErrTxDone
|
||||
@ -348,14 +435,21 @@ func (o *orm) Rollback() error {
|
||||
return err
|
||||
}
|
||||
|
||||
// return a raw query seter for raw sql string.
|
||||
func (o *orm) Raw(query string, args ...interface{}) RawSeter {
|
||||
return newRawSet(o, query, args)
|
||||
}
|
||||
|
||||
// return current using database Driver
|
||||
func (o *orm) Driver() Driver {
|
||||
return driver(o.alias.Name)
|
||||
}
|
||||
|
||||
func (o *orm) GetDB() dbQuerier {
|
||||
panic(ErrNotImplement)
|
||||
}
|
||||
|
||||
// create new orm
|
||||
func NewOrm() Ormer {
|
||||
BootStrap() // execute only once
|
||||
|
||||
@ -366,3 +460,30 @@ func NewOrm() Ormer {
|
||||
}
|
||||
return o
|
||||
}
|
||||
|
||||
// create a new ormer object with specify *sql.DB for query
|
||||
func NewOrmWithDB(driverName, aliasName string, db *sql.DB) (Ormer, error) {
|
||||
var al *alias
|
||||
|
||||
if dr, ok := drivers[driverName]; ok {
|
||||
al = new(alias)
|
||||
al.DbBaser = dbBasers[dr]
|
||||
al.Driver = dr
|
||||
} else {
|
||||
return nil, fmt.Errorf("driver name `%s` have not registered", driverName)
|
||||
}
|
||||
|
||||
al.Name = aliasName
|
||||
al.DriverName = driverName
|
||||
|
||||
o := new(orm)
|
||||
o.alias = al
|
||||
|
||||
if Debug {
|
||||
o.db = newDbQueryLog(o.alias, db)
|
||||
} else {
|
||||
o.db = db
|
||||
}
|
||||
|
||||
return o, nil
|
||||
}
|
||||
|
@ -1,3 +1,9 @@
|
||||
// Beego (http://beego.me/)
|
||||
// @description beego is an open-source, high-performance web framework for the Go programming language.
|
||||
// @link http://github.com/astaxie/beego for the canonical source repository
|
||||
// @license http://github.com/astaxie/beego/blob/master/LICENSE
|
||||
// @authors slene
|
||||
|
||||
package orm
|
||||
|
||||
import (
|
||||
@ -18,15 +24,19 @@ type condValue struct {
|
||||
isCond bool
|
||||
}
|
||||
|
||||
// condition struct.
|
||||
// work for WHERE conditions.
|
||||
type Condition struct {
|
||||
params []condValue
|
||||
}
|
||||
|
||||
// return new condition struct
|
||||
func NewCondition() *Condition {
|
||||
c := &Condition{}
|
||||
return c
|
||||
}
|
||||
|
||||
// add expression to condition
|
||||
func (c Condition) And(expr string, args ...interface{}) *Condition {
|
||||
if expr == "" || len(args) == 0 {
|
||||
panic(fmt.Errorf("<Condition.And> args cannot empty"))
|
||||
@ -35,6 +45,7 @@ func (c Condition) And(expr string, args ...interface{}) *Condition {
|
||||
return &c
|
||||
}
|
||||
|
||||
// add NOT expression to condition
|
||||
func (c Condition) AndNot(expr string, args ...interface{}) *Condition {
|
||||
if expr == "" || len(args) == 0 {
|
||||
panic(fmt.Errorf("<Condition.AndNot> args cannot empty"))
|
||||
@ -43,6 +54,7 @@ func (c Condition) AndNot(expr string, args ...interface{}) *Condition {
|
||||
return &c
|
||||
}
|
||||
|
||||
// combine a condition to current condition
|
||||
func (c *Condition) AndCond(cond *Condition) *Condition {
|
||||
c = c.clone()
|
||||
if c == cond {
|
||||
@ -54,6 +66,7 @@ func (c *Condition) AndCond(cond *Condition) *Condition {
|
||||
return c
|
||||
}
|
||||
|
||||
// add OR expression to condition
|
||||
func (c Condition) Or(expr string, args ...interface{}) *Condition {
|
||||
if expr == "" || len(args) == 0 {
|
||||
panic(fmt.Errorf("<Condition.Or> args cannot empty"))
|
||||
@ -62,6 +75,7 @@ func (c Condition) Or(expr string, args ...interface{}) *Condition {
|
||||
return &c
|
||||
}
|
||||
|
||||
// add OR NOT expression to condition
|
||||
func (c Condition) OrNot(expr string, args ...interface{}) *Condition {
|
||||
if expr == "" || len(args) == 0 {
|
||||
panic(fmt.Errorf("<Condition.OrNot> args cannot empty"))
|
||||
@ -70,6 +84,7 @@ func (c Condition) OrNot(expr string, args ...interface{}) *Condition {
|
||||
return &c
|
||||
}
|
||||
|
||||
// combine a OR condition to current condition
|
||||
func (c *Condition) OrCond(cond *Condition) *Condition {
|
||||
c = c.clone()
|
||||
if c == cond {
|
||||
@ -81,10 +96,12 @@ func (c *Condition) OrCond(cond *Condition) *Condition {
|
||||
return c
|
||||
}
|
||||
|
||||
// check the condition arguments are empty or not.
|
||||
func (c *Condition) IsEmpty() bool {
|
||||
return len(c.params) == 0
|
||||
}
|
||||
|
||||
// clone a condition
|
||||
func (c Condition) clone() *Condition {
|
||||
return &c
|
||||
}
|
||||
|
@ -1,3 +1,9 @@
|
||||
// Beego (http://beego.me/)
|
||||
// @description beego is an open-source, high-performance web framework for the Go programming language.
|
||||
// @link http://github.com/astaxie/beego for the canonical source repository
|
||||
// @license http://github.com/astaxie/beego/blob/master/LICENSE
|
||||
// @authors slene
|
||||
|
||||
package orm
|
||||
|
||||
import (
|
||||
@ -13,6 +19,7 @@ type Log struct {
|
||||
*log.Logger
|
||||
}
|
||||
|
||||
// set io.Writer to create a Logger.
|
||||
func NewLog(out io.Writer) *Log {
|
||||
d := new(Log)
|
||||
d.Logger = log.New(out, "[ORM]", 1e9)
|
||||
@ -40,6 +47,8 @@ func debugLogQueies(alias *alias, operaton, query string, t time.Time, err error
|
||||
DebugLog.Println(con)
|
||||
}
|
||||
|
||||
// statement query logger struct.
|
||||
// if dev mode, use stmtQueryLog, or use stmtQuerier.
|
||||
type stmtQueryLog struct {
|
||||
alias *alias
|
||||
query string
|
||||
@ -84,6 +93,8 @@ func newStmtQueryLog(alias *alias, stmt stmtQuerier, query string) stmtQuerier {
|
||||
return d
|
||||
}
|
||||
|
||||
// database query logger struct.
|
||||
// if dev mode, use dbQueryLog, or use dbQuerier.
|
||||
type dbQueryLog struct {
|
||||
alias *alias
|
||||
db dbQuerier
|
||||
|
@ -1,3 +1,9 @@
|
||||
// Beego (http://beego.me/)
|
||||
// @description beego is an open-source, high-performance web framework for the Go programming language.
|
||||
// @link http://github.com/astaxie/beego for the canonical source repository
|
||||
// @license http://github.com/astaxie/beego/blob/master/LICENSE
|
||||
// @authors slene
|
||||
|
||||
package orm
|
||||
|
||||
import (
|
||||
@ -5,6 +11,7 @@ import (
|
||||
"reflect"
|
||||
)
|
||||
|
||||
// an insert queryer struct
|
||||
type insertSet struct {
|
||||
mi *modelInfo
|
||||
orm *orm
|
||||
@ -14,6 +21,7 @@ type insertSet struct {
|
||||
|
||||
var _ Inserter = new(insertSet)
|
||||
|
||||
// insert model ignore it's registered or not.
|
||||
func (o *insertSet) Insert(md interface{}) (int64, error) {
|
||||
if o.closed {
|
||||
return 0, ErrStmtClosed
|
||||
@ -44,6 +52,7 @@ func (o *insertSet) Insert(md interface{}) (int64, error) {
|
||||
return id, nil
|
||||
}
|
||||
|
||||
// close insert queryer statement
|
||||
func (o *insertSet) Close() error {
|
||||
if o.closed {
|
||||
return ErrStmtClosed
|
||||
@ -52,6 +61,7 @@ func (o *insertSet) Close() error {
|
||||
return o.stmt.Close()
|
||||
}
|
||||
|
||||
// create new insert queryer.
|
||||
func newInsertSet(orm *orm, mi *modelInfo) (Inserter, error) {
|
||||
bi := new(insertSet)
|
||||
bi.orm = orm
|
||||
|
@ -1,9 +1,16 @@
|
||||
// Beego (http://beego.me/)
|
||||
// @description beego is an open-source, high-performance web framework for the Go programming language.
|
||||
// @link http://github.com/astaxie/beego for the canonical source repository
|
||||
// @license http://github.com/astaxie/beego/blob/master/LICENSE
|
||||
// @authors slene
|
||||
|
||||
package orm
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
)
|
||||
|
||||
// model to model struct
|
||||
type queryM2M struct {
|
||||
md interface{}
|
||||
mi *modelInfo
|
||||
@ -12,6 +19,13 @@ type queryM2M struct {
|
||||
ind reflect.Value
|
||||
}
|
||||
|
||||
// add models to origin models when creating queryM2M.
|
||||
// example:
|
||||
// m2m := orm.QueryM2M(post,"Tag")
|
||||
// m2m.Add(&Tag1{},&Tag2{})
|
||||
// for _,tag := range post.Tags{}
|
||||
//
|
||||
// make sure the relation is defined in post model struct tag.
|
||||
func (o *queryM2M) Add(mds ...interface{}) (int64, error) {
|
||||
fi := o.fi
|
||||
mi := fi.relThroughModelInfo
|
||||
@ -44,7 +58,8 @@ func (o *queryM2M) Add(mds ...interface{}) (int64, error) {
|
||||
|
||||
names := []string{mfi.column, rfi.column}
|
||||
|
||||
var nums int64
|
||||
values := make([]interface{}, 0, len(models)*2)
|
||||
|
||||
for _, md := range models {
|
||||
|
||||
ind := reflect.Indirect(reflect.ValueOf(md))
|
||||
@ -59,18 +74,14 @@ func (o *queryM2M) Add(mds ...interface{}) (int64, error) {
|
||||
}
|
||||
}
|
||||
|
||||
values := []interface{}{v1, v2}
|
||||
_, err := dbase.InsertValue(orm.db, mi, names, values)
|
||||
if err != nil {
|
||||
return nums, err
|
||||
values = append(values, v1, v2)
|
||||
|
||||
}
|
||||
|
||||
nums += 1
|
||||
}
|
||||
|
||||
return nums, nil
|
||||
return dbase.InsertValue(orm.db, mi, true, names, values)
|
||||
}
|
||||
|
||||
// remove models following the origin model relationship
|
||||
func (o *queryM2M) Remove(mds ...interface{}) (int64, error) {
|
||||
fi := o.fi
|
||||
qs := o.qs.Filter(fi.reverseFieldInfo.name, o.md)
|
||||
@ -82,17 +93,20 @@ func (o *queryM2M) Remove(mds ...interface{}) (int64, error) {
|
||||
return nums, nil
|
||||
}
|
||||
|
||||
// check model is existed in relationship of origin model
|
||||
func (o *queryM2M) Exist(md interface{}) bool {
|
||||
fi := o.fi
|
||||
return o.qs.Filter(fi.reverseFieldInfo.name, o.md).
|
||||
Filter(fi.reverseFieldInfoTwo.name, md).Exist()
|
||||
}
|
||||
|
||||
// clean all models in related of origin model
|
||||
func (o *queryM2M) Clear() (int64, error) {
|
||||
fi := o.fi
|
||||
return o.qs.Filter(fi.reverseFieldInfo.name, o.md).Delete()
|
||||
}
|
||||
|
||||
// count all related models of origin model
|
||||
func (o *queryM2M) Count() (int64, error) {
|
||||
fi := o.fi
|
||||
return o.qs.Filter(fi.reverseFieldInfo.name, o.md).Count()
|
||||
@ -100,6 +114,7 @@ func (o *queryM2M) Count() (int64, error) {
|
||||
|
||||
var _ QueryM2Mer = new(queryM2M)
|
||||
|
||||
// create new M2M queryer.
|
||||
func newQueryM2M(md interface{}, o *orm, mi *modelInfo, fi *fieldInfo, ind reflect.Value) QueryM2Mer {
|
||||
qm2m := new(queryM2M)
|
||||
qm2m.md = md
|
||||
|
@ -1,3 +1,9 @@
|
||||
// Beego (http://beego.me/)
|
||||
// @description beego is an open-source, high-performance web framework for the Go programming language.
|
||||
// @link http://github.com/astaxie/beego for the canonical source repository
|
||||
// @license http://github.com/astaxie/beego/blob/master/LICENSE
|
||||
// @authors slene
|
||||
|
||||
package orm
|
||||
|
||||
import (
|
||||
@ -18,6 +24,10 @@ const (
|
||||
Col_Except
|
||||
)
|
||||
|
||||
// ColValue do the field raw changes. e.g Nums = Nums + 10. usage:
|
||||
// Params{
|
||||
// "Nums": ColValue(Col_Add, 10),
|
||||
// }
|
||||
func ColValue(opt operator, value interface{}) interface{} {
|
||||
switch opt {
|
||||
case Col_Add, Col_Minus, Col_Multiply, Col_Except:
|
||||
@ -34,6 +44,7 @@ func ColValue(opt operator, value interface{}) interface{} {
|
||||
return val
|
||||
}
|
||||
|
||||
// real query struct
|
||||
type querySet struct {
|
||||
mi *modelInfo
|
||||
cond *Condition
|
||||
@ -47,6 +58,7 @@ type querySet struct {
|
||||
|
||||
var _ QuerySeter = new(querySet)
|
||||
|
||||
// add condition expression to QuerySeter.
|
||||
func (o querySet) Filter(expr string, args ...interface{}) QuerySeter {
|
||||
if o.cond == nil {
|
||||
o.cond = NewCondition()
|
||||
@ -55,6 +67,7 @@ func (o querySet) Filter(expr string, args ...interface{}) QuerySeter {
|
||||
return &o
|
||||
}
|
||||
|
||||
// add NOT condition to querySeter.
|
||||
func (o querySet) Exclude(expr string, args ...interface{}) QuerySeter {
|
||||
if o.cond == nil {
|
||||
o.cond = NewCondition()
|
||||
@ -63,10 +76,13 @@ func (o querySet) Exclude(expr string, args ...interface{}) QuerySeter {
|
||||
return &o
|
||||
}
|
||||
|
||||
// set offset number
|
||||
func (o *querySet) setOffset(num interface{}) {
|
||||
o.offset = ToInt64(num)
|
||||
}
|
||||
|
||||
// add LIMIT value.
|
||||
// args[0] means offset, e.g. LIMIT num,offset.
|
||||
func (o querySet) Limit(limit interface{}, args ...interface{}) QuerySeter {
|
||||
o.limit = ToInt64(limit)
|
||||
if len(args) > 0 {
|
||||
@ -75,16 +91,21 @@ func (o querySet) Limit(limit interface{}, args ...interface{}) QuerySeter {
|
||||
return &o
|
||||
}
|
||||
|
||||
// add OFFSET value
|
||||
func (o querySet) Offset(offset interface{}) QuerySeter {
|
||||
o.setOffset(offset)
|
||||
return &o
|
||||
}
|
||||
|
||||
// add ORDER expression.
|
||||
// "column" means ASC, "-column" means DESC.
|
||||
func (o querySet) OrderBy(exprs ...string) QuerySeter {
|
||||
o.orders = exprs
|
||||
return &o
|
||||
}
|
||||
|
||||
// set relation model to query together.
|
||||
// it will query relation models and assign to parent model.
|
||||
func (o querySet) RelatedSel(params ...interface{}) QuerySeter {
|
||||
var related []string
|
||||
if len(params) == 0 {
|
||||
@ -105,36 +126,50 @@ func (o querySet) RelatedSel(params ...interface{}) QuerySeter {
|
||||
return &o
|
||||
}
|
||||
|
||||
// set condition to QuerySeter.
|
||||
func (o querySet) SetCond(cond *Condition) QuerySeter {
|
||||
o.cond = cond
|
||||
return &o
|
||||
}
|
||||
|
||||
// return QuerySeter execution result number
|
||||
func (o *querySet) Count() (int64, error) {
|
||||
return o.orm.alias.DbBaser.Count(o.orm.db, o, o.mi, o.cond, o.orm.alias.TZ)
|
||||
}
|
||||
|
||||
// check result empty or not after QuerySeter executed
|
||||
func (o *querySet) Exist() bool {
|
||||
cnt, _ := o.orm.alias.DbBaser.Count(o.orm.db, o, o.mi, o.cond, o.orm.alias.TZ)
|
||||
return cnt > 0
|
||||
}
|
||||
|
||||
// execute update with parameters
|
||||
func (o *querySet) Update(values Params) (int64, error) {
|
||||
return o.orm.alias.DbBaser.UpdateBatch(o.orm.db, o, o.mi, o.cond, values, o.orm.alias.TZ)
|
||||
}
|
||||
|
||||
// execute delete
|
||||
func (o *querySet) Delete() (int64, error) {
|
||||
return o.orm.alias.DbBaser.DeleteBatch(o.orm.db, o, o.mi, o.cond, o.orm.alias.TZ)
|
||||
}
|
||||
|
||||
// return a insert queryer.
|
||||
// it can be used in times.
|
||||
// example:
|
||||
// i,err := sq.PrepareInsert()
|
||||
// i.Add(&user1{},&user2{})
|
||||
func (o *querySet) PrepareInsert() (Inserter, error) {
|
||||
return newInsertSet(o.orm, o.mi)
|
||||
}
|
||||
|
||||
// query all data and map to containers.
|
||||
// cols means the columns when querying.
|
||||
func (o *querySet) All(container interface{}, cols ...string) (int64, error) {
|
||||
return o.orm.alias.DbBaser.ReadBatch(o.orm.db, o, o.mi, o.cond, container, o.orm.alias.TZ, cols)
|
||||
}
|
||||
|
||||
// query one row data and map to containers.
|
||||
// cols means the columns when querying.
|
||||
func (o *querySet) One(container interface{}, cols ...string) error {
|
||||
num, err := o.orm.alias.DbBaser.ReadBatch(o.orm.db, o, o.mi, o.cond, container, o.orm.alias.TZ, cols)
|
||||
if err != nil {
|
||||
@ -149,18 +184,54 @@ func (o *querySet) One(container interface{}, cols ...string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// query all data and map to []map[string]interface.
|
||||
// expres means condition expression.
|
||||
// it converts data to []map[column]value.
|
||||
func (o *querySet) Values(results *[]Params, exprs ...string) (int64, error) {
|
||||
return o.orm.alias.DbBaser.ReadValues(o.orm.db, o, o.mi, o.cond, exprs, results, o.orm.alias.TZ)
|
||||
}
|
||||
|
||||
// query all data and map to [][]interface
|
||||
// it converts data to [][column_index]value
|
||||
func (o *querySet) ValuesList(results *[]ParamsList, exprs ...string) (int64, error) {
|
||||
return o.orm.alias.DbBaser.ReadValues(o.orm.db, o, o.mi, o.cond, exprs, results, o.orm.alias.TZ)
|
||||
}
|
||||
|
||||
// query all data and map to []interface.
|
||||
// it's designed for one row record set, auto change to []value, not [][column]value.
|
||||
func (o *querySet) ValuesFlat(result *ParamsList, expr string) (int64, error) {
|
||||
return o.orm.alias.DbBaser.ReadValues(o.orm.db, o, o.mi, o.cond, []string{expr}, result, o.orm.alias.TZ)
|
||||
}
|
||||
|
||||
// query all rows into map[string]interface with specify key and value column name.
|
||||
// keyCol = "name", valueCol = "value"
|
||||
// table data
|
||||
// name | value
|
||||
// total | 100
|
||||
// found | 200
|
||||
// to map[string]interface{}{
|
||||
// "total": 100,
|
||||
// "found": 200,
|
||||
// }
|
||||
func (o *querySet) RowsToMap(result *Params, keyCol, valueCol string) (int64, error) {
|
||||
panic(ErrNotImplement)
|
||||
}
|
||||
|
||||
// query all rows into struct with specify key and value column name.
|
||||
// keyCol = "name", valueCol = "value"
|
||||
// table data
|
||||
// name | value
|
||||
// total | 100
|
||||
// found | 200
|
||||
// to struct {
|
||||
// Total int
|
||||
// Found int
|
||||
// }
|
||||
func (o *querySet) RowsToStruct(ptrStruct interface{}, keyCol, valueCol string) (int64, error) {
|
||||
panic(ErrNotImplement)
|
||||
}
|
||||
|
||||
// create new QuerySeter.
|
||||
func newQuerySet(orm *orm, mi *modelInfo) QuerySeter {
|
||||
o := new(querySet)
|
||||
o.mi = mi
|
||||
|
493
orm/orm_raw.go
493
orm/orm_raw.go
@ -1,13 +1,19 @@
|
||||
// Beego (http://beego.me/)
|
||||
// @description beego is an open-source, high-performance web framework for the Go programming language.
|
||||
// @link http://github.com/astaxie/beego for the canonical source repository
|
||||
// @license http://github.com/astaxie/beego/blob/master/LICENSE
|
||||
// @authors slene
|
||||
|
||||
package orm
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// raw sql string prepared statement
|
||||
type rawPrepare struct {
|
||||
rs *rawSet
|
||||
stmt stmtQuerier
|
||||
@ -45,6 +51,7 @@ func newRawPreparer(rs *rawSet) (RawPreparer, error) {
|
||||
return o, nil
|
||||
}
|
||||
|
||||
// raw query seter
|
||||
type rawSet struct {
|
||||
query string
|
||||
args []interface{}
|
||||
@ -53,11 +60,13 @@ type rawSet struct {
|
||||
|
||||
var _ RawSeter = new(rawSet)
|
||||
|
||||
// set args for every query
|
||||
func (o rawSet) SetArgs(args ...interface{}) RawSeter {
|
||||
o.args = args
|
||||
return &o
|
||||
}
|
||||
|
||||
// execute raw sql and return sql.Result
|
||||
func (o *rawSet) Exec() (sql.Result, error) {
|
||||
query := o.query
|
||||
o.orm.alias.DbBaser.ReplaceMarks(&query)
|
||||
@ -66,6 +75,7 @@ func (o *rawSet) Exec() (sql.Result, error) {
|
||||
return o.orm.db.Exec(query, args...)
|
||||
}
|
||||
|
||||
// set field value to row container
|
||||
func (o *rawSet) setFieldValue(ind reflect.Value, value interface{}) {
|
||||
switch ind.Kind() {
|
||||
case reflect.Bool:
|
||||
@ -164,65 +174,12 @@ func (o *rawSet) setFieldValue(ind reflect.Value, value interface{}) {
|
||||
}
|
||||
}
|
||||
|
||||
func (o *rawSet) loopInitRefs(typ reflect.Type, refsPtr *[]interface{}, sIdxesPtr *[][]int) {
|
||||
sIdxes := *sIdxesPtr
|
||||
refs := *refsPtr
|
||||
|
||||
if typ.Kind() == reflect.Struct {
|
||||
if typ.String() == "time.Time" {
|
||||
var ref interface{}
|
||||
refs = append(refs, &ref)
|
||||
sIdxes = append(sIdxes, []int{0})
|
||||
} else {
|
||||
idxs := []int{}
|
||||
outFor:
|
||||
for idx := 0; idx < typ.NumField(); idx++ {
|
||||
ctyp := typ.Field(idx)
|
||||
|
||||
tag := ctyp.Tag.Get(defaultStructTagName)
|
||||
for _, v := range strings.Split(tag, defaultStructTagDelim) {
|
||||
if v == "-" {
|
||||
continue outFor
|
||||
}
|
||||
}
|
||||
|
||||
tp := ctyp.Type
|
||||
if tp.Kind() == reflect.Ptr {
|
||||
tp = tp.Elem()
|
||||
}
|
||||
|
||||
if tp.String() == "time.Time" {
|
||||
var ref interface{}
|
||||
refs = append(refs, &ref)
|
||||
|
||||
} else if tp.Kind() != reflect.Struct {
|
||||
var ref interface{}
|
||||
refs = append(refs, &ref)
|
||||
|
||||
} else {
|
||||
// skip other type
|
||||
continue
|
||||
}
|
||||
|
||||
idxs = append(idxs, idx)
|
||||
}
|
||||
sIdxes = append(sIdxes, idxs)
|
||||
}
|
||||
} else {
|
||||
var ref interface{}
|
||||
refs = append(refs, &ref)
|
||||
sIdxes = append(sIdxes, []int{0})
|
||||
}
|
||||
|
||||
*sIdxesPtr = sIdxes
|
||||
*refsPtr = refs
|
||||
}
|
||||
|
||||
func (o *rawSet) loopSetRefs(refs []interface{}, sIdxes [][]int, sInds []reflect.Value, nIndsPtr *[]reflect.Value, eTyps []reflect.Type, init bool) {
|
||||
// set field value in loop for slice container
|
||||
func (o *rawSet) loopSetRefs(refs []interface{}, sInds []reflect.Value, nIndsPtr *[]reflect.Value, eTyps []reflect.Type, init bool) {
|
||||
nInds := *nIndsPtr
|
||||
|
||||
cur := 0
|
||||
for i, idxs := range sIdxes {
|
||||
for i := 0; i < len(sInds); i++ {
|
||||
sInd := sInds[i]
|
||||
eTyp := eTyps[i]
|
||||
|
||||
@ -258,32 +215,8 @@ func (o *rawSet) loopSetRefs(refs []interface{}, sIdxes [][]int, sInds []reflect
|
||||
o.setFieldValue(ind, value)
|
||||
}
|
||||
cur++
|
||||
} else {
|
||||
hasValue := false
|
||||
for _, idx := range idxs {
|
||||
tind := ind.Field(idx)
|
||||
value := reflect.ValueOf(refs[cur]).Elem().Interface()
|
||||
if value != nil {
|
||||
hasValue = true
|
||||
}
|
||||
if tind.Kind() == reflect.Ptr {
|
||||
if value == nil {
|
||||
tindV := reflect.New(tind.Type()).Elem()
|
||||
tind.Set(tindV)
|
||||
} else {
|
||||
tindV := reflect.New(tind.Type().Elem())
|
||||
o.setFieldValue(tindV.Elem(), value)
|
||||
tind.Set(tindV)
|
||||
}
|
||||
} else {
|
||||
o.setFieldValue(tind, value)
|
||||
}
|
||||
cur++
|
||||
}
|
||||
if hasValue == false && isPtr {
|
||||
val = reflect.New(val.Type()).Elem()
|
||||
}
|
||||
}
|
||||
|
||||
} else {
|
||||
value := reflect.ValueOf(refs[cur]).Elem().Interface()
|
||||
if isPtr && value == nil {
|
||||
@ -312,16 +245,14 @@ func (o *rawSet) loopSetRefs(refs []interface{}, sIdxes [][]int, sInds []reflect
|
||||
}
|
||||
}
|
||||
|
||||
// query data and map to container
|
||||
func (o *rawSet) QueryRow(containers ...interface{}) error {
|
||||
if len(containers) == 0 {
|
||||
panic(fmt.Errorf("<RawSeter.QueryRow> need at least one arg"))
|
||||
}
|
||||
|
||||
refs := make([]interface{}, 0, len(containers))
|
||||
sIdxes := make([][]int, 0)
|
||||
sInds := make([]reflect.Value, 0)
|
||||
eTyps := make([]reflect.Type, 0)
|
||||
|
||||
structMode := false
|
||||
var sMi *modelInfo
|
||||
for _, container := range containers {
|
||||
val := reflect.ValueOf(container)
|
||||
ind := reflect.Indirect(val)
|
||||
@ -335,44 +266,123 @@ func (o *rawSet) QueryRow(containers ...interface{}) error {
|
||||
if typ.Kind() == reflect.Ptr {
|
||||
typ = typ.Elem()
|
||||
}
|
||||
if typ.Kind() == reflect.Ptr {
|
||||
typ = typ.Elem()
|
||||
}
|
||||
|
||||
sInds = append(sInds, ind)
|
||||
eTyps = append(eTyps, etyp)
|
||||
|
||||
o.loopInitRefs(typ, &refs, &sIdxes)
|
||||
if typ.Kind() == reflect.Struct && typ.String() != "time.Time" {
|
||||
if len(containers) > 1 {
|
||||
panic(fmt.Errorf("<RawSeter.QueryRow> now support one struct only. see #384"))
|
||||
}
|
||||
|
||||
structMode = true
|
||||
fn := getFullName(typ)
|
||||
if mi, ok := modelCache.getByFN(fn); ok {
|
||||
sMi = mi
|
||||
}
|
||||
} else {
|
||||
var ref interface{}
|
||||
refs = append(refs, &ref)
|
||||
}
|
||||
}
|
||||
|
||||
query := o.query
|
||||
o.orm.alias.DbBaser.ReplaceMarks(&query)
|
||||
|
||||
args := getFlatParams(nil, o.args, o.orm.alias.TZ)
|
||||
row := o.orm.db.QueryRow(query, args...)
|
||||
|
||||
if err := row.Scan(refs...); err == sql.ErrNoRows {
|
||||
rows, err := o.orm.db.Query(query, args...)
|
||||
if err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
return ErrNoRows
|
||||
} else if err != nil {
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
defer rows.Close()
|
||||
|
||||
if rows.Next() {
|
||||
if structMode {
|
||||
columns, err := rows.Columns()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
columnsMp := make(map[string]interface{}, len(columns))
|
||||
|
||||
refs = make([]interface{}, 0, len(columns))
|
||||
for _, col := range columns {
|
||||
var ref interface{}
|
||||
columnsMp[col] = &ref
|
||||
refs = append(refs, &ref)
|
||||
}
|
||||
|
||||
if err := rows.Scan(refs...); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
ind := sInds[0]
|
||||
|
||||
if ind.Kind() == reflect.Ptr {
|
||||
if ind.IsNil() || !ind.IsValid() {
|
||||
ind.Set(reflect.New(eTyps[0].Elem()))
|
||||
}
|
||||
ind = ind.Elem()
|
||||
}
|
||||
|
||||
if sMi != nil {
|
||||
for _, col := range columns {
|
||||
if fi := sMi.fields.GetByColumn(col); fi != nil {
|
||||
value := reflect.ValueOf(columnsMp[col]).Elem().Interface()
|
||||
o.setFieldValue(ind.FieldByIndex([]int{fi.fieldIndex}), value)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for i := 0; i < ind.NumField(); i++ {
|
||||
f := ind.Field(i)
|
||||
fe := ind.Type().Field(i)
|
||||
|
||||
var attrs map[string]bool
|
||||
var tags map[string]string
|
||||
parseStructTag(fe.Tag.Get("orm"), &attrs, &tags)
|
||||
var col string
|
||||
if col = tags["column"]; len(col) == 0 {
|
||||
col = snakeString(fe.Name)
|
||||
}
|
||||
if v, ok := columnsMp[col]; ok {
|
||||
value := reflect.ValueOf(v).Elem().Interface()
|
||||
o.setFieldValue(f, value)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} else {
|
||||
if err := rows.Scan(refs...); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
nInds := make([]reflect.Value, len(sInds))
|
||||
o.loopSetRefs(refs, sIdxes, sInds, &nInds, eTyps, true)
|
||||
o.loopSetRefs(refs, sInds, &nInds, eTyps, true)
|
||||
for i, sInd := range sInds {
|
||||
nInd := nInds[i]
|
||||
sInd.Set(nInd)
|
||||
}
|
||||
}
|
||||
|
||||
} else {
|
||||
return ErrNoRows
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// query data rows and map to container
|
||||
func (o *rawSet) QueryRows(containers ...interface{}) (int64, error) {
|
||||
refs := make([]interface{}, 0)
|
||||
sIdxes := make([][]int, 0)
|
||||
refs := make([]interface{}, 0, len(containers))
|
||||
sInds := make([]reflect.Value, 0)
|
||||
eTyps := make([]reflect.Type, 0)
|
||||
|
||||
structMode := false
|
||||
var sMi *modelInfo
|
||||
for _, container := range containers {
|
||||
val := reflect.ValueOf(container)
|
||||
sInd := reflect.Indirect(val)
|
||||
@ -389,7 +399,20 @@ func (o *rawSet) QueryRows(containers ...interface{}) (int64, error) {
|
||||
sInds = append(sInds, sInd)
|
||||
eTyps = append(eTyps, etyp)
|
||||
|
||||
o.loopInitRefs(typ, &refs, &sIdxes)
|
||||
if typ.Kind() == reflect.Struct && typ.String() != "time.Time" {
|
||||
if len(containers) > 1 {
|
||||
panic(fmt.Errorf("<RawSeter.QueryRow> now support one struct only. see #384"))
|
||||
}
|
||||
|
||||
structMode = true
|
||||
fn := getFullName(typ)
|
||||
if mi, ok := modelCache.getByFN(fn); ok {
|
||||
sMi = mi
|
||||
}
|
||||
} else {
|
||||
var ref interface{}
|
||||
refs = append(refs, &ref)
|
||||
}
|
||||
}
|
||||
|
||||
query := o.query
|
||||
@ -401,30 +424,107 @@ func (o *rawSet) QueryRows(containers ...interface{}) (int64, error) {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
nInds := make([]reflect.Value, len(sInds))
|
||||
defer rows.Close()
|
||||
|
||||
var cnt int64
|
||||
nInds := make([]reflect.Value, len(sInds))
|
||||
sInd := sInds[0]
|
||||
|
||||
for rows.Next() {
|
||||
|
||||
if structMode {
|
||||
columns, err := rows.Columns()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
columnsMp := make(map[string]interface{}, len(columns))
|
||||
|
||||
refs = make([]interface{}, 0, len(columns))
|
||||
for _, col := range columns {
|
||||
var ref interface{}
|
||||
columnsMp[col] = &ref
|
||||
refs = append(refs, &ref)
|
||||
}
|
||||
|
||||
if err := rows.Scan(refs...); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
o.loopSetRefs(refs, sIdxes, sInds, &nInds, eTyps, cnt == 0)
|
||||
if cnt == 0 && !sInd.IsNil() {
|
||||
sInd.Set(reflect.New(sInd.Type()).Elem())
|
||||
}
|
||||
|
||||
var ind reflect.Value
|
||||
if eTyps[0].Kind() == reflect.Ptr {
|
||||
ind = reflect.New(eTyps[0].Elem())
|
||||
} else {
|
||||
ind = reflect.New(eTyps[0])
|
||||
}
|
||||
|
||||
if ind.Kind() == reflect.Ptr {
|
||||
ind = ind.Elem()
|
||||
}
|
||||
|
||||
if sMi != nil {
|
||||
for _, col := range columns {
|
||||
if fi := sMi.fields.GetByColumn(col); fi != nil {
|
||||
value := reflect.ValueOf(columnsMp[col]).Elem().Interface()
|
||||
o.setFieldValue(ind.FieldByIndex([]int{fi.fieldIndex}), value)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for i := 0; i < ind.NumField(); i++ {
|
||||
f := ind.Field(i)
|
||||
fe := ind.Type().Field(i)
|
||||
|
||||
var attrs map[string]bool
|
||||
var tags map[string]string
|
||||
parseStructTag(fe.Tag.Get("orm"), &attrs, &tags)
|
||||
var col string
|
||||
if col = tags["column"]; len(col) == 0 {
|
||||
col = snakeString(fe.Name)
|
||||
}
|
||||
if v, ok := columnsMp[col]; ok {
|
||||
value := reflect.ValueOf(v).Elem().Interface()
|
||||
o.setFieldValue(f, value)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if eTyps[0].Kind() == reflect.Ptr {
|
||||
ind = ind.Addr()
|
||||
}
|
||||
|
||||
sInd = reflect.Append(sInd, ind)
|
||||
|
||||
} else {
|
||||
if err := rows.Scan(refs...); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
o.loopSetRefs(refs, sInds, &nInds, eTyps, cnt == 0)
|
||||
}
|
||||
|
||||
cnt++
|
||||
}
|
||||
|
||||
if cnt > 0 {
|
||||
|
||||
if structMode {
|
||||
sInds[0].Set(sInd)
|
||||
} else {
|
||||
for i, sInd := range sInds {
|
||||
nInd := nInds[i]
|
||||
sInd.Set(nInd)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return cnt, nil
|
||||
}
|
||||
|
||||
func (o *rawSet) readValues(container interface{}) (int64, error) {
|
||||
func (o *rawSet) readValues(container interface{}, needCols []string) (int64, error) {
|
||||
var (
|
||||
maps []Params
|
||||
lists []ParamsList
|
||||
@ -455,21 +555,41 @@ func (o *rawSet) readValues(container interface{}) (int64, error) {
|
||||
rs = r
|
||||
}
|
||||
|
||||
defer rs.Close()
|
||||
|
||||
var (
|
||||
refs []interface{}
|
||||
cnt int64
|
||||
cols []string
|
||||
indexs []int
|
||||
)
|
||||
|
||||
for rs.Next() {
|
||||
if cnt == 0 {
|
||||
if columns, err := rs.Columns(); err != nil {
|
||||
return 0, err
|
||||
} else {
|
||||
if len(needCols) > 0 {
|
||||
indexs = make([]int, 0, len(needCols))
|
||||
} else {
|
||||
indexs = make([]int, 0, len(columns))
|
||||
}
|
||||
|
||||
cols = columns
|
||||
refs = make([]interface{}, len(cols))
|
||||
for i, _ := range refs {
|
||||
var ref sql.NullString
|
||||
refs[i] = &ref
|
||||
|
||||
if len(needCols) > 0 {
|
||||
for _, c := range needCols {
|
||||
if c == cols[i] {
|
||||
indexs = append(indexs, i)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
indexs = append(indexs, i)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -481,7 +601,8 @@ func (o *rawSet) readValues(container interface{}) (int64, error) {
|
||||
switch typ {
|
||||
case 1:
|
||||
params := make(Params, len(cols))
|
||||
for i, ref := range refs {
|
||||
for _, i := range indexs {
|
||||
ref := refs[i]
|
||||
value := reflect.Indirect(reflect.ValueOf(ref)).Interface().(sql.NullString)
|
||||
if value.Valid {
|
||||
params[cols[i]] = value.String
|
||||
@ -492,7 +613,8 @@ func (o *rawSet) readValues(container interface{}) (int64, error) {
|
||||
maps = append(maps, params)
|
||||
case 2:
|
||||
params := make(ParamsList, 0, len(cols))
|
||||
for _, ref := range refs {
|
||||
for _, i := range indexs {
|
||||
ref := refs[i]
|
||||
value := reflect.Indirect(reflect.ValueOf(ref)).Interface().(sql.NullString)
|
||||
if value.Valid {
|
||||
params = append(params, value.String)
|
||||
@ -502,7 +624,8 @@ func (o *rawSet) readValues(container interface{}) (int64, error) {
|
||||
}
|
||||
lists = append(lists, params)
|
||||
case 3:
|
||||
for _, ref := range refs {
|
||||
for _, i := range indexs {
|
||||
ref := refs[i]
|
||||
value := reflect.Indirect(reflect.ValueOf(ref)).Interface().(sql.NullString)
|
||||
if value.Valid {
|
||||
list = append(list, value.String)
|
||||
@ -527,18 +650,166 @@ func (o *rawSet) readValues(container interface{}) (int64, error) {
|
||||
return cnt, nil
|
||||
}
|
||||
|
||||
func (o *rawSet) Values(container *[]Params) (int64, error) {
|
||||
return o.readValues(container)
|
||||
func (o *rawSet) queryRowsTo(container interface{}, keyCol, valueCol string) (int64, error) {
|
||||
var (
|
||||
maps Params
|
||||
ind *reflect.Value
|
||||
)
|
||||
|
||||
typ := 0
|
||||
switch container.(type) {
|
||||
case *Params:
|
||||
typ = 1
|
||||
default:
|
||||
typ = 2
|
||||
vl := reflect.ValueOf(container)
|
||||
id := reflect.Indirect(vl)
|
||||
if vl.Kind() != reflect.Ptr || id.Kind() != reflect.Struct {
|
||||
panic(fmt.Errorf("<RawSeter> RowsTo unsupport type `%T` need ptr struct", container))
|
||||
}
|
||||
|
||||
ind = &id
|
||||
}
|
||||
|
||||
query := o.query
|
||||
o.orm.alias.DbBaser.ReplaceMarks(&query)
|
||||
|
||||
args := getFlatParams(nil, o.args, o.orm.alias.TZ)
|
||||
|
||||
var rs *sql.Rows
|
||||
if r, err := o.orm.db.Query(query, args...); err != nil {
|
||||
return 0, err
|
||||
} else {
|
||||
rs = r
|
||||
}
|
||||
|
||||
defer rs.Close()
|
||||
|
||||
var (
|
||||
refs []interface{}
|
||||
cnt int64
|
||||
cols []string
|
||||
)
|
||||
|
||||
var (
|
||||
keyIndex = -1
|
||||
valueIndex = -1
|
||||
)
|
||||
|
||||
for rs.Next() {
|
||||
if cnt == 0 {
|
||||
if columns, err := rs.Columns(); err != nil {
|
||||
return 0, err
|
||||
} else {
|
||||
cols = columns
|
||||
refs = make([]interface{}, len(cols))
|
||||
for i, _ := range refs {
|
||||
if keyCol == cols[i] {
|
||||
keyIndex = i
|
||||
}
|
||||
|
||||
if typ == 1 || keyIndex == i {
|
||||
var ref sql.NullString
|
||||
refs[i] = &ref
|
||||
} else {
|
||||
var ref interface{}
|
||||
refs[i] = &ref
|
||||
}
|
||||
|
||||
if valueCol == cols[i] {
|
||||
valueIndex = i
|
||||
}
|
||||
}
|
||||
|
||||
if keyIndex == -1 || valueIndex == -1 {
|
||||
panic(fmt.Errorf("<RawSeter> RowsTo unknown key, value column name `%s: %s`", keyCol, valueCol))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if err := rs.Scan(refs...); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
if cnt == 0 {
|
||||
switch typ {
|
||||
case 1:
|
||||
maps = make(Params)
|
||||
}
|
||||
}
|
||||
|
||||
key := reflect.Indirect(reflect.ValueOf(refs[keyIndex])).Interface().(sql.NullString).String
|
||||
|
||||
switch typ {
|
||||
case 1:
|
||||
value := reflect.Indirect(reflect.ValueOf(refs[valueIndex])).Interface().(sql.NullString)
|
||||
if value.Valid {
|
||||
maps[key] = value.String
|
||||
} else {
|
||||
maps[key] = nil
|
||||
}
|
||||
|
||||
default:
|
||||
if id := ind.FieldByName(camelString(key)); id.IsValid() {
|
||||
o.setFieldValue(id, reflect.ValueOf(refs[valueIndex]).Elem().Interface())
|
||||
}
|
||||
}
|
||||
|
||||
cnt++
|
||||
}
|
||||
|
||||
if typ == 1 {
|
||||
v, _ := container.(*Params)
|
||||
*v = maps
|
||||
}
|
||||
|
||||
return cnt, nil
|
||||
}
|
||||
|
||||
func (o *rawSet) ValuesList(container *[]ParamsList) (int64, error) {
|
||||
return o.readValues(container)
|
||||
// query data to []map[string]interface
|
||||
func (o *rawSet) Values(container *[]Params, cols ...string) (int64, error) {
|
||||
return o.readValues(container, cols)
|
||||
}
|
||||
|
||||
func (o *rawSet) ValuesFlat(container *ParamsList) (int64, error) {
|
||||
return o.readValues(container)
|
||||
// query data to [][]interface
|
||||
func (o *rawSet) ValuesList(container *[]ParamsList, cols ...string) (int64, error) {
|
||||
return o.readValues(container, cols)
|
||||
}
|
||||
|
||||
// query data to []interface
|
||||
func (o *rawSet) ValuesFlat(container *ParamsList, cols ...string) (int64, error) {
|
||||
return o.readValues(container, cols)
|
||||
}
|
||||
|
||||
// query all rows into map[string]interface with specify key and value column name.
|
||||
// keyCol = "name", valueCol = "value"
|
||||
// table data
|
||||
// name | value
|
||||
// total | 100
|
||||
// found | 200
|
||||
// to map[string]interface{}{
|
||||
// "total": 100,
|
||||
// "found": 200,
|
||||
// }
|
||||
func (o *rawSet) RowsToMap(result *Params, keyCol, valueCol string) (int64, error) {
|
||||
return o.queryRowsTo(result, keyCol, valueCol)
|
||||
}
|
||||
|
||||
// query all rows into struct with specify key and value column name.
|
||||
// keyCol = "name", valueCol = "value"
|
||||
// table data
|
||||
// name | value
|
||||
// total | 100
|
||||
// found | 200
|
||||
// to struct {
|
||||
// Total int
|
||||
// Found int
|
||||
// }
|
||||
func (o *rawSet) RowsToStruct(ptrStruct interface{}, keyCol, valueCol string) (int64, error) {
|
||||
return o.queryRowsTo(ptrStruct, keyCol, valueCol)
|
||||
}
|
||||
|
||||
// return prepared raw statement for used in times.
|
||||
func (o *rawSet) Prepare() (RawPreparer, error) {
|
||||
return newRawPreparer(o)
|
||||
}
|
||||
|
347
orm/orm_test.go
347
orm/orm_test.go
@ -1,7 +1,14 @@
|
||||
// Beego (http://beego.me/)
|
||||
// @description beego is an open-source, high-performance web framework for the Go programming language.
|
||||
// @link http://github.com/astaxie/beego for the canonical source repository
|
||||
// @license http://github.com/astaxie/beego/blob/master/LICENSE
|
||||
// @authors slene
|
||||
|
||||
package orm
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"os"
|
||||
@ -138,8 +145,17 @@ func throwFailNow(t *testing.T, err error, args ...interface{}) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetDB(t *testing.T) {
|
||||
if db, err := GetDB(); err != nil {
|
||||
throwFailNow(t, err)
|
||||
} else {
|
||||
err = db.Ping()
|
||||
throwFailNow(t, err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSyncDb(t *testing.T) {
|
||||
RegisterModel(new(Data), new(DataNull))
|
||||
RegisterModel(new(Data), new(DataNull), new(DataCustom))
|
||||
RegisterModel(new(User))
|
||||
RegisterModel(new(Profile))
|
||||
RegisterModel(new(Post))
|
||||
@ -155,7 +171,7 @@ func TestSyncDb(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestRegisterModels(t *testing.T) {
|
||||
RegisterModel(new(Data), new(DataNull))
|
||||
RegisterModel(new(Data), new(DataNull), new(DataCustom))
|
||||
RegisterModel(new(User))
|
||||
RegisterModel(new(Profile))
|
||||
RegisterModel(new(Post))
|
||||
@ -258,12 +274,78 @@ func TestNullDataTypes(t *testing.T) {
|
||||
err = dORM.Read(&d)
|
||||
throwFail(t, err)
|
||||
|
||||
throwFail(t, AssertIs(d.NullBool.Valid, false))
|
||||
throwFail(t, AssertIs(d.NullString.Valid, false))
|
||||
throwFail(t, AssertIs(d.NullInt64.Valid, false))
|
||||
throwFail(t, AssertIs(d.NullFloat64.Valid, false))
|
||||
|
||||
_, err = dORM.Raw(`INSERT INTO data_null (boolean) VALUES (?)`, nil).Exec()
|
||||
throwFail(t, err)
|
||||
|
||||
d = DataNull{Id: 2}
|
||||
err = dORM.Read(&d)
|
||||
throwFail(t, err)
|
||||
|
||||
d = DataNull{
|
||||
DateTime: time.Now(),
|
||||
NullString: sql.NullString{String: "test", Valid: true},
|
||||
NullBool: sql.NullBool{Bool: true, Valid: true},
|
||||
NullInt64: sql.NullInt64{Int64: 42, Valid: true},
|
||||
NullFloat64: sql.NullFloat64{Float64: 42.42, Valid: true},
|
||||
}
|
||||
|
||||
id, err = dORM.Insert(&d)
|
||||
throwFail(t, err)
|
||||
throwFail(t, AssertIs(id, 3))
|
||||
|
||||
d = DataNull{Id: 3}
|
||||
err = dORM.Read(&d)
|
||||
throwFail(t, err)
|
||||
|
||||
throwFail(t, AssertIs(d.NullBool.Valid, true))
|
||||
throwFail(t, AssertIs(d.NullBool.Bool, true))
|
||||
|
||||
throwFail(t, AssertIs(d.NullString.Valid, true))
|
||||
throwFail(t, AssertIs(d.NullString.String, "test"))
|
||||
|
||||
throwFail(t, AssertIs(d.NullInt64.Valid, true))
|
||||
throwFail(t, AssertIs(d.NullInt64.Int64, 42))
|
||||
|
||||
throwFail(t, AssertIs(d.NullFloat64.Valid, true))
|
||||
throwFail(t, AssertIs(d.NullFloat64.Float64, 42.42))
|
||||
}
|
||||
|
||||
func TestDataCustomTypes(t *testing.T) {
|
||||
d := DataCustom{}
|
||||
ind := reflect.Indirect(reflect.ValueOf(&d))
|
||||
|
||||
for name, value := range Data_Values {
|
||||
e := ind.FieldByName(name)
|
||||
if !e.IsValid() {
|
||||
continue
|
||||
}
|
||||
e.Set(reflect.ValueOf(value).Convert(e.Type()))
|
||||
}
|
||||
|
||||
id, err := dORM.Insert(&d)
|
||||
throwFail(t, err)
|
||||
throwFail(t, AssertIs(id, 1))
|
||||
|
||||
d = DataCustom{Id: 1}
|
||||
err = dORM.Read(&d)
|
||||
throwFail(t, err)
|
||||
|
||||
ind = reflect.Indirect(reflect.ValueOf(&d))
|
||||
|
||||
for name, value := range Data_Values {
|
||||
e := ind.FieldByName(name)
|
||||
if !e.IsValid() {
|
||||
continue
|
||||
}
|
||||
vu := e.Interface()
|
||||
value = reflect.ValueOf(value).Convert(e.Type()).Interface()
|
||||
throwFail(t, AssertIs(vu == value, true), value, vu)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCRUD(t *testing.T) {
|
||||
@ -519,6 +601,10 @@ func TestOperators(t *testing.T) {
|
||||
throwFail(t, err)
|
||||
throwFail(t, AssertIs(num, 1))
|
||||
|
||||
num, err = qs.Filter("user_name__exact", String("slene")).Count()
|
||||
throwFail(t, err)
|
||||
throwFail(t, AssertIs(num, 1))
|
||||
|
||||
num, err = qs.Filter("user_name__exact", "slene").Count()
|
||||
throwFail(t, err)
|
||||
throwFail(t, AssertIs(num, 1))
|
||||
@ -559,11 +645,11 @@ func TestOperators(t *testing.T) {
|
||||
throwFail(t, err)
|
||||
throwFail(t, AssertIs(num, 3))
|
||||
|
||||
num, err = qs.Filter("status__lt", 3).Count()
|
||||
num, err = qs.Filter("status__lt", Uint(3)).Count()
|
||||
throwFail(t, err)
|
||||
throwFail(t, AssertIs(num, 2))
|
||||
|
||||
num, err = qs.Filter("status__lte", 3).Count()
|
||||
num, err = qs.Filter("status__lte", Int(3)).Count()
|
||||
throwFail(t, err)
|
||||
throwFail(t, AssertIs(num, 3))
|
||||
|
||||
@ -619,6 +705,14 @@ func TestOperators(t *testing.T) {
|
||||
num, err = qs.Filter("status__in", []*int{&n1}, &n2).Count()
|
||||
throwFail(t, err)
|
||||
throwFail(t, AssertIs(num, 2))
|
||||
|
||||
num, err = qs.Filter("id__between", 2, 3).Count()
|
||||
throwFail(t, err)
|
||||
throwFail(t, AssertIs(num, 2))
|
||||
|
||||
num, err = qs.Filter("id__between", []int{2, 3}).Count()
|
||||
throwFail(t, err)
|
||||
throwFail(t, AssertIs(num, 2))
|
||||
}
|
||||
|
||||
func TestSetCond(t *testing.T) {
|
||||
@ -1322,58 +1416,6 @@ func TestRawQueryRow(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
type Tmp struct {
|
||||
Skip0 string
|
||||
Id int
|
||||
Char *string
|
||||
Skip1 int `orm:"-"`
|
||||
Date time.Time
|
||||
DateTime time.Time
|
||||
}
|
||||
|
||||
Boolean = false
|
||||
Text = ""
|
||||
Int64 = 0
|
||||
Uint = 0
|
||||
|
||||
tmp := new(Tmp)
|
||||
|
||||
cols = []string{
|
||||
"int", "char", "date", "datetime", "boolean", "text", "int64", "uint",
|
||||
}
|
||||
query = fmt.Sprintf("SELECT NULL, %s%s%s FROM data WHERE id = ?", Q, strings.Join(cols, sep), Q)
|
||||
values = []interface{}{
|
||||
tmp, &Boolean, &Text, &Int64, &Uint,
|
||||
}
|
||||
err = dORM.Raw(query, 1).QueryRow(values...)
|
||||
throwFailNow(t, err)
|
||||
|
||||
for _, col := range cols {
|
||||
switch col {
|
||||
case "id":
|
||||
throwFail(t, AssertIs(tmp.Id, data_values[col]))
|
||||
case "char":
|
||||
c := tmp.Char
|
||||
throwFail(t, AssertIs(*c, data_values[col]))
|
||||
case "date":
|
||||
v := tmp.Date.In(DefaultTimeLoc)
|
||||
value := data_values[col].(time.Time).In(DefaultTimeLoc)
|
||||
throwFail(t, AssertIs(v, value, test_Date))
|
||||
case "datetime":
|
||||
v := tmp.DateTime.In(DefaultTimeLoc)
|
||||
value := data_values[col].(time.Time).In(DefaultTimeLoc)
|
||||
throwFail(t, AssertIs(v, value, test_DateTime))
|
||||
case "boolean":
|
||||
throwFail(t, AssertIs(Boolean, data_values[col]))
|
||||
case "text":
|
||||
throwFail(t, AssertIs(Text, data_values[col]))
|
||||
case "int64":
|
||||
throwFail(t, AssertIs(Int64, data_values[col]))
|
||||
case "uint":
|
||||
throwFail(t, AssertIs(Uint, data_values[col]))
|
||||
}
|
||||
}
|
||||
|
||||
var (
|
||||
uid int
|
||||
status *int
|
||||
@ -1381,7 +1423,7 @@ func TestRawQueryRow(t *testing.T) {
|
||||
)
|
||||
|
||||
cols = []string{
|
||||
"id", "status", "profile_id",
|
||||
"id", "Status", "profile_id",
|
||||
}
|
||||
query = fmt.Sprintf("SELECT %s%s%s FROM %suser%s WHERE id = ?", Q, strings.Join(cols, sep), Q, Q, Q)
|
||||
err = dORM.Raw(query, 4).QueryRow(&uid, &status, &pid)
|
||||
@ -1394,22 +1436,13 @@ func TestRawQueryRow(t *testing.T) {
|
||||
func TestQueryRows(t *testing.T) {
|
||||
Q := dDbBaser.TableQuote()
|
||||
|
||||
cols := []string{
|
||||
"id", "boolean", "char", "text", "date", "datetime", "byte", "rune", "int", "int8", "int16", "int32",
|
||||
"int64", "uint", "uint8", "uint16", "uint32", "uint64", "float32", "float64", "decimal",
|
||||
}
|
||||
|
||||
var datas []*Data
|
||||
var dids []int
|
||||
|
||||
sep := fmt.Sprintf("%s, %s", Q, Q)
|
||||
query := fmt.Sprintf("SELECT %s%s%s, id FROM %sdata%s", Q, strings.Join(cols, sep), Q, Q, Q)
|
||||
num, err := dORM.Raw(query).QueryRows(&datas, &dids)
|
||||
query := fmt.Sprintf("SELECT * FROM %sdata%s", Q, Q)
|
||||
num, err := dORM.Raw(query).QueryRows(&datas)
|
||||
throwFailNow(t, err)
|
||||
throwFailNow(t, AssertIs(num, 1))
|
||||
throwFailNow(t, AssertIs(len(datas), 1))
|
||||
throwFailNow(t, AssertIs(len(dids), 1))
|
||||
throwFailNow(t, AssertIs(dids[0], 1))
|
||||
|
||||
ind := reflect.Indirect(reflect.ValueOf(datas[0]))
|
||||
|
||||
@ -1427,97 +1460,50 @@ func TestQueryRows(t *testing.T) {
|
||||
throwFail(t, AssertIs(vu == value, true), value, vu)
|
||||
}
|
||||
|
||||
type Tmp struct {
|
||||
Id int
|
||||
Name string
|
||||
Skiped0 string `orm:"-"`
|
||||
Pid *int
|
||||
Skiped1 Data
|
||||
Skiped2 *Data
|
||||
var datas2 []Data
|
||||
|
||||
query = fmt.Sprintf("SELECT * FROM %sdata%s", Q, Q)
|
||||
num, err = dORM.Raw(query).QueryRows(&datas2)
|
||||
throwFailNow(t, err)
|
||||
throwFailNow(t, AssertIs(num, 1))
|
||||
throwFailNow(t, AssertIs(len(datas2), 1))
|
||||
|
||||
ind = reflect.Indirect(reflect.ValueOf(datas2[0]))
|
||||
|
||||
for name, value := range Data_Values {
|
||||
e := ind.FieldByName(name)
|
||||
vu := e.Interface()
|
||||
switch name {
|
||||
case "Date":
|
||||
vu = vu.(time.Time).In(DefaultTimeLoc).Format(test_Date)
|
||||
value = value.(time.Time).In(DefaultTimeLoc).Format(test_Date)
|
||||
case "DateTime":
|
||||
vu = vu.(time.Time).In(DefaultTimeLoc).Format(test_DateTime)
|
||||
value = value.(time.Time).In(DefaultTimeLoc).Format(test_DateTime)
|
||||
}
|
||||
throwFail(t, AssertIs(vu == value, true), value, vu)
|
||||
}
|
||||
|
||||
var (
|
||||
ids []int
|
||||
userNames []string
|
||||
profileIds1 []int
|
||||
profileIds2 []*int
|
||||
createds []time.Time
|
||||
updateds []time.Time
|
||||
tmps1 []*Tmp
|
||||
tmps2 []Tmp
|
||||
)
|
||||
cols = []string{
|
||||
"id", "user_name", "profile_id", "profile_id", "id", "user_name", "profile_id", "id", "user_name", "profile_id", "created", "updated",
|
||||
}
|
||||
query = fmt.Sprintf("SELECT %s%s%s FROM %suser%s ORDER BY id", Q, strings.Join(cols, sep), Q, Q, Q)
|
||||
num, err = dORM.Raw(query).QueryRows(&ids, &userNames, &profileIds1, &profileIds2, &tmps1, &tmps2, &createds, &updateds)
|
||||
var ids []int
|
||||
var usernames []string
|
||||
query = fmt.Sprintf("SELECT %sid%s, %suser_name%s FROM %suser%s ORDER BY %sid%s ASC", Q, Q, Q, Q, Q, Q, Q, Q)
|
||||
num, err = dORM.Raw(query).QueryRows(&ids, &usernames)
|
||||
throwFailNow(t, err)
|
||||
throwFailNow(t, AssertIs(num, 3))
|
||||
|
||||
var users []User
|
||||
dORM.QueryTable("user").OrderBy("Id").All(&users)
|
||||
|
||||
for i := 0; i < 3; i++ {
|
||||
id := ids[i]
|
||||
name := userNames[i]
|
||||
pid1 := profileIds1[i]
|
||||
pid2 := profileIds2[i]
|
||||
created := createds[i]
|
||||
updated := updateds[i]
|
||||
|
||||
user := users[i]
|
||||
throwFailNow(t, AssertIs(id, user.Id))
|
||||
throwFailNow(t, AssertIs(name, user.UserName))
|
||||
if user.Profile != nil {
|
||||
throwFailNow(t, AssertIs(pid1, user.Profile.Id))
|
||||
throwFailNow(t, AssertIs(*pid2, user.Profile.Id))
|
||||
} else {
|
||||
throwFailNow(t, AssertIs(pid1, 0))
|
||||
throwFailNow(t, AssertIs(pid2, nil))
|
||||
}
|
||||
throwFailNow(t, AssertIs(created, user.Created, test_Date))
|
||||
throwFailNow(t, AssertIs(updated, user.Updated, test_DateTime))
|
||||
|
||||
tmp := tmps1[i]
|
||||
tmp1 := *tmp
|
||||
throwFailNow(t, AssertIs(tmp1.Id, user.Id))
|
||||
throwFailNow(t, AssertIs(tmp1.Name, user.UserName))
|
||||
if user.Profile != nil {
|
||||
pid := tmp1.Pid
|
||||
throwFailNow(t, AssertIs(*pid, user.Profile.Id))
|
||||
} else {
|
||||
throwFailNow(t, AssertIs(tmp1.Pid, nil))
|
||||
}
|
||||
|
||||
tmp2 := tmps2[i]
|
||||
throwFailNow(t, AssertIs(tmp2.Id, user.Id))
|
||||
throwFailNow(t, AssertIs(tmp2.Name, user.UserName))
|
||||
if user.Profile != nil {
|
||||
pid := tmp2.Pid
|
||||
throwFailNow(t, AssertIs(*pid, user.Profile.Id))
|
||||
} else {
|
||||
throwFailNow(t, AssertIs(tmp2.Pid, nil))
|
||||
}
|
||||
}
|
||||
|
||||
type Sec struct {
|
||||
Id int
|
||||
Name string
|
||||
}
|
||||
|
||||
var tmp []*Sec
|
||||
query = fmt.Sprintf("SELECT NULL, NULL FROM %suser%s LIMIT 1", Q, Q)
|
||||
num, err = dORM.Raw(query).QueryRows(&tmp)
|
||||
throwFail(t, err)
|
||||
throwFail(t, AssertIs(num, 1))
|
||||
throwFail(t, AssertIs(tmp[0], nil))
|
||||
throwFailNow(t, AssertIs(len(ids), 3))
|
||||
throwFailNow(t, AssertIs(ids[0], 2))
|
||||
throwFailNow(t, AssertIs(usernames[0], "slene"))
|
||||
throwFailNow(t, AssertIs(ids[1], 3))
|
||||
throwFailNow(t, AssertIs(usernames[1], "astaxie"))
|
||||
throwFailNow(t, AssertIs(ids[2], 4))
|
||||
throwFailNow(t, AssertIs(usernames[2], "nobody"))
|
||||
}
|
||||
|
||||
func TestRawValues(t *testing.T) {
|
||||
Q := dDbBaser.TableQuote()
|
||||
|
||||
var maps []Params
|
||||
query := fmt.Sprintf("SELECT %suser_name%s FROM %suser%s WHERE %sstatus%s = ?", Q, Q, Q, Q, Q, Q)
|
||||
query := fmt.Sprintf("SELECT %suser_name%s FROM %suser%s WHERE %sStatus%s = ?", Q, Q, Q, Q, Q, Q)
|
||||
num, err := dORM.Raw(query, 1).Values(&maps)
|
||||
throwFail(t, err)
|
||||
throwFail(t, AssertIs(num, 1))
|
||||
@ -1669,6 +1655,31 @@ func TestDelete(t *testing.T) {
|
||||
num, err = qs.Filter("user_name", "slene").Filter("profile__isnull", true).Count()
|
||||
throwFail(t, err)
|
||||
throwFail(t, AssertIs(num, 1))
|
||||
|
||||
qs = dORM.QueryTable("comment")
|
||||
num, err = qs.Count()
|
||||
throwFail(t, err)
|
||||
throwFail(t, AssertIs(num, 6))
|
||||
|
||||
qs = dORM.QueryTable("post")
|
||||
num, err = qs.Filter("Id", 3).Delete()
|
||||
throwFail(t, err)
|
||||
throwFail(t, AssertIs(num, 1))
|
||||
|
||||
qs = dORM.QueryTable("comment")
|
||||
num, err = qs.Count()
|
||||
throwFail(t, err)
|
||||
throwFail(t, AssertIs(num, 4))
|
||||
|
||||
qs = dORM.QueryTable("comment")
|
||||
num, err = qs.Filter("Post__User", 3).Delete()
|
||||
throwFail(t, err)
|
||||
throwFail(t, AssertIs(num, 3))
|
||||
|
||||
qs = dORM.QueryTable("comment")
|
||||
num, err = qs.Count()
|
||||
throwFail(t, err)
|
||||
throwFail(t, AssertIs(num, 1))
|
||||
}
|
||||
|
||||
func TestTransaction(t *testing.T) {
|
||||
@ -1724,3 +1735,41 @@ func TestTransaction(t *testing.T) {
|
||||
throwFail(t, AssertIs(num, 1))
|
||||
|
||||
}
|
||||
|
||||
func TestReadOrCreate(t *testing.T) {
|
||||
u := &User{
|
||||
UserName: "Kyle",
|
||||
Email: "kylemcc@gmail.com",
|
||||
Password: "other_pass",
|
||||
Status: 7,
|
||||
IsStaff: false,
|
||||
IsActive: true,
|
||||
}
|
||||
|
||||
created, pk, err := dORM.ReadOrCreate(u, "UserName")
|
||||
throwFail(t, err)
|
||||
throwFail(t, AssertIs(created, true))
|
||||
throwFail(t, AssertIs(u.UserName, "Kyle"))
|
||||
throwFail(t, AssertIs(u.Email, "kylemcc@gmail.com"))
|
||||
throwFail(t, AssertIs(u.Password, "other_pass"))
|
||||
throwFail(t, AssertIs(u.Status, 7))
|
||||
throwFail(t, AssertIs(u.IsStaff, false))
|
||||
throwFail(t, AssertIs(u.IsActive, true))
|
||||
throwFail(t, AssertIs(u.Created.In(DefaultTimeLoc), u.Created.In(DefaultTimeLoc), test_Date))
|
||||
throwFail(t, AssertIs(u.Updated.In(DefaultTimeLoc), u.Updated.In(DefaultTimeLoc), test_DateTime))
|
||||
|
||||
nu := &User{UserName: u.UserName, Email: "someotheremail@gmail.com"}
|
||||
created, pk, err = dORM.ReadOrCreate(nu, "UserName")
|
||||
throwFail(t, err)
|
||||
throwFail(t, AssertIs(created, false))
|
||||
throwFail(t, AssertIs(nu.Id, u.Id))
|
||||
throwFail(t, AssertIs(pk, u.Id))
|
||||
throwFail(t, AssertIs(nu.UserName, u.UserName))
|
||||
throwFail(t, AssertIs(nu.Email, u.Email)) // should contain the value in the table, not the one specified above
|
||||
throwFail(t, AssertIs(nu.Password, u.Password))
|
||||
throwFail(t, AssertIs(nu.Status, u.Status))
|
||||
throwFail(t, AssertIs(nu.IsStaff, u.IsStaff))
|
||||
throwFail(t, AssertIs(nu.IsActive, u.IsActive))
|
||||
|
||||
dORM.Delete(u)
|
||||
}
|
||||
|
44
orm/types.go
44
orm/types.go
@ -1,3 +1,9 @@
|
||||
// Beego (http://beego.me/)
|
||||
// @description beego is an open-source, high-performance web framework for the Go programming language.
|
||||
// @link http://github.com/astaxie/beego for the canonical source repository
|
||||
// @license http://github.com/astaxie/beego/blob/master/LICENSE
|
||||
// @authors slene
|
||||
|
||||
package orm
|
||||
|
||||
import (
|
||||
@ -6,11 +12,13 @@ import (
|
||||
"time"
|
||||
)
|
||||
|
||||
// database driver
|
||||
type Driver interface {
|
||||
Name() string
|
||||
Type() DriverType
|
||||
}
|
||||
|
||||
// field info
|
||||
type Fielder interface {
|
||||
String() string
|
||||
FieldType() int
|
||||
@ -18,9 +26,12 @@ type Fielder interface {
|
||||
RawValue() interface{}
|
||||
}
|
||||
|
||||
// orm struct
|
||||
type Ormer interface {
|
||||
Read(interface{}, ...string) error
|
||||
ReadOrCreate(interface{}, string, ...string) (bool, int64, error)
|
||||
Insert(interface{}) (int64, error)
|
||||
InsertMulti(int, interface{}) (int64, error)
|
||||
Update(interface{}, ...string) (int64, error)
|
||||
Delete(interface{}) (int64, error)
|
||||
LoadRelated(interface{}, string, ...interface{}) (int64, error)
|
||||
@ -32,13 +43,16 @@ type Ormer interface {
|
||||
Rollback() error
|
||||
Raw(string, ...interface{}) RawSeter
|
||||
Driver() Driver
|
||||
GetDB() dbQuerier
|
||||
}
|
||||
|
||||
// insert prepared statement
|
||||
type Inserter interface {
|
||||
Insert(interface{}) (int64, error)
|
||||
Close() error
|
||||
}
|
||||
|
||||
// query seter
|
||||
type QuerySeter interface {
|
||||
Filter(string, ...interface{}) QuerySeter
|
||||
Exclude(string, ...interface{}) QuerySeter
|
||||
@ -57,8 +71,11 @@ type QuerySeter interface {
|
||||
Values(*[]Params, ...string) (int64, error)
|
||||
ValuesList(*[]ParamsList, ...string) (int64, error)
|
||||
ValuesFlat(*ParamsList, string) (int64, error)
|
||||
RowsToMap(*Params, string, string) (int64, error)
|
||||
RowsToStruct(interface{}, string, string) (int64, error)
|
||||
}
|
||||
|
||||
// model to model query struct
|
||||
type QueryM2Mer interface {
|
||||
Add(...interface{}) (int64, error)
|
||||
Remove(...interface{}) (int64, error)
|
||||
@ -67,22 +84,27 @@ type QueryM2Mer interface {
|
||||
Count() (int64, error)
|
||||
}
|
||||
|
||||
// raw query statement
|
||||
type RawPreparer interface {
|
||||
Exec(...interface{}) (sql.Result, error)
|
||||
Close() error
|
||||
}
|
||||
|
||||
// raw query seter
|
||||
type RawSeter interface {
|
||||
Exec() (sql.Result, error)
|
||||
QueryRow(...interface{}) error
|
||||
QueryRows(...interface{}) (int64, error)
|
||||
SetArgs(...interface{}) RawSeter
|
||||
Values(*[]Params) (int64, error)
|
||||
ValuesList(*[]ParamsList) (int64, error)
|
||||
ValuesFlat(*ParamsList) (int64, error)
|
||||
Values(*[]Params, ...string) (int64, error)
|
||||
ValuesList(*[]ParamsList, ...string) (int64, error)
|
||||
ValuesFlat(*ParamsList, ...string) (int64, error)
|
||||
RowsToMap(*Params, string, string) (int64, error)
|
||||
RowsToStruct(interface{}, string, string) (int64, error)
|
||||
Prepare() (RawPreparer, error)
|
||||
}
|
||||
|
||||
// statement querier
|
||||
type stmtQuerier interface {
|
||||
Close() error
|
||||
Exec(args ...interface{}) (sql.Result, error)
|
||||
@ -90,6 +112,7 @@ type stmtQuerier interface {
|
||||
QueryRow(args ...interface{}) *sql.Row
|
||||
}
|
||||
|
||||
// db querier
|
||||
type dbQuerier interface {
|
||||
Prepare(query string) (*sql.Stmt, error)
|
||||
Exec(query string, args ...interface{}) (sql.Result, error)
|
||||
@ -97,19 +120,31 @@ type dbQuerier interface {
|
||||
QueryRow(query string, args ...interface{}) *sql.Row
|
||||
}
|
||||
|
||||
// type DB interface {
|
||||
// Begin() (*sql.Tx, error)
|
||||
// Prepare(query string) (stmtQuerier, error)
|
||||
// Exec(query string, args ...interface{}) (sql.Result, error)
|
||||
// Query(query string, args ...interface{}) (*sql.Rows, error)
|
||||
// QueryRow(query string, args ...interface{}) *sql.Row
|
||||
// }
|
||||
|
||||
// transaction beginner
|
||||
type txer interface {
|
||||
Begin() (*sql.Tx, error)
|
||||
}
|
||||
|
||||
// transaction ending
|
||||
type txEnder interface {
|
||||
Commit() error
|
||||
Rollback() error
|
||||
}
|
||||
|
||||
// base database struct
|
||||
type dbBaser interface {
|
||||
Read(dbQuerier, *modelInfo, reflect.Value, *time.Location, []string) error
|
||||
Insert(dbQuerier, *modelInfo, reflect.Value, *time.Location) (int64, error)
|
||||
InsertValue(dbQuerier, *modelInfo, []string, []interface{}) (int64, error)
|
||||
InsertMulti(dbQuerier, *modelInfo, reflect.Value, int, *time.Location) (int64, error)
|
||||
InsertValue(dbQuerier, *modelInfo, bool, []string, []interface{}) (int64, error)
|
||||
InsertStmt(stmtQuerier, *modelInfo, reflect.Value, *time.Location) (int64, error)
|
||||
Update(dbQuerier, *modelInfo, reflect.Value, *time.Location, []string) (int64, error)
|
||||
Delete(dbQuerier, *modelInfo, reflect.Value, *time.Location) (int64, error)
|
||||
@ -123,6 +158,7 @@ type dbBaser interface {
|
||||
GenerateOperatorLeftCol(*fieldInfo, string, *string)
|
||||
PrepareInsert(dbQuerier, *modelInfo) (stmtQuerier, string, error)
|
||||
ReadValues(dbQuerier, *querySet, *modelInfo, *Condition, []string, interface{}, *time.Location) (int64, error)
|
||||
RowsTo(dbQuerier, *querySet, *modelInfo, *Condition, interface{}, string, string, *time.Location) (int64, error)
|
||||
MaxLimit() uint64
|
||||
TableQuote() string
|
||||
ReplaceMarks(*string)
|
||||
|
34
orm/utils.go
34
orm/utils.go
@ -1,3 +1,9 @@
|
||||
// Beego (http://beego.me/)
|
||||
// @description beego is an open-source, high-performance web framework for the Go programming language.
|
||||
// @link http://github.com/astaxie/beego for the canonical source repository
|
||||
// @license http://github.com/astaxie/beego/blob/master/LICENSE
|
||||
// @authors slene
|
||||
|
||||
package orm
|
||||
|
||||
import (
|
||||
@ -10,6 +16,7 @@ import (
|
||||
|
||||
type StrTo string
|
||||
|
||||
// set string
|
||||
func (f *StrTo) Set(v string) {
|
||||
if v != "" {
|
||||
*f = StrTo(v)
|
||||
@ -18,77 +25,93 @@ func (f *StrTo) Set(v string) {
|
||||
}
|
||||
}
|
||||
|
||||
// clean string
|
||||
func (f *StrTo) Clear() {
|
||||
*f = StrTo(0x1E)
|
||||
}
|
||||
|
||||
// check string exist
|
||||
func (f StrTo) Exist() bool {
|
||||
return string(f) != string(0x1E)
|
||||
}
|
||||
|
||||
// string to bool
|
||||
func (f StrTo) Bool() (bool, error) {
|
||||
return strconv.ParseBool(f.String())
|
||||
}
|
||||
|
||||
// string to float32
|
||||
func (f StrTo) Float32() (float32, error) {
|
||||
v, err := strconv.ParseFloat(f.String(), 32)
|
||||
return float32(v), err
|
||||
}
|
||||
|
||||
// string to float64
|
||||
func (f StrTo) Float64() (float64, error) {
|
||||
return strconv.ParseFloat(f.String(), 64)
|
||||
}
|
||||
|
||||
// string to int
|
||||
func (f StrTo) Int() (int, error) {
|
||||
v, err := strconv.ParseInt(f.String(), 10, 32)
|
||||
return int(v), err
|
||||
}
|
||||
|
||||
// string to int8
|
||||
func (f StrTo) Int8() (int8, error) {
|
||||
v, err := strconv.ParseInt(f.String(), 10, 8)
|
||||
return int8(v), err
|
||||
}
|
||||
|
||||
// string to int16
|
||||
func (f StrTo) Int16() (int16, error) {
|
||||
v, err := strconv.ParseInt(f.String(), 10, 16)
|
||||
return int16(v), err
|
||||
}
|
||||
|
||||
// string to int32
|
||||
func (f StrTo) Int32() (int32, error) {
|
||||
v, err := strconv.ParseInt(f.String(), 10, 32)
|
||||
return int32(v), err
|
||||
}
|
||||
|
||||
// string to int64
|
||||
func (f StrTo) Int64() (int64, error) {
|
||||
v, err := strconv.ParseInt(f.String(), 10, 64)
|
||||
return int64(v), err
|
||||
}
|
||||
|
||||
// string to uint
|
||||
func (f StrTo) Uint() (uint, error) {
|
||||
v, err := strconv.ParseUint(f.String(), 10, 32)
|
||||
return uint(v), err
|
||||
}
|
||||
|
||||
// string to uint8
|
||||
func (f StrTo) Uint8() (uint8, error) {
|
||||
v, err := strconv.ParseUint(f.String(), 10, 8)
|
||||
return uint8(v), err
|
||||
}
|
||||
|
||||
// string to uint16
|
||||
func (f StrTo) Uint16() (uint16, error) {
|
||||
v, err := strconv.ParseUint(f.String(), 10, 16)
|
||||
return uint16(v), err
|
||||
}
|
||||
|
||||
// string to uint31
|
||||
func (f StrTo) Uint32() (uint32, error) {
|
||||
v, err := strconv.ParseUint(f.String(), 10, 32)
|
||||
return uint32(v), err
|
||||
}
|
||||
|
||||
// string to uint64
|
||||
func (f StrTo) Uint64() (uint64, error) {
|
||||
v, err := strconv.ParseUint(f.String(), 10, 64)
|
||||
return uint64(v), err
|
||||
}
|
||||
|
||||
// string to string
|
||||
func (f StrTo) String() string {
|
||||
if f.Exist() {
|
||||
return string(f)
|
||||
@ -96,6 +119,7 @@ func (f StrTo) String() string {
|
||||
return ""
|
||||
}
|
||||
|
||||
// interface to string
|
||||
func ToStr(value interface{}, args ...int) (s string) {
|
||||
switch v := value.(type) {
|
||||
case bool:
|
||||
@ -134,6 +158,7 @@ func ToStr(value interface{}, args ...int) (s string) {
|
||||
return s
|
||||
}
|
||||
|
||||
// interface to int64
|
||||
func ToInt64(value interface{}) (d int64) {
|
||||
val := reflect.ValueOf(value)
|
||||
switch value.(type) {
|
||||
@ -147,6 +172,7 @@ func ToInt64(value interface{}) (d int64) {
|
||||
return
|
||||
}
|
||||
|
||||
// snake string, XxYy to xx_yy
|
||||
func snakeString(s string) string {
|
||||
data := make([]byte, 0, len(s)*2)
|
||||
j := false
|
||||
@ -164,6 +190,7 @@ func snakeString(s string) string {
|
||||
return strings.ToLower(string(data[:len(data)]))
|
||||
}
|
||||
|
||||
// camel string, xx_yy to XxYy
|
||||
func camelString(s string) string {
|
||||
data := make([]byte, 0, len(s))
|
||||
j := false
|
||||
@ -190,6 +217,7 @@ func camelString(s string) string {
|
||||
|
||||
type argString []string
|
||||
|
||||
// get string by index from string slice
|
||||
func (a argString) Get(i int, args ...string) (r string) {
|
||||
if i >= 0 && i < len(a) {
|
||||
r = a[i]
|
||||
@ -201,6 +229,7 @@ func (a argString) Get(i int, args ...string) (r string) {
|
||||
|
||||
type argInt []int
|
||||
|
||||
// get int by index from int slice
|
||||
func (a argInt) Get(i int, args ...int) (r int) {
|
||||
if i >= 0 && i < len(a) {
|
||||
r = a[i]
|
||||
@ -213,6 +242,7 @@ func (a argInt) Get(i int, args ...int) (r int) {
|
||||
|
||||
type argAny []interface{}
|
||||
|
||||
// get interface by index from interface slice
|
||||
func (a argAny) Get(i int, args ...interface{}) (r interface{}) {
|
||||
if i >= 0 && i < len(a) {
|
||||
r = a[i]
|
||||
@ -223,15 +253,18 @@ func (a argAny) Get(i int, args ...interface{}) (r interface{}) {
|
||||
return
|
||||
}
|
||||
|
||||
// parse time to string with location
|
||||
func timeParse(dateString, format string) (time.Time, error) {
|
||||
tp, err := time.ParseInLocation(format, dateString, DefaultTimeLoc)
|
||||
return tp, err
|
||||
}
|
||||
|
||||
// format time string
|
||||
func timeFormat(t time.Time, format string) string {
|
||||
return t.Format(format)
|
||||
}
|
||||
|
||||
// get pointer indirect type
|
||||
func indirectType(v reflect.Type) reflect.Type {
|
||||
switch v.Kind() {
|
||||
case reflect.Ptr:
|
||||
@ -239,5 +272,4 @@ func indirectType(v reflect.Type) reflect.Type {
|
||||
default:
|
||||
return v
|
||||
}
|
||||
return v
|
||||
}
|
||||
|
80
plugins/auth/basic.go
Normal file
80
plugins/auth/basic.go
Normal file
@ -0,0 +1,80 @@
|
||||
// Beego (http://beego.me/)
|
||||
// @description beego is an open-source, high-performance web framework for the Go programming language.
|
||||
// @link http://github.com/astaxie/beego for the canonical source repository
|
||||
// @license http://github.com/astaxie/beego/blob/master/LICENSE
|
||||
// @authors astaxie
|
||||
|
||||
package auth
|
||||
|
||||
// Example:
|
||||
// func SecretAuth(username, password string) bool {
|
||||
// if username == "astaxie" && password == "helloBeego" {
|
||||
// return true
|
||||
// }
|
||||
// return false
|
||||
// }
|
||||
// authPlugin := auth.NewBasicAuthenticator(SecretAuth, "My Realm")
|
||||
// beego.AddFilter("*","AfterStatic",authPlugin)
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/astaxie/beego"
|
||||
"github.com/astaxie/beego/context"
|
||||
)
|
||||
|
||||
func NewBasicAuthenticator(secrets SecretProvider, Realm string) beego.FilterFunc {
|
||||
return func(ctx *context.Context) {
|
||||
a := &BasicAuth{Secrets: secrets, Realm: Realm}
|
||||
if username := a.CheckAuth(ctx.Request); username == "" {
|
||||
a.RequireAuth(ctx.ResponseWriter, ctx.Request)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
type SecretProvider func(user, pass string) bool
|
||||
|
||||
type BasicAuth struct {
|
||||
Secrets SecretProvider
|
||||
Realm string
|
||||
}
|
||||
|
||||
/*
|
||||
Checks the username/password combination from the request. Returns
|
||||
either an empty string (authentication failed) or the name of the
|
||||
authenticated user.
|
||||
|
||||
Supports MD5 and SHA1 password entries
|
||||
*/
|
||||
func (a *BasicAuth) CheckAuth(r *http.Request) string {
|
||||
s := strings.SplitN(r.Header.Get("Authorization"), " ", 2)
|
||||
if len(s) != 2 || s[0] != "Basic" {
|
||||
return ""
|
||||
}
|
||||
|
||||
b, err := base64.StdEncoding.DecodeString(s[1])
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
pair := strings.SplitN(string(b), ":", 2)
|
||||
if len(pair) != 2 {
|
||||
return ""
|
||||
}
|
||||
|
||||
if a.Secrets(pair[0], pair[1]) {
|
||||
return pair[0]
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
/*
|
||||
http.Handler for BasicAuth which initiates the authentication process
|
||||
(or requires reauthentication).
|
||||
*/
|
||||
func (a *BasicAuth) RequireAuth(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("WWW-Authenticate", `Basic realm="`+a.Realm+`"`)
|
||||
w.WriteHeader(401)
|
||||
w.Write([]byte("401 Unauthorized\n"))
|
||||
}
|
10
reload.go
10
reload.go
@ -1,4 +1,9 @@
|
||||
// Zero-downtime restarts in Go.
|
||||
// Beego (http://beego.me/)
|
||||
// @description beego is an open-source, high-performance web framework for the Go programming language.
|
||||
// @link http://github.com/astaxie/beego for the canonical source repository
|
||||
// @license http://github.com/astaxie/beego/blob/master/LICENSE
|
||||
// @authors astaxie
|
||||
|
||||
package beego
|
||||
|
||||
import (
|
||||
@ -30,7 +35,7 @@ type conn struct {
|
||||
net.Conn
|
||||
wg *sync.WaitGroup
|
||||
isclose bool
|
||||
lock sync.Mutex
|
||||
lock *sync.Mutex
|
||||
}
|
||||
|
||||
// Close current processing connection.
|
||||
@ -103,7 +108,6 @@ func WaitSignal(l net.Listener) error {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
return nil // It'll never get here.
|
||||
}
|
||||
|
||||
// Kill current running os process.
|
||||
|
514
router.go
514
router.go
@ -1,10 +1,18 @@
|
||||
// Beego (http://beego.me/)
|
||||
// @description beego is an open-source, high-performance web framework for the Go programming language.
|
||||
// @link http://github.com/astaxie/beego for the canonical source repository
|
||||
// @license http://github.com/astaxie/beego/blob/master/LICENSE
|
||||
// @authors astaxie
|
||||
|
||||
package beego
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"reflect"
|
||||
"regexp"
|
||||
"runtime"
|
||||
@ -27,11 +35,30 @@ const (
|
||||
FinishRouter
|
||||
)
|
||||
|
||||
const (
|
||||
routerTypeBeego = iota
|
||||
routerTypeRESTFul
|
||||
routerTypeHandler
|
||||
)
|
||||
|
||||
var (
|
||||
// supported http methods.
|
||||
HTTPMETHOD = []string{"get", "post", "put", "delete", "patch", "options", "head"}
|
||||
HTTPMETHOD = []string{"get", "post", "put", "delete", "patch", "options", "head", "trace", "connect"}
|
||||
// these beego.Controller's methods shouldn't reflect to AutoRouter
|
||||
exceptMethod = []string{"Init", "Prepare", "Finish", "Render", "RenderString",
|
||||
"RenderBytes", "Redirect", "Abort", "StopRun", "UrlFor", "ServeJson", "ServeJsonp",
|
||||
"ServeXml", "Input", "ParseForm", "GetString", "GetStrings", "GetInt", "GetBool",
|
||||
"GetFloat", "GetFile", "SaveToFile", "StartSession", "SetSession", "GetSession",
|
||||
"DelSession", "SessionRegenerateID", "DestroySession", "IsAjax", "GetSecureCookie",
|
||||
"SetSecureCookie", "XsrfToken", "CheckXsrfCookie", "XsrfFormHtml",
|
||||
"GetControllerAndAction"}
|
||||
)
|
||||
|
||||
// To append a slice's value into "exceptMethod", for controller's methods shouldn't reflect to AutoRouter
|
||||
func ExceptMethodAppend(action string) {
|
||||
exceptMethod = append(exceptMethod, action)
|
||||
}
|
||||
|
||||
type controllerInfo struct {
|
||||
pattern string
|
||||
regex *regexp.Regexp
|
||||
@ -39,6 +66,10 @@ type controllerInfo struct {
|
||||
controllerType reflect.Type
|
||||
methods map[string]string
|
||||
hasMethod bool
|
||||
handler http.Handler
|
||||
runfunction FilterFunc
|
||||
routerType int
|
||||
isPrefix bool
|
||||
}
|
||||
|
||||
// ControllerRegistor containers registered router rules, controller handlers and filters.
|
||||
@ -71,13 +102,219 @@ func NewControllerRegistor() *ControllerRegistor {
|
||||
// Add("/api",&RestController{},"get,post:ApiFunc")
|
||||
// Add("/simple",&SimpleController{},"get:GetFunc;post:PostFunc")
|
||||
func (p *ControllerRegistor) Add(pattern string, c ControllerInterface, mappingMethods ...string) {
|
||||
parts := strings.Split(pattern, "/")
|
||||
j, params, parts := p.splitRoute(pattern)
|
||||
reflectVal := reflect.ValueOf(c)
|
||||
t := reflect.Indirect(reflectVal).Type()
|
||||
methods := make(map[string]string)
|
||||
if len(mappingMethods) > 0 {
|
||||
semi := strings.Split(mappingMethods[0], ";")
|
||||
for _, v := range semi {
|
||||
colon := strings.Split(v, ":")
|
||||
if len(colon) != 2 {
|
||||
panic("method mapping format is invalid")
|
||||
}
|
||||
comma := strings.Split(colon[0], ",")
|
||||
for _, m := range comma {
|
||||
if m == "*" || utils.InSlice(strings.ToLower(m), HTTPMETHOD) {
|
||||
if val := reflectVal.MethodByName(colon[1]); val.IsValid() {
|
||||
methods[strings.ToLower(m)] = colon[1]
|
||||
} else {
|
||||
panic(colon[1] + " method doesn't exist in the controller " + t.Name())
|
||||
}
|
||||
} else {
|
||||
panic(v + " is an invalid method mapping. Method doesn't exist " + m)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if j == 0 {
|
||||
//now create the Route
|
||||
route := &controllerInfo{}
|
||||
route.pattern = pattern
|
||||
route.controllerType = t
|
||||
route.methods = methods
|
||||
route.routerType = routerTypeBeego
|
||||
if len(methods) > 0 {
|
||||
route.hasMethod = true
|
||||
}
|
||||
p.fixrouters = append(p.fixrouters, route)
|
||||
} else { // add regexp routers
|
||||
//recreate the url pattern, with parameters replaced
|
||||
//by regular expressions. then compile the regex
|
||||
pattern = strings.Join(parts, "/")
|
||||
regex, regexErr := regexp.Compile(pattern)
|
||||
if regexErr != nil {
|
||||
//TODO add error handling here to avoid panic
|
||||
panic(regexErr)
|
||||
}
|
||||
|
||||
//now create the Route
|
||||
|
||||
route := &controllerInfo{}
|
||||
route.regex = regex
|
||||
route.params = params
|
||||
route.pattern = pattern
|
||||
route.methods = methods
|
||||
route.routerType = routerTypeBeego
|
||||
if len(methods) > 0 {
|
||||
route.hasMethod = true
|
||||
}
|
||||
route.controllerType = t
|
||||
p.routers = append(p.routers, route)
|
||||
}
|
||||
}
|
||||
|
||||
// add get method
|
||||
// usage:
|
||||
// Get("/", func(ctx *context.Context){
|
||||
// ctx.Output.Body("hello world")
|
||||
// })
|
||||
func (p *ControllerRegistor) Get(pattern string, f FilterFunc) {
|
||||
p.AddMethod("get", pattern, f)
|
||||
}
|
||||
|
||||
// add post method
|
||||
// usage:
|
||||
// Post("/api", func(ctx *context.Context){
|
||||
// ctx.Output.Body("hello world")
|
||||
// })
|
||||
func (p *ControllerRegistor) Post(pattern string, f FilterFunc) {
|
||||
p.AddMethod("post", pattern, f)
|
||||
}
|
||||
|
||||
// add put method
|
||||
// usage:
|
||||
// Put("/api/:id", func(ctx *context.Context){
|
||||
// ctx.Output.Body("hello world")
|
||||
// })
|
||||
func (p *ControllerRegistor) Put(pattern string, f FilterFunc) {
|
||||
p.AddMethod("put", pattern, f)
|
||||
}
|
||||
|
||||
// add delete method
|
||||
// usage:
|
||||
// Delete("/api/:id", func(ctx *context.Context){
|
||||
// ctx.Output.Body("hello world")
|
||||
// })
|
||||
func (p *ControllerRegistor) Delete(pattern string, f FilterFunc) {
|
||||
p.AddMethod("delete", pattern, f)
|
||||
}
|
||||
|
||||
// add head method
|
||||
// usage:
|
||||
// Head("/api/:id", func(ctx *context.Context){
|
||||
// ctx.Output.Body("hello world")
|
||||
// })
|
||||
func (p *ControllerRegistor) Head(pattern string, f FilterFunc) {
|
||||
p.AddMethod("head", pattern, f)
|
||||
}
|
||||
|
||||
// add patch method
|
||||
// usage:
|
||||
// Patch("/api/:id", func(ctx *context.Context){
|
||||
// ctx.Output.Body("hello world")
|
||||
// })
|
||||
func (p *ControllerRegistor) Patch(pattern string, f FilterFunc) {
|
||||
p.AddMethod("patch", pattern, f)
|
||||
}
|
||||
|
||||
// add options method
|
||||
// usage:
|
||||
// Options("/api/:id", func(ctx *context.Context){
|
||||
// ctx.Output.Body("hello world")
|
||||
// })
|
||||
func (p *ControllerRegistor) Options(pattern string, f FilterFunc) {
|
||||
p.AddMethod("options", pattern, f)
|
||||
}
|
||||
|
||||
// add all method
|
||||
// usage:
|
||||
// Any("/api/:id", func(ctx *context.Context){
|
||||
// ctx.Output.Body("hello world")
|
||||
// })
|
||||
func (p *ControllerRegistor) Any(pattern string, f FilterFunc) {
|
||||
p.AddMethod("*", pattern, f)
|
||||
}
|
||||
|
||||
// add http method router
|
||||
// usage:
|
||||
// AddMethod("get","/api/:id", func(ctx *context.Context){
|
||||
// ctx.Output.Body("hello world")
|
||||
// })
|
||||
func (p *ControllerRegistor) AddMethod(method, pattern string, f FilterFunc) {
|
||||
if method != "*" && !utils.InSlice(strings.ToLower(method), HTTPMETHOD) {
|
||||
panic("not support http method: " + method)
|
||||
}
|
||||
route := &controllerInfo{}
|
||||
route.routerType = routerTypeRESTFul
|
||||
route.runfunction = f
|
||||
methods := make(map[string]string)
|
||||
if method == "*" {
|
||||
for _, val := range HTTPMETHOD {
|
||||
methods[val] = val
|
||||
}
|
||||
} else {
|
||||
methods[method] = method
|
||||
}
|
||||
route.methods = methods
|
||||
paramnums, params, parts := p.splitRoute(pattern)
|
||||
if paramnums == 0 {
|
||||
//now create the Route
|
||||
route.pattern = pattern
|
||||
p.fixrouters = append(p.fixrouters, route)
|
||||
} else {
|
||||
//recreate the url pattern, with parameters replaced
|
||||
//by regular expressions. then compile the regex
|
||||
pattern = strings.Join(parts, "/")
|
||||
regex, regexErr := regexp.Compile(pattern)
|
||||
if regexErr != nil {
|
||||
panic(regexErr)
|
||||
}
|
||||
//now create the Route
|
||||
route.regex = regex
|
||||
route.params = params
|
||||
route.pattern = pattern
|
||||
p.routers = append(p.routers, route)
|
||||
}
|
||||
}
|
||||
|
||||
func (p *ControllerRegistor) Handler(pattern string, h http.Handler, options ...interface{}) {
|
||||
paramnums, params, parts := p.splitRoute(pattern)
|
||||
route := &controllerInfo{}
|
||||
route.routerType = routerTypeHandler
|
||||
route.handler = h
|
||||
if len(options) > 0 {
|
||||
if v, ok := options[0].(bool); ok {
|
||||
route.isPrefix = v
|
||||
}
|
||||
}
|
||||
if paramnums == 0 {
|
||||
route.pattern = pattern
|
||||
p.fixrouters = append(p.fixrouters, route)
|
||||
} else {
|
||||
//recreate the url pattern, with parameters replaced
|
||||
//by regular expressions. then compile the regex
|
||||
pattern = strings.Join(parts, "/")
|
||||
regex, regexErr := regexp.Compile(pattern)
|
||||
if regexErr != nil {
|
||||
panic(regexErr)
|
||||
}
|
||||
//now create the Route
|
||||
route.regex = regex
|
||||
route.params = params
|
||||
route.pattern = pattern
|
||||
p.routers = append(p.routers, route)
|
||||
}
|
||||
}
|
||||
|
||||
// analisys the patter to params & parts
|
||||
func (p *ControllerRegistor) splitRoute(pattern string) (paramnums int, params map[int]string, parts []string) {
|
||||
parts = strings.Split(pattern, "/")
|
||||
j := 0
|
||||
params := make(map[int]string)
|
||||
params = make(map[int]string)
|
||||
for i, part := range parts {
|
||||
if strings.HasPrefix(part, ":") {
|
||||
expr := "(.+)"
|
||||
expr := "(.*)"
|
||||
//a user may choose to override the defult expression
|
||||
// similar to expressjs: ‘/user/:id([0-9]+)’
|
||||
if index := strings.Index(part, "("); index != -1 {
|
||||
@ -94,13 +331,17 @@ func (p *ControllerRegistor) Add(pattern string, c ControllerInterface, mappingM
|
||||
expr = `([\w]+)`
|
||||
part = part[:lindex]
|
||||
}
|
||||
//marth /user/:id! non-empty value
|
||||
} else if part[len(part)-1] == '!' {
|
||||
expr = `(.+)`
|
||||
part = part[:len(part)-1]
|
||||
}
|
||||
params[j] = part
|
||||
parts[i] = expr
|
||||
j++
|
||||
}
|
||||
if strings.HasPrefix(part, "*") {
|
||||
expr := "(.+)"
|
||||
expr := "(.*)"
|
||||
if part == "*.*" {
|
||||
params[j] = ":path"
|
||||
parts[i] = "([^.]+).([^.]+)"
|
||||
@ -155,71 +396,14 @@ func (p *ControllerRegistor) Add(pattern string, c ControllerInterface, mappingM
|
||||
parts[i] = string(out)
|
||||
}
|
||||
}
|
||||
reflectVal := reflect.ValueOf(c)
|
||||
t := reflect.Indirect(reflectVal).Type()
|
||||
methods := make(map[string]string)
|
||||
if len(mappingMethods) > 0 {
|
||||
semi := strings.Split(mappingMethods[0], ";")
|
||||
for _, v := range semi {
|
||||
colon := strings.Split(v, ":")
|
||||
if len(colon) != 2 {
|
||||
panic("method mapping format is invalid")
|
||||
}
|
||||
comma := strings.Split(colon[0], ",")
|
||||
for _, m := range comma {
|
||||
if m == "*" || utils.InSlice(strings.ToLower(m), HTTPMETHOD) {
|
||||
if val := reflectVal.MethodByName(colon[1]); val.IsValid() {
|
||||
methods[strings.ToLower(m)] = colon[1]
|
||||
} else {
|
||||
panic(colon[1] + " method doesn't exist in the controller " + t.Name())
|
||||
}
|
||||
} else {
|
||||
panic(v + " is an invalid method mapping. Method doesn't exist " + m)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if j == 0 {
|
||||
//now create the Route
|
||||
route := &controllerInfo{}
|
||||
route.pattern = pattern
|
||||
route.controllerType = t
|
||||
route.methods = methods
|
||||
if len(methods) > 0 {
|
||||
route.hasMethod = true
|
||||
}
|
||||
p.fixrouters = append(p.fixrouters, route)
|
||||
} else { // add regexp routers
|
||||
//recreate the url pattern, with parameters replaced
|
||||
//by regular expressions. then compile the regex
|
||||
pattern = strings.Join(parts, "/")
|
||||
regex, regexErr := regexp.Compile(pattern)
|
||||
if regexErr != nil {
|
||||
//TODO add error handling here to avoid panic
|
||||
panic(regexErr)
|
||||
return
|
||||
}
|
||||
|
||||
//now create the Route
|
||||
|
||||
route := &controllerInfo{}
|
||||
route.regex = regex
|
||||
route.params = params
|
||||
route.pattern = pattern
|
||||
route.methods = methods
|
||||
if len(methods) > 0 {
|
||||
route.hasMethod = true
|
||||
}
|
||||
route.controllerType = t
|
||||
p.routers = append(p.routers, route)
|
||||
}
|
||||
return j, params, parts
|
||||
}
|
||||
|
||||
// Add auto router to ControllerRegistor.
|
||||
// example beego.AddAuto(&MainContorlller{}),
|
||||
// MainController has method List and Page.
|
||||
// visit the url /main/list to exec List function
|
||||
// /main/page to exec Page function.
|
||||
// visit the url /main/list to execute List function
|
||||
// /main/page to execute Page function.
|
||||
func (p *ControllerRegistor) AddAuto(c ControllerInterface) {
|
||||
p.enableAuto = true
|
||||
reflectVal := reflect.ValueOf(c)
|
||||
@ -232,14 +416,43 @@ func (p *ControllerRegistor) AddAuto(c ControllerInterface) {
|
||||
p.autoRouter[firstParam] = make(map[string]reflect.Type)
|
||||
}
|
||||
for i := 0; i < rt.NumMethod(); i++ {
|
||||
if !utils.InSlice(rt.Method(i).Name, exceptMethod) {
|
||||
p.autoRouter[firstParam][rt.Method(i).Name] = ct
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Add auto router to ControllerRegistor with prefix.
|
||||
// example beego.AddAutoPrefix("/admin",&MainContorlller{}),
|
||||
// MainController has method List and Page.
|
||||
// visit the url /admin/main/list to execute List function
|
||||
// /admin/main/page to execute Page function.
|
||||
func (p *ControllerRegistor) AddAutoPrefix(prefix string, c ControllerInterface) {
|
||||
p.enableAuto = true
|
||||
reflectVal := reflect.ValueOf(c)
|
||||
rt := reflectVal.Type()
|
||||
ct := reflect.Indirect(reflectVal).Type()
|
||||
firstParam := strings.Trim(prefix, "/") + "/" + strings.ToLower(strings.TrimSuffix(ct.Name(), "Controller"))
|
||||
if _, ok := p.autoRouter[firstParam]; ok {
|
||||
return
|
||||
} else {
|
||||
p.autoRouter[firstParam] = make(map[string]reflect.Type)
|
||||
}
|
||||
for i := 0; i < rt.NumMethod(); i++ {
|
||||
if !utils.InSlice(rt.Method(i).Name, exceptMethod) {
|
||||
p.autoRouter[firstParam][rt.Method(i).Name] = ct
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// [Deprecated] use InsertFilter.
|
||||
// Add FilterFunc with pattern for action.
|
||||
func (p *ControllerRegistor) AddFilter(pattern, action string, filter FilterFunc) {
|
||||
mr := buildFilter(pattern, filter)
|
||||
func (p *ControllerRegistor) AddFilter(pattern, action string, filter FilterFunc) error {
|
||||
mr, err := buildFilter(pattern, filter)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
switch action {
|
||||
case "BeforeRouter":
|
||||
p.filters[BeforeRouter] = append(p.filters[BeforeRouter], mr)
|
||||
@ -253,13 +466,18 @@ func (p *ControllerRegistor) AddFilter(pattern, action string, filter FilterFunc
|
||||
p.filters[FinishRouter] = append(p.filters[FinishRouter], mr)
|
||||
}
|
||||
p.enableFilter = true
|
||||
return nil
|
||||
}
|
||||
|
||||
// Add a FilterFunc with pattern rule and action constant.
|
||||
func (p *ControllerRegistor) InsertFilter(pattern string, pos int, filter FilterFunc) {
|
||||
mr := buildFilter(pattern, filter)
|
||||
func (p *ControllerRegistor) InsertFilter(pattern string, pos int, filter FilterFunc) error {
|
||||
mr, err := buildFilter(pattern, filter)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
p.filters[pos] = append(p.filters[pos], mr)
|
||||
p.enableFilter = true
|
||||
return nil
|
||||
}
|
||||
|
||||
// UrlFor does another controller handler in this request function.
|
||||
@ -444,6 +662,7 @@ func (p *ControllerRegistor) ServeHTTP(rw http.ResponseWriter, r *http.Request)
|
||||
var runrouter reflect.Type
|
||||
var findrouter bool
|
||||
var runMethod string
|
||||
var routerInfo *controllerInfo
|
||||
params := make(map[string]string)
|
||||
|
||||
w := &responseWriter{writer: rw}
|
||||
@ -485,101 +704,78 @@ func (p *ControllerRegistor) ServeHTTP(rw http.ResponseWriter, r *http.Request)
|
||||
// session init
|
||||
if SessionOn {
|
||||
context.Input.CruSession = GlobalSessions.SessionStart(w, r)
|
||||
defer context.Input.CruSession.SessionRelease()
|
||||
defer func() {
|
||||
context.Input.CruSession.SessionRelease(w)
|
||||
}()
|
||||
}
|
||||
|
||||
if !utils.InSlice(strings.ToLower(r.Method), HTTPMETHOD) {
|
||||
http.Error(w, "Method Not Allowed", 405)
|
||||
goto Admin
|
||||
}
|
||||
//static file server
|
||||
if serverStaticRouter(context) {
|
||||
goto Admin
|
||||
}
|
||||
|
||||
if !context.Input.IsGet() && !context.Input.IsHead() {
|
||||
if CopyRequestBody && !context.Input.IsUpload() {
|
||||
context.Input.CopyBody()
|
||||
}
|
||||
context.Input.ParseFormOrMulitForm(MaxMemory)
|
||||
}
|
||||
|
||||
if do_filter(BeforeRouter) {
|
||||
goto Admin
|
||||
}
|
||||
|
||||
//static file server
|
||||
for prefix, staticDir := range StaticDir {
|
||||
if r.URL.Path == "/favicon.ico" {
|
||||
file := staticDir + r.URL.Path
|
||||
http.ServeFile(w, r, file)
|
||||
w.started = true
|
||||
goto Admin
|
||||
}
|
||||
if strings.HasPrefix(r.URL.Path, prefix) {
|
||||
file := staticDir + r.URL.Path[len(prefix):]
|
||||
finfo, err := os.Stat(file)
|
||||
if err != nil {
|
||||
if RunMode == "dev" {
|
||||
Warn(err)
|
||||
}
|
||||
http.NotFound(w, r)
|
||||
goto Admin
|
||||
}
|
||||
//if the request is dir and DirectoryIndex is false then
|
||||
if finfo.IsDir() && !DirectoryIndex {
|
||||
middleware.Exception("403", rw, r, "403 Forbidden")
|
||||
goto Admin
|
||||
}
|
||||
|
||||
//This block obtained from (https://github.com/smithfox/beego) - it should probably get merged into astaxie/beego after a pull request
|
||||
isStaticFileToCompress := false
|
||||
if StaticExtensionsToGzip != nil && len(StaticExtensionsToGzip) > 0 {
|
||||
for _, statExtension := range StaticExtensionsToGzip {
|
||||
if strings.HasSuffix(strings.ToLower(file), strings.ToLower(statExtension)) {
|
||||
isStaticFileToCompress = true
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if isStaticFileToCompress {
|
||||
if EnableGzip {
|
||||
w.contentEncoding = GetAcceptEncodingZip(r)
|
||||
}
|
||||
|
||||
memzipfile, err := OpenMemZipFile(file, w.contentEncoding)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
w.InitHeadContent(finfo.Size())
|
||||
|
||||
http.ServeContent(w, r, file, finfo.ModTime(), memzipfile)
|
||||
} else {
|
||||
http.ServeFile(w, r, file)
|
||||
}
|
||||
|
||||
w.started = true
|
||||
goto Admin
|
||||
}
|
||||
}
|
||||
|
||||
if do_filter(AfterStatic) {
|
||||
goto Admin
|
||||
}
|
||||
|
||||
if CopyRequestBody {
|
||||
context.Input.Body()
|
||||
if context.Input.RunController != nil && context.Input.RunMethod != "" {
|
||||
findrouter = true
|
||||
runMethod = context.Input.RunMethod
|
||||
runrouter = context.Input.RunController
|
||||
}
|
||||
|
||||
//first find path from the fixrouters to Improve Performance
|
||||
if !findrouter {
|
||||
for _, route := range p.fixrouters {
|
||||
n := len(requestPath)
|
||||
if requestPath == route.pattern {
|
||||
runMethod = p.getRunMethod(r.Method, context, route)
|
||||
if runMethod != "" {
|
||||
routerInfo = route
|
||||
runrouter = route.controllerType
|
||||
findrouter = true
|
||||
break
|
||||
}
|
||||
}
|
||||
// pattern /admin url /admin 200 /admin/ 404
|
||||
// pattern /admin url /admin 200 /admin/ 200
|
||||
// pattern /admin/ url /admin 301 /admin/ 200
|
||||
if requestPath[n-1] != '/' && len(route.pattern) == n+1 &&
|
||||
route.pattern[n] == '/' && route.pattern[:n] == requestPath {
|
||||
if requestPath[n-1] != '/' && requestPath+"/" == route.pattern {
|
||||
http.Redirect(w, r, requestPath+"/", 301)
|
||||
goto Admin
|
||||
}
|
||||
if requestPath[n-1] == '/' && route.pattern+"/" == requestPath {
|
||||
runMethod = p.getRunMethod(r.Method, context, route)
|
||||
if runMethod != "" {
|
||||
routerInfo = route
|
||||
runrouter = route.controllerType
|
||||
findrouter = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if route.routerType == routerTypeHandler && route.isPrefix &&
|
||||
strings.HasPrefix(requestPath, route.pattern) {
|
||||
|
||||
routerInfo = route
|
||||
runrouter = route.controllerType
|
||||
findrouter = true
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
//find regex's router
|
||||
@ -612,6 +808,7 @@ func (p *ControllerRegistor) ServeHTTP(rw http.ResponseWriter, r *http.Request)
|
||||
}
|
||||
runMethod = p.getRunMethod(r.Method, context, route)
|
||||
if runMethod != "" {
|
||||
routerInfo = route
|
||||
runrouter = route.controllerType
|
||||
context.Input.Params = params
|
||||
findrouter = true
|
||||
@ -678,14 +875,27 @@ func (p *ControllerRegistor) ServeHTTP(rw http.ResponseWriter, r *http.Request)
|
||||
}
|
||||
|
||||
if findrouter {
|
||||
if r.Method == "POST" {
|
||||
r.ParseMultipartForm(MaxMemory)
|
||||
}
|
||||
//execute middleware filters
|
||||
if do_filter(BeforeExec) {
|
||||
goto Admin
|
||||
}
|
||||
isRunable := false
|
||||
if routerInfo != nil {
|
||||
if routerInfo.routerType == routerTypeRESTFul {
|
||||
if _, ok := routerInfo.methods[strings.ToLower(r.Method)]; ok {
|
||||
isRunable = true
|
||||
routerInfo.runfunction(context)
|
||||
} else {
|
||||
middleware.Exception("405", rw, r, "Method Not Allowed")
|
||||
goto Admin
|
||||
}
|
||||
} else if routerInfo.routerType == routerTypeHandler {
|
||||
isRunable = true
|
||||
routerInfo.handler.ServeHTTP(rw, r)
|
||||
}
|
||||
}
|
||||
|
||||
if !isRunable {
|
||||
//Invoke the request handler
|
||||
vc := reflect.New(runrouter)
|
||||
execController, ok := vc.Interface().(ControllerInterface)
|
||||
@ -696,6 +906,9 @@ func (p *ControllerRegistor) ServeHTTP(rw http.ResponseWriter, r *http.Request)
|
||||
//call the controller init function
|
||||
execController.Init(context, runrouter.Name(), runMethod, vc.Interface())
|
||||
|
||||
//call prepare function
|
||||
execController.Prepare()
|
||||
|
||||
//if XSRF is Enable then check cookie where there has any cookie in the request's cookie _csrf
|
||||
if EnableXSRF {
|
||||
execController.XsrfToken()
|
||||
@ -705,9 +918,6 @@ func (p *ControllerRegistor) ServeHTTP(rw http.ResponseWriter, r *http.Request)
|
||||
}
|
||||
}
|
||||
|
||||
//call prepare function
|
||||
execController.Prepare()
|
||||
|
||||
if !w.started {
|
||||
//exec main logic
|
||||
switch runMethod {
|
||||
@ -744,6 +954,7 @@ func (p *ControllerRegistor) ServeHTTP(rw http.ResponseWriter, r *http.Request)
|
||||
|
||||
// finish all runrouter. release resource
|
||||
execController.Finish()
|
||||
}
|
||||
|
||||
//execute middleware filters
|
||||
if do_filter(AfterExec) {
|
||||
@ -751,9 +962,8 @@ func (p *ControllerRegistor) ServeHTTP(rw http.ResponseWriter, r *http.Request)
|
||||
}
|
||||
}
|
||||
|
||||
Admin:
|
||||
do_filter(FinishRouter)
|
||||
|
||||
Admin:
|
||||
//admin module record QPS
|
||||
if EnableAdmin {
|
||||
timeend := time.Since(starttime)
|
||||
@ -815,7 +1025,6 @@ type responseWriter struct {
|
||||
writer http.ResponseWriter
|
||||
started bool
|
||||
status int
|
||||
contentEncoding string
|
||||
}
|
||||
|
||||
// Header returns the header map that will be sent by WriteHeader.
|
||||
@ -823,17 +1032,6 @@ func (w *responseWriter) Header() http.Header {
|
||||
return w.writer.Header()
|
||||
}
|
||||
|
||||
// Init content-length header.
|
||||
func (w *responseWriter) InitHeadContent(contentlength int64) {
|
||||
if w.contentEncoding == "gzip" {
|
||||
w.Header().Set("Content-Encoding", "gzip")
|
||||
} else if w.contentEncoding == "deflate" {
|
||||
w.Header().Set("Content-Encoding", "deflate")
|
||||
} else {
|
||||
w.Header().Set("Content-Length", strconv.FormatInt(contentlength, 10))
|
||||
}
|
||||
}
|
||||
|
||||
// Write writes the data to the connection as part of an HTTP reply,
|
||||
// and sets `started` to true.
|
||||
// started means the response has sent out.
|
||||
@ -849,3 +1047,13 @@ func (w *responseWriter) WriteHeader(code int) {
|
||||
w.started = true
|
||||
w.writer.WriteHeader(code)
|
||||
}
|
||||
|
||||
// hijacker for http
|
||||
func (w *responseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
|
||||
hj, ok := w.writer.(http.Hijacker)
|
||||
if !ok {
|
||||
println("supported?")
|
||||
return nil, nil, errors.New("webserver doesn't support hijacking")
|
||||
}
|
||||
return hj.Hijack()
|
||||
}
|
||||
|
@ -1,9 +1,17 @@
|
||||
// Beego (http://beego.me/)
|
||||
// @description beego is an open-source, high-performance web framework for the Go programming language.
|
||||
// @link http://github.com/astaxie/beego for the canonical source repository
|
||||
// @license http://github.com/astaxie/beego/blob/master/LICENSE
|
||||
// @authors astaxie
|
||||
|
||||
package beego
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/astaxie/beego/context"
|
||||
)
|
||||
|
||||
type TestController struct {
|
||||
@ -15,6 +23,10 @@ func (this *TestController) Get() {
|
||||
this.Ctx.Output.Body([]byte("ok"))
|
||||
}
|
||||
|
||||
func (this *TestController) Post() {
|
||||
this.Ctx.Output.Body([]byte(this.Ctx.Input.Query(":name")))
|
||||
}
|
||||
|
||||
func (this *TestController) List() {
|
||||
this.Ctx.Output.Body([]byte("i am list"))
|
||||
}
|
||||
@ -81,6 +93,18 @@ func TestUserFunc(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestPostFunc(t *testing.T) {
|
||||
r, _ := http.NewRequest("POST", "/astaxie", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
handler := NewControllerRegistor()
|
||||
handler.Add("/:name", &TestController{})
|
||||
handler.ServeHTTP(w, r)
|
||||
if w.Body.String() != "astaxie" {
|
||||
t.Errorf("post func should astaxie")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAutoFunc(t *testing.T) {
|
||||
r, _ := http.NewRequest("GET", "/test/list", nil)
|
||||
w := httptest.NewRecorder()
|
||||
@ -198,3 +222,59 @@ func TestPrepare(t *testing.T) {
|
||||
t.Errorf(w.Body.String() + "user define func can't run")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAutoPrefix(t *testing.T) {
|
||||
r, _ := http.NewRequest("GET", "/admin/test/list", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
handler := NewControllerRegistor()
|
||||
handler.AddAutoPrefix("/admin", &TestController{})
|
||||
handler.ServeHTTP(w, r)
|
||||
if w.Body.String() != "i am list" {
|
||||
t.Errorf("TestAutoPrefix can't run")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRouterGet(t *testing.T) {
|
||||
r, _ := http.NewRequest("GET", "/user", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
handler := NewControllerRegistor()
|
||||
handler.Get("/user", func(ctx *context.Context) {
|
||||
ctx.Output.Body([]byte("Get userlist"))
|
||||
})
|
||||
handler.ServeHTTP(w, r)
|
||||
if w.Body.String() != "Get userlist" {
|
||||
t.Errorf("TestRouterGet can't run")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRouterPost(t *testing.T) {
|
||||
r, _ := http.NewRequest("POST", "/user/123", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
handler := NewControllerRegistor()
|
||||
handler.Post("/user/:id", func(ctx *context.Context) {
|
||||
ctx.Output.Body([]byte(ctx.Input.Param(":id")))
|
||||
})
|
||||
handler.ServeHTTP(w, r)
|
||||
if w.Body.String() != "123" {
|
||||
t.Errorf("TestRouterPost can't run")
|
||||
}
|
||||
}
|
||||
|
||||
func sayhello(w http.ResponseWriter, r *http.Request) {
|
||||
w.Write([]byte("sayhello"))
|
||||
}
|
||||
|
||||
func TestRouterHandler(t *testing.T) {
|
||||
r, _ := http.NewRequest("POST", "/sayhi", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
handler := NewControllerRegistor()
|
||||
handler.Handler("/sayhi", http.HandlerFunc(sayhello))
|
||||
handler.ServeHTTP(w, r)
|
||||
if w.Body.String() != "sayhello" {
|
||||
t.Errorf("TestRouterHandler can't run")
|
||||
}
|
||||
}
|
||||
|
@ -28,21 +28,21 @@ Then in you web app init the global session manager
|
||||
* Use **memory** as provider:
|
||||
|
||||
func init() {
|
||||
globalSessions, _ = session.NewManager("memory", "gosessionid", 3600,"")
|
||||
globalSessions, _ = session.NewManager("memory", `{"cookieName":"gosessionid","gclifetime":3600}`)
|
||||
go globalSessions.GC()
|
||||
}
|
||||
|
||||
* Use **file** as provider, the last param is the path where you want file to be stored:
|
||||
|
||||
func init() {
|
||||
globalSessions, _ = session.NewManager("file", "gosessionid", 3600, "./tmp")
|
||||
globalSessions, _ = session.NewManager("file",`{"cookieName":"gosessionid","gclifetime":3600,"ProviderConfig":"./tmp"}`)
|
||||
go globalSessions.GC()
|
||||
}
|
||||
|
||||
* Use **Redis** as provider, the last param is the Redis conn address,poolsize,password:
|
||||
|
||||
func init() {
|
||||
globalSessions, _ = session.NewManager("redis", "gosessionid", 3600, "127.0.0.1:6379,100,astaxie")
|
||||
globalSessions, _ = session.NewManager("redis", `{"cookieName":"gosessionid","gclifetime":3600,"ProviderConfig":"127.0.0.1:6379,100,astaxie"}`)
|
||||
go globalSessions.GC()
|
||||
}
|
||||
|
||||
@ -50,15 +50,24 @@ Then in you web app init the global session manager
|
||||
|
||||
func init() {
|
||||
globalSessions, _ = session.NewManager(
|
||||
"mysql", "gosessionid", 3600, "username:password@protocol(address)/dbname?param=value")
|
||||
"mysql", `{"cookieName":"gosessionid","gclifetime":3600,"ProviderConfig":"username:password@protocol(address)/dbname?param=value"}`)
|
||||
go globalSessions.GC()
|
||||
}
|
||||
|
||||
* Use **Cookie** as provider:
|
||||
|
||||
func init() {
|
||||
globalSessions, _ = session.NewManager(
|
||||
"cookie", `{"cookieName":"gosessionid","enableSetCookie":false,"gclifetime":3600,"ProviderConfig":"{\"cookieName\":\"gosessionid\",\"securityKey\":\"beegocookiehashkey\"}"}`)
|
||||
go globalSessions.GC()
|
||||
}
|
||||
|
||||
|
||||
Finally in the handlerfunc you can use it like this
|
||||
|
||||
func login(w http.ResponseWriter, r *http.Request) {
|
||||
sess := globalSessions.SessionStart(w, r)
|
||||
defer sess.SessionRelease()
|
||||
defer sess.SessionRelease(w)
|
||||
username := sess.Get("username")
|
||||
fmt.Println(username)
|
||||
if r.Method == "GET" {
|
||||
@ -78,19 +87,19 @@ When you develop a web app, maybe you want to write own provider because you mus
|
||||
|
||||
Writing a provider is easy. You only need to define two struct types
|
||||
(Session and Provider), which satisfy the interface definition.
|
||||
Maybe you will find the **memory** provider as good example.
|
||||
Maybe you will find the **memory** provider is a good example.
|
||||
|
||||
type SessionStore interface {
|
||||
Set(key, value interface{}) error //set session value
|
||||
Get(key interface{}) interface{} //get session value
|
||||
Delete(key interface{}) error //delete session value
|
||||
SessionID() string //back current sessionID
|
||||
SessionRelease() // release the resource & save data to provider
|
||||
SessionRelease(w http.ResponseWriter) // release the resource & save data to provider & return the data
|
||||
Flush() error //delete all data
|
||||
}
|
||||
|
||||
type Provider interface {
|
||||
SessionInit(maxlifetime int64, savePath string) error
|
||||
SessionInit(gclifetime int64, config string) error
|
||||
SessionRead(sid string) (SessionStore, error)
|
||||
SessionExist(sid string) bool
|
||||
SessionRegenerate(oldsid, sid string) (SessionStore, error)
|
||||
|
211
session/couchbase/sess_couchbase.go
Normal file
211
session/couchbase/sess_couchbase.go
Normal file
@ -0,0 +1,211 @@
|
||||
// Beego (http://beego.me/)
|
||||
// @description beego is an open-source, high-performance web framework for the Go programming language.
|
||||
// @link http://github.com/astaxie/beego for the canonical source repository
|
||||
// @license http://github.com/astaxie/beego/blob/master/LICENSE
|
||||
// @authors astaxie
|
||||
|
||||
package session
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/couchbaselabs/go-couchbase"
|
||||
|
||||
"github.com/astaxie/beego/session"
|
||||
)
|
||||
|
||||
var couchbpder = &CouchbaseProvider{}
|
||||
|
||||
type CouchbaseSessionStore struct {
|
||||
b *couchbase.Bucket
|
||||
sid string
|
||||
lock sync.RWMutex
|
||||
values map[interface{}]interface{}
|
||||
maxlifetime int64
|
||||
}
|
||||
|
||||
type CouchbaseProvider struct {
|
||||
maxlifetime int64
|
||||
savePath string
|
||||
pool string
|
||||
bucket string
|
||||
b *couchbase.Bucket
|
||||
}
|
||||
|
||||
func (cs *CouchbaseSessionStore) Set(key, value interface{}) error {
|
||||
cs.lock.Lock()
|
||||
defer cs.lock.Unlock()
|
||||
cs.values[key] = value
|
||||
return nil
|
||||
}
|
||||
|
||||
func (cs *CouchbaseSessionStore) Get(key interface{}) interface{} {
|
||||
cs.lock.RLock()
|
||||
defer cs.lock.RUnlock()
|
||||
if v, ok := cs.values[key]; ok {
|
||||
return v
|
||||
} else {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func (cs *CouchbaseSessionStore) Delete(key interface{}) error {
|
||||
cs.lock.Lock()
|
||||
defer cs.lock.Unlock()
|
||||
delete(cs.values, key)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (cs *CouchbaseSessionStore) Flush() error {
|
||||
cs.lock.Lock()
|
||||
defer cs.lock.Unlock()
|
||||
cs.values = make(map[interface{}]interface{})
|
||||
return nil
|
||||
}
|
||||
|
||||
func (cs *CouchbaseSessionStore) SessionID() string {
|
||||
return cs.sid
|
||||
}
|
||||
|
||||
func (cs *CouchbaseSessionStore) SessionRelease(w http.ResponseWriter) {
|
||||
defer cs.b.Close()
|
||||
|
||||
// if rs.values is empty, return directly
|
||||
if len(cs.values) < 1 {
|
||||
cs.b.Delete(cs.sid)
|
||||
return
|
||||
}
|
||||
|
||||
bo, err := session.EncodeGob(cs.values)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
cs.b.Set(cs.sid, int(cs.maxlifetime), bo)
|
||||
}
|
||||
|
||||
func (cp *CouchbaseProvider) getBucket() *couchbase.Bucket {
|
||||
c, err := couchbase.Connect(cp.savePath)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
pool, err := c.GetPool(cp.pool)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
bucket, err := pool.GetBucket(cp.bucket)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return bucket
|
||||
}
|
||||
|
||||
// init couchbase session
|
||||
// savepath like couchbase server REST/JSON URL
|
||||
// e.g. http://host:port/, Pool, Bucket
|
||||
func (cp *CouchbaseProvider) SessionInit(maxlifetime int64, savePath string) error {
|
||||
cp.maxlifetime = maxlifetime
|
||||
configs := strings.Split(savePath, ",")
|
||||
if len(configs) > 0 {
|
||||
cp.savePath = configs[0]
|
||||
}
|
||||
if len(configs) > 1 {
|
||||
cp.pool = configs[1]
|
||||
}
|
||||
if len(configs) > 2 {
|
||||
cp.bucket = configs[2]
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// read couchbase session by sid
|
||||
func (cp *CouchbaseProvider) SessionRead(sid string) (session.SessionStore, error) {
|
||||
cp.b = cp.getBucket()
|
||||
|
||||
var doc []byte
|
||||
|
||||
err := cp.b.Get(sid, &doc)
|
||||
var kv map[interface{}]interface{}
|
||||
if doc == nil {
|
||||
kv = make(map[interface{}]interface{})
|
||||
} else {
|
||||
kv, err = session.DecodeGob(doc)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
cs := &CouchbaseSessionStore{b: cp.b, sid: sid, values: kv, maxlifetime: cp.maxlifetime}
|
||||
return cs, nil
|
||||
}
|
||||
|
||||
func (cp *CouchbaseProvider) SessionExist(sid string) bool {
|
||||
cp.b = cp.getBucket()
|
||||
defer cp.b.Close()
|
||||
|
||||
var doc []byte
|
||||
|
||||
if err := cp.b.Get(sid, &doc); err != nil || doc == nil {
|
||||
return false
|
||||
} else {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
func (cp *CouchbaseProvider) SessionRegenerate(oldsid, sid string) (session.SessionStore, error) {
|
||||
cp.b = cp.getBucket()
|
||||
|
||||
var doc []byte
|
||||
if err := cp.b.Get(oldsid, &doc); err != nil || doc == nil {
|
||||
cp.b.Set(sid, int(cp.maxlifetime), "")
|
||||
} else {
|
||||
err := cp.b.Delete(oldsid)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
_, _ = cp.b.Add(sid, int(cp.maxlifetime), doc)
|
||||
}
|
||||
|
||||
err := cp.b.Get(sid, &doc)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var kv map[interface{}]interface{}
|
||||
if doc == nil {
|
||||
kv = make(map[interface{}]interface{})
|
||||
} else {
|
||||
kv, err = session.DecodeGob(doc)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
cs := &CouchbaseSessionStore{b: cp.b, sid: sid, values: kv, maxlifetime: cp.maxlifetime}
|
||||
return cs, nil
|
||||
}
|
||||
|
||||
func (cp *CouchbaseProvider) SessionDestroy(sid string) error {
|
||||
cp.b = cp.getBucket()
|
||||
defer cp.b.Close()
|
||||
|
||||
cp.b.Delete(sid)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (cp *CouchbaseProvider) SessionGC() {
|
||||
return
|
||||
}
|
||||
|
||||
func (cp *CouchbaseProvider) SessionAll() int {
|
||||
return 0
|
||||
}
|
||||
|
||||
func init() {
|
||||
session.Register("couchbase", couchbpder)
|
||||
}
|
212
session/memcache/sess_memcache.go
Normal file
212
session/memcache/sess_memcache.go
Normal file
@ -0,0 +1,212 @@
|
||||
// Beego (http://beego.me/)
|
||||
// @description beego is an open-source, high-performance web framework for the Go programming language.
|
||||
// @link http://github.com/astaxie/beego for the canonical source repository
|
||||
// @license http://github.com/astaxie/beego/blob/master/LICENSE
|
||||
// @authors astaxie
|
||||
|
||||
package session
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"sync"
|
||||
|
||||
"github.com/astaxie/beego/session"
|
||||
|
||||
"github.com/beego/memcache"
|
||||
)
|
||||
|
||||
var mempder = &MemProvider{}
|
||||
|
||||
// memcache session store
|
||||
type MemcacheSessionStore struct {
|
||||
c *memcache.Connection
|
||||
sid string
|
||||
lock sync.RWMutex
|
||||
values map[interface{}]interface{}
|
||||
maxlifetime int64
|
||||
}
|
||||
|
||||
// set value in memcache session
|
||||
func (rs *MemcacheSessionStore) Set(key, value interface{}) error {
|
||||
rs.lock.Lock()
|
||||
defer rs.lock.Unlock()
|
||||
rs.values[key] = value
|
||||
return nil
|
||||
}
|
||||
|
||||
// get value in memcache session
|
||||
func (rs *MemcacheSessionStore) Get(key interface{}) interface{} {
|
||||
rs.lock.RLock()
|
||||
defer rs.lock.RUnlock()
|
||||
if v, ok := rs.values[key]; ok {
|
||||
return v
|
||||
} else {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// delete value in memcache session
|
||||
func (rs *MemcacheSessionStore) Delete(key interface{}) error {
|
||||
rs.lock.Lock()
|
||||
defer rs.lock.Unlock()
|
||||
delete(rs.values, key)
|
||||
return nil
|
||||
}
|
||||
|
||||
// clear all values in memcache session
|
||||
func (rs *MemcacheSessionStore) Flush() error {
|
||||
rs.lock.Lock()
|
||||
defer rs.lock.Unlock()
|
||||
rs.values = make(map[interface{}]interface{})
|
||||
return nil
|
||||
}
|
||||
|
||||
// get redis session id
|
||||
func (rs *MemcacheSessionStore) SessionID() string {
|
||||
return rs.sid
|
||||
}
|
||||
|
||||
// save session values to redis
|
||||
func (rs *MemcacheSessionStore) SessionRelease(w http.ResponseWriter) {
|
||||
defer rs.c.Close()
|
||||
// if rs.values is empty, return directly
|
||||
if len(rs.values) < 1 {
|
||||
rs.c.Delete(rs.sid)
|
||||
return
|
||||
}
|
||||
|
||||
b, err := session.EncodeGob(rs.values)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
rs.c.Set(rs.sid, 0, uint64(rs.maxlifetime), b)
|
||||
}
|
||||
|
||||
// redis session provider
|
||||
type MemProvider struct {
|
||||
maxlifetime int64
|
||||
savePath string
|
||||
poolsize int
|
||||
password string
|
||||
}
|
||||
|
||||
// init redis session
|
||||
// savepath like
|
||||
// e.g. 127.0.0.1:9090
|
||||
func (rp *MemProvider) SessionInit(maxlifetime int64, savePath string) error {
|
||||
rp.maxlifetime = maxlifetime
|
||||
rp.savePath = savePath
|
||||
return nil
|
||||
}
|
||||
|
||||
// read redis session by sid
|
||||
func (rp *MemProvider) SessionRead(sid string) (session.SessionStore, error) {
|
||||
conn, err := rp.connectInit()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
kvs, err := conn.Get(sid)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var contain []byte
|
||||
if len(kvs) > 0 {
|
||||
contain = kvs[0].Value
|
||||
}
|
||||
var kv map[interface{}]interface{}
|
||||
if len(contain) == 0 {
|
||||
kv = make(map[interface{}]interface{})
|
||||
} else {
|
||||
kv, err = session.DecodeGob(contain)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
rs := &MemcacheSessionStore{c: conn, sid: sid, values: kv, maxlifetime: rp.maxlifetime}
|
||||
return rs, nil
|
||||
}
|
||||
|
||||
// check redis session exist by sid
|
||||
func (rp *MemProvider) SessionExist(sid string) bool {
|
||||
conn, err := rp.connectInit()
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
defer conn.Close()
|
||||
if kvs, err := conn.Get(sid); err != nil || len(kvs) == 0 {
|
||||
return false
|
||||
} else {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
// generate new sid for redis session
|
||||
func (rp *MemProvider) SessionRegenerate(oldsid, sid string) (session.SessionStore, error) {
|
||||
conn, err := rp.connectInit()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var contain []byte
|
||||
if kvs, err := conn.Get(sid); err != nil || len(kvs) == 0 {
|
||||
// oldsid doesn't exists, set the new sid directly
|
||||
// ignore error here, since if it return error
|
||||
// the existed value will be 0
|
||||
conn.Set(sid, 0, uint64(rp.maxlifetime), []byte(""))
|
||||
} else {
|
||||
conn.Delete(oldsid)
|
||||
conn.Set(sid, 0, uint64(rp.maxlifetime), kvs[0].Value)
|
||||
contain = kvs[0].Value
|
||||
}
|
||||
|
||||
var kv map[interface{}]interface{}
|
||||
if len(contain) == 0 {
|
||||
kv = make(map[interface{}]interface{})
|
||||
} else {
|
||||
kv, err = session.DecodeGob(contain)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
rs := &MemcacheSessionStore{c: conn, sid: sid, values: kv, maxlifetime: rp.maxlifetime}
|
||||
return rs, nil
|
||||
}
|
||||
|
||||
// delete redis session by id
|
||||
func (rp *MemProvider) SessionDestroy(sid string) error {
|
||||
conn, err := rp.connectInit()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
_, err = conn.Delete(sid)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Impelment method, no used.
|
||||
func (rp *MemProvider) SessionGC() {
|
||||
return
|
||||
}
|
||||
|
||||
// @todo
|
||||
func (rp *MemProvider) SessionAll() int {
|
||||
return 0
|
||||
}
|
||||
|
||||
// connect to memcache and keep the connection.
|
||||
func (rp *MemProvider) connectInit() (*memcache.Connection, error) {
|
||||
c, err := memcache.Connect(rp.savePath)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return c, nil
|
||||
}
|
||||
|
||||
func init() {
|
||||
session.Register("memcache", mempder)
|
||||
}
|
@ -1,22 +1,33 @@
|
||||
// Beego (http://beego.me/)
|
||||
// @description beego is an open-source, high-performance web framework for the Go programming language.
|
||||
// @link http://github.com/astaxie/beego for the canonical source repository
|
||||
// @license http://github.com/astaxie/beego/blob/master/LICENSE
|
||||
// @authors astaxie
|
||||
|
||||
package session
|
||||
|
||||
//CREATE TABLE `session` (
|
||||
// mysql session support need create table as sql:
|
||||
// CREATE TABLE `session` (
|
||||
// `session_key` char(64) NOT NULL,
|
||||
// `session_data` blob,
|
||||
// session_data` blob,
|
||||
// `session_expiry` int(11) unsigned NOT NULL,
|
||||
// PRIMARY KEY (`session_key`)
|
||||
//) ENGINE=MyISAM DEFAULT CHARSET=utf8;
|
||||
// ) ENGINE=MyISAM DEFAULT CHARSET=utf8;
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/astaxie/beego/session"
|
||||
|
||||
_ "github.com/go-sql-driver/mysql"
|
||||
)
|
||||
|
||||
var mysqlpder = &MysqlProvider{}
|
||||
|
||||
// mysql session store
|
||||
type MysqlSessionStore struct {
|
||||
c *sql.DB
|
||||
sid string
|
||||
@ -24,6 +35,8 @@ type MysqlSessionStore struct {
|
||||
values map[interface{}]interface{}
|
||||
}
|
||||
|
||||
// set value in mysql session.
|
||||
// it is temp value in map.
|
||||
func (st *MysqlSessionStore) Set(key, value interface{}) error {
|
||||
st.lock.Lock()
|
||||
defer st.lock.Unlock()
|
||||
@ -31,6 +44,7 @@ func (st *MysqlSessionStore) Set(key, value interface{}) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// get value from mysql session
|
||||
func (st *MysqlSessionStore) Get(key interface{}) interface{} {
|
||||
st.lock.RLock()
|
||||
defer st.lock.RUnlock()
|
||||
@ -39,9 +53,9 @@ func (st *MysqlSessionStore) Get(key interface{}) interface{} {
|
||||
} else {
|
||||
return nil
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// delete value in mysql session
|
||||
func (st *MysqlSessionStore) Delete(key interface{}) error {
|
||||
st.lock.Lock()
|
||||
defer st.lock.Unlock()
|
||||
@ -49,6 +63,7 @@ func (st *MysqlSessionStore) Delete(key interface{}) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// clear all values in mysql session
|
||||
func (st *MysqlSessionStore) Flush() error {
|
||||
st.lock.Lock()
|
||||
defer st.lock.Unlock()
|
||||
@ -56,26 +71,31 @@ func (st *MysqlSessionStore) Flush() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// get session id of this mysql session store
|
||||
func (st *MysqlSessionStore) SessionID() string {
|
||||
return st.sid
|
||||
}
|
||||
|
||||
func (st *MysqlSessionStore) SessionRelease() {
|
||||
// save mysql session values to database.
|
||||
// must call this method to save values to database.
|
||||
func (st *MysqlSessionStore) SessionRelease(w http.ResponseWriter) {
|
||||
defer st.c.Close()
|
||||
if len(st.values) > 0 {
|
||||
b, err := encodeGob(st.values)
|
||||
b, err := session.EncodeGob(st.values)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
st.c.Exec("UPDATE session set `session_data`= ? where session_key=?", b, st.sid)
|
||||
}
|
||||
st.c.Exec("UPDATE session set `session_data`=?, `session_expiry`=? where session_key=?",
|
||||
b, time.Now().Unix(), st.sid)
|
||||
|
||||
}
|
||||
|
||||
// mysql session provider
|
||||
type MysqlProvider struct {
|
||||
maxlifetime int64
|
||||
savePath string
|
||||
}
|
||||
|
||||
// connect to mysql
|
||||
func (mp *MysqlProvider) connectInit() *sql.DB {
|
||||
db, e := sql.Open("mysql", mp.savePath)
|
||||
if e != nil {
|
||||
@ -84,25 +104,29 @@ func (mp *MysqlProvider) connectInit() *sql.DB {
|
||||
return db
|
||||
}
|
||||
|
||||
// init mysql session.
|
||||
// savepath is the connection string of mysql.
|
||||
func (mp *MysqlProvider) SessionInit(maxlifetime int64, savePath string) error {
|
||||
mp.maxlifetime = maxlifetime
|
||||
mp.savePath = savePath
|
||||
return nil
|
||||
}
|
||||
|
||||
func (mp *MysqlProvider) SessionRead(sid string) (SessionStore, error) {
|
||||
// get mysql session by sid
|
||||
func (mp *MysqlProvider) SessionRead(sid string) (session.SessionStore, error) {
|
||||
c := mp.connectInit()
|
||||
row := c.QueryRow("select session_data from session where session_key=?", sid)
|
||||
var sessiondata []byte
|
||||
err := row.Scan(&sessiondata)
|
||||
if err == sql.ErrNoRows {
|
||||
c.Exec("insert into session(`session_key`,`session_data`,`session_expiry`) values(?,?,?)", sid, "", time.Now().Unix())
|
||||
c.Exec("insert into session(`session_key`,`session_data`,`session_expiry`) values(?,?,?)",
|
||||
sid, "", time.Now().Unix())
|
||||
}
|
||||
var kv map[interface{}]interface{}
|
||||
if len(sessiondata) == 0 {
|
||||
kv = make(map[interface{}]interface{})
|
||||
} else {
|
||||
kv, err = decodeGob(sessiondata)
|
||||
kv, err = session.DecodeGob(sessiondata)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -111,8 +135,10 @@ func (mp *MysqlProvider) SessionRead(sid string) (SessionStore, error) {
|
||||
return rs, nil
|
||||
}
|
||||
|
||||
// check mysql session exist
|
||||
func (mp *MysqlProvider) SessionExist(sid string) bool {
|
||||
c := mp.connectInit()
|
||||
defer c.Close()
|
||||
row := c.QueryRow("select session_data from session where session_key=?", sid)
|
||||
var sessiondata []byte
|
||||
err := row.Scan(&sessiondata)
|
||||
@ -123,7 +149,8 @@ func (mp *MysqlProvider) SessionExist(sid string) bool {
|
||||
}
|
||||
}
|
||||
|
||||
func (mp *MysqlProvider) SessionRegenerate(oldsid, sid string) (SessionStore, error) {
|
||||
// generate new sid for mysql session
|
||||
func (mp *MysqlProvider) SessionRegenerate(oldsid, sid string) (session.SessionStore, error) {
|
||||
c := mp.connectInit()
|
||||
row := c.QueryRow("select session_data from session where session_key=?", oldsid)
|
||||
var sessiondata []byte
|
||||
@ -136,7 +163,7 @@ func (mp *MysqlProvider) SessionRegenerate(oldsid, sid string) (SessionStore, er
|
||||
if len(sessiondata) == 0 {
|
||||
kv = make(map[interface{}]interface{})
|
||||
} else {
|
||||
kv, err = decodeGob(sessiondata)
|
||||
kv, err = session.DecodeGob(sessiondata)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -145,6 +172,7 @@ func (mp *MysqlProvider) SessionRegenerate(oldsid, sid string) (SessionStore, er
|
||||
return rs, nil
|
||||
}
|
||||
|
||||
// delete mysql session by sid
|
||||
func (mp *MysqlProvider) SessionDestroy(sid string) error {
|
||||
c := mp.connectInit()
|
||||
c.Exec("DELETE FROM session where session_key=?", sid)
|
||||
@ -152,6 +180,7 @@ func (mp *MysqlProvider) SessionDestroy(sid string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// delete expired values in mysql session
|
||||
func (mp *MysqlProvider) SessionGC() {
|
||||
c := mp.connectInit()
|
||||
c.Exec("DELETE from session where session_expiry < ?", time.Now().Unix()-mp.maxlifetime)
|
||||
@ -159,6 +188,7 @@ func (mp *MysqlProvider) SessionGC() {
|
||||
return
|
||||
}
|
||||
|
||||
// count values in mysql session
|
||||
func (mp *MysqlProvider) SessionAll() int {
|
||||
c := mp.connectInit()
|
||||
defer c.Close()
|
||||
@ -171,5 +201,5 @@ func (mp *MysqlProvider) SessionAll() int {
|
||||
}
|
||||
|
||||
func init() {
|
||||
Register("mysql", mysqlpder)
|
||||
session.Register("mysql", mysqlpder)
|
||||
}
|
235
session/postgres/sess_postgresql.go
Normal file
235
session/postgres/sess_postgresql.go
Normal file
@ -0,0 +1,235 @@
|
||||
// Beego (http://beego.me/)
|
||||
// @description beego is an open-source, high-performance web framework for the Go programming language.
|
||||
// @link http://github.com/astaxie/beego for the canonical source repository
|
||||
// @license http://github.com/astaxie/beego/blob/master/LICENSE
|
||||
// @authors astaxie
|
||||
|
||||
package session
|
||||
|
||||
/*
|
||||
|
||||
beego session provider for postgresql
|
||||
-------------------------------------
|
||||
|
||||
depends on github.com/lib/pq:
|
||||
|
||||
go install github.com/lib/pq
|
||||
|
||||
|
||||
needs this table in your database:
|
||||
|
||||
CREATE TABLE session (
|
||||
session_key char(64) NOT NULL,
|
||||
session_data bytea,
|
||||
session_expiry timestamp NOT NULL,
|
||||
CONSTRAINT session_key PRIMARY KEY(session_key)
|
||||
);
|
||||
|
||||
|
||||
will be activated with these settings in app.conf:
|
||||
|
||||
SessionOn = true
|
||||
SessionProvider = postgresql
|
||||
SessionSavePath = "user=a password=b dbname=c sslmode=disable"
|
||||
SessionName = session
|
||||
|
||||
*/
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/astaxie/beego/session"
|
||||
|
||||
_ "github.com/lib/pq"
|
||||
)
|
||||
|
||||
var postgresqlpder = &PostgresqlProvider{}
|
||||
|
||||
// postgresql session store
|
||||
type PostgresqlSessionStore struct {
|
||||
c *sql.DB
|
||||
sid string
|
||||
lock sync.RWMutex
|
||||
values map[interface{}]interface{}
|
||||
}
|
||||
|
||||
// set value in postgresql session.
|
||||
// it is temp value in map.
|
||||
func (st *PostgresqlSessionStore) Set(key, value interface{}) error {
|
||||
st.lock.Lock()
|
||||
defer st.lock.Unlock()
|
||||
st.values[key] = value
|
||||
return nil
|
||||
}
|
||||
|
||||
// get value from postgresql session
|
||||
func (st *PostgresqlSessionStore) Get(key interface{}) interface{} {
|
||||
st.lock.RLock()
|
||||
defer st.lock.RUnlock()
|
||||
if v, ok := st.values[key]; ok {
|
||||
return v
|
||||
} else {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// delete value in postgresql session
|
||||
func (st *PostgresqlSessionStore) Delete(key interface{}) error {
|
||||
st.lock.Lock()
|
||||
defer st.lock.Unlock()
|
||||
delete(st.values, key)
|
||||
return nil
|
||||
}
|
||||
|
||||
// clear all values in postgresql session
|
||||
func (st *PostgresqlSessionStore) Flush() error {
|
||||
st.lock.Lock()
|
||||
defer st.lock.Unlock()
|
||||
st.values = make(map[interface{}]interface{})
|
||||
return nil
|
||||
}
|
||||
|
||||
// get session id of this postgresql session store
|
||||
func (st *PostgresqlSessionStore) SessionID() string {
|
||||
return st.sid
|
||||
}
|
||||
|
||||
// save postgresql session values to database.
|
||||
// must call this method to save values to database.
|
||||
func (st *PostgresqlSessionStore) SessionRelease(w http.ResponseWriter) {
|
||||
defer st.c.Close()
|
||||
b, err := session.EncodeGob(st.values)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
st.c.Exec("UPDATE session set session_data=$1, session_expiry=$2 where session_key=$3",
|
||||
b, time.Now().Format(time.RFC3339), st.sid)
|
||||
|
||||
}
|
||||
|
||||
// postgresql session provider
|
||||
type PostgresqlProvider struct {
|
||||
maxlifetime int64
|
||||
savePath string
|
||||
}
|
||||
|
||||
// connect to postgresql
|
||||
func (mp *PostgresqlProvider) connectInit() *sql.DB {
|
||||
db, e := sql.Open("postgres", mp.savePath)
|
||||
if e != nil {
|
||||
return nil
|
||||
}
|
||||
return db
|
||||
}
|
||||
|
||||
// init postgresql session.
|
||||
// savepath is the connection string of postgresql.
|
||||
func (mp *PostgresqlProvider) SessionInit(maxlifetime int64, savePath string) error {
|
||||
mp.maxlifetime = maxlifetime
|
||||
mp.savePath = savePath
|
||||
return nil
|
||||
}
|
||||
|
||||
// get postgresql session by sid
|
||||
func (mp *PostgresqlProvider) SessionRead(sid string) (session.SessionStore, error) {
|
||||
c := mp.connectInit()
|
||||
row := c.QueryRow("select session_data from session where session_key=$1", sid)
|
||||
var sessiondata []byte
|
||||
err := row.Scan(&sessiondata)
|
||||
if err == sql.ErrNoRows {
|
||||
_, err = c.Exec("insert into session(session_key,session_data,session_expiry) values($1,$2,$3)",
|
||||
sid, "", time.Now().Format(time.RFC3339))
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
} else if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var kv map[interface{}]interface{}
|
||||
if len(sessiondata) == 0 {
|
||||
kv = make(map[interface{}]interface{})
|
||||
} else {
|
||||
kv, err = session.DecodeGob(sessiondata)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
rs := &PostgresqlSessionStore{c: c, sid: sid, values: kv}
|
||||
return rs, nil
|
||||
}
|
||||
|
||||
// check postgresql session exist
|
||||
func (mp *PostgresqlProvider) SessionExist(sid string) bool {
|
||||
c := mp.connectInit()
|
||||
defer c.Close()
|
||||
row := c.QueryRow("select session_data from session where session_key=$1", sid)
|
||||
var sessiondata []byte
|
||||
err := row.Scan(&sessiondata)
|
||||
|
||||
if err == sql.ErrNoRows {
|
||||
return false
|
||||
} else {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
// generate new sid for postgresql session
|
||||
func (mp *PostgresqlProvider) SessionRegenerate(oldsid, sid string) (session.SessionStore, error) {
|
||||
c := mp.connectInit()
|
||||
row := c.QueryRow("select session_data from session where session_key=$1", oldsid)
|
||||
var sessiondata []byte
|
||||
err := row.Scan(&sessiondata)
|
||||
if err == sql.ErrNoRows {
|
||||
c.Exec("insert into session(session_key,session_data,session_expiry) values($1,$2,$3)",
|
||||
oldsid, "", time.Now().Format(time.RFC3339))
|
||||
}
|
||||
c.Exec("update session set session_key=$1 where session_key=$2", sid, oldsid)
|
||||
var kv map[interface{}]interface{}
|
||||
if len(sessiondata) == 0 {
|
||||
kv = make(map[interface{}]interface{})
|
||||
} else {
|
||||
kv, err = session.DecodeGob(sessiondata)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
rs := &PostgresqlSessionStore{c: c, sid: sid, values: kv}
|
||||
return rs, nil
|
||||
}
|
||||
|
||||
// delete postgresql session by sid
|
||||
func (mp *PostgresqlProvider) SessionDestroy(sid string) error {
|
||||
c := mp.connectInit()
|
||||
c.Exec("DELETE FROM session where session_key=$1", sid)
|
||||
c.Close()
|
||||
return nil
|
||||
}
|
||||
|
||||
// delete expired values in postgresql session
|
||||
func (mp *PostgresqlProvider) SessionGC() {
|
||||
c := mp.connectInit()
|
||||
c.Exec("DELETE from session where EXTRACT(EPOCH FROM (current_timestamp - session_expiry)) > $1", mp.maxlifetime)
|
||||
c.Close()
|
||||
return
|
||||
}
|
||||
|
||||
// count values in postgresql session
|
||||
func (mp *PostgresqlProvider) SessionAll() int {
|
||||
c := mp.connectInit()
|
||||
defer c.Close()
|
||||
var total int
|
||||
err := c.QueryRow("SELECT count(*) as num from session").Scan(&total)
|
||||
if err != nil {
|
||||
return 0
|
||||
}
|
||||
return total
|
||||
}
|
||||
|
||||
func init() {
|
||||
session.Register("postgresql", postgresqlpder)
|
||||
}
|
@ -1,27 +1,39 @@
|
||||
// Beego (http://beego.me/)
|
||||
// @description beego is an open-source, high-performance web framework for the Go programming language.
|
||||
// @link http://github.com/astaxie/beego for the canonical source repository
|
||||
// @license http://github.com/astaxie/beego/blob/master/LICENSE
|
||||
// @authors astaxie
|
||||
|
||||
package session
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/astaxie/beego/session"
|
||||
|
||||
"github.com/beego/redigo/redis"
|
||||
)
|
||||
|
||||
var redispder = &RedisProvider{}
|
||||
|
||||
// redis max pool size
|
||||
var MAX_POOL_SIZE = 100
|
||||
|
||||
var redisPool chan redis.Conn
|
||||
|
||||
// redis session store
|
||||
type RedisSessionStore struct {
|
||||
c redis.Conn
|
||||
p *redis.Pool
|
||||
sid string
|
||||
lock sync.RWMutex
|
||||
values map[interface{}]interface{}
|
||||
maxlifetime int64
|
||||
}
|
||||
|
||||
// set value in redis session
|
||||
func (rs *RedisSessionStore) Set(key, value interface{}) error {
|
||||
rs.lock.Lock()
|
||||
defer rs.lock.Unlock()
|
||||
@ -29,6 +41,7 @@ func (rs *RedisSessionStore) Set(key, value interface{}) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// get value in redis session
|
||||
func (rs *RedisSessionStore) Get(key interface{}) interface{} {
|
||||
rs.lock.RLock()
|
||||
defer rs.lock.RUnlock()
|
||||
@ -37,9 +50,9 @@ func (rs *RedisSessionStore) Get(key interface{}) interface{} {
|
||||
} else {
|
||||
return nil
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// delete value in redis session
|
||||
func (rs *RedisSessionStore) Delete(key interface{}) error {
|
||||
rs.lock.Lock()
|
||||
defer rs.lock.Unlock()
|
||||
@ -47,6 +60,7 @@ func (rs *RedisSessionStore) Delete(key interface{}) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// clear all values in redis session
|
||||
func (rs *RedisSessionStore) Flush() error {
|
||||
rs.lock.Lock()
|
||||
defer rs.lock.Unlock()
|
||||
@ -54,22 +68,31 @@ func (rs *RedisSessionStore) Flush() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// get redis session id
|
||||
func (rs *RedisSessionStore) SessionID() string {
|
||||
return rs.sid
|
||||
}
|
||||
|
||||
func (rs *RedisSessionStore) SessionRelease() {
|
||||
defer rs.c.Close()
|
||||
if len(rs.values) > 0 {
|
||||
b, err := encodeGob(rs.values)
|
||||
// save session values to redis
|
||||
func (rs *RedisSessionStore) SessionRelease(w http.ResponseWriter) {
|
||||
c := rs.p.Get()
|
||||
defer c.Close()
|
||||
|
||||
// if rs.values is empty, return directly
|
||||
if len(rs.values) < 1 {
|
||||
c.Do("DEL", rs.sid)
|
||||
return
|
||||
}
|
||||
|
||||
b, err := session.EncodeGob(rs.values)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
rs.c.Do("SET", rs.sid, string(b))
|
||||
rs.c.Do("EXPIRE", rs.sid, rs.maxlifetime)
|
||||
}
|
||||
|
||||
c.Do("SET", rs.sid, string(b), "EX", rs.maxlifetime)
|
||||
}
|
||||
|
||||
// redis session provider
|
||||
type RedisProvider struct {
|
||||
maxlifetime int64
|
||||
savePath string
|
||||
@ -78,8 +101,9 @@ type RedisProvider struct {
|
||||
poollist *redis.Pool
|
||||
}
|
||||
|
||||
//savepath like redisserveraddr,poolsize,password
|
||||
//127.0.0.1:6379,100,astaxie
|
||||
// init redis session
|
||||
// savepath like redis server addr,pool size,password
|
||||
// e.g. 127.0.0.1:6379,100,astaxie
|
||||
func (rp *RedisProvider) SessionInit(maxlifetime int64, savePath string) error {
|
||||
rp.maxlifetime = maxlifetime
|
||||
configs := strings.Split(savePath, ",")
|
||||
@ -112,32 +136,35 @@ func (rp *RedisProvider) SessionInit(maxlifetime int64, savePath string) error {
|
||||
}
|
||||
return c, err
|
||||
}, rp.poolsize)
|
||||
return nil
|
||||
|
||||
return rp.poollist.Get().Err()
|
||||
}
|
||||
|
||||
func (rp *RedisProvider) SessionRead(sid string) (SessionStore, error) {
|
||||
// read redis session by sid
|
||||
func (rp *RedisProvider) SessionRead(sid string) (session.SessionStore, error) {
|
||||
c := rp.poollist.Get()
|
||||
if existed, err := redis.Int(c.Do("EXISTS", sid)); err != nil || existed == 0 {
|
||||
c.Do("SET", sid)
|
||||
}
|
||||
c.Do("EXPIRE", sid, rp.maxlifetime)
|
||||
defer c.Close()
|
||||
|
||||
kvs, err := redis.String(c.Do("GET", sid))
|
||||
var kv map[interface{}]interface{}
|
||||
if len(kvs) == 0 {
|
||||
kv = make(map[interface{}]interface{})
|
||||
} else {
|
||||
kv, err = decodeGob([]byte(kvs))
|
||||
kv, err = session.DecodeGob([]byte(kvs))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
rs := &RedisSessionStore{c: c, sid: sid, values: kv, maxlifetime: rp.maxlifetime}
|
||||
|
||||
rs := &RedisSessionStore{p: rp.poollist, sid: sid, values: kv, maxlifetime: rp.maxlifetime}
|
||||
return rs, nil
|
||||
}
|
||||
|
||||
// check redis session exist by sid
|
||||
func (rp *RedisProvider) SessionExist(sid string) bool {
|
||||
c := rp.poollist.Get()
|
||||
defer c.Close()
|
||||
|
||||
if existed, err := redis.Int(c.Do("EXISTS", sid)); err != nil || existed == 0 {
|
||||
return false
|
||||
} else {
|
||||
@ -145,44 +172,55 @@ func (rp *RedisProvider) SessionExist(sid string) bool {
|
||||
}
|
||||
}
|
||||
|
||||
func (rp *RedisProvider) SessionRegenerate(oldsid, sid string) (SessionStore, error) {
|
||||
// generate new sid for redis session
|
||||
func (rp *RedisProvider) SessionRegenerate(oldsid, sid string) (session.SessionStore, error) {
|
||||
c := rp.poollist.Get()
|
||||
if existed, err := redis.Int(c.Do("EXISTS", oldsid)); err != nil || existed == 0 {
|
||||
c.Do("SET", oldsid)
|
||||
}
|
||||
defer c.Close()
|
||||
|
||||
if existed, _ := redis.Int(c.Do("EXISTS", oldsid)); existed == 0 {
|
||||
// oldsid doesn't exists, set the new sid directly
|
||||
// ignore error here, since if it return error
|
||||
// the existed value will be 0
|
||||
c.Do("SET", sid, "", "EX", rp.maxlifetime)
|
||||
} else {
|
||||
c.Do("RENAME", oldsid, sid)
|
||||
c.Do("EXPIRE", sid, rp.maxlifetime)
|
||||
}
|
||||
|
||||
kvs, err := redis.String(c.Do("GET", sid))
|
||||
var kv map[interface{}]interface{}
|
||||
if len(kvs) == 0 {
|
||||
kv = make(map[interface{}]interface{})
|
||||
} else {
|
||||
kv, err = decodeGob([]byte(kvs))
|
||||
kv, err = session.DecodeGob([]byte(kvs))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
rs := &RedisSessionStore{c: c, sid: sid, values: kv, maxlifetime: rp.maxlifetime}
|
||||
|
||||
rs := &RedisSessionStore{p: rp.poollist, sid: sid, values: kv, maxlifetime: rp.maxlifetime}
|
||||
return rs, nil
|
||||
}
|
||||
|
||||
// delete redis session by id
|
||||
func (rp *RedisProvider) SessionDestroy(sid string) error {
|
||||
c := rp.poollist.Get()
|
||||
defer c.Close()
|
||||
|
||||
c.Do("DEL", sid)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Impelment method, no used.
|
||||
func (rp *RedisProvider) SessionGC() {
|
||||
return
|
||||
}
|
||||
|
||||
//@todo
|
||||
// @todo
|
||||
func (rp *RedisProvider) SessionAll() int {
|
||||
|
||||
return 0
|
||||
}
|
||||
|
||||
func init() {
|
||||
Register("redis", redispder)
|
||||
session.Register("redis", redispder)
|
||||
}
|
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user