1
0
mirror of https://github.com/golang/go synced 2024-11-24 22:00:09 -07:00

exp/ssh: new package.

The typical UNIX method for controlling long running process is to
send the process signals. Since this doesn't get you very far, various
ad-hoc, remote-control protocols have been used over time by programs
like Apache and BIND.

Implementing an SSH server means that Go code will have a standard,
secure way to do this in the future.

R=bradfitz, borman, dave, gustavo, dsymonds, r, adg, rsc, rogpeppe, lvd, kevlar, raul.san
CC=golang-dev
https://golang.org/cl/4962064
This commit is contained in:
Adam Langley 2011-09-17 15:57:24 -04:00
parent b71a805cd5
commit 605e57d8fe
10 changed files with 2742 additions and 0 deletions

16
src/pkg/exp/ssh/Makefile Normal file
View File

@ -0,0 +1,16 @@
# Copyright 2011 The Go Authors. All rights reserved.
# Use of this source code is governed by a BSD-style
# license that can be found in the LICENSE file.
include ../../../Make.inc
TARG=exp/ssh
GOFILES=\
common.go\
messages.go\
server.go\
transport.go\
channel.go\
server_shell.go\
include ../../../Make.pkg

317
src/pkg/exp/ssh/channel.go Normal file
View File

@ -0,0 +1,317 @@
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package ssh
import (
"os"
"sync"
)
// A Channel is an ordered, reliable, duplex stream that is multiplexed over an
// SSH connection.
type Channel interface {
// Accept accepts the channel creation request.
Accept() os.Error
// Reject rejects the channel creation request. After calling this, no
// other methods on the Channel may be called. If they are then the
// peer is likely to signal a protocol error and drop the connection.
Reject(reason RejectionReason, message string) os.Error
// Read may return a ChannelRequest as an os.Error.
Read(data []byte) (int, os.Error)
Write(data []byte) (int, os.Error)
Close() os.Error
// AckRequest either sends an ack or nack to the channel request.
AckRequest(ok bool) os.Error
// ChannelType returns the type of the channel, as supplied by the
// client.
ChannelType() string
// ExtraData returns the arbitary payload for this channel, as supplied
// by the client. This data is specific to the channel type.
ExtraData() []byte
}
// ChannelRequest represents a request sent on a channel, outside of the normal
// stream of bytes. It may result from calling Read on a Channel.
type ChannelRequest struct {
Request string
WantReply bool
Payload []byte
}
func (c ChannelRequest) String() string {
return "channel request received"
}
// RejectionReason is an enumeration used when rejecting channel creation
// requests. See RFC 4254, section 5.1.
type RejectionReason int
const (
Prohibited RejectionReason = iota + 1
ConnectionFailed
UnknownChannelType
ResourceShortage
)
type channel struct {
// immutable once created
chanType string
extraData []byte
theyClosed bool
theySentEOF bool
weClosed bool
dead bool
serverConn *ServerConnection
myId, theirId uint32
myWindow, theirWindow uint32
maxPacketSize uint32
err os.Error
pendingRequests []ChannelRequest
pendingData []byte
head, length int
// This lock is inferior to serverConn.lock
lock sync.Mutex
cond *sync.Cond
}
func (c *channel) Accept() os.Error {
c.serverConn.lock.Lock()
defer c.serverConn.lock.Unlock()
if c.serverConn.err != nil {
return c.serverConn.err
}
confirm := channelOpenConfirmMsg{
PeersId: c.theirId,
MyId: c.myId,
MyWindow: c.myWindow,
MaxPacketSize: c.maxPacketSize,
}
return c.serverConn.out.writePacket(marshal(msgChannelOpenConfirm, confirm))
}
func (c *channel) Reject(reason RejectionReason, message string) os.Error {
c.serverConn.lock.Lock()
defer c.serverConn.lock.Unlock()
if c.serverConn.err != nil {
return c.serverConn.err
}
reject := channelOpenFailureMsg{
PeersId: c.theirId,
Reason: uint32(reason),
Message: message,
Language: "en",
}
return c.serverConn.out.writePacket(marshal(msgChannelOpenFailure, reject))
}
func (c *channel) handlePacket(packet interface{}) {
c.lock.Lock()
defer c.lock.Unlock()
switch packet := packet.(type) {
case *channelRequestMsg:
req := ChannelRequest{
Request: packet.Request,
WantReply: packet.WantReply,
Payload: packet.RequestSpecificData,
}
c.pendingRequests = append(c.pendingRequests, req)
c.cond.Signal()
case *channelCloseMsg:
c.theyClosed = true
c.cond.Signal()
case *channelEOFMsg:
c.theySentEOF = true
c.cond.Signal()
default:
panic("unknown packet type")
}
}
func (c *channel) handleData(data []byte) {
c.lock.Lock()
defer c.lock.Unlock()
// The other side should never send us more than our window.
if len(data)+c.length > len(c.pendingData) {
// TODO(agl): we should tear down the channel with a protocol
// error.
return
}
c.myWindow -= uint32(len(data))
for i := 0; i < 2; i++ {
tail := c.head + c.length
if tail > len(c.pendingData) {
tail -= len(c.pendingData)
}
n := copy(c.pendingData[tail:], data)
data = data[n:]
c.length += n
}
c.cond.Signal()
}
func (c *channel) Read(data []byte) (n int, err os.Error) {
c.lock.Lock()
defer c.lock.Unlock()
if c.err != nil {
return 0, c.err
}
if c.myWindow <= uint32(len(c.pendingData))/2 {
packet := marshal(msgChannelWindowAdjust, windowAdjustMsg{
PeersId: c.theirId,
AdditionalBytes: uint32(len(c.pendingData)) - c.myWindow,
})
if err := c.serverConn.out.writePacket(packet); err != nil {
return 0, err
}
}
for {
if c.theySentEOF || c.theyClosed || c.dead {
return 0, os.EOF
}
if len(c.pendingRequests) > 0 {
req := c.pendingRequests[0]
if len(c.pendingRequests) == 1 {
c.pendingRequests = nil
} else {
oldPendingRequests := c.pendingRequests
c.pendingRequests = make([]ChannelRequest, len(oldPendingRequests)-1)
copy(c.pendingRequests, oldPendingRequests[1:])
}
return 0, req
}
if c.length > 0 {
tail := c.head + c.length
if tail > len(c.pendingData) {
tail -= len(c.pendingData)
}
n = copy(data, c.pendingData[c.head:tail])
c.head += n
c.length -= n
if c.head == len(c.pendingData) {
c.head = 0
}
return
}
c.cond.Wait()
}
panic("unreachable")
}
func (c *channel) Write(data []byte) (n int, err os.Error) {
for len(data) > 0 {
c.lock.Lock()
if c.dead || c.weClosed {
return 0, os.EOF
}
if c.theirWindow == 0 {
c.cond.Wait()
continue
}
c.lock.Unlock()
todo := data
if uint32(len(todo)) > c.theirWindow {
todo = todo[:c.theirWindow]
}
packet := make([]byte, 1+4+4+len(todo))
packet[0] = msgChannelData
packet[1] = byte(c.theirId) >> 24
packet[2] = byte(c.theirId) >> 16
packet[3] = byte(c.theirId) >> 8
packet[4] = byte(c.theirId)
packet[5] = byte(len(todo)) >> 24
packet[6] = byte(len(todo)) >> 16
packet[7] = byte(len(todo)) >> 8
packet[8] = byte(len(todo))
copy(packet[9:], todo)
c.serverConn.lock.Lock()
if err = c.serverConn.out.writePacket(packet); err != nil {
c.serverConn.lock.Unlock()
return
}
c.serverConn.lock.Unlock()
n += len(todo)
data = data[len(todo):]
}
return
}
func (c *channel) Close() os.Error {
c.serverConn.lock.Lock()
defer c.serverConn.lock.Unlock()
if c.serverConn.err != nil {
return c.serverConn.err
}
if c.weClosed {
return os.NewError("ssh: channel already closed")
}
c.weClosed = true
closeMsg := channelCloseMsg{
PeersId: c.theirId,
}
return c.serverConn.out.writePacket(marshal(msgChannelClose, closeMsg))
}
func (c *channel) AckRequest(ok bool) os.Error {
c.serverConn.lock.Lock()
defer c.serverConn.lock.Unlock()
if c.serverConn.err != nil {
return c.serverConn.err
}
if ok {
ack := channelRequestSuccessMsg{
PeersId: c.theirId,
}
return c.serverConn.out.writePacket(marshal(msgChannelSuccess, ack))
} else {
ack := channelRequestFailureMsg{
PeersId: c.theirId,
}
return c.serverConn.out.writePacket(marshal(msgChannelFailure, ack))
}
panic("unreachable")
}
func (c *channel) ChannelType() string {
return c.chanType
}
func (c *channel) ExtraData() []byte {
return c.extraData
}

96
src/pkg/exp/ssh/common.go Normal file
View File

@ -0,0 +1,96 @@
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package ssh
import (
"strconv"
)
// These are string constants in the SSH protocol.
const (
kexAlgoDH14SHA1 = "diffie-hellman-group14-sha1"
hostAlgoRSA = "ssh-rsa"
cipherAES128CTR = "aes128-ctr"
macSHA196 = "hmac-sha1-96"
compressionNone = "none"
serviceUserAuth = "ssh-userauth"
serviceSSH = "ssh-connection"
)
// UnexpectedMessageError results when the SSH message that we received didn't
// match what we wanted.
type UnexpectedMessageError struct {
expected, got uint8
}
func (u UnexpectedMessageError) String() string {
return "ssh: unexpected message type " + strconv.Itoa(int(u.got)) + " (expected " + strconv.Itoa(int(u.expected)) + ")"
}
// ParseError results from a malformed SSH message.
type ParseError struct {
msgType uint8
}
func (p ParseError) String() string {
return "ssh: parse error in message type " + strconv.Itoa(int(p.msgType))
}
func findCommonAlgorithm(clientAlgos []string, serverAlgos []string) (commonAlgo string, ok bool) {
for _, clientAlgo := range clientAlgos {
for _, serverAlgo := range serverAlgos {
if clientAlgo == serverAlgo {
return clientAlgo, true
}
}
}
return
}
func findAgreedAlgorithms(clientToServer, serverToClient *halfConnection, clientKexInit, serverKexInit *kexInitMsg) (kexAlgo, hostKeyAlgo string, ok bool) {
kexAlgo, ok = findCommonAlgorithm(clientKexInit.KexAlgos, serverKexInit.KexAlgos)
if !ok {
return
}
hostKeyAlgo, ok = findCommonAlgorithm(clientKexInit.ServerHostKeyAlgos, serverKexInit.ServerHostKeyAlgos)
if !ok {
return
}
clientToServer.cipherAlgo, ok = findCommonAlgorithm(clientKexInit.CiphersClientServer, serverKexInit.CiphersClientServer)
if !ok {
return
}
serverToClient.cipherAlgo, ok = findCommonAlgorithm(clientKexInit.CiphersServerClient, serverKexInit.CiphersServerClient)
if !ok {
return
}
clientToServer.macAlgo, ok = findCommonAlgorithm(clientKexInit.MACsClientServer, serverKexInit.MACsClientServer)
if !ok {
return
}
serverToClient.macAlgo, ok = findCommonAlgorithm(clientKexInit.MACsServerClient, serverKexInit.MACsServerClient)
if !ok {
return
}
clientToServer.compressionAlgo, ok = findCommonAlgorithm(clientKexInit.CompressionClientServer, serverKexInit.CompressionClientServer)
if !ok {
return
}
serverToClient.compressionAlgo, ok = findCommonAlgorithm(clientKexInit.CompressionServerClient, serverKexInit.CompressionServerClient)
if !ok {
return
}
ok = true
return
}

79
src/pkg/exp/ssh/doc.go Normal file
View File

@ -0,0 +1,79 @@
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
/*
Package ssh implements an SSH server.
SSH is a transport security protocol, an authentication protocol and a
family of application protocols. The most typical application level
protocol is a remote shell and this is specifically implemented. However,
the multiplexed nature of SSH is exposed to users that wish to support
others.
An SSH server is represented by a Server, which manages a number of
ServerConnections and handles authentication.
var s Server
s.PubKeyCallback = pubKeyAuth
s.PasswordCallback = passwordAuth
pemBytes, err := ioutil.ReadFile("id_rsa")
if err != nil {
panic("Failed to load private key")
}
err = s.SetRSAPrivateKey(pemBytes)
if err != nil {
panic("Failed to parse private key")
}
Once a Server has been set up, connections can be attached.
var sConn ServerConnection
sConn.Server = &s
err = sConn.Handshake(conn)
if err != nil {
panic("failed to handshake")
}
An SSH connection multiplexes several channels, which must be accepted themselves:
for {
channel, err := sConn.Accept()
if err != nil {
panic("error from Accept")
}
...
}
Accept reads from the connection, demultiplexes packets to their corresponding
channels and returns when a new channel request is seen. Some goroutine must
always be calling Accept; otherwise no messages will be forwarded to the
channels.
Channels have a type, depending on the application level protocol intended. In
the case of a shell, the type is "session" and ServerShell may be used to
present a simple terminal interface.
if channel.ChannelType() != "session" {
c.Reject(RejectUnknownChannelType, "unknown channel type")
return
}
channel.Accept()
shell := NewServerShell(channel, "> ")
go func() {
defer channel.Close()
for {
line, err := shell.ReadLine()
if err != nil {
break
}
println(line)
}
return
}()
*/
package ssh

557
src/pkg/exp/ssh/messages.go Normal file
View File

@ -0,0 +1,557 @@
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package ssh
import (
"big"
"bytes"
"io"
"os"
"reflect"
)
// These are SSH message type numbers. They are scattered around several
// documents but many were taken from
// http://www.iana.org/assignments/ssh-parameters/ssh-parameters.xml#ssh-parameters-1
const (
msgDisconnect = 1
msgIgnore = 2
msgUnimplemented = 3
msgDebug = 4
msgServiceRequest = 5
msgServiceAccept = 6
msgKexInit = 20
msgNewKeys = 21
msgKexDHInit = 30
msgKexDHReply = 31
msgUserAuthRequest = 50
msgUserAuthFailure = 51
msgUserAuthSuccess = 52
msgUserAuthBanner = 53
msgUserAuthPubKeyOk = 60
msgGlobalRequest = 80
msgRequestSuccess = 81
msgRequestFailure = 82
msgChannelOpen = 90
msgChannelOpenConfirm = 91
msgChannelOpenFailure = 92
msgChannelWindowAdjust = 93
msgChannelData = 94
msgChannelExtendedData = 95
msgChannelEOF = 96
msgChannelClose = 97
msgChannelRequest = 98
msgChannelSuccess = 99
msgChannelFailure = 100
)
// SSH messages:
//
// These structures mirror the wire format of the corresponding SSH messages.
// They are marshaled using reflection with the marshal and unmarshal functions
// in this file. The only wrinkle is that a final member of type []byte with a
// tag of "rest" receives the remainder of a packet when unmarshaling.
// See RFC 4253, section 7.1.
type kexInitMsg struct {
Cookie [16]byte
KexAlgos []string
ServerHostKeyAlgos []string
CiphersClientServer []string
CiphersServerClient []string
MACsClientServer []string
MACsServerClient []string
CompressionClientServer []string
CompressionServerClient []string
LanguagesClientServer []string
LanguagesServerClient []string
FirstKexFollows bool
Reserved uint32
}
// See RFC 4253, section 8.
type kexDHInitMsg struct {
X *big.Int
}
type kexDHReplyMsg struct {
HostKey []byte
Y *big.Int
Signature []byte
}
// See RFC 4253, section 10.
type serviceRequestMsg struct {
Service string
}
// See RFC 4253, section 10.
type serviceAcceptMsg struct {
Service string
}
// See RFC 4252, section 5.
type userAuthRequestMsg struct {
User string
Service string
Method string
Payload []byte "rest"
}
// See RFC 4252, section 5.1
type userAuthFailureMsg struct {
Methods []string
PartialSuccess bool
}
// See RFC 4254, section 5.1.
type channelOpenMsg struct {
ChanType string
PeersId uint32
PeersWindow uint32
MaxPacketSize uint32
TypeSpecificData []byte "rest"
}
// See RFC 4254, section 5.1.
type channelOpenConfirmMsg struct {
PeersId uint32
MyId uint32
MyWindow uint32
MaxPacketSize uint32
TypeSpecificData []byte "rest"
}
// See RFC 4254, section 5.1.
type channelOpenFailureMsg struct {
PeersId uint32
Reason uint32
Message string
Language string
}
type channelRequestMsg struct {
PeersId uint32
Request string
WantReply bool
RequestSpecificData []byte "rest"
}
// See RFC 4254, section 5.4.
type channelRequestSuccessMsg struct {
PeersId uint32
}
// See RFC 4254, section 5.4.
type channelRequestFailureMsg struct {
PeersId uint32
}
// See RFC 4254, section 5.3
type channelCloseMsg struct {
PeersId uint32
}
// See RFC 4254, section 5.3
type channelEOFMsg struct {
PeersId uint32
}
// See RFC 4254, section 4
type globalRequestMsg struct {
Type string
WantReply bool
}
// See RFC 4254, section 5.2
type windowAdjustMsg struct {
PeersId uint32
AdditionalBytes uint32
}
// See RFC 4252, section 7
type userAuthPubKeyOkMsg struct {
Algo string
PubKey string
}
// unmarshal parses the SSH wire data in packet into out using reflection.
// expectedType is the expected SSH message type. It either returns nil on
// success, or a ParseError or UnexpectedMessageError on error.
func unmarshal(out interface{}, packet []byte, expectedType uint8) os.Error {
if len(packet) == 0 {
return ParseError{expectedType}
}
if packet[0] != expectedType {
return UnexpectedMessageError{expectedType, packet[0]}
}
packet = packet[1:]
v := reflect.ValueOf(out).Elem()
structType := v.Type()
var ok bool
for i := 0; i < v.NumField(); i++ {
field := v.Field(i)
t := field.Type()
switch t.Kind() {
case reflect.Bool:
if len(packet) < 1 {
return ParseError{expectedType}
}
field.SetBool(packet[0] != 0)
packet = packet[1:]
case reflect.Array:
if t.Elem().Kind() != reflect.Uint8 {
panic("array of non-uint8")
}
if len(packet) < t.Len() {
return ParseError{expectedType}
}
for j := 0; j < t.Len(); j++ {
field.Index(j).Set(reflect.ValueOf(packet[j]))
}
packet = packet[t.Len():]
case reflect.Uint32:
var u32 uint32
if u32, packet, ok = parseUint32(packet); !ok {
return ParseError{expectedType}
}
field.SetUint(uint64(u32))
case reflect.String:
var s []byte
if s, packet, ok = parseString(packet); !ok {
return ParseError{expectedType}
}
field.SetString(string(s))
case reflect.Slice:
switch t.Elem().Kind() {
case reflect.Uint8:
if structType.Field(i).Tag == "rest" {
field.Set(reflect.ValueOf(packet))
packet = nil
} else {
var s []byte
if s, packet, ok = parseString(packet); !ok {
return ParseError{expectedType}
}
field.Set(reflect.ValueOf(s))
}
case reflect.String:
var nl []string
if nl, packet, ok = parseNameList(packet); !ok {
return ParseError{expectedType}
}
field.Set(reflect.ValueOf(nl))
default:
panic("slice of unknown type")
}
case reflect.Ptr:
if t == bigIntType {
var n *big.Int
if n, packet, ok = parseInt(packet); !ok {
return ParseError{expectedType}
}
field.Set(reflect.ValueOf(n))
} else {
panic("pointer to unknown type")
}
default:
panic("unknown type")
}
}
if len(packet) != 0 {
return ParseError{expectedType}
}
return nil
}
// marshal serializes the message in msg, using the given message type.
func marshal(msgType uint8, msg interface{}) []byte {
var out []byte
out = append(out, msgType)
v := reflect.ValueOf(msg)
structType := v.Type()
for i := 0; i < v.NumField(); i++ {
field := v.Field(i)
t := field.Type()
switch t.Kind() {
case reflect.Bool:
var v uint8
if field.Bool() {
v = 1
}
out = append(out, v)
case reflect.Array:
if t.Elem().Kind() != reflect.Uint8 {
panic("array of non-uint8")
}
for j := 0; j < t.Len(); j++ {
out = append(out, byte(field.Index(j).Uint()))
}
case reflect.Uint32:
u32 := uint32(field.Uint())
out = append(out, byte(u32>>24))
out = append(out, byte(u32>>16))
out = append(out, byte(u32>>8))
out = append(out, byte(u32))
case reflect.String:
s := field.String()
out = append(out, byte(len(s)>>24))
out = append(out, byte(len(s)>>16))
out = append(out, byte(len(s)>>8))
out = append(out, byte(len(s)))
out = append(out, []byte(s)...)
case reflect.Slice:
switch t.Elem().Kind() {
case reflect.Uint8:
length := field.Len()
if structType.Field(i).Tag != "rest" {
out = append(out, byte(length>>24))
out = append(out, byte(length>>16))
out = append(out, byte(length>>8))
out = append(out, byte(length))
}
for j := 0; j < length; j++ {
out = append(out, byte(field.Index(j).Uint()))
}
case reflect.String:
var length int
for j := 0; j < field.Len(); j++ {
if j != 0 {
length++ /* comma */
}
length += len(field.Index(j).String())
}
out = append(out, byte(length>>24))
out = append(out, byte(length>>16))
out = append(out, byte(length>>8))
out = append(out, byte(length))
for j := 0; j < field.Len(); j++ {
if j != 0 {
out = append(out, ',')
}
out = append(out, []byte(field.Index(j).String())...)
}
default:
panic("slice of unknown type")
}
case reflect.Ptr:
if t == bigIntType {
var n *big.Int
nValue := reflect.ValueOf(&n)
nValue.Elem().Set(field)
needed := intLength(n)
oldLength := len(out)
if cap(out)-len(out) < needed {
newOut := make([]byte, len(out), 2*(len(out)+needed))
copy(newOut, out)
out = newOut
}
out = out[:oldLength+needed]
marshalInt(out[oldLength:], n)
} else {
panic("pointer to unknown type")
}
}
}
return out
}
var bigOne = big.NewInt(1)
func parseString(in []byte) (out, rest []byte, ok bool) {
if len(in) < 4 {
return
}
length := uint32(in[0])<<24 | uint32(in[1])<<16 | uint32(in[2])<<8 | uint32(in[3])
if uint32(len(in)) < 4+length {
return
}
out = in[4 : 4+length]
rest = in[4+length:]
ok = true
return
}
var comma = []byte{','}
func parseNameList(in []byte) (out []string, rest []byte, ok bool) {
contents, rest, ok := parseString(in)
if !ok {
return
}
if len(contents) == 0 {
return
}
parts := bytes.Split(contents, comma)
out = make([]string, len(parts))
for i, part := range parts {
out[i] = string(part)
}
return
}
func parseInt(in []byte) (out *big.Int, rest []byte, ok bool) {
contents, rest, ok := parseString(in)
if !ok {
return
}
out = new(big.Int)
if len(contents) > 0 && contents[0]&0x80 == 0x80 {
// This is a negative number
notBytes := make([]byte, len(contents))
for i := range notBytes {
notBytes[i] = ^contents[i]
}
out.SetBytes(notBytes)
out.Add(out, bigOne)
out.Neg(out)
} else {
// Positive number
out.SetBytes(contents)
}
ok = true
return
}
func parseUint32(in []byte) (out uint32, rest []byte, ok bool) {
if len(in) < 4 {
return
}
out = uint32(in[0])<<24 | uint32(in[1])<<16 | uint32(in[2])<<8 | uint32(in[3])
rest = in[4:]
ok = true
return
}
const maxPacketSize = 36000
func nameListLength(namelist []string) int {
length := 4 /* uint32 length prefix */
for i, name := range namelist {
if i != 0 {
length++ /* comma */
}
length += len(name)
}
return length
}
func intLength(n *big.Int) int {
length := 4 /* length bytes */
if n.Sign() < 0 {
nMinus1 := new(big.Int).Neg(n)
nMinus1.Sub(nMinus1, bigOne)
bitLen := nMinus1.BitLen()
if bitLen%8 == 0 {
// The number will need 0xff padding
length++
}
length += (bitLen + 7) / 8
} else if n.Sign() == 0 {
// A zero is the zero length string
} else {
bitLen := n.BitLen()
if bitLen%8 == 0 {
// The number will need 0x00 padding
length++
}
length += (bitLen + 7) / 8
}
return length
}
func marshalInt(to []byte, n *big.Int) []byte {
lengthBytes := to
to = to[4:]
length := 0
if n.Sign() < 0 {
// A negative number has to be converted to two's-complement
// form. So we'll subtract 1 and invert. If the
// most-significant-bit isn't set then we'll need to pad the
// beginning with 0xff in order to keep the number negative.
nMinus1 := new(big.Int).Neg(n)
nMinus1.Sub(nMinus1, bigOne)
bytes := nMinus1.Bytes()
for i := range bytes {
bytes[i] ^= 0xff
}
if len(bytes) == 0 || bytes[0]&0x80 == 0 {
to[0] = 0xff
to = to[1:]
length++
}
nBytes := copy(to, bytes)
to = to[nBytes:]
length += nBytes
} else if n.Sign() == 0 {
// A zero is the zero length string
} else {
bytes := n.Bytes()
if len(bytes) > 0 && bytes[0]&0x80 != 0 {
// We'll have to pad this with a 0x00 in order to
// stop it looking like a negative number.
to[0] = 0
to = to[1:]
length++
}
nBytes := copy(to, bytes)
to = to[nBytes:]
length += nBytes
}
lengthBytes[0] = byte(length >> 24)
lengthBytes[1] = byte(length >> 16)
lengthBytes[2] = byte(length >> 8)
lengthBytes[3] = byte(length)
return to
}
func writeInt(w io.Writer, n *big.Int) {
length := intLength(n)
buf := make([]byte, length)
marshalInt(buf, n)
w.Write(buf)
}
func writeString(w io.Writer, s []byte) {
var lengthBytes [4]byte
lengthBytes[0] = byte(len(s) >> 24)
lengthBytes[1] = byte(len(s) >> 16)
lengthBytes[2] = byte(len(s) >> 8)
lengthBytes[3] = byte(len(s))
w.Write(lengthBytes[:])
w.Write(s)
}
func stringLength(s []byte) int {
return 4 + len(s)
}
func marshalString(to []byte, s []byte) []byte {
to[0] = byte(len(s) >> 24)
to[1] = byte(len(s) >> 16)
to[2] = byte(len(s) >> 8)
to[3] = byte(len(s))
to = to[4:]
copy(to, s)
return to[len(s):]
}
var bigIntType = reflect.TypeOf((*big.Int)(nil))

View File

@ -0,0 +1,125 @@
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package ssh
import (
"big"
"rand"
"reflect"
"testing"
"testing/quick"
)
var intLengthTests = []struct {
val, length int
}{
{0, 4 + 0},
{1, 4 + 1},
{127, 4 + 1},
{128, 4 + 2},
{-1, 4 + 1},
}
func TestIntLength(t *testing.T) {
for _, test := range intLengthTests {
v := new(big.Int).SetInt64(int64(test.val))
length := intLength(v)
if length != test.length {
t.Errorf("For %d, got length %d but expected %d", test.val, length, test.length)
}
}
}
var messageTypes = []interface{}{
&kexInitMsg{},
&kexDHInitMsg{},
&serviceRequestMsg{},
&serviceAcceptMsg{},
&userAuthRequestMsg{},
&channelOpenMsg{},
&channelOpenConfirmMsg{},
&channelRequestMsg{},
&channelRequestSuccessMsg{},
}
func TestMarshalUnmarshal(t *testing.T) {
rand := rand.New(rand.NewSource(0))
for i, iface := range messageTypes {
ty := reflect.ValueOf(iface).Type()
n := 100
if testing.Short() {
n = 5
}
for j := 0; j < n; j++ {
v, ok := quick.Value(ty, rand)
if !ok {
t.Errorf("#%d: failed to create value", i)
break
}
m1 := v.Elem().Interface()
m2 := iface
marshaled := marshal(msgIgnore, m1)
if err := unmarshal(m2, marshaled, msgIgnore); err != nil {
t.Errorf("#%d failed to unmarshal %#v: %s", i, m1, err)
break
}
if !reflect.DeepEqual(v.Interface(), m2) {
t.Errorf("#%d\ngot: %#v\nwant:%#v\n%x", i, m2, m1, marshaled)
break
}
}
}
}
func randomBytes(out []byte, rand *rand.Rand) {
for i := 0; i < len(out); i++ {
out[i] = byte(rand.Int31())
}
}
func randomNameList(rand *rand.Rand) []string {
ret := make([]string, rand.Int31()&15)
for i := range ret {
s := make([]byte, 1+(rand.Int31()&15))
for j := range s {
s[j] = 'a' + uint8(rand.Int31()&15)
}
ret[i] = string(s)
}
return ret
}
func randomInt(rand *rand.Rand) *big.Int {
return new(big.Int).SetInt64(int64(int32(rand.Uint32())))
}
func (*kexInitMsg) Generate(rand *rand.Rand, size int) reflect.Value {
ki := &kexInitMsg{}
randomBytes(ki.Cookie[:], rand)
ki.KexAlgos = randomNameList(rand)
ki.ServerHostKeyAlgos = randomNameList(rand)
ki.CiphersClientServer = randomNameList(rand)
ki.CiphersServerClient = randomNameList(rand)
ki.MACsClientServer = randomNameList(rand)
ki.MACsServerClient = randomNameList(rand)
ki.CompressionClientServer = randomNameList(rand)
ki.CompressionServerClient = randomNameList(rand)
ki.LanguagesClientServer = randomNameList(rand)
ki.LanguagesServerClient = randomNameList(rand)
if rand.Int31()&1 == 1 {
ki.FirstKexFollows = true
}
return reflect.ValueOf(ki)
}
func (*kexDHInitMsg) Generate(rand *rand.Rand, size int) reflect.Value {
dhi := &kexDHInitMsg{}
dhi.X = randomInt(rand)
return reflect.ValueOf(dhi)
}

711
src/pkg/exp/ssh/server.go Normal file
View File

@ -0,0 +1,711 @@
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package ssh
import (
"big"
"bufio"
"bytes"
"crypto"
"crypto/rand"
"crypto/rsa"
_ "crypto/sha1"
"crypto/x509"
"encoding/pem"
"net"
"os"
"sync"
)
var supportedKexAlgos = []string{kexAlgoDH14SHA1}
var supportedHostKeyAlgos = []string{hostAlgoRSA}
var supportedCiphers = []string{cipherAES128CTR}
var supportedMACs = []string{macSHA196}
var supportedCompressions = []string{compressionNone}
// Server represents an SSH server. A Server may have several ServerConnections.
type Server struct {
rsa *rsa.PrivateKey
rsaSerialized []byte
// NoClientAuth is true if clients are allowed to connect without
// authenticating.
NoClientAuth bool
// PasswordCallback, if non-nil, is called when a user attempts to
// authenticate using a password. It may be called concurrently from
// several goroutines.
PasswordCallback func(user, password string) bool
// PubKeyCallback, if non-nil, is called when a client attempts public
// key authentication. It must return true iff the given public key is
// valid for the given user.
PubKeyCallback func(user, algo string, pubkey []byte) bool
}
// SetRSAPrivateKey sets the private key for a Server. A Server must have a
// private key configured in order to accept connections. The private key must
// be in the form of a PEM encoded, PKCS#1, RSA private key. The file "id_rsa"
// typically contains such a key.
func (s *Server) SetRSAPrivateKey(pemBytes []byte) os.Error {
block, _ := pem.Decode(pemBytes)
if block == nil {
return os.NewError("ssh: no key found")
}
var err os.Error
s.rsa, err = x509.ParsePKCS1PrivateKey(block.Bytes)
if err != nil {
return err
}
s.rsaSerialized = marshalRSA(s.rsa)
return nil
}
// marshalRSA serializes an RSA private key according to RFC 4256, section 6.6.
func marshalRSA(priv *rsa.PrivateKey) []byte {
e := new(big.Int).SetInt64(int64(priv.E))
length := stringLength([]byte(hostAlgoRSA))
length += intLength(e)
length += intLength(priv.N)
ret := make([]byte, length)
r := marshalString(ret, []byte(hostAlgoRSA))
r = marshalInt(r, e)
r = marshalInt(r, priv.N)
return ret
}
// parseRSA parses an RSA key according to RFC 4256, section 6.6.
func parseRSA(in []byte) (pubKey *rsa.PublicKey, ok bool) {
algo, in, ok := parseString(in)
if !ok || string(algo) != hostAlgoRSA {
return nil, false
}
bigE, in, ok := parseInt(in)
if !ok || bigE.BitLen() > 24 {
return nil, false
}
e := bigE.Int64()
if e < 3 || e&1 == 0 {
return nil, false
}
N, in, ok := parseInt(in)
if !ok || len(in) > 0 {
return nil, false
}
return &rsa.PublicKey{
N: N,
E: int(e),
}, true
}
func parseRSASig(in []byte) (sig []byte, ok bool) {
algo, in, ok := parseString(in)
if !ok || string(algo) != hostAlgoRSA {
return nil, false
}
sig, in, ok = parseString(in)
if len(in) > 0 {
ok = false
}
return
}
// cachedPubKey contains the results of querying whether a public key is
// acceptable for a user. The cache only applies to a single ServerConnection.
type cachedPubKey struct {
user, algo string
pubKey []byte
result bool
}
const maxCachedPubKeys = 16
// ServerConnection represents an incomming connection to a Server.
type ServerConnection struct {
Server *Server
in, out *halfConnection
channels map[uint32]*channel
nextChanId uint32
// lock protects err and also allows Channels to serialise their writes
// to out.
lock sync.RWMutex
err os.Error
// cachedPubKeys contains the cache results of tests for public keys.
// Since SSH clients will query whether a public key is acceptable
// before attempting to authenticate with it, we end up with duplicate
// queries for public key validity.
cachedPubKeys []cachedPubKey
}
// dhGroup is a multiplicative group suitable for implementing Diffie-Hellman key agreement.
type dhGroup struct {
g, p *big.Int
}
// dhGroup14 is the group called diffie-hellman-group14-sha1 in RFC 4253 and
// Oakley Group 14 in RFC 3526.
var dhGroup14 *dhGroup
var dhGroup14Once sync.Once
func initDHGroup14() {
p, _ := new(big.Int).SetString("FFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7EDEE386BFB5A899FA5AE9F24117C4B1FE649286651ECE45B3DC2007CB8A163BF0598DA48361C55D39A69163FA8FD24CF5F83655D23DCA3AD961C62F356208552BB9ED529077096966D670C354E4ABC9804F1746C08CA18217C32905E462E36CE3BE39E772C180E86039B2783A2EC07A28FB5C55DF06F4C52C9DE2BCBF6955817183995497CEA956AE515D2261898FA051015728E5A8AACAA68FFFFFFFFFFFFFFFF", 16)
dhGroup14 = &dhGroup{
g: new(big.Int).SetInt64(2),
p: p,
}
}
type handshakeMagics struct {
clientVersion, serverVersion []byte
clientKexInit, serverKexInit []byte
}
// kexDH performs Diffie-Hellman key agreement on a ServerConnection. The
// returned values are given the same names as in RFC 4253, section 8.
func (s *ServerConnection) kexDH(group *dhGroup, hashFunc crypto.Hash, magics *handshakeMagics, hostKeyAlgo string) (H, K []byte, err os.Error) {
packet, err := s.in.readPacket()
if err != nil {
return
}
var kexDHInit kexDHInitMsg
if err = unmarshal(&kexDHInit, packet, msgKexDHInit); err != nil {
return
}
if kexDHInit.X.Sign() == 0 || kexDHInit.X.Cmp(group.p) >= 0 {
return nil, nil, os.NewError("client DH parameter out of bounds")
}
y, err := rand.Int(rand.Reader, group.p)
if err != nil {
return
}
Y := new(big.Int).Exp(group.g, y, group.p)
kInt := new(big.Int).Exp(kexDHInit.X, y, group.p)
var serializedHostKey []byte
switch hostKeyAlgo {
case hostAlgoRSA:
serializedHostKey = s.Server.rsaSerialized
default:
return nil, nil, os.NewError("internal error")
}
h := hashFunc.New()
writeString(h, magics.clientVersion)
writeString(h, magics.serverVersion)
writeString(h, magics.clientKexInit)
writeString(h, magics.serverKexInit)
writeString(h, serializedHostKey)
writeInt(h, kexDHInit.X)
writeInt(h, Y)
K = make([]byte, intLength(kInt))
marshalInt(K, kInt)
h.Write(K)
H = h.Sum()
h.Reset()
h.Write(H)
hh := h.Sum()
var sig []byte
switch hostKeyAlgo {
case hostAlgoRSA:
sig, err = rsa.SignPKCS1v15(rand.Reader, s.Server.rsa, hashFunc, hh)
if err != nil {
return
}
default:
return nil, nil, os.NewError("internal error")
}
serializedSig := serializeRSASignature(sig)
kexDHReply := kexDHReplyMsg{
HostKey: serializedHostKey,
Y: Y,
Signature: serializedSig,
}
packet = marshal(msgKexDHReply, kexDHReply)
err = s.out.writePacket(packet)
return
}
func serializeRSASignature(sig []byte) []byte {
length := stringLength([]byte(hostAlgoRSA))
length += stringLength(sig)
ret := make([]byte, length)
r := marshalString(ret, []byte(hostAlgoRSA))
r = marshalString(r, sig)
return ret
}
// serverVersion is the fixed identification string that Server will use.
var serverVersion = []byte("SSH-2.0-Go\r\n")
// buildDataSignedForAuth returns the data that is signed in order to prove
// posession of a private key. See RFC 4252, section 7.
func buildDataSignedForAuth(sessionId []byte, req userAuthRequestMsg, algo, pubKey []byte) []byte {
user := []byte(req.User)
service := []byte(req.Service)
method := []byte(req.Method)
length := stringLength(sessionId)
length += 1
length += stringLength(user)
length += stringLength(service)
length += stringLength(method)
length += 1
length += stringLength(algo)
length += stringLength(pubKey)
ret := make([]byte, length)
r := marshalString(ret, sessionId)
r[0] = msgUserAuthRequest
r = r[1:]
r = marshalString(r, user)
r = marshalString(r, service)
r = marshalString(r, method)
r[0] = 1
r = r[1:]
r = marshalString(r, algo)
r = marshalString(r, pubKey)
return ret
}
// Handshake performs an SSH transport and client authentication on the given ServerConnection.
func (s *ServerConnection) Handshake(conn net.Conn) os.Error {
var magics handshakeMagics
inBuf := bufio.NewReader(conn)
_, err := conn.Write(serverVersion)
if err != nil {
return err
}
magics.serverVersion = serverVersion[:len(serverVersion)-2]
serverKexInit := kexInitMsg{
KexAlgos: supportedKexAlgos,
ServerHostKeyAlgos: supportedHostKeyAlgos,
CiphersClientServer: supportedCiphers,
CiphersServerClient: supportedCiphers,
MACsClientServer: supportedMACs,
MACsServerClient: supportedMACs,
CompressionClientServer: supportedCompressions,
CompressionServerClient: supportedCompressions,
}
kexInitPacket := marshal(msgKexInit, serverKexInit)
magics.serverKexInit = kexInitPacket
var out halfConnection
out.out = conn
out.rand = rand.Reader
s.out = &out
err = out.writePacket(kexInitPacket)
if err != nil {
return err
}
version, ok := readVersion(inBuf)
if !ok {
return os.NewError("failed to read version string from client")
}
magics.clientVersion = version
var in halfConnection
in.in = inBuf
s.in = &in
packet, err := in.readPacket()
if err != nil {
return err
}
magics.clientKexInit = packet
var clientKexInit kexInitMsg
if err = unmarshal(&clientKexInit, packet, msgKexInit); err != nil {
return err
}
kexAlgo, hostKeyAlgo, ok := findAgreedAlgorithms(&in, &out, &clientKexInit, &serverKexInit)
if !ok {
return os.NewError("ssh: no common algorithms")
}
if clientKexInit.FirstKexFollows && kexAlgo != clientKexInit.KexAlgos[0] {
// The client sent a Kex message for the wrong algorithm,
// which we have to ignore.
_, err := in.readPacket()
if err != nil {
return err
}
}
var H, K []byte
var hashFunc crypto.Hash
switch kexAlgo {
case kexAlgoDH14SHA1:
hashFunc = crypto.SHA1
dhGroup14Once.Do(initDHGroup14)
H, K, err = s.kexDH(dhGroup14, hashFunc, &magics, hostKeyAlgo)
default:
err = os.NewError("ssh: internal error")
}
if err != nil {
return err
}
packet = []byte{msgNewKeys}
if err = out.writePacket(packet); err != nil {
return err
}
if err = out.setupKeys(serverKeys, K, H, H, hashFunc); err != nil {
return err
}
if packet, err = in.readPacket(); err != nil {
return err
}
if packet[0] != msgNewKeys {
return UnexpectedMessageError{msgNewKeys, packet[0]}
}
in.setupKeys(clientKeys, K, H, H, hashFunc)
packet, err = in.readPacket()
if err != nil {
return err
}
var serviceRequest serviceRequestMsg
if err = unmarshal(&serviceRequest, packet, msgServiceRequest); err != nil {
return err
}
if serviceRequest.Service != serviceUserAuth {
return os.NewError("ssh: requested service '" + serviceRequest.Service + "' before authenticating")
}
serviceAccept := serviceAcceptMsg{
Service: serviceUserAuth,
}
packet = marshal(msgServiceAccept, serviceAccept)
if err = out.writePacket(packet); err != nil {
return err
}
if err = s.authenticate(H); err != nil {
return err
}
s.channels = make(map[uint32]*channel)
return nil
}
func isAcceptableAlgo(algo string) bool {
return algo == hostAlgoRSA
}
// testPubKey returns true if the given public key is acceptable for the user.
func (s *ServerConnection) testPubKey(user, algo string, pubKey []byte) bool {
if s.Server.PubKeyCallback == nil || !isAcceptableAlgo(algo) {
return false
}
for _, c := range s.cachedPubKeys {
if c.user == user && c.algo == algo && bytes.Equal(c.pubKey, pubKey) {
return c.result
}
}
result := s.Server.PubKeyCallback(user, algo, pubKey)
if len(s.cachedPubKeys) < maxCachedPubKeys {
c := cachedPubKey{
user: user,
algo: algo,
pubKey: make([]byte, len(pubKey)),
result: result,
}
copy(c.pubKey, pubKey)
s.cachedPubKeys = append(s.cachedPubKeys, c)
}
return result
}
func (s *ServerConnection) authenticate(H []byte) os.Error {
var userAuthReq userAuthRequestMsg
var err os.Error
var packet []byte
userAuthLoop:
for {
if packet, err = s.in.readPacket(); err != nil {
return err
}
if err = unmarshal(&userAuthReq, packet, msgUserAuthRequest); err != nil {
return err
}
if userAuthReq.Service != serviceSSH {
return os.NewError("ssh: client attempted to negotiate for unknown service: " + userAuthReq.Service)
}
switch userAuthReq.Method {
case "none":
if s.Server.NoClientAuth {
break userAuthLoop
}
case "password":
if s.Server.PasswordCallback == nil {
break
}
payload := userAuthReq.Payload
if len(payload) < 1 || payload[0] != 0 {
return ParseError{msgUserAuthRequest}
}
payload = payload[1:]
password, payload, ok := parseString(payload)
if !ok || len(payload) > 0 {
return ParseError{msgUserAuthRequest}
}
if s.Server.PasswordCallback(userAuthReq.User, string(password)) {
break userAuthLoop
}
case "publickey":
if s.Server.PubKeyCallback == nil {
break
}
payload := userAuthReq.Payload
if len(payload) < 1 {
return ParseError{msgUserAuthRequest}
}
isQuery := payload[0] == 0
payload = payload[1:]
algoBytes, payload, ok := parseString(payload)
if !ok {
return ParseError{msgUserAuthRequest}
}
algo := string(algoBytes)
pubKey, payload, ok := parseString(payload)
if !ok {
return ParseError{msgUserAuthRequest}
}
if isQuery {
// The client can query if the given public key
// would be ok.
if len(payload) > 0 {
return ParseError{msgUserAuthRequest}
}
if s.testPubKey(userAuthReq.User, algo, pubKey) {
okMsg := userAuthPubKeyOkMsg{
Algo: algo,
PubKey: string(pubKey),
}
if err = s.out.writePacket(marshal(msgUserAuthPubKeyOk, okMsg)); err != nil {
return err
}
continue userAuthLoop
}
} else {
sig, payload, ok := parseString(payload)
if !ok || len(payload) > 0 {
return ParseError{msgUserAuthRequest}
}
if !isAcceptableAlgo(algo) {
break
}
rsaSig, ok := parseRSASig(sig)
if !ok {
return ParseError{msgUserAuthRequest}
}
signedData := buildDataSignedForAuth(H, userAuthReq, algoBytes, pubKey)
switch algo {
case hostAlgoRSA:
hashFunc := crypto.SHA1
h := hashFunc.New()
h.Write(signedData)
digest := h.Sum()
rsaKey, ok := parseRSA(pubKey)
if !ok {
return ParseError{msgUserAuthRequest}
}
if rsa.VerifyPKCS1v15(rsaKey, hashFunc, digest, rsaSig) != nil {
return ParseError{msgUserAuthRequest}
}
default:
return os.NewError("ssh: isAcceptableAlgo incorrect")
}
if s.testPubKey(userAuthReq.User, algo, pubKey) {
break userAuthLoop
}
}
}
var failureMsg userAuthFailureMsg
if s.Server.PasswordCallback != nil {
failureMsg.Methods = append(failureMsg.Methods, "password")
}
if s.Server.PubKeyCallback != nil {
failureMsg.Methods = append(failureMsg.Methods, "publickey")
}
if len(failureMsg.Methods) == 0 {
return os.NewError("ssh: no authentication methods configured but NoClientAuth is also false")
}
if err = s.out.writePacket(marshal(msgUserAuthFailure, failureMsg)); err != nil {
return err
}
}
packet = []byte{msgUserAuthSuccess}
if err = s.out.writePacket(packet); err != nil {
return err
}
return nil
}
const defaultWindowSize = 32768
// Accept reads and processes messages on a ServerConnection. It must be called
// in order to demultiplex messages to any resulting Channels.
func (s *ServerConnection) Accept() (Channel, os.Error) {
if s.err != nil {
return nil, s.err
}
for {
packet, err := s.in.readPacket()
if err != nil {
s.lock.Lock()
s.err = err
s.lock.Unlock()
for _, c := range s.channels {
c.dead = true
c.handleData(nil)
}
return nil, err
}
switch packet[0] {
case msgChannelOpen:
var chanOpen channelOpenMsg
if err := unmarshal(&chanOpen, packet, msgChannelOpen); err != nil {
return nil, err
}
c := new(channel)
c.chanType = chanOpen.ChanType
c.theirId = chanOpen.PeersId
c.theirWindow = chanOpen.PeersWindow
c.maxPacketSize = chanOpen.MaxPacketSize
c.extraData = chanOpen.TypeSpecificData
c.myWindow = defaultWindowSize
c.serverConn = s
c.cond = sync.NewCond(&c.lock)
c.pendingData = make([]byte, c.myWindow)
s.lock.Lock()
c.myId = s.nextChanId
s.nextChanId++
s.channels[c.myId] = c
s.lock.Unlock()
return c, nil
case msgChannelRequest:
var chanRequest channelRequestMsg
if err := unmarshal(&chanRequest, packet, msgChannelRequest); err != nil {
return nil, err
}
s.lock.Lock()
c, ok := s.channels[chanRequest.PeersId]
if !ok {
continue
}
c.handlePacket(&chanRequest)
s.lock.Unlock()
case msgChannelData:
if len(packet) < 5 {
return nil, ParseError{msgChannelData}
}
chanId := uint32(packet[1])<<24 | uint32(packet[2])<<16 | uint32(packet[3])<<8 | uint32(packet[4])
s.lock.Lock()
c, ok := s.channels[chanId]
if !ok {
continue
}
c.handleData(packet[9:])
s.lock.Unlock()
case msgChannelEOF:
var eofMsg channelEOFMsg
if err := unmarshal(&eofMsg, packet, msgChannelEOF); err != nil {
return nil, err
}
s.lock.Lock()
c, ok := s.channels[eofMsg.PeersId]
if !ok {
continue
}
c.handlePacket(&eofMsg)
s.lock.Unlock()
case msgChannelClose:
var closeMsg channelCloseMsg
if err := unmarshal(&closeMsg, packet, msgChannelClose); err != nil {
return nil, err
}
s.lock.Lock()
c, ok := s.channels[closeMsg.PeersId]
if !ok {
continue
}
c.handlePacket(&closeMsg)
s.lock.Unlock()
case msgGlobalRequest:
var request globalRequestMsg
if err := unmarshal(&request, packet, msgGlobalRequest); err != nil {
return nil, err
}
if request.WantReply {
if err := s.out.writePacket([]byte{msgRequestFailure}); err != nil {
return nil, err
}
}
default:
// Unknown message. Ignore.
}
}
panic("unreachable")
}

View File

@ -0,0 +1,399 @@
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package ssh
import (
"os"
)
// ServerShell contains the state for running a VT100 terminal that is capable
// of reading lines of input.
type ServerShell struct {
c Channel
prompt string
// line is the current line being entered.
line []byte
// pos is the logical position of the cursor in line
pos int
// cursorX contains the current X value of the cursor where the left
// edge is 0. cursorY contains the row number where the first row of
// the current line is 0.
cursorX, cursorY int
// maxLine is the greatest value of cursorY so far.
maxLine int
termWidth, termHeight int
// outBuf contains the terminal data to be sent.
outBuf []byte
// remainder contains the remainder of any partial key sequences after
// a read. It aliases into inBuf.
remainder []byte
inBuf [256]byte
}
// NewServerShell runs a VT100 terminal on the given channel. prompt is a
// string that is written at the start of each input line. For example: "> ".
func NewServerShell(c Channel, prompt string) *ServerShell {
return &ServerShell{
c: c,
prompt: prompt,
termWidth: 80,
termHeight: 24,
}
}
const (
keyCtrlD = 4
keyEnter = '\r'
keyEscape = 27
keyBackspace = 127
keyUnknown = 256 + iota
keyUp
keyDown
keyLeft
keyRight
keyAltLeft
keyAltRight
)
// bytesToKey tries to parse a key sequence from b. If successful, it returns
// the key and the remainder of the input. Otherwise it returns -1.
func bytesToKey(b []byte) (int, []byte) {
if len(b) == 0 {
return -1, nil
}
if b[0] != keyEscape {
return int(b[0]), b[1:]
}
if len(b) >= 3 && b[0] == keyEscape && b[1] == '[' {
switch b[2] {
case 'A':
return keyUp, b[3:]
case 'B':
return keyDown, b[3:]
case 'C':
return keyRight, b[3:]
case 'D':
return keyLeft, b[3:]
}
}
if len(b) >= 6 && b[0] == keyEscape && b[1] == '[' && b[2] == '1' && b[3] == ';' && b[4] == '3' {
switch b[5] {
case 'C':
return keyAltRight, b[6:]
case 'D':
return keyAltLeft, b[6:]
}
}
// If we get here then we have a key that we don't recognise, or a
// partial sequence. It's not clear how one should find the end of a
// sequence without knowing them all, but it seems that [a-zA-Z] only
// appears at the end of a sequence.
for i, c := range b[0:] {
if c >= 'a' && c <= 'z' || c >= 'A' && c <= 'Z' {
return keyUnknown, b[i+1:]
}
}
return -1, b
}
// queue appends data to the end of ss.outBuf
func (ss *ServerShell) queue(data []byte) {
if len(ss.outBuf)+len(data) > cap(ss.outBuf) {
newOutBuf := make([]byte, len(ss.outBuf), 2*(len(ss.outBuf)+len(data)))
copy(newOutBuf, ss.outBuf)
ss.outBuf = newOutBuf
}
oldLen := len(ss.outBuf)
ss.outBuf = ss.outBuf[:len(ss.outBuf)+len(data)]
copy(ss.outBuf[oldLen:], data)
}
var eraseUnderCursor = []byte{' ', keyEscape, '[', 'D'}
func isPrintable(key int) bool {
return key >= 32 && key < 127
}
// moveCursorToPos appends data to ss.outBuf which will move the cursor to the
// given, logical position in the text.
func (ss *ServerShell) moveCursorToPos(pos int) {
x := len(ss.prompt) + pos
y := x / ss.termWidth
x = x % ss.termWidth
up := 0
if y < ss.cursorY {
up = ss.cursorY - y
}
down := 0
if y > ss.cursorY {
down = y - ss.cursorY
}
left := 0
if x < ss.cursorX {
left = ss.cursorX - x
}
right := 0
if x > ss.cursorX {
right = x - ss.cursorX
}
movement := make([]byte, 3*(up+down+left+right))
m := movement
for i := 0; i < up; i++ {
m[0] = keyEscape
m[1] = '['
m[2] = 'A'
m = m[3:]
}
for i := 0; i < down; i++ {
m[0] = keyEscape
m[1] = '['
m[2] = 'B'
m = m[3:]
}
for i := 0; i < left; i++ {
m[0] = keyEscape
m[1] = '['
m[2] = 'D'
m = m[3:]
}
for i := 0; i < right; i++ {
m[0] = keyEscape
m[1] = '['
m[2] = 'C'
m = m[3:]
}
ss.cursorX = x
ss.cursorY = y
ss.queue(movement)
}
const maxLineLength = 4096
// handleKey processes the given key and, optionally, returns a line of text
// that the user has entered.
func (ss *ServerShell) handleKey(key int) (line string, ok bool) {
switch key {
case keyBackspace:
if ss.pos == 0 {
return
}
ss.pos--
copy(ss.line[ss.pos:], ss.line[1+ss.pos:])
ss.line = ss.line[:len(ss.line)-1]
ss.writeLine(ss.line[ss.pos:])
ss.moveCursorToPos(ss.pos)
ss.queue(eraseUnderCursor)
case keyAltLeft:
// move left by a word.
if ss.pos == 0 {
return
}
ss.pos--
for ss.pos > 0 {
if ss.line[ss.pos] != ' ' {
break
}
ss.pos--
}
for ss.pos > 0 {
if ss.line[ss.pos] == ' ' {
ss.pos++
break
}
ss.pos--
}
ss.moveCursorToPos(ss.pos)
case keyAltRight:
// move right by a word.
for ss.pos < len(ss.line) {
if ss.line[ss.pos] == ' ' {
break
}
ss.pos++
}
for ss.pos < len(ss.line) {
if ss.line[ss.pos] != ' ' {
break
}
ss.pos++
}
ss.moveCursorToPos(ss.pos)
case keyLeft:
if ss.pos == 0 {
return
}
ss.pos--
ss.moveCursorToPos(ss.pos)
case keyRight:
if ss.pos == len(ss.line) {
return
}
ss.pos++
ss.moveCursorToPos(ss.pos)
case keyEnter:
ss.moveCursorToPos(len(ss.line))
ss.queue([]byte("\r\n"))
line = string(ss.line)
ok = true
ss.line = ss.line[:0]
ss.pos = 0
ss.cursorX = 0
ss.cursorY = 0
ss.maxLine = 0
default:
if !isPrintable(key) {
return
}
if len(ss.line) == maxLineLength {
return
}
if len(ss.line) == cap(ss.line) {
newLine := make([]byte, len(ss.line), 2*(1+len(ss.line)))
copy(newLine, ss.line)
ss.line = newLine
}
ss.line = ss.line[:len(ss.line)+1]
copy(ss.line[ss.pos+1:], ss.line[ss.pos:])
ss.line[ss.pos] = byte(key)
ss.writeLine(ss.line[ss.pos:])
ss.pos++
ss.moveCursorToPos(ss.pos)
}
return
}
func (ss *ServerShell) writeLine(line []byte) {
for len(line) != 0 {
if ss.cursorX == ss.termWidth {
ss.queue([]byte("\r\n"))
ss.cursorX = 0
ss.cursorY++
if ss.cursorY > ss.maxLine {
ss.maxLine = ss.cursorY
}
}
remainingOnLine := ss.termWidth - ss.cursorX
todo := len(line)
if todo > remainingOnLine {
todo = remainingOnLine
}
ss.queue(line[:todo])
ss.cursorX += todo
line = line[todo:]
}
}
// parsePtyRequest parses the payload of the pty-req message and extracts the
// dimensions of the terminal. See RFC 4254, section 6.2.
func parsePtyRequest(s []byte) (width, height int, ok bool) {
_, s, ok = parseString(s)
if !ok {
return
}
width32, s, ok := parseUint32(s)
if !ok {
return
}
height32, _, ok := parseUint32(s)
width = int(width32)
height = int(height32)
if width < 1 {
ok = false
}
if height < 1 {
ok = false
}
return
}
func (ss *ServerShell) Write(buf []byte) (n int, err os.Error) {
return ss.c.Write(buf)
}
// ReadLine returns a line of input from the terminal.
func (ss *ServerShell) ReadLine() (line string, err os.Error) {
ss.writeLine([]byte(ss.prompt))
ss.c.Write(ss.outBuf)
ss.outBuf = ss.outBuf[:0]
for {
// ss.remainder is a slice at the beginning of ss.inBuf
// containing a partial key sequence
readBuf := ss.inBuf[len(ss.remainder):]
n, err := ss.c.Read(readBuf)
if err == nil {
ss.remainder = ss.inBuf[:n+len(ss.remainder)]
rest := ss.remainder
lineOk := false
for !lineOk {
var key int
key, rest = bytesToKey(rest)
if key < 0 {
break
}
if key == keyCtrlD {
return "", os.EOF
}
line, lineOk = ss.handleKey(key)
}
if len(rest) > 0 {
n := copy(ss.inBuf[:], rest)
ss.remainder = ss.inBuf[:n]
} else {
ss.remainder = nil
}
ss.c.Write(ss.outBuf)
ss.outBuf = ss.outBuf[:0]
if lineOk {
return
}
continue
}
if req, ok := err.(ChannelRequest); ok {
ok := false
switch req.Request {
case "pty-req":
ss.termWidth, ss.termHeight, ok = parsePtyRequest(req.Payload)
if !ok {
ss.termWidth = 80
ss.termHeight = 24
}
case "shell":
ok = true
if len(req.Payload) > 0 {
// We don't accept any commands, only the default shell.
ok = false
}
case "env":
ok = true
}
if req.WantReply {
ss.c.AckRequest(ok)
}
} else {
return "", err
}
}
panic("unreachable")
}

View File

@ -0,0 +1,134 @@
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package ssh
import (
"testing"
"os"
)
type MockChannel struct {
toSend []byte
bytesPerRead int
received []byte
}
func (c *MockChannel) Accept() os.Error {
return nil
}
func (c *MockChannel) Reject(RejectionReason, string) os.Error {
return nil
}
func (c *MockChannel) Read(data []byte) (n int, err os.Error) {
n = len(data)
if n == 0 {
return
}
if n > len(c.toSend) {
n = len(c.toSend)
}
if n == 0 {
return 0, os.EOF
}
if c.bytesPerRead > 0 && n > c.bytesPerRead {
n = c.bytesPerRead
}
copy(data, c.toSend[:n])
c.toSend = c.toSend[n:]
return
}
func (c *MockChannel) Write(data []byte) (n int, err os.Error) {
c.received = append(c.received, data...)
return len(data), nil
}
func (c *MockChannel) Close() os.Error {
return nil
}
func (c *MockChannel) AckRequest(ok bool) os.Error {
return nil
}
func (c *MockChannel) ChannelType() string {
return ""
}
func (c *MockChannel) ExtraData() []byte {
return nil
}
func TestClose(t *testing.T) {
c := &MockChannel{}
ss := NewServerShell(c, "> ")
line, err := ss.ReadLine()
if line != "" {
t.Errorf("Expected empty line but got: %s", line)
}
if err != os.EOF {
t.Errorf("Error should have been EOF but got: %s", err)
}
}
var keyPressTests = []struct {
in string
line string
err os.Error
}{
{
"",
"",
os.EOF,
},
{
"\r",
"",
nil,
},
{
"foo\r",
"foo",
nil,
},
{
"a\x1b[Cb\r", // right
"ab",
nil,
},
{
"a\x1b[Db\r", // left
"ba",
nil,
},
{
"a\177b\r", // backspace
"b",
nil,
},
}
func TestKeyPresses(t *testing.T) {
for i, test := range keyPressTests {
for j := 0; j < len(test.in); j++ {
c := &MockChannel{
toSend: []byte(test.in),
bytesPerRead: j,
}
ss := NewServerShell(c, "> ")
line, err := ss.ReadLine()
if line != test.line {
t.Errorf("Line resulting from test %d (%d bytes per read) was '%s', expected '%s'", i, j, line, test.line)
break
}
if err != test.err {
t.Errorf("Error resulting from test %d (%d bytes per read) was '%v', expected '%v'", i, j, err, test.err)
break
}
}
}
}

View File

@ -0,0 +1,308 @@
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package ssh
import (
"bufio"
"crypto"
"crypto/aes"
"crypto/cipher"
"crypto/hmac"
"crypto/subtle"
"hash"
"io"
"net"
"os"
)
// halfConnection represents one direction of an SSH connection. It maintains
// the cipher state needed to process messages.
type halfConnection struct {
// Only one of these two will be non-nil
in *bufio.Reader
out net.Conn
rand io.Reader
cipherAlgo string
macAlgo string
compressionAlgo string
paddingMultiple int
seqNum uint32
mac hash.Hash
cipher cipher.Stream
}
func (hc *halfConnection) readOnePacket() (packet []byte, err os.Error) {
var lengthBytes [5]byte
_, err = io.ReadFull(hc.in, lengthBytes[:])
if err != nil {
return
}
if hc.cipher != nil {
hc.cipher.XORKeyStream(lengthBytes[:], lengthBytes[:])
}
macSize := 0
if hc.mac != nil {
hc.mac.Reset()
var seqNumBytes [4]byte
seqNumBytes[0] = byte(hc.seqNum >> 24)
seqNumBytes[1] = byte(hc.seqNum >> 16)
seqNumBytes[2] = byte(hc.seqNum >> 8)
seqNumBytes[3] = byte(hc.seqNum)
hc.mac.Write(seqNumBytes[:])
hc.mac.Write(lengthBytes[:])
macSize = hc.mac.Size()
}
length := uint32(lengthBytes[0])<<24 | uint32(lengthBytes[1])<<16 | uint32(lengthBytes[2])<<8 | uint32(lengthBytes[3])
paddingLength := uint32(lengthBytes[4])
if length <= paddingLength+1 {
return nil, os.NewError("invalid packet length")
}
if length > maxPacketSize {
return nil, os.NewError("packet too large")
}
packet = make([]byte, length-1+uint32(macSize))
_, err = io.ReadFull(hc.in, packet)
if err != nil {
return nil, err
}
mac := packet[length-1:]
if hc.cipher != nil {
hc.cipher.XORKeyStream(packet, packet[:length-1])
}
if hc.mac != nil {
hc.mac.Write(packet[:length-1])
if subtle.ConstantTimeCompare(hc.mac.Sum(), mac) != 1 {
return nil, os.NewError("ssh: MAC failure")
}
}
hc.seqNum++
packet = packet[:length-paddingLength-1]
return
}
func (hc *halfConnection) readPacket() (packet []byte, err os.Error) {
for {
packet, err := hc.readOnePacket()
if err != nil {
return nil, err
}
if packet[0] != msgIgnore && packet[0] != msgDebug {
return packet, nil
}
}
panic("unreachable")
}
func (hc *halfConnection) writePacket(packet []byte) os.Error {
paddingMultiple := hc.paddingMultiple
if paddingMultiple == 0 {
paddingMultiple = 8
}
paddingLength := paddingMultiple - (4+1+len(packet))%paddingMultiple
if paddingLength < 4 {
paddingLength += paddingMultiple
}
var lengthBytes [5]byte
length := len(packet) + 1 + paddingLength
lengthBytes[0] = byte(length >> 24)
lengthBytes[1] = byte(length >> 16)
lengthBytes[2] = byte(length >> 8)
lengthBytes[3] = byte(length)
lengthBytes[4] = byte(paddingLength)
var padding [32]byte
_, err := io.ReadFull(hc.rand, padding[:paddingLength])
if err != nil {
return err
}
if hc.mac != nil {
hc.mac.Reset()
var seqNumBytes [4]byte
seqNumBytes[0] = byte(hc.seqNum >> 24)
seqNumBytes[1] = byte(hc.seqNum >> 16)
seqNumBytes[2] = byte(hc.seqNum >> 8)
seqNumBytes[3] = byte(hc.seqNum)
hc.mac.Write(seqNumBytes[:])
hc.mac.Write(lengthBytes[:])
hc.mac.Write(packet)
hc.mac.Write(padding[:paddingLength])
}
if hc.cipher != nil {
hc.cipher.XORKeyStream(lengthBytes[:], lengthBytes[:])
hc.cipher.XORKeyStream(packet, packet)
hc.cipher.XORKeyStream(padding[:], padding[:paddingLength])
}
_, err = hc.out.Write(lengthBytes[:])
if err != nil {
return err
}
_, err = hc.out.Write(packet)
if err != nil {
return err
}
_, err = hc.out.Write(padding[:paddingLength])
if err != nil {
return err
}
if hc.mac != nil {
_, err = hc.out.Write(hc.mac.Sum())
}
hc.seqNum++
return err
}
const (
serverKeys = iota
clientKeys
)
// setupServerKeys sets the cipher and MAC keys from K, H and sessionId, as
// described in RFC 4253, section 6.4. direction should either be serverKeys
// (to setup server->client keys) or clientKeys (for client->server keys).
func (hc *halfConnection) setupKeys(direction int, K, H, sessionId []byte, hashFunc crypto.Hash) os.Error {
h := hashFunc.New()
// We only support these algorithms for now.
if hc.cipherAlgo != cipherAES128CTR || hc.macAlgo != macSHA196 {
return os.NewError("ssh: setupServerKeys internal error")
}
blockSize := 16
keySize := 16
macKeySize := 20
var ivTag, keyTag, macKeyTag byte
if direction == serverKeys {
ivTag, keyTag, macKeyTag = 'B', 'D', 'F'
} else {
ivTag, keyTag, macKeyTag = 'A', 'C', 'E'
}
iv := make([]byte, blockSize)
key := make([]byte, keySize)
macKey := make([]byte, macKeySize)
generateKeyMaterial(iv, ivTag, K, H, sessionId, h)
generateKeyMaterial(key, keyTag, K, H, sessionId, h)
generateKeyMaterial(macKey, macKeyTag, K, H, sessionId, h)
hc.mac = truncatingMAC{12, hmac.NewSHA1(macKey)}
aes, err := aes.NewCipher(key)
if err != nil {
return err
}
hc.cipher = cipher.NewCTR(aes, iv)
hc.paddingMultiple = 16
return nil
}
// generateKeyMaterial fills out with key material generated from tag, K, H
// and sessionId, as specified in RFC 4253, section 7.2.
func generateKeyMaterial(out []byte, tag byte, K, H, sessionId []byte, h hash.Hash) {
var digestsSoFar []byte
for len(out) > 0 {
h.Reset()
h.Write(K)
h.Write(H)
if len(digestsSoFar) == 0 {
h.Write([]byte{tag})
h.Write(sessionId)
} else {
h.Write(digestsSoFar)
}
digest := h.Sum()
n := copy(out, digest)
out = out[n:]
if len(out) > 0 {
digestsSoFar = append(digestsSoFar, digest...)
}
}
}
// truncatingMAC wraps around a hash.Hash and truncates the output digest to
// a given size.
type truncatingMAC struct {
length int
hmac hash.Hash
}
func (t truncatingMAC) Write(data []byte) (int, os.Error) {
return t.hmac.Write(data)
}
func (t truncatingMAC) Sum() []byte {
digest := t.hmac.Sum()
return digest[:t.length]
}
func (t truncatingMAC) Reset() {
t.hmac.Reset()
}
func (t truncatingMAC) Size() int {
return t.length
}
// maxVersionStringBytes is the maximum number of bytes that we'll accept as a
// version string. In the event that the client is talking a different protocol
// we need to set a limit otherwise we will keep using more and more memory
// while searching for the end of the version handshake.
const maxVersionStringBytes = 1024
func readVersion(r *bufio.Reader) (versionString []byte, ok bool) {
versionString = make([]byte, 0, 64)
seenCR := false
forEachByte:
for len(versionString) < maxVersionStringBytes {
b, err := r.ReadByte()
if err != nil {
return
}
if !seenCR {
if b == '\r' {
seenCR = true
}
} else {
if b == '\n' {
ok = true
break forEachByte
} else {
seenCR = false
}
}
versionString = append(versionString, b)
}
if ok {
// We need to remove the CR from versionString
versionString = versionString[:len(versionString)-1]
}
return
}