Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

B006 and B008: Cover additional test cases #239

Merged
merged 7 commits into from Mar 23, 2022
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Expand Up @@ -8,7 +8,7 @@ repos:
- id: isort

- repo: https://github.com/psf/black
rev: 21.10b0
rev: 22.1.0
Zac-HD marked this conversation as resolved.
Show resolved Hide resolved
hooks:
- id: black
args:
Expand Down
151 changes: 93 additions & 58 deletions bugbear.py
Expand Up @@ -354,13 +354,13 @@ def visit_Assert(self, node):

def visit_AsyncFunctionDef(self, node):
self.check_for_b902(node)
self.check_for_b006(node)
self.check_for_b006_and_b008(node)
self.generic_visit(node)

def visit_FunctionDef(self, node):
self.check_for_b901(node)
self.check_for_b902(node)
self.check_for_b006(node)
self.check_for_b006_and_b008(node)
self.check_for_b018(node)
self.check_for_b019(node)
self.check_for_b021(node)
Expand Down Expand Up @@ -390,23 +390,14 @@ def visit_With(self, node):
self.check_for_b022(node)
self.generic_visit(node)

def compose_call_path(self, node):
if isinstance(node, ast.Attribute):
yield from self.compose_call_path(node.value)
yield node.attr
elif isinstance(node, ast.Call):
yield from self.compose_call_path(node.func)
elif isinstance(node, ast.Name):
yield node.id
Comment on lines -393 to -400
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was useful in my other visitor and doesn't need to be in this class so separated it out


def check_for_b005(self, node):
if node.func.attr not in B005.methods:
return # method name doesn't match

if len(node.args) != 1 or not isinstance(node.args[0], ast.Str):
return # used arguments don't match the builtin strip

call_path = ".".join(self.compose_call_path(node.func.value))
call_path = ".".join(compose_call_path(node.func.value))
if call_path in B005.valid_paths:
return # path is exempt

Expand All @@ -419,48 +410,10 @@ def check_for_b005(self, node):

self.errors.append(B005(node.lineno, node.col_offset))

def check_for_b006(self, node):
for default in node.args.defaults + node.args.kw_defaults:
if isinstance(
default, (*B006.mutable_literals, *B006.mutable_comprehensions)
):
self.errors.append(B006(default.lineno, default.col_offset))
elif isinstance(default, ast.Call):
call_path = ".".join(self.compose_call_path(default.func))
if call_path in B006.mutable_calls:
self.errors.append(B006(default.lineno, default.col_offset))
elif (
call_path
not in B008.immutable_calls | self.b008_extend_immutable_calls
):
# Check if function call is actually a float infinity/NaN literal
if call_path == "float" and len(default.args) == 1:
float_arg = default.args[0]
if sys.version_info < (3, 8, 0):
# NOTE: pre-3.8, string literals are represented with ast.Str
if isinstance(float_arg, ast.Str):
str_val = float_arg.s
else:
str_val = ""
else:
# NOTE: post-3.8, string literals are represented with ast.Constant
if isinstance(float_arg, ast.Constant):
str_val = float_arg.value
if not isinstance(str_val, str):
str_val = ""
else:
str_val = ""

# NOTE: regex derived from documentation at:
# https://docs.python.org/3/library/functions.html#float
inf_nan_regex = r"^[+-]?(inf|infinity|nan)$"
re_result = re.search(inf_nan_regex, str_val.lower())
is_float_literal = re_result is not None
else:
is_float_literal = False

if not is_float_literal:
self.errors.append(B008(default.lineno, default.col_offset))
def check_for_b006_and_b008(self, node):
visitor = FuntionDefDefaultsVisitor(self.b008_extend_immutable_calls)
visitor.visit(node.args.defaults + node.args.kw_defaults)
self.errors.extend(visitor.errors)
Comment on lines +413 to +416
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rules now use a visitor to handle functions at any depth


def check_for_b007(self, node):
targets = NameFinder()
Expand Down Expand Up @@ -536,8 +489,7 @@ def check_for_b019(self, node):
# Preserve decorator order so we can get the lineno from the decorator node
# rather than the function node (this location definition changes in Python 3.8)
resolved_decorators = (
".".join(self.compose_call_path(decorator))
for decorator in node.decorator_list
".".join(compose_call_path(decorator)) for decorator in node.decorator_list
)
for idx, decorator in enumerate(resolved_decorators):
if decorator in {"classmethod", "staticmethod"}:
Expand Down Expand Up @@ -755,6 +707,16 @@ def check_for_b022(self, node):
self.errors.append(B022(node.lineno, node.col_offset))


def compose_call_path(node):
if isinstance(node, ast.Attribute):
yield from compose_call_path(node.value)
yield node.attr
elif isinstance(node, ast.Call):
yield from compose_call_path(node.func)
elif isinstance(node, ast.Name):
yield node.id


@attr.s
class NameFinder(ast.NodeVisitor):
"""Finds a name within a tree of nodes.
Expand All @@ -778,6 +740,79 @@ def visit(self, node):
return node


class FuntionDefDefaultsVisitor(ast.NodeVisitor):
def __init__(self, b008_extend_immutable_calls=None):
self.b008_extend_immutable_calls = b008_extend_immutable_calls or set()
for node in B006.mutable_literals + B006.mutable_comprehensions:
setattr(self, f"visit_{node}", self.visit_mutable_literal_or_comprehension)
self.errors = []
self.arg_depth = 0
super().__init__()

def visit_mutable_literal_or_comprehension(self, node):
# Flag B006 iff mutable literal/comprehension is not nested.
# We only flag these at the top level of the expression as we
# cannot easily guarantee that nested mutable structures are not
# made immutable by outer operations, so we prefer no false positives.
# e.g.
# >>> def this_is_fine(a=frozenset({"a", "b", "c"})): ...
#
# >>> def this_is_not_fine_but_hard_to_detect(a=(lambda x: x)([1, 2, 3]))
#
# We do still search for cases of B008 within mutable structures though.
if self.arg_depth == 1:
self.errors.append(B006(node.lineno, node.col_offset))
# Check for nested functions.
self.generic_visit(node)

def visit_Call(self, node):
call_path = ".".join(compose_call_path(node.func))
if call_path in B006.mutable_calls:
self.errors.append(B006(node.lineno, node.col_offset))
elif call_path not in B008.immutable_calls | self.b008_extend_immutable_calls:
# Check if function call is actually a float infinity/NaN literal
if call_path == "float" and len(node.args) == 1:
float_arg = node.args[0]
if sys.version_info < (3, 8, 0):
# NOTE: pre-3.8, string literals are represented with ast.Str
if isinstance(float_arg, ast.Str):
str_val = float_arg.s
else:
str_val = ""
else:
# NOTE: post-3.8, string literals are represented with ast.Constant
if isinstance(float_arg, ast.Constant):
str_val = float_arg.value
if not isinstance(str_val, str):
str_val = ""
else:
str_val = ""

# NOTE: regex derived from documentation at:
# https://docs.python.org/3/library/functions.html#float
inf_nan_regex = r"^[+-]?(inf|infinity|nan)$"
re_result = re.search(inf_nan_regex, str_val.lower())
is_float_literal = re_result is not None
else:
is_float_literal = False

if not is_float_literal:
self.errors.append(B008(node.lineno, node.col_offset))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I know this was previously existing code, but we can shorten it considerably by using ast.literal_eval():

Suggested change
float_arg = node.args[0]
if sys.version_info < (3, 8, 0):
# NOTE: pre-3.8, string literals are represented with ast.Str
if isinstance(float_arg, ast.Str):
str_val = float_arg.s
else:
str_val = ""
else:
# NOTE: post-3.8, string literals are represented with ast.Constant
if isinstance(float_arg, ast.Constant):
str_val = float_arg.value
if not isinstance(str_val, str):
str_val = ""
else:
str_val = ""
# NOTE: regex derived from documentation at:
# https://docs.python.org/3/library/functions.html#float
inf_nan_regex = r"^[+-]?(inf|infinity|nan)$"
re_result = re.search(inf_nan_regex, str_val.lower())
is_float_literal = re_result is not None
else:
is_float_literal = False
if not is_float_literal:
self.errors.append(B008(node.lineno, node.col_offset))
try:
value = float(ast.literal_eval(node.args[0]))
except Exception:
pass
else:
if math.isfinite(value):
self.errors.append(B008(node.lineno, node.col_offset))

(OK, I admit that this no longer warns on a redundant float() call wrapped around an "infinity literal" like 1e999, but IMO that's fine.)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! I've updated the code accordingly and it looks a lot nicer now 😄

# Check for nested functions.
self.generic_visit(node)

def visit(self, node):
"""Like super-visit but supports iteration over lists."""
self.arg_depth += 1
if isinstance(node, list):
for elem in node:
if elem is not None:
super().visit(elem)
else:
super().visit(node)
self.arg_depth -= 1


class B020NameFinder(NameFinder):
"""Ignore names defined within the local scope of a comprehension."""

Expand Down Expand Up @@ -851,8 +886,8 @@ def visit_comprehension(self, node):
"between them."
)
)
B006.mutable_literals = (ast.Dict, ast.List, ast.Set)
B006.mutable_comprehensions = (ast.ListComp, ast.DictComp, ast.SetComp)
B006.mutable_literals = ("Dict", "List", "Set")
B006.mutable_comprehensions = ("ListComp", "DictComp", "SetComp")
B006.mutable_calls = {
"Counter",
"OrderedDict",
Expand Down
103 changes: 74 additions & 29 deletions tests/b006_b008.py
@@ -1,18 +1,22 @@
import collections
import datetime as dt
import logging
import operator
import random
import re
import time
import types
from operator import attrgetter, itemgetter, methodcaller
from types import MappingProxyType


# B006
# Allow immutable literals/calls/comprehensions
def this_is_okay(value=(1, 2, 3)):
...


def and_this_also(value=tuple()):
async def and_this_also(value=tuple()):
jpy-git marked this conversation as resolved.
Show resolved Hide resolved
pass


Expand All @@ -26,6 +30,33 @@ def mappingproxytype_okay(
pass


def re_compile_ok(value=re.compile("foo")):
pass


def operators_ok(
v=operator.attrgetter("foo"),
v2=operator.itemgetter("foo"),
v3=operator.methodcaller("foo"),
):
pass


def operators_ok_unqualified(
v=attrgetter("foo"),
v2=itemgetter("foo"),
v3=methodcaller("foo"),
):
pass


def kwonlyargs_immutable(*, value=()):
...


# Flag mutable literals/comprehensions


def this_is_wrong(value=[1, 2, 3]):
...

Expand All @@ -42,35 +73,61 @@ def this_too(value=collections.OrderedDict()):
...


async def async_this_too(value=collections.OrderedDict()):
async def async_this_too(value=collections.defaultdict()):
...


def dont_forget_me(value=collections.deque()):
...


# N.B. we're also flagging the function call in the comprehension
def list_comprehension_also_not_okay(default=[i**2 for i in range(3)]):
pass


def dict_comprehension_also_not_okay(default={i: i**2 for i in range(3)}):
pass


def set_comprehension_also_not_okay(default={i**2 for i in range(3)}):
pass


def kwonlyargs_mutable(*, value=[]):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@cooperlees Re: the question on removing this. I rearranged this test file to structure it a bit more and help come up with extra test cases. The function is still here 😄

...


# Recommended approach for mutable defaults
def do_this_instead(value=None):
if value is None:
value = set()


# B008
# Flag function calls as default args (including if they are part of a sub-expression)
def in_fact_all_calls_are_wrong(value=time.time()):
...


LOGGER = logging.getLogger(__name__)
def f(when=dt.datetime.now() + dt.timedelta(days=7)):
pass


def do_this_instead_of_calls_in_defaults(logger=LOGGER):
# That makes it more obvious that this one value is reused.
def can_even_catch_lambdas(a=(lambda x: x)()):
...


def kwonlyargs_immutable(*, value=()):
...
# Recommended approach for function calls as default args
LOGGER = logging.getLogger(__name__)


def kwonlyargs_mutable(*, value=[]):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need an empty list for coverage? I guess you removed this cause of all the list comprehension tests now ...

def do_this_instead_of_calls_in_defaults(logger=LOGGER):
# That makes it more obvious that this one value is reused.
...


# Handle inf/infinity/nan special case
def float_inf_okay(value=float("inf")):
pass

Expand All @@ -95,6 +152,7 @@ def float_minus_NaN_okay(value=float("-NaN")):
pass


# But don't allow standard floats
def float_int_is_wrong(value=float(3)):
pass

Expand All @@ -103,31 +161,18 @@ def float_str_not_inf_or_nan_is_wrong(value=float("3.14")):
pass


def re_compile_ok(value=re.compile("foo")):
pass


def operators_ok(
v=operator.attrgetter("foo"),
v2=operator.itemgetter("foo"),
v3=operator.methodcaller("foo"),
):
pass


def operators_ok_unqualified(
v=attrgetter("foo"), v2=itemgetter("foo"), v3=methodcaller("foo")
):
pass


def list_comprehension_also_not_okay(default=[i ** 2 for i in range(3)]):
# B006 and B008
# We should handle arbitrary nesting of these B008.
def nested_combo(a=[float(3), dt.datetime.now()]):
pass


def dict_comprehension_also_not_okay(default={i: i ** 2 for i in range(3)}):
# Don't flag nested B006 since we can't guarantee that
# it isn't made mutable by the outer operation.
def no_nested_b006(a=map(lambda s: s.upper(), ["a", "b", "c"])):
pass


def set_comprehension_also_not_okay(default={i ** 2 for i in range(3)}):
# B008-ception.
def nested_b008(a=random.randint(0, dt.datetime.now().year)):
pass