Skip to content

Commit

Permalink
harden also key decoding
Browse files Browse the repository at this point in the history
as assert statements will be removed in optimised build, do use a custom
exception that inherits from AssertionError so that the failures are
caught
  • Loading branch information
tomato42 committed Oct 4, 2019
1 parent 3b4bd7e commit 75c0669
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 13 deletions.
43 changes: 32 additions & 11 deletions ecdsa/keys.py
Expand Up @@ -3,6 +3,7 @@
from . import ecdsa
from . import der
from . import rfc6979
from . import ellipticcurve
from .curves import NIST192p, find_curve
from .util import string_to_number, number_to_string, randrange
from .util import sigencode_string, sigdecode_string
Expand All @@ -15,6 +16,11 @@ class BadSignatureError(Exception):
class BadDigestError(Exception):
pass


class MalformedPointError(AssertionError):
pass


class VerifyingKey:
def __init__(self, _error__please_use_generate=None):
if not _error__please_use_generate:
Expand All @@ -33,17 +39,21 @@ def from_public_point(klass, point, curve=NIST192p, hashfunc=sha1):
def from_string(klass, string, curve=NIST192p, hashfunc=sha1,
validate_point=True):
order = curve.order
assert len(string) == curve.verifying_key_length, \
(len(string), curve.verifying_key_length)
if len(string) != curve.verifying_key_length:
raise MalformedPointError(
"Malformed encoding of public point. Expected string {0} bytes"
" long, received {1} bytes long string".format(
curve.verifying_key_length, len(string)))
xs = string[:curve.baselen]
ys = string[curve.baselen:]
assert len(xs) == curve.baselen, (len(xs), curve.baselen)
assert len(ys) == curve.baselen, (len(ys), curve.baselen)
if len(xs) != curve.baselen:
raise MalformedPointError("Unexpected length of encoded x")
if len(ys) != curve.baselen:
raise MalformedPointError("Unexpected length of encoded y")
x = string_to_number(xs)
y = string_to_number(ys)
if validate_point:
assert ecdsa.point_is_valid(curve.generator, x, y)
from . import ellipticcurve
if validate_point and not ecdsa.point_is_valid(curve.generator, x, y):
raise MalformedPointError("Point does not lie on the curve")
point = ellipticcurve.Point(curve.curve, x, y, order)
return klass.from_public_point(point, curve, hashfunc)

Expand All @@ -65,13 +75,18 @@ def from_der(klass, string):
if empty != b(""):
raise der.UnexpectedDER("trailing junk after DER pubkey objects: %s" %
binascii.hexlify(empty))
assert oid_pk == oid_ecPublicKey, (oid_pk, oid_ecPublicKey)
if oid_pk != oid_ecPublicKey:
raise der.UnexpectedDER(
"Unexpected OID in encoding, received {0}, expected {1}"
.format(oid_pk, oid_ecPublicKey))
curve = find_curve(oid_curve)
point_str, empty = der.remove_bitstring(point_str_bitstring)
if empty != b(""):
raise der.UnexpectedDER("trailing junk after pubkey pointstring: %s" %
binascii.hexlify(empty))
assert point_str.startswith(b("\x00\x04"))
if not point_str.startswith(b("\x00\x04")):
raise der.UnexpectedDER(
"Unsupported or invalid encoding of pubcli key")
return klass.from_string(point_str[2:], curve)

def to_string(self):
Expand Down Expand Up @@ -137,7 +152,10 @@ def from_secret_exponent(klass, secexp, curve=NIST192p, hashfunc=sha1):
self.default_hashfunc = hashfunc
self.baselen = curve.baselen
n = curve.order
assert 1 <= secexp < n
if not 1 <= secexp < n:
raise MalformedPointError(
"Invalid value for secexp, expected integer between 1 and {0}"
.format(n))
pubkey_point = curve.generator*secexp
pubkey = ecdsa.Public_key(curve.generator, pubkey_point)
pubkey.order = n
Expand All @@ -149,7 +167,10 @@ def from_secret_exponent(klass, secexp, curve=NIST192p, hashfunc=sha1):

@classmethod
def from_string(klass, string, curve=NIST192p, hashfunc=sha1):
assert len(string) == curve.baselen, (len(string), curve.baselen)
if len(string) != curve.baselen:
raise MalformedPointError(
"Invalid length of private key, received {0}, expected {1}"
.format(len(string), curve.baselen))
secexp = string_to_number(string)
return klass.from_secret_exponent(secexp, curve, hashfunc)

Expand Down
28 changes: 26 additions & 2 deletions ecdsa/test_pyecdsa.py
@@ -1,6 +1,9 @@
from __future__ import with_statement, division

import unittest
try:
import unittest2 as unittest
except ImportError:
import unittest
import os
import time
import shutil
Expand All @@ -10,7 +13,7 @@

from .six import b, print_, binary_type
from .keys import SigningKey, VerifyingKey
from .keys import BadSignatureError
from .keys import BadSignatureError, MalformedPointError
from . import util
from .util import sigencode_der, sigencode_strings
from .util import sigdecode_der, sigdecode_strings
Expand Down Expand Up @@ -299,6 +302,27 @@ def test_hashfunc(self):
curve=NIST256p)
self.assertTrue(vk3.verify(sig, data, hashfunc=sha256))

def test_decoding_with_malformed_uncompressed(self):
enc = b('\x0c\xe0\x1d\xe0d\x1c\x8eS\x8a\xc0\x9eK\xa8x !\xd5\xc2\xc3'
'\xfd\xc8\xa0c\xff\xfb\x02\xb9\xc4\x84)\x1a\x0f\x8b\x87\xa4'
'z\x8a#\xb5\x97\xecO\xb6\xa0HQ\x89*')

with self.assertRaises(MalformedPointError):
VerifyingKey.from_string(b('\x02') + enc)

def test_decoding_with_point_not_on_curve(self):
enc = b('\x0c\xe0\x1d\xe0d\x1c\x8eS\x8a\xc0\x9eK\xa8x !\xd5\xc2\xc3'
'\xfd\xc8\xa0c\xff\xfb\x02\xb9\xc4\x84)\x1a\x0f\x8b\x87\xa4'
'z\x8a#\xb5\x97\xecO\xb6\xa0HQ\x89*')

with self.assertRaises(MalformedPointError):
VerifyingKey.from_string(enc[:47] + b('\x00'))

def test_decoding_with_point_at_infinity(self):
# decoding it is unsupported, as it's not necessary to encode it
with self.assertRaises(MalformedPointError):
VerifyingKey.from_string(b('\x00'))


class OpenSSL(unittest.TestCase):
# test interoperability with OpenSSL tools. Note that openssl's ECDSA
Expand Down

0 comments on commit 75c0669

Please sign in to comment.