mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 03:42:50 +08:00
Return R2N2 R,T,K
Summary: Return rotation, translation and intrinsic matrices necessary to reproduce R2N2's own renderings. Reviewed By: nikhilaravi Differential Revision: D22462520 fbshipit-source-id: 46a3859743ebc43c7a24f75827d2be3adf3f486b
This commit is contained in:
parent
c122ccb13c
commit
326e4ccb5b
@ -10,7 +10,9 @@ import numpy as np
|
|||||||
import torch
|
import torch
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from pytorch3d.datasets.shapenet_base import ShapeNetBase
|
from pytorch3d.datasets.shapenet_base import ShapeNetBase
|
||||||
|
from pytorch3d.datasets.utils import compute_extrinsic_matrix
|
||||||
from pytorch3d.io import load_obj
|
from pytorch3d.io import load_obj
|
||||||
|
from pytorch3d.renderer import HardPhongShader
|
||||||
from pytorch3d.renderer.cameras import CamerasBase
|
from pytorch3d.renderer.cameras import CamerasBase
|
||||||
from pytorch3d.transforms import Transform3d
|
from pytorch3d.transforms import Transform3d
|
||||||
from tabulate import tabulate
|
from tabulate import tabulate
|
||||||
@ -168,6 +170,9 @@ class R2N2(ShapeNetBase):
|
|||||||
- label (str): synset label.
|
- label (str): synset label.
|
||||||
- images: FloatTensor of shape (V, H, W, C), where V is number of views
|
- images: FloatTensor of shape (V, H, W, C), where V is number of views
|
||||||
returned. Returns a batch of the renderings of the models from the R2N2 dataset.
|
returned. Returns a batch of the renderings of the models from the R2N2 dataset.
|
||||||
|
- R: Rotation matrix of shape (V, 3, 3), where V is number of views returned.
|
||||||
|
- T: Translation matrix of shape (V, 3), where V is number of views returned.
|
||||||
|
- K: Intrinsic matrix of shape (V, 4, 4), where V is number of views returned.
|
||||||
"""
|
"""
|
||||||
if isinstance(model_idx, tuple):
|
if isinstance(model_idx, tuple):
|
||||||
model_idx, view_idxs = model_idx
|
model_idx, view_idxs = model_idx
|
||||||
@ -213,7 +218,11 @@ class R2N2(ShapeNetBase):
|
|||||||
"rendering",
|
"rendering",
|
||||||
)
|
)
|
||||||
|
|
||||||
images = []
|
# Read metadata file to obtain params for calibration matrices.
|
||||||
|
with open(path.join(rendering_path, "rendering_metadata.txt"), "r") as f:
|
||||||
|
metadata_lines = f.readlines()
|
||||||
|
|
||||||
|
images, Rs, Ts = [], [], []
|
||||||
for i in model_views:
|
for i in model_views:
|
||||||
# Read image.
|
# Read image.
|
||||||
image_path = path.join(rendering_path, "%02d.png" % i)
|
image_path = path.join(rendering_path, "%02d.png" % i)
|
||||||
@ -221,10 +230,125 @@ class R2N2(ShapeNetBase):
|
|||||||
image = torch.from_numpy(np.array(raw_img) / 255.0)[..., :3]
|
image = torch.from_numpy(np.array(raw_img) / 255.0)[..., :3]
|
||||||
images.append(image.to(dtype=torch.float32))
|
images.append(image.to(dtype=torch.float32))
|
||||||
|
|
||||||
|
# Get camera calibration.
|
||||||
|
azim, elev, yaw, dist_ratio, fov = [
|
||||||
|
float(v) for v in metadata_lines[i].strip().split(" ")
|
||||||
|
]
|
||||||
|
R, T = self._compute_camera_calibration(azim, elev, dist_ratio)
|
||||||
|
Rs.append(R)
|
||||||
|
Ts.append(T)
|
||||||
|
|
||||||
|
# Intrinsic matrix extracted from the Blender with slight modification to work with
|
||||||
|
# PyTorch3D world space. Taken from meshrcnn codebase:
|
||||||
|
# https://github.com/facebookresearch/meshrcnn/blob/master/shapenet/utils/coords.py
|
||||||
|
K = torch.tensor(
|
||||||
|
[
|
||||||
|
[2.1875, 0.0, 0.0, 0.0],
|
||||||
|
[0.0, 2.1875, 0.0, 0.0],
|
||||||
|
[0.0, 0.0, -1.002002, -0.2002002],
|
||||||
|
[0.0, 0.0, 1.0, 0.0],
|
||||||
|
]
|
||||||
|
)
|
||||||
model["images"] = torch.stack(images)
|
model["images"] = torch.stack(images)
|
||||||
|
model["R"] = torch.stack(Rs)
|
||||||
|
model["T"] = torch.stack(Ts)
|
||||||
|
model["K"] = K.expand(len(model_views), 4, 4)
|
||||||
|
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
def _compute_camera_calibration(self, azim: float, elev: float, dist_ratio: float):
|
||||||
|
"""
|
||||||
|
Helper function for calculating rotation and translation matrices from azimuth
|
||||||
|
angle, elevation and distance ratio.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
azim: Rotation about the z-axis, in degrees.
|
||||||
|
elev: Rotation above the xy-plane, in degrees.
|
||||||
|
dist_ratio: Ratio of distance from the origin to the maximum camera distance.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
- R: Rotation matrix of shape (3, 3).
|
||||||
|
- T: Translation matrix of shape (3).
|
||||||
|
"""
|
||||||
|
# Retrive R,T,K of the selected view(s) by reading the metadata.
|
||||||
|
MAX_CAMERA_DISTANCE = 1.75 # Constant from R2N2.
|
||||||
|
dist = dist_ratio * MAX_CAMERA_DISTANCE
|
||||||
|
RT = compute_extrinsic_matrix(azim, elev, dist)
|
||||||
|
|
||||||
|
# Transform the mesh vertices from shapenet world to pytorch3d world.
|
||||||
|
shapenet_to_pytorch3d = torch.tensor(
|
||||||
|
[
|
||||||
|
[-1.0, 0.0, 0.0, 0.0],
|
||||||
|
[0.0, 1.0, 0.0, 0.0],
|
||||||
|
[0.0, 0.0, -1.0, 0.0],
|
||||||
|
[0.0, 0.0, 0.0, 1.0],
|
||||||
|
],
|
||||||
|
dtype=torch.float32,
|
||||||
|
)
|
||||||
|
RT = compute_extrinsic_matrix(azim, elev, dist) # (4, 4)
|
||||||
|
RT = torch.transpose(RT, 0, 1).mm(shapenet_to_pytorch3d) # (4, 4)
|
||||||
|
|
||||||
|
# Extract rotation and translation matrices from RT.
|
||||||
|
R = RT[:3, :3]
|
||||||
|
T = RT[3, :3]
|
||||||
|
return R, T
|
||||||
|
|
||||||
|
def render(
|
||||||
|
self,
|
||||||
|
model_ids: Optional[List[str]] = None,
|
||||||
|
categories: Optional[List[str]] = None,
|
||||||
|
sample_nums: Optional[List[int]] = None,
|
||||||
|
idxs: Optional[List[int]] = None,
|
||||||
|
view_idxs: Optional[List[int]] = None,
|
||||||
|
shader_type=HardPhongShader,
|
||||||
|
device="cpu",
|
||||||
|
**kwargs
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Render models with BlenderCamera by default to achieve the same orientations as the
|
||||||
|
R2N2 renderings. Also accepts other types of cameras and any of the args that the
|
||||||
|
render function in the ShapeNetBase class accepts.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
view_idxs: each model will be rendered with the orientation(s) of the specified
|
||||||
|
views. Only render by view_idxs if no camera or args for BlenderCamera is
|
||||||
|
supplied.
|
||||||
|
Accepts any of the args of the render function in ShapnetBase:
|
||||||
|
model_ids: List[str] of model_ids of models intended to be rendered.
|
||||||
|
categories: List[str] of categories intended to be rendered. categories
|
||||||
|
and sample_nums must be specified at the same time. categories can be given
|
||||||
|
in the form of synset offsets or labels, or a combination of both.
|
||||||
|
sample_nums: List[int] of number of models to be randomly sampled from
|
||||||
|
each category. Could also contain one single integer, in which case it
|
||||||
|
will be broadcasted for every category.
|
||||||
|
idxs: List[int] of indices of models to be rendered in the dataset.
|
||||||
|
shader_type: Shader to use for rendering. Examples include HardPhongShader
|
||||||
|
(default), SoftPhongShader etc or any other type of valid Shader class.
|
||||||
|
device: torch.device on which the tensors should be located.
|
||||||
|
**kwargs: Accepts any of the kwargs that the renderer supports and any of the
|
||||||
|
args that BlenderCamera supports.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Batch of rendered images of shape (N, H, W, 3).
|
||||||
|
"""
|
||||||
|
idxs = self._handle_render_inputs(model_ids, categories, sample_nums, idxs)
|
||||||
|
r = torch.cat([self[idxs[i], view_idxs]["R"] for i in range(len(idxs))])
|
||||||
|
t = torch.cat([self[idxs[i], view_idxs]["T"] for i in range(len(idxs))])
|
||||||
|
k = torch.cat([self[idxs[i], view_idxs]["K"] for i in range(len(idxs))])
|
||||||
|
# Initialize default camera using R, T, K from kwargs or R, T, K of the specified views.
|
||||||
|
blend_cameras = BlenderCamera(
|
||||||
|
R=kwargs.get("R", r),
|
||||||
|
T=kwargs.get("T", t),
|
||||||
|
K=kwargs.get("K", k),
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
cameras = kwargs.get("cameras", blend_cameras).to(device)
|
||||||
|
kwargs.pop("cameras", None)
|
||||||
|
# pass down all the same inputs
|
||||||
|
return super().render(
|
||||||
|
idxs=idxs, shader_type=shader_type, device=device, cameras=cameras, **kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class BlenderCamera(CamerasBase):
|
class BlenderCamera(CamerasBase):
|
||||||
"""
|
"""
|
||||||
|
@ -111,12 +111,27 @@ class ShapeNetBase(torch.utils.data.Dataset):
|
|||||||
Returns:
|
Returns:
|
||||||
Batch of rendered images of shape (N, H, W, 3).
|
Batch of rendered images of shape (N, H, W, 3).
|
||||||
"""
|
"""
|
||||||
paths = self._handle_render_inputs(model_ids, categories, sample_nums, idxs)
|
idxs = self._handle_render_inputs(model_ids, categories, sample_nums, idxs)
|
||||||
|
paths = [
|
||||||
|
path.join(
|
||||||
|
self.shapenet_dir,
|
||||||
|
self.synset_ids[idx],
|
||||||
|
self.model_ids[idx],
|
||||||
|
self.model_dir,
|
||||||
|
)
|
||||||
|
for idx in idxs
|
||||||
|
]
|
||||||
meshes = load_objs_as_meshes(paths, device=device, load_textures=False)
|
meshes = load_objs_as_meshes(paths, device=device, load_textures=False)
|
||||||
meshes.textures = TexturesVertex(
|
meshes.textures = TexturesVertex(
|
||||||
verts_features=torch.ones_like(meshes.verts_padded(), device=device)
|
verts_features=torch.ones_like(meshes.verts_padded(), device=device)
|
||||||
)
|
)
|
||||||
cameras = kwargs.get("cameras", OpenGLPerspectiveCameras()).to(device)
|
cameras = kwargs.get("cameras", OpenGLPerspectiveCameras()).to(device)
|
||||||
|
if len(cameras) != 1 and len(cameras) % len(meshes) != 0:
|
||||||
|
raise ValueError("Mismatch between batch dims of cameras and meshes.")
|
||||||
|
if len(cameras) > 1:
|
||||||
|
# When rendering R2N2 models, if more than one views are provided, broadcast
|
||||||
|
# the meshes so that each mesh can be rendered for each of the views.
|
||||||
|
meshes = meshes.extend(len(cameras) // len(meshes))
|
||||||
renderer = MeshRenderer(
|
renderer = MeshRenderer(
|
||||||
rasterizer=MeshRasterizer(
|
rasterizer=MeshRasterizer(
|
||||||
cameras=cameras,
|
cameras=cameras,
|
||||||
@ -136,7 +151,7 @@ class ShapeNetBase(torch.utils.data.Dataset):
|
|||||||
categories: Optional[List[str]] = None,
|
categories: Optional[List[str]] = None,
|
||||||
sample_nums: Optional[List[int]] = None,
|
sample_nums: Optional[List[int]] = None,
|
||||||
idxs: Optional[List[int]] = None,
|
idxs: Optional[List[int]] = None,
|
||||||
) -> List[str]:
|
) -> List[int]:
|
||||||
"""
|
"""
|
||||||
Helper function for converting user provided model_ids, categories and sample_nums
|
Helper function for converting user provided model_ids, categories and sample_nums
|
||||||
to indices of models in the loaded dataset. If model idxs are provided, we check if
|
to indices of models in the loaded dataset. If model idxs are provided, we check if
|
||||||
@ -206,15 +221,7 @@ class ShapeNetBase(torch.utils.data.Dataset):
|
|||||||
)
|
)
|
||||||
warnings.warn(msg)
|
warnings.warn(msg)
|
||||||
idxs = self._sample_idxs_from_category(sample_nums[0])
|
idxs = self._sample_idxs_from_category(sample_nums[0])
|
||||||
return [
|
return idxs
|
||||||
path.join(
|
|
||||||
self.shapenet_dir,
|
|
||||||
self.synset_ids[idx],
|
|
||||||
self.model_ids[idx],
|
|
||||||
self.model_dir,
|
|
||||||
)
|
|
||||||
for idx in idxs
|
|
||||||
]
|
|
||||||
|
|
||||||
def _sample_idxs_from_category(
|
def _sample_idxs_from_category(
|
||||||
self, sample_num: int = 1, category: Optional[str] = None
|
self, sample_num: int = 1, category: Optional[str] = None
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||||
|
import math
|
||||||
from typing import Dict, List
|
from typing import Dict, List
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@ -34,10 +34,77 @@ def collate_batched_meshes(batch: List[Dict]):
|
|||||||
verts=collated_dict["verts"], faces=collated_dict["faces"]
|
verts=collated_dict["verts"], faces=collated_dict["faces"]
|
||||||
)
|
)
|
||||||
|
|
||||||
# If collate_batched_meshes receives R2N2 items, stack the batches of
|
# If collate_batched_meshes receives R2N2 items with images and that
|
||||||
# views of each model into a new batch of shape (N, V, H, W, 3) where
|
# all models have the same number of views V, stack the batches of
|
||||||
# V is the number of views.
|
# views of each model into a new batch of shape (N, V, H, W, 3).
|
||||||
|
# Otherwise leave it as a list.
|
||||||
if "images" in collated_dict:
|
if "images" in collated_dict:
|
||||||
collated_dict["images"] = torch.stack(collated_dict["images"])
|
try:
|
||||||
|
collated_dict["images"] = torch.stack(collated_dict["images"])
|
||||||
|
except RuntimeError:
|
||||||
|
print(
|
||||||
|
"Models don't have the same number of views. Now returning "
|
||||||
|
"lists of images instead of batches."
|
||||||
|
)
|
||||||
|
|
||||||
|
# If collate_batched_meshes receives R2N2 items with camera calibration
|
||||||
|
# matrices and that all models have the same number of views V, stack each
|
||||||
|
# type of matrices into a new batch of shape (N, V, ...).
|
||||||
|
# Otherwise leave them as lists.
|
||||||
|
if all(x in collated_dict for x in ["R", "T", "K"]):
|
||||||
|
try:
|
||||||
|
collated_dict["R"] = torch.stack(collated_dict["R"]) # (N, V, 3, 3)
|
||||||
|
collated_dict["T"] = torch.stack(collated_dict["T"]) # (N, V, 3)
|
||||||
|
collated_dict["K"] = torch.stack(collated_dict["K"]) # (N, V, 4, 4)
|
||||||
|
except RuntimeError:
|
||||||
|
print(
|
||||||
|
"Models don't have the same number of views. Now returning "
|
||||||
|
"lists of calibration matrices instead of batches."
|
||||||
|
)
|
||||||
|
|
||||||
return collated_dict
|
return collated_dict
|
||||||
|
|
||||||
|
|
||||||
|
def compute_extrinsic_matrix(azimuth, elevation, distance):
|
||||||
|
"""
|
||||||
|
Copied from meshrcnn codebase:
|
||||||
|
https://github.com/facebookresearch/meshrcnn/blob/master/shapenet/utils/coords.py#L96
|
||||||
|
|
||||||
|
Compute 4x4 extrinsic matrix that converts from homogenous world coordinates
|
||||||
|
to homogenous camera coordinates. We assume that the camera is looking at the
|
||||||
|
origin.
|
||||||
|
Used in R2N2 Dataset when computing calibration matrices.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
azimuth: Rotation about the z-axis, in degrees.
|
||||||
|
elevation: Rotation above the xy-plane, in degrees.
|
||||||
|
distance: Distance from the origin.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
FloatTensor of shape (4, 4).
|
||||||
|
"""
|
||||||
|
azimuth, elevation, distance = float(azimuth), float(elevation), float(distance)
|
||||||
|
|
||||||
|
az_rad = -math.pi * azimuth / 180.0
|
||||||
|
el_rad = -math.pi * elevation / 180.0
|
||||||
|
sa = math.sin(az_rad)
|
||||||
|
ca = math.cos(az_rad)
|
||||||
|
se = math.sin(el_rad)
|
||||||
|
ce = math.cos(el_rad)
|
||||||
|
R_world2obj = torch.tensor(
|
||||||
|
[[ca * ce, sa * ce, -se], [-sa, ca, 0], [ca * se, sa * se, ce]]
|
||||||
|
)
|
||||||
|
R_obj2cam = torch.tensor([[0.0, 1.0, 0.0], [0.0, 0.0, 1.0], [1.0, 0.0, 0.0]])
|
||||||
|
R_world2cam = R_obj2cam.mm(R_world2obj)
|
||||||
|
cam_location = torch.tensor([[distance, 0, 0]]).t()
|
||||||
|
T_world2cam = -(R_obj2cam.mm(cam_location))
|
||||||
|
RT = torch.cat([R_world2cam, T_world2cam], dim=1)
|
||||||
|
RT = torch.cat([RT, torch.tensor([[0.0, 0, 0, 1]])])
|
||||||
|
|
||||||
|
# Georgia: For some reason I cannot fathom, when Blender loads a .obj file it
|
||||||
|
# rotates the model 90 degrees about the x axis. To compensate for this quirk we
|
||||||
|
# roll that rotation into the extrinsic matrix here
|
||||||
|
rot = torch.tensor([[1, 0, 0, 0], [0, 0, -1, 0], [0, 1, 0, 0], [0, 0, 0, 1]])
|
||||||
|
RT = RT.mm(rot.to(RT))
|
||||||
|
|
||||||
|
return RT
|
||||||
|
BIN
tests/data/test_r2n2_render_with_blender_calibrations_0.png
Normal file
BIN
tests/data/test_r2n2_render_with_blender_calibrations_0.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 3.0 KiB |
BIN
tests/data/test_r2n2_render_with_blender_calibrations_1.png
Normal file
BIN
tests/data/test_r2n2_render_with_blender_calibrations_1.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 2.7 KiB |
BIN
tests/data/test_r2n2_render_with_blender_calibrations_2.png
Normal file
BIN
tests/data/test_r2n2_render_with_blender_calibrations_2.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 2.9 KiB |
BIN
tests/data/test_r2n2_render_with_blender_calibrations_3.png
Normal file
BIN
tests/data/test_r2n2_render_with_blender_calibrations_3.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 3.1 KiB |
@ -93,10 +93,18 @@ class TestR2N2(TestCaseMixin, unittest.TestCase):
|
|||||||
self.assertEqual(faces.ndim, 2)
|
self.assertEqual(faces.ndim, 2)
|
||||||
self.assertEqual(faces.shape[-1], 3)
|
self.assertEqual(faces.shape[-1], 3)
|
||||||
|
|
||||||
|
# Check that the intrinsic matrix and extrinsic matrix have the
|
||||||
|
# correct shapes.
|
||||||
|
self.assertEqual(r2n2_obj["R"].shape[0], 24)
|
||||||
|
self.assertEqual(r2n2_obj["R"].shape[1:], (3, 3))
|
||||||
|
self.assertEqual(r2n2_obj["T"].ndim, 2)
|
||||||
|
self.assertEqual(r2n2_obj["T"].shape[1], 3)
|
||||||
|
self.assertEqual(r2n2_obj["K"].ndim, 3)
|
||||||
|
self.assertEqual(r2n2_obj["K"].shape[1:], (4, 4))
|
||||||
|
|
||||||
# Check that image batch returned by __getitem__ has the correct shape.
|
# Check that image batch returned by __getitem__ has the correct shape.
|
||||||
self.assertEqual(r2n2_obj["images"].shape[0], 24)
|
self.assertEqual(r2n2_obj["images"].shape[0], 24)
|
||||||
self.assertEqual(r2n2_obj["images"].shape[1], 137)
|
self.assertEqual(r2n2_obj["images"].shape[1:-1], (137, 137))
|
||||||
self.assertEqual(r2n2_obj["images"].shape[2], 137)
|
|
||||||
self.assertEqual(r2n2_obj["images"].shape[-1], 3)
|
self.assertEqual(r2n2_obj["images"].shape[-1], 3)
|
||||||
self.assertEqual(r2n2_dataset[39, [21]]["images"].shape[0], 1)
|
self.assertEqual(r2n2_dataset[39, [21]]["images"].shape[0], 1)
|
||||||
self.assertEqual(r2n2_dataset[39, torch.tensor([12, 21])]["images"].shape[0], 2)
|
self.assertEqual(r2n2_dataset[39, torch.tensor([12, 21])]["images"].shape[0], 2)
|
||||||
@ -113,7 +121,7 @@ class TestR2N2(TestCaseMixin, unittest.TestCase):
|
|||||||
the correct shapes and types are returned.
|
the correct shapes and types are returned.
|
||||||
"""
|
"""
|
||||||
# Load dataset in the train split.
|
# Load dataset in the train split.
|
||||||
r2n2_dataset = R2N2("train", SHAPENET_PATH, R2N2_PATH, SPLITS_PATH)
|
r2n2_dataset = R2N2("val", SHAPENET_PATH, R2N2_PATH, SPLITS_PATH)
|
||||||
|
|
||||||
# Randomly retrieve several objects from the dataset and collate them.
|
# Randomly retrieve several objects from the dataset and collate them.
|
||||||
collated_meshes = collate_batched_meshes(
|
collated_meshes = collate_batched_meshes(
|
||||||
@ -147,6 +155,9 @@ class TestR2N2(TestCaseMixin, unittest.TestCase):
|
|||||||
self.assertEqual(object_batch["mesh"].verts_padded().shape[0], batch_size)
|
self.assertEqual(object_batch["mesh"].verts_padded().shape[0], batch_size)
|
||||||
self.assertEqual(object_batch["mesh"].faces_padded().shape[0], batch_size)
|
self.assertEqual(object_batch["mesh"].faces_padded().shape[0], batch_size)
|
||||||
self.assertEqual(object_batch["images"].shape[0], batch_size)
|
self.assertEqual(object_batch["images"].shape[0], batch_size)
|
||||||
|
self.assertEqual(object_batch["R"].shape[0], batch_size)
|
||||||
|
self.assertEqual(object_batch["T"].shape[0], batch_size)
|
||||||
|
self.assertEqual(object_batch["K"].shape[0], batch_size)
|
||||||
|
|
||||||
def test_catch_render_arg_errors(self):
|
def test_catch_render_arg_errors(self):
|
||||||
"""
|
"""
|
||||||
@ -166,6 +177,13 @@ class TestR2N2(TestCaseMixin, unittest.TestCase):
|
|||||||
r2n2_dataset.render(idxs=[1000000])
|
r2n2_dataset.render(idxs=[1000000])
|
||||||
self.assertTrue("are out of bounds" in str(err.exception))
|
self.assertTrue("are out of bounds" in str(err.exception))
|
||||||
|
|
||||||
|
blend_cameras = BlenderCamera(
|
||||||
|
R=torch.rand((3, 3, 3)), T=torch.rand((3, 3)), K=torch.rand((3, 4, 4))
|
||||||
|
)
|
||||||
|
with self.assertRaises(ValueError) as err:
|
||||||
|
r2n2_dataset.render(idxs=[10, 11], cameras=blend_cameras)
|
||||||
|
self.assertTrue("Mismatch between batch dims" in str(err.exception))
|
||||||
|
|
||||||
def test_render_r2n2(self):
|
def test_render_r2n2(self):
|
||||||
"""
|
"""
|
||||||
Test rendering objects from R2N2 selected both by indices and model_ids.
|
Test rendering objects from R2N2 selected both by indices and model_ids.
|
||||||
@ -279,3 +297,44 @@ class TestR2N2(TestCaseMixin, unittest.TestCase):
|
|||||||
C = cam.get_camera_center()
|
C = cam.get_camera_center()
|
||||||
C_ = -torch.bmm(R, T[:, :, None])[:, :, 0]
|
C_ = -torch.bmm(R, T[:, :, None])[:, :, 0]
|
||||||
self.assertTrue(torch.allclose(C, C_, atol=1e-05))
|
self.assertTrue(torch.allclose(C, C_, atol=1e-05))
|
||||||
|
|
||||||
|
def test_render_by_r2n2_calibration(self):
|
||||||
|
"""
|
||||||
|
Test rendering R2N2 models with calibration matrices from R2N2's own Blender
|
||||||
|
in batches.
|
||||||
|
"""
|
||||||
|
# Set up device and seed for random selections.
|
||||||
|
device = torch.device("cuda:0")
|
||||||
|
torch.manual_seed(39)
|
||||||
|
|
||||||
|
# Load dataset in the train split.
|
||||||
|
r2n2_dataset = R2N2("train", SHAPENET_PATH, R2N2_PATH, SPLITS_PATH)
|
||||||
|
model_idxs = torch.randint(1000, (2,)).tolist()
|
||||||
|
view_idxs = torch.randint(24, (2,)).tolist()
|
||||||
|
raster_settings = RasterizationSettings(image_size=512)
|
||||||
|
lights = PointLights(
|
||||||
|
location=torch.tensor([0.0, 1.0, -2.0], device=device)[None],
|
||||||
|
# TODO(nikhilar): debug the source of the discrepancy in two images when
|
||||||
|
# rendering on GPU.
|
||||||
|
diffuse_color=((0, 0, 0),),
|
||||||
|
specular_color=((0, 0, 0),),
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
r2n2_batch = r2n2_dataset.render(
|
||||||
|
idxs=model_idxs,
|
||||||
|
view_idxs=view_idxs,
|
||||||
|
device=device,
|
||||||
|
raster_settings=raster_settings,
|
||||||
|
lights=lights,
|
||||||
|
)
|
||||||
|
for idx in range(4):
|
||||||
|
r2n2_batch_rgb = r2n2_batch[idx, ..., :3].squeeze().cpu()
|
||||||
|
if DEBUG:
|
||||||
|
Image.fromarray((r2n2_batch_rgb.numpy() * 255).astype(np.uint8)).save(
|
||||||
|
DATA_DIR
|
||||||
|
/ ("DEBUG_r2n2_render_with_blender_calibrations_%s.png" % idx)
|
||||||
|
)
|
||||||
|
image_ref = load_rgb_image(
|
||||||
|
"test_r2n2_render_with_blender_calibrations_%s.png" % idx, DATA_DIR
|
||||||
|
)
|
||||||
|
self.assertClose(r2n2_batch_rgb, image_ref, atol=0.05)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user