Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add to_jwk static method to ECAlgorithm #732

Merged
merged 5 commits into from May 23, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.rst
Expand Up @@ -15,6 +15,7 @@ Fixed

Added
~~~~~
- Add to_jwk static method to ECAlgorithm by @leonsmith in https://github.com/jpadilla/pyjwt/pull/732

`v2.4.0 <https://github.com/jpadilla/pyjwt/compare/2.3.0...2.4.0>`__
-----------------------------------------------------------------------
Expand Down
35 changes: 35 additions & 0 deletions jwt/algorithms.py
Expand Up @@ -439,6 +439,41 @@ def verify(self, msg, key, sig):
except InvalidSignature:
return False

@staticmethod
def to_jwk(key_obj):

if isinstance(key_obj, EllipticCurvePrivateKey):
public_numbers = key_obj.public_key().public_numbers()
elif isinstance(key_obj, EllipticCurvePublicKey):
public_numbers = key_obj.public_numbers()
else:
raise InvalidKeyError("Not a public or private key")

if isinstance(key_obj.curve, ec.SECP256R1):
crv = "P-256"
elif isinstance(key_obj.curve, ec.SECP384R1):
crv = "P-384"
elif isinstance(key_obj.curve, ec.SECP521R1):
crv = "P-521"
elif isinstance(key_obj.curve, ec.SECP256K1):
crv = "secp256k1"
else:
raise InvalidKeyError(f"Invalid curve: {key_obj.curve}")

obj = {
"kty": "EC",
"crv": crv,
"x": to_base64url_uint(public_numbers.x).decode(),
"y": to_base64url_uint(public_numbers.y).decode(),
}

if isinstance(key_obj, EllipticCurvePrivateKey):
obj["d"] = to_base64url_uint(
key_obj.private_numbers().private_value
).decode()

return json.dumps(obj)

@staticmethod
def from_jwk(jwk):
try:
Expand Down
5 changes: 5 additions & 0 deletions tests/keys/testkey_ec_secp192r1.priv
@@ -0,0 +1,5 @@
-----BEGIN PRIVATE KEY-----
MG8CAQAwEwYHKoZIzj0CAQYIKoZIzj0DAQEEVTBTAgEBBBiON6kYcPu8ZUDRTu8W
eXJ2FmX7e9yq0hahNAMyAARHecLjkXWDUJfZ4wiFH61JpmonCYH1GpinVlqw68Sf
wtDHg2F6SifQEFC6VKj1ZXw=
-----END PRIVATE KEY-----
97 changes: 97 additions & 0 deletions tests/test_algorithms.py
Expand Up @@ -236,6 +236,103 @@ def test_ec_jwk_fails_on_invalid_json(self):
f'{{"kty": "EC", "crv": "{curve}", "x": "{point["x"]}", "y": "{point["y"]}", "d": "dGVzdA=="}}'
)

@crypto_required
def test_ec_private_key_to_jwk_works_with_from_jwk(self):
algo = ECAlgorithm(ECAlgorithm.SHA256)

with open(key_path("testkey_ec.priv")) as ec_key:
orig_key = algo.prepare_key(ec_key.read())

parsed_key = algo.from_jwk(algo.to_jwk(orig_key))
assert parsed_key.private_numbers() == orig_key.private_numbers()
assert (
parsed_key.private_numbers().public_numbers
== orig_key.private_numbers().public_numbers
)

@crypto_required
def test_ec_public_key_to_jwk_works_with_from_jwk(self):
algo = ECAlgorithm(ECAlgorithm.SHA256)

with open(key_path("testkey_ec.pub")) as ec_key:
orig_key = algo.prepare_key(ec_key.read())

parsed_key = algo.from_jwk(algo.to_jwk(orig_key))
assert parsed_key.public_numbers() == orig_key.public_numbers()

@crypto_required
def test_ec_to_jwk_returns_correct_values_for_public_key(self):
algo = ECAlgorithm(ECAlgorithm.SHA256)

with open(key_path("testkey_ec.pub")) as keyfile:
pub_key = algo.prepare_key(keyfile.read())

key = algo.to_jwk(pub_key)

expected = {
"kty": "EC",
"crv": "P-256",
"x": "HzAcUWSlGBHcuf3y3RiNrWI-pE6-dD2T7fIzg9t6wEc",
"y": "t2G02kbWiOqimYfQAfnARdp2CTycsJPhwA8rn1Cn0SQ",
}

assert json.loads(key) == expected

@crypto_required
def test_ec_to_jwk_returns_correct_values_for_private_key(self):
algo = ECAlgorithm(ECAlgorithm.SHA256)

with open(key_path("testkey_ec.priv")) as keyfile:
priv_key = algo.prepare_key(keyfile.read())

key = algo.to_jwk(priv_key)

expected = {
"kty": "EC",
"crv": "P-256",
"x": "HzAcUWSlGBHcuf3y3RiNrWI-pE6-dD2T7fIzg9t6wEc",
"y": "t2G02kbWiOqimYfQAfnARdp2CTycsJPhwA8rn1Cn0SQ",
"d": "2nninfu2jMHDwAbn9oERUhRADS6duQaJEadybLaa0YQ",
}

assert json.loads(key) == expected

@crypto_required
def test_ec_to_jwk_raises_exception_on_invalid_key(self):
algo = ECAlgorithm(ECAlgorithm.SHA256)

with pytest.raises(InvalidKeyError):
algo.to_jwk({"not": "a valid key"})

@crypto_required
def test_ec_to_jwk_with_valid_curves(self):
tests = {
"P-256": ECAlgorithm.SHA256,
"P-384": ECAlgorithm.SHA384,
"P-521": ECAlgorithm.SHA512,
"secp256k1": ECAlgorithm.SHA256,
}
for (curve, hash) in tests.items():
algo = ECAlgorithm(hash)

with open(key_path(f"jwk_ec_pub_{curve}.json")) as keyfile:
pub_key = algo.from_jwk(keyfile.read())
assert json.loads(algo.to_jwk(pub_key))["crv"] == curve

with open(key_path(f"jwk_ec_key_{curve}.json")) as keyfile:
priv_key = algo.from_jwk(keyfile.read())
assert json.loads(algo.to_jwk(priv_key))["crv"] == curve

@crypto_required
def test_ec_to_jwk_with_invalid_curve(self):
algo = ECAlgorithm(ECAlgorithm.SHA256)

with open(key_path("testkey_ec_secp192r1.priv")) as keyfile:
priv_key = algo.prepare_key(keyfile.read())

with pytest.raises(InvalidKeyError):
algo.to_jwk(priv_key)

@crypto_required
def test_rsa_jwk_public_and_private_keys_should_parse_and_verify(self):
algo = RSAAlgorithm(RSAAlgorithm.SHA256)
Expand Down