mirror of
https://github.com/golang/go
synced 2024-11-22 02:34:40 -07:00
web socket: fix short Read
Fixes #1145. R=rsc CC=golang-dev https://golang.org/cl/2302042
This commit is contained in:
parent
f57f8b6f68
commit
64cc5be4ad
@ -27,6 +27,13 @@ func (addr WebSocketAddr) Network() string { return "websocket" }
|
|||||||
// String returns the network address for a Web Socket.
|
// String returns the network address for a Web Socket.
|
||||||
func (addr WebSocketAddr) String() string { return string(addr) }
|
func (addr WebSocketAddr) String() string { return string(addr) }
|
||||||
|
|
||||||
|
const (
|
||||||
|
stateFrameByte = iota
|
||||||
|
stateFrameLength
|
||||||
|
stateFrameData
|
||||||
|
stateFrameTextData
|
||||||
|
)
|
||||||
|
|
||||||
// Conn is a channel to communicate to a Web Socket.
|
// Conn is a channel to communicate to a Web Socket.
|
||||||
// It implements the net.Conn interface.
|
// It implements the net.Conn interface.
|
||||||
type Conn struct {
|
type Conn struct {
|
||||||
@ -39,6 +46,10 @@ type Conn struct {
|
|||||||
|
|
||||||
buf *bufio.ReadWriter
|
buf *bufio.ReadWriter
|
||||||
rwc io.ReadWriteCloser
|
rwc io.ReadWriteCloser
|
||||||
|
|
||||||
|
// It holds text data in previous Read() that failed with small buffer.
|
||||||
|
data []byte
|
||||||
|
reading bool
|
||||||
}
|
}
|
||||||
|
|
||||||
// newConn creates a new Web Socket.
|
// newConn creates a new Web Socket.
|
||||||
@ -48,60 +59,66 @@ func newConn(origin, location, protocol string, buf *bufio.ReadWriter, rwc io.Re
|
|||||||
bw := bufio.NewWriter(rwc)
|
bw := bufio.NewWriter(rwc)
|
||||||
buf = bufio.NewReadWriter(br, bw)
|
buf = bufio.NewReadWriter(br, bw)
|
||||||
}
|
}
|
||||||
ws := &Conn{origin, location, protocol, buf, rwc}
|
ws := &Conn{Origin: origin, Location: location, Protocol: protocol, buf: buf, rwc: rwc}
|
||||||
return ws
|
return ws
|
||||||
}
|
}
|
||||||
|
|
||||||
// Read implements the io.Reader interface for a Conn.
|
// Read implements the io.Reader interface for a Conn.
|
||||||
func (ws *Conn) Read(msg []byte) (n int, err os.Error) {
|
func (ws *Conn) Read(msg []byte) (n int, err os.Error) {
|
||||||
for {
|
Frame:
|
||||||
frameByte, err := ws.buf.ReadByte()
|
for !ws.reading && len(ws.data) == 0 {
|
||||||
|
// Beginning of frame, possibly.
|
||||||
|
b, err := ws.buf.ReadByte()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return n, err
|
return 0, err
|
||||||
}
|
}
|
||||||
if (frameByte & 0x80) == 0x80 {
|
if b&0x80 == 0x80 {
|
||||||
|
// Skip length frame.
|
||||||
length := 0
|
length := 0
|
||||||
for {
|
for {
|
||||||
c, err := ws.buf.ReadByte()
|
c, err := ws.buf.ReadByte()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return n, err
|
return 0, err
|
||||||
}
|
}
|
||||||
length = length*128 + int(c&0x7f)
|
length = length*128 + int(c&0x7f)
|
||||||
if (c & 0x80) == 0 {
|
if c&0x80 == 0 {
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
for length > 0 {
|
for length > 0 {
|
||||||
_, err := ws.buf.ReadByte()
|
_, err := ws.buf.ReadByte()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return n, err
|
return 0, err
|
||||||
}
|
}
|
||||||
length--
|
|
||||||
}
|
}
|
||||||
} else {
|
continue Frame
|
||||||
|
}
|
||||||
|
// In text mode
|
||||||
|
if b != 0 {
|
||||||
|
// Skip this frame
|
||||||
for {
|
for {
|
||||||
c, err := ws.buf.ReadByte()
|
c, err := ws.buf.ReadByte()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return n, err
|
return 0, err
|
||||||
}
|
}
|
||||||
if c == '\xff' {
|
if c == '\xff' {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
continue Frame
|
||||||
|
}
|
||||||
|
ws.reading = true
|
||||||
|
}
|
||||||
|
if len(ws.data) == 0 {
|
||||||
|
ws.data, err = ws.buf.ReadSlice('\xff')
|
||||||
|
if err == nil {
|
||||||
|
ws.reading = false
|
||||||
|
ws.data = ws.data[:len(ws.data)-1] // trim \xff
|
||||||
|
}
|
||||||
|
}
|
||||||
|
n = copy(msg, ws.data)
|
||||||
|
ws.data = ws.data[n:]
|
||||||
return n, err
|
return n, err
|
||||||
}
|
|
||||||
if frameByte == 0 {
|
|
||||||
if n+1 <= cap(msg) {
|
|
||||||
msg = msg[0 : n+1]
|
|
||||||
}
|
|
||||||
msg[n] = c
|
|
||||||
n++
|
|
||||||
}
|
|
||||||
if n >= cap(msg) {
|
|
||||||
return n, os.E2BIG
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
panic("unreachable")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Write implements the io.Writer interface for a Conn.
|
// Write implements the io.Writer interface for a Conn.
|
||||||
|
@ -5,6 +5,7 @@
|
|||||||
package websocket
|
package websocket
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bufio"
|
||||||
"bytes"
|
"bytes"
|
||||||
"fmt"
|
"fmt"
|
||||||
"http"
|
"http"
|
||||||
@ -195,3 +196,77 @@ func TestTrailingSpaces(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestSmallBuffer(t *testing.T) {
|
||||||
|
// http://code.google.com/p/go/issues/detail?id=1145
|
||||||
|
// Read should be able to handle reading a fragment of a frame.
|
||||||
|
once.Do(startServer)
|
||||||
|
|
||||||
|
// websocket.Dial()
|
||||||
|
client, err := net.Dial("tcp", "", serverAddr)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal("dialing", err)
|
||||||
|
}
|
||||||
|
ws, err := newClient("/echo", "localhost", "http://localhost",
|
||||||
|
"ws://localhost/echo", "", client, handshake)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("WebSocket handshake error: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
msg := []byte("hello, world\n")
|
||||||
|
if _, err := ws.Write(msg); err != nil {
|
||||||
|
t.Errorf("Write: %v", err)
|
||||||
|
}
|
||||||
|
var small_msg = make([]byte, 8)
|
||||||
|
n, err := ws.Read(small_msg)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Read: %v", err)
|
||||||
|
}
|
||||||
|
if !bytes.Equal(msg[:len(small_msg)], small_msg) {
|
||||||
|
t.Errorf("Echo: expected %q got %q", msg[:len(small_msg)], small_msg)
|
||||||
|
}
|
||||||
|
var second_msg = make([]byte, len(msg))
|
||||||
|
n, err = ws.Read(second_msg)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Read: %v", err)
|
||||||
|
}
|
||||||
|
second_msg = second_msg[0:n]
|
||||||
|
if !bytes.Equal(msg[len(small_msg):], second_msg) {
|
||||||
|
t.Errorf("Echo: expected %q got %q", msg[len(small_msg):], second_msg)
|
||||||
|
}
|
||||||
|
ws.Close()
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
func testSkipLengthFrame(t *testing.T) {
|
||||||
|
b := []byte{'\x80', '\x01', 'x', 0, 'h', 'e', 'l', 'l', 'o', '\xff'}
|
||||||
|
buf := bytes.NewBuffer(b)
|
||||||
|
br := bufio.NewReader(buf)
|
||||||
|
bw := bufio.NewWriter(buf)
|
||||||
|
ws := newConn("http://127.0.0.1/", "ws://127.0.0.1/", "", bufio.NewReadWriter(br, bw), nil)
|
||||||
|
msg := make([]byte, 5)
|
||||||
|
n, err := ws.Read(msg)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Read: %v", err)
|
||||||
|
}
|
||||||
|
if !bytes.Equal(b[4:8], msg[0:n]) {
|
||||||
|
t.Errorf("Read: expected %q got %q", msg[4:8], msg[0:n])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func testSkipNoUTF8Frame(t *testing.T) {
|
||||||
|
b := []byte{'\x01', 'n', '\xff', 0, 'h', 'e', 'l', 'l', 'o', '\xff'}
|
||||||
|
buf := bytes.NewBuffer(b)
|
||||||
|
br := bufio.NewReader(buf)
|
||||||
|
bw := bufio.NewWriter(buf)
|
||||||
|
ws := newConn("http://127.0.0.1/", "ws://127.0.0.1/", "", bufio.NewReadWriter(br, bw), nil)
|
||||||
|
msg := make([]byte, 5)
|
||||||
|
n, err := ws.Read(msg)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Read: %v", err)
|
||||||
|
}
|
||||||
|
if !bytes.Equal(b[4:8], msg[0:n]) {
|
||||||
|
t.Errorf("Read: expected %q got %q", msg[4:8], msg[0:n])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user