diff --git a/src/pkg/database/sql/sql.go b/src/pkg/database/sql/sql.go index 35d5338c564..352a7e82d94 100644 --- a/src/pkg/database/sql/sql.go +++ b/src/pkg/database/sql/sql.go @@ -1372,6 +1372,8 @@ func (rs *Rows) Scan(dest ...interface{}) error { return nil } +var rowsCloseHook func(*Rows, *error) + // Close closes the Rows, preventing further enumeration. If the // end is encountered, the Rows are closed automatically. Close // is idempotent. @@ -1381,6 +1383,9 @@ func (rs *Rows) Close() error { } rs.closed = true err := rs.rowsi.Close() + if fn := rowsCloseHook; fn != nil { + fn(rs, &err) + } if rs.closeStmt != nil { rs.closeStmt.Close() } diff --git a/src/pkg/database/sql/sql_test.go b/src/pkg/database/sql/sql_test.go index fc620bd6dcf..2b9347aedaa 100644 --- a/src/pkg/database/sql/sql_test.go +++ b/src/pkg/database/sql/sql_test.go @@ -5,6 +5,7 @@ package sql import ( + "database/sql/driver" "fmt" "reflect" "runtime" @@ -1110,6 +1111,52 @@ func manyConcurrentQueries(t testOrBench) { wg.Wait() } +func TestIssue6081(t *testing.T) { + t.Skip("known broken test") + db := newTestDB(t, "people") + defer closeDB(t, db) + + drv := db.driver.(*fakeDriver) + drv.mu.Lock() + opens0 := drv.openCount + closes0 := drv.closeCount + drv.mu.Unlock() + + stmt, err := db.Prepare("SELECT|people|name|") + if err != nil { + t.Fatal(err) + } + rowsCloseHook = func(rows *Rows, err *error) { + *err = driver.ErrBadConn + } + defer func() { rowsCloseHook = nil }() + for i := 0; i < 10; i++ { + rows, err := stmt.Query() + if err != nil { + t.Fatal(err) + } + rows.Close() + } + if n := len(stmt.css); n > 1 { + t.Errorf("len(css slice) = %d; want <= 1", n) + } + stmt.Close() + if n := len(stmt.css); n != 0 { + t.Errorf("len(css slice) after Close = %d; want 0", n) + } + + drv.mu.Lock() + opens := drv.openCount - opens0 + closes := drv.closeCount - closes0 + drv.mu.Unlock() + if opens < 9 { + t.Errorf("opens = %d; want >= 9", opens) + } + if closes < 9 { + t.Errorf("closes = %d; want >= 9", closes) + } +} + func TestConcurrency(t *testing.T) { manyConcurrentQueries(t) }