1
0
mirror of https://github.com/golang/go synced 2024-09-23 17:20:13 -06:00

context: add APIs for writing and reading cancelation cause

Extend the context package to allow users to specify why a context was
canceled in the form of an error, the "cause". Users write the cause
by calling WithCancelCause to construct a derived context, then
calling cancel(cause) to cancel the context with the provided cause.
Users retrieve the cause by calling context.Cause(ctx), which returns
the cause of the first cancelation for ctx or any of its parents.

The cause is implemented as a field of cancelCtx, since only cancelCtx
can be canceled. Calling cancel copies the cause to all derived (child)
cancelCtxs. Calling Cause(ctx) finds the nearest parent cancelCtx by
looking up the context value keyed by cancelCtxKey.

API changes:
+pkg context, func Cause(Context) error
+pkg context, func WithCancelCause(Context) (Context, CancelCauseFunc)
+pkg context, type CancelCauseFunc func(error)

Fixes #26356
Fixes #51365

Change-Id: I15b62bd454c014db3f4f1498b35204451509e641
Reviewed-on: https://go-review.googlesource.com/c/go/+/375977
Reviewed-by: Damien Neil <dneil@google.com>
TryBot-Result: Gopher Robot <gobot@golang.org>
Run-TryBot: Sameer Ajmani <sameer@golang.org>
Auto-Submit: Sameer Ajmani <sameer@golang.org>
This commit is contained in:
Sameer Ajmani 2022-01-06 13:57:05 -05:00 committed by Gopher Robot
parent 0b7aa9fa5b
commit 5b42f89e39
4 changed files with 231 additions and 20 deletions

3
api/next/51365.txt Normal file
View File

@ -0,0 +1,3 @@
pkg context, func Cause(Context) error #51365
pkg context, func WithCancelCause(Context) (Context, CancelCauseFunc) #51365
pkg context, type CancelCauseFunc func(error) #51365

View File

@ -22,6 +22,12 @@
// fires. The go vet tool checks that CancelFuncs are used on all
// control-flow paths.
//
// The WithCancelCause function returns a CancelCauseFunc, which
// takes an error and records it as the cancelation cause. Calling
// Cause on the canceled context or any of its children retrieves
// the cause. If no cause is specified, Cause(ctx) returns the same
// value as ctx.Err().
//
// Programs that use Contexts should follow these rules to keep interfaces
// consistent across packages and enable static analysis tools to check context
// propagation:
@ -230,17 +236,63 @@ type CancelFunc func()
// Canceling this context releases resources associated with it, so code should
// call cancel as soon as the operations running in this Context complete.
func WithCancel(parent Context) (ctx Context, cancel CancelFunc) {
c := withCancel(parent)
return c, func() { c.cancel(true, Canceled, nil) }
}
// A CancelCauseFunc behaves like a CancelFunc but additionally sets the cancelation cause.
// This cause can be retrieved by calling Cause on the canceled Context or on
// any of its derived Contexts.
//
// If the context has already been canceled, CancelCauseFunc does not set the cause.
// For example, if childContext is derived from parentContext:
// - if parentContext is canceled with cause1 before childContext is canceled with cause2,
// then Cause(parentContext) == Cause(childContext) == cause1
// - if childContext is canceled with cause2 before parentContext is canceled with cause1,
// then Cause(parentContext) == cause1 and Cause(childContext) == cause2
type CancelCauseFunc func(cause error)
// WithCancelCause behaves like WithCancel but returns a CancelCauseFunc instead of a CancelFunc.
// Calling cancel with a non-nil error (the "cause") records that error in ctx;
// it can then be retrieved using Cause(ctx).
// Calling cancel with nil sets the cause to Canceled.
//
// Example use:
//
// ctx, cancel := context.WithCancelCause(parent)
// cancel(myError)
// ctx.Err() // returns context.Canceled
// context.Cause(ctx) // returns myError
func WithCancelCause(parent Context) (ctx Context, cancel CancelCauseFunc) {
c := withCancel(parent)
return c, func(cause error) { c.cancel(true, Canceled, cause) }
}
func withCancel(parent Context) *cancelCtx {
if parent == nil {
panic("cannot create context from nil parent")
}
c := newCancelCtx(parent)
propagateCancel(parent, &c)
return &c, func() { c.cancel(true, Canceled) }
propagateCancel(parent, c)
return c
}
// Cause returns a non-nil error explaining why c was canceled.
// The first cancelation of c or one of its parents sets the cause.
// If that cancelation happened via a call to CancelCauseFunc(err),
// then Cause returns err.
// Otherwise Cause(c) returns the same value as c.Err().
// Cause returns nil if c has not been canceled yet.
func Cause(c Context) error {
if cc, ok := c.Value(&cancelCtxKey).(*cancelCtx); ok {
return cc.cause
}
return nil
}
// newCancelCtx returns an initialized cancelCtx.
func newCancelCtx(parent Context) cancelCtx {
return cancelCtx{Context: parent}
func newCancelCtx(parent Context) *cancelCtx {
return &cancelCtx{Context: parent}
}
// goroutines counts the number of goroutines ever created; for testing.
@ -256,7 +308,7 @@ func propagateCancel(parent Context, child canceler) {
select {
case <-done:
// parent is already canceled
child.cancel(false, parent.Err())
child.cancel(false, parent.Err(), Cause(parent))
return
default:
}
@ -265,7 +317,7 @@ func propagateCancel(parent Context, child canceler) {
p.mu.Lock()
if p.err != nil {
// parent has already been canceled
child.cancel(false, p.err)
child.cancel(false, p.err, p.cause)
} else {
if p.children == nil {
p.children = make(map[canceler]struct{})
@ -278,7 +330,7 @@ func propagateCancel(parent Context, child canceler) {
go func() {
select {
case <-parent.Done():
child.cancel(false, parent.Err())
child.cancel(false, parent.Err(), Cause(parent))
case <-child.Done():
}
}()
@ -326,7 +378,7 @@ func removeChild(parent Context, child canceler) {
// A canceler is a context type that can be canceled directly. The
// implementations are *cancelCtx and *timerCtx.
type canceler interface {
cancel(removeFromParent bool, err error)
cancel(removeFromParent bool, err, cause error)
Done() <-chan struct{}
}
@ -346,6 +398,7 @@ type cancelCtx struct {
done atomic.Value // of chan struct{}, created lazily, closed by first cancel call
children map[canceler]struct{} // set to nil by the first cancel call
err error // set to non-nil by the first cancel call
cause error // set to non-nil by the first cancel call
}
func (c *cancelCtx) Value(key any) any {
@ -394,16 +447,21 @@ func (c *cancelCtx) String() string {
// cancel closes c.done, cancels each of c's children, and, if
// removeFromParent is true, removes c from its parent's children.
func (c *cancelCtx) cancel(removeFromParent bool, err error) {
// cancel sets c.cause to cause if this is the first time c is canceled.
func (c *cancelCtx) cancel(removeFromParent bool, err, cause error) {
if err == nil {
panic("context: internal error: missing cancel error")
}
if cause == nil {
cause = err
}
c.mu.Lock()
if c.err != nil {
c.mu.Unlock()
return // already canceled
}
c.err = err
c.cause = cause
d, _ := c.done.Load().(chan struct{})
if d == nil {
c.done.Store(closedchan)
@ -412,7 +470,7 @@ func (c *cancelCtx) cancel(removeFromParent bool, err error) {
}
for child := range c.children {
// NOTE: acquiring the child's lock while holding parent's lock.
child.cancel(false, err)
child.cancel(false, err, cause)
}
c.children = nil
c.mu.Unlock()
@ -446,24 +504,24 @@ func WithDeadline(parent Context, d time.Time) (Context, CancelFunc) {
propagateCancel(parent, c)
dur := time.Until(d)
if dur <= 0 {
c.cancel(true, DeadlineExceeded) // deadline has already passed
return c, func() { c.cancel(false, Canceled) }
c.cancel(true, DeadlineExceeded, nil) // deadline has already passed
return c, func() { c.cancel(false, Canceled, nil) }
}
c.mu.Lock()
defer c.mu.Unlock()
if c.err == nil {
c.timer = time.AfterFunc(dur, func() {
c.cancel(true, DeadlineExceeded)
c.cancel(true, DeadlineExceeded, nil)
})
}
return c, func() { c.cancel(true, Canceled) }
return c, func() { c.cancel(true, Canceled, nil) }
}
// A timerCtx carries a timer and a deadline. It embeds a cancelCtx to
// implement Done and Err. It implements cancel by stopping its timer then
// delegating to cancelCtx.cancel.
type timerCtx struct {
cancelCtx
*cancelCtx
timer *time.Timer // Under cancelCtx.mu.
deadline time.Time
@ -479,8 +537,8 @@ func (c *timerCtx) String() string {
time.Until(c.deadline).String() + "])"
}
func (c *timerCtx) cancel(removeFromParent bool, err error) {
c.cancelCtx.cancel(false, err)
func (c *timerCtx) cancel(removeFromParent bool, err, cause error) {
c.cancelCtx.cancel(false, err, cause)
if removeFromParent {
// Remove this timerCtx from its parent cancelCtx's children.
removeChild(c.cancelCtx.Context, c)
@ -581,7 +639,7 @@ func value(c Context, key any) any {
c = ctx.Context
case *timerCtx:
if key == &cancelCtxKey {
return &ctx.cancelCtx
return ctx.cancelCtx
}
c = ctx.Context
case *emptyCtx:

View File

@ -650,8 +650,9 @@ func XTestCancelRemoves(t testingT) {
}
func XTestWithCancelCanceledParent(t testingT) {
parent, pcancel := WithCancel(Background())
pcancel()
parent, pcancel := WithCancelCause(Background())
cause := fmt.Errorf("Because!")
pcancel(cause)
c, _ := WithCancel(parent)
select {
@ -662,6 +663,9 @@ func XTestWithCancelCanceledParent(t testingT) {
if got, want := c.Err(), Canceled; got != want {
t.Errorf("child not canceled; got = %v, want = %v", got, want)
}
if got, want := Cause(c), cause; got != want {
t.Errorf("child has wrong cause; got = %v, want = %v", got, want)
}
}
func XTestWithValueChecksKey(t testingT) {
@ -785,3 +789,148 @@ func XTestCustomContextGoroutines(t testingT) {
defer cancel7()
checkNoGoroutine()
}
func XTestCause(t testingT) {
var (
parentCause = fmt.Errorf("parentCause")
childCause = fmt.Errorf("childCause")
)
for _, test := range []struct {
name string
ctx Context
err error
cause error
}{
{
name: "Background",
ctx: Background(),
err: nil,
cause: nil,
},
{
name: "TODO",
ctx: TODO(),
err: nil,
cause: nil,
},
{
name: "WithCancel",
ctx: func() Context {
ctx, cancel := WithCancel(Background())
cancel()
return ctx
}(),
err: Canceled,
cause: Canceled,
},
{
name: "WithCancelCause",
ctx: func() Context {
ctx, cancel := WithCancelCause(Background())
cancel(parentCause)
return ctx
}(),
err: Canceled,
cause: parentCause,
},
{
name: "WithCancelCause nil",
ctx: func() Context {
ctx, cancel := WithCancelCause(Background())
cancel(nil)
return ctx
}(),
err: Canceled,
cause: Canceled,
},
{
name: "WithCancelCause: parent cause before child",
ctx: func() Context {
ctx, cancelParent := WithCancelCause(Background())
ctx, cancelChild := WithCancelCause(ctx)
cancelParent(parentCause)
cancelChild(childCause)
return ctx
}(),
err: Canceled,
cause: parentCause,
},
{
name: "WithCancelCause: parent cause after child",
ctx: func() Context {
ctx, cancelParent := WithCancelCause(Background())
ctx, cancelChild := WithCancelCause(ctx)
cancelChild(childCause)
cancelParent(parentCause)
return ctx
}(),
err: Canceled,
cause: childCause,
},
{
name: "WithCancelCause: parent cause before nil",
ctx: func() Context {
ctx, cancelParent := WithCancelCause(Background())
ctx, cancelChild := WithCancel(ctx)
cancelParent(parentCause)
cancelChild()
return ctx
}(),
err: Canceled,
cause: parentCause,
},
{
name: "WithCancelCause: parent cause after nil",
ctx: func() Context {
ctx, cancelParent := WithCancelCause(Background())
ctx, cancelChild := WithCancel(ctx)
cancelChild()
cancelParent(parentCause)
return ctx
}(),
err: Canceled,
cause: Canceled,
},
{
name: "WithCancelCause: child cause after nil",
ctx: func() Context {
ctx, cancelParent := WithCancel(Background())
ctx, cancelChild := WithCancelCause(ctx)
cancelParent()
cancelChild(childCause)
return ctx
}(),
err: Canceled,
cause: Canceled,
},
{
name: "WithCancelCause: child cause before nil",
ctx: func() Context {
ctx, cancelParent := WithCancel(Background())
ctx, cancelChild := WithCancelCause(ctx)
cancelChild(childCause)
cancelParent()
return ctx
}(),
err: Canceled,
cause: childCause,
},
{
name: "WithTimeout",
ctx: func() Context {
ctx, cancel := WithTimeout(Background(), 0)
cancel()
return ctx
}(),
err: DeadlineExceeded,
cause: DeadlineExceeded,
},
} {
if got, want := test.ctx.Err(), test.err; want != got {
t.Errorf("%s: ctx.Err() = %v want %v", test.name, got, want)
}
if got, want := Cause(test.ctx), test.cause; want != got {
t.Errorf("%s: Cause(ctx) = %v want %v", test.name, got, want)
}
}
}

View File

@ -29,3 +29,4 @@ func TestWithValueChecksKey(t *testing.T) { XTestWithValueChecksKey
func TestInvalidDerivedFail(t *testing.T) { XTestInvalidDerivedFail(t) }
func TestDeadlineExceededSupportsTimeout(t *testing.T) { XTestDeadlineExceededSupportsTimeout(t) }
func TestCustomContextGoroutines(t *testing.T) { XTestCustomContextGoroutines(t) }
func TestCause(t *testing.T) { XTestCause(t) }