mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-12-14 11:26:24 +08:00
CUDA kernel for interpolate_face_attributes
Summary: When rendering meshes with Phong shading, interpolate_face_attributes was taking up a nontrivial fraction of the overall shading time. This diff replaces our Python implementation of this function with a CUDA implementation. Reviewed By: nikhilaravi Differential Revision: D21610763 fbshipit-source-id: 2bb362a28f698541812aeab539047264b125ebb8
This commit is contained in:
committed by
Facebook GitHub Bot
parent
0505e5f4a9
commit
26d2cc24c1
185
tests/test_interpolate_face_attributes.py
Normal file
185
tests/test_interpolate_face_attributes.py
Normal file
@@ -0,0 +1,185 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
from common_testing import TestCaseMixin, get_random_cuda_device
|
||||
from pytorch3d.ops.interp_face_attrs import (
|
||||
interpolate_face_attributes,
|
||||
interpolate_face_attributes_python,
|
||||
)
|
||||
from pytorch3d.renderer.mesh.rasterizer import Fragments
|
||||
from pytorch3d.renderer.mesh.texturing import (
|
||||
interpolate_texture_map,
|
||||
interpolate_vertex_colors,
|
||||
)
|
||||
from pytorch3d.structures import Meshes, Textures
|
||||
|
||||
|
||||
class TestInterpolateFaceAttributes(TestCaseMixin, unittest.TestCase):
|
||||
def _test_interp_face_attrs(self, interp_fun, device):
|
||||
pix_to_face = [0, 2, -1, 0, 1, -1]
|
||||
barycentric_coords = [
|
||||
[1.0, 0.0, 0.0],
|
||||
[0.0, 1.0, 0.0],
|
||||
[0.0, 0.0, 1.0],
|
||||
[0.5, 0.5, 0.0],
|
||||
[0.8, 0.0, 0.2],
|
||||
[0.25, 0.5, 0.25],
|
||||
]
|
||||
face_attrs = [
|
||||
[[1, 2], [3, 4], [5, 6]],
|
||||
[[7, 8], [9, 10], [11, 12]],
|
||||
[[13, 14], [15, 16], [17, 18]],
|
||||
]
|
||||
pix_attrs = [
|
||||
[1, 2],
|
||||
[15, 16],
|
||||
[0, 0],
|
||||
[2, 3],
|
||||
[0.8 * 7 + 0.2 * 11, 0.8 * 8 + 0.2 * 12],
|
||||
[0, 0],
|
||||
]
|
||||
N, H, W, K, D = 1, 2, 1, 3, 2
|
||||
pix_to_face = torch.tensor(pix_to_face, dtype=torch.int64, device=device)
|
||||
pix_to_face = pix_to_face.view(N, H, W, K)
|
||||
barycentric_coords = torch.tensor(
|
||||
barycentric_coords, dtype=torch.float32, device=device
|
||||
)
|
||||
barycentric_coords = barycentric_coords.view(N, H, W, K, 3)
|
||||
face_attrs = torch.tensor(face_attrs, dtype=torch.float32, device=device)
|
||||
pix_attrs = torch.tensor(pix_attrs, dtype=torch.float32, device=device)
|
||||
pix_attrs = pix_attrs.view(N, H, W, K, D)
|
||||
|
||||
args = (pix_to_face, barycentric_coords, face_attrs)
|
||||
pix_attrs_actual = interp_fun(*args)
|
||||
self.assertClose(pix_attrs_actual, pix_attrs)
|
||||
|
||||
def test_python(self):
|
||||
device = torch.device("cuda:0")
|
||||
self._test_interp_face_attrs(interpolate_face_attributes_python, device)
|
||||
|
||||
def test_cuda(self):
|
||||
device = torch.device("cuda:0")
|
||||
self._test_interp_face_attrs(interpolate_face_attributes, device)
|
||||
|
||||
def test_python_vs_cuda(self):
|
||||
N, H, W, K = 2, 32, 32, 5
|
||||
F = 1000
|
||||
D = 3
|
||||
device = get_random_cuda_device()
|
||||
torch.manual_seed(598)
|
||||
pix_to_face = torch.randint(-F, F, (N, H, W, K), device=device)
|
||||
barycentric_coords = torch.randn(
|
||||
N, H, W, K, 3, device=device, requires_grad=True
|
||||
)
|
||||
face_attrs = torch.randn(F, 3, D, device=device, requires_grad=True)
|
||||
grad_pix_attrs = torch.randn(N, H, W, K, D, device=device)
|
||||
args = (pix_to_face, barycentric_coords, face_attrs)
|
||||
|
||||
# Run the python version
|
||||
pix_attrs_py = interpolate_face_attributes_python(*args)
|
||||
pix_attrs_py.backward(gradient=grad_pix_attrs)
|
||||
grad_bary_py = barycentric_coords.grad.clone()
|
||||
grad_face_attrs_py = face_attrs.grad.clone()
|
||||
|
||||
# Clear gradients
|
||||
barycentric_coords.grad.zero_()
|
||||
face_attrs.grad.zero_()
|
||||
|
||||
# Run the CUDA version
|
||||
pix_attrs_cu = interpolate_face_attributes(*args)
|
||||
pix_attrs_cu.backward(gradient=grad_pix_attrs)
|
||||
grad_bary_cu = barycentric_coords.grad.clone()
|
||||
grad_face_attrs_cu = face_attrs.grad.clone()
|
||||
|
||||
# Check they are the same
|
||||
self.assertClose(pix_attrs_py, pix_attrs_cu, rtol=2e-3)
|
||||
self.assertClose(grad_bary_py, grad_bary_cu, rtol=1e-4)
|
||||
self.assertClose(grad_face_attrs_py, grad_face_attrs_cu, rtol=1e-3)
|
||||
|
||||
def test_interpolate_attributes(self):
|
||||
"""
|
||||
This tests both interpolate_vertex_colors as well as
|
||||
interpolate_face_attributes.
|
||||
"""
|
||||
verts = torch.randn((4, 3), dtype=torch.float32)
|
||||
faces = torch.tensor([[2, 1, 0], [3, 1, 0]], dtype=torch.int64)
|
||||
vert_tex = torch.tensor(
|
||||
[[0, 1, 0], [0, 1, 1], [1, 1, 0], [1, 1, 1]], dtype=torch.float32
|
||||
)
|
||||
tex = Textures(verts_rgb=vert_tex[None, :])
|
||||
mesh = Meshes(verts=[verts], faces=[faces], textures=tex)
|
||||
pix_to_face = torch.tensor([0, 1], dtype=torch.int64).view(1, 1, 1, 2)
|
||||
barycentric_coords = torch.tensor(
|
||||
[[0.5, 0.3, 0.2], [0.3, 0.6, 0.1]], dtype=torch.float32
|
||||
).view(1, 1, 1, 2, -1)
|
||||
expected_vals = torch.tensor(
|
||||
[[0.5, 1.0, 0.3], [0.3, 1.0, 0.9]], dtype=torch.float32
|
||||
).view(1, 1, 1, 2, -1)
|
||||
fragments = Fragments(
|
||||
pix_to_face=pix_to_face,
|
||||
bary_coords=barycentric_coords,
|
||||
zbuf=torch.ones_like(pix_to_face),
|
||||
dists=torch.ones_like(pix_to_face),
|
||||
)
|
||||
texels = interpolate_vertex_colors(fragments, mesh)
|
||||
self.assertTrue(torch.allclose(texels, expected_vals[None, :]))
|
||||
|
||||
def test_interpolate_attributes_grad(self):
|
||||
verts = torch.randn((4, 3), dtype=torch.float32)
|
||||
faces = torch.tensor([[2, 1, 0], [3, 1, 0]], dtype=torch.int64)
|
||||
vert_tex = torch.tensor(
|
||||
[[0, 1, 0], [0, 1, 1], [1, 1, 0], [1, 1, 1]],
|
||||
dtype=torch.float32,
|
||||
requires_grad=True,
|
||||
)
|
||||
tex = Textures(verts_rgb=vert_tex[None, :])
|
||||
mesh = Meshes(verts=[verts], faces=[faces], textures=tex)
|
||||
pix_to_face = torch.tensor([0, 1], dtype=torch.int64).view(1, 1, 1, 2)
|
||||
barycentric_coords = torch.tensor(
|
||||
[[0.5, 0.3, 0.2], [0.3, 0.6, 0.1]], dtype=torch.float32
|
||||
).view(1, 1, 1, 2, -1)
|
||||
fragments = Fragments(
|
||||
pix_to_face=pix_to_face,
|
||||
bary_coords=barycentric_coords,
|
||||
zbuf=torch.ones_like(pix_to_face),
|
||||
dists=torch.ones_like(pix_to_face),
|
||||
)
|
||||
grad_vert_tex = torch.tensor(
|
||||
[[0.3, 0.3, 0.3], [0.9, 0.9, 0.9], [0.5, 0.5, 0.5], [0.3, 0.3, 0.3]],
|
||||
dtype=torch.float32,
|
||||
)
|
||||
texels = interpolate_vertex_colors(fragments, mesh)
|
||||
texels.sum().backward()
|
||||
self.assertTrue(hasattr(vert_tex, "grad"))
|
||||
self.assertTrue(torch.allclose(vert_tex.grad, grad_vert_tex[None, :]))
|
||||
|
||||
def test_interpolate_face_attributes_fail(self):
|
||||
# 1. A face can only have 3 verts
|
||||
# i.e. face_attributes must have shape (F, 3, D)
|
||||
face_attributes = torch.ones(1, 4, 3)
|
||||
pix_to_face = torch.ones((1, 1, 1, 1))
|
||||
fragments = Fragments(
|
||||
pix_to_face=pix_to_face,
|
||||
bary_coords=pix_to_face[..., None].expand(-1, -1, -1, -1, 3),
|
||||
zbuf=pix_to_face,
|
||||
dists=pix_to_face,
|
||||
)
|
||||
with self.assertRaises(ValueError):
|
||||
interpolate_face_attributes(
|
||||
fragments.pix_to_face, fragments.bary_coords, face_attributes
|
||||
)
|
||||
|
||||
# 2. pix_to_face must have shape (N, H, W, K)
|
||||
pix_to_face = torch.ones((1, 1, 1, 1, 3))
|
||||
fragments = Fragments(
|
||||
pix_to_face=pix_to_face,
|
||||
bary_coords=pix_to_face,
|
||||
zbuf=pix_to_face,
|
||||
dists=pix_to_face,
|
||||
)
|
||||
with self.assertRaises(ValueError):
|
||||
interpolate_face_attributes(
|
||||
fragments.pix_to_face, fragments.bary_coords, face_attributes
|
||||
)
|
||||
Reference in New Issue
Block a user