Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

B006 and B008: Cover additional test cases #239

Merged
merged 7 commits into from Mar 23, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Expand Up @@ -8,7 +8,7 @@ repos:
- id: isort

- repo: https://github.com/psf/black
rev: 21.10b0
rev: 22.1.0
Zac-HD marked this conversation as resolved.
Show resolved Hide resolved
hooks:
- id: black
args:
Expand Down
5 changes: 5 additions & 0 deletions README.rst
Expand Up @@ -275,6 +275,11 @@ MIT
Change Log
----------

<release-tbd>
~~~~~~~~~~

* B006 and B008: Detect function calls at any level of the default expression.

22.3.20
~~~~~~~~~~

Expand Down
143 changes: 84 additions & 59 deletions bugbear.py
Expand Up @@ -2,8 +2,8 @@
import builtins
import itertools
import logging
import math
import re
import sys
from collections import namedtuple
from contextlib import suppress
from functools import lru_cache, partial
Expand Down Expand Up @@ -354,13 +354,13 @@ def visit_Assert(self, node):

def visit_AsyncFunctionDef(self, node):
self.check_for_b902(node)
self.check_for_b006(node)
self.check_for_b006_and_b008(node)
self.generic_visit(node)

def visit_FunctionDef(self, node):
self.check_for_b901(node)
self.check_for_b902(node)
self.check_for_b006(node)
self.check_for_b006_and_b008(node)
self.check_for_b018(node)
self.check_for_b019(node)
self.check_for_b021(node)
Expand Down Expand Up @@ -390,23 +390,14 @@ def visit_With(self, node):
self.check_for_b022(node)
self.generic_visit(node)

def compose_call_path(self, node):
if isinstance(node, ast.Attribute):
yield from self.compose_call_path(node.value)
yield node.attr
elif isinstance(node, ast.Call):
yield from self.compose_call_path(node.func)
elif isinstance(node, ast.Name):
yield node.id
Comment on lines -393 to -400
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was useful in my other visitor and doesn't need to be in this class so separated it out


def check_for_b005(self, node):
if node.func.attr not in B005.methods:
return # method name doesn't match

if len(node.args) != 1 or not isinstance(node.args[0], ast.Str):
return # used arguments don't match the builtin strip

call_path = ".".join(self.compose_call_path(node.func.value))
call_path = ".".join(compose_call_path(node.func.value))
if call_path in B005.valid_paths:
return # path is exempt

Expand All @@ -419,48 +410,10 @@ def check_for_b005(self, node):

self.errors.append(B005(node.lineno, node.col_offset))

def check_for_b006(self, node):
for default in node.args.defaults + node.args.kw_defaults:
if isinstance(
default, (*B006.mutable_literals, *B006.mutable_comprehensions)
):
self.errors.append(B006(default.lineno, default.col_offset))
elif isinstance(default, ast.Call):
call_path = ".".join(self.compose_call_path(default.func))
if call_path in B006.mutable_calls:
self.errors.append(B006(default.lineno, default.col_offset))
elif (
call_path
not in B008.immutable_calls | self.b008_extend_immutable_calls
):
# Check if function call is actually a float infinity/NaN literal
if call_path == "float" and len(default.args) == 1:
float_arg = default.args[0]
if sys.version_info < (3, 8, 0):
# NOTE: pre-3.8, string literals are represented with ast.Str
if isinstance(float_arg, ast.Str):
str_val = float_arg.s
else:
str_val = ""
else:
# NOTE: post-3.8, string literals are represented with ast.Constant
if isinstance(float_arg, ast.Constant):
str_val = float_arg.value
if not isinstance(str_val, str):
str_val = ""
else:
str_val = ""

# NOTE: regex derived from documentation at:
# https://docs.python.org/3/library/functions.html#float
inf_nan_regex = r"^[+-]?(inf|infinity|nan)$"
re_result = re.search(inf_nan_regex, str_val.lower())
is_float_literal = re_result is not None
else:
is_float_literal = False

if not is_float_literal:
self.errors.append(B008(default.lineno, default.col_offset))
def check_for_b006_and_b008(self, node):
visitor = FuntionDefDefaultsVisitor(self.b008_extend_immutable_calls)
visitor.visit(node.args.defaults + node.args.kw_defaults)
self.errors.extend(visitor.errors)
Comment on lines +413 to +416
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rules now use a visitor to handle functions at any depth


def check_for_b007(self, node):
targets = NameFinder()
Expand Down Expand Up @@ -536,8 +489,7 @@ def check_for_b019(self, node):
# Preserve decorator order so we can get the lineno from the decorator node
# rather than the function node (this location definition changes in Python 3.8)
resolved_decorators = (
".".join(self.compose_call_path(decorator))
for decorator in node.decorator_list
".".join(compose_call_path(decorator)) for decorator in node.decorator_list
)
for idx, decorator in enumerate(resolved_decorators):
if decorator in {"classmethod", "staticmethod"}:
Expand Down Expand Up @@ -755,6 +707,16 @@ def check_for_b022(self, node):
self.errors.append(B022(node.lineno, node.col_offset))


def compose_call_path(node):
if isinstance(node, ast.Attribute):
yield from compose_call_path(node.value)
yield node.attr
elif isinstance(node, ast.Call):
yield from compose_call_path(node.func)
elif isinstance(node, ast.Name):
yield node.id


@attr.s
class NameFinder(ast.NodeVisitor):
"""Finds a name within a tree of nodes.
Expand All @@ -778,6 +740,69 @@ def visit(self, node):
return node


class FuntionDefDefaultsVisitor(ast.NodeVisitor):
def __init__(self, b008_extend_immutable_calls=None):
self.b008_extend_immutable_calls = b008_extend_immutable_calls or set()
for node in B006.mutable_literals + B006.mutable_comprehensions:
setattr(self, f"visit_{node}", self.visit_mutable_literal_or_comprehension)
self.errors = []
self.arg_depth = 0
super().__init__()

def visit_mutable_literal_or_comprehension(self, node):
# Flag B006 iff mutable literal/comprehension is not nested.
# We only flag these at the top level of the expression as we
# cannot easily guarantee that nested mutable structures are not
# made immutable by outer operations, so we prefer no false positives.
# e.g.
# >>> def this_is_fine(a=frozenset({"a", "b", "c"})): ...
#
# >>> def this_is_not_fine_but_hard_to_detect(a=(lambda x: x)([1, 2, 3]))
#
# We do still search for cases of B008 within mutable structures though.
if self.arg_depth == 1:
self.errors.append(B006(node.lineno, node.col_offset))
# Check for nested functions.
self.generic_visit(node)

def visit_Call(self, node):
call_path = ".".join(compose_call_path(node.func))
if call_path in B006.mutable_calls:
self.errors.append(B006(node.lineno, node.col_offset))
self.generic_visit(node)
return

if call_path in B008.immutable_calls | self.b008_extend_immutable_calls:
self.generic_visit(node)
return

# Check if function call is actually a float infinity/NaN literal
if call_path == "float" and len(node.args) == 1:
try:
value = float(ast.literal_eval(node.args[0]))
except Exception:
pass
else:
if math.isfinite(value):
self.errors.append(B008(node.lineno, node.col_offset))
else:
self.errors.append(B008(node.lineno, node.col_offset))

# Check for nested functions.
self.generic_visit(node)

def visit(self, node):
"""Like super-visit but supports iteration over lists."""
self.arg_depth += 1
if isinstance(node, list):
for elem in node:
if elem is not None:
super().visit(elem)
else:
super().visit(node)
self.arg_depth -= 1


class B020NameFinder(NameFinder):
"""Ignore names defined within the local scope of a comprehension."""

Expand Down Expand Up @@ -851,8 +876,8 @@ def visit_comprehension(self, node):
"between them."
)
)
B006.mutable_literals = (ast.Dict, ast.List, ast.Set)
B006.mutable_comprehensions = (ast.ListComp, ast.DictComp, ast.SetComp)
B006.mutable_literals = ("Dict", "List", "Set")
B006.mutable_comprehensions = ("ListComp", "DictComp", "SetComp")
B006.mutable_calls = {
"Counter",
"OrderedDict",
Expand Down