From 50b16f9de590822a04ec8d6cbac476366c1bde32 Mon Sep 17 00:00:00 2001 From: Anmol Sethi Date: Sat, 24 Oct 2020 00:40:41 +0000 Subject: [PATCH] net/http: allow upgrading non keepalive connections If one was using http.Transport with DisableKeepAlives and trying to upgrade a connection against net/http's Server, the Server would not allow a "Connection: Upgrade" header to be written and instead override it to "Connection: Close" which would break the handshake. This change ensures net/http's Server does not override the connection header for successful protocol switch responses. Fixes #36381. Change-Id: I882aad8539e6c87ff5f37c20e20b3a7fa1a30357 GitHub-Last-Rev: dc0de83201dc26236527b68bd49dffc53dd0389b GitHub-Pull-Request: golang/go#36382 Reviewed-on: https://go-review.googlesource.com/c/go/+/213277 Trust: Emmanuel Odeke Trust: Brad Fitzpatrick Reviewed-by: Brad Fitzpatrick --- src/net/http/response.go | 16 ++++++++---- src/net/http/serve_test.go | 53 ++++++++++++++++++++++++++++++++++++++ src/net/http/server.go | 8 +++++- 3 files changed, 71 insertions(+), 6 deletions(-) diff --git a/src/net/http/response.go b/src/net/http/response.go index 72812f06429..b95abae646b 100644 --- a/src/net/http/response.go +++ b/src/net/http/response.go @@ -352,10 +352,16 @@ func (r *Response) bodyIsWritable() bool { return ok } -// isProtocolSwitch reports whether r is a response to a successful -// protocol upgrade. +// isProtocolSwitch reports whether the response code and header +// indicate a successful protocol upgrade response. func (r *Response) isProtocolSwitch() bool { - return r.StatusCode == StatusSwitchingProtocols && - r.Header.Get("Upgrade") != "" && - httpguts.HeaderValuesContainsToken(r.Header["Connection"], "Upgrade") + return isProtocolSwitchResponse(r.StatusCode, r.Header) +} + +// isProtocolSwitchResponse reports whether the response code and +// response header indicate a successful protocol upgrade response. +func isProtocolSwitchResponse(code int, h Header) bool { + return code == StatusSwitchingProtocols && + h.Get("Upgrade") != "" && + httpguts.HeaderValuesContainsToken(h["Connection"], "Upgrade") } diff --git a/src/net/http/serve_test.go b/src/net/http/serve_test.go index ba54b31a29d..b1bf8e6c5e8 100644 --- a/src/net/http/serve_test.go +++ b/src/net/http/serve_test.go @@ -6448,3 +6448,56 @@ func BenchmarkResponseStatusLine(b *testing.B) { } }) } +func TestDisableKeepAliveUpgrade(t *testing.T) { + if testing.Short() { + t.Skip("skipping in short mode") + } + + setParallel(t) + defer afterTest(t) + + s := httptest.NewUnstartedServer(HandlerFunc(func(w ResponseWriter, r *Request) { + w.Header().Set("Connection", "Upgrade") + w.Header().Set("Upgrade", "someProto") + w.WriteHeader(StatusSwitchingProtocols) + c, _, err := w.(Hijacker).Hijack() + if err != nil { + return + } + defer c.Close() + + io.Copy(c, c) + })) + s.Config.SetKeepAlivesEnabled(false) + s.Start() + defer s.Close() + + cl := s.Client() + cl.Transport.(*Transport).DisableKeepAlives = true + + resp, err := cl.Get(s.URL) + if err != nil { + t.Fatalf("failed to perform request: %v", err) + } + defer resp.Body.Close() + + rwc, ok := resp.Body.(io.ReadWriteCloser) + if !ok { + t.Fatalf("Response.Body is not a io.ReadWriteCloser: %T", resp.Body) + } + + _, err = rwc.Write([]byte("hello")) + if err != nil { + t.Fatalf("failed to write to body: %v", err) + } + + b := make([]byte, 5) + _, err = io.ReadFull(rwc, b) + if err != nil { + t.Fatalf("failed to read from body: %v", err) + } + + if string(b) != "hello" { + t.Fatalf("unexpected value read from body:\ngot: %q\nwant: %q", b, "hello") + } +} diff --git a/src/net/http/server.go b/src/net/http/server.go index 6c7d2817051..102e893d5f1 100644 --- a/src/net/http/server.go +++ b/src/net/http/server.go @@ -1468,7 +1468,13 @@ func (cw *chunkWriter) writeHeader(p []byte) { return } - if w.closeAfterReply && (!keepAlivesEnabled || !hasToken(cw.header.get("Connection"), "close")) { + // Only override the Connection header if it is not a successful + // protocol switch response and if KeepAlives are not enabled. + // See https://golang.org/issue/36381. + delConnectionHeader := w.closeAfterReply && + (!keepAlivesEnabled || !hasToken(cw.header.get("Connection"), "close")) && + !isProtocolSwitchResponse(w.status, header) + if delConnectionHeader { delHeader("Connection") if w.req.ProtoAtLeast(1, 1) { setHeader.connection = "close"