diff --git a/ast/compile.go b/ast/compile.go index 8768c05cb3..81f4082b1f 100644 --- a/ast/compile.go +++ b/ast/compile.go @@ -2465,11 +2465,19 @@ func outputVarsForTerms(expr *Expr, safe VarSet) VarSet { case *SetComprehension, *ArrayComprehension, *ObjectComprehension: return true case Ref: - if safe.Contains(r[0].Value.(Var)) { - output.Update(r.OutputVars()) - return false + if v, ok := r[0].Value.(Var); ok { + if !safe.Contains(v) { + return true + } + } else { + for k := range r[0].Vars() { + if !safe.Contains(k) { + return true + } + } } - return true + output.Update(r.OutputVars()) + return false } return false }) diff --git a/ast/compile_test.go b/ast/compile_test.go index a9df875b96..dbd36e6fbf 100644 --- a/ast/compile_test.go +++ b/ast/compile_test.go @@ -178,6 +178,21 @@ func TestOutputVarsForNode(t *testing.T) { query: "x = 1; y = x; z = y", exp: "{x, y, z}", }, + { + note: "composite head", + query: "{1, 2}[1] = x", + exp: `{x}`, + }, + { + note: "composite head", + query: "x = 1; {x, 2}[1] = y", + exp: `{x, y}`, + }, + { + note: "composite head", + query: "{x, 2}[1] = y", + exp: `set()`, + }, } for _, tc := range tests { diff --git a/ast/unify.go b/ast/unify.go index 6207c1cec1..7a87ea75df 100644 --- a/ast/unify.go +++ b/ast/unify.go @@ -27,6 +27,18 @@ func (u *unifier) isSafe(x Var) bool { return u.safe.Contains(x) || u.unified.Contains(x) } +func (u *unifier) isHeadSafe(r Ref) bool { + if v, ok := r[0].Value.(Var); ok { + return u.isSafe(v) + } + for v := range r[0].Vars() { + if !u.isSafe(v) { + return false + } + } + return true +} + func (u *unifier) unify(a *Term, b *Term) { switch a := a.Value.(type) { @@ -45,7 +57,7 @@ func (u *unifier) unify(a *Term, b *Term) { case *Array, Object: u.unifyAll(a, b) case Ref: - if u.isSafe(b[0].Value.(Var)) { + if u.isHeadSafe(b) { u.markSafe(a) } default: @@ -53,7 +65,7 @@ func (u *unifier) unify(a *Term, b *Term) { } case Ref: - if u.isSafe(a[0].Value.(Var)) { + if u.isHeadSafe(a) { switch b := b.Value.(type) { case Var: u.markSafe(b)