diff --git a/orm/db.go b/orm/db.go index d2f8a5b2..52ce9cc3 100644 --- a/orm/db.go +++ b/orm/db.go @@ -310,7 +310,7 @@ func (d *dbBase) InsertStmt(stmt stmtQuerier, mi *modelInfo, ind reflect.Value, } // query sql ,read records and persist in dbBaser. -func (d *dbBase) Read(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location, cols []string) error { +func (d *dbBase) Read(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location, cols []string, isForUpdate bool) error { var whereCols []string var args []interface{} @@ -341,7 +341,12 @@ func (d *dbBase) Read(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Lo sep = fmt.Sprintf("%s = ? AND %s", Q, Q) wheres := strings.Join(whereCols, sep) - query := fmt.Sprintf("SELECT %s%s%s FROM %s%s%s WHERE %s%s%s = ?", Q, sels, Q, Q, mi.table, Q, Q, wheres, Q) + forUpdate := "" + if isForUpdate { + forUpdate = "FOR UPDATE" + } + + query := fmt.Sprintf("SELECT %s%s%s FROM %s%s%s WHERE %s%s%s = ? %s", Q, sels, Q, Q, mi.table, Q, Q, wheres, Q, forUpdate) refs := make([]interface{}, colsNum) for i := range refs { diff --git a/orm/orm.go b/orm/orm.go index 390d300f..42931be9 100644 --- a/orm/orm.go +++ b/orm/orm.go @@ -122,7 +122,17 @@ func (o *orm) getFieldInfo(mi *modelInfo, name string) *fieldInfo { // read data to model func (o *orm) Read(md interface{}, cols ...string) error { mi, ind := o.getMiInd(md, true) - err := o.alias.DbBaser.Read(o.db, mi, ind, o.alias.TZ, cols) + err := o.alias.DbBaser.Read(o.db, mi, ind, o.alias.TZ, cols, false) + if err != nil { + return err + } + return nil +} + +// read data to model, like Read(), but use "SELECT FOR UPDATE" form +func (o *orm) ReadForUpdate(md interface{}, cols ...string) error { + mi, ind := o.getMiInd(md, true) + err := o.alias.DbBaser.Read(o.db, mi, ind, o.alias.TZ, cols, true) if err != nil { return err } @@ -133,7 +143,7 @@ func (o *orm) Read(md interface{}, cols ...string) error { func (o *orm) ReadOrCreate(md interface{}, col1 string, cols ...string) (bool, int64, error) { cols = append([]string{col1}, cols...) mi, ind := o.getMiInd(md, true) - err := o.alias.DbBaser.Read(o.db, mi, ind, o.alias.TZ, cols) + err := o.alias.DbBaser.Read(o.db, mi, ind, o.alias.TZ, cols, false) if err == ErrNoRows { // Create id, err := o.Insert(md) diff --git a/orm/qb.go b/orm/qb.go index 9f778916..e0655a17 100644 --- a/orm/qb.go +++ b/orm/qb.go @@ -19,6 +19,7 @@ import "errors" // QueryBuilder is the Query builder interface type QueryBuilder interface { Select(fields ...string) QueryBuilder + ForUpdate() QueryBuilder From(tables ...string) QueryBuilder InnerJoin(table string) QueryBuilder LeftJoin(table string) QueryBuilder diff --git a/orm/qb_mysql.go b/orm/qb_mysql.go index 886bc50e..23bdc9ee 100644 --- a/orm/qb_mysql.go +++ b/orm/qb_mysql.go @@ -34,6 +34,12 @@ func (qb *MySQLQueryBuilder) Select(fields ...string) QueryBuilder { return qb } +// ForUpdate add the FOR UPDATE clause +func (qb *MySQLQueryBuilder) ForUpdate() QueryBuilder { + qb.Tokens = append(qb.Tokens, "FOR UPDATE") + return qb +} + // From join the tables func (qb *MySQLQueryBuilder) From(tables ...string) QueryBuilder { qb.Tokens = append(qb.Tokens, "FROM", strings.Join(tables, CommaSpace)) diff --git a/orm/qb_tidb.go b/orm/qb_tidb.go index c504049e..87b3ae84 100644 --- a/orm/qb_tidb.go +++ b/orm/qb_tidb.go @@ -31,6 +31,12 @@ func (qb *TiDBQueryBuilder) Select(fields ...string) QueryBuilder { return qb } +// ForUpdate add the FOR UPDATE clause +func (qb *TiDBQueryBuilder) ForUpdate() QueryBuilder { + qb.Tokens = append(qb.Tokens, "FOR UPDATE") + return qb +} + // From join the tables func (qb *TiDBQueryBuilder) From(tables ...string) QueryBuilder { qb.Tokens = append(qb.Tokens, "FROM", strings.Join(tables, CommaSpace)) diff --git a/orm/types.go b/orm/types.go index 8c17271d..a36e0637 100644 --- a/orm/types.go +++ b/orm/types.go @@ -45,6 +45,9 @@ type Ormer interface { // u = &User{UserName: "astaxie", Password: "pass"} // err = Ormer.Read(u, "UserName") Read(md interface{}, cols ...string) error + // Like Read(), but with "FOR UPDATE" clause, useful in transaction. + // Some databases are not support this feature. + ReadForUpdate(md interface{}, cols ...string) error // Try to read a row from the database, or insert one if it doesn't exist ReadOrCreate(md interface{}, col1 string, cols ...string) (bool, int64, error) // insert model data to database @@ -394,7 +397,7 @@ type txEnder interface { // base database struct type dbBaser interface { - Read(dbQuerier, *modelInfo, reflect.Value, *time.Location, []string) error + Read(dbQuerier, *modelInfo, reflect.Value, *time.Location, []string, bool) error Insert(dbQuerier, *modelInfo, reflect.Value, *time.Location) (int64, error) InsertOrUpdate(dbQuerier, *modelInfo, reflect.Value, *alias, ...string) (int64, error) InsertMulti(dbQuerier, *modelInfo, reflect.Value, int, *time.Location) (int64, error)