1
0
mirror of https://github.com/golang/go synced 2024-11-15 00:20:30 -07:00

crypto/tls: don't cache marshal'd bytes

Only cache the wire representation for clientHelloMsg and serverHelloMsg
during unmarshal, which are the only places we actually need to hold
onto them. For everything else, remove the raw field.

This appears to have zero performance impact:

name                                               old time/op   new time/op   delta
CertCache/0-10                                       177µs ± 2%    189µs ±11%   ~     (p=0.700 n=3+3)
CertCache/1-10                                       184µs ± 3%    182µs ± 6%   ~     (p=1.000 n=3+3)
CertCache/2-10                                       187µs ±12%    187µs ± 2%   ~     (p=1.000 n=3+3)
CertCache/3-10                                       204µs ±21%    187µs ± 1%   ~     (p=0.700 n=3+3)
HandshakeServer/RSA-10                               410µs ± 2%    410µs ± 3%   ~     (p=1.000 n=3+3)
HandshakeServer/ECDHE-P256-RSA/TLSv13-10             473µs ± 3%    460µs ± 2%   ~     (p=0.200 n=3+3)
HandshakeServer/ECDHE-P256-RSA/TLSv12-10             498µs ± 3%    489µs ± 2%   ~     (p=0.700 n=3+3)
HandshakeServer/ECDHE-P256-ECDSA-P256/TLSv13-10      140µs ± 5%    138µs ± 5%   ~     (p=1.000 n=3+3)
HandshakeServer/ECDHE-P256-ECDSA-P256/TLSv12-10      132µs ± 1%    133µs ± 2%   ~     (p=0.400 n=3+3)
HandshakeServer/ECDHE-X25519-ECDSA-P256/TLSv13-10    168µs ± 1%    171µs ± 4%   ~     (p=1.000 n=3+3)
HandshakeServer/ECDHE-X25519-ECDSA-P256/TLSv12-10    166µs ± 3%    163µs ± 0%   ~     (p=0.700 n=3+3)
HandshakeServer/ECDHE-P521-ECDSA-P521/TLSv13-10     1.87ms ± 2%   1.81ms ± 0%   ~     (p=0.100 n=3+3)
HandshakeServer/ECDHE-P521-ECDSA-P521/TLSv12-10     1.86ms ± 0%   1.86ms ± 1%   ~     (p=1.000 n=3+3)
Throughput/MaxPacket/1MB/TLSv12-10                  6.79ms ± 3%   6.73ms ± 0%   ~     (p=1.000 n=3+3)
Throughput/MaxPacket/1MB/TLSv13-10                  6.73ms ± 1%   6.75ms ± 0%   ~     (p=0.700 n=3+3)
Throughput/MaxPacket/2MB/TLSv12-10                  12.8ms ± 2%   12.7ms ± 0%   ~     (p=0.700 n=3+3)
Throughput/MaxPacket/2MB/TLSv13-10                  13.1ms ± 3%   12.8ms ± 1%   ~     (p=0.400 n=3+3)
Throughput/MaxPacket/4MB/TLSv12-10                  24.9ms ± 2%   24.7ms ± 1%   ~     (p=1.000 n=3+3)
Throughput/MaxPacket/4MB/TLSv13-10                  26.0ms ± 4%   24.9ms ± 1%   ~     (p=0.100 n=3+3)
Throughput/MaxPacket/8MB/TLSv12-10                  50.0ms ± 3%   48.9ms ± 0%   ~     (p=0.200 n=3+3)
Throughput/MaxPacket/8MB/TLSv13-10                  49.8ms ± 2%   49.3ms ± 1%   ~     (p=0.400 n=3+3)
Throughput/MaxPacket/16MB/TLSv12-10                 97.3ms ± 1%   97.4ms ± 0%   ~     (p=0.700 n=3+3)
Throughput/MaxPacket/16MB/TLSv13-10                 97.9ms ± 0%   97.9ms ± 1%   ~     (p=1.000 n=3+3)
Throughput/MaxPacket/32MB/TLSv12-10                  195ms ± 0%    194ms ± 1%   ~     (p=0.400 n=3+3)
Throughput/MaxPacket/32MB/TLSv13-10                  196ms ± 0%    196ms ± 1%   ~     (p=0.700 n=3+3)
Throughput/MaxPacket/64MB/TLSv12-10                  405ms ± 3%    385ms ± 0%   ~     (p=0.100 n=3+3)
Throughput/MaxPacket/64MB/TLSv13-10                  391ms ± 1%    388ms ± 1%   ~     (p=0.200 n=3+3)
Throughput/DynamicPacket/1MB/TLSv12-10              6.75ms ± 0%   6.75ms ± 1%   ~     (p=0.700 n=3+3)
Throughput/DynamicPacket/1MB/TLSv13-10              6.84ms ± 1%   6.77ms ± 0%   ~     (p=0.100 n=3+3)
Throughput/DynamicPacket/2MB/TLSv12-10              12.8ms ± 1%   12.8ms ± 1%   ~     (p=0.400 n=3+3)
Throughput/DynamicPacket/2MB/TLSv13-10              12.8ms ± 1%   13.0ms ± 1%   ~     (p=0.200 n=3+3)
Throughput/DynamicPacket/4MB/TLSv12-10              24.8ms ± 1%   24.8ms ± 0%   ~     (p=1.000 n=3+3)
Throughput/DynamicPacket/4MB/TLSv13-10              25.1ms ± 2%   25.1ms ± 1%   ~     (p=1.000 n=3+3)
Throughput/DynamicPacket/8MB/TLSv12-10              49.2ms ± 2%   48.9ms ± 0%   ~     (p=0.700 n=3+3)
Throughput/DynamicPacket/8MB/TLSv13-10              49.3ms ± 1%   49.4ms ± 1%   ~     (p=0.700 n=3+3)
Throughput/DynamicPacket/16MB/TLSv12-10             97.1ms ± 0%   98.0ms ± 1%   ~     (p=0.200 n=3+3)
Throughput/DynamicPacket/16MB/TLSv13-10             98.8ms ± 1%   98.4ms ± 1%   ~     (p=0.700 n=3+3)
Throughput/DynamicPacket/32MB/TLSv12-10              192ms ± 0%    198ms ± 5%   ~     (p=0.100 n=3+3)
Throughput/DynamicPacket/32MB/TLSv13-10              194ms ± 0%    196ms ± 1%   ~     (p=0.400 n=3+3)
Throughput/DynamicPacket/64MB/TLSv12-10              385ms ± 1%    384ms ± 0%   ~     (p=0.700 n=3+3)
Throughput/DynamicPacket/64MB/TLSv13-10              387ms ± 0%    388ms ± 0%   ~     (p=0.400 n=3+3)
Latency/MaxPacket/200kbps/TLSv12-10                  694ms ± 0%    694ms ± 0%   ~     (p=0.700 n=3+3)
Latency/MaxPacket/200kbps/TLSv13-10                  699ms ± 0%    699ms ± 0%   ~     (p=0.700 n=3+3)
Latency/MaxPacket/500kbps/TLSv12-10                  278ms ± 0%    278ms ± 0%   ~     (p=0.400 n=3+3)
Latency/MaxPacket/500kbps/TLSv13-10                  280ms ± 0%    280ms ± 0%   ~     (p=1.000 n=3+3)
Latency/MaxPacket/1000kbps/TLSv12-10                 140ms ± 1%    140ms ± 0%   ~     (p=0.700 n=3+3)
Latency/MaxPacket/1000kbps/TLSv13-10                 141ms ± 0%    141ms ± 0%   ~     (p=1.000 n=3+3)
Latency/MaxPacket/2000kbps/TLSv12-10                70.5ms ± 0%   70.4ms ± 0%   ~     (p=0.700 n=3+3)
Latency/MaxPacket/2000kbps/TLSv13-10                70.7ms ± 0%   70.7ms ± 0%   ~     (p=0.700 n=3+3)
Latency/MaxPacket/5000kbps/TLSv12-10                28.8ms ± 0%   28.8ms ± 0%   ~     (p=0.700 n=3+3)
Latency/MaxPacket/5000kbps/TLSv13-10                28.9ms ± 0%   28.9ms ± 0%   ~     (p=0.700 n=3+3)
Latency/DynamicPacket/200kbps/TLSv12-10              134ms ± 0%    134ms ± 0%   ~     (p=0.700 n=3+3)
Latency/DynamicPacket/200kbps/TLSv13-10              138ms ± 0%    138ms ± 0%   ~     (p=1.000 n=3+3)
Latency/DynamicPacket/500kbps/TLSv12-10             54.1ms ± 0%   54.1ms ± 0%   ~     (p=1.000 n=3+3)
Latency/DynamicPacket/500kbps/TLSv13-10             55.7ms ± 0%   55.7ms ± 0%   ~     (p=0.100 n=3+3)
Latency/DynamicPacket/1000kbps/TLSv12-10            27.6ms ± 0%   27.6ms ± 0%   ~     (p=0.200 n=3+3)
Latency/DynamicPacket/1000kbps/TLSv13-10            28.4ms ± 0%   28.4ms ± 0%   ~     (p=0.200 n=3+3)
Latency/DynamicPacket/2000kbps/TLSv12-10            14.4ms ± 0%   14.4ms ± 0%   ~     (p=1.000 n=3+3)
Latency/DynamicPacket/2000kbps/TLSv13-10            14.6ms ± 0%   14.6ms ± 0%   ~     (p=1.000 n=3+3)
Latency/DynamicPacket/5000kbps/TLSv12-10            6.44ms ± 0%   6.45ms ± 0%   ~     (p=0.100 n=3+3)
Latency/DynamicPacket/5000kbps/TLSv13-10            6.49ms ± 0%   6.49ms ± 0%   ~     (p=0.700 n=3+3)

name                                               old speed     new speed     delta
Throughput/MaxPacket/1MB/TLSv12-10                 155MB/s ± 3%  156MB/s ± 0%   ~     (p=1.000 n=3+3)
Throughput/MaxPacket/1MB/TLSv13-10                 156MB/s ± 1%  155MB/s ± 0%   ~     (p=0.700 n=3+3)
Throughput/MaxPacket/2MB/TLSv12-10                 163MB/s ± 2%  165MB/s ± 0%   ~     (p=0.700 n=3+3)
Throughput/MaxPacket/2MB/TLSv13-10                 160MB/s ± 3%  164MB/s ± 1%   ~     (p=0.400 n=3+3)
Throughput/MaxPacket/4MB/TLSv12-10                 168MB/s ± 2%  170MB/s ± 1%   ~     (p=1.000 n=3+3)
Throughput/MaxPacket/4MB/TLSv13-10                 162MB/s ± 4%  168MB/s ± 1%   ~     (p=0.100 n=3+3)
Throughput/MaxPacket/8MB/TLSv12-10                 168MB/s ± 3%  172MB/s ± 0%   ~     (p=0.200 n=3+3)
Throughput/MaxPacket/8MB/TLSv13-10                 168MB/s ± 2%  170MB/s ± 1%   ~     (p=0.400 n=3+3)
Throughput/MaxPacket/16MB/TLSv12-10                172MB/s ± 1%  172MB/s ± 0%   ~     (p=0.700 n=3+3)
Throughput/MaxPacket/16MB/TLSv13-10                171MB/s ± 0%  171MB/s ± 1%   ~     (p=1.000 n=3+3)
Throughput/MaxPacket/32MB/TLSv12-10                172MB/s ± 0%  173MB/s ± 1%   ~     (p=0.400 n=3+3)
Throughput/MaxPacket/32MB/TLSv13-10                171MB/s ± 0%  172MB/s ± 1%   ~     (p=0.700 n=3+3)
Throughput/MaxPacket/64MB/TLSv12-10                166MB/s ± 3%  174MB/s ± 0%   ~     (p=0.100 n=3+3)
Throughput/MaxPacket/64MB/TLSv13-10                171MB/s ± 1%  173MB/s ± 1%   ~     (p=0.200 n=3+3)
Throughput/DynamicPacket/1MB/TLSv12-10             155MB/s ± 0%  155MB/s ± 1%   ~     (p=0.700 n=3+3)
Throughput/DynamicPacket/1MB/TLSv13-10             153MB/s ± 1%  155MB/s ± 0%   ~     (p=0.100 n=3+3)
Throughput/DynamicPacket/2MB/TLSv12-10             164MB/s ± 1%  164MB/s ± 1%   ~     (p=0.400 n=3+3)
Throughput/DynamicPacket/2MB/TLSv13-10             163MB/s ± 1%  162MB/s ± 1%   ~     (p=0.200 n=3+3)
Throughput/DynamicPacket/4MB/TLSv12-10             169MB/s ± 1%  169MB/s ± 0%   ~     (p=1.000 n=3+3)
Throughput/DynamicPacket/4MB/TLSv13-10             167MB/s ± 1%  167MB/s ± 1%   ~     (p=1.000 n=3+3)
Throughput/DynamicPacket/8MB/TLSv12-10             170MB/s ± 2%  171MB/s ± 0%   ~     (p=0.700 n=3+3)
Throughput/DynamicPacket/8MB/TLSv13-10             170MB/s ± 1%  170MB/s ± 1%   ~     (p=0.700 n=3+3)
Throughput/DynamicPacket/16MB/TLSv12-10            173MB/s ± 0%  171MB/s ± 1%   ~     (p=0.200 n=3+3)
Throughput/DynamicPacket/16MB/TLSv13-10            170MB/s ± 1%  170MB/s ± 1%   ~     (p=0.700 n=3+3)
Throughput/DynamicPacket/32MB/TLSv12-10            175MB/s ± 0%  170MB/s ± 5%   ~     (p=0.100 n=3+3)
Throughput/DynamicPacket/32MB/TLSv13-10            173MB/s ± 0%  171MB/s ± 1%   ~     (p=0.300 n=3+3)
Throughput/DynamicPacket/64MB/TLSv12-10            174MB/s ± 1%  175MB/s ± 0%   ~     (p=0.700 n=3+3)
Throughput/DynamicPacket/64MB/TLSv13-10            174MB/s ± 0%  173MB/s ± 0%   ~     (p=0.400 n=3+3)

Change-Id: Ifa79cce002011850ed8b2835edd34f60e014eea8
Cq-Include-Trybots: luci.golang.try:gotip-linux-amd64-longtest,gotip-linux-arm64-longtest
Reviewed-on: https://go-review.googlesource.com/c/go/+/580215
LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>
Reviewed-by: Damien Neil <dneil@google.com>
Reviewed-by: Filippo Valsorda <filippo@golang.org>
Auto-Submit: Roland Shoemaker <roland@golang.org>
This commit is contained in:
Roland Shoemaker 2024-04-18 10:51:25 -07:00 committed by Gopher Robot
parent 1a3682b4c1
commit f0d6ddfac0
4 changed files with 114 additions and 189 deletions

View File

@ -1445,6 +1445,15 @@ type handshakeMessage interface {
unmarshal([]byte) bool
}
type handshakeMessageWithOriginalBytes interface {
handshakeMessage
// originalBytes should return the original bytes that were passed to
// unmarshal to create the message. If the message was not produced by
// unmarshal, it should return nil.
originalBytes() []byte
}
// lruSessionCache is a ClientSessionCache implementation that uses an LRU
// caching strategy.
type lruSessionCache struct {

View File

@ -249,7 +249,6 @@ func (hs *clientHandshakeStateTLS13) processHelloRetryRequest() error {
hs.hello.keyShares = []keyShare{{group: curveID, data: key.PublicKey().Bytes()}}
}
hs.hello.raw = nil
if len(hs.hello.pskIdentities) > 0 {
pskSuite := cipherSuiteTLS13ByID(hs.session.cipherSuite)
if pskSuite == nil {

View File

@ -68,7 +68,7 @@ func readUint24LengthPrefixed(s *cryptobyte.String, out *[]byte) bool {
}
type clientHelloMsg struct {
raw []byte
original []byte
vers uint16
random []byte
sessionId []byte
@ -98,10 +98,6 @@ type clientHelloMsg struct {
}
func (m *clientHelloMsg) marshal() ([]byte, error) {
if m.raw != nil {
return m.raw, nil
}
var exts cryptobyte.Builder
if len(m.serverName) > 0 {
// RFC 6066, Section 3
@ -310,8 +306,7 @@ func (m *clientHelloMsg) marshal() ([]byte, error) {
}
})
m.raw, err = b.Bytes()
return m.raw, err
return b.Bytes()
}
// marshalWithoutBinders returns the ClientHello through the
@ -324,16 +319,21 @@ func (m *clientHelloMsg) marshalWithoutBinders() ([]byte, error) {
bindersLen += len(binder)
}
fullMessage, err := m.marshal()
if err != nil {
return nil, err
var fullMessage []byte
if m.original != nil {
fullMessage = m.original
} else {
var err error
fullMessage, err = m.marshal()
if err != nil {
return nil, err
}
}
return fullMessage[:len(fullMessage)-bindersLen], nil
}
// updateBinders updates the m.pskBinders field, if necessary updating the
// cached marshaled representation. The supplied binders must have the same
// length as the current m.pskBinders.
// updateBinders updates the m.pskBinders field. The supplied binders must have
// the same length as the current m.pskBinders.
func (m *clientHelloMsg) updateBinders(pskBinders [][]byte) error {
if len(pskBinders) != len(m.pskBinders) {
return errors.New("tls: internal error: pskBinders length mismatch")
@ -344,30 +344,12 @@ func (m *clientHelloMsg) updateBinders(pskBinders [][]byte) error {
}
}
m.pskBinders = pskBinders
if m.raw != nil {
helloBytes, err := m.marshalWithoutBinders()
if err != nil {
return err
}
lenWithoutBinders := len(helloBytes)
b := cryptobyte.NewFixedBuilder(m.raw[:lenWithoutBinders])
b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
for _, binder := range m.pskBinders {
b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) {
b.AddBytes(binder)
})
}
})
if out, err := b.Bytes(); err != nil || len(out) != len(m.raw) {
return errors.New("tls: internal error: failed to update binders")
}
}
return nil
}
func (m *clientHelloMsg) unmarshal(data []byte) bool {
*m = clientHelloMsg{raw: data}
*m = clientHelloMsg{original: data}
s := cryptobyte.String(data)
if !s.Skip(4) || // message type and uint24 length field
@ -625,8 +607,12 @@ func (m *clientHelloMsg) unmarshal(data []byte) bool {
return true
}
func (m *clientHelloMsg) originalBytes() []byte {
return m.original
}
type serverHelloMsg struct {
raw []byte
original []byte
vers uint16
random []byte
sessionId []byte
@ -651,10 +637,6 @@ type serverHelloMsg struct {
}
func (m *serverHelloMsg) marshal() ([]byte, error) {
if m.raw != nil {
return m.raw, nil
}
var exts cryptobyte.Builder
if m.ocspStapling {
exts.AddUint16(extensionStatusRequest)
@ -766,12 +748,11 @@ func (m *serverHelloMsg) marshal() ([]byte, error) {
}
})
m.raw, err = b.Bytes()
return m.raw, err
return b.Bytes()
}
func (m *serverHelloMsg) unmarshal(data []byte) bool {
*m = serverHelloMsg{raw: data}
*m = serverHelloMsg{original: data}
s := cryptobyte.String(data)
if !s.Skip(4) || // message type and uint24 length field
@ -888,18 +869,17 @@ func (m *serverHelloMsg) unmarshal(data []byte) bool {
return true
}
func (m *serverHelloMsg) originalBytes() []byte {
return m.original
}
type encryptedExtensionsMsg struct {
raw []byte
alpnProtocol string
quicTransportParameters []byte
earlyData bool
}
func (m *encryptedExtensionsMsg) marshal() ([]byte, error) {
if m.raw != nil {
return m.raw, nil
}
var b cryptobyte.Builder
b.AddUint8(typeEncryptedExtensions)
b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) {
@ -929,13 +909,11 @@ func (m *encryptedExtensionsMsg) marshal() ([]byte, error) {
})
})
var err error
m.raw, err = b.Bytes()
return m.raw, err
return b.Bytes()
}
func (m *encryptedExtensionsMsg) unmarshal(data []byte) bool {
*m = encryptedExtensionsMsg{raw: data}
*m = encryptedExtensionsMsg{}
s := cryptobyte.String(data)
var extensions cryptobyte.String
@ -998,15 +976,10 @@ func (m *endOfEarlyDataMsg) unmarshal(data []byte) bool {
}
type keyUpdateMsg struct {
raw []byte
updateRequested bool
}
func (m *keyUpdateMsg) marshal() ([]byte, error) {
if m.raw != nil {
return m.raw, nil
}
var b cryptobyte.Builder
b.AddUint8(typeKeyUpdate)
b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) {
@ -1017,13 +990,10 @@ func (m *keyUpdateMsg) marshal() ([]byte, error) {
}
})
var err error
m.raw, err = b.Bytes()
return m.raw, err
return b.Bytes()
}
func (m *keyUpdateMsg) unmarshal(data []byte) bool {
m.raw = data
s := cryptobyte.String(data)
var updateRequested uint8
@ -1043,7 +1013,6 @@ func (m *keyUpdateMsg) unmarshal(data []byte) bool {
}
type newSessionTicketMsgTLS13 struct {
raw []byte
lifetime uint32
ageAdd uint32
nonce []byte
@ -1052,10 +1021,6 @@ type newSessionTicketMsgTLS13 struct {
}
func (m *newSessionTicketMsgTLS13) marshal() ([]byte, error) {
if m.raw != nil {
return m.raw, nil
}
var b cryptobyte.Builder
b.AddUint8(typeNewSessionTicket)
b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) {
@ -1078,13 +1043,11 @@ func (m *newSessionTicketMsgTLS13) marshal() ([]byte, error) {
})
})
var err error
m.raw, err = b.Bytes()
return m.raw, err
return b.Bytes()
}
func (m *newSessionTicketMsgTLS13) unmarshal(data []byte) bool {
*m = newSessionTicketMsgTLS13{raw: data}
*m = newSessionTicketMsgTLS13{}
s := cryptobyte.String(data)
var extensions cryptobyte.String
@ -1125,7 +1088,6 @@ func (m *newSessionTicketMsgTLS13) unmarshal(data []byte) bool {
}
type certificateRequestMsgTLS13 struct {
raw []byte
ocspStapling bool
scts bool
supportedSignatureAlgorithms []SignatureScheme
@ -1134,10 +1096,6 @@ type certificateRequestMsgTLS13 struct {
}
func (m *certificateRequestMsgTLS13) marshal() ([]byte, error) {
if m.raw != nil {
return m.raw, nil
}
var b cryptobyte.Builder
b.AddUint8(typeCertificateRequest)
b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) {
@ -1194,13 +1152,11 @@ func (m *certificateRequestMsgTLS13) marshal() ([]byte, error) {
})
})
var err error
m.raw, err = b.Bytes()
return m.raw, err
return b.Bytes()
}
func (m *certificateRequestMsgTLS13) unmarshal(data []byte) bool {
*m = certificateRequestMsgTLS13{raw: data}
*m = certificateRequestMsgTLS13{}
s := cryptobyte.String(data)
var context, extensions cryptobyte.String
@ -1276,15 +1232,10 @@ func (m *certificateRequestMsgTLS13) unmarshal(data []byte) bool {
}
type certificateMsg struct {
raw []byte
certificates [][]byte
}
func (m *certificateMsg) marshal() ([]byte, error) {
if m.raw != nil {
return m.raw, nil
}
var i int
for _, slice := range m.certificates {
i += len(slice)
@ -1311,8 +1262,7 @@ func (m *certificateMsg) marshal() ([]byte, error) {
y = y[3+len(slice):]
}
m.raw = x
return m.raw, nil
return x, nil
}
func (m *certificateMsg) unmarshal(data []byte) bool {
@ -1320,7 +1270,6 @@ func (m *certificateMsg) unmarshal(data []byte) bool {
return false
}
m.raw = data
certsLen := uint32(data[4])<<16 | uint32(data[5])<<8 | uint32(data[6])
if uint32(len(data)) != certsLen+7 {
return false
@ -1353,17 +1302,12 @@ func (m *certificateMsg) unmarshal(data []byte) bool {
}
type certificateMsgTLS13 struct {
raw []byte
certificate Certificate
ocspStapling bool
scts bool
}
func (m *certificateMsgTLS13) marshal() ([]byte, error) {
if m.raw != nil {
return m.raw, nil
}
var b cryptobyte.Builder
b.AddUint8(typeCertificate)
b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) {
@ -1379,9 +1323,7 @@ func (m *certificateMsgTLS13) marshal() ([]byte, error) {
marshalCertificate(b, certificate)
})
var err error
m.raw, err = b.Bytes()
return m.raw, err
return b.Bytes()
}
func marshalCertificate(b *cryptobyte.Builder, certificate Certificate) {
@ -1422,7 +1364,7 @@ func marshalCertificate(b *cryptobyte.Builder, certificate Certificate) {
}
func (m *certificateMsgTLS13) unmarshal(data []byte) bool {
*m = certificateMsgTLS13{raw: data}
*m = certificateMsgTLS13{}
s := cryptobyte.String(data)
var context cryptobyte.String
@ -1500,14 +1442,10 @@ func unmarshalCertificate(s *cryptobyte.String, certificate *Certificate) bool {
}
type serverKeyExchangeMsg struct {
raw []byte
key []byte
}
func (m *serverKeyExchangeMsg) marshal() ([]byte, error) {
if m.raw != nil {
return m.raw, nil
}
length := len(m.key)
x := make([]byte, length+4)
x[0] = typeServerKeyExchange
@ -1516,12 +1454,10 @@ func (m *serverKeyExchangeMsg) marshal() ([]byte, error) {
x[3] = uint8(length)
copy(x[4:], m.key)
m.raw = x
return x, nil
}
func (m *serverKeyExchangeMsg) unmarshal(data []byte) bool {
m.raw = data
if len(data) < 4 {
return false
}
@ -1530,15 +1466,10 @@ func (m *serverKeyExchangeMsg) unmarshal(data []byte) bool {
}
type certificateStatusMsg struct {
raw []byte
response []byte
}
func (m *certificateStatusMsg) marshal() ([]byte, error) {
if m.raw != nil {
return m.raw, nil
}
var b cryptobyte.Builder
b.AddUint8(typeCertificateStatus)
b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) {
@ -1548,13 +1479,10 @@ func (m *certificateStatusMsg) marshal() ([]byte, error) {
})
})
var err error
m.raw, err = b.Bytes()
return m.raw, err
return b.Bytes()
}
func (m *certificateStatusMsg) unmarshal(data []byte) bool {
m.raw = data
s := cryptobyte.String(data)
var statusType uint8
@ -1580,14 +1508,10 @@ func (m *serverHelloDoneMsg) unmarshal(data []byte) bool {
}
type clientKeyExchangeMsg struct {
raw []byte
ciphertext []byte
}
func (m *clientKeyExchangeMsg) marshal() ([]byte, error) {
if m.raw != nil {
return m.raw, nil
}
length := len(m.ciphertext)
x := make([]byte, length+4)
x[0] = typeClientKeyExchange
@ -1596,12 +1520,10 @@ func (m *clientKeyExchangeMsg) marshal() ([]byte, error) {
x[3] = uint8(length)
copy(x[4:], m.ciphertext)
m.raw = x
return x, nil
}
func (m *clientKeyExchangeMsg) unmarshal(data []byte) bool {
m.raw = data
if len(data) < 4 {
return false
}
@ -1614,28 +1536,20 @@ func (m *clientKeyExchangeMsg) unmarshal(data []byte) bool {
}
type finishedMsg struct {
raw []byte
verifyData []byte
}
func (m *finishedMsg) marshal() ([]byte, error) {
if m.raw != nil {
return m.raw, nil
}
var b cryptobyte.Builder
b.AddUint8(typeFinished)
b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) {
b.AddBytes(m.verifyData)
})
var err error
m.raw, err = b.Bytes()
return m.raw, err
return b.Bytes()
}
func (m *finishedMsg) unmarshal(data []byte) bool {
m.raw = data
s := cryptobyte.String(data)
return s.Skip(1) &&
readUint24LengthPrefixed(&s, &m.verifyData) &&
@ -1643,7 +1557,6 @@ func (m *finishedMsg) unmarshal(data []byte) bool {
}
type certificateRequestMsg struct {
raw []byte
// hasSignatureAlgorithm indicates whether this message includes a list of
// supported signature algorithms. This change was introduced with TLS 1.2.
hasSignatureAlgorithm bool
@ -1654,10 +1567,6 @@ type certificateRequestMsg struct {
}
func (m *certificateRequestMsg) marshal() ([]byte, error) {
if m.raw != nil {
return m.raw, nil
}
// See RFC 4346, Section 7.4.4.
length := 1 + len(m.certificateTypes) + 2
casLength := 0
@ -1704,13 +1613,10 @@ func (m *certificateRequestMsg) marshal() ([]byte, error) {
y = y[len(ca):]
}
m.raw = x
return m.raw, nil
return x, nil
}
func (m *certificateRequestMsg) unmarshal(data []byte) bool {
m.raw = data
if len(data) < 5 {
return false
}
@ -1785,17 +1691,12 @@ func (m *certificateRequestMsg) unmarshal(data []byte) bool {
}
type certificateVerifyMsg struct {
raw []byte
hasSignatureAlgorithm bool // format change introduced in TLS 1.2
signatureAlgorithm SignatureScheme
signature []byte
}
func (m *certificateVerifyMsg) marshal() ([]byte, error) {
if m.raw != nil {
return m.raw, nil
}
var b cryptobyte.Builder
b.AddUint8(typeCertificateVerify)
b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) {
@ -1807,13 +1708,10 @@ func (m *certificateVerifyMsg) marshal() ([]byte, error) {
})
})
var err error
m.raw, err = b.Bytes()
return m.raw, err
return b.Bytes()
}
func (m *certificateVerifyMsg) unmarshal(data []byte) bool {
m.raw = data
s := cryptobyte.String(data)
if !s.Skip(4) { // message type and uint24 length field
@ -1828,15 +1726,10 @@ func (m *certificateVerifyMsg) unmarshal(data []byte) bool {
}
type newSessionTicketMsg struct {
raw []byte
ticket []byte
}
func (m *newSessionTicketMsg) marshal() ([]byte, error) {
if m.raw != nil {
return m.raw, nil
}
// See RFC 5077, Section 3.3.
ticketLen := len(m.ticket)
length := 2 + 4 + ticketLen
@ -1849,14 +1742,10 @@ func (m *newSessionTicketMsg) marshal() ([]byte, error) {
x[9] = uint8(ticketLen)
copy(x[10:], m.ticket)
m.raw = x
return m.raw, nil
return x, nil
}
func (m *newSessionTicketMsg) unmarshal(data []byte) bool {
m.raw = data
if len(data) < 10 {
return false
}
@ -1891,9 +1780,25 @@ type transcriptHash interface {
Write([]byte) (int, error)
}
// transcriptMsg is a helper used to marshal and hash messages which typically
// are not written to the wire, and as such aren't hashed during Conn.writeRecord.
// transcriptMsg is a helper used to hash messages which are not hashed when
// they are read from, or written to, the wire. This is typically the case for
// messages which are either not sent, or need to be hashed out of order from
// when they are read/written.
//
// For most messages, the message is marshalled using their marshal method,
// since their wire representation is idempotent. For clientHelloMsg and
// serverHelloMsg, we store the original wire representation of the message and
// use that for hashing, since unmarshal/marshal are not idempotent due to
// extension ordering and other malleable fields, which may cause differences
// between what was received and what we marshal.
func transcriptMsg(msg handshakeMessage, h transcriptHash) error {
if msgWithOrig, ok := msg.(handshakeMessageWithOriginalBytes); ok {
if orig := msgWithOrig.originalBytes(); orig != nil {
h.Write(msgWithOrig.originalBytes())
return nil
}
}
data, err := msg.marshal()
if err != nil {
return err

View File

@ -53,49 +53,61 @@ func TestMarshalUnmarshal(t *testing.T) {
for i, m := range tests {
ty := reflect.ValueOf(m).Type()
n := 100
if testing.Short() {
n = 5
}
for j := 0; j < n; j++ {
v, ok := quick.Value(ty, rand)
if !ok {
t.Errorf("#%d: failed to create value", i)
break
t.Run(ty.String(), func(t *testing.T) {
n := 100
if testing.Short() {
n = 5
}
for j := 0; j < n; j++ {
v, ok := quick.Value(ty, rand)
if !ok {
t.Errorf("#%d: failed to create value", i)
break
}
m1 := v.Interface().(handshakeMessage)
marshaled := mustMarshal(t, m1)
if !m.unmarshal(marshaled) {
t.Errorf("#%d failed to unmarshal %#v %x", i, m1, marshaled)
break
}
m.marshal() // to fill any marshal cache in the message
m1 := v.Interface().(handshakeMessage)
marshaled := mustMarshal(t, m1)
if !m.unmarshal(marshaled) {
t.Errorf("#%d failed to unmarshal %#v %x", i, m1, marshaled)
break
}
if m, ok := m.(*SessionState); ok {
m.activeCertHandles = nil
}
if m, ok := m.(*SessionState); ok {
m.activeCertHandles = nil
}
if !reflect.DeepEqual(m1, m) {
t.Errorf("#%d got:%#v want:%#v %x", i, m, m1, marshaled)
break
}
// clientHelloMsg and serverHelloMsg, when unmarshalled, store
// their original representation, for later use in the handshake
// transcript. In order to prevent DeepEqual from failing since
// we didn't create the original message via unmarshalling, nil
// the field.
switch t := m.(type) {
case *clientHelloMsg:
t.original = nil
case *serverHelloMsg:
t.original = nil
}
if i >= 3 {
// The first three message types (ClientHello,
// ServerHello and Finished) are allowed to
// have parsable prefixes because the extension
// data is optional and the length of the
// Finished varies across versions.
for j := 0; j < len(marshaled); j++ {
if m.unmarshal(marshaled[0:j]) {
t.Errorf("#%d unmarshaled a prefix of length %d of %#v", i, j, m1)
break
if !reflect.DeepEqual(m1, m) {
t.Errorf("#%d got:%#v want:%#v %x", i, m, m1, marshaled)
break
}
if i >= 3 {
// The first three message types (ClientHello,
// ServerHello and Finished) are allowed to
// have parsable prefixes because the extension
// data is optional and the length of the
// Finished varies across versions.
for j := 0; j < len(marshaled); j++ {
if m.unmarshal(marshaled[0:j]) {
t.Errorf("#%d unmarshaled a prefix of length %d of %#v", i, j, m1)
break
}
}
}
}
}
})
}
}