diff --git a/numpy/core/src/multiarray/methods.c b/numpy/core/src/multiarray/methods.c index 9cb4555f99a1..bf430929067c 100644 --- a/numpy/core/src/multiarray/methods.c +++ b/numpy/core/src/multiarray/methods.c @@ -583,6 +583,28 @@ array_tostring(PyArrayObject *self, PyObject *args, PyObject *kwds) return PyArray_ToString(self, order); } +/* Like PyArray_ToFile but takes the file as a python object */ +static int +PyArray_ToFileObject(PyArrayObject *self, PyObject *file, char *sep, char *format) +{ + npy_off_t orig_pos = 0; + FILE *fd = npy_PyFile_Dup2(file, "wb", &orig_pos); + + if (fd == NULL) { + return -1; + } + + int write_ret = PyArray_ToFile(self, fd, sep, format); + PyObject *err_type, *err_value, *err_traceback; + PyErr_Fetch(&err_type, &err_value, &err_traceback); + int close_ret = npy_PyFile_DupClose2(file, fd, orig_pos); + npy_PyErr_ChainExceptions(err_type, err_value, err_traceback); + + if (write_ret || close_ret) { + return -1; + } + return 0; +} /* This should grow an order= keyword to be consistent */ @@ -592,10 +614,8 @@ array_tofile(PyArrayObject *self, PyObject *args, PyObject *kwds) { int own; PyObject *file; - FILE *fd; char *sep = ""; char *format = ""; - npy_off_t orig_pos = 0; static char *kwlist[] = {"file", "sep", "format", NULL}; if (!PyArg_ParseTupleAndKeywords(args, kwds, "O|ss:tofile", kwlist, @@ -620,25 +640,22 @@ array_tofile(PyArrayObject *self, PyObject *args, PyObject *kwds) own = 0; } - fd = npy_PyFile_Dup2(file, "wb", &orig_pos); - if (fd == NULL) { - goto fail; - } - if (PyArray_ToFile(self, fd, sep, format) < 0) { - goto fail; - } - if (npy_PyFile_DupClose2(file, fd, orig_pos) < 0) { - goto fail; - } - if (own && npy_PyFile_CloseFile(file) < 0) { - goto fail; + int file_ret = PyArray_ToFileObject(self, file, sep, format); + int close_ret = 0; + + if (own) { + PyObject *err_type, *err_value, *err_traceback; + PyErr_Fetch(&err_type, &err_value, &err_traceback); + close_ret = npy_PyFile_CloseFile(file); + npy_PyErr_ChainExceptions(err_type, err_value, err_traceback); } - Py_DECREF(file); - Py_RETURN_NONE; -fail: Py_DECREF(file); - return NULL; + + if (file_ret || close_ret) { + return NULL; + } + Py_RETURN_NONE; } static PyObject * diff --git a/numpy/core/tests/test_multiarray.py b/numpy/core/tests/test_multiarray.py index 2e8db751889a..eef0dd8a43ea 100644 --- a/numpy/core/tests/test_multiarray.py +++ b/numpy/core/tests/test_multiarray.py @@ -4985,6 +4985,17 @@ def test_tofile_format(self): s = f.read() assert_equal(s, '1.51,2.00,3.51,4.00') + def test_tofile_cleanup(self): + x = np.zeros((10), dtype=object) + with open(self.filename, 'wb') as f: + assert_raises(IOError, lambda: x.tofile(f, sep='')) + # Dup-ed file handle should be closed or remove will fail on Windows OS + os.remove(self.filename) + + # Also make sure that we close the Python handle + assert_raises(IOError, lambda: x.tofile(self.filename)) + os.remove(self.filename) + def test_locale(self): with CommaDecimalPointLocale(): self.test_numbers()