mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 03:42:50 +08:00
Summary: There are a couple of options for supporting non square images: 1) NDC stays at [-1, 1] in both directions with the distance calculations all modified by (W/H). There are a lot of distance based calculations (e.g. triangle areas for barycentric coordinates etc) so this requires changes in many places. 2) NDC is scaled by (W/H) so the smallest side has [-1, 1]. In this case none of the distance calculations need to be updated and only the pixel to NDC calculation needs to be modified. I decided to go with option 2 after trying option 1! API Changes: - Image size can now be specified optionally as a tuple TODO: - add a benchmark test for the non square case. Reviewed By: jcjohnson Differential Revision: D24404975 fbshipit-source-id: 545efb67c822d748ec35999b35762bce58db2cf4
440 lines
16 KiB
Python
440 lines
16 KiB
Python
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
|
|
|
import unittest
|
|
from itertools import product
|
|
from pathlib import Path
|
|
|
|
import numpy as np
|
|
import torch
|
|
from common_testing import TestCaseMixin, load_rgb_image
|
|
from PIL import Image
|
|
from pytorch3d.io import load_obj
|
|
from pytorch3d.renderer.cameras import FoVPerspectiveCameras, look_at_view_transform
|
|
from pytorch3d.renderer.lighting import PointLights
|
|
from pytorch3d.renderer.materials import Materials
|
|
from pytorch3d.renderer.mesh import TexturesUV
|
|
from pytorch3d.renderer.mesh.rasterize_meshes import (
|
|
rasterize_meshes,
|
|
rasterize_meshes_python,
|
|
)
|
|
from pytorch3d.renderer.mesh.rasterizer import (
|
|
Fragments,
|
|
MeshRasterizer,
|
|
RasterizationSettings,
|
|
)
|
|
from pytorch3d.renderer.mesh.renderer import MeshRenderer
|
|
from pytorch3d.renderer.mesh.shader import BlendParams, SoftPhongShader
|
|
from pytorch3d.structures import Meshes
|
|
|
|
|
|
DEBUG = False
|
|
DATA_DIR = Path(__file__).resolve().parent / "data"
|
|
|
|
# Verts/Faces for a simple mesh with two faces.
|
|
verts0 = torch.tensor(
|
|
[
|
|
[-0.7, -0.70, 1.0],
|
|
[0.0, -0.1, 1.0],
|
|
[0.7, -0.7, 1.0],
|
|
[-0.7, 0.1, 1.0],
|
|
[0.0, 0.7, 1.0],
|
|
[0.7, 0.1, 1.0],
|
|
],
|
|
dtype=torch.float32,
|
|
)
|
|
faces0 = torch.tensor([[1, 0, 2], [4, 3, 5]], dtype=torch.int64)
|
|
|
|
|
|
class TestRasterizeRectanglesErrors(TestCaseMixin, unittest.TestCase):
|
|
def test_image_size_arg(self):
|
|
meshes = Meshes(verts=[verts0], faces=[faces0])
|
|
|
|
with self.assertRaises(ValueError) as cm:
|
|
rasterize_meshes(
|
|
meshes,
|
|
(100, 200, 3),
|
|
0.0001,
|
|
faces_per_pixel=1,
|
|
)
|
|
self.assertTrue("tuple/list of (H, W)" in cm.msg)
|
|
|
|
with self.assertRaises(ValueError) as cm:
|
|
rasterize_meshes(
|
|
meshes,
|
|
(0, 10),
|
|
0.0001,
|
|
faces_per_pixel=1,
|
|
)
|
|
self.assertTrue("sizes must be positive" in cm.msg)
|
|
|
|
with self.assertRaises(ValueError) as cm:
|
|
rasterize_meshes(
|
|
meshes,
|
|
(100.5, 120.5),
|
|
0.0001,
|
|
faces_per_pixel=1,
|
|
)
|
|
self.assertTrue("sizes must be integers" in cm.msg)
|
|
|
|
|
|
class TestRasterizeRectangles(TestCaseMixin, unittest.TestCase):
|
|
@staticmethod
|
|
def _clone_mesh(verts0, faces0, device, batch_size):
|
|
"""
|
|
Helper function to detach and clone the verts/faces.
|
|
This is needed in order to set up the tensors for
|
|
gradient computation in different tests.
|
|
"""
|
|
verts = verts0.detach().clone()
|
|
verts.requires_grad = True
|
|
meshes = Meshes(verts=[verts], faces=[faces0])
|
|
meshes = meshes.to(device).extend(batch_size)
|
|
return verts, meshes
|
|
|
|
def _rasterize(self, meshes, image_size, bin_size, blur):
|
|
"""
|
|
Simple wrapper around the rasterize function to return
|
|
the fragment data.
|
|
"""
|
|
face_idxs, zbuf, bary_coords, pix_dists = rasterize_meshes(
|
|
meshes,
|
|
image_size,
|
|
blur,
|
|
faces_per_pixel=1,
|
|
bin_size=bin_size,
|
|
)
|
|
return Fragments(
|
|
pix_to_face=face_idxs,
|
|
zbuf=zbuf,
|
|
bary_coords=bary_coords,
|
|
dists=pix_dists,
|
|
)
|
|
|
|
@staticmethod
|
|
def _save_debug_image(fragments, image_size, bin_size, blur):
|
|
"""
|
|
Save a mask image from the rasterization output for debugging.
|
|
"""
|
|
H, W = image_size
|
|
# Save out the last image for debugging
|
|
rgb = (fragments.pix_to_face[-1, ..., :3].cpu() > -1).squeeze()
|
|
suffix = "square" if H == W else "non_square"
|
|
filename = "triangle_%s_bin_size_%s_blur_%.3f_%dx%d.png"
|
|
filename = filename % (suffix, str(bin_size), blur, H, W)
|
|
if DEBUG:
|
|
filename = "DEBUG_%s" % filename
|
|
Image.fromarray((rgb.numpy() * 255).astype(np.uint8)).save(
|
|
DATA_DIR / filename
|
|
)
|
|
|
|
def _check_fragments(self, frag_1, frag_2):
|
|
"""
|
|
Helper function to check that the tensors in
|
|
the Fragments frag_1 and frag_2 are the same.
|
|
"""
|
|
self.assertClose(frag_1.pix_to_face, frag_2.pix_to_face)
|
|
self.assertClose(frag_1.dists, frag_2.dists)
|
|
self.assertClose(frag_1.bary_coords, frag_2.bary_coords)
|
|
self.assertClose(frag_1.zbuf, frag_2.zbuf)
|
|
|
|
def _compare_square_with_nonsq(
|
|
self,
|
|
image_size,
|
|
blur,
|
|
device,
|
|
verts0,
|
|
faces0,
|
|
nonsq_fragment_gradtensor_list,
|
|
batch_size=1,
|
|
):
|
|
"""
|
|
Calculate the output from rasterizing a square image with the minimum of (H, W).
|
|
Then compare this with the same square region in the non square image.
|
|
The input mesh faces given by faces0 and verts0 are contained within the
|
|
[-1, 1] range of the image so all the relevant pixels will be within the square region.
|
|
|
|
`nonsq_fragment_gradtensor_list` is a list of fragments and verts grad tensors
|
|
from rasterizing non square images.
|
|
"""
|
|
# Rasterize the square version of the image
|
|
H, W = image_size
|
|
S = min(H, W)
|
|
verts_square, meshes_sq = self._clone_mesh(verts0, faces0, device, batch_size)
|
|
square_fragments = self._rasterize(
|
|
meshes_sq, image_size=(S, S), bin_size=0, blur=blur
|
|
)
|
|
# Save debug image
|
|
self._save_debug_image(square_fragments, (S, S), 0, blur)
|
|
|
|
# Extract the values in the square image which are non zero.
|
|
square_mask = square_fragments.pix_to_face > -1
|
|
square_dists = square_fragments.dists[square_mask]
|
|
square_zbuf = square_fragments.zbuf[square_mask]
|
|
square_bary = square_fragments.bary_coords[square_mask]
|
|
|
|
# Retain gradients on the output of fragments to check
|
|
# intermediate values with the non square outputs.
|
|
square_fragments.dists.retain_grad()
|
|
square_fragments.bary_coords.retain_grad()
|
|
square_fragments.zbuf.retain_grad()
|
|
|
|
# Calculate gradient for the square image
|
|
torch.manual_seed(231)
|
|
grad_zbuf = torch.randn_like(square_zbuf)
|
|
grad_dist = torch.randn_like(square_dists)
|
|
grad_bary = torch.randn_like(square_bary)
|
|
loss0 = (
|
|
(grad_dist * square_dists).sum()
|
|
+ (grad_zbuf * square_zbuf).sum()
|
|
+ (grad_bary * square_bary).sum()
|
|
)
|
|
loss0.backward()
|
|
|
|
# Now compare against the non square outputs provided
|
|
# in the nonsq_fragment_gradtensor_list list
|
|
for fragments, grad_tensor, _name in nonsq_fragment_gradtensor_list:
|
|
# Check that there are the same number of non zero pixels
|
|
# in both the square and non square images.
|
|
non_square_mask = fragments.pix_to_face > -1
|
|
self.assertEqual(non_square_mask.sum().item(), square_mask.sum().item())
|
|
|
|
# Check dists, zbuf and bary match the square image
|
|
non_square_dists = fragments.dists[non_square_mask]
|
|
non_square_zbuf = fragments.zbuf[non_square_mask]
|
|
non_square_bary = fragments.bary_coords[non_square_mask]
|
|
self.assertClose(square_dists, non_square_dists)
|
|
self.assertClose(square_zbuf, non_square_zbuf)
|
|
self.assertClose(
|
|
square_bary,
|
|
non_square_bary,
|
|
atol=2e-7,
|
|
)
|
|
|
|
# Retain gradients to compare values with outputs from
|
|
# square image
|
|
fragments.dists.retain_grad()
|
|
fragments.bary_coords.retain_grad()
|
|
fragments.zbuf.retain_grad()
|
|
loss1 = (
|
|
(grad_dist * non_square_dists).sum()
|
|
+ (grad_zbuf * non_square_zbuf).sum()
|
|
+ (grad_bary * non_square_bary).sum()
|
|
)
|
|
loss1.sum().backward()
|
|
|
|
# Get the non zero values in the intermediate gradients
|
|
# and compare with the values from the square image
|
|
non_square_grad_dists = fragments.dists.grad[non_square_mask]
|
|
non_square_grad_bary = fragments.bary_coords.grad[non_square_mask]
|
|
non_square_grad_zbuf = fragments.zbuf.grad[non_square_mask]
|
|
|
|
self.assertClose(
|
|
non_square_grad_dists,
|
|
square_fragments.dists.grad[square_mask],
|
|
)
|
|
self.assertClose(
|
|
non_square_grad_bary,
|
|
square_fragments.bary_coords.grad[square_mask],
|
|
)
|
|
self.assertClose(
|
|
non_square_grad_zbuf,
|
|
square_fragments.zbuf.grad[square_mask],
|
|
)
|
|
|
|
# Finally check the gradients of the input vertices for
|
|
# the square and non square case
|
|
self.assertClose(verts_square.grad, grad_tensor.grad, rtol=2e-4)
|
|
|
|
def test_gpu(self):
|
|
"""
|
|
Test that the output of rendering non square images
|
|
gives the same result as square images. i.e. the
|
|
dists, zbuf, bary are all the same for the square
|
|
region which is present in both images.
|
|
"""
|
|
# Test both cases: (W > H), (H > W)
|
|
image_sizes = [(64, 128), (128, 64), (128, 256), (256, 128)]
|
|
|
|
devices = ["cuda:0"]
|
|
blurs = [0.0, 0.001]
|
|
batch_sizes = [1, 4]
|
|
test_cases = product(image_sizes, blurs, devices, batch_sizes)
|
|
|
|
for image_size, blur, device, batch_size in test_cases:
|
|
# Initialize the verts grad tensor and the meshes objects
|
|
verts_nonsq_naive, meshes_nonsq_naive = self._clone_mesh(
|
|
verts0, faces0, device, batch_size
|
|
)
|
|
verts_nonsq_binned, meshes_nonsq_binned = self._clone_mesh(
|
|
verts0, faces0, device, batch_size
|
|
)
|
|
|
|
# Get the outputs for both naive and coarse to fine rasterization
|
|
fragments_naive = self._rasterize(
|
|
meshes_nonsq_naive,
|
|
image_size,
|
|
blur=blur,
|
|
bin_size=0,
|
|
)
|
|
fragments_binned = self._rasterize(
|
|
meshes_nonsq_binned,
|
|
image_size,
|
|
blur=blur,
|
|
bin_size=None,
|
|
)
|
|
|
|
# Save out debug images if needed
|
|
self._save_debug_image(fragments_naive, image_size, 0, blur)
|
|
self._save_debug_image(fragments_binned, image_size, None, blur)
|
|
|
|
# Check naive and binned fragments give the same outputs
|
|
self._check_fragments(fragments_naive, fragments_binned)
|
|
|
|
# Here we want to compare the square image with the naive and the
|
|
# coarse to fine methods outputs
|
|
nonsq_fragment_gradtensor_list = [
|
|
(fragments_naive, verts_nonsq_naive, "naive"),
|
|
(fragments_binned, verts_nonsq_binned, "coarse-to-fine"),
|
|
]
|
|
|
|
self._compare_square_with_nonsq(
|
|
image_size,
|
|
blur,
|
|
device,
|
|
verts0,
|
|
faces0,
|
|
nonsq_fragment_gradtensor_list,
|
|
batch_size,
|
|
)
|
|
|
|
def test_cpu(self):
|
|
"""
|
|
Test that the output of rendering non square images
|
|
gives the same result as square images. i.e. the
|
|
dists, zbuf, bary are all the same for the square
|
|
region which is present in both images.
|
|
|
|
In this test we compare between the naive C++ implementation
|
|
and the naive python implementation as the Coarse/Fine
|
|
method is not fully implemented in C++
|
|
"""
|
|
# Test both when (W > H) and (H > W).
|
|
# Using smaller image sizes here as the Python rasterizer is really slow.
|
|
image_sizes = [(32, 64), (64, 32)]
|
|
devices = ["cpu"]
|
|
blurs = [0.0, 0.001]
|
|
batch_sizes = [1]
|
|
test_cases = product(image_sizes, blurs, devices, batch_sizes)
|
|
|
|
for image_size, blur, device, batch_size in test_cases:
|
|
# Initialize the verts grad tensor and the meshes objects
|
|
verts_nonsq_naive, meshes_nonsq_naive = self._clone_mesh(
|
|
verts0, faces0, device, batch_size
|
|
)
|
|
verts_nonsq_python, meshes_nonsq_python = self._clone_mesh(
|
|
verts0, faces0, device, batch_size
|
|
)
|
|
|
|
# Compare Naive CPU with Python as Coarse/Fine rasteriztation
|
|
# is not implemented for CPU
|
|
fragments_naive = self._rasterize(
|
|
meshes_nonsq_naive, image_size, bin_size=0, blur=blur
|
|
)
|
|
face_idxs, zbuf, bary_coords, pix_dists = rasterize_meshes_python(
|
|
meshes_nonsq_python,
|
|
image_size,
|
|
blur,
|
|
faces_per_pixel=1,
|
|
)
|
|
fragments_python = Fragments(
|
|
pix_to_face=face_idxs,
|
|
zbuf=zbuf,
|
|
bary_coords=bary_coords,
|
|
dists=pix_dists,
|
|
)
|
|
|
|
# Save debug images if DEBUG is set to true at the top of the file.
|
|
self._save_debug_image(fragments_naive, image_size, 0, blur)
|
|
self._save_debug_image(fragments_python, image_size, "python", blur)
|
|
|
|
# List of non square outputs to compare with the square output
|
|
nonsq_fragment_gradtensor_list = [
|
|
(fragments_naive, verts_nonsq_naive, "naive"),
|
|
(fragments_python, verts_nonsq_python, "python"),
|
|
]
|
|
self._compare_square_with_nonsq(
|
|
image_size,
|
|
blur,
|
|
device,
|
|
verts0,
|
|
faces0,
|
|
nonsq_fragment_gradtensor_list,
|
|
batch_size,
|
|
)
|
|
|
|
def test_render_cow(self):
|
|
"""
|
|
Test a larger textured mesh is rendered correctly in a non square image.
|
|
"""
|
|
device = torch.device("cuda:0")
|
|
obj_dir = Path(__file__).resolve().parent.parent / "docs/tutorials/data"
|
|
obj_filename = obj_dir / "cow_mesh/cow.obj"
|
|
|
|
# Load mesh + texture
|
|
verts, faces, aux = load_obj(
|
|
obj_filename, device=device, load_textures=True, texture_wrap=None
|
|
)
|
|
tex_map = list(aux.texture_images.values())[0]
|
|
tex_map = tex_map[None, ...].to(faces.textures_idx.device)
|
|
textures = TexturesUV(
|
|
maps=tex_map, faces_uvs=[faces.textures_idx], verts_uvs=[aux.verts_uvs]
|
|
)
|
|
mesh = Meshes(verts=[verts], faces=[faces.verts_idx], textures=textures)
|
|
|
|
# Init rasterizer settings
|
|
R, T = look_at_view_transform(2.7, 0, 180)
|
|
cameras = FoVPerspectiveCameras(device=device, R=R, T=T)
|
|
|
|
raster_settings = RasterizationSettings(
|
|
image_size=(512, 1024), blur_radius=0.0, faces_per_pixel=1
|
|
)
|
|
|
|
# Init shader settings
|
|
materials = Materials(device=device)
|
|
lights = PointLights(device=device)
|
|
lights.location = torch.tensor([0.0, 0.0, -2.0], device=device)[None]
|
|
blend_params = BlendParams(
|
|
sigma=1e-1,
|
|
gamma=1e-4,
|
|
background_color=torch.tensor([1.0, 1.0, 1.0], device=device),
|
|
)
|
|
|
|
# Init renderer
|
|
renderer = MeshRenderer(
|
|
rasterizer=MeshRasterizer(cameras=cameras, raster_settings=raster_settings),
|
|
shader=SoftPhongShader(
|
|
lights=lights,
|
|
cameras=cameras,
|
|
materials=materials,
|
|
blend_params=blend_params,
|
|
),
|
|
)
|
|
|
|
# Load reference image
|
|
image_ref = load_rgb_image("test_cow_image_rectangle.png", DATA_DIR)
|
|
|
|
for bin_size in [0, None]:
|
|
# Check both naive and coarse to fine produce the same output.
|
|
renderer.rasterizer.raster_settings.bin_size = bin_size
|
|
images = renderer(mesh)
|
|
rgb = images[0, ..., :3].squeeze().cpu()
|
|
|
|
if DEBUG:
|
|
Image.fromarray((rgb.numpy() * 255).astype(np.uint8)).save(
|
|
DATA_DIR / "DEBUG_cow_image_rectangle.png"
|
|
)
|
|
|
|
# NOTE some pixels can be flaky
|
|
cond1 = torch.allclose(rgb, image_ref, atol=0.05)
|
|
self.assertTrue(cond1)
|