diff --git a/python2/httplib2/__init__.py b/python2/httplib2/__init__.py index 98228e3b..d807bc1f 100644 --- a/python2/httplib2/__init__.py +++ b/python2/httplib2/__init__.py @@ -76,7 +76,7 @@ def _ssl_wrap_socket( - sock, key_file, cert_file, disable_validation, ca_certs, ssl_version, hostname + sock, key_file, cert_file, disable_validation, ca_certs, ssl_version, hostname, key_password ): if disable_validation: cert_reqs = ssl.CERT_NONE @@ -90,11 +90,16 @@ def _ssl_wrap_socket( context.verify_mode = cert_reqs context.check_hostname = cert_reqs != ssl.CERT_NONE if cert_file: - context.load_cert_chain(cert_file, key_file) + if key_password: + context.load_cert_chain(cert_file, key_file, key_password) + else: + context.load_cert_chain(cert_file, key_file) if ca_certs: context.load_verify_locations(ca_certs) return context.wrap_socket(sock, server_hostname=hostname) else: + if key_password: + raise NotSupportedOnThisPlatform("Certificate with password is not supported.") return ssl.wrap_socket( sock, keyfile=key_file, @@ -106,7 +111,7 @@ def _ssl_wrap_socket( def _ssl_wrap_socket_unsupported( - sock, key_file, cert_file, disable_validation, ca_certs, ssl_version, hostname + sock, key_file, cert_file, disable_validation, ca_certs, ssl_version, hostname, key_password ): if not disable_validation: raise CertificateValidationUnsupported( @@ -114,6 +119,8 @@ def _ssl_wrap_socket_unsupported( "the ssl module installed. To avoid this error, install " "the ssl module, or explicity disable validation." ) + if key_password: + raise NotSupportedOnThisPlatform("Certificate with password is not supported.") ssl_sock = socket.ssl(sock, key_file, cert_file) return httplib.FakeSocket(sock, ssl_sock) @@ -978,8 +985,13 @@ def iter(self, domain): class KeyCerts(Credentials): """Identical to Credentials except that name/password are mapped to key/cert.""" + def add(self, key, cert, domain, password): + self.credentials.append((domain.lower(), key, cert, password)) - pass + def iter(self, domain): + for (cdomain, key, cert, password) in self.credentials: + if cdomain == "" or domain == cdomain: + yield (key, cert, password) class AllHosts(object): @@ -1253,10 +1265,19 @@ def __init__( ca_certs=None, disable_ssl_certificate_validation=False, ssl_version=None, + key_password=None, ): - httplib.HTTPSConnection.__init__( - self, host, port=port, key_file=key_file, cert_file=cert_file, strict=strict - ) + if key_password: + httplib.HTTPSConnection.__init__(self, host, port=port, strict=strict) + self._context.load_cert_chain(cert_file, key_file, key_password) + self.key_file = key_file + self.cert_file = cert_file + self.key_password = key_password + else: + httplib.HTTPSConnection.__init__( + self, host, port=port, key_file=key_file, cert_file=cert_file, strict=strict + ) + self.key_password = None self.timeout = timeout self.proxy_info = proxy_info if ca_certs is None: @@ -1366,6 +1387,7 @@ def connect(self): self.ca_certs, self.ssl_version, self.host, + self.key_password, ) if self.debuglevel > 0: print("connect: (%s, %s)" % (self.host, self.port)) @@ -1515,7 +1537,10 @@ def __init__( ca_certs=None, disable_ssl_certificate_validation=False, ssl_version=None, + key_password=None, ): + if key_password: + raise NotSupportedOnThisPlatform("Certificate with password is not supported.") httplib.HTTPSConnection.__init__( self, host, @@ -1680,10 +1705,10 @@ def add_credentials(self, name, password, domain=""): any time a request requires authentication.""" self.credentials.add(name, password, domain) - def add_certificate(self, key, cert, domain): + def add_certificate(self, key, cert, domain, password=None): """Add a key and cert that will be used any time a request requires authentication.""" - self.certificates.add(key, cert, domain) + self.certificates.add(key, cert, domain, password) def clear_credentials(self): """Remove all the names and passwords @@ -1958,6 +1983,7 @@ def request( ca_certs=self.ca_certs, disable_ssl_certificate_validation=self.disable_ssl_certificate_validation, ssl_version=self.ssl_version, + key_password=certs[0][2], ) else: conn = self.connections[conn_key] = connection_type( diff --git a/python3/httplib2/__init__.py b/python3/httplib2/__init__.py index 4312f300..06b86bfb 100644 --- a/python3/httplib2/__init__.py +++ b/python3/httplib2/__init__.py @@ -175,7 +175,7 @@ class ProxiesUnavailableError(HttpLib2Error): def _build_ssl_context( disable_ssl_certificate_validation, ca_certs, cert_file=None, key_file=None, - maximum_version=None, minimum_version=None, + maximum_version=None, minimum_version=None, key_password=None, ): if not hasattr(ssl, "SSLContext"): raise RuntimeError("httplib2 requires Python 3.2+ for ssl.SSLContext") @@ -207,7 +207,7 @@ def _build_ssl_context( context.load_verify_locations(ca_certs) if cert_file: - context.load_cert_chain(cert_file, key_file) + context.load_cert_chain(cert_file, key_file, key_password) return context @@ -959,8 +959,13 @@ def iter(self, domain): class KeyCerts(Credentials): """Identical to Credentials except that name/password are mapped to key/cert.""" + def add(self, key, cert, domain, password): + self.credentials.append((domain.lower(), key, cert, password)) - pass + def iter(self, domain): + for (cdomain, key, cert, password) in self.credentials: + if cdomain == "" or domain == cdomain: + yield (key, cert, password) class AllHosts(object): @@ -1245,6 +1250,7 @@ def __init__( disable_ssl_certificate_validation=False, tls_maximum_version=None, tls_minimum_version=None, + key_password=None, ): self.disable_ssl_certificate_validation = disable_ssl_certificate_validation @@ -1257,15 +1263,17 @@ def __init__( context = _build_ssl_context( self.disable_ssl_certificate_validation, self.ca_certs, cert_file, key_file, maximum_version=tls_maximum_version, minimum_version=tls_minimum_version, + key_password=key_password, ) super(HTTPSConnectionWithTimeout, self).__init__( host, port=port, - key_file=key_file, - cert_file=cert_file, timeout=timeout, context=context, ) + self.key_file = key_file + self.cert_file = cert_file + self.key_password = key_password def connect(self): """Connect to a host on a given (SSL) port.""" @@ -1507,10 +1515,10 @@ def add_credentials(self, name, password, domain=""): any time a request requires authentication.""" self.credentials.add(name, password, domain) - def add_certificate(self, key, cert, domain): + def add_certificate(self, key, cert, domain, password=None): """Add a key and cert that will be used any time a request requires authentication.""" - self.certificates.add(key, cert, domain) + self.certificates.add(key, cert, domain, password) def clear_credentials(self): """Remove all the names and passwords @@ -1782,6 +1790,7 @@ def request( disable_ssl_certificate_validation=self.disable_ssl_certificate_validation, tls_maximum_version=self.tls_maximum_version, tls_minimum_version=self.tls_minimum_version, + key_password=certs[0][2], ) else: conn = self.connections[conn_key] = connection_type( diff --git a/tests/test_https.py b/tests/test_https.py index f494d7a1..39d7d59a 100644 --- a/tests/test_https.py +++ b/tests/test_https.py @@ -161,6 +161,30 @@ def handler(request): assert cert_log[0]["serialNumber"] == "E2AA6A96D1BF1AEC" +def test_client_cert_password_verified(): + cert_log = [] + + def setup_tls(context, server, skip_errors): + context.load_verify_locations(cafile=tests.CA_CERTS) + context.verify_mode = ssl.CERT_REQUIRED + return context.wrap_socket(server, server_side=True) + + def handler(request): + cert_log.append(request.client_sock.getpeercert()) + return tests.http_response_bytes() + + http = httplib2.Http(ca_certs=tests.CA_CERTS) + with tests.server_request(handler, tls=setup_tls) as uri: + uri_parsed = urllib.parse.urlparse(uri) + http.add_certificate(tests.CLIENT_ENCRYPTED_PEM, tests.CLIENT_ENCRYPTED_PEM, + uri_parsed.netloc, password="12345") + http.request(uri) + + assert len(cert_log) == 1 + # TODO extract serial from tests.CLIENT_PEM + assert cert_log[0]["serialNumber"] == "E2AA6A96D1BF1AED" + + @pytest.mark.skipif( not hasattr(tests.ssl_context(), "set_servername_callback"), reason="SSLContext.set_servername_callback is not available",