mirror of
https://github.com/golang/go
synced 2024-11-21 21:34:40 -07:00
exp/ssh: simplify client channel open logic
This is part one of a small set of CL's that aim to resolve the outstanding TODOs relating to channel close and blocking behavior. Firstly, the hairy handling of assigning the peersId is now done in one place. The cost of this change is the slightly paradoxical construction of the partially created clientChan. Secondly, by creating clientChan.stdin/out/err when the channel is opened, the creation of consumers like tcpchan and Session is simplified; they just have to wire themselves up to the relevant readers/writers. R=agl, gustav.paul, rsc CC=golang-dev https://golang.org/cl/5448073
This commit is contained in:
parent
0a5508c692
commit
bbbd41f4ff
@ -200,7 +200,7 @@ func (c *ClientConn) mainLoop() {
|
||||
peersId := uint32(packet[1])<<24 | uint32(packet[2])<<16 | uint32(packet[3])<<8 | uint32(packet[4])
|
||||
if length := int(packet[5])<<24 | int(packet[6])<<16 | int(packet[7])<<8 | int(packet[8]); length > 0 {
|
||||
packet = packet[9:]
|
||||
c.getChan(peersId).data <- packet[:length]
|
||||
c.getChan(peersId).stdout.data <- packet[:length]
|
||||
}
|
||||
case msgChannelExtendedData:
|
||||
if len(packet) < 13 {
|
||||
@ -215,7 +215,7 @@ func (c *ClientConn) mainLoop() {
|
||||
// for stderr on interactive sessions. Other data types are
|
||||
// silently discarded.
|
||||
if datatype == 1 {
|
||||
c.getChan(peersId).dataExt <- packet[:length]
|
||||
c.getChan(peersId).stderr.data <- packet[:length]
|
||||
}
|
||||
}
|
||||
default:
|
||||
@ -228,9 +228,9 @@ func (c *ClientConn) mainLoop() {
|
||||
c.getChan(msg.PeersId).msg <- msg
|
||||
case *channelCloseMsg:
|
||||
ch := c.getChan(msg.PeersId)
|
||||
close(ch.win)
|
||||
close(ch.data)
|
||||
close(ch.dataExt)
|
||||
close(ch.stdin.win)
|
||||
close(ch.stdout.data)
|
||||
close(ch.stderr.data)
|
||||
c.chanlist.remove(msg.PeersId)
|
||||
case *channelEOFMsg:
|
||||
c.getChan(msg.PeersId).msg <- msg
|
||||
@ -241,7 +241,7 @@ func (c *ClientConn) mainLoop() {
|
||||
case *channelRequestMsg:
|
||||
c.getChan(msg.PeersId).msg <- msg
|
||||
case *windowAdjustMsg:
|
||||
c.getChan(msg.PeersId).win <- int(msg.AdditionalBytes)
|
||||
c.getChan(msg.PeersId).stdin.win <- int(msg.AdditionalBytes)
|
||||
default:
|
||||
fmt.Printf("mainLoop: unhandled message %T: %v\n", msg, msg)
|
||||
}
|
||||
@ -290,21 +290,49 @@ func (c *ClientConfig) rand() io.Reader {
|
||||
type clientChan struct {
|
||||
packetWriter
|
||||
id, peersId uint32
|
||||
data chan []byte // receives the payload of channelData messages
|
||||
dataExt chan []byte // receives the payload of channelExtendedData messages
|
||||
win chan int // receives window adjustments
|
||||
stdin *chanWriter // receives window adjustments
|
||||
stdout *chanReader // receives the payload of channelData messages
|
||||
stderr *chanReader // receives the payload of channelExtendedData messages
|
||||
msg chan interface{} // incoming messages
|
||||
}
|
||||
|
||||
// newClientChan returns a partially constructed *clientChan
|
||||
// using the local id provided. To be usable clientChan.peersId
|
||||
// needs to be assigned once known.
|
||||
func newClientChan(t *transport, id uint32) *clientChan {
|
||||
return &clientChan{
|
||||
c := &clientChan{
|
||||
packetWriter: t,
|
||||
id: id,
|
||||
data: make(chan []byte, 16),
|
||||
dataExt: make(chan []byte, 16),
|
||||
win: make(chan int, 16),
|
||||
msg: make(chan interface{}, 16),
|
||||
}
|
||||
c.stdin = &chanWriter{
|
||||
win: make(chan int, 16),
|
||||
clientChan: c,
|
||||
}
|
||||
c.stdout = &chanReader{
|
||||
data: make(chan []byte, 16),
|
||||
clientChan: c,
|
||||
}
|
||||
c.stderr = &chanReader{
|
||||
data: make(chan []byte, 16),
|
||||
clientChan: c,
|
||||
}
|
||||
return c
|
||||
}
|
||||
|
||||
// waitForChannelOpenResponse, if successful, fills out
|
||||
// the peerId and records any initial window advertisement.
|
||||
func (c *clientChan) waitForChannelOpenResponse() error {
|
||||
switch msg := (<-c.msg).(type) {
|
||||
case *channelOpenConfirmMsg:
|
||||
// fixup peersId field
|
||||
c.peersId = msg.MyId
|
||||
c.stdin.win <- int(msg.MyWindow)
|
||||
return nil
|
||||
case *channelOpenFailureMsg:
|
||||
return errors.New(safeString(msg.Message))
|
||||
}
|
||||
return errors.New("unexpected packet")
|
||||
}
|
||||
|
||||
// Close closes the channel. This does not close the underlying connection.
|
||||
@ -356,9 +384,8 @@ func (c *chanlist) remove(id uint32) {
|
||||
// A chanWriter represents the stdin of a remote process.
|
||||
type chanWriter struct {
|
||||
win chan int // receives window adjustments
|
||||
peersId uint32 // the peer's id
|
||||
rwin int // current rwin size
|
||||
packetWriter // for sending channelDataMsg
|
||||
clientChan *clientChan // the channel backing this writer
|
||||
}
|
||||
|
||||
// Write writes data to the remote process's standard input.
|
||||
@ -372,12 +399,13 @@ func (w *chanWriter) Write(data []byte) (n int, err error) {
|
||||
w.rwin += win
|
||||
continue
|
||||
}
|
||||
peersId := w.clientChan.peersId
|
||||
n = len(data)
|
||||
packet := make([]byte, 0, 9+n)
|
||||
packet = append(packet, msgChannelData,
|
||||
byte(w.peersId>>24), byte(w.peersId>>16), byte(w.peersId>>8), byte(w.peersId),
|
||||
byte(peersId>>24), byte(peersId>>16), byte(peersId>>8), byte(peersId),
|
||||
byte(n>>24), byte(n>>16), byte(n>>8), byte(n))
|
||||
err = w.writePacket(append(packet, data...))
|
||||
err = w.clientChan.writePacket(append(packet, data...))
|
||||
w.rwin -= n
|
||||
return
|
||||
}
|
||||
@ -385,7 +413,7 @@ func (w *chanWriter) Write(data []byte) (n int, err error) {
|
||||
}
|
||||
|
||||
func (w *chanWriter) Close() error {
|
||||
return w.writePacket(marshal(msgChannelEOF, channelEOFMsg{w.peersId}))
|
||||
return w.clientChan.writePacket(marshal(msgChannelEOF, channelEOFMsg{w.clientChan.peersId}))
|
||||
}
|
||||
|
||||
// A chanReader represents stdout or stderr of a remote process.
|
||||
@ -394,8 +422,7 @@ type chanReader struct {
|
||||
// If writes to this channel block, they will block mainLoop, making
|
||||
// it unable to receive new messages from the remote side.
|
||||
data chan []byte // receives data from remote
|
||||
peersId uint32 // the peer's id
|
||||
packetWriter // for sending windowAdjustMsg
|
||||
clientChan *clientChan // the channel backing this reader
|
||||
buf []byte
|
||||
}
|
||||
|
||||
@ -407,10 +434,10 @@ func (r *chanReader) Read(data []byte) (int, error) {
|
||||
n := copy(data, r.buf)
|
||||
r.buf = r.buf[n:]
|
||||
msg := windowAdjustMsg{
|
||||
PeersId: r.peersId,
|
||||
PeersId: r.clientChan.peersId,
|
||||
AdditionalBytes: uint32(n),
|
||||
}
|
||||
return n, r.writePacket(marshal(msgChannelWindowAdjust, msg))
|
||||
return n, r.clientChan.writePacket(marshal(msgChannelWindowAdjust, msg))
|
||||
}
|
||||
r.buf, ok = <-r.data
|
||||
if !ok {
|
||||
|
@ -285,13 +285,8 @@ func (s *Session) stdin() error {
|
||||
s.Stdin = new(bytes.Buffer)
|
||||
}
|
||||
s.copyFuncs = append(s.copyFuncs, func() error {
|
||||
w := &chanWriter{
|
||||
packetWriter: s,
|
||||
peersId: s.peersId,
|
||||
win: s.win,
|
||||
}
|
||||
_, err := io.Copy(w, s.Stdin)
|
||||
if err1 := w.Close(); err == nil {
|
||||
_, err := io.Copy(s.clientChan.stdin, s.Stdin)
|
||||
if err1 := s.clientChan.stdin.Close(); err == nil {
|
||||
err = err1
|
||||
}
|
||||
return err
|
||||
@ -304,12 +299,7 @@ func (s *Session) stdout() error {
|
||||
s.Stdout = ioutil.Discard
|
||||
}
|
||||
s.copyFuncs = append(s.copyFuncs, func() error {
|
||||
r := &chanReader{
|
||||
packetWriter: s,
|
||||
peersId: s.peersId,
|
||||
data: s.data,
|
||||
}
|
||||
_, err := io.Copy(s.Stdout, r)
|
||||
_, err := io.Copy(s.Stdout, s.clientChan.stdout)
|
||||
return err
|
||||
})
|
||||
return nil
|
||||
@ -320,12 +310,7 @@ func (s *Session) stderr() error {
|
||||
s.Stderr = ioutil.Discard
|
||||
}
|
||||
s.copyFuncs = append(s.copyFuncs, func() error {
|
||||
r := &chanReader{
|
||||
packetWriter: s,
|
||||
peersId: s.peersId,
|
||||
data: s.dataExt,
|
||||
}
|
||||
_, err := io.Copy(s.Stderr, r)
|
||||
_, err := io.Copy(s.Stderr, s.clientChan.stderr)
|
||||
return err
|
||||
})
|
||||
return nil
|
||||
@ -398,19 +383,11 @@ func (c *ClientConn) NewSession() (*Session, error) {
|
||||
c.chanlist.remove(ch.id)
|
||||
return nil, err
|
||||
}
|
||||
// wait for response
|
||||
msg := <-ch.msg
|
||||
switch msg := msg.(type) {
|
||||
case *channelOpenConfirmMsg:
|
||||
ch.peersId = msg.MyId
|
||||
ch.win <- int(msg.MyWindow)
|
||||
if err := ch.waitForChannelOpenResponse(); err != nil {
|
||||
c.chanlist.remove(ch.id)
|
||||
return nil, fmt.Errorf("ssh: unable to open session: %v", err)
|
||||
}
|
||||
return &Session{
|
||||
clientChan: ch,
|
||||
}, nil
|
||||
case *channelOpenFailureMsg:
|
||||
c.chanlist.remove(ch.id)
|
||||
return nil, fmt.Errorf("ssh: channel open failed: %s", msg.Message)
|
||||
}
|
||||
c.chanlist.remove(ch.id)
|
||||
return nil, fmt.Errorf("ssh: unexpected message %T: %v", msg, msg)
|
||||
}
|
||||
|
@ -6,6 +6,7 @@ package ssh
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
)
|
||||
@ -42,9 +43,6 @@ func (c *ClientConn) DialTCP(n string, laddr, raddr *net.TCPAddr) (net.Conn, err
|
||||
}, nil
|
||||
}
|
||||
|
||||
// dial opens a direct-tcpip connection to the remote server. laddr and raddr are passed as
|
||||
// strings and are expected to be resolveable at the remote end.
|
||||
func (c *ClientConn) dial(laddr string, lport int, raddr string, rport int) (*tcpchan, error) {
|
||||
// RFC 4254 7.2
|
||||
type channelOpenDirectMsg struct {
|
||||
ChanType string
|
||||
@ -56,6 +54,10 @@ func (c *ClientConn) dial(laddr string, lport int, raddr string, rport int) (*tc
|
||||
laddr string
|
||||
lport uint32
|
||||
}
|
||||
|
||||
// dial opens a direct-tcpip connection to the remote server. laddr and raddr are passed as
|
||||
// strings and are expected to be resolveable at the remote end.
|
||||
func (c *ClientConn) dial(laddr string, lport int, raddr string, rport int) (*tcpchan, error) {
|
||||
ch := c.newChan(c.transport)
|
||||
if err := c.writePacket(marshal(msgChannelOpen, channelOpenDirectMsg{
|
||||
ChanType: "direct-tcpip",
|
||||
@ -70,30 +72,14 @@ func (c *ClientConn) dial(laddr string, lport int, raddr string, rport int) (*tc
|
||||
c.chanlist.remove(ch.id)
|
||||
return nil, err
|
||||
}
|
||||
// wait for response
|
||||
switch msg := (<-ch.msg).(type) {
|
||||
case *channelOpenConfirmMsg:
|
||||
ch.peersId = msg.MyId
|
||||
ch.win <- int(msg.MyWindow)
|
||||
case *channelOpenFailureMsg:
|
||||
if err := ch.waitForChannelOpenResponse(); err != nil {
|
||||
c.chanlist.remove(ch.id)
|
||||
return nil, errors.New("ssh: error opening remote TCP connection: " + msg.Message)
|
||||
default:
|
||||
c.chanlist.remove(ch.id)
|
||||
return nil, errors.New("ssh: unexpected packet")
|
||||
return nil, fmt.Errorf("ssh: unable to open direct tcpip connection: %v", err)
|
||||
}
|
||||
return &tcpchan{
|
||||
clientChan: ch,
|
||||
Reader: &chanReader{
|
||||
packetWriter: ch,
|
||||
peersId: ch.peersId,
|
||||
data: ch.data,
|
||||
},
|
||||
Writer: &chanWriter{
|
||||
packetWriter: ch,
|
||||
peersId: ch.peersId,
|
||||
win: ch.win,
|
||||
},
|
||||
Reader: ch.stdout,
|
||||
Writer: ch.stdin,
|
||||
}, nil
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user