mirror of
https://github.com/golang/go
synced 2024-11-12 01:50:22 -07:00
crypto/tls: check errors from (*Conn).writeRecord
This promotes a connection hang during TLS handshake to a proper error. This doesn't fully address #14539 because the error reported in that case is a write-on-socket-not-connected error, which implies that an earlier error during connection setup is not being checked, but it is an improvement over the current behaviour. Updates #14539. Change-Id: I0571a752d32d5303db48149ab448226868b19495 Reviewed-on: https://go-review.googlesource.com/19990 Reviewed-by: Adam Langley <agl@golang.org>
This commit is contained in:
parent
1012892f1e
commit
37c28759ca
@ -694,12 +694,14 @@ func (c *Conn) sendAlertLocked(err alert) error {
|
||||
c.tmp[0] = alertLevelError
|
||||
}
|
||||
c.tmp[1] = byte(err)
|
||||
c.writeRecord(recordTypeAlert, c.tmp[0:2])
|
||||
// closeNotify is a special case in that it isn't an error:
|
||||
if err != alertCloseNotify {
|
||||
return c.out.setErrorLocked(&net.OpError{Op: "local error", Err: err})
|
||||
|
||||
_, writeErr := c.writeRecord(recordTypeAlert, c.tmp[0:2])
|
||||
if err == alertCloseNotify {
|
||||
// closeNotify is a special case in that it isn't an error.
|
||||
return writeErr
|
||||
}
|
||||
return nil
|
||||
|
||||
return c.out.setErrorLocked(&net.OpError{Op: "local error", Err: err})
|
||||
}
|
||||
|
||||
// sendAlert sends a TLS alert message.
|
||||
@ -713,8 +715,11 @@ func (c *Conn) sendAlert(err alert) error {
|
||||
// writeRecord writes a TLS record with the given type and payload
|
||||
// to the connection and updates the record layer state.
|
||||
// c.out.Mutex <= L.
|
||||
func (c *Conn) writeRecord(typ recordType, data []byte) (n int, err error) {
|
||||
func (c *Conn) writeRecord(typ recordType, data []byte) (int, error) {
|
||||
b := c.out.newBlock()
|
||||
defer c.out.freeBlock(b)
|
||||
|
||||
var n int
|
||||
for len(data) > 0 {
|
||||
m := len(data)
|
||||
if m > maxPlaintext {
|
||||
@ -759,34 +764,27 @@ func (c *Conn) writeRecord(typ recordType, data []byte) (n int, err error) {
|
||||
if explicitIVIsSeq {
|
||||
copy(explicitIV, c.out.seq[:])
|
||||
} else {
|
||||
if _, err = io.ReadFull(c.config.rand(), explicitIV); err != nil {
|
||||
break
|
||||
if _, err := io.ReadFull(c.config.rand(), explicitIV); err != nil {
|
||||
return n, err
|
||||
}
|
||||
}
|
||||
}
|
||||
copy(b.data[recordHeaderLen+explicitIVLen:], data)
|
||||
c.out.encrypt(b, explicitIVLen)
|
||||
_, err = c.conn.Write(b.data)
|
||||
if err != nil {
|
||||
break
|
||||
if _, err := c.conn.Write(b.data); err != nil {
|
||||
return n, err
|
||||
}
|
||||
n += m
|
||||
data = data[m:]
|
||||
}
|
||||
c.out.freeBlock(b)
|
||||
|
||||
if typ == recordTypeChangeCipherSpec {
|
||||
err = c.out.changeCipherSpec()
|
||||
if err != nil {
|
||||
// Cannot call sendAlert directly,
|
||||
// because we already hold c.out.Mutex.
|
||||
c.tmp[0] = alertLevelError
|
||||
c.tmp[1] = byte(err.(alert))
|
||||
c.writeRecord(recordTypeAlert, c.tmp[0:2])
|
||||
return n, c.out.setErrorLocked(&net.OpError{Op: "local error", Err: err})
|
||||
if err := c.out.changeCipherSpec(); err != nil {
|
||||
return n, c.sendAlertLocked(err.(alert))
|
||||
}
|
||||
}
|
||||
return
|
||||
|
||||
return n, nil
|
||||
}
|
||||
|
||||
// readHandshake reads the next handshake message from
|
||||
|
@ -138,7 +138,9 @@ NextCipherSuite:
|
||||
}
|
||||
}
|
||||
|
||||
c.writeRecord(recordTypeHandshake, hello.marshal())
|
||||
if _, err := c.writeRecord(recordTypeHandshake, hello.marshal()); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
msg, err := c.readHandshake()
|
||||
if err != nil {
|
||||
@ -419,7 +421,9 @@ func (hs *clientHandshakeState) doFullHandshake() error {
|
||||
certMsg.certificates = chainToSend.Certificate
|
||||
}
|
||||
hs.finishedHash.Write(certMsg.marshal())
|
||||
c.writeRecord(recordTypeHandshake, certMsg.marshal())
|
||||
if _, err := c.writeRecord(recordTypeHandshake, certMsg.marshal()); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
preMasterSecret, ckx, err := keyAgreement.generateClientKeyExchange(c.config, hs.hello, certs[0])
|
||||
@ -429,7 +433,9 @@ func (hs *clientHandshakeState) doFullHandshake() error {
|
||||
}
|
||||
if ckx != nil {
|
||||
hs.finishedHash.Write(ckx.marshal())
|
||||
c.writeRecord(recordTypeHandshake, ckx.marshal())
|
||||
if _, err := c.writeRecord(recordTypeHandshake, ckx.marshal()); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if chainToSend != nil {
|
||||
@ -471,7 +477,9 @@ func (hs *clientHandshakeState) doFullHandshake() error {
|
||||
}
|
||||
|
||||
hs.finishedHash.Write(certVerify.marshal())
|
||||
c.writeRecord(recordTypeHandshake, certVerify.marshal())
|
||||
if _, err := c.writeRecord(recordTypeHandshake, certVerify.marshal()); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
hs.masterSecret = masterFromPreMasterSecret(c.vers, hs.suite, preMasterSecret, hs.hello.random, hs.serverHello.random)
|
||||
@ -615,7 +623,9 @@ func (hs *clientHandshakeState) readSessionTicket() error {
|
||||
func (hs *clientHandshakeState) sendFinished(out []byte) error {
|
||||
c := hs.c
|
||||
|
||||
c.writeRecord(recordTypeChangeCipherSpec, []byte{1})
|
||||
if _, err := c.writeRecord(recordTypeChangeCipherSpec, []byte{1}); err != nil {
|
||||
return err
|
||||
}
|
||||
if hs.serverHello.nextProtoNeg {
|
||||
nextProto := new(nextProtoMsg)
|
||||
proto, fallback := mutualProtocol(c.config.NextProtos, hs.serverHello.nextProtos)
|
||||
@ -624,13 +634,17 @@ func (hs *clientHandshakeState) sendFinished(out []byte) error {
|
||||
c.clientProtocolFallback = fallback
|
||||
|
||||
hs.finishedHash.Write(nextProto.marshal())
|
||||
c.writeRecord(recordTypeHandshake, nextProto.marshal())
|
||||
if _, err := c.writeRecord(recordTypeHandshake, nextProto.marshal()); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
finished := new(finishedMsg)
|
||||
finished.verifyData = hs.finishedHash.clientSum(hs.masterSecret)
|
||||
hs.finishedHash.Write(finished.marshal())
|
||||
c.writeRecord(recordTypeHandshake, finished.marshal())
|
||||
if _, err := c.writeRecord(recordTypeHandshake, finished.marshal()); err != nil {
|
||||
return err
|
||||
}
|
||||
copy(out, finished.verifyData)
|
||||
return nil
|
||||
}
|
||||
|
@ -12,6 +12,7 @@ import (
|
||||
"encoding/base64"
|
||||
"encoding/binary"
|
||||
"encoding/pem"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
@ -725,3 +726,51 @@ func TestServerSelectingUnconfiguredCipherSuite(t *testing.T) {
|
||||
t.Fatalf("Expected error about unconfigured cipher suite but got %q", err)
|
||||
}
|
||||
}
|
||||
|
||||
// brokenConn wraps a net.Conn and causes all Writes after a certain number to
|
||||
// fail with brokenConnErr.
|
||||
type brokenConn struct {
|
||||
net.Conn
|
||||
|
||||
// breakAfter is the number of successful writes that will be allowed
|
||||
// before all subsequent writes fail.
|
||||
breakAfter int
|
||||
|
||||
// numWrites is the number of writes that have been done.
|
||||
numWrites int
|
||||
}
|
||||
|
||||
// brokenConnErr is the error that brokenConn returns once exhausted.
|
||||
var brokenConnErr = errors.New("too many writes to brokenConn")
|
||||
|
||||
func (b *brokenConn) Write(data []byte) (int, error) {
|
||||
if b.numWrites >= b.breakAfter {
|
||||
return 0, brokenConnErr
|
||||
}
|
||||
|
||||
b.numWrites++
|
||||
return b.Conn.Write(data)
|
||||
}
|
||||
|
||||
func TestFailedWrite(t *testing.T) {
|
||||
// Test that a write error during the handshake is returned.
|
||||
for _, breakAfter := range []int{0, 1, 2, 3} {
|
||||
c, s := net.Pipe()
|
||||
done := make(chan bool)
|
||||
|
||||
go func() {
|
||||
Server(s, testConfig).Handshake()
|
||||
s.Close()
|
||||
done <- true
|
||||
}()
|
||||
|
||||
brokenC := &brokenConn{Conn: c, breakAfter: breakAfter}
|
||||
err := Client(brokenC, testConfig).Handshake()
|
||||
if err != brokenConnErr {
|
||||
t.Errorf("#%d: expected error from brokenConn but got %q", breakAfter, err)
|
||||
}
|
||||
brokenC.Close()
|
||||
|
||||
<-done
|
||||
}
|
||||
}
|
||||
|
@ -322,7 +322,9 @@ func (hs *serverHandshakeState) doResumeHandshake() error {
|
||||
hs.finishedHash.discardHandshakeBuffer()
|
||||
hs.finishedHash.Write(hs.clientHello.marshal())
|
||||
hs.finishedHash.Write(hs.hello.marshal())
|
||||
c.writeRecord(recordTypeHandshake, hs.hello.marshal())
|
||||
if _, err := c.writeRecord(recordTypeHandshake, hs.hello.marshal()); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if len(hs.sessionState.certificates) > 0 {
|
||||
if _, err := hs.processCertsFromClient(hs.sessionState.certificates); err != nil {
|
||||
@ -354,19 +356,25 @@ func (hs *serverHandshakeState) doFullHandshake() error {
|
||||
}
|
||||
hs.finishedHash.Write(hs.clientHello.marshal())
|
||||
hs.finishedHash.Write(hs.hello.marshal())
|
||||
c.writeRecord(recordTypeHandshake, hs.hello.marshal())
|
||||
if _, err := c.writeRecord(recordTypeHandshake, hs.hello.marshal()); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
certMsg := new(certificateMsg)
|
||||
certMsg.certificates = hs.cert.Certificate
|
||||
hs.finishedHash.Write(certMsg.marshal())
|
||||
c.writeRecord(recordTypeHandshake, certMsg.marshal())
|
||||
if _, err := c.writeRecord(recordTypeHandshake, certMsg.marshal()); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if hs.hello.ocspStapling {
|
||||
certStatus := new(certificateStatusMsg)
|
||||
certStatus.statusType = statusTypeOCSP
|
||||
certStatus.response = hs.cert.OCSPStaple
|
||||
hs.finishedHash.Write(certStatus.marshal())
|
||||
c.writeRecord(recordTypeHandshake, certStatus.marshal())
|
||||
if _, err := c.writeRecord(recordTypeHandshake, certStatus.marshal()); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
keyAgreement := hs.suite.ka(c.vers)
|
||||
@ -377,7 +385,9 @@ func (hs *serverHandshakeState) doFullHandshake() error {
|
||||
}
|
||||
if skx != nil {
|
||||
hs.finishedHash.Write(skx.marshal())
|
||||
c.writeRecord(recordTypeHandshake, skx.marshal())
|
||||
if _, err := c.writeRecord(recordTypeHandshake, skx.marshal()); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if config.ClientAuth >= RequestClientCert {
|
||||
@ -401,12 +411,16 @@ func (hs *serverHandshakeState) doFullHandshake() error {
|
||||
certReq.certificateAuthorities = config.ClientCAs.Subjects()
|
||||
}
|
||||
hs.finishedHash.Write(certReq.marshal())
|
||||
c.writeRecord(recordTypeHandshake, certReq.marshal())
|
||||
if _, err := c.writeRecord(recordTypeHandshake, certReq.marshal()); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
helloDone := new(serverHelloDoneMsg)
|
||||
hs.finishedHash.Write(helloDone.marshal())
|
||||
c.writeRecord(recordTypeHandshake, helloDone.marshal())
|
||||
if _, err := c.writeRecord(recordTypeHandshake, helloDone.marshal()); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var pub crypto.PublicKey // public key for client auth, if any
|
||||
|
||||
@ -632,7 +646,9 @@ func (hs *serverHandshakeState) sendSessionTicket() error {
|
||||
}
|
||||
|
||||
hs.finishedHash.Write(m.marshal())
|
||||
c.writeRecord(recordTypeHandshake, m.marshal())
|
||||
if _, err := c.writeRecord(recordTypeHandshake, m.marshal()); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@ -640,12 +656,16 @@ func (hs *serverHandshakeState) sendSessionTicket() error {
|
||||
func (hs *serverHandshakeState) sendFinished(out []byte) error {
|
||||
c := hs.c
|
||||
|
||||
c.writeRecord(recordTypeChangeCipherSpec, []byte{1})
|
||||
if _, err := c.writeRecord(recordTypeChangeCipherSpec, []byte{1}); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
finished := new(finishedMsg)
|
||||
finished.verifyData = hs.finishedHash.serverSum(hs.masterSecret)
|
||||
hs.finishedHash.Write(finished.marshal())
|
||||
c.writeRecord(recordTypeHandshake, finished.marshal())
|
||||
if _, err := c.writeRecord(recordTypeHandshake, finished.marshal()); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
c.cipherSuite = hs.suite.id
|
||||
copy(out, finished.verifyData)
|
||||
|
@ -80,7 +80,10 @@ func testClientHelloFailure(t *testing.T, serverConfig *Config, m handshakeMessa
|
||||
cli.writeRecord(recordTypeHandshake, m.marshal())
|
||||
c.Close()
|
||||
}()
|
||||
err := Server(s, serverConfig).Handshake()
|
||||
hs := serverHandshakeState{
|
||||
c: Server(s, serverConfig),
|
||||
}
|
||||
_, err := hs.readClientHello()
|
||||
s.Close()
|
||||
if len(expectedSubStr) == 0 {
|
||||
if err != nil && err != io.EOF {
|
||||
|
Loading…
Reference in New Issue
Block a user