mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 03:42:50 +08:00
Make Transform3d.to() not ignore dtype
Summary: Make Transform3d.to() not ignore a different dtype when device is the same and no copy is requested. Fix other methods where dtype is ignored. Reviewed By: nikhilaravi Differential Revision: D28981171 fbshipit-source-id: 4528e6092f4a693aecbe8131ede985fca84e84cf
This commit is contained in:
parent
626bf3fe23
commit
44508ed0db
@ -162,13 +162,15 @@ class Transform3d:
|
||||
raise ValueError(
|
||||
'"matrix" has to be a tensor of shape (minibatch, 4, 4)'
|
||||
)
|
||||
# set the device from matrix
|
||||
# set dtype and device from matrix
|
||||
dtype = matrix.dtype
|
||||
device = matrix.device
|
||||
self._matrix = matrix.view(-1, 4, 4)
|
||||
|
||||
self._transforms = [] # store transforms to compose
|
||||
self._lu = None
|
||||
self.device = make_device(device)
|
||||
self.dtype = dtype
|
||||
|
||||
def __len__(self):
|
||||
return self.get_matrix().shape[0]
|
||||
@ -200,7 +202,7 @@ class Transform3d:
|
||||
Returns:
|
||||
A new Transform3d with the stored transforms
|
||||
"""
|
||||
out = Transform3d(device=self.device)
|
||||
out = Transform3d(dtype=self.dtype, device=self.device)
|
||||
out._matrix = self._matrix.clone()
|
||||
for other in others:
|
||||
if not isinstance(other, Transform3d):
|
||||
@ -259,7 +261,7 @@ class Transform3d:
|
||||
transformation.
|
||||
"""
|
||||
|
||||
tinv = Transform3d(device=self.device)
|
||||
tinv = Transform3d(dtype=self.dtype, device=self.device)
|
||||
|
||||
if invert_composed:
|
||||
# first compose then invert
|
||||
@ -278,7 +280,7 @@ class Transform3d:
|
||||
# right-multiplies by the inverse of self._matrix
|
||||
# at the end of the composition.
|
||||
tinv._transforms = [t.inverse() for t in reversed(self._transforms)]
|
||||
last = Transform3d(device=self.device)
|
||||
last = Transform3d(dtype=self.dtype, device=self.device)
|
||||
last._matrix = i_matrix
|
||||
tinv._transforms.append(last)
|
||||
else:
|
||||
@ -291,7 +293,7 @@ class Transform3d:
|
||||
def stack(self, *others):
|
||||
transforms = [self] + list(others)
|
||||
matrix = torch.cat([t._matrix for t in transforms], dim=0)
|
||||
out = Transform3d()
|
||||
out = Transform3d(dtype=self.dtype, device=self.device)
|
||||
out._matrix = matrix
|
||||
return out
|
||||
|
||||
@ -392,7 +394,7 @@ class Transform3d:
|
||||
Returns:
|
||||
new Transforms object.
|
||||
"""
|
||||
other = Transform3d(device=self.device)
|
||||
other = Transform3d(dtype=self.dtype, device=self.device)
|
||||
if self._lu is not None:
|
||||
other._lu = [elem.clone() for elem in self._lu]
|
||||
other._matrix = self._matrix.clone()
|
||||
@ -422,17 +424,22 @@ class Transform3d:
|
||||
Transform3d object.
|
||||
"""
|
||||
device_ = make_device(device)
|
||||
if not copy and self.device == device_:
|
||||
dtype_ = self.dtype if dtype is None else dtype
|
||||
skip_to = self.device == device_ and self.dtype == dtype_
|
||||
|
||||
if not copy and skip_to:
|
||||
return self
|
||||
|
||||
other = self.clone()
|
||||
if self.device == device_:
|
||||
|
||||
if skip_to:
|
||||
return other
|
||||
|
||||
other.device = device_
|
||||
other._matrix = self._matrix.to(device=device_, dtype=dtype)
|
||||
other.dtype = dtype_
|
||||
other._matrix = other._matrix.to(device=device_, dtype=dtype_)
|
||||
other._transforms = [
|
||||
t.to(device_, copy=copy, dtype=dtype) for t in other._transforms
|
||||
t.to(device_, copy=copy, dtype=dtype_) for t in other._transforms
|
||||
]
|
||||
return other
|
||||
|
||||
|
@ -28,23 +28,45 @@ class TestTransform(TestCaseMixin, unittest.TestCase):
|
||||
cpu_t = t.to("cpu")
|
||||
self.assertEqual(cpu_device, cpu_t.device)
|
||||
self.assertEqual(cpu_device, t.device)
|
||||
self.assertEqual(torch.float32, cpu_t.dtype)
|
||||
self.assertEqual(torch.float32, t.dtype)
|
||||
self.assertIs(t, cpu_t)
|
||||
|
||||
cpu_t = t.to(cpu_device)
|
||||
self.assertEqual(cpu_device, cpu_t.device)
|
||||
self.assertEqual(cpu_device, t.device)
|
||||
self.assertEqual(torch.float32, cpu_t.dtype)
|
||||
self.assertEqual(torch.float32, t.dtype)
|
||||
self.assertIs(t, cpu_t)
|
||||
|
||||
cpu_t = t.to(dtype=torch.float64, device=cpu_device)
|
||||
self.assertEqual(cpu_device, cpu_t.device)
|
||||
self.assertEqual(cpu_device, t.device)
|
||||
self.assertEqual(torch.float64, cpu_t.dtype)
|
||||
self.assertEqual(torch.float32, t.dtype)
|
||||
self.assertIsNot(t, cpu_t)
|
||||
|
||||
cuda_device = torch.device("cuda")
|
||||
|
||||
cuda_t = t.to("cuda")
|
||||
self.assertEqual(cuda_device, cuda_t.device)
|
||||
self.assertEqual(cpu_device, t.device)
|
||||
self.assertEqual(torch.float32, cuda_t.dtype)
|
||||
self.assertEqual(torch.float32, t.dtype)
|
||||
self.assertIsNot(t, cuda_t)
|
||||
|
||||
cuda_t = t.to(cuda_device)
|
||||
self.assertEqual(cuda_device, cuda_t.device)
|
||||
self.assertEqual(cpu_device, t.device)
|
||||
self.assertEqual(torch.float32, cuda_t.dtype)
|
||||
self.assertEqual(torch.float32, t.dtype)
|
||||
self.assertIsNot(t, cuda_t)
|
||||
|
||||
cuda_t = t.to(dtype=torch.float64, device=cuda_device)
|
||||
self.assertEqual(cuda_device, cuda_t.device)
|
||||
self.assertEqual(cpu_device, t.device)
|
||||
self.assertEqual(torch.float64, cuda_t.dtype)
|
||||
self.assertEqual(torch.float32, t.dtype)
|
||||
self.assertIsNot(t, cuda_t)
|
||||
|
||||
cpu_points = torch.rand(9, 3)
|
||||
|
Loading…
x
Reference in New Issue
Block a user