mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 03:42:50 +08:00
Align_corners switch in Volumes
Summary:
Porting this commit by davnov134 .
93a3a62800 (diff-a8e107ebe039de52ca112ac6ddfba6ebccd53b4f53030b986e13f019fe57a378)
Capability to interpret world/local coordinates with various align_corners semantics.
Reviewed By: bottler
Differential Revision: D51855420
fbshipit-source-id: 834cd220c25d7f0143d8a55ba880da5977099dd6
This commit is contained in:
parent
fbc6725f03
commit
94da8841af
@ -98,6 +98,13 @@ def save_model(model, stats, fl, optimizer=None, cfg=None):
|
|||||||
return flstats, flmodel, flopt
|
return flstats, flmodel, flopt
|
||||||
|
|
||||||
|
|
||||||
|
def save_stats(stats, fl, cfg=None):
|
||||||
|
flstats = get_stats_path(fl)
|
||||||
|
logger.info("saving model stats to %s" % flstats)
|
||||||
|
stats.save(flstats)
|
||||||
|
return flstats
|
||||||
|
|
||||||
|
|
||||||
def load_model(fl, map_location: Optional[dict]):
|
def load_model(fl, map_location: Optional[dict]):
|
||||||
flstats = get_stats_path(fl)
|
flstats = get_stats_path(fl)
|
||||||
flmodel = get_model_path(fl)
|
flmodel = get_model_path(fl)
|
||||||
|
@ -291,6 +291,7 @@ def add_pointclouds_to_volumes(
|
|||||||
mask=mask,
|
mask=mask,
|
||||||
mode=mode,
|
mode=mode,
|
||||||
rescale_features=rescale_features,
|
rescale_features=rescale_features,
|
||||||
|
align_corners=initial_volumes.get_align_corners(),
|
||||||
_python=_python,
|
_python=_python,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -310,6 +311,7 @@ def add_points_features_to_volume_densities_features(
|
|||||||
grid_sizes: Optional[torch.LongTensor] = None,
|
grid_sizes: Optional[torch.LongTensor] = None,
|
||||||
rescale_features: bool = True,
|
rescale_features: bool = True,
|
||||||
_python: bool = False,
|
_python: bool = False,
|
||||||
|
align_corners: bool = True,
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
"""
|
"""
|
||||||
Convert a batch of point clouds represented with tensors of per-point
|
Convert a batch of point clouds represented with tensors of per-point
|
||||||
@ -356,6 +358,7 @@ def add_points_features_to_volume_densities_features(
|
|||||||
output densities are just summed without rescaling, so
|
output densities are just summed without rescaling, so
|
||||||
you may need to rescale them afterwards.
|
you may need to rescale them afterwards.
|
||||||
_python: Set to True to use a pure Python implementation.
|
_python: Set to True to use a pure Python implementation.
|
||||||
|
align_corners: as for grid_sample.
|
||||||
Returns:
|
Returns:
|
||||||
volume_features: Output volume of shape `(minibatch, feature_dim, D, H, W)`
|
volume_features: Output volume of shape `(minibatch, feature_dim, D, H, W)`
|
||||||
volume_densities: Occupancy volume of shape `(minibatch, 1, D, H, W)`
|
volume_densities: Occupancy volume of shape `(minibatch, 1, D, H, W)`
|
||||||
@ -409,7 +412,7 @@ def add_points_features_to_volume_densities_features(
|
|||||||
grid_sizes,
|
grid_sizes,
|
||||||
1.0, # point_weight
|
1.0, # point_weight
|
||||||
mask,
|
mask,
|
||||||
True, # align_corners
|
align_corners, # align_corners
|
||||||
splat,
|
splat,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -382,9 +382,9 @@ class VolumeSampler(torch.nn.Module):
|
|||||||
rays_densities = torch.nn.functional.grid_sample(
|
rays_densities = torch.nn.functional.grid_sample(
|
||||||
volumes_densities,
|
volumes_densities,
|
||||||
rays_points_local_flat,
|
rays_points_local_flat,
|
||||||
align_corners=True,
|
|
||||||
mode=self._sample_mode,
|
mode=self._sample_mode,
|
||||||
padding_mode=self._padding_mode,
|
padding_mode=self._padding_mode,
|
||||||
|
align_corners=self._volumes.get_align_corners(),
|
||||||
)
|
)
|
||||||
|
|
||||||
# permute the dimensions & reshape densities after sampling
|
# permute the dimensions & reshape densities after sampling
|
||||||
@ -400,9 +400,9 @@ class VolumeSampler(torch.nn.Module):
|
|||||||
rays_features = torch.nn.functional.grid_sample(
|
rays_features = torch.nn.functional.grid_sample(
|
||||||
volumes_features,
|
volumes_features,
|
||||||
rays_points_local_flat,
|
rays_points_local_flat,
|
||||||
align_corners=True,
|
|
||||||
mode=self._sample_mode,
|
mode=self._sample_mode,
|
||||||
padding_mode=self._padding_mode,
|
padding_mode=self._padding_mode,
|
||||||
|
align_corners=self._volumes.get_align_corners(),
|
||||||
)
|
)
|
||||||
|
|
||||||
# permute the dimensions & reshape features after sampling
|
# permute the dimensions & reshape features after sampling
|
||||||
|
@ -85,7 +85,7 @@ class Volumes:
|
|||||||
are linearly interpolated over the spatial dimensions of the volume.
|
are linearly interpolated over the spatial dimensions of the volume.
|
||||||
- Note that the convention is the same as for the 5D version of the
|
- Note that the convention is the same as for the 5D version of the
|
||||||
`torch.nn.functional.grid_sample` function called with
|
`torch.nn.functional.grid_sample` function called with
|
||||||
`align_corners==True`.
|
the same value of `align_corners` argument.
|
||||||
- Note that the local coordinate convention of `Volumes`
|
- Note that the local coordinate convention of `Volumes`
|
||||||
(+X = left to right, +Y = top to bottom, +Z = away from the user)
|
(+X = left to right, +Y = top to bottom, +Z = away from the user)
|
||||||
is *different* from the world coordinate convention of the
|
is *different* from the world coordinate convention of the
|
||||||
@ -143,7 +143,7 @@ class Volumes:
|
|||||||
torch.nn.functional.grid_sample(
|
torch.nn.functional.grid_sample(
|
||||||
v.densities(),
|
v.densities(),
|
||||||
v.get_coord_grid(world_coordinates=False),
|
v.get_coord_grid(world_coordinates=False),
|
||||||
align_corners=True,
|
align_corners=align_corners,
|
||||||
) == v.densities(),
|
) == v.densities(),
|
||||||
|
|
||||||
i.e. sampling the volume at trivial local coordinates
|
i.e. sampling the volume at trivial local coordinates
|
||||||
@ -157,6 +157,7 @@ class Volumes:
|
|||||||
features: Optional[_TensorBatch] = None,
|
features: Optional[_TensorBatch] = None,
|
||||||
voxel_size: _VoxelSize = 1.0,
|
voxel_size: _VoxelSize = 1.0,
|
||||||
volume_translation: _Translation = (0.0, 0.0, 0.0),
|
volume_translation: _Translation = (0.0, 0.0, 0.0),
|
||||||
|
align_corners: bool = True,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
@ -186,6 +187,10 @@ class Volumes:
|
|||||||
b) a Tensor of shape (3,)
|
b) a Tensor of shape (3,)
|
||||||
c) a Tensor of shape (minibatch, 3)
|
c) a Tensor of shape (minibatch, 3)
|
||||||
d) a Tensor of shape (1,) (square voxels)
|
d) a Tensor of shape (1,) (square voxels)
|
||||||
|
**align_corners**: If set (default), the coordinates of the corner voxels are
|
||||||
|
exactly −1 or +1 in the local coordinate system. Otherwise, the coordinates
|
||||||
|
correspond to the centers of the corner voxels. Cf. the namesake argument to
|
||||||
|
`torch.nn.functional.grid_sample`.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# handle densities
|
# handle densities
|
||||||
@ -206,6 +211,7 @@ class Volumes:
|
|||||||
voxel_size=voxel_size,
|
voxel_size=voxel_size,
|
||||||
volume_translation=volume_translation,
|
volume_translation=volume_translation,
|
||||||
device=self.device,
|
device=self.device,
|
||||||
|
align_corners=align_corners,
|
||||||
)
|
)
|
||||||
|
|
||||||
# handle features
|
# handle features
|
||||||
@ -336,6 +342,13 @@ class Volumes:
|
|||||||
return None
|
return None
|
||||||
return self._features_densities_list(features_)
|
return self._features_densities_list(features_)
|
||||||
|
|
||||||
|
def get_align_corners(self) -> bool:
|
||||||
|
"""
|
||||||
|
Return whether the corners of the voxels should be aligned with the
|
||||||
|
image pixels.
|
||||||
|
"""
|
||||||
|
return self.locator._align_corners
|
||||||
|
|
||||||
def _features_densities_list(self, x: torch.Tensor) -> List[torch.Tensor]:
|
def _features_densities_list(self, x: torch.Tensor) -> List[torch.Tensor]:
|
||||||
"""
|
"""
|
||||||
Retrieve the list representation of features/densities.
|
Retrieve the list representation of features/densities.
|
||||||
@ -576,7 +589,7 @@ class VolumeLocator:
|
|||||||
are linearly interpolated over the spatial dimensions of the volume.
|
are linearly interpolated over the spatial dimensions of the volume.
|
||||||
- Note that the convention is the same as for the 5D version of the
|
- Note that the convention is the same as for the 5D version of the
|
||||||
`torch.nn.functional.grid_sample` function called with
|
`torch.nn.functional.grid_sample` function called with
|
||||||
`align_corners==True`.
|
the same value of `align_corners` argument.
|
||||||
- Note that the local coordinate convention of `VolumeLocator`
|
- Note that the local coordinate convention of `VolumeLocator`
|
||||||
(+X = left to right, +Y = top to bottom, +Z = away from the user)
|
(+X = left to right, +Y = top to bottom, +Z = away from the user)
|
||||||
is *different* from the world coordinate convention of the
|
is *different* from the world coordinate convention of the
|
||||||
@ -634,7 +647,7 @@ class VolumeLocator:
|
|||||||
torch.nn.functional.grid_sample(
|
torch.nn.functional.grid_sample(
|
||||||
v.densities(),
|
v.densities(),
|
||||||
v.get_coord_grid(world_coordinates=False),
|
v.get_coord_grid(world_coordinates=False),
|
||||||
align_corners=True,
|
align_corners=align_corners,
|
||||||
) == v.densities(),
|
) == v.densities(),
|
||||||
|
|
||||||
i.e. sampling the volume at trivial local coordinates
|
i.e. sampling the volume at trivial local coordinates
|
||||||
@ -651,6 +664,7 @@ class VolumeLocator:
|
|||||||
device: torch.device,
|
device: torch.device,
|
||||||
voxel_size: _VoxelSize = 1.0,
|
voxel_size: _VoxelSize = 1.0,
|
||||||
volume_translation: _Translation = (0.0, 0.0, 0.0),
|
volume_translation: _Translation = (0.0, 0.0, 0.0),
|
||||||
|
align_corners: bool = True,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
**batch_size** : Batch size of the underlying grids
|
**batch_size** : Batch size of the underlying grids
|
||||||
@ -674,15 +688,21 @@ class VolumeLocator:
|
|||||||
b) a Tensor of shape (3,)
|
b) a Tensor of shape (3,)
|
||||||
c) a Tensor of shape (minibatch, 3)
|
c) a Tensor of shape (minibatch, 3)
|
||||||
d) a Tensor of shape (1,) (square voxels)
|
d) a Tensor of shape (1,) (square voxels)
|
||||||
|
**align_corners**: If set (default), the coordinates of the corner voxels are
|
||||||
|
exactly −1 or +1 in the local coordinate system. Otherwise, the coordinates
|
||||||
|
correspond to the centers of the corner voxels. Cf. the namesake argument to
|
||||||
|
`torch.nn.functional.grid_sample`.
|
||||||
"""
|
"""
|
||||||
self.device = device
|
self.device = device
|
||||||
self._batch_size = batch_size
|
self._batch_size = batch_size
|
||||||
self._grid_sizes = self._convert_grid_sizes2tensor(grid_sizes)
|
self._grid_sizes = self._convert_grid_sizes2tensor(grid_sizes)
|
||||||
self._resolution = tuple(torch.max(self._grid_sizes.cpu(), dim=0).values)
|
self._resolution = tuple(torch.max(self._grid_sizes.cpu(), dim=0).values)
|
||||||
|
self._align_corners = align_corners
|
||||||
|
|
||||||
# set the local_to_world transform
|
# set the local_to_world transform
|
||||||
self._set_local_to_world_transform(
|
self._set_local_to_world_transform(
|
||||||
voxel_size=voxel_size, volume_translation=volume_translation
|
voxel_size=voxel_size,
|
||||||
|
volume_translation=volume_translation,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _convert_grid_sizes2tensor(
|
def _convert_grid_sizes2tensor(
|
||||||
@ -806,8 +826,17 @@ class VolumeLocator:
|
|||||||
grid_sizes = self.get_grid_sizes()
|
grid_sizes = self.get_grid_sizes()
|
||||||
|
|
||||||
# generate coordinate axes
|
# generate coordinate axes
|
||||||
|
def corner_coord_adjustment(r):
|
||||||
|
return 0.0 if self._align_corners else 1.0 / r
|
||||||
|
|
||||||
vol_axes = [
|
vol_axes = [
|
||||||
torch.linspace(-1.0, 1.0, r, dtype=torch.float32, device=self.device)
|
torch.linspace(
|
||||||
|
-1.0 + corner_coord_adjustment(r),
|
||||||
|
1.0 - corner_coord_adjustment(r),
|
||||||
|
r,
|
||||||
|
dtype=torch.float32,
|
||||||
|
device=self.device,
|
||||||
|
)
|
||||||
for r in (de, he, wi)
|
for r in (de, he, wi)
|
||||||
]
|
]
|
||||||
|
|
||||||
|
@ -312,6 +312,49 @@ class TestVolumes(TestCaseMixin, unittest.TestCase):
|
|||||||
).permute(0, 2, 3, 4, 1)
|
).permute(0, 2, 3, 4, 1)
|
||||||
self.assertClose(grid_world_resampled, grid_world, atol=1e-7)
|
self.assertClose(grid_world_resampled, grid_world, atol=1e-7)
|
||||||
|
|
||||||
|
for align_corners in [True, False]:
|
||||||
|
v_trivial = Volumes(densities=densities, align_corners=align_corners)
|
||||||
|
|
||||||
|
# check the case with x_world=(0,0,0)
|
||||||
|
pts_world = torch.zeros(
|
||||||
|
num_volumes, 1, 3, device=device, dtype=torch.float32
|
||||||
|
)
|
||||||
|
pts_local = v_trivial.world_to_local_coords(pts_world)
|
||||||
|
pts_local_expected = torch.zeros_like(pts_local)
|
||||||
|
self.assertClose(pts_local, pts_local_expected)
|
||||||
|
|
||||||
|
# check the case with x_world=(-2, 3, -2)
|
||||||
|
pts_world_tuple = [-2, 3, -2]
|
||||||
|
pts_world = torch.tensor(
|
||||||
|
pts_world_tuple, device=device, dtype=torch.float32
|
||||||
|
)[None, None].repeat(num_volumes, 1, 1)
|
||||||
|
pts_local = v_trivial.world_to_local_coords(pts_world)
|
||||||
|
pts_local_expected = torch.tensor(
|
||||||
|
[-1, 1, -1], device=device, dtype=torch.float32
|
||||||
|
)[None, None].repeat(num_volumes, 1, 1)
|
||||||
|
self.assertClose(pts_local, pts_local_expected)
|
||||||
|
|
||||||
|
# # check that the central voxel has coords x_world=(0, 0, 0) and x_local(0, 0, 0)
|
||||||
|
grid_world = v_trivial.get_coord_grid(world_coordinates=True)
|
||||||
|
grid_local = v_trivial.get_coord_grid(world_coordinates=False)
|
||||||
|
for grid in (grid_world, grid_local):
|
||||||
|
x0 = grid[0, :, :, 2, 0]
|
||||||
|
y0 = grid[0, :, 3, :, 1]
|
||||||
|
z0 = grid[0, 2, :, :, 2]
|
||||||
|
for coord_line in (x0, y0, z0):
|
||||||
|
self.assertClose(
|
||||||
|
coord_line, torch.zeros_like(coord_line), atol=1e-7
|
||||||
|
)
|
||||||
|
|
||||||
|
# resample grid_world using grid_sampler with local coords
|
||||||
|
# -> make sure the resampled version is the same as original
|
||||||
|
grid_world_resampled = torch.nn.functional.grid_sample(
|
||||||
|
grid_world.permute(0, 4, 1, 2, 3),
|
||||||
|
grid_local,
|
||||||
|
align_corners=align_corners,
|
||||||
|
).permute(0, 2, 3, 4, 1)
|
||||||
|
self.assertClose(grid_world_resampled, grid_world, atol=1e-7)
|
||||||
|
|
||||||
def test_coord_grid_convention_heterogeneous(
|
def test_coord_grid_convention_heterogeneous(
|
||||||
self, num_channels=4, dtype=torch.float32
|
self, num_channels=4, dtype=torch.float32
|
||||||
):
|
):
|
||||||
|
Loading…
x
Reference in New Issue
Block a user