mirror of
https://github.com/golang/go
synced 2024-11-20 10:44:41 -07:00
math/big: correct quadratic space complexity in Mul.
The previous implementation used to have a O(n) recursion depth for unbalanced inputs. A test is added to check that a reasonable amount of bytes is allocated in this case. Fixes #3807. R=golang-dev, dsymonds, gri CC=golang-dev, remy https://golang.org/cl/6345075
This commit is contained in:
parent
eb1c03eacb
commit
ac12131649
@ -342,7 +342,7 @@ func alias(x, y nat) bool {
|
||||
return cap(x) > 0 && cap(y) > 0 && &x[0:cap(x)][cap(x)-1] == &y[0:cap(y)][cap(y)-1]
|
||||
}
|
||||
|
||||
// addAt implements z += x*(1<<(_W*i)); z must be long enough.
|
||||
// addAt implements z += x<<(_W*i); z must be long enough.
|
||||
// (we don't use nat.add because we need z to stay the same
|
||||
// slice, and we don't need to normalize z after each addition)
|
||||
func addAt(z, x nat, i int) {
|
||||
@ -405,8 +405,8 @@ func (z nat) mul(x, y nat) nat {
|
||||
|
||||
// determine Karatsuba length k such that
|
||||
//
|
||||
// x = x1*b + x0
|
||||
// y = y1*b + y0 (and k <= len(y), which implies k <= len(x))
|
||||
// x = xh*b + x0 (0 <= x0 < b)
|
||||
// y = yh*b + y0 (0 <= y0 < b)
|
||||
// b = 1<<(_W*k) ("base" of digits xi, yi)
|
||||
//
|
||||
k := karatsubaLen(n)
|
||||
@ -417,27 +417,41 @@ func (z nat) mul(x, y nat) nat {
|
||||
y0 := y[0:k] // y0 is not normalized
|
||||
z = z.make(max(6*k, m+n)) // enough space for karatsuba of x0*y0 and full result of x*y
|
||||
karatsuba(z, x0, y0)
|
||||
z = z[0 : m+n] // z has final length but may be incomplete, upper portion is garbage
|
||||
z = z[0 : m+n] // z has final length but may be incomplete
|
||||
z[2*k:].clear() // upper portion of z is garbage (and 2*k <= m+n since k <= n <= m)
|
||||
|
||||
// If x1 and/or y1 are not 0, add missing terms to z explicitly:
|
||||
//
|
||||
// m+n 2*k 0
|
||||
// z = [ ... | x0*y0 ]
|
||||
// + [ x1*y1 ]
|
||||
// + [ x1*y0 ]
|
||||
// + [ x0*y1 ]
|
||||
// If xh != 0 or yh != 0, add the missing terms to z. For
|
||||
//
|
||||
// xh = xi*b^i + ... + x2*b^2 + x1*b (0 <= xi < b)
|
||||
// yh = y1*b (0 <= y1 < b)
|
||||
//
|
||||
// the missing terms are
|
||||
//
|
||||
// x0*y1*b and xi*y0*b^i, xi*y1*b^(i+1) for i > 0
|
||||
//
|
||||
// since all the yi for i > 1 are 0 by choice of k: If any of them
|
||||
// were > 0, then yh >= b^2 and thus y >= b^2. Then k' = k*2 would
|
||||
// be a larger valid threshold contradicting the assumption about k.
|
||||
//
|
||||
if k < n || m != n {
|
||||
x1 := x[k:] // x1 is normalized because x is
|
||||
y1 := y[k:] // y1 is normalized because y is
|
||||
var t nat
|
||||
t = t.mul(x1, y1)
|
||||
copy(z[2*k:], t)
|
||||
z[2*k+len(t):].clear() // upper portion of z is garbage
|
||||
t = t.mul(x1, y0.norm())
|
||||
addAt(z, t, k)
|
||||
t = t.mul(x0.norm(), y1)
|
||||
addAt(z, t, k)
|
||||
|
||||
// add x0*y1*b
|
||||
x0 := x0.norm()
|
||||
y1 := y[k:] // y1 is normalized because y is
|
||||
addAt(z, t.mul(x0, y1), k)
|
||||
|
||||
// add xi*y0<<i, xi*y1*b<<(i+k)
|
||||
y0 := y0.norm()
|
||||
for i := k; i < len(x); i += k {
|
||||
xi := x[i:]
|
||||
if len(xi) > k {
|
||||
xi = xi[:k]
|
||||
}
|
||||
xi = xi.norm()
|
||||
addAt(z, t.mul(xi, y0), i)
|
||||
addAt(z, t.mul(xi, y1), i+k)
|
||||
}
|
||||
}
|
||||
|
||||
return z.norm()
|
||||
|
@ -7,6 +7,7 @@ package big
|
||||
import (
|
||||
"io"
|
||||
"math/rand"
|
||||
"runtime"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
@ -63,6 +64,36 @@ var prodNN = []argNN{
|
||||
{nat{0, 0, 991 * 991}, nat{0, 991}, nat{0, 991}},
|
||||
{nat{1 * 991, 2 * 991, 3 * 991, 4 * 991}, nat{1, 2, 3, 4}, nat{991}},
|
||||
{nat{4, 11, 20, 30, 20, 11, 4}, nat{1, 2, 3, 4}, nat{4, 3, 2, 1}},
|
||||
// 3^100 * 3^28 = 3^128
|
||||
{
|
||||
natFromString("11790184577738583171520872861412518665678211592275841109096961"),
|
||||
natFromString("515377520732011331036461129765621272702107522001"),
|
||||
natFromString("22876792454961"),
|
||||
},
|
||||
// z = 111....1 (70000 digits)
|
||||
// x = 10^(99*700) + ... + 10^1400 + 10^700 + 1
|
||||
// y = 111....1 (700 digits, larger than Karatsuba threshold on 32-bit and 64-bit)
|
||||
{
|
||||
natFromString(strings.Repeat("1", 70000)),
|
||||
natFromString("1" + strings.Repeat(strings.Repeat("0", 699)+"1", 99)),
|
||||
natFromString(strings.Repeat("1", 700)),
|
||||
},
|
||||
// z = 111....1 (20000 digits)
|
||||
// x = 10^10000 + 1
|
||||
// y = 111....1 (10000 digits)
|
||||
{
|
||||
natFromString(strings.Repeat("1", 20000)),
|
||||
natFromString("1" + strings.Repeat("0", 9999) + "1"),
|
||||
natFromString(strings.Repeat("1", 10000)),
|
||||
},
|
||||
}
|
||||
|
||||
func natFromString(s string) nat {
|
||||
x, _, err := nat(nil).scan(strings.NewReader(s), 0)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return x
|
||||
}
|
||||
|
||||
func TestSet(t *testing.T) {
|
||||
@ -136,6 +167,31 @@ func TestMulRangeN(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// allocBytes returns the number of bytes allocated by invoking f.
|
||||
func allocBytes(f func()) uint64 {
|
||||
var stats runtime.MemStats
|
||||
runtime.ReadMemStats(&stats)
|
||||
t := stats.TotalAlloc
|
||||
f()
|
||||
runtime.ReadMemStats(&stats)
|
||||
return stats.TotalAlloc - t
|
||||
}
|
||||
|
||||
// TestMulUnbalanced tests that multiplying numbers of different lengths
|
||||
// does not cause deep recursion and in turn allocate too much memory.
|
||||
// test case for issue 3807
|
||||
func TestMulUnbalanced(t *testing.T) {
|
||||
x := rndNat(50000)
|
||||
y := rndNat(40)
|
||||
allocSize := allocBytes(func() {
|
||||
nat(nil).mul(x, y)
|
||||
})
|
||||
inputSize := uint64(len(x)+len(y)) * _S
|
||||
if ratio := allocSize / uint64(inputSize); ratio > 10 {
|
||||
t.Errorf("multiplication uses too much memory (%d > %d times the size of inputs)", allocSize, ratio)
|
||||
}
|
||||
}
|
||||
|
||||
var rnd = rand.New(rand.NewSource(0x43de683f473542af))
|
||||
var mulx = rndNat(1e4)
|
||||
var muly = rndNat(1e4)
|
||||
|
Loading…
Reference in New Issue
Block a user