Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

oauth: use a context manager for the server's thread #140

Merged
merged 2 commits into from
Jun 22, 2022
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
156 changes: 98 additions & 58 deletions sigstore/_internal/oidc/oauth.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

import base64
import hashlib
import http.server
Expand Down Expand Up @@ -46,17 +48,40 @@
"""


class RedirectHandler(http.server.BaseHTTPRequestHandler):
class OAuthFlow:
def __init__(self, client_id: str, client_secret: str, issuer: Issuer):
self._client_id = client_id
self._client_secret = client_secret
self._issuer = issuer
self._server = OAuthRedirectServer(
self._client_id, self._client_secret, self._issuer
)
self._server_thread = threading.Thread(
target=lambda server: server.serve_forever(),
args=(self._server,),
)

def __enter__(self) -> OAuthRedirectServer:
self._server_thread.start()

return self._server

def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
self._server.shutdown()
self._server_thread.join()


class OAuthRedirectHandler(http.server.BaseHTTPRequestHandler):
def log_message(self, _format: str, *_args: Any) -> None:
pass

def do_GET(self) -> None:
logger.debug(f"GET: {self.path} with {dict(self.headers)}")
server = cast(RedirectServer, self.server)
server = cast(OAuthRedirectServer, self.server)

# If the auth response has already been populated, the main thread will be stopping this
# thread and accessing the auth response shortly so we should stop servicing any requests.
if not server.active:
if server.auth_response is not None:
logger.debug(f"{self.path} unavailable (teardown)")
self.send_response(404)
return None
Expand All @@ -74,10 +99,9 @@ def do_GET(self) -> None:
self.end_headers()
self.wfile.write(body)
server.auth_response = urllib.parse.parse_qs(r.query)
elif r.path == server.request_path:
url = server.auth_request()
elif r.path == server.auth_request_path:
self.send_response(302)
self.send_header("Location", url)
self.send_header("Location", server.auth_endpoint)
self.end_headers()
else:
# Anything else sends a "Not Found" response.
Expand All @@ -87,28 +111,73 @@ def do_GET(self) -> None:
OOB_REDIRECT_URI = "urn:ietf:wg:oauth:2.0:oob"


class RedirectServer(http.server.HTTPServer):
def __init__(self, client_id: str, client_secret: str, issuer: Issuer) -> None:
super().__init__(("127.0.0.1", 0), RedirectHandler)
self.state: Optional[str] = None
self.nonce: Optional[str] = None
self.auth_response: Optional[Dict[str, List[str]]] = None
self._is_out_of_band = False
self._port: int = self.socket.getsockname()[1]
class OAuthSession:
def __init__(self, client_id: str, client_secret: str, issuer: Issuer):
self.__poison = False

self._client_id = client_id
self._client_secret = client_secret
self._issuer = issuer
self._state = str(uuid.uuid4())
self._nonce = str(uuid.uuid4())

self.code_verifier = (
base64.urlsafe_b64encode(os.urandom(32)).rstrip(b"=").decode()
)

@property
def active(self) -> bool:
return self.auth_response is None
def code_challenge(self) -> str:
return (
base64.urlsafe_b64encode(
hashlib.sha256(self.code_verifier.encode()).digest()
)
.rstrip(b"=")
.decode()
)

def auth_endpoint(self, redirect_uri: str) -> str:
# Defensive programming: we don't have a nice way to limit the
# lifetime of the OAuth session here, so we use the internal
# "poison" flag to check if we're attempting to reuse it in a way
# that would compromise the flow's security (i.e. nonce reuse).
if self.__poison:
raise IdentityError("internal error: OAuth endpoint misuse")
else:
self.__poison = True

params = self._auth_params(redirect_uri)
return f"{self._issuer.auth_endpoint}?{urllib.parse.urlencode(params)}"

def _auth_params(self, redirect_uri: str) -> Dict[str, Any]:
return {
"response_type": "code",
"client_id": self._client_id,
"client_secret": self._client_secret,
"scope": "openid email",
"redirect_uri": redirect_uri,
"code_challenge": self.code_challenge,
"code_challenge_method": "S256",
"state": self._state,
"nonce": self._nonce,
}


class OAuthRedirectServer(http.server.HTTPServer):
def __init__(self, client_id: str, client_secret: str, issuer: Issuer) -> None:
super().__init__(("localhost", 0), OAuthRedirectHandler)
self.oauth_session = OAuthSession(client_id, client_secret, issuer)
self.auth_response: Optional[Dict[str, List[str]]] = None
self._is_out_of_band = False

@property
def base_uri(self) -> str:
return f"http://localhost:{self._port}"
# NOTE: We'd ideally use `self.server_name` here, but it uses
# the FQDN internally (which in turn confuses Sigstore).
return f"http://localhost:{self.server_port}"

@property
def request_path(self) -> str:
def auth_request_path(self) -> str:
# TODO: Maybe this should be /auth, for clarity?
return "/"

@property
Expand All @@ -123,36 +192,17 @@ def redirect_uri(self) -> str:
else OOB_REDIRECT_URI
)

def generate_code_challenge(self) -> bytes:
self.code_verifier = base64.urlsafe_b64encode(os.urandom(32)).rstrip(b"=")
return base64.urlsafe_b64encode(
hashlib.sha256(self.code_verifier).digest()
).rstrip(b"=")

def auth_request_params(self) -> Dict[str, str]:
code_challenge = self.generate_code_challenge()
self.state = str(uuid.uuid4())
self.nonce = str(uuid.uuid4())
return {
"response_type": "code",
"client_id": self._client_id,
"client_secret": self._client_secret,
"scope": "openid email",
"redirect_uri": self.redirect_uri,
"code_challenge": code_challenge.decode("utf-8"),
"code_challenge_method": "S256",
"state": self.state,
"nonce": self.nonce,
}

def auth_request(self) -> str:
params = self.auth_request_params()
return f"{self._issuer.auth_endpoint}?{urllib.parse.urlencode(params)}"
@property
def auth_endpoint(self) -> str:
return self.oauth_session.auth_endpoint(self.redirect_uri)

def enable_oob(self) -> None:
logger.debug("enabling out-of-band OAuth flow")
self._is_out_of_band = True

def is_oob(self) -> bool:
return self._is_out_of_band


def get_identity_token(client_id: str, client_secret: str, issuer: Issuer) -> str:
"""
Expand All @@ -165,26 +215,19 @@ def get_identity_token(client_id: str, client_secret: str, issuer: Issuer) -> st
force_oob = os.getenv("SIGSTORE_OAUTH_FORCE_OOB") is not None

code: str
with RedirectServer(client_id, client_secret, issuer) as server:
thread = threading.Thread(
target=lambda server: server.serve_forever(),
args=(server,),
)
thread.start()

with OAuthFlow(client_id, client_secret, issuer) as server:
# Launch web browser
if not force_oob and webbrowser.open(server.base_uri):
print("Waiting for browser interaction...")
else:
server.enable_oob()
print(
f"Go to the following link in a browser:\n\n\t{server.auth_request()}"
)
print(f"Go to the following link in a browser:\n\n\t{server.auth_endpoint}")

if not server._is_out_of_band:
if not server.is_oob():
# Wait until the redirect server populates the response
while server.auth_response is None:
time.sleep(0.1)

auth_error = server.auth_response.get("error")
if auth_error is not None:
raise IdentityError(
Expand All @@ -195,15 +238,12 @@ def get_identity_token(client_id: str, client_secret: str, issuer: Issuer) -> st
# In the out-of-band case, we wait until the user provides the code
code = input("Enter verification code: ")

server.shutdown()
thread.join()

# Provide code to token endpoint
data = {
"grant_type": "authorization_code",
"redirect_uri": server.redirect_uri,
"code": code,
"code_verifier": server.code_verifier.decode("utf-8"),
"code_verifier": server.oauth_session.code_verifier,
}
auth = (
client_id,
Expand Down