mirror of
https://github.com/golang/go
synced 2024-11-18 05:54:49 -07:00
crypto/elliptic: document and test that IsOnCurve(∞) == false
This also implies it can't be passed to Marshal. Fixes #37294 Change-Id: I1e6b6abd87ff31f323486958d5cb34a5c8f76b5f Reviewed-on: https://go-review.googlesource.com/c/go/+/239562 Run-TryBot: Filippo Valsorda <filippo@golang.org> TryBot-Result: Gobot Gobot <gobot@golang.org> Reviewed-by: Katie Hockman <katie@golang.org>
This commit is contained in:
parent
9f33108dfa
commit
320e4adc4b
@ -20,7 +20,10 @@ import (
|
||||
)
|
||||
|
||||
// A Curve represents a short-form Weierstrass curve with a=-3.
|
||||
// See https://www.hyperelliptic.org/EFD/g1p/auto-shortw.html
|
||||
//
|
||||
// Note that the point at infinity (0, 0) is not considered on the curve, and
|
||||
// although it can be returned by Add, Double, ScalarMult, or ScalarBaseMult, it
|
||||
// can't be marshaled or unmarshaled, and IsOnCurve will return false for it.
|
||||
type Curve interface {
|
||||
// Params returns the parameters for the curve.
|
||||
Params() *CurveParams
|
||||
@ -307,7 +310,8 @@ func GenerateKey(curve Curve, rand io.Reader) (priv []byte, x, y *big.Int, err e
|
||||
return
|
||||
}
|
||||
|
||||
// Marshal converts a point into the uncompressed form specified in section 4.3.6 of ANSI X9.62.
|
||||
// Marshal converts a point on the curve into the uncompressed form specified in
|
||||
// section 4.3.6 of ANSI X9.62.
|
||||
func Marshal(curve Curve, x, y *big.Int) []byte {
|
||||
byteLen := (curve.Params().BitSize + 7) / 8
|
||||
|
||||
@ -320,7 +324,8 @@ func Marshal(curve Curve, x, y *big.Int) []byte {
|
||||
return ret
|
||||
}
|
||||
|
||||
// MarshalCompressed converts a point into the compressed form specified in section 4.3.6 of ANSI X9.62.
|
||||
// MarshalCompressed converts a point on the curve into the compressed form
|
||||
// specified in section 4.3.6 of ANSI X9.62.
|
||||
func MarshalCompressed(curve Curve, x, y *big.Int) []byte {
|
||||
byteLen := (curve.Params().BitSize + 7) / 8
|
||||
compressed := make([]byte, 1+byteLen)
|
||||
|
@ -418,41 +418,62 @@ func TestP256Mult(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func testInfinity(t *testing.T, curve Curve) {
|
||||
_, x, y, _ := GenerateKey(curve, rand.Reader)
|
||||
x, y = curve.ScalarMult(x, y, curve.Params().N.Bytes())
|
||||
if x.Sign() != 0 || y.Sign() != 0 {
|
||||
t.Errorf("x^q != ∞")
|
||||
}
|
||||
|
||||
x, y = curve.ScalarBaseMult([]byte{0})
|
||||
if x.Sign() != 0 || y.Sign() != 0 {
|
||||
t.Errorf("b^0 != ∞")
|
||||
x.SetInt64(0)
|
||||
y.SetInt64(0)
|
||||
}
|
||||
|
||||
x2, y2 := curve.Double(x, y)
|
||||
if x2.Sign() != 0 || y2.Sign() != 0 {
|
||||
t.Errorf("2∞ != ∞")
|
||||
}
|
||||
|
||||
baseX := curve.Params().Gx
|
||||
baseY := curve.Params().Gy
|
||||
|
||||
x3, y3 := curve.Add(baseX, baseY, x, y)
|
||||
if x3.Cmp(baseX) != 0 || y3.Cmp(baseY) != 0 {
|
||||
t.Errorf("x+∞ != x")
|
||||
}
|
||||
|
||||
x4, y4 := curve.Add(x, y, baseX, baseY)
|
||||
if x4.Cmp(baseX) != 0 || y4.Cmp(baseY) != 0 {
|
||||
t.Errorf("∞+x != x")
|
||||
}
|
||||
|
||||
if curve.IsOnCurve(x, y) {
|
||||
t.Errorf("IsOnCurve(∞) == true")
|
||||
}
|
||||
}
|
||||
|
||||
func TestInfinity(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
curve Curve
|
||||
}{
|
||||
{"p224", P224()},
|
||||
{"p256", P256()},
|
||||
{"P-224", P224()},
|
||||
{"P-256", P256()},
|
||||
{"P-256/Generic", P256().Params()},
|
||||
{"P-384", P384()},
|
||||
{"P-521", P521()},
|
||||
}
|
||||
if testing.Short() {
|
||||
tests = tests[:1]
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
curve := test.curve
|
||||
x, y := curve.ScalarBaseMult(nil)
|
||||
if x.Sign() != 0 || y.Sign() != 0 {
|
||||
t.Errorf("%s: x^0 != ∞", test.name)
|
||||
}
|
||||
x.SetInt64(0)
|
||||
y.SetInt64(0)
|
||||
|
||||
x2, y2 := curve.Double(x, y)
|
||||
if x2.Sign() != 0 || y2.Sign() != 0 {
|
||||
t.Errorf("%s: 2∞ != ∞", test.name)
|
||||
}
|
||||
|
||||
baseX := curve.Params().Gx
|
||||
baseY := curve.Params().Gy
|
||||
|
||||
x3, y3 := curve.Add(baseX, baseY, x, y)
|
||||
if x3.Cmp(baseX) != 0 || y3.Cmp(baseY) != 0 {
|
||||
t.Errorf("%s: x+∞ != x", test.name)
|
||||
}
|
||||
|
||||
x4, y4 := curve.Add(x, y, baseX, baseY)
|
||||
if x4.Cmp(baseX) != 0 || y4.Cmp(baseY) != 0 {
|
||||
t.Errorf("%s: ∞+x != x", test.name)
|
||||
}
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
testInfinity(t, curve)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user