1
0
mirror of https://github.com/golang/go synced 2024-11-25 22:37:59 -07:00
Change-Id: I54d753bd280219f9392cb9061088132e526d1800
This commit is contained in:
Mateusz Poliwczak 2023-10-30 20:13:45 +01:00
parent 9e9479acf9
commit 693c741c76
3 changed files with 81 additions and 59 deletions

View File

@ -8,6 +8,7 @@ import (
"bytes" "bytes"
"encoding/asn1" "encoding/asn1"
"errors" "errors"
"math"
"math/big" "math/big"
"math/bits" "math/bits"
"strconv" "strconv"
@ -44,14 +45,6 @@ func newOIDFromDER(der []byte) (OID, bool) {
return OID{der}, true return OID{der}, true
} }
func mustNewOIDFromInts(ints []uint64) OID {
oid, err := OIDFromInts(ints)
if err != nil {
panic("crypto/x509: mustNewOIDFromInts: " + err.Error())
}
return oid
}
// OIDFromInts creates a new OID using ints, each integer is a separate component. // OIDFromInts creates a new OID using ints, each integer is a separate component.
func OIDFromInts(oid []uint64) (OID, error) { func OIDFromInts(oid []uint64) (OID, error) {
if len(oid) < 2 || oid[0] > 2 || (oid[0] < 2 && oid[1] >= 40) { if len(oid) < 2 || oid[0] > 2 || (oid[0] < 2 && oid[1] >= 40) {
@ -97,57 +90,78 @@ func (oid OID) Equal(other OID) bool {
return bytes.Equal(oid.der, other.der) return bytes.Equal(oid.der, other.der)
} }
func parseBase128Int(bytes []byte, initOffset int) (ret, offset int, failed bool) {
offset = initOffset
var ret64 int64
for shifted := 0; offset < len(bytes); shifted++ {
// 5 * 7 bits per byte == 35 bits of data
// Thus the representation is either non-minimal or too large for an int32
if shifted == 5 {
failed = true
return
}
ret64 <<= 7
b := bytes[offset]
// integers should be minimally encoded, so the leading octet should
// never be 0x80
if shifted == 0 && b == 0x80 {
failed = true
return
}
ret64 |= int64(b & 0x7f)
offset++
if b&0x80 == 0 {
ret = int(ret64)
// Ensure that the returned value fits in an int on all platforms
if ret64 > math.MaxInt32 {
failed = true
}
return
}
}
failed = true
return
}
// EqualASN1OID returns whether an OID equals an asn1.ObjectIdentifier. If // EqualASN1OID returns whether an OID equals an asn1.ObjectIdentifier. If
// asn1.ObjectIdentifier cannot represent the OID specified by oid, because // asn1.ObjectIdentifier cannot represent the OID specified by oid, because
// a component of OID requires more than 31 bits, it returns false. // a component of OID requires more than 31 bits, it returns false.
func (oid OID) EqualASN1OID(other asn1.ObjectIdentifier) bool { func (oid OID) EqualASN1OID(other asn1.ObjectIdentifier) bool {
const ( if len(other) < 2 {
valSize = 31 // amount of usable bits of val for OIDs. return false
bitsPerByte = 7 }
maxValSafeShift = (1 << (valSize - bitsPerByte)) - 1 v, offset, failed := parseBase128Int(oid.der, 0)
) if failed {
var ( // This should never happen, since we've already parsed the OID,
val = 0 // but just in case.
first = true return false
) }
for _, v := range oid.der { if v < 80 {
if val > maxValSafeShift { a, b := v/40, v%40
if other[0] != a || other[1] != b {
return false return false
} }
val <<= bitsPerByte } else {
val |= int(v & 0x7F) a, b := 2, v-80
if v&0x80 == 0 { if other[0] != a || other[1] != b {
if first { return false
if len(other) < 2 {
return false
}
var val1, val2 int
if val < 80 {
val1 = val / 40
val2 = val % 40
} else {
val1 = 2
val2 = val - 80
}
if val1 != other[0] || val2 != other[1] {
return false
}
val = 0
first = false
other = other[2:]
continue
}
if len(other) == 0 {
return false
}
if val != other[0] {
return false
}
val = 0
other = other[1:]
} }
} }
return true
i := 2
for ; offset < len(oid.der); i++ {
v, offset, failed = parseBase128Int(oid.der, offset)
if failed {
// Again, shouldn't happen, since we've already parsed
// the OID, but better safe than sorry.
return false
}
if v != other[i] {
return false
}
}
return i == len(other)
} }
// Strings returns the string representation of the Object Identifier. // Strings returns the string representation of the Object Identifier.

View File

@ -100,3 +100,11 @@ func TestOID(t *testing.T) {
} }
} }
} }
func mustNewOIDFromInts(t *testing.T, ints []uint64) OID {
oid, err := OIDFromInts(ints)
if err != nil {
t.Fatalf("OIDFromInts(%v) unexpected error: %v", ints, err)
}
return oid
}

View File

@ -673,7 +673,7 @@ func TestCreateSelfSignedCertificate(t *testing.T) {
URIs: []*url.URL{parseURI("https://foo.com/wibble#foo")}, URIs: []*url.URL{parseURI("https://foo.com/wibble#foo")},
PolicyIdentifiers: []asn1.ObjectIdentifier{[]int{1, 2, 3}}, PolicyIdentifiers: []asn1.ObjectIdentifier{[]int{1, 2, 3}},
Policies: []OID{mustNewOIDFromInts([]uint64{1, 2, 3, math.MaxUint32, math.MaxUint64})}, Policies: []OID{mustNewOIDFromInts(t, []uint64{1, 2, 3, math.MaxUint32, math.MaxUint64})},
PermittedDNSDomains: []string{".example.com", "example.com"}, PermittedDNSDomains: []string{".example.com", "example.com"},
ExcludedDNSDomains: []string{"bar.example.com"}, ExcludedDNSDomains: []string{"bar.example.com"},
PermittedIPRanges: []*net.IPNet{parseCIDR("192.168.1.1/16"), parseCIDR("1.2.3.4/8")}, PermittedIPRanges: []*net.IPNet{parseCIDR("192.168.1.1/16"), parseCIDR("1.2.3.4/8")},
@ -3929,9 +3929,9 @@ func TestCertificateOIDPolicies(t *testing.T) {
NotAfter: time.Unix(100000, 0), NotAfter: time.Unix(100000, 0),
PolicyIdentifiers: []asn1.ObjectIdentifier{[]int{1, 2, 3}}, PolicyIdentifiers: []asn1.ObjectIdentifier{[]int{1, 2, 3}},
Policies: []OID{ Policies: []OID{
mustNewOIDFromInts([]uint64{1, 2, 3, 4, 5}), mustNewOIDFromInts(t, []uint64{1, 2, 3, 4, 5}),
mustNewOIDFromInts([]uint64{1, 2, 3, math.MaxInt32}), mustNewOIDFromInts(t, []uint64{1, 2, 3, math.MaxInt32}),
mustNewOIDFromInts([]uint64{1, 2, 3, math.MaxUint32, math.MaxUint64}), mustNewOIDFromInts(t, []uint64{1, 2, 3, math.MaxUint32, math.MaxUint64}),
}, },
} }
@ -3942,10 +3942,10 @@ func TestCertificateOIDPolicies(t *testing.T) {
} }
var expectPolicies = []OID{ var expectPolicies = []OID{
mustNewOIDFromInts([]uint64{1, 2, 3, 4, 5}), mustNewOIDFromInts(t, []uint64{1, 2, 3, 4, 5}),
mustNewOIDFromInts([]uint64{1, 2, 3, math.MaxInt32}), mustNewOIDFromInts(t, []uint64{1, 2, 3, math.MaxInt32}),
mustNewOIDFromInts([]uint64{1, 2, 3, math.MaxUint32, math.MaxUint64}), mustNewOIDFromInts(t, []uint64{1, 2, 3, math.MaxUint32, math.MaxUint64}),
mustNewOIDFromInts([]uint64{1, 2, 3}), mustNewOIDFromInts(t, []uint64{1, 2, 3}),
} }
certDER, err := CreateCertificate(rand.Reader, &template, &template, rsaPrivateKey.Public(), rsaPrivateKey) certDER, err := CreateCertificate(rand.Reader, &template, &template, rsaPrivateKey.Public(), rsaPrivateKey)