mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 03:42:50 +08:00
Tidy uses of torch.device in Transform3d
Summary: Tidy uses of `torch.device` in `Transforms3d`: - Allow `str` or `torch.device` in user-facing methods - Consistently use `torch.device` for internal types - Fix comparison of devices Reviewed By: nikhilaravi Differential Revision: D28929486 fbshipit-source-id: bd1d6cc7ede3d8fd549fd3224a9b07eec53f8164
This commit is contained in:
parent
48faf8eb7e
commit
13a0110b69
@ -6,6 +6,7 @@ from typing import List, Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
from ..common.types import Device, get_device, make_device
|
||||
from .rotation_conversions import _axis_angle_rotation
|
||||
|
||||
|
||||
@ -137,7 +138,7 @@ class Transform3d:
|
||||
def __init__(
|
||||
self,
|
||||
dtype: torch.dtype = torch.float32,
|
||||
device="cpu",
|
||||
device: Device = "cpu",
|
||||
matrix: Optional[torch.Tensor] = None,
|
||||
):
|
||||
"""
|
||||
@ -167,7 +168,7 @@ class Transform3d:
|
||||
|
||||
self._transforms = [] # store transforms to compose
|
||||
self._lu = None
|
||||
self.device = device
|
||||
self.device = make_device(device)
|
||||
|
||||
def __len__(self):
|
||||
return self.get_matrix().shape[0]
|
||||
@ -398,7 +399,12 @@ class Transform3d:
|
||||
other._transforms = [t.clone() for t in self._transforms]
|
||||
return other
|
||||
|
||||
def to(self, device, copy: bool = False, dtype=None):
|
||||
def to(
|
||||
self,
|
||||
device: Device,
|
||||
copy: bool = False,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
):
|
||||
"""
|
||||
Match functionality of torch.Tensor.to()
|
||||
If copy = True or the self Tensor is on a different device, the
|
||||
@ -407,7 +413,7 @@ class Transform3d:
|
||||
then self is returned.
|
||||
|
||||
Args:
|
||||
device: Device id for the new tensor.
|
||||
device: Device (as str or torch.device) for the new tensor.
|
||||
copy: Boolean indicator whether or not to clone self. Default False.
|
||||
dtype: If not None, casts the internal tensor variables
|
||||
to a given torch.dtype.
|
||||
@ -415,26 +421,37 @@ class Transform3d:
|
||||
Returns:
|
||||
Transform3d object.
|
||||
"""
|
||||
if not copy and self.device == device:
|
||||
device_ = make_device(device)
|
||||
if not copy and self.device == device_:
|
||||
return self
|
||||
|
||||
other = self.clone()
|
||||
if self.device != device:
|
||||
other.device = device
|
||||
other._matrix = self._matrix.to(device=device, dtype=dtype)
|
||||
other._transforms = [
|
||||
t.to(device, copy=copy, dtype=dtype) for t in other._transforms
|
||||
]
|
||||
if self.device == device_:
|
||||
return other
|
||||
|
||||
other.device = device_
|
||||
other._matrix = self._matrix.to(device=device_, dtype=dtype)
|
||||
other._transforms = [
|
||||
t.to(device_, copy=copy, dtype=dtype) for t in other._transforms
|
||||
]
|
||||
return other
|
||||
|
||||
def cpu(self):
|
||||
return self.to(torch.device("cpu"))
|
||||
return self.to("cpu")
|
||||
|
||||
def cuda(self):
|
||||
return self.to(torch.device("cuda"))
|
||||
return self.to("cuda")
|
||||
|
||||
|
||||
class Translate(Transform3d):
|
||||
def __init__(self, x, y=None, z=None, dtype=torch.float32, device=None):
|
||||
def __init__(
|
||||
self,
|
||||
x,
|
||||
y=None,
|
||||
z=None,
|
||||
dtype: torch.dtype = torch.float32,
|
||||
device: Optional[Device] = None,
|
||||
):
|
||||
"""
|
||||
Create a new Transform3d representing 3D translations.
|
||||
|
||||
@ -468,7 +485,14 @@ class Translate(Transform3d):
|
||||
|
||||
|
||||
class Scale(Transform3d):
|
||||
def __init__(self, x, y=None, z=None, dtype=torch.float32, device=None):
|
||||
def __init__(
|
||||
self,
|
||||
x,
|
||||
y=None,
|
||||
z=None,
|
||||
dtype: torch.dtype = torch.float32,
|
||||
device: Optional[Device] = None,
|
||||
):
|
||||
"""
|
||||
A Transform3d representing a scaling operation, with different scale
|
||||
factors along each coordinate axis.
|
||||
@ -509,7 +533,11 @@ class Scale(Transform3d):
|
||||
|
||||
class Rotate(Transform3d):
|
||||
def __init__(
|
||||
self, R, dtype=torch.float32, device=None, orthogonal_tol: float = 1e-5
|
||||
self,
|
||||
R: torch.Tensor,
|
||||
dtype: torch.dtype = torch.float32,
|
||||
device: Optional[Device] = None,
|
||||
orthogonal_tol: float = 1e-5,
|
||||
):
|
||||
"""
|
||||
Create a new Transform3d representing 3D rotation using a rotation
|
||||
@ -520,17 +548,17 @@ class Rotate(Transform3d):
|
||||
orthogonal_tol: tolerance for the test of the orthogonality of R
|
||||
|
||||
"""
|
||||
device = _get_device(R, device)
|
||||
super().__init__(device=device)
|
||||
device_ = get_device(R, device)
|
||||
super().__init__(device=device_)
|
||||
if R.dim() == 2:
|
||||
R = R[None]
|
||||
if R.shape[-2:] != (3, 3):
|
||||
msg = "R must have shape (3, 3) or (N, 3, 3); got %s"
|
||||
raise ValueError(msg % repr(R.shape))
|
||||
R = R.to(dtype=dtype).to(device=device)
|
||||
R = R.to(dtype=dtype).to(device=device_)
|
||||
_check_valid_rotation_matrix(R, tol=orthogonal_tol)
|
||||
N = R.shape[0]
|
||||
mat = torch.eye(4, dtype=dtype, device=device)
|
||||
mat = torch.eye(4, dtype=dtype, device=device_)
|
||||
mat = mat.view(1, 4, 4).repeat(N, 1, 1)
|
||||
mat[:, :3, :3] = R
|
||||
self._matrix = mat
|
||||
@ -548,8 +576,8 @@ class RotateAxisAngle(Rotate):
|
||||
angle,
|
||||
axis: str = "X",
|
||||
degrees: bool = True,
|
||||
dtype=torch.float64,
|
||||
device=None,
|
||||
dtype: torch.dtype = torch.float64,
|
||||
device: Optional[Device] = None,
|
||||
):
|
||||
"""
|
||||
Create a new Transform3d representing 3D rotation about an axis
|
||||
@ -582,7 +610,7 @@ class RotateAxisAngle(Rotate):
|
||||
super().__init__(device=angle.device, R=R)
|
||||
|
||||
|
||||
def _handle_coord(c, dtype, device):
|
||||
def _handle_coord(c, dtype: torch.dtype, device: torch.device):
|
||||
"""
|
||||
Helper function for _handle_input.
|
||||
|
||||
@ -601,20 +629,15 @@ def _handle_coord(c, dtype, device):
|
||||
return c
|
||||
|
||||
|
||||
def _get_device(x, device=None):
|
||||
if device is not None:
|
||||
# User overriding device, leave
|
||||
device = device
|
||||
elif torch.is_tensor(x):
|
||||
# Set device based on input tensor
|
||||
device = x.device
|
||||
else:
|
||||
# Default device is cpu
|
||||
device = "cpu"
|
||||
return device
|
||||
|
||||
|
||||
def _handle_input(x, y, z, dtype, device, name: str, allow_singleton: bool = False):
|
||||
def _handle_input(
|
||||
x,
|
||||
y,
|
||||
z,
|
||||
dtype: torch.dtype,
|
||||
device: Optional[Device],
|
||||
name: str,
|
||||
allow_singleton: bool = False,
|
||||
):
|
||||
"""
|
||||
Helper function to handle parsing logic for building transforms. The output
|
||||
is always a tensor of shape (N, 3), but there are several types of allowed
|
||||
@ -642,7 +665,7 @@ def _handle_input(x, y, z, dtype, device, name: str, allow_singleton: bool = Fal
|
||||
Returns:
|
||||
xyz: Tensor of shape (N, 3)
|
||||
"""
|
||||
device = _get_device(x, device)
|
||||
device_ = get_device(x, device)
|
||||
# If x is actually a tensor of shape (N, 3) then just return it
|
||||
if torch.is_tensor(x) and x.dim() == 2:
|
||||
if x.shape[1] != 3:
|
||||
@ -651,14 +674,14 @@ def _handle_input(x, y, z, dtype, device, name: str, allow_singleton: bool = Fal
|
||||
if y is not None or z is not None:
|
||||
msg = "Expected y and z to be None (in %s)" % name
|
||||
raise ValueError(msg)
|
||||
return x.to(device=device)
|
||||
return x.to(device=device_)
|
||||
|
||||
if allow_singleton and y is None and z is None:
|
||||
y = x
|
||||
z = x
|
||||
|
||||
# Convert all to 1D tensors
|
||||
xyz = [_handle_coord(c, dtype, device) for c in [x, y, z]]
|
||||
xyz = [_handle_coord(c, dtype, device_) for c in [x, y, z]]
|
||||
|
||||
# Broadcast and concatenate
|
||||
sizes = [c.shape[0] for c in xyz]
|
||||
@ -672,7 +695,7 @@ def _handle_input(x, y, z, dtype, device, name: str, allow_singleton: bool = Fal
|
||||
return xyz
|
||||
|
||||
|
||||
def _handle_angle_input(x, dtype, device, name: str):
|
||||
def _handle_angle_input(x, dtype: torch.dtype, device: Optional[Device], name: str):
|
||||
"""
|
||||
Helper function for building a rotation function using angles.
|
||||
The output is always of shape (N,).
|
||||
@ -682,12 +705,12 @@ def _handle_angle_input(x, dtype, device, name: str):
|
||||
- Python scalar
|
||||
- Torch scalar
|
||||
"""
|
||||
device = _get_device(x, device)
|
||||
device_ = get_device(x, device)
|
||||
if torch.is_tensor(x) and x.dim() > 1:
|
||||
msg = "Expected tensor of shape (N,); got %r (in %s)"
|
||||
raise ValueError(msg % (x.shape, name))
|
||||
else:
|
||||
return _handle_coord(x, dtype, device)
|
||||
return _handle_coord(x, dtype, device_)
|
||||
|
||||
|
||||
def _broadcast_bmm(a, b):
|
||||
|
@ -20,10 +20,35 @@ class TestTransform(TestCaseMixin, unittest.TestCase):
|
||||
def test_to(self):
|
||||
tr = Translate(torch.FloatTensor([[1.0, 2.0, 3.0]]))
|
||||
R = torch.FloatTensor([[0.0, 1.0, 0.0], [0.0, 0.0, 1.0], [1.0, 0.0, 0.0]])
|
||||
cpu_points = torch.rand(9, 3)
|
||||
cuda_points = cpu_points.cuda()
|
||||
R = Rotate(R)
|
||||
t = Transform3d().compose(R, tr)
|
||||
|
||||
cpu_device = torch.device("cpu")
|
||||
|
||||
cpu_t = t.to("cpu")
|
||||
self.assertEqual(cpu_device, cpu_t.device)
|
||||
self.assertEqual(cpu_device, t.device)
|
||||
self.assertIs(t, cpu_t)
|
||||
|
||||
cpu_t = t.to(cpu_device)
|
||||
self.assertEqual(cpu_device, cpu_t.device)
|
||||
self.assertEqual(cpu_device, t.device)
|
||||
self.assertIs(t, cpu_t)
|
||||
|
||||
cuda_device = torch.device("cuda")
|
||||
|
||||
cuda_t = t.to("cuda")
|
||||
self.assertEqual(cuda_device, cuda_t.device)
|
||||
self.assertEqual(cpu_device, t.device)
|
||||
self.assertIsNot(t, cuda_t)
|
||||
|
||||
cuda_t = t.to(cuda_device)
|
||||
self.assertEqual(cuda_device, cuda_t.device)
|
||||
self.assertEqual(cpu_device, t.device)
|
||||
self.assertIsNot(t, cuda_t)
|
||||
|
||||
cpu_points = torch.rand(9, 3)
|
||||
cuda_points = cpu_points.cuda()
|
||||
for _ in range(3):
|
||||
t = t.cpu()
|
||||
t.transform_points(cpu_points)
|
||||
|
Loading…
x
Reference in New Issue
Block a user