diff --git a/src/context/context.go b/src/context/context.go index e40b63ef3cd..3afa3e90d2c 100644 --- a/src/context/context.go +++ b/src/context/context.go @@ -252,9 +252,9 @@ func propagateCancel(parent Context, child canceler) { child.cancel(false, p.err) } else { if p.children == nil { - p.children = make(map[canceler]bool) + p.children = make(map[canceler]struct{}) } - p.children[child] = true + p.children[child] = struct{}{} } p.mu.Unlock() } else { @@ -314,8 +314,8 @@ type cancelCtx struct { done chan struct{} // closed by the first cancel call. mu sync.Mutex - children map[canceler]bool // set to nil by the first cancel call - err error // set to non-nil by the first cancel call + children map[canceler]struct{} // set to nil by the first cancel call + err error // set to non-nil by the first cancel call } func (c *cancelCtx) Done() <-chan struct{} { diff --git a/src/context/context_test.go b/src/context/context_test.go index c31c4d8718d..d305db50dc0 100644 --- a/src/context/context_test.go +++ b/src/context/context_test.go @@ -92,6 +92,11 @@ func TestWithCancel(t *testing.T) { } } +func contains(m map[canceler]struct{}, key canceler) bool { + _, ret := m[key] + return ret +} + func TestParentFinishesChild(t *testing.T) { // Context tree: // parent -> cancelChild @@ -120,7 +125,7 @@ func TestParentFinishesChild(t *testing.T) { cc := cancelChild.(*cancelCtx) tc := timerChild.(*timerCtx) pc.mu.Lock() - if len(pc.children) != 2 || !pc.children[cc] || !pc.children[tc] { + if len(pc.children) != 2 || !contains(pc.children, cc) || !contains(pc.children, tc) { t.Errorf("bad linkage: pc.children = %v, want %v and %v", pc.children, cc, tc) } @@ -191,7 +196,7 @@ func TestChildFinishesFirst(t *testing.T) { if pcok { pc.mu.Lock() - if len(pc.children) != 1 || !pc.children[cc] { + if len(pc.children) != 1 || !contains(pc.children, cc) { t.Errorf("bad linkage: pc.children = %v, cc = %v", pc.children, cc) } pc.mu.Unlock() @@ -627,3 +632,36 @@ func TestDeadlineExceededSupportsTimeout(t *testing.T) { t.Fatal("wrong value for timeout") } } + +func BenchmarkContextCancelTree(b *testing.B) { + depths := []int{1, 10, 100, 1000} + for _, d := range depths { + b.Run(fmt.Sprintf("depth=%d", d), func(b *testing.B) { + b.Run("Root=Background", func(b *testing.B) { + for i := 0; i < b.N; i++ { + buildContextTree(Background(), d) + } + }) + b.Run("Root=OpenCanceler", func(b *testing.B) { + for i := 0; i < b.N; i++ { + ctx, cancel := WithCancel(Background()) + buildContextTree(ctx, d) + cancel() + } + }) + b.Run("Root=ClosedCanceler", func(b *testing.B) { + for i := 0; i < b.N; i++ { + ctx, cancel := WithCancel(Background()) + cancel() + buildContextTree(ctx, d) + } + }) + }) + } +} + +func buildContextTree(root Context, depth int) { + for d := 0; d < depth; d++ { + root, _ = WithCancel(root) + } +}