1
0
mirror of https://github.com/golang/go synced 2024-11-23 16:10:05 -07:00

math/big: reduce amount of copying in Montgomery multiplication

Instead shifting the accumulator every iteration of the loop, shift
once in the end. This significantly improves performance on arm64.

On arm64:

name                  old time/op    new time/op    delta
RSA2048Decrypt          3.33ms ± 0%    2.63ms ± 0%  -20.94%  (p=0.000 n=11+11)
RSA2048Sign             4.22ms ± 0%    3.55ms ± 0%  -15.89%  (p=0.000 n=11+11)
3PrimeRSA2048Decrypt    1.95ms ± 0%    1.59ms ± 0%  -18.59%  (p=0.000 n=11+11)

On Skylake:

name                    old time/op  new time/op  delta
RSA2048Decrypt-8        1.73ms ± 2%  1.55ms ± 2%  -10.19%  (p=0.000 n=10+10)
RSA2048Sign-8           2.17ms ± 2%  2.00ms ± 2%   -7.93%  (p=0.000 n=10+10)
3PrimeRSA2048Decrypt-8  1.10ms ± 2%  0.96ms ± 2%  -13.03%  (p=0.000 n=10+9)

Change-Id: I5786191a1a09e4217fdb1acfd90880d35c5855f7
Reviewed-on: https://go-review.googlesource.com/99838
Run-TryBot: Vlad Krasnov <vlad@cloudflare.com>
TryBot-Result: Gobot Gobot <gobot@golang.org>
Reviewed-by: Adam Langley <agl@golang.org>
Reviewed-by: Robert Griesemer <gri@golang.org>
This commit is contained in:
Vlad Krasnov 2018-03-09 17:14:49 +00:00 committed by Robert Griesemer
parent 4b837c7023
commit 26d74e8b65

View File

@ -213,18 +213,17 @@ func (z nat) montgomery(x, y, m nat, k Word, n int) nat {
if len(x) != n || len(y) != n || len(m) != n {
panic("math/big: mismatched montgomery number lengths")
}
z = z.make(n)
z = z.make(n * 2)
z.clear()
var c Word
for i := 0; i < n; i++ {
d := y[i]
c2 := addMulVVW(z, x, d)
t := z[0] * k
c3 := addMulVVW(z, m, t)
copy(z, z[1:])
c2 := addMulVVW(z[i:n+i], x, d)
t := z[i] * k
c3 := addMulVVW(z[i:n+i], m, t)
cx := c + c2
cy := cx + c3
z[n-1] = cy
z[n+i] = cy
if cx < c2 || cy < c3 {
c = 1
} else {
@ -232,9 +231,11 @@ func (z nat) montgomery(x, y, m nat, k Word, n int) nat {
}
}
if c != 0 {
subVV(z, z, m)
subVV(z[:n], z[n:], m)
} else {
copy(z[:n], z[n:])
}
return z
return z[:n]
}
// Fast version of z[0:n+n>>1].add(z[0:n+n>>1], x[0:n]) w/o bounds checks.