CUDA kernel for interpolate_face_attributes

Summary: When rendering meshes with Phong shading, interpolate_face_attributes was taking up a nontrivial fraction of the overall shading time. This diff replaces our Python implementation of this function with a CUDA implementation.

Reviewed By: nikhilaravi

Differential Revision: D21610763

fbshipit-source-id: 2bb362a28f698541812aeab539047264b125ebb8
This commit is contained in:
Justin Johnson 2020-07-13 12:58:07 -07:00 committed by Facebook GitHub Bot
parent 0505e5f4a9
commit 26d2cc24c1
11 changed files with 630 additions and 140 deletions

View File

@ -6,6 +6,7 @@
#include "compositing/weighted_sum.h"
#include "face_areas_normals/face_areas_normals.h"
#include "gather_scatter/gather_scatter.h"
#include "interp_face_attrs/interp_face_attrs.h"
#include "knn/knn.h"
#include "packed_to_padded_tensor/packed_to_padded_tensor.h"
#include "point_mesh/point_mesh_edge.h"
@ -18,6 +19,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("face_areas_normals_backward", &FaceAreasNormalsBackward);
m.def("packed_to_padded", &PackedToPadded);
m.def("padded_to_packed", &PaddedToPacked);
m.def("interp_face_attrs_forward", &InterpFaceAttrsForward);
m.def("interp_face_attrs_backward", &InterpFaceAttrsBackward);
#ifdef WITH_CUDA
m.def("knn_check_version", &KnnCheckVersion);
#endif

View File

@ -0,0 +1,161 @@
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <tuple>
template <typename scalar_t>
__global__ void InterpFaceAttrsForwardKernel(
const int64_t* __restrict__ pix_to_face, // (P,)
const scalar_t* __restrict__ barycentric_coords, // (P, 3)
const scalar_t* __restrict__ face_attrs, // (F, 3, D)
scalar_t* pix_attrs, // (P, D)
const size_t P,
const size_t F,
const size_t D) {
const int tid = threadIdx.x + blockIdx.x * blockDim.x;
const int num_threads = blockDim.x * gridDim.x;
for (int pd = tid; pd < P * D; pd += num_threads) {
const int p = pd / D;
const int d = pd % D;
const int64_t f = pix_to_face[p];
if (f < 0) {
continue;
}
scalar_t pix_attr = 0.0;
for (int i = 0; i < 3; ++i) {
scalar_t weight = barycentric_coords[p * 3 + i];
scalar_t vert_attr = face_attrs[f * 3 * D + i * D + d];
pix_attr += weight * vert_attr;
}
pix_attrs[p * D + d] = pix_attr;
}
}
at::Tensor InterpFaceAttrsForwardCuda(
const at::Tensor& pix_to_face,
const at::Tensor& barycentric_coords,
const at::Tensor& face_attrs) {
// Make sure all inputs are on the same device
at::TensorArg pix_to_face_t{pix_to_face, "pix_to_face", 1},
barycentric_coords_t{barycentric_coords, "barycentric_coords", 2},
face_attrs_t{face_attrs, "face_attributes", 3};
at::CheckedFrom c = "InterpFaceAttrsForwardCuda";
at::checkAllSameGPU(c, {pix_to_face_t, barycentric_coords_t, face_attrs_t});
at::checkAllSameType(c, {barycentric_coords_t, face_attrs_t});
// Set the device for the kernel launch based on the input
at::cuda::CUDAGuard device_guard(pix_to_face.device());
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
const auto P = pix_to_face.size(0);
const auto F = face_attrs.size(0);
const auto D = face_attrs.size(2);
TORCH_CHECK(
barycentric_coords.size(0) == P && barycentric_coords.size(1) == 3,
"barycentric_coords must have size (P, 3)");
TORCH_CHECK(face_attrs.size(1) == 3, "face_attrs must have size (F, 3, D)");
auto pix_attrs = at::zeros({P, D}, face_attrs.options());
const int threads = 1024;
const int blocks = 512;
AT_DISPATCH_FLOATING_TYPES(
face_attrs.scalar_type(), "interp_face_attrs_cuda", ([&] {
InterpFaceAttrsForwardKernel<<<blocks, threads, 0, stream>>>(
pix_to_face.contiguous().data_ptr<int64_t>(),
barycentric_coords.contiguous().data_ptr<scalar_t>(),
face_attrs.contiguous().data_ptr<scalar_t>(),
pix_attrs.contiguous().data_ptr<scalar_t>(),
P,
F,
D);
}));
AT_CUDA_CHECK(cudaGetLastError());
return pix_attrs;
}
template <typename scalar_t>
__global__ void InterpFaceAttrsBackwardKernel(
const int64_t* __restrict__ pix_to_face, // (P,)
const scalar_t* __restrict__ barycentric_coords, // (P, 3)
const scalar_t* __restrict__ face_attrs, // (F, 3, D)
const scalar_t* __restrict__ grad_pix_attrs, // (P, D)
scalar_t* __restrict__ grad_barycentric_coords, // (P, 3)
scalar_t* __restrict__ grad_face_attrs, // (F, 3, D)
const size_t P,
const size_t F,
const size_t D) {
const int tid = threadIdx.x + blockIdx.x * blockDim.x;
const int num_threads = blockDim.x * gridDim.x;
for (int pd = tid; pd < P * D; pd += num_threads) {
const int p = pd / D;
const int d = pd % D;
const int64_t f = pix_to_face[p];
if (f < 0) {
continue;
}
scalar_t upstream_grad = grad_pix_attrs[p * D + d];
for (int i = 0; i < 3; ++i) {
scalar_t weight = barycentric_coords[p * 3 + i];
scalar_t vert_attr = face_attrs[f * 3 * D + i * D + d];
scalar_t grad_bary_down = vert_attr * upstream_grad;
scalar_t grad_face_down = weight * upstream_grad;
atomicAdd(grad_barycentric_coords + p * 3 + i, grad_bary_down);
atomicAdd(grad_face_attrs + f * 3 * D + i * D + d, grad_face_down);
}
}
}
std::tuple<at::Tensor, at::Tensor> InterpFaceAttrsBackwardCuda(
const at::Tensor& pix_to_face,
const at::Tensor& barycentric_coords,
const at::Tensor& face_attrs,
const at::Tensor& grad_pix_attrs) {
// Make sure all inputs are on the same device
at::TensorArg pix_to_face_t{pix_to_face, "pix_to_face", 1},
barycentric_coords_t{barycentric_coords, "barycentric_coords", 2},
face_attrs_t{face_attrs, "face_attributes", 3},
grad_pix_attrs_t{grad_pix_attrs, "pix_attrs", 4};
at::CheckedFrom c = "InterpFaceAttrsBackwarduda";
at::checkAllSameGPU(
c, {pix_to_face_t, barycentric_coords_t, face_attrs_t, grad_pix_attrs_t});
at::checkAllSameType(
c, {barycentric_coords_t, face_attrs_t, grad_pix_attrs_t});
// Set the device for the kernel launch based on the input
at::cuda::CUDAGuard device_guard(pix_to_face.device());
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
const auto P = pix_to_face.size(0);
const auto F = face_attrs.size(0);
const auto D = face_attrs.size(2);
TORCH_CHECK(
barycentric_coords.size(0) == P && barycentric_coords.size(1) == 3,
"barycentric_coords must have size (P, 3)");
TORCH_CHECK(face_attrs.size(1) == 3, "face_attrs must have size (F, 3, D)");
TORCH_CHECK(
grad_pix_attrs.size(0) == P && grad_pix_attrs.size(1) == D,
"grad_pix_attrs must have size (P, D)");
auto grad_barycentric_coords = at::zeros_like(barycentric_coords);
auto grad_face_attrs = at::zeros_like(face_attrs);
const int threads = 1024;
const int blocks = 512;
// Only allow float for now.
// TODO: Add support for double once we fix atomicAdd
// clang-format off
InterpFaceAttrsBackwardKernel<<<blocks, threads, 0, stream>>>(
pix_to_face.contiguous().data_ptr<int64_t>(),
barycentric_coords.contiguous().data_ptr<float>(),
face_attrs.contiguous().data_ptr<float>(),
grad_pix_attrs.contiguous().data_ptr<float>(),
grad_barycentric_coords.contiguous().data_ptr<float>(),
grad_face_attrs.contiguous().data_ptr<float>(),
P, F, D);
AT_CUDA_CHECK(cudaGetLastError());
// clang-format on
return std::make_tuple(grad_barycentric_coords, grad_face_attrs);
}

View File

@ -0,0 +1,105 @@
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#pragma once
#include <torch/extension.h>
#include <tuple>
#include "utils/pytorch3d_cutils.h"
// Interpolates per-face attributes (forward pass)
//
// Inputs:
// pix_to_face: LongTensor of shape (P,) giving a face index for each pixel.
// Each element should be < F, the total number of faces.
// Face indices < 0 indicate that the pixel is not covered by a face.
// barycentric_coords: FloatTensor of shape (P, 3) giving barycentric coords.
// face_attrs: FloatTensor of shape (F, 3, D) giving a D-dimensional
// value for each vertex of each face.
//
// Returns:
// pix_attributes: FloatTensor of shape (P, D) giving an interpolated value
// for each pixel.
// CPU implementation
at::Tensor InterpFaceAttrsForwardCpu(
const at::Tensor& pix_to_face,
const at::Tensor& barycentric_coords,
const at::Tensor& face_attrs) {
AT_ERROR("Not Implemented");
return pix_to_face;
}
#ifdef WITH_CUDA
// Cuda implementation.
at::Tensor InterpFaceAttrsForwardCuda(
const at::Tensor& pix_to_face,
const at::Tensor& barycentric_coords,
const at::Tensor& face_attrs);
#endif
// General implementation
at::Tensor InterpFaceAttrsForward(
const at::Tensor& pix_to_face,
const at::Tensor& barycentric_coords,
const at::Tensor& face_attrs) {
if (pix_to_face.is_cuda()) {
#ifdef WITH_CUDA
CHECK_CUDA(face_attrs);
CHECK_CUDA(barycentric_coords);
return InterpFaceAttrsForwardCuda(
pix_to_face, barycentric_coords, face_attrs);
#else
AT_ERROR("Not compiled with GPU support.");
#endif
}
return InterpFaceAttrsForwardCpu(pix_to_face, barycentric_coords, face_attrs);
}
// Interpolates per-face attributes (backward pass)
//
// Inputs:
// pix_to_face: LongTensor of shape (P,) giving a face index for each pixel.
// Each element should be < F, the total number of faces.
// Face indices < 0 indicate that the pixel is not covered by a face.
// barycentric_coords: FloatTensor of shape (P, 3) giving barycentric coords.
// face_attrs: FloatTensor of shape (F, 3, D) giving a D-dimensional
// value for each vertex of each face.
// grad_pix_attrs: Upstream gradients of shape (P, D)
//
// Returns a tuple of:
// grad_barycentric_coords: FloatTensor of shape (P, 3)
// grad_face_attrs: FloatTensor of shape (F, 3, D)
std::tuple<at::Tensor, at::Tensor> InterpFaceAttrsBackwardCpu(
const at::Tensor& pix_to_face,
const at::Tensor& barycentric_coords,
const at::Tensor& face_attrs,
const at::Tensor& grad_pix_attrs) {
AT_ERROR("Not Implemented");
return std::make_tuple(pix_to_face, pix_to_face);
}
std::tuple<at::Tensor, at::Tensor> InterpFaceAttrsBackwardCuda(
const at::Tensor& pix_to_face,
const at::Tensor& barycentric_coords,
const at::Tensor& face_attrs,
const at::Tensor& grad_pix_attrs);
std::tuple<at::Tensor, at::Tensor> InterpFaceAttrsBackward(
const at::Tensor& pix_to_face,
const at::Tensor& barycentric_coords,
const at::Tensor& face_attrs,
const at::Tensor& grad_pix_attrs) {
if (pix_to_face.is_cuda()) {
#ifdef WITH_CUDA
CHECK_CUDA(face_attrs);
CHECK_CUDA(barycentric_coords);
CHECK_CUDA(grad_pix_attrs);
return InterpFaceAttrsBackwardCuda(
pix_to_face, barycentric_coords, face_attrs, grad_pix_attrs);
#else
AT_ERROR("Not compiled with GPU support.");
#endif
}
return InterpFaceAttrsBackwardCpu(
pix_to_face, barycentric_coords, face_attrs, grad_pix_attrs);
}

View File

@ -3,6 +3,7 @@
from .cubify import cubify
from .graph_conv import GraphConv
from .interp_face_attrs import interpolate_face_attributes
from .knn import knn_gather, knn_points
from .mesh_face_areas_normals import mesh_face_areas_normals
from .packed_to_padded import packed_to_padded, padded_to_packed

View File

@ -0,0 +1,95 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
import torch
from pytorch3d import _C
from torch.autograd import Function
from torch.autograd.function import once_differentiable
def interpolate_face_attributes(
pix_to_face: torch.Tensor,
barycentric_coords: torch.Tensor,
face_attributes: torch.Tensor,
) -> torch.Tensor:
"""
Interpolate arbitrary face attributes using the barycentric coordinates
for each pixel in the rasterized output.
Args:
pix_to_face: LongTensor of shape (...) specifying the indices
of the faces (in the packed representation) which overlap each
pixel in the image. A value < 0 indicates that the pixel does not
overlap any face and should be skipped.
barycentric_coords: FloatTensor of shape (N, H, W, K, 3) specifying
the barycentric coordianates of each pixel
relative to the faces (in the packed
representation) which overlap the pixel.
face_attributes: packed attributes of shape (total_faces, 3, D),
specifying the value of the attribute for each
vertex in the face.
Returns:
pixel_vals: tensor of shape (N, H, W, K, D) giving the interpolated
value of the face attribute for each pixel.
"""
# Check shapes
F, FV, D = face_attributes.shape
if FV != 3:
raise ValueError("Faces can only have three vertices; got %r" % FV)
N, H, W, K, _ = barycentric_coords.shape
if pix_to_face.shape != (N, H, W, K):
msg = "pix_to_face must have shape (batch_size, H, W, K); got %r"
raise ValueError(msg % (pix_to_face.shape,))
# On CPU use the python version
# TODO: Implement a C++ version of this function
if not pix_to_face.is_cuda:
args = (pix_to_face, barycentric_coords, face_attributes)
return interpolate_face_attributes_python(*args)
# Otherwise flatten and call the custom autograd function
N, H, W, K = pix_to_face.shape
pix_to_face = pix_to_face.view(-1)
barycentric_coords = barycentric_coords.view(N * H * W * K, 3)
args = (pix_to_face, barycentric_coords, face_attributes)
out = _InterpFaceAttrs.apply(*args)
out = out.view(N, H, W, K, -1)
return out
class _InterpFaceAttrs(Function):
@staticmethod
def forward(ctx, pix_to_face, barycentric_coords, face_attrs):
args = (pix_to_face, barycentric_coords, face_attrs)
ctx.save_for_backward(*args)
return _C.interp_face_attrs_forward(*args)
@staticmethod
@once_differentiable
def backward(ctx, grad_pix_attrs):
args = ctx.saved_tensors
args = args + (grad_pix_attrs,)
grads = _C.interp_face_attrs_backward(*args)
grad_pix_to_face = None
grad_barycentric_coords = grads[0]
grad_face_attrs = grads[1]
return grad_pix_to_face, grad_barycentric_coords, grad_face_attrs
def interpolate_face_attributes_python(
pix_to_face: torch.Tensor,
barycentric_coords: torch.Tensor,
face_attributes: torch.Tensor,
) -> torch.Tensor:
F, FV, D = face_attributes.shape
N, H, W, K, _ = barycentric_coords.shape
# Replace empty pixels in pix_to_face with 0 in order to interpolate.
mask = pix_to_face < 0
pix_to_face = pix_to_face.clone()
pix_to_face[mask] = 0
idx = pix_to_face.view(N * H * W * K, 1, 1).expand(N * H * W * K, 3, D)
pixel_face_vals = face_attributes.gather(0, idx).view(N, H, W, K, 3, D)
pixel_vals = (barycentric_coords[..., None] * pixel_face_vals).sum(dim=-2)
pixel_vals[mask] = 0 # Replace masked values in output.
return pixel_vals

View File

@ -4,8 +4,7 @@
from typing import Tuple
import torch
from .texturing import interpolate_face_attributes
from pytorch3d.ops import interpolate_face_attributes
def _apply_lighting(

View File

@ -3,10 +3,9 @@
import torch
import torch.nn.functional as F
from pytorch3d.ops import interpolate_face_attributes
from pytorch3d.structures.textures import Textures
from .utils import interpolate_face_attributes
def interpolate_texture_map(fragments, meshes) -> torch.Tensor:
"""

View File

@ -2,6 +2,7 @@
import torch
from pytorch3d.ops import interpolate_face_attributes
def _clip_barycentric_coordinates(bary) -> torch.Tensor:
@ -25,51 +26,6 @@ def _clip_barycentric_coordinates(bary) -> torch.Tensor:
return clipped
def interpolate_face_attributes(
pix_to_face: torch.Tensor,
barycentric_coords: torch.Tensor,
face_attributes: torch.Tensor,
) -> torch.Tensor:
"""
Interpolate arbitrary face attributes using the barycentric coordinates
for each pixel in the rasterized output.
Args:
pix_to_face: LongTensor of shape (N, H, W, K) specifying the indices
of the faces (in the packed representation) which
overlap each pixel in the image.
barycentric_coords: FloatTensor of shape (N, H, W, K, 3) specifying
the barycentric coordianates of each pixel
relative to the faces (in the packed
representation) which overlap the pixel.
face_attributes: packed attributes of shape (total_faces, 3, D),
specifying the value of the attribute for each
vertex in the face.
Returns:
pixel_vals: tensor of shape (N, H, W, K, D) giving the interpolated
value of the face attribute for each pixel.
"""
F, FV, D = face_attributes.shape
if FV != 3:
raise ValueError("Faces can only have three vertices; got %r" % FV)
N, H, W, K, _ = barycentric_coords.shape
if pix_to_face.shape != (N, H, W, K):
msg = "pix_to_face must have shape (batch_size, H, W, K); got %r"
raise ValueError(msg % (pix_to_face.shape,))
# Replace empty pixels in pix_to_face with 0 in order to interpolate.
mask = pix_to_face == -1
pix_to_face = pix_to_face.clone()
pix_to_face[mask] = 0
idx = pix_to_face.view(N * H * W * K, 1, 1).expand(N * H * W * K, 3, D)
# pyre-fixme[16]: `Tensor` has no attribute `gather`.
pixel_face_vals = face_attributes.gather(0, idx).view(N, H, W, K, 3, D)
pixel_vals = (barycentric_coords[..., None] * pixel_face_vals).sum(dim=-2)
pixel_vals[mask] = 0 # Replace masked values in output.
return pixel_vals
def _interpolate_zbuf(
pix_to_face: torch.Tensor, barycentric_coords: torch.Tensor, meshes
) -> torch.Tensor:

View File

@ -0,0 +1,76 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
from itertools import product
import torch
from fvcore.common.benchmark import benchmark
from pytorch3d.ops.interp_face_attrs import (
interpolate_face_attributes,
interpolate_face_attributes_python,
)
def _generate_data(N, S, K, F, D, device, requires_grad=False):
pix_to_face = torch.randint(-10, F, (N, S, S, K), device=device)
barycentric_coords = torch.randn(
N, S, S, K, 3, device=device, requires_grad=requires_grad
)
face_attrs = torch.randn(F, 3, D, device=device, requires_grad=requires_grad)
grad_pix_attrs = torch.randn(N, S, S, K, D, device=device)
return pix_to_face, barycentric_coords, face_attrs, grad_pix_attrs
def _bm_forward(N, S, F, K, D, impl):
# The runtime depends on the values of pix_to_face. So for proper
# benchmarking we should probably take the average of multiple
# values of pix to face. But this doesn't easily fit into fvcore
# benchmarking, so instead we'll just set a manual seed to make sure
# that different impls will use the same data.
torch.manual_seed(0)
device = torch.device("cuda")
data = _generate_data(N, S, K, F, D, device, requires_grad=False)
args = data[:3]
torch.cuda.synchronize()
if impl == "cuda":
fun = interpolate_face_attributes
elif impl == "python":
fun = interpolate_face_attributes_python
return lambda: fun(*args)
def _bm_forward_backward(N, S, F, K, D, impl):
torch.manual_seed(0)
device = torch.device("cuda")
data = _generate_data(N, S, K, F, D, device, requires_grad=True)
args, grad = data[:3], data[3]
torch.cuda.synchronize()
if impl == "cuda":
fun = interpolate_face_attributes
elif impl == "python":
fun = interpolate_face_attributes_python
def run():
out = fun(*args)
out.backward(gradient=grad)
return run
def bm_interpolate_face_attribues() -> None:
# For now only benchmark on GPU
if not torch.cuda.is_available():
return
Ns = [1, 4]
Ss = [128]
Ks = [1, 10, 40]
Fs = [5000]
Ds = [1, 3, 16]
impls = ["python", "cuda"]
test_cases = product(Ns, Ss, Ks, Fs, Ds, impls)
kwargs_list = []
for case in test_cases:
N, S, K, F, D, impl = case
kwargs_list.append({"N": N, "S": S, "K": K, "F": F, "D": D, "impl": impl})
benchmark(_bm_forward, "FORWARD", kwargs_list, warmup_iters=3)
benchmark(_bm_forward_backward, "FORWARD+BACKWARD", kwargs_list, warmup_iters=3)

View File

@ -0,0 +1,185 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
import unittest
import torch
from common_testing import TestCaseMixin, get_random_cuda_device
from pytorch3d.ops.interp_face_attrs import (
interpolate_face_attributes,
interpolate_face_attributes_python,
)
from pytorch3d.renderer.mesh.rasterizer import Fragments
from pytorch3d.renderer.mesh.texturing import (
interpolate_texture_map,
interpolate_vertex_colors,
)
from pytorch3d.structures import Meshes, Textures
class TestInterpolateFaceAttributes(TestCaseMixin, unittest.TestCase):
def _test_interp_face_attrs(self, interp_fun, device):
pix_to_face = [0, 2, -1, 0, 1, -1]
barycentric_coords = [
[1.0, 0.0, 0.0],
[0.0, 1.0, 0.0],
[0.0, 0.0, 1.0],
[0.5, 0.5, 0.0],
[0.8, 0.0, 0.2],
[0.25, 0.5, 0.25],
]
face_attrs = [
[[1, 2], [3, 4], [5, 6]],
[[7, 8], [9, 10], [11, 12]],
[[13, 14], [15, 16], [17, 18]],
]
pix_attrs = [
[1, 2],
[15, 16],
[0, 0],
[2, 3],
[0.8 * 7 + 0.2 * 11, 0.8 * 8 + 0.2 * 12],
[0, 0],
]
N, H, W, K, D = 1, 2, 1, 3, 2
pix_to_face = torch.tensor(pix_to_face, dtype=torch.int64, device=device)
pix_to_face = pix_to_face.view(N, H, W, K)
barycentric_coords = torch.tensor(
barycentric_coords, dtype=torch.float32, device=device
)
barycentric_coords = barycentric_coords.view(N, H, W, K, 3)
face_attrs = torch.tensor(face_attrs, dtype=torch.float32, device=device)
pix_attrs = torch.tensor(pix_attrs, dtype=torch.float32, device=device)
pix_attrs = pix_attrs.view(N, H, W, K, D)
args = (pix_to_face, barycentric_coords, face_attrs)
pix_attrs_actual = interp_fun(*args)
self.assertClose(pix_attrs_actual, pix_attrs)
def test_python(self):
device = torch.device("cuda:0")
self._test_interp_face_attrs(interpolate_face_attributes_python, device)
def test_cuda(self):
device = torch.device("cuda:0")
self._test_interp_face_attrs(interpolate_face_attributes, device)
def test_python_vs_cuda(self):
N, H, W, K = 2, 32, 32, 5
F = 1000
D = 3
device = get_random_cuda_device()
torch.manual_seed(598)
pix_to_face = torch.randint(-F, F, (N, H, W, K), device=device)
barycentric_coords = torch.randn(
N, H, W, K, 3, device=device, requires_grad=True
)
face_attrs = torch.randn(F, 3, D, device=device, requires_grad=True)
grad_pix_attrs = torch.randn(N, H, W, K, D, device=device)
args = (pix_to_face, barycentric_coords, face_attrs)
# Run the python version
pix_attrs_py = interpolate_face_attributes_python(*args)
pix_attrs_py.backward(gradient=grad_pix_attrs)
grad_bary_py = barycentric_coords.grad.clone()
grad_face_attrs_py = face_attrs.grad.clone()
# Clear gradients
barycentric_coords.grad.zero_()
face_attrs.grad.zero_()
# Run the CUDA version
pix_attrs_cu = interpolate_face_attributes(*args)
pix_attrs_cu.backward(gradient=grad_pix_attrs)
grad_bary_cu = barycentric_coords.grad.clone()
grad_face_attrs_cu = face_attrs.grad.clone()
# Check they are the same
self.assertClose(pix_attrs_py, pix_attrs_cu, rtol=2e-3)
self.assertClose(grad_bary_py, grad_bary_cu, rtol=1e-4)
self.assertClose(grad_face_attrs_py, grad_face_attrs_cu, rtol=1e-3)
def test_interpolate_attributes(self):
"""
This tests both interpolate_vertex_colors as well as
interpolate_face_attributes.
"""
verts = torch.randn((4, 3), dtype=torch.float32)
faces = torch.tensor([[2, 1, 0], [3, 1, 0]], dtype=torch.int64)
vert_tex = torch.tensor(
[[0, 1, 0], [0, 1, 1], [1, 1, 0], [1, 1, 1]], dtype=torch.float32
)
tex = Textures(verts_rgb=vert_tex[None, :])
mesh = Meshes(verts=[verts], faces=[faces], textures=tex)
pix_to_face = torch.tensor([0, 1], dtype=torch.int64).view(1, 1, 1, 2)
barycentric_coords = torch.tensor(
[[0.5, 0.3, 0.2], [0.3, 0.6, 0.1]], dtype=torch.float32
).view(1, 1, 1, 2, -1)
expected_vals = torch.tensor(
[[0.5, 1.0, 0.3], [0.3, 1.0, 0.9]], dtype=torch.float32
).view(1, 1, 1, 2, -1)
fragments = Fragments(
pix_to_face=pix_to_face,
bary_coords=barycentric_coords,
zbuf=torch.ones_like(pix_to_face),
dists=torch.ones_like(pix_to_face),
)
texels = interpolate_vertex_colors(fragments, mesh)
self.assertTrue(torch.allclose(texels, expected_vals[None, :]))
def test_interpolate_attributes_grad(self):
verts = torch.randn((4, 3), dtype=torch.float32)
faces = torch.tensor([[2, 1, 0], [3, 1, 0]], dtype=torch.int64)
vert_tex = torch.tensor(
[[0, 1, 0], [0, 1, 1], [1, 1, 0], [1, 1, 1]],
dtype=torch.float32,
requires_grad=True,
)
tex = Textures(verts_rgb=vert_tex[None, :])
mesh = Meshes(verts=[verts], faces=[faces], textures=tex)
pix_to_face = torch.tensor([0, 1], dtype=torch.int64).view(1, 1, 1, 2)
barycentric_coords = torch.tensor(
[[0.5, 0.3, 0.2], [0.3, 0.6, 0.1]], dtype=torch.float32
).view(1, 1, 1, 2, -1)
fragments = Fragments(
pix_to_face=pix_to_face,
bary_coords=barycentric_coords,
zbuf=torch.ones_like(pix_to_face),
dists=torch.ones_like(pix_to_face),
)
grad_vert_tex = torch.tensor(
[[0.3, 0.3, 0.3], [0.9, 0.9, 0.9], [0.5, 0.5, 0.5], [0.3, 0.3, 0.3]],
dtype=torch.float32,
)
texels = interpolate_vertex_colors(fragments, mesh)
texels.sum().backward()
self.assertTrue(hasattr(vert_tex, "grad"))
self.assertTrue(torch.allclose(vert_tex.grad, grad_vert_tex[None, :]))
def test_interpolate_face_attributes_fail(self):
# 1. A face can only have 3 verts
# i.e. face_attributes must have shape (F, 3, D)
face_attributes = torch.ones(1, 4, 3)
pix_to_face = torch.ones((1, 1, 1, 1))
fragments = Fragments(
pix_to_face=pix_to_face,
bary_coords=pix_to_face[..., None].expand(-1, -1, -1, -1, 3),
zbuf=pix_to_face,
dists=pix_to_face,
)
with self.assertRaises(ValueError):
interpolate_face_attributes(
fragments.pix_to_face, fragments.bary_coords, face_attributes
)
# 2. pix_to_face must have shape (N, H, W, K)
pix_to_face = torch.ones((1, 1, 1, 1, 3))
fragments = Fragments(
pix_to_face=pix_to_face,
bary_coords=pix_to_face,
zbuf=pix_to_face,
dists=pix_to_face,
)
with self.assertRaises(ValueError):
interpolate_face_attributes(
fragments.pix_to_face, fragments.bary_coords, face_attributes
)

View File

@ -7,103 +7,13 @@ import torch
import torch.nn.functional as F
from common_testing import TestCaseMixin
from pytorch3d.renderer.mesh.rasterizer import Fragments
from pytorch3d.renderer.mesh.texturing import (
interpolate_face_attributes,
interpolate_texture_map,
interpolate_vertex_colors,
)
from pytorch3d.renderer.mesh.texturing import interpolate_texture_map
from pytorch3d.structures import Meshes, Textures
from pytorch3d.structures.utils import list_to_padded
from test_meshes import TestMeshes
class TestTexturing(TestCaseMixin, unittest.TestCase):
def test_interpolate_attributes(self):
"""
This tests both interpolate_vertex_colors as well as
interpolate_face_attributes.
"""
verts = torch.randn((4, 3), dtype=torch.float32)
faces = torch.tensor([[2, 1, 0], [3, 1, 0]], dtype=torch.int64)
vert_tex = torch.tensor(
[[0, 1, 0], [0, 1, 1], [1, 1, 0], [1, 1, 1]], dtype=torch.float32
)
tex = Textures(verts_rgb=vert_tex[None, :])
mesh = Meshes(verts=[verts], faces=[faces], textures=tex)
pix_to_face = torch.tensor([0, 1], dtype=torch.int64).view(1, 1, 1, 2)
barycentric_coords = torch.tensor(
[[0.5, 0.3, 0.2], [0.3, 0.6, 0.1]], dtype=torch.float32
).view(1, 1, 1, 2, -1)
expected_vals = torch.tensor(
[[0.5, 1.0, 0.3], [0.3, 1.0, 0.9]], dtype=torch.float32
).view(1, 1, 1, 2, -1)
fragments = Fragments(
pix_to_face=pix_to_face,
bary_coords=barycentric_coords,
zbuf=torch.ones_like(pix_to_face),
dists=torch.ones_like(pix_to_face),
)
texels = interpolate_vertex_colors(fragments, mesh)
self.assertTrue(torch.allclose(texels, expected_vals[None, :]))
def test_interpolate_attributes_grad(self):
verts = torch.randn((4, 3), dtype=torch.float32)
faces = torch.tensor([[2, 1, 0], [3, 1, 0]], dtype=torch.int64)
vert_tex = torch.tensor(
[[0, 1, 0], [0, 1, 1], [1, 1, 0], [1, 1, 1]],
dtype=torch.float32,
requires_grad=True,
)
tex = Textures(verts_rgb=vert_tex[None, :])
mesh = Meshes(verts=[verts], faces=[faces], textures=tex)
pix_to_face = torch.tensor([0, 1], dtype=torch.int64).view(1, 1, 1, 2)
barycentric_coords = torch.tensor(
[[0.5, 0.3, 0.2], [0.3, 0.6, 0.1]], dtype=torch.float32
).view(1, 1, 1, 2, -1)
fragments = Fragments(
pix_to_face=pix_to_face,
bary_coords=barycentric_coords,
zbuf=torch.ones_like(pix_to_face),
dists=torch.ones_like(pix_to_face),
)
grad_vert_tex = torch.tensor(
[[0.3, 0.3, 0.3], [0.9, 0.9, 0.9], [0.5, 0.5, 0.5], [0.3, 0.3, 0.3]],
dtype=torch.float32,
)
texels = interpolate_vertex_colors(fragments, mesh)
texels.sum().backward()
self.assertTrue(hasattr(vert_tex, "grad"))
self.assertTrue(torch.allclose(vert_tex.grad, grad_vert_tex[None, :]))
def test_interpolate_face_attributes_fail(self):
# 1. A face can only have 3 verts
# i.e. face_attributes must have shape (F, 3, D)
face_attributes = torch.ones(1, 4, 3)
pix_to_face = torch.ones((1, 1, 1, 1))
fragments = Fragments(
pix_to_face=pix_to_face,
bary_coords=pix_to_face[..., None].expand(-1, -1, -1, -1, 3),
zbuf=pix_to_face,
dists=pix_to_face,
)
with self.assertRaises(ValueError):
interpolate_face_attributes(
fragments.pix_to_face, fragments.bary_coords, face_attributes
)
# 2. pix_to_face must have shape (N, H, W, K)
pix_to_face = torch.ones((1, 1, 1, 1, 3))
fragments = Fragments(
pix_to_face=pix_to_face,
bary_coords=pix_to_face,
zbuf=pix_to_face,
dists=pix_to_face,
)
with self.assertRaises(ValueError):
interpolate_face_attributes(
fragments.pix_to_face, fragments.bary_coords, face_attributes
)
def test_interpolate_texture_map(self):
barycentric_coords = torch.tensor(
[[0.5, 0.3, 0.2], [0.3, 0.6, 0.1]], dtype=torch.float32