Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Avoid calls to ast in plugins #918

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
6 changes: 3 additions & 3 deletions bandit/core/blacklisting.py
Expand Up @@ -2,10 +2,10 @@
# Copyright 2016 Hewlett-Packard Development Company, L.P.
#
# SPDX-License-Identifier: Apache-2.0
import ast
import fnmatch

from bandit.core import issue
from bandit.core import utils


def report_issue(check, name):
Expand Down Expand Up @@ -34,9 +34,9 @@ def blacklist(context, config):

if node_type == "Call":
func = context.node.func
if isinstance(func, ast.Name) and func.id == "__import__":
if utils.is_instance(func, "Name") and func.id == "__import__":
if len(context.node.args):
if isinstance(context.node.args[0], ast.Str):
if utils.is_instance(context.node.args[0], "Str"):
name = context.node.args[0].s
else:
# TODO(??): import through a variable, need symbol tab
Expand Down
22 changes: 10 additions & 12 deletions bandit/core/context.py
Expand Up @@ -2,8 +2,6 @@
# Copyright 2014 Hewlett-Packard Development Company, L.P.
#
# SPDX-License-Identifier: Apache-2.0
import ast

from bandit.core import utils


Expand Down Expand Up @@ -178,44 +176,44 @@ def _get_literal_value(self, literal):
:param literal: The AST literal to convert
:return: The value of the AST literal
"""
if isinstance(literal, ast.Num):
if utils.is_instance(literal, "Num"):
literal_value = literal.n

elif isinstance(literal, ast.Str):
elif utils.is_instance(literal, "Str"):
literal_value = literal.s

elif isinstance(literal, ast.List):
elif utils.is_instance(literal, "List"):
return_list = list()
for li in literal.elts:
return_list.append(self._get_literal_value(li))
literal_value = return_list

elif isinstance(literal, ast.Tuple):
elif utils.is_instance(literal, "Tuple"):
return_tuple = tuple()
for ti in literal.elts:
return_tuple = return_tuple + (self._get_literal_value(ti),)
literal_value = return_tuple

elif isinstance(literal, ast.Set):
elif utils.is_instance(literal, "Set"):
return_set = set()
for si in literal.elts:
return_set.add(self._get_literal_value(si))
literal_value = return_set

elif isinstance(literal, ast.Dict):
elif utils.is_instance(literal, "Dict"):
literal_value = dict(zip(literal.keys, literal.values))

elif isinstance(literal, ast.Ellipsis):
elif utils.is_instance(literal, "Ellipsis"):
# what do we want to do with this?
literal_value = None

elif isinstance(literal, ast.Name):
elif utils.is_instance(literal, "Name"):
literal_value = literal.id

elif isinstance(literal, ast.NameConstant):
elif utils.is_instance(literal, "NameConstant"):
literal_value = str(literal.value)

elif isinstance(literal, ast.Bytes):
elif utils.is_instance(literal, "Bytes"):
literal_value = literal.s

else:
Expand Down
11 changes: 11 additions & 0 deletions bandit/core/utils.py
Expand Up @@ -6,6 +6,7 @@
import logging
import os.path
import sys
from operator import attrgetter

try:
import configparser
Expand Down Expand Up @@ -370,3 +371,13 @@ def check_ast_node(name):
pass

raise TypeError("Error: %s is not a valid node type in AST" % name)


def is_instance(node, type_name):
"Check if the given node is an instance AST type."
if isinstance(type_name, tuple):
f = attrgetter(*type_name)
return isinstance(node, f(ast))
else:
node_type = getattr(ast, type_name)
return isinstance(node, node_type)
17 changes: 8 additions & 9 deletions bandit/plugins/django_sql_injection.py
Expand Up @@ -2,17 +2,16 @@
# Copyright (C) 2018 [Victor Torre](https://github.com/ehooo)
#
# SPDX-License-Identifier: Apache-2.0
import ast

import bandit
from bandit.core import issue
from bandit.core import test_properties as test
from bandit.core import utils


def keywords2dict(keywords):
kwargs = {}
for node in keywords:
if isinstance(node, ast.keyword):
if utils.is_instance(node, "keyword"):
kwargs[node.arg] = node.value
return kwargs

Expand Down Expand Up @@ -66,23 +65,23 @@ def django_extra_used(context):
insecure = False
for key in ["where", "tables"]:
if key in kwargs:
if isinstance(kwargs[key], ast.List):
if utils.is_instance(kwargs[key], "List"):
for val in kwargs[key].elts:
if not isinstance(val, ast.Str):
if not utils.is_instance(val, "Str"):
insecure = True
break
else:
insecure = True
break
if not insecure and "select" in kwargs:
if isinstance(kwargs["select"], ast.Dict):
if utils.is_instance(kwargs["select"], "Dict"):
for k in kwargs["select"].keys:
if not isinstance(k, ast.Str):
if not utils.is_instance(k, "Str"):
insecure = True
break
if not insecure:
for v in kwargs["select"].values:
if not isinstance(v, ast.Str):
if not utils.is_instance(v, "Str"):
insecure = True
break
else:
Expand Down Expand Up @@ -130,7 +129,7 @@ def django_rawsql_used(context):
if context.is_module_imported_like("django.db.models"):
if context.call_function_name == "RawSQL":
sql = context.node.args[0]
if not isinstance(sql, ast.Str):
if not utils.is_instance(sql, "Str"):
return bandit.Issue(
severity=bandit.MEDIUM,
confidence=bandit.MEDIUM,
Expand Down
86 changes: 46 additions & 40 deletions bandit/plugins/django_xss.py
Expand Up @@ -7,6 +7,7 @@
import bandit
from bandit.core import issue
from bandit.core import test_properties as test
from bandit.core import utils


class DeepAssignation:
Expand All @@ -32,45 +33,45 @@ def is_assigned(self, node):
if isinstance(node, self.ignore_nodes):
return assigned

if isinstance(node, ast.Expr):
if utils.is_instance(node, "Expr"):
assigned = self.is_assigned(node.value)
elif isinstance(node, ast.FunctionDef):
elif utils.is_instance(node, "FunctionDef"):
for name in node.args.args:
if isinstance(name, ast.Name):
if utils.is_instance(name, "Name"):
if name.id == self.var_name.id:
# If is param the assignations are not affected
return assigned
assigned = self.is_assigned_in(node.body)
elif isinstance(node, ast.With):
elif utils.is_instance(node, "With"):
for withitem in node.items:
var_id = getattr(withitem.optional_vars, "id", None)
if var_id == self.var_name.id:
assigned = node
else:
assigned = self.is_assigned_in(node.body)
elif isinstance(node, ast.Try):
elif utils.is_instance(node, "Try"):
assigned = []
assigned.extend(self.is_assigned_in(node.body))
assigned.extend(self.is_assigned_in(node.handlers))
assigned.extend(self.is_assigned_in(node.orelse))
assigned.extend(self.is_assigned_in(node.finalbody))
elif isinstance(node, ast.ExceptHandler):
elif utils.is_instance(node, "ExceptHandler"):
assigned = []
assigned.extend(self.is_assigned_in(node.body))
elif isinstance(node, (ast.If, ast.For, ast.While)):
elif utils.is_instance(node, ("If", "For", "While")):
assigned = []
assigned.extend(self.is_assigned_in(node.body))
assigned.extend(self.is_assigned_in(node.orelse))
elif isinstance(node, ast.AugAssign):
if isinstance(node.target, ast.Name):
elif utils.is_instance(node, "AugAssign"):
if utils.is_instance(node.target, "Name"):
if node.target.id == self.var_name.id:
assigned = node.value
elif isinstance(node, ast.Assign) and node.targets:
elif utils.is_instance(node, "Assign") and node.targets:
target = node.targets[0]
if isinstance(target, ast.Name):
if utils.is_instance(target, "Name"):
if target.id == self.var_name.id:
assigned = node.value
elif isinstance(target, ast.Tuple):
elif utils.is_instance(target, "Tuple"):
pos = 0
for name in target.elts:
if name.id == self.var_name.id:
Expand All @@ -82,8 +83,8 @@ def is_assigned(self, node):

def evaluate_var(xss_var, parent, until, ignore_nodes=None):
secure = False
if isinstance(xss_var, ast.Name):
if isinstance(parent, ast.FunctionDef):
if utils.is_instance(xss_var, "Name"):
if utils.is_instance(parent, "FunctionDef"):
for name in parent.args.args:
if name.arg == xss_var.id:
return False # Params are not secure
Expand All @@ -94,18 +95,18 @@ def evaluate_var(xss_var, parent, until, ignore_nodes=None):
break
to = analyser.is_assigned(node)
if to:
if isinstance(to, ast.Str):
if utils.is_instance(to, "Str"):
secure = True
elif isinstance(to, ast.Name):
elif utils.is_instance(to, "Name"):
secure = evaluate_var(to, parent, to.lineno, ignore_nodes)
elif isinstance(to, ast.Call):
elif utils.is_instance(to, "Call"):
secure = evaluate_call(to, parent, ignore_nodes)
elif isinstance(to, (list, tuple)):
num_secure = 0
for some_to in to:
if isinstance(some_to, ast.Str):
if utils.is_instance(some_to, "Str"):
num_secure += 1
elif isinstance(some_to, ast.Name):
elif utils.is_instance(some_to, "Name"):
if evaluate_var(
some_to, parent, node.lineno, ignore_nodes
):
Expand All @@ -128,8 +129,13 @@ def evaluate_var(xss_var, parent, until, ignore_nodes=None):
def evaluate_call(call, parent, ignore_nodes=None):
secure = False
evaluate = False
if isinstance(call, ast.Call) and isinstance(call.func, ast.Attribute):
if isinstance(call.func.value, ast.Str) and call.func.attr == "format":
if utils.is_instance(call, "Call") and utils.is_instance(
call.func, "Attribute"
):
if (
utils.is_instance(call.func.value, "Str")
and call.func.attr == "format"
):
evaluate = True
if call.keywords:
evaluate = False # TODO(??) get support for this
Expand All @@ -138,20 +144,20 @@ def evaluate_call(call, parent, ignore_nodes=None):
args = list(call.args)
num_secure = 0
for arg in args:
if isinstance(arg, ast.Str):
if utils.is_instance(arg, "Str"):
num_secure += 1
elif isinstance(arg, ast.Name):
elif utils.is_instance(arg, "Name"):
if evaluate_var(arg, parent, call.lineno, ignore_nodes):
num_secure += 1
else:
break
elif isinstance(arg, ast.Call):
elif utils.is_instance(arg, "Call"):
if evaluate_call(arg, parent, ignore_nodes):
num_secure += 1
else:
break
elif isinstance(arg, ast.Starred) and isinstance(
arg.value, (ast.List, ast.Tuple)
elif utils.is_instance(arg, "Starred") and utils.is_instance(
arg.value, ("List", "Tuple")
):
args.extend(arg.value.elts)
num_secure += 1
Expand All @@ -163,9 +169,9 @@ def evaluate_call(call, parent, ignore_nodes=None):


def transform2call(var):
if isinstance(var, ast.BinOp):
is_mod = isinstance(var.op, ast.Mod)
is_left_str = isinstance(var.left, ast.Str)
if utils.is_instance(var, "BinOp"):
is_mod = utils.is_instance(var.op, "Mod")
is_left_str = utils.is_instance(var.left, "Str")
if is_mod and is_left_str:
new_call = ast.Call()
new_call.args = []
Expand All @@ -175,7 +181,7 @@ def transform2call(var):
new_call.func = ast.Attribute()
new_call.func.value = var.left
new_call.func.attr = "format"
if isinstance(var.right, ast.Tuple):
if utils.is_instance(var.right, "Tuple"):
new_call.args = var.right.elts
else:
new_call.args = [var.right]
Expand All @@ -188,32 +194,32 @@ def check_risk(node):

secure = False

if isinstance(xss_var, ast.Name):
if utils.is_instance(xss_var, "Name"):
# Check if the var are secure
parent = node._bandit_parent
while not isinstance(parent, (ast.Module, ast.FunctionDef)):
while not utils.is_instance(parent, ("Module", "FunctionDef")):
parent = parent._bandit_parent

is_param = False
if isinstance(parent, ast.FunctionDef):
if utils.is_instance(parent, "FunctionDef"):
for name in parent.args.args:
if name.arg == xss_var.id:
is_param = True
break

if not is_param:
secure = evaluate_var(xss_var, parent, node.lineno)
elif isinstance(xss_var, ast.Call):
elif utils.is_instance(xss_var, "Call"):
parent = node._bandit_parent
while not isinstance(parent, (ast.Module, ast.FunctionDef)):
while not utils.is_instance(parent, ("Module", "FunctionDef")):
parent = parent._bandit_parent
secure = evaluate_call(xss_var, parent)
elif isinstance(xss_var, ast.BinOp):
is_mod = isinstance(xss_var.op, ast.Mod)
is_left_str = isinstance(xss_var.left, ast.Str)
elif utils.is_instance(xss_var, "BinOp"):
is_mod = utils.is_instance(xss_var.op, "Mod")
is_left_str = utils.is_instance(xss_var.left, "Str")
if is_mod and is_left_str:
parent = node._bandit_parent
while not isinstance(parent, (ast.Module, ast.FunctionDef)):
while not utils.is_instance(parent, ("Module", "FunctionDef")):
parent = parent._bandit_parent
new_call = transform2call(xss_var)
secure = evaluate_call(new_call, parent)
Expand Down Expand Up @@ -270,5 +276,5 @@ def django_mark_safe(context):
]
if context.call_function_name in affected_functions:
xss = context.node.args[0]
if not isinstance(xss, ast.Str):
if not utils.is_instance(xss, "Str"):
return check_risk(context.node)