diff --git a/ChangeLog b/ChangeLog index faed12c04..9f4856dad 100644 --- a/ChangeLog +++ b/ChangeLog @@ -24,6 +24,9 @@ What's New in astroid 2.12.13? ============================== Release date: TBA +* Infer the `length` argument of the ``random.sample`` function. + + Refs PyCQA/pylint#7706 What's New in astroid 2.12.12? diff --git a/astroid/brain/brain_random.py b/astroid/brain/brain_random.py index e66aa81a0..a580ff704 100644 --- a/astroid/brain/brain_random.py +++ b/astroid/brain/brain_random.py @@ -42,10 +42,10 @@ def infer_random_sample(node, context=None): if len(node.args) != 2: raise UseInferenceDefault - length = node.args[1] - if not isinstance(length, Const): + inferred_length = helpers.safe_infer(node.args[1], context=context) + if not isinstance(inferred_length, Const): raise UseInferenceDefault - if not isinstance(length.value, int): + if not isinstance(inferred_length.value, int): raise UseInferenceDefault inferred_sequence = helpers.safe_infer(node.args[0], context=context) @@ -55,12 +55,12 @@ def infer_random_sample(node, context=None): if not isinstance(inferred_sequence, ACCEPTED_ITERABLES_FOR_SAMPLE): raise UseInferenceDefault - if length.value > len(inferred_sequence.elts): + if inferred_length.value > len(inferred_sequence.elts): # In this case, this will raise a ValueError raise UseInferenceDefault try: - elts = random.sample(inferred_sequence.elts, length.value) + elts = random.sample(inferred_sequence.elts, inferred_length.value) except ValueError as exc: raise UseInferenceDefault from exc diff --git a/tests/unittest_brain.py b/tests/unittest_brain.py index 2b0b7230c..93cb220e1 100644 --- a/tests/unittest_brain.py +++ b/tests/unittest_brain.py @@ -2413,6 +2413,29 @@ def test_inferred_successfully(self) -> None: elems = sorted(elem.value for elem in inferred.elts) self.assertEqual(elems, [1, 2]) + def test_arguments_inferred_successfully(self) -> None: + """Test inference of `random.sample` when both arguments are of type `nodes.Call`.""" + node = astroid.extract_node( + """ + import random + + def sequence(): + return [1, 2] + + random.sample(sequence(), len([1,2])) #@ + """ + ) + # Check that arguments are of type `nodes.Call`. + sequence, length = node.args + self.assertIsInstance(sequence, astroid.Call) + self.assertIsInstance(length, astroid.Call) + + # Check the inference of `random.sample` call. + inferred = next(node.infer()) + self.assertIsInstance(inferred, astroid.List) + elems = sorted(elem.value for elem in inferred.elts) + self.assertEqual(elems, [1, 2]) + def test_no_crash_on_evaluatedobject(self) -> None: node = astroid.extract_node( """