diff --git a/src/pkg/exp/sql/driver/driver.go b/src/pkg/exp/sql/driver/driver.go index 91a388421d1..1139afa6bb4 100644 --- a/src/pkg/exp/sql/driver/driver.go +++ b/src/pkg/exp/sql/driver/driver.go @@ -94,12 +94,35 @@ type Result interface { // used by multiple goroutines concurrently. type Stmt interface { // Close closes the statement. + // + // Closing a statement should not interrupt any outstanding + // query created from that statement. That is, the following + // order of operations is valid: + // + // * create a driver statement + // * call Query on statement, returning Rows + // * close the statement + // * read from Rows + // + // If closing a statement invalidates currently-running + // queries, the final step above will incorrectly fail. + // + // TODO(bradfitz): possibly remove the restriction above, if + // enough driver authors object and find it complicates their + // code too much. The sql package could be smarter about + // refcounting the statement and closing it at the appropriate + // time. Close() error // NumInput returns the number of placeholder parameters. - // -1 means the driver doesn't know how to count the number of - // placeholders, so we won't sanity check input here and instead let the - // driver deal with errors. + // + // If NumInput returns >= 0, the sql package will sanity check + // argument counts from callers and return errors to the caller + // before the statement's Exec or Query methods are called. + // + // NumInput may also return -1, if the driver doesn't know + // its number of placeholders. In that case, the sql package + // will not sanity check Exec or Query argument counts. NumInput() int // Exec executes a query that doesn't return rows, such diff --git a/src/pkg/exp/sql/fakedb_test.go b/src/pkg/exp/sql/fakedb_test.go index 17028e2cc38..2474a86f644 100644 --- a/src/pkg/exp/sql/fakedb_test.go +++ b/src/pkg/exp/sql/fakedb_test.go @@ -90,6 +90,8 @@ type fakeStmt struct { cmd string table string + closed bool + colName []string // used by CREATE, INSERT, SELECT (selected columns) colType []string // used by CREATE colValue []interface{} // used by INSERT (mix of strings and "?" for bound params) @@ -232,6 +234,9 @@ func (c *fakeConn) prepareSelect(stmt *fakeStmt, parts []string) (driver.Stmt, e stmt.table = parts[0] stmt.colName = strings.Split(parts[1], ",") for n, colspec := range strings.Split(parts[2], ",") { + if colspec == "" { + continue + } nameVal := strings.Split(colspec, "=") if len(nameVal) != 2 { return nil, errf("SELECT on table %q has invalid column spec of %q (index %d)", stmt.table, colspec, n) @@ -342,10 +347,16 @@ func (s *fakeStmt) ColumnConverter(idx int) driver.ValueConverter { } func (s *fakeStmt) Close() error { + s.closed = true return nil } +var errClosed = errors.New("fakedb: statement has been closed") + func (s *fakeStmt) Exec(args []interface{}) (driver.Result, error) { + if s.closed { + return nil, errClosed + } err := checkSubsetTypes(args) if err != nil { return nil, err @@ -405,6 +416,9 @@ func (s *fakeStmt) execInsert(args []interface{}) (driver.Result, error) { } func (s *fakeStmt) Query(args []interface{}) (driver.Rows, error) { + if s.closed { + return nil, errClosed + } err := checkSubsetTypes(args) if err != nil { return nil, err diff --git a/src/pkg/exp/sql/sql_test.go b/src/pkg/exp/sql/sql_test.go index d365f6ba190..5b8bcc91422 100644 --- a/src/pkg/exp/sql/sql_test.go +++ b/src/pkg/exp/sql/sql_test.go @@ -5,6 +5,7 @@ package sql import ( + "reflect" "strings" "testing" ) @@ -22,7 +23,6 @@ func newTestDB(t *testing.T, name string) *DB { exec(t, db, "INSERT|people|name=Alice,age=?", 1) exec(t, db, "INSERT|people|name=Bob,age=?", 2) exec(t, db, "INSERT|people|name=Chris,age=?", 3) - } return db } @@ -42,6 +42,40 @@ func closeDB(t *testing.T, db *DB) { } func TestQuery(t *testing.T) { + db := newTestDB(t, "people") + defer closeDB(t, db) + rows, err := db.Query("SELECT|people|age,name|") + if err != nil { + t.Fatalf("Query: %v", err) + } + type row struct { + age int + name string + } + got := []row{} + for rows.Next() { + var r row + err = rows.Scan(&r.age, &r.name) + if err != nil { + t.Fatalf("Scan: %v", err) + } + got = append(got, r) + } + err = rows.Err() + if err != nil { + t.Fatalf("Err: %v", err) + } + want := []row{ + {age: 1, name: "Alice"}, + {age: 2, name: "Bob"}, + {age: 3, name: "Chris"}, + } + if !reflect.DeepEqual(got, want) { + t.Logf(" got: %#v\nwant: %#v", got, want) + } +} + +func TestQueryRow(t *testing.T) { db := newTestDB(t, "people") defer closeDB(t, db) var name string @@ -75,6 +109,24 @@ func TestQuery(t *testing.T) { } } +func TestStatementErrorAfterClose(t *testing.T) { + db := newTestDB(t, "people") + defer closeDB(t, db) + stmt, err := db.Prepare("SELECT|people|age|name=?") + if err != nil { + t.Fatalf("Prepare: %v", err) + } + err = stmt.Close() + if err != nil { + t.Fatalf("Close: %v", err) + } + var name string + err = stmt.QueryRow("foo").Scan(&name) + if err == nil { + t.Errorf("expected error from QueryRow.Scan after Stmt.Close") + } +} + func TestStatementQueryRow(t *testing.T) { db := newTestDB(t, "people") defer closeDB(t, db)