Skip to content

Commit

Permalink
add support for password protected certificate files
Browse files Browse the repository at this point in the history
  • Loading branch information
kkoshelev-g authored and temoto committed Oct 21, 2019
1 parent 3a6d7cd commit 4009e8e
Show file tree
Hide file tree
Showing 3 changed files with 75 additions and 16 deletions.
44 changes: 35 additions & 9 deletions python2/httplib2/__init__.py
Expand Up @@ -76,7 +76,7 @@


def _ssl_wrap_socket(
sock, key_file, cert_file, disable_validation, ca_certs, ssl_version, hostname
sock, key_file, cert_file, disable_validation, ca_certs, ssl_version, hostname, key_password
):
if disable_validation:
cert_reqs = ssl.CERT_NONE
Expand All @@ -90,11 +90,16 @@ def _ssl_wrap_socket(
context.verify_mode = cert_reqs
context.check_hostname = cert_reqs != ssl.CERT_NONE
if cert_file:
context.load_cert_chain(cert_file, key_file)
if key_password:
context.load_cert_chain(cert_file, key_file, key_password)
else:
context.load_cert_chain(cert_file, key_file)
if ca_certs:
context.load_verify_locations(ca_certs)
return context.wrap_socket(sock, server_hostname=hostname)
else:
if key_password:
raise NotSupportedOnThisPlatform("Certificate with password is not supported.")
return ssl.wrap_socket(
sock,
keyfile=key_file,
Expand All @@ -106,14 +111,16 @@ def _ssl_wrap_socket(


def _ssl_wrap_socket_unsupported(
sock, key_file, cert_file, disable_validation, ca_certs, ssl_version, hostname
sock, key_file, cert_file, disable_validation, ca_certs, ssl_version, hostname, key_password
):
if not disable_validation:
raise CertificateValidationUnsupported(
"SSL certificate validation is not supported without "
"the ssl module installed. To avoid this error, install "
"the ssl module, or explicity disable validation."
)
if key_password:
raise NotSupportedOnThisPlatform("Certificate with password is not supported.")
ssl_sock = socket.ssl(sock, key_file, cert_file)
return httplib.FakeSocket(sock, ssl_sock)

Expand Down Expand Up @@ -978,8 +985,13 @@ def iter(self, domain):
class KeyCerts(Credentials):
"""Identical to Credentials except that
name/password are mapped to key/cert."""
def add(self, key, cert, domain, password):
self.credentials.append((domain.lower(), key, cert, password))

pass
def iter(self, domain):
for (cdomain, key, cert, password) in self.credentials:
if cdomain == "" or domain == cdomain:
yield (key, cert, password)


class AllHosts(object):
Expand Down Expand Up @@ -1253,10 +1265,19 @@ def __init__(
ca_certs=None,
disable_ssl_certificate_validation=False,
ssl_version=None,
key_password=None,
):
httplib.HTTPSConnection.__init__(
self, host, port=port, key_file=key_file, cert_file=cert_file, strict=strict
)
if key_password:
httplib.HTTPSConnection.__init__(self, host, port=port, strict=strict)
self._context.load_cert_chain(cert_file, key_file, key_password)
self.key_file = key_file
self.cert_file = cert_file
self.key_password = key_password
else:
httplib.HTTPSConnection.__init__(
self, host, port=port, key_file=key_file, cert_file=cert_file, strict=strict
)
self.key_password = None
self.timeout = timeout
self.proxy_info = proxy_info
if ca_certs is None:
Expand Down Expand Up @@ -1366,6 +1387,7 @@ def connect(self):
self.ca_certs,
self.ssl_version,
self.host,
self.key_password,
)
if self.debuglevel > 0:
print("connect: (%s, %s)" % (self.host, self.port))
Expand Down Expand Up @@ -1515,7 +1537,10 @@ def __init__(
ca_certs=None,
disable_ssl_certificate_validation=False,
ssl_version=None,
key_password=None,
):
if key_password:
raise NotSupportedOnThisPlatform("Certificate with password is not supported.")
httplib.HTTPSConnection.__init__(
self,
host,
Expand Down Expand Up @@ -1680,10 +1705,10 @@ def add_credentials(self, name, password, domain=""):
any time a request requires authentication."""
self.credentials.add(name, password, domain)

def add_certificate(self, key, cert, domain):
def add_certificate(self, key, cert, domain, password=None):
"""Add a key and cert that will be used
any time a request requires authentication."""
self.certificates.add(key, cert, domain)
self.certificates.add(key, cert, domain, password)

def clear_credentials(self):
"""Remove all the names and passwords
Expand Down Expand Up @@ -1958,6 +1983,7 @@ def request(
ca_certs=self.ca_certs,
disable_ssl_certificate_validation=self.disable_ssl_certificate_validation,
ssl_version=self.ssl_version,
key_password=certs[0][2],
)
else:
conn = self.connections[conn_key] = connection_type(
Expand Down
23 changes: 16 additions & 7 deletions python3/httplib2/__init__.py
Expand Up @@ -175,7 +175,7 @@ class ProxiesUnavailableError(HttpLib2Error):

def _build_ssl_context(
disable_ssl_certificate_validation, ca_certs, cert_file=None, key_file=None,
maximum_version=None, minimum_version=None,
maximum_version=None, minimum_version=None, key_password=None,
):
if not hasattr(ssl, "SSLContext"):
raise RuntimeError("httplib2 requires Python 3.2+ for ssl.SSLContext")
Expand Down Expand Up @@ -207,7 +207,7 @@ def _build_ssl_context(
context.load_verify_locations(ca_certs)

if cert_file:
context.load_cert_chain(cert_file, key_file)
context.load_cert_chain(cert_file, key_file, key_password)

return context

Expand Down Expand Up @@ -959,8 +959,13 @@ def iter(self, domain):
class KeyCerts(Credentials):
"""Identical to Credentials except that
name/password are mapped to key/cert."""
def add(self, key, cert, domain, password):
self.credentials.append((domain.lower(), key, cert, password))

pass
def iter(self, domain):
for (cdomain, key, cert, password) in self.credentials:
if cdomain == "" or domain == cdomain:
yield (key, cert, password)


class AllHosts(object):
Expand Down Expand Up @@ -1245,6 +1250,7 @@ def __init__(
disable_ssl_certificate_validation=False,
tls_maximum_version=None,
tls_minimum_version=None,
key_password=None,
):

self.disable_ssl_certificate_validation = disable_ssl_certificate_validation
Expand All @@ -1257,15 +1263,17 @@ def __init__(
context = _build_ssl_context(
self.disable_ssl_certificate_validation, self.ca_certs, cert_file, key_file,
maximum_version=tls_maximum_version, minimum_version=tls_minimum_version,
key_password=key_password,
)
super(HTTPSConnectionWithTimeout, self).__init__(
host,
port=port,
key_file=key_file,
cert_file=cert_file,
timeout=timeout,
context=context,
)
self.key_file = key_file
self.cert_file = cert_file
self.key_password = key_password

def connect(self):
"""Connect to a host on a given (SSL) port."""
Expand Down Expand Up @@ -1507,10 +1515,10 @@ def add_credentials(self, name, password, domain=""):
any time a request requires authentication."""
self.credentials.add(name, password, domain)

def add_certificate(self, key, cert, domain):
def add_certificate(self, key, cert, domain, password=None):
"""Add a key and cert that will be used
any time a request requires authentication."""
self.certificates.add(key, cert, domain)
self.certificates.add(key, cert, domain, password)

def clear_credentials(self):
"""Remove all the names and passwords
Expand Down Expand Up @@ -1782,6 +1790,7 @@ def request(
disable_ssl_certificate_validation=self.disable_ssl_certificate_validation,
tls_maximum_version=self.tls_maximum_version,
tls_minimum_version=self.tls_minimum_version,
key_password=certs[0][2],
)
else:
conn = self.connections[conn_key] = connection_type(
Expand Down
24 changes: 24 additions & 0 deletions tests/test_https.py
Expand Up @@ -161,6 +161,30 @@ def handler(request):
assert cert_log[0]["serialNumber"] == "E2AA6A96D1BF1AEC"


def test_client_cert_password_verified():
cert_log = []

def setup_tls(context, server, skip_errors):
context.load_verify_locations(cafile=tests.CA_CERTS)
context.verify_mode = ssl.CERT_REQUIRED
return context.wrap_socket(server, server_side=True)

def handler(request):
cert_log.append(request.client_sock.getpeercert())
return tests.http_response_bytes()

http = httplib2.Http(ca_certs=tests.CA_CERTS)
with tests.server_request(handler, tls=setup_tls) as uri:
uri_parsed = urllib.parse.urlparse(uri)
http.add_certificate(tests.CLIENT_ENCRYPTED_PEM, tests.CLIENT_ENCRYPTED_PEM,
uri_parsed.netloc, password="12345")
http.request(uri)

assert len(cert_log) == 1
# TODO extract serial from tests.CLIENT_PEM
assert cert_log[0]["serialNumber"] == "E2AA6A96D1BF1AED"


@pytest.mark.skipif(
not hasattr(tests.ssl_context(), "set_servername_callback"),
reason="SSLContext.set_servername_callback is not available",
Expand Down

0 comments on commit 4009e8e

Please sign in to comment.