From 67ee9a7db103b5ae5c8d077fef9e21cf6f137f3a Mon Sep 17 00:00:00 2001 From: Dave Cheney Date: Thu, 6 Sep 2012 17:50:26 +1000 Subject: [PATCH] crypto/tls: fix data race on conn.err Fixes #3862. There were many areas where conn.err was being accessed outside the mutex. This proposal moves the err value to an embedded struct to make it more obvious when the error value is being accessed. As there are no Benchmark tests in this package I cannot feel confident of the impact of this additional locking, although most will be uncontended. R=dvyukov, agl CC=golang-dev https://golang.org/cl/6497070 --- src/pkg/crypto/tls/conn.go | 59 +++++++++++++------------- src/pkg/crypto/tls/handshake_client.go | 4 +- 2 files changed, 32 insertions(+), 31 deletions(-) diff --git a/src/pkg/crypto/tls/conn.go b/src/pkg/crypto/tls/conn.go index 455910af415..5dc344bed5b 100644 --- a/src/pkg/crypto/tls/conn.go +++ b/src/pkg/crypto/tls/conn.go @@ -44,8 +44,7 @@ type Conn struct { clientProtocolFallback bool // first permanent error - errMutex sync.Mutex - err error + connErr // input/output in, out halfConn // in.Mutex < out.Mutex @@ -56,21 +55,25 @@ type Conn struct { tmp [16]byte } -func (c *Conn) setError(err error) error { - c.errMutex.Lock() - defer c.errMutex.Unlock() +type connErr struct { + mu sync.Mutex + value error +} - if c.err == nil { - c.err = err +func (e *connErr) setError(err error) error { + e.mu.Lock() + defer e.mu.Unlock() + + if e.value == nil { + e.value = err } return err } -func (c *Conn) error() error { - c.errMutex.Lock() - defer c.errMutex.Unlock() - - return c.err +func (e *connErr) error() error { + e.mu.Lock() + defer e.mu.Unlock() + return e.value } // Access to net.Conn methods. @@ -660,8 +663,7 @@ func (c *Conn) writeRecord(typ recordType, data []byte) (n int, err error) { c.tmp[0] = alertLevelError c.tmp[1] = byte(err.(alert)) c.writeRecord(recordTypeAlert, c.tmp[0:2]) - c.err = &net.OpError{Op: "local error", Err: err} - return n, c.err + return n, c.setError(&net.OpError{Op: "local error", Err: err}) } } return @@ -672,8 +674,8 @@ func (c *Conn) writeRecord(typ recordType, data []byte) (n int, err error) { // c.in.Mutex < L; c.out.Mutex < L. func (c *Conn) readHandshake() (interface{}, error) { for c.hand.Len() < 4 { - if c.err != nil { - return nil, c.err + if err := c.error(); err != nil { + return nil, err } if err := c.readRecord(recordTypeHandshake); err != nil { return nil, err @@ -684,11 +686,11 @@ func (c *Conn) readHandshake() (interface{}, error) { n := int(data[1])<<16 | int(data[2])<<8 | int(data[3]) if n > maxHandshake { c.sendAlert(alertInternalError) - return nil, c.err + return nil, c.error() } for c.hand.Len() < 4+n { - if c.err != nil { - return nil, c.err + if err := c.error(); err != nil { + return nil, err } if err := c.readRecord(recordTypeHandshake); err != nil { return nil, err @@ -738,12 +740,12 @@ func (c *Conn) readHandshake() (interface{}, error) { // Write writes data to the connection. func (c *Conn) Write(b []byte) (int, error) { - if c.err != nil { - return 0, c.err + if err := c.error(); err != nil { + return 0, err } - if c.err = c.Handshake(); c.err != nil { - return 0, c.err + if err := c.Handshake(); err != nil { + return 0, c.setError(err) } c.out.Lock() @@ -753,9 +755,8 @@ func (c *Conn) Write(b []byte) (int, error) { return 0, alertInternalError } - var n int - n, c.err = c.writeRecord(recordTypeApplicationData, b) - return n, c.err + n, err := c.writeRecord(recordTypeApplicationData, b) + return n, c.setError(err) } // Read can be made to time out and return a net.Error with Timeout() == true @@ -768,14 +769,14 @@ func (c *Conn) Read(b []byte) (n int, err error) { c.in.Lock() defer c.in.Unlock() - for c.input == nil && c.err == nil { + for c.input == nil && c.error() == nil { if err := c.readRecord(recordTypeApplicationData); err != nil { // Soft error, like EAGAIN return 0, err } } - if c.err != nil { - return 0, c.err + if err := c.error(); err != nil { + return 0, err } n, err = c.input.Read(b) if c.input.off >= len(c.input.data) { diff --git a/src/pkg/crypto/tls/handshake_client.go b/src/pkg/crypto/tls/handshake_client.go index 2877f17387d..c6637c593c9 100644 --- a/src/pkg/crypto/tls/handshake_client.go +++ b/src/pkg/crypto/tls/handshake_client.go @@ -306,8 +306,8 @@ func (c *Conn) clientHandshake() error { serverHash := suite.mac(c.vers, serverMAC) c.in.prepareCipherSpec(c.vers, serverCipher, serverHash) c.readRecord(recordTypeChangeCipherSpec) - if c.err != nil { - return c.err + if err := c.error(); err != nil { + return err } msg, err = c.readHandshake()