1
0
mirror of https://github.com/golang/go synced 2024-09-29 16:14:28 -06:00

net/http: add MaxBytesHandler

This commit is contained in:
Carl Johnson 2021-08-31 16:31:45 -04:00
parent 144e0b1f6e
commit 6435fd5881
2 changed files with 69 additions and 0 deletions

View File

@ -6609,3 +6609,63 @@ func testQuerySemicolon(t *testing.T, query string, wantX string, allowSemicolon
}
}
}
func TestMaxBytesHandler(t *testing.T) {
setParallel(t)
defer afterTest(t)
for _, maxSize := range []int64{100, 1_000, 1_000_000} {
for _, requestSize := range []int64{100, 1_000, 1_000_000} {
t.Run(fmt.Sprintf("max size %d request size %d", maxSize, requestSize),
func(t *testing.T) {
testMaxBytesHandler(t, maxSize, requestSize)
})
}
}
}
func testMaxBytesHandler(t *testing.T, maxSize, requestSize int64) {
var (
handlerN int64
handlerErr error
)
echo := HandlerFunc(func(w ResponseWriter, r *Request) {
var buf bytes.Buffer
handlerN, handlerErr = io.Copy(&buf, r.Body)
io.Copy(w, &buf)
})
ts := httptest.NewServer(MaxBytesHandler(echo, maxSize))
defer ts.Close()
c := ts.Client()
var buf strings.Builder
body := strings.NewReader(strings.Repeat("a", int(requestSize)))
res, err := c.Post(ts.URL, "text/plain", body)
if err != nil {
t.Errorf("unexpected connection error: %v", err)
} else {
_, err = io.Copy(&buf, res.Body)
res.Body.Close()
if err != nil {
t.Errorf("unexpected read error: %v", err)
}
}
if handlerN > maxSize {
t.Errorf("expected max request body %d; got %d", maxSize, handlerN)
}
if requestSize > maxSize && handlerErr == nil {
t.Error("expected error on handler side; got nil")
}
if requestSize <= maxSize {
if handlerErr != nil {
t.Errorf("%d expected nil error on handler side; got %v", requestSize, handlerErr)
}
if handlerN != requestSize {
t.Errorf("expected request of size %d; got %d", requestSize, handlerN)
}
}
if buf.Len() != int(handlerN) {
t.Errorf("expected echo of size %d; got %d", handlerN, buf.Len())
}
}

View File

@ -3567,3 +3567,12 @@ func tlsRecordHeaderLooksLikeHTTP(hdr [5]byte) bool {
}
return false
}
// MaxBytesHandler returns a Handler that runs h with its ResponseWriter and Request.Body wrapped by a MaxBytesReader.
func MaxBytesHandler(h Handler, n int64) Handler {
return HandlerFunc(func(w ResponseWriter, r *Request) {
r2 := *r
r2.Body = MaxBytesReader(w, r.Body, n)
h.ServeHTTP(w, &r2)
})
}