diff --git a/pytorch3d/datasets/r2n2/r2n2.py b/pytorch3d/datasets/r2n2/r2n2.py index 13214c39..ecf3fe62 100644 --- a/pytorch3d/datasets/r2n2/r2n2.py +++ b/pytorch3d/datasets/r2n2/r2n2.py @@ -10,7 +10,9 @@ import numpy as np import torch from PIL import Image from pytorch3d.datasets.shapenet_base import ShapeNetBase +from pytorch3d.datasets.utils import compute_extrinsic_matrix from pytorch3d.io import load_obj +from pytorch3d.renderer import HardPhongShader from pytorch3d.renderer.cameras import CamerasBase from pytorch3d.transforms import Transform3d from tabulate import tabulate @@ -168,6 +170,9 @@ class R2N2(ShapeNetBase): - label (str): synset label. - images: FloatTensor of shape (V, H, W, C), where V is number of views returned. Returns a batch of the renderings of the models from the R2N2 dataset. + - R: Rotation matrix of shape (V, 3, 3), where V is number of views returned. + - T: Translation matrix of shape (V, 3), where V is number of views returned. + - K: Intrinsic matrix of shape (V, 4, 4), where V is number of views returned. """ if isinstance(model_idx, tuple): model_idx, view_idxs = model_idx @@ -213,7 +218,11 @@ class R2N2(ShapeNetBase): "rendering", ) - images = [] + # Read metadata file to obtain params for calibration matrices. + with open(path.join(rendering_path, "rendering_metadata.txt"), "r") as f: + metadata_lines = f.readlines() + + images, Rs, Ts = [], [], [] for i in model_views: # Read image. image_path = path.join(rendering_path, "%02d.png" % i) @@ -221,10 +230,125 @@ class R2N2(ShapeNetBase): image = torch.from_numpy(np.array(raw_img) / 255.0)[..., :3] images.append(image.to(dtype=torch.float32)) + # Get camera calibration. + azim, elev, yaw, dist_ratio, fov = [ + float(v) for v in metadata_lines[i].strip().split(" ") + ] + R, T = self._compute_camera_calibration(azim, elev, dist_ratio) + Rs.append(R) + Ts.append(T) + + # Intrinsic matrix extracted from the Blender with slight modification to work with + # PyTorch3D world space. Taken from meshrcnn codebase: + # https://github.com/facebookresearch/meshrcnn/blob/master/shapenet/utils/coords.py + K = torch.tensor( + [ + [2.1875, 0.0, 0.0, 0.0], + [0.0, 2.1875, 0.0, 0.0], + [0.0, 0.0, -1.002002, -0.2002002], + [0.0, 0.0, 1.0, 0.0], + ] + ) model["images"] = torch.stack(images) + model["R"] = torch.stack(Rs) + model["T"] = torch.stack(Ts) + model["K"] = K.expand(len(model_views), 4, 4) return model + def _compute_camera_calibration(self, azim: float, elev: float, dist_ratio: float): + """ + Helper function for calculating rotation and translation matrices from azimuth + angle, elevation and distance ratio. + + Args: + azim: Rotation about the z-axis, in degrees. + elev: Rotation above the xy-plane, in degrees. + dist_ratio: Ratio of distance from the origin to the maximum camera distance. + + Returns: + - R: Rotation matrix of shape (3, 3). + - T: Translation matrix of shape (3). + """ + # Retrive R,T,K of the selected view(s) by reading the metadata. + MAX_CAMERA_DISTANCE = 1.75 # Constant from R2N2. + dist = dist_ratio * MAX_CAMERA_DISTANCE + RT = compute_extrinsic_matrix(azim, elev, dist) + + # Transform the mesh vertices from shapenet world to pytorch3d world. + shapenet_to_pytorch3d = torch.tensor( + [ + [-1.0, 0.0, 0.0, 0.0], + [0.0, 1.0, 0.0, 0.0], + [0.0, 0.0, -1.0, 0.0], + [0.0, 0.0, 0.0, 1.0], + ], + dtype=torch.float32, + ) + RT = compute_extrinsic_matrix(azim, elev, dist) # (4, 4) + RT = torch.transpose(RT, 0, 1).mm(shapenet_to_pytorch3d) # (4, 4) + + # Extract rotation and translation matrices from RT. + R = RT[:3, :3] + T = RT[3, :3] + return R, T + + def render( + self, + model_ids: Optional[List[str]] = None, + categories: Optional[List[str]] = None, + sample_nums: Optional[List[int]] = None, + idxs: Optional[List[int]] = None, + view_idxs: Optional[List[int]] = None, + shader_type=HardPhongShader, + device="cpu", + **kwargs + ) -> torch.Tensor: + """ + Render models with BlenderCamera by default to achieve the same orientations as the + R2N2 renderings. Also accepts other types of cameras and any of the args that the + render function in the ShapeNetBase class accepts. + + Args: + view_idxs: each model will be rendered with the orientation(s) of the specified + views. Only render by view_idxs if no camera or args for BlenderCamera is + supplied. + Accepts any of the args of the render function in ShapnetBase: + model_ids: List[str] of model_ids of models intended to be rendered. + categories: List[str] of categories intended to be rendered. categories + and sample_nums must be specified at the same time. categories can be given + in the form of synset offsets or labels, or a combination of both. + sample_nums: List[int] of number of models to be randomly sampled from + each category. Could also contain one single integer, in which case it + will be broadcasted for every category. + idxs: List[int] of indices of models to be rendered in the dataset. + shader_type: Shader to use for rendering. Examples include HardPhongShader + (default), SoftPhongShader etc or any other type of valid Shader class. + device: torch.device on which the tensors should be located. + **kwargs: Accepts any of the kwargs that the renderer supports and any of the + args that BlenderCamera supports. + + Returns: + Batch of rendered images of shape (N, H, W, 3). + """ + idxs = self._handle_render_inputs(model_ids, categories, sample_nums, idxs) + r = torch.cat([self[idxs[i], view_idxs]["R"] for i in range(len(idxs))]) + t = torch.cat([self[idxs[i], view_idxs]["T"] for i in range(len(idxs))]) + k = torch.cat([self[idxs[i], view_idxs]["K"] for i in range(len(idxs))]) + # Initialize default camera using R, T, K from kwargs or R, T, K of the specified views. + blend_cameras = BlenderCamera( + R=kwargs.get("R", r), + T=kwargs.get("T", t), + K=kwargs.get("K", k), + device=device, + ) + cameras = kwargs.get("cameras", blend_cameras).to(device) + kwargs.pop("cameras", None) + # pass down all the same inputs + return super().render( + idxs=idxs, shader_type=shader_type, device=device, cameras=cameras, **kwargs + ) + class BlenderCamera(CamerasBase): """ diff --git a/pytorch3d/datasets/shapenet_base.py b/pytorch3d/datasets/shapenet_base.py index 722bde09..735a8414 100644 --- a/pytorch3d/datasets/shapenet_base.py +++ b/pytorch3d/datasets/shapenet_base.py @@ -111,12 +111,27 @@ class ShapeNetBase(torch.utils.data.Dataset): Returns: Batch of rendered images of shape (N, H, W, 3). """ - paths = self._handle_render_inputs(model_ids, categories, sample_nums, idxs) + idxs = self._handle_render_inputs(model_ids, categories, sample_nums, idxs) + paths = [ + path.join( + self.shapenet_dir, + self.synset_ids[idx], + self.model_ids[idx], + self.model_dir, + ) + for idx in idxs + ] meshes = load_objs_as_meshes(paths, device=device, load_textures=False) meshes.textures = TexturesVertex( verts_features=torch.ones_like(meshes.verts_padded(), device=device) ) cameras = kwargs.get("cameras", OpenGLPerspectiveCameras()).to(device) + if len(cameras) != 1 and len(cameras) % len(meshes) != 0: + raise ValueError("Mismatch between batch dims of cameras and meshes.") + if len(cameras) > 1: + # When rendering R2N2 models, if more than one views are provided, broadcast + # the meshes so that each mesh can be rendered for each of the views. + meshes = meshes.extend(len(cameras) // len(meshes)) renderer = MeshRenderer( rasterizer=MeshRasterizer( cameras=cameras, @@ -136,7 +151,7 @@ class ShapeNetBase(torch.utils.data.Dataset): categories: Optional[List[str]] = None, sample_nums: Optional[List[int]] = None, idxs: Optional[List[int]] = None, - ) -> List[str]: + ) -> List[int]: """ Helper function for converting user provided model_ids, categories and sample_nums to indices of models in the loaded dataset. If model idxs are provided, we check if @@ -206,15 +221,7 @@ class ShapeNetBase(torch.utils.data.Dataset): ) warnings.warn(msg) idxs = self._sample_idxs_from_category(sample_nums[0]) - return [ - path.join( - self.shapenet_dir, - self.synset_ids[idx], - self.model_ids[idx], - self.model_dir, - ) - for idx in idxs - ] + return idxs def _sample_idxs_from_category( self, sample_num: int = 1, category: Optional[str] = None diff --git a/pytorch3d/datasets/utils.py b/pytorch3d/datasets/utils.py index 5c2f4bc0..43243f5c 100644 --- a/pytorch3d/datasets/utils.py +++ b/pytorch3d/datasets/utils.py @@ -1,5 +1,5 @@ # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. - +import math from typing import Dict, List import torch @@ -34,10 +34,77 @@ def collate_batched_meshes(batch: List[Dict]): verts=collated_dict["verts"], faces=collated_dict["faces"] ) - # If collate_batched_meshes receives R2N2 items, stack the batches of - # views of each model into a new batch of shape (N, V, H, W, 3) where - # V is the number of views. + # If collate_batched_meshes receives R2N2 items with images and that + # all models have the same number of views V, stack the batches of + # views of each model into a new batch of shape (N, V, H, W, 3). + # Otherwise leave it as a list. if "images" in collated_dict: - collated_dict["images"] = torch.stack(collated_dict["images"]) + try: + collated_dict["images"] = torch.stack(collated_dict["images"]) + except RuntimeError: + print( + "Models don't have the same number of views. Now returning " + "lists of images instead of batches." + ) + + # If collate_batched_meshes receives R2N2 items with camera calibration + # matrices and that all models have the same number of views V, stack each + # type of matrices into a new batch of shape (N, V, ...). + # Otherwise leave them as lists. + if all(x in collated_dict for x in ["R", "T", "K"]): + try: + collated_dict["R"] = torch.stack(collated_dict["R"]) # (N, V, 3, 3) + collated_dict["T"] = torch.stack(collated_dict["T"]) # (N, V, 3) + collated_dict["K"] = torch.stack(collated_dict["K"]) # (N, V, 4, 4) + except RuntimeError: + print( + "Models don't have the same number of views. Now returning " + "lists of calibration matrices instead of batches." + ) return collated_dict + + +def compute_extrinsic_matrix(azimuth, elevation, distance): + """ + Copied from meshrcnn codebase: + https://github.com/facebookresearch/meshrcnn/blob/master/shapenet/utils/coords.py#L96 + + Compute 4x4 extrinsic matrix that converts from homogenous world coordinates + to homogenous camera coordinates. We assume that the camera is looking at the + origin. + Used in R2N2 Dataset when computing calibration matrices. + + Args: + azimuth: Rotation about the z-axis, in degrees. + elevation: Rotation above the xy-plane, in degrees. + distance: Distance from the origin. + + Returns: + FloatTensor of shape (4, 4). + """ + azimuth, elevation, distance = float(azimuth), float(elevation), float(distance) + + az_rad = -math.pi * azimuth / 180.0 + el_rad = -math.pi * elevation / 180.0 + sa = math.sin(az_rad) + ca = math.cos(az_rad) + se = math.sin(el_rad) + ce = math.cos(el_rad) + R_world2obj = torch.tensor( + [[ca * ce, sa * ce, -se], [-sa, ca, 0], [ca * se, sa * se, ce]] + ) + R_obj2cam = torch.tensor([[0.0, 1.0, 0.0], [0.0, 0.0, 1.0], [1.0, 0.0, 0.0]]) + R_world2cam = R_obj2cam.mm(R_world2obj) + cam_location = torch.tensor([[distance, 0, 0]]).t() + T_world2cam = -(R_obj2cam.mm(cam_location)) + RT = torch.cat([R_world2cam, T_world2cam], dim=1) + RT = torch.cat([RT, torch.tensor([[0.0, 0, 0, 1]])]) + + # Georgia: For some reason I cannot fathom, when Blender loads a .obj file it + # rotates the model 90 degrees about the x axis. To compensate for this quirk we + # roll that rotation into the extrinsic matrix here + rot = torch.tensor([[1, 0, 0, 0], [0, 0, -1, 0], [0, 1, 0, 0], [0, 0, 0, 1]]) + RT = RT.mm(rot.to(RT)) + + return RT diff --git a/tests/data/test_r2n2_render_with_blender_calibrations_0.png b/tests/data/test_r2n2_render_with_blender_calibrations_0.png new file mode 100644 index 00000000..c9c169c9 Binary files /dev/null and b/tests/data/test_r2n2_render_with_blender_calibrations_0.png differ diff --git a/tests/data/test_r2n2_render_with_blender_calibrations_1.png b/tests/data/test_r2n2_render_with_blender_calibrations_1.png new file mode 100644 index 00000000..7c869338 Binary files /dev/null and b/tests/data/test_r2n2_render_with_blender_calibrations_1.png differ diff --git a/tests/data/test_r2n2_render_with_blender_calibrations_2.png b/tests/data/test_r2n2_render_with_blender_calibrations_2.png new file mode 100644 index 00000000..1cbda3d4 Binary files /dev/null and b/tests/data/test_r2n2_render_with_blender_calibrations_2.png differ diff --git a/tests/data/test_r2n2_render_with_blender_calibrations_3.png b/tests/data/test_r2n2_render_with_blender_calibrations_3.png new file mode 100644 index 00000000..4d907198 Binary files /dev/null and b/tests/data/test_r2n2_render_with_blender_calibrations_3.png differ diff --git a/tests/test_r2n2.py b/tests/test_r2n2.py index 3f4a115e..01ab3274 100644 --- a/tests/test_r2n2.py +++ b/tests/test_r2n2.py @@ -93,10 +93,18 @@ class TestR2N2(TestCaseMixin, unittest.TestCase): self.assertEqual(faces.ndim, 2) self.assertEqual(faces.shape[-1], 3) + # Check that the intrinsic matrix and extrinsic matrix have the + # correct shapes. + self.assertEqual(r2n2_obj["R"].shape[0], 24) + self.assertEqual(r2n2_obj["R"].shape[1:], (3, 3)) + self.assertEqual(r2n2_obj["T"].ndim, 2) + self.assertEqual(r2n2_obj["T"].shape[1], 3) + self.assertEqual(r2n2_obj["K"].ndim, 3) + self.assertEqual(r2n2_obj["K"].shape[1:], (4, 4)) + # Check that image batch returned by __getitem__ has the correct shape. self.assertEqual(r2n2_obj["images"].shape[0], 24) - self.assertEqual(r2n2_obj["images"].shape[1], 137) - self.assertEqual(r2n2_obj["images"].shape[2], 137) + self.assertEqual(r2n2_obj["images"].shape[1:-1], (137, 137)) self.assertEqual(r2n2_obj["images"].shape[-1], 3) self.assertEqual(r2n2_dataset[39, [21]]["images"].shape[0], 1) self.assertEqual(r2n2_dataset[39, torch.tensor([12, 21])]["images"].shape[0], 2) @@ -113,7 +121,7 @@ class TestR2N2(TestCaseMixin, unittest.TestCase): the correct shapes and types are returned. """ # Load dataset in the train split. - r2n2_dataset = R2N2("train", SHAPENET_PATH, R2N2_PATH, SPLITS_PATH) + r2n2_dataset = R2N2("val", SHAPENET_PATH, R2N2_PATH, SPLITS_PATH) # Randomly retrieve several objects from the dataset and collate them. collated_meshes = collate_batched_meshes( @@ -147,6 +155,9 @@ class TestR2N2(TestCaseMixin, unittest.TestCase): self.assertEqual(object_batch["mesh"].verts_padded().shape[0], batch_size) self.assertEqual(object_batch["mesh"].faces_padded().shape[0], batch_size) self.assertEqual(object_batch["images"].shape[0], batch_size) + self.assertEqual(object_batch["R"].shape[0], batch_size) + self.assertEqual(object_batch["T"].shape[0], batch_size) + self.assertEqual(object_batch["K"].shape[0], batch_size) def test_catch_render_arg_errors(self): """ @@ -166,6 +177,13 @@ class TestR2N2(TestCaseMixin, unittest.TestCase): r2n2_dataset.render(idxs=[1000000]) self.assertTrue("are out of bounds" in str(err.exception)) + blend_cameras = BlenderCamera( + R=torch.rand((3, 3, 3)), T=torch.rand((3, 3)), K=torch.rand((3, 4, 4)) + ) + with self.assertRaises(ValueError) as err: + r2n2_dataset.render(idxs=[10, 11], cameras=blend_cameras) + self.assertTrue("Mismatch between batch dims" in str(err.exception)) + def test_render_r2n2(self): """ Test rendering objects from R2N2 selected both by indices and model_ids. @@ -279,3 +297,44 @@ class TestR2N2(TestCaseMixin, unittest.TestCase): C = cam.get_camera_center() C_ = -torch.bmm(R, T[:, :, None])[:, :, 0] self.assertTrue(torch.allclose(C, C_, atol=1e-05)) + + def test_render_by_r2n2_calibration(self): + """ + Test rendering R2N2 models with calibration matrices from R2N2's own Blender + in batches. + """ + # Set up device and seed for random selections. + device = torch.device("cuda:0") + torch.manual_seed(39) + + # Load dataset in the train split. + r2n2_dataset = R2N2("train", SHAPENET_PATH, R2N2_PATH, SPLITS_PATH) + model_idxs = torch.randint(1000, (2,)).tolist() + view_idxs = torch.randint(24, (2,)).tolist() + raster_settings = RasterizationSettings(image_size=512) + lights = PointLights( + location=torch.tensor([0.0, 1.0, -2.0], device=device)[None], + # TODO(nikhilar): debug the source of the discrepancy in two images when + # rendering on GPU. + diffuse_color=((0, 0, 0),), + specular_color=((0, 0, 0),), + device=device, + ) + r2n2_batch = r2n2_dataset.render( + idxs=model_idxs, + view_idxs=view_idxs, + device=device, + raster_settings=raster_settings, + lights=lights, + ) + for idx in range(4): + r2n2_batch_rgb = r2n2_batch[idx, ..., :3].squeeze().cpu() + if DEBUG: + Image.fromarray((r2n2_batch_rgb.numpy() * 255).astype(np.uint8)).save( + DATA_DIR + / ("DEBUG_r2n2_render_with_blender_calibrations_%s.png" % idx) + ) + image_ref = load_rgb_image( + "test_r2n2_render_with_blender_calibrations_%s.png" % idx, DATA_DIR + ) + self.assertClose(r2n2_batch_rgb, image_ref, atol=0.05)