From 75c0669f187ab96b110a769f93ace3ecea9cc19f Mon Sep 17 00:00:00 2001 From: Hubert Kario Date: Fri, 4 Oct 2019 21:08:05 +0200 Subject: [PATCH] harden also key decoding as assert statements will be removed in optimised build, do use a custom exception that inherits from AssertionError so that the failures are caught --- ecdsa/keys.py | 43 ++++++++++++++++++++++++++++++++----------- ecdsa/test_pyecdsa.py | 28 ++++++++++++++++++++++++++-- 2 files changed, 58 insertions(+), 13 deletions(-) diff --git a/ecdsa/keys.py b/ecdsa/keys.py index 72c0c18f..45198170 100644 --- a/ecdsa/keys.py +++ b/ecdsa/keys.py @@ -3,6 +3,7 @@ from . import ecdsa from . import der from . import rfc6979 +from . import ellipticcurve from .curves import NIST192p, find_curve from .util import string_to_number, number_to_string, randrange from .util import sigencode_string, sigdecode_string @@ -15,6 +16,11 @@ class BadSignatureError(Exception): class BadDigestError(Exception): pass + +class MalformedPointError(AssertionError): + pass + + class VerifyingKey: def __init__(self, _error__please_use_generate=None): if not _error__please_use_generate: @@ -33,17 +39,21 @@ def from_public_point(klass, point, curve=NIST192p, hashfunc=sha1): def from_string(klass, string, curve=NIST192p, hashfunc=sha1, validate_point=True): order = curve.order - assert len(string) == curve.verifying_key_length, \ - (len(string), curve.verifying_key_length) + if len(string) != curve.verifying_key_length: + raise MalformedPointError( + "Malformed encoding of public point. Expected string {0} bytes" + " long, received {1} bytes long string".format( + curve.verifying_key_length, len(string))) xs = string[:curve.baselen] ys = string[curve.baselen:] - assert len(xs) == curve.baselen, (len(xs), curve.baselen) - assert len(ys) == curve.baselen, (len(ys), curve.baselen) + if len(xs) != curve.baselen: + raise MalformedPointError("Unexpected length of encoded x") + if len(ys) != curve.baselen: + raise MalformedPointError("Unexpected length of encoded y") x = string_to_number(xs) y = string_to_number(ys) - if validate_point: - assert ecdsa.point_is_valid(curve.generator, x, y) - from . import ellipticcurve + if validate_point and not ecdsa.point_is_valid(curve.generator, x, y): + raise MalformedPointError("Point does not lie on the curve") point = ellipticcurve.Point(curve.curve, x, y, order) return klass.from_public_point(point, curve, hashfunc) @@ -65,13 +75,18 @@ def from_der(klass, string): if empty != b(""): raise der.UnexpectedDER("trailing junk after DER pubkey objects: %s" % binascii.hexlify(empty)) - assert oid_pk == oid_ecPublicKey, (oid_pk, oid_ecPublicKey) + if oid_pk != oid_ecPublicKey: + raise der.UnexpectedDER( + "Unexpected OID in encoding, received {0}, expected {1}" + .format(oid_pk, oid_ecPublicKey)) curve = find_curve(oid_curve) point_str, empty = der.remove_bitstring(point_str_bitstring) if empty != b(""): raise der.UnexpectedDER("trailing junk after pubkey pointstring: %s" % binascii.hexlify(empty)) - assert point_str.startswith(b("\x00\x04")) + if not point_str.startswith(b("\x00\x04")): + raise der.UnexpectedDER( + "Unsupported or invalid encoding of pubcli key") return klass.from_string(point_str[2:], curve) def to_string(self): @@ -137,7 +152,10 @@ def from_secret_exponent(klass, secexp, curve=NIST192p, hashfunc=sha1): self.default_hashfunc = hashfunc self.baselen = curve.baselen n = curve.order - assert 1 <= secexp < n + if not 1 <= secexp < n: + raise MalformedPointError( + "Invalid value for secexp, expected integer between 1 and {0}" + .format(n)) pubkey_point = curve.generator*secexp pubkey = ecdsa.Public_key(curve.generator, pubkey_point) pubkey.order = n @@ -149,7 +167,10 @@ def from_secret_exponent(klass, secexp, curve=NIST192p, hashfunc=sha1): @classmethod def from_string(klass, string, curve=NIST192p, hashfunc=sha1): - assert len(string) == curve.baselen, (len(string), curve.baselen) + if len(string) != curve.baselen: + raise MalformedPointError( + "Invalid length of private key, received {0}, expected {1}" + .format(len(string), curve.baselen)) secexp = string_to_number(string) return klass.from_secret_exponent(secexp, curve, hashfunc) diff --git a/ecdsa/test_pyecdsa.py b/ecdsa/test_pyecdsa.py index 326cac4a..60be3c83 100644 --- a/ecdsa/test_pyecdsa.py +++ b/ecdsa/test_pyecdsa.py @@ -1,6 +1,9 @@ from __future__ import with_statement, division -import unittest +try: + import unittest2 as unittest +except ImportError: + import unittest import os import time import shutil @@ -10,7 +13,7 @@ from .six import b, print_, binary_type from .keys import SigningKey, VerifyingKey -from .keys import BadSignatureError +from .keys import BadSignatureError, MalformedPointError from . import util from .util import sigencode_der, sigencode_strings from .util import sigdecode_der, sigdecode_strings @@ -299,6 +302,27 @@ def test_hashfunc(self): curve=NIST256p) self.assertTrue(vk3.verify(sig, data, hashfunc=sha256)) + def test_decoding_with_malformed_uncompressed(self): + enc = b('\x0c\xe0\x1d\xe0d\x1c\x8eS\x8a\xc0\x9eK\xa8x !\xd5\xc2\xc3' + '\xfd\xc8\xa0c\xff\xfb\x02\xb9\xc4\x84)\x1a\x0f\x8b\x87\xa4' + 'z\x8a#\xb5\x97\xecO\xb6\xa0HQ\x89*') + + with self.assertRaises(MalformedPointError): + VerifyingKey.from_string(b('\x02') + enc) + + def test_decoding_with_point_not_on_curve(self): + enc = b('\x0c\xe0\x1d\xe0d\x1c\x8eS\x8a\xc0\x9eK\xa8x !\xd5\xc2\xc3' + '\xfd\xc8\xa0c\xff\xfb\x02\xb9\xc4\x84)\x1a\x0f\x8b\x87\xa4' + 'z\x8a#\xb5\x97\xecO\xb6\xa0HQ\x89*') + + with self.assertRaises(MalformedPointError): + VerifyingKey.from_string(enc[:47] + b('\x00')) + + def test_decoding_with_point_at_infinity(self): + # decoding it is unsupported, as it's not necessary to encode it + with self.assertRaises(MalformedPointError): + VerifyingKey.from_string(b('\x00')) + class OpenSSL(unittest.TestCase): # test interoperability with OpenSSL tools. Note that openssl's ECDSA