mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 20:02:49 +08:00
Make Transform3d.to() not ignore dtype
Summary: Make Transform3d.to() not ignore a different dtype when device is the same and no copy is requested. Fix other methods where dtype is ignored. Reviewed By: nikhilaravi Differential Revision: D28981171 fbshipit-source-id: 4528e6092f4a693aecbe8131ede985fca84e84cf
This commit is contained in:
parent
626bf3fe23
commit
44508ed0db
@ -162,13 +162,15 @@ class Transform3d:
|
|||||||
raise ValueError(
|
raise ValueError(
|
||||||
'"matrix" has to be a tensor of shape (minibatch, 4, 4)'
|
'"matrix" has to be a tensor of shape (minibatch, 4, 4)'
|
||||||
)
|
)
|
||||||
# set the device from matrix
|
# set dtype and device from matrix
|
||||||
|
dtype = matrix.dtype
|
||||||
device = matrix.device
|
device = matrix.device
|
||||||
self._matrix = matrix.view(-1, 4, 4)
|
self._matrix = matrix.view(-1, 4, 4)
|
||||||
|
|
||||||
self._transforms = [] # store transforms to compose
|
self._transforms = [] # store transforms to compose
|
||||||
self._lu = None
|
self._lu = None
|
||||||
self.device = make_device(device)
|
self.device = make_device(device)
|
||||||
|
self.dtype = dtype
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
return self.get_matrix().shape[0]
|
return self.get_matrix().shape[0]
|
||||||
@ -200,7 +202,7 @@ class Transform3d:
|
|||||||
Returns:
|
Returns:
|
||||||
A new Transform3d with the stored transforms
|
A new Transform3d with the stored transforms
|
||||||
"""
|
"""
|
||||||
out = Transform3d(device=self.device)
|
out = Transform3d(dtype=self.dtype, device=self.device)
|
||||||
out._matrix = self._matrix.clone()
|
out._matrix = self._matrix.clone()
|
||||||
for other in others:
|
for other in others:
|
||||||
if not isinstance(other, Transform3d):
|
if not isinstance(other, Transform3d):
|
||||||
@ -259,7 +261,7 @@ class Transform3d:
|
|||||||
transformation.
|
transformation.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
tinv = Transform3d(device=self.device)
|
tinv = Transform3d(dtype=self.dtype, device=self.device)
|
||||||
|
|
||||||
if invert_composed:
|
if invert_composed:
|
||||||
# first compose then invert
|
# first compose then invert
|
||||||
@ -278,7 +280,7 @@ class Transform3d:
|
|||||||
# right-multiplies by the inverse of self._matrix
|
# right-multiplies by the inverse of self._matrix
|
||||||
# at the end of the composition.
|
# at the end of the composition.
|
||||||
tinv._transforms = [t.inverse() for t in reversed(self._transforms)]
|
tinv._transforms = [t.inverse() for t in reversed(self._transforms)]
|
||||||
last = Transform3d(device=self.device)
|
last = Transform3d(dtype=self.dtype, device=self.device)
|
||||||
last._matrix = i_matrix
|
last._matrix = i_matrix
|
||||||
tinv._transforms.append(last)
|
tinv._transforms.append(last)
|
||||||
else:
|
else:
|
||||||
@ -291,7 +293,7 @@ class Transform3d:
|
|||||||
def stack(self, *others):
|
def stack(self, *others):
|
||||||
transforms = [self] + list(others)
|
transforms = [self] + list(others)
|
||||||
matrix = torch.cat([t._matrix for t in transforms], dim=0)
|
matrix = torch.cat([t._matrix for t in transforms], dim=0)
|
||||||
out = Transform3d()
|
out = Transform3d(dtype=self.dtype, device=self.device)
|
||||||
out._matrix = matrix
|
out._matrix = matrix
|
||||||
return out
|
return out
|
||||||
|
|
||||||
@ -392,7 +394,7 @@ class Transform3d:
|
|||||||
Returns:
|
Returns:
|
||||||
new Transforms object.
|
new Transforms object.
|
||||||
"""
|
"""
|
||||||
other = Transform3d(device=self.device)
|
other = Transform3d(dtype=self.dtype, device=self.device)
|
||||||
if self._lu is not None:
|
if self._lu is not None:
|
||||||
other._lu = [elem.clone() for elem in self._lu]
|
other._lu = [elem.clone() for elem in self._lu]
|
||||||
other._matrix = self._matrix.clone()
|
other._matrix = self._matrix.clone()
|
||||||
@ -422,17 +424,22 @@ class Transform3d:
|
|||||||
Transform3d object.
|
Transform3d object.
|
||||||
"""
|
"""
|
||||||
device_ = make_device(device)
|
device_ = make_device(device)
|
||||||
if not copy and self.device == device_:
|
dtype_ = self.dtype if dtype is None else dtype
|
||||||
|
skip_to = self.device == device_ and self.dtype == dtype_
|
||||||
|
|
||||||
|
if not copy and skip_to:
|
||||||
return self
|
return self
|
||||||
|
|
||||||
other = self.clone()
|
other = self.clone()
|
||||||
if self.device == device_:
|
|
||||||
|
if skip_to:
|
||||||
return other
|
return other
|
||||||
|
|
||||||
other.device = device_
|
other.device = device_
|
||||||
other._matrix = self._matrix.to(device=device_, dtype=dtype)
|
other.dtype = dtype_
|
||||||
|
other._matrix = other._matrix.to(device=device_, dtype=dtype_)
|
||||||
other._transforms = [
|
other._transforms = [
|
||||||
t.to(device_, copy=copy, dtype=dtype) for t in other._transforms
|
t.to(device_, copy=copy, dtype=dtype_) for t in other._transforms
|
||||||
]
|
]
|
||||||
return other
|
return other
|
||||||
|
|
||||||
|
@ -28,23 +28,45 @@ class TestTransform(TestCaseMixin, unittest.TestCase):
|
|||||||
cpu_t = t.to("cpu")
|
cpu_t = t.to("cpu")
|
||||||
self.assertEqual(cpu_device, cpu_t.device)
|
self.assertEqual(cpu_device, cpu_t.device)
|
||||||
self.assertEqual(cpu_device, t.device)
|
self.assertEqual(cpu_device, t.device)
|
||||||
|
self.assertEqual(torch.float32, cpu_t.dtype)
|
||||||
|
self.assertEqual(torch.float32, t.dtype)
|
||||||
self.assertIs(t, cpu_t)
|
self.assertIs(t, cpu_t)
|
||||||
|
|
||||||
cpu_t = t.to(cpu_device)
|
cpu_t = t.to(cpu_device)
|
||||||
self.assertEqual(cpu_device, cpu_t.device)
|
self.assertEqual(cpu_device, cpu_t.device)
|
||||||
self.assertEqual(cpu_device, t.device)
|
self.assertEqual(cpu_device, t.device)
|
||||||
|
self.assertEqual(torch.float32, cpu_t.dtype)
|
||||||
|
self.assertEqual(torch.float32, t.dtype)
|
||||||
self.assertIs(t, cpu_t)
|
self.assertIs(t, cpu_t)
|
||||||
|
|
||||||
|
cpu_t = t.to(dtype=torch.float64, device=cpu_device)
|
||||||
|
self.assertEqual(cpu_device, cpu_t.device)
|
||||||
|
self.assertEqual(cpu_device, t.device)
|
||||||
|
self.assertEqual(torch.float64, cpu_t.dtype)
|
||||||
|
self.assertEqual(torch.float32, t.dtype)
|
||||||
|
self.assertIsNot(t, cpu_t)
|
||||||
|
|
||||||
cuda_device = torch.device("cuda")
|
cuda_device = torch.device("cuda")
|
||||||
|
|
||||||
cuda_t = t.to("cuda")
|
cuda_t = t.to("cuda")
|
||||||
self.assertEqual(cuda_device, cuda_t.device)
|
self.assertEqual(cuda_device, cuda_t.device)
|
||||||
self.assertEqual(cpu_device, t.device)
|
self.assertEqual(cpu_device, t.device)
|
||||||
|
self.assertEqual(torch.float32, cuda_t.dtype)
|
||||||
|
self.assertEqual(torch.float32, t.dtype)
|
||||||
self.assertIsNot(t, cuda_t)
|
self.assertIsNot(t, cuda_t)
|
||||||
|
|
||||||
cuda_t = t.to(cuda_device)
|
cuda_t = t.to(cuda_device)
|
||||||
self.assertEqual(cuda_device, cuda_t.device)
|
self.assertEqual(cuda_device, cuda_t.device)
|
||||||
self.assertEqual(cpu_device, t.device)
|
self.assertEqual(cpu_device, t.device)
|
||||||
|
self.assertEqual(torch.float32, cuda_t.dtype)
|
||||||
|
self.assertEqual(torch.float32, t.dtype)
|
||||||
|
self.assertIsNot(t, cuda_t)
|
||||||
|
|
||||||
|
cuda_t = t.to(dtype=torch.float64, device=cuda_device)
|
||||||
|
self.assertEqual(cuda_device, cuda_t.device)
|
||||||
|
self.assertEqual(cpu_device, t.device)
|
||||||
|
self.assertEqual(torch.float64, cuda_t.dtype)
|
||||||
|
self.assertEqual(torch.float32, t.dtype)
|
||||||
self.assertIsNot(t, cuda_t)
|
self.assertIsNot(t, cuda_t)
|
||||||
|
|
||||||
cpu_points = torch.rand(9, 3)
|
cpu_points = torch.rand(9, 3)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user