Skip to content

Commit

Permalink
ast: Refactor type checker and environment state
Browse files Browse the repository at this point in the history
This commit does not change any functionality it just updates the
implementation so that type environment mutation only occurs within
the type checker. This makes it easier to reason about changes to the
type environment data structure(s). Specifically:

* The type environment constructor is now private. Callers do not need
  to instantiate type environments. This should only be done by the
  checker. No one appears to be using this constructor so while this
  is backwards incompatible, it should be safe.

* The schema set is now held by the checker as opposed to the type
  environment. The environment should not have to know anything about
  the schemas.

* The global input schema is now loaded by the compiler on init() on
  and provided as input to the checker along with other global options
  like the schema set. This avoids having the compiler reach into the
  type environment to perform updates.

Signed-off-by: Torin Sandall <torinsandall@gmail.com>
  • Loading branch information
tsandall committed Apr 2, 2021
1 parent ed75f77 commit 8b2e9ff
Show file tree
Hide file tree
Showing 4 changed files with 92 additions and 82 deletions.
80 changes: 50 additions & 30 deletions ast/check.go
Expand Up @@ -27,6 +27,8 @@ type typeChecker struct {
errs Errors
exprCheckers map[string]exprChecker
varRewriter rewriteVars
ss *SchemaSet
input types.Type
}

// newTypeChecker returns a new typeChecker object that has no errors.
Expand All @@ -38,23 +40,57 @@ func newTypeChecker() *typeChecker {
return tc
}

func (tc *typeChecker) newEnv(exist *TypeEnv) *TypeEnv {
if exist != nil {
return exist.wrap()
}
env := newTypeEnv(tc.copy)
if tc.input != nil {
env.tree.Put(InputRootRef, tc.input)
}
return env
}

func (tc *typeChecker) copy() *typeChecker {
return newTypeChecker().
WithVarRewriter(tc.varRewriter).
WithSchemaSet(tc.ss).
WithInputType(tc.input)
}

func (tc *typeChecker) WithSchemaSet(ss *SchemaSet) *typeChecker {
tc.ss = ss
return tc
}

func (tc *typeChecker) WithVarRewriter(f rewriteVars) *typeChecker {
tc.varRewriter = f
return tc
}

func (tc *typeChecker) WithInputType(tpe types.Type) *typeChecker {
tc.input = tpe
return tc
}

// Env returns a type environment for the specified built-ins with any other
// global types configured on the checker. In practice, this is the default
// environment that other statements will be checked against.
func (tc *typeChecker) Env(builtins map[string]*Builtin) *TypeEnv {
env := tc.newEnv(nil)
for _, bi := range builtins {
env.tree.Put(bi.Ref(), bi.Decl)
}
return env
}

// CheckBody runs type checking on the body and returns a TypeEnv if no errors
// are found. The resulting TypeEnv wraps the provided one. The resulting
// TypeEnv will be able to resolve types of vars contained in the body.
func (tc *typeChecker) CheckBody(env *TypeEnv, body Body) (*TypeEnv, Errors) {

errors := []*Error{}

if env == nil {
env = NewTypeEnv()
} else {
env = env.wrap()
}
env = tc.newEnv(env)

WalkExprs(body, func(expr *Expr) bool {

Expand Down Expand Up @@ -94,11 +130,7 @@ func (tc *typeChecker) CheckBody(env *TypeEnv, body Body) (*TypeEnv, Errors) {
// are found. The resulting TypeEnv wraps the provided one. The resulting
// TypeEnv will be able to resolve types of refs that refer to rules.
func (tc *typeChecker) CheckTypes(env *TypeEnv, sorted []util.T) (*TypeEnv, Errors) {
if env == nil {
env = NewTypeEnv()
} else {
env = env.wrap()
}
env = tc.newEnv(env)
for _, s := range sorted {
tc.checkRule(env, s.(*Rule))
}
Expand All @@ -111,19 +143,19 @@ func (tc *typeChecker) checkClosures(env *TypeEnv, expr *Expr) Errors {
WalkClosures(expr, func(x interface{}) bool {
switch x := x.(type) {
case *ArrayComprehension:
_, errs := newTypeChecker().WithVarRewriter(tc.varRewriter).CheckBody(env, x.Body)
_, errs := tc.copy().CheckBody(env, x.Body)
if len(errs) > 0 {
result = errs
return true
}
case *SetComprehension:
_, errs := newTypeChecker().WithVarRewriter(tc.varRewriter).CheckBody(env, x.Body)
_, errs := tc.copy().CheckBody(env, x.Body)
if len(errs) > 0 {
result = errs
return true
}
case *ObjectComprehension:
_, errs := newTypeChecker().WithVarRewriter(tc.varRewriter).CheckBody(env, x.Body)
_, errs := tc.copy().CheckBody(env, x.Body)
if len(errs) > 0 {
result = errs
return true
Expand All @@ -134,25 +166,13 @@ func (tc *typeChecker) checkClosures(env *TypeEnv, expr *Expr) Errors {
return result
}

func (tc *typeChecker) checkLanguageBuiltins(env *TypeEnv, builtins map[string]*Builtin) *TypeEnv {
if env == nil {
env = NewTypeEnv()
} else {
env = env.wrap()
}
for _, bi := range builtins {
env.tree.Put(bi.Ref(), bi.Decl)
}
return env
}

func (tc *typeChecker) checkRule(env *TypeEnv, rule *Rule) {

env = env.wrap()

if schemaAnnots := getRuleAnnotation(rule); schemaAnnots != nil {
for _, schemaAnnot := range schemaAnnots {
ref, refType, err := processAnnotation(schemaAnnot, env, rule)
ref, refType, err := processAnnotation(tc.ss, schemaAnnot, env, rule)
if err != nil {
tc.err([]*Error{err})
continue
Expand Down Expand Up @@ -1121,8 +1141,8 @@ func getRuleAnnotation(rule *Rule) (sannots []SchemaAnnotation) {

// NOTE: Currently, annotations must preceed the rule. In the future, this
// restriction could be relaxed with other kinds of annotation scopes.
func processAnnotation(annot SchemaAnnotation, env *TypeEnv, rule *Rule) (Ref, types.Type, *Error) {
if env.schemaSet == nil {
func processAnnotation(ss *SchemaSet, annot SchemaAnnotation, env *TypeEnv, rule *Rule) (Ref, types.Type, *Error) {
if ss == nil {
return nil, nil, NewError(TypeErr, rule.Location, "schemas need to be supplied for the annotation: %s", annot.Schema)
}

Expand All @@ -1131,7 +1151,7 @@ func processAnnotation(annot SchemaAnnotation, env *TypeEnv, rule *Rule) (Ref, t
return nil, nil, NewError(TypeErr, rule.Location, "schema is not well formed in annotation: %s", annot.Schema)
}

schema := env.schemaSet.Get(schemaRef)
schema := ss.Get(schemaRef)
if schema == nil {
return nil, nil, NewError(TypeErr, rule.Location, "schema does not exist for given path in annotation: %s", schemaRef.String())
}
Expand Down
14 changes: 7 additions & 7 deletions ast/check_test.go
Expand Up @@ -286,7 +286,7 @@ func TestCheckInference(t *testing.T) {
t.Run(tc.note, func(t *testing.T) {
body := MustParseBody(tc.query)
checker := newTypeChecker()
env := checker.checkLanguageBuiltins(nil, BuiltinMap)
env := checker.Env(BuiltinMap)
env, err := checker.CheckBody(env, body)
if len(err) != 0 {
t.Fatalf("Unexpected error: %v", err)
Expand Down Expand Up @@ -528,7 +528,7 @@ func TestCheckErrorSuppression(t *testing.T) {

query = `_ = [true | count(1)]`

_, errs = newTypeChecker().CheckBody(newTypeChecker().checkLanguageBuiltins(nil, BuiltinMap), MustParseBody(query))
_, errs = newTypeChecker().CheckBody(newTypeChecker().Env(BuiltinMap), MustParseBody(query))
if len(errs) != 1 {
t.Fatalf("Expected exactly one error but got: %v", errs)
}
Expand Down Expand Up @@ -557,7 +557,7 @@ func TestCheckBadCardinality(t *testing.T) {
for _, test := range tests {
body := MustParseBody(test.body)
tc := newTypeChecker()
env := tc.checkLanguageBuiltins(nil, BuiltinMap)
env := tc.Env(BuiltinMap)
_, err := tc.CheckBody(env, body)
if len(err) != 1 || err[0].Code != TypeErr {
t.Fatalf("Expected 1 type error from %v but got: %v", body, err)
Expand Down Expand Up @@ -965,7 +965,7 @@ func TestFunctionTypeInferenceUnappliedWithObjectVarKey(t *testing.T) {
f(x) = y { y = {x: 1} }
`)

env, err := newTypeChecker().CheckTypes(newTypeChecker().checkLanguageBuiltins(nil, BuiltinMap), []util.T{
env, err := newTypeChecker().CheckTypes(newTypeChecker().Env(BuiltinMap), []util.T{
module.Rules[0],
})

Expand Down Expand Up @@ -1208,7 +1208,7 @@ func newTestEnv(rs []string) *TypeEnv {
}
}

env, err := newTypeChecker().CheckTypes(newTypeChecker().checkLanguageBuiltins(nil, BuiltinMap), elems)
env, err := newTypeChecker().CheckTypes(newTypeChecker().Env(BuiltinMap), elems)
if len(err) > 0 {
panic(err)
}
Expand Down Expand Up @@ -1863,8 +1863,8 @@ whocan[user] {
}
}

oldTypeEnv := newTypeChecker().checkLanguageBuiltins(nil, BuiltinMap).WithSchemas(tc.schemaSet)
typeenv, errors := newTypeChecker().CheckTypes(oldTypeEnv, elems)
oldTypeEnv := newTypeChecker().WithSchemaSet(tc.schemaSet).Env(BuiltinMap)
typeenv, errors := newTypeChecker().WithSchemaSet(tc.schemaSet).CheckTypes(oldTypeEnv, elems)
if len(errors) > 0 {
for _, e := range errors {
if tc.err == "" || !strings.Contains(e.Error(), tc.err) {
Expand Down
51 changes: 24 additions & 27 deletions ast/compile.go
Expand Up @@ -107,6 +107,7 @@ type Compiler struct {
initialized bool // indicates if init() has been called
debug debug.Debug // emits debug information produced during compilation
schemaSet *SchemaSet // user-supplied schemas for input and data documents
inputType types.Type // global input type retrieved from schema set
}

// CompilerStage defines the interface for stages in the compiler.
Expand Down Expand Up @@ -220,7 +221,6 @@ func NewCompiler() *Compiler {

c := &Compiler{
Modules: map[string]*Module{},
TypeEnv: NewTypeEnv(),
RewrittenVars: map[Var]Var{},
ruleIndices: util.NewHashMap(func(a, b util.T) bool {
r1, r2 := a.(Ref), b.(Ref)
Expand Down Expand Up @@ -630,7 +630,7 @@ func (c *Compiler) RuleIndex(path Ref) RuleIndex {

// PassesTypeCheck determines whether the given body passes type checking
func (c *Compiler) PassesTypeCheck(body Body) bool {
checker := newTypeChecker()
checker := newTypeChecker().WithSchemaSet(c.schemaSet).WithInputType(c.inputType)
env := c.TypeEnv
_, errs := checker.CheckBody(env, body)
return len(errs) == 0
Expand Down Expand Up @@ -941,7 +941,10 @@ func parseSchema(schema interface{}) (types.Type, error) {
func (c *Compiler) checkTypes() {
// Recursion is caught in earlier step, so this cannot fail.
sorted, _ := c.Graph.Sort()
checker := newTypeChecker().WithVarRewriter(rewriteVarsInRef(c.RewrittenVars))
checker := newTypeChecker().
WithSchemaSet(c.schemaSet).
WithInputType(c.inputType).
WithVarRewriter(rewriteVarsInRef(c.RewrittenVars))
env, errs := checker.CheckTypes(c.TypeEnv, sorted)
for _, err := range errs {
c.err(err)
Expand Down Expand Up @@ -1016,33 +1019,24 @@ func (c *Compiler) init() {
c.builtins[name] = bi
}

tc := newTypeChecker()
c.TypeEnv = tc.checkLanguageBuiltins(nil, c.builtins)
c.setSchemas()

c.initialized = true
}

func (c *Compiler) setSchemas() {
// Load the global input schema if one was provided.
if c.schemaSet != nil {

// First, set the schemaSet in the type environment
c.TypeEnv.WithSchemas(c.schemaSet)

// Second, set the schema for the input globally if it exists
schema := c.schemaSet.Get(InputRootRef)
if schema == nil {
return
if schema := c.schemaSet.Get(InputRootRef); schema != nil {
tpe, err := loadSchema(schema)
if err != nil {
c.err(NewError(TypeErr, nil, err.Error()))
} else {
c.inputType = tpe
}
}
}

tpe, err := loadSchema(schema)
if err != nil {
c.err(NewError(TypeErr, nil, err.Error()))
return
}
c.TypeEnv = newTypeChecker().
WithSchemaSet(c.schemaSet).
WithInputType(c.inputType).
Env(c.builtins)

c.TypeEnv.tree.Put(InputRootRef, tpe)
}
c.initialized = true
}

func (c *Compiler) err(err *Error) {
Expand Down Expand Up @@ -1658,7 +1652,10 @@ func (qc *queryCompiler) checkSafety(_ *QueryContext, body Body) (Body, error) {

func (qc *queryCompiler) checkTypes(qctx *QueryContext, body Body) (Body, error) {
var errs Errors
checker := newTypeChecker().WithVarRewriter(rewriteVarsInRef(qc.rewritten, qc.compiler.RewrittenVars))
checker := newTypeChecker().
WithSchemaSet(qc.compiler.schemaSet).
WithInputType(qc.compiler.inputType).
WithVarRewriter(rewriteVarsInRef(qc.rewritten, qc.compiler.RewrittenVars))
qc.typeEnv, errs = checker.CheckBody(qc.compiler.TypeEnv, body)
if len(errs) > 0 {
return nil, errs
Expand Down
29 changes: 11 additions & 18 deletions ast/env.go
Expand Up @@ -11,24 +11,20 @@ import (

// TypeEnv contains type info for static analysis such as type checking.
type TypeEnv struct {
tree *typeTreeNode
next *TypeEnv
schemaSet *SchemaSet
tree *typeTreeNode
next *TypeEnv
newChecker func() *typeChecker
}

// NewTypeEnv returns an empty TypeEnv.
func NewTypeEnv() *TypeEnv {
// newTypeEnv returns an empty TypeEnv. The constructor is not exported because
// type environments should only be created by the type checker.
func newTypeEnv(f func() *typeChecker) *TypeEnv {
return &TypeEnv{
tree: newTypeTree(),
tree: newTypeTree(),
newChecker: f,
}
}

// WithSchemas sets the user-provided schemas
func (env *TypeEnv) WithSchemas(schemas *SchemaSet) *TypeEnv {
env.schemaSet = schemas
return env
}

// Get returns the type of x.
func (env *TypeEnv) Get(x interface{}) types.Type {

Expand Down Expand Up @@ -101,22 +97,19 @@ func (env *TypeEnv) Get(x interface{}) types.Type {

// Comprehensions.
case *ArrayComprehension:
checker := newTypeChecker()
cpy, errs := checker.CheckBody(env, x.Body)
cpy, errs := env.newChecker().CheckBody(env, x.Body)
if len(errs) == 0 {
return types.NewArray(nil, cpy.Get(x.Term))
}
return nil
case *ObjectComprehension:
checker := newTypeChecker()
cpy, errs := checker.CheckBody(env, x.Body)
cpy, errs := env.newChecker().CheckBody(env, x.Body)
if len(errs) == 0 {
return types.NewObject(nil, types.NewDynamicProperty(cpy.Get(x.Key), cpy.Get(x.Value)))
}
return nil
case *SetComprehension:
checker := newTypeChecker()
cpy, errs := checker.CheckBody(env, x.Body)
cpy, errs := env.newChecker().CheckBody(env, x.Body)
if len(errs) == 0 {
return types.NewSet(cpy.Get(x.Term))
}
Expand Down

0 comments on commit 8b2e9ff

Please sign in to comment.