diff --git a/pytorch3d/datasets/__init__.py b/pytorch3d/datasets/__init__.py index 78679b6f..16872130 100644 --- a/pytorch3d/datasets/__init__.py +++ b/pytorch3d/datasets/__init__.py @@ -1,6 +1,6 @@ # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. -from .r2n2 import R2N2 +from .r2n2 import R2N2, BlenderCamera from .shapenet import ShapeNetCore from .utils import collate_batched_meshes diff --git a/pytorch3d/datasets/r2n2/__init__.py b/pytorch3d/datasets/r2n2/__init__.py index a98d7b00..f18dc45d 100644 --- a/pytorch3d/datasets/r2n2/__init__.py +++ b/pytorch3d/datasets/r2n2/__init__.py @@ -1,6 +1,6 @@ # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. -from .r2n2 import R2N2 +from .r2n2 import R2N2, BlenderCamera __all__ = [k for k in globals().keys() if not k.startswith("_")] diff --git a/pytorch3d/datasets/r2n2/r2n2.py b/pytorch3d/datasets/r2n2/r2n2.py index 46865b05..13214c39 100644 --- a/pytorch3d/datasets/r2n2/r2n2.py +++ b/pytorch3d/datasets/r2n2/r2n2.py @@ -11,11 +11,18 @@ import torch from PIL import Image from pytorch3d.datasets.shapenet_base import ShapeNetBase from pytorch3d.io import load_obj +from pytorch3d.renderer.cameras import CamerasBase +from pytorch3d.transforms import Transform3d from tabulate import tabulate SYNSET_DICT_DIR = Path(__file__).resolve().parent +# Default values of rotation, translation and intrinsic matrices for BlenderCamera. +r = np.expand_dims(np.eye(3), axis=0) # (1, 3, 3) +t = np.expand_dims(np.zeros(3), axis=0) # (1, 3) +k = np.expand_dims(np.eye(4), axis=0) # (1, 4, 4) + class R2N2(ShapeNetBase): """ @@ -217,3 +224,27 @@ class R2N2(ShapeNetBase): model["images"] = torch.stack(images) return model + + +class BlenderCamera(CamerasBase): + """ + Camera for rendering objects with calibration matrices from the R2N2 dataset + (which uses Blender for rendering the views for each model). + """ + + def __init__(self, R=r, T=t, K=k, device="cpu"): + """ + Args: + R: Rotation matrix of shape (N, 3, 3). + T: Translation matrix of shape (N, 3). + K: Intrinsic matrix of shape (N, 4, 4). + device: torch.device or str. + """ + # The initializer formats all inputs to torch tensors and broadcasts + # all the inputs to have the same batch dimension where necessary. + super().__init__(device=device, R=R, T=T, K=K) + + def get_projection_transform(self, **kwargs) -> Transform3d: + transform = Transform3d(device=self.device) + transform._matrix = self.K.transpose(1, 2).contiguous() # pyre-ignore[16] + return transform diff --git a/tests/test_r2n2.py b/tests/test_r2n2.py index 9b0253db..3f4a115e 100644 --- a/tests/test_r2n2.py +++ b/tests/test_r2n2.py @@ -11,13 +11,16 @@ import numpy as np import torch from common_testing import TestCaseMixin, load_rgb_image from PIL import Image -from pytorch3d.datasets import R2N2, collate_batched_meshes +from pytorch3d.datasets import R2N2, BlenderCamera, collate_batched_meshes from pytorch3d.renderer import ( OpenGLPerspectiveCameras, PointLights, RasterizationSettings, look_at_view_transform, ) +from pytorch3d.renderer.cameras import get_world_to_view_transform +from pytorch3d.transforms import Transform3d +from pytorch3d.transforms.so3 import so3_exponential_map from torch.utils.data import DataLoader @@ -258,3 +261,21 @@ class TestR2N2(TestCaseMixin, unittest.TestCase): "test_r2n2_render_by_categories_%s.png" % idx, DATA_DIR ) self.assertClose(mixed_rgb, image_ref, atol=0.05) + + def test_blender_camera(self): + """ + Test BlenderCamera. + """ + # Test get_world_to_view_transform. + T = torch.randn(10, 3) + R = so3_exponential_map(torch.randn(10, 3) * 3.0) + RT = get_world_to_view_transform(R=R, T=T) + cam = BlenderCamera(R=R, T=T) + RT_class = cam.get_world_to_view_transform() + self.assertTrue(torch.allclose(RT.get_matrix(), RT_class.get_matrix())) + self.assertTrue(isinstance(RT, Transform3d)) + + # Test getting camera center. + C = cam.get_camera_center() + C_ = -torch.bmm(R, T[:, :, None])[:, :, 0] + self.assertTrue(torch.allclose(C, C_, atol=1e-05))