mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 03:42:50 +08:00
scaffold
Summary: Forward method is sped up using the scaffold, a low resolution voxel grid which is used to filter out the points in empty space. These points will be predicted as having 0 density and (0, 0, 0) color. The points which were not evaluated as empty space will be passed through the steps outlined above. Reviewed By: kjchalup Differential Revision: D39579671 fbshipit-source-id: 8eab8bb43ef77c2a73557efdb725e99a6c60d415
This commit is contained in:
parent
95a2acf763
commit
56d3465b09
@ -15,8 +15,9 @@ these classes.
|
||||
|
||||
"""
|
||||
|
||||
from collections.abc import Mapping
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Callable, ClassVar, Dict, List, Optional, Tuple, Type
|
||||
from typing import Callable, ClassVar, Dict, Iterator, List, Optional, Tuple, Type
|
||||
|
||||
import torch
|
||||
from omegaconf import DictConfig
|
||||
@ -164,8 +165,9 @@ class VoxelGridBase(ReplaceableBase, torch.nn.Module):
|
||||
|
||||
def change_resolution(
|
||||
self,
|
||||
epoch: int,
|
||||
grid_values: VoxelGridValuesBase,
|
||||
epoch: int,
|
||||
*,
|
||||
mode: str = "linear",
|
||||
align_corners: bool = True,
|
||||
antialias: bool = False,
|
||||
@ -177,8 +179,8 @@ class VoxelGridBase(ReplaceableBase, torch.nn.Module):
|
||||
epoch: current training epoch, used to see if the grid needs regridding
|
||||
grid_values: instance of self.values_type which contains
|
||||
the voxel grid which will be interpolated to create the new grid
|
||||
wanted_resolution: tuple of (x, y, z) resolutions which determine
|
||||
new grid's resolution
|
||||
epoch: epoch which is used to get the resolution of the new
|
||||
`grid_values` using `self.resolution_changes`.
|
||||
align_corners: as for torch.nn.functional.interpolate
|
||||
mode: as for torch.nn.functional.interpolate
|
||||
'nearest' | 'bicubic' | 'linear' | 'area' | 'nearest-exact'.
|
||||
@ -225,11 +227,17 @@ class VoxelGridBase(ReplaceableBase, torch.nn.Module):
|
||||
# pyre-ignore[29]
|
||||
return self.values_type(**params), True
|
||||
|
||||
def get_resolution_change_epochs(self) -> List[int]:
|
||||
def get_resolution_change_epochs(self) -> Tuple[int, ...]:
|
||||
"""
|
||||
Returns epochs at which this grid should change epochs.
|
||||
"""
|
||||
return list(self.resolution_changes.keys())
|
||||
return tuple(self.resolution_changes.keys())
|
||||
|
||||
def get_align_corners(self) -> bool:
|
||||
"""
|
||||
Returns True if voxel grid uses align_corners=True
|
||||
"""
|
||||
return self.align_corners
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -583,6 +591,8 @@ class VoxelGridModule(Configurable, torch.nn.Module):
|
||||
"""
|
||||
A wrapper torch.nn.Module for the VoxelGrid classes, which
|
||||
contains parameters that are needed to train the VoxelGrid classes.
|
||||
Can contain the parameters for the voxel grid as pytorch parameters
|
||||
or as registered buffers.
|
||||
|
||||
Members:
|
||||
voxel_grid_class_type: The name of the class to use for voxel_grid,
|
||||
@ -596,17 +606,21 @@ class VoxelGridModule(Configurable, torch.nn.Module):
|
||||
with mean=init_mean and std=init_std. Default 0.1
|
||||
init_mean: Parameters are initialized using the gaussian distribution
|
||||
with mean=init_mean and std=init_std. Default 0.
|
||||
hold_voxel_grid_as_parameters: if True components of the underlying voxel grids
|
||||
will be saved as parameters and therefore be trainable. Default True.
|
||||
"""
|
||||
|
||||
voxel_grid_class_type: str = "FullResolutionVoxelGrid"
|
||||
voxel_grid: VoxelGridBase
|
||||
|
||||
extents: Tuple[float, float, float] = (1.0, 1.0, 1.0)
|
||||
extents: Tuple[float, float, float] = (2.0, 2.0, 2.0)
|
||||
translation: Tuple[float, float, float] = (0.0, 0.0, 0.0)
|
||||
|
||||
init_std: float = 0.1
|
||||
init_mean: float = 0
|
||||
|
||||
hold_voxel_grid_as_parameters: bool = True
|
||||
|
||||
def __post_init__(self):
|
||||
super().__init__()
|
||||
run_auto_creation(self)
|
||||
@ -619,7 +633,8 @@ class VoxelGridModule(Configurable, torch.nn.Module):
|
||||
)
|
||||
for name, shape in shapes.items()
|
||||
}
|
||||
self.params = torch.nn.ParameterDict(params)
|
||||
|
||||
self.set_voxel_grid_parameters(self.voxel_grid.values_type(**params))
|
||||
self._register_load_state_dict_pre_hook(self._create_parameters_with_new_size)
|
||||
|
||||
def forward(self, points: torch.Tensor) -> torch.Tensor:
|
||||
@ -632,31 +647,29 @@ class VoxelGridModule(Configurable, torch.nn.Module):
|
||||
Returns:
|
||||
torch.Tensor of shape (..., n_features)
|
||||
"""
|
||||
locator = VolumeLocator(
|
||||
batch_size=1,
|
||||
# The resolution of the voxel grid does not need to be known
|
||||
# to the locator object. It is easiest to fix the resolution of the locator.
|
||||
# In particular we fix it to (2,2,2) so that there is exactly one voxel of the
|
||||
# desired size. The locator object uses (z, y, x) convention for the grid_size,
|
||||
# and this module uses (x, y, z) convention so the order has to be reversed
|
||||
# (irrelevant in this case since they are all equal).
|
||||
# It is (2, 2, 2) because the VolumeLocator object behaves like
|
||||
# align_corners=True, which means that the points are in the corners of
|
||||
# the volume. So in the grid of (2, 2, 2) there is only one voxel.
|
||||
grid_sizes=(2, 2, 2),
|
||||
# The locator object uses (x, y, z) convention for the
|
||||
# voxel size and translation.
|
||||
voxel_size=tuple(self.extents),
|
||||
volume_translation=tuple(self.translation),
|
||||
# pyre-ignore[29]
|
||||
device=next(val for val in self.params.values() if val is not None).device,
|
||||
)
|
||||
locator = self._get_volume_locator()
|
||||
# pyre-fixme[29]: `Union[torch._tensor.Tensor,
|
||||
# torch.nn.modules.module.Module]` is not a function.
|
||||
grid_values = self.voxel_grid.values_type(**self.params)
|
||||
# voxel grids operate with extra n_grids dimension, which we fix to one
|
||||
return self.voxel_grid.evaluate_world(points[None], grid_values, locator)[0]
|
||||
|
||||
def set_voxel_grid_parameters(self, params: VoxelGridValuesBase) -> None:
|
||||
"""
|
||||
Sets the parameters of the underlying voxel grid.
|
||||
|
||||
Args:
|
||||
params: parameters of type `self.voxel_grid.values_type` which will
|
||||
replace current parameters
|
||||
"""
|
||||
if self.hold_voxel_grid_as_parameters:
|
||||
# pyre-ignore [16]
|
||||
self.params = torch.nn.ParameterDict(vars(params))
|
||||
else:
|
||||
# Torch Module to hold parameters since they can only be registered
|
||||
# at object level.
|
||||
self.params = _RegistratedBufferDict(vars(params))
|
||||
|
||||
@staticmethod
|
||||
def get_output_dim(args: DictConfig) -> int:
|
||||
"""
|
||||
@ -672,12 +685,12 @@ class VoxelGridModule(Configurable, torch.nn.Module):
|
||||
args["voxel_grid_" + args["voxel_grid_class_type"] + "_args"]
|
||||
)
|
||||
|
||||
def subscribe_to_epochs(self) -> Tuple[List[int], Callable[[int], bool]]:
|
||||
def subscribe_to_epochs(self) -> Tuple[Tuple[int, ...], Callable[[int], bool]]:
|
||||
"""
|
||||
Method which expresses interest in subscribing to optimization epoch updates.
|
||||
|
||||
Returns:
|
||||
list of epochs on which to call a callable and callable to be called on
|
||||
tuple of epochs on which to call a callable and callable to be called on
|
||||
particular epoch. The callable returns True if parameter change has
|
||||
happened else False and it must be supplied with one argument, epoch.
|
||||
"""
|
||||
@ -697,13 +710,12 @@ class VoxelGridModule(Configurable, torch.nn.Module):
|
||||
"""
|
||||
# pyre-ignore[29]
|
||||
grid_values = self.voxel_grid.values_type(**self.params)
|
||||
grid_values, change = self.voxel_grid.change_resolution(epoch, grid_values)
|
||||
grid_values, change = self.voxel_grid.change_resolution(
|
||||
grid_values, epoch=epoch
|
||||
)
|
||||
if change:
|
||||
# pyre-ignore[16]
|
||||
self.params = torch.nn.ParameterDict(
|
||||
{name: tensor for name, tensor in vars(grid_values).items()}
|
||||
)
|
||||
return change
|
||||
self.set_voxel_grid_parameters(grid_values)
|
||||
return change and self.hold_voxel_grid_as_parameters
|
||||
|
||||
def _create_parameters_with_new_size(
|
||||
self,
|
||||
@ -749,5 +761,113 @@ class VoxelGridModule(Configurable, torch.nn.Module):
|
||||
key = prefix + "params." + name
|
||||
if key in state_dict:
|
||||
new_params[name] = torch.zeros_like(state_dict[key])
|
||||
# pyre-ignore[16]
|
||||
self.params = torch.nn.ParameterDict(new_params)
|
||||
# pyre-ignore[29]
|
||||
self.set_voxel_grid_parameters(self.voxel_grid.values_type(**new_params))
|
||||
|
||||
def get_device(self) -> torch.device:
|
||||
"""
|
||||
Returns torch.device on which module parameters are located
|
||||
"""
|
||||
# pyre-ignore[29]
|
||||
return next(val for val in self.params.values() if val is not None).device
|
||||
|
||||
def _get_volume_locator(self) -> VolumeLocator:
|
||||
"""
|
||||
Returns VolumeLocator calculated from `extents` and `translation` members.
|
||||
"""
|
||||
return VolumeLocator(
|
||||
batch_size=1,
|
||||
# The resolution of the voxel grid does not need to be known
|
||||
# to the locator object. It is easiest to fix the resolution of the locator.
|
||||
# In particular we fix it to (2,2,2) so that there is exactly one voxel of the
|
||||
# desired size. The locator object uses (z, y, x) convention for the grid_size,
|
||||
# and this module uses (x, y, z) convention so the order has to be reversed
|
||||
# (irrelevant in this case since they are all equal).
|
||||
# It is (2, 2, 2) because the VolumeLocator object behaves like
|
||||
# align_corners=True, which means that the points are in the corners of
|
||||
# the volume. So in the grid of (2, 2, 2) there is only one voxel.
|
||||
grid_sizes=(2, 2, 2),
|
||||
# The locator object uses (x, y, z) convention for the
|
||||
# voxel size and translation.
|
||||
voxel_size=tuple(self.extents),
|
||||
# volume_translation is defined in `VolumeLocator` as a vector from the origin
|
||||
# of local coordinate frame to origin of world coordinate frame, that is:
|
||||
# x_world = x_local * extents/2 - translation.
|
||||
# To get the reverse we need to negate it.
|
||||
volume_translation=tuple(-t for t in self.translation),
|
||||
device=self.get_device(),
|
||||
)
|
||||
|
||||
def get_grid_points(self, epoch: int) -> torch.Tensor:
|
||||
"""
|
||||
Returns a grid of points that represent centers of voxels of the
|
||||
underlying voxel grid in world coordinates at specific epoch.
|
||||
|
||||
Args:
|
||||
epoch: underlying voxel grids change resolution depending on the
|
||||
epoch, this argument is used to determine the resolution
|
||||
of the voxel grid at that epoch.
|
||||
Returns:
|
||||
tensor of shape [xresolution, yresolution, zresolution, 3] where
|
||||
xresolution, yresolution, zresolution are resolutions of the
|
||||
underlying voxel grid
|
||||
"""
|
||||
xresolution, yresolution, zresolution = self.voxel_grid.get_resolution(epoch)
|
||||
width, height, depth = self.extents
|
||||
if not self.voxel_grid.get_align_corners():
|
||||
width = (
|
||||
width * (xresolution - 1) / xresolution if xresolution > 1 else width
|
||||
)
|
||||
height = (
|
||||
height * (xresolution - 1) / xresolution if xresolution > 1 else height
|
||||
)
|
||||
depth = (
|
||||
depth * (xresolution - 1) / xresolution if xresolution > 1 else depth
|
||||
)
|
||||
xs = torch.linspace(
|
||||
-width / 2, width / 2, xresolution, device=self.get_device()
|
||||
)
|
||||
ys = torch.linspace(
|
||||
-height / 2, height / 2, yresolution, device=self.get_device()
|
||||
)
|
||||
zs = torch.linspace(
|
||||
-depth / 2, depth / 2, zresolution, device=self.get_device()
|
||||
)
|
||||
xmesh, ymesh, zmesh = torch.meshgrid(xs, ys, zs, indexing="ij")
|
||||
return torch.stack((xmesh, ymesh, zmesh), dim=3)
|
||||
|
||||
|
||||
class _RegistratedBufferDict(torch.nn.Module, Mapping):
|
||||
"""
|
||||
Mapping class and a torch.nn.Module that registeres its values
|
||||
with `self.register_buffer`. Can be indexed like a regular Python
|
||||
dictionary, but torch.Tensors it contains are properly registered, and will be visible
|
||||
by all Module methods. Supports only `torch.Tensor` as value and str as key.
|
||||
"""
|
||||
|
||||
def __init__(self, init_dict: Optional[Dict[str, torch.Tensor]] = None) -> None:
|
||||
"""
|
||||
Args:
|
||||
init_dict: dictionary which will be used to populate the object
|
||||
"""
|
||||
super().__init__()
|
||||
self._keys = set()
|
||||
if init_dict is not None:
|
||||
for k, v in init_dict.items():
|
||||
self[k] = v
|
||||
|
||||
def __iter__(self) -> Iterator[Dict[str, torch.Tensor]]:
|
||||
return iter({k: self[k] for k in self._keys})
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self._keys)
|
||||
|
||||
def __getitem__(self, key: str) -> torch.Tensor:
|
||||
return getattr(self, key)
|
||||
|
||||
def __setitem__(self, key, value) -> None:
|
||||
self._keys.add(key)
|
||||
self.register_buffer(key, value)
|
||||
|
||||
def __hash__(self) -> int:
|
||||
return hash(repr(self))
|
||||
|
@ -653,7 +653,7 @@ class VolumeLocator:
|
||||
volume_translation: _Translation = (0.0, 0.0, 0.0),
|
||||
):
|
||||
"""
|
||||
**batch_size** : Batch size of the underlaying grids
|
||||
**batch_size** : Batch size of the underlying grids
|
||||
**grid_sizes** : Represents the resolutions of different grids in the batch. Can be
|
||||
a) tuple of form (H, W, D)
|
||||
b) list/tuple of length batch_size of lists/tuples of form (H, W, D)
|
||||
|
@ -35,9 +35,6 @@ class TestVoxelGrids(TestCaseMixin, unittest.TestCase):
|
||||
one by one sample and comparing with the batched implementation.
|
||||
"""
|
||||
|
||||
def test_my_code(self):
|
||||
return
|
||||
|
||||
def get_random_normalized_points(
|
||||
self, n_grids, n_points=None, dimension=3
|
||||
) -> torch.Tensor:
|
||||
@ -293,6 +290,8 @@ class TestVoxelGrids(TestCaseMixin, unittest.TestCase):
|
||||
padding_mode="zeros",
|
||||
mode="bilinear",
|
||||
),
|
||||
rtol=0.0001,
|
||||
atol=0.0001,
|
||||
)
|
||||
with self.subTest("2D interpolation"):
|
||||
points = self.get_random_normalized_points(
|
||||
@ -308,6 +307,8 @@ class TestVoxelGrids(TestCaseMixin, unittest.TestCase):
|
||||
padding_mode="zeros",
|
||||
mode="bilinear",
|
||||
),
|
||||
rtol=0.0001,
|
||||
atol=0.0001,
|
||||
)
|
||||
|
||||
with self.subTest("3D interpolation"):
|
||||
@ -325,6 +326,7 @@ class TestVoxelGrids(TestCaseMixin, unittest.TestCase):
|
||||
mode="bilinear",
|
||||
),
|
||||
rtol=0.0001,
|
||||
atol=0.0001,
|
||||
)
|
||||
|
||||
def test_floating_point_query(self):
|
||||
@ -378,7 +380,8 @@ class TestVoxelGrids(TestCaseMixin, unittest.TestCase):
|
||||
assert torch.allclose(
|
||||
grid.evaluate_local(points, params),
|
||||
expected_result,
|
||||
rtol=0.00001,
|
||||
rtol=0.0001,
|
||||
atol=0.0001,
|
||||
), grid.evaluate_local(points, params)
|
||||
with self.subTest("CP"):
|
||||
grid = CPFactorizedVoxelGrid(
|
||||
@ -446,14 +449,16 @@ class TestVoxelGrids(TestCaseMixin, unittest.TestCase):
|
||||
assert torch.allclose(
|
||||
grid.evaluate_local(points, params),
|
||||
expected_result_matrix,
|
||||
rtol=0.00001,
|
||||
rtol=0.0001,
|
||||
atol=0.0001,
|
||||
)
|
||||
del params.basis_matrix
|
||||
with self.subTest("CP with sum reduction"):
|
||||
assert torch.allclose(
|
||||
grid.evaluate_local(points, params),
|
||||
expected_result_sum,
|
||||
rtol=0.00001,
|
||||
rtol=0.0001,
|
||||
atol=0.0001,
|
||||
)
|
||||
|
||||
with self.subTest("VM"):
|
||||
@ -540,7 +545,8 @@ class TestVoxelGrids(TestCaseMixin, unittest.TestCase):
|
||||
assert torch.allclose(
|
||||
grid.evaluate_local(points, params),
|
||||
expected_result_matrix,
|
||||
rtol=0.00001,
|
||||
rtol=0.0001,
|
||||
atol=0.0001,
|
||||
)
|
||||
del params.basis_matrix
|
||||
with self.subTest("VM with sum reduction"):
|
||||
@ -548,6 +554,7 @@ class TestVoxelGrids(TestCaseMixin, unittest.TestCase):
|
||||
grid.evaluate_local(points, params),
|
||||
expected_result_sum,
|
||||
rtol=0.0001,
|
||||
atol=0.0001,
|
||||
), grid.evaluate_local(points, params)
|
||||
|
||||
def test_forward_with_small_init_std(self):
|
||||
@ -613,6 +620,7 @@ class TestVoxelGrids(TestCaseMixin, unittest.TestCase):
|
||||
grid(world_point)[0, 0],
|
||||
grid.voxel_grid.evaluate_local(local_point[None], grid_values)[0, 0, 0],
|
||||
rtol=0.0001,
|
||||
atol=0.0001,
|
||||
)
|
||||
|
||||
def test_resolution_change(self, n_times=10):
|
||||
|
Loading…
x
Reference in New Issue
Block a user