Skip to content

Commit

Permalink
Simplify black configuration to be closer to upstream defaults (#568)
Browse files Browse the repository at this point in the history
* Simplify black configuration to be closer to upstream defaults

Avoid extra configuration by simply going with Black defaults. This
allows removing some configuration options, thus simplifying the overall
configuration.

It also makes the code style closer to community conventions. As more
projects adopt black formatting, more code will look like the black
defaults.

Further, the default 88 tends to create more readable lines, IMO. The
black rationale is located at:
https://black.readthedocs.io/en/stable/the_black_code_style.html#line-length

* Update tests/test_api_jws.py

Co-authored-by: José Padilla <jpadilla@webapplicate.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Update tests/test_api_jws.py

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

Co-authored-by: José Padilla <jpadilla@webapplicate.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
3 people committed Dec 21, 2020
1 parent d5eff64 commit 811ae79
Show file tree
Hide file tree
Showing 13 changed files with 51 additions and 163 deletions.
4 changes: 1 addition & 3 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,7 @@ def find_version(*file_paths):
string inside.
"""
version_file = read(*file_paths)
version_match = re.search(
r"^__version__ = ['\"]([^'\"]*)['\"]", version_file, re.M
)
version_match = re.search(r"^__version__ = ['\"]([^'\"]*)['\"]", version_file, re.M)
if version_match:
return version_match.group(1)
raise RuntimeError("Unable to find version string.")
Expand Down
16 changes: 4 additions & 12 deletions jwt/algorithms.py
Original file line number Diff line number Diff line change
Expand Up @@ -444,23 +444,17 @@ def from_jwk(jwk):
if len(x) == len(y) == 32:
curve_obj = ec.SECP256R1()
else:
raise InvalidKeyError(
"Coords should be 32 bytes for curve P-256"
)
raise InvalidKeyError("Coords should be 32 bytes for curve P-256")
elif curve == "P-384":
if len(x) == len(y) == 48:
curve_obj = ec.SECP384R1()
else:
raise InvalidKeyError(
"Coords should be 48 bytes for curve P-384"
)
raise InvalidKeyError("Coords should be 48 bytes for curve P-384")
elif curve == "P-521":
if len(x) == len(y) == 66:
curve_obj = ec.SECP521R1()
else:
raise InvalidKeyError(
"Coords should be 66 bytes for curve P-521"
)
raise InvalidKeyError("Coords should be 66 bytes for curve P-521")
else:
raise InvalidKeyError(f"Invalid curve: {curve}")

Expand Down Expand Up @@ -568,8 +562,6 @@ def verify(self, msg, key, sig):
if isinstance(key, Ed25519PrivateKey):
key = key.public_key()
key.verify(sig, msg)
return (
True # If no exception was raised, the signature is valid.
)
return True # If no exception was raised, the signature is valid.
except cryptography.exceptions.InvalidSignature:
return False
8 changes: 2 additions & 6 deletions jwt/api_jwk.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,16 +13,12 @@ def __init__(self, jwk_data, algorithm=None):
algorithm = self._jwk_data.get("alg", None)

if not algorithm:
raise PyJWKError(
"Unable to find a algorithm for key: %s" % self._jwk_data
)
raise PyJWKError("Unable to find a algorithm for key: %s" % self._jwk_data)

self.Algorithm = self._algorithms.get(algorithm)

if not self.Algorithm:
raise PyJWKError(
"Unable to find a algorithm for key: %s" % self._jwk_data
)
raise PyJWKError("Unable to find a algorithm for key: %s" % self._jwk_data)

self.key = self.Algorithm.from_jwk(self._jwk_data)

Expand Down
12 changes: 3 additions & 9 deletions jwt/api_jws.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,7 @@ class PyJWS:
def __init__(self, algorithms=None, options=None):
self._algorithms = get_default_algorithms()
self._valid_algs = (
set(algorithms)
if algorithms is not None
else set(self._algorithms)
set(algorithms) if algorithms is not None else set(self._algorithms)
)

# Remove algorithms that aren't on the whitelist
Expand Down Expand Up @@ -148,9 +146,7 @@ def decode_complete(
payload, signing_input, header, signature = self._load(jwt)

if verify_signature:
self._verify_signature(
signing_input, header, signature, key, algorithms
)
self._verify_signature(signing_input, header, signature, key, algorithms)

return {
"payload": payload,
Expand Down Expand Up @@ -230,9 +226,7 @@ def _verify_signature(
alg = header.get("alg")

if algorithms is not None and alg not in algorithms:
raise InvalidAlgorithmError(
"The specified alg value is not allowed"
)
raise InvalidAlgorithmError("The specified alg value is not allowed")

try:
alg_obj = self._algorithms[alg]
Expand Down
16 changes: 4 additions & 12 deletions jwt/api_jwt.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,17 +54,13 @@ def encode(
for time_claim in ["exp", "iat", "nbf"]:
# Convert datetime to a intDate value in known time-format claims
if isinstance(payload.get(time_claim), datetime):
payload[time_claim] = timegm(
payload[time_claim].utctimetuple()
)
payload[time_claim] = timegm(payload[time_claim].utctimetuple())

json_payload = json.dumps(
payload, separators=(",", ":"), cls=json_encoder
).encode("utf-8")

return api_jws.encode(
json_payload, key, algorithm, headers, json_encoder
)
return api_jws.encode(json_payload, key, algorithm, headers, json_encoder)

def decode_complete(
self,
Expand Down Expand Up @@ -154,9 +150,7 @@ def _validate_iat(self, payload, now, leeway):
try:
int(payload["iat"])
except ValueError:
raise InvalidIssuedAtError(
"Issued At claim (iat) must be an integer."
)
raise InvalidIssuedAtError("Issued At claim (iat) must be an integer.")

def _validate_nbf(self, payload, now, leeway):
try:
Expand All @@ -171,9 +165,7 @@ def _validate_exp(self, payload, now, leeway):
try:
exp = int(payload["exp"])
except ValueError:
raise DecodeError(
"Expiration Time claim (exp) must be an" " integer."
)
raise DecodeError("Expiration Time claim (exp) must be an" " integer.")

if exp < (now - leeway):
raise ExpiredSignatureError("Signature has expired")
Expand Down
4 changes: 1 addition & 3 deletions jwt/jwks_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,7 @@ def get_signing_keys(self):
signing_keys.append(jwk_set_key)

if len(signing_keys) == 0:
raise PyJWKClientError(
"The JWKS endpoint did not contain any signing keys"
)
raise PyJWKClientError("The JWKS endpoint did not contain any signing keys")

return signing_keys

Expand Down
5 changes: 0 additions & 5 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,7 @@ source = ["jwt", ".tox/*/site-packages"]
show_missing = true


[tool.black]
line-length = 79


[tool.isort]
profile = "black"
atomic = true
combine_as_imports = true
line_length = 79
1 change: 0 additions & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,6 @@ exclude =
tests.*

[flake8]
max-line-length = 79
extend-ignore = E203, E501

[mypy]
Expand Down
8 changes: 2 additions & 6 deletions tests/test_algorithms.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,9 +238,7 @@ def test_ec_jwk_fails_on_invalid_json(self):

# EC coordinates not equally long
with pytest.raises(InvalidKeyError):
algo.from_jwk(
'{"kty": "EC", "x": "dGVzdHRlc3Q=", "y": "dGVzdA=="}'
)
algo.from_jwk('{"kty": "EC", "x": "dGVzdHRlc3Q=", "y": "dGVzdA=="}')

# EC coordinates length invalid
for curve in ("P-256", "P-384", "P-521"):
Expand Down Expand Up @@ -575,9 +573,7 @@ def test_hmac_verify_should_return_true_for_test_vector(self):
b"gd2hlcmUgeW91IG1pZ2h0IGJlIHN3ZXB0IG9mZiB0by4"
)

signature = base64url_decode(
b"s0h6KThzkfBBBkLspW1h84VsJZFTsPPqMDA7g1Md7p0"
)
signature = base64url_decode(b"s0h6KThzkfBBBkLspW1h84VsJZFTsPPqMDA7g1Md7p0")

algo = HMACAlgorithm(HMACAlgorithm.SHA256)
key = algo.prepare_key(load_hmac_key())
Expand Down
72 changes: 17 additions & 55 deletions tests/test_api_jws.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,11 +114,7 @@ def test_decode_works_with_unicode_token(self, jws):

def test_decode_missing_segments_throws_exception(self, jws):
secret = "secret"
example_jws = (
"eyJhbGciOiAiSFMyNTYiLCAidHlwIjogIkpXVCJ9"
".eyJoZWxsbyI6ICJ3b3JsZCJ9"
""
) # Missing segment
example_jws = "eyJhbGciOiAiSFMyNTYiLCAidHlwIjogIkpXVCJ9.eyJoZWxsbyI6ICJ3b3JsZCJ9" # Missing segment

with pytest.raises(DecodeError) as context:
jws.decode(example_jws, secret, algorithms=["HS256"])
Expand Down Expand Up @@ -160,9 +156,7 @@ def test_decode_with_non_mapping_header_throws_exception(self, jws):
exception = context.value
assert str(exception) == "Invalid header string: must be a json object"

def test_encode_algorithm_param_should_be_case_sensitive(
self, jws, payload
):
def test_encode_algorithm_param_should_be_case_sensitive(self, jws, payload):

jws.encode(payload, "secret", algorithm="HS256")

Expand Down Expand Up @@ -207,9 +201,7 @@ def test_decodes_valid_jws(self, jws, payload):
b"gEW0pdU4kxPthjtehYdhxB9mMOGajt1xCKlGGXDJ8PM"
)

decoded_payload = jws.decode(
example_jws, example_secret, algorithms=["HS256"]
)
decoded_payload = jws.decode(example_jws, example_secret, algorithms=["HS256"])

assert decoded_payload == payload

Expand All @@ -221,9 +213,7 @@ def test_decodes_complete_valid_jws(self, jws, payload):
b"gEW0pdU4kxPthjtehYdhxB9mMOGajt1xCKlGGXDJ8PM"
)

decoded = jws.decode_complete(
example_jws, example_secret, algorithms=["HS256"]
)
decoded = jws.decode_complete(example_jws, example_secret, algorithms=["HS256"])

assert decoded == {
"header": {"alg": "HS256", "typ": "JWT"},
Expand All @@ -248,9 +238,7 @@ def test_decodes_valid_es384_jws(self, jws):
b"eyJoZWxsbyI6IndvcmxkIn0.TORyNQab_MoXM7DvNKaTwbrJr4UY"
b"d2SsX8hhlnWelQFmPFSf_JzC2EbLnar92t-bXsDovzxp25ExazrVHkfPkQ"
)
decoded_payload = jws.decode(
example_jws, example_pubkey, algorithms=["ES256"]
)
decoded_payload = jws.decode(example_jws, example_pubkey, algorithms=["ES256"])
json_payload = json.loads(decoded_payload)

assert json_payload == example_payload
Expand All @@ -276,9 +264,7 @@ def test_decodes_valid_rs384_jws(self, jws):
b"uwmrtSWCBUjiN8sqJ00CDgycxKqHfUndZbEAOjcCAhBr"
b"qWW3mSVivUfubsYbwUdUG3fSRPjaUPcpe8A"
)
decoded_payload = jws.decode(
example_jws, example_pubkey, algorithms=["RS384"]
)
decoded_payload = jws.decode(example_jws, example_pubkey, algorithms=["RS384"])
json_payload = json.loads(decoded_payload)

assert json_payload == example_payload
Expand All @@ -299,9 +285,7 @@ def test_load_verify_valid_jws(self, jws, payload):
def test_allow_skip_verification(self, jws, payload):
right_secret = "foo"
jws_message = jws.encode(payload, right_secret)
decoded_payload = jws.decode(
jws_message, options={"verify_signature": False}
)
decoded_payload = jws.decode(jws_message, options={"verify_signature": False})

assert decoded_payload == payload

Expand Down Expand Up @@ -364,14 +348,8 @@ def test_verify_signature_with_no_secret(self, jws, payload):

assert "Signature verification" in str(exc.value)

def test_verify_signature_with_no_algo_header_throws_exception(
self, jws, payload
):
example_jws = (
b"e30"
b".eyJhIjo1fQ"
b".KEh186CjVw_Q8FadjJcaVnE7hO5Z9nHBbU8TgbhHcBY"
)
def test_verify_signature_with_no_algo_header_throws_exception(self, jws, payload):
example_jws = b"e30.eyJhIjo1fQ.KEh186CjVw_Q8FadjJcaVnE7hO5Z9nHBbU8TgbhHcBY"

with pytest.raises(InvalidAlgorithmError):
jws.decode(example_jws, "secret", algorithms=["HS256"])
Expand Down Expand Up @@ -467,9 +445,7 @@ def test_decode_with_algo_none_should_fail(self, jws, payload):
with pytest.raises(DecodeError):
jws.decode(jws_message, algorithms=["none"])

def test_decode_with_algo_none_and_verify_false_should_pass(
self, jws, payload
):
def test_decode_with_algo_none_and_verify_false_should_pass(self, jws, payload):
jws_message = jws.encode(payload, key=None, algorithm=None)
jws.decode(jws_message, options={"verify_signature": False})

Expand All @@ -486,9 +462,7 @@ def test_get_unverified_header_returns_header_values(self, jws, payload):
assert "kid" in header
assert header["kid"] == "toomanysecrets"

def test_get_unverified_header_fails_on_bad_header_types(
self, jws, payload
):
def test_get_unverified_header_fails_on_bad_header_types(self, jws, payload):
# Contains a bad kid value (int 123 instead of string)
example_jws = (
"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCIsImtpZCI6MTIzfQ"
Expand All @@ -505,9 +479,7 @@ def test_get_unverified_header_fails_on_bad_header_types(
def test_encode_decode_with_rsa_sha256(self, jws, payload):
# PEM-formatted RSA key
with open(key_path("testkey_rsa.priv"), "rb") as rsa_priv_file:
priv_rsakey = load_pem_private_key(
rsa_priv_file.read(), password=None
)
priv_rsakey = load_pem_private_key(rsa_priv_file.read(), password=None)
jws_message = jws.encode(payload, priv_rsakey, algorithm="RS256")

with open(key_path("testkey_rsa.pub"), "rb") as rsa_pub_file:
Expand All @@ -528,9 +500,7 @@ def test_encode_decode_with_rsa_sha256(self, jws, payload):
def test_encode_decode_with_rsa_sha384(self, jws, payload):
# PEM-formatted RSA key
with open(key_path("testkey_rsa.priv"), "rb") as rsa_priv_file:
priv_rsakey = load_pem_private_key(
rsa_priv_file.read(), password=None
)
priv_rsakey = load_pem_private_key(rsa_priv_file.read(), password=None)
jws_message = jws.encode(payload, priv_rsakey, algorithm="RS384")

with open(key_path("testkey_rsa.pub"), "rb") as rsa_pub_file:
Expand All @@ -550,9 +520,7 @@ def test_encode_decode_with_rsa_sha384(self, jws, payload):
def test_encode_decode_with_rsa_sha512(self, jws, payload):
# PEM-formatted RSA key
with open(key_path("testkey_rsa.priv"), "rb") as rsa_priv_file:
priv_rsakey = load_pem_private_key(
rsa_priv_file.read(), password=None
)
priv_rsakey = load_pem_private_key(rsa_priv_file.read(), password=None)
jws_message = jws.encode(payload, priv_rsakey, algorithm="RS512")

with open(key_path("testkey_rsa.pub"), "rb") as rsa_pub_file:
Expand Down Expand Up @@ -592,9 +560,7 @@ def test_rsa_related_algorithms(self, jws):
def test_encode_decode_with_ecdsa_sha256(self, jws, payload):
# PEM-formatted EC key
with open(key_path("testkey_ec.priv"), "rb") as ec_priv_file:
priv_eckey = load_pem_private_key(
ec_priv_file.read(), password=None
)
priv_eckey = load_pem_private_key(ec_priv_file.read(), password=None)
jws_message = jws.encode(payload, priv_eckey, algorithm="ES256")

with open(key_path("testkey_ec.pub"), "rb") as ec_pub_file:
Expand All @@ -615,9 +581,7 @@ def test_encode_decode_with_ecdsa_sha384(self, jws, payload):

# PEM-formatted EC key
with open(key_path("testkey_ec.priv"), "rb") as ec_priv_file:
priv_eckey = load_pem_private_key(
ec_priv_file.read(), password=None
)
priv_eckey = load_pem_private_key(ec_priv_file.read(), password=None)
jws_message = jws.encode(payload, priv_eckey, algorithm="ES384")

with open(key_path("testkey_ec.pub"), "rb") as ec_pub_file:
Expand All @@ -637,9 +601,7 @@ def test_encode_decode_with_ecdsa_sha384(self, jws, payload):
def test_encode_decode_with_ecdsa_sha512(self, jws, payload):
# PEM-formatted EC key
with open(key_path("testkey_ec.priv"), "rb") as ec_priv_file:
priv_eckey = load_pem_private_key(
ec_priv_file.read(), password=None
)
priv_eckey = load_pem_private_key(ec_priv_file.read(), password=None)
jws_message = jws.encode(payload, priv_eckey, algorithm="ES512")

with open(key_path("testkey_ec.pub"), "rb") as ec_pub_file:
Expand Down

0 comments on commit 811ae79

Please sign in to comment.