mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-12-22 07:10:34 +08:00
Tidy uses of torch.device in Volumes
Summary: Tidy uses of `torch.device` in `Volumes`: - Allow `str` or `torch.device` in `Volumes.to()` method - Consistently use `torch.device` for internal type - Fix comparison of devices Reviewed By: nikhilaravi Differential Revision: D28970876 fbshipit-source-id: c640cc22ced684a54cc450ac38a0f4b3435d47be
This commit is contained in:
committed by
Facebook GitHub Bot
parent
1db40ac566
commit
1f9661e150
@@ -5,6 +5,7 @@ from typing import List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
from ..common.types import Device, make_device
|
||||
from ..transforms import Scale, Transform3d
|
||||
from . import utils as struct_utils
|
||||
|
||||
@@ -694,7 +695,7 @@ class Volumes:
|
||||
"""
|
||||
return copy.deepcopy(self)
|
||||
|
||||
def to(self, device, copy: bool = False) -> "Volumes":
|
||||
def to(self, device: Device, copy: bool = False) -> "Volumes":
|
||||
"""
|
||||
Match the functionality of torch.Tensor.to()
|
||||
If copy = True or the self Tensor is on a different device, the
|
||||
@@ -703,30 +704,34 @@ class Volumes:
|
||||
then self is returned.
|
||||
|
||||
Args:
|
||||
**device**: Device id for the new tensor.
|
||||
**copy**: Boolean indicator whether or not to clone self. Default False.
|
||||
device: Device (as str or torch.device) for the new tensor.
|
||||
copy: Boolean indicator whether or not to clone self. Default False.
|
||||
|
||||
Returns:
|
||||
Volumes object.
|
||||
Volumes object.
|
||||
"""
|
||||
if not copy and self.device == device:
|
||||
device_ = make_device(device)
|
||||
if not copy and self.device == device_:
|
||||
return self
|
||||
|
||||
other = self.clone()
|
||||
if self.device != device:
|
||||
other.device = device
|
||||
# pyre-fixme[16]: `List` has no attribute `to`.
|
||||
other._densities = self._densities.to(device)
|
||||
if self._features is not None:
|
||||
# pyre-fixme[16]: `Optional` has no attribute `to`.
|
||||
other._features = self.features().to(device)
|
||||
other._local_to_world_transform = (
|
||||
self.get_local_to_world_coords_transform().to(device)
|
||||
)
|
||||
other._grid_sizes = self._grid_sizes.to(device)
|
||||
if self.device == device_:
|
||||
return other
|
||||
|
||||
other.device = device_
|
||||
# pyre-fixme[16]: `List` has no attribute `to`.
|
||||
other._densities = self._densities.to(device_)
|
||||
if self._features is not None:
|
||||
# pyre-fixme[16]: `Optional` has no attribute `to`.
|
||||
other._features = self.features().to(device_)
|
||||
other._local_to_world_transform = self.get_local_to_world_coords_transform().to(
|
||||
device_
|
||||
)
|
||||
other._grid_sizes = self._grid_sizes.to(device_)
|
||||
return other
|
||||
|
||||
def cpu(self) -> "Volumes":
|
||||
return self.to(torch.device("cpu"))
|
||||
return self.to("cpu")
|
||||
|
||||
def cuda(self) -> "Volumes":
|
||||
return self.to(torch.device("cuda"))
|
||||
return self.to("cuda")
|
||||
|
||||
Reference in New Issue
Block a user