diff --git a/src/lxml/apihelpers.pxi b/src/lxml/apihelpers.pxi index 5eb341634..c16627629 100644 --- a/src/lxml/apihelpers.pxi +++ b/src/lxml/apihelpers.pxi @@ -1582,6 +1582,25 @@ cdef bint _isFilePath(const_xmlChar* c_path): # assume it's a relative path return REL_FILE_PATH +cdef object _NO_FSPATH = object() + +cdef object _getFSPathOrObject(object obj): + """ + Get the __fspath__ attribute of an object if it exists. + Otherwise, the original object is returned. + """ + if _isString(obj): + return obj + if python.PY_VERSION_HEX >= 0x03060000: + try: + return python.PY_FSPath(obj) + except TypeError: + return obj + fspath = getattr(obj, '__fspath__', _NO_FSPATH) + if fspath is not _NO_FSPATH and callable(fspath): + return fspath() + return obj + cdef object _encodeFilename(object filename): u"""Make sure a filename is 8-bit encoded (or None). """ diff --git a/src/lxml/dtd.pxi b/src/lxml/dtd.pxi index 2b4bf762f..17242fb8f 100644 --- a/src/lxml/dtd.pxi +++ b/src/lxml/dtd.pxi @@ -279,6 +279,7 @@ cdef class DTD(_Validator): def __init__(self, file=None, *, external_id=None): _Validator.__init__(self) if file is not None: + file = _getFSPathOrObject(file) if _isString(file): file = _encodeFilename(file) with self._error_log: @@ -290,7 +291,7 @@ cdef class DTD(_Validator): self._c_dtd = _parseDtdFromFilelike(file) _reset_document_loader(orig_loader) else: - raise DTDParseError, u"file must be a filename or file-like object" + raise DTDParseError, u"file must be a filename, file-like or path-like object" elif external_id is not None: with self._error_log: orig_loader = _register_document_loader() diff --git a/src/lxml/includes/etree_defs.h b/src/lxml/includes/etree_defs.h index 20d4b9d11..c702e0473 100644 --- a/src/lxml/includes/etree_defs.h +++ b/src/lxml/includes/etree_defs.h @@ -247,6 +247,12 @@ long _ftol2( double dblSource ) { return _ftol( dblSource ); } #define _isString(obj) (PyUnicode_Check(obj) || PyBytes_Check(obj)) #endif +#if PY_VERSION_HEX >= 0x03060000 +#define lxml_PyOS_FSPath(obj) (PyOS_FSPath(obj)) +#else +#define lxml_PyOS_FSPath(obj) (NULL) +#endif + #define _isElement(c_node) \ (((c_node)->type == XML_ELEMENT_NODE) || \ ((c_node)->type == XML_COMMENT_NODE) || \ diff --git a/src/lxml/iterparse.pxi b/src/lxml/iterparse.pxi index 4c20506a4..138c23a6a 100644 --- a/src/lxml/iterparse.pxi +++ b/src/lxml/iterparse.pxi @@ -72,6 +72,7 @@ cdef class iterparse: html=False, recover=None, huge_tree=False, collect_ids=True, XMLSchema schema=None): if not hasattr(source, 'read'): + source = _getFSPathOrObject(source) self._filename = source if python.IS_PYTHON2: source = _encodeFilename(source) diff --git a/src/lxml/parser.pxi b/src/lxml/parser.pxi index 35b51458a..f5baf29b9 100644 --- a/src/lxml/parser.pxi +++ b/src/lxml/parser.pxi @@ -1870,6 +1870,7 @@ cdef xmlNode* _copyNodeToDoc(xmlNode* c_node, xmlDoc* c_doc) except NULL: cdef _Document _parseDocument(source, _BaseParser parser, base_url): cdef _Document doc + source = _getFSPathOrObject(source) if _isString(source): # parse the file directly from the filesystem doc = _parseDocumentFromURL(_encodeFilename(source), parser) diff --git a/src/lxml/python.pxd b/src/lxml/python.pxd index 0d26cdd54..62307aa11 100644 --- a/src/lxml/python.pxd +++ b/src/lxml/python.pxd @@ -127,6 +127,7 @@ cdef extern from "includes/etree_defs.h": # redefines some functions as macros cdef bint IS_PYTHON2 cdef bint IS_PYTHON3 # legacy, avoid cdef bint IS_PYPY + cdef object PY_FSPath "lxml_PyOS_FSPath" (object obj) cdef extern from "lxml_endian.h": cdef bint PY_BIG_ENDIAN # defined in later Py3.x versions diff --git a/src/lxml/serializer.pxi b/src/lxml/serializer.pxi index ec45cf1d4..79a02829e 100644 --- a/src/lxml/serializer.pxi +++ b/src/lxml/serializer.pxi @@ -627,6 +627,7 @@ cdef object _open_utf8_file @contextmanager def _open_utf8_file(file, compression=0): + file = _getFSPathOrObject(file) if _isString(file): if compression: with gzip.GzipFile(file, mode='wb', compresslevel=compression) as zf: @@ -723,6 +724,7 @@ cdef _tofilelike(f, _Element element, encoding, doctype, method, with GzipFile(fileobj=bytes_out, mode='wb', compresslevel=compression) as gzip_file: gzip_file.write(data) data = bytes_out.getvalue() + f = _getFSPathOrObject(f) if _isString(f): filename8 = _encodeFilename(f) with open(filename8, 'wb') as f: @@ -787,6 +789,7 @@ cdef _FilelikeWriter _create_output_buffer( raise LookupError( f"unknown encoding: '{c_enc.decode('UTF-8') if c_enc is not NULL else u''}'") try: + f = _getFSPathOrObject(f) if _isString(f): filename8 = _encodeFilename(f) if b'%' in filename8 and ( @@ -852,6 +855,7 @@ cdef _tofilelikeC14N(f, _Element element, bint exclusive, bint with_comments, _convert_ns_prefixes(c_doc.dict, inclusive_ns_prefixes) if inclusive_ns_prefixes else NULL) + f = _getFSPathOrObject(f) if _isString(f): filename8 = _encodeFilename(f) c_filename = _cstr(filename8) diff --git a/src/lxml/tests/common_imports.py b/src/lxml/tests/common_imports.py index 57097e3c4..68db7c2b2 100644 --- a/src/lxml/tests/common_imports.py +++ b/src/lxml/tests/common_imports.py @@ -251,6 +251,12 @@ def iterelements(self, depth): yield self.chars yield _str('') +class SimpleFSPath(object): + def __init__(self, path): + self.path = path + def __fspath__(self): + return self.path + def fileInTestDir(name): _testdir = os.path.dirname(__file__) return os.path.join(_testdir, name) diff --git a/src/lxml/tests/test_dtd.py b/src/lxml/tests/test_dtd.py index 779f9e849..5c9b1c024 100644 --- a/src/lxml/tests/test_dtd.py +++ b/src/lxml/tests/test_dtd.py @@ -9,7 +9,7 @@ from .common_imports import ( etree, html, BytesIO, _bytes, _str, HelperTestCase, make_doctest, skipIf, - fileInTestDir, fileUrlInTestDir + fileInTestDir, fileUrlInTestDir, SimpleFSPath ) @@ -24,6 +24,14 @@ def test_dtd_file(self): dtd = etree.DTD(fileInTestDir("test.dtd")) self.assertTrue(dtd.validate(root)) + + def test_dtd_file_pathlike(self): + parse = etree.parse + tree = parse(fileInTestDir("test.xml")) + root = tree.getroot() + + dtd = etree.DTD(SimpleFSPath(fileInTestDir("test.dtd"))) + self.assertTrue(dtd.validate(root)) def test_dtd_stringio(self): root = etree.XML(_bytes("")) diff --git a/src/lxml/tests/test_etree.py b/src/lxml/tests/test_etree.py index ef5c54b7b..e5f084692 100644 --- a/src/lxml/tests/test_etree.py +++ b/src/lxml/tests/test_etree.py @@ -25,6 +25,7 @@ from .common_imports import fileInTestDir, fileUrlInTestDir, read_file, path2url, tmpfile from .common_imports import SillyFileLike, LargeFileLikeUnicode, doctest, make_doctest from .common_imports import canonicalize, _str, _bytes +from .common_imports import SimpleFSPath print(""" TESTED VERSION: %s""" % etree.__version__ + """ @@ -4599,6 +4600,20 @@ def test_proxy_collect_siblings_text(self): self.assertEqual('child1', c2.getprevious().tag) self.assertEqual('abc', c2.getprevious().tail) + def test_parse_source_pathlike(self): + etree = self.etree + tounicode = self.etree.tounicode + + tree = etree.parse(SimpleFSPath(fileInTestDir('test.xml'))) + self.assertEqual(_bytes(''), + canonicalize(tounicode(tree))) + + def test_iterparse_source_pathlike(self): + iterparse = self.etree.iterparse + + events = list(iterparse(SimpleFSPath(fileInTestDir('test.xml')))) + self.assertEqual(2, len(events)) + # helper methods def _writeElement(self, element, encoding='us-ascii', compression=0): @@ -4883,6 +4898,14 @@ def test_c14n_file(self): data = read_file(filename, 'rb') self.assertEqual(_bytes(''), data) + + def test_c14n_file_pathlike(self): + tree = self.parse(_bytes('')) + with tmpfile() as filename: + tree.write_c14n(SimpleFSPath(filename)) + data = read_file(filename, 'rb') + self.assertEqual(_bytes(''), + data) def test_c14n_file_gzip(self): tree = self.parse(_bytes(''+''*200+'')) @@ -4892,6 +4915,15 @@ def test_c14n_file_gzip(self): data = f.read() self.assertEqual(_bytes(''+''*200+''), data) + + def test_c14n_file_gzip_pathlike(self): + tree = self.parse(_bytes(''+''*200+'')) + with tmpfile() as filename: + tree.write_c14n(SimpleFSPath(filename), compression=9) + with gzip.open(filename, 'rb') as f: + data = f.read() + self.assertEqual(_bytes(''+''*200+''), + data) def test_c14n2_file_gzip(self): tree = self.parse(_bytes(''+''*200+'')) @@ -5182,6 +5214,14 @@ def test_write_file(self): data = read_file(filename, 'rb') self.assertEqual(_bytes(''), data) + + def test_write_file_pathlike(self): + tree = self.parse(_bytes('')) + with tmpfile() as filename: + tree.write(SimpleFSPath(filename)) + data = read_file(filename, 'rb') + self.assertEqual(_bytes(''), + data) def test_write_file_gzip(self): tree = self.parse(_bytes(''+''*200+'')) @@ -5192,6 +5232,15 @@ def test_write_file_gzip(self): self.assertEqual(_bytes(''+''*200+''), data) + def test_write_file_gzip_pathlike(self): + tree = self.parse(_bytes(''+''*200+'')) + with tmpfile() as filename: + tree.write(SimpleFSPath(filename), compression=9) + with gzip.open(filename, 'rb') as f: + data = f.read() + self.assertEqual(_bytes(''+''*200+''), + data) + def test_write_file_gzip_parse(self): tree = self.parse(_bytes(''+''*200+'')) with tmpfile() as filename: diff --git a/src/lxml/tests/test_xmlschema.py b/src/lxml/tests/test_xmlschema.py index c5653c1e5..dbfc251a5 100644 --- a/src/lxml/tests/test_xmlschema.py +++ b/src/lxml/tests/test_xmlschema.py @@ -8,7 +8,7 @@ import unittest -from .common_imports import etree, BytesIO, HelperTestCase, fileInTestDir, make_doctest +from .common_imports import etree, BytesIO, HelperTestCase, fileInTestDir, make_doctest, SimpleFSPath class ETreeXMLSchemaTestCase(HelperTestCase): @@ -387,6 +387,11 @@ def test_create_from_partial_doc(self): etree.XMLSchema(schema_element) etree.XMLSchema(schema_element) + def test_xmlschema_pathlike(self): + schema = etree.XMLSchema(file=SimpleFSPath(fileInTestDir('test.xsd'))) + tree_valid = self.parse('') + self.assertTrue(schema.validate(tree_valid)) + class ETreeXMLSchemaResolversTestCase(HelperTestCase): resolver_schema_int = BytesIO("""\ diff --git a/src/lxml/tests/test_xslt.py b/src/lxml/tests/test_xslt.py index cde23357c..0ef076694 100644 --- a/src/lxml/tests/test_xslt.py +++ b/src/lxml/tests/test_xslt.py @@ -29,7 +29,7 @@ basestring = str from .common_imports import ( - etree, BytesIO, HelperTestCase, fileInTestDir, _bytes, make_doctest, skipif + etree, BytesIO, HelperTestCase, fileInTestDir, _bytes, make_doctest, skipif, SimpleFSPath ) @@ -195,6 +195,19 @@ def test_xslt_write_output_file_path(self): res[0] = f.read().decode("UTF-16") finally: os.unlink(f.name) + + def test_xslt_write_output_file_pathlike(self): + with self._xslt_setup() as res: + f = NamedTemporaryFile(delete=False) + try: + try: + res[0].write_output(SimpleFSPath(f.name), compression=9) + finally: + f.close() + with gzip.GzipFile(f.name) as f: + res[0] = f.read().decode("UTF-16") + finally: + os.unlink(f.name) def test_xslt_write_output_file_path_urlescaped(self): # libxml2 should not unescape file paths. diff --git a/src/lxml/xmlschema.pxi b/src/lxml/xmlschema.pxi index ab26d935e..fe7a2bacb 100644 --- a/src/lxml/xmlschema.pxi +++ b/src/lxml/xmlschema.pxi @@ -56,6 +56,7 @@ cdef class XMLSchema(_Validator): self._doc = _documentFactory(c_doc, doc._parser) parser_ctxt = xmlschema.xmlSchemaNewDocParserCtxt(c_doc) elif file is not None: + file = _getFSPathOrObject(file) if _isString(file): filename = _encodeFilename(file) parser_ctxt = xmlschema.xmlSchemaNewParserCtxt(_cstr(filename))