From a0f48fcea4955f45d650c686d04d68e59291fe37 Mon Sep 17 00:00:00 2001 From: Heath Stewart Date: Mon, 19 Sep 2022 21:54:19 -0700 Subject: [PATCH 1/2] Allow defining template user functions Resolves #73 --- pkg/template/template.go | 11 ++++++ pkg/template/template_test.go | 73 +++++++++++++++++++++++++++++++++++ 2 files changed, 84 insertions(+) diff --git a/pkg/template/template.go b/pkg/template/template.go index 90770f2..ba9ee3b 100644 --- a/pkg/template/template.go +++ b/pkg/template/template.go @@ -29,6 +29,7 @@ type Template struct { tmpl *template.Template tp tableprinter.TablePrinter width int + funcs template.FuncMap } // New initializes a Template. @@ -38,9 +39,16 @@ func New(w io.Writer, width int, colorEnabled bool) Template { output: w, tp: tableprinter.New(w, true, width), width: width, + funcs: template.FuncMap{}, } } +// RegisterFunc registers a named function or overwrites a built-in function. +// Call this before Parse. +func (t *Template) RegisterFunc(name string, f func(fields ...interface{}) (string, error)) { + t.funcs[name] = f +} + // Parse the given template string for use with Execute. func (t *Template) Parse(tmpl string) error { now := time.Now() @@ -70,6 +78,9 @@ func (t *Template) Parse(tmpl string) error { if !t.colorEnabled { templateFuncs["autocolor"] = autoColorFunc } + for name, f := range t.funcs { + templateFuncs[name] = f + } var err error t.tmpl, err = template.New("").Funcs(templateFuncs).Parse(tmpl) return err diff --git a/pkg/template/template_test.go b/pkg/template/template_test.go index 7922d21..f61ab97 100644 --- a/pkg/template/template_test.go +++ b/pkg/template/template_test.go @@ -11,6 +11,7 @@ import ( "time" "github.com/MakeNowJust/heredoc" + "github.com/cli/go-gh/pkg/text" "github.com/stretchr/testify/assert" ) @@ -39,6 +40,42 @@ func ExampleTemplate() { // FOOTER } +func ExampleTemplate_RegisterFunc() { + // Information about the terminal can be obtained using the [pkg/term] package. + colorEnabled := true + termWidth := 14 + json := strings.NewReader(heredoc.Doc(`[ + {"num": 1, "thing": "apple"}, + {"num": 2, "thing": "orange"} + ]`)) + template := "{{range .}}* {{pluralize .num .thing}}\n{{end}}" + tmpl := New(os.Stdout, termWidth, colorEnabled) + tmpl.RegisterFunc("pluralize", func(fields ...interface{}) (string, error) { + if l := len(fields); l != 2 { + return "", fmt.Errorf("wrong number of args for pluralize: want 2 got %d", l) + } + var ok bool + var num float64 + var thing string + if num, ok = fields[0].(float64); !ok && num == float64(int(num)) { + return "", fmt.Errorf("invalid value; expected int") + } + if thing, ok = fields[1].(string); !ok { + return "", fmt.Errorf("invalid value; expected string") + } + return text.Pluralize(int(num), thing), nil + }) + if err := tmpl.Parse(template); err != nil { + log.Fatal(err) + } + if err := tmpl.Execute(json); err != nil { + log.Fatal(err) + } + // Output: + // * 1 apple + // * 2 oranges +} + func TestJsonScalarToString(t *testing.T) { tests := []struct { name string @@ -432,3 +469,39 @@ func TestTruncateMultiline(t *testing.T) { }) } } + +func TestRegisterFunc(t *testing.T) { + w := &bytes.Buffer{} + tmpl := New(w, 80, false) + + // Override "truncate" and define a new "foo" function. + tmpl.RegisterFunc("truncate", func(fields ...interface{}) (string, error) { + if l := len(fields); l != 2 { + return "", fmt.Errorf("wrong number of args for truncate: want 2 got %d", l) + } + var ok bool + var width int + var input string + if width, ok = fields[0].(int); !ok { + return "", fmt.Errorf("invalid value; expected int") + } + if input, ok = fields[1].(string); !ok { + return "", fmt.Errorf("invalid value; expected string") + } + return input[:width], nil + }) + tmpl.RegisterFunc("foo", func(fields ...interface{}) (string, error) { + return "test", nil + }) + + err := tmpl.Parse(`{{ .text | truncate 5 }} {{ .status | color "green" }} {{ foo }}`) + assert.NoError(t, err) + + r := strings.NewReader(`{"text":"truncated","status":"open"}`) + err = tmpl.Execute(r) + assert.NoError(t, err) + + err = tmpl.Flush() + assert.NoError(t, err) + assert.Equal(t, "trunc \x1b[0;32mopen\x1b[0m test", w.String()) +} From fa49c1cb745a6d1d656d521e30e684f6de7a2c3a Mon Sep 17 00:00:00 2001 From: Sam Coe Date: Wed, 12 Oct 2022 16:09:03 +0300 Subject: [PATCH 2/2] Change RegisterFunc to Funcs --- pkg/template/template.go | 13 +++++-- pkg/template/template_test.go | 70 ++++++++++++++++++----------------- 2 files changed, 46 insertions(+), 37 deletions(-) diff --git a/pkg/template/template.go b/pkg/template/template.go index ba9ee3b..f698915 100644 --- a/pkg/template/template.go +++ b/pkg/template/template.go @@ -43,10 +43,15 @@ func New(w io.Writer, width int, colorEnabled bool) Template { } } -// RegisterFunc registers a named function or overwrites a built-in function. -// Call this before Parse. -func (t *Template) RegisterFunc(name string, f func(fields ...interface{}) (string, error)) { - t.funcs[name] = f +// Funcs adds the elements of the argument map to the template's function map. +// It must be called before the template is parsed. +// It is legal to overwrite elements of the map including default functions. +// The return value is the template, so calls can be chained. +func (t *Template) Funcs(funcMap map[string]interface{}) *Template { + for name, f := range funcMap { + t.funcs[name] = f + } + return t } // Parse the given template string for use with Execute. diff --git a/pkg/template/template_test.go b/pkg/template/template_test.go index f61ab97..f17e66d 100644 --- a/pkg/template/template_test.go +++ b/pkg/template/template_test.go @@ -40,7 +40,7 @@ func ExampleTemplate() { // FOOTER } -func ExampleTemplate_RegisterFunc() { +func ExampleTemplate_Funcs() { // Information about the terminal can be obtained using the [pkg/term] package. colorEnabled := true termWidth := 14 @@ -50,20 +50,22 @@ func ExampleTemplate_RegisterFunc() { ]`)) template := "{{range .}}* {{pluralize .num .thing}}\n{{end}}" tmpl := New(os.Stdout, termWidth, colorEnabled) - tmpl.RegisterFunc("pluralize", func(fields ...interface{}) (string, error) { - if l := len(fields); l != 2 { - return "", fmt.Errorf("wrong number of args for pluralize: want 2 got %d", l) - } - var ok bool - var num float64 - var thing string - if num, ok = fields[0].(float64); !ok && num == float64(int(num)) { - return "", fmt.Errorf("invalid value; expected int") - } - if thing, ok = fields[1].(string); !ok { - return "", fmt.Errorf("invalid value; expected string") - } - return text.Pluralize(int(num), thing), nil + tmpl.Funcs(map[string]interface{}{ + "pluralize": func(fields ...interface{}) (string, error) { + if l := len(fields); l != 2 { + return "", fmt.Errorf("wrong number of args for pluralize: want 2 got %d", l) + } + var ok bool + var num float64 + var thing string + if num, ok = fields[0].(float64); !ok && num == float64(int(num)) { + return "", fmt.Errorf("invalid value; expected int") + } + if thing, ok = fields[1].(string); !ok { + return "", fmt.Errorf("invalid value; expected string") + } + return text.Pluralize(int(num), thing), nil + }, }) if err := tmpl.Parse(template); err != nil { log.Fatal(err) @@ -470,28 +472,30 @@ func TestTruncateMultiline(t *testing.T) { } } -func TestRegisterFunc(t *testing.T) { +func TestFuncs(t *testing.T) { w := &bytes.Buffer{} tmpl := New(w, 80, false) // Override "truncate" and define a new "foo" function. - tmpl.RegisterFunc("truncate", func(fields ...interface{}) (string, error) { - if l := len(fields); l != 2 { - return "", fmt.Errorf("wrong number of args for truncate: want 2 got %d", l) - } - var ok bool - var width int - var input string - if width, ok = fields[0].(int); !ok { - return "", fmt.Errorf("invalid value; expected int") - } - if input, ok = fields[1].(string); !ok { - return "", fmt.Errorf("invalid value; expected string") - } - return input[:width], nil - }) - tmpl.RegisterFunc("foo", func(fields ...interface{}) (string, error) { - return "test", nil + tmpl.Funcs(map[string]interface{}{ + "truncate": func(fields ...interface{}) (string, error) { + if l := len(fields); l != 2 { + return "", fmt.Errorf("wrong number of args for truncate: want 2 got %d", l) + } + var ok bool + var width int + var input string + if width, ok = fields[0].(int); !ok { + return "", fmt.Errorf("invalid value; expected int") + } + if input, ok = fields[1].(string); !ok { + return "", fmt.Errorf("invalid value; expected string") + } + return input[:width], nil + }, + "foo": func(fields ...interface{}) (string, error) { + return "test", nil + }, }) err := tmpl.Parse(`{{ .text | truncate 5 }} {{ .status | color "green" }} {{ foo }}`)