fix get cuda device test error

Summary:
Cuda test failing on circle with the error `random_ expects 'from' to be less than 'to', but got from=0 >= to=0`

This is because the `high` value in `torch.randint` is 1 more than the highest value in the distribution from which a value is drawn. So if there is only 1 cuda device available then the low and high are 0.

Reviewed By: gkioxari

Differential Revision: D21236669

fbshipit-source-id: 46c312d431c474f1f2c50747b1d5e7afbd7df3a9
This commit is contained in:
Nikhila Ravi 2020-04-24 16:10:11 -07:00 committed by Facebook GitHub Bot
parent f8acecb6b3
commit cf84dacf2e

View File

@ -28,8 +28,10 @@ def get_random_cuda_device() -> str:
any device without having to set the device explicitly.
"""
num_devices = torch.cuda.device_count()
rand_device_id = torch.randint(high=num_devices, size=(1,)).item()
return "cuda:%d" % rand_device_id
device_id = (
torch.randint(high=num_devices, size=(1,)).item() if num_devices > 1 else 0
)
return "cuda:%d" % device_id
class TestCaseMixin(unittest.TestCase):