1
0
mirror of https://github.com/golang/go synced 2024-11-20 04:14:49 -07:00

net/http/httputil: Clean up ReverseProxy maxLatencyWriter goroutines.

When FlushInterval is specified on ReverseProxy, the ResponseWriter is
wrapped with a maxLatencyWriter that periodically flushes in a
goroutine. That goroutine was not being cleaned up at the end of the
request. This resulted in a panic when Flush() was being called on a
ResponseWriter that was closed.

The code was updated to always send the done message to the flushLoop()
goroutine after copying the body. Futhermore, the code was refactored to
allow the test to verify the maxLatencyWriter behavior.

R=golang-dev, bradfitz
CC=golang-dev
https://golang.org/cl/6033043
This commit is contained in:
Colby Ranger 2012-04-18 11:33:02 -07:00 committed by Brad Fitzpatrick
parent 6742d0a085
commit 5694ebf057
2 changed files with 87 additions and 19 deletions

View File

@ -17,6 +17,10 @@ import (
"time"
)
// beforeCopyResponse is a callback set by tests to intercept the state of the
// output io.Writer before the data is copied to it.
var beforeCopyResponse func(dst io.Writer)
// ReverseProxy is an HTTP Handler that takes an incoming request and
// sends it to another server, proxying the response back to the
// client.
@ -112,20 +116,32 @@ func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
rw.WriteHeader(http.StatusInternalServerError)
return
}
defer res.Body.Close()
copyHeader(rw.Header(), res.Header)
rw.WriteHeader(res.StatusCode)
p.copyResponse(rw, res.Body)
}
if res.Body != nil {
var dst io.Writer = rw
if p.FlushInterval != 0 {
if wf, ok := rw.(writeFlusher); ok {
dst = &maxLatencyWriter{dst: wf, latency: p.FlushInterval}
func (p *ReverseProxy) copyResponse(dst io.Writer, src io.Reader) {
if p.FlushInterval != 0 {
if wf, ok := dst.(writeFlusher); ok {
mlw := &maxLatencyWriter{
dst: wf,
latency: p.FlushInterval,
done: make(chan bool),
}
go mlw.flushLoop()
defer mlw.stop()
dst = mlw
}
io.Copy(dst, res.Body)
}
if beforeCopyResponse != nil {
beforeCopyResponse(dst)
}
io.Copy(dst, src)
}
type writeFlusher interface {
@ -137,22 +153,14 @@ type maxLatencyWriter struct {
dst writeFlusher
latency time.Duration
lk sync.Mutex // protects init of done, as well Write + Flush
lk sync.Mutex // protects Write + Flush
done chan bool
}
func (m *maxLatencyWriter) Write(p []byte) (n int, err error) {
func (m *maxLatencyWriter) Write(p []byte) (int, error) {
m.lk.Lock()
defer m.lk.Unlock()
if m.done == nil {
m.done = make(chan bool)
go m.flushLoop()
}
n, err = m.dst.Write(p)
if err != nil {
m.done <- true
}
return
return m.dst.Write(p)
}
func (m *maxLatencyWriter) flushLoop() {
@ -160,13 +168,15 @@ func (m *maxLatencyWriter) flushLoop() {
defer t.Stop()
for {
select {
case <-m.done:
return
case <-t.C:
m.lk.Lock()
m.dst.Flush()
m.lk.Unlock()
case <-m.done:
return
}
}
panic("unreached")
}
func (m *maxLatencyWriter) stop() { m.done <- true }

View File

@ -7,11 +7,14 @@
package httputil
import (
"io"
"io/ioutil"
"net/http"
"net/http/httptest"
"net/url"
"runtime"
"testing"
"time"
)
func TestReverseProxy(t *testing.T) {
@ -107,3 +110,58 @@ func TestReverseProxyQuery(t *testing.T) {
frontend.Close()
}
}
func TestReverseProxyFlushInterval(t *testing.T) {
if testing.Short() {
return
}
const expected = "hi"
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte(expected))
}))
defer backend.Close()
backendURL, err := url.Parse(backend.URL)
if err != nil {
t.Fatal(err)
}
proxyHandler := NewSingleHostReverseProxy(backendURL)
proxyHandler.FlushInterval = time.Microsecond
dstChan := make(chan io.Writer, 1)
beforeCopyResponse = func(dst io.Writer) { dstChan <- dst }
defer func() { beforeCopyResponse = nil }()
frontend := httptest.NewServer(proxyHandler)
defer frontend.Close()
initGoroutines := runtime.NumGoroutine()
for i := 0; i < 100; i++ {
req, _ := http.NewRequest("GET", frontend.URL, nil)
req.Close = true
res, err := http.DefaultClient.Do(req)
if err != nil {
t.Fatalf("Get: %v", err)
}
if bodyBytes, _ := ioutil.ReadAll(res.Body); string(bodyBytes) != expected {
t.Errorf("got body %q; expected %q", bodyBytes, expected)
}
select {
case dst := <-dstChan:
if _, ok := dst.(*maxLatencyWriter); !ok {
t.Errorf("got writer %T; expected %T", dst, &maxLatencyWriter{})
}
default:
t.Error("maxLatencyWriter Write() was never called")
}
res.Body.Close()
}
// Allow up to 50 additional goroutines over 100 requests.
if delta := runtime.NumGoroutine() - initGoroutines; delta > 50 {
t.Errorf("grew %d goroutines; leak?", delta)
}
}