diff --git a/playground/socket/socket.go b/playground/socket/socket.go index 66689a4f710..51c2a49ee04 100644 --- a/playground/socket/socket.go +++ b/playground/socket/socket.go @@ -22,6 +22,8 @@ import ( "io" "io/ioutil" "log" + "net/http" + "net/url" "os" "os/exec" "path/filepath" @@ -39,9 +41,6 @@ import ( // (snippets that start with a shebang). var RunScripts = true -// Handler implements a WebSocket handler for a client connection. -var Handler = websocket.Handler(socketHandler) - // Environ provides an environment when a binary, such as the go tool, is // invoked. var Environ func() []string = os.Environ @@ -69,6 +68,30 @@ type Options struct { Race bool // use -race flag when building code (for "run" only) } +// NewHandler returns a websocket server which checks the origin of requests. +func NewHandler(origin *url.URL) websocket.Server { + return websocket.Server{ + Config: websocket.Config{Origin: origin}, + Handshake: handshake, + Handler: websocket.Handler(socketHandler), + } +} + +// handshake checks the origin of a request during the websocket handshake. +func handshake(c *websocket.Config, req *http.Request) error { + o, err := websocket.Origin(c, req) + if err != nil { + log.Println("bad websocket origin:", err) + return websocket.ErrBadWebSocketOrigin + } + ok := c.Origin.Scheme == o.Scheme && c.Origin.Host == o.Host + if !ok { + log.Println("bad websocket origin:", o) + return websocket.ErrBadWebSocketOrigin + } + return nil +} + // socketHandler handles the websocket connection for a given present session. // It handles transcoding Messages to and from JSON format, and starting // and killing processes.