From 8b2e9ffc3aaea778812ee8416f888fe4027620ae Mon Sep 17 00:00:00 2001 From: Torin Sandall Date: Fri, 2 Apr 2021 10:06:15 -0400 Subject: [PATCH] ast: Refactor type checker and environment state This commit does not change any functionality it just updates the implementation so that type environment mutation only occurs within the type checker. This makes it easier to reason about changes to the type environment data structure(s). Specifically: * The type environment constructor is now private. Callers do not need to instantiate type environments. This should only be done by the checker. No one appears to be using this constructor so while this is backwards incompatible, it should be safe. * The schema set is now held by the checker as opposed to the type environment. The environment should not have to know anything about the schemas. * The global input schema is now loaded by the compiler on init() on and provided as input to the checker along with other global options like the schema set. This avoids having the compiler reach into the type environment to perform updates. Signed-off-by: Torin Sandall --- ast/check.go | 80 +++++++++++++++++++++++++++++------------------ ast/check_test.go | 14 ++++----- ast/compile.go | 51 ++++++++++++++---------------- ast/env.go | 29 +++++++---------- 4 files changed, 92 insertions(+), 82 deletions(-) diff --git a/ast/check.go b/ast/check.go index 3356e3910a..57652686b7 100644 --- a/ast/check.go +++ b/ast/check.go @@ -27,6 +27,8 @@ type typeChecker struct { errs Errors exprCheckers map[string]exprChecker varRewriter rewriteVars + ss *SchemaSet + input types.Type } // newTypeChecker returns a new typeChecker object that has no errors. @@ -38,23 +40,57 @@ func newTypeChecker() *typeChecker { return tc } +func (tc *typeChecker) newEnv(exist *TypeEnv) *TypeEnv { + if exist != nil { + return exist.wrap() + } + env := newTypeEnv(tc.copy) + if tc.input != nil { + env.tree.Put(InputRootRef, tc.input) + } + return env +} + +func (tc *typeChecker) copy() *typeChecker { + return newTypeChecker(). + WithVarRewriter(tc.varRewriter). + WithSchemaSet(tc.ss). + WithInputType(tc.input) +} + +func (tc *typeChecker) WithSchemaSet(ss *SchemaSet) *typeChecker { + tc.ss = ss + return tc +} + func (tc *typeChecker) WithVarRewriter(f rewriteVars) *typeChecker { tc.varRewriter = f return tc } +func (tc *typeChecker) WithInputType(tpe types.Type) *typeChecker { + tc.input = tpe + return tc +} + +// Env returns a type environment for the specified built-ins with any other +// global types configured on the checker. In practice, this is the default +// environment that other statements will be checked against. +func (tc *typeChecker) Env(builtins map[string]*Builtin) *TypeEnv { + env := tc.newEnv(nil) + for _, bi := range builtins { + env.tree.Put(bi.Ref(), bi.Decl) + } + return env +} + // CheckBody runs type checking on the body and returns a TypeEnv if no errors // are found. The resulting TypeEnv wraps the provided one. The resulting // TypeEnv will be able to resolve types of vars contained in the body. func (tc *typeChecker) CheckBody(env *TypeEnv, body Body) (*TypeEnv, Errors) { errors := []*Error{} - - if env == nil { - env = NewTypeEnv() - } else { - env = env.wrap() - } + env = tc.newEnv(env) WalkExprs(body, func(expr *Expr) bool { @@ -94,11 +130,7 @@ func (tc *typeChecker) CheckBody(env *TypeEnv, body Body) (*TypeEnv, Errors) { // are found. The resulting TypeEnv wraps the provided one. The resulting // TypeEnv will be able to resolve types of refs that refer to rules. func (tc *typeChecker) CheckTypes(env *TypeEnv, sorted []util.T) (*TypeEnv, Errors) { - if env == nil { - env = NewTypeEnv() - } else { - env = env.wrap() - } + env = tc.newEnv(env) for _, s := range sorted { tc.checkRule(env, s.(*Rule)) } @@ -111,19 +143,19 @@ func (tc *typeChecker) checkClosures(env *TypeEnv, expr *Expr) Errors { WalkClosures(expr, func(x interface{}) bool { switch x := x.(type) { case *ArrayComprehension: - _, errs := newTypeChecker().WithVarRewriter(tc.varRewriter).CheckBody(env, x.Body) + _, errs := tc.copy().CheckBody(env, x.Body) if len(errs) > 0 { result = errs return true } case *SetComprehension: - _, errs := newTypeChecker().WithVarRewriter(tc.varRewriter).CheckBody(env, x.Body) + _, errs := tc.copy().CheckBody(env, x.Body) if len(errs) > 0 { result = errs return true } case *ObjectComprehension: - _, errs := newTypeChecker().WithVarRewriter(tc.varRewriter).CheckBody(env, x.Body) + _, errs := tc.copy().CheckBody(env, x.Body) if len(errs) > 0 { result = errs return true @@ -134,25 +166,13 @@ func (tc *typeChecker) checkClosures(env *TypeEnv, expr *Expr) Errors { return result } -func (tc *typeChecker) checkLanguageBuiltins(env *TypeEnv, builtins map[string]*Builtin) *TypeEnv { - if env == nil { - env = NewTypeEnv() - } else { - env = env.wrap() - } - for _, bi := range builtins { - env.tree.Put(bi.Ref(), bi.Decl) - } - return env -} - func (tc *typeChecker) checkRule(env *TypeEnv, rule *Rule) { env = env.wrap() if schemaAnnots := getRuleAnnotation(rule); schemaAnnots != nil { for _, schemaAnnot := range schemaAnnots { - ref, refType, err := processAnnotation(schemaAnnot, env, rule) + ref, refType, err := processAnnotation(tc.ss, schemaAnnot, env, rule) if err != nil { tc.err([]*Error{err}) continue @@ -1121,8 +1141,8 @@ func getRuleAnnotation(rule *Rule) (sannots []SchemaAnnotation) { // NOTE: Currently, annotations must preceed the rule. In the future, this // restriction could be relaxed with other kinds of annotation scopes. -func processAnnotation(annot SchemaAnnotation, env *TypeEnv, rule *Rule) (Ref, types.Type, *Error) { - if env.schemaSet == nil { +func processAnnotation(ss *SchemaSet, annot SchemaAnnotation, env *TypeEnv, rule *Rule) (Ref, types.Type, *Error) { + if ss == nil { return nil, nil, NewError(TypeErr, rule.Location, "schemas need to be supplied for the annotation: %s", annot.Schema) } @@ -1131,7 +1151,7 @@ func processAnnotation(annot SchemaAnnotation, env *TypeEnv, rule *Rule) (Ref, t return nil, nil, NewError(TypeErr, rule.Location, "schema is not well formed in annotation: %s", annot.Schema) } - schema := env.schemaSet.Get(schemaRef) + schema := ss.Get(schemaRef) if schema == nil { return nil, nil, NewError(TypeErr, rule.Location, "schema does not exist for given path in annotation: %s", schemaRef.String()) } diff --git a/ast/check_test.go b/ast/check_test.go index a6e71b7041..4dbe7a3487 100644 --- a/ast/check_test.go +++ b/ast/check_test.go @@ -286,7 +286,7 @@ func TestCheckInference(t *testing.T) { t.Run(tc.note, func(t *testing.T) { body := MustParseBody(tc.query) checker := newTypeChecker() - env := checker.checkLanguageBuiltins(nil, BuiltinMap) + env := checker.Env(BuiltinMap) env, err := checker.CheckBody(env, body) if len(err) != 0 { t.Fatalf("Unexpected error: %v", err) @@ -528,7 +528,7 @@ func TestCheckErrorSuppression(t *testing.T) { query = `_ = [true | count(1)]` - _, errs = newTypeChecker().CheckBody(newTypeChecker().checkLanguageBuiltins(nil, BuiltinMap), MustParseBody(query)) + _, errs = newTypeChecker().CheckBody(newTypeChecker().Env(BuiltinMap), MustParseBody(query)) if len(errs) != 1 { t.Fatalf("Expected exactly one error but got: %v", errs) } @@ -557,7 +557,7 @@ func TestCheckBadCardinality(t *testing.T) { for _, test := range tests { body := MustParseBody(test.body) tc := newTypeChecker() - env := tc.checkLanguageBuiltins(nil, BuiltinMap) + env := tc.Env(BuiltinMap) _, err := tc.CheckBody(env, body) if len(err) != 1 || err[0].Code != TypeErr { t.Fatalf("Expected 1 type error from %v but got: %v", body, err) @@ -965,7 +965,7 @@ func TestFunctionTypeInferenceUnappliedWithObjectVarKey(t *testing.T) { f(x) = y { y = {x: 1} } `) - env, err := newTypeChecker().CheckTypes(newTypeChecker().checkLanguageBuiltins(nil, BuiltinMap), []util.T{ + env, err := newTypeChecker().CheckTypes(newTypeChecker().Env(BuiltinMap), []util.T{ module.Rules[0], }) @@ -1208,7 +1208,7 @@ func newTestEnv(rs []string) *TypeEnv { } } - env, err := newTypeChecker().CheckTypes(newTypeChecker().checkLanguageBuiltins(nil, BuiltinMap), elems) + env, err := newTypeChecker().CheckTypes(newTypeChecker().Env(BuiltinMap), elems) if len(err) > 0 { panic(err) } @@ -1863,8 +1863,8 @@ whocan[user] { } } - oldTypeEnv := newTypeChecker().checkLanguageBuiltins(nil, BuiltinMap).WithSchemas(tc.schemaSet) - typeenv, errors := newTypeChecker().CheckTypes(oldTypeEnv, elems) + oldTypeEnv := newTypeChecker().WithSchemaSet(tc.schemaSet).Env(BuiltinMap) + typeenv, errors := newTypeChecker().WithSchemaSet(tc.schemaSet).CheckTypes(oldTypeEnv, elems) if len(errors) > 0 { for _, e := range errors { if tc.err == "" || !strings.Contains(e.Error(), tc.err) { diff --git a/ast/compile.go b/ast/compile.go index 513dfa5650..a2d089a93d 100644 --- a/ast/compile.go +++ b/ast/compile.go @@ -107,6 +107,7 @@ type Compiler struct { initialized bool // indicates if init() has been called debug debug.Debug // emits debug information produced during compilation schemaSet *SchemaSet // user-supplied schemas for input and data documents + inputType types.Type // global input type retrieved from schema set } // CompilerStage defines the interface for stages in the compiler. @@ -220,7 +221,6 @@ func NewCompiler() *Compiler { c := &Compiler{ Modules: map[string]*Module{}, - TypeEnv: NewTypeEnv(), RewrittenVars: map[Var]Var{}, ruleIndices: util.NewHashMap(func(a, b util.T) bool { r1, r2 := a.(Ref), b.(Ref) @@ -630,7 +630,7 @@ func (c *Compiler) RuleIndex(path Ref) RuleIndex { // PassesTypeCheck determines whether the given body passes type checking func (c *Compiler) PassesTypeCheck(body Body) bool { - checker := newTypeChecker() + checker := newTypeChecker().WithSchemaSet(c.schemaSet).WithInputType(c.inputType) env := c.TypeEnv _, errs := checker.CheckBody(env, body) return len(errs) == 0 @@ -941,7 +941,10 @@ func parseSchema(schema interface{}) (types.Type, error) { func (c *Compiler) checkTypes() { // Recursion is caught in earlier step, so this cannot fail. sorted, _ := c.Graph.Sort() - checker := newTypeChecker().WithVarRewriter(rewriteVarsInRef(c.RewrittenVars)) + checker := newTypeChecker(). + WithSchemaSet(c.schemaSet). + WithInputType(c.inputType). + WithVarRewriter(rewriteVarsInRef(c.RewrittenVars)) env, errs := checker.CheckTypes(c.TypeEnv, sorted) for _, err := range errs { c.err(err) @@ -1016,33 +1019,24 @@ func (c *Compiler) init() { c.builtins[name] = bi } - tc := newTypeChecker() - c.TypeEnv = tc.checkLanguageBuiltins(nil, c.builtins) - c.setSchemas() - - c.initialized = true -} - -func (c *Compiler) setSchemas() { + // Load the global input schema if one was provided. if c.schemaSet != nil { - - // First, set the schemaSet in the type environment - c.TypeEnv.WithSchemas(c.schemaSet) - - // Second, set the schema for the input globally if it exists - schema := c.schemaSet.Get(InputRootRef) - if schema == nil { - return + if schema := c.schemaSet.Get(InputRootRef); schema != nil { + tpe, err := loadSchema(schema) + if err != nil { + c.err(NewError(TypeErr, nil, err.Error())) + } else { + c.inputType = tpe + } } + } - tpe, err := loadSchema(schema) - if err != nil { - c.err(NewError(TypeErr, nil, err.Error())) - return - } + c.TypeEnv = newTypeChecker(). + WithSchemaSet(c.schemaSet). + WithInputType(c.inputType). + Env(c.builtins) - c.TypeEnv.tree.Put(InputRootRef, tpe) - } + c.initialized = true } func (c *Compiler) err(err *Error) { @@ -1658,7 +1652,10 @@ func (qc *queryCompiler) checkSafety(_ *QueryContext, body Body) (Body, error) { func (qc *queryCompiler) checkTypes(qctx *QueryContext, body Body) (Body, error) { var errs Errors - checker := newTypeChecker().WithVarRewriter(rewriteVarsInRef(qc.rewritten, qc.compiler.RewrittenVars)) + checker := newTypeChecker(). + WithSchemaSet(qc.compiler.schemaSet). + WithInputType(qc.compiler.inputType). + WithVarRewriter(rewriteVarsInRef(qc.rewritten, qc.compiler.RewrittenVars)) qc.typeEnv, errs = checker.CheckBody(qc.compiler.TypeEnv, body) if len(errs) > 0 { return nil, errs diff --git a/ast/env.go b/ast/env.go index ca5d851d26..60006baafd 100644 --- a/ast/env.go +++ b/ast/env.go @@ -11,24 +11,20 @@ import ( // TypeEnv contains type info for static analysis such as type checking. type TypeEnv struct { - tree *typeTreeNode - next *TypeEnv - schemaSet *SchemaSet + tree *typeTreeNode + next *TypeEnv + newChecker func() *typeChecker } -// NewTypeEnv returns an empty TypeEnv. -func NewTypeEnv() *TypeEnv { +// newTypeEnv returns an empty TypeEnv. The constructor is not exported because +// type environments should only be created by the type checker. +func newTypeEnv(f func() *typeChecker) *TypeEnv { return &TypeEnv{ - tree: newTypeTree(), + tree: newTypeTree(), + newChecker: f, } } -// WithSchemas sets the user-provided schemas -func (env *TypeEnv) WithSchemas(schemas *SchemaSet) *TypeEnv { - env.schemaSet = schemas - return env -} - // Get returns the type of x. func (env *TypeEnv) Get(x interface{}) types.Type { @@ -101,22 +97,19 @@ func (env *TypeEnv) Get(x interface{}) types.Type { // Comprehensions. case *ArrayComprehension: - checker := newTypeChecker() - cpy, errs := checker.CheckBody(env, x.Body) + cpy, errs := env.newChecker().CheckBody(env, x.Body) if len(errs) == 0 { return types.NewArray(nil, cpy.Get(x.Term)) } return nil case *ObjectComprehension: - checker := newTypeChecker() - cpy, errs := checker.CheckBody(env, x.Body) + cpy, errs := env.newChecker().CheckBody(env, x.Body) if len(errs) == 0 { return types.NewObject(nil, types.NewDynamicProperty(cpy.Get(x.Key), cpy.Get(x.Value))) } return nil case *SetComprehension: - checker := newTypeChecker() - cpy, errs := checker.CheckBody(env, x.Body) + cpy, errs := env.newChecker().CheckBody(env, x.Body) if len(errs) == 0 { return types.NewSet(cpy.Get(x.Term)) }