mirror of
https://github.com/golang/go
synced 2024-11-25 04:17:57 -07:00
database/sql: ensure Stmts are correctly closed.
To make sure that there is no resource leak, I suggest to fix the 'fakedb' driver such as it fails when any Stmt is not closed. First, add a check in fakeConn.Close(). Then, fix all missing Stmt.Close()/Rows.Close(). I am not sure that the strategy choose in fakeConn.Prepare/prepare* is ok. The weak point in this patch is the change in Tx.Query: - Tests pass without this change, - I found it by manually analyzing the code, - I just try to make Tx.Query look like DB.Query. R=golang-dev, bradfitz CC=golang-dev https://golang.org/cl/5759050
This commit is contained in:
parent
959d0c7ac0
commit
c3954dd5da
@ -214,6 +214,9 @@ func (c *fakeConn) Close() error {
|
|||||||
if c.db == nil {
|
if c.db == nil {
|
||||||
return errors.New("can't close fakeConn; already closed")
|
return errors.New("can't close fakeConn; already closed")
|
||||||
}
|
}
|
||||||
|
if c.stmtsMade > c.stmtsClosed {
|
||||||
|
return errors.New("can't close; dangling statement(s)")
|
||||||
|
}
|
||||||
c.db = nil
|
c.db = nil
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@ -250,6 +253,7 @@ func errf(msg string, args ...interface{}) error {
|
|||||||
// just a limitation for fakedb)
|
// just a limitation for fakedb)
|
||||||
func (c *fakeConn) prepareSelect(stmt *fakeStmt, parts []string) (driver.Stmt, error) {
|
func (c *fakeConn) prepareSelect(stmt *fakeStmt, parts []string) (driver.Stmt, error) {
|
||||||
if len(parts) != 3 {
|
if len(parts) != 3 {
|
||||||
|
stmt.Close()
|
||||||
return nil, errf("invalid SELECT syntax with %d parts; want 3", len(parts))
|
return nil, errf("invalid SELECT syntax with %d parts; want 3", len(parts))
|
||||||
}
|
}
|
||||||
stmt.table = parts[0]
|
stmt.table = parts[0]
|
||||||
@ -260,14 +264,17 @@ func (c *fakeConn) prepareSelect(stmt *fakeStmt, parts []string) (driver.Stmt, e
|
|||||||
}
|
}
|
||||||
nameVal := strings.Split(colspec, "=")
|
nameVal := strings.Split(colspec, "=")
|
||||||
if len(nameVal) != 2 {
|
if len(nameVal) != 2 {
|
||||||
|
stmt.Close()
|
||||||
return nil, errf("SELECT on table %q has invalid column spec of %q (index %d)", stmt.table, colspec, n)
|
return nil, errf("SELECT on table %q has invalid column spec of %q (index %d)", stmt.table, colspec, n)
|
||||||
}
|
}
|
||||||
column, value := nameVal[0], nameVal[1]
|
column, value := nameVal[0], nameVal[1]
|
||||||
_, ok := c.db.columnType(stmt.table, column)
|
_, ok := c.db.columnType(stmt.table, column)
|
||||||
if !ok {
|
if !ok {
|
||||||
|
stmt.Close()
|
||||||
return nil, errf("SELECT on table %q references non-existent column %q", stmt.table, column)
|
return nil, errf("SELECT on table %q references non-existent column %q", stmt.table, column)
|
||||||
}
|
}
|
||||||
if value != "?" {
|
if value != "?" {
|
||||||
|
stmt.Close()
|
||||||
return nil, errf("SELECT on table %q has pre-bound value for where column %q; need a question mark",
|
return nil, errf("SELECT on table %q has pre-bound value for where column %q; need a question mark",
|
||||||
stmt.table, column)
|
stmt.table, column)
|
||||||
}
|
}
|
||||||
@ -280,12 +287,14 @@ func (c *fakeConn) prepareSelect(stmt *fakeStmt, parts []string) (driver.Stmt, e
|
|||||||
// parts are table|col=type,col2=type2
|
// parts are table|col=type,col2=type2
|
||||||
func (c *fakeConn) prepareCreate(stmt *fakeStmt, parts []string) (driver.Stmt, error) {
|
func (c *fakeConn) prepareCreate(stmt *fakeStmt, parts []string) (driver.Stmt, error) {
|
||||||
if len(parts) != 2 {
|
if len(parts) != 2 {
|
||||||
|
stmt.Close()
|
||||||
return nil, errf("invalid CREATE syntax with %d parts; want 2", len(parts))
|
return nil, errf("invalid CREATE syntax with %d parts; want 2", len(parts))
|
||||||
}
|
}
|
||||||
stmt.table = parts[0]
|
stmt.table = parts[0]
|
||||||
for n, colspec := range strings.Split(parts[1], ",") {
|
for n, colspec := range strings.Split(parts[1], ",") {
|
||||||
nameType := strings.Split(colspec, "=")
|
nameType := strings.Split(colspec, "=")
|
||||||
if len(nameType) != 2 {
|
if len(nameType) != 2 {
|
||||||
|
stmt.Close()
|
||||||
return nil, errf("CREATE table %q has invalid column spec of %q (index %d)", stmt.table, colspec, n)
|
return nil, errf("CREATE table %q has invalid column spec of %q (index %d)", stmt.table, colspec, n)
|
||||||
}
|
}
|
||||||
stmt.colName = append(stmt.colName, nameType[0])
|
stmt.colName = append(stmt.colName, nameType[0])
|
||||||
@ -297,17 +306,20 @@ func (c *fakeConn) prepareCreate(stmt *fakeStmt, parts []string) (driver.Stmt, e
|
|||||||
// parts are table|col=?,col2=val
|
// parts are table|col=?,col2=val
|
||||||
func (c *fakeConn) prepareInsert(stmt *fakeStmt, parts []string) (driver.Stmt, error) {
|
func (c *fakeConn) prepareInsert(stmt *fakeStmt, parts []string) (driver.Stmt, error) {
|
||||||
if len(parts) != 2 {
|
if len(parts) != 2 {
|
||||||
|
stmt.Close()
|
||||||
return nil, errf("invalid INSERT syntax with %d parts; want 2", len(parts))
|
return nil, errf("invalid INSERT syntax with %d parts; want 2", len(parts))
|
||||||
}
|
}
|
||||||
stmt.table = parts[0]
|
stmt.table = parts[0]
|
||||||
for n, colspec := range strings.Split(parts[1], ",") {
|
for n, colspec := range strings.Split(parts[1], ",") {
|
||||||
nameVal := strings.Split(colspec, "=")
|
nameVal := strings.Split(colspec, "=")
|
||||||
if len(nameVal) != 2 {
|
if len(nameVal) != 2 {
|
||||||
|
stmt.Close()
|
||||||
return nil, errf("INSERT table %q has invalid column spec of %q (index %d)", stmt.table, colspec, n)
|
return nil, errf("INSERT table %q has invalid column spec of %q (index %d)", stmt.table, colspec, n)
|
||||||
}
|
}
|
||||||
column, value := nameVal[0], nameVal[1]
|
column, value := nameVal[0], nameVal[1]
|
||||||
ctype, ok := c.db.columnType(stmt.table, column)
|
ctype, ok := c.db.columnType(stmt.table, column)
|
||||||
if !ok {
|
if !ok {
|
||||||
|
stmt.Close()
|
||||||
return nil, errf("INSERT table %q references non-existent column %q", stmt.table, column)
|
return nil, errf("INSERT table %q references non-existent column %q", stmt.table, column)
|
||||||
}
|
}
|
||||||
stmt.colName = append(stmt.colName, column)
|
stmt.colName = append(stmt.colName, column)
|
||||||
@ -323,10 +335,12 @@ func (c *fakeConn) prepareInsert(stmt *fakeStmt, parts []string) (driver.Stmt, e
|
|||||||
case "int32":
|
case "int32":
|
||||||
i, err := strconv.Atoi(value)
|
i, err := strconv.Atoi(value)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
stmt.Close()
|
||||||
return nil, errf("invalid conversion to int32 from %q", value)
|
return nil, errf("invalid conversion to int32 from %q", value)
|
||||||
}
|
}
|
||||||
subsetVal = int64(i) // int64 is a subset type, but not int32
|
subsetVal = int64(i) // int64 is a subset type, but not int32
|
||||||
default:
|
default:
|
||||||
|
stmt.Close()
|
||||||
return nil, errf("unsupported conversion for pre-bound parameter %q to type %q", value, ctype)
|
return nil, errf("unsupported conversion for pre-bound parameter %q to type %q", value, ctype)
|
||||||
}
|
}
|
||||||
stmt.colValue = append(stmt.colValue, subsetVal)
|
stmt.colValue = append(stmt.colValue, subsetVal)
|
||||||
@ -362,6 +376,7 @@ func (c *fakeConn) Prepare(query string) (driver.Stmt, error) {
|
|||||||
case "INSERT":
|
case "INSERT":
|
||||||
return c.prepareInsert(stmt, parts)
|
return c.prepareInsert(stmt, parts)
|
||||||
default:
|
default:
|
||||||
|
stmt.Close()
|
||||||
return nil, errf("unsupported command type %q", cmd)
|
return nil, errf("unsupported command type %q", cmd)
|
||||||
}
|
}
|
||||||
return stmt, nil
|
return stmt, nil
|
||||||
|
@ -612,9 +612,11 @@ func (tx *Tx) Query(query string, args ...interface{}) (*Rows, error) {
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
rows, err := stmt.Query(args...)
|
rows, err := stmt.Query(args...)
|
||||||
if err == nil {
|
if err != nil {
|
||||||
rows.closeStmt = stmt
|
stmt.Close()
|
||||||
|
return nil, err
|
||||||
}
|
}
|
||||||
|
rows.closeStmt = stmt
|
||||||
return rows, err
|
return rows, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -251,6 +251,7 @@ func TestStatementQueryRow(t *testing.T) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Prepare: %v", err)
|
t.Fatalf("Prepare: %v", err)
|
||||||
}
|
}
|
||||||
|
defer stmt.Close()
|
||||||
var age int
|
var age int
|
||||||
for n, tt := range []struct {
|
for n, tt := range []struct {
|
||||||
name string
|
name string
|
||||||
@ -291,6 +292,7 @@ func TestExec(t *testing.T) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("Stmt, err = %v, %v", stmt, err)
|
t.Errorf("Stmt, err = %v, %v", stmt, err)
|
||||||
}
|
}
|
||||||
|
defer stmt.Close()
|
||||||
|
|
||||||
type execTest struct {
|
type execTest struct {
|
||||||
args []interface{}
|
args []interface{}
|
||||||
@ -332,11 +334,14 @@ func TestTxStmt(t *testing.T) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Stmt, err = %v, %v", stmt, err)
|
t.Fatalf("Stmt, err = %v, %v", stmt, err)
|
||||||
}
|
}
|
||||||
|
defer stmt.Close()
|
||||||
tx, err := db.Begin()
|
tx, err := db.Begin()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Begin = %v", err)
|
t.Fatalf("Begin = %v", err)
|
||||||
}
|
}
|
||||||
_, err = tx.Stmt(stmt).Exec("Bobby", 7)
|
txs := tx.Stmt(stmt)
|
||||||
|
defer txs.Close()
|
||||||
|
_, err = txs.Exec("Bobby", 7)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Exec = %v", err)
|
t.Fatalf("Exec = %v", err)
|
||||||
}
|
}
|
||||||
@ -365,6 +370,7 @@ func TestTxQuery(t *testing.T) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
defer r.Close()
|
||||||
|
|
||||||
if !r.Next() {
|
if !r.Next() {
|
||||||
if r.Err() != nil {
|
if r.Err() != nil {
|
||||||
@ -561,6 +567,7 @@ func nullTestRun(t *testing.T, spec nullTestSpec) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("prepare: %v", err)
|
t.Fatalf("prepare: %v", err)
|
||||||
}
|
}
|
||||||
|
defer stmt.Close()
|
||||||
if _, err := stmt.Exec(3, "chris", spec.rows[2].nullParam, spec.rows[2].notNullParam); err != nil {
|
if _, err := stmt.Exec(3, "chris", spec.rows[2].nullParam, spec.rows[2].notNullParam); err != nil {
|
||||||
t.Errorf("exec insert chris: %v", err)
|
t.Errorf("exec insert chris: %v", err)
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user