Add BlenderCamera

Summary: Adding BlenderCamera (for rendering with R2N2 Blender transformations in the next diff).

Reviewed By: nikhilaravi

Differential Revision: D22462515

fbshipit-source-id: 4b40ee9bba8b6d56788dd3c723036ec704668153
This commit is contained in:
Luya Gao 2020-07-23 10:15:50 -07:00 committed by Facebook GitHub Bot
parent 483e538dae
commit 722c2b7149
4 changed files with 55 additions and 3 deletions

View File

@ -1,6 +1,6 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
from .r2n2 import R2N2 from .r2n2 import R2N2, BlenderCamera
from .shapenet import ShapeNetCore from .shapenet import ShapeNetCore
from .utils import collate_batched_meshes from .utils import collate_batched_meshes

View File

@ -1,6 +1,6 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
from .r2n2 import R2N2 from .r2n2 import R2N2, BlenderCamera
__all__ = [k for k in globals().keys() if not k.startswith("_")] __all__ = [k for k in globals().keys() if not k.startswith("_")]

View File

@ -11,11 +11,18 @@ import torch
from PIL import Image from PIL import Image
from pytorch3d.datasets.shapenet_base import ShapeNetBase from pytorch3d.datasets.shapenet_base import ShapeNetBase
from pytorch3d.io import load_obj from pytorch3d.io import load_obj
from pytorch3d.renderer.cameras import CamerasBase
from pytorch3d.transforms import Transform3d
from tabulate import tabulate from tabulate import tabulate
SYNSET_DICT_DIR = Path(__file__).resolve().parent SYNSET_DICT_DIR = Path(__file__).resolve().parent
# Default values of rotation, translation and intrinsic matrices for BlenderCamera.
r = np.expand_dims(np.eye(3), axis=0) # (1, 3, 3)
t = np.expand_dims(np.zeros(3), axis=0) # (1, 3)
k = np.expand_dims(np.eye(4), axis=0) # (1, 4, 4)
class R2N2(ShapeNetBase): class R2N2(ShapeNetBase):
""" """
@ -217,3 +224,27 @@ class R2N2(ShapeNetBase):
model["images"] = torch.stack(images) model["images"] = torch.stack(images)
return model return model
class BlenderCamera(CamerasBase):
"""
Camera for rendering objects with calibration matrices from the R2N2 dataset
(which uses Blender for rendering the views for each model).
"""
def __init__(self, R=r, T=t, K=k, device="cpu"):
"""
Args:
R: Rotation matrix of shape (N, 3, 3).
T: Translation matrix of shape (N, 3).
K: Intrinsic matrix of shape (N, 4, 4).
device: torch.device or str.
"""
# The initializer formats all inputs to torch tensors and broadcasts
# all the inputs to have the same batch dimension where necessary.
super().__init__(device=device, R=R, T=T, K=K)
def get_projection_transform(self, **kwargs) -> Transform3d:
transform = Transform3d(device=self.device)
transform._matrix = self.K.transpose(1, 2).contiguous() # pyre-ignore[16]
return transform

View File

@ -11,13 +11,16 @@ import numpy as np
import torch import torch
from common_testing import TestCaseMixin, load_rgb_image from common_testing import TestCaseMixin, load_rgb_image
from PIL import Image from PIL import Image
from pytorch3d.datasets import R2N2, collate_batched_meshes from pytorch3d.datasets import R2N2, BlenderCamera, collate_batched_meshes
from pytorch3d.renderer import ( from pytorch3d.renderer import (
OpenGLPerspectiveCameras, OpenGLPerspectiveCameras,
PointLights, PointLights,
RasterizationSettings, RasterizationSettings,
look_at_view_transform, look_at_view_transform,
) )
from pytorch3d.renderer.cameras import get_world_to_view_transform
from pytorch3d.transforms import Transform3d
from pytorch3d.transforms.so3 import so3_exponential_map
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
@ -258,3 +261,21 @@ class TestR2N2(TestCaseMixin, unittest.TestCase):
"test_r2n2_render_by_categories_%s.png" % idx, DATA_DIR "test_r2n2_render_by_categories_%s.png" % idx, DATA_DIR
) )
self.assertClose(mixed_rgb, image_ref, atol=0.05) self.assertClose(mixed_rgb, image_ref, atol=0.05)
def test_blender_camera(self):
"""
Test BlenderCamera.
"""
# Test get_world_to_view_transform.
T = torch.randn(10, 3)
R = so3_exponential_map(torch.randn(10, 3) * 3.0)
RT = get_world_to_view_transform(R=R, T=T)
cam = BlenderCamera(R=R, T=T)
RT_class = cam.get_world_to_view_transform()
self.assertTrue(torch.allclose(RT.get_matrix(), RT_class.get_matrix()))
self.assertTrue(isinstance(RT, Transform3d))
# Test getting camera center.
C = cam.get_camera_center()
C_ = -torch.bmm(R, T[:, :, None])[:, :, 0]
self.assertTrue(torch.allclose(C, C_, atol=1e-05))