mirror of
https://github.com/golang/go
synced 2024-11-18 08:44:43 -07:00
database/sql: prevent race on Rows close with Tx Rollback
In addition to adding a guard to the Rows close, add a var in the fakeConn that gets read and written to on each operation, simulating writing or reading from the server. TestConcurrency/TxStmt* tests have been commented out as they now fail after checking for races on the fakeConn. See issue #20646 for more information. Fixes #20622 Change-Id: I80b36ea33d776e5b4968be1683ff8c61728ee1ea Reviewed-on: https://go-review.googlesource.com/45275 Run-TryBot: Daniel Theophanes <kardianos@gmail.com> TryBot-Result: Gobot Gobot <gobot@golang.org> Reviewed-by: Brad Fitzpatrick <bradfitz@golang.org>
This commit is contained in:
parent
3820191839
commit
b0d592c3c9
@ -89,6 +89,10 @@ type fakeConn struct {
|
||||
|
||||
currTx *fakeTx
|
||||
|
||||
// Every operation writes to line to enable the race detector
|
||||
// check for data races.
|
||||
line int64
|
||||
|
||||
// Stats for tests:
|
||||
mu sync.Mutex
|
||||
stmtsMade int
|
||||
@ -299,6 +303,7 @@ func (c *fakeConn) Begin() (driver.Tx, error) {
|
||||
if c.currTx != nil {
|
||||
return nil, errors.New("already in a transaction")
|
||||
}
|
||||
c.line++
|
||||
c.currTx = &fakeTx{c: c}
|
||||
return c.currTx, nil
|
||||
}
|
||||
@ -340,6 +345,7 @@ func (c *fakeConn) Close() (err error) {
|
||||
drv.mu.Unlock()
|
||||
}
|
||||
}()
|
||||
c.line++
|
||||
if c.currTx != nil {
|
||||
return errors.New("can't close fakeConn; in a Transaction")
|
||||
}
|
||||
@ -527,6 +533,7 @@ func (c *fakeConn) PrepareContext(ctx context.Context, query string) (driver.Stm
|
||||
return nil, driver.ErrBadConn
|
||||
}
|
||||
|
||||
c.line++
|
||||
var firstStmt, prev *fakeStmt
|
||||
for _, query := range strings.Split(query, ";") {
|
||||
parts := strings.Split(query, "|")
|
||||
@ -615,6 +622,7 @@ func (s *fakeStmt) Close() error {
|
||||
if s.c.db == nil {
|
||||
panic("in fakeStmt.Close, conn's db is nil (already closed)")
|
||||
}
|
||||
s.c.line++
|
||||
if !s.closed {
|
||||
s.c.incrStat(&s.c.stmtsClosed)
|
||||
s.closed = true
|
||||
@ -649,6 +657,7 @@ func (s *fakeStmt) ExecContext(ctx context.Context, args []driver.NamedValue) (d
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
s.c.line++
|
||||
|
||||
if s.wait > 0 {
|
||||
time.Sleep(s.wait)
|
||||
@ -761,6 +770,7 @@ func (s *fakeStmt) QueryContext(ctx context.Context, args []driver.NamedValue) (
|
||||
return nil, err
|
||||
}
|
||||
|
||||
s.c.line++
|
||||
db := s.c.db
|
||||
if len(args) != s.placeholders {
|
||||
panic("error in pkg db; should only get here if size is correct")
|
||||
@ -856,6 +866,7 @@ func (s *fakeStmt) QueryContext(ctx context.Context, args []driver.NamedValue) (
|
||||
}
|
||||
|
||||
cursor := &rowsCursor{
|
||||
c: s.c,
|
||||
posRow: -1,
|
||||
rows: setMRows,
|
||||
cols: setColumns,
|
||||
@ -880,6 +891,7 @@ func (tx *fakeTx) Commit() error {
|
||||
if hookCommitBadConn != nil && hookCommitBadConn() {
|
||||
return driver.ErrBadConn
|
||||
}
|
||||
tx.c.line++
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -891,10 +903,12 @@ func (tx *fakeTx) Rollback() error {
|
||||
if hookRollbackBadConn != nil && hookRollbackBadConn() {
|
||||
return driver.ErrBadConn
|
||||
}
|
||||
tx.c.line++
|
||||
return nil
|
||||
}
|
||||
|
||||
type rowsCursor struct {
|
||||
c *fakeConn
|
||||
cols [][]string
|
||||
colType [][]string
|
||||
posSet int
|
||||
@ -918,6 +932,7 @@ func (rc *rowsCursor) Close() error {
|
||||
bs[0] = 255 // first byte corrupted
|
||||
}
|
||||
}
|
||||
rc.c.line++
|
||||
rc.closed = true
|
||||
return nil
|
||||
}
|
||||
@ -940,6 +955,7 @@ func (rc *rowsCursor) Next(dest []driver.Value) error {
|
||||
if rc.closed {
|
||||
return errors.New("fakedb: cursor is closed")
|
||||
}
|
||||
rc.c.line++
|
||||
rc.posRow++
|
||||
if rc.posRow == rc.errPos {
|
||||
return rc.err
|
||||
@ -973,10 +989,12 @@ func (rc *rowsCursor) Next(dest []driver.Value) error {
|
||||
}
|
||||
|
||||
func (rc *rowsCursor) HasNextResultSet() bool {
|
||||
rc.c.line++
|
||||
return rc.posSet < len(rc.rows)-1
|
||||
}
|
||||
|
||||
func (rc *rowsCursor) NextResultSet() error {
|
||||
rc.c.line++
|
||||
if rc.HasNextResultSet() {
|
||||
rc.posSet++
|
||||
rc.posRow = -1
|
||||
|
@ -2700,7 +2700,9 @@ func (rs *Rows) close(err error) error {
|
||||
rs.lasterr = err
|
||||
}
|
||||
|
||||
err = rs.rowsi.Close()
|
||||
withLock(rs.dc, func() {
|
||||
err = rs.rowsi.Close()
|
||||
})
|
||||
if fn := rowsCloseHook(); fn != nil {
|
||||
fn(rs, &err)
|
||||
}
|
||||
|
@ -2471,6 +2471,8 @@ func TestManyErrBadConn(t *testing.T) {
|
||||
// closing a transaction. Ensure Rows is closed while closing a trasaction.
|
||||
func TestIssue20575(t *testing.T) {
|
||||
db := newTestDB(t, "people")
|
||||
defer closeDB(t, db)
|
||||
|
||||
tx, err := db.Begin()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
@ -2493,6 +2495,43 @@ func TestIssue20575(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// TestIssue20622 tests closing the transaction before rows is closed, requires
|
||||
// the race detector to fail.
|
||||
func TestIssue20622(t *testing.T) {
|
||||
db := newTestDB(t, "people")
|
||||
defer closeDB(t, db)
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
tx, err := db.BeginTx(ctx, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
rows, err := tx.Query("SELECT|people|age,name|")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
count := 0
|
||||
for rows.Next() {
|
||||
count++
|
||||
var age int
|
||||
var name string
|
||||
if err := rows.Scan(&age, &name); err != nil {
|
||||
t.Fatal("scan failed", err)
|
||||
}
|
||||
|
||||
if count == 1 {
|
||||
cancel()
|
||||
}
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
}
|
||||
rows.Close()
|
||||
tx.Commit()
|
||||
}
|
||||
|
||||
// golang.org/issue/5718
|
||||
func TestErrBadConnReconnect(t *testing.T) {
|
||||
db := newTestDB(t, "foo")
|
||||
@ -2956,8 +2995,9 @@ func (c *concurrentRandomTest) init(t testing.TB, db *DB) {
|
||||
new(concurrentStmtExecTest),
|
||||
new(concurrentTxQueryTest),
|
||||
new(concurrentTxExecTest),
|
||||
new(concurrentTxStmtQueryTest),
|
||||
new(concurrentTxStmtExecTest),
|
||||
// golang.org/issue/20646
|
||||
// new(concurrentTxStmtQueryTest),
|
||||
// new(concurrentTxStmtExecTest),
|
||||
}
|
||||
for _, ct := range c.tests {
|
||||
ct.init(t, db)
|
||||
@ -3193,15 +3233,26 @@ func TestIssue18719(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestConcurrency(t *testing.T) {
|
||||
doConcurrentTest(t, new(concurrentDBQueryTest))
|
||||
doConcurrentTest(t, new(concurrentDBExecTest))
|
||||
doConcurrentTest(t, new(concurrentStmtQueryTest))
|
||||
doConcurrentTest(t, new(concurrentStmtExecTest))
|
||||
doConcurrentTest(t, new(concurrentTxQueryTest))
|
||||
doConcurrentTest(t, new(concurrentTxExecTest))
|
||||
doConcurrentTest(t, new(concurrentTxStmtQueryTest))
|
||||
doConcurrentTest(t, new(concurrentTxStmtExecTest))
|
||||
doConcurrentTest(t, new(concurrentRandomTest))
|
||||
list := []struct {
|
||||
name string
|
||||
ct concurrentTest
|
||||
}{
|
||||
{"Query", new(concurrentDBQueryTest)},
|
||||
{"Exec", new(concurrentDBExecTest)},
|
||||
{"StmtQuery", new(concurrentStmtQueryTest)},
|
||||
{"StmtExec", new(concurrentStmtExecTest)},
|
||||
{"TxQuery", new(concurrentTxQueryTest)},
|
||||
{"TxExec", new(concurrentTxExecTest)},
|
||||
// golang.org/issue/20646
|
||||
// {"TxStmtQuery", new(concurrentTxStmtQueryTest)},
|
||||
// {"TxStmtExec", new(concurrentTxStmtExecTest)},
|
||||
{"Random", new(concurrentRandomTest)},
|
||||
}
|
||||
for _, item := range list {
|
||||
t.Run(item.name, func(t *testing.T) {
|
||||
doConcurrentTest(t, item.ct)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestConnectionLeak(t *testing.T) {
|
||||
@ -3531,6 +3582,7 @@ func BenchmarkConcurrentTxExec(b *testing.B) {
|
||||
}
|
||||
|
||||
func BenchmarkConcurrentTxStmtQuery(b *testing.B) {
|
||||
b.Skip("golang.org/issue/20646")
|
||||
b.ReportAllocs()
|
||||
ct := new(concurrentTxStmtQueryTest)
|
||||
for i := 0; i < b.N; i++ {
|
||||
@ -3539,6 +3591,7 @@ func BenchmarkConcurrentTxStmtQuery(b *testing.B) {
|
||||
}
|
||||
|
||||
func BenchmarkConcurrentTxStmtExec(b *testing.B) {
|
||||
b.Skip("golang.org/issue/20646")
|
||||
b.ReportAllocs()
|
||||
ct := new(concurrentTxStmtExecTest)
|
||||
for i := 0; i < b.N; i++ {
|
||||
|
Loading…
Reference in New Issue
Block a user