1
0
mirror of https://github.com/golang/go synced 2024-11-22 03:14:41 -07:00

database/sql: make Rows.Next returning false always implicitly call

Rows.Close.

Previously, callers that followed the example code (but not call
rows.Close after "for rows.Next() { ... }") could leak statements if
the driver returned an error other than io.EOF.

R=bradfitz, alex.brainman
CC=golang-dev, rsc
https://golang.org/cl/12677050
This commit is contained in:
Nigel Tao 2013-08-16 11:23:35 +10:00
parent b75a08d03c
commit bc21265074
3 changed files with 49 additions and 16 deletions

View File

@ -608,9 +608,10 @@ rows:
} }
cursor := &rowsCursor{ cursor := &rowsCursor{
pos: -1, pos: -1,
rows: mrows, rows: mrows,
cols: s.colName, cols: s.colName,
errPos: -1,
} }
return cursor, nil return cursor, nil
} }
@ -635,6 +636,10 @@ type rowsCursor struct {
rows []*row rows []*row
closed bool closed bool
// errPos and err are for making Next return early with error.
errPos int
err error
// a clone of slices to give out to clients, indexed by the // a clone of slices to give out to clients, indexed by the
// the original slice's first byte address. we clone them // the original slice's first byte address. we clone them
// just so we're able to corrupt them on close. // just so we're able to corrupt them on close.
@ -660,6 +665,9 @@ func (rc *rowsCursor) Next(dest []driver.Value) error {
return errors.New("fakedb: cursor is closed") return errors.New("fakedb: cursor is closed")
} }
rc.pos++ rc.pos++
if rc.pos == rc.errPos {
return rc.err
}
if rc.pos >= len(rc.rows) { if rc.pos >= len(rc.rows) {
return io.EOF // per interface spec return io.EOF // per interface spec
} }

View File

@ -1293,7 +1293,7 @@ type Rows struct {
closed bool closed bool
lastcols []driver.Value lastcols []driver.Value
lasterr error lasterr error // non-nil only if closed is true
closeStmt driver.Stmt // if non-nil, statement to Close on close closeStmt driver.Stmt // if non-nil, statement to Close on close
} }
@ -1305,20 +1305,19 @@ func (rs *Rows) Next() bool {
if rs.closed { if rs.closed {
return false return false
} }
if rs.lasterr != nil {
return false
}
if rs.lastcols == nil { if rs.lastcols == nil {
rs.lastcols = make([]driver.Value, len(rs.rowsi.Columns())) rs.lastcols = make([]driver.Value, len(rs.rowsi.Columns()))
} }
rs.lasterr = rs.rowsi.Next(rs.lastcols) rs.lasterr = rs.rowsi.Next(rs.lastcols)
if rs.lasterr == io.EOF { if rs.lasterr != nil {
rs.Close() rs.Close()
return false
} }
return rs.lasterr == nil return true
} }
// Err returns the error, if any, that was encountered during iteration. // Err returns the error, if any, that was encountered during iteration.
// Err may be called after an explicit or implicit Close.
func (rs *Rows) Err() error { func (rs *Rows) Err() error {
if rs.lasterr == io.EOF { if rs.lasterr == io.EOF {
return nil return nil
@ -1353,10 +1352,7 @@ func (rs *Rows) Columns() ([]string, error) {
// is of type []byte, a copy is made and the caller owns the result. // is of type []byte, a copy is made and the caller owns the result.
func (rs *Rows) Scan(dest ...interface{}) error { func (rs *Rows) Scan(dest ...interface{}) error {
if rs.closed { if rs.closed {
return errors.New("sql: Rows closed") return errors.New("sql: Rows are closed")
}
if rs.lasterr != nil {
return rs.lasterr
} }
if rs.lastcols == nil { if rs.lastcols == nil {
return errors.New("sql: Scan called without calling Next") return errors.New("sql: Scan called without calling Next")
@ -1375,9 +1371,9 @@ func (rs *Rows) Scan(dest ...interface{}) error {
var rowsCloseHook func(*Rows, *error) var rowsCloseHook func(*Rows, *error)
// Close closes the Rows, preventing further enumeration. If the // Close closes the Rows, preventing further enumeration. If Next returns
// end is encountered, the Rows are closed automatically. Close // false, the Rows are closed automatically and it will suffice to check the
// is idempotent. // result of Err. Close is idempotent and does not affect the result of Err.
func (rs *Rows) Close() error { func (rs *Rows) Close() error {
if rs.closed { if rs.closed {
return nil return nil

View File

@ -6,6 +6,7 @@ package sql
import ( import (
"database/sql/driver" "database/sql/driver"
"errors"
"fmt" "fmt"
"reflect" "reflect"
"runtime" "runtime"
@ -1039,6 +1040,34 @@ func TestRowsCloseOrder(t *testing.T) {
} }
} }
func TestRowsImplicitClose(t *testing.T) {
db := newTestDB(t, "people")
defer closeDB(t, db)
rows, err := db.Query("SELECT|people|age,name|")
if err != nil {
t.Fatal(err)
}
want, fail := 2, errors.New("fail")
r := rows.rowsi.(*rowsCursor)
r.errPos, r.err = want, fail
got := 0
for rows.Next() {
got++
}
if got != want {
t.Errorf("got %d rows, want %d", got, want)
}
if err := rows.Err(); err != fail {
t.Errorf("got error %v, want %v", err, fail)
}
if !r.closed {
t.Errorf("r.closed is false, want true")
}
}
func TestStmtCloseOrder(t *testing.T) { func TestStmtCloseOrder(t *testing.T) {
db := newTestDB(t, "people") db := newTestDB(t, "people")
defer closeDB(t, db) defer closeDB(t, db)