From dbebb08601ae43566ed19748a838b2a36481f61a Mon Sep 17 00:00:00 2001 From: Adam Langley Date: Wed, 18 Jan 2012 15:04:17 -0500 Subject: [PATCH] exp/ssh: handle versions with just '\n' djm recommend that we do this because OpenSSL was only fixed in 2008: http://anoncvs.mindrot.org/index.cgi/openssh/sshd.c?revision=1.380&view=markup R=dave, jonathan.mark.pittman CC=golang-dev https://golang.org/cl/5555044 --- src/pkg/exp/ssh/transport.go | 31 +++++++++++++------------------ src/pkg/exp/ssh/transport_test.go | 18 ++++++++++++++++-- 2 files changed, 29 insertions(+), 20 deletions(-) diff --git a/src/pkg/exp/ssh/transport.go b/src/pkg/exp/ssh/transport.go index 5c15fe8505b..60a636f0a4b 100644 --- a/src/pkg/exp/ssh/transport.go +++ b/src/pkg/exp/ssh/transport.go @@ -339,7 +339,7 @@ const maxVersionStringBytes = 1024 // Read version string as specified by RFC 4253, section 4.2. func readVersion(r io.Reader) ([]byte, error) { versionString := make([]byte, 0, 64) - var ok, seenCR bool + var ok bool var buf [1]byte forEachByte: for len(versionString) < maxVersionStringBytes { @@ -347,27 +347,22 @@ forEachByte: if err != nil { return nil, err } - b := buf[0] - - if !seenCR { - if b == '\r' { - seenCR = true - } - } else { - if b == '\n' { - ok = true - break forEachByte - } else { - seenCR = false - } + // The RFC says that the version should be terminated with \r\n + // but several SSH servers actually only send a \n. + if buf[0] == '\n' { + ok = true + break forEachByte } - versionString = append(versionString, b) + versionString = append(versionString, buf[0]) } if !ok { - return nil, errors.New("failed to read version string") + return nil, errors.New("ssh: failed to read version string") } - // We need to remove the CR from versionString - return versionString[:len(versionString)-1], nil + // There might be a '\r' on the end which we should remove. + if len(versionString) > 0 && versionString[len(versionString)-1] == '\r' { + versionString = versionString[:len(versionString)-1] + } + return versionString, nil } diff --git a/src/pkg/exp/ssh/transport_test.go b/src/pkg/exp/ssh/transport_test.go index b2e2a7fc92a..ab9177f0d11 100644 --- a/src/pkg/exp/ssh/transport_test.go +++ b/src/pkg/exp/ssh/transport_test.go @@ -11,7 +11,7 @@ import ( ) func TestReadVersion(t *testing.T) { - buf := []byte(serverVersion) + buf := serverVersion result, err := readVersion(bufio.NewReader(bytes.NewBuffer(buf))) if err != nil { t.Errorf("readVersion didn't read version correctly: %s", err) @@ -21,6 +21,20 @@ func TestReadVersion(t *testing.T) { } } +func TestReadVersionWithJustLF(t *testing.T) { + var buf []byte + buf = append(buf, serverVersion...) + buf = buf[:len(buf)-1] + buf[len(buf)-1] = '\n' + result, err := readVersion(bufio.NewReader(bytes.NewBuffer(buf))) + if err != nil { + t.Error("readVersion failed to handle just a \n") + } + if !bytes.Equal(buf[:len(buf)-1], result) { + t.Errorf("version read did not match expected: got %x, want %x", result, buf[:len(buf)-1]) + } +} + func TestReadVersionTooLong(t *testing.T) { buf := make([]byte, maxVersionStringBytes+1) if _, err := readVersion(bufio.NewReader(bytes.NewBuffer(buf))); err == nil { @@ -29,7 +43,7 @@ func TestReadVersionTooLong(t *testing.T) { } func TestReadVersionWithoutCRLF(t *testing.T) { - buf := []byte(serverVersion) + buf := serverVersion buf = buf[:len(buf)-1] if _, err := readVersion(bufio.NewReader(bytes.NewBuffer(buf))); err == nil { t.Error("readVersion did not notice \\n was missing")