mirror of
https://github.com/golang/go
synced 2024-11-25 14:07:56 -07:00
crypto/tls: extensions and Next Protocol Negotiation
Add support for TLS extensions in general and Next Protocol Negotiation in particular. R=rsc CC=golang-dev https://golang.org/cl/181045
This commit is contained in:
parent
7c9111434a
commit
9ebb59634e
@ -41,6 +41,7 @@ const (
|
|||||||
typeServerHelloDone uint8 = 14
|
typeServerHelloDone uint8 = 14
|
||||||
typeClientKeyExchange uint8 = 16
|
typeClientKeyExchange uint8 = 16
|
||||||
typeFinished uint8 = 20
|
typeFinished uint8 = 20
|
||||||
|
typeNextProtocol uint8 = 67 // Not IANA assigned
|
||||||
)
|
)
|
||||||
|
|
||||||
// TLS cipher suites.
|
// TLS cipher suites.
|
||||||
@ -53,10 +54,17 @@ var (
|
|||||||
compressionNone uint8 = 0
|
compressionNone uint8 = 0
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// TLS extension numbers
|
||||||
|
var (
|
||||||
|
extensionServerName uint16 = 0
|
||||||
|
extensionNextProtoNeg uint16 = 13172 // not IANA assigned
|
||||||
|
)
|
||||||
|
|
||||||
type ConnectionState struct {
|
type ConnectionState struct {
|
||||||
HandshakeComplete bool
|
HandshakeComplete bool
|
||||||
CipherSuite string
|
CipherSuite string
|
||||||
Error alertType
|
Error alertType
|
||||||
|
NegotiatedProtocol string
|
||||||
}
|
}
|
||||||
|
|
||||||
// A Config structure is used to configure a TLS client or server. After one
|
// A Config structure is used to configure a TLS client or server. After one
|
||||||
@ -68,6 +76,9 @@ type Config struct {
|
|||||||
Time func() int64
|
Time func() int64
|
||||||
Certificates []Certificate
|
Certificates []Certificate
|
||||||
RootCAs *CASet
|
RootCAs *CASet
|
||||||
|
// NextProtos is a list of supported, application level protocols.
|
||||||
|
// Currently only server-side handling is supported.
|
||||||
|
NextProtos []string
|
||||||
}
|
}
|
||||||
|
|
||||||
type Certificate struct {
|
type Certificate struct {
|
||||||
|
@ -184,7 +184,7 @@ func (h *clientHandshake) loop(writeChan chan<- interface{}, controlChan chan<-
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
controlChan <- ConnectionState{true, "TLS_RSA_WITH_RC4_128_SHA", 0}
|
controlChan <- ConnectionState{HandshakeComplete: true, CipherSuite: "TLS_RSA_WITH_RC4_128_SHA"}
|
||||||
|
|
||||||
// This should just block forever.
|
// This should just block forever.
|
||||||
_ = h.readHandshakeMsg()
|
_ = h.readHandshakeMsg()
|
||||||
@ -218,7 +218,7 @@ func (h *clientHandshake) error(e alertType) {
|
|||||||
for _ = range h.msgChan {
|
for _ = range h.msgChan {
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
h.controlChan <- ConnectionState{false, "", e}
|
h.controlChan <- ConnectionState{Error: e}
|
||||||
close(h.controlChan)
|
close(h.controlChan)
|
||||||
h.writeChan <- alert{alertLevelError, e}
|
h.writeChan <- alert{alertLevelError, e}
|
||||||
}
|
}
|
||||||
|
@ -4,6 +4,8 @@
|
|||||||
|
|
||||||
package tls
|
package tls
|
||||||
|
|
||||||
|
import "strings"
|
||||||
|
|
||||||
type clientHelloMsg struct {
|
type clientHelloMsg struct {
|
||||||
raw []byte
|
raw []byte
|
||||||
major, minor uint8
|
major, minor uint8
|
||||||
@ -11,6 +13,8 @@ type clientHelloMsg struct {
|
|||||||
sessionId []byte
|
sessionId []byte
|
||||||
cipherSuites []uint16
|
cipherSuites []uint16
|
||||||
compressionMethods []uint8
|
compressionMethods []uint8
|
||||||
|
nextProtoNeg bool
|
||||||
|
serverName string
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *clientHelloMsg) marshal() []byte {
|
func (m *clientHelloMsg) marshal() []byte {
|
||||||
@ -19,6 +23,20 @@ func (m *clientHelloMsg) marshal() []byte {
|
|||||||
}
|
}
|
||||||
|
|
||||||
length := 2 + 32 + 1 + len(m.sessionId) + 2 + len(m.cipherSuites)*2 + 1 + len(m.compressionMethods)
|
length := 2 + 32 + 1 + len(m.sessionId) + 2 + len(m.cipherSuites)*2 + 1 + len(m.compressionMethods)
|
||||||
|
numExtensions := 0
|
||||||
|
extensionsLength := 0
|
||||||
|
if m.nextProtoNeg {
|
||||||
|
numExtensions++
|
||||||
|
}
|
||||||
|
if len(m.serverName) > 0 {
|
||||||
|
extensionsLength += 5 + len(m.serverName)
|
||||||
|
numExtensions++
|
||||||
|
}
|
||||||
|
if numExtensions > 0 {
|
||||||
|
extensionsLength += 4 * numExtensions
|
||||||
|
length += 2 + extensionsLength
|
||||||
|
}
|
||||||
|
|
||||||
x := make([]byte, 4+length)
|
x := make([]byte, 4+length)
|
||||||
x[0] = typeClientHello
|
x[0] = typeClientHello
|
||||||
x[1] = uint8(length >> 16)
|
x[1] = uint8(length >> 16)
|
||||||
@ -39,6 +57,53 @@ func (m *clientHelloMsg) marshal() []byte {
|
|||||||
z := y[2+len(m.cipherSuites)*2:]
|
z := y[2+len(m.cipherSuites)*2:]
|
||||||
z[0] = uint8(len(m.compressionMethods))
|
z[0] = uint8(len(m.compressionMethods))
|
||||||
copy(z[1:], m.compressionMethods)
|
copy(z[1:], m.compressionMethods)
|
||||||
|
|
||||||
|
z = z[1+len(m.compressionMethods):]
|
||||||
|
if numExtensions > 0 {
|
||||||
|
z[0] = byte(extensionsLength >> 8)
|
||||||
|
z[1] = byte(extensionsLength)
|
||||||
|
z = z[2:]
|
||||||
|
}
|
||||||
|
if m.nextProtoNeg {
|
||||||
|
z[0] = byte(extensionNextProtoNeg >> 8)
|
||||||
|
z[1] = byte(extensionNextProtoNeg)
|
||||||
|
// The length is always 0
|
||||||
|
z = z[4:]
|
||||||
|
}
|
||||||
|
if len(m.serverName) > 0 {
|
||||||
|
z[0] = byte(extensionServerName >> 8)
|
||||||
|
z[1] = byte(extensionServerName)
|
||||||
|
l := len(m.serverName) + 5
|
||||||
|
z[2] = byte(l >> 8)
|
||||||
|
z[3] = byte(l)
|
||||||
|
z = z[4:]
|
||||||
|
|
||||||
|
// RFC 3546, section 3.1
|
||||||
|
//
|
||||||
|
// struct {
|
||||||
|
// NameType name_type;
|
||||||
|
// select (name_type) {
|
||||||
|
// case host_name: HostName;
|
||||||
|
// } name;
|
||||||
|
// } ServerName;
|
||||||
|
//
|
||||||
|
// enum {
|
||||||
|
// host_name(0), (255)
|
||||||
|
// } NameType;
|
||||||
|
//
|
||||||
|
// opaque HostName<1..2^16-1>;
|
||||||
|
//
|
||||||
|
// struct {
|
||||||
|
// ServerName server_name_list<1..2^16-1>
|
||||||
|
// } ServerNameList;
|
||||||
|
|
||||||
|
z[1] = 1
|
||||||
|
z[3] = byte(len(m.serverName) >> 8)
|
||||||
|
z[4] = byte(len(m.serverName))
|
||||||
|
copy(z[5:], strings.Bytes(m.serverName))
|
||||||
|
z = z[l:]
|
||||||
|
}
|
||||||
|
|
||||||
m.raw = x
|
m.raw = x
|
||||||
|
|
||||||
return x
|
return x
|
||||||
@ -82,7 +147,68 @@ func (m *clientHelloMsg) unmarshal(data []byte) bool {
|
|||||||
}
|
}
|
||||||
m.compressionMethods = data[1 : 1+compressionMethodsLen]
|
m.compressionMethods = data[1 : 1+compressionMethodsLen]
|
||||||
|
|
||||||
// A ClientHello may be following by trailing data: RFC 4346 section 7.4.1.2
|
data = data[1+compressionMethodsLen:]
|
||||||
|
|
||||||
|
m.nextProtoNeg = false
|
||||||
|
m.serverName = ""
|
||||||
|
|
||||||
|
if len(data) == 0 {
|
||||||
|
// ClientHello is optionally followed by extension data
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
if len(data) < 2 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
extensionsLength := int(data[0])<<8 | int(data[1])
|
||||||
|
data = data[2:]
|
||||||
|
if extensionsLength != len(data) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
for len(data) != 0 {
|
||||||
|
if len(data) < 4 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
extension := uint16(data[0])<<8 | uint16(data[1])
|
||||||
|
length := int(data[2])<<8 | int(data[3])
|
||||||
|
data = data[4:]
|
||||||
|
if len(data) < length {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
switch extension {
|
||||||
|
case extensionServerName:
|
||||||
|
if length < 2 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
numNames := int(data[0])<<8 | int(data[1])
|
||||||
|
d := data[2:]
|
||||||
|
for i := 0; i < numNames; i++ {
|
||||||
|
if len(d) < 3 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
nameType := d[0]
|
||||||
|
nameLen := int(d[1])<<8 | int(d[2])
|
||||||
|
d = d[3:]
|
||||||
|
if len(d) < nameLen {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if nameType == 0 {
|
||||||
|
m.serverName = string(d[0:nameLen])
|
||||||
|
break
|
||||||
|
}
|
||||||
|
d = d[nameLen:]
|
||||||
|
}
|
||||||
|
case extensionNextProtoNeg:
|
||||||
|
if length > 0 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
m.nextProtoNeg = true
|
||||||
|
}
|
||||||
|
data = data[length:]
|
||||||
|
}
|
||||||
|
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -93,6 +219,8 @@ type serverHelloMsg struct {
|
|||||||
sessionId []byte
|
sessionId []byte
|
||||||
cipherSuite uint16
|
cipherSuite uint16
|
||||||
compressionMethod uint8
|
compressionMethod uint8
|
||||||
|
nextProtoNeg bool
|
||||||
|
nextProtos []string
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *serverHelloMsg) marshal() []byte {
|
func (m *serverHelloMsg) marshal() []byte {
|
||||||
@ -101,6 +229,23 @@ func (m *serverHelloMsg) marshal() []byte {
|
|||||||
}
|
}
|
||||||
|
|
||||||
length := 38 + len(m.sessionId)
|
length := 38 + len(m.sessionId)
|
||||||
|
numExtensions := 0
|
||||||
|
extensionsLength := 0
|
||||||
|
|
||||||
|
nextProtoLen := 0
|
||||||
|
if m.nextProtoNeg {
|
||||||
|
numExtensions++
|
||||||
|
for _, v := range m.nextProtos {
|
||||||
|
nextProtoLen += len(v)
|
||||||
|
}
|
||||||
|
nextProtoLen += len(m.nextProtos)
|
||||||
|
extensionsLength += nextProtoLen
|
||||||
|
}
|
||||||
|
if numExtensions > 0 {
|
||||||
|
extensionsLength += 4 * numExtensions
|
||||||
|
length += 2 + extensionsLength
|
||||||
|
}
|
||||||
|
|
||||||
x := make([]byte, 4+length)
|
x := make([]byte, 4+length)
|
||||||
x[0] = typeServerHello
|
x[0] = typeServerHello
|
||||||
x[1] = uint8(length >> 16)
|
x[1] = uint8(length >> 16)
|
||||||
@ -115,11 +260,49 @@ func (m *serverHelloMsg) marshal() []byte {
|
|||||||
z[0] = uint8(m.cipherSuite >> 8)
|
z[0] = uint8(m.cipherSuite >> 8)
|
||||||
z[1] = uint8(m.cipherSuite)
|
z[1] = uint8(m.cipherSuite)
|
||||||
z[2] = uint8(m.compressionMethod)
|
z[2] = uint8(m.compressionMethod)
|
||||||
|
|
||||||
|
z = z[3:]
|
||||||
|
if numExtensions > 0 {
|
||||||
|
z[0] = byte(extensionsLength >> 8)
|
||||||
|
z[1] = byte(extensionsLength)
|
||||||
|
z = z[2:]
|
||||||
|
}
|
||||||
|
if m.nextProtoNeg {
|
||||||
|
z[0] = byte(extensionNextProtoNeg >> 8)
|
||||||
|
z[1] = byte(extensionNextProtoNeg)
|
||||||
|
z[2] = byte(nextProtoLen >> 8)
|
||||||
|
z[3] = byte(nextProtoLen)
|
||||||
|
z = z[4:]
|
||||||
|
|
||||||
|
for _, v := range m.nextProtos {
|
||||||
|
l := len(v)
|
||||||
|
if l > 255 {
|
||||||
|
l = 255
|
||||||
|
}
|
||||||
|
z[0] = byte(l)
|
||||||
|
copy(z[1:], strings.Bytes(v[0:l]))
|
||||||
|
z = z[1+l:]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
m.raw = x
|
m.raw = x
|
||||||
|
|
||||||
return x
|
return x
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func append(slice []string, elem string) []string {
|
||||||
|
if len(slice) < cap(slice) {
|
||||||
|
slice = slice[0 : len(slice)+1]
|
||||||
|
slice[len(slice)-1] = elem
|
||||||
|
return slice
|
||||||
|
}
|
||||||
|
|
||||||
|
fresh := make([]string, len(slice)+1, cap(slice)*2+1)
|
||||||
|
copy(fresh, slice)
|
||||||
|
fresh[len(slice)] = elem
|
||||||
|
return fresh
|
||||||
|
}
|
||||||
|
|
||||||
func (m *serverHelloMsg) unmarshal(data []byte) bool {
|
func (m *serverHelloMsg) unmarshal(data []byte) bool {
|
||||||
if len(data) < 42 {
|
if len(data) < 42 {
|
||||||
return false
|
return false
|
||||||
@ -139,8 +322,53 @@ func (m *serverHelloMsg) unmarshal(data []byte) bool {
|
|||||||
}
|
}
|
||||||
m.cipherSuite = uint16(data[0])<<8 | uint16(data[1])
|
m.cipherSuite = uint16(data[0])<<8 | uint16(data[1])
|
||||||
m.compressionMethod = data[2]
|
m.compressionMethod = data[2]
|
||||||
|
data = data[3:]
|
||||||
|
|
||||||
|
m.nextProtoNeg = false
|
||||||
|
m.nextProtos = nil
|
||||||
|
|
||||||
|
if len(data) == 0 {
|
||||||
|
// ServerHello is optionally followed by extension data
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
if len(data) < 2 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
extensionsLength := int(data[0])<<8 | int(data[1])
|
||||||
|
data = data[2:]
|
||||||
|
if len(data) != extensionsLength {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
for len(data) != 0 {
|
||||||
|
if len(data) < 4 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
extension := uint16(data[0])<<8 | uint16(data[1])
|
||||||
|
length := int(data[2])<<8 | int(data[3])
|
||||||
|
data = data[4:]
|
||||||
|
if len(data) < length {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
switch extension {
|
||||||
|
case extensionNextProtoNeg:
|
||||||
|
m.nextProtoNeg = true
|
||||||
|
d := data
|
||||||
|
for len(d) > 0 {
|
||||||
|
l := int(d[0])
|
||||||
|
d = d[1:]
|
||||||
|
if l == 0 || l > len(d) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
m.nextProtos = append(m.nextProtos, string(d[0:l]))
|
||||||
|
d = d[l:]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
data = data[length:]
|
||||||
|
}
|
||||||
|
|
||||||
// Trailing data is allowed because extensions may be present.
|
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -295,3 +523,63 @@ func (m *finishedMsg) unmarshal(data []byte) bool {
|
|||||||
m.verifyData = data[4:]
|
m.verifyData = data[4:]
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type nextProtoMsg struct {
|
||||||
|
raw []byte
|
||||||
|
proto string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *nextProtoMsg) marshal() []byte {
|
||||||
|
if m.raw != nil {
|
||||||
|
return m.raw
|
||||||
|
}
|
||||||
|
l := len(m.proto)
|
||||||
|
if l > 255 {
|
||||||
|
l = 255
|
||||||
|
}
|
||||||
|
|
||||||
|
padding := 32 - (l+2)%32
|
||||||
|
length := l + padding + 2
|
||||||
|
x := make([]byte, length+4)
|
||||||
|
x[0] = typeNextProtocol
|
||||||
|
x[1] = uint8(length >> 16)
|
||||||
|
x[2] = uint8(length >> 8)
|
||||||
|
x[3] = uint8(length)
|
||||||
|
|
||||||
|
y := x[4:]
|
||||||
|
y[0] = byte(l)
|
||||||
|
copy(y[1:], strings.Bytes(m.proto[0:l]))
|
||||||
|
y = y[1+l:]
|
||||||
|
y[0] = byte(padding)
|
||||||
|
|
||||||
|
m.raw = x
|
||||||
|
|
||||||
|
return x
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *nextProtoMsg) unmarshal(data []byte) bool {
|
||||||
|
m.raw = data
|
||||||
|
|
||||||
|
if len(data) < 5 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
data = data[4:]
|
||||||
|
protoLen := int(data[0])
|
||||||
|
data = data[1:]
|
||||||
|
if len(data) < protoLen {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
m.proto = string(data[0:protoLen])
|
||||||
|
data = data[protoLen:]
|
||||||
|
|
||||||
|
if len(data) < 1 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
paddingLen := int(data[0])
|
||||||
|
data = data[1:]
|
||||||
|
if len(data) != paddingLen {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
@ -14,9 +14,11 @@ import (
|
|||||||
var tests = []interface{}{
|
var tests = []interface{}{
|
||||||
&clientHelloMsg{},
|
&clientHelloMsg{},
|
||||||
&serverHelloMsg{},
|
&serverHelloMsg{},
|
||||||
|
|
||||||
&certificateMsg{},
|
&certificateMsg{},
|
||||||
&clientKeyExchangeMsg{},
|
&clientKeyExchangeMsg{},
|
||||||
&finishedMsg{},
|
&finishedMsg{},
|
||||||
|
&nextProtoMsg{},
|
||||||
}
|
}
|
||||||
|
|
||||||
type testMessage interface {
|
type testMessage interface {
|
||||||
@ -40,21 +42,26 @@ func TestMarshalUnmarshal(t *testing.T) {
|
|||||||
marshaled := m1.marshal()
|
marshaled := m1.marshal()
|
||||||
m2 := iface.(testMessage)
|
m2 := iface.(testMessage)
|
||||||
if !m2.unmarshal(marshaled) {
|
if !m2.unmarshal(marshaled) {
|
||||||
t.Errorf("#%d failed to unmarshal %#v", i, m1)
|
t.Errorf("#%d failed to unmarshal %#v %x", i, m1, marshaled)
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
m2.marshal() // to fill any marshal cache in the message
|
m2.marshal() // to fill any marshal cache in the message
|
||||||
|
|
||||||
if !reflect.DeepEqual(m1, m2) {
|
if !reflect.DeepEqual(m1, m2) {
|
||||||
t.Errorf("#%d got:%#v want:%#v", i, m1, m2)
|
t.Errorf("#%d got:%#v want:%#v %x", i, m2, m1, marshaled)
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
|
|
||||||
// Now check that all prefixes are invalid.
|
if i >= 2 {
|
||||||
for j := 0; j < len(marshaled); j++ {
|
// The first two message types (ClientHello and
|
||||||
if m2.unmarshal(marshaled[0:j]) {
|
// ServerHello) are allowed to have parsable
|
||||||
t.Errorf("#%d unmarshaled a prefix of length %d of %#v", i, j, m1)
|
// prefixes because the extension data is
|
||||||
break
|
// optional.
|
||||||
|
for j := 0; j < len(marshaled); j++ {
|
||||||
|
if m2.unmarshal(marshaled[0:j]) {
|
||||||
|
t.Errorf("#%d unmarshaled a prefix of length %d of %#v", i, j, m1)
|
||||||
|
break
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -83,6 +90,11 @@ func randomBytes(n int, rand *rand.Rand) []byte {
|
|||||||
return r
|
return r
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func randomString(n int, rand *rand.Rand) string {
|
||||||
|
b := randomBytes(n, rand)
|
||||||
|
return string(b)
|
||||||
|
}
|
||||||
|
|
||||||
func (*clientHelloMsg) Generate(rand *rand.Rand, size int) reflect.Value {
|
func (*clientHelloMsg) Generate(rand *rand.Rand, size int) reflect.Value {
|
||||||
m := &clientHelloMsg{}
|
m := &clientHelloMsg{}
|
||||||
m.major = uint8(rand.Intn(256))
|
m.major = uint8(rand.Intn(256))
|
||||||
@ -94,6 +106,12 @@ func (*clientHelloMsg) Generate(rand *rand.Rand, size int) reflect.Value {
|
|||||||
m.cipherSuites[i] = uint16(rand.Int31())
|
m.cipherSuites[i] = uint16(rand.Int31())
|
||||||
}
|
}
|
||||||
m.compressionMethods = randomBytes(rand.Intn(63)+1, rand)
|
m.compressionMethods = randomBytes(rand.Intn(63)+1, rand)
|
||||||
|
if rand.Intn(10) > 5 {
|
||||||
|
m.nextProtoNeg = true
|
||||||
|
}
|
||||||
|
if rand.Intn(10) > 5 {
|
||||||
|
m.serverName = randomString(rand.Intn(255), rand)
|
||||||
|
}
|
||||||
|
|
||||||
return reflect.NewValue(m)
|
return reflect.NewValue(m)
|
||||||
}
|
}
|
||||||
@ -106,6 +124,17 @@ func (*serverHelloMsg) Generate(rand *rand.Rand, size int) reflect.Value {
|
|||||||
m.sessionId = randomBytes(rand.Intn(32), rand)
|
m.sessionId = randomBytes(rand.Intn(32), rand)
|
||||||
m.cipherSuite = uint16(rand.Int31())
|
m.cipherSuite = uint16(rand.Int31())
|
||||||
m.compressionMethod = uint8(rand.Intn(256))
|
m.compressionMethod = uint8(rand.Intn(256))
|
||||||
|
|
||||||
|
if rand.Intn(10) > 5 {
|
||||||
|
m.nextProtoNeg = true
|
||||||
|
|
||||||
|
n := rand.Intn(10)
|
||||||
|
m.nextProtos = make([]string, n)
|
||||||
|
for i := 0; i < n; i++ {
|
||||||
|
m.nextProtos[i] = randomString(20, rand)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return reflect.NewValue(m)
|
return reflect.NewValue(m)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -130,3 +159,9 @@ func (*finishedMsg) Generate(rand *rand.Rand, size int) reflect.Value {
|
|||||||
m.verifyData = randomBytes(12, rand)
|
m.verifyData = randomBytes(12, rand)
|
||||||
return reflect.NewValue(m)
|
return reflect.NewValue(m)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (*nextProtoMsg) Generate(rand *rand.Rand, size int) reflect.Value {
|
||||||
|
m := &nextProtoMsg{}
|
||||||
|
m.proto = randomString(rand.Intn(255), rand)
|
||||||
|
return reflect.NewValue(m)
|
||||||
|
}
|
||||||
|
@ -108,6 +108,10 @@ func (h *serverHandshake) loop(writeChan chan<- interface{}, controlChan chan<-
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
hello.compressionMethod = compressionNone
|
hello.compressionMethod = compressionNone
|
||||||
|
if clientHello.nextProtoNeg {
|
||||||
|
hello.nextProtoNeg = true
|
||||||
|
hello.nextProtos = config.NextProtos
|
||||||
|
}
|
||||||
|
|
||||||
finishedHash.Write(hello.marshal())
|
finishedHash.Write(hello.marshal())
|
||||||
writeChan <- writerSetVersion{major, minor}
|
writeChan <- writerSetVersion{major, minor}
|
||||||
@ -165,6 +169,17 @@ func (h *serverHandshake) loop(writeChan chan<- interface{}, controlChan chan<-
|
|||||||
cipher, _ := rc4.NewCipher(clientKey)
|
cipher, _ := rc4.NewCipher(clientKey)
|
||||||
controlChan <- &newCipherSpec{cipher, hmac.New(sha1.New(), clientMAC)}
|
controlChan <- &newCipherSpec{cipher, hmac.New(sha1.New(), clientMAC)}
|
||||||
|
|
||||||
|
clientProtocol := ""
|
||||||
|
if hello.nextProtoNeg {
|
||||||
|
nextProto, ok := h.readHandshakeMsg().(*nextProtoMsg)
|
||||||
|
if !ok {
|
||||||
|
h.error(alertUnexpectedMessage)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
finishedHash.Write(nextProto.marshal())
|
||||||
|
clientProtocol = nextProto.proto
|
||||||
|
}
|
||||||
|
|
||||||
clientFinished, ok := h.readHandshakeMsg().(*finishedMsg)
|
clientFinished, ok := h.readHandshakeMsg().(*finishedMsg)
|
||||||
if !ok {
|
if !ok {
|
||||||
h.error(alertUnexpectedMessage)
|
h.error(alertUnexpectedMessage)
|
||||||
@ -178,7 +193,7 @@ func (h *serverHandshake) loop(writeChan chan<- interface{}, controlChan chan<-
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
controlChan <- ConnectionState{true, "TLS_RSA_WITH_RC4_128_SHA", 0}
|
controlChan <- ConnectionState{true, "TLS_RSA_WITH_RC4_128_SHA", 0, clientProtocol}
|
||||||
|
|
||||||
finishedHash.Write(clientFinished.marshal())
|
finishedHash.Write(clientFinished.marshal())
|
||||||
|
|
||||||
@ -228,7 +243,7 @@ func (h *serverHandshake) error(e alertType) {
|
|||||||
for _ = range h.msgChan {
|
for _ = range h.msgChan {
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
h.controlChan <- ConnectionState{false, "", e}
|
h.controlChan <- ConnectionState{false, "", e, ""}
|
||||||
close(h.controlChan)
|
close(h.controlChan)
|
||||||
h.writeChan <- alert{alertLevelError, e}
|
h.writeChan <- alert{alertLevelError, e}
|
||||||
}
|
}
|
||||||
|
@ -51,7 +51,7 @@ func testClientHelloFailure(t *testing.T, clientHello interface{}, expectedAlert
|
|||||||
send := script.NewEvent("send", nil, script.Send{msgChan, clientHello})
|
send := script.NewEvent("send", nil, script.Send{msgChan, clientHello})
|
||||||
recvAlert := script.NewEvent("recv alert", []*script.Event{send}, script.Recv{writeChan, alert{alertLevelError, expectedAlert}})
|
recvAlert := script.NewEvent("recv alert", []*script.Event{send}, script.Recv{writeChan, alert{alertLevelError, expectedAlert}})
|
||||||
close1 := script.NewEvent("msgChan close", []*script.Event{recvAlert}, script.Closed{writeChan})
|
close1 := script.NewEvent("msgChan close", []*script.Event{recvAlert}, script.Closed{writeChan})
|
||||||
recvState := script.NewEvent("recv state", []*script.Event{send}, script.Recv{controlChan, ConnectionState{false, "", expectedAlert}})
|
recvState := script.NewEvent("recv state", []*script.Event{send}, script.Recv{controlChan, ConnectionState{false, "", expectedAlert, ""}})
|
||||||
close2 := script.NewEvent("controlChan close", []*script.Event{recvState}, script.Closed{controlChan})
|
close2 := script.NewEvent("controlChan close", []*script.Event{recvState}, script.Closed{controlChan})
|
||||||
|
|
||||||
err := script.Perform(0, []*script.Event{send, recvAlert, close1, recvState, close2})
|
err := script.Perform(0, []*script.Event{send, recvAlert, close1, recvState, close2})
|
||||||
@ -78,13 +78,13 @@ func TestRejectBadProtocolVersion(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestNoSuiteOverlap(t *testing.T) {
|
func TestNoSuiteOverlap(t *testing.T) {
|
||||||
clientHello := &clientHelloMsg{nil, 3, 1, nil, nil, []uint16{0xff00}, []uint8{0}}
|
clientHello := &clientHelloMsg{nil, 3, 1, nil, nil, []uint16{0xff00}, []uint8{0}, false, ""}
|
||||||
testClientHelloFailure(t, clientHello, alertHandshakeFailure)
|
testClientHelloFailure(t, clientHello, alertHandshakeFailure)
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestNoCompressionOverlap(t *testing.T) {
|
func TestNoCompressionOverlap(t *testing.T) {
|
||||||
clientHello := &clientHelloMsg{nil, 3, 1, nil, nil, []uint16{TLS_RSA_WITH_RC4_128_SHA}, []uint8{0xff}}
|
clientHello := &clientHelloMsg{nil, 3, 1, nil, nil, []uint16{TLS_RSA_WITH_RC4_128_SHA}, []uint8{0xff}, false, ""}
|
||||||
testClientHelloFailure(t, clientHello, alertHandshakeFailure)
|
testClientHelloFailure(t, clientHello, alertHandshakeFailure)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -165,7 +165,7 @@ func TestFullHandshake(t *testing.T) {
|
|||||||
defer close(msgChan)
|
defer close(msgChan)
|
||||||
|
|
||||||
// The values for this test were obtained from running `gnutls-cli --insecure --debug 9`
|
// The values for this test were obtained from running `gnutls-cli --insecure --debug 9`
|
||||||
clientHello := &clientHelloMsg{fromHex("0100007603024aef7d77e4686d5dfd9d953dfe280788759ffd440867d687670216da45516b310000340033004500390088001600320044003800870013006600900091008f008e002f004100350084000a00050004008c008d008b008a01000019000900030200010000000e000c0000093132372e302e302e31"), 3, 2, fromHex("4aef7d77e4686d5dfd9d953dfe280788759ffd440867d687670216da45516b31"), nil, []uint16{0x33, 0x45, 0x39, 0x88, 0x16, 0x32, 0x44, 0x38, 0x87, 0x13, 0x66, 0x90, 0x91, 0x8f, 0x8e, 0x2f, 0x41, 0x35, 0x84, 0xa, 0x5, 0x4, 0x8c, 0x8d, 0x8b, 0x8a}, []uint8{0x0}}
|
clientHello := &clientHelloMsg{fromHex("0100007603024aef7d77e4686d5dfd9d953dfe280788759ffd440867d687670216da45516b310000340033004500390088001600320044003800870013006600900091008f008e002f004100350084000a00050004008c008d008b008a01000019000900030200010000000e000c0000093132372e302e302e31"), 3, 2, fromHex("4aef7d77e4686d5dfd9d953dfe280788759ffd440867d687670216da45516b31"), nil, []uint16{0x33, 0x45, 0x39, 0x88, 0x16, 0x32, 0x44, 0x38, 0x87, 0x13, 0x66, 0x90, 0x91, 0x8f, 0x8e, 0x2f, 0x41, 0x35, 0x84, 0xa, 0x5, 0x4, 0x8c, 0x8d, 0x8b, 0x8a}, []uint8{0x0}, false, ""}
|
||||||
|
|
||||||
sendHello := script.NewEvent("send hello", nil, script.Send{msgChan, clientHello})
|
sendHello := script.NewEvent("send hello", nil, script.Send{msgChan, clientHello})
|
||||||
setVersion := script.NewEvent("set version", []*script.Event{sendHello}, script.Recv{writeChan, writerSetVersion{3, 2}})
|
setVersion := script.NewEvent("set version", []*script.Event{sendHello}, script.Recv{writeChan, writerSetVersion{3, 2}})
|
||||||
@ -183,7 +183,7 @@ func TestFullHandshake(t *testing.T) {
|
|||||||
sendFinished := script.NewEvent("send finished", []*script.Event{recvNCS}, script.Send{msgChan, finished})
|
sendFinished := script.NewEvent("send finished", []*script.Event{recvNCS}, script.Send{msgChan, finished})
|
||||||
recvFinished := script.NewEvent("recv finished", []*script.Event{sendFinished}, script.RecvMatch{writeChan, matchFinished})
|
recvFinished := script.NewEvent("recv finished", []*script.Event{sendFinished}, script.RecvMatch{writeChan, matchFinished})
|
||||||
setCipher := script.NewEvent("set cipher", []*script.Event{sendFinished}, script.RecvMatch{writeChan, matchSetCipher})
|
setCipher := script.NewEvent("set cipher", []*script.Event{sendFinished}, script.RecvMatch{writeChan, matchSetCipher})
|
||||||
recvConnectionState := script.NewEvent("recv state", []*script.Event{sendFinished}, script.Recv{controlChan, ConnectionState{true, "TLS_RSA_WITH_RC4_128_SHA", 0}})
|
recvConnectionState := script.NewEvent("recv state", []*script.Event{sendFinished}, script.Recv{controlChan, ConnectionState{true, "TLS_RSA_WITH_RC4_128_SHA", 0, ""}})
|
||||||
|
|
||||||
err := script.Perform(0, []*script.Event{sendHello, setVersion, recvHello, recvCert, recvDone, sendCKX, sendCCS, recvNCS, sendFinished, setCipher, recvConnectionState, recvFinished})
|
err := script.Perform(0, []*script.Event{sendHello, setVersion, recvHello, recvCert, recvDone, sendCKX, sendCCS, recvNCS, sendFinished, setCipher, recvConnectionState, recvFinished})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -291,6 +291,8 @@ func parseHandshakeMsg(data []byte) (interface{}, bool) {
|
|||||||
m = new(serverHelloDoneMsg)
|
m = new(serverHelloDoneMsg)
|
||||||
case typeClientKeyExchange:
|
case typeClientKeyExchange:
|
||||||
m = new(clientKeyExchangeMsg)
|
m = new(clientKeyExchangeMsg)
|
||||||
|
case typeNextProtocol:
|
||||||
|
m = new(nextProtoMsg)
|
||||||
default:
|
default:
|
||||||
return nil, false
|
return nil, false
|
||||||
}
|
}
|
||||||
|
@ -36,7 +36,7 @@ func TestNullConnectionState(t *testing.T) {
|
|||||||
// Test a simple request for the connection state.
|
// Test a simple request for the connection state.
|
||||||
replyChan := make(chan ConnectionState)
|
replyChan := make(chan ConnectionState)
|
||||||
sendReq := script.NewEvent("send request", nil, script.Send{requestChan, getConnectionState{replyChan}})
|
sendReq := script.NewEvent("send request", nil, script.Send{requestChan, getConnectionState{replyChan}})
|
||||||
getReply := script.NewEvent("get reply", []*script.Event{sendReq}, script.Recv{replyChan, ConnectionState{false, "", 0}})
|
getReply := script.NewEvent("get reply", []*script.Event{sendReq}, script.Recv{replyChan, ConnectionState{false, "", 0, ""}})
|
||||||
|
|
||||||
err := script.Perform(0, []*script.Event{sendReq, getReply})
|
err := script.Perform(0, []*script.Event{sendReq, getReply})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -55,9 +55,9 @@ func TestWaitConnectionState(t *testing.T) {
|
|||||||
sendReq := script.NewEvent("send request", nil, script.Send{requestChan, waitConnectionState{replyChan}})
|
sendReq := script.NewEvent("send request", nil, script.Send{requestChan, waitConnectionState{replyChan}})
|
||||||
replyChan2 := make(chan ConnectionState)
|
replyChan2 := make(chan ConnectionState)
|
||||||
sendReq2 := script.NewEvent("send request 2", []*script.Event{sendReq}, script.Send{requestChan, getConnectionState{replyChan2}})
|
sendReq2 := script.NewEvent("send request 2", []*script.Event{sendReq}, script.Send{requestChan, getConnectionState{replyChan2}})
|
||||||
getReply2 := script.NewEvent("get reply 2", []*script.Event{sendReq2}, script.Recv{replyChan2, ConnectionState{false, "", 0}})
|
getReply2 := script.NewEvent("get reply 2", []*script.Event{sendReq2}, script.Recv{replyChan2, ConnectionState{false, "", 0, ""}})
|
||||||
sendState := script.NewEvent("send state", []*script.Event{getReply2}, script.Send{controlChan, ConnectionState{true, "test", 1}})
|
sendState := script.NewEvent("send state", []*script.Event{getReply2}, script.Send{controlChan, ConnectionState{true, "test", 1, ""}})
|
||||||
getReply := script.NewEvent("get reply", []*script.Event{sendState}, script.Recv{replyChan, ConnectionState{true, "test", 1}})
|
getReply := script.NewEvent("get reply", []*script.Event{sendState}, script.Recv{replyChan, ConnectionState{true, "test", 1, ""}})
|
||||||
|
|
||||||
err := script.Perform(0, []*script.Event{sendReq, sendReq2, getReply2, sendState, getReply})
|
err := script.Perform(0, []*script.Event{sendReq, sendReq2, getReply2, sendState, getReply})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -108,7 +108,7 @@ func TestApplicationData(t *testing.T) {
|
|||||||
// Test that the application data is forwarded after a successful Finished message.
|
// Test that the application data is forwarded after a successful Finished message.
|
||||||
send1 := script.NewEvent("send 1", nil, script.Send{recordChan, &record{recordTypeHandshake, 0, 0, fromHex("1400000c000000000000000000000000")}})
|
send1 := script.NewEvent("send 1", nil, script.Send{recordChan, &record{recordTypeHandshake, 0, 0, fromHex("1400000c000000000000000000000000")}})
|
||||||
recv1 := script.NewEvent("recv finished", []*script.Event{send1}, script.Recv{handshakeChan, &finishedMsg{fromHex("1400000c000000000000000000000000"), fromHex("000000000000000000000000")}})
|
recv1 := script.NewEvent("recv finished", []*script.Event{send1}, script.Recv{handshakeChan, &finishedMsg{fromHex("1400000c000000000000000000000000"), fromHex("000000000000000000000000")}})
|
||||||
send2 := script.NewEvent("send connState", []*script.Event{recv1}, script.Send{controlChan, ConnectionState{true, "", 0}})
|
send2 := script.NewEvent("send connState", []*script.Event{recv1}, script.Send{controlChan, ConnectionState{true, "", 0, ""}})
|
||||||
send3 := script.NewEvent("send 2", []*script.Event{send2}, script.Send{recordChan, &record{recordTypeApplicationData, 0, 0, fromHex("0102")}})
|
send3 := script.NewEvent("send 2", []*script.Event{send2}, script.Send{recordChan, &record{recordTypeApplicationData, 0, 0, fromHex("0102")}})
|
||||||
recv2 := script.NewEvent("recv data", []*script.Event{send3}, script.Recv{appDataChan, []byte{0x01, 0x02}})
|
recv2 := script.NewEvent("recv data", []*script.Event{send3}, script.Recv{appDataChan, []byte{0x01, 0x02}})
|
||||||
|
|
||||||
@ -126,7 +126,7 @@ func TestInvalidChangeCipherSpec(t *testing.T) {
|
|||||||
|
|
||||||
send1 := script.NewEvent("send 1", nil, script.Send{recordChan, &record{recordTypeChangeCipherSpec, 0, 0, []byte{1}}})
|
send1 := script.NewEvent("send 1", nil, script.Send{recordChan, &record{recordTypeChangeCipherSpec, 0, 0, []byte{1}}})
|
||||||
recv1 := script.NewEvent("recv 1", []*script.Event{send1}, script.Recv{handshakeChan, changeCipherSpec{}})
|
recv1 := script.NewEvent("recv 1", []*script.Event{send1}, script.Recv{handshakeChan, changeCipherSpec{}})
|
||||||
send2 := script.NewEvent("send 2", []*script.Event{recv1}, script.Send{controlChan, ConnectionState{false, "", 42}})
|
send2 := script.NewEvent("send 2", []*script.Event{recv1}, script.Send{controlChan, ConnectionState{false, "", 42, ""}})
|
||||||
close := script.NewEvent("close 1", []*script.Event{send2}, script.Closed{appDataChan})
|
close := script.NewEvent("close 1", []*script.Event{send2}, script.Closed{appDataChan})
|
||||||
close2 := script.NewEvent("close 2", []*script.Event{send2}, script.Closed{handshakeChan})
|
close2 := script.NewEvent("close 2", []*script.Event{send2}, script.Closed{handshakeChan})
|
||||||
|
|
||||||
|
@ -21,7 +21,7 @@ func recordReader(c chan<- *record, source io.Reader) {
|
|||||||
|
|
||||||
for {
|
for {
|
||||||
var header [5]byte
|
var header [5]byte
|
||||||
n, _ := buf.Read(header[0:])
|
n, _ := buf.Read(&header)
|
||||||
if n != 5 {
|
if n != 5 {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user