mirror of
https://github.com/golang/go
synced 2024-11-19 14:24:47 -07:00
math/big: recognize z.Mul(x, x) as squaring of x
updates #13745 Multiprecision squaring can be done in a straightforward manner with about half the multiplications of a basic multiplication due to the symmetry of the operands. This change implements basic squaring for nat types and uses it for Int multiplication when the same variable is supplied to both arguments of z.Mul(x, x). This has some overhead to allocate a temporary variable to hold the cross products, shift them to double and add them to the diagonal terms. There is a speed benefit in the intermediate range when the overhead is neglible and the asymptotic performance of karatsuba multiplication has not been reached. basicSqrThreshold = 20 karatsubaSqrThreshold = 400 Were set by running calibrate_test.go to measure timing differences between the algorithms. Benchmarks for squaring: name old time/op new time/op delta IntSqr/1-4 51.5ns ±25% 25.1ns ± 7% -51.38% (p=0.008 n=5+5) IntSqr/2-4 79.1ns ± 4% 72.4ns ± 2% -8.47% (p=0.008 n=5+5) IntSqr/3-4 102ns ± 4% 97ns ± 5% ~ (p=0.056 n=5+5) IntSqr/5-4 161ns ± 4% 163ns ± 7% ~ (p=0.952 n=5+5) IntSqr/8-4 277ns ± 5% 267ns ± 6% ~ (p=0.087 n=5+5) IntSqr/10-4 358ns ± 3% 360ns ± 4% ~ (p=0.730 n=5+5) IntSqr/20-4 1.07µs ± 3% 1.01µs ± 6% ~ (p=0.056 n=5+5) IntSqr/30-4 2.36µs ± 4% 1.72µs ± 2% -27.03% (p=0.008 n=5+5) IntSqr/50-4 5.19µs ± 3% 3.88µs ± 4% -25.37% (p=0.008 n=5+5) IntSqr/80-4 11.3µs ± 4% 8.6µs ± 3% -23.78% (p=0.008 n=5+5) IntSqr/100-4 16.2µs ± 4% 12.8µs ± 3% -21.49% (p=0.008 n=5+5) IntSqr/200-4 50.1µs ± 5% 44.7µs ± 3% -10.65% (p=0.008 n=5+5) IntSqr/300-4 105µs ±11% 95µs ± 3% -9.50% (p=0.008 n=5+5) IntSqr/500-4 231µs ± 5% 227µs ± 2% ~ (p=0.310 n=5+5) IntSqr/800-4 496µs ± 9% 459µs ± 3% -7.40% (p=0.016 n=5+5) IntSqr/1000-4 700µs ± 3% 710µs ± 5% ~ (p=0.841 n=5+5) Show a speed up of 10-25% in the range where basicSqr is optimal, improved single word squaring and no significant difference when the fallback to standard multiplication is used. Change-Id: Iae2c82ca91cf890823f91e5c83bbe9a2c534b72b Reviewed-on: https://go-review.googlesource.com/53638 Reviewed-by: Robert Griesemer <gri@golang.org> Run-TryBot: Robert Griesemer <gri@golang.org> TryBot-Result: Gobot Gobot <gobot@golang.org>
This commit is contained in:
parent
259f78f001
commit
25b040c287
@ -2,13 +2,20 @@
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
// Calibration used to determine thresholds for using
|
||||
// different algorithms. Ideally, this would be converted
|
||||
// to go generate to create thresholds.go
|
||||
|
||||
// This file prints execution times for the Mul benchmark
|
||||
// given different Karatsuba thresholds. The result may be
|
||||
// used to manually fine-tune the threshold constant. The
|
||||
// results are somewhat fragile; use repeated runs to get
|
||||
// a clear picture.
|
||||
|
||||
// Usage: go test -run=TestCalibrate -calibrate
|
||||
// Calculates lower and upper thresholds for when basicSqr
|
||||
// is faster than standard multiplication.
|
||||
|
||||
// Usage: go test -run=TestCalibrate -v -calibrate
|
||||
|
||||
package big
|
||||
|
||||
@ -21,6 +28,27 @@ import (
|
||||
|
||||
var calibrate = flag.Bool("calibrate", false, "run calibration test")
|
||||
|
||||
func TestCalibrate(t *testing.T) {
|
||||
if *calibrate {
|
||||
computeKaratsubaThresholds()
|
||||
|
||||
// compute basicSqrThreshold where overhead becomes neglible
|
||||
minSqr := computeSqrThreshold(10, 30, 1, 3)
|
||||
// compute karatsubaSqrThreshold where karatsuba is faster
|
||||
maxSqr := computeSqrThreshold(300, 500, 10, 3)
|
||||
if minSqr != 0 {
|
||||
fmt.Printf("found basicSqrThreshold = %d\n", minSqr)
|
||||
} else {
|
||||
fmt.Println("no basicSqrThreshold found")
|
||||
}
|
||||
if maxSqr != 0 {
|
||||
fmt.Printf("found karatsubaSqrThreshold = %d\n", maxSqr)
|
||||
} else {
|
||||
fmt.Println("no karatsubaSqrThreshold found")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func karatsubaLoad(b *testing.B) {
|
||||
BenchmarkMul(b)
|
||||
}
|
||||
@ -34,7 +62,7 @@ func measureKaratsuba(th int) time.Duration {
|
||||
return time.Duration(res.NsPerOp())
|
||||
}
|
||||
|
||||
func computeThresholds() {
|
||||
func computeKaratsubaThresholds() {
|
||||
fmt.Printf("Multiplication times for varying Karatsuba thresholds\n")
|
||||
fmt.Printf("(run repeatedly for good results)\n")
|
||||
|
||||
@ -81,8 +109,56 @@ func computeThresholds() {
|
||||
}
|
||||
}
|
||||
|
||||
func TestCalibrate(t *testing.T) {
|
||||
if *calibrate {
|
||||
computeThresholds()
|
||||
func measureBasicSqr(words, nruns int, enable bool) time.Duration {
|
||||
// more runs for better statistics
|
||||
initBasicSqr, initKaratsubaSqr := basicSqrThreshold, karatsubaSqrThreshold
|
||||
|
||||
if enable {
|
||||
// set thresholds to use basicSqr at this number of words
|
||||
basicSqrThreshold, karatsubaSqrThreshold = words-1, words+1
|
||||
} else {
|
||||
// set thresholds to disable basicSqr for any number of words
|
||||
basicSqrThreshold, karatsubaSqrThreshold = -1, -1
|
||||
}
|
||||
|
||||
var testval int64
|
||||
for i := 0; i < nruns; i++ {
|
||||
res := testing.Benchmark(func(b *testing.B) { benchmarkNatSqr(b, words) })
|
||||
testval += res.NsPerOp()
|
||||
}
|
||||
testval /= int64(nruns)
|
||||
|
||||
basicSqrThreshold, karatsubaSqrThreshold = initBasicSqr, initKaratsubaSqr
|
||||
|
||||
return time.Duration(testval)
|
||||
}
|
||||
|
||||
func computeSqrThreshold(from, to, step, nruns int) int {
|
||||
fmt.Println("Calibrating thresholds for basicSqr via benchmarks of z.mul(x,x)")
|
||||
fmt.Printf("Looking for a timing difference for x between %d - %d words by %d step\n", from, to, step)
|
||||
var initPos bool
|
||||
var threshold int
|
||||
for i := from; i <= to; i += step {
|
||||
baseline := measureBasicSqr(i, nruns, false)
|
||||
testval := measureBasicSqr(i, nruns, true)
|
||||
pos := baseline > testval
|
||||
delta := baseline - testval
|
||||
percent := delta * 100 / baseline
|
||||
fmt.Printf("words = %3d deltaT = %10s (%4d%%) is basicSqr better: %v", i, delta, percent, pos)
|
||||
if i == from {
|
||||
initPos = pos
|
||||
}
|
||||
if threshold == 0 && pos != initPos {
|
||||
threshold = i
|
||||
fmt.Printf(" threshold found")
|
||||
}
|
||||
fmt.Println()
|
||||
|
||||
}
|
||||
if threshold != 0 {
|
||||
fmt.Printf("Found threshold = %d between %d - %d\n", threshold, from, to)
|
||||
} else {
|
||||
fmt.Printf("Found NO threshold between %d - %d\n", from, to)
|
||||
}
|
||||
return threshold
|
||||
}
|
||||
|
@ -153,6 +153,11 @@ func (z *Int) Mul(x, y *Int) *Int {
|
||||
// x * (-y) == -(x * y)
|
||||
// (-x) * y == -(x * y)
|
||||
// (-x) * (-y) == x * y
|
||||
if x == y {
|
||||
z.abs = z.abs.sqr(x.abs)
|
||||
z.neg = false
|
||||
return z
|
||||
}
|
||||
z.abs = z.abs.mul(x.abs, y.abs)
|
||||
z.neg = len(z.abs) > 0 && x.neg != y.neg // 0 has no sign
|
||||
return z
|
||||
|
@ -7,6 +7,7 @@ package big
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"strconv"
|
||||
"strings"
|
||||
@ -1544,3 +1545,24 @@ func BenchmarkSqrt(b *testing.B) {
|
||||
t.Sqrt(n)
|
||||
}
|
||||
}
|
||||
|
||||
func benchmarkIntSqr(b *testing.B, nwords int) {
|
||||
x := new(Int)
|
||||
x.abs = rndNat(nwords)
|
||||
t := new(Int)
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
t.Mul(x, x)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkIntSqr(b *testing.B) {
|
||||
for _, n := range sqrBenchSizes {
|
||||
if isRaceBuilder && n > 1e3 {
|
||||
continue
|
||||
}
|
||||
b.Run(fmt.Sprintf("%d", n), func(b *testing.B) {
|
||||
benchmarkIntSqr(b, n)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
@ -249,7 +249,7 @@ func karatsubaSub(z, x nat, n int) {
|
||||
// Operands that are shorter than karatsubaThreshold are multiplied using
|
||||
// "grade school" multiplication; for longer operands the Karatsuba algorithm
|
||||
// is used.
|
||||
var karatsubaThreshold int = 40 // computed by calibrate.go
|
||||
var karatsubaThreshold = 40 // computed by calibrate_test.go
|
||||
|
||||
// karatsuba multiplies x and y and leaves the result in z.
|
||||
// Both x and y must have the same length n and n must be a
|
||||
@ -473,6 +473,61 @@ func (z nat) mul(x, y nat) nat {
|
||||
return z.norm()
|
||||
}
|
||||
|
||||
// basicSqr sets z = x*x and is asymptotically faster than basicMul
|
||||
// by about a factor of 2, but slower for small arguments due to overhead.
|
||||
// Requirements: len(x) > 0, len(z) >= 2*len(x)
|
||||
// The (non-normalized) result is placed in z[0 : 2 * len(x)].
|
||||
func basicSqr(z, x nat) {
|
||||
n := len(x)
|
||||
t := make(nat, 2*n) // temporary variable to hold the products
|
||||
z[1], z[0] = mulWW(x[0], x[0]) // the initial square
|
||||
for i := 1; i < n; i++ {
|
||||
d := x[i]
|
||||
// z collects the squares x[i] * x[i]
|
||||
z[2*i+1], z[2*i] = mulWW(d, d)
|
||||
// t collects the products x[i] * x[j] where j < i
|
||||
t[2*i] = addMulVVW(t[i:2*i], x[0:i], d)
|
||||
}
|
||||
t[2*n-1] = shlVU(t[1:2*n-1], t[1:2*n-1], 1) // double the j < i products
|
||||
addVV(z, z, t) // combine the result
|
||||
}
|
||||
|
||||
// Operands that are shorter than basicSqrThreshold are squared using
|
||||
// "grade school" multiplication; for operands longer than karatsubaSqrThreshold
|
||||
// the Karatsuba algorithm is used.
|
||||
var basicSqrThreshold = 20 // computed by calibrate_test.go
|
||||
var karatsubaSqrThreshold = 400 // computed by calibrate_test.go
|
||||
|
||||
// z = x*x
|
||||
func (z nat) sqr(x nat) nat {
|
||||
n := len(x)
|
||||
switch {
|
||||
case n == 0:
|
||||
return z[:0]
|
||||
case n == 1:
|
||||
d := x[0]
|
||||
z = z.make(2)
|
||||
z[1], z[0] = mulWW(d, d)
|
||||
return z.norm()
|
||||
}
|
||||
|
||||
if alias(z, x) {
|
||||
z = nil // z is an alias for x - cannot reuse
|
||||
}
|
||||
z = z.make(2 * n)
|
||||
|
||||
if n < basicSqrThreshold {
|
||||
basicMul(z, x, x)
|
||||
return z.norm()
|
||||
}
|
||||
if n < karatsubaSqrThreshold {
|
||||
basicSqr(z, x)
|
||||
return z.norm()
|
||||
}
|
||||
|
||||
return z.mul(x, x)
|
||||
}
|
||||
|
||||
// mulRange computes the product of all the unsigned integers in the
|
||||
// range [a, b] inclusively. If a > b (empty range), the result is 1.
|
||||
func (z nat) mulRange(a, b uint64) nat {
|
||||
|
@ -619,3 +619,49 @@ func TestSticky(t *testing.T) {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func testBasicSqr(t *testing.T, x nat) {
|
||||
got := make(nat, 2*len(x))
|
||||
want := make(nat, 2*len(x))
|
||||
basicSqr(got, x)
|
||||
basicMul(want, x, x)
|
||||
if got.cmp(want) != 0 {
|
||||
t.Errorf("basicSqr(%v), got %v, want %v", x, got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBasicSqr(t *testing.T) {
|
||||
for _, a := range prodNN {
|
||||
if a.x != nil {
|
||||
testBasicSqr(t, a.x)
|
||||
}
|
||||
if a.y != nil {
|
||||
testBasicSqr(t, a.y)
|
||||
}
|
||||
if a.z != nil {
|
||||
testBasicSqr(t, a.z)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func benchmarkNatSqr(b *testing.B, nwords int) {
|
||||
x := rndNat(nwords)
|
||||
var z nat
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
z.sqr(x)
|
||||
}
|
||||
}
|
||||
|
||||
var sqrBenchSizes = []int{1, 2, 3, 5, 8, 10, 20, 30, 50, 80, 100, 200, 300, 500, 800, 1000}
|
||||
|
||||
func BenchmarkNatSqr(b *testing.B) {
|
||||
for _, n := range sqrBenchSizes {
|
||||
if isRaceBuilder && n > 1e3 {
|
||||
continue
|
||||
}
|
||||
b.Run(fmt.Sprintf("%d", n), func(b *testing.B) {
|
||||
benchmarkNatSqr(b, n)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user