Tidy uses of torch.device in Transform3d

Summary:
Tidy uses of `torch.device` in `Transforms3d`:
- Allow `str` or `torch.device` in user-facing methods
- Consistently use `torch.device` for internal types
- Fix comparison of devices

Reviewed By: nikhilaravi

Differential Revision: D28929486

fbshipit-source-id: bd1d6cc7ede3d8fd549fd3224a9b07eec53f8164
This commit is contained in:
Patrick Labatut 2021-06-09 15:48:56 -07:00 committed by Facebook GitHub Bot
parent 48faf8eb7e
commit 13a0110b69
2 changed files with 93 additions and 45 deletions

View File

@ -6,6 +6,7 @@ from typing import List, Optional, Union
import torch import torch
from ..common.types import Device, get_device, make_device
from .rotation_conversions import _axis_angle_rotation from .rotation_conversions import _axis_angle_rotation
@ -137,7 +138,7 @@ class Transform3d:
def __init__( def __init__(
self, self,
dtype: torch.dtype = torch.float32, dtype: torch.dtype = torch.float32,
device="cpu", device: Device = "cpu",
matrix: Optional[torch.Tensor] = None, matrix: Optional[torch.Tensor] = None,
): ):
""" """
@ -167,7 +168,7 @@ class Transform3d:
self._transforms = [] # store transforms to compose self._transforms = [] # store transforms to compose
self._lu = None self._lu = None
self.device = device self.device = make_device(device)
def __len__(self): def __len__(self):
return self.get_matrix().shape[0] return self.get_matrix().shape[0]
@ -398,7 +399,12 @@ class Transform3d:
other._transforms = [t.clone() for t in self._transforms] other._transforms = [t.clone() for t in self._transforms]
return other return other
def to(self, device, copy: bool = False, dtype=None): def to(
self,
device: Device,
copy: bool = False,
dtype: Optional[torch.dtype] = None,
):
""" """
Match functionality of torch.Tensor.to() Match functionality of torch.Tensor.to()
If copy = True or the self Tensor is on a different device, the If copy = True or the self Tensor is on a different device, the
@ -407,7 +413,7 @@ class Transform3d:
then self is returned. then self is returned.
Args: Args:
device: Device id for the new tensor. device: Device (as str or torch.device) for the new tensor.
copy: Boolean indicator whether or not to clone self. Default False. copy: Boolean indicator whether or not to clone self. Default False.
dtype: If not None, casts the internal tensor variables dtype: If not None, casts the internal tensor variables
to a given torch.dtype. to a given torch.dtype.
@ -415,26 +421,37 @@ class Transform3d:
Returns: Returns:
Transform3d object. Transform3d object.
""" """
if not copy and self.device == device: device_ = make_device(device)
if not copy and self.device == device_:
return self return self
other = self.clone() other = self.clone()
if self.device != device: if self.device == device_:
other.device = device return other
other._matrix = self._matrix.to(device=device, dtype=dtype)
other._transforms = [ other.device = device_
t.to(device, copy=copy, dtype=dtype) for t in other._transforms other._matrix = self._matrix.to(device=device_, dtype=dtype)
] other._transforms = [
t.to(device_, copy=copy, dtype=dtype) for t in other._transforms
]
return other return other
def cpu(self): def cpu(self):
return self.to(torch.device("cpu")) return self.to("cpu")
def cuda(self): def cuda(self):
return self.to(torch.device("cuda")) return self.to("cuda")
class Translate(Transform3d): class Translate(Transform3d):
def __init__(self, x, y=None, z=None, dtype=torch.float32, device=None): def __init__(
self,
x,
y=None,
z=None,
dtype: torch.dtype = torch.float32,
device: Optional[Device] = None,
):
""" """
Create a new Transform3d representing 3D translations. Create a new Transform3d representing 3D translations.
@ -468,7 +485,14 @@ class Translate(Transform3d):
class Scale(Transform3d): class Scale(Transform3d):
def __init__(self, x, y=None, z=None, dtype=torch.float32, device=None): def __init__(
self,
x,
y=None,
z=None,
dtype: torch.dtype = torch.float32,
device: Optional[Device] = None,
):
""" """
A Transform3d representing a scaling operation, with different scale A Transform3d representing a scaling operation, with different scale
factors along each coordinate axis. factors along each coordinate axis.
@ -509,7 +533,11 @@ class Scale(Transform3d):
class Rotate(Transform3d): class Rotate(Transform3d):
def __init__( def __init__(
self, R, dtype=torch.float32, device=None, orthogonal_tol: float = 1e-5 self,
R: torch.Tensor,
dtype: torch.dtype = torch.float32,
device: Optional[Device] = None,
orthogonal_tol: float = 1e-5,
): ):
""" """
Create a new Transform3d representing 3D rotation using a rotation Create a new Transform3d representing 3D rotation using a rotation
@ -520,17 +548,17 @@ class Rotate(Transform3d):
orthogonal_tol: tolerance for the test of the orthogonality of R orthogonal_tol: tolerance for the test of the orthogonality of R
""" """
device = _get_device(R, device) device_ = get_device(R, device)
super().__init__(device=device) super().__init__(device=device_)
if R.dim() == 2: if R.dim() == 2:
R = R[None] R = R[None]
if R.shape[-2:] != (3, 3): if R.shape[-2:] != (3, 3):
msg = "R must have shape (3, 3) or (N, 3, 3); got %s" msg = "R must have shape (3, 3) or (N, 3, 3); got %s"
raise ValueError(msg % repr(R.shape)) raise ValueError(msg % repr(R.shape))
R = R.to(dtype=dtype).to(device=device) R = R.to(dtype=dtype).to(device=device_)
_check_valid_rotation_matrix(R, tol=orthogonal_tol) _check_valid_rotation_matrix(R, tol=orthogonal_tol)
N = R.shape[0] N = R.shape[0]
mat = torch.eye(4, dtype=dtype, device=device) mat = torch.eye(4, dtype=dtype, device=device_)
mat = mat.view(1, 4, 4).repeat(N, 1, 1) mat = mat.view(1, 4, 4).repeat(N, 1, 1)
mat[:, :3, :3] = R mat[:, :3, :3] = R
self._matrix = mat self._matrix = mat
@ -548,8 +576,8 @@ class RotateAxisAngle(Rotate):
angle, angle,
axis: str = "X", axis: str = "X",
degrees: bool = True, degrees: bool = True,
dtype=torch.float64, dtype: torch.dtype = torch.float64,
device=None, device: Optional[Device] = None,
): ):
""" """
Create a new Transform3d representing 3D rotation about an axis Create a new Transform3d representing 3D rotation about an axis
@ -582,7 +610,7 @@ class RotateAxisAngle(Rotate):
super().__init__(device=angle.device, R=R) super().__init__(device=angle.device, R=R)
def _handle_coord(c, dtype, device): def _handle_coord(c, dtype: torch.dtype, device: torch.device):
""" """
Helper function for _handle_input. Helper function for _handle_input.
@ -601,20 +629,15 @@ def _handle_coord(c, dtype, device):
return c return c
def _get_device(x, device=None): def _handle_input(
if device is not None: x,
# User overriding device, leave y,
device = device z,
elif torch.is_tensor(x): dtype: torch.dtype,
# Set device based on input tensor device: Optional[Device],
device = x.device name: str,
else: allow_singleton: bool = False,
# Default device is cpu ):
device = "cpu"
return device
def _handle_input(x, y, z, dtype, device, name: str, allow_singleton: bool = False):
""" """
Helper function to handle parsing logic for building transforms. The output Helper function to handle parsing logic for building transforms. The output
is always a tensor of shape (N, 3), but there are several types of allowed is always a tensor of shape (N, 3), but there are several types of allowed
@ -642,7 +665,7 @@ def _handle_input(x, y, z, dtype, device, name: str, allow_singleton: bool = Fal
Returns: Returns:
xyz: Tensor of shape (N, 3) xyz: Tensor of shape (N, 3)
""" """
device = _get_device(x, device) device_ = get_device(x, device)
# If x is actually a tensor of shape (N, 3) then just return it # If x is actually a tensor of shape (N, 3) then just return it
if torch.is_tensor(x) and x.dim() == 2: if torch.is_tensor(x) and x.dim() == 2:
if x.shape[1] != 3: if x.shape[1] != 3:
@ -651,14 +674,14 @@ def _handle_input(x, y, z, dtype, device, name: str, allow_singleton: bool = Fal
if y is not None or z is not None: if y is not None or z is not None:
msg = "Expected y and z to be None (in %s)" % name msg = "Expected y and z to be None (in %s)" % name
raise ValueError(msg) raise ValueError(msg)
return x.to(device=device) return x.to(device=device_)
if allow_singleton and y is None and z is None: if allow_singleton and y is None and z is None:
y = x y = x
z = x z = x
# Convert all to 1D tensors # Convert all to 1D tensors
xyz = [_handle_coord(c, dtype, device) for c in [x, y, z]] xyz = [_handle_coord(c, dtype, device_) for c in [x, y, z]]
# Broadcast and concatenate # Broadcast and concatenate
sizes = [c.shape[0] for c in xyz] sizes = [c.shape[0] for c in xyz]
@ -672,7 +695,7 @@ def _handle_input(x, y, z, dtype, device, name: str, allow_singleton: bool = Fal
return xyz return xyz
def _handle_angle_input(x, dtype, device, name: str): def _handle_angle_input(x, dtype: torch.dtype, device: Optional[Device], name: str):
""" """
Helper function for building a rotation function using angles. Helper function for building a rotation function using angles.
The output is always of shape (N,). The output is always of shape (N,).
@ -682,12 +705,12 @@ def _handle_angle_input(x, dtype, device, name: str):
- Python scalar - Python scalar
- Torch scalar - Torch scalar
""" """
device = _get_device(x, device) device_ = get_device(x, device)
if torch.is_tensor(x) and x.dim() > 1: if torch.is_tensor(x) and x.dim() > 1:
msg = "Expected tensor of shape (N,); got %r (in %s)" msg = "Expected tensor of shape (N,); got %r (in %s)"
raise ValueError(msg % (x.shape, name)) raise ValueError(msg % (x.shape, name))
else: else:
return _handle_coord(x, dtype, device) return _handle_coord(x, dtype, device_)
def _broadcast_bmm(a, b): def _broadcast_bmm(a, b):

View File

@ -20,10 +20,35 @@ class TestTransform(TestCaseMixin, unittest.TestCase):
def test_to(self): def test_to(self):
tr = Translate(torch.FloatTensor([[1.0, 2.0, 3.0]])) tr = Translate(torch.FloatTensor([[1.0, 2.0, 3.0]]))
R = torch.FloatTensor([[0.0, 1.0, 0.0], [0.0, 0.0, 1.0], [1.0, 0.0, 0.0]]) R = torch.FloatTensor([[0.0, 1.0, 0.0], [0.0, 0.0, 1.0], [1.0, 0.0, 0.0]])
cpu_points = torch.rand(9, 3)
cuda_points = cpu_points.cuda()
R = Rotate(R) R = Rotate(R)
t = Transform3d().compose(R, tr) t = Transform3d().compose(R, tr)
cpu_device = torch.device("cpu")
cpu_t = t.to("cpu")
self.assertEqual(cpu_device, cpu_t.device)
self.assertEqual(cpu_device, t.device)
self.assertIs(t, cpu_t)
cpu_t = t.to(cpu_device)
self.assertEqual(cpu_device, cpu_t.device)
self.assertEqual(cpu_device, t.device)
self.assertIs(t, cpu_t)
cuda_device = torch.device("cuda")
cuda_t = t.to("cuda")
self.assertEqual(cuda_device, cuda_t.device)
self.assertEqual(cpu_device, t.device)
self.assertIsNot(t, cuda_t)
cuda_t = t.to(cuda_device)
self.assertEqual(cuda_device, cuda_t.device)
self.assertEqual(cpu_device, t.device)
self.assertIsNot(t, cuda_t)
cpu_points = torch.rand(9, 3)
cuda_points = cpu_points.cuda()
for _ in range(3): for _ in range(3):
t = t.cpu() t = t.cpu()
t.transform_points(cpu_points) t.transform_points(cpu_points)