mirror of
https://github.com/golang/go
synced 2024-11-19 21:04:43 -07:00
database/sql: fix auto-reconnect in prepared statements
This also fixes several connection leaks. Fixes #5718 R=bradfitz, adg CC=alberto.garcia.hierro, golang-dev https://golang.org/cl/14920046
This commit is contained in:
parent
a075fdbaa6
commit
762a9d934e
@ -433,11 +433,19 @@ func (c *fakeConn) prepareInsert(stmt *fakeStmt, parts []string) (driver.Stmt, e
|
||||
return stmt, nil
|
||||
}
|
||||
|
||||
// hook to simulate broken connections
|
||||
var hookPrepareBadConn func() bool
|
||||
|
||||
func (c *fakeConn) Prepare(query string) (driver.Stmt, error) {
|
||||
c.numPrepare++
|
||||
if c.db == nil {
|
||||
panic("nil c.db; conn = " + fmt.Sprintf("%#v", c))
|
||||
}
|
||||
|
||||
if hookPrepareBadConn != nil && hookPrepareBadConn() {
|
||||
return nil, driver.ErrBadConn
|
||||
}
|
||||
|
||||
parts := strings.Split(query, "|")
|
||||
if len(parts) < 1 {
|
||||
return nil, errf("empty query")
|
||||
@ -489,10 +497,18 @@ func (s *fakeStmt) Close() error {
|
||||
|
||||
var errClosed = errors.New("fakedb: statement has been closed")
|
||||
|
||||
// hook to simulate broken connections
|
||||
var hookExecBadConn func() bool
|
||||
|
||||
func (s *fakeStmt) Exec(args []driver.Value) (driver.Result, error) {
|
||||
if s.closed {
|
||||
return nil, errClosed
|
||||
}
|
||||
|
||||
if hookExecBadConn != nil && hookExecBadConn() {
|
||||
return nil, driver.ErrBadConn
|
||||
}
|
||||
|
||||
err := checkSubsetTypes(args)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@ -565,10 +581,18 @@ func (s *fakeStmt) execInsert(args []driver.Value, doInsert bool) (driver.Result
|
||||
return driver.RowsAffected(1), nil
|
||||
}
|
||||
|
||||
// hook to simulate broken connections
|
||||
var hookQueryBadConn func() bool
|
||||
|
||||
func (s *fakeStmt) Query(args []driver.Value) (driver.Rows, error) {
|
||||
if s.closed {
|
||||
return nil, errClosed
|
||||
}
|
||||
|
||||
if hookQueryBadConn != nil && hookQueryBadConn() {
|
||||
return nil, driver.ErrBadConn
|
||||
}
|
||||
|
||||
err := checkSubsetTypes(args)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
@ -256,7 +256,7 @@ func (dc *driverConn) prepareLocked(query string) (driver.Stmt, error) {
|
||||
// stmt closes if the conn is about to close anyway? For now
|
||||
// do the safe thing, in case stmts need to be closed.
|
||||
//
|
||||
// TODO(bradfitz): after Go 1.1, closing driver.Stmts
|
||||
// TODO(bradfitz): after Go 1.2, closing driver.Stmts
|
||||
// should be moved to driverStmt, using unique
|
||||
// *driverStmts everywhere (including from
|
||||
// *Stmt.connStmt, instead of returning a
|
||||
@ -798,13 +798,17 @@ func (db *DB) putConnDBLocked(dc *driverConn, err error) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// maxBadConnRetries is the number of maximum retries if the driver returns
|
||||
// driver.ErrBadConn to signal a broken connection.
|
||||
const maxBadConnRetries = 10
|
||||
|
||||
// Prepare creates a prepared statement for later queries or executions.
|
||||
// Multiple queries or executions may be run concurrently from the
|
||||
// returned statement.
|
||||
func (db *DB) Prepare(query string) (*Stmt, error) {
|
||||
var stmt *Stmt
|
||||
var err error
|
||||
for i := 0; i < 10; i++ {
|
||||
for i := 0; i < maxBadConnRetries; i++ {
|
||||
stmt, err = db.prepare(query)
|
||||
if err != driver.ErrBadConn {
|
||||
break
|
||||
@ -846,7 +850,7 @@ func (db *DB) prepare(query string) (*Stmt, error) {
|
||||
func (db *DB) Exec(query string, args ...interface{}) (Result, error) {
|
||||
var res Result
|
||||
var err error
|
||||
for i := 0; i < 10; i++ {
|
||||
for i := 0; i < maxBadConnRetries; i++ {
|
||||
res, err = db.exec(query, args)
|
||||
if err != driver.ErrBadConn {
|
||||
break
|
||||
@ -895,7 +899,7 @@ func (db *DB) exec(query string, args []interface{}) (res Result, err error) {
|
||||
func (db *DB) Query(query string, args ...interface{}) (*Rows, error) {
|
||||
var rows *Rows
|
||||
var err error
|
||||
for i := 0; i < 10; i++ {
|
||||
for i := 0; i < maxBadConnRetries; i++ {
|
||||
rows, err = db.query(query, args)
|
||||
if err != driver.ErrBadConn {
|
||||
break
|
||||
@ -983,7 +987,7 @@ func (db *DB) QueryRow(query string, args ...interface{}) *Row {
|
||||
func (db *DB) Begin() (*Tx, error) {
|
||||
var tx *Tx
|
||||
var err error
|
||||
for i := 0; i < 10; i++ {
|
||||
for i := 0; i < maxBadConnRetries; i++ {
|
||||
tx, err = db.begin()
|
||||
if err != driver.ErrBadConn {
|
||||
break
|
||||
@ -1245,13 +1249,24 @@ type Stmt struct {
|
||||
func (s *Stmt) Exec(args ...interface{}) (Result, error) {
|
||||
s.closemu.RLock()
|
||||
defer s.closemu.RUnlock()
|
||||
dc, releaseConn, si, err := s.connStmt()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer releaseConn(nil)
|
||||
|
||||
return resultFromStatement(driverStmt{dc, si}, args...)
|
||||
var res Result
|
||||
for i := 0; i < maxBadConnRetries; i++ {
|
||||
dc, releaseConn, si, err := s.connStmt()
|
||||
if err != nil {
|
||||
if err == driver.ErrBadConn {
|
||||
continue
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
res, err = resultFromStatement(driverStmt{dc, si}, args...)
|
||||
releaseConn(err)
|
||||
if err != driver.ErrBadConn {
|
||||
return res, err
|
||||
}
|
||||
}
|
||||
return nil, driver.ErrBadConn
|
||||
}
|
||||
|
||||
func resultFromStatement(ds driverStmt, args ...interface{}) (Result, error) {
|
||||
@ -1329,26 +1344,21 @@ func (s *Stmt) connStmt() (ci *driverConn, releaseConn func(error), si driver.St
|
||||
// Make a new conn if all are busy.
|
||||
// TODO(bradfitz): or wait for one? make configurable later?
|
||||
if !match {
|
||||
for i := 0; ; i++ {
|
||||
dc, err := s.db.conn()
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
dc.Lock()
|
||||
si, err := dc.prepareLocked(s.query)
|
||||
dc.Unlock()
|
||||
if err == driver.ErrBadConn && i < 10 {
|
||||
continue
|
||||
}
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
s.mu.Lock()
|
||||
cs = connStmt{dc, si}
|
||||
s.css = append(s.css, cs)
|
||||
s.mu.Unlock()
|
||||
break
|
||||
dc, err := s.db.conn()
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
dc.Lock()
|
||||
si, err := dc.prepareLocked(s.query)
|
||||
dc.Unlock()
|
||||
if err != nil {
|
||||
s.db.putConn(dc, err)
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
s.mu.Lock()
|
||||
cs = connStmt{dc, si}
|
||||
s.css = append(s.css, cs)
|
||||
s.mu.Unlock()
|
||||
}
|
||||
|
||||
conn := cs.dc
|
||||
@ -1361,31 +1371,39 @@ func (s *Stmt) Query(args ...interface{}) (*Rows, error) {
|
||||
s.closemu.RLock()
|
||||
defer s.closemu.RUnlock()
|
||||
|
||||
dc, releaseConn, si, err := s.connStmt()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var rowsi driver.Rows
|
||||
for i := 0; i < maxBadConnRetries; i++ {
|
||||
dc, releaseConn, si, err := s.connStmt()
|
||||
if err != nil {
|
||||
if err == driver.ErrBadConn {
|
||||
continue
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ds := driverStmt{dc, si}
|
||||
rowsi, err := rowsiFromStatement(ds, args...)
|
||||
if err != nil {
|
||||
releaseConn(err)
|
||||
return nil, err
|
||||
}
|
||||
rowsi, err = rowsiFromStatement(driverStmt{dc, si}, args...)
|
||||
if err == nil {
|
||||
// Note: ownership of ci passes to the *Rows, to be freed
|
||||
// with releaseConn.
|
||||
rows := &Rows{
|
||||
dc: dc,
|
||||
rowsi: rowsi,
|
||||
// releaseConn set below
|
||||
}
|
||||
s.db.addDep(s, rows)
|
||||
rows.releaseConn = func(err error) {
|
||||
releaseConn(err)
|
||||
s.db.removeDep(s, rows)
|
||||
}
|
||||
return rows, nil
|
||||
}
|
||||
|
||||
// Note: ownership of ci passes to the *Rows, to be freed
|
||||
// with releaseConn.
|
||||
rows := &Rows{
|
||||
dc: dc,
|
||||
rowsi: rowsi,
|
||||
// releaseConn set below
|
||||
}
|
||||
s.db.addDep(s, rows)
|
||||
rows.releaseConn = func(err error) {
|
||||
releaseConn(err)
|
||||
s.db.removeDep(s, rows)
|
||||
if err != driver.ErrBadConn {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
return rows, nil
|
||||
return nil, driver.ErrBadConn
|
||||
}
|
||||
|
||||
func rowsiFromStatement(ds driverStmt, args ...interface{}) (driver.Rows, error) {
|
||||
|
@ -348,7 +348,6 @@ func TestStatementQueryRow(t *testing.T) {
|
||||
t.Errorf("%d: age=%d, want %d", n, age, tt.want)
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
// golang.org/issue/3734
|
||||
@ -1255,6 +1254,111 @@ func TestStmtCloseOrder(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// golang.org/issue/5781
|
||||
func TestErrBadConnReconnect(t *testing.T) {
|
||||
db := newTestDB(t, "foo")
|
||||
defer closeDB(t, db)
|
||||
exec(t, db, "CREATE|t1|name=string,age=int32,dead=bool")
|
||||
|
||||
simulateBadConn := func(name string, hook *func() bool, op func() error) {
|
||||
broken, retried := false, false
|
||||
numOpen := db.numOpen
|
||||
|
||||
// simulate a broken connection on the first try
|
||||
*hook = func() bool {
|
||||
if !broken {
|
||||
broken = true
|
||||
return true
|
||||
}
|
||||
retried = true
|
||||
return false
|
||||
}
|
||||
|
||||
if err := op(); err != nil {
|
||||
t.Errorf(name+": %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
if !broken || !retried {
|
||||
t.Error(name + ": Failed to simulate broken connection")
|
||||
}
|
||||
*hook = nil
|
||||
|
||||
if numOpen != db.numOpen {
|
||||
t.Errorf(name+": leaked %d connection(s)!", db.numOpen-numOpen)
|
||||
numOpen = db.numOpen
|
||||
}
|
||||
}
|
||||
|
||||
// db.Exec
|
||||
dbExec := func() error {
|
||||
_, err := db.Exec("INSERT|t1|name=?,age=?,dead=?", "Gordon", 3, true)
|
||||
return err
|
||||
}
|
||||
simulateBadConn("db.Exec prepare", &hookPrepareBadConn, dbExec)
|
||||
simulateBadConn("db.Exec exec", &hookExecBadConn, dbExec)
|
||||
|
||||
// db.Query
|
||||
dbQuery := func() error {
|
||||
rows, err := db.Query("SELECT|t1|age,name|")
|
||||
if err == nil {
|
||||
err = rows.Close()
|
||||
}
|
||||
return err
|
||||
}
|
||||
simulateBadConn("db.Query prepare", &hookPrepareBadConn, dbQuery)
|
||||
simulateBadConn("db.Query query", &hookQueryBadConn, dbQuery)
|
||||
|
||||
// db.Prepare
|
||||
simulateBadConn("db.Prepare", &hookPrepareBadConn, func() error {
|
||||
stmt, err := db.Prepare("INSERT|t1|name=?,age=?,dead=?")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
stmt.Close()
|
||||
return nil
|
||||
})
|
||||
|
||||
// stmt.Exec
|
||||
stmt1, err := db.Prepare("INSERT|t1|name=?,age=?,dead=?")
|
||||
if err != nil {
|
||||
t.Fatalf("prepare: %v", err)
|
||||
}
|
||||
defer stmt1.Close()
|
||||
// make sure we must prepare the stmt first
|
||||
for _, cs := range stmt1.css {
|
||||
cs.dc.inUse = true
|
||||
}
|
||||
|
||||
stmtExec := func() error {
|
||||
_, err := stmt1.Exec("Gopher", 3, false)
|
||||
return err
|
||||
}
|
||||
simulateBadConn("stmt.Exec prepare", &hookPrepareBadConn, stmtExec)
|
||||
simulateBadConn("stmt.Exec exec", &hookExecBadConn, stmtExec)
|
||||
|
||||
// stmt.Query
|
||||
stmt2, err := db.Prepare("SELECT|t1|age,name|")
|
||||
if err != nil {
|
||||
t.Fatalf("prepare: %v", err)
|
||||
}
|
||||
defer stmt2.Close()
|
||||
// make sure we must prepare the stmt first
|
||||
for _, cs := range stmt2.css {
|
||||
cs.dc.inUse = true
|
||||
}
|
||||
|
||||
stmtQuery := func() error {
|
||||
rows, err := stmt2.Query()
|
||||
if err == nil {
|
||||
err = rows.Close()
|
||||
}
|
||||
return err
|
||||
}
|
||||
simulateBadConn("stmt.Query prepare", &hookPrepareBadConn, stmtQuery)
|
||||
simulateBadConn("stmt.Query exec", &hookQueryBadConn, stmtQuery)
|
||||
}
|
||||
|
||||
type concurrentTest interface {
|
||||
init(t testing.TB, db *DB)
|
||||
finish(t testing.TB)
|
||||
|
Loading…
Reference in New Issue
Block a user