volume cropping

Summary:
TensoRF at step 2000 does volume croping and resizing.
At those steps it calculates part of the voxel grid which has density big enough to have objects and resizes the grid to fit that object.
Change is done on 3 levels:
- implicit function subscribes to epochs and at specific epochs finds the bounding box of the object and calls resizing of the color and density voxel grids to fit it
- VoxelGrid module calls cropping of the underlaying voxel grid and resizing to fit previous size it also adjusts its extends and translation to match wanted size
- Each voxel grid has its own way of cropping the underlaying data

Reviewed By: kjchalup

Differential Revision: D39854548

fbshipit-source-id: 5435b6e599aef1eaab980f5421d3369ee4829c50
This commit is contained in:
Darijan Gudelj 2022-10-12 08:31:51 -07:00 committed by Facebook GitHub Bot
parent 0b5def5257
commit f55d37f07d
2 changed files with 377 additions and 11 deletions

View File

@ -166,14 +166,16 @@ class VoxelGridBase(ReplaceableBase, torch.nn.Module):
def change_resolution(
self,
grid_values: VoxelGridValuesBase,
epoch: int,
*,
epoch: Optional[int] = None,
grid_values_with_wanted_resolution: Optional[VoxelGridValuesBase] = None,
mode: str = "linear",
align_corners: bool = True,
antialias: bool = False,
) -> Tuple[VoxelGridValuesBase, bool]:
"""
Changes resolution of tensors in `grid_values` to match the `wanted_resolution`.
Changes resolution of tensors in `grid_values` to match the
`grid_values_with_wanted_resolution` or resolution on wanted epoch.
Args:
epoch: current training epoch, used to see if the grid needs regridding
@ -181,6 +183,8 @@ class VoxelGridBase(ReplaceableBase, torch.nn.Module):
the voxel grid which will be interpolated to create the new grid
epoch: epoch which is used to get the resolution of the new
`grid_values` using `self.resolution_changes`.
grid_values_with_wanted_resolution: `VoxelGridValuesBase` to whose resolution
to interpolate grid_values
align_corners: as for torch.nn.functional.interpolate
mode: as for torch.nn.functional.interpolate
'nearest' | 'bicubic' | 'linear' | 'area' | 'nearest-exact'.
@ -195,8 +199,12 @@ class VoxelGridBase(ReplaceableBase, torch.nn.Module):
- new voxel grid_values of desired resolution, of type self.values_type
- True if regridding has happened.
"""
if epoch not in self.resolution_changes:
return grid_values, False
if (epoch is None) == (grid_values_with_wanted_resolution is None):
raise ValueError(
"Exactly one of `epoch` or "
"`grid_values_with_wanted_resolution` has to be defined."
)
if mode not in ("nearest", "bicubic", "linear", "area", "nearest-exact"):
raise ValueError(
@ -219,11 +227,28 @@ class VoxelGridBase(ReplaceableBase, torch.nn.Module):
recompute_scale_factor=False,
)
wanted_shapes = self.get_shapes(epoch=epoch)
params = {
name: change_individual_resolution(getattr(grid_values, name), shape[1:])
for name, shape in wanted_shapes.items()
}
if epoch is not None:
if epoch not in self.resolution_changes:
return grid_values, False
wanted_shapes = self.get_shapes(epoch=epoch)
params = {
name: change_individual_resolution(
getattr(grid_values, name), shape[1:]
)
for name, shape in wanted_shapes.items()
}
else:
params = {
name: (
change_individual_resolution(
getattr(grid_values, name), tensor.shape[2:]
)
if tensor is not None
else None
)
for name, tensor in vars(grid_values_with_wanted_resolution).items()
}
# pyre-ignore[29]
return self.values_type(**params), True
@ -239,6 +264,82 @@ class VoxelGridBase(ReplaceableBase, torch.nn.Module):
"""
return self.align_corners
def crop_world(
self,
min_point_world: torch.Tensor,
max_point_world: torch.Tensor,
grid_values: VoxelGridValuesBase,
volume_locator: VolumeLocator,
) -> VoxelGridValuesBase:
"""
Crops the voxel grid based on minimum and maximum occupied point in
world coordinates. After cropping all 8 corner points are preserved in
the voxel grid. This is achieved by preserving all the voxels needed to
calculate the point.
+--------B
/ /|
/ / |
+--------+ | <==== Bounding box represented by points A and B:
| | | - B has x, y and z coordinates bigger or equal
| | + to all other points of the object
| | / - A has x, y and z coordinates smaller or equal
| |/ to all other points of the object
A--------+
Args:
min_point_world: torch.Tensor of shape (3,). Has x, y and z coordinates
smaller or equal to all other occupied points. Point A from the
picture above.
max_point_world: torch.Tensor of shape (3,). Has x, y and z coordinates
bigger or equal to all other occupied points. Point B from the
picture above.
grid_values: instance of self.values_type which contains
the voxel grid which will be cropped to create the new grid
volume_locator: VolumeLocator object used to convert world to local
cordinates
Returns:
instance of self.values_type which has volume cropped to desired size.
"""
min_point_local = volume_locator.world_to_local_coords(min_point_world[None])[0]
max_point_local = volume_locator.world_to_local_coords(max_point_world[None])[0]
return self.crop_local(min_point_local, max_point_local, grid_values)
def crop_local(
self,
min_point_local: torch.Tensor,
max_point_local: torch.Tensor,
grid_values: VoxelGridValuesBase,
) -> VoxelGridValuesBase:
"""
Crops the voxel grid based on minimum and maximum occupied point in local
coordinates. After cropping both min and max point are preserved in the voxel
grid. This is achieved by preserving all the voxels needed to calculate the point.
+--------B
/ /|
/ / |
+--------+ | <==== Bounding box represented by points A and B:
| | | - B has x, y and z coordinates bigger or equal
| | + to all other points of the object
| | / - A has x, y and z coordinates smaller or equal
| |/ to all other points of the object
A--------+
Args:
min_point_local: torch.Tensor of shape (3,). Has x, y and z coordinates
smaller or equal to all other occupied points. Point A from the
picture above. All elements in [-1, 1].
max_point_local: torch.Tensor of shape (3,). Has x, y and z coordinates
bigger or equal to all other occupied points. Point B from the
picture above. All elements in [-1, 1].
grid_values: instance of self.values_type which contains
the voxel grid which will be cropped to create the new grid
Returns:
instance of self.values_type which has volume cropped to desired size.
"""
raise NotImplementedError()
@dataclass
class FullResolutionVoxelGridValues(VoxelGridValuesBase):
@ -288,6 +389,34 @@ class FullResolutionVoxelGrid(VoxelGridBase):
width, height, depth = self.get_resolution(epoch)
return {"voxel_grid": (self.n_features, width, height, depth)}
# pyre-ignore[14]
def crop_local(
self,
min_point_local: torch.Tensor,
max_point_local: torch.Tensor,
grid_values: FullResolutionVoxelGridValues,
) -> FullResolutionVoxelGridValues:
assert torch.all(min_point_local < max_point_local)
min_point_local = torch.clamp(min_point_local, -1, 1)
max_point_local = torch.clamp(max_point_local, -1, 1)
_, _, width, height, depth = grid_values.voxel_grid.shape
resolution = grid_values.voxel_grid.new_tensor([width, height, depth])
min_point_local01 = (min_point_local + 1) / 2
max_point_local01 = (max_point_local + 1) / 2
if self.align_corners:
minx, miny, minz = torch.floor(min_point_local01 * (resolution - 1)).long()
maxx, maxy, maxz = torch.ceil(max_point_local01 * (resolution - 1)).long()
else:
minx, miny, minz = torch.floor(min_point_local01 * resolution - 0.5).long()
maxx, maxy, maxz = torch.ceil(max_point_local01 * resolution - 0.5).long()
return FullResolutionVoxelGridValues(
voxel_grid=grid_values.voxel_grid[
:, :, minx : maxx + 1, miny : maxy + 1, minz : maxz + 1
]
)
@dataclass
class CPFactorizedVoxelGridValues(VoxelGridValuesBase):
@ -388,6 +517,37 @@ class CPFactorizedVoxelGrid(VoxelGridBase):
shape_dict["basis_matrix"] = (self.n_components, self.n_features)
return shape_dict
# pyre-ignore[14]
def crop_local(
self,
min_point_local: torch.Tensor,
max_point_local: torch.Tensor,
grid_values: CPFactorizedVoxelGridValues,
) -> CPFactorizedVoxelGridValues:
assert torch.all(min_point_local < max_point_local)
min_point_local = torch.clamp(min_point_local, -1, 1)
max_point_local = torch.clamp(max_point_local, -1, 1)
_, _, width = grid_values.vector_components_x.shape
_, _, height = grid_values.vector_components_y.shape
_, _, depth = grid_values.vector_components_z.shape
resolution = grid_values.vector_components_x.new_tensor([width, height, depth])
min_point_local01 = (min_point_local + 1) / 2
max_point_local01 = (max_point_local + 1) / 2
if self.align_corners:
minx, miny, minz = torch.floor(min_point_local01 * (resolution - 1)).long()
maxx, maxy, maxz = torch.ceil(max_point_local01 * (resolution - 1)).long()
else:
minx, miny, minz = torch.floor(min_point_local01 * resolution - 0.5).long()
maxx, maxy, maxz = torch.ceil(max_point_local01 * resolution - 0.5).long()
return CPFactorizedVoxelGridValues(
vector_components_x=grid_values.vector_components_x[:, :, minx : maxx + 1],
vector_components_y=grid_values.vector_components_y[:, :, miny : maxy + 1],
vector_components_z=grid_values.vector_components_z[:, :, minz : maxz + 1],
basis_matrix=grid_values.basis_matrix,
)
@dataclass
class VMFactorizedVoxelGridValues(VoxelGridValuesBase):
@ -585,6 +745,46 @@ class VMFactorizedVoxelGrid(VoxelGridBase):
return shape_dict
# pyre-ignore[14]
def crop_local(
self,
min_point_local: torch.Tensor,
max_point_local: torch.Tensor,
grid_values: VMFactorizedVoxelGridValues,
) -> VMFactorizedVoxelGridValues:
assert torch.all(min_point_local < max_point_local)
min_point_local = torch.clamp(min_point_local, -1, 1)
max_point_local = torch.clamp(max_point_local, -1, 1)
_, _, width = grid_values.vector_components_x.shape
_, _, height = grid_values.vector_components_y.shape
_, _, depth = grid_values.vector_components_z.shape
resolution = grid_values.vector_components_x.new_tensor([width, height, depth])
min_point_local01 = (min_point_local + 1) / 2
max_point_local01 = (max_point_local + 1) / 2
if self.align_corners:
minx, miny, minz = torch.floor(min_point_local01 * (resolution - 1)).long()
maxx, maxy, maxz = torch.ceil(max_point_local01 * (resolution - 1)).long()
else:
minx, miny, minz = torch.floor(min_point_local01 * resolution - 0.5).long()
maxx, maxy, maxz = torch.ceil(max_point_local01 * resolution - 0.5).long()
return VMFactorizedVoxelGridValues(
vector_components_x=grid_values.vector_components_x[:, :, minx : maxx + 1],
vector_components_y=grid_values.vector_components_y[:, :, miny : maxy + 1],
vector_components_z=grid_values.vector_components_z[:, :, minz : maxz + 1],
matrix_components_xy=grid_values.matrix_components_xy[
:, :, minx : maxx + 1, miny : maxy + 1
],
matrix_components_yz=grid_values.matrix_components_yz[
:, :, miny : maxy + 1, minz : maxz + 1
],
matrix_components_xz=grid_values.matrix_components_xz[
:, :, minx : maxx + 1, minz : maxz + 1
],
basis_matrix=grid_values.basis_matrix,
)
# pyre-fixme[13]: Attribute `voxel_grid` is never initialized.
class VoxelGridModule(Configurable, torch.nn.Module):
@ -771,6 +971,37 @@ class VoxelGridModule(Configurable, torch.nn.Module):
# pyre-ignore[29]
return next(val for val in self.params.values() if val is not None).device
def crop_self(self, min_point: torch.Tensor, max_point: torch.Tensor) -> None:
"""
Crops self to only represent points between min_point and max_point (inclusive).
Args:
min_point: torch.Tensor of shape (3,). Has x, y and z coordinates
smaller or equal to all other occupied points.
max_point: torch.Tensor of shape (3,). Has x, y and z coordinates
bigger or equal to all other occupied points.
Returns:
nothing
"""
locator = self._get_volume_locator()
# pyre-fixme[29]: `Union[torch._tensor.Tensor,
# torch.nn.modules.module.Module]` is not a function.
old_grid_values = self.voxel_grid.values_type(**self.params)
new_grid_values = self.voxel_grid.crop_world(
min_point, max_point, old_grid_values, locator
)
grid_values, _ = self.voxel_grid.change_resolution(
new_grid_values, grid_values_with_wanted_resolution=old_grid_values
)
# pyre-ignore [16]
self.params = torch.nn.ParameterDict(
{k: v for k, v in vars(grid_values).items()}
)
# New center of voxel grid is the middle point between max and min points.
self.translation = tuple((max_point + min_point) / 2)
# new extents of voxel grid are distances between min and max points
self.extents = tuple(max_point - min_point)
def _get_volume_locator(self) -> VolumeLocator:
"""
Returns VolumeLocator calculated from `extents` and `translation` members.

View File

@ -9,7 +9,7 @@ import unittest
from typing import Optional, Tuple
import torch
from omegaconf import DictConfig
from omegaconf import DictConfig, OmegaConf
from pytorch3d.implicitron.models.implicit_function.utils import (
interpolate_line,
@ -19,11 +19,12 @@ from pytorch3d.implicitron.models.implicit_function.utils import (
from pytorch3d.implicitron.models.implicit_function.voxel_grid import (
CPFactorizedVoxelGrid,
FullResolutionVoxelGrid,
FullResolutionVoxelGridValues,
VMFactorizedVoxelGrid,
VoxelGridModule,
)
from pytorch3d.implicitron.tools.config import expand_args_fields
from pytorch3d.implicitron.tools.config import expand_args_fields, get_default_args
from tests.common_testing import TestCaseMixin
@ -693,6 +694,140 @@ class TestVoxelGrids(TestCaseMixin, unittest.TestCase):
grid.voxel_grid.get_resolution(epoch=epoch),
)
def _get_min_max_tuple(
self, n=4, denominator_base=2, max_exponent=6, add_edge_cases=True
):
if add_edge_cases:
n -= 2
def get_pair():
def get_one():
sign = -1 if torch.rand((1,)) < 0.5 else 1
exponent = int(torch.randint(1, max_exponent, (1,)))
denominator = denominator_base**exponent
numerator = int(torch.randint(1, denominator, (1,)))
return sign * numerator / denominator * 1.0
while True:
a, b = get_one(), get_one()
if a < b:
return a, b
for _ in range(n):
a, b, c = get_pair(), get_pair(), get_pair()
yield torch.tensor((a[0], b[0], c[0])), torch.tensor((a[1], b[1], c[1]))
if add_edge_cases:
yield torch.tensor((-1.0, -1.0, -1.0)), torch.tensor((1.0, 1.0, 1.0))
yield torch.tensor([0.0, 0.0, 0.0]), torch.tensor([1.0, 1.0, 1.0])
def test_cropping_voxel_grids(self, n_times=1):
"""
If the grid is 1d and we crop at A and B
---------A---------B---
and choose point p between them
---------A-----p---B---
it can be represented as
p = A + (B-A) * p_c
where p_c is local coordinate of p in cropped grid. So we now just see
if grid evaluated at p and cropped grid evaluated at p_c agree.
"""
for points_min, points_max in self._get_min_max_tuple(n=10):
n_grids, n_features, n_components = torch.randint(1, 3, (3,)).tolist()
n_grids = 1
n_components *= 3
resolution_changes = {0: (128 + 1, 128 + 1, 128 + 1)}
for cls, kwargs in (
(
FullResolutionVoxelGrid,
{
"n_features": n_features,
"resolution_changes": resolution_changes,
},
),
(
CPFactorizedVoxelGrid,
{
"n_features": n_features,
"resolution_changes": resolution_changes,
"n_components": n_components,
},
),
(
VMFactorizedVoxelGrid,
{
"n_features": n_features,
"resolution_changes": resolution_changes,
"n_components": n_components,
},
),
):
with self.subTest(
cls.__name__ + f" points {points_min} and {points_max}"
):
grid = cls(**kwargs)
shapes = grid.get_shapes(epoch=0)
params = {
name: torch.normal(
mean=torch.zeros((n_grids, *shape)),
std=1,
)
for name, shape in shapes.items()
}
grid_values = grid.values_type(**params)
grid_values_cropped = grid.crop_local(
points_min, points_max, grid_values
)
points_local_cropped = torch.rand((1, n_times, 3))
points_local = (
points_min[None, None]
+ (points_max - points_min)[None, None] * points_local_cropped
)
points_local_cropped = (points_local_cropped - 0.5) * 2
pred = grid.evaluate_local(points_local, grid_values)
pred_cropped = grid.evaluate_local(
points_local_cropped, grid_values_cropped
)
assert torch.allclose(pred, pred_cropped, rtol=1e-4, atol=1e-4), (
pred,
pred_cropped,
points_local,
points_local_cropped,
)
def test_cropping_voxel_grid_module(self, n_times=1):
for points_min, points_max in self._get_min_max_tuple(n=5, max_exponent=5):
extents = torch.ones((3,)) * 2
translation = torch.ones((3,)) * 0.2
points_min += translation
points_max += translation
default_cfg = get_default_args(VoxelGridModule)
custom_cfg = DictConfig(
{
"extents": tuple(float(e) for e in extents),
"translation": tuple(float(t) for t in translation),
"voxel_grid_FullResolutionVoxelGrid_args": {
"resolution_changes": {0: (128 + 1, 128 + 1, 128 + 1)}
},
}
)
cfg = OmegaConf.merge(default_cfg, custom_cfg)
grid = VoxelGridModule(**cfg)
points = (torch.rand(3) * (points_max - points_min) + points_min)[None]
result = grid(points)
grid.crop_self(points_min, points_max)
result_cropped = grid(points)
assert torch.allclose(result, result_cropped, rtol=0.001, atol=0.001), (
result,
result_cropped,
)
def test_loading_state_dict(self):
"""
Test loading state dict after rescaling.