From 605e57d8fee696238f3338c415043f16a7743731 Mon Sep 17 00:00:00 2001 From: Adam Langley Date: Sat, 17 Sep 2011 15:57:24 -0400 Subject: [PATCH] exp/ssh: new package. The typical UNIX method for controlling long running process is to send the process signals. Since this doesn't get you very far, various ad-hoc, remote-control protocols have been used over time by programs like Apache and BIND. Implementing an SSH server means that Go code will have a standard, secure way to do this in the future. R=bradfitz, borman, dave, gustavo, dsymonds, r, adg, rsc, rogpeppe, lvd, kevlar, raul.san CC=golang-dev https://golang.org/cl/4962064 --- src/pkg/exp/ssh/Makefile | 16 + src/pkg/exp/ssh/channel.go | 317 ++++++++++++ src/pkg/exp/ssh/common.go | 96 ++++ src/pkg/exp/ssh/doc.go | 79 +++ src/pkg/exp/ssh/messages.go | 557 +++++++++++++++++++++ src/pkg/exp/ssh/messages_test.go | 125 +++++ src/pkg/exp/ssh/server.go | 711 +++++++++++++++++++++++++++ src/pkg/exp/ssh/server_shell.go | 399 +++++++++++++++ src/pkg/exp/ssh/server_shell_test.go | 134 +++++ src/pkg/exp/ssh/transport.go | 308 ++++++++++++ 10 files changed, 2742 insertions(+) create mode 100644 src/pkg/exp/ssh/Makefile create mode 100644 src/pkg/exp/ssh/channel.go create mode 100644 src/pkg/exp/ssh/common.go create mode 100644 src/pkg/exp/ssh/doc.go create mode 100644 src/pkg/exp/ssh/messages.go create mode 100644 src/pkg/exp/ssh/messages_test.go create mode 100644 src/pkg/exp/ssh/server.go create mode 100644 src/pkg/exp/ssh/server_shell.go create mode 100644 src/pkg/exp/ssh/server_shell_test.go create mode 100644 src/pkg/exp/ssh/transport.go diff --git a/src/pkg/exp/ssh/Makefile b/src/pkg/exp/ssh/Makefile new file mode 100644 index 00000000000..e8f33b708c3 --- /dev/null +++ b/src/pkg/exp/ssh/Makefile @@ -0,0 +1,16 @@ +# Copyright 2011 The Go Authors. All rights reserved. +# Use of this source code is governed by a BSD-style +# license that can be found in the LICENSE file. + +include ../../../Make.inc + +TARG=exp/ssh +GOFILES=\ + common.go\ + messages.go\ + server.go\ + transport.go\ + channel.go\ + server_shell.go\ + +include ../../../Make.pkg diff --git a/src/pkg/exp/ssh/channel.go b/src/pkg/exp/ssh/channel.go new file mode 100644 index 00000000000..10f62354f43 --- /dev/null +++ b/src/pkg/exp/ssh/channel.go @@ -0,0 +1,317 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package ssh + +import ( + "os" + "sync" +) + +// A Channel is an ordered, reliable, duplex stream that is multiplexed over an +// SSH connection. +type Channel interface { + // Accept accepts the channel creation request. + Accept() os.Error + // Reject rejects the channel creation request. After calling this, no + // other methods on the Channel may be called. If they are then the + // peer is likely to signal a protocol error and drop the connection. + Reject(reason RejectionReason, message string) os.Error + + // Read may return a ChannelRequest as an os.Error. + Read(data []byte) (int, os.Error) + Write(data []byte) (int, os.Error) + Close() os.Error + + // AckRequest either sends an ack or nack to the channel request. + AckRequest(ok bool) os.Error + + // ChannelType returns the type of the channel, as supplied by the + // client. + ChannelType() string + // ExtraData returns the arbitary payload for this channel, as supplied + // by the client. This data is specific to the channel type. + ExtraData() []byte +} + +// ChannelRequest represents a request sent on a channel, outside of the normal +// stream of bytes. It may result from calling Read on a Channel. +type ChannelRequest struct { + Request string + WantReply bool + Payload []byte +} + +func (c ChannelRequest) String() string { + return "channel request received" +} + +// RejectionReason is an enumeration used when rejecting channel creation +// requests. See RFC 4254, section 5.1. +type RejectionReason int + +const ( + Prohibited RejectionReason = iota + 1 + ConnectionFailed + UnknownChannelType + ResourceShortage +) + +type channel struct { + // immutable once created + chanType string + extraData []byte + + theyClosed bool + theySentEOF bool + weClosed bool + dead bool + + serverConn *ServerConnection + myId, theirId uint32 + myWindow, theirWindow uint32 + maxPacketSize uint32 + err os.Error + + pendingRequests []ChannelRequest + pendingData []byte + head, length int + + // This lock is inferior to serverConn.lock + lock sync.Mutex + cond *sync.Cond +} + +func (c *channel) Accept() os.Error { + c.serverConn.lock.Lock() + defer c.serverConn.lock.Unlock() + + if c.serverConn.err != nil { + return c.serverConn.err + } + + confirm := channelOpenConfirmMsg{ + PeersId: c.theirId, + MyId: c.myId, + MyWindow: c.myWindow, + MaxPacketSize: c.maxPacketSize, + } + return c.serverConn.out.writePacket(marshal(msgChannelOpenConfirm, confirm)) +} + +func (c *channel) Reject(reason RejectionReason, message string) os.Error { + c.serverConn.lock.Lock() + defer c.serverConn.lock.Unlock() + + if c.serverConn.err != nil { + return c.serverConn.err + } + + reject := channelOpenFailureMsg{ + PeersId: c.theirId, + Reason: uint32(reason), + Message: message, + Language: "en", + } + return c.serverConn.out.writePacket(marshal(msgChannelOpenFailure, reject)) +} + +func (c *channel) handlePacket(packet interface{}) { + c.lock.Lock() + defer c.lock.Unlock() + + switch packet := packet.(type) { + case *channelRequestMsg: + req := ChannelRequest{ + Request: packet.Request, + WantReply: packet.WantReply, + Payload: packet.RequestSpecificData, + } + + c.pendingRequests = append(c.pendingRequests, req) + c.cond.Signal() + case *channelCloseMsg: + c.theyClosed = true + c.cond.Signal() + case *channelEOFMsg: + c.theySentEOF = true + c.cond.Signal() + default: + panic("unknown packet type") + } +} + +func (c *channel) handleData(data []byte) { + c.lock.Lock() + defer c.lock.Unlock() + + // The other side should never send us more than our window. + if len(data)+c.length > len(c.pendingData) { + // TODO(agl): we should tear down the channel with a protocol + // error. + return + } + + c.myWindow -= uint32(len(data)) + for i := 0; i < 2; i++ { + tail := c.head + c.length + if tail > len(c.pendingData) { + tail -= len(c.pendingData) + } + n := copy(c.pendingData[tail:], data) + data = data[n:] + c.length += n + } + + c.cond.Signal() +} + +func (c *channel) Read(data []byte) (n int, err os.Error) { + c.lock.Lock() + defer c.lock.Unlock() + + if c.err != nil { + return 0, c.err + } + + if c.myWindow <= uint32(len(c.pendingData))/2 { + packet := marshal(msgChannelWindowAdjust, windowAdjustMsg{ + PeersId: c.theirId, + AdditionalBytes: uint32(len(c.pendingData)) - c.myWindow, + }) + if err := c.serverConn.out.writePacket(packet); err != nil { + return 0, err + } + } + + for { + if c.theySentEOF || c.theyClosed || c.dead { + return 0, os.EOF + } + + if len(c.pendingRequests) > 0 { + req := c.pendingRequests[0] + if len(c.pendingRequests) == 1 { + c.pendingRequests = nil + } else { + oldPendingRequests := c.pendingRequests + c.pendingRequests = make([]ChannelRequest, len(oldPendingRequests)-1) + copy(c.pendingRequests, oldPendingRequests[1:]) + } + + return 0, req + } + + if c.length > 0 { + tail := c.head + c.length + if tail > len(c.pendingData) { + tail -= len(c.pendingData) + } + n = copy(data, c.pendingData[c.head:tail]) + c.head += n + c.length -= n + if c.head == len(c.pendingData) { + c.head = 0 + } + return + } + + c.cond.Wait() + } + + panic("unreachable") +} + +func (c *channel) Write(data []byte) (n int, err os.Error) { + for len(data) > 0 { + c.lock.Lock() + if c.dead || c.weClosed { + return 0, os.EOF + } + + if c.theirWindow == 0 { + c.cond.Wait() + continue + } + c.lock.Unlock() + + todo := data + if uint32(len(todo)) > c.theirWindow { + todo = todo[:c.theirWindow] + } + + packet := make([]byte, 1+4+4+len(todo)) + packet[0] = msgChannelData + packet[1] = byte(c.theirId) >> 24 + packet[2] = byte(c.theirId) >> 16 + packet[3] = byte(c.theirId) >> 8 + packet[4] = byte(c.theirId) + packet[5] = byte(len(todo)) >> 24 + packet[6] = byte(len(todo)) >> 16 + packet[7] = byte(len(todo)) >> 8 + packet[8] = byte(len(todo)) + copy(packet[9:], todo) + + c.serverConn.lock.Lock() + if err = c.serverConn.out.writePacket(packet); err != nil { + c.serverConn.lock.Unlock() + return + } + c.serverConn.lock.Unlock() + + n += len(todo) + data = data[len(todo):] + } + + return +} + +func (c *channel) Close() os.Error { + c.serverConn.lock.Lock() + defer c.serverConn.lock.Unlock() + + if c.serverConn.err != nil { + return c.serverConn.err + } + + if c.weClosed { + return os.NewError("ssh: channel already closed") + } + c.weClosed = true + + closeMsg := channelCloseMsg{ + PeersId: c.theirId, + } + return c.serverConn.out.writePacket(marshal(msgChannelClose, closeMsg)) +} + +func (c *channel) AckRequest(ok bool) os.Error { + c.serverConn.lock.Lock() + defer c.serverConn.lock.Unlock() + + if c.serverConn.err != nil { + return c.serverConn.err + } + + if ok { + ack := channelRequestSuccessMsg{ + PeersId: c.theirId, + } + return c.serverConn.out.writePacket(marshal(msgChannelSuccess, ack)) + } else { + ack := channelRequestFailureMsg{ + PeersId: c.theirId, + } + return c.serverConn.out.writePacket(marshal(msgChannelFailure, ack)) + } + panic("unreachable") +} + +func (c *channel) ChannelType() string { + return c.chanType +} + +func (c *channel) ExtraData() []byte { + return c.extraData +} diff --git a/src/pkg/exp/ssh/common.go b/src/pkg/exp/ssh/common.go new file mode 100644 index 00000000000..c951d1a753e --- /dev/null +++ b/src/pkg/exp/ssh/common.go @@ -0,0 +1,96 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package ssh + +import ( + "strconv" +) + +// These are string constants in the SSH protocol. +const ( + kexAlgoDH14SHA1 = "diffie-hellman-group14-sha1" + hostAlgoRSA = "ssh-rsa" + cipherAES128CTR = "aes128-ctr" + macSHA196 = "hmac-sha1-96" + compressionNone = "none" + serviceUserAuth = "ssh-userauth" + serviceSSH = "ssh-connection" +) + +// UnexpectedMessageError results when the SSH message that we received didn't +// match what we wanted. +type UnexpectedMessageError struct { + expected, got uint8 +} + +func (u UnexpectedMessageError) String() string { + return "ssh: unexpected message type " + strconv.Itoa(int(u.got)) + " (expected " + strconv.Itoa(int(u.expected)) + ")" +} + +// ParseError results from a malformed SSH message. +type ParseError struct { + msgType uint8 +} + +func (p ParseError) String() string { + return "ssh: parse error in message type " + strconv.Itoa(int(p.msgType)) +} + +func findCommonAlgorithm(clientAlgos []string, serverAlgos []string) (commonAlgo string, ok bool) { + for _, clientAlgo := range clientAlgos { + for _, serverAlgo := range serverAlgos { + if clientAlgo == serverAlgo { + return clientAlgo, true + } + } + } + + return +} + +func findAgreedAlgorithms(clientToServer, serverToClient *halfConnection, clientKexInit, serverKexInit *kexInitMsg) (kexAlgo, hostKeyAlgo string, ok bool) { + kexAlgo, ok = findCommonAlgorithm(clientKexInit.KexAlgos, serverKexInit.KexAlgos) + if !ok { + return + } + + hostKeyAlgo, ok = findCommonAlgorithm(clientKexInit.ServerHostKeyAlgos, serverKexInit.ServerHostKeyAlgos) + if !ok { + return + } + + clientToServer.cipherAlgo, ok = findCommonAlgorithm(clientKexInit.CiphersClientServer, serverKexInit.CiphersClientServer) + if !ok { + return + } + + serverToClient.cipherAlgo, ok = findCommonAlgorithm(clientKexInit.CiphersServerClient, serverKexInit.CiphersServerClient) + if !ok { + return + } + + clientToServer.macAlgo, ok = findCommonAlgorithm(clientKexInit.MACsClientServer, serverKexInit.MACsClientServer) + if !ok { + return + } + + serverToClient.macAlgo, ok = findCommonAlgorithm(clientKexInit.MACsServerClient, serverKexInit.MACsServerClient) + if !ok { + return + } + + clientToServer.compressionAlgo, ok = findCommonAlgorithm(clientKexInit.CompressionClientServer, serverKexInit.CompressionClientServer) + if !ok { + return + } + + serverToClient.compressionAlgo, ok = findCommonAlgorithm(clientKexInit.CompressionServerClient, serverKexInit.CompressionServerClient) + if !ok { + return + } + + ok = true + return +} diff --git a/src/pkg/exp/ssh/doc.go b/src/pkg/exp/ssh/doc.go new file mode 100644 index 00000000000..8dbdb0777c4 --- /dev/null +++ b/src/pkg/exp/ssh/doc.go @@ -0,0 +1,79 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +/* +Package ssh implements an SSH server. + +SSH is a transport security protocol, an authentication protocol and a +family of application protocols. The most typical application level +protocol is a remote shell and this is specifically implemented. However, +the multiplexed nature of SSH is exposed to users that wish to support +others. + +An SSH server is represented by a Server, which manages a number of +ServerConnections and handles authentication. + + var s Server + s.PubKeyCallback = pubKeyAuth + s.PasswordCallback = passwordAuth + + pemBytes, err := ioutil.ReadFile("id_rsa") + if err != nil { + panic("Failed to load private key") + } + err = s.SetRSAPrivateKey(pemBytes) + if err != nil { + panic("Failed to parse private key") + } + +Once a Server has been set up, connections can be attached. + + var sConn ServerConnection + sConn.Server = &s + err = sConn.Handshake(conn) + if err != nil { + panic("failed to handshake") + } + +An SSH connection multiplexes several channels, which must be accepted themselves: + + + for { + channel, err := sConn.Accept() + if err != nil { + panic("error from Accept") + } + + ... + } + +Accept reads from the connection, demultiplexes packets to their corresponding +channels and returns when a new channel request is seen. Some goroutine must +always be calling Accept; otherwise no messages will be forwarded to the +channels. + +Channels have a type, depending on the application level protocol intended. In +the case of a shell, the type is "session" and ServerShell may be used to +present a simple terminal interface. + + if channel.ChannelType() != "session" { + c.Reject(RejectUnknownChannelType, "unknown channel type") + return + } + channel.Accept() + + shell := NewServerShell(channel, "> ") + go func() { + defer channel.Close() + for { + line, err := shell.ReadLine() + if err != nil { + break + } + println(line) + } + return + }() +*/ +package ssh diff --git a/src/pkg/exp/ssh/messages.go b/src/pkg/exp/ssh/messages.go new file mode 100644 index 00000000000..d375eafae96 --- /dev/null +++ b/src/pkg/exp/ssh/messages.go @@ -0,0 +1,557 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package ssh + +import ( + "big" + "bytes" + "io" + "os" + "reflect" +) + +// These are SSH message type numbers. They are scattered around several +// documents but many were taken from +// http://www.iana.org/assignments/ssh-parameters/ssh-parameters.xml#ssh-parameters-1 +const ( + msgDisconnect = 1 + msgIgnore = 2 + msgUnimplemented = 3 + msgDebug = 4 + msgServiceRequest = 5 + msgServiceAccept = 6 + + msgKexInit = 20 + msgNewKeys = 21 + + msgKexDHInit = 30 + msgKexDHReply = 31 + + msgUserAuthRequest = 50 + msgUserAuthFailure = 51 + msgUserAuthSuccess = 52 + msgUserAuthBanner = 53 + msgUserAuthPubKeyOk = 60 + + msgGlobalRequest = 80 + msgRequestSuccess = 81 + msgRequestFailure = 82 + + msgChannelOpen = 90 + msgChannelOpenConfirm = 91 + msgChannelOpenFailure = 92 + msgChannelWindowAdjust = 93 + msgChannelData = 94 + msgChannelExtendedData = 95 + msgChannelEOF = 96 + msgChannelClose = 97 + msgChannelRequest = 98 + msgChannelSuccess = 99 + msgChannelFailure = 100 +) + +// SSH messages: +// +// These structures mirror the wire format of the corresponding SSH messages. +// They are marshaled using reflection with the marshal and unmarshal functions +// in this file. The only wrinkle is that a final member of type []byte with a +// tag of "rest" receives the remainder of a packet when unmarshaling. + +// See RFC 4253, section 7.1. +type kexInitMsg struct { + Cookie [16]byte + KexAlgos []string + ServerHostKeyAlgos []string + CiphersClientServer []string + CiphersServerClient []string + MACsClientServer []string + MACsServerClient []string + CompressionClientServer []string + CompressionServerClient []string + LanguagesClientServer []string + LanguagesServerClient []string + FirstKexFollows bool + Reserved uint32 +} + +// See RFC 4253, section 8. +type kexDHInitMsg struct { + X *big.Int +} + +type kexDHReplyMsg struct { + HostKey []byte + Y *big.Int + Signature []byte +} + +// See RFC 4253, section 10. +type serviceRequestMsg struct { + Service string +} + +// See RFC 4253, section 10. +type serviceAcceptMsg struct { + Service string +} + +// See RFC 4252, section 5. +type userAuthRequestMsg struct { + User string + Service string + Method string + Payload []byte "rest" +} + +// See RFC 4252, section 5.1 +type userAuthFailureMsg struct { + Methods []string + PartialSuccess bool +} + +// See RFC 4254, section 5.1. +type channelOpenMsg struct { + ChanType string + PeersId uint32 + PeersWindow uint32 + MaxPacketSize uint32 + TypeSpecificData []byte "rest" +} + +// See RFC 4254, section 5.1. +type channelOpenConfirmMsg struct { + PeersId uint32 + MyId uint32 + MyWindow uint32 + MaxPacketSize uint32 + TypeSpecificData []byte "rest" +} + +// See RFC 4254, section 5.1. +type channelOpenFailureMsg struct { + PeersId uint32 + Reason uint32 + Message string + Language string +} + +type channelRequestMsg struct { + PeersId uint32 + Request string + WantReply bool + RequestSpecificData []byte "rest" +} + +// See RFC 4254, section 5.4. +type channelRequestSuccessMsg struct { + PeersId uint32 +} + +// See RFC 4254, section 5.4. +type channelRequestFailureMsg struct { + PeersId uint32 +} + +// See RFC 4254, section 5.3 +type channelCloseMsg struct { + PeersId uint32 +} + +// See RFC 4254, section 5.3 +type channelEOFMsg struct { + PeersId uint32 +} + +// See RFC 4254, section 4 +type globalRequestMsg struct { + Type string + WantReply bool +} + +// See RFC 4254, section 5.2 +type windowAdjustMsg struct { + PeersId uint32 + AdditionalBytes uint32 +} + +// See RFC 4252, section 7 +type userAuthPubKeyOkMsg struct { + Algo string + PubKey string +} + +// unmarshal parses the SSH wire data in packet into out using reflection. +// expectedType is the expected SSH message type. It either returns nil on +// success, or a ParseError or UnexpectedMessageError on error. +func unmarshal(out interface{}, packet []byte, expectedType uint8) os.Error { + if len(packet) == 0 { + return ParseError{expectedType} + } + if packet[0] != expectedType { + return UnexpectedMessageError{expectedType, packet[0]} + } + packet = packet[1:] + + v := reflect.ValueOf(out).Elem() + structType := v.Type() + var ok bool + for i := 0; i < v.NumField(); i++ { + field := v.Field(i) + t := field.Type() + switch t.Kind() { + case reflect.Bool: + if len(packet) < 1 { + return ParseError{expectedType} + } + field.SetBool(packet[0] != 0) + packet = packet[1:] + case reflect.Array: + if t.Elem().Kind() != reflect.Uint8 { + panic("array of non-uint8") + } + if len(packet) < t.Len() { + return ParseError{expectedType} + } + for j := 0; j < t.Len(); j++ { + field.Index(j).Set(reflect.ValueOf(packet[j])) + } + packet = packet[t.Len():] + case reflect.Uint32: + var u32 uint32 + if u32, packet, ok = parseUint32(packet); !ok { + return ParseError{expectedType} + } + field.SetUint(uint64(u32)) + case reflect.String: + var s []byte + if s, packet, ok = parseString(packet); !ok { + return ParseError{expectedType} + } + field.SetString(string(s)) + case reflect.Slice: + switch t.Elem().Kind() { + case reflect.Uint8: + if structType.Field(i).Tag == "rest" { + field.Set(reflect.ValueOf(packet)) + packet = nil + } else { + var s []byte + if s, packet, ok = parseString(packet); !ok { + return ParseError{expectedType} + } + field.Set(reflect.ValueOf(s)) + } + case reflect.String: + var nl []string + if nl, packet, ok = parseNameList(packet); !ok { + return ParseError{expectedType} + } + field.Set(reflect.ValueOf(nl)) + default: + panic("slice of unknown type") + } + case reflect.Ptr: + if t == bigIntType { + var n *big.Int + if n, packet, ok = parseInt(packet); !ok { + return ParseError{expectedType} + } + field.Set(reflect.ValueOf(n)) + } else { + panic("pointer to unknown type") + } + default: + panic("unknown type") + } + } + + if len(packet) != 0 { + return ParseError{expectedType} + } + + return nil +} + +// marshal serializes the message in msg, using the given message type. +func marshal(msgType uint8, msg interface{}) []byte { + var out []byte + out = append(out, msgType) + + v := reflect.ValueOf(msg) + structType := v.Type() + for i := 0; i < v.NumField(); i++ { + field := v.Field(i) + t := field.Type() + switch t.Kind() { + case reflect.Bool: + var v uint8 + if field.Bool() { + v = 1 + } + out = append(out, v) + case reflect.Array: + if t.Elem().Kind() != reflect.Uint8 { + panic("array of non-uint8") + } + for j := 0; j < t.Len(); j++ { + out = append(out, byte(field.Index(j).Uint())) + } + case reflect.Uint32: + u32 := uint32(field.Uint()) + out = append(out, byte(u32>>24)) + out = append(out, byte(u32>>16)) + out = append(out, byte(u32>>8)) + out = append(out, byte(u32)) + case reflect.String: + s := field.String() + out = append(out, byte(len(s)>>24)) + out = append(out, byte(len(s)>>16)) + out = append(out, byte(len(s)>>8)) + out = append(out, byte(len(s))) + out = append(out, []byte(s)...) + case reflect.Slice: + switch t.Elem().Kind() { + case reflect.Uint8: + length := field.Len() + if structType.Field(i).Tag != "rest" { + out = append(out, byte(length>>24)) + out = append(out, byte(length>>16)) + out = append(out, byte(length>>8)) + out = append(out, byte(length)) + } + for j := 0; j < length; j++ { + out = append(out, byte(field.Index(j).Uint())) + } + case reflect.String: + var length int + for j := 0; j < field.Len(); j++ { + if j != 0 { + length++ /* comma */ + } + length += len(field.Index(j).String()) + } + + out = append(out, byte(length>>24)) + out = append(out, byte(length>>16)) + out = append(out, byte(length>>8)) + out = append(out, byte(length)) + for j := 0; j < field.Len(); j++ { + if j != 0 { + out = append(out, ',') + } + out = append(out, []byte(field.Index(j).String())...) + } + default: + panic("slice of unknown type") + } + case reflect.Ptr: + if t == bigIntType { + var n *big.Int + nValue := reflect.ValueOf(&n) + nValue.Elem().Set(field) + needed := intLength(n) + oldLength := len(out) + + if cap(out)-len(out) < needed { + newOut := make([]byte, len(out), 2*(len(out)+needed)) + copy(newOut, out) + out = newOut + } + out = out[:oldLength+needed] + marshalInt(out[oldLength:], n) + } else { + panic("pointer to unknown type") + } + } + } + + return out +} + +var bigOne = big.NewInt(1) + +func parseString(in []byte) (out, rest []byte, ok bool) { + if len(in) < 4 { + return + } + length := uint32(in[0])<<24 | uint32(in[1])<<16 | uint32(in[2])<<8 | uint32(in[3]) + if uint32(len(in)) < 4+length { + return + } + out = in[4 : 4+length] + rest = in[4+length:] + ok = true + return +} + +var comma = []byte{','} + +func parseNameList(in []byte) (out []string, rest []byte, ok bool) { + contents, rest, ok := parseString(in) + if !ok { + return + } + if len(contents) == 0 { + return + } + parts := bytes.Split(contents, comma) + out = make([]string, len(parts)) + for i, part := range parts { + out[i] = string(part) + } + return +} + +func parseInt(in []byte) (out *big.Int, rest []byte, ok bool) { + contents, rest, ok := parseString(in) + if !ok { + return + } + out = new(big.Int) + + if len(contents) > 0 && contents[0]&0x80 == 0x80 { + // This is a negative number + notBytes := make([]byte, len(contents)) + for i := range notBytes { + notBytes[i] = ^contents[i] + } + out.SetBytes(notBytes) + out.Add(out, bigOne) + out.Neg(out) + } else { + // Positive number + out.SetBytes(contents) + } + ok = true + return +} + +func parseUint32(in []byte) (out uint32, rest []byte, ok bool) { + if len(in) < 4 { + return + } + out = uint32(in[0])<<24 | uint32(in[1])<<16 | uint32(in[2])<<8 | uint32(in[3]) + rest = in[4:] + ok = true + return +} + +const maxPacketSize = 36000 + +func nameListLength(namelist []string) int { + length := 4 /* uint32 length prefix */ + for i, name := range namelist { + if i != 0 { + length++ /* comma */ + } + length += len(name) + } + return length +} + +func intLength(n *big.Int) int { + length := 4 /* length bytes */ + if n.Sign() < 0 { + nMinus1 := new(big.Int).Neg(n) + nMinus1.Sub(nMinus1, bigOne) + bitLen := nMinus1.BitLen() + if bitLen%8 == 0 { + // The number will need 0xff padding + length++ + } + length += (bitLen + 7) / 8 + } else if n.Sign() == 0 { + // A zero is the zero length string + } else { + bitLen := n.BitLen() + if bitLen%8 == 0 { + // The number will need 0x00 padding + length++ + } + length += (bitLen + 7) / 8 + } + + return length +} + +func marshalInt(to []byte, n *big.Int) []byte { + lengthBytes := to + to = to[4:] + length := 0 + + if n.Sign() < 0 { + // A negative number has to be converted to two's-complement + // form. So we'll subtract 1 and invert. If the + // most-significant-bit isn't set then we'll need to pad the + // beginning with 0xff in order to keep the number negative. + nMinus1 := new(big.Int).Neg(n) + nMinus1.Sub(nMinus1, bigOne) + bytes := nMinus1.Bytes() + for i := range bytes { + bytes[i] ^= 0xff + } + if len(bytes) == 0 || bytes[0]&0x80 == 0 { + to[0] = 0xff + to = to[1:] + length++ + } + nBytes := copy(to, bytes) + to = to[nBytes:] + length += nBytes + } else if n.Sign() == 0 { + // A zero is the zero length string + } else { + bytes := n.Bytes() + if len(bytes) > 0 && bytes[0]&0x80 != 0 { + // We'll have to pad this with a 0x00 in order to + // stop it looking like a negative number. + to[0] = 0 + to = to[1:] + length++ + } + nBytes := copy(to, bytes) + to = to[nBytes:] + length += nBytes + } + + lengthBytes[0] = byte(length >> 24) + lengthBytes[1] = byte(length >> 16) + lengthBytes[2] = byte(length >> 8) + lengthBytes[3] = byte(length) + return to +} + +func writeInt(w io.Writer, n *big.Int) { + length := intLength(n) + buf := make([]byte, length) + marshalInt(buf, n) + w.Write(buf) +} + +func writeString(w io.Writer, s []byte) { + var lengthBytes [4]byte + lengthBytes[0] = byte(len(s) >> 24) + lengthBytes[1] = byte(len(s) >> 16) + lengthBytes[2] = byte(len(s) >> 8) + lengthBytes[3] = byte(len(s)) + w.Write(lengthBytes[:]) + w.Write(s) +} + +func stringLength(s []byte) int { + return 4 + len(s) +} + +func marshalString(to []byte, s []byte) []byte { + to[0] = byte(len(s) >> 24) + to[1] = byte(len(s) >> 16) + to[2] = byte(len(s) >> 8) + to[3] = byte(len(s)) + to = to[4:] + copy(to, s) + return to[len(s):] +} + +var bigIntType = reflect.TypeOf((*big.Int)(nil)) diff --git a/src/pkg/exp/ssh/messages_test.go b/src/pkg/exp/ssh/messages_test.go new file mode 100644 index 00000000000..629f3d3b145 --- /dev/null +++ b/src/pkg/exp/ssh/messages_test.go @@ -0,0 +1,125 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package ssh + +import ( + "big" + "rand" + "reflect" + "testing" + "testing/quick" +) + +var intLengthTests = []struct { + val, length int +}{ + {0, 4 + 0}, + {1, 4 + 1}, + {127, 4 + 1}, + {128, 4 + 2}, + {-1, 4 + 1}, +} + +func TestIntLength(t *testing.T) { + for _, test := range intLengthTests { + v := new(big.Int).SetInt64(int64(test.val)) + length := intLength(v) + if length != test.length { + t.Errorf("For %d, got length %d but expected %d", test.val, length, test.length) + } + } +} + +var messageTypes = []interface{}{ + &kexInitMsg{}, + &kexDHInitMsg{}, + &serviceRequestMsg{}, + &serviceAcceptMsg{}, + &userAuthRequestMsg{}, + &channelOpenMsg{}, + &channelOpenConfirmMsg{}, + &channelRequestMsg{}, + &channelRequestSuccessMsg{}, +} + +func TestMarshalUnmarshal(t *testing.T) { + rand := rand.New(rand.NewSource(0)) + for i, iface := range messageTypes { + ty := reflect.ValueOf(iface).Type() + + n := 100 + if testing.Short() { + n = 5 + } + for j := 0; j < n; j++ { + v, ok := quick.Value(ty, rand) + if !ok { + t.Errorf("#%d: failed to create value", i) + break + } + + m1 := v.Elem().Interface() + m2 := iface + + marshaled := marshal(msgIgnore, m1) + if err := unmarshal(m2, marshaled, msgIgnore); err != nil { + t.Errorf("#%d failed to unmarshal %#v: %s", i, m1, err) + break + } + + if !reflect.DeepEqual(v.Interface(), m2) { + t.Errorf("#%d\ngot: %#v\nwant:%#v\n%x", i, m2, m1, marshaled) + break + } + } + } +} + +func randomBytes(out []byte, rand *rand.Rand) { + for i := 0; i < len(out); i++ { + out[i] = byte(rand.Int31()) + } +} + +func randomNameList(rand *rand.Rand) []string { + ret := make([]string, rand.Int31()&15) + for i := range ret { + s := make([]byte, 1+(rand.Int31()&15)) + for j := range s { + s[j] = 'a' + uint8(rand.Int31()&15) + } + ret[i] = string(s) + } + return ret +} + +func randomInt(rand *rand.Rand) *big.Int { + return new(big.Int).SetInt64(int64(int32(rand.Uint32()))) +} + +func (*kexInitMsg) Generate(rand *rand.Rand, size int) reflect.Value { + ki := &kexInitMsg{} + randomBytes(ki.Cookie[:], rand) + ki.KexAlgos = randomNameList(rand) + ki.ServerHostKeyAlgos = randomNameList(rand) + ki.CiphersClientServer = randomNameList(rand) + ki.CiphersServerClient = randomNameList(rand) + ki.MACsClientServer = randomNameList(rand) + ki.MACsServerClient = randomNameList(rand) + ki.CompressionClientServer = randomNameList(rand) + ki.CompressionServerClient = randomNameList(rand) + ki.LanguagesClientServer = randomNameList(rand) + ki.LanguagesServerClient = randomNameList(rand) + if rand.Int31()&1 == 1 { + ki.FirstKexFollows = true + } + return reflect.ValueOf(ki) +} + +func (*kexDHInitMsg) Generate(rand *rand.Rand, size int) reflect.Value { + dhi := &kexDHInitMsg{} + dhi.X = randomInt(rand) + return reflect.ValueOf(dhi) +} diff --git a/src/pkg/exp/ssh/server.go b/src/pkg/exp/ssh/server.go new file mode 100644 index 00000000000..57cd5971063 --- /dev/null +++ b/src/pkg/exp/ssh/server.go @@ -0,0 +1,711 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package ssh + +import ( + "big" + "bufio" + "bytes" + "crypto" + "crypto/rand" + "crypto/rsa" + _ "crypto/sha1" + "crypto/x509" + "encoding/pem" + "net" + "os" + "sync" +) + +var supportedKexAlgos = []string{kexAlgoDH14SHA1} +var supportedHostKeyAlgos = []string{hostAlgoRSA} +var supportedCiphers = []string{cipherAES128CTR} +var supportedMACs = []string{macSHA196} +var supportedCompressions = []string{compressionNone} + +// Server represents an SSH server. A Server may have several ServerConnections. +type Server struct { + rsa *rsa.PrivateKey + rsaSerialized []byte + + // NoClientAuth is true if clients are allowed to connect without + // authenticating. + NoClientAuth bool + + // PasswordCallback, if non-nil, is called when a user attempts to + // authenticate using a password. It may be called concurrently from + // several goroutines. + PasswordCallback func(user, password string) bool + + // PubKeyCallback, if non-nil, is called when a client attempts public + // key authentication. It must return true iff the given public key is + // valid for the given user. + PubKeyCallback func(user, algo string, pubkey []byte) bool +} + +// SetRSAPrivateKey sets the private key for a Server. A Server must have a +// private key configured in order to accept connections. The private key must +// be in the form of a PEM encoded, PKCS#1, RSA private key. The file "id_rsa" +// typically contains such a key. +func (s *Server) SetRSAPrivateKey(pemBytes []byte) os.Error { + block, _ := pem.Decode(pemBytes) + if block == nil { + return os.NewError("ssh: no key found") + } + var err os.Error + s.rsa, err = x509.ParsePKCS1PrivateKey(block.Bytes) + if err != nil { + return err + } + + s.rsaSerialized = marshalRSA(s.rsa) + return nil +} + +// marshalRSA serializes an RSA private key according to RFC 4256, section 6.6. +func marshalRSA(priv *rsa.PrivateKey) []byte { + e := new(big.Int).SetInt64(int64(priv.E)) + length := stringLength([]byte(hostAlgoRSA)) + length += intLength(e) + length += intLength(priv.N) + + ret := make([]byte, length) + r := marshalString(ret, []byte(hostAlgoRSA)) + r = marshalInt(r, e) + r = marshalInt(r, priv.N) + + return ret +} + +// parseRSA parses an RSA key according to RFC 4256, section 6.6. +func parseRSA(in []byte) (pubKey *rsa.PublicKey, ok bool) { + algo, in, ok := parseString(in) + if !ok || string(algo) != hostAlgoRSA { + return nil, false + } + bigE, in, ok := parseInt(in) + if !ok || bigE.BitLen() > 24 { + return nil, false + } + e := bigE.Int64() + if e < 3 || e&1 == 0 { + return nil, false + } + N, in, ok := parseInt(in) + if !ok || len(in) > 0 { + return nil, false + } + return &rsa.PublicKey{ + N: N, + E: int(e), + }, true +} + +func parseRSASig(in []byte) (sig []byte, ok bool) { + algo, in, ok := parseString(in) + if !ok || string(algo) != hostAlgoRSA { + return nil, false + } + sig, in, ok = parseString(in) + if len(in) > 0 { + ok = false + } + return +} + +// cachedPubKey contains the results of querying whether a public key is +// acceptable for a user. The cache only applies to a single ServerConnection. +type cachedPubKey struct { + user, algo string + pubKey []byte + result bool +} + +const maxCachedPubKeys = 16 + +// ServerConnection represents an incomming connection to a Server. +type ServerConnection struct { + Server *Server + + in, out *halfConnection + + channels map[uint32]*channel + nextChanId uint32 + + // lock protects err and also allows Channels to serialise their writes + // to out. + lock sync.RWMutex + err os.Error + + // cachedPubKeys contains the cache results of tests for public keys. + // Since SSH clients will query whether a public key is acceptable + // before attempting to authenticate with it, we end up with duplicate + // queries for public key validity. + cachedPubKeys []cachedPubKey +} + +// dhGroup is a multiplicative group suitable for implementing Diffie-Hellman key agreement. +type dhGroup struct { + g, p *big.Int +} + +// dhGroup14 is the group called diffie-hellman-group14-sha1 in RFC 4253 and +// Oakley Group 14 in RFC 3526. +var dhGroup14 *dhGroup + +var dhGroup14Once sync.Once + +func initDHGroup14() { + p, _ := new(big.Int).SetString("FFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7EDEE386BFB5A899FA5AE9F24117C4B1FE649286651ECE45B3DC2007CB8A163BF0598DA48361C55D39A69163FA8FD24CF5F83655D23DCA3AD961C62F356208552BB9ED529077096966D670C354E4ABC9804F1746C08CA18217C32905E462E36CE3BE39E772C180E86039B2783A2EC07A28FB5C55DF06F4C52C9DE2BCBF6955817183995497CEA956AE515D2261898FA051015728E5A8AACAA68FFFFFFFFFFFFFFFF", 16) + + dhGroup14 = &dhGroup{ + g: new(big.Int).SetInt64(2), + p: p, + } +} + +type handshakeMagics struct { + clientVersion, serverVersion []byte + clientKexInit, serverKexInit []byte +} + +// kexDH performs Diffie-Hellman key agreement on a ServerConnection. The +// returned values are given the same names as in RFC 4253, section 8. +func (s *ServerConnection) kexDH(group *dhGroup, hashFunc crypto.Hash, magics *handshakeMagics, hostKeyAlgo string) (H, K []byte, err os.Error) { + packet, err := s.in.readPacket() + if err != nil { + return + } + var kexDHInit kexDHInitMsg + if err = unmarshal(&kexDHInit, packet, msgKexDHInit); err != nil { + return + } + + if kexDHInit.X.Sign() == 0 || kexDHInit.X.Cmp(group.p) >= 0 { + return nil, nil, os.NewError("client DH parameter out of bounds") + } + + y, err := rand.Int(rand.Reader, group.p) + if err != nil { + return + } + + Y := new(big.Int).Exp(group.g, y, group.p) + kInt := new(big.Int).Exp(kexDHInit.X, y, group.p) + + var serializedHostKey []byte + switch hostKeyAlgo { + case hostAlgoRSA: + serializedHostKey = s.Server.rsaSerialized + default: + return nil, nil, os.NewError("internal error") + } + + h := hashFunc.New() + writeString(h, magics.clientVersion) + writeString(h, magics.serverVersion) + writeString(h, magics.clientKexInit) + writeString(h, magics.serverKexInit) + writeString(h, serializedHostKey) + writeInt(h, kexDHInit.X) + writeInt(h, Y) + K = make([]byte, intLength(kInt)) + marshalInt(K, kInt) + h.Write(K) + + H = h.Sum() + + h.Reset() + h.Write(H) + hh := h.Sum() + + var sig []byte + switch hostKeyAlgo { + case hostAlgoRSA: + sig, err = rsa.SignPKCS1v15(rand.Reader, s.Server.rsa, hashFunc, hh) + if err != nil { + return + } + default: + return nil, nil, os.NewError("internal error") + } + + serializedSig := serializeRSASignature(sig) + + kexDHReply := kexDHReplyMsg{ + HostKey: serializedHostKey, + Y: Y, + Signature: serializedSig, + } + packet = marshal(msgKexDHReply, kexDHReply) + + err = s.out.writePacket(packet) + return +} + +func serializeRSASignature(sig []byte) []byte { + length := stringLength([]byte(hostAlgoRSA)) + length += stringLength(sig) + + ret := make([]byte, length) + r := marshalString(ret, []byte(hostAlgoRSA)) + r = marshalString(r, sig) + + return ret +} + +// serverVersion is the fixed identification string that Server will use. +var serverVersion = []byte("SSH-2.0-Go\r\n") + +// buildDataSignedForAuth returns the data that is signed in order to prove +// posession of a private key. See RFC 4252, section 7. +func buildDataSignedForAuth(sessionId []byte, req userAuthRequestMsg, algo, pubKey []byte) []byte { + user := []byte(req.User) + service := []byte(req.Service) + method := []byte(req.Method) + + length := stringLength(sessionId) + length += 1 + length += stringLength(user) + length += stringLength(service) + length += stringLength(method) + length += 1 + length += stringLength(algo) + length += stringLength(pubKey) + + ret := make([]byte, length) + r := marshalString(ret, sessionId) + r[0] = msgUserAuthRequest + r = r[1:] + r = marshalString(r, user) + r = marshalString(r, service) + r = marshalString(r, method) + r[0] = 1 + r = r[1:] + r = marshalString(r, algo) + r = marshalString(r, pubKey) + return ret +} + +// Handshake performs an SSH transport and client authentication on the given ServerConnection. +func (s *ServerConnection) Handshake(conn net.Conn) os.Error { + var magics handshakeMagics + inBuf := bufio.NewReader(conn) + + _, err := conn.Write(serverVersion) + if err != nil { + return err + } + + magics.serverVersion = serverVersion[:len(serverVersion)-2] + serverKexInit := kexInitMsg{ + KexAlgos: supportedKexAlgos, + ServerHostKeyAlgos: supportedHostKeyAlgos, + CiphersClientServer: supportedCiphers, + CiphersServerClient: supportedCiphers, + MACsClientServer: supportedMACs, + MACsServerClient: supportedMACs, + CompressionClientServer: supportedCompressions, + CompressionServerClient: supportedCompressions, + } + kexInitPacket := marshal(msgKexInit, serverKexInit) + magics.serverKexInit = kexInitPacket + + var out halfConnection + out.out = conn + out.rand = rand.Reader + s.out = &out + err = out.writePacket(kexInitPacket) + if err != nil { + return err + } + + version, ok := readVersion(inBuf) + if !ok { + return os.NewError("failed to read version string from client") + } + magics.clientVersion = version + + var in halfConnection + in.in = inBuf + s.in = &in + packet, err := in.readPacket() + if err != nil { + return err + } + magics.clientKexInit = packet + + var clientKexInit kexInitMsg + if err = unmarshal(&clientKexInit, packet, msgKexInit); err != nil { + return err + } + + kexAlgo, hostKeyAlgo, ok := findAgreedAlgorithms(&in, &out, &clientKexInit, &serverKexInit) + if !ok { + return os.NewError("ssh: no common algorithms") + } + + if clientKexInit.FirstKexFollows && kexAlgo != clientKexInit.KexAlgos[0] { + // The client sent a Kex message for the wrong algorithm, + // which we have to ignore. + _, err := in.readPacket() + if err != nil { + return err + } + } + + var H, K []byte + var hashFunc crypto.Hash + switch kexAlgo { + case kexAlgoDH14SHA1: + hashFunc = crypto.SHA1 + dhGroup14Once.Do(initDHGroup14) + H, K, err = s.kexDH(dhGroup14, hashFunc, &magics, hostKeyAlgo) + default: + err = os.NewError("ssh: internal error") + } + + if err != nil { + return err + } + + packet = []byte{msgNewKeys} + if err = out.writePacket(packet); err != nil { + return err + } + if err = out.setupKeys(serverKeys, K, H, H, hashFunc); err != nil { + return err + } + + if packet, err = in.readPacket(); err != nil { + return err + } + if packet[0] != msgNewKeys { + return UnexpectedMessageError{msgNewKeys, packet[0]} + } + + in.setupKeys(clientKeys, K, H, H, hashFunc) + + packet, err = in.readPacket() + if err != nil { + return err + } + + var serviceRequest serviceRequestMsg + if err = unmarshal(&serviceRequest, packet, msgServiceRequest); err != nil { + return err + } + if serviceRequest.Service != serviceUserAuth { + return os.NewError("ssh: requested service '" + serviceRequest.Service + "' before authenticating") + } + + serviceAccept := serviceAcceptMsg{ + Service: serviceUserAuth, + } + packet = marshal(msgServiceAccept, serviceAccept) + if err = out.writePacket(packet); err != nil { + return err + } + + if err = s.authenticate(H); err != nil { + return err + } + + s.channels = make(map[uint32]*channel) + return nil +} + +func isAcceptableAlgo(algo string) bool { + return algo == hostAlgoRSA +} + +// testPubKey returns true if the given public key is acceptable for the user. +func (s *ServerConnection) testPubKey(user, algo string, pubKey []byte) bool { + if s.Server.PubKeyCallback == nil || !isAcceptableAlgo(algo) { + return false + } + + for _, c := range s.cachedPubKeys { + if c.user == user && c.algo == algo && bytes.Equal(c.pubKey, pubKey) { + return c.result + } + } + + result := s.Server.PubKeyCallback(user, algo, pubKey) + if len(s.cachedPubKeys) < maxCachedPubKeys { + c := cachedPubKey{ + user: user, + algo: algo, + pubKey: make([]byte, len(pubKey)), + result: result, + } + copy(c.pubKey, pubKey) + s.cachedPubKeys = append(s.cachedPubKeys, c) + } + + return result +} + +func (s *ServerConnection) authenticate(H []byte) os.Error { + var userAuthReq userAuthRequestMsg + var err os.Error + var packet []byte + +userAuthLoop: + for { + if packet, err = s.in.readPacket(); err != nil { + return err + } + if err = unmarshal(&userAuthReq, packet, msgUserAuthRequest); err != nil { + return err + } + + if userAuthReq.Service != serviceSSH { + return os.NewError("ssh: client attempted to negotiate for unknown service: " + userAuthReq.Service) + } + + switch userAuthReq.Method { + case "none": + if s.Server.NoClientAuth { + break userAuthLoop + } + case "password": + if s.Server.PasswordCallback == nil { + break + } + payload := userAuthReq.Payload + if len(payload) < 1 || payload[0] != 0 { + return ParseError{msgUserAuthRequest} + } + payload = payload[1:] + password, payload, ok := parseString(payload) + if !ok || len(payload) > 0 { + return ParseError{msgUserAuthRequest} + } + + if s.Server.PasswordCallback(userAuthReq.User, string(password)) { + break userAuthLoop + } + case "publickey": + if s.Server.PubKeyCallback == nil { + break + } + payload := userAuthReq.Payload + if len(payload) < 1 { + return ParseError{msgUserAuthRequest} + } + isQuery := payload[0] == 0 + payload = payload[1:] + algoBytes, payload, ok := parseString(payload) + if !ok { + return ParseError{msgUserAuthRequest} + } + algo := string(algoBytes) + + pubKey, payload, ok := parseString(payload) + if !ok { + return ParseError{msgUserAuthRequest} + } + if isQuery { + // The client can query if the given public key + // would be ok. + if len(payload) > 0 { + return ParseError{msgUserAuthRequest} + } + if s.testPubKey(userAuthReq.User, algo, pubKey) { + okMsg := userAuthPubKeyOkMsg{ + Algo: algo, + PubKey: string(pubKey), + } + if err = s.out.writePacket(marshal(msgUserAuthPubKeyOk, okMsg)); err != nil { + return err + } + continue userAuthLoop + } + } else { + sig, payload, ok := parseString(payload) + if !ok || len(payload) > 0 { + return ParseError{msgUserAuthRequest} + } + if !isAcceptableAlgo(algo) { + break + } + rsaSig, ok := parseRSASig(sig) + if !ok { + return ParseError{msgUserAuthRequest} + } + signedData := buildDataSignedForAuth(H, userAuthReq, algoBytes, pubKey) + switch algo { + case hostAlgoRSA: + hashFunc := crypto.SHA1 + h := hashFunc.New() + h.Write(signedData) + digest := h.Sum() + rsaKey, ok := parseRSA(pubKey) + if !ok { + return ParseError{msgUserAuthRequest} + } + if rsa.VerifyPKCS1v15(rsaKey, hashFunc, digest, rsaSig) != nil { + return ParseError{msgUserAuthRequest} + } + default: + return os.NewError("ssh: isAcceptableAlgo incorrect") + } + if s.testPubKey(userAuthReq.User, algo, pubKey) { + break userAuthLoop + } + } + } + + var failureMsg userAuthFailureMsg + if s.Server.PasswordCallback != nil { + failureMsg.Methods = append(failureMsg.Methods, "password") + } + if s.Server.PubKeyCallback != nil { + failureMsg.Methods = append(failureMsg.Methods, "publickey") + } + + if len(failureMsg.Methods) == 0 { + return os.NewError("ssh: no authentication methods configured but NoClientAuth is also false") + } + + if err = s.out.writePacket(marshal(msgUserAuthFailure, failureMsg)); err != nil { + return err + } + } + + packet = []byte{msgUserAuthSuccess} + if err = s.out.writePacket(packet); err != nil { + return err + } + + return nil +} + +const defaultWindowSize = 32768 + +// Accept reads and processes messages on a ServerConnection. It must be called +// in order to demultiplex messages to any resulting Channels. +func (s *ServerConnection) Accept() (Channel, os.Error) { + if s.err != nil { + return nil, s.err + } + + for { + packet, err := s.in.readPacket() + if err != nil { + + s.lock.Lock() + s.err = err + s.lock.Unlock() + + for _, c := range s.channels { + c.dead = true + c.handleData(nil) + } + + return nil, err + } + + switch packet[0] { + case msgChannelOpen: + var chanOpen channelOpenMsg + if err := unmarshal(&chanOpen, packet, msgChannelOpen); err != nil { + return nil, err + } + + c := new(channel) + c.chanType = chanOpen.ChanType + c.theirId = chanOpen.PeersId + c.theirWindow = chanOpen.PeersWindow + c.maxPacketSize = chanOpen.MaxPacketSize + c.extraData = chanOpen.TypeSpecificData + c.myWindow = defaultWindowSize + c.serverConn = s + c.cond = sync.NewCond(&c.lock) + c.pendingData = make([]byte, c.myWindow) + + s.lock.Lock() + c.myId = s.nextChanId + s.nextChanId++ + s.channels[c.myId] = c + s.lock.Unlock() + return c, nil + + case msgChannelRequest: + var chanRequest channelRequestMsg + if err := unmarshal(&chanRequest, packet, msgChannelRequest); err != nil { + return nil, err + } + + s.lock.Lock() + c, ok := s.channels[chanRequest.PeersId] + if !ok { + continue + } + c.handlePacket(&chanRequest) + s.lock.Unlock() + + case msgChannelData: + if len(packet) < 5 { + return nil, ParseError{msgChannelData} + } + chanId := uint32(packet[1])<<24 | uint32(packet[2])<<16 | uint32(packet[3])<<8 | uint32(packet[4]) + + s.lock.Lock() + c, ok := s.channels[chanId] + if !ok { + continue + } + c.handleData(packet[9:]) + s.lock.Unlock() + + case msgChannelEOF: + var eofMsg channelEOFMsg + if err := unmarshal(&eofMsg, packet, msgChannelEOF); err != nil { + return nil, err + } + + s.lock.Lock() + c, ok := s.channels[eofMsg.PeersId] + if !ok { + continue + } + c.handlePacket(&eofMsg) + s.lock.Unlock() + + case msgChannelClose: + var closeMsg channelCloseMsg + if err := unmarshal(&closeMsg, packet, msgChannelClose); err != nil { + return nil, err + } + + s.lock.Lock() + c, ok := s.channels[closeMsg.PeersId] + if !ok { + continue + } + c.handlePacket(&closeMsg) + s.lock.Unlock() + + case msgGlobalRequest: + var request globalRequestMsg + if err := unmarshal(&request, packet, msgGlobalRequest); err != nil { + return nil, err + } + + if request.WantReply { + if err := s.out.writePacket([]byte{msgRequestFailure}); err != nil { + return nil, err + } + } + + default: + // Unknown message. Ignore. + } + } + + panic("unreachable") +} diff --git a/src/pkg/exp/ssh/server_shell.go b/src/pkg/exp/ssh/server_shell.go new file mode 100644 index 00000000000..53a3241f5e0 --- /dev/null +++ b/src/pkg/exp/ssh/server_shell.go @@ -0,0 +1,399 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package ssh + +import ( + "os" +) + +// ServerShell contains the state for running a VT100 terminal that is capable +// of reading lines of input. +type ServerShell struct { + c Channel + prompt string + + // line is the current line being entered. + line []byte + // pos is the logical position of the cursor in line + pos int + + // cursorX contains the current X value of the cursor where the left + // edge is 0. cursorY contains the row number where the first row of + // the current line is 0. + cursorX, cursorY int + // maxLine is the greatest value of cursorY so far. + maxLine int + + termWidth, termHeight int + + // outBuf contains the terminal data to be sent. + outBuf []byte + // remainder contains the remainder of any partial key sequences after + // a read. It aliases into inBuf. + remainder []byte + inBuf [256]byte +} + +// NewServerShell runs a VT100 terminal on the given channel. prompt is a +// string that is written at the start of each input line. For example: "> ". +func NewServerShell(c Channel, prompt string) *ServerShell { + return &ServerShell{ + c: c, + prompt: prompt, + termWidth: 80, + termHeight: 24, + } +} + +const ( + keyCtrlD = 4 + keyEnter = '\r' + keyEscape = 27 + keyBackspace = 127 + keyUnknown = 256 + iota + keyUp + keyDown + keyLeft + keyRight + keyAltLeft + keyAltRight +) + +// bytesToKey tries to parse a key sequence from b. If successful, it returns +// the key and the remainder of the input. Otherwise it returns -1. +func bytesToKey(b []byte) (int, []byte) { + if len(b) == 0 { + return -1, nil + } + + if b[0] != keyEscape { + return int(b[0]), b[1:] + } + + if len(b) >= 3 && b[0] == keyEscape && b[1] == '[' { + switch b[2] { + case 'A': + return keyUp, b[3:] + case 'B': + return keyDown, b[3:] + case 'C': + return keyRight, b[3:] + case 'D': + return keyLeft, b[3:] + } + } + + if len(b) >= 6 && b[0] == keyEscape && b[1] == '[' && b[2] == '1' && b[3] == ';' && b[4] == '3' { + switch b[5] { + case 'C': + return keyAltRight, b[6:] + case 'D': + return keyAltLeft, b[6:] + } + } + + // If we get here then we have a key that we don't recognise, or a + // partial sequence. It's not clear how one should find the end of a + // sequence without knowing them all, but it seems that [a-zA-Z] only + // appears at the end of a sequence. + for i, c := range b[0:] { + if c >= 'a' && c <= 'z' || c >= 'A' && c <= 'Z' { + return keyUnknown, b[i+1:] + } + } + + return -1, b +} + +// queue appends data to the end of ss.outBuf +func (ss *ServerShell) queue(data []byte) { + if len(ss.outBuf)+len(data) > cap(ss.outBuf) { + newOutBuf := make([]byte, len(ss.outBuf), 2*(len(ss.outBuf)+len(data))) + copy(newOutBuf, ss.outBuf) + ss.outBuf = newOutBuf + } + + oldLen := len(ss.outBuf) + ss.outBuf = ss.outBuf[:len(ss.outBuf)+len(data)] + copy(ss.outBuf[oldLen:], data) +} + +var eraseUnderCursor = []byte{' ', keyEscape, '[', 'D'} + +func isPrintable(key int) bool { + return key >= 32 && key < 127 +} + +// moveCursorToPos appends data to ss.outBuf which will move the cursor to the +// given, logical position in the text. +func (ss *ServerShell) moveCursorToPos(pos int) { + x := len(ss.prompt) + pos + y := x / ss.termWidth + x = x % ss.termWidth + + up := 0 + if y < ss.cursorY { + up = ss.cursorY - y + } + + down := 0 + if y > ss.cursorY { + down = y - ss.cursorY + } + + left := 0 + if x < ss.cursorX { + left = ss.cursorX - x + } + + right := 0 + if x > ss.cursorX { + right = x - ss.cursorX + } + + movement := make([]byte, 3*(up+down+left+right)) + m := movement + for i := 0; i < up; i++ { + m[0] = keyEscape + m[1] = '[' + m[2] = 'A' + m = m[3:] + } + for i := 0; i < down; i++ { + m[0] = keyEscape + m[1] = '[' + m[2] = 'B' + m = m[3:] + } + for i := 0; i < left; i++ { + m[0] = keyEscape + m[1] = '[' + m[2] = 'D' + m = m[3:] + } + for i := 0; i < right; i++ { + m[0] = keyEscape + m[1] = '[' + m[2] = 'C' + m = m[3:] + } + + ss.cursorX = x + ss.cursorY = y + ss.queue(movement) +} + +const maxLineLength = 4096 + +// handleKey processes the given key and, optionally, returns a line of text +// that the user has entered. +func (ss *ServerShell) handleKey(key int) (line string, ok bool) { + switch key { + case keyBackspace: + if ss.pos == 0 { + return + } + ss.pos-- + + copy(ss.line[ss.pos:], ss.line[1+ss.pos:]) + ss.line = ss.line[:len(ss.line)-1] + ss.writeLine(ss.line[ss.pos:]) + ss.moveCursorToPos(ss.pos) + ss.queue(eraseUnderCursor) + case keyAltLeft: + // move left by a word. + if ss.pos == 0 { + return + } + ss.pos-- + for ss.pos > 0 { + if ss.line[ss.pos] != ' ' { + break + } + ss.pos-- + } + for ss.pos > 0 { + if ss.line[ss.pos] == ' ' { + ss.pos++ + break + } + ss.pos-- + } + ss.moveCursorToPos(ss.pos) + case keyAltRight: + // move right by a word. + for ss.pos < len(ss.line) { + if ss.line[ss.pos] == ' ' { + break + } + ss.pos++ + } + for ss.pos < len(ss.line) { + if ss.line[ss.pos] != ' ' { + break + } + ss.pos++ + } + ss.moveCursorToPos(ss.pos) + case keyLeft: + if ss.pos == 0 { + return + } + ss.pos-- + ss.moveCursorToPos(ss.pos) + case keyRight: + if ss.pos == len(ss.line) { + return + } + ss.pos++ + ss.moveCursorToPos(ss.pos) + case keyEnter: + ss.moveCursorToPos(len(ss.line)) + ss.queue([]byte("\r\n")) + line = string(ss.line) + ok = true + ss.line = ss.line[:0] + ss.pos = 0 + ss.cursorX = 0 + ss.cursorY = 0 + ss.maxLine = 0 + default: + if !isPrintable(key) { + return + } + if len(ss.line) == maxLineLength { + return + } + if len(ss.line) == cap(ss.line) { + newLine := make([]byte, len(ss.line), 2*(1+len(ss.line))) + copy(newLine, ss.line) + ss.line = newLine + } + ss.line = ss.line[:len(ss.line)+1] + copy(ss.line[ss.pos+1:], ss.line[ss.pos:]) + ss.line[ss.pos] = byte(key) + ss.writeLine(ss.line[ss.pos:]) + ss.pos++ + ss.moveCursorToPos(ss.pos) + } + return +} + +func (ss *ServerShell) writeLine(line []byte) { + for len(line) != 0 { + if ss.cursorX == ss.termWidth { + ss.queue([]byte("\r\n")) + ss.cursorX = 0 + ss.cursorY++ + if ss.cursorY > ss.maxLine { + ss.maxLine = ss.cursorY + } + } + + remainingOnLine := ss.termWidth - ss.cursorX + todo := len(line) + if todo > remainingOnLine { + todo = remainingOnLine + } + ss.queue(line[:todo]) + ss.cursorX += todo + line = line[todo:] + } +} + +// parsePtyRequest parses the payload of the pty-req message and extracts the +// dimensions of the terminal. See RFC 4254, section 6.2. +func parsePtyRequest(s []byte) (width, height int, ok bool) { + _, s, ok = parseString(s) + if !ok { + return + } + width32, s, ok := parseUint32(s) + if !ok { + return + } + height32, _, ok := parseUint32(s) + width = int(width32) + height = int(height32) + if width < 1 { + ok = false + } + if height < 1 { + ok = false + } + return +} + +func (ss *ServerShell) Write(buf []byte) (n int, err os.Error) { + return ss.c.Write(buf) +} + +// ReadLine returns a line of input from the terminal. +func (ss *ServerShell) ReadLine() (line string, err os.Error) { + ss.writeLine([]byte(ss.prompt)) + ss.c.Write(ss.outBuf) + ss.outBuf = ss.outBuf[:0] + + for { + // ss.remainder is a slice at the beginning of ss.inBuf + // containing a partial key sequence + readBuf := ss.inBuf[len(ss.remainder):] + n, err := ss.c.Read(readBuf) + if err == nil { + ss.remainder = ss.inBuf[:n+len(ss.remainder)] + rest := ss.remainder + lineOk := false + for !lineOk { + var key int + key, rest = bytesToKey(rest) + if key < 0 { + break + } + if key == keyCtrlD { + return "", os.EOF + } + line, lineOk = ss.handleKey(key) + } + if len(rest) > 0 { + n := copy(ss.inBuf[:], rest) + ss.remainder = ss.inBuf[:n] + } else { + ss.remainder = nil + } + ss.c.Write(ss.outBuf) + ss.outBuf = ss.outBuf[:0] + if lineOk { + return + } + continue + } + + if req, ok := err.(ChannelRequest); ok { + ok := false + switch req.Request { + case "pty-req": + ss.termWidth, ss.termHeight, ok = parsePtyRequest(req.Payload) + if !ok { + ss.termWidth = 80 + ss.termHeight = 24 + } + case "shell": + ok = true + if len(req.Payload) > 0 { + // We don't accept any commands, only the default shell. + ok = false + } + case "env": + ok = true + } + if req.WantReply { + ss.c.AckRequest(ok) + } + } else { + return "", err + } + } + panic("unreachable") +} diff --git a/src/pkg/exp/ssh/server_shell_test.go b/src/pkg/exp/ssh/server_shell_test.go new file mode 100644 index 00000000000..622cf7cfada --- /dev/null +++ b/src/pkg/exp/ssh/server_shell_test.go @@ -0,0 +1,134 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package ssh + +import ( + "testing" + "os" +) + +type MockChannel struct { + toSend []byte + bytesPerRead int + received []byte +} + +func (c *MockChannel) Accept() os.Error { + return nil +} + +func (c *MockChannel) Reject(RejectionReason, string) os.Error { + return nil +} + +func (c *MockChannel) Read(data []byte) (n int, err os.Error) { + n = len(data) + if n == 0 { + return + } + if n > len(c.toSend) { + n = len(c.toSend) + } + if n == 0 { + return 0, os.EOF + } + if c.bytesPerRead > 0 && n > c.bytesPerRead { + n = c.bytesPerRead + } + copy(data, c.toSend[:n]) + c.toSend = c.toSend[n:] + return +} + +func (c *MockChannel) Write(data []byte) (n int, err os.Error) { + c.received = append(c.received, data...) + return len(data), nil +} + +func (c *MockChannel) Close() os.Error { + return nil +} + +func (c *MockChannel) AckRequest(ok bool) os.Error { + return nil +} + +func (c *MockChannel) ChannelType() string { + return "" +} + +func (c *MockChannel) ExtraData() []byte { + return nil +} + +func TestClose(t *testing.T) { + c := &MockChannel{} + ss := NewServerShell(c, "> ") + line, err := ss.ReadLine() + if line != "" { + t.Errorf("Expected empty line but got: %s", line) + } + if err != os.EOF { + t.Errorf("Error should have been EOF but got: %s", err) + } +} + +var keyPressTests = []struct { + in string + line string + err os.Error +}{ + { + "", + "", + os.EOF, + }, + { + "\r", + "", + nil, + }, + { + "foo\r", + "foo", + nil, + }, + { + "a\x1b[Cb\r", // right + "ab", + nil, + }, + { + "a\x1b[Db\r", // left + "ba", + nil, + }, + { + "a\177b\r", // backspace + "b", + nil, + }, +} + +func TestKeyPresses(t *testing.T) { + for i, test := range keyPressTests { + for j := 0; j < len(test.in); j++ { + c := &MockChannel{ + toSend: []byte(test.in), + bytesPerRead: j, + } + ss := NewServerShell(c, "> ") + line, err := ss.ReadLine() + if line != test.line { + t.Errorf("Line resulting from test %d (%d bytes per read) was '%s', expected '%s'", i, j, line, test.line) + break + } + if err != test.err { + t.Errorf("Error resulting from test %d (%d bytes per read) was '%v', expected '%v'", i, j, err, test.err) + break + } + } + } +} diff --git a/src/pkg/exp/ssh/transport.go b/src/pkg/exp/ssh/transport.go new file mode 100644 index 00000000000..919759ff989 --- /dev/null +++ b/src/pkg/exp/ssh/transport.go @@ -0,0 +1,308 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package ssh + +import ( + "bufio" + "crypto" + "crypto/aes" + "crypto/cipher" + "crypto/hmac" + "crypto/subtle" + "hash" + "io" + "net" + "os" +) + +// halfConnection represents one direction of an SSH connection. It maintains +// the cipher state needed to process messages. +type halfConnection struct { + // Only one of these two will be non-nil + in *bufio.Reader + out net.Conn + + rand io.Reader + cipherAlgo string + macAlgo string + compressionAlgo string + paddingMultiple int + + seqNum uint32 + + mac hash.Hash + cipher cipher.Stream +} + +func (hc *halfConnection) readOnePacket() (packet []byte, err os.Error) { + var lengthBytes [5]byte + + _, err = io.ReadFull(hc.in, lengthBytes[:]) + if err != nil { + return + } + + if hc.cipher != nil { + hc.cipher.XORKeyStream(lengthBytes[:], lengthBytes[:]) + } + + macSize := 0 + if hc.mac != nil { + hc.mac.Reset() + var seqNumBytes [4]byte + seqNumBytes[0] = byte(hc.seqNum >> 24) + seqNumBytes[1] = byte(hc.seqNum >> 16) + seqNumBytes[2] = byte(hc.seqNum >> 8) + seqNumBytes[3] = byte(hc.seqNum) + hc.mac.Write(seqNumBytes[:]) + hc.mac.Write(lengthBytes[:]) + macSize = hc.mac.Size() + } + + length := uint32(lengthBytes[0])<<24 | uint32(lengthBytes[1])<<16 | uint32(lengthBytes[2])<<8 | uint32(lengthBytes[3]) + + paddingLength := uint32(lengthBytes[4]) + + if length <= paddingLength+1 { + return nil, os.NewError("invalid packet length") + } + if length > maxPacketSize { + return nil, os.NewError("packet too large") + } + + packet = make([]byte, length-1+uint32(macSize)) + _, err = io.ReadFull(hc.in, packet) + if err != nil { + return nil, err + } + mac := packet[length-1:] + if hc.cipher != nil { + hc.cipher.XORKeyStream(packet, packet[:length-1]) + } + + if hc.mac != nil { + hc.mac.Write(packet[:length-1]) + if subtle.ConstantTimeCompare(hc.mac.Sum(), mac) != 1 { + return nil, os.NewError("ssh: MAC failure") + } + } + + hc.seqNum++ + packet = packet[:length-paddingLength-1] + return +} + +func (hc *halfConnection) readPacket() (packet []byte, err os.Error) { + for { + packet, err := hc.readOnePacket() + if err != nil { + return nil, err + } + if packet[0] != msgIgnore && packet[0] != msgDebug { + return packet, nil + } + } + panic("unreachable") +} + +func (hc *halfConnection) writePacket(packet []byte) os.Error { + paddingMultiple := hc.paddingMultiple + if paddingMultiple == 0 { + paddingMultiple = 8 + } + + paddingLength := paddingMultiple - (4+1+len(packet))%paddingMultiple + if paddingLength < 4 { + paddingLength += paddingMultiple + } + + var lengthBytes [5]byte + length := len(packet) + 1 + paddingLength + lengthBytes[0] = byte(length >> 24) + lengthBytes[1] = byte(length >> 16) + lengthBytes[2] = byte(length >> 8) + lengthBytes[3] = byte(length) + lengthBytes[4] = byte(paddingLength) + + var padding [32]byte + _, err := io.ReadFull(hc.rand, padding[:paddingLength]) + if err != nil { + return err + } + + if hc.mac != nil { + hc.mac.Reset() + var seqNumBytes [4]byte + seqNumBytes[0] = byte(hc.seqNum >> 24) + seqNumBytes[1] = byte(hc.seqNum >> 16) + seqNumBytes[2] = byte(hc.seqNum >> 8) + seqNumBytes[3] = byte(hc.seqNum) + hc.mac.Write(seqNumBytes[:]) + hc.mac.Write(lengthBytes[:]) + hc.mac.Write(packet) + hc.mac.Write(padding[:paddingLength]) + } + + if hc.cipher != nil { + hc.cipher.XORKeyStream(lengthBytes[:], lengthBytes[:]) + hc.cipher.XORKeyStream(packet, packet) + hc.cipher.XORKeyStream(padding[:], padding[:paddingLength]) + } + + _, err = hc.out.Write(lengthBytes[:]) + if err != nil { + return err + } + _, err = hc.out.Write(packet) + if err != nil { + return err + } + _, err = hc.out.Write(padding[:paddingLength]) + if err != nil { + return err + } + + if hc.mac != nil { + _, err = hc.out.Write(hc.mac.Sum()) + } + + hc.seqNum++ + + return err +} + +const ( + serverKeys = iota + clientKeys +) + +// setupServerKeys sets the cipher and MAC keys from K, H and sessionId, as +// described in RFC 4253, section 6.4. direction should either be serverKeys +// (to setup server->client keys) or clientKeys (for client->server keys). +func (hc *halfConnection) setupKeys(direction int, K, H, sessionId []byte, hashFunc crypto.Hash) os.Error { + h := hashFunc.New() + + // We only support these algorithms for now. + if hc.cipherAlgo != cipherAES128CTR || hc.macAlgo != macSHA196 { + return os.NewError("ssh: setupServerKeys internal error") + } + + blockSize := 16 + keySize := 16 + macKeySize := 20 + + var ivTag, keyTag, macKeyTag byte + if direction == serverKeys { + ivTag, keyTag, macKeyTag = 'B', 'D', 'F' + } else { + ivTag, keyTag, macKeyTag = 'A', 'C', 'E' + } + + iv := make([]byte, blockSize) + key := make([]byte, keySize) + macKey := make([]byte, macKeySize) + generateKeyMaterial(iv, ivTag, K, H, sessionId, h) + generateKeyMaterial(key, keyTag, K, H, sessionId, h) + generateKeyMaterial(macKey, macKeyTag, K, H, sessionId, h) + + hc.mac = truncatingMAC{12, hmac.NewSHA1(macKey)} + aes, err := aes.NewCipher(key) + if err != nil { + return err + } + hc.cipher = cipher.NewCTR(aes, iv) + hc.paddingMultiple = 16 + return nil +} + +// generateKeyMaterial fills out with key material generated from tag, K, H +// and sessionId, as specified in RFC 4253, section 7.2. +func generateKeyMaterial(out []byte, tag byte, K, H, sessionId []byte, h hash.Hash) { + var digestsSoFar []byte + + for len(out) > 0 { + h.Reset() + h.Write(K) + h.Write(H) + + if len(digestsSoFar) == 0 { + h.Write([]byte{tag}) + h.Write(sessionId) + } else { + h.Write(digestsSoFar) + } + + digest := h.Sum() + n := copy(out, digest) + out = out[n:] + if len(out) > 0 { + digestsSoFar = append(digestsSoFar, digest...) + } + } +} + +// truncatingMAC wraps around a hash.Hash and truncates the output digest to +// a given size. +type truncatingMAC struct { + length int + hmac hash.Hash +} + +func (t truncatingMAC) Write(data []byte) (int, os.Error) { + return t.hmac.Write(data) +} + +func (t truncatingMAC) Sum() []byte { + digest := t.hmac.Sum() + return digest[:t.length] +} + +func (t truncatingMAC) Reset() { + t.hmac.Reset() +} + +func (t truncatingMAC) Size() int { + return t.length +} + +// maxVersionStringBytes is the maximum number of bytes that we'll accept as a +// version string. In the event that the client is talking a different protocol +// we need to set a limit otherwise we will keep using more and more memory +// while searching for the end of the version handshake. +const maxVersionStringBytes = 1024 + +func readVersion(r *bufio.Reader) (versionString []byte, ok bool) { + versionString = make([]byte, 0, 64) + seenCR := false + +forEachByte: + for len(versionString) < maxVersionStringBytes { + b, err := r.ReadByte() + if err != nil { + return + } + + if !seenCR { + if b == '\r' { + seenCR = true + } + } else { + if b == '\n' { + ok = true + break forEachByte + } else { + seenCR = false + } + } + versionString = append(versionString, b) + } + + if ok { + // We need to remove the CR from versionString + versionString = versionString[:len(versionString)-1] + } + + return +}