diff --git a/pytorch3d/implicitron/models/implicit_function/voxel_grid.py b/pytorch3d/implicitron/models/implicit_function/voxel_grid.py index 63272569..b9f8c1bf 100644 --- a/pytorch3d/implicitron/models/implicit_function/voxel_grid.py +++ b/pytorch3d/implicitron/models/implicit_function/voxel_grid.py @@ -166,14 +166,16 @@ class VoxelGridBase(ReplaceableBase, torch.nn.Module): def change_resolution( self, grid_values: VoxelGridValuesBase, - epoch: int, *, + epoch: Optional[int] = None, + grid_values_with_wanted_resolution: Optional[VoxelGridValuesBase] = None, mode: str = "linear", align_corners: bool = True, antialias: bool = False, ) -> Tuple[VoxelGridValuesBase, bool]: """ - Changes resolution of tensors in `grid_values` to match the `wanted_resolution`. + Changes resolution of tensors in `grid_values` to match the + `grid_values_with_wanted_resolution` or resolution on wanted epoch. Args: epoch: current training epoch, used to see if the grid needs regridding @@ -181,6 +183,8 @@ class VoxelGridBase(ReplaceableBase, torch.nn.Module): the voxel grid which will be interpolated to create the new grid epoch: epoch which is used to get the resolution of the new `grid_values` using `self.resolution_changes`. + grid_values_with_wanted_resolution: `VoxelGridValuesBase` to whose resolution + to interpolate grid_values align_corners: as for torch.nn.functional.interpolate mode: as for torch.nn.functional.interpolate 'nearest' | 'bicubic' | 'linear' | 'area' | 'nearest-exact'. @@ -195,8 +199,12 @@ class VoxelGridBase(ReplaceableBase, torch.nn.Module): - new voxel grid_values of desired resolution, of type self.values_type - True if regridding has happened. """ - if epoch not in self.resolution_changes: - return grid_values, False + + if (epoch is None) == (grid_values_with_wanted_resolution is None): + raise ValueError( + "Exactly one of `epoch` or " + "`grid_values_with_wanted_resolution` has to be defined." + ) if mode not in ("nearest", "bicubic", "linear", "area", "nearest-exact"): raise ValueError( @@ -219,11 +227,28 @@ class VoxelGridBase(ReplaceableBase, torch.nn.Module): recompute_scale_factor=False, ) - wanted_shapes = self.get_shapes(epoch=epoch) - params = { - name: change_individual_resolution(getattr(grid_values, name), shape[1:]) - for name, shape in wanted_shapes.items() - } + if epoch is not None: + if epoch not in self.resolution_changes: + return grid_values, False + + wanted_shapes = self.get_shapes(epoch=epoch) + params = { + name: change_individual_resolution( + getattr(grid_values, name), shape[1:] + ) + for name, shape in wanted_shapes.items() + } + else: + params = { + name: ( + change_individual_resolution( + getattr(grid_values, name), tensor.shape[2:] + ) + if tensor is not None + else None + ) + for name, tensor in vars(grid_values_with_wanted_resolution).items() + } # pyre-ignore[29] return self.values_type(**params), True @@ -239,6 +264,82 @@ class VoxelGridBase(ReplaceableBase, torch.nn.Module): """ return self.align_corners + def crop_world( + self, + min_point_world: torch.Tensor, + max_point_world: torch.Tensor, + grid_values: VoxelGridValuesBase, + volume_locator: VolumeLocator, + ) -> VoxelGridValuesBase: + """ + Crops the voxel grid based on minimum and maximum occupied point in + world coordinates. After cropping all 8 corner points are preserved in + the voxel grid. This is achieved by preserving all the voxels needed to + calculate the point. + + +--------B + / /| + / / | + +--------+ | <==== Bounding box represented by points A and B: + | | | - B has x, y and z coordinates bigger or equal + | | + to all other points of the object + | | / - A has x, y and z coordinates smaller or equal + | |/ to all other points of the object + A--------+ + + Args: + min_point_world: torch.Tensor of shape (3,). Has x, y and z coordinates + smaller or equal to all other occupied points. Point A from the + picture above. + max_point_world: torch.Tensor of shape (3,). Has x, y and z coordinates + bigger or equal to all other occupied points. Point B from the + picture above. + grid_values: instance of self.values_type which contains + the voxel grid which will be cropped to create the new grid + volume_locator: VolumeLocator object used to convert world to local + cordinates + Returns: + instance of self.values_type which has volume cropped to desired size. + """ + min_point_local = volume_locator.world_to_local_coords(min_point_world[None])[0] + max_point_local = volume_locator.world_to_local_coords(max_point_world[None])[0] + return self.crop_local(min_point_local, max_point_local, grid_values) + + def crop_local( + self, + min_point_local: torch.Tensor, + max_point_local: torch.Tensor, + grid_values: VoxelGridValuesBase, + ) -> VoxelGridValuesBase: + """ + Crops the voxel grid based on minimum and maximum occupied point in local + coordinates. After cropping both min and max point are preserved in the voxel + grid. This is achieved by preserving all the voxels needed to calculate the point. + + +--------B + / /| + / / | + +--------+ | <==== Bounding box represented by points A and B: + | | | - B has x, y and z coordinates bigger or equal + | | + to all other points of the object + | | / - A has x, y and z coordinates smaller or equal + | |/ to all other points of the object + A--------+ + + Args: + min_point_local: torch.Tensor of shape (3,). Has x, y and z coordinates + smaller or equal to all other occupied points. Point A from the + picture above. All elements in [-1, 1]. + max_point_local: torch.Tensor of shape (3,). Has x, y and z coordinates + bigger or equal to all other occupied points. Point B from the + picture above. All elements in [-1, 1]. + grid_values: instance of self.values_type which contains + the voxel grid which will be cropped to create the new grid + Returns: + instance of self.values_type which has volume cropped to desired size. + """ + raise NotImplementedError() + @dataclass class FullResolutionVoxelGridValues(VoxelGridValuesBase): @@ -288,6 +389,34 @@ class FullResolutionVoxelGrid(VoxelGridBase): width, height, depth = self.get_resolution(epoch) return {"voxel_grid": (self.n_features, width, height, depth)} + # pyre-ignore[14] + def crop_local( + self, + min_point_local: torch.Tensor, + max_point_local: torch.Tensor, + grid_values: FullResolutionVoxelGridValues, + ) -> FullResolutionVoxelGridValues: + assert torch.all(min_point_local < max_point_local) + min_point_local = torch.clamp(min_point_local, -1, 1) + max_point_local = torch.clamp(max_point_local, -1, 1) + _, _, width, height, depth = grid_values.voxel_grid.shape + resolution = grid_values.voxel_grid.new_tensor([width, height, depth]) + min_point_local01 = (min_point_local + 1) / 2 + max_point_local01 = (max_point_local + 1) / 2 + + if self.align_corners: + minx, miny, minz = torch.floor(min_point_local01 * (resolution - 1)).long() + maxx, maxy, maxz = torch.ceil(max_point_local01 * (resolution - 1)).long() + else: + minx, miny, minz = torch.floor(min_point_local01 * resolution - 0.5).long() + maxx, maxy, maxz = torch.ceil(max_point_local01 * resolution - 0.5).long() + + return FullResolutionVoxelGridValues( + voxel_grid=grid_values.voxel_grid[ + :, :, minx : maxx + 1, miny : maxy + 1, minz : maxz + 1 + ] + ) + @dataclass class CPFactorizedVoxelGridValues(VoxelGridValuesBase): @@ -388,6 +517,37 @@ class CPFactorizedVoxelGrid(VoxelGridBase): shape_dict["basis_matrix"] = (self.n_components, self.n_features) return shape_dict + # pyre-ignore[14] + def crop_local( + self, + min_point_local: torch.Tensor, + max_point_local: torch.Tensor, + grid_values: CPFactorizedVoxelGridValues, + ) -> CPFactorizedVoxelGridValues: + assert torch.all(min_point_local < max_point_local) + min_point_local = torch.clamp(min_point_local, -1, 1) + max_point_local = torch.clamp(max_point_local, -1, 1) + _, _, width = grid_values.vector_components_x.shape + _, _, height = grid_values.vector_components_y.shape + _, _, depth = grid_values.vector_components_z.shape + resolution = grid_values.vector_components_x.new_tensor([width, height, depth]) + min_point_local01 = (min_point_local + 1) / 2 + max_point_local01 = (max_point_local + 1) / 2 + + if self.align_corners: + minx, miny, minz = torch.floor(min_point_local01 * (resolution - 1)).long() + maxx, maxy, maxz = torch.ceil(max_point_local01 * (resolution - 1)).long() + else: + minx, miny, minz = torch.floor(min_point_local01 * resolution - 0.5).long() + maxx, maxy, maxz = torch.ceil(max_point_local01 * resolution - 0.5).long() + + return CPFactorizedVoxelGridValues( + vector_components_x=grid_values.vector_components_x[:, :, minx : maxx + 1], + vector_components_y=grid_values.vector_components_y[:, :, miny : maxy + 1], + vector_components_z=grid_values.vector_components_z[:, :, minz : maxz + 1], + basis_matrix=grid_values.basis_matrix, + ) + @dataclass class VMFactorizedVoxelGridValues(VoxelGridValuesBase): @@ -585,6 +745,46 @@ class VMFactorizedVoxelGrid(VoxelGridBase): return shape_dict + # pyre-ignore[14] + def crop_local( + self, + min_point_local: torch.Tensor, + max_point_local: torch.Tensor, + grid_values: VMFactorizedVoxelGridValues, + ) -> VMFactorizedVoxelGridValues: + assert torch.all(min_point_local < max_point_local) + min_point_local = torch.clamp(min_point_local, -1, 1) + max_point_local = torch.clamp(max_point_local, -1, 1) + _, _, width = grid_values.vector_components_x.shape + _, _, height = grid_values.vector_components_y.shape + _, _, depth = grid_values.vector_components_z.shape + resolution = grid_values.vector_components_x.new_tensor([width, height, depth]) + min_point_local01 = (min_point_local + 1) / 2 + max_point_local01 = (max_point_local + 1) / 2 + + if self.align_corners: + minx, miny, minz = torch.floor(min_point_local01 * (resolution - 1)).long() + maxx, maxy, maxz = torch.ceil(max_point_local01 * (resolution - 1)).long() + else: + minx, miny, minz = torch.floor(min_point_local01 * resolution - 0.5).long() + maxx, maxy, maxz = torch.ceil(max_point_local01 * resolution - 0.5).long() + + return VMFactorizedVoxelGridValues( + vector_components_x=grid_values.vector_components_x[:, :, minx : maxx + 1], + vector_components_y=grid_values.vector_components_y[:, :, miny : maxy + 1], + vector_components_z=grid_values.vector_components_z[:, :, minz : maxz + 1], + matrix_components_xy=grid_values.matrix_components_xy[ + :, :, minx : maxx + 1, miny : maxy + 1 + ], + matrix_components_yz=grid_values.matrix_components_yz[ + :, :, miny : maxy + 1, minz : maxz + 1 + ], + matrix_components_xz=grid_values.matrix_components_xz[ + :, :, minx : maxx + 1, minz : maxz + 1 + ], + basis_matrix=grid_values.basis_matrix, + ) + # pyre-fixme[13]: Attribute `voxel_grid` is never initialized. class VoxelGridModule(Configurable, torch.nn.Module): @@ -771,6 +971,37 @@ class VoxelGridModule(Configurable, torch.nn.Module): # pyre-ignore[29] return next(val for val in self.params.values() if val is not None).device + def crop_self(self, min_point: torch.Tensor, max_point: torch.Tensor) -> None: + """ + Crops self to only represent points between min_point and max_point (inclusive). + + Args: + min_point: torch.Tensor of shape (3,). Has x, y and z coordinates + smaller or equal to all other occupied points. + max_point: torch.Tensor of shape (3,). Has x, y and z coordinates + bigger or equal to all other occupied points. + Returns: + nothing + """ + locator = self._get_volume_locator() + # pyre-fixme[29]: `Union[torch._tensor.Tensor, + # torch.nn.modules.module.Module]` is not a function. + old_grid_values = self.voxel_grid.values_type(**self.params) + new_grid_values = self.voxel_grid.crop_world( + min_point, max_point, old_grid_values, locator + ) + grid_values, _ = self.voxel_grid.change_resolution( + new_grid_values, grid_values_with_wanted_resolution=old_grid_values + ) + # pyre-ignore [16] + self.params = torch.nn.ParameterDict( + {k: v for k, v in vars(grid_values).items()} + ) + # New center of voxel grid is the middle point between max and min points. + self.translation = tuple((max_point + min_point) / 2) + # new extents of voxel grid are distances between min and max points + self.extents = tuple(max_point - min_point) + def _get_volume_locator(self) -> VolumeLocator: """ Returns VolumeLocator calculated from `extents` and `translation` members. diff --git a/tests/implicitron/test_voxel_grids.py b/tests/implicitron/test_voxel_grids.py index 8e138411..6fdcdb29 100644 --- a/tests/implicitron/test_voxel_grids.py +++ b/tests/implicitron/test_voxel_grids.py @@ -9,7 +9,7 @@ import unittest from typing import Optional, Tuple import torch -from omegaconf import DictConfig +from omegaconf import DictConfig, OmegaConf from pytorch3d.implicitron.models.implicit_function.utils import ( interpolate_line, @@ -19,11 +19,12 @@ from pytorch3d.implicitron.models.implicit_function.utils import ( from pytorch3d.implicitron.models.implicit_function.voxel_grid import ( CPFactorizedVoxelGrid, FullResolutionVoxelGrid, + FullResolutionVoxelGridValues, VMFactorizedVoxelGrid, VoxelGridModule, ) -from pytorch3d.implicitron.tools.config import expand_args_fields +from pytorch3d.implicitron.tools.config import expand_args_fields, get_default_args from tests.common_testing import TestCaseMixin @@ -693,6 +694,140 @@ class TestVoxelGrids(TestCaseMixin, unittest.TestCase): grid.voxel_grid.get_resolution(epoch=epoch), ) + def _get_min_max_tuple( + self, n=4, denominator_base=2, max_exponent=6, add_edge_cases=True + ): + if add_edge_cases: + n -= 2 + + def get_pair(): + def get_one(): + sign = -1 if torch.rand((1,)) < 0.5 else 1 + exponent = int(torch.randint(1, max_exponent, (1,))) + denominator = denominator_base**exponent + numerator = int(torch.randint(1, denominator, (1,))) + return sign * numerator / denominator * 1.0 + + while True: + a, b = get_one(), get_one() + if a < b: + return a, b + + for _ in range(n): + a, b, c = get_pair(), get_pair(), get_pair() + yield torch.tensor((a[0], b[0], c[0])), torch.tensor((a[1], b[1], c[1])) + if add_edge_cases: + yield torch.tensor((-1.0, -1.0, -1.0)), torch.tensor((1.0, 1.0, 1.0)) + yield torch.tensor([0.0, 0.0, 0.0]), torch.tensor([1.0, 1.0, 1.0]) + + def test_cropping_voxel_grids(self, n_times=1): + """ + If the grid is 1d and we crop at A and B + ---------A---------B--- + and choose point p between them + ---------A-----p---B--- + it can be represented as + p = A + (B-A) * p_c + where p_c is local coordinate of p in cropped grid. So we now just see + if grid evaluated at p and cropped grid evaluated at p_c agree. + """ + for points_min, points_max in self._get_min_max_tuple(n=10): + n_grids, n_features, n_components = torch.randint(1, 3, (3,)).tolist() + n_grids = 1 + n_components *= 3 + resolution_changes = {0: (128 + 1, 128 + 1, 128 + 1)} + for cls, kwargs in ( + ( + FullResolutionVoxelGrid, + { + "n_features": n_features, + "resolution_changes": resolution_changes, + }, + ), + ( + CPFactorizedVoxelGrid, + { + "n_features": n_features, + "resolution_changes": resolution_changes, + "n_components": n_components, + }, + ), + ( + VMFactorizedVoxelGrid, + { + "n_features": n_features, + "resolution_changes": resolution_changes, + "n_components": n_components, + }, + ), + ): + with self.subTest( + cls.__name__ + f" points {points_min} and {points_max}" + ): + grid = cls(**kwargs) + shapes = grid.get_shapes(epoch=0) + params = { + name: torch.normal( + mean=torch.zeros((n_grids, *shape)), + std=1, + ) + for name, shape in shapes.items() + } + grid_values = grid.values_type(**params) + + grid_values_cropped = grid.crop_local( + points_min, points_max, grid_values + ) + + points_local_cropped = torch.rand((1, n_times, 3)) + points_local = ( + points_min[None, None] + + (points_max - points_min)[None, None] * points_local_cropped + ) + points_local_cropped = (points_local_cropped - 0.5) * 2 + + pred = grid.evaluate_local(points_local, grid_values) + pred_cropped = grid.evaluate_local( + points_local_cropped, grid_values_cropped + ) + + assert torch.allclose(pred, pred_cropped, rtol=1e-4, atol=1e-4), ( + pred, + pred_cropped, + points_local, + points_local_cropped, + ) + + def test_cropping_voxel_grid_module(self, n_times=1): + for points_min, points_max in self._get_min_max_tuple(n=5, max_exponent=5): + extents = torch.ones((3,)) * 2 + translation = torch.ones((3,)) * 0.2 + points_min += translation + points_max += translation + + default_cfg = get_default_args(VoxelGridModule) + custom_cfg = DictConfig( + { + "extents": tuple(float(e) for e in extents), + "translation": tuple(float(t) for t in translation), + "voxel_grid_FullResolutionVoxelGrid_args": { + "resolution_changes": {0: (128 + 1, 128 + 1, 128 + 1)} + }, + } + ) + cfg = OmegaConf.merge(default_cfg, custom_cfg) + grid = VoxelGridModule(**cfg) + + points = (torch.rand(3) * (points_max - points_min) + points_min)[None] + result = grid(points) + grid.crop_self(points_min, points_max) + result_cropped = grid(points) + + assert torch.allclose(result, result_cropped, rtol=0.001, atol=0.001), ( + result, + result_cropped, + ) + def test_loading_state_dict(self): """ Test loading state dict after rescaling.