mirror of
https://github.com/golang/go
synced 2024-09-24 11:10:12 -06:00
database/sql: ensure a Stmt from a Conn executes on the same driver.Conn
Ensure a Stmt prepared on a Conn executes on the same driver.Conn. This also removes another instance of duplicated prepare logic as a side effect. Fixes #20647 Change-Id: Ia00a19e4dd15e19e4d754105babdff5dc127728f Reviewed-on: https://go-review.googlesource.com/45391 Reviewed-by: Brad Fitzpatrick <bradfitz@golang.org> Run-TryBot: Brad Fitzpatrick <bradfitz@golang.org> TryBot-Result: Gobot Gobot <gobot@golang.org>
This commit is contained in:
parent
200d0cc192
commit
cd24a8a550
@ -393,11 +393,19 @@ func (dc *driverConn) expired(timeout time.Duration) bool {
|
||||
return dc.createdAt.Add(timeout).Before(nowFunc())
|
||||
}
|
||||
|
||||
func (dc *driverConn) prepareLocked(ctx context.Context, query string) (*driverStmt, error) {
|
||||
// prepareLocked prepares the query on dc. When cg == nil the dc must keep track of
|
||||
// the prepared statements in a pool.
|
||||
func (dc *driverConn) prepareLocked(ctx context.Context, cg stmtConnGrabber, query string) (*driverStmt, error) {
|
||||
si, err := ctxDriverPrepare(ctx, dc.ci, query)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ds := &driverStmt{Locker: dc, si: si}
|
||||
|
||||
// No need to manage open statements if there is a single connection grabber.
|
||||
if cg != nil {
|
||||
return ds, nil
|
||||
}
|
||||
|
||||
// Track each driverConn's open statements, so we can close them
|
||||
// before closing the conn.
|
||||
@ -406,9 +414,7 @@ func (dc *driverConn) prepareLocked(ctx context.Context, query string) (*driverS
|
||||
if dc.openStmt == nil {
|
||||
dc.openStmt = make(map[*driverStmt]bool)
|
||||
}
|
||||
ds := &driverStmt{Locker: dc, si: si}
|
||||
dc.openStmt[ds] = true
|
||||
|
||||
return ds, nil
|
||||
}
|
||||
|
||||
@ -1165,28 +1171,39 @@ func (db *DB) prepare(ctx context.Context, query string, strategy connReuseStrat
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return db.prepareDC(ctx, dc, dc.releaseConn, query)
|
||||
return db.prepareDC(ctx, dc, dc.releaseConn, nil, query)
|
||||
}
|
||||
|
||||
func (db *DB) prepareDC(ctx context.Context, dc *driverConn, release func(error), query string) (*Stmt, error) {
|
||||
// prepareDC prepares a query on the driverConn and calls release before
|
||||
// returning. When cg == nil it implies that a connection pool is used, and
|
||||
// when cg != nil only a single driver connection is used.
|
||||
func (db *DB) prepareDC(ctx context.Context, dc *driverConn, release func(error), cg stmtConnGrabber, query string) (*Stmt, error) {
|
||||
var ds *driverStmt
|
||||
var err error
|
||||
defer func() {
|
||||
release(err)
|
||||
}()
|
||||
withLock(dc, func() {
|
||||
ds, err = dc.prepareLocked(ctx, query)
|
||||
ds, err = dc.prepareLocked(ctx, cg, query)
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
stmt := &Stmt{
|
||||
db: db,
|
||||
query: query,
|
||||
css: []connStmt{{dc, ds}},
|
||||
lastNumClosed: atomic.LoadUint64(&db.numClosed),
|
||||
db: db,
|
||||
query: query,
|
||||
cg: cg,
|
||||
cgds: ds,
|
||||
}
|
||||
|
||||
// When cg == nil this statement will need to keep track of various
|
||||
// connections they are prepared on and record the stmt dependency on
|
||||
// the DB.
|
||||
if cg == nil {
|
||||
stmt.css = []connStmt{{dc, ds}}
|
||||
stmt.lastNumClosed = atomic.LoadUint64(&db.numClosed)
|
||||
db.addDep(stmt, stmt)
|
||||
}
|
||||
db.addDep(stmt, stmt)
|
||||
return stmt, nil
|
||||
}
|
||||
|
||||
@ -1474,6 +1491,8 @@ func (db *DB) Conn(ctx context.Context) (*Conn, error) {
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
type releaseConn func(error)
|
||||
|
||||
// Conn represents a single database session rather a pool of database
|
||||
// sessions. Prefer running queries from DB unless there is a specific
|
||||
// need for a continuous single database session.
|
||||
@ -1501,46 +1520,41 @@ type Conn struct {
|
||||
done int32
|
||||
}
|
||||
|
||||
func (c *Conn) grabConn() (*driverConn, error) {
|
||||
func (c *Conn) grabConn(context.Context) (*driverConn, releaseConn, error) {
|
||||
if atomic.LoadInt32(&c.done) != 0 {
|
||||
return nil, ErrConnDone
|
||||
return nil, nil, ErrConnDone
|
||||
}
|
||||
return c.dc, nil
|
||||
c.closemu.RLock()
|
||||
return c.dc, c.closemuRUnlockCondReleaseConn, nil
|
||||
}
|
||||
|
||||
// PingContext verifies the connection to the database is still alive.
|
||||
func (c *Conn) PingContext(ctx context.Context) error {
|
||||
dc, err := c.grabConn()
|
||||
dc, release, err := c.grabConn(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
c.closemu.RLock()
|
||||
return c.db.pingDC(ctx, dc, c.closemuRUnlockCondReleaseConn)
|
||||
return c.db.pingDC(ctx, dc, release)
|
||||
}
|
||||
|
||||
// ExecContext executes a query without returning any rows.
|
||||
// The args are for any placeholder parameters in the query.
|
||||
func (c *Conn) ExecContext(ctx context.Context, query string, args ...interface{}) (Result, error) {
|
||||
dc, err := c.grabConn()
|
||||
dc, release, err := c.grabConn(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
c.closemu.RLock()
|
||||
return c.db.execDC(ctx, dc, c.closemuRUnlockCondReleaseConn, query, args)
|
||||
return c.db.execDC(ctx, dc, release, query, args)
|
||||
}
|
||||
|
||||
// QueryContext executes a query that returns rows, typically a SELECT.
|
||||
// The args are for any placeholder parameters in the query.
|
||||
func (c *Conn) QueryContext(ctx context.Context, query string, args ...interface{}) (*Rows, error) {
|
||||
dc, err := c.grabConn()
|
||||
dc, release, err := c.grabConn(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
c.closemu.RLock()
|
||||
return c.db.queryDC(ctx, nil, dc, c.closemuRUnlockCondReleaseConn, query, args)
|
||||
return c.db.queryDC(ctx, nil, dc, release, query, args)
|
||||
}
|
||||
|
||||
// QueryRowContext executes a query that is expected to return at most one row.
|
||||
@ -1563,13 +1577,11 @@ func (c *Conn) QueryRowContext(ctx context.Context, query string, args ...interf
|
||||
// The provided context is used for the preparation of the statement, not for the
|
||||
// execution of the statement.
|
||||
func (c *Conn) PrepareContext(ctx context.Context, query string) (*Stmt, error) {
|
||||
dc, err := c.grabConn()
|
||||
dc, release, err := c.grabConn(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
c.closemu.RLock()
|
||||
return c.db.prepareDC(ctx, dc, c.closemuRUnlockCondReleaseConn, query)
|
||||
return c.db.prepareDC(ctx, dc, release, c, query)
|
||||
}
|
||||
|
||||
// BeginTx starts a transaction.
|
||||
@ -1583,13 +1595,11 @@ func (c *Conn) PrepareContext(ctx context.Context, query string) (*Stmt, error)
|
||||
// If a non-default isolation level is used that the driver doesn't support,
|
||||
// an error will be returned.
|
||||
func (c *Conn) BeginTx(ctx context.Context, opts *TxOptions) (*Tx, error) {
|
||||
dc, err := c.grabConn()
|
||||
dc, release, err := c.grabConn(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
c.closemu.RLock()
|
||||
return c.db.beginDC(ctx, dc, c.closemuRUnlockCondReleaseConn, opts)
|
||||
return c.db.beginDC(ctx, dc, release, opts)
|
||||
}
|
||||
|
||||
// closemuRUnlockCondReleaseConn read unlocks closemu
|
||||
@ -1601,6 +1611,10 @@ func (c *Conn) closemuRUnlockCondReleaseConn(err error) {
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Conn) txCtx() context.Context {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Conn) close(err error) error {
|
||||
if !atomic.CompareAndSwapInt32(&c.done, 0, 1) {
|
||||
return ErrConnDone
|
||||
@ -1712,19 +1726,28 @@ func (tx *Tx) close(err error) {
|
||||
// a successful call to (*Tx).grabConn. For tests.
|
||||
var hookTxGrabConn func()
|
||||
|
||||
func (tx *Tx) grabConn(ctx context.Context) (*driverConn, error) {
|
||||
func (tx *Tx) grabConn(ctx context.Context) (*driverConn, releaseConn, error) {
|
||||
select {
|
||||
default:
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
return nil, nil, ctx.Err()
|
||||
}
|
||||
|
||||
// closeme.RLock must come before the check for isDone to prevent the Tx from
|
||||
// closing while a query is executing.
|
||||
tx.closemu.RLock()
|
||||
if tx.isDone() {
|
||||
return nil, ErrTxDone
|
||||
tx.closemu.RUnlock()
|
||||
return nil, nil, ErrTxDone
|
||||
}
|
||||
if hookTxGrabConn != nil { // test hook
|
||||
hookTxGrabConn()
|
||||
}
|
||||
return tx.dc, nil
|
||||
return tx.dc, tx.closemuRUnlockRelease, nil
|
||||
}
|
||||
|
||||
func (tx *Tx) txCtx() context.Context {
|
||||
return tx.ctx
|
||||
}
|
||||
|
||||
// closemuRUnlockRelease is used as a func(error) method value in
|
||||
@ -1801,31 +1824,15 @@ func (tx *Tx) Rollback() error {
|
||||
// for the execution of the returned statement. The returned statement
|
||||
// will run in the transaction context.
|
||||
func (tx *Tx) PrepareContext(ctx context.Context, query string) (*Stmt, error) {
|
||||
tx.closemu.RLock()
|
||||
defer tx.closemu.RUnlock()
|
||||
|
||||
dc, err := tx.grabConn(ctx)
|
||||
dc, release, err := tx.grabConn(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var si driver.Stmt
|
||||
withLock(dc, func() {
|
||||
si, err = ctxDriverPrepare(ctx, dc.ci, query)
|
||||
})
|
||||
stmt, err := tx.db.prepareDC(ctx, dc, release, tx, query)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
stmt := &Stmt{
|
||||
db: tx.db,
|
||||
tx: tx,
|
||||
txds: &driverStmt{
|
||||
Locker: dc,
|
||||
si: si,
|
||||
},
|
||||
query: query,
|
||||
}
|
||||
tx.stmts.Lock()
|
||||
tx.stmts.v = append(tx.stmts.v, stmt)
|
||||
tx.stmts.Unlock()
|
||||
@ -1855,20 +1862,19 @@ func (tx *Tx) Prepare(query string) (*Stmt, error) {
|
||||
// The returned statement operates within the transaction and will be closed
|
||||
// when the transaction has been committed or rolled back.
|
||||
func (tx *Tx) StmtContext(ctx context.Context, stmt *Stmt) *Stmt {
|
||||
tx.closemu.RLock()
|
||||
defer tx.closemu.RUnlock()
|
||||
dc, release, err := tx.grabConn(ctx)
|
||||
if err != nil {
|
||||
return &Stmt{stickyErr: err}
|
||||
}
|
||||
defer release(nil)
|
||||
|
||||
if tx.db != stmt.db {
|
||||
return &Stmt{stickyErr: errors.New("sql: Tx.Stmt: statement from different database used")}
|
||||
}
|
||||
dc, err := tx.grabConn(ctx)
|
||||
if err != nil {
|
||||
return &Stmt{stickyErr: err}
|
||||
}
|
||||
var si driver.Stmt
|
||||
var parentStmt *Stmt
|
||||
stmt.mu.Lock()
|
||||
if stmt.closed || stmt.tx != nil {
|
||||
if stmt.closed || stmt.cg != nil {
|
||||
// If the statement has been closed or already belongs to a
|
||||
// transaction, we can't reuse it in this connection.
|
||||
// Since tx.StmtContext should never need to be called with a
|
||||
@ -1907,8 +1913,8 @@ func (tx *Tx) StmtContext(ctx context.Context, stmt *Stmt) *Stmt {
|
||||
|
||||
txs := &Stmt{
|
||||
db: tx.db,
|
||||
tx: tx,
|
||||
txds: &driverStmt{
|
||||
cg: tx,
|
||||
cgds: &driverStmt{
|
||||
Locker: dc,
|
||||
si: si,
|
||||
},
|
||||
@ -1943,14 +1949,11 @@ func (tx *Tx) Stmt(stmt *Stmt) *Stmt {
|
||||
// ExecContext executes a query that doesn't return rows.
|
||||
// For example: an INSERT and UPDATE.
|
||||
func (tx *Tx) ExecContext(ctx context.Context, query string, args ...interface{}) (Result, error) {
|
||||
tx.closemu.RLock()
|
||||
|
||||
dc, err := tx.grabConn(ctx)
|
||||
dc, release, err := tx.grabConn(ctx)
|
||||
if err != nil {
|
||||
tx.closemu.RUnlock()
|
||||
return nil, err
|
||||
}
|
||||
return tx.db.execDC(ctx, dc, tx.closemuRUnlockRelease, query, args)
|
||||
return tx.db.execDC(ctx, dc, release, query, args)
|
||||
}
|
||||
|
||||
// Exec executes a query that doesn't return rows.
|
||||
@ -1961,15 +1964,12 @@ func (tx *Tx) Exec(query string, args ...interface{}) (Result, error) {
|
||||
|
||||
// QueryContext executes a query that returns rows, typically a SELECT.
|
||||
func (tx *Tx) QueryContext(ctx context.Context, query string, args ...interface{}) (*Rows, error) {
|
||||
tx.closemu.RLock()
|
||||
|
||||
dc, err := tx.grabConn(ctx)
|
||||
dc, release, err := tx.grabConn(ctx)
|
||||
if err != nil {
|
||||
tx.closemu.RUnlock()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return tx.db.queryDC(ctx, tx.ctx, dc, tx.closemuRUnlockRelease, query, args)
|
||||
return tx.db.queryDC(ctx, tx.ctx, dc, release, query, args)
|
||||
}
|
||||
|
||||
// Query executes a query that returns rows, typically a SELECT.
|
||||
@ -2004,6 +2004,24 @@ type connStmt struct {
|
||||
ds *driverStmt
|
||||
}
|
||||
|
||||
// stmtConnGrabber represents a Tx or Conn that will return the underlying
|
||||
// driverConn and release function.
|
||||
type stmtConnGrabber interface {
|
||||
// grabConn returns the driverConn and the associated release function
|
||||
// that must be called when the operation completes.
|
||||
grabConn(context.Context) (*driverConn, releaseConn, error)
|
||||
|
||||
// txCtx returns the transaction context if available.
|
||||
// The returned context should be selected on along with
|
||||
// any query context when awaiting a cancel.
|
||||
txCtx() context.Context
|
||||
}
|
||||
|
||||
var (
|
||||
_ stmtConnGrabber = &Tx{}
|
||||
_ stmtConnGrabber = &Conn{}
|
||||
)
|
||||
|
||||
// Stmt is a prepared statement.
|
||||
// A Stmt is safe for concurrent use by multiple goroutines.
|
||||
type Stmt struct {
|
||||
@ -2014,9 +2032,13 @@ type Stmt struct {
|
||||
|
||||
closemu sync.RWMutex // held exclusively during close, for read otherwise.
|
||||
|
||||
// If in a transaction, else both nil:
|
||||
tx *Tx
|
||||
txds *driverStmt
|
||||
// If Stmt is prepared on a Tx or Conn then cg is present and will
|
||||
// only ever grab a connection from cg.
|
||||
// If cg is nil then the Stmt must grab an arbitrary connection
|
||||
// from db and determine if it must prepare the stmt again by
|
||||
// inspecting css.
|
||||
cg stmtConnGrabber
|
||||
cgds *driverStmt
|
||||
|
||||
// parentStmt is set when a transaction-specific statement
|
||||
// is requested from an identical statement prepared on the same
|
||||
@ -2031,8 +2053,8 @@ type Stmt struct {
|
||||
|
||||
// css is a list of underlying driver statement interfaces
|
||||
// that are valid on particular connections. This is only
|
||||
// used if tx == nil and one is found that has idle
|
||||
// connections. If tx != nil, txds is always used.
|
||||
// used if cg == nil and one is found that has idle
|
||||
// connections. If cg != nil, cgds is always used.
|
||||
css []connStmt
|
||||
|
||||
// lastNumClosed is copied from db.numClosed when Stmt is created
|
||||
@ -2120,7 +2142,7 @@ func (s *Stmt) removeClosedStmtLocked() {
|
||||
// connStmt returns a free driver connection on which to execute the
|
||||
// statement, a function to call to release the connection, and a
|
||||
// statement bound to that connection.
|
||||
func (s *Stmt) connStmt(ctx context.Context, strategy connReuseStrategy) (ci *driverConn, releaseConn func(error), ds *driverStmt, err error) {
|
||||
func (s *Stmt) connStmt(ctx context.Context, strategy connReuseStrategy) (dc *driverConn, releaseConn func(error), ds *driverStmt, err error) {
|
||||
if err = s.stickyErr; err != nil {
|
||||
return
|
||||
}
|
||||
@ -2131,22 +2153,21 @@ func (s *Stmt) connStmt(ctx context.Context, strategy connReuseStrategy) (ci *dr
|
||||
return
|
||||
}
|
||||
|
||||
// In a transaction, we always use the connection that the
|
||||
// transaction was created on.
|
||||
if s.tx != nil {
|
||||
// In a transaction or connection, we always use the connection that the
|
||||
// the stmt was created on.
|
||||
if s.cg != nil {
|
||||
s.mu.Unlock()
|
||||
ci, err = s.tx.grabConn(ctx) // blocks, waiting for the connection.
|
||||
dc, releaseConn, err = s.cg.grabConn(ctx) // blocks, waiting for the connection.
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
releaseConn = func(error) {}
|
||||
return ci, releaseConn, s.txds, nil
|
||||
return dc, releaseConn, s.cgds, nil
|
||||
}
|
||||
|
||||
s.removeClosedStmtLocked()
|
||||
s.mu.Unlock()
|
||||
|
||||
dc, err := s.db.conn(ctx, strategy)
|
||||
dc, err = s.db.conn(ctx, strategy)
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
@ -2175,7 +2196,7 @@ func (s *Stmt) connStmt(ctx context.Context, strategy connReuseStrategy) (ci *dr
|
||||
// prepareOnConnLocked prepares the query in Stmt s on dc and adds it to the list of
|
||||
// open connStmt on the statement. It assumes the caller is holding the lock on dc.
|
||||
func (s *Stmt) prepareOnConnLocked(ctx context.Context, dc *driverConn) (*driverStmt, error) {
|
||||
si, err := dc.prepareLocked(ctx, s.query)
|
||||
si, err := dc.prepareLocked(ctx, s.cg, s.query)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -2226,8 +2247,8 @@ func (s *Stmt) QueryContext(ctx context.Context, args ...interface{}) (*Rows, er
|
||||
s.db.removeDep(s, rows)
|
||||
}
|
||||
var txctx context.Context
|
||||
if s.tx != nil {
|
||||
txctx = s.tx.ctx
|
||||
if s.cg != nil {
|
||||
txctx = s.cg.txCtx()
|
||||
}
|
||||
rows.initContextClose(ctx, txctx)
|
||||
return rows, nil
|
||||
@ -2323,9 +2344,12 @@ func (s *Stmt) Close() error {
|
||||
return nil
|
||||
}
|
||||
s.closed = true
|
||||
txds := s.cgds
|
||||
s.cgds = nil
|
||||
|
||||
s.mu.Unlock()
|
||||
|
||||
if s.tx == nil {
|
||||
if s.cg == nil {
|
||||
return s.db.removeDep(s, s)
|
||||
}
|
||||
|
||||
@ -2334,7 +2358,7 @@ func (s *Stmt) Close() error {
|
||||
// in the css array of the parentStmt.
|
||||
return s.db.removeDep(s.parentStmt, s)
|
||||
}
|
||||
return s.txds.Close()
|
||||
return txds.Close()
|
||||
}
|
||||
|
||||
func (s *Stmt) finalClose() error {
|
||||
|
@ -877,7 +877,7 @@ func TestStatementClose(t *testing.T) {
|
||||
msg string
|
||||
}{
|
||||
{&Stmt{stickyErr: want}, "stickyErr not propagated"},
|
||||
{&Stmt{tx: &Tx{}, txds: &driverStmt{Locker: &sync.Mutex{}, si: stubDriverStmt{want}}}, "driverStmt.Close() error not propagated"},
|
||||
{&Stmt{cg: &Tx{}, cgds: &driverStmt{Locker: &sync.Mutex{}, si: stubDriverStmt{want}}}, "driverStmt.Close() error not propagated"},
|
||||
}
|
||||
for _, test := range tests {
|
||||
if err := test.stmt.Close(); err != want {
|
||||
@ -3231,6 +3231,42 @@ func TestIssue18719(t *testing.T) {
|
||||
cancel()
|
||||
}
|
||||
|
||||
func TestIssue20647(t *testing.T) {
|
||||
db := newTestDB(t, "people")
|
||||
defer closeDB(t, db)
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
conn, err := db.Conn(ctx)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
stmt, err := conn.PrepareContext(ctx, "SELECT|people|name|")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer stmt.Close()
|
||||
|
||||
rows1, err := stmt.QueryContext(ctx)
|
||||
if err != nil {
|
||||
t.Fatal("rows1", err)
|
||||
}
|
||||
defer rows1.Close()
|
||||
|
||||
rows2, err := stmt.QueryContext(ctx)
|
||||
if err != nil {
|
||||
t.Fatal("rows2", err)
|
||||
}
|
||||
defer rows2.Close()
|
||||
|
||||
if rows1.dc != rows2.dc {
|
||||
t.Fatal("stmt prepared on Conn does not use same connection")
|
||||
}
|
||||
}
|
||||
|
||||
func TestConcurrency(t *testing.T) {
|
||||
list := []struct {
|
||||
name string
|
||||
|
Loading…
Reference in New Issue
Block a user