mirror of
				https://github.com/facebookresearch/pytorch3d.git
				synced 2025-11-04 18:02:14 +08:00 
			
		
		
		
	Improve memory efficiency in VolumeSampler
Summary: Avoids use of `torch.cat` operation when rendering a volume by instead issuing multiple calls to `torch.nn.functional.grid_sample`. Density and color tensors can be large. Reviewed By: bottler Differential Revision: D40072399 fbshipit-source-id: eb4cd34f6171d54972bbf2877065f973db497de0
This commit is contained in:
		
							parent
							
								
									0d8608b9f9
								
							
						
					
					
						commit
						4c8338b00f
					
				@ -363,35 +363,40 @@ class VolumeSampler(torch.nn.Module):
 | 
			
		||||
        volumes_densities = self._volumes.densities()
 | 
			
		||||
        dim_density = volumes_densities.shape[1]
 | 
			
		||||
        volumes_features = self._volumes.features()
 | 
			
		||||
        # adjust the volumes_features variable in case we have a feature-less volume
 | 
			
		||||
        if volumes_features is None:
 | 
			
		||||
            dim_feature = 0
 | 
			
		||||
            data_to_sample = volumes_densities
 | 
			
		||||
        else:
 | 
			
		||||
            dim_feature = volumes_features.shape[1]
 | 
			
		||||
            data_to_sample = torch.cat((volumes_densities, volumes_features), dim=1)
 | 
			
		||||
 | 
			
		||||
        # reshape to a size which grid_sample likes
 | 
			
		||||
        rays_points_local_flat = rays_points_local.view(
 | 
			
		||||
            rays_points_local.shape[0], -1, 1, 1, 3
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        # run the grid sampler
 | 
			
		||||
        data_sampled = torch.nn.functional.grid_sample(
 | 
			
		||||
            data_to_sample,
 | 
			
		||||
        # run the grid sampler on the volumes densities
 | 
			
		||||
        rays_densities = torch.nn.functional.grid_sample(
 | 
			
		||||
            volumes_densities,
 | 
			
		||||
            rays_points_local_flat,
 | 
			
		||||
            align_corners=True,
 | 
			
		||||
            mode=self._sample_mode,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        # permute the dimensions & reshape after sampling
 | 
			
		||||
        data_sampled = data_sampled.permute(0, 2, 3, 4, 1).view(
 | 
			
		||||
            *rays_points_local.shape[:-1], data_sampled.shape[1]
 | 
			
		||||
        # permute the dimensions & reshape densities after sampling
 | 
			
		||||
        rays_densities = rays_densities.permute(0, 2, 3, 4, 1).view(
 | 
			
		||||
            *rays_points_local.shape[:-1], volumes_densities.shape[1]
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        # split back to densities and features
 | 
			
		||||
        rays_densities, rays_features = data_sampled.split(
 | 
			
		||||
            [dim_density, dim_feature], dim=-1
 | 
			
		||||
        )
 | 
			
		||||
        # if features exist, run grid sampler again on the features densities
 | 
			
		||||
        if volumes_features is None:
 | 
			
		||||
            dim_feature = 0
 | 
			
		||||
            _, rays_features = rays_densities.split([dim_density, dim_feature], dim=-1)
 | 
			
		||||
        else:
 | 
			
		||||
            rays_features = torch.nn.functional.grid_sample(
 | 
			
		||||
                volumes_features,
 | 
			
		||||
                rays_points_local_flat,
 | 
			
		||||
                align_corners=True,
 | 
			
		||||
                mode=self._sample_mode,
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
            # permute the dimensions & reshape features after sampling
 | 
			
		||||
            rays_features = rays_features.permute(0, 2, 3, 4, 1).view(
 | 
			
		||||
                *rays_points_local.shape[:-1], volumes_features.shape[1]
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
        return rays_densities, rays_features
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user