diff --git a/src/context/context.go b/src/context/context.go index f39abe91e2..f3fe1a474e 100644 --- a/src/context/context.go +++ b/src/context/context.go @@ -285,6 +285,8 @@ func withCancel(parent Context) *cancelCtx { // Cause returns nil if c has not been canceled yet. func Cause(c Context) error { if cc, ok := c.Value(&cancelCtxKey).(*cancelCtx); ok { + cc.mu.Lock() + defer cc.mu.Unlock() return cc.cause } return nil diff --git a/src/context/context_test.go b/src/context/context_test.go index 593a7b1521..eb5a86b3c6 100644 --- a/src/context/context_test.go +++ b/src/context/context_test.go @@ -5,6 +5,7 @@ package context import ( + "errors" "fmt" "math/rand" "runtime" @@ -934,3 +935,22 @@ func XTestCause(t testingT) { } } } + +func XTestCauseRace(t testingT) { + cause := errors.New("TestCauseRace") + ctx, cancel := WithCancelCause(Background()) + go func() { + cancel(cause) + }() + for { + // Poll Cause, rather than waiting for Done, to test that + // access to the underlying cause is synchronized properly. + if err := Cause(ctx); err != nil { + if err != cause { + t.Errorf("Cause returned %v, want %v", err, cause) + } + break + } + runtime.Gosched() + } +} diff --git a/src/context/x_test.go b/src/context/x_test.go index d3adb381d6..a2d814f8ea 100644 --- a/src/context/x_test.go +++ b/src/context/x_test.go @@ -30,3 +30,4 @@ func TestInvalidDerivedFail(t *testing.T) { XTestInvalidDerivedFail func TestDeadlineExceededSupportsTimeout(t *testing.T) { XTestDeadlineExceededSupportsTimeout(t) } func TestCustomContextGoroutines(t *testing.T) { XTestCustomContextGoroutines(t) } func TestCause(t *testing.T) { XTestCause(t) } +func TestCauseRace(t *testing.T) { XTestCauseRace(t) }