From 320e4adc4bd153cb0cb7e31e186fb3b4564fd0a7 Mon Sep 17 00:00:00 2001 From: Filippo Valsorda Date: Tue, 23 Jun 2020 18:14:25 -0400 Subject: [PATCH] =?UTF-8?q?crypto/elliptic:=20document=20and=20test=20that?= =?UTF-8?q?=20IsOnCurve(=E2=88=9E)=20=3D=3D=20false?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This also implies it can't be passed to Marshal. Fixes #37294 Change-Id: I1e6b6abd87ff31f323486958d5cb34a5c8f76b5f Reviewed-on: https://go-review.googlesource.com/c/go/+/239562 Run-TryBot: Filippo Valsorda TryBot-Result: Gobot Gobot Reviewed-by: Katie Hockman --- src/crypto/elliptic/elliptic.go | 11 ++-- src/crypto/elliptic/elliptic_test.go | 75 ++++++++++++++++++---------- 2 files changed, 56 insertions(+), 30 deletions(-) diff --git a/src/crypto/elliptic/elliptic.go b/src/crypto/elliptic/elliptic.go index 8735d3acf6..f93dc16419 100644 --- a/src/crypto/elliptic/elliptic.go +++ b/src/crypto/elliptic/elliptic.go @@ -20,7 +20,10 @@ import ( ) // A Curve represents a short-form Weierstrass curve with a=-3. -// See https://www.hyperelliptic.org/EFD/g1p/auto-shortw.html +// +// Note that the point at infinity (0, 0) is not considered on the curve, and +// although it can be returned by Add, Double, ScalarMult, or ScalarBaseMult, it +// can't be marshaled or unmarshaled, and IsOnCurve will return false for it. type Curve interface { // Params returns the parameters for the curve. Params() *CurveParams @@ -307,7 +310,8 @@ func GenerateKey(curve Curve, rand io.Reader) (priv []byte, x, y *big.Int, err e return } -// Marshal converts a point into the uncompressed form specified in section 4.3.6 of ANSI X9.62. +// Marshal converts a point on the curve into the uncompressed form specified in +// section 4.3.6 of ANSI X9.62. func Marshal(curve Curve, x, y *big.Int) []byte { byteLen := (curve.Params().BitSize + 7) / 8 @@ -320,7 +324,8 @@ func Marshal(curve Curve, x, y *big.Int) []byte { return ret } -// MarshalCompressed converts a point into the compressed form specified in section 4.3.6 of ANSI X9.62. +// MarshalCompressed converts a point on the curve into the compressed form +// specified in section 4.3.6 of ANSI X9.62. func MarshalCompressed(curve Curve, x, y *big.Int) []byte { byteLen := (curve.Params().BitSize + 7) / 8 compressed := make([]byte, 1+byteLen) diff --git a/src/crypto/elliptic/elliptic_test.go b/src/crypto/elliptic/elliptic_test.go index 45c2fb63f5..e80e7731aa 100644 --- a/src/crypto/elliptic/elliptic_test.go +++ b/src/crypto/elliptic/elliptic_test.go @@ -418,41 +418,62 @@ func TestP256Mult(t *testing.T) { } } +func testInfinity(t *testing.T, curve Curve) { + _, x, y, _ := GenerateKey(curve, rand.Reader) + x, y = curve.ScalarMult(x, y, curve.Params().N.Bytes()) + if x.Sign() != 0 || y.Sign() != 0 { + t.Errorf("x^q != ∞") + } + + x, y = curve.ScalarBaseMult([]byte{0}) + if x.Sign() != 0 || y.Sign() != 0 { + t.Errorf("b^0 != ∞") + x.SetInt64(0) + y.SetInt64(0) + } + + x2, y2 := curve.Double(x, y) + if x2.Sign() != 0 || y2.Sign() != 0 { + t.Errorf("2∞ != ∞") + } + + baseX := curve.Params().Gx + baseY := curve.Params().Gy + + x3, y3 := curve.Add(baseX, baseY, x, y) + if x3.Cmp(baseX) != 0 || y3.Cmp(baseY) != 0 { + t.Errorf("x+∞ != x") + } + + x4, y4 := curve.Add(x, y, baseX, baseY) + if x4.Cmp(baseX) != 0 || y4.Cmp(baseY) != 0 { + t.Errorf("∞+x != x") + } + + if curve.IsOnCurve(x, y) { + t.Errorf("IsOnCurve(∞) == true") + } +} + func TestInfinity(t *testing.T) { tests := []struct { name string curve Curve }{ - {"p224", P224()}, - {"p256", P256()}, + {"P-224", P224()}, + {"P-256", P256()}, + {"P-256/Generic", P256().Params()}, + {"P-384", P384()}, + {"P-521", P521()}, + } + if testing.Short() { + tests = tests[:1] } - for _, test := range tests { curve := test.curve - x, y := curve.ScalarBaseMult(nil) - if x.Sign() != 0 || y.Sign() != 0 { - t.Errorf("%s: x^0 != ∞", test.name) - } - x.SetInt64(0) - y.SetInt64(0) - - x2, y2 := curve.Double(x, y) - if x2.Sign() != 0 || y2.Sign() != 0 { - t.Errorf("%s: 2∞ != ∞", test.name) - } - - baseX := curve.Params().Gx - baseY := curve.Params().Gy - - x3, y3 := curve.Add(baseX, baseY, x, y) - if x3.Cmp(baseX) != 0 || y3.Cmp(baseY) != 0 { - t.Errorf("%s: x+∞ != x", test.name) - } - - x4, y4 := curve.Add(x, y, baseX, baseY) - if x4.Cmp(baseX) != 0 || y4.Cmp(baseY) != 0 { - t.Errorf("%s: ∞+x != x", test.name) - } + t.Run(test.name, func(t *testing.T) { + testInfinity(t, curve) + }) } }