diff --git a/src/pkg/websocket/websocket.go b/src/pkg/websocket/websocket.go index 99e1d14485..d5996abe1a 100644 --- a/src/pkg/websocket/websocket.go +++ b/src/pkg/websocket/websocket.go @@ -27,6 +27,13 @@ func (addr WebSocketAddr) Network() string { return "websocket" } // String returns the network address for a Web Socket. func (addr WebSocketAddr) String() string { return string(addr) } +const ( + stateFrameByte = iota + stateFrameLength + stateFrameData + stateFrameTextData +) + // Conn is a channel to communicate to a Web Socket. // It implements the net.Conn interface. type Conn struct { @@ -39,6 +46,10 @@ type Conn struct { buf *bufio.ReadWriter rwc io.ReadWriteCloser + + // It holds text data in previous Read() that failed with small buffer. + data []byte + reading bool } // newConn creates a new Web Socket. @@ -48,60 +59,66 @@ func newConn(origin, location, protocol string, buf *bufio.ReadWriter, rwc io.Re bw := bufio.NewWriter(rwc) buf = bufio.NewReadWriter(br, bw) } - ws := &Conn{origin, location, protocol, buf, rwc} + ws := &Conn{Origin: origin, Location: location, Protocol: protocol, buf: buf, rwc: rwc} return ws } // Read implements the io.Reader interface for a Conn. func (ws *Conn) Read(msg []byte) (n int, err os.Error) { - for { - frameByte, err := ws.buf.ReadByte() +Frame: + for !ws.reading && len(ws.data) == 0 { + // Beginning of frame, possibly. + b, err := ws.buf.ReadByte() if err != nil { - return n, err + return 0, err } - if (frameByte & 0x80) == 0x80 { + if b&0x80 == 0x80 { + // Skip length frame. length := 0 for { c, err := ws.buf.ReadByte() if err != nil { - return n, err + return 0, err } length = length*128 + int(c&0x7f) - if (c & 0x80) == 0 { + if c&0x80 == 0 { break } } for length > 0 { _, err := ws.buf.ReadByte() if err != nil { - return n, err + return 0, err } - length-- } - } else { + continue Frame + } + // In text mode + if b != 0 { + // Skip this frame for { c, err := ws.buf.ReadByte() if err != nil { - return n, err + return 0, err } if c == '\xff' { - return n, err - } - if frameByte == 0 { - if n+1 <= cap(msg) { - msg = msg[0 : n+1] - } - msg[n] = c - n++ - } - if n >= cap(msg) { - return n, os.E2BIG + break } } + continue Frame + } + ws.reading = true + } + if len(ws.data) == 0 { + ws.data, err = ws.buf.ReadSlice('\xff') + if err == nil { + ws.reading = false + ws.data = ws.data[:len(ws.data)-1] // trim \xff } } - - panic("unreachable") + n = copy(msg, ws.data) + ws.data = ws.data[n:] + return n, err } // Write implements the io.Writer interface for a Conn. diff --git a/src/pkg/websocket/websocket_test.go b/src/pkg/websocket/websocket_test.go index 9639d8f88b..c66c114589 100644 --- a/src/pkg/websocket/websocket_test.go +++ b/src/pkg/websocket/websocket_test.go @@ -5,6 +5,7 @@ package websocket import ( + "bufio" "bytes" "fmt" "http" @@ -195,3 +196,77 @@ func TestTrailingSpaces(t *testing.T) { } } } + +func TestSmallBuffer(t *testing.T) { + // http://code.google.com/p/go/issues/detail?id=1145 + // Read should be able to handle reading a fragment of a frame. + once.Do(startServer) + + // websocket.Dial() + client, err := net.Dial("tcp", "", serverAddr) + if err != nil { + t.Fatal("dialing", err) + } + ws, err := newClient("/echo", "localhost", "http://localhost", + "ws://localhost/echo", "", client, handshake) + if err != nil { + t.Errorf("WebSocket handshake error: %v", err) + return + } + + msg := []byte("hello, world\n") + if _, err := ws.Write(msg); err != nil { + t.Errorf("Write: %v", err) + } + var small_msg = make([]byte, 8) + n, err := ws.Read(small_msg) + if err != nil { + t.Errorf("Read: %v", err) + } + if !bytes.Equal(msg[:len(small_msg)], small_msg) { + t.Errorf("Echo: expected %q got %q", msg[:len(small_msg)], small_msg) + } + var second_msg = make([]byte, len(msg)) + n, err = ws.Read(second_msg) + if err != nil { + t.Errorf("Read: %v", err) + } + second_msg = second_msg[0:n] + if !bytes.Equal(msg[len(small_msg):], second_msg) { + t.Errorf("Echo: expected %q got %q", msg[len(small_msg):], second_msg) + } + ws.Close() + +} + +func testSkipLengthFrame(t *testing.T) { + b := []byte{'\x80', '\x01', 'x', 0, 'h', 'e', 'l', 'l', 'o', '\xff'} + buf := bytes.NewBuffer(b) + br := bufio.NewReader(buf) + bw := bufio.NewWriter(buf) + ws := newConn("http://127.0.0.1/", "ws://127.0.0.1/", "", bufio.NewReadWriter(br, bw), nil) + msg := make([]byte, 5) + n, err := ws.Read(msg) + if err != nil { + t.Errorf("Read: %v", err) + } + if !bytes.Equal(b[4:8], msg[0:n]) { + t.Errorf("Read: expected %q got %q", msg[4:8], msg[0:n]) + } +} + +func testSkipNoUTF8Frame(t *testing.T) { + b := []byte{'\x01', 'n', '\xff', 0, 'h', 'e', 'l', 'l', 'o', '\xff'} + buf := bytes.NewBuffer(b) + br := bufio.NewReader(buf) + bw := bufio.NewWriter(buf) + ws := newConn("http://127.0.0.1/", "ws://127.0.0.1/", "", bufio.NewReadWriter(br, bw), nil) + msg := make([]byte, 5) + n, err := ws.Read(msg) + if err != nil { + t.Errorf("Read: %v", err) + } + if !bytes.Equal(b[4:8], msg[0:n]) { + t.Errorf("Read: expected %q got %q", msg[4:8], msg[0:n]) + } +}