// Copyright 2011 The Go Authors. All rights reserved. // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. // Type conversions for Scan. package sql import ( "database/sql/driver" "errors" "fmt" "reflect" "strconv" "time" "unicode" "unicode/utf8" ) var errNilPtr = errors.New("destination pointer is nil") // embedded in descriptive error func describeNamedValue(nv *driver.NamedValue) string { if len(nv.Name) == 0 { return fmt.Sprintf("$%d", nv.Ordinal) } return fmt.Sprintf("with name %q", nv.Name) } func validateNamedValueName(name string) error { if len(name) == 0 { return nil } r, _ := utf8.DecodeRuneInString(name) if unicode.IsLetter(r) { return nil } return fmt.Errorf("name %q does not begin with a letter", name) } // ccChecker wraps the driver.ColumnConverter and allows it to be used // as if it were a NamedValueChecker. If the driver ColumnConverter // is not present then the NamedValueChecker will return driver.ErrSkip. type ccChecker struct { cci driver.ColumnConverter want int } func (c ccChecker) CheckNamedValue(nv *driver.NamedValue) error { if c.cci == nil { return driver.ErrSkip } // The column converter shouldn't be called on any index // it isn't expecting. The final error will be thrown // in the argument converter loop. index := nv.Ordinal - 1 if c.want <= index { return nil } // First, see if the value itself knows how to convert // itself to a driver type. For example, a NullString // struct changing into a string or nil. if vr, ok := nv.Value.(driver.Valuer); ok { sv, err := callValuerValue(vr) if err != nil { return err } if !driver.IsValue(sv) { return fmt.Errorf("non-subset type %T returned from Value", sv) } nv.Value = sv } // Second, ask the column to sanity check itself. For // example, drivers might use this to make sure that // an int64 values being inserted into a 16-bit // integer field is in range (before getting // truncated), or that a nil can't go into a NOT NULL // column before going across the network to get the // same error. var err error arg := nv.Value nv.Value, err = c.cci.ColumnConverter(index).ConvertValue(arg) if err != nil { return err } if !driver.IsValue(nv.Value) { return fmt.Errorf("driver ColumnConverter error converted %T to unsupported type %T", arg, nv.Value) } return nil } // defaultCheckNamedValue wraps the default ColumnConverter to have the same // function signature as the CheckNamedValue in the driver.NamedValueChecker // interface. func defaultCheckNamedValue(nv *driver.NamedValue) (err error) { nv.Value, err = driver.DefaultParameterConverter.ConvertValue(nv.Value) return err } // driverArgs converts arguments from callers of Stmt.Exec and // Stmt.Query into driver Values. // // The statement ds may be nil, if no statement is available. func driverArgsConnLocked(ci driver.Conn, ds *driverStmt, args []interface{}) ([]driver.NamedValue, error) { nvargs := make([]driver.NamedValue, len(args)) // -1 means the driver doesn't know how to count the number of // placeholders, so we won't sanity check input here and instead let the // driver deal with errors. want := -1 var si driver.Stmt var cc ccChecker if ds != nil { si = ds.si want = ds.si.NumInput() cc.want = want } // Check all types of interfaces from the start. // Drivers may opt to use the NamedValueChecker for special // argument types, then return driver.ErrSkip to pass it along // to the column converter. nvc, ok := si.(driver.NamedValueChecker) if !ok { nvc, ok = ci.(driver.NamedValueChecker) } cci, ok := si.(driver.ColumnConverter) if ok { cc.cci = cci } // Loop through all the arguments, checking each one. // If no error is returned simply increment the index // and continue. However if driver.ErrRemoveArgument // is returned the argument is not included in the query // argument list. var err error var n int for _, arg := range args { nv := &nvargs[n] if np, ok := arg.(NamedArg); ok { if err = validateNamedValueName(np.Name); err != nil { return nil, err } arg = np.Value nv.Name = np.Name } nv.Ordinal = n + 1 nv.Value = arg // Checking sequence has four routes: // A: 1. Default // B: 1. NamedValueChecker 2. Column Converter 3. Default // C: 1. NamedValueChecker 3. Default // D: 1. Column Converter 2. Default // // The only time a Column Converter is called is first // or after NamedValueConverter. If first it is handled before // the nextCheck label. Thus for repeats tries only when the // NamedValueConverter is selected should the Column Converter // be used in the retry. checker := defaultCheckNamedValue nextCC := false switch { case nvc != nil: nextCC = cci != nil checker = nvc.CheckNamedValue case cci != nil: checker = cc.CheckNamedValue } nextCheck: err = checker(nv) switch err { case nil: n++ continue case driver.ErrRemoveArgument: nvargs = nvargs[:len(nvargs)-1] continue case driver.ErrSkip: if nextCC { nextCC = false checker = cc.CheckNamedValue } else { checker = defaultCheckNamedValue } goto nextCheck default: return nil, fmt.Errorf("sql: converting argument %s type: %v", describeNamedValue(nv), err) } } // Check the length of arguments after conversion to allow for omitted // arguments. if want != -1 && len(nvargs) != want { return nil, fmt.Errorf("sql: expected %d arguments, got %d", want, len(nvargs)) } return nvargs, nil } // convertAssign copies to dest the value in src, converting it if possible. // An error is returned if the copy would result in loss of information. // dest should be a pointer type. func convertAssign(dest, src interface{}) error { // Common cases, without reflect. switch s := src.(type) { case string: switch d := dest.(type) { case *string: if d == nil { return errNilPtr } *d = s return nil case *[]byte: if d == nil { return errNilPtr } *d = []byte(s) return nil case *RawBytes: if d == nil { return errNilPtr } *d = append((*d)[:0], s...) return nil } case []byte: switch d := dest.(type) { case *string: if d == nil { return errNilPtr } *d = string(s) return nil case *interface{}: if d == nil { return errNilPtr } *d = cloneBytes(s) return nil case *[]byte: if d == nil { return errNilPtr } *d = cloneBytes(s) return nil case *RawBytes: if d == nil { return errNilPtr } *d = s return nil } case time.Time: switch d := dest.(type) { case *time.Time: *d = s return nil case *string: *d = s.Format(time.RFC3339Nano) return nil case *[]byte: if d == nil { return errNilPtr } *d = []byte(s.Format(time.RFC3339Nano)) return nil case *RawBytes: if d == nil { return errNilPtr } *d = s.AppendFormat((*d)[:0], time.RFC3339Nano) return nil } case nil: switch d := dest.(type) { case *interface{}: if d == nil { return errNilPtr } *d = nil return nil case *[]byte: if d == nil { return errNilPtr } *d = nil return nil case *RawBytes: if d == nil { return errNilPtr } *d = nil return nil } } var sv reflect.Value switch d := dest.(type) { case *string: sv = reflect.ValueOf(src) switch sv.Kind() { case reflect.Bool, reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Float32, reflect.Float64: *d = asString(src) return nil } case *[]byte: sv = reflect.ValueOf(src) if b, ok := asBytes(nil, sv); ok { *d = b return nil } case *RawBytes: sv = reflect.ValueOf(src) if b, ok := asBytes([]byte(*d)[:0], sv); ok { *d = RawBytes(b) return nil } case *bool: bv, err := driver.Bool.ConvertValue(src) if err == nil { *d = bv.(bool) } return err case *interface{}: *d = src return nil } if scanner, ok := dest.(Scanner); ok { return scanner.Scan(src) } dpv := reflect.ValueOf(dest) if dpv.Kind() != reflect.Ptr { return errors.New("destination not a pointer") } if dpv.IsNil() { return errNilPtr } if !sv.IsValid() { sv = reflect.ValueOf(src) } dv := reflect.Indirect(dpv) if sv.IsValid() && sv.Type().AssignableTo(dv.Type()) { switch b := src.(type) { case []byte: dv.Set(reflect.ValueOf(cloneBytes(b))) default: dv.Set(sv) } return nil } if dv.Kind() == sv.Kind() && sv.Type().ConvertibleTo(dv.Type()) { dv.Set(sv.Convert(dv.Type())) return nil } // The following conversions use a string value as an intermediate representation // to convert between various numeric types. // // This also allows scanning into user defined types such as "type Int int64". // For symmetry, also check for string destination types. switch dv.Kind() { case reflect.Ptr: if src == nil { dv.Set(reflect.Zero(dv.Type())) return nil } dv.Set(reflect.New(dv.Type().Elem())) return convertAssign(dv.Interface(), src) case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: s := asString(src) i64, err := strconv.ParseInt(s, 10, dv.Type().Bits()) if err != nil { err = strconvErr(err) return fmt.Errorf("converting driver.Value type %T (%q) to a %s: %v", src, s, dv.Kind(), err) } dv.SetInt(i64) return nil case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: s := asString(src) u64, err := strconv.ParseUint(s, 10, dv.Type().Bits()) if err != nil { err = strconvErr(err) return fmt.Errorf("converting driver.Value type %T (%q) to a %s: %v", src, s, dv.Kind(), err) } dv.SetUint(u64) return nil case reflect.Float32, reflect.Float64: s := asString(src) f64, err := strconv.ParseFloat(s, dv.Type().Bits()) if err != nil { err = strconvErr(err) return fmt.Errorf("converting driver.Value type %T (%q) to a %s: %v", src, s, dv.Kind(), err) } dv.SetFloat(f64) return nil case reflect.String: switch v := src.(type) { case string: dv.SetString(v) return nil case []byte: dv.SetString(string(v)) return nil } } return fmt.Errorf("unsupported Scan, storing driver.Value type %T into type %T", src, dest) } func strconvErr(err error) error { if ne, ok := err.(*strconv.NumError); ok { return ne.Err } return err } func cloneBytes(b []byte) []byte { if b == nil { return nil } c := make([]byte, len(b)) copy(c, b) return c } func asString(src interface{}) string { switch v := src.(type) { case string: return v case []byte: return string(v) } rv := reflect.ValueOf(src) switch rv.Kind() { case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: return strconv.FormatInt(rv.Int(), 10) case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: return strconv.FormatUint(rv.Uint(), 10) case reflect.Float64: return strconv.FormatFloat(rv.Float(), 'g', -1, 64) case reflect.Float32: return strconv.FormatFloat(rv.Float(), 'g', -1, 32) case reflect.Bool: return strconv.FormatBool(rv.Bool()) } return fmt.Sprintf("%v", src) } func asBytes(buf []byte, rv reflect.Value) (b []byte, ok bool) { switch rv.Kind() { case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: return strconv.AppendInt(buf, rv.Int(), 10), true case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: return strconv.AppendUint(buf, rv.Uint(), 10), true case reflect.Float32: return strconv.AppendFloat(buf, rv.Float(), 'g', -1, 32), true case reflect.Float64: return strconv.AppendFloat(buf, rv.Float(), 'g', -1, 64), true case reflect.Bool: return strconv.AppendBool(buf, rv.Bool()), true case reflect.String: s := rv.String() return append(buf, s...), true } return } var valuerReflectType = reflect.TypeOf((*driver.Valuer)(nil)).Elem() // callValuerValue returns vr.Value(), with one exception: // If vr.Value is an auto-generated method on a pointer type and the // pointer is nil, it would panic at runtime in the panicwrap // method. Treat it like nil instead. // Issue 8415. // // This is so people can implement driver.Value on value types and // still use nil pointers to those types to mean nil/NULL, just like // string/*string. // // This function is mirrored in the database/sql/driver package. func callValuerValue(vr driver.Valuer) (v driver.Value, err error) { if rv := reflect.ValueOf(vr); rv.Kind() == reflect.Ptr && rv.IsNil() && rv.Type().Elem().Implements(valuerReflectType) { return nil, nil } return vr.Value() }