Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add support for password protected certificate files #143

Merged
merged 14 commits into from Oct 21, 2019
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
44 changes: 35 additions & 9 deletions python2/httplib2/__init__.py
Expand Up @@ -76,7 +76,7 @@


def _ssl_wrap_socket(
sock, key_file, cert_file, disable_validation, ca_certs, ssl_version, hostname
sock, key_file, cert_file, disable_validation, ca_certs, ssl_version, hostname, key_password
):
if disable_validation:
cert_reqs = ssl.CERT_NONE
Expand All @@ -90,11 +90,16 @@ def _ssl_wrap_socket(
context.verify_mode = cert_reqs
context.check_hostname = cert_reqs != ssl.CERT_NONE
if cert_file:
context.load_cert_chain(cert_file, key_file)
if key_password:
context.load_cert_chain(cert_file, key_file, key_password)
else:
context.load_cert_chain(cert_file, key_file)
if ca_certs:
context.load_verify_locations(ca_certs)
return context.wrap_socket(sock, server_hostname=hostname)
else:
if key_password:
raise NotSupportedOnThisPlatform("Certificate with password is not supported.")
return ssl.wrap_socket(
sock,
keyfile=key_file,
Expand All @@ -106,14 +111,16 @@ def _ssl_wrap_socket(


def _ssl_wrap_socket_unsupported(
sock, key_file, cert_file, disable_validation, ca_certs, ssl_version, hostname
sock, key_file, cert_file, disable_validation, ca_certs, ssl_version, hostname, key_password
):
if not disable_validation:
raise CertificateValidationUnsupported(
"SSL certificate validation is not supported without "
"the ssl module installed. To avoid this error, install "
"the ssl module, or explicity disable validation."
)
if key_password:
raise NotSupportedOnThisPlatform("Certificate with password is not supported.")
ssl_sock = socket.ssl(sock, key_file, cert_file)
return httplib.FakeSocket(sock, ssl_sock)

Expand Down Expand Up @@ -978,8 +985,13 @@ def iter(self, domain):
class KeyCerts(Credentials):
"""Identical to Credentials except that
name/password are mapped to key/cert."""
def add(self, key, cert, domain, password):
self.credentials.append((domain.lower(), key, cert, password))

pass
def iter(self, domain):
for (cdomain, key, cert, password) in self.credentials:
if cdomain == "" or domain == cdomain:
yield (key, cert, password)


class AllHosts(object):
Expand Down Expand Up @@ -1253,10 +1265,19 @@ def __init__(
ca_certs=None,
disable_ssl_certificate_validation=False,
ssl_version=None,
key_password=None,
):
httplib.HTTPSConnection.__init__(
self, host, port=port, key_file=key_file, cert_file=cert_file, strict=strict
)
if key_password:
httplib.HTTPSConnection.__init__(self, host, port=port, strict=strict)
self._context.load_cert_chain(cert_file, key_file, key_password)
self.key_file = key_file
self.cert_file = cert_file
self.key_password = key_password
else:
httplib.HTTPSConnection.__init__(
self, host, port=port, key_file=key_file, cert_file=cert_file, strict=strict
)
self.key_password = None
self.timeout = timeout
self.proxy_info = proxy_info
if ca_certs is None:
Expand Down Expand Up @@ -1366,6 +1387,7 @@ def connect(self):
self.ca_certs,
self.ssl_version,
self.host,
self.key_password,
)
if self.debuglevel > 0:
print("connect: (%s, %s)" % (self.host, self.port))
Expand Down Expand Up @@ -1515,7 +1537,10 @@ def __init__(
ca_certs=None,
disable_ssl_certificate_validation=False,
ssl_version=None,
key_password=None,
):
if key_password:
raise NotSupportedOnThisPlatform("Certificate with password is not supported.")
httplib.HTTPSConnection.__init__(
self,
host,
Expand Down Expand Up @@ -1680,10 +1705,10 @@ def add_credentials(self, name, password, domain=""):
any time a request requires authentication."""
self.credentials.add(name, password, domain)

def add_certificate(self, key, cert, domain):
def add_certificate(self, key, cert, domain, password=None):
"""Add a key and cert that will be used
any time a request requires authentication."""
self.certificates.add(key, cert, domain)
self.certificates.add(key, cert, domain, password)

def clear_credentials(self):
"""Remove all the names and passwords
Expand Down Expand Up @@ -1958,6 +1983,7 @@ def request(
ca_certs=self.ca_certs,
disable_ssl_certificate_validation=self.disable_ssl_certificate_validation,
ssl_version=self.ssl_version,
key_password = certs[0][2],
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just this minor spaces around '=', CI linter should've caught this.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

)
else:
conn = self.connections[conn_key] = connection_type(
Expand Down
23 changes: 16 additions & 7 deletions python3/httplib2/__init__.py
Expand Up @@ -175,7 +175,7 @@ class ProxiesUnavailableError(HttpLib2Error):

def _build_ssl_context(
disable_ssl_certificate_validation, ca_certs, cert_file=None, key_file=None,
maximum_version=None, minimum_version=None,
maximum_version=None, minimum_version=None, key_password=None,
):
if not hasattr(ssl, "SSLContext"):
raise RuntimeError("httplib2 requires Python 3.2+ for ssl.SSLContext")
Expand Down Expand Up @@ -207,7 +207,7 @@ def _build_ssl_context(
context.load_verify_locations(ca_certs)

if cert_file:
context.load_cert_chain(cert_file, key_file)
context.load_cert_chain(cert_file, key_file, key_password)

return context

Expand Down Expand Up @@ -959,8 +959,13 @@ def iter(self, domain):
class KeyCerts(Credentials):
"""Identical to Credentials except that
name/password are mapped to key/cert."""
def add(self, key, cert, domain, password):
self.credentials.append((domain.lower(), key, cert, password))

pass
def iter(self, domain):
for (cdomain, key, cert, password) in self.credentials:
if cdomain == "" or domain == cdomain:
yield (key, cert, password)


class AllHosts(object):
Expand Down Expand Up @@ -1245,6 +1250,7 @@ def __init__(
disable_ssl_certificate_validation=False,
tls_maximum_version=None,
tls_minimum_version=None,
key_password=None,
):

self.disable_ssl_certificate_validation = disable_ssl_certificate_validation
Expand All @@ -1257,15 +1263,17 @@ def __init__(
context = _build_ssl_context(
self.disable_ssl_certificate_validation, self.ca_certs, cert_file, key_file,
maximum_version=tls_maximum_version, minimum_version=tls_minimum_version,
key_password=key_password,
)
super(HTTPSConnectionWithTimeout, self).__init__(
host,
port=port,
key_file=key_file,
cert_file=cert_file,
timeout=timeout,
context=context,
)
self.key_file = key_file
self.cert_file = cert_file
self.key_password = key_password

def connect(self):
"""Connect to a host on a given (SSL) port."""
Expand Down Expand Up @@ -1507,10 +1515,10 @@ def add_credentials(self, name, password, domain=""):
any time a request requires authentication."""
self.credentials.add(name, password, domain)

def add_certificate(self, key, cert, domain):
def add_certificate(self, key, cert, domain, password=None):
"""Add a key and cert that will be used
any time a request requires authentication."""
self.certificates.add(key, cert, domain)
self.certificates.add(key, cert, domain, password)

def clear_credentials(self):
"""Remove all the names and passwords
Expand Down Expand Up @@ -1782,6 +1790,7 @@ def request(
disable_ssl_certificate_validation=self.disable_ssl_certificate_validation,
tls_maximum_version=self.tls_maximum_version,
tls_minimum_version=self.tls_minimum_version,
key_password=certs[0][2],
)
else:
conn = self.connections[conn_key] = connection_type(
Expand Down
63 changes: 63 additions & 0 deletions tests/__init__.py
Expand Up @@ -8,12 +8,14 @@
import gzip
import hashlib
import httplib2
from http.server import BaseHTTPRequestHandler, HTTPServer
import os
import random
import re
import shutil
import six
import socket
import ssl
import struct
import sys
import threading
Expand All @@ -23,6 +25,12 @@
from six.moves import http_client, queue


SERVER_CERTFILE = 'tests/testdata/test_server_cert.pem'
CLIENT_CERTFILE = 'tests/testdata/test_cert.pem'
CLIENT_CERT_PASSWORD = '12345'
CLIENT_CERT_SERIAL = '5ECC68A6F89CAA16D032C838CCDDC7E577264CDB'


@contextlib.contextmanager
def assert_raises(exc_type):
def _name(t):
Expand Down Expand Up @@ -260,6 +268,61 @@ def getresponse(self):
raise http_client.BadStatusLine("")


def _get_free_port():
s = socket.socket(socket.AF_INET, type=socket.SOCK_STREAM)
s.bind(('localhost', 0))
address, port = s.getsockname()
s.close()
return port


class _MockServerRequestHandler(BaseHTTPRequestHandler):
"""Server request handler which always returns 200 and saves client cert info."""
def do_GET(self):
# save client cert
self.server.last_client_cert = self.connection.getpeercert()
# Process an HTTP GET request and return a response with an HTTP 200 status.
self.send_response(200)
self.end_headers()
return


class MockHttpServer():
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There already exists code that does same function, right below this class.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I respectfully disagree with that statement.
The code below doesn't support SSL and has some hacky "with" implementation / object life time management. It's also impossible to get back client cert used by the server.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, thanks for criticism, I very much agree that code turned out less simple than it should.

So if you disagree here, please wait for me to add https/cert support in existing server stub. There is no good reason to have two testing server implementations.

"""This creates local http server in a separate thread."""
def __init__(self, handler=None, port=0, use_ssl=False):
self.handler = handler if handler else _MockServerRequestHandler
self.port = port if port else _get_free_port()
self.use_ssl = use_ssl
self.client_certfile = CLIENT_CERTFILE
self.certfile = SERVER_CERTFILE

def __enter__(self):
self.server = HTTPServer(('localhost', self.port), self.handler)

# wrap socket when SSL server requested
if self.use_ssl:
context = ssl.SSLContext(ssl.PROTOCOL_TLS)
# ask client to present own cert for mutual auth
context.verify_mode = ssl.CERT_OPTIONAL
if self.client_certfile:
# avoid verification failure by preloading matching client cert
context.load_verify_locations(self.client_certfile)
# load server cert
context.load_cert_chain(self.certfile)
self.server.socket = context.wrap_socket(
sock=self.server.socket, server_side=True)

# Start running mock server in a separate thread.
# Daemon threads automatically shut down when the main process exits.
server_thread = threading.Thread(target=self.server.serve_forever)
server_thread.setDaemon(True)
server_thread.start()
return self

def __exit__(self, type, value, traceback):
self.server.shutdown()


@contextlib.contextmanager
def server_socket(fun, request_count=1, timeout=5):
gresult = [None]
Expand Down
60 changes: 60 additions & 0 deletions tests/test_external.py
Expand Up @@ -62,6 +62,66 @@ def test_get_via_https_key_cert():
pass


def test_get_via_https_key_cert_password():
# At this point I can only test
# that the key and cert files are passed in
# correctly to httplib. It would be nice to have
# a real https endpoint to test against.
http = httplib2.Http(timeout=2)
http.add_certificate("akeyfile", "acertfile", "bitworking.org", "apassword")
try:
http.request("https://bitworking.org", "GET")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Still trying to hit external website.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Test without relying on external network.
The test without relying on external network is implemented - test_get_via_https_key_cert_password_with_pem_local_server.

Still trying to hit external website.
Just to clarify, do you expect all tests to use local server? There are quite a few tests which do that in test_external.py file. Or are you referring to newly added tests only?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

They are all transient, were created because local server stub was not implemented or not capable of required features at that time. See file docstring:

"""These tests rely on replies from public internet services

TODO: reimplement with local stubs
"""

except AttributeError:
assert http.connections["https:bitworking.org"].key_file == "akeyfile"
assert http.connections["https:bitworking.org"].cert_file == "acertfile"
assert http.connections["https:bitworking.org"].key_password == "apassword"
except IOError:
# Skip on 3.2
pass

try:
http.request("https://notthere.bitworking.org", "GET")
except httplib2.ServerNotFoundError:
assert http.connections["https:notthere.bitworking.org"].key_file is None
assert http.connections["https:notthere.bitworking.org"].cert_file is None
assert http.connections["https:notthere.bitworking.org"].key_password is None
except IOError:
# Skip on 3.2
pass


def test_get_via_https_key_cert_password_with_pem():
# At this point I can only test
# that the key and cert files are passed in
# correctly to httplib. It would be nice to have
# a real https endpoint to test against.
http = httplib2.Http(timeout=2)
http.add_certificate(tests.CLIENT_CERTFILE, tests.CLIENT_CERTFILE,
"bitworking.org", tests.CLIENT_CERT_PASSWORD)
http.request("https://bitworking.org", "GET")

# try invalid password
http = httplib2.Http(timeout=2)
http.add_certificate(tests.CLIENT_CERTFILE, tests.CLIENT_CERTFILE,
"bitworking.org", "invalid")
with tests.assert_raises(ssl.SSLError):
http.request("https://bitworking.org", "GET")


def test_get_via_https_key_cert_password_with_pem_local_server():
with tests.MockHttpServer(use_ssl=True) as server:
# load matching server cert to avoid verification failure
http = httplib2.Http(ca_certs=server.certfile)
# load client cert to be presented when server asks for it
http.add_certificate(tests.CLIENT_CERTFILE, tests.CLIENT_CERTFILE,
'', tests.CLIENT_CERT_PASSWORD)
url = 'https://localhost:{port}/'.format(port=server.port)
response, content = http.request(url, "GET")
assert response.status == 200
# verify that client cert was presented with matching serial number
assert server.server.last_client_cert['serialNumber'] == tests.CLIENT_CERT_SERIAL


def test_ssl_invalid_ca_certs_path():
# Test that we get an ssl.SSLError when specifying a non-existent CA
# certs file.
Expand Down