diff --git a/src/pkg/crypto/tls/handshake_messages.go b/src/pkg/crypto/tls/handshake_messages.go index f11232d8ee..5438e749ce 100644 --- a/src/pkg/crypto/tls/handshake_messages.go +++ b/src/pkg/crypto/tls/handshake_messages.go @@ -4,6 +4,8 @@ package tls +import "bytes" + type clientHelloMsg struct { raw []byte vers uint16 @@ -18,6 +20,25 @@ type clientHelloMsg struct { supportedPoints []uint8 } +func (m *clientHelloMsg) equal(i interface{}) bool { + m1, ok := i.(*clientHelloMsg) + if !ok { + return false + } + + return bytes.Equal(m.raw, m1.raw) && + m.vers == m1.vers && + bytes.Equal(m.random, m1.random) && + bytes.Equal(m.sessionId, m1.sessionId) && + eqUint16s(m.cipherSuites, m1.cipherSuites) && + bytes.Equal(m.compressionMethods, m1.compressionMethods) && + m.nextProtoNeg == m1.nextProtoNeg && + m.serverName == m1.serverName && + m.ocspStapling == m1.ocspStapling && + eqUint16s(m.supportedCurves, m1.supportedCurves) && + bytes.Equal(m.supportedPoints, m1.supportedPoints) +} + func (m *clientHelloMsg) marshal() []byte { if m.raw != nil { return m.raw @@ -309,6 +330,23 @@ type serverHelloMsg struct { ocspStapling bool } +func (m *serverHelloMsg) equal(i interface{}) bool { + m1, ok := i.(*serverHelloMsg) + if !ok { + return false + } + + return bytes.Equal(m.raw, m1.raw) && + m.vers == m1.vers && + bytes.Equal(m.random, m1.random) && + bytes.Equal(m.sessionId, m1.sessionId) && + m.cipherSuite == m1.cipherSuite && + m.compressionMethod == m1.compressionMethod && + m.nextProtoNeg == m1.nextProtoNeg && + eqStrings(m.nextProtos, m1.nextProtos) && + m.ocspStapling == m1.ocspStapling +} + func (m *serverHelloMsg) marshal() []byte { if m.raw != nil { return m.raw @@ -463,6 +501,16 @@ type certificateMsg struct { certificates [][]byte } +func (m *certificateMsg) equal(i interface{}) bool { + m1, ok := i.(*certificateMsg) + if !ok { + return false + } + + return bytes.Equal(m.raw, m1.raw) && + eqByteSlices(m.certificates, m1.certificates) +} + func (m *certificateMsg) marshal() (x []byte) { if m.raw != nil { return m.raw @@ -540,6 +588,16 @@ type serverKeyExchangeMsg struct { key []byte } +func (m *serverKeyExchangeMsg) equal(i interface{}) bool { + m1, ok := i.(*serverKeyExchangeMsg) + if !ok { + return false + } + + return bytes.Equal(m.raw, m1.raw) && + bytes.Equal(m.key, m1.key) +} + func (m *serverKeyExchangeMsg) marshal() []byte { if m.raw != nil { return m.raw @@ -571,6 +629,17 @@ type certificateStatusMsg struct { response []byte } +func (m *certificateStatusMsg) equal(i interface{}) bool { + m1, ok := i.(*certificateStatusMsg) + if !ok { + return false + } + + return bytes.Equal(m.raw, m1.raw) && + m.statusType == m1.statusType && + bytes.Equal(m.response, m1.response) +} + func (m *certificateStatusMsg) marshal() []byte { if m.raw != nil { return m.raw @@ -622,6 +691,11 @@ func (m *certificateStatusMsg) unmarshal(data []byte) bool { type serverHelloDoneMsg struct{} +func (m *serverHelloDoneMsg) equal(i interface{}) bool { + _, ok := i.(*serverHelloDoneMsg) + return ok +} + func (m *serverHelloDoneMsg) marshal() []byte { x := make([]byte, 4) x[0] = typeServerHelloDone @@ -637,6 +711,16 @@ type clientKeyExchangeMsg struct { ciphertext []byte } +func (m *clientKeyExchangeMsg) equal(i interface{}) bool { + m1, ok := i.(*clientKeyExchangeMsg) + if !ok { + return false + } + + return bytes.Equal(m.raw, m1.raw) && + bytes.Equal(m.ciphertext, m1.ciphertext) +} + func (m *clientKeyExchangeMsg) marshal() []byte { if m.raw != nil { return m.raw @@ -671,6 +755,16 @@ type finishedMsg struct { verifyData []byte } +func (m *finishedMsg) equal(i interface{}) bool { + m1, ok := i.(*finishedMsg) + if !ok { + return false + } + + return bytes.Equal(m.raw, m1.raw) && + bytes.Equal(m.verifyData, m1.verifyData) +} + func (m *finishedMsg) marshal() (x []byte) { if m.raw != nil { return m.raw @@ -698,6 +792,16 @@ type nextProtoMsg struct { proto string } +func (m *nextProtoMsg) equal(i interface{}) bool { + m1, ok := i.(*nextProtoMsg) + if !ok { + return false + } + + return bytes.Equal(m.raw, m1.raw) && + m.proto == m1.proto +} + func (m *nextProtoMsg) marshal() []byte { if m.raw != nil { return m.raw @@ -759,6 +863,17 @@ type certificateRequestMsg struct { certificateAuthorities [][]byte } +func (m *certificateRequestMsg) equal(i interface{}) bool { + m1, ok := i.(*certificateRequestMsg) + if !ok { + return false + } + + return bytes.Equal(m.raw, m1.raw) && + bytes.Equal(m.certificateTypes, m1.certificateTypes) && + eqByteSlices(m.certificateAuthorities, m1.certificateAuthorities) +} + func (m *certificateRequestMsg) marshal() (x []byte) { if m.raw != nil { return m.raw @@ -859,6 +974,16 @@ type certificateVerifyMsg struct { signature []byte } +func (m *certificateVerifyMsg) equal(i interface{}) bool { + m1, ok := i.(*certificateVerifyMsg) + if !ok { + return false + } + + return bytes.Equal(m.raw, m1.raw) && + bytes.Equal(m.signature, m1.signature) +} + func (m *certificateVerifyMsg) marshal() (x []byte) { if m.raw != nil { return m.raw @@ -902,3 +1027,39 @@ func (m *certificateVerifyMsg) unmarshal(data []byte) bool { return true } + +func eqUint16s(x, y []uint16) bool { + if len(x) != len(y) { + return false + } + for i, v := range x { + if y[i] != v { + return false + } + } + return true +} + +func eqStrings(x, y []string) bool { + if len(x) != len(y) { + return false + } + for i, v := range x { + if y[i] != v { + return false + } + } + return true +} + +func eqByteSlices(x, y [][]byte) bool { + if len(x) != len(y) { + return false + } + for i, v := range x { + if !bytes.Equal(v, y[i]) { + return false + } + } + return true +} diff --git a/src/pkg/crypto/tls/handshake_messages_test.go b/src/pkg/crypto/tls/handshake_messages_test.go index 87e8f7e428..e62a9d581b 100644 --- a/src/pkg/crypto/tls/handshake_messages_test.go +++ b/src/pkg/crypto/tls/handshake_messages_test.go @@ -27,10 +27,12 @@ var tests = []interface{}{ type testMessage interface { marshal() []byte unmarshal([]byte) bool + equal(interface{}) bool } func TestMarshalUnmarshal(t *testing.T) { rand := rand.New(rand.NewSource(0)) + for i, iface := range tests { ty := reflect.ValueOf(iface).Type() @@ -54,7 +56,7 @@ func TestMarshalUnmarshal(t *testing.T) { } m2.marshal() // to fill any marshal cache in the message - if !reflect.DeepEqual(m1, m2) { + if !m1.equal(m2) { t.Errorf("#%d got:%#v want:%#v %x", i, m2, m1, marshaled) break }