Skip to content

Commit

Permalink
Add type hint to util/ssl_match_hostname.py
Browse files Browse the repository at this point in the history
  • Loading branch information
hramezani committed Apr 21, 2021
1 parent 749209f commit a47fe23
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 13 deletions.
1 change: 1 addition & 0 deletions noxfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
"src/urllib3/util/queue.py",
"src/urllib3/util/response.py",
"src/urllib3/util/ssl_.py",
"src/urllib3/util/ssl_match_hostname.py",
"src/urllib3/util/ssltransport.py",
"src/urllib3/util/url.py",
"src/urllib3/util/wait.py",
Expand Down
24 changes: 11 additions & 13 deletions src/urllib3/util/ssl_match_hostname.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,9 @@

import ipaddress
import re
from typing import Dict, Tuple, Union
from typing import Any, Match, Optional, Union

# https://github.com/python/typeshed/blob/master/stdlib/2and3/ssl.pyi
_PCTRTT = Tuple[Tuple[str, str], ...]
_PCTRTTT = Tuple[_PCTRTT, ...]
_PeerCertRetDictType = Dict[str, Union[str, _PCTRTTT, _PCTRTT]]
_PeerCertRetType = Union[_PeerCertRetDictType, bytes, None]
from .ssl_ import PeerCertRetType

__version__ = "3.5.0.1"

Expand All @@ -20,7 +16,9 @@ class CertificateError(ValueError):
pass


def _dnsname_match(dn, hostname, max_wildcards=1):
def _dnsname_match(
dn: Any, hostname: str, max_wildcards: int = 1
) -> Union[Optional[Match[str]], bool]:
"""Matching according to RFC 6125, section 6.4.3
http://tools.ietf.org/html/rfc6125#section-6.4.3
Expand All @@ -47,7 +45,7 @@ def _dnsname_match(dn, hostname, max_wildcards=1):

# speed up common case w/o wildcards
if not wildcards:
return dn.lower() == hostname.lower()
return bool(dn.lower() == hostname.lower())

# RFC 6125, section 6.4.3, subitem 1.
# The client SHOULD NOT attempt to match a presented identifier in which
Expand All @@ -74,7 +72,7 @@ def _dnsname_match(dn, hostname, max_wildcards=1):
return pat.match(hostname)


def _ipaddress_match(ipname, host_ip):
def _ipaddress_match(ipname: Any, host_ip: str) -> bool:
"""Exact matching of IP addresses.
RFC 6125 explicitly doesn't define an algorithm for this
Expand All @@ -83,10 +81,10 @@ def _ipaddress_match(ipname, host_ip):
# OpenSSL may add a trailing newline to a subjectAltName's IP address
# Divergence from upstream: ipaddress can't handle byte str
ip = ipaddress.ip_address(ipname.rstrip())
return ip == host_ip
return bool(ip == host_ip)


def match_hostname(cert: _PeerCertRetType, hostname: str) -> None:
def match_hostname(cert: PeerCertRetType, hostname: str) -> None:
"""Verify that *cert* (in decoded format as returned by
SSLSocket.getpeercert()) matches the *hostname*. RFC 2818 and RFC 6125
rules are followed, but IP addresses are not accepted for *hostname*.
Expand All @@ -107,8 +105,8 @@ def match_hostname(cert: _PeerCertRetType, hostname: str) -> None:
# Not an IP address (common case)
host_ip = None
dnsnames = []
san = cert.get("subjectAltName", ())
for key, value in san:
san = cert.get("subjectAltName", ()) # type: ignore
for key, value in san: # type: ignore
if key == "DNS":
if host_ip is None and _dnsname_match(value, hostname):
return
Expand Down

0 comments on commit a47fe23

Please sign in to comment.