diff --git a/api/next/44886.txt b/api/next/44886.txt new file mode 100644 index 0000000000..b3ab6996ea --- /dev/null +++ b/api/next/44886.txt @@ -0,0 +1,41 @@ +pkg crypto/tls, const QUICEncryptionLevelApplication = 2 #44886 +pkg crypto/tls, const QUICEncryptionLevelApplication QUICEncryptionLevel #44886 +pkg crypto/tls, const QUICEncryptionLevelHandshake = 1 #44886 +pkg crypto/tls, const QUICEncryptionLevelHandshake QUICEncryptionLevel #44886 +pkg crypto/tls, const QUICEncryptionLevelInitial = 0 #44886 +pkg crypto/tls, const QUICEncryptionLevelInitial QUICEncryptionLevel #44886 +pkg crypto/tls, const QUICHandshakeDone = 6 #44886 +pkg crypto/tls, const QUICHandshakeDone QUICEventKind #44886 +pkg crypto/tls, const QUICNoEvent = 0 #44886 +pkg crypto/tls, const QUICNoEvent QUICEventKind #44886 +pkg crypto/tls, const QUICSetReadSecret = 1 #44886 +pkg crypto/tls, const QUICSetReadSecret QUICEventKind #44886 +pkg crypto/tls, const QUICSetWriteSecret = 2 #44886 +pkg crypto/tls, const QUICSetWriteSecret QUICEventKind #44886 +pkg crypto/tls, const QUICTransportParameters = 4 #44886 +pkg crypto/tls, const QUICTransportParameters QUICEventKind #44886 +pkg crypto/tls, const QUICTransportParametersRequired = 5 #44886 +pkg crypto/tls, const QUICTransportParametersRequired QUICEventKind #44886 +pkg crypto/tls, const QUICWriteData = 3 #44886 +pkg crypto/tls, const QUICWriteData QUICEventKind #44886 +pkg crypto/tls, func QUICClient(*QUICConfig) *QUICConn #44886 +pkg crypto/tls, func QUICServer(*QUICConfig) *QUICConn #44886 +pkg crypto/tls, method (*QUICConn) Close() error #44886 +pkg crypto/tls, method (*QUICConn) ConnectionState() ConnectionState #44886 +pkg crypto/tls, method (*QUICConn) HandleData(QUICEncryptionLevel, []uint8) error #44886 +pkg crypto/tls, method (*QUICConn) NextEvent() QUICEvent #44886 +pkg crypto/tls, method (*QUICConn) SetTransportParameters([]uint8) #44886 +pkg crypto/tls, method (*QUICConn) Start(context.Context) error #44886 +pkg crypto/tls, method (AlertError) Error() string #44886 +pkg crypto/tls, method (QUICEncryptionLevel) String() string #44886 +pkg crypto/tls, type AlertError uint8 #44886 +pkg crypto/tls, type QUICConfig struct #44886 +pkg crypto/tls, type QUICConfig struct, TLSConfig *Config #44886 +pkg crypto/tls, type QUICConn struct #44886 +pkg crypto/tls, type QUICEncryptionLevel int #44886 +pkg crypto/tls, type QUICEvent struct #44886 +pkg crypto/tls, type QUICEvent struct, Data []uint8 #44886 +pkg crypto/tls, type QUICEvent struct, Kind QUICEventKind #44886 +pkg crypto/tls, type QUICEvent struct, Level QUICEncryptionLevel #44886 +pkg crypto/tls, type QUICEvent struct, Suite uint16 #44886 +pkg crypto/tls, type QUICEventKind int #44886 diff --git a/src/crypto/tls/alert.go b/src/crypto/tls/alert.go index 4790b73724..33022cd2b4 100644 --- a/src/crypto/tls/alert.go +++ b/src/crypto/tls/alert.go @@ -6,6 +6,16 @@ package tls import "strconv" +// An AlertError is a TLS alert. +// +// When using a QUIC transport, QUICConn methods will return an error +// which wraps AlertError rather than sending a TLS alert. +type AlertError uint8 + +func (e AlertError) Error() string { + return alert(e).String() +} + type alert uint8 const ( diff --git a/src/crypto/tls/common.go b/src/crypto/tls/common.go index 5394d64ac6..b8332e90fd 100644 --- a/src/crypto/tls/common.go +++ b/src/crypto/tls/common.go @@ -99,6 +99,7 @@ const ( extensionCertificateAuthorities uint16 = 47 extensionSignatureAlgorithmsCert uint16 = 50 extensionKeyShare uint16 = 51 + extensionQUICTransportParameters uint16 = 57 extensionRenegotiationInfo uint16 = 0xff01 ) diff --git a/src/crypto/tls/conn.go b/src/crypto/tls/conn.go index 847d3f8f06..e3607c8fec 100644 --- a/src/crypto/tls/conn.go +++ b/src/crypto/tls/conn.go @@ -29,6 +29,7 @@ type Conn struct { conn net.Conn isClient bool handshakeFn func(context.Context) error // (*Conn).clientHandshake or serverHandshake + quic *quicState // nil for non-QUIC connections // isHandshakeComplete is true if the connection is currently transferring // application data (i.e. is not currently processing a handshake). @@ -176,7 +177,8 @@ type halfConn struct { nextCipher any // next encryption state nextMac hash.Hash // next MAC algorithm - trafficSecret []byte // current TLS 1.3 traffic secret + level QUICEncryptionLevel // current QUIC encryption level + trafficSecret []byte // current TLS 1.3 traffic secret } type permanentError struct { @@ -221,8 +223,9 @@ func (hc *halfConn) changeCipherSpec() error { return nil } -func (hc *halfConn) setTrafficSecret(suite *cipherSuiteTLS13, secret []byte) { +func (hc *halfConn) setTrafficSecret(suite *cipherSuiteTLS13, level QUICEncryptionLevel, secret []byte) { hc.trafficSecret = secret + hc.level = level key, iv := suite.trafficKey(secret) hc.cipher = suite.aead(key, iv) for i := range hc.seq { @@ -613,6 +616,10 @@ func (c *Conn) readRecordOrCCS(expectChangeCipherSpec bool) error { } c.input.Reset(nil) + if c.quic != nil { + return c.in.setErrorLocked(errors.New("tls: internal error: attempted to read record with QUIC transport")) + } + // Read header, payload. if err := c.readFromUntil(c.conn, recordHeaderLen); err != nil { // RFC 8446, Section 6.1 suggests that EOF without an alertCloseNotify @@ -702,6 +709,9 @@ func (c *Conn) readRecordOrCCS(expectChangeCipherSpec bool) error { return c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage)) case recordTypeAlert: + if c.quic != nil { + return c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage)) + } if len(data) != 2 { return c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage)) } @@ -819,6 +829,10 @@ func (c *Conn) readFromUntil(r io.Reader, n int) error { // sendAlertLocked sends a TLS alert message. func (c *Conn) sendAlertLocked(err alert) error { + if c.quic != nil { + return c.out.setErrorLocked(&net.OpError{Op: "local error", Err: err}) + } + switch err { case alertNoRenegotiation, alertCloseNotify: c.tmp[0] = alertLevelWarning @@ -953,6 +967,19 @@ var outBufPool = sync.Pool{ // writeRecordLocked writes a TLS record with the given type and payload to the // connection and updates the record layer state. func (c *Conn) writeRecordLocked(typ recordType, data []byte) (int, error) { + if c.quic != nil { + if typ != recordTypeHandshake { + return 0, errors.New("tls: internal error: sending non-handshake message to QUIC transport") + } + c.quicWriteCryptoData(c.out.level, data) + if !c.buffering { + if _, err := c.flush(); err != nil { + return 0, err + } + } + return len(data), nil + } + outBufPtr := outBufPool.Get().(*[]byte) outBuf := *outBufPtr defer func() { @@ -1037,28 +1064,40 @@ func (c *Conn) writeChangeCipherRecord() error { return err } +// readHandshakeBytes reads handshake data until c.hand contains at least n bytes. +func (c *Conn) readHandshakeBytes(n int) error { + if c.quic != nil { + return c.quicReadHandshakeBytes(n) + } + for c.hand.Len() < n { + if err := c.readRecord(); err != nil { + return err + } + } + return nil +} + // readHandshake reads the next handshake message from // the record layer. If transcript is non-nil, the message // is written to the passed transcriptHash. func (c *Conn) readHandshake(transcript transcriptHash) (any, error) { - for c.hand.Len() < 4 { - if err := c.readRecord(); err != nil { - return nil, err - } + if err := c.readHandshakeBytes(4); err != nil { + return nil, err } - data := c.hand.Bytes() n := int(data[1])<<16 | int(data[2])<<8 | int(data[3]) if n > maxHandshake { c.sendAlertLocked(alertInternalError) return nil, c.in.setErrorLocked(fmt.Errorf("tls: handshake message of length %d bytes exceeds maximum of %d bytes", n, maxHandshake)) } - for c.hand.Len() < 4+n { - if err := c.readRecord(); err != nil { - return nil, err - } + if err := c.readHandshakeBytes(4 + n); err != nil { + return nil, err } data = c.hand.Next(4 + n) + return c.unmarshalHandshakeMessage(data, transcript) +} + +func (c *Conn) unmarshalHandshakeMessage(data []byte, transcript transcriptHash) (handshakeMessage, error) { var m handshakeMessage switch data[0] { case typeHelloRequest: @@ -1249,7 +1288,6 @@ func (c *Conn) handlePostHandshakeMessage() error { if err != nil { return err } - c.retryCount++ if c.retryCount > maxUselessRecords { c.sendAlert(alertUnexpectedMessage) @@ -1261,20 +1299,28 @@ func (c *Conn) handlePostHandshakeMessage() error { return c.handleNewSessionTicket(msg) case *keyUpdateMsg: return c.handleKeyUpdate(msg) - default: - c.sendAlert(alertUnexpectedMessage) - return fmt.Errorf("tls: received unexpected handshake message of type %T", msg) } + // The QUIC layer is supposed to treat an unexpected post-handshake CertificateRequest + // as a QUIC-level PROTOCOL_VIOLATION error (RFC 9001, Section 4.4). Returning an + // unexpected_message alert here doesn't provide it with enough information to distinguish + // this condition from other unexpected messages. This is probably fine. + c.sendAlert(alertUnexpectedMessage) + return fmt.Errorf("tls: received unexpected handshake message of type %T", msg) } func (c *Conn) handleKeyUpdate(keyUpdate *keyUpdateMsg) error { + if c.quic != nil { + c.sendAlert(alertUnexpectedMessage) + return c.in.setErrorLocked(errors.New("tls: received unexpected key update message")) + } + cipherSuite := cipherSuiteTLS13ByID(c.cipherSuite) if cipherSuite == nil { return c.in.setErrorLocked(c.sendAlert(alertInternalError)) } newSecret := cipherSuite.nextTrafficSecret(c.in.trafficSecret) - c.in.setTrafficSecret(cipherSuite, newSecret) + c.in.setTrafficSecret(cipherSuite, QUICEncryptionLevelInitial, newSecret) if keyUpdate.updateRequested { c.out.Lock() @@ -1293,7 +1339,7 @@ func (c *Conn) handleKeyUpdate(keyUpdate *keyUpdateMsg) error { } newSecret := cipherSuite.nextTrafficSecret(c.out.trafficSecret) - c.out.setTrafficSecret(cipherSuite, newSecret) + c.out.setTrafficSecret(cipherSuite, QUICEncryptionLevelInitial, newSecret) } return nil @@ -1454,12 +1500,15 @@ func (c *Conn) handshakeContext(ctx context.Context) (ret error) { // this cancellation. In the former case, we need to close the connection. defer cancel() - // Start the "interrupter" goroutine, if this context might be canceled. - // (The background context cannot). - // - // The interrupter goroutine waits for the input context to be done and - // closes the connection if this happens before the function returns. - if ctx.Done() != nil { + if c.quic != nil { + c.quic.cancelc = handshakeCtx.Done() + c.quic.cancel = cancel + } else if ctx.Done() != nil { + // Start the "interrupter" goroutine, if this context might be canceled. + // (The background context cannot). + // + // The interrupter goroutine waits for the input context to be done and + // closes the connection if this happens before the function returns. done := make(chan struct{}) interruptRes := make(chan error, 1) defer func() { @@ -1510,6 +1559,30 @@ func (c *Conn) handshakeContext(ctx context.Context) (ret error) { panic("tls: internal error: handshake returned an error but is marked successful") } + if c.quic != nil { + if c.handshakeErr == nil { + c.quicHandshakeComplete() + // Provide the 1-RTT read secret now that the handshake is complete. + // The QUIC layer MUST NOT decrypt 1-RTT packets prior to completing + // the handshake (RFC 9001, Section 5.7). + c.quicSetReadSecret(QUICEncryptionLevelApplication, c.cipherSuite, c.in.trafficSecret) + } else { + var a alert + c.out.Lock() + if !errors.As(c.out.err, &a) { + a = alertInternalError + } + c.out.Unlock() + // Return an error which wraps both the handshake error and + // any alert error we may have sent, or alertInternalError + // if we didn't send an alert. + // Truncate the text of the alert to 0 characters. + c.handshakeErr = fmt.Errorf("%w%.0w", c.handshakeErr, AlertError(a)) + } + close(c.quic.blockedc) + close(c.quic.signalc) + } + return c.handshakeErr } diff --git a/src/crypto/tls/handshake_client.go b/src/crypto/tls/handshake_client.go index 63d86b9f3a..9f74cc4ef9 100644 --- a/src/crypto/tls/handshake_client.go +++ b/src/crypto/tls/handshake_client.go @@ -71,7 +71,6 @@ func (c *Conn) makeClientHello() (*clientHelloMsg, *ecdh.PrivateKey, error) { vers: clientHelloVersion, compressionMethods: []uint8{compressionNone}, random: make([]byte, 32), - sessionId: make([]byte, 32), ocspStapling: true, scts: true, serverName: hostnameInSNI(config.ServerName), @@ -114,8 +113,13 @@ func (c *Conn) makeClientHello() (*clientHelloMsg, *ecdh.PrivateKey, error) { // A random session ID is used to detect when the server accepted a ticket // and is resuming a session (see RFC 5077). In TLS 1.3, it's always set as // a compatibility measure (see RFC 8446, Section 4.1.2). - if _, err := io.ReadFull(config.rand(), hello.sessionId); err != nil { - return nil, nil, errors.New("tls: short read from Rand: " + err.Error()) + // + // The session ID is not set for QUIC connections (see RFC 9001, Section 8.4). + if c.quic == nil { + hello.sessionId = make([]byte, 32) + if _, err := io.ReadFull(config.rand(), hello.sessionId); err != nil { + return nil, nil, errors.New("tls: short read from Rand: " + err.Error()) + } } if hello.vers >= VersionTLS12 { @@ -144,6 +148,17 @@ func (c *Conn) makeClientHello() (*clientHelloMsg, *ecdh.PrivateKey, error) { hello.keyShares = []keyShare{{group: curveID, data: key.PublicKey().Bytes()}} } + if c.quic != nil { + p, err := c.quicGetTransportParameters() + if err != nil { + return nil, nil, err + } + if p == nil { + p = []byte{} + } + hello.quicTransportParameters = p + } + return hello, key, nil } @@ -271,7 +286,10 @@ func (c *Conn) loadSession(hello *clientHelloMsg) (cacheKey string, } // Try to resume a previously negotiated TLS session, if available. - cacheKey = clientSessionCacheKey(c.conn.RemoteAddr(), c.config) + cacheKey = c.clientSessionCacheKey() + if cacheKey == "" { + return "", nil, nil, nil, nil + } session, ok := c.config.ClientSessionCache.Get(cacheKey) if !ok || session == nil { return cacheKey, nil, nil, nil, nil @@ -722,7 +740,7 @@ func (hs *clientHandshakeState) processServerHello() (bool, error) { } } - if err := checkALPN(hs.hello.alpnProtocols, hs.serverHello.alpnProtocol); err != nil { + if err := checkALPN(hs.hello.alpnProtocols, hs.serverHello.alpnProtocol, false); err != nil { c.sendAlert(alertUnsupportedExtension) return false, err } @@ -760,8 +778,12 @@ func (hs *clientHandshakeState) processServerHello() (bool, error) { // checkALPN ensure that the server's choice of ALPN protocol is compatible with // the protocols that we advertised in the Client Hello. -func checkALPN(clientProtos []string, serverProto string) error { +func checkALPN(clientProtos []string, serverProto string, quic bool) error { if serverProto == "" { + if quic && len(clientProtos) > 0 { + // RFC 9001, Section 8.1 + return errors.New("tls: server did not select an ALPN protocol") + } return nil } if len(clientProtos) == 0 { @@ -1003,11 +1025,14 @@ func (c *Conn) getClientCertificate(cri *CertificateRequestInfo) (*Certificate, // clientSessionCacheKey returns a key used to cache sessionTickets that could // be used to resume previously negotiated TLS sessions with a server. -func clientSessionCacheKey(serverAddr net.Addr, config *Config) string { - if len(config.ServerName) > 0 { - return config.ServerName +func (c *Conn) clientSessionCacheKey() string { + if len(c.config.ServerName) > 0 { + return c.config.ServerName } - return serverAddr.String() + if c.conn != nil { + return c.conn.RemoteAddr().String() + } + return "" } // hostnameInSNI converts name into an appropriate hostname for SNI. diff --git a/src/crypto/tls/handshake_client_tls13.go b/src/crypto/tls/handshake_client_tls13.go index 4a8661085e..15e0a74848 100644 --- a/src/crypto/tls/handshake_client_tls13.go +++ b/src/crypto/tls/handshake_client_tls13.go @@ -172,6 +172,9 @@ func (hs *clientHandshakeStateTLS13) checkServerHelloOrHRR() error { // sendDummyChangeCipherSpec sends a ChangeCipherSpec record for compatibility // with middleboxes that didn't implement TLS correctly. See RFC 8446, Appendix D.4. func (hs *clientHandshakeStateTLS13) sendDummyChangeCipherSpec() error { + if hs.c.quic != nil { + return nil + } if hs.sentDummyCCS { return nil } @@ -383,10 +386,18 @@ func (hs *clientHandshakeStateTLS13) establishHandshakeKeys() error { clientSecret := hs.suite.deriveSecret(handshakeSecret, clientHandshakeTrafficLabel, hs.transcript) - c.out.setTrafficSecret(hs.suite, clientSecret) + c.out.setTrafficSecret(hs.suite, QUICEncryptionLevelHandshake, clientSecret) serverSecret := hs.suite.deriveSecret(handshakeSecret, serverHandshakeTrafficLabel, hs.transcript) - c.in.setTrafficSecret(hs.suite, serverSecret) + c.in.setTrafficSecret(hs.suite, QUICEncryptionLevelHandshake, serverSecret) + + if c.quic != nil { + if c.hand.Len() != 0 { + c.sendAlert(alertUnexpectedMessage) + } + c.quicSetWriteSecret(QUICEncryptionLevelHandshake, hs.suite.id, clientSecret) + c.quicSetReadSecret(QUICEncryptionLevelHandshake, hs.suite.id, serverSecret) + } err = c.config.writeKeyLog(keyLogLabelClientHandshake, hs.hello.random, clientSecret) if err != nil { @@ -419,12 +430,30 @@ func (hs *clientHandshakeStateTLS13) readServerParameters() error { return unexpectedMessageError(encryptedExtensions, msg) } - if err := checkALPN(hs.hello.alpnProtocols, encryptedExtensions.alpnProtocol); err != nil { - c.sendAlert(alertUnsupportedExtension) + if err := checkALPN(hs.hello.alpnProtocols, encryptedExtensions.alpnProtocol, c.quic != nil); err != nil { + // RFC 8446 specifies that no_application_protocol is sent by servers, but + // does not specify how clients handle the selection of an incompatible protocol. + // RFC 9001 Section 8.1 specifies that QUIC clients send no_application_protocol + // in this case. Always sending no_application_protocol seems reasonable. + c.sendAlert(alertNoApplicationProtocol) return err } c.clientProtocol = encryptedExtensions.alpnProtocol + if c.quic != nil { + if encryptedExtensions.quicTransportParameters == nil { + // RFC 9001 Section 8.2. + c.sendAlert(alertMissingExtension) + return errors.New("tls: server did not send a quic_transport_parameters extension") + } + c.quicSetTransportParameters(encryptedExtensions.quicTransportParameters) + } else { + if encryptedExtensions.quicTransportParameters != nil { + c.sendAlert(alertUnsupportedExtension) + return errors.New("tls: server sent an unexpected quic_transport_parameters extension") + } + } + return nil } @@ -552,7 +581,7 @@ func (hs *clientHandshakeStateTLS13) readServerFinished() error { clientApplicationTrafficLabel, hs.transcript) serverSecret := hs.suite.deriveSecret(hs.masterSecret, serverApplicationTrafficLabel, hs.transcript) - c.in.setTrafficSecret(hs.suite, serverSecret) + c.in.setTrafficSecret(hs.suite, QUICEncryptionLevelApplication, serverSecret) err = c.config.writeKeyLog(keyLogLabelClientTraffic, hs.hello.random, hs.trafficSecret) if err != nil { @@ -648,13 +677,20 @@ func (hs *clientHandshakeStateTLS13) sendClientFinished() error { return err } - c.out.setTrafficSecret(hs.suite, hs.trafficSecret) + c.out.setTrafficSecret(hs.suite, QUICEncryptionLevelApplication, hs.trafficSecret) if !c.config.SessionTicketsDisabled && c.config.ClientSessionCache != nil { c.resumptionSecret = hs.suite.deriveSecret(hs.masterSecret, resumptionLabel, hs.transcript) } + if c.quic != nil { + if c.hand.Len() != 0 { + c.sendAlert(alertUnexpectedMessage) + } + c.quicSetWriteSecret(QUICEncryptionLevelApplication, hs.suite.id, hs.trafficSecret) + } + return nil } @@ -702,8 +738,10 @@ func (c *Conn) handleNewSessionTicket(msg *newSessionTicketMsgTLS13) error { scts: c.scts, } - cacheKey := clientSessionCacheKey(c.conn.RemoteAddr(), c.config) - c.config.ClientSessionCache.Put(cacheKey, session) + cacheKey := c.clientSessionCacheKey() + if cacheKey != "" { + c.config.ClientSessionCache.Put(cacheKey, session) + } return nil } diff --git a/src/crypto/tls/handshake_messages.go b/src/crypto/tls/handshake_messages.go index 695aacf126..eac01fd085 100644 --- a/src/crypto/tls/handshake_messages.go +++ b/src/crypto/tls/handshake_messages.go @@ -93,6 +93,7 @@ type clientHelloMsg struct { pskModes []uint8 pskIdentities []pskIdentity pskBinders [][]byte + quicTransportParameters []byte } func (m *clientHelloMsg) marshal() ([]byte, error) { @@ -246,6 +247,13 @@ func (m *clientHelloMsg) marshal() ([]byte, error) { }) }) } + if m.quicTransportParameters != nil { // marshal zero-length parameters when present + // RFC 9001, Section 8.2 + exts.AddUint16(extensionQUICTransportParameters) + exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) { + exts.AddBytes(m.quicTransportParameters) + }) + } if len(m.pskIdentities) > 0 { // pre_shared_key must be the last extension // RFC 8446, Section 4.2.11 exts.AddUint16(extensionPreSharedKey) @@ -560,6 +568,11 @@ func (m *clientHelloMsg) unmarshal(data []byte) bool { if !readUint8LengthPrefixed(&extData, &m.pskModes) { return false } + case extensionQUICTransportParameters: + m.quicTransportParameters = make([]byte, len(extData)) + if !extData.CopyBytes(m.quicTransportParameters) { + return false + } case extensionPreSharedKey: // RFC 8446, Section 4.2.11 if !extensions.Empty() { @@ -860,8 +873,9 @@ func (m *serverHelloMsg) unmarshal(data []byte) bool { } type encryptedExtensionsMsg struct { - raw []byte - alpnProtocol string + raw []byte + alpnProtocol string + quicTransportParameters []byte } func (m *encryptedExtensionsMsg) marshal() ([]byte, error) { @@ -883,6 +897,13 @@ func (m *encryptedExtensionsMsg) marshal() ([]byte, error) { }) }) } + if m.quicTransportParameters != nil { // marshal zero-length parameters when present + // draft-ietf-quic-tls-32, Section 8.2 + b.AddUint16(extensionQUICTransportParameters) + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddBytes(m.quicTransportParameters) + }) + } }) }) @@ -921,6 +942,11 @@ func (m *encryptedExtensionsMsg) unmarshal(data []byte) bool { return false } m.alpnProtocol = string(proto) + case extensionQUICTransportParameters: + m.quicTransportParameters = make([]byte, len(extData)) + if !extData.CopyBytes(m.quicTransportParameters) { + return false + } default: // Ignore unknown extensions. continue diff --git a/src/crypto/tls/handshake_messages_test.go b/src/crypto/tls/handshake_messages_test.go index 206e2fb024..1ef6c432ff 100644 --- a/src/crypto/tls/handshake_messages_test.go +++ b/src/crypto/tls/handshake_messages_test.go @@ -197,6 +197,9 @@ func (*clientHelloMsg) Generate(rand *rand.Rand, size int) reflect.Value { m.pskIdentities = append(m.pskIdentities, psk) m.pskBinders = append(m.pskBinders, randomBytes(rand.Intn(50)+32, rand)) } + if rand.Intn(10) > 5 { + m.quicTransportParameters = randomBytes(rand.Intn(500), rand) + } if rand.Intn(10) > 5 { m.earlyData = true } diff --git a/src/crypto/tls/handshake_server.go b/src/crypto/tls/handshake_server.go index a17ba2fe27..450c5f7714 100644 --- a/src/crypto/tls/handshake_server.go +++ b/src/crypto/tls/handshake_server.go @@ -218,7 +218,7 @@ func (hs *serverHandshakeState) processClientHello() error { c.serverName = hs.clientHello.serverName } - selectedProto, err := negotiateALPN(c.config.NextProtos, hs.clientHello.alpnProtocols) + selectedProto, err := negotiateALPN(c.config.NextProtos, hs.clientHello.alpnProtocols, false) if err != nil { c.sendAlert(alertNoApplicationProtocol) return err @@ -279,8 +279,12 @@ func (hs *serverHandshakeState) processClientHello() error { // negotiateALPN picks a shared ALPN protocol that both sides support in server // preference order. If ALPN is not configured or the peer doesn't support it, // it returns "" and no error. -func negotiateALPN(serverProtos, clientProtos []string) (string, error) { +func negotiateALPN(serverProtos, clientProtos []string, quic bool) (string, error) { if len(serverProtos) == 0 || len(clientProtos) == 0 { + if quic && len(serverProtos) != 0 { + // RFC 9001, Section 8.1 + return "", fmt.Errorf("tls: client did not request an application protocol") + } return "", nil } var http11fallback bool diff --git a/src/crypto/tls/handshake_server_tls13.go b/src/crypto/tls/handshake_server_tls13.go index b7b568cd84..69ebe1c7d5 100644 --- a/src/crypto/tls/handshake_server_tls13.go +++ b/src/crypto/tls/handshake_server_tls13.go @@ -226,6 +226,20 @@ GroupSelection: return errors.New("tls: invalid client key share") } + if c.quic != nil { + if hs.clientHello.quicTransportParameters == nil { + // RFC 9001 Section 8.2. + c.sendAlert(alertMissingExtension) + return errors.New("tls: client did not send a quic_transport_parameters extension") + } + c.quicSetTransportParameters(hs.clientHello.quicTransportParameters) + } else { + if hs.clientHello.quicTransportParameters != nil { + c.sendAlert(alertUnsupportedExtension) + return errors.New("tls: client sent an unexpected quic_transport_parameters extension") + } + } + c.serverName = hs.clientHello.serverName return nil } @@ -397,6 +411,9 @@ func (hs *serverHandshakeStateTLS13) pickCertificate() error { // sendDummyChangeCipherSpec sends a ChangeCipherSpec record for compatibility // with middleboxes that didn't implement TLS correctly. See RFC 8446, Appendix D.4. func (hs *serverHandshakeStateTLS13) sendDummyChangeCipherSpec() error { + if hs.c.quic != nil { + return nil + } if hs.sentDummyCCS { return nil } @@ -548,10 +565,18 @@ func (hs *serverHandshakeStateTLS13) sendServerParameters() error { clientSecret := hs.suite.deriveSecret(hs.handshakeSecret, clientHandshakeTrafficLabel, hs.transcript) - c.in.setTrafficSecret(hs.suite, clientSecret) + c.in.setTrafficSecret(hs.suite, QUICEncryptionLevelHandshake, clientSecret) serverSecret := hs.suite.deriveSecret(hs.handshakeSecret, serverHandshakeTrafficLabel, hs.transcript) - c.out.setTrafficSecret(hs.suite, serverSecret) + c.out.setTrafficSecret(hs.suite, QUICEncryptionLevelHandshake, serverSecret) + + if c.quic != nil { + if c.hand.Len() != 0 { + c.sendAlert(alertUnexpectedMessage) + } + c.quicSetWriteSecret(QUICEncryptionLevelHandshake, hs.suite.id, serverSecret) + c.quicSetReadSecret(QUICEncryptionLevelHandshake, hs.suite.id, clientSecret) + } err := c.config.writeKeyLog(keyLogLabelClientHandshake, hs.clientHello.random, clientSecret) if err != nil { @@ -566,7 +591,7 @@ func (hs *serverHandshakeStateTLS13) sendServerParameters() error { encryptedExtensions := new(encryptedExtensionsMsg) - selectedProto, err := negotiateALPN(c.config.NextProtos, hs.clientHello.alpnProtocols) + selectedProto, err := negotiateALPN(c.config.NextProtos, hs.clientHello.alpnProtocols, c.quic != nil) if err != nil { c.sendAlert(alertNoApplicationProtocol) return err @@ -574,6 +599,14 @@ func (hs *serverHandshakeStateTLS13) sendServerParameters() error { encryptedExtensions.alpnProtocol = selectedProto c.clientProtocol = selectedProto + if c.quic != nil { + p, err := c.quicGetTransportParameters() + if err != nil { + return err + } + encryptedExtensions.quicTransportParameters = p + } + if _, err := hs.c.writeHandshakeRecord(encryptedExtensions, hs.transcript); err != nil { return err } @@ -672,7 +705,15 @@ func (hs *serverHandshakeStateTLS13) sendServerFinished() error { clientApplicationTrafficLabel, hs.transcript) serverSecret := hs.suite.deriveSecret(hs.masterSecret, serverApplicationTrafficLabel, hs.transcript) - c.out.setTrafficSecret(hs.suite, serverSecret) + c.out.setTrafficSecret(hs.suite, QUICEncryptionLevelApplication, serverSecret) + + if c.quic != nil { + if c.hand.Len() != 0 { + // TODO: Handle this in setTrafficSecret? + c.sendAlert(alertUnexpectedMessage) + } + c.quicSetWriteSecret(QUICEncryptionLevelApplication, hs.suite.id, serverSecret) + } err := c.config.writeKeyLog(keyLogLabelClientTraffic, hs.clientHello.random, hs.trafficSecret) if err != nil { @@ -887,7 +928,7 @@ func (hs *serverHandshakeStateTLS13) readClientFinished() error { return errors.New("tls: invalid client finished hash") } - c.in.setTrafficSecret(hs.suite, hs.trafficSecret) + c.in.setTrafficSecret(hs.suite, QUICEncryptionLevelApplication, hs.trafficSecret) return nil } diff --git a/src/crypto/tls/quic.go b/src/crypto/tls/quic.go new file mode 100644 index 0000000000..a59b893738 --- /dev/null +++ b/src/crypto/tls/quic.go @@ -0,0 +1,376 @@ +// Copyright 2023 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package tls + +import ( + "context" + "errors" + "fmt" +) + +// QUICEncryptionLevel represents a QUIC encryption level used to transmit +// handshake messages. +type QUICEncryptionLevel int + +const ( + QUICEncryptionLevelInitial = QUICEncryptionLevel(iota) + QUICEncryptionLevelHandshake + QUICEncryptionLevelApplication +) + +func (l QUICEncryptionLevel) String() string { + switch l { + case QUICEncryptionLevelInitial: + return "Initial" + case QUICEncryptionLevelHandshake: + return "Handshake" + case QUICEncryptionLevelApplication: + return "Application" + default: + return fmt.Sprintf("QUICEncryptionLevel(%v)", int(l)) + } +} + +// A QUICConn represents a connection which uses a QUIC implementation as the underlying +// transport as described in RFC 9001. +// +// Methods of QUICConn are not safe for concurrent use. +type QUICConn struct { + conn *Conn +} + +// A QUICConfig configures a QUICConn. +type QUICConfig struct { + TLSConfig *Config +} + +// A QUICEventKind is a type of operation on a QUIC connection. +type QUICEventKind int + +const ( + // QUICNoEvent indicates that there are no events available. + QUICNoEvent QUICEventKind = iota + + // QUICSetReadSecret and QUICSetWriteSecret provide the read and write + // secrets for a given encryption level. + // QUICEvent.Level, QUICEvent.Data, and QUICEvent.Suite are set. + // + // Secrets for the Initial encryption level are derived from the initial + // destination connection ID, and are not provided by the QUICConn. + QUICSetReadSecret + QUICSetWriteSecret + + // QUICWriteData provides data to send to the peer in CRYPTO frames. + // QUICEvent.Data is set. + QUICWriteData + + // QUICTransportParameters provides the peer's QUIC transport parameters. + // QUICEvent.Data is set. + QUICTransportParameters + + // QUICTransportParametersRequired indicates that the caller must provide + // QUIC transport parameters to send to the peer. The caller should set + // the transport parameters with QUICConn.SetTransportParameters and call + // QUICConn.NextEvent again. + // + // If transport parameters are set before calling QUICConn.Start, the + // connection will never generate a QUICTransportParametersRequired event. + QUICTransportParametersRequired + + // QUICHandshakeDone indicates that the TLS handshake has completed. + QUICHandshakeDone +) + +// A QUICEvent is an event occurring on a QUIC connection. +// +// The type of event is specified by the Kind field. +// The contents of the other fields are kind-specific. +type QUICEvent struct { + Kind QUICEventKind + + // Set for QUICSetReadSecret, QUICSetWriteSecret, and QUICWriteData. + Level QUICEncryptionLevel + + // Set for QUICTransportParameters, QUICSetReadSecret, QUICSetWriteSecret, and QUICWriteData. + // The contents are owned by crypto/tls, and are valid until the next NextEvent call. + Data []byte + + // Set for QUICSetReadSecret and QUICSetWriteSecret. + Suite uint16 +} + +type quicState struct { + events []QUICEvent + nextEvent int + + // eventArr is a statically allocated event array, large enough to handle + // the usual maximum number of events resulting from a single call: + // transport parameters, Initial data, Handshake write and read secrets, + // Handshake data, Application write secret, Application data. + eventArr [7]QUICEvent + + started bool + signalc chan struct{} // handshake data is available to be read + blockedc chan struct{} // handshake is waiting for data, closed when done + cancelc <-chan struct{} // handshake has been canceled + cancel context.CancelFunc + + // readbuf is shared between HandleData and the handshake goroutine. + // HandshakeCryptoData passes ownership to the handshake goroutine by + // reading from signalc, and reclaims ownership by reading from blockedc. + readbuf []byte + + transportParams []byte // to send to the peer +} + +// QUICClient returns a new TLS client side connection using QUICTransport as the +// underlying transport. The config cannot be nil. +// +// The config's MinVersion must be at least TLS 1.3. +func QUICClient(config *QUICConfig) *QUICConn { + return newQUICConn(Client(nil, config.TLSConfig)) +} + +// QUICServer returns a new TLS server side connection using QUICTransport as the +// underlying transport. The config cannot be nil. +// +// The config's MinVersion must be at least TLS 1.3. +func QUICServer(config *QUICConfig) *QUICConn { + return newQUICConn(Server(nil, config.TLSConfig)) +} + +func newQUICConn(conn *Conn) *QUICConn { + conn.quic = &quicState{ + signalc: make(chan struct{}), + blockedc: make(chan struct{}), + } + conn.quic.events = conn.quic.eventArr[:0] + return &QUICConn{ + conn: conn, + } +} + +// Start starts the client or server handshake protocol. +// It may produce connection events, which may be read with NextEvent. +// +// Start must be called at most once. +func (q *QUICConn) Start(ctx context.Context) error { + if q.conn.quic.started { + return quicError(errors.New("tls: Start called more than once")) + } + q.conn.quic.started = true + if q.conn.config.MinVersion < VersionTLS13 { + return quicError(errors.New("tls: Config MinVersion must be at least TLS 1.13")) + } + go q.conn.HandshakeContext(ctx) + if _, ok := <-q.conn.quic.blockedc; !ok { + return q.conn.handshakeErr + } + return nil +} + +// NextEvent returns the next event occurring on the connection. +// It returns an event with a Kind of QUICNoEvent when no events are available. +func (q *QUICConn) NextEvent() QUICEvent { + qs := q.conn.quic + if last := qs.nextEvent - 1; last >= 0 && len(qs.events[last].Data) > 0 { + // Write over some of the previous event's data, + // to catch callers erroniously retaining it. + qs.events[last].Data[0] = 0 + } + if qs.nextEvent >= len(qs.events) { + qs.events = qs.events[:0] + qs.nextEvent = 0 + return QUICEvent{Kind: QUICNoEvent} + } + e := qs.events[qs.nextEvent] + qs.events[qs.nextEvent] = QUICEvent{} // zero out references to data + qs.nextEvent++ + return e +} + +// Close closes the connection and stops any in-progress handshake. +func (q *QUICConn) Close() error { + if q.conn.quic.cancel == nil { + return nil // never started + } + q.conn.quic.cancel() + for range q.conn.quic.blockedc { + // Wait for the handshake goroutine to return. + } + return q.conn.handshakeErr +} + +// HandleData handles handshake bytes received from the peer. +// It may produce connection events, which may be read with NextEvent. +func (q *QUICConn) HandleData(level QUICEncryptionLevel, data []byte) error { + c := q.conn + if c.in.level != level { + return quicError(c.in.setErrorLocked(errors.New("tls: handshake data received at wrong level"))) + } + c.quic.readbuf = data + <-c.quic.signalc + _, ok := <-c.quic.blockedc + if ok { + // The handshake goroutine is waiting for more data. + return nil + } + // The handshake goroutine has exited. + c.hand.Write(c.quic.readbuf) + c.quic.readbuf = nil + for q.conn.hand.Len() >= 4 && q.conn.handshakeErr == nil { + b := q.conn.hand.Bytes() + n := int(b[1])<<16 | int(b[2])<<8 | int(b[3]) + if 4+n < len(b) { + return nil + } + if err := q.conn.handlePostHandshakeMessage(); err != nil { + return quicError(err) + } + } + if q.conn.handshakeErr != nil { + return quicError(q.conn.handshakeErr) + } + return nil +} + +// ConnectionState returns basic TLS details about the connection. +func (q *QUICConn) ConnectionState() ConnectionState { + return q.conn.ConnectionState() +} + +// SetTransportParameters sets the transport parameters to send to the peer. +// +// Server connections may delay setting the transport parameters until after +// receiving the client's transport parameters. See QUICTransportParametersRequired. +func (q *QUICConn) SetTransportParameters(params []byte) { + if params == nil { + params = []byte{} + } + q.conn.quic.transportParams = params + if q.conn.quic.started { + <-q.conn.quic.signalc + <-q.conn.quic.blockedc + } +} + +// quicError ensures err is an AlertError. +// If err is not already, quicError wraps it with alertInternalError. +func quicError(err error) error { + if err == nil { + return nil + } + var ae AlertError + if errors.As(err, &ae) { + return err + } + var a alert + if !errors.As(err, &a) { + a = alertInternalError + } + // Return an error wrapping the original error and an AlertError. + // Truncate the text of the alert to 0 characters. + return fmt.Errorf("%w%.0w", err, AlertError(a)) +} + +func (c *Conn) quicReadHandshakeBytes(n int) error { + for c.hand.Len() < n { + if err := c.quicWaitForSignal(); err != nil { + return err + } + } + return nil +} + +func (c *Conn) quicSetReadSecret(level QUICEncryptionLevel, suite uint16, secret []byte) { + c.quic.events = append(c.quic.events, QUICEvent{ + Kind: QUICSetReadSecret, + Level: level, + Suite: suite, + Data: secret, + }) +} + +func (c *Conn) quicSetWriteSecret(level QUICEncryptionLevel, suite uint16, secret []byte) { + c.quic.events = append(c.quic.events, QUICEvent{ + Kind: QUICSetWriteSecret, + Level: level, + Suite: suite, + Data: secret, + }) +} + +func (c *Conn) quicWriteCryptoData(level QUICEncryptionLevel, data []byte) { + var last *QUICEvent + if len(c.quic.events) > 0 { + last = &c.quic.events[len(c.quic.events)-1] + } + if last == nil || last.Kind != QUICWriteData || last.Level != level { + c.quic.events = append(c.quic.events, QUICEvent{ + Kind: QUICWriteData, + Level: level, + }) + last = &c.quic.events[len(c.quic.events)-1] + } + last.Data = append(last.Data, data...) +} + +func (c *Conn) quicSetTransportParameters(params []byte) { + c.quic.events = append(c.quic.events, QUICEvent{ + Kind: QUICTransportParameters, + Data: params, + }) +} + +func (c *Conn) quicGetTransportParameters() ([]byte, error) { + if c.quic.transportParams == nil { + c.quic.events = append(c.quic.events, QUICEvent{ + Kind: QUICTransportParametersRequired, + }) + } + for c.quic.transportParams == nil { + if err := c.quicWaitForSignal(); err != nil { + return nil, err + } + } + return c.quic.transportParams, nil +} + +func (c *Conn) quicHandshakeComplete() { + c.quic.events = append(c.quic.events, QUICEvent{ + Kind: QUICHandshakeDone, + }) +} + +// quicWaitForSignal notifies the QUICConn that handshake progress is blocked, +// and waits for a signal that the handshake should proceed. +// +// The handshake may become blocked waiting for handshake bytes +// or for the user to provide transport parameters. +func (c *Conn) quicWaitForSignal() error { + // Drop the handshake mutex while blocked to allow the user + // to call ConnectionState before the handshake completes. + c.handshakeMutex.Unlock() + defer c.handshakeMutex.Lock() + // Send on blockedc to notify the QUICConn that the handshake is blocked. + // Exported methods of QUICConn wait for the handshake to become blocked + // before returning to the user. + select { + case c.quic.blockedc <- struct{}{}: + case <-c.quic.cancelc: + return c.sendAlertLocked(alertCloseNotify) + } + // The QUICConn reads from signalc to notify us that the handshake may + // be able to proceed. (The QUICConn reads, because we close signalc to + // indicate that the handshake has completed.) + select { + case c.quic.signalc <- struct{}{}: + c.hand.Write(c.quic.readbuf) + c.quic.readbuf = nil + case <-c.quic.cancelc: + return c.sendAlertLocked(alertCloseNotify) + } + return nil +} diff --git a/src/crypto/tls/quic_test.go b/src/crypto/tls/quic_test.go new file mode 100644 index 0000000000..58054de80d --- /dev/null +++ b/src/crypto/tls/quic_test.go @@ -0,0 +1,430 @@ +// Copyright 2023 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package tls + +import ( + "context" + "errors" + "reflect" + "testing" +) + +type testQUICConn struct { + t *testing.T + conn *QUICConn + readSecret map[QUICEncryptionLevel]suiteSecret + writeSecret map[QUICEncryptionLevel]suiteSecret + gotParams []byte + complete bool +} + +func newTestQUICClient(t *testing.T, config *Config) *testQUICConn { + q := &testQUICConn{t: t} + q.conn = QUICClient(&QUICConfig{ + TLSConfig: config, + }) + t.Cleanup(func() { + q.conn.Close() + }) + return q +} + +func newTestQUICServer(t *testing.T, config *Config) *testQUICConn { + q := &testQUICConn{t: t} + q.conn = QUICServer(&QUICConfig{ + TLSConfig: config, + }) + t.Cleanup(func() { + q.conn.Close() + }) + return q +} + +type suiteSecret struct { + suite uint16 + secret []byte +} + +func (q *testQUICConn) setReadSecret(level QUICEncryptionLevel, suite uint16, secret []byte) { + if _, ok := q.writeSecret[level]; !ok { + q.t.Errorf("SetReadSecret for level %v called before SetWriteSecret", level) + } + if level == QUICEncryptionLevelApplication && !q.complete { + q.t.Errorf("SetReadSecret for level %v called before HandshakeComplete", level) + } + if _, ok := q.readSecret[level]; ok { + q.t.Errorf("SetReadSecret for level %v called twice", level) + } + if q.readSecret == nil { + q.readSecret = map[QUICEncryptionLevel]suiteSecret{} + } + switch level { + case QUICEncryptionLevelHandshake, QUICEncryptionLevelApplication: + q.readSecret[level] = suiteSecret{suite, secret} + default: + q.t.Errorf("SetReadSecret for unexpected level %v", level) + } +} + +func (q *testQUICConn) setWriteSecret(level QUICEncryptionLevel, suite uint16, secret []byte) { + if _, ok := q.writeSecret[level]; ok { + q.t.Errorf("SetWriteSecret for level %v called twice", level) + } + if q.writeSecret == nil { + q.writeSecret = map[QUICEncryptionLevel]suiteSecret{} + } + switch level { + case QUICEncryptionLevelHandshake, QUICEncryptionLevelApplication: + q.writeSecret[level] = suiteSecret{suite, secret} + default: + q.t.Errorf("SetWriteSecret for unexpected level %v", level) + } +} + +var errTransportParametersRequired = errors.New("transport parameters required") + +func runTestQUICConnection(ctx context.Context, a, b *testQUICConn, onHandleCryptoData func()) error { + for _, c := range []*testQUICConn{a, b} { + if !c.conn.conn.quic.started { + if err := c.conn.Start(ctx); err != nil { + return err + } + } + } + idleCount := 0 + for { + e := a.conn.NextEvent() + switch e.Kind { + case QUICNoEvent: + idleCount++ + if idleCount == 2 { + if !a.complete || !b.complete { + return errors.New("handshake incomplete") + } + return nil + } + a, b = b, a + case QUICSetReadSecret: + a.setReadSecret(e.Level, e.Suite, e.Data) + case QUICSetWriteSecret: + a.setWriteSecret(e.Level, e.Suite, e.Data) + case QUICWriteData: + if err := b.conn.HandleData(e.Level, e.Data); err != nil { + return err + } + case QUICTransportParameters: + a.gotParams = e.Data + if a.gotParams == nil { + a.gotParams = []byte{} + } + case QUICTransportParametersRequired: + return errTransportParametersRequired + case QUICHandshakeDone: + a.complete = true + } + if e.Kind != QUICNoEvent { + idleCount = 0 + } + } +} + +func TestQUICConnection(t *testing.T) { + config := testConfig.Clone() + config.MinVersion = VersionTLS13 + + cli := newTestQUICClient(t, config) + cli.conn.SetTransportParameters(nil) + + srv := newTestQUICServer(t, config) + srv.conn.SetTransportParameters(nil) + + if err := runTestQUICConnection(context.Background(), cli, srv, nil); err != nil { + t.Fatalf("error during connection handshake: %v", err) + } + + if _, ok := cli.readSecret[QUICEncryptionLevelHandshake]; !ok { + t.Errorf("client has no Handshake secret") + } + if _, ok := cli.readSecret[QUICEncryptionLevelApplication]; !ok { + t.Errorf("client has no Application secret") + } + if _, ok := srv.readSecret[QUICEncryptionLevelHandshake]; !ok { + t.Errorf("server has no Handshake secret") + } + if _, ok := srv.readSecret[QUICEncryptionLevelApplication]; !ok { + t.Errorf("server has no Application secret") + } + for _, level := range []QUICEncryptionLevel{QUICEncryptionLevelHandshake, QUICEncryptionLevelApplication} { + if _, ok := cli.readSecret[level]; !ok { + t.Errorf("client has no %v read secret", level) + } + if _, ok := srv.readSecret[level]; !ok { + t.Errorf("server has no %v read secret", level) + } + if !reflect.DeepEqual(cli.readSecret[level], srv.writeSecret[level]) { + t.Errorf("client read secret does not match server write secret for level %v", level) + } + if !reflect.DeepEqual(cli.writeSecret[level], srv.readSecret[level]) { + t.Errorf("client write secret does not match server read secret for level %v", level) + } + } +} + +func TestQUICSessionResumption(t *testing.T) { + clientConfig := testConfig.Clone() + clientConfig.MinVersion = VersionTLS13 + clientConfig.ClientSessionCache = NewLRUClientSessionCache(1) + clientConfig.ServerName = "example.go.dev" + + serverConfig := testConfig.Clone() + serverConfig.MinVersion = VersionTLS13 + + cli := newTestQUICClient(t, clientConfig) + cli.conn.SetTransportParameters(nil) + srv := newTestQUICServer(t, serverConfig) + srv.conn.SetTransportParameters(nil) + if err := runTestQUICConnection(context.Background(), cli, srv, nil); err != nil { + t.Fatalf("error during first connection handshake: %v", err) + } + if cli.conn.ConnectionState().DidResume { + t.Errorf("first connection unexpectedly used session resumption") + } + + cli2 := newTestQUICClient(t, clientConfig) + cli2.conn.SetTransportParameters(nil) + srv2 := newTestQUICServer(t, serverConfig) + srv2.conn.SetTransportParameters(nil) + if err := runTestQUICConnection(context.Background(), cli2, srv2, nil); err != nil { + t.Fatalf("error during second connection handshake: %v", err) + } + if !cli2.conn.ConnectionState().DidResume { + t.Errorf("second connection did not use session resumption") + } +} + +func TestQUICPostHandshakeClientAuthentication(t *testing.T) { + // RFC 9001, Section 4.4. + config := testConfig.Clone() + config.MinVersion = VersionTLS13 + cli := newTestQUICClient(t, config) + cli.conn.SetTransportParameters(nil) + srv := newTestQUICServer(t, config) + srv.conn.SetTransportParameters(nil) + if err := runTestQUICConnection(context.Background(), cli, srv, nil); err != nil { + t.Fatalf("error during connection handshake: %v", err) + } + + certReq := new(certificateRequestMsgTLS13) + certReq.ocspStapling = true + certReq.scts = true + certReq.supportedSignatureAlgorithms = supportedSignatureAlgorithms() + certReqBytes, err := certReq.marshal() + if err != nil { + t.Fatal(err) + } + if err := cli.conn.HandleData(QUICEncryptionLevelApplication, append([]byte{ + byte(typeCertificateRequest), + byte(0), byte(0), byte(len(certReqBytes)), + }, certReqBytes...)); err == nil { + t.Fatalf("post-handshake authentication request: got no error, want one") + } +} + +func TestQUICPostHandshakeKeyUpdate(t *testing.T) { + // RFC 9001, Section 6. + config := testConfig.Clone() + config.MinVersion = VersionTLS13 + cli := newTestQUICClient(t, config) + cli.conn.SetTransportParameters(nil) + srv := newTestQUICServer(t, config) + srv.conn.SetTransportParameters(nil) + if err := runTestQUICConnection(context.Background(), cli, srv, nil); err != nil { + t.Fatalf("error during connection handshake: %v", err) + } + + keyUpdate := new(keyUpdateMsg) + keyUpdateBytes, err := keyUpdate.marshal() + if err != nil { + t.Fatal(err) + } + if err := cli.conn.HandleData(QUICEncryptionLevelApplication, append([]byte{ + byte(typeKeyUpdate), + byte(0), byte(0), byte(len(keyUpdateBytes)), + }, keyUpdateBytes...)); !errors.Is(err, alertUnexpectedMessage) { + t.Fatalf("key update request: got error %v, want alertUnexpectedMessage", err) + } +} + +func TestQUICHandshakeError(t *testing.T) { + clientConfig := testConfig.Clone() + clientConfig.MinVersion = VersionTLS13 + clientConfig.InsecureSkipVerify = false + clientConfig.ServerName = "name" + + serverConfig := testConfig.Clone() + serverConfig.MinVersion = VersionTLS13 + + cli := newTestQUICClient(t, clientConfig) + cli.conn.SetTransportParameters(nil) + srv := newTestQUICServer(t, serverConfig) + srv.conn.SetTransportParameters(nil) + err := runTestQUICConnection(context.Background(), cli, srv, nil) + if !errors.Is(err, AlertError(alertBadCertificate)) { + t.Errorf("connection handshake terminated with error %q, want alertBadCertificate", err) + } + var e *CertificateVerificationError + if !errors.As(err, &e) { + t.Errorf("connection handshake terminated with error %q, want CertificateVerificationError", err) + } +} + +// Test that QUICConn.ConnectionState can be used during the handshake, +// and that it reports the application protocol as soon as it has been +// negotiated. +func TestQUICConnectionState(t *testing.T) { + config := testConfig.Clone() + config.MinVersion = VersionTLS13 + config.NextProtos = []string{"h3"} + cli := newTestQUICClient(t, config) + cli.conn.SetTransportParameters(nil) + srv := newTestQUICServer(t, config) + srv.conn.SetTransportParameters(nil) + onHandleCryptoData := func() { + cliCS := cli.conn.ConnectionState() + cliWantALPN := "" + if _, ok := cli.readSecret[QUICEncryptionLevelApplication]; ok { + cliWantALPN = "h3" + } + if want, got := cliCS.NegotiatedProtocol, cliWantALPN; want != got { + t.Errorf("cli.ConnectionState().NegotiatedProtocol = %q, want %q", want, got) + } + + srvCS := srv.conn.ConnectionState() + srvWantALPN := "" + if _, ok := srv.readSecret[QUICEncryptionLevelHandshake]; ok { + srvWantALPN = "h3" + } + if want, got := srvCS.NegotiatedProtocol, srvWantALPN; want != got { + t.Errorf("srv.ConnectionState().NegotiatedProtocol = %q, want %q", want, got) + } + } + if err := runTestQUICConnection(context.Background(), cli, srv, onHandleCryptoData); err != nil { + t.Fatalf("error during connection handshake: %v", err) + } +} + +func TestQUICStartContextPropagation(t *testing.T) { + const key = "key" + const value = "value" + ctx := context.WithValue(context.Background(), key, value) + config := testConfig.Clone() + config.MinVersion = VersionTLS13 + calls := 0 + config.GetConfigForClient = func(info *ClientHelloInfo) (*Config, error) { + calls++ + got, _ := info.Context().Value(key).(string) + if got != value { + t.Errorf("GetConfigForClient context key %q has value %q, want %q", key, got, value) + } + return nil, nil + } + cli := newTestQUICClient(t, config) + cli.conn.SetTransportParameters(nil) + srv := newTestQUICServer(t, config) + srv.conn.SetTransportParameters(nil) + if err := runTestQUICConnection(ctx, cli, srv, nil); err != nil { + t.Fatalf("error during connection handshake: %v", err) + } + if calls != 1 { + t.Errorf("GetConfigForClient called %v times, want 1", calls) + } +} + +func TestQUICDelayedTransportParameters(t *testing.T) { + clientConfig := testConfig.Clone() + clientConfig.MinVersion = VersionTLS13 + clientConfig.ClientSessionCache = NewLRUClientSessionCache(1) + clientConfig.ServerName = "example.go.dev" + + serverConfig := testConfig.Clone() + serverConfig.MinVersion = VersionTLS13 + + cliParams := "client params" + srvParams := "server params" + + cli := newTestQUICClient(t, clientConfig) + srv := newTestQUICServer(t, serverConfig) + if err := runTestQUICConnection(context.Background(), cli, srv, nil); err != errTransportParametersRequired { + t.Fatalf("handshake with no client parameters: %v; want errTransportParametersRequired", err) + } + cli.conn.SetTransportParameters([]byte(cliParams)) + if err := runTestQUICConnection(context.Background(), cli, srv, nil); err != errTransportParametersRequired { + t.Fatalf("handshake with no server parameters: %v; want errTransportParametersRequired", err) + } + srv.conn.SetTransportParameters([]byte(srvParams)) + if err := runTestQUICConnection(context.Background(), cli, srv, nil); err != nil { + t.Fatalf("error during connection handshake: %v", err) + } + + if got, want := string(cli.gotParams), srvParams; got != want { + t.Errorf("client got transport params: %q, want %q", got, want) + } + if got, want := string(srv.gotParams), cliParams; got != want { + t.Errorf("server got transport params: %q, want %q", got, want) + } +} + +func TestQUICEmptyTransportParameters(t *testing.T) { + config := testConfig.Clone() + config.MinVersion = VersionTLS13 + + cli := newTestQUICClient(t, config) + cli.conn.SetTransportParameters(nil) + srv := newTestQUICServer(t, config) + srv.conn.SetTransportParameters(nil) + if err := runTestQUICConnection(context.Background(), cli, srv, nil); err != nil { + t.Fatalf("error during connection handshake: %v", err) + } + + if cli.gotParams == nil { + t.Errorf("client did not get transport params") + } + if srv.gotParams == nil { + t.Errorf("server did not get transport params") + } + if len(cli.gotParams) != 0 { + t.Errorf("client got transport params: %v, want empty", cli.gotParams) + } + if len(srv.gotParams) != 0 { + t.Errorf("server got transport params: %v, want empty", srv.gotParams) + } +} + +func TestQUICCanceledWaitingForData(t *testing.T) { + config := testConfig.Clone() + config.MinVersion = VersionTLS13 + cli := newTestQUICClient(t, config) + cli.conn.SetTransportParameters(nil) + cli.conn.Start(context.Background()) + for cli.conn.NextEvent().Kind != QUICNoEvent { + } + err := cli.conn.Close() + if !errors.Is(err, alertCloseNotify) { + t.Errorf("conn.Close() = %v, want alertCloseNotify", err) + } +} + +func TestQUICCanceledWaitingForTransportParams(t *testing.T) { + config := testConfig.Clone() + config.MinVersion = VersionTLS13 + cli := newTestQUICClient(t, config) + cli.conn.Start(context.Background()) + for cli.conn.NextEvent().Kind != QUICTransportParametersRequired { + } + err := cli.conn.Close() + if !errors.Is(err, alertCloseNotify) { + t.Errorf("conn.Close() = %v, want alertCloseNotify", err) + } +}