Skip to content

Commit

Permalink
Apply suggestions from review; backport feature to all python versions
Browse files Browse the repository at this point in the history
instead of checking for __fspath__ multiple times a cython function
_getFSPathOrObject is added which calls the __fspath__ attribute if
it exists

For python 3.6+ this function uses PyOS_FSPath
  • Loading branch information
janssenhenning committed Jan 28, 2022
1 parent 77d7a61 commit 146b9a1
Show file tree
Hide file tree
Showing 13 changed files with 72 additions and 114 deletions.
22 changes: 19 additions & 3 deletions src/lxml/apihelpers.pxi
Original file line number Diff line number Diff line change
Expand Up @@ -1582,12 +1582,28 @@ cdef bint _isFilePath(const_xmlChar* c_path):
# assume it's a relative path
return REL_FILE_PATH

cdef object _getFSPathOrObject(object obj):
u"""
Get the __fspath__ attribute of an object
if it exists otherwise the original object
is returned
"""
if python.PY_VERSION_HEX >= 0x03060000:
try:
fspath = python.PY_FSPath(obj)
except TypeError:
return obj
return fspath
elif hasattr(obj, '__fspath__'):
fspath = obj.__fspath__
if callable(fspath):
return fspath()
raise ValueError('__fspath__ is not callable')
return obj

cdef object _encodeFilename(object filename):
u"""Make sure a filename is 8-bit encoded (or None).
"""
if hasattr(filename, u"__fspath__") and python.PY_VERSION_HEX >= 0x03060000:
filename = os_fspath(filename)

if filename is None:
return None
elif isinstance(filename, bytes):
Expand Down
3 changes: 2 additions & 1 deletion src/lxml/dtd.pxi
Original file line number Diff line number Diff line change
Expand Up @@ -279,7 +279,8 @@ cdef class DTD(_Validator):
def __init__(self, file=None, *, external_id=None):
_Validator.__init__(self)
if file is not None:
if _isStringOrPathLike(file):
file = _getFSPathOrObject(file)
if _isString(file):
file = _encodeFilename(file)
with self._error_log:
orig_loader = _register_document_loader()
Expand Down
8 changes: 1 addition & 7 deletions src/lxml/etree.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ from lxml cimport python
from lxml.includes cimport tree, config
from lxml.includes.tree cimport xmlDoc, xmlNode, xmlAttr, xmlNs, _isElement, _getNs
from lxml.includes.tree cimport const_xmlChar, xmlChar, _xcstr
from lxml.python cimport _cstr, _isString, _isStringOrPathLike
from lxml.python cimport _cstr, _isString
from lxml.includes cimport xpath
from lxml.includes cimport c14n

Expand All @@ -64,12 +64,6 @@ from libc.string cimport const_char
cdef object os_path_abspath
from os.path import abspath as os_path_abspath

cdef object os_fspath
try:
from os import fspath as os_fspath # py3.6+
except ImportError:
os_fspath = None

cdef object BytesIO, StringIO
from io import BytesIO, StringIO

Expand Down
16 changes: 3 additions & 13 deletions src/lxml/includes/etree_defs.h
Original file line number Diff line number Diff line change
Expand Up @@ -247,20 +247,10 @@ long _ftol2( double dblSource ) { return _ftol( dblSource ); }
#define _isString(obj) (PyUnicode_Check(obj) || PyBytes_Check(obj))
#endif

#if PY_MAJOR_VERSION < 3
#define _isStringOrPathLike(obj) (PyString_CheckExact(obj) || \
PyUnicode_CheckExact(obj) || \
PyType_IsSubtype(Py_TYPE(obj), &PyBaseString_Type))
#if PY_VERSION_HEX >= 0x03060000
#define PY_FSPath(obj) (PyOS_FSPath(obj))
#else
#if PY_MINOR_VERSION < 6
/* builtin subtype type checks are almost as fast as exact checks in Py2.7+
* and Unicode is more common in Py3 */
#define _isStringOrPathLike(obj) (PyUnicode_Check(obj) || PyBytes_Check(obj))
#else
#define _isStringOrPathLike(obj) (PyUnicode_Check(obj) || \
PyBytes_Check(obj) || \
PyObject_HasAttrString(obj, "__fspath__"))
#endif
#define PY_FSPath(obj) (NULL)
#endif

#define _isElement(c_node) \
Expand Down
1 change: 1 addition & 0 deletions src/lxml/iterparse.pxi
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ cdef class iterparse:
html=False, recover=None, huge_tree=False, collect_ids=True,
XMLSchema schema=None):
if not hasattr(source, 'read'):
source = _getFSPathOrObject(source)
self._filename = source
if python.IS_PYTHON2:
source = _encodeFilename(source)
Expand Down
3 changes: 2 additions & 1 deletion src/lxml/parser.pxi
Original file line number Diff line number Diff line change
Expand Up @@ -1870,7 +1870,8 @@ cdef xmlNode* _copyNodeToDoc(xmlNode* c_node, xmlDoc* c_doc) except NULL:

cdef _Document _parseDocument(source, _BaseParser parser, base_url):
cdef _Document doc
if _isStringOrPathLike(source):
source = _getFSPathOrObject(source)
if _isString(source):
# parse the file directly from the filesystem
doc = _parseDocumentFromURL(_encodeFilename(source), parser)
# fix base URL if requested
Expand Down
2 changes: 1 addition & 1 deletion src/lxml/python.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -121,13 +121,13 @@ cdef extern from "includes/etree_defs.h": # redefines some functions as macros
cdef void lxml_free(void* mem)
cdef void* lxml_unpack_xmldoc_capsule(object capsule, bint* is_owned) except? NULL
cdef bint _isString(object obj)
cdef bint _isStringOrPathLike(object obj)
cdef const_char* _fqtypename(object t)
cdef object PY_NEW(object t)
cdef bint LXML_UNICODE_STRINGS
cdef bint IS_PYTHON2
cdef bint IS_PYTHON3 # legacy, avoid
cdef bint IS_PYPY
cdef object PY_FSPath(object obj)

cdef extern from "lxml_endian.h":
cdef bint PY_BIG_ENDIAN # defined in later Py3.x versions
12 changes: 8 additions & 4 deletions src/lxml/serializer.pxi
Original file line number Diff line number Diff line change
Expand Up @@ -627,7 +627,8 @@ cdef object _open_utf8_file

@contextmanager
def _open_utf8_file(file, compression=0):
if _isStringOrPathLike(file):
file = _getFSPathOrObject(file)
if _isString(file):
if compression:
with gzip.GzipFile(file, mode='wb', compresslevel=compression) as zf:
yield utf8_writer(zf)
Expand Down Expand Up @@ -723,7 +724,8 @@ cdef _tofilelike(f, _Element element, encoding, doctype, method,
with GzipFile(fileobj=bytes_out, mode='wb', compresslevel=compression) as gzip_file:
gzip_file.write(data)
data = bytes_out.getvalue()
if _isStringOrPathLike(f):
f = _getFSPathOrObject(f)
if _isString(f):
filename8 = _encodeFilename(f)
with open(filename8, 'wb') as f:
f.write(data)
Expand Down Expand Up @@ -787,7 +789,8 @@ cdef _FilelikeWriter _create_output_buffer(
raise LookupError(
f"unknown encoding: '{c_enc.decode('UTF-8') if c_enc is not NULL else u''}'")
try:
if _isStringOrPathLike(f):
f = _getFSPathOrObject(f)
if _isString(f):
filename8 = _encodeFilename(f)
if b'%' in filename8 and (
# Exclude absolute Windows paths and file:// URLs.
Expand Down Expand Up @@ -852,7 +855,8 @@ cdef _tofilelikeC14N(f, _Element element, bint exclusive, bint with_comments,
_convert_ns_prefixes(c_doc.dict, inclusive_ns_prefixes)
if inclusive_ns_prefixes else NULL)

if _isStringOrPathLike(f):
f = _getFSPathOrObject(f)
if _isString(f):
filename8 = _encodeFilename(f)
c_filename = _cstr(filename8)
with nogil:
Expand Down
9 changes: 2 additions & 7 deletions src/lxml/tests/test_dtd.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,13 +36,8 @@ def __init__(self,path):
def __fspath__(self):
return self.path

if sys.version_info >= (3,6):
dtd = etree.DTD(Path(fileInTestDir("test.dtd")))
self.assertTrue(dtd.validate(root))
else:
with self.assertRaises(etree.DTDParseError) as cm:
etree.DTD(Path(fileInTestDir("test.dtd")))
self.assertEqual(str(cm.exception),'file must be a filename, file-like or path-like object')
dtd = etree.DTD(Path(fileInTestDir("test.dtd")))
self.assertTrue(dtd.validate(root))

def test_dtd_stringio(self):
root = etree.XML(_bytes("<b/>"))
Expand Down
87 changes: 27 additions & 60 deletions src/lxml/tests/test_etree.py
Original file line number Diff line number Diff line change
Expand Up @@ -4610,14 +4610,9 @@ def __init__(self, path):
def __fspath__(self):
return self.path

if sys.version_info >= (3,6):
tree = etree.parse(Path(fileInTestDir('test.xml')))
self.assertEqual(_bytes('<a><b></b></a>'),
canonicalize(tounicode(tree)))
else:
with self.assertRaises(TypeError) as cm:
etree.parse(Path(fileInTestDir('test.xml')))
self.assertEqual(str(cm.exception),"cannot parse from 'Path'")
tree = etree.parse(Path(fileInTestDir('test.xml')))
self.assertEqual(_bytes('<a><b></b></a>'),
canonicalize(tounicode(tree)))

def test_iterparse_source_pathlike(self):
iterparse = self.etree.iterparse
Expand All @@ -4629,12 +4624,8 @@ def __init__(self, path):
def __fspath__(self):
return self.path

if sys.version_info >= (3,6):
events = list(iterparse(Path(fileInTestDir('test.xml'))))
self.assertEqual(2, len(events))
else:
with self.assertRaises(TypeError) as cm:
list(iterparse(Path(fileInTestDir('test.xml'))))
events = list(iterparse(Path(fileInTestDir('test.xml'))))
self.assertEqual(2, len(events))

# helper methods

Expand Down Expand Up @@ -4929,17 +4920,11 @@ def __fspath__(self):
return self.path

tree = self.parse(_bytes('<a><b/></a>'))
if sys.version_info >= (3,6):
with tmpfile() as filename:
tree.write_c14n(Path(filename))
data = read_file(filename, 'rb')
self.assertEqual(_bytes('<a><b></b></a>'),
data)
else:
with tmpfile() as filename:
with self.assertRaises(TypeError) as cm:
tree.write_c14n(Path(filename))
self.assertEqual(str(cm.exception), "File or filename expected, got 'Path'")
with tmpfile() as filename:
tree.write_c14n(Path(filename))
data = read_file(filename, 'rb')
self.assertEqual(_bytes('<a><b></b></a>'),
data)

def test_c14n_file_gzip(self):
tree = self.parse(_bytes('<a>'+'<b/>'*200+'</a>'))
Expand All @@ -4958,18 +4943,12 @@ def __fspath__(self):
return self.path

tree = self.parse(_bytes('<a>'+'<b/>'*200+'</a>'))
if sys.version_info >= (3,6):
with tmpfile() as filename:
tree.write_c14n(Path(filename), compression=9)
with gzip.open(filename, 'rb') as f:
data = f.read()
self.assertEqual(_bytes('<a>'+'<b></b>'*200+'</a>'),
data)
else:
with tmpfile() as filename:
with self.assertRaises(TypeError) as cm:
tree.write_c14n(Path(filename), compression=9)
self.assertEqual(str(cm.exception), "File or filename expected, got 'Path'")
with tmpfile() as filename:
tree.write_c14n(Path(filename), compression=9)
with gzip.open(filename, 'rb') as f:
data = f.read()
self.assertEqual(_bytes('<a>'+'<b></b>'*200+'</a>'),
data)

def test_c14n2_file_gzip(self):
tree = self.parse(_bytes('<a>'+'<b/>'*200+'</a>'))
Expand Down Expand Up @@ -5269,17 +5248,11 @@ def __fspath__(self):
return self.path

tree = self.parse(_bytes('<a><b/></a>'))
if sys.version_info >= (3,6):
with tmpfile() as filename:
tree.write(Path(filename))
data = read_file(filename, 'rb')
self.assertEqual(_bytes('<a><b/></a>'),
data)
else:
with tmpfile() as filename:
with self.assertRaises(TypeError) as cm:
tree.write(Path(filename))
self.assertEqual(str(cm.exception),"File or filename expected, got 'Path'")
with tmpfile() as filename:
tree.write(Path(filename))
data = read_file(filename, 'rb')
self.assertEqual(_bytes('<a><b/></a>'),
data)

def test_write_file_gzip(self):
tree = self.parse(_bytes('<a>'+'<b/>'*200+'</a>'))
Expand All @@ -5299,18 +5272,12 @@ def __fspath__(self):
return self.path

tree = self.parse(_bytes('<a>'+'<b/>'*200+'</a>'))
if sys.version_info >= (3,6):
with tmpfile() as filename:
tree.write(Path(filename), compression=9)
with gzip.open(filename, 'rb') as f:
data = f.read()
self.assertEqual(_bytes('<a>'+'<b/>'*200+'</a>'),
data)
else:
with tmpfile() as filename:
with self.assertRaises(TypeError) as cm:
tree.write(Path(filename), compression=9)
self.assertEqual(str(cm.exception),"File or filename expected, got 'Path'")
with tmpfile() as filename:
tree.write(Path(filename), compression=9)
with gzip.open(filename, 'rb') as f:
data = f.read()
self.assertEqual(_bytes('<a>'+'<b/>'*200+'</a>'),
data)

def test_write_file_gzip_parse(self):
tree = self.parse(_bytes('<a>'+'<b/>'*200+'</a>'))
Expand Down
12 changes: 3 additions & 9 deletions src/lxml/tests/test_xmlschema.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from __future__ import absolute_import

import unittest
import sys

from .common_imports import etree, BytesIO, HelperTestCase, fileInTestDir, make_doctest

Expand Down Expand Up @@ -396,14 +395,9 @@ def __init__(self, path):
def __fspath__(self):
return self.path

if sys.version_info >= (3,6):
schema = etree.XMLSchema(file=Path(fileInTestDir('test.xsd')))
tree_valid = self.parse('<a><b></b></a>')
self.assertTrue(schema.validate(tree_valid))
else:
with self.assertRaises(TypeError) as cm:
etree.XMLSchema(file=Path(fileInTestDir('test.xsd')))
self.assertEqual(str(cm.exception),"cannot parse from 'Path'")
schema = etree.XMLSchema(file=Path(fileInTestDir('test.xsd')))
tree_valid = self.parse('<a><b></b></a>')
self.assertTrue(schema.validate(tree_valid))


class ETreeXMLSchemaResolversTestCase(HelperTestCase):
Expand Down
8 changes: 1 addition & 7 deletions src/lxml/tests/test_xslt.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,13 +207,7 @@ def __fspath__(self):
f = NamedTemporaryFile(delete=False)
try:
try:
if sys.version_info >= (3,6):
res[0].write_output(Path(f.name), compression=9)
else:
with self.assertRaises(TypeError) as cm:
res[0].write_output(Path(f.name), compression=9)
self.assertEqual(str(cm.exception),"File or filename expected, got 'Path'")
res[0].write_output(f.name, compression=9) #Otherwise the test cannot continue
res[0].write_output(Path(f.name), compression=9)
finally:
f.close()
with gzip.GzipFile(f.name) as f:
Expand Down
3 changes: 2 additions & 1 deletion src/lxml/xmlschema.pxi
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,8 @@ cdef class XMLSchema(_Validator):
self._doc = _documentFactory(c_doc, doc._parser)
parser_ctxt = xmlschema.xmlSchemaNewDocParserCtxt(c_doc)
elif file is not None:
if _isStringOrPathLike(file):
file = _getFSPathOrObject(file)
if _isString(file):
filename = _encodeFilename(file)
parser_ctxt = xmlschema.xmlSchemaNewParserCtxt(_cstr(filename))
else:
Expand Down

0 comments on commit 146b9a1

Please sign in to comment.