From 6435fd5881fc70a276d04df5a60440e365924b49 Mon Sep 17 00:00:00 2001 From: Carl Johnson Date: Tue, 31 Aug 2021 16:31:45 -0400 Subject: [PATCH] net/http: add MaxBytesHandler --- src/net/http/serve_test.go | 60 ++++++++++++++++++++++++++++++++++++++ src/net/http/server.go | 9 ++++++ 2 files changed, 69 insertions(+) diff --git a/src/net/http/serve_test.go b/src/net/http/serve_test.go index 6394da3bb7..a425f66e8d 100644 --- a/src/net/http/serve_test.go +++ b/src/net/http/serve_test.go @@ -6609,3 +6609,63 @@ func testQuerySemicolon(t *testing.T, query string, wantX string, allowSemicolon } } } + +func TestMaxBytesHandler(t *testing.T) { + setParallel(t) + defer afterTest(t) + + for _, maxSize := range []int64{100, 1_000, 1_000_000} { + for _, requestSize := range []int64{100, 1_000, 1_000_000} { + t.Run(fmt.Sprintf("max size %d request size %d", maxSize, requestSize), + func(t *testing.T) { + testMaxBytesHandler(t, maxSize, requestSize) + }) + } + } +} + +func testMaxBytesHandler(t *testing.T, maxSize, requestSize int64) { + var ( + handlerN int64 + handlerErr error + ) + echo := HandlerFunc(func(w ResponseWriter, r *Request) { + var buf bytes.Buffer + handlerN, handlerErr = io.Copy(&buf, r.Body) + io.Copy(w, &buf) + }) + + ts := httptest.NewServer(MaxBytesHandler(echo, maxSize)) + defer ts.Close() + + c := ts.Client() + var buf strings.Builder + body := strings.NewReader(strings.Repeat("a", int(requestSize))) + res, err := c.Post(ts.URL, "text/plain", body) + if err != nil { + t.Errorf("unexpected connection error: %v", err) + } else { + _, err = io.Copy(&buf, res.Body) + res.Body.Close() + if err != nil { + t.Errorf("unexpected read error: %v", err) + } + } + if handlerN > maxSize { + t.Errorf("expected max request body %d; got %d", maxSize, handlerN) + } + if requestSize > maxSize && handlerErr == nil { + t.Error("expected error on handler side; got nil") + } + if requestSize <= maxSize { + if handlerErr != nil { + t.Errorf("%d expected nil error on handler side; got %v", requestSize, handlerErr) + } + if handlerN != requestSize { + t.Errorf("expected request of size %d; got %d", requestSize, handlerN) + } + } + if buf.Len() != int(handlerN) { + t.Errorf("expected echo of size %d; got %d", handlerN, buf.Len()) + } +} diff --git a/src/net/http/server.go b/src/net/http/server.go index 5b113cff97..efddf2980b 100644 --- a/src/net/http/server.go +++ b/src/net/http/server.go @@ -3567,3 +3567,12 @@ func tlsRecordHeaderLooksLikeHTTP(hdr [5]byte) bool { } return false } + +// MaxBytesHandler returns a Handler that runs h with its ResponseWriter and Request.Body wrapped by a MaxBytesReader. +func MaxBytesHandler(h Handler, n int64) Handler { + return HandlerFunc(func(w ResponseWriter, r *Request) { + r2 := *r + r2.Body = MaxBytesReader(w, r.Body, n) + h.ServeHTTP(w, &r2) + }) +}