mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 03:42:50 +08:00
fix get cuda device test error
Summary: Cuda test failing on circle with the error `random_ expects 'from' to be less than 'to', but got from=0 >= to=0` This is because the `high` value in `torch.randint` is 1 more than the highest value in the distribution from which a value is drawn. So if there is only 1 cuda device available then the low and high are 0. Reviewed By: gkioxari Differential Revision: D21236669 fbshipit-source-id: 46c312d431c474f1f2c50747b1d5e7afbd7df3a9
This commit is contained in:
parent
f8acecb6b3
commit
cf84dacf2e
@ -28,8 +28,10 @@ def get_random_cuda_device() -> str:
|
||||
any device without having to set the device explicitly.
|
||||
"""
|
||||
num_devices = torch.cuda.device_count()
|
||||
rand_device_id = torch.randint(high=num_devices, size=(1,)).item()
|
||||
return "cuda:%d" % rand_device_id
|
||||
device_id = (
|
||||
torch.randint(high=num_devices, size=(1,)).item() if num_devices > 1 else 0
|
||||
)
|
||||
return "cuda:%d" % device_id
|
||||
|
||||
|
||||
class TestCaseMixin(unittest.TestCase):
|
||||
|
Loading…
x
Reference in New Issue
Block a user