mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 20:02:49 +08:00
Tidy uses of torch.device in Volumes
Summary: Tidy uses of `torch.device` in `Volumes`: - Allow `str` or `torch.device` in `Volumes.to()` method - Consistently use `torch.device` for internal type - Fix comparison of devices Reviewed By: nikhilaravi Differential Revision: D28970876 fbshipit-source-id: c640cc22ced684a54cc450ac38a0f4b3435d47be
This commit is contained in:
parent
1db40ac566
commit
1f9661e150
@ -5,6 +5,7 @@ from typing import List, Optional, Tuple, Union
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from ..common.types import Device, make_device
|
||||||
from ..transforms import Scale, Transform3d
|
from ..transforms import Scale, Transform3d
|
||||||
from . import utils as struct_utils
|
from . import utils as struct_utils
|
||||||
|
|
||||||
@ -694,7 +695,7 @@ class Volumes:
|
|||||||
"""
|
"""
|
||||||
return copy.deepcopy(self)
|
return copy.deepcopy(self)
|
||||||
|
|
||||||
def to(self, device, copy: bool = False) -> "Volumes":
|
def to(self, device: Device, copy: bool = False) -> "Volumes":
|
||||||
"""
|
"""
|
||||||
Match the functionality of torch.Tensor.to()
|
Match the functionality of torch.Tensor.to()
|
||||||
If copy = True or the self Tensor is on a different device, the
|
If copy = True or the self Tensor is on a different device, the
|
||||||
@ -703,30 +704,34 @@ class Volumes:
|
|||||||
then self is returned.
|
then self is returned.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
**device**: Device id for the new tensor.
|
device: Device (as str or torch.device) for the new tensor.
|
||||||
**copy**: Boolean indicator whether or not to clone self. Default False.
|
copy: Boolean indicator whether or not to clone self. Default False.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Volumes object.
|
Volumes object.
|
||||||
"""
|
"""
|
||||||
if not copy and self.device == device:
|
device_ = make_device(device)
|
||||||
|
if not copy and self.device == device_:
|
||||||
return self
|
return self
|
||||||
|
|
||||||
other = self.clone()
|
other = self.clone()
|
||||||
if self.device != device:
|
if self.device == device_:
|
||||||
other.device = device
|
return other
|
||||||
|
|
||||||
|
other.device = device_
|
||||||
# pyre-fixme[16]: `List` has no attribute `to`.
|
# pyre-fixme[16]: `List` has no attribute `to`.
|
||||||
other._densities = self._densities.to(device)
|
other._densities = self._densities.to(device_)
|
||||||
if self._features is not None:
|
if self._features is not None:
|
||||||
# pyre-fixme[16]: `Optional` has no attribute `to`.
|
# pyre-fixme[16]: `Optional` has no attribute `to`.
|
||||||
other._features = self.features().to(device)
|
other._features = self.features().to(device_)
|
||||||
other._local_to_world_transform = (
|
other._local_to_world_transform = self.get_local_to_world_coords_transform().to(
|
||||||
self.get_local_to_world_coords_transform().to(device)
|
device_
|
||||||
)
|
)
|
||||||
other._grid_sizes = self._grid_sizes.to(device)
|
other._grid_sizes = self._grid_sizes.to(device_)
|
||||||
return other
|
return other
|
||||||
|
|
||||||
def cpu(self) -> "Volumes":
|
def cpu(self) -> "Volumes":
|
||||||
return self.to(torch.device("cpu"))
|
return self.to("cpu")
|
||||||
|
|
||||||
def cuda(self) -> "Volumes":
|
def cuda(self) -> "Volumes":
|
||||||
return self.to(torch.device("cuda"))
|
return self.to("cuda")
|
||||||
|
@ -445,38 +445,65 @@ class TestVolumes(TestCaseMixin, unittest.TestCase):
|
|||||||
Test the moving of the volumes from/to gpu and cpu
|
Test the moving of the volumes from/to gpu and cpu
|
||||||
"""
|
"""
|
||||||
|
|
||||||
device = torch.device("cuda:0")
|
|
||||||
device_cpu = torch.device("cpu")
|
|
||||||
|
|
||||||
features = torch.randn(
|
features = torch.randn(
|
||||||
size=[num_volumes, num_channels, *size], device=device, dtype=torch.float32
|
size=[num_volumes, num_channels, *size], dtype=torch.float32
|
||||||
)
|
)
|
||||||
densities = torch.rand(size=[num_volumes, 1, *size], device=device, dtype=dtype)
|
densities = torch.rand(size=[num_volumes, 1, *size], dtype=dtype)
|
||||||
|
volumes = Volumes(densities=densities, features=features)
|
||||||
|
|
||||||
|
# Test support for str and torch.device
|
||||||
|
cpu_device = torch.device("cpu")
|
||||||
|
|
||||||
|
converted_volumes = volumes.to("cpu")
|
||||||
|
self.assertEqual(cpu_device, converted_volumes.device)
|
||||||
|
self.assertEqual(cpu_device, volumes.device)
|
||||||
|
self.assertIs(volumes, converted_volumes)
|
||||||
|
|
||||||
|
converted_volumes = volumes.to(cpu_device)
|
||||||
|
self.assertEqual(cpu_device, converted_volumes.device)
|
||||||
|
self.assertEqual(cpu_device, volumes.device)
|
||||||
|
self.assertIs(volumes, converted_volumes)
|
||||||
|
|
||||||
|
cuda_device = torch.device("cuda:0")
|
||||||
|
|
||||||
|
converted_volumes = volumes.to("cuda:0")
|
||||||
|
self.assertEqual(cuda_device, converted_volumes.device)
|
||||||
|
self.assertEqual(cpu_device, volumes.device)
|
||||||
|
self.assertIsNot(volumes, converted_volumes)
|
||||||
|
|
||||||
|
converted_volumes = volumes.to(cuda_device)
|
||||||
|
self.assertEqual(cuda_device, converted_volumes.device)
|
||||||
|
self.assertEqual(cpu_device, volumes.device)
|
||||||
|
self.assertIsNot(volumes, converted_volumes)
|
||||||
|
|
||||||
|
# Test device placement of internal tensors
|
||||||
|
features = features.to(cuda_device)
|
||||||
|
densities = features.to(cuda_device)
|
||||||
|
|
||||||
for features_ in (features, None):
|
for features_ in (features, None):
|
||||||
v = Volumes(densities=densities, features=features_)
|
volumes = Volumes(densities=densities, features=features_)
|
||||||
|
|
||||||
v_cpu = v.cpu()
|
cpu_volumes = volumes.cpu()
|
||||||
v_cuda = v_cpu.cuda()
|
cuda_volumes = cpu_volumes.cuda()
|
||||||
v_cuda_2 = v_cuda.cuda()
|
cuda_volumes2 = cuda_volumes.cuda()
|
||||||
v_cpu_2 = v_cuda_2.cpu()
|
cpu_volumes2 = cuda_volumes2.cpu()
|
||||||
|
|
||||||
for v1, v2 in itertools.combinations(
|
for volumes1, volumes2 in itertools.combinations(
|
||||||
(v, v_cpu, v_cpu_2, v_cuda, v_cuda_2), 2
|
(volumes, cpu_volumes, cpu_volumes2, cuda_volumes, cuda_volumes2), 2
|
||||||
):
|
):
|
||||||
if v1 is v_cuda and v2 is v_cuda_2:
|
if volumes1 is cuda_volumes and volumes2 is cuda_volumes2:
|
||||||
# checks that we do not copy if the devices stay the same
|
# checks that we do not copy if the devices stay the same
|
||||||
assert_fun = self.assertIs
|
assert_fun = self.assertIs
|
||||||
else:
|
else:
|
||||||
assert_fun = self.assertSeparate
|
assert_fun = self.assertSeparate
|
||||||
assert_fun(v1._densities, v2._densities)
|
assert_fun(volumes1._densities, volumes2._densities)
|
||||||
if features_ is not None:
|
if features_ is not None:
|
||||||
assert_fun(v1._features, v2._features)
|
assert_fun(volumes1._features, volumes2._features)
|
||||||
for v_ in (v1, v2):
|
for volumes_ in (volumes1, volumes2):
|
||||||
if v_ in (v_cpu, v_cpu_2):
|
if volumes_ in (cpu_volumes, cpu_volumes2):
|
||||||
self._check_vars_on_device(v_, device_cpu)
|
self._check_vars_on_device(volumes_, cpu_device)
|
||||||
else:
|
else:
|
||||||
self._check_vars_on_device(v_, device)
|
self._check_vars_on_device(volumes_, cuda_device)
|
||||||
|
|
||||||
def _check_padded(self, x_pad, x_list, grid_sizes):
|
def _check_padded(self, x_pad, x_list, grid_sizes):
|
||||||
"""
|
"""
|
||||||
|
Loading…
x
Reference in New Issue
Block a user