diff --git a/src/pkg/exp/sql/sql.go b/src/pkg/exp/sql/sql.go index c055fdd68c6..f17d12eaa13 100644 --- a/src/pkg/exp/sql/sql.go +++ b/src/pkg/exp/sql/sql.go @@ -344,25 +344,26 @@ func (tx *Tx) Rollback() error { return tx.txi.Rollback() } -// Prepare creates a prepared statement. +// Prepare creates a prepared statement for use within a transaction. // -// The statement is only valid within the scope of this transaction. +// The returned statement operates within the transaction and can no longer +// be used once the transaction has been committed or rolled back. +// +// To use an existing prepared statement on this transaction, see Tx.Stmt. func (tx *Tx) Prepare(query string) (*Stmt, error) { - // TODO(bradfitz): the restriction that the returned statement - // is only valid for this Transaction is lame and negates a - // lot of the benefit of prepared statements. We could be - // more efficient here and either provide a method to take an - // existing Stmt (created on perhaps a different Conn), and - // re-create it on this Conn if necessary. Or, better: keep a - // map in DB of query string to Stmts, and have Stmt.Execute - // do the right thing and re-prepare if the Conn in use - // doesn't have that prepared statement. But we'll want to - // avoid caching the statement in the case where we only call - // conn.Prepare implicitly (such as in db.Exec or tx.Exec), - // but the caller package can't be holding a reference to the - // returned statement. Perhaps just looking at the reference - // count (by noting Stmt.Close) would be enough. We might also - // want a finalizer on Stmt to drop the reference count. + // TODO(bradfitz): We could be more efficient here and either + // provide a method to take an existing Stmt (created on + // perhaps a different Conn), and re-create it on this Conn if + // necessary. Or, better: keep a map in DB of query string to + // Stmts, and have Stmt.Execute do the right thing and + // re-prepare if the Conn in use doesn't have that prepared + // statement. But we'll want to avoid caching the statement + // in the case where we only call conn.Prepare implicitly + // (such as in db.Exec or tx.Exec), but the caller package + // can't be holding a reference to the returned statement. + // Perhaps just looking at the reference count (by noting + // Stmt.Close) would be enough. We might also want a finalizer + // on Stmt to drop the reference count. ci, err := tx.grabConn() if err != nil { return nil, err @@ -383,6 +384,39 @@ func (tx *Tx) Prepare(query string) (*Stmt, error) { return stmt, nil } +// Stmt returns a transaction-specific prepared statement from +// an existing statement. +// +// Example: +// updateMoney, err := db.Prepare("UPDATE balance SET money=money+? WHERE id=?") +// ... +// tx, err := db.Begin() +// ... +// res, err := tx.Stmt(updateMoney).Exec(123.45, 98293203) +func (tx *Tx) Stmt(stmt *Stmt) *Stmt { + // TODO(bradfitz): optimize this. Currently this re-prepares + // each time. This is fine for now to illustrate the API but + // we should really cache already-prepared statements + // per-Conn. See also the big comment in Tx.Prepare. + + if tx.db != stmt.db { + return &Stmt{stickyErr: errors.New("sql: Tx.Stmt: statement from different database used")} + } + ci, err := tx.grabConn() + if err != nil { + return &Stmt{stickyErr: err} + } + defer tx.releaseConn() + si, err := ci.Prepare(stmt.query) + return &Stmt{ + db: tx.db, + tx: tx, + txsi: si, + query: stmt.query, + stickyErr: err, + } +} + // Exec executes a query that doesn't return rows. // For example: an INSERT and UPDATE. func (tx *Tx) Exec(query string, args ...interface{}) (Result, error) { @@ -448,8 +482,9 @@ type connStmt struct { // Stmt is a prepared statement. Stmt is safe for concurrent use by multiple goroutines. type Stmt struct { // Immutable: - db *DB // where we came from - query string // that created the Sttm + db *DB // where we came from + query string // that created the Stmt + stickyErr error // if non-nil, this error is returned for all operations // If in a transaction, else both nil: tx *Tx @@ -513,6 +548,9 @@ func (s *Stmt) Exec(args ...interface{}) (Result, error) { // statement, a function to call to release the connection, and a // statement bound to that connection. func (s *Stmt) connStmt() (ci driver.Conn, releaseConn func(), si driver.Stmt, err error) { + if s.stickyErr != nil { + return nil, nil, nil, s.stickyErr + } s.mu.Lock() if s.closed { s.mu.Unlock() @@ -621,6 +659,9 @@ func (s *Stmt) QueryRow(args ...interface{}) *Row { // Close closes the statement. func (s *Stmt) Close() error { + if s.stickyErr != nil { + return s.stickyErr + } s.mu.Lock() defer s.mu.Unlock() if s.closed { diff --git a/src/pkg/exp/sql/sql_test.go b/src/pkg/exp/sql/sql_test.go index 5b8bcc91422..4f8318d26ef 100644 --- a/src/pkg/exp/sql/sql_test.go +++ b/src/pkg/exp/sql/sql_test.go @@ -166,7 +166,7 @@ func TestBogusPreboundParameters(t *testing.T) { } } -func TestDb(t *testing.T) { +func TestExec(t *testing.T) { db := newTestDB(t, "foo") defer closeDB(t, db) exec(t, db, "CREATE|t1|name=string,age=int32,dead=bool") @@ -206,3 +206,25 @@ func TestDb(t *testing.T) { } } } + +func TestTxStmt(t *testing.T) { + db := newTestDB(t, "") + defer closeDB(t, db) + exec(t, db, "CREATE|t1|name=string,age=int32,dead=bool") + stmt, err := db.Prepare("INSERT|t1|name=?,age=?") + if err != nil { + t.Fatalf("Stmt, err = %v, %v", stmt, err) + } + tx, err := db.Begin() + if err != nil { + t.Fatalf("Begin = %v", err) + } + _, err = tx.Stmt(stmt).Exec("Bobby", 7) + if err != nil { + t.Fatalf("Exec = %v", err) + } + err = tx.Commit() + if err != nil { + t.Fatalf("Commit = %v", err) + } +}