Tidy uses of torch.device in Transform3d

Summary:
Tidy uses of `torch.device` in `Transforms3d`:
- Allow `str` or `torch.device` in user-facing methods
- Consistently use `torch.device` for internal types
- Fix comparison of devices

Reviewed By: nikhilaravi

Differential Revision: D28929486

fbshipit-source-id: bd1d6cc7ede3d8fd549fd3224a9b07eec53f8164
This commit is contained in:
Patrick Labatut 2021-06-09 15:48:56 -07:00 committed by Facebook GitHub Bot
parent 48faf8eb7e
commit 13a0110b69
2 changed files with 93 additions and 45 deletions

View File

@ -6,6 +6,7 @@ from typing import List, Optional, Union
import torch
from ..common.types import Device, get_device, make_device
from .rotation_conversions import _axis_angle_rotation
@ -137,7 +138,7 @@ class Transform3d:
def __init__(
self,
dtype: torch.dtype = torch.float32,
device="cpu",
device: Device = "cpu",
matrix: Optional[torch.Tensor] = None,
):
"""
@ -167,7 +168,7 @@ class Transform3d:
self._transforms = [] # store transforms to compose
self._lu = None
self.device = device
self.device = make_device(device)
def __len__(self):
return self.get_matrix().shape[0]
@ -398,7 +399,12 @@ class Transform3d:
other._transforms = [t.clone() for t in self._transforms]
return other
def to(self, device, copy: bool = False, dtype=None):
def to(
self,
device: Device,
copy: bool = False,
dtype: Optional[torch.dtype] = None,
):
"""
Match functionality of torch.Tensor.to()
If copy = True or the self Tensor is on a different device, the
@ -407,7 +413,7 @@ class Transform3d:
then self is returned.
Args:
device: Device id for the new tensor.
device: Device (as str or torch.device) for the new tensor.
copy: Boolean indicator whether or not to clone self. Default False.
dtype: If not None, casts the internal tensor variables
to a given torch.dtype.
@ -415,26 +421,37 @@ class Transform3d:
Returns:
Transform3d object.
"""
if not copy and self.device == device:
device_ = make_device(device)
if not copy and self.device == device_:
return self
other = self.clone()
if self.device != device:
other.device = device
other._matrix = self._matrix.to(device=device, dtype=dtype)
other._transforms = [
t.to(device, copy=copy, dtype=dtype) for t in other._transforms
]
if self.device == device_:
return other
other.device = device_
other._matrix = self._matrix.to(device=device_, dtype=dtype)
other._transforms = [
t.to(device_, copy=copy, dtype=dtype) for t in other._transforms
]
return other
def cpu(self):
return self.to(torch.device("cpu"))
return self.to("cpu")
def cuda(self):
return self.to(torch.device("cuda"))
return self.to("cuda")
class Translate(Transform3d):
def __init__(self, x, y=None, z=None, dtype=torch.float32, device=None):
def __init__(
self,
x,
y=None,
z=None,
dtype: torch.dtype = torch.float32,
device: Optional[Device] = None,
):
"""
Create a new Transform3d representing 3D translations.
@ -468,7 +485,14 @@ class Translate(Transform3d):
class Scale(Transform3d):
def __init__(self, x, y=None, z=None, dtype=torch.float32, device=None):
def __init__(
self,
x,
y=None,
z=None,
dtype: torch.dtype = torch.float32,
device: Optional[Device] = None,
):
"""
A Transform3d representing a scaling operation, with different scale
factors along each coordinate axis.
@ -509,7 +533,11 @@ class Scale(Transform3d):
class Rotate(Transform3d):
def __init__(
self, R, dtype=torch.float32, device=None, orthogonal_tol: float = 1e-5
self,
R: torch.Tensor,
dtype: torch.dtype = torch.float32,
device: Optional[Device] = None,
orthogonal_tol: float = 1e-5,
):
"""
Create a new Transform3d representing 3D rotation using a rotation
@ -520,17 +548,17 @@ class Rotate(Transform3d):
orthogonal_tol: tolerance for the test of the orthogonality of R
"""
device = _get_device(R, device)
super().__init__(device=device)
device_ = get_device(R, device)
super().__init__(device=device_)
if R.dim() == 2:
R = R[None]
if R.shape[-2:] != (3, 3):
msg = "R must have shape (3, 3) or (N, 3, 3); got %s"
raise ValueError(msg % repr(R.shape))
R = R.to(dtype=dtype).to(device=device)
R = R.to(dtype=dtype).to(device=device_)
_check_valid_rotation_matrix(R, tol=orthogonal_tol)
N = R.shape[0]
mat = torch.eye(4, dtype=dtype, device=device)
mat = torch.eye(4, dtype=dtype, device=device_)
mat = mat.view(1, 4, 4).repeat(N, 1, 1)
mat[:, :3, :3] = R
self._matrix = mat
@ -548,8 +576,8 @@ class RotateAxisAngle(Rotate):
angle,
axis: str = "X",
degrees: bool = True,
dtype=torch.float64,
device=None,
dtype: torch.dtype = torch.float64,
device: Optional[Device] = None,
):
"""
Create a new Transform3d representing 3D rotation about an axis
@ -582,7 +610,7 @@ class RotateAxisAngle(Rotate):
super().__init__(device=angle.device, R=R)
def _handle_coord(c, dtype, device):
def _handle_coord(c, dtype: torch.dtype, device: torch.device):
"""
Helper function for _handle_input.
@ -601,20 +629,15 @@ def _handle_coord(c, dtype, device):
return c
def _get_device(x, device=None):
if device is not None:
# User overriding device, leave
device = device
elif torch.is_tensor(x):
# Set device based on input tensor
device = x.device
else:
# Default device is cpu
device = "cpu"
return device
def _handle_input(x, y, z, dtype, device, name: str, allow_singleton: bool = False):
def _handle_input(
x,
y,
z,
dtype: torch.dtype,
device: Optional[Device],
name: str,
allow_singleton: bool = False,
):
"""
Helper function to handle parsing logic for building transforms. The output
is always a tensor of shape (N, 3), but there are several types of allowed
@ -642,7 +665,7 @@ def _handle_input(x, y, z, dtype, device, name: str, allow_singleton: bool = Fal
Returns:
xyz: Tensor of shape (N, 3)
"""
device = _get_device(x, device)
device_ = get_device(x, device)
# If x is actually a tensor of shape (N, 3) then just return it
if torch.is_tensor(x) and x.dim() == 2:
if x.shape[1] != 3:
@ -651,14 +674,14 @@ def _handle_input(x, y, z, dtype, device, name: str, allow_singleton: bool = Fal
if y is not None or z is not None:
msg = "Expected y and z to be None (in %s)" % name
raise ValueError(msg)
return x.to(device=device)
return x.to(device=device_)
if allow_singleton and y is None and z is None:
y = x
z = x
# Convert all to 1D tensors
xyz = [_handle_coord(c, dtype, device) for c in [x, y, z]]
xyz = [_handle_coord(c, dtype, device_) for c in [x, y, z]]
# Broadcast and concatenate
sizes = [c.shape[0] for c in xyz]
@ -672,7 +695,7 @@ def _handle_input(x, y, z, dtype, device, name: str, allow_singleton: bool = Fal
return xyz
def _handle_angle_input(x, dtype, device, name: str):
def _handle_angle_input(x, dtype: torch.dtype, device: Optional[Device], name: str):
"""
Helper function for building a rotation function using angles.
The output is always of shape (N,).
@ -682,12 +705,12 @@ def _handle_angle_input(x, dtype, device, name: str):
- Python scalar
- Torch scalar
"""
device = _get_device(x, device)
device_ = get_device(x, device)
if torch.is_tensor(x) and x.dim() > 1:
msg = "Expected tensor of shape (N,); got %r (in %s)"
raise ValueError(msg % (x.shape, name))
else:
return _handle_coord(x, dtype, device)
return _handle_coord(x, dtype, device_)
def _broadcast_bmm(a, b):

View File

@ -20,10 +20,35 @@ class TestTransform(TestCaseMixin, unittest.TestCase):
def test_to(self):
tr = Translate(torch.FloatTensor([[1.0, 2.0, 3.0]]))
R = torch.FloatTensor([[0.0, 1.0, 0.0], [0.0, 0.0, 1.0], [1.0, 0.0, 0.0]])
cpu_points = torch.rand(9, 3)
cuda_points = cpu_points.cuda()
R = Rotate(R)
t = Transform3d().compose(R, tr)
cpu_device = torch.device("cpu")
cpu_t = t.to("cpu")
self.assertEqual(cpu_device, cpu_t.device)
self.assertEqual(cpu_device, t.device)
self.assertIs(t, cpu_t)
cpu_t = t.to(cpu_device)
self.assertEqual(cpu_device, cpu_t.device)
self.assertEqual(cpu_device, t.device)
self.assertIs(t, cpu_t)
cuda_device = torch.device("cuda")
cuda_t = t.to("cuda")
self.assertEqual(cuda_device, cuda_t.device)
self.assertEqual(cpu_device, t.device)
self.assertIsNot(t, cuda_t)
cuda_t = t.to(cuda_device)
self.assertEqual(cuda_device, cuda_t.device)
self.assertEqual(cpu_device, t.device)
self.assertIsNot(t, cuda_t)
cpu_points = torch.rand(9, 3)
cuda_points = cpu_points.cuda()
for _ in range(3):
t = t.cpu()
t.transform_points(cpu_points)