mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-12-20 22:30:35 +08:00
Update Rasterizer and add end2end fisheye integration test
Summary: 1) Update rasterizer/point rasterizer to accommodate fisheyecamera. Specifically, transform_points is in placement of explicit transform compositions. 2) In rasterizer unittests, update corresponding tests for rasterizer and point_rasterizer. Address comments to test fisheye against perspective camera when distortions are turned off. 3) Address comments to add end2end test for fisheyecameras. In test_render_meshes, fisheyecameras are added to camera enuerations whenever possible. 4) Test renderings with fisheyecameras of different params on cow mesh. 5) Use compositions for linear cameras whenever possible. Reviewed By: kjchalup Differential Revision: D38932736 fbshipit-source-id: 5b7074fc001f2390f4cf43c7267a8b37fd987547
This commit is contained in:
committed by
Facebook GitHub Bot
parent
b0515e1461
commit
d19e6243d0
@@ -12,10 +12,12 @@ import os
|
||||
import unittest
|
||||
from collections import namedtuple
|
||||
|
||||
from itertools import product
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image
|
||||
from pytorch3d.io import load_obj
|
||||
from pytorch3d.io import load_obj, load_objs_as_meshes
|
||||
from pytorch3d.renderer import (
|
||||
AmbientLights,
|
||||
FoVOrthographicCameras,
|
||||
@@ -33,6 +35,7 @@ from pytorch3d.renderer import (
|
||||
TexturesUV,
|
||||
TexturesVertex,
|
||||
)
|
||||
from pytorch3d.renderer.fisheyecameras import FishEyeCameras
|
||||
from pytorch3d.renderer.mesh.shader import (
|
||||
BlendParams,
|
||||
HardFlatShader,
|
||||
@@ -59,7 +62,6 @@ from .common_testing import (
|
||||
TestCaseMixin,
|
||||
)
|
||||
|
||||
|
||||
# If DEBUG=True, save out images generated in the tests for debugging.
|
||||
# All saved images have prefix DEBUG_
|
||||
DEBUG = False
|
||||
@@ -107,8 +109,38 @@ class TestRenderMeshes(TestCaseMixin, unittest.TestCase):
|
||||
FoVOrthographicCameras,
|
||||
PerspectiveCameras,
|
||||
OrthographicCameras,
|
||||
FishEyeCameras,
|
||||
):
|
||||
cameras = cam_type(device=device, R=R, T=T)
|
||||
if cam_type == FishEyeCameras:
|
||||
cam_kwargs = {
|
||||
"radial_params": torch.tensor(
|
||||
[
|
||||
[-1, -2, -3, 0, 0, 1],
|
||||
],
|
||||
dtype=torch.float32,
|
||||
),
|
||||
"tangential_params": torch.tensor(
|
||||
[[0.7002747019, -0.4005228974]], dtype=torch.float32
|
||||
),
|
||||
"thin_prism_params": torch.tensor(
|
||||
[
|
||||
[-1.000134884, -1.000084822, -1.0009420014, -1.0001276838],
|
||||
],
|
||||
dtype=torch.float32,
|
||||
),
|
||||
}
|
||||
cameras = cam_type(
|
||||
device=device,
|
||||
R=R,
|
||||
T=T,
|
||||
use_tangential=True,
|
||||
use_radial=True,
|
||||
use_thin_prism=True,
|
||||
world_coordinates=True,
|
||||
**cam_kwargs,
|
||||
)
|
||||
else:
|
||||
cameras = cam_type(device=device, R=R, T=T)
|
||||
|
||||
# Init shader settings
|
||||
materials = Materials(device=device)
|
||||
@@ -146,7 +178,11 @@ class TestRenderMeshes(TestCaseMixin, unittest.TestCase):
|
||||
cameras=cameras, raster_settings=raster_settings
|
||||
)
|
||||
elif test.rasterizer == MeshRasterizerOpenGL:
|
||||
if type(cameras) in [PerspectiveCameras, OrthographicCameras]:
|
||||
if type(cameras) in [
|
||||
PerspectiveCameras,
|
||||
OrthographicCameras,
|
||||
FishEyeCameras,
|
||||
]:
|
||||
# MeshRasterizerOpenGL is only compatible with FoV cameras.
|
||||
continue
|
||||
rasterizer = test.rasterizer(
|
||||
@@ -181,8 +217,6 @@ class TestRenderMeshes(TestCaseMixin, unittest.TestCase):
|
||||
)
|
||||
|
||||
image_ref = load_rgb_image("test_%s" % filename, DATA_DIR)
|
||||
self.assertClose(rgb, image_ref, atol=0.05)
|
||||
|
||||
if DEBUG:
|
||||
debug_filename = "simple_sphere_light_%s%s%s.png" % (
|
||||
test.debug_name,
|
||||
@@ -193,6 +227,7 @@ class TestRenderMeshes(TestCaseMixin, unittest.TestCase):
|
||||
Image.fromarray((rgb.numpy() * 255).astype(np.uint8)).save(
|
||||
DATA_DIR / filename
|
||||
)
|
||||
self.assertClose(rgb, image_ref, atol=0.05)
|
||||
|
||||
########################################################
|
||||
# Move the light to the +z axis in world space so it is
|
||||
@@ -429,8 +464,20 @@ class TestRenderMeshes(TestCaseMixin, unittest.TestCase):
|
||||
FoVOrthographicCameras,
|
||||
PerspectiveCameras,
|
||||
OrthographicCameras,
|
||||
FishEyeCameras,
|
||||
):
|
||||
cameras = cam_type(device=device, R=R, T=T)
|
||||
if cam_type == FishEyeCameras:
|
||||
cameras = cam_type(
|
||||
device=device,
|
||||
R=R,
|
||||
T=T,
|
||||
use_tangential=False,
|
||||
use_radial=False,
|
||||
use_thin_prism=False,
|
||||
world_coordinates=True,
|
||||
)
|
||||
else:
|
||||
cameras = cam_type(device=device, R=R, T=T)
|
||||
|
||||
# Init renderer
|
||||
renderer = MeshRenderer(
|
||||
@@ -1443,84 +1490,291 @@ class TestRenderMeshes(TestCaseMixin, unittest.TestCase):
|
||||
)
|
||||
self.assertClose(rgb, image_ref, atol=0.05)
|
||||
|
||||
def test_nd_sphere(self):
|
||||
"""
|
||||
Test that the render can handle textures with more than 3 channels and
|
||||
not just 3 channel RGB.
|
||||
"""
|
||||
torch.manual_seed(1)
|
||||
device = torch.device("cuda:0")
|
||||
C = 5
|
||||
WHITE = ((1.0,) * C,)
|
||||
BLACK = ((0.0,) * C,)
|
||||
def test_nd_sphere(self):
|
||||
"""
|
||||
Test that the render can handle textures with more than 3 channels and
|
||||
not just 3 channel RGB.
|
||||
"""
|
||||
torch.manual_seed(1)
|
||||
device = torch.device("cuda:0")
|
||||
C = 5
|
||||
WHITE = ((1.0,) * C,)
|
||||
BLACK = ((0.0,) * C,)
|
||||
|
||||
# Init mesh
|
||||
sphere_mesh = ico_sphere(5, device)
|
||||
verts_padded = sphere_mesh.verts_padded()
|
||||
faces_padded = sphere_mesh.faces_padded()
|
||||
feats = torch.ones(*verts_padded.shape[:-1], C, device=device)
|
||||
n_verts = feats.shape[1]
|
||||
# make some non-uniform pattern
|
||||
feats *= torch.arange(0, 10, step=10 / n_verts, device=device).unsqueeze(1)
|
||||
textures = TexturesVertex(verts_features=feats)
|
||||
sphere_mesh = Meshes(
|
||||
verts=verts_padded, faces=faces_padded, textures=textures
|
||||
# Init mesh
|
||||
sphere_mesh = ico_sphere(5, device)
|
||||
verts_padded = sphere_mesh.verts_padded()
|
||||
faces_padded = sphere_mesh.faces_padded()
|
||||
feats = torch.ones(*verts_padded.shape[:-1], C, device=device)
|
||||
n_verts = feats.shape[1]
|
||||
# make some non-uniform pattern
|
||||
feats *= torch.arange(0, 10, step=10 / n_verts, device=device).unsqueeze(1)
|
||||
textures = TexturesVertex(verts_features=feats)
|
||||
sphere_mesh = Meshes(verts=verts_padded, faces=faces_padded, textures=textures)
|
||||
|
||||
# No elevation or azimuth rotation
|
||||
R, T = look_at_view_transform(2.7, 0.0, 0.0)
|
||||
|
||||
cameras = PerspectiveCameras(device=device, R=R, T=T)
|
||||
|
||||
# Init shader settings
|
||||
materials = Materials(
|
||||
device=device,
|
||||
ambient_color=WHITE,
|
||||
diffuse_color=WHITE,
|
||||
specular_color=WHITE,
|
||||
)
|
||||
lights = AmbientLights(
|
||||
device=device,
|
||||
ambient_color=WHITE,
|
||||
)
|
||||
lights.location = torch.tensor([0.0, 0.0, +2.0], device=device)[None]
|
||||
|
||||
raster_settings = RasterizationSettings(
|
||||
image_size=512, blur_radius=0.0, faces_per_pixel=1
|
||||
)
|
||||
rasterizer = MeshRasterizer(cameras=cameras, raster_settings=raster_settings)
|
||||
blend_params = BlendParams(
|
||||
1e-4,
|
||||
1e-4,
|
||||
background_color=BLACK[0],
|
||||
)
|
||||
|
||||
# only test HardFlatShader since that's the only one that makes
|
||||
# sense for classification
|
||||
shader = HardFlatShader(
|
||||
lights=lights,
|
||||
cameras=cameras,
|
||||
materials=materials,
|
||||
blend_params=blend_params,
|
||||
)
|
||||
renderer = MeshRenderer(rasterizer=rasterizer, shader=shader)
|
||||
images = renderer(sphere_mesh)
|
||||
|
||||
self.assertEqual(images.shape[-1], C + 1)
|
||||
self.assertClose(images.amax(), torch.tensor(10.0), atol=0.01)
|
||||
self.assertClose(images.amin(), torch.tensor(0.0), atol=0.01)
|
||||
|
||||
# grab last 3 color channels
|
||||
rgb = (images[0, ..., C - 3 : C] / 10).squeeze().cpu()
|
||||
filename = "test_nd_sphere.png"
|
||||
|
||||
if DEBUG:
|
||||
debug_filename = "DEBUG_%s" % filename
|
||||
Image.fromarray((rgb.numpy() * 255).astype(np.uint8)).save(
|
||||
DATA_DIR / debug_filename
|
||||
)
|
||||
|
||||
# No elevation or azimuth rotation
|
||||
R, T = look_at_view_transform(2.7, 0.0, 0.0)
|
||||
image_ref = load_rgb_image(filename, DATA_DIR)
|
||||
self.assertClose(rgb, image_ref, atol=0.05)
|
||||
|
||||
cameras = PerspectiveCameras(device=device, R=R, T=T)
|
||||
def test_simple_sphere_fisheye_params(self):
|
||||
"""
|
||||
Test output of phong and gouraud shading matches a reference image using
|
||||
the default values for the light sources.
|
||||
|
||||
"""
|
||||
device = torch.device("cuda:0")
|
||||
|
||||
# Init mesh
|
||||
sphere_mesh = ico_sphere(5, device)
|
||||
verts_padded = sphere_mesh.verts_padded()
|
||||
faces_padded = sphere_mesh.faces_padded()
|
||||
feats = torch.ones_like(verts_padded, device=device)
|
||||
textures = TexturesVertex(verts_features=feats)
|
||||
sphere_mesh = Meshes(verts=verts_padded, faces=faces_padded, textures=textures)
|
||||
|
||||
# Init rasterizer settings
|
||||
R, T = look_at_view_transform(2.7, 0.0, 0.0)
|
||||
postfix = "_"
|
||||
|
||||
cam_kwargs = [
|
||||
{
|
||||
"radial_params": torch.tensor(
|
||||
[
|
||||
[-1, -2, -3, 0, 0, 1],
|
||||
],
|
||||
dtype=torch.float32,
|
||||
),
|
||||
},
|
||||
{
|
||||
"tangential_params": torch.tensor(
|
||||
[[0.7002747019, -0.4005228974]], dtype=torch.float32
|
||||
),
|
||||
},
|
||||
{
|
||||
"thin_prism_params": torch.tensor(
|
||||
[
|
||||
[
|
||||
-1.000134884,
|
||||
-1.000084822,
|
||||
-1.0009420014,
|
||||
-1.0001276838,
|
||||
],
|
||||
],
|
||||
dtype=torch.float32,
|
||||
),
|
||||
},
|
||||
]
|
||||
variants = ["radial", "tangential", "prism"]
|
||||
for test_case, variant in zip(cam_kwargs, variants):
|
||||
cameras = FishEyeCameras(
|
||||
device=device,
|
||||
R=R,
|
||||
T=T,
|
||||
use_tangential=True,
|
||||
use_radial=True,
|
||||
use_thin_prism=True,
|
||||
world_coordinates=True,
|
||||
**test_case,
|
||||
)
|
||||
|
||||
# Init shader settings
|
||||
materials = Materials(
|
||||
device=device,
|
||||
ambient_color=WHITE,
|
||||
diffuse_color=WHITE,
|
||||
specular_color=WHITE,
|
||||
)
|
||||
lights = AmbientLights(
|
||||
device=device,
|
||||
ambient_color=WHITE,
|
||||
)
|
||||
materials = Materials(device=device)
|
||||
lights = PointLights(device=device)
|
||||
lights.location = torch.tensor([0.0, 0.0, +2.0], device=device)[None]
|
||||
|
||||
raster_settings = RasterizationSettings(
|
||||
image_size=512, blur_radius=0.0, faces_per_pixel=1
|
||||
)
|
||||
rasterizer = MeshRasterizer(
|
||||
cameras=cameras, raster_settings=raster_settings
|
||||
)
|
||||
blend_params = BlendParams(
|
||||
1e-4,
|
||||
1e-4,
|
||||
background_color=BLACK[0],
|
||||
)
|
||||
blend_params = BlendParams(0.5, 1e-4, (0, 0, 0))
|
||||
|
||||
# only test HardFlatShader since that's the only one that makes
|
||||
# sense for classification
|
||||
shader = HardFlatShader(
|
||||
# Test several shaders
|
||||
rasterizer_tests = [
|
||||
RasterizerTest(
|
||||
MeshRasterizer, HardPhongShader, "hard_phong", "hard_phong"
|
||||
),
|
||||
RasterizerTest(
|
||||
MeshRasterizer, HardGouraudShader, "hard_gouraud", "hard_gouraud"
|
||||
),
|
||||
RasterizerTest(
|
||||
MeshRasterizer, HardFlatShader, "hard_flat", "hard_flat"
|
||||
),
|
||||
]
|
||||
for test in rasterizer_tests:
|
||||
shader = test.shader(
|
||||
lights=lights,
|
||||
cameras=cameras,
|
||||
materials=materials,
|
||||
blend_params=blend_params,
|
||||
)
|
||||
if test.rasterizer == MeshRasterizer:
|
||||
rasterizer = test.rasterizer(
|
||||
cameras=cameras, raster_settings=raster_settings
|
||||
)
|
||||
|
||||
renderer = MeshRenderer(rasterizer=rasterizer, shader=shader)
|
||||
images = renderer(sphere_mesh)
|
||||
|
||||
rgb = images[0, ..., :3].squeeze().cpu()
|
||||
filename = "simple_sphere_light_%s%s%s%s%s.png" % (
|
||||
test.reference_name,
|
||||
postfix,
|
||||
variant,
|
||||
postfix,
|
||||
FishEyeCameras.__name__,
|
||||
)
|
||||
|
||||
image_ref = load_rgb_image("test_%s" % filename, DATA_DIR)
|
||||
if DEBUG:
|
||||
debug_filename = "simple_sphere_light_%s%s%s%s%s.png" % (
|
||||
test.debug_name,
|
||||
postfix,
|
||||
variant,
|
||||
postfix,
|
||||
FishEyeCameras.__name__,
|
||||
)
|
||||
filename = "DEBUG_%s" % debug_filename
|
||||
Image.fromarray((rgb.numpy() * 255).astype(np.uint8)).save(
|
||||
DATA_DIR / filename
|
||||
)
|
||||
self.assertClose(rgb, image_ref, atol=0.05)
|
||||
|
||||
########################################################
|
||||
# Move the light to the +z axis in world space so it is
|
||||
# behind the sphere. Note that +Z is in, +Y up,
|
||||
# +X left for both world and camera space.
|
||||
########################################################
|
||||
lights.location[..., 2] = -2.0
|
||||
phong_shader = HardPhongShader(
|
||||
lights=lights,
|
||||
cameras=cameras,
|
||||
materials=materials,
|
||||
blend_params=blend_params,
|
||||
)
|
||||
renderer = MeshRenderer(rasterizer=rasterizer, shader=shader)
|
||||
images = renderer(sphere_mesh)
|
||||
|
||||
self.assertEqual(images.shape[-1], C + 1)
|
||||
self.assertClose(images.amax(), torch.tensor(10.0), atol=0.01)
|
||||
self.assertClose(images.amin(), torch.tensor(0.0), atol=0.01)
|
||||
|
||||
# grab last 3 color channels
|
||||
rgb = (images[0, ..., C - 3 : C] / 10).squeeze().cpu()
|
||||
filename = "test_nd_sphere.png"
|
||||
|
||||
phong_renderer = MeshRenderer(rasterizer=rasterizer, shader=phong_shader)
|
||||
images = phong_renderer(sphere_mesh, lights=lights)
|
||||
rgb = images[0, ..., :3].squeeze().cpu()
|
||||
if DEBUG:
|
||||
debug_filename = "DEBUG_%s" % filename
|
||||
filename = "DEBUG_simple_sphere_dark%s%s%s%s.png" % (
|
||||
postfix,
|
||||
variant,
|
||||
postfix,
|
||||
FishEyeCameras.__name__,
|
||||
)
|
||||
Image.fromarray((rgb.numpy() * 255).astype(np.uint8)).save(
|
||||
DATA_DIR / debug_filename
|
||||
DATA_DIR / filename
|
||||
)
|
||||
|
||||
image_ref_phong_dark = load_rgb_image(
|
||||
"test_simple_sphere_dark%s%s%s%s.png"
|
||||
% (postfix, variant, postfix, FishEyeCameras.__name__),
|
||||
DATA_DIR,
|
||||
)
|
||||
# Soft shaders (SplatterPhong) will have a different boundary than hard
|
||||
# ones, but should be identical otherwise.
|
||||
self.assertLess((rgb - image_ref_phong_dark).quantile(0.99), 0.005)
|
||||
|
||||
def test_fisheye_cow_mesh(self):
|
||||
"""
|
||||
Test FishEye Camera distortions on real meshes
|
||||
"""
|
||||
device = torch.device("cuda:0")
|
||||
obj_filename = os.path.join(DATA_DIR, "missing_usemtl/cow.obj")
|
||||
mesh = load_objs_as_meshes([obj_filename], device=device)
|
||||
R, T = look_at_view_transform(2.7, 0, 180)
|
||||
radial_params = torch.tensor([[-1.0, 1.0, 1.0, 0.0, 0.0, -1.0]])
|
||||
tangential_params = torch.tensor([[0.5, 0.5]])
|
||||
thin_prism_params = torch.tensor([[0.5, 0.5, 0.5, 0.5]])
|
||||
combinations = product([False, True], repeat=3)
|
||||
for combination in combinations:
|
||||
cameras = FishEyeCameras(
|
||||
device=device,
|
||||
R=R,
|
||||
T=T,
|
||||
world_coordinates=True,
|
||||
use_radial=combination[0],
|
||||
use_tangential=combination[1],
|
||||
use_thin_prism=combination[2],
|
||||
radial_params=radial_params,
|
||||
tangential_params=tangential_params,
|
||||
thin_prism_params=thin_prism_params,
|
||||
)
|
||||
raster_settings = RasterizationSettings(
|
||||
image_size=512,
|
||||
blur_radius=0.0,
|
||||
faces_per_pixel=1,
|
||||
)
|
||||
lights = PointLights(device=device, location=[[0.0, 0.0, -3.0]])
|
||||
renderer = MeshRenderer(
|
||||
rasterizer=MeshRasterizer(
|
||||
cameras=cameras, raster_settings=raster_settings
|
||||
),
|
||||
shader=SoftPhongShader(device=device, cameras=cameras, lights=lights),
|
||||
)
|
||||
images = renderer(mesh)
|
||||
rgb = images[0, ..., :3].squeeze().cpu()
|
||||
filename = "test_cow_mesh_%s_radial_%s_tangential_%s_prism_%s.png" % (
|
||||
FishEyeCameras.__name__,
|
||||
combination[0],
|
||||
combination[1],
|
||||
combination[2],
|
||||
)
|
||||
image_ref = load_rgb_image(filename, DATA_DIR)
|
||||
if DEBUG:
|
||||
filename = filename.replace("test", "DEBUG")
|
||||
Image.fromarray((rgb.numpy() * 255).astype(np.uint8)).save(
|
||||
DATA_DIR / filename
|
||||
)
|
||||
self.assertClose(rgb, image_ref, atol=0.05)
|
||||
|
||||
Reference in New Issue
Block a user