Skip to content

Commit

Permalink
Merge pull request #145 from jribbens/master
Browse files Browse the repository at this point in the history
Accepting the mypy issues for now, given they only affect the oldest versions, very likely a mypy bug.
  • Loading branch information
kjd committed Jun 19, 2023
2 parents 55e98a5 + 81446d4 commit 701288b
Show file tree
Hide file tree
Showing 6 changed files with 103 additions and 58 deletions.
4 changes: 2 additions & 2 deletions README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -57,9 +57,9 @@ You may use the codec encoding and decoding methods using the
.. code-block:: pycon
>>> import idna.codec
>>> print('домен.испытание'.encode('idna'))
>>> print('домен.испытание'.encode('idna2008'))
b'xn--d1acufc.xn--80akhbyknj4f'
>>> print(b'xn--d1acufc.xn--80akhbyknj4f'.decode('idna'))
>>> print(b'xn--d1acufc.xn--80akhbyknj4f'.decode('idna2008'))
домен.испытание
Conversions can be applied at a per-label basis using the ``ulabel`` or
Expand Down
31 changes: 16 additions & 15 deletions idna/codec.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from .core import encode, decode, alabel, ulabel, IDNAError
import codecs
import re
from typing import Tuple, Optional
from typing import Any, Tuple, Optional

_unicode_dots_re = re.compile('[\u002e\u3002\uff0e\uff61]')

Expand All @@ -26,24 +26,24 @@ def decode(self, data: bytes, errors: str = 'strict') -> Tuple[str, int]:
return decode(data), len(data)

class IncrementalEncoder(codecs.BufferedIncrementalEncoder):
def _buffer_encode(self, data: str, errors: str, final: bool) -> Tuple[str, int]: # type: ignore
def _buffer_encode(self, data: str, errors: str, final: bool) -> Tuple[bytes, int]:
if errors != 'strict':
raise IDNAError('Unsupported error handling \"{}\"'.format(errors))

if not data:
return "", 0
return b'', 0

labels = _unicode_dots_re.split(data)
trailing_dot = ''
trailing_dot = b''
if labels:
if not labels[-1]:
trailing_dot = '.'
trailing_dot = b'.'
del labels[-1]
elif not final:
# Keep potentially unfinished label until the next call
del labels[-1]
if labels:
trailing_dot = '.'
trailing_dot = b'.'

result = []
size = 0
Expand All @@ -54,18 +54,21 @@ def _buffer_encode(self, data: str, errors: str, final: bool) -> Tuple[str, int]
size += len(label)

# Join with U+002E
result_str = '.'.join(result) + trailing_dot # type: ignore
result_bytes = b'.'.join(result) + trailing_dot
size += len(trailing_dot)
return result_str, size
return result_bytes, size

class IncrementalDecoder(codecs.BufferedIncrementalDecoder):
def _buffer_decode(self, data: str, errors: str, final: bool) -> Tuple[str, int]: # type: ignore
def _buffer_decode(self, data: Any, errors: str, final: bool) -> Tuple[str, int]:
if errors != 'strict':
raise IDNAError('Unsupported error handling \"{}\"'.format(errors))

if not data:
return ('', 0)

if not isinstance(data, str):
data = str(data, 'ascii')

labels = _unicode_dots_re.split(data)
trailing_dot = ''
if labels:
Expand Down Expand Up @@ -99,13 +102,11 @@ class StreamReader(Codec, codecs.StreamReader):
pass


def getregentry(name: str) -> Optional[codecs.CodecInfo]:
if name != 'idna' and name != 'idna2008':
def search_function(name: str) -> Optional[codecs.CodecInfo]:
if name != 'idna2008':
return None

# Compatibility as a search_function for codecs.register()
return codecs.CodecInfo(
name='idna2008',
name=name,
encode=Codec().encode, # type: ignore
decode=Codec().decode, # type: ignore
incrementalencoder=IncrementalEncoder,
Expand All @@ -114,4 +115,4 @@ def getregentry(name: str) -> Optional[codecs.CodecInfo]:
streamreader=StreamReader,
)

codecs.register(getregentry)
codecs.register(search_function)
8 changes: 4 additions & 4 deletions idna/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,9 +338,9 @@ def uts46_remap(domain: str, std3_rules: bool = True, transitional: bool = False


def encode(s: Union[str, bytes, bytearray], strict: bool = False, uts46: bool = False, std3_rules: bool = False, transitional: bool = False) -> bytes:
if isinstance(s, (bytes, bytearray)):
if not isinstance(s, str):
try:
s = s.decode('ascii')
s = str(s, 'ascii')
except UnicodeDecodeError:
raise IDNAError('should pass a unicode string to the function rather than a byte string.')
if uts46:
Expand Down Expand Up @@ -372,8 +372,8 @@ def encode(s: Union[str, bytes, bytearray], strict: bool = False, uts46: bool =

def decode(s: Union[str, bytes, bytearray], strict: bool = False, uts46: bool = False, std3_rules: bool = False) -> str:
try:
if isinstance(s, (bytes, bytearray)):
s = s.decode('ascii')
if not isinstance(s, str):
s = str(s, 'ascii')
except UnicodeDecodeError:
raise IDNAError('Invalid ASCII in A-label')
if uts46:
Expand Down
66 changes: 37 additions & 29 deletions tests/test_idna.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,37 +231,45 @@ def test_valid_contexto(self):
self.assertTrue(idna.valid_contexto(ext_arabic_digit + ext_arabic_digit, 0))
self.assertFalse(idna.valid_contexto(ext_arabic_digit + arabic_digit, 0))

def test_encode(self):

self.assertEqual(idna.encode('xn--zckzah.xn--zckzah'), b'xn--zckzah.xn--zckzah')
self.assertEqual(idna.encode('\u30c6\u30b9\u30c8.xn--zckzah'), b'xn--zckzah.xn--zckzah')
self.assertEqual(idna.encode('\u30c6\u30b9\u30c8.\u30c6\u30b9\u30c8'), b'xn--zckzah.xn--zckzah')
self.assertEqual(idna.encode('abc.abc'), b'abc.abc')
self.assertEqual(idna.encode('xn--zckzah.abc'), b'xn--zckzah.abc')
self.assertEqual(idna.encode('\u30c6\u30b9\u30c8.abc'), b'xn--zckzah.abc')
self.assertEqual(idna.encode('\u0521\u0525\u0523-\u0523\u0523-----\u0521\u0523\u0523\u0523.aa'),
def test_encode(self, encode=None, skip_bytes=False):
if encode is None:
encode = idna.encode

self.assertEqual(encode('xn--zckzah.xn--zckzah'), b'xn--zckzah.xn--zckzah')
self.assertEqual(encode('\u30c6\u30b9\u30c8.xn--zckzah'), b'xn--zckzah.xn--zckzah')
self.assertEqual(encode('\u30c6\u30b9\u30c8.\u30c6\u30b9\u30c8'), b'xn--zckzah.xn--zckzah')
self.assertEqual(encode('abc.abc'), b'abc.abc')
self.assertEqual(encode('xn--zckzah.abc'), b'xn--zckzah.abc')
self.assertEqual(encode('\u30c6\u30b9\u30c8.abc'), b'xn--zckzah.abc')
self.assertEqual(encode('\u0521\u0525\u0523-\u0523\u0523-----\u0521\u0523\u0523\u0523.aa'),
b'xn---------90gglbagaar.aa')
self.assertRaises(idna.IDNAError, idna.encode,
'\u0521\u0524\u0523-\u0523\u0523-----\u0521\u0523\u0523\u0523.aa', uts46=False)
self.assertEqual(idna.encode('a'*63), b'a'*63)
self.assertRaises(idna.IDNAError, idna.encode, 'a'*64)
self.assertRaises(idna.core.InvalidCodepoint, idna.encode, '*')
self.assertRaises(idna.IDNAError, idna.encode, b'\x0a\x33\x81')

def test_decode(self):

self.assertEqual(idna.decode('xn--zckzah.xn--zckzah'), '\u30c6\u30b9\u30c8.\u30c6\u30b9\u30c8')
self.assertEqual(idna.decode('\u30c6\u30b9\u30c8.xn--zckzah'), '\u30c6\u30b9\u30c8.\u30c6\u30b9\u30c8')
self.assertEqual(idna.decode('\u30c6\u30b9\u30c8.\u30c6\u30b9\u30c8'),
'\u30c6\u30b9\u30c8.\u30c6\u30b9\u30c8')
self.assertEqual(idna.decode('abc.abc'), 'abc.abc')
self.assertEqual(idna.decode('xn---------90gglbagaar.aa'),
if encode is idna.encode:
self.assertRaises(idna.IDNAError, encode,
'\u0521\u0524\u0523-\u0523\u0523-----\u0521\u0523\u0523\u0523.aa', uts46=False)
self.assertEqual(encode('a'*63), b'a'*63)
self.assertRaises(idna.IDNAError, encode, 'a'*64)
self.assertRaises(idna.core.InvalidCodepoint, encode, '*')
if not skip_bytes:
self.assertRaises(idna.IDNAError, encode, b'\x0a\x33\x81')

def test_decode(self, decode=None, skip_str=False):
if decode is None:
decode = idna.decode
self.assertEqual(decode(b'xn--zckzah.xn--zckzah'), '\u30c6\u30b9\u30c8.\u30c6\u30b9\u30c8')
self.assertEqual(decode(b'xn--d1acufc.xn--80akhbyknj4f'),
'\u0434\u043e\u043c\u0435\u043d.\u0438\u0441\u043f\u044b\u0442\u0430\u043d\u0438\u0435')
if not skip_str:
self.assertEqual(decode('\u30c6\u30b9\u30c8.xn--zckzah'), '\u30c6\u30b9\u30c8.\u30c6\u30b9\u30c8')
self.assertEqual(decode('\u30c6\u30b9\u30c8.\u30c6\u30b9\u30c8'),
'\u30c6\u30b9\u30c8.\u30c6\u30b9\u30c8')
self.assertEqual(decode('abc.abc'), 'abc.abc')
self.assertEqual(decode(b'xn---------90gglbagaar.aa'),
'\u0521\u0525\u0523-\u0523\u0523-----\u0521\u0523\u0523\u0523.aa')
self.assertRaises(idna.IDNAError, idna.decode, 'XN---------90GGLBAGAAC.AA')
self.assertRaises(idna.IDNAError, idna.decode, 'xn---------90gglbagaac.aa')
self.assertRaises(idna.IDNAError, idna.decode, 'xn--')
self.assertRaises(idna.IDNAError, idna.decode, b'\x8d\xd2')
self.assertRaises(idna.IDNAError, idna.decode, b'A.A.0.a.a.A.0.a.A.A.0.a.A.0A.2.a.A.A.0.a.A.0.A.a.A0.a.a.A.0.a.fB.A.A.a.A.A.B.A.A.a.A.A.B.A.A.a.A.A.0.a.A.a.a.A.A.0.a.A.0.A.a.A0.a.a.A.0.a.fB.A.A.a.A.A.B.0A.A.a.A.A.B.A.A.a.A.A.a.A.A.B.A.A.a.A.0.a.B.A.A.a.A.B.A.a.A.A.5.a.A.0.a.Ba.A.B.A.A.a.A.0.a.Xn--B.A.A.A.a')
self.assertRaises(idna.IDNAError, decode, b'XN---------90GGLBAGAAC.AA')
self.assertRaises(idna.IDNAError, decode, b'xn---------90gglbagaac.aa')
self.assertRaises(idna.IDNAError, decode, b'xn--')
self.assertRaises(idna.IDNAError, decode, b'\x8d\xd2')
self.assertRaises(idna.IDNAError, decode, b'A.A.0.a.a.A.0.a.A.A.0.a.A.0A.2.a.A.A.0.a.A.0.A.a.A0.a.a.A.0.a.fB.A.A.a.A.A.B.A.A.a.A.A.B.A.A.a.A.A.0.a.A.a.a.A.A.0.a.A.0.A.a.A0.a.a.A.0.a.fB.A.A.a.A.A.B.0A.A.a.A.A.B.A.A.a.A.A.a.A.A.B.A.A.a.A.0.a.B.A.A.a.A.B.A.a.A.A.5.a.A.0.a.Ba.A.B.A.A.a.A.0.a.Xn--B.A.A.A.a')

if __name__ == '__main__':
unittest.main()
50 changes: 43 additions & 7 deletions tests/test_idna_codec.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,51 @@
#!/usr/bin/env python

import codecs
import sys
import io
import unittest

import idna.codec

CODEC_NAME = 'idna2008'

class IDNACodecTests(unittest.TestCase):

def setUp(self):
from . import test_idna
self.idnatests = test_idna.IDNATests()
self.idnatests.setUp()

def testCodec(self):
pass
self.assertIs(codecs.lookup(CODEC_NAME).incrementalencoder, idna.codec.IncrementalEncoder)

def testDirectDecode(self):
self.idnatests.test_decode(decode=lambda obj: codecs.decode(obj, CODEC_NAME))

def testIndirectDecode(self):
self.idnatests.test_decode(decode=lambda obj: obj.decode(CODEC_NAME), skip_str=True)

def testDirectEncode(self):
self.idnatests.test_encode(encode=lambda obj: codecs.encode(obj, CODEC_NAME))

def testIndirectEncode(self):
self.idnatests.test_encode(encode=lambda obj: obj.encode(CODEC_NAME), skip_bytes=True)

def testStreamReader(self):
def decode(obj):
if isinstance(obj, str):
obj = bytes(obj, 'ascii')
buffer = io.BytesIO(obj)
stream = codecs.getreader(CODEC_NAME)(buffer)
return stream.read()
return self.idnatests.test_decode(decode=decode, skip_str=True)

def testStreamWriter(self):
def encode(obj):
buffer = io.BytesIO()
stream = codecs.getwriter(CODEC_NAME)(buffer)
stream.write(obj)
stream.flush()
return buffer.getvalue()
return self.idnatests.test_encode(encode=encode)

def testIncrementalDecoder(self):

Expand All @@ -23,10 +59,10 @@ def testIncrementalDecoder(self):
)

for decoded, encoded in incremental_tests:
self.assertEqual("".join(codecs.iterdecode((bytes([c]) for c in encoded), "idna")),
self.assertEqual("".join(codecs.iterdecode((bytes([c]) for c in encoded), CODEC_NAME)),
decoded)

decoder = codecs.getincrementaldecoder("idna")()
decoder = codecs.getincrementaldecoder(CODEC_NAME)()
self.assertEqual(decoder.decode(b"xn--xam", ), "")
self.assertEqual(decoder.decode(b"ple-9ta.o", ), "\xe4xample.")
self.assertEqual(decoder.decode(b"rg"), "")
Expand All @@ -50,10 +86,10 @@ def testIncrementalEncoder(self):
("pyth\xf6n.org.", b"xn--pythn-mua.org."),
)
for decoded, encoded in incremental_tests:
self.assertEqual(b"".join(codecs.iterencode(decoded, "idna")),
self.assertEqual(b"".join(codecs.iterencode(decoded, CODEC_NAME)),
encoded)

encoder = codecs.getincrementalencoder("idna")()
encoder = codecs.getincrementalencoder(CODEC_NAME)()
self.assertEqual(encoder.encode("\xe4x"), b"")
self.assertEqual(encoder.encode("ample.org"), b"xn--xample-9ta.")
self.assertEqual(encoder.encode("", True), b"org")
Expand Down
2 changes: 1 addition & 1 deletion tests/test_idna_uts46.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def shortDescription(self):

def runTest(self):
if not self.fields:
return ''
return
source, to_unicode, to_unicode_status, to_ascii, to_ascii_status, to_ascii_t, to_ascii_t_status = self.fields
if source in _SKIP_TESTS:
return
Expand Down

0 comments on commit 701288b

Please sign in to comment.