From d8c9dd604801958c649a32511deef373adeecfe0 Mon Sep 17 00:00:00 2001 From: OneOfOne Date: Sun, 10 Apr 2016 03:50:11 +0200 Subject: [PATCH] math/big: implement GobDecode/Encode for big.Float Added GobEncode/Decode and a test for them. Fixes #14593 Change-Id: Ic8d3efd24d0313a1a66f01da293c4c1fd39764a8 Reviewed-on: https://go-review.googlesource.com/21755 Reviewed-by: Robert Griesemer --- src/math/big/floatmarsh.go | 65 ++++++++++++++++++++++++++++++++- src/math/big/floatmarsh_test.go | 63 ++++++++++++++++++++++++++++++++ 2 files changed, 127 insertions(+), 1 deletion(-) diff --git a/src/math/big/floatmarsh.go b/src/math/big/floatmarsh.go index 44987ee03a..6127aa2e83 100644 --- a/src/math/big/floatmarsh.go +++ b/src/math/big/floatmarsh.go @@ -6,7 +6,70 @@ package big -import "fmt" +import ( + "encoding/binary" + "fmt" +) + +// Gob codec version. Permits backward-compatible changes to the encoding. +const floatGobVersion byte = 1 + +// GobEncode implements the gob.GobEncoder interface. +func (x *Float) GobEncode() ([]byte, error) { + if x == nil { + return nil, nil + } + sz := 1 + 1 + 4 // version + mode|acc|form|neg (3+2+2+1bit) + prec + if x.form == finite { + sz += 4 + int((x.prec+(_W-1))/_W)*_S // exp + mant + } + buf := make([]byte, sz) + + buf[0] = floatGobVersion + b := byte(x.mode&7)<<5 | byte((x.acc+1)&3)<<3 | byte(x.form&3)<<1 + if x.neg { + b |= 1 + } + buf[1] = b + binary.BigEndian.PutUint32(buf[2:], x.prec) + if x.form == finite { + binary.BigEndian.PutUint32(buf[6:], uint32(x.exp)) + x.mant.bytes(buf[10:]) + } + return buf, nil +} + +// GobDecode implements the gob.GobDecoder interface. +func (z *Float) GobDecode(buf []byte) error { + if len(buf) == 0 { + // Other side sent a nil or default value. + *z = Float{} + return nil + } + + if buf[0] != floatGobVersion { + return fmt.Errorf("Float.GobDecode: encoding version %d not supported", buf[0]) + } + + b := buf[1] + z.mode = RoundingMode((b >> 5) & 7) + z.acc = Accuracy((b>>3)&3) - 1 + z.form = form((b >> 1) & 3) + z.neg = b&1 != 0 + + oldPrec := uint(z.prec) + z.prec = binary.BigEndian.Uint32(buf[2:]) + + if z.form == finite { + z.exp = int32(binary.BigEndian.Uint32(buf[6:])) + z.mant = z.mant.setBytes(buf[10:]) + } + + if oldPrec != 0 { + z.SetPrec(oldPrec) + } + return nil +} // MarshalText implements the encoding.TextMarshaler interface. // Only the Float value is marshaled (in full precision), other diff --git a/src/math/big/floatmarsh_test.go b/src/math/big/floatmarsh_test.go index d7ef2fca68..f726c35e99 100644 --- a/src/math/big/floatmarsh_test.go +++ b/src/math/big/floatmarsh_test.go @@ -5,7 +5,10 @@ package big import ( + "bytes" + "encoding/gob" "encoding/json" + "io" "testing" ) @@ -23,6 +26,66 @@ var floatVals = []string{ "Inf", } +func TestFloatGobEncoding(t *testing.T) { + var medium bytes.Buffer + for _, test := range floatVals { + for _, sign := range []string{"", "+", "-"} { + for _, prec := range []uint{0, 1, 2, 10, 53, 64, 100, 1000} { + medium.Reset() // empty buffer for each test case (in case of failures) + enc := gob.NewEncoder(&medium) + dec := gob.NewDecoder(&medium) + x := sign + test + var tx Float + _, _, err := tx.SetPrec(prec).Parse(x, 0) + if err != nil { + t.Errorf("parsing of %s (prec = %d) failed (invalid test case): %v", x, prec, err) + continue + } + tx.SetMode(ToPositiveInf) + if err := enc.Encode(&tx); err != nil { + t.Errorf("encoding of %v (prec = %d) failed: %v", &tx, prec, err) + continue + } + + var rx Float + if err := dec.Decode(&rx); err != nil { + t.Errorf("decoding of %v (prec = %d) failed: %v", &tx, prec, err) + continue + } + + if rx.Cmp(&tx) != 0 { + t.Errorf("transmission of %s failed: got %s want %s", x, rx.String(), tx.String()) + continue + } + + if rx.Mode() != ToPositiveInf { + t.Errorf("transmission of %s's mode failed: got %s want %s", x, rx.Mode(), ToPositiveInf) + } + } + } + } +} +func TestFloatCorruptGob(t *testing.T) { + var buf bytes.Buffer + tx := NewFloat(4 / 3).SetPrec(1000).SetMode(ToPositiveInf) + if err := gob.NewEncoder(&buf).Encode(tx); err != nil { + t.Fatal(err) + } + b := buf.Bytes() + var rx Float + if err := gob.NewDecoder(bytes.NewReader(b)).Decode(&rx); err != nil { + t.Fatal(err) + } + var rx2 Float + if err := gob.NewDecoder(bytes.NewReader(b[:10])).Decode(&rx2); err != io.ErrUnexpectedEOF { + t.Errorf("expected io.ErrUnexpectedEOF, got %v", err) + } + b[1] = 0 + if err := gob.NewDecoder(bytes.NewReader(b)).Decode(&rx); err == nil { + t.Fatal("expected a version error, got nil") + } + +} func TestFloatJSONEncoding(t *testing.T) { for _, test := range floatVals { for _, sign := range []string{"", "+", "-"} {