diff --git a/src/pkg/database/sql/sql.go b/src/pkg/database/sql/sql.go index d89aa597927..bd450c7ec9c 100644 --- a/src/pkg/database/sql/sql.go +++ b/src/pkg/database/sql/sql.go @@ -192,14 +192,12 @@ type DB struct { driver driver.Driver dsn string - mu sync.Mutex // protects following fields - outConn map[*driverConn]bool // whether the conn is in use - freeConn []*driverConn - closed bool - dep map[finalCloser]depSet - onConnPut map[*driverConn][]func() // code (with mu held) run when conn is next returned - lastPut map[*driverConn]string // stacktrace of last conn's put; debug only - maxIdle int // zero means defaultMaxIdleConns; negative means 0 + mu sync.Mutex // protects following fields + freeConn []*driverConn + closed bool + dep map[finalCloser]depSet + lastPut map[*driverConn]string // stacktrace of last conn's put; debug only + maxIdle int // zero means defaultMaxIdleConns; negative means 0 } // driverConn wraps a driver.Conn with a mutex, to @@ -212,6 +210,10 @@ type driverConn struct { sync.Mutex // guards following ci driver.Conn closed bool + + // guarded by db.mu + inUse bool + onPut []func() // code (with db.mu held) run when conn is next returned } // the dc.db's Mutex is held. @@ -341,11 +343,9 @@ func Open(driverName, dataSourceName string) (*DB, error) { return nil, fmt.Errorf("sql: unknown driver %q (forgotten import?)", driverName) } db := &DB{ - driver: driveri, - dsn: dataSourceName, - outConn: make(map[*driverConn]bool), - lastPut: make(map[*driverConn]string), - onConnPut: make(map[*driverConn][]func()), + driver: driveri, + dsn: dataSourceName, + lastPut: make(map[*driverConn]string), } return db, nil } @@ -427,7 +427,7 @@ func (db *DB) conn() (*driverConn, error) { if n := len(db.freeConn); n > 0 { conn := db.freeConn[n-1] db.freeConn = db.freeConn[:n-1] - db.outConn[conn] = true + conn.inUse = true db.mu.Unlock() return conn, nil } @@ -443,7 +443,7 @@ func (db *DB) conn() (*driverConn, error) { } db.mu.Lock() db.addDepLocked(dc, dc) - db.outConn[dc] = true + dc.inUse = true db.mu.Unlock() return dc, nil } @@ -456,7 +456,7 @@ func (db *DB) conn() (*driverConn, error) { func (db *DB) connIfFree(wanted *driverConn) (conn *driverConn, ok bool) { db.mu.Lock() defer db.mu.Unlock() - if db.outConn[wanted] { + if wanted.inUse { return conn, false } for i, conn := range db.freeConn { @@ -465,7 +465,7 @@ func (db *DB) connIfFree(wanted *driverConn) (conn *driverConn, ok bool) { } db.freeConn[i] = db.freeConn[len(db.freeConn)-1] db.freeConn = db.freeConn[:len(db.freeConn)-1] - db.outConn[wanted] = true + wanted.inUse = true return wanted, true } return nil, false @@ -480,8 +480,8 @@ var putConnHook func(*DB, *driverConn) func (db *DB) noteUnusedDriverStatement(c *driverConn, si driver.Stmt) { db.mu.Lock() defer db.mu.Unlock() - if db.outConn[c] { - db.onConnPut[c] = append(db.onConnPut[c], func() { + if c.inUse { + c.onPut = append(c.onPut, func() { si.Close() }) } else { @@ -497,7 +497,7 @@ const debugGetPut = false // err is optionally the last error that occurred on this connection. func (db *DB) putConn(dc *driverConn, err error) { db.mu.Lock() - if !db.outConn[dc] { + if !dc.inUse { if debugGetPut { fmt.Printf("putConn(%v) DUPLICATE was: %s\n\nPREVIOUS was: %s", dc, stack(), db.lastPut[dc]) } @@ -506,14 +506,12 @@ func (db *DB) putConn(dc *driverConn, err error) { if debugGetPut { db.lastPut[dc] = stack() } - delete(db.outConn, dc) + dc.inUse = false - if fns, ok := db.onConnPut[dc]; ok { - for _, fn := range fns { - fn() - } - delete(db.onConnPut, dc) + for _, fn := range dc.onPut { + fn() } + dc.onPut = nil if err == driver.ErrBadConn { // Don't reuse bad connections. diff --git a/src/pkg/database/sql/sql_test.go b/src/pkg/database/sql/sql_test.go index 54aad3a5d01..6b91783784e 100644 --- a/src/pkg/database/sql/sql_test.go +++ b/src/pkg/database/sql/sql_test.go @@ -7,7 +7,9 @@ package sql import ( "fmt" "reflect" + "runtime" "strings" + "sync" "testing" "time" ) @@ -36,7 +38,14 @@ const fakeDBName = "foo" var chrisBirthday = time.Unix(123456789, 0) -func newTestDB(t *testing.T, name string) *DB { +type testOrBench interface { + Fatalf(string, ...interface{}) + Errorf(string, ...interface{}) + Fatal(...interface{}) + Error(...interface{}) +} + +func newTestDB(t testOrBench, name string) *DB { db, err := Open("test", fakeDBName) if err != nil { t.Fatalf("Open: %v", err) @@ -53,14 +62,14 @@ func newTestDB(t *testing.T, name string) *DB { return db } -func exec(t *testing.T, db *DB, query string, args ...interface{}) { +func exec(t testOrBench, db *DB, query string, args ...interface{}) { _, err := db.Exec(query, args...) if err != nil { t.Fatalf("Exec of %q: %v", query, err) } } -func closeDB(t *testing.T, db *DB) { +func closeDB(t testOrBench, db *DB) { if e := recover(); e != nil { fmt.Printf("Panic: %v\n", e) panic(e) @@ -844,3 +853,63 @@ func TestCloseConnBeforeStmts(t *testing.T) { t.Errorf("after Stmt Close, driverConn's Conn interface should be nil") } } + +func manyConcurrentQueries(t testOrBench) { + maxProcs, numReqs := 16, 500 + if testing.Short() { + maxProcs, numReqs = 4, 50 + } + defer runtime.GOMAXPROCS(runtime.GOMAXPROCS(maxProcs)) + + db := newTestDB(t, "people") + defer closeDB(t, db) + + stmt, err := db.Prepare("SELECT|people|name|") + if err != nil { + t.Fatal(err) + } + + var wg sync.WaitGroup + wg.Add(numReqs) + + reqs := make(chan bool) + defer close(reqs) + + for i := 0; i < maxProcs*2; i++ { + go func() { + for _ = range reqs { + rows, err := stmt.Query() + if err != nil { + t.Errorf("error on query: %v", err) + wg.Done() + continue + } + + var name string + for rows.Next() { + rows.Scan(&name) + } + rows.Close() + + wg.Done() + } + }() + } + + for i := 0; i < numReqs; i++ { + reqs <- true + } + + wg.Wait() +} + +func TestConcurrency(t *testing.T) { + manyConcurrentQueries(t) +} + +func BenchmarkConcurrency(b *testing.B) { + b.ReportAllocs() + for i := 0; i < b.N; i++ { + manyConcurrentQueries(b) + } +}