Update Rasterizer and add end2end fisheye integration test

Summary:
1) Update rasterizer/point rasterizer to accommodate fisheyecamera. Specifically, transform_points is in placement of explicit transform compositions.

2) In rasterizer unittests, update corresponding tests for rasterizer and point_rasterizer. Address comments to test fisheye against perspective camera when distortions are turned off.

3) Address comments to add end2end test for fisheyecameras. In test_render_meshes, fisheyecameras are added to camera enuerations whenever possible.

4) Test renderings with fisheyecameras of different params on cow mesh.

5) Use compositions for linear cameras whenever possible.

Reviewed By: kjchalup

Differential Revision: D38932736

fbshipit-source-id: 5b7074fc001f2390f4cf43c7267a8b37fd987547
This commit is contained in:
Jiali Duan 2022-08-31 16:50:41 -07:00 committed by Facebook GitHub Bot
parent b0515e1461
commit d19e6243d0
63 changed files with 566 additions and 76 deletions

View File

@ -6,7 +6,7 @@
import math import math
import warnings import warnings
from typing import List, Optional, Sequence, Tuple, Union from typing import Callable, List, Optional, Sequence, Tuple, Union
import numpy as np import numpy as np
import torch import torch
@ -91,7 +91,7 @@ class CamerasBase(TensorProperties):
# When joining objects into a batch, they will have to agree. # When joining objects into a batch, they will have to agree.
_SHARED_FIELDS: Tuple[str, ...] = () _SHARED_FIELDS: Tuple[str, ...] = ()
def get_projection_transform(self): def get_projection_transform(self, **kwargs):
""" """
Calculate the projective transformation matrix. Calculate the projective transformation matrix.
@ -1841,3 +1841,23 @@ def get_screen_to_ndc_transform(
image_size=image_size, image_size=image_size,
).inverse() ).inverse()
return transform return transform
def try_get_projection_transform(cameras, kwargs) -> Optional[Callable]:
"""
Try block to get projection transform.
Args:
cameras instance, can be linear cameras or nonliear cameras
Returns:
If the camera implemented projection_transform, return the
projection transform; Otherwise, return None
"""
transform = None
try:
transform = cameras.get_projection_transform(**kwargs)
except NotImplementedError:
pass
return transform

View File

@ -9,6 +9,7 @@ from typing import Optional, Tuple, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
from pytorch3d.renderer.cameras import try_get_projection_transform
from .rasterize_meshes import rasterize_meshes from .rasterize_meshes import rasterize_meshes
@ -197,12 +198,19 @@ class MeshRasterizer(nn.Module):
verts_view = cameras.get_world_to_view_transform(**kwargs).transform_points( verts_view = cameras.get_world_to_view_transform(**kwargs).transform_points(
verts_world, eps=eps verts_world, eps=eps
) )
# view to NDC transform # Call transform_points instead of explicitly composing transforms to handle
# the case, where camera class does not have a projection matrix form.
verts_proj = cameras.transform_points(verts_world, eps=eps)
to_ndc_transform = cameras.get_ndc_camera_transform(**kwargs) to_ndc_transform = cameras.get_ndc_camera_transform(**kwargs)
projection_transform = cameras.get_projection_transform(**kwargs).compose( projection_transform = try_get_projection_transform(cameras, kwargs)
to_ndc_transform if projection_transform is not None:
) projection_transform = projection_transform.compose(to_ndc_transform)
verts_ndc = projection_transform.transform_points(verts_view, eps=eps) verts_ndc = projection_transform.transform_points(verts_view, eps=eps)
else:
# Call transform_points instead of explicitly composing transforms to handle
# the case, where camera class does not have a projection matrix form.
verts_proj = cameras.transform_points(verts_world, eps=eps)
verts_ndc = to_ndc_transform.transform_points(verts_proj, eps=eps)
verts_ndc[..., 2] = verts_view[..., 2] verts_ndc[..., 2] = verts_view[..., 2]
meshes_ndc = meshes_world.update_padded(new_verts_padded=verts_ndc) meshes_ndc = meshes_world.update_padded(new_verts_padded=verts_ndc)

View File

@ -10,6 +10,7 @@ from typing import NamedTuple, Optional, Tuple, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
from pytorch3d.renderer.cameras import try_get_projection_transform
from pytorch3d.structures import Pointclouds from pytorch3d.structures import Pointclouds
from .rasterize_points import rasterize_points from .rasterize_points import rasterize_points
@ -103,12 +104,16 @@ class PointsRasterizer(nn.Module):
pts_view = cameras.get_world_to_view_transform(**kwargs).transform_points( pts_view = cameras.get_world_to_view_transform(**kwargs).transform_points(
pts_world, eps=eps pts_world, eps=eps
) )
# view to NDC transform
to_ndc_transform = cameras.get_ndc_camera_transform(**kwargs) to_ndc_transform = cameras.get_ndc_camera_transform(**kwargs)
projection_transform = cameras.get_projection_transform(**kwargs).compose( projection_transform = try_get_projection_transform(cameras, kwargs)
to_ndc_transform if projection_transform is not None:
) projection_transform = projection_transform.compose(to_ndc_transform)
pts_ndc = projection_transform.transform_points(pts_view, eps=eps) pts_ndc = projection_transform.transform_points(pts_view, eps=eps)
else:
# Call transform_points instead of explicitly composing transforms to handle
# the case, where camera class does not have a projection matrix form.
pts_proj = cameras.transform_points(pts_world, eps=eps)
pts_ndc = to_ndc_transform.transform_points(pts_proj, eps=eps)
pts_ndc[..., 2] = pts_view[..., 2] pts_ndc[..., 2] = pts_view[..., 2]
point_clouds = point_clouds.update_padded(pts_ndc) point_clouds = point_clouds.update_padded(pts_ndc)

Binary file not shown.

After

Width:  |  Height:  |  Size: 5.5 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 15 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 16 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 18 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 19 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 13 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 14 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 16 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 17 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 2.3 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 2.3 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.3 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.3 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.6 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.9 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.8 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 2.5 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 2.0 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.9 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.8 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.9 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 2.1 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 2.1 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 11 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 8.6 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 8.5 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 13 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 7.9 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 6.0 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 5.8 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 9.1 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 11 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 8.6 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 13 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 10 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 11 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 14 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 14 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 7.9 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 6.0 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 8.8 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 7.2 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 7.9 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 9.7 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 10 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 8.0 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 5.9 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 5.9 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 8.0 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 8.7 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 7.2 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 8.0 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 9.6 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 10 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 8.0 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 5.9 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 8.8 KiB

View File

@ -21,6 +21,7 @@ from pytorch3d.renderer import (
PointsRasterizer, PointsRasterizer,
RasterizationSettings, RasterizationSettings,
) )
from pytorch3d.renderer.fisheyecameras import FishEyeCameras
from pytorch3d.renderer.opengl.rasterizer_opengl import ( from pytorch3d.renderer.opengl.rasterizer_opengl import (
_check_cameras, _check_cameras,
_check_raster_settings, _check_raster_settings,
@ -51,6 +52,9 @@ class TestMeshRasterizer(unittest.TestCase):
def test_simple_sphere(self): def test_simple_sphere(self):
self._simple_sphere(MeshRasterizer) self._simple_sphere(MeshRasterizer)
def test_simple_sphere_fisheye(self):
self._simple_sphere_fisheye_against_perspective(MeshRasterizer)
def test_simple_sphere_opengl(self): def test_simple_sphere_opengl(self):
self._simple_sphere(MeshRasterizerOpenGL) self._simple_sphere(MeshRasterizerOpenGL)
@ -155,6 +159,91 @@ class TestMeshRasterizer(unittest.TestCase):
self.assertTrue(torch.allclose(image, image_ref)) self.assertTrue(torch.allclose(image, image_ref))
def _simple_sphere_fisheye_against_perspective(self, rasterizer_type):
device = torch.device("cuda:0")
# Init mesh
sphere_mesh = ico_sphere(5, device)
# Init rasterizer settings
R, T = look_at_view_transform(2.7, 0, 0)
# Init Fisheye camera params
focal = torch.tensor([[1.7321]], dtype=torch.float32)
principal_point = torch.tensor([[0.0101, -0.0101]])
perspective_cameras = PerspectiveCameras(
R=R,
T=T,
focal_length=focal,
principal_point=principal_point,
device="cuda:0",
)
fisheye_cameras = FishEyeCameras(
device=device,
R=R,
T=T,
focal_length=focal,
principal_point=principal_point,
world_coordinates=True,
use_radial=False,
use_tangential=False,
use_thin_prism=False,
)
raster_settings = RasterizationSettings(
image_size=512, blur_radius=0.0, faces_per_pixel=1, bin_size=0
)
# Init rasterizer
perspective_rasterizer = rasterizer_type(
cameras=perspective_cameras, raster_settings=raster_settings
)
fisheye_rasterizer = rasterizer_type(
cameras=fisheye_cameras, raster_settings=raster_settings
)
####################################################################################
# Test rasterizing a single mesh comparing fisheye camera against perspective camera
####################################################################################
perspective_fragments = perspective_rasterizer(sphere_mesh)
perspective_image = perspective_fragments.pix_to_face[0, ..., 0].squeeze().cpu()
# Convert pix_to_face to a binary mask
perspective_image[perspective_image >= 0] = 1.0
perspective_image[perspective_image < 0] = 0.0
if DEBUG:
Image.fromarray((perspective_image.numpy() * 255).astype(np.uint8)).save(
DATA_DIR
/ f"DEBUG_test_perspective_rasterized_sphere_{rasterizer_type.__name__}.png"
)
fisheye_fragments = fisheye_rasterizer(sphere_mesh)
fisheye_image = fisheye_fragments.pix_to_face[0, ..., 0].squeeze().cpu()
# Convert pix_to_face to a binary mask
fisheye_image[fisheye_image >= 0] = 1.0
fisheye_image[fisheye_image < 0] = 0.0
if DEBUG:
Image.fromarray((fisheye_image.numpy() * 255).astype(np.uint8)).save(
DATA_DIR
/ f"DEBUG_test_fisheye_rasterized_sphere_{rasterizer_type.__name__}.png"
)
self.assertTrue(torch.allclose(fisheye_image, perspective_image))
##################################
# 2. Test with a batch of meshes
##################################
batch_size = 10
sphere_meshes = sphere_mesh.extend(batch_size)
fragments = fisheye_rasterizer(sphere_meshes)
for i in range(batch_size):
image = fragments.pix_to_face[i, ..., 0].squeeze().cpu()
image[image >= 0] = 1.0
image[image < 0] = 0.0
self.assertTrue(torch.allclose(image, perspective_image))
def test_simple_to(self): def test_simple_to(self):
# Check that to() works without a cameras object. # Check that to() works without a cameras object.
device = torch.device("cuda:0") device = torch.device("cuda:0")
@ -412,6 +501,76 @@ class TestPointRasterizer(unittest.TestCase):
image[image < 0] = 0.0 image[image < 0] = 0.0
self.assertTrue(torch.allclose(image, image_ref[..., 0])) self.assertTrue(torch.allclose(image, image_ref[..., 0]))
def test_simple_sphere_fisheye_against_perspective(self):
device = torch.device("cuda:0")
# Rescale image_ref to the 0 - 1 range and convert to a binary mask.
sphere_mesh = ico_sphere(1, device)
verts_padded = sphere_mesh.verts_padded()
verts_padded[..., 1] += 0.2
verts_padded[..., 0] += 0.2
pointclouds = Pointclouds(points=verts_padded)
R, T = look_at_view_transform(2.7, 0.0, 0.0)
perspective_cameras = PerspectiveCameras(
R=R,
T=T,
device=device,
)
fisheye_cameras = FishEyeCameras(
device=device,
R=R,
T=T,
world_coordinates=True,
use_radial=False,
use_tangential=False,
use_thin_prism=False,
)
raster_settings = PointsRasterizationSettings(
image_size=256, radius=5e-2, points_per_pixel=1
)
#################################
# 1. Test init without cameras.
##################################
# Initialize without passing in the cameras
rasterizer = PointsRasterizer()
# Check that omitting the cameras in both initialization
# and the forward pass throws an error:
with self.assertRaisesRegex(ValueError, "Cameras must be specified"):
rasterizer(pointclouds)
########################################################################################
# 2. Test rasterizing a single pointcloud with fisheye camera agasint perspective camera
########################################################################################
perspective_fragments = rasterizer(
pointclouds, cameras=perspective_cameras, raster_settings=raster_settings
)
fisheye_fragments = rasterizer(
pointclouds, cameras=fisheye_cameras, raster_settings=raster_settings
)
# Convert idx to a binary mask
perspective_image = perspective_fragments.idx[0, ..., 0].squeeze().cpu()
perspective_image[perspective_image >= 0] = 1.0
perspective_image[perspective_image < 0] = 0.0
fisheye_image = fisheye_fragments.idx[0, ..., 0].squeeze().cpu()
fisheye_image[fisheye_image >= 0] = 1.0
fisheye_image[fisheye_image < 0] = 0.0
if DEBUG:
Image.fromarray((perspective_image.numpy() * 255).astype(np.uint8)).save(
DATA_DIR / "DEBUG_test_rasterized_perspective_sphere_points.png"
)
Image.fromarray((fisheye_image.numpy() * 255).astype(np.uint8)).save(
DATA_DIR / "DEBUG_test_rasterized_fisheye_sphere_points.png"
)
self.assertTrue(torch.allclose(fisheye_image, perspective_image))
def test_simple_to(self): def test_simple_to(self):
# Check that to() works without a cameras object. # Check that to() works without a cameras object.
device = torch.device("cuda:0") device = torch.device("cuda:0")

View File

@ -12,10 +12,12 @@ import os
import unittest import unittest
from collections import namedtuple from collections import namedtuple
from itertools import product
import numpy as np import numpy as np
import torch import torch
from PIL import Image from PIL import Image
from pytorch3d.io import load_obj from pytorch3d.io import load_obj, load_objs_as_meshes
from pytorch3d.renderer import ( from pytorch3d.renderer import (
AmbientLights, AmbientLights,
FoVOrthographicCameras, FoVOrthographicCameras,
@ -33,6 +35,7 @@ from pytorch3d.renderer import (
TexturesUV, TexturesUV,
TexturesVertex, TexturesVertex,
) )
from pytorch3d.renderer.fisheyecameras import FishEyeCameras
from pytorch3d.renderer.mesh.shader import ( from pytorch3d.renderer.mesh.shader import (
BlendParams, BlendParams,
HardFlatShader, HardFlatShader,
@ -59,7 +62,6 @@ from .common_testing import (
TestCaseMixin, TestCaseMixin,
) )
# If DEBUG=True, save out images generated in the tests for debugging. # If DEBUG=True, save out images generated in the tests for debugging.
# All saved images have prefix DEBUG_ # All saved images have prefix DEBUG_
DEBUG = False DEBUG = False
@ -107,8 +109,38 @@ class TestRenderMeshes(TestCaseMixin, unittest.TestCase):
FoVOrthographicCameras, FoVOrthographicCameras,
PerspectiveCameras, PerspectiveCameras,
OrthographicCameras, OrthographicCameras,
FishEyeCameras,
): ):
cameras = cam_type(device=device, R=R, T=T) if cam_type == FishEyeCameras:
cam_kwargs = {
"radial_params": torch.tensor(
[
[-1, -2, -3, 0, 0, 1],
],
dtype=torch.float32,
),
"tangential_params": torch.tensor(
[[0.7002747019, -0.4005228974]], dtype=torch.float32
),
"thin_prism_params": torch.tensor(
[
[-1.000134884, -1.000084822, -1.0009420014, -1.0001276838],
],
dtype=torch.float32,
),
}
cameras = cam_type(
device=device,
R=R,
T=T,
use_tangential=True,
use_radial=True,
use_thin_prism=True,
world_coordinates=True,
**cam_kwargs,
)
else:
cameras = cam_type(device=device, R=R, T=T)
# Init shader settings # Init shader settings
materials = Materials(device=device) materials = Materials(device=device)
@ -146,7 +178,11 @@ class TestRenderMeshes(TestCaseMixin, unittest.TestCase):
cameras=cameras, raster_settings=raster_settings cameras=cameras, raster_settings=raster_settings
) )
elif test.rasterizer == MeshRasterizerOpenGL: elif test.rasterizer == MeshRasterizerOpenGL:
if type(cameras) in [PerspectiveCameras, OrthographicCameras]: if type(cameras) in [
PerspectiveCameras,
OrthographicCameras,
FishEyeCameras,
]:
# MeshRasterizerOpenGL is only compatible with FoV cameras. # MeshRasterizerOpenGL is only compatible with FoV cameras.
continue continue
rasterizer = test.rasterizer( rasterizer = test.rasterizer(
@ -181,8 +217,6 @@ class TestRenderMeshes(TestCaseMixin, unittest.TestCase):
) )
image_ref = load_rgb_image("test_%s" % filename, DATA_DIR) image_ref = load_rgb_image("test_%s" % filename, DATA_DIR)
self.assertClose(rgb, image_ref, atol=0.05)
if DEBUG: if DEBUG:
debug_filename = "simple_sphere_light_%s%s%s.png" % ( debug_filename = "simple_sphere_light_%s%s%s.png" % (
test.debug_name, test.debug_name,
@ -193,6 +227,7 @@ class TestRenderMeshes(TestCaseMixin, unittest.TestCase):
Image.fromarray((rgb.numpy() * 255).astype(np.uint8)).save( Image.fromarray((rgb.numpy() * 255).astype(np.uint8)).save(
DATA_DIR / filename DATA_DIR / filename
) )
self.assertClose(rgb, image_ref, atol=0.05)
######################################################## ########################################################
# Move the light to the +z axis in world space so it is # Move the light to the +z axis in world space so it is
@ -429,8 +464,20 @@ class TestRenderMeshes(TestCaseMixin, unittest.TestCase):
FoVOrthographicCameras, FoVOrthographicCameras,
PerspectiveCameras, PerspectiveCameras,
OrthographicCameras, OrthographicCameras,
FishEyeCameras,
): ):
cameras = cam_type(device=device, R=R, T=T) if cam_type == FishEyeCameras:
cameras = cam_type(
device=device,
R=R,
T=T,
use_tangential=False,
use_radial=False,
use_thin_prism=False,
world_coordinates=True,
)
else:
cameras = cam_type(device=device, R=R, T=T)
# Init renderer # Init renderer
renderer = MeshRenderer( renderer = MeshRenderer(
@ -1443,84 +1490,291 @@ class TestRenderMeshes(TestCaseMixin, unittest.TestCase):
) )
self.assertClose(rgb, image_ref, atol=0.05) self.assertClose(rgb, image_ref, atol=0.05)
def test_nd_sphere(self): def test_nd_sphere(self):
""" """
Test that the render can handle textures with more than 3 channels and Test that the render can handle textures with more than 3 channels and
not just 3 channel RGB. not just 3 channel RGB.
""" """
torch.manual_seed(1) torch.manual_seed(1)
device = torch.device("cuda:0") device = torch.device("cuda:0")
C = 5 C = 5
WHITE = ((1.0,) * C,) WHITE = ((1.0,) * C,)
BLACK = ((0.0,) * C,) BLACK = ((0.0,) * C,)
# Init mesh # Init mesh
sphere_mesh = ico_sphere(5, device) sphere_mesh = ico_sphere(5, device)
verts_padded = sphere_mesh.verts_padded() verts_padded = sphere_mesh.verts_padded()
faces_padded = sphere_mesh.faces_padded() faces_padded = sphere_mesh.faces_padded()
feats = torch.ones(*verts_padded.shape[:-1], C, device=device) feats = torch.ones(*verts_padded.shape[:-1], C, device=device)
n_verts = feats.shape[1] n_verts = feats.shape[1]
# make some non-uniform pattern # make some non-uniform pattern
feats *= torch.arange(0, 10, step=10 / n_verts, device=device).unsqueeze(1) feats *= torch.arange(0, 10, step=10 / n_verts, device=device).unsqueeze(1)
textures = TexturesVertex(verts_features=feats) textures = TexturesVertex(verts_features=feats)
sphere_mesh = Meshes( sphere_mesh = Meshes(verts=verts_padded, faces=faces_padded, textures=textures)
verts=verts_padded, faces=faces_padded, textures=textures
# No elevation or azimuth rotation
R, T = look_at_view_transform(2.7, 0.0, 0.0)
cameras = PerspectiveCameras(device=device, R=R, T=T)
# Init shader settings
materials = Materials(
device=device,
ambient_color=WHITE,
diffuse_color=WHITE,
specular_color=WHITE,
)
lights = AmbientLights(
device=device,
ambient_color=WHITE,
)
lights.location = torch.tensor([0.0, 0.0, +2.0], device=device)[None]
raster_settings = RasterizationSettings(
image_size=512, blur_radius=0.0, faces_per_pixel=1
)
rasterizer = MeshRasterizer(cameras=cameras, raster_settings=raster_settings)
blend_params = BlendParams(
1e-4,
1e-4,
background_color=BLACK[0],
)
# only test HardFlatShader since that's the only one that makes
# sense for classification
shader = HardFlatShader(
lights=lights,
cameras=cameras,
materials=materials,
blend_params=blend_params,
)
renderer = MeshRenderer(rasterizer=rasterizer, shader=shader)
images = renderer(sphere_mesh)
self.assertEqual(images.shape[-1], C + 1)
self.assertClose(images.amax(), torch.tensor(10.0), atol=0.01)
self.assertClose(images.amin(), torch.tensor(0.0), atol=0.01)
# grab last 3 color channels
rgb = (images[0, ..., C - 3 : C] / 10).squeeze().cpu()
filename = "test_nd_sphere.png"
if DEBUG:
debug_filename = "DEBUG_%s" % filename
Image.fromarray((rgb.numpy() * 255).astype(np.uint8)).save(
DATA_DIR / debug_filename
) )
# No elevation or azimuth rotation image_ref = load_rgb_image(filename, DATA_DIR)
R, T = look_at_view_transform(2.7, 0.0, 0.0) self.assertClose(rgb, image_ref, atol=0.05)
cameras = PerspectiveCameras(device=device, R=R, T=T) def test_simple_sphere_fisheye_params(self):
"""
Test output of phong and gouraud shading matches a reference image using
the default values for the light sources.
"""
device = torch.device("cuda:0")
# Init mesh
sphere_mesh = ico_sphere(5, device)
verts_padded = sphere_mesh.verts_padded()
faces_padded = sphere_mesh.faces_padded()
feats = torch.ones_like(verts_padded, device=device)
textures = TexturesVertex(verts_features=feats)
sphere_mesh = Meshes(verts=verts_padded, faces=faces_padded, textures=textures)
# Init rasterizer settings
R, T = look_at_view_transform(2.7, 0.0, 0.0)
postfix = "_"
cam_kwargs = [
{
"radial_params": torch.tensor(
[
[-1, -2, -3, 0, 0, 1],
],
dtype=torch.float32,
),
},
{
"tangential_params": torch.tensor(
[[0.7002747019, -0.4005228974]], dtype=torch.float32
),
},
{
"thin_prism_params": torch.tensor(
[
[
-1.000134884,
-1.000084822,
-1.0009420014,
-1.0001276838,
],
],
dtype=torch.float32,
),
},
]
variants = ["radial", "tangential", "prism"]
for test_case, variant in zip(cam_kwargs, variants):
cameras = FishEyeCameras(
device=device,
R=R,
T=T,
use_tangential=True,
use_radial=True,
use_thin_prism=True,
world_coordinates=True,
**test_case,
)
# Init shader settings # Init shader settings
materials = Materials( materials = Materials(device=device)
device=device, lights = PointLights(device=device)
ambient_color=WHITE,
diffuse_color=WHITE,
specular_color=WHITE,
)
lights = AmbientLights(
device=device,
ambient_color=WHITE,
)
lights.location = torch.tensor([0.0, 0.0, +2.0], device=device)[None] lights.location = torch.tensor([0.0, 0.0, +2.0], device=device)[None]
raster_settings = RasterizationSettings( raster_settings = RasterizationSettings(
image_size=512, blur_radius=0.0, faces_per_pixel=1 image_size=512, blur_radius=0.0, faces_per_pixel=1
) )
rasterizer = MeshRasterizer( blend_params = BlendParams(0.5, 1e-4, (0, 0, 0))
cameras=cameras, raster_settings=raster_settings
)
blend_params = BlendParams(
1e-4,
1e-4,
background_color=BLACK[0],
)
# only test HardFlatShader since that's the only one that makes # Test several shaders
# sense for classification rasterizer_tests = [
shader = HardFlatShader( RasterizerTest(
MeshRasterizer, HardPhongShader, "hard_phong", "hard_phong"
),
RasterizerTest(
MeshRasterizer, HardGouraudShader, "hard_gouraud", "hard_gouraud"
),
RasterizerTest(
MeshRasterizer, HardFlatShader, "hard_flat", "hard_flat"
),
]
for test in rasterizer_tests:
shader = test.shader(
lights=lights,
cameras=cameras,
materials=materials,
blend_params=blend_params,
)
if test.rasterizer == MeshRasterizer:
rasterizer = test.rasterizer(
cameras=cameras, raster_settings=raster_settings
)
renderer = MeshRenderer(rasterizer=rasterizer, shader=shader)
images = renderer(sphere_mesh)
rgb = images[0, ..., :3].squeeze().cpu()
filename = "simple_sphere_light_%s%s%s%s%s.png" % (
test.reference_name,
postfix,
variant,
postfix,
FishEyeCameras.__name__,
)
image_ref = load_rgb_image("test_%s" % filename, DATA_DIR)
if DEBUG:
debug_filename = "simple_sphere_light_%s%s%s%s%s.png" % (
test.debug_name,
postfix,
variant,
postfix,
FishEyeCameras.__name__,
)
filename = "DEBUG_%s" % debug_filename
Image.fromarray((rgb.numpy() * 255).astype(np.uint8)).save(
DATA_DIR / filename
)
self.assertClose(rgb, image_ref, atol=0.05)
########################################################
# Move the light to the +z axis in world space so it is
# behind the sphere. Note that +Z is in, +Y up,
# +X left for both world and camera space.
########################################################
lights.location[..., 2] = -2.0
phong_shader = HardPhongShader(
lights=lights, lights=lights,
cameras=cameras, cameras=cameras,
materials=materials, materials=materials,
blend_params=blend_params, blend_params=blend_params,
) )
renderer = MeshRenderer(rasterizer=rasterizer, shader=shader)
images = renderer(sphere_mesh)
self.assertEqual(images.shape[-1], C + 1)
self.assertClose(images.amax(), torch.tensor(10.0), atol=0.01)
self.assertClose(images.amin(), torch.tensor(0.0), atol=0.01)
# grab last 3 color channels
rgb = (images[0, ..., C - 3 : C] / 10).squeeze().cpu()
filename = "test_nd_sphere.png"
phong_renderer = MeshRenderer(rasterizer=rasterizer, shader=phong_shader)
images = phong_renderer(sphere_mesh, lights=lights)
rgb = images[0, ..., :3].squeeze().cpu()
if DEBUG: if DEBUG:
debug_filename = "DEBUG_%s" % filename filename = "DEBUG_simple_sphere_dark%s%s%s%s.png" % (
postfix,
variant,
postfix,
FishEyeCameras.__name__,
)
Image.fromarray((rgb.numpy() * 255).astype(np.uint8)).save( Image.fromarray((rgb.numpy() * 255).astype(np.uint8)).save(
DATA_DIR / debug_filename DATA_DIR / filename
) )
image_ref_phong_dark = load_rgb_image(
"test_simple_sphere_dark%s%s%s%s.png"
% (postfix, variant, postfix, FishEyeCameras.__name__),
DATA_DIR,
)
# Soft shaders (SplatterPhong) will have a different boundary than hard
# ones, but should be identical otherwise.
self.assertLess((rgb - image_ref_phong_dark).quantile(0.99), 0.005)
def test_fisheye_cow_mesh(self):
"""
Test FishEye Camera distortions on real meshes
"""
device = torch.device("cuda:0")
obj_filename = os.path.join(DATA_DIR, "missing_usemtl/cow.obj")
mesh = load_objs_as_meshes([obj_filename], device=device)
R, T = look_at_view_transform(2.7, 0, 180)
radial_params = torch.tensor([[-1.0, 1.0, 1.0, 0.0, 0.0, -1.0]])
tangential_params = torch.tensor([[0.5, 0.5]])
thin_prism_params = torch.tensor([[0.5, 0.5, 0.5, 0.5]])
combinations = product([False, True], repeat=3)
for combination in combinations:
cameras = FishEyeCameras(
device=device,
R=R,
T=T,
world_coordinates=True,
use_radial=combination[0],
use_tangential=combination[1],
use_thin_prism=combination[2],
radial_params=radial_params,
tangential_params=tangential_params,
thin_prism_params=thin_prism_params,
)
raster_settings = RasterizationSettings(
image_size=512,
blur_radius=0.0,
faces_per_pixel=1,
)
lights = PointLights(device=device, location=[[0.0, 0.0, -3.0]])
renderer = MeshRenderer(
rasterizer=MeshRasterizer(
cameras=cameras, raster_settings=raster_settings
),
shader=SoftPhongShader(device=device, cameras=cameras, lights=lights),
)
images = renderer(mesh)
rgb = images[0, ..., :3].squeeze().cpu()
filename = "test_cow_mesh_%s_radial_%s_tangential_%s_prism_%s.png" % (
FishEyeCameras.__name__,
combination[0],
combination[1],
combination[2],
)
image_ref = load_rgb_image(filename, DATA_DIR) image_ref = load_rgb_image(filename, DATA_DIR)
if DEBUG:
filename = filename.replace("test", "DEBUG")
Image.fromarray((rgb.numpy() * 255).astype(np.uint8)).save(
DATA_DIR / filename
)
self.assertClose(rgb, image_ref, atol=0.05) self.assertClose(rgb, image_ref, atol=0.05)

View File

@ -23,6 +23,7 @@ from pytorch3d.renderer.cameras import (
PerspectiveCameras, PerspectiveCameras,
) )
from pytorch3d.renderer.compositing import alpha_composite, norm_weighted_sum from pytorch3d.renderer.compositing import alpha_composite, norm_weighted_sum
from pytorch3d.renderer.fisheyecameras import FishEyeCameras
from pytorch3d.renderer.points import ( from pytorch3d.renderer.points import (
AlphaCompositor, AlphaCompositor,
NormWeightedCompositor, NormWeightedCompositor,
@ -84,6 +85,49 @@ class TestRenderPoints(TestCaseMixin, unittest.TestCase):
) )
self.assertClose(rgb, image_ref) self.assertClose(rgb, image_ref)
def test_simple_sphere_fisheye(self):
device = torch.device("cuda:0")
sphere_mesh = ico_sphere(1, device)
verts_padded = sphere_mesh.verts_padded()
# Shift vertices to check coordinate frames are correct.
verts_padded[..., 1] += 0.2
verts_padded[..., 0] += 0.2
pointclouds = Pointclouds(
points=verts_padded, features=torch.ones_like(verts_padded)
)
R, T = look_at_view_transform(2.7, 0.0, 0.0)
cameras = FishEyeCameras(
device=device,
R=R,
T=T,
use_radial=False,
use_tangential=False,
use_thin_prism=False,
world_coordinates=True,
)
raster_settings = PointsRasterizationSettings(
image_size=256, radius=5e-2, points_per_pixel=1
)
rasterizer = PointsRasterizer(cameras=cameras, raster_settings=raster_settings)
compositor = NormWeightedCompositor()
renderer = PointsRenderer(rasterizer=rasterizer, compositor=compositor)
# Load reference image
filename = "render_fisheye_sphere_points.png"
image_ref = load_rgb_image("test_%s" % filename, DATA_DIR)
for bin_size in [0, None]:
# Check both naive and coarse to fine produce the same output.
renderer.rasterizer.raster_settings.bin_size = bin_size
images = renderer(pointclouds)
rgb = images[0, ..., :3].squeeze().cpu()
if DEBUG:
filename = "DEBUG_%s" % filename
Image.fromarray((rgb.numpy() * 255).astype(np.uint8)).save(
DATA_DIR / filename
)
self.assertClose(rgb, image_ref)
def test_simple_sphere_pulsar(self): def test_simple_sphere_pulsar(self):
for device in [torch.device("cpu"), torch.device("cuda")]: for device in [torch.device("cpu"), torch.device("cuda")]:
sphere_mesh = ico_sphere(1, device) sphere_mesh = ico_sphere(1, device)