mirror of
https://github.com/golang/go
synced 2024-09-29 18:24:29 -06:00
net/http: add MaxBytesHandler
This commit is contained in:
parent
144e0b1f6e
commit
6435fd5881
@ -6609,3 +6609,63 @@ func testQuerySemicolon(t *testing.T, query string, wantX string, allowSemicolon
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestMaxBytesHandler(t *testing.T) {
|
||||
setParallel(t)
|
||||
defer afterTest(t)
|
||||
|
||||
for _, maxSize := range []int64{100, 1_000, 1_000_000} {
|
||||
for _, requestSize := range []int64{100, 1_000, 1_000_000} {
|
||||
t.Run(fmt.Sprintf("max size %d request size %d", maxSize, requestSize),
|
||||
func(t *testing.T) {
|
||||
testMaxBytesHandler(t, maxSize, requestSize)
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func testMaxBytesHandler(t *testing.T, maxSize, requestSize int64) {
|
||||
var (
|
||||
handlerN int64
|
||||
handlerErr error
|
||||
)
|
||||
echo := HandlerFunc(func(w ResponseWriter, r *Request) {
|
||||
var buf bytes.Buffer
|
||||
handlerN, handlerErr = io.Copy(&buf, r.Body)
|
||||
io.Copy(w, &buf)
|
||||
})
|
||||
|
||||
ts := httptest.NewServer(MaxBytesHandler(echo, maxSize))
|
||||
defer ts.Close()
|
||||
|
||||
c := ts.Client()
|
||||
var buf strings.Builder
|
||||
body := strings.NewReader(strings.Repeat("a", int(requestSize)))
|
||||
res, err := c.Post(ts.URL, "text/plain", body)
|
||||
if err != nil {
|
||||
t.Errorf("unexpected connection error: %v", err)
|
||||
} else {
|
||||
_, err = io.Copy(&buf, res.Body)
|
||||
res.Body.Close()
|
||||
if err != nil {
|
||||
t.Errorf("unexpected read error: %v", err)
|
||||
}
|
||||
}
|
||||
if handlerN > maxSize {
|
||||
t.Errorf("expected max request body %d; got %d", maxSize, handlerN)
|
||||
}
|
||||
if requestSize > maxSize && handlerErr == nil {
|
||||
t.Error("expected error on handler side; got nil")
|
||||
}
|
||||
if requestSize <= maxSize {
|
||||
if handlerErr != nil {
|
||||
t.Errorf("%d expected nil error on handler side; got %v", requestSize, handlerErr)
|
||||
}
|
||||
if handlerN != requestSize {
|
||||
t.Errorf("expected request of size %d; got %d", requestSize, handlerN)
|
||||
}
|
||||
}
|
||||
if buf.Len() != int(handlerN) {
|
||||
t.Errorf("expected echo of size %d; got %d", handlerN, buf.Len())
|
||||
}
|
||||
}
|
||||
|
@ -3567,3 +3567,12 @@ func tlsRecordHeaderLooksLikeHTTP(hdr [5]byte) bool {
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// MaxBytesHandler returns a Handler that runs h with its ResponseWriter and Request.Body wrapped by a MaxBytesReader.
|
||||
func MaxBytesHandler(h Handler, n int64) Handler {
|
||||
return HandlerFunc(func(w ResponseWriter, r *Request) {
|
||||
r2 := *r
|
||||
r2.Body = MaxBytesReader(w, r.Body, n)
|
||||
h.ServeHTTP(w, &r2)
|
||||
})
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user