1
0
mirror of https://github.com/golang/go synced 2024-11-24 20:10:02 -07:00

big: use fast shift routines

- fixed a couple of bugs in the process
  (shift right was incorrect for negative numbers)
- added more tests and made some tests more robust
- changed pidigits back to using shifts to multiply
  by 2 instead of add

  This improves pidigit -s -n 10000 by approx. 5%:

  user 0m6.496s (old)
  user 0m6.156s (new)

R=rsc
CC=golang-dev
https://golang.org/cl/963044
This commit is contained in:
Robert Griesemer 2010-04-30 21:25:48 -07:00
parent 161b44c76a
commit 58e77990ba
5 changed files with 95 additions and 46 deletions

View File

@ -216,13 +216,16 @@ func (z *Int) SetString(s string, base int) (*Int, bool) {
if scanned != len(s) {
goto Error
}
if len(z.abs) == 0 {
z.neg = false // 0 has no sign
}
return z, true
Error:
z.neg = false
z.abs = nil
return nil, false
return z, false
}
@ -384,26 +387,24 @@ func ProbablyPrime(z *Int, n int) bool { return !z.neg && z.abs.probablyPrime(n)
// Lsh sets z = x << n and returns z.
func (z *Int) Lsh(x *Int, n uint) *Int {
addedWords := int(n) / _W
// Don't assign z.abs yet, in case z == x
znew := z.abs.make(len(x.abs) + addedWords + 1)
z.neg = x.neg
znew[addedWords:].shiftLeft(x.abs, n%_W)
for i := range znew[0:addedWords] {
znew[i] = 0
}
z.abs = znew.norm()
z.abs = z.abs.shl(x.abs, n)
return z
}
// Rsh sets z = x >> n and returns z.
func (z *Int) Rsh(x *Int, n uint) *Int {
removedWords := int(n) / _W
// Don't assign z.abs yet, in case z == x
znew := z.abs.make(len(x.abs) - removedWords)
z.neg = x.neg
znew.shiftRight(x.abs[removedWords:], n%_W)
z.abs = znew.norm()
if x.neg {
// (-x) >> s == ^(x-1) >> s == ^((x-1) >> s) == -(((x-1) >> s) + 1)
z.neg = true
t := z.abs.sub(x.abs, natOne) // no underflow because |x| > 0
t = t.shr(t, n)
z.abs = t.add(t, natOne)
return z
}
z.neg = false
z.abs = z.abs.shr(x.abs, n)
return z
}

View File

@ -562,6 +562,7 @@ type intShiftTest struct {
var rshTests = []intShiftTest{
intShiftTest{"0", 0, "0"},
intShiftTest{"-0", 0, "0"},
intShiftTest{"0", 1, "0"},
intShiftTest{"0", 2, "0"},
intShiftTest{"1", 0, "1"},
@ -569,7 +570,12 @@ var rshTests = []intShiftTest{
intShiftTest{"1", 2, "0"},
intShiftTest{"2", 0, "2"},
intShiftTest{"2", 1, "1"},
intShiftTest{"2", 2, "0"},
intShiftTest{"-1", 0, "-1"},
intShiftTest{"-1", 1, "-1"},
intShiftTest{"-1", 10, "-1"},
intShiftTest{"-100", 2, "-25"},
intShiftTest{"-100", 3, "-13"},
intShiftTest{"-100", 100, "-1"},
intShiftTest{"4294967296", 0, "4294967296"},
intShiftTest{"4294967296", 1, "2147483648"},
intShiftTest{"4294967296", 2, "1073741824"},

View File

@ -554,8 +554,8 @@ func (z nat) divLarge(z2, uIn, v nat) (q, r nat) {
// D1.
shift := uint(leadingZeroBits(v[n-1]))
v.shiftLeft(v, shift)
u.shiftLeft(uIn, shift)
v.shiftLeftDeprecated(v, shift)
u.shiftLeftDeprecated(uIn, shift)
u[len(uIn)] = uIn[len(uIn)-1] >> (_W - uint(shift))
// D2.
@ -597,8 +597,8 @@ func (z nat) divLarge(z2, uIn, v nat) (q, r nat) {
}
q = q.norm()
u.shiftRight(u, shift)
v.shiftRight(v, shift)
u.shiftRightDeprecated(u, shift)
v.shiftRightDeprecated(v, shift)
r = u.norm()
return q, r
@ -780,12 +780,56 @@ func trailingZeroBits(x Word) int {
}
// TODO(gri) Make the shift routines faster.
// Use pidigits.go benchmark as a test case.
// z = x << s
func (z nat) shl(x nat, s uint) nat {
m := len(x)
if m == 0 {
return z.make(0)
}
// m > 0
// determine if z can be reused
// TODO(gri) change shlVW so we don't need this
if len(z) > 0 && alias(z, x) {
z = nil // z is an alias for x - cannot reuse
}
n := m + int(s/_W)
z = z.make(n + 1)
z[n] = shlVW(&z[n-m], &x[0], Word(s%_W), m)
return z.norm()
}
// z = x >> s
func (z nat) shr(x nat, s uint) nat {
m := len(x)
n := m - int(s/_W)
if n <= 0 {
return z.make(0)
}
// n > 0
// determine if z can be reused
// TODO(gri) change shrVW so we don't need this
if len(z) > 0 && alias(z, x) {
z = nil // z is an alias for x - cannot reuse
}
z = z.make(n)
shrVW(&z[0], &x[m-n], Word(s%_W), m)
return z.norm()
}
// TODO(gri) Remove these shift functions once shlVW and shrVW can be
// used directly in divLarge and powersOfTwoDecompose
//
// To avoid losing the top n bits, z should be sized so that
// len(z) == len(x) + 1.
func (z nat) shiftLeft(x nat, n uint) nat {
func (z nat) shiftLeftDeprecated(x nat, n uint) nat {
if len(x) == 0 {
return x
}
@ -805,7 +849,7 @@ func (z nat) shiftLeft(x nat, n uint) nat {
}
func (z nat) shiftRight(x nat, n uint) nat {
func (z nat) shiftRightDeprecated(x nat, n uint) nat {
if len(x) == 0 {
return x
}
@ -850,7 +894,7 @@ func (n nat) powersOfTwoDecompose() (q nat, k Word) {
x := trailingZeroBits(n[zeroWords])
q = q.make(len(n) - zeroWords)
q.shiftRight(n[zeroWords:], uint(x))
q.shiftRightDeprecated(n[zeroWords:], uint(x))
q = q.norm()
k = Word(_W*zeroWords + x)

View File

@ -230,9 +230,8 @@ type shiftTest struct {
var leftShiftTests = []shiftTest{
shiftTest{nil, 0, nil},
shiftTest{nil, 1, nil},
shiftTest{nat{0}, 0, nat{0}},
shiftTest{nat{1}, 0, nat{1}},
shiftTest{nat{1}, 1, nat{2}},
shiftTest{natOne, 0, natOne},
shiftTest{natOne, 1, natTwo},
shiftTest{nat{1 << (_W - 1)}, 1, nat{0}},
shiftTest{nat{1 << (_W - 1), 0}, 1, nat{0, 1}},
}
@ -240,11 +239,11 @@ var leftShiftTests = []shiftTest{
func TestShiftLeft(t *testing.T) {
for i, test := range leftShiftTests {
dst := make(nat, len(test.out))
dst.shiftLeft(test.in, test.shift)
for j, v := range dst {
if test.out[j] != v {
t.Errorf("#%d: got: %v want: %v", i, dst, test.out)
var z nat
z = z.shl(test.in, test.shift)
for j, d := range test.out {
if j >= len(z) || z[j] != d {
t.Errorf("#%d: got: %v want: %v", i, z, test.out)
break
}
}
@ -255,22 +254,21 @@ func TestShiftLeft(t *testing.T) {
var rightShiftTests = []shiftTest{
shiftTest{nil, 0, nil},
shiftTest{nil, 1, nil},
shiftTest{nat{0}, 0, nat{0}},
shiftTest{nat{1}, 0, nat{1}},
shiftTest{nat{1}, 1, nat{0}},
shiftTest{nat{2}, 1, nat{1}},
shiftTest{nat{0, 1}, 1, nat{1 << (_W - 1), 0}},
shiftTest{nat{2, 1, 1}, 1, nat{1<<(_W-1) + 1, 1 << (_W - 1), 0}},
shiftTest{natOne, 0, natOne},
shiftTest{natOne, 1, nil},
shiftTest{natTwo, 1, natOne},
shiftTest{nat{0, 1}, 1, nat{1 << (_W - 1)}},
shiftTest{nat{2, 1, 1}, 1, nat{1<<(_W-1) + 1, 1 << (_W - 1)}},
}
func TestShiftRight(t *testing.T) {
for i, test := range rightShiftTests {
dst := make(nat, len(test.out))
dst.shiftRight(test.in, test.shift)
for j, v := range dst {
if test.out[j] != v {
t.Errorf("#%d: got: %v want: %v", i, dst, test.out)
var z nat
z = z.shr(test.in, test.shift)
for j, d := range test.out {
if j >= len(z) || z[j] != d {
t.Errorf("#%d: got: %v want: %v", i, z, test.out)
break
}
}

View File

@ -63,7 +63,7 @@ func extract_digit() int64 {
}
// Compute (numer * 3 + accum) / denom
tmp1.Add(numer, numer) // tmp1.Lsh(numer, 1)
tmp1.Lsh(numer, 1)
tmp1.Add(tmp1, numer)
tmp1.Add(tmp1, accum)
tmp1.DivMod(tmp1, denom, tmp2)
@ -84,7 +84,7 @@ func next_term(k int64) {
y2.New(k*2 + 1)
bigk.New(k)
tmp1.Add(numer, numer) // tmp1.Lsh(numer, 1)
tmp1.Lsh(numer, 1)
accum.Add(accum, tmp1)
accum.Mul(accum, y2)
numer.Mul(numer, bigk)