From 0f6b80c69498d2047d584d365e4056ced9f38adc Mon Sep 17 00:00:00 2001 From: Dave Cheney Date: Sat, 29 Oct 2011 14:22:30 -0400 Subject: [PATCH] exp/ssh: fix length header leaking into channel data streams. The payload of a data message is defined as an SSH string type, which uses the first four bytes to encode its length. When channelData and channelExtendedData were added I defined Payload as []byte to be able to use it directly without a string to []byte conversion. This resulted in the length data leaking into the payload data. This CL fixes the bug, and restores agl's original fast path code. Additionally, a bug whereby s.lock was not released if a packet arrived for an invalid channel has been fixed. Finally, as they were no longer used, I have removed the channelData and channelExtedendData structs. R=agl, rsc CC=golang-dev https://golang.org/cl/5330053 --- src/pkg/exp/ssh/client.go | 90 ++++++++++++++--------- src/pkg/exp/ssh/messages.go | 17 ----- src/pkg/exp/ssh/server.go | 138 ++++++++++++++++++++---------------- 3 files changed, 131 insertions(+), 114 deletions(-) diff --git a/src/pkg/exp/ssh/client.go b/src/pkg/exp/ssh/client.go index 9223b6c3cf2..fe76db16c4b 100644 --- a/src/pkg/exp/ssh/client.go +++ b/src/pkg/exp/ssh/client.go @@ -258,51 +258,71 @@ func (c *ClientConn) openChan(typ string) (*clientChan, os.Error) { // mainloop reads incoming messages and routes channel messages // to their respective ClientChans. func (c *ClientConn) mainLoop() { + // TODO(dfc) signal the underlying close to all channels + defer c.Close() for { packet, err := c.readPacket() if err != nil { - // TODO(dfc) signal the underlying close to all channels - c.Close() - return + break } // TODO(dfc) A note on blocking channel use. // The msg, win, data and dataExt channels of a clientChan can // cause this loop to block indefinately if the consumer does // not service them. - switch msg := decode(packet).(type) { - case *channelOpenMsg: - c.getChan(msg.PeersId).msg <- msg - case *channelOpenConfirmMsg: - c.getChan(msg.PeersId).msg <- msg - case *channelOpenFailureMsg: - c.getChan(msg.PeersId).msg <- msg - case *channelCloseMsg: - ch := c.getChan(msg.PeersId) - close(ch.win) - close(ch.data) - close(ch.dataExt) - c.chanlist.remove(msg.PeersId) - case *channelEOFMsg: - c.getChan(msg.PeersId).msg <- msg - case *channelRequestSuccessMsg: - c.getChan(msg.PeersId).msg <- msg - case *channelRequestFailureMsg: - c.getChan(msg.PeersId).msg <- msg - case *channelRequestMsg: - c.getChan(msg.PeersId).msg <- msg - case *windowAdjustMsg: - c.getChan(msg.PeersId).win <- int(msg.AdditionalBytes) - case *channelData: - c.getChan(msg.PeersId).data <- msg.Payload - case *channelExtendedData: - // RFC 4254 5.2 defines data_type_code 1 to be data destined - // for stderr on interactive sessions. Other data types are - // silently discarded. - if msg.Datatype == 1 { - c.getChan(msg.PeersId).dataExt <- msg.Payload + switch packet[0] { + case msgChannelData: + if len(packet) < 9 { + // malformed data packet + break + } + peersId := uint32(packet[1])<<24 | uint32(packet[2])<<16 | uint32(packet[3])<<8 | uint32(packet[4]) + if length := int(packet[5])<<24 | int(packet[6])<<16 | int(packet[7])<<8 | int(packet[8]); length > 0 { + packet = packet[9:] + c.getChan(peersId).data <- packet[:length] + } + case msgChannelExtendedData: + if len(packet) < 13 { + // malformed data packet + break + } + peersId := uint32(packet[1])<<24 | uint32(packet[2])<<16 | uint32(packet[3])<<8 | uint32(packet[4]) + datatype := uint32(packet[5])<<24 | uint32(packet[6])<<16 | uint32(packet[7])<<8 | uint32(packet[8]) + if length := int(packet[9])<<24 | int(packet[10])<<16 | int(packet[11])<<8 | int(packet[12]); length > 0 { + packet = packet[13:] + // RFC 4254 5.2 defines data_type_code 1 to be data destined + // for stderr on interactive sessions. Other data types are + // silently discarded. + if datatype == 1 { + c.getChan(peersId).dataExt <- packet[:length] + } } default: - fmt.Printf("mainLoop: unhandled %#v\n", msg) + switch msg := decode(packet).(type) { + case *channelOpenMsg: + c.getChan(msg.PeersId).msg <- msg + case *channelOpenConfirmMsg: + c.getChan(msg.PeersId).msg <- msg + case *channelOpenFailureMsg: + c.getChan(msg.PeersId).msg <- msg + case *channelCloseMsg: + ch := c.getChan(msg.PeersId) + close(ch.win) + close(ch.data) + close(ch.dataExt) + c.chanlist.remove(msg.PeersId) + case *channelEOFMsg: + c.getChan(msg.PeersId).msg <- msg + case *channelRequestSuccessMsg: + c.getChan(msg.PeersId).msg <- msg + case *channelRequestFailureMsg: + c.getChan(msg.PeersId).msg <- msg + case *channelRequestMsg: + c.getChan(msg.PeersId).msg <- msg + case *windowAdjustMsg: + c.getChan(msg.PeersId).win <- int(msg.AdditionalBytes) + default: + fmt.Printf("mainLoop: unhandled %#v\n", msg) + } } } } diff --git a/src/pkg/exp/ssh/messages.go b/src/pkg/exp/ssh/messages.go index 7771f2b242e..5f2c447142a 100644 --- a/src/pkg/exp/ssh/messages.go +++ b/src/pkg/exp/ssh/messages.go @@ -144,19 +144,6 @@ type channelOpenFailureMsg struct { Language string } -// See RFC 4254, section 5.2. -type channelData struct { - PeersId uint32 - Payload []byte `ssh:"rest"` -} - -// See RFC 4254, section 5.2. -type channelExtendedData struct { - PeersId uint32 - Datatype uint32 - Payload []byte `ssh:"rest"` -} - type channelRequestMsg struct { PeersId uint32 Request string @@ -612,10 +599,6 @@ func decode(packet []byte) interface{} { msg = new(channelOpenFailureMsg) case msgChannelWindowAdjust: msg = new(windowAdjustMsg) - case msgChannelData: - msg = new(channelData) - case msgChannelExtendedData: - msg = new(channelExtendedData) case msgChannelEOF: msg = new(channelEOFMsg) case msgChannelClose: diff --git a/src/pkg/exp/ssh/server.go b/src/pkg/exp/ssh/server.go index 3a640fc081e..0dd24ecd6ee 100644 --- a/src/pkg/exp/ssh/server.go +++ b/src/pkg/exp/ssh/server.go @@ -581,75 +581,89 @@ func (s *ServerConn) Accept() (Channel, os.Error) { return nil, err } - switch msg := decode(packet).(type) { - case *channelOpenMsg: - c := new(channel) - c.chanType = msg.ChanType - c.theirId = msg.PeersId - c.theirWindow = msg.PeersWindow - c.maxPacketSize = msg.MaxPacketSize - c.extraData = msg.TypeSpecificData - c.myWindow = defaultWindowSize - c.serverConn = s - c.cond = sync.NewCond(&c.lock) - c.pendingData = make([]byte, c.myWindow) - + switch packet[0] { + case msgChannelData: + if len(packet) < 9 { + // malformed data packet + return nil, ParseError{msgChannelData} + } + peersId := uint32(packet[1])<<24 | uint32(packet[2])<<16 | uint32(packet[3])<<8 | uint32(packet[4]) s.lock.Lock() - c.myId = s.nextChanId - s.nextChanId++ - s.channels[c.myId] = c - s.lock.Unlock() - return c, nil - - case *channelRequestMsg: - s.lock.Lock() - c, ok := s.channels[msg.PeersId] + c, ok := s.channels[peersId] if !ok { + s.lock.Unlock() continue } - c.handlePacket(msg) - s.lock.Unlock() - - case *channelData: - s.lock.Lock() - c, ok := s.channels[msg.PeersId] - if !ok { - continue + if length := int(packet[5])<<24 | int(packet[6])<<16 | int(packet[7])<<8 | int(packet[8]); length > 0 { + packet = packet[9:] + c.handleData(packet[:length]) } - c.handleData(msg.Payload) s.lock.Unlock() - - case *channelEOFMsg: - s.lock.Lock() - c, ok := s.channels[msg.PeersId] - if !ok { - continue - } - c.handlePacket(msg) - s.lock.Unlock() - - case *channelCloseMsg: - s.lock.Lock() - c, ok := s.channels[msg.PeersId] - if !ok { - continue - } - c.handlePacket(msg) - s.lock.Unlock() - - case *globalRequestMsg: - if msg.WantReply { - if err := s.writePacket([]byte{msgRequestFailure}); err != nil { - return nil, err - } - } - - case UnexpectedMessageError: - return nil, msg - case *disconnectMsg: - return nil, os.EOF default: - // Unknown message. Ignore. + switch msg := decode(packet).(type) { + case *channelOpenMsg: + c := new(channel) + c.chanType = msg.ChanType + c.theirId = msg.PeersId + c.theirWindow = msg.PeersWindow + c.maxPacketSize = msg.MaxPacketSize + c.extraData = msg.TypeSpecificData + c.myWindow = defaultWindowSize + c.serverConn = s + c.cond = sync.NewCond(&c.lock) + c.pendingData = make([]byte, c.myWindow) + + s.lock.Lock() + c.myId = s.nextChanId + s.nextChanId++ + s.channels[c.myId] = c + s.lock.Unlock() + return c, nil + + case *channelRequestMsg: + s.lock.Lock() + c, ok := s.channels[msg.PeersId] + if !ok { + s.lock.Unlock() + continue + } + c.handlePacket(msg) + s.lock.Unlock() + + case *channelEOFMsg: + s.lock.Lock() + c, ok := s.channels[msg.PeersId] + if !ok { + s.lock.Unlock() + continue + } + c.handlePacket(msg) + s.lock.Unlock() + + case *channelCloseMsg: + s.lock.Lock() + c, ok := s.channels[msg.PeersId] + if !ok { + s.lock.Unlock() + continue + } + c.handlePacket(msg) + s.lock.Unlock() + + case *globalRequestMsg: + if msg.WantReply { + if err := s.writePacket([]byte{msgRequestFailure}); err != nil { + return nil, err + } + } + + case UnexpectedMessageError: + return nil, msg + case *disconnectMsg: + return nil, os.EOF + default: + // Unknown message. Ignore. + } } }