diff --git a/ast/compile.go b/ast/compile.go index 5417119a2f..3697229dda 100644 --- a/ast/compile.go +++ b/ast/compile.go @@ -113,6 +113,8 @@ type Compiler struct { inputType types.Type // global input type retrieved from schema set annotationSet *AnnotationSet // hierarchical set of annotations strict bool // enforce strict compilation checks + keepModules bool // whether to keep the unprocessed, parse modules (below) + parsedModules map[string]*Module // parsed, but otherwise unprocessed modules, kept track of when keepModules is true } // CompilerStage defines the interface for stages in the compiler. @@ -385,6 +387,22 @@ func (c *Compiler) WithStrict(strict bool) *Compiler { return c } +// WithKeepModules enables retaining unprocessed modules in the compiler. +// Note that the modules aren't copied on the way in or out -- so when +// accessing them via ParsedModules(), mutations will occur in the module +// map that was passed into Compile().` +func (c *Compiler) WithKeepModules(y bool) *Compiler { + c.keepModules = y + return c +} + +// ParsedModules returns the parsed, unprocessed modules from the compiler. +// It is `nil` if keeping modules wasn't enabled via `WithKeepModules(true)`. +// The map includes all modules loaded via the ModuleLoader, if one was used. +func (c *Compiler) ParsedModules() map[string]*Module { + return c.parsedModules +} + // QueryCompiler returns a new QueryCompiler object. func (c *Compiler) QueryCompiler() QueryCompiler { c.init() @@ -400,10 +418,20 @@ func (c *Compiler) Compile(modules map[string]*Module) { c.init() c.Modules = make(map[string]*Module, len(modules)) + c.sorted = make([]string, 0, len(modules)) + + if c.keepModules { + c.parsedModules = make(map[string]*Module, len(modules)) + } else { + c.parsedModules = nil + } for k, v := range modules { c.Modules[k] = v.Copy() c.sorted = append(c.sorted, k) + if c.parsedModules != nil { + c.parsedModules[k] = v + } } sort.Strings(c.sorted) @@ -1441,6 +1469,9 @@ func (c *Compiler) resolveAllRefs() { for id, module := range parsed { c.Modules[id] = module.Copy() c.sorted = append(c.sorted, id) + if c.parsedModules != nil { + c.parsedModules[id] = module + } } sort.Strings(c.sorted) diff --git a/ast/compile_test.go b/ast/compile_test.go index feca0fd6dc..627b148044 100644 --- a/ast/compile_test.go +++ b/ast/compile_test.go @@ -7003,3 +7003,155 @@ func TestCompilerPassesTypeCheckNegative(t *testing.T) { t.Fatal("Incorrectly detected a type-checking violation") } } + +func TestKeepModules(t *testing.T) { + + t.Run("no keep", func(t *testing.T) { + c := NewCompiler() // no keep is default + + // This one is overwritten by c.Compile() + c.Modules["foo.rego"] = MustParseModule("package foo\np = true") + + c.Compile(map[string]*Module{"bar.rego": MustParseModule("package bar\np = input")}) + + if len(c.Errors) != 0 { + t.Fatalf("expected no error; got %v", c.Errors) + } + + mods := c.ParsedModules() + if mods != nil { + t.Errorf("expected ParsedModules == nil, got %v", mods) + } + }) + + t.Run("keep", func(t *testing.T) { + + c := NewCompiler().WithKeepModules(true) + + // This one is overwritten by c.Compile() + c.Modules["foo.rego"] = MustParseModule("package foo\np = true") + + c.Compile(map[string]*Module{"bar.rego": MustParseModule("package bar\np = input")}) + if len(c.Errors) != 0 { + t.Fatalf("expected no error; got %v", c.Errors) + } + + mods := c.ParsedModules() + if exp, act := 1, len(mods); exp != act { + t.Errorf("expected %d modules, found %d: %v", exp, act, mods) + } + for k := range mods { + if k != "bar.rego" { + t.Errorf("unexpected key: %v, want 'bar.rego'", k) + } + } + + for k := range mods { + compiled := c.Modules[k] + if compiled.Equal(mods[k]) { + t.Errorf("expected module %v to not be compiled: %v", k, mods[k]) + } + } + + // expect ParsedModules to be reset + c.Compile(map[string]*Module{"baz.rego": MustParseModule("package baz\np = input")}) + mods = c.ParsedModules() + if exp, act := 1, len(mods); exp != act { + t.Errorf("expected %d modules, found %d: %v", exp, act, mods) + } + for k := range mods { + if k != "baz.rego" { + t.Errorf("unexpected key: %v, want 'baz.rego'", k) + } + } + + for k := range mods { + compiled := c.Modules[k] + if compiled.Equal(mods[k]) { + t.Errorf("expected module %v to not be compiled: %v", k, mods[k]) + } + } + + // expect ParsedModules to be reset to nil + c = c.WithKeepModules(false) + c.Compile(map[string]*Module{"baz.rego": MustParseModule("package baz\np = input")}) + mods = c.ParsedModules() + if mods != nil { + t.Errorf("expected ParsedModules == nil, got %v", mods) + } + }) + + t.Run("no copies", func(t *testing.T) { + extra := MustParseModule("package extra\np = input") + done := false + testLoader := func(map[string]*Module) (map[string]*Module, error) { + if done { + return nil, nil + } + done = true + return map[string]*Module{"extra.rego": extra}, nil + } + + c := NewCompiler().WithModuleLoader(testLoader).WithKeepModules(true) + + mod := MustParseModule("package bar\np = input") + c.Compile(map[string]*Module{"bar.rego": mod}) + if len(c.Errors) != 0 { + t.Fatalf("expected no error; got %v", c.Errors) + } + + mods := c.ParsedModules() + if exp, act := 2, len(mods); exp != act { + t.Errorf("expected %d modules, found %d: %v", exp, act, mods) + } + newName := Var("q") + mods["bar.rego"].Rules[0].Head.Name = newName + if exp, act := newName, mod.Rules[0].Head.Name; !exp.Equal(act) { + t.Errorf("expected modified rule name %v, found %v", exp, act) + } + mods["extra.rego"].Rules[0].Head.Name = newName + if exp, act := newName, extra.Rules[0].Head.Name; !exp.Equal(act) { + t.Errorf("expected modified rule name %v, found %v", exp, act) + } + }) + + t.Run("keep, with loader", func(t *testing.T) { + extra := MustParseModule("package extra\np = input") + done := false + testLoader := func(map[string]*Module) (map[string]*Module, error) { + if done { + return nil, nil + } + done = true + return map[string]*Module{"extra.rego": extra}, nil + } + + c := NewCompiler().WithModuleLoader(testLoader).WithKeepModules(true) + + // This one is overwritten by c.Compile() + c.Modules["foo.rego"] = MustParseModule("package foo\np = true") + + c.Compile(map[string]*Module{"bar.rego": MustParseModule("package bar\np = input")}) + + if len(c.Errors) != 0 { + t.Fatalf("expected no error; got %v", c.Errors) + } + + mods := c.ParsedModules() + if exp, act := 2, len(mods); exp != act { + t.Errorf("expected %d modules, found %d: %v", exp, act, mods) + } + for k := range mods { + if k != "bar.rego" && k != "extra.rego" { + t.Errorf("unexpected key: %v, want 'extra.rego' and 'bar.rego'", k) + } + } + + for k := range mods { + compiled := c.Modules[k] + if compiled.Equal(mods[k]) { + t.Errorf("expected module %v to not be compiled: %v", k, mods[k]) + } + } + }) +}