Update Rasterizer and add end2end fisheye integration test
Summary: 1) Update rasterizer/point rasterizer to accommodate fisheyecamera. Specifically, transform_points is in placement of explicit transform compositions. 2) In rasterizer unittests, update corresponding tests for rasterizer and point_rasterizer. Address comments to test fisheye against perspective camera when distortions are turned off. 3) Address comments to add end2end test for fisheyecameras. In test_render_meshes, fisheyecameras are added to camera enuerations whenever possible. 4) Test renderings with fisheyecameras of different params on cow mesh. 5) Use compositions for linear cameras whenever possible. Reviewed By: kjchalup Differential Revision: D38932736 fbshipit-source-id: 5b7074fc001f2390f4cf43c7267a8b37fd987547
@ -6,7 +6,7 @@
|
||||
|
||||
import math
|
||||
import warnings
|
||||
from typing import List, Optional, Sequence, Tuple, Union
|
||||
from typing import Callable, List, Optional, Sequence, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@ -91,7 +91,7 @@ class CamerasBase(TensorProperties):
|
||||
# When joining objects into a batch, they will have to agree.
|
||||
_SHARED_FIELDS: Tuple[str, ...] = ()
|
||||
|
||||
def get_projection_transform(self):
|
||||
def get_projection_transform(self, **kwargs):
|
||||
"""
|
||||
Calculate the projective transformation matrix.
|
||||
|
||||
@ -1841,3 +1841,23 @@ def get_screen_to_ndc_transform(
|
||||
image_size=image_size,
|
||||
).inverse()
|
||||
return transform
|
||||
|
||||
|
||||
def try_get_projection_transform(cameras, kwargs) -> Optional[Callable]:
|
||||
"""
|
||||
Try block to get projection transform.
|
||||
|
||||
Args:
|
||||
cameras instance, can be linear cameras or nonliear cameras
|
||||
|
||||
Returns:
|
||||
If the camera implemented projection_transform, return the
|
||||
projection transform; Otherwise, return None
|
||||
"""
|
||||
|
||||
transform = None
|
||||
try:
|
||||
transform = cameras.get_projection_transform(**kwargs)
|
||||
except NotImplementedError:
|
||||
pass
|
||||
return transform
|
||||
|
@ -9,6 +9,7 @@ from typing import Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from pytorch3d.renderer.cameras import try_get_projection_transform
|
||||
|
||||
from .rasterize_meshes import rasterize_meshes
|
||||
|
||||
@ -197,12 +198,19 @@ class MeshRasterizer(nn.Module):
|
||||
verts_view = cameras.get_world_to_view_transform(**kwargs).transform_points(
|
||||
verts_world, eps=eps
|
||||
)
|
||||
# view to NDC transform
|
||||
# Call transform_points instead of explicitly composing transforms to handle
|
||||
# the case, where camera class does not have a projection matrix form.
|
||||
verts_proj = cameras.transform_points(verts_world, eps=eps)
|
||||
to_ndc_transform = cameras.get_ndc_camera_transform(**kwargs)
|
||||
projection_transform = cameras.get_projection_transform(**kwargs).compose(
|
||||
to_ndc_transform
|
||||
)
|
||||
verts_ndc = projection_transform.transform_points(verts_view, eps=eps)
|
||||
projection_transform = try_get_projection_transform(cameras, kwargs)
|
||||
if projection_transform is not None:
|
||||
projection_transform = projection_transform.compose(to_ndc_transform)
|
||||
verts_ndc = projection_transform.transform_points(verts_view, eps=eps)
|
||||
else:
|
||||
# Call transform_points instead of explicitly composing transforms to handle
|
||||
# the case, where camera class does not have a projection matrix form.
|
||||
verts_proj = cameras.transform_points(verts_world, eps=eps)
|
||||
verts_ndc = to_ndc_transform.transform_points(verts_proj, eps=eps)
|
||||
|
||||
verts_ndc[..., 2] = verts_view[..., 2]
|
||||
meshes_ndc = meshes_world.update_padded(new_verts_padded=verts_ndc)
|
||||
|
@ -10,6 +10,7 @@ from typing import NamedTuple, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from pytorch3d.renderer.cameras import try_get_projection_transform
|
||||
from pytorch3d.structures import Pointclouds
|
||||
|
||||
from .rasterize_points import rasterize_points
|
||||
@ -103,12 +104,16 @@ class PointsRasterizer(nn.Module):
|
||||
pts_view = cameras.get_world_to_view_transform(**kwargs).transform_points(
|
||||
pts_world, eps=eps
|
||||
)
|
||||
# view to NDC transform
|
||||
to_ndc_transform = cameras.get_ndc_camera_transform(**kwargs)
|
||||
projection_transform = cameras.get_projection_transform(**kwargs).compose(
|
||||
to_ndc_transform
|
||||
)
|
||||
pts_ndc = projection_transform.transform_points(pts_view, eps=eps)
|
||||
projection_transform = try_get_projection_transform(cameras, kwargs)
|
||||
if projection_transform is not None:
|
||||
projection_transform = projection_transform.compose(to_ndc_transform)
|
||||
pts_ndc = projection_transform.transform_points(pts_view, eps=eps)
|
||||
else:
|
||||
# Call transform_points instead of explicitly composing transforms to handle
|
||||
# the case, where camera class does not have a projection matrix form.
|
||||
pts_proj = cameras.transform_points(pts_world, eps=eps)
|
||||
pts_ndc = to_ndc_transform.transform_points(pts_proj, eps=eps)
|
||||
|
||||
pts_ndc[..., 2] = pts_view[..., 2]
|
||||
point_clouds = point_clouds.update_padded(pts_ndc)
|
||||
|
BIN
tests/data/test_FishEyeCameras_silhouette.png
Normal file
After Width: | Height: | Size: 5.5 KiB |
After Width: | Height: | Size: 15 KiB |
After Width: | Height: | Size: 16 KiB |
After Width: | Height: | Size: 18 KiB |
After Width: | Height: | Size: 19 KiB |
After Width: | Height: | Size: 13 KiB |
After Width: | Height: | Size: 14 KiB |
After Width: | Height: | Size: 16 KiB |
After Width: | Height: | Size: 17 KiB |
BIN
tests/data/test_fisheye_rasterized_sphere_MeshRasterizer.png
Normal file
After Width: | Height: | Size: 2.3 KiB |
BIN
tests/data/test_perspective_rasterized_sphere_MeshRasterizer.png
Normal file
After Width: | Height: | Size: 2.3 KiB |
BIN
tests/data/test_rasterized_fisheye_sphere_points.png
Normal file
After Width: | Height: | Size: 1.3 KiB |
BIN
tests/data/test_rasterized_perspective_sphere_points.png
Normal file
After Width: | Height: | Size: 1.3 KiB |
BIN
tests/data/test_render_fisheye_sphere_points.png
Normal file
After Width: | Height: | Size: 1.6 KiB |
BIN
tests/data/test_simple_sphere_dark_FishEyeCameras.png
Normal file
After Width: | Height: | Size: 1.9 KiB |
BIN
tests/data/test_simple_sphere_dark_elevated_FishEyeCameras.png
Normal file
After Width: | Height: | Size: 1.8 KiB |
After Width: | Height: | Size: 2.5 KiB |
BIN
tests/data/test_simple_sphere_dark_none_FishEyeCameras.png
Normal file
After Width: | Height: | Size: 2.0 KiB |
BIN
tests/data/test_simple_sphere_dark_prism_FishEyeCameras.png
Normal file
After Width: | Height: | Size: 1.9 KiB |
BIN
tests/data/test_simple_sphere_dark_radial_FishEyeCameras.png
Normal file
After Width: | Height: | Size: 1.8 KiB |
After Width: | Height: | Size: 1.9 KiB |
BIN
tests/data/test_simple_sphere_dark_tangential_FishEyeCameras.png
Normal file
After Width: | Height: | Size: 2.1 KiB |
After Width: | Height: | Size: 2.1 KiB |
BIN
tests/data/test_simple_sphere_light_flat_FishEyeCameras.png
Normal file
After Width: | Height: | Size: 11 KiB |
After Width: | Height: | Size: 8.6 KiB |
After Width: | Height: | Size: 8.5 KiB |
BIN
tests/data/test_simple_sphere_light_flat_none_FishEyeCameras.png
Normal file
After Width: | Height: | Size: 13 KiB |
BIN
tests/data/test_simple_sphere_light_gouraud_FishEyeCameras.png
Normal file
After Width: | Height: | Size: 7.9 KiB |
After Width: | Height: | Size: 6.0 KiB |
After Width: | Height: | Size: 5.8 KiB |
After Width: | Height: | Size: 9.1 KiB |
BIN
tests/data/test_simple_sphere_light_hard_flat_FishEyeCameras.png
Normal file
After Width: | Height: | Size: 11 KiB |
After Width: | Height: | Size: 8.6 KiB |
After Width: | Height: | Size: 13 KiB |
After Width: | Height: | Size: 10 KiB |
After Width: | Height: | Size: 11 KiB |
After Width: | Height: | Size: 14 KiB |
After Width: | Height: | Size: 14 KiB |
After Width: | Height: | Size: 7.9 KiB |
After Width: | Height: | Size: 6.0 KiB |
After Width: | Height: | Size: 8.8 KiB |
After Width: | Height: | Size: 7.2 KiB |
After Width: | Height: | Size: 7.9 KiB |
After Width: | Height: | Size: 9.7 KiB |
After Width: | Height: | Size: 10 KiB |
After Width: | Height: | Size: 8.0 KiB |
After Width: | Height: | Size: 5.9 KiB |
After Width: | Height: | Size: 5.9 KiB |
After Width: | Height: | Size: 8.0 KiB |
After Width: | Height: | Size: 8.7 KiB |
After Width: | Height: | Size: 7.2 KiB |
After Width: | Height: | Size: 8.0 KiB |
After Width: | Height: | Size: 9.6 KiB |
After Width: | Height: | Size: 10 KiB |
BIN
tests/data/test_simple_sphere_light_phong_FishEyeCameras.png
Normal file
After Width: | Height: | Size: 8.0 KiB |
After Width: | Height: | Size: 5.9 KiB |
After Width: | Height: | Size: 8.8 KiB |
@ -21,6 +21,7 @@ from pytorch3d.renderer import (
|
||||
PointsRasterizer,
|
||||
RasterizationSettings,
|
||||
)
|
||||
from pytorch3d.renderer.fisheyecameras import FishEyeCameras
|
||||
from pytorch3d.renderer.opengl.rasterizer_opengl import (
|
||||
_check_cameras,
|
||||
_check_raster_settings,
|
||||
@ -51,6 +52,9 @@ class TestMeshRasterizer(unittest.TestCase):
|
||||
def test_simple_sphere(self):
|
||||
self._simple_sphere(MeshRasterizer)
|
||||
|
||||
def test_simple_sphere_fisheye(self):
|
||||
self._simple_sphere_fisheye_against_perspective(MeshRasterizer)
|
||||
|
||||
def test_simple_sphere_opengl(self):
|
||||
self._simple_sphere(MeshRasterizerOpenGL)
|
||||
|
||||
@ -155,6 +159,91 @@ class TestMeshRasterizer(unittest.TestCase):
|
||||
|
||||
self.assertTrue(torch.allclose(image, image_ref))
|
||||
|
||||
def _simple_sphere_fisheye_against_perspective(self, rasterizer_type):
|
||||
device = torch.device("cuda:0")
|
||||
|
||||
# Init mesh
|
||||
sphere_mesh = ico_sphere(5, device)
|
||||
|
||||
# Init rasterizer settings
|
||||
R, T = look_at_view_transform(2.7, 0, 0)
|
||||
|
||||
# Init Fisheye camera params
|
||||
focal = torch.tensor([[1.7321]], dtype=torch.float32)
|
||||
principal_point = torch.tensor([[0.0101, -0.0101]])
|
||||
perspective_cameras = PerspectiveCameras(
|
||||
R=R,
|
||||
T=T,
|
||||
focal_length=focal,
|
||||
principal_point=principal_point,
|
||||
device="cuda:0",
|
||||
)
|
||||
fisheye_cameras = FishEyeCameras(
|
||||
device=device,
|
||||
R=R,
|
||||
T=T,
|
||||
focal_length=focal,
|
||||
principal_point=principal_point,
|
||||
world_coordinates=True,
|
||||
use_radial=False,
|
||||
use_tangential=False,
|
||||
use_thin_prism=False,
|
||||
)
|
||||
raster_settings = RasterizationSettings(
|
||||
image_size=512, blur_radius=0.0, faces_per_pixel=1, bin_size=0
|
||||
)
|
||||
|
||||
# Init rasterizer
|
||||
perspective_rasterizer = rasterizer_type(
|
||||
cameras=perspective_cameras, raster_settings=raster_settings
|
||||
)
|
||||
fisheye_rasterizer = rasterizer_type(
|
||||
cameras=fisheye_cameras, raster_settings=raster_settings
|
||||
)
|
||||
|
||||
####################################################################################
|
||||
# Test rasterizing a single mesh comparing fisheye camera against perspective camera
|
||||
####################################################################################
|
||||
|
||||
perspective_fragments = perspective_rasterizer(sphere_mesh)
|
||||
perspective_image = perspective_fragments.pix_to_face[0, ..., 0].squeeze().cpu()
|
||||
# Convert pix_to_face to a binary mask
|
||||
perspective_image[perspective_image >= 0] = 1.0
|
||||
perspective_image[perspective_image < 0] = 0.0
|
||||
|
||||
if DEBUG:
|
||||
Image.fromarray((perspective_image.numpy() * 255).astype(np.uint8)).save(
|
||||
DATA_DIR
|
||||
/ f"DEBUG_test_perspective_rasterized_sphere_{rasterizer_type.__name__}.png"
|
||||
)
|
||||
|
||||
fisheye_fragments = fisheye_rasterizer(sphere_mesh)
|
||||
fisheye_image = fisheye_fragments.pix_to_face[0, ..., 0].squeeze().cpu()
|
||||
# Convert pix_to_face to a binary mask
|
||||
fisheye_image[fisheye_image >= 0] = 1.0
|
||||
fisheye_image[fisheye_image < 0] = 0.0
|
||||
|
||||
if DEBUG:
|
||||
Image.fromarray((fisheye_image.numpy() * 255).astype(np.uint8)).save(
|
||||
DATA_DIR
|
||||
/ f"DEBUG_test_fisheye_rasterized_sphere_{rasterizer_type.__name__}.png"
|
||||
)
|
||||
|
||||
self.assertTrue(torch.allclose(fisheye_image, perspective_image))
|
||||
|
||||
##################################
|
||||
# 2. Test with a batch of meshes
|
||||
##################################
|
||||
|
||||
batch_size = 10
|
||||
sphere_meshes = sphere_mesh.extend(batch_size)
|
||||
fragments = fisheye_rasterizer(sphere_meshes)
|
||||
for i in range(batch_size):
|
||||
image = fragments.pix_to_face[i, ..., 0].squeeze().cpu()
|
||||
image[image >= 0] = 1.0
|
||||
image[image < 0] = 0.0
|
||||
self.assertTrue(torch.allclose(image, perspective_image))
|
||||
|
||||
def test_simple_to(self):
|
||||
# Check that to() works without a cameras object.
|
||||
device = torch.device("cuda:0")
|
||||
@ -412,6 +501,76 @@ class TestPointRasterizer(unittest.TestCase):
|
||||
image[image < 0] = 0.0
|
||||
self.assertTrue(torch.allclose(image, image_ref[..., 0]))
|
||||
|
||||
def test_simple_sphere_fisheye_against_perspective(self):
|
||||
device = torch.device("cuda:0")
|
||||
|
||||
# Rescale image_ref to the 0 - 1 range and convert to a binary mask.
|
||||
sphere_mesh = ico_sphere(1, device)
|
||||
verts_padded = sphere_mesh.verts_padded()
|
||||
verts_padded[..., 1] += 0.2
|
||||
verts_padded[..., 0] += 0.2
|
||||
pointclouds = Pointclouds(points=verts_padded)
|
||||
R, T = look_at_view_transform(2.7, 0.0, 0.0)
|
||||
perspective_cameras = PerspectiveCameras(
|
||||
R=R,
|
||||
T=T,
|
||||
device=device,
|
||||
)
|
||||
fisheye_cameras = FishEyeCameras(
|
||||
device=device,
|
||||
R=R,
|
||||
T=T,
|
||||
world_coordinates=True,
|
||||
use_radial=False,
|
||||
use_tangential=False,
|
||||
use_thin_prism=False,
|
||||
)
|
||||
raster_settings = PointsRasterizationSettings(
|
||||
image_size=256, radius=5e-2, points_per_pixel=1
|
||||
)
|
||||
|
||||
#################################
|
||||
# 1. Test init without cameras.
|
||||
##################################
|
||||
|
||||
# Initialize without passing in the cameras
|
||||
rasterizer = PointsRasterizer()
|
||||
|
||||
# Check that omitting the cameras in both initialization
|
||||
# and the forward pass throws an error:
|
||||
with self.assertRaisesRegex(ValueError, "Cameras must be specified"):
|
||||
rasterizer(pointclouds)
|
||||
|
||||
########################################################################################
|
||||
# 2. Test rasterizing a single pointcloud with fisheye camera agasint perspective camera
|
||||
########################################################################################
|
||||
|
||||
perspective_fragments = rasterizer(
|
||||
pointclouds, cameras=perspective_cameras, raster_settings=raster_settings
|
||||
)
|
||||
fisheye_fragments = rasterizer(
|
||||
pointclouds, cameras=fisheye_cameras, raster_settings=raster_settings
|
||||
)
|
||||
|
||||
# Convert idx to a binary mask
|
||||
perspective_image = perspective_fragments.idx[0, ..., 0].squeeze().cpu()
|
||||
perspective_image[perspective_image >= 0] = 1.0
|
||||
perspective_image[perspective_image < 0] = 0.0
|
||||
|
||||
fisheye_image = fisheye_fragments.idx[0, ..., 0].squeeze().cpu()
|
||||
fisheye_image[fisheye_image >= 0] = 1.0
|
||||
fisheye_image[fisheye_image < 0] = 0.0
|
||||
|
||||
if DEBUG:
|
||||
Image.fromarray((perspective_image.numpy() * 255).astype(np.uint8)).save(
|
||||
DATA_DIR / "DEBUG_test_rasterized_perspective_sphere_points.png"
|
||||
)
|
||||
Image.fromarray((fisheye_image.numpy() * 255).astype(np.uint8)).save(
|
||||
DATA_DIR / "DEBUG_test_rasterized_fisheye_sphere_points.png"
|
||||
)
|
||||
|
||||
self.assertTrue(torch.allclose(fisheye_image, perspective_image))
|
||||
|
||||
def test_simple_to(self):
|
||||
# Check that to() works without a cameras object.
|
||||
device = torch.device("cuda:0")
|
||||
|
@ -12,10 +12,12 @@ import os
|
||||
import unittest
|
||||
from collections import namedtuple
|
||||
|
||||
from itertools import product
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image
|
||||
from pytorch3d.io import load_obj
|
||||
from pytorch3d.io import load_obj, load_objs_as_meshes
|
||||
from pytorch3d.renderer import (
|
||||
AmbientLights,
|
||||
FoVOrthographicCameras,
|
||||
@ -33,6 +35,7 @@ from pytorch3d.renderer import (
|
||||
TexturesUV,
|
||||
TexturesVertex,
|
||||
)
|
||||
from pytorch3d.renderer.fisheyecameras import FishEyeCameras
|
||||
from pytorch3d.renderer.mesh.shader import (
|
||||
BlendParams,
|
||||
HardFlatShader,
|
||||
@ -59,7 +62,6 @@ from .common_testing import (
|
||||
TestCaseMixin,
|
||||
)
|
||||
|
||||
|
||||
# If DEBUG=True, save out images generated in the tests for debugging.
|
||||
# All saved images have prefix DEBUG_
|
||||
DEBUG = False
|
||||
@ -107,8 +109,38 @@ class TestRenderMeshes(TestCaseMixin, unittest.TestCase):
|
||||
FoVOrthographicCameras,
|
||||
PerspectiveCameras,
|
||||
OrthographicCameras,
|
||||
FishEyeCameras,
|
||||
):
|
||||
cameras = cam_type(device=device, R=R, T=T)
|
||||
if cam_type == FishEyeCameras:
|
||||
cam_kwargs = {
|
||||
"radial_params": torch.tensor(
|
||||
[
|
||||
[-1, -2, -3, 0, 0, 1],
|
||||
],
|
||||
dtype=torch.float32,
|
||||
),
|
||||
"tangential_params": torch.tensor(
|
||||
[[0.7002747019, -0.4005228974]], dtype=torch.float32
|
||||
),
|
||||
"thin_prism_params": torch.tensor(
|
||||
[
|
||||
[-1.000134884, -1.000084822, -1.0009420014, -1.0001276838],
|
||||
],
|
||||
dtype=torch.float32,
|
||||
),
|
||||
}
|
||||
cameras = cam_type(
|
||||
device=device,
|
||||
R=R,
|
||||
T=T,
|
||||
use_tangential=True,
|
||||
use_radial=True,
|
||||
use_thin_prism=True,
|
||||
world_coordinates=True,
|
||||
**cam_kwargs,
|
||||
)
|
||||
else:
|
||||
cameras = cam_type(device=device, R=R, T=T)
|
||||
|
||||
# Init shader settings
|
||||
materials = Materials(device=device)
|
||||
@ -146,7 +178,11 @@ class TestRenderMeshes(TestCaseMixin, unittest.TestCase):
|
||||
cameras=cameras, raster_settings=raster_settings
|
||||
)
|
||||
elif test.rasterizer == MeshRasterizerOpenGL:
|
||||
if type(cameras) in [PerspectiveCameras, OrthographicCameras]:
|
||||
if type(cameras) in [
|
||||
PerspectiveCameras,
|
||||
OrthographicCameras,
|
||||
FishEyeCameras,
|
||||
]:
|
||||
# MeshRasterizerOpenGL is only compatible with FoV cameras.
|
||||
continue
|
||||
rasterizer = test.rasterizer(
|
||||
@ -181,8 +217,6 @@ class TestRenderMeshes(TestCaseMixin, unittest.TestCase):
|
||||
)
|
||||
|
||||
image_ref = load_rgb_image("test_%s" % filename, DATA_DIR)
|
||||
self.assertClose(rgb, image_ref, atol=0.05)
|
||||
|
||||
if DEBUG:
|
||||
debug_filename = "simple_sphere_light_%s%s%s.png" % (
|
||||
test.debug_name,
|
||||
@ -193,6 +227,7 @@ class TestRenderMeshes(TestCaseMixin, unittest.TestCase):
|
||||
Image.fromarray((rgb.numpy() * 255).astype(np.uint8)).save(
|
||||
DATA_DIR / filename
|
||||
)
|
||||
self.assertClose(rgb, image_ref, atol=0.05)
|
||||
|
||||
########################################################
|
||||
# Move the light to the +z axis in world space so it is
|
||||
@ -429,8 +464,20 @@ class TestRenderMeshes(TestCaseMixin, unittest.TestCase):
|
||||
FoVOrthographicCameras,
|
||||
PerspectiveCameras,
|
||||
OrthographicCameras,
|
||||
FishEyeCameras,
|
||||
):
|
||||
cameras = cam_type(device=device, R=R, T=T)
|
||||
if cam_type == FishEyeCameras:
|
||||
cameras = cam_type(
|
||||
device=device,
|
||||
R=R,
|
||||
T=T,
|
||||
use_tangential=False,
|
||||
use_radial=False,
|
||||
use_thin_prism=False,
|
||||
world_coordinates=True,
|
||||
)
|
||||
else:
|
||||
cameras = cam_type(device=device, R=R, T=T)
|
||||
|
||||
# Init renderer
|
||||
renderer = MeshRenderer(
|
||||
@ -1443,84 +1490,291 @@ class TestRenderMeshes(TestCaseMixin, unittest.TestCase):
|
||||
)
|
||||
self.assertClose(rgb, image_ref, atol=0.05)
|
||||
|
||||
def test_nd_sphere(self):
|
||||
"""
|
||||
Test that the render can handle textures with more than 3 channels and
|
||||
not just 3 channel RGB.
|
||||
"""
|
||||
torch.manual_seed(1)
|
||||
device = torch.device("cuda:0")
|
||||
C = 5
|
||||
WHITE = ((1.0,) * C,)
|
||||
BLACK = ((0.0,) * C,)
|
||||
def test_nd_sphere(self):
|
||||
"""
|
||||
Test that the render can handle textures with more than 3 channels and
|
||||
not just 3 channel RGB.
|
||||
"""
|
||||
torch.manual_seed(1)
|
||||
device = torch.device("cuda:0")
|
||||
C = 5
|
||||
WHITE = ((1.0,) * C,)
|
||||
BLACK = ((0.0,) * C,)
|
||||
|
||||
# Init mesh
|
||||
sphere_mesh = ico_sphere(5, device)
|
||||
verts_padded = sphere_mesh.verts_padded()
|
||||
faces_padded = sphere_mesh.faces_padded()
|
||||
feats = torch.ones(*verts_padded.shape[:-1], C, device=device)
|
||||
n_verts = feats.shape[1]
|
||||
# make some non-uniform pattern
|
||||
feats *= torch.arange(0, 10, step=10 / n_verts, device=device).unsqueeze(1)
|
||||
textures = TexturesVertex(verts_features=feats)
|
||||
sphere_mesh = Meshes(
|
||||
verts=verts_padded, faces=faces_padded, textures=textures
|
||||
# Init mesh
|
||||
sphere_mesh = ico_sphere(5, device)
|
||||
verts_padded = sphere_mesh.verts_padded()
|
||||
faces_padded = sphere_mesh.faces_padded()
|
||||
feats = torch.ones(*verts_padded.shape[:-1], C, device=device)
|
||||
n_verts = feats.shape[1]
|
||||
# make some non-uniform pattern
|
||||
feats *= torch.arange(0, 10, step=10 / n_verts, device=device).unsqueeze(1)
|
||||
textures = TexturesVertex(verts_features=feats)
|
||||
sphere_mesh = Meshes(verts=verts_padded, faces=faces_padded, textures=textures)
|
||||
|
||||
# No elevation or azimuth rotation
|
||||
R, T = look_at_view_transform(2.7, 0.0, 0.0)
|
||||
|
||||
cameras = PerspectiveCameras(device=device, R=R, T=T)
|
||||
|
||||
# Init shader settings
|
||||
materials = Materials(
|
||||
device=device,
|
||||
ambient_color=WHITE,
|
||||
diffuse_color=WHITE,
|
||||
specular_color=WHITE,
|
||||
)
|
||||
lights = AmbientLights(
|
||||
device=device,
|
||||
ambient_color=WHITE,
|
||||
)
|
||||
lights.location = torch.tensor([0.0, 0.0, +2.0], device=device)[None]
|
||||
|
||||
raster_settings = RasterizationSettings(
|
||||
image_size=512, blur_radius=0.0, faces_per_pixel=1
|
||||
)
|
||||
rasterizer = MeshRasterizer(cameras=cameras, raster_settings=raster_settings)
|
||||
blend_params = BlendParams(
|
||||
1e-4,
|
||||
1e-4,
|
||||
background_color=BLACK[0],
|
||||
)
|
||||
|
||||
# only test HardFlatShader since that's the only one that makes
|
||||
# sense for classification
|
||||
shader = HardFlatShader(
|
||||
lights=lights,
|
||||
cameras=cameras,
|
||||
materials=materials,
|
||||
blend_params=blend_params,
|
||||
)
|
||||
renderer = MeshRenderer(rasterizer=rasterizer, shader=shader)
|
||||
images = renderer(sphere_mesh)
|
||||
|
||||
self.assertEqual(images.shape[-1], C + 1)
|
||||
self.assertClose(images.amax(), torch.tensor(10.0), atol=0.01)
|
||||
self.assertClose(images.amin(), torch.tensor(0.0), atol=0.01)
|
||||
|
||||
# grab last 3 color channels
|
||||
rgb = (images[0, ..., C - 3 : C] / 10).squeeze().cpu()
|
||||
filename = "test_nd_sphere.png"
|
||||
|
||||
if DEBUG:
|
||||
debug_filename = "DEBUG_%s" % filename
|
||||
Image.fromarray((rgb.numpy() * 255).astype(np.uint8)).save(
|
||||
DATA_DIR / debug_filename
|
||||
)
|
||||
|
||||
# No elevation or azimuth rotation
|
||||
R, T = look_at_view_transform(2.7, 0.0, 0.0)
|
||||
image_ref = load_rgb_image(filename, DATA_DIR)
|
||||
self.assertClose(rgb, image_ref, atol=0.05)
|
||||
|
||||
cameras = PerspectiveCameras(device=device, R=R, T=T)
|
||||
def test_simple_sphere_fisheye_params(self):
|
||||
"""
|
||||
Test output of phong and gouraud shading matches a reference image using
|
||||
the default values for the light sources.
|
||||
|
||||
"""
|
||||
device = torch.device("cuda:0")
|
||||
|
||||
# Init mesh
|
||||
sphere_mesh = ico_sphere(5, device)
|
||||
verts_padded = sphere_mesh.verts_padded()
|
||||
faces_padded = sphere_mesh.faces_padded()
|
||||
feats = torch.ones_like(verts_padded, device=device)
|
||||
textures = TexturesVertex(verts_features=feats)
|
||||
sphere_mesh = Meshes(verts=verts_padded, faces=faces_padded, textures=textures)
|
||||
|
||||
# Init rasterizer settings
|
||||
R, T = look_at_view_transform(2.7, 0.0, 0.0)
|
||||
postfix = "_"
|
||||
|
||||
cam_kwargs = [
|
||||
{
|
||||
"radial_params": torch.tensor(
|
||||
[
|
||||
[-1, -2, -3, 0, 0, 1],
|
||||
],
|
||||
dtype=torch.float32,
|
||||
),
|
||||
},
|
||||
{
|
||||
"tangential_params": torch.tensor(
|
||||
[[0.7002747019, -0.4005228974]], dtype=torch.float32
|
||||
),
|
||||
},
|
||||
{
|
||||
"thin_prism_params": torch.tensor(
|
||||
[
|
||||
[
|
||||
-1.000134884,
|
||||
-1.000084822,
|
||||
-1.0009420014,
|
||||
-1.0001276838,
|
||||
],
|
||||
],
|
||||
dtype=torch.float32,
|
||||
),
|
||||
},
|
||||
]
|
||||
variants = ["radial", "tangential", "prism"]
|
||||
for test_case, variant in zip(cam_kwargs, variants):
|
||||
cameras = FishEyeCameras(
|
||||
device=device,
|
||||
R=R,
|
||||
T=T,
|
||||
use_tangential=True,
|
||||
use_radial=True,
|
||||
use_thin_prism=True,
|
||||
world_coordinates=True,
|
||||
**test_case,
|
||||
)
|
||||
|
||||
# Init shader settings
|
||||
materials = Materials(
|
||||
device=device,
|
||||
ambient_color=WHITE,
|
||||
diffuse_color=WHITE,
|
||||
specular_color=WHITE,
|
||||
)
|
||||
lights = AmbientLights(
|
||||
device=device,
|
||||
ambient_color=WHITE,
|
||||
)
|
||||
materials = Materials(device=device)
|
||||
lights = PointLights(device=device)
|
||||
lights.location = torch.tensor([0.0, 0.0, +2.0], device=device)[None]
|
||||
|
||||
raster_settings = RasterizationSettings(
|
||||
image_size=512, blur_radius=0.0, faces_per_pixel=1
|
||||
)
|
||||
rasterizer = MeshRasterizer(
|
||||
cameras=cameras, raster_settings=raster_settings
|
||||
)
|
||||
blend_params = BlendParams(
|
||||
1e-4,
|
||||
1e-4,
|
||||
background_color=BLACK[0],
|
||||
)
|
||||
blend_params = BlendParams(0.5, 1e-4, (0, 0, 0))
|
||||
|
||||
# only test HardFlatShader since that's the only one that makes
|
||||
# sense for classification
|
||||
shader = HardFlatShader(
|
||||
# Test several shaders
|
||||
rasterizer_tests = [
|
||||
RasterizerTest(
|
||||
MeshRasterizer, HardPhongShader, "hard_phong", "hard_phong"
|
||||
),
|
||||
RasterizerTest(
|
||||
MeshRasterizer, HardGouraudShader, "hard_gouraud", "hard_gouraud"
|
||||
),
|
||||
RasterizerTest(
|
||||
MeshRasterizer, HardFlatShader, "hard_flat", "hard_flat"
|
||||
),
|
||||
]
|
||||
for test in rasterizer_tests:
|
||||
shader = test.shader(
|
||||
lights=lights,
|
||||
cameras=cameras,
|
||||
materials=materials,
|
||||
blend_params=blend_params,
|
||||
)
|
||||
if test.rasterizer == MeshRasterizer:
|
||||
rasterizer = test.rasterizer(
|
||||
cameras=cameras, raster_settings=raster_settings
|
||||
)
|
||||
|
||||
renderer = MeshRenderer(rasterizer=rasterizer, shader=shader)
|
||||
images = renderer(sphere_mesh)
|
||||
|
||||
rgb = images[0, ..., :3].squeeze().cpu()
|
||||
filename = "simple_sphere_light_%s%s%s%s%s.png" % (
|
||||
test.reference_name,
|
||||
postfix,
|
||||
variant,
|
||||
postfix,
|
||||
FishEyeCameras.__name__,
|
||||
)
|
||||
|
||||
image_ref = load_rgb_image("test_%s" % filename, DATA_DIR)
|
||||
if DEBUG:
|
||||
debug_filename = "simple_sphere_light_%s%s%s%s%s.png" % (
|
||||
test.debug_name,
|
||||
postfix,
|
||||
variant,
|
||||
postfix,
|
||||
FishEyeCameras.__name__,
|
||||
)
|
||||
filename = "DEBUG_%s" % debug_filename
|
||||
Image.fromarray((rgb.numpy() * 255).astype(np.uint8)).save(
|
||||
DATA_DIR / filename
|
||||
)
|
||||
self.assertClose(rgb, image_ref, atol=0.05)
|
||||
|
||||
########################################################
|
||||
# Move the light to the +z axis in world space so it is
|
||||
# behind the sphere. Note that +Z is in, +Y up,
|
||||
# +X left for both world and camera space.
|
||||
########################################################
|
||||
lights.location[..., 2] = -2.0
|
||||
phong_shader = HardPhongShader(
|
||||
lights=lights,
|
||||
cameras=cameras,
|
||||
materials=materials,
|
||||
blend_params=blend_params,
|
||||
)
|
||||
renderer = MeshRenderer(rasterizer=rasterizer, shader=shader)
|
||||
images = renderer(sphere_mesh)
|
||||
|
||||
self.assertEqual(images.shape[-1], C + 1)
|
||||
self.assertClose(images.amax(), torch.tensor(10.0), atol=0.01)
|
||||
self.assertClose(images.amin(), torch.tensor(0.0), atol=0.01)
|
||||
|
||||
# grab last 3 color channels
|
||||
rgb = (images[0, ..., C - 3 : C] / 10).squeeze().cpu()
|
||||
filename = "test_nd_sphere.png"
|
||||
|
||||
phong_renderer = MeshRenderer(rasterizer=rasterizer, shader=phong_shader)
|
||||
images = phong_renderer(sphere_mesh, lights=lights)
|
||||
rgb = images[0, ..., :3].squeeze().cpu()
|
||||
if DEBUG:
|
||||
debug_filename = "DEBUG_%s" % filename
|
||||
filename = "DEBUG_simple_sphere_dark%s%s%s%s.png" % (
|
||||
postfix,
|
||||
variant,
|
||||
postfix,
|
||||
FishEyeCameras.__name__,
|
||||
)
|
||||
Image.fromarray((rgb.numpy() * 255).astype(np.uint8)).save(
|
||||
DATA_DIR / debug_filename
|
||||
DATA_DIR / filename
|
||||
)
|
||||
|
||||
image_ref_phong_dark = load_rgb_image(
|
||||
"test_simple_sphere_dark%s%s%s%s.png"
|
||||
% (postfix, variant, postfix, FishEyeCameras.__name__),
|
||||
DATA_DIR,
|
||||
)
|
||||
# Soft shaders (SplatterPhong) will have a different boundary than hard
|
||||
# ones, but should be identical otherwise.
|
||||
self.assertLess((rgb - image_ref_phong_dark).quantile(0.99), 0.005)
|
||||
|
||||
def test_fisheye_cow_mesh(self):
|
||||
"""
|
||||
Test FishEye Camera distortions on real meshes
|
||||
"""
|
||||
device = torch.device("cuda:0")
|
||||
obj_filename = os.path.join(DATA_DIR, "missing_usemtl/cow.obj")
|
||||
mesh = load_objs_as_meshes([obj_filename], device=device)
|
||||
R, T = look_at_view_transform(2.7, 0, 180)
|
||||
radial_params = torch.tensor([[-1.0, 1.0, 1.0, 0.0, 0.0, -1.0]])
|
||||
tangential_params = torch.tensor([[0.5, 0.5]])
|
||||
thin_prism_params = torch.tensor([[0.5, 0.5, 0.5, 0.5]])
|
||||
combinations = product([False, True], repeat=3)
|
||||
for combination in combinations:
|
||||
cameras = FishEyeCameras(
|
||||
device=device,
|
||||
R=R,
|
||||
T=T,
|
||||
world_coordinates=True,
|
||||
use_radial=combination[0],
|
||||
use_tangential=combination[1],
|
||||
use_thin_prism=combination[2],
|
||||
radial_params=radial_params,
|
||||
tangential_params=tangential_params,
|
||||
thin_prism_params=thin_prism_params,
|
||||
)
|
||||
raster_settings = RasterizationSettings(
|
||||
image_size=512,
|
||||
blur_radius=0.0,
|
||||
faces_per_pixel=1,
|
||||
)
|
||||
lights = PointLights(device=device, location=[[0.0, 0.0, -3.0]])
|
||||
renderer = MeshRenderer(
|
||||
rasterizer=MeshRasterizer(
|
||||
cameras=cameras, raster_settings=raster_settings
|
||||
),
|
||||
shader=SoftPhongShader(device=device, cameras=cameras, lights=lights),
|
||||
)
|
||||
images = renderer(mesh)
|
||||
rgb = images[0, ..., :3].squeeze().cpu()
|
||||
filename = "test_cow_mesh_%s_radial_%s_tangential_%s_prism_%s.png" % (
|
||||
FishEyeCameras.__name__,
|
||||
combination[0],
|
||||
combination[1],
|
||||
combination[2],
|
||||
)
|
||||
image_ref = load_rgb_image(filename, DATA_DIR)
|
||||
if DEBUG:
|
||||
filename = filename.replace("test", "DEBUG")
|
||||
Image.fromarray((rgb.numpy() * 255).astype(np.uint8)).save(
|
||||
DATA_DIR / filename
|
||||
)
|
||||
self.assertClose(rgb, image_ref, atol=0.05)
|
||||
|
@ -23,6 +23,7 @@ from pytorch3d.renderer.cameras import (
|
||||
PerspectiveCameras,
|
||||
)
|
||||
from pytorch3d.renderer.compositing import alpha_composite, norm_weighted_sum
|
||||
from pytorch3d.renderer.fisheyecameras import FishEyeCameras
|
||||
from pytorch3d.renderer.points import (
|
||||
AlphaCompositor,
|
||||
NormWeightedCompositor,
|
||||
@ -84,6 +85,49 @@ class TestRenderPoints(TestCaseMixin, unittest.TestCase):
|
||||
)
|
||||
self.assertClose(rgb, image_ref)
|
||||
|
||||
def test_simple_sphere_fisheye(self):
|
||||
device = torch.device("cuda:0")
|
||||
sphere_mesh = ico_sphere(1, device)
|
||||
verts_padded = sphere_mesh.verts_padded()
|
||||
# Shift vertices to check coordinate frames are correct.
|
||||
verts_padded[..., 1] += 0.2
|
||||
verts_padded[..., 0] += 0.2
|
||||
pointclouds = Pointclouds(
|
||||
points=verts_padded, features=torch.ones_like(verts_padded)
|
||||
)
|
||||
R, T = look_at_view_transform(2.7, 0.0, 0.0)
|
||||
cameras = FishEyeCameras(
|
||||
device=device,
|
||||
R=R,
|
||||
T=T,
|
||||
use_radial=False,
|
||||
use_tangential=False,
|
||||
use_thin_prism=False,
|
||||
world_coordinates=True,
|
||||
)
|
||||
raster_settings = PointsRasterizationSettings(
|
||||
image_size=256, radius=5e-2, points_per_pixel=1
|
||||
)
|
||||
rasterizer = PointsRasterizer(cameras=cameras, raster_settings=raster_settings)
|
||||
compositor = NormWeightedCompositor()
|
||||
renderer = PointsRenderer(rasterizer=rasterizer, compositor=compositor)
|
||||
|
||||
# Load reference image
|
||||
filename = "render_fisheye_sphere_points.png"
|
||||
image_ref = load_rgb_image("test_%s" % filename, DATA_DIR)
|
||||
|
||||
for bin_size in [0, None]:
|
||||
# Check both naive and coarse to fine produce the same output.
|
||||
renderer.rasterizer.raster_settings.bin_size = bin_size
|
||||
images = renderer(pointclouds)
|
||||
rgb = images[0, ..., :3].squeeze().cpu()
|
||||
if DEBUG:
|
||||
filename = "DEBUG_%s" % filename
|
||||
Image.fromarray((rgb.numpy() * 255).astype(np.uint8)).save(
|
||||
DATA_DIR / filename
|
||||
)
|
||||
self.assertClose(rgb, image_ref)
|
||||
|
||||
def test_simple_sphere_pulsar(self):
|
||||
for device in [torch.device("cpu"), torch.device("cuda")]:
|
||||
sphere_mesh = ico_sphere(1, device)
|
||||
|