mirror of
https://github.com/golang/go
synced 2024-11-23 16:30:06 -07:00
net/http/httputil: TestReverseProxyWebSocketCancelation
This commit is contained in:
parent
ae44b425aa
commit
3629c78493
@ -1158,6 +1158,128 @@ func TestReverseProxyWebSocket(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestReverseProxyWebSocketCancelation(t *testing.T) {
|
||||
n := 5
|
||||
triggerCancelCh := make(chan bool, n)
|
||||
nthResponse := func(i int) string {
|
||||
return fmt.Sprintf("backend response #%d\n", i)
|
||||
}
|
||||
terminalMsg := "final message"
|
||||
|
||||
cst := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if g, ws := upgradeType(r.Header), "websocket"; g != ws {
|
||||
t.Errorf("Unexpected upgrade type %q, want %q", g, ws)
|
||||
http.Error(w, "Unexpected request", 400)
|
||||
return
|
||||
}
|
||||
conn, bufrw, err := w.(http.Hijacker).Hijack()
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
return
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
upgradeMsg := "HTTP/1.1 101 Switching Protocols\r\nConnection: upgrade\r\nUpgrade: WebSocket\r\n\r\n"
|
||||
if _, err := io.WriteString(conn, upgradeMsg); err != nil {
|
||||
t.Error(err)
|
||||
return
|
||||
}
|
||||
if _, _, err := bufrw.ReadLine(); err != nil {
|
||||
t.Errorf("Failed to read line from client: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
for i := 0; i < n; i++ {
|
||||
if _, err := bufrw.WriteString(nthResponse(i)); err != nil {
|
||||
t.Errorf("Writing response #%d failed: %v", i, err)
|
||||
}
|
||||
bufrw.Flush()
|
||||
time.Sleep(time.Second)
|
||||
}
|
||||
if _, err := bufrw.WriteString(terminalMsg); err != nil {
|
||||
t.Errorf("Failed to write terminal message: %v", err)
|
||||
}
|
||||
bufrw.Flush()
|
||||
}))
|
||||
defer cst.Close()
|
||||
|
||||
backendURL, _ := url.Parse(cst.URL)
|
||||
rproxy := NewSingleHostReverseProxy(backendURL)
|
||||
rproxy.ErrorLog = log.New(ioutil.Discard, "", 0) // quiet for tests
|
||||
rproxy.ModifyResponse = func(res *http.Response) error {
|
||||
res.Header.Add("X-Modified", "true")
|
||||
return nil
|
||||
}
|
||||
|
||||
handler := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||||
rw.Header().Set("X-Header", "X-Value")
|
||||
ctx, cancel := context.WithCancel(req.Context())
|
||||
go func() {
|
||||
<-triggerCancelCh
|
||||
cancel()
|
||||
}()
|
||||
rproxy.ServeHTTP(rw, req.WithContext(ctx))
|
||||
})
|
||||
|
||||
frontendProxy := httptest.NewServer(handler)
|
||||
defer frontendProxy.Close()
|
||||
|
||||
req, _ := http.NewRequest("GET", frontendProxy.URL, nil)
|
||||
req.Header.Set("Connection", "Upgrade")
|
||||
req.Header.Set("Upgrade", "websocket")
|
||||
|
||||
res, err := frontendProxy.Client().Do(req)
|
||||
if err != nil {
|
||||
t.Fatalf("Dialing to frontend proxy: %v", err)
|
||||
}
|
||||
defer res.Body.Close()
|
||||
if g, w := res.StatusCode, 101; g != w {
|
||||
t.Fatalf("Switching protocols failed, got: %d, want: %d", g, w)
|
||||
}
|
||||
|
||||
if g, w := res.Header.Get("X-Header"), "X-Value"; g != w {
|
||||
t.Errorf("X-Header mismatch\n\tgot: %q\n\twant: %q", g, w)
|
||||
}
|
||||
|
||||
if g, w := upgradeType(res.Header), "websocket"; g != w {
|
||||
t.Fatalf("Upgrade header mismatch\n\tgot: %q\n\twant: %q", g, w)
|
||||
}
|
||||
|
||||
rwc, ok := res.Body.(io.ReadWriteCloser)
|
||||
if !ok {
|
||||
t.Fatalf("Response body type mismatch, got %T, want io.ReadWriteCloser", res.Body)
|
||||
}
|
||||
|
||||
if got, want := res.Header.Get("X-Modified"), "true"; got != want {
|
||||
t.Errorf("response X-Modified header = %q; want %q", got, want)
|
||||
}
|
||||
|
||||
if _, err := io.WriteString(rwc, "Hello\n"); err != nil {
|
||||
t.Fatalf("Failed to write first message: %v", err)
|
||||
}
|
||||
|
||||
// Read loop.
|
||||
|
||||
br := bufio.NewReader(rwc)
|
||||
for {
|
||||
line, err := br.ReadString('\n')
|
||||
switch {
|
||||
case line == terminalMsg: // this case before "err == io.EOF"
|
||||
t.Fatalf("The websocket request was not canceled, unfortunately!")
|
||||
|
||||
case err == io.EOF:
|
||||
return
|
||||
|
||||
case err != nil:
|
||||
t.Fatalf("Unexpected error: %v", err)
|
||||
|
||||
case line == nthResponse(0): // We've gotten the first response back
|
||||
// Let's trigger a cancel.
|
||||
close(triggerCancelCh)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestUnannouncedTrailer(t *testing.T) {
|
||||
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
|
Loading…
Reference in New Issue
Block a user