Skip to content

Commit

Permalink
ast: Importing rego.v1 in v0 support modules when applicable (#6698)
Browse files Browse the repository at this point in the history
Prioritizing generating v0 Rego with `rego.v1` import when producing support modules for non-`--v1-compatible` optimized builds.

Affects `opa build` when the `-O` flag is used for optimization, and `opa eval` for partial evaluation with the `-p` flag.

Fixes: #6450
Signed-off-by: Johan Fylling <johan.dev@fylling.se>
  • Loading branch information
johanfylling committed Apr 24, 2024
1 parent 44fa8ad commit b58e87f
Show file tree
Hide file tree
Showing 9 changed files with 523 additions and 19 deletions.
5 changes: 5 additions & 0 deletions ast/parser.go
Expand Up @@ -2570,6 +2570,11 @@ var futureKeywords = map[string]tokens.Token{
"if": tokens.If,
}

func IsFutureKeyword(s string) bool {
_, ok := futureKeywords[s]
return ok
}

func (p *Parser) futureImport(imp *Import, allowedFutureKeywords map[string]tokens.Token) {
path := imp.Path.Value.(Ref)

Expand Down
6 changes: 6 additions & 0 deletions ast/policy.go
Expand Up @@ -407,6 +407,12 @@ func (mod *Module) RegoVersion() RegoVersion {
return mod.regoVersion
}

// SetRegoVersion sets the RegoVersion for the module.
// Note: Setting a rego-version that does not match the module's rego-version might have unintended consequences.
func (mod *Module) SetRegoVersion(v RegoVersion) {
mod.regoVersion = v
}

// NewComment returns a new Comment object.
func NewComment(text []byte) *Comment {
return &Comment{
Expand Down
16 changes: 8 additions & 8 deletions bundle/bundle.go
Expand Up @@ -1082,8 +1082,15 @@ func (b *Bundle) FormatModulesForRegoVersion(version ast.RegoVersion, preserveMo
var err error

for i, module := range b.Modules {
opts := format.Opts{}
if preserveModuleRegoVersion {
opts.RegoVersion = module.Parsed.RegoVersion()
} else {
opts.RegoVersion = version
}

if module.Raw == nil {
module.Raw, err = format.AstWithOpts(module.Parsed, format.Opts{RegoVersion: version})
module.Raw, err = format.AstWithOpts(module.Parsed, opts)
if err != nil {
return err
}
Expand All @@ -1093,13 +1100,6 @@ func (b *Bundle) FormatModulesForRegoVersion(version ast.RegoVersion, preserveMo
path = module.Path
}

opts := format.Opts{}
if preserveModuleRegoVersion {
opts.RegoVersion = module.Parsed.RegoVersion()
} else {
opts.RegoVersion = version
}

module.Raw, err = format.SourceWithOpts(path, module.Raw, opts)
if err != nil {
return err
Expand Down
86 changes: 78 additions & 8 deletions cmd/build_test.go
Expand Up @@ -1792,14 +1792,70 @@ allow if {
}
}

func TestBuildWithV1CompatibleFlagOptimized(t *testing.T) {
func TestBuildOptimizedWithRegoVersion(t *testing.T) {
tests := []struct {
note string
files map[string]string
expectedFiles map[string]string
note string
v1Compatible bool
regoV1ImportCapable bool
files map[string]string
expectedFiles map[string]string
}{
{
note: "No imports",
note: "v0, no future keywords",
v1Compatible: false,
regoV1ImportCapable: true,
files: map[string]string{
"test.rego": `package test
# METADATA
# entrypoint: true
p[v] {
v := input.v
}
`,
},
expectedFiles: map[string]string{
"/.manifest": `{"revision":"","roots":[""],"rego_version":0}
`,
// rego.v1 import added to optimized support module
"/optimized/test.rego": `package test
import rego.v1
p contains __local0__1 if {
__local0__1 = input.v
}
`,
},
},
{
note: "v0, No future keywords, not rego.v1 import capable",
v1Compatible: false,
regoV1ImportCapable: false,
files: map[string]string{
"test.rego": `package test
# METADATA
# entrypoint: true
p[v] {
v := input.v
}
`,
},
expectedFiles: map[string]string{
"/.manifest": `{"revision":"","roots":[""],"rego_version":0}
`,
// rego.v1 import NOT added to optimized support module
"/optimized/test.rego": `package test
p[__local0__1] {
__local0__1 = input.v
}
`,
},
},
{
note: "v1, No imports",
v1Compatible: true,
regoV1ImportCapable: true,
files: map[string]string{
"test.rego": `package test
# METADATA
Expand All @@ -1822,7 +1878,9 @@ foo contains __local1__1 if {
},
},
{
note: "rego.v1 imported",
note: "v1, rego.v1 imported",
v1Compatible: true,
regoV1ImportCapable: true,
files: map[string]string{
"test.rego": `package test
import rego.v1
Expand All @@ -1849,7 +1907,9 @@ foo contains __local1__1 if {
},
},
{
note: "future.keywords imported",
note: "v1, future.keywords imported",
v1Compatible: true,
regoV1ImportCapable: true,
files: map[string]string{
"test.rego": `package test
import future.keywords
Expand Down Expand Up @@ -1879,9 +1939,19 @@ foo contains __local1__1 if {
test.WithTempFS(tc.files, func(root string) {
params := newBuildParams()
params.outputFile = path.Join(root, "bundle.tar.gz")
params.v1Compatible = true
params.v1Compatible = tc.v1Compatible
params.optimizationLevel = 1

if !tc.regoV1ImportCapable {
caps := newcapabilitiesFlag()
caps.C = ast.CapabilitiesForThisVersion()
caps.C.Features = []string{
ast.FeatureRefHeadStringPrefixes,
ast.FeatureRefHeads,
}
params.capabilities = caps
}

err := dobuild(params, []string{root})

if err != nil {
Expand Down
137 changes: 137 additions & 0 deletions cmd/eval_test.go
Expand Up @@ -1313,6 +1313,143 @@ time.clock(input.y, time.clock(input.x))
}
}

func TestEvalPartialRegoVersionOutput(t *testing.T) {
tests := []struct {
note string
regoV1ImportCapable bool
v1Compatible bool
query string
module string
expected string
}{
{
note: "v0, no future keywords",
regoV1ImportCapable: true,
query: "data.test.p",
module: `package test
p[v] {
v := input.v
}
`,
expected: `# Query 1
data.partial.test.p = _term_0_0
_term_0_0
# Module 1
package partial.test
import rego.v1
p contains __local0__1 if __local0__1 = input.v
`,
},
{
note: "v0, no future keywords, not rego.v1 import capable",
regoV1ImportCapable: false,
query: "data.test.p",
module: `package test
p[v] {
v := input.v
}
`,
expected: `# Query 1
data.partial.test.p = _term_0_0
_term_0_0
# Module 1
package partial.test
p[__local0__1] {
__local0__1 = input.v
}
`,
},
{
note: "v0, future keywords",
regoV1ImportCapable: true,
query: "data.test.p",
module: `package test
import rego.v1
p contains v if {
v := input.v
}
`,
expected: `# Query 1
data.partial.test.p = _term_0_0
_term_0_0
# Module 1
package partial.test
import rego.v1
p contains __local0__1 if __local0__1 = input.v
`,
},
{
note: "v1",
regoV1ImportCapable: true,
v1Compatible: true,
query: "data.test.p",
module: `package test
p contains v if {
v := input.v
}
`,
expected: `# Query 1
data.partial.test.p = _term_0_0
_term_0_0
# Module 1
package partial.test
p contains __local0__1 if __local0__1 = input.v
`,
},
}

for _, tc := range tests {
t.Run(tc.note, func(t *testing.T) {
files := map[string]string{
"test.rego": tc.module,
}

test.WithTempFS(files, func(path string) {
params := newEvalCommandParams()
_ = params.dataPaths.Set(filepath.Join(path, "test.rego"))
params.partial = true
params.shallowInlining = true
params.v1Compatible = tc.v1Compatible
_ = params.outputFormat.Set(evalSourceOutput)

if !tc.regoV1ImportCapable {
caps := newcapabilitiesFlag()
caps.C = ast.CapabilitiesForThisVersion()
caps.C.Features = []string{
ast.FeatureRefHeadStringPrefixes,
ast.FeatureRefHeads,
}
params.capabilities = caps
}

buf := new(bytes.Buffer)
_, err := eval([]string{tc.query}, params, buf)
if err != nil {
t.Fatal("unexpected error:", err)
}
if actual := buf.String(); actual != tc.expected {
t.Errorf("expected output %q\ngot %q", tc.expected, actual)
}
})
})
}
}

func TestEvalDiscardOutput(t *testing.T) {
tests := map[string]struct {
query, format, expected string
Expand Down
11 changes: 10 additions & 1 deletion compile/compile.go
Expand Up @@ -544,7 +544,8 @@ func (c *Compiler) optimize(ctx context.Context) error {
WithEntrypoints(c.entrypointrefs).
WithDebug(c.debug.Writer()).
WithShallowInlining(c.optimizationLevel <= 1).
WithEnablePrintStatements(c.enablePrintStatements)
WithEnablePrintStatements(c.enablePrintStatements).
WithRegoVersion(c.regoVersion)

if c.ns != "" {
o = o.WithPartialNamespace(c.ns)
Expand Down Expand Up @@ -869,6 +870,7 @@ type optimizer struct {
shallow bool
debug debug.Debug
enablePrintStatements bool
regoVersion ast.RegoVersion
}

func newOptimizer(c *ast.Capabilities, b *bundle.Bundle) *optimizer {
Expand Down Expand Up @@ -909,6 +911,11 @@ func (o *optimizer) WithPartialNamespace(ns string) *optimizer {
return o
}

func (o *optimizer) WithRegoVersion(regoVersion ast.RegoVersion) *optimizer {
o.regoVersion = regoVersion
return o
}

func (o *optimizer) Do(ctx context.Context) error {

// NOTE(tsandall): if there are multiple entrypoints, copy the bundle because
Expand Down Expand Up @@ -958,6 +965,8 @@ func (o *optimizer) Do(ctx context.Context) error {
rego.ParsedUnknowns(unknowns),
rego.Compiler(o.compiler),
rego.Store(store),
rego.Capabilities(o.capabilities),
rego.SetRegoVersion(o.regoVersion),
)

o.debug.Printf("optimizer: entrypoint: %v", e)
Expand Down

0 comments on commit b58e87f

Please sign in to comment.