mirror of
https://github.com/golang/go
synced 2024-11-21 15:24:45 -07:00
exp/ssh: add Std{in,out,err}Pipe methods to Session
R=gustav.paul, cw, agl, rsc, n13m3y3r CC=golang-dev https://golang.org/cl/5433080
This commit is contained in:
parent
c0a53bbc4a
commit
c4d0ac0e2f
@ -54,9 +54,10 @@ type Session struct {
|
||||
|
||||
*clientChan // the channel backing this session
|
||||
|
||||
started bool // true once a Shell or Run is invoked.
|
||||
copyFuncs []func() error
|
||||
errch chan error // one send per copyFunc
|
||||
started bool // true once Start, Run or Shell is invoked.
|
||||
closeAfterWait []io.Closer
|
||||
copyFuncs []func() error
|
||||
errch chan error // one send per copyFunc
|
||||
}
|
||||
|
||||
// RFC 4254 Section 6.4.
|
||||
@ -231,7 +232,7 @@ func (s *Session) start() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Wait waits for the remote command to exit.
|
||||
// Wait waits for the remote command to exit.
|
||||
func (s *Session) Wait() error {
|
||||
if !s.started {
|
||||
return errors.New("ssh: session not started")
|
||||
@ -244,11 +245,12 @@ func (s *Session) Wait() error {
|
||||
copyError = err
|
||||
}
|
||||
}
|
||||
|
||||
for _, fd := range s.closeAfterWait {
|
||||
fd.Close()
|
||||
}
|
||||
if waitErr != nil {
|
||||
return waitErr
|
||||
}
|
||||
|
||||
return copyError
|
||||
}
|
||||
|
||||
@ -283,11 +285,15 @@ func (s *Session) stdin() error {
|
||||
s.Stdin = new(bytes.Buffer)
|
||||
}
|
||||
s.copyFuncs = append(s.copyFuncs, func() error {
|
||||
_, err := io.Copy(&chanWriter{
|
||||
w := &chanWriter{
|
||||
packetWriter: s,
|
||||
peersId: s.peersId,
|
||||
win: s.win,
|
||||
}, s.Stdin)
|
||||
}
|
||||
_, err := io.Copy(w, s.Stdin)
|
||||
if err1 := w.Close(); err == nil {
|
||||
err = err1
|
||||
}
|
||||
return err
|
||||
})
|
||||
return nil
|
||||
@ -298,11 +304,12 @@ func (s *Session) stdout() error {
|
||||
s.Stdout = ioutil.Discard
|
||||
}
|
||||
s.copyFuncs = append(s.copyFuncs, func() error {
|
||||
_, err := io.Copy(s.Stdout, &chanReader{
|
||||
r := &chanReader{
|
||||
packetWriter: s,
|
||||
peersId: s.peersId,
|
||||
data: s.data,
|
||||
})
|
||||
}
|
||||
_, err := io.Copy(s.Stdout, r)
|
||||
return err
|
||||
})
|
||||
return nil
|
||||
@ -313,16 +320,72 @@ func (s *Session) stderr() error {
|
||||
s.Stderr = ioutil.Discard
|
||||
}
|
||||
s.copyFuncs = append(s.copyFuncs, func() error {
|
||||
_, err := io.Copy(s.Stderr, &chanReader{
|
||||
r := &chanReader{
|
||||
packetWriter: s,
|
||||
peersId: s.peersId,
|
||||
data: s.dataExt,
|
||||
})
|
||||
}
|
||||
_, err := io.Copy(s.Stderr, r)
|
||||
return err
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
||||
// StdinPipe returns a pipe that will be connected to the
|
||||
// remote command's standard input when the command starts.
|
||||
func (s *Session) StdinPipe() (io.WriteCloser, error) {
|
||||
if s.Stdin != nil {
|
||||
return nil, errors.New("ssh: Stdin already set")
|
||||
}
|
||||
if s.started {
|
||||
return nil, errors.New("ssh: StdinPipe after process started")
|
||||
}
|
||||
pr, pw := io.Pipe()
|
||||
s.Stdin = pr
|
||||
s.closeAfterWait = append(s.closeAfterWait, pr)
|
||||
return pw, nil
|
||||
}
|
||||
|
||||
// StdoutPipe returns a pipe that will be connected to the
|
||||
// remote command's standard output when the command starts.
|
||||
// There is a fixed amount of buffering that is shared between
|
||||
// stdout and stderr streams. If the StdoutPipe reader is
|
||||
// not serviced fast enought it may eventually cause the
|
||||
// remote command to block.
|
||||
func (s *Session) StdoutPipe() (io.ReadCloser, error) {
|
||||
if s.Stdout != nil {
|
||||
return nil, errors.New("ssh: Stdout already set")
|
||||
}
|
||||
if s.started {
|
||||
return nil, errors.New("ssh: StdoutPipe after process started")
|
||||
}
|
||||
pr, pw := io.Pipe()
|
||||
s.Stdout = pw
|
||||
s.closeAfterWait = append(s.closeAfterWait, pw)
|
||||
return pr, nil
|
||||
}
|
||||
|
||||
// StderrPipe returns a pipe that will be connected to the
|
||||
// remote command's standard error when the command starts.
|
||||
// There is a fixed amount of buffering that is shared between
|
||||
// stdout and stderr streams. If the StderrPipe reader is
|
||||
// not serviced fast enought it may eventually cause the
|
||||
// remote command to block.
|
||||
func (s *Session) StderrPipe() (io.ReadCloser, error) {
|
||||
if s.Stderr != nil {
|
||||
return nil, errors.New("ssh: Stderr already set")
|
||||
}
|
||||
if s.started {
|
||||
return nil, errors.New("ssh: StderrPipe after process started")
|
||||
}
|
||||
pr, pw := io.Pipe()
|
||||
s.Stderr = pw
|
||||
s.closeAfterWait = append(s.closeAfterWait, pw)
|
||||
return pr, nil
|
||||
}
|
||||
|
||||
// TODO(dfc) add Output and CombinedOutput helpers
|
||||
|
||||
// NewSession returns a new interactive session on the remote host.
|
||||
func (c *ClientConn) NewSession() (*Session, error) {
|
||||
ch := c.newChan(c.transport)
|
||||
|
149
src/pkg/exp/ssh/session_test.go
Normal file
149
src/pkg/exp/ssh/session_test.go
Normal file
@ -0,0 +1,149 @@
|
||||
// Copyright 2011 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package ssh
|
||||
|
||||
// Session tests.
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"io"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// dial constructs a new test server and returns a *ClientConn.
|
||||
func dial(t *testing.T) *ClientConn {
|
||||
pw := password("tiger")
|
||||
serverConfig.PasswordCallback = func(user, pass string) bool {
|
||||
return user == "testuser" && pass == string(pw)
|
||||
}
|
||||
serverConfig.PubKeyCallback = nil
|
||||
|
||||
l, err := Listen("tcp", "127.0.0.1:0", serverConfig)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to listen: %s", err)
|
||||
}
|
||||
go func() {
|
||||
defer l.Close()
|
||||
conn, err := l.Accept()
|
||||
if err != nil {
|
||||
t.Errorf("Unable to accept: %v", err)
|
||||
return
|
||||
}
|
||||
defer conn.Close()
|
||||
if err := conn.Handshake(); err != nil {
|
||||
t.Errorf("Unable to handshake: %v", err)
|
||||
return
|
||||
}
|
||||
for {
|
||||
ch, err := conn.Accept()
|
||||
if err == io.EOF {
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
t.Errorf("Unable to accept incoming channel request: %v", err)
|
||||
return
|
||||
}
|
||||
if ch.ChannelType() != "session" {
|
||||
ch.Reject(UnknownChannelType, "unknown channel type")
|
||||
continue
|
||||
}
|
||||
ch.Accept()
|
||||
go func() {
|
||||
defer ch.Close()
|
||||
// this string is returned to stdout
|
||||
shell := NewServerShell(ch, "golang")
|
||||
shell.ReadLine()
|
||||
type exitMsg struct {
|
||||
PeersId uint32
|
||||
Request string
|
||||
WantReply bool
|
||||
Status uint32
|
||||
}
|
||||
// TODO(dfc) casting to the concrete type should not be
|
||||
// necessary to send a packet.
|
||||
msg := exitMsg{
|
||||
PeersId: ch.(*channel).theirId,
|
||||
Request: "exit-status",
|
||||
WantReply: false,
|
||||
Status: 0,
|
||||
}
|
||||
ch.(*channel).serverConn.writePacket(marshal(msgChannelRequest, msg))
|
||||
}()
|
||||
}
|
||||
t.Log("done")
|
||||
}()
|
||||
|
||||
config := &ClientConfig{
|
||||
User: "testuser",
|
||||
Auth: []ClientAuth{
|
||||
ClientAuthPassword(pw),
|
||||
},
|
||||
}
|
||||
|
||||
c, err := Dial("tcp", l.Addr().String(), config)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to dial remote side: %s", err)
|
||||
}
|
||||
return c
|
||||
}
|
||||
|
||||
// Test a simple string is returned to session.Stdout.
|
||||
func TestSessionShell(t *testing.T) {
|
||||
conn := dial(t)
|
||||
defer conn.Close()
|
||||
session, err := conn.NewSession()
|
||||
if err != nil {
|
||||
t.Fatalf("Unable to request new session: %s", err)
|
||||
}
|
||||
defer session.Close()
|
||||
stdout := new(bytes.Buffer)
|
||||
session.Stdout = stdout
|
||||
if err := session.Shell(); err != nil {
|
||||
t.Fatalf("Unable to execute command: %s", err)
|
||||
}
|
||||
if err := session.Wait(); err != nil {
|
||||
t.Fatalf("Remote command did not exit cleanly: %s", err)
|
||||
}
|
||||
actual := stdout.String()
|
||||
if actual != "golang" {
|
||||
t.Fatalf("Remote shell did not return expected string: expected=golang, actual=%s", actual)
|
||||
}
|
||||
}
|
||||
|
||||
// TODO(dfc) add support for Std{in,err}Pipe when the Server supports it.
|
||||
|
||||
// Test a simple string is returned via StdoutPipe.
|
||||
func TestSessionStdoutPipe(t *testing.T) {
|
||||
conn := dial(t)
|
||||
defer conn.Close()
|
||||
session, err := conn.NewSession()
|
||||
if err != nil {
|
||||
t.Fatalf("Unable to request new session: %s", err)
|
||||
}
|
||||
defer session.Close()
|
||||
stdout, err := session.StdoutPipe()
|
||||
if err != nil {
|
||||
t.Fatalf("Unable to request StdoutPipe(): %v", err)
|
||||
}
|
||||
var buf bytes.Buffer
|
||||
if err := session.Shell(); err != nil {
|
||||
t.Fatalf("Unable to execute command: %s", err)
|
||||
}
|
||||
done := make(chan bool, 1)
|
||||
go func() {
|
||||
if _, err := io.Copy(&buf, stdout); err != nil {
|
||||
t.Errorf("Copy of stdout failed: %v", err)
|
||||
}
|
||||
done <- true
|
||||
}()
|
||||
if err := session.Wait(); err != nil {
|
||||
t.Fatalf("Remote command did not exit cleanly: %s", err)
|
||||
}
|
||||
<-done
|
||||
actual := buf.String()
|
||||
if actual != "golang" {
|
||||
t.Fatalf("Remote shell did not return expected string: expected=golang, actual=%s", actual)
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue
Block a user