Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ast/compiler: allow retaining parsed modules #4921

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
31 changes: 31 additions & 0 deletions ast/compile.go
Expand Up @@ -113,6 +113,8 @@ type Compiler struct {
inputType types.Type // global input type retrieved from schema set
annotationSet *AnnotationSet // hierarchical set of annotations
strict bool // enforce strict compilation checks
keepModules bool // whether to keep the unprocessed, parse modules (below)
parsedModules map[string]*Module // parsed, but otherwise unprocessed modules, kept track of when keepModules is true
}

// CompilerStage defines the interface for stages in the compiler.
Expand Down Expand Up @@ -385,6 +387,22 @@ func (c *Compiler) WithStrict(strict bool) *Compiler {
return c
}

// WithKeepModules enables retaining unprocessed modules in the compiler.
// Note that the modules aren't copied on the way in or out -- so when
// accessing them via ParsedModules(), mutations will occur in the module
// map that was passed into Compile().`
func (c *Compiler) WithKeepModules(y bool) *Compiler {
c.keepModules = y
return c
}

// ParsedModules returns the parsed, unprocessed modules from the compiler.
// It is `nil` if keeping modules wasn't enabled via `WithKeepModules(true)`.
// The map includes all modules loaded via the ModuleLoader, if one was used.
func (c *Compiler) ParsedModules() map[string]*Module {
return c.parsedModules
}

// QueryCompiler returns a new QueryCompiler object.
func (c *Compiler) QueryCompiler() QueryCompiler {
c.init()
Expand All @@ -400,10 +418,20 @@ func (c *Compiler) Compile(modules map[string]*Module) {
c.init()

c.Modules = make(map[string]*Module, len(modules))
c.sorted = make([]string, 0, len(modules))

if c.keepModules {
c.parsedModules = make(map[string]*Module, len(modules))
} else {
c.parsedModules = nil
}

for k, v := range modules {
c.Modules[k] = v.Copy()
c.sorted = append(c.sorted, k)
if c.parsedModules != nil {
c.parsedModules[k] = v
}
}

sort.Strings(c.sorted)
Expand Down Expand Up @@ -1441,6 +1469,9 @@ func (c *Compiler) resolveAllRefs() {
for id, module := range parsed {
c.Modules[id] = module.Copy()
c.sorted = append(c.sorted, id)
if c.parsedModules != nil {
c.parsedModules[id] = module
}
}

sort.Strings(c.sorted)
Expand Down
152 changes: 152 additions & 0 deletions ast/compile_test.go
Expand Up @@ -7003,3 +7003,155 @@ func TestCompilerPassesTypeCheckNegative(t *testing.T) {
t.Fatal("Incorrectly detected a type-checking violation")
}
}

func TestKeepModules(t *testing.T) {

t.Run("no keep", func(t *testing.T) {
c := NewCompiler() // no keep is default

// This one is overwritten by c.Compile()
c.Modules["foo.rego"] = MustParseModule("package foo\np = true")

c.Compile(map[string]*Module{"bar.rego": MustParseModule("package bar\np = input")})

if len(c.Errors) != 0 {
t.Fatalf("expected no error; got %v", c.Errors)
}

mods := c.ParsedModules()
if mods != nil {
t.Errorf("expected ParsedModules == nil, got %v", mods)
}
})

t.Run("keep", func(t *testing.T) {

c := NewCompiler().WithKeepModules(true)

// This one is overwritten by c.Compile()
c.Modules["foo.rego"] = MustParseModule("package foo\np = true")

c.Compile(map[string]*Module{"bar.rego": MustParseModule("package bar\np = input")})
if len(c.Errors) != 0 {
t.Fatalf("expected no error; got %v", c.Errors)
}

mods := c.ParsedModules()
if exp, act := 1, len(mods); exp != act {
t.Errorf("expected %d modules, found %d: %v", exp, act, mods)
}
for k := range mods {
if k != "bar.rego" {
t.Errorf("unexpected key: %v, want 'bar.rego'", k)
}
}

for k := range mods {
compiled := c.Modules[k]
if compiled.Equal(mods[k]) {
t.Errorf("expected module %v to not be compiled: %v", k, mods[k])
}
}

// expect ParsedModules to be reset
c.Compile(map[string]*Module{"baz.rego": MustParseModule("package baz\np = input")})
mods = c.ParsedModules()
if exp, act := 1, len(mods); exp != act {
t.Errorf("expected %d modules, found %d: %v", exp, act, mods)
}
for k := range mods {
if k != "baz.rego" {
t.Errorf("unexpected key: %v, want 'baz.rego'", k)
}
}

for k := range mods {
compiled := c.Modules[k]
if compiled.Equal(mods[k]) {
t.Errorf("expected module %v to not be compiled: %v", k, mods[k])
}
}

// expect ParsedModules to be reset to nil
c = c.WithKeepModules(false)
c.Compile(map[string]*Module{"baz.rego": MustParseModule("package baz\np = input")})
mods = c.ParsedModules()
if mods != nil {
t.Errorf("expected ParsedModules == nil, got %v", mods)
}
})

t.Run("no copies", func(t *testing.T) {
extra := MustParseModule("package extra\np = input")
done := false
testLoader := func(map[string]*Module) (map[string]*Module, error) {
if done {
return nil, nil
}
done = true
return map[string]*Module{"extra.rego": extra}, nil
}

c := NewCompiler().WithModuleLoader(testLoader).WithKeepModules(true)

mod := MustParseModule("package bar\np = input")
c.Compile(map[string]*Module{"bar.rego": mod})
if len(c.Errors) != 0 {
t.Fatalf("expected no error; got %v", c.Errors)
}

mods := c.ParsedModules()
if exp, act := 2, len(mods); exp != act {
t.Errorf("expected %d modules, found %d: %v", exp, act, mods)
}
newName := Var("q")
mods["bar.rego"].Rules[0].Head.Name = newName
if exp, act := newName, mod.Rules[0].Head.Name; !exp.Equal(act) {
t.Errorf("expected modified rule name %v, found %v", exp, act)
}
mods["extra.rego"].Rules[0].Head.Name = newName
if exp, act := newName, extra.Rules[0].Head.Name; !exp.Equal(act) {
t.Errorf("expected modified rule name %v, found %v", exp, act)
}
})

t.Run("keep, with loader", func(t *testing.T) {
extra := MustParseModule("package extra\np = input")
done := false
testLoader := func(map[string]*Module) (map[string]*Module, error) {
if done {
return nil, nil
}
done = true
return map[string]*Module{"extra.rego": extra}, nil
}

c := NewCompiler().WithModuleLoader(testLoader).WithKeepModules(true)

// This one is overwritten by c.Compile()
c.Modules["foo.rego"] = MustParseModule("package foo\np = true")

c.Compile(map[string]*Module{"bar.rego": MustParseModule("package bar\np = input")})

if len(c.Errors) != 0 {
t.Fatalf("expected no error; got %v", c.Errors)
}

mods := c.ParsedModules()
if exp, act := 2, len(mods); exp != act {
t.Errorf("expected %d modules, found %d: %v", exp, act, mods)
}
for k := range mods {
if k != "bar.rego" && k != "extra.rego" {
t.Errorf("unexpected key: %v, want 'extra.rego' and 'bar.rego'", k)
}
}

for k := range mods {
compiled := c.Modules[k]
if compiled.Equal(mods[k]) {
t.Errorf("expected module %v to not be compiled: %v", k, mods[k])
}
}
})
}