diff --git a/src/pkg/exp/ssh/client.go b/src/pkg/exp/ssh/client.go index 429dee975bc..d89b908cdcf 100644 --- a/src/pkg/exp/ssh/client.go +++ b/src/pkg/exp/ssh/client.go @@ -200,7 +200,7 @@ func (c *ClientConn) mainLoop() { peersId := uint32(packet[1])<<24 | uint32(packet[2])<<16 | uint32(packet[3])<<8 | uint32(packet[4]) if length := int(packet[5])<<24 | int(packet[6])<<16 | int(packet[7])<<8 | int(packet[8]); length > 0 { packet = packet[9:] - c.getChan(peersId).data <- packet[:length] + c.getChan(peersId).stdout.data <- packet[:length] } case msgChannelExtendedData: if len(packet) < 13 { @@ -215,7 +215,7 @@ func (c *ClientConn) mainLoop() { // for stderr on interactive sessions. Other data types are // silently discarded. if datatype == 1 { - c.getChan(peersId).dataExt <- packet[:length] + c.getChan(peersId).stderr.data <- packet[:length] } } default: @@ -228,9 +228,9 @@ func (c *ClientConn) mainLoop() { c.getChan(msg.PeersId).msg <- msg case *channelCloseMsg: ch := c.getChan(msg.PeersId) - close(ch.win) - close(ch.data) - close(ch.dataExt) + close(ch.stdin.win) + close(ch.stdout.data) + close(ch.stderr.data) c.chanlist.remove(msg.PeersId) case *channelEOFMsg: c.getChan(msg.PeersId).msg <- msg @@ -241,7 +241,7 @@ func (c *ClientConn) mainLoop() { case *channelRequestMsg: c.getChan(msg.PeersId).msg <- msg case *windowAdjustMsg: - c.getChan(msg.PeersId).win <- int(msg.AdditionalBytes) + c.getChan(msg.PeersId).stdin.win <- int(msg.AdditionalBytes) default: fmt.Printf("mainLoop: unhandled message %T: %v\n", msg, msg) } @@ -290,21 +290,49 @@ func (c *ClientConfig) rand() io.Reader { type clientChan struct { packetWriter id, peersId uint32 - data chan []byte // receives the payload of channelData messages - dataExt chan []byte // receives the payload of channelExtendedData messages - win chan int // receives window adjustments + stdin *chanWriter // receives window adjustments + stdout *chanReader // receives the payload of channelData messages + stderr *chanReader // receives the payload of channelExtendedData messages msg chan interface{} // incoming messages } +// newClientChan returns a partially constructed *clientChan +// using the local id provided. To be usable clientChan.peersId +// needs to be assigned once known. func newClientChan(t *transport, id uint32) *clientChan { - return &clientChan{ + c := &clientChan{ packetWriter: t, id: id, - data: make(chan []byte, 16), - dataExt: make(chan []byte, 16), - win: make(chan int, 16), msg: make(chan interface{}, 16), } + c.stdin = &chanWriter{ + win: make(chan int, 16), + clientChan: c, + } + c.stdout = &chanReader{ + data: make(chan []byte, 16), + clientChan: c, + } + c.stderr = &chanReader{ + data: make(chan []byte, 16), + clientChan: c, + } + return c +} + +// waitForChannelOpenResponse, if successful, fills out +// the peerId and records any initial window advertisement. +func (c *clientChan) waitForChannelOpenResponse() error { + switch msg := (<-c.msg).(type) { + case *channelOpenConfirmMsg: + // fixup peersId field + c.peersId = msg.MyId + c.stdin.win <- int(msg.MyWindow) + return nil + case *channelOpenFailureMsg: + return errors.New(safeString(msg.Message)) + } + return errors.New("unexpected packet") } // Close closes the channel. This does not close the underlying connection. @@ -355,10 +383,9 @@ func (c *chanlist) remove(id uint32) { // A chanWriter represents the stdin of a remote process. type chanWriter struct { - win chan int // receives window adjustments - peersId uint32 // the peer's id - rwin int // current rwin size - packetWriter // for sending channelDataMsg + win chan int // receives window adjustments + rwin int // current rwin size + clientChan *clientChan // the channel backing this writer } // Write writes data to the remote process's standard input. @@ -372,12 +399,13 @@ func (w *chanWriter) Write(data []byte) (n int, err error) { w.rwin += win continue } + peersId := w.clientChan.peersId n = len(data) packet := make([]byte, 0, 9+n) packet = append(packet, msgChannelData, - byte(w.peersId>>24), byte(w.peersId>>16), byte(w.peersId>>8), byte(w.peersId), + byte(peersId>>24), byte(peersId>>16), byte(peersId>>8), byte(peersId), byte(n>>24), byte(n>>16), byte(n>>8), byte(n)) - err = w.writePacket(append(packet, data...)) + err = w.clientChan.writePacket(append(packet, data...)) w.rwin -= n return } @@ -385,7 +413,7 @@ func (w *chanWriter) Write(data []byte) (n int, err error) { } func (w *chanWriter) Close() error { - return w.writePacket(marshal(msgChannelEOF, channelEOFMsg{w.peersId})) + return w.clientChan.writePacket(marshal(msgChannelEOF, channelEOFMsg{w.clientChan.peersId})) } // A chanReader represents stdout or stderr of a remote process. @@ -393,10 +421,9 @@ type chanReader struct { // TODO(dfc) a fixed size channel may not be the right data structure. // If writes to this channel block, they will block mainLoop, making // it unable to receive new messages from the remote side. - data chan []byte // receives data from remote - peersId uint32 // the peer's id - packetWriter // for sending windowAdjustMsg - buf []byte + data chan []byte // receives data from remote + clientChan *clientChan // the channel backing this reader + buf []byte } // Read reads data from the remote process's stdout or stderr. @@ -407,10 +434,10 @@ func (r *chanReader) Read(data []byte) (int, error) { n := copy(data, r.buf) r.buf = r.buf[n:] msg := windowAdjustMsg{ - PeersId: r.peersId, + PeersId: r.clientChan.peersId, AdditionalBytes: uint32(n), } - return n, r.writePacket(marshal(msgChannelWindowAdjust, msg)) + return n, r.clientChan.writePacket(marshal(msgChannelWindowAdjust, msg)) } r.buf, ok = <-r.data if !ok { diff --git a/src/pkg/exp/ssh/session.go b/src/pkg/exp/ssh/session.go index 5f98a8d58c6..23ea18c29ae 100644 --- a/src/pkg/exp/ssh/session.go +++ b/src/pkg/exp/ssh/session.go @@ -285,13 +285,8 @@ func (s *Session) stdin() error { s.Stdin = new(bytes.Buffer) } s.copyFuncs = append(s.copyFuncs, func() error { - w := &chanWriter{ - packetWriter: s, - peersId: s.peersId, - win: s.win, - } - _, err := io.Copy(w, s.Stdin) - if err1 := w.Close(); err == nil { + _, err := io.Copy(s.clientChan.stdin, s.Stdin) + if err1 := s.clientChan.stdin.Close(); err == nil { err = err1 } return err @@ -304,12 +299,7 @@ func (s *Session) stdout() error { s.Stdout = ioutil.Discard } s.copyFuncs = append(s.copyFuncs, func() error { - r := &chanReader{ - packetWriter: s, - peersId: s.peersId, - data: s.data, - } - _, err := io.Copy(s.Stdout, r) + _, err := io.Copy(s.Stdout, s.clientChan.stdout) return err }) return nil @@ -320,12 +310,7 @@ func (s *Session) stderr() error { s.Stderr = ioutil.Discard } s.copyFuncs = append(s.copyFuncs, func() error { - r := &chanReader{ - packetWriter: s, - peersId: s.peersId, - data: s.dataExt, - } - _, err := io.Copy(s.Stderr, r) + _, err := io.Copy(s.Stderr, s.clientChan.stderr) return err }) return nil @@ -398,19 +383,11 @@ func (c *ClientConn) NewSession() (*Session, error) { c.chanlist.remove(ch.id) return nil, err } - // wait for response - msg := <-ch.msg - switch msg := msg.(type) { - case *channelOpenConfirmMsg: - ch.peersId = msg.MyId - ch.win <- int(msg.MyWindow) - return &Session{ - clientChan: ch, - }, nil - case *channelOpenFailureMsg: + if err := ch.waitForChannelOpenResponse(); err != nil { c.chanlist.remove(ch.id) - return nil, fmt.Errorf("ssh: channel open failed: %s", msg.Message) + return nil, fmt.Errorf("ssh: unable to open session: %v", err) } - c.chanlist.remove(ch.id) - return nil, fmt.Errorf("ssh: unexpected message %T: %v", msg, msg) + return &Session{ + clientChan: ch, + }, nil } diff --git a/src/pkg/exp/ssh/tcpip.go b/src/pkg/exp/ssh/tcpip.go index f3bbac5d19e..a85044ace9c 100644 --- a/src/pkg/exp/ssh/tcpip.go +++ b/src/pkg/exp/ssh/tcpip.go @@ -6,6 +6,7 @@ package ssh import ( "errors" + "fmt" "io" "net" ) @@ -42,20 +43,21 @@ func (c *ClientConn) DialTCP(n string, laddr, raddr *net.TCPAddr) (net.Conn, err }, nil } +// RFC 4254 7.2 +type channelOpenDirectMsg struct { + ChanType string + PeersId uint32 + PeersWindow uint32 + MaxPacketSize uint32 + raddr string + rport uint32 + laddr string + lport uint32 +} + // dial opens a direct-tcpip connection to the remote server. laddr and raddr are passed as // strings and are expected to be resolveable at the remote end. func (c *ClientConn) dial(laddr string, lport int, raddr string, rport int) (*tcpchan, error) { - // RFC 4254 7.2 - type channelOpenDirectMsg struct { - ChanType string - PeersId uint32 - PeersWindow uint32 - MaxPacketSize uint32 - raddr string - rport uint32 - laddr string - lport uint32 - } ch := c.newChan(c.transport) if err := c.writePacket(marshal(msgChannelOpen, channelOpenDirectMsg{ ChanType: "direct-tcpip", @@ -70,30 +72,14 @@ func (c *ClientConn) dial(laddr string, lport int, raddr string, rport int) (*tc c.chanlist.remove(ch.id) return nil, err } - // wait for response - switch msg := (<-ch.msg).(type) { - case *channelOpenConfirmMsg: - ch.peersId = msg.MyId - ch.win <- int(msg.MyWindow) - case *channelOpenFailureMsg: + if err := ch.waitForChannelOpenResponse(); err != nil { c.chanlist.remove(ch.id) - return nil, errors.New("ssh: error opening remote TCP connection: " + msg.Message) - default: - c.chanlist.remove(ch.id) - return nil, errors.New("ssh: unexpected packet") + return nil, fmt.Errorf("ssh: unable to open direct tcpip connection: %v", err) } return &tcpchan{ clientChan: ch, - Reader: &chanReader{ - packetWriter: ch, - peersId: ch.peersId, - data: ch.data, - }, - Writer: &chanWriter{ - packetWriter: ch, - peersId: ch.peersId, - win: ch.win, - }, + Reader: ch.stdout, + Writer: ch.stdin, }, nil }