mirror of
https://github.com/golang/go
synced 2024-11-21 14:04:41 -07:00
web socket: fix short Read
Fixes #1145. R=rsc CC=golang-dev https://golang.org/cl/2302042
This commit is contained in:
parent
f57f8b6f68
commit
64cc5be4ad
@ -27,6 +27,13 @@ func (addr WebSocketAddr) Network() string { return "websocket" }
|
||||
// String returns the network address for a Web Socket.
|
||||
func (addr WebSocketAddr) String() string { return string(addr) }
|
||||
|
||||
const (
|
||||
stateFrameByte = iota
|
||||
stateFrameLength
|
||||
stateFrameData
|
||||
stateFrameTextData
|
||||
)
|
||||
|
||||
// Conn is a channel to communicate to a Web Socket.
|
||||
// It implements the net.Conn interface.
|
||||
type Conn struct {
|
||||
@ -39,6 +46,10 @@ type Conn struct {
|
||||
|
||||
buf *bufio.ReadWriter
|
||||
rwc io.ReadWriteCloser
|
||||
|
||||
// It holds text data in previous Read() that failed with small buffer.
|
||||
data []byte
|
||||
reading bool
|
||||
}
|
||||
|
||||
// newConn creates a new Web Socket.
|
||||
@ -48,60 +59,66 @@ func newConn(origin, location, protocol string, buf *bufio.ReadWriter, rwc io.Re
|
||||
bw := bufio.NewWriter(rwc)
|
||||
buf = bufio.NewReadWriter(br, bw)
|
||||
}
|
||||
ws := &Conn{origin, location, protocol, buf, rwc}
|
||||
ws := &Conn{Origin: origin, Location: location, Protocol: protocol, buf: buf, rwc: rwc}
|
||||
return ws
|
||||
}
|
||||
|
||||
// Read implements the io.Reader interface for a Conn.
|
||||
func (ws *Conn) Read(msg []byte) (n int, err os.Error) {
|
||||
for {
|
||||
frameByte, err := ws.buf.ReadByte()
|
||||
Frame:
|
||||
for !ws.reading && len(ws.data) == 0 {
|
||||
// Beginning of frame, possibly.
|
||||
b, err := ws.buf.ReadByte()
|
||||
if err != nil {
|
||||
return n, err
|
||||
return 0, err
|
||||
}
|
||||
if (frameByte & 0x80) == 0x80 {
|
||||
if b&0x80 == 0x80 {
|
||||
// Skip length frame.
|
||||
length := 0
|
||||
for {
|
||||
c, err := ws.buf.ReadByte()
|
||||
if err != nil {
|
||||
return n, err
|
||||
return 0, err
|
||||
}
|
||||
length = length*128 + int(c&0x7f)
|
||||
if (c & 0x80) == 0 {
|
||||
if c&0x80 == 0 {
|
||||
break
|
||||
}
|
||||
}
|
||||
for length > 0 {
|
||||
_, err := ws.buf.ReadByte()
|
||||
if err != nil {
|
||||
return n, err
|
||||
return 0, err
|
||||
}
|
||||
length--
|
||||
}
|
||||
} else {
|
||||
continue Frame
|
||||
}
|
||||
// In text mode
|
||||
if b != 0 {
|
||||
// Skip this frame
|
||||
for {
|
||||
c, err := ws.buf.ReadByte()
|
||||
if err != nil {
|
||||
return n, err
|
||||
return 0, err
|
||||
}
|
||||
if c == '\xff' {
|
||||
return n, err
|
||||
}
|
||||
if frameByte == 0 {
|
||||
if n+1 <= cap(msg) {
|
||||
msg = msg[0 : n+1]
|
||||
}
|
||||
msg[n] = c
|
||||
n++
|
||||
}
|
||||
if n >= cap(msg) {
|
||||
return n, os.E2BIG
|
||||
break
|
||||
}
|
||||
}
|
||||
continue Frame
|
||||
}
|
||||
ws.reading = true
|
||||
}
|
||||
if len(ws.data) == 0 {
|
||||
ws.data, err = ws.buf.ReadSlice('\xff')
|
||||
if err == nil {
|
||||
ws.reading = false
|
||||
ws.data = ws.data[:len(ws.data)-1] // trim \xff
|
||||
}
|
||||
}
|
||||
|
||||
panic("unreachable")
|
||||
n = copy(msg, ws.data)
|
||||
ws.data = ws.data[n:]
|
||||
return n, err
|
||||
}
|
||||
|
||||
// Write implements the io.Writer interface for a Conn.
|
||||
|
@ -5,6 +5,7 @@
|
||||
package websocket
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"fmt"
|
||||
"http"
|
||||
@ -195,3 +196,77 @@ func TestTrailingSpaces(t *testing.T) {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestSmallBuffer(t *testing.T) {
|
||||
// http://code.google.com/p/go/issues/detail?id=1145
|
||||
// Read should be able to handle reading a fragment of a frame.
|
||||
once.Do(startServer)
|
||||
|
||||
// websocket.Dial()
|
||||
client, err := net.Dial("tcp", "", serverAddr)
|
||||
if err != nil {
|
||||
t.Fatal("dialing", err)
|
||||
}
|
||||
ws, err := newClient("/echo", "localhost", "http://localhost",
|
||||
"ws://localhost/echo", "", client, handshake)
|
||||
if err != nil {
|
||||
t.Errorf("WebSocket handshake error: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
msg := []byte("hello, world\n")
|
||||
if _, err := ws.Write(msg); err != nil {
|
||||
t.Errorf("Write: %v", err)
|
||||
}
|
||||
var small_msg = make([]byte, 8)
|
||||
n, err := ws.Read(small_msg)
|
||||
if err != nil {
|
||||
t.Errorf("Read: %v", err)
|
||||
}
|
||||
if !bytes.Equal(msg[:len(small_msg)], small_msg) {
|
||||
t.Errorf("Echo: expected %q got %q", msg[:len(small_msg)], small_msg)
|
||||
}
|
||||
var second_msg = make([]byte, len(msg))
|
||||
n, err = ws.Read(second_msg)
|
||||
if err != nil {
|
||||
t.Errorf("Read: %v", err)
|
||||
}
|
||||
second_msg = second_msg[0:n]
|
||||
if !bytes.Equal(msg[len(small_msg):], second_msg) {
|
||||
t.Errorf("Echo: expected %q got %q", msg[len(small_msg):], second_msg)
|
||||
}
|
||||
ws.Close()
|
||||
|
||||
}
|
||||
|
||||
func testSkipLengthFrame(t *testing.T) {
|
||||
b := []byte{'\x80', '\x01', 'x', 0, 'h', 'e', 'l', 'l', 'o', '\xff'}
|
||||
buf := bytes.NewBuffer(b)
|
||||
br := bufio.NewReader(buf)
|
||||
bw := bufio.NewWriter(buf)
|
||||
ws := newConn("http://127.0.0.1/", "ws://127.0.0.1/", "", bufio.NewReadWriter(br, bw), nil)
|
||||
msg := make([]byte, 5)
|
||||
n, err := ws.Read(msg)
|
||||
if err != nil {
|
||||
t.Errorf("Read: %v", err)
|
||||
}
|
||||
if !bytes.Equal(b[4:8], msg[0:n]) {
|
||||
t.Errorf("Read: expected %q got %q", msg[4:8], msg[0:n])
|
||||
}
|
||||
}
|
||||
|
||||
func testSkipNoUTF8Frame(t *testing.T) {
|
||||
b := []byte{'\x01', 'n', '\xff', 0, 'h', 'e', 'l', 'l', 'o', '\xff'}
|
||||
buf := bytes.NewBuffer(b)
|
||||
br := bufio.NewReader(buf)
|
||||
bw := bufio.NewWriter(buf)
|
||||
ws := newConn("http://127.0.0.1/", "ws://127.0.0.1/", "", bufio.NewReadWriter(br, bw), nil)
|
||||
msg := make([]byte, 5)
|
||||
n, err := ws.Read(msg)
|
||||
if err != nil {
|
||||
t.Errorf("Read: %v", err)
|
||||
}
|
||||
if !bytes.Equal(b[4:8], msg[0:n]) {
|
||||
t.Errorf("Read: expected %q got %q", msg[4:8], msg[0:n])
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user