Skip to content

Commit

Permalink
ast/compiler: allow retaining parsed modules
Browse files Browse the repository at this point in the history
The included change resetting `c.sorted` in `(*Compiler).Compile(...)` was
necessary to test the repeated compilation without a panic.

Fixes open-policy-agent#4910.

Signed-off-by: Stephan Renatus <stephan.renatus@gmail.com>
  • Loading branch information
srenatus committed Jul 25, 2022
1 parent bfa04b3 commit c0fdef5
Show file tree
Hide file tree
Showing 2 changed files with 183 additions and 0 deletions.
31 changes: 31 additions & 0 deletions ast/compile.go
Expand Up @@ -113,6 +113,8 @@ type Compiler struct {
inputType types.Type // global input type retrieved from schema set
annotationSet *AnnotationSet // hierarchical set of annotations
strict bool // enforce strict compilation checks
keepModules bool // whether to keep the unprocessed, parse modules (below)
parsedModules map[string]*Module // parsed, but otherwise unprocessed modules, kept track of when keepModules is true
}

// CompilerStage defines the interface for stages in the compiler.
Expand Down Expand Up @@ -385,6 +387,22 @@ func (c *Compiler) WithStrict(strict bool) *Compiler {
return c
}

// WithKeepModules enables retaining unprocessed modules in the compiler.
// Note that the modules aren't copied on the way in or out -- so when
// accessing them via ParsedModules(), mutations will occur in the module
// map that was passed into Compile().`
func (c *Compiler) WithKeepModules(y bool) *Compiler {
c.keepModules = y
return c
}

// ParsedModules returns the parsed, unprocessed modules from the compiler.
// It is `nil` if keeping modules wasn't enabled via `WithKeepModules(true)`.
// The map includes all modules loaded via the ModuleLoader, if one was used.
func (c *Compiler) ParsedModules() map[string]*Module {
return c.parsedModules
}

// QueryCompiler returns a new QueryCompiler object.
func (c *Compiler) QueryCompiler() QueryCompiler {
c.init()
Expand All @@ -400,10 +418,20 @@ func (c *Compiler) Compile(modules map[string]*Module) {
c.init()

c.Modules = make(map[string]*Module, len(modules))
c.sorted = make([]string, 0, len(modules))

if c.keepModules {
c.parsedModules = make(map[string]*Module, len(modules))
} else {
c.parsedModules = nil
}

for k, v := range modules {
c.Modules[k] = v.Copy()
c.sorted = append(c.sorted, k)
if c.parsedModules != nil {
c.parsedModules[k] = v
}
}

sort.Strings(c.sorted)
Expand Down Expand Up @@ -1441,6 +1469,9 @@ func (c *Compiler) resolveAllRefs() {
for id, module := range parsed {
c.Modules[id] = module.Copy()
c.sorted = append(c.sorted, id)
if c.parsedModules != nil {
c.parsedModules[id] = module
}
}

sort.Strings(c.sorted)
Expand Down
152 changes: 152 additions & 0 deletions ast/compile_test.go
Expand Up @@ -7003,3 +7003,155 @@ func TestCompilerPassesTypeCheckNegative(t *testing.T) {
t.Fatal("Incorrectly detected a type-checking violation")
}
}

func TestKeepModules(t *testing.T) {

t.Run("no keep", func(t *testing.T) {
c := NewCompiler() // no keep is default

// This one is overwritten by c.Compile()
c.Modules["foo.rego"] = MustParseModule("package foo\np = true")

c.Compile(map[string]*Module{"bar.rego": MustParseModule("package bar\np = input")})

if len(c.Errors) != 0 {
t.Fatalf("expected no error; got %v", c.Errors)
}

mods := c.ParsedModules()
if mods != nil {
t.Errorf("expected ParsedModules == nil, got %v", mods)
}
})

t.Run("keep", func(t *testing.T) {

c := NewCompiler().WithKeepModules(true)

// This one is overwritten by c.Compile()
c.Modules["foo.rego"] = MustParseModule("package foo\np = true")

c.Compile(map[string]*Module{"bar.rego": MustParseModule("package bar\np = input")})
if len(c.Errors) != 0 {
t.Fatalf("expected no error; got %v", c.Errors)
}

mods := c.ParsedModules()
if exp, act := 1, len(mods); exp != act {
t.Errorf("expected %d modules, found %d: %v", exp, act, mods)
}
for k := range mods {
if k != "bar.rego" {
t.Errorf("unexpected key: %v, want 'bar.rego'", k)
}
}

for k := range mods {
compiled := c.Modules[k]
if compiled.Equal(mods[k]) {
t.Errorf("expected module %v to not be compiled: %v", k, mods[k])
}
}

// expect ParsedModules to be reset
c.Compile(map[string]*Module{"baz.rego": MustParseModule("package baz\np = input")})
mods = c.ParsedModules()
if exp, act := 1, len(mods); exp != act {
t.Errorf("expected %d modules, found %d: %v", exp, act, mods)
}
for k := range mods {
if k != "baz.rego" {
t.Errorf("unexpected key: %v, want 'baz.rego'", k)
}
}

for k := range mods {
compiled := c.Modules[k]
if compiled.Equal(mods[k]) {
t.Errorf("expected module %v to not be compiled: %v", k, mods[k])
}
}

// expect ParsedModules to be reset to nil
c = c.WithKeepModules(false)
c.Compile(map[string]*Module{"baz.rego": MustParseModule("package baz\np = input")})
mods = c.ParsedModules()
if mods != nil {
t.Errorf("expected ParsedModules == nil, got %v", mods)
}
})

t.Run("no copies", func(t *testing.T) {
extra := MustParseModule("package extra\np = input")
done := false
testLoader := func(map[string]*Module) (map[string]*Module, error) {
if done {
return nil, nil
}
done = true
return map[string]*Module{"extra.rego": extra}, nil
}

c := NewCompiler().WithModuleLoader(testLoader).WithKeepModules(true)

mod := MustParseModule("package bar\np = input")
c.Compile(map[string]*Module{"bar.rego": mod})
if len(c.Errors) != 0 {
t.Fatalf("expected no error; got %v", c.Errors)
}

mods := c.ParsedModules()
if exp, act := 2, len(mods); exp != act {
t.Errorf("expected %d modules, found %d: %v", exp, act, mods)
}
newName := Var("q")
mods["bar.rego"].Rules[0].Head.Name = newName
if exp, act := newName, mod.Rules[0].Head.Name; !exp.Equal(act) {
t.Errorf("expected modified rule name %v, found %v", exp, act)
}
mods["extra.rego"].Rules[0].Head.Name = newName
if exp, act := newName, extra.Rules[0].Head.Name; !exp.Equal(act) {
t.Errorf("expected modified rule name %v, found %v", exp, act)
}
})

t.Run("keep, with loader", func(t *testing.T) {
extra := MustParseModule("package extra\np = input")
done := false
testLoader := func(map[string]*Module) (map[string]*Module, error) {
if done {
return nil, nil
}
done = true
return map[string]*Module{"extra.rego": extra}, nil
}

c := NewCompiler().WithModuleLoader(testLoader).WithKeepModules(true)

// This one is overwritten by c.Compile()
c.Modules["foo.rego"] = MustParseModule("package foo\np = true")

c.Compile(map[string]*Module{"bar.rego": MustParseModule("package bar\np = input")})

if len(c.Errors) != 0 {
t.Fatalf("expected no error; got %v", c.Errors)
}

mods := c.ParsedModules()
if exp, act := 2, len(mods); exp != act {
t.Errorf("expected %d modules, found %d: %v", exp, act, mods)
}
for k := range mods {
if k != "bar.rego" && k != "extra.rego" {
t.Errorf("unexpected key: %v, want 'extra.rego' and 'bar.rego'", k)
}
}

for k := range mods {
compiled := c.Modules[k]
if compiled.Equal(mods[k]) {
t.Errorf("expected module %v to not be compiled: %v", k, mods[k])
}
}
})
}

0 comments on commit c0fdef5

Please sign in to comment.