diff --git a/docs/tutorials/data/cow_mesh/README.md b/docs/tutorials/data/cow_mesh/README.md index c7ff345e..3e5cefc7 100644 --- a/docs/tutorials/data/cow_mesh/README.md +++ b/docs/tutorials/data/cow_mesh/README.md @@ -1,5 +1,5 @@ # Acknowledgements -Thank you to Keenen Crane for allowing the cow mesh model to be used freely in the public domain. +Thank you to Keenan Crane for allowing the cow mesh model to be used freely in the public domain. ###### Source: http://www.cs.cmu.edu/~kmcrane/Projects/ModelRepository/ diff --git a/docs/tutorials/utils/generate_cow_renders.py b/docs/tutorials/utils/generate_cow_renders.py index eaf0ced5..89f9ba4f 100644 --- a/docs/tutorials/utils/generate_cow_renders.py +++ b/docs/tutorials/utils/generate_cow_renders.py @@ -44,6 +44,8 @@ def generate_cow_renders( data_dir: The folder that contains the cow mesh files. If the cow mesh files do not exist in the folder, this function will automatically download them. + azimuth_range: number of degrees on each side of the start position to + take samples Returns: cameras: A batch of `num_views` `FoVPerspectiveCameras` from which the diff --git a/projects/implicitron_trainer/tests/experiment.yaml b/projects/implicitron_trainer/tests/experiment.yaml index 4ef2b510..2ecba628 100644 --- a/projects/implicitron_trainer/tests/experiment.yaml +++ b/projects/implicitron_trainer/tests/experiment.yaml @@ -101,6 +101,15 @@ data_source_ImplicitronDataSource_args: n_known_frames_for_test: null path_manager_factory_PathManagerFactory_args: silence_logs: true + dataset_map_provider_RenderedMeshDatasetMapProvider_args: + num_views: 40 + data_file: null + azimuth_range: 180.0 + resolution: 128 + use_point_light: true + path_manager_factory_class_type: PathManagerFactory + path_manager_factory_PathManagerFactory_args: + silence_logs: true data_loader_map_provider_SequenceDataLoaderMapProvider_args: batch_size: 1 num_workers: 0 diff --git a/pytorch3d/implicitron/dataset/data_source.py b/pytorch3d/implicitron/dataset/data_source.py index 9696597a..880679a6 100644 --- a/pytorch3d/implicitron/dataset/data_source.py +++ b/pytorch3d/implicitron/dataset/data_source.py @@ -19,6 +19,7 @@ from .dataset_map_provider import DatasetMap, DatasetMapProviderBase, Task from .json_index_dataset_map_provider import JsonIndexDatasetMapProvider # noqa from .json_index_dataset_map_provider_v2 import JsonIndexDatasetMapProviderV2 # noqa from .llff_dataset_map_provider import LlffDatasetMapProvider # noqa +from .rendered_mesh_dataset_map_provider import RenderedMeshDatasetMapProvider # noqa class DataSourceBase(ReplaceableBase): diff --git a/pytorch3d/implicitron/dataset/rendered_mesh_dataset_map_provider.py b/pytorch3d/implicitron/dataset/rendered_mesh_dataset_map_provider.py new file mode 100644 index 00000000..1c4fca43 --- /dev/null +++ b/pytorch3d/implicitron/dataset/rendered_mesh_dataset_map_provider.py @@ -0,0 +1,219 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from os.path import dirname, join, realpath +from typing import Optional, Tuple + +import torch +from pytorch3d.implicitron.tools.config import ( + expand_args_fields, + registry, + run_auto_creation, +) +from pytorch3d.io import IO +from pytorch3d.renderer import ( + AmbientLights, + BlendParams, + CamerasBase, + FoVPerspectiveCameras, + HardPhongShader, + look_at_view_transform, + MeshRasterizer, + MeshRendererWithFragments, + PointLights, + RasterizationSettings, +) +from pytorch3d.structures.meshes import Meshes + +from .dataset_map_provider import ( + DatasetMap, + DatasetMapProviderBase, + PathManagerFactory, + Task, +) +from .single_sequence_dataset import SingleSceneDataset +from .utils import DATASET_TYPE_KNOWN + + +@registry.register +class RenderedMeshDatasetMapProvider(DatasetMapProviderBase): # pyre-ignore [13] + """ + A simple single-scene dataset based on PyTorch3D renders of a mesh. + Provides `num_views` renders of the mesh as train, with no val + and test. The renders are generated from viewpoints sampled at uniformly + distributed azimuth intervals. The elevation is kept constant so that the + camera's vertical position coincides with the equator. + + By default, uses Keenan Crane's cow model, and the camera locations are + set to make sense for that. + + Although the rendering used to generate this dataset will use a GPU + if one is available, the data it produces is on the CPU just like + the data returned by implicitron's other dataset map providers. + This is because both datasets and models can be large, so implicitron's + GenericModel.forward (etc) expects data on the CPU and only moves + what it needs to the device. + + For a more detailed explanation of this code, please refer to the + docs/tutorials/fit_textured_mesh.ipynb notebook. + + Members: + num_views: The number of generated renders. + data_file: The folder that contains the mesh file. By default, finds + the cow mesh in the same repo as this code. + azimuth_range: number of degrees on each side of the start position to + take samples + resolution: the common height and width of the output images. + use_point_light: whether to use a particular point light as opposed + to ambient white. + """ + + num_views: int = 40 + data_file: Optional[str] = None + azimuth_range: float = 180 + resolution: int = 128 + use_point_light: bool = True + path_manager_factory: PathManagerFactory + path_manager_factory_class_type: str = "PathManagerFactory" + + def get_dataset_map(self) -> DatasetMap: + # pyre-ignore[16] + return DatasetMap(train=self.train_dataset, val=None, test=None) + + def get_task(self) -> Task: + return Task.SINGLE_SEQUENCE + + def get_all_train_cameras(self) -> CamerasBase: + # pyre-ignore[16] + return self.poses + + def __post_init__(self) -> None: + super().__init__() + run_auto_creation(self) + if torch.cuda.is_available(): + device = torch.device("cuda:0") + else: + device = torch.device("cpu") + if self.data_file is None: + data_file = join( + dirname(dirname(dirname(dirname(realpath(__file__))))), + "docs", + "tutorials", + "data", + "cow_mesh", + "cow.obj", + ) + else: + data_file = self.data_file + io = IO(path_manager=self.path_manager_factory.get()) + mesh = io.load_mesh(data_file, device=device) + poses, images, masks = _generate_cow_renders( + num_views=self.num_views, + mesh=mesh, + azimuth_range=self.azimuth_range, + resolution=self.resolution, + device=device, + use_point_light=self.use_point_light, + ) + # pyre-ignore[16] + self.poses = poses.cpu() + expand_args_fields(SingleSceneDataset) + # pyre-ignore[16] + self.train_dataset = SingleSceneDataset( # pyre-ignore[28] + object_name="cow", + images=list(images.permute(0, 3, 1, 2).cpu()), + fg_probabilities=list(masks[:, None].cpu()), + poses=[self.poses[i] for i in range(len(poses))], + frame_types=[DATASET_TYPE_KNOWN] * len(poses), + eval_batches=None, + ) + + +@torch.no_grad() +def _generate_cow_renders( + *, + num_views: int, + mesh: Meshes, + azimuth_range: float, + resolution: int, + device: torch.device, + use_point_light: bool, +) -> Tuple[CamerasBase, torch.Tensor, torch.Tensor]: + """ + Returns: + cameras: A batch of `num_views` `FoVPerspectiveCameras` from which the + images are rendered. + images: A tensor of shape `(num_views, height, width, 3)` containing + the rendered images. + silhouettes: A tensor of shape `(num_views, height, width)` containing + the rendered silhouettes. + """ + + # Load obj file + + # We scale normalize and center the target mesh to fit in a sphere of radius 1 + # centered at (0,0,0). (scale, center) will be used to bring the predicted mesh + # to its original center and scale. Note that normalizing the target mesh, + # speeds up the optimization but is not necessary! + verts = mesh.verts_packed() + N = verts.shape[0] + center = verts.mean(0) + scale = max((verts - center).abs().max(0)[0]) + mesh.offset_verts_(-(center.expand(N, 3))) + mesh.scale_verts_((1.0 / float(scale))) + + # Get a batch of viewing angles. + elev = torch.linspace(0, 0, num_views) # keep constant + azim = torch.linspace(-azimuth_range, azimuth_range, num_views) + 180.0 + + # Place a point light in front of the object. As mentioned above, the front of + # the cow is facing the -z direction. + if use_point_light: + lights = PointLights(device=device, location=[[0.0, 0.0, -3.0]]) + else: + lights = AmbientLights(device=device) + + # Initialize an OpenGL perspective camera that represents a batch of different + # viewing angles. All the cameras helper methods support mixed type inputs and + # broadcasting. So we can view the camera from the a distance of dist=2.7, and + # then specify elevation and azimuth angles for each viewpoint as tensors. + R, T = look_at_view_transform(dist=2.7, elev=elev, azim=azim) + cameras = FoVPerspectiveCameras(device=device, R=R, T=T) + + # Define the settings for rasterization and shading. + # As we are rendering images for visualization + # purposes only we will set faces_per_pixel=1 and blur_radius=0.0. Refer to + # rasterize_meshes.py for explanations of these parameters. We also leave + # bin_size and max_faces_per_bin to their default values of None, which sets + # their values using heuristics and ensures that the faster coarse-to-fine + # rasterization method is used. Refer to docs/notes/renderer.md for an + # explanation of the difference between naive and coarse-to-fine rasterization. + raster_settings = RasterizationSettings( + image_size=resolution, blur_radius=0.0, faces_per_pixel=1 + ) + + # Create a Phong renderer by composing a rasterizer and a shader. The textured + # Phong shader will interpolate the texture uv coordinates for each vertex, + # sample from a texture image and apply the Phong lighting model + blend_params = BlendParams(sigma=1e-4, gamma=1e-4, background_color=(0.0, 0.0, 0.0)) + rasterizer_type = MeshRasterizer + renderer = MeshRendererWithFragments( + rasterizer=rasterizer_type(cameras=cameras, raster_settings=raster_settings), + shader=HardPhongShader( + device=device, cameras=cameras, lights=lights, blend_params=blend_params + ), + ) + + # Create a batch of meshes by repeating the cow mesh and associated textures. + # Meshes has a useful `extend` method which allows us do this very easily. + # This also extends the textures. + meshes = mesh.extend(num_views) + + # Render the cow mesh from each viewing angle + target_images, fragments = renderer(meshes, cameras=cameras, lights=lights) + silhouette_binary = (fragments.pix_to_face[..., 0] >= 0).float() + + return cameras, target_images[..., :3], silhouette_binary diff --git a/pytorch3d/renderer/cameras.py b/pytorch3d/renderer/cameras.py index f7f6c9e9..74caf902 100644 --- a/pytorch3d/renderer/cameras.py +++ b/pytorch3d/renderer/cameras.py @@ -1661,9 +1661,9 @@ def look_at_rotation( def look_at_view_transform( - dist: float = 1.0, - elev: float = 0.0, - azim: float = 0.0, + dist: _BatchFloatType = 1.0, + elev: _BatchFloatType = 0.0, + azim: _BatchFloatType = 0.0, degrees: bool = True, eye: Optional[Union[Sequence, torch.Tensor]] = None, at=((0, 0, 0),), # (1, 3) diff --git a/tests/data/missing_usemtl/README.md b/tests/data/missing_usemtl/README.md index 82045cd6..1c9b08be 100644 --- a/tests/data/missing_usemtl/README.md +++ b/tests/data/missing_usemtl/README.md @@ -2,6 +2,6 @@ This is copied version of docs/tutorials/data/cow_mesh with removed line 6159 (usemtl material_1) to test behavior without usemtl material_1 declaration. -Thank you to Keenen Crane for allowing the cow mesh model to be used freely in the public domain. +Thank you to Keenan Crane for allowing the cow mesh model to be used freely in the public domain. ###### Source: http://www.cs.cmu.edu/~kmcrane/Projects/ModelRepository/ diff --git a/tests/implicitron/data/data_source.yaml b/tests/implicitron/data/data_source.yaml index 3760f944..0c212c96 100644 --- a/tests/implicitron/data/data_source.yaml +++ b/tests/implicitron/data/data_source.yaml @@ -90,6 +90,15 @@ dataset_map_provider_LlffDatasetMapProvider_args: n_known_frames_for_test: null path_manager_factory_PathManagerFactory_args: silence_logs: true +dataset_map_provider_RenderedMeshDatasetMapProvider_args: + num_views: 40 + data_file: null + azimuth_range: 180.0 + resolution: 128 + use_point_light: true + path_manager_factory_class_type: PathManagerFactory + path_manager_factory_PathManagerFactory_args: + silence_logs: true data_loader_map_provider_SequenceDataLoaderMapProvider_args: batch_size: 1 num_workers: 0 diff --git a/tests/implicitron/test_data_cow.py b/tests/implicitron/test_data_cow.py new file mode 100644 index 00000000..07b0b339 --- /dev/null +++ b/tests/implicitron/test_data_cow.py @@ -0,0 +1,57 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import os +import unittest + +import torch +from pytorch3d.implicitron.dataset.dataset_base import FrameData +from pytorch3d.implicitron.dataset.rendered_mesh_dataset_map_provider import ( + RenderedMeshDatasetMapProvider, +) +from pytorch3d.implicitron.tools.config import expand_args_fields +from pytorch3d.renderer import FoVPerspectiveCameras +from tests.common_testing import TestCaseMixin + + +inside_re_worker = os.environ.get("INSIDE_RE_WORKER", False) + + +class TestDataCow(TestCaseMixin, unittest.TestCase): + def test_simple(self): + if inside_re_worker: + return + expand_args_fields(RenderedMeshDatasetMapProvider) + self._runtest(use_point_light=True, num_views=4) + self._runtest(use_point_light=False, num_views=4) + + def _runtest(self, **kwargs): + provider = RenderedMeshDatasetMapProvider(**kwargs) + dataset_map = provider.get_dataset_map() + known_matrix = torch.zeros(1, 4, 4) + known_matrix[0, 0, 0] = 1.7321 + known_matrix[0, 1, 1] = 1.7321 + known_matrix[0, 2, 2] = 1.0101 + known_matrix[0, 3, 2] = -1.0101 + known_matrix[0, 2, 3] = 1 + + self.assertIsNone(dataset_map.val) + self.assertIsNone(dataset_map.test) + self.assertEqual(len(dataset_map.train), provider.num_views) + + value = dataset_map.train[0] + self.assertIsInstance(value, FrameData) + + self.assertEqual(value.image_rgb.shape, (3, 128, 128)) + self.assertEqual(value.fg_probability.shape, (1, 128, 128)) + # corner of image is background + self.assertEqual(value.fg_probability[0, 0, 0], 0) + self.assertEqual(value.fg_probability.max(), 1.0) + self.assertIsInstance(value.camera, FoVPerspectiveCameras) + self.assertEqual(len(value.camera), 1) + self.assertIsNone(value.camera.K) + matrix = value.camera.get_projection_transform().get_matrix() + self.assertClose(matrix, known_matrix, atol=1e-4)