Skip to content

Commit

Permalink
Teach guessOverload to respect block arity
Browse files Browse the repository at this point in the history
Sometimes it's useful to be able to use the arity of the block to guess
an overload.

This isn't perfect for all the reasons that overload checking isn't
perfect, but there are some places where this is useful, especially in
abstractions that check the proc's arity when deciding how to call the
block.

This is also a pre-requisite for doing something like #7741, which is a
partial fix for #3914 / #4149, where we infer the types of
`Kernel#lambda` blocks by codegenerating overloaded signatures.
  • Loading branch information
jez committed Apr 25, 2024
1 parent 950e441 commit 05007af
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 8 deletions.
22 changes: 16 additions & 6 deletions core/types/calls.cc
Original file line number Diff line number Diff line change
Expand Up @@ -361,7 +361,8 @@ struct GuessOverloadCandidate {
// Guess overload. The way we guess is only arity based - we will return the overload that has the smallest number of
// arguments that is >= args.size()
MethodRef guessOverload(const GlobalState &gs, ClassOrModuleRef inClass, MethodRef primary, uint16_t numPosArgs,
InlinedVector<const TypeAndOrigins *, 2> &args, const vector<TypePtr> &targs, bool hasBlock) {
InlinedVector<const TypeAndOrigins *, 2> &args, const vector<TypePtr> &targs,
const shared_ptr<const SendAndBlockLink> &block) {
counterInc("calls.overloaded_invocations");
MethodRef fallback = primary;
vector<MethodRef> allCandidates;
Expand Down Expand Up @@ -481,11 +482,21 @@ MethodRef guessOverload(const GlobalState &gs, ClassOrModuleRef inClass, MethodR
ENFORCE(!args.empty(), "Should at least have a block argument.");
const auto &lastArg = args.back();
auto mentionsBlockArg = !lastArg.isSyntheticBlockArgument();
if (hasBlock) {
if (block != nullptr) {
if (!mentionsBlockArg || lastArg.type == Types::nilClass()) {
it = leftCandidates.erase(it);
continue;
}
if (auto blockParamType = cast_type<AppliedType>(lastArg.type)) {
if (auto procArity = Types::getProcArity(*blockParamType)) {
if (auto fixedArity = block->fixedArity()) {
if (procArity != block->fixedArity()) {
it = leftCandidates.erase(it);
continue;
}
}
}
}
} else {
if (mentionsBlockArg && lastArg.type != nullptr &&
(!lastArg.type.isFullyDefined() || !Types::isSubType(gs, Types::nilClass(), lastArg.type))) {
Expand Down Expand Up @@ -899,10 +910,9 @@ DispatchResult dispatchCallSymbol(const GlobalState &gs, const DispatchArgs &arg
return result;
}

auto method =
mayBeOverloaded.data(gs)->flags.isOverloaded
? guessOverload(gs, symbol, mayBeOverloaded, args.numPosArgs, args.args, targs, args.block != nullptr)
: mayBeOverloaded;
auto method = mayBeOverloaded.data(gs)->flags.isOverloaded
? guessOverload(gs, symbol, mayBeOverloaded, args.numPosArgs, args.args, targs, args.block)
: mayBeOverloaded;

if (method.data(gs)->flags.isPrivate && !args.isPrivateOk) {
if (auto e = gs.beginError(errLoc, core::errors::Infer::PrivateMethod)) {
Expand Down
4 changes: 2 additions & 2 deletions test/testdata/infer/overloads_test.rb
Original file line number Diff line number Diff line change
Expand Up @@ -202,8 +202,8 @@ def test
end
T.reveal_type(x) # error: `Integer`
x = block_arity_overload do |y|
T.reveal_type(y) # error: `NilClass`
T.reveal_type(y) # error: `String`
end
T.reveal_type(x) # error: `Integer`
T.reveal_type(x) # error: `String`
end
end

0 comments on commit 05007af

Please sign in to comment.