diff --git a/tests/common_testing.py b/tests/common_testing.py index 7f6898a8..ddf1fc3d 100644 --- a/tests/common_testing.py +++ b/tests/common_testing.py @@ -28,8 +28,10 @@ def get_random_cuda_device() -> str: any device without having to set the device explicitly. """ num_devices = torch.cuda.device_count() - rand_device_id = torch.randint(high=num_devices, size=(1,)).item() - return "cuda:%d" % rand_device_id + device_id = ( + torch.randint(high=num_devices, size=(1,)).item() if num_devices > 1 else 0 + ) + return "cuda:%d" % device_id class TestCaseMixin(unittest.TestCase):