Skip to content

Commit

Permalink
Restore CUDA_VISIBLE_DEVICES after test
Browse files Browse the repository at this point in the history
Signed-off-by: Enrico Minack <github@enrico.minack.dev>
  • Loading branch information
EnricoMi committed Mar 1, 2022
1 parent 787fcb9 commit 4e9c44c
Showing 1 changed file with 17 additions and 10 deletions.
27 changes: 17 additions & 10 deletions test/single/test_ray.py
Expand Up @@ -5,7 +5,6 @@
import os
import socket
import sys
import time

import pytest
import ray
Expand Down Expand Up @@ -39,19 +38,27 @@ def ray_start_4_cpus():
@pytest.fixture
def ray_start_6_cpus():
address_info = ray.init(num_cpus=6)
yield address_info
# The code after the yield will run as teardown code.
ray.shutdown()
try:
yield address_info
finally:
# The code after the yield will run as teardown code.
ray.shutdown()


@pytest.fixture
def ray_start_4_cpus_4_gpus():
orig_devices = os.environ.get("CUDA_VISIBLE_DEVICES")
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3"
address_info = ray.init(num_cpus=4, num_gpus=4)
yield address_info
# The code after the yield will run as teardown code.
ray.shutdown()
del os.environ["CUDA_VISIBLE_DEVICES"]
try:
yield address_info
# The code after the yield will run as teardown code.
ray.shutdown()
finally:
if orig_devices:
os.environ["CUDA_VISIBLE_DEVICES"] = orig_devices
else:
del os.environ["CUDA_VISIBLE_DEVICES"]


@pytest.fixture
Expand Down Expand Up @@ -168,7 +175,7 @@ def test_gpu_ids(ray_start_4_cpus_4_gpus):
all_envs = hjob.execute(lambda _: os.environ.copy())
all_cudas = {ev["CUDA_VISIBLE_DEVICES"] for ev in all_envs}
assert len(all_cudas) == 1, all_cudas
assert len(all_envs[0]["CUDA_VISIBLE_DEVICES"].split(",")) == 4
assert len(all_envs[0]["CUDA_VISIBLE_DEVICES"].split(",")) == 4, all_envs[0]["CUDA_VISIBLE_DEVICES"]
hjob.shutdown()


Expand All @@ -184,7 +191,7 @@ def test_gpu_ids_num_workers(ray_start_4_cpus_4_gpus):
all_cudas = {ev["CUDA_VISIBLE_DEVICES"] for ev in all_envs}

assert len(all_cudas) == 1, all_cudas
assert len(all_envs[0]["CUDA_VISIBLE_DEVICES"].split(",")) == 4
assert len(all_envs[0]["CUDA_VISIBLE_DEVICES"].split(",")) == 4, all_envs[0]["CUDA_VISIBLE_DEVICES"]

def _test(worker):
import horovod.torch as hvd
Expand Down

0 comments on commit 4e9c44c

Please sign in to comment.