From bc2126507472e51a6820aecce9f07df6e4231a0a Mon Sep 17 00:00:00 2001 From: Nigel Tao Date: Fri, 16 Aug 2013 11:23:35 +1000 Subject: [PATCH] database/sql: make Rows.Next returning false always implicitly call Rows.Close. Previously, callers that followed the example code (but not call rows.Close after "for rows.Next() { ... }") could leak statements if the driver returned an error other than io.EOF. R=bradfitz, alex.brainman CC=golang-dev, rsc https://golang.org/cl/12677050 --- src/pkg/database/sql/fakedb_test.go | 14 +++++++++++--- src/pkg/database/sql/sql.go | 22 +++++++++------------- src/pkg/database/sql/sql_test.go | 29 +++++++++++++++++++++++++++++ 3 files changed, 49 insertions(+), 16 deletions(-) diff --git a/src/pkg/database/sql/fakedb_test.go b/src/pkg/database/sql/fakedb_test.go index d900e2cebe8..8af753b5d35 100644 --- a/src/pkg/database/sql/fakedb_test.go +++ b/src/pkg/database/sql/fakedb_test.go @@ -608,9 +608,10 @@ rows: } cursor := &rowsCursor{ - pos: -1, - rows: mrows, - cols: s.colName, + pos: -1, + rows: mrows, + cols: s.colName, + errPos: -1, } return cursor, nil } @@ -635,6 +636,10 @@ type rowsCursor struct { rows []*row closed bool + // errPos and err are for making Next return early with error. + errPos int + err error + // a clone of slices to give out to clients, indexed by the // the original slice's first byte address. we clone them // just so we're able to corrupt them on close. @@ -660,6 +665,9 @@ func (rc *rowsCursor) Next(dest []driver.Value) error { return errors.New("fakedb: cursor is closed") } rc.pos++ + if rc.pos == rc.errPos { + return rc.err + } if rc.pos >= len(rc.rows) { return io.EOF // per interface spec } diff --git a/src/pkg/database/sql/sql.go b/src/pkg/database/sql/sql.go index f0c86a8aeb2..d81f6fe9842 100644 --- a/src/pkg/database/sql/sql.go +++ b/src/pkg/database/sql/sql.go @@ -1293,7 +1293,7 @@ type Rows struct { closed bool lastcols []driver.Value - lasterr error + lasterr error // non-nil only if closed is true closeStmt driver.Stmt // if non-nil, statement to Close on close } @@ -1305,20 +1305,19 @@ func (rs *Rows) Next() bool { if rs.closed { return false } - if rs.lasterr != nil { - return false - } if rs.lastcols == nil { rs.lastcols = make([]driver.Value, len(rs.rowsi.Columns())) } rs.lasterr = rs.rowsi.Next(rs.lastcols) - if rs.lasterr == io.EOF { + if rs.lasterr != nil { rs.Close() + return false } - return rs.lasterr == nil + return true } // Err returns the error, if any, that was encountered during iteration. +// Err may be called after an explicit or implicit Close. func (rs *Rows) Err() error { if rs.lasterr == io.EOF { return nil @@ -1353,10 +1352,7 @@ func (rs *Rows) Columns() ([]string, error) { // is of type []byte, a copy is made and the caller owns the result. func (rs *Rows) Scan(dest ...interface{}) error { if rs.closed { - return errors.New("sql: Rows closed") - } - if rs.lasterr != nil { - return rs.lasterr + return errors.New("sql: Rows are closed") } if rs.lastcols == nil { return errors.New("sql: Scan called without calling Next") @@ -1375,9 +1371,9 @@ func (rs *Rows) Scan(dest ...interface{}) error { var rowsCloseHook func(*Rows, *error) -// Close closes the Rows, preventing further enumeration. If the -// end is encountered, the Rows are closed automatically. Close -// is idempotent. +// Close closes the Rows, preventing further enumeration. If Next returns +// false, the Rows are closed automatically and it will suffice to check the +// result of Err. Close is idempotent and does not affect the result of Err. func (rs *Rows) Close() error { if rs.closed { return nil diff --git a/src/pkg/database/sql/sql_test.go b/src/pkg/database/sql/sql_test.go index 2a059da4531..4005f154466 100644 --- a/src/pkg/database/sql/sql_test.go +++ b/src/pkg/database/sql/sql_test.go @@ -6,6 +6,7 @@ package sql import ( "database/sql/driver" + "errors" "fmt" "reflect" "runtime" @@ -1039,6 +1040,34 @@ func TestRowsCloseOrder(t *testing.T) { } } +func TestRowsImplicitClose(t *testing.T) { + db := newTestDB(t, "people") + defer closeDB(t, db) + + rows, err := db.Query("SELECT|people|age,name|") + if err != nil { + t.Fatal(err) + } + + want, fail := 2, errors.New("fail") + r := rows.rowsi.(*rowsCursor) + r.errPos, r.err = want, fail + + got := 0 + for rows.Next() { + got++ + } + if got != want { + t.Errorf("got %d rows, want %d", got, want) + } + if err := rows.Err(); err != fail { + t.Errorf("got error %v, want %v", err, fail) + } + if !r.closed { + t.Errorf("r.closed is false, want true") + } +} + func TestStmtCloseOrder(t *testing.T) { db := newTestDB(t, "people") defer closeDB(t, db)