mirror of
https://github.com/golang/go
synced 2024-11-19 20:54:39 -07:00
database/sql: record the context error in Rows if canceled
Previously it was intended that Rows.Scan would return an error and Rows.Err would return nil. This was problematic because drivers could not differentiate between a normal Rows.Close or a context cancel close. The alternative is to require drivers to return a Scan to return an error if the driver is closed while there are still rows to be read. This is currently not how several drivers currently work and may be difficult to detect when there are additional rows. At the same time guard the the Rows.lasterr and prevent a close while a Rows operation is active. For the drivers that do not have Context methods, do not check for context cancelation after the operation, but before for any operation that may modify the database state. Fixes #18961 Change-Id: I49a25318ecd9f97a35d5b50540ecd850c01cfa5e Reviewed-on: https://go-review.googlesource.com/36485 Reviewed-by: Russ Cox <rsc@golang.org> Run-TryBot: Russ Cox <rsc@golang.org> TryBot-Result: Gobot Gobot <gobot@golang.org>
This commit is contained in:
parent
0c9325e13d
commit
c026845bd2
@ -35,15 +35,12 @@ func ctxDriverExec(ctx context.Context, execer driver.Execer, query string, nvda
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
resi, err := execer.Exec(query, dargs)
|
select {
|
||||||
if err == nil {
|
default:
|
||||||
select {
|
case <-ctx.Done():
|
||||||
default:
|
return nil, ctx.Err()
|
||||||
case <-ctx.Done():
|
|
||||||
return resi, ctx.Err()
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
return resi, err
|
return execer.Exec(query, dargs)
|
||||||
}
|
}
|
||||||
|
|
||||||
func ctxDriverQuery(ctx context.Context, queryer driver.Queryer, query string, nvdargs []driver.NamedValue) (driver.Rows, error) {
|
func ctxDriverQuery(ctx context.Context, queryer driver.Queryer, query string, nvdargs []driver.NamedValue) (driver.Rows, error) {
|
||||||
@ -56,16 +53,12 @@ func ctxDriverQuery(ctx context.Context, queryer driver.Queryer, query string, n
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
rowsi, err := queryer.Query(query, dargs)
|
select {
|
||||||
if err == nil {
|
default:
|
||||||
select {
|
case <-ctx.Done():
|
||||||
default:
|
return nil, ctx.Err()
|
||||||
case <-ctx.Done():
|
|
||||||
rowsi.Close()
|
|
||||||
return nil, ctx.Err()
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
return rowsi, err
|
return queryer.Query(query, dargs)
|
||||||
}
|
}
|
||||||
|
|
||||||
func ctxDriverStmtExec(ctx context.Context, si driver.Stmt, nvdargs []driver.NamedValue) (driver.Result, error) {
|
func ctxDriverStmtExec(ctx context.Context, si driver.Stmt, nvdargs []driver.NamedValue) (driver.Result, error) {
|
||||||
@ -77,15 +70,12 @@ func ctxDriverStmtExec(ctx context.Context, si driver.Stmt, nvdargs []driver.Nam
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
resi, err := si.Exec(dargs)
|
select {
|
||||||
if err == nil {
|
default:
|
||||||
select {
|
case <-ctx.Done():
|
||||||
default:
|
return nil, ctx.Err()
|
||||||
case <-ctx.Done():
|
|
||||||
return resi, ctx.Err()
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
return resi, err
|
return si.Exec(dargs)
|
||||||
}
|
}
|
||||||
|
|
||||||
func ctxDriverStmtQuery(ctx context.Context, si driver.Stmt, nvdargs []driver.NamedValue) (driver.Rows, error) {
|
func ctxDriverStmtQuery(ctx context.Context, si driver.Stmt, nvdargs []driver.NamedValue) (driver.Rows, error) {
|
||||||
@ -97,16 +87,12 @@ func ctxDriverStmtQuery(ctx context.Context, si driver.Stmt, nvdargs []driver.Na
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
rowsi, err := si.Query(dargs)
|
select {
|
||||||
if err == nil {
|
default:
|
||||||
select {
|
case <-ctx.Done():
|
||||||
default:
|
return nil, ctx.Err()
|
||||||
case <-ctx.Done():
|
|
||||||
rowsi.Close()
|
|
||||||
return nil, ctx.Err()
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
return rowsi, err
|
return si.Query(dargs)
|
||||||
}
|
}
|
||||||
|
|
||||||
var errLevelNotSupported = errors.New("sql: selected isolation level is not supported")
|
var errLevelNotSupported = errors.New("sql: selected isolation level is not supported")
|
||||||
|
@ -2071,14 +2071,21 @@ type Rows struct {
|
|||||||
dc *driverConn // owned; must call releaseConn when closed to release
|
dc *driverConn // owned; must call releaseConn when closed to release
|
||||||
releaseConn func(error)
|
releaseConn func(error)
|
||||||
rowsi driver.Rows
|
rowsi driver.Rows
|
||||||
|
cancel func() // called when Rows is closed, may be nil.
|
||||||
|
closeStmt *driverStmt // if non-nil, statement to Close on close
|
||||||
|
|
||||||
// closed value is 1 when the Rows is closed.
|
// closemu prevents Rows from closing while there
|
||||||
// Use atomic operations on value when checking value.
|
// is an active streaming result. It is held for read during non-close operations
|
||||||
closed int32
|
// and exclusively during close.
|
||||||
cancel func() // called when Rows is closed, may be nil.
|
//
|
||||||
lastcols []driver.Value
|
// closemu guards lasterr and closed.
|
||||||
lasterr error // non-nil only if closed is true
|
closemu sync.RWMutex
|
||||||
closeStmt *driverStmt // if non-nil, statement to Close on close
|
closed bool
|
||||||
|
lasterr error // non-nil only if closed is true
|
||||||
|
|
||||||
|
// lastcols is only used in Scan, Next, and NextResultSet which are expected
|
||||||
|
// not not be called concurrently.
|
||||||
|
lastcols []driver.Value
|
||||||
}
|
}
|
||||||
|
|
||||||
func (rs *Rows) initContextClose(ctx context.Context) {
|
func (rs *Rows) initContextClose(ctx context.Context) {
|
||||||
@ -2089,7 +2096,7 @@ func (rs *Rows) initContextClose(ctx context.Context) {
|
|||||||
// awaitDone blocks until the rows are closed or the context canceled.
|
// awaitDone blocks until the rows are closed or the context canceled.
|
||||||
func (rs *Rows) awaitDone(ctx context.Context) {
|
func (rs *Rows) awaitDone(ctx context.Context) {
|
||||||
<-ctx.Done()
|
<-ctx.Done()
|
||||||
rs.Close()
|
rs.close(ctx.Err())
|
||||||
}
|
}
|
||||||
|
|
||||||
// Next prepares the next result row for reading with the Scan method. It
|
// Next prepares the next result row for reading with the Scan method. It
|
||||||
@ -2099,8 +2106,19 @@ func (rs *Rows) awaitDone(ctx context.Context) {
|
|||||||
//
|
//
|
||||||
// Every call to Scan, even the first one, must be preceded by a call to Next.
|
// Every call to Scan, even the first one, must be preceded by a call to Next.
|
||||||
func (rs *Rows) Next() bool {
|
func (rs *Rows) Next() bool {
|
||||||
if rs.isClosed() {
|
var doClose, ok bool
|
||||||
return false
|
withLock(rs.closemu.RLocker(), func() {
|
||||||
|
doClose, ok = rs.nextLocked()
|
||||||
|
})
|
||||||
|
if doClose {
|
||||||
|
rs.Close()
|
||||||
|
}
|
||||||
|
return ok
|
||||||
|
}
|
||||||
|
|
||||||
|
func (rs *Rows) nextLocked() (doClose, ok bool) {
|
||||||
|
if rs.closed {
|
||||||
|
return false, false
|
||||||
}
|
}
|
||||||
if rs.lastcols == nil {
|
if rs.lastcols == nil {
|
||||||
rs.lastcols = make([]driver.Value, len(rs.rowsi.Columns()))
|
rs.lastcols = make([]driver.Value, len(rs.rowsi.Columns()))
|
||||||
@ -2109,23 +2127,21 @@ func (rs *Rows) Next() bool {
|
|||||||
if rs.lasterr != nil {
|
if rs.lasterr != nil {
|
||||||
// Close the connection if there is a driver error.
|
// Close the connection if there is a driver error.
|
||||||
if rs.lasterr != io.EOF {
|
if rs.lasterr != io.EOF {
|
||||||
rs.Close()
|
return true, false
|
||||||
return false
|
|
||||||
}
|
}
|
||||||
nextResultSet, ok := rs.rowsi.(driver.RowsNextResultSet)
|
nextResultSet, ok := rs.rowsi.(driver.RowsNextResultSet)
|
||||||
if !ok {
|
if !ok {
|
||||||
rs.Close()
|
return true, false
|
||||||
return false
|
|
||||||
}
|
}
|
||||||
// The driver is at the end of the current result set.
|
// The driver is at the end of the current result set.
|
||||||
// Test to see if there is another result set after the current one.
|
// Test to see if there is another result set after the current one.
|
||||||
// Only close Rows if there is no further result sets to read.
|
// Only close Rows if there is no further result sets to read.
|
||||||
if !nextResultSet.HasNextResultSet() {
|
if !nextResultSet.HasNextResultSet() {
|
||||||
rs.Close()
|
doClose = true
|
||||||
}
|
}
|
||||||
return false
|
return doClose, false
|
||||||
}
|
}
|
||||||
return true
|
return false, true
|
||||||
}
|
}
|
||||||
|
|
||||||
// NextResultSet prepares the next result set for reading. It returns true if
|
// NextResultSet prepares the next result set for reading. It returns true if
|
||||||
@ -2137,18 +2153,28 @@ func (rs *Rows) Next() bool {
|
|||||||
// scanning. If there are further result sets they may not have rows in the result
|
// scanning. If there are further result sets they may not have rows in the result
|
||||||
// set.
|
// set.
|
||||||
func (rs *Rows) NextResultSet() bool {
|
func (rs *Rows) NextResultSet() bool {
|
||||||
if rs.isClosed() {
|
var doClose bool
|
||||||
|
defer func() {
|
||||||
|
if doClose {
|
||||||
|
rs.Close()
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
rs.closemu.RLock()
|
||||||
|
defer rs.closemu.RUnlock()
|
||||||
|
|
||||||
|
if rs.closed {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
rs.lastcols = nil
|
rs.lastcols = nil
|
||||||
nextResultSet, ok := rs.rowsi.(driver.RowsNextResultSet)
|
nextResultSet, ok := rs.rowsi.(driver.RowsNextResultSet)
|
||||||
if !ok {
|
if !ok {
|
||||||
rs.Close()
|
doClose = true
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
rs.lasterr = nextResultSet.NextResultSet()
|
rs.lasterr = nextResultSet.NextResultSet()
|
||||||
if rs.lasterr != nil {
|
if rs.lasterr != nil {
|
||||||
rs.Close()
|
doClose = true
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
return true
|
return true
|
||||||
@ -2157,6 +2183,8 @@ func (rs *Rows) NextResultSet() bool {
|
|||||||
// Err returns the error, if any, that was encountered during iteration.
|
// Err returns the error, if any, that was encountered during iteration.
|
||||||
// Err may be called after an explicit or implicit Close.
|
// Err may be called after an explicit or implicit Close.
|
||||||
func (rs *Rows) Err() error {
|
func (rs *Rows) Err() error {
|
||||||
|
rs.closemu.RLock()
|
||||||
|
defer rs.closemu.RUnlock()
|
||||||
if rs.lasterr == io.EOF {
|
if rs.lasterr == io.EOF {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@ -2167,7 +2195,9 @@ func (rs *Rows) Err() error {
|
|||||||
// Columns returns an error if the rows are closed, or if the rows
|
// Columns returns an error if the rows are closed, or if the rows
|
||||||
// are from QueryRow and there was a deferred error.
|
// are from QueryRow and there was a deferred error.
|
||||||
func (rs *Rows) Columns() ([]string, error) {
|
func (rs *Rows) Columns() ([]string, error) {
|
||||||
if rs.isClosed() {
|
rs.closemu.RLock()
|
||||||
|
defer rs.closemu.RUnlock()
|
||||||
|
if rs.closed {
|
||||||
return nil, errors.New("sql: Rows are closed")
|
return nil, errors.New("sql: Rows are closed")
|
||||||
}
|
}
|
||||||
if rs.rowsi == nil {
|
if rs.rowsi == nil {
|
||||||
@ -2179,7 +2209,9 @@ func (rs *Rows) Columns() ([]string, error) {
|
|||||||
// ColumnTypes returns column information such as column type, length,
|
// ColumnTypes returns column information such as column type, length,
|
||||||
// and nullable. Some information may not be available from some drivers.
|
// and nullable. Some information may not be available from some drivers.
|
||||||
func (rs *Rows) ColumnTypes() ([]*ColumnType, error) {
|
func (rs *Rows) ColumnTypes() ([]*ColumnType, error) {
|
||||||
if rs.isClosed() {
|
rs.closemu.RLock()
|
||||||
|
defer rs.closemu.RUnlock()
|
||||||
|
if rs.closed {
|
||||||
return nil, errors.New("sql: Rows are closed")
|
return nil, errors.New("sql: Rows are closed")
|
||||||
}
|
}
|
||||||
if rs.rowsi == nil {
|
if rs.rowsi == nil {
|
||||||
@ -2329,9 +2361,13 @@ func rowsColumnInfoSetup(rowsi driver.Rows) []*ColumnType {
|
|||||||
// For scanning into *bool, the source may be true, false, 1, 0, or
|
// For scanning into *bool, the source may be true, false, 1, 0, or
|
||||||
// string inputs parseable by strconv.ParseBool.
|
// string inputs parseable by strconv.ParseBool.
|
||||||
func (rs *Rows) Scan(dest ...interface{}) error {
|
func (rs *Rows) Scan(dest ...interface{}) error {
|
||||||
if rs.isClosed() {
|
rs.closemu.RLock()
|
||||||
|
if rs.closed {
|
||||||
|
rs.closemu.RUnlock()
|
||||||
return errors.New("sql: Rows are closed")
|
return errors.New("sql: Rows are closed")
|
||||||
}
|
}
|
||||||
|
rs.closemu.RUnlock()
|
||||||
|
|
||||||
if rs.lastcols == nil {
|
if rs.lastcols == nil {
|
||||||
return errors.New("sql: Scan called without calling Next")
|
return errors.New("sql: Scan called without calling Next")
|
||||||
}
|
}
|
||||||
@ -2351,20 +2387,28 @@ func (rs *Rows) Scan(dest ...interface{}) error {
|
|||||||
// hook through a test only mutex.
|
// hook through a test only mutex.
|
||||||
var rowsCloseHook = func() func(*Rows, *error) { return nil }
|
var rowsCloseHook = func() func(*Rows, *error) { return nil }
|
||||||
|
|
||||||
func (rs *Rows) isClosed() bool {
|
|
||||||
return atomic.LoadInt32(&rs.closed) != 0
|
|
||||||
}
|
|
||||||
|
|
||||||
// Close closes the Rows, preventing further enumeration. If Next is called
|
// Close closes the Rows, preventing further enumeration. If Next is called
|
||||||
// and returns false and there are no further result sets,
|
// and returns false and there are no further result sets,
|
||||||
// the Rows are closed automatically and it will suffice to check the
|
// the Rows are closed automatically and it will suffice to check the
|
||||||
// result of Err. Close is idempotent and does not affect the result of Err.
|
// result of Err. Close is idempotent and does not affect the result of Err.
|
||||||
func (rs *Rows) Close() error {
|
func (rs *Rows) Close() error {
|
||||||
if !atomic.CompareAndSwapInt32(&rs.closed, 0, 1) {
|
return rs.close(nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (rs *Rows) close(err error) error {
|
||||||
|
rs.closemu.Lock()
|
||||||
|
defer rs.closemu.Unlock()
|
||||||
|
|
||||||
|
if rs.closed {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
rs.closed = true
|
||||||
|
|
||||||
err := rs.rowsi.Close()
|
if rs.lasterr == nil {
|
||||||
|
rs.lasterr = err
|
||||||
|
}
|
||||||
|
|
||||||
|
err = rs.rowsi.Close()
|
||||||
if fn := rowsCloseHook(); fn != nil {
|
if fn := rowsCloseHook(); fn != nil {
|
||||||
fn(rs, &err)
|
fn(rs, &err)
|
||||||
}
|
}
|
||||||
|
@ -313,9 +313,13 @@ func TestQueryContext(t *testing.T) {
|
|||||||
got = append(got, r)
|
got = append(got, r)
|
||||||
index++
|
index++
|
||||||
}
|
}
|
||||||
err = rows.Err()
|
select {
|
||||||
if err != nil {
|
case <-ctx.Done():
|
||||||
t.Fatalf("Err: %v", err)
|
if err := ctx.Err(); err != context.Canceled {
|
||||||
|
t.Fatalf("context err = %v; want context.Canceled")
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
t.Fatalf("context err = nil; want context.Canceled")
|
||||||
}
|
}
|
||||||
want := []row{
|
want := []row{
|
||||||
{age: 1, name: "Alice"},
|
{age: 1, name: "Alice"},
|
||||||
|
Loading…
Reference in New Issue
Block a user