mirror of
				https://github.com/facebookresearch/pytorch3d.git
				synced 2025-11-04 18:02:14 +08:00 
			
		
		
		
	Align_corners switch in Volumes
Summary:
Porting this commit by davnov134 .
93a3a62800 (diff-a8e107ebe039de52ca112ac6ddfba6ebccd53b4f53030b986e13f019fe57a378)
Capability to interpret world/local coordinates with various align_corners semantics.
Reviewed By: bottler
Differential Revision: D51855420
fbshipit-source-id: 834cd220c25d7f0143d8a55ba880da5977099dd6
			
			
This commit is contained in:
		
							parent
							
								
									fbc6725f03
								
							
						
					
					
						commit
						94da8841af
					
				@ -98,6 +98,13 @@ def save_model(model, stats, fl, optimizer=None, cfg=None):
 | 
			
		||||
    return flstats, flmodel, flopt
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def save_stats(stats, fl, cfg=None):
 | 
			
		||||
    flstats = get_stats_path(fl)
 | 
			
		||||
    logger.info("saving model stats to %s" % flstats)
 | 
			
		||||
    stats.save(flstats)
 | 
			
		||||
    return flstats
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def load_model(fl, map_location: Optional[dict]):
 | 
			
		||||
    flstats = get_stats_path(fl)
 | 
			
		||||
    flmodel = get_model_path(fl)
 | 
			
		||||
 | 
			
		||||
@ -291,6 +291,7 @@ def add_pointclouds_to_volumes(
 | 
			
		||||
        mask=mask,
 | 
			
		||||
        mode=mode,
 | 
			
		||||
        rescale_features=rescale_features,
 | 
			
		||||
        align_corners=initial_volumes.get_align_corners(),
 | 
			
		||||
        _python=_python,
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
@ -310,6 +311,7 @@ def add_points_features_to_volume_densities_features(
 | 
			
		||||
    grid_sizes: Optional[torch.LongTensor] = None,
 | 
			
		||||
    rescale_features: bool = True,
 | 
			
		||||
    _python: bool = False,
 | 
			
		||||
    align_corners: bool = True,
 | 
			
		||||
) -> Tuple[torch.Tensor, torch.Tensor]:
 | 
			
		||||
    """
 | 
			
		||||
    Convert a batch of point clouds represented with tensors of per-point
 | 
			
		||||
@ -356,6 +358,7 @@ def add_points_features_to_volume_densities_features(
 | 
			
		||||
                            output densities are just summed without rescaling, so
 | 
			
		||||
                            you may need to rescale them afterwards.
 | 
			
		||||
        _python: Set to True to use a pure Python implementation.
 | 
			
		||||
        align_corners: as for grid_sample.
 | 
			
		||||
    Returns:
 | 
			
		||||
        volume_features: Output volume of shape `(minibatch, feature_dim, D, H, W)`
 | 
			
		||||
        volume_densities: Occupancy volume of shape `(minibatch, 1, D, H, W)`
 | 
			
		||||
@ -409,7 +412,7 @@ def add_points_features_to_volume_densities_features(
 | 
			
		||||
        grid_sizes,
 | 
			
		||||
        1.0,  # point_weight
 | 
			
		||||
        mask,
 | 
			
		||||
        True,  # align_corners
 | 
			
		||||
        align_corners,  # align_corners
 | 
			
		||||
        splat,
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -382,9 +382,9 @@ class VolumeSampler(torch.nn.Module):
 | 
			
		||||
        rays_densities = torch.nn.functional.grid_sample(
 | 
			
		||||
            volumes_densities,
 | 
			
		||||
            rays_points_local_flat,
 | 
			
		||||
            align_corners=True,
 | 
			
		||||
            mode=self._sample_mode,
 | 
			
		||||
            padding_mode=self._padding_mode,
 | 
			
		||||
            align_corners=self._volumes.get_align_corners(),
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        # permute the dimensions & reshape densities after sampling
 | 
			
		||||
@ -400,9 +400,9 @@ class VolumeSampler(torch.nn.Module):
 | 
			
		||||
            rays_features = torch.nn.functional.grid_sample(
 | 
			
		||||
                volumes_features,
 | 
			
		||||
                rays_points_local_flat,
 | 
			
		||||
                align_corners=True,
 | 
			
		||||
                mode=self._sample_mode,
 | 
			
		||||
                padding_mode=self._padding_mode,
 | 
			
		||||
                align_corners=self._volumes.get_align_corners(),
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
            # permute the dimensions & reshape features after sampling
 | 
			
		||||
 | 
			
		||||
@ -85,7 +85,7 @@ class Volumes:
 | 
			
		||||
              are linearly interpolated over the spatial dimensions of the volume.
 | 
			
		||||
            - Note that the convention is the same as for the 5D version of the
 | 
			
		||||
              `torch.nn.functional.grid_sample` function called with
 | 
			
		||||
              `align_corners==True`.
 | 
			
		||||
              the same value of `align_corners` argument.
 | 
			
		||||
            - Note that the local coordinate convention of `Volumes`
 | 
			
		||||
              (+X = left to right, +Y = top to bottom, +Z = away from the user)
 | 
			
		||||
              is *different* from the world coordinate convention of the
 | 
			
		||||
@ -143,7 +143,7 @@ class Volumes:
 | 
			
		||||
                torch.nn.functional.grid_sample(
 | 
			
		||||
                    v.densities(),
 | 
			
		||||
                    v.get_coord_grid(world_coordinates=False),
 | 
			
		||||
                    align_corners=True,
 | 
			
		||||
                    align_corners=align_corners,
 | 
			
		||||
                ) == v.densities(),
 | 
			
		||||
 | 
			
		||||
            i.e. sampling the volume at trivial local coordinates
 | 
			
		||||
@ -157,6 +157,7 @@ class Volumes:
 | 
			
		||||
        features: Optional[_TensorBatch] = None,
 | 
			
		||||
        voxel_size: _VoxelSize = 1.0,
 | 
			
		||||
        volume_translation: _Translation = (0.0, 0.0, 0.0),
 | 
			
		||||
        align_corners: bool = True,
 | 
			
		||||
    ) -> None:
 | 
			
		||||
        """
 | 
			
		||||
        Args:
 | 
			
		||||
@ -186,6 +187,10 @@ class Volumes:
 | 
			
		||||
                b) a Tensor of shape (3,)
 | 
			
		||||
                c) a Tensor of shape (minibatch, 3)
 | 
			
		||||
                d) a Tensor of shape (1,) (square voxels)
 | 
			
		||||
            **align_corners**: If set (default), the coordinates of the corner voxels are
 | 
			
		||||
                exactly −1 or +1 in the local coordinate system. Otherwise, the coordinates
 | 
			
		||||
                correspond to the centers of the corner voxels. Cf. the namesake argument to
 | 
			
		||||
                `torch.nn.functional.grid_sample`.
 | 
			
		||||
        """
 | 
			
		||||
 | 
			
		||||
        # handle densities
 | 
			
		||||
@ -206,6 +211,7 @@ class Volumes:
 | 
			
		||||
            voxel_size=voxel_size,
 | 
			
		||||
            volume_translation=volume_translation,
 | 
			
		||||
            device=self.device,
 | 
			
		||||
            align_corners=align_corners,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        # handle features
 | 
			
		||||
@ -336,6 +342,13 @@ class Volumes:
 | 
			
		||||
            return None
 | 
			
		||||
        return self._features_densities_list(features_)
 | 
			
		||||
 | 
			
		||||
    def get_align_corners(self) -> bool:
 | 
			
		||||
        """
 | 
			
		||||
        Return whether the corners of the voxels should be aligned with the
 | 
			
		||||
        image pixels.
 | 
			
		||||
        """
 | 
			
		||||
        return self.locator._align_corners
 | 
			
		||||
 | 
			
		||||
    def _features_densities_list(self, x: torch.Tensor) -> List[torch.Tensor]:
 | 
			
		||||
        """
 | 
			
		||||
        Retrieve the list representation of features/densities.
 | 
			
		||||
@ -576,7 +589,7 @@ class VolumeLocator:
 | 
			
		||||
              are linearly interpolated over the spatial dimensions of the volume.
 | 
			
		||||
            - Note that the convention is the same as for the 5D version of the
 | 
			
		||||
              `torch.nn.functional.grid_sample` function called with
 | 
			
		||||
              `align_corners==True`.
 | 
			
		||||
              the same value of `align_corners` argument.
 | 
			
		||||
            - Note that the local coordinate convention of `VolumeLocator`
 | 
			
		||||
              (+X = left to right, +Y = top to bottom, +Z = away from the user)
 | 
			
		||||
              is *different* from the world coordinate convention of the
 | 
			
		||||
@ -634,7 +647,7 @@ class VolumeLocator:
 | 
			
		||||
                torch.nn.functional.grid_sample(
 | 
			
		||||
                    v.densities(),
 | 
			
		||||
                    v.get_coord_grid(world_coordinates=False),
 | 
			
		||||
                    align_corners=True,
 | 
			
		||||
                    align_corners=align_corners,
 | 
			
		||||
                ) == v.densities(),
 | 
			
		||||
 | 
			
		||||
            i.e. sampling the volume at trivial local coordinates
 | 
			
		||||
@ -651,6 +664,7 @@ class VolumeLocator:
 | 
			
		||||
        device: torch.device,
 | 
			
		||||
        voxel_size: _VoxelSize = 1.0,
 | 
			
		||||
        volume_translation: _Translation = (0.0, 0.0, 0.0),
 | 
			
		||||
        align_corners: bool = True,
 | 
			
		||||
    ):
 | 
			
		||||
        """
 | 
			
		||||
        **batch_size** : Batch size of the underlying grids
 | 
			
		||||
@ -674,15 +688,21 @@ class VolumeLocator:
 | 
			
		||||
            b) a Tensor of shape (3,)
 | 
			
		||||
            c) a Tensor of shape (minibatch, 3)
 | 
			
		||||
            d) a Tensor of shape (1,) (square voxels)
 | 
			
		||||
        **align_corners**: If set (default), the coordinates of the corner voxels are
 | 
			
		||||
            exactly −1 or +1 in the local coordinate system. Otherwise, the coordinates
 | 
			
		||||
            correspond to the centers of the corner voxels. Cf. the namesake argument to
 | 
			
		||||
            `torch.nn.functional.grid_sample`.
 | 
			
		||||
        """
 | 
			
		||||
        self.device = device
 | 
			
		||||
        self._batch_size = batch_size
 | 
			
		||||
        self._grid_sizes = self._convert_grid_sizes2tensor(grid_sizes)
 | 
			
		||||
        self._resolution = tuple(torch.max(self._grid_sizes.cpu(), dim=0).values)
 | 
			
		||||
        self._align_corners = align_corners
 | 
			
		||||
 | 
			
		||||
        # set the local_to_world transform
 | 
			
		||||
        self._set_local_to_world_transform(
 | 
			
		||||
            voxel_size=voxel_size, volume_translation=volume_translation
 | 
			
		||||
            voxel_size=voxel_size,
 | 
			
		||||
            volume_translation=volume_translation,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    def _convert_grid_sizes2tensor(
 | 
			
		||||
@ -806,8 +826,17 @@ class VolumeLocator:
 | 
			
		||||
        grid_sizes = self.get_grid_sizes()
 | 
			
		||||
 | 
			
		||||
        # generate coordinate axes
 | 
			
		||||
        def corner_coord_adjustment(r):
 | 
			
		||||
            return 0.0 if self._align_corners else 1.0 / r
 | 
			
		||||
 | 
			
		||||
        vol_axes = [
 | 
			
		||||
            torch.linspace(-1.0, 1.0, r, dtype=torch.float32, device=self.device)
 | 
			
		||||
            torch.linspace(
 | 
			
		||||
                -1.0 + corner_coord_adjustment(r),
 | 
			
		||||
                1.0 - corner_coord_adjustment(r),
 | 
			
		||||
                r,
 | 
			
		||||
                dtype=torch.float32,
 | 
			
		||||
                device=self.device,
 | 
			
		||||
            )
 | 
			
		||||
            for r in (de, he, wi)
 | 
			
		||||
        ]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -312,6 +312,49 @@ class TestVolumes(TestCaseMixin, unittest.TestCase):
 | 
			
		||||
        ).permute(0, 2, 3, 4, 1)
 | 
			
		||||
        self.assertClose(grid_world_resampled, grid_world, atol=1e-7)
 | 
			
		||||
 | 
			
		||||
        for align_corners in [True, False]:
 | 
			
		||||
            v_trivial = Volumes(densities=densities, align_corners=align_corners)
 | 
			
		||||
 | 
			
		||||
            # check the case with x_world=(0,0,0)
 | 
			
		||||
            pts_world = torch.zeros(
 | 
			
		||||
                num_volumes, 1, 3, device=device, dtype=torch.float32
 | 
			
		||||
            )
 | 
			
		||||
            pts_local = v_trivial.world_to_local_coords(pts_world)
 | 
			
		||||
            pts_local_expected = torch.zeros_like(pts_local)
 | 
			
		||||
            self.assertClose(pts_local, pts_local_expected)
 | 
			
		||||
 | 
			
		||||
            # check the case with x_world=(-2, 3, -2)
 | 
			
		||||
            pts_world_tuple = [-2, 3, -2]
 | 
			
		||||
            pts_world = torch.tensor(
 | 
			
		||||
                pts_world_tuple, device=device, dtype=torch.float32
 | 
			
		||||
            )[None, None].repeat(num_volumes, 1, 1)
 | 
			
		||||
            pts_local = v_trivial.world_to_local_coords(pts_world)
 | 
			
		||||
            pts_local_expected = torch.tensor(
 | 
			
		||||
                [-1, 1, -1], device=device, dtype=torch.float32
 | 
			
		||||
            )[None, None].repeat(num_volumes, 1, 1)
 | 
			
		||||
            self.assertClose(pts_local, pts_local_expected)
 | 
			
		||||
 | 
			
		||||
            # # check that the central voxel has coords x_world=(0, 0, 0) and x_local(0, 0, 0)
 | 
			
		||||
            grid_world = v_trivial.get_coord_grid(world_coordinates=True)
 | 
			
		||||
            grid_local = v_trivial.get_coord_grid(world_coordinates=False)
 | 
			
		||||
            for grid in (grid_world, grid_local):
 | 
			
		||||
                x0 = grid[0, :, :, 2, 0]
 | 
			
		||||
                y0 = grid[0, :, 3, :, 1]
 | 
			
		||||
                z0 = grid[0, 2, :, :, 2]
 | 
			
		||||
                for coord_line in (x0, y0, z0):
 | 
			
		||||
                    self.assertClose(
 | 
			
		||||
                        coord_line, torch.zeros_like(coord_line), atol=1e-7
 | 
			
		||||
                    )
 | 
			
		||||
 | 
			
		||||
            # resample grid_world using grid_sampler with local coords
 | 
			
		||||
            # -> make sure the resampled version is the same as original
 | 
			
		||||
            grid_world_resampled = torch.nn.functional.grid_sample(
 | 
			
		||||
                grid_world.permute(0, 4, 1, 2, 3),
 | 
			
		||||
                grid_local,
 | 
			
		||||
                align_corners=align_corners,
 | 
			
		||||
            ).permute(0, 2, 3, 4, 1)
 | 
			
		||||
            self.assertClose(grid_world_resampled, grid_world, atol=1e-7)
 | 
			
		||||
 | 
			
		||||
    def test_coord_grid_convention_heterogeneous(
 | 
			
		||||
        self, num_channels=4, dtype=torch.float32
 | 
			
		||||
    ):
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user