Skip to content

Commit

Permalink
use original body for detached jws validation
Browse files Browse the repository at this point in the history
  • Loading branch information
johanlundberg committed Apr 22, 2024
1 parent 9f1d99a commit 4fb4015
Show file tree
Hide file tree
Showing 5 changed files with 30 additions and 23 deletions.
1 change: 1 addition & 0 deletions src/auth_server/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ class Context(BaseModel):
client_cert: Optional[str] = None
jws_obj: Optional[jws.JWS] = None
detached_jws: Optional[str] = None
detached_jws_body: Optional[str] = None
model_config = ConfigDict(arbitrary_types_allowed=True)

def to_dict(self):
Expand Down
2 changes: 0 additions & 2 deletions src/auth_server/flows.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,8 +170,6 @@ async def check_proof(self, gnap_key: Key, gnap_request: Optional[Union[GrantReq
return await check_jwsd_proof(
request=self.request,
gnap_key=gnap_key,
gnap_request=gnap_request,
key_reference=self.state.key_reference,
access_token=self.state.continue_access_token,
)
else:
Expand Down
13 changes: 11 additions & 2 deletions src/auth_server/middleware.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,11 +44,11 @@ async def dispatch(self, request: Request, call_next):

if request.headers.get("content-type") == "application/jose+json":
request = self.make_context_request(request)
logger.info("got application/jose request")
logger.info("got application/jose+json request")
body = await get_body(request)
# deserialize jws
body_str = body.decode("utf-8")
logger.debug(f"body: {body_str}")
logger.debug(f"JWS body: {body_str}")
jwstoken = jws.JWS()
try:
jwstoken.deserialize(body_str)
Expand All @@ -63,4 +63,13 @@ async def dispatch(self, request: Request, call_next):
# replace body with unverified deserialized token - verification is done when verifying proof
await set_body(request, jwstoken.objects["payload"])

if request.headers.get("Detached-JWS"):
request = self.make_context_request(request)
logger.info("got detached jws request")
# save original body for the detached jws validation
body = await get_body(request)
body_str = body.decode("utf-8")
logger.debug(f"JWSD body: {body_str}")
request.context.detached_jws_body = body_str

return await call_next(request)
19 changes: 6 additions & 13 deletions src/auth_server/proof/jws.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,14 +99,13 @@ async def check_jws_proof(
async def check_jwsd_proof(
request: ContextRequest,
gnap_key: Key,
gnap_request: Union[GrantRequest, ContinueRequest],
key_reference: Optional[str] = None,
access_token: Optional[str] = None,
) -> bool:
if request.context.detached_jws is None:
if request.context.detached_jws is None or request.context.detached_jws_body is None:
raise HTTPException(status_code=400, detail="No detached JWS found")

logger.debug(f"detached_jws: {request.context.detached_jws}")
logger.debug(f"detached_jws_body: {request.context.detached_jws_body}")

# recreate jws
try:
Expand All @@ -115,19 +114,13 @@ async def check_jwsd_proof(
logger.error(f"invalid detached jws: {e}")
return False

gnap_request_orig = gnap_request.copy(deep=True)
if isinstance(gnap_request_orig, GrantRequest) and key_reference is not None:
# If key was sent as reference in grant request we need to mirror that when
# rebuilding the request as that was what was signed
assert isinstance(gnap_request_orig.client, Client) # please mypy
gnap_request_orig.client.key = key_reference

logger.debug(f"gnap_request_orig: {gnap_request_orig.json(exclude_unset=True)}")
payload = base64url_encode(gnap_request_orig.json(exclude_unset=True))
payload = base64url_encode(request.context.detached_jws_body)
logger.debug(f"payload: {payload}")
raw_jws = f"{header}.{payload}.{signature}"
_jws = jws.JWS()
logger.debug(f"raw_jws: {raw_jws}")

# deserialize jws
_jws = jws.JWS()
try:
_jws.deserialize(raw_jws=raw_jws)
logger.info("Detached JWS token deserialized")
Expand Down
18 changes: 12 additions & 6 deletions src/auth_server/tests/test_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,7 +321,7 @@ def test_transaction_jws(self):
assert claims["auth_source"] == AuthSource.TEST

def test_transaction_jwsd(self):
client_key_dict = self.client_jwk.export(as_dict=True)
client_key_dict = self.client_jwk.export_public(as_dict=True)
client_jwk = ECJWK(**client_key_dict)
req = GrantRequest(
client=Client(key=Key(proof=Proof(method=ProofMethod.JWSD), jwk=client_jwk)),
Expand All @@ -335,7 +335,9 @@ def test_transaction_jwsd(self):
"uri": "http://testserver/transaction",
"created": int(utc_now().timestamp()),
}
_jws = jws.JWS(payload=req.json(exclude_unset=True))

payload = req.model_dump_json(exclude_unset=True)
_jws = jws.JWS(payload=payload)
_jws.add_signature(
key=self.client_jwk,
protected=json.dumps(jws_header),
Expand All @@ -346,7 +348,9 @@ def test_transaction_jwsd(self):
header, _, signature = data.split(".")
client_header = {"Detached-JWS": f"{header}..{signature}"}

response = self.client.post("/transaction", json=req.dict(exclude_unset=True), headers=client_header)
response = self.client.post(
"/transaction", content=req.model_dump_json(exclude_unset=True), headers=client_header
)

assert response.status_code == 200
assert "access_token" in response.json()
Expand Down Expand Up @@ -1148,7 +1152,7 @@ def test_transaction_jwsd_continue(self):
self.config["auth_flows"] = json.dumps(["InteractionFlow"])
self._update_app_config(config=self.config)

client_key_dict = self.client_jwk.export(as_dict=True)
client_key_dict = self.client_jwk.export_public(as_dict=True)
client_jwk = ECJWK(**client_key_dict)

req = GrantRequest(
Expand All @@ -1164,7 +1168,7 @@ def test_transaction_jwsd_continue(self):
"uri": "http://testserver/transaction",
"created": int(utc_now().timestamp()),
}
_jws = jws.JWS(payload=req.json(exclude_unset=True))
_jws = jws.JWS(payload=req.model_dump_json(exclude_unset=True))
_jws.add_signature(
key=self.client_jwk,
protected=json.dumps(jws_header),
Expand All @@ -1175,7 +1179,9 @@ def test_transaction_jwsd_continue(self):
header, _, signature = data.split(".")
client_header = {"Detached-JWS": f"{header}..{signature}"}

response = self.client.post("/transaction", json=req.dict(exclude_unset=True), headers=client_header)
response = self.client.post(
"/transaction", content=req.model_dump_json(exclude_unset=True), headers=client_header
)
assert response.status_code == 200

# continue response with no continue reference in uri
Expand Down

0 comments on commit 4fb4015

Please sign in to comment.