From fd3978552ba514bff23ebadac6a75066c36f651d Mon Sep 17 00:00:00 2001 From: Dave Cheney Date: Tue, 20 Sep 2011 12:21:50 -0400 Subject: [PATCH] exp/ssh: refactor halfConnection to transport This CL generalises the pair of halfConnection members that the serverConn holds into a single transport struct that is shared by both Server and Client, see also CL 5037047. This CL is a replacement for 5040046 which I closed by accident. R=agl, bradfitz CC=golang-dev https://golang.org/cl/5075042 --- src/pkg/exp/ssh/channel.go | 14 +- src/pkg/exp/ssh/common.go | 2 +- src/pkg/exp/ssh/server.go | 83 +++++------ src/pkg/exp/ssh/transport.go | 232 ++++++++++++++++-------------- src/pkg/exp/ssh/transport_test.go | 37 +++++ 5 files changed, 211 insertions(+), 157 deletions(-) create mode 100644 src/pkg/exp/ssh/transport_test.go diff --git a/src/pkg/exp/ssh/channel.go b/src/pkg/exp/ssh/channel.go index 10f62354f43..922584f6317 100644 --- a/src/pkg/exp/ssh/channel.go +++ b/src/pkg/exp/ssh/channel.go @@ -97,7 +97,7 @@ func (c *channel) Accept() os.Error { MyWindow: c.myWindow, MaxPacketSize: c.maxPacketSize, } - return c.serverConn.out.writePacket(marshal(msgChannelOpenConfirm, confirm)) + return c.serverConn.writePacket(marshal(msgChannelOpenConfirm, confirm)) } func (c *channel) Reject(reason RejectionReason, message string) os.Error { @@ -114,7 +114,7 @@ func (c *channel) Reject(reason RejectionReason, message string) os.Error { Message: message, Language: "en", } - return c.serverConn.out.writePacket(marshal(msgChannelOpenFailure, reject)) + return c.serverConn.writePacket(marshal(msgChannelOpenFailure, reject)) } func (c *channel) handlePacket(packet interface{}) { @@ -180,7 +180,7 @@ func (c *channel) Read(data []byte) (n int, err os.Error) { PeersId: c.theirId, AdditionalBytes: uint32(len(c.pendingData)) - c.myWindow, }) - if err := c.serverConn.out.writePacket(packet); err != nil { + if err := c.serverConn.writePacket(packet); err != nil { return 0, err } } @@ -254,7 +254,7 @@ func (c *channel) Write(data []byte) (n int, err os.Error) { copy(packet[9:], todo) c.serverConn.lock.Lock() - if err = c.serverConn.out.writePacket(packet); err != nil { + if err = c.serverConn.writePacket(packet); err != nil { c.serverConn.lock.Unlock() return } @@ -283,7 +283,7 @@ func (c *channel) Close() os.Error { closeMsg := channelCloseMsg{ PeersId: c.theirId, } - return c.serverConn.out.writePacket(marshal(msgChannelClose, closeMsg)) + return c.serverConn.writePacket(marshal(msgChannelClose, closeMsg)) } func (c *channel) AckRequest(ok bool) os.Error { @@ -298,12 +298,12 @@ func (c *channel) AckRequest(ok bool) os.Error { ack := channelRequestSuccessMsg{ PeersId: c.theirId, } - return c.serverConn.out.writePacket(marshal(msgChannelSuccess, ack)) + return c.serverConn.writePacket(marshal(msgChannelSuccess, ack)) } else { ack := channelRequestFailureMsg{ PeersId: c.theirId, } - return c.serverConn.out.writePacket(marshal(msgChannelFailure, ack)) + return c.serverConn.writePacket(marshal(msgChannelFailure, ack)) } panic("unreachable") } diff --git a/src/pkg/exp/ssh/common.go b/src/pkg/exp/ssh/common.go index c951d1a753e..698db60b8e6 100644 --- a/src/pkg/exp/ssh/common.go +++ b/src/pkg/exp/ssh/common.go @@ -50,7 +50,7 @@ func findCommonAlgorithm(clientAlgos []string, serverAlgos []string) (commonAlgo return } -func findAgreedAlgorithms(clientToServer, serverToClient *halfConnection, clientKexInit, serverKexInit *kexInitMsg) (kexAlgo, hostKeyAlgo string, ok bool) { +func findAgreedAlgorithms(clientToServer, serverToClient *transport, clientKexInit, serverKexInit *kexInitMsg) (kexAlgo, hostKeyAlgo string, ok bool) { kexAlgo, ok = findCommonAlgorithm(clientKexInit.KexAlgos, serverKexInit.KexAlgos) if !ok { return diff --git a/src/pkg/exp/ssh/server.go b/src/pkg/exp/ssh/server.go index 57cd5971063..bc0af13e821 100644 --- a/src/pkg/exp/ssh/server.go +++ b/src/pkg/exp/ssh/server.go @@ -129,7 +129,7 @@ const maxCachedPubKeys = 16 type ServerConnection struct { Server *Server - in, out *halfConnection + *transport channels map[uint32]*channel nextChanId uint32 @@ -174,7 +174,7 @@ type handshakeMagics struct { // kexDH performs Diffie-Hellman key agreement on a ServerConnection. The // returned values are given the same names as in RFC 4253, section 8. func (s *ServerConnection) kexDH(group *dhGroup, hashFunc crypto.Hash, magics *handshakeMagics, hostKeyAlgo string) (H, K []byte, err os.Error) { - packet, err := s.in.readPacket() + packet, err := s.readPacket() if err != nil { return } @@ -241,7 +241,7 @@ func (s *ServerConnection) kexDH(group *dhGroup, hashFunc crypto.Hash, magics *h } packet = marshal(msgKexDHReply, kexDHReply) - err = s.out.writePacket(packet) + err = s.writePacket(packet) return } @@ -292,14 +292,30 @@ func buildDataSignedForAuth(sessionId []byte, req userAuthRequestMsg, algo, pubK // Handshake performs an SSH transport and client authentication on the given ServerConnection. func (s *ServerConnection) Handshake(conn net.Conn) os.Error { var magics handshakeMagics - inBuf := bufio.NewReader(conn) - - _, err := conn.Write(serverVersion) - if err != nil { - return err + s.transport = &transport{ + reader: reader{ + Reader: bufio.NewReader(conn), + }, + writer: writer{ + Writer: bufio.NewWriter(conn), + rand: rand.Reader, + }, + Close: func() os.Error { + return conn.Close() + }, } + if _, err := conn.Write(serverVersion); err != nil { + return err + } magics.serverVersion = serverVersion[:len(serverVersion)-2] + + version, ok := readVersion(s.transport) + if !ok { + return os.NewError("failed to read version string from client") + } + magics.clientVersion = version + serverKexInit := kexInitMsg{ KexAlgos: supportedKexAlgos, ServerHostKeyAlgos: supportedHostKeyAlgos, @@ -313,28 +329,15 @@ func (s *ServerConnection) Handshake(conn net.Conn) os.Error { kexInitPacket := marshal(msgKexInit, serverKexInit) magics.serverKexInit = kexInitPacket - var out halfConnection - out.out = conn - out.rand = rand.Reader - s.out = &out - err = out.writePacket(kexInitPacket) + if err := s.writePacket(kexInitPacket); err != nil { + return err + } + + packet, err := s.readPacket() if err != nil { return err } - version, ok := readVersion(inBuf) - if !ok { - return os.NewError("failed to read version string from client") - } - magics.clientVersion = version - - var in halfConnection - in.in = inBuf - s.in = &in - packet, err := in.readPacket() - if err != nil { - return err - } magics.clientKexInit = packet var clientKexInit kexInitMsg @@ -342,7 +345,7 @@ func (s *ServerConnection) Handshake(conn net.Conn) os.Error { return err } - kexAlgo, hostKeyAlgo, ok := findAgreedAlgorithms(&in, &out, &clientKexInit, &serverKexInit) + kexAlgo, hostKeyAlgo, ok := findAgreedAlgorithms(s.transport, s.transport, &clientKexInit, &serverKexInit) if !ok { return os.NewError("ssh: no common algorithms") } @@ -350,7 +353,7 @@ func (s *ServerConnection) Handshake(conn net.Conn) os.Error { if clientKexInit.FirstKexFollows && kexAlgo != clientKexInit.KexAlgos[0] { // The client sent a Kex message for the wrong algorithm, // which we have to ignore. - _, err := in.readPacket() + _, err := s.readPacket() if err != nil { return err } @@ -372,23 +375,23 @@ func (s *ServerConnection) Handshake(conn net.Conn) os.Error { } packet = []byte{msgNewKeys} - if err = out.writePacket(packet); err != nil { + if err = s.writePacket(packet); err != nil { return err } - if err = out.setupKeys(serverKeys, K, H, H, hashFunc); err != nil { + if err = s.transport.writer.setupKeys(serverKeys, K, H, H, hashFunc); err != nil { return err } - if packet, err = in.readPacket(); err != nil { + if packet, err = s.readPacket(); err != nil { return err } if packet[0] != msgNewKeys { return UnexpectedMessageError{msgNewKeys, packet[0]} } - in.setupKeys(clientKeys, K, H, H, hashFunc) + s.transport.reader.setupKeys(clientKeys, K, H, H, hashFunc) - packet, err = in.readPacket() + packet, err = s.readPacket() if err != nil { return err } @@ -405,7 +408,7 @@ func (s *ServerConnection) Handshake(conn net.Conn) os.Error { Service: serviceUserAuth, } packet = marshal(msgServiceAccept, serviceAccept) - if err = out.writePacket(packet); err != nil { + if err = s.writePacket(packet); err != nil { return err } @@ -455,7 +458,7 @@ func (s *ServerConnection) authenticate(H []byte) os.Error { userAuthLoop: for { - if packet, err = s.in.readPacket(); err != nil { + if packet, err = s.readPacket(); err != nil { return err } if err = unmarshal(&userAuthReq, packet, msgUserAuthRequest); err != nil { @@ -519,7 +522,7 @@ userAuthLoop: Algo: algo, PubKey: string(pubKey), } - if err = s.out.writePacket(marshal(msgUserAuthPubKeyOk, okMsg)); err != nil { + if err = s.writePacket(marshal(msgUserAuthPubKeyOk, okMsg)); err != nil { return err } continue userAuthLoop @@ -571,13 +574,13 @@ userAuthLoop: return os.NewError("ssh: no authentication methods configured but NoClientAuth is also false") } - if err = s.out.writePacket(marshal(msgUserAuthFailure, failureMsg)); err != nil { + if err = s.writePacket(marshal(msgUserAuthFailure, failureMsg)); err != nil { return err } } packet = []byte{msgUserAuthSuccess} - if err = s.out.writePacket(packet); err != nil { + if err = s.writePacket(packet); err != nil { return err } @@ -594,7 +597,7 @@ func (s *ServerConnection) Accept() (Channel, os.Error) { } for { - packet, err := s.in.readPacket() + packet, err := s.readPacket() if err != nil { s.lock.Lock() @@ -697,7 +700,7 @@ func (s *ServerConnection) Accept() (Channel, os.Error) { } if request.WantReply { - if err := s.out.writePacket([]byte{msgRequestFailure}); err != nil { + if err := s.writePacket([]byte{msgRequestFailure}); err != nil { return nil, err } } diff --git a/src/pkg/exp/ssh/transport.go b/src/pkg/exp/ssh/transport.go index 919759ff989..5a474a9f2c2 100644 --- a/src/pkg/exp/ssh/transport.go +++ b/src/pkg/exp/ssh/transport.go @@ -13,56 +13,74 @@ import ( "crypto/subtle" "hash" "io" - "net" "os" ) -// halfConnection represents one direction of an SSH connection. It maintains -// the cipher state needed to process messages. -type halfConnection struct { - // Only one of these two will be non-nil - in *bufio.Reader - out net.Conn +const ( + paddingMultiple = 16 // TODO(dfc) does this need to be configurable? +) + +// transport represents the SSH connection to the remote peer. +type transport struct { + reader + writer - rand io.Reader cipherAlgo string macAlgo string compressionAlgo string + + Close func() os.Error +} + +// reader represents the incoming connection state. +type reader struct { + io.Reader + common +} + +// writer represnts the outgoing connection state. +type writer struct { + *bufio.Writer paddingMultiple int + rand io.Reader + common +} +// common represents the cipher state needed to process messages in a single +// direction. +type common struct { seqNum uint32 - mac hash.Hash cipher cipher.Stream } -func (hc *halfConnection) readOnePacket() (packet []byte, err os.Error) { - var lengthBytes [5]byte +// Read and decrypt a single packet from the remote peer. +func (r *reader) readOnePacket() ([]byte, os.Error) { + var lengthBytes = make([]byte, 5) + var macSize uint32 - _, err = io.ReadFull(hc.in, lengthBytes[:]) - if err != nil { - return + if _, err := io.ReadFull(r, lengthBytes); err != nil { + return nil, err } - if hc.cipher != nil { - hc.cipher.XORKeyStream(lengthBytes[:], lengthBytes[:]) + if r.cipher != nil { + r.cipher.XORKeyStream(lengthBytes, lengthBytes) } - macSize := 0 - if hc.mac != nil { - hc.mac.Reset() - var seqNumBytes [4]byte - seqNumBytes[0] = byte(hc.seqNum >> 24) - seqNumBytes[1] = byte(hc.seqNum >> 16) - seqNumBytes[2] = byte(hc.seqNum >> 8) - seqNumBytes[3] = byte(hc.seqNum) - hc.mac.Write(seqNumBytes[:]) - hc.mac.Write(lengthBytes[:]) - macSize = hc.mac.Size() + if r.mac != nil { + r.mac.Reset() + seqNumBytes := []byte{ + byte(r.seqNum >> 24), + byte(r.seqNum >> 16), + byte(r.seqNum >> 8), + byte(r.seqNum), + } + r.mac.Write(seqNumBytes) + r.mac.Write(lengthBytes) + macSize = uint32(r.mac.Size()) } length := uint32(lengthBytes[0])<<24 | uint32(lengthBytes[1])<<16 | uint32(lengthBytes[2])<<8 | uint32(lengthBytes[3]) - paddingLength := uint32(lengthBytes[4]) if length <= paddingLength+1 { @@ -72,31 +90,30 @@ func (hc *halfConnection) readOnePacket() (packet []byte, err os.Error) { return nil, os.NewError("packet too large") } - packet = make([]byte, length-1+uint32(macSize)) - _, err = io.ReadFull(hc.in, packet) - if err != nil { + packet := make([]byte, length-1+macSize) + if _, err := io.ReadFull(r, packet); err != nil { return nil, err } mac := packet[length-1:] - if hc.cipher != nil { - hc.cipher.XORKeyStream(packet, packet[:length-1]) + if r.cipher != nil { + r.cipher.XORKeyStream(packet, packet[:length-1]) } - if hc.mac != nil { - hc.mac.Write(packet[:length-1]) - if subtle.ConstantTimeCompare(hc.mac.Sum(), mac) != 1 { + if r.mac != nil { + r.mac.Write(packet[:length-1]) + if subtle.ConstantTimeCompare(r.mac.Sum(), mac) != 1 { return nil, os.NewError("ssh: MAC failure") } } - hc.seqNum++ - packet = packet[:length-paddingLength-1] - return + r.seqNum++ + return packet[:length-paddingLength-1], nil } -func (hc *halfConnection) readPacket() (packet []byte, err os.Error) { +// Read and decrypt next packet discarding debug and noop messages. +func (t *transport) readPacket() ([]byte, os.Error) { for { - packet, err := hc.readOnePacket() + packet, err := t.readOnePacket() if err != nil { return nil, err } @@ -107,119 +124,113 @@ func (hc *halfConnection) readPacket() (packet []byte, err os.Error) { panic("unreachable") } -func (hc *halfConnection) writePacket(packet []byte) os.Error { - paddingMultiple := hc.paddingMultiple - if paddingMultiple == 0 { - paddingMultiple = 8 - } - - paddingLength := paddingMultiple - (4+1+len(packet))%paddingMultiple +// Encrypt and send a packet of data to the remote peer. +func (w *writer) writePacket(packet []byte) os.Error { + paddingLength := paddingMultiple - (5+len(packet))%paddingMultiple if paddingLength < 4 { paddingLength += paddingMultiple } - var lengthBytes [5]byte length := len(packet) + 1 + paddingLength - lengthBytes[0] = byte(length >> 24) - lengthBytes[1] = byte(length >> 16) - lengthBytes[2] = byte(length >> 8) - lengthBytes[3] = byte(length) - lengthBytes[4] = byte(paddingLength) - - var padding [32]byte - _, err := io.ReadFull(hc.rand, padding[:paddingLength]) + lengthBytes := []byte{ + byte(length >> 24), + byte(length >> 16), + byte(length >> 8), + byte(length), + byte(paddingLength), + } + padding := make([]byte, paddingLength) + _, err := io.ReadFull(w.rand, padding) if err != nil { return err } - if hc.mac != nil { - hc.mac.Reset() - var seqNumBytes [4]byte - seqNumBytes[0] = byte(hc.seqNum >> 24) - seqNumBytes[1] = byte(hc.seqNum >> 16) - seqNumBytes[2] = byte(hc.seqNum >> 8) - seqNumBytes[3] = byte(hc.seqNum) - hc.mac.Write(seqNumBytes[:]) - hc.mac.Write(lengthBytes[:]) - hc.mac.Write(packet) - hc.mac.Write(padding[:paddingLength]) + if w.mac != nil { + w.mac.Reset() + seqNumBytes := []byte{ + byte(w.seqNum >> 24), + byte(w.seqNum >> 16), + byte(w.seqNum >> 8), + byte(w.seqNum), + } + w.mac.Write(seqNumBytes) + w.mac.Write(lengthBytes) + w.mac.Write(packet) + w.mac.Write(padding) } - if hc.cipher != nil { - hc.cipher.XORKeyStream(lengthBytes[:], lengthBytes[:]) - hc.cipher.XORKeyStream(packet, packet) - hc.cipher.XORKeyStream(padding[:], padding[:paddingLength]) + // TODO(dfc) lengthBytes, packet and padding should be + // subslices of a single buffer + if w.cipher != nil { + w.cipher.XORKeyStream(lengthBytes, lengthBytes) + w.cipher.XORKeyStream(packet, packet) + w.cipher.XORKeyStream(padding, padding) } - _, err = hc.out.Write(lengthBytes[:]) - if err != nil { + if _, err := w.Write(lengthBytes); err != nil { return err } - _, err = hc.out.Write(packet) - if err != nil { + if _, err := w.Write(packet); err != nil { return err } - _, err = hc.out.Write(padding[:paddingLength]) - if err != nil { + if _, err := w.Write(padding); err != nil { return err } - if hc.mac != nil { - _, err = hc.out.Write(hc.mac.Sum()) + if w.mac != nil { + if _, err := w.Write(w.mac.Sum()); err != nil { + return err + } } - hc.seqNum++ - + if err := w.Flush(); err != nil { + return err + } + w.seqNum++ return err } -const ( - serverKeys = iota - clientKeys +type direction struct { + ivTag []byte + keyTag []byte + macKeyTag []byte +} + +// TODO(dfc) can this be made a constant ? +var ( + serverKeys = direction{[]byte{'B'}, []byte{'D'}, []byte{'F'}} + clientKeys = direction{[]byte{'A'}, []byte{'C'}, []byte{'E'}} ) -// setupServerKeys sets the cipher and MAC keys from K, H and sessionId, as +// setupKeys sets the cipher and MAC keys from K, H and sessionId, as // described in RFC 4253, section 6.4. direction should either be serverKeys // (to setup server->client keys) or clientKeys (for client->server keys). -func (hc *halfConnection) setupKeys(direction int, K, H, sessionId []byte, hashFunc crypto.Hash) os.Error { +func (c *common) setupKeys(d direction, K, H, sessionId []byte, hashFunc crypto.Hash) os.Error { h := hashFunc.New() - // We only support these algorithms for now. - if hc.cipherAlgo != cipherAES128CTR || hc.macAlgo != macSHA196 { - return os.NewError("ssh: setupServerKeys internal error") - } - blockSize := 16 keySize := 16 macKeySize := 20 - var ivTag, keyTag, macKeyTag byte - if direction == serverKeys { - ivTag, keyTag, macKeyTag = 'B', 'D', 'F' - } else { - ivTag, keyTag, macKeyTag = 'A', 'C', 'E' - } - iv := make([]byte, blockSize) key := make([]byte, keySize) macKey := make([]byte, macKeySize) - generateKeyMaterial(iv, ivTag, K, H, sessionId, h) - generateKeyMaterial(key, keyTag, K, H, sessionId, h) - generateKeyMaterial(macKey, macKeyTag, K, H, sessionId, h) + generateKeyMaterial(iv, d.ivTag, K, H, sessionId, h) + generateKeyMaterial(key, d.keyTag, K, H, sessionId, h) + generateKeyMaterial(macKey, d.macKeyTag, K, H, sessionId, h) - hc.mac = truncatingMAC{12, hmac.NewSHA1(macKey)} + c.mac = truncatingMAC{12, hmac.NewSHA1(macKey)} aes, err := aes.NewCipher(key) if err != nil { return err } - hc.cipher = cipher.NewCTR(aes, iv) - hc.paddingMultiple = 16 + c.cipher = cipher.NewCTR(aes, iv) return nil } // generateKeyMaterial fills out with key material generated from tag, K, H // and sessionId, as specified in RFC 4253, section 7.2. -func generateKeyMaterial(out []byte, tag byte, K, H, sessionId []byte, h hash.Hash) { +func generateKeyMaterial(out, tag []byte, K, H, sessionId []byte, h hash.Hash) { var digestsSoFar []byte for len(out) > 0 { @@ -228,7 +239,7 @@ func generateKeyMaterial(out []byte, tag byte, K, H, sessionId []byte, h hash.Ha h.Write(H) if len(digestsSoFar) == 0 { - h.Write([]byte{tag}) + h.Write(tag) h.Write(sessionId) } else { h.Write(digestsSoFar) @@ -273,16 +284,19 @@ func (t truncatingMAC) Size() int { // while searching for the end of the version handshake. const maxVersionStringBytes = 1024 -func readVersion(r *bufio.Reader) (versionString []byte, ok bool) { +// Read version string as specified by RFC 4253, section 4.2. +func readVersion(r io.Reader) (versionString []byte, ok bool) { versionString = make([]byte, 0, 64) seenCR := false + var buf [1]byte forEachByte: for len(versionString) < maxVersionStringBytes { - b, err := r.ReadByte() + _, err := io.ReadFull(r, buf[:]) if err != nil { return } + b := buf[0] if !seenCR { if b == '\r' { diff --git a/src/pkg/exp/ssh/transport_test.go b/src/pkg/exp/ssh/transport_test.go new file mode 100644 index 00000000000..9a610a7803c --- /dev/null +++ b/src/pkg/exp/ssh/transport_test.go @@ -0,0 +1,37 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package ssh + +import ( + "bufio" + "bytes" + "testing" +) + +func TestReadVersion(t *testing.T) { + buf := []byte(serverVersion) + result, ok := readVersion(bufio.NewReader(bytes.NewBuffer(buf))) + if !ok { + t.Error("readVersion didn't read version correctly") + } + if !bytes.Equal(buf[:len(buf)-2], result) { + t.Error("version read did not match expected") + } +} + +func TestReadVersionTooLong(t *testing.T) { + buf := make([]byte, maxVersionStringBytes+1) + if _, ok := readVersion(bufio.NewReader(bytes.NewBuffer(buf))); ok { + t.Errorf("readVersion consumed %d bytes without error", len(buf)) + } +} + +func TestReadVersionWithoutCRLF(t *testing.T) { + buf := []byte(serverVersion) + buf = buf[:len(buf)-1] + if _, ok := readVersion(bufio.NewReader(bytes.NewBuffer(buf))); ok { + t.Error("readVersion did not notice \\n was missing") + } +}