diff --git a/src/pkg/xml/read.go b/src/pkg/xml/read.go index 4f944038e87..c85b697025a 100644 --- a/src/pkg/xml/read.go +++ b/src/pkg/xml/read.go @@ -9,6 +9,7 @@ import ( "io" "os" "reflect" + "strconv" "strings" "unicode" ) @@ -106,6 +107,10 @@ import ( // // Unmarshal maps an XML element to a bool by setting the bool to true. // +// Unmarshal maps an XML element to an integer or floating-point +// field by setting the field to the result of interpreting the string +// value in decimal. There is no check for overflow. +// // Unmarshal maps an XML element to an xml.Name by recording the // element name. // @@ -200,6 +205,9 @@ func (p *Parser) unmarshal(val reflect.Value, start *StartElement) os.Error { styp *reflect.StructType ) switch v := val.(type) { + default: + return os.ErrorString("unknown type " + v.Type().String()) + case *reflect.BoolValue: v.Set(true) @@ -232,7 +240,11 @@ func (p *Parser) unmarshal(val reflect.Value, start *StartElement) os.Error { } return nil - case *reflect.StringValue: + case *reflect.StringValue, + *reflect.IntValue, *reflect.UintValue, *reflect.UintptrValue, + *reflect.Int8Value, *reflect.Int16Value, *reflect.Int32Value, *reflect.Int64Value, + *reflect.Uint8Value, *reflect.Uint16Value, *reflect.Uint32Value, *reflect.Uint64Value, + *reflect.FloatValue, *reflect.Float32Value, *reflect.Float64Value: saveData = v case *reflect.StructValue: @@ -365,8 +377,103 @@ Loop: } } - // Save accumulated character data and comments + var err os.Error + // Helper functions for integer and unsigned integer conversions + var itmp int64 + getInt64 := func() bool { + itmp, err = strconv.Atoi64(string(data)) + // TODO: should check sizes + return err == nil + } + var utmp uint64 + getUint64 := func() bool { + utmp, err = strconv.Atoui64(string(data)) + // TODO: check for overflow? + return err == nil + } + var ftmp float64 + getFloat64 := func() bool { + ftmp, err = strconv.Atof64(string(data)) + // TODO: check for overflow? + return err == nil + } + + // Save accumulated data and comments switch t := saveData.(type) { + case nil: + // Probably a comment, handled below + default: + return os.ErrorString("cannot happen: unknown type " + t.Type().String()) + case *reflect.IntValue: + if !getInt64() { + return err + } + t.Set(int(itmp)) + case *reflect.Int8Value: + if !getInt64() { + return err + } + t.Set(int8(itmp)) + case *reflect.Int16Value: + if !getInt64() { + return err + } + t.Set(int16(itmp)) + case *reflect.Int32Value: + if !getInt64() { + return err + } + t.Set(int32(itmp)) + case *reflect.Int64Value: + if !getInt64() { + return err + } + t.Set(itmp) + case *reflect.UintValue: + if !getUint64() { + return err + } + t.Set(uint(utmp)) + case *reflect.Uint8Value: + if !getUint64() { + return err + } + t.Set(uint8(utmp)) + case *reflect.Uint16Value: + if !getUint64() { + return err + } + t.Set(uint16(utmp)) + case *reflect.Uint32Value: + if !getUint64() { + return err + } + t.Set(uint32(utmp)) + case *reflect.Uint64Value: + if !getUint64() { + return err + } + t.Set(utmp) + case *reflect.UintptrValue: + if !getUint64() { + return err + } + t.Set(uintptr(utmp)) + case *reflect.FloatValue: + if !getFloat64() { + return err + } + t.Set(float(ftmp)) + case *reflect.Float32Value: + if !getFloat64() { + return err + } + t.Set(float32(ftmp)) + case *reflect.Float64Value: + if !getFloat64() { + return err + } + t.Set(ftmp) case *reflect.StringValue: t.Set(string(data)) case *reflect.SliceValue: diff --git a/src/pkg/xml/xml_test.go b/src/pkg/xml/xml_test.go index f228dfba370..fa19495001e 100644 --- a/src/pkg/xml/xml_test.go +++ b/src/pkg/xml/xml_test.go @@ -214,6 +214,76 @@ func TestSyntax(t *testing.T) { } } +type allScalars struct { + Bool bool + Int int + Int8 int8 + Int16 int16 + Int32 int32 + Int64 int64 + Uint int + Uint8 uint8 + Uint16 uint16 + Uint32 uint32 + Uint64 uint64 + Uintptr uintptr + Float float + Float32 float32 + Float64 float64 + String string +} + +var all = allScalars{ + Bool: true, + Int: 1, + Int8: -2, + Int16: 3, + Int32: -4, + Int64: 5, + Uint: 6, + Uint8: 7, + Uint16: 8, + Uint32: 9, + Uint64: 10, + Uintptr: 11, + Float: 12.0, + Float32: 13.0, + Float64: 14.0, + String: "15", +} + +const testScalarsInput = ` + + 1 + -2 + 3 + -4 + 5 + 6 + 7 + 8 + 9 + 10 + 11 + 12.0 + 13.0 + 14.0 + 15 +` + +func TestAllScalars(t *testing.T) { + var a allScalars + buf := bytes.NewBufferString(testScalarsInput) + err := Unmarshal(buf, &a) + + if err != nil { + t.Fatal(err) + } + if !reflect.DeepEqual(a, all) { + t.Errorf("expected %+v got %+v", a, all) + } +} + type item struct { Field_a string }