Skip to content

Commit

Permalink
Fix context class (#449)
Browse files Browse the repository at this point in the history
* fix bugs
* improve _context access
* change _ast to ast
* fix typo
  • Loading branch information
ehooo authored and ericwb committed Feb 20, 2019
1 parent 3371a6d commit 09b0207
Showing 1 changed file with 34 additions and 47 deletions.
81 changes: 34 additions & 47 deletions bandit/core/context.py
Expand Up @@ -14,7 +14,7 @@
# License for the specific language governing permissions and limitations
# under the License.

import _ast
import ast

import six

Expand Down Expand Up @@ -57,18 +57,19 @@ def call_args(self):
:return: A list of function args
'''
args = []
for arg in self._context['call'].args:
if hasattr(arg, 'attr'):
args.append(arg.attr)
else:
args.append(self._get_literal_value(arg))
if 'call' in self._context and hasattr(self._context['call'], 'args'):
for arg in self._context['call'].args:
if hasattr(arg, 'attr'):
args.append(arg.attr)
else:
args.append(self._get_literal_value(arg))
return args

@property
def call_args_count(self):
'''Get the number of args a function call has
:return: The number of args a function call has
:return: The number of args a function call has or None
'''
if 'call' in self._context and hasattr(self._context['call'], 'args'):
return len(self._context['call'].args)
Expand All @@ -81,21 +82,15 @@ def call_function_name(self):
:return: The name (not FQ) of a function call
'''
if 'name' in self._context:
return self._context['name']
else:
return None
return self._context.get('name')

@property
def call_function_name_qual(self):
'''Get the FQ name of a function call
:return: The FQ name of a function call
'''
if 'qualname' in self._context:
return self._context['qualname']
else:
return None
return self._context.get('qualname')

@property
def call_keywords(self):
Expand All @@ -121,21 +116,15 @@ def node(self):
:return: The raw AST node associated with the context
'''
if 'node' in self._context:
return self._context['node']
else:
return None
return self._context.get('node')

@property
def string_val(self):
'''Get the value of a standalone unicode or string object
:return: value of a standalone unicode or string object
'''
if 'str' in self._context:
return self._context['str']
else:
return None
return self._context.get('str')

@property
def bytes_val(self):
Expand Down Expand Up @@ -174,10 +163,7 @@ def statement(self):
:return: The raw AST for the current statement
'''
if 'statement' in self._context:
return self._context['statement']
else:
return None
return self._context.get('statement')

@property
def function_def_defaults_qual(self):
Expand All @@ -186,7 +172,9 @@ def function_def_defaults_qual(self):
:return: List of defaults
'''
defaults = []
if 'node' in self._context:
if ('node' in self._context and
hasattr(self._context['node'], 'args') and
hasattr(self._context['node'].args, 'defaults')):
for default in self._context['node'].args.defaults:
defaults.append(utils.get_qual_attr(
default,
Expand All @@ -199,48 +187,48 @@ def _get_literal_value(self, literal):
:param literal: The AST literal to convert
:return: The value of the AST literal
'''
if isinstance(literal, _ast.Num):
if isinstance(literal, ast.Num):
literal_value = literal.n

elif isinstance(literal, _ast.Str):
elif isinstance(literal, ast.Str):
literal_value = literal.s

elif isinstance(literal, _ast.List):
elif isinstance(literal, ast.List):
return_list = list()
for li in literal.elts:
return_list.append(self._get_literal_value(li))
literal_value = return_list

elif isinstance(literal, _ast.Tuple):
elif isinstance(literal, ast.Tuple):
return_tuple = tuple()
for ti in literal.elts:
return_tuple = return_tuple + (self._get_literal_value(ti),)
literal_value = return_tuple

elif isinstance(literal, _ast.Set):
elif isinstance(literal, ast.Set):
return_set = set()
for si in literal.elts:
return_set.add(self._get_literal_value(si))
literal_value = return_set

elif isinstance(literal, _ast.Dict):
elif isinstance(literal, ast.Dict):
literal_value = dict(zip(literal.keys, literal.values))

elif isinstance(literal, _ast.Ellipsis):
elif isinstance(literal, ast.Ellipsis):
# what do we want to do with this?
literal_value = None

elif isinstance(literal, _ast.Name):
elif isinstance(literal, ast.Name):
literal_value = literal.id

# NOTE(sigmavirus24): NameConstants are only part of the AST in Python
# 3. NameConstants tend to refer to things like True and False. This
# prevents them from being re-assigned in Python 3.
elif six.PY3 and isinstance(literal, _ast.NameConstant):
elif six.PY3 and isinstance(literal, ast.NameConstant):
literal_value = str(literal.value)

# NOTE(sigmavirus24): Bytes are only part of the AST in Python 3
elif six.PY3 and isinstance(literal, _ast.Bytes):
elif six.PY3 and isinstance(literal, ast.Bytes):
literal_value = literal.s

else:
Expand Down Expand Up @@ -288,19 +276,19 @@ def get_lineno_for_call_arg(self, argument_name):
:param argument_name: A string - name of the argument to look for
:return: Integer - the line number of the found argument, or -1
'''
for key in self.node.keywords:
if key.arg == argument_name:
return key.value.lineno
if hasattr(self.node, 'keywords'):
for key in self.node.keywords:
if key.arg == argument_name:
return key.value.lineno

def get_call_arg_at_position(self, position_num):
'''Returns positional argument at the specified position (if it exists)
:param position_num: The index of the argument to return the value for
:return: Value of the argument at the specified position if it exists
'''
if ('call' in self._context and
hasattr(self._context['call'], 'args') and
position_num < len(self._context['call'].args)):
max_args = self.call_args_count
if max_args and position_num < max_args:
return self._get_literal_value(
self._context['call'].args[position_num]
)
Expand All @@ -313,16 +301,15 @@ def is_module_being_imported(self, module):
:param module: The module name to look for
:return: True if the module is found, False otherwise
'''
return 'module' in self._context and self._context['module'] == module
return self._context.get('module') == module

def is_module_imported_exact(self, module):
'''Check if a specified module has been imported; only exact matches.
:param module: The module name to look for
:return: True if the module is found, False otherwise
'''
return ('imports' in self._context and
module in self._context['imports'])
return module in self._context.get('imports', [])

def is_module_imported_like(self, module):
'''Check if a specified module has been imported
Expand Down

0 comments on commit 09b0207

Please sign in to comment.