mirror of
https://github.com/astaxie/beego.git
synced 2025-06-11 14:10:39 +00:00
Merge pull request #4173 from AllenX2018/fix-bug-queryRow
Fix issue 3866
This commit is contained in:
59
pkg/client/cache/README.md
vendored
Normal file
59
pkg/client/cache/README.md
vendored
Normal file
@ -0,0 +1,59 @@
|
||||
## cache
|
||||
cache is a Go cache manager. It can use many cache adapters. The repo is inspired by `database/sql` .
|
||||
|
||||
|
||||
## How to install?
|
||||
|
||||
go get github.com/astaxie/beego/cache
|
||||
|
||||
|
||||
## What adapters are supported?
|
||||
|
||||
As of now this cache support memory, Memcache and Redis.
|
||||
|
||||
|
||||
## How to use it?
|
||||
|
||||
First you must import it
|
||||
|
||||
import (
|
||||
"github.com/astaxie/beego/cache"
|
||||
)
|
||||
|
||||
Then init a Cache (example with memory adapter)
|
||||
|
||||
bm, err := cache.NewCache("memory", `{"interval":60}`)
|
||||
|
||||
Use it like this:
|
||||
|
||||
bm.Put("astaxie", 1, 10 * time.Second)
|
||||
bm.Get("astaxie")
|
||||
bm.IsExist("astaxie")
|
||||
bm.Delete("astaxie")
|
||||
|
||||
|
||||
## Memory adapter
|
||||
|
||||
Configure memory adapter like this:
|
||||
|
||||
{"interval":60}
|
||||
|
||||
interval means the gc time. The cache will check at each time interval, whether item has expired.
|
||||
|
||||
|
||||
## Memcache adapter
|
||||
|
||||
Memcache adapter use the [gomemcache](http://github.com/bradfitz/gomemcache) client.
|
||||
|
||||
Configure like this:
|
||||
|
||||
{"conn":"127.0.0.1:11211"}
|
||||
|
||||
|
||||
## Redis adapter
|
||||
|
||||
Redis adapter use the [redigo](http://github.com/gomodule/redigo) client.
|
||||
|
||||
Configure like this:
|
||||
|
||||
{"conn":":6039"}
|
103
pkg/client/cache/cache.go
vendored
Normal file
103
pkg/client/cache/cache.go
vendored
Normal file
@ -0,0 +1,103 @@
|
||||
// Copyright 2014 beego Author. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
// Package cache provide a Cache interface and some implement engine
|
||||
// Usage:
|
||||
//
|
||||
// import(
|
||||
// "github.com/astaxie/beego/cache"
|
||||
// )
|
||||
//
|
||||
// bm, err := cache.NewCache("memory", `{"interval":60}`)
|
||||
//
|
||||
// Use it like this:
|
||||
//
|
||||
// bm.Put("astaxie", 1, 10 * time.Second)
|
||||
// bm.Get("astaxie")
|
||||
// bm.IsExist("astaxie")
|
||||
// bm.Delete("astaxie")
|
||||
//
|
||||
// more docs http://beego.me/docs/module/cache.md
|
||||
package cache
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Cache interface contains all behaviors for cache adapter.
|
||||
// usage:
|
||||
// cache.Register("file",cache.NewFileCache) // this operation is run in init method of file.go.
|
||||
// c,err := cache.NewCache("file","{....}")
|
||||
// c.Put("key",value, 3600 * time.Second)
|
||||
// v := c.Get("key")
|
||||
//
|
||||
// c.Incr("counter") // now is 1
|
||||
// c.Incr("counter") // now is 2
|
||||
// count := c.Get("counter").(int)
|
||||
type Cache interface {
|
||||
// Get a cached value by key.
|
||||
Get(key string) interface{}
|
||||
// GetMulti is a batch version of Get.
|
||||
GetMulti(keys []string) []interface{}
|
||||
// Set a cached value with key and expire time.
|
||||
Put(key string, val interface{}, timeout time.Duration) error
|
||||
// Delete cached value by key.
|
||||
Delete(key string) error
|
||||
// Increment a cached int value by key, as a counter.
|
||||
Incr(key string) error
|
||||
// Decrement a cached int value by key, as a counter.
|
||||
Decr(key string) error
|
||||
// Check if a cached value exists or not.
|
||||
IsExist(key string) bool
|
||||
// Clear all cache.
|
||||
ClearAll() error
|
||||
// Start gc routine based on config string settings.
|
||||
StartAndGC(config string) error
|
||||
}
|
||||
|
||||
// Instance is a function create a new Cache Instance
|
||||
type Instance func() Cache
|
||||
|
||||
var adapters = make(map[string]Instance)
|
||||
|
||||
// Register makes a cache adapter available by the adapter name.
|
||||
// If Register is called twice with the same name or if driver is nil,
|
||||
// it panics.
|
||||
func Register(name string, adapter Instance) {
|
||||
if adapter == nil {
|
||||
panic("cache: Register adapter is nil")
|
||||
}
|
||||
if _, ok := adapters[name]; ok {
|
||||
panic("cache: Register called twice for adapter " + name)
|
||||
}
|
||||
adapters[name] = adapter
|
||||
}
|
||||
|
||||
// NewCache creates a new cache driver by adapter name and config string.
|
||||
// config: must be in JSON format such as {"interval":360}.
|
||||
// Starts gc automatically.
|
||||
func NewCache(adapterName, config string) (adapter Cache, err error) {
|
||||
instanceFunc, ok := adapters[adapterName]
|
||||
if !ok {
|
||||
err = fmt.Errorf("cache: unknown adapter name %q (forgot to import?)", adapterName)
|
||||
return
|
||||
}
|
||||
adapter = instanceFunc()
|
||||
err = adapter.StartAndGC(config)
|
||||
if err != nil {
|
||||
adapter = nil
|
||||
}
|
||||
return
|
||||
}
|
191
pkg/client/cache/cache_test.go
vendored
Normal file
191
pkg/client/cache/cache_test.go
vendored
Normal file
@ -0,0 +1,191 @@
|
||||
// Copyright 2014 beego Author. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package cache
|
||||
|
||||
import (
|
||||
"os"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestCacheIncr(t *testing.T) {
|
||||
bm, err := NewCache("memory", `{"interval":20}`)
|
||||
if err != nil {
|
||||
t.Error("init err")
|
||||
}
|
||||
//timeoutDuration := 10 * time.Second
|
||||
|
||||
bm.Put("edwardhey", 0, time.Second*20)
|
||||
wg := sync.WaitGroup{}
|
||||
wg.Add(10)
|
||||
for i := 0; i < 10; i++ {
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
bm.Incr("edwardhey")
|
||||
}()
|
||||
}
|
||||
wg.Wait()
|
||||
if bm.Get("edwardhey").(int) != 10 {
|
||||
t.Error("Incr err")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCache(t *testing.T) {
|
||||
bm, err := NewCache("memory", `{"interval":20}`)
|
||||
if err != nil {
|
||||
t.Error("init err")
|
||||
}
|
||||
timeoutDuration := 10 * time.Second
|
||||
if err = bm.Put("astaxie", 1, timeoutDuration); err != nil {
|
||||
t.Error("set Error", err)
|
||||
}
|
||||
if !bm.IsExist("astaxie") {
|
||||
t.Error("check err")
|
||||
}
|
||||
|
||||
if v := bm.Get("astaxie"); v.(int) != 1 {
|
||||
t.Error("get err")
|
||||
}
|
||||
|
||||
time.Sleep(30 * time.Second)
|
||||
|
||||
if bm.IsExist("astaxie") {
|
||||
t.Error("check err")
|
||||
}
|
||||
|
||||
if err = bm.Put("astaxie", 1, timeoutDuration); err != nil {
|
||||
t.Error("set Error", err)
|
||||
}
|
||||
|
||||
if err = bm.Incr("astaxie"); err != nil {
|
||||
t.Error("Incr Error", err)
|
||||
}
|
||||
|
||||
if v := bm.Get("astaxie"); v.(int) != 2 {
|
||||
t.Error("get err")
|
||||
}
|
||||
|
||||
if err = bm.Decr("astaxie"); err != nil {
|
||||
t.Error("Decr Error", err)
|
||||
}
|
||||
|
||||
if v := bm.Get("astaxie"); v.(int) != 1 {
|
||||
t.Error("get err")
|
||||
}
|
||||
bm.Delete("astaxie")
|
||||
if bm.IsExist("astaxie") {
|
||||
t.Error("delete err")
|
||||
}
|
||||
|
||||
//test GetMulti
|
||||
if err = bm.Put("astaxie", "author", timeoutDuration); err != nil {
|
||||
t.Error("set Error", err)
|
||||
}
|
||||
if !bm.IsExist("astaxie") {
|
||||
t.Error("check err")
|
||||
}
|
||||
if v := bm.Get("astaxie"); v.(string) != "author" {
|
||||
t.Error("get err")
|
||||
}
|
||||
|
||||
if err = bm.Put("astaxie1", "author1", timeoutDuration); err != nil {
|
||||
t.Error("set Error", err)
|
||||
}
|
||||
if !bm.IsExist("astaxie1") {
|
||||
t.Error("check err")
|
||||
}
|
||||
|
||||
vv := bm.GetMulti([]string{"astaxie", "astaxie1"})
|
||||
if len(vv) != 2 {
|
||||
t.Error("GetMulti ERROR")
|
||||
}
|
||||
if vv[0].(string) != "author" {
|
||||
t.Error("GetMulti ERROR")
|
||||
}
|
||||
if vv[1].(string) != "author1" {
|
||||
t.Error("GetMulti ERROR")
|
||||
}
|
||||
}
|
||||
|
||||
func TestFileCache(t *testing.T) {
|
||||
bm, err := NewCache("file", `{"CachePath":"cache","FileSuffix":".bin","DirectoryLevel":"2","EmbedExpiry":"0"}`)
|
||||
if err != nil {
|
||||
t.Error("init err")
|
||||
}
|
||||
timeoutDuration := 10 * time.Second
|
||||
if err = bm.Put("astaxie", 1, timeoutDuration); err != nil {
|
||||
t.Error("set Error", err)
|
||||
}
|
||||
if !bm.IsExist("astaxie") {
|
||||
t.Error("check err")
|
||||
}
|
||||
|
||||
if v := bm.Get("astaxie"); v.(int) != 1 {
|
||||
t.Error("get err")
|
||||
}
|
||||
|
||||
if err = bm.Incr("astaxie"); err != nil {
|
||||
t.Error("Incr Error", err)
|
||||
}
|
||||
|
||||
if v := bm.Get("astaxie"); v.(int) != 2 {
|
||||
t.Error("get err")
|
||||
}
|
||||
|
||||
if err = bm.Decr("astaxie"); err != nil {
|
||||
t.Error("Decr Error", err)
|
||||
}
|
||||
|
||||
if v := bm.Get("astaxie"); v.(int) != 1 {
|
||||
t.Error("get err")
|
||||
}
|
||||
bm.Delete("astaxie")
|
||||
if bm.IsExist("astaxie") {
|
||||
t.Error("delete err")
|
||||
}
|
||||
|
||||
//test string
|
||||
if err = bm.Put("astaxie", "author", timeoutDuration); err != nil {
|
||||
t.Error("set Error", err)
|
||||
}
|
||||
if !bm.IsExist("astaxie") {
|
||||
t.Error("check err")
|
||||
}
|
||||
if v := bm.Get("astaxie"); v.(string) != "author" {
|
||||
t.Error("get err")
|
||||
}
|
||||
|
||||
//test GetMulti
|
||||
if err = bm.Put("astaxie1", "author1", timeoutDuration); err != nil {
|
||||
t.Error("set Error", err)
|
||||
}
|
||||
if !bm.IsExist("astaxie1") {
|
||||
t.Error("check err")
|
||||
}
|
||||
|
||||
vv := bm.GetMulti([]string{"astaxie", "astaxie1"})
|
||||
if len(vv) != 2 {
|
||||
t.Error("GetMulti ERROR")
|
||||
}
|
||||
if vv[0].(string) != "author" {
|
||||
t.Error("GetMulti ERROR")
|
||||
}
|
||||
if vv[1].(string) != "author1" {
|
||||
t.Error("GetMulti ERROR")
|
||||
}
|
||||
|
||||
os.RemoveAll("cache")
|
||||
}
|
100
pkg/client/cache/conv.go
vendored
Normal file
100
pkg/client/cache/conv.go
vendored
Normal file
@ -0,0 +1,100 @@
|
||||
// Copyright 2014 beego Author. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package cache
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strconv"
|
||||
)
|
||||
|
||||
// GetString converts interface to string.
|
||||
func GetString(v interface{}) string {
|
||||
switch result := v.(type) {
|
||||
case string:
|
||||
return result
|
||||
case []byte:
|
||||
return string(result)
|
||||
default:
|
||||
if v != nil {
|
||||
return fmt.Sprint(result)
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// GetInt converts interface to int.
|
||||
func GetInt(v interface{}) int {
|
||||
switch result := v.(type) {
|
||||
case int:
|
||||
return result
|
||||
case int32:
|
||||
return int(result)
|
||||
case int64:
|
||||
return int(result)
|
||||
default:
|
||||
if d := GetString(v); d != "" {
|
||||
value, _ := strconv.Atoi(d)
|
||||
return value
|
||||
}
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
// GetInt64 converts interface to int64.
|
||||
func GetInt64(v interface{}) int64 {
|
||||
switch result := v.(type) {
|
||||
case int:
|
||||
return int64(result)
|
||||
case int32:
|
||||
return int64(result)
|
||||
case int64:
|
||||
return result
|
||||
default:
|
||||
|
||||
if d := GetString(v); d != "" {
|
||||
value, _ := strconv.ParseInt(d, 10, 64)
|
||||
return value
|
||||
}
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
// GetFloat64 converts interface to float64.
|
||||
func GetFloat64(v interface{}) float64 {
|
||||
switch result := v.(type) {
|
||||
case float64:
|
||||
return result
|
||||
default:
|
||||
if d := GetString(v); d != "" {
|
||||
value, _ := strconv.ParseFloat(d, 64)
|
||||
return value
|
||||
}
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
// GetBool converts interface to bool.
|
||||
func GetBool(v interface{}) bool {
|
||||
switch result := v.(type) {
|
||||
case bool:
|
||||
return result
|
||||
default:
|
||||
if d := GetString(v); d != "" {
|
||||
value, _ := strconv.ParseBool(d)
|
||||
return value
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
143
pkg/client/cache/conv_test.go
vendored
Normal file
143
pkg/client/cache/conv_test.go
vendored
Normal file
@ -0,0 +1,143 @@
|
||||
// Copyright 2014 beego Author. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package cache
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestGetString(t *testing.T) {
|
||||
var t1 = "test1"
|
||||
if "test1" != GetString(t1) {
|
||||
t.Error("get string from string error")
|
||||
}
|
||||
var t2 = []byte("test2")
|
||||
if "test2" != GetString(t2) {
|
||||
t.Error("get string from byte array error")
|
||||
}
|
||||
var t3 = 1
|
||||
if "1" != GetString(t3) {
|
||||
t.Error("get string from int error")
|
||||
}
|
||||
var t4 int64 = 1
|
||||
if "1" != GetString(t4) {
|
||||
t.Error("get string from int64 error")
|
||||
}
|
||||
var t5 = 1.1
|
||||
if "1.1" != GetString(t5) {
|
||||
t.Error("get string from float64 error")
|
||||
}
|
||||
|
||||
if "" != GetString(nil) {
|
||||
t.Error("get string from nil error")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetInt(t *testing.T) {
|
||||
var t1 = 1
|
||||
if 1 != GetInt(t1) {
|
||||
t.Error("get int from int error")
|
||||
}
|
||||
var t2 int32 = 32
|
||||
if 32 != GetInt(t2) {
|
||||
t.Error("get int from int32 error")
|
||||
}
|
||||
var t3 int64 = 64
|
||||
if 64 != GetInt(t3) {
|
||||
t.Error("get int from int64 error")
|
||||
}
|
||||
var t4 = "128"
|
||||
if 128 != GetInt(t4) {
|
||||
t.Error("get int from num string error")
|
||||
}
|
||||
if 0 != GetInt(nil) {
|
||||
t.Error("get int from nil error")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetInt64(t *testing.T) {
|
||||
var i int64 = 1
|
||||
var t1 = 1
|
||||
if i != GetInt64(t1) {
|
||||
t.Error("get int64 from int error")
|
||||
}
|
||||
var t2 int32 = 1
|
||||
if i != GetInt64(t2) {
|
||||
t.Error("get int64 from int32 error")
|
||||
}
|
||||
var t3 int64 = 1
|
||||
if i != GetInt64(t3) {
|
||||
t.Error("get int64 from int64 error")
|
||||
}
|
||||
var t4 = "1"
|
||||
if i != GetInt64(t4) {
|
||||
t.Error("get int64 from num string error")
|
||||
}
|
||||
if 0 != GetInt64(nil) {
|
||||
t.Error("get int64 from nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetFloat64(t *testing.T) {
|
||||
var f = 1.11
|
||||
var t1 float32 = 1.11
|
||||
if f != GetFloat64(t1) {
|
||||
t.Error("get float64 from float32 error")
|
||||
}
|
||||
var t2 = 1.11
|
||||
if f != GetFloat64(t2) {
|
||||
t.Error("get float64 from float64 error")
|
||||
}
|
||||
var t3 = "1.11"
|
||||
if f != GetFloat64(t3) {
|
||||
t.Error("get float64 from string error")
|
||||
}
|
||||
|
||||
var f2 float64 = 1
|
||||
var t4 = 1
|
||||
if f2 != GetFloat64(t4) {
|
||||
t.Error("get float64 from int error")
|
||||
}
|
||||
|
||||
if 0 != GetFloat64(nil) {
|
||||
t.Error("get float64 from nil error")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetBool(t *testing.T) {
|
||||
var t1 = true
|
||||
if !GetBool(t1) {
|
||||
t.Error("get bool from bool error")
|
||||
}
|
||||
var t2 = "true"
|
||||
if !GetBool(t2) {
|
||||
t.Error("get bool from string error")
|
||||
}
|
||||
if GetBool(nil) {
|
||||
t.Error("get bool from nil error")
|
||||
}
|
||||
}
|
||||
|
||||
func byteArrayEquals(a []byte, b []byte) bool {
|
||||
if len(a) != len(b) {
|
||||
return false
|
||||
}
|
||||
for i, v := range a {
|
||||
if v != b[i] {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
257
pkg/client/cache/file.go
vendored
Normal file
257
pkg/client/cache/file.go
vendored
Normal file
@ -0,0 +1,257 @@
|
||||
// Copyright 2014 beego Author. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package cache
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/md5"
|
||||
"encoding/gob"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"reflect"
|
||||
"strconv"
|
||||
"time"
|
||||
)
|
||||
|
||||
// FileCacheItem is basic unit of file cache adapter which
|
||||
// contains data and expire time.
|
||||
type FileCacheItem struct {
|
||||
Data interface{}
|
||||
Lastaccess time.Time
|
||||
Expired time.Time
|
||||
}
|
||||
|
||||
// FileCache Config
|
||||
var (
|
||||
FileCachePath = "cache" // cache directory
|
||||
FileCacheFileSuffix = ".bin" // cache file suffix
|
||||
FileCacheDirectoryLevel = 2 // cache file deep level if auto generated cache files.
|
||||
FileCacheEmbedExpiry time.Duration // cache expire time, default is no expire forever.
|
||||
)
|
||||
|
||||
// FileCache is cache adapter for file storage.
|
||||
type FileCache struct {
|
||||
CachePath string
|
||||
FileSuffix string
|
||||
DirectoryLevel int
|
||||
EmbedExpiry int
|
||||
}
|
||||
|
||||
// NewFileCache creates a new file cache with no config.
|
||||
// The level and expiry need to be set in the method StartAndGC as config string.
|
||||
func NewFileCache() Cache {
|
||||
// return &FileCache{CachePath:FileCachePath, FileSuffix:FileCacheFileSuffix}
|
||||
return &FileCache{}
|
||||
}
|
||||
|
||||
// StartAndGC starts gc for file cache.
|
||||
// config must be in the format {CachePath:"/cache","FileSuffix":".bin","DirectoryLevel":"2","EmbedExpiry":"0"}
|
||||
func (fc *FileCache) StartAndGC(config string) error {
|
||||
|
||||
cfg := make(map[string]string)
|
||||
err := json.Unmarshal([]byte(config), &cfg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if _, ok := cfg["CachePath"]; !ok {
|
||||
cfg["CachePath"] = FileCachePath
|
||||
}
|
||||
if _, ok := cfg["FileSuffix"]; !ok {
|
||||
cfg["FileSuffix"] = FileCacheFileSuffix
|
||||
}
|
||||
if _, ok := cfg["DirectoryLevel"]; !ok {
|
||||
cfg["DirectoryLevel"] = strconv.Itoa(FileCacheDirectoryLevel)
|
||||
}
|
||||
if _, ok := cfg["EmbedExpiry"]; !ok {
|
||||
cfg["EmbedExpiry"] = strconv.FormatInt(int64(FileCacheEmbedExpiry.Seconds()), 10)
|
||||
}
|
||||
fc.CachePath = cfg["CachePath"]
|
||||
fc.FileSuffix = cfg["FileSuffix"]
|
||||
fc.DirectoryLevel, _ = strconv.Atoi(cfg["DirectoryLevel"])
|
||||
fc.EmbedExpiry, _ = strconv.Atoi(cfg["EmbedExpiry"])
|
||||
|
||||
fc.Init()
|
||||
return nil
|
||||
}
|
||||
|
||||
// Init makes new a dir for file cache if it does not already exist
|
||||
func (fc *FileCache) Init() {
|
||||
if ok, _ := exists(fc.CachePath); !ok { // todo : error handle
|
||||
_ = os.MkdirAll(fc.CachePath, os.ModePerm) // todo : error handle
|
||||
}
|
||||
}
|
||||
|
||||
// getCachedFilename returns an md5 encoded file name.
|
||||
func (fc *FileCache) getCacheFileName(key string) string {
|
||||
m := md5.New()
|
||||
io.WriteString(m, key)
|
||||
keyMd5 := hex.EncodeToString(m.Sum(nil))
|
||||
cachePath := fc.CachePath
|
||||
switch fc.DirectoryLevel {
|
||||
case 2:
|
||||
cachePath = filepath.Join(cachePath, keyMd5[0:2], keyMd5[2:4])
|
||||
case 1:
|
||||
cachePath = filepath.Join(cachePath, keyMd5[0:2])
|
||||
}
|
||||
|
||||
if ok, _ := exists(cachePath); !ok { // todo : error handle
|
||||
_ = os.MkdirAll(cachePath, os.ModePerm) // todo : error handle
|
||||
}
|
||||
|
||||
return filepath.Join(cachePath, fmt.Sprintf("%s%s", keyMd5, fc.FileSuffix))
|
||||
}
|
||||
|
||||
// Get value from file cache.
|
||||
// if nonexistent or expired return an empty string.
|
||||
func (fc *FileCache) Get(key string) interface{} {
|
||||
fileData, err := FileGetContents(fc.getCacheFileName(key))
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
var to FileCacheItem
|
||||
GobDecode(fileData, &to)
|
||||
if to.Expired.Before(time.Now()) {
|
||||
return ""
|
||||
}
|
||||
return to.Data
|
||||
}
|
||||
|
||||
// GetMulti gets values from file cache.
|
||||
// if nonexistent or expired return an empty string.
|
||||
func (fc *FileCache) GetMulti(keys []string) []interface{} {
|
||||
var rc []interface{}
|
||||
for _, key := range keys {
|
||||
rc = append(rc, fc.Get(key))
|
||||
}
|
||||
return rc
|
||||
}
|
||||
|
||||
// Put value into file cache.
|
||||
// timeout: how long this file should be kept in ms
|
||||
// if timeout equals fc.EmbedExpiry(default is 0), cache this item forever.
|
||||
func (fc *FileCache) Put(key string, val interface{}, timeout time.Duration) error {
|
||||
gob.Register(val)
|
||||
|
||||
item := FileCacheItem{Data: val}
|
||||
if timeout == time.Duration(fc.EmbedExpiry) {
|
||||
item.Expired = time.Now().Add((86400 * 365 * 10) * time.Second) // ten years
|
||||
} else {
|
||||
item.Expired = time.Now().Add(timeout)
|
||||
}
|
||||
item.Lastaccess = time.Now()
|
||||
data, err := GobEncode(item)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return FilePutContents(fc.getCacheFileName(key), data)
|
||||
}
|
||||
|
||||
// Delete file cache value.
|
||||
func (fc *FileCache) Delete(key string) error {
|
||||
filename := fc.getCacheFileName(key)
|
||||
if ok, _ := exists(filename); ok {
|
||||
return os.Remove(filename)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Incr increases cached int value.
|
||||
// fc value is saved forever unless deleted.
|
||||
func (fc *FileCache) Incr(key string) error {
|
||||
data := fc.Get(key)
|
||||
var incr int
|
||||
if reflect.TypeOf(data).Name() != "int" {
|
||||
incr = 0
|
||||
} else {
|
||||
incr = data.(int) + 1
|
||||
}
|
||||
fc.Put(key, incr, time.Duration(fc.EmbedExpiry))
|
||||
return nil
|
||||
}
|
||||
|
||||
// Decr decreases cached int value.
|
||||
func (fc *FileCache) Decr(key string) error {
|
||||
data := fc.Get(key)
|
||||
var decr int
|
||||
if reflect.TypeOf(data).Name() != "int" || data.(int)-1 <= 0 {
|
||||
decr = 0
|
||||
} else {
|
||||
decr = data.(int) - 1
|
||||
}
|
||||
fc.Put(key, decr, time.Duration(fc.EmbedExpiry))
|
||||
return nil
|
||||
}
|
||||
|
||||
// IsExist checks if value exists.
|
||||
func (fc *FileCache) IsExist(key string) bool {
|
||||
ret, _ := exists(fc.getCacheFileName(key))
|
||||
return ret
|
||||
}
|
||||
|
||||
// ClearAll cleans cached files (not implemented)
|
||||
func (fc *FileCache) ClearAll() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Check if a file exists
|
||||
func exists(path string) (bool, error) {
|
||||
_, err := os.Stat(path)
|
||||
if err == nil {
|
||||
return true, nil
|
||||
}
|
||||
if os.IsNotExist(err) {
|
||||
return false, nil
|
||||
}
|
||||
return false, err
|
||||
}
|
||||
|
||||
// FileGetContents Reads bytes from a file.
|
||||
// if non-existent, create this file.
|
||||
func FileGetContents(filename string) (data []byte, e error) {
|
||||
return ioutil.ReadFile(filename)
|
||||
}
|
||||
|
||||
// FilePutContents puts bytes into a file.
|
||||
// if non-existent, create this file.
|
||||
func FilePutContents(filename string, content []byte) error {
|
||||
return ioutil.WriteFile(filename, content, os.ModePerm)
|
||||
}
|
||||
|
||||
// GobEncode Gob encodes a file cache item.
|
||||
func GobEncode(data interface{}) ([]byte, error) {
|
||||
buf := bytes.NewBuffer(nil)
|
||||
enc := gob.NewEncoder(buf)
|
||||
err := enc.Encode(data)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return buf.Bytes(), err
|
||||
}
|
||||
|
||||
// GobDecode Gob decodes a file cache item.
|
||||
func GobDecode(data []byte, to *FileCacheItem) error {
|
||||
buf := bytes.NewBuffer(data)
|
||||
dec := gob.NewDecoder(buf)
|
||||
return dec.Decode(&to)
|
||||
}
|
||||
|
||||
func init() {
|
||||
Register("file", NewFileCache)
|
||||
}
|
189
pkg/client/cache/memcache/memcache.go
vendored
Normal file
189
pkg/client/cache/memcache/memcache.go
vendored
Normal file
@ -0,0 +1,189 @@
|
||||
// Copyright 2014 beego Author. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
// Package memcache for cache provider
|
||||
//
|
||||
// depend on github.com/bradfitz/gomemcache/memcache
|
||||
//
|
||||
// go install github.com/bradfitz/gomemcache/memcache
|
||||
//
|
||||
// Usage:
|
||||
// import(
|
||||
// _ "github.com/astaxie/beego/cache/memcache"
|
||||
// "github.com/astaxie/beego/cache"
|
||||
// )
|
||||
//
|
||||
// bm, err := cache.NewCache("memcache", `{"conn":"127.0.0.1:11211"}`)
|
||||
//
|
||||
// more docs http://beego.me/docs/module/cache.md
|
||||
package memcache
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/bradfitz/gomemcache/memcache"
|
||||
|
||||
"github.com/astaxie/beego/pkg/client/cache"
|
||||
)
|
||||
|
||||
// Cache Memcache adapter.
|
||||
type Cache struct {
|
||||
conn *memcache.Client
|
||||
conninfo []string
|
||||
}
|
||||
|
||||
// NewMemCache creates a new memcache adapter.
|
||||
func NewMemCache() cache.Cache {
|
||||
return &Cache{}
|
||||
}
|
||||
|
||||
// Get get value from memcache.
|
||||
func (rc *Cache) Get(key string) interface{} {
|
||||
if rc.conn == nil {
|
||||
if err := rc.connectInit(); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if item, err := rc.conn.Get(key); err == nil {
|
||||
return item.Value
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetMulti gets a value from a key in memcache.
|
||||
func (rc *Cache) GetMulti(keys []string) []interface{} {
|
||||
size := len(keys)
|
||||
var rv []interface{}
|
||||
if rc.conn == nil {
|
||||
if err := rc.connectInit(); err != nil {
|
||||
for i := 0; i < size; i++ {
|
||||
rv = append(rv, err)
|
||||
}
|
||||
return rv
|
||||
}
|
||||
}
|
||||
mv, err := rc.conn.GetMulti(keys)
|
||||
if err == nil {
|
||||
for _, v := range mv {
|
||||
rv = append(rv, v.Value)
|
||||
}
|
||||
return rv
|
||||
}
|
||||
for i := 0; i < size; i++ {
|
||||
rv = append(rv, err)
|
||||
}
|
||||
return rv
|
||||
}
|
||||
|
||||
// Put puts a value into memcache.
|
||||
func (rc *Cache) Put(key string, val interface{}, timeout time.Duration) error {
|
||||
if rc.conn == nil {
|
||||
if err := rc.connectInit(); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
item := memcache.Item{Key: key, Expiration: int32(timeout / time.Second)}
|
||||
if v, ok := val.([]byte); ok {
|
||||
item.Value = v
|
||||
} else if str, ok := val.(string); ok {
|
||||
item.Value = []byte(str)
|
||||
} else {
|
||||
return errors.New("val only support string and []byte")
|
||||
}
|
||||
return rc.conn.Set(&item)
|
||||
}
|
||||
|
||||
// Delete deletes a value in memcache.
|
||||
func (rc *Cache) Delete(key string) error {
|
||||
if rc.conn == nil {
|
||||
if err := rc.connectInit(); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return rc.conn.Delete(key)
|
||||
}
|
||||
|
||||
// Incr increases counter.
|
||||
func (rc *Cache) Incr(key string) error {
|
||||
if rc.conn == nil {
|
||||
if err := rc.connectInit(); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
_, err := rc.conn.Increment(key, 1)
|
||||
return err
|
||||
}
|
||||
|
||||
// Decr decreases counter.
|
||||
func (rc *Cache) Decr(key string) error {
|
||||
if rc.conn == nil {
|
||||
if err := rc.connectInit(); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
_, err := rc.conn.Decrement(key, 1)
|
||||
return err
|
||||
}
|
||||
|
||||
// IsExist checks if a value exists in memcache.
|
||||
func (rc *Cache) IsExist(key string) bool {
|
||||
if rc.conn == nil {
|
||||
if err := rc.connectInit(); err != nil {
|
||||
return false
|
||||
}
|
||||
}
|
||||
_, err := rc.conn.Get(key)
|
||||
return err == nil
|
||||
}
|
||||
|
||||
// ClearAll clears all cache in memcache.
|
||||
func (rc *Cache) ClearAll() error {
|
||||
if rc.conn == nil {
|
||||
if err := rc.connectInit(); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return rc.conn.FlushAll()
|
||||
}
|
||||
|
||||
// StartAndGC starts the memcache adapter.
|
||||
// config: must be in the format {"conn":"connection info"}.
|
||||
// If an error occurs during connecting, an error is returned
|
||||
func (rc *Cache) StartAndGC(config string) error {
|
||||
var cf map[string]string
|
||||
json.Unmarshal([]byte(config), &cf)
|
||||
if _, ok := cf["conn"]; !ok {
|
||||
return errors.New("config has no conn key")
|
||||
}
|
||||
rc.conninfo = strings.Split(cf["conn"], ";")
|
||||
if rc.conn == nil {
|
||||
if err := rc.connectInit(); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// connect to memcache and keep the connection.
|
||||
func (rc *Cache) connectInit() error {
|
||||
rc.conn = memcache.New(rc.conninfo...)
|
||||
return nil
|
||||
}
|
||||
|
||||
func init() {
|
||||
cache.Register("memcache", NewMemCache)
|
||||
}
|
117
pkg/client/cache/memcache/memcache_test.go
vendored
Normal file
117
pkg/client/cache/memcache/memcache_test.go
vendored
Normal file
@ -0,0 +1,117 @@
|
||||
// Copyright 2014 beego Author. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package memcache
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
|
||||
_ "github.com/bradfitz/gomemcache/memcache"
|
||||
|
||||
"strconv"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/astaxie/beego/pkg/client/cache"
|
||||
)
|
||||
|
||||
func TestMemcacheCache(t *testing.T) {
|
||||
|
||||
addr := os.Getenv("MEMCACHE_ADDR")
|
||||
if addr == "" {
|
||||
addr = "127.0.0.1:11211"
|
||||
}
|
||||
|
||||
bm, err := cache.NewCache("memcache", fmt.Sprintf(`{"conn": "%s"}`, addr))
|
||||
if err != nil {
|
||||
t.Error("init err")
|
||||
}
|
||||
timeoutDuration := 10 * time.Second
|
||||
if err = bm.Put("astaxie", "1", timeoutDuration); err != nil {
|
||||
t.Error("set Error", err)
|
||||
}
|
||||
if !bm.IsExist("astaxie") {
|
||||
t.Error("check err")
|
||||
}
|
||||
|
||||
time.Sleep(11 * time.Second)
|
||||
|
||||
if bm.IsExist("astaxie") {
|
||||
t.Error("check err")
|
||||
}
|
||||
if err = bm.Put("astaxie", "1", timeoutDuration); err != nil {
|
||||
t.Error("set Error", err)
|
||||
}
|
||||
|
||||
if v, err := strconv.Atoi(string(bm.Get("astaxie").([]byte))); err != nil || v != 1 {
|
||||
t.Error("get err")
|
||||
}
|
||||
|
||||
if err = bm.Incr("astaxie"); err != nil {
|
||||
t.Error("Incr Error", err)
|
||||
}
|
||||
|
||||
if v, err := strconv.Atoi(string(bm.Get("astaxie").([]byte))); err != nil || v != 2 {
|
||||
t.Error("get err")
|
||||
}
|
||||
|
||||
if err = bm.Decr("astaxie"); err != nil {
|
||||
t.Error("Decr Error", err)
|
||||
}
|
||||
|
||||
if v, err := strconv.Atoi(string(bm.Get("astaxie").([]byte))); err != nil || v != 1 {
|
||||
t.Error("get err")
|
||||
}
|
||||
bm.Delete("astaxie")
|
||||
if bm.IsExist("astaxie") {
|
||||
t.Error("delete err")
|
||||
}
|
||||
|
||||
// test string
|
||||
if err = bm.Put("astaxie", "author", timeoutDuration); err != nil {
|
||||
t.Error("set Error", err)
|
||||
}
|
||||
if !bm.IsExist("astaxie") {
|
||||
t.Error("check err")
|
||||
}
|
||||
|
||||
if v := bm.Get("astaxie").([]byte); string(v) != "author" {
|
||||
t.Error("get err")
|
||||
}
|
||||
|
||||
// test GetMulti
|
||||
if err = bm.Put("astaxie1", "author1", timeoutDuration); err != nil {
|
||||
t.Error("set Error", err)
|
||||
}
|
||||
if !bm.IsExist("astaxie1") {
|
||||
t.Error("check err")
|
||||
}
|
||||
|
||||
vv := bm.GetMulti([]string{"astaxie", "astaxie1"})
|
||||
if len(vv) != 2 {
|
||||
t.Error("GetMulti ERROR")
|
||||
}
|
||||
if string(vv[0].([]byte)) != "author" && string(vv[0].([]byte)) != "author1" {
|
||||
t.Error("GetMulti ERROR")
|
||||
}
|
||||
if string(vv[1].([]byte)) != "author1" && string(vv[1].([]byte)) != "author" {
|
||||
t.Error("GetMulti ERROR")
|
||||
}
|
||||
|
||||
// test clear all
|
||||
if err = bm.ClearAll(); err != nil {
|
||||
t.Error("clear all err")
|
||||
}
|
||||
}
|
256
pkg/client/cache/memory.go
vendored
Normal file
256
pkg/client/cache/memory.go
vendored
Normal file
@ -0,0 +1,256 @@
|
||||
// Copyright 2014 beego Author. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package cache
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
var (
|
||||
// Timer for how often to recycle the expired cache items in memory (in seconds)
|
||||
DefaultEvery = 60 // 1 minute
|
||||
)
|
||||
|
||||
// MemoryItem stores memory cache item.
|
||||
type MemoryItem struct {
|
||||
val interface{}
|
||||
createdTime time.Time
|
||||
lifespan time.Duration
|
||||
}
|
||||
|
||||
func (mi *MemoryItem) isExpire() bool {
|
||||
// 0 means forever
|
||||
if mi.lifespan == 0 {
|
||||
return false
|
||||
}
|
||||
return time.Now().Sub(mi.createdTime) > mi.lifespan
|
||||
}
|
||||
|
||||
// MemoryCache is a memory cache adapter.
|
||||
// Contains a RW locker for safe map storage.
|
||||
type MemoryCache struct {
|
||||
sync.RWMutex
|
||||
dur time.Duration
|
||||
items map[string]*MemoryItem
|
||||
Every int // run an expiration check Every clock time
|
||||
}
|
||||
|
||||
// NewMemoryCache returns a new MemoryCache.
|
||||
func NewMemoryCache() Cache {
|
||||
cache := MemoryCache{items: make(map[string]*MemoryItem)}
|
||||
return &cache
|
||||
}
|
||||
|
||||
// Get returns cache from memory.
|
||||
// If non-existent or expired, return nil.
|
||||
func (bc *MemoryCache) Get(name string) interface{} {
|
||||
bc.RLock()
|
||||
defer bc.RUnlock()
|
||||
if itm, ok := bc.items[name]; ok {
|
||||
if itm.isExpire() {
|
||||
return nil
|
||||
}
|
||||
return itm.val
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetMulti gets caches from memory.
|
||||
// If non-existent or expired, return nil.
|
||||
func (bc *MemoryCache) GetMulti(names []string) []interface{} {
|
||||
var rc []interface{}
|
||||
for _, name := range names {
|
||||
rc = append(rc, bc.Get(name))
|
||||
}
|
||||
return rc
|
||||
}
|
||||
|
||||
// Put puts cache into memory.
|
||||
// If lifespan is 0, it will never overwrite this value unless restarted
|
||||
func (bc *MemoryCache) Put(name string, value interface{}, lifespan time.Duration) error {
|
||||
bc.Lock()
|
||||
defer bc.Unlock()
|
||||
bc.items[name] = &MemoryItem{
|
||||
val: value,
|
||||
createdTime: time.Now(),
|
||||
lifespan: lifespan,
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Delete cache in memory.
|
||||
func (bc *MemoryCache) Delete(name string) error {
|
||||
bc.Lock()
|
||||
defer bc.Unlock()
|
||||
if _, ok := bc.items[name]; !ok {
|
||||
return errors.New("key not exist")
|
||||
}
|
||||
delete(bc.items, name)
|
||||
if _, ok := bc.items[name]; ok {
|
||||
return errors.New("delete key error")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Incr increases cache counter in memory.
|
||||
// Supports int,int32,int64,uint,uint32,uint64.
|
||||
func (bc *MemoryCache) Incr(key string) error {
|
||||
bc.Lock()
|
||||
defer bc.Unlock()
|
||||
itm, ok := bc.items[key]
|
||||
if !ok {
|
||||
return errors.New("key not exist")
|
||||
}
|
||||
switch val := itm.val.(type) {
|
||||
case int:
|
||||
itm.val = val + 1
|
||||
case int32:
|
||||
itm.val = val + 1
|
||||
case int64:
|
||||
itm.val = val + 1
|
||||
case uint:
|
||||
itm.val = val + 1
|
||||
case uint32:
|
||||
itm.val = val + 1
|
||||
case uint64:
|
||||
itm.val = val + 1
|
||||
default:
|
||||
return errors.New("item val is not (u)int (u)int32 (u)int64")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Decr decreases counter in memory.
|
||||
func (bc *MemoryCache) Decr(key string) error {
|
||||
bc.Lock()
|
||||
defer bc.Unlock()
|
||||
itm, ok := bc.items[key]
|
||||
if !ok {
|
||||
return errors.New("key not exist")
|
||||
}
|
||||
switch val := itm.val.(type) {
|
||||
case int:
|
||||
itm.val = val - 1
|
||||
case int64:
|
||||
itm.val = val - 1
|
||||
case int32:
|
||||
itm.val = val - 1
|
||||
case uint:
|
||||
if val > 0 {
|
||||
itm.val = val - 1
|
||||
} else {
|
||||
return errors.New("item val is less than 0")
|
||||
}
|
||||
case uint32:
|
||||
if val > 0 {
|
||||
itm.val = val - 1
|
||||
} else {
|
||||
return errors.New("item val is less than 0")
|
||||
}
|
||||
case uint64:
|
||||
if val > 0 {
|
||||
itm.val = val - 1
|
||||
} else {
|
||||
return errors.New("item val is less than 0")
|
||||
}
|
||||
default:
|
||||
return errors.New("item val is not int int64 int32")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// IsExist checks if cache exists in memory.
|
||||
func (bc *MemoryCache) IsExist(name string) bool {
|
||||
bc.RLock()
|
||||
defer bc.RUnlock()
|
||||
if v, ok := bc.items[name]; ok {
|
||||
return !v.isExpire()
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// ClearAll deletes all cache in memory.
|
||||
func (bc *MemoryCache) ClearAll() error {
|
||||
bc.Lock()
|
||||
defer bc.Unlock()
|
||||
bc.items = make(map[string]*MemoryItem)
|
||||
return nil
|
||||
}
|
||||
|
||||
// StartAndGC starts memory cache. Checks expiration in every clock time.
|
||||
func (bc *MemoryCache) StartAndGC(config string) error {
|
||||
var cf map[string]int
|
||||
json.Unmarshal([]byte(config), &cf)
|
||||
if _, ok := cf["interval"]; !ok {
|
||||
cf = make(map[string]int)
|
||||
cf["interval"] = DefaultEvery
|
||||
}
|
||||
dur := time.Duration(cf["interval"]) * time.Second
|
||||
bc.Every = cf["interval"]
|
||||
bc.dur = dur
|
||||
go bc.vacuum()
|
||||
return nil
|
||||
}
|
||||
|
||||
// check expiration.
|
||||
func (bc *MemoryCache) vacuum() {
|
||||
bc.RLock()
|
||||
every := bc.Every
|
||||
bc.RUnlock()
|
||||
|
||||
if every < 1 {
|
||||
return
|
||||
}
|
||||
for {
|
||||
<-time.After(bc.dur)
|
||||
bc.RLock()
|
||||
if bc.items == nil {
|
||||
bc.RUnlock()
|
||||
return
|
||||
}
|
||||
bc.RUnlock()
|
||||
if keys := bc.expiredKeys(); len(keys) != 0 {
|
||||
bc.clearItems(keys)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// expiredKeys returns keys list which are expired.
|
||||
func (bc *MemoryCache) expiredKeys() (keys []string) {
|
||||
bc.RLock()
|
||||
defer bc.RUnlock()
|
||||
for key, itm := range bc.items {
|
||||
if itm.isExpire() {
|
||||
keys = append(keys, key)
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// ClearItems removes all items who's key is in keys
|
||||
func (bc *MemoryCache) clearItems(keys []string) {
|
||||
bc.Lock()
|
||||
defer bc.Unlock()
|
||||
for _, key := range keys {
|
||||
delete(bc.items, key)
|
||||
}
|
||||
}
|
||||
|
||||
func init() {
|
||||
Register("memory", NewMemoryCache)
|
||||
}
|
271
pkg/client/cache/redis/redis.go
vendored
Normal file
271
pkg/client/cache/redis/redis.go
vendored
Normal file
@ -0,0 +1,271 @@
|
||||
// Copyright 2014 beego Author. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
// Package redis for cache provider
|
||||
//
|
||||
// depend on github.com/gomodule/redigo/redis
|
||||
//
|
||||
// go install github.com/gomodule/redigo/redis
|
||||
//
|
||||
// Usage:
|
||||
// import(
|
||||
// _ "github.com/astaxie/beego/cache/redis"
|
||||
// "github.com/astaxie/beego/cache"
|
||||
// )
|
||||
//
|
||||
// bm, err := cache.NewCache("redis", `{"conn":"127.0.0.1:11211"}`)
|
||||
//
|
||||
// more docs http://beego.me/docs/module/cache.md
|
||||
package redis
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gomodule/redigo/redis"
|
||||
|
||||
"github.com/astaxie/beego/pkg/client/cache"
|
||||
)
|
||||
|
||||
var (
|
||||
// The collection name of redis for the cache adapter.
|
||||
DefaultKey = "beecacheRedis"
|
||||
)
|
||||
|
||||
// Cache is Redis cache adapter.
|
||||
type Cache struct {
|
||||
p *redis.Pool // redis connection pool
|
||||
conninfo string
|
||||
dbNum int
|
||||
key string
|
||||
password string
|
||||
maxIdle int
|
||||
|
||||
// Timeout value (less than the redis server's timeout value)
|
||||
timeout time.Duration
|
||||
}
|
||||
|
||||
// NewRedisCache creates a new redis cache with default collection name.
|
||||
func NewRedisCache() cache.Cache {
|
||||
return &Cache{key: DefaultKey}
|
||||
}
|
||||
|
||||
// Execute the redis commands. args[0] must be the key name
|
||||
func (rc *Cache) do(commandName string, args ...interface{}) (reply interface{}, err error) {
|
||||
if len(args) < 1 {
|
||||
return nil, errors.New("missing required arguments")
|
||||
}
|
||||
args[0] = rc.associate(args[0])
|
||||
c := rc.p.Get()
|
||||
defer c.Close()
|
||||
|
||||
return c.Do(commandName, args...)
|
||||
}
|
||||
|
||||
// associate with config key.
|
||||
func (rc *Cache) associate(originKey interface{}) string {
|
||||
return fmt.Sprintf("%s:%s", rc.key, originKey)
|
||||
}
|
||||
|
||||
// Get cache from redis.
|
||||
func (rc *Cache) Get(key string) interface{} {
|
||||
if v, err := rc.do("GET", key); err == nil {
|
||||
return v
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetMulti gets cache from redis.
|
||||
func (rc *Cache) GetMulti(keys []string) []interface{} {
|
||||
c := rc.p.Get()
|
||||
defer c.Close()
|
||||
var args []interface{}
|
||||
for _, key := range keys {
|
||||
args = append(args, rc.associate(key))
|
||||
}
|
||||
values, err := redis.Values(c.Do("MGET", args...))
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
return values
|
||||
}
|
||||
|
||||
// Put puts cache into redis.
|
||||
func (rc *Cache) Put(key string, val interface{}, timeout time.Duration) error {
|
||||
_, err := rc.do("SETEX", key, int64(timeout/time.Second), val)
|
||||
return err
|
||||
}
|
||||
|
||||
// Delete deletes a key's cache in redis.
|
||||
func (rc *Cache) Delete(key string) error {
|
||||
_, err := rc.do("DEL", key)
|
||||
return err
|
||||
}
|
||||
|
||||
// IsExist checks cache's existence in redis.
|
||||
func (rc *Cache) IsExist(key string) bool {
|
||||
v, err := redis.Bool(rc.do("EXISTS", key))
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
return v
|
||||
}
|
||||
|
||||
// Incr increases a key's counter in redis.
|
||||
func (rc *Cache) Incr(key string) error {
|
||||
_, err := redis.Bool(rc.do("INCRBY", key, 1))
|
||||
return err
|
||||
}
|
||||
|
||||
// Decr decreases a key's counter in redis.
|
||||
func (rc *Cache) Decr(key string) error {
|
||||
_, err := redis.Bool(rc.do("INCRBY", key, -1))
|
||||
return err
|
||||
}
|
||||
|
||||
// ClearAll deletes all cache in the redis collection
|
||||
func (rc *Cache) ClearAll() error {
|
||||
cachedKeys, err := rc.Scan(rc.key + ":*")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
c := rc.p.Get()
|
||||
defer c.Close()
|
||||
for _, str := range cachedKeys {
|
||||
if _, err = c.Do("DEL", str); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// Scan scans all keys matching a given pattern.
|
||||
func (rc *Cache) Scan(pattern string) (keys []string, err error) {
|
||||
c := rc.p.Get()
|
||||
defer c.Close()
|
||||
var (
|
||||
cursor uint64 = 0 // start
|
||||
result []interface{}
|
||||
list []string
|
||||
)
|
||||
for {
|
||||
result, err = redis.Values(c.Do("SCAN", cursor, "MATCH", pattern, "COUNT", 1024))
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
list, err = redis.Strings(result[1], nil)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
keys = append(keys, list...)
|
||||
cursor, err = redis.Uint64(result[0], nil)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if cursor == 0 { // over
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// StartAndGC starts the redis cache adapter.
|
||||
// config: must be in this format {"key":"collection key","conn":"connection info","dbNum":"0"}
|
||||
// Cached items in redis are stored forever, no garbage collection happens
|
||||
func (rc *Cache) StartAndGC(config string) error {
|
||||
var cf map[string]string
|
||||
json.Unmarshal([]byte(config), &cf)
|
||||
|
||||
if _, ok := cf["key"]; !ok {
|
||||
cf["key"] = DefaultKey
|
||||
}
|
||||
if _, ok := cf["conn"]; !ok {
|
||||
return errors.New("config has no conn key")
|
||||
}
|
||||
|
||||
// Format redis://<password>@<host>:<port>
|
||||
cf["conn"] = strings.Replace(cf["conn"], "redis://", "", 1)
|
||||
if i := strings.Index(cf["conn"], "@"); i > -1 {
|
||||
cf["password"] = cf["conn"][0:i]
|
||||
cf["conn"] = cf["conn"][i+1:]
|
||||
}
|
||||
|
||||
if _, ok := cf["dbNum"]; !ok {
|
||||
cf["dbNum"] = "0"
|
||||
}
|
||||
if _, ok := cf["password"]; !ok {
|
||||
cf["password"] = ""
|
||||
}
|
||||
if _, ok := cf["maxIdle"]; !ok {
|
||||
cf["maxIdle"] = "3"
|
||||
}
|
||||
if _, ok := cf["timeout"]; !ok {
|
||||
cf["timeout"] = "180s"
|
||||
}
|
||||
rc.key = cf["key"]
|
||||
rc.conninfo = cf["conn"]
|
||||
rc.dbNum, _ = strconv.Atoi(cf["dbNum"])
|
||||
rc.password = cf["password"]
|
||||
rc.maxIdle, _ = strconv.Atoi(cf["maxIdle"])
|
||||
|
||||
if v, err := time.ParseDuration(cf["timeout"]); err == nil {
|
||||
rc.timeout = v
|
||||
} else {
|
||||
rc.timeout = 180 * time.Second
|
||||
}
|
||||
|
||||
rc.connectInit()
|
||||
|
||||
c := rc.p.Get()
|
||||
defer c.Close()
|
||||
|
||||
return c.Err()
|
||||
}
|
||||
|
||||
// connect to redis.
|
||||
func (rc *Cache) connectInit() {
|
||||
dialFunc := func() (c redis.Conn, err error) {
|
||||
c, err = redis.Dial("tcp", rc.conninfo)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if rc.password != "" {
|
||||
if _, err := c.Do("AUTH", rc.password); err != nil {
|
||||
c.Close()
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
_, selecterr := c.Do("SELECT", rc.dbNum)
|
||||
if selecterr != nil {
|
||||
c.Close()
|
||||
return nil, selecterr
|
||||
}
|
||||
return
|
||||
}
|
||||
// initialize a new pool
|
||||
rc.p = &redis.Pool{
|
||||
MaxIdle: rc.maxIdle,
|
||||
IdleTimeout: rc.timeout,
|
||||
Dial: dialFunc,
|
||||
}
|
||||
}
|
||||
|
||||
func init() {
|
||||
cache.Register("redis", NewRedisCache)
|
||||
}
|
159
pkg/client/cache/redis/redis_test.go
vendored
Normal file
159
pkg/client/cache/redis/redis_test.go
vendored
Normal file
@ -0,0 +1,159 @@
|
||||
// Copyright 2014 beego Author. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package redis
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/gomodule/redigo/redis"
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/astaxie/beego/pkg/client/cache"
|
||||
)
|
||||
|
||||
func TestRedisCache(t *testing.T) {
|
||||
|
||||
redisAddr := os.Getenv("REDIS_ADDR")
|
||||
if redisAddr == "" {
|
||||
redisAddr = "127.0.0.1:6379"
|
||||
}
|
||||
|
||||
bm, err := cache.NewCache("redis", fmt.Sprintf(`{"conn": "%s"}`, redisAddr))
|
||||
if err != nil {
|
||||
t.Error("init err")
|
||||
}
|
||||
timeoutDuration := 10 * time.Second
|
||||
if err = bm.Put("astaxie", 1, timeoutDuration); err != nil {
|
||||
t.Error("set Error", err)
|
||||
}
|
||||
if !bm.IsExist("astaxie") {
|
||||
t.Error("check err")
|
||||
}
|
||||
|
||||
time.Sleep(11 * time.Second)
|
||||
|
||||
if bm.IsExist("astaxie") {
|
||||
t.Error("check err")
|
||||
}
|
||||
if err = bm.Put("astaxie", 1, timeoutDuration); err != nil {
|
||||
t.Error("set Error", err)
|
||||
}
|
||||
|
||||
if v, _ := redis.Int(bm.Get("astaxie"), err); v != 1 {
|
||||
t.Error("get err")
|
||||
}
|
||||
|
||||
if err = bm.Incr("astaxie"); err != nil {
|
||||
t.Error("Incr Error", err)
|
||||
}
|
||||
|
||||
if v, _ := redis.Int(bm.Get("astaxie"), err); v != 2 {
|
||||
t.Error("get err")
|
||||
}
|
||||
|
||||
if err = bm.Decr("astaxie"); err != nil {
|
||||
t.Error("Decr Error", err)
|
||||
}
|
||||
|
||||
if v, _ := redis.Int(bm.Get("astaxie"), err); v != 1 {
|
||||
t.Error("get err")
|
||||
}
|
||||
bm.Delete("astaxie")
|
||||
if bm.IsExist("astaxie") {
|
||||
t.Error("delete err")
|
||||
}
|
||||
|
||||
// test string
|
||||
if err = bm.Put("astaxie", "author", timeoutDuration); err != nil {
|
||||
t.Error("set Error", err)
|
||||
}
|
||||
if !bm.IsExist("astaxie") {
|
||||
t.Error("check err")
|
||||
}
|
||||
|
||||
if v, _ := redis.String(bm.Get("astaxie"), err); v != "author" {
|
||||
t.Error("get err")
|
||||
}
|
||||
|
||||
// test GetMulti
|
||||
if err = bm.Put("astaxie1", "author1", timeoutDuration); err != nil {
|
||||
t.Error("set Error", err)
|
||||
}
|
||||
if !bm.IsExist("astaxie1") {
|
||||
t.Error("check err")
|
||||
}
|
||||
|
||||
vv := bm.GetMulti([]string{"astaxie", "astaxie1"})
|
||||
if len(vv) != 2 {
|
||||
t.Error("GetMulti ERROR")
|
||||
}
|
||||
if v, _ := redis.String(vv[0], nil); v != "author" {
|
||||
t.Error("GetMulti ERROR")
|
||||
}
|
||||
if v, _ := redis.String(vv[1], nil); v != "author1" {
|
||||
t.Error("GetMulti ERROR")
|
||||
}
|
||||
|
||||
// test clear all
|
||||
if err = bm.ClearAll(); err != nil {
|
||||
t.Error("clear all err")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCache_Scan(t *testing.T) {
|
||||
timeoutDuration := 10 * time.Second
|
||||
|
||||
addr := os.Getenv("REDIS_ADDR")
|
||||
if addr == "" {
|
||||
addr = "127.0.0.1:6379"
|
||||
}
|
||||
|
||||
// init
|
||||
bm, err := cache.NewCache("redis", fmt.Sprintf(`{"conn": "%s"}`, addr))
|
||||
if err != nil {
|
||||
t.Error("init err")
|
||||
}
|
||||
// insert all
|
||||
for i := 0; i < 10000; i++ {
|
||||
if err = bm.Put(fmt.Sprintf("astaxie%d", i), fmt.Sprintf("author%d", i), timeoutDuration); err != nil {
|
||||
t.Error("set Error", err)
|
||||
}
|
||||
}
|
||||
time.Sleep(time.Second)
|
||||
// scan all for the first time
|
||||
keys, err := bm.(*Cache).Scan(DefaultKey + ":*")
|
||||
if err != nil {
|
||||
t.Error("scan Error", err)
|
||||
}
|
||||
|
||||
assert.Equal(t, 10000, len(keys), "scan all error")
|
||||
|
||||
// clear all
|
||||
if err = bm.ClearAll(); err != nil {
|
||||
t.Error("clear all err")
|
||||
}
|
||||
|
||||
// scan all for the second time
|
||||
keys, err = bm.(*Cache).Scan(DefaultKey + ":*")
|
||||
if err != nil {
|
||||
t.Error("scan Error", err)
|
||||
}
|
||||
if len(keys) != 0 {
|
||||
t.Error("scan all err")
|
||||
}
|
||||
}
|
232
pkg/client/cache/ssdb/ssdb.go
vendored
Normal file
232
pkg/client/cache/ssdb/ssdb.go
vendored
Normal file
@ -0,0 +1,232 @@
|
||||
package ssdb
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/ssdb/gossdb/ssdb"
|
||||
|
||||
"github.com/astaxie/beego/pkg/client/cache"
|
||||
)
|
||||
|
||||
// Cache SSDB adapter
|
||||
type Cache struct {
|
||||
conn *ssdb.Client
|
||||
conninfo []string
|
||||
}
|
||||
|
||||
//NewSsdbCache creates new ssdb adapter.
|
||||
func NewSsdbCache() cache.Cache {
|
||||
return &Cache{}
|
||||
}
|
||||
|
||||
// Get gets a key's value from memcache.
|
||||
func (rc *Cache) Get(key string) interface{} {
|
||||
if rc.conn == nil {
|
||||
if err := rc.connectInit(); err != nil {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
value, err := rc.conn.Get(key)
|
||||
if err == nil {
|
||||
return value
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetMulti gets one or keys values from memcache.
|
||||
func (rc *Cache) GetMulti(keys []string) []interface{} {
|
||||
size := len(keys)
|
||||
var values []interface{}
|
||||
if rc.conn == nil {
|
||||
if err := rc.connectInit(); err != nil {
|
||||
for i := 0; i < size; i++ {
|
||||
values = append(values, err)
|
||||
}
|
||||
return values
|
||||
}
|
||||
}
|
||||
res, err := rc.conn.Do("multi_get", keys)
|
||||
resSize := len(res)
|
||||
if err == nil {
|
||||
for i := 1; i < resSize; i += 2 {
|
||||
values = append(values, res[i+1])
|
||||
}
|
||||
return values
|
||||
}
|
||||
for i := 0; i < size; i++ {
|
||||
values = append(values, err)
|
||||
}
|
||||
return values
|
||||
}
|
||||
|
||||
// DelMulti deletes one or more keys from memcache
|
||||
func (rc *Cache) DelMulti(keys []string) error {
|
||||
if rc.conn == nil {
|
||||
if err := rc.connectInit(); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
_, err := rc.conn.Do("multi_del", keys)
|
||||
return err
|
||||
}
|
||||
|
||||
// Put puts value into memcache.
|
||||
// value: must be of type string
|
||||
func (rc *Cache) Put(key string, value interface{}, timeout time.Duration) error {
|
||||
if rc.conn == nil {
|
||||
if err := rc.connectInit(); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
v, ok := value.(string)
|
||||
if !ok {
|
||||
return errors.New("value must string")
|
||||
}
|
||||
var resp []string
|
||||
var err error
|
||||
ttl := int(timeout / time.Second)
|
||||
if ttl < 0 {
|
||||
resp, err = rc.conn.Do("set", key, v)
|
||||
} else {
|
||||
resp, err = rc.conn.Do("setx", key, v, ttl)
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if len(resp) == 2 && resp[0] == "ok" {
|
||||
return nil
|
||||
}
|
||||
return errors.New("bad response")
|
||||
}
|
||||
|
||||
// Delete deletes a value in memcache.
|
||||
func (rc *Cache) Delete(key string) error {
|
||||
if rc.conn == nil {
|
||||
if err := rc.connectInit(); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
_, err := rc.conn.Del(key)
|
||||
return err
|
||||
}
|
||||
|
||||
// Incr increases a key's counter.
|
||||
func (rc *Cache) Incr(key string) error {
|
||||
if rc.conn == nil {
|
||||
if err := rc.connectInit(); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
_, err := rc.conn.Do("incr", key, 1)
|
||||
return err
|
||||
}
|
||||
|
||||
// Decr decrements a key's counter.
|
||||
func (rc *Cache) Decr(key string) error {
|
||||
if rc.conn == nil {
|
||||
if err := rc.connectInit(); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
_, err := rc.conn.Do("incr", key, -1)
|
||||
return err
|
||||
}
|
||||
|
||||
// IsExist checks if a key exists in memcache.
|
||||
func (rc *Cache) IsExist(key string) bool {
|
||||
if rc.conn == nil {
|
||||
if err := rc.connectInit(); err != nil {
|
||||
return false
|
||||
}
|
||||
}
|
||||
resp, err := rc.conn.Do("exists", key)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
if len(resp) == 2 && resp[1] == "1" {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
|
||||
}
|
||||
|
||||
// ClearAll clears all cached items in memcache.
|
||||
func (rc *Cache) ClearAll() error {
|
||||
if rc.conn == nil {
|
||||
if err := rc.connectInit(); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
keyStart, keyEnd, limit := "", "", 50
|
||||
resp, err := rc.Scan(keyStart, keyEnd, limit)
|
||||
for err == nil {
|
||||
size := len(resp)
|
||||
if size == 1 {
|
||||
return nil
|
||||
}
|
||||
keys := []string{}
|
||||
for i := 1; i < size; i += 2 {
|
||||
keys = append(keys, resp[i])
|
||||
}
|
||||
_, e := rc.conn.Do("multi_del", keys)
|
||||
if e != nil {
|
||||
return e
|
||||
}
|
||||
keyStart = resp[size-2]
|
||||
resp, err = rc.Scan(keyStart, keyEnd, limit)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// Scan key all cached in ssdb.
|
||||
func (rc *Cache) Scan(keyStart string, keyEnd string, limit int) ([]string, error) {
|
||||
if rc.conn == nil {
|
||||
if err := rc.connectInit(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
resp, err := rc.conn.Do("scan", keyStart, keyEnd, limit)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
// StartAndGC starts the memcache adapter.
|
||||
// config: must be in the format {"conn":"connection info"}.
|
||||
// If an error occurs during connection, an error is returned
|
||||
func (rc *Cache) StartAndGC(config string) error {
|
||||
var cf map[string]string
|
||||
json.Unmarshal([]byte(config), &cf)
|
||||
if _, ok := cf["conn"]; !ok {
|
||||
return errors.New("config has no conn key")
|
||||
}
|
||||
rc.conninfo = strings.Split(cf["conn"], ";")
|
||||
if rc.conn == nil {
|
||||
if err := rc.connectInit(); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// connect to memcache and keep the connection.
|
||||
func (rc *Cache) connectInit() error {
|
||||
conninfoArray := strings.Split(rc.conninfo[0], ":")
|
||||
host := conninfoArray[0]
|
||||
port, e := strconv.Atoi(conninfoArray[1])
|
||||
if e != nil {
|
||||
return e
|
||||
}
|
||||
var err error
|
||||
rc.conn, err = ssdb.Connect(host, port)
|
||||
return err
|
||||
}
|
||||
|
||||
func init() {
|
||||
cache.Register("ssdb", NewSsdbCache)
|
||||
}
|
112
pkg/client/cache/ssdb/ssdb_test.go
vendored
Normal file
112
pkg/client/cache/ssdb/ssdb_test.go
vendored
Normal file
@ -0,0 +1,112 @@
|
||||
package ssdb
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"strconv"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/astaxie/beego/pkg/client/cache"
|
||||
)
|
||||
|
||||
func TestSsdbcacheCache(t *testing.T) {
|
||||
|
||||
ssdbAddr := os.Getenv("SSDB_ADDR")
|
||||
if ssdbAddr == "" {
|
||||
ssdbAddr = "127.0.0.1:8888"
|
||||
}
|
||||
|
||||
ssdb, err := cache.NewCache("ssdb", fmt.Sprintf(`{"conn": "%s"}`, ssdbAddr))
|
||||
if err != nil {
|
||||
t.Error("init err")
|
||||
}
|
||||
|
||||
// test put and exist
|
||||
if ssdb.IsExist("ssdb") {
|
||||
t.Error("check err")
|
||||
}
|
||||
timeoutDuration := 10 * time.Second
|
||||
//timeoutDuration := -10*time.Second if timeoutDuration is negtive,it means permanent
|
||||
if err = ssdb.Put("ssdb", "ssdb", timeoutDuration); err != nil {
|
||||
t.Error("set Error", err)
|
||||
}
|
||||
if !ssdb.IsExist("ssdb") {
|
||||
t.Error("check err")
|
||||
}
|
||||
|
||||
// Get test done
|
||||
if err = ssdb.Put("ssdb", "ssdb", timeoutDuration); err != nil {
|
||||
t.Error("set Error", err)
|
||||
}
|
||||
|
||||
if v := ssdb.Get("ssdb"); v != "ssdb" {
|
||||
t.Error("get Error")
|
||||
}
|
||||
|
||||
//inc/dec test done
|
||||
if err = ssdb.Put("ssdb", "2", timeoutDuration); err != nil {
|
||||
t.Error("set Error", err)
|
||||
}
|
||||
if err = ssdb.Incr("ssdb"); err != nil {
|
||||
t.Error("incr Error", err)
|
||||
}
|
||||
|
||||
if v, err := strconv.Atoi(ssdb.Get("ssdb").(string)); err != nil || v != 3 {
|
||||
t.Error("get err")
|
||||
}
|
||||
|
||||
if err = ssdb.Decr("ssdb"); err != nil {
|
||||
t.Error("decr error")
|
||||
}
|
||||
|
||||
// test del
|
||||
if err = ssdb.Put("ssdb", "3", timeoutDuration); err != nil {
|
||||
t.Error("set Error", err)
|
||||
}
|
||||
if v, err := strconv.Atoi(ssdb.Get("ssdb").(string)); err != nil || v != 3 {
|
||||
t.Error("get err")
|
||||
}
|
||||
if err := ssdb.Delete("ssdb"); err == nil {
|
||||
if ssdb.IsExist("ssdb") {
|
||||
t.Error("delete err")
|
||||
}
|
||||
}
|
||||
|
||||
//test string
|
||||
if err = ssdb.Put("ssdb", "ssdb", -10*time.Second); err != nil {
|
||||
t.Error("set Error", err)
|
||||
}
|
||||
if !ssdb.IsExist("ssdb") {
|
||||
t.Error("check err")
|
||||
}
|
||||
if v := ssdb.Get("ssdb").(string); v != "ssdb" {
|
||||
t.Error("get err")
|
||||
}
|
||||
|
||||
//test GetMulti done
|
||||
if err = ssdb.Put("ssdb1", "ssdb1", -10*time.Second); err != nil {
|
||||
t.Error("set Error", err)
|
||||
}
|
||||
if !ssdb.IsExist("ssdb1") {
|
||||
t.Error("check err")
|
||||
}
|
||||
vv := ssdb.GetMulti([]string{"ssdb", "ssdb1"})
|
||||
if len(vv) != 2 {
|
||||
t.Error("getmulti error")
|
||||
}
|
||||
if vv[0].(string) != "ssdb" {
|
||||
t.Error("getmulti error")
|
||||
}
|
||||
if vv[1].(string) != "ssdb1" {
|
||||
t.Error("getmulti error")
|
||||
}
|
||||
|
||||
// test clear all done
|
||||
if err = ssdb.ClearAll(); err != nil {
|
||||
t.Error("clear all err")
|
||||
}
|
||||
if ssdb.IsExist("ssdb") || ssdb.IsExist("ssdb1") {
|
||||
t.Error("check err")
|
||||
}
|
||||
}
|
97
pkg/client/httplib/README.md
Normal file
97
pkg/client/httplib/README.md
Normal file
@ -0,0 +1,97 @@
|
||||
# httplib
|
||||
httplib is an libs help you to curl remote url.
|
||||
|
||||
# How to use?
|
||||
|
||||
## GET
|
||||
you can use Get to crawl data.
|
||||
|
||||
import "github.com/astaxie/beego/httplib"
|
||||
|
||||
str, err := httplib.Get("http://beego.me/").String()
|
||||
if err != nil {
|
||||
// error
|
||||
}
|
||||
fmt.Println(str)
|
||||
|
||||
## POST
|
||||
POST data to remote url
|
||||
|
||||
req := httplib.Post("http://beego.me/")
|
||||
req.Param("username","astaxie")
|
||||
req.Param("password","123456")
|
||||
str, err := req.String()
|
||||
if err != nil {
|
||||
// error
|
||||
}
|
||||
fmt.Println(str)
|
||||
|
||||
## Set timeout
|
||||
|
||||
The default timeout is `60` seconds, function prototype:
|
||||
|
||||
SetTimeout(connectTimeout, readWriteTimeout time.Duration)
|
||||
|
||||
Example:
|
||||
|
||||
// GET
|
||||
httplib.Get("http://beego.me/").SetTimeout(100 * time.Second, 30 * time.Second)
|
||||
|
||||
// POST
|
||||
httplib.Post("http://beego.me/").SetTimeout(100 * time.Second, 30 * time.Second)
|
||||
|
||||
|
||||
## Debug
|
||||
|
||||
If you want to debug the request info, set the debug on
|
||||
|
||||
httplib.Get("http://beego.me/").Debug(true)
|
||||
|
||||
## Set HTTP Basic Auth
|
||||
|
||||
str, err := Get("http://beego.me/").SetBasicAuth("user", "passwd").String()
|
||||
if err != nil {
|
||||
// error
|
||||
}
|
||||
fmt.Println(str)
|
||||
|
||||
## Set HTTPS
|
||||
|
||||
If request url is https, You can set the client support TSL:
|
||||
|
||||
httplib.SetTLSClientConfig(&tls.Config{InsecureSkipVerify: true})
|
||||
|
||||
More info about the `tls.Config` please visit http://golang.org/pkg/crypto/tls/#Config
|
||||
|
||||
## Set HTTP Version
|
||||
|
||||
some servers need to specify the protocol version of HTTP
|
||||
|
||||
httplib.Get("http://beego.me/").SetProtocolVersion("HTTP/1.1")
|
||||
|
||||
## Set Cookie
|
||||
|
||||
some http request need setcookie. So set it like this:
|
||||
|
||||
cookie := &http.Cookie{}
|
||||
cookie.Name = "username"
|
||||
cookie.Value = "astaxie"
|
||||
httplib.Get("http://beego.me/").SetCookie(cookie)
|
||||
|
||||
## Upload file
|
||||
|
||||
httplib support mutil file upload, use `req.PostFile()`
|
||||
|
||||
req := httplib.Post("http://beego.me/")
|
||||
req.Param("username","astaxie")
|
||||
req.PostFile("uploadfile1", "httplib.pdf")
|
||||
str, err := req.String()
|
||||
if err != nil {
|
||||
// error
|
||||
}
|
||||
fmt.Println(str)
|
||||
|
||||
|
||||
See godoc for further documentation and examples.
|
||||
|
||||
* [godoc.org/github.com/astaxie/beego/httplib](https://godoc.org/github.com/astaxie/beego/httplib)
|
24
pkg/client/httplib/filter.go
Normal file
24
pkg/client/httplib/filter.go
Normal file
@ -0,0 +1,24 @@
|
||||
// Copyright 2020 beego
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package httplib
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
type FilterChain func(next Filter) Filter
|
||||
|
||||
type Filter func(ctx context.Context, req *BeegoHTTPRequest) (*http.Response, error)
|
71
pkg/client/httplib/filter/opentracing/filter.go
Normal file
71
pkg/client/httplib/filter/opentracing/filter.go
Normal file
@ -0,0 +1,71 @@
|
||||
// Copyright 2020 beego
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package opentracing
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
|
||||
"github.com/astaxie/beego/pkg/client/httplib"
|
||||
logKit "github.com/go-kit/kit/log"
|
||||
opentracingKit "github.com/go-kit/kit/tracing/opentracing"
|
||||
"github.com/opentracing/opentracing-go"
|
||||
)
|
||||
|
||||
type FilterChainBuilder struct {
|
||||
// CustomSpanFunc users are able to custom their span
|
||||
CustomSpanFunc func(span opentracing.Span, ctx context.Context,
|
||||
req *httplib.BeegoHTTPRequest, resp *http.Response, err error)
|
||||
}
|
||||
|
||||
func (builder *FilterChainBuilder) FilterChain(next httplib.Filter) httplib.Filter {
|
||||
|
||||
return func(ctx context.Context, req *httplib.BeegoHTTPRequest) (*http.Response, error) {
|
||||
|
||||
method := req.GetRequest().Method
|
||||
|
||||
operationName := method + "#" + req.GetRequest().URL.String()
|
||||
span, spanCtx := opentracing.StartSpanFromContext(ctx, operationName)
|
||||
defer span.Finish()
|
||||
|
||||
inject := opentracingKit.ContextToHTTP(opentracing.GlobalTracer(), logKit.NewNopLogger())
|
||||
inject(spanCtx, req.GetRequest())
|
||||
resp, err := next(spanCtx, req)
|
||||
|
||||
if resp != nil {
|
||||
span.SetTag("http.status_code", resp.StatusCode)
|
||||
}
|
||||
span.SetTag("http.method", method)
|
||||
span.SetTag("peer.hostname", req.GetRequest().URL.Host)
|
||||
span.SetTag("http.url", req.GetRequest().URL.String())
|
||||
span.SetTag("http.scheme", req.GetRequest().URL.Scheme)
|
||||
span.SetTag("span.kind", "client")
|
||||
span.SetTag("component", "beego")
|
||||
if err != nil {
|
||||
span.SetTag("error", true)
|
||||
span.SetTag("message", err.Error())
|
||||
} else if resp != nil && !(resp.StatusCode < 300 && resp.StatusCode >= 200) {
|
||||
span.SetTag("error", true)
|
||||
}
|
||||
|
||||
span.SetTag("peer.address", req.GetRequest().RemoteAddr)
|
||||
span.SetTag("http.proto", req.GetRequest().Proto)
|
||||
|
||||
if builder.CustomSpanFunc != nil {
|
||||
builder.CustomSpanFunc(span, ctx, req, resp, err)
|
||||
}
|
||||
return resp, err
|
||||
}
|
||||
}
|
42
pkg/client/httplib/filter/opentracing/filter_test.go
Normal file
42
pkg/client/httplib/filter/opentracing/filter_test.go
Normal file
@ -0,0 +1,42 @@
|
||||
// Copyright 2020 beego
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package opentracing
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net/http"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/astaxie/beego/pkg/client/httplib"
|
||||
)
|
||||
|
||||
func TestFilterChainBuilder_FilterChain(t *testing.T) {
|
||||
next := func(ctx context.Context, req *httplib.BeegoHTTPRequest) (*http.Response, error) {
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
return &http.Response{
|
||||
StatusCode: 404,
|
||||
}, errors.New("hello")
|
||||
}
|
||||
builder := &FilterChainBuilder{}
|
||||
filter := builder.FilterChain(next)
|
||||
req := httplib.Get("https://github.com/notifications?query=repo%3Aastaxie%2Fbeego")
|
||||
resp, err := filter(context.Background(), req)
|
||||
assert.NotNil(t, resp)
|
||||
assert.NotNil(t, err)
|
||||
}
|
75
pkg/client/httplib/filter/prometheus/filter.go
Normal file
75
pkg/client/httplib/filter/prometheus/filter.go
Normal file
@ -0,0 +1,75 @@
|
||||
// Copyright 2020 beego
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package prometheus
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/prometheus/client_golang/prometheus"
|
||||
|
||||
"github.com/astaxie/beego/pkg/client/httplib"
|
||||
"github.com/astaxie/beego/pkg/server/web"
|
||||
)
|
||||
|
||||
type FilterChainBuilder struct {
|
||||
summaryVec prometheus.ObserverVec
|
||||
}
|
||||
|
||||
func (builder *FilterChainBuilder) FilterChain(next httplib.Filter) httplib.Filter {
|
||||
|
||||
builder.summaryVec = prometheus.NewSummaryVec(prometheus.SummaryOpts{
|
||||
Name: "beego",
|
||||
Subsystem: "remote_http_request",
|
||||
ConstLabels: map[string]string{
|
||||
"server": web.BConfig.ServerName,
|
||||
"env": web.BConfig.RunMode,
|
||||
"appname": web.BConfig.AppName,
|
||||
},
|
||||
Help: "The statics info for remote http requests",
|
||||
}, []string{"proto", "scheme", "method", "host", "path", "status", "duration", "isError"})
|
||||
|
||||
return func(ctx context.Context, req *httplib.BeegoHTTPRequest) (*http.Response, error) {
|
||||
startTime := time.Now()
|
||||
resp, err := next(ctx, req)
|
||||
endTime := time.Now()
|
||||
go builder.report(startTime, endTime, ctx, req, resp, err)
|
||||
return resp, err
|
||||
}
|
||||
}
|
||||
|
||||
func (builder *FilterChainBuilder) report(startTime time.Time, endTime time.Time,
|
||||
ctx context.Context, req *httplib.BeegoHTTPRequest, resp *http.Response, err error) {
|
||||
|
||||
proto := req.GetRequest().Proto
|
||||
|
||||
scheme := req.GetRequest().URL.Scheme
|
||||
method := req.GetRequest().Method
|
||||
|
||||
host := req.GetRequest().URL.Host
|
||||
path := req.GetRequest().URL.Path
|
||||
|
||||
status := -1
|
||||
if resp != nil {
|
||||
status = resp.StatusCode
|
||||
}
|
||||
|
||||
dur := int(endTime.Sub(startTime) / time.Millisecond)
|
||||
|
||||
builder.summaryVec.WithLabelValues(proto, scheme, method, host, path,
|
||||
strconv.Itoa(status), strconv.Itoa(dur), strconv.FormatBool(err == nil))
|
||||
}
|
41
pkg/client/httplib/filter/prometheus/filter_test.go
Normal file
41
pkg/client/httplib/filter/prometheus/filter_test.go
Normal file
@ -0,0 +1,41 @@
|
||||
// Copyright 2020 beego
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package prometheus
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/astaxie/beego/pkg/client/httplib"
|
||||
)
|
||||
|
||||
func TestFilterChainBuilder_FilterChain(t *testing.T) {
|
||||
next := func(ctx context.Context, req *httplib.BeegoHTTPRequest) (*http.Response, error) {
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
return &http.Response{
|
||||
StatusCode: 404,
|
||||
}, nil
|
||||
}
|
||||
builder := &FilterChainBuilder{}
|
||||
filter := builder.FilterChain(next)
|
||||
req := httplib.Get("https://github.com/notifications?query=repo%3Aastaxie%2Fbeego")
|
||||
resp, err := filter(context.Background(), req)
|
||||
assert.NotNil(t, resp)
|
||||
assert.Nil(t, err)
|
||||
}
|
689
pkg/client/httplib/httplib.go
Normal file
689
pkg/client/httplib/httplib.go
Normal file
@ -0,0 +1,689 @@
|
||||
// Copyright 2014 beego Author. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
// Package httplib is used as http.Client
|
||||
// Usage:
|
||||
//
|
||||
// import "github.com/astaxie/beego/httplib"
|
||||
//
|
||||
// b := httplib.Post("http://beego.me/")
|
||||
// b.Param("username","astaxie")
|
||||
// b.Param("password","123456")
|
||||
// b.PostFile("uploadfile1", "httplib.pdf")
|
||||
// b.PostFile("uploadfile2", "httplib.txt")
|
||||
// str, err := b.String()
|
||||
// if err != nil {
|
||||
// t.Fatal(err)
|
||||
// }
|
||||
// fmt.Println(str)
|
||||
//
|
||||
// more docs http://beego.me/docs/module/httplib.md
|
||||
package httplib
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"compress/gzip"
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"encoding/json"
|
||||
"encoding/xml"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"log"
|
||||
"mime/multipart"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/cookiejar"
|
||||
"net/http/httputil"
|
||||
"net/url"
|
||||
"os"
|
||||
"path"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"gopkg.in/yaml.v2"
|
||||
)
|
||||
|
||||
var defaultSetting = BeegoHTTPSettings{
|
||||
UserAgent: "beegoServer",
|
||||
ConnectTimeout: 60 * time.Second,
|
||||
ReadWriteTimeout: 60 * time.Second,
|
||||
Gzip: true,
|
||||
DumpBody: true,
|
||||
}
|
||||
|
||||
var defaultCookieJar http.CookieJar
|
||||
var settingMutex sync.Mutex
|
||||
|
||||
// it will be the last filter and execute request.Do
|
||||
var doRequestFilter = func(ctx context.Context, req *BeegoHTTPRequest) (*http.Response, error) {
|
||||
return req.doRequest(ctx)
|
||||
}
|
||||
|
||||
// createDefaultCookie creates a global cookiejar to store cookies.
|
||||
func createDefaultCookie() {
|
||||
settingMutex.Lock()
|
||||
defer settingMutex.Unlock()
|
||||
defaultCookieJar, _ = cookiejar.New(nil)
|
||||
}
|
||||
|
||||
// SetDefaultSetting overwrites default settings
|
||||
func SetDefaultSetting(setting BeegoHTTPSettings) {
|
||||
settingMutex.Lock()
|
||||
defer settingMutex.Unlock()
|
||||
defaultSetting = setting
|
||||
}
|
||||
|
||||
// NewBeegoRequest returns *BeegoHttpRequest with specific method
|
||||
func NewBeegoRequest(rawurl, method string) *BeegoHTTPRequest {
|
||||
var resp http.Response
|
||||
u, err := url.Parse(rawurl)
|
||||
if err != nil {
|
||||
log.Println("Httplib:", err)
|
||||
}
|
||||
req := http.Request{
|
||||
URL: u,
|
||||
Method: method,
|
||||
Header: make(http.Header),
|
||||
Proto: "HTTP/1.1",
|
||||
ProtoMajor: 1,
|
||||
ProtoMinor: 1,
|
||||
}
|
||||
return &BeegoHTTPRequest{
|
||||
url: rawurl,
|
||||
req: &req,
|
||||
params: map[string][]string{},
|
||||
files: map[string]string{},
|
||||
setting: defaultSetting,
|
||||
resp: &resp,
|
||||
}
|
||||
}
|
||||
|
||||
// Get returns *BeegoHttpRequest with GET method.
|
||||
func Get(url string) *BeegoHTTPRequest {
|
||||
return NewBeegoRequest(url, "GET")
|
||||
}
|
||||
|
||||
// Post returns *BeegoHttpRequest with POST method.
|
||||
func Post(url string) *BeegoHTTPRequest {
|
||||
return NewBeegoRequest(url, "POST")
|
||||
}
|
||||
|
||||
// Put returns *BeegoHttpRequest with PUT method.
|
||||
func Put(url string) *BeegoHTTPRequest {
|
||||
return NewBeegoRequest(url, "PUT")
|
||||
}
|
||||
|
||||
// Delete returns *BeegoHttpRequest DELETE method.
|
||||
func Delete(url string) *BeegoHTTPRequest {
|
||||
return NewBeegoRequest(url, "DELETE")
|
||||
}
|
||||
|
||||
// Head returns *BeegoHttpRequest with HEAD method.
|
||||
func Head(url string) *BeegoHTTPRequest {
|
||||
return NewBeegoRequest(url, "HEAD")
|
||||
}
|
||||
|
||||
// BeegoHTTPSettings is the http.Client setting
|
||||
type BeegoHTTPSettings struct {
|
||||
ShowDebug bool
|
||||
UserAgent string
|
||||
ConnectTimeout time.Duration
|
||||
ReadWriteTimeout time.Duration
|
||||
TLSClientConfig *tls.Config
|
||||
Proxy func(*http.Request) (*url.URL, error)
|
||||
Transport http.RoundTripper
|
||||
CheckRedirect func(req *http.Request, via []*http.Request) error
|
||||
EnableCookie bool
|
||||
Gzip bool
|
||||
DumpBody bool
|
||||
Retries int // if set to -1 means will retry forever
|
||||
RetryDelay time.Duration
|
||||
FilterChains []FilterChain
|
||||
}
|
||||
|
||||
// BeegoHTTPRequest provides more useful methods than http.Request for requesting a url.
|
||||
type BeegoHTTPRequest struct {
|
||||
url string
|
||||
req *http.Request
|
||||
params map[string][]string
|
||||
files map[string]string
|
||||
setting BeegoHTTPSettings
|
||||
resp *http.Response
|
||||
body []byte
|
||||
dump []byte
|
||||
}
|
||||
|
||||
// GetRequest returns the request object
|
||||
func (b *BeegoHTTPRequest) GetRequest() *http.Request {
|
||||
return b.req
|
||||
}
|
||||
|
||||
// Setting changes request settings
|
||||
func (b *BeegoHTTPRequest) Setting(setting BeegoHTTPSettings) *BeegoHTTPRequest {
|
||||
b.setting = setting
|
||||
return b
|
||||
}
|
||||
|
||||
// SetBasicAuth sets the request's Authorization header to use HTTP Basic Authentication with the provided username and password.
|
||||
func (b *BeegoHTTPRequest) SetBasicAuth(username, password string) *BeegoHTTPRequest {
|
||||
b.req.SetBasicAuth(username, password)
|
||||
return b
|
||||
}
|
||||
|
||||
// SetEnableCookie sets enable/disable cookiejar
|
||||
func (b *BeegoHTTPRequest) SetEnableCookie(enable bool) *BeegoHTTPRequest {
|
||||
b.setting.EnableCookie = enable
|
||||
return b
|
||||
}
|
||||
|
||||
// SetUserAgent sets User-Agent header field
|
||||
func (b *BeegoHTTPRequest) SetUserAgent(useragent string) *BeegoHTTPRequest {
|
||||
b.setting.UserAgent = useragent
|
||||
return b
|
||||
}
|
||||
|
||||
// Debug sets show debug or not when executing request.
|
||||
func (b *BeegoHTTPRequest) Debug(isdebug bool) *BeegoHTTPRequest {
|
||||
b.setting.ShowDebug = isdebug
|
||||
return b
|
||||
}
|
||||
|
||||
// Retries sets Retries times.
|
||||
// default is 0 (never retry)
|
||||
// -1 retry indefinitely (forever)
|
||||
// Other numbers specify the exact retry amount
|
||||
func (b *BeegoHTTPRequest) Retries(times int) *BeegoHTTPRequest {
|
||||
b.setting.Retries = times
|
||||
return b
|
||||
}
|
||||
|
||||
// RetryDelay sets the time to sleep between reconnection attempts
|
||||
func (b *BeegoHTTPRequest) RetryDelay(delay time.Duration) *BeegoHTTPRequest {
|
||||
b.setting.RetryDelay = delay
|
||||
return b
|
||||
}
|
||||
|
||||
// DumpBody sets the DumbBody field
|
||||
func (b *BeegoHTTPRequest) DumpBody(isdump bool) *BeegoHTTPRequest {
|
||||
b.setting.DumpBody = isdump
|
||||
return b
|
||||
}
|
||||
|
||||
// DumpRequest returns the DumpRequest
|
||||
func (b *BeegoHTTPRequest) DumpRequest() []byte {
|
||||
return b.dump
|
||||
}
|
||||
|
||||
// SetTimeout sets connect time out and read-write time out for BeegoRequest.
|
||||
func (b *BeegoHTTPRequest) SetTimeout(connectTimeout, readWriteTimeout time.Duration) *BeegoHTTPRequest {
|
||||
b.setting.ConnectTimeout = connectTimeout
|
||||
b.setting.ReadWriteTimeout = readWriteTimeout
|
||||
return b
|
||||
}
|
||||
|
||||
// SetTLSClientConfig sets TLS connection configuration if visiting HTTPS url.
|
||||
func (b *BeegoHTTPRequest) SetTLSClientConfig(config *tls.Config) *BeegoHTTPRequest {
|
||||
b.setting.TLSClientConfig = config
|
||||
return b
|
||||
}
|
||||
|
||||
// Header adds header item string in request.
|
||||
func (b *BeegoHTTPRequest) Header(key, value string) *BeegoHTTPRequest {
|
||||
b.req.Header.Set(key, value)
|
||||
return b
|
||||
}
|
||||
|
||||
// SetHost set the request host
|
||||
func (b *BeegoHTTPRequest) SetHost(host string) *BeegoHTTPRequest {
|
||||
b.req.Host = host
|
||||
return b
|
||||
}
|
||||
|
||||
// SetProtocolVersion sets the protocol version for incoming requests.
|
||||
// Client requests always use HTTP/1.1.
|
||||
func (b *BeegoHTTPRequest) SetProtocolVersion(vers string) *BeegoHTTPRequest {
|
||||
if len(vers) == 0 {
|
||||
vers = "HTTP/1.1"
|
||||
}
|
||||
|
||||
major, minor, ok := http.ParseHTTPVersion(vers)
|
||||
if ok {
|
||||
b.req.Proto = vers
|
||||
b.req.ProtoMajor = major
|
||||
b.req.ProtoMinor = minor
|
||||
}
|
||||
|
||||
return b
|
||||
}
|
||||
|
||||
// SetCookie adds a cookie to the request.
|
||||
func (b *BeegoHTTPRequest) SetCookie(cookie *http.Cookie) *BeegoHTTPRequest {
|
||||
b.req.Header.Add("Cookie", cookie.String())
|
||||
return b
|
||||
}
|
||||
|
||||
// SetTransport sets the transport field
|
||||
func (b *BeegoHTTPRequest) SetTransport(transport http.RoundTripper) *BeegoHTTPRequest {
|
||||
b.setting.Transport = transport
|
||||
return b
|
||||
}
|
||||
|
||||
// SetProxy sets the HTTP proxy
|
||||
// example:
|
||||
//
|
||||
// func(req *http.Request) (*url.URL, error) {
|
||||
// u, _ := url.ParseRequestURI("http://127.0.0.1:8118")
|
||||
// return u, nil
|
||||
// }
|
||||
func (b *BeegoHTTPRequest) SetProxy(proxy func(*http.Request) (*url.URL, error)) *BeegoHTTPRequest {
|
||||
b.setting.Proxy = proxy
|
||||
return b
|
||||
}
|
||||
|
||||
// SetCheckRedirect specifies the policy for handling redirects.
|
||||
//
|
||||
// If CheckRedirect is nil, the Client uses its default policy,
|
||||
// which is to stop after 10 consecutive requests.
|
||||
func (b *BeegoHTTPRequest) SetCheckRedirect(redirect func(req *http.Request, via []*http.Request) error) *BeegoHTTPRequest {
|
||||
b.setting.CheckRedirect = redirect
|
||||
return b
|
||||
}
|
||||
|
||||
// SetFilters will use the filter as the invocation filters
|
||||
func (b *BeegoHTTPRequest) SetFilters(fcs ...FilterChain) *BeegoHTTPRequest {
|
||||
b.setting.FilterChains = fcs
|
||||
return b
|
||||
}
|
||||
|
||||
// AddFilters adds filter
|
||||
func (b *BeegoHTTPRequest) AddFilters(fcs ...FilterChain) *BeegoHTTPRequest {
|
||||
b.setting.FilterChains = append(b.setting.FilterChains, fcs...)
|
||||
return b
|
||||
}
|
||||
|
||||
// Param adds query param in to request.
|
||||
// params build query string as ?key1=value1&key2=value2...
|
||||
func (b *BeegoHTTPRequest) Param(key, value string) *BeegoHTTPRequest {
|
||||
if param, ok := b.params[key]; ok {
|
||||
b.params[key] = append(param, value)
|
||||
} else {
|
||||
b.params[key] = []string{value}
|
||||
}
|
||||
return b
|
||||
}
|
||||
|
||||
// PostFile adds a post file to the request
|
||||
func (b *BeegoHTTPRequest) PostFile(formname, filename string) *BeegoHTTPRequest {
|
||||
b.files[formname] = filename
|
||||
return b
|
||||
}
|
||||
|
||||
// Body adds request raw body.
|
||||
// Supports string and []byte.
|
||||
func (b *BeegoHTTPRequest) Body(data interface{}) *BeegoHTTPRequest {
|
||||
switch t := data.(type) {
|
||||
case string:
|
||||
bf := bytes.NewBufferString(t)
|
||||
b.req.Body = ioutil.NopCloser(bf)
|
||||
b.req.ContentLength = int64(len(t))
|
||||
case []byte:
|
||||
bf := bytes.NewBuffer(t)
|
||||
b.req.Body = ioutil.NopCloser(bf)
|
||||
b.req.ContentLength = int64(len(t))
|
||||
}
|
||||
return b
|
||||
}
|
||||
|
||||
// XMLBody adds the request raw body encoded in XML.
|
||||
func (b *BeegoHTTPRequest) XMLBody(obj interface{}) (*BeegoHTTPRequest, error) {
|
||||
if b.req.Body == nil && obj != nil {
|
||||
byts, err := xml.Marshal(obj)
|
||||
if err != nil {
|
||||
return b, err
|
||||
}
|
||||
b.req.Body = ioutil.NopCloser(bytes.NewReader(byts))
|
||||
b.req.ContentLength = int64(len(byts))
|
||||
b.req.Header.Set("Content-Type", "application/xml")
|
||||
}
|
||||
return b, nil
|
||||
}
|
||||
|
||||
// YAMLBody adds the request raw body encoded in YAML.
|
||||
func (b *BeegoHTTPRequest) YAMLBody(obj interface{}) (*BeegoHTTPRequest, error) {
|
||||
if b.req.Body == nil && obj != nil {
|
||||
byts, err := yaml.Marshal(obj)
|
||||
if err != nil {
|
||||
return b, err
|
||||
}
|
||||
b.req.Body = ioutil.NopCloser(bytes.NewReader(byts))
|
||||
b.req.ContentLength = int64(len(byts))
|
||||
b.req.Header.Set("Content-Type", "application/x+yaml")
|
||||
}
|
||||
return b, nil
|
||||
}
|
||||
|
||||
// JSONBody adds the request raw body encoded in JSON.
|
||||
func (b *BeegoHTTPRequest) JSONBody(obj interface{}) (*BeegoHTTPRequest, error) {
|
||||
if b.req.Body == nil && obj != nil {
|
||||
byts, err := json.Marshal(obj)
|
||||
if err != nil {
|
||||
return b, err
|
||||
}
|
||||
b.req.Body = ioutil.NopCloser(bytes.NewReader(byts))
|
||||
b.req.ContentLength = int64(len(byts))
|
||||
b.req.Header.Set("Content-Type", "application/json")
|
||||
}
|
||||
return b, nil
|
||||
}
|
||||
|
||||
func (b *BeegoHTTPRequest) buildURL(paramBody string) {
|
||||
// build GET url with query string
|
||||
if b.req.Method == "GET" && len(paramBody) > 0 {
|
||||
if strings.Contains(b.url, "?") {
|
||||
b.url += "&" + paramBody
|
||||
} else {
|
||||
b.url = b.url + "?" + paramBody
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// build POST/PUT/PATCH url and body
|
||||
if (b.req.Method == "POST" || b.req.Method == "PUT" || b.req.Method == "PATCH" || b.req.Method == "DELETE") && b.req.Body == nil {
|
||||
// with files
|
||||
if len(b.files) > 0 {
|
||||
pr, pw := io.Pipe()
|
||||
bodyWriter := multipart.NewWriter(pw)
|
||||
go func() {
|
||||
for formname, filename := range b.files {
|
||||
fileWriter, err := bodyWriter.CreateFormFile(formname, filename)
|
||||
if err != nil {
|
||||
log.Println("Httplib:", err)
|
||||
}
|
||||
fh, err := os.Open(filename)
|
||||
if err != nil {
|
||||
log.Println("Httplib:", err)
|
||||
}
|
||||
// iocopy
|
||||
_, err = io.Copy(fileWriter, fh)
|
||||
fh.Close()
|
||||
if err != nil {
|
||||
log.Println("Httplib:", err)
|
||||
}
|
||||
}
|
||||
for k, v := range b.params {
|
||||
for _, vv := range v {
|
||||
bodyWriter.WriteField(k, vv)
|
||||
}
|
||||
}
|
||||
bodyWriter.Close()
|
||||
pw.Close()
|
||||
}()
|
||||
b.Header("Content-Type", bodyWriter.FormDataContentType())
|
||||
b.req.Body = ioutil.NopCloser(pr)
|
||||
b.Header("Transfer-Encoding", "chunked")
|
||||
return
|
||||
}
|
||||
|
||||
// with params
|
||||
if len(paramBody) > 0 {
|
||||
b.Header("Content-Type", "application/x-www-form-urlencoded")
|
||||
b.Body(paramBody)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (b *BeegoHTTPRequest) getResponse() (*http.Response, error) {
|
||||
if b.resp.StatusCode != 0 {
|
||||
return b.resp, nil
|
||||
}
|
||||
resp, err := b.DoRequest()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
b.resp = resp
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
// DoRequest executes client.Do
|
||||
func (b *BeegoHTTPRequest) DoRequest() (resp *http.Response, err error) {
|
||||
return b.DoRequestWithCtx(context.Background())
|
||||
}
|
||||
|
||||
func (b *BeegoHTTPRequest) DoRequestWithCtx(ctx context.Context) (resp *http.Response, err error) {
|
||||
|
||||
root := doRequestFilter
|
||||
if len(b.setting.FilterChains) > 0 {
|
||||
for i := len(b.setting.FilterChains) - 1; i >= 0; i-- {
|
||||
root = b.setting.FilterChains[i](root)
|
||||
}
|
||||
}
|
||||
return root(ctx, b)
|
||||
}
|
||||
|
||||
func (b *BeegoHTTPRequest) doRequest(ctx context.Context) (resp *http.Response, err error) {
|
||||
var paramBody string
|
||||
if len(b.params) > 0 {
|
||||
var buf bytes.Buffer
|
||||
for k, v := range b.params {
|
||||
for _, vv := range v {
|
||||
buf.WriteString(url.QueryEscape(k))
|
||||
buf.WriteByte('=')
|
||||
buf.WriteString(url.QueryEscape(vv))
|
||||
buf.WriteByte('&')
|
||||
}
|
||||
}
|
||||
paramBody = buf.String()
|
||||
paramBody = paramBody[0 : len(paramBody)-1]
|
||||
}
|
||||
|
||||
b.buildURL(paramBody)
|
||||
urlParsed, err := url.Parse(b.url)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
b.req.URL = urlParsed
|
||||
|
||||
trans := b.setting.Transport
|
||||
|
||||
if trans == nil {
|
||||
// create default transport
|
||||
trans = &http.Transport{
|
||||
TLSClientConfig: b.setting.TLSClientConfig,
|
||||
Proxy: b.setting.Proxy,
|
||||
Dial: TimeoutDialer(b.setting.ConnectTimeout, b.setting.ReadWriteTimeout),
|
||||
MaxIdleConnsPerHost: 100,
|
||||
}
|
||||
} else {
|
||||
// if b.transport is *http.Transport then set the settings.
|
||||
if t, ok := trans.(*http.Transport); ok {
|
||||
if t.TLSClientConfig == nil {
|
||||
t.TLSClientConfig = b.setting.TLSClientConfig
|
||||
}
|
||||
if t.Proxy == nil {
|
||||
t.Proxy = b.setting.Proxy
|
||||
}
|
||||
if t.Dial == nil {
|
||||
t.Dial = TimeoutDialer(b.setting.ConnectTimeout, b.setting.ReadWriteTimeout)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var jar http.CookieJar
|
||||
if b.setting.EnableCookie {
|
||||
if defaultCookieJar == nil {
|
||||
createDefaultCookie()
|
||||
}
|
||||
jar = defaultCookieJar
|
||||
}
|
||||
|
||||
client := &http.Client{
|
||||
Transport: trans,
|
||||
Jar: jar,
|
||||
}
|
||||
|
||||
if b.setting.UserAgent != "" && b.req.Header.Get("User-Agent") == "" {
|
||||
b.req.Header.Set("User-Agent", b.setting.UserAgent)
|
||||
}
|
||||
|
||||
if b.setting.CheckRedirect != nil {
|
||||
client.CheckRedirect = b.setting.CheckRedirect
|
||||
}
|
||||
|
||||
if b.setting.ShowDebug {
|
||||
dump, err := httputil.DumpRequest(b.req, b.setting.DumpBody)
|
||||
if err != nil {
|
||||
log.Println(err.Error())
|
||||
}
|
||||
b.dump = dump
|
||||
}
|
||||
// retries default value is 0, it will run once.
|
||||
// retries equal to -1, it will run forever until success
|
||||
// retries is setted, it will retries fixed times.
|
||||
// Sleeps for a 400ms inbetween calls to reduce spam
|
||||
for i := 0; b.setting.Retries == -1 || i <= b.setting.Retries; i++ {
|
||||
resp, err = client.Do(b.req)
|
||||
if err == nil {
|
||||
break
|
||||
}
|
||||
time.Sleep(b.setting.RetryDelay)
|
||||
}
|
||||
return resp, err
|
||||
}
|
||||
|
||||
// String returns the body string in response.
|
||||
// Calls Response inner.
|
||||
func (b *BeegoHTTPRequest) String() (string, error) {
|
||||
data, err := b.Bytes()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return string(data), nil
|
||||
}
|
||||
|
||||
// Bytes returns the body []byte in response.
|
||||
// Calls Response inner.
|
||||
func (b *BeegoHTTPRequest) Bytes() ([]byte, error) {
|
||||
if b.body != nil {
|
||||
return b.body, nil
|
||||
}
|
||||
resp, err := b.getResponse()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if resp.Body == nil {
|
||||
return nil, nil
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if b.setting.Gzip && resp.Header.Get("Content-Encoding") == "gzip" {
|
||||
reader, err := gzip.NewReader(resp.Body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
b.body, err = ioutil.ReadAll(reader)
|
||||
return b.body, err
|
||||
}
|
||||
b.body, err = ioutil.ReadAll(resp.Body)
|
||||
return b.body, err
|
||||
}
|
||||
|
||||
// ToFile saves the body data in response to one file.
|
||||
// Calls Response inner.
|
||||
func (b *BeegoHTTPRequest) ToFile(filename string) error {
|
||||
resp, err := b.getResponse()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if resp.Body == nil {
|
||||
return nil
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
err = pathExistAndMkdir(filename)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
f, err := os.Create(filename)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer f.Close()
|
||||
_, err = io.Copy(f, resp.Body)
|
||||
return err
|
||||
}
|
||||
|
||||
// Check if the file directory exists. If it doesn't then it's created
|
||||
func pathExistAndMkdir(filename string) (err error) {
|
||||
filename = path.Dir(filename)
|
||||
_, err = os.Stat(filename)
|
||||
if err == nil {
|
||||
return nil
|
||||
}
|
||||
if os.IsNotExist(err) {
|
||||
err = os.MkdirAll(filename, os.ModePerm)
|
||||
if err == nil {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// ToJSON returns the map that marshals from the body bytes as json in response.
|
||||
// Calls Response inner.
|
||||
func (b *BeegoHTTPRequest) ToJSON(v interface{}) error {
|
||||
data, err := b.Bytes()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return json.Unmarshal(data, v)
|
||||
}
|
||||
|
||||
// ToXML returns the map that marshals from the body bytes as xml in response .
|
||||
// Calls Response inner.
|
||||
func (b *BeegoHTTPRequest) ToXML(v interface{}) error {
|
||||
data, err := b.Bytes()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return xml.Unmarshal(data, v)
|
||||
}
|
||||
|
||||
// ToYAML returns the map that marshals from the body bytes as yaml in response .
|
||||
// Calls Response inner.
|
||||
func (b *BeegoHTTPRequest) ToYAML(v interface{}) error {
|
||||
data, err := b.Bytes()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return yaml.Unmarshal(data, v)
|
||||
}
|
||||
|
||||
// Response executes request client gets response manually.
|
||||
func (b *BeegoHTTPRequest) Response() (*http.Response, error) {
|
||||
return b.getResponse()
|
||||
}
|
||||
|
||||
// TimeoutDialer returns functions of connection dialer with timeout settings for http.Transport Dial field.
|
||||
func TimeoutDialer(cTimeout time.Duration, rwTimeout time.Duration) func(net, addr string) (c net.Conn, err error) {
|
||||
return func(netw, addr string) (net.Conn, error) {
|
||||
conn, err := net.DialTimeout(netw, addr, cTimeout)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
err = conn.SetDeadline(time.Now().Add(rwTimeout))
|
||||
return conn, err
|
||||
}
|
||||
}
|
286
pkg/client/httplib/httplib_test.go
Normal file
286
pkg/client/httplib/httplib_test.go
Normal file
@ -0,0 +1,286 @@
|
||||
// Copyright 2014 beego Author. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package httplib
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"io/ioutil"
|
||||
"net"
|
||||
"net/http"
|
||||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestResponse(t *testing.T) {
|
||||
req := Get("http://httpbin.org/get")
|
||||
resp, err := req.Response()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Log(resp)
|
||||
}
|
||||
|
||||
func TestDoRequest(t *testing.T) {
|
||||
req := Get("https://goolnk.com/33BD2j")
|
||||
retryAmount := 1
|
||||
req.Retries(1)
|
||||
req.RetryDelay(1400 * time.Millisecond)
|
||||
retryDelay := 1400 * time.Millisecond
|
||||
|
||||
req.setting.CheckRedirect = func(redirectReq *http.Request, redirectVia []*http.Request) error {
|
||||
return errors.New("Redirect triggered")
|
||||
}
|
||||
|
||||
startTime := time.Now().UnixNano() / int64(time.Millisecond)
|
||||
|
||||
_, err := req.Response()
|
||||
if err == nil {
|
||||
t.Fatal("Response should have yielded an error")
|
||||
}
|
||||
|
||||
endTime := time.Now().UnixNano() / int64(time.Millisecond)
|
||||
elapsedTime := endTime - startTime
|
||||
delayedTime := int64(retryAmount) * retryDelay.Milliseconds()
|
||||
|
||||
if elapsedTime < delayedTime {
|
||||
t.Errorf("Not enough retries. Took %dms. Delay was meant to take %dms", elapsedTime, delayedTime)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func TestGet(t *testing.T) {
|
||||
req := Get("http://httpbin.org/get")
|
||||
b, err := req.Bytes()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Log(b)
|
||||
|
||||
s, err := req.String()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Log(s)
|
||||
|
||||
if string(b) != s {
|
||||
t.Fatal("request data not match")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSimplePost(t *testing.T) {
|
||||
v := "smallfish"
|
||||
req := Post("http://httpbin.org/post")
|
||||
req.Param("username", v)
|
||||
|
||||
str, err := req.String()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Log(str)
|
||||
|
||||
n := strings.Index(str, v)
|
||||
if n == -1 {
|
||||
t.Fatal(v + " not found in post")
|
||||
}
|
||||
}
|
||||
|
||||
//func TestPostFile(t *testing.T) {
|
||||
// v := "smallfish"
|
||||
// req := Post("http://httpbin.org/post")
|
||||
// req.Debug(true)
|
||||
// req.Param("username", v)
|
||||
// req.PostFile("uploadfile", "httplib_test.go")
|
||||
|
||||
// str, err := req.String()
|
||||
// if err != nil {
|
||||
// t.Fatal(err)
|
||||
// }
|
||||
// t.Log(str)
|
||||
|
||||
// n := strings.Index(str, v)
|
||||
// if n == -1 {
|
||||
// t.Fatal(v + " not found in post")
|
||||
// }
|
||||
//}
|
||||
|
||||
func TestSimplePut(t *testing.T) {
|
||||
str, err := Put("http://httpbin.org/put").String()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Log(str)
|
||||
}
|
||||
|
||||
func TestSimpleDelete(t *testing.T) {
|
||||
str, err := Delete("http://httpbin.org/delete").String()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Log(str)
|
||||
}
|
||||
|
||||
func TestSimpleDeleteParam(t *testing.T) {
|
||||
str, err := Delete("http://httpbin.org/delete").Param("key", "val").String()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Log(str)
|
||||
}
|
||||
|
||||
func TestWithCookie(t *testing.T) {
|
||||
v := "smallfish"
|
||||
str, err := Get("http://httpbin.org/cookies/set?k1=" + v).SetEnableCookie(true).String()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Log(str)
|
||||
|
||||
str, err = Get("http://httpbin.org/cookies").SetEnableCookie(true).String()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Log(str)
|
||||
|
||||
n := strings.Index(str, v)
|
||||
if n == -1 {
|
||||
t.Fatal(v + " not found in cookie")
|
||||
}
|
||||
}
|
||||
|
||||
func TestWithBasicAuth(t *testing.T) {
|
||||
str, err := Get("http://httpbin.org/basic-auth/user/passwd").SetBasicAuth("user", "passwd").String()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Log(str)
|
||||
n := strings.Index(str, "authenticated")
|
||||
if n == -1 {
|
||||
t.Fatal("authenticated not found in response")
|
||||
}
|
||||
}
|
||||
|
||||
func TestWithUserAgent(t *testing.T) {
|
||||
v := "beego"
|
||||
str, err := Get("http://httpbin.org/headers").SetUserAgent(v).String()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Log(str)
|
||||
|
||||
n := strings.Index(str, v)
|
||||
if n == -1 {
|
||||
t.Fatal(v + " not found in user-agent")
|
||||
}
|
||||
}
|
||||
|
||||
func TestWithSetting(t *testing.T) {
|
||||
v := "beego"
|
||||
var setting BeegoHTTPSettings
|
||||
setting.EnableCookie = true
|
||||
setting.UserAgent = v
|
||||
setting.Transport = &http.Transport{
|
||||
DialContext: (&net.Dialer{
|
||||
Timeout: 30 * time.Second,
|
||||
KeepAlive: 30 * time.Second,
|
||||
DualStack: true,
|
||||
}).DialContext,
|
||||
MaxIdleConns: 50,
|
||||
IdleConnTimeout: 90 * time.Second,
|
||||
ExpectContinueTimeout: 1 * time.Second,
|
||||
}
|
||||
setting.ReadWriteTimeout = 5 * time.Second
|
||||
SetDefaultSetting(setting)
|
||||
|
||||
str, err := Get("http://httpbin.org/get").String()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Log(str)
|
||||
|
||||
n := strings.Index(str, v)
|
||||
if n == -1 {
|
||||
t.Fatal(v + " not found in user-agent")
|
||||
}
|
||||
}
|
||||
|
||||
func TestToJson(t *testing.T) {
|
||||
req := Get("http://httpbin.org/ip")
|
||||
resp, err := req.Response()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Log(resp)
|
||||
|
||||
// httpbin will return http remote addr
|
||||
type IP struct {
|
||||
Origin string `json:"origin"`
|
||||
}
|
||||
var ip IP
|
||||
err = req.ToJSON(&ip)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Log(ip.Origin)
|
||||
ips := strings.Split(ip.Origin, ",")
|
||||
if len(ips) == 0 {
|
||||
t.Fatal("response is not valid ip")
|
||||
}
|
||||
for i := range ips {
|
||||
if net.ParseIP(strings.TrimSpace(ips[i])).To4() == nil {
|
||||
t.Fatal("response is not valid ip")
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func TestToFile(t *testing.T) {
|
||||
f := "beego_testfile"
|
||||
req := Get("http://httpbin.org/ip")
|
||||
err := req.ToFile(f)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer os.Remove(f)
|
||||
b, err := ioutil.ReadFile(f)
|
||||
if n := strings.Index(string(b), "origin"); n == -1 {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestToFileDir(t *testing.T) {
|
||||
f := "./files/beego_testfile"
|
||||
req := Get("http://httpbin.org/ip")
|
||||
err := req.ToFile(f)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer os.RemoveAll("./files")
|
||||
b, err := ioutil.ReadFile(f)
|
||||
if n := strings.Index(string(b), "origin"); n == -1 {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHeader(t *testing.T) {
|
||||
req := Get("http://httpbin.org/headers")
|
||||
req.Header("User-Agent", "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_9_0) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/31.0.1650.57 Safari/537.36")
|
||||
str, err := req.String()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Log(str)
|
||||
}
|
66
pkg/client/httplib/testing/client.go
Normal file
66
pkg/client/httplib/testing/client.go
Normal file
@ -0,0 +1,66 @@
|
||||
// Copyright 2014 beego Author. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package testing
|
||||
|
||||
import (
|
||||
"github.com/astaxie/beego/pkg/client/httplib"
|
||||
|
||||
"github.com/astaxie/beego/pkg/infrastructure/config"
|
||||
)
|
||||
|
||||
var port = ""
|
||||
var baseURL = "http://localhost:"
|
||||
|
||||
// TestHTTPRequest beego test request client
|
||||
type TestHTTPRequest struct {
|
||||
httplib.BeegoHTTPRequest
|
||||
}
|
||||
|
||||
func getPort() string {
|
||||
if port == "" {
|
||||
config, err := config.NewConfig("ini", "../conf/app.conf")
|
||||
if err != nil {
|
||||
return "8080"
|
||||
}
|
||||
port = config.String("httpport")
|
||||
return port
|
||||
}
|
||||
return port
|
||||
}
|
||||
|
||||
// Get returns test client in GET method
|
||||
func Get(path string) *TestHTTPRequest {
|
||||
return &TestHTTPRequest{*httplib.Get(baseURL + getPort() + path)}
|
||||
}
|
||||
|
||||
// Post returns test client in POST method
|
||||
func Post(path string) *TestHTTPRequest {
|
||||
return &TestHTTPRequest{*httplib.Post(baseURL + getPort() + path)}
|
||||
}
|
||||
|
||||
// Put returns test client in PUT method
|
||||
func Put(path string) *TestHTTPRequest {
|
||||
return &TestHTTPRequest{*httplib.Put(baseURL + getPort() + path)}
|
||||
}
|
||||
|
||||
// Delete returns test client in DELETE method
|
||||
func Delete(path string) *TestHTTPRequest {
|
||||
return &TestHTTPRequest{*httplib.Delete(baseURL + getPort() + path)}
|
||||
}
|
||||
|
||||
// Head returns test client in HEAD method
|
||||
func Head(path string) *TestHTTPRequest {
|
||||
return &TestHTTPRequest{*httplib.Head(baseURL + getPort() + path)}
|
||||
}
|
159
pkg/client/orm/README.md
Normal file
159
pkg/client/orm/README.md
Normal file
@ -0,0 +1,159 @@
|
||||
# beego orm
|
||||
|
||||
[](https://drone.io/github.com/astaxie/beego/latest)
|
||||
|
||||
A powerful orm framework for go.
|
||||
|
||||
It is heavily influenced by Django ORM, SQLAlchemy.
|
||||
|
||||
**Support Database:**
|
||||
|
||||
* MySQL: [github.com/go-sql-driver/mysql](https://github.com/go-sql-driver/mysql)
|
||||
* PostgreSQL: [github.com/lib/pq](https://github.com/lib/pq)
|
||||
* Sqlite3: [github.com/mattn/go-sqlite3](https://github.com/mattn/go-sqlite3)
|
||||
|
||||
Passed all test, but need more feedback.
|
||||
|
||||
**Features:**
|
||||
|
||||
* full go type support
|
||||
* easy for usage, simple CRUD operation
|
||||
* auto join with relation table
|
||||
* cross DataBase compatible query
|
||||
* Raw SQL query / mapper without orm model
|
||||
* full test keep stable and strong
|
||||
|
||||
more features please read the docs
|
||||
|
||||
**Install:**
|
||||
|
||||
go get github.com/astaxie/beego/orm
|
||||
|
||||
## Changelog
|
||||
|
||||
* 2013-08-19: support table auto create
|
||||
* 2013-08-13: update test for database types
|
||||
* 2013-08-13: go type support, such as int8, uint8, byte, rune
|
||||
* 2013-08-13: date / datetime timezone support very well
|
||||
|
||||
## Quick Start
|
||||
|
||||
#### Simple Usage
|
||||
|
||||
```go
|
||||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/astaxie/beego/orm"
|
||||
_ "github.com/go-sql-driver/mysql" // import your used driver
|
||||
)
|
||||
|
||||
// Model Struct
|
||||
type User struct {
|
||||
Id int `orm:"auto"`
|
||||
Name string `orm:"size(100)"`
|
||||
}
|
||||
|
||||
func init() {
|
||||
// register model
|
||||
orm.RegisterModel(new(User))
|
||||
|
||||
// set default database
|
||||
orm.RegisterDataBase("default", "mysql", "root:root@/my_db?charset=utf8", 30)
|
||||
|
||||
// create table
|
||||
orm.RunSyncdb("default", false, true)
|
||||
}
|
||||
|
||||
func main() {
|
||||
o := orm.NewOrm()
|
||||
|
||||
user := User{Name: "slene"}
|
||||
|
||||
// insert
|
||||
id, err := o.Insert(&user)
|
||||
|
||||
// update
|
||||
user.Name = "astaxie"
|
||||
num, err := o.Update(&user)
|
||||
|
||||
// read one
|
||||
u := User{Id: user.Id}
|
||||
err = o.Read(&u)
|
||||
|
||||
// delete
|
||||
num, err = o.Delete(&u)
|
||||
}
|
||||
```
|
||||
|
||||
#### Next with relation
|
||||
|
||||
```go
|
||||
type Post struct {
|
||||
Id int `orm:"auto"`
|
||||
Title string `orm:"size(100)"`
|
||||
User *User `orm:"rel(fk)"`
|
||||
}
|
||||
|
||||
var posts []*Post
|
||||
qs := o.QueryTable("post")
|
||||
num, err := qs.Filter("User__Name", "slene").All(&posts)
|
||||
```
|
||||
|
||||
#### Use Raw sql
|
||||
|
||||
If you don't like ORM,use Raw SQL to query / mapping without ORM setting
|
||||
|
||||
```go
|
||||
var maps []Params
|
||||
num, err := o.Raw("SELECT id FROM user WHERE name = ?", "slene").Values(&maps)
|
||||
if num > 0 {
|
||||
fmt.Println(maps[0]["id"])
|
||||
}
|
||||
```
|
||||
|
||||
#### Transaction
|
||||
|
||||
```go
|
||||
o.Begin()
|
||||
...
|
||||
user := User{Name: "slene"}
|
||||
id, err := o.Insert(&user)
|
||||
if err == nil {
|
||||
o.Commit()
|
||||
} else {
|
||||
o.Rollback()
|
||||
}
|
||||
|
||||
```
|
||||
|
||||
#### Debug Log Queries
|
||||
|
||||
In development env, you can simple use
|
||||
|
||||
```go
|
||||
func main() {
|
||||
orm.Debug = true
|
||||
...
|
||||
```
|
||||
|
||||
enable log queries.
|
||||
|
||||
output include all queries, such as exec / prepare / transaction.
|
||||
|
||||
like this:
|
||||
|
||||
```go
|
||||
[ORM] - 2013-08-09 13:18:16 - [Queries/default] - [ db.Exec / 0.4ms] - [INSERT INTO `user` (`name`) VALUES (?)] - `slene`
|
||||
...
|
||||
```
|
||||
|
||||
note: not recommend use this in product env.
|
||||
|
||||
## Docs
|
||||
|
||||
more details and examples in docs and test
|
||||
|
||||
[documents](http://beego.me/docs/mvc/model/overview.md)
|
||||
|
283
pkg/client/orm/cmd.go
Normal file
283
pkg/client/orm/cmd.go
Normal file
@ -0,0 +1,283 @@
|
||||
// Copyright 2014 beego Author. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package orm
|
||||
|
||||
import (
|
||||
"flag"
|
||||
"fmt"
|
||||
"os"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type commander interface {
|
||||
Parse([]string)
|
||||
Run() error
|
||||
}
|
||||
|
||||
var (
|
||||
commands = make(map[string]commander)
|
||||
)
|
||||
|
||||
// print help.
|
||||
func printHelp(errs ...string) {
|
||||
content := `orm command usage:
|
||||
|
||||
syncdb - auto create tables
|
||||
sqlall - print sql of create tables
|
||||
help - print this help
|
||||
`
|
||||
|
||||
if len(errs) > 0 {
|
||||
fmt.Println(errs[0])
|
||||
}
|
||||
fmt.Println(content)
|
||||
os.Exit(2)
|
||||
}
|
||||
|
||||
// RunCommand listens for orm command and runs if command arguments have been passed.
|
||||
func RunCommand() {
|
||||
if len(os.Args) < 2 || os.Args[1] != "orm" {
|
||||
return
|
||||
}
|
||||
|
||||
BootStrap()
|
||||
|
||||
args := argString(os.Args[2:])
|
||||
name := args.Get(0)
|
||||
|
||||
if name == "help" {
|
||||
printHelp()
|
||||
}
|
||||
|
||||
if cmd, ok := commands[name]; ok {
|
||||
cmd.Parse(os.Args[3:])
|
||||
cmd.Run()
|
||||
os.Exit(0)
|
||||
} else {
|
||||
if name == "" {
|
||||
printHelp()
|
||||
} else {
|
||||
printHelp(fmt.Sprintf("unknown command %s", name))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// sync database struct command interface.
|
||||
type commandSyncDb struct {
|
||||
al *alias
|
||||
force bool
|
||||
verbose bool
|
||||
noInfo bool
|
||||
rtOnError bool
|
||||
}
|
||||
|
||||
// Parse the orm command line arguments.
|
||||
func (d *commandSyncDb) Parse(args []string) {
|
||||
var name string
|
||||
|
||||
flagSet := flag.NewFlagSet("orm command: syncdb", flag.ExitOnError)
|
||||
flagSet.StringVar(&name, "db", "default", "DataBase alias name")
|
||||
flagSet.BoolVar(&d.force, "force", false, "drop tables before create")
|
||||
flagSet.BoolVar(&d.verbose, "v", false, "verbose info")
|
||||
flagSet.Parse(args)
|
||||
|
||||
d.al = getDbAlias(name)
|
||||
}
|
||||
|
||||
// Run orm line command.
|
||||
func (d *commandSyncDb) Run() error {
|
||||
var drops []string
|
||||
if d.force {
|
||||
drops = getDbDropSQL(d.al)
|
||||
}
|
||||
|
||||
db := d.al.DB
|
||||
|
||||
if d.force {
|
||||
for i, mi := range modelCache.allOrdered() {
|
||||
query := drops[i]
|
||||
if !d.noInfo {
|
||||
fmt.Printf("drop table `%s`\n", mi.table)
|
||||
}
|
||||
_, err := db.Exec(query)
|
||||
if d.verbose {
|
||||
fmt.Printf(" %s\n\n", query)
|
||||
}
|
||||
if err != nil {
|
||||
if d.rtOnError {
|
||||
return err
|
||||
}
|
||||
fmt.Printf(" %s\n", err.Error())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
sqls, indexes := getDbCreateSQL(d.al)
|
||||
|
||||
tables, err := d.al.DbBaser.GetTables(db)
|
||||
if err != nil {
|
||||
if d.rtOnError {
|
||||
return err
|
||||
}
|
||||
fmt.Printf(" %s\n", err.Error())
|
||||
}
|
||||
|
||||
for i, mi := range modelCache.allOrdered() {
|
||||
if tables[mi.table] {
|
||||
if !d.noInfo {
|
||||
fmt.Printf("table `%s` already exists, skip\n", mi.table)
|
||||
}
|
||||
|
||||
var fields []*fieldInfo
|
||||
columns, err := d.al.DbBaser.GetColumns(db, mi.table)
|
||||
if err != nil {
|
||||
if d.rtOnError {
|
||||
return err
|
||||
}
|
||||
fmt.Printf(" %s\n", err.Error())
|
||||
}
|
||||
|
||||
for _, fi := range mi.fields.fieldsDB {
|
||||
if _, ok := columns[fi.column]; !ok {
|
||||
fields = append(fields, fi)
|
||||
}
|
||||
}
|
||||
|
||||
for _, fi := range fields {
|
||||
query := getColumnAddQuery(d.al, fi)
|
||||
|
||||
if !d.noInfo {
|
||||
fmt.Printf("add column `%s` for table `%s`\n", fi.fullName, mi.table)
|
||||
}
|
||||
|
||||
_, err := db.Exec(query)
|
||||
if d.verbose {
|
||||
fmt.Printf(" %s\n", query)
|
||||
}
|
||||
if err != nil {
|
||||
if d.rtOnError {
|
||||
return err
|
||||
}
|
||||
fmt.Printf(" %s\n", err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
for _, idx := range indexes[mi.table] {
|
||||
if !d.al.DbBaser.IndexExists(db, idx.Table, idx.Name) {
|
||||
if !d.noInfo {
|
||||
fmt.Printf("create index `%s` for table `%s`\n", idx.Name, idx.Table)
|
||||
}
|
||||
|
||||
query := idx.SQL
|
||||
_, err := db.Exec(query)
|
||||
if d.verbose {
|
||||
fmt.Printf(" %s\n", query)
|
||||
}
|
||||
if err != nil {
|
||||
if d.rtOnError {
|
||||
return err
|
||||
}
|
||||
fmt.Printf(" %s\n", err.Error())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
continue
|
||||
}
|
||||
|
||||
if !d.noInfo {
|
||||
fmt.Printf("create table `%s` \n", mi.table)
|
||||
}
|
||||
|
||||
queries := []string{sqls[i]}
|
||||
for _, idx := range indexes[mi.table] {
|
||||
queries = append(queries, idx.SQL)
|
||||
}
|
||||
|
||||
for _, query := range queries {
|
||||
_, err := db.Exec(query)
|
||||
if d.verbose {
|
||||
query = " " + strings.Join(strings.Split(query, "\n"), "\n ")
|
||||
fmt.Println(query)
|
||||
}
|
||||
if err != nil {
|
||||
if d.rtOnError {
|
||||
return err
|
||||
}
|
||||
fmt.Printf(" %s\n", err.Error())
|
||||
}
|
||||
}
|
||||
if d.verbose {
|
||||
fmt.Println("")
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// database creation commander interface implement.
|
||||
type commandSQLAll struct {
|
||||
al *alias
|
||||
}
|
||||
|
||||
// Parse orm command line arguments.
|
||||
func (d *commandSQLAll) Parse(args []string) {
|
||||
var name string
|
||||
|
||||
flagSet := flag.NewFlagSet("orm command: sqlall", flag.ExitOnError)
|
||||
flagSet.StringVar(&name, "db", "default", "DataBase alias name")
|
||||
flagSet.Parse(args)
|
||||
|
||||
d.al = getDbAlias(name)
|
||||
}
|
||||
|
||||
// Run orm line command.
|
||||
func (d *commandSQLAll) Run() error {
|
||||
sqls, indexes := getDbCreateSQL(d.al)
|
||||
var all []string
|
||||
for i, mi := range modelCache.allOrdered() {
|
||||
queries := []string{sqls[i]}
|
||||
for _, idx := range indexes[mi.table] {
|
||||
queries = append(queries, idx.SQL)
|
||||
}
|
||||
sql := strings.Join(queries, "\n")
|
||||
all = append(all, sql)
|
||||
}
|
||||
fmt.Println(strings.Join(all, "\n\n"))
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func init() {
|
||||
commands["syncdb"] = new(commandSyncDb)
|
||||
commands["sqlall"] = new(commandSQLAll)
|
||||
}
|
||||
|
||||
// RunSyncdb run syncdb command line.
|
||||
// name: Table's alias name (default is "default")
|
||||
// force: Run the next sql command even if the current gave an error
|
||||
// verbose: Print all information, useful for debugging
|
||||
func RunSyncdb(name string, force bool, verbose bool) error {
|
||||
BootStrap()
|
||||
|
||||
al := getDbAlias(name)
|
||||
cmd := new(commandSyncDb)
|
||||
cmd.al = al
|
||||
cmd.force = force
|
||||
cmd.noInfo = !verbose
|
||||
cmd.verbose = verbose
|
||||
cmd.rtOnError = true
|
||||
return cmd.Run()
|
||||
}
|
320
pkg/client/orm/cmd_utils.go
Normal file
320
pkg/client/orm/cmd_utils.go
Normal file
@ -0,0 +1,320 @@
|
||||
// Copyright 2014 beego Author. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package orm
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type dbIndex struct {
|
||||
Table string
|
||||
Name string
|
||||
SQL string
|
||||
}
|
||||
|
||||
// create database drop sql.
|
||||
func getDbDropSQL(al *alias) (sqls []string) {
|
||||
if len(modelCache.cache) == 0 {
|
||||
fmt.Println("no Model found, need register your model")
|
||||
os.Exit(2)
|
||||
}
|
||||
|
||||
Q := al.DbBaser.TableQuote()
|
||||
|
||||
for _, mi := range modelCache.allOrdered() {
|
||||
sqls = append(sqls, fmt.Sprintf(`DROP TABLE IF EXISTS %s%s%s`, Q, mi.table, Q))
|
||||
}
|
||||
return sqls
|
||||
}
|
||||
|
||||
// get database column type string.
|
||||
func getColumnTyp(al *alias, fi *fieldInfo) (col string) {
|
||||
T := al.DbBaser.DbTypes()
|
||||
fieldType := fi.fieldType
|
||||
fieldSize := fi.size
|
||||
|
||||
checkColumn:
|
||||
switch fieldType {
|
||||
case TypeBooleanField:
|
||||
col = T["bool"]
|
||||
case TypeVarCharField:
|
||||
if al.Driver == DRPostgres && fi.toText {
|
||||
col = T["string-text"]
|
||||
} else {
|
||||
col = fmt.Sprintf(T["string"], fieldSize)
|
||||
}
|
||||
case TypeCharField:
|
||||
col = fmt.Sprintf(T["string-char"], fieldSize)
|
||||
case TypeTextField:
|
||||
col = T["string-text"]
|
||||
case TypeTimeField:
|
||||
col = T["time.Time-clock"]
|
||||
case TypeDateField:
|
||||
col = T["time.Time-date"]
|
||||
case TypeDateTimeField:
|
||||
col = T["time.Time"]
|
||||
case TypeBitField:
|
||||
col = T["int8"]
|
||||
case TypeSmallIntegerField:
|
||||
col = T["int16"]
|
||||
case TypeIntegerField:
|
||||
col = T["int32"]
|
||||
case TypeBigIntegerField:
|
||||
if al.Driver == DRSqlite {
|
||||
fieldType = TypeIntegerField
|
||||
goto checkColumn
|
||||
}
|
||||
col = T["int64"]
|
||||
case TypePositiveBitField:
|
||||
col = T["uint8"]
|
||||
case TypePositiveSmallIntegerField:
|
||||
col = T["uint16"]
|
||||
case TypePositiveIntegerField:
|
||||
col = T["uint32"]
|
||||
case TypePositiveBigIntegerField:
|
||||
col = T["uint64"]
|
||||
case TypeFloatField:
|
||||
col = T["float64"]
|
||||
case TypeDecimalField:
|
||||
s := T["float64-decimal"]
|
||||
if !strings.Contains(s, "%d") {
|
||||
col = s
|
||||
} else {
|
||||
col = fmt.Sprintf(s, fi.digits, fi.decimals)
|
||||
}
|
||||
case TypeJSONField:
|
||||
if al.Driver != DRPostgres {
|
||||
fieldType = TypeVarCharField
|
||||
goto checkColumn
|
||||
}
|
||||
col = T["json"]
|
||||
case TypeJsonbField:
|
||||
if al.Driver != DRPostgres {
|
||||
fieldType = TypeVarCharField
|
||||
goto checkColumn
|
||||
}
|
||||
col = T["jsonb"]
|
||||
case RelForeignKey, RelOneToOne:
|
||||
fieldType = fi.relModelInfo.fields.pk.fieldType
|
||||
fieldSize = fi.relModelInfo.fields.pk.size
|
||||
goto checkColumn
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// create alter sql string.
|
||||
func getColumnAddQuery(al *alias, fi *fieldInfo) string {
|
||||
Q := al.DbBaser.TableQuote()
|
||||
typ := getColumnTyp(al, fi)
|
||||
|
||||
if !fi.null {
|
||||
typ += " " + "NOT NULL"
|
||||
}
|
||||
|
||||
return fmt.Sprintf("ALTER TABLE %s%s%s ADD COLUMN %s%s%s %s %s",
|
||||
Q, fi.mi.table, Q,
|
||||
Q, fi.column, Q,
|
||||
typ, getColumnDefault(fi),
|
||||
)
|
||||
}
|
||||
|
||||
// create database creation string.
|
||||
func getDbCreateSQL(al *alias) (sqls []string, tableIndexes map[string][]dbIndex) {
|
||||
if len(modelCache.cache) == 0 {
|
||||
fmt.Println("no Model found, need register your model")
|
||||
os.Exit(2)
|
||||
}
|
||||
|
||||
Q := al.DbBaser.TableQuote()
|
||||
T := al.DbBaser.DbTypes()
|
||||
sep := fmt.Sprintf("%s, %s", Q, Q)
|
||||
|
||||
tableIndexes = make(map[string][]dbIndex)
|
||||
|
||||
for _, mi := range modelCache.allOrdered() {
|
||||
sql := fmt.Sprintf("-- %s\n", strings.Repeat("-", 50))
|
||||
sql += fmt.Sprintf("-- Table Structure for `%s`\n", mi.fullName)
|
||||
sql += fmt.Sprintf("-- %s\n", strings.Repeat("-", 50))
|
||||
|
||||
sql += fmt.Sprintf("CREATE TABLE IF NOT EXISTS %s%s%s (\n", Q, mi.table, Q)
|
||||
|
||||
columns := make([]string, 0, len(mi.fields.fieldsDB))
|
||||
|
||||
sqlIndexes := [][]string{}
|
||||
|
||||
for _, fi := range mi.fields.fieldsDB {
|
||||
|
||||
column := fmt.Sprintf(" %s%s%s ", Q, fi.column, Q)
|
||||
col := getColumnTyp(al, fi)
|
||||
|
||||
if fi.auto {
|
||||
switch al.Driver {
|
||||
case DRSqlite, DRPostgres:
|
||||
column += T["auto"]
|
||||
default:
|
||||
column += col + " " + T["auto"]
|
||||
}
|
||||
} else if fi.pk {
|
||||
column += col + " " + T["pk"]
|
||||
} else {
|
||||
column += col
|
||||
|
||||
if !fi.null {
|
||||
column += " " + "NOT NULL"
|
||||
}
|
||||
|
||||
//if fi.initial.String() != "" {
|
||||
// column += " DEFAULT " + fi.initial.String()
|
||||
//}
|
||||
|
||||
// Append attribute DEFAULT
|
||||
column += getColumnDefault(fi)
|
||||
|
||||
if fi.unique {
|
||||
column += " " + "UNIQUE"
|
||||
}
|
||||
|
||||
if fi.index {
|
||||
sqlIndexes = append(sqlIndexes, []string{fi.column})
|
||||
}
|
||||
}
|
||||
|
||||
if strings.Contains(column, "%COL%") {
|
||||
column = strings.Replace(column, "%COL%", fi.column, -1)
|
||||
}
|
||||
|
||||
if fi.description != "" && al.Driver != DRSqlite {
|
||||
column += " " + fmt.Sprintf("COMMENT '%s'", fi.description)
|
||||
}
|
||||
|
||||
columns = append(columns, column)
|
||||
}
|
||||
|
||||
if mi.model != nil {
|
||||
allnames := getTableUnique(mi.addrField)
|
||||
if !mi.manual && len(mi.uniques) > 0 {
|
||||
allnames = append(allnames, mi.uniques)
|
||||
}
|
||||
for _, names := range allnames {
|
||||
cols := make([]string, 0, len(names))
|
||||
for _, name := range names {
|
||||
if fi, ok := mi.fields.GetByAny(name); ok && fi.dbcol {
|
||||
cols = append(cols, fi.column)
|
||||
} else {
|
||||
panic(fmt.Errorf("cannot found column `%s` when parse UNIQUE in `%s.TableUnique`", name, mi.fullName))
|
||||
}
|
||||
}
|
||||
column := fmt.Sprintf(" UNIQUE (%s%s%s)", Q, strings.Join(cols, sep), Q)
|
||||
columns = append(columns, column)
|
||||
}
|
||||
}
|
||||
|
||||
sql += strings.Join(columns, ",\n")
|
||||
sql += "\n)"
|
||||
|
||||
if al.Driver == DRMySQL {
|
||||
var engine string
|
||||
if mi.model != nil {
|
||||
engine = getTableEngine(mi.addrField)
|
||||
}
|
||||
if engine == "" {
|
||||
engine = al.Engine
|
||||
}
|
||||
sql += " ENGINE=" + engine
|
||||
}
|
||||
|
||||
sql += ";"
|
||||
sqls = append(sqls, sql)
|
||||
|
||||
if mi.model != nil {
|
||||
for _, names := range getTableIndex(mi.addrField) {
|
||||
cols := make([]string, 0, len(names))
|
||||
for _, name := range names {
|
||||
if fi, ok := mi.fields.GetByAny(name); ok && fi.dbcol {
|
||||
cols = append(cols, fi.column)
|
||||
} else {
|
||||
panic(fmt.Errorf("cannot found column `%s` when parse INDEX in `%s.TableIndex`", name, mi.fullName))
|
||||
}
|
||||
}
|
||||
sqlIndexes = append(sqlIndexes, cols)
|
||||
}
|
||||
}
|
||||
|
||||
for _, names := range sqlIndexes {
|
||||
name := mi.table + "_" + strings.Join(names, "_")
|
||||
cols := strings.Join(names, sep)
|
||||
sql := fmt.Sprintf("CREATE INDEX %s%s%s ON %s%s%s (%s%s%s);", Q, name, Q, Q, mi.table, Q, Q, cols, Q)
|
||||
|
||||
index := dbIndex{}
|
||||
index.Table = mi.table
|
||||
index.Name = name
|
||||
index.SQL = sql
|
||||
|
||||
tableIndexes[mi.table] = append(tableIndexes[mi.table], index)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// Get string value for the attribute "DEFAULT" for the CREATE, ALTER commands
|
||||
func getColumnDefault(fi *fieldInfo) string {
|
||||
var (
|
||||
v, t, d string
|
||||
)
|
||||
|
||||
// Skip default attribute if field is in relations
|
||||
if fi.rel || fi.reverse {
|
||||
return v
|
||||
}
|
||||
|
||||
t = " DEFAULT '%s' "
|
||||
|
||||
// These defaults will be useful if there no config value orm:"default" and NOT NULL is on
|
||||
switch fi.fieldType {
|
||||
case TypeTimeField, TypeDateField, TypeDateTimeField, TypeTextField:
|
||||
return v
|
||||
|
||||
case TypeBitField, TypeSmallIntegerField, TypeIntegerField,
|
||||
TypeBigIntegerField, TypePositiveBitField, TypePositiveSmallIntegerField,
|
||||
TypePositiveIntegerField, TypePositiveBigIntegerField, TypeFloatField,
|
||||
TypeDecimalField:
|
||||
t = " DEFAULT %s "
|
||||
d = "0"
|
||||
case TypeBooleanField:
|
||||
t = " DEFAULT %s "
|
||||
d = "FALSE"
|
||||
case TypeJSONField, TypeJsonbField:
|
||||
d = "{}"
|
||||
}
|
||||
|
||||
if fi.colDefault {
|
||||
if !fi.initial.Exist() {
|
||||
v = fmt.Sprintf(t, "")
|
||||
} else {
|
||||
v = fmt.Sprintf(t, fi.initial.String())
|
||||
}
|
||||
} else {
|
||||
if !fi.null {
|
||||
v = fmt.Sprintf(t, d)
|
||||
}
|
||||
}
|
||||
|
||||
return v
|
||||
}
|
1957
pkg/client/orm/db.go
Normal file
1957
pkg/client/orm/db.go
Normal file
File diff suppressed because it is too large
Load Diff
555
pkg/client/orm/db_alias.go
Normal file
555
pkg/client/orm/db_alias.go
Normal file
@ -0,0 +1,555 @@
|
||||
// Copyright 2014 beego Author. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package orm
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/astaxie/beego/pkg/client/orm/hints"
|
||||
"github.com/astaxie/beego/pkg/infrastructure/utils"
|
||||
|
||||
lru "github.com/hashicorp/golang-lru"
|
||||
)
|
||||
|
||||
// DriverType database driver constant int.
|
||||
type DriverType int
|
||||
|
||||
// Enum the Database driver
|
||||
const (
|
||||
_ DriverType = iota // int enum type
|
||||
DRMySQL // mysql
|
||||
DRSqlite // sqlite
|
||||
DROracle // oracle
|
||||
DRPostgres // pgsql
|
||||
DRTiDB // TiDB
|
||||
)
|
||||
|
||||
// database driver string.
|
||||
type driver string
|
||||
|
||||
// get type constant int of current driver..
|
||||
func (d driver) Type() DriverType {
|
||||
a, _ := dataBaseCache.get(string(d))
|
||||
return a.Driver
|
||||
}
|
||||
|
||||
// get name of current driver
|
||||
func (d driver) Name() string {
|
||||
return string(d)
|
||||
}
|
||||
|
||||
// check driver iis implemented Driver interface or not.
|
||||
var _ Driver = new(driver)
|
||||
|
||||
var (
|
||||
dataBaseCache = &_dbCache{cache: make(map[string]*alias)}
|
||||
drivers = map[string]DriverType{
|
||||
"mysql": DRMySQL,
|
||||
"postgres": DRPostgres,
|
||||
"sqlite3": DRSqlite,
|
||||
"tidb": DRTiDB,
|
||||
"oracle": DROracle,
|
||||
"oci8": DROracle, // github.com/mattn/go-oci8
|
||||
"ora": DROracle, // https://github.com/rana/ora
|
||||
}
|
||||
dbBasers = map[DriverType]dbBaser{
|
||||
DRMySQL: newdbBaseMysql(),
|
||||
DRSqlite: newdbBaseSqlite(),
|
||||
DROracle: newdbBaseOracle(),
|
||||
DRPostgres: newdbBasePostgres(),
|
||||
DRTiDB: newdbBaseTidb(),
|
||||
}
|
||||
)
|
||||
|
||||
// database alias cacher.
|
||||
type _dbCache struct {
|
||||
mux sync.RWMutex
|
||||
cache map[string]*alias
|
||||
}
|
||||
|
||||
// add database alias with original name.
|
||||
func (ac *_dbCache) add(name string, al *alias) (added bool) {
|
||||
ac.mux.Lock()
|
||||
defer ac.mux.Unlock()
|
||||
if _, ok := ac.cache[name]; !ok {
|
||||
ac.cache[name] = al
|
||||
added = true
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// get database alias if cached.
|
||||
func (ac *_dbCache) get(name string) (al *alias, ok bool) {
|
||||
ac.mux.RLock()
|
||||
defer ac.mux.RUnlock()
|
||||
al, ok = ac.cache[name]
|
||||
return
|
||||
}
|
||||
|
||||
// get default alias.
|
||||
func (ac *_dbCache) getDefault() (al *alias) {
|
||||
al, _ = ac.get("default")
|
||||
return
|
||||
}
|
||||
|
||||
type DB struct {
|
||||
*sync.RWMutex
|
||||
DB *sql.DB
|
||||
stmtDecorators *lru.Cache
|
||||
stmtDecoratorsLimit int
|
||||
}
|
||||
|
||||
var _ dbQuerier = new(DB)
|
||||
var _ txer = new(DB)
|
||||
|
||||
func (d *DB) Begin() (*sql.Tx, error) {
|
||||
return d.DB.Begin()
|
||||
}
|
||||
|
||||
func (d *DB) BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error) {
|
||||
return d.DB.BeginTx(ctx, opts)
|
||||
}
|
||||
|
||||
// su must call release to release *sql.Stmt after using
|
||||
func (d *DB) getStmtDecorator(query string) (*stmtDecorator, error) {
|
||||
d.RLock()
|
||||
c, ok := d.stmtDecorators.Get(query)
|
||||
if ok {
|
||||
c.(*stmtDecorator).acquire()
|
||||
d.RUnlock()
|
||||
return c.(*stmtDecorator), nil
|
||||
}
|
||||
d.RUnlock()
|
||||
|
||||
d.Lock()
|
||||
c, ok = d.stmtDecorators.Get(query)
|
||||
if ok {
|
||||
c.(*stmtDecorator).acquire()
|
||||
d.Unlock()
|
||||
return c.(*stmtDecorator), nil
|
||||
}
|
||||
|
||||
stmt, err := d.Prepare(query)
|
||||
if err != nil {
|
||||
d.Unlock()
|
||||
return nil, err
|
||||
}
|
||||
sd := newStmtDecorator(stmt)
|
||||
sd.acquire()
|
||||
d.stmtDecorators.Add(query, sd)
|
||||
d.Unlock()
|
||||
|
||||
return sd, nil
|
||||
}
|
||||
|
||||
func (d *DB) Prepare(query string) (*sql.Stmt, error) {
|
||||
return d.DB.Prepare(query)
|
||||
}
|
||||
|
||||
func (d *DB) PrepareContext(ctx context.Context, query string) (*sql.Stmt, error) {
|
||||
return d.DB.PrepareContext(ctx, query)
|
||||
}
|
||||
|
||||
func (d *DB) Exec(query string, args ...interface{}) (sql.Result, error) {
|
||||
return d.ExecContext(context.Background(), query, args...)
|
||||
}
|
||||
|
||||
func (d *DB) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) {
|
||||
if d.stmtDecorators == nil {
|
||||
return d.DB.ExecContext(ctx, query, args...)
|
||||
}
|
||||
|
||||
sd, err := d.getStmtDecorator(query)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
stmt := sd.getStmt()
|
||||
defer sd.release()
|
||||
return stmt.ExecContext(ctx, args...)
|
||||
}
|
||||
|
||||
func (d *DB) Query(query string, args ...interface{}) (*sql.Rows, error) {
|
||||
return d.QueryContext(context.Background(), query, args...)
|
||||
}
|
||||
|
||||
func (d *DB) QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) {
|
||||
if d.stmtDecorators == nil {
|
||||
return d.DB.QueryContext(ctx, query, args...)
|
||||
}
|
||||
|
||||
sd, err := d.getStmtDecorator(query)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
stmt := sd.getStmt()
|
||||
defer sd.release()
|
||||
return stmt.QueryContext(ctx, args...)
|
||||
}
|
||||
|
||||
func (d *DB) QueryRow(query string, args ...interface{}) *sql.Row {
|
||||
return d.QueryRowContext(context.Background(), query, args...)
|
||||
}
|
||||
|
||||
func (d *DB) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row {
|
||||
if d.stmtDecorators == nil {
|
||||
return d.DB.QueryRowContext(ctx, query, args...)
|
||||
}
|
||||
|
||||
sd, err := d.getStmtDecorator(query)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
stmt := sd.getStmt()
|
||||
defer sd.release()
|
||||
return stmt.QueryRowContext(ctx, args...)
|
||||
}
|
||||
|
||||
type TxDB struct {
|
||||
tx *sql.Tx
|
||||
}
|
||||
|
||||
var _ dbQuerier = new(TxDB)
|
||||
var _ txEnder = new(TxDB)
|
||||
|
||||
func (t *TxDB) Commit() error {
|
||||
return t.tx.Commit()
|
||||
}
|
||||
|
||||
func (t *TxDB) Rollback() error {
|
||||
return t.tx.Rollback()
|
||||
}
|
||||
|
||||
var _ dbQuerier = new(TxDB)
|
||||
var _ txEnder = new(TxDB)
|
||||
|
||||
func (t *TxDB) Prepare(query string) (*sql.Stmt, error) {
|
||||
return t.PrepareContext(context.Background(), query)
|
||||
}
|
||||
|
||||
func (t *TxDB) PrepareContext(ctx context.Context, query string) (*sql.Stmt, error) {
|
||||
return t.tx.PrepareContext(ctx, query)
|
||||
}
|
||||
|
||||
func (t *TxDB) Exec(query string, args ...interface{}) (sql.Result, error) {
|
||||
return t.ExecContext(context.Background(), query, args...)
|
||||
}
|
||||
|
||||
func (t *TxDB) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) {
|
||||
return t.tx.ExecContext(ctx, query, args...)
|
||||
}
|
||||
|
||||
func (t *TxDB) Query(query string, args ...interface{}) (*sql.Rows, error) {
|
||||
return t.QueryContext(context.Background(), query, args...)
|
||||
}
|
||||
|
||||
func (t *TxDB) QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) {
|
||||
return t.tx.QueryContext(ctx, query, args...)
|
||||
}
|
||||
|
||||
func (t *TxDB) QueryRow(query string, args ...interface{}) *sql.Row {
|
||||
return t.QueryRowContext(context.Background(), query, args...)
|
||||
}
|
||||
|
||||
func (t *TxDB) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row {
|
||||
return t.tx.QueryRowContext(ctx, query, args...)
|
||||
}
|
||||
|
||||
type alias struct {
|
||||
Name string
|
||||
Driver DriverType
|
||||
DriverName string
|
||||
DataSource string
|
||||
MaxIdleConns int
|
||||
MaxOpenConns int
|
||||
ConnMaxLifetime time.Duration
|
||||
DB *DB
|
||||
DbBaser dbBaser
|
||||
TZ *time.Location
|
||||
Engine string
|
||||
}
|
||||
|
||||
func detectTZ(al *alias) {
|
||||
// orm timezone system match database
|
||||
// default use Local
|
||||
al.TZ = DefaultTimeLoc
|
||||
|
||||
if al.DriverName == "sphinx" {
|
||||
return
|
||||
}
|
||||
|
||||
switch al.Driver {
|
||||
case DRMySQL:
|
||||
row := al.DB.QueryRow("SELECT TIMEDIFF(NOW(), UTC_TIMESTAMP)")
|
||||
var tz string
|
||||
row.Scan(&tz)
|
||||
if len(tz) >= 8 {
|
||||
if tz[0] != '-' {
|
||||
tz = "+" + tz
|
||||
}
|
||||
t, err := time.Parse("-07:00:00", tz)
|
||||
if err == nil {
|
||||
if t.Location().String() != "" {
|
||||
al.TZ = t.Location()
|
||||
}
|
||||
} else {
|
||||
DebugLog.Printf("Detect DB timezone: %s %s\n", tz, err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
// get default engine from current database
|
||||
row = al.DB.QueryRow("SELECT ENGINE, TRANSACTIONS FROM information_schema.engines WHERE SUPPORT = 'DEFAULT'")
|
||||
var engine string
|
||||
var tx bool
|
||||
row.Scan(&engine, &tx)
|
||||
|
||||
if engine != "" {
|
||||
al.Engine = engine
|
||||
} else {
|
||||
al.Engine = "INNODB"
|
||||
}
|
||||
|
||||
case DRSqlite, DROracle:
|
||||
al.TZ = time.UTC
|
||||
|
||||
case DRPostgres:
|
||||
row := al.DB.QueryRow("SELECT current_setting('TIMEZONE')")
|
||||
var tz string
|
||||
row.Scan(&tz)
|
||||
loc, err := time.LoadLocation(tz)
|
||||
if err == nil {
|
||||
al.TZ = loc
|
||||
} else {
|
||||
DebugLog.Printf("Detect DB timezone: %s %s\n", tz, err.Error())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func addAliasWthDB(aliasName, driverName string, db *sql.DB, params ...utils.KV) (*alias, error) {
|
||||
existErr := fmt.Errorf("DataBase alias name `%s` already registered, cannot reuse", aliasName)
|
||||
if _, ok := dataBaseCache.get(aliasName); ok {
|
||||
return nil, existErr
|
||||
}
|
||||
|
||||
al, err := newAliasWithDb(aliasName, driverName, db, params...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if !dataBaseCache.add(aliasName, al) {
|
||||
return nil, existErr
|
||||
}
|
||||
|
||||
return al, nil
|
||||
}
|
||||
|
||||
func newAliasWithDb(aliasName, driverName string, db *sql.DB, params ...utils.KV) (*alias, error) {
|
||||
kvs := utils.NewKVs(params...)
|
||||
|
||||
var stmtCache *lru.Cache
|
||||
var stmtCacheSize int
|
||||
|
||||
maxStmtCacheSize := kvs.GetValueOr(hints.KeyMaxStmtCacheSize, 0).(int)
|
||||
if maxStmtCacheSize > 0 {
|
||||
_stmtCache, errC := newStmtDecoratorLruWithEvict(maxStmtCacheSize)
|
||||
if errC != nil {
|
||||
return nil, errC
|
||||
} else {
|
||||
stmtCache = _stmtCache
|
||||
stmtCacheSize = maxStmtCacheSize
|
||||
}
|
||||
}
|
||||
|
||||
al := new(alias)
|
||||
al.Name = aliasName
|
||||
al.DriverName = driverName
|
||||
al.DB = &DB{
|
||||
RWMutex: new(sync.RWMutex),
|
||||
DB: db,
|
||||
stmtDecorators: stmtCache,
|
||||
stmtDecoratorsLimit: stmtCacheSize,
|
||||
}
|
||||
|
||||
if dr, ok := drivers[driverName]; ok {
|
||||
al.DbBaser = dbBasers[dr]
|
||||
al.Driver = dr
|
||||
} else {
|
||||
return nil, fmt.Errorf("driver name `%s` have not registered", driverName)
|
||||
}
|
||||
|
||||
err := db.Ping()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("register db Ping `%s`, %s", aliasName, err.Error())
|
||||
}
|
||||
|
||||
detectTZ(al)
|
||||
|
||||
kvs.IfContains(hints.KeyMaxIdleConnections, func(value interface{}) {
|
||||
if m, ok := value.(int); ok {
|
||||
SetMaxIdleConns(al, m)
|
||||
}
|
||||
}).IfContains(hints.KeyMaxOpenConnections, func(value interface{}) {
|
||||
if m, ok := value.(int); ok {
|
||||
SetMaxOpenConns(al, m)
|
||||
}
|
||||
}).IfContains(hints.KeyConnMaxLifetime, func(value interface{}) {
|
||||
if m, ok := value.(time.Duration); ok {
|
||||
SetConnMaxLifetime(al, m)
|
||||
}
|
||||
})
|
||||
|
||||
return al, nil
|
||||
}
|
||||
|
||||
// AddAliasWthDB add a aliasName for the drivename
|
||||
func AddAliasWthDB(aliasName, driverName string, db *sql.DB, params ...utils.KV) error {
|
||||
_, err := addAliasWthDB(aliasName, driverName, db, params...)
|
||||
return err
|
||||
}
|
||||
|
||||
// RegisterDataBase Setting the database connect params. Use the database driver self dataSource args.
|
||||
func RegisterDataBase(aliasName, driverName, dataSource string, params ...utils.KV) error {
|
||||
var (
|
||||
err error
|
||||
db *sql.DB
|
||||
al *alias
|
||||
)
|
||||
|
||||
db, err = sql.Open(driverName, dataSource)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("register db `%s`, %s", aliasName, err.Error())
|
||||
goto end
|
||||
}
|
||||
|
||||
al, err = addAliasWthDB(aliasName, driverName, db, params...)
|
||||
if err != nil {
|
||||
goto end
|
||||
}
|
||||
|
||||
al.DataSource = dataSource
|
||||
|
||||
end:
|
||||
if err != nil {
|
||||
if db != nil {
|
||||
db.Close()
|
||||
}
|
||||
DebugLog.Println(err.Error())
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// RegisterDriver Register a database driver use specify driver name, this can be definition the driver is which database type.
|
||||
func RegisterDriver(driverName string, typ DriverType) error {
|
||||
if t, ok := drivers[driverName]; !ok {
|
||||
drivers[driverName] = typ
|
||||
} else {
|
||||
if t != typ {
|
||||
return fmt.Errorf("driverName `%s` db driver already registered and is other type", driverName)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// SetDataBaseTZ Change the database default used timezone
|
||||
func SetDataBaseTZ(aliasName string, tz *time.Location) error {
|
||||
if al, ok := dataBaseCache.get(aliasName); ok {
|
||||
al.TZ = tz
|
||||
} else {
|
||||
return fmt.Errorf("DataBase alias name `%s` not registered", aliasName)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// SetMaxIdleConns Change the max idle conns for *sql.DB, use specify database alias name
|
||||
func SetMaxIdleConns(al *alias, maxIdleConns int) {
|
||||
al.MaxIdleConns = maxIdleConns
|
||||
al.DB.DB.SetMaxIdleConns(maxIdleConns)
|
||||
}
|
||||
|
||||
// SetMaxOpenConns Change the max open conns for *sql.DB, use specify database alias name
|
||||
func SetMaxOpenConns(al *alias, maxOpenConns int) {
|
||||
al.MaxOpenConns = maxOpenConns
|
||||
al.DB.DB.SetMaxOpenConns(maxOpenConns)
|
||||
}
|
||||
|
||||
func SetConnMaxLifetime(al *alias, lifeTime time.Duration) {
|
||||
al.ConnMaxLifetime = lifeTime
|
||||
al.DB.DB.SetConnMaxLifetime(lifeTime)
|
||||
}
|
||||
|
||||
// GetDB Get *sql.DB from registered database by db alias name.
|
||||
// Use "default" as alias name if you not set.
|
||||
func GetDB(aliasNames ...string) (*sql.DB, error) {
|
||||
var name string
|
||||
if len(aliasNames) > 0 {
|
||||
name = aliasNames[0]
|
||||
} else {
|
||||
name = "default"
|
||||
}
|
||||
al, ok := dataBaseCache.get(name)
|
||||
if ok {
|
||||
return al.DB.DB, nil
|
||||
}
|
||||
return nil, fmt.Errorf("DataBase of alias name `%s` not found", name)
|
||||
}
|
||||
|
||||
type stmtDecorator struct {
|
||||
wg sync.WaitGroup
|
||||
stmt *sql.Stmt
|
||||
}
|
||||
|
||||
func (s *stmtDecorator) getStmt() *sql.Stmt {
|
||||
return s.stmt
|
||||
}
|
||||
|
||||
// acquire will add one
|
||||
// since this method will be used inside read lock scope,
|
||||
// so we can not do more things here
|
||||
// we should think about refactor this
|
||||
func (s *stmtDecorator) acquire() {
|
||||
s.wg.Add(1)
|
||||
}
|
||||
|
||||
func (s *stmtDecorator) release() {
|
||||
s.wg.Done()
|
||||
}
|
||||
|
||||
// garbage recycle for stmt
|
||||
func (s *stmtDecorator) destroy() {
|
||||
go func() {
|
||||
s.wg.Wait()
|
||||
_ = s.stmt.Close()
|
||||
}()
|
||||
}
|
||||
|
||||
func newStmtDecorator(sqlStmt *sql.Stmt) *stmtDecorator {
|
||||
return &stmtDecorator{
|
||||
stmt: sqlStmt,
|
||||
}
|
||||
}
|
||||
|
||||
func newStmtDecoratorLruWithEvict(cacheSize int) (*lru.Cache, error) {
|
||||
cache, err := lru.NewWithEvict(cacheSize, func(key interface{}, value interface{}) {
|
||||
value.(*stmtDecorator).destroy()
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return cache, nil
|
||||
}
|
88
pkg/client/orm/db_alias_test.go
Normal file
88
pkg/client/orm/db_alias_test.go
Normal file
@ -0,0 +1,88 @@
|
||||
// Copyright 2020 beego-dev
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package orm
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/astaxie/beego/pkg/client/orm/hints"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestRegisterDataBase(t *testing.T) {
|
||||
err := RegisterDataBase("test-params", DBARGS.Driver, DBARGS.Source,
|
||||
hints.MaxIdleConnections(20),
|
||||
hints.MaxOpenConnections(300),
|
||||
hints.ConnMaxLifetime(time.Minute))
|
||||
assert.Nil(t, err)
|
||||
|
||||
al := getDbAlias("test-params")
|
||||
assert.NotNil(t, al)
|
||||
assert.Equal(t, al.MaxIdleConns, 20)
|
||||
assert.Equal(t, al.MaxOpenConns, 300)
|
||||
assert.Equal(t, al.ConnMaxLifetime, time.Minute)
|
||||
}
|
||||
|
||||
func TestRegisterDataBase_MaxStmtCacheSizeNegative1(t *testing.T) {
|
||||
aliasName := "TestRegisterDataBase_MaxStmtCacheSizeNegative1"
|
||||
err := RegisterDataBase(aliasName, DBARGS.Driver, DBARGS.Source, hints.MaxStmtCacheSize(-1))
|
||||
assert.Nil(t, err)
|
||||
|
||||
al := getDbAlias(aliasName)
|
||||
assert.NotNil(t, al)
|
||||
assert.Equal(t, al.DB.stmtDecoratorsLimit, 0)
|
||||
}
|
||||
|
||||
func TestRegisterDataBase_MaxStmtCacheSize0(t *testing.T) {
|
||||
aliasName := "TestRegisterDataBase_MaxStmtCacheSize0"
|
||||
err := RegisterDataBase(aliasName, DBARGS.Driver, DBARGS.Source, hints.MaxStmtCacheSize(0))
|
||||
assert.Nil(t, err)
|
||||
|
||||
al := getDbAlias(aliasName)
|
||||
assert.NotNil(t, al)
|
||||
assert.Equal(t, al.DB.stmtDecoratorsLimit, 0)
|
||||
}
|
||||
|
||||
func TestRegisterDataBase_MaxStmtCacheSize1(t *testing.T) {
|
||||
aliasName := "TestRegisterDataBase_MaxStmtCacheSize1"
|
||||
err := RegisterDataBase(aliasName, DBARGS.Driver, DBARGS.Source, hints.MaxStmtCacheSize(1))
|
||||
assert.Nil(t, err)
|
||||
|
||||
al := getDbAlias(aliasName)
|
||||
assert.NotNil(t, al)
|
||||
assert.Equal(t, al.DB.stmtDecoratorsLimit, 1)
|
||||
}
|
||||
|
||||
func TestRegisterDataBase_MaxStmtCacheSize841(t *testing.T) {
|
||||
aliasName := "TestRegisterDataBase_MaxStmtCacheSize841"
|
||||
err := RegisterDataBase(aliasName, DBARGS.Driver, DBARGS.Source, hints.MaxStmtCacheSize(841))
|
||||
assert.Nil(t, err)
|
||||
|
||||
al := getDbAlias(aliasName)
|
||||
assert.NotNil(t, al)
|
||||
assert.Equal(t, al.DB.stmtDecoratorsLimit, 841)
|
||||
}
|
||||
|
||||
func TestDBCache(t *testing.T) {
|
||||
dataBaseCache.add("test1", &alias{})
|
||||
dataBaseCache.add("default", &alias{})
|
||||
al := dataBaseCache.getDefault()
|
||||
assert.NotNil(t, al)
|
||||
al, ok := dataBaseCache.get("test1")
|
||||
assert.NotNil(t, al)
|
||||
assert.True(t, ok)
|
||||
}
|
190
pkg/client/orm/db_mysql.go
Normal file
190
pkg/client/orm/db_mysql.go
Normal file
@ -0,0 +1,190 @@
|
||||
// Copyright 2014 beego Author. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package orm
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"reflect"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// mysql operators.
|
||||
var mysqlOperators = map[string]string{
|
||||
"exact": "= ?",
|
||||
"iexact": "LIKE ?",
|
||||
"contains": "LIKE BINARY ?",
|
||||
"icontains": "LIKE ?",
|
||||
// "regex": "REGEXP BINARY ?",
|
||||
// "iregex": "REGEXP ?",
|
||||
"gt": "> ?",
|
||||
"gte": ">= ?",
|
||||
"lt": "< ?",
|
||||
"lte": "<= ?",
|
||||
"eq": "= ?",
|
||||
"ne": "!= ?",
|
||||
"startswith": "LIKE BINARY ?",
|
||||
"endswith": "LIKE BINARY ?",
|
||||
"istartswith": "LIKE ?",
|
||||
"iendswith": "LIKE ?",
|
||||
}
|
||||
|
||||
// mysql column field types.
|
||||
var mysqlTypes = map[string]string{
|
||||
"auto": "AUTO_INCREMENT NOT NULL PRIMARY KEY",
|
||||
"pk": "NOT NULL PRIMARY KEY",
|
||||
"bool": "bool",
|
||||
"string": "varchar(%d)",
|
||||
"string-char": "char(%d)",
|
||||
"string-text": "longtext",
|
||||
"time.Time-date": "date",
|
||||
"time.Time": "datetime",
|
||||
"int8": "tinyint",
|
||||
"int16": "smallint",
|
||||
"int32": "integer",
|
||||
"int64": "bigint",
|
||||
"uint8": "tinyint unsigned",
|
||||
"uint16": "smallint unsigned",
|
||||
"uint32": "integer unsigned",
|
||||
"uint64": "bigint unsigned",
|
||||
"float64": "double precision",
|
||||
"float64-decimal": "numeric(%d, %d)",
|
||||
}
|
||||
|
||||
// mysql dbBaser implementation.
|
||||
type dbBaseMysql struct {
|
||||
dbBase
|
||||
}
|
||||
|
||||
var _ dbBaser = new(dbBaseMysql)
|
||||
|
||||
// get mysql operator.
|
||||
func (d *dbBaseMysql) OperatorSQL(operator string) string {
|
||||
return mysqlOperators[operator]
|
||||
}
|
||||
|
||||
// get mysql table field types.
|
||||
func (d *dbBaseMysql) DbTypes() map[string]string {
|
||||
return mysqlTypes
|
||||
}
|
||||
|
||||
// show table sql for mysql.
|
||||
func (d *dbBaseMysql) ShowTablesQuery() string {
|
||||
return "SELECT table_name FROM information_schema.tables WHERE table_type = 'BASE TABLE' AND table_schema = DATABASE()"
|
||||
}
|
||||
|
||||
// show columns sql of table for mysql.
|
||||
func (d *dbBaseMysql) ShowColumnsQuery(table string) string {
|
||||
return fmt.Sprintf("SELECT COLUMN_NAME, COLUMN_TYPE, IS_NULLABLE FROM information_schema.columns "+
|
||||
"WHERE table_schema = DATABASE() AND table_name = '%s'", table)
|
||||
}
|
||||
|
||||
// execute sql to check index exist.
|
||||
func (d *dbBaseMysql) IndexExists(db dbQuerier, table string, name string) bool {
|
||||
row := db.QueryRow("SELECT count(*) FROM information_schema.statistics "+
|
||||
"WHERE table_schema = DATABASE() AND table_name = ? AND index_name = ?", table, name)
|
||||
var cnt int
|
||||
row.Scan(&cnt)
|
||||
return cnt > 0
|
||||
}
|
||||
|
||||
// InsertOrUpdate a row
|
||||
// If your primary key or unique column conflict will update
|
||||
// If no will insert
|
||||
// Add "`" for mysql sql building
|
||||
func (d *dbBaseMysql) InsertOrUpdate(q dbQuerier, mi *modelInfo, ind reflect.Value, a *alias, args ...string) (int64, error) {
|
||||
var iouStr string
|
||||
argsMap := map[string]string{}
|
||||
|
||||
iouStr = "ON DUPLICATE KEY UPDATE"
|
||||
|
||||
//Get on the key-value pairs
|
||||
for _, v := range args {
|
||||
kv := strings.Split(v, "=")
|
||||
if len(kv) == 2 {
|
||||
argsMap[strings.ToLower(kv[0])] = kv[1]
|
||||
}
|
||||
}
|
||||
|
||||
isMulti := false
|
||||
names := make([]string, 0, len(mi.fields.dbcols)-1)
|
||||
Q := d.ins.TableQuote()
|
||||
values, _, err := d.collectValues(mi, ind, mi.fields.dbcols, true, true, &names, a.TZ)
|
||||
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
marks := make([]string, len(names))
|
||||
updateValues := make([]interface{}, 0)
|
||||
updates := make([]string, len(names))
|
||||
|
||||
for i, v := range names {
|
||||
marks[i] = "?"
|
||||
valueStr := argsMap[strings.ToLower(v)]
|
||||
if valueStr != "" {
|
||||
updates[i] = "`" + v + "`" + "=" + valueStr
|
||||
} else {
|
||||
updates[i] = "`" + v + "`" + "=?"
|
||||
updateValues = append(updateValues, values[i])
|
||||
}
|
||||
}
|
||||
|
||||
values = append(values, updateValues...)
|
||||
|
||||
sep := fmt.Sprintf("%s, %s", Q, Q)
|
||||
qmarks := strings.Join(marks, ", ")
|
||||
qupdates := strings.Join(updates, ", ")
|
||||
columns := strings.Join(names, sep)
|
||||
|
||||
multi := len(values) / len(names)
|
||||
|
||||
if isMulti {
|
||||
qmarks = strings.Repeat(qmarks+"), (", multi-1) + qmarks
|
||||
}
|
||||
//conflitValue maybe is a int,can`t use fmt.Sprintf
|
||||
query := fmt.Sprintf("INSERT INTO %s%s%s (%s%s%s) VALUES (%s) %s "+qupdates, Q, mi.table, Q, Q, columns, Q, qmarks, iouStr)
|
||||
|
||||
d.ins.ReplaceMarks(&query)
|
||||
|
||||
if isMulti || !d.ins.HasReturningID(mi, &query) {
|
||||
res, err := q.Exec(query, values...)
|
||||
if err == nil {
|
||||
if isMulti {
|
||||
return res.RowsAffected()
|
||||
}
|
||||
|
||||
lastInsertId, err := res.LastInsertId()
|
||||
if err != nil {
|
||||
DebugLog.Println(ErrLastInsertIdUnavailable, ':', err)
|
||||
return lastInsertId, ErrLastInsertIdUnavailable
|
||||
} else {
|
||||
return lastInsertId, nil
|
||||
}
|
||||
}
|
||||
return 0, err
|
||||
}
|
||||
|
||||
row := q.QueryRow(query, values...)
|
||||
var id int64
|
||||
err = row.Scan(&id)
|
||||
return id, err
|
||||
}
|
||||
|
||||
// create new mysql dbBaser.
|
||||
func newdbBaseMysql() dbBaser {
|
||||
b := new(dbBaseMysql)
|
||||
b.ins = b
|
||||
return b
|
||||
}
|
169
pkg/client/orm/db_oracle.go
Normal file
169
pkg/client/orm/db_oracle.go
Normal file
@ -0,0 +1,169 @@
|
||||
// Copyright 2014 beego Author. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package orm
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/astaxie/beego/pkg/client/orm/hints"
|
||||
)
|
||||
|
||||
// oracle operators.
|
||||
var oracleOperators = map[string]string{
|
||||
"exact": "= ?",
|
||||
"gt": "> ?",
|
||||
"gte": ">= ?",
|
||||
"lt": "< ?",
|
||||
"lte": "<= ?",
|
||||
"//iendswith": "LIKE ?",
|
||||
}
|
||||
|
||||
// oracle column field types.
|
||||
var oracleTypes = map[string]string{
|
||||
"pk": "NOT NULL PRIMARY KEY",
|
||||
"bool": "bool",
|
||||
"string": "VARCHAR2(%d)",
|
||||
"string-char": "CHAR(%d)",
|
||||
"string-text": "VARCHAR2(%d)",
|
||||
"time.Time-date": "DATE",
|
||||
"time.Time": "TIMESTAMP",
|
||||
"int8": "INTEGER",
|
||||
"int16": "INTEGER",
|
||||
"int32": "INTEGER",
|
||||
"int64": "INTEGER",
|
||||
"uint8": "INTEGER",
|
||||
"uint16": "INTEGER",
|
||||
"uint32": "INTEGER",
|
||||
"uint64": "INTEGER",
|
||||
"float64": "NUMBER",
|
||||
"float64-decimal": "NUMBER(%d, %d)",
|
||||
}
|
||||
|
||||
// oracle dbBaser
|
||||
type dbBaseOracle struct {
|
||||
dbBase
|
||||
}
|
||||
|
||||
var _ dbBaser = new(dbBaseOracle)
|
||||
|
||||
// create oracle dbBaser.
|
||||
func newdbBaseOracle() dbBaser {
|
||||
b := new(dbBaseOracle)
|
||||
b.ins = b
|
||||
return b
|
||||
}
|
||||
|
||||
// OperatorSQL get oracle operator.
|
||||
func (d *dbBaseOracle) OperatorSQL(operator string) string {
|
||||
return oracleOperators[operator]
|
||||
}
|
||||
|
||||
// DbTypes get oracle table field types.
|
||||
func (d *dbBaseOracle) DbTypes() map[string]string {
|
||||
return oracleTypes
|
||||
}
|
||||
|
||||
//ShowTablesQuery show all the tables in database
|
||||
func (d *dbBaseOracle) ShowTablesQuery() string {
|
||||
return "SELECT TABLE_NAME FROM USER_TABLES"
|
||||
}
|
||||
|
||||
// Oracle
|
||||
func (d *dbBaseOracle) ShowColumnsQuery(table string) string {
|
||||
return fmt.Sprintf("SELECT COLUMN_NAME FROM ALL_TAB_COLUMNS "+
|
||||
"WHERE TABLE_NAME ='%s'", strings.ToUpper(table))
|
||||
}
|
||||
|
||||
// check index is exist
|
||||
func (d *dbBaseOracle) IndexExists(db dbQuerier, table string, name string) bool {
|
||||
row := db.QueryRow("SELECT COUNT(*) FROM USER_IND_COLUMNS, USER_INDEXES "+
|
||||
"WHERE USER_IND_COLUMNS.INDEX_NAME = USER_INDEXES.INDEX_NAME "+
|
||||
"AND USER_IND_COLUMNS.TABLE_NAME = ? AND USER_IND_COLUMNS.INDEX_NAME = ?", strings.ToUpper(table), strings.ToUpper(name))
|
||||
|
||||
var cnt int
|
||||
row.Scan(&cnt)
|
||||
return cnt > 0
|
||||
}
|
||||
|
||||
func (d *dbBaseOracle) GenerateSpecifyIndex(tableName string, useIndex int, indexes []string) string {
|
||||
var s []string
|
||||
Q := d.TableQuote()
|
||||
for _, index := range indexes {
|
||||
tmp := fmt.Sprintf(`%s%s%s`, Q, index, Q)
|
||||
s = append(s, tmp)
|
||||
}
|
||||
|
||||
var hint string
|
||||
|
||||
switch useIndex {
|
||||
case hints.KeyUseIndex, hints.KeyForceIndex:
|
||||
hint = `INDEX`
|
||||
case hints.KeyIgnoreIndex:
|
||||
hint = `NO_INDEX`
|
||||
default:
|
||||
DebugLog.Println("[WARN] Not a valid specifying action, so that action is ignored")
|
||||
return ``
|
||||
}
|
||||
|
||||
return fmt.Sprintf(` /*+ %s(%s %s)*/ `, hint, tableName, strings.Join(s, `,`))
|
||||
}
|
||||
|
||||
// execute insert sql with given struct and given values.
|
||||
// insert the given values, not the field values in struct.
|
||||
func (d *dbBaseOracle) InsertValue(q dbQuerier, mi *modelInfo, isMulti bool, names []string, values []interface{}) (int64, error) {
|
||||
Q := d.ins.TableQuote()
|
||||
|
||||
marks := make([]string, len(names))
|
||||
for i := range marks {
|
||||
marks[i] = ":" + names[i]
|
||||
}
|
||||
|
||||
sep := fmt.Sprintf("%s, %s", Q, Q)
|
||||
qmarks := strings.Join(marks, ", ")
|
||||
columns := strings.Join(names, sep)
|
||||
|
||||
multi := len(values) / len(names)
|
||||
|
||||
if isMulti {
|
||||
qmarks = strings.Repeat(qmarks+"), (", multi-1) + qmarks
|
||||
}
|
||||
|
||||
query := fmt.Sprintf("INSERT INTO %s%s%s (%s%s%s) VALUES (%s)", Q, mi.table, Q, Q, columns, Q, qmarks)
|
||||
|
||||
d.ins.ReplaceMarks(&query)
|
||||
|
||||
if isMulti || !d.ins.HasReturningID(mi, &query) {
|
||||
res, err := q.Exec(query, values...)
|
||||
if err == nil {
|
||||
if isMulti {
|
||||
return res.RowsAffected()
|
||||
}
|
||||
|
||||
lastInsertId, err := res.LastInsertId()
|
||||
if err != nil {
|
||||
DebugLog.Println(ErrLastInsertIdUnavailable, ':', err)
|
||||
return lastInsertId, ErrLastInsertIdUnavailable
|
||||
} else {
|
||||
return lastInsertId, nil
|
||||
}
|
||||
}
|
||||
return 0, err
|
||||
}
|
||||
row := q.QueryRow(query, values...)
|
||||
var id int64
|
||||
err := row.Scan(&id)
|
||||
return id, err
|
||||
}
|
195
pkg/client/orm/db_postgres.go
Normal file
195
pkg/client/orm/db_postgres.go
Normal file
@ -0,0 +1,195 @@
|
||||
// Copyright 2014 beego Author. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package orm
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strconv"
|
||||
)
|
||||
|
||||
// postgresql operators.
|
||||
var postgresOperators = map[string]string{
|
||||
"exact": "= ?",
|
||||
"iexact": "= UPPER(?)",
|
||||
"contains": "LIKE ?",
|
||||
"icontains": "LIKE UPPER(?)",
|
||||
"gt": "> ?",
|
||||
"gte": ">= ?",
|
||||
"lt": "< ?",
|
||||
"lte": "<= ?",
|
||||
"eq": "= ?",
|
||||
"ne": "!= ?",
|
||||
"startswith": "LIKE ?",
|
||||
"endswith": "LIKE ?",
|
||||
"istartswith": "LIKE UPPER(?)",
|
||||
"iendswith": "LIKE UPPER(?)",
|
||||
}
|
||||
|
||||
// postgresql column field types.
|
||||
var postgresTypes = map[string]string{
|
||||
"auto": "serial NOT NULL PRIMARY KEY",
|
||||
"pk": "NOT NULL PRIMARY KEY",
|
||||
"bool": "bool",
|
||||
"string": "varchar(%d)",
|
||||
"string-char": "char(%d)",
|
||||
"string-text": "text",
|
||||
"time.Time-date": "date",
|
||||
"time.Time": "timestamp with time zone",
|
||||
"int8": `smallint CHECK("%COL%" >= -127 AND "%COL%" <= 128)`,
|
||||
"int16": "smallint",
|
||||
"int32": "integer",
|
||||
"int64": "bigint",
|
||||
"uint8": `smallint CHECK("%COL%" >= 0 AND "%COL%" <= 255)`,
|
||||
"uint16": `integer CHECK("%COL%" >= 0)`,
|
||||
"uint32": `bigint CHECK("%COL%" >= 0)`,
|
||||
"uint64": `bigint CHECK("%COL%" >= 0)`,
|
||||
"float64": "double precision",
|
||||
"float64-decimal": "numeric(%d, %d)",
|
||||
"json": "json",
|
||||
"jsonb": "jsonb",
|
||||
}
|
||||
|
||||
// postgresql dbBaser.
|
||||
type dbBasePostgres struct {
|
||||
dbBase
|
||||
}
|
||||
|
||||
var _ dbBaser = new(dbBasePostgres)
|
||||
|
||||
// get postgresql operator.
|
||||
func (d *dbBasePostgres) OperatorSQL(operator string) string {
|
||||
return postgresOperators[operator]
|
||||
}
|
||||
|
||||
// generate functioned sql string, such as contains(text).
|
||||
func (d *dbBasePostgres) GenerateOperatorLeftCol(fi *fieldInfo, operator string, leftCol *string) {
|
||||
switch operator {
|
||||
case "contains", "startswith", "endswith":
|
||||
*leftCol = fmt.Sprintf("%s::text", *leftCol)
|
||||
case "iexact", "icontains", "istartswith", "iendswith":
|
||||
*leftCol = fmt.Sprintf("UPPER(%s::text)", *leftCol)
|
||||
}
|
||||
}
|
||||
|
||||
// postgresql unsupports updating joined record.
|
||||
func (d *dbBasePostgres) SupportUpdateJoin() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (d *dbBasePostgres) MaxLimit() uint64 {
|
||||
return 0
|
||||
}
|
||||
|
||||
// postgresql quote is ".
|
||||
func (d *dbBasePostgres) TableQuote() string {
|
||||
return `"`
|
||||
}
|
||||
|
||||
// postgresql value placeholder is $n.
|
||||
// replace default ? to $n.
|
||||
func (d *dbBasePostgres) ReplaceMarks(query *string) {
|
||||
q := *query
|
||||
num := 0
|
||||
for _, c := range q {
|
||||
if c == '?' {
|
||||
num++
|
||||
}
|
||||
}
|
||||
if num == 0 {
|
||||
return
|
||||
}
|
||||
data := make([]byte, 0, len(q)+num)
|
||||
num = 1
|
||||
for i := 0; i < len(q); i++ {
|
||||
c := q[i]
|
||||
if c == '?' {
|
||||
data = append(data, '$')
|
||||
data = append(data, []byte(strconv.Itoa(num))...)
|
||||
num++
|
||||
} else {
|
||||
data = append(data, c)
|
||||
}
|
||||
}
|
||||
*query = string(data)
|
||||
}
|
||||
|
||||
// make returning sql support for postgresql.
|
||||
func (d *dbBasePostgres) HasReturningID(mi *modelInfo, query *string) bool {
|
||||
fi := mi.fields.pk
|
||||
if fi.fieldType&IsPositiveIntegerField == 0 && fi.fieldType&IsIntegerField == 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
if query != nil {
|
||||
*query = fmt.Sprintf(`%s RETURNING "%s"`, *query, fi.column)
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// sync auto key
|
||||
func (d *dbBasePostgres) setval(db dbQuerier, mi *modelInfo, autoFields []string) error {
|
||||
if len(autoFields) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
Q := d.ins.TableQuote()
|
||||
for _, name := range autoFields {
|
||||
query := fmt.Sprintf("SELECT setval(pg_get_serial_sequence('%s', '%s'), (SELECT MAX(%s%s%s) FROM %s%s%s));",
|
||||
mi.table, name,
|
||||
Q, name, Q,
|
||||
Q, mi.table, Q)
|
||||
if _, err := db.Exec(query); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// show table sql for postgresql.
|
||||
func (d *dbBasePostgres) ShowTablesQuery() string {
|
||||
return "SELECT table_name FROM information_schema.tables WHERE table_type = 'BASE TABLE' AND table_schema NOT IN ('pg_catalog', 'information_schema')"
|
||||
}
|
||||
|
||||
// show table columns sql for postgresql.
|
||||
func (d *dbBasePostgres) ShowColumnsQuery(table string) string {
|
||||
return fmt.Sprintf("SELECT column_name, data_type, is_nullable FROM information_schema.columns where table_schema NOT IN ('pg_catalog', 'information_schema') and table_name = '%s'", table)
|
||||
}
|
||||
|
||||
// get column types of postgresql.
|
||||
func (d *dbBasePostgres) DbTypes() map[string]string {
|
||||
return postgresTypes
|
||||
}
|
||||
|
||||
// check index exist in postgresql.
|
||||
func (d *dbBasePostgres) IndexExists(db dbQuerier, table string, name string) bool {
|
||||
query := fmt.Sprintf("SELECT COUNT(*) FROM pg_indexes WHERE tablename = '%s' AND indexname = '%s'", table, name)
|
||||
row := db.QueryRow(query)
|
||||
var cnt int
|
||||
row.Scan(&cnt)
|
||||
return cnt > 0
|
||||
}
|
||||
|
||||
// GenerateSpecifyIndex return a specifying index clause
|
||||
func (d *dbBasePostgres) GenerateSpecifyIndex(tableName string, useIndex int, indexes []string) string {
|
||||
DebugLog.Println("[WARN] Not support any specifying index action, so that action is ignored")
|
||||
return ``
|
||||
}
|
||||
|
||||
// create new postgresql dbBaser.
|
||||
func newdbBasePostgres() dbBaser {
|
||||
b := new(dbBasePostgres)
|
||||
b.ins = b
|
||||
return b
|
||||
}
|
182
pkg/client/orm/db_sqlite.go
Normal file
182
pkg/client/orm/db_sqlite.go
Normal file
@ -0,0 +1,182 @@
|
||||
// Copyright 2014 beego Author. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package orm
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/astaxie/beego/pkg/client/orm/hints"
|
||||
)
|
||||
|
||||
// sqlite operators.
|
||||
var sqliteOperators = map[string]string{
|
||||
"exact": "= ?",
|
||||
"iexact": "LIKE ? ESCAPE '\\'",
|
||||
"contains": "LIKE ? ESCAPE '\\'",
|
||||
"icontains": "LIKE ? ESCAPE '\\'",
|
||||
"gt": "> ?",
|
||||
"gte": ">= ?",
|
||||
"lt": "< ?",
|
||||
"lte": "<= ?",
|
||||
"eq": "= ?",
|
||||
"ne": "!= ?",
|
||||
"startswith": "LIKE ? ESCAPE '\\'",
|
||||
"endswith": "LIKE ? ESCAPE '\\'",
|
||||
"istartswith": "LIKE ? ESCAPE '\\'",
|
||||
"iendswith": "LIKE ? ESCAPE '\\'",
|
||||
}
|
||||
|
||||
// sqlite column types.
|
||||
var sqliteTypes = map[string]string{
|
||||
"auto": "integer NOT NULL PRIMARY KEY AUTOINCREMENT",
|
||||
"pk": "NOT NULL PRIMARY KEY",
|
||||
"bool": "bool",
|
||||
"string": "varchar(%d)",
|
||||
"string-char": "character(%d)",
|
||||
"string-text": "text",
|
||||
"time.Time-date": "date",
|
||||
"time.Time": "datetime",
|
||||
"int8": "tinyint",
|
||||
"int16": "smallint",
|
||||
"int32": "integer",
|
||||
"int64": "bigint",
|
||||
"uint8": "tinyint unsigned",
|
||||
"uint16": "smallint unsigned",
|
||||
"uint32": "integer unsigned",
|
||||
"uint64": "bigint unsigned",
|
||||
"float64": "real",
|
||||
"float64-decimal": "decimal",
|
||||
}
|
||||
|
||||
// sqlite dbBaser.
|
||||
type dbBaseSqlite struct {
|
||||
dbBase
|
||||
}
|
||||
|
||||
var _ dbBaser = new(dbBaseSqlite)
|
||||
|
||||
// override base db read for update behavior as SQlite does not support syntax
|
||||
func (d *dbBaseSqlite) Read(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location, cols []string, isForUpdate bool) error {
|
||||
if isForUpdate {
|
||||
DebugLog.Println("[WARN] SQLite does not support SELECT FOR UPDATE query, isForUpdate param is ignored and always as false to do the work")
|
||||
}
|
||||
return d.dbBase.Read(q, mi, ind, tz, cols, false)
|
||||
}
|
||||
|
||||
// get sqlite operator.
|
||||
func (d *dbBaseSqlite) OperatorSQL(operator string) string {
|
||||
return sqliteOperators[operator]
|
||||
}
|
||||
|
||||
// generate functioned sql for sqlite.
|
||||
// only support DATE(text).
|
||||
func (d *dbBaseSqlite) GenerateOperatorLeftCol(fi *fieldInfo, operator string, leftCol *string) {
|
||||
if fi.fieldType == TypeDateField {
|
||||
*leftCol = fmt.Sprintf("DATE(%s)", *leftCol)
|
||||
}
|
||||
}
|
||||
|
||||
// unable updating joined record in sqlite.
|
||||
func (d *dbBaseSqlite) SupportUpdateJoin() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// max int in sqlite.
|
||||
func (d *dbBaseSqlite) MaxLimit() uint64 {
|
||||
return 9223372036854775807
|
||||
}
|
||||
|
||||
// get column types in sqlite.
|
||||
func (d *dbBaseSqlite) DbTypes() map[string]string {
|
||||
return sqliteTypes
|
||||
}
|
||||
|
||||
// get show tables sql in sqlite.
|
||||
func (d *dbBaseSqlite) ShowTablesQuery() string {
|
||||
return "SELECT name FROM sqlite_master WHERE type = 'table'"
|
||||
}
|
||||
|
||||
// get columns in sqlite.
|
||||
func (d *dbBaseSqlite) GetColumns(db dbQuerier, table string) (map[string][3]string, error) {
|
||||
query := d.ins.ShowColumnsQuery(table)
|
||||
rows, err := db.Query(query)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
columns := make(map[string][3]string)
|
||||
for rows.Next() {
|
||||
var tmp, name, typ, null sql.NullString
|
||||
err := rows.Scan(&tmp, &name, &typ, &null, &tmp, &tmp)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
columns[name.String] = [3]string{name.String, typ.String, null.String}
|
||||
}
|
||||
|
||||
return columns, nil
|
||||
}
|
||||
|
||||
// get show columns sql in sqlite.
|
||||
func (d *dbBaseSqlite) ShowColumnsQuery(table string) string {
|
||||
return fmt.Sprintf("pragma table_info('%s')", table)
|
||||
}
|
||||
|
||||
// check index exist in sqlite.
|
||||
func (d *dbBaseSqlite) IndexExists(db dbQuerier, table string, name string) bool {
|
||||
query := fmt.Sprintf("PRAGMA index_list('%s')", table)
|
||||
rows, err := db.Query(query)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
defer rows.Close()
|
||||
for rows.Next() {
|
||||
var tmp, index sql.NullString
|
||||
rows.Scan(&tmp, &index, &tmp, &tmp, &tmp)
|
||||
if name == index.String {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// GenerateSpecifyIndex return a specifying index clause
|
||||
func (d *dbBaseSqlite) GenerateSpecifyIndex(tableName string, useIndex int, indexes []string) string {
|
||||
var s []string
|
||||
Q := d.TableQuote()
|
||||
for _, index := range indexes {
|
||||
tmp := fmt.Sprintf(`%s%s%s`, Q, index, Q)
|
||||
s = append(s, tmp)
|
||||
}
|
||||
|
||||
switch useIndex {
|
||||
case hints.KeyUseIndex, hints.KeyForceIndex:
|
||||
return fmt.Sprintf(` INDEXED BY %s `, strings.Join(s, `,`))
|
||||
default:
|
||||
DebugLog.Println("[WARN] Not a valid specifying action, so that action is ignored")
|
||||
return ``
|
||||
}
|
||||
}
|
||||
|
||||
// create new sqlite dbBaser.
|
||||
func newdbBaseSqlite() dbBaser {
|
||||
b := new(dbBaseSqlite)
|
||||
b.ins = b
|
||||
return b
|
||||
}
|
491
pkg/client/orm/db_tables.go
Normal file
491
pkg/client/orm/db_tables.go
Normal file
@ -0,0 +1,491 @@
|
||||
// Copyright 2014 beego Author. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package orm
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// table info struct.
|
||||
type dbTable struct {
|
||||
id int
|
||||
index string
|
||||
name string
|
||||
names []string
|
||||
sel bool
|
||||
inner bool
|
||||
mi *modelInfo
|
||||
fi *fieldInfo
|
||||
jtl *dbTable
|
||||
}
|
||||
|
||||
// tables collection struct, contains some tables.
|
||||
type dbTables struct {
|
||||
tablesM map[string]*dbTable
|
||||
tables []*dbTable
|
||||
mi *modelInfo
|
||||
base dbBaser
|
||||
skipEnd bool
|
||||
}
|
||||
|
||||
// set table info to collection.
|
||||
// if not exist, create new.
|
||||
func (t *dbTables) set(names []string, mi *modelInfo, fi *fieldInfo, inner bool) *dbTable {
|
||||
name := strings.Join(names, ExprSep)
|
||||
if j, ok := t.tablesM[name]; ok {
|
||||
j.name = name
|
||||
j.mi = mi
|
||||
j.fi = fi
|
||||
j.inner = inner
|
||||
} else {
|
||||
i := len(t.tables) + 1
|
||||
jt := &dbTable{i, fmt.Sprintf("T%d", i), name, names, false, inner, mi, fi, nil}
|
||||
t.tablesM[name] = jt
|
||||
t.tables = append(t.tables, jt)
|
||||
}
|
||||
return t.tablesM[name]
|
||||
}
|
||||
|
||||
// add table info to collection.
|
||||
func (t *dbTables) add(names []string, mi *modelInfo, fi *fieldInfo, inner bool) (*dbTable, bool) {
|
||||
name := strings.Join(names, ExprSep)
|
||||
if _, ok := t.tablesM[name]; !ok {
|
||||
i := len(t.tables) + 1
|
||||
jt := &dbTable{i, fmt.Sprintf("T%d", i), name, names, false, inner, mi, fi, nil}
|
||||
t.tablesM[name] = jt
|
||||
t.tables = append(t.tables, jt)
|
||||
return jt, true
|
||||
}
|
||||
return t.tablesM[name], false
|
||||
}
|
||||
|
||||
// get table info in collection.
|
||||
func (t *dbTables) get(name string) (*dbTable, bool) {
|
||||
j, ok := t.tablesM[name]
|
||||
return j, ok
|
||||
}
|
||||
|
||||
// get related fields info in recursive depth loop.
|
||||
// loop once, depth decreases one.
|
||||
func (t *dbTables) loopDepth(depth int, prefix string, fi *fieldInfo, related []string) []string {
|
||||
if depth < 0 || fi.fieldType == RelManyToMany {
|
||||
return related
|
||||
}
|
||||
|
||||
if prefix == "" {
|
||||
prefix = fi.name
|
||||
} else {
|
||||
prefix = prefix + ExprSep + fi.name
|
||||
}
|
||||
related = append(related, prefix)
|
||||
|
||||
depth--
|
||||
for _, fi := range fi.relModelInfo.fields.fieldsRel {
|
||||
related = t.loopDepth(depth, prefix, fi, related)
|
||||
}
|
||||
|
||||
return related
|
||||
}
|
||||
|
||||
// parse related fields.
|
||||
func (t *dbTables) parseRelated(rels []string, depth int) {
|
||||
|
||||
relsNum := len(rels)
|
||||
related := make([]string, relsNum)
|
||||
copy(related, rels)
|
||||
|
||||
relDepth := depth
|
||||
|
||||
if relsNum != 0 {
|
||||
relDepth = 0
|
||||
}
|
||||
|
||||
relDepth--
|
||||
for _, fi := range t.mi.fields.fieldsRel {
|
||||
related = t.loopDepth(relDepth, "", fi, related)
|
||||
}
|
||||
|
||||
for i, s := range related {
|
||||
var (
|
||||
exs = strings.Split(s, ExprSep)
|
||||
names = make([]string, 0, len(exs))
|
||||
mmi = t.mi
|
||||
cancel = true
|
||||
jtl *dbTable
|
||||
)
|
||||
|
||||
inner := true
|
||||
|
||||
for _, ex := range exs {
|
||||
if fi, ok := mmi.fields.GetByAny(ex); ok && fi.rel && fi.fieldType != RelManyToMany {
|
||||
names = append(names, fi.name)
|
||||
mmi = fi.relModelInfo
|
||||
|
||||
if fi.null || t.skipEnd {
|
||||
inner = false
|
||||
}
|
||||
|
||||
jt := t.set(names, mmi, fi, inner)
|
||||
jt.jtl = jtl
|
||||
|
||||
if fi.reverse {
|
||||
cancel = false
|
||||
}
|
||||
|
||||
if cancel {
|
||||
jt.sel = depth > 0
|
||||
|
||||
if i < relsNum {
|
||||
jt.sel = true
|
||||
}
|
||||
}
|
||||
|
||||
jtl = jt
|
||||
|
||||
} else {
|
||||
panic(fmt.Errorf("unknown model/table name `%s`", ex))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// generate join string.
|
||||
func (t *dbTables) getJoinSQL() (join string) {
|
||||
Q := t.base.TableQuote()
|
||||
|
||||
for _, jt := range t.tables {
|
||||
if jt.inner {
|
||||
join += "INNER JOIN "
|
||||
} else {
|
||||
join += "LEFT OUTER JOIN "
|
||||
}
|
||||
var (
|
||||
table string
|
||||
t1, t2 string
|
||||
c1, c2 string
|
||||
)
|
||||
t1 = "T0"
|
||||
if jt.jtl != nil {
|
||||
t1 = jt.jtl.index
|
||||
}
|
||||
t2 = jt.index
|
||||
table = jt.mi.table
|
||||
|
||||
switch {
|
||||
case jt.fi.fieldType == RelManyToMany || jt.fi.fieldType == RelReverseMany || jt.fi.reverse && jt.fi.reverseFieldInfo.fieldType == RelManyToMany:
|
||||
c1 = jt.fi.mi.fields.pk.column
|
||||
for _, ffi := range jt.mi.fields.fieldsRel {
|
||||
if jt.fi.mi == ffi.relModelInfo {
|
||||
c2 = ffi.column
|
||||
break
|
||||
}
|
||||
}
|
||||
default:
|
||||
c1 = jt.fi.column
|
||||
c2 = jt.fi.relModelInfo.fields.pk.column
|
||||
|
||||
if jt.fi.reverse {
|
||||
c1 = jt.mi.fields.pk.column
|
||||
c2 = jt.fi.reverseFieldInfo.column
|
||||
}
|
||||
}
|
||||
|
||||
join += fmt.Sprintf("%s%s%s %s ON %s.%s%s%s = %s.%s%s%s ", Q, table, Q, t2,
|
||||
t2, Q, c2, Q, t1, Q, c1, Q)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// parse orm model struct field tag expression.
|
||||
func (t *dbTables) parseExprs(mi *modelInfo, exprs []string) (index, name string, info *fieldInfo, success bool) {
|
||||
var (
|
||||
jtl *dbTable
|
||||
fi *fieldInfo
|
||||
fiN *fieldInfo
|
||||
mmi = mi
|
||||
)
|
||||
|
||||
num := len(exprs) - 1
|
||||
var names []string
|
||||
|
||||
inner := true
|
||||
|
||||
loopFor:
|
||||
for i, ex := range exprs {
|
||||
|
||||
var ok, okN bool
|
||||
|
||||
if fiN != nil {
|
||||
fi = fiN
|
||||
ok = true
|
||||
fiN = nil
|
||||
}
|
||||
|
||||
if i == 0 {
|
||||
fi, ok = mmi.fields.GetByAny(ex)
|
||||
}
|
||||
|
||||
_ = okN
|
||||
|
||||
if ok {
|
||||
|
||||
isRel := fi.rel || fi.reverse
|
||||
|
||||
names = append(names, fi.name)
|
||||
|
||||
switch {
|
||||
case fi.rel:
|
||||
mmi = fi.relModelInfo
|
||||
if fi.fieldType == RelManyToMany {
|
||||
mmi = fi.relThroughModelInfo
|
||||
}
|
||||
case fi.reverse:
|
||||
mmi = fi.reverseFieldInfo.mi
|
||||
}
|
||||
|
||||
if i < num {
|
||||
fiN, okN = mmi.fields.GetByAny(exprs[i+1])
|
||||
}
|
||||
|
||||
if isRel && (!fi.mi.isThrough || num != i) {
|
||||
if fi.null || t.skipEnd {
|
||||
inner = false
|
||||
}
|
||||
|
||||
if t.skipEnd && okN || !t.skipEnd {
|
||||
if t.skipEnd && okN && fiN.pk {
|
||||
goto loopEnd
|
||||
}
|
||||
|
||||
jt, _ := t.add(names, mmi, fi, inner)
|
||||
jt.jtl = jtl
|
||||
jtl = jt
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
if num != i {
|
||||
continue
|
||||
}
|
||||
|
||||
loopEnd:
|
||||
|
||||
if i == 0 || jtl == nil {
|
||||
index = "T0"
|
||||
} else {
|
||||
index = jtl.index
|
||||
}
|
||||
|
||||
info = fi
|
||||
|
||||
if jtl == nil {
|
||||
name = fi.name
|
||||
} else {
|
||||
name = jtl.name + ExprSep + fi.name
|
||||
}
|
||||
|
||||
switch {
|
||||
case fi.rel:
|
||||
|
||||
case fi.reverse:
|
||||
switch fi.reverseFieldInfo.fieldType {
|
||||
case RelOneToOne, RelForeignKey:
|
||||
index = jtl.index
|
||||
info = fi.reverseFieldInfo.mi.fields.pk
|
||||
name = info.name
|
||||
}
|
||||
}
|
||||
|
||||
break loopFor
|
||||
|
||||
} else {
|
||||
index = ""
|
||||
name = ""
|
||||
info = nil
|
||||
success = false
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
success = index != "" && info != nil
|
||||
return
|
||||
}
|
||||
|
||||
// generate condition sql.
|
||||
func (t *dbTables) getCondSQL(cond *Condition, sub bool, tz *time.Location) (where string, params []interface{}) {
|
||||
if cond == nil || cond.IsEmpty() {
|
||||
return
|
||||
}
|
||||
|
||||
Q := t.base.TableQuote()
|
||||
|
||||
mi := t.mi
|
||||
|
||||
for i, p := range cond.params {
|
||||
if i > 0 {
|
||||
if p.isOr {
|
||||
where += "OR "
|
||||
} else {
|
||||
where += "AND "
|
||||
}
|
||||
}
|
||||
if p.isNot {
|
||||
where += "NOT "
|
||||
}
|
||||
if p.isCond {
|
||||
w, ps := t.getCondSQL(p.cond, true, tz)
|
||||
if w != "" {
|
||||
w = fmt.Sprintf("( %s) ", w)
|
||||
}
|
||||
where += w
|
||||
params = append(params, ps...)
|
||||
} else {
|
||||
exprs := p.exprs
|
||||
|
||||
num := len(exprs) - 1
|
||||
operator := ""
|
||||
if operators[exprs[num]] {
|
||||
operator = exprs[num]
|
||||
exprs = exprs[:num]
|
||||
}
|
||||
|
||||
index, _, fi, suc := t.parseExprs(mi, exprs)
|
||||
if !suc {
|
||||
panic(fmt.Errorf("unknown field/column name `%s`", strings.Join(p.exprs, ExprSep)))
|
||||
}
|
||||
|
||||
if operator == "" {
|
||||
operator = "exact"
|
||||
}
|
||||
|
||||
var operSQL string
|
||||
var args []interface{}
|
||||
if p.isRaw {
|
||||
operSQL = p.sql
|
||||
} else {
|
||||
operSQL, args = t.base.GenerateOperatorSQL(mi, fi, operator, p.args, tz)
|
||||
}
|
||||
|
||||
leftCol := fmt.Sprintf("%s.%s%s%s", index, Q, fi.column, Q)
|
||||
t.base.GenerateOperatorLeftCol(fi, operator, &leftCol)
|
||||
|
||||
where += fmt.Sprintf("%s %s ", leftCol, operSQL)
|
||||
params = append(params, args...)
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
if !sub && where != "" {
|
||||
where = "WHERE " + where
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// generate group sql.
|
||||
func (t *dbTables) getGroupSQL(groups []string) (groupSQL string) {
|
||||
if len(groups) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
Q := t.base.TableQuote()
|
||||
|
||||
groupSqls := make([]string, 0, len(groups))
|
||||
for _, group := range groups {
|
||||
exprs := strings.Split(group, ExprSep)
|
||||
|
||||
index, _, fi, suc := t.parseExprs(t.mi, exprs)
|
||||
if !suc {
|
||||
panic(fmt.Errorf("unknown field/column name `%s`", strings.Join(exprs, ExprSep)))
|
||||
}
|
||||
|
||||
groupSqls = append(groupSqls, fmt.Sprintf("%s.%s%s%s", index, Q, fi.column, Q))
|
||||
}
|
||||
|
||||
groupSQL = fmt.Sprintf("GROUP BY %s ", strings.Join(groupSqls, ", "))
|
||||
return
|
||||
}
|
||||
|
||||
// generate order sql.
|
||||
func (t *dbTables) getOrderSQL(orders []string) (orderSQL string) {
|
||||
if len(orders) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
Q := t.base.TableQuote()
|
||||
|
||||
orderSqls := make([]string, 0, len(orders))
|
||||
for _, order := range orders {
|
||||
asc := "ASC"
|
||||
if order[0] == '-' {
|
||||
asc = "DESC"
|
||||
order = order[1:]
|
||||
}
|
||||
exprs := strings.Split(order, ExprSep)
|
||||
|
||||
index, _, fi, suc := t.parseExprs(t.mi, exprs)
|
||||
if !suc {
|
||||
panic(fmt.Errorf("unknown field/column name `%s`", strings.Join(exprs, ExprSep)))
|
||||
}
|
||||
|
||||
orderSqls = append(orderSqls, fmt.Sprintf("%s.%s%s%s %s", index, Q, fi.column, Q, asc))
|
||||
}
|
||||
|
||||
orderSQL = fmt.Sprintf("ORDER BY %s ", strings.Join(orderSqls, ", "))
|
||||
return
|
||||
}
|
||||
|
||||
// generate limit sql.
|
||||
func (t *dbTables) getLimitSQL(mi *modelInfo, offset int64, limit int64) (limits string) {
|
||||
if limit == 0 {
|
||||
limit = int64(DefaultRowsLimit)
|
||||
}
|
||||
if limit < 0 {
|
||||
// no limit
|
||||
if offset > 0 {
|
||||
maxLimit := t.base.MaxLimit()
|
||||
if maxLimit == 0 {
|
||||
limits = fmt.Sprintf("OFFSET %d", offset)
|
||||
} else {
|
||||
limits = fmt.Sprintf("LIMIT %d OFFSET %d", maxLimit, offset)
|
||||
}
|
||||
}
|
||||
} else if offset <= 0 {
|
||||
limits = fmt.Sprintf("LIMIT %d", limit)
|
||||
} else {
|
||||
limits = fmt.Sprintf("LIMIT %d OFFSET %d", limit, offset)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// getIndexSql generate index sql.
|
||||
func (t *dbTables) getIndexSql(tableName string, useIndex int, indexes []string) (clause string) {
|
||||
if len(indexes) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
return t.base.GenerateSpecifyIndex(tableName, useIndex, indexes)
|
||||
}
|
||||
|
||||
// crete new tables collection.
|
||||
func newDbTables(mi *modelInfo, base dbBaser) *dbTables {
|
||||
tables := &dbTables{}
|
||||
tables.tablesM = make(map[string]*dbTable)
|
||||
tables.mi = mi
|
||||
tables.base = base
|
||||
return tables
|
||||
}
|
63
pkg/client/orm/db_tidb.go
Normal file
63
pkg/client/orm/db_tidb.go
Normal file
@ -0,0 +1,63 @@
|
||||
// Copyright 2015 TiDB Author. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package orm
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
)
|
||||
|
||||
// mysql dbBaser implementation.
|
||||
type dbBaseTidb struct {
|
||||
dbBase
|
||||
}
|
||||
|
||||
var _ dbBaser = new(dbBaseTidb)
|
||||
|
||||
// get mysql operator.
|
||||
func (d *dbBaseTidb) OperatorSQL(operator string) string {
|
||||
return mysqlOperators[operator]
|
||||
}
|
||||
|
||||
// get mysql table field types.
|
||||
func (d *dbBaseTidb) DbTypes() map[string]string {
|
||||
return mysqlTypes
|
||||
}
|
||||
|
||||
// show table sql for mysql.
|
||||
func (d *dbBaseTidb) ShowTablesQuery() string {
|
||||
return "SELECT table_name FROM information_schema.tables WHERE table_type = 'BASE TABLE' AND table_schema = DATABASE()"
|
||||
}
|
||||
|
||||
// show columns sql of table for mysql.
|
||||
func (d *dbBaseTidb) ShowColumnsQuery(table string) string {
|
||||
return fmt.Sprintf("SELECT COLUMN_NAME, COLUMN_TYPE, IS_NULLABLE FROM information_schema.columns "+
|
||||
"WHERE table_schema = DATABASE() AND table_name = '%s'", table)
|
||||
}
|
||||
|
||||
// execute sql to check index exist.
|
||||
func (d *dbBaseTidb) IndexExists(db dbQuerier, table string, name string) bool {
|
||||
row := db.QueryRow("SELECT count(*) FROM information_schema.statistics "+
|
||||
"WHERE table_schema = DATABASE() AND table_name = ? AND index_name = ?", table, name)
|
||||
var cnt int
|
||||
row.Scan(&cnt)
|
||||
return cnt > 0
|
||||
}
|
||||
|
||||
// create new mysql dbBaser.
|
||||
func newdbBaseTidb() dbBaser {
|
||||
b := new(dbBaseTidb)
|
||||
b.ins = b
|
||||
return b
|
||||
}
|
177
pkg/client/orm/db_utils.go
Normal file
177
pkg/client/orm/db_utils.go
Normal file
@ -0,0 +1,177 @@
|
||||
// Copyright 2014 beego Author. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package orm
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"reflect"
|
||||
"time"
|
||||
)
|
||||
|
||||
// get table alias.
|
||||
func getDbAlias(name string) *alias {
|
||||
if al, ok := dataBaseCache.get(name); ok {
|
||||
return al
|
||||
}
|
||||
panic(fmt.Errorf("unknown DataBase alias name %s", name))
|
||||
}
|
||||
|
||||
// get pk column info.
|
||||
func getExistPk(mi *modelInfo, ind reflect.Value) (column string, value interface{}, exist bool) {
|
||||
fi := mi.fields.pk
|
||||
|
||||
v := ind.FieldByIndex(fi.fieldIndex)
|
||||
if fi.fieldType&IsPositiveIntegerField > 0 {
|
||||
vu := v.Uint()
|
||||
exist = vu > 0
|
||||
value = vu
|
||||
} else if fi.fieldType&IsIntegerField > 0 {
|
||||
vu := v.Int()
|
||||
exist = true
|
||||
value = vu
|
||||
} else if fi.fieldType&IsRelField > 0 {
|
||||
_, value, exist = getExistPk(fi.relModelInfo, reflect.Indirect(v))
|
||||
} else {
|
||||
vu := v.String()
|
||||
exist = vu != ""
|
||||
value = vu
|
||||
}
|
||||
|
||||
column = fi.column
|
||||
return
|
||||
}
|
||||
|
||||
// get fields description as flatted string.
|
||||
func getFlatParams(fi *fieldInfo, args []interface{}, tz *time.Location) (params []interface{}) {
|
||||
|
||||
outFor:
|
||||
for _, arg := range args {
|
||||
val := reflect.ValueOf(arg)
|
||||
|
||||
if arg == nil {
|
||||
params = append(params, arg)
|
||||
continue
|
||||
}
|
||||
|
||||
kind := val.Kind()
|
||||
if kind == reflect.Ptr {
|
||||
val = val.Elem()
|
||||
kind = val.Kind()
|
||||
arg = val.Interface()
|
||||
}
|
||||
|
||||
switch kind {
|
||||
case reflect.String:
|
||||
v := val.String()
|
||||
if fi != nil {
|
||||
if fi.fieldType == TypeTimeField || fi.fieldType == TypeDateField || fi.fieldType == TypeDateTimeField {
|
||||
var t time.Time
|
||||
var err error
|
||||
if len(v) >= 19 {
|
||||
s := v[:19]
|
||||
t, err = time.ParseInLocation(formatDateTime, s, DefaultTimeLoc)
|
||||
} else if len(v) >= 10 {
|
||||
s := v
|
||||
if len(v) > 10 {
|
||||
s = v[:10]
|
||||
}
|
||||
t, err = time.ParseInLocation(formatDate, s, tz)
|
||||
} else {
|
||||
s := v
|
||||
if len(s) > 8 {
|
||||
s = v[:8]
|
||||
}
|
||||
t, err = time.ParseInLocation(formatTime, s, tz)
|
||||
}
|
||||
if err == nil {
|
||||
if fi.fieldType == TypeDateField {
|
||||
v = t.In(tz).Format(formatDate)
|
||||
} else if fi.fieldType == TypeDateTimeField {
|
||||
v = t.In(tz).Format(formatDateTime)
|
||||
} else {
|
||||
v = t.In(tz).Format(formatTime)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
arg = v
|
||||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
|
||||
arg = val.Int()
|
||||
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
|
||||
arg = val.Uint()
|
||||
case reflect.Float32:
|
||||
arg, _ = StrTo(ToStr(arg)).Float64()
|
||||
case reflect.Float64:
|
||||
arg = val.Float()
|
||||
case reflect.Bool:
|
||||
arg = val.Bool()
|
||||
case reflect.Slice, reflect.Array:
|
||||
if _, ok := arg.([]byte); ok {
|
||||
continue outFor
|
||||
}
|
||||
|
||||
var args []interface{}
|
||||
for i := 0; i < val.Len(); i++ {
|
||||
v := val.Index(i)
|
||||
|
||||
var vu interface{}
|
||||
if v.CanInterface() {
|
||||
vu = v.Interface()
|
||||
}
|
||||
|
||||
if vu == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
args = append(args, vu)
|
||||
}
|
||||
|
||||
if len(args) > 0 {
|
||||
p := getFlatParams(fi, args, tz)
|
||||
params = append(params, p...)
|
||||
}
|
||||
continue outFor
|
||||
case reflect.Struct:
|
||||
if v, ok := arg.(time.Time); ok {
|
||||
if fi != nil && fi.fieldType == TypeDateField {
|
||||
arg = v.In(tz).Format(formatDate)
|
||||
} else if fi != nil && fi.fieldType == TypeDateTimeField {
|
||||
arg = v.In(tz).Format(formatDateTime)
|
||||
} else if fi != nil && fi.fieldType == TypeTimeField {
|
||||
arg = v.In(tz).Format(formatTime)
|
||||
} else {
|
||||
arg = v.In(tz).Format(formatDateTime)
|
||||
}
|
||||
} else {
|
||||
typ := val.Type()
|
||||
name := getFullName(typ)
|
||||
var value interface{}
|
||||
if mmi, ok := modelCache.getByFullName(name); ok {
|
||||
if _, vu, exist := getExistPk(mmi, val); exist {
|
||||
value = vu
|
||||
}
|
||||
}
|
||||
arg = value
|
||||
|
||||
if arg == nil {
|
||||
panic(fmt.Errorf("need a valid args value, unknown table or value `%s`", name))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
params = append(params, arg)
|
||||
}
|
||||
return
|
||||
}
|
180
pkg/client/orm/do_nothing_orm.go
Normal file
180
pkg/client/orm/do_nothing_orm.go
Normal file
@ -0,0 +1,180 @@
|
||||
// Copyright 2020 beego
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package orm
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
|
||||
"github.com/astaxie/beego/pkg/infrastructure/utils"
|
||||
)
|
||||
|
||||
// DoNothingOrm won't do anything, usually you use this to custom your mock Ormer implementation
|
||||
// I think golang mocking interface is hard to use
|
||||
// this may help you to integrate with Ormer
|
||||
|
||||
var _ Ormer = new(DoNothingOrm)
|
||||
|
||||
type DoNothingOrm struct {
|
||||
}
|
||||
|
||||
func (d *DoNothingOrm) Read(md interface{}, cols ...string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d *DoNothingOrm) ReadWithCtx(ctx context.Context, md interface{}, cols ...string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d *DoNothingOrm) ReadForUpdate(md interface{}, cols ...string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d *DoNothingOrm) ReadForUpdateWithCtx(ctx context.Context, md interface{}, cols ...string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d *DoNothingOrm) ReadOrCreate(md interface{}, col1 string, cols ...string) (bool, int64, error) {
|
||||
return false, 0, nil
|
||||
}
|
||||
|
||||
func (d *DoNothingOrm) ReadOrCreateWithCtx(ctx context.Context, md interface{}, col1 string, cols ...string) (bool, int64, error) {
|
||||
return false, 0, nil
|
||||
}
|
||||
|
||||
func (d *DoNothingOrm) LoadRelated(md interface{}, name string, args ...utils.KV) (int64, error) {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
func (d *DoNothingOrm) LoadRelatedWithCtx(ctx context.Context, md interface{}, name string, args ...utils.KV) (int64, error) {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
func (d *DoNothingOrm) QueryM2M(md interface{}, name string) QueryM2Mer {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d *DoNothingOrm) QueryM2MWithCtx(ctx context.Context, md interface{}, name string) QueryM2Mer {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d *DoNothingOrm) QueryTable(ptrStructOrTableName interface{}) QuerySeter {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d *DoNothingOrm) QueryTableWithCtx(ctx context.Context, ptrStructOrTableName interface{}) QuerySeter {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d *DoNothingOrm) DBStats() *sql.DBStats {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d *DoNothingOrm) Insert(md interface{}) (int64, error) {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
func (d *DoNothingOrm) InsertWithCtx(ctx context.Context, md interface{}) (int64, error) {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
func (d *DoNothingOrm) InsertOrUpdate(md interface{}, colConflitAndArgs ...string) (int64, error) {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
func (d *DoNothingOrm) InsertOrUpdateWithCtx(ctx context.Context, md interface{}, colConflitAndArgs ...string) (int64, error) {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
func (d *DoNothingOrm) InsertMulti(bulk int, mds interface{}) (int64, error) {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
func (d *DoNothingOrm) InsertMultiWithCtx(ctx context.Context, bulk int, mds interface{}) (int64, error) {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
func (d *DoNothingOrm) Update(md interface{}, cols ...string) (int64, error) {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
func (d *DoNothingOrm) UpdateWithCtx(ctx context.Context, md interface{}, cols ...string) (int64, error) {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
func (d *DoNothingOrm) Delete(md interface{}, cols ...string) (int64, error) {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
func (d *DoNothingOrm) DeleteWithCtx(ctx context.Context, md interface{}, cols ...string) (int64, error) {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
func (d *DoNothingOrm) Raw(query string, args ...interface{}) RawSeter {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d *DoNothingOrm) RawWithCtx(ctx context.Context, query string, args ...interface{}) RawSeter {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d *DoNothingOrm) Driver() Driver {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d *DoNothingOrm) Begin() (TxOrmer, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (d *DoNothingOrm) BeginWithCtx(ctx context.Context) (TxOrmer, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (d *DoNothingOrm) BeginWithOpts(opts *sql.TxOptions) (TxOrmer, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (d *DoNothingOrm) BeginWithCtxAndOpts(ctx context.Context, opts *sql.TxOptions) (TxOrmer, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (d *DoNothingOrm) DoTx(task func(ctx context.Context, txOrm TxOrmer) error) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d *DoNothingOrm) DoTxWithCtx(ctx context.Context, task func(ctx context.Context, txOrm TxOrmer) error) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d *DoNothingOrm) DoTxWithOpts(opts *sql.TxOptions, task func(ctx context.Context, txOrm TxOrmer) error) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d *DoNothingOrm) DoTxWithCtxAndOpts(ctx context.Context, opts *sql.TxOptions, task func(ctx context.Context, txOrm TxOrmer) error) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// DoNothingTxOrm is similar with DoNothingOrm, usually you use it to test
|
||||
type DoNothingTxOrm struct {
|
||||
DoNothingOrm
|
||||
}
|
||||
|
||||
func (d *DoNothingTxOrm) Commit() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d *DoNothingTxOrm) Rollback() error {
|
||||
return nil
|
||||
}
|
134
pkg/client/orm/do_nothing_orm_test.go
Normal file
134
pkg/client/orm/do_nothing_orm_test.go
Normal file
@ -0,0 +1,134 @@
|
||||
// Copyright 2020 beego
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package orm
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestDoNothingOrm(t *testing.T) {
|
||||
o := &DoNothingOrm{}
|
||||
err := o.DoTxWithCtxAndOpts(nil, nil, nil)
|
||||
assert.Nil(t, err)
|
||||
|
||||
err = o.DoTxWithCtx(nil, nil)
|
||||
assert.Nil(t, err)
|
||||
|
||||
err = o.DoTx(nil)
|
||||
assert.Nil(t, err)
|
||||
|
||||
err = o.DoTxWithOpts(nil, nil)
|
||||
assert.Nil(t, err)
|
||||
|
||||
assert.Nil(t, o.Driver())
|
||||
|
||||
assert.Nil(t, o.QueryM2MWithCtx(nil, nil, ""))
|
||||
assert.Nil(t, o.QueryM2M(nil, ""))
|
||||
assert.Nil(t, o.ReadWithCtx(nil, nil))
|
||||
assert.Nil(t, o.Read(nil))
|
||||
|
||||
txOrm, err := o.BeginWithCtxAndOpts(nil, nil)
|
||||
assert.Nil(t, err)
|
||||
assert.Nil(t, txOrm)
|
||||
|
||||
txOrm, err = o.BeginWithCtx(nil)
|
||||
assert.Nil(t, err)
|
||||
assert.Nil(t, txOrm)
|
||||
|
||||
txOrm, err = o.BeginWithOpts(nil)
|
||||
assert.Nil(t, err)
|
||||
assert.Nil(t, txOrm)
|
||||
|
||||
txOrm, err = o.Begin()
|
||||
assert.Nil(t, err)
|
||||
assert.Nil(t, txOrm)
|
||||
|
||||
assert.Nil(t, o.RawWithCtx(nil, ""))
|
||||
assert.Nil(t, o.Raw(""))
|
||||
|
||||
i, err := o.InsertMulti(0, nil)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, int64(0), i)
|
||||
|
||||
i, err = o.Insert(nil)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, int64(0), i)
|
||||
|
||||
i, err = o.InsertWithCtx(nil, nil)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, int64(0), i)
|
||||
|
||||
i, err = o.InsertOrUpdateWithCtx(nil, nil)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, int64(0), i)
|
||||
|
||||
i, err = o.InsertOrUpdate(nil)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, int64(0), i)
|
||||
|
||||
i, err = o.InsertMultiWithCtx(nil, 0, nil)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, int64(0), i)
|
||||
|
||||
i, err = o.LoadRelatedWithCtx(nil, nil, "")
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, int64(0), i)
|
||||
|
||||
i, err = o.LoadRelated(nil, "")
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, int64(0), i)
|
||||
|
||||
assert.Nil(t, o.QueryTableWithCtx(nil, nil))
|
||||
assert.Nil(t, o.QueryTable(nil))
|
||||
|
||||
assert.Nil(t, o.Read(nil))
|
||||
assert.Nil(t, o.ReadWithCtx(nil, nil))
|
||||
assert.Nil(t, o.ReadForUpdateWithCtx(nil, nil))
|
||||
assert.Nil(t, o.ReadForUpdate(nil))
|
||||
|
||||
ok, i, err := o.ReadOrCreate(nil, "")
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, int64(0), i)
|
||||
assert.False(t, ok)
|
||||
|
||||
ok, i, err = o.ReadOrCreateWithCtx(nil, nil, "")
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, int64(0), i)
|
||||
assert.False(t, ok)
|
||||
|
||||
i, err = o.Delete(nil)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, int64(0), i)
|
||||
|
||||
i, err = o.DeleteWithCtx(nil, nil)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, int64(0), i)
|
||||
|
||||
i, err = o.Update(nil)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, int64(0), i)
|
||||
|
||||
i, err = o.UpdateWithCtx(nil, nil)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, int64(0), i)
|
||||
|
||||
assert.Nil(t, o.DBStats())
|
||||
|
||||
to := &DoNothingTxOrm{}
|
||||
assert.Nil(t, to.Commit())
|
||||
assert.Nil(t, to.Rollback())
|
||||
}
|
40
pkg/client/orm/filter.go
Normal file
40
pkg/client/orm/filter.go
Normal file
@ -0,0 +1,40 @@
|
||||
// Copyright 2020 beego
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package orm
|
||||
|
||||
import (
|
||||
"context"
|
||||
)
|
||||
|
||||
// FilterChain is used to build a Filter
|
||||
// don't forget to call next(...) inside your Filter
|
||||
type FilterChain func(next Filter) Filter
|
||||
|
||||
// Filter's behavior is a little big strange.
|
||||
// it's only be called when users call methods of Ormer
|
||||
// return value is an array. it's a little bit hard to understand,
|
||||
// for example, the Ormer's Read method only return error
|
||||
// so the filter processing this method should return an array whose first element is error
|
||||
// and, Ormer's ReadOrCreateWithCtx return three values, so the Filter's result should contains three values
|
||||
type Filter func(ctx context.Context, inv *Invocation) []interface{}
|
||||
|
||||
var globalFilterChains = make([]FilterChain, 0, 4)
|
||||
|
||||
// AddGlobalFilterChain adds a new FilterChain
|
||||
// All orm instances built after this invocation will use this filterChain,
|
||||
// but instances built before this invocation will not be affected
|
||||
func AddGlobalFilterChain(filterChain ...FilterChain) {
|
||||
globalFilterChains = append(globalFilterChains, filterChain...)
|
||||
}
|
137
pkg/client/orm/filter/bean/default_value_filter.go
Normal file
137
pkg/client/orm/filter/bean/default_value_filter.go
Normal file
@ -0,0 +1,137 @@
|
||||
// Copyright 2020
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package bean
|
||||
|
||||
import (
|
||||
"context"
|
||||
"reflect"
|
||||
"strings"
|
||||
|
||||
"github.com/astaxie/beego/pkg/infrastructure/logs"
|
||||
|
||||
"github.com/astaxie/beego/pkg/client/orm"
|
||||
"github.com/astaxie/beego/pkg/infrastructure/bean"
|
||||
)
|
||||
|
||||
// DefaultValueFilterChainBuilder only works for InsertXXX method,
|
||||
// But InsertOrUpdate and InsertOrUpdateWithCtx is more dangerous than other methods.
|
||||
// so we won't handle those two methods unless you set includeInsertOrUpdate to true
|
||||
// And if the element is not pointer, this filter doesn't work
|
||||
type DefaultValueFilterChainBuilder struct {
|
||||
factory bean.AutoWireBeanFactory
|
||||
compatibleWithOldStyle bool
|
||||
|
||||
// only the includeInsertOrUpdate is true, this filter will handle those two methods
|
||||
includeInsertOrUpdate bool
|
||||
}
|
||||
|
||||
// NewDefaultValueFilterChainBuilder will create an instance of DefaultValueFilterChainBuilder
|
||||
// In beego v1.x, the default value config looks like orm:default(xxxx)
|
||||
// But the default value in 2.x is default:xxx
|
||||
// so if you want to be compatible with v1.x, please pass true as compatibleWithOldStyle
|
||||
func NewDefaultValueFilterChainBuilder(typeAdapters map[string]bean.TypeAdapter,
|
||||
includeInsertOrUpdate bool,
|
||||
compatibleWithOldStyle bool) *DefaultValueFilterChainBuilder {
|
||||
factory := bean.NewTagAutoWireBeanFactory()
|
||||
|
||||
if compatibleWithOldStyle {
|
||||
newParser := factory.FieldTagParser
|
||||
factory.FieldTagParser = func(field reflect.StructField) *bean.FieldMetadata {
|
||||
if newParser != nil && field.Tag.Get(bean.DefaultValueTagKey) != "" {
|
||||
return newParser(field)
|
||||
} else {
|
||||
res := &bean.FieldMetadata{}
|
||||
ormMeta := field.Tag.Get("orm")
|
||||
ormMetaParts := strings.Split(ormMeta, ";")
|
||||
for _, p := range ormMetaParts {
|
||||
if strings.HasPrefix(p, "default(") && strings.HasSuffix(p, ")") {
|
||||
res.DftValue = p[8 : len(p)-1]
|
||||
}
|
||||
}
|
||||
return res
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for k, v := range typeAdapters {
|
||||
factory.Adapters[k] = v
|
||||
}
|
||||
|
||||
return &DefaultValueFilterChainBuilder{
|
||||
factory: factory,
|
||||
compatibleWithOldStyle: compatibleWithOldStyle,
|
||||
includeInsertOrUpdate: includeInsertOrUpdate,
|
||||
}
|
||||
}
|
||||
|
||||
func (d *DefaultValueFilterChainBuilder) FilterChain(next orm.Filter) orm.Filter {
|
||||
return func(ctx context.Context, inv *orm.Invocation) []interface{} {
|
||||
switch inv.Method {
|
||||
case "Insert", "InsertWithCtx":
|
||||
d.handleInsert(ctx, inv)
|
||||
break
|
||||
case "InsertOrUpdate", "InsertOrUpdateWithCtx":
|
||||
d.handleInsertOrUpdate(ctx, inv)
|
||||
break
|
||||
case "InsertMulti", "InsertMultiWithCtx":
|
||||
d.handleInsertMulti(ctx, inv)
|
||||
break
|
||||
}
|
||||
return next(ctx, inv)
|
||||
}
|
||||
}
|
||||
|
||||
func (d *DefaultValueFilterChainBuilder) handleInsert(ctx context.Context, inv *orm.Invocation) {
|
||||
d.setDefaultValue(ctx, inv.Args[0])
|
||||
}
|
||||
|
||||
func (d *DefaultValueFilterChainBuilder) handleInsertOrUpdate(ctx context.Context, inv *orm.Invocation) {
|
||||
if d.includeInsertOrUpdate {
|
||||
ins := inv.Args[0]
|
||||
if ins == nil {
|
||||
return
|
||||
}
|
||||
|
||||
pkName := inv.GetPkFieldName()
|
||||
pkField := reflect.Indirect(reflect.ValueOf(ins)).FieldByName(pkName)
|
||||
|
||||
if pkField.IsZero() {
|
||||
d.setDefaultValue(ctx, ins)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (d *DefaultValueFilterChainBuilder) handleInsertMulti(ctx context.Context, inv *orm.Invocation) {
|
||||
mds := inv.Args[1]
|
||||
|
||||
if t := reflect.TypeOf(mds).Kind(); t != reflect.Array && t != reflect.Slice {
|
||||
// do nothing
|
||||
return
|
||||
}
|
||||
|
||||
mdsArr := reflect.Indirect(reflect.ValueOf(mds))
|
||||
for i := 0; i < mdsArr.Len(); i++ {
|
||||
d.setDefaultValue(ctx, mdsArr.Index(i).Interface())
|
||||
}
|
||||
logs.Warn("%v", mdsArr.Index(0).Interface())
|
||||
}
|
||||
|
||||
func (d *DefaultValueFilterChainBuilder) setDefaultValue(ctx context.Context, ins interface{}) {
|
||||
err := d.factory.AutoWire(ctx, nil, ins)
|
||||
if err != nil {
|
||||
logs.Error("try to wire the bean for orm.Insert failed. "+
|
||||
"the default value is not set: %v, ", err)
|
||||
}
|
||||
}
|
72
pkg/client/orm/filter/bean/default_value_filter_test.go
Normal file
72
pkg/client/orm/filter/bean/default_value_filter_test.go
Normal file
@ -0,0 +1,72 @@
|
||||
// Copyright 2020
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package bean
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/astaxie/beego/pkg/client/orm"
|
||||
)
|
||||
|
||||
func TestDefaultValueFilterChainBuilder_FilterChain(t *testing.T) {
|
||||
builder := NewDefaultValueFilterChainBuilder(nil, true, true)
|
||||
o := orm.NewFilterOrmDecorator(&defaultValueTestOrm{}, builder.FilterChain)
|
||||
|
||||
// test insert
|
||||
entity := &DefaultValueTestEntity{}
|
||||
_, _ = o.Insert(entity)
|
||||
assert.Equal(t, 12, entity.Age)
|
||||
assert.Equal(t, 13, entity.AgeInOldStyle)
|
||||
assert.Equal(t, 0, entity.AgeIgnore)
|
||||
|
||||
// test InsertOrUpdate
|
||||
entity = &DefaultValueTestEntity{}
|
||||
orm.RegisterModel(entity)
|
||||
|
||||
_, _ = o.InsertOrUpdate(entity)
|
||||
assert.Equal(t, 12, entity.Age)
|
||||
assert.Equal(t, 13, entity.AgeInOldStyle)
|
||||
|
||||
// we won't set the default value because we find the pk is not Zero value
|
||||
entity.Id = 3
|
||||
entity.AgeInOldStyle = 0
|
||||
_, _ = o.InsertOrUpdate(entity)
|
||||
assert.Equal(t, 0, entity.AgeInOldStyle)
|
||||
|
||||
entity = &DefaultValueTestEntity{}
|
||||
|
||||
// the entity is not array, it will be ignored
|
||||
_, _ = o.InsertMulti(3, entity)
|
||||
assert.Equal(t, 0, entity.Age)
|
||||
assert.Equal(t, 0, entity.AgeInOldStyle)
|
||||
|
||||
_, _ = o.InsertMulti(3, []*DefaultValueTestEntity{entity})
|
||||
assert.Equal(t, 12, entity.Age)
|
||||
assert.Equal(t, 13, entity.AgeInOldStyle)
|
||||
|
||||
}
|
||||
|
||||
type defaultValueTestOrm struct {
|
||||
orm.DoNothingOrm
|
||||
}
|
||||
|
||||
type DefaultValueTestEntity struct {
|
||||
Id int
|
||||
Age int `default:"12"`
|
||||
AgeInOldStyle int `orm:"default(13);bee()"`
|
||||
AgeIgnore int
|
||||
}
|
71
pkg/client/orm/filter/opentracing/filter.go
Normal file
71
pkg/client/orm/filter/opentracing/filter.go
Normal file
@ -0,0 +1,71 @@
|
||||
// Copyright 2020 beego
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package opentracing
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
|
||||
"github.com/opentracing/opentracing-go"
|
||||
|
||||
"github.com/astaxie/beego/pkg/client/orm"
|
||||
)
|
||||
|
||||
// FilterChainBuilder provides an extension point
|
||||
// this Filter's behavior looks a little bit strange
|
||||
// for example:
|
||||
// if we want to trace QuerySetter
|
||||
// actually we trace invoking "QueryTable" and "QueryTableWithCtx"
|
||||
// the method Begin*, Commit and Rollback are ignored.
|
||||
// When use using those methods, it means that they want to manager their transaction manually, so we won't handle them.
|
||||
type FilterChainBuilder struct {
|
||||
// CustomSpanFunc users are able to custom their span
|
||||
CustomSpanFunc func(span opentracing.Span, ctx context.Context, inv *orm.Invocation)
|
||||
}
|
||||
|
||||
func (builder *FilterChainBuilder) FilterChain(next orm.Filter) orm.Filter {
|
||||
return func(ctx context.Context, inv *orm.Invocation) []interface{} {
|
||||
operationName := builder.operationName(ctx, inv)
|
||||
if strings.HasPrefix(inv.Method, "Begin") || inv.Method == "Commit" || inv.Method == "Rollback" {
|
||||
return next(ctx, inv)
|
||||
}
|
||||
|
||||
span, spanCtx := opentracing.StartSpanFromContext(ctx, operationName)
|
||||
defer span.Finish()
|
||||
res := next(spanCtx, inv)
|
||||
builder.buildSpan(span, spanCtx, inv)
|
||||
return res
|
||||
}
|
||||
}
|
||||
|
||||
func (builder *FilterChainBuilder) buildSpan(span opentracing.Span, ctx context.Context, inv *orm.Invocation) {
|
||||
span.SetTag("orm.method", inv.Method)
|
||||
span.SetTag("orm.table", inv.GetTableName())
|
||||
span.SetTag("orm.insideTx", inv.InsideTx)
|
||||
span.SetTag("orm.txName", ctx.Value(orm.TxNameKey))
|
||||
span.SetTag("span.kind", "client")
|
||||
span.SetTag("component", "beego")
|
||||
|
||||
if builder.CustomSpanFunc != nil {
|
||||
builder.CustomSpanFunc(span, ctx, inv)
|
||||
}
|
||||
}
|
||||
|
||||
func (builder *FilterChainBuilder) operationName(ctx context.Context, inv *orm.Invocation) string {
|
||||
if n, ok := ctx.Value(orm.TxNameKey).(string); ok {
|
||||
return inv.Method + "#tx(" + n + ")"
|
||||
}
|
||||
return inv.Method + "#" + inv.GetTableName()
|
||||
}
|
44
pkg/client/orm/filter/opentracing/filter_test.go
Normal file
44
pkg/client/orm/filter/opentracing/filter_test.go
Normal file
@ -0,0 +1,44 @@
|
||||
// Copyright 2020 beego
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package opentracing
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/opentracing/opentracing-go"
|
||||
|
||||
"github.com/astaxie/beego/pkg/client/orm"
|
||||
)
|
||||
|
||||
func TestFilterChainBuilder_FilterChain(t *testing.T) {
|
||||
next := func(ctx context.Context, inv *orm.Invocation) []interface{} {
|
||||
inv.TxName = "Hello"
|
||||
return []interface{}{}
|
||||
}
|
||||
|
||||
builder := &FilterChainBuilder{
|
||||
CustomSpanFunc: func(span opentracing.Span, ctx context.Context, inv *orm.Invocation) {
|
||||
span.SetTag("hello", "hell")
|
||||
},
|
||||
}
|
||||
|
||||
inv := &orm.Invocation{
|
||||
Method: "Hello",
|
||||
TxStartTime: time.Now(),
|
||||
}
|
||||
builder.FilterChain(next)(context.Background(), inv)
|
||||
}
|
89
pkg/client/orm/filter/prometheus/filter.go
Normal file
89
pkg/client/orm/filter/prometheus/filter.go
Normal file
@ -0,0 +1,89 @@
|
||||
// Copyright 2020 beego
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package prometheus
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/prometheus/client_golang/prometheus"
|
||||
|
||||
"github.com/astaxie/beego/pkg/client/orm"
|
||||
"github.com/astaxie/beego/pkg/server/web"
|
||||
)
|
||||
|
||||
// FilterChainBuilder is an extension point,
|
||||
// when we want to support some configuration,
|
||||
// please use this structure
|
||||
// this Filter's behavior looks a little bit strange
|
||||
// for example:
|
||||
// if we want to records the metrics of QuerySetter
|
||||
// actually we only records metrics of invoking "QueryTable" and "QueryTableWithCtx"
|
||||
type FilterChainBuilder struct {
|
||||
summaryVec prometheus.ObserverVec
|
||||
}
|
||||
|
||||
func NewFilterChainBuilder() *FilterChainBuilder {
|
||||
summaryVec := prometheus.NewSummaryVec(prometheus.SummaryOpts{
|
||||
Name: "beego",
|
||||
Subsystem: "orm_operation",
|
||||
ConstLabels: map[string]string{
|
||||
"server": web.BConfig.ServerName,
|
||||
"env": web.BConfig.RunMode,
|
||||
"appname": web.BConfig.AppName,
|
||||
},
|
||||
Help: "The statics info for orm operation",
|
||||
}, []string{"method", "name", "duration", "insideTx", "txName"})
|
||||
|
||||
prometheus.MustRegister(summaryVec)
|
||||
return &FilterChainBuilder{
|
||||
summaryVec: summaryVec,
|
||||
}
|
||||
}
|
||||
|
||||
func (builder *FilterChainBuilder) FilterChain(next orm.Filter) orm.Filter {
|
||||
return func(ctx context.Context, inv *orm.Invocation) []interface{} {
|
||||
startTime := time.Now()
|
||||
res := next(ctx, inv)
|
||||
endTime := time.Now()
|
||||
dur := (endTime.Sub(startTime)) / time.Millisecond
|
||||
|
||||
// if the TPS is too large, here may be some problem
|
||||
// thinking about using goroutine pool
|
||||
go builder.report(ctx, inv, dur)
|
||||
return res
|
||||
}
|
||||
}
|
||||
|
||||
func (builder *FilterChainBuilder) report(ctx context.Context, inv *orm.Invocation, dur time.Duration) {
|
||||
// start a transaction, we don't record it
|
||||
if strings.HasPrefix(inv.Method, "Begin") {
|
||||
return
|
||||
}
|
||||
if inv.Method == "Commit" || inv.Method == "Rollback" {
|
||||
builder.reportTxn(ctx, inv)
|
||||
return
|
||||
}
|
||||
builder.summaryVec.WithLabelValues(inv.Method, inv.GetTableName(), strconv.Itoa(int(dur)),
|
||||
strconv.FormatBool(inv.InsideTx), inv.TxName)
|
||||
}
|
||||
|
||||
func (builder *FilterChainBuilder) reportTxn(ctx context.Context, inv *orm.Invocation) {
|
||||
dur := time.Now().Sub(inv.TxStartTime) / time.Millisecond
|
||||
builder.summaryVec.WithLabelValues(inv.Method, inv.TxName, strconv.Itoa(int(dur)),
|
||||
strconv.FormatBool(inv.InsideTx), inv.TxName)
|
||||
}
|
61
pkg/client/orm/filter/prometheus/filter_test.go
Normal file
61
pkg/client/orm/filter/prometheus/filter_test.go
Normal file
@ -0,0 +1,61 @@
|
||||
// Copyright 2020 beego
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package prometheus
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/astaxie/beego/pkg/client/orm"
|
||||
)
|
||||
|
||||
func TestFilterChainBuilder_FilterChain(t *testing.T) {
|
||||
builder := NewFilterChainBuilder()
|
||||
assert.NotNil(t, builder.summaryVec)
|
||||
|
||||
filter := builder.FilterChain(func(ctx context.Context, inv *orm.Invocation) []interface{} {
|
||||
inv.Method = "coming"
|
||||
return []interface{}{}
|
||||
})
|
||||
assert.NotNil(t, filter)
|
||||
|
||||
inv := &orm.Invocation{}
|
||||
filter(context.Background(), inv)
|
||||
assert.Equal(t, "coming", inv.Method)
|
||||
|
||||
inv = &orm.Invocation{
|
||||
Method: "Hello",
|
||||
TxStartTime: time.Now(),
|
||||
}
|
||||
builder.reportTxn(context.Background(), inv)
|
||||
|
||||
inv = &orm.Invocation{
|
||||
Method: "Begin",
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
// it will be ignored
|
||||
builder.report(ctx, inv, time.Second)
|
||||
|
||||
inv.Method = "Commit"
|
||||
builder.report(ctx, inv, time.Second)
|
||||
|
||||
inv.Method = "Update"
|
||||
builder.report(ctx, inv, time.Second)
|
||||
|
||||
}
|
514
pkg/client/orm/filter_orm_decorator.go
Normal file
514
pkg/client/orm/filter_orm_decorator.go
Normal file
@ -0,0 +1,514 @@
|
||||
// Copyright 2020 beego
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package orm
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"reflect"
|
||||
"time"
|
||||
|
||||
"github.com/astaxie/beego/pkg/infrastructure/utils"
|
||||
)
|
||||
|
||||
const (
|
||||
TxNameKey = "TxName"
|
||||
)
|
||||
|
||||
var _ Ormer = new(filterOrmDecorator)
|
||||
var _ TxOrmer = new(filterOrmDecorator)
|
||||
|
||||
type filterOrmDecorator struct {
|
||||
ormer
|
||||
TxBeginner
|
||||
TxCommitter
|
||||
|
||||
root Filter
|
||||
|
||||
insideTx bool
|
||||
txStartTime time.Time
|
||||
txName string
|
||||
}
|
||||
|
||||
func NewFilterOrmDecorator(delegate Ormer, filterChains ...FilterChain) Ormer {
|
||||
res := &filterOrmDecorator{
|
||||
ormer: delegate,
|
||||
TxBeginner: delegate,
|
||||
root: func(ctx context.Context, inv *Invocation) []interface{} {
|
||||
return inv.execute(ctx)
|
||||
},
|
||||
}
|
||||
|
||||
for i := len(filterChains) - 1; i >= 0; i-- {
|
||||
node := filterChains[i]
|
||||
res.root = node(res.root)
|
||||
}
|
||||
return res
|
||||
}
|
||||
|
||||
func NewFilterTxOrmDecorator(delegate TxOrmer, root Filter, txName string) TxOrmer {
|
||||
res := &filterOrmDecorator{
|
||||
ormer: delegate,
|
||||
TxCommitter: delegate,
|
||||
root: root,
|
||||
insideTx: true,
|
||||
txStartTime: time.Now(),
|
||||
txName: txName,
|
||||
}
|
||||
return res
|
||||
}
|
||||
|
||||
func (f *filterOrmDecorator) Read(md interface{}, cols ...string) error {
|
||||
return f.ReadWithCtx(context.Background(), md, cols...)
|
||||
}
|
||||
|
||||
func (f *filterOrmDecorator) ReadWithCtx(ctx context.Context, md interface{}, cols ...string) error {
|
||||
mi, _ := modelCache.getByMd(md)
|
||||
inv := &Invocation{
|
||||
Method: "ReadWithCtx",
|
||||
Args: []interface{}{md, cols},
|
||||
Md: md,
|
||||
mi: mi,
|
||||
InsideTx: f.insideTx,
|
||||
TxStartTime: f.txStartTime,
|
||||
f: func(c context.Context) []interface{} {
|
||||
err := f.ormer.ReadWithCtx(c, md, cols...)
|
||||
return []interface{}{err}
|
||||
},
|
||||
}
|
||||
res := f.root(ctx, inv)
|
||||
return f.convertError(res[0])
|
||||
}
|
||||
|
||||
func (f *filterOrmDecorator) ReadForUpdate(md interface{}, cols ...string) error {
|
||||
return f.ReadForUpdateWithCtx(context.Background(), md, cols...)
|
||||
}
|
||||
|
||||
func (f *filterOrmDecorator) ReadForUpdateWithCtx(ctx context.Context, md interface{}, cols ...string) error {
|
||||
mi, _ := modelCache.getByMd(md)
|
||||
inv := &Invocation{
|
||||
Method: "ReadForUpdateWithCtx",
|
||||
Args: []interface{}{md, cols},
|
||||
Md: md,
|
||||
mi: mi,
|
||||
InsideTx: f.insideTx,
|
||||
TxStartTime: f.txStartTime,
|
||||
f: func(c context.Context) []interface{} {
|
||||
err := f.ormer.ReadForUpdateWithCtx(c, md, cols...)
|
||||
return []interface{}{err}
|
||||
},
|
||||
}
|
||||
res := f.root(ctx, inv)
|
||||
return f.convertError(res[0])
|
||||
}
|
||||
|
||||
func (f *filterOrmDecorator) ReadOrCreate(md interface{}, col1 string, cols ...string) (bool, int64, error) {
|
||||
return f.ReadOrCreateWithCtx(context.Background(), md, col1, cols...)
|
||||
}
|
||||
|
||||
func (f *filterOrmDecorator) ReadOrCreateWithCtx(ctx context.Context, md interface{}, col1 string, cols ...string) (bool, int64, error) {
|
||||
|
||||
mi, _ := modelCache.getByMd(md)
|
||||
inv := &Invocation{
|
||||
Method: "ReadOrCreateWithCtx",
|
||||
Args: []interface{}{md, col1, cols},
|
||||
Md: md,
|
||||
mi: mi,
|
||||
InsideTx: f.insideTx,
|
||||
TxStartTime: f.txStartTime,
|
||||
f: func(c context.Context) []interface{} {
|
||||
ok, res, err := f.ormer.ReadOrCreateWithCtx(c, md, col1, cols...)
|
||||
return []interface{}{ok, res, err}
|
||||
},
|
||||
}
|
||||
res := f.root(ctx, inv)
|
||||
return res[0].(bool), res[1].(int64), f.convertError(res[2])
|
||||
}
|
||||
|
||||
func (f *filterOrmDecorator) LoadRelated(md interface{}, name string, args ...utils.KV) (int64, error) {
|
||||
return f.LoadRelatedWithCtx(context.Background(), md, name, args...)
|
||||
}
|
||||
|
||||
func (f *filterOrmDecorator) LoadRelatedWithCtx(ctx context.Context, md interface{}, name string, args ...utils.KV) (int64, error) {
|
||||
|
||||
mi, _ := modelCache.getByMd(md)
|
||||
inv := &Invocation{
|
||||
Method: "LoadRelatedWithCtx",
|
||||
Args: []interface{}{md, name, args},
|
||||
Md: md,
|
||||
mi: mi,
|
||||
InsideTx: f.insideTx,
|
||||
TxStartTime: f.txStartTime,
|
||||
f: func(c context.Context) []interface{} {
|
||||
res, err := f.ormer.LoadRelatedWithCtx(c, md, name, args...)
|
||||
return []interface{}{res, err}
|
||||
},
|
||||
}
|
||||
res := f.root(ctx, inv)
|
||||
return res[0].(int64), f.convertError(res[1])
|
||||
}
|
||||
|
||||
func (f *filterOrmDecorator) QueryM2M(md interface{}, name string) QueryM2Mer {
|
||||
return f.QueryM2MWithCtx(context.Background(), md, name)
|
||||
}
|
||||
|
||||
func (f *filterOrmDecorator) QueryM2MWithCtx(ctx context.Context, md interface{}, name string) QueryM2Mer {
|
||||
|
||||
mi, _ := modelCache.getByMd(md)
|
||||
inv := &Invocation{
|
||||
Method: "QueryM2MWithCtx",
|
||||
Args: []interface{}{md, name},
|
||||
Md: md,
|
||||
mi: mi,
|
||||
InsideTx: f.insideTx,
|
||||
TxStartTime: f.txStartTime,
|
||||
f: func(c context.Context) []interface{} {
|
||||
res := f.ormer.QueryM2MWithCtx(c, md, name)
|
||||
return []interface{}{res}
|
||||
},
|
||||
}
|
||||
res := f.root(ctx, inv)
|
||||
if res[0] == nil {
|
||||
return nil
|
||||
}
|
||||
return res[0].(QueryM2Mer)
|
||||
}
|
||||
|
||||
func (f *filterOrmDecorator) QueryTable(ptrStructOrTableName interface{}) QuerySeter {
|
||||
return f.QueryTableWithCtx(context.Background(), ptrStructOrTableName)
|
||||
}
|
||||
|
||||
func (f *filterOrmDecorator) QueryTableWithCtx(ctx context.Context, ptrStructOrTableName interface{}) QuerySeter {
|
||||
var (
|
||||
name string
|
||||
md interface{}
|
||||
mi *modelInfo
|
||||
)
|
||||
|
||||
if table, ok := ptrStructOrTableName.(string); ok {
|
||||
name = table
|
||||
} else {
|
||||
name = getFullName(indirectType(reflect.TypeOf(ptrStructOrTableName)))
|
||||
md = ptrStructOrTableName
|
||||
}
|
||||
|
||||
if m, ok := modelCache.getByFullName(name); ok {
|
||||
mi = m
|
||||
}
|
||||
|
||||
inv := &Invocation{
|
||||
Method: "QueryTableWithCtx",
|
||||
Args: []interface{}{ptrStructOrTableName},
|
||||
InsideTx: f.insideTx,
|
||||
TxStartTime: f.txStartTime,
|
||||
Md: md,
|
||||
mi: mi,
|
||||
f: func(c context.Context) []interface{} {
|
||||
res := f.ormer.QueryTableWithCtx(c, ptrStructOrTableName)
|
||||
return []interface{}{res}
|
||||
},
|
||||
}
|
||||
res := f.root(ctx, inv)
|
||||
|
||||
if res[0] == nil {
|
||||
return nil
|
||||
}
|
||||
return res[0].(QuerySeter)
|
||||
}
|
||||
|
||||
func (f *filterOrmDecorator) DBStats() *sql.DBStats {
|
||||
inv := &Invocation{
|
||||
Method: "DBStats",
|
||||
InsideTx: f.insideTx,
|
||||
TxStartTime: f.txStartTime,
|
||||
f: func(c context.Context) []interface{} {
|
||||
res := f.ormer.DBStats()
|
||||
return []interface{}{res}
|
||||
},
|
||||
}
|
||||
res := f.root(context.Background(), inv)
|
||||
|
||||
if res[0] == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return res[0].(*sql.DBStats)
|
||||
}
|
||||
|
||||
func (f *filterOrmDecorator) Insert(md interface{}) (int64, error) {
|
||||
return f.InsertWithCtx(context.Background(), md)
|
||||
}
|
||||
|
||||
func (f *filterOrmDecorator) InsertWithCtx(ctx context.Context, md interface{}) (int64, error) {
|
||||
mi, _ := modelCache.getByMd(md)
|
||||
inv := &Invocation{
|
||||
Method: "InsertWithCtx",
|
||||
Args: []interface{}{md},
|
||||
Md: md,
|
||||
mi: mi,
|
||||
InsideTx: f.insideTx,
|
||||
TxStartTime: f.txStartTime,
|
||||
f: func(c context.Context) []interface{} {
|
||||
res, err := f.ormer.InsertWithCtx(c, md)
|
||||
return []interface{}{res, err}
|
||||
},
|
||||
}
|
||||
res := f.root(ctx, inv)
|
||||
return res[0].(int64), f.convertError(res[1])
|
||||
}
|
||||
|
||||
func (f *filterOrmDecorator) InsertOrUpdate(md interface{}, colConflitAndArgs ...string) (int64, error) {
|
||||
return f.InsertOrUpdateWithCtx(context.Background(), md, colConflitAndArgs...)
|
||||
}
|
||||
|
||||
func (f *filterOrmDecorator) InsertOrUpdateWithCtx(ctx context.Context, md interface{}, colConflitAndArgs ...string) (int64, error) {
|
||||
mi, _ := modelCache.getByMd(md)
|
||||
inv := &Invocation{
|
||||
Method: "InsertOrUpdateWithCtx",
|
||||
Args: []interface{}{md, colConflitAndArgs},
|
||||
Md: md,
|
||||
mi: mi,
|
||||
InsideTx: f.insideTx,
|
||||
TxStartTime: f.txStartTime,
|
||||
f: func(c context.Context) []interface{} {
|
||||
res, err := f.ormer.InsertOrUpdateWithCtx(c, md, colConflitAndArgs...)
|
||||
return []interface{}{res, err}
|
||||
},
|
||||
}
|
||||
res := f.root(ctx, inv)
|
||||
return res[0].(int64), f.convertError(res[1])
|
||||
}
|
||||
|
||||
func (f *filterOrmDecorator) InsertMulti(bulk int, mds interface{}) (int64, error) {
|
||||
return f.InsertMultiWithCtx(context.Background(), bulk, mds)
|
||||
}
|
||||
|
||||
// InsertMultiWithCtx uses the first element's model info
|
||||
func (f *filterOrmDecorator) InsertMultiWithCtx(ctx context.Context, bulk int, mds interface{}) (int64, error) {
|
||||
var (
|
||||
md interface{}
|
||||
mi *modelInfo
|
||||
)
|
||||
|
||||
sind := reflect.Indirect(reflect.ValueOf(mds))
|
||||
|
||||
if (sind.Kind() == reflect.Array || sind.Kind() == reflect.Slice) && sind.Len() > 0 {
|
||||
ind := reflect.Indirect(sind.Index(0))
|
||||
md = ind.Interface()
|
||||
mi, _ = modelCache.getByMd(md)
|
||||
}
|
||||
|
||||
inv := &Invocation{
|
||||
Method: "InsertMultiWithCtx",
|
||||
Args: []interface{}{bulk, mds},
|
||||
Md: md,
|
||||
mi: mi,
|
||||
InsideTx: f.insideTx,
|
||||
TxStartTime: f.txStartTime,
|
||||
f: func(c context.Context) []interface{} {
|
||||
res, err := f.ormer.InsertMultiWithCtx(c, bulk, mds)
|
||||
return []interface{}{res, err}
|
||||
},
|
||||
}
|
||||
res := f.root(ctx, inv)
|
||||
return res[0].(int64), f.convertError(res[1])
|
||||
}
|
||||
|
||||
func (f *filterOrmDecorator) Update(md interface{}, cols ...string) (int64, error) {
|
||||
return f.UpdateWithCtx(context.Background(), md, cols...)
|
||||
}
|
||||
|
||||
func (f *filterOrmDecorator) UpdateWithCtx(ctx context.Context, md interface{}, cols ...string) (int64, error) {
|
||||
mi, _ := modelCache.getByMd(md)
|
||||
inv := &Invocation{
|
||||
Method: "UpdateWithCtx",
|
||||
Args: []interface{}{md, cols},
|
||||
Md: md,
|
||||
mi: mi,
|
||||
InsideTx: f.insideTx,
|
||||
TxStartTime: f.txStartTime,
|
||||
f: func(c context.Context) []interface{} {
|
||||
res, err := f.ormer.UpdateWithCtx(c, md, cols...)
|
||||
return []interface{}{res, err}
|
||||
},
|
||||
}
|
||||
res := f.root(ctx, inv)
|
||||
return res[0].(int64), f.convertError(res[1])
|
||||
}
|
||||
|
||||
func (f *filterOrmDecorator) Delete(md interface{}, cols ...string) (int64, error) {
|
||||
return f.DeleteWithCtx(context.Background(), md, cols...)
|
||||
}
|
||||
|
||||
func (f *filterOrmDecorator) DeleteWithCtx(ctx context.Context, md interface{}, cols ...string) (int64, error) {
|
||||
mi, _ := modelCache.getByMd(md)
|
||||
inv := &Invocation{
|
||||
Method: "DeleteWithCtx",
|
||||
Args: []interface{}{md, cols},
|
||||
Md: md,
|
||||
mi: mi,
|
||||
InsideTx: f.insideTx,
|
||||
TxStartTime: f.txStartTime,
|
||||
f: func(c context.Context) []interface{} {
|
||||
res, err := f.ormer.DeleteWithCtx(c, md, cols...)
|
||||
return []interface{}{res, err}
|
||||
},
|
||||
}
|
||||
res := f.root(ctx, inv)
|
||||
return res[0].(int64), f.convertError(res[1])
|
||||
}
|
||||
|
||||
func (f *filterOrmDecorator) Raw(query string, args ...interface{}) RawSeter {
|
||||
return f.RawWithCtx(context.Background(), query, args...)
|
||||
}
|
||||
|
||||
func (f *filterOrmDecorator) RawWithCtx(ctx context.Context, query string, args ...interface{}) RawSeter {
|
||||
inv := &Invocation{
|
||||
Method: "RawWithCtx",
|
||||
Args: []interface{}{query, args},
|
||||
InsideTx: f.insideTx,
|
||||
TxStartTime: f.txStartTime,
|
||||
f: func(c context.Context) []interface{} {
|
||||
res := f.ormer.RawWithCtx(c, query, args...)
|
||||
return []interface{}{res}
|
||||
},
|
||||
}
|
||||
res := f.root(ctx, inv)
|
||||
|
||||
if res[0] == nil {
|
||||
return nil
|
||||
}
|
||||
return res[0].(RawSeter)
|
||||
}
|
||||
|
||||
func (f *filterOrmDecorator) Driver() Driver {
|
||||
inv := &Invocation{
|
||||
Method: "Driver",
|
||||
InsideTx: f.insideTx,
|
||||
TxStartTime: f.txStartTime,
|
||||
f: func(c context.Context) []interface{} {
|
||||
res := f.ormer.Driver()
|
||||
return []interface{}{res}
|
||||
},
|
||||
}
|
||||
res := f.root(context.Background(), inv)
|
||||
if res[0] == nil {
|
||||
return nil
|
||||
}
|
||||
return res[0].(Driver)
|
||||
}
|
||||
|
||||
func (f *filterOrmDecorator) Begin() (TxOrmer, error) {
|
||||
return f.BeginWithCtxAndOpts(context.Background(), nil)
|
||||
}
|
||||
|
||||
func (f *filterOrmDecorator) BeginWithCtx(ctx context.Context) (TxOrmer, error) {
|
||||
return f.BeginWithCtxAndOpts(ctx, nil)
|
||||
}
|
||||
|
||||
func (f *filterOrmDecorator) BeginWithOpts(opts *sql.TxOptions) (TxOrmer, error) {
|
||||
return f.BeginWithCtxAndOpts(context.Background(), opts)
|
||||
}
|
||||
|
||||
func (f *filterOrmDecorator) BeginWithCtxAndOpts(ctx context.Context, opts *sql.TxOptions) (TxOrmer, error) {
|
||||
inv := &Invocation{
|
||||
Method: "BeginWithCtxAndOpts",
|
||||
Args: []interface{}{opts},
|
||||
InsideTx: f.insideTx,
|
||||
TxStartTime: f.txStartTime,
|
||||
f: func(c context.Context) []interface{} {
|
||||
res, err := f.TxBeginner.BeginWithCtxAndOpts(c, opts)
|
||||
res = NewFilterTxOrmDecorator(res, f.root, getTxNameFromCtx(c))
|
||||
return []interface{}{res, err}
|
||||
},
|
||||
}
|
||||
res := f.root(ctx, inv)
|
||||
return res[0].(TxOrmer), f.convertError(res[1])
|
||||
}
|
||||
|
||||
func (f *filterOrmDecorator) DoTx(task func(ctx context.Context, txOrm TxOrmer) error) error {
|
||||
return f.DoTxWithCtxAndOpts(context.Background(), nil, task)
|
||||
}
|
||||
|
||||
func (f *filterOrmDecorator) DoTxWithCtx(ctx context.Context, task func(ctx context.Context, txOrm TxOrmer) error) error {
|
||||
return f.DoTxWithCtxAndOpts(ctx, nil, task)
|
||||
}
|
||||
|
||||
func (f *filterOrmDecorator) DoTxWithOpts(opts *sql.TxOptions, task func(ctx context.Context, txOrm TxOrmer) error) error {
|
||||
return f.DoTxWithCtxAndOpts(context.Background(), opts, task)
|
||||
}
|
||||
|
||||
func (f *filterOrmDecorator) DoTxWithCtxAndOpts(ctx context.Context, opts *sql.TxOptions, task func(ctx context.Context, txOrm TxOrmer) error) error {
|
||||
inv := &Invocation{
|
||||
Method: "DoTxWithCtxAndOpts",
|
||||
Args: []interface{}{opts, task},
|
||||
InsideTx: f.insideTx,
|
||||
TxStartTime: f.txStartTime,
|
||||
TxName: getTxNameFromCtx(ctx),
|
||||
f: func(c context.Context) []interface{} {
|
||||
err := doTxTemplate(f, c, opts, task)
|
||||
return []interface{}{err}
|
||||
},
|
||||
}
|
||||
res := f.root(ctx, inv)
|
||||
return f.convertError(res[0])
|
||||
}
|
||||
|
||||
func (f *filterOrmDecorator) Commit() error {
|
||||
inv := &Invocation{
|
||||
Method: "Commit",
|
||||
Args: []interface{}{},
|
||||
InsideTx: f.insideTx,
|
||||
TxStartTime: f.txStartTime,
|
||||
TxName: f.txName,
|
||||
f: func(c context.Context) []interface{} {
|
||||
err := f.TxCommitter.Commit()
|
||||
return []interface{}{err}
|
||||
},
|
||||
}
|
||||
res := f.root(context.Background(), inv)
|
||||
return f.convertError(res[0])
|
||||
}
|
||||
|
||||
func (f *filterOrmDecorator) Rollback() error {
|
||||
inv := &Invocation{
|
||||
Method: "Rollback",
|
||||
Args: []interface{}{},
|
||||
InsideTx: f.insideTx,
|
||||
TxStartTime: f.txStartTime,
|
||||
TxName: f.txName,
|
||||
f: func(c context.Context) []interface{} {
|
||||
err := f.TxCommitter.Rollback()
|
||||
return []interface{}{err}
|
||||
},
|
||||
}
|
||||
res := f.root(context.Background(), inv)
|
||||
return f.convertError(res[0])
|
||||
}
|
||||
|
||||
func (f *filterOrmDecorator) convertError(v interface{}) error {
|
||||
if v == nil {
|
||||
return nil
|
||||
}
|
||||
return v.(error)
|
||||
}
|
||||
|
||||
func getTxNameFromCtx(ctx context.Context) string {
|
||||
txName := ""
|
||||
if n, ok := ctx.Value(TxNameKey).(string); ok {
|
||||
txName = n
|
||||
}
|
||||
return txName
|
||||
}
|
434
pkg/client/orm/filter_orm_decorator_test.go
Normal file
434
pkg/client/orm/filter_orm_decorator_test.go
Normal file
@ -0,0 +1,434 @@
|
||||
// Copyright 2020 beego
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package orm
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/astaxie/beego/pkg/infrastructure/utils"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestFilterOrmDecorator_Read(t *testing.T) {
|
||||
|
||||
register()
|
||||
|
||||
o := &filterMockOrm{}
|
||||
od := NewFilterOrmDecorator(o, func(next Filter) Filter {
|
||||
return func(ctx context.Context, inv *Invocation) []interface{} {
|
||||
assert.Equal(t, "ReadWithCtx", inv.Method)
|
||||
assert.Equal(t, 2, len(inv.Args))
|
||||
assert.Equal(t, "FILTER_TEST", inv.GetTableName())
|
||||
return next(ctx, inv)
|
||||
}
|
||||
})
|
||||
|
||||
fte := &FilterTestEntity{}
|
||||
err := od.Read(fte)
|
||||
assert.NotNil(t, err)
|
||||
assert.Equal(t, "read error", err.Error())
|
||||
}
|
||||
|
||||
func TestFilterOrmDecorator_BeginTx(t *testing.T) {
|
||||
register()
|
||||
|
||||
o := &filterMockOrm{}
|
||||
od := NewFilterOrmDecorator(o, func(next Filter) Filter {
|
||||
return func(ctx context.Context, inv *Invocation) []interface{} {
|
||||
if inv.Method == "BeginWithCtxAndOpts" {
|
||||
assert.Equal(t, 1, len(inv.Args))
|
||||
assert.Equal(t, "", inv.GetTableName())
|
||||
assert.False(t, inv.InsideTx)
|
||||
} else if inv.Method == "Commit" {
|
||||
assert.Equal(t, 0, len(inv.Args))
|
||||
assert.Equal(t, "Commit_tx", inv.TxName)
|
||||
assert.Equal(t, "", inv.GetTableName())
|
||||
assert.True(t, inv.InsideTx)
|
||||
} else if inv.Method == "Rollback" {
|
||||
assert.Equal(t, 0, len(inv.Args))
|
||||
assert.Equal(t, "Rollback_tx", inv.TxName)
|
||||
assert.Equal(t, "", inv.GetTableName())
|
||||
assert.True(t, inv.InsideTx)
|
||||
} else {
|
||||
t.Fail()
|
||||
}
|
||||
|
||||
return next(ctx, inv)
|
||||
}
|
||||
})
|
||||
to, err := od.Begin()
|
||||
assert.True(t, validateBeginResult(t, to, err))
|
||||
|
||||
to, err = od.BeginWithOpts(nil)
|
||||
assert.True(t, validateBeginResult(t, to, err))
|
||||
|
||||
ctx := context.WithValue(context.Background(), TxNameKey, "Commit_tx")
|
||||
to, err = od.BeginWithCtx(ctx)
|
||||
assert.True(t, validateBeginResult(t, to, err))
|
||||
|
||||
err = to.Commit()
|
||||
assert.NotNil(t, err)
|
||||
assert.Equal(t, "commit", err.Error())
|
||||
|
||||
ctx = context.WithValue(context.Background(), TxNameKey, "Rollback_tx")
|
||||
to, err = od.BeginWithCtxAndOpts(ctx, nil)
|
||||
assert.True(t, validateBeginResult(t, to, err))
|
||||
|
||||
err = to.Rollback()
|
||||
assert.NotNil(t, err)
|
||||
assert.Equal(t, "rollback", err.Error())
|
||||
}
|
||||
|
||||
func TestFilterOrmDecorator_DBStats(t *testing.T) {
|
||||
o := &filterMockOrm{}
|
||||
od := NewFilterOrmDecorator(o, func(next Filter) Filter {
|
||||
return func(ctx context.Context, inv *Invocation) []interface{} {
|
||||
assert.Equal(t, "DBStats", inv.Method)
|
||||
assert.Equal(t, 0, len(inv.Args))
|
||||
assert.Equal(t, "", inv.GetTableName())
|
||||
return next(ctx, inv)
|
||||
}
|
||||
})
|
||||
res := od.DBStats()
|
||||
assert.NotNil(t, res)
|
||||
assert.Equal(t, -1, res.MaxOpenConnections)
|
||||
}
|
||||
|
||||
func TestFilterOrmDecorator_Delete(t *testing.T) {
|
||||
register()
|
||||
o := &filterMockOrm{}
|
||||
od := NewFilterOrmDecorator(o, func(next Filter) Filter {
|
||||
return func(ctx context.Context, inv *Invocation) []interface{} {
|
||||
assert.Equal(t, "DeleteWithCtx", inv.Method)
|
||||
assert.Equal(t, 2, len(inv.Args))
|
||||
assert.Equal(t, "FILTER_TEST", inv.GetTableName())
|
||||
return next(ctx, inv)
|
||||
}
|
||||
})
|
||||
res, err := od.Delete(&FilterTestEntity{})
|
||||
assert.NotNil(t, err)
|
||||
assert.Equal(t, "delete error", err.Error())
|
||||
assert.Equal(t, int64(-2), res)
|
||||
}
|
||||
|
||||
func TestFilterOrmDecorator_DoTx(t *testing.T) {
|
||||
o := &filterMockOrm{}
|
||||
od := NewFilterOrmDecorator(o, func(next Filter) Filter {
|
||||
return func(ctx context.Context, inv *Invocation) []interface{} {
|
||||
if inv.Method == "DoTxWithCtxAndOpts" {
|
||||
assert.Equal(t, 2, len(inv.Args))
|
||||
assert.Equal(t, "", inv.GetTableName())
|
||||
assert.False(t, inv.InsideTx)
|
||||
}
|
||||
return next(ctx, inv)
|
||||
}
|
||||
})
|
||||
|
||||
err := od.DoTx(func(c context.Context, txOrm TxOrmer) error {
|
||||
return nil
|
||||
})
|
||||
assert.NotNil(t, err)
|
||||
|
||||
err = od.DoTxWithCtx(context.Background(), func(c context.Context, txOrm TxOrmer) error {
|
||||
return nil
|
||||
})
|
||||
assert.NotNil(t, err)
|
||||
|
||||
err = od.DoTxWithOpts(nil, func(c context.Context, txOrm TxOrmer) error {
|
||||
return nil
|
||||
})
|
||||
assert.NotNil(t, err)
|
||||
|
||||
od = NewFilterOrmDecorator(o, func(next Filter) Filter {
|
||||
return func(ctx context.Context, inv *Invocation) []interface{} {
|
||||
if inv.Method == "DoTxWithCtxAndOpts" {
|
||||
assert.Equal(t, 2, len(inv.Args))
|
||||
assert.Equal(t, "", inv.GetTableName())
|
||||
assert.Equal(t, "do tx name", inv.TxName)
|
||||
assert.False(t, inv.InsideTx)
|
||||
}
|
||||
return next(ctx, inv)
|
||||
}
|
||||
})
|
||||
|
||||
ctx := context.WithValue(context.Background(), TxNameKey, "do tx name")
|
||||
err = od.DoTxWithCtxAndOpts(ctx, nil, func(c context.Context, txOrm TxOrmer) error {
|
||||
return nil
|
||||
})
|
||||
assert.NotNil(t, err)
|
||||
}
|
||||
|
||||
func TestFilterOrmDecorator_Driver(t *testing.T) {
|
||||
o := &filterMockOrm{}
|
||||
od := NewFilterOrmDecorator(o, func(next Filter) Filter {
|
||||
return func(ctx context.Context, inv *Invocation) []interface{} {
|
||||
assert.Equal(t, "Driver", inv.Method)
|
||||
assert.Equal(t, 0, len(inv.Args))
|
||||
assert.Equal(t, "", inv.GetTableName())
|
||||
assert.False(t, inv.InsideTx)
|
||||
return next(ctx, inv)
|
||||
}
|
||||
})
|
||||
res := od.Driver()
|
||||
assert.Nil(t, res)
|
||||
}
|
||||
|
||||
func TestFilterOrmDecorator_Insert(t *testing.T) {
|
||||
register()
|
||||
o := &filterMockOrm{}
|
||||
od := NewFilterOrmDecorator(o, func(next Filter) Filter {
|
||||
return func(ctx context.Context, inv *Invocation) []interface{} {
|
||||
assert.Equal(t, "InsertWithCtx", inv.Method)
|
||||
assert.Equal(t, 1, len(inv.Args))
|
||||
assert.Equal(t, "FILTER_TEST", inv.GetTableName())
|
||||
assert.False(t, inv.InsideTx)
|
||||
return next(ctx, inv)
|
||||
}
|
||||
})
|
||||
|
||||
i, err := od.Insert(&FilterTestEntity{})
|
||||
assert.NotNil(t, err)
|
||||
assert.Equal(t, "insert error", err.Error())
|
||||
assert.Equal(t, int64(100), i)
|
||||
}
|
||||
|
||||
func TestFilterOrmDecorator_InsertMulti(t *testing.T) {
|
||||
register()
|
||||
o := &filterMockOrm{}
|
||||
od := NewFilterOrmDecorator(o, func(next Filter) Filter {
|
||||
return func(ctx context.Context, inv *Invocation) []interface{} {
|
||||
assert.Equal(t, "InsertMultiWithCtx", inv.Method)
|
||||
assert.Equal(t, 2, len(inv.Args))
|
||||
assert.Equal(t, "FILTER_TEST", inv.GetTableName())
|
||||
assert.False(t, inv.InsideTx)
|
||||
return next(ctx, inv)
|
||||
}
|
||||
})
|
||||
|
||||
bulk := []*FilterTestEntity{&FilterTestEntity{}, &FilterTestEntity{}}
|
||||
i, err := od.InsertMulti(2, bulk)
|
||||
assert.NotNil(t, err)
|
||||
assert.Equal(t, "insert multi error", err.Error())
|
||||
assert.Equal(t, int64(2), i)
|
||||
}
|
||||
|
||||
func TestFilterOrmDecorator_InsertOrUpdate(t *testing.T) {
|
||||
register()
|
||||
o := &filterMockOrm{}
|
||||
od := NewFilterOrmDecorator(o, func(next Filter) Filter {
|
||||
return func(ctx context.Context, inv *Invocation) []interface{} {
|
||||
assert.Equal(t, "InsertOrUpdateWithCtx", inv.Method)
|
||||
assert.Equal(t, 2, len(inv.Args))
|
||||
assert.Equal(t, "FILTER_TEST", inv.GetTableName())
|
||||
assert.False(t, inv.InsideTx)
|
||||
return next(ctx, inv)
|
||||
}
|
||||
})
|
||||
i, err := od.InsertOrUpdate(&FilterTestEntity{})
|
||||
assert.NotNil(t, err)
|
||||
assert.Equal(t, "insert or update error", err.Error())
|
||||
assert.Equal(t, int64(1), i)
|
||||
}
|
||||
|
||||
func TestFilterOrmDecorator_LoadRelated(t *testing.T) {
|
||||
o := &filterMockOrm{}
|
||||
od := NewFilterOrmDecorator(o, func(next Filter) Filter {
|
||||
return func(ctx context.Context, inv *Invocation) []interface{} {
|
||||
assert.Equal(t, "LoadRelatedWithCtx", inv.Method)
|
||||
assert.Equal(t, 3, len(inv.Args))
|
||||
assert.Equal(t, "FILTER_TEST", inv.GetTableName())
|
||||
assert.False(t, inv.InsideTx)
|
||||
return next(ctx, inv)
|
||||
}
|
||||
})
|
||||
i, err := od.LoadRelated(&FilterTestEntity{}, "hello")
|
||||
assert.NotNil(t, err)
|
||||
assert.Equal(t, "load related error", err.Error())
|
||||
assert.Equal(t, int64(99), i)
|
||||
}
|
||||
|
||||
func TestFilterOrmDecorator_QueryM2M(t *testing.T) {
|
||||
o := &filterMockOrm{}
|
||||
od := NewFilterOrmDecorator(o, func(next Filter) Filter {
|
||||
return func(ctx context.Context, inv *Invocation) []interface{} {
|
||||
assert.Equal(t, "QueryM2MWithCtx", inv.Method)
|
||||
assert.Equal(t, 2, len(inv.Args))
|
||||
assert.Equal(t, "FILTER_TEST", inv.GetTableName())
|
||||
assert.False(t, inv.InsideTx)
|
||||
return next(ctx, inv)
|
||||
}
|
||||
})
|
||||
res := od.QueryM2M(&FilterTestEntity{}, "hello")
|
||||
assert.Nil(t, res)
|
||||
}
|
||||
|
||||
func TestFilterOrmDecorator_QueryTable(t *testing.T) {
|
||||
register()
|
||||
o := &filterMockOrm{}
|
||||
od := NewFilterOrmDecorator(o, func(next Filter) Filter {
|
||||
return func(ctx context.Context, inv *Invocation) []interface{} {
|
||||
assert.Equal(t, "QueryTableWithCtx", inv.Method)
|
||||
assert.Equal(t, 1, len(inv.Args))
|
||||
assert.Equal(t, "FILTER_TEST", inv.GetTableName())
|
||||
assert.False(t, inv.InsideTx)
|
||||
return next(ctx, inv)
|
||||
}
|
||||
})
|
||||
res := od.QueryTable(&FilterTestEntity{})
|
||||
assert.Nil(t, res)
|
||||
}
|
||||
|
||||
func TestFilterOrmDecorator_Raw(t *testing.T) {
|
||||
register()
|
||||
o := &filterMockOrm{}
|
||||
od := NewFilterOrmDecorator(o, func(next Filter) Filter {
|
||||
return func(ctx context.Context, inv *Invocation) []interface{} {
|
||||
assert.Equal(t, "RawWithCtx", inv.Method)
|
||||
assert.Equal(t, 2, len(inv.Args))
|
||||
assert.Equal(t, "", inv.GetTableName())
|
||||
assert.False(t, inv.InsideTx)
|
||||
return next(ctx, inv)
|
||||
}
|
||||
})
|
||||
res := od.Raw("hh")
|
||||
assert.Nil(t, res)
|
||||
}
|
||||
|
||||
func TestFilterOrmDecorator_ReadForUpdate(t *testing.T) {
|
||||
register()
|
||||
o := &filterMockOrm{}
|
||||
od := NewFilterOrmDecorator(o, func(next Filter) Filter {
|
||||
return func(ctx context.Context, inv *Invocation) []interface{} {
|
||||
assert.Equal(t, "ReadForUpdateWithCtx", inv.Method)
|
||||
assert.Equal(t, 2, len(inv.Args))
|
||||
assert.Equal(t, "FILTER_TEST", inv.GetTableName())
|
||||
assert.False(t, inv.InsideTx)
|
||||
return next(ctx, inv)
|
||||
}
|
||||
})
|
||||
err := od.ReadForUpdate(&FilterTestEntity{})
|
||||
assert.NotNil(t, err)
|
||||
assert.Equal(t, "read for update error", err.Error())
|
||||
}
|
||||
|
||||
func TestFilterOrmDecorator_ReadOrCreate(t *testing.T) {
|
||||
register()
|
||||
o := &filterMockOrm{}
|
||||
od := NewFilterOrmDecorator(o, func(next Filter) Filter {
|
||||
return func(ctx context.Context, inv *Invocation) []interface{} {
|
||||
assert.Equal(t, "ReadOrCreateWithCtx", inv.Method)
|
||||
assert.Equal(t, 3, len(inv.Args))
|
||||
assert.Equal(t, "FILTER_TEST", inv.GetTableName())
|
||||
assert.False(t, inv.InsideTx)
|
||||
return next(ctx, inv)
|
||||
}
|
||||
})
|
||||
ok, i, err := od.ReadOrCreate(&FilterTestEntity{}, "name")
|
||||
assert.NotNil(t, err)
|
||||
assert.Equal(t, "read or create error", err.Error())
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, int64(13), i)
|
||||
}
|
||||
|
||||
var _ Ormer = new(filterMockOrm)
|
||||
|
||||
// filterMockOrm is only used in this test file
|
||||
type filterMockOrm struct {
|
||||
DoNothingOrm
|
||||
}
|
||||
|
||||
func (f *filterMockOrm) ReadOrCreateWithCtx(ctx context.Context, md interface{}, col1 string, cols ...string) (bool, int64, error) {
|
||||
return true, 13, errors.New("read or create error")
|
||||
}
|
||||
|
||||
func (f *filterMockOrm) ReadForUpdateWithCtx(ctx context.Context, md interface{}, cols ...string) error {
|
||||
return errors.New("read for update error")
|
||||
}
|
||||
|
||||
func (f *filterMockOrm) LoadRelatedWithCtx(ctx context.Context, md interface{}, name string, args ...utils.KV) (int64, error) {
|
||||
return 99, errors.New("load related error")
|
||||
}
|
||||
|
||||
func (f *filterMockOrm) InsertOrUpdateWithCtx(ctx context.Context, md interface{}, colConflitAndArgs ...string) (int64, error) {
|
||||
return 1, errors.New("insert or update error")
|
||||
}
|
||||
|
||||
func (f *filterMockOrm) InsertMultiWithCtx(ctx context.Context, bulk int, mds interface{}) (int64, error) {
|
||||
return 2, errors.New("insert multi error")
|
||||
}
|
||||
|
||||
func (f *filterMockOrm) InsertWithCtx(ctx context.Context, md interface{}) (int64, error) {
|
||||
return 100, errors.New("insert error")
|
||||
}
|
||||
|
||||
func (f *filterMockOrm) DoTxWithCtxAndOpts(ctx context.Context, opts *sql.TxOptions, task func(c context.Context, txOrm TxOrmer) error) error {
|
||||
return task(ctx, nil)
|
||||
}
|
||||
|
||||
func (f *filterMockOrm) DeleteWithCtx(ctx context.Context, md interface{}, cols ...string) (int64, error) {
|
||||
return -2, errors.New("delete error")
|
||||
}
|
||||
|
||||
func (f *filterMockOrm) BeginWithCtxAndOpts(ctx context.Context, opts *sql.TxOptions) (TxOrmer, error) {
|
||||
return &filterMockOrm{}, errors.New("begin tx")
|
||||
}
|
||||
|
||||
func (f *filterMockOrm) ReadWithCtx(ctx context.Context, md interface{}, cols ...string) error {
|
||||
return errors.New("read error")
|
||||
}
|
||||
|
||||
func (f *filterMockOrm) Commit() error {
|
||||
return errors.New("commit")
|
||||
}
|
||||
|
||||
func (f *filterMockOrm) Rollback() error {
|
||||
return errors.New("rollback")
|
||||
}
|
||||
|
||||
func (f *filterMockOrm) DBStats() *sql.DBStats {
|
||||
return &sql.DBStats{
|
||||
MaxOpenConnections: -1,
|
||||
}
|
||||
}
|
||||
|
||||
func validateBeginResult(t *testing.T, to TxOrmer, err error) bool {
|
||||
assert.NotNil(t, err)
|
||||
assert.Equal(t, "begin tx", err.Error())
|
||||
_, ok := to.(*filterOrmDecorator).TxCommitter.(*filterMockOrm)
|
||||
assert.True(t, ok)
|
||||
return true
|
||||
}
|
||||
|
||||
var filterTestEntityRegisterOnce sync.Once
|
||||
|
||||
type FilterTestEntity struct {
|
||||
ID int
|
||||
Name string
|
||||
}
|
||||
|
||||
func register() {
|
||||
filterTestEntityRegisterOnce.Do(func() {
|
||||
RegisterModel(&FilterTestEntity{})
|
||||
})
|
||||
}
|
||||
|
||||
func (f *FilterTestEntity) TableName() string {
|
||||
return "FILTER_TEST"
|
||||
}
|
32
pkg/client/orm/filter_test.go
Normal file
32
pkg/client/orm/filter_test.go
Normal file
@ -0,0 +1,32 @@
|
||||
// Copyright 2020 beego
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package orm
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestAddGlobalFilterChain(t *testing.T) {
|
||||
AddGlobalFilterChain(func(next Filter) Filter {
|
||||
return func(ctx context.Context, inv *Invocation) []interface{} {
|
||||
return next(ctx, inv)
|
||||
}
|
||||
})
|
||||
assert.Equal(t, 1, len(globalFilterChains))
|
||||
globalFilterChains = nil
|
||||
}
|
131
pkg/client/orm/hints/db_hints.go
Normal file
131
pkg/client/orm/hints/db_hints.go
Normal file
@ -0,0 +1,131 @@
|
||||
// Copyright 2020 beego-dev
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package hints
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/astaxie/beego/pkg/infrastructure/utils"
|
||||
)
|
||||
|
||||
const (
|
||||
//db level
|
||||
KeyMaxIdleConnections = iota
|
||||
KeyMaxOpenConnections
|
||||
KeyConnMaxLifetime
|
||||
KeyMaxStmtCacheSize
|
||||
|
||||
//query level
|
||||
KeyForceIndex
|
||||
KeyUseIndex
|
||||
KeyIgnoreIndex
|
||||
KeyForUpdate
|
||||
KeyLimit
|
||||
KeyOffset
|
||||
KeyOrderBy
|
||||
KeyRelDepth
|
||||
)
|
||||
|
||||
type Hint struct {
|
||||
key interface{}
|
||||
value interface{}
|
||||
}
|
||||
|
||||
var _ utils.KV = new(Hint)
|
||||
|
||||
// GetKey return key
|
||||
func (s *Hint) GetKey() interface{} {
|
||||
return s.key
|
||||
}
|
||||
|
||||
// GetValue return value
|
||||
func (s *Hint) GetValue() interface{} {
|
||||
return s.value
|
||||
}
|
||||
|
||||
var _ utils.KV = new(Hint)
|
||||
|
||||
// MaxIdleConnections return a hint about MaxIdleConnections
|
||||
func MaxIdleConnections(v int) *Hint {
|
||||
return NewHint(KeyMaxIdleConnections, v)
|
||||
}
|
||||
|
||||
// MaxOpenConnections return a hint about MaxOpenConnections
|
||||
func MaxOpenConnections(v int) *Hint {
|
||||
return NewHint(KeyMaxOpenConnections, v)
|
||||
}
|
||||
|
||||
// ConnMaxLifetime return a hint about ConnMaxLifetime
|
||||
func ConnMaxLifetime(v time.Duration) *Hint {
|
||||
return NewHint(KeyConnMaxLifetime, v)
|
||||
}
|
||||
|
||||
// MaxStmtCacheSize return a hint about MaxStmtCacheSize
|
||||
func MaxStmtCacheSize(v int) *Hint {
|
||||
return NewHint(KeyMaxStmtCacheSize, v)
|
||||
}
|
||||
|
||||
// ForceIndex return a hint about ForceIndex
|
||||
func ForceIndex(indexes ...string) *Hint {
|
||||
return NewHint(KeyForceIndex, indexes)
|
||||
}
|
||||
|
||||
// UseIndex return a hint about UseIndex
|
||||
func UseIndex(indexes ...string) *Hint {
|
||||
return NewHint(KeyUseIndex, indexes)
|
||||
}
|
||||
|
||||
// IgnoreIndex return a hint about IgnoreIndex
|
||||
func IgnoreIndex(indexes ...string) *Hint {
|
||||
return NewHint(KeyIgnoreIndex, indexes)
|
||||
}
|
||||
|
||||
// ForUpdate return a hint about ForUpdate
|
||||
func ForUpdate() *Hint {
|
||||
return NewHint(KeyForUpdate, true)
|
||||
}
|
||||
|
||||
// DefaultRelDepth return a hint about DefaultRelDepth
|
||||
func DefaultRelDepth() *Hint {
|
||||
return NewHint(KeyRelDepth, true)
|
||||
}
|
||||
|
||||
// RelDepth return a hint about RelDepth
|
||||
func RelDepth(d int) *Hint {
|
||||
return NewHint(KeyRelDepth, d)
|
||||
}
|
||||
|
||||
// Limit return a hint about Limit
|
||||
func Limit(d int64) *Hint {
|
||||
return NewHint(KeyLimit, d)
|
||||
}
|
||||
|
||||
// Offset return a hint about Offset
|
||||
func Offset(d int64) *Hint {
|
||||
return NewHint(KeyOffset, d)
|
||||
}
|
||||
|
||||
// OrderBy return a hint about OrderBy
|
||||
func OrderBy(s string) *Hint {
|
||||
return NewHint(KeyOrderBy, s)
|
||||
}
|
||||
|
||||
// NewHint return a hint
|
||||
func NewHint(key interface{}, value interface{}) *Hint {
|
||||
return &Hint{
|
||||
key: key,
|
||||
value: value,
|
||||
}
|
||||
}
|
155
pkg/client/orm/hints/db_hints_test.go
Normal file
155
pkg/client/orm/hints/db_hints_test.go
Normal file
@ -0,0 +1,155 @@
|
||||
// Copyright 2020 beego-dev
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package hints
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestNewHint_time(t *testing.T) {
|
||||
key := "qweqwe"
|
||||
value := time.Second
|
||||
hint := NewHint(key, value)
|
||||
|
||||
assert.Equal(t, hint.GetKey(), key)
|
||||
assert.Equal(t, hint.GetValue(), value)
|
||||
}
|
||||
|
||||
func TestNewHint_int(t *testing.T) {
|
||||
key := "qweqwe"
|
||||
value := 281230
|
||||
hint := NewHint(key, value)
|
||||
|
||||
assert.Equal(t, hint.GetKey(), key)
|
||||
assert.Equal(t, hint.GetValue(), value)
|
||||
}
|
||||
|
||||
func TestNewHint_float(t *testing.T) {
|
||||
key := "qweqwe"
|
||||
value := 21.2459753
|
||||
hint := NewHint(key, value)
|
||||
|
||||
assert.Equal(t, hint.GetKey(), key)
|
||||
assert.Equal(t, hint.GetValue(), value)
|
||||
}
|
||||
|
||||
func TestMaxOpenConnections(t *testing.T) {
|
||||
i := 887423
|
||||
hint := MaxOpenConnections(i)
|
||||
assert.Equal(t, hint.GetValue(), i)
|
||||
assert.Equal(t, hint.GetKey(), KeyMaxOpenConnections)
|
||||
}
|
||||
|
||||
func TestConnMaxLifetime(t *testing.T) {
|
||||
i := time.Hour
|
||||
hint := ConnMaxLifetime(i)
|
||||
assert.Equal(t, hint.GetValue(), i)
|
||||
assert.Equal(t, hint.GetKey(), KeyConnMaxLifetime)
|
||||
}
|
||||
|
||||
func TestMaxIdleConnections(t *testing.T) {
|
||||
i := 42316
|
||||
hint := MaxIdleConnections(i)
|
||||
assert.Equal(t, hint.GetValue(), i)
|
||||
assert.Equal(t, hint.GetKey(), KeyMaxIdleConnections)
|
||||
}
|
||||
|
||||
func TestMaxStmtCacheSize(t *testing.T) {
|
||||
i := 94157
|
||||
hint := MaxStmtCacheSize(i)
|
||||
assert.Equal(t, hint.GetValue(), i)
|
||||
assert.Equal(t, hint.GetKey(), KeyMaxStmtCacheSize)
|
||||
}
|
||||
|
||||
func TestForceIndex(t *testing.T) {
|
||||
s := []string{`f_index1`, `f_index2`, `f_index3`}
|
||||
hint := ForceIndex(s...)
|
||||
assert.Equal(t, hint.GetValue(), s)
|
||||
assert.Equal(t, hint.GetKey(), KeyForceIndex)
|
||||
}
|
||||
|
||||
func TestForceIndex_0(t *testing.T) {
|
||||
var s []string
|
||||
hint := ForceIndex(s...)
|
||||
assert.Equal(t, hint.GetValue(), s)
|
||||
assert.Equal(t, hint.GetKey(), KeyForceIndex)
|
||||
}
|
||||
|
||||
func TestIgnoreIndex(t *testing.T) {
|
||||
s := []string{`i_index1`, `i_index2`, `i_index3`}
|
||||
hint := IgnoreIndex(s...)
|
||||
assert.Equal(t, hint.GetValue(), s)
|
||||
assert.Equal(t, hint.GetKey(), KeyIgnoreIndex)
|
||||
}
|
||||
|
||||
func TestIgnoreIndex_0(t *testing.T) {
|
||||
var s []string
|
||||
hint := IgnoreIndex(s...)
|
||||
assert.Equal(t, hint.GetValue(), s)
|
||||
assert.Equal(t, hint.GetKey(), KeyIgnoreIndex)
|
||||
}
|
||||
|
||||
func TestUseIndex(t *testing.T) {
|
||||
s := []string{`u_index1`, `u_index2`, `u_index3`}
|
||||
hint := UseIndex(s...)
|
||||
assert.Equal(t, hint.GetValue(), s)
|
||||
assert.Equal(t, hint.GetKey(), KeyUseIndex)
|
||||
}
|
||||
|
||||
func TestUseIndex_0(t *testing.T) {
|
||||
var s []string
|
||||
hint := UseIndex(s...)
|
||||
assert.Equal(t, hint.GetValue(), s)
|
||||
assert.Equal(t, hint.GetKey(), KeyUseIndex)
|
||||
}
|
||||
|
||||
func TestForUpdate(t *testing.T) {
|
||||
hint := ForUpdate()
|
||||
assert.Equal(t, hint.GetValue(), true)
|
||||
assert.Equal(t, hint.GetKey(), KeyForUpdate)
|
||||
}
|
||||
|
||||
func TestDefaultRelDepth(t *testing.T) {
|
||||
hint := DefaultRelDepth()
|
||||
assert.Equal(t, hint.GetValue(), true)
|
||||
assert.Equal(t, hint.GetKey(), KeyRelDepth)
|
||||
}
|
||||
|
||||
func TestRelDepth(t *testing.T) {
|
||||
hint := RelDepth(157965)
|
||||
assert.Equal(t, hint.GetValue(), 157965)
|
||||
assert.Equal(t, hint.GetKey(), KeyRelDepth)
|
||||
}
|
||||
|
||||
func TestLimit(t *testing.T) {
|
||||
hint := Limit(1579625)
|
||||
assert.Equal(t, hint.GetValue(), int64(1579625))
|
||||
assert.Equal(t, hint.GetKey(), KeyLimit)
|
||||
}
|
||||
|
||||
func TestOffset(t *testing.T) {
|
||||
hint := Offset(int64(1572123965))
|
||||
assert.Equal(t, hint.GetValue(), int64(1572123965))
|
||||
assert.Equal(t, hint.GetKey(), KeyOffset)
|
||||
}
|
||||
|
||||
func TestOrderBy(t *testing.T) {
|
||||
hint := OrderBy(`-ID`)
|
||||
assert.Equal(t, hint.GetValue(), `-ID`)
|
||||
assert.Equal(t, hint.GetKey(), KeyOrderBy)
|
||||
}
|
58
pkg/client/orm/invocation.go
Normal file
58
pkg/client/orm/invocation.go
Normal file
@ -0,0 +1,58 @@
|
||||
// Copyright 2020 beego
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package orm
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Invocation represents an "Orm" invocation
|
||||
type Invocation struct {
|
||||
Method string
|
||||
// Md may be nil in some cases. It depends on method
|
||||
Md interface{}
|
||||
// the args are all arguments except context.Context
|
||||
Args []interface{}
|
||||
|
||||
mi *modelInfo
|
||||
// f is the Orm operation
|
||||
f func(ctx context.Context) []interface{}
|
||||
|
||||
// insideTx indicates whether this is inside a transaction
|
||||
InsideTx bool
|
||||
TxStartTime time.Time
|
||||
TxName string
|
||||
}
|
||||
|
||||
func (inv *Invocation) GetTableName() string {
|
||||
if inv.mi != nil {
|
||||
return inv.mi.table
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (inv *Invocation) execute(ctx context.Context) []interface{} {
|
||||
return inv.f(ctx)
|
||||
}
|
||||
|
||||
// GetPkFieldName return the primary key of this table
|
||||
// if not found, "" is returned
|
||||
func (inv *Invocation) GetPkFieldName() string {
|
||||
if inv.mi.fields.pk != nil {
|
||||
return inv.mi.fields.pk.name
|
||||
}
|
||||
return ""
|
||||
}
|
395
pkg/client/orm/migration/ddl.go
Normal file
395
pkg/client/orm/migration/ddl.go
Normal file
@ -0,0 +1,395 @@
|
||||
// Copyright 2014 beego Author. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package migration
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/astaxie/beego/pkg/infrastructure/logs"
|
||||
)
|
||||
|
||||
// Index struct defines the structure of Index Columns
|
||||
type Index struct {
|
||||
Name string
|
||||
}
|
||||
|
||||
// Unique struct defines a single unique key combination
|
||||
type Unique struct {
|
||||
Definition string
|
||||
Columns []*Column
|
||||
}
|
||||
|
||||
//Column struct defines a single column of a table
|
||||
type Column struct {
|
||||
Name string
|
||||
Inc string
|
||||
Null string
|
||||
Default string
|
||||
Unsign string
|
||||
DataType string
|
||||
remove bool
|
||||
Modify bool
|
||||
}
|
||||
|
||||
// Foreign struct defines a single foreign relationship
|
||||
type Foreign struct {
|
||||
ForeignTable string
|
||||
ForeignColumn string
|
||||
OnDelete string
|
||||
OnUpdate string
|
||||
Column
|
||||
}
|
||||
|
||||
// RenameColumn struct allows renaming of columns
|
||||
type RenameColumn struct {
|
||||
OldName string
|
||||
OldNull string
|
||||
OldDefault string
|
||||
OldUnsign string
|
||||
OldDataType string
|
||||
NewName string
|
||||
Column
|
||||
}
|
||||
|
||||
// CreateTable creates the table on system
|
||||
func (m *Migration) CreateTable(tablename, engine, charset string, p ...func()) {
|
||||
m.TableName = tablename
|
||||
m.Engine = engine
|
||||
m.Charset = charset
|
||||
m.ModifyType = "create"
|
||||
}
|
||||
|
||||
// AlterTable set the ModifyType to alter
|
||||
func (m *Migration) AlterTable(tablename string) {
|
||||
m.TableName = tablename
|
||||
m.ModifyType = "alter"
|
||||
}
|
||||
|
||||
// NewCol creates a new standard column and attaches it to m struct
|
||||
func (m *Migration) NewCol(name string) *Column {
|
||||
col := &Column{Name: name}
|
||||
m.AddColumns(col)
|
||||
return col
|
||||
}
|
||||
|
||||
//PriCol creates a new primary column and attaches it to m struct
|
||||
func (m *Migration) PriCol(name string) *Column {
|
||||
col := &Column{Name: name}
|
||||
m.AddColumns(col)
|
||||
m.AddPrimary(col)
|
||||
return col
|
||||
}
|
||||
|
||||
//UniCol creates / appends columns to specified unique key and attaches it to m struct
|
||||
func (m *Migration) UniCol(uni, name string) *Column {
|
||||
col := &Column{Name: name}
|
||||
m.AddColumns(col)
|
||||
|
||||
uniqueOriginal := &Unique{}
|
||||
|
||||
for _, unique := range m.Uniques {
|
||||
if unique.Definition == uni {
|
||||
unique.AddColumnsToUnique(col)
|
||||
uniqueOriginal = unique
|
||||
}
|
||||
}
|
||||
if uniqueOriginal.Definition == "" {
|
||||
unique := &Unique{Definition: uni}
|
||||
unique.AddColumnsToUnique(col)
|
||||
m.AddUnique(unique)
|
||||
}
|
||||
|
||||
return col
|
||||
}
|
||||
|
||||
//ForeignCol creates a new foreign column and returns the instance of column
|
||||
func (m *Migration) ForeignCol(colname, foreigncol, foreigntable string) (foreign *Foreign) {
|
||||
|
||||
foreign = &Foreign{ForeignColumn: foreigncol, ForeignTable: foreigntable}
|
||||
foreign.Name = colname
|
||||
m.AddForeign(foreign)
|
||||
return foreign
|
||||
}
|
||||
|
||||
//SetOnDelete sets the on delete of foreign
|
||||
func (foreign *Foreign) SetOnDelete(del string) *Foreign {
|
||||
foreign.OnDelete = "ON DELETE" + del
|
||||
return foreign
|
||||
}
|
||||
|
||||
//SetOnUpdate sets the on update of foreign
|
||||
func (foreign *Foreign) SetOnUpdate(update string) *Foreign {
|
||||
foreign.OnUpdate = "ON UPDATE" + update
|
||||
return foreign
|
||||
}
|
||||
|
||||
//Remove marks the columns to be removed.
|
||||
//it allows reverse m to create the column.
|
||||
func (c *Column) Remove() {
|
||||
c.remove = true
|
||||
}
|
||||
|
||||
//SetAuto enables auto_increment of column (can be used once)
|
||||
func (c *Column) SetAuto(inc bool) *Column {
|
||||
if inc {
|
||||
c.Inc = "auto_increment"
|
||||
}
|
||||
return c
|
||||
}
|
||||
|
||||
//SetNullable sets the column to be null
|
||||
func (c *Column) SetNullable(null bool) *Column {
|
||||
if null {
|
||||
c.Null = ""
|
||||
|
||||
} else {
|
||||
c.Null = "NOT NULL"
|
||||
}
|
||||
return c
|
||||
}
|
||||
|
||||
//SetDefault sets the default value, prepend with "DEFAULT "
|
||||
func (c *Column) SetDefault(def string) *Column {
|
||||
c.Default = "DEFAULT " + def
|
||||
return c
|
||||
}
|
||||
|
||||
//SetUnsigned sets the column to be unsigned int
|
||||
func (c *Column) SetUnsigned(unsign bool) *Column {
|
||||
if unsign {
|
||||
c.Unsign = "UNSIGNED"
|
||||
}
|
||||
return c
|
||||
}
|
||||
|
||||
//SetDataType sets the dataType of the column
|
||||
func (c *Column) SetDataType(dataType string) *Column {
|
||||
c.DataType = dataType
|
||||
return c
|
||||
}
|
||||
|
||||
//SetOldNullable allows reverting to previous nullable on reverse ms
|
||||
func (c *RenameColumn) SetOldNullable(null bool) *RenameColumn {
|
||||
if null {
|
||||
c.OldNull = ""
|
||||
|
||||
} else {
|
||||
c.OldNull = "NOT NULL"
|
||||
}
|
||||
return c
|
||||
}
|
||||
|
||||
//SetOldDefault allows reverting to previous default on reverse ms
|
||||
func (c *RenameColumn) SetOldDefault(def string) *RenameColumn {
|
||||
c.OldDefault = def
|
||||
return c
|
||||
}
|
||||
|
||||
//SetOldUnsigned allows reverting to previous unsgined on reverse ms
|
||||
func (c *RenameColumn) SetOldUnsigned(unsign bool) *RenameColumn {
|
||||
if unsign {
|
||||
c.OldUnsign = "UNSIGNED"
|
||||
}
|
||||
return c
|
||||
}
|
||||
|
||||
//SetOldDataType allows reverting to previous datatype on reverse ms
|
||||
func (c *RenameColumn) SetOldDataType(dataType string) *RenameColumn {
|
||||
c.OldDataType = dataType
|
||||
return c
|
||||
}
|
||||
|
||||
//SetPrimary adds the columns to the primary key (can only be used any number of times in only one m)
|
||||
func (c *Column) SetPrimary(m *Migration) *Column {
|
||||
m.Primary = append(m.Primary, c)
|
||||
return c
|
||||
}
|
||||
|
||||
//AddColumnsToUnique adds the columns to Unique Struct
|
||||
func (unique *Unique) AddColumnsToUnique(columns ...*Column) *Unique {
|
||||
|
||||
unique.Columns = append(unique.Columns, columns...)
|
||||
|
||||
return unique
|
||||
}
|
||||
|
||||
//AddColumns adds columns to m struct
|
||||
func (m *Migration) AddColumns(columns ...*Column) *Migration {
|
||||
|
||||
m.Columns = append(m.Columns, columns...)
|
||||
|
||||
return m
|
||||
}
|
||||
|
||||
//AddPrimary adds the column to primary in m struct
|
||||
func (m *Migration) AddPrimary(primary *Column) *Migration {
|
||||
m.Primary = append(m.Primary, primary)
|
||||
return m
|
||||
}
|
||||
|
||||
//AddUnique adds the column to unique in m struct
|
||||
func (m *Migration) AddUnique(unique *Unique) *Migration {
|
||||
m.Uniques = append(m.Uniques, unique)
|
||||
return m
|
||||
}
|
||||
|
||||
//AddForeign adds the column to foreign in m struct
|
||||
func (m *Migration) AddForeign(foreign *Foreign) *Migration {
|
||||
m.Foreigns = append(m.Foreigns, foreign)
|
||||
return m
|
||||
}
|
||||
|
||||
//AddIndex adds the column to index in m struct
|
||||
func (m *Migration) AddIndex(index *Index) *Migration {
|
||||
m.Indexes = append(m.Indexes, index)
|
||||
return m
|
||||
}
|
||||
|
||||
//RenameColumn allows renaming of columns
|
||||
func (m *Migration) RenameColumn(from, to string) *RenameColumn {
|
||||
rename := &RenameColumn{OldName: from, NewName: to}
|
||||
m.Renames = append(m.Renames, rename)
|
||||
return rename
|
||||
}
|
||||
|
||||
//GetSQL returns the generated sql depending on ModifyType
|
||||
func (m *Migration) GetSQL() (sql string) {
|
||||
sql = ""
|
||||
switch m.ModifyType {
|
||||
case "create":
|
||||
{
|
||||
sql += fmt.Sprintf("CREATE TABLE `%s` (", m.TableName)
|
||||
for index, column := range m.Columns {
|
||||
sql += fmt.Sprintf("\n `%s` %s %s %s %s %s", column.Name, column.DataType, column.Unsign, column.Null, column.Inc, column.Default)
|
||||
if len(m.Columns) > index+1 {
|
||||
sql += ","
|
||||
}
|
||||
}
|
||||
|
||||
if len(m.Primary) > 0 {
|
||||
sql += fmt.Sprintf(",\n PRIMARY KEY( ")
|
||||
}
|
||||
for index, column := range m.Primary {
|
||||
sql += fmt.Sprintf(" `%s`", column.Name)
|
||||
if len(m.Primary) > index+1 {
|
||||
sql += ","
|
||||
}
|
||||
|
||||
}
|
||||
if len(m.Primary) > 0 {
|
||||
sql += fmt.Sprintf(")")
|
||||
}
|
||||
|
||||
for _, unique := range m.Uniques {
|
||||
sql += fmt.Sprintf(",\n UNIQUE KEY `%s`( ", unique.Definition)
|
||||
for index, column := range unique.Columns {
|
||||
sql += fmt.Sprintf(" `%s`", column.Name)
|
||||
if len(unique.Columns) > index+1 {
|
||||
sql += ","
|
||||
}
|
||||
}
|
||||
sql += fmt.Sprintf(")")
|
||||
}
|
||||
for _, foreign := range m.Foreigns {
|
||||
sql += fmt.Sprintf(",\n `%s` %s %s %s %s %s", foreign.Name, foreign.DataType, foreign.Unsign, foreign.Null, foreign.Inc, foreign.Default)
|
||||
sql += fmt.Sprintf(",\n KEY `%s_%s_foreign`(`%s`),", m.TableName, foreign.Column.Name, foreign.Column.Name)
|
||||
sql += fmt.Sprintf("\n CONSTRAINT `%s_%s_foreign` FOREIGN KEY (`%s`) REFERENCES `%s` (`%s`) %s %s", m.TableName, foreign.Column.Name, foreign.Column.Name, foreign.ForeignTable, foreign.ForeignColumn, foreign.OnDelete, foreign.OnUpdate)
|
||||
|
||||
}
|
||||
sql += fmt.Sprintf(")ENGINE=%s DEFAULT CHARSET=%s;", m.Engine, m.Charset)
|
||||
break
|
||||
}
|
||||
case "alter":
|
||||
{
|
||||
sql += fmt.Sprintf("ALTER TABLE `%s` ", m.TableName)
|
||||
for index, column := range m.Columns {
|
||||
if !column.remove {
|
||||
logs.Info("col")
|
||||
sql += fmt.Sprintf("\n ADD `%s` %s %s %s %s %s", column.Name, column.DataType, column.Unsign, column.Null, column.Inc, column.Default)
|
||||
} else {
|
||||
sql += fmt.Sprintf("\n DROP COLUMN `%s`", column.Name)
|
||||
}
|
||||
|
||||
if len(m.Columns) > index+1 {
|
||||
sql += ","
|
||||
}
|
||||
}
|
||||
for index, column := range m.Renames {
|
||||
sql += fmt.Sprintf("CHANGE COLUMN `%s` `%s` %s %s %s %s %s", column.OldName, column.NewName, column.DataType, column.Unsign, column.Null, column.Inc, column.Default)
|
||||
if len(m.Renames) > index+1 {
|
||||
sql += ","
|
||||
}
|
||||
}
|
||||
|
||||
for index, foreign := range m.Foreigns {
|
||||
sql += fmt.Sprintf("ADD `%s` %s %s %s %s %s", foreign.Name, foreign.DataType, foreign.Unsign, foreign.Null, foreign.Inc, foreign.Default)
|
||||
sql += fmt.Sprintf(",\n ADD KEY `%s_%s_foreign`(`%s`)", m.TableName, foreign.Column.Name, foreign.Column.Name)
|
||||
sql += fmt.Sprintf(",\n ADD CONSTRAINT `%s_%s_foreign` FOREIGN KEY (`%s`) REFERENCES `%s` (`%s`) %s %s", m.TableName, foreign.Column.Name, foreign.Column.Name, foreign.ForeignTable, foreign.ForeignColumn, foreign.OnDelete, foreign.OnUpdate)
|
||||
if len(m.Foreigns) > index+1 {
|
||||
sql += ","
|
||||
}
|
||||
}
|
||||
sql += ";"
|
||||
|
||||
break
|
||||
}
|
||||
case "reverse":
|
||||
{
|
||||
|
||||
sql += fmt.Sprintf("ALTER TABLE `%s`", m.TableName)
|
||||
for index, column := range m.Columns {
|
||||
if column.remove {
|
||||
sql += fmt.Sprintf("\n ADD `%s` %s %s %s %s %s", column.Name, column.DataType, column.Unsign, column.Null, column.Inc, column.Default)
|
||||
} else {
|
||||
sql += fmt.Sprintf("\n DROP COLUMN `%s`", column.Name)
|
||||
}
|
||||
if len(m.Columns) > index+1 {
|
||||
sql += ","
|
||||
}
|
||||
}
|
||||
|
||||
if len(m.Primary) > 0 {
|
||||
sql += fmt.Sprintf("\n DROP PRIMARY KEY,")
|
||||
}
|
||||
|
||||
for index, unique := range m.Uniques {
|
||||
sql += fmt.Sprintf("\n DROP KEY `%s`", unique.Definition)
|
||||
if len(m.Uniques) > index+1 {
|
||||
sql += ","
|
||||
}
|
||||
|
||||
}
|
||||
for index, column := range m.Renames {
|
||||
sql += fmt.Sprintf("\n CHANGE COLUMN `%s` `%s` %s %s %s %s", column.NewName, column.OldName, column.OldDataType, column.OldUnsign, column.OldNull, column.OldDefault)
|
||||
if len(m.Renames) > index+1 {
|
||||
sql += ","
|
||||
}
|
||||
}
|
||||
|
||||
for _, foreign := range m.Foreigns {
|
||||
sql += fmt.Sprintf("\n DROP KEY `%s_%s_foreign`", m.TableName, foreign.Column.Name)
|
||||
sql += fmt.Sprintf(",\n DROP FOREIGN KEY `%s_%s_foreign`", m.TableName, foreign.Column.Name)
|
||||
sql += fmt.Sprintf(",\n DROP COLUMN `%s`", foreign.Name)
|
||||
}
|
||||
sql += ";"
|
||||
}
|
||||
case "delete":
|
||||
{
|
||||
sql += fmt.Sprintf("DROP TABLE IF EXISTS `%s`;", m.TableName)
|
||||
}
|
||||
}
|
||||
|
||||
return
|
||||
}
|
32
pkg/client/orm/migration/doc.go
Normal file
32
pkg/client/orm/migration/doc.go
Normal file
@ -0,0 +1,32 @@
|
||||
// Package migration enables you to generate migrations back and forth. It generates both migrations.
|
||||
//
|
||||
// //Creates a table
|
||||
// m.CreateTable("tablename","InnoDB","utf8");
|
||||
//
|
||||
// //Alter a table
|
||||
// m.AlterTable("tablename")
|
||||
//
|
||||
// Standard Column Methods
|
||||
// * SetDataType
|
||||
// * SetNullable
|
||||
// * SetDefault
|
||||
// * SetUnsigned (use only on integer types unless produces error)
|
||||
//
|
||||
// //Sets a primary column, multiple calls allowed, standard column methods available
|
||||
// m.PriCol("id").SetAuto(true).SetNullable(false).SetDataType("INT(10)").SetUnsigned(true)
|
||||
//
|
||||
// //UniCol Can be used multiple times, allows standard Column methods. Use same "index" string to add to same index
|
||||
// m.UniCol("index","column")
|
||||
//
|
||||
// //Standard Column Initialisation, can call .Remove() after NewCol("") on alter to remove
|
||||
// m.NewCol("name").SetDataType("VARCHAR(255) COLLATE utf8_unicode_ci").SetNullable(false)
|
||||
// m.NewCol("value").SetDataType("DOUBLE(8,2)").SetNullable(false)
|
||||
//
|
||||
// //Rename Columns , only use with Alter table, doesn't works with Create, prefix standard column methods with "Old" to
|
||||
// //create a true reversible migration eg: SetOldDataType("DOUBLE(12,3)")
|
||||
// m.RenameColumn("from","to")...
|
||||
//
|
||||
// //Foreign Columns, single columns are only supported, SetOnDelete & SetOnUpdate are available, call appropriately.
|
||||
// //Supports standard column methods, automatic reverse.
|
||||
// m.ForeignCol("local_col","foreign_col","foreign_table")
|
||||
package migration
|
330
pkg/client/orm/migration/migration.go
Normal file
330
pkg/client/orm/migration/migration.go
Normal file
@ -0,0 +1,330 @@
|
||||
// Copyright 2014 beego Author. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
// Package migration is used for migration
|
||||
//
|
||||
// The table structure is as follow:
|
||||
//
|
||||
// CREATE TABLE `migrations` (
|
||||
// `id_migration` int(10) unsigned NOT NULL AUTO_INCREMENT COMMENT 'surrogate key',
|
||||
// `name` varchar(255) DEFAULT NULL COMMENT 'migration name, unique',
|
||||
// `created_at` timestamp NOT NULL DEFAULT CURRENT_TIMESTAMP COMMENT 'date migrated or rolled back',
|
||||
// `statements` longtext COMMENT 'SQL statements for this migration',
|
||||
// `rollback_statements` longtext,
|
||||
// `status` enum('update','rollback') DEFAULT NULL COMMENT 'update indicates it is a normal migration while rollback means this migration is rolled back',
|
||||
// PRIMARY KEY (`id_migration`)
|
||||
// ) ENGINE=InnoDB DEFAULT CHARSET=utf8;
|
||||
package migration
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"sort"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/astaxie/beego/pkg/client/orm"
|
||||
"github.com/astaxie/beego/pkg/infrastructure/logs"
|
||||
)
|
||||
|
||||
// const the data format for the bee generate migration datatype
|
||||
const (
|
||||
DateFormat = "20060102_150405"
|
||||
DBDateFormat = "2006-01-02 15:04:05"
|
||||
)
|
||||
|
||||
// Migrationer is an interface for all Migration struct
|
||||
type Migrationer interface {
|
||||
Up()
|
||||
Down()
|
||||
Reset()
|
||||
Exec(name, status string) error
|
||||
GetCreated() int64
|
||||
}
|
||||
|
||||
//Migration defines the migrations by either SQL or DDL
|
||||
type Migration struct {
|
||||
sqls []string
|
||||
Created string
|
||||
TableName string
|
||||
Engine string
|
||||
Charset string
|
||||
ModifyType string
|
||||
Columns []*Column
|
||||
Indexes []*Index
|
||||
Primary []*Column
|
||||
Uniques []*Unique
|
||||
Foreigns []*Foreign
|
||||
Renames []*RenameColumn
|
||||
RemoveColumns []*Column
|
||||
RemoveIndexes []*Index
|
||||
RemoveUniques []*Unique
|
||||
RemoveForeigns []*Foreign
|
||||
}
|
||||
|
||||
var (
|
||||
migrationMap map[string]Migrationer
|
||||
)
|
||||
|
||||
func init() {
|
||||
migrationMap = make(map[string]Migrationer)
|
||||
}
|
||||
|
||||
// Up implement in the Inheritance struct for upgrade
|
||||
func (m *Migration) Up() {
|
||||
|
||||
switch m.ModifyType {
|
||||
case "reverse":
|
||||
m.ModifyType = "alter"
|
||||
case "delete":
|
||||
m.ModifyType = "create"
|
||||
}
|
||||
m.sqls = append(m.sqls, m.GetSQL())
|
||||
}
|
||||
|
||||
// Down implement in the Inheritance struct for down
|
||||
func (m *Migration) Down() {
|
||||
|
||||
switch m.ModifyType {
|
||||
case "alter":
|
||||
m.ModifyType = "reverse"
|
||||
case "create":
|
||||
m.ModifyType = "delete"
|
||||
}
|
||||
m.sqls = append(m.sqls, m.GetSQL())
|
||||
}
|
||||
|
||||
//Migrate adds the SQL to the execution list
|
||||
func (m *Migration) Migrate(migrationType string) {
|
||||
m.ModifyType = migrationType
|
||||
m.sqls = append(m.sqls, m.GetSQL())
|
||||
}
|
||||
|
||||
// SQL add sql want to execute
|
||||
func (m *Migration) SQL(sql string) {
|
||||
m.sqls = append(m.sqls, sql)
|
||||
}
|
||||
|
||||
// Reset the sqls
|
||||
func (m *Migration) Reset() {
|
||||
m.sqls = make([]string, 0)
|
||||
}
|
||||
|
||||
// Exec execute the sql already add in the sql
|
||||
func (m *Migration) Exec(name, status string) error {
|
||||
o := orm.NewOrm()
|
||||
for _, s := range m.sqls {
|
||||
logs.Info("exec sql:", s)
|
||||
r := o.Raw(s)
|
||||
_, err := r.Exec()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return m.addOrUpdateRecord(name, status)
|
||||
}
|
||||
|
||||
func (m *Migration) addOrUpdateRecord(name, status string) error {
|
||||
o := orm.NewOrm()
|
||||
if status == "down" {
|
||||
status = "rollback"
|
||||
p, err := o.Raw("update migrations set status = ?, rollback_statements = ?, created_at = ? where name = ?").Prepare()
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
_, err = p.Exec(status, strings.Join(m.sqls, "; "), time.Now().Format(DBDateFormat), name)
|
||||
return err
|
||||
}
|
||||
status = "update"
|
||||
p, err := o.Raw("insert into migrations(name, created_at, statements, status) values(?,?,?,?)").Prepare()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = p.Exec(name, time.Now().Format(DBDateFormat), strings.Join(m.sqls, "; "), status)
|
||||
return err
|
||||
}
|
||||
|
||||
// GetCreated get the unixtime from the Created
|
||||
func (m *Migration) GetCreated() int64 {
|
||||
t, err := time.Parse(DateFormat, m.Created)
|
||||
if err != nil {
|
||||
return 0
|
||||
}
|
||||
return t.Unix()
|
||||
}
|
||||
|
||||
// Register register the Migration in the map
|
||||
func Register(name string, m Migrationer) error {
|
||||
if _, ok := migrationMap[name]; ok {
|
||||
return errors.New("already exist name:" + name)
|
||||
}
|
||||
migrationMap[name] = m
|
||||
return nil
|
||||
}
|
||||
|
||||
// Upgrade upgrade the migration from lasttime
|
||||
func Upgrade(lasttime int64) error {
|
||||
sm := sortMap(migrationMap)
|
||||
i := 0
|
||||
migs, _ := getAllMigrations()
|
||||
for _, v := range sm {
|
||||
if _, ok := migs[v.name]; !ok {
|
||||
logs.Info("start upgrade", v.name)
|
||||
v.m.Reset()
|
||||
v.m.Up()
|
||||
err := v.m.Exec(v.name, "up")
|
||||
if err != nil {
|
||||
logs.Error("execute error:", err)
|
||||
time.Sleep(2 * time.Second)
|
||||
return err
|
||||
}
|
||||
logs.Info("end upgrade:", v.name)
|
||||
i++
|
||||
}
|
||||
}
|
||||
logs.Info("total success upgrade:", i, " migration")
|
||||
time.Sleep(2 * time.Second)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Rollback rollback the migration by the name
|
||||
func Rollback(name string) error {
|
||||
if v, ok := migrationMap[name]; ok {
|
||||
logs.Info("start rollback")
|
||||
v.Reset()
|
||||
v.Down()
|
||||
err := v.Exec(name, "down")
|
||||
if err != nil {
|
||||
logs.Error("execute error:", err)
|
||||
time.Sleep(2 * time.Second)
|
||||
return err
|
||||
}
|
||||
logs.Info("end rollback")
|
||||
time.Sleep(2 * time.Second)
|
||||
return nil
|
||||
}
|
||||
logs.Error("not exist the migrationMap name:" + name)
|
||||
time.Sleep(2 * time.Second)
|
||||
return errors.New("not exist the migrationMap name:" + name)
|
||||
}
|
||||
|
||||
// Reset reset all migration
|
||||
// run all migration's down function
|
||||
func Reset() error {
|
||||
sm := sortMap(migrationMap)
|
||||
i := 0
|
||||
for j := len(sm) - 1; j >= 0; j-- {
|
||||
v := sm[j]
|
||||
if isRollBack(v.name) {
|
||||
logs.Info("skip the", v.name)
|
||||
time.Sleep(1 * time.Second)
|
||||
continue
|
||||
}
|
||||
logs.Info("start reset:", v.name)
|
||||
v.m.Reset()
|
||||
v.m.Down()
|
||||
err := v.m.Exec(v.name, "down")
|
||||
if err != nil {
|
||||
logs.Error("execute error:", err)
|
||||
time.Sleep(2 * time.Second)
|
||||
return err
|
||||
}
|
||||
i++
|
||||
logs.Info("end reset:", v.name)
|
||||
}
|
||||
logs.Info("total success reset:", i, " migration")
|
||||
time.Sleep(2 * time.Second)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Refresh first Reset, then Upgrade
|
||||
func Refresh() error {
|
||||
err := Reset()
|
||||
if err != nil {
|
||||
logs.Error("execute error:", err)
|
||||
time.Sleep(2 * time.Second)
|
||||
return err
|
||||
}
|
||||
err = Upgrade(0)
|
||||
return err
|
||||
}
|
||||
|
||||
type dataSlice []data
|
||||
|
||||
type data struct {
|
||||
created int64
|
||||
name string
|
||||
m Migrationer
|
||||
}
|
||||
|
||||
// Len is part of sort.Interface.
|
||||
func (d dataSlice) Len() int {
|
||||
return len(d)
|
||||
}
|
||||
|
||||
// Swap is part of sort.Interface.
|
||||
func (d dataSlice) Swap(i, j int) {
|
||||
d[i], d[j] = d[j], d[i]
|
||||
}
|
||||
|
||||
// Less is part of sort.Interface. We use count as the value to sort by
|
||||
func (d dataSlice) Less(i, j int) bool {
|
||||
return d[i].created < d[j].created
|
||||
}
|
||||
|
||||
func sortMap(m map[string]Migrationer) dataSlice {
|
||||
s := make(dataSlice, 0, len(m))
|
||||
for k, v := range m {
|
||||
d := data{}
|
||||
d.created = v.GetCreated()
|
||||
d.name = k
|
||||
d.m = v
|
||||
s = append(s, d)
|
||||
}
|
||||
sort.Sort(s)
|
||||
return s
|
||||
}
|
||||
|
||||
func isRollBack(name string) bool {
|
||||
o := orm.NewOrm()
|
||||
var maps []orm.Params
|
||||
num, err := o.Raw("select * from migrations where `name` = ? order by id_migration desc", name).Values(&maps)
|
||||
if err != nil {
|
||||
logs.Info("get name has error", err)
|
||||
return false
|
||||
}
|
||||
if num <= 0 {
|
||||
return false
|
||||
}
|
||||
if maps[0]["status"] == "rollback" {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
func getAllMigrations() (map[string]string, error) {
|
||||
o := orm.NewOrm()
|
||||
var maps []orm.Params
|
||||
migs := make(map[string]string)
|
||||
num, err := o.Raw("select * from migrations order by id_migration desc").Values(&maps)
|
||||
if err != nil {
|
||||
logs.Info("get name has error", err)
|
||||
return migs, err
|
||||
}
|
||||
if num > 0 {
|
||||
for _, v := range maps {
|
||||
name := v["name"].(string)
|
||||
migs[name] = v["status"].(string)
|
||||
}
|
||||
}
|
||||
return migs, nil
|
||||
}
|
62
pkg/client/orm/model_utils_test.go
Normal file
62
pkg/client/orm/model_utils_test.go
Normal file
@ -0,0 +1,62 @@
|
||||
// Copyright 2020
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package orm
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
type Interface struct {
|
||||
Id int
|
||||
Name string
|
||||
|
||||
Index1 string
|
||||
Index2 string
|
||||
|
||||
Unique1 string
|
||||
Unique2 string
|
||||
}
|
||||
|
||||
func (i *Interface) TableIndex() [][]string {
|
||||
return [][]string{{"index1"}, {"index2"}}
|
||||
}
|
||||
|
||||
func (i *Interface) TableUnique() [][]string {
|
||||
return [][]string{{"unique1"}, {"unique2"}}
|
||||
}
|
||||
|
||||
func (i *Interface) TableName() string {
|
||||
return "INTERFACE_"
|
||||
}
|
||||
|
||||
func (i *Interface) TableEngine() string {
|
||||
return "innodb"
|
||||
}
|
||||
|
||||
func TestDbBase_GetTables(t *testing.T) {
|
||||
RegisterModel(&Interface{})
|
||||
mi, ok := modelCache.get("INTERFACE_")
|
||||
assert.True(t, ok)
|
||||
assert.NotNil(t, mi)
|
||||
|
||||
engine := getTableEngine(mi.addrField)
|
||||
assert.Equal(t, "innodb", engine)
|
||||
uniques := getTableUnique(mi.addrField)
|
||||
assert.Equal(t, [][]string{{"unique1"}, {"unique2"}}, uniques)
|
||||
indexes := getTableIndex(mi.addrField)
|
||||
assert.Equal(t, [][]string{{"index1"}, {"index2"}}, indexes)
|
||||
}
|
108
pkg/client/orm/models.go
Normal file
108
pkg/client/orm/models.go
Normal file
@ -0,0 +1,108 @@
|
||||
// Copyright 2014 beego Author. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package orm
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"sync"
|
||||
)
|
||||
|
||||
const (
|
||||
odCascade = "cascade"
|
||||
odSetNULL = "set_null"
|
||||
odSetDefault = "set_default"
|
||||
odDoNothing = "do_nothing"
|
||||
defaultStructTagName = "orm"
|
||||
defaultStructTagDelim = ";"
|
||||
)
|
||||
|
||||
var (
|
||||
modelCache = &_modelCache{
|
||||
cache: make(map[string]*modelInfo),
|
||||
cacheByFullName: make(map[string]*modelInfo),
|
||||
}
|
||||
)
|
||||
|
||||
// model info collection
|
||||
type _modelCache struct {
|
||||
sync.RWMutex // only used outsite for bootStrap
|
||||
orders []string
|
||||
cache map[string]*modelInfo
|
||||
cacheByFullName map[string]*modelInfo
|
||||
done bool
|
||||
}
|
||||
|
||||
// get all model info
|
||||
func (mc *_modelCache) all() map[string]*modelInfo {
|
||||
m := make(map[string]*modelInfo, len(mc.cache))
|
||||
for k, v := range mc.cache {
|
||||
m[k] = v
|
||||
}
|
||||
return m
|
||||
}
|
||||
|
||||
// get ordered model info
|
||||
func (mc *_modelCache) allOrdered() []*modelInfo {
|
||||
m := make([]*modelInfo, 0, len(mc.orders))
|
||||
for _, table := range mc.orders {
|
||||
m = append(m, mc.cache[table])
|
||||
}
|
||||
return m
|
||||
}
|
||||
|
||||
// get model info by table name
|
||||
func (mc *_modelCache) get(table string) (mi *modelInfo, ok bool) {
|
||||
mi, ok = mc.cache[table]
|
||||
return
|
||||
}
|
||||
|
||||
// get model info by full name
|
||||
func (mc *_modelCache) getByFullName(name string) (mi *modelInfo, ok bool) {
|
||||
mi, ok = mc.cacheByFullName[name]
|
||||
return
|
||||
}
|
||||
|
||||
func (mc *_modelCache) getByMd(md interface{}) (*modelInfo, bool) {
|
||||
val := reflect.ValueOf(md)
|
||||
ind := reflect.Indirect(val)
|
||||
typ := ind.Type()
|
||||
name := getFullName(typ)
|
||||
return mc.getByFullName(name)
|
||||
}
|
||||
|
||||
// set model info to collection
|
||||
func (mc *_modelCache) set(table string, mi *modelInfo) *modelInfo {
|
||||
mii := mc.cache[table]
|
||||
mc.cache[table] = mi
|
||||
mc.cacheByFullName[mi.fullName] = mi
|
||||
if mii == nil {
|
||||
mc.orders = append(mc.orders, table)
|
||||
}
|
||||
return mii
|
||||
}
|
||||
|
||||
// clean all model info.
|
||||
func (mc *_modelCache) clean() {
|
||||
mc.orders = make([]string, 0)
|
||||
mc.cache = make(map[string]*modelInfo)
|
||||
mc.cacheByFullName = make(map[string]*modelInfo)
|
||||
mc.done = false
|
||||
}
|
||||
|
||||
// ResetModelCache Clean model cache. Then you can re-RegisterModel.
|
||||
// Common use this api for test case.
|
||||
func ResetModelCache() {
|
||||
modelCache.clean()
|
||||
}
|
347
pkg/client/orm/models_boot.go
Normal file
347
pkg/client/orm/models_boot.go
Normal file
@ -0,0 +1,347 @@
|
||||
// Copyright 2014 beego Author. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package orm
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"reflect"
|
||||
"runtime/debug"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// register models.
|
||||
// PrefixOrSuffix means table name prefix or suffix.
|
||||
// isPrefix whether the prefix is prefix or suffix
|
||||
func registerModel(PrefixOrSuffix string, model interface{}, isPrefix bool) {
|
||||
val := reflect.ValueOf(model)
|
||||
typ := reflect.Indirect(val).Type()
|
||||
|
||||
if val.Kind() != reflect.Ptr {
|
||||
panic(fmt.Errorf("<orm.RegisterModel> cannot use non-ptr model struct `%s`", getFullName(typ)))
|
||||
}
|
||||
// For this case:
|
||||
// u := &User{}
|
||||
// registerModel(&u)
|
||||
if typ.Kind() == reflect.Ptr {
|
||||
panic(fmt.Errorf("<orm.RegisterModel> only allow ptr model struct, it looks you use two reference to the struct `%s`", typ))
|
||||
}
|
||||
|
||||
table := getTableName(val)
|
||||
|
||||
if PrefixOrSuffix != "" {
|
||||
if isPrefix {
|
||||
table = PrefixOrSuffix + table
|
||||
} else {
|
||||
table = table + PrefixOrSuffix
|
||||
}
|
||||
}
|
||||
// models's fullname is pkgpath + struct name
|
||||
name := getFullName(typ)
|
||||
if _, ok := modelCache.getByFullName(name); ok {
|
||||
fmt.Printf("<orm.RegisterModel> model `%s` repeat register, must be unique\n", name)
|
||||
os.Exit(2)
|
||||
}
|
||||
|
||||
if _, ok := modelCache.get(table); ok {
|
||||
fmt.Printf("<orm.RegisterModel> table name `%s` repeat register, must be unique\n", table)
|
||||
os.Exit(2)
|
||||
}
|
||||
|
||||
mi := newModelInfo(val)
|
||||
if mi.fields.pk == nil {
|
||||
outFor:
|
||||
for _, fi := range mi.fields.fieldsDB {
|
||||
if strings.ToLower(fi.name) == "id" {
|
||||
switch fi.addrValue.Elem().Kind() {
|
||||
case reflect.Int, reflect.Int32, reflect.Int64, reflect.Uint, reflect.Uint32, reflect.Uint64:
|
||||
fi.auto = true
|
||||
fi.pk = true
|
||||
mi.fields.pk = fi
|
||||
break outFor
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if mi.fields.pk == nil {
|
||||
fmt.Printf("<orm.RegisterModel> `%s` needs a primary key field, default is to use 'id' if not set\n", name)
|
||||
os.Exit(2)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
mi.table = table
|
||||
mi.pkg = typ.PkgPath()
|
||||
mi.model = model
|
||||
mi.manual = true
|
||||
|
||||
modelCache.set(table, mi)
|
||||
}
|
||||
|
||||
// bootstrap models
|
||||
func bootStrap() {
|
||||
if modelCache.done {
|
||||
return
|
||||
}
|
||||
var (
|
||||
err error
|
||||
models map[string]*modelInfo
|
||||
)
|
||||
if dataBaseCache.getDefault() == nil {
|
||||
err = fmt.Errorf("must have one register DataBase alias named `default`")
|
||||
goto end
|
||||
}
|
||||
|
||||
// set rel and reverse model
|
||||
// RelManyToMany set the relTable
|
||||
models = modelCache.all()
|
||||
for _, mi := range models {
|
||||
for _, fi := range mi.fields.columns {
|
||||
if fi.rel || fi.reverse {
|
||||
elm := fi.addrValue.Type().Elem()
|
||||
if fi.fieldType == RelReverseMany || fi.fieldType == RelManyToMany {
|
||||
elm = elm.Elem()
|
||||
}
|
||||
// check the rel or reverse model already register
|
||||
name := getFullName(elm)
|
||||
mii, ok := modelCache.getByFullName(name)
|
||||
if !ok || mii.pkg != elm.PkgPath() {
|
||||
err = fmt.Errorf("can not find rel in field `%s`, `%s` may be miss register", fi.fullName, elm.String())
|
||||
goto end
|
||||
}
|
||||
fi.relModelInfo = mii
|
||||
|
||||
switch fi.fieldType {
|
||||
case RelManyToMany:
|
||||
if fi.relThrough != "" {
|
||||
if i := strings.LastIndex(fi.relThrough, "."); i != -1 && len(fi.relThrough) > (i+1) {
|
||||
pn := fi.relThrough[:i]
|
||||
rmi, ok := modelCache.getByFullName(fi.relThrough)
|
||||
if !ok || pn != rmi.pkg {
|
||||
err = fmt.Errorf("field `%s` wrong rel_through value `%s` cannot find table", fi.fullName, fi.relThrough)
|
||||
goto end
|
||||
}
|
||||
fi.relThroughModelInfo = rmi
|
||||
fi.relTable = rmi.table
|
||||
} else {
|
||||
err = fmt.Errorf("field `%s` wrong rel_through value `%s`", fi.fullName, fi.relThrough)
|
||||
goto end
|
||||
}
|
||||
} else {
|
||||
i := newM2MModelInfo(mi, mii)
|
||||
if fi.relTable != "" {
|
||||
i.table = fi.relTable
|
||||
}
|
||||
if v := modelCache.set(i.table, i); v != nil {
|
||||
err = fmt.Errorf("the rel table name `%s` already registered, cannot be use, please change one", fi.relTable)
|
||||
goto end
|
||||
}
|
||||
fi.relTable = i.table
|
||||
fi.relThroughModelInfo = i
|
||||
}
|
||||
|
||||
fi.relThroughModelInfo.isThrough = true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// check the rel filed while the relModelInfo also has filed point to current model
|
||||
// if not exist, add a new field to the relModelInfo
|
||||
models = modelCache.all()
|
||||
for _, mi := range models {
|
||||
for _, fi := range mi.fields.fieldsRel {
|
||||
switch fi.fieldType {
|
||||
case RelForeignKey, RelOneToOne, RelManyToMany:
|
||||
inModel := false
|
||||
for _, ffi := range fi.relModelInfo.fields.fieldsReverse {
|
||||
if ffi.relModelInfo == mi {
|
||||
inModel = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !inModel {
|
||||
rmi := fi.relModelInfo
|
||||
ffi := new(fieldInfo)
|
||||
ffi.name = mi.name
|
||||
ffi.column = ffi.name
|
||||
ffi.fullName = rmi.fullName + "." + ffi.name
|
||||
ffi.reverse = true
|
||||
ffi.relModelInfo = mi
|
||||
ffi.mi = rmi
|
||||
if fi.fieldType == RelOneToOne {
|
||||
ffi.fieldType = RelReverseOne
|
||||
} else {
|
||||
ffi.fieldType = RelReverseMany
|
||||
}
|
||||
if !rmi.fields.Add(ffi) {
|
||||
added := false
|
||||
for cnt := 0; cnt < 5; cnt++ {
|
||||
ffi.name = fmt.Sprintf("%s%d", mi.name, cnt)
|
||||
ffi.column = ffi.name
|
||||
ffi.fullName = rmi.fullName + "." + ffi.name
|
||||
if added = rmi.fields.Add(ffi); added {
|
||||
break
|
||||
}
|
||||
}
|
||||
if !added {
|
||||
panic(fmt.Errorf("cannot generate auto reverse field info `%s` to `%s`", fi.fullName, ffi.fullName))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
models = modelCache.all()
|
||||
for _, mi := range models {
|
||||
for _, fi := range mi.fields.fieldsRel {
|
||||
switch fi.fieldType {
|
||||
case RelManyToMany:
|
||||
for _, ffi := range fi.relThroughModelInfo.fields.fieldsRel {
|
||||
switch ffi.fieldType {
|
||||
case RelOneToOne, RelForeignKey:
|
||||
if ffi.relModelInfo == fi.relModelInfo {
|
||||
fi.reverseFieldInfoTwo = ffi
|
||||
}
|
||||
if ffi.relModelInfo == mi {
|
||||
fi.reverseField = ffi.name
|
||||
fi.reverseFieldInfo = ffi
|
||||
}
|
||||
}
|
||||
}
|
||||
if fi.reverseFieldInfoTwo == nil {
|
||||
err = fmt.Errorf("can not find m2m field for m2m model `%s`, ensure your m2m model defined correct",
|
||||
fi.relThroughModelInfo.fullName)
|
||||
goto end
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
models = modelCache.all()
|
||||
for _, mi := range models {
|
||||
for _, fi := range mi.fields.fieldsReverse {
|
||||
switch fi.fieldType {
|
||||
case RelReverseOne:
|
||||
found := false
|
||||
mForA:
|
||||
for _, ffi := range fi.relModelInfo.fields.fieldsByType[RelOneToOne] {
|
||||
if ffi.relModelInfo == mi {
|
||||
found = true
|
||||
fi.reverseField = ffi.name
|
||||
fi.reverseFieldInfo = ffi
|
||||
|
||||
ffi.reverseField = fi.name
|
||||
ffi.reverseFieldInfo = fi
|
||||
break mForA
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
err = fmt.Errorf("reverse field `%s` not found in model `%s`", fi.fullName, fi.relModelInfo.fullName)
|
||||
goto end
|
||||
}
|
||||
case RelReverseMany:
|
||||
found := false
|
||||
mForB:
|
||||
for _, ffi := range fi.relModelInfo.fields.fieldsByType[RelForeignKey] {
|
||||
if ffi.relModelInfo == mi {
|
||||
found = true
|
||||
fi.reverseField = ffi.name
|
||||
fi.reverseFieldInfo = ffi
|
||||
|
||||
ffi.reverseField = fi.name
|
||||
ffi.reverseFieldInfo = fi
|
||||
|
||||
break mForB
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
mForC:
|
||||
for _, ffi := range fi.relModelInfo.fields.fieldsByType[RelManyToMany] {
|
||||
conditions := fi.relThrough != "" && fi.relThrough == ffi.relThrough ||
|
||||
fi.relTable != "" && fi.relTable == ffi.relTable ||
|
||||
fi.relThrough == "" && fi.relTable == ""
|
||||
if ffi.relModelInfo == mi && conditions {
|
||||
found = true
|
||||
|
||||
fi.reverseField = ffi.reverseFieldInfoTwo.name
|
||||
fi.reverseFieldInfo = ffi.reverseFieldInfoTwo
|
||||
fi.relThroughModelInfo = ffi.relThroughModelInfo
|
||||
fi.reverseFieldInfoTwo = ffi.reverseFieldInfo
|
||||
fi.reverseFieldInfoM2M = ffi
|
||||
ffi.reverseFieldInfoM2M = fi
|
||||
|
||||
break mForC
|
||||
}
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
err = fmt.Errorf("reverse field for `%s` not found in model `%s`", fi.fullName, fi.relModelInfo.fullName)
|
||||
goto end
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
end:
|
||||
if err != nil {
|
||||
fmt.Println(err)
|
||||
debug.PrintStack()
|
||||
os.Exit(2)
|
||||
}
|
||||
}
|
||||
|
||||
// RegisterModel register models
|
||||
func RegisterModel(models ...interface{}) {
|
||||
if modelCache.done {
|
||||
panic(fmt.Errorf("RegisterModel must be run before BootStrap"))
|
||||
}
|
||||
RegisterModelWithPrefix("", models...)
|
||||
}
|
||||
|
||||
// RegisterModelWithPrefix register models with a prefix
|
||||
func RegisterModelWithPrefix(prefix string, models ...interface{}) {
|
||||
if modelCache.done {
|
||||
panic(fmt.Errorf("RegisterModelWithPrefix must be run before BootStrap"))
|
||||
}
|
||||
|
||||
for _, model := range models {
|
||||
registerModel(prefix, model, true)
|
||||
}
|
||||
}
|
||||
|
||||
// RegisterModelWithSuffix register models with a suffix
|
||||
func RegisterModelWithSuffix(suffix string, models ...interface{}) {
|
||||
if modelCache.done {
|
||||
panic(fmt.Errorf("RegisterModelWithSuffix must be run before BootStrap"))
|
||||
}
|
||||
|
||||
for _, model := range models {
|
||||
registerModel(suffix, model, false)
|
||||
}
|
||||
}
|
||||
|
||||
// BootStrap bootstrap models.
|
||||
// make all model parsed and can not add more models
|
||||
func BootStrap() {
|
||||
modelCache.Lock()
|
||||
defer modelCache.Unlock()
|
||||
if modelCache.done {
|
||||
return
|
||||
}
|
||||
bootStrap()
|
||||
modelCache.done = true
|
||||
}
|
783
pkg/client/orm/models_fields.go
Normal file
783
pkg/client/orm/models_fields.go
Normal file
@ -0,0 +1,783 @@
|
||||
// Copyright 2014 beego Author. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package orm
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strconv"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Define the Type enum
|
||||
const (
|
||||
TypeBooleanField = 1 << iota
|
||||
TypeVarCharField
|
||||
TypeCharField
|
||||
TypeTextField
|
||||
TypeTimeField
|
||||
TypeDateField
|
||||
TypeDateTimeField
|
||||
TypeBitField
|
||||
TypeSmallIntegerField
|
||||
TypeIntegerField
|
||||
TypeBigIntegerField
|
||||
TypePositiveBitField
|
||||
TypePositiveSmallIntegerField
|
||||
TypePositiveIntegerField
|
||||
TypePositiveBigIntegerField
|
||||
TypeFloatField
|
||||
TypeDecimalField
|
||||
TypeJSONField
|
||||
TypeJsonbField
|
||||
RelForeignKey
|
||||
RelOneToOne
|
||||
RelManyToMany
|
||||
RelReverseOne
|
||||
RelReverseMany
|
||||
)
|
||||
|
||||
// Define some logic enum
|
||||
const (
|
||||
IsIntegerField = ^-TypePositiveBigIntegerField >> 6 << 7
|
||||
IsPositiveIntegerField = ^-TypePositiveBigIntegerField >> 10 << 11
|
||||
IsRelField = ^-RelReverseMany >> 18 << 19
|
||||
IsFieldType = ^-RelReverseMany<<1 + 1
|
||||
)
|
||||
|
||||
// BooleanField A true/false field.
|
||||
type BooleanField bool
|
||||
|
||||
// Value return the BooleanField
|
||||
func (e BooleanField) Value() bool {
|
||||
return bool(e)
|
||||
}
|
||||
|
||||
// Set will set the BooleanField
|
||||
func (e *BooleanField) Set(d bool) {
|
||||
*e = BooleanField(d)
|
||||
}
|
||||
|
||||
// String format the Bool to string
|
||||
func (e *BooleanField) String() string {
|
||||
return strconv.FormatBool(e.Value())
|
||||
}
|
||||
|
||||
// FieldType return BooleanField the type
|
||||
func (e *BooleanField) FieldType() int {
|
||||
return TypeBooleanField
|
||||
}
|
||||
|
||||
// SetRaw set the interface to bool
|
||||
func (e *BooleanField) SetRaw(value interface{}) error {
|
||||
switch d := value.(type) {
|
||||
case bool:
|
||||
e.Set(d)
|
||||
case string:
|
||||
v, err := StrTo(d).Bool()
|
||||
if err == nil {
|
||||
e.Set(v)
|
||||
}
|
||||
return err
|
||||
default:
|
||||
return fmt.Errorf("<BooleanField.SetRaw> unknown value `%s`", value)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// RawValue return the current value
|
||||
func (e *BooleanField) RawValue() interface{} {
|
||||
return e.Value()
|
||||
}
|
||||
|
||||
// verify the BooleanField implement the Fielder interface
|
||||
var _ Fielder = new(BooleanField)
|
||||
|
||||
// CharField A string field
|
||||
// required values tag: size
|
||||
// The size is enforced at the database level and in models’s validation.
|
||||
// eg: `orm:"size(120)"`
|
||||
type CharField string
|
||||
|
||||
// Value return the CharField's Value
|
||||
func (e CharField) Value() string {
|
||||
return string(e)
|
||||
}
|
||||
|
||||
// Set CharField value
|
||||
func (e *CharField) Set(d string) {
|
||||
*e = CharField(d)
|
||||
}
|
||||
|
||||
// String return the CharField
|
||||
func (e *CharField) String() string {
|
||||
return e.Value()
|
||||
}
|
||||
|
||||
// FieldType return the enum type
|
||||
func (e *CharField) FieldType() int {
|
||||
return TypeVarCharField
|
||||
}
|
||||
|
||||
// SetRaw set the interface to string
|
||||
func (e *CharField) SetRaw(value interface{}) error {
|
||||
switch d := value.(type) {
|
||||
case string:
|
||||
e.Set(d)
|
||||
default:
|
||||
return fmt.Errorf("<CharField.SetRaw> unknown value `%s`", value)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// RawValue return the CharField value
|
||||
func (e *CharField) RawValue() interface{} {
|
||||
return e.Value()
|
||||
}
|
||||
|
||||
// verify CharField implement Fielder
|
||||
var _ Fielder = new(CharField)
|
||||
|
||||
// TimeField A time, represented in go by a time.Time instance.
|
||||
// only time values like 10:00:00
|
||||
// Has a few extra, optional attr tag:
|
||||
//
|
||||
// auto_now:
|
||||
// Automatically set the field to now every time the object is saved. Useful for “last-modified” timestamps.
|
||||
// Note that the current date is always used; it’s not just a default value that you can override.
|
||||
//
|
||||
// auto_now_add:
|
||||
// Automatically set the field to now when the object is first created. Useful for creation of timestamps.
|
||||
// Note that the current date is always used; it’s not just a default value that you can override.
|
||||
//
|
||||
// eg: `orm:"auto_now"` or `orm:"auto_now_add"`
|
||||
type TimeField time.Time
|
||||
|
||||
// Value return the time.Time
|
||||
func (e TimeField) Value() time.Time {
|
||||
return time.Time(e)
|
||||
}
|
||||
|
||||
// Set set the TimeField's value
|
||||
func (e *TimeField) Set(d time.Time) {
|
||||
*e = TimeField(d)
|
||||
}
|
||||
|
||||
// String convert time to string
|
||||
func (e *TimeField) String() string {
|
||||
return e.Value().String()
|
||||
}
|
||||
|
||||
// FieldType return enum type Date
|
||||
func (e *TimeField) FieldType() int {
|
||||
return TypeDateField
|
||||
}
|
||||
|
||||
// SetRaw convert the interface to time.Time. Allow string and time.Time
|
||||
func (e *TimeField) SetRaw(value interface{}) error {
|
||||
switch d := value.(type) {
|
||||
case time.Time:
|
||||
e.Set(d)
|
||||
case string:
|
||||
v, err := timeParse(d, formatTime)
|
||||
if err == nil {
|
||||
e.Set(v)
|
||||
}
|
||||
return err
|
||||
default:
|
||||
return fmt.Errorf("<TimeField.SetRaw> unknown value `%s`", value)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// RawValue return time value
|
||||
func (e *TimeField) RawValue() interface{} {
|
||||
return e.Value()
|
||||
}
|
||||
|
||||
var _ Fielder = new(TimeField)
|
||||
|
||||
// DateField A date, represented in go by a time.Time instance.
|
||||
// only date values like 2006-01-02
|
||||
// Has a few extra, optional attr tag:
|
||||
//
|
||||
// auto_now:
|
||||
// Automatically set the field to now every time the object is saved. Useful for “last-modified” timestamps.
|
||||
// Note that the current date is always used; it’s not just a default value that you can override.
|
||||
//
|
||||
// auto_now_add:
|
||||
// Automatically set the field to now when the object is first created. Useful for creation of timestamps.
|
||||
// Note that the current date is always used; it’s not just a default value that you can override.
|
||||
//
|
||||
// eg: `orm:"auto_now"` or `orm:"auto_now_add"`
|
||||
type DateField time.Time
|
||||
|
||||
// Value return the time.Time
|
||||
func (e DateField) Value() time.Time {
|
||||
return time.Time(e)
|
||||
}
|
||||
|
||||
// Set set the DateField's value
|
||||
func (e *DateField) Set(d time.Time) {
|
||||
*e = DateField(d)
|
||||
}
|
||||
|
||||
// String convert datetime to string
|
||||
func (e *DateField) String() string {
|
||||
return e.Value().String()
|
||||
}
|
||||
|
||||
// FieldType return enum type Date
|
||||
func (e *DateField) FieldType() int {
|
||||
return TypeDateField
|
||||
}
|
||||
|
||||
// SetRaw convert the interface to time.Time. Allow string and time.Time
|
||||
func (e *DateField) SetRaw(value interface{}) error {
|
||||
switch d := value.(type) {
|
||||
case time.Time:
|
||||
e.Set(d)
|
||||
case string:
|
||||
v, err := timeParse(d, formatDate)
|
||||
if err == nil {
|
||||
e.Set(v)
|
||||
}
|
||||
return err
|
||||
default:
|
||||
return fmt.Errorf("<DateField.SetRaw> unknown value `%s`", value)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// RawValue return Date value
|
||||
func (e *DateField) RawValue() interface{} {
|
||||
return e.Value()
|
||||
}
|
||||
|
||||
// verify DateField implement fielder interface
|
||||
var _ Fielder = new(DateField)
|
||||
|
||||
// DateTimeField A date, represented in go by a time.Time instance.
|
||||
// datetime values like 2006-01-02 15:04:05
|
||||
// Takes the same extra arguments as DateField.
|
||||
type DateTimeField time.Time
|
||||
|
||||
// Value return the datetime value
|
||||
func (e DateTimeField) Value() time.Time {
|
||||
return time.Time(e)
|
||||
}
|
||||
|
||||
// Set set the time.Time to datetime
|
||||
func (e *DateTimeField) Set(d time.Time) {
|
||||
*e = DateTimeField(d)
|
||||
}
|
||||
|
||||
// String return the time's String
|
||||
func (e *DateTimeField) String() string {
|
||||
return e.Value().String()
|
||||
}
|
||||
|
||||
// FieldType return the enum TypeDateTimeField
|
||||
func (e *DateTimeField) FieldType() int {
|
||||
return TypeDateTimeField
|
||||
}
|
||||
|
||||
// SetRaw convert the string or time.Time to DateTimeField
|
||||
func (e *DateTimeField) SetRaw(value interface{}) error {
|
||||
switch d := value.(type) {
|
||||
case time.Time:
|
||||
e.Set(d)
|
||||
case string:
|
||||
v, err := timeParse(d, formatDateTime)
|
||||
if err == nil {
|
||||
e.Set(v)
|
||||
}
|
||||
return err
|
||||
default:
|
||||
return fmt.Errorf("<DateTimeField.SetRaw> unknown value `%s`", value)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// RawValue return the datetime value
|
||||
func (e *DateTimeField) RawValue() interface{} {
|
||||
return e.Value()
|
||||
}
|
||||
|
||||
// verify datetime implement fielder
|
||||
var _ Fielder = new(DateTimeField)
|
||||
|
||||
// FloatField A floating-point number represented in go by a float32 value.
|
||||
type FloatField float64
|
||||
|
||||
// Value return the FloatField value
|
||||
func (e FloatField) Value() float64 {
|
||||
return float64(e)
|
||||
}
|
||||
|
||||
// Set the Float64
|
||||
func (e *FloatField) Set(d float64) {
|
||||
*e = FloatField(d)
|
||||
}
|
||||
|
||||
// String return the string
|
||||
func (e *FloatField) String() string {
|
||||
return ToStr(e.Value(), -1, 32)
|
||||
}
|
||||
|
||||
// FieldType return the enum type
|
||||
func (e *FloatField) FieldType() int {
|
||||
return TypeFloatField
|
||||
}
|
||||
|
||||
// SetRaw converter interface Float64 float32 or string to FloatField
|
||||
func (e *FloatField) SetRaw(value interface{}) error {
|
||||
switch d := value.(type) {
|
||||
case float32:
|
||||
e.Set(float64(d))
|
||||
case float64:
|
||||
e.Set(d)
|
||||
case string:
|
||||
v, err := StrTo(d).Float64()
|
||||
if err == nil {
|
||||
e.Set(v)
|
||||
}
|
||||
return err
|
||||
default:
|
||||
return fmt.Errorf("<FloatField.SetRaw> unknown value `%s`", value)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// RawValue return the FloatField value
|
||||
func (e *FloatField) RawValue() interface{} {
|
||||
return e.Value()
|
||||
}
|
||||
|
||||
// verify FloatField implement Fielder
|
||||
var _ Fielder = new(FloatField)
|
||||
|
||||
// SmallIntegerField -32768 to 32767
|
||||
type SmallIntegerField int16
|
||||
|
||||
// Value return int16 value
|
||||
func (e SmallIntegerField) Value() int16 {
|
||||
return int16(e)
|
||||
}
|
||||
|
||||
// Set the SmallIntegerField value
|
||||
func (e *SmallIntegerField) Set(d int16) {
|
||||
*e = SmallIntegerField(d)
|
||||
}
|
||||
|
||||
// String convert smallint to string
|
||||
func (e *SmallIntegerField) String() string {
|
||||
return ToStr(e.Value())
|
||||
}
|
||||
|
||||
// FieldType return enum type SmallIntegerField
|
||||
func (e *SmallIntegerField) FieldType() int {
|
||||
return TypeSmallIntegerField
|
||||
}
|
||||
|
||||
// SetRaw convert interface int16/string to int16
|
||||
func (e *SmallIntegerField) SetRaw(value interface{}) error {
|
||||
switch d := value.(type) {
|
||||
case int16:
|
||||
e.Set(d)
|
||||
case string:
|
||||
v, err := StrTo(d).Int16()
|
||||
if err == nil {
|
||||
e.Set(v)
|
||||
}
|
||||
return err
|
||||
default:
|
||||
return fmt.Errorf("<SmallIntegerField.SetRaw> unknown value `%s`", value)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// RawValue return smallint value
|
||||
func (e *SmallIntegerField) RawValue() interface{} {
|
||||
return e.Value()
|
||||
}
|
||||
|
||||
// verify SmallIntegerField implement Fielder
|
||||
var _ Fielder = new(SmallIntegerField)
|
||||
|
||||
// IntegerField -2147483648 to 2147483647
|
||||
type IntegerField int32
|
||||
|
||||
// Value return the int32
|
||||
func (e IntegerField) Value() int32 {
|
||||
return int32(e)
|
||||
}
|
||||
|
||||
// Set IntegerField value
|
||||
func (e *IntegerField) Set(d int32) {
|
||||
*e = IntegerField(d)
|
||||
}
|
||||
|
||||
// String convert Int32 to string
|
||||
func (e *IntegerField) String() string {
|
||||
return ToStr(e.Value())
|
||||
}
|
||||
|
||||
// FieldType return the enum type
|
||||
func (e *IntegerField) FieldType() int {
|
||||
return TypeIntegerField
|
||||
}
|
||||
|
||||
// SetRaw convert interface int32/string to int32
|
||||
func (e *IntegerField) SetRaw(value interface{}) error {
|
||||
switch d := value.(type) {
|
||||
case int32:
|
||||
e.Set(d)
|
||||
case string:
|
||||
v, err := StrTo(d).Int32()
|
||||
if err == nil {
|
||||
e.Set(v)
|
||||
}
|
||||
return err
|
||||
default:
|
||||
return fmt.Errorf("<IntegerField.SetRaw> unknown value `%s`", value)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// RawValue return IntegerField value
|
||||
func (e *IntegerField) RawValue() interface{} {
|
||||
return e.Value()
|
||||
}
|
||||
|
||||
// verify IntegerField implement Fielder
|
||||
var _ Fielder = new(IntegerField)
|
||||
|
||||
// BigIntegerField -9223372036854775808 to 9223372036854775807.
|
||||
type BigIntegerField int64
|
||||
|
||||
// Value return int64
|
||||
func (e BigIntegerField) Value() int64 {
|
||||
return int64(e)
|
||||
}
|
||||
|
||||
// Set the BigIntegerField value
|
||||
func (e *BigIntegerField) Set(d int64) {
|
||||
*e = BigIntegerField(d)
|
||||
}
|
||||
|
||||
// String convert BigIntegerField to string
|
||||
func (e *BigIntegerField) String() string {
|
||||
return ToStr(e.Value())
|
||||
}
|
||||
|
||||
// FieldType return enum type
|
||||
func (e *BigIntegerField) FieldType() int {
|
||||
return TypeBigIntegerField
|
||||
}
|
||||
|
||||
// SetRaw convert interface int64/string to int64
|
||||
func (e *BigIntegerField) SetRaw(value interface{}) error {
|
||||
switch d := value.(type) {
|
||||
case int64:
|
||||
e.Set(d)
|
||||
case string:
|
||||
v, err := StrTo(d).Int64()
|
||||
if err == nil {
|
||||
e.Set(v)
|
||||
}
|
||||
return err
|
||||
default:
|
||||
return fmt.Errorf("<BigIntegerField.SetRaw> unknown value `%s`", value)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// RawValue return BigIntegerField value
|
||||
func (e *BigIntegerField) RawValue() interface{} {
|
||||
return e.Value()
|
||||
}
|
||||
|
||||
// verify BigIntegerField implement Fielder
|
||||
var _ Fielder = new(BigIntegerField)
|
||||
|
||||
// PositiveSmallIntegerField 0 to 65535
|
||||
type PositiveSmallIntegerField uint16
|
||||
|
||||
// Value return uint16
|
||||
func (e PositiveSmallIntegerField) Value() uint16 {
|
||||
return uint16(e)
|
||||
}
|
||||
|
||||
// Set PositiveSmallIntegerField value
|
||||
func (e *PositiveSmallIntegerField) Set(d uint16) {
|
||||
*e = PositiveSmallIntegerField(d)
|
||||
}
|
||||
|
||||
// String convert uint16 to string
|
||||
func (e *PositiveSmallIntegerField) String() string {
|
||||
return ToStr(e.Value())
|
||||
}
|
||||
|
||||
// FieldType return enum type
|
||||
func (e *PositiveSmallIntegerField) FieldType() int {
|
||||
return TypePositiveSmallIntegerField
|
||||
}
|
||||
|
||||
// SetRaw convert Interface uint16/string to uint16
|
||||
func (e *PositiveSmallIntegerField) SetRaw(value interface{}) error {
|
||||
switch d := value.(type) {
|
||||
case uint16:
|
||||
e.Set(d)
|
||||
case string:
|
||||
v, err := StrTo(d).Uint16()
|
||||
if err == nil {
|
||||
e.Set(v)
|
||||
}
|
||||
return err
|
||||
default:
|
||||
return fmt.Errorf("<PositiveSmallIntegerField.SetRaw> unknown value `%s`", value)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// RawValue returns PositiveSmallIntegerField value
|
||||
func (e *PositiveSmallIntegerField) RawValue() interface{} {
|
||||
return e.Value()
|
||||
}
|
||||
|
||||
// verify PositiveSmallIntegerField implement Fielder
|
||||
var _ Fielder = new(PositiveSmallIntegerField)
|
||||
|
||||
// PositiveIntegerField 0 to 4294967295
|
||||
type PositiveIntegerField uint32
|
||||
|
||||
// Value return PositiveIntegerField value. Uint32
|
||||
func (e PositiveIntegerField) Value() uint32 {
|
||||
return uint32(e)
|
||||
}
|
||||
|
||||
// Set the PositiveIntegerField value
|
||||
func (e *PositiveIntegerField) Set(d uint32) {
|
||||
*e = PositiveIntegerField(d)
|
||||
}
|
||||
|
||||
// String convert PositiveIntegerField to string
|
||||
func (e *PositiveIntegerField) String() string {
|
||||
return ToStr(e.Value())
|
||||
}
|
||||
|
||||
// FieldType return enum type
|
||||
func (e *PositiveIntegerField) FieldType() int {
|
||||
return TypePositiveIntegerField
|
||||
}
|
||||
|
||||
// SetRaw convert interface uint32/string to Uint32
|
||||
func (e *PositiveIntegerField) SetRaw(value interface{}) error {
|
||||
switch d := value.(type) {
|
||||
case uint32:
|
||||
e.Set(d)
|
||||
case string:
|
||||
v, err := StrTo(d).Uint32()
|
||||
if err == nil {
|
||||
e.Set(v)
|
||||
}
|
||||
return err
|
||||
default:
|
||||
return fmt.Errorf("<PositiveIntegerField.SetRaw> unknown value `%s`", value)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// RawValue return the PositiveIntegerField Value
|
||||
func (e *PositiveIntegerField) RawValue() interface{} {
|
||||
return e.Value()
|
||||
}
|
||||
|
||||
// verify PositiveIntegerField implement Fielder
|
||||
var _ Fielder = new(PositiveIntegerField)
|
||||
|
||||
// PositiveBigIntegerField 0 to 18446744073709551615
|
||||
type PositiveBigIntegerField uint64
|
||||
|
||||
// Value return uint64
|
||||
func (e PositiveBigIntegerField) Value() uint64 {
|
||||
return uint64(e)
|
||||
}
|
||||
|
||||
// Set PositiveBigIntegerField value
|
||||
func (e *PositiveBigIntegerField) Set(d uint64) {
|
||||
*e = PositiveBigIntegerField(d)
|
||||
}
|
||||
|
||||
// String convert PositiveBigIntegerField to string
|
||||
func (e *PositiveBigIntegerField) String() string {
|
||||
return ToStr(e.Value())
|
||||
}
|
||||
|
||||
// FieldType return enum type
|
||||
func (e *PositiveBigIntegerField) FieldType() int {
|
||||
return TypePositiveIntegerField
|
||||
}
|
||||
|
||||
// SetRaw convert interface uint64/string to Uint64
|
||||
func (e *PositiveBigIntegerField) SetRaw(value interface{}) error {
|
||||
switch d := value.(type) {
|
||||
case uint64:
|
||||
e.Set(d)
|
||||
case string:
|
||||
v, err := StrTo(d).Uint64()
|
||||
if err == nil {
|
||||
e.Set(v)
|
||||
}
|
||||
return err
|
||||
default:
|
||||
return fmt.Errorf("<PositiveBigIntegerField.SetRaw> unknown value `%s`", value)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// RawValue return PositiveBigIntegerField value
|
||||
func (e *PositiveBigIntegerField) RawValue() interface{} {
|
||||
return e.Value()
|
||||
}
|
||||
|
||||
// verify PositiveBigIntegerField implement Fielder
|
||||
var _ Fielder = new(PositiveBigIntegerField)
|
||||
|
||||
// TextField A large text field.
|
||||
type TextField string
|
||||
|
||||
// Value return TextField value
|
||||
func (e TextField) Value() string {
|
||||
return string(e)
|
||||
}
|
||||
|
||||
// Set the TextField value
|
||||
func (e *TextField) Set(d string) {
|
||||
*e = TextField(d)
|
||||
}
|
||||
|
||||
// String convert TextField to string
|
||||
func (e *TextField) String() string {
|
||||
return e.Value()
|
||||
}
|
||||
|
||||
// FieldType return enum type
|
||||
func (e *TextField) FieldType() int {
|
||||
return TypeTextField
|
||||
}
|
||||
|
||||
// SetRaw convert interface string to string
|
||||
func (e *TextField) SetRaw(value interface{}) error {
|
||||
switch d := value.(type) {
|
||||
case string:
|
||||
e.Set(d)
|
||||
default:
|
||||
return fmt.Errorf("<TextField.SetRaw> unknown value `%s`", value)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// RawValue return TextField value
|
||||
func (e *TextField) RawValue() interface{} {
|
||||
return e.Value()
|
||||
}
|
||||
|
||||
// verify TextField implement Fielder
|
||||
var _ Fielder = new(TextField)
|
||||
|
||||
// JSONField postgres json field.
|
||||
type JSONField string
|
||||
|
||||
// Value return JSONField value
|
||||
func (j JSONField) Value() string {
|
||||
return string(j)
|
||||
}
|
||||
|
||||
// Set the JSONField value
|
||||
func (j *JSONField) Set(d string) {
|
||||
*j = JSONField(d)
|
||||
}
|
||||
|
||||
// String convert JSONField to string
|
||||
func (j *JSONField) String() string {
|
||||
return j.Value()
|
||||
}
|
||||
|
||||
// FieldType return enum type
|
||||
func (j *JSONField) FieldType() int {
|
||||
return TypeJSONField
|
||||
}
|
||||
|
||||
// SetRaw convert interface string to string
|
||||
func (j *JSONField) SetRaw(value interface{}) error {
|
||||
switch d := value.(type) {
|
||||
case string:
|
||||
j.Set(d)
|
||||
default:
|
||||
return fmt.Errorf("<JSONField.SetRaw> unknown value `%s`", value)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// RawValue return JSONField value
|
||||
func (j *JSONField) RawValue() interface{} {
|
||||
return j.Value()
|
||||
}
|
||||
|
||||
// verify JSONField implement Fielder
|
||||
var _ Fielder = new(JSONField)
|
||||
|
||||
// JsonbField postgres json field.
|
||||
type JsonbField string
|
||||
|
||||
// Value return JsonbField value
|
||||
func (j JsonbField) Value() string {
|
||||
return string(j)
|
||||
}
|
||||
|
||||
// Set the JsonbField value
|
||||
func (j *JsonbField) Set(d string) {
|
||||
*j = JsonbField(d)
|
||||
}
|
||||
|
||||
// String convert JsonbField to string
|
||||
func (j *JsonbField) String() string {
|
||||
return j.Value()
|
||||
}
|
||||
|
||||
// FieldType return enum type
|
||||
func (j *JsonbField) FieldType() int {
|
||||
return TypeJsonbField
|
||||
}
|
||||
|
||||
// SetRaw convert interface string to string
|
||||
func (j *JsonbField) SetRaw(value interface{}) error {
|
||||
switch d := value.(type) {
|
||||
case string:
|
||||
j.Set(d)
|
||||
default:
|
||||
return fmt.Errorf("<JsonbField.SetRaw> unknown value `%s`", value)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// RawValue return JsonbField value
|
||||
func (j *JsonbField) RawValue() interface{} {
|
||||
return j.Value()
|
||||
}
|
||||
|
||||
// verify JsonbField implement Fielder
|
||||
var _ Fielder = new(JsonbField)
|
473
pkg/client/orm/models_info_f.go
Normal file
473
pkg/client/orm/models_info_f.go
Normal file
@ -0,0 +1,473 @@
|
||||
// Copyright 2014 beego Author. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package orm
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"strings"
|
||||
)
|
||||
|
||||
var errSkipField = errors.New("skip field")
|
||||
|
||||
// field info collection
|
||||
type fields struct {
|
||||
pk *fieldInfo
|
||||
columns map[string]*fieldInfo
|
||||
fields map[string]*fieldInfo
|
||||
fieldsLow map[string]*fieldInfo
|
||||
fieldsByType map[int][]*fieldInfo
|
||||
fieldsRel []*fieldInfo
|
||||
fieldsReverse []*fieldInfo
|
||||
fieldsDB []*fieldInfo
|
||||
rels []*fieldInfo
|
||||
orders []string
|
||||
dbcols []string
|
||||
}
|
||||
|
||||
// add field info
|
||||
func (f *fields) Add(fi *fieldInfo) (added bool) {
|
||||
if f.fields[fi.name] == nil && f.columns[fi.column] == nil {
|
||||
f.columns[fi.column] = fi
|
||||
f.fields[fi.name] = fi
|
||||
f.fieldsLow[strings.ToLower(fi.name)] = fi
|
||||
} else {
|
||||
return
|
||||
}
|
||||
if _, ok := f.fieldsByType[fi.fieldType]; !ok {
|
||||
f.fieldsByType[fi.fieldType] = make([]*fieldInfo, 0)
|
||||
}
|
||||
f.fieldsByType[fi.fieldType] = append(f.fieldsByType[fi.fieldType], fi)
|
||||
f.orders = append(f.orders, fi.column)
|
||||
if fi.dbcol {
|
||||
f.dbcols = append(f.dbcols, fi.column)
|
||||
f.fieldsDB = append(f.fieldsDB, fi)
|
||||
}
|
||||
if fi.rel {
|
||||
f.fieldsRel = append(f.fieldsRel, fi)
|
||||
}
|
||||
if fi.reverse {
|
||||
f.fieldsReverse = append(f.fieldsReverse, fi)
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// get field info by name
|
||||
func (f *fields) GetByName(name string) *fieldInfo {
|
||||
return f.fields[name]
|
||||
}
|
||||
|
||||
// get field info by column name
|
||||
func (f *fields) GetByColumn(column string) *fieldInfo {
|
||||
return f.columns[column]
|
||||
}
|
||||
|
||||
// get field info by string, name is prior
|
||||
func (f *fields) GetByAny(name string) (*fieldInfo, bool) {
|
||||
if fi, ok := f.fields[name]; ok {
|
||||
return fi, ok
|
||||
}
|
||||
if fi, ok := f.fieldsLow[strings.ToLower(name)]; ok {
|
||||
return fi, ok
|
||||
}
|
||||
if fi, ok := f.columns[name]; ok {
|
||||
return fi, ok
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
|
||||
// create new field info collection
|
||||
func newFields() *fields {
|
||||
f := new(fields)
|
||||
f.fields = make(map[string]*fieldInfo)
|
||||
f.fieldsLow = make(map[string]*fieldInfo)
|
||||
f.columns = make(map[string]*fieldInfo)
|
||||
f.fieldsByType = make(map[int][]*fieldInfo)
|
||||
return f
|
||||
}
|
||||
|
||||
// single field info
|
||||
type fieldInfo struct {
|
||||
mi *modelInfo
|
||||
fieldIndex []int
|
||||
fieldType int
|
||||
dbcol bool // table column fk and onetoone
|
||||
inModel bool
|
||||
name string
|
||||
fullName string
|
||||
column string
|
||||
addrValue reflect.Value
|
||||
sf reflect.StructField
|
||||
auto bool
|
||||
pk bool
|
||||
null bool
|
||||
index bool
|
||||
unique bool
|
||||
colDefault bool // whether has default tag
|
||||
initial StrTo // store the default value
|
||||
size int
|
||||
toText bool
|
||||
autoNow bool
|
||||
autoNowAdd bool
|
||||
rel bool // if type equal to RelForeignKey, RelOneToOne, RelManyToMany then true
|
||||
reverse bool
|
||||
reverseField string
|
||||
reverseFieldInfo *fieldInfo
|
||||
reverseFieldInfoTwo *fieldInfo
|
||||
reverseFieldInfoM2M *fieldInfo
|
||||
relTable string
|
||||
relThrough string
|
||||
relThroughModelInfo *modelInfo
|
||||
relModelInfo *modelInfo
|
||||
digits int
|
||||
decimals int
|
||||
isFielder bool // implement Fielder interface
|
||||
onDelete string
|
||||
description string
|
||||
}
|
||||
|
||||
// new field info
|
||||
func newFieldInfo(mi *modelInfo, field reflect.Value, sf reflect.StructField, mName string) (fi *fieldInfo, err error) {
|
||||
var (
|
||||
tag string
|
||||
tagValue string
|
||||
initial StrTo // store the default value
|
||||
fieldType int
|
||||
attrs map[string]bool
|
||||
tags map[string]string
|
||||
addrField reflect.Value
|
||||
)
|
||||
|
||||
fi = new(fieldInfo)
|
||||
|
||||
// if field which CanAddr is the follow type
|
||||
// A value is addressable if it is an element of a slice,
|
||||
// an element of an addressable array, a field of an
|
||||
// addressable struct, or the result of dereferencing a pointer.
|
||||
addrField = field
|
||||
if field.CanAddr() && field.Kind() != reflect.Ptr {
|
||||
addrField = field.Addr()
|
||||
if _, ok := addrField.Interface().(Fielder); !ok {
|
||||
if field.Kind() == reflect.Slice {
|
||||
addrField = field
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
attrs, tags = parseStructTag(sf.Tag.Get(defaultStructTagName))
|
||||
|
||||
if _, ok := attrs["-"]; ok {
|
||||
return nil, errSkipField
|
||||
}
|
||||
|
||||
digits := tags["digits"]
|
||||
decimals := tags["decimals"]
|
||||
size := tags["size"]
|
||||
onDelete := tags["on_delete"]
|
||||
|
||||
initial.Clear()
|
||||
if v, ok := tags["default"]; ok {
|
||||
initial.Set(v)
|
||||
}
|
||||
|
||||
checkType:
|
||||
switch f := addrField.Interface().(type) {
|
||||
case Fielder:
|
||||
fi.isFielder = true
|
||||
if field.Kind() == reflect.Ptr {
|
||||
err = fmt.Errorf("the model Fielder can not be use ptr")
|
||||
goto end
|
||||
}
|
||||
fieldType = f.FieldType()
|
||||
if fieldType&IsRelField > 0 {
|
||||
err = fmt.Errorf("unsupport type custom field, please refer to https://github.com/astaxie/beego/blob/master/orm/models_fields.go#L24-L42")
|
||||
goto end
|
||||
}
|
||||
default:
|
||||
tag = "rel"
|
||||
tagValue = tags[tag]
|
||||
if tagValue != "" {
|
||||
switch tagValue {
|
||||
case "fk":
|
||||
fieldType = RelForeignKey
|
||||
break checkType
|
||||
case "one":
|
||||
fieldType = RelOneToOne
|
||||
break checkType
|
||||
case "m2m":
|
||||
fieldType = RelManyToMany
|
||||
if tv := tags["rel_table"]; tv != "" {
|
||||
fi.relTable = tv
|
||||
} else if tv := tags["rel_through"]; tv != "" {
|
||||
fi.relThrough = tv
|
||||
}
|
||||
break checkType
|
||||
default:
|
||||
err = fmt.Errorf("rel only allow these value: fk, one, m2m")
|
||||
goto wrongTag
|
||||
}
|
||||
}
|
||||
tag = "reverse"
|
||||
tagValue = tags[tag]
|
||||
if tagValue != "" {
|
||||
switch tagValue {
|
||||
case "one":
|
||||
fieldType = RelReverseOne
|
||||
break checkType
|
||||
case "many":
|
||||
fieldType = RelReverseMany
|
||||
if tv := tags["rel_table"]; tv != "" {
|
||||
fi.relTable = tv
|
||||
} else if tv := tags["rel_through"]; tv != "" {
|
||||
fi.relThrough = tv
|
||||
}
|
||||
break checkType
|
||||
default:
|
||||
err = fmt.Errorf("reverse only allow these value: one, many")
|
||||
goto wrongTag
|
||||
}
|
||||
}
|
||||
|
||||
fieldType, err = getFieldType(addrField)
|
||||
if err != nil {
|
||||
goto end
|
||||
}
|
||||
if fieldType == TypeVarCharField {
|
||||
switch tags["type"] {
|
||||
case "char":
|
||||
fieldType = TypeCharField
|
||||
case "text":
|
||||
fieldType = TypeTextField
|
||||
case "json":
|
||||
fieldType = TypeJSONField
|
||||
case "jsonb":
|
||||
fieldType = TypeJsonbField
|
||||
}
|
||||
}
|
||||
if fieldType == TypeFloatField && (digits != "" || decimals != "") {
|
||||
fieldType = TypeDecimalField
|
||||
}
|
||||
if fieldType == TypeDateTimeField && tags["type"] == "date" {
|
||||
fieldType = TypeDateField
|
||||
}
|
||||
if fieldType == TypeTimeField && tags["type"] == "time" {
|
||||
fieldType = TypeTimeField
|
||||
}
|
||||
}
|
||||
|
||||
// check the rel and reverse type
|
||||
// rel should Ptr
|
||||
// reverse should slice []*struct
|
||||
switch fieldType {
|
||||
case RelForeignKey, RelOneToOne, RelReverseOne:
|
||||
if field.Kind() != reflect.Ptr {
|
||||
err = fmt.Errorf("rel/reverse:one field must be *%s", field.Type().Name())
|
||||
goto end
|
||||
}
|
||||
case RelManyToMany, RelReverseMany:
|
||||
if field.Kind() != reflect.Slice {
|
||||
err = fmt.Errorf("rel/reverse:many field must be slice")
|
||||
goto end
|
||||
} else {
|
||||
if field.Type().Elem().Kind() != reflect.Ptr {
|
||||
err = fmt.Errorf("rel/reverse:many slice must be []*%s", field.Type().Elem().Name())
|
||||
goto end
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if fieldType&IsFieldType == 0 {
|
||||
err = fmt.Errorf("wrong field type")
|
||||
goto end
|
||||
}
|
||||
|
||||
fi.fieldType = fieldType
|
||||
fi.name = sf.Name
|
||||
fi.column = getColumnName(fieldType, addrField, sf, tags["column"])
|
||||
fi.addrValue = addrField
|
||||
fi.sf = sf
|
||||
fi.fullName = mi.fullName + mName + "." + sf.Name
|
||||
|
||||
fi.description = tags["description"]
|
||||
fi.null = attrs["null"]
|
||||
fi.index = attrs["index"]
|
||||
fi.auto = attrs["auto"]
|
||||
fi.pk = attrs["pk"]
|
||||
fi.unique = attrs["unique"]
|
||||
|
||||
// Mark object property if there is attribute "default" in the orm configuration
|
||||
if _, ok := tags["default"]; ok {
|
||||
fi.colDefault = true
|
||||
}
|
||||
|
||||
switch fieldType {
|
||||
case RelManyToMany, RelReverseMany, RelReverseOne:
|
||||
fi.null = false
|
||||
fi.index = false
|
||||
fi.auto = false
|
||||
fi.pk = false
|
||||
fi.unique = false
|
||||
default:
|
||||
fi.dbcol = true
|
||||
}
|
||||
|
||||
switch fieldType {
|
||||
case RelForeignKey, RelOneToOne, RelManyToMany:
|
||||
fi.rel = true
|
||||
if fieldType == RelOneToOne {
|
||||
fi.unique = true
|
||||
}
|
||||
case RelReverseMany, RelReverseOne:
|
||||
fi.reverse = true
|
||||
}
|
||||
|
||||
if fi.rel && fi.dbcol {
|
||||
switch onDelete {
|
||||
case odCascade, odDoNothing:
|
||||
case odSetDefault:
|
||||
if !initial.Exist() {
|
||||
err = errors.New("on_delete: set_default need set field a default value")
|
||||
goto end
|
||||
}
|
||||
case odSetNULL:
|
||||
if !fi.null {
|
||||
err = errors.New("on_delete: set_null need set field null")
|
||||
goto end
|
||||
}
|
||||
default:
|
||||
if onDelete == "" {
|
||||
onDelete = odCascade
|
||||
} else {
|
||||
err = fmt.Errorf("on_delete value expected choice in `cascade,set_null,set_default,do_nothing`, unknown `%s`", onDelete)
|
||||
goto end
|
||||
}
|
||||
}
|
||||
|
||||
fi.onDelete = onDelete
|
||||
}
|
||||
|
||||
switch fieldType {
|
||||
case TypeBooleanField:
|
||||
case TypeVarCharField, TypeCharField, TypeJSONField, TypeJsonbField:
|
||||
if size != "" {
|
||||
v, e := StrTo(size).Int32()
|
||||
if e != nil {
|
||||
err = fmt.Errorf("wrong size value `%s`", size)
|
||||
} else {
|
||||
fi.size = int(v)
|
||||
}
|
||||
} else {
|
||||
fi.size = 255
|
||||
fi.toText = true
|
||||
}
|
||||
case TypeTextField:
|
||||
fi.index = false
|
||||
fi.unique = false
|
||||
case TypeTimeField, TypeDateField, TypeDateTimeField:
|
||||
if attrs["auto_now"] {
|
||||
fi.autoNow = true
|
||||
} else if attrs["auto_now_add"] {
|
||||
fi.autoNowAdd = true
|
||||
}
|
||||
case TypeFloatField:
|
||||
case TypeDecimalField:
|
||||
d1 := digits
|
||||
d2 := decimals
|
||||
v1, er1 := StrTo(d1).Int8()
|
||||
v2, er2 := StrTo(d2).Int8()
|
||||
if er1 != nil || er2 != nil {
|
||||
err = fmt.Errorf("wrong digits/decimals value %s/%s", d2, d1)
|
||||
goto end
|
||||
}
|
||||
fi.digits = int(v1)
|
||||
fi.decimals = int(v2)
|
||||
default:
|
||||
switch {
|
||||
case fieldType&IsIntegerField > 0:
|
||||
case fieldType&IsRelField > 0:
|
||||
}
|
||||
}
|
||||
|
||||
if fieldType&IsIntegerField == 0 {
|
||||
if fi.auto {
|
||||
err = fmt.Errorf("non-integer type cannot set auto")
|
||||
goto end
|
||||
}
|
||||
}
|
||||
|
||||
if fi.auto || fi.pk {
|
||||
if fi.auto {
|
||||
switch addrField.Elem().Kind() {
|
||||
case reflect.Int, reflect.Int32, reflect.Int64, reflect.Uint, reflect.Uint32, reflect.Uint64:
|
||||
default:
|
||||
err = fmt.Errorf("auto primary key only support int, int32, int64, uint, uint32, uint64 but found `%s`", addrField.Elem().Kind())
|
||||
goto end
|
||||
}
|
||||
fi.pk = true
|
||||
}
|
||||
fi.null = false
|
||||
fi.index = false
|
||||
fi.unique = false
|
||||
}
|
||||
|
||||
if fi.unique {
|
||||
fi.index = false
|
||||
}
|
||||
|
||||
// can not set default for these type
|
||||
if fi.auto || fi.pk || fi.unique || fieldType == TypeTimeField || fieldType == TypeDateField || fieldType == TypeDateTimeField {
|
||||
initial.Clear()
|
||||
}
|
||||
|
||||
if initial.Exist() {
|
||||
v := initial
|
||||
switch fieldType {
|
||||
case TypeBooleanField:
|
||||
_, err = v.Bool()
|
||||
case TypeFloatField, TypeDecimalField:
|
||||
_, err = v.Float64()
|
||||
case TypeBitField:
|
||||
_, err = v.Int8()
|
||||
case TypeSmallIntegerField:
|
||||
_, err = v.Int16()
|
||||
case TypeIntegerField:
|
||||
_, err = v.Int32()
|
||||
case TypeBigIntegerField:
|
||||
_, err = v.Int64()
|
||||
case TypePositiveBitField:
|
||||
_, err = v.Uint8()
|
||||
case TypePositiveSmallIntegerField:
|
||||
_, err = v.Uint16()
|
||||
case TypePositiveIntegerField:
|
||||
_, err = v.Uint32()
|
||||
case TypePositiveBigIntegerField:
|
||||
_, err = v.Uint64()
|
||||
}
|
||||
if err != nil {
|
||||
tag, tagValue = "default", tags["default"]
|
||||
goto wrongTag
|
||||
}
|
||||
}
|
||||
|
||||
fi.initial = initial
|
||||
end:
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return
|
||||
wrongTag:
|
||||
return nil, fmt.Errorf("wrong tag format: `%s:\"%s\"`, %s", tag, tagValue, err)
|
||||
}
|
148
pkg/client/orm/models_info_m.go
Normal file
148
pkg/client/orm/models_info_m.go
Normal file
@ -0,0 +1,148 @@
|
||||
// Copyright 2014 beego Author. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package orm
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"reflect"
|
||||
)
|
||||
|
||||
// single model info
|
||||
type modelInfo struct {
|
||||
pkg string
|
||||
name string
|
||||
fullName string
|
||||
table string
|
||||
model interface{}
|
||||
fields *fields
|
||||
manual bool
|
||||
addrField reflect.Value // store the original struct value
|
||||
uniques []string
|
||||
isThrough bool
|
||||
}
|
||||
|
||||
// new model info
|
||||
func newModelInfo(val reflect.Value) (mi *modelInfo) {
|
||||
mi = &modelInfo{}
|
||||
mi.fields = newFields()
|
||||
ind := reflect.Indirect(val)
|
||||
mi.addrField = val
|
||||
mi.name = ind.Type().Name()
|
||||
mi.fullName = getFullName(ind.Type())
|
||||
addModelFields(mi, ind, "", []int{})
|
||||
return
|
||||
}
|
||||
|
||||
// index: FieldByIndex returns the nested field corresponding to index
|
||||
func addModelFields(mi *modelInfo, ind reflect.Value, mName string, index []int) {
|
||||
var (
|
||||
err error
|
||||
fi *fieldInfo
|
||||
sf reflect.StructField
|
||||
)
|
||||
|
||||
for i := 0; i < ind.NumField(); i++ {
|
||||
field := ind.Field(i)
|
||||
sf = ind.Type().Field(i)
|
||||
// if the field is unexported skip
|
||||
if sf.PkgPath != "" {
|
||||
continue
|
||||
}
|
||||
// add anonymous struct fields
|
||||
if sf.Anonymous {
|
||||
addModelFields(mi, field, mName+"."+sf.Name, append(index, i))
|
||||
continue
|
||||
}
|
||||
|
||||
fi, err = newFieldInfo(mi, field, sf, mName)
|
||||
if err == errSkipField {
|
||||
err = nil
|
||||
continue
|
||||
} else if err != nil {
|
||||
break
|
||||
}
|
||||
//record current field index
|
||||
fi.fieldIndex = append(fi.fieldIndex, index...)
|
||||
fi.fieldIndex = append(fi.fieldIndex, i)
|
||||
fi.mi = mi
|
||||
fi.inModel = true
|
||||
if !mi.fields.Add(fi) {
|
||||
err = fmt.Errorf("duplicate column name: %s", fi.column)
|
||||
break
|
||||
}
|
||||
if fi.pk {
|
||||
if mi.fields.pk != nil {
|
||||
err = fmt.Errorf("one model must have one pk field only")
|
||||
break
|
||||
} else {
|
||||
mi.fields.pk = fi
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
fmt.Println(fmt.Errorf("field: %s.%s, %s", ind.Type(), sf.Name, err))
|
||||
os.Exit(2)
|
||||
}
|
||||
}
|
||||
|
||||
// combine related model info to new model info.
|
||||
// prepare for relation models query.
|
||||
func newM2MModelInfo(m1, m2 *modelInfo) (mi *modelInfo) {
|
||||
mi = new(modelInfo)
|
||||
mi.fields = newFields()
|
||||
mi.table = m1.table + "_" + m2.table + "s"
|
||||
mi.name = camelString(mi.table)
|
||||
mi.fullName = m1.pkg + "." + mi.name
|
||||
|
||||
fa := new(fieldInfo) // pk
|
||||
f1 := new(fieldInfo) // m1 table RelForeignKey
|
||||
f2 := new(fieldInfo) // m2 table RelForeignKey
|
||||
fa.fieldType = TypeBigIntegerField
|
||||
fa.auto = true
|
||||
fa.pk = true
|
||||
fa.dbcol = true
|
||||
fa.name = "Id"
|
||||
fa.column = "id"
|
||||
fa.fullName = mi.fullName + "." + fa.name
|
||||
|
||||
f1.dbcol = true
|
||||
f2.dbcol = true
|
||||
f1.fieldType = RelForeignKey
|
||||
f2.fieldType = RelForeignKey
|
||||
f1.name = camelString(m1.table)
|
||||
f2.name = camelString(m2.table)
|
||||
f1.fullName = mi.fullName + "." + f1.name
|
||||
f2.fullName = mi.fullName + "." + f2.name
|
||||
f1.column = m1.table + "_id"
|
||||
f2.column = m2.table + "_id"
|
||||
f1.rel = true
|
||||
f2.rel = true
|
||||
f1.relTable = m1.table
|
||||
f2.relTable = m2.table
|
||||
f1.relModelInfo = m1
|
||||
f2.relModelInfo = m2
|
||||
f1.mi = mi
|
||||
f2.mi = mi
|
||||
|
||||
mi.fields.Add(fa)
|
||||
mi.fields.Add(f1)
|
||||
mi.fields.Add(f2)
|
||||
mi.fields.pk = fa
|
||||
|
||||
mi.uniques = []string{f1.column, f2.column}
|
||||
return
|
||||
}
|
525
pkg/client/orm/models_test.go
Normal file
525
pkg/client/orm/models_test.go
Normal file
@ -0,0 +1,525 @@
|
||||
// Copyright 2014 beego Author. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package orm
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/astaxie/beego/pkg/client/orm/hints"
|
||||
|
||||
_ "github.com/go-sql-driver/mysql"
|
||||
_ "github.com/lib/pq"
|
||||
_ "github.com/mattn/go-sqlite3"
|
||||
// As tidb can't use go get, so disable the tidb testing now
|
||||
// _ "github.com/pingcap/tidb"
|
||||
)
|
||||
|
||||
// A slice string field.
|
||||
type SliceStringField []string
|
||||
|
||||
func (e SliceStringField) Value() []string {
|
||||
return []string(e)
|
||||
}
|
||||
|
||||
func (e *SliceStringField) Set(d []string) {
|
||||
*e = SliceStringField(d)
|
||||
}
|
||||
|
||||
func (e *SliceStringField) Add(v string) {
|
||||
*e = append(*e, v)
|
||||
}
|
||||
|
||||
func (e *SliceStringField) String() string {
|
||||
return strings.Join(e.Value(), ",")
|
||||
}
|
||||
|
||||
func (e *SliceStringField) FieldType() int {
|
||||
return TypeVarCharField
|
||||
}
|
||||
|
||||
func (e *SliceStringField) SetRaw(value interface{}) error {
|
||||
f := func(str string) {
|
||||
if len(str) > 0 {
|
||||
parts := strings.Split(str, ",")
|
||||
v := make([]string, 0, len(parts))
|
||||
for _, p := range parts {
|
||||
v = append(v, strings.TrimSpace(p))
|
||||
}
|
||||
e.Set(v)
|
||||
}
|
||||
}
|
||||
|
||||
switch d := value.(type) {
|
||||
case []string:
|
||||
e.Set(d)
|
||||
case string:
|
||||
f(d)
|
||||
case []byte:
|
||||
f(string(d))
|
||||
default:
|
||||
return fmt.Errorf("<SliceStringField.SetRaw> unknown value `%v`", value)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (e *SliceStringField) RawValue() interface{} {
|
||||
return e.String()
|
||||
}
|
||||
|
||||
var _ Fielder = new(SliceStringField)
|
||||
|
||||
// A json field.
|
||||
type JSONFieldTest struct {
|
||||
Name string
|
||||
Data string
|
||||
}
|
||||
|
||||
func (e *JSONFieldTest) String() string {
|
||||
data, _ := json.Marshal(e)
|
||||
return string(data)
|
||||
}
|
||||
|
||||
func (e *JSONFieldTest) FieldType() int {
|
||||
return TypeTextField
|
||||
}
|
||||
|
||||
func (e *JSONFieldTest) SetRaw(value interface{}) error {
|
||||
switch d := value.(type) {
|
||||
case string:
|
||||
return json.Unmarshal([]byte(d), e)
|
||||
case []byte:
|
||||
return json.Unmarshal(d, e)
|
||||
default:
|
||||
return fmt.Errorf("<JSONField.SetRaw> unknown value `%v`", value)
|
||||
}
|
||||
}
|
||||
|
||||
func (e *JSONFieldTest) RawValue() interface{} {
|
||||
return e.String()
|
||||
}
|
||||
|
||||
var _ Fielder = new(JSONFieldTest)
|
||||
|
||||
type Data struct {
|
||||
ID int `orm:"column(id)"`
|
||||
Boolean bool
|
||||
Char string `orm:"size(50)"`
|
||||
Text string `orm:"type(text)"`
|
||||
JSON string `orm:"type(json);default({\"name\":\"json\"})"`
|
||||
Jsonb string `orm:"type(jsonb)"`
|
||||
Time time.Time `orm:"type(time)"`
|
||||
Date time.Time `orm:"type(date)"`
|
||||
DateTime time.Time `orm:"column(datetime)"`
|
||||
Byte byte
|
||||
Rune rune
|
||||
Int int
|
||||
Int8 int8
|
||||
Int16 int16
|
||||
Int32 int32
|
||||
Int64 int64
|
||||
Uint uint
|
||||
Uint8 uint8
|
||||
Uint16 uint16
|
||||
Uint32 uint32
|
||||
Uint64 uint64
|
||||
Float32 float32
|
||||
Float64 float64
|
||||
Decimal float64 `orm:"digits(8);decimals(4)"`
|
||||
}
|
||||
|
||||
type DataNull struct {
|
||||
ID int `orm:"column(id)"`
|
||||
Boolean bool `orm:"null"`
|
||||
Char string `orm:"null;size(50)"`
|
||||
Text string `orm:"null;type(text)"`
|
||||
JSON string `orm:"type(json);null"`
|
||||
Jsonb string `orm:"type(jsonb);null"`
|
||||
Time time.Time `orm:"null;type(time)"`
|
||||
Date time.Time `orm:"null;type(date)"`
|
||||
DateTime time.Time `orm:"null;column(datetime)"`
|
||||
Byte byte `orm:"null"`
|
||||
Rune rune `orm:"null"`
|
||||
Int int `orm:"null"`
|
||||
Int8 int8 `orm:"null"`
|
||||
Int16 int16 `orm:"null"`
|
||||
Int32 int32 `orm:"null"`
|
||||
Int64 int64 `orm:"null"`
|
||||
Uint uint `orm:"null"`
|
||||
Uint8 uint8 `orm:"null"`
|
||||
Uint16 uint16 `orm:"null"`
|
||||
Uint32 uint32 `orm:"null"`
|
||||
Uint64 uint64 `orm:"null"`
|
||||
Float32 float32 `orm:"null"`
|
||||
Float64 float64 `orm:"null"`
|
||||
Decimal float64 `orm:"digits(8);decimals(4);null"`
|
||||
NullString sql.NullString `orm:"null"`
|
||||
NullBool sql.NullBool `orm:"null"`
|
||||
NullFloat64 sql.NullFloat64 `orm:"null"`
|
||||
NullInt64 sql.NullInt64 `orm:"null"`
|
||||
BooleanPtr *bool `orm:"null"`
|
||||
CharPtr *string `orm:"null;size(50)"`
|
||||
TextPtr *string `orm:"null;type(text)"`
|
||||
BytePtr *byte `orm:"null"`
|
||||
RunePtr *rune `orm:"null"`
|
||||
IntPtr *int `orm:"null"`
|
||||
Int8Ptr *int8 `orm:"null"`
|
||||
Int16Ptr *int16 `orm:"null"`
|
||||
Int32Ptr *int32 `orm:"null"`
|
||||
Int64Ptr *int64 `orm:"null"`
|
||||
UintPtr *uint `orm:"null"`
|
||||
Uint8Ptr *uint8 `orm:"null"`
|
||||
Uint16Ptr *uint16 `orm:"null"`
|
||||
Uint32Ptr *uint32 `orm:"null"`
|
||||
Uint64Ptr *uint64 `orm:"null"`
|
||||
Float32Ptr *float32 `orm:"null"`
|
||||
Float64Ptr *float64 `orm:"null"`
|
||||
DecimalPtr *float64 `orm:"digits(8);decimals(4);null"`
|
||||
TimePtr *time.Time `orm:"null;type(time)"`
|
||||
DatePtr *time.Time `orm:"null;type(date)"`
|
||||
DateTimePtr *time.Time `orm:"null"`
|
||||
}
|
||||
|
||||
type String string
|
||||
type Boolean bool
|
||||
type Byte byte
|
||||
type Rune rune
|
||||
type Int int
|
||||
type Int8 int8
|
||||
type Int16 int16
|
||||
type Int32 int32
|
||||
type Int64 int64
|
||||
type Uint uint
|
||||
type Uint8 uint8
|
||||
type Uint16 uint16
|
||||
type Uint32 uint32
|
||||
type Uint64 uint64
|
||||
type Float32 float64
|
||||
type Float64 float64
|
||||
|
||||
type DataCustom struct {
|
||||
ID int `orm:"column(id)"`
|
||||
Boolean Boolean
|
||||
Char string `orm:"size(50)"`
|
||||
Text string `orm:"type(text)"`
|
||||
Byte Byte
|
||||
Rune Rune
|
||||
Int Int
|
||||
Int8 Int8
|
||||
Int16 Int16
|
||||
Int32 Int32
|
||||
Int64 Int64
|
||||
Uint Uint
|
||||
Uint8 Uint8
|
||||
Uint16 Uint16
|
||||
Uint32 Uint32
|
||||
Uint64 Uint64
|
||||
Float32 Float32
|
||||
Float64 Float64
|
||||
Decimal Float64 `orm:"digits(8);decimals(4)"`
|
||||
}
|
||||
|
||||
// only for mysql
|
||||
type UserBig struct {
|
||||
ID uint64 `orm:"column(id)"`
|
||||
Name string
|
||||
}
|
||||
|
||||
type User struct {
|
||||
ID int `orm:"column(id)"`
|
||||
UserName string `orm:"size(30);unique"`
|
||||
Email string `orm:"size(100)"`
|
||||
Password string `orm:"size(100)"`
|
||||
Status int16 `orm:"column(Status)"`
|
||||
IsStaff bool
|
||||
IsActive bool `orm:"default(true)"`
|
||||
Created time.Time `orm:"auto_now_add;type(date)"`
|
||||
Updated time.Time `orm:"auto_now"`
|
||||
Profile *Profile `orm:"null;rel(one);on_delete(set_null)"`
|
||||
Posts []*Post `orm:"reverse(many)" json:"-"`
|
||||
ShouldSkip string `orm:"-"`
|
||||
Nums int
|
||||
Langs SliceStringField `orm:"size(100)"`
|
||||
Extra JSONFieldTest `orm:"type(text)"`
|
||||
unexport bool `orm:"-"`
|
||||
unexportBool bool
|
||||
}
|
||||
|
||||
func (u *User) TableIndex() [][]string {
|
||||
return [][]string{
|
||||
{"Id", "UserName"},
|
||||
{"Id", "Created"},
|
||||
}
|
||||
}
|
||||
|
||||
func (u *User) TableUnique() [][]string {
|
||||
return [][]string{
|
||||
{"UserName", "Email"},
|
||||
}
|
||||
}
|
||||
|
||||
func NewUser() *User {
|
||||
obj := new(User)
|
||||
return obj
|
||||
}
|
||||
|
||||
type Profile struct {
|
||||
ID int `orm:"column(id)"`
|
||||
Age int16
|
||||
Money float64
|
||||
User *User `orm:"reverse(one)" json:"-"`
|
||||
BestPost *Post `orm:"rel(one);null"`
|
||||
}
|
||||
|
||||
func (u *Profile) TableName() string {
|
||||
return "user_profile"
|
||||
}
|
||||
|
||||
func NewProfile() *Profile {
|
||||
obj := new(Profile)
|
||||
return obj
|
||||
}
|
||||
|
||||
type Post struct {
|
||||
ID int `orm:"column(id)"`
|
||||
User *User `orm:"rel(fk)"`
|
||||
Title string `orm:"size(60)"`
|
||||
Content string `orm:"type(text)"`
|
||||
Created time.Time `orm:"auto_now_add"`
|
||||
Updated time.Time `orm:"auto_now"`
|
||||
Tags []*Tag `orm:"rel(m2m);rel_through(github.com/astaxie/beego/pkg/client/orm.PostTags)"`
|
||||
}
|
||||
|
||||
func (u *Post) TableIndex() [][]string {
|
||||
return [][]string{
|
||||
{"Id", "Created"},
|
||||
}
|
||||
}
|
||||
|
||||
func NewPost() *Post {
|
||||
obj := new(Post)
|
||||
return obj
|
||||
}
|
||||
|
||||
type Tag struct {
|
||||
ID int `orm:"column(id)"`
|
||||
Name string `orm:"size(30)"`
|
||||
BestPost *Post `orm:"rel(one);null"`
|
||||
Posts []*Post `orm:"reverse(many)" json:"-"`
|
||||
}
|
||||
|
||||
func NewTag() *Tag {
|
||||
obj := new(Tag)
|
||||
return obj
|
||||
}
|
||||
|
||||
type PostTags struct {
|
||||
ID int `orm:"column(id)"`
|
||||
Post *Post `orm:"rel(fk)"`
|
||||
Tag *Tag `orm:"rel(fk)"`
|
||||
}
|
||||
|
||||
func (m *PostTags) TableName() string {
|
||||
return "prefix_post_tags"
|
||||
}
|
||||
|
||||
type Comment struct {
|
||||
ID int `orm:"column(id)"`
|
||||
Post *Post `orm:"rel(fk);column(post)"`
|
||||
Content string `orm:"type(text)"`
|
||||
Parent *Comment `orm:"null;rel(fk)"`
|
||||
Created time.Time `orm:"auto_now_add"`
|
||||
}
|
||||
|
||||
func NewComment() *Comment {
|
||||
obj := new(Comment)
|
||||
return obj
|
||||
}
|
||||
|
||||
type Group struct {
|
||||
ID int `orm:"column(gid);size(32)"`
|
||||
Name string
|
||||
Permissions []*Permission `orm:"reverse(many)" json:"-"`
|
||||
}
|
||||
|
||||
type Permission struct {
|
||||
ID int `orm:"column(id)"`
|
||||
Name string
|
||||
Groups []*Group `orm:"rel(m2m);rel_through(github.com/astaxie/beego/pkg/client/orm.GroupPermissions)"`
|
||||
}
|
||||
|
||||
type GroupPermissions struct {
|
||||
ID int `orm:"column(id)"`
|
||||
Group *Group `orm:"rel(fk)"`
|
||||
Permission *Permission `orm:"rel(fk)"`
|
||||
}
|
||||
|
||||
type ModelID struct {
|
||||
ID int64
|
||||
}
|
||||
|
||||
type ModelBase struct {
|
||||
ModelID
|
||||
|
||||
Created time.Time `orm:"auto_now_add;type(datetime)"`
|
||||
Updated time.Time `orm:"auto_now;type(datetime)"`
|
||||
}
|
||||
|
||||
type InLine struct {
|
||||
// Common Fields
|
||||
ModelBase
|
||||
|
||||
// Other Fields
|
||||
Name string `orm:"unique"`
|
||||
Email string
|
||||
}
|
||||
|
||||
type Index struct {
|
||||
// Common Fields
|
||||
Id int `orm:"column(id)"`
|
||||
|
||||
// Other Fields
|
||||
F1 int `orm:"column(f1);index"`
|
||||
F2 int `orm:"column(f2);index"`
|
||||
}
|
||||
|
||||
func NewInLine() *InLine {
|
||||
return new(InLine)
|
||||
}
|
||||
|
||||
type InLineOneToOne struct {
|
||||
// Common Fields
|
||||
ModelBase
|
||||
|
||||
Note string
|
||||
InLine *InLine `orm:"rel(fk);column(inline)"`
|
||||
}
|
||||
|
||||
func NewInLineOneToOne() *InLineOneToOne {
|
||||
return new(InLineOneToOne)
|
||||
}
|
||||
|
||||
type IntegerPk struct {
|
||||
ID int64 `orm:"pk"`
|
||||
Value string
|
||||
}
|
||||
|
||||
type UintPk struct {
|
||||
ID uint32 `orm:"pk"`
|
||||
Name string
|
||||
}
|
||||
|
||||
type PtrPk struct {
|
||||
ID *IntegerPk `orm:"pk;rel(one)"`
|
||||
Positive bool
|
||||
}
|
||||
|
||||
type StrPk struct {
|
||||
Id string `orm:"column(id);size(64);pk"`
|
||||
Value string
|
||||
}
|
||||
|
||||
var DBARGS = struct {
|
||||
Driver string
|
||||
Source string
|
||||
Debug string
|
||||
}{
|
||||
os.Getenv("ORM_DRIVER"),
|
||||
os.Getenv("ORM_SOURCE"),
|
||||
os.Getenv("ORM_DEBUG"),
|
||||
}
|
||||
|
||||
var (
|
||||
IsMysql = DBARGS.Driver == "mysql"
|
||||
IsSqlite = DBARGS.Driver == "sqlite3"
|
||||
IsPostgres = DBARGS.Driver == "postgres"
|
||||
IsTidb = DBARGS.Driver == "tidb"
|
||||
)
|
||||
|
||||
var (
|
||||
dORM Ormer
|
||||
dDbBaser dbBaser
|
||||
)
|
||||
|
||||
var (
|
||||
helpinfo = `need driver and source!
|
||||
|
||||
Default DB Drivers.
|
||||
|
||||
driver: url
|
||||
mysql: https://github.com/go-sql-driver/mysql
|
||||
sqlite3: https://github.com/mattn/go-sqlite3
|
||||
postgres: https://github.com/lib/pq
|
||||
tidb: https://github.com/pingcap/tidb
|
||||
|
||||
usage:
|
||||
|
||||
go get -u github.com/astaxie/beego/pkg/client/orm
|
||||
go get -u github.com/go-sql-driver/mysql
|
||||
go get -u github.com/mattn/go-sqlite3
|
||||
go get -u github.com/lib/pq
|
||||
go get -u github.com/pingcap/tidb
|
||||
|
||||
#### MySQL
|
||||
mysql -u root -e 'create database orm_test;'
|
||||
export ORM_DRIVER=mysql
|
||||
export ORM_SOURCE="root:@/orm_test?charset=utf8"
|
||||
go test -v github.com/astaxie/beego/pkg/client/orm
|
||||
|
||||
|
||||
#### Sqlite3
|
||||
export ORM_DRIVER=sqlite3
|
||||
export ORM_SOURCE='file:memory_test?mode=memory'
|
||||
go test -v github.com/astaxie/beego/pkg/client/orm
|
||||
|
||||
|
||||
#### PostgreSQL
|
||||
psql -c 'create database orm_test;' -U postgres
|
||||
export ORM_DRIVER=postgres
|
||||
export ORM_SOURCE="user=postgres dbname=orm_test sslmode=disable"
|
||||
go test -v github.com/astaxie/beego/pkg/client/orm
|
||||
|
||||
#### TiDB
|
||||
export ORM_DRIVER=tidb
|
||||
export ORM_SOURCE='memory://test/test'
|
||||
go test -v github.com/astaxie/beego/pgk/orm
|
||||
|
||||
`
|
||||
)
|
||||
|
||||
func init() {
|
||||
Debug, _ = StrTo(DBARGS.Debug).Bool()
|
||||
|
||||
if DBARGS.Driver == "" || DBARGS.Source == "" {
|
||||
fmt.Println(helpinfo)
|
||||
os.Exit(2)
|
||||
}
|
||||
|
||||
err := RegisterDataBase("default", DBARGS.Driver, DBARGS.Source, hints.MaxIdleConnections(20))
|
||||
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("can not register database: %v", err))
|
||||
}
|
||||
|
||||
alias := getDbAlias("default")
|
||||
if alias.Driver == DRMySQL {
|
||||
alias.Engine = "INNODB"
|
||||
}
|
||||
|
||||
}
|
227
pkg/client/orm/models_utils.go
Normal file
227
pkg/client/orm/models_utils.go
Normal file
@ -0,0 +1,227 @@
|
||||
// Copyright 2014 beego Author. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package orm
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// 1 is attr
|
||||
// 2 is tag
|
||||
var supportTag = map[string]int{
|
||||
"-": 1,
|
||||
"null": 1,
|
||||
"index": 1,
|
||||
"unique": 1,
|
||||
"pk": 1,
|
||||
"auto": 1,
|
||||
"auto_now": 1,
|
||||
"auto_now_add": 1,
|
||||
"size": 2,
|
||||
"column": 2,
|
||||
"default": 2,
|
||||
"rel": 2,
|
||||
"reverse": 2,
|
||||
"rel_table": 2,
|
||||
"rel_through": 2,
|
||||
"digits": 2,
|
||||
"decimals": 2,
|
||||
"on_delete": 2,
|
||||
"type": 2,
|
||||
"description": 2,
|
||||
}
|
||||
|
||||
// get reflect.Type name with package path.
|
||||
func getFullName(typ reflect.Type) string {
|
||||
return typ.PkgPath() + "." + typ.Name()
|
||||
}
|
||||
|
||||
// getTableName get struct table name.
|
||||
// If the struct implement the TableName, then get the result as tablename
|
||||
// else use the struct name which will apply snakeString.
|
||||
func getTableName(val reflect.Value) string {
|
||||
if fun := val.MethodByName("TableName"); fun.IsValid() {
|
||||
vals := fun.Call([]reflect.Value{})
|
||||
// has return and the first val is string
|
||||
if len(vals) > 0 && vals[0].Kind() == reflect.String {
|
||||
return vals[0].String()
|
||||
}
|
||||
}
|
||||
return snakeString(reflect.Indirect(val).Type().Name())
|
||||
}
|
||||
|
||||
// get table engine, myisam or innodb.
|
||||
func getTableEngine(val reflect.Value) string {
|
||||
fun := val.MethodByName("TableEngine")
|
||||
if fun.IsValid() {
|
||||
vals := fun.Call([]reflect.Value{})
|
||||
if len(vals) > 0 && vals[0].Kind() == reflect.String {
|
||||
return vals[0].String()
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// get table index from method.
|
||||
func getTableIndex(val reflect.Value) [][]string {
|
||||
fun := val.MethodByName("TableIndex")
|
||||
if fun.IsValid() {
|
||||
vals := fun.Call([]reflect.Value{})
|
||||
if len(vals) > 0 && vals[0].CanInterface() {
|
||||
if d, ok := vals[0].Interface().([][]string); ok {
|
||||
return d
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// get table unique from method
|
||||
func getTableUnique(val reflect.Value) [][]string {
|
||||
fun := val.MethodByName("TableUnique")
|
||||
if fun.IsValid() {
|
||||
vals := fun.Call([]reflect.Value{})
|
||||
if len(vals) > 0 && vals[0].CanInterface() {
|
||||
if d, ok := vals[0].Interface().([][]string); ok {
|
||||
return d
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// get snaked column name
|
||||
func getColumnName(ft int, addrField reflect.Value, sf reflect.StructField, col string) string {
|
||||
column := col
|
||||
if col == "" {
|
||||
column = nameStrategyMap[nameStrategy](sf.Name)
|
||||
}
|
||||
switch ft {
|
||||
case RelForeignKey, RelOneToOne:
|
||||
if len(col) == 0 {
|
||||
column = column + "_id"
|
||||
}
|
||||
case RelManyToMany, RelReverseMany, RelReverseOne:
|
||||
column = sf.Name
|
||||
}
|
||||
return column
|
||||
}
|
||||
|
||||
// return field type as type constant from reflect.Value
|
||||
func getFieldType(val reflect.Value) (ft int, err error) {
|
||||
switch val.Type() {
|
||||
case reflect.TypeOf(new(int8)):
|
||||
ft = TypeBitField
|
||||
case reflect.TypeOf(new(int16)):
|
||||
ft = TypeSmallIntegerField
|
||||
case reflect.TypeOf(new(int32)),
|
||||
reflect.TypeOf(new(int)):
|
||||
ft = TypeIntegerField
|
||||
case reflect.TypeOf(new(int64)):
|
||||
ft = TypeBigIntegerField
|
||||
case reflect.TypeOf(new(uint8)):
|
||||
ft = TypePositiveBitField
|
||||
case reflect.TypeOf(new(uint16)):
|
||||
ft = TypePositiveSmallIntegerField
|
||||
case reflect.TypeOf(new(uint32)),
|
||||
reflect.TypeOf(new(uint)):
|
||||
ft = TypePositiveIntegerField
|
||||
case reflect.TypeOf(new(uint64)):
|
||||
ft = TypePositiveBigIntegerField
|
||||
case reflect.TypeOf(new(float32)),
|
||||
reflect.TypeOf(new(float64)):
|
||||
ft = TypeFloatField
|
||||
case reflect.TypeOf(new(bool)):
|
||||
ft = TypeBooleanField
|
||||
case reflect.TypeOf(new(string)):
|
||||
ft = TypeVarCharField
|
||||
case reflect.TypeOf(new(time.Time)):
|
||||
ft = TypeDateTimeField
|
||||
default:
|
||||
elm := reflect.Indirect(val)
|
||||
switch elm.Kind() {
|
||||
case reflect.Int8:
|
||||
ft = TypeBitField
|
||||
case reflect.Int16:
|
||||
ft = TypeSmallIntegerField
|
||||
case reflect.Int32, reflect.Int:
|
||||
ft = TypeIntegerField
|
||||
case reflect.Int64:
|
||||
ft = TypeBigIntegerField
|
||||
case reflect.Uint8:
|
||||
ft = TypePositiveBitField
|
||||
case reflect.Uint16:
|
||||
ft = TypePositiveSmallIntegerField
|
||||
case reflect.Uint32, reflect.Uint:
|
||||
ft = TypePositiveIntegerField
|
||||
case reflect.Uint64:
|
||||
ft = TypePositiveBigIntegerField
|
||||
case reflect.Float32, reflect.Float64:
|
||||
ft = TypeFloatField
|
||||
case reflect.Bool:
|
||||
ft = TypeBooleanField
|
||||
case reflect.String:
|
||||
ft = TypeVarCharField
|
||||
default:
|
||||
if elm.Interface() == nil {
|
||||
panic(fmt.Errorf("%s is nil pointer, may be miss setting tag", val))
|
||||
}
|
||||
switch elm.Interface().(type) {
|
||||
case sql.NullInt64:
|
||||
ft = TypeBigIntegerField
|
||||
case sql.NullFloat64:
|
||||
ft = TypeFloatField
|
||||
case sql.NullBool:
|
||||
ft = TypeBooleanField
|
||||
case sql.NullString:
|
||||
ft = TypeVarCharField
|
||||
case time.Time:
|
||||
ft = TypeDateTimeField
|
||||
}
|
||||
}
|
||||
}
|
||||
if ft&IsFieldType == 0 {
|
||||
err = fmt.Errorf("unsupport field type %s, may be miss setting tag", val)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// parse struct tag string
|
||||
func parseStructTag(data string) (attrs map[string]bool, tags map[string]string) {
|
||||
attrs = make(map[string]bool)
|
||||
tags = make(map[string]string)
|
||||
for _, v := range strings.Split(data, defaultStructTagDelim) {
|
||||
if v == "" {
|
||||
continue
|
||||
}
|
||||
v = strings.TrimSpace(v)
|
||||
if t := strings.ToLower(v); supportTag[t] == 1 {
|
||||
attrs[t] = true
|
||||
} else if i := strings.Index(v, "("); i > 0 && strings.Index(v, ")") == len(v)-1 {
|
||||
name := t[:i]
|
||||
if supportTag[name] == 2 {
|
||||
v = v[i+1 : len(v)-1]
|
||||
tags[name] = v
|
||||
}
|
||||
} else {
|
||||
DebugLog.Println("unsupport orm tag", v)
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
629
pkg/client/orm/orm.go
Normal file
629
pkg/client/orm/orm.go
Normal file
@ -0,0 +1,629 @@
|
||||
// Copyright 2014 beego Author. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
// +build go1.8
|
||||
|
||||
// Package orm provide ORM for MySQL/PostgreSQL/sqlite
|
||||
// Simple Usage
|
||||
//
|
||||
// package main
|
||||
//
|
||||
// import (
|
||||
// "fmt"
|
||||
// "github.com/astaxie/beego/pkg/client/orm"
|
||||
// _ "github.com/go-sql-driver/mysql" // import your used driver
|
||||
// )
|
||||
//
|
||||
// // Model Struct
|
||||
// type User struct {
|
||||
// Id int `orm:"auto"`
|
||||
// Name string `orm:"size(100)"`
|
||||
// }
|
||||
//
|
||||
// func init() {
|
||||
// orm.RegisterDataBase("default", "mysql", "root:root@/my_db?charset=utf8", 30)
|
||||
// }
|
||||
//
|
||||
// func main() {
|
||||
// o := orm.NewOrm()
|
||||
// user := User{Name: "slene"}
|
||||
// // insert
|
||||
// id, err := o.Insert(&user)
|
||||
// // update
|
||||
// user.Name = "astaxie"
|
||||
// num, err := o.Update(&user)
|
||||
// // read one
|
||||
// u := User{Id: user.Id}
|
||||
// err = o.Read(&u)
|
||||
// // delete
|
||||
// num, err = o.Delete(&u)
|
||||
// }
|
||||
//
|
||||
// more docs: http://beego.me/docs/mvc/model/overview.md
|
||||
package orm
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"reflect"
|
||||
"time"
|
||||
|
||||
"github.com/astaxie/beego/pkg/client/orm/hints"
|
||||
"github.com/astaxie/beego/pkg/infrastructure/utils"
|
||||
|
||||
"github.com/astaxie/beego/pkg/infrastructure/logs"
|
||||
)
|
||||
|
||||
// DebugQueries define the debug
|
||||
const (
|
||||
DebugQueries = iota
|
||||
)
|
||||
|
||||
// Define common vars
|
||||
var (
|
||||
Debug = false
|
||||
DebugLog = NewLog(os.Stdout)
|
||||
DefaultRowsLimit = -1
|
||||
DefaultRelsDepth = 2
|
||||
DefaultTimeLoc = time.Local
|
||||
ErrTxDone = errors.New("<TxOrmer.Commit/Rollback> transaction already done")
|
||||
ErrMultiRows = errors.New("<QuerySeter> return multi rows")
|
||||
ErrNoRows = errors.New("<QuerySeter> no row found")
|
||||
ErrStmtClosed = errors.New("<QuerySeter> stmt already closed")
|
||||
ErrArgs = errors.New("<Ormer> args error may be empty")
|
||||
ErrNotImplement = errors.New("have not implement")
|
||||
|
||||
ErrLastInsertIdUnavailable = errors.New("<Ormer> last insert id is unavailable")
|
||||
)
|
||||
|
||||
// Params stores the Params
|
||||
type Params map[string]interface{}
|
||||
|
||||
// ParamsList stores paramslist
|
||||
type ParamsList []interface{}
|
||||
|
||||
type ormBase struct {
|
||||
alias *alias
|
||||
db dbQuerier
|
||||
}
|
||||
|
||||
var _ DQL = new(ormBase)
|
||||
var _ DML = new(ormBase)
|
||||
var _ DriverGetter = new(ormBase)
|
||||
|
||||
// get model info and model reflect value
|
||||
func (o *ormBase) getMiInd(md interface{}, needPtr bool) (mi *modelInfo, ind reflect.Value) {
|
||||
val := reflect.ValueOf(md)
|
||||
ind = reflect.Indirect(val)
|
||||
typ := ind.Type()
|
||||
if needPtr && val.Kind() != reflect.Ptr {
|
||||
panic(fmt.Errorf("<Ormer> cannot use non-ptr model struct `%s`", getFullName(typ)))
|
||||
}
|
||||
name := getFullName(typ)
|
||||
if mi, ok := modelCache.getByFullName(name); ok {
|
||||
return mi, ind
|
||||
}
|
||||
panic(fmt.Errorf("<Ormer> table: `%s` not found, make sure it was registered with `RegisterModel()`", name))
|
||||
}
|
||||
|
||||
// get field info from model info by given field name
|
||||
func (o *ormBase) getFieldInfo(mi *modelInfo, name string) *fieldInfo {
|
||||
fi, ok := mi.fields.GetByAny(name)
|
||||
if !ok {
|
||||
panic(fmt.Errorf("<Ormer> cannot find field `%s` for model `%s`", name, mi.fullName))
|
||||
}
|
||||
return fi
|
||||
}
|
||||
|
||||
// read data to model
|
||||
func (o *ormBase) Read(md interface{}, cols ...string) error {
|
||||
return o.ReadWithCtx(context.Background(), md, cols...)
|
||||
}
|
||||
func (o *ormBase) ReadWithCtx(ctx context.Context, md interface{}, cols ...string) error {
|
||||
mi, ind := o.getMiInd(md, true)
|
||||
return o.alias.DbBaser.Read(o.db, mi, ind, o.alias.TZ, cols, false)
|
||||
}
|
||||
|
||||
// read data to model, like Read(), but use "SELECT FOR UPDATE" form
|
||||
func (o *ormBase) ReadForUpdate(md interface{}, cols ...string) error {
|
||||
return o.ReadForUpdateWithCtx(context.Background(), md, cols...)
|
||||
}
|
||||
func (o *ormBase) ReadForUpdateWithCtx(ctx context.Context, md interface{}, cols ...string) error {
|
||||
mi, ind := o.getMiInd(md, true)
|
||||
return o.alias.DbBaser.Read(o.db, mi, ind, o.alias.TZ, cols, true)
|
||||
}
|
||||
|
||||
// Try to read a row from the database, or insert one if it doesn't exist
|
||||
func (o *ormBase) ReadOrCreate(md interface{}, col1 string, cols ...string) (bool, int64, error) {
|
||||
return o.ReadOrCreateWithCtx(context.Background(), md, col1, cols...)
|
||||
}
|
||||
func (o *ormBase) ReadOrCreateWithCtx(ctx context.Context, md interface{}, col1 string, cols ...string) (bool, int64, error) {
|
||||
cols = append([]string{col1}, cols...)
|
||||
mi, ind := o.getMiInd(md, true)
|
||||
err := o.alias.DbBaser.Read(o.db, mi, ind, o.alias.TZ, cols, false)
|
||||
if err == ErrNoRows {
|
||||
// Create
|
||||
id, err := o.InsertWithCtx(ctx, md)
|
||||
return err == nil, id, err
|
||||
}
|
||||
|
||||
id, vid := int64(0), ind.FieldByIndex(mi.fields.pk.fieldIndex)
|
||||
if mi.fields.pk.fieldType&IsPositiveIntegerField > 0 {
|
||||
id = int64(vid.Uint())
|
||||
} else if mi.fields.pk.rel {
|
||||
return o.ReadOrCreateWithCtx(ctx, vid.Interface(), mi.fields.pk.relModelInfo.fields.pk.name)
|
||||
} else {
|
||||
id = vid.Int()
|
||||
}
|
||||
|
||||
return false, id, err
|
||||
}
|
||||
|
||||
// insert model data to database
|
||||
func (o *ormBase) Insert(md interface{}) (int64, error) {
|
||||
return o.InsertWithCtx(context.Background(), md)
|
||||
}
|
||||
func (o *ormBase) InsertWithCtx(ctx context.Context, md interface{}) (int64, error) {
|
||||
mi, ind := o.getMiInd(md, true)
|
||||
id, err := o.alias.DbBaser.Insert(o.db, mi, ind, o.alias.TZ)
|
||||
if err != nil {
|
||||
return id, err
|
||||
}
|
||||
|
||||
o.setPk(mi, ind, id)
|
||||
|
||||
return id, nil
|
||||
}
|
||||
|
||||
// set auto pk field
|
||||
func (o *ormBase) setPk(mi *modelInfo, ind reflect.Value, id int64) {
|
||||
if mi.fields.pk.auto {
|
||||
if mi.fields.pk.fieldType&IsPositiveIntegerField > 0 {
|
||||
ind.FieldByIndex(mi.fields.pk.fieldIndex).SetUint(uint64(id))
|
||||
} else {
|
||||
ind.FieldByIndex(mi.fields.pk.fieldIndex).SetInt(id)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// insert some models to database
|
||||
func (o *ormBase) InsertMulti(bulk int, mds interface{}) (int64, error) {
|
||||
return o.InsertMultiWithCtx(context.Background(), bulk, mds)
|
||||
}
|
||||
func (o *ormBase) InsertMultiWithCtx(ctx context.Context, bulk int, mds interface{}) (int64, error) {
|
||||
var cnt int64
|
||||
|
||||
sind := reflect.Indirect(reflect.ValueOf(mds))
|
||||
|
||||
switch sind.Kind() {
|
||||
case reflect.Array, reflect.Slice:
|
||||
if sind.Len() == 0 {
|
||||
return cnt, ErrArgs
|
||||
}
|
||||
default:
|
||||
return cnt, ErrArgs
|
||||
}
|
||||
|
||||
if bulk <= 1 {
|
||||
for i := 0; i < sind.Len(); i++ {
|
||||
ind := reflect.Indirect(sind.Index(i))
|
||||
mi, _ := o.getMiInd(ind.Interface(), false)
|
||||
id, err := o.alias.DbBaser.Insert(o.db, mi, ind, o.alias.TZ)
|
||||
if err != nil {
|
||||
return cnt, err
|
||||
}
|
||||
|
||||
o.setPk(mi, ind, id)
|
||||
|
||||
cnt++
|
||||
}
|
||||
} else {
|
||||
mi, _ := o.getMiInd(sind.Index(0).Interface(), false)
|
||||
return o.alias.DbBaser.InsertMulti(o.db, mi, sind, bulk, o.alias.TZ)
|
||||
}
|
||||
return cnt, nil
|
||||
}
|
||||
|
||||
// InsertOrUpdate data to database
|
||||
func (o *ormBase) InsertOrUpdate(md interface{}, colConflictAndArgs ...string) (int64, error) {
|
||||
return o.InsertOrUpdateWithCtx(context.Background(), md, colConflictAndArgs...)
|
||||
}
|
||||
func (o *ormBase) InsertOrUpdateWithCtx(ctx context.Context, md interface{}, colConflitAndArgs ...string) (int64, error) {
|
||||
mi, ind := o.getMiInd(md, true)
|
||||
id, err := o.alias.DbBaser.InsertOrUpdate(o.db, mi, ind, o.alias, colConflitAndArgs...)
|
||||
if err != nil {
|
||||
return id, err
|
||||
}
|
||||
|
||||
o.setPk(mi, ind, id)
|
||||
|
||||
return id, nil
|
||||
}
|
||||
|
||||
// update model to database.
|
||||
// cols set the columns those want to update.
|
||||
func (o *ormBase) Update(md interface{}, cols ...string) (int64, error) {
|
||||
return o.UpdateWithCtx(context.Background(), md, cols...)
|
||||
}
|
||||
func (o *ormBase) UpdateWithCtx(ctx context.Context, md interface{}, cols ...string) (int64, error) {
|
||||
mi, ind := o.getMiInd(md, true)
|
||||
return o.alias.DbBaser.Update(o.db, mi, ind, o.alias.TZ, cols)
|
||||
}
|
||||
|
||||
// delete model in database
|
||||
// cols shows the delete conditions values read from. default is pk
|
||||
func (o *ormBase) Delete(md interface{}, cols ...string) (int64, error) {
|
||||
return o.DeleteWithCtx(context.Background(), md, cols...)
|
||||
}
|
||||
func (o *ormBase) DeleteWithCtx(ctx context.Context, md interface{}, cols ...string) (int64, error) {
|
||||
mi, ind := o.getMiInd(md, true)
|
||||
num, err := o.alias.DbBaser.Delete(o.db, mi, ind, o.alias.TZ, cols)
|
||||
if err != nil {
|
||||
return num, err
|
||||
}
|
||||
if num > 0 {
|
||||
o.setPk(mi, ind, 0)
|
||||
}
|
||||
return num, nil
|
||||
}
|
||||
|
||||
// create a models to models queryer
|
||||
func (o *ormBase) QueryM2M(md interface{}, name string) QueryM2Mer {
|
||||
return o.QueryM2MWithCtx(context.Background(), md, name)
|
||||
}
|
||||
func (o *ormBase) QueryM2MWithCtx(ctx context.Context, md interface{}, name string) QueryM2Mer {
|
||||
mi, ind := o.getMiInd(md, true)
|
||||
fi := o.getFieldInfo(mi, name)
|
||||
|
||||
switch {
|
||||
case fi.fieldType == RelManyToMany:
|
||||
case fi.fieldType == RelReverseMany && fi.reverseFieldInfo.mi.isThrough:
|
||||
default:
|
||||
panic(fmt.Errorf("<Ormer.QueryM2M> model `%s` . name `%s` is not a m2m field", fi.name, mi.fullName))
|
||||
}
|
||||
|
||||
return newQueryM2M(md, o, mi, fi, ind)
|
||||
}
|
||||
|
||||
// load related models to md model.
|
||||
// args are limit, offset int and order string.
|
||||
//
|
||||
// example:
|
||||
// orm.LoadRelated(post,"Tags")
|
||||
// for _,tag := range post.Tags{...}
|
||||
//
|
||||
// make sure the relation is defined in model struct tags.
|
||||
func (o *ormBase) LoadRelated(md interface{}, name string, args ...utils.KV) (int64, error) {
|
||||
return o.LoadRelatedWithCtx(context.Background(), md, name, args...)
|
||||
}
|
||||
func (o *ormBase) LoadRelatedWithCtx(ctx context.Context, md interface{}, name string, args ...utils.KV) (int64, error) {
|
||||
_, fi, ind, qseter := o.queryRelated(md, name)
|
||||
|
||||
qs := qseter.(*querySet)
|
||||
|
||||
var relDepth int
|
||||
var limit, offset int64
|
||||
var order string
|
||||
|
||||
kvs := utils.NewKVs(args...)
|
||||
kvs.IfContains(hints.KeyRelDepth, func(value interface{}) {
|
||||
if v, ok := value.(bool); ok {
|
||||
if v {
|
||||
relDepth = DefaultRelsDepth
|
||||
}
|
||||
} else if v, ok := value.(int); ok {
|
||||
relDepth = v
|
||||
}
|
||||
}).IfContains(hints.KeyLimit, func(value interface{}) {
|
||||
if v, ok := value.(int64); ok {
|
||||
limit = v
|
||||
}
|
||||
}).IfContains(hints.KeyOffset, func(value interface{}) {
|
||||
if v, ok := value.(int64); ok {
|
||||
offset = v
|
||||
}
|
||||
}).IfContains(hints.KeyOrderBy, func(value interface{}) {
|
||||
if v, ok := value.(string); ok {
|
||||
order = v
|
||||
}
|
||||
})
|
||||
|
||||
switch fi.fieldType {
|
||||
case RelOneToOne, RelForeignKey, RelReverseOne:
|
||||
limit = 1
|
||||
offset = 0
|
||||
}
|
||||
|
||||
qs.limit = limit
|
||||
qs.offset = offset
|
||||
qs.relDepth = relDepth
|
||||
|
||||
if len(order) > 0 {
|
||||
qs.orders = []string{order}
|
||||
}
|
||||
|
||||
find := ind.FieldByIndex(fi.fieldIndex)
|
||||
|
||||
var nums int64
|
||||
var err error
|
||||
switch fi.fieldType {
|
||||
case RelOneToOne, RelForeignKey, RelReverseOne:
|
||||
val := reflect.New(find.Type().Elem())
|
||||
container := val.Interface()
|
||||
err = qs.One(container)
|
||||
if err == nil {
|
||||
find.Set(val)
|
||||
nums = 1
|
||||
}
|
||||
default:
|
||||
nums, err = qs.All(find.Addr().Interface())
|
||||
}
|
||||
|
||||
return nums, err
|
||||
}
|
||||
|
||||
// get QuerySeter for related models to md model
|
||||
func (o *ormBase) queryRelated(md interface{}, name string) (*modelInfo, *fieldInfo, reflect.Value, QuerySeter) {
|
||||
mi, ind := o.getMiInd(md, true)
|
||||
fi := o.getFieldInfo(mi, name)
|
||||
|
||||
_, _, exist := getExistPk(mi, ind)
|
||||
if !exist {
|
||||
panic(ErrMissPK)
|
||||
}
|
||||
|
||||
var qs *querySet
|
||||
|
||||
switch fi.fieldType {
|
||||
case RelOneToOne, RelForeignKey, RelManyToMany:
|
||||
if !fi.inModel {
|
||||
break
|
||||
}
|
||||
qs = o.getRelQs(md, mi, fi)
|
||||
case RelReverseOne, RelReverseMany:
|
||||
if !fi.inModel {
|
||||
break
|
||||
}
|
||||
qs = o.getReverseQs(md, mi, fi)
|
||||
}
|
||||
|
||||
if qs == nil {
|
||||
panic(fmt.Errorf("<Ormer> name `%s` for model `%s` is not an available rel/reverse field", md, name))
|
||||
}
|
||||
|
||||
return mi, fi, ind, qs
|
||||
}
|
||||
|
||||
// get reverse relation QuerySeter
|
||||
func (o *ormBase) getReverseQs(md interface{}, mi *modelInfo, fi *fieldInfo) *querySet {
|
||||
switch fi.fieldType {
|
||||
case RelReverseOne, RelReverseMany:
|
||||
default:
|
||||
panic(fmt.Errorf("<Ormer> name `%s` for model `%s` is not an available reverse field", fi.name, mi.fullName))
|
||||
}
|
||||
|
||||
var q *querySet
|
||||
|
||||
if fi.fieldType == RelReverseMany && fi.reverseFieldInfo.mi.isThrough {
|
||||
q = newQuerySet(o, fi.relModelInfo).(*querySet)
|
||||
q.cond = NewCondition().And(fi.reverseFieldInfoM2M.column+ExprSep+fi.reverseFieldInfo.column, md)
|
||||
} else {
|
||||
q = newQuerySet(o, fi.reverseFieldInfo.mi).(*querySet)
|
||||
q.cond = NewCondition().And(fi.reverseFieldInfo.column, md)
|
||||
}
|
||||
|
||||
return q
|
||||
}
|
||||
|
||||
// get relation QuerySeter
|
||||
func (o *ormBase) getRelQs(md interface{}, mi *modelInfo, fi *fieldInfo) *querySet {
|
||||
switch fi.fieldType {
|
||||
case RelOneToOne, RelForeignKey, RelManyToMany:
|
||||
default:
|
||||
panic(fmt.Errorf("<Ormer> name `%s` for model `%s` is not an available rel field", fi.name, mi.fullName))
|
||||
}
|
||||
|
||||
q := newQuerySet(o, fi.relModelInfo).(*querySet)
|
||||
q.cond = NewCondition()
|
||||
|
||||
if fi.fieldType == RelManyToMany {
|
||||
q.cond = q.cond.And(fi.reverseFieldInfoM2M.column+ExprSep+fi.reverseFieldInfo.column, md)
|
||||
} else {
|
||||
q.cond = q.cond.And(fi.reverseFieldInfo.column, md)
|
||||
}
|
||||
|
||||
return q
|
||||
}
|
||||
|
||||
// return a QuerySeter for table operations.
|
||||
// table name can be string or struct.
|
||||
// e.g. QueryTable("user"), QueryTable(&user{}) or QueryTable((*User)(nil)),
|
||||
func (o *ormBase) QueryTable(ptrStructOrTableName interface{}) (qs QuerySeter) {
|
||||
return o.QueryTableWithCtx(context.Background(), ptrStructOrTableName)
|
||||
}
|
||||
func (o *ormBase) QueryTableWithCtx(ctx context.Context, ptrStructOrTableName interface{}) (qs QuerySeter) {
|
||||
var name string
|
||||
if table, ok := ptrStructOrTableName.(string); ok {
|
||||
name = nameStrategyMap[defaultNameStrategy](table)
|
||||
if mi, ok := modelCache.get(name); ok {
|
||||
qs = newQuerySet(o, mi)
|
||||
}
|
||||
} else {
|
||||
name = getFullName(indirectType(reflect.TypeOf(ptrStructOrTableName)))
|
||||
if mi, ok := modelCache.getByFullName(name); ok {
|
||||
qs = newQuerySet(o, mi)
|
||||
}
|
||||
}
|
||||
if qs == nil {
|
||||
panic(fmt.Errorf("<Ormer.QueryTable> table name: `%s` not exists", name))
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// return a raw query seter for raw sql string.
|
||||
func (o *ormBase) Raw(query string, args ...interface{}) RawSeter {
|
||||
return o.RawWithCtx(context.Background(), query, args...)
|
||||
}
|
||||
func (o *ormBase) RawWithCtx(ctx context.Context, query string, args ...interface{}) RawSeter {
|
||||
return newRawSet(o, query, args)
|
||||
}
|
||||
|
||||
// return current using database Driver
|
||||
func (o *ormBase) Driver() Driver {
|
||||
return driver(o.alias.Name)
|
||||
}
|
||||
|
||||
// return sql.DBStats for current database
|
||||
func (o *ormBase) DBStats() *sql.DBStats {
|
||||
if o.alias != nil && o.alias.DB != nil {
|
||||
stats := o.alias.DB.DB.Stats()
|
||||
return &stats
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type orm struct {
|
||||
ormBase
|
||||
}
|
||||
|
||||
var _ Ormer = new(orm)
|
||||
|
||||
func (o *orm) Begin() (TxOrmer, error) {
|
||||
return o.BeginWithCtx(context.Background())
|
||||
}
|
||||
|
||||
func (o *orm) BeginWithCtx(ctx context.Context) (TxOrmer, error) {
|
||||
return o.BeginWithCtxAndOpts(ctx, nil)
|
||||
}
|
||||
|
||||
func (o *orm) BeginWithOpts(opts *sql.TxOptions) (TxOrmer, error) {
|
||||
return o.BeginWithCtxAndOpts(context.Background(), opts)
|
||||
}
|
||||
|
||||
func (o *orm) BeginWithCtxAndOpts(ctx context.Context, opts *sql.TxOptions) (TxOrmer, error) {
|
||||
tx, err := o.db.(txer).BeginTx(ctx, opts)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
_txOrm := &txOrm{
|
||||
ormBase: ormBase{
|
||||
alias: o.alias,
|
||||
db: &TxDB{tx: tx},
|
||||
},
|
||||
}
|
||||
|
||||
var taskTxOrm TxOrmer = _txOrm
|
||||
return taskTxOrm, nil
|
||||
}
|
||||
|
||||
func (o *orm) DoTx(task func(ctx context.Context, txOrm TxOrmer) error) error {
|
||||
return o.DoTxWithCtx(context.Background(), task)
|
||||
}
|
||||
|
||||
func (o *orm) DoTxWithCtx(ctx context.Context, task func(ctx context.Context, txOrm TxOrmer) error) error {
|
||||
return o.DoTxWithCtxAndOpts(ctx, nil, task)
|
||||
}
|
||||
|
||||
func (o *orm) DoTxWithOpts(opts *sql.TxOptions, task func(ctx context.Context, txOrm TxOrmer) error) error {
|
||||
return o.DoTxWithCtxAndOpts(context.Background(), opts, task)
|
||||
}
|
||||
|
||||
func (o *orm) DoTxWithCtxAndOpts(ctx context.Context, opts *sql.TxOptions, task func(ctx context.Context, txOrm TxOrmer) error) error {
|
||||
return doTxTemplate(o, ctx, opts, task)
|
||||
}
|
||||
|
||||
func doTxTemplate(o TxBeginner, ctx context.Context, opts *sql.TxOptions,
|
||||
task func(ctx context.Context, txOrm TxOrmer) error) error {
|
||||
_txOrm, err := o.BeginWithCtxAndOpts(ctx, opts)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
panicked := true
|
||||
defer func() {
|
||||
if panicked || err != nil {
|
||||
e := _txOrm.Rollback()
|
||||
if e != nil {
|
||||
logs.Error("rollback transaction failed: %v,%v", e, panicked)
|
||||
}
|
||||
} else {
|
||||
e := _txOrm.Commit()
|
||||
if e != nil {
|
||||
logs.Error("commit transaction failed: %v,%v", e, panicked)
|
||||
}
|
||||
}
|
||||
}()
|
||||
var taskTxOrm = _txOrm
|
||||
err = task(ctx, taskTxOrm)
|
||||
panicked = false
|
||||
return err
|
||||
}
|
||||
|
||||
type txOrm struct {
|
||||
ormBase
|
||||
}
|
||||
|
||||
var _ TxOrmer = new(txOrm)
|
||||
|
||||
func (t *txOrm) Commit() error {
|
||||
return t.db.(txEnder).Commit()
|
||||
}
|
||||
|
||||
func (t *txOrm) Rollback() error {
|
||||
return t.db.(txEnder).Rollback()
|
||||
}
|
||||
|
||||
// NewOrm create new orm
|
||||
func NewOrm() Ormer {
|
||||
BootStrap() // execute only once
|
||||
return NewOrmUsingDB(`default`)
|
||||
}
|
||||
|
||||
// NewOrmUsingDB create new orm with the name
|
||||
func NewOrmUsingDB(aliasName string) Ormer {
|
||||
if al, ok := dataBaseCache.get(aliasName); ok {
|
||||
return newDBWithAlias(al)
|
||||
} else {
|
||||
panic(fmt.Errorf("<Ormer.Using> unknown db alias name `%s`", aliasName))
|
||||
}
|
||||
}
|
||||
|
||||
// NewOrmWithDB create a new ormer object with specify *sql.DB for query
|
||||
func NewOrmWithDB(driverName, aliasName string, db *sql.DB, params ...utils.KV) (Ormer, error) {
|
||||
al, err := newAliasWithDb(aliasName, driverName, db, params...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return newDBWithAlias(al), nil
|
||||
}
|
||||
|
||||
func newDBWithAlias(al *alias) Ormer {
|
||||
o := new(orm)
|
||||
o.alias = al
|
||||
|
||||
if Debug {
|
||||
o.db = newDbQueryLog(al, al.DB)
|
||||
} else {
|
||||
o.db = al.DB
|
||||
}
|
||||
|
||||
if len(globalFilterChains) > 0 {
|
||||
return NewFilterOrmDecorator(o, globalFilterChains...)
|
||||
}
|
||||
return o
|
||||
}
|
153
pkg/client/orm/orm_conds.go
Normal file
153
pkg/client/orm/orm_conds.go
Normal file
@ -0,0 +1,153 @@
|
||||
// Copyright 2014 beego Author. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package orm
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// ExprSep define the expression separation
|
||||
const (
|
||||
ExprSep = "__"
|
||||
)
|
||||
|
||||
type condValue struct {
|
||||
exprs []string
|
||||
args []interface{}
|
||||
cond *Condition
|
||||
isOr bool
|
||||
isNot bool
|
||||
isCond bool
|
||||
isRaw bool
|
||||
sql string
|
||||
}
|
||||
|
||||
// Condition struct.
|
||||
// work for WHERE conditions.
|
||||
type Condition struct {
|
||||
params []condValue
|
||||
}
|
||||
|
||||
// NewCondition return new condition struct
|
||||
func NewCondition() *Condition {
|
||||
c := &Condition{}
|
||||
return c
|
||||
}
|
||||
|
||||
// Raw add raw sql to condition
|
||||
func (c Condition) Raw(expr string, sql string) *Condition {
|
||||
if len(sql) == 0 {
|
||||
panic(fmt.Errorf("<Condition.Raw> sql cannot empty"))
|
||||
}
|
||||
c.params = append(c.params, condValue{exprs: strings.Split(expr, ExprSep), sql: sql, isRaw: true})
|
||||
return &c
|
||||
}
|
||||
|
||||
// And add expression to condition
|
||||
func (c Condition) And(expr string, args ...interface{}) *Condition {
|
||||
if expr == "" || len(args) == 0 {
|
||||
panic(fmt.Errorf("<Condition.And> args cannot empty"))
|
||||
}
|
||||
c.params = append(c.params, condValue{exprs: strings.Split(expr, ExprSep), args: args})
|
||||
return &c
|
||||
}
|
||||
|
||||
// AndNot add NOT expression to condition
|
||||
func (c Condition) AndNot(expr string, args ...interface{}) *Condition {
|
||||
if expr == "" || len(args) == 0 {
|
||||
panic(fmt.Errorf("<Condition.AndNot> args cannot empty"))
|
||||
}
|
||||
c.params = append(c.params, condValue{exprs: strings.Split(expr, ExprSep), args: args, isNot: true})
|
||||
return &c
|
||||
}
|
||||
|
||||
// AndCond combine a condition to current condition
|
||||
func (c *Condition) AndCond(cond *Condition) *Condition {
|
||||
c = c.clone()
|
||||
if c == cond {
|
||||
panic(fmt.Errorf("<Condition.AndCond> cannot use self as sub cond"))
|
||||
}
|
||||
if cond != nil {
|
||||
c.params = append(c.params, condValue{cond: cond, isCond: true})
|
||||
}
|
||||
return c
|
||||
}
|
||||
|
||||
// AndNotCond combine a AND NOT condition to current condition
|
||||
func (c *Condition) AndNotCond(cond *Condition) *Condition {
|
||||
c = c.clone()
|
||||
if c == cond {
|
||||
panic(fmt.Errorf("<Condition.AndNotCond> cannot use self as sub cond"))
|
||||
}
|
||||
|
||||
if cond != nil {
|
||||
c.params = append(c.params, condValue{cond: cond, isCond: true, isNot: true})
|
||||
}
|
||||
return c
|
||||
}
|
||||
|
||||
// Or add OR expression to condition
|
||||
func (c Condition) Or(expr string, args ...interface{}) *Condition {
|
||||
if expr == "" || len(args) == 0 {
|
||||
panic(fmt.Errorf("<Condition.Or> args cannot empty"))
|
||||
}
|
||||
c.params = append(c.params, condValue{exprs: strings.Split(expr, ExprSep), args: args, isOr: true})
|
||||
return &c
|
||||
}
|
||||
|
||||
// OrNot add OR NOT expression to condition
|
||||
func (c Condition) OrNot(expr string, args ...interface{}) *Condition {
|
||||
if expr == "" || len(args) == 0 {
|
||||
panic(fmt.Errorf("<Condition.OrNot> args cannot empty"))
|
||||
}
|
||||
c.params = append(c.params, condValue{exprs: strings.Split(expr, ExprSep), args: args, isNot: true, isOr: true})
|
||||
return &c
|
||||
}
|
||||
|
||||
// OrCond combine a OR condition to current condition
|
||||
func (c *Condition) OrCond(cond *Condition) *Condition {
|
||||
c = c.clone()
|
||||
if c == cond {
|
||||
panic(fmt.Errorf("<Condition.OrCond> cannot use self as sub cond"))
|
||||
}
|
||||
if cond != nil {
|
||||
c.params = append(c.params, condValue{cond: cond, isCond: true, isOr: true})
|
||||
}
|
||||
return c
|
||||
}
|
||||
|
||||
// OrNotCond combine a OR NOT condition to current condition
|
||||
func (c *Condition) OrNotCond(cond *Condition) *Condition {
|
||||
c = c.clone()
|
||||
if c == cond {
|
||||
panic(fmt.Errorf("<Condition.OrNotCond> cannot use self as sub cond"))
|
||||
}
|
||||
|
||||
if cond != nil {
|
||||
c.params = append(c.params, condValue{cond: cond, isCond: true, isNot: true, isOr: true})
|
||||
}
|
||||
return c
|
||||
}
|
||||
|
||||
// IsEmpty check the condition arguments are empty or not.
|
||||
func (c *Condition) IsEmpty() bool {
|
||||
return len(c.params) == 0
|
||||
}
|
||||
|
||||
// clone clone a condition
|
||||
func (c Condition) clone() *Condition {
|
||||
return &c
|
||||
}
|
207
pkg/client/orm/orm_log.go
Normal file
207
pkg/client/orm/orm_log.go
Normal file
@ -0,0 +1,207 @@
|
||||
// Copyright 2014 beego Author. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package orm
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Log implement the log.Logger
|
||||
type Log struct {
|
||||
*log.Logger
|
||||
}
|
||||
|
||||
//costomer log func
|
||||
var LogFunc func(query map[string]interface{})
|
||||
|
||||
// NewLog set io.Writer to create a Logger.
|
||||
func NewLog(out io.Writer) *Log {
|
||||
d := new(Log)
|
||||
d.Logger = log.New(out, "[ORM]", log.LstdFlags)
|
||||
return d
|
||||
}
|
||||
|
||||
func debugLogQueies(alias *alias, operaton, query string, t time.Time, err error, args ...interface{}) {
|
||||
var logMap = make(map[string]interface{})
|
||||
sub := time.Now().Sub(t) / 1e5
|
||||
elsp := float64(int(sub)) / 10.0
|
||||
logMap["cost_time"] = elsp
|
||||
flag := " OK"
|
||||
if err != nil {
|
||||
flag = "FAIL"
|
||||
}
|
||||
logMap["flag"] = flag
|
||||
con := fmt.Sprintf(" -[Queries/%s] - [%s / %11s / %7.1fms] - [%s]", alias.Name, flag, operaton, elsp, query)
|
||||
cons := make([]string, 0, len(args))
|
||||
for _, arg := range args {
|
||||
cons = append(cons, fmt.Sprintf("%v", arg))
|
||||
}
|
||||
if len(cons) > 0 {
|
||||
con += fmt.Sprintf(" - `%s`", strings.Join(cons, "`, `"))
|
||||
}
|
||||
if err != nil {
|
||||
con += " - " + err.Error()
|
||||
}
|
||||
logMap["sql"] = fmt.Sprintf("%s-`%s`", query, strings.Join(cons, "`, `"))
|
||||
if LogFunc != nil {
|
||||
LogFunc(logMap)
|
||||
}
|
||||
DebugLog.Println(con)
|
||||
}
|
||||
|
||||
// statement query logger struct.
|
||||
// if dev mode, use stmtQueryLog, or use stmtQuerier.
|
||||
type stmtQueryLog struct {
|
||||
alias *alias
|
||||
query string
|
||||
stmt stmtQuerier
|
||||
}
|
||||
|
||||
var _ stmtQuerier = new(stmtQueryLog)
|
||||
|
||||
func (d *stmtQueryLog) Close() error {
|
||||
a := time.Now()
|
||||
err := d.stmt.Close()
|
||||
debugLogQueies(d.alias, "st.Close", d.query, a, err)
|
||||
return err
|
||||
}
|
||||
|
||||
func (d *stmtQueryLog) Exec(args ...interface{}) (sql.Result, error) {
|
||||
a := time.Now()
|
||||
res, err := d.stmt.Exec(args...)
|
||||
debugLogQueies(d.alias, "st.Exec", d.query, a, err, args...)
|
||||
return res, err
|
||||
}
|
||||
|
||||
func (d *stmtQueryLog) Query(args ...interface{}) (*sql.Rows, error) {
|
||||
a := time.Now()
|
||||
res, err := d.stmt.Query(args...)
|
||||
debugLogQueies(d.alias, "st.Query", d.query, a, err, args...)
|
||||
return res, err
|
||||
}
|
||||
|
||||
func (d *stmtQueryLog) QueryRow(args ...interface{}) *sql.Row {
|
||||
a := time.Now()
|
||||
res := d.stmt.QueryRow(args...)
|
||||
debugLogQueies(d.alias, "st.QueryRow", d.query, a, nil, args...)
|
||||
return res
|
||||
}
|
||||
|
||||
func newStmtQueryLog(alias *alias, stmt stmtQuerier, query string) stmtQuerier {
|
||||
d := new(stmtQueryLog)
|
||||
d.stmt = stmt
|
||||
d.alias = alias
|
||||
d.query = query
|
||||
return d
|
||||
}
|
||||
|
||||
// database query logger struct.
|
||||
// if dev mode, use dbQueryLog, or use dbQuerier.
|
||||
type dbQueryLog struct {
|
||||
alias *alias
|
||||
db dbQuerier
|
||||
tx txer
|
||||
txe txEnder
|
||||
}
|
||||
|
||||
var _ dbQuerier = new(dbQueryLog)
|
||||
var _ txer = new(dbQueryLog)
|
||||
var _ txEnder = new(dbQueryLog)
|
||||
|
||||
func (d *dbQueryLog) Prepare(query string) (*sql.Stmt, error) {
|
||||
return d.PrepareContext(context.Background(), query)
|
||||
}
|
||||
|
||||
func (d *dbQueryLog) PrepareContext(ctx context.Context, query string) (*sql.Stmt, error) {
|
||||
a := time.Now()
|
||||
stmt, err := d.db.PrepareContext(ctx, query)
|
||||
debugLogQueies(d.alias, "db.Prepare", query, a, err)
|
||||
return stmt, err
|
||||
}
|
||||
|
||||
func (d *dbQueryLog) Exec(query string, args ...interface{}) (sql.Result, error) {
|
||||
return d.ExecContext(context.Background(), query, args...)
|
||||
}
|
||||
|
||||
func (d *dbQueryLog) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) {
|
||||
a := time.Now()
|
||||
res, err := d.db.ExecContext(ctx, query, args...)
|
||||
debugLogQueies(d.alias, "db.Exec", query, a, err, args...)
|
||||
return res, err
|
||||
}
|
||||
|
||||
func (d *dbQueryLog) Query(query string, args ...interface{}) (*sql.Rows, error) {
|
||||
return d.QueryContext(context.Background(), query, args...)
|
||||
}
|
||||
|
||||
func (d *dbQueryLog) QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) {
|
||||
a := time.Now()
|
||||
res, err := d.db.QueryContext(ctx, query, args...)
|
||||
debugLogQueies(d.alias, "db.Query", query, a, err, args...)
|
||||
return res, err
|
||||
}
|
||||
|
||||
func (d *dbQueryLog) QueryRow(query string, args ...interface{}) *sql.Row {
|
||||
return d.QueryRowContext(context.Background(), query, args...)
|
||||
}
|
||||
|
||||
func (d *dbQueryLog) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row {
|
||||
a := time.Now()
|
||||
res := d.db.QueryRowContext(ctx, query, args...)
|
||||
debugLogQueies(d.alias, "db.QueryRow", query, a, nil, args...)
|
||||
return res
|
||||
}
|
||||
|
||||
func (d *dbQueryLog) Begin() (*sql.Tx, error) {
|
||||
return d.BeginTx(context.Background(), nil)
|
||||
}
|
||||
|
||||
func (d *dbQueryLog) BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error) {
|
||||
a := time.Now()
|
||||
tx, err := d.db.(txer).BeginTx(ctx, opts)
|
||||
debugLogQueies(d.alias, "db.BeginTx", "START TRANSACTION", a, err)
|
||||
return tx, err
|
||||
}
|
||||
|
||||
func (d *dbQueryLog) Commit() error {
|
||||
a := time.Now()
|
||||
err := d.db.(txEnder).Commit()
|
||||
debugLogQueies(d.alias, "tx.Commit", "COMMIT", a, err)
|
||||
return err
|
||||
}
|
||||
|
||||
func (d *dbQueryLog) Rollback() error {
|
||||
a := time.Now()
|
||||
err := d.db.(txEnder).Rollback()
|
||||
debugLogQueies(d.alias, "tx.Rollback", "ROLLBACK", a, err)
|
||||
return err
|
||||
}
|
||||
|
||||
func (d *dbQueryLog) SetDB(db dbQuerier) {
|
||||
d.db = db
|
||||
}
|
||||
|
||||
func newDbQueryLog(alias *alias, db dbQuerier) dbQuerier {
|
||||
d := new(dbQueryLog)
|
||||
d.alias = alias
|
||||
d.db = db
|
||||
return d
|
||||
}
|
87
pkg/client/orm/orm_object.go
Normal file
87
pkg/client/orm/orm_object.go
Normal file
@ -0,0 +1,87 @@
|
||||
// Copyright 2014 beego Author. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package orm
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"reflect"
|
||||
)
|
||||
|
||||
// an insert queryer struct
|
||||
type insertSet struct {
|
||||
mi *modelInfo
|
||||
orm *ormBase
|
||||
stmt stmtQuerier
|
||||
closed bool
|
||||
}
|
||||
|
||||
var _ Inserter = new(insertSet)
|
||||
|
||||
// insert model ignore it's registered or not.
|
||||
func (o *insertSet) Insert(md interface{}) (int64, error) {
|
||||
if o.closed {
|
||||
return 0, ErrStmtClosed
|
||||
}
|
||||
val := reflect.ValueOf(md)
|
||||
ind := reflect.Indirect(val)
|
||||
typ := ind.Type()
|
||||
name := getFullName(typ)
|
||||
if val.Kind() != reflect.Ptr {
|
||||
panic(fmt.Errorf("<Inserter.Insert> cannot use non-ptr model struct `%s`", name))
|
||||
}
|
||||
if name != o.mi.fullName {
|
||||
panic(fmt.Errorf("<Inserter.Insert> need model `%s` but found `%s`", o.mi.fullName, name))
|
||||
}
|
||||
id, err := o.orm.alias.DbBaser.InsertStmt(o.stmt, o.mi, ind, o.orm.alias.TZ)
|
||||
if err != nil {
|
||||
return id, err
|
||||
}
|
||||
if id > 0 {
|
||||
if o.mi.fields.pk.auto {
|
||||
if o.mi.fields.pk.fieldType&IsPositiveIntegerField > 0 {
|
||||
ind.FieldByIndex(o.mi.fields.pk.fieldIndex).SetUint(uint64(id))
|
||||
} else {
|
||||
ind.FieldByIndex(o.mi.fields.pk.fieldIndex).SetInt(id)
|
||||
}
|
||||
}
|
||||
}
|
||||
return id, nil
|
||||
}
|
||||
|
||||
// close insert queryer statement
|
||||
func (o *insertSet) Close() error {
|
||||
if o.closed {
|
||||
return ErrStmtClosed
|
||||
}
|
||||
o.closed = true
|
||||
return o.stmt.Close()
|
||||
}
|
||||
|
||||
// create new insert queryer.
|
||||
func newInsertSet(orm *ormBase, mi *modelInfo) (Inserter, error) {
|
||||
bi := new(insertSet)
|
||||
bi.orm = orm
|
||||
bi.mi = mi
|
||||
st, query, err := orm.alias.DbBaser.PrepareInsert(orm.db, mi)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if Debug {
|
||||
bi.stmt = newStmtQueryLog(orm.alias, st, query)
|
||||
} else {
|
||||
bi.stmt = st
|
||||
}
|
||||
return bi, nil
|
||||
}
|
140
pkg/client/orm/orm_querym2m.go
Normal file
140
pkg/client/orm/orm_querym2m.go
Normal file
@ -0,0 +1,140 @@
|
||||
// Copyright 2014 beego Author. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package orm
|
||||
|
||||
import "reflect"
|
||||
|
||||
// model to model struct
|
||||
type queryM2M struct {
|
||||
md interface{}
|
||||
mi *modelInfo
|
||||
fi *fieldInfo
|
||||
qs *querySet
|
||||
ind reflect.Value
|
||||
}
|
||||
|
||||
// add models to origin models when creating queryM2M.
|
||||
// example:
|
||||
// m2m := orm.QueryM2M(post,"Tag")
|
||||
// m2m.Add(&Tag1{},&Tag2{})
|
||||
// for _,tag := range post.Tags{}
|
||||
//
|
||||
// make sure the relation is defined in post model struct tag.
|
||||
func (o *queryM2M) Add(mds ...interface{}) (int64, error) {
|
||||
fi := o.fi
|
||||
mi := fi.relThroughModelInfo
|
||||
mfi := fi.reverseFieldInfo
|
||||
rfi := fi.reverseFieldInfoTwo
|
||||
|
||||
orm := o.qs.orm
|
||||
dbase := orm.alias.DbBaser
|
||||
|
||||
var models []interface{}
|
||||
var otherValues []interface{}
|
||||
var otherNames []string
|
||||
|
||||
for _, colname := range mi.fields.dbcols {
|
||||
if colname != mfi.column && colname != rfi.column && colname != fi.mi.fields.pk.column &&
|
||||
mi.fields.columns[colname] != mi.fields.pk {
|
||||
otherNames = append(otherNames, colname)
|
||||
}
|
||||
}
|
||||
for i, md := range mds {
|
||||
if reflect.Indirect(reflect.ValueOf(md)).Kind() != reflect.Struct && i > 0 {
|
||||
otherValues = append(otherValues, md)
|
||||
mds = append(mds[:i], mds[i+1:]...)
|
||||
}
|
||||
}
|
||||
for _, md := range mds {
|
||||
val := reflect.ValueOf(md)
|
||||
if val.Kind() == reflect.Slice || val.Kind() == reflect.Array {
|
||||
for i := 0; i < val.Len(); i++ {
|
||||
v := val.Index(i)
|
||||
if v.CanInterface() {
|
||||
models = append(models, v.Interface())
|
||||
}
|
||||
}
|
||||
} else {
|
||||
models = append(models, md)
|
||||
}
|
||||
}
|
||||
|
||||
_, v1, exist := getExistPk(o.mi, o.ind)
|
||||
if !exist {
|
||||
panic(ErrMissPK)
|
||||
}
|
||||
|
||||
names := []string{mfi.column, rfi.column}
|
||||
|
||||
values := make([]interface{}, 0, len(models)*2)
|
||||
for _, md := range models {
|
||||
|
||||
ind := reflect.Indirect(reflect.ValueOf(md))
|
||||
var v2 interface{}
|
||||
if ind.Kind() != reflect.Struct {
|
||||
v2 = ind.Interface()
|
||||
} else {
|
||||
_, v2, exist = getExistPk(fi.relModelInfo, ind)
|
||||
if !exist {
|
||||
panic(ErrMissPK)
|
||||
}
|
||||
}
|
||||
values = append(values, v1, v2)
|
||||
|
||||
}
|
||||
names = append(names, otherNames...)
|
||||
values = append(values, otherValues...)
|
||||
return dbase.InsertValue(orm.db, mi, true, names, values)
|
||||
}
|
||||
|
||||
// remove models following the origin model relationship
|
||||
func (o *queryM2M) Remove(mds ...interface{}) (int64, error) {
|
||||
fi := o.fi
|
||||
qs := o.qs.Filter(fi.reverseFieldInfo.name, o.md)
|
||||
|
||||
return qs.Filter(fi.reverseFieldInfoTwo.name+ExprSep+"in", mds).Delete()
|
||||
}
|
||||
|
||||
// check model is existed in relationship of origin model
|
||||
func (o *queryM2M) Exist(md interface{}) bool {
|
||||
fi := o.fi
|
||||
return o.qs.Filter(fi.reverseFieldInfo.name, o.md).
|
||||
Filter(fi.reverseFieldInfoTwo.name, md).Exist()
|
||||
}
|
||||
|
||||
// clean all models in related of origin model
|
||||
func (o *queryM2M) Clear() (int64, error) {
|
||||
fi := o.fi
|
||||
return o.qs.Filter(fi.reverseFieldInfo.name, o.md).Delete()
|
||||
}
|
||||
|
||||
// count all related models of origin model
|
||||
func (o *queryM2M) Count() (int64, error) {
|
||||
fi := o.fi
|
||||
return o.qs.Filter(fi.reverseFieldInfo.name, o.md).Count()
|
||||
}
|
||||
|
||||
var _ QueryM2Mer = new(queryM2M)
|
||||
|
||||
// create new M2M queryer.
|
||||
func newQueryM2M(md interface{}, o *ormBase, mi *modelInfo, fi *fieldInfo, ind reflect.Value) QueryM2Mer {
|
||||
qm2m := new(queryM2M)
|
||||
qm2m.md = md
|
||||
qm2m.mi = mi
|
||||
qm2m.fi = fi
|
||||
qm2m.ind = ind
|
||||
qm2m.qs = newQuerySet(o, fi.relThroughModelInfo).(*querySet)
|
||||
return qm2m
|
||||
}
|
325
pkg/client/orm/orm_queryset.go
Normal file
325
pkg/client/orm/orm_queryset.go
Normal file
@ -0,0 +1,325 @@
|
||||
// Copyright 2014 beego Author. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package orm
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/astaxie/beego/pkg/client/orm/hints"
|
||||
)
|
||||
|
||||
type colValue struct {
|
||||
value int64
|
||||
opt operator
|
||||
}
|
||||
|
||||
type operator int
|
||||
|
||||
// define Col operations
|
||||
const (
|
||||
ColAdd operator = iota
|
||||
ColMinus
|
||||
ColMultiply
|
||||
ColExcept
|
||||
ColBitAnd
|
||||
ColBitRShift
|
||||
ColBitLShift
|
||||
ColBitXOR
|
||||
ColBitOr
|
||||
)
|
||||
|
||||
// ColValue do the field raw changes. e.g Nums = Nums + 10. usage:
|
||||
// Params{
|
||||
// "Nums": ColValue(Col_Add, 10),
|
||||
// }
|
||||
func ColValue(opt operator, value interface{}) interface{} {
|
||||
switch opt {
|
||||
case ColAdd, ColMinus, ColMultiply, ColExcept, ColBitAnd, ColBitRShift,
|
||||
ColBitLShift, ColBitXOR, ColBitOr:
|
||||
default:
|
||||
panic(fmt.Errorf("orm.ColValue wrong operator"))
|
||||
}
|
||||
v, err := StrTo(ToStr(value)).Int64()
|
||||
if err != nil {
|
||||
panic(fmt.Errorf("orm.ColValue doesn't support non string/numeric type, %s", err))
|
||||
}
|
||||
var val colValue
|
||||
val.value = v
|
||||
val.opt = opt
|
||||
return val
|
||||
}
|
||||
|
||||
// real query struct
|
||||
type querySet struct {
|
||||
mi *modelInfo
|
||||
cond *Condition
|
||||
related []string
|
||||
relDepth int
|
||||
limit int64
|
||||
offset int64
|
||||
groups []string
|
||||
orders []string
|
||||
distinct bool
|
||||
forUpdate bool
|
||||
useIndex int
|
||||
indexes []string
|
||||
orm *ormBase
|
||||
ctx context.Context
|
||||
forContext bool
|
||||
}
|
||||
|
||||
var _ QuerySeter = new(querySet)
|
||||
|
||||
// add condition expression to QuerySeter.
|
||||
func (o querySet) Filter(expr string, args ...interface{}) QuerySeter {
|
||||
if o.cond == nil {
|
||||
o.cond = NewCondition()
|
||||
}
|
||||
o.cond = o.cond.And(expr, args...)
|
||||
return &o
|
||||
}
|
||||
|
||||
// add raw sql to querySeter.
|
||||
func (o querySet) FilterRaw(expr string, sql string) QuerySeter {
|
||||
if o.cond == nil {
|
||||
o.cond = NewCondition()
|
||||
}
|
||||
o.cond = o.cond.Raw(expr, sql)
|
||||
return &o
|
||||
}
|
||||
|
||||
// add NOT condition to querySeter.
|
||||
func (o querySet) Exclude(expr string, args ...interface{}) QuerySeter {
|
||||
if o.cond == nil {
|
||||
o.cond = NewCondition()
|
||||
}
|
||||
o.cond = o.cond.AndNot(expr, args...)
|
||||
return &o
|
||||
}
|
||||
|
||||
// set offset number
|
||||
func (o *querySet) setOffset(num interface{}) {
|
||||
o.offset = ToInt64(num)
|
||||
}
|
||||
|
||||
// add LIMIT value.
|
||||
// args[0] means offset, e.g. LIMIT num,offset.
|
||||
func (o querySet) Limit(limit interface{}, args ...interface{}) QuerySeter {
|
||||
o.limit = ToInt64(limit)
|
||||
if len(args) > 0 {
|
||||
o.setOffset(args[0])
|
||||
}
|
||||
return &o
|
||||
}
|
||||
|
||||
// add OFFSET value
|
||||
func (o querySet) Offset(offset interface{}) QuerySeter {
|
||||
o.setOffset(offset)
|
||||
return &o
|
||||
}
|
||||
|
||||
// add GROUP expression
|
||||
func (o querySet) GroupBy(exprs ...string) QuerySeter {
|
||||
o.groups = exprs
|
||||
return &o
|
||||
}
|
||||
|
||||
// add ORDER expression.
|
||||
// "column" means ASC, "-column" means DESC.
|
||||
func (o querySet) OrderBy(exprs ...string) QuerySeter {
|
||||
o.orders = exprs
|
||||
return &o
|
||||
}
|
||||
|
||||
// add DISTINCT to SELECT
|
||||
func (o querySet) Distinct() QuerySeter {
|
||||
o.distinct = true
|
||||
return &o
|
||||
}
|
||||
|
||||
// add FOR UPDATE to SELECT
|
||||
func (o querySet) ForUpdate() QuerySeter {
|
||||
o.forUpdate = true
|
||||
return &o
|
||||
}
|
||||
|
||||
// ForceIndex force index for query
|
||||
func (o querySet) ForceIndex(indexes ...string) QuerySeter {
|
||||
o.useIndex = hints.KeyForceIndex
|
||||
o.indexes = indexes
|
||||
return &o
|
||||
}
|
||||
|
||||
// UseIndex use index for query
|
||||
func (o querySet) UseIndex(indexes ...string) QuerySeter {
|
||||
o.useIndex = hints.KeyUseIndex
|
||||
o.indexes = indexes
|
||||
return &o
|
||||
}
|
||||
|
||||
// IgnoreIndex ignore index for query
|
||||
func (o querySet) IgnoreIndex(indexes ...string) QuerySeter {
|
||||
o.useIndex = hints.KeyIgnoreIndex
|
||||
o.indexes = indexes
|
||||
return &o
|
||||
}
|
||||
|
||||
// set relation model to query together.
|
||||
// it will query relation models and assign to parent model.
|
||||
func (o querySet) RelatedSel(params ...interface{}) QuerySeter {
|
||||
if len(params) == 0 {
|
||||
o.relDepth = DefaultRelsDepth
|
||||
} else {
|
||||
for _, p := range params {
|
||||
switch val := p.(type) {
|
||||
case string:
|
||||
o.related = append(o.related, val)
|
||||
case int:
|
||||
o.relDepth = val
|
||||
default:
|
||||
panic(fmt.Errorf("<QuerySeter.RelatedSel> wrong param kind: %v", val))
|
||||
}
|
||||
}
|
||||
}
|
||||
return &o
|
||||
}
|
||||
|
||||
// set condition to QuerySeter.
|
||||
func (o querySet) SetCond(cond *Condition) QuerySeter {
|
||||
o.cond = cond
|
||||
return &o
|
||||
}
|
||||
|
||||
// get condition from QuerySeter
|
||||
func (o querySet) GetCond() *Condition {
|
||||
return o.cond
|
||||
}
|
||||
|
||||
// return QuerySeter execution result number
|
||||
func (o *querySet) Count() (int64, error) {
|
||||
return o.orm.alias.DbBaser.Count(o.orm.db, o, o.mi, o.cond, o.orm.alias.TZ)
|
||||
}
|
||||
|
||||
// check result empty or not after QuerySeter executed
|
||||
func (o *querySet) Exist() bool {
|
||||
cnt, _ := o.orm.alias.DbBaser.Count(o.orm.db, o, o.mi, o.cond, o.orm.alias.TZ)
|
||||
return cnt > 0
|
||||
}
|
||||
|
||||
// execute update with parameters
|
||||
func (o *querySet) Update(values Params) (int64, error) {
|
||||
return o.orm.alias.DbBaser.UpdateBatch(o.orm.db, o, o.mi, o.cond, values, o.orm.alias.TZ)
|
||||
}
|
||||
|
||||
// execute delete
|
||||
func (o *querySet) Delete() (int64, error) {
|
||||
return o.orm.alias.DbBaser.DeleteBatch(o.orm.db, o, o.mi, o.cond, o.orm.alias.TZ)
|
||||
}
|
||||
|
||||
// return a insert queryer.
|
||||
// it can be used in times.
|
||||
// example:
|
||||
// i,err := sq.PrepareInsert()
|
||||
// i.Add(&user1{},&user2{})
|
||||
func (o *querySet) PrepareInsert() (Inserter, error) {
|
||||
return newInsertSet(o.orm, o.mi)
|
||||
}
|
||||
|
||||
// query all data and map to containers.
|
||||
// cols means the columns when querying.
|
||||
func (o *querySet) All(container interface{}, cols ...string) (int64, error) {
|
||||
return o.orm.alias.DbBaser.ReadBatch(o.orm.db, o, o.mi, o.cond, container, o.orm.alias.TZ, cols)
|
||||
}
|
||||
|
||||
// query one row data and map to containers.
|
||||
// cols means the columns when querying.
|
||||
func (o *querySet) One(container interface{}, cols ...string) error {
|
||||
o.limit = 1
|
||||
num, err := o.orm.alias.DbBaser.ReadBatch(o.orm.db, o, o.mi, o.cond, container, o.orm.alias.TZ, cols)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if num == 0 {
|
||||
return ErrNoRows
|
||||
}
|
||||
|
||||
if num > 1 {
|
||||
return ErrMultiRows
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// query all data and map to []map[string]interface.
|
||||
// expres means condition expression.
|
||||
// it converts data to []map[column]value.
|
||||
func (o *querySet) Values(results *[]Params, exprs ...string) (int64, error) {
|
||||
return o.orm.alias.DbBaser.ReadValues(o.orm.db, o, o.mi, o.cond, exprs, results, o.orm.alias.TZ)
|
||||
}
|
||||
|
||||
// query all data and map to [][]interface
|
||||
// it converts data to [][column_index]value
|
||||
func (o *querySet) ValuesList(results *[]ParamsList, exprs ...string) (int64, error) {
|
||||
return o.orm.alias.DbBaser.ReadValues(o.orm.db, o, o.mi, o.cond, exprs, results, o.orm.alias.TZ)
|
||||
}
|
||||
|
||||
// query all data and map to []interface.
|
||||
// it's designed for one row record set, auto change to []value, not [][column]value.
|
||||
func (o *querySet) ValuesFlat(result *ParamsList, expr string) (int64, error) {
|
||||
return o.orm.alias.DbBaser.ReadValues(o.orm.db, o, o.mi, o.cond, []string{expr}, result, o.orm.alias.TZ)
|
||||
}
|
||||
|
||||
// query all rows into map[string]interface with specify key and value column name.
|
||||
// keyCol = "name", valueCol = "value"
|
||||
// table data
|
||||
// name | value
|
||||
// total | 100
|
||||
// found | 200
|
||||
// to map[string]interface{}{
|
||||
// "total": 100,
|
||||
// "found": 200,
|
||||
// }
|
||||
func (o *querySet) RowsToMap(result *Params, keyCol, valueCol string) (int64, error) {
|
||||
panic(ErrNotImplement)
|
||||
}
|
||||
|
||||
// query all rows into struct with specify key and value column name.
|
||||
// keyCol = "name", valueCol = "value"
|
||||
// table data
|
||||
// name | value
|
||||
// total | 100
|
||||
// found | 200
|
||||
// to struct {
|
||||
// Total int
|
||||
// Found int
|
||||
// }
|
||||
func (o *querySet) RowsToStruct(ptrStruct interface{}, keyCol, valueCol string) (int64, error) {
|
||||
panic(ErrNotImplement)
|
||||
}
|
||||
|
||||
// set context to QuerySeter.
|
||||
func (o querySet) WithContext(ctx context.Context) QuerySeter {
|
||||
o.ctx = ctx
|
||||
o.forContext = true
|
||||
return &o
|
||||
}
|
||||
|
||||
// create new QuerySeter.
|
||||
func newQuerySet(orm *ormBase, mi *modelInfo) QuerySeter {
|
||||
o := new(querySet)
|
||||
o.mi = mi
|
||||
o.orm = orm
|
||||
return o
|
||||
}
|
900
pkg/client/orm/orm_raw.go
Normal file
900
pkg/client/orm/orm_raw.go
Normal file
@ -0,0 +1,900 @@
|
||||
// Copyright 2014 beego Author. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package orm
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"time"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
// raw sql string prepared statement
|
||||
type rawPrepare struct {
|
||||
rs *rawSet
|
||||
stmt stmtQuerier
|
||||
closed bool
|
||||
}
|
||||
|
||||
func (o *rawPrepare) Exec(args ...interface{}) (sql.Result, error) {
|
||||
if o.closed {
|
||||
return nil, ErrStmtClosed
|
||||
}
|
||||
flatParams := getFlatParams(nil, args, o.rs.orm.alias.TZ)
|
||||
return o.stmt.Exec(flatParams...)
|
||||
}
|
||||
|
||||
func (o *rawPrepare) Close() error {
|
||||
o.closed = true
|
||||
return o.stmt.Close()
|
||||
}
|
||||
|
||||
func newRawPreparer(rs *rawSet) (RawPreparer, error) {
|
||||
o := new(rawPrepare)
|
||||
o.rs = rs
|
||||
|
||||
query := rs.query
|
||||
rs.orm.alias.DbBaser.ReplaceMarks(&query)
|
||||
|
||||
st, err := rs.orm.db.Prepare(query)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if Debug {
|
||||
o.stmt = newStmtQueryLog(rs.orm.alias, st, query)
|
||||
} else {
|
||||
o.stmt = st
|
||||
}
|
||||
return o, nil
|
||||
}
|
||||
|
||||
// raw query seter
|
||||
type rawSet struct {
|
||||
query string
|
||||
args []interface{}
|
||||
orm *ormBase
|
||||
}
|
||||
|
||||
var _ RawSeter = new(rawSet)
|
||||
|
||||
// set args for every query
|
||||
func (o rawSet) SetArgs(args ...interface{}) RawSeter {
|
||||
o.args = args
|
||||
return &o
|
||||
}
|
||||
|
||||
// execute raw sql and return sql.Result
|
||||
func (o *rawSet) Exec() (sql.Result, error) {
|
||||
query := o.query
|
||||
o.orm.alias.DbBaser.ReplaceMarks(&query)
|
||||
|
||||
args := getFlatParams(nil, o.args, o.orm.alias.TZ)
|
||||
return o.orm.db.Exec(query, args...)
|
||||
}
|
||||
|
||||
// set field value to row container
|
||||
func (o *rawSet) setFieldValue(ind reflect.Value, value interface{}) {
|
||||
switch ind.Kind() {
|
||||
case reflect.Bool:
|
||||
if value == nil {
|
||||
ind.SetBool(false)
|
||||
} else if v, ok := value.(bool); ok {
|
||||
ind.SetBool(v)
|
||||
} else {
|
||||
v, _ := StrTo(ToStr(value)).Bool()
|
||||
ind.SetBool(v)
|
||||
}
|
||||
|
||||
case reflect.String:
|
||||
if value == nil {
|
||||
ind.SetString("")
|
||||
} else {
|
||||
ind.SetString(ToStr(value))
|
||||
}
|
||||
|
||||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
|
||||
if value == nil {
|
||||
ind.SetInt(0)
|
||||
} else {
|
||||
val := reflect.ValueOf(value)
|
||||
switch val.Kind() {
|
||||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
|
||||
ind.SetInt(val.Int())
|
||||
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
|
||||
ind.SetInt(int64(val.Uint()))
|
||||
default:
|
||||
v, _ := StrTo(ToStr(value)).Int64()
|
||||
ind.SetInt(v)
|
||||
}
|
||||
}
|
||||
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
|
||||
if value == nil {
|
||||
ind.SetUint(0)
|
||||
} else {
|
||||
val := reflect.ValueOf(value)
|
||||
switch val.Kind() {
|
||||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
|
||||
ind.SetUint(uint64(val.Int()))
|
||||
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
|
||||
ind.SetUint(val.Uint())
|
||||
default:
|
||||
v, _ := StrTo(ToStr(value)).Uint64()
|
||||
ind.SetUint(v)
|
||||
}
|
||||
}
|
||||
case reflect.Float64, reflect.Float32:
|
||||
if value == nil {
|
||||
ind.SetFloat(0)
|
||||
} else {
|
||||
val := reflect.ValueOf(value)
|
||||
switch val.Kind() {
|
||||
case reflect.Float64:
|
||||
ind.SetFloat(val.Float())
|
||||
default:
|
||||
v, _ := StrTo(ToStr(value)).Float64()
|
||||
ind.SetFloat(v)
|
||||
}
|
||||
}
|
||||
|
||||
case reflect.Struct:
|
||||
if value == nil {
|
||||
ind.Set(reflect.Zero(ind.Type()))
|
||||
return
|
||||
}
|
||||
switch ind.Interface().(type) {
|
||||
case time.Time:
|
||||
var str string
|
||||
switch d := value.(type) {
|
||||
case time.Time:
|
||||
o.orm.alias.DbBaser.TimeFromDB(&d, o.orm.alias.TZ)
|
||||
ind.Set(reflect.ValueOf(d))
|
||||
case []byte:
|
||||
str = string(d)
|
||||
case string:
|
||||
str = d
|
||||
}
|
||||
if str != "" {
|
||||
if len(str) >= 19 {
|
||||
str = str[:19]
|
||||
t, err := time.ParseInLocation(formatDateTime, str, o.orm.alias.TZ)
|
||||
if err == nil {
|
||||
t = t.In(DefaultTimeLoc)
|
||||
ind.Set(reflect.ValueOf(t))
|
||||
}
|
||||
} else if len(str) >= 10 {
|
||||
str = str[:10]
|
||||
t, err := time.ParseInLocation(formatDate, str, DefaultTimeLoc)
|
||||
if err == nil {
|
||||
ind.Set(reflect.ValueOf(t))
|
||||
}
|
||||
}
|
||||
}
|
||||
case sql.NullString, sql.NullInt64, sql.NullFloat64, sql.NullBool:
|
||||
indi := reflect.New(ind.Type()).Interface()
|
||||
sc, ok := indi.(sql.Scanner)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
err := sc.Scan(value)
|
||||
if err == nil {
|
||||
ind.Set(reflect.Indirect(reflect.ValueOf(sc)))
|
||||
}
|
||||
}
|
||||
|
||||
case reflect.Ptr:
|
||||
if value == nil {
|
||||
ind.Set(reflect.Zero(ind.Type()))
|
||||
break
|
||||
}
|
||||
ind.Set(reflect.New(ind.Type().Elem()))
|
||||
o.setFieldValue(reflect.Indirect(ind), value)
|
||||
}
|
||||
}
|
||||
|
||||
// set field value in loop for slice container
|
||||
func (o *rawSet) loopSetRefs(refs []interface{}, sInds []reflect.Value, nIndsPtr *[]reflect.Value, eTyps []reflect.Type, init bool) {
|
||||
nInds := *nIndsPtr
|
||||
|
||||
cur := 0
|
||||
for i := 0; i < len(sInds); i++ {
|
||||
sInd := sInds[i]
|
||||
eTyp := eTyps[i]
|
||||
|
||||
typ := eTyp
|
||||
isPtr := false
|
||||
if typ.Kind() == reflect.Ptr {
|
||||
isPtr = true
|
||||
typ = typ.Elem()
|
||||
}
|
||||
if typ.Kind() == reflect.Ptr {
|
||||
isPtr = true
|
||||
typ = typ.Elem()
|
||||
}
|
||||
|
||||
var nInd reflect.Value
|
||||
if init {
|
||||
nInd = reflect.New(sInd.Type()).Elem()
|
||||
} else {
|
||||
nInd = nInds[i]
|
||||
}
|
||||
|
||||
val := reflect.New(typ)
|
||||
ind := val.Elem()
|
||||
|
||||
tpName := ind.Type().String()
|
||||
|
||||
if ind.Kind() == reflect.Struct {
|
||||
if tpName == "time.Time" {
|
||||
value := reflect.ValueOf(refs[cur]).Elem().Interface()
|
||||
if isPtr && value == nil {
|
||||
val = reflect.New(val.Type()).Elem()
|
||||
} else {
|
||||
o.setFieldValue(ind, value)
|
||||
}
|
||||
cur++
|
||||
}
|
||||
|
||||
} else {
|
||||
value := reflect.ValueOf(refs[cur]).Elem().Interface()
|
||||
if isPtr && value == nil {
|
||||
val = reflect.New(val.Type()).Elem()
|
||||
} else {
|
||||
o.setFieldValue(ind, value)
|
||||
}
|
||||
cur++
|
||||
}
|
||||
|
||||
if nInd.Kind() == reflect.Slice {
|
||||
if isPtr {
|
||||
nInd = reflect.Append(nInd, val)
|
||||
} else {
|
||||
nInd = reflect.Append(nInd, ind)
|
||||
}
|
||||
} else {
|
||||
if isPtr {
|
||||
nInd.Set(val)
|
||||
} else {
|
||||
nInd.Set(ind)
|
||||
}
|
||||
}
|
||||
|
||||
nInds[i] = nInd
|
||||
}
|
||||
}
|
||||
|
||||
// query data and map to container
|
||||
func (o *rawSet) QueryRow(containers ...interface{}) error {
|
||||
var (
|
||||
refs = make([]interface{}, 0, len(containers))
|
||||
sInds []reflect.Value
|
||||
eTyps []reflect.Type
|
||||
sMi *modelInfo
|
||||
)
|
||||
structMode := false
|
||||
for _, container := range containers {
|
||||
val := reflect.ValueOf(container)
|
||||
ind := reflect.Indirect(val)
|
||||
|
||||
if val.Kind() != reflect.Ptr {
|
||||
panic(fmt.Errorf("<RawSeter.QueryRow> all args must be use ptr"))
|
||||
}
|
||||
|
||||
etyp := ind.Type()
|
||||
typ := etyp
|
||||
if typ.Kind() == reflect.Ptr {
|
||||
typ = typ.Elem()
|
||||
}
|
||||
|
||||
sInds = append(sInds, ind)
|
||||
eTyps = append(eTyps, etyp)
|
||||
|
||||
if typ.Kind() == reflect.Struct && typ.String() != "time.Time" {
|
||||
if len(containers) > 1 {
|
||||
panic(fmt.Errorf("<RawSeter.QueryRow> now support one struct only. see #384"))
|
||||
}
|
||||
|
||||
structMode = true
|
||||
fn := getFullName(typ)
|
||||
if mi, ok := modelCache.getByFullName(fn); ok {
|
||||
sMi = mi
|
||||
}
|
||||
} else {
|
||||
var ref interface{}
|
||||
refs = append(refs, &ref)
|
||||
}
|
||||
}
|
||||
|
||||
query := o.query
|
||||
o.orm.alias.DbBaser.ReplaceMarks(&query)
|
||||
|
||||
args := getFlatParams(nil, o.args, o.orm.alias.TZ)
|
||||
rows, err := o.orm.db.Query(query, args...)
|
||||
if err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
return ErrNoRows
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
defer rows.Close()
|
||||
|
||||
if rows.Next() {
|
||||
if structMode {
|
||||
columns, err := rows.Columns()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
columnsMp := make(map[string]interface{}, len(columns))
|
||||
|
||||
refs = make([]interface{}, 0, len(columns))
|
||||
for _, col := range columns {
|
||||
var ref interface{}
|
||||
columnsMp[col] = &ref
|
||||
refs = append(refs, &ref)
|
||||
}
|
||||
|
||||
if err := rows.Scan(refs...); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
ind := sInds[0]
|
||||
|
||||
if ind.Kind() == reflect.Ptr {
|
||||
if ind.IsNil() || !ind.IsValid() {
|
||||
ind.Set(reflect.New(eTyps[0].Elem()))
|
||||
}
|
||||
ind = ind.Elem()
|
||||
}
|
||||
|
||||
if sMi != nil {
|
||||
for _, col := range columns {
|
||||
if fi := sMi.fields.GetByColumn(col); fi != nil {
|
||||
value := reflect.ValueOf(columnsMp[col]).Elem().Interface()
|
||||
field := ind.FieldByIndex(fi.fieldIndex)
|
||||
if fi.fieldType&IsRelField > 0 {
|
||||
mf := reflect.New(fi.relModelInfo.addrField.Elem().Type())
|
||||
field.Set(mf)
|
||||
field = mf.Elem().FieldByIndex(fi.relModelInfo.fields.pk.fieldIndex)
|
||||
}
|
||||
if fi.isFielder {
|
||||
fd := field.Addr().Interface().(Fielder)
|
||||
err := fd.SetRaw(value)
|
||||
if err != nil {
|
||||
return errors.Errorf("set raw error:%s", err)
|
||||
}
|
||||
} else {
|
||||
o.setFieldValue(field, value)
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// define recursive function
|
||||
var recursiveSetField func(rv reflect.Value)
|
||||
recursiveSetField = func(rv reflect.Value) {
|
||||
for i := 0; i < rv.NumField(); i++ {
|
||||
f := rv.Field(i)
|
||||
fe := rv.Type().Field(i)
|
||||
|
||||
// check if the field is a Struct
|
||||
// recursive the Struct type
|
||||
if fe.Type.Kind() == reflect.Struct {
|
||||
recursiveSetField(f)
|
||||
}
|
||||
|
||||
_, tags := parseStructTag(fe.Tag.Get(defaultStructTagName))
|
||||
var col string
|
||||
if col = tags["column"]; col == "" {
|
||||
col = nameStrategyMap[nameStrategy](fe.Name)
|
||||
}
|
||||
if v, ok := columnsMp[col]; ok {
|
||||
value := reflect.ValueOf(v).Elem().Interface()
|
||||
o.setFieldValue(f, value)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// init call the recursive function
|
||||
recursiveSetField(ind)
|
||||
}
|
||||
|
||||
} else {
|
||||
if err := rows.Scan(refs...); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
nInds := make([]reflect.Value, len(sInds))
|
||||
o.loopSetRefs(refs, sInds, &nInds, eTyps, true)
|
||||
for i, sInd := range sInds {
|
||||
nInd := nInds[i]
|
||||
sInd.Set(nInd)
|
||||
}
|
||||
}
|
||||
|
||||
} else {
|
||||
return ErrNoRows
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// query data rows and map to container
|
||||
func (o *rawSet) QueryRows(containers ...interface{}) (int64, error) {
|
||||
var (
|
||||
refs = make([]interface{}, 0, len(containers))
|
||||
sInds []reflect.Value
|
||||
eTyps []reflect.Type
|
||||
sMi *modelInfo
|
||||
)
|
||||
structMode := false
|
||||
for _, container := range containers {
|
||||
val := reflect.ValueOf(container)
|
||||
sInd := reflect.Indirect(val)
|
||||
if val.Kind() != reflect.Ptr || sInd.Kind() != reflect.Slice {
|
||||
panic(fmt.Errorf("<RawSeter.QueryRows> all args must be use ptr slice"))
|
||||
}
|
||||
|
||||
etyp := sInd.Type().Elem()
|
||||
typ := etyp
|
||||
if typ.Kind() == reflect.Ptr {
|
||||
typ = typ.Elem()
|
||||
}
|
||||
|
||||
sInds = append(sInds, sInd)
|
||||
eTyps = append(eTyps, etyp)
|
||||
|
||||
if typ.Kind() == reflect.Struct && typ.String() != "time.Time" {
|
||||
if len(containers) > 1 {
|
||||
panic(fmt.Errorf("<RawSeter.QueryRow> now support one struct only. see #384"))
|
||||
}
|
||||
|
||||
structMode = true
|
||||
fn := getFullName(typ)
|
||||
if mi, ok := modelCache.getByFullName(fn); ok {
|
||||
sMi = mi
|
||||
}
|
||||
} else {
|
||||
var ref interface{}
|
||||
refs = append(refs, &ref)
|
||||
}
|
||||
}
|
||||
|
||||
query := o.query
|
||||
o.orm.alias.DbBaser.ReplaceMarks(&query)
|
||||
|
||||
args := getFlatParams(nil, o.args, o.orm.alias.TZ)
|
||||
rows, err := o.orm.db.Query(query, args...)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
defer rows.Close()
|
||||
|
||||
var cnt int64
|
||||
nInds := make([]reflect.Value, len(sInds))
|
||||
sInd := sInds[0]
|
||||
|
||||
for rows.Next() {
|
||||
|
||||
if structMode {
|
||||
columns, err := rows.Columns()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
columnsMp := make(map[string]interface{}, len(columns))
|
||||
|
||||
refs = make([]interface{}, 0, len(columns))
|
||||
for _, col := range columns {
|
||||
var ref interface{}
|
||||
columnsMp[col] = &ref
|
||||
refs = append(refs, &ref)
|
||||
}
|
||||
|
||||
if err := rows.Scan(refs...); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
if cnt == 0 && !sInd.IsNil() {
|
||||
sInd.Set(reflect.New(sInd.Type()).Elem())
|
||||
}
|
||||
|
||||
var ind reflect.Value
|
||||
if eTyps[0].Kind() == reflect.Ptr {
|
||||
ind = reflect.New(eTyps[0].Elem())
|
||||
} else {
|
||||
ind = reflect.New(eTyps[0])
|
||||
}
|
||||
|
||||
if ind.Kind() == reflect.Ptr {
|
||||
ind = ind.Elem()
|
||||
}
|
||||
|
||||
if sMi != nil {
|
||||
for _, col := range columns {
|
||||
if fi := sMi.fields.GetByColumn(col); fi != nil {
|
||||
value := reflect.ValueOf(columnsMp[col]).Elem().Interface()
|
||||
field := ind.FieldByIndex(fi.fieldIndex)
|
||||
if fi.fieldType&IsRelField > 0 {
|
||||
mf := reflect.New(fi.relModelInfo.addrField.Elem().Type())
|
||||
field.Set(mf)
|
||||
field = mf.Elem().FieldByIndex(fi.relModelInfo.fields.pk.fieldIndex)
|
||||
}
|
||||
if fi.isFielder {
|
||||
fd := field.Addr().Interface().(Fielder)
|
||||
err := fd.SetRaw(value)
|
||||
if err != nil {
|
||||
return 0, errors.Errorf("set raw error:%s", err)
|
||||
}
|
||||
} else {
|
||||
o.setFieldValue(field, value)
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// define recursive function
|
||||
var recursiveSetField func(rv reflect.Value)
|
||||
recursiveSetField = func(rv reflect.Value) {
|
||||
for i := 0; i < rv.NumField(); i++ {
|
||||
f := rv.Field(i)
|
||||
fe := rv.Type().Field(i)
|
||||
|
||||
// check if the field is a Struct
|
||||
// recursive the Struct type
|
||||
if fe.Type.Kind() == reflect.Struct {
|
||||
recursiveSetField(f)
|
||||
}
|
||||
|
||||
_, tags := parseStructTag(fe.Tag.Get(defaultStructTagName))
|
||||
var col string
|
||||
if col = tags["column"]; col == "" {
|
||||
col = nameStrategyMap[nameStrategy](fe.Name)
|
||||
}
|
||||
if v, ok := columnsMp[col]; ok {
|
||||
value := reflect.ValueOf(v).Elem().Interface()
|
||||
o.setFieldValue(f, value)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// init call the recursive function
|
||||
recursiveSetField(ind)
|
||||
}
|
||||
|
||||
if eTyps[0].Kind() == reflect.Ptr {
|
||||
ind = ind.Addr()
|
||||
}
|
||||
|
||||
sInd = reflect.Append(sInd, ind)
|
||||
|
||||
} else {
|
||||
if err := rows.Scan(refs...); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
o.loopSetRefs(refs, sInds, &nInds, eTyps, cnt == 0)
|
||||
}
|
||||
|
||||
cnt++
|
||||
}
|
||||
|
||||
if cnt > 0 {
|
||||
|
||||
if structMode {
|
||||
sInds[0].Set(sInd)
|
||||
} else {
|
||||
for i, sInd := range sInds {
|
||||
nInd := nInds[i]
|
||||
sInd.Set(nInd)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return cnt, nil
|
||||
}
|
||||
|
||||
func (o *rawSet) readValues(container interface{}, needCols []string) (int64, error) {
|
||||
var (
|
||||
maps []Params
|
||||
lists []ParamsList
|
||||
list ParamsList
|
||||
)
|
||||
|
||||
typ := 0
|
||||
switch container.(type) {
|
||||
case *[]Params:
|
||||
typ = 1
|
||||
case *[]ParamsList:
|
||||
typ = 2
|
||||
case *ParamsList:
|
||||
typ = 3
|
||||
default:
|
||||
panic(fmt.Errorf("<RawSeter> unsupport read values type `%T`", container))
|
||||
}
|
||||
|
||||
query := o.query
|
||||
o.orm.alias.DbBaser.ReplaceMarks(&query)
|
||||
|
||||
args := getFlatParams(nil, o.args, o.orm.alias.TZ)
|
||||
|
||||
var rs *sql.Rows
|
||||
rs, err := o.orm.db.Query(query, args...)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
defer rs.Close()
|
||||
|
||||
var (
|
||||
refs []interface{}
|
||||
cnt int64
|
||||
cols []string
|
||||
indexs []int
|
||||
)
|
||||
|
||||
for rs.Next() {
|
||||
if cnt == 0 {
|
||||
columns, err := rs.Columns()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if len(needCols) > 0 {
|
||||
indexs = make([]int, 0, len(needCols))
|
||||
} else {
|
||||
indexs = make([]int, 0, len(columns))
|
||||
}
|
||||
|
||||
cols = columns
|
||||
refs = make([]interface{}, len(cols))
|
||||
for i := range refs {
|
||||
var ref sql.NullString
|
||||
refs[i] = &ref
|
||||
|
||||
if len(needCols) > 0 {
|
||||
for _, c := range needCols {
|
||||
if c == cols[i] {
|
||||
indexs = append(indexs, i)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
indexs = append(indexs, i)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if err := rs.Scan(refs...); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
switch typ {
|
||||
case 1:
|
||||
params := make(Params, len(cols))
|
||||
for _, i := range indexs {
|
||||
ref := refs[i]
|
||||
value := reflect.Indirect(reflect.ValueOf(ref)).Interface().(sql.NullString)
|
||||
if value.Valid {
|
||||
params[cols[i]] = value.String
|
||||
} else {
|
||||
params[cols[i]] = nil
|
||||
}
|
||||
}
|
||||
maps = append(maps, params)
|
||||
case 2:
|
||||
params := make(ParamsList, 0, len(cols))
|
||||
for _, i := range indexs {
|
||||
ref := refs[i]
|
||||
value := reflect.Indirect(reflect.ValueOf(ref)).Interface().(sql.NullString)
|
||||
if value.Valid {
|
||||
params = append(params, value.String)
|
||||
} else {
|
||||
params = append(params, nil)
|
||||
}
|
||||
}
|
||||
lists = append(lists, params)
|
||||
case 3:
|
||||
for _, i := range indexs {
|
||||
ref := refs[i]
|
||||
value := reflect.Indirect(reflect.ValueOf(ref)).Interface().(sql.NullString)
|
||||
if value.Valid {
|
||||
list = append(list, value.String)
|
||||
} else {
|
||||
list = append(list, nil)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
cnt++
|
||||
}
|
||||
|
||||
switch v := container.(type) {
|
||||
case *[]Params:
|
||||
*v = maps
|
||||
case *[]ParamsList:
|
||||
*v = lists
|
||||
case *ParamsList:
|
||||
*v = list
|
||||
}
|
||||
|
||||
return cnt, nil
|
||||
}
|
||||
|
||||
func (o *rawSet) queryRowsTo(container interface{}, keyCol, valueCol string) (int64, error) {
|
||||
var (
|
||||
maps Params
|
||||
ind *reflect.Value
|
||||
)
|
||||
|
||||
var typ int
|
||||
switch container.(type) {
|
||||
case *Params:
|
||||
typ = 1
|
||||
default:
|
||||
typ = 2
|
||||
vl := reflect.ValueOf(container)
|
||||
id := reflect.Indirect(vl)
|
||||
if vl.Kind() != reflect.Ptr || id.Kind() != reflect.Struct {
|
||||
panic(fmt.Errorf("<RawSeter> RowsTo unsupport type `%T` need ptr struct", container))
|
||||
}
|
||||
|
||||
ind = &id
|
||||
}
|
||||
|
||||
query := o.query
|
||||
o.orm.alias.DbBaser.ReplaceMarks(&query)
|
||||
|
||||
args := getFlatParams(nil, o.args, o.orm.alias.TZ)
|
||||
|
||||
rs, err := o.orm.db.Query(query, args...)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
defer rs.Close()
|
||||
|
||||
var (
|
||||
refs []interface{}
|
||||
cnt int64
|
||||
cols []string
|
||||
)
|
||||
|
||||
var (
|
||||
keyIndex = -1
|
||||
valueIndex = -1
|
||||
)
|
||||
|
||||
for rs.Next() {
|
||||
if cnt == 0 {
|
||||
columns, err := rs.Columns()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
cols = columns
|
||||
refs = make([]interface{}, len(cols))
|
||||
for i := range refs {
|
||||
if keyCol == cols[i] {
|
||||
keyIndex = i
|
||||
}
|
||||
if typ == 1 || keyIndex == i {
|
||||
var ref sql.NullString
|
||||
refs[i] = &ref
|
||||
} else {
|
||||
var ref interface{}
|
||||
refs[i] = &ref
|
||||
}
|
||||
if valueCol == cols[i] {
|
||||
valueIndex = i
|
||||
}
|
||||
}
|
||||
if keyIndex == -1 || valueIndex == -1 {
|
||||
panic(fmt.Errorf("<RawSeter> RowsTo unknown key, value column name `%s: %s`", keyCol, valueCol))
|
||||
}
|
||||
}
|
||||
|
||||
if err := rs.Scan(refs...); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
if cnt == 0 {
|
||||
switch typ {
|
||||
case 1:
|
||||
maps = make(Params)
|
||||
}
|
||||
}
|
||||
|
||||
key := reflect.Indirect(reflect.ValueOf(refs[keyIndex])).Interface().(sql.NullString).String
|
||||
|
||||
switch typ {
|
||||
case 1:
|
||||
value := reflect.Indirect(reflect.ValueOf(refs[valueIndex])).Interface().(sql.NullString)
|
||||
if value.Valid {
|
||||
maps[key] = value.String
|
||||
} else {
|
||||
maps[key] = nil
|
||||
}
|
||||
|
||||
default:
|
||||
if id := ind.FieldByName(camelString(key)); id.IsValid() {
|
||||
o.setFieldValue(id, reflect.ValueOf(refs[valueIndex]).Elem().Interface())
|
||||
}
|
||||
}
|
||||
|
||||
cnt++
|
||||
}
|
||||
|
||||
if typ == 1 {
|
||||
v, _ := container.(*Params)
|
||||
*v = maps
|
||||
}
|
||||
|
||||
return cnt, nil
|
||||
}
|
||||
|
||||
// query data to []map[string]interface
|
||||
func (o *rawSet) Values(container *[]Params, cols ...string) (int64, error) {
|
||||
return o.readValues(container, cols)
|
||||
}
|
||||
|
||||
// query data to [][]interface
|
||||
func (o *rawSet) ValuesList(container *[]ParamsList, cols ...string) (int64, error) {
|
||||
return o.readValues(container, cols)
|
||||
}
|
||||
|
||||
// query data to []interface
|
||||
func (o *rawSet) ValuesFlat(container *ParamsList, cols ...string) (int64, error) {
|
||||
return o.readValues(container, cols)
|
||||
}
|
||||
|
||||
// query all rows into map[string]interface with specify key and value column name.
|
||||
// keyCol = "name", valueCol = "value"
|
||||
// table data
|
||||
// name | value
|
||||
// total | 100
|
||||
// found | 200
|
||||
// to map[string]interface{}{
|
||||
// "total": 100,
|
||||
// "found": 200,
|
||||
// }
|
||||
func (o *rawSet) RowsToMap(result *Params, keyCol, valueCol string) (int64, error) {
|
||||
return o.queryRowsTo(result, keyCol, valueCol)
|
||||
}
|
||||
|
||||
// query all rows into struct with specify key and value column name.
|
||||
// keyCol = "name", valueCol = "value"
|
||||
// table data
|
||||
// name | value
|
||||
// total | 100
|
||||
// found | 200
|
||||
// to struct {
|
||||
// Total int
|
||||
// Found int
|
||||
// }
|
||||
func (o *rawSet) RowsToStruct(ptrStruct interface{}, keyCol, valueCol string) (int64, error) {
|
||||
return o.queryRowsTo(ptrStruct, keyCol, valueCol)
|
||||
}
|
||||
|
||||
// return prepared raw statement for used in times.
|
||||
func (o *rawSet) Prepare() (RawPreparer, error) {
|
||||
return newRawPreparer(o)
|
||||
}
|
||||
|
||||
func newRawSet(orm *ormBase, query string, args []interface{}) RawSeter {
|
||||
o := new(rawSet)
|
||||
o.query = query
|
||||
o.args = args
|
||||
o.orm = orm
|
||||
return o
|
||||
}
|
2602
pkg/client/orm/orm_test.go
Normal file
2602
pkg/client/orm/orm_test.go
Normal file
File diff suppressed because it is too large
Load Diff
62
pkg/client/orm/qb.go
Normal file
62
pkg/client/orm/qb.go
Normal file
@ -0,0 +1,62 @@
|
||||
// Copyright 2014 beego Author. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package orm
|
||||
|
||||
import "errors"
|
||||
|
||||
// QueryBuilder is the Query builder interface
|
||||
type QueryBuilder interface {
|
||||
Select(fields ...string) QueryBuilder
|
||||
ForUpdate() QueryBuilder
|
||||
From(tables ...string) QueryBuilder
|
||||
InnerJoin(table string) QueryBuilder
|
||||
LeftJoin(table string) QueryBuilder
|
||||
RightJoin(table string) QueryBuilder
|
||||
On(cond string) QueryBuilder
|
||||
Where(cond string) QueryBuilder
|
||||
And(cond string) QueryBuilder
|
||||
Or(cond string) QueryBuilder
|
||||
In(vals ...string) QueryBuilder
|
||||
OrderBy(fields ...string) QueryBuilder
|
||||
Asc() QueryBuilder
|
||||
Desc() QueryBuilder
|
||||
Limit(limit int) QueryBuilder
|
||||
Offset(offset int) QueryBuilder
|
||||
GroupBy(fields ...string) QueryBuilder
|
||||
Having(cond string) QueryBuilder
|
||||
Update(tables ...string) QueryBuilder
|
||||
Set(kv ...string) QueryBuilder
|
||||
Delete(tables ...string) QueryBuilder
|
||||
InsertInto(table string, fields ...string) QueryBuilder
|
||||
Values(vals ...string) QueryBuilder
|
||||
Subquery(sub string, alias string) string
|
||||
String() string
|
||||
}
|
||||
|
||||
// NewQueryBuilder return the QueryBuilder
|
||||
func NewQueryBuilder(driver string) (qb QueryBuilder, err error) {
|
||||
if driver == "mysql" {
|
||||
qb = new(MySQLQueryBuilder)
|
||||
} else if driver == "tidb" {
|
||||
qb = new(TiDBQueryBuilder)
|
||||
} else if driver == "postgres" {
|
||||
err = errors.New("postgres query builder is not supported yet")
|
||||
} else if driver == "sqlite" {
|
||||
err = errors.New("sqlite query builder is not supported yet")
|
||||
} else {
|
||||
err = errors.New("unknown driver for query builder")
|
||||
}
|
||||
return
|
||||
}
|
185
pkg/client/orm/qb_mysql.go
Normal file
185
pkg/client/orm/qb_mysql.go
Normal file
@ -0,0 +1,185 @@
|
||||
// Copyright 2014 beego Author. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package orm
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// CommaSpace is the separation
|
||||
const CommaSpace = ", "
|
||||
|
||||
// MySQLQueryBuilder is the SQL build
|
||||
type MySQLQueryBuilder struct {
|
||||
Tokens []string
|
||||
}
|
||||
|
||||
// Select will join the fields
|
||||
func (qb *MySQLQueryBuilder) Select(fields ...string) QueryBuilder {
|
||||
qb.Tokens = append(qb.Tokens, "SELECT", strings.Join(fields, CommaSpace))
|
||||
return qb
|
||||
}
|
||||
|
||||
// ForUpdate add the FOR UPDATE clause
|
||||
func (qb *MySQLQueryBuilder) ForUpdate() QueryBuilder {
|
||||
qb.Tokens = append(qb.Tokens, "FOR UPDATE")
|
||||
return qb
|
||||
}
|
||||
|
||||
// From join the tables
|
||||
func (qb *MySQLQueryBuilder) From(tables ...string) QueryBuilder {
|
||||
qb.Tokens = append(qb.Tokens, "FROM", strings.Join(tables, CommaSpace))
|
||||
return qb
|
||||
}
|
||||
|
||||
// InnerJoin INNER JOIN the table
|
||||
func (qb *MySQLQueryBuilder) InnerJoin(table string) QueryBuilder {
|
||||
qb.Tokens = append(qb.Tokens, "INNER JOIN", table)
|
||||
return qb
|
||||
}
|
||||
|
||||
// LeftJoin LEFT JOIN the table
|
||||
func (qb *MySQLQueryBuilder) LeftJoin(table string) QueryBuilder {
|
||||
qb.Tokens = append(qb.Tokens, "LEFT JOIN", table)
|
||||
return qb
|
||||
}
|
||||
|
||||
// RightJoin RIGHT JOIN the table
|
||||
func (qb *MySQLQueryBuilder) RightJoin(table string) QueryBuilder {
|
||||
qb.Tokens = append(qb.Tokens, "RIGHT JOIN", table)
|
||||
return qb
|
||||
}
|
||||
|
||||
// On join with on cond
|
||||
func (qb *MySQLQueryBuilder) On(cond string) QueryBuilder {
|
||||
qb.Tokens = append(qb.Tokens, "ON", cond)
|
||||
return qb
|
||||
}
|
||||
|
||||
// Where join the Where cond
|
||||
func (qb *MySQLQueryBuilder) Where(cond string) QueryBuilder {
|
||||
qb.Tokens = append(qb.Tokens, "WHERE", cond)
|
||||
return qb
|
||||
}
|
||||
|
||||
// And join the and cond
|
||||
func (qb *MySQLQueryBuilder) And(cond string) QueryBuilder {
|
||||
qb.Tokens = append(qb.Tokens, "AND", cond)
|
||||
return qb
|
||||
}
|
||||
|
||||
// Or join the or cond
|
||||
func (qb *MySQLQueryBuilder) Or(cond string) QueryBuilder {
|
||||
qb.Tokens = append(qb.Tokens, "OR", cond)
|
||||
return qb
|
||||
}
|
||||
|
||||
// In join the IN (vals)
|
||||
func (qb *MySQLQueryBuilder) In(vals ...string) QueryBuilder {
|
||||
qb.Tokens = append(qb.Tokens, "IN", "(", strings.Join(vals, CommaSpace), ")")
|
||||
return qb
|
||||
}
|
||||
|
||||
// OrderBy join the Order by fields
|
||||
func (qb *MySQLQueryBuilder) OrderBy(fields ...string) QueryBuilder {
|
||||
qb.Tokens = append(qb.Tokens, "ORDER BY", strings.Join(fields, CommaSpace))
|
||||
return qb
|
||||
}
|
||||
|
||||
// Asc join the asc
|
||||
func (qb *MySQLQueryBuilder) Asc() QueryBuilder {
|
||||
qb.Tokens = append(qb.Tokens, "ASC")
|
||||
return qb
|
||||
}
|
||||
|
||||
// Desc join the desc
|
||||
func (qb *MySQLQueryBuilder) Desc() QueryBuilder {
|
||||
qb.Tokens = append(qb.Tokens, "DESC")
|
||||
return qb
|
||||
}
|
||||
|
||||
// Limit join the limit num
|
||||
func (qb *MySQLQueryBuilder) Limit(limit int) QueryBuilder {
|
||||
qb.Tokens = append(qb.Tokens, "LIMIT", strconv.Itoa(limit))
|
||||
return qb
|
||||
}
|
||||
|
||||
// Offset join the offset num
|
||||
func (qb *MySQLQueryBuilder) Offset(offset int) QueryBuilder {
|
||||
qb.Tokens = append(qb.Tokens, "OFFSET", strconv.Itoa(offset))
|
||||
return qb
|
||||
}
|
||||
|
||||
// GroupBy join the Group by fields
|
||||
func (qb *MySQLQueryBuilder) GroupBy(fields ...string) QueryBuilder {
|
||||
qb.Tokens = append(qb.Tokens, "GROUP BY", strings.Join(fields, CommaSpace))
|
||||
return qb
|
||||
}
|
||||
|
||||
// Having join the Having cond
|
||||
func (qb *MySQLQueryBuilder) Having(cond string) QueryBuilder {
|
||||
qb.Tokens = append(qb.Tokens, "HAVING", cond)
|
||||
return qb
|
||||
}
|
||||
|
||||
// Update join the update table
|
||||
func (qb *MySQLQueryBuilder) Update(tables ...string) QueryBuilder {
|
||||
qb.Tokens = append(qb.Tokens, "UPDATE", strings.Join(tables, CommaSpace))
|
||||
return qb
|
||||
}
|
||||
|
||||
// Set join the set kv
|
||||
func (qb *MySQLQueryBuilder) Set(kv ...string) QueryBuilder {
|
||||
qb.Tokens = append(qb.Tokens, "SET", strings.Join(kv, CommaSpace))
|
||||
return qb
|
||||
}
|
||||
|
||||
// Delete join the Delete tables
|
||||
func (qb *MySQLQueryBuilder) Delete(tables ...string) QueryBuilder {
|
||||
qb.Tokens = append(qb.Tokens, "DELETE")
|
||||
if len(tables) != 0 {
|
||||
qb.Tokens = append(qb.Tokens, strings.Join(tables, CommaSpace))
|
||||
}
|
||||
return qb
|
||||
}
|
||||
|
||||
// InsertInto join the insert SQL
|
||||
func (qb *MySQLQueryBuilder) InsertInto(table string, fields ...string) QueryBuilder {
|
||||
qb.Tokens = append(qb.Tokens, "INSERT INTO", table)
|
||||
if len(fields) != 0 {
|
||||
fieldsStr := strings.Join(fields, CommaSpace)
|
||||
qb.Tokens = append(qb.Tokens, "(", fieldsStr, ")")
|
||||
}
|
||||
return qb
|
||||
}
|
||||
|
||||
// Values join the Values(vals)
|
||||
func (qb *MySQLQueryBuilder) Values(vals ...string) QueryBuilder {
|
||||
valsStr := strings.Join(vals, CommaSpace)
|
||||
qb.Tokens = append(qb.Tokens, "VALUES", "(", valsStr, ")")
|
||||
return qb
|
||||
}
|
||||
|
||||
// Subquery join the sub as alias
|
||||
func (qb *MySQLQueryBuilder) Subquery(sub string, alias string) string {
|
||||
return fmt.Sprintf("(%s) AS %s", sub, alias)
|
||||
}
|
||||
|
||||
// String join all Tokens
|
||||
func (qb *MySQLQueryBuilder) String() string {
|
||||
return strings.Join(qb.Tokens, " ")
|
||||
}
|
182
pkg/client/orm/qb_tidb.go
Normal file
182
pkg/client/orm/qb_tidb.go
Normal file
@ -0,0 +1,182 @@
|
||||
// Copyright 2015 TiDB Author. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package orm
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// TiDBQueryBuilder is the SQL build
|
||||
type TiDBQueryBuilder struct {
|
||||
Tokens []string
|
||||
}
|
||||
|
||||
// Select will join the fields
|
||||
func (qb *TiDBQueryBuilder) Select(fields ...string) QueryBuilder {
|
||||
qb.Tokens = append(qb.Tokens, "SELECT", strings.Join(fields, CommaSpace))
|
||||
return qb
|
||||
}
|
||||
|
||||
// ForUpdate add the FOR UPDATE clause
|
||||
func (qb *TiDBQueryBuilder) ForUpdate() QueryBuilder {
|
||||
qb.Tokens = append(qb.Tokens, "FOR UPDATE")
|
||||
return qb
|
||||
}
|
||||
|
||||
// From join the tables
|
||||
func (qb *TiDBQueryBuilder) From(tables ...string) QueryBuilder {
|
||||
qb.Tokens = append(qb.Tokens, "FROM", strings.Join(tables, CommaSpace))
|
||||
return qb
|
||||
}
|
||||
|
||||
// InnerJoin INNER JOIN the table
|
||||
func (qb *TiDBQueryBuilder) InnerJoin(table string) QueryBuilder {
|
||||
qb.Tokens = append(qb.Tokens, "INNER JOIN", table)
|
||||
return qb
|
||||
}
|
||||
|
||||
// LeftJoin LEFT JOIN the table
|
||||
func (qb *TiDBQueryBuilder) LeftJoin(table string) QueryBuilder {
|
||||
qb.Tokens = append(qb.Tokens, "LEFT JOIN", table)
|
||||
return qb
|
||||
}
|
||||
|
||||
// RightJoin RIGHT JOIN the table
|
||||
func (qb *TiDBQueryBuilder) RightJoin(table string) QueryBuilder {
|
||||
qb.Tokens = append(qb.Tokens, "RIGHT JOIN", table)
|
||||
return qb
|
||||
}
|
||||
|
||||
// On join with on cond
|
||||
func (qb *TiDBQueryBuilder) On(cond string) QueryBuilder {
|
||||
qb.Tokens = append(qb.Tokens, "ON", cond)
|
||||
return qb
|
||||
}
|
||||
|
||||
// Where join the Where cond
|
||||
func (qb *TiDBQueryBuilder) Where(cond string) QueryBuilder {
|
||||
qb.Tokens = append(qb.Tokens, "WHERE", cond)
|
||||
return qb
|
||||
}
|
||||
|
||||
// And join the and cond
|
||||
func (qb *TiDBQueryBuilder) And(cond string) QueryBuilder {
|
||||
qb.Tokens = append(qb.Tokens, "AND", cond)
|
||||
return qb
|
||||
}
|
||||
|
||||
// Or join the or cond
|
||||
func (qb *TiDBQueryBuilder) Or(cond string) QueryBuilder {
|
||||
qb.Tokens = append(qb.Tokens, "OR", cond)
|
||||
return qb
|
||||
}
|
||||
|
||||
// In join the IN (vals)
|
||||
func (qb *TiDBQueryBuilder) In(vals ...string) QueryBuilder {
|
||||
qb.Tokens = append(qb.Tokens, "IN", "(", strings.Join(vals, CommaSpace), ")")
|
||||
return qb
|
||||
}
|
||||
|
||||
// OrderBy join the Order by fields
|
||||
func (qb *TiDBQueryBuilder) OrderBy(fields ...string) QueryBuilder {
|
||||
qb.Tokens = append(qb.Tokens, "ORDER BY", strings.Join(fields, CommaSpace))
|
||||
return qb
|
||||
}
|
||||
|
||||
// Asc join the asc
|
||||
func (qb *TiDBQueryBuilder) Asc() QueryBuilder {
|
||||
qb.Tokens = append(qb.Tokens, "ASC")
|
||||
return qb
|
||||
}
|
||||
|
||||
// Desc join the desc
|
||||
func (qb *TiDBQueryBuilder) Desc() QueryBuilder {
|
||||
qb.Tokens = append(qb.Tokens, "DESC")
|
||||
return qb
|
||||
}
|
||||
|
||||
// Limit join the limit num
|
||||
func (qb *TiDBQueryBuilder) Limit(limit int) QueryBuilder {
|
||||
qb.Tokens = append(qb.Tokens, "LIMIT", strconv.Itoa(limit))
|
||||
return qb
|
||||
}
|
||||
|
||||
// Offset join the offset num
|
||||
func (qb *TiDBQueryBuilder) Offset(offset int) QueryBuilder {
|
||||
qb.Tokens = append(qb.Tokens, "OFFSET", strconv.Itoa(offset))
|
||||
return qb
|
||||
}
|
||||
|
||||
// GroupBy join the Group by fields
|
||||
func (qb *TiDBQueryBuilder) GroupBy(fields ...string) QueryBuilder {
|
||||
qb.Tokens = append(qb.Tokens, "GROUP BY", strings.Join(fields, CommaSpace))
|
||||
return qb
|
||||
}
|
||||
|
||||
// Having join the Having cond
|
||||
func (qb *TiDBQueryBuilder) Having(cond string) QueryBuilder {
|
||||
qb.Tokens = append(qb.Tokens, "HAVING", cond)
|
||||
return qb
|
||||
}
|
||||
|
||||
// Update join the update table
|
||||
func (qb *TiDBQueryBuilder) Update(tables ...string) QueryBuilder {
|
||||
qb.Tokens = append(qb.Tokens, "UPDATE", strings.Join(tables, CommaSpace))
|
||||
return qb
|
||||
}
|
||||
|
||||
// Set join the set kv
|
||||
func (qb *TiDBQueryBuilder) Set(kv ...string) QueryBuilder {
|
||||
qb.Tokens = append(qb.Tokens, "SET", strings.Join(kv, CommaSpace))
|
||||
return qb
|
||||
}
|
||||
|
||||
// Delete join the Delete tables
|
||||
func (qb *TiDBQueryBuilder) Delete(tables ...string) QueryBuilder {
|
||||
qb.Tokens = append(qb.Tokens, "DELETE")
|
||||
if len(tables) != 0 {
|
||||
qb.Tokens = append(qb.Tokens, strings.Join(tables, CommaSpace))
|
||||
}
|
||||
return qb
|
||||
}
|
||||
|
||||
// InsertInto join the insert SQL
|
||||
func (qb *TiDBQueryBuilder) InsertInto(table string, fields ...string) QueryBuilder {
|
||||
qb.Tokens = append(qb.Tokens, "INSERT INTO", table)
|
||||
if len(fields) != 0 {
|
||||
fieldsStr := strings.Join(fields, CommaSpace)
|
||||
qb.Tokens = append(qb.Tokens, "(", fieldsStr, ")")
|
||||
}
|
||||
return qb
|
||||
}
|
||||
|
||||
// Values join the Values(vals)
|
||||
func (qb *TiDBQueryBuilder) Values(vals ...string) QueryBuilder {
|
||||
valsStr := strings.Join(vals, CommaSpace)
|
||||
qb.Tokens = append(qb.Tokens, "VALUES", "(", valsStr, ")")
|
||||
return qb
|
||||
}
|
||||
|
||||
// Subquery join the sub as alias
|
||||
func (qb *TiDBQueryBuilder) Subquery(sub string, alias string) string {
|
||||
return fmt.Sprintf("(%s) AS %s", sub, alias)
|
||||
}
|
||||
|
||||
// String join all Tokens
|
||||
func (qb *TiDBQueryBuilder) String() string {
|
||||
return strings.Join(qb.Tokens, " ")
|
||||
}
|
584
pkg/client/orm/types.go
Normal file
584
pkg/client/orm/types.go
Normal file
@ -0,0 +1,584 @@
|
||||
// Copyright 2014 beego Author. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package orm
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"reflect"
|
||||
"time"
|
||||
|
||||
"github.com/astaxie/beego/pkg/infrastructure/utils"
|
||||
)
|
||||
|
||||
// TableNaming is usually used by model
|
||||
// when you custom your table name, please implement this interfaces
|
||||
// for example:
|
||||
// type User struct {
|
||||
// ...
|
||||
// }
|
||||
// func (u *User) TableName() string {
|
||||
// return "USER_TABLE"
|
||||
// }
|
||||
type TableNameI interface {
|
||||
TableName() string
|
||||
}
|
||||
|
||||
// TableEngineI is usually used by model
|
||||
// when you want to use specific engine, like myisam, you can implement this interface
|
||||
// for example:
|
||||
// type User struct {
|
||||
// ...
|
||||
// }
|
||||
// func (u *User) TableEngine() string {
|
||||
// return "myisam"
|
||||
// }
|
||||
type TableEngineI interface {
|
||||
TableEngine() string
|
||||
}
|
||||
|
||||
// TableIndexI is usually used by model
|
||||
// when you want to create indexes, you can implement this interface
|
||||
// for example:
|
||||
// type User struct {
|
||||
// ...
|
||||
// }
|
||||
// func (u *User) TableIndex() [][]string {
|
||||
// return [][]string{{"Name"}}
|
||||
// }
|
||||
type TableIndexI interface {
|
||||
TableIndex() [][]string
|
||||
}
|
||||
|
||||
// TableUniqueI is usually used by model
|
||||
// when you want to create unique indexes, you can implement this interface
|
||||
// for example:
|
||||
// type User struct {
|
||||
// ...
|
||||
// }
|
||||
// func (u *User) TableUnique() [][]string {
|
||||
// return [][]string{{"Email"}}
|
||||
// }
|
||||
type TableUniqueI interface {
|
||||
TableUnique() [][]string
|
||||
}
|
||||
|
||||
// Driver define database driver
|
||||
type Driver interface {
|
||||
Name() string
|
||||
Type() DriverType
|
||||
}
|
||||
|
||||
// Fielder define field info
|
||||
type Fielder interface {
|
||||
String() string
|
||||
FieldType() int
|
||||
SetRaw(interface{}) error
|
||||
RawValue() interface{}
|
||||
}
|
||||
|
||||
type TxBeginner interface {
|
||||
//self control transaction
|
||||
Begin() (TxOrmer, error)
|
||||
BeginWithCtx(ctx context.Context) (TxOrmer, error)
|
||||
BeginWithOpts(opts *sql.TxOptions) (TxOrmer, error)
|
||||
BeginWithCtxAndOpts(ctx context.Context, opts *sql.TxOptions) (TxOrmer, error)
|
||||
|
||||
//closure control transaction
|
||||
DoTx(task func(ctx context.Context, txOrm TxOrmer) error) error
|
||||
DoTxWithCtx(ctx context.Context, task func(ctx context.Context, txOrm TxOrmer) error) error
|
||||
DoTxWithOpts(opts *sql.TxOptions, task func(ctx context.Context, txOrm TxOrmer) error) error
|
||||
DoTxWithCtxAndOpts(ctx context.Context, opts *sql.TxOptions, task func(ctx context.Context, txOrm TxOrmer) error) error
|
||||
}
|
||||
|
||||
type TxCommitter interface {
|
||||
Commit() error
|
||||
Rollback() error
|
||||
}
|
||||
|
||||
//Data Manipulation Language
|
||||
type DML interface {
|
||||
// insert model data to database
|
||||
// for example:
|
||||
// user := new(User)
|
||||
// id, err = Ormer.Insert(user)
|
||||
// user must be a pointer and Insert will set user's pk field
|
||||
Insert(md interface{}) (int64, error)
|
||||
InsertWithCtx(ctx context.Context, md interface{}) (int64, error)
|
||||
// mysql:InsertOrUpdate(model) or InsertOrUpdate(model,"colu=colu+value")
|
||||
// if colu type is integer : can use(+-*/), string : convert(colu,"value")
|
||||
// postgres: InsertOrUpdate(model,"conflictColumnName") or InsertOrUpdate(model,"conflictColumnName","colu=colu+value")
|
||||
// if colu type is integer : can use(+-*/), string : colu || "value"
|
||||
InsertOrUpdate(md interface{}, colConflitAndArgs ...string) (int64, error)
|
||||
InsertOrUpdateWithCtx(ctx context.Context, md interface{}, colConflitAndArgs ...string) (int64, error)
|
||||
// insert some models to database
|
||||
InsertMulti(bulk int, mds interface{}) (int64, error)
|
||||
InsertMultiWithCtx(ctx context.Context, bulk int, mds interface{}) (int64, error)
|
||||
// update model to database.
|
||||
// cols set the columns those want to update.
|
||||
// find model by Id(pk) field and update columns specified by fields, if cols is null then update all columns
|
||||
// for example:
|
||||
// user := User{Id: 2}
|
||||
// user.Langs = append(user.Langs, "zh-CN", "en-US")
|
||||
// user.Extra.Name = "beego"
|
||||
// user.Extra.Data = "orm"
|
||||
// num, err = Ormer.Update(&user, "Langs", "Extra")
|
||||
Update(md interface{}, cols ...string) (int64, error)
|
||||
UpdateWithCtx(ctx context.Context, md interface{}, cols ...string) (int64, error)
|
||||
// delete model in database
|
||||
Delete(md interface{}, cols ...string) (int64, error)
|
||||
DeleteWithCtx(ctx context.Context, md interface{}, cols ...string) (int64, error)
|
||||
|
||||
// return a raw query seter for raw sql string.
|
||||
// for example:
|
||||
// ormer.Raw("UPDATE `user` SET `user_name` = ? WHERE `user_name` = ?", "slene", "testing").Exec()
|
||||
// // update user testing's name to slene
|
||||
Raw(query string, args ...interface{}) RawSeter
|
||||
RawWithCtx(ctx context.Context, query string, args ...interface{}) RawSeter
|
||||
}
|
||||
|
||||
// Data Query Language
|
||||
type DQL interface {
|
||||
// read data to model
|
||||
// for example:
|
||||
// this will find User by Id field
|
||||
// u = &User{Id: user.Id}
|
||||
// err = Ormer.Read(u)
|
||||
// this will find User by UserName field
|
||||
// u = &User{UserName: "astaxie", Password: "pass"}
|
||||
// err = Ormer.Read(u, "UserName")
|
||||
Read(md interface{}, cols ...string) error
|
||||
ReadWithCtx(ctx context.Context, md interface{}, cols ...string) error
|
||||
|
||||
// Like Read(), but with "FOR UPDATE" clause, useful in transaction.
|
||||
// Some databases are not support this feature.
|
||||
ReadForUpdate(md interface{}, cols ...string) error
|
||||
ReadForUpdateWithCtx(ctx context.Context, md interface{}, cols ...string) error
|
||||
|
||||
// Try to read a row from the database, or insert one if it doesn't exist
|
||||
ReadOrCreate(md interface{}, col1 string, cols ...string) (bool, int64, error)
|
||||
ReadOrCreateWithCtx(ctx context.Context, md interface{}, col1 string, cols ...string) (bool, int64, error)
|
||||
|
||||
// load related models to md model.
|
||||
// args are limit, offset int and order string.
|
||||
//
|
||||
// example:
|
||||
// Ormer.LoadRelated(post,"Tags")
|
||||
// for _,tag := range post.Tags{...}
|
||||
// hints.DefaultRelDepth useDefaultRelsDepth ; or depth 0
|
||||
// hints.RelDepth loadRelationDepth
|
||||
// hints.Limit limit default limit 1000
|
||||
// hints.Offset int offset default offset 0
|
||||
// hints.OrderBy string order for example : "-Id"
|
||||
// make sure the relation is defined in model struct tags.
|
||||
LoadRelated(md interface{}, name string, args ...utils.KV) (int64, error)
|
||||
LoadRelatedWithCtx(ctx context.Context, md interface{}, name string, args ...utils.KV) (int64, error)
|
||||
|
||||
// create a models to models queryer
|
||||
// for example:
|
||||
// post := Post{Id: 4}
|
||||
// m2m := Ormer.QueryM2M(&post, "Tags")
|
||||
QueryM2M(md interface{}, name string) QueryM2Mer
|
||||
QueryM2MWithCtx(ctx context.Context, md interface{}, name string) QueryM2Mer
|
||||
|
||||
// return a QuerySeter for table operations.
|
||||
// table name can be string or struct.
|
||||
// e.g. QueryTable("user"), QueryTable(&user{}) or QueryTable((*User)(nil)),
|
||||
QueryTable(ptrStructOrTableName interface{}) QuerySeter
|
||||
QueryTableWithCtx(ctx context.Context, ptrStructOrTableName interface{}) QuerySeter
|
||||
|
||||
DBStats() *sql.DBStats
|
||||
}
|
||||
|
||||
type DriverGetter interface {
|
||||
Driver() Driver
|
||||
}
|
||||
|
||||
type ormer interface {
|
||||
DQL
|
||||
DML
|
||||
DriverGetter
|
||||
}
|
||||
|
||||
type Ormer interface {
|
||||
ormer
|
||||
TxBeginner
|
||||
}
|
||||
|
||||
type TxOrmer interface {
|
||||
ormer
|
||||
TxCommitter
|
||||
}
|
||||
|
||||
// Inserter insert prepared statement
|
||||
type Inserter interface {
|
||||
Insert(interface{}) (int64, error)
|
||||
Close() error
|
||||
}
|
||||
|
||||
// QuerySeter query seter
|
||||
type QuerySeter interface {
|
||||
// add condition expression to QuerySeter.
|
||||
// for example:
|
||||
// filter by UserName == 'slene'
|
||||
// qs.Filter("UserName", "slene")
|
||||
// sql : left outer join profile on t0.id1==t1.id2 where t1.age == 28
|
||||
// Filter("profile__Age", 28)
|
||||
// // time compare
|
||||
// qs.Filter("created", time.Now())
|
||||
Filter(string, ...interface{}) QuerySeter
|
||||
// add raw sql to querySeter.
|
||||
// for example:
|
||||
// qs.FilterRaw("user_id IN (SELECT id FROM profile WHERE age>=18)")
|
||||
// //sql-> WHERE user_id IN (SELECT id FROM profile WHERE age>=18)
|
||||
FilterRaw(string, string) QuerySeter
|
||||
// add NOT condition to querySeter.
|
||||
// have the same usage as Filter
|
||||
Exclude(string, ...interface{}) QuerySeter
|
||||
// set condition to QuerySeter.
|
||||
// sql's where condition
|
||||
// cond := orm.NewCondition()
|
||||
// cond1 := cond.And("profile__isnull", false).AndNot("status__in", 1).Or("profile__age__gt", 2000)
|
||||
// //sql-> WHERE T0.`profile_id` IS NOT NULL AND NOT T0.`Status` IN (?) OR T1.`age` > 2000
|
||||
// num, err := qs.SetCond(cond1).Count()
|
||||
SetCond(*Condition) QuerySeter
|
||||
// get condition from QuerySeter.
|
||||
// sql's where condition
|
||||
// cond := orm.NewCondition()
|
||||
// cond = cond.And("profile__isnull", false).AndNot("status__in", 1)
|
||||
// qs = qs.SetCond(cond)
|
||||
// cond = qs.GetCond()
|
||||
// cond := cond.Or("profile__age__gt", 2000)
|
||||
// //sql-> WHERE T0.`profile_id` IS NOT NULL AND NOT T0.`Status` IN (?) OR T1.`age` > 2000
|
||||
// num, err := qs.SetCond(cond).Count()
|
||||
GetCond() *Condition
|
||||
// add LIMIT value.
|
||||
// args[0] means offset, e.g. LIMIT num,offset.
|
||||
// if Limit <= 0 then Limit will be set to default limit ,eg 1000
|
||||
// if QuerySeter doesn't call Limit, the sql's Limit will be set to default limit, eg 1000
|
||||
// for example:
|
||||
// qs.Limit(10, 2)
|
||||
// // sql-> limit 10 offset 2
|
||||
Limit(limit interface{}, args ...interface{}) QuerySeter
|
||||
// add OFFSET value
|
||||
// same as Limit function's args[0]
|
||||
Offset(offset interface{}) QuerySeter
|
||||
// add GROUP BY expression
|
||||
// for example:
|
||||
// qs.GroupBy("id")
|
||||
GroupBy(exprs ...string) QuerySeter
|
||||
// add ORDER expression.
|
||||
// "column" means ASC, "-column" means DESC.
|
||||
// for example:
|
||||
// qs.OrderBy("-status")
|
||||
OrderBy(exprs ...string) QuerySeter
|
||||
// add FORCE INDEX expression.
|
||||
// for example:
|
||||
// qs.ForceIndex(`idx_name1`,`idx_name2`)
|
||||
// ForceIndex, UseIndex , IgnoreIndex are mutually exclusive
|
||||
ForceIndex(indexes ...string) QuerySeter
|
||||
// add USE INDEX expression.
|
||||
// for example:
|
||||
// qs.UseIndex(`idx_name1`,`idx_name2`)
|
||||
// ForceIndex, UseIndex , IgnoreIndex are mutually exclusive
|
||||
UseIndex(indexes ...string) QuerySeter
|
||||
// add IGNORE INDEX expression.
|
||||
// for example:
|
||||
// qs.IgnoreIndex(`idx_name1`,`idx_name2`)
|
||||
// ForceIndex, UseIndex , IgnoreIndex are mutually exclusive
|
||||
IgnoreIndex(indexes ...string) QuerySeter
|
||||
// set relation model to query together.
|
||||
// it will query relation models and assign to parent model.
|
||||
// for example:
|
||||
// // will load all related fields use left join .
|
||||
// qs.RelatedSel().One(&user)
|
||||
// // will load related field only profile
|
||||
// qs.RelatedSel("profile").One(&user)
|
||||
// user.Profile.Age = 32
|
||||
RelatedSel(params ...interface{}) QuerySeter
|
||||
// Set Distinct
|
||||
// for example:
|
||||
// o.QueryTable("policy").Filter("Groups__Group__Users__User", user).
|
||||
// Distinct().
|
||||
// All(&permissions)
|
||||
Distinct() QuerySeter
|
||||
// set FOR UPDATE to query.
|
||||
// for example:
|
||||
// o.QueryTable("user").Filter("uid", uid).ForUpdate().All(&users)
|
||||
ForUpdate() QuerySeter
|
||||
// return QuerySeter execution result number
|
||||
// for example:
|
||||
// num, err = qs.Filter("profile__age__gt", 28).Count()
|
||||
Count() (int64, error)
|
||||
// check result empty or not after QuerySeter executed
|
||||
// the same as QuerySeter.Count > 0
|
||||
Exist() bool
|
||||
// execute update with parameters
|
||||
// for example:
|
||||
// num, err = qs.Filter("user_name", "slene").Update(Params{
|
||||
// "Nums": ColValue(Col_Minus, 50),
|
||||
// }) // user slene's Nums will minus 50
|
||||
// num, err = qs.Filter("UserName", "slene").Update(Params{
|
||||
// "user_name": "slene2"
|
||||
// }) // user slene's name will change to slene2
|
||||
Update(values Params) (int64, error)
|
||||
// delete from table
|
||||
// for example:
|
||||
// num ,err = qs.Filter("user_name__in", "testing1", "testing2").Delete()
|
||||
// //delete two user who's name is testing1 or testing2
|
||||
Delete() (int64, error)
|
||||
// return a insert queryer.
|
||||
// it can be used in times.
|
||||
// example:
|
||||
// i,err := sq.PrepareInsert()
|
||||
// num, err = i.Insert(&user1) // user table will add one record user1 at once
|
||||
// num, err = i.Insert(&user2) // user table will add one record user2 at once
|
||||
// err = i.Close() //don't forget call Close
|
||||
PrepareInsert() (Inserter, error)
|
||||
// query all data and map to containers.
|
||||
// cols means the columns when querying.
|
||||
// for example:
|
||||
// var users []*User
|
||||
// qs.All(&users) // users[0],users[1],users[2] ...
|
||||
All(container interface{}, cols ...string) (int64, error)
|
||||
// query one row data and map to containers.
|
||||
// cols means the columns when querying.
|
||||
// for example:
|
||||
// var user User
|
||||
// qs.One(&user) //user.UserName == "slene"
|
||||
One(container interface{}, cols ...string) error
|
||||
// query all data and map to []map[string]interface.
|
||||
// expres means condition expression.
|
||||
// it converts data to []map[column]value.
|
||||
// for example:
|
||||
// var maps []Params
|
||||
// qs.Values(&maps) //maps[0]["UserName"]=="slene"
|
||||
Values(results *[]Params, exprs ...string) (int64, error)
|
||||
// query all data and map to [][]interface
|
||||
// it converts data to [][column_index]value
|
||||
// for example:
|
||||
// var list []ParamsList
|
||||
// qs.ValuesList(&list) // list[0][1] == "slene"
|
||||
ValuesList(results *[]ParamsList, exprs ...string) (int64, error)
|
||||
// query all data and map to []interface.
|
||||
// it's designed for one column record set, auto change to []value, not [][column]value.
|
||||
// for example:
|
||||
// var list ParamsList
|
||||
// qs.ValuesFlat(&list, "UserName") // list[0] == "slene"
|
||||
ValuesFlat(result *ParamsList, expr string) (int64, error)
|
||||
// query all rows into map[string]interface with specify key and value column name.
|
||||
// keyCol = "name", valueCol = "value"
|
||||
// table data
|
||||
// name | value
|
||||
// total | 100
|
||||
// found | 200
|
||||
// to map[string]interface{}{
|
||||
// "total": 100,
|
||||
// "found": 200,
|
||||
// }
|
||||
RowsToMap(result *Params, keyCol, valueCol string) (int64, error)
|
||||
// query all rows into struct with specify key and value column name.
|
||||
// keyCol = "name", valueCol = "value"
|
||||
// table data
|
||||
// name | value
|
||||
// total | 100
|
||||
// found | 200
|
||||
// to struct {
|
||||
// Total int
|
||||
// Found int
|
||||
// }
|
||||
RowsToStruct(ptrStruct interface{}, keyCol, valueCol string) (int64, error)
|
||||
}
|
||||
|
||||
// QueryM2Mer model to model query struct
|
||||
// all operations are on the m2m table only, will not affect the origin model table
|
||||
type QueryM2Mer interface {
|
||||
// add models to origin models when creating queryM2M.
|
||||
// example:
|
||||
// m2m := orm.QueryM2M(post,"Tag")
|
||||
// m2m.Add(&Tag1{},&Tag2{})
|
||||
// for _,tag := range post.Tags{}{ ... }
|
||||
// param could also be any of the follow
|
||||
// []*Tag{{Id:3,Name: "TestTag1"}, {Id:4,Name: "TestTag2"}}
|
||||
// &Tag{Id:5,Name: "TestTag3"}
|
||||
// []interface{}{&Tag{Id:6,Name: "TestTag4"}}
|
||||
// insert one or more rows to m2m table
|
||||
// make sure the relation is defined in post model struct tag.
|
||||
Add(...interface{}) (int64, error)
|
||||
// remove models following the origin model relationship
|
||||
// only delete rows from m2m table
|
||||
// for example:
|
||||
// tag3 := &Tag{Id:5,Name: "TestTag3"}
|
||||
// num, err = m2m.Remove(tag3)
|
||||
Remove(...interface{}) (int64, error)
|
||||
// check model is existed in relationship of origin model
|
||||
Exist(interface{}) bool
|
||||
// clean all models in related of origin model
|
||||
Clear() (int64, error)
|
||||
// count all related models of origin model
|
||||
Count() (int64, error)
|
||||
}
|
||||
|
||||
// RawPreparer raw query statement
|
||||
type RawPreparer interface {
|
||||
Exec(...interface{}) (sql.Result, error)
|
||||
Close() error
|
||||
}
|
||||
|
||||
// RawSeter raw query seter
|
||||
// create From Ormer.Raw
|
||||
// for example:
|
||||
// sql := fmt.Sprintf("SELECT %sid%s,%sname%s FROM %suser%s WHERE id = ?",Q,Q,Q,Q,Q,Q)
|
||||
// rs := Ormer.Raw(sql, 1)
|
||||
type RawSeter interface {
|
||||
// execute sql and get result
|
||||
Exec() (sql.Result, error)
|
||||
// query data and map to container
|
||||
// for example:
|
||||
// var name string
|
||||
// var id int
|
||||
// rs.QueryRow(&id,&name) // id==2 name=="slene"
|
||||
QueryRow(containers ...interface{}) error
|
||||
|
||||
// query data rows and map to container
|
||||
// var ids []int
|
||||
// var names []int
|
||||
// query = fmt.Sprintf("SELECT 'id','name' FROM %suser%s", Q, Q)
|
||||
// num, err = dORM.Raw(query).QueryRows(&ids,&names) // ids=>{1,2},names=>{"nobody","slene"}
|
||||
QueryRows(containers ...interface{}) (int64, error)
|
||||
SetArgs(...interface{}) RawSeter
|
||||
// query data to []map[string]interface
|
||||
// see QuerySeter's Values
|
||||
Values(container *[]Params, cols ...string) (int64, error)
|
||||
// query data to [][]interface
|
||||
// see QuerySeter's ValuesList
|
||||
ValuesList(container *[]ParamsList, cols ...string) (int64, error)
|
||||
// query data to []interface
|
||||
// see QuerySeter's ValuesFlat
|
||||
ValuesFlat(container *ParamsList, cols ...string) (int64, error)
|
||||
// query all rows into map[string]interface with specify key and value column name.
|
||||
// keyCol = "name", valueCol = "value"
|
||||
// table data
|
||||
// name | value
|
||||
// total | 100
|
||||
// found | 200
|
||||
// to map[string]interface{}{
|
||||
// "total": 100,
|
||||
// "found": 200,
|
||||
// }
|
||||
RowsToMap(result *Params, keyCol, valueCol string) (int64, error)
|
||||
// query all rows into struct with specify key and value column name.
|
||||
// keyCol = "name", valueCol = "value"
|
||||
// table data
|
||||
// name | value
|
||||
// total | 100
|
||||
// found | 200
|
||||
// to struct {
|
||||
// Total int
|
||||
// Found int
|
||||
// }
|
||||
RowsToStruct(ptrStruct interface{}, keyCol, valueCol string) (int64, error)
|
||||
|
||||
// return prepared raw statement for used in times.
|
||||
// for example:
|
||||
// pre, err := dORM.Raw("INSERT INTO tag (name) VALUES (?)").Prepare()
|
||||
// r, err := pre.Exec("name1") // INSERT INTO tag (name) VALUES (`name1`)
|
||||
Prepare() (RawPreparer, error)
|
||||
}
|
||||
|
||||
// stmtQuerier statement querier
|
||||
type stmtQuerier interface {
|
||||
Close() error
|
||||
Exec(args ...interface{}) (sql.Result, error)
|
||||
// ExecContext(ctx context.Context, args ...interface{}) (sql.Result, error)
|
||||
Query(args ...interface{}) (*sql.Rows, error)
|
||||
// QueryContext(args ...interface{}) (*sql.Rows, error)
|
||||
QueryRow(args ...interface{}) *sql.Row
|
||||
// QueryRowContext(ctx context.Context, args ...interface{}) *sql.Row
|
||||
}
|
||||
|
||||
// db querier
|
||||
type dbQuerier interface {
|
||||
Prepare(query string) (*sql.Stmt, error)
|
||||
PrepareContext(ctx context.Context, query string) (*sql.Stmt, error)
|
||||
Exec(query string, args ...interface{}) (sql.Result, error)
|
||||
ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error)
|
||||
Query(query string, args ...interface{}) (*sql.Rows, error)
|
||||
QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error)
|
||||
QueryRow(query string, args ...interface{}) *sql.Row
|
||||
QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row
|
||||
}
|
||||
|
||||
// type DB interface {
|
||||
// Begin() (*sql.Tx, error)
|
||||
// Prepare(query string) (stmtQuerier, error)
|
||||
// Exec(query string, args ...interface{}) (sql.Result, error)
|
||||
// Query(query string, args ...interface{}) (*sql.Rows, error)
|
||||
// QueryRow(query string, args ...interface{}) *sql.Row
|
||||
// }
|
||||
|
||||
// transaction beginner
|
||||
type txer interface {
|
||||
Begin() (*sql.Tx, error)
|
||||
BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error)
|
||||
}
|
||||
|
||||
// transaction ending
|
||||
type txEnder interface {
|
||||
Commit() error
|
||||
Rollback() error
|
||||
}
|
||||
|
||||
// base database struct
|
||||
type dbBaser interface {
|
||||
Read(dbQuerier, *modelInfo, reflect.Value, *time.Location, []string, bool) error
|
||||
ReadBatch(dbQuerier, *querySet, *modelInfo, *Condition, interface{}, *time.Location, []string) (int64, error)
|
||||
Count(dbQuerier, *querySet, *modelInfo, *Condition, *time.Location) (int64, error)
|
||||
ReadValues(dbQuerier, *querySet, *modelInfo, *Condition, []string, interface{}, *time.Location) (int64, error)
|
||||
|
||||
Insert(dbQuerier, *modelInfo, reflect.Value, *time.Location) (int64, error)
|
||||
InsertOrUpdate(dbQuerier, *modelInfo, reflect.Value, *alias, ...string) (int64, error)
|
||||
InsertMulti(dbQuerier, *modelInfo, reflect.Value, int, *time.Location) (int64, error)
|
||||
InsertValue(dbQuerier, *modelInfo, bool, []string, []interface{}) (int64, error)
|
||||
InsertStmt(stmtQuerier, *modelInfo, reflect.Value, *time.Location) (int64, error)
|
||||
|
||||
Update(dbQuerier, *modelInfo, reflect.Value, *time.Location, []string) (int64, error)
|
||||
UpdateBatch(dbQuerier, *querySet, *modelInfo, *Condition, Params, *time.Location) (int64, error)
|
||||
|
||||
Delete(dbQuerier, *modelInfo, reflect.Value, *time.Location, []string) (int64, error)
|
||||
DeleteBatch(dbQuerier, *querySet, *modelInfo, *Condition, *time.Location) (int64, error)
|
||||
|
||||
SupportUpdateJoin() bool
|
||||
OperatorSQL(string) string
|
||||
GenerateOperatorSQL(*modelInfo, *fieldInfo, string, []interface{}, *time.Location) (string, []interface{})
|
||||
GenerateOperatorLeftCol(*fieldInfo, string, *string)
|
||||
PrepareInsert(dbQuerier, *modelInfo) (stmtQuerier, string, error)
|
||||
MaxLimit() uint64
|
||||
TableQuote() string
|
||||
ReplaceMarks(*string)
|
||||
HasReturningID(*modelInfo, *string) bool
|
||||
TimeFromDB(*time.Time, *time.Location)
|
||||
TimeToDB(*time.Time, *time.Location)
|
||||
DbTypes() map[string]string
|
||||
GetTables(dbQuerier) (map[string]bool, error)
|
||||
GetColumns(dbQuerier, string) (map[string][3]string, error)
|
||||
ShowTablesQuery() string
|
||||
ShowColumnsQuery(string) string
|
||||
IndexExists(dbQuerier, string, string) bool
|
||||
collectFieldValue(*modelInfo, *fieldInfo, reflect.Value, bool, *time.Location) (interface{}, error)
|
||||
setval(dbQuerier, *modelInfo, []string) error
|
||||
|
||||
GenerateSpecifyIndex(tableName string, useIndex int, indexes []string) string
|
||||
}
|
319
pkg/client/orm/utils.go
Normal file
319
pkg/client/orm/utils.go
Normal file
@ -0,0 +1,319 @@
|
||||
// Copyright 2014 beego Author. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package orm
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"math/big"
|
||||
"reflect"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
type fn func(string) string
|
||||
|
||||
var (
|
||||
nameStrategyMap = map[string]fn{
|
||||
defaultNameStrategy: snakeString,
|
||||
SnakeAcronymNameStrategy: snakeStringWithAcronym,
|
||||
}
|
||||
defaultNameStrategy = "snakeString"
|
||||
SnakeAcronymNameStrategy = "snakeStringWithAcronym"
|
||||
nameStrategy = defaultNameStrategy
|
||||
)
|
||||
|
||||
// StrTo is the target string
|
||||
type StrTo string
|
||||
|
||||
// Set string
|
||||
func (f *StrTo) Set(v string) {
|
||||
if v != "" {
|
||||
*f = StrTo(v)
|
||||
} else {
|
||||
f.Clear()
|
||||
}
|
||||
}
|
||||
|
||||
// Clear string
|
||||
func (f *StrTo) Clear() {
|
||||
*f = StrTo(0x1E)
|
||||
}
|
||||
|
||||
// Exist check string exist
|
||||
func (f StrTo) Exist() bool {
|
||||
return string(f) != string(0x1E)
|
||||
}
|
||||
|
||||
// Bool string to bool
|
||||
func (f StrTo) Bool() (bool, error) {
|
||||
return strconv.ParseBool(f.String())
|
||||
}
|
||||
|
||||
// Float32 string to float32
|
||||
func (f StrTo) Float32() (float32, error) {
|
||||
v, err := strconv.ParseFloat(f.String(), 32)
|
||||
return float32(v), err
|
||||
}
|
||||
|
||||
// Float64 string to float64
|
||||
func (f StrTo) Float64() (float64, error) {
|
||||
return strconv.ParseFloat(f.String(), 64)
|
||||
}
|
||||
|
||||
// Int string to int
|
||||
func (f StrTo) Int() (int, error) {
|
||||
v, err := strconv.ParseInt(f.String(), 10, 32)
|
||||
return int(v), err
|
||||
}
|
||||
|
||||
// Int8 string to int8
|
||||
func (f StrTo) Int8() (int8, error) {
|
||||
v, err := strconv.ParseInt(f.String(), 10, 8)
|
||||
return int8(v), err
|
||||
}
|
||||
|
||||
// Int16 string to int16
|
||||
func (f StrTo) Int16() (int16, error) {
|
||||
v, err := strconv.ParseInt(f.String(), 10, 16)
|
||||
return int16(v), err
|
||||
}
|
||||
|
||||
// Int32 string to int32
|
||||
func (f StrTo) Int32() (int32, error) {
|
||||
v, err := strconv.ParseInt(f.String(), 10, 32)
|
||||
return int32(v), err
|
||||
}
|
||||
|
||||
// Int64 string to int64
|
||||
func (f StrTo) Int64() (int64, error) {
|
||||
v, err := strconv.ParseInt(f.String(), 10, 64)
|
||||
if err != nil {
|
||||
i := new(big.Int)
|
||||
ni, ok := i.SetString(f.String(), 10) // octal
|
||||
if !ok {
|
||||
return v, err
|
||||
}
|
||||
return ni.Int64(), nil
|
||||
}
|
||||
return v, err
|
||||
}
|
||||
|
||||
// Uint string to uint
|
||||
func (f StrTo) Uint() (uint, error) {
|
||||
v, err := strconv.ParseUint(f.String(), 10, 32)
|
||||
return uint(v), err
|
||||
}
|
||||
|
||||
// Uint8 string to uint8
|
||||
func (f StrTo) Uint8() (uint8, error) {
|
||||
v, err := strconv.ParseUint(f.String(), 10, 8)
|
||||
return uint8(v), err
|
||||
}
|
||||
|
||||
// Uint16 string to uint16
|
||||
func (f StrTo) Uint16() (uint16, error) {
|
||||
v, err := strconv.ParseUint(f.String(), 10, 16)
|
||||
return uint16(v), err
|
||||
}
|
||||
|
||||
// Uint32 string to uint32
|
||||
func (f StrTo) Uint32() (uint32, error) {
|
||||
v, err := strconv.ParseUint(f.String(), 10, 32)
|
||||
return uint32(v), err
|
||||
}
|
||||
|
||||
// Uint64 string to uint64
|
||||
func (f StrTo) Uint64() (uint64, error) {
|
||||
v, err := strconv.ParseUint(f.String(), 10, 64)
|
||||
if err != nil {
|
||||
i := new(big.Int)
|
||||
ni, ok := i.SetString(f.String(), 10)
|
||||
if !ok {
|
||||
return v, err
|
||||
}
|
||||
return ni.Uint64(), nil
|
||||
}
|
||||
return v, err
|
||||
}
|
||||
|
||||
// String string to string
|
||||
func (f StrTo) String() string {
|
||||
if f.Exist() {
|
||||
return string(f)
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// ToStr interface to string
|
||||
func ToStr(value interface{}, args ...int) (s string) {
|
||||
switch v := value.(type) {
|
||||
case bool:
|
||||
s = strconv.FormatBool(v)
|
||||
case float32:
|
||||
s = strconv.FormatFloat(float64(v), 'f', argInt(args).Get(0, -1), argInt(args).Get(1, 32))
|
||||
case float64:
|
||||
s = strconv.FormatFloat(v, 'f', argInt(args).Get(0, -1), argInt(args).Get(1, 64))
|
||||
case int:
|
||||
s = strconv.FormatInt(int64(v), argInt(args).Get(0, 10))
|
||||
case int8:
|
||||
s = strconv.FormatInt(int64(v), argInt(args).Get(0, 10))
|
||||
case int16:
|
||||
s = strconv.FormatInt(int64(v), argInt(args).Get(0, 10))
|
||||
case int32:
|
||||
s = strconv.FormatInt(int64(v), argInt(args).Get(0, 10))
|
||||
case int64:
|
||||
s = strconv.FormatInt(v, argInt(args).Get(0, 10))
|
||||
case uint:
|
||||
s = strconv.FormatUint(uint64(v), argInt(args).Get(0, 10))
|
||||
case uint8:
|
||||
s = strconv.FormatUint(uint64(v), argInt(args).Get(0, 10))
|
||||
case uint16:
|
||||
s = strconv.FormatUint(uint64(v), argInt(args).Get(0, 10))
|
||||
case uint32:
|
||||
s = strconv.FormatUint(uint64(v), argInt(args).Get(0, 10))
|
||||
case uint64:
|
||||
s = strconv.FormatUint(v, argInt(args).Get(0, 10))
|
||||
case string:
|
||||
s = v
|
||||
case []byte:
|
||||
s = string(v)
|
||||
default:
|
||||
s = fmt.Sprintf("%v", v)
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
// ToInt64 interface to int64
|
||||
func ToInt64(value interface{}) (d int64) {
|
||||
val := reflect.ValueOf(value)
|
||||
switch value.(type) {
|
||||
case int, int8, int16, int32, int64:
|
||||
d = val.Int()
|
||||
case uint, uint8, uint16, uint32, uint64:
|
||||
d = int64(val.Uint())
|
||||
default:
|
||||
panic(fmt.Errorf("ToInt64 need numeric not `%T`", value))
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func snakeStringWithAcronym(s string) string {
|
||||
data := make([]byte, 0, len(s)*2)
|
||||
num := len(s)
|
||||
for i := 0; i < num; i++ {
|
||||
d := s[i]
|
||||
before := false
|
||||
after := false
|
||||
if i > 0 {
|
||||
before = s[i-1] >= 'a' && s[i-1] <= 'z'
|
||||
}
|
||||
if i+1 < num {
|
||||
after = s[i+1] >= 'a' && s[i+1] <= 'z'
|
||||
}
|
||||
if i > 0 && d >= 'A' && d <= 'Z' && (before || after) {
|
||||
data = append(data, '_')
|
||||
}
|
||||
data = append(data, d)
|
||||
}
|
||||
return strings.ToLower(string(data[:]))
|
||||
}
|
||||
|
||||
// snake string, XxYy to xx_yy , XxYY to xx_y_y
|
||||
func snakeString(s string) string {
|
||||
data := make([]byte, 0, len(s)*2)
|
||||
j := false
|
||||
num := len(s)
|
||||
for i := 0; i < num; i++ {
|
||||
d := s[i]
|
||||
if i > 0 && d >= 'A' && d <= 'Z' && j {
|
||||
data = append(data, '_')
|
||||
}
|
||||
if d != '_' {
|
||||
j = true
|
||||
}
|
||||
data = append(data, d)
|
||||
}
|
||||
return strings.ToLower(string(data[:]))
|
||||
}
|
||||
|
||||
// SetNameStrategy set different name strategy
|
||||
func SetNameStrategy(s string) {
|
||||
if SnakeAcronymNameStrategy != s {
|
||||
nameStrategy = defaultNameStrategy
|
||||
}
|
||||
nameStrategy = s
|
||||
}
|
||||
|
||||
// camel string, xx_yy to XxYy
|
||||
func camelString(s string) string {
|
||||
data := make([]byte, 0, len(s))
|
||||
flag, num := true, len(s)-1
|
||||
for i := 0; i <= num; i++ {
|
||||
d := s[i]
|
||||
if d == '_' {
|
||||
flag = true
|
||||
continue
|
||||
} else if flag {
|
||||
if d >= 'a' && d <= 'z' {
|
||||
d = d - 32
|
||||
}
|
||||
flag = false
|
||||
}
|
||||
data = append(data, d)
|
||||
}
|
||||
return string(data[:])
|
||||
}
|
||||
|
||||
type argString []string
|
||||
|
||||
// get string by index from string slice
|
||||
func (a argString) Get(i int, args ...string) (r string) {
|
||||
if i >= 0 && i < len(a) {
|
||||
r = a[i]
|
||||
} else if len(args) > 0 {
|
||||
r = args[0]
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
type argInt []int
|
||||
|
||||
// get int by index from int slice
|
||||
func (a argInt) Get(i int, args ...int) (r int) {
|
||||
if i >= 0 && i < len(a) {
|
||||
r = a[i]
|
||||
}
|
||||
if len(args) > 0 {
|
||||
r = args[0]
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// parse time to string with location
|
||||
func timeParse(dateString, format string) (time.Time, error) {
|
||||
tp, err := time.ParseInLocation(format, dateString, DefaultTimeLoc)
|
||||
return tp, err
|
||||
}
|
||||
|
||||
// get pointer indirect type
|
||||
func indirectType(v reflect.Type) reflect.Type {
|
||||
switch v.Kind() {
|
||||
case reflect.Ptr:
|
||||
return indirectType(v.Elem())
|
||||
default:
|
||||
return v
|
||||
}
|
||||
}
|
70
pkg/client/orm/utils_test.go
Normal file
70
pkg/client/orm/utils_test.go
Normal file
@ -0,0 +1,70 @@
|
||||
// Copyright 2014 beego Author. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package orm
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestCamelString(t *testing.T) {
|
||||
snake := []string{"pic_url", "hello_world_", "hello__World", "_HelLO_Word", "pic_url_1", "pic_url__1"}
|
||||
camel := []string{"PicUrl", "HelloWorld", "HelloWorld", "HelLOWord", "PicUrl1", "PicUrl1"}
|
||||
|
||||
answer := make(map[string]string)
|
||||
for i, v := range snake {
|
||||
answer[v] = camel[i]
|
||||
}
|
||||
|
||||
for _, v := range snake {
|
||||
res := camelString(v)
|
||||
if res != answer[v] {
|
||||
t.Error("Unit Test Fail:", v, res, answer[v])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestSnakeString(t *testing.T) {
|
||||
camel := []string{"PicUrl", "HelloWorld", "HelloWorld", "HelLOWord", "PicUrl1", "XyXX"}
|
||||
snake := []string{"pic_url", "hello_world", "hello_world", "hel_l_o_word", "pic_url1", "xy_x_x"}
|
||||
|
||||
answer := make(map[string]string)
|
||||
for i, v := range camel {
|
||||
answer[v] = snake[i]
|
||||
}
|
||||
|
||||
for _, v := range camel {
|
||||
res := snakeString(v)
|
||||
if res != answer[v] {
|
||||
t.Error("Unit Test Fail:", v, res, answer[v])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestSnakeStringWithAcronym(t *testing.T) {
|
||||
camel := []string{"ID", "PicURL", "HelloWorld", "HelloWorld", "HelLOWord", "PicUrl1", "XyXX"}
|
||||
snake := []string{"id", "pic_url", "hello_world", "hello_world", "hel_lo_word", "pic_url1", "xy_xx"}
|
||||
|
||||
answer := make(map[string]string)
|
||||
for i, v := range camel {
|
||||
answer[v] = snake[i]
|
||||
}
|
||||
|
||||
for _, v := range camel {
|
||||
res := snakeStringWithAcronym(v)
|
||||
if res != answer[v] {
|
||||
t.Error("Unit Test Fail:", v, res, answer[v])
|
||||
}
|
||||
}
|
||||
}
|
Reference in New Issue
Block a user