1
0
mirror of https://github.com/golang/go synced 2024-11-19 14:24:47 -07:00

math/big: recognize z.Mul(x, x) as squaring of x

updates #13745

Multiprecision squaring can be done in a straightforward manner
with about half the multiplications of a basic multiplication
due to the symmetry of the operands.  This change implements
basic squaring for nat types and uses it for Int multiplication
when the same variable is supplied to both arguments of
z.Mul(x, x). This has some overhead to allocate a temporary
variable to hold the cross products, shift them to double and
add them to the diagonal terms.  There is a speed benefit in
the intermediate range when the overhead is neglible and the
asymptotic performance of karatsuba multiplication has not been
reached.

basicSqrThreshold = 20
karatsubaSqrThreshold = 400

Were set by running calibrate_test.go to measure timing differences
between the algorithms.  Benchmarks for squaring:

name           old time/op  new time/op  delta
IntSqr/1-4     51.5ns ±25%  25.1ns ± 7%  -51.38%  (p=0.008 n=5+5)
IntSqr/2-4     79.1ns ± 4%  72.4ns ± 2%   -8.47%  (p=0.008 n=5+5)
IntSqr/3-4      102ns ± 4%    97ns ± 5%     ~     (p=0.056 n=5+5)
IntSqr/5-4      161ns ± 4%   163ns ± 7%     ~     (p=0.952 n=5+5)
IntSqr/8-4      277ns ± 5%   267ns ± 6%     ~     (p=0.087 n=5+5)
IntSqr/10-4     358ns ± 3%   360ns ± 4%     ~     (p=0.730 n=5+5)
IntSqr/20-4    1.07µs ± 3%  1.01µs ± 6%     ~     (p=0.056 n=5+5)
IntSqr/30-4    2.36µs ± 4%  1.72µs ± 2%  -27.03%  (p=0.008 n=5+5)
IntSqr/50-4    5.19µs ± 3%  3.88µs ± 4%  -25.37%  (p=0.008 n=5+5)
IntSqr/80-4    11.3µs ± 4%   8.6µs ± 3%  -23.78%  (p=0.008 n=5+5)
IntSqr/100-4   16.2µs ± 4%  12.8µs ± 3%  -21.49%  (p=0.008 n=5+5)
IntSqr/200-4   50.1µs ± 5%  44.7µs ± 3%  -10.65%  (p=0.008 n=5+5)
IntSqr/300-4    105µs ±11%    95µs ± 3%   -9.50%  (p=0.008 n=5+5)
IntSqr/500-4    231µs ± 5%   227µs ± 2%     ~     (p=0.310 n=5+5)
IntSqr/800-4    496µs ± 9%   459µs ± 3%   -7.40%  (p=0.016 n=5+5)
IntSqr/1000-4   700µs ± 3%   710µs ± 5%     ~     (p=0.841 n=5+5)

Show a speed up of 10-25% in the range where basicSqr is optimal,
improved single word squaring and no significant difference when
the fallback to standard multiplication is used.

Change-Id: Iae2c82ca91cf890823f91e5c83bbe9a2c534b72b
Reviewed-on: https://go-review.googlesource.com/53638
Reviewed-by: Robert Griesemer <gri@golang.org>
Run-TryBot: Robert Griesemer <gri@golang.org>
TryBot-Result: Gobot Gobot <gobot@golang.org>
This commit is contained in:
Brian Kessler 2017-07-28 22:29:30 -07:00 committed by Robert Griesemer
parent 259f78f001
commit 25b040c287
5 changed files with 210 additions and 6 deletions

View File

@ -2,13 +2,20 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Calibration used to determine thresholds for using
// different algorithms. Ideally, this would be converted
// to go generate to create thresholds.go
// This file prints execution times for the Mul benchmark
// given different Karatsuba thresholds. The result may be
// used to manually fine-tune the threshold constant. The
// results are somewhat fragile; use repeated runs to get
// a clear picture.
// Usage: go test -run=TestCalibrate -calibrate
// Calculates lower and upper thresholds for when basicSqr
// is faster than standard multiplication.
// Usage: go test -run=TestCalibrate -v -calibrate
package big
@ -21,6 +28,27 @@ import (
var calibrate = flag.Bool("calibrate", false, "run calibration test")
func TestCalibrate(t *testing.T) {
if *calibrate {
computeKaratsubaThresholds()
// compute basicSqrThreshold where overhead becomes neglible
minSqr := computeSqrThreshold(10, 30, 1, 3)
// compute karatsubaSqrThreshold where karatsuba is faster
maxSqr := computeSqrThreshold(300, 500, 10, 3)
if minSqr != 0 {
fmt.Printf("found basicSqrThreshold = %d\n", minSqr)
} else {
fmt.Println("no basicSqrThreshold found")
}
if maxSqr != 0 {
fmt.Printf("found karatsubaSqrThreshold = %d\n", maxSqr)
} else {
fmt.Println("no karatsubaSqrThreshold found")
}
}
}
func karatsubaLoad(b *testing.B) {
BenchmarkMul(b)
}
@ -34,7 +62,7 @@ func measureKaratsuba(th int) time.Duration {
return time.Duration(res.NsPerOp())
}
func computeThresholds() {
func computeKaratsubaThresholds() {
fmt.Printf("Multiplication times for varying Karatsuba thresholds\n")
fmt.Printf("(run repeatedly for good results)\n")
@ -81,8 +109,56 @@ func computeThresholds() {
}
}
func TestCalibrate(t *testing.T) {
if *calibrate {
computeThresholds()
func measureBasicSqr(words, nruns int, enable bool) time.Duration {
// more runs for better statistics
initBasicSqr, initKaratsubaSqr := basicSqrThreshold, karatsubaSqrThreshold
if enable {
// set thresholds to use basicSqr at this number of words
basicSqrThreshold, karatsubaSqrThreshold = words-1, words+1
} else {
// set thresholds to disable basicSqr for any number of words
basicSqrThreshold, karatsubaSqrThreshold = -1, -1
}
var testval int64
for i := 0; i < nruns; i++ {
res := testing.Benchmark(func(b *testing.B) { benchmarkNatSqr(b, words) })
testval += res.NsPerOp()
}
testval /= int64(nruns)
basicSqrThreshold, karatsubaSqrThreshold = initBasicSqr, initKaratsubaSqr
return time.Duration(testval)
}
func computeSqrThreshold(from, to, step, nruns int) int {
fmt.Println("Calibrating thresholds for basicSqr via benchmarks of z.mul(x,x)")
fmt.Printf("Looking for a timing difference for x between %d - %d words by %d step\n", from, to, step)
var initPos bool
var threshold int
for i := from; i <= to; i += step {
baseline := measureBasicSqr(i, nruns, false)
testval := measureBasicSqr(i, nruns, true)
pos := baseline > testval
delta := baseline - testval
percent := delta * 100 / baseline
fmt.Printf("words = %3d deltaT = %10s (%4d%%) is basicSqr better: %v", i, delta, percent, pos)
if i == from {
initPos = pos
}
if threshold == 0 && pos != initPos {
threshold = i
fmt.Printf(" threshold found")
}
fmt.Println()
}
if threshold != 0 {
fmt.Printf("Found threshold = %d between %d - %d\n", threshold, from, to)
} else {
fmt.Printf("Found NO threshold between %d - %d\n", from, to)
}
return threshold
}

View File

@ -153,6 +153,11 @@ func (z *Int) Mul(x, y *Int) *Int {
// x * (-y) == -(x * y)
// (-x) * y == -(x * y)
// (-x) * (-y) == x * y
if x == y {
z.abs = z.abs.sqr(x.abs)
z.neg = false
return z
}
z.abs = z.abs.mul(x.abs, y.abs)
z.neg = len(z.abs) > 0 && x.neg != y.neg // 0 has no sign
return z

View File

@ -7,6 +7,7 @@ package big
import (
"bytes"
"encoding/hex"
"fmt"
"math/rand"
"strconv"
"strings"
@ -1544,3 +1545,24 @@ func BenchmarkSqrt(b *testing.B) {
t.Sqrt(n)
}
}
func benchmarkIntSqr(b *testing.B, nwords int) {
x := new(Int)
x.abs = rndNat(nwords)
t := new(Int)
b.ResetTimer()
for i := 0; i < b.N; i++ {
t.Mul(x, x)
}
}
func BenchmarkIntSqr(b *testing.B) {
for _, n := range sqrBenchSizes {
if isRaceBuilder && n > 1e3 {
continue
}
b.Run(fmt.Sprintf("%d", n), func(b *testing.B) {
benchmarkIntSqr(b, n)
})
}
}

View File

@ -249,7 +249,7 @@ func karatsubaSub(z, x nat, n int) {
// Operands that are shorter than karatsubaThreshold are multiplied using
// "grade school" multiplication; for longer operands the Karatsuba algorithm
// is used.
var karatsubaThreshold int = 40 // computed by calibrate.go
var karatsubaThreshold = 40 // computed by calibrate_test.go
// karatsuba multiplies x and y and leaves the result in z.
// Both x and y must have the same length n and n must be a
@ -473,6 +473,61 @@ func (z nat) mul(x, y nat) nat {
return z.norm()
}
// basicSqr sets z = x*x and is asymptotically faster than basicMul
// by about a factor of 2, but slower for small arguments due to overhead.
// Requirements: len(x) > 0, len(z) >= 2*len(x)
// The (non-normalized) result is placed in z[0 : 2 * len(x)].
func basicSqr(z, x nat) {
n := len(x)
t := make(nat, 2*n) // temporary variable to hold the products
z[1], z[0] = mulWW(x[0], x[0]) // the initial square
for i := 1; i < n; i++ {
d := x[i]
// z collects the squares x[i] * x[i]
z[2*i+1], z[2*i] = mulWW(d, d)
// t collects the products x[i] * x[j] where j < i
t[2*i] = addMulVVW(t[i:2*i], x[0:i], d)
}
t[2*n-1] = shlVU(t[1:2*n-1], t[1:2*n-1], 1) // double the j < i products
addVV(z, z, t) // combine the result
}
// Operands that are shorter than basicSqrThreshold are squared using
// "grade school" multiplication; for operands longer than karatsubaSqrThreshold
// the Karatsuba algorithm is used.
var basicSqrThreshold = 20 // computed by calibrate_test.go
var karatsubaSqrThreshold = 400 // computed by calibrate_test.go
// z = x*x
func (z nat) sqr(x nat) nat {
n := len(x)
switch {
case n == 0:
return z[:0]
case n == 1:
d := x[0]
z = z.make(2)
z[1], z[0] = mulWW(d, d)
return z.norm()
}
if alias(z, x) {
z = nil // z is an alias for x - cannot reuse
}
z = z.make(2 * n)
if n < basicSqrThreshold {
basicMul(z, x, x)
return z.norm()
}
if n < karatsubaSqrThreshold {
basicSqr(z, x)
return z.norm()
}
return z.mul(x, x)
}
// mulRange computes the product of all the unsigned integers in the
// range [a, b] inclusively. If a > b (empty range), the result is 1.
func (z nat) mulRange(a, b uint64) nat {

View File

@ -619,3 +619,49 @@ func TestSticky(t *testing.T) {
}
}
}
func testBasicSqr(t *testing.T, x nat) {
got := make(nat, 2*len(x))
want := make(nat, 2*len(x))
basicSqr(got, x)
basicMul(want, x, x)
if got.cmp(want) != 0 {
t.Errorf("basicSqr(%v), got %v, want %v", x, got, want)
}
}
func TestBasicSqr(t *testing.T) {
for _, a := range prodNN {
if a.x != nil {
testBasicSqr(t, a.x)
}
if a.y != nil {
testBasicSqr(t, a.y)
}
if a.z != nil {
testBasicSqr(t, a.z)
}
}
}
func benchmarkNatSqr(b *testing.B, nwords int) {
x := rndNat(nwords)
var z nat
b.ResetTimer()
for i := 0; i < b.N; i++ {
z.sqr(x)
}
}
var sqrBenchSizes = []int{1, 2, 3, 5, 8, 10, 20, 30, 50, 80, 100, 200, 300, 500, 800, 1000}
func BenchmarkNatSqr(b *testing.B) {
for _, n := range sqrBenchSizes {
if isRaceBuilder && n > 1e3 {
continue
}
b.Run(fmt.Sprintf("%d", n), func(b *testing.B) {
benchmarkNatSqr(b, n)
})
}
}