mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 11:52:50 +08:00
Improve memory efficiency in VolumeSampler
Summary: Avoids use of `torch.cat` operation when rendering a volume by instead issuing multiple calls to `torch.nn.functional.grid_sample`. Density and color tensors can be large. Reviewed By: bottler Differential Revision: D40072399 fbshipit-source-id: eb4cd34f6171d54972bbf2877065f973db497de0
This commit is contained in:
parent
0d8608b9f9
commit
4c8338b00f
@ -363,35 +363,40 @@ class VolumeSampler(torch.nn.Module):
|
|||||||
volumes_densities = self._volumes.densities()
|
volumes_densities = self._volumes.densities()
|
||||||
dim_density = volumes_densities.shape[1]
|
dim_density = volumes_densities.shape[1]
|
||||||
volumes_features = self._volumes.features()
|
volumes_features = self._volumes.features()
|
||||||
# adjust the volumes_features variable in case we have a feature-less volume
|
|
||||||
if volumes_features is None:
|
|
||||||
dim_feature = 0
|
|
||||||
data_to_sample = volumes_densities
|
|
||||||
else:
|
|
||||||
dim_feature = volumes_features.shape[1]
|
|
||||||
data_to_sample = torch.cat((volumes_densities, volumes_features), dim=1)
|
|
||||||
|
|
||||||
# reshape to a size which grid_sample likes
|
# reshape to a size which grid_sample likes
|
||||||
rays_points_local_flat = rays_points_local.view(
|
rays_points_local_flat = rays_points_local.view(
|
||||||
rays_points_local.shape[0], -1, 1, 1, 3
|
rays_points_local.shape[0], -1, 1, 1, 3
|
||||||
)
|
)
|
||||||
|
|
||||||
# run the grid sampler
|
# run the grid sampler on the volumes densities
|
||||||
data_sampled = torch.nn.functional.grid_sample(
|
rays_densities = torch.nn.functional.grid_sample(
|
||||||
data_to_sample,
|
volumes_densities,
|
||||||
rays_points_local_flat,
|
rays_points_local_flat,
|
||||||
align_corners=True,
|
align_corners=True,
|
||||||
mode=self._sample_mode,
|
mode=self._sample_mode,
|
||||||
)
|
)
|
||||||
|
|
||||||
# permute the dimensions & reshape after sampling
|
# permute the dimensions & reshape densities after sampling
|
||||||
data_sampled = data_sampled.permute(0, 2, 3, 4, 1).view(
|
rays_densities = rays_densities.permute(0, 2, 3, 4, 1).view(
|
||||||
*rays_points_local.shape[:-1], data_sampled.shape[1]
|
*rays_points_local.shape[:-1], volumes_densities.shape[1]
|
||||||
)
|
)
|
||||||
|
|
||||||
# split back to densities and features
|
# if features exist, run grid sampler again on the features densities
|
||||||
rays_densities, rays_features = data_sampled.split(
|
if volumes_features is None:
|
||||||
[dim_density, dim_feature], dim=-1
|
dim_feature = 0
|
||||||
)
|
_, rays_features = rays_densities.split([dim_density, dim_feature], dim=-1)
|
||||||
|
else:
|
||||||
|
rays_features = torch.nn.functional.grid_sample(
|
||||||
|
volumes_features,
|
||||||
|
rays_points_local_flat,
|
||||||
|
align_corners=True,
|
||||||
|
mode=self._sample_mode,
|
||||||
|
)
|
||||||
|
|
||||||
|
# permute the dimensions & reshape features after sampling
|
||||||
|
rays_features = rays_features.permute(0, 2, 3, 4, 1).view(
|
||||||
|
*rays_points_local.shape[:-1], volumes_features.shape[1]
|
||||||
|
)
|
||||||
|
|
||||||
return rays_densities, rays_features
|
return rays_densities, rays_features
|
||||||
|
Loading…
x
Reference in New Issue
Block a user