From 26d2cc24c1382047a81dd182f9621a17184e0a95 Mon Sep 17 00:00:00 2001 From: Justin Johnson Date: Mon, 13 Jul 2020 12:58:07 -0700 Subject: [PATCH] CUDA kernel for interpolate_face_attributes Summary: When rendering meshes with Phong shading, interpolate_face_attributes was taking up a nontrivial fraction of the overall shading time. This diff replaces our Python implementation of this function with a CUDA implementation. Reviewed By: nikhilaravi Differential Revision: D21610763 fbshipit-source-id: 2bb362a28f698541812aeab539047264b125ebb8 --- pytorch3d/csrc/ext.cpp | 3 + .../interp_face_attrs/interp_face_attrs.cu | 161 +++++++++++++++ .../interp_face_attrs/interp_face_attrs.h | 105 ++++++++++ pytorch3d/ops/__init__.py | 1 + pytorch3d/ops/interp_face_attrs.py | 95 +++++++++ pytorch3d/renderer/mesh/shading.py | 3 +- pytorch3d/renderer/mesh/texturing.py | 3 +- pytorch3d/renderer/mesh/utils.py | 46 +---- tests/bm_interpolate_face_attributes.py | 76 +++++++ tests/test_interpolate_face_attributes.py | 185 ++++++++++++++++++ tests/test_texturing.py | 92 +-------- 11 files changed, 630 insertions(+), 140 deletions(-) create mode 100644 pytorch3d/csrc/interp_face_attrs/interp_face_attrs.cu create mode 100644 pytorch3d/csrc/interp_face_attrs/interp_face_attrs.h create mode 100644 pytorch3d/ops/interp_face_attrs.py create mode 100644 tests/bm_interpolate_face_attributes.py create mode 100644 tests/test_interpolate_face_attributes.py diff --git a/pytorch3d/csrc/ext.cpp b/pytorch3d/csrc/ext.cpp index 4dc9454a..55da2ee7 100644 --- a/pytorch3d/csrc/ext.cpp +++ b/pytorch3d/csrc/ext.cpp @@ -6,6 +6,7 @@ #include "compositing/weighted_sum.h" #include "face_areas_normals/face_areas_normals.h" #include "gather_scatter/gather_scatter.h" +#include "interp_face_attrs/interp_face_attrs.h" #include "knn/knn.h" #include "packed_to_padded_tensor/packed_to_padded_tensor.h" #include "point_mesh/point_mesh_edge.h" @@ -18,6 +19,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("face_areas_normals_backward", &FaceAreasNormalsBackward); m.def("packed_to_padded", &PackedToPadded); m.def("padded_to_packed", &PaddedToPacked); + m.def("interp_face_attrs_forward", &InterpFaceAttrsForward); + m.def("interp_face_attrs_backward", &InterpFaceAttrsBackward); #ifdef WITH_CUDA m.def("knn_check_version", &KnnCheckVersion); #endif diff --git a/pytorch3d/csrc/interp_face_attrs/interp_face_attrs.cu b/pytorch3d/csrc/interp_face_attrs/interp_face_attrs.cu new file mode 100644 index 00000000..66a2534c --- /dev/null +++ b/pytorch3d/csrc/interp_face_attrs/interp_face_attrs.cu @@ -0,0 +1,161 @@ +// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. + +#include +#include +#include +#include + +template +__global__ void InterpFaceAttrsForwardKernel( + const int64_t* __restrict__ pix_to_face, // (P,) + const scalar_t* __restrict__ barycentric_coords, // (P, 3) + const scalar_t* __restrict__ face_attrs, // (F, 3, D) + scalar_t* pix_attrs, // (P, D) + const size_t P, + const size_t F, + const size_t D) { + const int tid = threadIdx.x + blockIdx.x * blockDim.x; + const int num_threads = blockDim.x * gridDim.x; + for (int pd = tid; pd < P * D; pd += num_threads) { + const int p = pd / D; + const int d = pd % D; + const int64_t f = pix_to_face[p]; + if (f < 0) { + continue; + } + scalar_t pix_attr = 0.0; + for (int i = 0; i < 3; ++i) { + scalar_t weight = barycentric_coords[p * 3 + i]; + scalar_t vert_attr = face_attrs[f * 3 * D + i * D + d]; + pix_attr += weight * vert_attr; + } + pix_attrs[p * D + d] = pix_attr; + } +} + +at::Tensor InterpFaceAttrsForwardCuda( + const at::Tensor& pix_to_face, + const at::Tensor& barycentric_coords, + const at::Tensor& face_attrs) { + // Make sure all inputs are on the same device + at::TensorArg pix_to_face_t{pix_to_face, "pix_to_face", 1}, + barycentric_coords_t{barycentric_coords, "barycentric_coords", 2}, + face_attrs_t{face_attrs, "face_attributes", 3}; + at::CheckedFrom c = "InterpFaceAttrsForwardCuda"; + at::checkAllSameGPU(c, {pix_to_face_t, barycentric_coords_t, face_attrs_t}); + at::checkAllSameType(c, {barycentric_coords_t, face_attrs_t}); + + // Set the device for the kernel launch based on the input + at::cuda::CUDAGuard device_guard(pix_to_face.device()); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + const auto P = pix_to_face.size(0); + const auto F = face_attrs.size(0); + const auto D = face_attrs.size(2); + + TORCH_CHECK( + barycentric_coords.size(0) == P && barycentric_coords.size(1) == 3, + "barycentric_coords must have size (P, 3)"); + TORCH_CHECK(face_attrs.size(1) == 3, "face_attrs must have size (F, 3, D)"); + + auto pix_attrs = at::zeros({P, D}, face_attrs.options()); + const int threads = 1024; + const int blocks = 512; + AT_DISPATCH_FLOATING_TYPES( + face_attrs.scalar_type(), "interp_face_attrs_cuda", ([&] { + InterpFaceAttrsForwardKernel<<>>( + pix_to_face.contiguous().data_ptr(), + barycentric_coords.contiguous().data_ptr(), + face_attrs.contiguous().data_ptr(), + pix_attrs.contiguous().data_ptr(), + P, + F, + D); + })); + AT_CUDA_CHECK(cudaGetLastError()); + return pix_attrs; +} + +template +__global__ void InterpFaceAttrsBackwardKernel( + const int64_t* __restrict__ pix_to_face, // (P,) + const scalar_t* __restrict__ barycentric_coords, // (P, 3) + const scalar_t* __restrict__ face_attrs, // (F, 3, D) + const scalar_t* __restrict__ grad_pix_attrs, // (P, D) + scalar_t* __restrict__ grad_barycentric_coords, // (P, 3) + scalar_t* __restrict__ grad_face_attrs, // (F, 3, D) + const size_t P, + const size_t F, + const size_t D) { + const int tid = threadIdx.x + blockIdx.x * blockDim.x; + const int num_threads = blockDim.x * gridDim.x; + for (int pd = tid; pd < P * D; pd += num_threads) { + const int p = pd / D; + const int d = pd % D; + const int64_t f = pix_to_face[p]; + if (f < 0) { + continue; + } + scalar_t upstream_grad = grad_pix_attrs[p * D + d]; + for (int i = 0; i < 3; ++i) { + scalar_t weight = barycentric_coords[p * 3 + i]; + scalar_t vert_attr = face_attrs[f * 3 * D + i * D + d]; + scalar_t grad_bary_down = vert_attr * upstream_grad; + scalar_t grad_face_down = weight * upstream_grad; + atomicAdd(grad_barycentric_coords + p * 3 + i, grad_bary_down); + atomicAdd(grad_face_attrs + f * 3 * D + i * D + d, grad_face_down); + } + } +} + +std::tuple InterpFaceAttrsBackwardCuda( + const at::Tensor& pix_to_face, + const at::Tensor& barycentric_coords, + const at::Tensor& face_attrs, + const at::Tensor& grad_pix_attrs) { + // Make sure all inputs are on the same device + at::TensorArg pix_to_face_t{pix_to_face, "pix_to_face", 1}, + barycentric_coords_t{barycentric_coords, "barycentric_coords", 2}, + face_attrs_t{face_attrs, "face_attributes", 3}, + grad_pix_attrs_t{grad_pix_attrs, "pix_attrs", 4}; + at::CheckedFrom c = "InterpFaceAttrsBackwarduda"; + at::checkAllSameGPU( + c, {pix_to_face_t, barycentric_coords_t, face_attrs_t, grad_pix_attrs_t}); + at::checkAllSameType( + c, {barycentric_coords_t, face_attrs_t, grad_pix_attrs_t}); + + // Set the device for the kernel launch based on the input + at::cuda::CUDAGuard device_guard(pix_to_face.device()); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + const auto P = pix_to_face.size(0); + const auto F = face_attrs.size(0); + const auto D = face_attrs.size(2); + + TORCH_CHECK( + barycentric_coords.size(0) == P && barycentric_coords.size(1) == 3, + "barycentric_coords must have size (P, 3)"); + TORCH_CHECK(face_attrs.size(1) == 3, "face_attrs must have size (F, 3, D)"); + TORCH_CHECK( + grad_pix_attrs.size(0) == P && grad_pix_attrs.size(1) == D, + "grad_pix_attrs must have size (P, D)"); + + auto grad_barycentric_coords = at::zeros_like(barycentric_coords); + auto grad_face_attrs = at::zeros_like(face_attrs); + const int threads = 1024; + const int blocks = 512; + // Only allow float for now. + // TODO: Add support for double once we fix atomicAdd + // clang-format off + InterpFaceAttrsBackwardKernel<<>>( + pix_to_face.contiguous().data_ptr(), + barycentric_coords.contiguous().data_ptr(), + face_attrs.contiguous().data_ptr(), + grad_pix_attrs.contiguous().data_ptr(), + grad_barycentric_coords.contiguous().data_ptr(), + grad_face_attrs.contiguous().data_ptr(), + P, F, D); + AT_CUDA_CHECK(cudaGetLastError()); + // clang-format on + return std::make_tuple(grad_barycentric_coords, grad_face_attrs); +} diff --git a/pytorch3d/csrc/interp_face_attrs/interp_face_attrs.h b/pytorch3d/csrc/interp_face_attrs/interp_face_attrs.h new file mode 100644 index 00000000..ace48ca2 --- /dev/null +++ b/pytorch3d/csrc/interp_face_attrs/interp_face_attrs.h @@ -0,0 +1,105 @@ +// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. + +#pragma once +#include +#include +#include "utils/pytorch3d_cutils.h" + +// Interpolates per-face attributes (forward pass) +// +// Inputs: +// pix_to_face: LongTensor of shape (P,) giving a face index for each pixel. +// Each element should be < F, the total number of faces. +// Face indices < 0 indicate that the pixel is not covered by a face. +// barycentric_coords: FloatTensor of shape (P, 3) giving barycentric coords. +// face_attrs: FloatTensor of shape (F, 3, D) giving a D-dimensional +// value for each vertex of each face. +// +// Returns: +// pix_attributes: FloatTensor of shape (P, D) giving an interpolated value +// for each pixel. + +// CPU implementation +at::Tensor InterpFaceAttrsForwardCpu( + const at::Tensor& pix_to_face, + const at::Tensor& barycentric_coords, + const at::Tensor& face_attrs) { + AT_ERROR("Not Implemented"); + return pix_to_face; +} + +#ifdef WITH_CUDA +// Cuda implementation. +at::Tensor InterpFaceAttrsForwardCuda( + const at::Tensor& pix_to_face, + const at::Tensor& barycentric_coords, + const at::Tensor& face_attrs); +#endif + +// General implementation +at::Tensor InterpFaceAttrsForward( + const at::Tensor& pix_to_face, + const at::Tensor& barycentric_coords, + const at::Tensor& face_attrs) { + if (pix_to_face.is_cuda()) { +#ifdef WITH_CUDA + CHECK_CUDA(face_attrs); + CHECK_CUDA(barycentric_coords); + return InterpFaceAttrsForwardCuda( + pix_to_face, barycentric_coords, face_attrs); +#else + AT_ERROR("Not compiled with GPU support."); +#endif + } + return InterpFaceAttrsForwardCpu(pix_to_face, barycentric_coords, face_attrs); +} + +// Interpolates per-face attributes (backward pass) +// +// Inputs: +// pix_to_face: LongTensor of shape (P,) giving a face index for each pixel. +// Each element should be < F, the total number of faces. +// Face indices < 0 indicate that the pixel is not covered by a face. +// barycentric_coords: FloatTensor of shape (P, 3) giving barycentric coords. +// face_attrs: FloatTensor of shape (F, 3, D) giving a D-dimensional +// value for each vertex of each face. +// grad_pix_attrs: Upstream gradients of shape (P, D) +// +// Returns a tuple of: +// grad_barycentric_coords: FloatTensor of shape (P, 3) +// grad_face_attrs: FloatTensor of shape (F, 3, D) + +std::tuple InterpFaceAttrsBackwardCpu( + const at::Tensor& pix_to_face, + const at::Tensor& barycentric_coords, + const at::Tensor& face_attrs, + const at::Tensor& grad_pix_attrs) { + AT_ERROR("Not Implemented"); + return std::make_tuple(pix_to_face, pix_to_face); +} + +std::tuple InterpFaceAttrsBackwardCuda( + const at::Tensor& pix_to_face, + const at::Tensor& barycentric_coords, + const at::Tensor& face_attrs, + const at::Tensor& grad_pix_attrs); + +std::tuple InterpFaceAttrsBackward( + const at::Tensor& pix_to_face, + const at::Tensor& barycentric_coords, + const at::Tensor& face_attrs, + const at::Tensor& grad_pix_attrs) { + if (pix_to_face.is_cuda()) { +#ifdef WITH_CUDA + CHECK_CUDA(face_attrs); + CHECK_CUDA(barycentric_coords); + CHECK_CUDA(grad_pix_attrs); + return InterpFaceAttrsBackwardCuda( + pix_to_face, barycentric_coords, face_attrs, grad_pix_attrs); +#else + AT_ERROR("Not compiled with GPU support."); +#endif + } + return InterpFaceAttrsBackwardCpu( + pix_to_face, barycentric_coords, face_attrs, grad_pix_attrs); +} diff --git a/pytorch3d/ops/__init__.py b/pytorch3d/ops/__init__.py index 7cbaf389..baf7164a 100644 --- a/pytorch3d/ops/__init__.py +++ b/pytorch3d/ops/__init__.py @@ -3,6 +3,7 @@ from .cubify import cubify from .graph_conv import GraphConv +from .interp_face_attrs import interpolate_face_attributes from .knn import knn_gather, knn_points from .mesh_face_areas_normals import mesh_face_areas_normals from .packed_to_padded import packed_to_padded, padded_to_packed diff --git a/pytorch3d/ops/interp_face_attrs.py b/pytorch3d/ops/interp_face_attrs.py new file mode 100644 index 00000000..f72ea82f --- /dev/null +++ b/pytorch3d/ops/interp_face_attrs.py @@ -0,0 +1,95 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. + +import torch +from pytorch3d import _C +from torch.autograd import Function +from torch.autograd.function import once_differentiable + + +def interpolate_face_attributes( + pix_to_face: torch.Tensor, + barycentric_coords: torch.Tensor, + face_attributes: torch.Tensor, +) -> torch.Tensor: + """ + Interpolate arbitrary face attributes using the barycentric coordinates + for each pixel in the rasterized output. + + Args: + pix_to_face: LongTensor of shape (...) specifying the indices + of the faces (in the packed representation) which overlap each + pixel in the image. A value < 0 indicates that the pixel does not + overlap any face and should be skipped. + barycentric_coords: FloatTensor of shape (N, H, W, K, 3) specifying + the barycentric coordianates of each pixel + relative to the faces (in the packed + representation) which overlap the pixel. + face_attributes: packed attributes of shape (total_faces, 3, D), + specifying the value of the attribute for each + vertex in the face. + + Returns: + pixel_vals: tensor of shape (N, H, W, K, D) giving the interpolated + value of the face attribute for each pixel. + """ + # Check shapes + F, FV, D = face_attributes.shape + if FV != 3: + raise ValueError("Faces can only have three vertices; got %r" % FV) + N, H, W, K, _ = barycentric_coords.shape + if pix_to_face.shape != (N, H, W, K): + msg = "pix_to_face must have shape (batch_size, H, W, K); got %r" + raise ValueError(msg % (pix_to_face.shape,)) + + # On CPU use the python version + # TODO: Implement a C++ version of this function + if not pix_to_face.is_cuda: + args = (pix_to_face, barycentric_coords, face_attributes) + return interpolate_face_attributes_python(*args) + + # Otherwise flatten and call the custom autograd function + N, H, W, K = pix_to_face.shape + pix_to_face = pix_to_face.view(-1) + barycentric_coords = barycentric_coords.view(N * H * W * K, 3) + args = (pix_to_face, barycentric_coords, face_attributes) + out = _InterpFaceAttrs.apply(*args) + out = out.view(N, H, W, K, -1) + return out + + +class _InterpFaceAttrs(Function): + @staticmethod + def forward(ctx, pix_to_face, barycentric_coords, face_attrs): + args = (pix_to_face, barycentric_coords, face_attrs) + ctx.save_for_backward(*args) + return _C.interp_face_attrs_forward(*args) + + @staticmethod + @once_differentiable + def backward(ctx, grad_pix_attrs): + args = ctx.saved_tensors + args = args + (grad_pix_attrs,) + grads = _C.interp_face_attrs_backward(*args) + grad_pix_to_face = None + grad_barycentric_coords = grads[0] + grad_face_attrs = grads[1] + return grad_pix_to_face, grad_barycentric_coords, grad_face_attrs + + +def interpolate_face_attributes_python( + pix_to_face: torch.Tensor, + barycentric_coords: torch.Tensor, + face_attributes: torch.Tensor, +) -> torch.Tensor: + F, FV, D = face_attributes.shape + N, H, W, K, _ = barycentric_coords.shape + + # Replace empty pixels in pix_to_face with 0 in order to interpolate. + mask = pix_to_face < 0 + pix_to_face = pix_to_face.clone() + pix_to_face[mask] = 0 + idx = pix_to_face.view(N * H * W * K, 1, 1).expand(N * H * W * K, 3, D) + pixel_face_vals = face_attributes.gather(0, idx).view(N, H, W, K, 3, D) + pixel_vals = (barycentric_coords[..., None] * pixel_face_vals).sum(dim=-2) + pixel_vals[mask] = 0 # Replace masked values in output. + return pixel_vals diff --git a/pytorch3d/renderer/mesh/shading.py b/pytorch3d/renderer/mesh/shading.py index a3b20fa5..b6ff84ac 100644 --- a/pytorch3d/renderer/mesh/shading.py +++ b/pytorch3d/renderer/mesh/shading.py @@ -4,8 +4,7 @@ from typing import Tuple import torch - -from .texturing import interpolate_face_attributes +from pytorch3d.ops import interpolate_face_attributes def _apply_lighting( diff --git a/pytorch3d/renderer/mesh/texturing.py b/pytorch3d/renderer/mesh/texturing.py index 0d1b0564..b2ac7eba 100644 --- a/pytorch3d/renderer/mesh/texturing.py +++ b/pytorch3d/renderer/mesh/texturing.py @@ -3,10 +3,9 @@ import torch import torch.nn.functional as F +from pytorch3d.ops import interpolate_face_attributes from pytorch3d.structures.textures import Textures -from .utils import interpolate_face_attributes - def interpolate_texture_map(fragments, meshes) -> torch.Tensor: """ diff --git a/pytorch3d/renderer/mesh/utils.py b/pytorch3d/renderer/mesh/utils.py index d065d2eb..f61f4faf 100644 --- a/pytorch3d/renderer/mesh/utils.py +++ b/pytorch3d/renderer/mesh/utils.py @@ -2,6 +2,7 @@ import torch +from pytorch3d.ops import interpolate_face_attributes def _clip_barycentric_coordinates(bary) -> torch.Tensor: @@ -25,51 +26,6 @@ def _clip_barycentric_coordinates(bary) -> torch.Tensor: return clipped -def interpolate_face_attributes( - pix_to_face: torch.Tensor, - barycentric_coords: torch.Tensor, - face_attributes: torch.Tensor, -) -> torch.Tensor: - """ - Interpolate arbitrary face attributes using the barycentric coordinates - for each pixel in the rasterized output. - - Args: - pix_to_face: LongTensor of shape (N, H, W, K) specifying the indices - of the faces (in the packed representation) which - overlap each pixel in the image. - barycentric_coords: FloatTensor of shape (N, H, W, K, 3) specifying - the barycentric coordianates of each pixel - relative to the faces (in the packed - representation) which overlap the pixel. - face_attributes: packed attributes of shape (total_faces, 3, D), - specifying the value of the attribute for each - vertex in the face. - - Returns: - pixel_vals: tensor of shape (N, H, W, K, D) giving the interpolated - value of the face attribute for each pixel. - """ - F, FV, D = face_attributes.shape - if FV != 3: - raise ValueError("Faces can only have three vertices; got %r" % FV) - N, H, W, K, _ = barycentric_coords.shape - if pix_to_face.shape != (N, H, W, K): - msg = "pix_to_face must have shape (batch_size, H, W, K); got %r" - raise ValueError(msg % (pix_to_face.shape,)) - - # Replace empty pixels in pix_to_face with 0 in order to interpolate. - mask = pix_to_face == -1 - pix_to_face = pix_to_face.clone() - pix_to_face[mask] = 0 - idx = pix_to_face.view(N * H * W * K, 1, 1).expand(N * H * W * K, 3, D) - # pyre-fixme[16]: `Tensor` has no attribute `gather`. - pixel_face_vals = face_attributes.gather(0, idx).view(N, H, W, K, 3, D) - pixel_vals = (barycentric_coords[..., None] * pixel_face_vals).sum(dim=-2) - pixel_vals[mask] = 0 # Replace masked values in output. - return pixel_vals - - def _interpolate_zbuf( pix_to_face: torch.Tensor, barycentric_coords: torch.Tensor, meshes ) -> torch.Tensor: diff --git a/tests/bm_interpolate_face_attributes.py b/tests/bm_interpolate_face_attributes.py new file mode 100644 index 00000000..60fb8d1c --- /dev/null +++ b/tests/bm_interpolate_face_attributes.py @@ -0,0 +1,76 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. + +from itertools import product + +import torch +from fvcore.common.benchmark import benchmark +from pytorch3d.ops.interp_face_attrs import ( + interpolate_face_attributes, + interpolate_face_attributes_python, +) + + +def _generate_data(N, S, K, F, D, device, requires_grad=False): + pix_to_face = torch.randint(-10, F, (N, S, S, K), device=device) + barycentric_coords = torch.randn( + N, S, S, K, 3, device=device, requires_grad=requires_grad + ) + face_attrs = torch.randn(F, 3, D, device=device, requires_grad=requires_grad) + grad_pix_attrs = torch.randn(N, S, S, K, D, device=device) + return pix_to_face, barycentric_coords, face_attrs, grad_pix_attrs + + +def _bm_forward(N, S, F, K, D, impl): + # The runtime depends on the values of pix_to_face. So for proper + # benchmarking we should probably take the average of multiple + # values of pix to face. But this doesn't easily fit into fvcore + # benchmarking, so instead we'll just set a manual seed to make sure + # that different impls will use the same data. + torch.manual_seed(0) + device = torch.device("cuda") + data = _generate_data(N, S, K, F, D, device, requires_grad=False) + args = data[:3] + torch.cuda.synchronize() + if impl == "cuda": + fun = interpolate_face_attributes + elif impl == "python": + fun = interpolate_face_attributes_python + return lambda: fun(*args) + + +def _bm_forward_backward(N, S, F, K, D, impl): + torch.manual_seed(0) + device = torch.device("cuda") + data = _generate_data(N, S, K, F, D, device, requires_grad=True) + args, grad = data[:3], data[3] + torch.cuda.synchronize() + if impl == "cuda": + fun = interpolate_face_attributes + elif impl == "python": + fun = interpolate_face_attributes_python + + def run(): + out = fun(*args) + out.backward(gradient=grad) + + return run + + +def bm_interpolate_face_attribues() -> None: + # For now only benchmark on GPU + if not torch.cuda.is_available(): + return + + Ns = [1, 4] + Ss = [128] + Ks = [1, 10, 40] + Fs = [5000] + Ds = [1, 3, 16] + impls = ["python", "cuda"] + test_cases = product(Ns, Ss, Ks, Fs, Ds, impls) + kwargs_list = [] + for case in test_cases: + N, S, K, F, D, impl = case + kwargs_list.append({"N": N, "S": S, "K": K, "F": F, "D": D, "impl": impl}) + benchmark(_bm_forward, "FORWARD", kwargs_list, warmup_iters=3) + benchmark(_bm_forward_backward, "FORWARD+BACKWARD", kwargs_list, warmup_iters=3) diff --git a/tests/test_interpolate_face_attributes.py b/tests/test_interpolate_face_attributes.py new file mode 100644 index 00000000..7eaf9ef9 --- /dev/null +++ b/tests/test_interpolate_face_attributes.py @@ -0,0 +1,185 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. + +import unittest + +import torch +from common_testing import TestCaseMixin, get_random_cuda_device +from pytorch3d.ops.interp_face_attrs import ( + interpolate_face_attributes, + interpolate_face_attributes_python, +) +from pytorch3d.renderer.mesh.rasterizer import Fragments +from pytorch3d.renderer.mesh.texturing import ( + interpolate_texture_map, + interpolate_vertex_colors, +) +from pytorch3d.structures import Meshes, Textures + + +class TestInterpolateFaceAttributes(TestCaseMixin, unittest.TestCase): + def _test_interp_face_attrs(self, interp_fun, device): + pix_to_face = [0, 2, -1, 0, 1, -1] + barycentric_coords = [ + [1.0, 0.0, 0.0], + [0.0, 1.0, 0.0], + [0.0, 0.0, 1.0], + [0.5, 0.5, 0.0], + [0.8, 0.0, 0.2], + [0.25, 0.5, 0.25], + ] + face_attrs = [ + [[1, 2], [3, 4], [5, 6]], + [[7, 8], [9, 10], [11, 12]], + [[13, 14], [15, 16], [17, 18]], + ] + pix_attrs = [ + [1, 2], + [15, 16], + [0, 0], + [2, 3], + [0.8 * 7 + 0.2 * 11, 0.8 * 8 + 0.2 * 12], + [0, 0], + ] + N, H, W, K, D = 1, 2, 1, 3, 2 + pix_to_face = torch.tensor(pix_to_face, dtype=torch.int64, device=device) + pix_to_face = pix_to_face.view(N, H, W, K) + barycentric_coords = torch.tensor( + barycentric_coords, dtype=torch.float32, device=device + ) + barycentric_coords = barycentric_coords.view(N, H, W, K, 3) + face_attrs = torch.tensor(face_attrs, dtype=torch.float32, device=device) + pix_attrs = torch.tensor(pix_attrs, dtype=torch.float32, device=device) + pix_attrs = pix_attrs.view(N, H, W, K, D) + + args = (pix_to_face, barycentric_coords, face_attrs) + pix_attrs_actual = interp_fun(*args) + self.assertClose(pix_attrs_actual, pix_attrs) + + def test_python(self): + device = torch.device("cuda:0") + self._test_interp_face_attrs(interpolate_face_attributes_python, device) + + def test_cuda(self): + device = torch.device("cuda:0") + self._test_interp_face_attrs(interpolate_face_attributes, device) + + def test_python_vs_cuda(self): + N, H, W, K = 2, 32, 32, 5 + F = 1000 + D = 3 + device = get_random_cuda_device() + torch.manual_seed(598) + pix_to_face = torch.randint(-F, F, (N, H, W, K), device=device) + barycentric_coords = torch.randn( + N, H, W, K, 3, device=device, requires_grad=True + ) + face_attrs = torch.randn(F, 3, D, device=device, requires_grad=True) + grad_pix_attrs = torch.randn(N, H, W, K, D, device=device) + args = (pix_to_face, barycentric_coords, face_attrs) + + # Run the python version + pix_attrs_py = interpolate_face_attributes_python(*args) + pix_attrs_py.backward(gradient=grad_pix_attrs) + grad_bary_py = barycentric_coords.grad.clone() + grad_face_attrs_py = face_attrs.grad.clone() + + # Clear gradients + barycentric_coords.grad.zero_() + face_attrs.grad.zero_() + + # Run the CUDA version + pix_attrs_cu = interpolate_face_attributes(*args) + pix_attrs_cu.backward(gradient=grad_pix_attrs) + grad_bary_cu = barycentric_coords.grad.clone() + grad_face_attrs_cu = face_attrs.grad.clone() + + # Check they are the same + self.assertClose(pix_attrs_py, pix_attrs_cu, rtol=2e-3) + self.assertClose(grad_bary_py, grad_bary_cu, rtol=1e-4) + self.assertClose(grad_face_attrs_py, grad_face_attrs_cu, rtol=1e-3) + + def test_interpolate_attributes(self): + """ + This tests both interpolate_vertex_colors as well as + interpolate_face_attributes. + """ + verts = torch.randn((4, 3), dtype=torch.float32) + faces = torch.tensor([[2, 1, 0], [3, 1, 0]], dtype=torch.int64) + vert_tex = torch.tensor( + [[0, 1, 0], [0, 1, 1], [1, 1, 0], [1, 1, 1]], dtype=torch.float32 + ) + tex = Textures(verts_rgb=vert_tex[None, :]) + mesh = Meshes(verts=[verts], faces=[faces], textures=tex) + pix_to_face = torch.tensor([0, 1], dtype=torch.int64).view(1, 1, 1, 2) + barycentric_coords = torch.tensor( + [[0.5, 0.3, 0.2], [0.3, 0.6, 0.1]], dtype=torch.float32 + ).view(1, 1, 1, 2, -1) + expected_vals = torch.tensor( + [[0.5, 1.0, 0.3], [0.3, 1.0, 0.9]], dtype=torch.float32 + ).view(1, 1, 1, 2, -1) + fragments = Fragments( + pix_to_face=pix_to_face, + bary_coords=barycentric_coords, + zbuf=torch.ones_like(pix_to_face), + dists=torch.ones_like(pix_to_face), + ) + texels = interpolate_vertex_colors(fragments, mesh) + self.assertTrue(torch.allclose(texels, expected_vals[None, :])) + + def test_interpolate_attributes_grad(self): + verts = torch.randn((4, 3), dtype=torch.float32) + faces = torch.tensor([[2, 1, 0], [3, 1, 0]], dtype=torch.int64) + vert_tex = torch.tensor( + [[0, 1, 0], [0, 1, 1], [1, 1, 0], [1, 1, 1]], + dtype=torch.float32, + requires_grad=True, + ) + tex = Textures(verts_rgb=vert_tex[None, :]) + mesh = Meshes(verts=[verts], faces=[faces], textures=tex) + pix_to_face = torch.tensor([0, 1], dtype=torch.int64).view(1, 1, 1, 2) + barycentric_coords = torch.tensor( + [[0.5, 0.3, 0.2], [0.3, 0.6, 0.1]], dtype=torch.float32 + ).view(1, 1, 1, 2, -1) + fragments = Fragments( + pix_to_face=pix_to_face, + bary_coords=barycentric_coords, + zbuf=torch.ones_like(pix_to_face), + dists=torch.ones_like(pix_to_face), + ) + grad_vert_tex = torch.tensor( + [[0.3, 0.3, 0.3], [0.9, 0.9, 0.9], [0.5, 0.5, 0.5], [0.3, 0.3, 0.3]], + dtype=torch.float32, + ) + texels = interpolate_vertex_colors(fragments, mesh) + texels.sum().backward() + self.assertTrue(hasattr(vert_tex, "grad")) + self.assertTrue(torch.allclose(vert_tex.grad, grad_vert_tex[None, :])) + + def test_interpolate_face_attributes_fail(self): + # 1. A face can only have 3 verts + # i.e. face_attributes must have shape (F, 3, D) + face_attributes = torch.ones(1, 4, 3) + pix_to_face = torch.ones((1, 1, 1, 1)) + fragments = Fragments( + pix_to_face=pix_to_face, + bary_coords=pix_to_face[..., None].expand(-1, -1, -1, -1, 3), + zbuf=pix_to_face, + dists=pix_to_face, + ) + with self.assertRaises(ValueError): + interpolate_face_attributes( + fragments.pix_to_face, fragments.bary_coords, face_attributes + ) + + # 2. pix_to_face must have shape (N, H, W, K) + pix_to_face = torch.ones((1, 1, 1, 1, 3)) + fragments = Fragments( + pix_to_face=pix_to_face, + bary_coords=pix_to_face, + zbuf=pix_to_face, + dists=pix_to_face, + ) + with self.assertRaises(ValueError): + interpolate_face_attributes( + fragments.pix_to_face, fragments.bary_coords, face_attributes + ) diff --git a/tests/test_texturing.py b/tests/test_texturing.py index f5b0ddc3..c1abfbbf 100644 --- a/tests/test_texturing.py +++ b/tests/test_texturing.py @@ -7,103 +7,13 @@ import torch import torch.nn.functional as F from common_testing import TestCaseMixin from pytorch3d.renderer.mesh.rasterizer import Fragments -from pytorch3d.renderer.mesh.texturing import ( - interpolate_face_attributes, - interpolate_texture_map, - interpolate_vertex_colors, -) +from pytorch3d.renderer.mesh.texturing import interpolate_texture_map from pytorch3d.structures import Meshes, Textures from pytorch3d.structures.utils import list_to_padded from test_meshes import TestMeshes class TestTexturing(TestCaseMixin, unittest.TestCase): - def test_interpolate_attributes(self): - """ - This tests both interpolate_vertex_colors as well as - interpolate_face_attributes. - """ - verts = torch.randn((4, 3), dtype=torch.float32) - faces = torch.tensor([[2, 1, 0], [3, 1, 0]], dtype=torch.int64) - vert_tex = torch.tensor( - [[0, 1, 0], [0, 1, 1], [1, 1, 0], [1, 1, 1]], dtype=torch.float32 - ) - tex = Textures(verts_rgb=vert_tex[None, :]) - mesh = Meshes(verts=[verts], faces=[faces], textures=tex) - pix_to_face = torch.tensor([0, 1], dtype=torch.int64).view(1, 1, 1, 2) - barycentric_coords = torch.tensor( - [[0.5, 0.3, 0.2], [0.3, 0.6, 0.1]], dtype=torch.float32 - ).view(1, 1, 1, 2, -1) - expected_vals = torch.tensor( - [[0.5, 1.0, 0.3], [0.3, 1.0, 0.9]], dtype=torch.float32 - ).view(1, 1, 1, 2, -1) - fragments = Fragments( - pix_to_face=pix_to_face, - bary_coords=barycentric_coords, - zbuf=torch.ones_like(pix_to_face), - dists=torch.ones_like(pix_to_face), - ) - texels = interpolate_vertex_colors(fragments, mesh) - self.assertTrue(torch.allclose(texels, expected_vals[None, :])) - - def test_interpolate_attributes_grad(self): - verts = torch.randn((4, 3), dtype=torch.float32) - faces = torch.tensor([[2, 1, 0], [3, 1, 0]], dtype=torch.int64) - vert_tex = torch.tensor( - [[0, 1, 0], [0, 1, 1], [1, 1, 0], [1, 1, 1]], - dtype=torch.float32, - requires_grad=True, - ) - tex = Textures(verts_rgb=vert_tex[None, :]) - mesh = Meshes(verts=[verts], faces=[faces], textures=tex) - pix_to_face = torch.tensor([0, 1], dtype=torch.int64).view(1, 1, 1, 2) - barycentric_coords = torch.tensor( - [[0.5, 0.3, 0.2], [0.3, 0.6, 0.1]], dtype=torch.float32 - ).view(1, 1, 1, 2, -1) - fragments = Fragments( - pix_to_face=pix_to_face, - bary_coords=barycentric_coords, - zbuf=torch.ones_like(pix_to_face), - dists=torch.ones_like(pix_to_face), - ) - grad_vert_tex = torch.tensor( - [[0.3, 0.3, 0.3], [0.9, 0.9, 0.9], [0.5, 0.5, 0.5], [0.3, 0.3, 0.3]], - dtype=torch.float32, - ) - texels = interpolate_vertex_colors(fragments, mesh) - texels.sum().backward() - self.assertTrue(hasattr(vert_tex, "grad")) - self.assertTrue(torch.allclose(vert_tex.grad, grad_vert_tex[None, :])) - - def test_interpolate_face_attributes_fail(self): - # 1. A face can only have 3 verts - # i.e. face_attributes must have shape (F, 3, D) - face_attributes = torch.ones(1, 4, 3) - pix_to_face = torch.ones((1, 1, 1, 1)) - fragments = Fragments( - pix_to_face=pix_to_face, - bary_coords=pix_to_face[..., None].expand(-1, -1, -1, -1, 3), - zbuf=pix_to_face, - dists=pix_to_face, - ) - with self.assertRaises(ValueError): - interpolate_face_attributes( - fragments.pix_to_face, fragments.bary_coords, face_attributes - ) - - # 2. pix_to_face must have shape (N, H, W, K) - pix_to_face = torch.ones((1, 1, 1, 1, 3)) - fragments = Fragments( - pix_to_face=pix_to_face, - bary_coords=pix_to_face, - zbuf=pix_to_face, - dists=pix_to_face, - ) - with self.assertRaises(ValueError): - interpolate_face_attributes( - fragments.pix_to_face, fragments.bary_coords, face_attributes - ) - def test_interpolate_texture_map(self): barycentric_coords = torch.tensor( [[0.5, 0.3, 0.2], [0.3, 0.6, 0.1]], dtype=torch.float32