diff --git a/src/pkg/flag/flag.go b/src/pkg/flag/flag.go index 964f5541b8..1719af89a1 100644 --- a/src/pkg/flag/flag.go +++ b/src/pkg/flag/flag.go @@ -62,6 +62,7 @@ package flag import ( "errors" "fmt" + "io" "os" "sort" "strconv" @@ -228,6 +229,7 @@ type FlagSet struct { args []string // arguments after flags exitOnError bool // does the program exit if there's an error? errorHandling ErrorHandling + output io.Writer // nil means stderr; use out() accessor } // A Flag represents the state of a flag. @@ -254,6 +256,19 @@ func sortFlags(flags map[string]*Flag) []*Flag { return result } +func (f *FlagSet) out() io.Writer { + if f.output == nil { + return os.Stderr + } + return f.output +} + +// SetOutput sets the destination for usage and error messages. +// If output is nil, os.Stderr is used. +func (f *FlagSet) SetOutput(output io.Writer) { + f.output = output +} + // VisitAll visits the flags in lexicographical order, calling fn for each. // It visits all flags, even those not set. func (f *FlagSet) VisitAll(fn func(*Flag)) { @@ -315,15 +330,16 @@ func Set(name, value string) error { return commandLine.Set(name, value) } -// PrintDefaults prints to standard error the default values of all defined flags in the set. +// PrintDefaults prints, to standard error unless configured +// otherwise, the default values of all defined flags in the set. func (f *FlagSet) PrintDefaults() { - f.VisitAll(func(f *Flag) { + f.VisitAll(func(flag *Flag) { format := " -%s=%s: %s\n" - if _, ok := f.Value.(*stringValue); ok { + if _, ok := flag.Value.(*stringValue); ok { // put quotes on the value format = " -%s=%q: %s\n" } - fmt.Fprintf(os.Stderr, format, f.Name, f.DefValue, f.Usage) + fmt.Fprintf(f.out(), format, flag.Name, flag.DefValue, flag.Usage) }) } @@ -334,7 +350,7 @@ func PrintDefaults() { // defaultUsage is the default function to print a usage message. func defaultUsage(f *FlagSet) { - fmt.Fprintf(os.Stderr, "Usage of %s:\n", f.name) + fmt.Fprintf(f.out(), "Usage of %s:\n", f.name) f.PrintDefaults() } @@ -601,7 +617,7 @@ func (f *FlagSet) Var(value Value, name string, usage string) { flag := &Flag{name, usage, value, value.String()} _, alreadythere := f.formal[name] if alreadythere { - fmt.Fprintf(os.Stderr, "%s flag redefined: %s\n", f.name, name) + fmt.Fprintf(f.out(), "%s flag redefined: %s\n", f.name, name) panic("flag redefinition") // Happens only if flags are declared with identical names } if f.formal == nil { @@ -624,7 +640,7 @@ func Var(value Value, name string, usage string) { // returns the error. func (f *FlagSet) failf(format string, a ...interface{}) error { err := fmt.Errorf(format, a...) - fmt.Fprintln(os.Stderr, err) + fmt.Fprintln(f.out(), err) f.usage() return err } diff --git a/src/pkg/flag/flag_test.go b/src/pkg/flag/flag_test.go index 698c15f2c5..a9561f269f 100644 --- a/src/pkg/flag/flag_test.go +++ b/src/pkg/flag/flag_test.go @@ -5,10 +5,12 @@ package flag_test import ( + "bytes" . "flag" "fmt" "os" "sort" + "strings" "testing" "time" ) @@ -206,6 +208,17 @@ func TestUserDefined(t *testing.T) { } } +func TestSetOutput(t *testing.T) { + var flags FlagSet + var buf bytes.Buffer + flags.SetOutput(&buf) + flags.Init("test", ContinueOnError) + flags.Parse([]string{"-unknown"}) + if out := buf.String(); !strings.Contains(out, "-unknown") { + t.Logf("expected output mentioning unknown; got %q", out) + } +} + // This tests that one can reset the flags. This still works but not well, and is // superseded by FlagSet. func TestChangingArgs(t *testing.T) {