Skip to content

Commit

Permalink
Ensure overloads are searched in the order they are declared during d…
Browse files Browse the repository at this point in the history
…ynamic dispatch
  • Loading branch information
TristonianJones committed Jul 13, 2022
1 parent f3df06c commit 5b81ae9
Showing 1 changed file with 39 additions and 39 deletions.
78 changes: 39 additions & 39 deletions cel/decls.go
Expand Up @@ -221,7 +221,8 @@ func (t *Type) equals(other *Type) bool {
// - The from types are the same instance
// - The target type is dynamic
// - The fromType has the same kind and type name as the target type, and all parameters of the target type
// are IsAssignableType() from the parameters of the fromType.
//
// are IsAssignableType() from the parameters of the fromType.
func (t *Type) defaultIsAssignableType(fromType *Type) bool {
if t == fromType || t.isDyn() {
return true
Expand Down Expand Up @@ -333,7 +334,7 @@ func Function(name string, opts ...FunctionOpt) EnvOption {
return func(e *Env) (*Env, error) {
fn := &functionDecl{
name: name,
overloads: map[string]*overloadDecl{},
overloads: []*o{},
options: opts,
}
err := fn.init()
Expand Down Expand Up @@ -445,12 +446,12 @@ func MemberOverload(overloadID string, args []*Type, resultType *Type, opts ...O
}

// OverloadOpt is a functional option for configuring a function overload.
type OverloadOpt func(*overloadDecl) (*overloadDecl, error)
type OverloadOpt func(*o) (*o, error)

// UnaryBinding provides the implementation of a unary overload. The provided function is protected by a runtime
// type-guard which ensures runtime type agreement between the overload signature and runtime argument types.
func UnaryBinding(binding functions.UnaryOp) OverloadOpt {
return func(o *overloadDecl) (*overloadDecl, error) {
return func(o *o) (*o, error) {
if o.hasBinding() {
return nil, fmt.Errorf("overload already has a binding: %s", o.id)
}
Expand All @@ -465,7 +466,7 @@ func UnaryBinding(binding functions.UnaryOp) OverloadOpt {
// BinaryBinding provides the implementation of a binary overload. The provided function is protected by a runtime
// type-guard which ensures runtime type agreement between the overload signature and runtime argument types.
func BinaryBinding(binding functions.BinaryOp) OverloadOpt {
return func(o *overloadDecl) (*overloadDecl, error) {
return func(o *o) (*o, error) {
if o.hasBinding() {
return nil, fmt.Errorf("overload already has a binding: %s", o.id)
}
Expand All @@ -480,7 +481,7 @@ func BinaryBinding(binding functions.BinaryOp) OverloadOpt {
// FunctionBinding provides the implementation of a variadic overload. The provided function is protected by a runtime
// type-guard which ensures runtime type agreement between the overload signature and runtime argument types.
func FunctionBinding(binding functions.FunctionOp) OverloadOpt {
return func(o *overloadDecl) (*overloadDecl, error) {
return func(o *o) (*o, error) {
if o.hasBinding() {
return nil, fmt.Errorf("overload already has a binding: %s", o.id)
}
Expand All @@ -493,7 +494,7 @@ func FunctionBinding(binding functions.FunctionOp) OverloadOpt {
//
// Note: do not use this option unless absoluately necessary as it should be an uncommon feature.
func OverloadIsNonStrict() OverloadOpt {
return func(o *overloadDecl) (*overloadDecl, error) {
return func(o *o) (*o, error) {
o.nonStrict = true
return o, nil
}
Expand All @@ -502,15 +503,15 @@ func OverloadIsNonStrict() OverloadOpt {
// OverloadOperandTrait configures a set of traits which the first argument to the overload must implement in order to be
// successfully invoked.
func OverloadOperandTrait(trait int) OverloadOpt {
return func(o *overloadDecl) (*overloadDecl, error) {
return func(o *o) (*o, error) {
o.operandTrait = trait
return o, nil
}
}

type functionDecl struct {
name string
overloads map[string]*overloadDecl
overloads []*o
options []FunctionOpt
singleton *functions.Overload
initialized bool
Expand Down Expand Up @@ -591,22 +592,22 @@ func (f *functionDecl) bindings() ([]*functions.Overload, error) {
// performs dynamic dispatch to the proper overload based on the argument types.
bindings := append([]*functions.Overload{}, overloads...)
funcDispatch := func(args ...ref.Val) ref.Val {
for _, overloadDecl := range f.overloads {
if !overloadDecl.matchesRuntimeSignature(args...) {
for _, o := range f.overloads {
if !o.matchesRuntimeSignature(args...) {
continue
}
switch len(args) {
case 1:
if overloadDecl.unaryOp != nil {
return overloadDecl.unaryOp(args[0])
if o.unaryOp != nil {
return o.unaryOp(args[0])
}
case 2:
if overloadDecl.binaryOp != nil {
return overloadDecl.binaryOp(args[0], args[1])
if o.binaryOp != nil {
return o.binaryOp(args[0], args[1])
}
}
if overloadDecl.functionOp != nil {
return overloadDecl.functionOp(args...)
if o.functionOp != nil {
return o.functionOp(args...)
}
// eventually this will fall through to the noSuchOverload below.
}
Expand Down Expand Up @@ -639,14 +640,12 @@ func (f *functionDecl) merge(other *functionDecl) (*functionDecl, error) {
}
merged := &functionDecl{
name: f.name,
overloads: map[string]*overloadDecl{},
overloads: make([]*o, len(f.overloads)),
options: []FunctionOpt{},
initialized: true,
singleton: f.singleton,
}
for id, o := range f.overloads {
merged.overloads[id] = o
}
copy(merged.overloads, f.overloads)
for _, o := range other.overloads {
err := merged.addOverload(o)
if err != nil {
Expand All @@ -665,21 +664,22 @@ func (f *functionDecl) merge(other *functionDecl) (*functionDecl, error) {
// addOverload ensures that the new overload does not collide with an existing overload signature;
// however, if the function signatures are identical, the implementation may be rewritten as its
// difficult to compare functions by object identity.
func (f *functionDecl) addOverload(overload *overloadDecl) error {
for id, o := range f.overloads {
if id != overload.id && o.signatureOverlaps(overload) {
func (f *functionDecl) addOverload(overload *o) error {
for index, o := range f.overloads {
if o.id != overload.id && o.signatureOverlaps(overload) {
return fmt.Errorf("overload signature collision in function %s: %s collides with %s", f.name, o.id, overload.id)
}
if id == overload.id {
if o.id == overload.id {
if o.signatureEquals(overload) && o.nonStrict == overload.nonStrict {
// Allow redefinition of an overload implementation so long as the signatures match.
f.overloads[id] = overload
f.overloads[index] = overload
return nil
} else {
return fmt.Errorf("overload redefinition in function. %s: %s has multiple definitions", f.name, o.id)
}
}
}
f.overloads[overload.id] = overload
f.overloads = append(f.overloads, overload)
return nil
}

Expand All @@ -692,8 +692,8 @@ func noSuchOverload(funcName string, args ...ref.Val) ref.Val {
return types.NewErr("no such overload: %s(%s)", funcName, signature)
}

// overloadDecl contains all of the relevant information regarding a specific function overload.
type overloadDecl struct {
// o contains all of the relevant information regarding a specific function overload.
type o struct {
id string
argTypes []*Type
resultType *Type
Expand All @@ -709,12 +709,12 @@ type overloadDecl struct {
operandTrait int
}

func (o *overloadDecl) hasBinding() bool {
func (o *o) hasBinding() bool {
return o.unaryOp != nil || o.binaryOp != nil || o.functionOp != nil
}

// guardedUnaryOp creates an invocation guard around the provided unary operator, if one is defined.
func (o *overloadDecl) guardedUnaryOp(funcName string) functions.UnaryOp {
func (o *o) guardedUnaryOp(funcName string) functions.UnaryOp {
if o.unaryOp == nil {
return nil
}
Expand All @@ -727,7 +727,7 @@ func (o *overloadDecl) guardedUnaryOp(funcName string) functions.UnaryOp {
}

// guardedBinaryOp creates an invocation guard around the provided binary operator, if one is defined.
func (o *overloadDecl) guardedBinaryOp(funcName string) functions.BinaryOp {
func (o *o) guardedBinaryOp(funcName string) functions.BinaryOp {
if o.binaryOp == nil {
return nil
}
Expand All @@ -740,7 +740,7 @@ func (o *overloadDecl) guardedBinaryOp(funcName string) functions.BinaryOp {
}

// guardedFunctionOp creates an invocation guard around the provided variadic function binding, if one is provided.
func (o *overloadDecl) guardedFunctionOp(funcName string) functions.FunctionOp {
func (o *o) guardedFunctionOp(funcName string) functions.FunctionOp {
if o.functionOp == nil {
return nil
}
Expand All @@ -753,15 +753,15 @@ func (o *overloadDecl) guardedFunctionOp(funcName string) functions.FunctionOp {
}

// matchesRuntimeUnarySignature indicates whether the argument type is runtime assiganble to the overload's expected argument.
func (o *overloadDecl) matchesRuntimeUnarySignature(arg ref.Val) bool {
func (o *o) matchesRuntimeUnarySignature(arg ref.Val) bool {
if o.nonStrict && types.IsUnknownOrError(arg) {
return true
}
return o.argTypes[0].IsAssignableRuntimeType(arg.Type()) && (o.operandTrait == 0 || arg.Type().HasTrait(o.operandTrait))
}

// matchesRuntimeBinarySignature indicates whether the argument types are runtime assiganble to the overload's expected arguments.
func (o *overloadDecl) matchesRuntimeBinarySignature(arg1, arg2 ref.Val) bool {
func (o *o) matchesRuntimeBinarySignature(arg1, arg2 ref.Val) bool {
if o.nonStrict {
if types.IsUnknownOrError(arg1) {
return types.IsUnknownOrError(arg2) || o.argTypes[1].IsAssignableRuntimeType(arg2.Type())
Expand All @@ -773,7 +773,7 @@ func (o *overloadDecl) matchesRuntimeBinarySignature(arg1, arg2 ref.Val) bool {
}

// matchesRuntimeSignature indicates whether the argument types are runtime assiganble to the overload's expected arguments.
func (o *overloadDecl) matchesRuntimeSignature(args ...ref.Val) bool {
func (o *o) matchesRuntimeSignature(args ...ref.Val) bool {
if len(args) != len(o.argTypes) {
return false
}
Expand All @@ -795,7 +795,7 @@ func (o *overloadDecl) matchesRuntimeSignature(args ...ref.Val) bool {
// signatureEquals indicates whether one overload has an identical signature to another overload.
//
// Providing a duplicate signature is not an issue, but an overloapping signature is problematic.
func (o *overloadDecl) signatureEquals(other *overloadDecl) bool {
func (o *o) signatureEquals(other *o) bool {
if o.id != other.id || o.memberFunction != other.memberFunction || len(o.argTypes) != len(other.argTypes) {
return false
}
Expand All @@ -811,7 +811,7 @@ func (o *overloadDecl) signatureEquals(other *overloadDecl) bool {
// signatureOverlaps indicates whether one overload has an overlapping signature with another overload.
//
// The 'other' overload must first be checked for equality before determining whether it overlaps in order to be completely accurate.
func (o *overloadDecl) signatureOverlaps(other *overloadDecl) bool {
func (o *o) signatureOverlaps(other *o) bool {
if o.memberFunction != other.memberFunction || len(o.argTypes) != len(other.argTypes) {
return false
}
Expand All @@ -827,7 +827,7 @@ func (o *overloadDecl) signatureOverlaps(other *overloadDecl) bool {

func newOverload(overloadID string, memberFunction bool, args []*Type, resultType *Type, opts ...OverloadOpt) FunctionOpt {
return func(f *functionDecl) (*functionDecl, error) {
overload := &overloadDecl{
overload := &o{
id: overloadID,
argTypes: args,
resultType: resultType,
Expand Down

0 comments on commit 5b81ae9

Please sign in to comment.