mirror of
				https://github.com/facebookresearch/pytorch3d.git
				synced 2025-11-04 18:02:14 +08:00 
			
		
		
		
	provide cow dataset
Summary: Make a dummy single-scene dataset using the code from generate_cow_renders (used in existing NeRF tutorials) Reviewed By: kjchalup Differential Revision: D38116910 fbshipit-source-id: 8db6df7098aa221c81d392e5cd21b0e67f65bd70
This commit is contained in:
		
							parent
							
								
									1b0584f7bd
								
							
						
					
					
						commit
						14bd5e28e8
					
				@ -1,5 +1,5 @@
 | 
			
		||||
# Acknowledgements
 | 
			
		||||
 | 
			
		||||
Thank you to Keenen Crane for allowing the cow mesh model to be used freely in the public domain.
 | 
			
		||||
Thank you to Keenan Crane for allowing the cow mesh model to be used freely in the public domain.
 | 
			
		||||
 | 
			
		||||
###### Source: http://www.cs.cmu.edu/~kmcrane/Projects/ModelRepository/
 | 
			
		||||
 | 
			
		||||
@ -44,6 +44,8 @@ def generate_cow_renders(
 | 
			
		||||
        data_dir: The folder that contains the cow mesh files. If the cow mesh
 | 
			
		||||
            files do not exist in the folder, this function will automatically
 | 
			
		||||
            download them.
 | 
			
		||||
        azimuth_range: number of degrees on each side of the start position to
 | 
			
		||||
            take samples
 | 
			
		||||
 | 
			
		||||
    Returns:
 | 
			
		||||
        cameras: A batch of `num_views` `FoVPerspectiveCameras` from which the
 | 
			
		||||
 | 
			
		||||
@ -101,6 +101,15 @@ data_source_ImplicitronDataSource_args:
 | 
			
		||||
    n_known_frames_for_test: null
 | 
			
		||||
    path_manager_factory_PathManagerFactory_args:
 | 
			
		||||
      silence_logs: true
 | 
			
		||||
  dataset_map_provider_RenderedMeshDatasetMapProvider_args:
 | 
			
		||||
    num_views: 40
 | 
			
		||||
    data_file: null
 | 
			
		||||
    azimuth_range: 180.0
 | 
			
		||||
    resolution: 128
 | 
			
		||||
    use_point_light: true
 | 
			
		||||
    path_manager_factory_class_type: PathManagerFactory
 | 
			
		||||
    path_manager_factory_PathManagerFactory_args:
 | 
			
		||||
      silence_logs: true
 | 
			
		||||
  data_loader_map_provider_SequenceDataLoaderMapProvider_args:
 | 
			
		||||
    batch_size: 1
 | 
			
		||||
    num_workers: 0
 | 
			
		||||
 | 
			
		||||
@ -19,6 +19,7 @@ from .dataset_map_provider import DatasetMap, DatasetMapProviderBase, Task
 | 
			
		||||
from .json_index_dataset_map_provider import JsonIndexDatasetMapProvider  # noqa
 | 
			
		||||
from .json_index_dataset_map_provider_v2 import JsonIndexDatasetMapProviderV2  # noqa
 | 
			
		||||
from .llff_dataset_map_provider import LlffDatasetMapProvider  # noqa
 | 
			
		||||
from .rendered_mesh_dataset_map_provider import RenderedMeshDatasetMapProvider  # noqa
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class DataSourceBase(ReplaceableBase):
 | 
			
		||||
 | 
			
		||||
@ -0,0 +1,219 @@
 | 
			
		||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
 | 
			
		||||
# All rights reserved.
 | 
			
		||||
#
 | 
			
		||||
# This source code is licensed under the BSD-style license found in the
 | 
			
		||||
# LICENSE file in the root directory of this source tree.
 | 
			
		||||
 | 
			
		||||
from os.path import dirname, join, realpath
 | 
			
		||||
from typing import Optional, Tuple
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
from pytorch3d.implicitron.tools.config import (
 | 
			
		||||
    expand_args_fields,
 | 
			
		||||
    registry,
 | 
			
		||||
    run_auto_creation,
 | 
			
		||||
)
 | 
			
		||||
from pytorch3d.io import IO
 | 
			
		||||
from pytorch3d.renderer import (
 | 
			
		||||
    AmbientLights,
 | 
			
		||||
    BlendParams,
 | 
			
		||||
    CamerasBase,
 | 
			
		||||
    FoVPerspectiveCameras,
 | 
			
		||||
    HardPhongShader,
 | 
			
		||||
    look_at_view_transform,
 | 
			
		||||
    MeshRasterizer,
 | 
			
		||||
    MeshRendererWithFragments,
 | 
			
		||||
    PointLights,
 | 
			
		||||
    RasterizationSettings,
 | 
			
		||||
)
 | 
			
		||||
from pytorch3d.structures.meshes import Meshes
 | 
			
		||||
 | 
			
		||||
from .dataset_map_provider import (
 | 
			
		||||
    DatasetMap,
 | 
			
		||||
    DatasetMapProviderBase,
 | 
			
		||||
    PathManagerFactory,
 | 
			
		||||
    Task,
 | 
			
		||||
)
 | 
			
		||||
from .single_sequence_dataset import SingleSceneDataset
 | 
			
		||||
from .utils import DATASET_TYPE_KNOWN
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@registry.register
 | 
			
		||||
class RenderedMeshDatasetMapProvider(DatasetMapProviderBase):  # pyre-ignore [13]
 | 
			
		||||
    """
 | 
			
		||||
    A simple single-scene dataset based on PyTorch3D renders of a mesh.
 | 
			
		||||
    Provides `num_views` renders of the mesh as train, with no val
 | 
			
		||||
    and test. The renders are generated from viewpoints sampled at uniformly
 | 
			
		||||
    distributed azimuth intervals. The elevation is kept constant so that the
 | 
			
		||||
    camera's vertical position coincides with the equator.
 | 
			
		||||
 | 
			
		||||
    By default, uses Keenan Crane's cow model, and the camera locations are
 | 
			
		||||
    set to make sense for that.
 | 
			
		||||
 | 
			
		||||
    Although the rendering used to generate this dataset will use a GPU
 | 
			
		||||
    if one is available, the data it produces is on the CPU just like
 | 
			
		||||
    the data returned by implicitron's other dataset map providers.
 | 
			
		||||
    This is because both datasets and models can be large, so implicitron's
 | 
			
		||||
    GenericModel.forward (etc) expects data on the CPU and only moves
 | 
			
		||||
    what it needs to the device.
 | 
			
		||||
 | 
			
		||||
    For a more detailed explanation of this code, please refer to the
 | 
			
		||||
    docs/tutorials/fit_textured_mesh.ipynb notebook.
 | 
			
		||||
 | 
			
		||||
    Members:
 | 
			
		||||
        num_views: The number of generated renders.
 | 
			
		||||
        data_file: The folder that contains the mesh file. By default, finds
 | 
			
		||||
            the cow mesh in the same repo as this code.
 | 
			
		||||
        azimuth_range: number of degrees on each side of the start position to
 | 
			
		||||
            take samples
 | 
			
		||||
        resolution: the common height and width of the output images.
 | 
			
		||||
        use_point_light: whether to use a particular point light as opposed
 | 
			
		||||
            to ambient white.
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    num_views: int = 40
 | 
			
		||||
    data_file: Optional[str] = None
 | 
			
		||||
    azimuth_range: float = 180
 | 
			
		||||
    resolution: int = 128
 | 
			
		||||
    use_point_light: bool = True
 | 
			
		||||
    path_manager_factory: PathManagerFactory
 | 
			
		||||
    path_manager_factory_class_type: str = "PathManagerFactory"
 | 
			
		||||
 | 
			
		||||
    def get_dataset_map(self) -> DatasetMap:
 | 
			
		||||
        # pyre-ignore[16]
 | 
			
		||||
        return DatasetMap(train=self.train_dataset, val=None, test=None)
 | 
			
		||||
 | 
			
		||||
    def get_task(self) -> Task:
 | 
			
		||||
        return Task.SINGLE_SEQUENCE
 | 
			
		||||
 | 
			
		||||
    def get_all_train_cameras(self) -> CamerasBase:
 | 
			
		||||
        # pyre-ignore[16]
 | 
			
		||||
        return self.poses
 | 
			
		||||
 | 
			
		||||
    def __post_init__(self) -> None:
 | 
			
		||||
        super().__init__()
 | 
			
		||||
        run_auto_creation(self)
 | 
			
		||||
        if torch.cuda.is_available():
 | 
			
		||||
            device = torch.device("cuda:0")
 | 
			
		||||
        else:
 | 
			
		||||
            device = torch.device("cpu")
 | 
			
		||||
        if self.data_file is None:
 | 
			
		||||
            data_file = join(
 | 
			
		||||
                dirname(dirname(dirname(dirname(realpath(__file__))))),
 | 
			
		||||
                "docs",
 | 
			
		||||
                "tutorials",
 | 
			
		||||
                "data",
 | 
			
		||||
                "cow_mesh",
 | 
			
		||||
                "cow.obj",
 | 
			
		||||
            )
 | 
			
		||||
        else:
 | 
			
		||||
            data_file = self.data_file
 | 
			
		||||
        io = IO(path_manager=self.path_manager_factory.get())
 | 
			
		||||
        mesh = io.load_mesh(data_file, device=device)
 | 
			
		||||
        poses, images, masks = _generate_cow_renders(
 | 
			
		||||
            num_views=self.num_views,
 | 
			
		||||
            mesh=mesh,
 | 
			
		||||
            azimuth_range=self.azimuth_range,
 | 
			
		||||
            resolution=self.resolution,
 | 
			
		||||
            device=device,
 | 
			
		||||
            use_point_light=self.use_point_light,
 | 
			
		||||
        )
 | 
			
		||||
        # pyre-ignore[16]
 | 
			
		||||
        self.poses = poses.cpu()
 | 
			
		||||
        expand_args_fields(SingleSceneDataset)
 | 
			
		||||
        # pyre-ignore[16]
 | 
			
		||||
        self.train_dataset = SingleSceneDataset(  # pyre-ignore[28]
 | 
			
		||||
            object_name="cow",
 | 
			
		||||
            images=list(images.permute(0, 3, 1, 2).cpu()),
 | 
			
		||||
            fg_probabilities=list(masks[:, None].cpu()),
 | 
			
		||||
            poses=[self.poses[i] for i in range(len(poses))],
 | 
			
		||||
            frame_types=[DATASET_TYPE_KNOWN] * len(poses),
 | 
			
		||||
            eval_batches=None,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@torch.no_grad()
 | 
			
		||||
def _generate_cow_renders(
 | 
			
		||||
    *,
 | 
			
		||||
    num_views: int,
 | 
			
		||||
    mesh: Meshes,
 | 
			
		||||
    azimuth_range: float,
 | 
			
		||||
    resolution: int,
 | 
			
		||||
    device: torch.device,
 | 
			
		||||
    use_point_light: bool,
 | 
			
		||||
) -> Tuple[CamerasBase, torch.Tensor, torch.Tensor]:
 | 
			
		||||
    """
 | 
			
		||||
    Returns:
 | 
			
		||||
        cameras: A batch of `num_views` `FoVPerspectiveCameras` from which the
 | 
			
		||||
            images are rendered.
 | 
			
		||||
        images: A tensor of shape `(num_views, height, width, 3)` containing
 | 
			
		||||
            the rendered images.
 | 
			
		||||
        silhouettes: A tensor of shape `(num_views, height, width)` containing
 | 
			
		||||
            the rendered silhouettes.
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    # Load obj file
 | 
			
		||||
 | 
			
		||||
    # We scale normalize and center the target mesh to fit in a sphere of radius 1
 | 
			
		||||
    # centered at (0,0,0). (scale, center) will be used to bring the predicted mesh
 | 
			
		||||
    # to its original center and scale.  Note that normalizing the target mesh,
 | 
			
		||||
    # speeds up the optimization but is not necessary!
 | 
			
		||||
    verts = mesh.verts_packed()
 | 
			
		||||
    N = verts.shape[0]
 | 
			
		||||
    center = verts.mean(0)
 | 
			
		||||
    scale = max((verts - center).abs().max(0)[0])
 | 
			
		||||
    mesh.offset_verts_(-(center.expand(N, 3)))
 | 
			
		||||
    mesh.scale_verts_((1.0 / float(scale)))
 | 
			
		||||
 | 
			
		||||
    # Get a batch of viewing angles.
 | 
			
		||||
    elev = torch.linspace(0, 0, num_views)  # keep constant
 | 
			
		||||
    azim = torch.linspace(-azimuth_range, azimuth_range, num_views) + 180.0
 | 
			
		||||
 | 
			
		||||
    # Place a point light in front of the object. As mentioned above, the front of
 | 
			
		||||
    # the cow is facing the -z direction.
 | 
			
		||||
    if use_point_light:
 | 
			
		||||
        lights = PointLights(device=device, location=[[0.0, 0.0, -3.0]])
 | 
			
		||||
    else:
 | 
			
		||||
        lights = AmbientLights(device=device)
 | 
			
		||||
 | 
			
		||||
    # Initialize an OpenGL perspective camera that represents a batch of different
 | 
			
		||||
    # viewing angles. All the cameras helper methods support mixed type inputs and
 | 
			
		||||
    # broadcasting. So we can view the camera from the a distance of dist=2.7, and
 | 
			
		||||
    # then specify elevation and azimuth angles for each viewpoint as tensors.
 | 
			
		||||
    R, T = look_at_view_transform(dist=2.7, elev=elev, azim=azim)
 | 
			
		||||
    cameras = FoVPerspectiveCameras(device=device, R=R, T=T)
 | 
			
		||||
 | 
			
		||||
    # Define the settings for rasterization and shading.
 | 
			
		||||
    # As we are rendering images for visualization
 | 
			
		||||
    # purposes only we will set faces_per_pixel=1 and blur_radius=0.0. Refer to
 | 
			
		||||
    # rasterize_meshes.py for explanations of these parameters.  We also leave
 | 
			
		||||
    # bin_size and max_faces_per_bin to their default values of None, which sets
 | 
			
		||||
    # their values using heuristics and ensures that the faster coarse-to-fine
 | 
			
		||||
    # rasterization method is used.  Refer to docs/notes/renderer.md for an
 | 
			
		||||
    # explanation of the difference between naive and coarse-to-fine rasterization.
 | 
			
		||||
    raster_settings = RasterizationSettings(
 | 
			
		||||
        image_size=resolution, blur_radius=0.0, faces_per_pixel=1
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    # Create a Phong renderer by composing a rasterizer and a shader. The textured
 | 
			
		||||
    # Phong shader will interpolate the texture uv coordinates for each vertex,
 | 
			
		||||
    # sample from a texture image and apply the Phong lighting model
 | 
			
		||||
    blend_params = BlendParams(sigma=1e-4, gamma=1e-4, background_color=(0.0, 0.0, 0.0))
 | 
			
		||||
    rasterizer_type = MeshRasterizer
 | 
			
		||||
    renderer = MeshRendererWithFragments(
 | 
			
		||||
        rasterizer=rasterizer_type(cameras=cameras, raster_settings=raster_settings),
 | 
			
		||||
        shader=HardPhongShader(
 | 
			
		||||
            device=device, cameras=cameras, lights=lights, blend_params=blend_params
 | 
			
		||||
        ),
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    # Create a batch of meshes by repeating the cow mesh and associated textures.
 | 
			
		||||
    # Meshes has a useful `extend` method which allows us do this very easily.
 | 
			
		||||
    # This also extends the textures.
 | 
			
		||||
    meshes = mesh.extend(num_views)
 | 
			
		||||
 | 
			
		||||
    # Render the cow mesh from each viewing angle
 | 
			
		||||
    target_images, fragments = renderer(meshes, cameras=cameras, lights=lights)
 | 
			
		||||
    silhouette_binary = (fragments.pix_to_face[..., 0] >= 0).float()
 | 
			
		||||
 | 
			
		||||
    return cameras, target_images[..., :3], silhouette_binary
 | 
			
		||||
@ -1661,9 +1661,9 @@ def look_at_rotation(
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def look_at_view_transform(
 | 
			
		||||
    dist: float = 1.0,
 | 
			
		||||
    elev: float = 0.0,
 | 
			
		||||
    azim: float = 0.0,
 | 
			
		||||
    dist: _BatchFloatType = 1.0,
 | 
			
		||||
    elev: _BatchFloatType = 0.0,
 | 
			
		||||
    azim: _BatchFloatType = 0.0,
 | 
			
		||||
    degrees: bool = True,
 | 
			
		||||
    eye: Optional[Union[Sequence, torch.Tensor]] = None,
 | 
			
		||||
    at=((0, 0, 0),),  # (1, 3)
 | 
			
		||||
 | 
			
		||||
@ -2,6 +2,6 @@
 | 
			
		||||
 | 
			
		||||
This is copied version of docs/tutorials/data/cow_mesh with removed line 6159 (usemtl material_1) to test behavior without usemtl material_1 declaration.
 | 
			
		||||
 | 
			
		||||
Thank you to Keenen Crane for allowing the cow mesh model to be used freely in the public domain.
 | 
			
		||||
Thank you to Keenan Crane for allowing the cow mesh model to be used freely in the public domain.
 | 
			
		||||
 | 
			
		||||
###### Source: http://www.cs.cmu.edu/~kmcrane/Projects/ModelRepository/
 | 
			
		||||
 | 
			
		||||
@ -90,6 +90,15 @@ dataset_map_provider_LlffDatasetMapProvider_args:
 | 
			
		||||
  n_known_frames_for_test: null
 | 
			
		||||
  path_manager_factory_PathManagerFactory_args:
 | 
			
		||||
    silence_logs: true
 | 
			
		||||
dataset_map_provider_RenderedMeshDatasetMapProvider_args:
 | 
			
		||||
  num_views: 40
 | 
			
		||||
  data_file: null
 | 
			
		||||
  azimuth_range: 180.0
 | 
			
		||||
  resolution: 128
 | 
			
		||||
  use_point_light: true
 | 
			
		||||
  path_manager_factory_class_type: PathManagerFactory
 | 
			
		||||
  path_manager_factory_PathManagerFactory_args:
 | 
			
		||||
    silence_logs: true
 | 
			
		||||
data_loader_map_provider_SequenceDataLoaderMapProvider_args:
 | 
			
		||||
  batch_size: 1
 | 
			
		||||
  num_workers: 0
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										57
									
								
								tests/implicitron/test_data_cow.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										57
									
								
								tests/implicitron/test_data_cow.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,57 @@
 | 
			
		||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
 | 
			
		||||
# All rights reserved.
 | 
			
		||||
#
 | 
			
		||||
# This source code is licensed under the BSD-style license found in the
 | 
			
		||||
# LICENSE file in the root directory of this source tree.
 | 
			
		||||
 | 
			
		||||
import os
 | 
			
		||||
import unittest
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
from pytorch3d.implicitron.dataset.dataset_base import FrameData
 | 
			
		||||
from pytorch3d.implicitron.dataset.rendered_mesh_dataset_map_provider import (
 | 
			
		||||
    RenderedMeshDatasetMapProvider,
 | 
			
		||||
)
 | 
			
		||||
from pytorch3d.implicitron.tools.config import expand_args_fields
 | 
			
		||||
from pytorch3d.renderer import FoVPerspectiveCameras
 | 
			
		||||
from tests.common_testing import TestCaseMixin
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
inside_re_worker = os.environ.get("INSIDE_RE_WORKER", False)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class TestDataCow(TestCaseMixin, unittest.TestCase):
 | 
			
		||||
    def test_simple(self):
 | 
			
		||||
        if inside_re_worker:
 | 
			
		||||
            return
 | 
			
		||||
        expand_args_fields(RenderedMeshDatasetMapProvider)
 | 
			
		||||
        self._runtest(use_point_light=True, num_views=4)
 | 
			
		||||
        self._runtest(use_point_light=False, num_views=4)
 | 
			
		||||
 | 
			
		||||
    def _runtest(self, **kwargs):
 | 
			
		||||
        provider = RenderedMeshDatasetMapProvider(**kwargs)
 | 
			
		||||
        dataset_map = provider.get_dataset_map()
 | 
			
		||||
        known_matrix = torch.zeros(1, 4, 4)
 | 
			
		||||
        known_matrix[0, 0, 0] = 1.7321
 | 
			
		||||
        known_matrix[0, 1, 1] = 1.7321
 | 
			
		||||
        known_matrix[0, 2, 2] = 1.0101
 | 
			
		||||
        known_matrix[0, 3, 2] = -1.0101
 | 
			
		||||
        known_matrix[0, 2, 3] = 1
 | 
			
		||||
 | 
			
		||||
        self.assertIsNone(dataset_map.val)
 | 
			
		||||
        self.assertIsNone(dataset_map.test)
 | 
			
		||||
        self.assertEqual(len(dataset_map.train), provider.num_views)
 | 
			
		||||
 | 
			
		||||
        value = dataset_map.train[0]
 | 
			
		||||
        self.assertIsInstance(value, FrameData)
 | 
			
		||||
 | 
			
		||||
        self.assertEqual(value.image_rgb.shape, (3, 128, 128))
 | 
			
		||||
        self.assertEqual(value.fg_probability.shape, (1, 128, 128))
 | 
			
		||||
        # corner of image is background
 | 
			
		||||
        self.assertEqual(value.fg_probability[0, 0, 0], 0)
 | 
			
		||||
        self.assertEqual(value.fg_probability.max(), 1.0)
 | 
			
		||||
        self.assertIsInstance(value.camera, FoVPerspectiveCameras)
 | 
			
		||||
        self.assertEqual(len(value.camera), 1)
 | 
			
		||||
        self.assertIsNone(value.camera.K)
 | 
			
		||||
        matrix = value.camera.get_projection_transform().get_matrix()
 | 
			
		||||
        self.assertClose(matrix, known_matrix, atol=1e-4)
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user