forked from open-policy-agent/opa
/
copypropagation.go
475 lines (404 loc) · 12.7 KB
/
copypropagation.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
// Copyright 2018 The OPA Authors. All rights reserved.
// Use of this source code is governed by an Apache2
// license that can be found in the LICENSE file.
package copypropagation
import (
"fmt"
"sort"
"github.com/open-policy-agent/opa/ast"
)
// CopyPropagator implements a simple copy propagation optimization to remove
// intermediate variables in partial evaluation results.
//
// For example, given the query: input.x > 1 where 'input' is unknown, the
// compiled query would become input.x = a; a > 1 which would remain in the
// partial evaluation result. The CopyPropagator will remove the variable
// assignment so that partial evaluation simply outputs input.x > 1.
//
// In many cases, copy propagation can remove all variables from the result of
// partial evaluation which simplifies evaluation for non-OPA consumers.
//
// In some cases, copy propagation cannot remove all variables. If the output of
// a built-in call is subsequently used as a ref head, the output variable must
// be kept. For example. sort(input, x); x[0] == 1. In this case, copy
// propagation cannot replace x[0] == 1 with sort(input, x)[0] == 1 as this is
// not legal.
type CopyPropagator struct {
livevars ast.VarSet // vars that must be preserved in the resulting query
sorted []ast.Var // sorted copy of vars to ensure deterministic result
ensureNonEmptyBody bool
compiler *ast.Compiler
localvargen *localVarGenerator
}
type localVarGenerator struct {
next int
}
func (l *localVarGenerator) Generate() ast.Var {
result := ast.Var(fmt.Sprintf("__localcp%d__", l.next))
l.next++
return result
}
// New returns a new CopyPropagator that optimizes queries while preserving vars
// in the livevars set.
func New(livevars ast.VarSet) *CopyPropagator {
sorted := make([]ast.Var, 0, len(livevars))
for v := range livevars {
sorted = append(sorted, v)
}
sort.Slice(sorted, func(i, j int) bool {
return sorted[i].Compare(sorted[j]) < 0
})
return &CopyPropagator{livevars: livevars, sorted: sorted, localvargen: &localVarGenerator{}}
}
// WithEnsureNonEmptyBody configures p to ensure that results are always non-empty.
func (p *CopyPropagator) WithEnsureNonEmptyBody(yes bool) *CopyPropagator {
p.ensureNonEmptyBody = yes
return p
}
// WithCompiler configures the compiler to read from while processing the query. This
// should be the same compiler used to compile the original policy.
func (p *CopyPropagator) WithCompiler(c *ast.Compiler) *CopyPropagator {
p.compiler = c
return p
}
// Apply executes the copy propagation optimization and returns a new query.
func (p *CopyPropagator) Apply(query ast.Body) ast.Body {
result := ast.NewBody()
uf, ok := makeDisjointSets(p.livevars, query)
if !ok {
return query
}
// Compute set of vars that appear in the head of refs in the query. If a var
// is dereferenced, we can plug it with a constant value, but it is not always
// optimal to do so.
// TODO: Improve the algorithm for when we should plug constants/calls/etc
headvars := ast.NewVarSet()
ast.WalkRefs(query, func(x ast.Ref) bool {
if v, ok := x[0].Value.(ast.Var); ok {
if root, ok := uf.Find(v); ok {
root.constant = nil
headvars.Add(root.key.(ast.Var))
} else {
headvars.Add(v)
}
}
return false
})
removedEqs := ast.NewValueMap()
for _, expr := range query {
pctx := &plugContext{
removedEqs: removedEqs,
uf: uf,
negated: expr.Negated,
headvars: headvars,
}
expr = p.plugBindings(pctx, expr)
if p.updateBindings(pctx, expr) {
result.Append(expr)
}
}
// Run post-processing step on the query to ensure that all live vars are bound
// in the result. The plugging that happens above substitutes all vars in the
// same set with the root.
//
// This step should run before the next step to prevent unnecessary bindings
// from being added to the result. For example:
//
// - Given the following result: <empty>
// - Given the following removed equalities: "x = input.x" and "y = input"
// - Given the following liveset: {x}
//
// If this step were to run AFTER the following step, the output would be:
//
// x = input.x; y = input
//
// Even though y = input is not required.
for _, v := range p.sorted {
if root, ok := uf.Find(v); ok {
if root.constant != nil {
result.Append(ast.Equality.Expr(ast.NewTerm(v), root.constant))
} else if b := removedEqs.Get(root.key); b != nil {
result.Append(ast.Equality.Expr(ast.NewTerm(v), ast.NewTerm(b)))
} else if root.key != v {
result.Append(ast.Equality.Expr(ast.NewTerm(v), ast.NewTerm(root.key)))
}
}
}
// Run post-processing step on query to ensure that all killed exprs are
// accounted for. There are several cases we look for:
//
// * If an expr is killed but the binding is never used, the query
// must still include the expr. For example, given the query 'input.x = a' and
// an empty livevar set, the result must include the ref input.x otherwise the
// query could be satisfied without input.x being defined.
//
// * If an expr is killed that provided safety to vars which are not
// otherwise being made safe by the current result.
//
// For any of these cases we re-add the removed equality expression
// to the current result.
// Invariant: Live vars are bound (above) and reserved vars are implicitly ground.
safe := ast.ReservedVars.Copy()
safe.Update(p.livevars)
safe.Update(ast.OutputVarsFromBody(p.compiler, result, safe))
unsafe := result.Vars(ast.SafetyCheckVisitorParams).Diff(safe)
for _, b := range sortbindings(removedEqs) {
removedEq := ast.Equality.Expr(ast.NewTerm(b.k), ast.NewTerm(b.v))
providesSafety := false
outputVars := ast.OutputVarsFromExpr(p.compiler, removedEq, safe)
diff := unsafe.Diff(outputVars)
if len(diff) < len(unsafe) {
unsafe = diff
providesSafety = true
}
if providesSafety || !containedIn(b.v, result) {
result.Append(removedEq)
safe.Update(outputVars)
}
}
if len(unsafe) > 0 {
// NOTE(tsandall): This should be impossible but if it does occur, throw
// away the result rather than generating unsafe output.
return query
}
if p.ensureNonEmptyBody && len(result) == 0 {
result = append(result, ast.NewExpr(ast.BooleanTerm(true)))
}
return result
}
// plugBindings applies the binding list and union-find to x. This process
// removes as many variables as possible.
func (p *CopyPropagator) plugBindings(pctx *plugContext, expr *ast.Expr) *ast.Expr {
xform := bindingPlugTransform{
pctx: pctx,
}
// Deep copy the expression as it may be mutated during the transform and
// the caller running copy propagation may have references to the
// expression. Note, the transform does not contain any error paths and
// should never return a non-expression value for the root so consider
// errors unreachable.
x, err := ast.Transform(xform, expr.Copy())
if expr, ok := x.(*ast.Expr); !ok || err != nil {
panic("unreachable")
} else {
return expr
}
}
type bindingPlugTransform struct {
pctx *plugContext
}
func (t bindingPlugTransform) Transform(x interface{}) (interface{}, error) {
switch x := x.(type) {
case ast.Var:
return t.plugBindingsVar(t.pctx, x), nil
case ast.Ref:
return t.plugBindingsRef(t.pctx, x), nil
default:
return x, nil
}
}
func (t bindingPlugTransform) plugBindingsVar(pctx *plugContext, v ast.Var) ast.Value {
var result ast.Value = v
// Apply union-find to remove redundant variables from input.
root, ok := pctx.uf.Find(v)
if ok {
result = root.Value()
}
// Apply binding list to substitute remaining vars.
v, ok = result.(ast.Var)
if !ok {
return result
}
b := pctx.removedEqs.Get(v)
if b == nil {
return result
}
if pctx.negated && !b.IsGround() {
return result
}
if r, ok := b.(ast.Ref); ok && r.OutputVars().Contains(v) {
return result
}
return b
}
func (t bindingPlugTransform) plugBindingsRef(pctx *plugContext, v ast.Ref) ast.Ref {
// Apply union-find to remove redundant variables from input.
if root, ok := pctx.uf.Find(v[0].Value); ok {
v[0].Value = root.Value()
}
result := v
// Refs require special handling. If the head of the ref was killed, then
// the rest of the ref must be concatenated with the new base.
if b := pctx.removedEqs.Get(v[0].Value); b != nil {
if !pctx.negated || b.IsGround() {
var base ast.Ref
switch x := b.(type) {
case ast.Ref:
base = x
default:
base = ast.Ref{ast.NewTerm(x)}
}
result = base.Concat(v[1:])
}
}
return result
}
// updateBindings returns false if the expression can be killed. If the
// expression is killed, the binding list is updated to map a var to value.
func (p *CopyPropagator) updateBindings(pctx *plugContext, expr *ast.Expr) bool {
switch {
case pctx.negated || len(expr.With) > 0:
return true
case expr.IsEquality():
a, b := expr.Operand(0), expr.Operand(1)
if a.Equal(b) {
if p.livevarRef(a) {
pctx.removedEqs.Put(p.localvargen.Generate(), a.Value)
}
return false
}
k, v, keep := p.updateBindingsEq(a, b)
if !keep {
if v != nil {
pctx.removedEqs.Put(k, v)
}
return false
}
case expr.IsCall():
terms := expr.Terms.([]*ast.Term)
if p.compiler.GetArity(expr.Operator()) == len(terms)-2 { // with captured output
output := terms[len(terms)-1]
if k, ok := output.Value.(ast.Var); ok && !p.livevars.Contains(k) && !pctx.headvars.Contains(k) {
pctx.removedEqs.Put(k, ast.CallTerm(terms[:len(terms)-1]...).Value)
return false
}
}
}
return !isNoop(expr)
}
func (p *CopyPropagator) livevarRef(a *ast.Term) bool {
ref, ok := a.Value.(ast.Ref)
if !ok {
return false
}
for _, v := range p.sorted {
if ref[0].Value.Compare(v) == 0 {
return true
}
}
return false
}
func (p *CopyPropagator) updateBindingsEq(a, b *ast.Term) (ast.Var, ast.Value, bool) {
k, v, keep := p.updateBindingsEqAsymmetric(a, b)
if !keep {
return k, v, keep
}
return p.updateBindingsEqAsymmetric(b, a)
}
func (p *CopyPropagator) updateBindingsEqAsymmetric(a, b *ast.Term) (ast.Var, ast.Value, bool) {
k, ok := a.Value.(ast.Var)
if !ok || p.livevars.Contains(k) {
return "", nil, true
}
switch b.Value.(type) {
case ast.Ref, ast.Call:
return k, b.Value, false
}
return "", nil, true
}
type plugContext struct {
removedEqs *ast.ValueMap
uf *unionFind
headvars ast.VarSet
negated bool
}
type binding struct {
k, v ast.Value
}
func containedIn(value ast.Value, x interface{}) bool {
var stop bool
switch v := value.(type) {
case ast.Ref:
ast.WalkRefs(x, func(other ast.Ref) bool {
if stop || other.HasPrefix(v) {
stop = true
return stop
}
return false
})
default:
ast.WalkTerms(x, func(other *ast.Term) bool {
if stop || other.Value.Compare(v) == 0 {
stop = true
return stop
}
return false
})
}
return stop
}
func sortbindings(bindings *ast.ValueMap) []*binding {
sorted := make([]*binding, 0, bindings.Len())
bindings.Iter(func(k ast.Value, v ast.Value) bool {
sorted = append(sorted, &binding{k, v})
return false
})
sort.Slice(sorted, func(i, j int) bool {
return sorted[i].k.Compare(sorted[j].k) > 0
})
return sorted
}
// makeDisjointSets builds the union-find structure for the query. The structure
// is built by processing all of the equality exprs in the query. Sets represent
// vars that must be equal to each other. In addition to vars, each set can have
// at most one constant. If the query contains expressions that cannot be
// satisfied (e.g., because a set has multiple constants) this function returns
// false.
func makeDisjointSets(livevars ast.VarSet, query ast.Body) (*unionFind, bool) {
uf := newUnionFind(func(r1, r2 *unionFindRoot) (*unionFindRoot, *unionFindRoot) {
if v, ok := r1.key.(ast.Var); ok && livevars.Contains(v) {
return r1, r2
}
return r2, r1
})
for _, expr := range query {
if expr.IsEquality() && !expr.Negated && len(expr.With) == 0 {
a, b := expr.Operand(0), expr.Operand(1)
varA, ok1 := a.Value.(ast.Var)
varB, ok2 := b.Value.(ast.Var)
switch {
case ok1 && ok2:
if _, ok := uf.Merge(varA, varB); !ok {
return nil, false
}
case ok1 && ast.IsConstant(b.Value):
root := uf.MakeSet(varA)
if root.constant != nil && !root.constant.Equal(b) {
return nil, false
}
root.constant = b
case ok2 && ast.IsConstant(a.Value):
root := uf.MakeSet(varB)
if root.constant != nil && !root.constant.Equal(a) {
return nil, false
}
root.constant = a
}
}
}
return uf, true
}
func isNoop(expr *ast.Expr) bool {
if !expr.IsCall() && !expr.IsEvery() {
term := expr.Terms.(*ast.Term)
if !ast.IsConstant(term.Value) {
return false
}
return !ast.Boolean(false).Equal(term.Value)
}
// A==A can be ignored
if expr.Operator().Equal(ast.Equal.Ref()) {
return expr.Operand(0).Equal(expr.Operand(1))
}
return false
}