diff --git a/command.go b/command.go index 2cc18891d..c973d3222 100644 --- a/command.go +++ b/command.go @@ -230,6 +230,12 @@ func (c *Command) Context() context.Context { return c.ctx } +// SetContext sets context for the command. It is set to context.Background by default and will be overwritten by +// Command.ExecuteContext or Command.ExecuteContextC +func (c *Command) SetContext(ctx context.Context) { + c.ctx = ctx +} + // SetArgs sets arguments for the command. It is set to os.Args[1:] by default, if desired, can be overridden // particularly useful when testing. func (c *Command) SetArgs(a []string) { diff --git a/command_test.go b/command_test.go index 583cb0235..cedcb66b9 100644 --- a/command_test.go +++ b/command_test.go @@ -2058,3 +2058,106 @@ func TestFParseErrWhitelistSiblingCommand(t *testing.T) { } checkStringContains(t, output, "unknown flag: --unknown") } + +func TestSetContext(t *testing.T) { + type key struct{} + val := "foobar" + root := &Command{ + Use: "root", + Run: func(cmd *Command, args []string) { + key := cmd.Context().Value(key{}) + got, ok := key.(string) + if !ok { + t.Error("key not found in context") + } + if got != val { + t.Errorf("Expected value: \n %v\nGot:\n %v\n", val, got) + } + }, + } + + ctx := context.WithValue(context.Background(), key{}, val) + root.SetContext(ctx) + err := root.Execute() + if err != nil { + t.Error(err) + } +} + +func TestSetContextPreRun(t *testing.T) { + type key struct{} + val := "barr" + root := &Command{ + Use: "root", + PreRun: func(cmd *Command, args []string) { + ctx := context.WithValue(cmd.Context(), key{}, val) + cmd.SetContext(ctx) + }, + Run: func(cmd *Command, args []string) { + val := cmd.Context().Value(key{}) + got, ok := val.(string) + if !ok { + t.Error("key not found in context") + } + if got != val { + t.Errorf("Expected value: \n %v\nGot:\n %v\n", val, got) + } + }, + } + err := root.Execute() + if err != nil { + t.Error(err) + } +} + +func TestSetContextPreRunOverwrite(t *testing.T) { + type key struct{} + val := "blah" + root := &Command{ + Use: "root", + Run: func(cmd *Command, args []string) { + key := cmd.Context().Value(key{}) + _, ok := key.(string) + if ok { + t.Error("key found in context when not expected") + } + }, + } + ctx := context.WithValue(context.Background(), key{}, val) + root.SetContext(ctx) + err := root.ExecuteContext(context.Background()) + if err != nil { + t.Error(err) + } +} + +func TestSetContextPersistentPreRun(t *testing.T) { + type key struct{} + val := "barbar" + root := &Command{ + Use: "root", + PersistentPreRun: func(cmd *Command, args []string) { + ctx := context.WithValue(cmd.Context(), key{}, val) + cmd.SetContext(ctx) + }, + } + child := &Command{ + Use: "child", + Run: func(cmd *Command, args []string) { + key := cmd.Context().Value(key{}) + got, ok := key.(string) + if !ok { + t.Error("key not found in context") + } + if got != val { + t.Errorf("Expected value: \n %v\nGot:\n %v\n", val, got) + } + }, + } + root.AddCommand(child) + root.SetArgs([]string{"child"}) + err := root.Execute() + if err != nil { + t.Error(err) + } +}