diff --git a/src/math/big/decimal.go b/src/math/big/decimal.go index f4c535acdb..670668baaf 100644 --- a/src/math/big/decimal.go +++ b/src/math/big/decimal.go @@ -71,19 +71,18 @@ func (x *decimal) init(m nat, shift int) { x.exp = n // Trim trailing zeros; instead the exponent is tracking // the decimal point independent of the number of digits. - for n > 0 && s[n-1] == 0 { + for n > 0 && s[n-1] == '0' { n-- } - x.mant = make([]byte, n) - copy(x.mant, s) + x.mant = append(x.mant[:0], s[:n]...) // Do any (remaining) shift right in decimal representation. if shift < 0 { for shift < -maxShift { - x.shr(maxShift) + shr(x, maxShift) shift += maxShift } - x.shr(uint(-shift)) + shr(x, uint(-shift)) } } @@ -94,7 +93,7 @@ func (x *decimal) init(m nat, shift int) { // single +'0' pass at the end). // shr implements x >> s, for s <= maxShift. -func (x *decimal) shr(s uint) { +func shr(x *decimal, s uint) { // Division by 1< 0 && x.mant[w-1] == '0' { - w-- - } - x.mant = x.mant[:w] + trim(x) } func (x *decimal) String() string { @@ -189,3 +183,73 @@ func appendZeros(buf []byte, n int) []byte { } return buf } + +// shouldRoundUp reports if x should be rounded up +// if shortened to n digits. n must be a valid index +// for x.mant. +func shouldRoundUp(x *decimal, n int) bool { + if x.mant[n] == '5' && n+1 == len(x.mant) { + // exactly halfway - round to even + return n > 0 && (x.mant[n-1]-'0')&1 != 0 + } + // not halfway - digit tells all (x.mant has no trailing zeros) + return x.mant[n] >= '5' +} + +// round sets x to (at most) n mantissa digits by rounding it +// to the nearest even value with n (or fever) mantissa digits. +// If n < 0, x remains unchanged. +func (x *decimal) round(n int) { + if n < 0 || n >= len(x.mant) { + return // nothing to do + } + + if shouldRoundUp(x, n) { + x.roundUp(n) + } else { + x.roundDown(n) + } +} + +func (x *decimal) roundUp(n int) { + if n < 0 || n >= len(x.mant) { + return // nothing to do + } + // 0 <= n < len(x.mant) + + // find first digit < '9' + for n > 0 && x.mant[n-1] >= '9' { + n-- + } + + if n == 0 { + // all digits are '9's => round up to '1' and update exponent + x.mant[0] = '1' // ok since len(x.mant) > n + x.mant = x.mant[:1] + x.exp++ + return + } + + // n > 0 && x.mant[n-1] < '9' + x.mant[n-1]++ + x.mant = x.mant[:n] + // x already trimmed +} + +func (x *decimal) roundDown(n int) { + if n < 0 || n >= len(x.mant) { + return // nothing to do + } + x.mant = x.mant[:n] + trim(x) +} + +// trim cuts off any trailing zeros from x's mantissa; +// they are meaningless for the value of x. +func trim(x *decimal) { + i := len(x.mant) + for i > 0 && x.mant[i-1] == '0' { + i-- + } + x.mant = x.mant[:i] +} diff --git a/src/math/big/decimal_test.go b/src/math/big/decimal_test.go index ce20800ef0..81e022a47d 100644 --- a/src/math/big/decimal_test.go +++ b/src/math/big/decimal_test.go @@ -49,3 +49,58 @@ func TestDecimalInit(t *testing.T) { } } } + +func TestDecimalRounding(t *testing.T) { + for _, test := range []struct { + x uint64 + n int + down, even, up string + }{ + {0, 0, "0", "0", "0"}, + {0, 1, "0", "0", "0"}, + + {1, 0, "0", "0", "10"}, + {5, 0, "0", "0", "10"}, + {9, 0, "0", "10", "10"}, + + {15, 1, "10", "20", "20"}, + {45, 1, "40", "40", "50"}, + {95, 1, "90", "100", "100"}, + + {12344999, 4, "12340000", "12340000", "12350000"}, + {12345000, 4, "12340000", "12340000", "12350000"}, + {12345001, 4, "12340000", "12350000", "12350000"}, + {23454999, 4, "23450000", "23450000", "23460000"}, + {23455000, 4, "23450000", "23460000", "23460000"}, + {23455001, 4, "23450000", "23460000", "23460000"}, + + {99994999, 4, "99990000", "99990000", "100000000"}, + {99995000, 4, "99990000", "100000000", "100000000"}, + {99999999, 4, "99990000", "100000000", "100000000"}, + + {12994999, 4, "12990000", "12990000", "13000000"}, + {12995000, 4, "12990000", "13000000", "13000000"}, + {12999999, 4, "12990000", "13000000", "13000000"}, + } { + x := nat(nil).setUint64(test.x) + + var d decimal + d.init(x, 0) + d.roundDown(test.n) + if got := d.String(); got != test.down { + t.Errorf("roundDown(%d, %d) = %s; want %s", test.x, test.n, got, test.down) + } + + d.init(x, 0) + d.round(test.n) + if got := d.String(); got != test.even { + t.Errorf("round(%d, %d) = %s; want %s", test.x, test.n, got, test.even) + } + + d.init(x, 0) + d.roundUp(test.n) + if got := d.String(); got != test.up { + t.Errorf("roundUp(%d, %d) = %s; want %s", test.x, test.n, got, test.up) + } + } +}