mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-12-20 22:30:35 +08:00
CUDA kernel for interpolate_face_attributes
Summary: When rendering meshes with Phong shading, interpolate_face_attributes was taking up a nontrivial fraction of the overall shading time. This diff replaces our Python implementation of this function with a CUDA implementation. Reviewed By: nikhilaravi Differential Revision: D21610763 fbshipit-source-id: 2bb362a28f698541812aeab539047264b125ebb8
This commit is contained in:
committed by
Facebook GitHub Bot
parent
0505e5f4a9
commit
26d2cc24c1
@@ -6,6 +6,7 @@
|
||||
#include "compositing/weighted_sum.h"
|
||||
#include "face_areas_normals/face_areas_normals.h"
|
||||
#include "gather_scatter/gather_scatter.h"
|
||||
#include "interp_face_attrs/interp_face_attrs.h"
|
||||
#include "knn/knn.h"
|
||||
#include "packed_to_padded_tensor/packed_to_padded_tensor.h"
|
||||
#include "point_mesh/point_mesh_edge.h"
|
||||
@@ -18,6 +19,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
m.def("face_areas_normals_backward", &FaceAreasNormalsBackward);
|
||||
m.def("packed_to_padded", &PackedToPadded);
|
||||
m.def("padded_to_packed", &PaddedToPacked);
|
||||
m.def("interp_face_attrs_forward", &InterpFaceAttrsForward);
|
||||
m.def("interp_face_attrs_backward", &InterpFaceAttrsBackward);
|
||||
#ifdef WITH_CUDA
|
||||
m.def("knn_check_version", &KnnCheckVersion);
|
||||
#endif
|
||||
|
||||
161
pytorch3d/csrc/interp_face_attrs/interp_face_attrs.cu
Normal file
161
pytorch3d/csrc/interp_face_attrs/interp_face_attrs.cu
Normal file
@@ -0,0 +1,161 @@
|
||||
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
#include <tuple>
|
||||
|
||||
template <typename scalar_t>
|
||||
__global__ void InterpFaceAttrsForwardKernel(
|
||||
const int64_t* __restrict__ pix_to_face, // (P,)
|
||||
const scalar_t* __restrict__ barycentric_coords, // (P, 3)
|
||||
const scalar_t* __restrict__ face_attrs, // (F, 3, D)
|
||||
scalar_t* pix_attrs, // (P, D)
|
||||
const size_t P,
|
||||
const size_t F,
|
||||
const size_t D) {
|
||||
const int tid = threadIdx.x + blockIdx.x * blockDim.x;
|
||||
const int num_threads = blockDim.x * gridDim.x;
|
||||
for (int pd = tid; pd < P * D; pd += num_threads) {
|
||||
const int p = pd / D;
|
||||
const int d = pd % D;
|
||||
const int64_t f = pix_to_face[p];
|
||||
if (f < 0) {
|
||||
continue;
|
||||
}
|
||||
scalar_t pix_attr = 0.0;
|
||||
for (int i = 0; i < 3; ++i) {
|
||||
scalar_t weight = barycentric_coords[p * 3 + i];
|
||||
scalar_t vert_attr = face_attrs[f * 3 * D + i * D + d];
|
||||
pix_attr += weight * vert_attr;
|
||||
}
|
||||
pix_attrs[p * D + d] = pix_attr;
|
||||
}
|
||||
}
|
||||
|
||||
at::Tensor InterpFaceAttrsForwardCuda(
|
||||
const at::Tensor& pix_to_face,
|
||||
const at::Tensor& barycentric_coords,
|
||||
const at::Tensor& face_attrs) {
|
||||
// Make sure all inputs are on the same device
|
||||
at::TensorArg pix_to_face_t{pix_to_face, "pix_to_face", 1},
|
||||
barycentric_coords_t{barycentric_coords, "barycentric_coords", 2},
|
||||
face_attrs_t{face_attrs, "face_attributes", 3};
|
||||
at::CheckedFrom c = "InterpFaceAttrsForwardCuda";
|
||||
at::checkAllSameGPU(c, {pix_to_face_t, barycentric_coords_t, face_attrs_t});
|
||||
at::checkAllSameType(c, {barycentric_coords_t, face_attrs_t});
|
||||
|
||||
// Set the device for the kernel launch based on the input
|
||||
at::cuda::CUDAGuard device_guard(pix_to_face.device());
|
||||
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
|
||||
const auto P = pix_to_face.size(0);
|
||||
const auto F = face_attrs.size(0);
|
||||
const auto D = face_attrs.size(2);
|
||||
|
||||
TORCH_CHECK(
|
||||
barycentric_coords.size(0) == P && barycentric_coords.size(1) == 3,
|
||||
"barycentric_coords must have size (P, 3)");
|
||||
TORCH_CHECK(face_attrs.size(1) == 3, "face_attrs must have size (F, 3, D)");
|
||||
|
||||
auto pix_attrs = at::zeros({P, D}, face_attrs.options());
|
||||
const int threads = 1024;
|
||||
const int blocks = 512;
|
||||
AT_DISPATCH_FLOATING_TYPES(
|
||||
face_attrs.scalar_type(), "interp_face_attrs_cuda", ([&] {
|
||||
InterpFaceAttrsForwardKernel<<<blocks, threads, 0, stream>>>(
|
||||
pix_to_face.contiguous().data_ptr<int64_t>(),
|
||||
barycentric_coords.contiguous().data_ptr<scalar_t>(),
|
||||
face_attrs.contiguous().data_ptr<scalar_t>(),
|
||||
pix_attrs.contiguous().data_ptr<scalar_t>(),
|
||||
P,
|
||||
F,
|
||||
D);
|
||||
}));
|
||||
AT_CUDA_CHECK(cudaGetLastError());
|
||||
return pix_attrs;
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
__global__ void InterpFaceAttrsBackwardKernel(
|
||||
const int64_t* __restrict__ pix_to_face, // (P,)
|
||||
const scalar_t* __restrict__ barycentric_coords, // (P, 3)
|
||||
const scalar_t* __restrict__ face_attrs, // (F, 3, D)
|
||||
const scalar_t* __restrict__ grad_pix_attrs, // (P, D)
|
||||
scalar_t* __restrict__ grad_barycentric_coords, // (P, 3)
|
||||
scalar_t* __restrict__ grad_face_attrs, // (F, 3, D)
|
||||
const size_t P,
|
||||
const size_t F,
|
||||
const size_t D) {
|
||||
const int tid = threadIdx.x + blockIdx.x * blockDim.x;
|
||||
const int num_threads = blockDim.x * gridDim.x;
|
||||
for (int pd = tid; pd < P * D; pd += num_threads) {
|
||||
const int p = pd / D;
|
||||
const int d = pd % D;
|
||||
const int64_t f = pix_to_face[p];
|
||||
if (f < 0) {
|
||||
continue;
|
||||
}
|
||||
scalar_t upstream_grad = grad_pix_attrs[p * D + d];
|
||||
for (int i = 0; i < 3; ++i) {
|
||||
scalar_t weight = barycentric_coords[p * 3 + i];
|
||||
scalar_t vert_attr = face_attrs[f * 3 * D + i * D + d];
|
||||
scalar_t grad_bary_down = vert_attr * upstream_grad;
|
||||
scalar_t grad_face_down = weight * upstream_grad;
|
||||
atomicAdd(grad_barycentric_coords + p * 3 + i, grad_bary_down);
|
||||
atomicAdd(grad_face_attrs + f * 3 * D + i * D + d, grad_face_down);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::tuple<at::Tensor, at::Tensor> InterpFaceAttrsBackwardCuda(
|
||||
const at::Tensor& pix_to_face,
|
||||
const at::Tensor& barycentric_coords,
|
||||
const at::Tensor& face_attrs,
|
||||
const at::Tensor& grad_pix_attrs) {
|
||||
// Make sure all inputs are on the same device
|
||||
at::TensorArg pix_to_face_t{pix_to_face, "pix_to_face", 1},
|
||||
barycentric_coords_t{barycentric_coords, "barycentric_coords", 2},
|
||||
face_attrs_t{face_attrs, "face_attributes", 3},
|
||||
grad_pix_attrs_t{grad_pix_attrs, "pix_attrs", 4};
|
||||
at::CheckedFrom c = "InterpFaceAttrsBackwarduda";
|
||||
at::checkAllSameGPU(
|
||||
c, {pix_to_face_t, barycentric_coords_t, face_attrs_t, grad_pix_attrs_t});
|
||||
at::checkAllSameType(
|
||||
c, {barycentric_coords_t, face_attrs_t, grad_pix_attrs_t});
|
||||
|
||||
// Set the device for the kernel launch based on the input
|
||||
at::cuda::CUDAGuard device_guard(pix_to_face.device());
|
||||
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
|
||||
const auto P = pix_to_face.size(0);
|
||||
const auto F = face_attrs.size(0);
|
||||
const auto D = face_attrs.size(2);
|
||||
|
||||
TORCH_CHECK(
|
||||
barycentric_coords.size(0) == P && barycentric_coords.size(1) == 3,
|
||||
"barycentric_coords must have size (P, 3)");
|
||||
TORCH_CHECK(face_attrs.size(1) == 3, "face_attrs must have size (F, 3, D)");
|
||||
TORCH_CHECK(
|
||||
grad_pix_attrs.size(0) == P && grad_pix_attrs.size(1) == D,
|
||||
"grad_pix_attrs must have size (P, D)");
|
||||
|
||||
auto grad_barycentric_coords = at::zeros_like(barycentric_coords);
|
||||
auto grad_face_attrs = at::zeros_like(face_attrs);
|
||||
const int threads = 1024;
|
||||
const int blocks = 512;
|
||||
// Only allow float for now.
|
||||
// TODO: Add support for double once we fix atomicAdd
|
||||
// clang-format off
|
||||
InterpFaceAttrsBackwardKernel<<<blocks, threads, 0, stream>>>(
|
||||
pix_to_face.contiguous().data_ptr<int64_t>(),
|
||||
barycentric_coords.contiguous().data_ptr<float>(),
|
||||
face_attrs.contiguous().data_ptr<float>(),
|
||||
grad_pix_attrs.contiguous().data_ptr<float>(),
|
||||
grad_barycentric_coords.contiguous().data_ptr<float>(),
|
||||
grad_face_attrs.contiguous().data_ptr<float>(),
|
||||
P, F, D);
|
||||
AT_CUDA_CHECK(cudaGetLastError());
|
||||
// clang-format on
|
||||
return std::make_tuple(grad_barycentric_coords, grad_face_attrs);
|
||||
}
|
||||
105
pytorch3d/csrc/interp_face_attrs/interp_face_attrs.h
Normal file
105
pytorch3d/csrc/interp_face_attrs/interp_face_attrs.h
Normal file
@@ -0,0 +1,105 @@
|
||||
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
#include <torch/extension.h>
|
||||
#include <tuple>
|
||||
#include "utils/pytorch3d_cutils.h"
|
||||
|
||||
// Interpolates per-face attributes (forward pass)
|
||||
//
|
||||
// Inputs:
|
||||
// pix_to_face: LongTensor of shape (P,) giving a face index for each pixel.
|
||||
// Each element should be < F, the total number of faces.
|
||||
// Face indices < 0 indicate that the pixel is not covered by a face.
|
||||
// barycentric_coords: FloatTensor of shape (P, 3) giving barycentric coords.
|
||||
// face_attrs: FloatTensor of shape (F, 3, D) giving a D-dimensional
|
||||
// value for each vertex of each face.
|
||||
//
|
||||
// Returns:
|
||||
// pix_attributes: FloatTensor of shape (P, D) giving an interpolated value
|
||||
// for each pixel.
|
||||
|
||||
// CPU implementation
|
||||
at::Tensor InterpFaceAttrsForwardCpu(
|
||||
const at::Tensor& pix_to_face,
|
||||
const at::Tensor& barycentric_coords,
|
||||
const at::Tensor& face_attrs) {
|
||||
AT_ERROR("Not Implemented");
|
||||
return pix_to_face;
|
||||
}
|
||||
|
||||
#ifdef WITH_CUDA
|
||||
// Cuda implementation.
|
||||
at::Tensor InterpFaceAttrsForwardCuda(
|
||||
const at::Tensor& pix_to_face,
|
||||
const at::Tensor& barycentric_coords,
|
||||
const at::Tensor& face_attrs);
|
||||
#endif
|
||||
|
||||
// General implementation
|
||||
at::Tensor InterpFaceAttrsForward(
|
||||
const at::Tensor& pix_to_face,
|
||||
const at::Tensor& barycentric_coords,
|
||||
const at::Tensor& face_attrs) {
|
||||
if (pix_to_face.is_cuda()) {
|
||||
#ifdef WITH_CUDA
|
||||
CHECK_CUDA(face_attrs);
|
||||
CHECK_CUDA(barycentric_coords);
|
||||
return InterpFaceAttrsForwardCuda(
|
||||
pix_to_face, barycentric_coords, face_attrs);
|
||||
#else
|
||||
AT_ERROR("Not compiled with GPU support.");
|
||||
#endif
|
||||
}
|
||||
return InterpFaceAttrsForwardCpu(pix_to_face, barycentric_coords, face_attrs);
|
||||
}
|
||||
|
||||
// Interpolates per-face attributes (backward pass)
|
||||
//
|
||||
// Inputs:
|
||||
// pix_to_face: LongTensor of shape (P,) giving a face index for each pixel.
|
||||
// Each element should be < F, the total number of faces.
|
||||
// Face indices < 0 indicate that the pixel is not covered by a face.
|
||||
// barycentric_coords: FloatTensor of shape (P, 3) giving barycentric coords.
|
||||
// face_attrs: FloatTensor of shape (F, 3, D) giving a D-dimensional
|
||||
// value for each vertex of each face.
|
||||
// grad_pix_attrs: Upstream gradients of shape (P, D)
|
||||
//
|
||||
// Returns a tuple of:
|
||||
// grad_barycentric_coords: FloatTensor of shape (P, 3)
|
||||
// grad_face_attrs: FloatTensor of shape (F, 3, D)
|
||||
|
||||
std::tuple<at::Tensor, at::Tensor> InterpFaceAttrsBackwardCpu(
|
||||
const at::Tensor& pix_to_face,
|
||||
const at::Tensor& barycentric_coords,
|
||||
const at::Tensor& face_attrs,
|
||||
const at::Tensor& grad_pix_attrs) {
|
||||
AT_ERROR("Not Implemented");
|
||||
return std::make_tuple(pix_to_face, pix_to_face);
|
||||
}
|
||||
|
||||
std::tuple<at::Tensor, at::Tensor> InterpFaceAttrsBackwardCuda(
|
||||
const at::Tensor& pix_to_face,
|
||||
const at::Tensor& barycentric_coords,
|
||||
const at::Tensor& face_attrs,
|
||||
const at::Tensor& grad_pix_attrs);
|
||||
|
||||
std::tuple<at::Tensor, at::Tensor> InterpFaceAttrsBackward(
|
||||
const at::Tensor& pix_to_face,
|
||||
const at::Tensor& barycentric_coords,
|
||||
const at::Tensor& face_attrs,
|
||||
const at::Tensor& grad_pix_attrs) {
|
||||
if (pix_to_face.is_cuda()) {
|
||||
#ifdef WITH_CUDA
|
||||
CHECK_CUDA(face_attrs);
|
||||
CHECK_CUDA(barycentric_coords);
|
||||
CHECK_CUDA(grad_pix_attrs);
|
||||
return InterpFaceAttrsBackwardCuda(
|
||||
pix_to_face, barycentric_coords, face_attrs, grad_pix_attrs);
|
||||
#else
|
||||
AT_ERROR("Not compiled with GPU support.");
|
||||
#endif
|
||||
}
|
||||
return InterpFaceAttrsBackwardCpu(
|
||||
pix_to_face, barycentric_coords, face_attrs, grad_pix_attrs);
|
||||
}
|
||||
@@ -3,6 +3,7 @@
|
||||
|
||||
from .cubify import cubify
|
||||
from .graph_conv import GraphConv
|
||||
from .interp_face_attrs import interpolate_face_attributes
|
||||
from .knn import knn_gather, knn_points
|
||||
from .mesh_face_areas_normals import mesh_face_areas_normals
|
||||
from .packed_to_padded import packed_to_padded, padded_to_packed
|
||||
|
||||
95
pytorch3d/ops/interp_face_attrs.py
Normal file
95
pytorch3d/ops/interp_face_attrs.py
Normal file
@@ -0,0 +1,95 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
|
||||
import torch
|
||||
from pytorch3d import _C
|
||||
from torch.autograd import Function
|
||||
from torch.autograd.function import once_differentiable
|
||||
|
||||
|
||||
def interpolate_face_attributes(
|
||||
pix_to_face: torch.Tensor,
|
||||
barycentric_coords: torch.Tensor,
|
||||
face_attributes: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Interpolate arbitrary face attributes using the barycentric coordinates
|
||||
for each pixel in the rasterized output.
|
||||
|
||||
Args:
|
||||
pix_to_face: LongTensor of shape (...) specifying the indices
|
||||
of the faces (in the packed representation) which overlap each
|
||||
pixel in the image. A value < 0 indicates that the pixel does not
|
||||
overlap any face and should be skipped.
|
||||
barycentric_coords: FloatTensor of shape (N, H, W, K, 3) specifying
|
||||
the barycentric coordianates of each pixel
|
||||
relative to the faces (in the packed
|
||||
representation) which overlap the pixel.
|
||||
face_attributes: packed attributes of shape (total_faces, 3, D),
|
||||
specifying the value of the attribute for each
|
||||
vertex in the face.
|
||||
|
||||
Returns:
|
||||
pixel_vals: tensor of shape (N, H, W, K, D) giving the interpolated
|
||||
value of the face attribute for each pixel.
|
||||
"""
|
||||
# Check shapes
|
||||
F, FV, D = face_attributes.shape
|
||||
if FV != 3:
|
||||
raise ValueError("Faces can only have three vertices; got %r" % FV)
|
||||
N, H, W, K, _ = barycentric_coords.shape
|
||||
if pix_to_face.shape != (N, H, W, K):
|
||||
msg = "pix_to_face must have shape (batch_size, H, W, K); got %r"
|
||||
raise ValueError(msg % (pix_to_face.shape,))
|
||||
|
||||
# On CPU use the python version
|
||||
# TODO: Implement a C++ version of this function
|
||||
if not pix_to_face.is_cuda:
|
||||
args = (pix_to_face, barycentric_coords, face_attributes)
|
||||
return interpolate_face_attributes_python(*args)
|
||||
|
||||
# Otherwise flatten and call the custom autograd function
|
||||
N, H, W, K = pix_to_face.shape
|
||||
pix_to_face = pix_to_face.view(-1)
|
||||
barycentric_coords = barycentric_coords.view(N * H * W * K, 3)
|
||||
args = (pix_to_face, barycentric_coords, face_attributes)
|
||||
out = _InterpFaceAttrs.apply(*args)
|
||||
out = out.view(N, H, W, K, -1)
|
||||
return out
|
||||
|
||||
|
||||
class _InterpFaceAttrs(Function):
|
||||
@staticmethod
|
||||
def forward(ctx, pix_to_face, barycentric_coords, face_attrs):
|
||||
args = (pix_to_face, barycentric_coords, face_attrs)
|
||||
ctx.save_for_backward(*args)
|
||||
return _C.interp_face_attrs_forward(*args)
|
||||
|
||||
@staticmethod
|
||||
@once_differentiable
|
||||
def backward(ctx, grad_pix_attrs):
|
||||
args = ctx.saved_tensors
|
||||
args = args + (grad_pix_attrs,)
|
||||
grads = _C.interp_face_attrs_backward(*args)
|
||||
grad_pix_to_face = None
|
||||
grad_barycentric_coords = grads[0]
|
||||
grad_face_attrs = grads[1]
|
||||
return grad_pix_to_face, grad_barycentric_coords, grad_face_attrs
|
||||
|
||||
|
||||
def interpolate_face_attributes_python(
|
||||
pix_to_face: torch.Tensor,
|
||||
barycentric_coords: torch.Tensor,
|
||||
face_attributes: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
F, FV, D = face_attributes.shape
|
||||
N, H, W, K, _ = barycentric_coords.shape
|
||||
|
||||
# Replace empty pixels in pix_to_face with 0 in order to interpolate.
|
||||
mask = pix_to_face < 0
|
||||
pix_to_face = pix_to_face.clone()
|
||||
pix_to_face[mask] = 0
|
||||
idx = pix_to_face.view(N * H * W * K, 1, 1).expand(N * H * W * K, 3, D)
|
||||
pixel_face_vals = face_attributes.gather(0, idx).view(N, H, W, K, 3, D)
|
||||
pixel_vals = (barycentric_coords[..., None] * pixel_face_vals).sum(dim=-2)
|
||||
pixel_vals[mask] = 0 # Replace masked values in output.
|
||||
return pixel_vals
|
||||
@@ -4,8 +4,7 @@
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from .texturing import interpolate_face_attributes
|
||||
from pytorch3d.ops import interpolate_face_attributes
|
||||
|
||||
|
||||
def _apply_lighting(
|
||||
|
||||
@@ -3,10 +3,9 @@
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from pytorch3d.ops import interpolate_face_attributes
|
||||
from pytorch3d.structures.textures import Textures
|
||||
|
||||
from .utils import interpolate_face_attributes
|
||||
|
||||
|
||||
def interpolate_texture_map(fragments, meshes) -> torch.Tensor:
|
||||
"""
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
|
||||
|
||||
import torch
|
||||
from pytorch3d.ops import interpolate_face_attributes
|
||||
|
||||
|
||||
def _clip_barycentric_coordinates(bary) -> torch.Tensor:
|
||||
@@ -25,51 +26,6 @@ def _clip_barycentric_coordinates(bary) -> torch.Tensor:
|
||||
return clipped
|
||||
|
||||
|
||||
def interpolate_face_attributes(
|
||||
pix_to_face: torch.Tensor,
|
||||
barycentric_coords: torch.Tensor,
|
||||
face_attributes: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Interpolate arbitrary face attributes using the barycentric coordinates
|
||||
for each pixel in the rasterized output.
|
||||
|
||||
Args:
|
||||
pix_to_face: LongTensor of shape (N, H, W, K) specifying the indices
|
||||
of the faces (in the packed representation) which
|
||||
overlap each pixel in the image.
|
||||
barycentric_coords: FloatTensor of shape (N, H, W, K, 3) specifying
|
||||
the barycentric coordianates of each pixel
|
||||
relative to the faces (in the packed
|
||||
representation) which overlap the pixel.
|
||||
face_attributes: packed attributes of shape (total_faces, 3, D),
|
||||
specifying the value of the attribute for each
|
||||
vertex in the face.
|
||||
|
||||
Returns:
|
||||
pixel_vals: tensor of shape (N, H, W, K, D) giving the interpolated
|
||||
value of the face attribute for each pixel.
|
||||
"""
|
||||
F, FV, D = face_attributes.shape
|
||||
if FV != 3:
|
||||
raise ValueError("Faces can only have three vertices; got %r" % FV)
|
||||
N, H, W, K, _ = barycentric_coords.shape
|
||||
if pix_to_face.shape != (N, H, W, K):
|
||||
msg = "pix_to_face must have shape (batch_size, H, W, K); got %r"
|
||||
raise ValueError(msg % (pix_to_face.shape,))
|
||||
|
||||
# Replace empty pixels in pix_to_face with 0 in order to interpolate.
|
||||
mask = pix_to_face == -1
|
||||
pix_to_face = pix_to_face.clone()
|
||||
pix_to_face[mask] = 0
|
||||
idx = pix_to_face.view(N * H * W * K, 1, 1).expand(N * H * W * K, 3, D)
|
||||
# pyre-fixme[16]: `Tensor` has no attribute `gather`.
|
||||
pixel_face_vals = face_attributes.gather(0, idx).view(N, H, W, K, 3, D)
|
||||
pixel_vals = (barycentric_coords[..., None] * pixel_face_vals).sum(dim=-2)
|
||||
pixel_vals[mask] = 0 # Replace masked values in output.
|
||||
return pixel_vals
|
||||
|
||||
|
||||
def _interpolate_zbuf(
|
||||
pix_to_face: torch.Tensor, barycentric_coords: torch.Tensor, meshes
|
||||
) -> torch.Tensor:
|
||||
|
||||
Reference in New Issue
Block a user