mirror of
https://github.com/golang/go
synced 2024-11-23 16:00:06 -07:00
math/big: implement recursive algorithm for division
The current division algorithm produces one word of result at a time, using 2-word division to compute the top word and mulAddVWW to compute the remainder. The top word may need to be adjusted by 1 or 2 units. The recursive version, based on Burnikel, Ziegler, "Fast Recursive Division", uses the same principles, but in a multi-word setting, so that multiplication benefits from the Karatsuba algorithm (and possibly later improvements). benchmark old ns/op new ns/op delta BenchmarkDiv/20/10-4 38.2 38.3 +0.26% BenchmarkDiv/40/20-4 38.7 38.5 -0.52% BenchmarkDiv/100/50-4 62.5 62.6 +0.16% BenchmarkDiv/200/100-4 238 259 +8.82% BenchmarkDiv/400/200-4 311 338 +8.68% BenchmarkDiv/1000/500-4 604 649 +7.45% BenchmarkDiv/2000/1000-4 1214 1278 +5.27% BenchmarkDiv/20000/10000-4 38279 36510 -4.62% BenchmarkDiv/200000/100000-4 3022057 1359615 -55.01% BenchmarkDiv/2000000/1000000-4 310827664 54012939 -82.62% BenchmarkDiv/20000000/10000000-4 33272829421 1965401359 -94.09% BenchmarkString/10/Base10-4 158 156 -1.27% BenchmarkString/100/Base10-4 797 792 -0.63% BenchmarkString/1000/Base10-4 3677 3814 +3.73% BenchmarkString/10000/Base10-4 16633 17116 +2.90% BenchmarkString/100000/Base10-4 5779029 1793808 -68.96% BenchmarkString/1000000/Base10-4 889840820 85524031 -90.39% BenchmarkString/10000000/Base10-4 134338236860 4935657026 -96.33% Fixes #21960 Updates #30943 Change-Id: I134c6f81a47870c688ca95b6081eb9211def15a2 Reviewed-on: https://go-review.googlesource.com/c/go/+/172018 Reviewed-by: Robert Griesemer <gri@golang.org> Run-TryBot: Robert Griesemer <gri@golang.org> TryBot-Result: Gobot Gobot <gobot@golang.org>
This commit is contained in:
parent
b8cb75fb17
commit
194ae3236d
@ -1829,8 +1829,11 @@ func benchmarkDiv(b *testing.B, aSize, bSize int) {
|
||||
}
|
||||
|
||||
func BenchmarkDiv(b *testing.B) {
|
||||
min, max, step := 10, 100000, 10
|
||||
for i := min; i <= max; i *= step {
|
||||
sizes := []int{
|
||||
10, 20, 50, 100, 200, 500, 1000,
|
||||
1e4, 1e5, 1e6, 1e7,
|
||||
}
|
||||
for _, i := range sizes {
|
||||
j := 2 * i
|
||||
b.Run(fmt.Sprintf("%d/%d", j, i), func(b *testing.B) {
|
||||
benchmarkDiv(b, j, i)
|
||||
|
@ -693,7 +693,7 @@ func putNat(x *nat) {
|
||||
|
||||
var natPool sync.Pool
|
||||
|
||||
// q = (uIn-r)/vIn, with 0 <= r < y
|
||||
// q = (uIn-r)/vIn, with 0 <= r < vIn
|
||||
// Uses z as storage for q, and u as storage for r if possible.
|
||||
// See Knuth, Volume 2, section 4.3.1, Algorithm D.
|
||||
// Preconditions:
|
||||
@ -721,6 +721,30 @@ func (z nat) divLarge(u, uIn, vIn nat) (q, r nat) {
|
||||
}
|
||||
q = z.make(m + 1)
|
||||
|
||||
if n < divRecursiveThreshold {
|
||||
q.divBasic(u, v)
|
||||
} else {
|
||||
q.divRecursive(u, v)
|
||||
}
|
||||
putNat(vp)
|
||||
|
||||
q = q.norm()
|
||||
shrVU(u, u, shift)
|
||||
r = u.norm()
|
||||
|
||||
return q, r
|
||||
}
|
||||
|
||||
// divBasic performs word-by-word division of u by v.
|
||||
// The quotient is written in pre-allocated q.
|
||||
// The remainder overwrites input u.
|
||||
//
|
||||
// Precondition:
|
||||
// - len(q) >= len(u)-len(v)
|
||||
func (q nat) divBasic(u, v nat) {
|
||||
n := len(v)
|
||||
m := len(u) - n
|
||||
|
||||
qhatvp := getNat(n + 1)
|
||||
qhatv := *qhatvp
|
||||
|
||||
@ -729,7 +753,11 @@ func (z nat) divLarge(u, uIn, vIn nat) (q, r nat) {
|
||||
for j := m; j >= 0; j-- {
|
||||
// D3.
|
||||
qhat := Word(_M)
|
||||
if ujn := u[j+n]; ujn != vn1 {
|
||||
var ujn Word
|
||||
if j+n < len(u) {
|
||||
ujn = u[j+n]
|
||||
}
|
||||
if ujn != vn1 {
|
||||
var rhat Word
|
||||
qhat, rhat = divWW(ujn, u[j+n-1], vn1)
|
||||
|
||||
@ -752,25 +780,175 @@ func (z nat) divLarge(u, uIn, vIn nat) (q, r nat) {
|
||||
|
||||
// D4.
|
||||
qhatv[n] = mulAddVWW(qhatv[0:n], v, qhat, 0)
|
||||
|
||||
c := subVV(u[j:j+len(qhatv)], u[j:], qhatv)
|
||||
qhl := len(qhatv)
|
||||
if j+qhl > len(u) && qhatv[n] == 0 {
|
||||
qhl--
|
||||
}
|
||||
c := subVV(u[j:j+qhl], u[j:], qhatv)
|
||||
if c != 0 {
|
||||
c := addVV(u[j:j+n], u[j:], v)
|
||||
u[j+n] += c
|
||||
qhat--
|
||||
}
|
||||
|
||||
if j == m && m == len(q) && qhat == 0 {
|
||||
continue
|
||||
}
|
||||
q[j] = qhat
|
||||
}
|
||||
|
||||
putNat(vp)
|
||||
putNat(qhatvp)
|
||||
}
|
||||
|
||||
q = q.norm()
|
||||
shrVU(u, u, shift)
|
||||
r = u.norm()
|
||||
const divRecursiveThreshold = 100
|
||||
|
||||
return q, r
|
||||
// divRecursive performs word-by-word division of u by v.
|
||||
// The quotient is written in pre-allocated z.
|
||||
// The remainder overwrites input u.
|
||||
//
|
||||
// Precondition:
|
||||
// - len(z) >= len(u)-len(v)
|
||||
//
|
||||
// See Burnikel, Ziegler, "Fast Recursive Division", Algorithm 1 and 2.
|
||||
func (z nat) divRecursive(u, v nat) {
|
||||
// Recursion depth is less than 2 log2(len(v))
|
||||
// Allocate a slice of temporaries to be reused across recursion.
|
||||
recDepth := 2 * bits.Len(uint(len(v)))
|
||||
// large enough to perform Karatsuba on operands as large as v
|
||||
tmp := getNat(3 * len(v))
|
||||
temps := make([]*nat, recDepth)
|
||||
z.clear()
|
||||
z.divRecursiveStep(u, v, 0, tmp, temps)
|
||||
for _, n := range temps {
|
||||
if n != nil {
|
||||
putNat(n)
|
||||
}
|
||||
}
|
||||
putNat(tmp)
|
||||
}
|
||||
|
||||
func (z nat) divRecursiveStep(u, v nat, depth int, tmp *nat, temps []*nat) {
|
||||
u = u.norm()
|
||||
v = v.norm()
|
||||
|
||||
if len(u) == 0 {
|
||||
z.clear()
|
||||
return
|
||||
}
|
||||
n := len(v)
|
||||
if n < divRecursiveThreshold {
|
||||
z.divBasic(u, v)
|
||||
return
|
||||
}
|
||||
m := len(u) - n
|
||||
if m < 0 {
|
||||
return
|
||||
}
|
||||
|
||||
// Produce the quotient by blocks of B words.
|
||||
// Division by v (length n) is done using a length n/2 division
|
||||
// and a length n/2 multiplication for each block. The final
|
||||
// complexity is driven by multiplication complexity.
|
||||
B := n / 2
|
||||
|
||||
// Allocate a nat for qhat below.
|
||||
if temps[depth] == nil {
|
||||
temps[depth] = getNat(n)
|
||||
} else {
|
||||
*temps[depth] = temps[depth].make(B + 1)
|
||||
}
|
||||
|
||||
j := m
|
||||
for j > B {
|
||||
// Divide u[j-B:j+n] by vIn. Keep remainder in u
|
||||
// for next block.
|
||||
//
|
||||
// The following property will be used (Lemma 2):
|
||||
// if u = u1 << s + u0
|
||||
// v = v1 << s + v0
|
||||
// then floor(u1/v1) >= floor(u/v)
|
||||
//
|
||||
// Moreover, the difference is at most 2 if len(v1) >= len(u/v)
|
||||
// We choose s = B-1 since len(v)-B >= B+1 >= len(u/v)
|
||||
s := (B - 1)
|
||||
// Except for the first step, the top bits are always
|
||||
// a division remainder, so the quotient length is <= n.
|
||||
uu := u[j-B:]
|
||||
|
||||
qhat := *temps[depth]
|
||||
qhat.clear()
|
||||
qhat.divRecursiveStep(uu[s:B+n], v[s:], depth+1, tmp, temps)
|
||||
qhat = qhat.norm()
|
||||
// Adjust the quotient:
|
||||
// u = u_h << s + u_l
|
||||
// v = v_h << s + v_l
|
||||
// u_h = q̂ v_h + rh
|
||||
// u = q̂ (v - v_l) + rh << s + u_l
|
||||
// After the above step, u contains a remainder:
|
||||
// u = rh << s + u_l
|
||||
// and we need to substract q̂ v_l
|
||||
//
|
||||
// But it may be a bit too large, in which case q̂ needs to be smaller.
|
||||
qhatv := tmp.make(3 * n)
|
||||
qhatv.clear()
|
||||
qhatv = qhatv.mul(qhat, v[:s])
|
||||
for i := 0; i < 2; i++ {
|
||||
e := qhatv.cmp(uu.norm())
|
||||
if e <= 0 {
|
||||
break
|
||||
}
|
||||
subVW(qhat, qhat, 1)
|
||||
c := subVV(qhatv[:s], qhatv[:s], v[:s])
|
||||
if len(qhatv) > s {
|
||||
subVW(qhatv[s:], qhatv[s:], c)
|
||||
}
|
||||
addAt(uu[s:], v[s:], 0)
|
||||
}
|
||||
if qhatv.cmp(uu.norm()) > 0 {
|
||||
panic("impossible")
|
||||
}
|
||||
c := subVV(uu[:len(qhatv)], uu[:len(qhatv)], qhatv)
|
||||
if c > 0 {
|
||||
subVW(uu[len(qhatv):], uu[len(qhatv):], c)
|
||||
}
|
||||
addAt(z, qhat, j-B)
|
||||
j -= B
|
||||
}
|
||||
|
||||
// Now u < (v<<B), compute lower bits in the same way.
|
||||
// Choose shift = B-1 again.
|
||||
s := B
|
||||
qhat := *temps[depth]
|
||||
qhat.clear()
|
||||
qhat.divRecursiveStep(u[s:].norm(), v[s:], depth+1, tmp, temps)
|
||||
qhat = qhat.norm()
|
||||
qhatv := tmp.make(3 * n)
|
||||
qhatv.clear()
|
||||
qhatv = qhatv.mul(qhat, v[:s])
|
||||
// Set the correct remainder as before.
|
||||
for i := 0; i < 2; i++ {
|
||||
if e := qhatv.cmp(u.norm()); e > 0 {
|
||||
subVW(qhat, qhat, 1)
|
||||
c := subVV(qhatv[:s], qhatv[:s], v[:s])
|
||||
if len(qhatv) > s {
|
||||
subVW(qhatv[s:], qhatv[s:], c)
|
||||
}
|
||||
addAt(u[s:], v[s:], 0)
|
||||
}
|
||||
}
|
||||
if qhatv.cmp(u.norm()) > 0 {
|
||||
panic("impossible")
|
||||
}
|
||||
c := subVV(u[0:len(qhatv)], u[0:len(qhatv)], qhatv)
|
||||
if c > 0 {
|
||||
c = subVW(u[len(qhatv):B+n], u[len(qhatv):B+n], c)
|
||||
}
|
||||
if c > 0 {
|
||||
panic("impossible")
|
||||
}
|
||||
|
||||
// Done!
|
||||
addAt(z, qhat.norm(), 0)
|
||||
}
|
||||
|
||||
// Length of x in bits. x must be normalized.
|
||||
|
@ -739,3 +739,27 @@ func BenchmarkNatSetBytes(b *testing.B) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNatDiv(t *testing.T) {
|
||||
sizes := []int{
|
||||
1, 2, 5, 8, 15, 25, 40, 65, 100,
|
||||
200, 500, 800, 1500, 2500, 4000, 6500, 10000,
|
||||
}
|
||||
for _, i := range sizes {
|
||||
for _, j := range sizes {
|
||||
a := rndNat(i)
|
||||
b := rndNat(j)
|
||||
x := nat(nil).mul(a, b)
|
||||
addVW(x, x, 1)
|
||||
|
||||
var q, r nat
|
||||
q, r = q.div(r, x, b)
|
||||
if q.cmp(a) != 0 {
|
||||
t.Fatal("wrong quotient", i, j)
|
||||
}
|
||||
if len(r) != 1 || r[0] != 1 {
|
||||
t.Fatal("wrong remainder")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user