mirror of
https://github.com/beego/bee.git
synced 2024-11-14 17:00:54 +00:00
413 lines
8.9 KiB
Go
413 lines
8.9 KiB
Go
|
// Go MySQL Driver - A MySQL-Driver for Go's database/sql package
|
||
|
//
|
||
|
// Copyright 2012 The Go-MySQL-Driver Authors. All rights reserved.
|
||
|
//
|
||
|
// This Source Code Form is subject to the terms of the Mozilla Public
|
||
|
// License, v. 2.0. If a copy of the MPL was not distributed with this file,
|
||
|
// You can obtain one at http://mozilla.org/MPL/2.0/.
|
||
|
|
||
|
package mysql
|
||
|
|
||
|
import (
|
||
|
"crypto/tls"
|
||
|
"database/sql/driver"
|
||
|
"errors"
|
||
|
"net"
|
||
|
"strconv"
|
||
|
"strings"
|
||
|
"time"
|
||
|
)
|
||
|
|
||
|
type mysqlConn struct {
|
||
|
buf buffer
|
||
|
netConn net.Conn
|
||
|
affectedRows uint64
|
||
|
insertId uint64
|
||
|
cfg *config
|
||
|
maxPacketAllowed int
|
||
|
maxWriteSize int
|
||
|
flags clientFlag
|
||
|
status statusFlag
|
||
|
sequence uint8
|
||
|
parseTime bool
|
||
|
strict bool
|
||
|
}
|
||
|
|
||
|
type config struct {
|
||
|
user string
|
||
|
passwd string
|
||
|
net string
|
||
|
addr string
|
||
|
dbname string
|
||
|
params map[string]string
|
||
|
loc *time.Location
|
||
|
tls *tls.Config
|
||
|
timeout time.Duration
|
||
|
collation uint8
|
||
|
allowAllFiles bool
|
||
|
allowOldPasswords bool
|
||
|
allowCleartextPasswords bool
|
||
|
clientFoundRows bool
|
||
|
columnsWithAlias bool
|
||
|
interpolateParams bool
|
||
|
}
|
||
|
|
||
|
// Handles parameters set in DSN after the connection is established
|
||
|
func (mc *mysqlConn) handleParams() (err error) {
|
||
|
for param, val := range mc.cfg.params {
|
||
|
switch param {
|
||
|
// Charset
|
||
|
case "charset":
|
||
|
charsets := strings.Split(val, ",")
|
||
|
for i := range charsets {
|
||
|
// ignore errors here - a charset may not exist
|
||
|
err = mc.exec("SET NAMES " + charsets[i])
|
||
|
if err == nil {
|
||
|
break
|
||
|
}
|
||
|
}
|
||
|
if err != nil {
|
||
|
return
|
||
|
}
|
||
|
|
||
|
// time.Time parsing
|
||
|
case "parseTime":
|
||
|
var isBool bool
|
||
|
mc.parseTime, isBool = readBool(val)
|
||
|
if !isBool {
|
||
|
return errors.New("Invalid Bool value: " + val)
|
||
|
}
|
||
|
|
||
|
// Strict mode
|
||
|
case "strict":
|
||
|
var isBool bool
|
||
|
mc.strict, isBool = readBool(val)
|
||
|
if !isBool {
|
||
|
return errors.New("Invalid Bool value: " + val)
|
||
|
}
|
||
|
|
||
|
// Compression
|
||
|
case "compress":
|
||
|
err = errors.New("Compression not implemented yet")
|
||
|
return
|
||
|
|
||
|
// System Vars
|
||
|
default:
|
||
|
err = mc.exec("SET " + param + "=" + val + "")
|
||
|
if err != nil {
|
||
|
return
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
|
||
|
return
|
||
|
}
|
||
|
|
||
|
func (mc *mysqlConn) Begin() (driver.Tx, error) {
|
||
|
if mc.netConn == nil {
|
||
|
errLog.Print(ErrInvalidConn)
|
||
|
return nil, driver.ErrBadConn
|
||
|
}
|
||
|
err := mc.exec("START TRANSACTION")
|
||
|
if err == nil {
|
||
|
return &mysqlTx{mc}, err
|
||
|
}
|
||
|
|
||
|
return nil, err
|
||
|
}
|
||
|
|
||
|
func (mc *mysqlConn) Close() (err error) {
|
||
|
// Makes Close idempotent
|
||
|
if mc.netConn != nil {
|
||
|
err = mc.writeCommandPacket(comQuit)
|
||
|
}
|
||
|
|
||
|
mc.cleanup()
|
||
|
|
||
|
return
|
||
|
}
|
||
|
|
||
|
// Closes the network connection and unsets internal variables. Do not call this
|
||
|
// function after successfully authentication, call Close instead. This function
|
||
|
// is called before auth or on auth failure because MySQL will have already
|
||
|
// closed the network connection.
|
||
|
func (mc *mysqlConn) cleanup() {
|
||
|
// Makes cleanup idempotent
|
||
|
if mc.netConn != nil {
|
||
|
if err := mc.netConn.Close(); err != nil {
|
||
|
errLog.Print(err)
|
||
|
}
|
||
|
mc.netConn = nil
|
||
|
}
|
||
|
mc.cfg = nil
|
||
|
mc.buf.rd = nil
|
||
|
}
|
||
|
|
||
|
func (mc *mysqlConn) Prepare(query string) (driver.Stmt, error) {
|
||
|
if mc.netConn == nil {
|
||
|
errLog.Print(ErrInvalidConn)
|
||
|
return nil, driver.ErrBadConn
|
||
|
}
|
||
|
// Send command
|
||
|
err := mc.writeCommandPacketStr(comStmtPrepare, query)
|
||
|
if err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
|
||
|
stmt := &mysqlStmt{
|
||
|
mc: mc,
|
||
|
}
|
||
|
|
||
|
// Read Result
|
||
|
columnCount, err := stmt.readPrepareResultPacket()
|
||
|
if err == nil {
|
||
|
if stmt.paramCount > 0 {
|
||
|
if err = mc.readUntilEOF(); err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
}
|
||
|
|
||
|
if columnCount > 0 {
|
||
|
err = mc.readUntilEOF()
|
||
|
}
|
||
|
}
|
||
|
|
||
|
return stmt, err
|
||
|
}
|
||
|
|
||
|
func (mc *mysqlConn) interpolateParams(query string, args []driver.Value) (string, error) {
|
||
|
buf := mc.buf.takeCompleteBuffer()
|
||
|
if buf == nil {
|
||
|
// can not take the buffer. Something must be wrong with the connection
|
||
|
errLog.Print(ErrBusyBuffer)
|
||
|
return "", driver.ErrBadConn
|
||
|
}
|
||
|
buf = buf[:0]
|
||
|
argPos := 0
|
||
|
|
||
|
for i := 0; i < len(query); i++ {
|
||
|
q := strings.IndexByte(query[i:], '?')
|
||
|
if q == -1 {
|
||
|
buf = append(buf, query[i:]...)
|
||
|
break
|
||
|
}
|
||
|
buf = append(buf, query[i:i+q]...)
|
||
|
i += q
|
||
|
|
||
|
arg := args[argPos]
|
||
|
argPos++
|
||
|
|
||
|
if arg == nil {
|
||
|
buf = append(buf, "NULL"...)
|
||
|
continue
|
||
|
}
|
||
|
|
||
|
switch v := arg.(type) {
|
||
|
case int64:
|
||
|
buf = strconv.AppendInt(buf, v, 10)
|
||
|
case float64:
|
||
|
buf = strconv.AppendFloat(buf, v, 'g', -1, 64)
|
||
|
case bool:
|
||
|
if v {
|
||
|
buf = append(buf, '1')
|
||
|
} else {
|
||
|
buf = append(buf, '0')
|
||
|
}
|
||
|
case time.Time:
|
||
|
if v.IsZero() {
|
||
|
buf = append(buf, "'0000-00-00'"...)
|
||
|
} else {
|
||
|
v := v.In(mc.cfg.loc)
|
||
|
v = v.Add(time.Nanosecond * 500) // To round under microsecond
|
||
|
year := v.Year()
|
||
|
year100 := year / 100
|
||
|
year1 := year % 100
|
||
|
month := v.Month()
|
||
|
day := v.Day()
|
||
|
hour := v.Hour()
|
||
|
minute := v.Minute()
|
||
|
second := v.Second()
|
||
|
micro := v.Nanosecond() / 1000
|
||
|
|
||
|
buf = append(buf, []byte{
|
||
|
'\'',
|
||
|
digits10[year100], digits01[year100],
|
||
|
digits10[year1], digits01[year1],
|
||
|
'-',
|
||
|
digits10[month], digits01[month],
|
||
|
'-',
|
||
|
digits10[day], digits01[day],
|
||
|
' ',
|
||
|
digits10[hour], digits01[hour],
|
||
|
':',
|
||
|
digits10[minute], digits01[minute],
|
||
|
':',
|
||
|
digits10[second], digits01[second],
|
||
|
}...)
|
||
|
|
||
|
if micro != 0 {
|
||
|
micro10000 := micro / 10000
|
||
|
micro100 := micro / 100 % 100
|
||
|
micro1 := micro % 100
|
||
|
buf = append(buf, []byte{
|
||
|
'.',
|
||
|
digits10[micro10000], digits01[micro10000],
|
||
|
digits10[micro100], digits01[micro100],
|
||
|
digits10[micro1], digits01[micro1],
|
||
|
}...)
|
||
|
}
|
||
|
buf = append(buf, '\'')
|
||
|
}
|
||
|
case []byte:
|
||
|
if v == nil {
|
||
|
buf = append(buf, "NULL"...)
|
||
|
} else {
|
||
|
buf = append(buf, "_binary'"...)
|
||
|
if mc.status&statusNoBackslashEscapes == 0 {
|
||
|
buf = escapeBytesBackslash(buf, v)
|
||
|
} else {
|
||
|
buf = escapeBytesQuotes(buf, v)
|
||
|
}
|
||
|
buf = append(buf, '\'')
|
||
|
}
|
||
|
case string:
|
||
|
buf = append(buf, '\'')
|
||
|
if mc.status&statusNoBackslashEscapes == 0 {
|
||
|
buf = escapeStringBackslash(buf, v)
|
||
|
} else {
|
||
|
buf = escapeStringQuotes(buf, v)
|
||
|
}
|
||
|
buf = append(buf, '\'')
|
||
|
default:
|
||
|
return "", driver.ErrSkip
|
||
|
}
|
||
|
|
||
|
if len(buf)+4 > mc.maxPacketAllowed {
|
||
|
return "", driver.ErrSkip
|
||
|
}
|
||
|
}
|
||
|
if argPos != len(args) {
|
||
|
return "", driver.ErrSkip
|
||
|
}
|
||
|
return string(buf), nil
|
||
|
}
|
||
|
|
||
|
func (mc *mysqlConn) Exec(query string, args []driver.Value) (driver.Result, error) {
|
||
|
if mc.netConn == nil {
|
||
|
errLog.Print(ErrInvalidConn)
|
||
|
return nil, driver.ErrBadConn
|
||
|
}
|
||
|
if len(args) != 0 {
|
||
|
if !mc.cfg.interpolateParams {
|
||
|
return nil, driver.ErrSkip
|
||
|
}
|
||
|
// try to interpolate the parameters to save extra roundtrips for preparing and closing a statement
|
||
|
prepared, err := mc.interpolateParams(query, args)
|
||
|
if err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
query = prepared
|
||
|
args = nil
|
||
|
}
|
||
|
mc.affectedRows = 0
|
||
|
mc.insertId = 0
|
||
|
|
||
|
err := mc.exec(query)
|
||
|
if err == nil {
|
||
|
return &mysqlResult{
|
||
|
affectedRows: int64(mc.affectedRows),
|
||
|
insertId: int64(mc.insertId),
|
||
|
}, err
|
||
|
}
|
||
|
return nil, err
|
||
|
}
|
||
|
|
||
|
// Internal function to execute commands
|
||
|
func (mc *mysqlConn) exec(query string) error {
|
||
|
// Send command
|
||
|
err := mc.writeCommandPacketStr(comQuery, query)
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
|
||
|
// Read Result
|
||
|
resLen, err := mc.readResultSetHeaderPacket()
|
||
|
if err == nil && resLen > 0 {
|
||
|
if err = mc.readUntilEOF(); err != nil {
|
||
|
return err
|
||
|
}
|
||
|
|
||
|
err = mc.readUntilEOF()
|
||
|
}
|
||
|
|
||
|
return err
|
||
|
}
|
||
|
|
||
|
func (mc *mysqlConn) Query(query string, args []driver.Value) (driver.Rows, error) {
|
||
|
if mc.netConn == nil {
|
||
|
errLog.Print(ErrInvalidConn)
|
||
|
return nil, driver.ErrBadConn
|
||
|
}
|
||
|
if len(args) != 0 {
|
||
|
if !mc.cfg.interpolateParams {
|
||
|
return nil, driver.ErrSkip
|
||
|
}
|
||
|
// try client-side prepare to reduce roundtrip
|
||
|
prepared, err := mc.interpolateParams(query, args)
|
||
|
if err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
query = prepared
|
||
|
args = nil
|
||
|
}
|
||
|
// Send command
|
||
|
err := mc.writeCommandPacketStr(comQuery, query)
|
||
|
if err == nil {
|
||
|
// Read Result
|
||
|
var resLen int
|
||
|
resLen, err = mc.readResultSetHeaderPacket()
|
||
|
if err == nil {
|
||
|
rows := new(textRows)
|
||
|
rows.mc = mc
|
||
|
|
||
|
if resLen == 0 {
|
||
|
// no columns, no more data
|
||
|
return emptyRows{}, nil
|
||
|
}
|
||
|
// Columns
|
||
|
rows.columns, err = mc.readColumns(resLen)
|
||
|
return rows, err
|
||
|
}
|
||
|
}
|
||
|
return nil, err
|
||
|
}
|
||
|
|
||
|
// Gets the value of the given MySQL System Variable
|
||
|
// The returned byte slice is only valid until the next read
|
||
|
func (mc *mysqlConn) getSystemVar(name string) ([]byte, error) {
|
||
|
// Send command
|
||
|
if err := mc.writeCommandPacketStr(comQuery, "SELECT @@"+name); err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
|
||
|
// Read Result
|
||
|
resLen, err := mc.readResultSetHeaderPacket()
|
||
|
if err == nil {
|
||
|
rows := new(textRows)
|
||
|
rows.mc = mc
|
||
|
|
||
|
if resLen > 0 {
|
||
|
// Columns
|
||
|
if err := mc.readUntilEOF(); err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
}
|
||
|
|
||
|
dest := make([]driver.Value, resLen)
|
||
|
if err = rows.readRow(dest); err == nil {
|
||
|
return dest[0].([]byte), mc.readUntilEOF()
|
||
|
}
|
||
|
}
|
||
|
return nil, err
|
||
|
}
|