mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 20:02:49 +08:00
Tidy uses of torch.device in Transform3d
Summary: Tidy uses of `torch.device` in `Transforms3d`: - Allow `str` or `torch.device` in user-facing methods - Consistently use `torch.device` for internal types - Fix comparison of devices Reviewed By: nikhilaravi Differential Revision: D28929486 fbshipit-source-id: bd1d6cc7ede3d8fd549fd3224a9b07eec53f8164
This commit is contained in:
parent
48faf8eb7e
commit
13a0110b69
@ -6,6 +6,7 @@ from typing import List, Optional, Union
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from ..common.types import Device, get_device, make_device
|
||||||
from .rotation_conversions import _axis_angle_rotation
|
from .rotation_conversions import _axis_angle_rotation
|
||||||
|
|
||||||
|
|
||||||
@ -137,7 +138,7 @@ class Transform3d:
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
dtype: torch.dtype = torch.float32,
|
dtype: torch.dtype = torch.float32,
|
||||||
device="cpu",
|
device: Device = "cpu",
|
||||||
matrix: Optional[torch.Tensor] = None,
|
matrix: Optional[torch.Tensor] = None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
@ -167,7 +168,7 @@ class Transform3d:
|
|||||||
|
|
||||||
self._transforms = [] # store transforms to compose
|
self._transforms = [] # store transforms to compose
|
||||||
self._lu = None
|
self._lu = None
|
||||||
self.device = device
|
self.device = make_device(device)
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
return self.get_matrix().shape[0]
|
return self.get_matrix().shape[0]
|
||||||
@ -398,7 +399,12 @@ class Transform3d:
|
|||||||
other._transforms = [t.clone() for t in self._transforms]
|
other._transforms = [t.clone() for t in self._transforms]
|
||||||
return other
|
return other
|
||||||
|
|
||||||
def to(self, device, copy: bool = False, dtype=None):
|
def to(
|
||||||
|
self,
|
||||||
|
device: Device,
|
||||||
|
copy: bool = False,
|
||||||
|
dtype: Optional[torch.dtype] = None,
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Match functionality of torch.Tensor.to()
|
Match functionality of torch.Tensor.to()
|
||||||
If copy = True or the self Tensor is on a different device, the
|
If copy = True or the self Tensor is on a different device, the
|
||||||
@ -407,7 +413,7 @@ class Transform3d:
|
|||||||
then self is returned.
|
then self is returned.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
device: Device id for the new tensor.
|
device: Device (as str or torch.device) for the new tensor.
|
||||||
copy: Boolean indicator whether or not to clone self. Default False.
|
copy: Boolean indicator whether or not to clone self. Default False.
|
||||||
dtype: If not None, casts the internal tensor variables
|
dtype: If not None, casts the internal tensor variables
|
||||||
to a given torch.dtype.
|
to a given torch.dtype.
|
||||||
@ -415,26 +421,37 @@ class Transform3d:
|
|||||||
Returns:
|
Returns:
|
||||||
Transform3d object.
|
Transform3d object.
|
||||||
"""
|
"""
|
||||||
if not copy and self.device == device:
|
device_ = make_device(device)
|
||||||
|
if not copy and self.device == device_:
|
||||||
return self
|
return self
|
||||||
|
|
||||||
other = self.clone()
|
other = self.clone()
|
||||||
if self.device != device:
|
if self.device == device_:
|
||||||
other.device = device
|
return other
|
||||||
other._matrix = self._matrix.to(device=device, dtype=dtype)
|
|
||||||
|
other.device = device_
|
||||||
|
other._matrix = self._matrix.to(device=device_, dtype=dtype)
|
||||||
other._transforms = [
|
other._transforms = [
|
||||||
t.to(device, copy=copy, dtype=dtype) for t in other._transforms
|
t.to(device_, copy=copy, dtype=dtype) for t in other._transforms
|
||||||
]
|
]
|
||||||
return other
|
return other
|
||||||
|
|
||||||
def cpu(self):
|
def cpu(self):
|
||||||
return self.to(torch.device("cpu"))
|
return self.to("cpu")
|
||||||
|
|
||||||
def cuda(self):
|
def cuda(self):
|
||||||
return self.to(torch.device("cuda"))
|
return self.to("cuda")
|
||||||
|
|
||||||
|
|
||||||
class Translate(Transform3d):
|
class Translate(Transform3d):
|
||||||
def __init__(self, x, y=None, z=None, dtype=torch.float32, device=None):
|
def __init__(
|
||||||
|
self,
|
||||||
|
x,
|
||||||
|
y=None,
|
||||||
|
z=None,
|
||||||
|
dtype: torch.dtype = torch.float32,
|
||||||
|
device: Optional[Device] = None,
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Create a new Transform3d representing 3D translations.
|
Create a new Transform3d representing 3D translations.
|
||||||
|
|
||||||
@ -468,7 +485,14 @@ class Translate(Transform3d):
|
|||||||
|
|
||||||
|
|
||||||
class Scale(Transform3d):
|
class Scale(Transform3d):
|
||||||
def __init__(self, x, y=None, z=None, dtype=torch.float32, device=None):
|
def __init__(
|
||||||
|
self,
|
||||||
|
x,
|
||||||
|
y=None,
|
||||||
|
z=None,
|
||||||
|
dtype: torch.dtype = torch.float32,
|
||||||
|
device: Optional[Device] = None,
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
A Transform3d representing a scaling operation, with different scale
|
A Transform3d representing a scaling operation, with different scale
|
||||||
factors along each coordinate axis.
|
factors along each coordinate axis.
|
||||||
@ -509,7 +533,11 @@ class Scale(Transform3d):
|
|||||||
|
|
||||||
class Rotate(Transform3d):
|
class Rotate(Transform3d):
|
||||||
def __init__(
|
def __init__(
|
||||||
self, R, dtype=torch.float32, device=None, orthogonal_tol: float = 1e-5
|
self,
|
||||||
|
R: torch.Tensor,
|
||||||
|
dtype: torch.dtype = torch.float32,
|
||||||
|
device: Optional[Device] = None,
|
||||||
|
orthogonal_tol: float = 1e-5,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Create a new Transform3d representing 3D rotation using a rotation
|
Create a new Transform3d representing 3D rotation using a rotation
|
||||||
@ -520,17 +548,17 @@ class Rotate(Transform3d):
|
|||||||
orthogonal_tol: tolerance for the test of the orthogonality of R
|
orthogonal_tol: tolerance for the test of the orthogonality of R
|
||||||
|
|
||||||
"""
|
"""
|
||||||
device = _get_device(R, device)
|
device_ = get_device(R, device)
|
||||||
super().__init__(device=device)
|
super().__init__(device=device_)
|
||||||
if R.dim() == 2:
|
if R.dim() == 2:
|
||||||
R = R[None]
|
R = R[None]
|
||||||
if R.shape[-2:] != (3, 3):
|
if R.shape[-2:] != (3, 3):
|
||||||
msg = "R must have shape (3, 3) or (N, 3, 3); got %s"
|
msg = "R must have shape (3, 3) or (N, 3, 3); got %s"
|
||||||
raise ValueError(msg % repr(R.shape))
|
raise ValueError(msg % repr(R.shape))
|
||||||
R = R.to(dtype=dtype).to(device=device)
|
R = R.to(dtype=dtype).to(device=device_)
|
||||||
_check_valid_rotation_matrix(R, tol=orthogonal_tol)
|
_check_valid_rotation_matrix(R, tol=orthogonal_tol)
|
||||||
N = R.shape[0]
|
N = R.shape[0]
|
||||||
mat = torch.eye(4, dtype=dtype, device=device)
|
mat = torch.eye(4, dtype=dtype, device=device_)
|
||||||
mat = mat.view(1, 4, 4).repeat(N, 1, 1)
|
mat = mat.view(1, 4, 4).repeat(N, 1, 1)
|
||||||
mat[:, :3, :3] = R
|
mat[:, :3, :3] = R
|
||||||
self._matrix = mat
|
self._matrix = mat
|
||||||
@ -548,8 +576,8 @@ class RotateAxisAngle(Rotate):
|
|||||||
angle,
|
angle,
|
||||||
axis: str = "X",
|
axis: str = "X",
|
||||||
degrees: bool = True,
|
degrees: bool = True,
|
||||||
dtype=torch.float64,
|
dtype: torch.dtype = torch.float64,
|
||||||
device=None,
|
device: Optional[Device] = None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Create a new Transform3d representing 3D rotation about an axis
|
Create a new Transform3d representing 3D rotation about an axis
|
||||||
@ -582,7 +610,7 @@ class RotateAxisAngle(Rotate):
|
|||||||
super().__init__(device=angle.device, R=R)
|
super().__init__(device=angle.device, R=R)
|
||||||
|
|
||||||
|
|
||||||
def _handle_coord(c, dtype, device):
|
def _handle_coord(c, dtype: torch.dtype, device: torch.device):
|
||||||
"""
|
"""
|
||||||
Helper function for _handle_input.
|
Helper function for _handle_input.
|
||||||
|
|
||||||
@ -601,20 +629,15 @@ def _handle_coord(c, dtype, device):
|
|||||||
return c
|
return c
|
||||||
|
|
||||||
|
|
||||||
def _get_device(x, device=None):
|
def _handle_input(
|
||||||
if device is not None:
|
x,
|
||||||
# User overriding device, leave
|
y,
|
||||||
device = device
|
z,
|
||||||
elif torch.is_tensor(x):
|
dtype: torch.dtype,
|
||||||
# Set device based on input tensor
|
device: Optional[Device],
|
||||||
device = x.device
|
name: str,
|
||||||
else:
|
allow_singleton: bool = False,
|
||||||
# Default device is cpu
|
):
|
||||||
device = "cpu"
|
|
||||||
return device
|
|
||||||
|
|
||||||
|
|
||||||
def _handle_input(x, y, z, dtype, device, name: str, allow_singleton: bool = False):
|
|
||||||
"""
|
"""
|
||||||
Helper function to handle parsing logic for building transforms. The output
|
Helper function to handle parsing logic for building transforms. The output
|
||||||
is always a tensor of shape (N, 3), but there are several types of allowed
|
is always a tensor of shape (N, 3), but there are several types of allowed
|
||||||
@ -642,7 +665,7 @@ def _handle_input(x, y, z, dtype, device, name: str, allow_singleton: bool = Fal
|
|||||||
Returns:
|
Returns:
|
||||||
xyz: Tensor of shape (N, 3)
|
xyz: Tensor of shape (N, 3)
|
||||||
"""
|
"""
|
||||||
device = _get_device(x, device)
|
device_ = get_device(x, device)
|
||||||
# If x is actually a tensor of shape (N, 3) then just return it
|
# If x is actually a tensor of shape (N, 3) then just return it
|
||||||
if torch.is_tensor(x) and x.dim() == 2:
|
if torch.is_tensor(x) and x.dim() == 2:
|
||||||
if x.shape[1] != 3:
|
if x.shape[1] != 3:
|
||||||
@ -651,14 +674,14 @@ def _handle_input(x, y, z, dtype, device, name: str, allow_singleton: bool = Fal
|
|||||||
if y is not None or z is not None:
|
if y is not None or z is not None:
|
||||||
msg = "Expected y and z to be None (in %s)" % name
|
msg = "Expected y and z to be None (in %s)" % name
|
||||||
raise ValueError(msg)
|
raise ValueError(msg)
|
||||||
return x.to(device=device)
|
return x.to(device=device_)
|
||||||
|
|
||||||
if allow_singleton and y is None and z is None:
|
if allow_singleton and y is None and z is None:
|
||||||
y = x
|
y = x
|
||||||
z = x
|
z = x
|
||||||
|
|
||||||
# Convert all to 1D tensors
|
# Convert all to 1D tensors
|
||||||
xyz = [_handle_coord(c, dtype, device) for c in [x, y, z]]
|
xyz = [_handle_coord(c, dtype, device_) for c in [x, y, z]]
|
||||||
|
|
||||||
# Broadcast and concatenate
|
# Broadcast and concatenate
|
||||||
sizes = [c.shape[0] for c in xyz]
|
sizes = [c.shape[0] for c in xyz]
|
||||||
@ -672,7 +695,7 @@ def _handle_input(x, y, z, dtype, device, name: str, allow_singleton: bool = Fal
|
|||||||
return xyz
|
return xyz
|
||||||
|
|
||||||
|
|
||||||
def _handle_angle_input(x, dtype, device, name: str):
|
def _handle_angle_input(x, dtype: torch.dtype, device: Optional[Device], name: str):
|
||||||
"""
|
"""
|
||||||
Helper function for building a rotation function using angles.
|
Helper function for building a rotation function using angles.
|
||||||
The output is always of shape (N,).
|
The output is always of shape (N,).
|
||||||
@ -682,12 +705,12 @@ def _handle_angle_input(x, dtype, device, name: str):
|
|||||||
- Python scalar
|
- Python scalar
|
||||||
- Torch scalar
|
- Torch scalar
|
||||||
"""
|
"""
|
||||||
device = _get_device(x, device)
|
device_ = get_device(x, device)
|
||||||
if torch.is_tensor(x) and x.dim() > 1:
|
if torch.is_tensor(x) and x.dim() > 1:
|
||||||
msg = "Expected tensor of shape (N,); got %r (in %s)"
|
msg = "Expected tensor of shape (N,); got %r (in %s)"
|
||||||
raise ValueError(msg % (x.shape, name))
|
raise ValueError(msg % (x.shape, name))
|
||||||
else:
|
else:
|
||||||
return _handle_coord(x, dtype, device)
|
return _handle_coord(x, dtype, device_)
|
||||||
|
|
||||||
|
|
||||||
def _broadcast_bmm(a, b):
|
def _broadcast_bmm(a, b):
|
||||||
|
@ -20,10 +20,35 @@ class TestTransform(TestCaseMixin, unittest.TestCase):
|
|||||||
def test_to(self):
|
def test_to(self):
|
||||||
tr = Translate(torch.FloatTensor([[1.0, 2.0, 3.0]]))
|
tr = Translate(torch.FloatTensor([[1.0, 2.0, 3.0]]))
|
||||||
R = torch.FloatTensor([[0.0, 1.0, 0.0], [0.0, 0.0, 1.0], [1.0, 0.0, 0.0]])
|
R = torch.FloatTensor([[0.0, 1.0, 0.0], [0.0, 0.0, 1.0], [1.0, 0.0, 0.0]])
|
||||||
cpu_points = torch.rand(9, 3)
|
|
||||||
cuda_points = cpu_points.cuda()
|
|
||||||
R = Rotate(R)
|
R = Rotate(R)
|
||||||
t = Transform3d().compose(R, tr)
|
t = Transform3d().compose(R, tr)
|
||||||
|
|
||||||
|
cpu_device = torch.device("cpu")
|
||||||
|
|
||||||
|
cpu_t = t.to("cpu")
|
||||||
|
self.assertEqual(cpu_device, cpu_t.device)
|
||||||
|
self.assertEqual(cpu_device, t.device)
|
||||||
|
self.assertIs(t, cpu_t)
|
||||||
|
|
||||||
|
cpu_t = t.to(cpu_device)
|
||||||
|
self.assertEqual(cpu_device, cpu_t.device)
|
||||||
|
self.assertEqual(cpu_device, t.device)
|
||||||
|
self.assertIs(t, cpu_t)
|
||||||
|
|
||||||
|
cuda_device = torch.device("cuda")
|
||||||
|
|
||||||
|
cuda_t = t.to("cuda")
|
||||||
|
self.assertEqual(cuda_device, cuda_t.device)
|
||||||
|
self.assertEqual(cpu_device, t.device)
|
||||||
|
self.assertIsNot(t, cuda_t)
|
||||||
|
|
||||||
|
cuda_t = t.to(cuda_device)
|
||||||
|
self.assertEqual(cuda_device, cuda_t.device)
|
||||||
|
self.assertEqual(cpu_device, t.device)
|
||||||
|
self.assertIsNot(t, cuda_t)
|
||||||
|
|
||||||
|
cpu_points = torch.rand(9, 3)
|
||||||
|
cuda_points = cpu_points.cuda()
|
||||||
for _ in range(3):
|
for _ in range(3):
|
||||||
t = t.cpu()
|
t = t.cpu()
|
||||||
t.transform_points(cpu_points)
|
t.transform_points(cpu_points)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user