Skip to content

Commit

Permalink
Add test for unicode str and decode str if encoding is set in binary_…
Browse files Browse the repository at this point in the history
…ext handler
  • Loading branch information
fozzle committed Oct 1, 2019
1 parent a01001b commit 15e9cee
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 17 deletions.
32 changes: 16 additions & 16 deletions py/erlpack/_unpacker.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,19 @@ cdef class ErlangTermDecoder(object):
raise ErlangTermDecodeError('Bad version number. Expected %d found %d' % (FORMAT_VERSION, version))
return self.decode_part(bytes, offset + 1)[0]

def _decode_str(self, st):
byte_elements_are_ints = isinstance(b'a'[0], int)
if self.encoding:
try:
return st.decode(self.encoding)
except UnicodeError:
pass

if byte_elements_are_ints:
return [x for x in st]

return [ord(x) for x in st]

cdef object decode_part(self, bytes, offset=0):
opcode = bytes[offset:offset+1]

Expand Down Expand Up @@ -190,21 +203,7 @@ cdef class ErlangTermDecoder(object):
"""STRING_EXT"""
length, = struct.unpack('>H', bytes[offset:offset + 2])
offset += 2
st = bytes[offset:offset + length]
byte_elements_are_ints = isinstance(b'a'[0], int)
if self.encoding:
try:
st = st.decode(self.encoding)
except UnicodeError:
if byte_elements_are_ints:
st = [x for x in st]
else:
st = [ord(x) for x in st]
else:
if byte_elements_are_ints:
st = [x for x in st]
else:
st = [ord(x) for x in st]
st = self._decode_str(bytes[offset:offset + length])
return st, offset + length

cdef object decode_l(self, bytes, offset):
Expand All @@ -226,7 +225,8 @@ cdef class ErlangTermDecoder(object):
"""BINARY_EXT"""
length, = struct.unpack('>L', bytes[offset:offset + 4])
offset += 4
return bytes[offset:offset + length], offset + length
rv = self._decode_str(bytes[offset:offset + length])
return rv, offset + length

cdef object decode_n(self, bytes, offset):
"""SMALL_BIG_EXT"""
Expand Down
10 changes: 9 additions & 1 deletion py/tests/test_unicode.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# -*- coding: utf-8 -*-

from __future__ import absolute_import
from erlpack import pack, unpack, Atom
from erlpack import pack, unpack, Atom, ErlangTermDecoder


def test_unicode():
Expand Down Expand Up @@ -29,3 +29,11 @@ def test_unicode_atom_encode_raises():
def test_unicode_atom_decodes():
atm = unpack(b'\x83w\x15\xe3\x81\x93\xe3\x82\x93\xe3\x81\xab\xe3\x81\xa1\xe3\x81\xaf\xe4\xb8\x96\xe7\x95\x8c')
assert atm == Atom(u'こんにちは世界')


def test_unicode_string_decodes_by_default():
unicode_string = u'こんにちは世界'
decoder = ErlangTermDecoder(encoding='utf-8')
packed = pack(unicode_string)
unpacker = decoder.loads
assert unpacker(packed) == unicode_string

0 comments on commit 15e9cee

Please sign in to comment.