mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 03:42:50 +08:00
volume cropping
Summary: TensoRF at step 2000 does volume croping and resizing. At those steps it calculates part of the voxel grid which has density big enough to have objects and resizes the grid to fit that object. Change is done on 3 levels: - implicit function subscribes to epochs and at specific epochs finds the bounding box of the object and calls resizing of the color and density voxel grids to fit it - VoxelGrid module calls cropping of the underlaying voxel grid and resizing to fit previous size it also adjusts its extends and translation to match wanted size - Each voxel grid has its own way of cropping the underlaying data Reviewed By: kjchalup Differential Revision: D39854548 fbshipit-source-id: 5435b6e599aef1eaab980f5421d3369ee4829c50
This commit is contained in:
parent
0b5def5257
commit
f55d37f07d
@ -166,14 +166,16 @@ class VoxelGridBase(ReplaceableBase, torch.nn.Module):
|
||||
def change_resolution(
|
||||
self,
|
||||
grid_values: VoxelGridValuesBase,
|
||||
epoch: int,
|
||||
*,
|
||||
epoch: Optional[int] = None,
|
||||
grid_values_with_wanted_resolution: Optional[VoxelGridValuesBase] = None,
|
||||
mode: str = "linear",
|
||||
align_corners: bool = True,
|
||||
antialias: bool = False,
|
||||
) -> Tuple[VoxelGridValuesBase, bool]:
|
||||
"""
|
||||
Changes resolution of tensors in `grid_values` to match the `wanted_resolution`.
|
||||
Changes resolution of tensors in `grid_values` to match the
|
||||
`grid_values_with_wanted_resolution` or resolution on wanted epoch.
|
||||
|
||||
Args:
|
||||
epoch: current training epoch, used to see if the grid needs regridding
|
||||
@ -181,6 +183,8 @@ class VoxelGridBase(ReplaceableBase, torch.nn.Module):
|
||||
the voxel grid which will be interpolated to create the new grid
|
||||
epoch: epoch which is used to get the resolution of the new
|
||||
`grid_values` using `self.resolution_changes`.
|
||||
grid_values_with_wanted_resolution: `VoxelGridValuesBase` to whose resolution
|
||||
to interpolate grid_values
|
||||
align_corners: as for torch.nn.functional.interpolate
|
||||
mode: as for torch.nn.functional.interpolate
|
||||
'nearest' | 'bicubic' | 'linear' | 'area' | 'nearest-exact'.
|
||||
@ -195,8 +199,12 @@ class VoxelGridBase(ReplaceableBase, torch.nn.Module):
|
||||
- new voxel grid_values of desired resolution, of type self.values_type
|
||||
- True if regridding has happened.
|
||||
"""
|
||||
if epoch not in self.resolution_changes:
|
||||
return grid_values, False
|
||||
|
||||
if (epoch is None) == (grid_values_with_wanted_resolution is None):
|
||||
raise ValueError(
|
||||
"Exactly one of `epoch` or "
|
||||
"`grid_values_with_wanted_resolution` has to be defined."
|
||||
)
|
||||
|
||||
if mode not in ("nearest", "bicubic", "linear", "area", "nearest-exact"):
|
||||
raise ValueError(
|
||||
@ -219,11 +227,28 @@ class VoxelGridBase(ReplaceableBase, torch.nn.Module):
|
||||
recompute_scale_factor=False,
|
||||
)
|
||||
|
||||
wanted_shapes = self.get_shapes(epoch=epoch)
|
||||
params = {
|
||||
name: change_individual_resolution(getattr(grid_values, name), shape[1:])
|
||||
for name, shape in wanted_shapes.items()
|
||||
}
|
||||
if epoch is not None:
|
||||
if epoch not in self.resolution_changes:
|
||||
return grid_values, False
|
||||
|
||||
wanted_shapes = self.get_shapes(epoch=epoch)
|
||||
params = {
|
||||
name: change_individual_resolution(
|
||||
getattr(grid_values, name), shape[1:]
|
||||
)
|
||||
for name, shape in wanted_shapes.items()
|
||||
}
|
||||
else:
|
||||
params = {
|
||||
name: (
|
||||
change_individual_resolution(
|
||||
getattr(grid_values, name), tensor.shape[2:]
|
||||
)
|
||||
if tensor is not None
|
||||
else None
|
||||
)
|
||||
for name, tensor in vars(grid_values_with_wanted_resolution).items()
|
||||
}
|
||||
# pyre-ignore[29]
|
||||
return self.values_type(**params), True
|
||||
|
||||
@ -239,6 +264,82 @@ class VoxelGridBase(ReplaceableBase, torch.nn.Module):
|
||||
"""
|
||||
return self.align_corners
|
||||
|
||||
def crop_world(
|
||||
self,
|
||||
min_point_world: torch.Tensor,
|
||||
max_point_world: torch.Tensor,
|
||||
grid_values: VoxelGridValuesBase,
|
||||
volume_locator: VolumeLocator,
|
||||
) -> VoxelGridValuesBase:
|
||||
"""
|
||||
Crops the voxel grid based on minimum and maximum occupied point in
|
||||
world coordinates. After cropping all 8 corner points are preserved in
|
||||
the voxel grid. This is achieved by preserving all the voxels needed to
|
||||
calculate the point.
|
||||
|
||||
+--------B
|
||||
/ /|
|
||||
/ / |
|
||||
+--------+ | <==== Bounding box represented by points A and B:
|
||||
| | | - B has x, y and z coordinates bigger or equal
|
||||
| | + to all other points of the object
|
||||
| | / - A has x, y and z coordinates smaller or equal
|
||||
| |/ to all other points of the object
|
||||
A--------+
|
||||
|
||||
Args:
|
||||
min_point_world: torch.Tensor of shape (3,). Has x, y and z coordinates
|
||||
smaller or equal to all other occupied points. Point A from the
|
||||
picture above.
|
||||
max_point_world: torch.Tensor of shape (3,). Has x, y and z coordinates
|
||||
bigger or equal to all other occupied points. Point B from the
|
||||
picture above.
|
||||
grid_values: instance of self.values_type which contains
|
||||
the voxel grid which will be cropped to create the new grid
|
||||
volume_locator: VolumeLocator object used to convert world to local
|
||||
cordinates
|
||||
Returns:
|
||||
instance of self.values_type which has volume cropped to desired size.
|
||||
"""
|
||||
min_point_local = volume_locator.world_to_local_coords(min_point_world[None])[0]
|
||||
max_point_local = volume_locator.world_to_local_coords(max_point_world[None])[0]
|
||||
return self.crop_local(min_point_local, max_point_local, grid_values)
|
||||
|
||||
def crop_local(
|
||||
self,
|
||||
min_point_local: torch.Tensor,
|
||||
max_point_local: torch.Tensor,
|
||||
grid_values: VoxelGridValuesBase,
|
||||
) -> VoxelGridValuesBase:
|
||||
"""
|
||||
Crops the voxel grid based on minimum and maximum occupied point in local
|
||||
coordinates. After cropping both min and max point are preserved in the voxel
|
||||
grid. This is achieved by preserving all the voxels needed to calculate the point.
|
||||
|
||||
+--------B
|
||||
/ /|
|
||||
/ / |
|
||||
+--------+ | <==== Bounding box represented by points A and B:
|
||||
| | | - B has x, y and z coordinates bigger or equal
|
||||
| | + to all other points of the object
|
||||
| | / - A has x, y and z coordinates smaller or equal
|
||||
| |/ to all other points of the object
|
||||
A--------+
|
||||
|
||||
Args:
|
||||
min_point_local: torch.Tensor of shape (3,). Has x, y and z coordinates
|
||||
smaller or equal to all other occupied points. Point A from the
|
||||
picture above. All elements in [-1, 1].
|
||||
max_point_local: torch.Tensor of shape (3,). Has x, y and z coordinates
|
||||
bigger or equal to all other occupied points. Point B from the
|
||||
picture above. All elements in [-1, 1].
|
||||
grid_values: instance of self.values_type which contains
|
||||
the voxel grid which will be cropped to create the new grid
|
||||
Returns:
|
||||
instance of self.values_type which has volume cropped to desired size.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
@dataclass
|
||||
class FullResolutionVoxelGridValues(VoxelGridValuesBase):
|
||||
@ -288,6 +389,34 @@ class FullResolutionVoxelGrid(VoxelGridBase):
|
||||
width, height, depth = self.get_resolution(epoch)
|
||||
return {"voxel_grid": (self.n_features, width, height, depth)}
|
||||
|
||||
# pyre-ignore[14]
|
||||
def crop_local(
|
||||
self,
|
||||
min_point_local: torch.Tensor,
|
||||
max_point_local: torch.Tensor,
|
||||
grid_values: FullResolutionVoxelGridValues,
|
||||
) -> FullResolutionVoxelGridValues:
|
||||
assert torch.all(min_point_local < max_point_local)
|
||||
min_point_local = torch.clamp(min_point_local, -1, 1)
|
||||
max_point_local = torch.clamp(max_point_local, -1, 1)
|
||||
_, _, width, height, depth = grid_values.voxel_grid.shape
|
||||
resolution = grid_values.voxel_grid.new_tensor([width, height, depth])
|
||||
min_point_local01 = (min_point_local + 1) / 2
|
||||
max_point_local01 = (max_point_local + 1) / 2
|
||||
|
||||
if self.align_corners:
|
||||
minx, miny, minz = torch.floor(min_point_local01 * (resolution - 1)).long()
|
||||
maxx, maxy, maxz = torch.ceil(max_point_local01 * (resolution - 1)).long()
|
||||
else:
|
||||
minx, miny, minz = torch.floor(min_point_local01 * resolution - 0.5).long()
|
||||
maxx, maxy, maxz = torch.ceil(max_point_local01 * resolution - 0.5).long()
|
||||
|
||||
return FullResolutionVoxelGridValues(
|
||||
voxel_grid=grid_values.voxel_grid[
|
||||
:, :, minx : maxx + 1, miny : maxy + 1, minz : maxz + 1
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class CPFactorizedVoxelGridValues(VoxelGridValuesBase):
|
||||
@ -388,6 +517,37 @@ class CPFactorizedVoxelGrid(VoxelGridBase):
|
||||
shape_dict["basis_matrix"] = (self.n_components, self.n_features)
|
||||
return shape_dict
|
||||
|
||||
# pyre-ignore[14]
|
||||
def crop_local(
|
||||
self,
|
||||
min_point_local: torch.Tensor,
|
||||
max_point_local: torch.Tensor,
|
||||
grid_values: CPFactorizedVoxelGridValues,
|
||||
) -> CPFactorizedVoxelGridValues:
|
||||
assert torch.all(min_point_local < max_point_local)
|
||||
min_point_local = torch.clamp(min_point_local, -1, 1)
|
||||
max_point_local = torch.clamp(max_point_local, -1, 1)
|
||||
_, _, width = grid_values.vector_components_x.shape
|
||||
_, _, height = grid_values.vector_components_y.shape
|
||||
_, _, depth = grid_values.vector_components_z.shape
|
||||
resolution = grid_values.vector_components_x.new_tensor([width, height, depth])
|
||||
min_point_local01 = (min_point_local + 1) / 2
|
||||
max_point_local01 = (max_point_local + 1) / 2
|
||||
|
||||
if self.align_corners:
|
||||
minx, miny, minz = torch.floor(min_point_local01 * (resolution - 1)).long()
|
||||
maxx, maxy, maxz = torch.ceil(max_point_local01 * (resolution - 1)).long()
|
||||
else:
|
||||
minx, miny, minz = torch.floor(min_point_local01 * resolution - 0.5).long()
|
||||
maxx, maxy, maxz = torch.ceil(max_point_local01 * resolution - 0.5).long()
|
||||
|
||||
return CPFactorizedVoxelGridValues(
|
||||
vector_components_x=grid_values.vector_components_x[:, :, minx : maxx + 1],
|
||||
vector_components_y=grid_values.vector_components_y[:, :, miny : maxy + 1],
|
||||
vector_components_z=grid_values.vector_components_z[:, :, minz : maxz + 1],
|
||||
basis_matrix=grid_values.basis_matrix,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class VMFactorizedVoxelGridValues(VoxelGridValuesBase):
|
||||
@ -585,6 +745,46 @@ class VMFactorizedVoxelGrid(VoxelGridBase):
|
||||
|
||||
return shape_dict
|
||||
|
||||
# pyre-ignore[14]
|
||||
def crop_local(
|
||||
self,
|
||||
min_point_local: torch.Tensor,
|
||||
max_point_local: torch.Tensor,
|
||||
grid_values: VMFactorizedVoxelGridValues,
|
||||
) -> VMFactorizedVoxelGridValues:
|
||||
assert torch.all(min_point_local < max_point_local)
|
||||
min_point_local = torch.clamp(min_point_local, -1, 1)
|
||||
max_point_local = torch.clamp(max_point_local, -1, 1)
|
||||
_, _, width = grid_values.vector_components_x.shape
|
||||
_, _, height = grid_values.vector_components_y.shape
|
||||
_, _, depth = grid_values.vector_components_z.shape
|
||||
resolution = grid_values.vector_components_x.new_tensor([width, height, depth])
|
||||
min_point_local01 = (min_point_local + 1) / 2
|
||||
max_point_local01 = (max_point_local + 1) / 2
|
||||
|
||||
if self.align_corners:
|
||||
minx, miny, minz = torch.floor(min_point_local01 * (resolution - 1)).long()
|
||||
maxx, maxy, maxz = torch.ceil(max_point_local01 * (resolution - 1)).long()
|
||||
else:
|
||||
minx, miny, minz = torch.floor(min_point_local01 * resolution - 0.5).long()
|
||||
maxx, maxy, maxz = torch.ceil(max_point_local01 * resolution - 0.5).long()
|
||||
|
||||
return VMFactorizedVoxelGridValues(
|
||||
vector_components_x=grid_values.vector_components_x[:, :, minx : maxx + 1],
|
||||
vector_components_y=grid_values.vector_components_y[:, :, miny : maxy + 1],
|
||||
vector_components_z=grid_values.vector_components_z[:, :, minz : maxz + 1],
|
||||
matrix_components_xy=grid_values.matrix_components_xy[
|
||||
:, :, minx : maxx + 1, miny : maxy + 1
|
||||
],
|
||||
matrix_components_yz=grid_values.matrix_components_yz[
|
||||
:, :, miny : maxy + 1, minz : maxz + 1
|
||||
],
|
||||
matrix_components_xz=grid_values.matrix_components_xz[
|
||||
:, :, minx : maxx + 1, minz : maxz + 1
|
||||
],
|
||||
basis_matrix=grid_values.basis_matrix,
|
||||
)
|
||||
|
||||
|
||||
# pyre-fixme[13]: Attribute `voxel_grid` is never initialized.
|
||||
class VoxelGridModule(Configurable, torch.nn.Module):
|
||||
@ -771,6 +971,37 @@ class VoxelGridModule(Configurable, torch.nn.Module):
|
||||
# pyre-ignore[29]
|
||||
return next(val for val in self.params.values() if val is not None).device
|
||||
|
||||
def crop_self(self, min_point: torch.Tensor, max_point: torch.Tensor) -> None:
|
||||
"""
|
||||
Crops self to only represent points between min_point and max_point (inclusive).
|
||||
|
||||
Args:
|
||||
min_point: torch.Tensor of shape (3,). Has x, y and z coordinates
|
||||
smaller or equal to all other occupied points.
|
||||
max_point: torch.Tensor of shape (3,). Has x, y and z coordinates
|
||||
bigger or equal to all other occupied points.
|
||||
Returns:
|
||||
nothing
|
||||
"""
|
||||
locator = self._get_volume_locator()
|
||||
# pyre-fixme[29]: `Union[torch._tensor.Tensor,
|
||||
# torch.nn.modules.module.Module]` is not a function.
|
||||
old_grid_values = self.voxel_grid.values_type(**self.params)
|
||||
new_grid_values = self.voxel_grid.crop_world(
|
||||
min_point, max_point, old_grid_values, locator
|
||||
)
|
||||
grid_values, _ = self.voxel_grid.change_resolution(
|
||||
new_grid_values, grid_values_with_wanted_resolution=old_grid_values
|
||||
)
|
||||
# pyre-ignore [16]
|
||||
self.params = torch.nn.ParameterDict(
|
||||
{k: v for k, v in vars(grid_values).items()}
|
||||
)
|
||||
# New center of voxel grid is the middle point between max and min points.
|
||||
self.translation = tuple((max_point + min_point) / 2)
|
||||
# new extents of voxel grid are distances between min and max points
|
||||
self.extents = tuple(max_point - min_point)
|
||||
|
||||
def _get_volume_locator(self) -> VolumeLocator:
|
||||
"""
|
||||
Returns VolumeLocator calculated from `extents` and `translation` members.
|
||||
|
@ -9,7 +9,7 @@ import unittest
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
from omegaconf import DictConfig
|
||||
from omegaconf import DictConfig, OmegaConf
|
||||
|
||||
from pytorch3d.implicitron.models.implicit_function.utils import (
|
||||
interpolate_line,
|
||||
@ -19,11 +19,12 @@ from pytorch3d.implicitron.models.implicit_function.utils import (
|
||||
from pytorch3d.implicitron.models.implicit_function.voxel_grid import (
|
||||
CPFactorizedVoxelGrid,
|
||||
FullResolutionVoxelGrid,
|
||||
FullResolutionVoxelGridValues,
|
||||
VMFactorizedVoxelGrid,
|
||||
VoxelGridModule,
|
||||
)
|
||||
|
||||
from pytorch3d.implicitron.tools.config import expand_args_fields
|
||||
from pytorch3d.implicitron.tools.config import expand_args_fields, get_default_args
|
||||
from tests.common_testing import TestCaseMixin
|
||||
|
||||
|
||||
@ -693,6 +694,140 @@ class TestVoxelGrids(TestCaseMixin, unittest.TestCase):
|
||||
grid.voxel_grid.get_resolution(epoch=epoch),
|
||||
)
|
||||
|
||||
def _get_min_max_tuple(
|
||||
self, n=4, denominator_base=2, max_exponent=6, add_edge_cases=True
|
||||
):
|
||||
if add_edge_cases:
|
||||
n -= 2
|
||||
|
||||
def get_pair():
|
||||
def get_one():
|
||||
sign = -1 if torch.rand((1,)) < 0.5 else 1
|
||||
exponent = int(torch.randint(1, max_exponent, (1,)))
|
||||
denominator = denominator_base**exponent
|
||||
numerator = int(torch.randint(1, denominator, (1,)))
|
||||
return sign * numerator / denominator * 1.0
|
||||
|
||||
while True:
|
||||
a, b = get_one(), get_one()
|
||||
if a < b:
|
||||
return a, b
|
||||
|
||||
for _ in range(n):
|
||||
a, b, c = get_pair(), get_pair(), get_pair()
|
||||
yield torch.tensor((a[0], b[0], c[0])), torch.tensor((a[1], b[1], c[1]))
|
||||
if add_edge_cases:
|
||||
yield torch.tensor((-1.0, -1.0, -1.0)), torch.tensor((1.0, 1.0, 1.0))
|
||||
yield torch.tensor([0.0, 0.0, 0.0]), torch.tensor([1.0, 1.0, 1.0])
|
||||
|
||||
def test_cropping_voxel_grids(self, n_times=1):
|
||||
"""
|
||||
If the grid is 1d and we crop at A and B
|
||||
---------A---------B---
|
||||
and choose point p between them
|
||||
---------A-----p---B---
|
||||
it can be represented as
|
||||
p = A + (B-A) * p_c
|
||||
where p_c is local coordinate of p in cropped grid. So we now just see
|
||||
if grid evaluated at p and cropped grid evaluated at p_c agree.
|
||||
"""
|
||||
for points_min, points_max in self._get_min_max_tuple(n=10):
|
||||
n_grids, n_features, n_components = torch.randint(1, 3, (3,)).tolist()
|
||||
n_grids = 1
|
||||
n_components *= 3
|
||||
resolution_changes = {0: (128 + 1, 128 + 1, 128 + 1)}
|
||||
for cls, kwargs in (
|
||||
(
|
||||
FullResolutionVoxelGrid,
|
||||
{
|
||||
"n_features": n_features,
|
||||
"resolution_changes": resolution_changes,
|
||||
},
|
||||
),
|
||||
(
|
||||
CPFactorizedVoxelGrid,
|
||||
{
|
||||
"n_features": n_features,
|
||||
"resolution_changes": resolution_changes,
|
||||
"n_components": n_components,
|
||||
},
|
||||
),
|
||||
(
|
||||
VMFactorizedVoxelGrid,
|
||||
{
|
||||
"n_features": n_features,
|
||||
"resolution_changes": resolution_changes,
|
||||
"n_components": n_components,
|
||||
},
|
||||
),
|
||||
):
|
||||
with self.subTest(
|
||||
cls.__name__ + f" points {points_min} and {points_max}"
|
||||
):
|
||||
grid = cls(**kwargs)
|
||||
shapes = grid.get_shapes(epoch=0)
|
||||
params = {
|
||||
name: torch.normal(
|
||||
mean=torch.zeros((n_grids, *shape)),
|
||||
std=1,
|
||||
)
|
||||
for name, shape in shapes.items()
|
||||
}
|
||||
grid_values = grid.values_type(**params)
|
||||
|
||||
grid_values_cropped = grid.crop_local(
|
||||
points_min, points_max, grid_values
|
||||
)
|
||||
|
||||
points_local_cropped = torch.rand((1, n_times, 3))
|
||||
points_local = (
|
||||
points_min[None, None]
|
||||
+ (points_max - points_min)[None, None] * points_local_cropped
|
||||
)
|
||||
points_local_cropped = (points_local_cropped - 0.5) * 2
|
||||
|
||||
pred = grid.evaluate_local(points_local, grid_values)
|
||||
pred_cropped = grid.evaluate_local(
|
||||
points_local_cropped, grid_values_cropped
|
||||
)
|
||||
|
||||
assert torch.allclose(pred, pred_cropped, rtol=1e-4, atol=1e-4), (
|
||||
pred,
|
||||
pred_cropped,
|
||||
points_local,
|
||||
points_local_cropped,
|
||||
)
|
||||
|
||||
def test_cropping_voxel_grid_module(self, n_times=1):
|
||||
for points_min, points_max in self._get_min_max_tuple(n=5, max_exponent=5):
|
||||
extents = torch.ones((3,)) * 2
|
||||
translation = torch.ones((3,)) * 0.2
|
||||
points_min += translation
|
||||
points_max += translation
|
||||
|
||||
default_cfg = get_default_args(VoxelGridModule)
|
||||
custom_cfg = DictConfig(
|
||||
{
|
||||
"extents": tuple(float(e) for e in extents),
|
||||
"translation": tuple(float(t) for t in translation),
|
||||
"voxel_grid_FullResolutionVoxelGrid_args": {
|
||||
"resolution_changes": {0: (128 + 1, 128 + 1, 128 + 1)}
|
||||
},
|
||||
}
|
||||
)
|
||||
cfg = OmegaConf.merge(default_cfg, custom_cfg)
|
||||
grid = VoxelGridModule(**cfg)
|
||||
|
||||
points = (torch.rand(3) * (points_max - points_min) + points_min)[None]
|
||||
result = grid(points)
|
||||
grid.crop_self(points_min, points_max)
|
||||
result_cropped = grid(points)
|
||||
|
||||
assert torch.allclose(result, result_cropped, rtol=0.001, atol=0.001), (
|
||||
result,
|
||||
result_cropped,
|
||||
)
|
||||
|
||||
def test_loading_state_dict(self):
|
||||
"""
|
||||
Test loading state dict after rescaling.
|
||||
|
Loading…
x
Reference in New Issue
Block a user