diff --git a/src/flag/flag.go b/src/flag/flag.go index fde7411f82..d638e49b42 100644 --- a/src/flag/flag.go +++ b/src/flag/flag.go @@ -308,13 +308,25 @@ func sortFlags(flags map[string]*Flag) []*Flag { return result } -func (f *FlagSet) out() io.Writer { +// Output returns the destination for usage and error messages. os.Stderr is returned if +// output was not set or was set to nil. +func (f *FlagSet) Output() io.Writer { if f.output == nil { return os.Stderr } return f.output } +// Name returns the name of the flag set. +func (f *FlagSet) Name() string { + return f.name +} + +// ErrorHandling returns the error handling behavior of the flag set. +func (f *FlagSet) ErrorHandling() ErrorHandling { + return f.errorHandling +} + // SetOutput sets the destination for usage and error messages. // If output is nil, os.Stderr is used. func (f *FlagSet) SetOutput(output io.Writer) { @@ -474,7 +486,7 @@ func (f *FlagSet) PrintDefaults() { s += fmt.Sprintf(" (default %v)", flag.DefValue) } } - fmt.Fprint(f.out(), s, "\n") + fmt.Fprint(f.Output(), s, "\n") }) } @@ -504,9 +516,9 @@ func PrintDefaults() { // defaultUsage is the default function to print a usage message. func (f *FlagSet) defaultUsage() { if f.name == "" { - fmt.Fprintf(f.out(), "Usage:\n") + fmt.Fprintf(f.Output(), "Usage:\n") } else { - fmt.Fprintf(f.out(), "Usage of %s:\n", f.name) + fmt.Fprintf(f.Output(), "Usage of %s:\n", f.name) } f.PrintDefaults() } @@ -525,7 +537,7 @@ func (f *FlagSet) defaultUsage() { // happens anyway as the command line's error handling strategy is set to // ExitOnError. var Usage = func() { - fmt.Fprintf(CommandLine.out(), "Usage of %s:\n", os.Args[0]) + fmt.Fprintf(CommandLine.Output(), "Usage of %s:\n", os.Args[0]) PrintDefaults() } @@ -793,7 +805,7 @@ func (f *FlagSet) Var(value Value, name string, usage string) { } else { msg = fmt.Sprintf("%s flag redefined: %s", f.name, name) } - fmt.Fprintln(f.out(), msg) + fmt.Fprintln(f.Output(), msg) panic(msg) // Happens only if flags are declared with identical names } if f.formal == nil { @@ -816,7 +828,7 @@ func Var(value Value, name string, usage string) { // returns the error. func (f *FlagSet) failf(format string, a ...interface{}) error { err := fmt.Errorf(format, a...) - fmt.Fprintln(f.out(), err) + fmt.Fprintln(f.Output(), err) f.usage() return err } diff --git a/src/flag/flag_test.go b/src/flag/flag_test.go index 4c6db96ba0..67c409f29b 100644 --- a/src/flag/flag_test.go +++ b/src/flag/flag_test.go @@ -8,6 +8,7 @@ import ( "bytes" . "flag" "fmt" + "io" "os" "sort" "strconv" @@ -454,3 +455,36 @@ func TestUsageOutput(t *testing.T) { t.Errorf("output = %q; want %q", got, want) } } + +func TestGetters(t *testing.T) { + expectedName := "flag set" + expectedErrorHandling := ContinueOnError + expectedOutput := io.Writer(os.Stderr) + fs := NewFlagSet(expectedName, expectedErrorHandling) + + if fs.Name() != expectedName { + t.Errorf("unexpected name: got %s, expected %s", fs.Name(), expectedName) + } + if fs.ErrorHandling() != expectedErrorHandling { + t.Errorf("unexpected ErrorHandling: got %d, expected %d", fs.ErrorHandling(), expectedErrorHandling) + } + if fs.Output() != expectedOutput { + t.Errorf("unexpected output: got %#v, expected %#v", fs.Output(), expectedOutput) + } + + expectedName = "gopher" + expectedErrorHandling = ExitOnError + expectedOutput = os.Stdout + fs.Init(expectedName, expectedErrorHandling) + fs.SetOutput(expectedOutput) + + if fs.Name() != expectedName { + t.Errorf("unexpected name: got %s, expected %s", fs.Name(), expectedName) + } + if fs.ErrorHandling() != expectedErrorHandling { + t.Errorf("unexpected ErrorHandling: got %d, expected %d", fs.ErrorHandling(), expectedErrorHandling) + } + if fs.Output() != expectedOutput { + t.Errorf("unexpected output: got %v, expected %v", fs.Output(), expectedOutput) + } +}