mirror of
https://github.com/golang/go
synced 2024-11-21 20:04:44 -07:00
http: avoid server crash on malformed client request
R=r, rsc CC=golang-dev https://golang.org/cl/206079
This commit is contained in:
parent
c2dea2196c
commit
2161e3e23e
@ -38,20 +38,34 @@ type Handler func(*Conn)
|
||||
|
||||
// ServeHTTP implements the http.Handler interface for a Web Socket.
|
||||
func (f Handler) ServeHTTP(c *http.Conn, req *http.Request) {
|
||||
if req.Method != "GET" || req.Proto != "HTTP/1.1" ||
|
||||
req.Header["Upgrade"] != "WebSocket" ||
|
||||
req.Header["Connection"] != "Upgrade" {
|
||||
c.WriteHeader(http.StatusNotFound)
|
||||
io.WriteString(c, "must use websocket to connect here")
|
||||
if req.Method != "GET" || req.Proto != "HTTP/1.1" {
|
||||
c.WriteHeader(http.StatusBadRequest)
|
||||
io.WriteString(c, "Unexpected request")
|
||||
return
|
||||
}
|
||||
if v, present := req.Header["Upgrade"]; !present || v != "WebSocket" {
|
||||
c.WriteHeader(http.StatusBadRequest)
|
||||
io.WriteString(c, "missing Upgrade: WebSocket header")
|
||||
return
|
||||
}
|
||||
if v, present := req.Header["Connection"]; !present || v != "Upgrade" {
|
||||
c.WriteHeader(http.StatusBadRequest)
|
||||
io.WriteString(c, "missing Connection: Upgrade header")
|
||||
return
|
||||
}
|
||||
origin, present := req.Header["Origin"]
|
||||
if !present {
|
||||
c.WriteHeader(http.StatusBadRequest)
|
||||
io.WriteString(c, "missing Origin header")
|
||||
return
|
||||
}
|
||||
|
||||
rwc, buf, err := c.Hijack()
|
||||
if err != nil {
|
||||
panic("Hijack failed: ", err.String())
|
||||
return
|
||||
}
|
||||
defer rwc.Close()
|
||||
origin := req.Header["Origin"]
|
||||
location := "ws://" + req.Host + req.URL.Path
|
||||
|
||||
// TODO(ukai): verify origin,location,protocol.
|
||||
@ -61,9 +75,9 @@ func (f Handler) ServeHTTP(c *http.Conn, req *http.Request) {
|
||||
buf.WriteString("Connection: Upgrade\r\n")
|
||||
buf.WriteString("WebSocket-Origin: " + origin + "\r\n")
|
||||
buf.WriteString("WebSocket-Location: " + location + "\r\n")
|
||||
protocol := ""
|
||||
protocol, present := req.Header["Websocket-Protocol"]
|
||||
// canonical header key of WebSocket-Protocol.
|
||||
if protocol, found := req.Header["Websocket-Protocol"]; found {
|
||||
if present {
|
||||
buf.WriteString("WebSocket-Protocol: " + protocol + "\r\n")
|
||||
}
|
||||
buf.WriteString("\r\n")
|
||||
|
@ -6,6 +6,7 @@ package websocket
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"http"
|
||||
"io"
|
||||
"log"
|
||||
@ -59,3 +60,17 @@ func TestEcho(t *testing.T) {
|
||||
}
|
||||
ws.Close()
|
||||
}
|
||||
|
||||
func TestHTTP(t *testing.T) {
|
||||
once.Do(startServer)
|
||||
|
||||
r, _, err := http.Get(fmt.Sprintf("http://%s/echo", serverAddr))
|
||||
if err != nil {
|
||||
t.Errorf("Get: error %v", err)
|
||||
return
|
||||
}
|
||||
if r.StatusCode != http.StatusBadRequest {
|
||||
t.Errorf("Get: got status %d", r.StatusCode)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user