From 2f702d2fee263dd8d133f8cc762b5a6a499434ae Mon Sep 17 00:00:00 2001 From: joshcarp Date: Tue, 8 Feb 2022 21:57:18 +0000 Subject: [PATCH] Add Command.SetContext Basically the same as https://github.com/spf13/cobra/pull/1517 but uses the `Set` naming convention instead of `WithContext` Context setting without execution is important because it means that more design patterns can be achieved. Currently I am using functional options in a project and I can add behaviour through functional options as such: ```go type GrptlOption func(*cobra.Command) error func WithFileDescriptors(descriptors ...protoreflect.FileDescriptor) GrptlOption { return func(cmd *cobra.Command) error { err := CommandFromFileDescriptors(cmd, descriptors...) if err != nil { return err } return nil } } ``` I've got a lot more options and this pattern allows me to have nice abstracted pieces of logic that interact with the cobra command. This Pattern also allows for adding extra information to a call through `PreRun` functions: ```go cmd := &cobra.Command{ Use: "Foobar", PersistentPreRun: func(cmd *cobra.Command, args []string) { err :=cmd.SetContext(metadata.AppendToOutgoingContext(context.Background(), "pre", "run")) }, } ``` This is a veer nice abstraction and allows for these functions to be very modular The issue I'm facing at the moment is that I can't write a nifty option (something like `WithAuthentication`) because that needs access to reading and setting the context. Currently I can only read the context. Needing to use `ExecuteContext` breaks this abstraction because I need to run it right at the end. Merge https://github.com/spf13/cobra/pull/1551 Fixes https://github.com/spf13/cobra/pull/1517 Fixes https://github.com/spf13/cobra/pull/1118 Fixes https://github.com/spf13/cobra/issues/563 --- command.go | 6 +++ command_test.go | 98 +++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 104 insertions(+) diff --git a/command.go b/command.go index d262c39..f4bbdd5 100644 --- a/command.go +++ b/command.go @@ -233,6 +233,12 @@ func (c *Command) Context() context.Context { return c.ctx } +// SetContext sets context for the command. It is set to context.Background by default and will be overwritten by +// Command.ExecuteContext or Command.ExecuteContextC +func (c *Command) SetContext(ctx context.Context) { + c.ctx = ctx +} + // SetArgs sets arguments for the command. It is set to os.Args[1:] by default, if desired, can be overridden // particularly useful when testing. func (c *Command) SetArgs(a []string) { diff --git a/command_test.go b/command_test.go index 3e16cad..52b5724 100644 --- a/command_test.go +++ b/command_test.go @@ -2087,3 +2087,101 @@ func TestContext(t *testing.T){ t.Error("expected root.Context() != nil") } } + +func TestSetContext(t *testing.T) { + key, val := "foo", "bar" + root := &Command{ + Use: "root", + Run: func(cmd *Command, args []string) { + key := cmd.Context().Value(key) + got, ok := key.(string) + if !ok { + t.Error("key not found in context") + } + if got != val { + t.Errorf("Expected value: \n %v\nGot:\n %v\n", val, got) + } + }, + } + ctx := context.WithValue(context.Background(), key, val) + root.SetContext(ctx) + err := root.Execute() + if err != nil { + t.Error(err) + } +} + +func TestSetContextPreRun(t *testing.T) { + key, val := "foo", "bar" + root := &Command{ + Use: "root", + PreRun: func(cmd *Command, args []string) { + ctx := context.WithValue(cmd.Context(), key, val) + cmd.SetContext(ctx) + }, + Run: func(cmd *Command, args []string) { + key := cmd.Context().Value(key) + got, ok := key.(string) + if !ok { + t.Error("key not found in context") + } + if got != val { + t.Errorf("Expected value: \n %v\nGot:\n %v\n", val, got) + } + }, + } + err := root.Execute() + if err != nil { + t.Error(err) + } +} + +func TestSetContextPreRunOverwrite(t *testing.T) { + key, val := "foo", "bar" + root := &Command{ + Use: "root", + Run: func(cmd *Command, args []string) { + key := cmd.Context().Value(key) + _, ok := key.(string) + if ok { + t.Error("key found in context when not expected") + } + }, + } + ctx := context.WithValue(context.Background(), key, val) + root.SetContext(ctx) + err := root.ExecuteContext(context.Background()) + if err != nil { + t.Error(err) + } +} + +func TestSetContextPersistentPreRun(t *testing.T) { + key, val := "foo", "bar" + root := &Command{ + Use: "root", + PersistentPreRun: func(cmd *Command, args []string) { + ctx := context.WithValue(cmd.Context(), key, val) + cmd.SetContext(ctx) + }, + } + child := &Command{ + Use: "child", + Run: func(cmd *Command, args []string) { + key := cmd.Context().Value(key) + got, ok := key.(string) + if !ok { + t.Error("key not found in context") + } + if got != val { + t.Errorf("Expected value: \n %v\nGot:\n %v\n", val, got) + } + }, + } + root.AddCommand(child) + root.SetArgs([]string{"child"}) + err := root.Execute() + if err != nil { + t.Error(err) + } +}