From 23fd374bf22aa2eea9c07076061ef8cfbc6cf3d7 Mon Sep 17 00:00:00 2001 From: Robert Griesemer Date: Fri, 13 Mar 2015 17:24:30 -0700 Subject: [PATCH] math/big: wrap Float.Cmp result in struct to prevent wrong use Float.Cmp used to return a value < 0, 0, or > 0 depending on how arguments x, y compared against each other. With the possibility of NaNs, the result was changed into an Accuracy (to include Undef). Consequently, Float.Cmp results could still be compared for (in-) equality with 0, but comparing if < 0 or > 0 would provide the wrong answer w/o any obvious notice by the compiler. This change wraps Float.Cmp results into a struct and accessors are used to access the desired result. This prevents incorrect use. Change-Id: I34e6a6c1859251ec99b5cf953e82542025ace56f Reviewed-on: https://go-review.googlesource.com/7526 Reviewed-by: Rob Pike --- src/math/big/float.go | 31 ++++++++++++++++++------------- src/math/big/float_test.go | 20 ++++++++++---------- src/math/big/floatconv_test.go | 2 +- src/math/big/floatexample_test.go | 2 +- 4 files changed, 30 insertions(+), 25 deletions(-) diff --git a/src/math/big/float.go b/src/math/big/float.go index 44691c47839..d716c8ca59f 100644 --- a/src/math/big/float.go +++ b/src/math/big/float.go @@ -1446,6 +1446,10 @@ func (z *Float) Quo(x, y *Float) *Float { return z } +type cmpResult struct { + acc Accuracy +} + // Cmp compares x and y and returns: // // Below if x < y @@ -1453,44 +1457,45 @@ func (z *Float) Quo(x, y *Float) *Float { // Above if x > y // Undef if any of x, y is NaN // -func (x *Float) Cmp(y *Float) Accuracy { +func (x *Float) Cmp(y *Float) cmpResult { if debugFloat { x.validate() y.validate() } if x.form == nan || y.form == nan { - return Undef + return cmpResult{Undef} } mx := x.ord() my := y.ord() switch { case mx < my: - return Below + return cmpResult{Below} case mx > my: - return Above + return cmpResult{Above} } // mx == my // only if |mx| == 1 we have to compare the mantissae switch mx { case -1: - return y.ucmp(x) + return cmpResult{y.ucmp(x)} case +1: - return x.ucmp(y) + return cmpResult{x.ucmp(y)} } - return Exact + return cmpResult{Exact} } // The following accessors simplify testing of Cmp results. -func (acc Accuracy) Eql() bool { return acc == Exact } -func (acc Accuracy) Neq() bool { return acc != Exact } -func (acc Accuracy) Lss() bool { return acc == Below } -func (acc Accuracy) Leq() bool { return acc&Above == 0 } -func (acc Accuracy) Gtr() bool { return acc == Above } -func (acc Accuracy) Geq() bool { return acc&Below == 0 } +func (res cmpResult) Acc() Accuracy { return res.acc } +func (res cmpResult) Eql() bool { return res.acc == Exact } +func (res cmpResult) Neq() bool { return res.acc != Exact } +func (res cmpResult) Lss() bool { return res.acc == Below } +func (res cmpResult) Leq() bool { return res.acc&Above == 0 } +func (res cmpResult) Gtr() bool { return res.acc == Above } +func (res cmpResult) Geq() bool { return res.acc&Below == 0 } // ord classifies x and returns: // diff --git a/src/math/big/float_test.go b/src/math/big/float_test.go index 86b1c6f7a1a..683809bf567 100644 --- a/src/math/big/float_test.go +++ b/src/math/big/float_test.go @@ -207,7 +207,7 @@ func feq(x, y *Float) bool { if x.IsNaN() || y.IsNaN() { return x.IsNaN() && y.IsNaN() } - return x.Cmp(y) == 0 && x.IsNeg() == y.IsNeg() + return x.Cmp(y).Eql() && x.IsNeg() == y.IsNeg() } func TestFloatMantExp(t *testing.T) { @@ -918,7 +918,7 @@ func TestFloatRat(t *testing.T) { // inverse conversion if res != nil { got := new(Float).SetPrec(64).SetRat(res) - if got.Cmp(x) != 0 { + if got.Cmp(x).Neq() { t.Errorf("%s: got %s; want %s", test.x, got, x) } } @@ -995,7 +995,7 @@ func TestFloatInc(t *testing.T) { for i := 0; i < n; i++ { x.Add(&x, &one) } - if x.Cmp(new(Float).SetInt64(n)) != 0 { + if x.Cmp(new(Float).SetInt64(n)).Neq() { t.Errorf("prec = %d: got %s; want %d", prec, &x, n) } } @@ -1036,14 +1036,14 @@ func TestFloatAdd(t *testing.T) { got := new(Float).SetPrec(prec).SetMode(mode) got.Add(x, y) want := zbits.round(prec, mode) - if got.Cmp(want) != 0 { + if got.Cmp(want).Neq() { t.Errorf("i = %d, prec = %d, %s:\n\t %s %v\n\t+ %s %v\n\t= %s\n\twant %s", i, prec, mode, x, xbits, y, ybits, got, want) } got.Sub(z, x) want = ybits.round(prec, mode) - if got.Cmp(want) != 0 { + if got.Cmp(want).Neq() { t.Errorf("i = %d, prec = %d, %s:\n\t %s %v\n\t- %s %v\n\t= %s\n\twant %s", i, prec, mode, z, zbits, x, xbits, got, want) } @@ -1137,7 +1137,7 @@ func TestFloatMul(t *testing.T) { got := new(Float).SetPrec(prec).SetMode(mode) got.Mul(x, y) want := zbits.round(prec, mode) - if got.Cmp(want) != 0 { + if got.Cmp(want).Neq() { t.Errorf("i = %d, prec = %d, %s:\n\t %s %v\n\t* %s %v\n\t= %s\n\twant %s", i, prec, mode, x, xbits, y, ybits, got, want) } @@ -1147,7 +1147,7 @@ func TestFloatMul(t *testing.T) { } got.Quo(z, x) want = ybits.round(prec, mode) - if got.Cmp(want) != 0 { + if got.Cmp(want).Neq() { t.Errorf("i = %d, prec = %d, %s:\n\t %s %v\n\t/ %s %v\n\t= %s\n\twant %s", i, prec, mode, z, zbits, x, xbits, got, want) } @@ -1230,7 +1230,7 @@ func TestIssue6866(t *testing.T) { p.Mul(p, psix) z2.Sub(two, p) - if z1.Cmp(z2) != 0 { + if z1.Cmp(z2).Neq() { t.Fatalf("prec %d: got z1 = %s != z2 = %s; want z1 == z2\n", prec, z1, z2) } if z1.Sign() != 0 { @@ -1281,7 +1281,7 @@ func TestFloatQuo(t *testing.T) { prec := uint(preci + d) got := new(Float).SetPrec(prec).SetMode(mode).Quo(x, y) want := bits.round(prec, mode) - if got.Cmp(want) != 0 { + if got.Cmp(want).Neq() { t.Errorf("i = %d, prec = %d, %s:\n\t %s\n\t/ %s\n\t= %s\n\twant %s", i, prec, mode, x, y, got, want) } @@ -1462,7 +1462,7 @@ func TestFloatCmpSpecialValues(t *testing.T) { } for _, y := range args { yy.SetFloat64(y) - got := xx.Cmp(yy) + got := xx.Cmp(yy).Acc() want := Undef switch { case x < y: diff --git a/src/math/big/floatconv_test.go b/src/math/big/floatconv_test.go index e7920d0c079..17c8b147862 100644 --- a/src/math/big/floatconv_test.go +++ b/src/math/big/floatconv_test.go @@ -102,7 +102,7 @@ func TestFloatSetFloat64String(t *testing.T) { } f, _ := x.Float64() want := new(Float).SetFloat64(test.x) - if x.Cmp(want) != 0 { + if x.Cmp(want).Neq() { t.Errorf("%s: got %s (%v); want %v", test.s, &x, f, test.x) } } diff --git a/src/math/big/floatexample_test.go b/src/math/big/floatexample_test.go index 181c0bc136a..5e4a9cbe89c 100644 --- a/src/math/big/floatexample_test.go +++ b/src/math/big/floatexample_test.go @@ -63,7 +63,7 @@ func ExampleFloat_Cmp() { t := x.Cmp(y) fmt.Printf( "%4s %4s %5s %c %c %c %c %c %c\n", - x, y, t, + x, y, t.Acc(), mark(t.Eql()), mark(t.Neq()), mark(t.Lss()), mark(t.Leq()), mark(t.Gtr()), mark(t.Geq())) } fmt.Println()