diff --git a/bytes.go b/bytes.go index 12c58db9..67d53045 100644 --- a/bytes.go +++ b/bytes.go @@ -1,6 +1,7 @@ package pflag import ( + "encoding/base64" "encoding/hex" "fmt" "strings" @@ -9,10 +10,12 @@ import ( // BytesHex adapts []byte for use as a flag. Value of flag is HEX encoded type bytesHexValue []byte +// String implements pflag.Value.String. func (bytesHex bytesHexValue) String() string { return fmt.Sprintf("%X", []byte(bytesHex)) } +// Set implements pflag.Value.Set. func (bytesHex *bytesHexValue) Set(value string) error { bin, err := hex.DecodeString(strings.TrimSpace(value)) @@ -25,6 +28,7 @@ func (bytesHex *bytesHexValue) Set(value string) error { return nil } +// Type implements pflag.Value.Type. func (*bytesHexValue) Type() string { return "bytesHex" } @@ -103,3 +107,103 @@ func BytesHex(name string, value []byte, usage string) *[]byte { func BytesHexP(name, shorthand string, value []byte, usage string) *[]byte { return CommandLine.BytesHexP(name, shorthand, value, usage) } + +// BytesBase64 adapts []byte for use as a flag. Value of flag is Base64 encoded +type bytesBase64Value []byte + +// String implements pflag.Value.String. +func (bytesBase64 bytesBase64Value) String() string { + return base64.StdEncoding.EncodeToString([]byte(bytesBase64)) +} + +// Set implements pflag.Value.Set. +func (bytesBase64 *bytesBase64Value) Set(value string) error { + bin, err := base64.StdEncoding.DecodeString(strings.TrimSpace(value)) + + if err != nil { + return err + } + + *bytesBase64 = bin + + return nil +} + +// Type implements pflag.Value.Type. +func (*bytesBase64Value) Type() string { + return "bytesBase64" +} + +func newBytesBase64Value(val []byte, p *[]byte) *bytesBase64Value { + *p = val + return (*bytesBase64Value)(p) +} + +func bytesBase64ValueConv(sval string) (interface{}, error) { + + bin, err := base64.StdEncoding.DecodeString(sval) + if err == nil { + return bin, nil + } + + return nil, fmt.Errorf("invalid string being converted to Bytes: %s %s", sval, err) +} + +// GetBytesBase64 return the []byte value of a flag with the given name +func (f *FlagSet) GetBytesBase64(name string) ([]byte, error) { + val, err := f.getFlagType(name, "bytesBase64", bytesBase64ValueConv) + + if err != nil { + return []byte{}, err + } + + return val.([]byte), nil +} + +// BytesBase64Var defines an []byte flag with specified name, default value, and usage string. +// The argument p points to an []byte variable in which to store the value of the flag. +func (f *FlagSet) BytesBase64Var(p *[]byte, name string, value []byte, usage string) { + f.VarP(newBytesBase64Value(value, p), name, "", usage) +} + +// BytesBase64VarP is like BytesBase64Var, but accepts a shorthand letter that can be used after a single dash. +func (f *FlagSet) BytesBase64VarP(p *[]byte, name, shorthand string, value []byte, usage string) { + f.VarP(newBytesBase64Value(value, p), name, shorthand, usage) +} + +// BytesBase64Var defines an []byte flag with specified name, default value, and usage string. +// The argument p points to an []byte variable in which to store the value of the flag. +func BytesBase64Var(p *[]byte, name string, value []byte, usage string) { + CommandLine.VarP(newBytesBase64Value(value, p), name, "", usage) +} + +// BytesBase64VarP is like BytesBase64Var, but accepts a shorthand letter that can be used after a single dash. +func BytesBase64VarP(p *[]byte, name, shorthand string, value []byte, usage string) { + CommandLine.VarP(newBytesBase64Value(value, p), name, shorthand, usage) +} + +// BytesBase64 defines an []byte flag with specified name, default value, and usage string. +// The return value is the address of an []byte variable that stores the value of the flag. +func (f *FlagSet) BytesBase64(name string, value []byte, usage string) *[]byte { + p := new([]byte) + f.BytesBase64VarP(p, name, "", value, usage) + return p +} + +// BytesBase64P is like BytesBase64, but accepts a shorthand letter that can be used after a single dash. +func (f *FlagSet) BytesBase64P(name, shorthand string, value []byte, usage string) *[]byte { + p := new([]byte) + f.BytesBase64VarP(p, name, shorthand, value, usage) + return p +} + +// BytesBase64 defines an []byte flag with specified name, default value, and usage string. +// The return value is the address of an []byte variable that stores the value of the flag. +func BytesBase64(name string, value []byte, usage string) *[]byte { + return CommandLine.BytesBase64P(name, "", value, usage) +} + +// BytesBase64P is like BytesBase64, but accepts a shorthand letter that can be used after a single dash. +func BytesBase64P(name, shorthand string, value []byte, usage string) *[]byte { + return CommandLine.BytesBase64P(name, shorthand, value, usage) +} diff --git a/bytes_test.go b/bytes_test.go index cc4a769d..5251f347 100644 --- a/bytes_test.go +++ b/bytes_test.go @@ -1,6 +1,7 @@ package pflag import ( + "encoding/base64" "fmt" "os" "testing" @@ -61,7 +62,7 @@ func TestBytesHex(t *testing.T) { } else if tc.success { bytesHex, err := f.GetBytesHex("bytes") if err != nil { - t.Errorf("Got error trying to fetch the IP flag: %v", err) + t.Errorf("Got error trying to fetch the 'bytes' flag: %v", err) } if fmt.Sprintf("%X", bytesHex) != tc.expected { t.Errorf("expected %q, got '%X'", tc.expected, bytesHex) @@ -70,3 +71,64 @@ func TestBytesHex(t *testing.T) { } } } + +func setUpBytesBase64(bytesBase64 *[]byte) *FlagSet { + f := NewFlagSet("test", ContinueOnError) + f.BytesBase64Var(bytesBase64, "bytes", []byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 0}, "Some bytes in Base64") + f.BytesBase64VarP(bytesBase64, "bytes2", "B", []byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 0}, "Some bytes in Base64") + return f +} + +func TestBytesBase64(t *testing.T) { + testCases := []struct { + input string + success bool + expected string + }{ + /// Positive cases + {"", true, ""}, // Is empty string OK ? + {"AQ==", true, "AQ=="}, + + // Negative cases + {"AQ", false, ""}, // Padding removed + {"ï", false, ""}, // non-base64 characters + } + + devnull, _ := os.Open(os.DevNull) + os.Stderr = devnull + + for i := range testCases { + var bytesBase64 []byte + f := setUpBytesBase64(&bytesBase64) + + tc := &testCases[i] + + // --bytes + args := []string{ + fmt.Sprintf("--bytes=%s", tc.input), + fmt.Sprintf("-B %s", tc.input), + fmt.Sprintf("--bytes2=%s", tc.input), + } + + for _, arg := range args { + err := f.Parse([]string{arg}) + + if err != nil && tc.success == true { + t.Errorf("expected success, got %q", err) + continue + } else if err == nil && tc.success == false { + // bytesBase64, err := f.GetBytesBase64("bytes") + t.Errorf("expected failure while processing %q", tc.input) + continue + } else if tc.success { + bytesBase64, err := f.GetBytesBase64("bytes") + if err != nil { + t.Errorf("Got error trying to fetch the 'bytes' flag: %v", err) + } + if base64.StdEncoding.EncodeToString(bytesBase64) != tc.expected { + t.Errorf("expected %q, got '%X'", tc.expected, bytesBase64) + } + } + } + } +}