Make Transform3d.to() not ignore dtype

Summary: Make Transform3d.to() not ignore a different dtype when device is the same and no copy is requested. Fix other methods where dtype is ignored.

Reviewed By: nikhilaravi

Differential Revision: D28981171

fbshipit-source-id: 4528e6092f4a693aecbe8131ede985fca84e84cf
This commit is contained in:
Patrick Labatut 2021-06-09 15:48:56 -07:00 committed by Facebook GitHub Bot
parent 626bf3fe23
commit 44508ed0db
2 changed files with 39 additions and 10 deletions

View File

@ -162,13 +162,15 @@ class Transform3d:
raise ValueError(
'"matrix" has to be a tensor of shape (minibatch, 4, 4)'
)
# set the device from matrix
# set dtype and device from matrix
dtype = matrix.dtype
device = matrix.device
self._matrix = matrix.view(-1, 4, 4)
self._transforms = [] # store transforms to compose
self._lu = None
self.device = make_device(device)
self.dtype = dtype
def __len__(self):
return self.get_matrix().shape[0]
@ -200,7 +202,7 @@ class Transform3d:
Returns:
A new Transform3d with the stored transforms
"""
out = Transform3d(device=self.device)
out = Transform3d(dtype=self.dtype, device=self.device)
out._matrix = self._matrix.clone()
for other in others:
if not isinstance(other, Transform3d):
@ -259,7 +261,7 @@ class Transform3d:
transformation.
"""
tinv = Transform3d(device=self.device)
tinv = Transform3d(dtype=self.dtype, device=self.device)
if invert_composed:
# first compose then invert
@ -278,7 +280,7 @@ class Transform3d:
# right-multiplies by the inverse of self._matrix
# at the end of the composition.
tinv._transforms = [t.inverse() for t in reversed(self._transforms)]
last = Transform3d(device=self.device)
last = Transform3d(dtype=self.dtype, device=self.device)
last._matrix = i_matrix
tinv._transforms.append(last)
else:
@ -291,7 +293,7 @@ class Transform3d:
def stack(self, *others):
transforms = [self] + list(others)
matrix = torch.cat([t._matrix for t in transforms], dim=0)
out = Transform3d()
out = Transform3d(dtype=self.dtype, device=self.device)
out._matrix = matrix
return out
@ -392,7 +394,7 @@ class Transform3d:
Returns:
new Transforms object.
"""
other = Transform3d(device=self.device)
other = Transform3d(dtype=self.dtype, device=self.device)
if self._lu is not None:
other._lu = [elem.clone() for elem in self._lu]
other._matrix = self._matrix.clone()
@ -422,17 +424,22 @@ class Transform3d:
Transform3d object.
"""
device_ = make_device(device)
if not copy and self.device == device_:
dtype_ = self.dtype if dtype is None else dtype
skip_to = self.device == device_ and self.dtype == dtype_
if not copy and skip_to:
return self
other = self.clone()
if self.device == device_:
if skip_to:
return other
other.device = device_
other._matrix = self._matrix.to(device=device_, dtype=dtype)
other.dtype = dtype_
other._matrix = other._matrix.to(device=device_, dtype=dtype_)
other._transforms = [
t.to(device_, copy=copy, dtype=dtype) for t in other._transforms
t.to(device_, copy=copy, dtype=dtype_) for t in other._transforms
]
return other

View File

@ -28,23 +28,45 @@ class TestTransform(TestCaseMixin, unittest.TestCase):
cpu_t = t.to("cpu")
self.assertEqual(cpu_device, cpu_t.device)
self.assertEqual(cpu_device, t.device)
self.assertEqual(torch.float32, cpu_t.dtype)
self.assertEqual(torch.float32, t.dtype)
self.assertIs(t, cpu_t)
cpu_t = t.to(cpu_device)
self.assertEqual(cpu_device, cpu_t.device)
self.assertEqual(cpu_device, t.device)
self.assertEqual(torch.float32, cpu_t.dtype)
self.assertEqual(torch.float32, t.dtype)
self.assertIs(t, cpu_t)
cpu_t = t.to(dtype=torch.float64, device=cpu_device)
self.assertEqual(cpu_device, cpu_t.device)
self.assertEqual(cpu_device, t.device)
self.assertEqual(torch.float64, cpu_t.dtype)
self.assertEqual(torch.float32, t.dtype)
self.assertIsNot(t, cpu_t)
cuda_device = torch.device("cuda")
cuda_t = t.to("cuda")
self.assertEqual(cuda_device, cuda_t.device)
self.assertEqual(cpu_device, t.device)
self.assertEqual(torch.float32, cuda_t.dtype)
self.assertEqual(torch.float32, t.dtype)
self.assertIsNot(t, cuda_t)
cuda_t = t.to(cuda_device)
self.assertEqual(cuda_device, cuda_t.device)
self.assertEqual(cpu_device, t.device)
self.assertEqual(torch.float32, cuda_t.dtype)
self.assertEqual(torch.float32, t.dtype)
self.assertIsNot(t, cuda_t)
cuda_t = t.to(dtype=torch.float64, device=cuda_device)
self.assertEqual(cuda_device, cuda_t.device)
self.assertEqual(cpu_device, t.device)
self.assertEqual(torch.float64, cuda_t.dtype)
self.assertEqual(torch.float32, t.dtype)
self.assertIsNot(t, cuda_t)
cpu_points = torch.rand(9, 3)