diff --git a/src/net/http/httputil/reverseproxy.go b/src/net/http/httputil/reverseproxy.go index fb1aa0f3e4d..190279ca007 100644 --- a/src/net/http/httputil/reverseproxy.go +++ b/src/net/http/httputil/reverseproxy.go @@ -113,6 +113,14 @@ type ReverseProxy struct { // outbound request before Rewrite is called. See also // the ProxyRequest.SetXForwarded method. // + // Unparsable query parameters are removed from the + // outbound request before Rewrite is called. + // The Rewrite function may copy the inbound URL's + // RawQuery to the outbound URL to preserve the original + // parameter string. Note that this can lead to security + // issues if the proxy's interpretation of query parameters + // does not match that of the downstream server. + // // At most one of Rewrite or Director may be set. Rewrite func(*ProxyRequest) @@ -140,6 +148,9 @@ type ReverseProxy struct { // Director. Use a Rewrite function instead to ensure // modifications to the request are preserved. // + // Unparsable query parameters are removed from the outbound + // request if Request.Form is set after Director returns. + // // At most one of Rewrite or Director may be set. Director func(*http.Request) @@ -374,6 +385,9 @@ func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) { if p.Director != nil { p.Director(outreq) + if outreq.Form != nil { + outreq.URL.RawQuery = cleanQueryParams(outreq.URL.RawQuery) + } } outreq.Close = false @@ -409,6 +423,9 @@ func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) { outreq.Header.Del("X-Forwarded-Host") outreq.Header.Del("X-Forwarded-Proto") + // Remove unparsable query parameters from the outbound request. + outreq.URL.RawQuery = cleanQueryParams(outreq.URL.RawQuery) + pr := &ProxyRequest{ In: req, Out: outreq, @@ -797,3 +814,36 @@ func (c switchProtocolCopier) copyToBackend(errc chan<- error) { _, err := io.Copy(c.backend, c.user) errc <- err } + +func cleanQueryParams(s string) string { + reencode := func(s string) string { + v, _ := url.ParseQuery(s) + return v.Encode() + } + for i := 0; i < len(s); { + switch s[i] { + case ';': + return reencode(s) + case '%': + if i+2 >= len(s) || !ishex(s[i+1]) || !ishex(s[i+2]) { + return reencode(s) + } + i += 3 + default: + i++ + } + } + return s +} + +func ishex(c byte) bool { + switch { + case '0' <= c && c <= '9': + return true + case 'a' <= c && c <= 'f': + return true + case 'A' <= c && c <= 'F': + return true + } + return false +} diff --git a/src/net/http/httputil/reverseproxy_test.go b/src/net/http/httputil/reverseproxy_test.go index 68052627550..5b882d3a456 100644 --- a/src/net/http/httputil/reverseproxy_test.go +++ b/src/net/http/httputil/reverseproxy_test.go @@ -1753,3 +1753,98 @@ func Test1xxResponses(t *testing.T) { t.Errorf("Read body %q; want Hello", body) } } + +const ( + testWantsCleanQuery = true + testWantsRawQuery = false +) + +func TestReverseProxyQueryParameterSmugglingDirectorDoesNotParseForm(t *testing.T) { + testReverseProxyQueryParameterSmuggling(t, testWantsRawQuery, func(u *url.URL) *ReverseProxy { + proxyHandler := NewSingleHostReverseProxy(u) + oldDirector := proxyHandler.Director + proxyHandler.Director = func(r *http.Request) { + oldDirector(r) + } + return proxyHandler + }) +} + +func TestReverseProxyQueryParameterSmugglingDirectorParsesForm(t *testing.T) { + testReverseProxyQueryParameterSmuggling(t, testWantsCleanQuery, func(u *url.URL) *ReverseProxy { + proxyHandler := NewSingleHostReverseProxy(u) + oldDirector := proxyHandler.Director + proxyHandler.Director = func(r *http.Request) { + // Parsing the form causes ReverseProxy to remove unparsable + // query parameters before forwarding. + r.FormValue("a") + oldDirector(r) + } + return proxyHandler + }) +} + +func TestReverseProxyQueryParameterSmugglingRewrite(t *testing.T) { + testReverseProxyQueryParameterSmuggling(t, testWantsCleanQuery, func(u *url.URL) *ReverseProxy { + return &ReverseProxy{ + Rewrite: func(r *ProxyRequest) { + r.SetURL(u) + }, + } + }) +} + +func TestReverseProxyQueryParameterSmugglingRewritePreservesRawQuery(t *testing.T) { + testReverseProxyQueryParameterSmuggling(t, testWantsRawQuery, func(u *url.URL) *ReverseProxy { + return &ReverseProxy{ + Rewrite: func(r *ProxyRequest) { + r.SetURL(u) + r.Out.URL.RawQuery = r.In.URL.RawQuery + }, + } + }) +} + +func testReverseProxyQueryParameterSmuggling(t *testing.T, wantCleanQuery bool, newProxy func(*url.URL) *ReverseProxy) { + const content = "response_content" + backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte(r.URL.RawQuery)) + })) + defer backend.Close() + backendURL, err := url.Parse(backend.URL) + if err != nil { + t.Fatal(err) + } + proxyHandler := newProxy(backendURL) + frontend := httptest.NewServer(proxyHandler) + defer frontend.Close() + + // Don't spam output with logs of queries containing semicolons. + backend.Config.ErrorLog = log.New(io.Discard, "", 0) + frontend.Config.ErrorLog = log.New(io.Discard, "", 0) + + for _, test := range []struct { + rawQuery string + cleanQuery string + }{{ + rawQuery: "a=1&a=2;b=3", + cleanQuery: "a=1", + }, { + rawQuery: "a=1&a=%zz&b=3", + cleanQuery: "a=1&b=3", + }} { + res, err := frontend.Client().Get(frontend.URL + "?" + test.rawQuery) + if err != nil { + t.Fatalf("Get: %v", err) + } + defer res.Body.Close() + body, _ := io.ReadAll(res.Body) + wantQuery := test.rawQuery + if wantCleanQuery { + wantQuery = test.cleanQuery + } + if got, want := string(body), wantQuery; got != want { + t.Errorf("proxy forwarded raw query %q as %q, want %q", test.rawQuery, got, want) + } + } +}