diff --git a/src/database/sql/sql_test.go b/src/database/sql/sql_test.go index 6bc869fc864..8c58723c030 100644 --- a/src/database/sql/sql_test.go +++ b/src/database/sql/sql_test.go @@ -449,6 +449,16 @@ func TestQueryContextWait(t *testing.T) { // TestTxContextWait tests the transaction behavior when the tx context is canceled // during execution of the query. func TestTxContextWait(t *testing.T) { + testContextWait(t, false) +} + +// TestTxContextWaitNoDiscard is the same as TestTxContextWait, but should not discard +// the final connection. +func TestTxContextWaitNoDiscard(t *testing.T) { + testContextWait(t, true) +} + +func testContextWait(t *testing.T, keepConnOnRollback bool) { db := newTestDB(t, "people") defer closeDB(t, db) @@ -458,7 +468,7 @@ func TestTxContextWait(t *testing.T) { if err != nil { t.Fatal(err) } - tx.keepConnOnRollback = false + tx.keepConnOnRollback = keepConnOnRollback tx.dc.ci.(*fakeConn).waiter = func(c context.Context) { cancel() @@ -472,36 +482,11 @@ func TestTxContextWait(t *testing.T) { t.Fatalf("expected QueryContext to error with context canceled but returned %v", err) } - waitForFree(t, db, 0) -} - -// TestTxContextWaitNoDiscard is the same as TestTxContextWait, but should not discard -// the final connection. -func TestTxContextWaitNoDiscard(t *testing.T) { - db := newTestDB(t, "people") - defer closeDB(t, db) - - ctx, cancel := context.WithTimeout(context.Background(), 15*time.Millisecond) - defer cancel() - - tx, err := db.BeginTx(ctx, nil) - if err != nil { - // Guard against the context being canceled before BeginTx completes. - if err == context.DeadlineExceeded { - t.Skip("tx context canceled prior to first use") - } - t.Fatal(err) + if keepConnOnRollback { + waitForFree(t, db, 1) + } else { + waitForFree(t, db, 0) } - - // This will trigger the *fakeConn.Prepare method which will take time - // performing the query. The ctxDriverPrepare func will check the context - // after this and close the rows and return an error. - _, err = tx.QueryContext(ctx, "WAIT|1s|SELECT|people|age,name|") - if err != context.DeadlineExceeded { - t.Fatalf("expected QueryContext to error with context deadline exceeded but returned %v", err) - } - - waitForFree(t, db, 1) } // TestUnsupportedOptions checks that the database fails when a driver that