mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 20:02:49 +08:00
Add BlenderCamera
Summary: Adding BlenderCamera (for rendering with R2N2 Blender transformations in the next diff). Reviewed By: nikhilaravi Differential Revision: D22462515 fbshipit-source-id: 4b40ee9bba8b6d56788dd3c723036ec704668153
This commit is contained in:
parent
483e538dae
commit
722c2b7149
@ -1,6 +1,6 @@
|
|||||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||||
|
|
||||||
from .r2n2 import R2N2
|
from .r2n2 import R2N2, BlenderCamera
|
||||||
from .shapenet import ShapeNetCore
|
from .shapenet import ShapeNetCore
|
||||||
from .utils import collate_batched_meshes
|
from .utils import collate_batched_meshes
|
||||||
|
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||||
|
|
||||||
from .r2n2 import R2N2
|
from .r2n2 import R2N2, BlenderCamera
|
||||||
|
|
||||||
|
|
||||||
__all__ = [k for k in globals().keys() if not k.startswith("_")]
|
__all__ = [k for k in globals().keys() if not k.startswith("_")]
|
||||||
|
@ -11,11 +11,18 @@ import torch
|
|||||||
from PIL import Image
|
from PIL import Image
|
||||||
from pytorch3d.datasets.shapenet_base import ShapeNetBase
|
from pytorch3d.datasets.shapenet_base import ShapeNetBase
|
||||||
from pytorch3d.io import load_obj
|
from pytorch3d.io import load_obj
|
||||||
|
from pytorch3d.renderer.cameras import CamerasBase
|
||||||
|
from pytorch3d.transforms import Transform3d
|
||||||
from tabulate import tabulate
|
from tabulate import tabulate
|
||||||
|
|
||||||
|
|
||||||
SYNSET_DICT_DIR = Path(__file__).resolve().parent
|
SYNSET_DICT_DIR = Path(__file__).resolve().parent
|
||||||
|
|
||||||
|
# Default values of rotation, translation and intrinsic matrices for BlenderCamera.
|
||||||
|
r = np.expand_dims(np.eye(3), axis=0) # (1, 3, 3)
|
||||||
|
t = np.expand_dims(np.zeros(3), axis=0) # (1, 3)
|
||||||
|
k = np.expand_dims(np.eye(4), axis=0) # (1, 4, 4)
|
||||||
|
|
||||||
|
|
||||||
class R2N2(ShapeNetBase):
|
class R2N2(ShapeNetBase):
|
||||||
"""
|
"""
|
||||||
@ -217,3 +224,27 @@ class R2N2(ShapeNetBase):
|
|||||||
model["images"] = torch.stack(images)
|
model["images"] = torch.stack(images)
|
||||||
|
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
class BlenderCamera(CamerasBase):
|
||||||
|
"""
|
||||||
|
Camera for rendering objects with calibration matrices from the R2N2 dataset
|
||||||
|
(which uses Blender for rendering the views for each model).
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, R=r, T=t, K=k, device="cpu"):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
R: Rotation matrix of shape (N, 3, 3).
|
||||||
|
T: Translation matrix of shape (N, 3).
|
||||||
|
K: Intrinsic matrix of shape (N, 4, 4).
|
||||||
|
device: torch.device or str.
|
||||||
|
"""
|
||||||
|
# The initializer formats all inputs to torch tensors and broadcasts
|
||||||
|
# all the inputs to have the same batch dimension where necessary.
|
||||||
|
super().__init__(device=device, R=R, T=T, K=K)
|
||||||
|
|
||||||
|
def get_projection_transform(self, **kwargs) -> Transform3d:
|
||||||
|
transform = Transform3d(device=self.device)
|
||||||
|
transform._matrix = self.K.transpose(1, 2).contiguous() # pyre-ignore[16]
|
||||||
|
return transform
|
||||||
|
@ -11,13 +11,16 @@ import numpy as np
|
|||||||
import torch
|
import torch
|
||||||
from common_testing import TestCaseMixin, load_rgb_image
|
from common_testing import TestCaseMixin, load_rgb_image
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from pytorch3d.datasets import R2N2, collate_batched_meshes
|
from pytorch3d.datasets import R2N2, BlenderCamera, collate_batched_meshes
|
||||||
from pytorch3d.renderer import (
|
from pytorch3d.renderer import (
|
||||||
OpenGLPerspectiveCameras,
|
OpenGLPerspectiveCameras,
|
||||||
PointLights,
|
PointLights,
|
||||||
RasterizationSettings,
|
RasterizationSettings,
|
||||||
look_at_view_transform,
|
look_at_view_transform,
|
||||||
)
|
)
|
||||||
|
from pytorch3d.renderer.cameras import get_world_to_view_transform
|
||||||
|
from pytorch3d.transforms import Transform3d
|
||||||
|
from pytorch3d.transforms.so3 import so3_exponential_map
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
|
|
||||||
@ -258,3 +261,21 @@ class TestR2N2(TestCaseMixin, unittest.TestCase):
|
|||||||
"test_r2n2_render_by_categories_%s.png" % idx, DATA_DIR
|
"test_r2n2_render_by_categories_%s.png" % idx, DATA_DIR
|
||||||
)
|
)
|
||||||
self.assertClose(mixed_rgb, image_ref, atol=0.05)
|
self.assertClose(mixed_rgb, image_ref, atol=0.05)
|
||||||
|
|
||||||
|
def test_blender_camera(self):
|
||||||
|
"""
|
||||||
|
Test BlenderCamera.
|
||||||
|
"""
|
||||||
|
# Test get_world_to_view_transform.
|
||||||
|
T = torch.randn(10, 3)
|
||||||
|
R = so3_exponential_map(torch.randn(10, 3) * 3.0)
|
||||||
|
RT = get_world_to_view_transform(R=R, T=T)
|
||||||
|
cam = BlenderCamera(R=R, T=T)
|
||||||
|
RT_class = cam.get_world_to_view_transform()
|
||||||
|
self.assertTrue(torch.allclose(RT.get_matrix(), RT_class.get_matrix()))
|
||||||
|
self.assertTrue(isinstance(RT, Transform3d))
|
||||||
|
|
||||||
|
# Test getting camera center.
|
||||||
|
C = cam.get_camera_center()
|
||||||
|
C_ = -torch.bmm(R, T[:, :, None])[:, :, 0]
|
||||||
|
self.assertTrue(torch.allclose(C, C_, atol=1e-05))
|
||||||
|
Loading…
x
Reference in New Issue
Block a user