mirror of
				https://github.com/facebookresearch/pytorch3d.git
				synced 2025-11-04 18:02:14 +08:00 
			
		
		
		
	Return R2N2 R,T,K
Summary: Return rotation, translation and intrinsic matrices necessary to reproduce R2N2's own renderings. Reviewed By: nikhilaravi Differential Revision: D22462520 fbshipit-source-id: 46a3859743ebc43c7a24f75827d2be3adf3f486b
This commit is contained in:
		
							parent
							
								
									c122ccb13c
								
							
						
					
					
						commit
						326e4ccb5b
					
				@ -10,7 +10,9 @@ import numpy as np
 | 
			
		||||
import torch
 | 
			
		||||
from PIL import Image
 | 
			
		||||
from pytorch3d.datasets.shapenet_base import ShapeNetBase
 | 
			
		||||
from pytorch3d.datasets.utils import compute_extrinsic_matrix
 | 
			
		||||
from pytorch3d.io import load_obj
 | 
			
		||||
from pytorch3d.renderer import HardPhongShader
 | 
			
		||||
from pytorch3d.renderer.cameras import CamerasBase
 | 
			
		||||
from pytorch3d.transforms import Transform3d
 | 
			
		||||
from tabulate import tabulate
 | 
			
		||||
@ -168,6 +170,9 @@ class R2N2(ShapeNetBase):
 | 
			
		||||
            - label (str): synset label.
 | 
			
		||||
            - images: FloatTensor of shape (V, H, W, C), where V is number of views
 | 
			
		||||
                returned. Returns a batch of the renderings of the models from the R2N2 dataset.
 | 
			
		||||
            - R: Rotation matrix of shape (V, 3, 3), where V is number of views returned.
 | 
			
		||||
            - T: Translation matrix of shape (V, 3), where V is number of views returned.
 | 
			
		||||
            - K: Intrinsic matrix of shape (V, 4, 4), where V is number of views returned.
 | 
			
		||||
        """
 | 
			
		||||
        if isinstance(model_idx, tuple):
 | 
			
		||||
            model_idx, view_idxs = model_idx
 | 
			
		||||
@ -213,7 +218,11 @@ class R2N2(ShapeNetBase):
 | 
			
		||||
                "rendering",
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
            images = []
 | 
			
		||||
            # Read metadata file to obtain params for calibration matrices.
 | 
			
		||||
            with open(path.join(rendering_path, "rendering_metadata.txt"), "r") as f:
 | 
			
		||||
                metadata_lines = f.readlines()
 | 
			
		||||
 | 
			
		||||
            images, Rs, Ts = [], [], []
 | 
			
		||||
            for i in model_views:
 | 
			
		||||
                # Read image.
 | 
			
		||||
                image_path = path.join(rendering_path, "%02d.png" % i)
 | 
			
		||||
@ -221,10 +230,125 @@ class R2N2(ShapeNetBase):
 | 
			
		||||
                image = torch.from_numpy(np.array(raw_img) / 255.0)[..., :3]
 | 
			
		||||
                images.append(image.to(dtype=torch.float32))
 | 
			
		||||
 | 
			
		||||
                # Get camera calibration.
 | 
			
		||||
                azim, elev, yaw, dist_ratio, fov = [
 | 
			
		||||
                    float(v) for v in metadata_lines[i].strip().split(" ")
 | 
			
		||||
                ]
 | 
			
		||||
                R, T = self._compute_camera_calibration(azim, elev, dist_ratio)
 | 
			
		||||
                Rs.append(R)
 | 
			
		||||
                Ts.append(T)
 | 
			
		||||
 | 
			
		||||
            # Intrinsic matrix extracted from the Blender with slight modification to work with
 | 
			
		||||
            # PyTorch3D world space. Taken from meshrcnn codebase:
 | 
			
		||||
            # https://github.com/facebookresearch/meshrcnn/blob/master/shapenet/utils/coords.py
 | 
			
		||||
            K = torch.tensor(
 | 
			
		||||
                [
 | 
			
		||||
                    [2.1875, 0.0, 0.0, 0.0],
 | 
			
		||||
                    [0.0, 2.1875, 0.0, 0.0],
 | 
			
		||||
                    [0.0, 0.0, -1.002002, -0.2002002],
 | 
			
		||||
                    [0.0, 0.0, 1.0, 0.0],
 | 
			
		||||
                ]
 | 
			
		||||
            )
 | 
			
		||||
            model["images"] = torch.stack(images)
 | 
			
		||||
            model["R"] = torch.stack(Rs)
 | 
			
		||||
            model["T"] = torch.stack(Ts)
 | 
			
		||||
            model["K"] = K.expand(len(model_views), 4, 4)
 | 
			
		||||
 | 
			
		||||
        return model
 | 
			
		||||
 | 
			
		||||
    def _compute_camera_calibration(self, azim: float, elev: float, dist_ratio: float):
 | 
			
		||||
        """
 | 
			
		||||
        Helper function for calculating rotation and translation matrices from azimuth
 | 
			
		||||
        angle, elevation and distance ratio.
 | 
			
		||||
 | 
			
		||||
        Args:
 | 
			
		||||
            azim: Rotation about the z-axis, in degrees.
 | 
			
		||||
            elev: Rotation above the xy-plane, in degrees.
 | 
			
		||||
            dist_ratio: Ratio of distance from the origin to the maximum camera distance.
 | 
			
		||||
 | 
			
		||||
        Returns:
 | 
			
		||||
            - R: Rotation matrix of shape (3, 3).
 | 
			
		||||
            - T: Translation matrix of shape (3).
 | 
			
		||||
        """
 | 
			
		||||
        # Retrive R,T,K of the selected view(s) by reading the metadata.
 | 
			
		||||
        MAX_CAMERA_DISTANCE = 1.75  # Constant from R2N2.
 | 
			
		||||
        dist = dist_ratio * MAX_CAMERA_DISTANCE
 | 
			
		||||
        RT = compute_extrinsic_matrix(azim, elev, dist)
 | 
			
		||||
 | 
			
		||||
        # Transform the mesh vertices from shapenet world to pytorch3d world.
 | 
			
		||||
        shapenet_to_pytorch3d = torch.tensor(
 | 
			
		||||
            [
 | 
			
		||||
                [-1.0, 0.0, 0.0, 0.0],
 | 
			
		||||
                [0.0, 1.0, 0.0, 0.0],
 | 
			
		||||
                [0.0, 0.0, -1.0, 0.0],
 | 
			
		||||
                [0.0, 0.0, 0.0, 1.0],
 | 
			
		||||
            ],
 | 
			
		||||
            dtype=torch.float32,
 | 
			
		||||
        )
 | 
			
		||||
        RT = compute_extrinsic_matrix(azim, elev, dist)  # (4, 4)
 | 
			
		||||
        RT = torch.transpose(RT, 0, 1).mm(shapenet_to_pytorch3d)  # (4, 4)
 | 
			
		||||
 | 
			
		||||
        # Extract rotation and translation matrices from RT.
 | 
			
		||||
        R = RT[:3, :3]
 | 
			
		||||
        T = RT[3, :3]
 | 
			
		||||
        return R, T
 | 
			
		||||
 | 
			
		||||
    def render(
 | 
			
		||||
        self,
 | 
			
		||||
        model_ids: Optional[List[str]] = None,
 | 
			
		||||
        categories: Optional[List[str]] = None,
 | 
			
		||||
        sample_nums: Optional[List[int]] = None,
 | 
			
		||||
        idxs: Optional[List[int]] = None,
 | 
			
		||||
        view_idxs: Optional[List[int]] = None,
 | 
			
		||||
        shader_type=HardPhongShader,
 | 
			
		||||
        device="cpu",
 | 
			
		||||
        **kwargs
 | 
			
		||||
    ) -> torch.Tensor:
 | 
			
		||||
        """
 | 
			
		||||
        Render models with BlenderCamera by default to achieve the same orientations as the
 | 
			
		||||
        R2N2 renderings. Also accepts other types of cameras and any of the args that the
 | 
			
		||||
        render function in the ShapeNetBase class accepts.
 | 
			
		||||
 | 
			
		||||
        Args:
 | 
			
		||||
            view_idxs: each model will be rendered with the orientation(s) of the specified
 | 
			
		||||
                views. Only render by view_idxs if no camera or args for BlenderCamera is
 | 
			
		||||
                supplied.
 | 
			
		||||
            Accepts any of the args of the render function in ShapnetBase:
 | 
			
		||||
            model_ids: List[str] of model_ids of models intended to be rendered.
 | 
			
		||||
            categories: List[str] of categories intended to be rendered. categories
 | 
			
		||||
                and sample_nums must be specified at the same time. categories can be given
 | 
			
		||||
                in the form of synset offsets or labels, or a combination of both.
 | 
			
		||||
            sample_nums: List[int] of number of models to be randomly sampled from
 | 
			
		||||
                each category. Could also contain one single integer, in which case it
 | 
			
		||||
                will be broadcasted for every category.
 | 
			
		||||
            idxs: List[int] of indices of models to be rendered in the dataset.
 | 
			
		||||
            shader_type: Shader to use for rendering. Examples include HardPhongShader
 | 
			
		||||
            (default), SoftPhongShader etc or any other type of valid Shader class.
 | 
			
		||||
            device: torch.device on which the tensors should be located.
 | 
			
		||||
            **kwargs: Accepts any of the kwargs that the renderer supports and any of the
 | 
			
		||||
                args that BlenderCamera supports.
 | 
			
		||||
 | 
			
		||||
        Returns:
 | 
			
		||||
            Batch of rendered images of shape (N, H, W, 3).
 | 
			
		||||
        """
 | 
			
		||||
        idxs = self._handle_render_inputs(model_ids, categories, sample_nums, idxs)
 | 
			
		||||
        r = torch.cat([self[idxs[i], view_idxs]["R"] for i in range(len(idxs))])
 | 
			
		||||
        t = torch.cat([self[idxs[i], view_idxs]["T"] for i in range(len(idxs))])
 | 
			
		||||
        k = torch.cat([self[idxs[i], view_idxs]["K"] for i in range(len(idxs))])
 | 
			
		||||
        # Initialize default camera using R, T, K from kwargs or R, T, K of the specified views.
 | 
			
		||||
        blend_cameras = BlenderCamera(
 | 
			
		||||
            R=kwargs.get("R", r),
 | 
			
		||||
            T=kwargs.get("T", t),
 | 
			
		||||
            K=kwargs.get("K", k),
 | 
			
		||||
            device=device,
 | 
			
		||||
        )
 | 
			
		||||
        cameras = kwargs.get("cameras", blend_cameras).to(device)
 | 
			
		||||
        kwargs.pop("cameras", None)
 | 
			
		||||
        # pass down all the same inputs
 | 
			
		||||
        return super().render(
 | 
			
		||||
            idxs=idxs, shader_type=shader_type, device=device, cameras=cameras, **kwargs
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class BlenderCamera(CamerasBase):
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
@ -111,12 +111,27 @@ class ShapeNetBase(torch.utils.data.Dataset):
 | 
			
		||||
        Returns:
 | 
			
		||||
            Batch of rendered images of shape (N, H, W, 3).
 | 
			
		||||
        """
 | 
			
		||||
        paths = self._handle_render_inputs(model_ids, categories, sample_nums, idxs)
 | 
			
		||||
        idxs = self._handle_render_inputs(model_ids, categories, sample_nums, idxs)
 | 
			
		||||
        paths = [
 | 
			
		||||
            path.join(
 | 
			
		||||
                self.shapenet_dir,
 | 
			
		||||
                self.synset_ids[idx],
 | 
			
		||||
                self.model_ids[idx],
 | 
			
		||||
                self.model_dir,
 | 
			
		||||
            )
 | 
			
		||||
            for idx in idxs
 | 
			
		||||
        ]
 | 
			
		||||
        meshes = load_objs_as_meshes(paths, device=device, load_textures=False)
 | 
			
		||||
        meshes.textures = TexturesVertex(
 | 
			
		||||
            verts_features=torch.ones_like(meshes.verts_padded(), device=device)
 | 
			
		||||
        )
 | 
			
		||||
        cameras = kwargs.get("cameras", OpenGLPerspectiveCameras()).to(device)
 | 
			
		||||
        if len(cameras) != 1 and len(cameras) % len(meshes) != 0:
 | 
			
		||||
            raise ValueError("Mismatch between batch dims of cameras and meshes.")
 | 
			
		||||
        if len(cameras) > 1:
 | 
			
		||||
            # When rendering R2N2 models, if more than one views are provided, broadcast
 | 
			
		||||
            # the meshes so that each mesh can be rendered for each of the views.
 | 
			
		||||
            meshes = meshes.extend(len(cameras) // len(meshes))
 | 
			
		||||
        renderer = MeshRenderer(
 | 
			
		||||
            rasterizer=MeshRasterizer(
 | 
			
		||||
                cameras=cameras,
 | 
			
		||||
@ -136,7 +151,7 @@ class ShapeNetBase(torch.utils.data.Dataset):
 | 
			
		||||
        categories: Optional[List[str]] = None,
 | 
			
		||||
        sample_nums: Optional[List[int]] = None,
 | 
			
		||||
        idxs: Optional[List[int]] = None,
 | 
			
		||||
    ) -> List[str]:
 | 
			
		||||
    ) -> List[int]:
 | 
			
		||||
        """
 | 
			
		||||
        Helper function for converting user provided model_ids, categories and sample_nums
 | 
			
		||||
        to indices of models in the loaded dataset. If model idxs are provided, we check if
 | 
			
		||||
@ -206,15 +221,7 @@ class ShapeNetBase(torch.utils.data.Dataset):
 | 
			
		||||
                )
 | 
			
		||||
                warnings.warn(msg)
 | 
			
		||||
            idxs = self._sample_idxs_from_category(sample_nums[0])
 | 
			
		||||
        return [
 | 
			
		||||
            path.join(
 | 
			
		||||
                self.shapenet_dir,
 | 
			
		||||
                self.synset_ids[idx],
 | 
			
		||||
                self.model_ids[idx],
 | 
			
		||||
                self.model_dir,
 | 
			
		||||
            )
 | 
			
		||||
            for idx in idxs
 | 
			
		||||
        ]
 | 
			
		||||
        return idxs
 | 
			
		||||
 | 
			
		||||
    def _sample_idxs_from_category(
 | 
			
		||||
        self, sample_num: int = 1, category: Optional[str] = None
 | 
			
		||||
 | 
			
		||||
@ -1,5 +1,5 @@
 | 
			
		||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
 | 
			
		||||
 | 
			
		||||
import math
 | 
			
		||||
from typing import Dict, List
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
@ -34,10 +34,77 @@ def collate_batched_meshes(batch: List[Dict]):
 | 
			
		||||
            verts=collated_dict["verts"], faces=collated_dict["faces"]
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    # If collate_batched_meshes receives R2N2 items, stack the batches of
 | 
			
		||||
    # views of each model into a new batch of shape (N, V, H, W, 3) where
 | 
			
		||||
    # V is the number of views.
 | 
			
		||||
    # If collate_batched_meshes receives R2N2 items with images and that
 | 
			
		||||
    # all models have the same number of views V, stack the batches of
 | 
			
		||||
    # views of each model into a new batch of shape (N, V, H, W, 3).
 | 
			
		||||
    # Otherwise leave it as a list.
 | 
			
		||||
    if "images" in collated_dict:
 | 
			
		||||
        collated_dict["images"] = torch.stack(collated_dict["images"])
 | 
			
		||||
        try:
 | 
			
		||||
            collated_dict["images"] = torch.stack(collated_dict["images"])
 | 
			
		||||
        except RuntimeError:
 | 
			
		||||
            print(
 | 
			
		||||
                "Models don't have the same number of views. Now returning "
 | 
			
		||||
                "lists of images instead of batches."
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
    # If collate_batched_meshes receives R2N2 items with camera calibration
 | 
			
		||||
    # matrices and that all models have the same number of views V, stack each
 | 
			
		||||
    # type of matrices into a new batch of shape (N, V, ...).
 | 
			
		||||
    # Otherwise leave them as lists.
 | 
			
		||||
    if all(x in collated_dict for x in ["R", "T", "K"]):
 | 
			
		||||
        try:
 | 
			
		||||
            collated_dict["R"] = torch.stack(collated_dict["R"])  # (N, V, 3, 3)
 | 
			
		||||
            collated_dict["T"] = torch.stack(collated_dict["T"])  # (N, V, 3)
 | 
			
		||||
            collated_dict["K"] = torch.stack(collated_dict["K"])  # (N, V, 4, 4)
 | 
			
		||||
        except RuntimeError:
 | 
			
		||||
            print(
 | 
			
		||||
                "Models don't have the same number of views. Now returning "
 | 
			
		||||
                "lists of calibration matrices instead of batches."
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
    return collated_dict
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def compute_extrinsic_matrix(azimuth, elevation, distance):
 | 
			
		||||
    """
 | 
			
		||||
    Copied from meshrcnn codebase:
 | 
			
		||||
    https://github.com/facebookresearch/meshrcnn/blob/master/shapenet/utils/coords.py#L96
 | 
			
		||||
 | 
			
		||||
    Compute 4x4 extrinsic matrix that converts from homogenous world coordinates
 | 
			
		||||
    to homogenous camera coordinates. We assume that the camera is looking at the
 | 
			
		||||
    origin.
 | 
			
		||||
    Used in R2N2 Dataset when computing calibration matrices.
 | 
			
		||||
 | 
			
		||||
    Args:
 | 
			
		||||
        azimuth: Rotation about the z-axis, in degrees.
 | 
			
		||||
        elevation: Rotation above the xy-plane, in degrees.
 | 
			
		||||
        distance: Distance from the origin.
 | 
			
		||||
 | 
			
		||||
    Returns:
 | 
			
		||||
        FloatTensor of shape (4, 4).
 | 
			
		||||
    """
 | 
			
		||||
    azimuth, elevation, distance = float(azimuth), float(elevation), float(distance)
 | 
			
		||||
 | 
			
		||||
    az_rad = -math.pi * azimuth / 180.0
 | 
			
		||||
    el_rad = -math.pi * elevation / 180.0
 | 
			
		||||
    sa = math.sin(az_rad)
 | 
			
		||||
    ca = math.cos(az_rad)
 | 
			
		||||
    se = math.sin(el_rad)
 | 
			
		||||
    ce = math.cos(el_rad)
 | 
			
		||||
    R_world2obj = torch.tensor(
 | 
			
		||||
        [[ca * ce, sa * ce, -se], [-sa, ca, 0], [ca * se, sa * se, ce]]
 | 
			
		||||
    )
 | 
			
		||||
    R_obj2cam = torch.tensor([[0.0, 1.0, 0.0], [0.0, 0.0, 1.0], [1.0, 0.0, 0.0]])
 | 
			
		||||
    R_world2cam = R_obj2cam.mm(R_world2obj)
 | 
			
		||||
    cam_location = torch.tensor([[distance, 0, 0]]).t()
 | 
			
		||||
    T_world2cam = -(R_obj2cam.mm(cam_location))
 | 
			
		||||
    RT = torch.cat([R_world2cam, T_world2cam], dim=1)
 | 
			
		||||
    RT = torch.cat([RT, torch.tensor([[0.0, 0, 0, 1]])])
 | 
			
		||||
 | 
			
		||||
    # Georgia: For some reason I cannot fathom, when Blender loads a .obj file it
 | 
			
		||||
    # rotates the model 90 degrees about the x axis. To compensate for this quirk we
 | 
			
		||||
    # roll that rotation into the extrinsic matrix here
 | 
			
		||||
    rot = torch.tensor([[1, 0, 0, 0], [0, 0, -1, 0], [0, 1, 0, 0], [0, 0, 0, 1]])
 | 
			
		||||
    RT = RT.mm(rot.to(RT))
 | 
			
		||||
 | 
			
		||||
    return RT
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										
											BIN
										
									
								
								tests/data/test_r2n2_render_with_blender_calibrations_0.png
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										
											BIN
										
									
								
								tests/data/test_r2n2_render_with_blender_calibrations_0.png
									
									
									
									
									
										Normal file
									
								
							
										
											Binary file not shown.
										
									
								
							| 
		 After Width: | Height: | Size: 3.0 KiB  | 
							
								
								
									
										
											BIN
										
									
								
								tests/data/test_r2n2_render_with_blender_calibrations_1.png
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										
											BIN
										
									
								
								tests/data/test_r2n2_render_with_blender_calibrations_1.png
									
									
									
									
									
										Normal file
									
								
							
										
											Binary file not shown.
										
									
								
							| 
		 After Width: | Height: | Size: 2.7 KiB  | 
							
								
								
									
										
											BIN
										
									
								
								tests/data/test_r2n2_render_with_blender_calibrations_2.png
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										
											BIN
										
									
								
								tests/data/test_r2n2_render_with_blender_calibrations_2.png
									
									
									
									
									
										Normal file
									
								
							
										
											Binary file not shown.
										
									
								
							| 
		 After Width: | Height: | Size: 2.9 KiB  | 
							
								
								
									
										
											BIN
										
									
								
								tests/data/test_r2n2_render_with_blender_calibrations_3.png
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										
											BIN
										
									
								
								tests/data/test_r2n2_render_with_blender_calibrations_3.png
									
									
									
									
									
										Normal file
									
								
							
										
											Binary file not shown.
										
									
								
							| 
		 After Width: | Height: | Size: 3.1 KiB  | 
@ -93,10 +93,18 @@ class TestR2N2(TestCaseMixin, unittest.TestCase):
 | 
			
		||||
        self.assertEqual(faces.ndim, 2)
 | 
			
		||||
        self.assertEqual(faces.shape[-1], 3)
 | 
			
		||||
 | 
			
		||||
        # Check that the intrinsic matrix and extrinsic matrix have the
 | 
			
		||||
        # correct shapes.
 | 
			
		||||
        self.assertEqual(r2n2_obj["R"].shape[0], 24)
 | 
			
		||||
        self.assertEqual(r2n2_obj["R"].shape[1:], (3, 3))
 | 
			
		||||
        self.assertEqual(r2n2_obj["T"].ndim, 2)
 | 
			
		||||
        self.assertEqual(r2n2_obj["T"].shape[1], 3)
 | 
			
		||||
        self.assertEqual(r2n2_obj["K"].ndim, 3)
 | 
			
		||||
        self.assertEqual(r2n2_obj["K"].shape[1:], (4, 4))
 | 
			
		||||
 | 
			
		||||
        # Check that image batch returned by __getitem__ has the correct shape.
 | 
			
		||||
        self.assertEqual(r2n2_obj["images"].shape[0], 24)
 | 
			
		||||
        self.assertEqual(r2n2_obj["images"].shape[1], 137)
 | 
			
		||||
        self.assertEqual(r2n2_obj["images"].shape[2], 137)
 | 
			
		||||
        self.assertEqual(r2n2_obj["images"].shape[1:-1], (137, 137))
 | 
			
		||||
        self.assertEqual(r2n2_obj["images"].shape[-1], 3)
 | 
			
		||||
        self.assertEqual(r2n2_dataset[39, [21]]["images"].shape[0], 1)
 | 
			
		||||
        self.assertEqual(r2n2_dataset[39, torch.tensor([12, 21])]["images"].shape[0], 2)
 | 
			
		||||
@ -113,7 +121,7 @@ class TestR2N2(TestCaseMixin, unittest.TestCase):
 | 
			
		||||
        the correct shapes and types are returned.
 | 
			
		||||
        """
 | 
			
		||||
        # Load dataset in the train split.
 | 
			
		||||
        r2n2_dataset = R2N2("train", SHAPENET_PATH, R2N2_PATH, SPLITS_PATH)
 | 
			
		||||
        r2n2_dataset = R2N2("val", SHAPENET_PATH, R2N2_PATH, SPLITS_PATH)
 | 
			
		||||
 | 
			
		||||
        # Randomly retrieve several objects from the dataset and collate them.
 | 
			
		||||
        collated_meshes = collate_batched_meshes(
 | 
			
		||||
@ -147,6 +155,9 @@ class TestR2N2(TestCaseMixin, unittest.TestCase):
 | 
			
		||||
        self.assertEqual(object_batch["mesh"].verts_padded().shape[0], batch_size)
 | 
			
		||||
        self.assertEqual(object_batch["mesh"].faces_padded().shape[0], batch_size)
 | 
			
		||||
        self.assertEqual(object_batch["images"].shape[0], batch_size)
 | 
			
		||||
        self.assertEqual(object_batch["R"].shape[0], batch_size)
 | 
			
		||||
        self.assertEqual(object_batch["T"].shape[0], batch_size)
 | 
			
		||||
        self.assertEqual(object_batch["K"].shape[0], batch_size)
 | 
			
		||||
 | 
			
		||||
    def test_catch_render_arg_errors(self):
 | 
			
		||||
        """
 | 
			
		||||
@ -166,6 +177,13 @@ class TestR2N2(TestCaseMixin, unittest.TestCase):
 | 
			
		||||
            r2n2_dataset.render(idxs=[1000000])
 | 
			
		||||
        self.assertTrue("are out of bounds" in str(err.exception))
 | 
			
		||||
 | 
			
		||||
        blend_cameras = BlenderCamera(
 | 
			
		||||
            R=torch.rand((3, 3, 3)), T=torch.rand((3, 3)), K=torch.rand((3, 4, 4))
 | 
			
		||||
        )
 | 
			
		||||
        with self.assertRaises(ValueError) as err:
 | 
			
		||||
            r2n2_dataset.render(idxs=[10, 11], cameras=blend_cameras)
 | 
			
		||||
        self.assertTrue("Mismatch between batch dims" in str(err.exception))
 | 
			
		||||
 | 
			
		||||
    def test_render_r2n2(self):
 | 
			
		||||
        """
 | 
			
		||||
        Test rendering objects from R2N2 selected both by indices and model_ids.
 | 
			
		||||
@ -279,3 +297,44 @@ class TestR2N2(TestCaseMixin, unittest.TestCase):
 | 
			
		||||
        C = cam.get_camera_center()
 | 
			
		||||
        C_ = -torch.bmm(R, T[:, :, None])[:, :, 0]
 | 
			
		||||
        self.assertTrue(torch.allclose(C, C_, atol=1e-05))
 | 
			
		||||
 | 
			
		||||
    def test_render_by_r2n2_calibration(self):
 | 
			
		||||
        """
 | 
			
		||||
        Test rendering R2N2 models with calibration matrices from R2N2's own Blender
 | 
			
		||||
        in batches.
 | 
			
		||||
        """
 | 
			
		||||
        # Set up device and seed for random selections.
 | 
			
		||||
        device = torch.device("cuda:0")
 | 
			
		||||
        torch.manual_seed(39)
 | 
			
		||||
 | 
			
		||||
        # Load dataset in the train split.
 | 
			
		||||
        r2n2_dataset = R2N2("train", SHAPENET_PATH, R2N2_PATH, SPLITS_PATH)
 | 
			
		||||
        model_idxs = torch.randint(1000, (2,)).tolist()
 | 
			
		||||
        view_idxs = torch.randint(24, (2,)).tolist()
 | 
			
		||||
        raster_settings = RasterizationSettings(image_size=512)
 | 
			
		||||
        lights = PointLights(
 | 
			
		||||
            location=torch.tensor([0.0, 1.0, -2.0], device=device)[None],
 | 
			
		||||
            # TODO(nikhilar): debug the source of the discrepancy in two images when
 | 
			
		||||
            # rendering on GPU.
 | 
			
		||||
            diffuse_color=((0, 0, 0),),
 | 
			
		||||
            specular_color=((0, 0, 0),),
 | 
			
		||||
            device=device,
 | 
			
		||||
        )
 | 
			
		||||
        r2n2_batch = r2n2_dataset.render(
 | 
			
		||||
            idxs=model_idxs,
 | 
			
		||||
            view_idxs=view_idxs,
 | 
			
		||||
            device=device,
 | 
			
		||||
            raster_settings=raster_settings,
 | 
			
		||||
            lights=lights,
 | 
			
		||||
        )
 | 
			
		||||
        for idx in range(4):
 | 
			
		||||
            r2n2_batch_rgb = r2n2_batch[idx, ..., :3].squeeze().cpu()
 | 
			
		||||
            if DEBUG:
 | 
			
		||||
                Image.fromarray((r2n2_batch_rgb.numpy() * 255).astype(np.uint8)).save(
 | 
			
		||||
                    DATA_DIR
 | 
			
		||||
                    / ("DEBUG_r2n2_render_with_blender_calibrations_%s.png" % idx)
 | 
			
		||||
                )
 | 
			
		||||
            image_ref = load_rgb_image(
 | 
			
		||||
                "test_r2n2_render_with_blender_calibrations_%s.png" % idx, DATA_DIR
 | 
			
		||||
            )
 | 
			
		||||
            self.assertClose(r2n2_batch_rgb, image_ref, atol=0.05)
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user