1
0
mirror of https://github.com/golang/go synced 2024-11-18 06:14:46 -07:00

net/http/httputil: forward 1xx responses in ReverseProxy

Support for 1xx responses has recently been merged in
net/http (CL 269997).

As discussed in this CL
(https://go-review.googlesource.com/c/go/+/269997/comments/1ff70bef_c25a829a),
support for forwarding 1xx responses in ReverseProxy has been extracted
in this separate patch.

According to RFC 7231, "a proxy MUST forward 1xx responses unless the
proxy itself requested the generation of the 1xx response".
Consequently, all received 1xx responses are automatically forwarded as long as the
underlying transport supports ClientTrace.Got1xxResponse.

Fixes #26088
Fixes #51914

Change-Id: I3a35ea023b798bfe56b7fb8696d5a49695229cfd
GitHub-Last-Rev: dab8a461fb
GitHub-Pull-Request: golang/go#53164
Reviewed-on: https://go-review.googlesource.com/c/go/+/409536
Reviewed-by: Damien Neil <dneil@google.com>
Auto-Submit: Dmitri Shuralyov <dmitshur@golang.org>
TryBot-Result: Gopher Robot <gobot@golang.org>
Reviewed-by: Dmitri Shuralyov <dmitshur@google.com>
Reviewed-by: Rhys Hiltner <rhys@justin.tv>
Run-TryBot: hopehook <hopehook@golangcn.org>
This commit is contained in:
Kévin Dunglas 2022-08-30 10:18:50 +00:00 committed by Gopher Robot
parent 4baa486983
commit 972870da11
2 changed files with 102 additions and 0 deletions

View File

@ -15,6 +15,7 @@ import (
"mime"
"net"
"net/http"
"net/http/httptrace"
"net/http/internal/ascii"
"net/textproto"
"net/url"
@ -96,6 +97,9 @@ func (r *ProxyRequest) SetXForwarded() {
// ReverseProxy is an HTTP Handler that takes an incoming request and
// sends it to another server, proxying the response back to the
// client.
//
// 1xx responses are forwarded to the client if the underlying
// transport supports ClientTrace.Got1xxResponse.
type ReverseProxy struct {
// Rewrite must be a function which modifies
// the request into a new request to be sent
@ -449,6 +453,22 @@ func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
outreq.Header.Set("User-Agent", "")
}
trace := &httptrace.ClientTrace{
Got1xxResponse: func(code int, header textproto.MIMEHeader) error {
h := rw.Header()
copyHeader(h, http.Header(header))
rw.WriteHeader(code)
// Clear headers, it's not automatically done by ResponseWriter.WriteHeader() for 1xx responses
for k := range h {
delete(h, k)
}
return nil
},
}
outreq = outreq.WithContext(httptrace.WithClientTrace(outreq.Context(), trace))
res, err := transport.RoundTrip(outreq)
if err != nil {
p.getErrorHandler()(rw, outreq, err)

View File

@ -16,7 +16,9 @@ import (
"log"
"net/http"
"net/http/httptest"
"net/http/httptrace"
"net/http/internal/ascii"
"net/textproto"
"net/url"
"os"
"reflect"
@ -1671,3 +1673,83 @@ func TestReverseProxyRewriteReplacesOut(t *testing.T) {
t.Errorf("got response %q, want %q", got, want)
}
}
func Test1xxResponses(t *testing.T) {
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
h := w.Header()
h.Add("Link", "</style.css>; rel=preload; as=style")
h.Add("Link", "</script.js>; rel=preload; as=script")
w.WriteHeader(http.StatusEarlyHints)
h.Add("Link", "</foo.js>; rel=preload; as=script")
w.WriteHeader(http.StatusProcessing)
w.Write([]byte("Hello"))
}))
defer backend.Close()
backendURL, err := url.Parse(backend.URL)
if err != nil {
t.Fatal(err)
}
proxyHandler := NewSingleHostReverseProxy(backendURL)
proxyHandler.ErrorLog = log.New(io.Discard, "", 0) // quiet for tests
frontend := httptest.NewServer(proxyHandler)
defer frontend.Close()
frontendClient := frontend.Client()
checkLinkHeaders := func(t *testing.T, expected, got []string) {
t.Helper()
if len(expected) != len(got) {
t.Errorf("Expected %d link headers; got %d", len(expected), len(got))
}
for i := range expected {
if i >= len(got) {
t.Errorf("Expected %q link header; got nothing", expected[i])
continue
}
if expected[i] != got[i] {
t.Errorf("Expected %q link header; got %q", expected[i], got[i])
}
}
}
var respCounter uint8
trace := &httptrace.ClientTrace{
Got1xxResponse: func(code int, header textproto.MIMEHeader) error {
switch code {
case http.StatusEarlyHints:
checkLinkHeaders(t, []string{"</style.css>; rel=preload; as=style", "</script.js>; rel=preload; as=script"}, header["Link"])
case http.StatusProcessing:
checkLinkHeaders(t, []string{"</style.css>; rel=preload; as=style", "</script.js>; rel=preload; as=script", "</foo.js>; rel=preload; as=script"}, header["Link"])
default:
t.Error("Unexpected 1xx response")
}
respCounter++
return nil
},
}
req, _ := http.NewRequestWithContext(httptrace.WithClientTrace(context.Background(), trace), "GET", frontend.URL, nil)
res, err := frontendClient.Do(req)
if err != nil {
t.Fatalf("Get: %v", err)
}
defer res.Body.Close()
if respCounter != 2 {
t.Errorf("Expected 2 1xx responses; got %d", respCounter)
}
checkLinkHeaders(t, []string{"</style.css>; rel=preload; as=style", "</script.js>; rel=preload; as=script", "</foo.js>; rel=preload; as=script"}, res.Header["Link"])
body, _ := io.ReadAll(res.Body)
if string(body) != "Hello" {
t.Errorf("Read body %q; want Hello", body)
}
}