diff --git a/safety/cli.py b/safety/cli.py index 41130e1c..42eb9ff6 100644 --- a/safety/cli.py +++ b/safety/cli.py @@ -18,7 +18,8 @@ from safety.safety import get_packages, read_vulnerabilities, fetch_policy, post_results from safety.util import get_proxy_dict, get_packages_licenses, output_exception, \ MutuallyExclusiveOption, DependentOption, transform_ignore, SafetyPolicyFile, active_color_if_needed, \ - get_processed_options, get_safety_version, json_alias, bare_alias, SafetyContext, is_a_remote_mirror + get_processed_options, get_safety_version, json_alias, bare_alias, SafetyContext, is_a_remote_mirror, \ + filter_announcements LOG = logging.getLogger(__name__) @@ -46,6 +47,11 @@ def cli(ctx, debug, telemetry, disable_optional_telemetry_data): LOG.info(f'Telemetry enabled: {ctx.telemetry}') + @ctx.call_on_close + def clean_up_on_close(): + LOG.debug('Calling clean up on close function.') + safety.close_session() + @cli.command() @click.option("--key", default="", envvar="SAFETY_API_KEY", @@ -105,11 +111,6 @@ def check(ctx, key, db, full_report, stdin, files, cache, ignore, output, json, packages = get_packages(files, stdin) proxy_dictionary = get_proxy_dict(proxy_protocol, proxy_host, proxy_port) - announcements = [] - if not db or is_a_remote_mirror(db): - LOG.info('Not local DB used, Getting announcements') - announcements = safety.get_announcements(key=key, proxy=proxy_dictionary, telemetry=ctx.parent.telemetry) - if key: server_policies = fetch_policy(key=key, proxy=proxy_dictionary) server_audit_and_monitor = server_policies["audit_and_monitor"] @@ -151,6 +152,11 @@ def check(ctx, key, db, full_report, stdin, files, cache, ignore, output, json, LOG.info('Safety is going to calculate remediations') remediations = safety.calculate_remediations(vulns, db_full) + announcements = [] + if not db or is_a_remote_mirror(db): + LOG.info('Not local DB used, Getting announcements') + announcements = safety.get_announcements(key=key, proxy=proxy_dictionary, telemetry=ctx.parent.telemetry) + json_report = None if save_json or (server_audit_and_monitor and audit_and_monitor): default_name = 'safety-report.json' @@ -177,12 +183,14 @@ def check(ctx, key, db, full_report, stdin, files, cache, ignore, output, json, output_report = json_report else: output_report = SafetyFormatter(output=output).render_vulnerabilities(announcements, vulns, remediations, - full_report, packages) + full_report, packages) # Announcements are send to stderr if not terminal, it doesn't depend on "exit_code" value - if announcements and (not sys.stdout.isatty() and os.environ.get("SAFETY_OS_DESCRIPTION", None) != 'run'): - LOG.info('sys.stdout is not a tty, announcements are going to be send to stderr') - click.secho(SafetyFormatter(output='text').render_announcements(announcements), fg="red", file=sys.stderr) + stderr_announcements = filter_announcements(announcements=announcements, by_type='error') + if stderr_announcements and (not sys.stdout.isatty() and os.environ.get("SAFETY_OS_DESCRIPTION", None) != 'run'): + LOG.info('sys.stdout is not a tty, error announcements are going to be send to stderr') + click.secho(SafetyFormatter(output='text').render_announcements(stderr_announcements), fg="red", + file=sys.stderr) found_vulns = list(filter(lambda v: not v.ignored, vulns)) LOG.info('Vulnerabilities found (Not ignored): %s', len(found_vulns)) @@ -219,7 +227,6 @@ def review(ctx, full_report, output, file): Show an output from a previous exported JSON report. """ LOG.info('Running check command') - announcements = safety.get_announcements(key=None, proxy=None, telemetry=ctx.parent.telemetry) report = {} try: @@ -235,6 +242,7 @@ def review(ctx, full_report, output, file): params = {'file': file} vulns, remediations, packages = safety.review(report, params=params) + announcements = safety.get_announcements(key=None, proxy=None, telemetry=ctx.parent.telemetry) output_report = SafetyFormatter(output=output).render_vulnerabilities(announcements, vulns, remediations, full_report, packages) @@ -271,14 +279,11 @@ def license(ctx, key, db, output, cache, files, proxyprotocol, proxyhost, proxyp packages = get_packages(files, False) proxy_dictionary = get_proxy_dict(proxyprotocol, proxyhost, proxyport) - announcements = [] - if not db: - announcements = safety.get_announcements(key=key, proxy=proxy_dictionary, telemetry=ctx.parent.telemetry) - licenses_db = {} try: - licenses_db = safety.get_licenses(key, db, cache, proxy_dictionary, telemetry=ctx.parent.telemetry) + licenses_db = safety.get_licenses(key=key, db_mirror=db, cached=cache, proxy=proxy_dictionary, + telemetry=ctx.parent.telemetry) except SafetyError as e: LOG.exception('Expected SafetyError happened: %s', e) output_exception(e, exit_code_output=False) @@ -289,6 +294,10 @@ def license(ctx, key, db, output, cache, files, proxyprotocol, proxyhost, proxyp filtered_packages_licenses = get_packages_licenses(packages=packages, licenses_db=licenses_db) + announcements = [] + if not db: + announcements = safety.get_announcements(key=key, proxy=proxy_dictionary, telemetry=ctx.parent.telemetry) + output_report = SafetyFormatter(output=output).render_licenses(announcements, filtered_packages_licenses) click.secho(output_report, nl=True) @@ -367,5 +376,6 @@ def validate(ctx, name, path): cli.add_command(alert) + if __name__ == "__main__": cli() diff --git a/safety/output_utils.py b/safety/output_utils.py index 529860ff..78e7abd4 100644 --- a/safety/output_utils.py +++ b/safety/output_utils.py @@ -495,18 +495,16 @@ def build_using_sentence(key, db): key_sentence = [{'style': True, 'value': 'an API KEY'}, {'style': False, 'value': ' and the '}] db_name = 'PyUp Commercial' - elif db and custom_integration and is_a_remote_mirror(db): - return [] + elif db: + if is_a_remote_mirror(db): + if custom_integration: + return [] + db_name = f"remote URL {db}" + else: + db_name = f"local file {db}" else: db_name = 'non-commercial' - if db: - db_type = 'local file' - if is_a_remote_mirror(db): - db_type = 'remote URL' - - db_name = f"{db_type} {db}" - database_sentence = [{'style': True, 'value': db_name + ' database'}] return [{'style': False, 'value': 'Using '}] + key_sentence + database_sentence diff --git a/safety/safety.py b/safety/safety.py index 67a2b414..f39f6718 100644 --- a/safety/safety.py +++ b/safety/safety.py @@ -208,10 +208,12 @@ def fetch_database_file(path, db_name): def fetch_database(full=False, key=False, db=False, cached=0, proxy=None, telemetry=True): - if db: + if key: + mirrors = API_MIRRORS + elif db: mirrors = [db] else: - mirrors = API_MIRRORS if key else OPEN_MIRRORS + mirrors = OPEN_MIRRORS db_name = "insecure_full.json" if full else "insecure.json" for mirror in mirrors: @@ -346,7 +348,7 @@ def check(packages, key=False, db_mirror=False, cached=0, ignore_vulns=None, ign ignore_vuln_if_needed(vuln_id, cve, ignore_vulns, ignore_severity_rules) - vulnerability = get_vulnerability_from(vuln_id, cve, data, specifier, db, name, pkg, + vulnerability = get_vulnerability_from(vuln_id, cve, data, specifier, db_full, name, pkg, ignore_vulns) should_add_vuln = not (vulnerability.is_transitive and is_env_scan) @@ -548,7 +550,7 @@ def get_announcements(key, proxy, telemetry=True): url = source method = 'get' data = { - 'telemetry': json.dumps(build_telemetry_data(telemetry=telemetry))} + 'telemetry': json.dumps(data)} data_keyword = 'params' request_kwargs[data_keyword] = data @@ -608,3 +610,8 @@ def read_vulnerabilities(fh): raise MalformedDatabase(reason=e, fetched_from=fh.name) return data + + +def close_session(): + LOG.debug('Closing requests session.') + session.close() diff --git a/safety/util.py b/safety/util.py index 5af070b3..aa6683ab 100644 --- a/safety/util.py +++ b/safety/util.py @@ -152,6 +152,11 @@ def get_basic_announcements(announcements): announcement.get('type', '').lower() != 'primary_announcement'] +def filter_announcements(announcements, by_type='error'): + return [announcement for announcement in announcements if + announcement.get('type', '').lower() == by_type] + + def build_telemetry_data(telemetry=True): context = SafetyContext() @@ -388,11 +393,13 @@ def __init__( mode: str = "r", encoding: str = None, errors: str = "strict", + pure: bool = os.environ.get('SAFETY_PURE_YAML', 'false').lower() == 'true' ) -> None: self.mode = mode self.encoding = encoding self.errors = errors self.basic_msg = '\n' + click.style('Unable to load the Safety Policy file "{name}".', fg='red') + self.pure = pure def to_info_dict(self): info_dict = super().to_info_dict() @@ -429,16 +436,17 @@ def convert(self, value, param, ctx): msg = self.basic_msg.format(name=value) + '\n' + click.style('HINT:', fg='yellow') + ' {hint}' - f, should_close = click.types.open_stream( + f, _ = click.types.open_stream( value, self.mode, self.encoding, self.errors, atomic=False ) filename = '' try: raw = f.read() - yaml = YAML(typ='safe', pure=False) + yaml = YAML(typ='safe', pure=self.pure) safety_policy = yaml.load(raw) filename = f.name + f.close() except Exception as e: show_parsed_hint = isinstance(e, MarkedYAMLError) hint = str(e)