From 8089e578124227a322a77f9c5599d2244f9d0cfc Mon Sep 17 00:00:00 2001 From: Brad Fitzpatrick Date: Wed, 2 Nov 2011 11:46:04 -0700 Subject: [PATCH] exp/sql: finish transactions, flesh out types, docs Fixes #2328 (float, bool) R=rsc, r CC=golang-dev https://golang.org/cl/5294067 --- src/pkg/Makefile | 1 - src/pkg/exp/sql/convert.go | 71 +++++--- src/pkg/exp/sql/convert_test.go | 52 ++++++ src/pkg/exp/sql/driver/driver.go | 14 +- src/pkg/exp/sql/driver/types.go | 62 ++++++- src/pkg/exp/sql/driver/types_test.go | 57 ++++++ src/pkg/exp/sql/fakedb_test.go | 2 +- src/pkg/exp/sql/sql.go | 260 +++++++++++++++++++++------ 8 files changed, 434 insertions(+), 85 deletions(-) create mode 100644 src/pkg/exp/sql/driver/types_test.go diff --git a/src/pkg/Makefile b/src/pkg/Makefile index f23f7fc4ed..3d11502f24 100644 --- a/src/pkg/Makefile +++ b/src/pkg/Makefile @@ -203,7 +203,6 @@ NOTEST+=\ exp/ebnflint\ exp/gui\ exp/gui/x11\ - exp/sql/driver\ go/doc\ hash\ http/pprof\ diff --git a/src/pkg/exp/sql/convert.go b/src/pkg/exp/sql/convert.go index b1feef0eb8..e46cebe9a3 100644 --- a/src/pkg/exp/sql/convert.go +++ b/src/pkg/exp/sql/convert.go @@ -8,6 +8,7 @@ package sql import ( "errors" + "exp/sql/driver" "fmt" "reflect" "strconv" @@ -36,10 +37,11 @@ func convertAssign(dest, src interface{}) error { } } - sv := reflect.ValueOf(src) + var sv reflect.Value switch d := dest.(type) { case *string: + sv = reflect.ValueOf(src) switch sv.Kind() { case reflect.Bool, reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, @@ -48,6 +50,12 @@ func convertAssign(dest, src interface{}) error { *d = fmt.Sprintf("%v", src) return nil } + case *bool: + bv, err := driver.Bool.ConvertValue(src) + if err == nil { + *d = bv.(bool) + } + return err } if scanner, ok := dest.(ScannerInto); ok { @@ -59,6 +67,10 @@ func convertAssign(dest, src interface{}) error { return errors.New("destination not a pointer") } + if !sv.IsValid() { + sv = reflect.ValueOf(src) + } + dv := reflect.Indirect(dpv) if dv.Kind() == sv.Kind() { dv.Set(sv) @@ -67,40 +79,49 @@ func convertAssign(dest, src interface{}) error { switch dv.Kind() { case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: - if s, ok := asString(src); ok { - i64, err := strconv.Atoi64(s) - if err != nil { - return fmt.Errorf("converting string %q to a %s: %v", s, dv.Kind(), err) - } - if dv.OverflowInt(i64) { - return fmt.Errorf("string %q overflows %s", s, dv.Kind()) - } - dv.SetInt(i64) - return nil + s := asString(src) + i64, err := strconv.Atoi64(s) + if err != nil { + return fmt.Errorf("converting string %q to a %s: %v", s, dv.Kind(), err) } + if dv.OverflowInt(i64) { + return fmt.Errorf("string %q overflows %s", s, dv.Kind()) + } + dv.SetInt(i64) + return nil case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: - if s, ok := asString(src); ok { - u64, err := strconv.Atoui64(s) - if err != nil { - return fmt.Errorf("converting string %q to a %s: %v", s, dv.Kind(), err) - } - if dv.OverflowUint(u64) { - return fmt.Errorf("string %q overflows %s", s, dv.Kind()) - } - dv.SetUint(u64) - return nil + s := asString(src) + u64, err := strconv.Atoui64(s) + if err != nil { + return fmt.Errorf("converting string %q to a %s: %v", s, dv.Kind(), err) } + if dv.OverflowUint(u64) { + return fmt.Errorf("string %q overflows %s", s, dv.Kind()) + } + dv.SetUint(u64) + return nil + case reflect.Float32, reflect.Float64: + s := asString(src) + f64, err := strconv.Atof64(s) + if err != nil { + return fmt.Errorf("converting string %q to a %s: %v", s, dv.Kind(), err) + } + if dv.OverflowFloat(f64) { + return fmt.Errorf("value %q overflows %s", s, dv.Kind()) + } + dv.SetFloat(f64) + return nil } return fmt.Errorf("unsupported driver -> Scan pair: %T -> %T", src, dest) } -func asString(src interface{}) (s string, ok bool) { +func asString(src interface{}) string { switch v := src.(type) { case string: - return v, true + return v case []byte: - return string(v), true + return string(v) } - return "", false + return fmt.Sprintf("%v", src) } diff --git a/src/pkg/exp/sql/convert_test.go b/src/pkg/exp/sql/convert_test.go index f85ed99978..52cee92724 100644 --- a/src/pkg/exp/sql/convert_test.go +++ b/src/pkg/exp/sql/convert_test.go @@ -17,6 +17,9 @@ type conversionTest struct { wantint int64 wantuint uint64 wantstr string + wantf32 float32 + wantf64 float64 + wantbool bool // used if d is of type *bool wanterr string } @@ -29,6 +32,9 @@ var ( scanint32 int32 scanuint8 uint8 scanuint16 uint16 + scanbool bool + scanf32 float32 + scanf64 float64 ) var conversionTests = []conversionTest{ @@ -53,6 +59,35 @@ var conversionTests = []conversionTest{ {s: "256", d: &scanuint16, wantuint: 256}, {s: "-1", d: &scanint, wantint: -1}, {s: "foo", d: &scanint, wanterr: `converting string "foo" to a int: parsing "foo": invalid syntax`}, + + // True bools + {s: true, d: &scanbool, wantbool: true}, + {s: "True", d: &scanbool, wantbool: true}, + {s: "TRUE", d: &scanbool, wantbool: true}, + {s: "1", d: &scanbool, wantbool: true}, + {s: 1, d: &scanbool, wantbool: true}, + {s: int64(1), d: &scanbool, wantbool: true}, + {s: uint16(1), d: &scanbool, wantbool: true}, + + // False bools + {s: false, d: &scanbool, wantbool: false}, + {s: "false", d: &scanbool, wantbool: false}, + {s: "FALSE", d: &scanbool, wantbool: false}, + {s: "0", d: &scanbool, wantbool: false}, + {s: 0, d: &scanbool, wantbool: false}, + {s: int64(0), d: &scanbool, wantbool: false}, + {s: uint16(0), d: &scanbool, wantbool: false}, + + // Not bools + {s: "yup", d: &scanbool, wanterr: `sql/driver: couldn't convert "yup" into type bool`}, + {s: 2, d: &scanbool, wanterr: `sql/driver: couldn't convert 2 into type bool`}, + + // Floats + {s: float64(1.5), d: &scanf64, wantf64: float64(1.5)}, + {s: int64(1), d: &scanf64, wantf64: float64(1)}, + {s: float64(1.5), d: &scanf32, wantf32: float32(1.5)}, + {s: "1.5", d: &scanf32, wantf32: float32(1.5)}, + {s: "1.5", d: &scanf64, wantf64: float64(1.5)}, } func intValue(intptr interface{}) int64 { @@ -63,6 +98,14 @@ func uintValue(intptr interface{}) uint64 { return reflect.Indirect(reflect.ValueOf(intptr)).Uint() } +func float64Value(ptr interface{}) float64 { + return *(ptr.(*float64)) +} + +func float32Value(ptr interface{}) float32 { + return *(ptr.(*float32)) +} + func TestConversions(t *testing.T) { for n, ct := range conversionTests { err := convertAssign(ct.d, ct.s) @@ -86,6 +129,15 @@ func TestConversions(t *testing.T) { if ct.wantuint != 0 && ct.wantuint != uintValue(ct.d) { errf("want uint %d, got %d", ct.wantuint, uintValue(ct.d)) } + if ct.wantf32 != 0 && ct.wantf32 != float32Value(ct.d) { + errf("want float32 %v, got %v", ct.wantf32, float32Value(ct.d)) + } + if ct.wantf64 != 0 && ct.wantf64 != float64Value(ct.d) { + errf("want float32 %v, got %v", ct.wantf64, float64Value(ct.d)) + } + if bp, boolTest := ct.d.(*bool); boolTest && *bp != ct.wantbool && ct.wanterr == "" { + errf("want bool %v, got %v", ct.wantbool, *bp) + } } } diff --git a/src/pkg/exp/sql/driver/driver.go b/src/pkg/exp/sql/driver/driver.go index 52714e817a..6a51c34241 100644 --- a/src/pkg/exp/sql/driver/driver.go +++ b/src/pkg/exp/sql/driver/driver.go @@ -24,9 +24,13 @@ import "errors" // Driver is the interface that must be implemented by a database // driver. type Driver interface { - // Open returns a new or cached connection to the database. + // Open returns a new connection to the database. // The name is a string in a driver-specific format. // + // Open may return a cached connection (one previously + // closed), but doing so is unnecessary; the sql package + // maintains a pool of idle connections for efficient re-use. + // // The returned connection is only used by one goroutine at a // time. Open(name string) (Conn, error) @@ -59,8 +63,12 @@ type Conn interface { // Close invalidates and potentially stops any current // prepared statements and transactions, marking this - // connection as no longer in use. The driver may cache or - // close its underlying connection to its database. + // connection as no longer in use. + // + // Because the sql package maintains a free pool of + // connections and only calls Close when there's a surplus of + // idle connections, it shouldn't be necessary for drivers to + // do their own connection caching. Close() error // Begin starts and returns a new transaction. diff --git a/src/pkg/exp/sql/driver/types.go b/src/pkg/exp/sql/driver/types.go index 9faf32f671..6e0ce4339c 100644 --- a/src/pkg/exp/sql/driver/types.go +++ b/src/pkg/exp/sql/driver/types.go @@ -11,6 +11,21 @@ import ( ) // ValueConverter is the interface providing the ConvertValue method. +// +// Various implementations of ValueConverter are provided by the +// driver package to provide consistent implementations of conversions +// between drivers. The ValueConverters have several uses: +// +// * converting from the subset types as provided by the sql package +// into a database table's specific column type and making sure it +// fits, such as making sure a particular int64 fits in a +// table's uint16 column. +// +// * converting a value as given from the database into one of the +// subset types. +// +// * by the sql package, for converting from a driver's subset type +// to a user's type in a scan. type ValueConverter interface { // ConvertValue converts a value to a restricted subset type. ConvertValue(v interface{}) (interface{}, error) @@ -19,15 +34,56 @@ type ValueConverter interface { // Bool is a ValueConverter that converts input values to bools. // // The conversion rules are: -// - .... TODO(bradfitz): TBD +// - booleans are returned unchanged +// - for integer types, +// 1 is true +// 0 is false, +// other integers are an error +// - for strings and []byte, same rules as strconv.Atob +// - all other types are an error var Bool boolType type boolType struct{} var _ ValueConverter = boolType{} -func (boolType) ConvertValue(v interface{}) (interface{}, error) { - return nil, fmt.Errorf("TODO(bradfitz): bool conversions") +func (boolType) String() string { return "Bool" } + +func (boolType) ConvertValue(src interface{}) (interface{}, error) { + switch s := src.(type) { + case bool: + return s, nil + case string: + b, err := strconv.Atob(s) + if err != nil { + return nil, fmt.Errorf("sql/driver: couldn't convert %q into type bool", s) + } + return b, nil + case []byte: + b, err := strconv.Atob(string(s)) + if err != nil { + return nil, fmt.Errorf("sql/driver: couldn't convert %q into type bool", s) + } + return b, nil + } + + sv := reflect.ValueOf(src) + switch sv.Kind() { + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + iv := sv.Int() + if iv == 1 || iv == 0 { + return iv == 1, nil + } + return nil, fmt.Errorf("sql/driver: couldn't convert %d into type bool", iv) + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + uv := sv.Uint() + if uv == 1 || uv == 0 { + return uv == 1, nil + } + return nil, fmt.Errorf("sql/driver: couldn't convert %d into type bool", uv) + } + + return nil, fmt.Errorf("sql/driver: couldn't convert %v (%T) into type bool", src, src) } // Int32 is a ValueConverter that converts input values to int64, diff --git a/src/pkg/exp/sql/driver/types_test.go b/src/pkg/exp/sql/driver/types_test.go new file mode 100644 index 0000000000..4b049e26e5 --- /dev/null +++ b/src/pkg/exp/sql/driver/types_test.go @@ -0,0 +1,57 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package driver + +import ( + "reflect" + "testing" +) + +type valueConverterTest struct { + c ValueConverter + in interface{} + out interface{} + err string +} + +var valueConverterTests = []valueConverterTest{ + {Bool, "true", true, ""}, + {Bool, "True", true, ""}, + {Bool, []byte("t"), true, ""}, + {Bool, true, true, ""}, + {Bool, "1", true, ""}, + {Bool, 1, true, ""}, + {Bool, int64(1), true, ""}, + {Bool, uint16(1), true, ""}, + {Bool, "false", false, ""}, + {Bool, false, false, ""}, + {Bool, "0", false, ""}, + {Bool, 0, false, ""}, + {Bool, int64(0), false, ""}, + {Bool, uint16(0), false, ""}, + {c: Bool, in: "foo", err: "sql/driver: couldn't convert \"foo\" into type bool"}, + {c: Bool, in: 2, err: "sql/driver: couldn't convert 2 into type bool"}, +} + +func TestValueConverters(t *testing.T) { + for i, tt := range valueConverterTests { + out, err := tt.c.ConvertValue(tt.in) + goterr := "" + if err != nil { + goterr = err.Error() + } + if goterr != tt.err { + t.Errorf("test %d: %s(%T(%v)) error = %q; want error = %q", + i, tt.c, tt.in, tt.in, goterr, tt.err) + } + if tt.err != "" { + continue + } + if !reflect.DeepEqual(out, tt.out) { + t.Errorf("test %d: %s(%T(%v)) = %v (%T); want %v (%T)", + i, tt.c, tt.in, tt.in, out, out, tt.out, tt.out) + } + } +} diff --git a/src/pkg/exp/sql/fakedb_test.go b/src/pkg/exp/sql/fakedb_test.go index 289294bee2..c8a19974d6 100644 --- a/src/pkg/exp/sql/fakedb_test.go +++ b/src/pkg/exp/sql/fakedb_test.go @@ -476,7 +476,7 @@ func (rc *rowsCursor) Next(dest []interface{}) error { for i, v := range rc.rows[rc.pos].cols { // TODO(bradfitz): convert to subset types? naah, I // think the subset types should only be input to - // driver, but the db package should be able to handle + // driver, but the sql package should be able to handle // a wider range of types coming out of drivers. all // for ease of drivers, and to prevent drivers from // messing up conversions or doing them differently. diff --git a/src/pkg/exp/sql/sql.go b/src/pkg/exp/sql/sql.go index 4f1c539127..1af8e063cf 100644 --- a/src/pkg/exp/sql/sql.go +++ b/src/pkg/exp/sql/sql.go @@ -10,7 +10,6 @@ import ( "errors" "fmt" "io" - "runtime" "sync" "exp/sql/driver" @@ -192,13 +191,13 @@ func (db *DB) Exec(query string, args ...interface{}) (Result, error) { // If the driver does not implement driver.Execer, we need // a connection. - conn, err := db.conn() + ci, err := db.conn() if err != nil { return nil, err } - defer db.putConn(conn) + defer db.putConn(ci) - if execer, ok := conn.(driver.Execer); ok { + if execer, ok := ci.(driver.Execer); ok { resi, err := execer.Exec(query, args) if err != nil { return nil, err @@ -206,7 +205,7 @@ func (db *DB) Exec(query string, args ...interface{}) (Result, error) { return result{resi}, nil } - sti, err := conn.Prepare(query) + sti, err := ci.Prepare(query) if err != nil { return nil, err } @@ -233,18 +232,26 @@ func (db *DB) Query(query string, args ...interface{}) (*Rows, error) { // Row's Scan method is called. func (db *DB) QueryRow(query string, args ...interface{}) *Row { rows, err := db.Query(query, args...) - if err != nil { - return &Row{err: err} - } - return &Row{rows: rows} + return &Row{rows: rows, err: err} } -// Begin starts a transaction. The isolation level is dependent on +// Begin starts a transaction. The isolation level is dependent on // the driver. func (db *DB) Begin() (*Tx, error) { - // TODO(bradfitz): add another method for beginning a transaction - // at a specific isolation level. - panic(todo()) + ci, err := db.conn() + if err != nil { + return nil, err + } + txi, err := ci.Begin() + if err != nil { + db.putConn(ci) + return nil, fmt.Errorf("sql: failed to Begin transaction: %v", err) + } + return &Tx{ + db: db, + ci: ci, + txi: txi, + }, nil } // DriverDatabase returns the database's underlying driver. @@ -253,41 +260,158 @@ func (db *DB) Driver() driver.Driver { } // Tx is an in-progress database transaction. +// +// A transaction must end with a call to Commit or Rollback. +// +// After a call to Commit or Rollback, all operations on the +// transaction fail with ErrTransactionFinished. type Tx struct { + db *DB + // ci is owned exclusively until Commit or Rollback, at which point + // it's returned with putConn. + ci driver.Conn + txi driver.Tx + + // cimu is held while somebody is using ci (between grabConn + // and releaseConn) + cimu sync.Mutex + + // done transitions from false to true exactly once, on Commit + // or Rollback. once done, all operations fail with + // ErrTransactionFinished. + done bool +} + +var ErrTransactionFinished = errors.New("sql: Transaction has already been committed or rolled back") + +func (tx *Tx) close() { + if tx.done { + panic("double close") // internal error + } + tx.done = true + tx.db.putConn(tx.ci) + tx.ci = nil + tx.txi = nil +} + +func (tx *Tx) grabConn() (driver.Conn, error) { + if tx.done { + return nil, ErrTransactionFinished + } + tx.cimu.Lock() + return tx.ci, nil +} + +func (tx *Tx) releaseConn() { + tx.cimu.Unlock() } // Commit commits the transaction. func (tx *Tx) Commit() error { - panic(todo()) + if tx.done { + return ErrTransactionFinished + } + defer tx.close() + return tx.txi.Commit() } // Rollback aborts the transaction. func (tx *Tx) Rollback() error { - panic(todo()) + if tx.done { + return ErrTransactionFinished + } + defer tx.close() + return tx.txi.Rollback() } // Prepare creates a prepared statement. +// +// The statement is only valid within the scope of this transaction. func (tx *Tx) Prepare(query string) (*Stmt, error) { - panic(todo()) + // TODO(bradfitz): the restriction that the returned statement + // is only valid for this Transaction is lame and negates a + // lot of the benefit of prepared statements. We could be + // more efficient here and either provide a method to take an + // existing Stmt (created on perhaps a different Conn), and + // re-create it on this Conn if necessary. Or, better: keep a + // map in DB of query string to Stmts, and have Stmt.Execute + // do the right thing and re-prepare if the Conn in use + // doesn't have that prepared statement. But we'll want to + // avoid caching the statement in the case where we only call + // conn.Prepare implicitly (such as in db.Exec or tx.Exec), + // but the caller package can't be holding a reference to the + // returned statement. Perhaps just looking at the reference + // count (by noting Stmt.Close) would be enough. We might also + // want a finalizer on Stmt to drop the reference count. + ci, err := tx.grabConn() + if err != nil { + return nil, err + } + defer tx.releaseConn() + + si, err := ci.Prepare(query) + if err != nil { + return nil, err + } + + stmt := &Stmt{ + db: tx.db, + tx: tx, + txsi: si, + query: query, + } + return stmt, nil } // Exec executes a query that doesn't return rows. // For example: an INSERT and UPDATE. -func (tx *Tx) Exec(query string, args ...interface{}) { - panic(todo()) +func (tx *Tx) Exec(query string, args ...interface{}) (Result, error) { + ci, err := tx.grabConn() + if err != nil { + return nil, err + } + defer tx.releaseConn() + + if execer, ok := ci.(driver.Execer); ok { + resi, err := execer.Exec(query, args) + if err != nil { + return nil, err + } + return result{resi}, nil + } + + sti, err := ci.Prepare(query) + if err != nil { + return nil, err + } + defer sti.Close() + resi, err := sti.Exec(args) + if err != nil { + return nil, err + } + return result{resi}, nil } // Query executes a query that returns rows, typically a SELECT. func (tx *Tx) Query(query string, args ...interface{}) (*Rows, error) { - panic(todo()) + if tx.done { + return nil, ErrTransactionFinished + } + stmt, err := tx.Prepare(query) + if err != nil { + return nil, err + } + defer stmt.Close() + return stmt.Query(args...) } // QueryRow executes a query that is expected to return at most one row. // QueryRow always return a non-nil value. Errors are deferred until // Row's Scan method is called. func (tx *Tx) QueryRow(query string, args ...interface{}) *Row { - panic(todo()) + rows, err := tx.Query(query, args...) + return &Row{rows: rows, err: err} } // connStmt is a prepared statement on a particular connection. @@ -302,24 +426,28 @@ type Stmt struct { db *DB // where we came from query string // that created the Sttm - mu sync.Mutex - closed bool - css []connStmt // can use any that have idle connections -} + // If in a transaction, else both nil: + tx *Tx + txsi driver.Stmt -func todo() string { - _, file, line, _ := runtime.Caller(1) - return fmt.Sprintf("%s:%d: TODO: implement", file, line) + mu sync.Mutex // protects the rest of the fields + closed bool + + // css is a list of underlying driver statement interfaces + // that are valid on particular connections. This is only + // used if tx == nil and one is found that has idle + // connections. If tx != nil, txsi is always used. + css []connStmt } // Exec executes a prepared statement with the given arguments and // returns a Result summarizing the effect of the statement. func (s *Stmt) Exec(args ...interface{}) (Result, error) { - ci, si, err := s.connStmt() + _, releaseConn, si, err := s.connStmt() if err != nil { return nil, err } - defer s.db.putConn(ci) + defer releaseConn() if want := si.NumInput(); len(args) != want { return nil, fmt.Errorf("db: expected %d arguments, got %d", want, len(args)) @@ -353,11 +481,29 @@ func (s *Stmt) Exec(args ...interface{}) (Result, error) { return result{resi}, nil } -func (s *Stmt) connStmt(args ...interface{}) (driver.Conn, driver.Stmt, error) { +// connStmt returns a free driver connection on which to execute the +// statement, a function to call to release the connection, and a +// statement bound to that connection. +func (s *Stmt) connStmt() (ci driver.Conn, releaseConn func(), si driver.Stmt, err error) { s.mu.Lock() if s.closed { - return nil, nil, errors.New("db: statement is closed") + s.mu.Unlock() + err = errors.New("db: statement is closed") + return } + + // In a transaction, we always use the connection that the + // transaction was created on. + if s.tx != nil { + s.mu.Unlock() + ci, err = s.tx.grabConn() // blocks, waiting for the connection. + if err != nil { + return + } + releaseConn = func() { s.tx.releaseConn() } + return ci, releaseConn, s.txsi, nil + } + var cs connStmt match := false for _, v := range s.css { @@ -375,11 +521,11 @@ func (s *Stmt) connStmt(args ...interface{}) (driver.Conn, driver.Stmt, error) { if !match { ci, err := s.db.conn() if err != nil { - return nil, nil, err + return nil, nil, nil, err } si, err := ci.Prepare(s.query) if err != nil { - return nil, nil, err + return nil, nil, nil, err } s.mu.Lock() cs = connStmt{ci, si} @@ -387,13 +533,15 @@ func (s *Stmt) connStmt(args ...interface{}) (driver.Conn, driver.Stmt, error) { s.mu.Unlock() } - return cs.ci, cs.si, nil + conn := cs.ci + releaseConn = func() { s.db.putConn(conn) } + return conn, releaseConn, cs.si, nil } // Query executes a prepared query statement with the given arguments // and returns the query results as a *Rows. func (s *Stmt) Query(args ...interface{}) (*Rows, error) { - ci, si, err := s.connStmt(args...) + ci, releaseConn, si, err := s.connStmt() if err != nil { return nil, err } @@ -405,11 +553,13 @@ func (s *Stmt) Query(args ...interface{}) (*Rows, error) { s.db.putConn(ci) return nil, err } - // Note: ownership of ci passes to the *Rows + // Note: ownership of ci passes to the *Rows, to be freed + // with releaseConn. rows := &Rows{ - db: s.db, - ci: ci, - rowsi: rowsi, + db: s.db, + ci: ci, + releaseConn: releaseConn, + rowsi: rowsi, } return rows, nil } @@ -436,19 +586,24 @@ func (s *Stmt) QueryRow(args ...interface{}) *Row { // Close closes the statement. func (s *Stmt) Close() error { s.mu.Lock() - defer s.mu.Unlock() // TODO(bradfitz): move this unlock after 'closed = true'? + defer s.mu.Unlock() if s.closed { return nil } s.closed = true - for _, v := range s.css { - if ci, match := s.db.connIfFree(v.ci); match { - v.si.Close() - s.db.putConn(ci) - } else { - // TODO(bradfitz): care that we can't close - // this statement because the statement's - // connection is in use? + + if s.tx != nil { + s.txsi.Close() + } else { + for _, v := range s.css { + if ci, match := s.db.connIfFree(v.ci); match { + v.si.Close() + s.db.putConn(ci) + } else { + // TODO(bradfitz): care that we can't close + // this statement because the statement's + // connection is in use? + } } } return nil @@ -468,9 +623,10 @@ func (s *Stmt) Close() error { // err = rows.Error() // get any Error encountered during iteration // ... type Rows struct { - db *DB - ci driver.Conn // owned; must be returned when Rows is closed - rowsi driver.Rows + db *DB + ci driver.Conn // owned; must call putconn when closed to release + releaseConn func() + rowsi driver.Rows closed bool lastcols []interface{} @@ -538,7 +694,7 @@ func (rs *Rows) Close() error { } rs.closed = true err := rs.rowsi.Close() - rs.db.putConn(rs.ci) + rs.releaseConn() return err }