1
0
mirror of https://github.com/golang/go synced 2024-11-20 05:44:44 -07:00

math/big: fix big.Exp and document better

- always return 1 for y <= 0
- document that the sign of m is ignored
- protect against div-0 panics by treating
  m == 0 the same way as m == nil
- added extra tests

Fixes #4239.

R=agl, remyoudompheng, agl
CC=golang-dev
https://golang.org/cl/6724046
This commit is contained in:
Robert Griesemer 2012-10-16 13:46:27 -07:00
parent cfa1ba34cc
commit 7565726875
3 changed files with 32 additions and 20 deletions

View File

@ -561,19 +561,18 @@ func (x *Int) BitLen() int {
return x.abs.bitLen() return x.abs.bitLen()
} }
// Exp sets z = x**y mod m and returns z. If m is nil, z = x**y. // Exp sets z = x**y mod |m| (i.e. the sign of m is ignored), and returns z.
// If y <= 0, the result is 1; if m == nil or m == 0, z = x**y.
// See Knuth, volume 2, section 4.6.3. // See Knuth, volume 2, section 4.6.3.
func (z *Int) Exp(x, y, m *Int) *Int { func (z *Int) Exp(x, y, m *Int) *Int {
if y.neg || len(y.abs) == 0 { if y.neg || len(y.abs) == 0 {
neg := x.neg return z.SetInt64(1)
z.SetInt64(1)
z.neg = neg
return z
} }
// y > 0
var mWords nat var mWords nat
if m != nil { if m != nil {
mWords = m.abs mWords = m.abs // m.abs may be nil for m == 0
} }
z.abs = z.abs.expNN(x.abs, y.abs, mWords) z.abs = z.abs.expNN(x.abs, y.abs, mWords)

View File

@ -767,8 +767,10 @@ var expTests = []struct {
x, y, m string x, y, m string
out string out string
}{ }{
{"5", "-7", "", "1"},
{"-5", "-7", "", "1"},
{"5", "0", "", "1"}, {"5", "0", "", "1"},
{"-5", "0", "", "-1"}, {"-5", "0", "", "1"},
{"5", "1", "", "5"}, {"5", "1", "", "5"},
{"-5", "1", "", "-5"}, {"-5", "1", "", "-5"},
{"-2", "3", "2", "0"}, {"-2", "3", "2", "0"},
@ -779,6 +781,7 @@ var expTests = []struct {
{"0x8000000000000000", "3", "6719", "5447"}, {"0x8000000000000000", "3", "6719", "5447"},
{"0x8000000000000000", "1000", "6719", "1603"}, {"0x8000000000000000", "1000", "6719", "1603"},
{"0x8000000000000000", "1000000", "6719", "3199"}, {"0x8000000000000000", "1000000", "6719", "3199"},
{"0x8000000000000000", "-1000000", "6719", "1"},
{ {
"2938462938472983472983659726349017249287491026512746239764525612965293865296239471239874193284792387498274256129746192347", "2938462938472983472983659726349017249287491026512746239764525612965293865296239471239874193284792387498274256129746192347",
"298472983472983471903246121093472394872319615612417471234712061", "298472983472983471903246121093472394872319615612417471234712061",
@ -807,12 +810,22 @@ func TestExp(t *testing.T) {
continue continue
} }
z := y.Exp(x, y, m) z1 := new(Int).Exp(x, y, m)
if !isNormalized(z) { if !isNormalized(z1) {
t.Errorf("#%d: %v is not normalized", i, *z) t.Errorf("#%d: %v is not normalized", i, *z1)
}
if z1.Cmp(out) != 0 {
t.Errorf("#%d: got %s want %s", i, z1, out)
}
if m == nil {
// the result should be the same as for m == 0;
// specifically, there should be no div-zero panic
m = &Int{abs: nat{}} // m != nil && len(m.abs) == 0
z2 := new(Int).Exp(x, y, m)
if z2.Cmp(z1) != 0 {
t.Errorf("#%d: got %s want %s", i, z1, z2)
} }
if z.Cmp(out) != 0 {
t.Errorf("#%d: got %s want %s", i, z, out)
} }
} }
} }

View File

@ -1227,8 +1227,8 @@ func (z nat) random(rand *rand.Rand, limit nat, n int) nat {
return z.norm() return z.norm()
} }
// If m != nil, expNN calculates x**y mod m. Otherwise it calculates x**y. It // If m != 0 (i.e., len(m) != 0), expNN sets z to x**y mod m;
// reuses the storage of z if possible. // otherwise it sets z to x**y. The result is the value of z.
func (z nat) expNN(x, y, m nat) nat { func (z nat) expNN(x, y, m nat) nat {
if alias(z, x) || alias(z, y) { if alias(z, x) || alias(z, y) {
// We cannot allow in-place modification of x or y. // We cannot allow in-place modification of x or y.
@ -1240,15 +1240,15 @@ func (z nat) expNN(x, y, m nat) nat {
z[0] = 1 z[0] = 1
return z return z
} }
// y > 0
if m != nil { if len(m) != 0 {
// We likely end up being as long as the modulus. // We likely end up being as long as the modulus.
z = z.make(len(m)) z = z.make(len(m))
} }
z = z.set(x) z = z.set(x)
v := y[len(y)-1]
// It's invalid for the most significant word to be zero, therefore we v := y[len(y)-1] // v > 0 because y is normalized and y > 0
// will find a one bit.
shift := leadingZeros(v) + 1 shift := leadingZeros(v) + 1
v <<= shift v <<= shift
var q nat var q nat
@ -1272,7 +1272,7 @@ func (z nat) expNN(x, y, m nat) nat {
zz, z = z, zz zz, z = z, zz
} }
if m != nil { if len(m) != 0 {
zz, r = zz.div(r, z, m) zz, r = zz.div(r, z, m)
zz, r, q, z = q, z, zz, r zz, r, q, z = q, z, zz, r
} }
@ -1292,7 +1292,7 @@ func (z nat) expNN(x, y, m nat) nat {
zz, z = z, zz zz, z = z, zz
} }
if m != nil { if len(m) != 0 {
zz, r = zz.div(r, z, m) zz, r = zz.div(r, z, m)
zz, r, q, z = q, z, zz, r zz, r, q, z = q, z, zz, r
} }