1
0
mirror of https://github.com/golang/go synced 2024-11-23 15:20:03 -07:00

net/http: await state traces earlier in TestServerConnState

This approach attempts to ensure that the log for each connection is
complete before the next sequence of states begins.

Updates #32329

Change-Id: I25150d3ceab6568af56a40d2b14b5f544dc87f61
Reviewed-on: https://go-review.googlesource.com/c/go/+/210717
Run-TryBot: Brad Fitzpatrick <bradfitz@golang.org>
TryBot-Result: Gobot Gobot <gobot@golang.org>
Reviewed-by: Brad Fitzpatrick <bradfitz@golang.org>
This commit is contained in:
Bryan C. Mills 2019-12-10 10:02:27 -05:00
parent ecde0bfa1f
commit 931fe39400

View File

@ -34,7 +34,6 @@ import (
"regexp" "regexp"
"runtime" "runtime"
"runtime/debug" "runtime/debug"
"sort"
"strconv" "strconv"
"strings" "strings"
"sync" "sync"
@ -4116,14 +4115,49 @@ func TestServerConnState(t *testing.T) {
panic("intentional panic") panic("intentional panic")
}, },
} }
// A stateLog is a log of states over the lifetime of a connection.
type stateLog struct {
active net.Conn // The connection for which the log is recorded; set to the first connection seen in StateNew.
got []ConnState
want []ConnState
complete chan<- struct{} // If non-nil, closed when either 'got' is equal to 'want', or 'got' is no longer a prefix of 'want'.
}
activeLog := make(chan *stateLog, 1)
// wantLog invokes doRequests, then waits for the resulting connection to
// either pass through the sequence of states in want or enter a state outside
// of that sequence.
wantLog := func(doRequests func(), want ...ConnState) {
t.Helper()
complete := make(chan struct{})
activeLog <- &stateLog{want: want, complete: complete}
doRequests()
timer := time.NewTimer(5 * time.Second)
select {
case <-timer.C:
t.Errorf("Timed out waiting for connection to change state.")
case <-complete:
timer.Stop()
}
sl := <-activeLog
if !reflect.DeepEqual(sl.got, sl.want) {
t.Errorf("Request(s) produced unexpected state sequence.\nGot: %v\nWant: %v", sl.got, sl.want)
}
// Don't return sl to activeLog: we don't expect any further states after
// this point, and want to keep the ConnState callback blocked until the
// next call to wantLog.
}
ts := httptest.NewUnstartedServer(HandlerFunc(func(w ResponseWriter, r *Request) { ts := httptest.NewUnstartedServer(HandlerFunc(func(w ResponseWriter, r *Request) {
handler[r.URL.Path](w, r) handler[r.URL.Path](w, r)
})) }))
defer ts.Close() defer func() {
activeLog <- &stateLog{} // If the test failed, allow any remaining ConnState callbacks to complete.
var mu sync.Mutex // guard stateLog and connID ts.Close()
var stateLog = map[int][]ConnState{} }()
var connID = map[net.Conn]int{}
ts.Config.ErrorLog = log.New(ioutil.Discard, "", 0) ts.Config.ErrorLog = log.New(ioutil.Discard, "", 0)
ts.Config.ConnState = func(c net.Conn, state ConnState) { ts.Config.ConnState = func(c net.Conn, state ConnState) {
@ -4131,20 +4165,27 @@ func TestServerConnState(t *testing.T) {
t.Errorf("nil conn seen in state %s", state) t.Errorf("nil conn seen in state %s", state)
return return
} }
mu.Lock() sl := <-activeLog
defer mu.Unlock() if sl.active == nil && state == StateNew {
id, ok := connID[c] sl.active = c
if !ok { } else if sl.active != c {
id = len(connID) + 1 t.Errorf("unexpected conn in state %s", state)
connID[c] = id activeLog <- sl
return
} }
stateLog[id] = append(stateLog[id], state) sl.got = append(sl.got, state)
if sl.complete != nil && (len(sl.got) >= len(sl.want) || !reflect.DeepEqual(sl.got, sl.want[:len(sl.got)])) {
close(sl.complete)
sl.complete = nil
}
activeLog <- sl
} }
ts.Start()
ts.Start()
c := ts.Client() c := ts.Client()
mustGet := func(url string, headers ...string) { mustGet := func(url string, headers ...string) {
t.Helper()
req, err := NewRequest("GET", url, nil) req, err := NewRequest("GET", url, nil)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
@ -4165,26 +4206,33 @@ func TestServerConnState(t *testing.T) {
} }
} }
wantLog(func() {
mustGet(ts.URL + "/") mustGet(ts.URL + "/")
mustGet(ts.URL + "/close") mustGet(ts.URL + "/close")
}, StateNew, StateActive, StateIdle, StateActive, StateClosed)
wantLog(func() {
mustGet(ts.URL + "/") mustGet(ts.URL + "/")
mustGet(ts.URL+"/", "Connection", "close") mustGet(ts.URL+"/", "Connection", "close")
}, StateNew, StateActive, StateIdle, StateActive, StateClosed)
wantLog(func() {
mustGet(ts.URL + "/hijack") mustGet(ts.URL + "/hijack")
mustGet(ts.URL + "/hijack-panic") }, StateNew, StateActive, StateHijacked)
// New->Closed wantLog(func() {
{ mustGet(ts.URL + "/hijack-panic")
}, StateNew, StateActive, StateHijacked)
wantLog(func() {
c, err := net.Dial("tcp", ts.Listener.Addr().String()) c, err := net.Dial("tcp", ts.Listener.Addr().String())
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
c.Close() c.Close()
} }, StateNew, StateClosed)
// New->Active->Closed wantLog(func() {
{
c, err := net.Dial("tcp", ts.Listener.Addr().String()) c, err := net.Dial("tcp", ts.Listener.Addr().String())
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
@ -4194,10 +4242,9 @@ func TestServerConnState(t *testing.T) {
} }
c.Read(make([]byte, 1)) // block until server hangs up on us c.Read(make([]byte, 1)) // block until server hangs up on us
c.Close() c.Close()
} }, StateNew, StateActive, StateClosed)
// New->Idle->Closed wantLog(func() {
{
c, err := net.Dial("tcp", ts.Listener.Addr().String()) c, err := net.Dial("tcp", ts.Listener.Addr().String())
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
@ -4213,47 +4260,7 @@ func TestServerConnState(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
c.Close() c.Close()
} }, StateNew, StateActive, StateIdle, StateClosed)
want := map[int][]ConnState{
1: {StateNew, StateActive, StateIdle, StateActive, StateClosed},
2: {StateNew, StateActive, StateIdle, StateActive, StateClosed},
3: {StateNew, StateActive, StateHijacked},
4: {StateNew, StateActive, StateHijacked},
5: {StateNew, StateClosed},
6: {StateNew, StateActive, StateClosed},
7: {StateNew, StateActive, StateIdle, StateClosed},
}
logString := func(m map[int][]ConnState) string {
var b bytes.Buffer
var keys []int
for id := range m {
keys = append(keys, id)
}
sort.Ints(keys)
for _, id := range keys {
fmt.Fprintf(&b, "Conn %d: ", id)
for _, s := range m[id] {
fmt.Fprintf(&b, "%s ", s)
}
b.WriteString("\n")
}
return b.String()
}
for i := 0; i < 5; i++ {
time.Sleep(time.Duration(i) * 50 * time.Millisecond)
mu.Lock()
match := reflect.DeepEqual(stateLog, want)
mu.Unlock()
if match {
return
}
}
mu.Lock()
t.Errorf("Unexpected events.\nGot log:\n%s\n Want:\n%s\n", logString(stateLog), logString(want))
mu.Unlock()
} }
func TestServerKeepAlivesEnabled(t *testing.T) { func TestServerKeepAlivesEnabled(t *testing.T) {