diff --git a/src/database/sql/sql.go b/src/database/sql/sql.go index 09de1c34e8..9d8afb01b0 100644 --- a/src/database/sql/sql.go +++ b/src/database/sql/sql.go @@ -983,9 +983,10 @@ func (db *DB) prepare(query string, strategy connReuseStrategy) (*Stmt, error) { if err != nil { return nil, err } - dc.Lock() - si, err := dc.prepareLocked(query) - dc.Unlock() + var si driver.Stmt + withLock(dc, func() { + si, err = dc.prepareLocked(query) + }) if err != nil { db.putConn(dc, err) return nil, err @@ -1028,13 +1029,15 @@ func (db *DB) exec(query string, args []interface{}, strategy connReuseStrategy) }() if execer, ok := dc.ci.(driver.Execer); ok { - dargs, err := driverArgs(nil, args) + var dargs []driver.Value + dargs, err = driverArgs(nil, args) if err != nil { return nil, err } - dc.Lock() - resi, err := execer.Exec(query, dargs) - dc.Unlock() + var resi driver.Result + withLock(dc, func() { + resi, err = execer.Exec(query, dargs) + }) if err != driver.ErrSkip { if err != nil { return nil, err @@ -1043,9 +1046,10 @@ func (db *DB) exec(query string, args []interface{}, strategy connReuseStrategy) } } - dc.Lock() - si, err := dc.ci.Prepare(query) - dc.Unlock() + var si driver.Stmt + withLock(dc, func() { + si, err = dc.ci.Prepare(query) + }) if err != nil { return nil, err } @@ -1088,9 +1092,10 @@ func (db *DB) queryConn(dc *driverConn, releaseConn func(error), query string, a releaseConn(err) return nil, err } - dc.Lock() - rowsi, err := queryer.Query(query, dargs) - dc.Unlock() + var rowsi driver.Rows + withLock(dc, func() { + rowsi, err = queryer.Query(query, dargs) + }) if err != driver.ErrSkip { if err != nil { releaseConn(err) @@ -1107,9 +1112,11 @@ func (db *DB) queryConn(dc *driverConn, releaseConn func(error), query string, a } } - dc.Lock() - si, err := dc.ci.Prepare(query) - dc.Unlock() + var si driver.Stmt + var err error + withLock(dc, func() { + si, err = dc.ci.Prepare(query) + }) if err != nil { releaseConn(err) return nil, err @@ -1118,9 +1125,9 @@ func (db *DB) queryConn(dc *driverConn, releaseConn func(error), query string, a ds := driverStmt{dc, si} rowsi, err := rowsiFromStatement(ds, args...) if err != nil { - dc.Lock() - si.Close() - dc.Unlock() + withLock(dc, func() { + si.Close() + }) releaseConn(err) return nil, err } @@ -1166,9 +1173,10 @@ func (db *DB) begin(strategy connReuseStrategy) (tx *Tx, err error) { if err != nil { return nil, err } - dc.Lock() - txi, err := dc.ci.Begin() - dc.Unlock() + var txi driver.Tx + withLock(dc, func() { + txi, err = dc.ci.Begin() + }) if err != nil { db.putConn(dc, err) return nil, err @@ -1238,10 +1246,10 @@ func (tx *Tx) grabConn() (*driverConn, error) { // Closes all Stmts prepared for this transaction. func (tx *Tx) closePrepared() { tx.stmts.Lock() + defer tx.stmts.Unlock() for _, stmt := range tx.stmts.v { stmt.Close() } - tx.stmts.Unlock() } // Commit commits the transaction. @@ -1249,9 +1257,10 @@ func (tx *Tx) Commit() error { if tx.done { return ErrTxDone } - tx.dc.Lock() - err := tx.txi.Commit() - tx.dc.Unlock() + var err error + withLock(tx.dc, func() { + err = tx.txi.Commit() + }) if err != driver.ErrBadConn { tx.closePrepared() } @@ -1264,9 +1273,10 @@ func (tx *Tx) Rollback() error { if tx.done { return ErrTxDone } - tx.dc.Lock() - err := tx.txi.Rollback() - tx.dc.Unlock() + var err error + withLock(tx.dc, func() { + err = tx.txi.Rollback() + }) if err != driver.ErrBadConn { tx.closePrepared() } @@ -1299,9 +1309,10 @@ func (tx *Tx) Prepare(query string) (*Stmt, error) { return nil, err } - dc.Lock() - si, err := dc.ci.Prepare(query) - dc.Unlock() + var si driver.Stmt + withLock(dc, func() { + si, err = dc.ci.Prepare(query) + }) if err != nil { return nil, err } @@ -1346,9 +1357,10 @@ func (tx *Tx) Stmt(stmt *Stmt) *Stmt { if err != nil { return &Stmt{stickyErr: err} } - dc.Lock() - si, err := dc.ci.Prepare(stmt.query) - dc.Unlock() + var si driver.Stmt + withLock(dc, func() { + si, err = dc.ci.Prepare(stmt.query) + }) txs := &Stmt{ db: tx.db, tx: tx, @@ -1378,9 +1390,10 @@ func (tx *Tx) Exec(query string, args ...interface{}) (Result, error) { if err != nil { return nil, err } - dc.Lock() - resi, err := execer.Exec(query, dargs) - dc.Unlock() + var resi driver.Result + withLock(dc, func() { + resi, err = execer.Exec(query, dargs) + }) if err == nil { return driverResult{dc, resi}, nil } @@ -1389,9 +1402,10 @@ func (tx *Tx) Exec(query string, args ...interface{}) (Result, error) { } } - dc.Lock() - si, err := dc.ci.Prepare(query) - dc.Unlock() + var si driver.Stmt + withLock(dc, func() { + si, err = dc.ci.Prepare(query) + }) if err != nil { return nil, err } @@ -1578,9 +1592,9 @@ func (s *Stmt) connStmt() (ci *driverConn, releaseConn func(error), si driver.St s.mu.Unlock() // No luck; we need to prepare the statement on this connection - dc.Lock() - si, err = dc.prepareLocked(s.query) - dc.Unlock() + withLock(dc, func() { + si, err = dc.prepareLocked(s.query) + }) if err != nil { s.db.putConn(dc, err) return nil, nil, nil, err @@ -1635,9 +1649,10 @@ func (s *Stmt) Query(args ...interface{}) (*Rows, error) { } func rowsiFromStatement(ds driverStmt, args ...interface{}) (driver.Rows, error) { - ds.Lock() - want := ds.si.NumInput() - ds.Unlock() + var want int + withLock(ds, func() { + want = ds.si.NumInput() + }) // -1 means the driver doesn't know how to count the number of // placeholders, so we won't sanity check input here and instead let the @@ -1652,8 +1667,8 @@ func rowsiFromStatement(ds driverStmt, args ...interface{}) (driver.Rows, error) } ds.Lock() + defer ds.Unlock() rowsi, err := ds.si.Query(dargs) - ds.Unlock() if err != nil { return nil, err } @@ -1695,9 +1710,8 @@ func (s *Stmt) Close() error { s.closed = true if s.tx != nil { - err := s.txsi.Close() - s.mu.Unlock() - return err + defer s.mu.Unlock() + return s.txsi.Close() } s.mu.Unlock() diff --git a/src/database/sql/sql_test.go b/src/database/sql/sql_test.go index 08df0c7666..41afd00e92 100644 --- a/src/database/sql/sql_test.go +++ b/src/database/sql/sql_test.go @@ -2299,6 +2299,53 @@ func TestConnectionLeak(t *testing.T) { wg.Wait() } +// badConn implements a bad driver.Conn, for TestBadDriver. +// The Exec method panics. +type badConn struct{} + +func (bc badConn) Prepare(query string) (driver.Stmt, error) { + return nil, errors.New("badConn Prepare") +} + +func (bc badConn) Close() error { + return nil +} + +func (bc badConn) Begin() (driver.Tx, error) { + return nil, errors.New("badConn Begin") +} + +func (bc badConn) Exec(query string, args []driver.Value) (driver.Result, error) { + panic("badConn.Exec") +} + +// badDriver is a driver.Driver that uses badConn. +type badDriver struct{} + +func (bd badDriver) Open(name string) (driver.Conn, error) { + return badConn{}, nil +} + +// Issue 15901. +func TestBadDriver(t *testing.T) { + Register("bad", badDriver{}) + db, err := Open("bad", "ignored") + if err != nil { + t.Fatal(err) + } + defer func() { + if r := recover(); r == nil { + t.Error("expected panic") + } else { + if want := "badConn.Exec"; r.(string) != want { + t.Errorf("panic was %v, expected %v", r, want) + } + } + }() + defer db.Close() + db.Exec("ignored") +} + func BenchmarkConcurrentDBExec(b *testing.B) { b.ReportAllocs() ct := new(concurrentDBExecTest)