Skip to content

Commit

Permalink
Merge pull request #28 from SUNET/lundberg_jwsd_fix
Browse files Browse the repository at this point in the history
Fix jwsd implementation
  • Loading branch information
johanlundberg committed May 3, 2024
2 parents 82583e6 + 8c02b48 commit 9aceb59
Show file tree
Hide file tree
Showing 8 changed files with 119 additions and 38 deletions.
16 changes: 16 additions & 0 deletions src/auth_server/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,13 @@
from typing import Dict, Type, cast

from fastapi import FastAPI
from fastapi.exceptions import RequestValidationError
from fastapi.middleware.cors import CORSMiddleware
from loguru import logger
from starlette.requests import Request
from starlette.responses import JSONResponse
from starlette.staticfiles import StaticFiles
from starlette.status import HTTP_422_UNPROCESSABLE_ENTITY

from auth_server.config import AuthServerConfig, ConfigurationError, FlowName, load_config
from auth_server.context import ContextRequestRoute
Expand Down Expand Up @@ -78,4 +82,16 @@ def init_auth_server_api() -> AuthServer:
app.mount(
"/static", StaticFiles(packages=["auth_server"]), name="static"
) # defaults to the "statics" directory (the ending s is not a mistake) because starlette says so

config = load_config()
if config.debug or config.testing:
# log more info about 422 errors to ease fault tracing
@app.exception_handler(RequestValidationError)
async def validation_exception_handler(request: Request, exc: RequestValidationError):

exc_str = f"{exc}".replace("\n", " ").replace(" ", " ")
logger.exception(f"{exc}")
content = {"status_code": 10422, "message": exc_str, "data": None}
return JSONResponse(content=content, status_code=HTTP_422_UNPROCESSABLE_ENTITY)

return app
1 change: 1 addition & 0 deletions src/auth_server/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ class TLSFEDMetadata(BaseModel):
class AuthServerConfig(BaseSettings):
app_name: str = Field(default="auth-server")
environment: Environment = Field(default=Environment.PROD)
debug: bool = False
testing: bool = False
log_level: str = Field(default="INFO")
log_color: bool = True
Expand Down
1 change: 1 addition & 0 deletions src/auth_server/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ class Context(BaseModel):
client_cert: Optional[str] = None
jws_obj: Optional[jws.JWS] = None
detached_jws: Optional[str] = None
detached_jws_body: Optional[str] = None
model_config = ConfigDict(arbitrary_types_allowed=True)

def to_dict(self):
Expand Down
9 changes: 6 additions & 3 deletions src/auth_server/flows.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,9 @@ async def continue_transaction(self, continue_request: ContinueRequest) -> Optio
self.state.proof_ok = await self.check_proof(
gnap_key=self.state.grant_request.client.key, gnap_request=continue_request
)
if not self.state.proof_ok:
logger.error("could not validate proof of key possession in continue response, aborting")
raise StopTransactionException(status_code=401, detail="could not validate proof of key possession")

# run the remaining steps in the flow
steps = await self.steps()
Expand Down Expand Up @@ -170,15 +173,14 @@ async def check_proof(self, gnap_key: Key, gnap_request: Optional[Union[GrantReq
return await check_jwsd_proof(
request=self.request,
gnap_key=gnap_key,
gnap_request=gnap_request,
key_reference=self.state.key_reference,
access_token=self.state.continue_access_token,
)
else:
raise NextFlowException(status_code=400, detail="no supported proof method")

async def create_claims(self) -> Claims:
if self.state.auth_source is None:
logger.error("no auth_source set, aborting")
raise NextFlowException(status_code=400, detail="no auth source set")

claims = Claims(
Expand Down Expand Up @@ -384,7 +386,8 @@ async def handle_subject_response(self) -> Optional[GrantResponse]:

async def create_auth_token(self) -> Optional[GrantResponse]:
if not self.state.proof_ok:
return None
logger.error("could not validate proof of key possession, running next flow")
raise NextFlowException(status_code=401, detail="could not validate proof of key possession")

# Create claims
claims = await self.create_claims()
Expand Down
48 changes: 42 additions & 6 deletions src/auth_server/middleware.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
# -*- coding: utf-8 -*-
from typing import Optional

from jwcrypto import jws
from jwcrypto.common import JWException
from loguru import logger
Expand Down Expand Up @@ -33,22 +35,43 @@ async def get_body(request: Request) -> bytes:
return body


def get_header_index(request: Request, header_key: bytes) -> Optional[int]:
for key, value in request.scope["headers"]:
if key == header_key:
return request.scope["headers"].index((key, value))
return None


def set_header(request: Request, header_key: str, header_value: str) -> None:
b_header_key = header_key.encode("utf-8")
b_header_value = header_value.encode("utf-8")
content_type_index = get_header_index(request, b_header_key)
if content_type_index:
logger.debug(
f"Replacing header {request.scope['headers'][content_type_index]} with {(b_header_key, b_header_value)}"
)
request.scope["headers"][content_type_index] = (b_header_key, b_header_value)
else:
# no header to replace, just set it
request.scope["headers"].append((b_header_key, b_header_value))


class JOSEMiddleware(BaseHTTPMiddleware, ContextRequestMixin):
def __init__(self, app):
super().__init__(app)

async def dispatch(self, request: Request, call_next):
if request.headers.get("content-type") == "application/jose":
# Return a more helpful error message for a common mistake
return return_error_response(status_code=422, detail="content-type needs to be application/jose+json")
acceptable_jose_content_types = ["application/jose", "application/jose+json"]
is_jose = request.headers.get("content-type") in acceptable_jose_content_types
is_detached_jws = request.headers.get("Detached-JWS") is not None

if request.headers.get("content-type") == "application/jose+json":
if is_jose and not is_detached_jws:
request = self.make_context_request(request)
logger.info("got application/jose request")
logger.info("got application/jose+json request")
body = await get_body(request)
# deserialize jws
body_str = body.decode("utf-8")
logger.debug(f"body: {body_str}")
logger.debug(f"JWS body: {body_str}")
jwstoken = jws.JWS()
try:
jwstoken.deserialize(body_str)
Expand All @@ -62,5 +85,18 @@ async def dispatch(self, request: Request, call_next):
request.context.jws_obj = jwstoken
# replace body with unverified deserialized token - verification is done when verifying proof
await set_body(request, jwstoken.objects["payload"])
# set content-type to application/json as the body has changed
set_header(request, "content-type", "application/json")
# update content-length header to match the new body
set_header(request, "content-length", str(len(jwstoken.objects["payload"])))

if is_detached_jws:
request = self.make_context_request(request)
logger.info("got detached jws request")
# save original body for the detached jws validation
body = await get_body(request)
body_str = body.decode("utf-8")
logger.debug(f"JWSD body: {body_str}")
request.context.detached_jws_body = body_str

return await call_next(request)
2 changes: 1 addition & 1 deletion src/auth_server/models/gnap.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ class FinishInteraction(GnapBaseModel):
method: FinishInteractionMethod
uri: str
nonce: str
hash_method: HashMethod = Field(default=HashMethod.SHA_256)
hash_method: Optional[HashMethod] = None


class Hints(GnapBaseModel):
Expand Down
32 changes: 16 additions & 16 deletions src/auth_server/proof/jws.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from cryptography.hazmat.primitives.hashes import SHA256, SHA384, SHA512
from fastapi import HTTPException
from jwcrypto import jwk, jws
from jwcrypto.common import base64url_encode
from jwcrypto.common import base64url_decode, base64url_encode
from loguru import logger
from pydantic import ValidationError

Expand Down Expand Up @@ -99,42 +99,42 @@ async def check_jws_proof(
async def check_jwsd_proof(
request: ContextRequest,
gnap_key: Key,
gnap_request: Union[GrantRequest, ContinueRequest],
key_reference: Optional[str] = None,
access_token: Optional[str] = None,
) -> bool:
if request.context.detached_jws is None:
if request.context.detached_jws is None or request.context.detached_jws_body is None:
raise HTTPException(status_code=400, detail="No detached JWS found")

logger.debug(f"detached_jws: {request.context.detached_jws}")
logger.debug(f"detached_jws_body: {request.context.detached_jws_body}")

# recreate jws
try:
header, _, signature = request.context.detached_jws.split(".")
header, client_payload_hash, signature = request.context.detached_jws.split(".")
except ValueError as e:
logger.error(f"invalid detached jws: {e}")
return False
raise HTTPException(status_code=400, detail="invalid format for detached jws")

gnap_request_orig = gnap_request.copy(deep=True)
if isinstance(gnap_request_orig, GrantRequest) and key_reference is not None:
# If key was sent as reference in grant request we need to mirror that when
# rebuilding the request as that was what was signed
assert isinstance(gnap_request_orig.client, Client) # please mypy
gnap_request_orig.client.key = key_reference
payload = base64url_encode(request.context.detached_jws_body)
logger.debug(f"payload: {payload}")

# check hash of payload
payload_hash = hash_with(SHA256(), request.context.detached_jws_body.encode())
if payload_hash != base64url_decode(client_payload_hash):
logger.error(f"invalid payload hash: {repr(payload_hash)}")
raise HTTPException(status_code=400, detail="invalid payload hash")

logger.debug(f"gnap_request_orig: {gnap_request_orig.json(exclude_unset=True)}")
payload = base64url_encode(gnap_request_orig.json(exclude_unset=True))
raw_jws = f"{header}.{payload}.{signature}"
_jws = jws.JWS()
logger.debug(f"raw_jws: {raw_jws}")

# deserialize jws
_jws = jws.JWS()
try:
_jws.deserialize(raw_jws=raw_jws)
logger.info("Detached JWS token deserialized")
logger.debug(f"JWS: {_jws.objects}")
except jws.InvalidJWSObject as e:
logger.error(f"Failed to deserialize detached jws: {e}")
return False
raise HTTPException(status_code=400, detail=str(e))

verify_jws(jws_obj=_jws, gnap_key=gnap_key)

Expand Down
48 changes: 36 additions & 12 deletions src/auth_server/tests/test_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from cryptography import x509
from cryptography.hazmat.primitives.hashes import SHA256
from jwcrypto import jwk, jws, jwt
from jwcrypto.common import base64url_encode
from starlette.testclient import TestClient

from auth_server.api import init_auth_server_api
Expand Down Expand Up @@ -308,7 +309,7 @@ def test_transaction_jws(self):
)
data = _jws.serialize(compact=True)

client_header = {"Content-Type": "application/jose+json"}
client_header = {"Content-Type": "application/jose"}
response = self.client.post("/transaction", content=data, headers=client_header)

assert response.status_code == 200
Expand All @@ -321,7 +322,7 @@ def test_transaction_jws(self):
assert claims["auth_source"] == AuthSource.TEST

def test_transaction_jwsd(self):
client_key_dict = self.client_jwk.export(as_dict=True)
client_key_dict = self.client_jwk.export_public(as_dict=True)
client_jwk = ECJWK(**client_key_dict)
req = GrantRequest(
client=Client(key=Key(proof=Proof(method=ProofMethod.JWSD), jwk=client_jwk)),
Expand All @@ -335,7 +336,15 @@ def test_transaction_jwsd(self):
"uri": "http://testserver/transaction",
"created": int(utc_now().timestamp()),
}
_jws = jws.JWS(payload=req.json(exclude_unset=True))

payload = req.model_dump_json(exclude_unset=True)

# create a hash of payload to send in payload place
payload_digest = hash_with(SHA256(), payload.encode())
payload_hash = base64url_encode(payload_digest)

# create detached jws
_jws = jws.JWS(payload=payload)
_jws.add_signature(
key=self.client_jwk,
protected=json.dumps(jws_header),
Expand All @@ -344,9 +353,11 @@ def test_transaction_jwsd(self):

# Remove payload from serialized jws
header, _, signature = data.split(".")
client_header = {"Detached-JWS": f"{header}..{signature}"}
client_header = {"Detached-JWS": f"{header}.{payload_hash}.{signature}"}

response = self.client.post("/transaction", json=req.dict(exclude_unset=True), headers=client_header)
response = self.client.post(
"/transaction", content=req.model_dump_json(exclude_unset=True), headers=client_header
)

assert response.status_code == 200
assert "access_token" in response.json()
Expand Down Expand Up @@ -1148,7 +1159,7 @@ def test_transaction_jwsd_continue(self):
self.config["auth_flows"] = json.dumps(["InteractionFlow"])
self._update_app_config(config=self.config)

client_key_dict = self.client_jwk.export(as_dict=True)
client_key_dict = self.client_jwk.export_public(as_dict=True)
client_jwk = ECJWK(**client_key_dict)

req = GrantRequest(
Expand All @@ -1164,7 +1175,14 @@ def test_transaction_jwsd_continue(self):
"uri": "http://testserver/transaction",
"created": int(utc_now().timestamp()),
}
_jws = jws.JWS(payload=req.json(exclude_unset=True))

payload = req.model_dump_json(exclude_unset=True)

# create a hash of payload to send in payload place
payload_digest = hash_with(SHA256(), payload.encode())
payload_hash = base64url_encode(payload_digest)

_jws = jws.JWS(payload=payload)
_jws.add_signature(
key=self.client_jwk,
protected=json.dumps(jws_header),
Expand All @@ -1173,9 +1191,11 @@ def test_transaction_jwsd_continue(self):

# Remove payload from serialized jws
header, _, signature = data.split(".")
client_header = {"Detached-JWS": f"{header}..{signature}"}
client_header = {"Detached-JWS": f"{header}.{payload_hash}.{signature}"}

response = self.client.post("/transaction", json=req.dict(exclude_unset=True), headers=client_header)
response = self.client.post(
"/transaction", content=req.model_dump_json(exclude_unset=True), headers=client_header
)
assert response.status_code == 200

# continue response with no continue reference in uri
Expand Down Expand Up @@ -1207,7 +1227,11 @@ def test_transaction_jwsd_continue(self):
# calculate ath header value
access_token_hash = hash_with(SHA256(), continue_response["access_token"]["value"].encode())
jws_header["ath"] = base64.urlsafe_b64encode(access_token_hash).decode("ascii").rstrip("=")
_jws = jws.JWS(payload="{}")
# create hash of empty payload to send in payload place
payload = "{}"
payload_digest = hash_with(SHA256(), payload.encode())
payload_hash = base64url_encode(payload_digest)
_jws = jws.JWS(payload=payload)
_jws.add_signature(
key=self.client_jwk,
protected=json.dumps(jws_header),
Expand All @@ -1216,11 +1240,11 @@ def test_transaction_jwsd_continue(self):

# Remove payload from serialized jws
continue_header, _, continue_signature = continue_data.split(".")
client_header = {"Detached-JWS": f"{continue_header}..{continue_signature}"}
client_header = {"Detached-JWS": f"{continue_header}.{payload_hash}.{continue_signature}"}

authorization_header = f'GNAP {continue_response["access_token"]["value"]}'
client_header["Authorization"] = authorization_header
response = self.client.post(continue_response["uri"], json=dict(), headers=client_header)
response = self.client.post(continue_response["uri"], content=payload, headers=client_header)

assert response.status_code == 200
assert "access_token" in response.json()
Expand Down

0 comments on commit 9aceb59

Please sign in to comment.