diff --git a/src/pkg/math/big/int.go b/src/pkg/math/big/int.go index 4bd7958ae5..caa23ae3d2 100644 --- a/src/pkg/math/big/int.go +++ b/src/pkg/math/big/int.go @@ -561,19 +561,18 @@ func (x *Int) BitLen() int { return x.abs.bitLen() } -// Exp sets z = x**y mod m and returns z. If m is nil, z = x**y. +// Exp sets z = x**y mod |m| (i.e. the sign of m is ignored), and returns z. +// If y <= 0, the result is 1; if m == nil or m == 0, z = x**y. // See Knuth, volume 2, section 4.6.3. func (z *Int) Exp(x, y, m *Int) *Int { if y.neg || len(y.abs) == 0 { - neg := x.neg - z.SetInt64(1) - z.neg = neg - return z + return z.SetInt64(1) } + // y > 0 var mWords nat if m != nil { - mWords = m.abs + mWords = m.abs // m.abs may be nil for m == 0 } z.abs = z.abs.expNN(x.abs, y.abs, mWords) diff --git a/src/pkg/math/big/int_test.go b/src/pkg/math/big/int_test.go index 27834cec6a..d3c5a0e8bf 100644 --- a/src/pkg/math/big/int_test.go +++ b/src/pkg/math/big/int_test.go @@ -767,8 +767,10 @@ var expTests = []struct { x, y, m string out string }{ + {"5", "-7", "", "1"}, + {"-5", "-7", "", "1"}, {"5", "0", "", "1"}, - {"-5", "0", "", "-1"}, + {"-5", "0", "", "1"}, {"5", "1", "", "5"}, {"-5", "1", "", "-5"}, {"-2", "3", "2", "0"}, @@ -779,6 +781,7 @@ var expTests = []struct { {"0x8000000000000000", "3", "6719", "5447"}, {"0x8000000000000000", "1000", "6719", "1603"}, {"0x8000000000000000", "1000000", "6719", "3199"}, + {"0x8000000000000000", "-1000000", "6719", "1"}, { "2938462938472983472983659726349017249287491026512746239764525612965293865296239471239874193284792387498274256129746192347", "298472983472983471903246121093472394872319615612417471234712061", @@ -807,12 +810,22 @@ func TestExp(t *testing.T) { continue } - z := y.Exp(x, y, m) - if !isNormalized(z) { - t.Errorf("#%d: %v is not normalized", i, *z) + z1 := new(Int).Exp(x, y, m) + if !isNormalized(z1) { + t.Errorf("#%d: %v is not normalized", i, *z1) } - if z.Cmp(out) != 0 { - t.Errorf("#%d: got %s want %s", i, z, out) + if z1.Cmp(out) != 0 { + t.Errorf("#%d: got %s want %s", i, z1, out) + } + + if m == nil { + // the result should be the same as for m == 0; + // specifically, there should be no div-zero panic + m = &Int{abs: nat{}} // m != nil && len(m.abs) == 0 + z2 := new(Int).Exp(x, y, m) + if z2.Cmp(z1) != 0 { + t.Errorf("#%d: got %s want %s", i, z1, z2) + } } } } diff --git a/src/pkg/math/big/nat.go b/src/pkg/math/big/nat.go index b2d6cd96c6..2e5c56d461 100644 --- a/src/pkg/math/big/nat.go +++ b/src/pkg/math/big/nat.go @@ -1227,8 +1227,8 @@ func (z nat) random(rand *rand.Rand, limit nat, n int) nat { return z.norm() } -// If m != nil, expNN calculates x**y mod m. Otherwise it calculates x**y. It -// reuses the storage of z if possible. +// If m != 0 (i.e., len(m) != 0), expNN sets z to x**y mod m; +// otherwise it sets z to x**y. The result is the value of z. func (z nat) expNN(x, y, m nat) nat { if alias(z, x) || alias(z, y) { // We cannot allow in-place modification of x or y. @@ -1240,15 +1240,15 @@ func (z nat) expNN(x, y, m nat) nat { z[0] = 1 return z } + // y > 0 - if m != nil { + if len(m) != 0 { // We likely end up being as long as the modulus. z = z.make(len(m)) } z = z.set(x) - v := y[len(y)-1] - // It's invalid for the most significant word to be zero, therefore we - // will find a one bit. + + v := y[len(y)-1] // v > 0 because y is normalized and y > 0 shift := leadingZeros(v) + 1 v <<= shift var q nat @@ -1272,7 +1272,7 @@ func (z nat) expNN(x, y, m nat) nat { zz, z = z, zz } - if m != nil { + if len(m) != 0 { zz, r = zz.div(r, z, m) zz, r, q, z = q, z, zz, r } @@ -1292,7 +1292,7 @@ func (z nat) expNN(x, y, m nat) nat { zz, z = z, zz } - if m != nil { + if len(m) != 0 { zz, r = zz.div(r, z, m) zz, r, q, z = q, z, zz, r }