diff --git a/src/crypto/tls/handshake_messages_test.go b/src/crypto/tls/handshake_messages_test.go index bef7570512..a50fa96fab 100644 --- a/src/crypto/tls/handshake_messages_test.go +++ b/src/crypto/tls/handshake_messages_test.go @@ -308,11 +308,9 @@ func (*sessionState) Generate(rand *rand.Rand, size int) reflect.Value { s := &sessionState{} s.vers = uint16(rand.Intn(10000)) s.cipherSuite = uint16(rand.Intn(10000)) - s.masterSecret = randomBytes(rand.Intn(100), rand) - numCerts := rand.Intn(20) - s.certificates = make([][]byte, numCerts) - for i := 0; i < numCerts; i++ { - s.certificates[i] = randomBytes(rand.Intn(10)+1, rand) + s.masterSecret = randomBytes(rand.Intn(100)+1, rand) + for i := 0; i < rand.Intn(20); i++ { + s.certificates = append(s.certificates, randomBytes(rand.Intn(500)+1, rand)) } return reflect.ValueOf(s) } diff --git a/src/crypto/tls/ticket.go b/src/crypto/tls/ticket.go index c873e43a70..dda0443ff4 100644 --- a/src/crypto/tls/ticket.go +++ b/src/crypto/tls/ticket.go @@ -12,8 +12,9 @@ import ( "crypto/sha256" "crypto/subtle" "errors" - "golang.org/x/crypto/cryptobyte" "io" + + "golang.org/x/crypto/cryptobyte" ) // sessionState contains the information that is serialized into a session @@ -21,88 +22,56 @@ import ( type sessionState struct { vers uint16 cipherSuite uint16 - masterSecret []byte - certificates [][]byte + masterSecret []byte // opaque master_secret<1..2^16-1>; + // struct { opaque certificate<1..2^32-1> } Certificate; + certificates [][]byte // Certificate certificate_list<0..2^16-1>; + // usedOldKey is true if the ticket from which this session came from // was encrypted with an older key and thus should be refreshed. usedOldKey bool } -func (s *sessionState) marshal() []byte { - length := 2 + 2 + 2 + len(s.masterSecret) + 2 - for _, cert := range s.certificates { - length += 4 + len(cert) - } - - ret := make([]byte, length) - x := ret - x[0] = byte(s.vers >> 8) - x[1] = byte(s.vers) - x[2] = byte(s.cipherSuite >> 8) - x[3] = byte(s.cipherSuite) - x[4] = byte(len(s.masterSecret) >> 8) - x[5] = byte(len(s.masterSecret)) - x = x[6:] - copy(x, s.masterSecret) - x = x[len(s.masterSecret):] - - x[0] = byte(len(s.certificates) >> 8) - x[1] = byte(len(s.certificates)) - x = x[2:] - - for _, cert := range s.certificates { - x[0] = byte(len(cert) >> 24) - x[1] = byte(len(cert) >> 16) - x[2] = byte(len(cert) >> 8) - x[3] = byte(len(cert)) - copy(x[4:], cert) - x = x[4+len(cert):] - } - - return ret +func (m *sessionState) marshal() []byte { + var b cryptobyte.Builder + b.AddUint16(m.vers) + b.AddUint16(m.cipherSuite) + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddBytes(m.masterSecret) + }) + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + for _, cert := range m.certificates { + b.AddUint32LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddBytes(cert) + }) + } + }) + return b.BytesOrPanic() } -func (s *sessionState) unmarshal(data []byte) bool { - if len(data) < 8 { +func (m *sessionState) unmarshal(data []byte) bool { + *m = sessionState{usedOldKey: m.usedOldKey} + s := cryptobyte.String(data) + if ok := s.ReadUint16(&m.vers) && + m.vers != VersionTLS13 && + s.ReadUint16(&m.cipherSuite) && + readUint16LengthPrefixed(&s, &m.masterSecret) && + len(m.masterSecret) != 0; !ok { return false } - - s.vers = uint16(data[0])<<8 | uint16(data[1]) - s.cipherSuite = uint16(data[2])<<8 | uint16(data[3]) - masterSecretLen := int(data[4])<<8 | int(data[5]) - data = data[6:] - if len(data) < masterSecretLen { + var certList cryptobyte.String + if !s.ReadUint16LengthPrefixed(&certList) { return false } - - s.masterSecret = data[:masterSecretLen] - data = data[masterSecretLen:] - - if len(data) < 2 { - return false + for !certList.Empty() { + var certLen uint32 + certList.ReadUint32(&certLen) + var cert []byte + if certLen == 0 || !certList.ReadBytes(&cert, int(certLen)) { + return false + } + m.certificates = append(m.certificates, cert) } - - numCerts := int(data[0])<<8 | int(data[1]) - data = data[2:] - - s.certificates = make([][]byte, numCerts) - for i := range s.certificates { - if len(data) < 4 { - return false - } - certLen := int(data[0])<<24 | int(data[1])<<16 | int(data[2])<<8 | int(data[3]) - data = data[4:] - if certLen < 0 { - return false - } - if len(data) < certLen { - return false - } - s.certificates[i] = data[:certLen] - data = data[certLen:] - } - - return len(data) == 0 + return s.Empty() } // sessionStateTLS13 is the content of a TLS 1.3 session ticket. Its first