Update Rasterizer and add end2end fisheye integration test

Summary:
1) Update rasterizer/point rasterizer to accommodate fisheyecamera. Specifically, transform_points is in placement of explicit transform compositions.

2) In rasterizer unittests, update corresponding tests for rasterizer and point_rasterizer. Address comments to test fisheye against perspective camera when distortions are turned off.

3) Address comments to add end2end test for fisheyecameras. In test_render_meshes, fisheyecameras are added to camera enuerations whenever possible.

4) Test renderings with fisheyecameras of different params on cow mesh.

5) Use compositions for linear cameras whenever possible.

Reviewed By: kjchalup

Differential Revision: D38932736

fbshipit-source-id: 5b7074fc001f2390f4cf43c7267a8b37fd987547
This commit is contained in:
Jiali Duan
2022-08-31 16:50:41 -07:00
committed by Facebook GitHub Bot
parent b0515e1461
commit d19e6243d0
63 changed files with 566 additions and 76 deletions

View File

@@ -12,10 +12,12 @@ import os
import unittest
from collections import namedtuple
from itertools import product
import numpy as np
import torch
from PIL import Image
from pytorch3d.io import load_obj
from pytorch3d.io import load_obj, load_objs_as_meshes
from pytorch3d.renderer import (
AmbientLights,
FoVOrthographicCameras,
@@ -33,6 +35,7 @@ from pytorch3d.renderer import (
TexturesUV,
TexturesVertex,
)
from pytorch3d.renderer.fisheyecameras import FishEyeCameras
from pytorch3d.renderer.mesh.shader import (
BlendParams,
HardFlatShader,
@@ -59,7 +62,6 @@ from .common_testing import (
TestCaseMixin,
)
# If DEBUG=True, save out images generated in the tests for debugging.
# All saved images have prefix DEBUG_
DEBUG = False
@@ -107,8 +109,38 @@ class TestRenderMeshes(TestCaseMixin, unittest.TestCase):
FoVOrthographicCameras,
PerspectiveCameras,
OrthographicCameras,
FishEyeCameras,
):
cameras = cam_type(device=device, R=R, T=T)
if cam_type == FishEyeCameras:
cam_kwargs = {
"radial_params": torch.tensor(
[
[-1, -2, -3, 0, 0, 1],
],
dtype=torch.float32,
),
"tangential_params": torch.tensor(
[[0.7002747019, -0.4005228974]], dtype=torch.float32
),
"thin_prism_params": torch.tensor(
[
[-1.000134884, -1.000084822, -1.0009420014, -1.0001276838],
],
dtype=torch.float32,
),
}
cameras = cam_type(
device=device,
R=R,
T=T,
use_tangential=True,
use_radial=True,
use_thin_prism=True,
world_coordinates=True,
**cam_kwargs,
)
else:
cameras = cam_type(device=device, R=R, T=T)
# Init shader settings
materials = Materials(device=device)
@@ -146,7 +178,11 @@ class TestRenderMeshes(TestCaseMixin, unittest.TestCase):
cameras=cameras, raster_settings=raster_settings
)
elif test.rasterizer == MeshRasterizerOpenGL:
if type(cameras) in [PerspectiveCameras, OrthographicCameras]:
if type(cameras) in [
PerspectiveCameras,
OrthographicCameras,
FishEyeCameras,
]:
# MeshRasterizerOpenGL is only compatible with FoV cameras.
continue
rasterizer = test.rasterizer(
@@ -181,8 +217,6 @@ class TestRenderMeshes(TestCaseMixin, unittest.TestCase):
)
image_ref = load_rgb_image("test_%s" % filename, DATA_DIR)
self.assertClose(rgb, image_ref, atol=0.05)
if DEBUG:
debug_filename = "simple_sphere_light_%s%s%s.png" % (
test.debug_name,
@@ -193,6 +227,7 @@ class TestRenderMeshes(TestCaseMixin, unittest.TestCase):
Image.fromarray((rgb.numpy() * 255).astype(np.uint8)).save(
DATA_DIR / filename
)
self.assertClose(rgb, image_ref, atol=0.05)
########################################################
# Move the light to the +z axis in world space so it is
@@ -429,8 +464,20 @@ class TestRenderMeshes(TestCaseMixin, unittest.TestCase):
FoVOrthographicCameras,
PerspectiveCameras,
OrthographicCameras,
FishEyeCameras,
):
cameras = cam_type(device=device, R=R, T=T)
if cam_type == FishEyeCameras:
cameras = cam_type(
device=device,
R=R,
T=T,
use_tangential=False,
use_radial=False,
use_thin_prism=False,
world_coordinates=True,
)
else:
cameras = cam_type(device=device, R=R, T=T)
# Init renderer
renderer = MeshRenderer(
@@ -1443,84 +1490,291 @@ class TestRenderMeshes(TestCaseMixin, unittest.TestCase):
)
self.assertClose(rgb, image_ref, atol=0.05)
def test_nd_sphere(self):
"""
Test that the render can handle textures with more than 3 channels and
not just 3 channel RGB.
"""
torch.manual_seed(1)
device = torch.device("cuda:0")
C = 5
WHITE = ((1.0,) * C,)
BLACK = ((0.0,) * C,)
def test_nd_sphere(self):
"""
Test that the render can handle textures with more than 3 channels and
not just 3 channel RGB.
"""
torch.manual_seed(1)
device = torch.device("cuda:0")
C = 5
WHITE = ((1.0,) * C,)
BLACK = ((0.0,) * C,)
# Init mesh
sphere_mesh = ico_sphere(5, device)
verts_padded = sphere_mesh.verts_padded()
faces_padded = sphere_mesh.faces_padded()
feats = torch.ones(*verts_padded.shape[:-1], C, device=device)
n_verts = feats.shape[1]
# make some non-uniform pattern
feats *= torch.arange(0, 10, step=10 / n_verts, device=device).unsqueeze(1)
textures = TexturesVertex(verts_features=feats)
sphere_mesh = Meshes(
verts=verts_padded, faces=faces_padded, textures=textures
# Init mesh
sphere_mesh = ico_sphere(5, device)
verts_padded = sphere_mesh.verts_padded()
faces_padded = sphere_mesh.faces_padded()
feats = torch.ones(*verts_padded.shape[:-1], C, device=device)
n_verts = feats.shape[1]
# make some non-uniform pattern
feats *= torch.arange(0, 10, step=10 / n_verts, device=device).unsqueeze(1)
textures = TexturesVertex(verts_features=feats)
sphere_mesh = Meshes(verts=verts_padded, faces=faces_padded, textures=textures)
# No elevation or azimuth rotation
R, T = look_at_view_transform(2.7, 0.0, 0.0)
cameras = PerspectiveCameras(device=device, R=R, T=T)
# Init shader settings
materials = Materials(
device=device,
ambient_color=WHITE,
diffuse_color=WHITE,
specular_color=WHITE,
)
lights = AmbientLights(
device=device,
ambient_color=WHITE,
)
lights.location = torch.tensor([0.0, 0.0, +2.0], device=device)[None]
raster_settings = RasterizationSettings(
image_size=512, blur_radius=0.0, faces_per_pixel=1
)
rasterizer = MeshRasterizer(cameras=cameras, raster_settings=raster_settings)
blend_params = BlendParams(
1e-4,
1e-4,
background_color=BLACK[0],
)
# only test HardFlatShader since that's the only one that makes
# sense for classification
shader = HardFlatShader(
lights=lights,
cameras=cameras,
materials=materials,
blend_params=blend_params,
)
renderer = MeshRenderer(rasterizer=rasterizer, shader=shader)
images = renderer(sphere_mesh)
self.assertEqual(images.shape[-1], C + 1)
self.assertClose(images.amax(), torch.tensor(10.0), atol=0.01)
self.assertClose(images.amin(), torch.tensor(0.0), atol=0.01)
# grab last 3 color channels
rgb = (images[0, ..., C - 3 : C] / 10).squeeze().cpu()
filename = "test_nd_sphere.png"
if DEBUG:
debug_filename = "DEBUG_%s" % filename
Image.fromarray((rgb.numpy() * 255).astype(np.uint8)).save(
DATA_DIR / debug_filename
)
# No elevation or azimuth rotation
R, T = look_at_view_transform(2.7, 0.0, 0.0)
image_ref = load_rgb_image(filename, DATA_DIR)
self.assertClose(rgb, image_ref, atol=0.05)
cameras = PerspectiveCameras(device=device, R=R, T=T)
def test_simple_sphere_fisheye_params(self):
"""
Test output of phong and gouraud shading matches a reference image using
the default values for the light sources.
"""
device = torch.device("cuda:0")
# Init mesh
sphere_mesh = ico_sphere(5, device)
verts_padded = sphere_mesh.verts_padded()
faces_padded = sphere_mesh.faces_padded()
feats = torch.ones_like(verts_padded, device=device)
textures = TexturesVertex(verts_features=feats)
sphere_mesh = Meshes(verts=verts_padded, faces=faces_padded, textures=textures)
# Init rasterizer settings
R, T = look_at_view_transform(2.7, 0.0, 0.0)
postfix = "_"
cam_kwargs = [
{
"radial_params": torch.tensor(
[
[-1, -2, -3, 0, 0, 1],
],
dtype=torch.float32,
),
},
{
"tangential_params": torch.tensor(
[[0.7002747019, -0.4005228974]], dtype=torch.float32
),
},
{
"thin_prism_params": torch.tensor(
[
[
-1.000134884,
-1.000084822,
-1.0009420014,
-1.0001276838,
],
],
dtype=torch.float32,
),
},
]
variants = ["radial", "tangential", "prism"]
for test_case, variant in zip(cam_kwargs, variants):
cameras = FishEyeCameras(
device=device,
R=R,
T=T,
use_tangential=True,
use_radial=True,
use_thin_prism=True,
world_coordinates=True,
**test_case,
)
# Init shader settings
materials = Materials(
device=device,
ambient_color=WHITE,
diffuse_color=WHITE,
specular_color=WHITE,
)
lights = AmbientLights(
device=device,
ambient_color=WHITE,
)
materials = Materials(device=device)
lights = PointLights(device=device)
lights.location = torch.tensor([0.0, 0.0, +2.0], device=device)[None]
raster_settings = RasterizationSettings(
image_size=512, blur_radius=0.0, faces_per_pixel=1
)
rasterizer = MeshRasterizer(
cameras=cameras, raster_settings=raster_settings
)
blend_params = BlendParams(
1e-4,
1e-4,
background_color=BLACK[0],
)
blend_params = BlendParams(0.5, 1e-4, (0, 0, 0))
# only test HardFlatShader since that's the only one that makes
# sense for classification
shader = HardFlatShader(
# Test several shaders
rasterizer_tests = [
RasterizerTest(
MeshRasterizer, HardPhongShader, "hard_phong", "hard_phong"
),
RasterizerTest(
MeshRasterizer, HardGouraudShader, "hard_gouraud", "hard_gouraud"
),
RasterizerTest(
MeshRasterizer, HardFlatShader, "hard_flat", "hard_flat"
),
]
for test in rasterizer_tests:
shader = test.shader(
lights=lights,
cameras=cameras,
materials=materials,
blend_params=blend_params,
)
if test.rasterizer == MeshRasterizer:
rasterizer = test.rasterizer(
cameras=cameras, raster_settings=raster_settings
)
renderer = MeshRenderer(rasterizer=rasterizer, shader=shader)
images = renderer(sphere_mesh)
rgb = images[0, ..., :3].squeeze().cpu()
filename = "simple_sphere_light_%s%s%s%s%s.png" % (
test.reference_name,
postfix,
variant,
postfix,
FishEyeCameras.__name__,
)
image_ref = load_rgb_image("test_%s" % filename, DATA_DIR)
if DEBUG:
debug_filename = "simple_sphere_light_%s%s%s%s%s.png" % (
test.debug_name,
postfix,
variant,
postfix,
FishEyeCameras.__name__,
)
filename = "DEBUG_%s" % debug_filename
Image.fromarray((rgb.numpy() * 255).astype(np.uint8)).save(
DATA_DIR / filename
)
self.assertClose(rgb, image_ref, atol=0.05)
########################################################
# Move the light to the +z axis in world space so it is
# behind the sphere. Note that +Z is in, +Y up,
# +X left for both world and camera space.
########################################################
lights.location[..., 2] = -2.0
phong_shader = HardPhongShader(
lights=lights,
cameras=cameras,
materials=materials,
blend_params=blend_params,
)
renderer = MeshRenderer(rasterizer=rasterizer, shader=shader)
images = renderer(sphere_mesh)
self.assertEqual(images.shape[-1], C + 1)
self.assertClose(images.amax(), torch.tensor(10.0), atol=0.01)
self.assertClose(images.amin(), torch.tensor(0.0), atol=0.01)
# grab last 3 color channels
rgb = (images[0, ..., C - 3 : C] / 10).squeeze().cpu()
filename = "test_nd_sphere.png"
phong_renderer = MeshRenderer(rasterizer=rasterizer, shader=phong_shader)
images = phong_renderer(sphere_mesh, lights=lights)
rgb = images[0, ..., :3].squeeze().cpu()
if DEBUG:
debug_filename = "DEBUG_%s" % filename
filename = "DEBUG_simple_sphere_dark%s%s%s%s.png" % (
postfix,
variant,
postfix,
FishEyeCameras.__name__,
)
Image.fromarray((rgb.numpy() * 255).astype(np.uint8)).save(
DATA_DIR / debug_filename
DATA_DIR / filename
)
image_ref_phong_dark = load_rgb_image(
"test_simple_sphere_dark%s%s%s%s.png"
% (postfix, variant, postfix, FishEyeCameras.__name__),
DATA_DIR,
)
# Soft shaders (SplatterPhong) will have a different boundary than hard
# ones, but should be identical otherwise.
self.assertLess((rgb - image_ref_phong_dark).quantile(0.99), 0.005)
def test_fisheye_cow_mesh(self):
"""
Test FishEye Camera distortions on real meshes
"""
device = torch.device("cuda:0")
obj_filename = os.path.join(DATA_DIR, "missing_usemtl/cow.obj")
mesh = load_objs_as_meshes([obj_filename], device=device)
R, T = look_at_view_transform(2.7, 0, 180)
radial_params = torch.tensor([[-1.0, 1.0, 1.0, 0.0, 0.0, -1.0]])
tangential_params = torch.tensor([[0.5, 0.5]])
thin_prism_params = torch.tensor([[0.5, 0.5, 0.5, 0.5]])
combinations = product([False, True], repeat=3)
for combination in combinations:
cameras = FishEyeCameras(
device=device,
R=R,
T=T,
world_coordinates=True,
use_radial=combination[0],
use_tangential=combination[1],
use_thin_prism=combination[2],
radial_params=radial_params,
tangential_params=tangential_params,
thin_prism_params=thin_prism_params,
)
raster_settings = RasterizationSettings(
image_size=512,
blur_radius=0.0,
faces_per_pixel=1,
)
lights = PointLights(device=device, location=[[0.0, 0.0, -3.0]])
renderer = MeshRenderer(
rasterizer=MeshRasterizer(
cameras=cameras, raster_settings=raster_settings
),
shader=SoftPhongShader(device=device, cameras=cameras, lights=lights),
)
images = renderer(mesh)
rgb = images[0, ..., :3].squeeze().cpu()
filename = "test_cow_mesh_%s_radial_%s_tangential_%s_prism_%s.png" % (
FishEyeCameras.__name__,
combination[0],
combination[1],
combination[2],
)
image_ref = load_rgb_image(filename, DATA_DIR)
if DEBUG:
filename = filename.replace("test", "DEBUG")
Image.fromarray((rgb.numpy() * 255).astype(np.uint8)).save(
DATA_DIR / filename
)
self.assertClose(rgb, image_ref, atol=0.05)