From 4a0a2b33dfa3c99250efa222439f2c27d6780e4a Mon Sep 17 00:00:00 2001 From: Damien Neil Date: Thu, 22 Sep 2022 10:43:26 -0700 Subject: [PATCH] errors, fmt: add support for wrapping multiple errors An error which implements an "Unwrap() []error" method wraps all the non-nil errors in the returned []error. We replace the concept of the "error chain" inspected by errors.Is and errors.As with the "error tree". Is and As perform a pre-order, depth-first traversal of an error's tree. As returns the first matching result, if any. The new errors.Join function returns an error wrapping a list of errors. The fmt.Errorf function now supports multiple instances of the %w verb. For #53435. Change-Id: Ib7402e70b68e28af8f201d2b66bd8e87ccfb5283 Reviewed-on: https://go-review.googlesource.com/c/go/+/432898 Reviewed-by: Cherry Mui Reviewed-by: Rob Pike Run-TryBot: Damien Neil Reviewed-by: Joseph Tsai --- api/next/53435.txt | 1 + src/errors/errors.go | 31 +++++++++++---------- src/errors/errors_test.go | 18 +++++++++++++ src/errors/join.go | 51 +++++++++++++++++++++++++++++++++++ src/errors/join_test.go | 49 +++++++++++++++++++++++++++++++++ src/errors/wrap.go | 57 ++++++++++++++++++++++++++++++--------- src/errors/wrap_test.go | 52 ++++++++++++++++++++++++++++++++++- src/fmt/errors.go | 51 +++++++++++++++++++++++++++++------ src/fmt/errors_test.go | 38 +++++++++++++++++++++++--- src/fmt/print.go | 28 +++++++++++-------- 10 files changed, 325 insertions(+), 51 deletions(-) create mode 100644 api/next/53435.txt create mode 100644 src/errors/join.go create mode 100644 src/errors/join_test.go diff --git a/api/next/53435.txt b/api/next/53435.txt new file mode 100644 index 00000000000..8f295fc96b1 --- /dev/null +++ b/api/next/53435.txt @@ -0,0 +1 @@ +pkg errors, func Join(...error) error #53435 diff --git a/src/errors/errors.go b/src/errors/errors.go index f2fabacd4e9..8436f812a6e 100644 --- a/src/errors/errors.go +++ b/src/errors/errors.go @@ -6,26 +6,29 @@ // // The New function creates errors whose only content is a text message. // -// The Unwrap, Is and As functions work on errors that may wrap other errors. -// An error wraps another error if its type has the method +// An error e wraps another error if e's type has one of the methods // // Unwrap() error +// Unwrap() []error // -// If e.Unwrap() returns a non-nil error w, then we say that e wraps w. +// If e.Unwrap() returns a non-nil error w or a slice containing w, +// then we say that e wraps w. A nil error returned from e.Unwrap() +// indicates that e does not wrap any error. It is invalid for an +// Unwrap method to return an []error containing a nil error value. // -// Unwrap unpacks wrapped errors. If its argument's type has an -// Unwrap method, it calls the method once. Otherwise, it returns nil. +// An easy way to create wrapped errors is to call fmt.Errorf and apply +// the %w verb to the error argument: // -// A simple way to create wrapped errors is to call fmt.Errorf and apply the %w verb -// to the error argument: +// wrapsErr := fmt.Errorf("... %w ...", ..., err, ...) // -// errors.Unwrap(fmt.Errorf("... %w ...", ..., err, ...)) +// Successive unwrapping of an error creates a tree. The Is and As +// functions inspect an error's tree by examining first the error +// itself followed by the tree of each of its children in turn +// (pre-order, depth-first traversal). // -// returns err. -// -// Is unwraps its first argument sequentially looking for an error that matches the -// second. It reports whether it finds a match. It should be used in preference to -// simple equality checks: +// Is examines the tree of its first argument looking for an error that +// matches the second. It reports whether it finds a match. It should be +// used in preference to simple equality checks: // // if errors.Is(err, fs.ErrExist) // @@ -35,7 +38,7 @@ // // because the former will succeed if err wraps fs.ErrExist. // -// As unwraps its first argument sequentially looking for an error that can be +// As examines the tree of its first argument looking for an error that can be // assigned to its second argument, which must be a pointer. If it succeeds, it // performs the assignment and returns true. Otherwise, it returns false. The form // diff --git a/src/errors/errors_test.go b/src/errors/errors_test.go index cf4df90b695..8b93f530d59 100644 --- a/src/errors/errors_test.go +++ b/src/errors/errors_test.go @@ -51,3 +51,21 @@ func ExampleNew_errorf() { } // Output: user "bimmler" (id 17) not found } + +func ExampleJoin() { + err1 := errors.New("err1") + err2 := errors.New("err2") + err := errors.Join(err1, err2) + fmt.Println(err) + if errors.Is(err, err1) { + fmt.Println("err is err1") + } + if errors.Is(err, err2) { + fmt.Println("err is err2") + } + // Output: + // err1 + // err2 + // err is err1 + // err is err2 +} diff --git a/src/errors/join.go b/src/errors/join.go new file mode 100644 index 00000000000..dc5a716aa6f --- /dev/null +++ b/src/errors/join.go @@ -0,0 +1,51 @@ +// Copyright 2022 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package errors + +// Join returns an error that wraps the given errors. +// Any nil error values are discarded. +// Join returns nil if errs contains no non-nil values. +// The error formats as the concatenation of the strings obtained +// by calling the Error method of each element of errs, with a newline +// between each string. +func Join(errs ...error) error { + n := 0 + for _, err := range errs { + if err != nil { + n++ + } + } + if n == 0 { + return nil + } + e := &joinError{ + errs: make([]error, 0, n), + } + for _, err := range errs { + if err != nil { + e.errs = append(e.errs, err) + } + } + return e +} + +type joinError struct { + errs []error +} + +func (e *joinError) Error() string { + var b []byte + for i, err := range e.errs { + if i > 0 { + b = append(b, '\n') + } + b = append(b, err.Error()...) + } + return string(b) +} + +func (e *joinError) Unwrap() []error { + return e.errs +} diff --git a/src/errors/join_test.go b/src/errors/join_test.go new file mode 100644 index 00000000000..ee69314529f --- /dev/null +++ b/src/errors/join_test.go @@ -0,0 +1,49 @@ +// Copyright 2022 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package errors_test + +import ( + "errors" + "reflect" + "testing" +) + +func TestJoinReturnsNil(t *testing.T) { + if err := errors.Join(); err != nil { + t.Errorf("errors.Join() = %v, want nil", err) + } + if err := errors.Join(nil); err != nil { + t.Errorf("errors.Join(nil) = %v, want nil", err) + } + if err := errors.Join(nil, nil); err != nil { + t.Errorf("errors.Join(nil, nil) = %v, want nil", err) + } +} + +func TestJoin(t *testing.T) { + err1 := errors.New("err1") + err2 := errors.New("err2") + for _, test := range []struct { + errs []error + want []error + }{{ + errs: []error{err1}, + want: []error{err1}, + }, { + errs: []error{err1, err2}, + want: []error{err1, err2}, + }, { + errs: []error{err1, nil, err2}, + want: []error{err1, err2}, + }} { + got := errors.Join(test.errs...).(interface{ Unwrap() []error }).Unwrap() + if !reflect.DeepEqual(got, test.want) { + t.Errorf("Join(%v) = %v; want %v", test.errs, got, test.want) + } + if len(got) != cap(got) { + t.Errorf("Join(%v) returns errors with len=%v, cap=%v; want len==cap", test.errs, len(got), cap(got)) + } + } +} diff --git a/src/errors/wrap.go b/src/errors/wrap.go index 263ae16b48d..a719655b10d 100644 --- a/src/errors/wrap.go +++ b/src/errors/wrap.go @@ -11,6 +11,8 @@ import ( // Unwrap returns the result of calling the Unwrap method on err, if err's // type contains an Unwrap method returning error. // Otherwise, Unwrap returns nil. +// +// Unwrap returns nil if the Unwrap method returns []error. func Unwrap(err error) error { u, ok := err.(interface { Unwrap() error @@ -21,10 +23,11 @@ func Unwrap(err error) error { return u.Unwrap() } -// Is reports whether any error in err's chain matches target. +// Is reports whether any error in err's tree matches target. // -// The chain consists of err itself followed by the sequence of errors obtained by -// repeatedly calling Unwrap. +// The tree consists of err itself, followed by the errors obtained by repeatedly +// calling Unwrap. When err wraps multiple errors, Is examines err followed by a +// depth-first traversal of its children. // // An error is considered to match a target if it is equal to that target or if // it implements a method Is(error) bool such that Is(target) returns true. @@ -50,20 +53,31 @@ func Is(err, target error) bool { if x, ok := err.(interface{ Is(error) bool }); ok && x.Is(target) { return true } - // TODO: consider supporting target.Is(err). This would allow - // user-definable predicates, but also may allow for coping with sloppy - // APIs, thereby making it easier to get away with them. - if err = Unwrap(err); err == nil { + switch x := err.(type) { + case interface{ Unwrap() error }: + err = x.Unwrap() + if err == nil { + return false + } + case interface{ Unwrap() []error }: + for _, err := range x.Unwrap() { + if Is(err, target) { + return true + } + } + return false + default: return false } } } -// As finds the first error in err's chain that matches target, and if one is found, sets +// As finds the first error in err's tree that matches target, and if one is found, sets // target to that error value and returns true. Otherwise, it returns false. // -// The chain consists of err itself followed by the sequence of errors obtained by -// repeatedly calling Unwrap. +// The tree consists of err itself, followed by the errors obtained by repeatedly +// calling Unwrap. When err wraps multiple errors, As examines err followed by a +// depth-first traversal of its children. // // An error matches target if the error's concrete value is assignable to the value // pointed to by target, or if the error has a method As(interface{}) bool such that @@ -76,6 +90,9 @@ func Is(err, target error) bool { // As panics if target is not a non-nil pointer to either a type that implements // error, or to any interface type. func As(err error, target any) bool { + if err == nil { + return false + } if target == nil { panic("errors: target cannot be nil") } @@ -88,7 +105,7 @@ func As(err error, target any) bool { if targetType.Kind() != reflectlite.Interface && !targetType.Implements(errorType) { panic("errors: *target must be interface or implement error") } - for err != nil { + for { if reflectlite.TypeOf(err).AssignableTo(targetType) { val.Elem().Set(reflectlite.ValueOf(err)) return true @@ -96,9 +113,23 @@ func As(err error, target any) bool { if x, ok := err.(interface{ As(any) bool }); ok && x.As(target) { return true } - err = Unwrap(err) + switch x := err.(type) { + case interface{ Unwrap() error }: + err = x.Unwrap() + if err == nil { + return false + } + case interface{ Unwrap() []error }: + for _, err := range x.Unwrap() { + if As(err, target) { + return true + } + } + return false + default: + return false + } } - return false } var errorType = reflectlite.TypeOf((*error)(nil)).Elem() diff --git a/src/errors/wrap_test.go b/src/errors/wrap_test.go index eb8314b04bb..9efbe45ee0b 100644 --- a/src/errors/wrap_test.go +++ b/src/errors/wrap_test.go @@ -47,6 +47,17 @@ func TestIs(t *testing.T) { {&errorUncomparable{}, &errorUncomparable{}, false}, {errorUncomparable{}, err1, false}, {&errorUncomparable{}, err1, false}, + {multiErr{}, err1, false}, + {multiErr{err1, err3}, err1, true}, + {multiErr{err3, err1}, err1, true}, + {multiErr{err1, err3}, errors.New("x"), false}, + {multiErr{err3, errb}, errb, true}, + {multiErr{err3, errb}, erra, true}, + {multiErr{err3, errb}, err1, true}, + {multiErr{errb, err3}, err1, true}, + {multiErr{poser}, err1, true}, + {multiErr{poser}, err3, true}, + {multiErr{nil}, nil, false}, } for _, tc := range testCases { t.Run("", func(t *testing.T) { @@ -148,6 +159,41 @@ func TestAs(t *testing.T) { &timeout, true, errF, + }, { + multiErr{}, + &errT, + false, + nil, + }, { + multiErr{errors.New("a"), errorT{"T"}}, + &errT, + true, + errorT{"T"}, + }, { + multiErr{errorT{"T"}, errors.New("a")}, + &errT, + true, + errorT{"T"}, + }, { + multiErr{errorT{"a"}, errorT{"b"}}, + &errT, + true, + errorT{"a"}, + }, { + multiErr{multiErr{errors.New("a"), errorT{"a"}}, errorT{"b"}}, + &errT, + true, + errorT{"a"}, + }, { + multiErr{wrapped{"path error", errF}}, + &timeout, + true, + errF, + }, { + multiErr{nil}, + &errT, + false, + nil, }} for i, tc := range testCases { name := fmt.Sprintf("%d:As(Errorf(..., %v), %v)", i, tc.err, tc.target) @@ -223,9 +269,13 @@ type wrapped struct { } func (e wrapped) Error() string { return e.msg } - func (e wrapped) Unwrap() error { return e.err } +type multiErr []error + +func (m multiErr) Error() string { return "multiError" } +func (m multiErr) Unwrap() []error { return []error(m) } + type errorUncomparable struct { f []string } diff --git a/src/fmt/errors.go b/src/fmt/errors.go index 4f4daf19e14..1fbd39f8f17 100644 --- a/src/fmt/errors.go +++ b/src/fmt/errors.go @@ -4,26 +4,48 @@ package fmt -import "errors" +import ( + "errors" + "sort" +) // Errorf formats according to a format specifier and returns the string as a // value that satisfies error. // // If the format specifier includes a %w verb with an error operand, -// the returned error will implement an Unwrap method returning the operand. It is -// invalid to include more than one %w verb or to supply it with an operand -// that does not implement the error interface. The %w verb is otherwise -// a synonym for %v. +// the returned error will implement an Unwrap method returning the operand. +// If there is more than one %w verb, the returned error will implement an +// Unwrap method returning a []error containing all the %w operands in the +// order they appear in the arguments. +// It is invalid to supply the %w verb with an operand that does not implement +// the error interface. The %w verb is otherwise a synonym for %v. func Errorf(format string, a ...any) error { p := newPrinter() p.wrapErrs = true p.doPrintf(format, a) s := string(p.buf) var err error - if p.wrappedErr == nil { + switch len(p.wrappedErrs) { + case 0: err = errors.New(s) - } else { - err = &wrapError{s, p.wrappedErr} + case 1: + w := &wrapError{msg: s} + w.err, _ = a[p.wrappedErrs[0]].(error) + err = w + default: + if p.reordered { + sort.Ints(p.wrappedErrs) + } + var errs []error + for i, argNum := range p.wrappedErrs { + if i > 0 && p.wrappedErrs[i-1] == argNum { + continue + } + if e, ok := a[argNum].(error); ok { + errs = append(errs, e) + } + } + err = &wrapErrors{s, errs} } p.free() return err @@ -41,3 +63,16 @@ func (e *wrapError) Error() string { func (e *wrapError) Unwrap() error { return e.err } + +type wrapErrors struct { + msg string + errs []error +} + +func (e *wrapErrors) Error() string { + return e.msg +} + +func (e *wrapErrors) Unwrap() []error { + return e.errs +} diff --git a/src/fmt/errors_test.go b/src/fmt/errors_test.go index 481a7b84030..4eb55faffe7 100644 --- a/src/fmt/errors_test.go +++ b/src/fmt/errors_test.go @@ -7,6 +7,7 @@ package fmt_test import ( "errors" "fmt" + "reflect" "testing" ) @@ -20,6 +21,7 @@ func TestErrorf(t *testing.T) { err error wantText string wantUnwrap error + wantSplit []error }{{ err: fmt.Errorf("%w", wrapped), wantText: "inner error", @@ -53,11 +55,29 @@ func TestErrorf(t *testing.T) { err: noVetErrorf("%w is not an error", "not-an-error"), wantText: "%!w(string=not-an-error) is not an error", }, { - err: noVetErrorf("wrapped two errors: %w %w", errString("1"), errString("2")), - wantText: "wrapped two errors: 1 %!w(fmt_test.errString=2)", + err: noVetErrorf("wrapped two errors: %w %w", errString("1"), errString("2")), + wantText: "wrapped two errors: 1 2", + wantSplit: []error{errString("1"), errString("2")}, }, { - err: noVetErrorf("wrapped three errors: %w %w %w", errString("1"), errString("2"), errString("3")), - wantText: "wrapped three errors: 1 %!w(fmt_test.errString=2) %!w(fmt_test.errString=3)", + err: noVetErrorf("wrapped three errors: %w %w %w", errString("1"), errString("2"), errString("3")), + wantText: "wrapped three errors: 1 2 3", + wantSplit: []error{errString("1"), errString("2"), errString("3")}, + }, { + err: noVetErrorf("wrapped nil error: %w %w %w", errString("1"), nil, errString("2")), + wantText: "wrapped nil error: 1 %!w() 2", + wantSplit: []error{errString("1"), errString("2")}, + }, { + err: noVetErrorf("wrapped one non-error: %w %w %w", errString("1"), "not-an-error", errString("3")), + wantText: "wrapped one non-error: 1 %!w(string=not-an-error) 3", + wantSplit: []error{errString("1"), errString("3")}, + }, { + err: fmt.Errorf("wrapped errors out of order: %[3]w %[2]w %[1]w", errString("1"), errString("2"), errString("3")), + wantText: "wrapped errors out of order: 3 2 1", + wantSplit: []error{errString("1"), errString("2"), errString("3")}, + }, { + err: fmt.Errorf("wrapped several times: %[1]w %[1]w %[2]w %[1]w", errString("1"), errString("2")), + wantText: "wrapped several times: 1 1 2 1", + wantSplit: []error{errString("1"), errString("2")}, }, { err: fmt.Errorf("%w", nil), wantText: "%!w()", @@ -66,12 +86,22 @@ func TestErrorf(t *testing.T) { if got, want := errors.Unwrap(test.err), test.wantUnwrap; got != want { t.Errorf("Formatted error: %v\nerrors.Unwrap() = %v, want %v", test.err, got, want) } + if got, want := splitErr(test.err), test.wantSplit; !reflect.DeepEqual(got, want) { + t.Errorf("Formatted error: %v\nUnwrap() []error = %v, want %v", test.err, got, want) + } if got, want := test.err.Error(), test.wantText; got != want { t.Errorf("err.Error() = %q, want %q", got, want) } } } +func splitErr(err error) []error { + if e, ok := err.(interface{ Unwrap() []error }); ok { + return e.Unwrap() + } + return nil +} + type errString string func (e errString) Error() string { return string(e) } diff --git a/src/fmt/print.go b/src/fmt/print.go index 4eabda1ce86..b3dd43ce04e 100644 --- a/src/fmt/print.go +++ b/src/fmt/print.go @@ -139,8 +139,8 @@ type pp struct { erroring bool // wrapErrs is set when the format string may contain a %w verb. wrapErrs bool - // wrappedErr records the target of the %w verb. - wrappedErr error + // wrappedErrs records the targets of the %w verb. + wrappedErrs []int } var ppFree = sync.Pool{ @@ -171,10 +171,13 @@ func (p *pp) free() { } else { p.buf = p.buf[:0] } + if cap(p.wrappedErrs) > 8 { + p.wrappedErrs = nil + } p.arg = nil p.value = reflect.Value{} - p.wrappedErr = nil + p.wrappedErrs = p.wrappedErrs[:0] ppFree.Put(p) } @@ -620,16 +623,12 @@ func (p *pp) handleMethods(verb rune) (handled bool) { return } if verb == 'w' { - // It is invalid to use %w other than with Errorf, more than once, - // or with a non-error arg. - err, ok := p.arg.(error) - if !ok || !p.wrapErrs || p.wrappedErr != nil { - p.wrappedErr = nil - p.wrapErrs = false + // It is invalid to use %w other than with Errorf or with a non-error arg. + _, ok := p.arg.(error) + if !ok || !p.wrapErrs { p.badVerb(verb) return true } - p.wrappedErr = err // If the arg is a Formatter, pass 'v' as the verb to it. verb = 'v' } @@ -1063,7 +1062,11 @@ formatLoop: // Fast path for common case of ascii lower case simple verbs // without precision or width or argument indices. if 'a' <= c && c <= 'z' && argNum < len(a) { - if c == 'v' { + switch c { + case 'w': + p.wrappedErrs = append(p.wrappedErrs, argNum) + fallthrough + case 'v': // Go syntax p.fmt.sharpV = p.fmt.sharp p.fmt.sharp = false @@ -1158,6 +1161,9 @@ formatLoop: p.badArgNum(verb) case argNum >= len(a): // No argument left over to print for the current verb. p.missingArg(verb) + case verb == 'w': + p.wrappedErrs = append(p.wrappedErrs, argNum) + fallthrough case verb == 'v': // Go syntax p.fmt.sharpV = p.fmt.sharp