mirror of
https://github.com/golang/go
synced 2024-11-22 03:24:41 -07:00
exp/ssh: add Std{in,out,err}Pipe methods to Session
R=gustav.paul, cw, agl, rsc, n13m3y3r CC=golang-dev https://golang.org/cl/5433080
This commit is contained in:
parent
c0a53bbc4a
commit
c4d0ac0e2f
@ -54,7 +54,8 @@ type Session struct {
|
|||||||
|
|
||||||
*clientChan // the channel backing this session
|
*clientChan // the channel backing this session
|
||||||
|
|
||||||
started bool // true once a Shell or Run is invoked.
|
started bool // true once Start, Run or Shell is invoked.
|
||||||
|
closeAfterWait []io.Closer
|
||||||
copyFuncs []func() error
|
copyFuncs []func() error
|
||||||
errch chan error // one send per copyFunc
|
errch chan error // one send per copyFunc
|
||||||
}
|
}
|
||||||
@ -244,11 +245,12 @@ func (s *Session) Wait() error {
|
|||||||
copyError = err
|
copyError = err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
for _, fd := range s.closeAfterWait {
|
||||||
|
fd.Close()
|
||||||
|
}
|
||||||
if waitErr != nil {
|
if waitErr != nil {
|
||||||
return waitErr
|
return waitErr
|
||||||
}
|
}
|
||||||
|
|
||||||
return copyError
|
return copyError
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -283,11 +285,15 @@ func (s *Session) stdin() error {
|
|||||||
s.Stdin = new(bytes.Buffer)
|
s.Stdin = new(bytes.Buffer)
|
||||||
}
|
}
|
||||||
s.copyFuncs = append(s.copyFuncs, func() error {
|
s.copyFuncs = append(s.copyFuncs, func() error {
|
||||||
_, err := io.Copy(&chanWriter{
|
w := &chanWriter{
|
||||||
packetWriter: s,
|
packetWriter: s,
|
||||||
peersId: s.peersId,
|
peersId: s.peersId,
|
||||||
win: s.win,
|
win: s.win,
|
||||||
}, s.Stdin)
|
}
|
||||||
|
_, err := io.Copy(w, s.Stdin)
|
||||||
|
if err1 := w.Close(); err == nil {
|
||||||
|
err = err1
|
||||||
|
}
|
||||||
return err
|
return err
|
||||||
})
|
})
|
||||||
return nil
|
return nil
|
||||||
@ -298,11 +304,12 @@ func (s *Session) stdout() error {
|
|||||||
s.Stdout = ioutil.Discard
|
s.Stdout = ioutil.Discard
|
||||||
}
|
}
|
||||||
s.copyFuncs = append(s.copyFuncs, func() error {
|
s.copyFuncs = append(s.copyFuncs, func() error {
|
||||||
_, err := io.Copy(s.Stdout, &chanReader{
|
r := &chanReader{
|
||||||
packetWriter: s,
|
packetWriter: s,
|
||||||
peersId: s.peersId,
|
peersId: s.peersId,
|
||||||
data: s.data,
|
data: s.data,
|
||||||
})
|
}
|
||||||
|
_, err := io.Copy(s.Stdout, r)
|
||||||
return err
|
return err
|
||||||
})
|
})
|
||||||
return nil
|
return nil
|
||||||
@ -313,16 +320,72 @@ func (s *Session) stderr() error {
|
|||||||
s.Stderr = ioutil.Discard
|
s.Stderr = ioutil.Discard
|
||||||
}
|
}
|
||||||
s.copyFuncs = append(s.copyFuncs, func() error {
|
s.copyFuncs = append(s.copyFuncs, func() error {
|
||||||
_, err := io.Copy(s.Stderr, &chanReader{
|
r := &chanReader{
|
||||||
packetWriter: s,
|
packetWriter: s,
|
||||||
peersId: s.peersId,
|
peersId: s.peersId,
|
||||||
data: s.dataExt,
|
data: s.dataExt,
|
||||||
})
|
}
|
||||||
|
_, err := io.Copy(s.Stderr, r)
|
||||||
return err
|
return err
|
||||||
})
|
})
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// StdinPipe returns a pipe that will be connected to the
|
||||||
|
// remote command's standard input when the command starts.
|
||||||
|
func (s *Session) StdinPipe() (io.WriteCloser, error) {
|
||||||
|
if s.Stdin != nil {
|
||||||
|
return nil, errors.New("ssh: Stdin already set")
|
||||||
|
}
|
||||||
|
if s.started {
|
||||||
|
return nil, errors.New("ssh: StdinPipe after process started")
|
||||||
|
}
|
||||||
|
pr, pw := io.Pipe()
|
||||||
|
s.Stdin = pr
|
||||||
|
s.closeAfterWait = append(s.closeAfterWait, pr)
|
||||||
|
return pw, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// StdoutPipe returns a pipe that will be connected to the
|
||||||
|
// remote command's standard output when the command starts.
|
||||||
|
// There is a fixed amount of buffering that is shared between
|
||||||
|
// stdout and stderr streams. If the StdoutPipe reader is
|
||||||
|
// not serviced fast enought it may eventually cause the
|
||||||
|
// remote command to block.
|
||||||
|
func (s *Session) StdoutPipe() (io.ReadCloser, error) {
|
||||||
|
if s.Stdout != nil {
|
||||||
|
return nil, errors.New("ssh: Stdout already set")
|
||||||
|
}
|
||||||
|
if s.started {
|
||||||
|
return nil, errors.New("ssh: StdoutPipe after process started")
|
||||||
|
}
|
||||||
|
pr, pw := io.Pipe()
|
||||||
|
s.Stdout = pw
|
||||||
|
s.closeAfterWait = append(s.closeAfterWait, pw)
|
||||||
|
return pr, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// StderrPipe returns a pipe that will be connected to the
|
||||||
|
// remote command's standard error when the command starts.
|
||||||
|
// There is a fixed amount of buffering that is shared between
|
||||||
|
// stdout and stderr streams. If the StderrPipe reader is
|
||||||
|
// not serviced fast enought it may eventually cause the
|
||||||
|
// remote command to block.
|
||||||
|
func (s *Session) StderrPipe() (io.ReadCloser, error) {
|
||||||
|
if s.Stderr != nil {
|
||||||
|
return nil, errors.New("ssh: Stderr already set")
|
||||||
|
}
|
||||||
|
if s.started {
|
||||||
|
return nil, errors.New("ssh: StderrPipe after process started")
|
||||||
|
}
|
||||||
|
pr, pw := io.Pipe()
|
||||||
|
s.Stderr = pw
|
||||||
|
s.closeAfterWait = append(s.closeAfterWait, pw)
|
||||||
|
return pr, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO(dfc) add Output and CombinedOutput helpers
|
||||||
|
|
||||||
// NewSession returns a new interactive session on the remote host.
|
// NewSession returns a new interactive session on the remote host.
|
||||||
func (c *ClientConn) NewSession() (*Session, error) {
|
func (c *ClientConn) NewSession() (*Session, error) {
|
||||||
ch := c.newChan(c.transport)
|
ch := c.newChan(c.transport)
|
||||||
|
149
src/pkg/exp/ssh/session_test.go
Normal file
149
src/pkg/exp/ssh/session_test.go
Normal file
@ -0,0 +1,149 @@
|
|||||||
|
// Copyright 2011 The Go Authors. All rights reserved.
|
||||||
|
// Use of this source code is governed by a BSD-style
|
||||||
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
|
package ssh
|
||||||
|
|
||||||
|
// Session tests.
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"io"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
// dial constructs a new test server and returns a *ClientConn.
|
||||||
|
func dial(t *testing.T) *ClientConn {
|
||||||
|
pw := password("tiger")
|
||||||
|
serverConfig.PasswordCallback = func(user, pass string) bool {
|
||||||
|
return user == "testuser" && pass == string(pw)
|
||||||
|
}
|
||||||
|
serverConfig.PubKeyCallback = nil
|
||||||
|
|
||||||
|
l, err := Listen("tcp", "127.0.0.1:0", serverConfig)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unable to listen: %s", err)
|
||||||
|
}
|
||||||
|
go func() {
|
||||||
|
defer l.Close()
|
||||||
|
conn, err := l.Accept()
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Unable to accept: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer conn.Close()
|
||||||
|
if err := conn.Handshake(); err != nil {
|
||||||
|
t.Errorf("Unable to handshake: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
for {
|
||||||
|
ch, err := conn.Accept()
|
||||||
|
if err == io.EOF {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Unable to accept incoming channel request: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if ch.ChannelType() != "session" {
|
||||||
|
ch.Reject(UnknownChannelType, "unknown channel type")
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
ch.Accept()
|
||||||
|
go func() {
|
||||||
|
defer ch.Close()
|
||||||
|
// this string is returned to stdout
|
||||||
|
shell := NewServerShell(ch, "golang")
|
||||||
|
shell.ReadLine()
|
||||||
|
type exitMsg struct {
|
||||||
|
PeersId uint32
|
||||||
|
Request string
|
||||||
|
WantReply bool
|
||||||
|
Status uint32
|
||||||
|
}
|
||||||
|
// TODO(dfc) casting to the concrete type should not be
|
||||||
|
// necessary to send a packet.
|
||||||
|
msg := exitMsg{
|
||||||
|
PeersId: ch.(*channel).theirId,
|
||||||
|
Request: "exit-status",
|
||||||
|
WantReply: false,
|
||||||
|
Status: 0,
|
||||||
|
}
|
||||||
|
ch.(*channel).serverConn.writePacket(marshal(msgChannelRequest, msg))
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
t.Log("done")
|
||||||
|
}()
|
||||||
|
|
||||||
|
config := &ClientConfig{
|
||||||
|
User: "testuser",
|
||||||
|
Auth: []ClientAuth{
|
||||||
|
ClientAuthPassword(pw),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
c, err := Dial("tcp", l.Addr().String(), config)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unable to dial remote side: %s", err)
|
||||||
|
}
|
||||||
|
return c
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test a simple string is returned to session.Stdout.
|
||||||
|
func TestSessionShell(t *testing.T) {
|
||||||
|
conn := dial(t)
|
||||||
|
defer conn.Close()
|
||||||
|
session, err := conn.NewSession()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Unable to request new session: %s", err)
|
||||||
|
}
|
||||||
|
defer session.Close()
|
||||||
|
stdout := new(bytes.Buffer)
|
||||||
|
session.Stdout = stdout
|
||||||
|
if err := session.Shell(); err != nil {
|
||||||
|
t.Fatalf("Unable to execute command: %s", err)
|
||||||
|
}
|
||||||
|
if err := session.Wait(); err != nil {
|
||||||
|
t.Fatalf("Remote command did not exit cleanly: %s", err)
|
||||||
|
}
|
||||||
|
actual := stdout.String()
|
||||||
|
if actual != "golang" {
|
||||||
|
t.Fatalf("Remote shell did not return expected string: expected=golang, actual=%s", actual)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO(dfc) add support for Std{in,err}Pipe when the Server supports it.
|
||||||
|
|
||||||
|
// Test a simple string is returned via StdoutPipe.
|
||||||
|
func TestSessionStdoutPipe(t *testing.T) {
|
||||||
|
conn := dial(t)
|
||||||
|
defer conn.Close()
|
||||||
|
session, err := conn.NewSession()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Unable to request new session: %s", err)
|
||||||
|
}
|
||||||
|
defer session.Close()
|
||||||
|
stdout, err := session.StdoutPipe()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Unable to request StdoutPipe(): %v", err)
|
||||||
|
}
|
||||||
|
var buf bytes.Buffer
|
||||||
|
if err := session.Shell(); err != nil {
|
||||||
|
t.Fatalf("Unable to execute command: %s", err)
|
||||||
|
}
|
||||||
|
done := make(chan bool, 1)
|
||||||
|
go func() {
|
||||||
|
if _, err := io.Copy(&buf, stdout); err != nil {
|
||||||
|
t.Errorf("Copy of stdout failed: %v", err)
|
||||||
|
}
|
||||||
|
done <- true
|
||||||
|
}()
|
||||||
|
if err := session.Wait(); err != nil {
|
||||||
|
t.Fatalf("Remote command did not exit cleanly: %s", err)
|
||||||
|
}
|
||||||
|
<-done
|
||||||
|
actual := buf.String()
|
||||||
|
if actual != "golang" {
|
||||||
|
t.Fatalf("Remote shell did not return expected string: expected=golang, actual=%s", actual)
|
||||||
|
}
|
||||||
|
}
|
Loading…
Reference in New Issue
Block a user