diff --git a/src/pkg/exp/sql/fakedb_test.go b/src/pkg/exp/sql/fakedb_test.go index 2474a86f644..0883dd9f3ed 100644 --- a/src/pkg/exp/sql/fakedb_test.go +++ b/src/pkg/exp/sql/fakedb_test.go @@ -110,25 +110,34 @@ func init() { // Supports dsn forms: // -// ;wipe +// ; (no currently supported options) func (d *fakeDriver) Open(dsn string) (driver.Conn, error) { - d.mu.Lock() - defer d.mu.Unlock() - d.openCount++ - if d.dbs == nil { - d.dbs = make(map[string]*fakeDB) - } parts := strings.Split(dsn, ";") if len(parts) < 1 { return nil, errors.New("fakedb: no database name") } name := parts[0] + + db := d.getDB(name) + + d.mu.Lock() + d.openCount++ + d.mu.Unlock() + return &fakeConn{db: db}, nil +} + +func (d *fakeDriver) getDB(name string) *fakeDB { + d.mu.Lock() + defer d.mu.Unlock() + if d.dbs == nil { + d.dbs = make(map[string]*fakeDB) + } db, ok := d.dbs[name] if !ok { db = &fakeDB{name: name} d.dbs[name] = db } - return &fakeConn{db: db}, nil + return db } func (db *fakeDB) wipe() { diff --git a/src/pkg/exp/sql/sql.go b/src/pkg/exp/sql/sql.go index 937982cdbe6..f53691b7c42 100644 --- a/src/pkg/exp/sql/sql.go +++ b/src/pkg/exp/sql/sql.go @@ -549,8 +549,8 @@ func (s *Stmt) Exec(args ...interface{}) (Result, error) { // statement, a function to call to release the connection, and a // statement bound to that connection. func (s *Stmt) connStmt() (ci driver.Conn, releaseConn func(), si driver.Stmt, err error) { - if s.stickyErr != nil { - return nil, nil, nil, s.stickyErr + if err = s.stickyErr; err != nil { + return } s.mu.Lock() if s.closed { @@ -726,6 +726,9 @@ func (rs *Rows) Next() bool { rs.lastcols = make([]interface{}, len(rs.rowsi.Columns())) } rs.lasterr = rs.rowsi.Next(rs.lastcols) + if rs.lasterr == io.EOF { + rs.Close() + } return rs.lasterr == nil } diff --git a/src/pkg/exp/sql/sql_test.go b/src/pkg/exp/sql/sql_test.go index 5307a235ddf..590bf818fef 100644 --- a/src/pkg/exp/sql/sql_test.go +++ b/src/pkg/exp/sql/sql_test.go @@ -10,8 +10,10 @@ import ( "testing" ) +const fakeDBName = "foo" + func newTestDB(t *testing.T, name string) *DB { - db, err := Open("test", "foo") + db, err := Open("test", fakeDBName) if err != nil { t.Fatalf("Open: %v", err) } @@ -73,6 +75,12 @@ func TestQuery(t *testing.T) { if !reflect.DeepEqual(got, want) { t.Logf(" got: %#v\nwant: %#v", got, want) } + + // And verify that the final rows.Next() call, which hit EOF, + // also closed the rows connection. + if n := len(db.freeConn); n != 1 { + t.Errorf("free conns after query hitting EOF = %d; want 1", n) + } } func TestRowsColumns(t *testing.T) {