diff --git a/src/pkg/websocket/client.go b/src/pkg/websocket/client.go index 7497ac6508..b7eaafda16 100644 --- a/src/pkg/websocket/client.go +++ b/src/pkg/websocket/client.go @@ -26,7 +26,7 @@ func (e *DialError) String() string { // NewConfig creates a new WebSocket config for client connection. func NewConfig(server, origin string) (config *Config, err os.Error) { config = new(Config) - config.Version = ProtocolVersionHybi + config.Version = ProtocolVersionHybi13 config.Location, err = url.ParseRequest(server) if err != nil { return @@ -47,7 +47,7 @@ func NewClient(config *Config, rwc io.ReadWriteCloser) (ws *Conn, err os.Error) err = hixie75ClientHandshake(config, br, bw) case ProtocolVersionHixie76, ProtocolVersionHybi00: err = hixie76ClientHandshake(config, br, bw) - case ProtocolVersionHybi: + case ProtocolVersionHybi08, ProtocolVersionHybi13: err = hybiClientHandshake(config, br, bw) default: err = ErrBadProtocolVersion @@ -59,7 +59,7 @@ func NewClient(config *Config, rwc io.ReadWriteCloser) (ws *Conn, err os.Error) switch config.Version { case ProtocolVersionHixie75, ProtocolVersionHixie76, ProtocolVersionHybi00: ws = newHixieClientConn(config, buf, rwc) - case ProtocolVersionHybi: + case ProtocolVersionHybi08, ProtocolVersionHybi13: ws = newHybiClientConn(config, buf, rwc) } return diff --git a/src/pkg/websocket/hybi.go b/src/pkg/websocket/hybi.go index c4d990d6d8..c832dfc832 100644 --- a/src/pkg/websocket/hybi.go +++ b/src/pkg/websocket/hybi.go @@ -5,7 +5,7 @@ package websocket // This file implements a protocol of hybi draft. -// http://tools.ietf.org/html/draft-ietf-hybi-thewebsocketprotocol-10 +// http://tools.ietf.org/html/draft-ietf-hybi-thewebsocketprotocol-17 import ( "bufio" @@ -26,13 +26,17 @@ import ( const ( websocketGUID = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11" - closeStatusNormal = 1000 - closeStatusGoingAway = 1001 - closeStatusProtocolError = 1002 - closeStatusUnsupportedData = 1003 - closeStatusFrameTooLarge = 1004 - closeStatusNoStatusRcvd = 1005 - closeStatusAbnormalClosure = 1006 + closeStatusNormal = 1000 + closeStatusGoingAway = 1001 + closeStatusProtocolError = 1002 + closeStatusUnsupportedData = 1003 + closeStatusFrameTooLarge = 1004 + closeStatusNoStatusRcvd = 1005 + closeStatusAbnormalClosure = 1006 + closeStatusBadMessageData = 1007 + closeStatusPolicyViolation = 1008 + closeStatusTooBigData = 1009 + closeStatusExtensionMismatch = 1010 maxControlFramePayloadLength = 125 ) @@ -101,8 +105,8 @@ type hybiFrameReaderFactory struct { } // NewFrameReader reads a frame header from the connection, and creates new reader for the frame. -// See Section 4.2 Base Frameing protocol for detail. -// http://tools.ietf.org/html/draft-ietf-hybi-thewebsocketprotocol-10#section-4.2 +// See Section 5.2 Base Frameing protocol for detail. +// http://tools.ietf.org/html/draft-ietf-hybi-thewebsocketprotocol-17#section-5.2 func (buf hybiFrameReaderFactory) NewFrameReader() (frame frameReader, err os.Error) { hybiFrame := new(hybiFrameReader) frame = hybiFrame @@ -258,6 +262,12 @@ func (handler *hybiFrameHandler) HandleFrame(frame frameReader) (r frameReader, handler.WriteClose(closeStatusProtocolError) return nil, os.EOF } + } else { + // The server MUST NOT mask all frames. + if frame.(*hybiFrameReader).header.MaskingKey != nil { + handler.WriteClose(closeStatusProtocolError) + return nil, os.EOF + } } if header := frame.HeaderReader(); header != nil { io.Copy(ioutil.Discard, header) @@ -366,9 +376,18 @@ func getNonceAccept(nonce []byte) (expected []byte, err os.Error) { return } -// Client handhake described in draft-ietf-hybi-thewebsocket-protocol-09 +func isHybiVersion(version int) bool { + switch version { + case ProtocolVersionHybi08, ProtocolVersionHybi13: + return true + default: + } + return false +} + +// Client handhake described in draft-ietf-hybi-thewebsocket-protocol-17 func hybiClientHandshake(config *Config, br *bufio.Reader, bw *bufio.Writer) (err os.Error) { - if config.Version != ProtocolVersionHybi { + if !isHybiVersion(config.Version) { panic("wrong protocol version.") } @@ -382,7 +401,11 @@ func hybiClientHandshake(config *Config, br *bufio.Reader, bw *bufio.Writer) (er nonce = []byte(config.handshakeData["key"]) } bw.WriteString("Sec-WebSocket-Key: " + string(nonce) + "\r\n") - bw.WriteString("Sec-WebSocket-Origin: " + strings.ToLower(config.Origin.String()) + "\r\n") + if config.Version == ProtocolVersionHybi13 { + bw.WriteString("Origin: " + strings.ToLower(config.Origin.String()) + "\r\n") + } else if config.Version == ProtocolVersionHybi08 { + bw.WriteString("Sec-WebSocket-Origin: " + strings.ToLower(config.Origin.String()) + "\r\n") + } bw.WriteString("Sec-WebSocket-Version: " + fmt.Sprintf("%d", config.Version) + "\r\n") if len(config.Protocol) > 0 { bw.WriteString("Sec-WebSocket-Protocol: " + strings.Join(config.Protocol, ", ") + "\r\n") @@ -446,7 +469,7 @@ type hybiServerHandshaker struct { } func (c *hybiServerHandshaker) ReadHandshake(buf *bufio.Reader, req *http.Request) (code int, err os.Error) { - c.Version = ProtocolVersionHybi + c.Version = ProtocolVersionHybi13 if req.Method != "GET" { return http.StatusMethodNotAllowed, ErrBadRequestMethod } @@ -462,12 +485,20 @@ func (c *hybiServerHandshaker) ReadHandshake(buf *bufio.Reader, req *http.Reques return http.StatusBadRequest, ErrChallengeResponse } version := req.Header.Get("Sec-Websocket-Version") - if version != fmt.Sprintf("%d", c.Version) { + var origin string + switch version { + case "13": + c.Version = ProtocolVersionHybi13 + origin = req.Header.Get("Origin") + case "8": + c.Version = ProtocolVersionHybi08 + origin = req.Header.Get("Sec-Websocket-Origin") + default: return http.StatusBadRequest, ErrBadWebSocketVersion } - c.Origin, err = url.ParseRequest(req.Header.Get("Sec-Websocket-Origin")) + c.Origin, err = url.ParseRequest(origin) if err != nil { - return http.StatusBadRequest, err + return http.StatusForbidden, err } var scheme string if req.TLS != nil { diff --git a/src/pkg/websocket/hybi_test.go b/src/pkg/websocket/hybi_test.go index c437819639..0814c08015 100644 --- a/src/pkg/websocket/hybi_test.go +++ b/src/pkg/websocket/hybi_test.go @@ -16,7 +16,7 @@ import ( ) // Test the getNonceAccept function with values in -// http://tools.ietf.org/html/draft-ietf-hybi-thewebsocketprotocol-09 +// http://tools.ietf.org/html/draft-ietf-hybi-thewebsocketprotocol-17 func TestSecWebSocketAccept(t *testing.T) { nonce := []byte("dGhlIHNhbXBsZSBub25jZQ==") expected := []byte("s3pPLMBiTxaQ9kYGzzhZRbK+xOo=") @@ -52,7 +52,69 @@ Sec-WebSocket-Protocol: chat } config.Protocol = append(config.Protocol, "chat") config.Protocol = append(config.Protocol, "superchat") - config.Version = ProtocolVersionHybi + config.Version = ProtocolVersionHybi13 + + config.handshakeData = map[string]string{ + "key": "dGhlIHNhbXBsZSBub25jZQ==", + } + err = hybiClientHandshake(config, br, bw) + if err != nil { + t.Errorf("handshake failed: %v", err) + } + req, err := http.ReadRequest(bufio.NewReader(b)) + if err != nil { + t.Errorf("read request: %v", err) + } + if req.Method != "GET" { + t.Errorf("request method expected GET, but got %q", req.Method) + } + if req.RawURL != "/chat" { + t.Errorf("request path expected /chat, but got %q", req.RawURL) + } + if req.Proto != "HTTP/1.1" { + t.Errorf("request proto expected HTTP/1.1, but got %q", req.Proto) + } + if req.Host != "server.example.com" { + t.Errorf("request Host expected server.example.com, but got %v", req.Host) + } + var expectedHeader = map[string]string{ + "Connection": "Upgrade", + "Upgrade": "websocket", + "Sec-Websocket-Key": config.handshakeData["key"], + "Origin": config.Origin.String(), + "Sec-Websocket-Protocol": "chat, superchat", + "Sec-Websocket-Version": fmt.Sprintf("%d", ProtocolVersionHybi13), + } + for k, v := range expectedHeader { + if req.Header.Get(k) != v { + t.Errorf(fmt.Sprintf("%s expected %q but got %q", k, v, req.Header.Get(k))) + } + } +} + +func TestHybiClientHandshakeHybi08(t *testing.T) { + b := bytes.NewBuffer([]byte{}) + bw := bufio.NewWriter(b) + br := bufio.NewReader(strings.NewReader(`HTTP/1.1 101 Switching Protocols +Upgrade: websocket +Connection: Upgrade +Sec-WebSocket-Accept: s3pPLMBiTxaQ9kYGzzhZRbK+xOo= +Sec-WebSocket-Protocol: chat + +`)) + var err os.Error + config := new(Config) + config.Location, err = url.ParseRequest("ws://server.example.com/chat") + if err != nil { + t.Fatal("location url", err) + } + config.Origin, err = url.ParseRequest("http://example.com") + if err != nil { + t.Fatal("origin url", err) + } + config.Protocol = append(config.Protocol, "chat") + config.Protocol = append(config.Protocol, "superchat") + config.Version = ProtocolVersionHybi08 config.handshakeData = map[string]string{ "key": "dGhlIHNhbXBsZSBub25jZQ==", @@ -83,7 +145,7 @@ Sec-WebSocket-Protocol: chat "Sec-Websocket-Key": config.handshakeData["key"], "Sec-Websocket-Origin": config.Origin.String(), "Sec-Websocket-Protocol": "chat, superchat", - "Sec-Websocket-Version": fmt.Sprintf("%d", ProtocolVersionHybi), + "Sec-Websocket-Version": fmt.Sprintf("%d", ProtocolVersionHybi08), } for k, v := range expectedHeader { if req.Header.Get(k) != v { @@ -100,6 +162,52 @@ Host: server.example.com Upgrade: websocket Connection: Upgrade Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ== +Origin: http://example.com +Sec-WebSocket-Protocol: chat, superchat +Sec-WebSocket-Version: 13 + +`)) + req, err := http.ReadRequest(br) + if err != nil { + t.Fatal("request", err) + } + code, err := handshaker.ReadHandshake(br, req) + if err != nil { + t.Errorf("handshake failed: %v", err) + } + if code != http.StatusSwitchingProtocols { + t.Errorf("status expected %q but got %q", http.StatusSwitchingProtocols, code) + } + b := bytes.NewBuffer([]byte{}) + bw := bufio.NewWriter(b) + + config.Protocol = []string{"chat"} + + err = handshaker.AcceptHandshake(bw) + if err != nil { + t.Errorf("handshake response failed: %v", err) + } + expectedResponse := strings.Join([]string{ + "HTTP/1.1 101 Switching Protocols", + "Upgrade: websocket", + "Connection: Upgrade", + "Sec-WebSocket-Accept: s3pPLMBiTxaQ9kYGzzhZRbK+xOo=", + "Sec-WebSocket-Protocol: chat", + "", ""}, "\r\n") + + if b.String() != expectedResponse { + t.Errorf("handshake expected %q but got %q", expectedResponse, b.String()) + } +} + +func TestHybiServerHandshakeHybi08(t *testing.T) { + config := new(Config) + handshaker := &hybiServerHandshaker{Config: config} + br := bufio.NewReader(strings.NewReader(`GET /chat HTTP/1.1 +Host: server.example.com +Upgrade: websocket +Connection: Upgrade +Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ== Sec-WebSocket-Origin: http://example.com Sec-WebSocket-Protocol: chat, superchat Sec-WebSocket-Version: 8 @@ -138,6 +246,32 @@ Sec-WebSocket-Version: 8 } } +func TestHybiServerHandshakeHybiBadVersion(t *testing.T) { + config := new(Config) + handshaker := &hybiServerHandshaker{Config: config} + br := bufio.NewReader(strings.NewReader(`GET /chat HTTP/1.1 +Host: server.example.com +Upgrade: websocket +Connection: Upgrade +Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ== +Sec-WebSocket-Origin: http://example.com +Sec-WebSocket-Protocol: chat, superchat +Sec-WebSocket-Version: 9 + +`)) + req, err := http.ReadRequest(br) + if err != nil { + t.Fatal("request", err) + } + code, err := handshaker.ReadHandshake(br, req) + if err != ErrBadWebSocketVersion { + t.Errorf("handshake expected err %q but got %q", ErrBadWebSocketVersion, err) + } + if code != http.StatusBadRequest { + t.Errorf("status expected %q but got %q", http.StatusBadRequest, code) + } +} + func testHybiFrame(t *testing.T, testHeader, testPayload, testMaskedPayload []byte, frameHeader *hybiFrameHeader) { b := bytes.NewBuffer([]byte{}) frameWriterFactory := &hybiFrameWriterFactory{bufio.NewWriter(b), false} @@ -247,7 +381,7 @@ func TestHybiLongFrame(t *testing.T) { testHybiFrame(t, []byte{0x81, 127, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00}, payload, payload, frameHeader) } -func TestHybiRead(t *testing.T) { +func TestHybiClientRead(t *testing.T) { wireData := []byte{0x81, 0x05, 'h', 'e', 'l', 'l', 'o', 0x89, 0x05, 'h', 'e', 'l', 'l', 'o', // ping 0x81, 0x05, 'w', 'o', 'r', 'l', 'd'} @@ -325,16 +459,17 @@ func TestHybiShortRead(t *testing.T) { } } -func TestHybiReadWithMasking(t *testing.T) { +func TestHybiServerRead(t *testing.T) { wireData := []byte{0x81, 0x85, 0xcc, 0x55, 0x80, 0x20, 0xa4, 0x30, 0xec, 0x4c, 0xa3, // hello - 0x89, 0x05, 'h', 'e', 'l', 'l', 'o', + 0x89, 0x85, 0xcc, 0x55, 0x80, 0x20, + 0xa4, 0x30, 0xec, 0x4c, 0xa3, // ping: hello 0x81, 0x85, 0xed, 0x83, 0xb4, 0x24, 0x9a, 0xec, 0xc6, 0x48, 0x89, // world } br := bufio.NewReader(bytes.NewBuffer(wireData)) bw := bufio.NewWriter(bytes.NewBuffer([]byte{})) - conn := newHybiConn(newConfig(t, "/"), bufio.NewReadWriter(br, bw), nil, nil) + conn := newHybiConn(newConfig(t, "/"), bufio.NewReadWriter(br, bw), nil, new(http.Request)) expected := [][]byte{[]byte("hello"), []byte("world")} @@ -369,3 +504,32 @@ func TestHybiReadWithMasking(t *testing.T) { t.Errorf("expect read 0, got %d", n) } } + +func TestHybiServerReadWithoutMasking(t *testing.T) { + wireData := []byte{0x81, 0x05, 'h', 'e', 'l', 'l', 'o'} + br := bufio.NewReader(bytes.NewBuffer(wireData)) + bw := bufio.NewWriter(bytes.NewBuffer([]byte{})) + conn := newHybiConn(newConfig(t, "/"), bufio.NewReadWriter(br, bw), nil, new(http.Request)) + // server MUST close the connection upon receiving a non-masked frame. + msg := make([]byte, 512) + _, err := conn.Read(msg) + if err != os.EOF { + t.Errorf("read 1st frame, expect %q, but got %q", os.EOF, err) + } +} + +func TestHybiClientReadWithMasking(t *testing.T) { + wireData := []byte{0x81, 0x85, 0xcc, 0x55, 0x80, 0x20, + 0xa4, 0x30, 0xec, 0x4c, 0xa3, // hello + } + br := bufio.NewReader(bytes.NewBuffer(wireData)) + bw := bufio.NewWriter(bytes.NewBuffer([]byte{})) + conn := newHybiConn(newConfig(t, "/"), bufio.NewReadWriter(br, bw), nil, nil) + + // client MUST close the connection upon receiving a masked frame. + msg := make([]byte, 512) + _, err := conn.Read(msg) + if err != os.EOF { + t.Errorf("read 1st frame, expect %q, but got %q", os.EOF, err) + } +} diff --git a/src/pkg/websocket/server.go b/src/pkg/websocket/server.go index 8f6a6a94fb..a1d1d48600 100644 --- a/src/pkg/websocket/server.go +++ b/src/pkg/websocket/server.go @@ -16,6 +16,13 @@ func newServerConn(rwc io.ReadWriteCloser, buf *bufio.ReadWriter, req *http.Requ config := new(Config) var hs serverHandshaker = &hybiServerHandshaker{Config: config} code, err := hs.ReadHandshake(buf.Reader, req) + if err == ErrBadWebSocketVersion { + fmt.Fprintf(buf, "HTTP/1.1 %03d %s\r\n", code, http.StatusText(code)) + fmt.Fprintf(buf, "Sec-WebSocket-Version: %s\r\n", SupportedProtocolVersion) + buf.WriteString("\r\n") + buf.WriteString(err.String()) + return + } if err != nil { hs = &hixie76ServerHandshaker{Config: config} code, err = hs.ReadHandshake(buf.Reader, req) diff --git a/src/pkg/websocket/websocket.go b/src/pkg/websocket/websocket.go index 1855705c99..a3750dde11 100644 --- a/src/pkg/websocket/websocket.go +++ b/src/pkg/websocket/websocket.go @@ -20,10 +20,13 @@ import ( ) const ( - ProtocolVersionHixie75 = -75 - ProtocolVersionHixie76 = -76 - ProtocolVersionHybi00 = 0 - ProtocolVersionHybi = 8 + ProtocolVersionHixie75 = -75 + ProtocolVersionHixie76 = -76 + ProtocolVersionHybi00 = 0 + ProtocolVersionHybi08 = 8 + ProtocolVersionHybi13 = 13 + ProtocolVersionHybi = ProtocolVersionHybi13 + SupportedProtocolVersion = "13, 8" ContinuationFrame = 0 TextFrame = 1 @@ -39,23 +42,23 @@ type ProtocolError struct { ErrorString string } -func (err ProtocolError) String() string { return err.ErrorString } +func (err *ProtocolError) String() string { return err.ErrorString } var ( - ErrBadProtocolVersion = ProtocolError{"bad protocol version"} - ErrBadScheme = ProtocolError{"bad scheme"} - ErrBadStatus = ProtocolError{"bad status"} - ErrBadUpgrade = ProtocolError{"missing or bad upgrade"} - ErrBadWebSocketOrigin = ProtocolError{"missing or bad WebSocket-Origin"} - ErrBadWebSocketLocation = ProtocolError{"missing or bad WebSocket-Location"} - ErrBadWebSocketProtocol = ProtocolError{"missing or bad WebSocket-Protocol"} - ErrBadWebSocketVersion = ProtocolError{"missing or bad WebSocket Version"} - ErrChallengeResponse = ProtocolError{"mismatch challenge/response"} - ErrBadFrame = ProtocolError{"bad frame"} - ErrBadFrameBoundary = ProtocolError{"not on frame boundary"} - ErrNotWebSocket = ProtocolError{"not websocket protocol"} - ErrBadRequestMethod = ProtocolError{"bad method"} - ErrNotSupported = ProtocolError{"not supported"} + ErrBadProtocolVersion = &ProtocolError{"bad protocol version"} + ErrBadScheme = &ProtocolError{"bad scheme"} + ErrBadStatus = &ProtocolError{"bad status"} + ErrBadUpgrade = &ProtocolError{"missing or bad upgrade"} + ErrBadWebSocketOrigin = &ProtocolError{"missing or bad WebSocket-Origin"} + ErrBadWebSocketLocation = &ProtocolError{"missing or bad WebSocket-Location"} + ErrBadWebSocketProtocol = &ProtocolError{"missing or bad WebSocket-Protocol"} + ErrBadWebSocketVersion = &ProtocolError{"missing or bad WebSocket Version"} + ErrChallengeResponse = &ProtocolError{"mismatch challenge/response"} + ErrBadFrame = &ProtocolError{"bad frame"} + ErrBadFrameBoundary = &ProtocolError{"not on frame boundary"} + ErrNotWebSocket = &ProtocolError{"not websocket protocol"} + ErrBadRequestMethod = &ProtocolError{"bad method"} + ErrNotSupported = &ProtocolError{"not supported"} ) // Addr is an implementation of net.Addr for WebSocket.