Skip to content

Commit

Permalink
feat: add headless auth
Browse files Browse the repository at this point in the history
  • Loading branch information
yeisonvargasf committed Mar 26, 2024
1 parent 1958c25 commit f63b6fb
Show file tree
Hide file tree
Showing 8 changed files with 145 additions and 86 deletions.
31 changes: 20 additions & 11 deletions safety/auth/cli.py
Expand Up @@ -96,7 +96,7 @@ def render_successful_login(auth: Auth,


@auth_app.command(name=CMD_LOGIN_NAME, help=CLI_AUTH_LOGIN_HELP)
def login(ctx: typer.Context):
def login(ctx: typer.Context, headless: bool = False):
"""
Authenticate Safety CLI with your safetycli.com account using your default browser.
"""
Expand All @@ -105,29 +105,38 @@ def login(ctx: typer.Context):
fail_if_authenticated(ctx, with_msg=MSG_FAIL_LOGIN_AUTHED)

console.print()
brief_msg: str = "Redirecting your browser to log in; once authenticated, " \
"return here to start using Safety"

uri, initial_state = get_authorization_data(client=ctx.obj.auth.client,
code_verifier=ctx.obj.auth.code_verifier,
organization=ctx.obj.auth.org)
info = None

if ctx.obj.auth.org:
brief_msg: str = "Redirecting your browser to log in; once authenticated, " \
"return here to start using Safety"

if ctx.obj.auth.org:
console.print(f"Logging into [bold]{ctx.obj.auth.org.name}[/bold] " \
"organization.")


if headless:
brief_msg = "Running in headless mode. Please copy and open the following URL in a browser"


uri, initial_state = get_authorization_data(client=ctx.obj.auth.client,
code_verifier=ctx.obj.auth.code_verifier,
organization=ctx.obj.auth.org, headless=headless)
click.secho(brief_msg)
click.echo()

info = process_browser_callback(uri,
initial_state=initial_state, ctx=ctx)
info = process_browser_callback(uri, initial_state=initial_state, ctx=ctx, headless=headless)


if info:
if info.get("email", None):
organization = None
if ctx.obj.auth.org and ctx.obj.auth.org.name:
organization = ctx.obj.auth.org.name
ctx.obj.auth.refresh_from(info)
if headless:
console.print()

render_successful_login(ctx.obj.auth, organization=organization)

console.print()
Expand All @@ -149,7 +158,7 @@ def login(ctx: typer.Context):
else:
msg += "Error logging into Safety."

msg += " Please try again, or use [bold]`safety auth –help`[/bold] " \
msg += " Please try again, or use [bold]`safety auth -–help`[/bold] " \
"for more information[/red]"

console.print(msg, emoji=True)
Expand Down
5 changes: 3 additions & 2 deletions safety/auth/main.py
Expand Up @@ -2,6 +2,7 @@
import json

from typing import Any, Dict, Optional, Tuple, Union
from urllib.parse import urlencode

from authlib.oidc.core import CodeIDToken
from authlib.jose import jwt
Expand All @@ -17,9 +18,9 @@

def get_authorization_data(client, code_verifier: str,
organization: Optional[Organization] = None,
sign_up: bool = False, ensure_auth: bool = False) -> Tuple[str, str]:
sign_up: bool = False, ensure_auth: bool = False, headless: bool = False) -> Tuple[str, str]:

kwargs = {'sign_up': sign_up, 'locale': 'en', 'ensure_auth': ensure_auth}
kwargs = {'sign_up': sign_up, 'locale': 'en', 'ensure_auth': ensure_auth, 'headless': headless}
if organization:
kwargs['organization'] = organization.id

Expand Down
133 changes: 85 additions & 48 deletions safety/auth/server.py
@@ -1,4 +1,5 @@
import http.server
import json
import logging
import socket
import sys
Expand All @@ -13,6 +14,8 @@

from safety.auth.constants import AUTH_SERVER_URL, CLI_AUTH_SUCCESS, CLI_LOGOUT_SUCCESS, HOST
from safety.auth.main import save_auth_config
from authlib.integrations.base_client.errors import OAuthError
from rich.prompt import Prompt

LOG = logging.getLogger(__name__)

Expand All @@ -33,40 +36,49 @@ def find_available_port():

return None

def auth_process(code: str, state: str, initial_state: str, code_verifier, client):
err = None

if initial_state is None or initial_state != state:
err = "The state parameter value provided does not match the expected " \
"value. The state parameter is used to protect against Cross-Site " \
"Request Forgery (CSRF) attacks. For security reasons, the " \
"authorization process cannot proceed with an invalid state " \
"parameter value. Please try again, ensuring that the state " \
"parameter value provided in the authorization request matches " \
"the value returned in the callback."

if err:
click.secho(f'Error: {err}', fg='red')
sys.exit(1)

try:
tokens = client.fetch_token(url=f'{AUTH_SERVER_URL}/oauth/token',
code_verifier=code_verifier,
client_id=client.client_id,
grant_type='authorization_code', code=code)

save_auth_config(access_token=tokens['access_token'],
id_token=tokens['id_token'],
refresh_token=tokens['refresh_token'])
return client.fetch_user_info()

except Exception as e:
LOG.exception(e)
sys.exit(1)

class CallbackHandler(http.server.BaseHTTPRequestHandler):
def auth(self, code: str, state: str, err, error_description):
initial_state = self.server.initial_state
ctx = self.server.ctx

if initial_state is None or initial_state != state:
err = "The state parameter value provided does not match the expected" \
"value. The state parameter is used to protect against Cross-Site " \
"Request Forgery (CSRF) attacks. For security reasons, the " \
"authorization process cannot proceed with an invalid state " \
"parameter value. Please try again, ensuring that the state " \
"parameter value provided in the authorization request matches " \
"the value returned in the callback."

if err:
click.secho(f'Error: {err}', fg='red')
sys.exit(1)
result = auth_process(code=code,
state=state,
initial_state=initial_state,
code_verifier=ctx.obj.auth.code_verifier,
client=ctx.obj.auth.client)

try:
tokens = ctx.obj.auth.client.fetch_token(url=f'{AUTH_SERVER_URL}/oauth/token',
code_verifier=ctx.obj.auth.code_verifier,
client_id=ctx.obj.auth.client.client_id,
grant_type='authorization_code', code=code)

save_auth_config(access_token=tokens['access_token'],
id_token=tokens['id_token'],
refresh_token=tokens['refresh_token'])
self.server.callback = ctx.obj.auth.client.fetch_user_info()

except Exception as e:
LOG.exception(e)
sys.exit(1)

self.server.callback = result
self.do_redirect(location=CLI_AUTH_SUCCESS, params={})

def logout(self):
Expand Down Expand Up @@ -132,27 +144,52 @@ def handle_timeout(self) -> None:
sys.exit(1)

try:
server = ThreadedHTTPServer((HOST, PORT), CallbackHandler)
server.initial_state = kwargs.get("initial_state", None)
server.timeout = kwargs.get("timeout", 600)
# timeout = kwargs.get("timeout", None)
# timeout = float(timeout) if timeout else None
server.ctx = kwargs.get("ctx", None)
server_thread = threading.Thread(target=server.handle_request)
server_thread.start()

target = f"{uri}&port={PORT}"
console.print(f"If the browser does not automatically open in 5 seconds, " \
"copy and paste this url into your browser: " \
f"[link={target}]{target}[/link]")
click.echo()

wait_msg = "waiting for browser authentication"

with console.status(wait_msg, spinner="bouncingBar"):
time.sleep(2)
click.launch(target)
server_thread.join()
headless = kwargs.get("headless", False)
initial_state = kwargs.get("initial_state", None)
ctx = kwargs.get("ctx", None)

message = "Copy and paste this url into your browser:"


if not headless:
server = ThreadedHTTPServer((HOST, PORT), CallbackHandler)
server.initial_state = initial_state
server.timeout = kwargs.get("timeout", 600)
server.ctx = ctx
server_thread = threading.Thread(target=server.handle_request)
server_thread.start()
message = f"If the browser does not automatically open in 5 seconds, " \
"copy and paste this url into your browser:"

target = uri if headless else f"{uri}&port={PORT}"
console.print(f"{message} [link={target}]{target}[/link]")
console.print()

if headless:

exchange_data = None
while not exchange_data:
auth_code_text = Prompt.ask("Paste the response here", default=None, console=console)
try:
exchange_data = json.loads(auth_code_text)
state = exchange_data["state"]
code = exchange_data["code"]
except Exception as e:
code = state = None

return auth_process(code=code,
state=state,
initial_state=initial_state,
code_verifier=ctx.obj.auth.code_verifier,
client=ctx.obj.auth.client)
else:

wait_msg = "waiting for browser authentication"

with console.status(wait_msg, spinner="bouncingBar"):
time.sleep(2)
click.launch(target)
server_thread.join()

except OSError as e:
if e.errno == socket.errno.EADDRINUSE:
Expand Down
2 changes: 1 addition & 1 deletion tests/auth/test_cli.py
Expand Up @@ -28,7 +28,7 @@ def test_auth_calls_login(self, process_browser_callback,
get_authorization_data.assert_called_once()
process_browser_callback.assert_called_once_with(auth_data[0],
initial_state=auth_data[1],
ctx=ANY)
ctx=ANY, headless=False)

expected = [
"",
Expand Down
6 changes: 4 additions & 2 deletions tests/auth/test_main.py
Expand Up @@ -30,7 +30,8 @@ def test_get_authorization_data(self):
"sign_up": False,
"locale": "en",
"ensure_auth": False,
"organization": org_id
"organization": org_id,
"headless": False
}

client.create_authorization_url.assert_called_once_with(
Expand All @@ -42,7 +43,8 @@ def test_get_authorization_data(self):
kwargs = {
"sign_up": False,
"locale": "en",
"ensure_auth":False
"ensure_auth":False,
"headless": False
}

client.create_authorization_url.assert_called_once_with(
Expand Down
48 changes: 28 additions & 20 deletions tests/test_cli.py
Expand Up @@ -204,8 +204,7 @@ def test_validate_with_basic_policy_file(self):
result = self.runner.invoke(cli.cli, ['validate', 'policy_file', '3.0', '--path', path])
cleaned_stdout = click.unstyle(result.stdout)
msg = 'The Safety policy (3.0) file (Used for scan and system-scan commands) was successfully parsed with the following values:\n'
parsed = json.dumps(
{
parsed = {
"version": "3.0",
"scan": {
"max_depth": 6,
Expand All @@ -230,19 +229,19 @@ def test_validate_with_basic_policy_file(self):
},
"fail_scan": {
"dependency_vulnerabilities": {
"enabled": True,
"fail_on_any_of": {
"cvss_severity": [
"critical",
"high",
"medium"
],
"exploitability": [
"critical",
"high",
"medium"
]
}
"enabled": True,
"fail_on_any_of": {
"cvss_severity": [
"critical",
"high",
"medium",
],
"exploitability": [
"critical",
"high",
"medium",
]
}
}
},
"security_updates": {
Expand All @@ -252,12 +251,21 @@ def test_validate_with_basic_policy_file(self):
]
}
}
},
indent=2
) + '\n'
}

self.assertEqual(msg + parsed, cleaned_stdout)
self.assertEqual(result.exit_code, 0)
msg_stdout, parsed_policy = cleaned_stdout.split('\n', 1)
msg_stdout += '\n'
parsed_policy = json.loads(parsed_policy.replace('\n', ''))

fail_scan = parsed_policy.get("fail_scan", None)
self.assertIsNotNone(fail_scan)
fail_of_any = fail_scan["dependency_vulnerabilities"]["fail_on_any_of"]
fail_of_any["cvss_severity"] = sorted(fail_of_any["cvss_severity"])
fail_of_any["exploitability"] = sorted(fail_of_any["exploitability"])

self.assertEqual(msg, msg_stdout)
self.assertEqual(parsed, parsed_policy)
self.assertEqual(result.exit_code, 0)


def test_validate_with_policy_file_using_invalid_keyword(self):
Expand Down
2 changes: 2 additions & 0 deletions tests/test_safety.py
Expand Up @@ -494,6 +494,8 @@ def test_get_announcements_http_ok(self, get_used_options):
@patch("safety.util.get_used_options")
@patch.object(click, 'get_current_context', Mock(command=Mock(name=Mock(return_value='check'))))
def test_get_announcements_wrong_json_response_handling(self, get_used_options):
get_used_options.return_value = {}

# wrong JSON structure
announcements = {
"type": "notice",
Expand Down
4 changes: 2 additions & 2 deletions tox.ini
Expand Up @@ -5,8 +5,8 @@ isolated_build = true

[testenv]
deps =
pytest-cov
pytest
pytest-cov==4.1.0
pytest==7.4.4

commands =
pytest -rP tests/ --cov=safety/ --cov-report=html
Expand Down

0 comments on commit f63b6fb

Please sign in to comment.