rendered_mesh_dataset improvements

Summary: Allow choosing the device and the distance

Reviewed By: shapovalov

Differential Revision: D42451605

fbshipit-source-id: 214f02d09da94eb127b3cc308d5bae800dc7b9e2
This commit is contained in:
Jeremy Reizenstein 2023-01-16 07:46:41 -08:00 committed by Facebook GitHub Bot
parent acc60db4f4
commit b7e3b7b16c
2 changed files with 17 additions and 7 deletions

View File

@ -103,8 +103,10 @@ data_source_ImplicitronDataSource_args:
num_views: 40 num_views: 40
data_file: null data_file: null
azimuth_range: 180.0 azimuth_range: 180.0
distance: 2.7
resolution: 128 resolution: 128
use_point_light: true use_point_light: true
gpu_idx: 0
path_manager_factory_class_type: PathManagerFactory path_manager_factory_class_type: PathManagerFactory
path_manager_factory_PathManagerFactory_args: path_manager_factory_PathManagerFactory_args:
silence_logs: true silence_logs: true

View File

@ -49,7 +49,7 @@ class RenderedMeshDatasetMapProvider(DatasetMapProviderBase): # pyre-ignore [13
if one is available, the data it produces is on the CPU just like if one is available, the data it produces is on the CPU just like
the data returned by implicitron's other dataset map providers. the data returned by implicitron's other dataset map providers.
This is because both datasets and models can be large, so implicitron's This is because both datasets and models can be large, so implicitron's
GenericModel.forward (etc) expects data on the CPU and only moves training loop expects data on the CPU and only moves
what it needs to the device. what it needs to the device.
For a more detailed explanation of this code, please refer to the For a more detailed explanation of this code, please refer to the
@ -61,16 +61,23 @@ class RenderedMeshDatasetMapProvider(DatasetMapProviderBase): # pyre-ignore [13
the cow mesh in the same repo as this code. the cow mesh in the same repo as this code.
azimuth_range: number of degrees on each side of the start position to azimuth_range: number of degrees on each side of the start position to
take samples take samples
distance: distance from camera centres to the origin.
resolution: the common height and width of the output images. resolution: the common height and width of the output images.
use_point_light: whether to use a particular point light as opposed use_point_light: whether to use a particular point light as opposed
to ambient white. to ambient white.
gpu_idx: which gpu to use for rendering the mesh.
path_manager_factory: (Optional) An object that generates an instance of
PathManager that can translate provided file paths.
path_manager_factory_class_type: The class type of `path_manager_factory`.
""" """
num_views: int = 40 num_views: int = 40
data_file: Optional[str] = None data_file: Optional[str] = None
azimuth_range: float = 180 azimuth_range: float = 180
distance: float = 2.7
resolution: int = 128 resolution: int = 128
use_point_light: bool = True use_point_light: bool = True
gpu_idx: Optional[int] = 0
path_manager_factory: PathManagerFactory path_manager_factory: PathManagerFactory
path_manager_factory_class_type: str = "PathManagerFactory" path_manager_factory_class_type: str = "PathManagerFactory"
@ -85,8 +92,8 @@ class RenderedMeshDatasetMapProvider(DatasetMapProviderBase): # pyre-ignore [13
def __post_init__(self) -> None: def __post_init__(self) -> None:
super().__init__() super().__init__()
run_auto_creation(self) run_auto_creation(self)
if torch.cuda.is_available(): if torch.cuda.is_available() and self.gpu_idx is not None:
device = torch.device("cuda:0") device = torch.device(f"cuda:{self.gpu_idx}")
else: else:
device = torch.device("cpu") device = torch.device("cpu")
if self.data_file is None: if self.data_file is None:
@ -106,13 +113,13 @@ class RenderedMeshDatasetMapProvider(DatasetMapProviderBase): # pyre-ignore [13
num_views=self.num_views, num_views=self.num_views,
mesh=mesh, mesh=mesh,
azimuth_range=self.azimuth_range, azimuth_range=self.azimuth_range,
distance=self.distance,
resolution=self.resolution, resolution=self.resolution,
device=device, device=device,
use_point_light=self.use_point_light, use_point_light=self.use_point_light,
) )
# pyre-ignore[16] # pyre-ignore[16]
self.poses = poses.cpu() self.poses = poses.cpu()
expand_args_fields(SingleSceneDataset)
# pyre-ignore[16] # pyre-ignore[16]
self.train_dataset = SingleSceneDataset( # pyre-ignore[28] self.train_dataset = SingleSceneDataset( # pyre-ignore[28]
object_name="cow", object_name="cow",
@ -130,6 +137,7 @@ def _generate_cow_renders(
num_views: int, num_views: int,
mesh: Meshes, mesh: Meshes,
azimuth_range: float, azimuth_range: float,
distance: float,
resolution: int, resolution: int,
device: torch.device, device: torch.device,
use_point_light: bool, use_point_light: bool,
@ -168,11 +176,11 @@ def _generate_cow_renders(
else: else:
lights = AmbientLights(device=device) lights = AmbientLights(device=device)
# Initialize an OpenGL perspective camera that represents a batch of different # Initialize a perspective camera that represents a batch of different
# viewing angles. All the cameras helper methods support mixed type inputs and # viewing angles. All the cameras helper methods support mixed type inputs and
# broadcasting. So we can view the camera from the a distance of dist=2.7, and # broadcasting. So we can view the camera from a fixed distance, and
# then specify elevation and azimuth angles for each viewpoint as tensors. # then specify elevation and azimuth angles for each viewpoint as tensors.
R, T = look_at_view_transform(dist=2.7, elev=elev, azim=azim) R, T = look_at_view_transform(dist=distance, elev=elev, azim=azim)
cameras = FoVPerspectiveCameras(device=device, R=R, T=T) cameras = FoVPerspectiveCameras(device=device, R=R, T=T)
# Define the settings for rasterization and shading. # Define the settings for rasterization and shading.