1
0
mirror of https://github.com/golang/go synced 2024-11-21 20:14:52 -07:00

web socket: fix short Read

Fixes #1145.

R=rsc
CC=golang-dev
https://golang.org/cl/2302042
This commit is contained in:
Fumitoshi Ukai 2010-10-20 22:36:06 -04:00 committed by Russ Cox
parent f57f8b6f68
commit 64cc5be4ad
2 changed files with 116 additions and 24 deletions

View File

@ -27,6 +27,13 @@ func (addr WebSocketAddr) Network() string { return "websocket" }
// String returns the network address for a Web Socket.
func (addr WebSocketAddr) String() string { return string(addr) }
const (
stateFrameByte = iota
stateFrameLength
stateFrameData
stateFrameTextData
)
// Conn is a channel to communicate to a Web Socket.
// It implements the net.Conn interface.
type Conn struct {
@ -39,6 +46,10 @@ type Conn struct {
buf *bufio.ReadWriter
rwc io.ReadWriteCloser
// It holds text data in previous Read() that failed with small buffer.
data []byte
reading bool
}
// newConn creates a new Web Socket.
@ -48,60 +59,66 @@ func newConn(origin, location, protocol string, buf *bufio.ReadWriter, rwc io.Re
bw := bufio.NewWriter(rwc)
buf = bufio.NewReadWriter(br, bw)
}
ws := &Conn{origin, location, protocol, buf, rwc}
ws := &Conn{Origin: origin, Location: location, Protocol: protocol, buf: buf, rwc: rwc}
return ws
}
// Read implements the io.Reader interface for a Conn.
func (ws *Conn) Read(msg []byte) (n int, err os.Error) {
for {
frameByte, err := ws.buf.ReadByte()
Frame:
for !ws.reading && len(ws.data) == 0 {
// Beginning of frame, possibly.
b, err := ws.buf.ReadByte()
if err != nil {
return n, err
return 0, err
}
if (frameByte & 0x80) == 0x80 {
if b&0x80 == 0x80 {
// Skip length frame.
length := 0
for {
c, err := ws.buf.ReadByte()
if err != nil {
return n, err
return 0, err
}
length = length*128 + int(c&0x7f)
if (c & 0x80) == 0 {
if c&0x80 == 0 {
break
}
}
for length > 0 {
_, err := ws.buf.ReadByte()
if err != nil {
return n, err
return 0, err
}
length--
}
} else {
continue Frame
}
// In text mode
if b != 0 {
// Skip this frame
for {
c, err := ws.buf.ReadByte()
if err != nil {
return n, err
return 0, err
}
if c == '\xff' {
break
}
}
continue Frame
}
ws.reading = true
}
if len(ws.data) == 0 {
ws.data, err = ws.buf.ReadSlice('\xff')
if err == nil {
ws.reading = false
ws.data = ws.data[:len(ws.data)-1] // trim \xff
}
}
n = copy(msg, ws.data)
ws.data = ws.data[n:]
return n, err
}
if frameByte == 0 {
if n+1 <= cap(msg) {
msg = msg[0 : n+1]
}
msg[n] = c
n++
}
if n >= cap(msg) {
return n, os.E2BIG
}
}
}
}
panic("unreachable")
}
// Write implements the io.Writer interface for a Conn.

View File

@ -5,6 +5,7 @@
package websocket
import (
"bufio"
"bytes"
"fmt"
"http"
@ -195,3 +196,77 @@ func TestTrailingSpaces(t *testing.T) {
}
}
}
func TestSmallBuffer(t *testing.T) {
// http://code.google.com/p/go/issues/detail?id=1145
// Read should be able to handle reading a fragment of a frame.
once.Do(startServer)
// websocket.Dial()
client, err := net.Dial("tcp", "", serverAddr)
if err != nil {
t.Fatal("dialing", err)
}
ws, err := newClient("/echo", "localhost", "http://localhost",
"ws://localhost/echo", "", client, handshake)
if err != nil {
t.Errorf("WebSocket handshake error: %v", err)
return
}
msg := []byte("hello, world\n")
if _, err := ws.Write(msg); err != nil {
t.Errorf("Write: %v", err)
}
var small_msg = make([]byte, 8)
n, err := ws.Read(small_msg)
if err != nil {
t.Errorf("Read: %v", err)
}
if !bytes.Equal(msg[:len(small_msg)], small_msg) {
t.Errorf("Echo: expected %q got %q", msg[:len(small_msg)], small_msg)
}
var second_msg = make([]byte, len(msg))
n, err = ws.Read(second_msg)
if err != nil {
t.Errorf("Read: %v", err)
}
second_msg = second_msg[0:n]
if !bytes.Equal(msg[len(small_msg):], second_msg) {
t.Errorf("Echo: expected %q got %q", msg[len(small_msg):], second_msg)
}
ws.Close()
}
func testSkipLengthFrame(t *testing.T) {
b := []byte{'\x80', '\x01', 'x', 0, 'h', 'e', 'l', 'l', 'o', '\xff'}
buf := bytes.NewBuffer(b)
br := bufio.NewReader(buf)
bw := bufio.NewWriter(buf)
ws := newConn("http://127.0.0.1/", "ws://127.0.0.1/", "", bufio.NewReadWriter(br, bw), nil)
msg := make([]byte, 5)
n, err := ws.Read(msg)
if err != nil {
t.Errorf("Read: %v", err)
}
if !bytes.Equal(b[4:8], msg[0:n]) {
t.Errorf("Read: expected %q got %q", msg[4:8], msg[0:n])
}
}
func testSkipNoUTF8Frame(t *testing.T) {
b := []byte{'\x01', 'n', '\xff', 0, 'h', 'e', 'l', 'l', 'o', '\xff'}
buf := bytes.NewBuffer(b)
br := bufio.NewReader(buf)
bw := bufio.NewWriter(buf)
ws := newConn("http://127.0.0.1/", "ws://127.0.0.1/", "", bufio.NewReadWriter(br, bw), nil)
msg := make([]byte, 5)
n, err := ws.Read(msg)
if err != nil {
t.Errorf("Read: %v", err)
}
if !bytes.Equal(b[4:8], msg[0:n]) {
t.Errorf("Read: expected %q got %q", msg[4:8], msg[0:n])
}
}