diff --git a/src/pkg/database/sql/fakedb_test.go b/src/pkg/database/sql/fakedb_test.go index 8af753b5d35..39c02827897 100644 --- a/src/pkg/database/sql/fakedb_test.go +++ b/src/pkg/database/sql/fakedb_test.go @@ -447,6 +447,10 @@ func (c *fakeConn) Prepare(query string) (driver.Stmt, error) { return c.prepareCreate(stmt, parts) case "INSERT": return c.prepareInsert(stmt, parts) + case "NOSERT": + // Do all the prep-work like for an INSERT but don't actually insert the row. + // Used for some of the concurrent tests. + return c.prepareInsert(stmt, parts) default: stmt.Close() return nil, errf("unsupported command type %q", cmd) @@ -497,13 +501,20 @@ func (s *fakeStmt) Exec(args []driver.Value) (driver.Result, error) { } return driver.ResultNoRows, nil case "INSERT": - return s.execInsert(args) + return s.execInsert(args, true) + case "NOSERT": + // Do all the prep-work like for an INSERT but don't actually insert the row. + // Used for some of the concurrent tests. + return s.execInsert(args, false) } fmt.Printf("EXEC statement, cmd=%q: %#v\n", s.cmd, s) return nil, fmt.Errorf("unimplemented statement Exec command type of %q", s.cmd) } -func (s *fakeStmt) execInsert(args []driver.Value) (driver.Result, error) { +// When doInsert is true, add the row to the table. +// When doInsert is false do prep-work and error checking, but don't +// actually add the row to the table. +func (s *fakeStmt) execInsert(args []driver.Value, doInsert bool) (driver.Result, error) { db := s.c.db if len(args) != s.placeholders { panic("error in pkg db; should only get here if size is correct") @@ -518,7 +529,10 @@ func (s *fakeStmt) execInsert(args []driver.Value) (driver.Result, error) { t.mu.Lock() defer t.mu.Unlock() - cols := make([]interface{}, len(t.colname)) + var cols []interface{} + if doInsert { + cols = make([]interface{}, len(t.colname)) + } argPos := 0 for n, colname := range s.colName { colidx := t.columnIndex(colname) @@ -532,10 +546,14 @@ func (s *fakeStmt) execInsert(args []driver.Value) (driver.Result, error) { } else { val = s.colValue[n] } - cols[colidx] = val + if doInsert { + cols[colidx] = val + } } - t.rows = append(t.rows, &row{cols: cols}) + if doInsert { + t.rows = append(t.rows, &row{cols: cols}) + } return driver.RowsAffected(1), nil } diff --git a/src/pkg/database/sql/sql.go b/src/pkg/database/sql/sql.go index d81f6fe9842..44257778c1c 100644 --- a/src/pkg/database/sql/sql.go +++ b/src/pkg/database/sql/sql.go @@ -10,6 +10,7 @@ package sql import ( + "container/list" "database/sql/driver" "errors" "fmt" @@ -192,12 +193,22 @@ type DB struct { driver driver.Driver dsn string - mu sync.Mutex // protects following fields - freeConn []*driverConn + mu sync.Mutex // protects following fields + freeConn *list.List // of *driverConn + connRequests *list.List // of connRequest + numOpen int + pendingOpens int + // Used to sygnal the need for new connections + // a goroutine running connectionOpener() reads on this chan and + // maybeOpenNewConnections sends on the chan (one send per needed connection) + // It is closed during db.Close(). The close tells the connectionOpener + // goroutine to exit. + openerCh chan struct{} closed bool dep map[finalCloser]depSet lastPut map[*driverConn]string // stacktrace of last conn's put; debug only maxIdle int // zero means defaultMaxIdleConns; negative means 0 + maxOpen int // <= 0 means unlimited } // driverConn wraps a driver.Conn with a mutex, to @@ -217,6 +228,9 @@ type driverConn struct { inUse bool onPut []func() // code (with db.mu held) run when conn is next returned dbmuClosed bool // same as closed, but guarded by db.mu, for connIfFree + // This is the Element returned by db.freeConn.PushFront(conn). + // It's used by connIfFree to remove the conn from the freeConn list. + listElem *list.Element } func (dc *driverConn) releaseConn(err error) { @@ -254,15 +268,14 @@ func (dc *driverConn) prepareLocked(query string) (driver.Stmt, error) { } // the dc.db's Mutex is held. -func (dc *driverConn) closeDBLocked() error { +func (dc *driverConn) closeDBLocked() func() error { dc.Lock() + defer dc.Unlock() if dc.closed { - dc.Unlock() - return errors.New("sql: duplicate driverConn close") + return func() error { return errors.New("sql: duplicate driverConn close") } } dc.closed = true - dc.Unlock() // not defer; removeDep finalClose calls may need to lock - return dc.db.removeDepLocked(dc, dc)() + return dc.db.removeDepLocked(dc, dc) } func (dc *driverConn) Close() error { @@ -293,8 +306,13 @@ func (dc *driverConn) finalClose() error { err := dc.ci.Close() dc.ci = nil dc.finalClosed = true - dc.Unlock() + + dc.db.mu.Lock() + dc.db.numOpen-- + dc.db.maybeOpenNewConnections() + dc.db.mu.Unlock() + return err } @@ -380,6 +398,13 @@ func (db *DB) removeDepLocked(x finalCloser, dep interface{}) func() error { } } +// This is the size of the connectionOpener request chan (dn.openerCh). +// This value should be larger than the maximum typical value +// used for db.maxOpen. If maxOpen is significantly larger than +// connectionRequestQueueSize then it is possible for ALL calls into the *DB +// to block until the connectionOpener can satify the backlog of requests. +var connectionRequestQueueSize = 1000000 + // Open opens a database specified by its database driver name and a // driver-specific data source name, usually consisting of at least a // database name and connection information. @@ -398,10 +423,14 @@ func Open(driverName, dataSourceName string) (*DB, error) { return nil, fmt.Errorf("sql: unknown driver %q (forgotten import?)", driverName) } db := &DB{ - driver: driveri, - dsn: dataSourceName, - lastPut: make(map[*driverConn]string), + driver: driveri, + dsn: dataSourceName, + openerCh: make(chan struct{}, connectionRequestQueueSize), + lastPut: make(map[*driverConn]string), } + db.freeConn = list.New() + db.connRequests = list.New() + go db.connectionOpener() return db, nil } @@ -422,16 +451,32 @@ func (db *DB) Ping() error { // Close closes the database, releasing any open resources. func (db *DB) Close() error { db.mu.Lock() - defer db.mu.Unlock() + if db.closed { // Make DB.Close idempotent + db.mu.Unlock() + return nil + } + close(db.openerCh) var err error - for _, dc := range db.freeConn { - err1 := dc.closeDBLocked() + fns := make([]func() error, 0, db.freeConn.Len()) + for db.freeConn.Front() != nil { + dc := db.freeConn.Front().Value.(*driverConn) + dc.listElem = nil + fns = append(fns, dc.closeDBLocked()) + db.freeConn.Remove(db.freeConn.Front()) + } + db.closed = true + for db.connRequests.Front() != nil { + req := db.connRequests.Front().Value.(connRequest) + db.connRequests.Remove(db.connRequests.Front()) + close(req) + } + db.mu.Unlock() + for _, fn := range fns { + err1 := fn() if err1 != nil { err = err1 } } - db.freeConn = nil - db.closed = true return err } @@ -453,6 +498,9 @@ func (db *DB) maxIdleConnsLocked() int { // SetMaxIdleConns sets the maximum number of connections in the idle // connection pool. // +// If MaxOpenConns is greater than 0 but less than the new MaxIdleConns +// then the new MaxIdleConns will be reduced to match the MaxOpenConns limit +// // If n <= 0, no idle connections are retained. func (db *DB) SetMaxIdleConns(n int) { db.mu.Lock() @@ -463,40 +511,148 @@ func (db *DB) SetMaxIdleConns(n int) { // No idle connections. db.maxIdle = -1 } - for len(db.freeConn) > 0 && len(db.freeConn) > n { - nfree := len(db.freeConn) - dc := db.freeConn[nfree-1] - db.freeConn[nfree-1] = nil - db.freeConn = db.freeConn[:nfree-1] + // Make sure maxIdle doesn't exceed maxOpen + if db.maxOpen > 0 && db.maxIdleConnsLocked() > db.maxOpen { + db.maxIdle = db.maxOpen + } + for db.freeConn.Len() > db.maxIdleConnsLocked() { + dc := db.freeConn.Back().Value.(*driverConn) + dc.listElem = nil + db.freeConn.Remove(db.freeConn.Back()) go dc.Close() } } +// SetMaxOpenConns sets the maximum number of open connections to the database. +// +// If MaxIdleConns is greater than 0 and the new MaxOpenConns is less than +// MaxIdleConns, then MaxIdleConns will be reduced to match the new +// MaxOpenConns limit +// +// If n <= 0, then there is no limit on the number of open connections. +// The default is 0 (unlimited). +func (db *DB) SetMaxOpenConns(n int) { + db.mu.Lock() + db.maxOpen = n + if n < 0 { + db.maxOpen = 0 + } + syncMaxIdle := db.maxOpen > 0 && db.maxIdleConnsLocked() > db.maxOpen + db.mu.Unlock() + if syncMaxIdle { + db.SetMaxIdleConns(n) + } +} + +// Assumes db.mu is locked. +// If there are connRequests and the connection limit hasn't been reached, +// then tell the connectionOpener to open new connections. +func (db *DB) maybeOpenNewConnections() { + numRequests := db.connRequests.Len() - db.pendingOpens + if db.maxOpen > 0 { + numCanOpen := db.maxOpen - (db.numOpen + db.pendingOpens) + if numRequests > numCanOpen { + numRequests = numCanOpen + } + } + for numRequests > 0 { + db.pendingOpens++ + numRequests-- + db.openerCh <- struct{}{} + } +} + +// Runs in a seperate goroutine, opens new connections when requested. +func (db *DB) connectionOpener() { + for _ = range db.openerCh { + db.openNewConnection() + } +} + +// Open one new connection +func (db *DB) openNewConnection() { + ci, err := db.driver.Open(db.dsn) + db.mu.Lock() + defer db.mu.Unlock() + if db.closed { + if err == nil { + ci.Close() + } + return + } + db.pendingOpens-- + if err != nil { + db.putConnDBLocked(nil, err) + return + } + dc := &driverConn{ + db: db, + ci: ci, + } + db.addDepLocked(dc, dc) + db.numOpen++ + db.putConnDBLocked(dc, err) +} + +// connRequest represents one request for a new connection +// When there are no idle connections available, DB.conn will create +// a new connRequest and put it on the db.connRequests list. +type connRequest chan<- interface{} // takes either a *driverConn or an error + +var errDBClosed = errors.New("sql: database is closed") + // conn returns a newly-opened or cached *driverConn func (db *DB) conn() (*driverConn, error) { db.mu.Lock() if db.closed { db.mu.Unlock() - return nil, errors.New("sql: database is closed") + return nil, errDBClosed } - if n := len(db.freeConn); n > 0 { - conn := db.freeConn[n-1] - db.freeConn = db.freeConn[:n-1] + + // If db.maxOpen > 0 and the number of open connections is over the limit + // or there are no free connection, then make a request and wait. + if db.maxOpen > 0 && (db.numOpen >= db.maxOpen || db.freeConn.Len() == 0) { + // Make the connRequest channel. It's buffered so that the + // connectionOpener doesn't block while waiting for the req to be read. + ch := make(chan interface{}, 1) + req := connRequest(ch) + db.connRequests.PushBack(req) + db.maybeOpenNewConnections() + db.mu.Unlock() + ret, ok := <-ch + if !ok { + return nil, errDBClosed + } + switch ret.(type) { + case *driverConn: + return ret.(*driverConn), nil + case error: + return nil, ret.(error) + default: + panic("sql: Unexpected type passed through connRequest.ch") + } + } + + if f := db.freeConn.Front(); f != nil { + conn := f.Value.(*driverConn) + conn.listElem = nil + db.freeConn.Remove(f) conn.inUse = true db.mu.Unlock() return conn, nil } - db.mu.Unlock() + db.mu.Unlock() ci, err := db.driver.Open(db.dsn) if err != nil { return nil, err } + db.mu.Lock() + db.numOpen++ dc := &driverConn{ db: db, ci: ci, } - db.mu.Lock() db.addDepLocked(dc, dc) dc.inUse = true db.mu.Unlock() @@ -524,12 +680,9 @@ func (db *DB) connIfFree(wanted *driverConn) (*driverConn, error) { if wanted.inUse { return nil, errConnBusy } - for i, conn := range db.freeConn { - if conn != wanted { - continue - } - db.freeConn[i] = db.freeConn[len(db.freeConn)-1] - db.freeConn = db.freeConn[:len(db.freeConn)-1] + if wanted.listElem != nil { + db.freeConn.Remove(wanted.listElem) + wanted.listElem = nil wanted.inUse = true return wanted, nil } @@ -589,6 +742,10 @@ func (db *DB) putConn(dc *driverConn, err error) { if err == driver.ErrBadConn { // Don't reuse bad connections. + // Since the conn is considered bad and is being discarded, treat it + // as closed. Decrement the open count. + db.numOpen-- + db.maybeOpenNewConnections() db.mu.Unlock() dc.Close() return @@ -596,14 +753,38 @@ func (db *DB) putConn(dc *driverConn, err error) { if putConnHook != nil { putConnHook(db, dc) } - if n := len(db.freeConn); !db.closed && n < db.maxIdleConnsLocked() { - db.freeConn = append(db.freeConn, dc) - db.mu.Unlock() - return - } + added := db.putConnDBLocked(dc, nil) db.mu.Unlock() + if !added { + dc.Close() + } +} - dc.Close() +// Satisfy a connRequest or put the driverConn in the idle pool and return true +// or return false. +// putConnDBLocked will satisfy a connRequest if there is one, or it will +// return the *driverConn to the freeConn list if err != nil and the idle +// connection limit would not be reached. +// If err != nil, the value of dc is ignored. +// If err == nil, then dc must not equal nil. +// If a connRequest was fullfilled or the *driverConn was placed in the +// freeConn list, then true is returned, otherwise false is returned. +func (db *DB) putConnDBLocked(dc *driverConn, err error) bool { + if db.connRequests.Len() > 0 { + req := db.connRequests.Front().Value.(connRequest) + db.connRequests.Remove(db.connRequests.Front()) + if err != nil { + req <- err + } else { + dc.inUse = true + req <- dc + } + return true + } else if err == nil && !db.closed && db.maxIdleConnsLocked() > 0 && db.maxIdleConnsLocked() > db.freeConn.Len() { + dc.listElem = db.freeConn.PushFront(dc) + return true + } + return false } // Prepare creates a prepared statement for later queries or executions. diff --git a/src/pkg/database/sql/sql_test.go b/src/pkg/database/sql/sql_test.go index 4005f154466..435d79c24a9 100644 --- a/src/pkg/database/sql/sql_test.go +++ b/src/pkg/database/sql/sql_test.go @@ -8,6 +8,7 @@ import ( "database/sql/driver" "errors" "fmt" + "math/rand" "reflect" "runtime" "strings" @@ -23,14 +24,12 @@ func init() { } freedFrom := make(map[dbConn]string) putConnHook = func(db *DB, c *driverConn) { - for _, oc := range db.freeConn { - if oc == c { - // print before panic, as panic may get lost due to conflicting panic - // (all goroutines asleep) elsewhere, since we might not unlock - // the mutex in freeConn here. - println("double free of conn. conflicts are:\nA) " + freedFrom[dbConn{db, c}] + "\n\nand\nB) " + stack()) - panic("double free of conn.") - } + if c.listElem != nil { + // print before panic, as panic may get lost due to conflicting panic + // (all goroutines asleep) elsewhere, since we might not unlock + // the mutex in freeConn here. + println("double free of conn. conflicts are:\nA) " + freedFrom[dbConn{db, c}] + "\n\nand\nB) " + stack()) + panic("double free of conn.") } freedFrom[dbConn{db, c}] = stack() } @@ -80,14 +79,15 @@ func closeDB(t testing.TB, db *DB) { t.Errorf("Error closing fakeConn: %v", err) } }) - for i, dc := range db.freeConn { + for node, i := db.freeConn.Front(), 0; node != nil; node, i = node.Next(), i+1 { + dc := node.Value.(*driverConn) if n := len(dc.openStmt); n > 0 { // Just a sanity check. This is legal in // general, but if we make the tests clean up // their statements first, then we can safely // verify this is always zero here, and any // other value is a leak. - t.Errorf("while closing db, freeConn %d/%d had %d open stmts; want 0", i, len(db.freeConn), n) + t.Errorf("while closing db, freeConn %d/%d had %d open stmts; want 0", i, db.freeConn.Len(), n) } } err := db.Close() @@ -99,10 +99,10 @@ func closeDB(t testing.TB, db *DB) { // numPrepares assumes that db has exactly 1 idle conn and returns // its count of calls to Prepare func numPrepares(t *testing.T, db *DB) int { - if n := len(db.freeConn); n != 1 { + if n := db.freeConn.Len(); n != 1 { t.Fatalf("free conns = %d; want 1", n) } - return db.freeConn[0].ci.(*fakeConn).numPrepare + return (db.freeConn.Front().Value.(*driverConn)).ci.(*fakeConn).numPrepare } func (db *DB) numDeps() int { @@ -127,7 +127,7 @@ func (db *DB) numDepsPollUntil(want int, d time.Duration) int { func (db *DB) numFreeConns() int { db.mu.Lock() defer db.mu.Unlock() - return len(db.freeConn) + return db.freeConn.Len() } func (db *DB) dumpDeps(t *testing.T) { @@ -642,10 +642,10 @@ func TestQueryRowClosingStmt(t *testing.T) { if err != nil { t.Fatal(err) } - if len(db.freeConn) != 1 { + if db.freeConn.Len() != 1 { t.Fatalf("expected 1 free conn") } - fakeConn := db.freeConn[0].ci.(*fakeConn) + fakeConn := (db.freeConn.Front().Value.(*driverConn)).ci.(*fakeConn) if made, closed := fakeConn.stmtsMade, fakeConn.stmtsClosed; made != closed { t.Errorf("statement close mismatch: made %d, closed %d", made, closed) } @@ -841,13 +841,13 @@ func TestMaxIdleConns(t *testing.T) { t.Fatal(err) } tx.Commit() - if got := len(db.freeConn); got != 1 { + if got := db.freeConn.Len(); got != 1 { t.Errorf("freeConns = %d; want 1", got) } db.SetMaxIdleConns(0) - if got := len(db.freeConn); got != 0 { + if got := db.freeConn.Len(); got != 0 { t.Errorf("freeConns after set to zero = %d; want 0", got) } @@ -856,11 +856,146 @@ func TestMaxIdleConns(t *testing.T) { t.Fatal(err) } tx.Commit() - if got := len(db.freeConn); got != 0 { + if got := db.freeConn.Len(); got != 0 { t.Errorf("freeConns = %d; want 0", got) } } +func TestMaxOpenConns(t *testing.T) { + if testing.Short() { + t.Skip("skipping in short mode") + } + defer setHookpostCloseConn(nil) + setHookpostCloseConn(func(_ *fakeConn, err error) { + if err != nil { + t.Errorf("Error closing fakeConn: %v", err) + } + }) + + db := newTestDB(t, "magicquery") + defer closeDB(t, db) + + driver := db.driver.(*fakeDriver) + + // Force the number of open connections to 0 so we can get an accurate + // count for the test + db.SetMaxIdleConns(0) + + if g, w := db.numFreeConns(), 0; g != w { + t.Errorf("free conns = %d; want %d", g, w) + } + + if n := db.numDepsPollUntil(0, time.Second); n > 0 { + t.Errorf("number of dependencies = %d; expected 0", n) + db.dumpDeps(t) + } + + driver.mu.Lock() + opens0 := driver.openCount + closes0 := driver.closeCount + driver.mu.Unlock() + + db.SetMaxIdleConns(10) + db.SetMaxOpenConns(10) + + stmt, err := db.Prepare("SELECT|magicquery|op|op=?,millis=?") + if err != nil { + t.Fatal(err) + } + + // Start 50 parallel slow queries. + const ( + nquery = 50 + sleepMillis = 25 + nbatch = 2 + ) + var wg sync.WaitGroup + for batch := 0; batch < nbatch; batch++ { + for i := 0; i < nquery; i++ { + wg.Add(1) + go func() { + defer wg.Done() + var op string + if err := stmt.QueryRow("sleep", sleepMillis).Scan(&op); err != nil && err != ErrNoRows { + t.Error(err) + } + }() + } + // Sleep for twice the expected length of time for the + // batch of 50 queries above to finish before starting + // the next round. + time.Sleep(2 * sleepMillis * time.Millisecond) + } + wg.Wait() + + if g, w := db.numFreeConns(), 10; g != w { + t.Errorf("free conns = %d; want %d", g, w) + } + + if n := db.numDepsPollUntil(20, time.Second); n > 20 { + t.Errorf("number of dependencies = %d; expected <= 20", n) + db.dumpDeps(t) + } + + driver.mu.Lock() + opens := driver.openCount - opens0 + closes := driver.closeCount - closes0 + driver.mu.Unlock() + + if opens > 10 { + t.Logf("open calls = %d", opens) + t.Logf("close calls = %d", closes) + t.Errorf("db connections opened = %d; want <= 10", opens) + db.dumpDeps(t) + } + + if err := stmt.Close(); err != nil { + t.Fatal(err) + } + + if g, w := db.numFreeConns(), 10; g != w { + t.Errorf("free conns = %d; want %d", g, w) + } + + if n := db.numDepsPollUntil(10, time.Second); n > 10 { + t.Errorf("number of dependencies = %d; expected <= 10", n) + db.dumpDeps(t) + } + + db.SetMaxOpenConns(5) + + if g, w := db.numFreeConns(), 5; g != w { + t.Errorf("free conns = %d; want %d", g, w) + } + + if n := db.numDepsPollUntil(5, time.Second); n > 5 { + t.Errorf("number of dependencies = %d; expected 0", n) + db.dumpDeps(t) + } + + db.SetMaxOpenConns(0) + + if g, w := db.numFreeConns(), 5; g != w { + t.Errorf("free conns = %d; want %d", g, w) + } + + if n := db.numDepsPollUntil(5, time.Second); n > 5 { + t.Errorf("number of dependencies = %d; expected 0", n) + db.dumpDeps(t) + } + + db.SetMaxIdleConns(0) + + if g, w := db.numFreeConns(), 0; g != w { + t.Errorf("free conns = %d; want %d", g, w) + } + + if n := db.numDepsPollUntil(0, time.Second); n > 0 { + t.Errorf("number of dependencies = %d; expected 0", n) + db.dumpDeps(t) + } +} + // golang.org/issue/5323 func TestStmtCloseDeps(t *testing.T) { if testing.Short() { @@ -926,8 +1061,8 @@ func TestStmtCloseDeps(t *testing.T) { driver.mu.Lock() opens := driver.openCount - opens0 closes := driver.closeCount - closes0 - driver.mu.Unlock() openDelta := (driver.openCount - driver.closeCount) - openDelta0 + driver.mu.Unlock() if openDelta > 2 { t.Logf("open calls = %d", opens) @@ -985,10 +1120,10 @@ func TestCloseConnBeforeStmts(t *testing.T) { t.Fatal(err) } - if len(db.freeConn) != 1 { - t.Fatalf("expected 1 freeConn; got %d", len(db.freeConn)) + if db.freeConn.Len() != 1 { + t.Fatalf("expected 1 freeConn; got %d", db.freeConn.Len()) } - dc := db.freeConn[0] + dc := db.freeConn.Front().Value.(*driverConn) if dc.closed { t.Errorf("conn shouldn't be closed") } @@ -1082,6 +1217,350 @@ func TestStmtCloseOrder(t *testing.T) { } } +type concurrentTest interface { + init(t testing.TB, db *DB) + finish(t testing.TB) + test(t testing.TB) error +} + +type concurrentDBQueryTest struct { + db *DB +} + +func (c *concurrentDBQueryTest) init(t testing.TB, db *DB) { + c.db = db +} + +func (c *concurrentDBQueryTest) finish(t testing.TB) { + c.db = nil +} + +func (c *concurrentDBQueryTest) test(t testing.TB) error { + rows, err := c.db.Query("SELECT|people|name|") + if err != nil { + t.Error(err) + return err + } + var name string + for rows.Next() { + rows.Scan(&name) + } + rows.Close() + return nil +} + +type concurrentDBExecTest struct { + db *DB +} + +func (c *concurrentDBExecTest) init(t testing.TB, db *DB) { + c.db = db +} + +func (c *concurrentDBExecTest) finish(t testing.TB) { + c.db = nil +} + +func (c *concurrentDBExecTest) test(t testing.TB) error { + _, err := c.db.Exec("NOSERT|people|name=Chris,age=?,photo=CPHOTO,bdate=?", 3, chrisBirthday) + if err != nil { + t.Error(err) + return err + } + return nil +} + +type concurrentStmtQueryTest struct { + db *DB + stmt *Stmt +} + +func (c *concurrentStmtQueryTest) init(t testing.TB, db *DB) { + c.db = db + var err error + c.stmt, err = db.Prepare("SELECT|people|name|") + if err != nil { + t.Fatal(err) + } +} + +func (c *concurrentStmtQueryTest) finish(t testing.TB) { + if c.stmt != nil { + c.stmt.Close() + c.stmt = nil + } + c.db = nil +} + +func (c *concurrentStmtQueryTest) test(t testing.TB) error { + rows, err := c.stmt.Query() + if err != nil { + t.Errorf("error on query: %v", err) + return err + } + + var name string + for rows.Next() { + rows.Scan(&name) + } + rows.Close() + return nil +} + +type concurrentStmtExecTest struct { + db *DB + stmt *Stmt +} + +func (c *concurrentStmtExecTest) init(t testing.TB, db *DB) { + c.db = db + var err error + c.stmt, err = db.Prepare("NOSERT|people|name=Chris,age=?,photo=CPHOTO,bdate=?") + if err != nil { + t.Fatal(err) + } +} + +func (c *concurrentStmtExecTest) finish(t testing.TB) { + if c.stmt != nil { + c.stmt.Close() + c.stmt = nil + } + c.db = nil +} + +func (c *concurrentStmtExecTest) test(t testing.TB) error { + _, err := c.stmt.Exec(3, chrisBirthday) + if err != nil { + t.Errorf("error on exec: %v", err) + return err + } + return nil +} + +type concurrentTxQueryTest struct { + db *DB + tx *Tx +} + +func (c *concurrentTxQueryTest) init(t testing.TB, db *DB) { + c.db = db + var err error + c.tx, err = c.db.Begin() + if err != nil { + t.Fatal(err) + } +} + +func (c *concurrentTxQueryTest) finish(t testing.TB) { + if c.tx != nil { + c.tx.Rollback() + c.tx = nil + } + c.db = nil +} + +func (c *concurrentTxQueryTest) test(t testing.TB) error { + rows, err := c.db.Query("SELECT|people|name|") + if err != nil { + t.Error(err) + return err + } + var name string + for rows.Next() { + rows.Scan(&name) + } + rows.Close() + return nil +} + +type concurrentTxExecTest struct { + db *DB + tx *Tx +} + +func (c *concurrentTxExecTest) init(t testing.TB, db *DB) { + c.db = db + var err error + c.tx, err = c.db.Begin() + if err != nil { + t.Fatal(err) + } +} + +func (c *concurrentTxExecTest) finish(t testing.TB) { + if c.tx != nil { + c.tx.Rollback() + c.tx = nil + } + c.db = nil +} + +func (c *concurrentTxExecTest) test(t testing.TB) error { + _, err := c.tx.Exec("NOSERT|people|name=Chris,age=?,photo=CPHOTO,bdate=?", 3, chrisBirthday) + if err != nil { + t.Error(err) + return err + } + return nil +} + +type concurrentTxStmtQueryTest struct { + db *DB + tx *Tx + stmt *Stmt +} + +func (c *concurrentTxStmtQueryTest) init(t testing.TB, db *DB) { + c.db = db + var err error + c.tx, err = c.db.Begin() + if err != nil { + t.Fatal(err) + } + c.stmt, err = c.tx.Prepare("SELECT|people|name|") + if err != nil { + t.Fatal(err) + } +} + +func (c *concurrentTxStmtQueryTest) finish(t testing.TB) { + if c.stmt != nil { + c.stmt.Close() + c.stmt = nil + } + if c.tx != nil { + c.tx.Rollback() + c.tx = nil + } + c.db = nil +} + +func (c *concurrentTxStmtQueryTest) test(t testing.TB) error { + rows, err := c.stmt.Query() + if err != nil { + t.Errorf("error on query: %v", err) + return err + } + + var name string + for rows.Next() { + rows.Scan(&name) + } + rows.Close() + return nil +} + +type concurrentTxStmtExecTest struct { + db *DB + tx *Tx + stmt *Stmt +} + +func (c *concurrentTxStmtExecTest) init(t testing.TB, db *DB) { + c.db = db + var err error + c.tx, err = c.db.Begin() + if err != nil { + t.Fatal(err) + } + c.stmt, err = c.tx.Prepare("NOSERT|people|name=Chris,age=?,photo=CPHOTO,bdate=?") + if err != nil { + t.Fatal(err) + } +} + +func (c *concurrentTxStmtExecTest) finish(t testing.TB) { + if c.stmt != nil { + c.stmt.Close() + c.stmt = nil + } + if c.tx != nil { + c.tx.Rollback() + c.tx = nil + } + c.db = nil +} + +func (c *concurrentTxStmtExecTest) test(t testing.TB) error { + _, err := c.stmt.Exec(3, chrisBirthday) + if err != nil { + t.Errorf("error on exec: %v", err) + return err + } + return nil +} + +type concurrentRandomTest struct { + tests []concurrentTest +} + +func (c *concurrentRandomTest) init(t testing.TB, db *DB) { + c.tests = []concurrentTest{ + new(concurrentDBQueryTest), + new(concurrentDBExecTest), + new(concurrentStmtQueryTest), + new(concurrentStmtExecTest), + new(concurrentTxQueryTest), + new(concurrentTxExecTest), + new(concurrentTxStmtQueryTest), + new(concurrentTxStmtExecTest), + } + for _, ct := range c.tests { + ct.init(t, db) + } +} + +func (c *concurrentRandomTest) finish(t testing.TB) { + for _, ct := range c.tests { + ct.finish(t) + } +} + +func (c *concurrentRandomTest) test(t testing.TB) error { + ct := c.tests[rand.Intn(len(c.tests))] + return ct.test(t) +} + +func doConcurrentTest(t testing.TB, ct concurrentTest) { + maxProcs, numReqs := 1, 500 + if testing.Short() { + maxProcs, numReqs = 4, 50 + } + defer runtime.GOMAXPROCS(runtime.GOMAXPROCS(maxProcs)) + + db := newTestDB(t, "people") + defer closeDB(t, db) + + ct.init(t, db) + defer ct.finish(t) + + var wg sync.WaitGroup + wg.Add(numReqs) + + reqs := make(chan bool) + defer close(reqs) + + for i := 0; i < maxProcs*2; i++ { + go func() { + for _ = range reqs { + err := ct.test(t) + if err != nil { + wg.Done() + continue + } + wg.Done() + } + }() + } + + for i := 0; i < numReqs; i++ { + reqs <- true + } + + wg.Wait() +} + func manyConcurrentQueries(t testing.TB) { maxProcs, numReqs := 16, 500 if testing.Short() { @@ -1178,12 +1657,77 @@ func TestIssue6081(t *testing.T) { } func TestConcurrency(t *testing.T) { - manyConcurrentQueries(t) + doConcurrentTest(t, new(concurrentDBQueryTest)) + doConcurrentTest(t, new(concurrentDBExecTest)) + doConcurrentTest(t, new(concurrentStmtQueryTest)) + doConcurrentTest(t, new(concurrentStmtExecTest)) + doConcurrentTest(t, new(concurrentTxQueryTest)) + doConcurrentTest(t, new(concurrentTxExecTest)) + doConcurrentTest(t, new(concurrentTxStmtQueryTest)) + doConcurrentTest(t, new(concurrentTxStmtExecTest)) + doConcurrentTest(t, new(concurrentRandomTest)) } -func BenchmarkConcurrency(b *testing.B) { +func BenchmarkConcurrentDBExec(b *testing.B) { b.ReportAllocs() + ct := new(concurrentDBExecTest) for i := 0; i < b.N; i++ { - manyConcurrentQueries(b) + doConcurrentTest(b, ct) + } +} + +func BenchmarkConcurrentStmtQuery(b *testing.B) { + b.ReportAllocs() + ct := new(concurrentStmtQueryTest) + for i := 0; i < b.N; i++ { + doConcurrentTest(b, ct) + } +} + +func BenchmarkConcurrentStmtExec(b *testing.B) { + b.ReportAllocs() + ct := new(concurrentStmtExecTest) + for i := 0; i < b.N; i++ { + doConcurrentTest(b, ct) + } +} + +func BenchmarkConcurrentTxQuery(b *testing.B) { + b.ReportAllocs() + ct := new(concurrentTxQueryTest) + for i := 0; i < b.N; i++ { + doConcurrentTest(b, ct) + } +} + +func BenchmarkConcurrentTxExec(b *testing.B) { + b.ReportAllocs() + ct := new(concurrentTxExecTest) + for i := 0; i < b.N; i++ { + doConcurrentTest(b, ct) + } +} + +func BenchmarkConcurrentTxStmtQuery(b *testing.B) { + b.ReportAllocs() + ct := new(concurrentTxStmtQueryTest) + for i := 0; i < b.N; i++ { + doConcurrentTest(b, ct) + } +} + +func BenchmarkConcurrentTxStmtExec(b *testing.B) { + b.ReportAllocs() + ct := new(concurrentTxStmtExecTest) + for i := 0; i < b.N; i++ { + doConcurrentTest(b, ct) + } +} + +func BenchmarkConcurrentRandom(b *testing.B) { + b.ReportAllocs() + ct := new(concurrentRandomTest) + for i := 0; i < b.N; i++ { + doConcurrentTest(b, ct) } }