1
0
mirror of https://github.com/golang/go synced 2024-09-25 01:10:13 -06:00

math/big: handle NaNs in Float.Cmp

Also:
- Implemented NewFloat convenience factory function (analogous to
  NewInt and NewRat).
- Implemented convenience accessors for Accuracy values returned
  from Float.Cmp.
- Added test and example.

Change-Id: I985bb4f86e6def222d4b2505417250d29a39c60e
Reviewed-on: https://go-review.googlesource.com/6970
Reviewed-by: Alan Donovan <adonovan@google.com>
This commit is contained in:
Robert Griesemer 2015-03-05 13:10:33 -08:00
parent d022266a9a
commit 8a6eca43df
3 changed files with 194 additions and 58 deletions

View File

@ -65,6 +65,12 @@ type Float struct {
exp int32 exp int32
} }
// NewFloat allocates and returns a new Float set to x,
// with precision 53 and rounding mode ToNearestEven.
func NewFloat(x float64) *Float {
return new(Float).SetFloat64(x)
}
// Exponent and precision limits. // Exponent and precision limits.
const ( const (
MaxExp = math.MaxInt32 // largest supported exponent MaxExp = math.MaxInt32 // largest supported exponent
@ -135,13 +141,6 @@ const (
//go:generate stringer -type=Accuracy //go:generate stringer -type=Accuracy
func (x *Float) cmpZero() Accuracy {
if x.neg {
return Above
}
return Below
}
// SetPrec sets z's precision to prec and returns the (possibly) rounded // SetPrec sets z's precision to prec and returns the (possibly) rounded
// value of z. Rounding occurs according to z's rounding mode if the mantissa // value of z. Rounding occurs according to z's rounding mode if the mantissa
// cannot be represented in prec bits without loss of precision. // cannot be represented in prec bits without loss of precision.
@ -173,6 +172,13 @@ func (z *Float) SetPrec(prec uint) *Float {
return z return z
} }
func (x *Float) cmpZero() Accuracy {
if x.neg {
return Above
}
return Below
}
// SetMode sets z's rounding mode to mode and returns an exact z. // SetMode sets z's rounding mode to mode and returns an exact z.
// z remains unchanged otherwise. // z remains unchanged otherwise.
// z.SetMode(z.Mode()) is a cheap way to set z's accuracy to Exact. // z.SetMode(z.Mode()) is a cheap way to set z's accuracy to Exact.
@ -1030,7 +1036,7 @@ func (z *Float) Neg(x *Float) *Float {
} }
// z = x + y, ignoring signs of x and y. // z = x + y, ignoring signs of x and y.
// x and y must not be 0, Inf, or NaN. // x.form and y.form must be finite.
func (z *Float) uadd(x, y *Float) { func (z *Float) uadd(x, y *Float) {
// Note: This implementation requires 2 shifts most of the // Note: This implementation requires 2 shifts most of the
// time. It is also inefficient if exponents or precisions // time. It is also inefficient if exponents or precisions
@ -1074,7 +1080,7 @@ func (z *Float) uadd(x, y *Float) {
} }
// z = x - y for x >= y, ignoring signs of x and y. // z = x - y for x >= y, ignoring signs of x and y.
// x and y must not be 0, Inf, or NaN. // x.form and y.form must be finite.
func (z *Float) usub(x, y *Float) { func (z *Float) usub(x, y *Float) {
// This code is symmetric to uadd. // This code is symmetric to uadd.
// We have not factored the common code out because // We have not factored the common code out because
@ -1116,7 +1122,7 @@ func (z *Float) usub(x, y *Float) {
} }
// z = x * y, ignoring signs of x and y. // z = x * y, ignoring signs of x and y.
// x and y must not be 0, Inf, or NaN. // x.form and y.form must be finite.
func (z *Float) umul(x, y *Float) { func (z *Float) umul(x, y *Float) {
if debugFloat && (len(x.mant) == 0 || len(y.mant) == 0) { if debugFloat && (len(x.mant) == 0 || len(y.mant) == 0) {
panic("umul called with empty mantissa") panic("umul called with empty mantissa")
@ -1137,7 +1143,7 @@ func (z *Float) umul(x, y *Float) {
} }
// z = x / y, ignoring signs of x and y. // z = x / y, ignoring signs of x and y.
// x and y must not be 0, Inf, or NaN. // x.form and y.form must be finite.
func (z *Float) uquo(x, y *Float) { func (z *Float) uquo(x, y *Float) {
if debugFloat && (len(x.mant) == 0 || len(y.mant) == 0) { if debugFloat && (len(x.mant) == 0 || len(y.mant) == 0) {
panic("uquo called with empty mantissa") panic("uquo called with empty mantissa")
@ -1184,18 +1190,19 @@ func (z *Float) uquo(x, y *Float) {
z.round(sbit) z.round(sbit)
} }
// ucmp returns -1, 0, or 1, depending on whether x < y, x == y, or x > y, // ucmp returns Below, Exact, or Above, depending
// while ignoring the signs of x and y. x and y must not be 0, Inf, or NaN. // on whether x < y, x == y, or x > y.
func (x *Float) ucmp(y *Float) int { // x.form and y.form must be finite.
func (x *Float) ucmp(y *Float) Accuracy {
if debugFloat && (len(x.mant) == 0 || len(y.mant) == 0) { if debugFloat && (len(x.mant) == 0 || len(y.mant) == 0) {
panic("ucmp called with empty mantissa") panic("ucmp called with empty mantissa")
} }
switch { switch {
case x.exp < y.exp: case x.exp < y.exp:
return -1 return Below
case x.exp > y.exp: case x.exp > y.exp:
return 1 return Above
} }
// x.exp == y.exp // x.exp == y.exp
@ -1214,13 +1221,13 @@ func (x *Float) ucmp(y *Float) int {
} }
switch { switch {
case xm < ym: case xm < ym:
return -1 return Below
case xm > ym: case xm > ym:
return 1 return Above
} }
} }
return 0 return Exact
} }
// Handling of sign bit as defined by IEEE 754-2008, section 6.3: // Handling of sign bit as defined by IEEE 754-2008, section 6.3:
@ -1286,7 +1293,7 @@ func (z *Float) Add(x, y *Float) *Float {
} else { } else {
// x + (-y) == x - y == -(y - x) // x + (-y) == x - y == -(y - x)
// (-x) + y == y - x == -(x - y) // (-x) + y == y - x == -(x - y)
if x.ucmp(y) >= 0 { if x.ucmp(y) == Above {
z.usub(x, y) z.usub(x, y)
} else { } else {
z.neg = !z.neg z.neg = !z.neg
@ -1342,7 +1349,7 @@ func (z *Float) Sub(x, y *Float) *Float {
} else { } else {
// x - y == x - y == -(y - x) // x - y == x - y == -(y - x)
// (-x) - (-y) == y - x == -(x - y) // (-x) - (-y) == y - x == -(x - y)
if x.ucmp(y) >= 0 { if x.ucmp(y) == Above {
z.usub(x, y) z.usub(x, y)
} else { } else {
z.neg = !z.neg z.neg = !z.neg
@ -1441,46 +1448,49 @@ func (z *Float) Quo(x, y *Float) *Float {
// Cmp compares x and y and returns: // Cmp compares x and y and returns:
// //
// -1 if x < y // Below if x < y
// 0 if x == y (incl. -0 == 0) // Exact if x == y (incl. -0 == 0, -Inf == -Inf, and +Inf == +Inf)
// +1 if x > y // Above if x > y
// Undef if any of x, y is NaN
// //
// Infinities with matching sign are equal. func (x *Float) Cmp(y *Float) Accuracy {
// NaN values are never equal.
// BUG(gri) Float.Cmp does not implement comparing of NaNs.
func (x *Float) Cmp(y *Float) int {
if debugFloat { if debugFloat {
x.validate() x.validate()
y.validate() y.validate()
} }
if x.form == nan || y.form == nan {
return Undef
}
mx := x.ord() mx := x.ord()
my := y.ord() my := y.ord()
switch { switch {
case mx < my: case mx < my:
return -1 return Below
case mx > my: case mx > my:
return +1 return Above
} }
// mx == my // mx == my
// only if |mx| == 1 we have to compare the mantissae // only if |mx| == 1 we have to compare the mantissae
switch mx { switch mx {
case -1: case -1:
return -x.ucmp(y) return y.ucmp(x)
case +1: case +1:
return +x.ucmp(y) return x.ucmp(y)
} }
return 0 return Exact
} }
func umax32(x, y uint32) uint32 { // The following accessors simplify testing of Cmp results.
if x > y { func (acc Accuracy) Eql() bool { return acc == Exact }
return x func (acc Accuracy) Neq() bool { return acc != Exact }
} func (acc Accuracy) Lss() bool { return acc == Below }
return y func (acc Accuracy) Leq() bool { return acc&Above == 0 }
} func (acc Accuracy) Gtr() bool { return acc == Above }
func (acc Accuracy) Geq() bool { return acc&Below == 0 }
// ord classifies x and returns: // ord classifies x and returns:
// //
@ -1507,3 +1517,10 @@ func (x *Float) ord() int {
} }
return m return m
} }
func umax32(x, y uint32) uint32 {
if x > y {
return x
}
return y
}

View File

@ -1063,8 +1063,8 @@ func TestFloatAdd32(t *testing.T) {
x0, y0 = y0, x0 x0, y0 = y0, x0
} }
x := new(Float).SetFloat64(x0) x := NewFloat(x0)
y := new(Float).SetFloat64(y0) y := NewFloat(y0)
z := new(Float).SetPrec(24) z := new(Float).SetPrec(24)
z.Add(x, y) z.Add(x, y)
@ -1096,8 +1096,8 @@ func TestFloatAdd64(t *testing.T) {
x0, y0 = y0, x0 x0, y0 = y0, x0
} }
x := new(Float).SetFloat64(x0) x := NewFloat(x0)
y := new(Float).SetFloat64(y0) y := NewFloat(y0)
z := new(Float).SetPrec(53) z := new(Float).SetPrec(53)
z.Add(x, y) z.Add(x, y)
@ -1182,8 +1182,8 @@ func TestFloatMul64(t *testing.T) {
x0, y0 = y0, x0 x0, y0 = y0, x0
} }
x := new(Float).SetFloat64(x0) x := NewFloat(x0)
y := new(Float).SetFloat64(y0) y := NewFloat(y0)
z := new(Float).SetPrec(53) z := new(Float).SetPrec(53)
z.Mul(x, y) z.Mul(x, y)
@ -1260,7 +1260,7 @@ func TestFloatQuo(t *testing.T) {
z := bits.Float() z := bits.Float()
// compute accurate x as z*y // compute accurate x as z*y
y := new(Float).SetFloat64(3.14159265358979323e123) y := NewFloat(3.14159265358979323e123)
x := new(Float).SetPrec(z.Prec() + y.Prec()).SetMode(ToZero) x := new(Float).SetPrec(z.Prec() + y.Prec()).SetMode(ToZero)
x.Mul(z, y) x.Mul(z, y)
@ -1329,10 +1329,9 @@ func TestFloatQuoSmoke(t *testing.T) {
} }
} }
// TestFloatArithmeticSpecialValues tests that Float operations produce // TestFloatArithmeticSpecialValues tests that Float operations produce the
// the correct result for all combinations of regular and special value // correct results for combinations of zero (±0), finite (±1 and ±2.71828),
// arguments (±0, ±Inf, NaN) and ±1 and ±2.71828 as representatives for // and non-finite (±Inf, NaN) operands.
// nonzero finite values.
func TestFloatArithmeticSpecialValues(t *testing.T) { func TestFloatArithmeticSpecialValues(t *testing.T) {
zero := 0.0 zero := 0.0
args := []float64{math.Inf(-1), -2.71828, -1, -zero, zero, 1, 2.71828, math.Inf(1), math.NaN()} args := []float64{math.Inf(-1), -2.71828, -1, -zero, zero, 1, 2.71828, math.Inf(1), math.NaN()}
@ -1442,6 +1441,39 @@ func TestFloatArithmeticRounding(t *testing.T) {
} }
} }
func TestFloatCmp(t *testing.T) { // TestFloatCmpSpecialValues tests that Cmp produces the correct results for
// TODO(gri) implement this // combinations of zero (±0), finite (±1 and ±2.71828), and non-finite (±Inf,
// NaN) operands.
func TestFloatCmpSpecialValues(t *testing.T) {
zero := 0.0
args := []float64{math.Inf(-1), -2.71828, -1, -zero, zero, 1, 2.71828, math.Inf(1), math.NaN()}
xx := new(Float)
yy := new(Float)
for i := 0; i < 4; i++ {
for _, x := range args {
xx.SetFloat64(x)
// check conversion is correct
// (no need to do this for y, since we see exactly the
// same values there)
if got, acc := xx.Float64(); !math.IsNaN(x) && (got != x || acc != Exact) {
t.Errorf("Float(%g) == %g (%s)", x, got, acc)
}
for _, y := range args {
yy.SetFloat64(y)
got := xx.Cmp(yy)
want := Undef
switch {
case x < y:
want = Below
case x == y:
want = Exact
case x > y:
want = Above
}
if got != want {
t.Errorf("(%g).Cmp(%g) = %s; want %s", x, y, got, want)
}
}
}
}
} }

View File

@ -6,11 +6,10 @@ package big_test
import ( import (
"fmt" "fmt"
"math"
"math/big" "math/big"
) )
// TODO(gri) add more examples
func ExampleFloat_Add() { func ExampleFloat_Add() {
// Operating on numbers of different precision. // Operating on numbers of different precision.
var x, y, z big.Float var x, y, z big.Float
@ -29,11 +28,10 @@ func ExampleFloat_Add() {
func Example_Shift() { func Example_Shift() {
// Implementing Float "shift" by modifying the (binary) exponents directly. // Implementing Float "shift" by modifying the (binary) exponents directly.
var x big.Float
for s := -5; s <= 5; s++ { for s := -5; s <= 5; s++ {
x.SetFloat64(0.5) x := big.NewFloat(0.5)
x.SetMantExp(&x, x.MantExp(nil)+s) // shift x by s x.SetMantExp(x, x.MantExp(nil)+s) // shift x by s
fmt.Println(&x) fmt.Println(x)
} }
// Output: // Output:
// 0.015625 // 0.015625
@ -48,3 +46,92 @@ func Example_Shift() {
// 8 // 8
// 16 // 16
} }
func ExampleFloat_Cmp() {
inf := math.Inf(1)
zero := 0.0
nan := math.NaN()
operands := []float64{-inf, -1.2, -zero, 0, +1.2, +inf, nan}
fmt.Println(" x y cmp eql neq lss leq gtr geq")
fmt.Println("-----------------------------------------------")
for _, x64 := range operands {
x := big.NewFloat(x64)
for _, y64 := range operands {
y := big.NewFloat(y64)
t := x.Cmp(y)
fmt.Printf(
"%4s %4s %5s %c %c %c %c %c %c\n",
x, y, t,
mark(t.Eql()), mark(t.Neq()), mark(t.Lss()), mark(t.Leq()), mark(t.Gtr()), mark(t.Geq()))
}
fmt.Println()
}
// Output:
// x y cmp eql neq lss leq gtr geq
// -----------------------------------------------
// -Inf -Inf Exact ● ○ ○ ● ○ ●
// -Inf -1.2 Below ○ ● ● ● ○ ○
// -Inf -0 Below ○ ● ● ● ○ ○
// -Inf 0 Below ○ ● ● ● ○ ○
// -Inf 1.2 Below ○ ● ● ● ○ ○
// -Inf +Inf Below ○ ● ● ● ○ ○
// -Inf NaN Undef ○ ● ○ ○ ○ ○
//
// -1.2 -Inf Above ○ ● ○ ○ ● ●
// -1.2 -1.2 Exact ● ○ ○ ● ○ ●
// -1.2 -0 Below ○ ● ● ● ○ ○
// -1.2 0 Below ○ ● ● ● ○ ○
// -1.2 1.2 Below ○ ● ● ● ○ ○
// -1.2 +Inf Below ○ ● ● ● ○ ○
// -1.2 NaN Undef ○ ● ○ ○ ○ ○
//
// -0 -Inf Above ○ ● ○ ○ ● ●
// -0 -1.2 Above ○ ● ○ ○ ● ●
// -0 -0 Exact ● ○ ○ ● ○ ●
// -0 0 Exact ● ○ ○ ● ○ ●
// -0 1.2 Below ○ ● ● ● ○ ○
// -0 +Inf Below ○ ● ● ● ○ ○
// -0 NaN Undef ○ ● ○ ○ ○ ○
//
// 0 -Inf Above ○ ● ○ ○ ● ●
// 0 -1.2 Above ○ ● ○ ○ ● ●
// 0 -0 Exact ● ○ ○ ● ○ ●
// 0 0 Exact ● ○ ○ ● ○ ●
// 0 1.2 Below ○ ● ● ● ○ ○
// 0 +Inf Below ○ ● ● ● ○ ○
// 0 NaN Undef ○ ● ○ ○ ○ ○
//
// 1.2 -Inf Above ○ ● ○ ○ ● ●
// 1.2 -1.2 Above ○ ● ○ ○ ● ●
// 1.2 -0 Above ○ ● ○ ○ ● ●
// 1.2 0 Above ○ ● ○ ○ ● ●
// 1.2 1.2 Exact ● ○ ○ ● ○ ●
// 1.2 +Inf Below ○ ● ● ● ○ ○
// 1.2 NaN Undef ○ ● ○ ○ ○ ○
//
// +Inf -Inf Above ○ ● ○ ○ ● ●
// +Inf -1.2 Above ○ ● ○ ○ ● ●
// +Inf -0 Above ○ ● ○ ○ ● ●
// +Inf 0 Above ○ ● ○ ○ ● ●
// +Inf 1.2 Above ○ ● ○ ○ ● ●
// +Inf +Inf Exact ● ○ ○ ● ○ ●
// +Inf NaN Undef ○ ● ○ ○ ○ ○
//
// NaN -Inf Undef ○ ● ○ ○ ○ ○
// NaN -1.2 Undef ○ ● ○ ○ ○ ○
// NaN -0 Undef ○ ● ○ ○ ○ ○
// NaN 0 Undef ○ ● ○ ○ ○ ○
// NaN 1.2 Undef ○ ● ○ ○ ○ ○
// NaN +Inf Undef ○ ● ○ ○ ○ ○
// NaN NaN Undef ○ ● ○ ○ ○ ○
}
func mark(p bool) rune {
if p {
return '●'
}
return '○'
}