provide cow dataset

Summary: Make a dummy single-scene dataset using the code from generate_cow_renders (used in existing NeRF tutorials)

Reviewed By: kjchalup

Differential Revision: D38116910

fbshipit-source-id: 8db6df7098aa221c81d392e5cd21b0e67f65bd70
This commit is contained in:
Jeremy Reizenstein 2022-08-01 01:52:12 -07:00 committed by Facebook GitHub Bot
parent 1b0584f7bd
commit 14bd5e28e8
9 changed files with 302 additions and 5 deletions

View File

@ -1,5 +1,5 @@
# Acknowledgements
Thank you to Keenen Crane for allowing the cow mesh model to be used freely in the public domain.
Thank you to Keenan Crane for allowing the cow mesh model to be used freely in the public domain.
###### Source: http://www.cs.cmu.edu/~kmcrane/Projects/ModelRepository/

View File

@ -44,6 +44,8 @@ def generate_cow_renders(
data_dir: The folder that contains the cow mesh files. If the cow mesh
files do not exist in the folder, this function will automatically
download them.
azimuth_range: number of degrees on each side of the start position to
take samples
Returns:
cameras: A batch of `num_views` `FoVPerspectiveCameras` from which the

View File

@ -101,6 +101,15 @@ data_source_ImplicitronDataSource_args:
n_known_frames_for_test: null
path_manager_factory_PathManagerFactory_args:
silence_logs: true
dataset_map_provider_RenderedMeshDatasetMapProvider_args:
num_views: 40
data_file: null
azimuth_range: 180.0
resolution: 128
use_point_light: true
path_manager_factory_class_type: PathManagerFactory
path_manager_factory_PathManagerFactory_args:
silence_logs: true
data_loader_map_provider_SequenceDataLoaderMapProvider_args:
batch_size: 1
num_workers: 0

View File

@ -19,6 +19,7 @@ from .dataset_map_provider import DatasetMap, DatasetMapProviderBase, Task
from .json_index_dataset_map_provider import JsonIndexDatasetMapProvider # noqa
from .json_index_dataset_map_provider_v2 import JsonIndexDatasetMapProviderV2 # noqa
from .llff_dataset_map_provider import LlffDatasetMapProvider # noqa
from .rendered_mesh_dataset_map_provider import RenderedMeshDatasetMapProvider # noqa
class DataSourceBase(ReplaceableBase):

View File

@ -0,0 +1,219 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from os.path import dirname, join, realpath
from typing import Optional, Tuple
import torch
from pytorch3d.implicitron.tools.config import (
expand_args_fields,
registry,
run_auto_creation,
)
from pytorch3d.io import IO
from pytorch3d.renderer import (
AmbientLights,
BlendParams,
CamerasBase,
FoVPerspectiveCameras,
HardPhongShader,
look_at_view_transform,
MeshRasterizer,
MeshRendererWithFragments,
PointLights,
RasterizationSettings,
)
from pytorch3d.structures.meshes import Meshes
from .dataset_map_provider import (
DatasetMap,
DatasetMapProviderBase,
PathManagerFactory,
Task,
)
from .single_sequence_dataset import SingleSceneDataset
from .utils import DATASET_TYPE_KNOWN
@registry.register
class RenderedMeshDatasetMapProvider(DatasetMapProviderBase): # pyre-ignore [13]
"""
A simple single-scene dataset based on PyTorch3D renders of a mesh.
Provides `num_views` renders of the mesh as train, with no val
and test. The renders are generated from viewpoints sampled at uniformly
distributed azimuth intervals. The elevation is kept constant so that the
camera's vertical position coincides with the equator.
By default, uses Keenan Crane's cow model, and the camera locations are
set to make sense for that.
Although the rendering used to generate this dataset will use a GPU
if one is available, the data it produces is on the CPU just like
the data returned by implicitron's other dataset map providers.
This is because both datasets and models can be large, so implicitron's
GenericModel.forward (etc) expects data on the CPU and only moves
what it needs to the device.
For a more detailed explanation of this code, please refer to the
docs/tutorials/fit_textured_mesh.ipynb notebook.
Members:
num_views: The number of generated renders.
data_file: The folder that contains the mesh file. By default, finds
the cow mesh in the same repo as this code.
azimuth_range: number of degrees on each side of the start position to
take samples
resolution: the common height and width of the output images.
use_point_light: whether to use a particular point light as opposed
to ambient white.
"""
num_views: int = 40
data_file: Optional[str] = None
azimuth_range: float = 180
resolution: int = 128
use_point_light: bool = True
path_manager_factory: PathManagerFactory
path_manager_factory_class_type: str = "PathManagerFactory"
def get_dataset_map(self) -> DatasetMap:
# pyre-ignore[16]
return DatasetMap(train=self.train_dataset, val=None, test=None)
def get_task(self) -> Task:
return Task.SINGLE_SEQUENCE
def get_all_train_cameras(self) -> CamerasBase:
# pyre-ignore[16]
return self.poses
def __post_init__(self) -> None:
super().__init__()
run_auto_creation(self)
if torch.cuda.is_available():
device = torch.device("cuda:0")
else:
device = torch.device("cpu")
if self.data_file is None:
data_file = join(
dirname(dirname(dirname(dirname(realpath(__file__))))),
"docs",
"tutorials",
"data",
"cow_mesh",
"cow.obj",
)
else:
data_file = self.data_file
io = IO(path_manager=self.path_manager_factory.get())
mesh = io.load_mesh(data_file, device=device)
poses, images, masks = _generate_cow_renders(
num_views=self.num_views,
mesh=mesh,
azimuth_range=self.azimuth_range,
resolution=self.resolution,
device=device,
use_point_light=self.use_point_light,
)
# pyre-ignore[16]
self.poses = poses.cpu()
expand_args_fields(SingleSceneDataset)
# pyre-ignore[16]
self.train_dataset = SingleSceneDataset( # pyre-ignore[28]
object_name="cow",
images=list(images.permute(0, 3, 1, 2).cpu()),
fg_probabilities=list(masks[:, None].cpu()),
poses=[self.poses[i] for i in range(len(poses))],
frame_types=[DATASET_TYPE_KNOWN] * len(poses),
eval_batches=None,
)
@torch.no_grad()
def _generate_cow_renders(
*,
num_views: int,
mesh: Meshes,
azimuth_range: float,
resolution: int,
device: torch.device,
use_point_light: bool,
) -> Tuple[CamerasBase, torch.Tensor, torch.Tensor]:
"""
Returns:
cameras: A batch of `num_views` `FoVPerspectiveCameras` from which the
images are rendered.
images: A tensor of shape `(num_views, height, width, 3)` containing
the rendered images.
silhouettes: A tensor of shape `(num_views, height, width)` containing
the rendered silhouettes.
"""
# Load obj file
# We scale normalize and center the target mesh to fit in a sphere of radius 1
# centered at (0,0,0). (scale, center) will be used to bring the predicted mesh
# to its original center and scale. Note that normalizing the target mesh,
# speeds up the optimization but is not necessary!
verts = mesh.verts_packed()
N = verts.shape[0]
center = verts.mean(0)
scale = max((verts - center).abs().max(0)[0])
mesh.offset_verts_(-(center.expand(N, 3)))
mesh.scale_verts_((1.0 / float(scale)))
# Get a batch of viewing angles.
elev = torch.linspace(0, 0, num_views) # keep constant
azim = torch.linspace(-azimuth_range, azimuth_range, num_views) + 180.0
# Place a point light in front of the object. As mentioned above, the front of
# the cow is facing the -z direction.
if use_point_light:
lights = PointLights(device=device, location=[[0.0, 0.0, -3.0]])
else:
lights = AmbientLights(device=device)
# Initialize an OpenGL perspective camera that represents a batch of different
# viewing angles. All the cameras helper methods support mixed type inputs and
# broadcasting. So we can view the camera from the a distance of dist=2.7, and
# then specify elevation and azimuth angles for each viewpoint as tensors.
R, T = look_at_view_transform(dist=2.7, elev=elev, azim=azim)
cameras = FoVPerspectiveCameras(device=device, R=R, T=T)
# Define the settings for rasterization and shading.
# As we are rendering images for visualization
# purposes only we will set faces_per_pixel=1 and blur_radius=0.0. Refer to
# rasterize_meshes.py for explanations of these parameters. We also leave
# bin_size and max_faces_per_bin to their default values of None, which sets
# their values using heuristics and ensures that the faster coarse-to-fine
# rasterization method is used. Refer to docs/notes/renderer.md for an
# explanation of the difference between naive and coarse-to-fine rasterization.
raster_settings = RasterizationSettings(
image_size=resolution, blur_radius=0.0, faces_per_pixel=1
)
# Create a Phong renderer by composing a rasterizer and a shader. The textured
# Phong shader will interpolate the texture uv coordinates for each vertex,
# sample from a texture image and apply the Phong lighting model
blend_params = BlendParams(sigma=1e-4, gamma=1e-4, background_color=(0.0, 0.0, 0.0))
rasterizer_type = MeshRasterizer
renderer = MeshRendererWithFragments(
rasterizer=rasterizer_type(cameras=cameras, raster_settings=raster_settings),
shader=HardPhongShader(
device=device, cameras=cameras, lights=lights, blend_params=blend_params
),
)
# Create a batch of meshes by repeating the cow mesh and associated textures.
# Meshes has a useful `extend` method which allows us do this very easily.
# This also extends the textures.
meshes = mesh.extend(num_views)
# Render the cow mesh from each viewing angle
target_images, fragments = renderer(meshes, cameras=cameras, lights=lights)
silhouette_binary = (fragments.pix_to_face[..., 0] >= 0).float()
return cameras, target_images[..., :3], silhouette_binary

View File

@ -1661,9 +1661,9 @@ def look_at_rotation(
def look_at_view_transform(
dist: float = 1.0,
elev: float = 0.0,
azim: float = 0.0,
dist: _BatchFloatType = 1.0,
elev: _BatchFloatType = 0.0,
azim: _BatchFloatType = 0.0,
degrees: bool = True,
eye: Optional[Union[Sequence, torch.Tensor]] = None,
at=((0, 0, 0),), # (1, 3)

View File

@ -2,6 +2,6 @@
This is copied version of docs/tutorials/data/cow_mesh with removed line 6159 (usemtl material_1) to test behavior without usemtl material_1 declaration.
Thank you to Keenen Crane for allowing the cow mesh model to be used freely in the public domain.
Thank you to Keenan Crane for allowing the cow mesh model to be used freely in the public domain.
###### Source: http://www.cs.cmu.edu/~kmcrane/Projects/ModelRepository/

View File

@ -90,6 +90,15 @@ dataset_map_provider_LlffDatasetMapProvider_args:
n_known_frames_for_test: null
path_manager_factory_PathManagerFactory_args:
silence_logs: true
dataset_map_provider_RenderedMeshDatasetMapProvider_args:
num_views: 40
data_file: null
azimuth_range: 180.0
resolution: 128
use_point_light: true
path_manager_factory_class_type: PathManagerFactory
path_manager_factory_PathManagerFactory_args:
silence_logs: true
data_loader_map_provider_SequenceDataLoaderMapProvider_args:
batch_size: 1
num_workers: 0

View File

@ -0,0 +1,57 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import os
import unittest
import torch
from pytorch3d.implicitron.dataset.dataset_base import FrameData
from pytorch3d.implicitron.dataset.rendered_mesh_dataset_map_provider import (
RenderedMeshDatasetMapProvider,
)
from pytorch3d.implicitron.tools.config import expand_args_fields
from pytorch3d.renderer import FoVPerspectiveCameras
from tests.common_testing import TestCaseMixin
inside_re_worker = os.environ.get("INSIDE_RE_WORKER", False)
class TestDataCow(TestCaseMixin, unittest.TestCase):
def test_simple(self):
if inside_re_worker:
return
expand_args_fields(RenderedMeshDatasetMapProvider)
self._runtest(use_point_light=True, num_views=4)
self._runtest(use_point_light=False, num_views=4)
def _runtest(self, **kwargs):
provider = RenderedMeshDatasetMapProvider(**kwargs)
dataset_map = provider.get_dataset_map()
known_matrix = torch.zeros(1, 4, 4)
known_matrix[0, 0, 0] = 1.7321
known_matrix[0, 1, 1] = 1.7321
known_matrix[0, 2, 2] = 1.0101
known_matrix[0, 3, 2] = -1.0101
known_matrix[0, 2, 3] = 1
self.assertIsNone(dataset_map.val)
self.assertIsNone(dataset_map.test)
self.assertEqual(len(dataset_map.train), provider.num_views)
value = dataset_map.train[0]
self.assertIsInstance(value, FrameData)
self.assertEqual(value.image_rgb.shape, (3, 128, 128))
self.assertEqual(value.fg_probability.shape, (1, 128, 128))
# corner of image is background
self.assertEqual(value.fg_probability[0, 0, 0], 0)
self.assertEqual(value.fg_probability.max(), 1.0)
self.assertIsInstance(value.camera, FoVPerspectiveCameras)
self.assertEqual(len(value.camera), 1)
self.assertIsNone(value.camera.K)
matrix = value.camera.get_projection_transform().get_matrix()
self.assertClose(matrix, known_matrix, atol=1e-4)