Skip to content

Commit

Permalink
Resync type annotations with master
Browse files Browse the repository at this point in the history
  • Loading branch information
kjd committed Oct 3, 2021
1 parent b91e138 commit 8bbb873
Show file tree
Hide file tree
Showing 6 changed files with 112 additions and 223 deletions.
15 changes: 5 additions & 10 deletions idna/codec.py
Expand Up @@ -7,8 +7,7 @@

class Codec(codecs.Codec):

def encode(self, data, errors='strict'):
# type: (str, str) -> Tuple[bytes, int]
def encode(self, data: str, errors: str = 'strict') -> Tuple[bytes, int]:
if errors != 'strict':
raise IDNAError('Unsupported error handling \"{}\"'.format(errors))

Expand All @@ -17,8 +16,7 @@ def encode(self, data, errors='strict'):

return encode(data), len(data)

def decode(self, data, errors='strict'):
# type: (bytes, str) -> Tuple[str, int]
def decode(self, data: bytes, errors: str = 'strict') -> Tuple[str, int]:
if errors != 'strict':
raise IDNAError('Unsupported error handling \"{}\"'.format(errors))

Expand All @@ -28,8 +26,7 @@ def decode(self, data, errors='strict'):
return decode(data), len(data)

class IncrementalEncoder(codecs.BufferedIncrementalEncoder):
def _buffer_encode(self, data, errors, final): # type: ignore
# type: (str, str, bool) -> Tuple[str, int]
def _buffer_encode(self, data: str, errors: str, final: bool) -> Tuple[str, int]: # type: ignore
if errors != 'strict':
raise IDNAError('Unsupported error handling \"{}\"'.format(errors))

Expand Down Expand Up @@ -62,8 +59,7 @@ def _buffer_encode(self, data, errors, final): # type: ignore
return result_str, size

class IncrementalDecoder(codecs.BufferedIncrementalDecoder):
def _buffer_decode(self, data, errors, final): # type: ignore
# type: (str, str, bool) -> Tuple[str, int]
def _buffer_decode(self, data: str, errors: str, final: bool) -> Tuple[str, int]: # type: ignore
if errors != 'strict':
raise IDNAError('Unsupported error handling \"{}\"'.format(errors))

Expand Down Expand Up @@ -103,8 +99,7 @@ class StreamReader(Codec, codecs.StreamReader):
pass


def getregentry():
# type: () -> codecs.CodecInfo
def getregentry() -> codecs.CodecInfo:
# Compatibility as a search_function for codecs.register()
return codecs.CodecInfo(
name='idna',
Expand Down
9 changes: 3 additions & 6 deletions idna/compat.py
Expand Up @@ -2,15 +2,12 @@
from .codec import *
from typing import Any, Union

def ToASCII(label):
# type: (str) -> bytes
def ToASCII(label: str) -> bytes:
return encode(label)

def ToUnicode(label):
# type: (Union[bytes, bytearray]) -> str
def ToUnicode(label: Union[bytes, bytearray]) -> str:
return decode(label)

def nameprep(s):
# type: (Any) -> None
def nameprep(s: Any) -> None:
raise NotImplementedError('IDNA 2008 does not utilise nameprep protocol')

54 changes: 18 additions & 36 deletions idna/core.py
Expand Up @@ -29,43 +29,36 @@ class InvalidCodepointContext(IDNAError):
pass


def _combining_class(cp):
# type: (int) -> int
def _combining_class(cp: int) -> int:
v = unicodedata.combining(chr(cp))
if v == 0:
if not unicodedata.name(chr(cp)):
raise ValueError('Unknown character in unicodedata')
return v

def _is_script(cp, script):
# type: (str, str) -> bool
def _is_script(cp: str, script: str) -> bool:
return intranges_contain(ord(cp), idnadata.scripts[script])

def _punycode(s):
# type: (str) -> bytes
def _punycode(s: str) -> bytes:
return s.encode('punycode')

def _unot(s):
# type: (int) -> str
def _unot(s: int) -> str:
return 'U+{:04X}'.format(s)


def valid_label_length(label):
# type: (Union[bytes, str]) -> bool
def valid_label_length(label: Union[bytes, str]) -> bool:
if len(label) > 63:
return False
return True


def valid_string_length(label, trailing_dot):
# type: (Union[bytes, str], bool) -> bool
def valid_string_length(label: Union[bytes, str], trailing_dot: bool) -> bool:
if len(label) > (254 if trailing_dot else 253):
return False
return True


def check_bidi(label, check_ltr=False):
# type: (str, bool) -> bool
def check_bidi(label: str, check_ltr: bool = False) -> bool:
# Bidi rules should only be applied if string contains RTL characters
bidi_label = False
for (idx, cp) in enumerate(label, 1):
Expand Down Expand Up @@ -124,30 +117,26 @@ def check_bidi(label, check_ltr=False):
return True


def check_initial_combiner(label):
# type: (str) -> bool
def check_initial_combiner(label: str) -> bool:
if unicodedata.category(label[0])[0] == 'M':
raise IDNAError('Label begins with an illegal combining character')
return True


def check_hyphen_ok(label):
# type: (str) -> bool
def check_hyphen_ok(label: str) -> bool:
if label[2:4] == '--':
raise IDNAError('Label has disallowed hyphens in 3rd and 4th position')
if label[0] == '-' or label[-1] == '-':
raise IDNAError('Label must not start or end with a hyphen')
return True


def check_nfc(label):
# type: (str) -> None
def check_nfc(label: str) -> None:
if unicodedata.normalize('NFC', label) != label:
raise IDNAError('Label must be in Normalization Form C')


def valid_contextj(label, pos):
# type: (str, int) -> bool
def valid_contextj(label: str, pos: int) -> bool:
cp_value = ord(label[pos])

if cp_value == 0x200c:
Expand Down Expand Up @@ -190,8 +179,7 @@ def valid_contextj(label, pos):
return False


def valid_contexto(label, pos, exception=False):
# type: (str, int, bool) -> bool
def valid_contexto(label: str, pos: int, exception: bool = False) -> bool:
cp_value = ord(label[pos])

if cp_value == 0x00b7:
Expand Down Expand Up @@ -233,8 +221,7 @@ def valid_contexto(label, pos, exception=False):
return False


def check_label(label):
# type: (Union[str, bytes, bytearray]) -> None
def check_label(label: Union[str, bytes, bytearray]) -> None:
if isinstance(label, (bytes, bytearray)):
label = label.decode('utf-8')
if len(label) == 0:
Expand Down Expand Up @@ -265,8 +252,7 @@ def check_label(label):
check_bidi(label)


def alabel(label):
# type: (str) -> bytes
def alabel(label: str) -> bytes:
try:
label_bytes = label.encode('ascii')
ulabel(label_bytes)
Expand All @@ -290,8 +276,7 @@ def alabel(label):
return label_bytes


def ulabel(label):
# type: (Union[str, bytes, bytearray]) -> str
def ulabel(label: Union[str, bytes, bytearray]) -> str:
if not isinstance(label, (bytes, bytearray)):
try:
label_bytes = label.encode('ascii')
Expand Down Expand Up @@ -320,8 +305,7 @@ def ulabel(label):
return label


def uts46_remap(domain, std3_rules=True, transitional=False):
# type: (str, bool, bool) -> str
def uts46_remap(domain: str, std3_rules: bool = True, transitional: bool = False) -> str:
"""Re-map the characters in the string according to UTS46 processing."""
from .uts46data import uts46data
output = ''
Expand Down Expand Up @@ -353,8 +337,7 @@ def uts46_remap(domain, std3_rules=True, transitional=False):
return unicodedata.normalize('NFC', output)


def encode(s, strict=False, uts46=False, std3_rules=False, transitional=False):
# type: (Union[str, bytes, bytearray], bool, bool, bool, bool) -> bytes
def encode(s: Union[str, bytes, bytearray], strict: bool = False, uts46: bool = False, std3_rules: bool = False, transitional: bool = False) -> bytes:
if isinstance(s, (bytes, bytearray)):
s = s.decode('ascii')
if uts46:
Expand Down Expand Up @@ -384,8 +367,7 @@ def encode(s, strict=False, uts46=False, std3_rules=False, transitional=False):
return s


def decode(s, strict=False, uts46=False, std3_rules=False):
# type: (Union[str, bytes, bytearray], bool, bool, bool) -> str
def decode(s: Union[str, bytes, bytearray], strict: bool = False, uts46: bool = False, std3_rules: bool = False) -> str:
try:
if isinstance(s, (bytes, bytearray)):
s = s.decode('ascii')
Expand Down
12 changes: 4 additions & 8 deletions idna/intranges.py
Expand Up @@ -8,8 +8,7 @@
import bisect
from typing import List, Tuple

def intranges_from_list(list_):
# type: (List[int]) -> Tuple[int, ...]
def intranges_from_list(list_: List[int]) -> Tuple[int, ...]:
"""Represent a list of integers as a sequence of ranges:
((start_0, end_0), (start_1, end_1), ...), such that the original
integers are exactly those x such that start_i <= x < end_i for some i.
Expand All @@ -30,17 +29,14 @@ def intranges_from_list(list_):

return tuple(ranges)

def _encode_range(start, end):
# type: (int, int) -> int
def _encode_range(start: int, end: int) -> int:
return (start << 32) | end

def _decode_range(r):
# type: (int) -> Tuple[int, int]
def _decode_range(r: int) -> Tuple[int, int]:
return (r >> 32), (r & ((1 << 32) - 1))


def intranges_contain(int_, ranges):
# type: (int, Tuple[int, ...]) -> bool
def intranges_contain(int_: int, ranges: Tuple[int, ...]) -> bool:
"""Determine if `int_` falls into one of the ranges in `ranges`."""
tuple_ = _encode_range(int_, 0)
pos = bisect.bisect_left(ranges, tuple_)
Expand Down

0 comments on commit 8bbb873

Please sign in to comment.