diff --git a/internal/cmd/generate.go b/internal/cmd/generate.go index 4c9025641d..c25d37bc0b 100644 --- a/internal/cmd/generate.go +++ b/internal/cmd/generate.go @@ -187,33 +187,25 @@ func Generate(ctx context.Context, e Env, dir, filename string, stderr io.Writer if debug.Traced { region = trace.StartRegion(ctx, "codegen") } - var files map[string]string - var resp *plugin.CodeGenResponse + var genfunc func(req *plugin.CodeGenRequest) (*plugin.CodeGenResponse, error) var out string switch { case sql.Gen.Go != nil: out = combo.Go.Out - resp, err = golang.Generate(codeGenRequest(result, combo)) + genfunc = golang.Generate case sql.Gen.Kotlin != nil: out = combo.Kotlin.Out - resp, err = kotlin.Generate(codeGenRequest(result, combo)) + genfunc = kotlin.Generate case sql.Gen.Python != nil: out = combo.Python.Out - resp, err = python.Generate(codeGenRequest(result, combo)) + genfunc = python.Generate default: panic("missing language backend") } + resp, err := genfunc(codeGenRequest(result, combo)) if region != nil { region.End() } - - if resp != nil { - files = map[string]string{} - for _, file := range resp.Files { - files[file.Name] = string(file.Contents) - } - } - if err != nil { fmt.Fprintf(stderr, "# package %s\n", name) fmt.Fprintf(stderr, "error generating code: %s\n", err) @@ -223,6 +215,10 @@ func Generate(ctx context.Context, e Env, dir, filename string, stderr io.Writer } continue } + files := map[string]string{} + for _, file := range resp.Files { + files[file.Name] = string(file.Contents) + } for n, source := range files { filename := filepath.Join(dir, out, n) output[filename] = source