diff --git a/bandit/core/context.py b/bandit/core/context.py index cf5fe42db..2874f66ea 100644 --- a/bandit/core/context.py +++ b/bandit/core/context.py @@ -14,7 +14,7 @@ # License for the specific language governing permissions and limitations # under the License. -import _ast +import ast import six @@ -57,18 +57,19 @@ def call_args(self): :return: A list of function args ''' args = [] - for arg in self._context['call'].args: - if hasattr(arg, 'attr'): - args.append(arg.attr) - else: - args.append(self._get_literal_value(arg)) + if 'call' in self._context and hasattr(self._context['call'], 'args'): + for arg in self._context['call'].args: + if hasattr(arg, 'attr'): + args.append(arg.attr) + else: + args.append(self._get_literal_value(arg)) return args @property def call_args_count(self): '''Get the number of args a function call has - :return: The number of args a function call has + :return: The number of args a function call has or None ''' if 'call' in self._context and hasattr(self._context['call'], 'args'): return len(self._context['call'].args) @@ -81,10 +82,7 @@ def call_function_name(self): :return: The name (not FQ) of a function call ''' - if 'name' in self._context: - return self._context['name'] - else: - return None + return self._context.get('name') @property def call_function_name_qual(self): @@ -92,10 +90,7 @@ def call_function_name_qual(self): :return: The FQ name of a function call ''' - if 'qualname' in self._context: - return self._context['qualname'] - else: - return None + return self._context.get('qualname') @property def call_keywords(self): @@ -121,10 +116,7 @@ def node(self): :return: The raw AST node associated with the context ''' - if 'node' in self._context: - return self._context['node'] - else: - return None + return self._context.get('node') @property def string_val(self): @@ -132,10 +124,7 @@ def string_val(self): :return: value of a standalone unicode or string object ''' - if 'str' in self._context: - return self._context['str'] - else: - return None + return self._context.get('str') @property def bytes_val(self): @@ -174,10 +163,7 @@ def statement(self): :return: The raw AST for the current statement ''' - if 'statement' in self._context: - return self._context['statement'] - else: - return None + return self._context.get('statement') @property def function_def_defaults_qual(self): @@ -186,7 +172,9 @@ def function_def_defaults_qual(self): :return: List of defaults ''' defaults = [] - if 'node' in self._context: + if ('node' in self._context and + hasattr(self._context['node'], 'args') and + hasattr(self._context['node'].args, 'defaults')): for default in self._context['node'].args.defaults: defaults.append(utils.get_qual_attr( default, @@ -199,48 +187,48 @@ def _get_literal_value(self, literal): :param literal: The AST literal to convert :return: The value of the AST literal ''' - if isinstance(literal, _ast.Num): + if isinstance(literal, ast.Num): literal_value = literal.n - elif isinstance(literal, _ast.Str): + elif isinstance(literal, ast.Str): literal_value = literal.s - elif isinstance(literal, _ast.List): + elif isinstance(literal, ast.List): return_list = list() for li in literal.elts: return_list.append(self._get_literal_value(li)) literal_value = return_list - elif isinstance(literal, _ast.Tuple): + elif isinstance(literal, ast.Tuple): return_tuple = tuple() for ti in literal.elts: return_tuple = return_tuple + (self._get_literal_value(ti),) literal_value = return_tuple - elif isinstance(literal, _ast.Set): + elif isinstance(literal, ast.Set): return_set = set() for si in literal.elts: return_set.add(self._get_literal_value(si)) literal_value = return_set - elif isinstance(literal, _ast.Dict): + elif isinstance(literal, ast.Dict): literal_value = dict(zip(literal.keys, literal.values)) - elif isinstance(literal, _ast.Ellipsis): + elif isinstance(literal, ast.Ellipsis): # what do we want to do with this? literal_value = None - elif isinstance(literal, _ast.Name): + elif isinstance(literal, ast.Name): literal_value = literal.id # NOTE(sigmavirus24): NameConstants are only part of the AST in Python # 3. NameConstants tend to refer to things like True and False. This # prevents them from being re-assigned in Python 3. - elif six.PY3 and isinstance(literal, _ast.NameConstant): + elif six.PY3 and isinstance(literal, ast.NameConstant): literal_value = str(literal.value) # NOTE(sigmavirus24): Bytes are only part of the AST in Python 3 - elif six.PY3 and isinstance(literal, _ast.Bytes): + elif six.PY3 and isinstance(literal, ast.Bytes): literal_value = literal.s else: @@ -288,9 +276,10 @@ def get_lineno_for_call_arg(self, argument_name): :param argument_name: A string - name of the argument to look for :return: Integer - the line number of the found argument, or -1 ''' - for key in self.node.keywords: - if key.arg == argument_name: - return key.value.lineno + if hasattr(self.node, 'keywords'): + for key in self.node.keywords: + if key.arg == argument_name: + return key.value.lineno def get_call_arg_at_position(self, position_num): '''Returns positional argument at the specified position (if it exists) @@ -298,9 +287,8 @@ def get_call_arg_at_position(self, position_num): :param position_num: The index of the argument to return the value for :return: Value of the argument at the specified position if it exists ''' - if ('call' in self._context and - hasattr(self._context['call'], 'args') and - position_num < len(self._context['call'].args)): + max_args = self.call_args_count + if max_args and position_num < max_args: return self._get_literal_value( self._context['call'].args[position_num] ) @@ -313,7 +301,7 @@ def is_module_being_imported(self, module): :param module: The module name to look for :return: True if the module is found, False otherwise ''' - return 'module' in self._context and self._context['module'] == module + return self._context.get('module') == module def is_module_imported_exact(self, module): '''Check if a specified module has been imported; only exact matches. @@ -321,8 +309,7 @@ def is_module_imported_exact(self, module): :param module: The module name to look for :return: True if the module is found, False otherwise ''' - return ('imports' in self._context and - module in self._context['imports']) + return module in self._context.get('imports', []) def is_module_imported_like(self, module): '''Check if a specified module has been imported