mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 03:42:50 +08:00
provide cow dataset
Summary: Make a dummy single-scene dataset using the code from generate_cow_renders (used in existing NeRF tutorials) Reviewed By: kjchalup Differential Revision: D38116910 fbshipit-source-id: 8db6df7098aa221c81d392e5cd21b0e67f65bd70
This commit is contained in:
parent
1b0584f7bd
commit
14bd5e28e8
@ -1,5 +1,5 @@
|
|||||||
# Acknowledgements
|
# Acknowledgements
|
||||||
|
|
||||||
Thank you to Keenen Crane for allowing the cow mesh model to be used freely in the public domain.
|
Thank you to Keenan Crane for allowing the cow mesh model to be used freely in the public domain.
|
||||||
|
|
||||||
###### Source: http://www.cs.cmu.edu/~kmcrane/Projects/ModelRepository/
|
###### Source: http://www.cs.cmu.edu/~kmcrane/Projects/ModelRepository/
|
||||||
|
@ -44,6 +44,8 @@ def generate_cow_renders(
|
|||||||
data_dir: The folder that contains the cow mesh files. If the cow mesh
|
data_dir: The folder that contains the cow mesh files. If the cow mesh
|
||||||
files do not exist in the folder, this function will automatically
|
files do not exist in the folder, this function will automatically
|
||||||
download them.
|
download them.
|
||||||
|
azimuth_range: number of degrees on each side of the start position to
|
||||||
|
take samples
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
cameras: A batch of `num_views` `FoVPerspectiveCameras` from which the
|
cameras: A batch of `num_views` `FoVPerspectiveCameras` from which the
|
||||||
|
@ -101,6 +101,15 @@ data_source_ImplicitronDataSource_args:
|
|||||||
n_known_frames_for_test: null
|
n_known_frames_for_test: null
|
||||||
path_manager_factory_PathManagerFactory_args:
|
path_manager_factory_PathManagerFactory_args:
|
||||||
silence_logs: true
|
silence_logs: true
|
||||||
|
dataset_map_provider_RenderedMeshDatasetMapProvider_args:
|
||||||
|
num_views: 40
|
||||||
|
data_file: null
|
||||||
|
azimuth_range: 180.0
|
||||||
|
resolution: 128
|
||||||
|
use_point_light: true
|
||||||
|
path_manager_factory_class_type: PathManagerFactory
|
||||||
|
path_manager_factory_PathManagerFactory_args:
|
||||||
|
silence_logs: true
|
||||||
data_loader_map_provider_SequenceDataLoaderMapProvider_args:
|
data_loader_map_provider_SequenceDataLoaderMapProvider_args:
|
||||||
batch_size: 1
|
batch_size: 1
|
||||||
num_workers: 0
|
num_workers: 0
|
||||||
|
@ -19,6 +19,7 @@ from .dataset_map_provider import DatasetMap, DatasetMapProviderBase, Task
|
|||||||
from .json_index_dataset_map_provider import JsonIndexDatasetMapProvider # noqa
|
from .json_index_dataset_map_provider import JsonIndexDatasetMapProvider # noqa
|
||||||
from .json_index_dataset_map_provider_v2 import JsonIndexDatasetMapProviderV2 # noqa
|
from .json_index_dataset_map_provider_v2 import JsonIndexDatasetMapProviderV2 # noqa
|
||||||
from .llff_dataset_map_provider import LlffDatasetMapProvider # noqa
|
from .llff_dataset_map_provider import LlffDatasetMapProvider # noqa
|
||||||
|
from .rendered_mesh_dataset_map_provider import RenderedMeshDatasetMapProvider # noqa
|
||||||
|
|
||||||
|
|
||||||
class DataSourceBase(ReplaceableBase):
|
class DataSourceBase(ReplaceableBase):
|
||||||
|
@ -0,0 +1,219 @@
|
|||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the BSD-style license found in the
|
||||||
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
|
from os.path import dirname, join, realpath
|
||||||
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from pytorch3d.implicitron.tools.config import (
|
||||||
|
expand_args_fields,
|
||||||
|
registry,
|
||||||
|
run_auto_creation,
|
||||||
|
)
|
||||||
|
from pytorch3d.io import IO
|
||||||
|
from pytorch3d.renderer import (
|
||||||
|
AmbientLights,
|
||||||
|
BlendParams,
|
||||||
|
CamerasBase,
|
||||||
|
FoVPerspectiveCameras,
|
||||||
|
HardPhongShader,
|
||||||
|
look_at_view_transform,
|
||||||
|
MeshRasterizer,
|
||||||
|
MeshRendererWithFragments,
|
||||||
|
PointLights,
|
||||||
|
RasterizationSettings,
|
||||||
|
)
|
||||||
|
from pytorch3d.structures.meshes import Meshes
|
||||||
|
|
||||||
|
from .dataset_map_provider import (
|
||||||
|
DatasetMap,
|
||||||
|
DatasetMapProviderBase,
|
||||||
|
PathManagerFactory,
|
||||||
|
Task,
|
||||||
|
)
|
||||||
|
from .single_sequence_dataset import SingleSceneDataset
|
||||||
|
from .utils import DATASET_TYPE_KNOWN
|
||||||
|
|
||||||
|
|
||||||
|
@registry.register
|
||||||
|
class RenderedMeshDatasetMapProvider(DatasetMapProviderBase): # pyre-ignore [13]
|
||||||
|
"""
|
||||||
|
A simple single-scene dataset based on PyTorch3D renders of a mesh.
|
||||||
|
Provides `num_views` renders of the mesh as train, with no val
|
||||||
|
and test. The renders are generated from viewpoints sampled at uniformly
|
||||||
|
distributed azimuth intervals. The elevation is kept constant so that the
|
||||||
|
camera's vertical position coincides with the equator.
|
||||||
|
|
||||||
|
By default, uses Keenan Crane's cow model, and the camera locations are
|
||||||
|
set to make sense for that.
|
||||||
|
|
||||||
|
Although the rendering used to generate this dataset will use a GPU
|
||||||
|
if one is available, the data it produces is on the CPU just like
|
||||||
|
the data returned by implicitron's other dataset map providers.
|
||||||
|
This is because both datasets and models can be large, so implicitron's
|
||||||
|
GenericModel.forward (etc) expects data on the CPU and only moves
|
||||||
|
what it needs to the device.
|
||||||
|
|
||||||
|
For a more detailed explanation of this code, please refer to the
|
||||||
|
docs/tutorials/fit_textured_mesh.ipynb notebook.
|
||||||
|
|
||||||
|
Members:
|
||||||
|
num_views: The number of generated renders.
|
||||||
|
data_file: The folder that contains the mesh file. By default, finds
|
||||||
|
the cow mesh in the same repo as this code.
|
||||||
|
azimuth_range: number of degrees on each side of the start position to
|
||||||
|
take samples
|
||||||
|
resolution: the common height and width of the output images.
|
||||||
|
use_point_light: whether to use a particular point light as opposed
|
||||||
|
to ambient white.
|
||||||
|
"""
|
||||||
|
|
||||||
|
num_views: int = 40
|
||||||
|
data_file: Optional[str] = None
|
||||||
|
azimuth_range: float = 180
|
||||||
|
resolution: int = 128
|
||||||
|
use_point_light: bool = True
|
||||||
|
path_manager_factory: PathManagerFactory
|
||||||
|
path_manager_factory_class_type: str = "PathManagerFactory"
|
||||||
|
|
||||||
|
def get_dataset_map(self) -> DatasetMap:
|
||||||
|
# pyre-ignore[16]
|
||||||
|
return DatasetMap(train=self.train_dataset, val=None, test=None)
|
||||||
|
|
||||||
|
def get_task(self) -> Task:
|
||||||
|
return Task.SINGLE_SEQUENCE
|
||||||
|
|
||||||
|
def get_all_train_cameras(self) -> CamerasBase:
|
||||||
|
# pyre-ignore[16]
|
||||||
|
return self.poses
|
||||||
|
|
||||||
|
def __post_init__(self) -> None:
|
||||||
|
super().__init__()
|
||||||
|
run_auto_creation(self)
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
device = torch.device("cuda:0")
|
||||||
|
else:
|
||||||
|
device = torch.device("cpu")
|
||||||
|
if self.data_file is None:
|
||||||
|
data_file = join(
|
||||||
|
dirname(dirname(dirname(dirname(realpath(__file__))))),
|
||||||
|
"docs",
|
||||||
|
"tutorials",
|
||||||
|
"data",
|
||||||
|
"cow_mesh",
|
||||||
|
"cow.obj",
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
data_file = self.data_file
|
||||||
|
io = IO(path_manager=self.path_manager_factory.get())
|
||||||
|
mesh = io.load_mesh(data_file, device=device)
|
||||||
|
poses, images, masks = _generate_cow_renders(
|
||||||
|
num_views=self.num_views,
|
||||||
|
mesh=mesh,
|
||||||
|
azimuth_range=self.azimuth_range,
|
||||||
|
resolution=self.resolution,
|
||||||
|
device=device,
|
||||||
|
use_point_light=self.use_point_light,
|
||||||
|
)
|
||||||
|
# pyre-ignore[16]
|
||||||
|
self.poses = poses.cpu()
|
||||||
|
expand_args_fields(SingleSceneDataset)
|
||||||
|
# pyre-ignore[16]
|
||||||
|
self.train_dataset = SingleSceneDataset( # pyre-ignore[28]
|
||||||
|
object_name="cow",
|
||||||
|
images=list(images.permute(0, 3, 1, 2).cpu()),
|
||||||
|
fg_probabilities=list(masks[:, None].cpu()),
|
||||||
|
poses=[self.poses[i] for i in range(len(poses))],
|
||||||
|
frame_types=[DATASET_TYPE_KNOWN] * len(poses),
|
||||||
|
eval_batches=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def _generate_cow_renders(
|
||||||
|
*,
|
||||||
|
num_views: int,
|
||||||
|
mesh: Meshes,
|
||||||
|
azimuth_range: float,
|
||||||
|
resolution: int,
|
||||||
|
device: torch.device,
|
||||||
|
use_point_light: bool,
|
||||||
|
) -> Tuple[CamerasBase, torch.Tensor, torch.Tensor]:
|
||||||
|
"""
|
||||||
|
Returns:
|
||||||
|
cameras: A batch of `num_views` `FoVPerspectiveCameras` from which the
|
||||||
|
images are rendered.
|
||||||
|
images: A tensor of shape `(num_views, height, width, 3)` containing
|
||||||
|
the rendered images.
|
||||||
|
silhouettes: A tensor of shape `(num_views, height, width)` containing
|
||||||
|
the rendered silhouettes.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Load obj file
|
||||||
|
|
||||||
|
# We scale normalize and center the target mesh to fit in a sphere of radius 1
|
||||||
|
# centered at (0,0,0). (scale, center) will be used to bring the predicted mesh
|
||||||
|
# to its original center and scale. Note that normalizing the target mesh,
|
||||||
|
# speeds up the optimization but is not necessary!
|
||||||
|
verts = mesh.verts_packed()
|
||||||
|
N = verts.shape[0]
|
||||||
|
center = verts.mean(0)
|
||||||
|
scale = max((verts - center).abs().max(0)[0])
|
||||||
|
mesh.offset_verts_(-(center.expand(N, 3)))
|
||||||
|
mesh.scale_verts_((1.0 / float(scale)))
|
||||||
|
|
||||||
|
# Get a batch of viewing angles.
|
||||||
|
elev = torch.linspace(0, 0, num_views) # keep constant
|
||||||
|
azim = torch.linspace(-azimuth_range, azimuth_range, num_views) + 180.0
|
||||||
|
|
||||||
|
# Place a point light in front of the object. As mentioned above, the front of
|
||||||
|
# the cow is facing the -z direction.
|
||||||
|
if use_point_light:
|
||||||
|
lights = PointLights(device=device, location=[[0.0, 0.0, -3.0]])
|
||||||
|
else:
|
||||||
|
lights = AmbientLights(device=device)
|
||||||
|
|
||||||
|
# Initialize an OpenGL perspective camera that represents a batch of different
|
||||||
|
# viewing angles. All the cameras helper methods support mixed type inputs and
|
||||||
|
# broadcasting. So we can view the camera from the a distance of dist=2.7, and
|
||||||
|
# then specify elevation and azimuth angles for each viewpoint as tensors.
|
||||||
|
R, T = look_at_view_transform(dist=2.7, elev=elev, azim=azim)
|
||||||
|
cameras = FoVPerspectiveCameras(device=device, R=R, T=T)
|
||||||
|
|
||||||
|
# Define the settings for rasterization and shading.
|
||||||
|
# As we are rendering images for visualization
|
||||||
|
# purposes only we will set faces_per_pixel=1 and blur_radius=0.0. Refer to
|
||||||
|
# rasterize_meshes.py for explanations of these parameters. We also leave
|
||||||
|
# bin_size and max_faces_per_bin to their default values of None, which sets
|
||||||
|
# their values using heuristics and ensures that the faster coarse-to-fine
|
||||||
|
# rasterization method is used. Refer to docs/notes/renderer.md for an
|
||||||
|
# explanation of the difference between naive and coarse-to-fine rasterization.
|
||||||
|
raster_settings = RasterizationSettings(
|
||||||
|
image_size=resolution, blur_radius=0.0, faces_per_pixel=1
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create a Phong renderer by composing a rasterizer and a shader. The textured
|
||||||
|
# Phong shader will interpolate the texture uv coordinates for each vertex,
|
||||||
|
# sample from a texture image and apply the Phong lighting model
|
||||||
|
blend_params = BlendParams(sigma=1e-4, gamma=1e-4, background_color=(0.0, 0.0, 0.0))
|
||||||
|
rasterizer_type = MeshRasterizer
|
||||||
|
renderer = MeshRendererWithFragments(
|
||||||
|
rasterizer=rasterizer_type(cameras=cameras, raster_settings=raster_settings),
|
||||||
|
shader=HardPhongShader(
|
||||||
|
device=device, cameras=cameras, lights=lights, blend_params=blend_params
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create a batch of meshes by repeating the cow mesh and associated textures.
|
||||||
|
# Meshes has a useful `extend` method which allows us do this very easily.
|
||||||
|
# This also extends the textures.
|
||||||
|
meshes = mesh.extend(num_views)
|
||||||
|
|
||||||
|
# Render the cow mesh from each viewing angle
|
||||||
|
target_images, fragments = renderer(meshes, cameras=cameras, lights=lights)
|
||||||
|
silhouette_binary = (fragments.pix_to_face[..., 0] >= 0).float()
|
||||||
|
|
||||||
|
return cameras, target_images[..., :3], silhouette_binary
|
@ -1661,9 +1661,9 @@ def look_at_rotation(
|
|||||||
|
|
||||||
|
|
||||||
def look_at_view_transform(
|
def look_at_view_transform(
|
||||||
dist: float = 1.0,
|
dist: _BatchFloatType = 1.0,
|
||||||
elev: float = 0.0,
|
elev: _BatchFloatType = 0.0,
|
||||||
azim: float = 0.0,
|
azim: _BatchFloatType = 0.0,
|
||||||
degrees: bool = True,
|
degrees: bool = True,
|
||||||
eye: Optional[Union[Sequence, torch.Tensor]] = None,
|
eye: Optional[Union[Sequence, torch.Tensor]] = None,
|
||||||
at=((0, 0, 0),), # (1, 3)
|
at=((0, 0, 0),), # (1, 3)
|
||||||
|
@ -2,6 +2,6 @@
|
|||||||
|
|
||||||
This is copied version of docs/tutorials/data/cow_mesh with removed line 6159 (usemtl material_1) to test behavior without usemtl material_1 declaration.
|
This is copied version of docs/tutorials/data/cow_mesh with removed line 6159 (usemtl material_1) to test behavior without usemtl material_1 declaration.
|
||||||
|
|
||||||
Thank you to Keenen Crane for allowing the cow mesh model to be used freely in the public domain.
|
Thank you to Keenan Crane for allowing the cow mesh model to be used freely in the public domain.
|
||||||
|
|
||||||
###### Source: http://www.cs.cmu.edu/~kmcrane/Projects/ModelRepository/
|
###### Source: http://www.cs.cmu.edu/~kmcrane/Projects/ModelRepository/
|
||||||
|
@ -90,6 +90,15 @@ dataset_map_provider_LlffDatasetMapProvider_args:
|
|||||||
n_known_frames_for_test: null
|
n_known_frames_for_test: null
|
||||||
path_manager_factory_PathManagerFactory_args:
|
path_manager_factory_PathManagerFactory_args:
|
||||||
silence_logs: true
|
silence_logs: true
|
||||||
|
dataset_map_provider_RenderedMeshDatasetMapProvider_args:
|
||||||
|
num_views: 40
|
||||||
|
data_file: null
|
||||||
|
azimuth_range: 180.0
|
||||||
|
resolution: 128
|
||||||
|
use_point_light: true
|
||||||
|
path_manager_factory_class_type: PathManagerFactory
|
||||||
|
path_manager_factory_PathManagerFactory_args:
|
||||||
|
silence_logs: true
|
||||||
data_loader_map_provider_SequenceDataLoaderMapProvider_args:
|
data_loader_map_provider_SequenceDataLoaderMapProvider_args:
|
||||||
batch_size: 1
|
batch_size: 1
|
||||||
num_workers: 0
|
num_workers: 0
|
||||||
|
57
tests/implicitron/test_data_cow.py
Normal file
57
tests/implicitron/test_data_cow.py
Normal file
@ -0,0 +1,57 @@
|
|||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the BSD-style license found in the
|
||||||
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
|
import os
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from pytorch3d.implicitron.dataset.dataset_base import FrameData
|
||||||
|
from pytorch3d.implicitron.dataset.rendered_mesh_dataset_map_provider import (
|
||||||
|
RenderedMeshDatasetMapProvider,
|
||||||
|
)
|
||||||
|
from pytorch3d.implicitron.tools.config import expand_args_fields
|
||||||
|
from pytorch3d.renderer import FoVPerspectiveCameras
|
||||||
|
from tests.common_testing import TestCaseMixin
|
||||||
|
|
||||||
|
|
||||||
|
inside_re_worker = os.environ.get("INSIDE_RE_WORKER", False)
|
||||||
|
|
||||||
|
|
||||||
|
class TestDataCow(TestCaseMixin, unittest.TestCase):
|
||||||
|
def test_simple(self):
|
||||||
|
if inside_re_worker:
|
||||||
|
return
|
||||||
|
expand_args_fields(RenderedMeshDatasetMapProvider)
|
||||||
|
self._runtest(use_point_light=True, num_views=4)
|
||||||
|
self._runtest(use_point_light=False, num_views=4)
|
||||||
|
|
||||||
|
def _runtest(self, **kwargs):
|
||||||
|
provider = RenderedMeshDatasetMapProvider(**kwargs)
|
||||||
|
dataset_map = provider.get_dataset_map()
|
||||||
|
known_matrix = torch.zeros(1, 4, 4)
|
||||||
|
known_matrix[0, 0, 0] = 1.7321
|
||||||
|
known_matrix[0, 1, 1] = 1.7321
|
||||||
|
known_matrix[0, 2, 2] = 1.0101
|
||||||
|
known_matrix[0, 3, 2] = -1.0101
|
||||||
|
known_matrix[0, 2, 3] = 1
|
||||||
|
|
||||||
|
self.assertIsNone(dataset_map.val)
|
||||||
|
self.assertIsNone(dataset_map.test)
|
||||||
|
self.assertEqual(len(dataset_map.train), provider.num_views)
|
||||||
|
|
||||||
|
value = dataset_map.train[0]
|
||||||
|
self.assertIsInstance(value, FrameData)
|
||||||
|
|
||||||
|
self.assertEqual(value.image_rgb.shape, (3, 128, 128))
|
||||||
|
self.assertEqual(value.fg_probability.shape, (1, 128, 128))
|
||||||
|
# corner of image is background
|
||||||
|
self.assertEqual(value.fg_probability[0, 0, 0], 0)
|
||||||
|
self.assertEqual(value.fg_probability.max(), 1.0)
|
||||||
|
self.assertIsInstance(value.camera, FoVPerspectiveCameras)
|
||||||
|
self.assertEqual(len(value.camera), 1)
|
||||||
|
self.assertIsNone(value.camera.K)
|
||||||
|
matrix = value.camera.get_projection_transform().get_matrix()
|
||||||
|
self.assertClose(matrix, known_matrix, atol=1e-4)
|
Loading…
x
Reference in New Issue
Block a user