Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Infer lambda return types #7741

Draft
wants to merge 3 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
29 changes: 23 additions & 6 deletions core/types/calls.cc
Original file line number Diff line number Diff line change
Expand Up @@ -359,7 +359,8 @@ struct GuessOverloadCandidate {
// Guess overload. The way we guess is only arity based - we will return the overload that has the smallest number of
// arguments that is >= args.size()
MethodRef guessOverload(const GlobalState &gs, ClassOrModuleRef inClass, MethodRef primary, uint16_t numPosArgs,
InlinedVector<const TypeAndOrigins *, 2> &args, const vector<TypePtr> &targs, bool hasBlock) {
InlinedVector<const TypeAndOrigins *, 2> &args, const vector<TypePtr> &targs,
const shared_ptr<const SendAndBlockLink> &block) {
counterInc("calls.overloaded_invocations");
MethodRef fallback = primary;
vector<MethodRef> allCandidates;
Expand Down Expand Up @@ -478,11 +479,21 @@ MethodRef guessOverload(const GlobalState &gs, ClassOrModuleRef inClass, MethodR
ENFORCE(!args.empty(), "Should at least have a block argument.");
const auto &lastArg = args.back();
auto mentionsBlockArg = !lastArg.isSyntheticBlockArgument();
if (hasBlock) {
if (block != nullptr) {
if (!mentionsBlockArg || lastArg.type == Types::nilClass()) {
it = leftCandidates.erase(it);
continue;
}
if (auto blockParamType = cast_type<AppliedType>(lastArg.type)) {
if (auto procArity = Types::getProcArity(*blockParamType)) {
if (auto fixedArity = block->fixedArity()) {
if (procArity != block->fixedArity()) {
it = leftCandidates.erase(it);
continue;
}
}
}
}
} else {
if (mentionsBlockArg && lastArg.type != nullptr &&
(!lastArg.type.isFullyDefined() || !Types::isSubType(gs, Types::nilClass(), lastArg.type))) {
Expand Down Expand Up @@ -881,10 +892,9 @@ DispatchResult dispatchCallSymbol(const GlobalState &gs, const DispatchArgs &arg
return result;
}

auto method =
mayBeOverloaded.data(gs)->flags.isOverloaded
? guessOverload(gs, symbol, mayBeOverloaded, args.numPosArgs, args.args, targs, args.block != nullptr)
: mayBeOverloaded;
auto method = mayBeOverloaded.data(gs)->flags.isOverloaded
? guessOverload(gs, symbol, mayBeOverloaded, args.numPosArgs, args.args, targs, args.block)
: mayBeOverloaded;

if (method.data(gs)->flags.isPrivate && !args.isPrivateOk) {
if (auto e = gs.beginError(errLoc, core::errors::Infer::PrivateMethod)) {
Expand Down Expand Up @@ -4105,6 +4115,13 @@ class Kernel_proc : public IntrinsicMethod {
res.returnType = core::Types::procClass();
return;
}

if (args.name == core::Names::lambda()) {
// Handled by the Kernel#lambda overload, generated by generate_procs
// TODO(jez) Can we expand to proc?
return;
}

auto untypedWithBlame = core::Types::untyped(Symbols::Magic_UntypedSource_proc());
vector<core::TypePtr> targs(*numberOfPositionalBlockParams + 1, untypedWithBlame);
auto procClass = core::Symbols::Proc(*numberOfPositionalBlockParams);
Expand Down
13 changes: 2 additions & 11 deletions rbi/core/kernel.rbi
Original file line number Diff line number Diff line change
Expand Up @@ -1882,17 +1882,8 @@ module Kernel
end
def proc(&blk); end

# Equivalent to
# [`Proc.new`](https://docs.ruby-lang.org/en/2.7.0/Proc.html#method-c-new),
# except the resulting [`Proc`](https://docs.ruby-lang.org/en/2.7.0/Proc.html)
# objects check the number of parameters passed when called.
sig do
params(
blk: T.untyped,
)
.returns(Proc)
end
def lambda(&blk); end
### The signature for this method is defined in procs.rbi, generated by generate_procs.cc
### def lambda(&blk); end

# Equivalent to:
#
Expand Down
49 changes: 49 additions & 0 deletions rbi/tools/generate_procs.cc
Original file line number Diff line number Diff line change
Expand Up @@ -43,10 +43,59 @@ void emitProc(ofstream &out, int arity) {
out << "end" << '\n' << '\n';
}

void emitLambdaOverload(ofstream &out, int arity) {
out << " sig do\n";
out << " type_parameters(:Return)\n"
" .params(\n"
" blk: T.proc";
if (arity != 0) {
out << ".params(\n";
for (int i = 0; i < arity; ++i) {
out << " arg" << i << ": T.untyped";
if (i + 1 != arity) {
out << ",";
}
out << "\n";
}
out << " )\n ";
}
out << ".returns(T.type_parameter(:Return))\n";
out << " )\n"
" .returns(\n"
" T.proc";
if (arity != 0) {
out << ".params(\n";
for (int i = 0; i < arity; ++i) {
out << " arg" << i << ": T.untyped";
if (i + 1 != arity) {
out << ",";
}
out << "\n";
}
out << " )\n ";
}
out << ".returns(T.type_parameter(:Return))\n"
" )\n"
" end\n";
}

int main(int argc, char **argv) {
ofstream rb(argv[1], ios::trunc);
rb << "# typed: true" << '\n';
for (int arity = 0; arity <= MAX_PROC_ARITY; ++arity) {
emitProc(rb, arity);
}

rb << "module Kernel\n"
" # Equivalent to\n"
" # [`Proc.new`](https://docs.ruby-lang.org/en/2.7.0/Proc.html#method-c-new),\n"
" # except the resulting [`Proc`](https://docs.ruby-lang.org/en/2.7.0/Proc.html)\n"
" # objects check the number of parameters passed when called.\n";

for (int arity = 0; arity <= MAX_PROC_ARITY; ++arity) {
emitLambdaOverload(rb, arity);
}

rb << " def lambda(&blk); end\n"
"end\n\n";
}
19 changes: 19 additions & 0 deletions test/testdata/infer/lambda_return.rb
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# typed: true
extend T::Sig

sig { params(x: T.anything).returns(String) }
def returns_string(x)
''
end

f = -> () { 0 }
T.reveal_type(f) # error: T.proc.returns(Integer)

f = -> (x) { 0 }
T.reveal_type(f) # error: T.proc.params(arg0: T.untyped).returns(Integer)

f = -> (x) { x }
T.reveal_type(f) # error: T.proc.params(arg0: T.untyped).returns(T.untyped)

f = -> (x) { returns_string(x) }
T.reveal_type(f) # error: T.proc.params(arg0: T.untyped).returns(String)