mirror of
https://github.com/golang/go
synced 2024-11-25 07:17:56 -07:00
exp/sql: copy when scanning into []byte by default
Fixes #2698 R=rsc CC=golang-dev https://golang.org/cl/5539060
This commit is contained in:
parent
ddef49dfce
commit
ebc8013edf
@ -30,6 +30,11 @@ func Register(name string, driver driver.Driver) {
|
||||
drivers[name] = driver
|
||||
}
|
||||
|
||||
// RawBytes is a byte slice that holds a reference to memory owned by
|
||||
// the database itself. After a Scan into a RawBytes, the slice is only
|
||||
// valid until the next call to Next, Scan, or Close.
|
||||
type RawBytes []byte
|
||||
|
||||
// NullableString represents a string that may be null.
|
||||
// NullableString implements the ScannerInto interface so
|
||||
// it can be used as a scan destination:
|
||||
@ -760,9 +765,13 @@ func (rs *Rows) Columns() ([]string, error) {
|
||||
}
|
||||
|
||||
// Scan copies the columns in the current row into the values pointed
|
||||
// at by dest. If dest contains pointers to []byte, the slices should
|
||||
// not be modified and should only be considered valid until the next
|
||||
// call to Next or Scan.
|
||||
// at by dest.
|
||||
//
|
||||
// If an argument has type *[]byte, Scan saves in that argument a copy
|
||||
// of the corresponding data. The copy is owned by the caller and can
|
||||
// be modified and held indefinitely. The copy can be avoided by using
|
||||
// an argument of type *RawBytes instead; see the documentation for
|
||||
// RawBytes for restrictions on its use.
|
||||
func (rs *Rows) Scan(dest ...interface{}) error {
|
||||
if rs.closed {
|
||||
return errors.New("sql: Rows closed")
|
||||
@ -782,6 +791,18 @@ func (rs *Rows) Scan(dest ...interface{}) error {
|
||||
return fmt.Errorf("sql: Scan error on column index %d: %v", i, err)
|
||||
}
|
||||
}
|
||||
for _, dp := range dest {
|
||||
b, ok := dp.(*[]byte)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
if _, ok = dp.(*RawBytes); ok {
|
||||
continue
|
||||
}
|
||||
clone := make([]byte, len(*b))
|
||||
copy(clone, *b)
|
||||
*b = clone
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -838,6 +859,9 @@ func (r *Row) Scan(dest ...interface{}) error {
|
||||
// they were obtained from the network anyway) But for now we
|
||||
// don't care.
|
||||
for _, dp := range dest {
|
||||
if _, ok := dp.(*RawBytes); ok {
|
||||
return errors.New("sql: RawBytes isn't allowed on Row.Scan")
|
||||
}
|
||||
b, ok := dp.(*[]byte)
|
||||
if !ok {
|
||||
continue
|
||||
|
@ -76,7 +76,7 @@ func TestQuery(t *testing.T) {
|
||||
{age: 3, name: "Chris"},
|
||||
}
|
||||
if !reflect.DeepEqual(got, want) {
|
||||
t.Logf(" got: %#v\nwant: %#v", got, want)
|
||||
t.Errorf("mismatch.\n got: %#v\nwant: %#v", got, want)
|
||||
}
|
||||
|
||||
// And verify that the final rows.Next() call, which hit EOF,
|
||||
@ -86,6 +86,43 @@ func TestQuery(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestByteOwnership(t *testing.T) {
|
||||
db := newTestDB(t, "people")
|
||||
defer closeDB(t, db)
|
||||
rows, err := db.Query("SELECT|people|name,photo|")
|
||||
if err != nil {
|
||||
t.Fatalf("Query: %v", err)
|
||||
}
|
||||
type row struct {
|
||||
name []byte
|
||||
photo RawBytes
|
||||
}
|
||||
got := []row{}
|
||||
for rows.Next() {
|
||||
var r row
|
||||
err = rows.Scan(&r.name, &r.photo)
|
||||
if err != nil {
|
||||
t.Fatalf("Scan: %v", err)
|
||||
}
|
||||
got = append(got, r)
|
||||
}
|
||||
corruptMemory := []byte("\xffPHOTO")
|
||||
want := []row{
|
||||
{name: []byte("Alice"), photo: corruptMemory},
|
||||
{name: []byte("Bob"), photo: corruptMemory},
|
||||
{name: []byte("Chris"), photo: corruptMemory},
|
||||
}
|
||||
if !reflect.DeepEqual(got, want) {
|
||||
t.Errorf("mismatch.\n got: %#v\nwant: %#v", got, want)
|
||||
}
|
||||
|
||||
var photo RawBytes
|
||||
err = db.QueryRow("SELECT|people|photo|name=?", "Alice").Scan(&photo)
|
||||
if err == nil {
|
||||
t.Error("want error scanning into RawBytes from QueryRow")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRowsColumns(t *testing.T) {
|
||||
db := newTestDB(t, "people")
|
||||
defer closeDB(t, db)
|
||||
@ -300,6 +337,6 @@ func TestQueryRowClosingStmt(t *testing.T) {
|
||||
}
|
||||
fakeConn := db.freeConn[0].(*fakeConn)
|
||||
if made, closed := fakeConn.stmtsMade, fakeConn.stmtsClosed; made != closed {
|
||||
t.Logf("statement close mismatch: made %d, closed %d", made, closed)
|
||||
t.Errorf("statement close mismatch: made %d, closed %d", made, closed)
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user