Skip to content

Commit

Permalink
Infer the length argument of the random.sample function. (#1862)
Browse files Browse the repository at this point in the history
  • Loading branch information
mbyrnepr2 authored and Pierre-Sassoulas committed Nov 19, 2022
1 parent 7dab003 commit e1aabb7
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 5 deletions.
3 changes: 3 additions & 0 deletions ChangeLog
Expand Up @@ -24,6 +24,9 @@ What's New in astroid 2.12.13?
==============================
Release date: TBA

* Infer the `length` argument of the ``random.sample`` function.

Refs PyCQA/pylint#7706


What's New in astroid 2.12.12?
Expand Down
10 changes: 5 additions & 5 deletions astroid/brain/brain_random.py
Expand Up @@ -42,10 +42,10 @@ def infer_random_sample(node, context=None):
if len(node.args) != 2:
raise UseInferenceDefault

length = node.args[1]
if not isinstance(length, Const):
inferred_length = helpers.safe_infer(node.args[1], context=context)
if not isinstance(inferred_length, Const):
raise UseInferenceDefault
if not isinstance(length.value, int):
if not isinstance(inferred_length.value, int):
raise UseInferenceDefault

inferred_sequence = helpers.safe_infer(node.args[0], context=context)
Expand All @@ -55,12 +55,12 @@ def infer_random_sample(node, context=None):
if not isinstance(inferred_sequence, ACCEPTED_ITERABLES_FOR_SAMPLE):
raise UseInferenceDefault

if length.value > len(inferred_sequence.elts):
if inferred_length.value > len(inferred_sequence.elts):
# In this case, this will raise a ValueError
raise UseInferenceDefault

try:
elts = random.sample(inferred_sequence.elts, length.value)
elts = random.sample(inferred_sequence.elts, inferred_length.value)
except ValueError as exc:
raise UseInferenceDefault from exc

Expand Down
23 changes: 23 additions & 0 deletions tests/unittest_brain.py
Expand Up @@ -2413,6 +2413,29 @@ def test_inferred_successfully(self) -> None:
elems = sorted(elem.value for elem in inferred.elts)
self.assertEqual(elems, [1, 2])

def test_arguments_inferred_successfully(self) -> None:
"""Test inference of `random.sample` when both arguments are of type `nodes.Call`."""
node = astroid.extract_node(
"""
import random
def sequence():
return [1, 2]
random.sample(sequence(), len([1,2])) #@
"""
)
# Check that arguments are of type `nodes.Call`.
sequence, length = node.args
self.assertIsInstance(sequence, astroid.Call)
self.assertIsInstance(length, astroid.Call)

# Check the inference of `random.sample` call.
inferred = next(node.infer())
self.assertIsInstance(inferred, astroid.List)
elems = sorted(elem.value for elem in inferred.elts)
self.assertEqual(elems, [1, 2])

def test_no_crash_on_evaluatedobject(self) -> None:
node = astroid.extract_node(
"""
Expand Down

0 comments on commit e1aabb7

Please sign in to comment.