Blob Blame History Raw
From 81378bbcd8a1005f72b1e8d7579e5dd7b2d612ab Mon Sep 17 00:00:00 2001
From: Ryan McKern <344926+mckern@users.noreply.github.com>
Date: Mon, 4 May 2020 07:38:53 -0700
Subject: [PATCH] Add exported functions to preserve `pkg/flag` compatibility
 (#220)

* Rename out() to Output()

This brings behavior inline with go's flag library, and allows for
printing output directly to whatever the current FlagSet is using for
output. This change will make it easier to correctly emit output to
stdout or stderr (e.g. a user has requested a help screen, which
should emit to stdout since it's the desired outcome).

* improve compat. with pkg/flag by adding Name()

pkg/flag has a public `Name()` function, which returns the name of the
flag set when called. This commit adds that function, as well as a
test for it.

* Streamline testing Name()

Testing `Name()` will move into its own explicit test, instead of
running inline during `TestAddFlagSet()`.

Co-authored-by: Chloe Kudryavtsev <toast@toast.cafe>

Co-authored-by: Chloe Kudryavtsev <toast@toast.cafe>
---
 flag.go      | 29 ++++++++++++++++++-----------
 flag_test.go | 21 +++++++++++++++++++++
 2 files changed, 39 insertions(+), 11 deletions(-)

diff --git a/flag.go b/flag.go
index 24a5036..7c058de 100644
--- a/flag.go
+++ b/flag.go
@@ -160,7 +160,7 @@ type FlagSet struct {
 	args              []string // arguments after flags
 	argsLenAtDash     int      // len(args) when a '--' was located when parsing, or -1 if no --
 	errorHandling     ErrorHandling
-	output            io.Writer // nil means stderr; use out() accessor
+	output            io.Writer // nil means stderr; use Output() accessor
 	interspersed      bool      // allow interspersed option/non-option args
 	normalizeNameFunc func(f *FlagSet, name string) NormalizedName
 
@@ -255,13 +255,20 @@ func (f *FlagSet) normalizeFlagName(name string) NormalizedName {
 	return n(f, name)
 }
 
-func (f *FlagSet) out() io.Writer {
+// Output returns the destination for usage and error messages. os.Stderr is returned if
+// output was not set or was set to nil.
+func (f *FlagSet) Output() io.Writer {
 	if f.output == nil {
 		return os.Stderr
 	}
 	return f.output
 }
 
+// Name returns the name of the flag set.
+func (f *FlagSet) Name() string {
+	return f.name
+}
+
 // SetOutput sets the destination for usage and error messages.
 // If output is nil, os.Stderr is used.
 func (f *FlagSet) SetOutput(output io.Writer) {
@@ -358,7 +365,7 @@ func (f *FlagSet) ShorthandLookup(name string) *Flag {
 	}
 	if len(name) > 1 {
 		msg := fmt.Sprintf("can not look up shorthand which is more than one ASCII character: %q", name)
-		fmt.Fprintf(f.out(), msg)
+		fmt.Fprintf(f.Output(), msg)
 		panic(msg)
 	}
 	c := name[0]
@@ -482,7 +489,7 @@ func (f *FlagSet) Set(name, value string) error {
 	}
 
 	if flag.Deprecated != "" {
-		fmt.Fprintf(f.out(), "Flag --%s has been deprecated, %s\n", flag.Name, flag.Deprecated)
+		fmt.Fprintf(f.Output(), "Flag --%s has been deprecated, %s\n", flag.Name, flag.Deprecated)
 	}
 	return nil
 }
@@ -523,7 +530,7 @@ func Set(name, value string) error {
 // otherwise, the default values of all defined flags in the set.
 func (f *FlagSet) PrintDefaults() {
 	usages := f.FlagUsages()
-	fmt.Fprint(f.out(), usages)
+	fmt.Fprint(f.Output(), usages)
 }
 
 // defaultIsZeroValue returns true if the default value for this flag represents
@@ -758,7 +765,7 @@ func PrintDefaults() {
 
 // defaultUsage is the default function to print a usage message.
 func defaultUsage(f *FlagSet) {
-	fmt.Fprintf(f.out(), "Usage of %s:\n", f.name)
+	fmt.Fprintf(f.Output(), "Usage of %s:\n", f.name)
 	f.PrintDefaults()
 }
 
@@ -844,7 +851,7 @@ func (f *FlagSet) AddFlag(flag *Flag) {
 	_, alreadyThere := f.formal[normalizedFlagName]
 	if alreadyThere {
 		msg := fmt.Sprintf("%s flag redefined: %s", f.name, flag.Name)
-		fmt.Fprintln(f.out(), msg)
+		fmt.Fprintln(f.Output(), msg)
 		panic(msg) // Happens only if flags are declared with identical names
 	}
 	if f.formal == nil {
@@ -860,7 +867,7 @@ func (f *FlagSet) AddFlag(flag *Flag) {
 	}
 	if len(flag.Shorthand) > 1 {
 		msg := fmt.Sprintf("%q shorthand is more than one ASCII character", flag.Shorthand)
-		fmt.Fprintf(f.out(), msg)
+		fmt.Fprintf(f.Output(), msg)
 		panic(msg)
 	}
 	if f.shorthands == nil {
@@ -870,7 +877,7 @@ func (f *FlagSet) AddFlag(flag *Flag) {
 	used, alreadyThere := f.shorthands[c]
 	if alreadyThere {
 		msg := fmt.Sprintf("unable to redefine %q shorthand in %q flagset: it's already used for %q flag", c, f.name, used.Name)
-		fmt.Fprintf(f.out(), msg)
+		fmt.Fprintf(f.Output(), msg)
 		panic(msg)
 	}
 	f.shorthands[c] = flag
@@ -909,7 +916,7 @@ func VarP(value Value, name, shorthand, usage string) {
 func (f *FlagSet) failf(format string, a ...interface{}) error {
 	err := fmt.Errorf(format, a...)
 	if f.errorHandling != ContinueOnError {
-		fmt.Fprintln(f.out(), err)
+		fmt.Fprintln(f.Output(), err)
 		f.usage()
 	}
 	return err
@@ -1060,7 +1067,7 @@ func (f *FlagSet) parseSingleShortArg(shorthands string, args []string, fn parse
 	}
 
 	if flag.ShorthandDeprecated != "" {
-		fmt.Fprintf(f.out(), "Flag shorthand -%s has been deprecated, %s\n", flag.Shorthand, flag.ShorthandDeprecated)
+		fmt.Fprintf(f.Output(), "Flag shorthand -%s has been deprecated, %s\n", flag.Shorthand, flag.ShorthandDeprecated)
 	}
 
 	err = fn(flag, value)
diff --git a/flag_test.go b/flag_test.go
index 7d02dbc..58a5d25 100644
--- a/flag_test.go
+++ b/flag_test.go
@@ -159,6 +159,16 @@ func TestAnnotation(t *testing.T) {
 	}
 }
 
+func TestName(t *testing.T) {
+	flagSetName := "bob"
+	f := NewFlagSet(flagSetName, ContinueOnError)
+
+	givenName := f.Name()
+	if givenName != flagSetName {
+		t.Errorf("Unexpected result when retrieving a FlagSet's name: expected %s, but found %s", flagSetName, givenName)
+	}
+}
+
 func testParse(f *FlagSet, t *testing.T) {
 	if f.Parsed() {
 		t.Error("f.Parse() = true before Parse")
@@ -854,6 +864,17 @@ func TestSetOutput(t *testing.T) {
 	}
 }
 
+func TestOutput(t *testing.T) {
+	var flags FlagSet
+	var buf bytes.Buffer
+	expect := "an example string"
+	flags.SetOutput(&buf)
+	fmt.Fprint(flags.Output(), expect)
+	if out := buf.String(); !strings.Contains(out, expect) {
+		t.Errorf("expected output %q; got %q", expect, out)
+	}
+}
+
 // This tests that one can reset the flags. This still works but not well, and is
 // superseded by FlagSet.
 func TestChangingArgs(t *testing.T) {