Skip to content

Commit

Permalink
Merge pull request #7417 from klimick/partially-applied-closure-infer…
Browse files Browse the repository at this point in the history
…ence

Contextual type inference for high order function arg
  • Loading branch information
orklah committed Jan 20, 2022
2 parents 2f052a8 + 5fb1df8 commit 6f1a5e8
Show file tree
Hide file tree
Showing 11 changed files with 581 additions and 27 deletions.
Expand Up @@ -19,6 +19,7 @@
use Psalm\Internal\DataFlow\TaintSink;
use Psalm\Internal\MethodIdentifier;
use Psalm\Internal\Stubs\Generator\StubsGenerator;
use Psalm\Internal\Type\Comparator\CallableTypeComparator;
use Psalm\Internal\Type\Comparator\UnionTypeComparator;
use Psalm\Internal\Type\TemplateInferredTypeReplacer;
use Psalm\Internal\Type\TemplateResult;
Expand Down Expand Up @@ -196,7 +197,21 @@ public static function analyze(
$toggled_class_exists = true;
}

if (($arg->value instanceof PhpParser\Node\Expr\Closure
$high_order_template_result = null;

if (($arg->value instanceof PhpParser\Node\Expr\FuncCall
|| $arg->value instanceof PhpParser\Node\Expr\MethodCall
|| $arg->value instanceof PhpParser\Node\Expr\StaticCall)
&& $param
&& $function_storage = self::getHighOrderFuncStorage($context, $statements_analyzer, $arg->value)
) {
$high_order_template_result = self::handleHighOrderFuncCallArg(
$statements_analyzer,
$template_result ?? new TemplateResult([], []),
$function_storage,
$param
);
} elseif (($arg->value instanceof PhpParser\Node\Expr\Closure
|| $arg->value instanceof PhpParser\Node\Expr\ArrowFunction)
&& $param
&& !$arg->value->getDocComment()
Expand All @@ -217,7 +232,15 @@ public static function analyze(

$context->inside_call = true;

if (ExpressionAnalyzer::analyze($statements_analyzer, $arg->value, $context) === false) {
if (ExpressionAnalyzer::analyze(
$statements_analyzer,
$arg->value,
$context,
false,
null,
false,
$high_order_template_result
) === false) {
$context->inside_call = $was_inside_call;

return false;
Expand Down Expand Up @@ -315,6 +338,172 @@ private static function handleArrayMapFilterArrayArg(
}
}

private static function getHighOrderFuncStorage(
Context $context,
StatementsAnalyzer $statements_analyzer,
PhpParser\Node\Expr\CallLike $function_like_call
): ?FunctionLikeStorage {
$codebase = $statements_analyzer->getCodebase();

try {
if ($function_like_call instanceof PhpParser\Node\Expr\FuncCall) {
$function_id = strtolower((string) $function_like_call->name->getAttribute('resolvedName'));

if (empty($function_id)) {
return null;
}

return $codebase->functions->getStorage($statements_analyzer, $function_id);
}

if ($function_like_call instanceof PhpParser\Node\Expr\MethodCall &&
$function_like_call->var instanceof PhpParser\Node\Expr\Variable &&
$function_like_call->name instanceof PhpParser\Node\Identifier &&
is_string($function_like_call->var->name) &&
isset($context->vars_in_scope['$' . $function_like_call->var->name])
) {
$lhs_type = $context->vars_in_scope['$' . $function_like_call->var->name]->getSingleAtomic();

if (!$lhs_type instanceof Type\Atomic\TNamedObject) {
return null;
}

$method_id = new MethodIdentifier(
$lhs_type->value,
strtolower((string)$function_like_call->name)
);

return $codebase->methods->getStorage($method_id);
}

if ($function_like_call instanceof PhpParser\Node\Expr\StaticCall &&
$function_like_call->name instanceof PhpParser\Node\Identifier
) {
$method_id = new MethodIdentifier(
(string)$function_like_call->class->getAttribute('resolvedName'),
strtolower($function_like_call->name->name)
);

return $codebase->methods->getStorage($method_id);
}
} catch (UnexpectedValueException $e) {
return null;
}

return null;
}

/**
* Compiles TemplateResult for high-order functions ($func_call)
* by previous template args ($inferred_template_result).
*
* It's need for proper template replacement:
*
* ```
* * template T
* * return Closure(T): T
* function id(): Closure { ... }
*
* * template A
* * template B
* *
* * param list<A> $_items
* * param callable(A): B $_ab
* * return list<B>
* function map(array $items, callable $ab): array { ... }
*
* // list<int>
* $numbers = [1, 2, 3];
*
* $result = map($numbers, id());
* // $result is list<int> because template T of id() was inferred by previous arg.
* ```
*/
private static function handleHighOrderFuncCallArg(
StatementsAnalyzer $statements_analyzer,
TemplateResult $inferred_template_result,
FunctionLikeStorage $storage,
FunctionLikeParameter $actual_func_param
): ?TemplateResult {
$codebase = $statements_analyzer->getCodebase();

$input_hof_atomic = $storage->return_type && $storage->return_type->isSingle()
? $storage->return_type->getSingleAtomic()
: null;

// Try upcast invokable to callable type.
if ($input_hof_atomic instanceof Type\Atomic\TNamedObject &&
$input_hof_atomic->value !== 'Closure' &&
$codebase->classExists($input_hof_atomic->value)
) {
$callable_from_invokable = CallableTypeComparator::getCallableFromAtomic(
$codebase,
$input_hof_atomic
);

if ($callable_from_invokable) {
$invoke_id = new MethodIdentifier($input_hof_atomic->value, '__invoke');
$declaring_invoke_id = $codebase->methods->getDeclaringMethodId($invoke_id);

$storage = $codebase->methods->getStorage($declaring_invoke_id ?? $invoke_id);
$input_hof_atomic = $callable_from_invokable;
}
}

if (!$input_hof_atomic instanceof TClosure && !$input_hof_atomic instanceof TCallable) {
return null;
}

$container_hof_atomic = $actual_func_param->type && $actual_func_param->type->isSingle()
? $actual_func_param->type->getSingleAtomic()
: null;

if (!$container_hof_atomic instanceof TClosure && !$container_hof_atomic instanceof TCallable) {
return null;
}

$replaced_container_hof_atomic = new Union([clone $container_hof_atomic]);

// Replaces all input args in container function.
//
// For example:
// The map function expects callable(A):B as second param
// We know that previous arg type is list<int> where the int is the A template.
// Then we can replace callable(A): B to callable(int):B using $inferred_template_result.
TemplateInferredTypeReplacer::replace(
$replaced_container_hof_atomic,
$inferred_template_result,
$codebase
);

/** @var TClosure|TCallable $container_hof_atomic */
$container_hof_atomic = $replaced_container_hof_atomic->getSingleAtomic();
$high_order_template_result = new TemplateResult($storage->template_types ?: [], []);

// We can replace each templated param for the input function.
// Example:
// map($numbers, id());
// We know that map expects callable(int):B because the $numbers is list<int>.
// We know that id() returns callable(T):T.
// Then we can replace templated params sequentially using the expected callable(int):B.
foreach ($input_hof_atomic->params ?? [] as $offset => $actual_func_param) {
if ($actual_func_param->type &&
$actual_func_param->type->getTemplateTypes() &&
isset($container_hof_atomic->params[$offset])
) {
TemplateStandinTypeReplacer::replace(
clone $actual_func_param->type,
$high_order_template_result,
$codebase,
null,
$container_hof_atomic->params[$offset]->type
);
}
}

return $high_order_template_result;
}

/**
* @param array<int, PhpParser\Node\Arg> $args
*/
Expand Down
Expand Up @@ -82,7 +82,8 @@ class FunctionCallAnalyzer extends CallAnalyzer
public static function analyze(
StatementsAnalyzer $statements_analyzer,
PhpParser\Node\Expr\FuncCall $stmt,
Context $context
Context $context,
?TemplateResult $template_result = null
): bool {
$function_name = $stmt->name;

Expand Down Expand Up @@ -166,10 +167,12 @@ public static function analyze(
}

if (!$is_first_class_callable) {
$template_result = null;

if (isset($function_call_info->function_storage->template_types)) {
$template_result = new TemplateResult($function_call_info->function_storage->template_types ?: [], []);
if (!$template_result) {
$template_result = new TemplateResult([], []);
}

$template_result->template_types += $function_call_info->function_storage->template_types ?: [];
}

ArgumentsAnalyzer::analyze(
Expand Down Expand Up @@ -205,6 +208,10 @@ public static function analyze(
}
}

$already_inferred_lower_bounds = $template_result
? $template_result->lower_bounds
: [];

$template_result = new TemplateResult([], []);

// do this here to allow closure param checks
Expand All @@ -229,6 +236,8 @@ public static function analyze(
$function_call_info->function_id
);

$template_result->lower_bounds += $already_inferred_lower_bounds;

if ($function_name instanceof PhpParser\Node\Name && $function_call_info->function_id) {
$stmt_type = FunctionCallReturnTypeFetcher::fetch(
$statements_analyzer,
Expand Down
Expand Up @@ -76,7 +76,8 @@ public static function analyze(
?Atomic $static_type,
bool $is_intersection,
?string $lhs_var_id,
AtomicMethodCallAnalysisResult $result
AtomicMethodCallAnalysisResult $result,
?TemplateResult $inferred_template_result = null
): void {
if ($lhs_type_part instanceof TTemplateParam
&& !$lhs_type_part->as->isMixed()
Expand Down Expand Up @@ -440,7 +441,8 @@ public static function analyze(
$static_type,
$lhs_var_id,
$method_id,
$result
$result,
$inferred_template_result
);

$statements_analyzer->node_data = $old_node_data;
Expand Down
Expand Up @@ -68,7 +68,8 @@ public static function analyze(
?Atomic $static_type,
?string $lhs_var_id,
MethodIdentifier $method_id,
AtomicMethodCallAnalysisResult $result
AtomicMethodCallAnalysisResult $result,
?TemplateResult $inferred_template_result = null
): Union {
$config = $codebase->config;

Expand Down Expand Up @@ -220,6 +221,10 @@ public static function analyze(
$template_result = new TemplateResult([], $class_template_params ?: []);
$template_result->lower_bounds += $method_template_params;

if ($inferred_template_result) {
$template_result->lower_bounds += $inferred_template_result->lower_bounds;
}

if ($codebase->store_node_types
&& !$context->collect_initializations
&& !$context->collect_mutations
Expand Down
Expand Up @@ -11,6 +11,7 @@
use Psalm\Internal\Analyzer\Statements\Expression\ExpressionIdentifier;
use Psalm\Internal\Analyzer\Statements\ExpressionAnalyzer;
use Psalm\Internal\Analyzer\StatementsAnalyzer;
use Psalm\Internal\Type\TemplateResult;
use Psalm\Issue\InvalidMethodCall;
use Psalm\Issue\InvalidScope;
use Psalm\Issue\NullReference;
Expand Down Expand Up @@ -43,7 +44,8 @@ public static function analyze(
StatementsAnalyzer $statements_analyzer,
PhpParser\Node\Expr\MethodCall $stmt,
Context $context,
bool $real_method_call = true
bool $real_method_call = true,
?TemplateResult $template_result = null
): bool {
$was_inside_call = $context->inside_call;

Expand Down Expand Up @@ -194,7 +196,8 @@ public static function analyze(
: null,
false,
$lhs_var_id,
$result
$result,
$template_result
);
if (isset($context->vars_in_scope[$lhs_var_id])
&& ($possible_new_class_type = $context->vars_in_scope[$lhs_var_id]) instanceof Union
Expand Down
Expand Up @@ -41,7 +41,8 @@ class StaticCallAnalyzer extends CallAnalyzer
public static function analyze(
StatementsAnalyzer $statements_analyzer,
PhpParser\Node\Expr\StaticCall $stmt,
Context $context
Context $context,
?TemplateResult $template_result = null
): bool {
$method_id = null;

Expand Down Expand Up @@ -219,7 +220,8 @@ public static function analyze(
$lhs_type->ignore_nullable_issues,
$moved_call,
$has_mock,
$has_existing_method
$has_existing_method,
$template_result
);
}

Expand Down

0 comments on commit 6f1a5e8

Please sign in to comment.