Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
awaelchli committed Apr 27, 2024
1 parent 4d69b26 commit dc3bbe1
Show file tree
Hide file tree
Showing 4 changed files with 11 additions and 14 deletions.
17 changes: 6 additions & 11 deletions tests/tests_pytorch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,24 +13,19 @@
# limitations under the License.
import os
import warnings
from pathlib import Path

import pytest

_TEST_ROOT = os.path.dirname(__file__)
_PROJECT_ROOT = os.path.dirname(_TEST_ROOT)
_TEMP_PATH = os.path.join(_PROJECT_ROOT, "test_temp")
_PATH_DATASETS = os.path.join(_PROJECT_ROOT, "Datasets")
_PATH_LEGACY = os.path.join(_PROJECT_ROOT, "legacy")
_TEST_ROOT = Path(__file__).parent
_PROJECT_ROOT = _TEST_ROOT.parent.parent
_PATH_DATASETS = _PROJECT_ROOT / "Datasets"
_PATH_LEGACY = _PROJECT_ROOT / "legacy"

# todo: this setting `PYTHONPATH` may not be used by other evns like Conda for import packages
if _PROJECT_ROOT not in os.getenv("PYTHONPATH", ""):
if str(_PROJECT_ROOT) not in os.getenv("PYTHONPATH", ""):
splitter = ":" if os.environ.get("PYTHONPATH", "") else ""
os.environ["PYTHONPATH"] = f'{_PROJECT_ROOT}{splitter}{os.environ.get("PYTHONPATH", "")}'


if not os.path.isdir(_TEMP_PATH):
os.mkdir(_TEMP_PATH)


# Ignore cleanup warnings from pytest (rarely happens due to a race condition when executing pytest in parallel)
warnings.filterwarnings("ignore", category=pytest.PytestWarning, message=r".*\(rm_rf\) error removing.*")
2 changes: 2 additions & 0 deletions tests/tests_pytorch/trainer/flags/test_val_check_interval.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,8 @@ def test_validation_check_interval_exceed_data_length_wrong():
trainer = Trainer(
limit_train_batches=10,
val_check_interval=100,
logger=False,
enable_checkpointing=False,
)

model = BoringModel()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,13 @@ def test_backward_count_simple(torch_backward, num_steps):
def test_backward_count_with_grad_accumulation(torch_backward):
"""Test that backward is called the correct number of times when accumulating gradients."""
model = BoringModel()
trainer = Trainer(max_epochs=1, limit_train_batches=6, accumulate_grad_batches=2)
trainer = Trainer(max_epochs=1, limit_train_batches=6, accumulate_grad_batches=2, logger=False, enable_checkpointing=False)
trainer.fit(model)
assert torch_backward.call_count == 6

torch_backward.reset_mock()

trainer = Trainer(max_steps=6, accumulate_grad_batches=2)
trainer = Trainer(max_steps=6, accumulate_grad_batches=2, logger=False, enable_checkpointing=False)
trainer.fit(model)
assert torch_backward.call_count == 12

Expand Down
2 changes: 1 addition & 1 deletion tests/tests_pytorch/trainer/test_dataloaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -679,7 +679,7 @@ def test_warning_with_small_dataloader_and_logging_interval(tmp_path):

with pytest.warns(UserWarning, match=r"The number of training batches \(1\) is smaller than the logging interval"):
trainer = Trainer(
default_root_dir=tmp_path, max_epochs=1, log_every_n_steps=2, limit_train_batches=1, logger=CSVLogger(".")
default_root_dir=tmp_path, max_epochs=1, log_every_n_steps=2, limit_train_batches=1, logger=CSVLogger(tmp_path)
)
trainer.fit(model)

Expand Down

0 comments on commit dc3bbe1

Please sign in to comment.