From 707a83341b8c7973f4e0fce731fa279c618f233b Mon Sep 17 00:00:00 2001 From: Daniel Theophanes Date: Mon, 3 Oct 2016 09:49:25 -0700 Subject: [PATCH] database/sql: add option to use named parameter in query arguments Modify the new Context methods to take a name-value driver struct. This will require more modifications to drivers to use, but will reduce the overall number of structures that need to be maintained over time. Fixes #12381 Change-Id: I30747533ce418a1be5991a0c8767a26e8451adbd Reviewed-on: https://go-review.googlesource.com/30166 Reviewed-by: Brad Fitzpatrick Run-TryBot: Brad Fitzpatrick TryBot-Result: Gobot Gobot --- src/database/sql/convert.go | 27 +++++++--- src/database/sql/ctxutil.go | 43 +++++++++++++--- src/database/sql/driver/driver.go | 18 +++++-- src/database/sql/fakedb_test.go | 83 ++++++++++++++++++++++++------- src/database/sql/sql.go | 23 ++++++++- src/database/sql/sql_test.go | 47 +++++++++++++++++ 6 files changed, 202 insertions(+), 39 deletions(-) diff --git a/src/database/sql/convert.go b/src/database/sql/convert.go index 99aed2398e..cee96319da 100644 --- a/src/database/sql/convert.go +++ b/src/database/sql/convert.go @@ -21,8 +21,8 @@ var errNilPtr = errors.New("destination pointer is nil") // embedded in descript // Stmt.Query into driver Values. // // The statement ds may be nil, if no statement is available. -func driverArgs(ds *driverStmt, args []interface{}) ([]driver.Value, error) { - dargs := make([]driver.Value, len(args)) +func driverArgs(ds *driverStmt, args []interface{}) ([]driver.NamedValue, error) { + nvargs := make([]driver.NamedValue, len(args)) var si driver.Stmt if ds != nil { si = ds.si @@ -33,16 +33,27 @@ func driverArgs(ds *driverStmt, args []interface{}) ([]driver.Value, error) { if !ok { for n, arg := range args { var err error - dargs[n], err = driver.DefaultParameterConverter.ConvertValue(arg) + nvargs[n].Ordinal = n + 1 + if np, ok := arg.(NamedParam); ok { + arg = np.Value + nvargs[n].Name = np.Name + } + nvargs[n].Value, err = driver.DefaultParameterConverter.ConvertValue(arg) + if err != nil { return nil, fmt.Errorf("sql: converting Exec argument #%d's type: %v", n, err) } } - return dargs, nil + return nvargs, nil } // Let the Stmt convert its own arguments. for n, arg := range args { + nvargs[n].Ordinal = n + 1 + if np, ok := arg.(NamedParam); ok { + arg = np.Value + nvargs[n].Name = np.Name + } // First, see if the value itself knows how to convert // itself to a driver type. For example, a NullString // struct changing into a string or nil. @@ -66,18 +77,18 @@ func driverArgs(ds *driverStmt, args []interface{}) ([]driver.Value, error) { // same error. var err error ds.Lock() - dargs[n], err = cc.ColumnConverter(n).ConvertValue(arg) + nvargs[n].Value, err = cc.ColumnConverter(n).ConvertValue(arg) ds.Unlock() if err != nil { return nil, fmt.Errorf("sql: converting argument #%d's type: %v", n, err) } - if !driver.IsValue(dargs[n]) { + if !driver.IsValue(nvargs[n].Value) { return nil, fmt.Errorf("sql: driver ColumnConverter error converted %T to unsupported type %T", - arg, dargs[n]) + arg, nvargs[n].Value) } } - return dargs, nil + return nvargs, nil } // convertAssign copies to dest the value in src, converting it if possible. diff --git a/src/database/sql/ctxutil.go b/src/database/sql/ctxutil.go index e1d4c03c9a..173f6a9d2b 100644 --- a/src/database/sql/ctxutil.go +++ b/src/database/sql/ctxutil.go @@ -50,9 +50,13 @@ func ctxDriverPrepare(ctx context.Context, ci driver.Conn, query string) (driver } } -func ctxDriverExec(ctx context.Context, execer driver.Execer, query string, dargs []driver.Value) (driver.Result, error) { +func ctxDriverExec(ctx context.Context, execer driver.Execer, query string, nvdargs []driver.NamedValue) (driver.Result, error) { if execerCtx, is := execer.(driver.ExecerContext); is { - return execerCtx.ExecContext(ctx, query, dargs) + return execerCtx.ExecContext(ctx, query, nvdargs) + } + dargs, err := namedValueToValue(nvdargs) + if err != nil { + return nil, err } if ctx.Done() == context.Background().Done() { return execer.Exec(query, dargs) @@ -90,9 +94,13 @@ func ctxDriverExec(ctx context.Context, execer driver.Execer, query string, darg } } -func ctxDriverQuery(ctx context.Context, queryer driver.Queryer, query string, dargs []driver.Value) (driver.Rows, error) { +func ctxDriverQuery(ctx context.Context, queryer driver.Queryer, query string, nvdargs []driver.NamedValue) (driver.Rows, error) { if queryerCtx, is := queryer.(driver.QueryerContext); is { - return queryerCtx.QueryContext(ctx, query, dargs) + return queryerCtx.QueryContext(ctx, query, nvdargs) + } + dargs, err := namedValueToValue(nvdargs) + if err != nil { + return nil, err } if ctx.Done() == context.Background().Done() { return queryer.Query(query, dargs) @@ -130,9 +138,13 @@ func ctxDriverQuery(ctx context.Context, queryer driver.Queryer, query string, d } } -func ctxDriverStmtExec(ctx context.Context, si driver.Stmt, dargs []driver.Value) (driver.Result, error) { +func ctxDriverStmtExec(ctx context.Context, si driver.Stmt, nvdargs []driver.NamedValue) (driver.Result, error) { if siCtx, is := si.(driver.StmtExecContext); is { - return siCtx.ExecContext(ctx, dargs) + return siCtx.ExecContext(ctx, nvdargs) + } + dargs, err := namedValueToValue(nvdargs) + if err != nil { + return nil, err } if ctx.Done() == context.Background().Done() { return si.Exec(dargs) @@ -170,9 +182,13 @@ func ctxDriverStmtExec(ctx context.Context, si driver.Stmt, dargs []driver.Value } } -func ctxDriverStmtQuery(ctx context.Context, si driver.Stmt, dargs []driver.Value) (driver.Rows, error) { +func ctxDriverStmtQuery(ctx context.Context, si driver.Stmt, nvdargs []driver.NamedValue) (driver.Rows, error) { if siCtx, is := si.(driver.StmtQueryContext); is { - return siCtx.QueryContext(ctx, dargs) + return siCtx.QueryContext(ctx, nvdargs) + } + dargs, err := namedValueToValue(nvdargs) + if err != nil { + return nil, err } if ctx.Done() == context.Background().Done() { return si.Query(dargs) @@ -253,3 +269,14 @@ func ctxDriverBegin(ctx context.Context, ci driver.Conn) (driver.Tx, error) { return r.txi, r.err } } + +func namedValueToValue(named []driver.NamedValue) ([]driver.Value, error) { + dargs := make([]driver.Value, len(named)) + for n, param := range named { + if len(param.Name) > 0 { + return nil, errors.New("sql: driver does not support the use of Named Parameters") + } + dargs[n] = param.Value + } + return dargs, nil +} diff --git a/src/database/sql/driver/driver.go b/src/database/sql/driver/driver.go index b3d83f3ff4..6cc970f688 100644 --- a/src/database/sql/driver/driver.go +++ b/src/database/sql/driver/driver.go @@ -24,6 +24,16 @@ import ( // time.Time type Value interface{} +// NamedValue holds both the value name and value. +// The Ordinal is the position of the parameter starting from one and is always set. +// If the Name is not empty it should be used for the parameter identifier and +// not the ordinal position. +type NamedValue struct { + Name string + Ordinal int + Value Value +} + // Driver is the interface that must be implemented by a database // driver. type Driver interface { @@ -71,7 +81,7 @@ type Execer interface { // ExecerContext is like execer, but must honor the context timeout and return // when the context is cancelled. type ExecerContext interface { - ExecContext(ctx context.Context, query string, args []Value) (Result, error) + ExecContext(ctx context.Context, query string, args []NamedValue) (Result, error) } // Queryer is an optional interface that may be implemented by a Conn. @@ -88,7 +98,7 @@ type Queryer interface { // QueryerContext is like Queryer, but most honor the context timeout and return // when the context is cancelled. type QueryerContext interface { - QueryContext(ctx context.Context, query string, args []Value) (Rows, error) + QueryContext(ctx context.Context, query string, args []NamedValue) (Rows, error) } // Conn is a connection to a database. It is not used concurrently @@ -174,13 +184,13 @@ type Stmt interface { // StmtExecContext enhances the Stmt interface by providing Exec with context. type StmtExecContext interface { // ExecContext must honor the context timeout and return when it is cancelled. - ExecContext(ctx context.Context, args []Value) (Result, error) + ExecContext(ctx context.Context, args []NamedValue) (Result, error) } // StmtQueryContext enhances the Stmt interface by providing Query with context. type StmtQueryContext interface { // QueryContext must honor the context timeout and return when it is cancelled. - QueryContext(ctx context.Context, args []Value) (Rows, error) + QueryContext(ctx context.Context, args []NamedValue) (Rows, error) } // ColumnConverter may be optionally implemented by Stmt if the diff --git a/src/database/sql/fakedb_test.go b/src/database/sql/fakedb_test.go index aaa13a6799..07f50196a5 100644 --- a/src/database/sql/fakedb_test.go +++ b/src/database/sql/fakedb_test.go @@ -5,6 +5,7 @@ package sql import ( + "context" "database/sql/driver" "errors" "fmt" @@ -32,6 +33,7 @@ var _ = log.Printf // where types are: "string", [u]int{8,16,32,64}, "bool" // INSERT||col=val,col2=val2,col3=? // SELECT||projectcol1,projectcol2|filtercol=?,filtercol2=? +// SELECT||projectcol1,projectcol2|filtercol=?param1,filtercol2=?param2 // // Any of these can be preceded by PANIC||, to cause the // named method on fakeStmt to panic. @@ -103,6 +105,12 @@ type fakeTx struct { c *fakeConn } +type boundCol struct { + Column string + Placeholder string + Ordinal int +} + type fakeStmt struct { c *fakeConn q string // just for debugging @@ -120,7 +128,7 @@ type fakeStmt struct { colValue []interface{} // used by INSERT (mix of strings and "?" for bound params) placeholders int // used by INSERT/SELECT: number of ? params - whereCol []string // used by SELECT (all placeholders) + whereCol []boundCol // used by SELECT (all placeholders) placeholderConverter []driver.ValueConverter // used by INSERT } @@ -339,18 +347,23 @@ func (c *fakeConn) Close() (err error) { return nil } -func checkSubsetTypes(args []driver.Value) error { - for n, arg := range args { - switch arg.(type) { +func checkSubsetTypes(args []driver.NamedValue) error { + for _, arg := range args { + switch arg.Value.(type) { case int64, float64, bool, nil, []byte, string, time.Time: default: - return fmt.Errorf("fakedb_test: invalid argument #%d: %v, type %T", n+1, arg, arg) + return fmt.Errorf("fakedb_test: invalid argument ordinal %[1]d: %[2]v, type %[2]T", arg.Ordinal, arg.Value) } } return nil } func (c *fakeConn) Exec(query string, args []driver.Value) (driver.Result, error) { + // Ensure that ExecContext is called if available. + panic("ExecContext was not called.") +} + +func (c *fakeConn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) { // This is an optional interface, but it's implemented here // just to check that all the args are of the proper types. // ErrSkip is returned so the caller acts as if we didn't @@ -363,6 +376,11 @@ func (c *fakeConn) Exec(query string, args []driver.Value) (driver.Result, error } func (c *fakeConn) Query(query string, args []driver.Value) (driver.Rows, error) { + // Ensure that ExecContext is called if available. + panic("QueryContext was not called.") +} + +func (c *fakeConn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) { // This is an optional interface, but it's implemented here // just to check that all the args are of the proper types. // ErrSkip is returned so the caller acts as if we didn't @@ -403,13 +421,13 @@ func (c *fakeConn) prepareSelect(stmt *fakeStmt, parts []string) (*fakeStmt, err stmt.Close() return nil, errf("SELECT on table %q references non-existent column %q", stmt.table, column) } - if value != "?" { + if !strings.HasPrefix(value, "?") { stmt.Close() return nil, errf("SELECT on table %q has pre-bound value for where column %q; need a question mark", stmt.table, column) } - stmt.whereCol = append(stmt.whereCol, column) stmt.placeholders++ + stmt.whereCol = append(stmt.whereCol, boundCol{Column: column, Placeholder: value, Ordinal: stmt.placeholders}) } return stmt, nil } @@ -454,7 +472,7 @@ func (c *fakeConn) prepareInsert(stmt *fakeStmt, parts []string) (*fakeStmt, err } stmt.colName = append(stmt.colName, column) - if value != "?" { + if !strings.HasPrefix(value, "?") { var subsetVal interface{} // Convert to driver subset type switch ctype { @@ -477,7 +495,7 @@ func (c *fakeConn) prepareInsert(stmt *fakeStmt, parts []string) (*fakeStmt, err } else { stmt.placeholders++ stmt.placeholderConverter = append(stmt.placeholderConverter, converterForType(ctype)) - stmt.colValue = append(stmt.colValue, "?") + stmt.colValue = append(stmt.colValue, value) } } return stmt, nil @@ -580,6 +598,9 @@ var errClosed = errors.New("fakedb: statement has been closed") var hookExecBadConn func() bool func (s *fakeStmt) Exec(args []driver.Value) (driver.Result, error) { + panic("Using ExecContext") +} +func (s *fakeStmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) { if s.panic == "Exec" { panic(s.panic) } @@ -620,7 +641,7 @@ func (s *fakeStmt) Exec(args []driver.Value) (driver.Result, error) { // When doInsert is true, add the row to the table. // When doInsert is false do prep-work and error checking, but don't // actually add the row to the table. -func (s *fakeStmt) execInsert(args []driver.Value, doInsert bool) (driver.Result, error) { +func (s *fakeStmt) execInsert(args []driver.NamedValue, doInsert bool) (driver.Result, error) { db := s.c.db if len(args) != s.placeholders { panic("error in pkg db; should only get here if size is correct") @@ -646,8 +667,18 @@ func (s *fakeStmt) execInsert(args []driver.Value, doInsert bool) (driver.Result return nil, fmt.Errorf("fakedb: column %q doesn't exist or dropped since prepared statement was created", colname) } var val interface{} - if strvalue, ok := s.colValue[n].(string); ok && strvalue == "?" { - val = args[argPos] + if strvalue, ok := s.colValue[n].(string); ok && strings.HasPrefix(strvalue, "?") { + if strvalue == "?" { + val = args[argPos].Value + } else { + // Assign value from argument placeholder name. + for _, a := range args { + if a.Name == strvalue { + val = a.Value + break + } + } + } argPos++ } else { val = s.colValue[n] @@ -667,6 +698,10 @@ func (s *fakeStmt) execInsert(args []driver.Value, doInsert bool) (driver.Result var hookQueryBadConn func() bool func (s *fakeStmt) Query(args []driver.Value) (driver.Rows, error) { + panic("Use QueryContext") +} + +func (s *fakeStmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) { if s.panic == "Query" { panic(s.panic) } @@ -700,9 +735,9 @@ func (s *fakeStmt) Query(args []driver.Value) (driver.Rows, error) { } if s.table == "magicquery" { - if len(s.whereCol) == 2 && s.whereCol[0] == "op" && s.whereCol[1] == "millis" { - if args[0] == "sleep" { - time.Sleep(time.Duration(args[1].(int64)) * time.Millisecond) + if len(s.whereCol) == 2 && s.whereCol[0].Column == "op" && s.whereCol[1].Column == "millis" { + if args[0].Value == "sleep" { + time.Sleep(time.Duration(args[1].Value.(int64)) * time.Millisecond) } } } @@ -725,8 +760,8 @@ func (s *fakeStmt) Query(args []driver.Value) (driver.Rows, error) { // Process the where clause, skipping non-match rows. This is lazy // and just uses fmt.Sprintf("%v") to test equality. Good enough // for test code. - for widx, wcol := range s.whereCol { - idx := t.columnIndex(wcol) + for _, wcol := range s.whereCol { + idx := t.columnIndex(wcol.Column) if idx == -1 { t.mu.Unlock() return nil, fmt.Errorf("db: invalid where clause column %q", wcol) @@ -736,7 +771,19 @@ func (s *fakeStmt) Query(args []driver.Value) (driver.Rows, error) { // lazy hack to avoid sprintf %v on a []byte tcol = string(bs) } - if fmt.Sprintf("%v", tcol) != fmt.Sprintf("%v", args[widx]) { + var argValue interface{} + if wcol.Placeholder == "?" { + argValue = args[wcol.Ordinal-1].Value + } else { + // Assign arg value from placeholder name. + for _, a := range args { + if a.Name == wcol.Placeholder { + argValue = a.Value + break + } + } + } + if fmt.Sprintf("%v", tcol) != fmt.Sprintf("%v", argValue) { continue rows } } diff --git a/src/database/sql/sql.go b/src/database/sql/sql.go index 970334269d..616acb2be1 100644 --- a/src/database/sql/sql.go +++ b/src/database/sql/sql.go @@ -67,6 +67,27 @@ func Drivers() []string { return list } +// NamedParam may be passed into query parameter arguments to associate +// a named placeholder with a value. +type NamedParam struct { + // Name of the parameter placeholder. If empty the ordinal position in the + // argument list will be used. + Name string + + // Value of the parameter. It may be assigned the same value types as + // the query arguments. + Value interface{} +} + +// Param provides a more concise way to create NamedParam values. +func Param(name string, value interface{}) NamedParam { + // This method exists because the go1compat promise + // doesn't guarantee that structs don't grow more fields, + // so unkeyed struct literals are a vet error. Thus, we don't + // want to encourage sql.NamedParam{name, value}. + return NamedParam{Name: name, Value: value} +} + // RawBytes is a byte slice that holds a reference to memory owned by // the database itself. After a Scan into a RawBytes, the slice is only // valid until the next call to Next, Scan, or Close. @@ -1064,7 +1085,7 @@ func (db *DB) exec(ctx context.Context, query string, args []interface{}, strate }() if execer, ok := dc.ci.(driver.Execer); ok { - var dargs []driver.Value + var dargs []driver.NamedValue dargs, err = driverArgs(nil, args) if err != nil { return nil, err diff --git a/src/database/sql/sql_test.go b/src/database/sql/sql_test.go index bce210da97..885cadf3c6 100644 --- a/src/database/sql/sql_test.go +++ b/src/database/sql/sql_test.go @@ -395,6 +395,53 @@ func TestMultiResultSetQuery(t *testing.T) { } } +func TestQueryNamedParam(t *testing.T) { + db := newTestDB(t, "people") + defer closeDB(t, db) + prepares0 := numPrepares(t, db) + rows, err := db.Query( + // Ensure the name and age parameters only match on placeholder name, not position. + "SELECT|people|age,name|name=?name,age=?age", + Param("?age", 2), + Param("?name", "Bob"), + ) + if err != nil { + t.Fatalf("Query: %v", err) + } + type row struct { + age int + name string + } + got := []row{} + for rows.Next() { + var r row + err = rows.Scan(&r.age, &r.name) + if err != nil { + t.Fatalf("Scan: %v", err) + } + got = append(got, r) + } + err = rows.Err() + if err != nil { + t.Fatalf("Err: %v", err) + } + want := []row{ + {age: 2, name: "Bob"}, + } + if !reflect.DeepEqual(got, want) { + t.Errorf("mismatch.\n got: %#v\nwant: %#v", got, want) + } + + // And verify that the final rows.Next() call, which hit EOF, + // also closed the rows connection. + if n := db.numFreeConns(); n != 1 { + t.Fatalf("free conns after query hitting EOF = %d; want 1", n) + } + if prepares := numPrepares(t, db) - prepares0; prepares != 1 { + t.Errorf("executed %d Prepare statements; want 1", prepares) + } +} + func TestByteOwnership(t *testing.T) { db := newTestDB(t, "people") defer closeDB(t, db)