Split Volumes class to data and location part

Summary: Split Volumes class to data and location part so that location part can be reused in planned VoxelGrid classes.

Reviewed By: bottler

Differential Revision: D38782015

fbshipit-source-id: 489da09c5c236f3b81961ce9b09edbd97afaa7c8
This commit is contained in:
Darijan Gudelj
2022-08-18 08:12:33 -07:00
committed by Facebook GitHub Bot
parent fdaaa299a7
commit f825f7e42c
2 changed files with 721 additions and 297 deletions

View File

@@ -11,7 +11,7 @@ import unittest
import numpy as np
import torch
from pytorch3d.structures.volumes import Volumes
from pytorch3d.structures.volumes import VolumeLocator, Volumes
from pytorch3d.transforms import Scale
from .common_testing import TestCaseMixin
@@ -53,8 +53,8 @@ class TestVolumes(TestCaseMixin, unittest.TestCase):
for selectedIdx, index in indices:
self.assertClose(selected.densities()[selectedIdx], v.densities()[index])
self.assertClose(
v._local_to_world_transform.get_matrix()[index],
selected._local_to_world_transform.get_matrix()[selectedIdx],
v.locator._local_to_world_transform.get_matrix()[index],
selected.locator._local_to_world_transform.get_matrix()[selectedIdx],
)
if selected.features() is not None:
self.assertClose(selected.features()[selectedIdx], v.features()[index])
@@ -149,10 +149,55 @@ class TestVolumes(TestCaseMixin, unittest.TestCase):
with self.assertRaises(IndexError):
v_selected = v[index]
def test_locator_init(self, batch_size=9, resolution=(3, 5, 7)):
with self.subTest("VolumeLocator init with all sizes equal"):
grid_sizes = [resolution for _ in range(batch_size)]
locator_tuple = VolumeLocator(
batch_size=batch_size, grid_sizes=resolution, device=torch.device("cpu")
)
locator_list = VolumeLocator(
batch_size=batch_size, grid_sizes=grid_sizes, device=torch.device("cpu")
)
locator_tensor = VolumeLocator(
batch_size=batch_size,
grid_sizes=torch.tensor(grid_sizes),
device=torch.device("cpu"),
)
expected_grid_sizes = torch.tensor(grid_sizes)
expected_resolution = resolution
assert torch.allclose(expected_grid_sizes, locator_tuple._grid_sizes)
assert torch.allclose(expected_grid_sizes, locator_list._grid_sizes)
assert torch.allclose(expected_grid_sizes, locator_tensor._grid_sizes)
self.assertEqual(expected_resolution, locator_tuple._resolution)
self.assertEqual(expected_resolution, locator_list._resolution)
self.assertEqual(expected_resolution, locator_tensor._resolution)
with self.subTest("VolumeLocator with different sizes in different grids"):
grid_sizes_list = [
torch.randint(low=1, high=42, size=(3,)) for _ in range(batch_size)
]
grid_sizes_tensor = torch.cat([el[None] for el in grid_sizes_list])
locator_list = VolumeLocator(
batch_size=batch_size,
grid_sizes=grid_sizes_list,
device=torch.device("cpu"),
)
locator_tensor = VolumeLocator(
batch_size=batch_size,
grid_sizes=grid_sizes_tensor,
device=torch.device("cpu"),
)
expected_grid_sizes = grid_sizes_tensor
expected_resolution = tuple(torch.max(expected_grid_sizes, dim=0).values)
assert torch.allclose(expected_grid_sizes, locator_list._grid_sizes)
assert torch.allclose(expected_grid_sizes, locator_tensor._grid_sizes)
self.assertEqual(expected_resolution, locator_list._resolution)
self.assertEqual(expected_resolution, locator_tensor._resolution)
def test_coord_transforms(self, num_volumes=3, num_channels=4, dtype=torch.float32):
"""
Test the correctness of the conversion between the internal
Transform3D Volumes._local_to_world_transform and the initialization
Transform3D Volumes.VolumeLocator._local_to_world_transform and the initialization
from the translation and voxel_size.
"""
@@ -440,7 +485,10 @@ class TestVolumes(TestCaseMixin, unittest.TestCase):
for var_name, var in vars(v).items():
if var_name != "device":
if var is not None:
self.assertTrue(var.device.type == desired_device.type)
self.assertTrue(
var.device.type == desired_device.type,
(var_name, var.device, desired_device),
)
else:
self.assertTrue(var.type == desired_device.type)
@@ -456,60 +504,74 @@ class TestVolumes(TestCaseMixin, unittest.TestCase):
)
densities = torch.rand(size=[num_volumes, 1, *size], dtype=dtype)
volumes = Volumes(densities=densities, features=features)
locator = VolumeLocator(
batch_size=5, grid_sizes=(3, 5, 7), device=volumes.device
)
# Test support for str and torch.device
cpu_device = torch.device("cpu")
for name, obj in (("VolumeLocator", locator), ("Volumes", volumes)):
with self.subTest(f"Moving {name} from/to gpu and cpu"):
# Test support for str and torch.device
cpu_device = torch.device("cpu")
converted_volumes = volumes.to("cpu")
self.assertEqual(cpu_device, converted_volumes.device)
self.assertEqual(cpu_device, volumes.device)
self.assertIs(volumes, converted_volumes)
converted_obj = obj.to("cpu")
self.assertEqual(cpu_device, converted_obj.device)
self.assertEqual(cpu_device, obj.device)
self.assertIs(obj, converted_obj)
converted_volumes = volumes.to(cpu_device)
self.assertEqual(cpu_device, converted_volumes.device)
self.assertEqual(cpu_device, volumes.device)
self.assertIs(volumes, converted_volumes)
converted_obj = obj.to(cpu_device)
self.assertEqual(cpu_device, converted_obj.device)
self.assertEqual(cpu_device, obj.device)
self.assertIs(obj, converted_obj)
cuda_device = torch.device("cuda:0")
cuda_device = torch.device("cuda:0")
converted_volumes = volumes.to("cuda:0")
self.assertEqual(cuda_device, converted_volumes.device)
self.assertEqual(cpu_device, volumes.device)
self.assertIsNot(volumes, converted_volumes)
converted_obj = obj.to("cuda:0")
self.assertEqual(cuda_device, converted_obj.device)
self.assertEqual(cpu_device, obj.device)
self.assertIsNot(obj, converted_obj)
converted_volumes = volumes.to(cuda_device)
self.assertEqual(cuda_device, converted_volumes.device)
self.assertEqual(cpu_device, volumes.device)
self.assertIsNot(volumes, converted_volumes)
converted_obj = obj.to(cuda_device)
self.assertEqual(cuda_device, converted_obj.device)
self.assertEqual(cpu_device, obj.device)
self.assertIsNot(obj, converted_obj)
# Test device placement of internal tensors
features = features.to(cuda_device)
densities = features.to(cuda_device)
with self.subTest("Test device placement of internal tensors of Volumes"):
features = features.to(cuda_device)
densities = features.to(cuda_device)
for features_ in (features, None):
volumes = Volumes(densities=densities, features=features_)
for features_ in (features, None):
volumes = Volumes(densities=densities, features=features_)
cpu_volumes = volumes.cpu()
cuda_volumes = cpu_volumes.cuda()
cuda_volumes2 = cuda_volumes.cuda()
cpu_volumes2 = cuda_volumes2.cpu()
cpu_volumes = volumes.cpu()
cuda_volumes = cpu_volumes.cuda()
cuda_volumes2 = cuda_volumes.cuda()
cpu_volumes2 = cuda_volumes2.cpu()
for volumes1, volumes2 in itertools.combinations(
(volumes, cpu_volumes, cpu_volumes2, cuda_volumes, cuda_volumes2), 2
):
if volumes1 is cuda_volumes and volumes2 is cuda_volumes2:
# checks that we do not copy if the devices stay the same
assert_fun = self.assertIs
else:
assert_fun = self.assertSeparate
assert_fun(volumes1._densities, volumes2._densities)
if features_ is not None:
assert_fun(volumes1._features, volumes2._features)
for volumes_ in (volumes1, volumes2):
if volumes_ in (cpu_volumes, cpu_volumes2):
self._check_vars_on_device(volumes_, cpu_device)
for volumes1, volumes2 in itertools.combinations(
(volumes, cpu_volumes, cpu_volumes2, cuda_volumes, cuda_volumes2), 2
):
if volumes1 is cuda_volumes and volumes2 is cuda_volumes2:
# checks that we do not copy if the devices stay the same
assert_fun = self.assertIs
else:
self._check_vars_on_device(volumes_, cuda_device)
assert_fun = self.assertSeparate
assert_fun(volumes1._densities, volumes2._densities)
if features_ is not None:
assert_fun(volumes1._features, volumes2._features)
for volumes_ in (volumes1, volumes2):
if volumes_ in (cpu_volumes, cpu_volumes2):
self._check_vars_on_device(volumes_, cpu_device)
else:
self._check_vars_on_device(volumes_, cuda_device)
with self.subTest("Test device placement of internal tensors of VolumeLocator"):
for device1, device2 in itertools.combinations(
(torch.device("cpu"), torch.device("cuda:0")), 2
):
locator = locator.to(device1)
locator = locator.to(device2)
self.assertEqual(locator._grid_sizes.device, device2)
self.assertEqual(locator._local_to_world_transform.device, device2)
def _check_padded(self, x_pad, x_list, grid_sizes):
"""