mirror of
https://github.com/golang/go
synced 2024-11-21 15:54:43 -07:00
exp/ssh: new package.
The typical UNIX method for controlling long running process is to send the process signals. Since this doesn't get you very far, various ad-hoc, remote-control protocols have been used over time by programs like Apache and BIND. Implementing an SSH server means that Go code will have a standard, secure way to do this in the future. R=bradfitz, borman, dave, gustavo, dsymonds, r, adg, rsc, rogpeppe, lvd, kevlar, raul.san CC=golang-dev https://golang.org/cl/4962064
This commit is contained in:
parent
b71a805cd5
commit
605e57d8fe
16
src/pkg/exp/ssh/Makefile
Normal file
16
src/pkg/exp/ssh/Makefile
Normal file
@ -0,0 +1,16 @@
|
||||
# Copyright 2011 The Go Authors. All rights reserved.
|
||||
# Use of this source code is governed by a BSD-style
|
||||
# license that can be found in the LICENSE file.
|
||||
|
||||
include ../../../Make.inc
|
||||
|
||||
TARG=exp/ssh
|
||||
GOFILES=\
|
||||
common.go\
|
||||
messages.go\
|
||||
server.go\
|
||||
transport.go\
|
||||
channel.go\
|
||||
server_shell.go\
|
||||
|
||||
include ../../../Make.pkg
|
317
src/pkg/exp/ssh/channel.go
Normal file
317
src/pkg/exp/ssh/channel.go
Normal file
@ -0,0 +1,317 @@
|
||||
// Copyright 2011 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package ssh
|
||||
|
||||
import (
|
||||
"os"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// A Channel is an ordered, reliable, duplex stream that is multiplexed over an
|
||||
// SSH connection.
|
||||
type Channel interface {
|
||||
// Accept accepts the channel creation request.
|
||||
Accept() os.Error
|
||||
// Reject rejects the channel creation request. After calling this, no
|
||||
// other methods on the Channel may be called. If they are then the
|
||||
// peer is likely to signal a protocol error and drop the connection.
|
||||
Reject(reason RejectionReason, message string) os.Error
|
||||
|
||||
// Read may return a ChannelRequest as an os.Error.
|
||||
Read(data []byte) (int, os.Error)
|
||||
Write(data []byte) (int, os.Error)
|
||||
Close() os.Error
|
||||
|
||||
// AckRequest either sends an ack or nack to the channel request.
|
||||
AckRequest(ok bool) os.Error
|
||||
|
||||
// ChannelType returns the type of the channel, as supplied by the
|
||||
// client.
|
||||
ChannelType() string
|
||||
// ExtraData returns the arbitary payload for this channel, as supplied
|
||||
// by the client. This data is specific to the channel type.
|
||||
ExtraData() []byte
|
||||
}
|
||||
|
||||
// ChannelRequest represents a request sent on a channel, outside of the normal
|
||||
// stream of bytes. It may result from calling Read on a Channel.
|
||||
type ChannelRequest struct {
|
||||
Request string
|
||||
WantReply bool
|
||||
Payload []byte
|
||||
}
|
||||
|
||||
func (c ChannelRequest) String() string {
|
||||
return "channel request received"
|
||||
}
|
||||
|
||||
// RejectionReason is an enumeration used when rejecting channel creation
|
||||
// requests. See RFC 4254, section 5.1.
|
||||
type RejectionReason int
|
||||
|
||||
const (
|
||||
Prohibited RejectionReason = iota + 1
|
||||
ConnectionFailed
|
||||
UnknownChannelType
|
||||
ResourceShortage
|
||||
)
|
||||
|
||||
type channel struct {
|
||||
// immutable once created
|
||||
chanType string
|
||||
extraData []byte
|
||||
|
||||
theyClosed bool
|
||||
theySentEOF bool
|
||||
weClosed bool
|
||||
dead bool
|
||||
|
||||
serverConn *ServerConnection
|
||||
myId, theirId uint32
|
||||
myWindow, theirWindow uint32
|
||||
maxPacketSize uint32
|
||||
err os.Error
|
||||
|
||||
pendingRequests []ChannelRequest
|
||||
pendingData []byte
|
||||
head, length int
|
||||
|
||||
// This lock is inferior to serverConn.lock
|
||||
lock sync.Mutex
|
||||
cond *sync.Cond
|
||||
}
|
||||
|
||||
func (c *channel) Accept() os.Error {
|
||||
c.serverConn.lock.Lock()
|
||||
defer c.serverConn.lock.Unlock()
|
||||
|
||||
if c.serverConn.err != nil {
|
||||
return c.serverConn.err
|
||||
}
|
||||
|
||||
confirm := channelOpenConfirmMsg{
|
||||
PeersId: c.theirId,
|
||||
MyId: c.myId,
|
||||
MyWindow: c.myWindow,
|
||||
MaxPacketSize: c.maxPacketSize,
|
||||
}
|
||||
return c.serverConn.out.writePacket(marshal(msgChannelOpenConfirm, confirm))
|
||||
}
|
||||
|
||||
func (c *channel) Reject(reason RejectionReason, message string) os.Error {
|
||||
c.serverConn.lock.Lock()
|
||||
defer c.serverConn.lock.Unlock()
|
||||
|
||||
if c.serverConn.err != nil {
|
||||
return c.serverConn.err
|
||||
}
|
||||
|
||||
reject := channelOpenFailureMsg{
|
||||
PeersId: c.theirId,
|
||||
Reason: uint32(reason),
|
||||
Message: message,
|
||||
Language: "en",
|
||||
}
|
||||
return c.serverConn.out.writePacket(marshal(msgChannelOpenFailure, reject))
|
||||
}
|
||||
|
||||
func (c *channel) handlePacket(packet interface{}) {
|
||||
c.lock.Lock()
|
||||
defer c.lock.Unlock()
|
||||
|
||||
switch packet := packet.(type) {
|
||||
case *channelRequestMsg:
|
||||
req := ChannelRequest{
|
||||
Request: packet.Request,
|
||||
WantReply: packet.WantReply,
|
||||
Payload: packet.RequestSpecificData,
|
||||
}
|
||||
|
||||
c.pendingRequests = append(c.pendingRequests, req)
|
||||
c.cond.Signal()
|
||||
case *channelCloseMsg:
|
||||
c.theyClosed = true
|
||||
c.cond.Signal()
|
||||
case *channelEOFMsg:
|
||||
c.theySentEOF = true
|
||||
c.cond.Signal()
|
||||
default:
|
||||
panic("unknown packet type")
|
||||
}
|
||||
}
|
||||
|
||||
func (c *channel) handleData(data []byte) {
|
||||
c.lock.Lock()
|
||||
defer c.lock.Unlock()
|
||||
|
||||
// The other side should never send us more than our window.
|
||||
if len(data)+c.length > len(c.pendingData) {
|
||||
// TODO(agl): we should tear down the channel with a protocol
|
||||
// error.
|
||||
return
|
||||
}
|
||||
|
||||
c.myWindow -= uint32(len(data))
|
||||
for i := 0; i < 2; i++ {
|
||||
tail := c.head + c.length
|
||||
if tail > len(c.pendingData) {
|
||||
tail -= len(c.pendingData)
|
||||
}
|
||||
n := copy(c.pendingData[tail:], data)
|
||||
data = data[n:]
|
||||
c.length += n
|
||||
}
|
||||
|
||||
c.cond.Signal()
|
||||
}
|
||||
|
||||
func (c *channel) Read(data []byte) (n int, err os.Error) {
|
||||
c.lock.Lock()
|
||||
defer c.lock.Unlock()
|
||||
|
||||
if c.err != nil {
|
||||
return 0, c.err
|
||||
}
|
||||
|
||||
if c.myWindow <= uint32(len(c.pendingData))/2 {
|
||||
packet := marshal(msgChannelWindowAdjust, windowAdjustMsg{
|
||||
PeersId: c.theirId,
|
||||
AdditionalBytes: uint32(len(c.pendingData)) - c.myWindow,
|
||||
})
|
||||
if err := c.serverConn.out.writePacket(packet); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
}
|
||||
|
||||
for {
|
||||
if c.theySentEOF || c.theyClosed || c.dead {
|
||||
return 0, os.EOF
|
||||
}
|
||||
|
||||
if len(c.pendingRequests) > 0 {
|
||||
req := c.pendingRequests[0]
|
||||
if len(c.pendingRequests) == 1 {
|
||||
c.pendingRequests = nil
|
||||
} else {
|
||||
oldPendingRequests := c.pendingRequests
|
||||
c.pendingRequests = make([]ChannelRequest, len(oldPendingRequests)-1)
|
||||
copy(c.pendingRequests, oldPendingRequests[1:])
|
||||
}
|
||||
|
||||
return 0, req
|
||||
}
|
||||
|
||||
if c.length > 0 {
|
||||
tail := c.head + c.length
|
||||
if tail > len(c.pendingData) {
|
||||
tail -= len(c.pendingData)
|
||||
}
|
||||
n = copy(data, c.pendingData[c.head:tail])
|
||||
c.head += n
|
||||
c.length -= n
|
||||
if c.head == len(c.pendingData) {
|
||||
c.head = 0
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
c.cond.Wait()
|
||||
}
|
||||
|
||||
panic("unreachable")
|
||||
}
|
||||
|
||||
func (c *channel) Write(data []byte) (n int, err os.Error) {
|
||||
for len(data) > 0 {
|
||||
c.lock.Lock()
|
||||
if c.dead || c.weClosed {
|
||||
return 0, os.EOF
|
||||
}
|
||||
|
||||
if c.theirWindow == 0 {
|
||||
c.cond.Wait()
|
||||
continue
|
||||
}
|
||||
c.lock.Unlock()
|
||||
|
||||
todo := data
|
||||
if uint32(len(todo)) > c.theirWindow {
|
||||
todo = todo[:c.theirWindow]
|
||||
}
|
||||
|
||||
packet := make([]byte, 1+4+4+len(todo))
|
||||
packet[0] = msgChannelData
|
||||
packet[1] = byte(c.theirId) >> 24
|
||||
packet[2] = byte(c.theirId) >> 16
|
||||
packet[3] = byte(c.theirId) >> 8
|
||||
packet[4] = byte(c.theirId)
|
||||
packet[5] = byte(len(todo)) >> 24
|
||||
packet[6] = byte(len(todo)) >> 16
|
||||
packet[7] = byte(len(todo)) >> 8
|
||||
packet[8] = byte(len(todo))
|
||||
copy(packet[9:], todo)
|
||||
|
||||
c.serverConn.lock.Lock()
|
||||
if err = c.serverConn.out.writePacket(packet); err != nil {
|
||||
c.serverConn.lock.Unlock()
|
||||
return
|
||||
}
|
||||
c.serverConn.lock.Unlock()
|
||||
|
||||
n += len(todo)
|
||||
data = data[len(todo):]
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (c *channel) Close() os.Error {
|
||||
c.serverConn.lock.Lock()
|
||||
defer c.serverConn.lock.Unlock()
|
||||
|
||||
if c.serverConn.err != nil {
|
||||
return c.serverConn.err
|
||||
}
|
||||
|
||||
if c.weClosed {
|
||||
return os.NewError("ssh: channel already closed")
|
||||
}
|
||||
c.weClosed = true
|
||||
|
||||
closeMsg := channelCloseMsg{
|
||||
PeersId: c.theirId,
|
||||
}
|
||||
return c.serverConn.out.writePacket(marshal(msgChannelClose, closeMsg))
|
||||
}
|
||||
|
||||
func (c *channel) AckRequest(ok bool) os.Error {
|
||||
c.serverConn.lock.Lock()
|
||||
defer c.serverConn.lock.Unlock()
|
||||
|
||||
if c.serverConn.err != nil {
|
||||
return c.serverConn.err
|
||||
}
|
||||
|
||||
if ok {
|
||||
ack := channelRequestSuccessMsg{
|
||||
PeersId: c.theirId,
|
||||
}
|
||||
return c.serverConn.out.writePacket(marshal(msgChannelSuccess, ack))
|
||||
} else {
|
||||
ack := channelRequestFailureMsg{
|
||||
PeersId: c.theirId,
|
||||
}
|
||||
return c.serverConn.out.writePacket(marshal(msgChannelFailure, ack))
|
||||
}
|
||||
panic("unreachable")
|
||||
}
|
||||
|
||||
func (c *channel) ChannelType() string {
|
||||
return c.chanType
|
||||
}
|
||||
|
||||
func (c *channel) ExtraData() []byte {
|
||||
return c.extraData
|
||||
}
|
96
src/pkg/exp/ssh/common.go
Normal file
96
src/pkg/exp/ssh/common.go
Normal file
@ -0,0 +1,96 @@
|
||||
// Copyright 2011 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package ssh
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
)
|
||||
|
||||
// These are string constants in the SSH protocol.
|
||||
const (
|
||||
kexAlgoDH14SHA1 = "diffie-hellman-group14-sha1"
|
||||
hostAlgoRSA = "ssh-rsa"
|
||||
cipherAES128CTR = "aes128-ctr"
|
||||
macSHA196 = "hmac-sha1-96"
|
||||
compressionNone = "none"
|
||||
serviceUserAuth = "ssh-userauth"
|
||||
serviceSSH = "ssh-connection"
|
||||
)
|
||||
|
||||
// UnexpectedMessageError results when the SSH message that we received didn't
|
||||
// match what we wanted.
|
||||
type UnexpectedMessageError struct {
|
||||
expected, got uint8
|
||||
}
|
||||
|
||||
func (u UnexpectedMessageError) String() string {
|
||||
return "ssh: unexpected message type " + strconv.Itoa(int(u.got)) + " (expected " + strconv.Itoa(int(u.expected)) + ")"
|
||||
}
|
||||
|
||||
// ParseError results from a malformed SSH message.
|
||||
type ParseError struct {
|
||||
msgType uint8
|
||||
}
|
||||
|
||||
func (p ParseError) String() string {
|
||||
return "ssh: parse error in message type " + strconv.Itoa(int(p.msgType))
|
||||
}
|
||||
|
||||
func findCommonAlgorithm(clientAlgos []string, serverAlgos []string) (commonAlgo string, ok bool) {
|
||||
for _, clientAlgo := range clientAlgos {
|
||||
for _, serverAlgo := range serverAlgos {
|
||||
if clientAlgo == serverAlgo {
|
||||
return clientAlgo, true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func findAgreedAlgorithms(clientToServer, serverToClient *halfConnection, clientKexInit, serverKexInit *kexInitMsg) (kexAlgo, hostKeyAlgo string, ok bool) {
|
||||
kexAlgo, ok = findCommonAlgorithm(clientKexInit.KexAlgos, serverKexInit.KexAlgos)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
hostKeyAlgo, ok = findCommonAlgorithm(clientKexInit.ServerHostKeyAlgos, serverKexInit.ServerHostKeyAlgos)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
clientToServer.cipherAlgo, ok = findCommonAlgorithm(clientKexInit.CiphersClientServer, serverKexInit.CiphersClientServer)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
serverToClient.cipherAlgo, ok = findCommonAlgorithm(clientKexInit.CiphersServerClient, serverKexInit.CiphersServerClient)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
clientToServer.macAlgo, ok = findCommonAlgorithm(clientKexInit.MACsClientServer, serverKexInit.MACsClientServer)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
serverToClient.macAlgo, ok = findCommonAlgorithm(clientKexInit.MACsServerClient, serverKexInit.MACsServerClient)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
clientToServer.compressionAlgo, ok = findCommonAlgorithm(clientKexInit.CompressionClientServer, serverKexInit.CompressionClientServer)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
serverToClient.compressionAlgo, ok = findCommonAlgorithm(clientKexInit.CompressionServerClient, serverKexInit.CompressionServerClient)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
ok = true
|
||||
return
|
||||
}
|
79
src/pkg/exp/ssh/doc.go
Normal file
79
src/pkg/exp/ssh/doc.go
Normal file
@ -0,0 +1,79 @@
|
||||
// Copyright 2011 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
/*
|
||||
Package ssh implements an SSH server.
|
||||
|
||||
SSH is a transport security protocol, an authentication protocol and a
|
||||
family of application protocols. The most typical application level
|
||||
protocol is a remote shell and this is specifically implemented. However,
|
||||
the multiplexed nature of SSH is exposed to users that wish to support
|
||||
others.
|
||||
|
||||
An SSH server is represented by a Server, which manages a number of
|
||||
ServerConnections and handles authentication.
|
||||
|
||||
var s Server
|
||||
s.PubKeyCallback = pubKeyAuth
|
||||
s.PasswordCallback = passwordAuth
|
||||
|
||||
pemBytes, err := ioutil.ReadFile("id_rsa")
|
||||
if err != nil {
|
||||
panic("Failed to load private key")
|
||||
}
|
||||
err = s.SetRSAPrivateKey(pemBytes)
|
||||
if err != nil {
|
||||
panic("Failed to parse private key")
|
||||
}
|
||||
|
||||
Once a Server has been set up, connections can be attached.
|
||||
|
||||
var sConn ServerConnection
|
||||
sConn.Server = &s
|
||||
err = sConn.Handshake(conn)
|
||||
if err != nil {
|
||||
panic("failed to handshake")
|
||||
}
|
||||
|
||||
An SSH connection multiplexes several channels, which must be accepted themselves:
|
||||
|
||||
|
||||
for {
|
||||
channel, err := sConn.Accept()
|
||||
if err != nil {
|
||||
panic("error from Accept")
|
||||
}
|
||||
|
||||
...
|
||||
}
|
||||
|
||||
Accept reads from the connection, demultiplexes packets to their corresponding
|
||||
channels and returns when a new channel request is seen. Some goroutine must
|
||||
always be calling Accept; otherwise no messages will be forwarded to the
|
||||
channels.
|
||||
|
||||
Channels have a type, depending on the application level protocol intended. In
|
||||
the case of a shell, the type is "session" and ServerShell may be used to
|
||||
present a simple terminal interface.
|
||||
|
||||
if channel.ChannelType() != "session" {
|
||||
c.Reject(RejectUnknownChannelType, "unknown channel type")
|
||||
return
|
||||
}
|
||||
channel.Accept()
|
||||
|
||||
shell := NewServerShell(channel, "> ")
|
||||
go func() {
|
||||
defer channel.Close()
|
||||
for {
|
||||
line, err := shell.ReadLine()
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
println(line)
|
||||
}
|
||||
return
|
||||
}()
|
||||
*/
|
||||
package ssh
|
557
src/pkg/exp/ssh/messages.go
Normal file
557
src/pkg/exp/ssh/messages.go
Normal file
@ -0,0 +1,557 @@
|
||||
// Copyright 2011 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package ssh
|
||||
|
||||
import (
|
||||
"big"
|
||||
"bytes"
|
||||
"io"
|
||||
"os"
|
||||
"reflect"
|
||||
)
|
||||
|
||||
// These are SSH message type numbers. They are scattered around several
|
||||
// documents but many were taken from
|
||||
// http://www.iana.org/assignments/ssh-parameters/ssh-parameters.xml#ssh-parameters-1
|
||||
const (
|
||||
msgDisconnect = 1
|
||||
msgIgnore = 2
|
||||
msgUnimplemented = 3
|
||||
msgDebug = 4
|
||||
msgServiceRequest = 5
|
||||
msgServiceAccept = 6
|
||||
|
||||
msgKexInit = 20
|
||||
msgNewKeys = 21
|
||||
|
||||
msgKexDHInit = 30
|
||||
msgKexDHReply = 31
|
||||
|
||||
msgUserAuthRequest = 50
|
||||
msgUserAuthFailure = 51
|
||||
msgUserAuthSuccess = 52
|
||||
msgUserAuthBanner = 53
|
||||
msgUserAuthPubKeyOk = 60
|
||||
|
||||
msgGlobalRequest = 80
|
||||
msgRequestSuccess = 81
|
||||
msgRequestFailure = 82
|
||||
|
||||
msgChannelOpen = 90
|
||||
msgChannelOpenConfirm = 91
|
||||
msgChannelOpenFailure = 92
|
||||
msgChannelWindowAdjust = 93
|
||||
msgChannelData = 94
|
||||
msgChannelExtendedData = 95
|
||||
msgChannelEOF = 96
|
||||
msgChannelClose = 97
|
||||
msgChannelRequest = 98
|
||||
msgChannelSuccess = 99
|
||||
msgChannelFailure = 100
|
||||
)
|
||||
|
||||
// SSH messages:
|
||||
//
|
||||
// These structures mirror the wire format of the corresponding SSH messages.
|
||||
// They are marshaled using reflection with the marshal and unmarshal functions
|
||||
// in this file. The only wrinkle is that a final member of type []byte with a
|
||||
// tag of "rest" receives the remainder of a packet when unmarshaling.
|
||||
|
||||
// See RFC 4253, section 7.1.
|
||||
type kexInitMsg struct {
|
||||
Cookie [16]byte
|
||||
KexAlgos []string
|
||||
ServerHostKeyAlgos []string
|
||||
CiphersClientServer []string
|
||||
CiphersServerClient []string
|
||||
MACsClientServer []string
|
||||
MACsServerClient []string
|
||||
CompressionClientServer []string
|
||||
CompressionServerClient []string
|
||||
LanguagesClientServer []string
|
||||
LanguagesServerClient []string
|
||||
FirstKexFollows bool
|
||||
Reserved uint32
|
||||
}
|
||||
|
||||
// See RFC 4253, section 8.
|
||||
type kexDHInitMsg struct {
|
||||
X *big.Int
|
||||
}
|
||||
|
||||
type kexDHReplyMsg struct {
|
||||
HostKey []byte
|
||||
Y *big.Int
|
||||
Signature []byte
|
||||
}
|
||||
|
||||
// See RFC 4253, section 10.
|
||||
type serviceRequestMsg struct {
|
||||
Service string
|
||||
}
|
||||
|
||||
// See RFC 4253, section 10.
|
||||
type serviceAcceptMsg struct {
|
||||
Service string
|
||||
}
|
||||
|
||||
// See RFC 4252, section 5.
|
||||
type userAuthRequestMsg struct {
|
||||
User string
|
||||
Service string
|
||||
Method string
|
||||
Payload []byte "rest"
|
||||
}
|
||||
|
||||
// See RFC 4252, section 5.1
|
||||
type userAuthFailureMsg struct {
|
||||
Methods []string
|
||||
PartialSuccess bool
|
||||
}
|
||||
|
||||
// See RFC 4254, section 5.1.
|
||||
type channelOpenMsg struct {
|
||||
ChanType string
|
||||
PeersId uint32
|
||||
PeersWindow uint32
|
||||
MaxPacketSize uint32
|
||||
TypeSpecificData []byte "rest"
|
||||
}
|
||||
|
||||
// See RFC 4254, section 5.1.
|
||||
type channelOpenConfirmMsg struct {
|
||||
PeersId uint32
|
||||
MyId uint32
|
||||
MyWindow uint32
|
||||
MaxPacketSize uint32
|
||||
TypeSpecificData []byte "rest"
|
||||
}
|
||||
|
||||
// See RFC 4254, section 5.1.
|
||||
type channelOpenFailureMsg struct {
|
||||
PeersId uint32
|
||||
Reason uint32
|
||||
Message string
|
||||
Language string
|
||||
}
|
||||
|
||||
type channelRequestMsg struct {
|
||||
PeersId uint32
|
||||
Request string
|
||||
WantReply bool
|
||||
RequestSpecificData []byte "rest"
|
||||
}
|
||||
|
||||
// See RFC 4254, section 5.4.
|
||||
type channelRequestSuccessMsg struct {
|
||||
PeersId uint32
|
||||
}
|
||||
|
||||
// See RFC 4254, section 5.4.
|
||||
type channelRequestFailureMsg struct {
|
||||
PeersId uint32
|
||||
}
|
||||
|
||||
// See RFC 4254, section 5.3
|
||||
type channelCloseMsg struct {
|
||||
PeersId uint32
|
||||
}
|
||||
|
||||
// See RFC 4254, section 5.3
|
||||
type channelEOFMsg struct {
|
||||
PeersId uint32
|
||||
}
|
||||
|
||||
// See RFC 4254, section 4
|
||||
type globalRequestMsg struct {
|
||||
Type string
|
||||
WantReply bool
|
||||
}
|
||||
|
||||
// See RFC 4254, section 5.2
|
||||
type windowAdjustMsg struct {
|
||||
PeersId uint32
|
||||
AdditionalBytes uint32
|
||||
}
|
||||
|
||||
// See RFC 4252, section 7
|
||||
type userAuthPubKeyOkMsg struct {
|
||||
Algo string
|
||||
PubKey string
|
||||
}
|
||||
|
||||
// unmarshal parses the SSH wire data in packet into out using reflection.
|
||||
// expectedType is the expected SSH message type. It either returns nil on
|
||||
// success, or a ParseError or UnexpectedMessageError on error.
|
||||
func unmarshal(out interface{}, packet []byte, expectedType uint8) os.Error {
|
||||
if len(packet) == 0 {
|
||||
return ParseError{expectedType}
|
||||
}
|
||||
if packet[0] != expectedType {
|
||||
return UnexpectedMessageError{expectedType, packet[0]}
|
||||
}
|
||||
packet = packet[1:]
|
||||
|
||||
v := reflect.ValueOf(out).Elem()
|
||||
structType := v.Type()
|
||||
var ok bool
|
||||
for i := 0; i < v.NumField(); i++ {
|
||||
field := v.Field(i)
|
||||
t := field.Type()
|
||||
switch t.Kind() {
|
||||
case reflect.Bool:
|
||||
if len(packet) < 1 {
|
||||
return ParseError{expectedType}
|
||||
}
|
||||
field.SetBool(packet[0] != 0)
|
||||
packet = packet[1:]
|
||||
case reflect.Array:
|
||||
if t.Elem().Kind() != reflect.Uint8 {
|
||||
panic("array of non-uint8")
|
||||
}
|
||||
if len(packet) < t.Len() {
|
||||
return ParseError{expectedType}
|
||||
}
|
||||
for j := 0; j < t.Len(); j++ {
|
||||
field.Index(j).Set(reflect.ValueOf(packet[j]))
|
||||
}
|
||||
packet = packet[t.Len():]
|
||||
case reflect.Uint32:
|
||||
var u32 uint32
|
||||
if u32, packet, ok = parseUint32(packet); !ok {
|
||||
return ParseError{expectedType}
|
||||
}
|
||||
field.SetUint(uint64(u32))
|
||||
case reflect.String:
|
||||
var s []byte
|
||||
if s, packet, ok = parseString(packet); !ok {
|
||||
return ParseError{expectedType}
|
||||
}
|
||||
field.SetString(string(s))
|
||||
case reflect.Slice:
|
||||
switch t.Elem().Kind() {
|
||||
case reflect.Uint8:
|
||||
if structType.Field(i).Tag == "rest" {
|
||||
field.Set(reflect.ValueOf(packet))
|
||||
packet = nil
|
||||
} else {
|
||||
var s []byte
|
||||
if s, packet, ok = parseString(packet); !ok {
|
||||
return ParseError{expectedType}
|
||||
}
|
||||
field.Set(reflect.ValueOf(s))
|
||||
}
|
||||
case reflect.String:
|
||||
var nl []string
|
||||
if nl, packet, ok = parseNameList(packet); !ok {
|
||||
return ParseError{expectedType}
|
||||
}
|
||||
field.Set(reflect.ValueOf(nl))
|
||||
default:
|
||||
panic("slice of unknown type")
|
||||
}
|
||||
case reflect.Ptr:
|
||||
if t == bigIntType {
|
||||
var n *big.Int
|
||||
if n, packet, ok = parseInt(packet); !ok {
|
||||
return ParseError{expectedType}
|
||||
}
|
||||
field.Set(reflect.ValueOf(n))
|
||||
} else {
|
||||
panic("pointer to unknown type")
|
||||
}
|
||||
default:
|
||||
panic("unknown type")
|
||||
}
|
||||
}
|
||||
|
||||
if len(packet) != 0 {
|
||||
return ParseError{expectedType}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// marshal serializes the message in msg, using the given message type.
|
||||
func marshal(msgType uint8, msg interface{}) []byte {
|
||||
var out []byte
|
||||
out = append(out, msgType)
|
||||
|
||||
v := reflect.ValueOf(msg)
|
||||
structType := v.Type()
|
||||
for i := 0; i < v.NumField(); i++ {
|
||||
field := v.Field(i)
|
||||
t := field.Type()
|
||||
switch t.Kind() {
|
||||
case reflect.Bool:
|
||||
var v uint8
|
||||
if field.Bool() {
|
||||
v = 1
|
||||
}
|
||||
out = append(out, v)
|
||||
case reflect.Array:
|
||||
if t.Elem().Kind() != reflect.Uint8 {
|
||||
panic("array of non-uint8")
|
||||
}
|
||||
for j := 0; j < t.Len(); j++ {
|
||||
out = append(out, byte(field.Index(j).Uint()))
|
||||
}
|
||||
case reflect.Uint32:
|
||||
u32 := uint32(field.Uint())
|
||||
out = append(out, byte(u32>>24))
|
||||
out = append(out, byte(u32>>16))
|
||||
out = append(out, byte(u32>>8))
|
||||
out = append(out, byte(u32))
|
||||
case reflect.String:
|
||||
s := field.String()
|
||||
out = append(out, byte(len(s)>>24))
|
||||
out = append(out, byte(len(s)>>16))
|
||||
out = append(out, byte(len(s)>>8))
|
||||
out = append(out, byte(len(s)))
|
||||
out = append(out, []byte(s)...)
|
||||
case reflect.Slice:
|
||||
switch t.Elem().Kind() {
|
||||
case reflect.Uint8:
|
||||
length := field.Len()
|
||||
if structType.Field(i).Tag != "rest" {
|
||||
out = append(out, byte(length>>24))
|
||||
out = append(out, byte(length>>16))
|
||||
out = append(out, byte(length>>8))
|
||||
out = append(out, byte(length))
|
||||
}
|
||||
for j := 0; j < length; j++ {
|
||||
out = append(out, byte(field.Index(j).Uint()))
|
||||
}
|
||||
case reflect.String:
|
||||
var length int
|
||||
for j := 0; j < field.Len(); j++ {
|
||||
if j != 0 {
|
||||
length++ /* comma */
|
||||
}
|
||||
length += len(field.Index(j).String())
|
||||
}
|
||||
|
||||
out = append(out, byte(length>>24))
|
||||
out = append(out, byte(length>>16))
|
||||
out = append(out, byte(length>>8))
|
||||
out = append(out, byte(length))
|
||||
for j := 0; j < field.Len(); j++ {
|
||||
if j != 0 {
|
||||
out = append(out, ',')
|
||||
}
|
||||
out = append(out, []byte(field.Index(j).String())...)
|
||||
}
|
||||
default:
|
||||
panic("slice of unknown type")
|
||||
}
|
||||
case reflect.Ptr:
|
||||
if t == bigIntType {
|
||||
var n *big.Int
|
||||
nValue := reflect.ValueOf(&n)
|
||||
nValue.Elem().Set(field)
|
||||
needed := intLength(n)
|
||||
oldLength := len(out)
|
||||
|
||||
if cap(out)-len(out) < needed {
|
||||
newOut := make([]byte, len(out), 2*(len(out)+needed))
|
||||
copy(newOut, out)
|
||||
out = newOut
|
||||
}
|
||||
out = out[:oldLength+needed]
|
||||
marshalInt(out[oldLength:], n)
|
||||
} else {
|
||||
panic("pointer to unknown type")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return out
|
||||
}
|
||||
|
||||
var bigOne = big.NewInt(1)
|
||||
|
||||
func parseString(in []byte) (out, rest []byte, ok bool) {
|
||||
if len(in) < 4 {
|
||||
return
|
||||
}
|
||||
length := uint32(in[0])<<24 | uint32(in[1])<<16 | uint32(in[2])<<8 | uint32(in[3])
|
||||
if uint32(len(in)) < 4+length {
|
||||
return
|
||||
}
|
||||
out = in[4 : 4+length]
|
||||
rest = in[4+length:]
|
||||
ok = true
|
||||
return
|
||||
}
|
||||
|
||||
var comma = []byte{','}
|
||||
|
||||
func parseNameList(in []byte) (out []string, rest []byte, ok bool) {
|
||||
contents, rest, ok := parseString(in)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
if len(contents) == 0 {
|
||||
return
|
||||
}
|
||||
parts := bytes.Split(contents, comma)
|
||||
out = make([]string, len(parts))
|
||||
for i, part := range parts {
|
||||
out[i] = string(part)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func parseInt(in []byte) (out *big.Int, rest []byte, ok bool) {
|
||||
contents, rest, ok := parseString(in)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
out = new(big.Int)
|
||||
|
||||
if len(contents) > 0 && contents[0]&0x80 == 0x80 {
|
||||
// This is a negative number
|
||||
notBytes := make([]byte, len(contents))
|
||||
for i := range notBytes {
|
||||
notBytes[i] = ^contents[i]
|
||||
}
|
||||
out.SetBytes(notBytes)
|
||||
out.Add(out, bigOne)
|
||||
out.Neg(out)
|
||||
} else {
|
||||
// Positive number
|
||||
out.SetBytes(contents)
|
||||
}
|
||||
ok = true
|
||||
return
|
||||
}
|
||||
|
||||
func parseUint32(in []byte) (out uint32, rest []byte, ok bool) {
|
||||
if len(in) < 4 {
|
||||
return
|
||||
}
|
||||
out = uint32(in[0])<<24 | uint32(in[1])<<16 | uint32(in[2])<<8 | uint32(in[3])
|
||||
rest = in[4:]
|
||||
ok = true
|
||||
return
|
||||
}
|
||||
|
||||
const maxPacketSize = 36000
|
||||
|
||||
func nameListLength(namelist []string) int {
|
||||
length := 4 /* uint32 length prefix */
|
||||
for i, name := range namelist {
|
||||
if i != 0 {
|
||||
length++ /* comma */
|
||||
}
|
||||
length += len(name)
|
||||
}
|
||||
return length
|
||||
}
|
||||
|
||||
func intLength(n *big.Int) int {
|
||||
length := 4 /* length bytes */
|
||||
if n.Sign() < 0 {
|
||||
nMinus1 := new(big.Int).Neg(n)
|
||||
nMinus1.Sub(nMinus1, bigOne)
|
||||
bitLen := nMinus1.BitLen()
|
||||
if bitLen%8 == 0 {
|
||||
// The number will need 0xff padding
|
||||
length++
|
||||
}
|
||||
length += (bitLen + 7) / 8
|
||||
} else if n.Sign() == 0 {
|
||||
// A zero is the zero length string
|
||||
} else {
|
||||
bitLen := n.BitLen()
|
||||
if bitLen%8 == 0 {
|
||||
// The number will need 0x00 padding
|
||||
length++
|
||||
}
|
||||
length += (bitLen + 7) / 8
|
||||
}
|
||||
|
||||
return length
|
||||
}
|
||||
|
||||
func marshalInt(to []byte, n *big.Int) []byte {
|
||||
lengthBytes := to
|
||||
to = to[4:]
|
||||
length := 0
|
||||
|
||||
if n.Sign() < 0 {
|
||||
// A negative number has to be converted to two's-complement
|
||||
// form. So we'll subtract 1 and invert. If the
|
||||
// most-significant-bit isn't set then we'll need to pad the
|
||||
// beginning with 0xff in order to keep the number negative.
|
||||
nMinus1 := new(big.Int).Neg(n)
|
||||
nMinus1.Sub(nMinus1, bigOne)
|
||||
bytes := nMinus1.Bytes()
|
||||
for i := range bytes {
|
||||
bytes[i] ^= 0xff
|
||||
}
|
||||
if len(bytes) == 0 || bytes[0]&0x80 == 0 {
|
||||
to[0] = 0xff
|
||||
to = to[1:]
|
||||
length++
|
||||
}
|
||||
nBytes := copy(to, bytes)
|
||||
to = to[nBytes:]
|
||||
length += nBytes
|
||||
} else if n.Sign() == 0 {
|
||||
// A zero is the zero length string
|
||||
} else {
|
||||
bytes := n.Bytes()
|
||||
if len(bytes) > 0 && bytes[0]&0x80 != 0 {
|
||||
// We'll have to pad this with a 0x00 in order to
|
||||
// stop it looking like a negative number.
|
||||
to[0] = 0
|
||||
to = to[1:]
|
||||
length++
|
||||
}
|
||||
nBytes := copy(to, bytes)
|
||||
to = to[nBytes:]
|
||||
length += nBytes
|
||||
}
|
||||
|
||||
lengthBytes[0] = byte(length >> 24)
|
||||
lengthBytes[1] = byte(length >> 16)
|
||||
lengthBytes[2] = byte(length >> 8)
|
||||
lengthBytes[3] = byte(length)
|
||||
return to
|
||||
}
|
||||
|
||||
func writeInt(w io.Writer, n *big.Int) {
|
||||
length := intLength(n)
|
||||
buf := make([]byte, length)
|
||||
marshalInt(buf, n)
|
||||
w.Write(buf)
|
||||
}
|
||||
|
||||
func writeString(w io.Writer, s []byte) {
|
||||
var lengthBytes [4]byte
|
||||
lengthBytes[0] = byte(len(s) >> 24)
|
||||
lengthBytes[1] = byte(len(s) >> 16)
|
||||
lengthBytes[2] = byte(len(s) >> 8)
|
||||
lengthBytes[3] = byte(len(s))
|
||||
w.Write(lengthBytes[:])
|
||||
w.Write(s)
|
||||
}
|
||||
|
||||
func stringLength(s []byte) int {
|
||||
return 4 + len(s)
|
||||
}
|
||||
|
||||
func marshalString(to []byte, s []byte) []byte {
|
||||
to[0] = byte(len(s) >> 24)
|
||||
to[1] = byte(len(s) >> 16)
|
||||
to[2] = byte(len(s) >> 8)
|
||||
to[3] = byte(len(s))
|
||||
to = to[4:]
|
||||
copy(to, s)
|
||||
return to[len(s):]
|
||||
}
|
||||
|
||||
var bigIntType = reflect.TypeOf((*big.Int)(nil))
|
125
src/pkg/exp/ssh/messages_test.go
Normal file
125
src/pkg/exp/ssh/messages_test.go
Normal file
@ -0,0 +1,125 @@
|
||||
// Copyright 2011 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package ssh
|
||||
|
||||
import (
|
||||
"big"
|
||||
"rand"
|
||||
"reflect"
|
||||
"testing"
|
||||
"testing/quick"
|
||||
)
|
||||
|
||||
var intLengthTests = []struct {
|
||||
val, length int
|
||||
}{
|
||||
{0, 4 + 0},
|
||||
{1, 4 + 1},
|
||||
{127, 4 + 1},
|
||||
{128, 4 + 2},
|
||||
{-1, 4 + 1},
|
||||
}
|
||||
|
||||
func TestIntLength(t *testing.T) {
|
||||
for _, test := range intLengthTests {
|
||||
v := new(big.Int).SetInt64(int64(test.val))
|
||||
length := intLength(v)
|
||||
if length != test.length {
|
||||
t.Errorf("For %d, got length %d but expected %d", test.val, length, test.length)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var messageTypes = []interface{}{
|
||||
&kexInitMsg{},
|
||||
&kexDHInitMsg{},
|
||||
&serviceRequestMsg{},
|
||||
&serviceAcceptMsg{},
|
||||
&userAuthRequestMsg{},
|
||||
&channelOpenMsg{},
|
||||
&channelOpenConfirmMsg{},
|
||||
&channelRequestMsg{},
|
||||
&channelRequestSuccessMsg{},
|
||||
}
|
||||
|
||||
func TestMarshalUnmarshal(t *testing.T) {
|
||||
rand := rand.New(rand.NewSource(0))
|
||||
for i, iface := range messageTypes {
|
||||
ty := reflect.ValueOf(iface).Type()
|
||||
|
||||
n := 100
|
||||
if testing.Short() {
|
||||
n = 5
|
||||
}
|
||||
for j := 0; j < n; j++ {
|
||||
v, ok := quick.Value(ty, rand)
|
||||
if !ok {
|
||||
t.Errorf("#%d: failed to create value", i)
|
||||
break
|
||||
}
|
||||
|
||||
m1 := v.Elem().Interface()
|
||||
m2 := iface
|
||||
|
||||
marshaled := marshal(msgIgnore, m1)
|
||||
if err := unmarshal(m2, marshaled, msgIgnore); err != nil {
|
||||
t.Errorf("#%d failed to unmarshal %#v: %s", i, m1, err)
|
||||
break
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(v.Interface(), m2) {
|
||||
t.Errorf("#%d\ngot: %#v\nwant:%#v\n%x", i, m2, m1, marshaled)
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func randomBytes(out []byte, rand *rand.Rand) {
|
||||
for i := 0; i < len(out); i++ {
|
||||
out[i] = byte(rand.Int31())
|
||||
}
|
||||
}
|
||||
|
||||
func randomNameList(rand *rand.Rand) []string {
|
||||
ret := make([]string, rand.Int31()&15)
|
||||
for i := range ret {
|
||||
s := make([]byte, 1+(rand.Int31()&15))
|
||||
for j := range s {
|
||||
s[j] = 'a' + uint8(rand.Int31()&15)
|
||||
}
|
||||
ret[i] = string(s)
|
||||
}
|
||||
return ret
|
||||
}
|
||||
|
||||
func randomInt(rand *rand.Rand) *big.Int {
|
||||
return new(big.Int).SetInt64(int64(int32(rand.Uint32())))
|
||||
}
|
||||
|
||||
func (*kexInitMsg) Generate(rand *rand.Rand, size int) reflect.Value {
|
||||
ki := &kexInitMsg{}
|
||||
randomBytes(ki.Cookie[:], rand)
|
||||
ki.KexAlgos = randomNameList(rand)
|
||||
ki.ServerHostKeyAlgos = randomNameList(rand)
|
||||
ki.CiphersClientServer = randomNameList(rand)
|
||||
ki.CiphersServerClient = randomNameList(rand)
|
||||
ki.MACsClientServer = randomNameList(rand)
|
||||
ki.MACsServerClient = randomNameList(rand)
|
||||
ki.CompressionClientServer = randomNameList(rand)
|
||||
ki.CompressionServerClient = randomNameList(rand)
|
||||
ki.LanguagesClientServer = randomNameList(rand)
|
||||
ki.LanguagesServerClient = randomNameList(rand)
|
||||
if rand.Int31()&1 == 1 {
|
||||
ki.FirstKexFollows = true
|
||||
}
|
||||
return reflect.ValueOf(ki)
|
||||
}
|
||||
|
||||
func (*kexDHInitMsg) Generate(rand *rand.Rand, size int) reflect.Value {
|
||||
dhi := &kexDHInitMsg{}
|
||||
dhi.X = randomInt(rand)
|
||||
return reflect.ValueOf(dhi)
|
||||
}
|
711
src/pkg/exp/ssh/server.go
Normal file
711
src/pkg/exp/ssh/server.go
Normal file
@ -0,0 +1,711 @@
|
||||
// Copyright 2011 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package ssh
|
||||
|
||||
import (
|
||||
"big"
|
||||
"bufio"
|
||||
"bytes"
|
||||
"crypto"
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
_ "crypto/sha1"
|
||||
"crypto/x509"
|
||||
"encoding/pem"
|
||||
"net"
|
||||
"os"
|
||||
"sync"
|
||||
)
|
||||
|
||||
var supportedKexAlgos = []string{kexAlgoDH14SHA1}
|
||||
var supportedHostKeyAlgos = []string{hostAlgoRSA}
|
||||
var supportedCiphers = []string{cipherAES128CTR}
|
||||
var supportedMACs = []string{macSHA196}
|
||||
var supportedCompressions = []string{compressionNone}
|
||||
|
||||
// Server represents an SSH server. A Server may have several ServerConnections.
|
||||
type Server struct {
|
||||
rsa *rsa.PrivateKey
|
||||
rsaSerialized []byte
|
||||
|
||||
// NoClientAuth is true if clients are allowed to connect without
|
||||
// authenticating.
|
||||
NoClientAuth bool
|
||||
|
||||
// PasswordCallback, if non-nil, is called when a user attempts to
|
||||
// authenticate using a password. It may be called concurrently from
|
||||
// several goroutines.
|
||||
PasswordCallback func(user, password string) bool
|
||||
|
||||
// PubKeyCallback, if non-nil, is called when a client attempts public
|
||||
// key authentication. It must return true iff the given public key is
|
||||
// valid for the given user.
|
||||
PubKeyCallback func(user, algo string, pubkey []byte) bool
|
||||
}
|
||||
|
||||
// SetRSAPrivateKey sets the private key for a Server. A Server must have a
|
||||
// private key configured in order to accept connections. The private key must
|
||||
// be in the form of a PEM encoded, PKCS#1, RSA private key. The file "id_rsa"
|
||||
// typically contains such a key.
|
||||
func (s *Server) SetRSAPrivateKey(pemBytes []byte) os.Error {
|
||||
block, _ := pem.Decode(pemBytes)
|
||||
if block == nil {
|
||||
return os.NewError("ssh: no key found")
|
||||
}
|
||||
var err os.Error
|
||||
s.rsa, err = x509.ParsePKCS1PrivateKey(block.Bytes)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
s.rsaSerialized = marshalRSA(s.rsa)
|
||||
return nil
|
||||
}
|
||||
|
||||
// marshalRSA serializes an RSA private key according to RFC 4256, section 6.6.
|
||||
func marshalRSA(priv *rsa.PrivateKey) []byte {
|
||||
e := new(big.Int).SetInt64(int64(priv.E))
|
||||
length := stringLength([]byte(hostAlgoRSA))
|
||||
length += intLength(e)
|
||||
length += intLength(priv.N)
|
||||
|
||||
ret := make([]byte, length)
|
||||
r := marshalString(ret, []byte(hostAlgoRSA))
|
||||
r = marshalInt(r, e)
|
||||
r = marshalInt(r, priv.N)
|
||||
|
||||
return ret
|
||||
}
|
||||
|
||||
// parseRSA parses an RSA key according to RFC 4256, section 6.6.
|
||||
func parseRSA(in []byte) (pubKey *rsa.PublicKey, ok bool) {
|
||||
algo, in, ok := parseString(in)
|
||||
if !ok || string(algo) != hostAlgoRSA {
|
||||
return nil, false
|
||||
}
|
||||
bigE, in, ok := parseInt(in)
|
||||
if !ok || bigE.BitLen() > 24 {
|
||||
return nil, false
|
||||
}
|
||||
e := bigE.Int64()
|
||||
if e < 3 || e&1 == 0 {
|
||||
return nil, false
|
||||
}
|
||||
N, in, ok := parseInt(in)
|
||||
if !ok || len(in) > 0 {
|
||||
return nil, false
|
||||
}
|
||||
return &rsa.PublicKey{
|
||||
N: N,
|
||||
E: int(e),
|
||||
}, true
|
||||
}
|
||||
|
||||
func parseRSASig(in []byte) (sig []byte, ok bool) {
|
||||
algo, in, ok := parseString(in)
|
||||
if !ok || string(algo) != hostAlgoRSA {
|
||||
return nil, false
|
||||
}
|
||||
sig, in, ok = parseString(in)
|
||||
if len(in) > 0 {
|
||||
ok = false
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// cachedPubKey contains the results of querying whether a public key is
|
||||
// acceptable for a user. The cache only applies to a single ServerConnection.
|
||||
type cachedPubKey struct {
|
||||
user, algo string
|
||||
pubKey []byte
|
||||
result bool
|
||||
}
|
||||
|
||||
const maxCachedPubKeys = 16
|
||||
|
||||
// ServerConnection represents an incomming connection to a Server.
|
||||
type ServerConnection struct {
|
||||
Server *Server
|
||||
|
||||
in, out *halfConnection
|
||||
|
||||
channels map[uint32]*channel
|
||||
nextChanId uint32
|
||||
|
||||
// lock protects err and also allows Channels to serialise their writes
|
||||
// to out.
|
||||
lock sync.RWMutex
|
||||
err os.Error
|
||||
|
||||
// cachedPubKeys contains the cache results of tests for public keys.
|
||||
// Since SSH clients will query whether a public key is acceptable
|
||||
// before attempting to authenticate with it, we end up with duplicate
|
||||
// queries for public key validity.
|
||||
cachedPubKeys []cachedPubKey
|
||||
}
|
||||
|
||||
// dhGroup is a multiplicative group suitable for implementing Diffie-Hellman key agreement.
|
||||
type dhGroup struct {
|
||||
g, p *big.Int
|
||||
}
|
||||
|
||||
// dhGroup14 is the group called diffie-hellman-group14-sha1 in RFC 4253 and
|
||||
// Oakley Group 14 in RFC 3526.
|
||||
var dhGroup14 *dhGroup
|
||||
|
||||
var dhGroup14Once sync.Once
|
||||
|
||||
func initDHGroup14() {
|
||||
p, _ := new(big.Int).SetString("FFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7EDEE386BFB5A899FA5AE9F24117C4B1FE649286651ECE45B3DC2007CB8A163BF0598DA48361C55D39A69163FA8FD24CF5F83655D23DCA3AD961C62F356208552BB9ED529077096966D670C354E4ABC9804F1746C08CA18217C32905E462E36CE3BE39E772C180E86039B2783A2EC07A28FB5C55DF06F4C52C9DE2BCBF6955817183995497CEA956AE515D2261898FA051015728E5A8AACAA68FFFFFFFFFFFFFFFF", 16)
|
||||
|
||||
dhGroup14 = &dhGroup{
|
||||
g: new(big.Int).SetInt64(2),
|
||||
p: p,
|
||||
}
|
||||
}
|
||||
|
||||
type handshakeMagics struct {
|
||||
clientVersion, serverVersion []byte
|
||||
clientKexInit, serverKexInit []byte
|
||||
}
|
||||
|
||||
// kexDH performs Diffie-Hellman key agreement on a ServerConnection. The
|
||||
// returned values are given the same names as in RFC 4253, section 8.
|
||||
func (s *ServerConnection) kexDH(group *dhGroup, hashFunc crypto.Hash, magics *handshakeMagics, hostKeyAlgo string) (H, K []byte, err os.Error) {
|
||||
packet, err := s.in.readPacket()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
var kexDHInit kexDHInitMsg
|
||||
if err = unmarshal(&kexDHInit, packet, msgKexDHInit); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if kexDHInit.X.Sign() == 0 || kexDHInit.X.Cmp(group.p) >= 0 {
|
||||
return nil, nil, os.NewError("client DH parameter out of bounds")
|
||||
}
|
||||
|
||||
y, err := rand.Int(rand.Reader, group.p)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
Y := new(big.Int).Exp(group.g, y, group.p)
|
||||
kInt := new(big.Int).Exp(kexDHInit.X, y, group.p)
|
||||
|
||||
var serializedHostKey []byte
|
||||
switch hostKeyAlgo {
|
||||
case hostAlgoRSA:
|
||||
serializedHostKey = s.Server.rsaSerialized
|
||||
default:
|
||||
return nil, nil, os.NewError("internal error")
|
||||
}
|
||||
|
||||
h := hashFunc.New()
|
||||
writeString(h, magics.clientVersion)
|
||||
writeString(h, magics.serverVersion)
|
||||
writeString(h, magics.clientKexInit)
|
||||
writeString(h, magics.serverKexInit)
|
||||
writeString(h, serializedHostKey)
|
||||
writeInt(h, kexDHInit.X)
|
||||
writeInt(h, Y)
|
||||
K = make([]byte, intLength(kInt))
|
||||
marshalInt(K, kInt)
|
||||
h.Write(K)
|
||||
|
||||
H = h.Sum()
|
||||
|
||||
h.Reset()
|
||||
h.Write(H)
|
||||
hh := h.Sum()
|
||||
|
||||
var sig []byte
|
||||
switch hostKeyAlgo {
|
||||
case hostAlgoRSA:
|
||||
sig, err = rsa.SignPKCS1v15(rand.Reader, s.Server.rsa, hashFunc, hh)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
default:
|
||||
return nil, nil, os.NewError("internal error")
|
||||
}
|
||||
|
||||
serializedSig := serializeRSASignature(sig)
|
||||
|
||||
kexDHReply := kexDHReplyMsg{
|
||||
HostKey: serializedHostKey,
|
||||
Y: Y,
|
||||
Signature: serializedSig,
|
||||
}
|
||||
packet = marshal(msgKexDHReply, kexDHReply)
|
||||
|
||||
err = s.out.writePacket(packet)
|
||||
return
|
||||
}
|
||||
|
||||
func serializeRSASignature(sig []byte) []byte {
|
||||
length := stringLength([]byte(hostAlgoRSA))
|
||||
length += stringLength(sig)
|
||||
|
||||
ret := make([]byte, length)
|
||||
r := marshalString(ret, []byte(hostAlgoRSA))
|
||||
r = marshalString(r, sig)
|
||||
|
||||
return ret
|
||||
}
|
||||
|
||||
// serverVersion is the fixed identification string that Server will use.
|
||||
var serverVersion = []byte("SSH-2.0-Go\r\n")
|
||||
|
||||
// buildDataSignedForAuth returns the data that is signed in order to prove
|
||||
// posession of a private key. See RFC 4252, section 7.
|
||||
func buildDataSignedForAuth(sessionId []byte, req userAuthRequestMsg, algo, pubKey []byte) []byte {
|
||||
user := []byte(req.User)
|
||||
service := []byte(req.Service)
|
||||
method := []byte(req.Method)
|
||||
|
||||
length := stringLength(sessionId)
|
||||
length += 1
|
||||
length += stringLength(user)
|
||||
length += stringLength(service)
|
||||
length += stringLength(method)
|
||||
length += 1
|
||||
length += stringLength(algo)
|
||||
length += stringLength(pubKey)
|
||||
|
||||
ret := make([]byte, length)
|
||||
r := marshalString(ret, sessionId)
|
||||
r[0] = msgUserAuthRequest
|
||||
r = r[1:]
|
||||
r = marshalString(r, user)
|
||||
r = marshalString(r, service)
|
||||
r = marshalString(r, method)
|
||||
r[0] = 1
|
||||
r = r[1:]
|
||||
r = marshalString(r, algo)
|
||||
r = marshalString(r, pubKey)
|
||||
return ret
|
||||
}
|
||||
|
||||
// Handshake performs an SSH transport and client authentication on the given ServerConnection.
|
||||
func (s *ServerConnection) Handshake(conn net.Conn) os.Error {
|
||||
var magics handshakeMagics
|
||||
inBuf := bufio.NewReader(conn)
|
||||
|
||||
_, err := conn.Write(serverVersion)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
magics.serverVersion = serverVersion[:len(serverVersion)-2]
|
||||
serverKexInit := kexInitMsg{
|
||||
KexAlgos: supportedKexAlgos,
|
||||
ServerHostKeyAlgos: supportedHostKeyAlgos,
|
||||
CiphersClientServer: supportedCiphers,
|
||||
CiphersServerClient: supportedCiphers,
|
||||
MACsClientServer: supportedMACs,
|
||||
MACsServerClient: supportedMACs,
|
||||
CompressionClientServer: supportedCompressions,
|
||||
CompressionServerClient: supportedCompressions,
|
||||
}
|
||||
kexInitPacket := marshal(msgKexInit, serverKexInit)
|
||||
magics.serverKexInit = kexInitPacket
|
||||
|
||||
var out halfConnection
|
||||
out.out = conn
|
||||
out.rand = rand.Reader
|
||||
s.out = &out
|
||||
err = out.writePacket(kexInitPacket)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
version, ok := readVersion(inBuf)
|
||||
if !ok {
|
||||
return os.NewError("failed to read version string from client")
|
||||
}
|
||||
magics.clientVersion = version
|
||||
|
||||
var in halfConnection
|
||||
in.in = inBuf
|
||||
s.in = &in
|
||||
packet, err := in.readPacket()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
magics.clientKexInit = packet
|
||||
|
||||
var clientKexInit kexInitMsg
|
||||
if err = unmarshal(&clientKexInit, packet, msgKexInit); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
kexAlgo, hostKeyAlgo, ok := findAgreedAlgorithms(&in, &out, &clientKexInit, &serverKexInit)
|
||||
if !ok {
|
||||
return os.NewError("ssh: no common algorithms")
|
||||
}
|
||||
|
||||
if clientKexInit.FirstKexFollows && kexAlgo != clientKexInit.KexAlgos[0] {
|
||||
// The client sent a Kex message for the wrong algorithm,
|
||||
// which we have to ignore.
|
||||
_, err := in.readPacket()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
var H, K []byte
|
||||
var hashFunc crypto.Hash
|
||||
switch kexAlgo {
|
||||
case kexAlgoDH14SHA1:
|
||||
hashFunc = crypto.SHA1
|
||||
dhGroup14Once.Do(initDHGroup14)
|
||||
H, K, err = s.kexDH(dhGroup14, hashFunc, &magics, hostKeyAlgo)
|
||||
default:
|
||||
err = os.NewError("ssh: internal error")
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
packet = []byte{msgNewKeys}
|
||||
if err = out.writePacket(packet); err != nil {
|
||||
return err
|
||||
}
|
||||
if err = out.setupKeys(serverKeys, K, H, H, hashFunc); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if packet, err = in.readPacket(); err != nil {
|
||||
return err
|
||||
}
|
||||
if packet[0] != msgNewKeys {
|
||||
return UnexpectedMessageError{msgNewKeys, packet[0]}
|
||||
}
|
||||
|
||||
in.setupKeys(clientKeys, K, H, H, hashFunc)
|
||||
|
||||
packet, err = in.readPacket()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var serviceRequest serviceRequestMsg
|
||||
if err = unmarshal(&serviceRequest, packet, msgServiceRequest); err != nil {
|
||||
return err
|
||||
}
|
||||
if serviceRequest.Service != serviceUserAuth {
|
||||
return os.NewError("ssh: requested service '" + serviceRequest.Service + "' before authenticating")
|
||||
}
|
||||
|
||||
serviceAccept := serviceAcceptMsg{
|
||||
Service: serviceUserAuth,
|
||||
}
|
||||
packet = marshal(msgServiceAccept, serviceAccept)
|
||||
if err = out.writePacket(packet); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err = s.authenticate(H); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
s.channels = make(map[uint32]*channel)
|
||||
return nil
|
||||
}
|
||||
|
||||
func isAcceptableAlgo(algo string) bool {
|
||||
return algo == hostAlgoRSA
|
||||
}
|
||||
|
||||
// testPubKey returns true if the given public key is acceptable for the user.
|
||||
func (s *ServerConnection) testPubKey(user, algo string, pubKey []byte) bool {
|
||||
if s.Server.PubKeyCallback == nil || !isAcceptableAlgo(algo) {
|
||||
return false
|
||||
}
|
||||
|
||||
for _, c := range s.cachedPubKeys {
|
||||
if c.user == user && c.algo == algo && bytes.Equal(c.pubKey, pubKey) {
|
||||
return c.result
|
||||
}
|
||||
}
|
||||
|
||||
result := s.Server.PubKeyCallback(user, algo, pubKey)
|
||||
if len(s.cachedPubKeys) < maxCachedPubKeys {
|
||||
c := cachedPubKey{
|
||||
user: user,
|
||||
algo: algo,
|
||||
pubKey: make([]byte, len(pubKey)),
|
||||
result: result,
|
||||
}
|
||||
copy(c.pubKey, pubKey)
|
||||
s.cachedPubKeys = append(s.cachedPubKeys, c)
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
func (s *ServerConnection) authenticate(H []byte) os.Error {
|
||||
var userAuthReq userAuthRequestMsg
|
||||
var err os.Error
|
||||
var packet []byte
|
||||
|
||||
userAuthLoop:
|
||||
for {
|
||||
if packet, err = s.in.readPacket(); err != nil {
|
||||
return err
|
||||
}
|
||||
if err = unmarshal(&userAuthReq, packet, msgUserAuthRequest); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if userAuthReq.Service != serviceSSH {
|
||||
return os.NewError("ssh: client attempted to negotiate for unknown service: " + userAuthReq.Service)
|
||||
}
|
||||
|
||||
switch userAuthReq.Method {
|
||||
case "none":
|
||||
if s.Server.NoClientAuth {
|
||||
break userAuthLoop
|
||||
}
|
||||
case "password":
|
||||
if s.Server.PasswordCallback == nil {
|
||||
break
|
||||
}
|
||||
payload := userAuthReq.Payload
|
||||
if len(payload) < 1 || payload[0] != 0 {
|
||||
return ParseError{msgUserAuthRequest}
|
||||
}
|
||||
payload = payload[1:]
|
||||
password, payload, ok := parseString(payload)
|
||||
if !ok || len(payload) > 0 {
|
||||
return ParseError{msgUserAuthRequest}
|
||||
}
|
||||
|
||||
if s.Server.PasswordCallback(userAuthReq.User, string(password)) {
|
||||
break userAuthLoop
|
||||
}
|
||||
case "publickey":
|
||||
if s.Server.PubKeyCallback == nil {
|
||||
break
|
||||
}
|
||||
payload := userAuthReq.Payload
|
||||
if len(payload) < 1 {
|
||||
return ParseError{msgUserAuthRequest}
|
||||
}
|
||||
isQuery := payload[0] == 0
|
||||
payload = payload[1:]
|
||||
algoBytes, payload, ok := parseString(payload)
|
||||
if !ok {
|
||||
return ParseError{msgUserAuthRequest}
|
||||
}
|
||||
algo := string(algoBytes)
|
||||
|
||||
pubKey, payload, ok := parseString(payload)
|
||||
if !ok {
|
||||
return ParseError{msgUserAuthRequest}
|
||||
}
|
||||
if isQuery {
|
||||
// The client can query if the given public key
|
||||
// would be ok.
|
||||
if len(payload) > 0 {
|
||||
return ParseError{msgUserAuthRequest}
|
||||
}
|
||||
if s.testPubKey(userAuthReq.User, algo, pubKey) {
|
||||
okMsg := userAuthPubKeyOkMsg{
|
||||
Algo: algo,
|
||||
PubKey: string(pubKey),
|
||||
}
|
||||
if err = s.out.writePacket(marshal(msgUserAuthPubKeyOk, okMsg)); err != nil {
|
||||
return err
|
||||
}
|
||||
continue userAuthLoop
|
||||
}
|
||||
} else {
|
||||
sig, payload, ok := parseString(payload)
|
||||
if !ok || len(payload) > 0 {
|
||||
return ParseError{msgUserAuthRequest}
|
||||
}
|
||||
if !isAcceptableAlgo(algo) {
|
||||
break
|
||||
}
|
||||
rsaSig, ok := parseRSASig(sig)
|
||||
if !ok {
|
||||
return ParseError{msgUserAuthRequest}
|
||||
}
|
||||
signedData := buildDataSignedForAuth(H, userAuthReq, algoBytes, pubKey)
|
||||
switch algo {
|
||||
case hostAlgoRSA:
|
||||
hashFunc := crypto.SHA1
|
||||
h := hashFunc.New()
|
||||
h.Write(signedData)
|
||||
digest := h.Sum()
|
||||
rsaKey, ok := parseRSA(pubKey)
|
||||
if !ok {
|
||||
return ParseError{msgUserAuthRequest}
|
||||
}
|
||||
if rsa.VerifyPKCS1v15(rsaKey, hashFunc, digest, rsaSig) != nil {
|
||||
return ParseError{msgUserAuthRequest}
|
||||
}
|
||||
default:
|
||||
return os.NewError("ssh: isAcceptableAlgo incorrect")
|
||||
}
|
||||
if s.testPubKey(userAuthReq.User, algo, pubKey) {
|
||||
break userAuthLoop
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var failureMsg userAuthFailureMsg
|
||||
if s.Server.PasswordCallback != nil {
|
||||
failureMsg.Methods = append(failureMsg.Methods, "password")
|
||||
}
|
||||
if s.Server.PubKeyCallback != nil {
|
||||
failureMsg.Methods = append(failureMsg.Methods, "publickey")
|
||||
}
|
||||
|
||||
if len(failureMsg.Methods) == 0 {
|
||||
return os.NewError("ssh: no authentication methods configured but NoClientAuth is also false")
|
||||
}
|
||||
|
||||
if err = s.out.writePacket(marshal(msgUserAuthFailure, failureMsg)); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
packet = []byte{msgUserAuthSuccess}
|
||||
if err = s.out.writePacket(packet); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
const defaultWindowSize = 32768
|
||||
|
||||
// Accept reads and processes messages on a ServerConnection. It must be called
|
||||
// in order to demultiplex messages to any resulting Channels.
|
||||
func (s *ServerConnection) Accept() (Channel, os.Error) {
|
||||
if s.err != nil {
|
||||
return nil, s.err
|
||||
}
|
||||
|
||||
for {
|
||||
packet, err := s.in.readPacket()
|
||||
if err != nil {
|
||||
|
||||
s.lock.Lock()
|
||||
s.err = err
|
||||
s.lock.Unlock()
|
||||
|
||||
for _, c := range s.channels {
|
||||
c.dead = true
|
||||
c.handleData(nil)
|
||||
}
|
||||
|
||||
return nil, err
|
||||
}
|
||||
|
||||
switch packet[0] {
|
||||
case msgChannelOpen:
|
||||
var chanOpen channelOpenMsg
|
||||
if err := unmarshal(&chanOpen, packet, msgChannelOpen); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
c := new(channel)
|
||||
c.chanType = chanOpen.ChanType
|
||||
c.theirId = chanOpen.PeersId
|
||||
c.theirWindow = chanOpen.PeersWindow
|
||||
c.maxPacketSize = chanOpen.MaxPacketSize
|
||||
c.extraData = chanOpen.TypeSpecificData
|
||||
c.myWindow = defaultWindowSize
|
||||
c.serverConn = s
|
||||
c.cond = sync.NewCond(&c.lock)
|
||||
c.pendingData = make([]byte, c.myWindow)
|
||||
|
||||
s.lock.Lock()
|
||||
c.myId = s.nextChanId
|
||||
s.nextChanId++
|
||||
s.channels[c.myId] = c
|
||||
s.lock.Unlock()
|
||||
return c, nil
|
||||
|
||||
case msgChannelRequest:
|
||||
var chanRequest channelRequestMsg
|
||||
if err := unmarshal(&chanRequest, packet, msgChannelRequest); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
s.lock.Lock()
|
||||
c, ok := s.channels[chanRequest.PeersId]
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
c.handlePacket(&chanRequest)
|
||||
s.lock.Unlock()
|
||||
|
||||
case msgChannelData:
|
||||
if len(packet) < 5 {
|
||||
return nil, ParseError{msgChannelData}
|
||||
}
|
||||
chanId := uint32(packet[1])<<24 | uint32(packet[2])<<16 | uint32(packet[3])<<8 | uint32(packet[4])
|
||||
|
||||
s.lock.Lock()
|
||||
c, ok := s.channels[chanId]
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
c.handleData(packet[9:])
|
||||
s.lock.Unlock()
|
||||
|
||||
case msgChannelEOF:
|
||||
var eofMsg channelEOFMsg
|
||||
if err := unmarshal(&eofMsg, packet, msgChannelEOF); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
s.lock.Lock()
|
||||
c, ok := s.channels[eofMsg.PeersId]
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
c.handlePacket(&eofMsg)
|
||||
s.lock.Unlock()
|
||||
|
||||
case msgChannelClose:
|
||||
var closeMsg channelCloseMsg
|
||||
if err := unmarshal(&closeMsg, packet, msgChannelClose); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
s.lock.Lock()
|
||||
c, ok := s.channels[closeMsg.PeersId]
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
c.handlePacket(&closeMsg)
|
||||
s.lock.Unlock()
|
||||
|
||||
case msgGlobalRequest:
|
||||
var request globalRequestMsg
|
||||
if err := unmarshal(&request, packet, msgGlobalRequest); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if request.WantReply {
|
||||
if err := s.out.writePacket([]byte{msgRequestFailure}); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
default:
|
||||
// Unknown message. Ignore.
|
||||
}
|
||||
}
|
||||
|
||||
panic("unreachable")
|
||||
}
|
399
src/pkg/exp/ssh/server_shell.go
Normal file
399
src/pkg/exp/ssh/server_shell.go
Normal file
@ -0,0 +1,399 @@
|
||||
// Copyright 2011 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package ssh
|
||||
|
||||
import (
|
||||
"os"
|
||||
)
|
||||
|
||||
// ServerShell contains the state for running a VT100 terminal that is capable
|
||||
// of reading lines of input.
|
||||
type ServerShell struct {
|
||||
c Channel
|
||||
prompt string
|
||||
|
||||
// line is the current line being entered.
|
||||
line []byte
|
||||
// pos is the logical position of the cursor in line
|
||||
pos int
|
||||
|
||||
// cursorX contains the current X value of the cursor where the left
|
||||
// edge is 0. cursorY contains the row number where the first row of
|
||||
// the current line is 0.
|
||||
cursorX, cursorY int
|
||||
// maxLine is the greatest value of cursorY so far.
|
||||
maxLine int
|
||||
|
||||
termWidth, termHeight int
|
||||
|
||||
// outBuf contains the terminal data to be sent.
|
||||
outBuf []byte
|
||||
// remainder contains the remainder of any partial key sequences after
|
||||
// a read. It aliases into inBuf.
|
||||
remainder []byte
|
||||
inBuf [256]byte
|
||||
}
|
||||
|
||||
// NewServerShell runs a VT100 terminal on the given channel. prompt is a
|
||||
// string that is written at the start of each input line. For example: "> ".
|
||||
func NewServerShell(c Channel, prompt string) *ServerShell {
|
||||
return &ServerShell{
|
||||
c: c,
|
||||
prompt: prompt,
|
||||
termWidth: 80,
|
||||
termHeight: 24,
|
||||
}
|
||||
}
|
||||
|
||||
const (
|
||||
keyCtrlD = 4
|
||||
keyEnter = '\r'
|
||||
keyEscape = 27
|
||||
keyBackspace = 127
|
||||
keyUnknown = 256 + iota
|
||||
keyUp
|
||||
keyDown
|
||||
keyLeft
|
||||
keyRight
|
||||
keyAltLeft
|
||||
keyAltRight
|
||||
)
|
||||
|
||||
// bytesToKey tries to parse a key sequence from b. If successful, it returns
|
||||
// the key and the remainder of the input. Otherwise it returns -1.
|
||||
func bytesToKey(b []byte) (int, []byte) {
|
||||
if len(b) == 0 {
|
||||
return -1, nil
|
||||
}
|
||||
|
||||
if b[0] != keyEscape {
|
||||
return int(b[0]), b[1:]
|
||||
}
|
||||
|
||||
if len(b) >= 3 && b[0] == keyEscape && b[1] == '[' {
|
||||
switch b[2] {
|
||||
case 'A':
|
||||
return keyUp, b[3:]
|
||||
case 'B':
|
||||
return keyDown, b[3:]
|
||||
case 'C':
|
||||
return keyRight, b[3:]
|
||||
case 'D':
|
||||
return keyLeft, b[3:]
|
||||
}
|
||||
}
|
||||
|
||||
if len(b) >= 6 && b[0] == keyEscape && b[1] == '[' && b[2] == '1' && b[3] == ';' && b[4] == '3' {
|
||||
switch b[5] {
|
||||
case 'C':
|
||||
return keyAltRight, b[6:]
|
||||
case 'D':
|
||||
return keyAltLeft, b[6:]
|
||||
}
|
||||
}
|
||||
|
||||
// If we get here then we have a key that we don't recognise, or a
|
||||
// partial sequence. It's not clear how one should find the end of a
|
||||
// sequence without knowing them all, but it seems that [a-zA-Z] only
|
||||
// appears at the end of a sequence.
|
||||
for i, c := range b[0:] {
|
||||
if c >= 'a' && c <= 'z' || c >= 'A' && c <= 'Z' {
|
||||
return keyUnknown, b[i+1:]
|
||||
}
|
||||
}
|
||||
|
||||
return -1, b
|
||||
}
|
||||
|
||||
// queue appends data to the end of ss.outBuf
|
||||
func (ss *ServerShell) queue(data []byte) {
|
||||
if len(ss.outBuf)+len(data) > cap(ss.outBuf) {
|
||||
newOutBuf := make([]byte, len(ss.outBuf), 2*(len(ss.outBuf)+len(data)))
|
||||
copy(newOutBuf, ss.outBuf)
|
||||
ss.outBuf = newOutBuf
|
||||
}
|
||||
|
||||
oldLen := len(ss.outBuf)
|
||||
ss.outBuf = ss.outBuf[:len(ss.outBuf)+len(data)]
|
||||
copy(ss.outBuf[oldLen:], data)
|
||||
}
|
||||
|
||||
var eraseUnderCursor = []byte{' ', keyEscape, '[', 'D'}
|
||||
|
||||
func isPrintable(key int) bool {
|
||||
return key >= 32 && key < 127
|
||||
}
|
||||
|
||||
// moveCursorToPos appends data to ss.outBuf which will move the cursor to the
|
||||
// given, logical position in the text.
|
||||
func (ss *ServerShell) moveCursorToPos(pos int) {
|
||||
x := len(ss.prompt) + pos
|
||||
y := x / ss.termWidth
|
||||
x = x % ss.termWidth
|
||||
|
||||
up := 0
|
||||
if y < ss.cursorY {
|
||||
up = ss.cursorY - y
|
||||
}
|
||||
|
||||
down := 0
|
||||
if y > ss.cursorY {
|
||||
down = y - ss.cursorY
|
||||
}
|
||||
|
||||
left := 0
|
||||
if x < ss.cursorX {
|
||||
left = ss.cursorX - x
|
||||
}
|
||||
|
||||
right := 0
|
||||
if x > ss.cursorX {
|
||||
right = x - ss.cursorX
|
||||
}
|
||||
|
||||
movement := make([]byte, 3*(up+down+left+right))
|
||||
m := movement
|
||||
for i := 0; i < up; i++ {
|
||||
m[0] = keyEscape
|
||||
m[1] = '['
|
||||
m[2] = 'A'
|
||||
m = m[3:]
|
||||
}
|
||||
for i := 0; i < down; i++ {
|
||||
m[0] = keyEscape
|
||||
m[1] = '['
|
||||
m[2] = 'B'
|
||||
m = m[3:]
|
||||
}
|
||||
for i := 0; i < left; i++ {
|
||||
m[0] = keyEscape
|
||||
m[1] = '['
|
||||
m[2] = 'D'
|
||||
m = m[3:]
|
||||
}
|
||||
for i := 0; i < right; i++ {
|
||||
m[0] = keyEscape
|
||||
m[1] = '['
|
||||
m[2] = 'C'
|
||||
m = m[3:]
|
||||
}
|
||||
|
||||
ss.cursorX = x
|
||||
ss.cursorY = y
|
||||
ss.queue(movement)
|
||||
}
|
||||
|
||||
const maxLineLength = 4096
|
||||
|
||||
// handleKey processes the given key and, optionally, returns a line of text
|
||||
// that the user has entered.
|
||||
func (ss *ServerShell) handleKey(key int) (line string, ok bool) {
|
||||
switch key {
|
||||
case keyBackspace:
|
||||
if ss.pos == 0 {
|
||||
return
|
||||
}
|
||||
ss.pos--
|
||||
|
||||
copy(ss.line[ss.pos:], ss.line[1+ss.pos:])
|
||||
ss.line = ss.line[:len(ss.line)-1]
|
||||
ss.writeLine(ss.line[ss.pos:])
|
||||
ss.moveCursorToPos(ss.pos)
|
||||
ss.queue(eraseUnderCursor)
|
||||
case keyAltLeft:
|
||||
// move left by a word.
|
||||
if ss.pos == 0 {
|
||||
return
|
||||
}
|
||||
ss.pos--
|
||||
for ss.pos > 0 {
|
||||
if ss.line[ss.pos] != ' ' {
|
||||
break
|
||||
}
|
||||
ss.pos--
|
||||
}
|
||||
for ss.pos > 0 {
|
||||
if ss.line[ss.pos] == ' ' {
|
||||
ss.pos++
|
||||
break
|
||||
}
|
||||
ss.pos--
|
||||
}
|
||||
ss.moveCursorToPos(ss.pos)
|
||||
case keyAltRight:
|
||||
// move right by a word.
|
||||
for ss.pos < len(ss.line) {
|
||||
if ss.line[ss.pos] == ' ' {
|
||||
break
|
||||
}
|
||||
ss.pos++
|
||||
}
|
||||
for ss.pos < len(ss.line) {
|
||||
if ss.line[ss.pos] != ' ' {
|
||||
break
|
||||
}
|
||||
ss.pos++
|
||||
}
|
||||
ss.moveCursorToPos(ss.pos)
|
||||
case keyLeft:
|
||||
if ss.pos == 0 {
|
||||
return
|
||||
}
|
||||
ss.pos--
|
||||
ss.moveCursorToPos(ss.pos)
|
||||
case keyRight:
|
||||
if ss.pos == len(ss.line) {
|
||||
return
|
||||
}
|
||||
ss.pos++
|
||||
ss.moveCursorToPos(ss.pos)
|
||||
case keyEnter:
|
||||
ss.moveCursorToPos(len(ss.line))
|
||||
ss.queue([]byte("\r\n"))
|
||||
line = string(ss.line)
|
||||
ok = true
|
||||
ss.line = ss.line[:0]
|
||||
ss.pos = 0
|
||||
ss.cursorX = 0
|
||||
ss.cursorY = 0
|
||||
ss.maxLine = 0
|
||||
default:
|
||||
if !isPrintable(key) {
|
||||
return
|
||||
}
|
||||
if len(ss.line) == maxLineLength {
|
||||
return
|
||||
}
|
||||
if len(ss.line) == cap(ss.line) {
|
||||
newLine := make([]byte, len(ss.line), 2*(1+len(ss.line)))
|
||||
copy(newLine, ss.line)
|
||||
ss.line = newLine
|
||||
}
|
||||
ss.line = ss.line[:len(ss.line)+1]
|
||||
copy(ss.line[ss.pos+1:], ss.line[ss.pos:])
|
||||
ss.line[ss.pos] = byte(key)
|
||||
ss.writeLine(ss.line[ss.pos:])
|
||||
ss.pos++
|
||||
ss.moveCursorToPos(ss.pos)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (ss *ServerShell) writeLine(line []byte) {
|
||||
for len(line) != 0 {
|
||||
if ss.cursorX == ss.termWidth {
|
||||
ss.queue([]byte("\r\n"))
|
||||
ss.cursorX = 0
|
||||
ss.cursorY++
|
||||
if ss.cursorY > ss.maxLine {
|
||||
ss.maxLine = ss.cursorY
|
||||
}
|
||||
}
|
||||
|
||||
remainingOnLine := ss.termWidth - ss.cursorX
|
||||
todo := len(line)
|
||||
if todo > remainingOnLine {
|
||||
todo = remainingOnLine
|
||||
}
|
||||
ss.queue(line[:todo])
|
||||
ss.cursorX += todo
|
||||
line = line[todo:]
|
||||
}
|
||||
}
|
||||
|
||||
// parsePtyRequest parses the payload of the pty-req message and extracts the
|
||||
// dimensions of the terminal. See RFC 4254, section 6.2.
|
||||
func parsePtyRequest(s []byte) (width, height int, ok bool) {
|
||||
_, s, ok = parseString(s)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
width32, s, ok := parseUint32(s)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
height32, _, ok := parseUint32(s)
|
||||
width = int(width32)
|
||||
height = int(height32)
|
||||
if width < 1 {
|
||||
ok = false
|
||||
}
|
||||
if height < 1 {
|
||||
ok = false
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (ss *ServerShell) Write(buf []byte) (n int, err os.Error) {
|
||||
return ss.c.Write(buf)
|
||||
}
|
||||
|
||||
// ReadLine returns a line of input from the terminal.
|
||||
func (ss *ServerShell) ReadLine() (line string, err os.Error) {
|
||||
ss.writeLine([]byte(ss.prompt))
|
||||
ss.c.Write(ss.outBuf)
|
||||
ss.outBuf = ss.outBuf[:0]
|
||||
|
||||
for {
|
||||
// ss.remainder is a slice at the beginning of ss.inBuf
|
||||
// containing a partial key sequence
|
||||
readBuf := ss.inBuf[len(ss.remainder):]
|
||||
n, err := ss.c.Read(readBuf)
|
||||
if err == nil {
|
||||
ss.remainder = ss.inBuf[:n+len(ss.remainder)]
|
||||
rest := ss.remainder
|
||||
lineOk := false
|
||||
for !lineOk {
|
||||
var key int
|
||||
key, rest = bytesToKey(rest)
|
||||
if key < 0 {
|
||||
break
|
||||
}
|
||||
if key == keyCtrlD {
|
||||
return "", os.EOF
|
||||
}
|
||||
line, lineOk = ss.handleKey(key)
|
||||
}
|
||||
if len(rest) > 0 {
|
||||
n := copy(ss.inBuf[:], rest)
|
||||
ss.remainder = ss.inBuf[:n]
|
||||
} else {
|
||||
ss.remainder = nil
|
||||
}
|
||||
ss.c.Write(ss.outBuf)
|
||||
ss.outBuf = ss.outBuf[:0]
|
||||
if lineOk {
|
||||
return
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
if req, ok := err.(ChannelRequest); ok {
|
||||
ok := false
|
||||
switch req.Request {
|
||||
case "pty-req":
|
||||
ss.termWidth, ss.termHeight, ok = parsePtyRequest(req.Payload)
|
||||
if !ok {
|
||||
ss.termWidth = 80
|
||||
ss.termHeight = 24
|
||||
}
|
||||
case "shell":
|
||||
ok = true
|
||||
if len(req.Payload) > 0 {
|
||||
// We don't accept any commands, only the default shell.
|
||||
ok = false
|
||||
}
|
||||
case "env":
|
||||
ok = true
|
||||
}
|
||||
if req.WantReply {
|
||||
ss.c.AckRequest(ok)
|
||||
}
|
||||
} else {
|
||||
return "", err
|
||||
}
|
||||
}
|
||||
panic("unreachable")
|
||||
}
|
134
src/pkg/exp/ssh/server_shell_test.go
Normal file
134
src/pkg/exp/ssh/server_shell_test.go
Normal file
@ -0,0 +1,134 @@
|
||||
// Copyright 2011 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package ssh
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"os"
|
||||
)
|
||||
|
||||
type MockChannel struct {
|
||||
toSend []byte
|
||||
bytesPerRead int
|
||||
received []byte
|
||||
}
|
||||
|
||||
func (c *MockChannel) Accept() os.Error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *MockChannel) Reject(RejectionReason, string) os.Error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *MockChannel) Read(data []byte) (n int, err os.Error) {
|
||||
n = len(data)
|
||||
if n == 0 {
|
||||
return
|
||||
}
|
||||
if n > len(c.toSend) {
|
||||
n = len(c.toSend)
|
||||
}
|
||||
if n == 0 {
|
||||
return 0, os.EOF
|
||||
}
|
||||
if c.bytesPerRead > 0 && n > c.bytesPerRead {
|
||||
n = c.bytesPerRead
|
||||
}
|
||||
copy(data, c.toSend[:n])
|
||||
c.toSend = c.toSend[n:]
|
||||
return
|
||||
}
|
||||
|
||||
func (c *MockChannel) Write(data []byte) (n int, err os.Error) {
|
||||
c.received = append(c.received, data...)
|
||||
return len(data), nil
|
||||
}
|
||||
|
||||
func (c *MockChannel) Close() os.Error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *MockChannel) AckRequest(ok bool) os.Error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *MockChannel) ChannelType() string {
|
||||
return ""
|
||||
}
|
||||
|
||||
func (c *MockChannel) ExtraData() []byte {
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestClose(t *testing.T) {
|
||||
c := &MockChannel{}
|
||||
ss := NewServerShell(c, "> ")
|
||||
line, err := ss.ReadLine()
|
||||
if line != "" {
|
||||
t.Errorf("Expected empty line but got: %s", line)
|
||||
}
|
||||
if err != os.EOF {
|
||||
t.Errorf("Error should have been EOF but got: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
var keyPressTests = []struct {
|
||||
in string
|
||||
line string
|
||||
err os.Error
|
||||
}{
|
||||
{
|
||||
"",
|
||||
"",
|
||||
os.EOF,
|
||||
},
|
||||
{
|
||||
"\r",
|
||||
"",
|
||||
nil,
|
||||
},
|
||||
{
|
||||
"foo\r",
|
||||
"foo",
|
||||
nil,
|
||||
},
|
||||
{
|
||||
"a\x1b[Cb\r", // right
|
||||
"ab",
|
||||
nil,
|
||||
},
|
||||
{
|
||||
"a\x1b[Db\r", // left
|
||||
"ba",
|
||||
nil,
|
||||
},
|
||||
{
|
||||
"a\177b\r", // backspace
|
||||
"b",
|
||||
nil,
|
||||
},
|
||||
}
|
||||
|
||||
func TestKeyPresses(t *testing.T) {
|
||||
for i, test := range keyPressTests {
|
||||
for j := 0; j < len(test.in); j++ {
|
||||
c := &MockChannel{
|
||||
toSend: []byte(test.in),
|
||||
bytesPerRead: j,
|
||||
}
|
||||
ss := NewServerShell(c, "> ")
|
||||
line, err := ss.ReadLine()
|
||||
if line != test.line {
|
||||
t.Errorf("Line resulting from test %d (%d bytes per read) was '%s', expected '%s'", i, j, line, test.line)
|
||||
break
|
||||
}
|
||||
if err != test.err {
|
||||
t.Errorf("Error resulting from test %d (%d bytes per read) was '%v', expected '%v'", i, j, err, test.err)
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
308
src/pkg/exp/ssh/transport.go
Normal file
308
src/pkg/exp/ssh/transport.go
Normal file
@ -0,0 +1,308 @@
|
||||
// Copyright 2011 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package ssh
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"crypto"
|
||||
"crypto/aes"
|
||||
"crypto/cipher"
|
||||
"crypto/hmac"
|
||||
"crypto/subtle"
|
||||
"hash"
|
||||
"io"
|
||||
"net"
|
||||
"os"
|
||||
)
|
||||
|
||||
// halfConnection represents one direction of an SSH connection. It maintains
|
||||
// the cipher state needed to process messages.
|
||||
type halfConnection struct {
|
||||
// Only one of these two will be non-nil
|
||||
in *bufio.Reader
|
||||
out net.Conn
|
||||
|
||||
rand io.Reader
|
||||
cipherAlgo string
|
||||
macAlgo string
|
||||
compressionAlgo string
|
||||
paddingMultiple int
|
||||
|
||||
seqNum uint32
|
||||
|
||||
mac hash.Hash
|
||||
cipher cipher.Stream
|
||||
}
|
||||
|
||||
func (hc *halfConnection) readOnePacket() (packet []byte, err os.Error) {
|
||||
var lengthBytes [5]byte
|
||||
|
||||
_, err = io.ReadFull(hc.in, lengthBytes[:])
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if hc.cipher != nil {
|
||||
hc.cipher.XORKeyStream(lengthBytes[:], lengthBytes[:])
|
||||
}
|
||||
|
||||
macSize := 0
|
||||
if hc.mac != nil {
|
||||
hc.mac.Reset()
|
||||
var seqNumBytes [4]byte
|
||||
seqNumBytes[0] = byte(hc.seqNum >> 24)
|
||||
seqNumBytes[1] = byte(hc.seqNum >> 16)
|
||||
seqNumBytes[2] = byte(hc.seqNum >> 8)
|
||||
seqNumBytes[3] = byte(hc.seqNum)
|
||||
hc.mac.Write(seqNumBytes[:])
|
||||
hc.mac.Write(lengthBytes[:])
|
||||
macSize = hc.mac.Size()
|
||||
}
|
||||
|
||||
length := uint32(lengthBytes[0])<<24 | uint32(lengthBytes[1])<<16 | uint32(lengthBytes[2])<<8 | uint32(lengthBytes[3])
|
||||
|
||||
paddingLength := uint32(lengthBytes[4])
|
||||
|
||||
if length <= paddingLength+1 {
|
||||
return nil, os.NewError("invalid packet length")
|
||||
}
|
||||
if length > maxPacketSize {
|
||||
return nil, os.NewError("packet too large")
|
||||
}
|
||||
|
||||
packet = make([]byte, length-1+uint32(macSize))
|
||||
_, err = io.ReadFull(hc.in, packet)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
mac := packet[length-1:]
|
||||
if hc.cipher != nil {
|
||||
hc.cipher.XORKeyStream(packet, packet[:length-1])
|
||||
}
|
||||
|
||||
if hc.mac != nil {
|
||||
hc.mac.Write(packet[:length-1])
|
||||
if subtle.ConstantTimeCompare(hc.mac.Sum(), mac) != 1 {
|
||||
return nil, os.NewError("ssh: MAC failure")
|
||||
}
|
||||
}
|
||||
|
||||
hc.seqNum++
|
||||
packet = packet[:length-paddingLength-1]
|
||||
return
|
||||
}
|
||||
|
||||
func (hc *halfConnection) readPacket() (packet []byte, err os.Error) {
|
||||
for {
|
||||
packet, err := hc.readOnePacket()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if packet[0] != msgIgnore && packet[0] != msgDebug {
|
||||
return packet, nil
|
||||
}
|
||||
}
|
||||
panic("unreachable")
|
||||
}
|
||||
|
||||
func (hc *halfConnection) writePacket(packet []byte) os.Error {
|
||||
paddingMultiple := hc.paddingMultiple
|
||||
if paddingMultiple == 0 {
|
||||
paddingMultiple = 8
|
||||
}
|
||||
|
||||
paddingLength := paddingMultiple - (4+1+len(packet))%paddingMultiple
|
||||
if paddingLength < 4 {
|
||||
paddingLength += paddingMultiple
|
||||
}
|
||||
|
||||
var lengthBytes [5]byte
|
||||
length := len(packet) + 1 + paddingLength
|
||||
lengthBytes[0] = byte(length >> 24)
|
||||
lengthBytes[1] = byte(length >> 16)
|
||||
lengthBytes[2] = byte(length >> 8)
|
||||
lengthBytes[3] = byte(length)
|
||||
lengthBytes[4] = byte(paddingLength)
|
||||
|
||||
var padding [32]byte
|
||||
_, err := io.ReadFull(hc.rand, padding[:paddingLength])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if hc.mac != nil {
|
||||
hc.mac.Reset()
|
||||
var seqNumBytes [4]byte
|
||||
seqNumBytes[0] = byte(hc.seqNum >> 24)
|
||||
seqNumBytes[1] = byte(hc.seqNum >> 16)
|
||||
seqNumBytes[2] = byte(hc.seqNum >> 8)
|
||||
seqNumBytes[3] = byte(hc.seqNum)
|
||||
hc.mac.Write(seqNumBytes[:])
|
||||
hc.mac.Write(lengthBytes[:])
|
||||
hc.mac.Write(packet)
|
||||
hc.mac.Write(padding[:paddingLength])
|
||||
}
|
||||
|
||||
if hc.cipher != nil {
|
||||
hc.cipher.XORKeyStream(lengthBytes[:], lengthBytes[:])
|
||||
hc.cipher.XORKeyStream(packet, packet)
|
||||
hc.cipher.XORKeyStream(padding[:], padding[:paddingLength])
|
||||
}
|
||||
|
||||
_, err = hc.out.Write(lengthBytes[:])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = hc.out.Write(packet)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = hc.out.Write(padding[:paddingLength])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if hc.mac != nil {
|
||||
_, err = hc.out.Write(hc.mac.Sum())
|
||||
}
|
||||
|
||||
hc.seqNum++
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
const (
|
||||
serverKeys = iota
|
||||
clientKeys
|
||||
)
|
||||
|
||||
// setupServerKeys sets the cipher and MAC keys from K, H and sessionId, as
|
||||
// described in RFC 4253, section 6.4. direction should either be serverKeys
|
||||
// (to setup server->client keys) or clientKeys (for client->server keys).
|
||||
func (hc *halfConnection) setupKeys(direction int, K, H, sessionId []byte, hashFunc crypto.Hash) os.Error {
|
||||
h := hashFunc.New()
|
||||
|
||||
// We only support these algorithms for now.
|
||||
if hc.cipherAlgo != cipherAES128CTR || hc.macAlgo != macSHA196 {
|
||||
return os.NewError("ssh: setupServerKeys internal error")
|
||||
}
|
||||
|
||||
blockSize := 16
|
||||
keySize := 16
|
||||
macKeySize := 20
|
||||
|
||||
var ivTag, keyTag, macKeyTag byte
|
||||
if direction == serverKeys {
|
||||
ivTag, keyTag, macKeyTag = 'B', 'D', 'F'
|
||||
} else {
|
||||
ivTag, keyTag, macKeyTag = 'A', 'C', 'E'
|
||||
}
|
||||
|
||||
iv := make([]byte, blockSize)
|
||||
key := make([]byte, keySize)
|
||||
macKey := make([]byte, macKeySize)
|
||||
generateKeyMaterial(iv, ivTag, K, H, sessionId, h)
|
||||
generateKeyMaterial(key, keyTag, K, H, sessionId, h)
|
||||
generateKeyMaterial(macKey, macKeyTag, K, H, sessionId, h)
|
||||
|
||||
hc.mac = truncatingMAC{12, hmac.NewSHA1(macKey)}
|
||||
aes, err := aes.NewCipher(key)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
hc.cipher = cipher.NewCTR(aes, iv)
|
||||
hc.paddingMultiple = 16
|
||||
return nil
|
||||
}
|
||||
|
||||
// generateKeyMaterial fills out with key material generated from tag, K, H
|
||||
// and sessionId, as specified in RFC 4253, section 7.2.
|
||||
func generateKeyMaterial(out []byte, tag byte, K, H, sessionId []byte, h hash.Hash) {
|
||||
var digestsSoFar []byte
|
||||
|
||||
for len(out) > 0 {
|
||||
h.Reset()
|
||||
h.Write(K)
|
||||
h.Write(H)
|
||||
|
||||
if len(digestsSoFar) == 0 {
|
||||
h.Write([]byte{tag})
|
||||
h.Write(sessionId)
|
||||
} else {
|
||||
h.Write(digestsSoFar)
|
||||
}
|
||||
|
||||
digest := h.Sum()
|
||||
n := copy(out, digest)
|
||||
out = out[n:]
|
||||
if len(out) > 0 {
|
||||
digestsSoFar = append(digestsSoFar, digest...)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// truncatingMAC wraps around a hash.Hash and truncates the output digest to
|
||||
// a given size.
|
||||
type truncatingMAC struct {
|
||||
length int
|
||||
hmac hash.Hash
|
||||
}
|
||||
|
||||
func (t truncatingMAC) Write(data []byte) (int, os.Error) {
|
||||
return t.hmac.Write(data)
|
||||
}
|
||||
|
||||
func (t truncatingMAC) Sum() []byte {
|
||||
digest := t.hmac.Sum()
|
||||
return digest[:t.length]
|
||||
}
|
||||
|
||||
func (t truncatingMAC) Reset() {
|
||||
t.hmac.Reset()
|
||||
}
|
||||
|
||||
func (t truncatingMAC) Size() int {
|
||||
return t.length
|
||||
}
|
||||
|
||||
// maxVersionStringBytes is the maximum number of bytes that we'll accept as a
|
||||
// version string. In the event that the client is talking a different protocol
|
||||
// we need to set a limit otherwise we will keep using more and more memory
|
||||
// while searching for the end of the version handshake.
|
||||
const maxVersionStringBytes = 1024
|
||||
|
||||
func readVersion(r *bufio.Reader) (versionString []byte, ok bool) {
|
||||
versionString = make([]byte, 0, 64)
|
||||
seenCR := false
|
||||
|
||||
forEachByte:
|
||||
for len(versionString) < maxVersionStringBytes {
|
||||
b, err := r.ReadByte()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if !seenCR {
|
||||
if b == '\r' {
|
||||
seenCR = true
|
||||
}
|
||||
} else {
|
||||
if b == '\n' {
|
||||
ok = true
|
||||
break forEachByte
|
||||
} else {
|
||||
seenCR = false
|
||||
}
|
||||
}
|
||||
versionString = append(versionString, b)
|
||||
}
|
||||
|
||||
if ok {
|
||||
// We need to remove the CR from versionString
|
||||
versionString = versionString[:len(versionString)-1]
|
||||
}
|
||||
|
||||
return
|
||||
}
|
Loading…
Reference in New Issue
Block a user