diff --git a/src/database/sql/sql.go b/src/database/sql/sql.go index ad17eb3da2c..8dd48107a6c 100644 --- a/src/database/sql/sql.go +++ b/src/database/sql/sql.go @@ -1947,20 +1947,27 @@ type Conn struct { // it's returned to the connection pool. dc *driverConn - // done transitions from 0 to 1 exactly once, on close. + // done transitions from false to true exactly once, on close. // Once done, all operations fail with ErrConnDone. - // Use atomic operations on value when checking value. - done int32 + done atomic.Bool + + // releaseConn is a cache of c.closemuRUnlockCondReleaseConn + // to save allocations in a call to grabConn. + releaseConnOnce sync.Once + releaseConnCache releaseConn } // grabConn takes a context to implement stmtConnGrabber // but the context is not used. func (c *Conn) grabConn(context.Context) (*driverConn, releaseConn, error) { - if atomic.LoadInt32(&c.done) != 0 { + if c.done.Load() { return nil, nil, ErrConnDone } + c.releaseConnOnce.Do(func() { + c.releaseConnCache = c.closemuRUnlockCondReleaseConn + }) c.closemu.RLock() - return c.dc, c.closemuRUnlockCondReleaseConn, nil + return c.dc, c.releaseConnCache, nil } // PingContext verifies the connection to the database is still alive. @@ -2084,7 +2091,7 @@ func (c *Conn) txCtx() context.Context { } func (c *Conn) close(err error) error { - if !atomic.CompareAndSwapInt32(&c.done, 0, 1) { + if !c.done.CompareAndSwap(false, true) { return ErrConnDone } diff --git a/src/database/sql/sql_test.go b/src/database/sql/sql_test.go index 8c58723c030..2b3d76f5130 100644 --- a/src/database/sql/sql_test.go +++ b/src/database/sql/sql_test.go @@ -9,6 +9,8 @@ import ( "database/sql/driver" "errors" "fmt" + "internal/race" + "internal/testenv" "math/rand" "reflect" "runtime" @@ -4583,3 +4585,35 @@ func BenchmarkManyConcurrentQueries(b *testing.B) { } }) } + +func TestGrabConnAllocs(t *testing.T) { + testenv.SkipIfOptimizationOff(t) + if race.Enabled { + t.Skip("skipping allocation test when using race detector") + } + c := new(Conn) + ctx := context.Background() + n := int(testing.AllocsPerRun(1000, func() { + _, release, err := c.grabConn(ctx) + if err != nil { + t.Fatal(err) + } + release(nil) + })) + if n > 0 { + t.Fatalf("Conn.grabConn allocated %v objects; want 0", n) + } +} + +func BenchmarkGrabConn(b *testing.B) { + b.ReportAllocs() + c := new(Conn) + ctx := context.Background() + for i := 0; i < b.N; i++ { + _, release, err := c.grabConn(ctx) + if err != nil { + b.Fatal(err) + } + release(nil) + } +}