mirror of
https://github.com/golang/go
synced 2024-11-20 04:14:49 -07:00
net/http/httputil: Clean up ReverseProxy maxLatencyWriter goroutines.
When FlushInterval is specified on ReverseProxy, the ResponseWriter is wrapped with a maxLatencyWriter that periodically flushes in a goroutine. That goroutine was not being cleaned up at the end of the request. This resulted in a panic when Flush() was being called on a ResponseWriter that was closed. The code was updated to always send the done message to the flushLoop() goroutine after copying the body. Futhermore, the code was refactored to allow the test to verify the maxLatencyWriter behavior. R=golang-dev, bradfitz CC=golang-dev https://golang.org/cl/6033043
This commit is contained in:
parent
6742d0a085
commit
5694ebf057
@ -17,6 +17,10 @@ import (
|
||||
"time"
|
||||
)
|
||||
|
||||
// beforeCopyResponse is a callback set by tests to intercept the state of the
|
||||
// output io.Writer before the data is copied to it.
|
||||
var beforeCopyResponse func(dst io.Writer)
|
||||
|
||||
// ReverseProxy is an HTTP Handler that takes an incoming request and
|
||||
// sends it to another server, proxying the response back to the
|
||||
// client.
|
||||
@ -112,20 +116,32 @@ func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
||||
rw.WriteHeader(http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
defer res.Body.Close()
|
||||
|
||||
copyHeader(rw.Header(), res.Header)
|
||||
|
||||
rw.WriteHeader(res.StatusCode)
|
||||
p.copyResponse(rw, res.Body)
|
||||
}
|
||||
|
||||
if res.Body != nil {
|
||||
var dst io.Writer = rw
|
||||
if p.FlushInterval != 0 {
|
||||
if wf, ok := rw.(writeFlusher); ok {
|
||||
dst = &maxLatencyWriter{dst: wf, latency: p.FlushInterval}
|
||||
func (p *ReverseProxy) copyResponse(dst io.Writer, src io.Reader) {
|
||||
if p.FlushInterval != 0 {
|
||||
if wf, ok := dst.(writeFlusher); ok {
|
||||
mlw := &maxLatencyWriter{
|
||||
dst: wf,
|
||||
latency: p.FlushInterval,
|
||||
done: make(chan bool),
|
||||
}
|
||||
go mlw.flushLoop()
|
||||
defer mlw.stop()
|
||||
dst = mlw
|
||||
}
|
||||
io.Copy(dst, res.Body)
|
||||
}
|
||||
|
||||
if beforeCopyResponse != nil {
|
||||
beforeCopyResponse(dst)
|
||||
}
|
||||
io.Copy(dst, src)
|
||||
}
|
||||
|
||||
type writeFlusher interface {
|
||||
@ -137,22 +153,14 @@ type maxLatencyWriter struct {
|
||||
dst writeFlusher
|
||||
latency time.Duration
|
||||
|
||||
lk sync.Mutex // protects init of done, as well Write + Flush
|
||||
lk sync.Mutex // protects Write + Flush
|
||||
done chan bool
|
||||
}
|
||||
|
||||
func (m *maxLatencyWriter) Write(p []byte) (n int, err error) {
|
||||
func (m *maxLatencyWriter) Write(p []byte) (int, error) {
|
||||
m.lk.Lock()
|
||||
defer m.lk.Unlock()
|
||||
if m.done == nil {
|
||||
m.done = make(chan bool)
|
||||
go m.flushLoop()
|
||||
}
|
||||
n, err = m.dst.Write(p)
|
||||
if err != nil {
|
||||
m.done <- true
|
||||
}
|
||||
return
|
||||
return m.dst.Write(p)
|
||||
}
|
||||
|
||||
func (m *maxLatencyWriter) flushLoop() {
|
||||
@ -160,13 +168,15 @@ func (m *maxLatencyWriter) flushLoop() {
|
||||
defer t.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-m.done:
|
||||
return
|
||||
case <-t.C:
|
||||
m.lk.Lock()
|
||||
m.dst.Flush()
|
||||
m.lk.Unlock()
|
||||
case <-m.done:
|
||||
return
|
||||
}
|
||||
}
|
||||
panic("unreached")
|
||||
}
|
||||
|
||||
func (m *maxLatencyWriter) stop() { m.done <- true }
|
||||
|
@ -7,11 +7,14 @@
|
||||
package httputil
|
||||
|
||||
import (
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"runtime"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestReverseProxy(t *testing.T) {
|
||||
@ -107,3 +110,58 @@ func TestReverseProxyQuery(t *testing.T) {
|
||||
frontend.Close()
|
||||
}
|
||||
}
|
||||
|
||||
func TestReverseProxyFlushInterval(t *testing.T) {
|
||||
if testing.Short() {
|
||||
return
|
||||
}
|
||||
|
||||
const expected = "hi"
|
||||
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Write([]byte(expected))
|
||||
}))
|
||||
defer backend.Close()
|
||||
|
||||
backendURL, err := url.Parse(backend.URL)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
proxyHandler := NewSingleHostReverseProxy(backendURL)
|
||||
proxyHandler.FlushInterval = time.Microsecond
|
||||
|
||||
dstChan := make(chan io.Writer, 1)
|
||||
beforeCopyResponse = func(dst io.Writer) { dstChan <- dst }
|
||||
defer func() { beforeCopyResponse = nil }()
|
||||
|
||||
frontend := httptest.NewServer(proxyHandler)
|
||||
defer frontend.Close()
|
||||
|
||||
initGoroutines := runtime.NumGoroutine()
|
||||
for i := 0; i < 100; i++ {
|
||||
req, _ := http.NewRequest("GET", frontend.URL, nil)
|
||||
req.Close = true
|
||||
res, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
t.Fatalf("Get: %v", err)
|
||||
}
|
||||
if bodyBytes, _ := ioutil.ReadAll(res.Body); string(bodyBytes) != expected {
|
||||
t.Errorf("got body %q; expected %q", bodyBytes, expected)
|
||||
}
|
||||
|
||||
select {
|
||||
case dst := <-dstChan:
|
||||
if _, ok := dst.(*maxLatencyWriter); !ok {
|
||||
t.Errorf("got writer %T; expected %T", dst, &maxLatencyWriter{})
|
||||
}
|
||||
default:
|
||||
t.Error("maxLatencyWriter Write() was never called")
|
||||
}
|
||||
|
||||
res.Body.Close()
|
||||
}
|
||||
// Allow up to 50 additional goroutines over 100 requests.
|
||||
if delta := runtime.NumGoroutine() - initGoroutines; delta > 50 {
|
||||
t.Errorf("grew %d goroutines; leak?", delta)
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user