diff --git a/ast/check.go b/ast/check.go index 3356e3910a..57652686b7 100644 --- a/ast/check.go +++ b/ast/check.go @@ -27,6 +27,8 @@ type typeChecker struct { errs Errors exprCheckers map[string]exprChecker varRewriter rewriteVars + ss *SchemaSet + input types.Type } // newTypeChecker returns a new typeChecker object that has no errors. @@ -38,23 +40,57 @@ func newTypeChecker() *typeChecker { return tc } +func (tc *typeChecker) newEnv(exist *TypeEnv) *TypeEnv { + if exist != nil { + return exist.wrap() + } + env := newTypeEnv(tc.copy) + if tc.input != nil { + env.tree.Put(InputRootRef, tc.input) + } + return env +} + +func (tc *typeChecker) copy() *typeChecker { + return newTypeChecker(). + WithVarRewriter(tc.varRewriter). + WithSchemaSet(tc.ss). + WithInputType(tc.input) +} + +func (tc *typeChecker) WithSchemaSet(ss *SchemaSet) *typeChecker { + tc.ss = ss + return tc +} + func (tc *typeChecker) WithVarRewriter(f rewriteVars) *typeChecker { tc.varRewriter = f return tc } +func (tc *typeChecker) WithInputType(tpe types.Type) *typeChecker { + tc.input = tpe + return tc +} + +// Env returns a type environment for the specified built-ins with any other +// global types configured on the checker. In practice, this is the default +// environment that other statements will be checked against. +func (tc *typeChecker) Env(builtins map[string]*Builtin) *TypeEnv { + env := tc.newEnv(nil) + for _, bi := range builtins { + env.tree.Put(bi.Ref(), bi.Decl) + } + return env +} + // CheckBody runs type checking on the body and returns a TypeEnv if no errors // are found. The resulting TypeEnv wraps the provided one. The resulting // TypeEnv will be able to resolve types of vars contained in the body. func (tc *typeChecker) CheckBody(env *TypeEnv, body Body) (*TypeEnv, Errors) { errors := []*Error{} - - if env == nil { - env = NewTypeEnv() - } else { - env = env.wrap() - } + env = tc.newEnv(env) WalkExprs(body, func(expr *Expr) bool { @@ -94,11 +130,7 @@ func (tc *typeChecker) CheckBody(env *TypeEnv, body Body) (*TypeEnv, Errors) { // are found. The resulting TypeEnv wraps the provided one. The resulting // TypeEnv will be able to resolve types of refs that refer to rules. func (tc *typeChecker) CheckTypes(env *TypeEnv, sorted []util.T) (*TypeEnv, Errors) { - if env == nil { - env = NewTypeEnv() - } else { - env = env.wrap() - } + env = tc.newEnv(env) for _, s := range sorted { tc.checkRule(env, s.(*Rule)) } @@ -111,19 +143,19 @@ func (tc *typeChecker) checkClosures(env *TypeEnv, expr *Expr) Errors { WalkClosures(expr, func(x interface{}) bool { switch x := x.(type) { case *ArrayComprehension: - _, errs := newTypeChecker().WithVarRewriter(tc.varRewriter).CheckBody(env, x.Body) + _, errs := tc.copy().CheckBody(env, x.Body) if len(errs) > 0 { result = errs return true } case *SetComprehension: - _, errs := newTypeChecker().WithVarRewriter(tc.varRewriter).CheckBody(env, x.Body) + _, errs := tc.copy().CheckBody(env, x.Body) if len(errs) > 0 { result = errs return true } case *ObjectComprehension: - _, errs := newTypeChecker().WithVarRewriter(tc.varRewriter).CheckBody(env, x.Body) + _, errs := tc.copy().CheckBody(env, x.Body) if len(errs) > 0 { result = errs return true @@ -134,25 +166,13 @@ func (tc *typeChecker) checkClosures(env *TypeEnv, expr *Expr) Errors { return result } -func (tc *typeChecker) checkLanguageBuiltins(env *TypeEnv, builtins map[string]*Builtin) *TypeEnv { - if env == nil { - env = NewTypeEnv() - } else { - env = env.wrap() - } - for _, bi := range builtins { - env.tree.Put(bi.Ref(), bi.Decl) - } - return env -} - func (tc *typeChecker) checkRule(env *TypeEnv, rule *Rule) { env = env.wrap() if schemaAnnots := getRuleAnnotation(rule); schemaAnnots != nil { for _, schemaAnnot := range schemaAnnots { - ref, refType, err := processAnnotation(schemaAnnot, env, rule) + ref, refType, err := processAnnotation(tc.ss, schemaAnnot, env, rule) if err != nil { tc.err([]*Error{err}) continue @@ -1121,8 +1141,8 @@ func getRuleAnnotation(rule *Rule) (sannots []SchemaAnnotation) { // NOTE: Currently, annotations must preceed the rule. In the future, this // restriction could be relaxed with other kinds of annotation scopes. -func processAnnotation(annot SchemaAnnotation, env *TypeEnv, rule *Rule) (Ref, types.Type, *Error) { - if env.schemaSet == nil { +func processAnnotation(ss *SchemaSet, annot SchemaAnnotation, env *TypeEnv, rule *Rule) (Ref, types.Type, *Error) { + if ss == nil { return nil, nil, NewError(TypeErr, rule.Location, "schemas need to be supplied for the annotation: %s", annot.Schema) } @@ -1131,7 +1151,7 @@ func processAnnotation(annot SchemaAnnotation, env *TypeEnv, rule *Rule) (Ref, t return nil, nil, NewError(TypeErr, rule.Location, "schema is not well formed in annotation: %s", annot.Schema) } - schema := env.schemaSet.Get(schemaRef) + schema := ss.Get(schemaRef) if schema == nil { return nil, nil, NewError(TypeErr, rule.Location, "schema does not exist for given path in annotation: %s", schemaRef.String()) } diff --git a/ast/check_test.go b/ast/check_test.go index a6e71b7041..4dbe7a3487 100644 --- a/ast/check_test.go +++ b/ast/check_test.go @@ -286,7 +286,7 @@ func TestCheckInference(t *testing.T) { t.Run(tc.note, func(t *testing.T) { body := MustParseBody(tc.query) checker := newTypeChecker() - env := checker.checkLanguageBuiltins(nil, BuiltinMap) + env := checker.Env(BuiltinMap) env, err := checker.CheckBody(env, body) if len(err) != 0 { t.Fatalf("Unexpected error: %v", err) @@ -528,7 +528,7 @@ func TestCheckErrorSuppression(t *testing.T) { query = `_ = [true | count(1)]` - _, errs = newTypeChecker().CheckBody(newTypeChecker().checkLanguageBuiltins(nil, BuiltinMap), MustParseBody(query)) + _, errs = newTypeChecker().CheckBody(newTypeChecker().Env(BuiltinMap), MustParseBody(query)) if len(errs) != 1 { t.Fatalf("Expected exactly one error but got: %v", errs) } @@ -557,7 +557,7 @@ func TestCheckBadCardinality(t *testing.T) { for _, test := range tests { body := MustParseBody(test.body) tc := newTypeChecker() - env := tc.checkLanguageBuiltins(nil, BuiltinMap) + env := tc.Env(BuiltinMap) _, err := tc.CheckBody(env, body) if len(err) != 1 || err[0].Code != TypeErr { t.Fatalf("Expected 1 type error from %v but got: %v", body, err) @@ -965,7 +965,7 @@ func TestFunctionTypeInferenceUnappliedWithObjectVarKey(t *testing.T) { f(x) = y { y = {x: 1} } `) - env, err := newTypeChecker().CheckTypes(newTypeChecker().checkLanguageBuiltins(nil, BuiltinMap), []util.T{ + env, err := newTypeChecker().CheckTypes(newTypeChecker().Env(BuiltinMap), []util.T{ module.Rules[0], }) @@ -1208,7 +1208,7 @@ func newTestEnv(rs []string) *TypeEnv { } } - env, err := newTypeChecker().CheckTypes(newTypeChecker().checkLanguageBuiltins(nil, BuiltinMap), elems) + env, err := newTypeChecker().CheckTypes(newTypeChecker().Env(BuiltinMap), elems) if len(err) > 0 { panic(err) } @@ -1863,8 +1863,8 @@ whocan[user] { } } - oldTypeEnv := newTypeChecker().checkLanguageBuiltins(nil, BuiltinMap).WithSchemas(tc.schemaSet) - typeenv, errors := newTypeChecker().CheckTypes(oldTypeEnv, elems) + oldTypeEnv := newTypeChecker().WithSchemaSet(tc.schemaSet).Env(BuiltinMap) + typeenv, errors := newTypeChecker().WithSchemaSet(tc.schemaSet).CheckTypes(oldTypeEnv, elems) if len(errors) > 0 { for _, e := range errors { if tc.err == "" || !strings.Contains(e.Error(), tc.err) { diff --git a/ast/compile.go b/ast/compile.go index 513dfa5650..a2d089a93d 100644 --- a/ast/compile.go +++ b/ast/compile.go @@ -107,6 +107,7 @@ type Compiler struct { initialized bool // indicates if init() has been called debug debug.Debug // emits debug information produced during compilation schemaSet *SchemaSet // user-supplied schemas for input and data documents + inputType types.Type // global input type retrieved from schema set } // CompilerStage defines the interface for stages in the compiler. @@ -220,7 +221,6 @@ func NewCompiler() *Compiler { c := &Compiler{ Modules: map[string]*Module{}, - TypeEnv: NewTypeEnv(), RewrittenVars: map[Var]Var{}, ruleIndices: util.NewHashMap(func(a, b util.T) bool { r1, r2 := a.(Ref), b.(Ref) @@ -630,7 +630,7 @@ func (c *Compiler) RuleIndex(path Ref) RuleIndex { // PassesTypeCheck determines whether the given body passes type checking func (c *Compiler) PassesTypeCheck(body Body) bool { - checker := newTypeChecker() + checker := newTypeChecker().WithSchemaSet(c.schemaSet).WithInputType(c.inputType) env := c.TypeEnv _, errs := checker.CheckBody(env, body) return len(errs) == 0 @@ -941,7 +941,10 @@ func parseSchema(schema interface{}) (types.Type, error) { func (c *Compiler) checkTypes() { // Recursion is caught in earlier step, so this cannot fail. sorted, _ := c.Graph.Sort() - checker := newTypeChecker().WithVarRewriter(rewriteVarsInRef(c.RewrittenVars)) + checker := newTypeChecker(). + WithSchemaSet(c.schemaSet). + WithInputType(c.inputType). + WithVarRewriter(rewriteVarsInRef(c.RewrittenVars)) env, errs := checker.CheckTypes(c.TypeEnv, sorted) for _, err := range errs { c.err(err) @@ -1016,33 +1019,24 @@ func (c *Compiler) init() { c.builtins[name] = bi } - tc := newTypeChecker() - c.TypeEnv = tc.checkLanguageBuiltins(nil, c.builtins) - c.setSchemas() - - c.initialized = true -} - -func (c *Compiler) setSchemas() { + // Load the global input schema if one was provided. if c.schemaSet != nil { - - // First, set the schemaSet in the type environment - c.TypeEnv.WithSchemas(c.schemaSet) - - // Second, set the schema for the input globally if it exists - schema := c.schemaSet.Get(InputRootRef) - if schema == nil { - return + if schema := c.schemaSet.Get(InputRootRef); schema != nil { + tpe, err := loadSchema(schema) + if err != nil { + c.err(NewError(TypeErr, nil, err.Error())) + } else { + c.inputType = tpe + } } + } - tpe, err := loadSchema(schema) - if err != nil { - c.err(NewError(TypeErr, nil, err.Error())) - return - } + c.TypeEnv = newTypeChecker(). + WithSchemaSet(c.schemaSet). + WithInputType(c.inputType). + Env(c.builtins) - c.TypeEnv.tree.Put(InputRootRef, tpe) - } + c.initialized = true } func (c *Compiler) err(err *Error) { @@ -1658,7 +1652,10 @@ func (qc *queryCompiler) checkSafety(_ *QueryContext, body Body) (Body, error) { func (qc *queryCompiler) checkTypes(qctx *QueryContext, body Body) (Body, error) { var errs Errors - checker := newTypeChecker().WithVarRewriter(rewriteVarsInRef(qc.rewritten, qc.compiler.RewrittenVars)) + checker := newTypeChecker(). + WithSchemaSet(qc.compiler.schemaSet). + WithInputType(qc.compiler.inputType). + WithVarRewriter(rewriteVarsInRef(qc.rewritten, qc.compiler.RewrittenVars)) qc.typeEnv, errs = checker.CheckBody(qc.compiler.TypeEnv, body) if len(errs) > 0 { return nil, errs diff --git a/ast/env.go b/ast/env.go index ca5d851d26..60006baafd 100644 --- a/ast/env.go +++ b/ast/env.go @@ -11,24 +11,20 @@ import ( // TypeEnv contains type info for static analysis such as type checking. type TypeEnv struct { - tree *typeTreeNode - next *TypeEnv - schemaSet *SchemaSet + tree *typeTreeNode + next *TypeEnv + newChecker func() *typeChecker } -// NewTypeEnv returns an empty TypeEnv. -func NewTypeEnv() *TypeEnv { +// newTypeEnv returns an empty TypeEnv. The constructor is not exported because +// type environments should only be created by the type checker. +func newTypeEnv(f func() *typeChecker) *TypeEnv { return &TypeEnv{ - tree: newTypeTree(), + tree: newTypeTree(), + newChecker: f, } } -// WithSchemas sets the user-provided schemas -func (env *TypeEnv) WithSchemas(schemas *SchemaSet) *TypeEnv { - env.schemaSet = schemas - return env -} - // Get returns the type of x. func (env *TypeEnv) Get(x interface{}) types.Type { @@ -101,22 +97,19 @@ func (env *TypeEnv) Get(x interface{}) types.Type { // Comprehensions. case *ArrayComprehension: - checker := newTypeChecker() - cpy, errs := checker.CheckBody(env, x.Body) + cpy, errs := env.newChecker().CheckBody(env, x.Body) if len(errs) == 0 { return types.NewArray(nil, cpy.Get(x.Term)) } return nil case *ObjectComprehension: - checker := newTypeChecker() - cpy, errs := checker.CheckBody(env, x.Body) + cpy, errs := env.newChecker().CheckBody(env, x.Body) if len(errs) == 0 { return types.NewObject(nil, types.NewDynamicProperty(cpy.Get(x.Key), cpy.Get(x.Value))) } return nil case *SetComprehension: - checker := newTypeChecker() - cpy, errs := checker.CheckBody(env, x.Body) + cpy, errs := env.newChecker().CheckBody(env, x.Body) if len(errs) == 0 { return types.NewSet(cpy.Get(x.Term)) }