mirror of
https://github.com/golang/go
synced 2024-11-24 23:47:55 -07:00
exp/ssh: alter Session to match the exec.Cmd API
This CL inverts the direction of the Stdin/out/err members of the Session struct so they reflect the API of the exec.Cmd. In doing so it borrows heavily from the exec package. Additionally Shell now returns immediately, wait for completion using Wait. Exec calls Wait internally and so blocks until the remote command is complete. Credit to Gustavo Niemeyer for the impetus for this CL. R=rsc, agl, n13m3y3r, huin, bradfitz CC=cw, golang-dev https://golang.org/cl/5322055
This commit is contained in:
parent
05d8d112fe
commit
fb57134d47
@ -342,17 +342,6 @@ func (c *clientChan) Close() error {
|
||||
}))
|
||||
}
|
||||
|
||||
func (c *clientChan) sendChanReq(req channelRequestMsg) error {
|
||||
if err := c.writePacket(marshal(msgChannelRequest, req)); err != nil {
|
||||
return err
|
||||
}
|
||||
msg := <-c.msg
|
||||
if _, ok := msg.(*channelRequestSuccessMsg); ok {
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("failed to complete request: %s, %#v", req.Request, msg)
|
||||
}
|
||||
|
||||
// Thread safe channel list.
|
||||
type chanlist struct {
|
||||
// protects concurrent access to chans
|
||||
|
@ -8,66 +8,104 @@ package ssh
|
||||
// "RFC 4254, section 6".
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"bytes"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
)
|
||||
|
||||
// A Session represents a connection to a remote command or shell.
|
||||
type Session struct {
|
||||
// Writes to Stdin are made available to the remote command's standard input.
|
||||
// Closing Stdin causes the command to observe an EOF on its standard input.
|
||||
Stdin io.WriteCloser
|
||||
// Stdin specifies the remote process's standard input.
|
||||
// If Stdin is nil, the remote process reads from an empty
|
||||
// bytes.Buffer.
|
||||
Stdin io.Reader
|
||||
|
||||
// Reads from Stdout and Stderr consume from the remote command's standard
|
||||
// output and error streams, respectively.
|
||||
// There is a fixed amount of buffering that is shared for the two streams.
|
||||
// Failing to read from either may eventually cause the command to block.
|
||||
// Closing Stdout unblocks such writes and causes them to return errors.
|
||||
Stdout io.ReadCloser
|
||||
Stderr io.Reader
|
||||
// Stdout and Stderr specify the remote process's standard
|
||||
// output and error.
|
||||
//
|
||||
// If either is nil, Run connects the corresponding file
|
||||
// descriptor to an instance of ioutil.Discard. There is a
|
||||
// fixed amount of buffering that is shared for the two streams.
|
||||
// If either blocks it may eventually cause the remote
|
||||
// command to block.
|
||||
Stdout io.Writer
|
||||
Stderr io.Writer
|
||||
|
||||
*clientChan // the channel backing this session
|
||||
|
||||
started bool // started is set to true once a Shell or Exec is invoked.
|
||||
started bool // true once a Shell or Exec is invoked.
|
||||
copyFuncs []func() error
|
||||
errch chan error // one send per copyFunc
|
||||
}
|
||||
|
||||
// RFC 4254 Section 6.4.
|
||||
type setenvRequest struct {
|
||||
PeersId uint32
|
||||
Request string
|
||||
WantReply bool
|
||||
Name string
|
||||
Value string
|
||||
}
|
||||
|
||||
// Setenv sets an environment variable that will be applied to any
|
||||
// command executed by Shell or Exec.
|
||||
func (s *Session) Setenv(name, value string) error {
|
||||
n, v := []byte(name), []byte(value)
|
||||
nlen, vlen := stringLength(n), stringLength(v)
|
||||
payload := make([]byte, nlen+vlen)
|
||||
marshalString(payload[:nlen], n)
|
||||
marshalString(payload[nlen:], v)
|
||||
|
||||
return s.sendChanReq(channelRequestMsg{
|
||||
req := setenvRequest{
|
||||
PeersId: s.id,
|
||||
Request: "env",
|
||||
WantReply: true,
|
||||
RequestSpecificData: payload,
|
||||
})
|
||||
Name: name,
|
||||
Value: value,
|
||||
}
|
||||
if err := s.writePacket(marshal(msgChannelRequest, req)); err != nil {
|
||||
return err
|
||||
}
|
||||
return s.waitForResponse()
|
||||
}
|
||||
|
||||
// An empty mode list (a string of 1 character, opcode 0), see RFC 4254 Section 8.
|
||||
var emptyModeList = []byte{0, 0, 0, 1, 0}
|
||||
// An empty mode list, see RFC 4254 Section 8.
|
||||
var emptyModelist = "\x00"
|
||||
|
||||
// RFC 4254 Section 6.2.
|
||||
type ptyRequestMsg struct {
|
||||
PeersId uint32
|
||||
Request string
|
||||
WantReply bool
|
||||
Term string
|
||||
Columns uint32
|
||||
Rows uint32
|
||||
Width uint32
|
||||
Height uint32
|
||||
Modelist string
|
||||
}
|
||||
|
||||
// RequestPty requests the association of a pty with the session on the remote host.
|
||||
func (s *Session) RequestPty(term string, h, w int) error {
|
||||
buf := make([]byte, 4+len(term)+16+len(emptyModeList))
|
||||
b := marshalString(buf, []byte(term))
|
||||
binary.BigEndian.PutUint32(b, uint32(h))
|
||||
binary.BigEndian.PutUint32(b[4:], uint32(w))
|
||||
binary.BigEndian.PutUint32(b[8:], uint32(h*8))
|
||||
binary.BigEndian.PutUint32(b[12:], uint32(w*8))
|
||||
copy(b[16:], emptyModeList)
|
||||
|
||||
return s.sendChanReq(channelRequestMsg{
|
||||
req := ptyRequestMsg{
|
||||
PeersId: s.id,
|
||||
Request: "pty-req",
|
||||
WantReply: true,
|
||||
RequestSpecificData: buf,
|
||||
})
|
||||
Term: term,
|
||||
Columns: uint32(w),
|
||||
Rows: uint32(h),
|
||||
Width: uint32(w * 8),
|
||||
Height: uint32(h * 8),
|
||||
Modelist: emptyModelist,
|
||||
}
|
||||
if err := s.writePacket(marshal(msgChannelRequest, req)); err != nil {
|
||||
return err
|
||||
}
|
||||
return s.waitForResponse()
|
||||
}
|
||||
|
||||
// RFC 4254 Section 6.5.
|
||||
type execMsg struct {
|
||||
PeersId uint32
|
||||
Request string
|
||||
WantReply bool
|
||||
Command string
|
||||
}
|
||||
|
||||
// Exec runs cmd on the remote host. Typically, the remote
|
||||
@ -75,34 +113,166 @@ func (s *Session) RequestPty(term string, h, w int) error {
|
||||
// A Session only accepts one call to Exec or Shell.
|
||||
func (s *Session) Exec(cmd string) error {
|
||||
if s.started {
|
||||
return errors.New("session already started")
|
||||
return errors.New("ssh: session already started")
|
||||
}
|
||||
cmdLen := stringLength([]byte(cmd))
|
||||
payload := make([]byte, cmdLen)
|
||||
marshalString(payload, []byte(cmd))
|
||||
s.started = true
|
||||
|
||||
return s.sendChanReq(channelRequestMsg{
|
||||
req := execMsg{
|
||||
PeersId: s.id,
|
||||
Request: "exec",
|
||||
WantReply: true,
|
||||
RequestSpecificData: payload,
|
||||
})
|
||||
Command: cmd,
|
||||
}
|
||||
if err := s.writePacket(marshal(msgChannelRequest, req)); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := s.waitForResponse(); err != nil {
|
||||
return fmt.Errorf("ssh: could not execute command %s: %v", cmd, err)
|
||||
}
|
||||
if err := s.start(); err != nil {
|
||||
return err
|
||||
}
|
||||
return s.Wait()
|
||||
}
|
||||
|
||||
// Shell starts a login shell on the remote host. A Session only
|
||||
// accepts one call to Exec or Shell.
|
||||
func (s *Session) Shell() error {
|
||||
if s.started {
|
||||
return errors.New("session already started")
|
||||
return errors.New("ssh: session already started")
|
||||
}
|
||||
s.started = true
|
||||
|
||||
return s.sendChanReq(channelRequestMsg{
|
||||
req := channelRequestMsg{
|
||||
PeersId: s.id,
|
||||
Request: "shell",
|
||||
WantReply: true,
|
||||
}
|
||||
if err := s.writePacket(marshal(msgChannelRequest, req)); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := s.waitForResponse(); err != nil {
|
||||
return fmt.Errorf("ssh: cound not execute shell: %v", err)
|
||||
}
|
||||
return s.start()
|
||||
}
|
||||
|
||||
func (s *Session) waitForResponse() error {
|
||||
msg := <-s.msg
|
||||
switch msg.(type) {
|
||||
case *channelRequestSuccessMsg:
|
||||
return nil
|
||||
case *channelRequestFailureMsg:
|
||||
return errors.New("request failed")
|
||||
}
|
||||
return fmt.Errorf("unknown packet %T received: %v", msg, msg)
|
||||
}
|
||||
|
||||
func (s *Session) start() error {
|
||||
s.started = true
|
||||
|
||||
type F func(*Session) error
|
||||
for _, setupFd := range []F{(*Session).stdin, (*Session).stdout, (*Session).stderr} {
|
||||
if err := setupFd(s); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
s.errch = make(chan error, len(s.copyFuncs))
|
||||
for _, fn := range s.copyFuncs {
|
||||
go func(fn func() error) {
|
||||
s.errch <- fn()
|
||||
}(fn)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Wait waits for the remote command to exit.
|
||||
func (s *Session) Wait() error {
|
||||
if !s.started {
|
||||
return errors.New("ssh: session not started")
|
||||
}
|
||||
waitErr := s.wait()
|
||||
|
||||
var copyError error
|
||||
for _ = range s.copyFuncs {
|
||||
if err := <-s.errch; err != nil && copyError == nil {
|
||||
copyError = err
|
||||
}
|
||||
}
|
||||
|
||||
if waitErr != nil {
|
||||
return waitErr
|
||||
}
|
||||
|
||||
return copyError
|
||||
}
|
||||
|
||||
func (s *Session) wait() error {
|
||||
for {
|
||||
switch msg := (<-s.msg).(type) {
|
||||
case *channelRequestMsg:
|
||||
// TODO(dfc) improve this behavior to match os.Waitmsg
|
||||
switch msg.Request {
|
||||
case "exit-status":
|
||||
d := msg.RequestSpecificData
|
||||
status := int(d[0])<<24 | int(d[1])<<16 | int(d[2])<<8 | int(d[3])
|
||||
if status > 0 {
|
||||
return fmt.Errorf("remote process exited with %d", status)
|
||||
}
|
||||
return nil
|
||||
case "exit-signal":
|
||||
// TODO(dfc) make a more readable error message
|
||||
return fmt.Errorf("%v", msg.RequestSpecificData)
|
||||
default:
|
||||
return fmt.Errorf("wait: unexpected channel request: %v", msg)
|
||||
}
|
||||
default:
|
||||
return fmt.Errorf("wait: unexpected packet %T received: %v", msg, msg)
|
||||
}
|
||||
}
|
||||
panic("unreachable")
|
||||
}
|
||||
|
||||
func (s *Session) stdin() error {
|
||||
if s.Stdin == nil {
|
||||
s.Stdin = new(bytes.Buffer)
|
||||
}
|
||||
s.copyFuncs = append(s.copyFuncs, func() error {
|
||||
_, err := io.Copy(&chanWriter{
|
||||
packetWriter: s,
|
||||
id: s.id,
|
||||
win: s.win,
|
||||
}, s.Stdin)
|
||||
return err
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Session) stdout() error {
|
||||
if s.Stdout == nil {
|
||||
s.Stdout = ioutil.Discard
|
||||
}
|
||||
s.copyFuncs = append(s.copyFuncs, func() error {
|
||||
_, err := io.Copy(s.Stdout, &chanReader{
|
||||
packetWriter: s,
|
||||
id: s.id,
|
||||
data: s.data,
|
||||
})
|
||||
return err
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Session) stderr() error {
|
||||
if s.Stderr == nil {
|
||||
s.Stderr = ioutil.Discard
|
||||
}
|
||||
s.copyFuncs = append(s.copyFuncs, func() error {
|
||||
_, err := io.Copy(s.Stderr, &chanReader{
|
||||
packetWriter: s,
|
||||
id: s.id,
|
||||
data: s.dataExt,
|
||||
})
|
||||
return err
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
||||
// NewSession returns a new interactive session on the remote host.
|
||||
@ -112,21 +282,6 @@ func (c *ClientConn) NewSession() (*Session, error) {
|
||||
return nil, err
|
||||
}
|
||||
return &Session{
|
||||
Stdin: &chanWriter{
|
||||
packetWriter: ch,
|
||||
id: ch.id,
|
||||
win: ch.win,
|
||||
},
|
||||
Stdout: &chanReader{
|
||||
packetWriter: ch,
|
||||
id: ch.id,
|
||||
data: ch.data,
|
||||
},
|
||||
Stderr: &chanReader{
|
||||
packetWriter: ch,
|
||||
id: ch.id,
|
||||
data: ch.dataExt,
|
||||
},
|
||||
clientChan: ch,
|
||||
}, nil
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user