mirror of
https://github.com/golang/go
synced 2024-11-24 20:10:02 -07:00
big: use fast shift routines
- fixed a couple of bugs in the process (shift right was incorrect for negative numbers) - added more tests and made some tests more robust - changed pidigits back to using shifts to multiply by 2 instead of add This improves pidigit -s -n 10000 by approx. 5%: user 0m6.496s (old) user 0m6.156s (new) R=rsc CC=golang-dev https://golang.org/cl/963044
This commit is contained in:
parent
161b44c76a
commit
58e77990ba
@ -216,13 +216,16 @@ func (z *Int) SetString(s string, base int) (*Int, bool) {
|
||||
if scanned != len(s) {
|
||||
goto Error
|
||||
}
|
||||
if len(z.abs) == 0 {
|
||||
z.neg = false // 0 has no sign
|
||||
}
|
||||
|
||||
return z, true
|
||||
|
||||
Error:
|
||||
z.neg = false
|
||||
z.abs = nil
|
||||
return nil, false
|
||||
return z, false
|
||||
}
|
||||
|
||||
|
||||
@ -384,26 +387,24 @@ func ProbablyPrime(z *Int, n int) bool { return !z.neg && z.abs.probablyPrime(n)
|
||||
|
||||
// Lsh sets z = x << n and returns z.
|
||||
func (z *Int) Lsh(x *Int, n uint) *Int {
|
||||
addedWords := int(n) / _W
|
||||
// Don't assign z.abs yet, in case z == x
|
||||
znew := z.abs.make(len(x.abs) + addedWords + 1)
|
||||
z.neg = x.neg
|
||||
znew[addedWords:].shiftLeft(x.abs, n%_W)
|
||||
for i := range znew[0:addedWords] {
|
||||
znew[i] = 0
|
||||
}
|
||||
z.abs = znew.norm()
|
||||
z.abs = z.abs.shl(x.abs, n)
|
||||
return z
|
||||
}
|
||||
|
||||
|
||||
// Rsh sets z = x >> n and returns z.
|
||||
func (z *Int) Rsh(x *Int, n uint) *Int {
|
||||
removedWords := int(n) / _W
|
||||
// Don't assign z.abs yet, in case z == x
|
||||
znew := z.abs.make(len(x.abs) - removedWords)
|
||||
z.neg = x.neg
|
||||
znew.shiftRight(x.abs[removedWords:], n%_W)
|
||||
z.abs = znew.norm()
|
||||
if x.neg {
|
||||
// (-x) >> s == ^(x-1) >> s == ^((x-1) >> s) == -(((x-1) >> s) + 1)
|
||||
z.neg = true
|
||||
t := z.abs.sub(x.abs, natOne) // no underflow because |x| > 0
|
||||
t = t.shr(t, n)
|
||||
z.abs = t.add(t, natOne)
|
||||
return z
|
||||
}
|
||||
|
||||
z.neg = false
|
||||
z.abs = z.abs.shr(x.abs, n)
|
||||
return z
|
||||
}
|
||||
|
@ -562,6 +562,7 @@ type intShiftTest struct {
|
||||
|
||||
var rshTests = []intShiftTest{
|
||||
intShiftTest{"0", 0, "0"},
|
||||
intShiftTest{"-0", 0, "0"},
|
||||
intShiftTest{"0", 1, "0"},
|
||||
intShiftTest{"0", 2, "0"},
|
||||
intShiftTest{"1", 0, "1"},
|
||||
@ -569,7 +570,12 @@ var rshTests = []intShiftTest{
|
||||
intShiftTest{"1", 2, "0"},
|
||||
intShiftTest{"2", 0, "2"},
|
||||
intShiftTest{"2", 1, "1"},
|
||||
intShiftTest{"2", 2, "0"},
|
||||
intShiftTest{"-1", 0, "-1"},
|
||||
intShiftTest{"-1", 1, "-1"},
|
||||
intShiftTest{"-1", 10, "-1"},
|
||||
intShiftTest{"-100", 2, "-25"},
|
||||
intShiftTest{"-100", 3, "-13"},
|
||||
intShiftTest{"-100", 100, "-1"},
|
||||
intShiftTest{"4294967296", 0, "4294967296"},
|
||||
intShiftTest{"4294967296", 1, "2147483648"},
|
||||
intShiftTest{"4294967296", 2, "1073741824"},
|
||||
|
@ -554,8 +554,8 @@ func (z nat) divLarge(z2, uIn, v nat) (q, r nat) {
|
||||
|
||||
// D1.
|
||||
shift := uint(leadingZeroBits(v[n-1]))
|
||||
v.shiftLeft(v, shift)
|
||||
u.shiftLeft(uIn, shift)
|
||||
v.shiftLeftDeprecated(v, shift)
|
||||
u.shiftLeftDeprecated(uIn, shift)
|
||||
u[len(uIn)] = uIn[len(uIn)-1] >> (_W - uint(shift))
|
||||
|
||||
// D2.
|
||||
@ -597,8 +597,8 @@ func (z nat) divLarge(z2, uIn, v nat) (q, r nat) {
|
||||
}
|
||||
|
||||
q = q.norm()
|
||||
u.shiftRight(u, shift)
|
||||
v.shiftRight(v, shift)
|
||||
u.shiftRightDeprecated(u, shift)
|
||||
v.shiftRightDeprecated(v, shift)
|
||||
r = u.norm()
|
||||
|
||||
return q, r
|
||||
@ -780,12 +780,56 @@ func trailingZeroBits(x Word) int {
|
||||
}
|
||||
|
||||
|
||||
// TODO(gri) Make the shift routines faster.
|
||||
// Use pidigits.go benchmark as a test case.
|
||||
// z = x << s
|
||||
func (z nat) shl(x nat, s uint) nat {
|
||||
m := len(x)
|
||||
if m == 0 {
|
||||
return z.make(0)
|
||||
}
|
||||
// m > 0
|
||||
|
||||
// determine if z can be reused
|
||||
// TODO(gri) change shlVW so we don't need this
|
||||
if len(z) > 0 && alias(z, x) {
|
||||
z = nil // z is an alias for x - cannot reuse
|
||||
}
|
||||
|
||||
n := m + int(s/_W)
|
||||
z = z.make(n + 1)
|
||||
z[n] = shlVW(&z[n-m], &x[0], Word(s%_W), m)
|
||||
|
||||
return z.norm()
|
||||
}
|
||||
|
||||
|
||||
// z = x >> s
|
||||
func (z nat) shr(x nat, s uint) nat {
|
||||
m := len(x)
|
||||
n := m - int(s/_W)
|
||||
if n <= 0 {
|
||||
return z.make(0)
|
||||
}
|
||||
// n > 0
|
||||
|
||||
// determine if z can be reused
|
||||
// TODO(gri) change shrVW so we don't need this
|
||||
if len(z) > 0 && alias(z, x) {
|
||||
z = nil // z is an alias for x - cannot reuse
|
||||
}
|
||||
|
||||
z = z.make(n)
|
||||
shrVW(&z[0], &x[m-n], Word(s%_W), m)
|
||||
|
||||
return z.norm()
|
||||
}
|
||||
|
||||
|
||||
// TODO(gri) Remove these shift functions once shlVW and shrVW can be
|
||||
// used directly in divLarge and powersOfTwoDecompose
|
||||
//
|
||||
// To avoid losing the top n bits, z should be sized so that
|
||||
// len(z) == len(x) + 1.
|
||||
func (z nat) shiftLeft(x nat, n uint) nat {
|
||||
func (z nat) shiftLeftDeprecated(x nat, n uint) nat {
|
||||
if len(x) == 0 {
|
||||
return x
|
||||
}
|
||||
@ -805,7 +849,7 @@ func (z nat) shiftLeft(x nat, n uint) nat {
|
||||
}
|
||||
|
||||
|
||||
func (z nat) shiftRight(x nat, n uint) nat {
|
||||
func (z nat) shiftRightDeprecated(x nat, n uint) nat {
|
||||
if len(x) == 0 {
|
||||
return x
|
||||
}
|
||||
@ -850,7 +894,7 @@ func (n nat) powersOfTwoDecompose() (q nat, k Word) {
|
||||
x := trailingZeroBits(n[zeroWords])
|
||||
|
||||
q = q.make(len(n) - zeroWords)
|
||||
q.shiftRight(n[zeroWords:], uint(x))
|
||||
q.shiftRightDeprecated(n[zeroWords:], uint(x))
|
||||
q = q.norm()
|
||||
|
||||
k = Word(_W*zeroWords + x)
|
||||
|
@ -230,9 +230,8 @@ type shiftTest struct {
|
||||
var leftShiftTests = []shiftTest{
|
||||
shiftTest{nil, 0, nil},
|
||||
shiftTest{nil, 1, nil},
|
||||
shiftTest{nat{0}, 0, nat{0}},
|
||||
shiftTest{nat{1}, 0, nat{1}},
|
||||
shiftTest{nat{1}, 1, nat{2}},
|
||||
shiftTest{natOne, 0, natOne},
|
||||
shiftTest{natOne, 1, natTwo},
|
||||
shiftTest{nat{1 << (_W - 1)}, 1, nat{0}},
|
||||
shiftTest{nat{1 << (_W - 1), 0}, 1, nat{0, 1}},
|
||||
}
|
||||
@ -240,11 +239,11 @@ var leftShiftTests = []shiftTest{
|
||||
|
||||
func TestShiftLeft(t *testing.T) {
|
||||
for i, test := range leftShiftTests {
|
||||
dst := make(nat, len(test.out))
|
||||
dst.shiftLeft(test.in, test.shift)
|
||||
for j, v := range dst {
|
||||
if test.out[j] != v {
|
||||
t.Errorf("#%d: got: %v want: %v", i, dst, test.out)
|
||||
var z nat
|
||||
z = z.shl(test.in, test.shift)
|
||||
for j, d := range test.out {
|
||||
if j >= len(z) || z[j] != d {
|
||||
t.Errorf("#%d: got: %v want: %v", i, z, test.out)
|
||||
break
|
||||
}
|
||||
}
|
||||
@ -255,22 +254,21 @@ func TestShiftLeft(t *testing.T) {
|
||||
var rightShiftTests = []shiftTest{
|
||||
shiftTest{nil, 0, nil},
|
||||
shiftTest{nil, 1, nil},
|
||||
shiftTest{nat{0}, 0, nat{0}},
|
||||
shiftTest{nat{1}, 0, nat{1}},
|
||||
shiftTest{nat{1}, 1, nat{0}},
|
||||
shiftTest{nat{2}, 1, nat{1}},
|
||||
shiftTest{nat{0, 1}, 1, nat{1 << (_W - 1), 0}},
|
||||
shiftTest{nat{2, 1, 1}, 1, nat{1<<(_W-1) + 1, 1 << (_W - 1), 0}},
|
||||
shiftTest{natOne, 0, natOne},
|
||||
shiftTest{natOne, 1, nil},
|
||||
shiftTest{natTwo, 1, natOne},
|
||||
shiftTest{nat{0, 1}, 1, nat{1 << (_W - 1)}},
|
||||
shiftTest{nat{2, 1, 1}, 1, nat{1<<(_W-1) + 1, 1 << (_W - 1)}},
|
||||
}
|
||||
|
||||
|
||||
func TestShiftRight(t *testing.T) {
|
||||
for i, test := range rightShiftTests {
|
||||
dst := make(nat, len(test.out))
|
||||
dst.shiftRight(test.in, test.shift)
|
||||
for j, v := range dst {
|
||||
if test.out[j] != v {
|
||||
t.Errorf("#%d: got: %v want: %v", i, dst, test.out)
|
||||
var z nat
|
||||
z = z.shr(test.in, test.shift)
|
||||
for j, d := range test.out {
|
||||
if j >= len(z) || z[j] != d {
|
||||
t.Errorf("#%d: got: %v want: %v", i, z, test.out)
|
||||
break
|
||||
}
|
||||
}
|
||||
|
@ -63,7 +63,7 @@ func extract_digit() int64 {
|
||||
}
|
||||
|
||||
// Compute (numer * 3 + accum) / denom
|
||||
tmp1.Add(numer, numer) // tmp1.Lsh(numer, 1)
|
||||
tmp1.Lsh(numer, 1)
|
||||
tmp1.Add(tmp1, numer)
|
||||
tmp1.Add(tmp1, accum)
|
||||
tmp1.DivMod(tmp1, denom, tmp2)
|
||||
@ -84,7 +84,7 @@ func next_term(k int64) {
|
||||
y2.New(k*2 + 1)
|
||||
bigk.New(k)
|
||||
|
||||
tmp1.Add(numer, numer) // tmp1.Lsh(numer, 1)
|
||||
tmp1.Lsh(numer, 1)
|
||||
accum.Add(accum, tmp1)
|
||||
accum.Mul(accum, y2)
|
||||
numer.Mul(numer, bigk)
|
||||
|
Loading…
Reference in New Issue
Block a user