Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add support for password protected certificate files #143

Merged
merged 14 commits into from Oct 21, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
44 changes: 35 additions & 9 deletions python2/httplib2/__init__.py
Expand Up @@ -76,7 +76,7 @@


def _ssl_wrap_socket(
sock, key_file, cert_file, disable_validation, ca_certs, ssl_version, hostname
sock, key_file, cert_file, disable_validation, ca_certs, ssl_version, hostname, key_password
):
if disable_validation:
cert_reqs = ssl.CERT_NONE
Expand All @@ -90,11 +90,16 @@ def _ssl_wrap_socket(
context.verify_mode = cert_reqs
context.check_hostname = cert_reqs != ssl.CERT_NONE
if cert_file:
context.load_cert_chain(cert_file, key_file)
if key_password:
context.load_cert_chain(cert_file, key_file, key_password)
else:
context.load_cert_chain(cert_file, key_file)
if ca_certs:
context.load_verify_locations(ca_certs)
return context.wrap_socket(sock, server_hostname=hostname)
else:
if key_password:
raise NotSupportedOnThisPlatform("Certificate with password is not supported.")
return ssl.wrap_socket(
sock,
keyfile=key_file,
Expand All @@ -106,14 +111,16 @@ def _ssl_wrap_socket(


def _ssl_wrap_socket_unsupported(
sock, key_file, cert_file, disable_validation, ca_certs, ssl_version, hostname
sock, key_file, cert_file, disable_validation, ca_certs, ssl_version, hostname, key_password
):
if not disable_validation:
raise CertificateValidationUnsupported(
"SSL certificate validation is not supported without "
"the ssl module installed. To avoid this error, install "
"the ssl module, or explicity disable validation."
)
if key_password:
raise NotSupportedOnThisPlatform("Certificate with password is not supported.")
ssl_sock = socket.ssl(sock, key_file, cert_file)
return httplib.FakeSocket(sock, ssl_sock)

Expand Down Expand Up @@ -978,8 +985,13 @@ def iter(self, domain):
class KeyCerts(Credentials):
"""Identical to Credentials except that
name/password are mapped to key/cert."""
def add(self, key, cert, domain, password):
self.credentials.append((domain.lower(), key, cert, password))

pass
def iter(self, domain):
for (cdomain, key, cert, password) in self.credentials:
if cdomain == "" or domain == cdomain:
yield (key, cert, password)


class AllHosts(object):
Expand Down Expand Up @@ -1253,10 +1265,19 @@ def __init__(
ca_certs=None,
disable_ssl_certificate_validation=False,
ssl_version=None,
key_password=None,
):
httplib.HTTPSConnection.__init__(
self, host, port=port, key_file=key_file, cert_file=cert_file, strict=strict
)
if key_password:
httplib.HTTPSConnection.__init__(self, host, port=port, strict=strict)
self._context.load_cert_chain(cert_file, key_file, key_password)
self.key_file = key_file
self.cert_file = cert_file
self.key_password = key_password
else:
httplib.HTTPSConnection.__init__(
self, host, port=port, key_file=key_file, cert_file=cert_file, strict=strict
)
self.key_password = None
self.timeout = timeout
self.proxy_info = proxy_info
if ca_certs is None:
Expand Down Expand Up @@ -1366,6 +1387,7 @@ def connect(self):
self.ca_certs,
self.ssl_version,
self.host,
self.key_password,
)
if self.debuglevel > 0:
print("connect: (%s, %s)" % (self.host, self.port))
Expand Down Expand Up @@ -1515,7 +1537,10 @@ def __init__(
ca_certs=None,
disable_ssl_certificate_validation=False,
ssl_version=None,
key_password=None,
):
if key_password:
raise NotSupportedOnThisPlatform("Certificate with password is not supported.")
httplib.HTTPSConnection.__init__(
self,
host,
Expand Down Expand Up @@ -1680,10 +1705,10 @@ def add_credentials(self, name, password, domain=""):
any time a request requires authentication."""
self.credentials.add(name, password, domain)

def add_certificate(self, key, cert, domain):
def add_certificate(self, key, cert, domain, password=None):
"""Add a key and cert that will be used
any time a request requires authentication."""
self.certificates.add(key, cert, domain)
self.certificates.add(key, cert, domain, password)

def clear_credentials(self):
"""Remove all the names and passwords
Expand Down Expand Up @@ -1958,6 +1983,7 @@ def request(
ca_certs=self.ca_certs,
disable_ssl_certificate_validation=self.disable_ssl_certificate_validation,
ssl_version=self.ssl_version,
key_password=certs[0][2],
)
else:
conn = self.connections[conn_key] = connection_type(
Expand Down
23 changes: 16 additions & 7 deletions python3/httplib2/__init__.py
Expand Up @@ -175,7 +175,7 @@ class ProxiesUnavailableError(HttpLib2Error):

def _build_ssl_context(
disable_ssl_certificate_validation, ca_certs, cert_file=None, key_file=None,
maximum_version=None, minimum_version=None,
maximum_version=None, minimum_version=None, key_password=None,
):
if not hasattr(ssl, "SSLContext"):
raise RuntimeError("httplib2 requires Python 3.2+ for ssl.SSLContext")
Expand Down Expand Up @@ -207,7 +207,7 @@ def _build_ssl_context(
context.load_verify_locations(ca_certs)

if cert_file:
context.load_cert_chain(cert_file, key_file)
context.load_cert_chain(cert_file, key_file, key_password)

return context

Expand Down Expand Up @@ -959,8 +959,13 @@ def iter(self, domain):
class KeyCerts(Credentials):
"""Identical to Credentials except that
name/password are mapped to key/cert."""
def add(self, key, cert, domain, password):
self.credentials.append((domain.lower(), key, cert, password))

pass
def iter(self, domain):
for (cdomain, key, cert, password) in self.credentials:
if cdomain == "" or domain == cdomain:
yield (key, cert, password)


class AllHosts(object):
Expand Down Expand Up @@ -1245,6 +1250,7 @@ def __init__(
disable_ssl_certificate_validation=False,
tls_maximum_version=None,
tls_minimum_version=None,
key_password=None,
):

self.disable_ssl_certificate_validation = disable_ssl_certificate_validation
Expand All @@ -1257,15 +1263,17 @@ def __init__(
context = _build_ssl_context(
self.disable_ssl_certificate_validation, self.ca_certs, cert_file, key_file,
maximum_version=tls_maximum_version, minimum_version=tls_minimum_version,
key_password=key_password,
)
super(HTTPSConnectionWithTimeout, self).__init__(
host,
port=port,
key_file=key_file,
cert_file=cert_file,
timeout=timeout,
context=context,
)
self.key_file = key_file
self.cert_file = cert_file
self.key_password = key_password

def connect(self):
"""Connect to a host on a given (SSL) port."""
Expand Down Expand Up @@ -1507,10 +1515,10 @@ def add_credentials(self, name, password, domain=""):
any time a request requires authentication."""
self.credentials.add(name, password, domain)

def add_certificate(self, key, cert, domain):
def add_certificate(self, key, cert, domain, password=None):
"""Add a key and cert that will be used
any time a request requires authentication."""
self.certificates.add(key, cert, domain)
self.certificates.add(key, cert, domain, password)

def clear_credentials(self):
"""Remove all the names and passwords
Expand Down Expand Up @@ -1782,6 +1790,7 @@ def request(
disable_ssl_certificate_validation=self.disable_ssl_certificate_validation,
tls_maximum_version=self.tls_maximum_version,
tls_minimum_version=self.tls_minimum_version,
key_password=certs[0][2],
)
else:
conn = self.connections[conn_key] = connection_type(
Expand Down
24 changes: 24 additions & 0 deletions tests/test_https.py
Expand Up @@ -161,6 +161,30 @@ def handler(request):
assert cert_log[0]["serialNumber"] == "E2AA6A96D1BF1AEC"


def test_client_cert_password_verified():
cert_log = []

def setup_tls(context, server, skip_errors):
context.load_verify_locations(cafile=tests.CA_CERTS)
context.verify_mode = ssl.CERT_REQUIRED
return context.wrap_socket(server, server_side=True)

def handler(request):
cert_log.append(request.client_sock.getpeercert())
return tests.http_response_bytes()

http = httplib2.Http(ca_certs=tests.CA_CERTS)
with tests.server_request(handler, tls=setup_tls) as uri:
uri_parsed = urllib.parse.urlparse(uri)
http.add_certificate(tests.CLIENT_ENCRYPTED_PEM, tests.CLIENT_ENCRYPTED_PEM,
uri_parsed.netloc, password="12345")
http.request(uri)

assert len(cert_log) == 1
# TODO extract serial from tests.CLIENT_PEM
assert cert_log[0]["serialNumber"] == "E2AA6A96D1BF1AED"


@pytest.mark.skipif(
not hasattr(tests.ssl_context(), "set_servername_callback"),
reason="SSLContext.set_servername_callback is not available",
Expand Down