From ea4db7eef470a6c198c2fca623d1047c1de65ad8 Mon Sep 17 00:00:00 2001 From: joshcarp Date: Mon, 6 Dec 2021 21:23:45 +1100 Subject: [PATCH 1/4] Add Command.SetContext --- command.go | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/command.go b/command.go index 2cc18891d..c973d3222 100644 --- a/command.go +++ b/command.go @@ -230,6 +230,12 @@ func (c *Command) Context() context.Context { return c.ctx } +// SetContext sets context for the command. It is set to context.Background by default and will be overwritten by +// Command.ExecuteContext or Command.ExecuteContextC +func (c *Command) SetContext(ctx context.Context) { + c.ctx = ctx +} + // SetArgs sets arguments for the command. It is set to os.Args[1:] by default, if desired, can be overridden // particularly useful when testing. func (c *Command) SetArgs(a []string) { From 9439d8696c3f563f32f293da70e01417730c1444 Mon Sep 17 00:00:00 2001 From: joshcarp Date: Thu, 9 Dec 2021 06:51:07 +1100 Subject: [PATCH 2/4] Add SetContext tests --- command_test.go | 98 +++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 98 insertions(+) diff --git a/command_test.go b/command_test.go index 583cb0235..2730f82ec 100644 --- a/command_test.go +++ b/command_test.go @@ -2058,3 +2058,101 @@ func TestFParseErrWhitelistSiblingCommand(t *testing.T) { } checkStringContains(t, output, "unknown flag: --unknown") } + +func TestSetContext(t *testing.T) { + key, val := "foo", "bar" + root := &Command{ + Use: "root", + Run: func(cmd *Command, args []string) { + key := cmd.Context().Value(key) + got, ok := key.(string) + if !ok { + t.Error("key not found in context") + } + if got != val { + t.Errorf("Expected value: \n %v\nGot:\n %v\n", val, got) + } + }, + } + ctx := context.WithValue(context.Background(), key, val) + root.SetContext(ctx) + err := root.Execute() + if err != nil { + t.Error(err) + } +} + +func TestSetContextPreRun(t *testing.T) { + key, val := "foo", "bar" + root := &Command{ + Use: "root", + PreRun: func(cmd *Command, args []string) { + ctx := context.WithValue(cmd.Context(), key, val) + cmd.SetContext(ctx) + }, + Run: func(cmd *Command, args []string) { + key := cmd.Context().Value(key) + got, ok := key.(string) + if !ok { + t.Error("key not found in context") + } + if got != val { + t.Errorf("Expected value: \n %v\nGot:\n %v\n", val, got) + } + }, + } + err := root.Execute() + if err != nil { + t.Error(err) + } +} + +func TestSetContextPreRunOverwrite(t *testing.T) { + key, val := "foo", "bar" + root := &Command{ + Use: "root", + Run: func(cmd *Command, args []string) { + key := cmd.Context().Value(key) + _, ok := key.(string) + if ok { + t.Error("key found in context when not expected") + } + }, + } + ctx := context.WithValue(context.Background(), key, val) + root.SetContext(ctx) + err := root.ExecuteContext(context.Background()) + if err != nil { + t.Error(err) + } +} + +func TestSetContextPersistentPreRun(t *testing.T) { + key, val := "foo", "bar" + root := &Command{ + Use: "root", + PersistentPreRun: func(cmd *Command, args []string) { + ctx := context.WithValue(cmd.Context(), key, val) + cmd.SetContext(ctx) + }, + } + child := &Command{ + Use: "child", + Run: func(cmd *Command, args []string) { + key := cmd.Context().Value(key) + got, ok := key.(string) + if !ok { + t.Error("key not found in context") + } + if got != val { + t.Errorf("Expected value: \n %v\nGot:\n %v\n", val, got) + } + }, + } + root.AddCommand(child) + root.SetArgs([]string{"child"}) + err := root.Execute() + if err != nil { + t.Error(err) + } +} From d969ce0e3887ef0fd46879c7ad59244a6e7405b6 Mon Sep 17 00:00:00 2001 From: joshcarp Date: Wed, 16 Mar 2022 17:26:58 +1100 Subject: [PATCH 3/4] Change key to struct type --- command_test.go | 31 ++++++++++++++++++------------- 1 file changed, 18 insertions(+), 13 deletions(-) diff --git a/command_test.go b/command_test.go index 2730f82ec..1b6b7bf7e 100644 --- a/command_test.go +++ b/command_test.go @@ -2060,11 +2060,12 @@ func TestFParseErrWhitelistSiblingCommand(t *testing.T) { } func TestSetContext(t *testing.T) { - key, val := "foo", "bar" + type key struct{} + val := "val" root := &Command{ Use: "root", Run: func(cmd *Command, args []string) { - key := cmd.Context().Value(key) + key := cmd.Context().Value(key{}) got, ok := key.(string) if !ok { t.Error("key not found in context") @@ -2074,7 +2075,8 @@ func TestSetContext(t *testing.T) { } }, } - ctx := context.WithValue(context.Background(), key, val) + + ctx := context.WithValue(context.Background(), key{}, val) root.SetContext(ctx) err := root.Execute() if err != nil { @@ -2083,16 +2085,17 @@ func TestSetContext(t *testing.T) { } func TestSetContextPreRun(t *testing.T) { - key, val := "foo", "bar" + type key struct{} + val := "bar" root := &Command{ Use: "root", PreRun: func(cmd *Command, args []string) { - ctx := context.WithValue(cmd.Context(), key, val) + ctx := context.WithValue(cmd.Context(), key{}, val) cmd.SetContext(ctx) }, Run: func(cmd *Command, args []string) { - key := cmd.Context().Value(key) - got, ok := key.(string) + val := cmd.Context().Value(key{}) + got, ok := val.(string) if !ok { t.Error("key not found in context") } @@ -2108,18 +2111,19 @@ func TestSetContextPreRun(t *testing.T) { } func TestSetContextPreRunOverwrite(t *testing.T) { - key, val := "foo", "bar" + type key struct{} + val := "bar" root := &Command{ Use: "root", Run: func(cmd *Command, args []string) { - key := cmd.Context().Value(key) + key := cmd.Context().Value(key{}) _, ok := key.(string) if ok { t.Error("key found in context when not expected") } }, } - ctx := context.WithValue(context.Background(), key, val) + ctx := context.WithValue(context.Background(), key{}, val) root.SetContext(ctx) err := root.ExecuteContext(context.Background()) if err != nil { @@ -2128,18 +2132,19 @@ func TestSetContextPreRunOverwrite(t *testing.T) { } func TestSetContextPersistentPreRun(t *testing.T) { - key, val := "foo", "bar" + type key struct{} + val := "bar" root := &Command{ Use: "root", PersistentPreRun: func(cmd *Command, args []string) { - ctx := context.WithValue(cmd.Context(), key, val) + ctx := context.WithValue(cmd.Context(), key{}, val) cmd.SetContext(ctx) }, } child := &Command{ Use: "child", Run: func(cmd *Command, args []string) { - key := cmd.Context().Value(key) + key := cmd.Context().Value(key{}) got, ok := key.(string) if !ok { t.Error("key not found in context") From 645be2c44793b83581c09f7b30e24630ac86324f Mon Sep 17 00:00:00 2001 From: joshcarp Date: Wed, 16 Mar 2022 17:50:04 +1100 Subject: [PATCH 4/4] Fix linting --- command_test.go | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/command_test.go b/command_test.go index 1b6b7bf7e..cedcb66b9 100644 --- a/command_test.go +++ b/command_test.go @@ -2061,7 +2061,7 @@ func TestFParseErrWhitelistSiblingCommand(t *testing.T) { func TestSetContext(t *testing.T) { type key struct{} - val := "val" + val := "foobar" root := &Command{ Use: "root", Run: func(cmd *Command, args []string) { @@ -2076,7 +2076,7 @@ func TestSetContext(t *testing.T) { }, } - ctx := context.WithValue(context.Background(), key{}, val) + ctx := context.WithValue(context.Background(), key{}, val) root.SetContext(ctx) err := root.Execute() if err != nil { @@ -2086,7 +2086,7 @@ func TestSetContext(t *testing.T) { func TestSetContextPreRun(t *testing.T) { type key struct{} - val := "bar" + val := "barr" root := &Command{ Use: "root", PreRun: func(cmd *Command, args []string) { @@ -2112,7 +2112,7 @@ func TestSetContextPreRun(t *testing.T) { func TestSetContextPreRunOverwrite(t *testing.T) { type key struct{} - val := "bar" + val := "blah" root := &Command{ Use: "root", Run: func(cmd *Command, args []string) { @@ -2133,7 +2133,7 @@ func TestSetContextPreRunOverwrite(t *testing.T) { func TestSetContextPersistentPreRun(t *testing.T) { type key struct{} - val := "bar" + val := "barbar" root := &Command{ Use: "root", PersistentPreRun: func(cmd *Command, args []string) {