From dd84bb682482390bb8465482cb7b13d2e3b17297 Mon Sep 17 00:00:00 2001 From: Mateusz Poliwczak Date: Tue, 31 Oct 2023 17:26:43 +0000 Subject: [PATCH] crypto/x509: add new OID type and use it in Certificate Fixes #60665 Change-Id: I814b7d4b26b964f74443584fb2048b3e27e3b675 GitHub-Last-Rev: 693c741c76e6369e36aa2a599ee6242d632573c7 GitHub-Pull-Request: golang/go#62096 Reviewed-on: https://go-review.googlesource.com/c/go/+/520535 TryBot-Result: Gopher Robot Run-TryBot: Mateusz Poliwczak Auto-Submit: Roland Shoemaker Reviewed-by: Roland Shoemaker Reviewed-by: Cherry Mui --- api/next/60665.txt | 6 + src/crypto/x509/oid.go | 273 +++++++++++++++++++++++++++++++++++ src/crypto/x509/oid_test.go | 110 ++++++++++++++ src/crypto/x509/parser.go | 20 ++- src/crypto/x509/x509.go | 35 ++++- src/crypto/x509/x509_test.go | 49 +++++++ 6 files changed, 479 insertions(+), 14 deletions(-) create mode 100644 api/next/60665.txt create mode 100644 src/crypto/x509/oid.go create mode 100644 src/crypto/x509/oid_test.go diff --git a/api/next/60665.txt b/api/next/60665.txt new file mode 100644 index 0000000000..10e50e1832 --- /dev/null +++ b/api/next/60665.txt @@ -0,0 +1,6 @@ +pkg crypto/x509, type Certificate struct, Policies []OID #60665 +pkg crypto/x509, type OID struct #60665 +pkg crypto/x509, method (OID) Equal(OID) bool #60665 +pkg crypto/x509, method (OID) EqualASN1OID(asn1.ObjectIdentifier) bool #60665 +pkg crypto/x509, method (OID) String() string #60665 +pkg crypto/x509, func OIDFromInts([]uint64) (OID, error) #60665 diff --git a/src/crypto/x509/oid.go b/src/crypto/x509/oid.go new file mode 100644 index 0000000000..5359af624b --- /dev/null +++ b/src/crypto/x509/oid.go @@ -0,0 +1,273 @@ +// Copyright 2023 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package x509 + +import ( + "bytes" + "encoding/asn1" + "errors" + "math" + "math/big" + "math/bits" + "strconv" + "strings" +) + +var ( + errInvalidOID = errors.New("invalid oid") +) + +// An OID represents an ASN.1 OBJECT IDENTIFIER. +type OID struct { + der []byte +} + +func newOIDFromDER(der []byte) (OID, bool) { + if len(der) == 0 || der[len(der)-1]&0x80 != 0 { + return OID{}, false + } + + start := 0 + for i, v := range der { + // ITU-T X.690, section 8.19.2: + // The subidentifier shall be encoded in the fewest possible octets, + // that is, the leading octet of the subidentifier shall not have the value 0x80. + if i == start && v == 0x80 { + return OID{}, false + } + if v&0x80 == 0 { + start = i + 1 + } + } + + return OID{der}, true +} + +// OIDFromInts creates a new OID using ints, each integer is a separate component. +func OIDFromInts(oid []uint64) (OID, error) { + if len(oid) < 2 || oid[0] > 2 || (oid[0] < 2 && oid[1] >= 40) { + return OID{}, errInvalidOID + } + + length := base128IntLength(oid[0]*40 + oid[1]) + for _, v := range oid[2:] { + length += base128IntLength(v) + } + + der := make([]byte, 0, length) + der = appendBase128Int(der, oid[0]*40+oid[1]) + for _, v := range oid[2:] { + der = appendBase128Int(der, v) + } + return OID{der}, nil +} + +func base128IntLength(n uint64) int { + if n == 0 { + return 1 + } + return (bits.Len64(n) + 6) / 7 +} + +func appendBase128Int(dst []byte, n uint64) []byte { + for i := base128IntLength(n) - 1; i >= 0; i-- { + o := byte(n >> uint(i*7)) + o &= 0x7f + if i != 0 { + o |= 0x80 + } + dst = append(dst, o) + } + return dst +} + +// Equal returns true when oid and other represents the same Object Identifier. +func (oid OID) Equal(other OID) bool { + // There is only one possible DER encoding of + // each unique Object Identifier. + return bytes.Equal(oid.der, other.der) +} + +func parseBase128Int(bytes []byte, initOffset int) (ret, offset int, failed bool) { + offset = initOffset + var ret64 int64 + for shifted := 0; offset < len(bytes); shifted++ { + // 5 * 7 bits per byte == 35 bits of data + // Thus the representation is either non-minimal or too large for an int32 + if shifted == 5 { + failed = true + return + } + ret64 <<= 7 + b := bytes[offset] + // integers should be minimally encoded, so the leading octet should + // never be 0x80 + if shifted == 0 && b == 0x80 { + failed = true + return + } + ret64 |= int64(b & 0x7f) + offset++ + if b&0x80 == 0 { + ret = int(ret64) + // Ensure that the returned value fits in an int on all platforms + if ret64 > math.MaxInt32 { + failed = true + } + return + } + } + failed = true + return +} + +// EqualASN1OID returns whether an OID equals an asn1.ObjectIdentifier. If +// asn1.ObjectIdentifier cannot represent the OID specified by oid, because +// a component of OID requires more than 31 bits, it returns false. +func (oid OID) EqualASN1OID(other asn1.ObjectIdentifier) bool { + if len(other) < 2 { + return false + } + v, offset, failed := parseBase128Int(oid.der, 0) + if failed { + // This should never happen, since we've already parsed the OID, + // but just in case. + return false + } + if v < 80 { + a, b := v/40, v%40 + if other[0] != a || other[1] != b { + return false + } + } else { + a, b := 2, v-80 + if other[0] != a || other[1] != b { + return false + } + } + + i := 2 + for ; offset < len(oid.der); i++ { + v, offset, failed = parseBase128Int(oid.der, offset) + if failed { + // Again, shouldn't happen, since we've already parsed + // the OID, but better safe than sorry. + return false + } + if v != other[i] { + return false + } + } + + return i == len(other) +} + +// Strings returns the string representation of the Object Identifier. +func (oid OID) String() string { + var b strings.Builder + b.Grow(32) + const ( + valSize = 64 // size in bits of val. + bitsPerByte = 7 + maxValSafeShift = (1 << (valSize - bitsPerByte)) - 1 + ) + var ( + start = 0 + val = uint64(0) + numBuf = make([]byte, 0, 21) + bigVal *big.Int + overflow bool + ) + for i, v := range oid.der { + curVal := v & 0x7F + valEnd := v&0x80 == 0 + if valEnd { + if start != 0 { + b.WriteByte('.') + } + } + if !overflow && val > maxValSafeShift { + if bigVal == nil { + bigVal = new(big.Int) + } + bigVal = bigVal.SetUint64(val) + overflow = true + } + if overflow { + bigVal = bigVal.Lsh(bigVal, bitsPerByte).Or(bigVal, big.NewInt(int64(curVal))) + if valEnd { + if start == 0 { + b.WriteString("2.") + bigVal = bigVal.Sub(bigVal, big.NewInt(80)) + } + numBuf = bigVal.Append(numBuf, 10) + b.Write(numBuf) + numBuf = numBuf[:0] + val = 0 + start = i + 1 + overflow = false + } + continue + } + val <<= bitsPerByte + val |= uint64(curVal) + if valEnd { + if start == 0 { + if val < 80 { + b.Write(strconv.AppendUint(numBuf, val/40, 10)) + b.WriteByte('.') + b.Write(strconv.AppendUint(numBuf, val%40, 10)) + } else { + b.WriteString("2.") + b.Write(strconv.AppendUint(numBuf, val-80, 10)) + } + } else { + b.Write(strconv.AppendUint(numBuf, val, 10)) + } + val = 0 + start = i + 1 + } + } + return b.String() +} + +func (oid OID) toASN1OID() (asn1.ObjectIdentifier, bool) { + out := make([]int, 0, len(oid.der)+1) + + const ( + valSize = 31 // amount of usable bits of val for OIDs. + bitsPerByte = 7 + maxValSafeShift = (1 << (valSize - bitsPerByte)) - 1 + ) + + val := 0 + + for _, v := range oid.der { + if val > maxValSafeShift { + return nil, false + } + + val <<= bitsPerByte + val |= int(v & 0x7F) + + if v&0x80 == 0 { + if len(out) == 0 { + if val < 80 { + out = append(out, val/40) + out = append(out, val%40) + } else { + out = append(out, 2) + out = append(out, val-80) + } + val = 0 + continue + } + out = append(out, val) + val = 0 + } + } + + return out, true +} diff --git a/src/crypto/x509/oid_test.go b/src/crypto/x509/oid_test.go new file mode 100644 index 0000000000..b2be1079c1 --- /dev/null +++ b/src/crypto/x509/oid_test.go @@ -0,0 +1,110 @@ +// Copyright 2023 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package x509 + +import ( + "encoding/asn1" + "math" + "testing" +) + +func TestOID(t *testing.T) { + var tests = []struct { + raw []byte + valid bool + str string + ints []uint64 + }{ + {[]byte{}, false, "", nil}, + {[]byte{0x80, 0x01}, false, "", nil}, + {[]byte{0x01, 0x80, 0x01}, false, "", nil}, + + {[]byte{1, 2, 3}, true, "0.1.2.3", []uint64{0, 1, 2, 3}}, + {[]byte{41, 2, 3}, true, "1.1.2.3", []uint64{1, 1, 2, 3}}, + {[]byte{86, 2, 3}, true, "2.6.2.3", []uint64{2, 6, 2, 3}}, + + {[]byte{41, 255, 255, 255, 127}, true, "1.1.268435455", []uint64{1, 1, 268435455}}, + {[]byte{41, 0x87, 255, 255, 255, 127}, true, "1.1.2147483647", []uint64{1, 1, 2147483647}}, + {[]byte{41, 255, 255, 255, 255, 127}, true, "1.1.34359738367", []uint64{1, 1, 34359738367}}, + {[]byte{42, 255, 255, 255, 255, 255, 255, 255, 255, 127}, true, "1.2.9223372036854775807", []uint64{1, 2, 9223372036854775807}}, + {[]byte{43, 0x81, 255, 255, 255, 255, 255, 255, 255, 255, 127}, true, "1.3.18446744073709551615", []uint64{1, 3, 18446744073709551615}}, + {[]byte{44, 0x83, 255, 255, 255, 255, 255, 255, 255, 255, 127}, true, "1.4.36893488147419103231", nil}, + {[]byte{85, 255, 255, 255, 255, 255, 255, 255, 255, 255, 127}, true, "2.5.1180591620717411303423", nil}, + {[]byte{85, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 127}, true, "2.5.19342813113834066795298815", nil}, + + {[]byte{255, 255, 255, 127}, true, "2.268435375", []uint64{2, 268435375}}, + {[]byte{0x87, 255, 255, 255, 127}, true, "2.2147483567", []uint64{2, 2147483567}}, + {[]byte{255, 127}, true, "2.16303", []uint64{2, 16303}}, + {[]byte{255, 255, 255, 255, 127}, true, "2.34359738287", []uint64{2, 34359738287}}, + {[]byte{255, 255, 255, 255, 255, 255, 255, 255, 127}, true, "2.9223372036854775727", []uint64{2, 9223372036854775727}}, + {[]byte{0x81, 255, 255, 255, 255, 255, 255, 255, 255, 127}, true, "2.18446744073709551535", []uint64{2, 18446744073709551535}}, + {[]byte{0x83, 255, 255, 255, 255, 255, 255, 255, 255, 127}, true, "2.36893488147419103151", nil}, + {[]byte{255, 255, 255, 255, 255, 255, 255, 255, 255, 127}, true, "2.1180591620717411303343", nil}, + {[]byte{255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 127}, true, "2.19342813113834066795298735", nil}, + } + + for _, v := range tests { + oid, ok := newOIDFromDER(v.raw) + if ok != v.valid { + if ok { + t.Errorf("%v: unexpected success while parsing: %v", v.raw, oid) + } else { + t.Errorf("%v: unexpected failure while parsing", v.raw) + } + continue + } + + if !ok { + continue + } + + if str := oid.String(); str != v.str { + t.Errorf("%v: oid.String() = %v, want; %v", v.raw, str, v.str) + } + + var asn1OID asn1.ObjectIdentifier + for _, v := range v.ints { + if v > math.MaxInt32 { + asn1OID = nil + break + } + asn1OID = append(asn1OID, int(v)) + } + + o, ok := oid.toASN1OID() + if shouldOk := asn1OID != nil; shouldOk != ok { + if ok { + t.Errorf("%v: oid.toASN1OID() unexpected success", v.raw) + } else { + t.Errorf("%v: oid.toASN1OID() unexpected fauilure", v.raw) + } + continue + } + + if asn1OID != nil { + if !o.Equal(asn1OID) { + t.Errorf("%v: oid.toASN1OID(asn1OID).Equal(oid) = false, want: true", v.raw) + } + } + + if v.ints != nil { + oid2, err := OIDFromInts(v.ints) + if err != nil { + t.Errorf("%v: OIDFromInts() unexpected error: %v", v.raw, err) + } + if !oid2.Equal(oid) { + t.Errorf("%v: %#v.Equal(%#v) = false, want: true", v.raw, oid2, oid) + } + } + } +} + +func mustNewOIDFromInts(t *testing.T, ints []uint64) OID { + oid, err := OIDFromInts(ints) + if err != nil { + t.Fatalf("OIDFromInts(%v) unexpected error: %v", ints, err) + } + return oid +} diff --git a/src/crypto/x509/parser.go b/src/crypto/x509/parser.go index 019a53b5dc..812b0d2d28 100644 --- a/src/crypto/x509/parser.go +++ b/src/crypto/x509/parser.go @@ -435,23 +435,23 @@ func parseExtKeyUsageExtension(der cryptobyte.String) ([]ExtKeyUsage, []asn1.Obj return extKeyUsages, unknownUsages, nil } -func parseCertificatePoliciesExtension(der cryptobyte.String) ([]asn1.ObjectIdentifier, error) { - var oids []asn1.ObjectIdentifier +func parseCertificatePoliciesExtension(der cryptobyte.String) ([]OID, error) { + var oids []OID if !der.ReadASN1(&der, cryptobyte_asn1.SEQUENCE) { return nil, errors.New("x509: invalid certificate policies") } for !der.Empty() { var cp cryptobyte.String - if !der.ReadASN1(&cp, cryptobyte_asn1.SEQUENCE) { + var OIDBytes cryptobyte.String + if !der.ReadASN1(&cp, cryptobyte_asn1.SEQUENCE) || !cp.ReadASN1(&OIDBytes, cryptobyte_asn1.OBJECT_IDENTIFIER) { return nil, errors.New("x509: invalid certificate policies") } - var oid asn1.ObjectIdentifier - if !cp.ReadASN1ObjectIdentifier(&oid) { + oid, ok := newOIDFromDER(OIDBytes) + if !ok { return nil, errors.New("x509: invalid certificate policies") } oids = append(oids, oid) } - return oids, nil } @@ -748,10 +748,16 @@ func processExtensions(out *Certificate) error { } out.SubjectKeyId = skid case 32: - out.PolicyIdentifiers, err = parseCertificatePoliciesExtension(e.Value) + out.Policies, err = parseCertificatePoliciesExtension(e.Value) if err != nil { return err } + out.PolicyIdentifiers = make([]asn1.ObjectIdentifier, 0, len(out.Policies)) + for _, oid := range out.Policies { + if oid, ok := oid.toASN1OID(); ok { + out.PolicyIdentifiers = append(out.PolicyIdentifiers, oid) + } + } default: // Unknown extensions are recorded if critical. unhandled = true diff --git a/src/crypto/x509/x509.go b/src/crypto/x509/x509.go index dfc5092b30..b2e31f76b4 100644 --- a/src/crypto/x509/x509.go +++ b/src/crypto/x509/x509.go @@ -772,7 +772,15 @@ type Certificate struct { // CRL Distribution Points CRLDistributionPoints []string + // PolicyIdentifiers contains asn1.ObjectIdentifiers, the components + // of which are limited to int32. If a certificate contains a policy which + // cannot be represented by asn1.ObjectIdentifier, it will not be included in + // PolicyIdentifiers, but will be present in Policies, which contains all parsed + // policy OIDs. PolicyIdentifiers []asn1.ObjectIdentifier + + // Policies contains all policy identifiers included in the certificate. + Policies []OID } // ErrUnsupportedAlgorithm results from attempting to perform an operation that @@ -1179,7 +1187,7 @@ func buildCertExtensions(template *Certificate, subjectIsEmpty bool, authorityKe if len(template.PolicyIdentifiers) > 0 && !oidInExtensions(oidExtensionCertificatePolicies, template.ExtraExtensions) { - ret[n], err = marshalCertificatePolicies(template.PolicyIdentifiers) + ret[n], err = marshalCertificatePolicies(template.Policies, template.PolicyIdentifiers) if err != nil { return nil, err } @@ -1364,14 +1372,27 @@ func marshalBasicConstraints(isCA bool, maxPathLen int, maxPathLenZero bool) (pk return ext, err } -func marshalCertificatePolicies(policyIdentifiers []asn1.ObjectIdentifier) (pkix.Extension, error) { +func marshalCertificatePolicies(policies []OID, policyIdentifiers []asn1.ObjectIdentifier) (pkix.Extension, error) { ext := pkix.Extension{Id: oidExtensionCertificatePolicies} - policies := make([]policyInformation, len(policyIdentifiers)) - for i, policy := range policyIdentifiers { - policies[i].Policy = policy - } + + b := cryptobyte.NewBuilder(make([]byte, 0, 128)) + b.AddASN1(cryptobyte_asn1.SEQUENCE, func(child *cryptobyte.Builder) { + for _, v := range policies { + child.AddASN1(cryptobyte_asn1.SEQUENCE, func(child *cryptobyte.Builder) { + child.AddASN1(cryptobyte_asn1.OBJECT_IDENTIFIER, func(child *cryptobyte.Builder) { + child.AddBytes(v.der) + }) + }) + } + for _, v := range policyIdentifiers { + child.AddASN1(cryptobyte_asn1.SEQUENCE, func(child *cryptobyte.Builder) { + child.AddASN1ObjectIdentifier(v) + }) + } + }) + var err error - ext.Value, err = asn1.Marshal(policies) + ext.Value, err = b.Bytes() return ext, err } diff --git a/src/crypto/x509/x509_test.go b/src/crypto/x509/x509_test.go index 9a80b2b434..bdc03216bc 100644 --- a/src/crypto/x509/x509_test.go +++ b/src/crypto/x509/x509_test.go @@ -24,12 +24,14 @@ import ( "fmt" "internal/testenv" "io" + "math" "math/big" "net" "net/url" "os/exec" "reflect" "runtime" + "slices" "strings" "testing" "time" @@ -671,6 +673,7 @@ func TestCreateSelfSignedCertificate(t *testing.T) { URIs: []*url.URL{parseURI("https://foo.com/wibble#foo")}, PolicyIdentifiers: []asn1.ObjectIdentifier{[]int{1, 2, 3}}, + Policies: []OID{mustNewOIDFromInts(t, []uint64{1, 2, 3, math.MaxUint32, math.MaxUint64})}, PermittedDNSDomains: []string{".example.com", "example.com"}, ExcludedDNSDomains: []string{"bar.example.com"}, PermittedIPRanges: []*net.IPNet{parseCIDR("192.168.1.1/16"), parseCIDR("1.2.3.4/8")}, @@ -3917,3 +3920,49 @@ func TestDuplicateAttributesCSR(t *testing.T) { t.Fatal("ParseCertificateRequest should succeed when parsing CSR with duplicate attributes") } } + +func TestCertificateOIDPolicies(t *testing.T) { + template := Certificate{ + SerialNumber: big.NewInt(1), + Subject: pkix.Name{CommonName: "Cert"}, + NotBefore: time.Unix(1000, 0), + NotAfter: time.Unix(100000, 0), + PolicyIdentifiers: []asn1.ObjectIdentifier{[]int{1, 2, 3}}, + Policies: []OID{ + mustNewOIDFromInts(t, []uint64{1, 2, 3, 4, 5}), + mustNewOIDFromInts(t, []uint64{1, 2, 3, math.MaxInt32}), + mustNewOIDFromInts(t, []uint64{1, 2, 3, math.MaxUint32, math.MaxUint64}), + }, + } + + var expectPolicyIdentifiers = []asn1.ObjectIdentifier{ + []int{1, 2, 3, 4, 5}, + []int{1, 2, 3, math.MaxInt32}, + []int{1, 2, 3}, + } + + var expectPolicies = []OID{ + mustNewOIDFromInts(t, []uint64{1, 2, 3, 4, 5}), + mustNewOIDFromInts(t, []uint64{1, 2, 3, math.MaxInt32}), + mustNewOIDFromInts(t, []uint64{1, 2, 3, math.MaxUint32, math.MaxUint64}), + mustNewOIDFromInts(t, []uint64{1, 2, 3}), + } + + certDER, err := CreateCertificate(rand.Reader, &template, &template, rsaPrivateKey.Public(), rsaPrivateKey) + if err != nil { + t.Fatalf("CreateCertificate() unexpected error: %v", err) + } + + cert, err := ParseCertificate(certDER) + if err != nil { + t.Fatalf("ParseCertificate() unexpected error: %v", err) + } + + if !slices.EqualFunc(cert.PolicyIdentifiers, expectPolicyIdentifiers, slices.Equal) { + t.Errorf("cert.PolicyIdentifiers = %v, want: %v", cert.PolicyIdentifiers, expectPolicyIdentifiers) + } + + if !slices.EqualFunc(cert.Policies, expectPolicies, OID.Equal) { + t.Errorf("cert.Policies = %v, want: %v", cert.Policies, expectPolicies) + } +}