diff --git a/projects/implicitron_trainer/tests/experiment.yaml b/projects/implicitron_trainer/tests/experiment.yaml index 613ca168..e8cdba05 100644 --- a/projects/implicitron_trainer/tests/experiment.yaml +++ b/projects/implicitron_trainer/tests/experiment.yaml @@ -103,8 +103,10 @@ data_source_ImplicitronDataSource_args: num_views: 40 data_file: null azimuth_range: 180.0 + distance: 2.7 resolution: 128 use_point_light: true + gpu_idx: 0 path_manager_factory_class_type: PathManagerFactory path_manager_factory_PathManagerFactory_args: silence_logs: true diff --git a/pytorch3d/implicitron/dataset/rendered_mesh_dataset_map_provider.py b/pytorch3d/implicitron/dataset/rendered_mesh_dataset_map_provider.py index 3cc793fd..953ca904 100644 --- a/pytorch3d/implicitron/dataset/rendered_mesh_dataset_map_provider.py +++ b/pytorch3d/implicitron/dataset/rendered_mesh_dataset_map_provider.py @@ -49,7 +49,7 @@ class RenderedMeshDatasetMapProvider(DatasetMapProviderBase): # pyre-ignore [13 if one is available, the data it produces is on the CPU just like the data returned by implicitron's other dataset map providers. This is because both datasets and models can be large, so implicitron's - GenericModel.forward (etc) expects data on the CPU and only moves + training loop expects data on the CPU and only moves what it needs to the device. For a more detailed explanation of this code, please refer to the @@ -61,16 +61,23 @@ class RenderedMeshDatasetMapProvider(DatasetMapProviderBase): # pyre-ignore [13 the cow mesh in the same repo as this code. azimuth_range: number of degrees on each side of the start position to take samples + distance: distance from camera centres to the origin. resolution: the common height and width of the output images. use_point_light: whether to use a particular point light as opposed to ambient white. + gpu_idx: which gpu to use for rendering the mesh. + path_manager_factory: (Optional) An object that generates an instance of + PathManager that can translate provided file paths. + path_manager_factory_class_type: The class type of `path_manager_factory`. """ num_views: int = 40 data_file: Optional[str] = None azimuth_range: float = 180 + distance: float = 2.7 resolution: int = 128 use_point_light: bool = True + gpu_idx: Optional[int] = 0 path_manager_factory: PathManagerFactory path_manager_factory_class_type: str = "PathManagerFactory" @@ -85,8 +92,8 @@ class RenderedMeshDatasetMapProvider(DatasetMapProviderBase): # pyre-ignore [13 def __post_init__(self) -> None: super().__init__() run_auto_creation(self) - if torch.cuda.is_available(): - device = torch.device("cuda:0") + if torch.cuda.is_available() and self.gpu_idx is not None: + device = torch.device(f"cuda:{self.gpu_idx}") else: device = torch.device("cpu") if self.data_file is None: @@ -106,13 +113,13 @@ class RenderedMeshDatasetMapProvider(DatasetMapProviderBase): # pyre-ignore [13 num_views=self.num_views, mesh=mesh, azimuth_range=self.azimuth_range, + distance=self.distance, resolution=self.resolution, device=device, use_point_light=self.use_point_light, ) # pyre-ignore[16] self.poses = poses.cpu() - expand_args_fields(SingleSceneDataset) # pyre-ignore[16] self.train_dataset = SingleSceneDataset( # pyre-ignore[28] object_name="cow", @@ -130,6 +137,7 @@ def _generate_cow_renders( num_views: int, mesh: Meshes, azimuth_range: float, + distance: float, resolution: int, device: torch.device, use_point_light: bool, @@ -168,11 +176,11 @@ def _generate_cow_renders( else: lights = AmbientLights(device=device) - # Initialize an OpenGL perspective camera that represents a batch of different + # Initialize a perspective camera that represents a batch of different # viewing angles. All the cameras helper methods support mixed type inputs and - # broadcasting. So we can view the camera from the a distance of dist=2.7, and + # broadcasting. So we can view the camera from a fixed distance, and # then specify elevation and azimuth angles for each viewpoint as tensors. - R, T = look_at_view_transform(dist=2.7, elev=elev, azim=azim) + R, T = look_at_view_transform(dist=distance, elev=elev, azim=azim) cameras = FoVPerspectiveCameras(device=device, R=R, T=T) # Define the settings for rasterization and shading.