Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 40 additions & 0 deletions bool_func.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
package pflag

// -- func Value
type boolfuncValue func(string) error

func (f boolfuncValue) Set(s string) error { return f(s) }

func (f boolfuncValue) Type() string { return "func" }

func (f boolfuncValue) String() string { return "" } // same behavior as stdlib 'flag' package

func (f boolfuncValue) IsBoolFlag() bool { return true }

// BoolFunc defines a func flag with specified name, callback function and usage string.
//
// The callback function will be called every time "--{name}" (or any form that matches the flag) is parsed
// on the command line.
func (f *FlagSet) BoolFunc(name string, usage string, fn func(string) error) {
f.BoolFuncP(name, "", usage, fn)
}

// BoolFuncP is like BoolFunc, but accepts a shorthand letter that can be used after a single dash.
func (f *FlagSet) BoolFuncP(name, shorthand string, usage string, fn func(string) error) {
var val Value = boolfuncValue(fn)
flag := f.VarPF(val, name, shorthand, usage)
flag.NoOptDefVal = "true"
}

// BoolFunc defines a func flag with specified name, callback function and usage string.
//
// The callback function will be called every time "--{name}" (or any form that matches the flag) is parsed
// on the command line.
func BoolFunc(name string, usage string, fn func(string) error) {
CommandLine.BoolFuncP(name, "", usage, fn)
}

// BoolFuncP is like BoolFunc, but accepts a shorthand letter that can be used after a single dash.
func BoolFuncP(name, shorthand string, fn func(string) error, usage string) {
CommandLine.BoolFuncP(name, shorthand, usage, fn)
}
147 changes: 147 additions & 0 deletions bool_func_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
package pflag

import (
"errors"
"flag"
"io"
"strings"
"testing"
)

func TestBoolFunc(t *testing.T) {
var count int
fn := func(_ string) error {
count++
return nil
}

fset := NewFlagSet("test", ContinueOnError)
fset.BoolFunc("func", "Callback function", fn)

err := fset.Parse([]string{"--func", "--func=1", "--func=false"})
if err != nil {
t.Fatal("expected no error; got", err)
}

if count != 3 {
t.Fatalf("expected 3 calls to the callback, got %d calls", count)
}
}

func TestBoolFuncP(t *testing.T) {
var count int
fn := func(_ string) error {
count++
return nil
}

fset := NewFlagSet("test", ContinueOnError)
fset.BoolFuncP("bfunc", "b", "Callback function", fn)

err := fset.Parse([]string{"--bfunc", "--bfunc=0", "--bfunc=false", "-b", "-b=0"})
if err != nil {
t.Fatal("expected no error; got", err)
}

if count != 5 {
t.Fatalf("expected 5 calls to the callback, got %d calls", count)
}
}

func TestBoolFuncCompat(t *testing.T) {
// compare behavior with the stdlib 'flag' package
type BoolFuncFlagSet interface {
BoolFunc(name string, usage string, fn func(string) error)
Parse([]string) error
}

unitTestErr := errors.New("unit test error")
runCase := func(f BoolFuncFlagSet, name string, args []string) (values []string, err error) {
fn := func(s string) error {
values = append(values, s)
if s == "err" {
return unitTestErr
}
return nil
}
f.BoolFunc(name, "Callback function", fn)

err = f.Parse(args)
return values, err
}

t.Run("regular parsing", func(t *testing.T) {
flagName := "bflag"
args := []string{"--bflag", "--bflag=false", "--bflag=1", "--bflag=bar", "--bflag="}

// It turns out that, even though the function is called "BoolFunc",
// the stanard flag package does not try to parse the value assigned to
// that cli flag as a boolean. The string provided on the command line is
// passed as is to the callback.
// e.g: with "--bflag=not_a_bool" on the command line, the FlagSet does not
// generate an error stating "invalid boolean value", and `fn` will be called
// with "not_a_bool" as an argument.

stdFSet := flag.NewFlagSet("std test", flag.ContinueOnError)
stdValues, err := runCase(stdFSet, flagName, args)
if err != nil {
t.Fatalf("std flag: expected no error, got %v", err)
}
expected := []string{"true", "false", "1", "bar", ""}
if !cmpLists(expected, stdValues) {
t.Fatalf("std flag: expected %v, got %v", expected, stdValues)
}

fset := NewFlagSet("pflag test", ContinueOnError)
pflagValues, err := runCase(fset, flagName, args)
if err != nil {
t.Fatalf("pflag: expected no error, got %v", err)
}
if !cmpLists(stdValues, pflagValues) {
t.Fatalf("pflag: expected %v, got %v", stdValues, pflagValues)
}
})

t.Run("error triggered by callback", func(t *testing.T) {
flagName := "bflag"
args := []string{"--bflag", "--bflag=err", "--bflag=after"}

// test behavior of standard flag.Fset with an error triggere by the callback:
// (note: as can be seen in 'runCase()', if the callback sees "err" as a value
// for the bool flag, it will return an error)
stdFSet := flag.NewFlagSet("std test", flag.ContinueOnError)
stdFSet.SetOutput(io.Discard) // suppress output

// run test case with standard flag.Fset
stdValues, err := runCase(stdFSet, flagName, args)

// double check the standard behavior:
// - .Parse() should return an error, which contains the error message
if err == nil {
t.Fatalf("std flag: expected an error triggered by callback, got no error instead")
}
if !strings.HasSuffix(err.Error(), unitTestErr.Error()) {
t.Fatalf("std flag: expected unittest error, got unexpected error value: %T %v", err, err)
}
// - the function should have been called twice, with the first two values,
// the final "=after" should not be recorded
expected := []string{"true", "err"}
if !cmpLists(expected, stdValues) {
t.Fatalf("std flag: expected %v, got %v", expected, stdValues)
}

// now run the test case on a pflag FlagSet:
fset := NewFlagSet("pflag test", ContinueOnError)
pflagValues, err := runCase(fset, flagName, args)

// check that there is a similar error (note: pflag will _wrap_ the error, while the stdlib
// currently keeps the original message but creates a flat errors.Error)
if !errors.Is(err, unitTestErr) {
t.Fatalf("pflag: got unexpected error value: %T %v", err, err)
}
// the callback should be called the same number of times, with the same values:
if !cmpLists(stdValues, pflagValues) {
t.Fatalf("pflag: expected %v, got %v", stdValues, pflagValues)
}
})
}
37 changes: 37 additions & 0 deletions func.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
package pflag

// -- func Value
type funcValue func(string) error

func (f funcValue) Set(s string) error { return f(s) }

func (f funcValue) Type() string { return "func" }

func (f funcValue) String() string { return "" } // same behavior as stdlib 'flag' package

// Func defines a func flag with specified name, callback function and usage string.
//
// The callback function will be called every time "--{name}={value}" (or equivalent) is
// parsed on the command line, with "{value}" as an argument.
func (f *FlagSet) Func(name string, usage string, fn func(string) error) {
f.FuncP(name, "", usage, fn)
}

// FuncP is like Func, but accepts a shorthand letter that can be used after a single dash.
func (f *FlagSet) FuncP(name string, shorthand string, usage string, fn func(string) error) {
var val Value = funcValue(fn)
f.VarP(val, name, shorthand, usage)
}

// Func defines a func flag with specified name, callback function and usage string.
//
// The callback function will be called every time "--{name}={value}" (or equivalent) is
// parsed on the command line, with "{value}" as an argument.
func Func(name string, fn func(string) error, usage string) {
CommandLine.FuncP(name, "", usage, fn)
}

// FuncP is like Func, but accepts a shorthand letter that can be used after a single dash.
func FuncP(name, shorthand string, fn func(string) error, usage string) {
CommandLine.FuncP(name, shorthand, usage, fn)
}
153 changes: 153 additions & 0 deletions func_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
package pflag

import (
"errors"
"flag"
"io"
"strings"
"testing"
)

func cmpLists(a, b []string) bool {
if len(a) != len(b) {
return false
}
for i := range a {
if a[i] != b[i] {
return false
}
}
return true
}

func TestFunc(t *testing.T) {
var values []string
fn := func(s string) error {
values = append(values, s)
return nil
}

fset := NewFlagSet("test", ContinueOnError)
fset.Func("fnflag", "Callback function", fn)

err := fset.Parse([]string{"--fnflag=aa", "--fnflag", "bb"})
if err != nil {
t.Fatal("expected no error; got", err)
}

expected := []string{"aa", "bb"}
if !cmpLists(expected, values) {
t.Fatalf("expected %v, got %v", expected, values)
}
}

func TestFuncP(t *testing.T) {
var values []string
fn := func(s string) error {
values = append(values, s)
return nil
}

fset := NewFlagSet("test", ContinueOnError)
fset.FuncP("fnflag", "f", "Callback function", fn)

err := fset.Parse([]string{"--fnflag=a", "--fnflag", "b", "-fc", "-f=d", "-f", "e"})
if err != nil {
t.Fatal("expected no error; got", err)
}

expected := []string{"a", "b", "c", "d", "e"}
if !cmpLists(expected, values) {
t.Fatalf("expected %v, got %v", expected, values)
}
}

func TestFuncCompat(t *testing.T) {
// compare behavior with the stdlib 'flag' package
type FuncFlagSet interface {
Func(name string, usage string, fn func(string) error)
Parse([]string) error
}

unitTestErr := errors.New("unit test error")
runCase := func(f FuncFlagSet, name string, args []string) (values []string, err error) {
fn := func(s string) error {
values = append(values, s)
if s == "err" {
return unitTestErr
}
return nil
}
f.Func(name, "Callback function", fn)

err = f.Parse(args)
return values, err
}

t.Run("regular parsing", func(t *testing.T) {
flagName := "fnflag"
args := []string{"--fnflag=xx", "--fnflag", "yy", "--fnflag=zz"}

stdFSet := flag.NewFlagSet("std test", flag.ContinueOnError)
stdValues, err := runCase(stdFSet, flagName, args)
if err != nil {
t.Fatalf("std flag: expected no error, got %v", err)
}
expected := []string{"xx", "yy", "zz"}
if !cmpLists(expected, stdValues) {
t.Fatalf("std flag: expected %v, got %v", expected, stdValues)
}

fset := NewFlagSet("pflag test", ContinueOnError)
pflagValues, err := runCase(fset, flagName, args)
if err != nil {
t.Fatalf("pflag: expected no error, got %v", err)
}
if !cmpLists(stdValues, pflagValues) {
t.Fatalf("pflag: expected %v, got %v", stdValues, pflagValues)
}
})

t.Run("error triggered by callback", func(t *testing.T) {
flagName := "fnflag"
args := []string{"--fnflag", "before", "--fnflag", "err", "--fnflag", "after"}

// test behavior of standard flag.Fset with an error triggere by the callback:
// (note: as can be seen in 'runCase()', if the callback sees "err" as a value
// for the bool flag, it will return an error)
stdFSet := flag.NewFlagSet("std test", flag.ContinueOnError)
stdFSet.SetOutput(io.Discard) // suppress output

// run test case with standard flag.Fset
stdValues, err := runCase(stdFSet, flagName, args)

// double check the standard behavior:
// - .Parse() should return an error, which contains the error message
if err == nil {
t.Fatalf("std flag: expected an error triggered by callback, got no error instead")
}
if !strings.HasSuffix(err.Error(), unitTestErr.Error()) {
t.Fatalf("std flag: expected unittest error, got unexpected error value: %T %v", err, err)
}
// - the function should have been called twice, with the first two values,
// the final "=after" should not be recorded
expected := []string{"before", "err"}
if !cmpLists(expected, stdValues) {
t.Fatalf("std flag: expected %v, got %v", expected, stdValues)
}

// now run the test case on a pflag FlagSet:
fset := NewFlagSet("pflag test", ContinueOnError)
pflagValues, err := runCase(fset, flagName, args)

// check that there is a similar error (note: pflag will _wrap_ the error, while the stdlib
// currently keeps the original message but creates a flat errors.Error)
if !errors.Is(err, unitTestErr) {
t.Fatalf("pflag: got unexpected error value: %T %v", err, err)
}
// the callback should be called the same number of times, with the same values:
if !cmpLists(stdValues, pflagValues) {
t.Fatalf("pflag: expected %v, got %v", stdValues, pflagValues)
}
})
}
Loading