Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Simplify black configuration to be closer to upstream defaults #568

Merged
merged 5 commits into from
Dec 21, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
4 changes: 1 addition & 3 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,7 @@ def find_version(*file_paths):
string inside.
"""
version_file = read(*file_paths)
version_match = re.search(
r"^__version__ = ['\"]([^'\"]*)['\"]", version_file, re.M
)
version_match = re.search(r"^__version__ = ['\"]([^'\"]*)['\"]", version_file, re.M)
if version_match:
return version_match.group(1)
raise RuntimeError("Unable to find version string.")
Expand Down
16 changes: 4 additions & 12 deletions jwt/algorithms.py
Original file line number Diff line number Diff line change
Expand Up @@ -444,23 +444,17 @@ def from_jwk(jwk):
if len(x) == len(y) == 32:
curve_obj = ec.SECP256R1()
else:
raise InvalidKeyError(
"Coords should be 32 bytes for curve P-256"
)
raise InvalidKeyError("Coords should be 32 bytes for curve P-256")
elif curve == "P-384":
if len(x) == len(y) == 48:
curve_obj = ec.SECP384R1()
else:
raise InvalidKeyError(
"Coords should be 48 bytes for curve P-384"
)
raise InvalidKeyError("Coords should be 48 bytes for curve P-384")
elif curve == "P-521":
if len(x) == len(y) == 66:
curve_obj = ec.SECP521R1()
else:
raise InvalidKeyError(
"Coords should be 66 bytes for curve P-521"
)
raise InvalidKeyError("Coords should be 66 bytes for curve P-521")
else:
raise InvalidKeyError(f"Invalid curve: {curve}")

Expand Down Expand Up @@ -568,8 +562,6 @@ def verify(self, msg, key, sig):
if isinstance(key, Ed25519PrivateKey):
key = key.public_key()
key.verify(sig, msg)
return (
True # If no exception was raised, the signature is valid.
)
return True # If no exception was raised, the signature is valid.
except cryptography.exceptions.InvalidSignature:
return False
8 changes: 2 additions & 6 deletions jwt/api_jwk.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,16 +13,12 @@ def __init__(self, jwk_data, algorithm=None):
algorithm = self._jwk_data.get("alg", None)

if not algorithm:
raise PyJWKError(
"Unable to find a algorithm for key: %s" % self._jwk_data
)
raise PyJWKError("Unable to find a algorithm for key: %s" % self._jwk_data)

self.Algorithm = self._algorithms.get(algorithm)

if not self.Algorithm:
raise PyJWKError(
"Unable to find a algorithm for key: %s" % self._jwk_data
)
raise PyJWKError("Unable to find a algorithm for key: %s" % self._jwk_data)

self.key = self.Algorithm.from_jwk(self._jwk_data)

Expand Down
12 changes: 3 additions & 9 deletions jwt/api_jws.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,7 @@ class PyJWS:
def __init__(self, algorithms=None, options=None):
self._algorithms = get_default_algorithms()
self._valid_algs = (
set(algorithms)
if algorithms is not None
else set(self._algorithms)
set(algorithms) if algorithms is not None else set(self._algorithms)
)

# Remove algorithms that aren't on the whitelist
Expand Down Expand Up @@ -148,9 +146,7 @@ def decode_complete(
payload, signing_input, header, signature = self._load(jwt)

if verify_signature:
self._verify_signature(
signing_input, header, signature, key, algorithms
)
self._verify_signature(signing_input, header, signature, key, algorithms)

return {
"payload": payload,
Expand Down Expand Up @@ -230,9 +226,7 @@ def _verify_signature(
alg = header.get("alg")

if algorithms is not None and alg not in algorithms:
raise InvalidAlgorithmError(
"The specified alg value is not allowed"
)
raise InvalidAlgorithmError("The specified alg value is not allowed")

try:
alg_obj = self._algorithms[alg]
Expand Down
16 changes: 4 additions & 12 deletions jwt/api_jwt.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,17 +54,13 @@ def encode(
for time_claim in ["exp", "iat", "nbf"]:
# Convert datetime to a intDate value in known time-format claims
if isinstance(payload.get(time_claim), datetime):
payload[time_claim] = timegm(
payload[time_claim].utctimetuple()
)
payload[time_claim] = timegm(payload[time_claim].utctimetuple())

json_payload = json.dumps(
payload, separators=(",", ":"), cls=json_encoder
).encode("utf-8")

return api_jws.encode(
json_payload, key, algorithm, headers, json_encoder
)
return api_jws.encode(json_payload, key, algorithm, headers, json_encoder)

def decode_complete(
self,
Expand Down Expand Up @@ -154,9 +150,7 @@ def _validate_iat(self, payload, now, leeway):
try:
int(payload["iat"])
except ValueError:
raise InvalidIssuedAtError(
"Issued At claim (iat) must be an integer."
)
raise InvalidIssuedAtError("Issued At claim (iat) must be an integer.")

def _validate_nbf(self, payload, now, leeway):
try:
Expand All @@ -171,9 +165,7 @@ def _validate_exp(self, payload, now, leeway):
try:
exp = int(payload["exp"])
except ValueError:
raise DecodeError(
"Expiration Time claim (exp) must be an" " integer."
)
raise DecodeError("Expiration Time claim (exp) must be an" " integer.")

if exp < (now - leeway):
raise ExpiredSignatureError("Signature has expired")
Expand Down
4 changes: 1 addition & 3 deletions jwt/jwks_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,7 @@ def get_signing_keys(self):
signing_keys.append(jwk_set_key)

if len(signing_keys) == 0:
raise PyJWKClientError(
"The JWKS endpoint did not contain any signing keys"
)
raise PyJWKClientError("The JWKS endpoint did not contain any signing keys")

return signing_keys

Expand Down
5 changes: 0 additions & 5 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,7 @@ source = ["jwt", ".tox/*/site-packages"]
show_missing = true


[tool.black]
line-length = 79


[tool.isort]
profile = "black"
atomic = true
combine_as_imports = true
line_length = 79
1 change: 0 additions & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,6 @@ exclude =
tests.*

[flake8]
max-line-length = 79
extend-ignore = E203, E501

[mypy]
Expand Down
8 changes: 2 additions & 6 deletions tests/test_algorithms.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,9 +238,7 @@ def test_ec_jwk_fails_on_invalid_json(self):

# EC coordinates not equally long
with pytest.raises(InvalidKeyError):
algo.from_jwk(
'{"kty": "EC", "x": "dGVzdHRlc3Q=", "y": "dGVzdA=="}'
)
algo.from_jwk('{"kty": "EC", "x": "dGVzdHRlc3Q=", "y": "dGVzdA=="}')

# EC coordinates length invalid
for curve in ("P-256", "P-384", "P-521"):
Expand Down Expand Up @@ -575,9 +573,7 @@ def test_hmac_verify_should_return_true_for_test_vector(self):
b"gd2hlcmUgeW91IG1pZ2h0IGJlIHN3ZXB0IG9mZiB0by4"
)

signature = base64url_decode(
b"s0h6KThzkfBBBkLspW1h84VsJZFTsPPqMDA7g1Md7p0"
)
signature = base64url_decode(b"s0h6KThzkfBBBkLspW1h84VsJZFTsPPqMDA7g1Md7p0")

algo = HMACAlgorithm(HMACAlgorithm.SHA256)
key = algo.prepare_key(load_hmac_key())
Expand Down
72 changes: 17 additions & 55 deletions tests/test_api_jws.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,11 +114,7 @@ def test_decode_works_with_unicode_token(self, jws):

def test_decode_missing_segments_throws_exception(self, jws):
secret = "secret"
example_jws = (
"eyJhbGciOiAiSFMyNTYiLCAidHlwIjogIkpXVCJ9"
".eyJoZWxsbyI6ICJ3b3JsZCJ9"
""
) # Missing segment
example_jws = "eyJhbGciOiAiSFMyNTYiLCAidHlwIjogIkpXVCJ9.eyJoZWxsbyI6ICJ3b3JsZCJ9" # Missing segment

with pytest.raises(DecodeError) as context:
jws.decode(example_jws, secret, algorithms=["HS256"])
Expand Down Expand Up @@ -160,9 +156,7 @@ def test_decode_with_non_mapping_header_throws_exception(self, jws):
exception = context.value
assert str(exception) == "Invalid header string: must be a json object"

def test_encode_algorithm_param_should_be_case_sensitive(
self, jws, payload
):
def test_encode_algorithm_param_should_be_case_sensitive(self, jws, payload):

jws.encode(payload, "secret", algorithm="HS256")

Expand Down Expand Up @@ -207,9 +201,7 @@ def test_decodes_valid_jws(self, jws, payload):
b"gEW0pdU4kxPthjtehYdhxB9mMOGajt1xCKlGGXDJ8PM"
)

decoded_payload = jws.decode(
example_jws, example_secret, algorithms=["HS256"]
)
decoded_payload = jws.decode(example_jws, example_secret, algorithms=["HS256"])

assert decoded_payload == payload

Expand All @@ -221,9 +213,7 @@ def test_decodes_complete_valid_jws(self, jws, payload):
b"gEW0pdU4kxPthjtehYdhxB9mMOGajt1xCKlGGXDJ8PM"
)

decoded = jws.decode_complete(
example_jws, example_secret, algorithms=["HS256"]
)
decoded = jws.decode_complete(example_jws, example_secret, algorithms=["HS256"])

assert decoded == {
"header": {"alg": "HS256", "typ": "JWT"},
Expand All @@ -248,9 +238,7 @@ def test_decodes_valid_es384_jws(self, jws):
b"eyJoZWxsbyI6IndvcmxkIn0.TORyNQab_MoXM7DvNKaTwbrJr4UY"
b"d2SsX8hhlnWelQFmPFSf_JzC2EbLnar92t-bXsDovzxp25ExazrVHkfPkQ"
)
decoded_payload = jws.decode(
example_jws, example_pubkey, algorithms=["ES256"]
)
decoded_payload = jws.decode(example_jws, example_pubkey, algorithms=["ES256"])
json_payload = json.loads(decoded_payload)

assert json_payload == example_payload
Expand All @@ -276,9 +264,7 @@ def test_decodes_valid_rs384_jws(self, jws):
b"uwmrtSWCBUjiN8sqJ00CDgycxKqHfUndZbEAOjcCAhBr"
b"qWW3mSVivUfubsYbwUdUG3fSRPjaUPcpe8A"
)
decoded_payload = jws.decode(
example_jws, example_pubkey, algorithms=["RS384"]
)
decoded_payload = jws.decode(example_jws, example_pubkey, algorithms=["RS384"])
json_payload = json.loads(decoded_payload)

assert json_payload == example_payload
Expand All @@ -299,9 +285,7 @@ def test_load_verify_valid_jws(self, jws, payload):
def test_allow_skip_verification(self, jws, payload):
right_secret = "foo"
jws_message = jws.encode(payload, right_secret)
decoded_payload = jws.decode(
jws_message, options={"verify_signature": False}
)
decoded_payload = jws.decode(jws_message, options={"verify_signature": False})

assert decoded_payload == payload

Expand Down Expand Up @@ -364,14 +348,8 @@ def test_verify_signature_with_no_secret(self, jws, payload):

assert "Signature verification" in str(exc.value)

def test_verify_signature_with_no_algo_header_throws_exception(
self, jws, payload
):
example_jws = (
b"e30"
b".eyJhIjo1fQ"
b".KEh186CjVw_Q8FadjJcaVnE7hO5Z9nHBbU8TgbhHcBY"
)
def test_verify_signature_with_no_algo_header_throws_exception(self, jws, payload):
example_jws = b"e30.eyJhIjo1fQ.KEh186CjVw_Q8FadjJcaVnE7hO5Z9nHBbU8TgbhHcBY"

with pytest.raises(InvalidAlgorithmError):
jws.decode(example_jws, "secret", algorithms=["HS256"])
Expand Down Expand Up @@ -467,9 +445,7 @@ def test_decode_with_algo_none_should_fail(self, jws, payload):
with pytest.raises(DecodeError):
jws.decode(jws_message, algorithms=["none"])

def test_decode_with_algo_none_and_verify_false_should_pass(
self, jws, payload
):
def test_decode_with_algo_none_and_verify_false_should_pass(self, jws, payload):
jws_message = jws.encode(payload, key=None, algorithm=None)
jws.decode(jws_message, options={"verify_signature": False})

Expand All @@ -486,9 +462,7 @@ def test_get_unverified_header_returns_header_values(self, jws, payload):
assert "kid" in header
assert header["kid"] == "toomanysecrets"

def test_get_unverified_header_fails_on_bad_header_types(
self, jws, payload
):
def test_get_unverified_header_fails_on_bad_header_types(self, jws, payload):
# Contains a bad kid value (int 123 instead of string)
example_jws = (
"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCIsImtpZCI6MTIzfQ"
Expand All @@ -505,9 +479,7 @@ def test_get_unverified_header_fails_on_bad_header_types(
def test_encode_decode_with_rsa_sha256(self, jws, payload):
# PEM-formatted RSA key
with open(key_path("testkey_rsa.priv"), "rb") as rsa_priv_file:
priv_rsakey = load_pem_private_key(
rsa_priv_file.read(), password=None
)
priv_rsakey = load_pem_private_key(rsa_priv_file.read(), password=None)
jws_message = jws.encode(payload, priv_rsakey, algorithm="RS256")

with open(key_path("testkey_rsa.pub"), "rb") as rsa_pub_file:
Expand All @@ -528,9 +500,7 @@ def test_encode_decode_with_rsa_sha256(self, jws, payload):
def test_encode_decode_with_rsa_sha384(self, jws, payload):
# PEM-formatted RSA key
with open(key_path("testkey_rsa.priv"), "rb") as rsa_priv_file:
priv_rsakey = load_pem_private_key(
rsa_priv_file.read(), password=None
)
priv_rsakey = load_pem_private_key(rsa_priv_file.read(), password=None)
jws_message = jws.encode(payload, priv_rsakey, algorithm="RS384")

with open(key_path("testkey_rsa.pub"), "rb") as rsa_pub_file:
Expand All @@ -550,9 +520,7 @@ def test_encode_decode_with_rsa_sha384(self, jws, payload):
def test_encode_decode_with_rsa_sha512(self, jws, payload):
# PEM-formatted RSA key
with open(key_path("testkey_rsa.priv"), "rb") as rsa_priv_file:
priv_rsakey = load_pem_private_key(
rsa_priv_file.read(), password=None
)
priv_rsakey = load_pem_private_key(rsa_priv_file.read(), password=None)
jws_message = jws.encode(payload, priv_rsakey, algorithm="RS512")

with open(key_path("testkey_rsa.pub"), "rb") as rsa_pub_file:
Expand Down Expand Up @@ -592,9 +560,7 @@ def test_rsa_related_algorithms(self, jws):
def test_encode_decode_with_ecdsa_sha256(self, jws, payload):
# PEM-formatted EC key
with open(key_path("testkey_ec.priv"), "rb") as ec_priv_file:
priv_eckey = load_pem_private_key(
ec_priv_file.read(), password=None
)
priv_eckey = load_pem_private_key(ec_priv_file.read(), password=None)
jws_message = jws.encode(payload, priv_eckey, algorithm="ES256")

with open(key_path("testkey_ec.pub"), "rb") as ec_pub_file:
Expand All @@ -615,9 +581,7 @@ def test_encode_decode_with_ecdsa_sha384(self, jws, payload):

# PEM-formatted EC key
with open(key_path("testkey_ec.priv"), "rb") as ec_priv_file:
priv_eckey = load_pem_private_key(
ec_priv_file.read(), password=None
)
priv_eckey = load_pem_private_key(ec_priv_file.read(), password=None)
jws_message = jws.encode(payload, priv_eckey, algorithm="ES384")

with open(key_path("testkey_ec.pub"), "rb") as ec_pub_file:
Expand All @@ -637,9 +601,7 @@ def test_encode_decode_with_ecdsa_sha384(self, jws, payload):
def test_encode_decode_with_ecdsa_sha512(self, jws, payload):
# PEM-formatted EC key
with open(key_path("testkey_ec.priv"), "rb") as ec_priv_file:
priv_eckey = load_pem_private_key(
ec_priv_file.read(), password=None
)
priv_eckey = load_pem_private_key(ec_priv_file.read(), password=None)
jws_message = jws.encode(payload, priv_eckey, algorithm="ES512")

with open(key_path("testkey_ec.pub"), "rb") as ec_pub_file:
Expand Down