diff --git a/src/dotenv/main.py b/src/dotenv/main.py index 05d377a9..33217885 100644 --- a/src/dotenv/main.py +++ b/src/dotenv/main.py @@ -125,15 +125,16 @@ def rewrite( path: Union[str, os.PathLike], encoding: Optional[str], ) -> Iterator[Tuple[IO[str], IO[str]]]: + dest = None try: if not os.path.isfile(path): with open(path, "w+", encoding=encoding) as source: source.write("") - with tempfile.NamedTemporaryFile(mode="w+", delete=False, encoding=encoding) as dest: - with open(path, encoding=encoding) as source: - yield (source, dest) # type: ignore + dest = tempfile.NamedTemporaryFile(mode="w+", delete=False, encoding=encoding) + with open(path, encoding=encoding) as source: + yield (source, dest) # type: ignore except BaseException: - if os.path.isfile(dest.name): + if dest and os.path.isfile(dest.name): os.unlink(dest.name) raise else: diff --git a/tests/test_main.py b/tests/test_main.py index 82c73ba1..84a982fe 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -22,6 +22,11 @@ def test_set_key_no_file(tmp_path): assert os.path.exists(nx_file) +def test_set_key_invalid_file(): + with pytest.raises(TypeError): + result = dotenv.set_key(None, "foo", "bar") + + @pytest.mark.parametrize( "before,key,value,expected,after", [