diff --git a/src/database/sql/driver/driver.go b/src/database/sql/driver/driver.go index f09396175a..ea1de5a8fb 100644 --- a/src/database/sql/driver/driver.go +++ b/src/database/sql/driver/driver.go @@ -156,6 +156,9 @@ var ErrSkip = errors.New("driver: skip fast-path; continue as if unimplemented") // if there's a possibility that the database server might have // performed the operation. Even if the server sends back an error, // you shouldn't return ErrBadConn. +// +// Errors will be checked using errors.Is. An error may +// wrap ErrBadConn or implement the Is(error) bool method. var ErrBadConn = errors.New("driver: bad connection") // Pinger is an optional interface that may be implemented by a Conn. diff --git a/src/database/sql/fakedb_test.go b/src/database/sql/fakedb_test.go index 4b68f1cba9..34e97e012b 100644 --- a/src/database/sql/fakedb_test.go +++ b/src/database/sql/fakedb_test.go @@ -96,6 +96,19 @@ type fakeDB struct { allowAny bool } +type fakeError struct { + Message string + Wrapped error +} + +func (err fakeError) Error() string { + return err.Message +} + +func (err fakeError) Unwrap() error { + return err.Wrapped +} + type table struct { mu sync.Mutex colname []string @@ -368,7 +381,7 @@ func (c *fakeConn) isDirtyAndMark() bool { func (c *fakeConn) Begin() (driver.Tx, error) { if c.isBad() { - return nil, driver.ErrBadConn + return nil, fakeError{Wrapped: driver.ErrBadConn} } if c.currTx != nil { return nil, errors.New("fakedb: already in a transaction") @@ -401,7 +414,7 @@ func (c *fakeConn) ResetSession(ctx context.Context) error { c.dirtySession = false c.currTx = nil if c.isBad() { - return driver.ErrBadConn + return fakeError{Message: "Reset Session: bad conn", Wrapped: driver.ErrBadConn} } return nil } @@ -629,7 +642,7 @@ func (c *fakeConn) PrepareContext(ctx context.Context, query string) (driver.Stm } if c.stickyBad || (hookPrepareBadConn != nil && hookPrepareBadConn()) { - return nil, driver.ErrBadConn + return nil, fakeError{Message: "Preapre: Sticky Bad", Wrapped: driver.ErrBadConn} } c.touchMem() @@ -756,7 +769,7 @@ func (s *fakeStmt) ExecContext(ctx context.Context, args []driver.NamedValue) (d } if s.c.stickyBad || (hookExecBadConn != nil && hookExecBadConn()) { - return nil, driver.ErrBadConn + return nil, fakeError{Message: "Exec: Sticky Bad", Wrapped: driver.ErrBadConn} } if s.c.isDirtyAndMark() { return nil, errFakeConnSessionDirty @@ -870,7 +883,7 @@ func (s *fakeStmt) QueryContext(ctx context.Context, args []driver.NamedValue) ( } if s.c.stickyBad || (hookQueryBadConn != nil && hookQueryBadConn()) { - return nil, driver.ErrBadConn + return nil, fakeError{Message: "Query: Sticky Bad", Wrapped: driver.ErrBadConn} } if s.c.isDirtyAndMark() { return nil, errFakeConnSessionDirty @@ -1031,7 +1044,7 @@ var hookCommitBadConn func() bool func (tx *fakeTx) Commit() error { tx.c.currTx = nil if hookCommitBadConn != nil && hookCommitBadConn() { - return driver.ErrBadConn + return fakeError{Message: "Commit: Hook Bad Conn", Wrapped: driver.ErrBadConn} } tx.c.touchMem() return nil @@ -1043,7 +1056,7 @@ var hookRollbackBadConn func() bool func (tx *fakeTx) Rollback() error { tx.c.currTx = nil if hookRollbackBadConn != nil && hookRollbackBadConn() { - return driver.ErrBadConn + return fakeError{Message: "Rollback: Hook Bad Conn", Wrapped: driver.ErrBadConn} } tx.c.touchMem() return nil diff --git a/src/database/sql/sql.go b/src/database/sql/sql.go index e4a5a225b0..897bca059b 100644 --- a/src/database/sql/sql.go +++ b/src/database/sql/sql.go @@ -848,14 +848,15 @@ func (db *DB) pingDC(ctx context.Context, dc *driverConn, release func(error)) e func (db *DB) PingContext(ctx context.Context) error { var dc *driverConn var err error - + var isBadConn bool for i := 0; i < maxBadConnRetries; i++ { dc, err = db.conn(ctx, cachedOrNewConn) - if err != driver.ErrBadConn { + isBadConn = errors.Is(err, driver.ErrBadConn) + if !isBadConn { break } } - if err == driver.ErrBadConn { + if isBadConn { dc, err = db.conn(ctx, alwaysNewConn) } if err != nil { @@ -1317,9 +1318,9 @@ func (db *DB) conn(ctx context.Context, strategy connReuseStrategy) (*driverConn db.mu.Unlock() // Reset the session if required. - if err := conn.resetSession(ctx); err == driver.ErrBadConn { + if err := conn.resetSession(ctx); errors.Is(err, driver.ErrBadConn) { conn.Close() - return nil, driver.ErrBadConn + return nil, err } return conn, nil @@ -1381,9 +1382,9 @@ func (db *DB) conn(ctx context.Context, strategy connReuseStrategy) (*driverConn } // Reset the session if required. - if err := ret.conn.resetSession(ctx); err == driver.ErrBadConn { + if err := ret.conn.resetSession(ctx); errors.Is(err, driver.ErrBadConn) { ret.conn.Close() - return nil, driver.ErrBadConn + return nil, err } return ret.conn, ret.err } @@ -1442,7 +1443,7 @@ const debugGetPut = false // putConn adds a connection to the db's free pool. // err is optionally the last error that occurred on this connection. func (db *DB) putConn(dc *driverConn, err error, resetSession bool) { - if err != driver.ErrBadConn { + if !errors.Is(err, driver.ErrBadConn) { if !dc.validateConnection(resetSession) { err = driver.ErrBadConn } @@ -1456,7 +1457,7 @@ func (db *DB) putConn(dc *driverConn, err error, resetSession bool) { panic("sql: connection returned that was never out") } - if err != driver.ErrBadConn && dc.expired(db.maxLifetime) { + if !errors.Is(err, driver.ErrBadConn) && dc.expired(db.maxLifetime) { db.maxLifetimeClosed++ err = driver.ErrBadConn } @@ -1471,7 +1472,7 @@ func (db *DB) putConn(dc *driverConn, err error, resetSession bool) { } dc.onPut = nil - if err == driver.ErrBadConn { + if errors.Is(err, driver.ErrBadConn) { // Don't reuse bad connections. // Since the conn is considered bad and is being discarded, treat it // as closed. Don't decrement the open count here, finalClose will @@ -1551,13 +1552,15 @@ const maxBadConnRetries = 2 func (db *DB) PrepareContext(ctx context.Context, query string) (*Stmt, error) { var stmt *Stmt var err error + var isBadConn bool for i := 0; i < maxBadConnRetries; i++ { stmt, err = db.prepare(ctx, query, cachedOrNewConn) - if err != driver.ErrBadConn { + isBadConn = errors.Is(err, driver.ErrBadConn) + if !isBadConn { break } } - if err == driver.ErrBadConn { + if isBadConn { return db.prepare(ctx, query, alwaysNewConn) } return stmt, err @@ -1627,13 +1630,15 @@ func (db *DB) prepareDC(ctx context.Context, dc *driverConn, release func(error) func (db *DB) ExecContext(ctx context.Context, query string, args ...interface{}) (Result, error) { var res Result var err error + var isBadConn bool for i := 0; i < maxBadConnRetries; i++ { res, err = db.exec(ctx, query, args, cachedOrNewConn) - if err != driver.ErrBadConn { + isBadConn = errors.Is(err, driver.ErrBadConn) + if !isBadConn { break } } - if err == driver.ErrBadConn { + if isBadConn { return db.exec(ctx, query, args, alwaysNewConn) } return res, err @@ -1700,13 +1705,15 @@ func (db *DB) execDC(ctx context.Context, dc *driverConn, release func(error), q func (db *DB) QueryContext(ctx context.Context, query string, args ...interface{}) (*Rows, error) { var rows *Rows var err error + var isBadConn bool for i := 0; i < maxBadConnRetries; i++ { rows, err = db.query(ctx, query, args, cachedOrNewConn) - if err != driver.ErrBadConn { + isBadConn = errors.Is(err, driver.ErrBadConn) + if !isBadConn { break } } - if err == driver.ErrBadConn { + if isBadConn { return db.query(ctx, query, args, alwaysNewConn) } return rows, err @@ -1835,13 +1842,15 @@ func (db *DB) QueryRow(query string, args ...interface{}) *Row { func (db *DB) BeginTx(ctx context.Context, opts *TxOptions) (*Tx, error) { var tx *Tx var err error + var isBadConn bool for i := 0; i < maxBadConnRetries; i++ { tx, err = db.begin(ctx, opts, cachedOrNewConn) - if err != driver.ErrBadConn { + isBadConn = errors.Is(err, driver.ErrBadConn) + if !isBadConn { break } } - if err == driver.ErrBadConn { + if isBadConn { return db.begin(ctx, opts, alwaysNewConn) } return tx, err @@ -1914,13 +1923,15 @@ var ErrConnDone = errors.New("sql: connection is already closed") func (db *DB) Conn(ctx context.Context) (*Conn, error) { var dc *driverConn var err error + var isBadConn bool for i := 0; i < maxBadConnRetries; i++ { dc, err = db.conn(ctx, cachedOrNewConn) - if err != driver.ErrBadConn { + isBadConn = errors.Is(err, driver.ErrBadConn) + if !isBadConn { break } } - if err == driver.ErrBadConn { + if isBadConn { dc, err = db.conn(ctx, alwaysNewConn) } if err != nil { @@ -2032,8 +2043,8 @@ func (c *Conn) PrepareContext(ctx context.Context, query string) (*Stmt, error) // Raw executes f exposing the underlying driver connection for the // duration of f. The driverConn must not be used outside of f. // -// Once f returns and err is not equal to driver.ErrBadConn, the Conn will -// continue to be usable until Conn.Close is called. +// Once f returns and err is not driver.ErrBadConn, the Conn will continue to be usable +// until Conn.Close is called. func (c *Conn) Raw(f func(driverConn interface{}) error) (err error) { var dc *driverConn var release releaseConn @@ -2084,7 +2095,7 @@ func (c *Conn) BeginTx(ctx context.Context, opts *TxOptions) (*Tx, error) { // as the sql operation is done with the dc. func (c *Conn) closemuRUnlockCondReleaseConn(err error) { c.closemu.RUnlock() - if err == driver.ErrBadConn { + if errors.Is(err, driver.ErrBadConn) { c.close(err) } } @@ -2278,7 +2289,7 @@ func (tx *Tx) Commit() error { withLock(tx.dc, func() { err = tx.txi.Commit() }) - if err != driver.ErrBadConn { + if !errors.Is(err, driver.ErrBadConn) { tx.closePrepared() } tx.close(err) @@ -2310,7 +2321,7 @@ func (tx *Tx) rollback(discardConn bool) error { withLock(tx.dc, func() { err = tx.txi.Rollback() }) - if err != driver.ErrBadConn { + if !errors.Is(err, driver.ErrBadConn) { tx.closePrepared() } if discardConn { @@ -2616,7 +2627,7 @@ func (s *Stmt) ExecContext(ctx context.Context, args ...interface{}) (Result, er } dc, releaseConn, ds, err := s.connStmt(ctx, strategy) if err != nil { - if err == driver.ErrBadConn { + if errors.Is(err, driver.ErrBadConn) { continue } return nil, err @@ -2624,7 +2635,7 @@ func (s *Stmt) ExecContext(ctx context.Context, args ...interface{}) (Result, er res, err = resultFromStatement(ctx, dc.ci, ds, args...) releaseConn(err) - if err != driver.ErrBadConn { + if !errors.Is(err, driver.ErrBadConn) { return res, err } } @@ -2764,7 +2775,7 @@ func (s *Stmt) QueryContext(ctx context.Context, args ...interface{}) (*Rows, er } dc, releaseConn, ds, err := s.connStmt(ctx, strategy) if err != nil { - if err == driver.ErrBadConn { + if errors.Is(err, driver.ErrBadConn) { continue } return nil, err @@ -2798,7 +2809,7 @@ func (s *Stmt) QueryContext(ctx context.Context, args ...interface{}) (*Rows, er } releaseConn(err) - if err != driver.ErrBadConn { + if !errors.Is(err, driver.ErrBadConn) { return nil, err } } diff --git a/src/database/sql/sql_test.go b/src/database/sql/sql_test.go index 15c30e0d00..889adc3164 100644 --- a/src/database/sql/sql_test.go +++ b/src/database/sql/sql_test.go @@ -3159,7 +3159,7 @@ func TestTxEndBadConn(t *testing.T) { return broken } - if err := op(); err != driver.ErrBadConn { + if err := op(); !errors.Is(err, driver.ErrBadConn) { t.Errorf(name+": %v", err) return }