mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-12-20 22:30:35 +08:00
Implicit/Volume renderer
Summary: Implements the `ImplicitRenderer` and `VolumeRenderer`. Reviewed By: gkioxari Differential Revision: D24418791 fbshipit-source-id: 127f21186d8e210895db1dcd0681f09f230d81a4
This commit is contained in:
committed by
Facebook GitHub Bot
parent
e6a32bfc37
commit
b466c381da
403
tests/test_render_implicit.py
Normal file
403
tests/test_render_implicit.py
Normal file
@@ -0,0 +1,403 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from common_testing import TestCaseMixin
|
||||
from pytorch3d.renderer import (
|
||||
BlendParams,
|
||||
EmissionAbsorptionRaymarcher,
|
||||
GridRaysampler,
|
||||
ImplicitRenderer,
|
||||
Materials,
|
||||
MeshRasterizer,
|
||||
MeshRenderer,
|
||||
MonteCarloRaysampler,
|
||||
NDCGridRaysampler,
|
||||
PointLights,
|
||||
RasterizationSettings,
|
||||
RayBundle,
|
||||
SoftPhongShader,
|
||||
TexturesVertex,
|
||||
ray_bundle_to_ray_points,
|
||||
)
|
||||
from pytorch3d.structures import Meshes
|
||||
from pytorch3d.utils import ico_sphere
|
||||
from test_render_volumes import init_cameras
|
||||
|
||||
|
||||
DEBUG = False
|
||||
if DEBUG:
|
||||
import os
|
||||
import tempfile
|
||||
|
||||
from PIL import Image
|
||||
|
||||
|
||||
def spherical_volumetric_function(
|
||||
ray_bundle: RayBundle,
|
||||
sphere_centroid: torch.Tensor,
|
||||
sphere_diameter: float,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Volumetric function of a simple RGB sphere with diameter `sphere_diameter`
|
||||
and centroid `sphere_centroid`.
|
||||
"""
|
||||
# convert the ray bundle to world points
|
||||
rays_points_world = ray_bundle_to_ray_points(ray_bundle)
|
||||
batch_size = rays_points_world.shape[0]
|
||||
|
||||
# surface_vectors = vectors from world coords towards the sphere centroid
|
||||
surface_vectors = (
|
||||
rays_points_world.view(batch_size, -1, 3) - sphere_centroid[:, None]
|
||||
)
|
||||
|
||||
# the squared distance of each ray point to the centroid of the sphere
|
||||
surface_dist = (
|
||||
(surface_vectors ** 2)
|
||||
.sum(-1, keepdim=True)
|
||||
.view(*rays_points_world.shape[:-1], 1)
|
||||
)
|
||||
|
||||
# set all ray densities within the sphere_diameter distance from the centroid to 1
|
||||
rays_densities = torch.sigmoid(-100.0 * (surface_dist - sphere_diameter ** 2))
|
||||
|
||||
# ray colors are proportional to the normalized surface_vectors
|
||||
rays_features = (
|
||||
torch.nn.functional.normalize(
|
||||
surface_vectors.view(rays_points_world.shape), dim=-1
|
||||
)
|
||||
* 0.5
|
||||
+ 0.5
|
||||
)
|
||||
|
||||
return rays_densities, rays_features
|
||||
|
||||
|
||||
class TestRenderImplicit(TestCaseMixin, unittest.TestCase):
|
||||
def setUp(self) -> None:
|
||||
super().setUp()
|
||||
torch.manual_seed(42)
|
||||
np.random.seed(42)
|
||||
|
||||
@staticmethod
|
||||
def renderer(
|
||||
batch_size=10,
|
||||
raymarcher_type=EmissionAbsorptionRaymarcher,
|
||||
n_rays_per_image=10,
|
||||
n_pts_per_ray=10,
|
||||
sphere_diameter=0.75,
|
||||
):
|
||||
# generate NDC camera extrinsics and intrinsics
|
||||
cameras = init_cameras(batch_size, image_size=None, ndc=True)
|
||||
|
||||
# get rand offset of the volume
|
||||
sphere_centroid = torch.randn(batch_size, 3, device=cameras.device) * 0.1
|
||||
|
||||
# init the mc raysampler
|
||||
raysampler = MonteCarloRaysampler(
|
||||
min_x=-1.0,
|
||||
max_x=1.0,
|
||||
min_y=-1.0,
|
||||
max_y=1.0,
|
||||
n_rays_per_image=n_rays_per_image,
|
||||
n_pts_per_ray=n_pts_per_ray,
|
||||
min_depth=0.1,
|
||||
max_depth=2.0,
|
||||
).to(cameras.device)
|
||||
|
||||
# get the raymarcher
|
||||
raymarcher = raymarcher_type()
|
||||
|
||||
# get the implicit renderer
|
||||
renderer = ImplicitRenderer(raysampler=raysampler, raymarcher=raymarcher)
|
||||
|
||||
def run_renderer():
|
||||
renderer(
|
||||
cameras=cameras,
|
||||
volumetric_function=spherical_volumetric_function,
|
||||
sphere_centroid=sphere_centroid,
|
||||
sphere_diameter=sphere_diameter,
|
||||
)
|
||||
|
||||
return run_renderer
|
||||
|
||||
def test_input_types(self):
|
||||
"""
|
||||
Check that ValueErrors are thrown where expected.
|
||||
"""
|
||||
# check the constructor
|
||||
for bad_raysampler in (None, 5, []):
|
||||
for bad_raymarcher in (None, 5, []):
|
||||
with self.assertRaises(ValueError):
|
||||
ImplicitRenderer(
|
||||
raysampler=bad_raysampler, raymarcher=bad_raymarcher
|
||||
)
|
||||
|
||||
# init a trivial renderer
|
||||
renderer = ImplicitRenderer(
|
||||
raysampler=NDCGridRaysampler(
|
||||
image_width=100,
|
||||
image_height=100,
|
||||
n_pts_per_ray=10,
|
||||
min_depth=0.1,
|
||||
max_depth=1.0,
|
||||
),
|
||||
raymarcher=EmissionAbsorptionRaymarcher(),
|
||||
)
|
||||
|
||||
# get default cameras
|
||||
cameras = init_cameras()
|
||||
|
||||
for bad_volumetric_function in (None, 5, []):
|
||||
with self.assertRaises(ValueError):
|
||||
renderer(cameras=cameras, volumetric_function=bad_volumetric_function)
|
||||
|
||||
def test_compare_with_meshes_renderer(
|
||||
self, batch_size=11, image_size=100, sphere_diameter=0.6
|
||||
):
|
||||
"""
|
||||
Generate a spherical RGB volumetric function and its corresponding mesh
|
||||
and check whether MeshesRenderer returns the same images as the
|
||||
corresponding ImplicitRenderer.
|
||||
"""
|
||||
|
||||
# generate NDC camera extrinsics and intrinsics
|
||||
cameras = init_cameras(
|
||||
batch_size, image_size=[image_size, image_size], ndc=True
|
||||
)
|
||||
|
||||
# get rand offset of the volume
|
||||
sphere_centroid = torch.randn(batch_size, 3, device=cameras.device) * 0.1
|
||||
sphere_centroid.requires_grad = True
|
||||
|
||||
# init the grid raysampler with the ndc grid
|
||||
raysampler = NDCGridRaysampler(
|
||||
image_width=image_size,
|
||||
image_height=image_size,
|
||||
n_pts_per_ray=256,
|
||||
min_depth=0.1,
|
||||
max_depth=2.0,
|
||||
)
|
||||
|
||||
# get the EA raymarcher
|
||||
raymarcher = EmissionAbsorptionRaymarcher()
|
||||
|
||||
# jitter the camera intrinsics a bit for each render
|
||||
cameras_randomized = cameras.clone()
|
||||
cameras_randomized.principal_point = (
|
||||
torch.randn_like(cameras.principal_point) * 0.3
|
||||
)
|
||||
cameras_randomized.focal_length = (
|
||||
cameras.focal_length + torch.randn_like(cameras.focal_length) * 0.2
|
||||
)
|
||||
|
||||
# the list of differentiable camera vars
|
||||
cam_vars = ("R", "T", "focal_length", "principal_point")
|
||||
# enable the gradient caching for the camera variables
|
||||
for cam_var in cam_vars:
|
||||
getattr(cameras_randomized, cam_var).requires_grad = True
|
||||
|
||||
# get the implicit renderer
|
||||
images_opacities = ImplicitRenderer(
|
||||
raysampler=raysampler, raymarcher=raymarcher
|
||||
)(
|
||||
cameras=cameras_randomized,
|
||||
volumetric_function=spherical_volumetric_function,
|
||||
sphere_centroid=sphere_centroid,
|
||||
sphere_diameter=sphere_diameter,
|
||||
)[
|
||||
0
|
||||
]
|
||||
|
||||
# check that the renderer does not erase gradients
|
||||
loss = images_opacities.sum()
|
||||
loss.backward()
|
||||
for check_var in (
|
||||
*[getattr(cameras_randomized, cam_var) for cam_var in cam_vars],
|
||||
sphere_centroid,
|
||||
):
|
||||
self.assertIsNotNone(check_var.grad)
|
||||
|
||||
# instantiate the corresponding spherical mesh
|
||||
ico = ico_sphere(level=4, device=cameras.device).extend(batch_size)
|
||||
verts = (
|
||||
torch.nn.functional.normalize(ico.verts_padded(), dim=-1) * sphere_diameter
|
||||
+ sphere_centroid[:, None]
|
||||
)
|
||||
meshes = Meshes(
|
||||
verts=verts,
|
||||
faces=ico.faces_padded(),
|
||||
textures=TexturesVertex(
|
||||
verts_features=(
|
||||
torch.nn.functional.normalize(verts, dim=-1) * 0.5 + 0.5
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
# instantiate the corresponding mesh renderer
|
||||
lights = PointLights(device=cameras.device, location=[[0.0, 0.0, 0.0]])
|
||||
renderer_textured = MeshRenderer(
|
||||
rasterizer=MeshRasterizer(
|
||||
cameras=cameras_randomized,
|
||||
raster_settings=RasterizationSettings(
|
||||
image_size=image_size, blur_radius=1e-3, faces_per_pixel=10
|
||||
),
|
||||
),
|
||||
shader=SoftPhongShader(
|
||||
device=cameras.device,
|
||||
cameras=cameras_randomized,
|
||||
lights=lights,
|
||||
materials=Materials(
|
||||
ambient_color=((2.0, 2.0, 2.0),),
|
||||
diffuse_color=((0.0, 0.0, 0.0),),
|
||||
specular_color=((0.0, 0.0, 0.0),),
|
||||
shininess=64,
|
||||
device=cameras.device,
|
||||
),
|
||||
blend_params=BlendParams(
|
||||
sigma=1e-3, gamma=1e-4, background_color=(0.0, 0.0, 0.0)
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
# get the mesh render
|
||||
images_opacities_meshes = renderer_textured(
|
||||
meshes, cameras=cameras_randomized, lights=lights
|
||||
)
|
||||
|
||||
if DEBUG:
|
||||
outdir = tempfile.gettempdir() + "/test_implicit_vs_mesh_renderer"
|
||||
os.makedirs(outdir, exist_ok=True)
|
||||
|
||||
frames = []
|
||||
for (image_opacity, image_opacity_mesh) in zip(
|
||||
images_opacities, images_opacities_meshes
|
||||
):
|
||||
image, opacity = image_opacity.split([3, 1], dim=-1)
|
||||
image_mesh, opacity_mesh = image_opacity_mesh.split([3, 1], dim=-1)
|
||||
diff_image = (
|
||||
((image - image_mesh) * 0.5 + 0.5)
|
||||
.mean(dim=2, keepdim=True)
|
||||
.repeat(1, 1, 3)
|
||||
)
|
||||
image_pil = Image.fromarray(
|
||||
(
|
||||
torch.cat(
|
||||
(
|
||||
image,
|
||||
image_mesh,
|
||||
diff_image,
|
||||
opacity.repeat(1, 1, 3),
|
||||
opacity_mesh.repeat(1, 1, 3),
|
||||
),
|
||||
dim=1,
|
||||
)
|
||||
.detach()
|
||||
.cpu()
|
||||
.numpy()
|
||||
* 255.0
|
||||
).astype(np.uint8)
|
||||
)
|
||||
frames.append(image_pil)
|
||||
|
||||
# export gif
|
||||
outfile = os.path.join(outdir, "implicit_vs_mesh_render.gif")
|
||||
frames[0].save(
|
||||
outfile,
|
||||
save_all=True,
|
||||
append_images=frames[1:],
|
||||
duration=batch_size // 15,
|
||||
loop=0,
|
||||
)
|
||||
print(f"exported {outfile}")
|
||||
|
||||
# export concatenated frames
|
||||
outfile_cat = os.path.join(outdir, "implicit_vs_mesh_render.png")
|
||||
Image.fromarray(np.concatenate([np.array(f) for f in frames], axis=0)).save(
|
||||
outfile_cat
|
||||
)
|
||||
print(f"exported {outfile_cat}")
|
||||
|
||||
# compare the renders
|
||||
diff = (images_opacities - images_opacities_meshes).abs().mean(dim=-1)
|
||||
mu_diff = diff.mean(dim=(1, 2))
|
||||
std_diff = diff.std(dim=(1, 2))
|
||||
self.assertClose(mu_diff, torch.zeros_like(mu_diff), atol=5e-2)
|
||||
self.assertClose(std_diff, torch.zeros_like(std_diff), atol=6e-2)
|
||||
|
||||
def test_rotating_gif(
|
||||
self, n_frames=50, fps=15, image_size=(100, 100), sphere_diameter=0.5
|
||||
):
|
||||
"""
|
||||
Render a gif animation of a rotating sphere (runs only if `DEBUG==True`).
|
||||
"""
|
||||
|
||||
if not DEBUG:
|
||||
# do not run this if debug is False
|
||||
return
|
||||
|
||||
# generate camera extrinsics and intrinsics
|
||||
cameras = init_cameras(n_frames, image_size=image_size)
|
||||
|
||||
# init the grid raysampler
|
||||
raysampler = GridRaysampler(
|
||||
min_x=0.5,
|
||||
max_x=image_size[1] - 0.5,
|
||||
min_y=0.5,
|
||||
max_y=image_size[0] - 0.5,
|
||||
image_width=image_size[1],
|
||||
image_height=image_size[0],
|
||||
n_pts_per_ray=256,
|
||||
min_depth=0.1,
|
||||
max_depth=2.0,
|
||||
)
|
||||
|
||||
# get the EA raymarcher
|
||||
raymarcher = EmissionAbsorptionRaymarcher()
|
||||
|
||||
# get the implicit render
|
||||
renderer = ImplicitRenderer(raysampler=raysampler, raymarcher=raymarcher)
|
||||
|
||||
# get the (0) centroid of the sphere
|
||||
sphere_centroid = torch.zeros(n_frames, 3, device=cameras.device) * 0.1
|
||||
|
||||
# run the renderer
|
||||
images_opacities = renderer(
|
||||
cameras=cameras,
|
||||
volumetric_function=spherical_volumetric_function,
|
||||
sphere_centroid=sphere_centroid,
|
||||
sphere_diameter=sphere_diameter,
|
||||
)[0]
|
||||
|
||||
# split output to the alpha channel and rendered images
|
||||
images, opacities = images_opacities[..., :3], images_opacities[..., 3]
|
||||
|
||||
# export the gif
|
||||
outdir = tempfile.gettempdir() + "/test_implicit_renderer_gifs"
|
||||
os.makedirs(outdir, exist_ok=True)
|
||||
frames = []
|
||||
for image, opacity in zip(images, opacities):
|
||||
image_pil = Image.fromarray(
|
||||
(
|
||||
torch.cat(
|
||||
(image, opacity[..., None].clamp(0.0, 1.0).repeat(1, 1, 3)),
|
||||
dim=1,
|
||||
)
|
||||
.detach()
|
||||
.cpu()
|
||||
.numpy()
|
||||
* 255.0
|
||||
).astype(np.uint8)
|
||||
)
|
||||
frames.append(image_pil)
|
||||
outfile = os.path.join(outdir, "rotating_sphere.gif")
|
||||
frames[0].save(
|
||||
outfile,
|
||||
save_all=True,
|
||||
append_images=frames[1:],
|
||||
duration=n_frames // fps,
|
||||
loop=0,
|
||||
)
|
||||
print(f"exported {outfile}")
|
||||
Reference in New Issue
Block a user