1
0
mirror of https://github.com/golang/go synced 2024-11-21 22:24:40 -07:00

exp/ssh: fix length header leaking into channel data streams.

The payload of a data message is defined as an SSH string type,
which uses the first four bytes to encode its length. When channelData
and channelExtendedData were added I defined Payload as []byte to
be able to use it directly without a string to []byte conversion. This
resulted in the length data leaking into the payload data.

This CL fixes the bug, and restores agl's original fast path code.

Additionally, a bug whereby s.lock was not released if a packet arrived
for an invalid channel has been fixed.

Finally, as they were no longer used, I have removed
the channelData and channelExtedendData structs.

R=agl, rsc
CC=golang-dev
https://golang.org/cl/5330053
This commit is contained in:
Dave Cheney 2011-10-29 14:22:30 -04:00 committed by Adam Langley
parent 604e10c34d
commit 0f6b80c694
3 changed files with 131 additions and 114 deletions

View File

@ -258,51 +258,71 @@ func (c *ClientConn) openChan(typ string) (*clientChan, os.Error) {
// mainloop reads incoming messages and routes channel messages // mainloop reads incoming messages and routes channel messages
// to their respective ClientChans. // to their respective ClientChans.
func (c *ClientConn) mainLoop() { func (c *ClientConn) mainLoop() {
// TODO(dfc) signal the underlying close to all channels
defer c.Close()
for { for {
packet, err := c.readPacket() packet, err := c.readPacket()
if err != nil { if err != nil {
// TODO(dfc) signal the underlying close to all channels break
c.Close()
return
} }
// TODO(dfc) A note on blocking channel use. // TODO(dfc) A note on blocking channel use.
// The msg, win, data and dataExt channels of a clientChan can // The msg, win, data and dataExt channels of a clientChan can
// cause this loop to block indefinately if the consumer does // cause this loop to block indefinately if the consumer does
// not service them. // not service them.
switch msg := decode(packet).(type) { switch packet[0] {
case *channelOpenMsg: case msgChannelData:
c.getChan(msg.PeersId).msg <- msg if len(packet) < 9 {
case *channelOpenConfirmMsg: // malformed data packet
c.getChan(msg.PeersId).msg <- msg break
case *channelOpenFailureMsg: }
c.getChan(msg.PeersId).msg <- msg peersId := uint32(packet[1])<<24 | uint32(packet[2])<<16 | uint32(packet[3])<<8 | uint32(packet[4])
case *channelCloseMsg: if length := int(packet[5])<<24 | int(packet[6])<<16 | int(packet[7])<<8 | int(packet[8]); length > 0 {
ch := c.getChan(msg.PeersId) packet = packet[9:]
close(ch.win) c.getChan(peersId).data <- packet[:length]
close(ch.data) }
close(ch.dataExt) case msgChannelExtendedData:
c.chanlist.remove(msg.PeersId) if len(packet) < 13 {
case *channelEOFMsg: // malformed data packet
c.getChan(msg.PeersId).msg <- msg break
case *channelRequestSuccessMsg: }
c.getChan(msg.PeersId).msg <- msg peersId := uint32(packet[1])<<24 | uint32(packet[2])<<16 | uint32(packet[3])<<8 | uint32(packet[4])
case *channelRequestFailureMsg: datatype := uint32(packet[5])<<24 | uint32(packet[6])<<16 | uint32(packet[7])<<8 | uint32(packet[8])
c.getChan(msg.PeersId).msg <- msg if length := int(packet[9])<<24 | int(packet[10])<<16 | int(packet[11])<<8 | int(packet[12]); length > 0 {
case *channelRequestMsg: packet = packet[13:]
c.getChan(msg.PeersId).msg <- msg // RFC 4254 5.2 defines data_type_code 1 to be data destined
case *windowAdjustMsg: // for stderr on interactive sessions. Other data types are
c.getChan(msg.PeersId).win <- int(msg.AdditionalBytes) // silently discarded.
case *channelData: if datatype == 1 {
c.getChan(msg.PeersId).data <- msg.Payload c.getChan(peersId).dataExt <- packet[:length]
case *channelExtendedData: }
// RFC 4254 5.2 defines data_type_code 1 to be data destined
// for stderr on interactive sessions. Other data types are
// silently discarded.
if msg.Datatype == 1 {
c.getChan(msg.PeersId).dataExt <- msg.Payload
} }
default: default:
fmt.Printf("mainLoop: unhandled %#v\n", msg) switch msg := decode(packet).(type) {
case *channelOpenMsg:
c.getChan(msg.PeersId).msg <- msg
case *channelOpenConfirmMsg:
c.getChan(msg.PeersId).msg <- msg
case *channelOpenFailureMsg:
c.getChan(msg.PeersId).msg <- msg
case *channelCloseMsg:
ch := c.getChan(msg.PeersId)
close(ch.win)
close(ch.data)
close(ch.dataExt)
c.chanlist.remove(msg.PeersId)
case *channelEOFMsg:
c.getChan(msg.PeersId).msg <- msg
case *channelRequestSuccessMsg:
c.getChan(msg.PeersId).msg <- msg
case *channelRequestFailureMsg:
c.getChan(msg.PeersId).msg <- msg
case *channelRequestMsg:
c.getChan(msg.PeersId).msg <- msg
case *windowAdjustMsg:
c.getChan(msg.PeersId).win <- int(msg.AdditionalBytes)
default:
fmt.Printf("mainLoop: unhandled %#v\n", msg)
}
} }
} }
} }

View File

@ -144,19 +144,6 @@ type channelOpenFailureMsg struct {
Language string Language string
} }
// See RFC 4254, section 5.2.
type channelData struct {
PeersId uint32
Payload []byte `ssh:"rest"`
}
// See RFC 4254, section 5.2.
type channelExtendedData struct {
PeersId uint32
Datatype uint32
Payload []byte `ssh:"rest"`
}
type channelRequestMsg struct { type channelRequestMsg struct {
PeersId uint32 PeersId uint32
Request string Request string
@ -612,10 +599,6 @@ func decode(packet []byte) interface{} {
msg = new(channelOpenFailureMsg) msg = new(channelOpenFailureMsg)
case msgChannelWindowAdjust: case msgChannelWindowAdjust:
msg = new(windowAdjustMsg) msg = new(windowAdjustMsg)
case msgChannelData:
msg = new(channelData)
case msgChannelExtendedData:
msg = new(channelExtendedData)
case msgChannelEOF: case msgChannelEOF:
msg = new(channelEOFMsg) msg = new(channelEOFMsg)
case msgChannelClose: case msgChannelClose:

View File

@ -581,75 +581,89 @@ func (s *ServerConn) Accept() (Channel, os.Error) {
return nil, err return nil, err
} }
switch msg := decode(packet).(type) { switch packet[0] {
case *channelOpenMsg: case msgChannelData:
c := new(channel) if len(packet) < 9 {
c.chanType = msg.ChanType // malformed data packet
c.theirId = msg.PeersId return nil, ParseError{msgChannelData}
c.theirWindow = msg.PeersWindow }
c.maxPacketSize = msg.MaxPacketSize peersId := uint32(packet[1])<<24 | uint32(packet[2])<<16 | uint32(packet[3])<<8 | uint32(packet[4])
c.extraData = msg.TypeSpecificData
c.myWindow = defaultWindowSize
c.serverConn = s
c.cond = sync.NewCond(&c.lock)
c.pendingData = make([]byte, c.myWindow)
s.lock.Lock() s.lock.Lock()
c.myId = s.nextChanId c, ok := s.channels[peersId]
s.nextChanId++
s.channels[c.myId] = c
s.lock.Unlock()
return c, nil
case *channelRequestMsg:
s.lock.Lock()
c, ok := s.channels[msg.PeersId]
if !ok { if !ok {
s.lock.Unlock()
continue continue
} }
c.handlePacket(msg) if length := int(packet[5])<<24 | int(packet[6])<<16 | int(packet[7])<<8 | int(packet[8]); length > 0 {
s.lock.Unlock() packet = packet[9:]
c.handleData(packet[:length])
case *channelData:
s.lock.Lock()
c, ok := s.channels[msg.PeersId]
if !ok {
continue
} }
c.handleData(msg.Payload)
s.lock.Unlock() s.lock.Unlock()
case *channelEOFMsg:
s.lock.Lock()
c, ok := s.channels[msg.PeersId]
if !ok {
continue
}
c.handlePacket(msg)
s.lock.Unlock()
case *channelCloseMsg:
s.lock.Lock()
c, ok := s.channels[msg.PeersId]
if !ok {
continue
}
c.handlePacket(msg)
s.lock.Unlock()
case *globalRequestMsg:
if msg.WantReply {
if err := s.writePacket([]byte{msgRequestFailure}); err != nil {
return nil, err
}
}
case UnexpectedMessageError:
return nil, msg
case *disconnectMsg:
return nil, os.EOF
default: default:
// Unknown message. Ignore. switch msg := decode(packet).(type) {
case *channelOpenMsg:
c := new(channel)
c.chanType = msg.ChanType
c.theirId = msg.PeersId
c.theirWindow = msg.PeersWindow
c.maxPacketSize = msg.MaxPacketSize
c.extraData = msg.TypeSpecificData
c.myWindow = defaultWindowSize
c.serverConn = s
c.cond = sync.NewCond(&c.lock)
c.pendingData = make([]byte, c.myWindow)
s.lock.Lock()
c.myId = s.nextChanId
s.nextChanId++
s.channels[c.myId] = c
s.lock.Unlock()
return c, nil
case *channelRequestMsg:
s.lock.Lock()
c, ok := s.channels[msg.PeersId]
if !ok {
s.lock.Unlock()
continue
}
c.handlePacket(msg)
s.lock.Unlock()
case *channelEOFMsg:
s.lock.Lock()
c, ok := s.channels[msg.PeersId]
if !ok {
s.lock.Unlock()
continue
}
c.handlePacket(msg)
s.lock.Unlock()
case *channelCloseMsg:
s.lock.Lock()
c, ok := s.channels[msg.PeersId]
if !ok {
s.lock.Unlock()
continue
}
c.handlePacket(msg)
s.lock.Unlock()
case *globalRequestMsg:
if msg.WantReply {
if err := s.writePacket([]byte{msgRequestFailure}); err != nil {
return nil, err
}
}
case UnexpectedMessageError:
return nil, msg
case *disconnectMsg:
return nil, os.EOF
default:
// Unknown message. Ignore.
}
} }
} }