diff --git a/entc/gen/graph.go b/entc/gen/graph.go index e20e683a72..c431dd5d24 100644 --- a/entc/gen/graph.go +++ b/entc/gen/graph.go @@ -660,6 +660,7 @@ func (g *Graph) templates() (*Template, []GraphTemplate) { external = append(external, GraphTemplate{ Name: name, Format: snake(name) + ".go", + Skip: rootT.condition, }) roots[name] = struct{}{} } diff --git a/entc/gen/graph_test.go b/entc/gen/graph_test.go index 61bea6c4a3..16152f5ccb 100644 --- a/entc/gen/graph_test.go +++ b/entc/gen/graph_test.go @@ -307,11 +307,12 @@ func TestGraph_Gen(t *testing.T) { require.NoError(os.MkdirAll(target, os.ModePerm), "creating tmpdir") defer os.RemoveAll(target) external := MustParse(NewTemplate("external").Parse("package external")) + skipped := MustParse(NewTemplate("skipped").SkipIf(func(*Graph) bool { return true }).Parse("package external")) graph, err := NewGraph(&Config{ Package: "entc/gen", Target: target, Storage: drivers[0], - Templates: []*Template{external}, + Templates: []*Template{external, skipped}, IDType: &field.TypeInfo{Type: field.TypeInt}, Features: AllFeatures, }, &load.Schema{ @@ -340,6 +341,8 @@ func TestGraph_Gen(t *testing.T) { } _, err = os.Stat(filepath.Join(target, "external.go")) require.NoError(err) + _, err = os.Stat(filepath.Join(target, "skipped.go")) + require.True(os.IsNotExist(err)) // Generated feature templates. _, err = os.Stat(filepath.Join(target, "internal", "schema.go")) diff --git a/entc/gen/template.go b/entc/gen/template.go index 563812b5bf..0b3e3ce44c 100644 --- a/entc/gen/template.go +++ b/entc/gen/template.go @@ -231,7 +231,8 @@ func initTemplates() { // provide additional functionality for ent extensions. type Template struct { *template.Template - FuncMap template.FuncMap + FuncMap template.FuncMap + condition func(*Graph) bool } // NewTemplate creates an empty template with the standard codegen functions. @@ -254,6 +255,12 @@ func (t *Template) Funcs(funcMap template.FuncMap) *Template { return t } +// SkipIf allows registering a function to determine if the template needs to be skipped or not. +func (t *Template) SkipIf(cond func(*Graph) bool) *Template { + t.condition = cond + return t +} + // Parse parses text as a template body for t. func (t *Template) Parse(text string) (*Template, error) { if _, err := t.Template.Parse(text); err != nil {