Skip to content

Commit

Permalink
make sure assert is not optimised out
Browse files Browse the repository at this point in the history
as assert will be optimised out if the module is compiled with
optimisations on, we can't use them for checking user-provided data

use an exception that inherits from it, so that existing code will
behave as expected
  • Loading branch information
tomato42 committed Oct 9, 2019
1 parent 35ea44c commit 0e907c0
Show file tree
Hide file tree
Showing 3 changed files with 184 additions and 13 deletions.
6 changes: 1 addition & 5 deletions .coveragerc
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,5 @@
include =
src/ecdsa/*
omit =
src/ecdsa/six.py
src/ecdsa/_version.py
src/ecdsa/test_ecdsa.py
src/ecdsa/test_ellipticcurve.py
src/ecdsa/test_numbertheory.py
src/ecdsa/test_pyecdsa.py
src/ecdsa/test_*
21 changes: 15 additions & 6 deletions src/ecdsa/keys.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,12 +48,14 @@ def from_public_point(klass, point, curve=NIST192p, hashfunc=sha1):
@staticmethod
def _from_raw_encoding(string, curve, validate_point):
order = curve.order
assert (len(string) == curve.verifying_key_length), \
(len(string), curve.verifying_key_length)
# real assert, from_string() should not call us with different length
assert len(string) == curve.verifying_key_length
xs = string[:curve.baselen]
ys = string[curve.baselen:]
assert len(xs) == curve.baselen, (len(xs), curve.baselen)
assert len(ys) == curve.baselen, (len(ys), curve.baselen)
if len(xs) != curve.baselen:
raise MalformedPointError("Unexpected length of encoded x")
if len(ys) != curve.baselen:
raise MalformedPointError("Unexpected length of encoded y")
x = string_to_number(xs)
y = string_to_number(ys)
if validate_point and not ecdsa.point_is_valid(curve.generator, x, y):
Expand Down Expand Up @@ -86,6 +88,7 @@ def _from_compressed(string, curve, validate_point):

@classmethod
def _from_hybrid(cls, string, curve, validate_point):
# real assert, from_string() should not call us with different types
assert string[:1] in (b('\x06'), b('\x07'))

# primarily use the uncompressed as it's easiest to handle
Expand Down Expand Up @@ -271,7 +274,10 @@ def from_secret_exponent(klass, secexp, curve=NIST192p, hashfunc=sha1):
self.default_hashfunc = hashfunc
self.baselen = curve.baselen
n = curve.order
assert 1 <= secexp < n
if not 1 <= secexp < n:
raise MalformedPointError(
"Invalid value for secexp, expected integer between 1 and {0}"
.format(n))
pubkey_point = curve.generator * secexp
pubkey = ecdsa.Public_key(curve.generator, pubkey_point)
pubkey.order = n
Expand All @@ -283,7 +289,10 @@ def from_secret_exponent(klass, secexp, curve=NIST192p, hashfunc=sha1):

@classmethod
def from_string(klass, string, curve=NIST192p, hashfunc=sha1):
assert len(string) == curve.baselen, (len(string), curve.baselen)
if len(string) != curve.baselen:
raise MalformedPointError(
"Invalid length of private key, received {0}, expected {1}"
.format(len(string), curve.baselen))
secexp = string_to_number(string)
return klass.from_secret_exponent(secexp, curve, hashfunc)

Expand Down
170 changes: 168 additions & 2 deletions src/ecdsa/test_pyecdsa.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,17 +14,19 @@

from six import b, print_, binary_type
from .keys import SigningKey, VerifyingKey
from .keys import BadSignatureError, MalformedPointError
from .keys import BadSignatureError, MalformedPointError, BadDigestError
from . import util
from .util import sigencode_der, sigencode_strings
from .util import sigdecode_der, sigdecode_strings
from .util import number_to_string
from .util import number_to_string, encoded_oid_ecPublicKey, \
MalformedSignature
from .curves import Curve, UnknownCurveError
from .curves import NIST192p, NIST224p, NIST256p, NIST384p, NIST521p, \
SECP256k1, curves
from .ellipticcurve import Point
from . import der
from . import rfc6979
from . import ecdsa


class SubprocessError(Exception):
Expand Down Expand Up @@ -275,6 +277,47 @@ def order(self):
pub2 = VerifyingKey.from_pem(pem)
self.assertTruePubkeysEqual(pub1, pub2)

def test_vk_from_der_garbage_after_curve_oid(self):
type_oid_der = encoded_oid_ecPublicKey
curve_oid_der = der.encode_oid(*(1, 2, 840, 10045, 3, 1, 1)) + \
b('garbage')
enc_type_der = der.encode_sequence(type_oid_der, curve_oid_der)
point_der = der.encode_bitstring(b'\x00\xff')
to_decode = der.encode_sequence(enc_type_der, point_der)

with self.assertRaises(der.UnexpectedDER):
VerifyingKey.from_der(to_decode)

def test_vk_from_der_invalid_key_type(self):
type_oid_der = der.encode_oid(*(1, 2, 3))
curve_oid_der = der.encode_oid(*(1, 2, 840, 10045, 3, 1, 1))
enc_type_der = der.encode_sequence(type_oid_der, curve_oid_der)
point_der = der.encode_bitstring(b'\x00\xff')
to_decode = der.encode_sequence(enc_type_der, point_der)

with self.assertRaises(der.UnexpectedDER):
VerifyingKey.from_der(to_decode)

def test_vk_from_der_garbage_after_point_string(self):
type_oid_der = encoded_oid_ecPublicKey
curve_oid_der = der.encode_oid(*(1, 2, 840, 10045, 3, 1, 1))
enc_type_der = der.encode_sequence(type_oid_der, curve_oid_der)
point_der = der.encode_bitstring(b'\x00\xff') + b('garbage')
to_decode = der.encode_sequence(enc_type_der, point_der)

with self.assertRaises(der.UnexpectedDER):
VerifyingKey.from_der(to_decode)

def test_vk_from_der_invalid_bitstring(self):
type_oid_der = encoded_oid_ecPublicKey
curve_oid_der = der.encode_oid(*(1, 2, 840, 10045, 3, 1, 1))
enc_type_der = der.encode_sequence(type_oid_der, curve_oid_der)
point_der = der.encode_bitstring(b'\x08\xff')
to_decode = der.encode_sequence(enc_type_der, point_der)

with self.assertRaises(der.UnexpectedDER):
VerifyingKey.from_der(to_decode)

def test_signature_strings(self):
priv1 = SigningKey.generate()
pub1 = priv1.get_verifying_key()
Expand All @@ -298,6 +341,86 @@ def test_signature_strings(self):
self.assertEqual(type(sig_der), binary_type)
self.assertTrue(pub1.verify(sig_der, data, sigdecode=sigdecode_der))

def test_sig_decode_strings_with_invalid_count(self):
with self.assertRaises(MalformedSignature):
sigdecode_strings([b('one'), b('two'), b('three')], 0xff)

def test_sig_decode_strings_with_wrong_r_len(self):
with self.assertRaises(MalformedSignature):
sigdecode_strings([b('one'), b('two')], 0xff)

def test_sig_decode_strings_with_wrong_s_len(self):
with self.assertRaises(MalformedSignature):
sigdecode_strings([b('\xa0'), b('\xb0\xff')], 0xff)

def test_verify_with_too_long_input(self):
sk = SigningKey.generate()
vk = sk.verifying_key

with self.assertRaises(BadDigestError):
vk.verify_digest(None, b('\x00') * 128)

def test_sk_from_secret_exponent_with_wrong_sec_exponent(self):
with self.assertRaises(MalformedPointError):
SigningKey.from_secret_exponent(0)

def test_sk_from_string_with_wrong_len_string(self):
with self.assertRaises(MalformedPointError):
SigningKey.from_string(b('\x01'))

def test_sk_from_der_with_junk_after_sequence(self):
ver_der = der.encode_integer(1)
to_decode = der.encode_sequence(ver_der) + b('garbage')

with self.assertRaises(der.UnexpectedDER):
SigningKey.from_der(to_decode)

def test_sk_from_der_with_wrong_version(self):
ver_der = der.encode_integer(0)
to_decode = der.encode_sequence(ver_der)

with self.assertRaises(der.UnexpectedDER):
SigningKey.from_der(to_decode)

def test_sk_from_der_invalid_const_tag(self):
ver_der = der.encode_integer(1)
privkey_der = der.encode_octet_string(b('\x00\xff'))
curve_oid_der = der.encode_oid(*(1, 2, 3))
const_der = der.encode_constructed(1, curve_oid_der)
to_decode = der.encode_sequence(ver_der, privkey_der, const_der,
curve_oid_der)

with self.assertRaises(der.UnexpectedDER):
SigningKey.from_der(to_decode)

def test_sk_from_der_garbage_after_privkey_oid(self):
ver_der = der.encode_integer(1)
privkey_der = der.encode_octet_string(b('\x00\xff'))
curve_oid_der = der.encode_oid(*(1, 2, 3)) + b('garbage')
const_der = der.encode_constructed(0, curve_oid_der)
to_decode = der.encode_sequence(ver_der, privkey_der, const_der,
curve_oid_der)

with self.assertRaises(der.UnexpectedDER):
SigningKey.from_der(to_decode)

def test_sk_from_der_with_short_privkey(self):
ver_der = der.encode_integer(1)
privkey_der = der.encode_octet_string(b('\x00\xff'))
curve_oid_der = der.encode_oid(*(1, 2, 840, 10045, 3, 1, 1))
const_der = der.encode_constructed(0, curve_oid_der)
to_decode = der.encode_sequence(ver_der, privkey_der, const_der,
curve_oid_der)

sk = SigningKey.from_der(to_decode)
self.assertEqual(sk.privkey.secret_multiplier, 255)

def test_sign_with_too_long_hash(self):
sk = SigningKey.from_secret_exponent(12)

with self.assertRaises(BadDigestError):
sk.sign_digest(b('\xff') * 64)

def test_hashfunc(self):
sk = SigningKey.generate(curve=NIST256p, hashfunc=sha256)
data = b("security level is 128 bits")
Expand Down Expand Up @@ -448,6 +571,49 @@ def test_not_lying_on_curve(self):
with self.assertRaises(MalformedPointError):
VerifyingKey.from_string(b('\x02') + enc)

def test_decoding_with_malformed_uncompressed(self):
enc = b('\x0c\xe0\x1d\xe0d\x1c\x8eS\x8a\xc0\x9eK\xa8x !\xd5\xc2\xc3'
'\xfd\xc8\xa0c\xff\xfb\x02\xb9\xc4\x84)\x1a\x0f\x8b\x87\xa4'
'z\x8a#\xb5\x97\xecO\xb6\xa0HQ\x89*')

with self.assertRaises(MalformedPointError):
VerifyingKey.from_string(b('\x02') + enc)

def test_decoding_with_point_not_on_curve(self):
enc = b('\x0c\xe0\x1d\xe0d\x1c\x8eS\x8a\xc0\x9eK\xa8x !\xd5\xc2\xc3'
'\xfd\xc8\xa0c\xff\xfb\x02\xb9\xc4\x84)\x1a\x0f\x8b\x87\xa4'
'z\x8a#\xb5\x97\xecO\xb6\xa0HQ\x89*')

with self.assertRaises(MalformedPointError):
VerifyingKey.from_string(enc[:47] + b('\x00'))

def test_decoding_with_point_at_infinity(self):
# decoding it is unsupported, as it's not necessary to encode it
with self.assertRaises(MalformedPointError):
VerifyingKey.from_string(b('\x00'))

def test_from_string_with_invalid_curve_too_short_ver_key_len(self):
# both verifying_key_length and baselen are calculated internally
# by the Curve constructor, but since we depend on them verify
# that inconsistent values are detected
curve = Curve("test", ecdsa.curve_192, ecdsa.generator_192, (1, 2))
curve.verifying_key_length = 16
curve.baselen = 32

with self.assertRaises(MalformedPointError):
VerifyingKey.from_string(b('\x00')*16, curve)

def test_from_string_with_invalid_curve_too_long_ver_key_len(self):
# both verifying_key_length and baselen are calculated internally
# by the Curve constructor, but since we depend on them verify
# that inconsistent values are detected
curve = Curve("test", ecdsa.curve_192, ecdsa.generator_192, (1, 2))
curve.verifying_key_length = 16
curve.baselen = 16

with self.assertRaises(MalformedPointError):
VerifyingKey.from_string(b('\x00')*16, curve)


@pytest.mark.parametrize("val,even",
[(i, j) for i in range(256) for j in [True, False]])
Expand Down

0 comments on commit 0e907c0

Please sign in to comment.