From 190d0d3e69b113bea0b6b604ba2f0beb62c08741 Mon Sep 17 00:00:00 2001 From: Brad Fitzpatrick Date: Sat, 16 Mar 2024 18:29:06 -0700 Subject: [PATCH] database/sql: optimize connection request pool MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This replaces a map used as a set with a slice. We were using a surprising amount of CPU in this code, making mapiters to pull out a random element of the map. Instead, just rand.IntN to pick a random element of the slice. It also adds a benchmark: │ before │ after │ │ sec/op │ sec/op vs base │ ConnRequestSet-8 1818.0n ± 0% 452.4n ± 0% -75.12% (p=0.000 n=10) (whether random is a good policy is a bigger question, but this optimizes the current policy without changing behavior) Updates #66361 Change-Id: I3d456a819cc720c2d18e1befffd2657e5f50f1e7 Reviewed-on: https://go-review.googlesource.com/c/go/+/572119 Reviewed-by: Cherry Mui Reviewed-by: Ian Lance Taylor Reviewed-by: Emmanuel Odeke Reviewed-by: David Chase LUCI-TryBot-Result: Go LUCI Auto-Submit: Brad Fitzpatrick --- src/database/sql/sql.go | 163 +++++++++++++++++++++++++++-------- src/database/sql/sql_test.go | 104 +++++++++++++++++++++- src/go/build/deps_test.go | 5 +- 3 files changed, 234 insertions(+), 38 deletions(-) diff --git a/src/database/sql/sql.go b/src/database/sql/sql.go index 4f1197dc6e..b5facdbf2a 100644 --- a/src/database/sql/sql.go +++ b/src/database/sql/sql.go @@ -21,6 +21,7 @@ import ( "errors" "fmt" "io" + "math/rand/v2" "reflect" "runtime" "sort" @@ -497,9 +498,8 @@ type DB struct { mu sync.Mutex // protects following fields freeConn []*driverConn // free connections ordered by returnedAt oldest to newest - connRequests map[uint64]chan connRequest - nextRequest uint64 // Next key to use in connRequests. - numOpen int // number of opened and pending open connections + connRequests connRequestSet + numOpen int // number of opened and pending open connections // Used to signal the need for new connections // a goroutine running connectionOpener() reads on this chan and // maybeOpenNewConnections sends on the chan (one send per needed connection) @@ -814,11 +814,10 @@ func (t dsnConnector) Driver() driver.Driver { func OpenDB(c driver.Connector) *DB { ctx, cancel := context.WithCancel(context.Background()) db := &DB{ - connector: c, - openerCh: make(chan struct{}, connectionRequestQueueSize), - lastPut: make(map[*driverConn]string), - connRequests: make(map[uint64]chan connRequest), - stop: cancel, + connector: c, + openerCh: make(chan struct{}, connectionRequestQueueSize), + lastPut: make(map[*driverConn]string), + stop: cancel, } go db.connectionOpener(ctx) @@ -922,9 +921,7 @@ func (db *DB) Close() error { } db.freeConn = nil db.closed = true - for _, req := range db.connRequests { - close(req) - } + db.connRequests.CloseAndRemoveAll() db.mu.Unlock() for _, fn := range fns { err1 := fn() @@ -1223,7 +1220,7 @@ func (db *DB) Stats() DBStats { // If there are connRequests and the connection limit hasn't been reached, // then tell the connectionOpener to open new connections. func (db *DB) maybeOpenNewConnections() { - numRequests := len(db.connRequests) + numRequests := db.connRequests.Len() if db.maxOpen > 0 { numCanOpen := db.maxOpen - db.numOpen if numRequests > numCanOpen { @@ -1297,14 +1294,6 @@ type connRequest struct { var errDBClosed = errors.New("sql: database is closed") -// nextRequestKeyLocked returns the next connection request key. -// It is assumed that nextRequest will not overflow. -func (db *DB) nextRequestKeyLocked() uint64 { - next := db.nextRequest - db.nextRequest++ - return next -} - // conn returns a newly-opened or cached *driverConn. func (db *DB) conn(ctx context.Context, strategy connReuseStrategy) (*driverConn, error) { db.mu.Lock() @@ -1352,8 +1341,7 @@ func (db *DB) conn(ctx context.Context, strategy connReuseStrategy) (*driverConn // Make the connRequest channel. It's buffered so that the // connectionOpener doesn't block while waiting for the req to be read. req := make(chan connRequest, 1) - reqKey := db.nextRequestKeyLocked() - db.connRequests[reqKey] = req + delHandle := db.connRequests.Add(req) db.waitCount++ db.mu.Unlock() @@ -1365,16 +1353,26 @@ func (db *DB) conn(ctx context.Context, strategy connReuseStrategy) (*driverConn // Remove the connection request and ensure no value has been sent // on it after removing. db.mu.Lock() - delete(db.connRequests, reqKey) + deleted := db.connRequests.Delete(delHandle) db.mu.Unlock() db.waitDuration.Add(int64(time.Since(waitStart))) - select { - default: - case ret, ok := <-req: - if ok && ret.conn != nil { - db.putConn(ret.conn, ret.err, false) + // If we failed to delete it, that means something else + // grabbed it and is about to send on it. + if !deleted { + // TODO(bradfitz): rather than this best effort select, we + // should probably start a goroutine to read from req. This best + // effort select existed before the change to check 'deleted'. + // But if we know for sure it wasn't deleted and a sender is + // outstanding, we should probably block on req (in a new + // goroutine) to get the connection back. + select { + default: + case ret, ok := <-req: + if ok && ret.conn != nil { + db.putConn(ret.conn, ret.err, false) + } } } return nil, ctx.Err() @@ -1530,13 +1528,7 @@ func (db *DB) putConnDBLocked(dc *driverConn, err error) bool { if db.maxOpen > 0 && db.numOpen > db.maxOpen { return false } - if c := len(db.connRequests); c > 0 { - var req chan connRequest - var reqKey uint64 - for reqKey, req = range db.connRequests { - break - } - delete(db.connRequests, reqKey) // Remove from pending requests. + if req, ok := db.connRequests.TakeRandom(); ok { if err == nil { dc.inUse = true } @@ -3529,3 +3521,104 @@ func withLock(lk sync.Locker, fn func()) { defer lk.Unlock() // in case fn panics fn() } + +// connRequestSet is a set of chan connRequest that's +// optimized for: +// +// - adding an element +// - removing an element (only by the caller who added it) +// - taking (get + delete) a random element +// +// We previously used a map for this but the take of a random element +// was expensive, making mapiters. This type avoids a map entirely +// and just uses a slice. +type connRequestSet struct { + // s are the elements in the set. + s []connRequestAndIndex +} + +type connRequestAndIndex struct { + // req is the element in the set. + req chan connRequest + + // curIdx points to the current location of this element in + // connRequestSet.s. It gets set to -1 upon removal. + curIdx *int +} + +// CloseAndRemoveAll closes all channels in the set +// and clears the set. +func (s *connRequestSet) CloseAndRemoveAll() { + for _, v := range s.s { + close(v.req) + } + s.s = nil +} + +// Len returns the length of the set. +func (s *connRequestSet) Len() int { return len(s.s) } + +// connRequestDelHandle is an opaque handle to delete an +// item from calling Add. +type connRequestDelHandle struct { + idx *int // pointer to index; or -1 if not in slice +} + +// Add adds v to the set of waiting requests. +// The returned connRequestDelHandle can be used to remove the item from +// the set. +func (s *connRequestSet) Add(v chan connRequest) connRequestDelHandle { + idx := len(s.s) + // TODO(bradfitz): for simplicity, this always allocates a new int-sized + // allocation to store the index. But generally the set will be small and + // under a scannable-threshold. As an optimization, we could permit the *int + // to be nil when the set is small and should be scanned. This works even if + // the set grows over the threshold with delete handles outstanding because + // an element can only move to a lower index. So if it starts with a nil + // position, it'll always be in a low index and thus scannable. But that + // can be done in a follow-up change. + idxPtr := &idx + s.s = append(s.s, connRequestAndIndex{v, idxPtr}) + return connRequestDelHandle{idxPtr} +} + +// Delete removes an element from the set. +// +// It reports whether the element was deleted. (It can return false if a caller +// of TakeRandom took it meanwhile, or upon the second call to Delete) +func (s *connRequestSet) Delete(h connRequestDelHandle) bool { + idx := *h.idx + if idx < 0 { + return false + } + s.deleteIndex(idx) + return true +} + +func (s *connRequestSet) deleteIndex(idx int) { + // Mark item as deleted. + *(s.s[idx].curIdx) = -1 + // Copy last element, updating its position + // to its new home. + if idx < len(s.s)-1 { + last := s.s[len(s.s)-1] + *last.curIdx = idx + s.s[idx] = last + } + // Zero out last element (for GC) before shrinking the slice. + s.s[len(s.s)-1] = connRequestAndIndex{} + s.s = s.s[:len(s.s)-1] +} + +// TakeRandom returns and removes a random element from s +// and reports whether there was one to take. (It returns ok=false +// if the set is empty.) +func (s *connRequestSet) TakeRandom() (v chan connRequest, ok bool) { + if len(s.s) == 0 { + return nil, false + } + pick := rand.IntN(len(s.s)) + e := s.s[pick] + s.deleteIndex(pick) + return e.req, true +} diff --git a/src/database/sql/sql_test.go b/src/database/sql/sql_test.go index bf0ecc243f..e786ecbfab 100644 --- a/src/database/sql/sql_test.go +++ b/src/database/sql/sql_test.go @@ -14,6 +14,7 @@ import ( "math/rand" "reflect" "runtime" + "slices" "strings" "sync" "sync/atomic" @@ -2983,7 +2984,7 @@ func TestConnExpiresFreshOutOfPool(t *testing.T) { return } db.mu.Lock() - ct := len(db.connRequests) + ct := db.connRequests.Len() db.mu.Unlock() if ct > 0 { return @@ -4803,3 +4804,104 @@ func BenchmarkGrabConn(b *testing.B) { release(nil) } } + +func TestConnRequestSet(t *testing.T) { + var s connRequestSet + wantLen := func(want int) { + t.Helper() + if got := s.Len(); got != want { + t.Errorf("Len = %d; want %d", got, want) + } + if want == 0 && !t.Failed() { + if _, ok := s.TakeRandom(); ok { + t.Fatalf("TakeRandom returned result when empty") + } + } + } + reset := func() { s = connRequestSet{} } + + t.Run("add-delete", func(t *testing.T) { + reset() + wantLen(0) + dh := s.Add(nil) + wantLen(1) + if !s.Delete(dh) { + t.Fatal("failed to delete") + } + wantLen(0) + if s.Delete(dh) { + t.Error("delete worked twice") + } + wantLen(0) + }) + t.Run("take-before-delete", func(t *testing.T) { + reset() + ch1 := make(chan connRequest) + dh := s.Add(ch1) + wantLen(1) + if got, ok := s.TakeRandom(); !ok || got != ch1 { + t.Fatalf("wrong take; ok=%v", ok) + } + wantLen(0) + if s.Delete(dh) { + t.Error("unexpected delete after take") + } + }) + t.Run("get-take-many", func(t *testing.T) { + reset() + m := map[chan connRequest]bool{} + const N = 100 + var inOrder, backOut []chan connRequest + for range N { + c := make(chan connRequest) + m[c] = true + s.Add(c) + inOrder = append(inOrder, c) + } + if s.Len() != N { + t.Fatalf("Len = %v; want %v", s.Len(), N) + } + for s.Len() > 0 { + c, ok := s.TakeRandom() + if !ok { + t.Fatal("failed to take when non-empty") + } + if !m[c] { + t.Fatal("returned item not in remaining set") + } + delete(m, c) + backOut = append(backOut, c) + } + if len(m) > 0 { + t.Error("items remain in expected map") + } + if slices.Equal(inOrder, backOut) { // N! chance of flaking; N=100 is fine + t.Error("wasn't random") + } + }) +} + +func BenchmarkConnRequestSet(b *testing.B) { + var s connRequestSet + for range b.N { + for range 16 { + s.Add(nil) + } + for range 8 { + if _, ok := s.TakeRandom(); !ok { + b.Fatal("want ok") + } + } + for range 8 { + s.Add(nil) + } + for range 16 { + if _, ok := s.TakeRandom(); !ok { + b.Fatal("want ok") + } + } + if _, ok := s.TakeRandom(); ok { + b.Fatal("unexpected ok") + } + } +} diff --git a/src/go/build/deps_test.go b/src/go/build/deps_test.go index 26e6e8a77d..427f5a96b5 100644 --- a/src/go/build/deps_test.go +++ b/src/go/build/deps_test.go @@ -321,8 +321,9 @@ var depsRules = ` # databases FMT < database/sql/internal - < database/sql/driver - < database/sql; + < database/sql/driver; + + database/sql/driver, math/rand/v2 < database/sql; # images FMT, compress/lzw, compress/zlib