mirror of
https://github.com/golang/go
synced 2024-11-20 07:44:41 -07:00
exp/sql: finish transactions, flesh out types, docs
Fixes #2328 (float, bool) R=rsc, r CC=golang-dev https://golang.org/cl/5294067
This commit is contained in:
parent
cefee3c919
commit
8089e57812
@ -203,7 +203,6 @@ NOTEST+=\
|
|||||||
exp/ebnflint\
|
exp/ebnflint\
|
||||||
exp/gui\
|
exp/gui\
|
||||||
exp/gui/x11\
|
exp/gui/x11\
|
||||||
exp/sql/driver\
|
|
||||||
go/doc\
|
go/doc\
|
||||||
hash\
|
hash\
|
||||||
http/pprof\
|
http/pprof\
|
||||||
|
@ -8,6 +8,7 @@ package sql
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"errors"
|
"errors"
|
||||||
|
"exp/sql/driver"
|
||||||
"fmt"
|
"fmt"
|
||||||
"reflect"
|
"reflect"
|
||||||
"strconv"
|
"strconv"
|
||||||
@ -36,10 +37,11 @@ func convertAssign(dest, src interface{}) error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
sv := reflect.ValueOf(src)
|
var sv reflect.Value
|
||||||
|
|
||||||
switch d := dest.(type) {
|
switch d := dest.(type) {
|
||||||
case *string:
|
case *string:
|
||||||
|
sv = reflect.ValueOf(src)
|
||||||
switch sv.Kind() {
|
switch sv.Kind() {
|
||||||
case reflect.Bool,
|
case reflect.Bool,
|
||||||
reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64,
|
reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64,
|
||||||
@ -48,6 +50,12 @@ func convertAssign(dest, src interface{}) error {
|
|||||||
*d = fmt.Sprintf("%v", src)
|
*d = fmt.Sprintf("%v", src)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
case *bool:
|
||||||
|
bv, err := driver.Bool.ConvertValue(src)
|
||||||
|
if err == nil {
|
||||||
|
*d = bv.(bool)
|
||||||
|
}
|
||||||
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
if scanner, ok := dest.(ScannerInto); ok {
|
if scanner, ok := dest.(ScannerInto); ok {
|
||||||
@ -59,6 +67,10 @@ func convertAssign(dest, src interface{}) error {
|
|||||||
return errors.New("destination not a pointer")
|
return errors.New("destination not a pointer")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if !sv.IsValid() {
|
||||||
|
sv = reflect.ValueOf(src)
|
||||||
|
}
|
||||||
|
|
||||||
dv := reflect.Indirect(dpv)
|
dv := reflect.Indirect(dpv)
|
||||||
if dv.Kind() == sv.Kind() {
|
if dv.Kind() == sv.Kind() {
|
||||||
dv.Set(sv)
|
dv.Set(sv)
|
||||||
@ -67,7 +79,7 @@ func convertAssign(dest, src interface{}) error {
|
|||||||
|
|
||||||
switch dv.Kind() {
|
switch dv.Kind() {
|
||||||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
|
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
|
||||||
if s, ok := asString(src); ok {
|
s := asString(src)
|
||||||
i64, err := strconv.Atoi64(s)
|
i64, err := strconv.Atoi64(s)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("converting string %q to a %s: %v", s, dv.Kind(), err)
|
return fmt.Errorf("converting string %q to a %s: %v", s, dv.Kind(), err)
|
||||||
@ -77,9 +89,8 @@ func convertAssign(dest, src interface{}) error {
|
|||||||
}
|
}
|
||||||
dv.SetInt(i64)
|
dv.SetInt(i64)
|
||||||
return nil
|
return nil
|
||||||
}
|
|
||||||
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
|
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
|
||||||
if s, ok := asString(src); ok {
|
s := asString(src)
|
||||||
u64, err := strconv.Atoui64(s)
|
u64, err := strconv.Atoui64(s)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("converting string %q to a %s: %v", s, dv.Kind(), err)
|
return fmt.Errorf("converting string %q to a %s: %v", s, dv.Kind(), err)
|
||||||
@ -89,18 +100,28 @@ func convertAssign(dest, src interface{}) error {
|
|||||||
}
|
}
|
||||||
dv.SetUint(u64)
|
dv.SetUint(u64)
|
||||||
return nil
|
return nil
|
||||||
|
case reflect.Float32, reflect.Float64:
|
||||||
|
s := asString(src)
|
||||||
|
f64, err := strconv.Atof64(s)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("converting string %q to a %s: %v", s, dv.Kind(), err)
|
||||||
}
|
}
|
||||||
|
if dv.OverflowFloat(f64) {
|
||||||
|
return fmt.Errorf("value %q overflows %s", s, dv.Kind())
|
||||||
|
}
|
||||||
|
dv.SetFloat(f64)
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
return fmt.Errorf("unsupported driver -> Scan pair: %T -> %T", src, dest)
|
return fmt.Errorf("unsupported driver -> Scan pair: %T -> %T", src, dest)
|
||||||
}
|
}
|
||||||
|
|
||||||
func asString(src interface{}) (s string, ok bool) {
|
func asString(src interface{}) string {
|
||||||
switch v := src.(type) {
|
switch v := src.(type) {
|
||||||
case string:
|
case string:
|
||||||
return v, true
|
return v
|
||||||
case []byte:
|
case []byte:
|
||||||
return string(v), true
|
return string(v)
|
||||||
}
|
}
|
||||||
return "", false
|
return fmt.Sprintf("%v", src)
|
||||||
}
|
}
|
||||||
|
@ -17,6 +17,9 @@ type conversionTest struct {
|
|||||||
wantint int64
|
wantint int64
|
||||||
wantuint uint64
|
wantuint uint64
|
||||||
wantstr string
|
wantstr string
|
||||||
|
wantf32 float32
|
||||||
|
wantf64 float64
|
||||||
|
wantbool bool // used if d is of type *bool
|
||||||
wanterr string
|
wanterr string
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -29,6 +32,9 @@ var (
|
|||||||
scanint32 int32
|
scanint32 int32
|
||||||
scanuint8 uint8
|
scanuint8 uint8
|
||||||
scanuint16 uint16
|
scanuint16 uint16
|
||||||
|
scanbool bool
|
||||||
|
scanf32 float32
|
||||||
|
scanf64 float64
|
||||||
)
|
)
|
||||||
|
|
||||||
var conversionTests = []conversionTest{
|
var conversionTests = []conversionTest{
|
||||||
@ -53,6 +59,35 @@ var conversionTests = []conversionTest{
|
|||||||
{s: "256", d: &scanuint16, wantuint: 256},
|
{s: "256", d: &scanuint16, wantuint: 256},
|
||||||
{s: "-1", d: &scanint, wantint: -1},
|
{s: "-1", d: &scanint, wantint: -1},
|
||||||
{s: "foo", d: &scanint, wanterr: `converting string "foo" to a int: parsing "foo": invalid syntax`},
|
{s: "foo", d: &scanint, wanterr: `converting string "foo" to a int: parsing "foo": invalid syntax`},
|
||||||
|
|
||||||
|
// True bools
|
||||||
|
{s: true, d: &scanbool, wantbool: true},
|
||||||
|
{s: "True", d: &scanbool, wantbool: true},
|
||||||
|
{s: "TRUE", d: &scanbool, wantbool: true},
|
||||||
|
{s: "1", d: &scanbool, wantbool: true},
|
||||||
|
{s: 1, d: &scanbool, wantbool: true},
|
||||||
|
{s: int64(1), d: &scanbool, wantbool: true},
|
||||||
|
{s: uint16(1), d: &scanbool, wantbool: true},
|
||||||
|
|
||||||
|
// False bools
|
||||||
|
{s: false, d: &scanbool, wantbool: false},
|
||||||
|
{s: "false", d: &scanbool, wantbool: false},
|
||||||
|
{s: "FALSE", d: &scanbool, wantbool: false},
|
||||||
|
{s: "0", d: &scanbool, wantbool: false},
|
||||||
|
{s: 0, d: &scanbool, wantbool: false},
|
||||||
|
{s: int64(0), d: &scanbool, wantbool: false},
|
||||||
|
{s: uint16(0), d: &scanbool, wantbool: false},
|
||||||
|
|
||||||
|
// Not bools
|
||||||
|
{s: "yup", d: &scanbool, wanterr: `sql/driver: couldn't convert "yup" into type bool`},
|
||||||
|
{s: 2, d: &scanbool, wanterr: `sql/driver: couldn't convert 2 into type bool`},
|
||||||
|
|
||||||
|
// Floats
|
||||||
|
{s: float64(1.5), d: &scanf64, wantf64: float64(1.5)},
|
||||||
|
{s: int64(1), d: &scanf64, wantf64: float64(1)},
|
||||||
|
{s: float64(1.5), d: &scanf32, wantf32: float32(1.5)},
|
||||||
|
{s: "1.5", d: &scanf32, wantf32: float32(1.5)},
|
||||||
|
{s: "1.5", d: &scanf64, wantf64: float64(1.5)},
|
||||||
}
|
}
|
||||||
|
|
||||||
func intValue(intptr interface{}) int64 {
|
func intValue(intptr interface{}) int64 {
|
||||||
@ -63,6 +98,14 @@ func uintValue(intptr interface{}) uint64 {
|
|||||||
return reflect.Indirect(reflect.ValueOf(intptr)).Uint()
|
return reflect.Indirect(reflect.ValueOf(intptr)).Uint()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func float64Value(ptr interface{}) float64 {
|
||||||
|
return *(ptr.(*float64))
|
||||||
|
}
|
||||||
|
|
||||||
|
func float32Value(ptr interface{}) float32 {
|
||||||
|
return *(ptr.(*float32))
|
||||||
|
}
|
||||||
|
|
||||||
func TestConversions(t *testing.T) {
|
func TestConversions(t *testing.T) {
|
||||||
for n, ct := range conversionTests {
|
for n, ct := range conversionTests {
|
||||||
err := convertAssign(ct.d, ct.s)
|
err := convertAssign(ct.d, ct.s)
|
||||||
@ -86,6 +129,15 @@ func TestConversions(t *testing.T) {
|
|||||||
if ct.wantuint != 0 && ct.wantuint != uintValue(ct.d) {
|
if ct.wantuint != 0 && ct.wantuint != uintValue(ct.d) {
|
||||||
errf("want uint %d, got %d", ct.wantuint, uintValue(ct.d))
|
errf("want uint %d, got %d", ct.wantuint, uintValue(ct.d))
|
||||||
}
|
}
|
||||||
|
if ct.wantf32 != 0 && ct.wantf32 != float32Value(ct.d) {
|
||||||
|
errf("want float32 %v, got %v", ct.wantf32, float32Value(ct.d))
|
||||||
|
}
|
||||||
|
if ct.wantf64 != 0 && ct.wantf64 != float64Value(ct.d) {
|
||||||
|
errf("want float32 %v, got %v", ct.wantf64, float64Value(ct.d))
|
||||||
|
}
|
||||||
|
if bp, boolTest := ct.d.(*bool); boolTest && *bp != ct.wantbool && ct.wanterr == "" {
|
||||||
|
errf("want bool %v, got %v", ct.wantbool, *bp)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -24,9 +24,13 @@ import "errors"
|
|||||||
// Driver is the interface that must be implemented by a database
|
// Driver is the interface that must be implemented by a database
|
||||||
// driver.
|
// driver.
|
||||||
type Driver interface {
|
type Driver interface {
|
||||||
// Open returns a new or cached connection to the database.
|
// Open returns a new connection to the database.
|
||||||
// The name is a string in a driver-specific format.
|
// The name is a string in a driver-specific format.
|
||||||
//
|
//
|
||||||
|
// Open may return a cached connection (one previously
|
||||||
|
// closed), but doing so is unnecessary; the sql package
|
||||||
|
// maintains a pool of idle connections for efficient re-use.
|
||||||
|
//
|
||||||
// The returned connection is only used by one goroutine at a
|
// The returned connection is only used by one goroutine at a
|
||||||
// time.
|
// time.
|
||||||
Open(name string) (Conn, error)
|
Open(name string) (Conn, error)
|
||||||
@ -59,8 +63,12 @@ type Conn interface {
|
|||||||
|
|
||||||
// Close invalidates and potentially stops any current
|
// Close invalidates and potentially stops any current
|
||||||
// prepared statements and transactions, marking this
|
// prepared statements and transactions, marking this
|
||||||
// connection as no longer in use. The driver may cache or
|
// connection as no longer in use.
|
||||||
// close its underlying connection to its database.
|
//
|
||||||
|
// Because the sql package maintains a free pool of
|
||||||
|
// connections and only calls Close when there's a surplus of
|
||||||
|
// idle connections, it shouldn't be necessary for drivers to
|
||||||
|
// do their own connection caching.
|
||||||
Close() error
|
Close() error
|
||||||
|
|
||||||
// Begin starts and returns a new transaction.
|
// Begin starts and returns a new transaction.
|
||||||
|
@ -11,6 +11,21 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
// ValueConverter is the interface providing the ConvertValue method.
|
// ValueConverter is the interface providing the ConvertValue method.
|
||||||
|
//
|
||||||
|
// Various implementations of ValueConverter are provided by the
|
||||||
|
// driver package to provide consistent implementations of conversions
|
||||||
|
// between drivers. The ValueConverters have several uses:
|
||||||
|
//
|
||||||
|
// * converting from the subset types as provided by the sql package
|
||||||
|
// into a database table's specific column type and making sure it
|
||||||
|
// fits, such as making sure a particular int64 fits in a
|
||||||
|
// table's uint16 column.
|
||||||
|
//
|
||||||
|
// * converting a value as given from the database into one of the
|
||||||
|
// subset types.
|
||||||
|
//
|
||||||
|
// * by the sql package, for converting from a driver's subset type
|
||||||
|
// to a user's type in a scan.
|
||||||
type ValueConverter interface {
|
type ValueConverter interface {
|
||||||
// ConvertValue converts a value to a restricted subset type.
|
// ConvertValue converts a value to a restricted subset type.
|
||||||
ConvertValue(v interface{}) (interface{}, error)
|
ConvertValue(v interface{}) (interface{}, error)
|
||||||
@ -19,15 +34,56 @@ type ValueConverter interface {
|
|||||||
// Bool is a ValueConverter that converts input values to bools.
|
// Bool is a ValueConverter that converts input values to bools.
|
||||||
//
|
//
|
||||||
// The conversion rules are:
|
// The conversion rules are:
|
||||||
// - .... TODO(bradfitz): TBD
|
// - booleans are returned unchanged
|
||||||
|
// - for integer types,
|
||||||
|
// 1 is true
|
||||||
|
// 0 is false,
|
||||||
|
// other integers are an error
|
||||||
|
// - for strings and []byte, same rules as strconv.Atob
|
||||||
|
// - all other types are an error
|
||||||
var Bool boolType
|
var Bool boolType
|
||||||
|
|
||||||
type boolType struct{}
|
type boolType struct{}
|
||||||
|
|
||||||
var _ ValueConverter = boolType{}
|
var _ ValueConverter = boolType{}
|
||||||
|
|
||||||
func (boolType) ConvertValue(v interface{}) (interface{}, error) {
|
func (boolType) String() string { return "Bool" }
|
||||||
return nil, fmt.Errorf("TODO(bradfitz): bool conversions")
|
|
||||||
|
func (boolType) ConvertValue(src interface{}) (interface{}, error) {
|
||||||
|
switch s := src.(type) {
|
||||||
|
case bool:
|
||||||
|
return s, nil
|
||||||
|
case string:
|
||||||
|
b, err := strconv.Atob(s)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("sql/driver: couldn't convert %q into type bool", s)
|
||||||
|
}
|
||||||
|
return b, nil
|
||||||
|
case []byte:
|
||||||
|
b, err := strconv.Atob(string(s))
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("sql/driver: couldn't convert %q into type bool", s)
|
||||||
|
}
|
||||||
|
return b, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
sv := reflect.ValueOf(src)
|
||||||
|
switch sv.Kind() {
|
||||||
|
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
|
||||||
|
iv := sv.Int()
|
||||||
|
if iv == 1 || iv == 0 {
|
||||||
|
return iv == 1, nil
|
||||||
|
}
|
||||||
|
return nil, fmt.Errorf("sql/driver: couldn't convert %d into type bool", iv)
|
||||||
|
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
|
||||||
|
uv := sv.Uint()
|
||||||
|
if uv == 1 || uv == 0 {
|
||||||
|
return uv == 1, nil
|
||||||
|
}
|
||||||
|
return nil, fmt.Errorf("sql/driver: couldn't convert %d into type bool", uv)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, fmt.Errorf("sql/driver: couldn't convert %v (%T) into type bool", src, src)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Int32 is a ValueConverter that converts input values to int64,
|
// Int32 is a ValueConverter that converts input values to int64,
|
||||||
|
57
src/pkg/exp/sql/driver/types_test.go
Normal file
57
src/pkg/exp/sql/driver/types_test.go
Normal file
@ -0,0 +1,57 @@
|
|||||||
|
// Copyright 2011 The Go Authors. All rights reserved.
|
||||||
|
// Use of this source code is governed by a BSD-style
|
||||||
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
|
package driver
|
||||||
|
|
||||||
|
import (
|
||||||
|
"reflect"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
type valueConverterTest struct {
|
||||||
|
c ValueConverter
|
||||||
|
in interface{}
|
||||||
|
out interface{}
|
||||||
|
err string
|
||||||
|
}
|
||||||
|
|
||||||
|
var valueConverterTests = []valueConverterTest{
|
||||||
|
{Bool, "true", true, ""},
|
||||||
|
{Bool, "True", true, ""},
|
||||||
|
{Bool, []byte("t"), true, ""},
|
||||||
|
{Bool, true, true, ""},
|
||||||
|
{Bool, "1", true, ""},
|
||||||
|
{Bool, 1, true, ""},
|
||||||
|
{Bool, int64(1), true, ""},
|
||||||
|
{Bool, uint16(1), true, ""},
|
||||||
|
{Bool, "false", false, ""},
|
||||||
|
{Bool, false, false, ""},
|
||||||
|
{Bool, "0", false, ""},
|
||||||
|
{Bool, 0, false, ""},
|
||||||
|
{Bool, int64(0), false, ""},
|
||||||
|
{Bool, uint16(0), false, ""},
|
||||||
|
{c: Bool, in: "foo", err: "sql/driver: couldn't convert \"foo\" into type bool"},
|
||||||
|
{c: Bool, in: 2, err: "sql/driver: couldn't convert 2 into type bool"},
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValueConverters(t *testing.T) {
|
||||||
|
for i, tt := range valueConverterTests {
|
||||||
|
out, err := tt.c.ConvertValue(tt.in)
|
||||||
|
goterr := ""
|
||||||
|
if err != nil {
|
||||||
|
goterr = err.Error()
|
||||||
|
}
|
||||||
|
if goterr != tt.err {
|
||||||
|
t.Errorf("test %d: %s(%T(%v)) error = %q; want error = %q",
|
||||||
|
i, tt.c, tt.in, tt.in, goterr, tt.err)
|
||||||
|
}
|
||||||
|
if tt.err != "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if !reflect.DeepEqual(out, tt.out) {
|
||||||
|
t.Errorf("test %d: %s(%T(%v)) = %v (%T); want %v (%T)",
|
||||||
|
i, tt.c, tt.in, tt.in, out, out, tt.out, tt.out)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
@ -476,7 +476,7 @@ func (rc *rowsCursor) Next(dest []interface{}) error {
|
|||||||
for i, v := range rc.rows[rc.pos].cols {
|
for i, v := range rc.rows[rc.pos].cols {
|
||||||
// TODO(bradfitz): convert to subset types? naah, I
|
// TODO(bradfitz): convert to subset types? naah, I
|
||||||
// think the subset types should only be input to
|
// think the subset types should only be input to
|
||||||
// driver, but the db package should be able to handle
|
// driver, but the sql package should be able to handle
|
||||||
// a wider range of types coming out of drivers. all
|
// a wider range of types coming out of drivers. all
|
||||||
// for ease of drivers, and to prevent drivers from
|
// for ease of drivers, and to prevent drivers from
|
||||||
// messing up conversions or doing them differently.
|
// messing up conversions or doing them differently.
|
||||||
|
@ -10,7 +10,6 @@ import (
|
|||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"runtime"
|
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
"exp/sql/driver"
|
"exp/sql/driver"
|
||||||
@ -192,13 +191,13 @@ func (db *DB) Exec(query string, args ...interface{}) (Result, error) {
|
|||||||
|
|
||||||
// If the driver does not implement driver.Execer, we need
|
// If the driver does not implement driver.Execer, we need
|
||||||
// a connection.
|
// a connection.
|
||||||
conn, err := db.conn()
|
ci, err := db.conn()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
defer db.putConn(conn)
|
defer db.putConn(ci)
|
||||||
|
|
||||||
if execer, ok := conn.(driver.Execer); ok {
|
if execer, ok := ci.(driver.Execer); ok {
|
||||||
resi, err := execer.Exec(query, args)
|
resi, err := execer.Exec(query, args)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@ -206,7 +205,7 @@ func (db *DB) Exec(query string, args ...interface{}) (Result, error) {
|
|||||||
return result{resi}, nil
|
return result{resi}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
sti, err := conn.Prepare(query)
|
sti, err := ci.Prepare(query)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@ -233,18 +232,26 @@ func (db *DB) Query(query string, args ...interface{}) (*Rows, error) {
|
|||||||
// Row's Scan method is called.
|
// Row's Scan method is called.
|
||||||
func (db *DB) QueryRow(query string, args ...interface{}) *Row {
|
func (db *DB) QueryRow(query string, args ...interface{}) *Row {
|
||||||
rows, err := db.Query(query, args...)
|
rows, err := db.Query(query, args...)
|
||||||
if err != nil {
|
return &Row{rows: rows, err: err}
|
||||||
return &Row{err: err}
|
|
||||||
}
|
|
||||||
return &Row{rows: rows}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Begin starts a transaction. The isolation level is dependent on
|
// Begin starts a transaction. The isolation level is dependent on
|
||||||
// the driver.
|
// the driver.
|
||||||
func (db *DB) Begin() (*Tx, error) {
|
func (db *DB) Begin() (*Tx, error) {
|
||||||
// TODO(bradfitz): add another method for beginning a transaction
|
ci, err := db.conn()
|
||||||
// at a specific isolation level.
|
if err != nil {
|
||||||
panic(todo())
|
return nil, err
|
||||||
|
}
|
||||||
|
txi, err := ci.Begin()
|
||||||
|
if err != nil {
|
||||||
|
db.putConn(ci)
|
||||||
|
return nil, fmt.Errorf("sql: failed to Begin transaction: %v", err)
|
||||||
|
}
|
||||||
|
return &Tx{
|
||||||
|
db: db,
|
||||||
|
ci: ci,
|
||||||
|
txi: txi,
|
||||||
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// DriverDatabase returns the database's underlying driver.
|
// DriverDatabase returns the database's underlying driver.
|
||||||
@ -253,41 +260,158 @@ func (db *DB) Driver() driver.Driver {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Tx is an in-progress database transaction.
|
// Tx is an in-progress database transaction.
|
||||||
|
//
|
||||||
|
// A transaction must end with a call to Commit or Rollback.
|
||||||
|
//
|
||||||
|
// After a call to Commit or Rollback, all operations on the
|
||||||
|
// transaction fail with ErrTransactionFinished.
|
||||||
type Tx struct {
|
type Tx struct {
|
||||||
|
db *DB
|
||||||
|
|
||||||
|
// ci is owned exclusively until Commit or Rollback, at which point
|
||||||
|
// it's returned with putConn.
|
||||||
|
ci driver.Conn
|
||||||
|
txi driver.Tx
|
||||||
|
|
||||||
|
// cimu is held while somebody is using ci (between grabConn
|
||||||
|
// and releaseConn)
|
||||||
|
cimu sync.Mutex
|
||||||
|
|
||||||
|
// done transitions from false to true exactly once, on Commit
|
||||||
|
// or Rollback. once done, all operations fail with
|
||||||
|
// ErrTransactionFinished.
|
||||||
|
done bool
|
||||||
|
}
|
||||||
|
|
||||||
|
var ErrTransactionFinished = errors.New("sql: Transaction has already been committed or rolled back")
|
||||||
|
|
||||||
|
func (tx *Tx) close() {
|
||||||
|
if tx.done {
|
||||||
|
panic("double close") // internal error
|
||||||
|
}
|
||||||
|
tx.done = true
|
||||||
|
tx.db.putConn(tx.ci)
|
||||||
|
tx.ci = nil
|
||||||
|
tx.txi = nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (tx *Tx) grabConn() (driver.Conn, error) {
|
||||||
|
if tx.done {
|
||||||
|
return nil, ErrTransactionFinished
|
||||||
|
}
|
||||||
|
tx.cimu.Lock()
|
||||||
|
return tx.ci, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (tx *Tx) releaseConn() {
|
||||||
|
tx.cimu.Unlock()
|
||||||
}
|
}
|
||||||
|
|
||||||
// Commit commits the transaction.
|
// Commit commits the transaction.
|
||||||
func (tx *Tx) Commit() error {
|
func (tx *Tx) Commit() error {
|
||||||
panic(todo())
|
if tx.done {
|
||||||
|
return ErrTransactionFinished
|
||||||
|
}
|
||||||
|
defer tx.close()
|
||||||
|
return tx.txi.Commit()
|
||||||
}
|
}
|
||||||
|
|
||||||
// Rollback aborts the transaction.
|
// Rollback aborts the transaction.
|
||||||
func (tx *Tx) Rollback() error {
|
func (tx *Tx) Rollback() error {
|
||||||
panic(todo())
|
if tx.done {
|
||||||
|
return ErrTransactionFinished
|
||||||
|
}
|
||||||
|
defer tx.close()
|
||||||
|
return tx.txi.Rollback()
|
||||||
}
|
}
|
||||||
|
|
||||||
// Prepare creates a prepared statement.
|
// Prepare creates a prepared statement.
|
||||||
|
//
|
||||||
|
// The statement is only valid within the scope of this transaction.
|
||||||
func (tx *Tx) Prepare(query string) (*Stmt, error) {
|
func (tx *Tx) Prepare(query string) (*Stmt, error) {
|
||||||
panic(todo())
|
// TODO(bradfitz): the restriction that the returned statement
|
||||||
|
// is only valid for this Transaction is lame and negates a
|
||||||
|
// lot of the benefit of prepared statements. We could be
|
||||||
|
// more efficient here and either provide a method to take an
|
||||||
|
// existing Stmt (created on perhaps a different Conn), and
|
||||||
|
// re-create it on this Conn if necessary. Or, better: keep a
|
||||||
|
// map in DB of query string to Stmts, and have Stmt.Execute
|
||||||
|
// do the right thing and re-prepare if the Conn in use
|
||||||
|
// doesn't have that prepared statement. But we'll want to
|
||||||
|
// avoid caching the statement in the case where we only call
|
||||||
|
// conn.Prepare implicitly (such as in db.Exec or tx.Exec),
|
||||||
|
// but the caller package can't be holding a reference to the
|
||||||
|
// returned statement. Perhaps just looking at the reference
|
||||||
|
// count (by noting Stmt.Close) would be enough. We might also
|
||||||
|
// want a finalizer on Stmt to drop the reference count.
|
||||||
|
ci, err := tx.grabConn()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer tx.releaseConn()
|
||||||
|
|
||||||
|
si, err := ci.Prepare(query)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
stmt := &Stmt{
|
||||||
|
db: tx.db,
|
||||||
|
tx: tx,
|
||||||
|
txsi: si,
|
||||||
|
query: query,
|
||||||
|
}
|
||||||
|
return stmt, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Exec executes a query that doesn't return rows.
|
// Exec executes a query that doesn't return rows.
|
||||||
// For example: an INSERT and UPDATE.
|
// For example: an INSERT and UPDATE.
|
||||||
func (tx *Tx) Exec(query string, args ...interface{}) {
|
func (tx *Tx) Exec(query string, args ...interface{}) (Result, error) {
|
||||||
panic(todo())
|
ci, err := tx.grabConn()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer tx.releaseConn()
|
||||||
|
|
||||||
|
if execer, ok := ci.(driver.Execer); ok {
|
||||||
|
resi, err := execer.Exec(query, args)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return result{resi}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
sti, err := ci.Prepare(query)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer sti.Close()
|
||||||
|
resi, err := sti.Exec(args)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return result{resi}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Query executes a query that returns rows, typically a SELECT.
|
// Query executes a query that returns rows, typically a SELECT.
|
||||||
func (tx *Tx) Query(query string, args ...interface{}) (*Rows, error) {
|
func (tx *Tx) Query(query string, args ...interface{}) (*Rows, error) {
|
||||||
panic(todo())
|
if tx.done {
|
||||||
|
return nil, ErrTransactionFinished
|
||||||
|
}
|
||||||
|
stmt, err := tx.Prepare(query)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer stmt.Close()
|
||||||
|
return stmt.Query(args...)
|
||||||
}
|
}
|
||||||
|
|
||||||
// QueryRow executes a query that is expected to return at most one row.
|
// QueryRow executes a query that is expected to return at most one row.
|
||||||
// QueryRow always return a non-nil value. Errors are deferred until
|
// QueryRow always return a non-nil value. Errors are deferred until
|
||||||
// Row's Scan method is called.
|
// Row's Scan method is called.
|
||||||
func (tx *Tx) QueryRow(query string, args ...interface{}) *Row {
|
func (tx *Tx) QueryRow(query string, args ...interface{}) *Row {
|
||||||
panic(todo())
|
rows, err := tx.Query(query, args...)
|
||||||
|
return &Row{rows: rows, err: err}
|
||||||
}
|
}
|
||||||
|
|
||||||
// connStmt is a prepared statement on a particular connection.
|
// connStmt is a prepared statement on a particular connection.
|
||||||
@ -302,24 +426,28 @@ type Stmt struct {
|
|||||||
db *DB // where we came from
|
db *DB // where we came from
|
||||||
query string // that created the Sttm
|
query string // that created the Sttm
|
||||||
|
|
||||||
mu sync.Mutex
|
// If in a transaction, else both nil:
|
||||||
closed bool
|
tx *Tx
|
||||||
css []connStmt // can use any that have idle connections
|
txsi driver.Stmt
|
||||||
}
|
|
||||||
|
|
||||||
func todo() string {
|
mu sync.Mutex // protects the rest of the fields
|
||||||
_, file, line, _ := runtime.Caller(1)
|
closed bool
|
||||||
return fmt.Sprintf("%s:%d: TODO: implement", file, line)
|
|
||||||
|
// css is a list of underlying driver statement interfaces
|
||||||
|
// that are valid on particular connections. This is only
|
||||||
|
// used if tx == nil and one is found that has idle
|
||||||
|
// connections. If tx != nil, txsi is always used.
|
||||||
|
css []connStmt
|
||||||
}
|
}
|
||||||
|
|
||||||
// Exec executes a prepared statement with the given arguments and
|
// Exec executes a prepared statement with the given arguments and
|
||||||
// returns a Result summarizing the effect of the statement.
|
// returns a Result summarizing the effect of the statement.
|
||||||
func (s *Stmt) Exec(args ...interface{}) (Result, error) {
|
func (s *Stmt) Exec(args ...interface{}) (Result, error) {
|
||||||
ci, si, err := s.connStmt()
|
_, releaseConn, si, err := s.connStmt()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
defer s.db.putConn(ci)
|
defer releaseConn()
|
||||||
|
|
||||||
if want := si.NumInput(); len(args) != want {
|
if want := si.NumInput(); len(args) != want {
|
||||||
return nil, fmt.Errorf("db: expected %d arguments, got %d", want, len(args))
|
return nil, fmt.Errorf("db: expected %d arguments, got %d", want, len(args))
|
||||||
@ -353,11 +481,29 @@ func (s *Stmt) Exec(args ...interface{}) (Result, error) {
|
|||||||
return result{resi}, nil
|
return result{resi}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Stmt) connStmt(args ...interface{}) (driver.Conn, driver.Stmt, error) {
|
// connStmt returns a free driver connection on which to execute the
|
||||||
|
// statement, a function to call to release the connection, and a
|
||||||
|
// statement bound to that connection.
|
||||||
|
func (s *Stmt) connStmt() (ci driver.Conn, releaseConn func(), si driver.Stmt, err error) {
|
||||||
s.mu.Lock()
|
s.mu.Lock()
|
||||||
if s.closed {
|
if s.closed {
|
||||||
return nil, nil, errors.New("db: statement is closed")
|
s.mu.Unlock()
|
||||||
|
err = errors.New("db: statement is closed")
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// In a transaction, we always use the connection that the
|
||||||
|
// transaction was created on.
|
||||||
|
if s.tx != nil {
|
||||||
|
s.mu.Unlock()
|
||||||
|
ci, err = s.tx.grabConn() // blocks, waiting for the connection.
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
releaseConn = func() { s.tx.releaseConn() }
|
||||||
|
return ci, releaseConn, s.txsi, nil
|
||||||
|
}
|
||||||
|
|
||||||
var cs connStmt
|
var cs connStmt
|
||||||
match := false
|
match := false
|
||||||
for _, v := range s.css {
|
for _, v := range s.css {
|
||||||
@ -375,11 +521,11 @@ func (s *Stmt) connStmt(args ...interface{}) (driver.Conn, driver.Stmt, error) {
|
|||||||
if !match {
|
if !match {
|
||||||
ci, err := s.db.conn()
|
ci, err := s.db.conn()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, err
|
return nil, nil, nil, err
|
||||||
}
|
}
|
||||||
si, err := ci.Prepare(s.query)
|
si, err := ci.Prepare(s.query)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, err
|
return nil, nil, nil, err
|
||||||
}
|
}
|
||||||
s.mu.Lock()
|
s.mu.Lock()
|
||||||
cs = connStmt{ci, si}
|
cs = connStmt{ci, si}
|
||||||
@ -387,13 +533,15 @@ func (s *Stmt) connStmt(args ...interface{}) (driver.Conn, driver.Stmt, error) {
|
|||||||
s.mu.Unlock()
|
s.mu.Unlock()
|
||||||
}
|
}
|
||||||
|
|
||||||
return cs.ci, cs.si, nil
|
conn := cs.ci
|
||||||
|
releaseConn = func() { s.db.putConn(conn) }
|
||||||
|
return conn, releaseConn, cs.si, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Query executes a prepared query statement with the given arguments
|
// Query executes a prepared query statement with the given arguments
|
||||||
// and returns the query results as a *Rows.
|
// and returns the query results as a *Rows.
|
||||||
func (s *Stmt) Query(args ...interface{}) (*Rows, error) {
|
func (s *Stmt) Query(args ...interface{}) (*Rows, error) {
|
||||||
ci, si, err := s.connStmt(args...)
|
ci, releaseConn, si, err := s.connStmt()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@ -405,10 +553,12 @@ func (s *Stmt) Query(args ...interface{}) (*Rows, error) {
|
|||||||
s.db.putConn(ci)
|
s.db.putConn(ci)
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
// Note: ownership of ci passes to the *Rows
|
// Note: ownership of ci passes to the *Rows, to be freed
|
||||||
|
// with releaseConn.
|
||||||
rows := &Rows{
|
rows := &Rows{
|
||||||
db: s.db,
|
db: s.db,
|
||||||
ci: ci,
|
ci: ci,
|
||||||
|
releaseConn: releaseConn,
|
||||||
rowsi: rowsi,
|
rowsi: rowsi,
|
||||||
}
|
}
|
||||||
return rows, nil
|
return rows, nil
|
||||||
@ -436,11 +586,15 @@ func (s *Stmt) QueryRow(args ...interface{}) *Row {
|
|||||||
// Close closes the statement.
|
// Close closes the statement.
|
||||||
func (s *Stmt) Close() error {
|
func (s *Stmt) Close() error {
|
||||||
s.mu.Lock()
|
s.mu.Lock()
|
||||||
defer s.mu.Unlock() // TODO(bradfitz): move this unlock after 'closed = true'?
|
defer s.mu.Unlock()
|
||||||
if s.closed {
|
if s.closed {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
s.closed = true
|
s.closed = true
|
||||||
|
|
||||||
|
if s.tx != nil {
|
||||||
|
s.txsi.Close()
|
||||||
|
} else {
|
||||||
for _, v := range s.css {
|
for _, v := range s.css {
|
||||||
if ci, match := s.db.connIfFree(v.ci); match {
|
if ci, match := s.db.connIfFree(v.ci); match {
|
||||||
v.si.Close()
|
v.si.Close()
|
||||||
@ -451,6 +605,7 @@ func (s *Stmt) Close() error {
|
|||||||
// connection is in use?
|
// connection is in use?
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -469,7 +624,8 @@ func (s *Stmt) Close() error {
|
|||||||
// ...
|
// ...
|
||||||
type Rows struct {
|
type Rows struct {
|
||||||
db *DB
|
db *DB
|
||||||
ci driver.Conn // owned; must be returned when Rows is closed
|
ci driver.Conn // owned; must call putconn when closed to release
|
||||||
|
releaseConn func()
|
||||||
rowsi driver.Rows
|
rowsi driver.Rows
|
||||||
|
|
||||||
closed bool
|
closed bool
|
||||||
@ -538,7 +694,7 @@ func (rs *Rows) Close() error {
|
|||||||
}
|
}
|
||||||
rs.closed = true
|
rs.closed = true
|
||||||
err := rs.rowsi.Close()
|
err := rs.rowsi.Close()
|
||||||
rs.db.putConn(rs.ci)
|
rs.releaseConn()
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user