Initial commit

fbshipit-source-id: ad58e416e3ceeca85fae0583308968d04e78fe0d
This commit is contained in:
facebook-github-bot
2020-01-23 11:53:41 -08:00
commit dbf06b504b
211 changed files with 47362 additions and 0 deletions

3
pytorch3d/__init__.py Normal file
View File

@@ -0,0 +1,3 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
__version__ = "0.1"

27
pytorch3d/csrc/ext.cpp Normal file
View File

@@ -0,0 +1,27 @@
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#include <torch/extension.h>
#include "face_areas_normals/face_areas_normals.h"
#include "gather_scatter/gather_scatter.h"
#include "nearest_neighbor_points/nearest_neighbor_points.h"
#include "packed_to_padded_tensor/packed_to_padded_tensor.h"
#include "rasterize_meshes/rasterize_meshes.h"
#include "rasterize_points/rasterize_points.h"
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("face_areas_normals", &face_areas_normals);
m.def("packed_to_padded_tensor", &packed_to_padded_tensor);
m.def("nn_points_idx", &nn_points_idx);
m.def("gather_scatter", &gather_scatter);
m.def("rasterize_points", &RasterizePoints);
m.def("rasterize_points_backward", &RasterizePointsBackward);
m.def("rasterize_meshes_backward", &RasterizeMeshesBackward);
m.def("rasterize_meshes", &RasterizeMeshes);
// These are only visible for testing; users should not call them directly
m.def("_rasterize_points_coarse", &RasterizePointsCoarse);
m.def("_rasterize_points_naive", &RasterizePointsNaive);
m.def("_rasterize_meshes_naive", &RasterizeMeshesNaive);
m.def("_rasterize_meshes_coarse", &RasterizeMeshesCoarse);
m.def("_rasterize_meshes_fine", &RasterizeMeshesFine);
}

View File

@@ -0,0 +1,80 @@
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#include <ATen/ATen.h>
#include <tuple>
template <typename scalar_t>
__global__ void face_areas_kernel(
const scalar_t* __restrict__ verts,
const long* __restrict__ faces,
scalar_t* __restrict__ face_areas,
scalar_t* __restrict__ face_normals,
const size_t V,
const size_t F) {
const size_t tid = blockIdx.x * blockDim.x + threadIdx.x;
const size_t stride = gridDim.x * blockDim.x;
// Faces split evenly over the number of threads in the grid.
// Each thread computes the area & normal of its respective faces and adds it
// to the global face_areas tensor.
for (size_t f = tid; f < F; f += stride) {
const long i0 = faces[3 * f + 0];
const long i1 = faces[3 * f + 1];
const long i2 = faces[3 * f + 2];
const scalar_t v0_x = verts[3 * i0 + 0];
const scalar_t v0_y = verts[3 * i0 + 1];
const scalar_t v0_z = verts[3 * i0 + 2];
const scalar_t v1_x = verts[3 * i1 + 0];
const scalar_t v1_y = verts[3 * i1 + 1];
const scalar_t v1_z = verts[3 * i1 + 2];
const scalar_t v2_x = verts[3 * i2 + 0];
const scalar_t v2_y = verts[3 * i2 + 1];
const scalar_t v2_z = verts[3 * i2 + 2];
const scalar_t ax = v1_x - v0_x;
const scalar_t ay = v1_y - v0_y;
const scalar_t az = v1_z - v0_z;
const scalar_t bx = v2_x - v0_x;
const scalar_t by = v2_y - v0_y;
const scalar_t bz = v2_z - v0_z;
const scalar_t cx = ay * bz - az * by;
const scalar_t cy = az * bx - ax * bz;
const scalar_t cz = ax * by - ay * bx;
scalar_t norm = sqrt(cx * cx + cy * cy + cz * cz);
face_areas[f] = norm / 2.0;
norm = (norm < 1e-6) ? 1e-6 : norm; // max(norm, 1e-6)
face_normals[3 * f + 0] = cx / norm;
face_normals[3 * f + 1] = cy / norm;
face_normals[3 * f + 2] = cz / norm;
}
}
std::tuple<at::Tensor, at::Tensor> face_areas_cuda(
at::Tensor verts,
at::Tensor faces) {
const auto V = verts.size(0);
const auto F = faces.size(0);
at::Tensor areas = at::empty({F}, verts.options());
at::Tensor normals = at::empty({F, 3}, verts.options());
const int blocks = 64;
const int threads = 512;
AT_DISPATCH_FLOATING_TYPES(verts.type(), "face_areas_kernel", ([&] {
face_areas_kernel<scalar_t><<<blocks, threads>>>(
verts.data_ptr<scalar_t>(),
faces.data_ptr<long>(),
areas.data_ptr<scalar_t>(),
normals.data_ptr<scalar_t>(),
V,
F);
}));
return std::make_tuple(areas, normals);
}

View File

@@ -0,0 +1,36 @@
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#pragma once
#include <torch/extension.h>
#include <tuple>
// Compute areas of mesh faces using packed representation.
//
// Inputs:
// verts: FloatTensor of shape (V, 3) giving vertex positions.
// faces: LongTensor of shape (F, 3) giving faces.
//
// Returns:
// areas: FloatTensor of shape (F,) where areas[f] is the area of faces[f].
// normals: FloatTensor of shape (F, 3) where normals[f] is the normal of
// faces[f]
//
// Cuda implementation.
std::tuple<at::Tensor, at::Tensor> face_areas_cuda(
at::Tensor verts,
at::Tensor faces);
// Implementation which is exposed.
std::tuple<at::Tensor, at::Tensor> face_areas_normals(
at::Tensor verts,
at::Tensor faces) {
if (verts.type().is_cuda() && faces.type().is_cuda()) {
#ifdef WITH_CUDA
return face_areas_cuda(verts, faces);
#else
AT_ERROR("Not compiled with GPU support.");
#endif
}
AT_ERROR("Not implemented on the CPU.");
}

View File

@@ -0,0 +1,69 @@
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#include <ATen/ATen.h>
// TODO(T47953967) to make this cuda kernel support all datatypes.
__global__ void gather_scatter_kernel(
const float* __restrict__ input,
const long* __restrict__ edges,
float* __restrict__ output,
bool directed,
bool backward,
const size_t V,
const size_t D,
const size_t E) {
const int tid = threadIdx.x;
// Reverse the vertex order if backward.
const int v0_idx = backward ? 1 : 0;
const int v1_idx = backward ? 0 : 1;
// Edges are split evenly across the blocks.
for (int e = blockIdx.x; e < E; e += gridDim.x) {
// Get indices of vertices which form the edge.
const long v0 = edges[2 * e + v0_idx];
const long v1 = edges[2 * e + v1_idx];
// Split vertex features evenly across threads.
// This implementation will be quite wasteful when D<128 since there will be
// a lot of threads doing nothing.
for (int d = tid; d < D; d += blockDim.x) {
const float val = input[v1 * D + d];
float* address = output + v0 * D + d;
atomicAdd(address, val);
if (!directed) {
const float val = input[v0 * D + d];
float* address = output + v1 * D + d;
atomicAdd(address, val);
}
}
__syncthreads();
}
}
at::Tensor gather_scatter_cuda(
const at::Tensor input,
const at::Tensor edges,
bool directed,
bool backward) {
const auto num_vertices = input.size(0);
const auto input_feature_dim = input.size(1);
const auto num_edges = edges.size(0);
auto output = at::zeros({num_vertices, input_feature_dim}, input.options());
const size_t threads = 128;
const size_t max_blocks = 1920;
const size_t blocks = num_edges < max_blocks ? num_edges : max_blocks;
gather_scatter_kernel<<<blocks, threads>>>(
input.data_ptr<float>(),
edges.data_ptr<long>(),
output.data_ptr<float>(),
directed,
backward,
num_vertices,
input_feature_dim,
num_edges);
return output;
}

View File

@@ -0,0 +1,43 @@
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#pragma once
#include <torch/extension.h>
// Fused gather scatter operation for aggregating features of neighbor nodes
// in a graph. This gather scatter operation is specific to graphs as edge
// indices are used as input.
//
// Args:
// input: float32 Tensor of shape (V, D) where V is the number of vertices
// and D is the feature dimension.
// edges: int64 Tensor of shape (E, 2) giving the indices of the vertices that
// make up the edge. E is the number of edges.
// directed: Bool indicating if edges in the graph are directed. For a
// directed graph v0 -> v1 the updated feature for v0 depends on v1.
// backward: Bool indicating if the operation is the backward pass.
//
// Returns:
// output: float32 Tensor of same shape as input.
// Cuda implementation.
at::Tensor gather_scatter_cuda(
const at::Tensor input,
const at::Tensor edges,
bool directed,
bool backward);
// Exposed implementation.
at::Tensor gather_scatter(
const at::Tensor input,
const at::Tensor edges,
bool directed,
bool backward) {
if (input.type().is_cuda() && edges.type().is_cuda()) {
#ifdef WITH_CUDA
return gather_scatter_cuda(input, edges, directed, backward);
#else
AT_ERROR("Not compiled with GPU support.");
#endif
}
AT_ERROR("Not implemented on the CPU");
}

View File

@@ -0,0 +1,265 @@
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#include <ATen/ATen.h>
#include <float.h>
template <typename scalar_t>
__device__ void warp_reduce(
volatile scalar_t* min_dists,
volatile long* min_idxs,
const size_t tid) {
// s = 32
if (min_dists[tid] > min_dists[tid + 32]) {
min_idxs[tid] = min_idxs[tid + 32];
min_dists[tid] = min_dists[tid + 32];
}
// s = 16
if (min_dists[tid] > min_dists[tid + 16]) {
min_idxs[tid] = min_idxs[tid + 16];
min_dists[tid] = min_dists[tid + 16];
}
// s = 8
if (min_dists[tid] > min_dists[tid + 8]) {
min_idxs[tid] = min_idxs[tid + 8];
min_dists[tid] = min_dists[tid + 8];
}
// s = 4
if (min_dists[tid] > min_dists[tid + 4]) {
min_idxs[tid] = min_idxs[tid + 4];
min_dists[tid] = min_dists[tid + 4];
}
// s = 2
if (min_dists[tid] > min_dists[tid + 2]) {
min_idxs[tid] = min_idxs[tid + 2];
min_dists[tid] = min_dists[tid + 2];
}
// s = 1
if (min_dists[tid] > min_dists[tid + 1]) {
min_idxs[tid] = min_idxs[tid + 1];
min_dists[tid] = min_dists[tid + 1];
}
}
// CUDA kernel to compute nearest neighbors between two batches of pointclouds
// where each point is of dimension D.
//
// Args:
// points1: First set of points, of shape (N, P1, D).
// points2: Second set of points, of shape (N, P2, D).
// idx: Output memory buffer of shape (N, P1).
// N: Batch size.
// P1: Number of points in points1.
// P2: Number of points in points2.
// D_2: Size of the shared buffer; this is D rounded up so that memory access
// is aligned.
//
template <typename scalar_t>
__global__ void nearest_neighbor_kernel(
const scalar_t* __restrict__ points1,
const scalar_t* __restrict__ points2,
long* __restrict__ idx,
const size_t N,
const size_t P1,
const size_t P2,
const size_t D,
const size_t D_2) {
// Each block will compute one element of the output idx[n, i]. Within the
// block we will use threads to compute the distances between points1[n, i]
// and points2[n, j] for all 0 <= j < P2, then use a block reduction to
// take an argmin of the distances.
// Shared buffers for the threads in the block. CUDA only allows declaration
// of a single shared buffer, so it needs to be manually sliced and cast to
// build several logical shared buffers of different types.
extern __shared__ char shared_buf[];
scalar_t* x = (scalar_t*)shared_buf; // scalar_t[DD]
scalar_t* min_dists = &x[D_2]; // scalar_t[NUM_THREADS]
long* min_idxs = (long*)&min_dists[blockDim.x]; // long[NUM_THREADS]
const size_t n = blockIdx.y; // index of batch element.
const size_t i = blockIdx.x; // index of point within batch element.
const size_t tid = threadIdx.x;
// Thread 0 copies points1[n, i, :] into x.
if (tid == 0) {
for (size_t d = 0; d < D; d++) {
x[d] = points1[n * (P1 * D) + i * D + d];
}
}
__syncthreads();
// Compute the distances between points1[n, i] and points2[n, j] for
// all 0 <= j < P2. Here each thread will reduce over P2 / blockDim.x
// in serial, and store its result to shared memory
scalar_t min_dist = FLT_MAX;
size_t min_idx = 0;
for (size_t j = tid; j < P2; j += blockDim.x) {
scalar_t dist = 0;
for (size_t d = 0; d < D; d++) {
scalar_t x_d = x[d];
scalar_t y_d = points2[n * (P2 * D) + j * D + d];
scalar_t diff = x_d - y_d;
dist += diff * diff;
}
min_dist = (j == tid) ? dist : min_dist;
min_idx = (dist <= min_dist) ? j : min_idx;
min_dist = (dist <= min_dist) ? dist : min_dist;
}
min_dists[tid] = min_dist;
min_idxs[tid] = min_idx;
__syncthreads();
// Perform reduction in shared memory.
for (int s = blockDim.x / 2; s > 32; s >>= 1) {
if (tid < s) {
if (min_dists[tid] > min_dists[tid + s]) {
min_dists[tid] = min_dists[tid + s];
min_idxs[tid] = min_idxs[tid + s];
}
}
__syncthreads();
}
// Unroll the last 6 iterations of the loop since they will happen
// synchronized within a single warp.
if (tid < 32)
warp_reduce<scalar_t>(min_dists, min_idxs, tid);
// Finally thread 0 writes the result to the output buffer.
if (tid == 0) {
idx[n * P1 + i] = min_idxs[0];
}
}
// CUDA kernel to compute nearest neighbors between two sets of 3-dimensional
// pointclouds. This is a specialization of the nearest_neighbor_kernel
// to the case D=3.
//
// Args:
// points1: First set of pointclouds, of shape (N, P1, 3).
// points2: Second set of pointclouds, of shape (N, P2, 3).
// idx: Output memory buffer of shape (N, P1).
// N: Batch size.
// P1: Number of points in points1.
// P2: Number of points in points2.
//
template <typename scalar_t>
__global__ void nearest_neighbor_kernel_D3(
const scalar_t* __restrict__ points1,
const scalar_t* __restrict__ points2,
long* __restrict__ idx,
const size_t N,
const size_t P1,
const size_t P2) {
// Single shared memory buffer which is split and cast to different types.
extern __shared__ char shared_buf[];
scalar_t* min_dists = (scalar_t*)shared_buf; // scalar_t[NUM_THREADS]
long* min_idxs = (long*)&min_dists[blockDim.x]; // long[NUM_THREADS]
const size_t D = 3;
const size_t n = blockIdx.y; // index of batch element.
const size_t i = blockIdx.x; // index of point within batch element.
const size_t tid = threadIdx.x;
// Retrieve the coordinates of points1[n, i] from global memory; these
// will be stored in registers for fast access.
const scalar_t x = points1[n * (P1 * D) + i * D + 0];
const scalar_t y = points1[n * (P1 * D) + i * D + 1];
const scalar_t z = points1[n * (P1 * D) + i * D + 2];
// Compute distances between points1[n, i] and all points2[n, j]
// for 0 <= j < P2
scalar_t min_dist = FLT_MAX;
size_t min_idx = 0;
// Distance computation for points in p2 spread across threads in the block.
for (size_t j = tid; j < P2; j += blockDim.x) {
scalar_t dx = x - points2[n * (P2 * D) + j * D + 0];
scalar_t dy = y - points2[n * (P2 * D) + j * D + 1];
scalar_t dz = z - points2[n * (P2 * D) + j * D + 2];
scalar_t dist = dx * dx + dy * dy + dz * dz;
min_dist = (j == tid) ? dist : min_dist;
min_idx = (dist <= min_dist) ? j : min_idx;
min_dist = (dist <= min_dist) ? dist : min_dist;
}
min_dists[tid] = min_dist;
min_idxs[tid] = min_idx;
// Synchronize local threads writing to the shared memory buffer.
__syncthreads();
// Perform reduction in shared memory.
for (int s = blockDim.x / 2; s > 32; s >>= 1) {
if (tid < s) {
if (min_dists[tid] > min_dists[tid + s]) {
min_dists[tid] = min_dists[tid + s];
min_idxs[tid] = min_idxs[tid + s];
}
}
// Synchronize local threads so that min_dists is correct.
__syncthreads();
}
// Unroll the last 6 iterations of the loop since they will happen
// synchronized within a single warp.
if (tid < 32)
warp_reduce<scalar_t>(min_dists, min_idxs, tid);
// Finally thread 0 writes the result to the output buffer.
if (tid == 0) {
idx[n * P1 + i] = min_idxs[0];
}
}
at::Tensor nn_points_idx_cuda(at::Tensor p1, at::Tensor p2) {
const auto N = p1.size(0);
const auto P1 = p1.size(1);
const auto P2 = p2.size(1);
const auto D = p1.size(2);
AT_ASSERTM(p2.size(2) == D, "Point sets must have same last dimension.");
auto idx = at::empty({N, P1}, p1.options().dtype(at::kLong));
// On P100 with pointclouds of size (16, 5000, 3), 128 threads per block
// gives best results.
const int threads = 128;
const dim3 blocks(P1, N);
if (D == 3) {
// Use the specialized kernel for D=3.
AT_DISPATCH_FLOATING_TYPES(p1.type(), "nearest_neighbor_v3_cuda", ([&] {
size_t shared_size = threads * sizeof(size_t) +
threads * sizeof(long);
nearest_neighbor_kernel_D3<scalar_t>
<<<blocks, threads, shared_size>>>(
p1.data_ptr<scalar_t>(),
p2.data_ptr<scalar_t>(),
idx.data_ptr<long>(),
N,
P1,
P2);
}));
} else {
// Use the general kernel for all other D.
AT_DISPATCH_FLOATING_TYPES(
p1.type(), "nearest_neighbor_v3_cuda", ([&] {
// To avoid misaligned memory access, the size of shared buffers
// need to be rounded to the next even size.
size_t D_2 = D + (D % 2);
size_t shared_size = (D_2 + threads) * sizeof(size_t);
shared_size += threads * sizeof(long);
nearest_neighbor_kernel<scalar_t><<<blocks, threads, shared_size>>>(
p1.data_ptr<scalar_t>(),
p2.data_ptr<scalar_t>(),
idx.data_ptr<long>(),
N,
P1,
P2,
D,
D_2);
}));
}
return idx;
}

View File

@@ -0,0 +1,37 @@
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#pragma once
#include <torch/extension.h>
#include "pytorch3d_cutils.h"
// Compute indices of nearest neighbors in pointcloud p2 to points
// in pointcloud p1.
//
// Args:
// p1: FloatTensor of shape (N, P1, D) giving a batch of pointclouds each
// containing P1 points of dimension D.
// p2: FloatTensor of shape (N, P2, D) giving a batch of pointclouds each
// containing P2 points of dimension D.
//
// Returns:
// p1_neighbor_idx: LongTensor of shape (N, P1), where
// p1_neighbor_idx[n, i] = j means that the nearest neighbor
// to p1[n, i] in the cloud p2[n] is p2[n, j].
//
// Cuda implementation.
at::Tensor nn_points_idx_cuda(at::Tensor p1, at::Tensor p2);
// Implementation which is exposed.
at::Tensor nn_points_idx(at::Tensor p1, at::Tensor p2) {
if (p1.type().is_cuda() && p2.type().is_cuda()) {
#ifdef WITH_CUDA
CHECK_CONTIGUOUS_CUDA(p1);
CHECK_CONTIGUOUS_CUDA(p2);
return nn_points_idx_cuda(p1, p2);
#else
AT_ERROR("Not compiled with GPU support.");
#endif
}
AT_ERROR("Not implemented on the CPU.");
};

View File

@@ -0,0 +1,52 @@
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#include <ATen/ATen.h>
template <typename scalar_t>
__global__ void packed_to_padded_tensor_kernel(
const scalar_t* __restrict__ inputs,
const long* __restrict__ first_idxs,
scalar_t* __restrict__ inputs_padded,
const size_t batch_size,
const size_t max_size,
const size_t num_inputs) {
// Batch elements split evenly across blocks (num blocks = batch_size) and
// values for each element split across threads in the block. Each thread adds
// the values of its respective input elements to the global inputs_padded
// tensor.
const size_t tid = threadIdx.x;
const size_t batch_idx = blockIdx.x;
const long start = first_idxs[batch_idx];
const long end =
batch_idx + 1 < batch_size ? first_idxs[batch_idx + 1] : num_inputs;
const int num_faces = end - start;
for (size_t f = tid; f < num_faces; f += blockDim.x) {
inputs_padded[batch_idx * max_size + f] = inputs[start + f];
}
}
at::Tensor packed_to_padded_tensor_cuda(
at::Tensor inputs,
at::Tensor first_idxs,
const long max_size) {
const auto num_inputs = inputs.size(0);
const auto batch_size = first_idxs.size(0);
at::Tensor inputs_padded =
at::zeros({batch_size, max_size}, inputs.options());
const int threads = 512;
const int blocks = batch_size;
AT_DISPATCH_FLOATING_TYPES(
inputs.type(), "packed_to_padded_tensor_kernel", ([&] {
packed_to_padded_tensor_kernel<scalar_t><<<blocks, threads>>>(
inputs.data_ptr<scalar_t>(),
first_idxs.data_ptr<long>(),
inputs_padded.data_ptr<scalar_t>(),
batch_size,
max_size,
num_inputs);
}));
return inputs_padded;
}

View File

@@ -0,0 +1,44 @@
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#pragma once
#include <torch/extension.h>
// Converts a packed tensor into a padded tensor, restoring the batch dimension.
// Refer to pytorch3d/structures/meshes.py for details on packed/padded tensors.
//
// Inputs:
// inputs: FloatTensor of shape (F,), representing the packed batch tensor.
// e.g. areas for faces in a batch of meshes.
// first_idxs: LongTensor of shape (N,) where N is the number of
// elements in the batch and `packed_first_idxs[i] = f`
// means that the inputs for batch element i begin at
// `inputs[f]`.
// max_size: Max length of an element in the batch.
// Returns:
// inputs_padded: FloatTensor of shape (N, max_size) where max_size is max
// of `sizes`. The values for batch element i which start at
// `inputs[packed_first_idxs[i]]` will be copied to
// `inputs_padded[i, :]``, with zeros padding out the extra
// inputs.
//
// Cuda implementation.
at::Tensor packed_to_padded_tensor_cuda(
at::Tensor inputs,
at::Tensor first_idxs,
const long max_size);
// Implementation which is exposed.
at::Tensor packed_to_padded_tensor(
at::Tensor inputs,
at::Tensor first_idxs,
const long max_size) {
if (inputs.type().is_cuda()) {
#ifdef WITH_CUDA
return packed_to_padded_tensor_cuda(inputs, first_idxs, max_size);
#else
AT_ERROR("Not compiled with GPU support.");
#endif
}
AT_ERROR("Not implemented on the CPU.");
}

View File

@@ -0,0 +1,12 @@
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#pragma once
#include <torch/extension.h>
#define CHECK_CUDA(x) \
AT_ASSERTM(x.type().is_cuda(), #x "must be a CUDA tensor.")
#define CHECK_CONTIGUOUS(x) \
AT_ASSERTM(x.is_contiguous(), #x "must be contiguous.")
#define CHECK_CONTIGUOUS_CUDA(x) \
CHECK_CUDA(x); \
CHECK_CONTIGUOUS(x)

View File

@@ -0,0 +1,86 @@
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#pragma once
#include <thrust/tuple.h>
// Common functions and operators for float2.
__device__ inline float2 operator-(const float2& a, const float2& b) {
return make_float2(a.x - b.x, a.y - b.y);
}
__device__ inline float2 operator+(const float2& a, const float2& b) {
return make_float2(a.x + b.x, a.y + b.y);
}
__device__ inline float2 operator/(const float2& a, const float2& b) {
return make_float2(a.x / b.x, a.y / b.y);
}
__device__ inline float2 operator/(const float2& a, const float b) {
return make_float2(a.x / b, a.y / b);
}
__device__ inline float2 operator*(const float2& a, const float2& b) {
return make_float2(a.x * b.x, a.y * b.y);
}
__device__ inline float2 operator*(const float a, const float2& b) {
return make_float2(a * b.x, a * b.y);
}
__device__ inline float dot(const float2& a, const float2& b) {
return a.x * b.x + a.y * b.y;
}
// Backward pass for the dot product.
// Args:
// a, b: Coordinates of two points.
// grad_dot: Upstream gradient for the output.
//
// Returns:
// tuple of gradients for each of the input points:
// (float2 grad_a, float2 grad_b)
//
__device__ inline thrust::tuple<float2, float2>
DotBackward(const float2& a, const float2& b, const float& grad_dot) {
return thrust::make_tuple(grad_dot * b, grad_dot * a);
}
__device__ inline float sum(const float2& a) {
return a.x + a.y;
}
// Common functions and operators for float3.
__device__ inline float3 operator-(const float3& a, const float3& b) {
return make_float3(a.x - b.x, a.y - b.y, a.z - b.z);
}
__device__ inline float3 operator+(const float3& a, const float3& b) {
return make_float3(a.x + b.x, a.y + b.y, a.z + b.z);
}
__device__ inline float3 operator/(const float3& a, const float3& b) {
return make_float3(a.x / b.x, a.y / b.y, a.z / b.z);
}
__device__ inline float3 operator/(const float3& a, const float b) {
return make_float3(a.x / b, a.y / b, a.z / b);
}
__device__ inline float3 operator*(const float3& a, const float3& b) {
return make_float3(a.x * b.x, a.y * b.y, a.z * b.z);
}
__device__ inline float3 operator*(const float a, const float3& b) {
return make_float3(a * b.x, a * b.y, a * b.z);
}
__device__ inline float dot(const float3& a, const float3& b) {
return a.x * b.x + a.y * b.y + a.z * b.z;
}
__device__ inline float sum(const float3& a) {
return a.x + a.y + a.z;
}

View File

@@ -0,0 +1,350 @@
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#include <float.h>
#include <math.h>
#include <torch/extension.h>
#include <cstdio>
#include "float_math.cuh"
// Set epsilon for preventing floating point errors and division by 0.
const auto kEpsilon = 1e-30;
// Determines whether a point p is on the right side of a 2D line segment
// given by the end points v0, v1.
//
// Args:
// p: vec2 Coordinates of a point.
// v0, v1: vec2 Coordinates of the end points of the edge.
//
// Returns:
// area: The signed area of the parallelogram given by the vectors
// A = p - v0
// B = v1 - v0
//
__device__ inline float
EdgeFunctionForward(const float2& p, const float2& v0, const float2& v1) {
return (p.x - v0.x) * (v1.y - v0.y) - (p.y - v0.y) * (v1.x - v0.x);
}
// Backward pass for the edge function returning partial dervivatives for each
// of the input points.
//
// Args:
// p: vec2 Coordinates of a point.
// v0, v1: vec2 Coordinates of the end points of the edge.
// grad_edge: Upstream gradient for output from edge function.
//
// Returns:
// tuple of gradients for each of the input points:
// (float2 d_edge_dp, float2 d_edge_dv0, float2 d_edge_dv1)
//
__device__ inline thrust::tuple<float2, float2, float2> EdgeFunctionBackward(
const float2& p,
const float2& v0,
const float2& v1,
const float& grad_edge) {
const float2 dedge_dp = make_float2(v1.y - v0.y, v0.x - v1.x);
const float2 dedge_dv0 = make_float2(p.y - v1.y, v1.x - p.x);
const float2 dedge_dv1 = make_float2(v0.y - p.y, p.x - v0.x);
return thrust::make_tuple(
grad_edge * dedge_dp, grad_edge * dedge_dv0, grad_edge * dedge_dv1);
}
// The forward pass for computing the barycentric coordinates of a point
// relative to a triangle.
//
// Args:
// p: Coordinates of a point.
// v0, v1, v2: Coordinates of the triangle vertices.
//
// Returns
// bary: (w0, w1, w2) barycentric coordinates in the range [0, 1].
//
__device__ inline float3 BarycentricCoordsForward(
const float2& p,
const float2& v0,
const float2& v1,
const float2& v2) {
const float area = EdgeFunctionForward(v2, v0, v1) + kEpsilon;
const float w0 = EdgeFunctionForward(p, v1, v2) / area;
const float w1 = EdgeFunctionForward(p, v2, v0) / area;
const float w2 = EdgeFunctionForward(p, v0, v1) / area;
return make_float3(w0, w1, w2);
}
// The backward pass for computing the barycentric coordinates of a point
// relative to a triangle.
//
// Args:
// p: Coordinates of a point.
// v0, v1, v2: (x, y) coordinates of the triangle vertices.
// grad_bary_upstream: vec3<T> Upstream gradient for each of the
// barycentric coordaintes [grad_w0, grad_w1, grad_w2].
//
// Returns
// tuple of gradients for each of the triangle vertices:
// (float2 grad_v0, float2 grad_v1, float2 grad_v2)
//
__device__ inline thrust::tuple<float2, float2, float2, float2>
BarycentricCoordsBackward(
const float2& p,
const float2& v0,
const float2& v1,
const float2& v2,
const float3& grad_bary_upstream) {
const float area = EdgeFunctionForward(v2, v0, v1) + kEpsilon;
const float area2 = pow(area, 2.0);
const float e0 = EdgeFunctionForward(p, v1, v2);
const float e1 = EdgeFunctionForward(p, v2, v0);
const float e2 = EdgeFunctionForward(p, v0, v1);
const float grad_w0 = grad_bary_upstream.x;
const float grad_w1 = grad_bary_upstream.y;
const float grad_w2 = grad_bary_upstream.z;
// Calculate component of the gradient from each of w0, w1 and w2.
// e.g. for w0:
// dloss/dw0_v = dl/dw0 * dw0/dw0_top * dw0_top/dv
// + dl/dw0 * dw0/dw0_bot * dw0_bot/dv
const float dw0_darea = -e0 / (area2);
const float dw0_e0 = 1 / area;
const float dloss_d_w0area = grad_w0 * dw0_darea;
const float dloss_e0 = grad_w0 * dw0_e0;
auto de0_dv = EdgeFunctionBackward(p, v1, v2, dloss_e0);
auto dw0area_dv = EdgeFunctionBackward(v2, v0, v1, dloss_d_w0area);
const float2 dw0_p = thrust::get<0>(de0_dv);
const float2 dw0_dv0 = thrust::get<1>(dw0area_dv);
const float2 dw0_dv1 = thrust::get<1>(de0_dv) + thrust::get<2>(dw0area_dv);
const float2 dw0_dv2 = thrust::get<2>(de0_dv) + thrust::get<0>(dw0area_dv);
const float dw1_darea = -e1 / (area2);
const float dw1_e1 = 1 / area;
const float dloss_d_w1area = grad_w1 * dw1_darea;
const float dloss_e1 = grad_w1 * dw1_e1;
auto de1_dv = EdgeFunctionBackward(p, v2, v0, dloss_e1);
auto dw1area_dv = EdgeFunctionBackward(v2, v0, v1, dloss_d_w1area);
const float2 dw1_p = thrust::get<0>(de1_dv);
const float2 dw1_dv0 = thrust::get<2>(de1_dv) + thrust::get<1>(dw1area_dv);
const float2 dw1_dv1 = thrust::get<2>(dw1area_dv);
const float2 dw1_dv2 = thrust::get<1>(de1_dv) + thrust::get<0>(dw1area_dv);
const float dw2_darea = -e2 / (area2);
const float dw2_e2 = 1 / area;
const float dloss_d_w2area = grad_w2 * dw2_darea;
const float dloss_e2 = grad_w2 * dw2_e2;
auto de2_dv = EdgeFunctionBackward(p, v0, v1, dloss_e2);
auto dw2area_dv = EdgeFunctionBackward(v2, v0, v1, dloss_d_w2area);
const float2 dw2_p = thrust::get<0>(de2_dv);
const float2 dw2_dv0 = thrust::get<1>(de2_dv) + thrust::get<1>(dw2area_dv);
const float2 dw2_dv1 = thrust::get<2>(de2_dv) + thrust::get<2>(dw2area_dv);
const float2 dw2_dv2 = thrust::get<0>(dw2area_dv);
const float2 dbary_p = dw0_p + dw1_p + dw2_p;
const float2 dbary_dv0 = dw0_dv0 + dw1_dv0 + dw2_dv0;
const float2 dbary_dv1 = dw0_dv1 + dw1_dv1 + dw2_dv1;
const float2 dbary_dv2 = dw0_dv2 + dw1_dv2 + dw2_dv2;
return thrust::make_tuple(dbary_p, dbary_dv0, dbary_dv1, dbary_dv2);
}
// Forward pass for applying perspective correction to barycentric coordinates.
//
// Args:
// bary: Screen-space barycentric coordinates for a point
// z0, z1, z2: Camera-space z-coordinates of the triangle vertices
//
// Returns
// World-space barycentric coordinates
//
__device__ inline float3 BarycentricPerspectiveCorrectionForward(
const float3& bary,
const float z0,
const float z1,
const float z2) {
const float w0_top = bary.x * z1 * z2;
const float w1_top = z0 * bary.y * z2;
const float w2_top = z0 * z1 * bary.z;
const float denom = w0_top + w1_top + w2_top;
const float w0 = w0_top / denom;
const float w1 = w1_top / denom;
const float w2 = w2_top / denom;
return make_float3(w0, w1, w2);
}
// Backward pass for applying perspective correction to barycentric coordinates.
//
// Args:
// bary: Screen-space barycentric coordinates for a point
// z0, z1, z2: Camera-space z-coordinates of the triangle vertices
// grad_out: Upstream gradient of the loss with respect to the corrected
// barycentric coordinates.
//
// Returns a tuple of:
// grad_bary: Downstream gradient of the loss with respect to the the
// uncorrected barycentric coordinates.
// grad_z0, grad_z1, grad_z2: Downstream gradient of the loss with respect
// to the z-coordinates of the triangle verts
__device__ inline thrust::tuple<float3, float, float, float>
BarycentricPerspectiveCorrectionBackward(
const float3& bary,
const float z0,
const float z1,
const float z2,
const float3& grad_out) {
// Recompute forward pass
const float w0_top = bary.x * z1 * z2;
const float w1_top = z0 * bary.y * z2;
const float w2_top = z0 * z1 * bary.z;
const float denom = w0_top + w1_top + w2_top;
// Now do backward pass
const float grad_denom_top =
-w0_top * grad_out.x - w1_top * grad_out.y - w2_top * grad_out.z;
const float grad_denom = grad_denom_top / (denom * denom);
const float grad_w0_top = grad_denom + grad_out.x / denom;
const float grad_w1_top = grad_denom + grad_out.y / denom;
const float grad_w2_top = grad_denom + grad_out.z / denom;
const float grad_bary_x = grad_w0_top * z1 * z2;
const float grad_bary_y = grad_w1_top * z0 * z2;
const float grad_bary_z = grad_w2_top * z0 * z1;
const float3 grad_bary = make_float3(grad_bary_x, grad_bary_y, grad_bary_z);
const float grad_z0 = grad_w1_top * bary.y * z2 + grad_w2_top * bary.z * z1;
const float grad_z1 = grad_w0_top * bary.x * z2 + grad_w2_top * bary.z * z0;
const float grad_z2 = grad_w0_top * bary.x * z1 + grad_w1_top * bary.y * z0;
return thrust::make_tuple(grad_bary, grad_z0, grad_z1, grad_z2);
}
// Return minimum distance between line segment (v1 - v0) and point p.
//
// Args:
// p: Coordinates of a point.
// v0, v1: Coordinates of the end points of the line segment.
//
// Returns:
// non-square distance to the boundary of the triangle.
//
__device__ inline float
PointLineDistanceForward(const float2& p, const float2& a, const float2& b) {
const float2 ba = b - a;
float l2 = dot(ba, ba);
float t = dot(ba, p - a) / l2;
if (l2 <= kEpsilon) {
return dot(p - b, p - b);
}
t = __saturatef(t); // clamp to the interval [+0.0, 1.0]
const float2 p_proj = a + t * ba;
const float2 d = (p_proj - p);
return dot(d, d); // squared distance
}
// Backward pass for point to line distance in 2D.
//
// Args:
// p: Coordinates of a point.
// v0, v1: Coordinates of the end points of the line segment.
// grad_dist: Upstream gradient for the distance.
//
// Returns:
// tuple of gradients for each of the input points:
// (float2 grad_p, float2 grad_v0, float2 grad_v1)
//
__device__ inline thrust::tuple<float2, float2, float2>
PointLineDistanceBackward(
const float2& p,
const float2& v0,
const float2& v1,
const float& grad_dist) {
// Redo some of the forward pass calculations.
const float2 v1v0 = v1 - v0;
const float2 pv0 = p - v0;
const float t_bot = dot(v1v0, v1v0);
const float t_top = dot(v1v0, pv0);
float tt = t_top / t_bot;
tt = __saturatef(tt);
const float2 p_proj = (1.0f - tt) * v0 + tt * v1;
const float2 d = p - p_proj;
const float dist = sqrt(dot(d, d));
const float2 grad_p = -1.0f * grad_dist * 2.0f * (p_proj - p);
const float2 grad_v0 = grad_dist * (1.0f - tt) * 2.0f * (p_proj - p);
const float2 grad_v1 = grad_dist * tt * 2.0f * (p_proj - p);
return thrust::make_tuple(grad_p, grad_v0, grad_v1);
}
// The forward pass for calculating the shortest distance between a point
// and a triangle.
//
// Args:
// p: Coordinates of a point.
// v0, v1, v2: Coordinates of the three triangle vertices.
//
// Returns:
// shortest absolute distance from a point to a triangle.
//
__device__ inline float PointTriangleDistanceForward(
const float2& p,
const float2& v0,
const float2& v1,
const float2& v2) {
// Compute distance to all 3 edges of the triangle and return the min.
const float e01_dist = PointLineDistanceForward(p, v0, v1);
const float e02_dist = PointLineDistanceForward(p, v0, v2);
const float e12_dist = PointLineDistanceForward(p, v1, v2);
const float edge_dist = fminf(fminf(e01_dist, e02_dist), e12_dist);
return edge_dist;
}
// Backward pass for point triangle distance.
//
// Args:
// p: Coordinates of a point.
// v0, v1, v2: Coordinates of the three triangle vertices.
// grad_dist: Upstream gradient for the distance.
//
// Returns:
// tuple of gradients for each of the triangle vertices:
// (float2 grad_v0, float2 grad_v1, float2 grad_v2)
//
__device__ inline thrust::tuple<float2, float2, float2, float2>
PointTriangleDistanceBackward(
const float2& p,
const float2& v0,
const float2& v1,
const float2& v2,
const float& grad_dist) {
// Compute distance to all 3 edges of the triangle.
const float e01_dist = PointLineDistanceForward(p, v0, v1);
const float e02_dist = PointLineDistanceForward(p, v0, v2);
const float e12_dist = PointLineDistanceForward(p, v1, v2);
// Initialize output tensors.
float2 grad_v0 = make_float2(0.0f, 0.0f);
float2 grad_v1 = make_float2(0.0f, 0.0f);
float2 grad_v2 = make_float2(0.0f, 0.0f);
float2 grad_p = make_float2(0.0f, 0.0f);
// Find which edge is the closest and return PointLineDistanceBackward for
// that edge.
if (e01_dist <= e02_dist && e01_dist <= e12_dist) {
// Closest edge is v1 - v0.
auto grad_e01 = PointLineDistanceBackward(p, v0, v1, grad_dist);
grad_p = thrust::get<0>(grad_e01);
grad_v0 = thrust::get<1>(grad_e01);
grad_v1 = thrust::get<2>(grad_e01);
} else if (e02_dist <= e01_dist && e02_dist <= e12_dist) {
// Closest edge is v2 - v0.
auto grad_e02 = PointLineDistanceBackward(p, v0, v2, grad_dist);
grad_p = thrust::get<0>(grad_e02);
grad_v0 = thrust::get<1>(grad_e02);
grad_v2 = thrust::get<2>(grad_e02);
} else if (e12_dist <= e01_dist && e12_dist <= e02_dist) {
// Closest edge is v2 - v1.
auto grad_e12 = PointLineDistanceBackward(p, v1, v2, grad_dist);
grad_p = thrust::get<0>(grad_e12);
grad_v1 = thrust::get<1>(grad_e12);
grad_v2 = thrust::get<2>(grad_e12);
}
return thrust::make_tuple(grad_p, grad_v0, grad_v1, grad_v2);
}

View File

@@ -0,0 +1,397 @@
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#include <ATen/ATen.h>
#include <algorithm>
#include <type_traits>
#include "vec2.h"
#include "vec3.h"
// Set epsilon for preventing floating point errors and division by 0.
const auto kEpsilon = 1e-30;
// Determines whether a point p is on the right side of a 2D line segment
// given by the end points v0, v1.
//
// Args:
// p: vec2 Coordinates of a point.
// v0, v1: vec2 Coordinates of the end points of the edge.
//
// Returns:
// area: The signed area of the parallelogram given by the vectors
// A = p - v0
// B = v1 - v0
//
// v1 ________
// /\ /
// A / \ /
// / \ /
// v0 /______\/
// B p
//
// The area can also be interpreted as the cross product A x B.
// If the sign of the area is positive, the point p is on the
// right side of the edge. Negative area indicates the point is on
// the left side of the edge. i.e. for an edge v1 - v0:
//
// v1
// /
// /
// - / +
// /
// /
// v0
//
template <typename T>
T EdgeFunctionForward(const vec2<T>& p, const vec2<T>& v0, const vec2<T>& v1) {
const T edge = (p.x - v0.x) * (v1.y - v0.y) - (p.y - v0.y) * (v1.x - v0.x);
return edge;
}
// Backward pass for the edge function returning partial dervivatives for each
// of the input points.
//
// Args:
// p: vec2 Coordinates of a point.
// v0, v1: vec2 Coordinates of the end points of the edge.
// grad_edge: Upstream gradient for output from edge function.
//
// Returns:
// tuple of gradients for each of the input points:
// (vec2<T> d_edge_dp, vec2<T> d_edge_dv0, vec2<T> d_edge_dv1)
//
template <typename T>
inline std::tuple<vec2<T>, vec2<T>, vec2<T>> EdgeFunctionBackward(
const vec2<T>& p,
const vec2<T>& v0,
const vec2<T>& v1,
const T grad_edge) {
const vec2<T> dedge_dp(v1.y - v0.y, v0.x - v1.x);
const vec2<T> dedge_dv0(p.y - v1.y, v1.x - p.x);
const vec2<T> dedge_dv1(v0.y - p.y, p.x - v0.x);
return std::make_tuple(
grad_edge * dedge_dp, grad_edge * dedge_dv0, grad_edge * dedge_dv1);
}
// The forward pass for computing the barycentric coordinates of a point
// relative to a triangle.
// Ref:
// https://www.scratchapixel.com/lessons/3d-basic-rendering/ray-tracing-rendering-a-triangle/barycentric-coordinates
//
// Args:
// p: Coordinates of a point.
// v0, v1, v2: Coordinates of the triangle vertices.
//
// Returns
// bary: (w0, w1, w2) barycentric coordinates in the range [0, 1].
//
template <typename T>
vec3<T> BarycentricCoordinatesForward(
const vec2<T>& p,
const vec2<T>& v0,
const vec2<T>& v1,
const vec2<T>& v2) {
const T area = EdgeFunctionForward(v2, v0, v1) + kEpsilon;
const T w0 = EdgeFunctionForward(p, v1, v2) / area;
const T w1 = EdgeFunctionForward(p, v2, v0) / area;
const T w2 = EdgeFunctionForward(p, v0, v1) / area;
return vec3<T>(w0, w1, w2);
}
// The backward pass for computing the barycentric coordinates of a point
// relative to a triangle.
//
// Args:
// p: Coordinates of a point.
// v0, v1, v2: (x, y) coordinates of the triangle vertices.
// grad_bary_upstream: vec3<T> Upstream gradient for each of the
// barycentric coordaintes [grad_w0, grad_w1, grad_w2].
//
// Returns
// tuple of gradients for each of the triangle vertices:
// (vec2<T> grad_v0, vec2<T> grad_v1, vec2<T> grad_v2)
//
template <typename T>
inline std::tuple<vec2<T>, vec2<T>, vec2<T>, vec2<T>> BarycentricCoordsBackward(
const vec2<T>& p,
const vec2<T>& v0,
const vec2<T>& v1,
const vec2<T>& v2,
const vec3<T>& grad_bary_upstream) {
const T area = EdgeFunctionForward(v2, v0, v1) + kEpsilon;
const T area2 = pow(area, 2.0f);
const T area_inv = 1.0f / area;
const T e0 = EdgeFunctionForward(p, v1, v2);
const T e1 = EdgeFunctionForward(p, v2, v0);
const T e2 = EdgeFunctionForward(p, v0, v1);
const T grad_w0 = grad_bary_upstream.x;
const T grad_w1 = grad_bary_upstream.y;
const T grad_w2 = grad_bary_upstream.z;
// Calculate component of the gradient from each of w0, w1 and w2.
// e.g. for w0:
// dloss/dw0_v = dl/dw0 * dw0/dw0_top * dw0_top/dv
// + dl/dw0 * dw0/dw0_bot * dw0_bot/dv
const T dw0_darea = -e0 / (area2);
const T dw0_e0 = area_inv;
const T dloss_d_w0area = grad_w0 * dw0_darea;
const T dloss_e0 = grad_w0 * dw0_e0;
auto de0_dv = EdgeFunctionBackward(p, v1, v2, dloss_e0);
auto dw0area_dv = EdgeFunctionBackward(v2, v0, v1, dloss_d_w0area);
const vec2<T> dw0_p = std::get<0>(de0_dv);
const vec2<T> dw0_dv0 = std::get<1>(dw0area_dv);
const vec2<T> dw0_dv1 = std::get<1>(de0_dv) + std::get<2>(dw0area_dv);
const vec2<T> dw0_dv2 = std::get<2>(de0_dv) + std::get<0>(dw0area_dv);
const T dw1_darea = -e1 / (area2);
const T dw1_e1 = area_inv;
const T dloss_d_w1area = grad_w1 * dw1_darea;
const T dloss_e1 = grad_w1 * dw1_e1;
auto de1_dv = EdgeFunctionBackward(p, v2, v0, dloss_e1);
auto dw1area_dv = EdgeFunctionBackward(v2, v0, v1, dloss_d_w1area);
const vec2<T> dw1_p = std::get<0>(de1_dv);
const vec2<T> dw1_dv0 = std::get<2>(de1_dv) + std::get<1>(dw1area_dv);
const vec2<T> dw1_dv1 = std::get<2>(dw1area_dv);
const vec2<T> dw1_dv2 = std::get<1>(de1_dv) + std::get<0>(dw1area_dv);
const T dw2_darea = -e2 / (area2);
const T dw2_e2 = area_inv;
const T dloss_d_w2area = grad_w2 * dw2_darea;
const T dloss_e2 = grad_w2 * dw2_e2;
auto de2_dv = EdgeFunctionBackward(p, v0, v1, dloss_e2);
auto dw2area_dv = EdgeFunctionBackward(v2, v0, v1, dloss_d_w2area);
const vec2<T> dw2_p = std::get<0>(de2_dv);
const vec2<T> dw2_dv0 = std::get<1>(de2_dv) + std::get<1>(dw2area_dv);
const vec2<T> dw2_dv1 = std::get<2>(de2_dv) + std::get<2>(dw2area_dv);
const vec2<T> dw2_dv2 = std::get<0>(dw2area_dv);
const vec2<T> dbary_p = dw0_p + dw1_p + dw2_p;
const vec2<T> dbary_dv0 = dw0_dv0 + dw1_dv0 + dw2_dv0;
const vec2<T> dbary_dv1 = dw0_dv1 + dw1_dv1 + dw2_dv1;
const vec2<T> dbary_dv2 = dw0_dv2 + dw1_dv2 + dw2_dv2;
return std::make_tuple(dbary_p, dbary_dv0, dbary_dv1, dbary_dv2);
}
// Forward pass for applying perspective correction to barycentric coordinates.
//
// Args:
// bary: Screen-space barycentric coordinates for a point
// z0, z1, z2: Camera-space z-coordinates of the triangle vertices
//
// Returns
// World-space barycentric coordinates
//
template <typename T>
inline vec3<T> BarycentricPerspectiveCorrectionForward(
const vec3<T>& bary,
const T z0,
const T z1,
const T z2) {
const T w0_top = bary.x * z1 * z2;
const T w1_top = bary.y * z0 * z2;
const T w2_top = bary.z * z0 * z1;
const T denom = w0_top + w1_top + w2_top;
const T w0 = w0_top / denom;
const T w1 = w1_top / denom;
const T w2 = w2_top / denom;
return vec3<T>(w0, w1, w2);
}
// Backward pass for applying perspective correction to barycentric coordinates.
//
// Args:
// bary: Screen-space barycentric coordinates for a point
// z0, z1, z2: Camera-space z-coordinates of the triangle vertices
// grad_out: Upstream gradient of the loss with respect to the corrected
// barycentric coordinates.
//
// Returns a tuple of:
// grad_bary: Downstream gradient of the loss with respect to the the
// uncorrected barycentric coordinates.
// grad_z0, grad_z1, grad_z2: Downstream gradient of the loss with respect
// to the z-coordinates of the triangle verts
template <typename T>
inline std::tuple<vec3<T>, T, T, T> BarycentricPerspectiveCorrectionBackward(
const vec3<T>& bary,
const T z0,
const T z1,
const T z2,
const vec3<T>& grad_out) {
// Recompute forward pass
const T w0_top = bary.x * z1 * z2;
const T w1_top = bary.y * z0 * z2;
const T w2_top = bary.z * z0 * z1;
const T denom = w0_top + w1_top + w2_top;
// Now do backward pass
const T grad_denom_top =
-w0_top * grad_out.x - w1_top * grad_out.y - w2_top * grad_out.z;
const T grad_denom = grad_denom_top / (denom * denom);
const T grad_w0_top = grad_denom + grad_out.x / denom;
const T grad_w1_top = grad_denom + grad_out.y / denom;
const T grad_w2_top = grad_denom + grad_out.z / denom;
const T grad_bary_x = grad_w0_top * z1 * z2;
const T grad_bary_y = grad_w1_top * z0 * z2;
const T grad_bary_z = grad_w2_top * z0 * z1;
const vec3<T> grad_bary(grad_bary_x, grad_bary_y, grad_bary_z);
const T grad_z0 = grad_w1_top * bary.y * z2 + grad_w2_top * bary.z * z1;
const T grad_z1 = grad_w0_top * bary.x * z2 + grad_w2_top * bary.z * z0;
const T grad_z2 = grad_w0_top * bary.x * z1 + grad_w1_top * bary.y * z0;
return std::make_tuple(grad_bary, grad_z0, grad_z1, grad_z2);
}
// Calculate minimum distance between a line segment (v1 - v0) and point p.
//
// Args:
// p: Coordinates of a point.
// v0, v1: Coordinates of the end points of the line segment.
//
// Returns:
// non-square distance of the point to the line.
//
// Consider the line extending the segment - this can be parameterized as:
// v0 + t (v1 - v0).
//
// First find the projection of point p onto the line. It falls where:
// t = [(p - v0) . (v1 - v0)] / |v1 - v0|^2
// where . is the dot product.
//
// The parameter t is clamped from [0, 1] to handle points outside the
// segment (v1 - v0).
//
// Once the projection of the point on the segment is known, the distance from
// p to the projection gives the minimum distance to the segment.
//
template <typename T>
T PointLineDistanceForward(
const vec2<T>& p,
const vec2<T>& v0,
const vec2<T>& v1) {
const vec2<T> v1v0 = v1 - v0;
const T l2 = dot(v1v0, v1v0);
if (l2 <= kEpsilon) {
return sqrt(dot(p - v1, p - v1));
}
const T t = dot(v1v0, p - v0) / l2;
const T tt = std::min(std::max(t, 0.00f), 1.00f);
const vec2<T> p_proj = v0 + tt * v1v0;
return dot(p - p_proj, p - p_proj);
}
// Backward pass for point to line distance in 2D.
//
// Args:
// p: Coordinates of a point.
// v0, v1: Coordinates of the end points of the line segment.
// grad_dist: Upstream gradient for the distance.
//
// Returns:
// tuple of gradients for each of the input points:
// (vec2<T> grad_p, vec2<T> grad_v0, vec2<T> grad_v1)
//
template <typename T>
inline std::tuple<vec2<T>, vec2<T>, vec2<T>> PointLineDistanceBackward(
const vec2<T>& p,
const vec2<T>& v0,
const vec2<T>& v1,
const T& grad_dist) {
// Redo some of the forward pass calculations.
const vec2<T> v1v0 = v1 - v0;
const vec2<T> pv0 = p - v0;
const T t_bot = dot(v1v0, v1v0);
const T t_top = dot(v1v0, pv0);
const T t = t_top / t_bot;
const T tt = std::min(std::max(t, 0.00f), 1.00f);
const vec2<T> p_proj = (1.0f - tt) * v0 + tt * v1;
const vec2<T> grad_v0 = grad_dist * (1.0f - tt) * 2.0f * (p_proj - p);
const vec2<T> grad_v1 = grad_dist * tt * 2.0f * (p_proj - p);
const vec2<T> grad_p = -1.0f * grad_dist * 2.0f * (p_proj - p);
return std::make_tuple(grad_p, grad_v0, grad_v1);
}
// The forward pass for calculating the shortest distance between a point
// and a triangle.
// Ref: https://www.randygaul.net/2014/07/23/distance-point-to-line-segment/
//
// Args:
// p: Coordinates of a point.
// v0, v1, v2: Coordinates of the three triangle vertices.
//
// Returns:
// shortest absolute distance from a point to a triangle.
//
//
template <typename T>
T PointTriangleDistanceForward(
const vec2<T>& p,
const vec2<T>& v0,
const vec2<T>& v1,
const vec2<T>& v2) {
// Compute distance of point to 3 edges of the triangle and return the
// minimum value.
const T e01_dist = PointLineDistanceForward(p, v0, v1);
const T e02_dist = PointLineDistanceForward(p, v0, v2);
const T e12_dist = PointLineDistanceForward(p, v1, v2);
const T edge_dist = std::min(std::min(e01_dist, e02_dist), e12_dist);
return edge_dist;
}
// Backward pass for point triangle distance.
//
// Args:
// p: Coordinates of a point.
// v0, v1, v2: Coordinates of the three triangle vertices.
// grad_dist: Upstream gradient for the distance.
//
// Returns:
// tuple of gradients for each of the triangle vertices:
// (vec2<T> grad_v0, vec2<T> grad_v1, vec2<T> grad_v2)
//
template <typename T>
inline std::tuple<vec2<T>, vec2<T>, vec2<T>, vec2<T>>
PointTriangleDistanceBackward(
const vec2<T>& p,
const vec2<T>& v0,
const vec2<T>& v1,
const vec2<T>& v2,
const T& grad_dist) {
// Compute distance to all 3 edges of the triangle.
const T e01_dist = PointLineDistanceForward(p, v0, v1);
const T e02_dist = PointLineDistanceForward(p, v0, v2);
const T e12_dist = PointLineDistanceForward(p, v1, v2);
// Initialize output tensors.
vec2<T> grad_v0(0.0f, 0.0f);
vec2<T> grad_v1(0.0f, 0.0f);
vec2<T> grad_v2(0.0f, 0.0f);
vec2<T> grad_p(0.0f, 0.0f);
// Find which edge is the closest and return PointLineDistanceBackward for
// that edge.
if (e01_dist <= e02_dist && e01_dist <= e12_dist) {
// Closest edge is v1 - v0.
auto grad_e01 = PointLineDistanceBackward(p, v0, v1, grad_dist);
grad_p = std::get<0>(grad_e01);
grad_v0 = std::get<1>(grad_e01);
grad_v1 = std::get<2>(grad_e01);
} else if (e02_dist <= e01_dist && e02_dist <= e12_dist) {
// Closest edge is v2 - v0.
auto grad_e02 = PointLineDistanceBackward(p, v0, v2, grad_dist);
grad_p = std::get<0>(grad_e02);
grad_v0 = std::get<1>(grad_e02);
grad_v2 = std::get<2>(grad_e02);
} else if (e12_dist <= e01_dist && e12_dist <= e02_dist) {
// Closest edge is v2 - v1.
auto grad_e12 = PointLineDistanceBackward(p, v1, v2, grad_dist);
grad_p = std::get<0>(grad_e12);
grad_v1 = std::get<1>(grad_e12);
grad_v2 = std::get<2>(grad_e12);
}
return std::make_tuple(grad_p, grad_v0, grad_v1, grad_v2);
}

View File

@@ -0,0 +1,803 @@
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#include <float.h>
#include <math.h>
#include <thrust/tuple.h>
#include <torch/extension.h>
#include <cstdio>
#include <tuple>
#include "float_math.cuh"
#include "geometry_utils.cuh"
#include "rasterize_points/bitmask.cuh"
#include "rasterize_points/rasterization_utils.cuh"
namespace {
// A structure for holding details about a pixel.
struct Pixel {
float z;
int64_t idx;
float dist;
float3 bary;
};
__device__ bool operator<(const Pixel& a, const Pixel& b) {
return a.z < b.z;
}
__device__ float FloatMin3(const float p1, const float p2, const float p3) {
return fminf(p1, fminf(p2, p3));
}
__device__ float FloatMax3(const float p1, const float p2, const float p3) {
return fmaxf(p1, fmaxf(p2, p3));
}
// Get the xyz coordinates of the three vertices for the face given by the
// index face_idx into face_verts.
__device__ thrust::tuple<float3, float3, float3> GetSingleFaceVerts(
const float* face_verts,
int face_idx) {
const float x0 = face_verts[face_idx * 9 + 0];
const float y0 = face_verts[face_idx * 9 + 1];
const float z0 = face_verts[face_idx * 9 + 2];
const float x1 = face_verts[face_idx * 9 + 3];
const float y1 = face_verts[face_idx * 9 + 4];
const float z1 = face_verts[face_idx * 9 + 5];
const float x2 = face_verts[face_idx * 9 + 6];
const float y2 = face_verts[face_idx * 9 + 7];
const float z2 = face_verts[face_idx * 9 + 8];
const float3 v0xyz = make_float3(x0, y0, z0);
const float3 v1xyz = make_float3(x1, y1, z1);
const float3 v2xyz = make_float3(x2, y2, z2);
return thrust::make_tuple(v0xyz, v1xyz, v2xyz);
}
// Get the min/max x/y/z values for the face given by vertices v0, v1, v2.
__device__ thrust::tuple<float2, float2, float2>
GetFaceBoundingBox(float3 v0, float3 v1, float3 v2) {
const float xmin = FloatMin3(v0.x, v1.x, v2.x);
const float ymin = FloatMin3(v0.y, v1.y, v2.y);
const float zmin = FloatMin3(v0.z, v1.z, v2.z);
const float xmax = FloatMax3(v0.x, v1.x, v2.x);
const float ymax = FloatMax3(v0.y, v1.y, v2.y);
const float zmax = FloatMax3(v0.z, v1.z, v2.z);
return thrust::make_tuple(
make_float2(xmin, xmax),
make_float2(ymin, ymax),
make_float2(zmin, zmax));
}
// Check if the point (px, py) lies outside the face bounding box face_bbox.
// Return true if the point is outside.
__device__ bool CheckPointOutsideBoundingBox(
float3 v0,
float3 v1,
float3 v2,
float blur_radius,
float2 pxy) {
const auto bbox = GetFaceBoundingBox(v0, v1, v2);
const float2 xlims = thrust::get<0>(bbox);
const float2 ylims = thrust::get<1>(bbox);
const float2 zlims = thrust::get<2>(bbox);
const float x_min = xlims.x - blur_radius;
const float y_min = ylims.x - blur_radius;
const float x_max = xlims.y + blur_radius;
const float y_max = ylims.y + blur_radius;
// Check if the current point is oustside the triangle bounding box.
return (pxy.x > x_max || pxy.x < x_min || pxy.y > y_max || pxy.y < y_min);
}
// This function checks if a pixel given by xy location pxy lies within the
// face with index face_idx in face_verts. One of the inputs is a list (q)
// which contains Pixel structs with the indices of the faces which intersect
// with this pixel sorted by closest z distance. If the point pxy lies in the
// face, the list (q) is updated and re-orderered in place. In addition
// the auxillary variables q_size, q_max_z and q_max_idx are also modified.
// This code is shared between RasterizeMeshesNaiveCudaKernel and
// RasterizeMeshesFineCudaKernel.
template <typename FaceQ>
__device__ void CheckPixelInsideFace(
const float* face_verts, // (N, P, 3)
int face_idx,
int& q_size,
float& q_max_z,
int& q_max_idx,
FaceQ& q,
float blur_radius,
float2 pxy, // Coordinates of the pixel
int K,
bool perspective_correct) {
const auto v012 = GetSingleFaceVerts(face_verts, face_idx);
const float3 v0 = thrust::get<0>(v012);
const float3 v1 = thrust::get<1>(v012);
const float3 v2 = thrust::get<2>(v012);
// Only need xy for barycentric coordinates and distance calculations.
const float2 v0xy = make_float2(v0.x, v0.y);
const float2 v1xy = make_float2(v1.x, v1.y);
const float2 v2xy = make_float2(v2.x, v2.y);
// Perform checks and skip if:
// 1. the face is behind the camera
// 2. the face has very small face area
// 3. the pixel is outside the face bbox
const float zmax = FloatMax3(v0.z, v1.z, v2.z);
const bool outside_bbox = CheckPointOutsideBoundingBox(
v0, v1, v2, sqrt(blur_radius), pxy); // use sqrt of blur for bbox
const float face_area = EdgeFunctionForward(v0xy, v1xy, v2xy);
const bool zero_face_area =
(face_area <= kEpsilon && face_area >= -1.0f * kEpsilon);
if (zmax < 0 || outside_bbox || zero_face_area) {
return;
}
// Calculate barycentric coords and euclidean dist to triangle.
const float3 p_bary0 = BarycentricCoordsForward(pxy, v0xy, v1xy, v2xy);
const float3 p_bary = !perspective_correct
? p_bary0
: BarycentricPerspectiveCorrectionForward(p_bary0, v0.z, v1.z, v2.z);
const float pz = p_bary.x * v0.z + p_bary.y * v1.z + p_bary.z * v2.z;
if (pz < 0) {
return; // Face is behind the image plane.
}
// Get abs squared distance
const float dist = PointTriangleDistanceForward(pxy, v0xy, v1xy, v2xy);
// Use the bary coordinates to determine if the point is inside the face.
const bool inside = p_bary.x > 0.0f && p_bary.y > 0.0f && p_bary.z > 0.0f;
const float signed_dist = inside ? -dist : dist;
// Check if pixel is outside blur region
if (!inside && dist >= blur_radius) {
return;
}
if (q_size < K) {
// Just insert it.
q[q_size] = {pz, face_idx, signed_dist, p_bary};
if (pz > q_max_z) {
q_max_z = pz;
q_max_idx = q_size;
}
q_size++;
} else if (pz < q_max_z) {
// Overwrite the old max, and find the new max.
q[q_max_idx] = {pz, face_idx, signed_dist, p_bary};
q_max_z = pz;
for (int i = 0; i < K; i++) {
if (q[i].z > q_max_z) {
q_max_z = q[i].z;
q_max_idx = i;
}
}
}
}
} // namespace
// ****************************************************************************
// * NAIVE RASTERIZATION *
// ****************************************************************************
__global__ void RasterizeMeshesNaiveCudaKernel(
const float* face_verts,
const int64_t* mesh_to_face_first_idx,
const int64_t* num_faces_per_mesh,
float blur_radius,
bool perspective_correct,
int N,
int H,
int W,
int K,
int64_t* face_idxs,
float* zbuf,
float* pix_dists,
float* bary) {
// Simple version: One thread per output pixel
int num_threads = gridDim.x * blockDim.x;
int tid = blockDim.x * blockIdx.x + threadIdx.x;
for (int i = tid; i < N * H * W; i += num_threads) {
// Convert linear index to 3D index
const int n = i / (H * W); // batch index.
const int pix_idx = i % (H * W);
const int yi = pix_idx / H;
const int xi = pix_idx % W;
// screen coordinates to ndc coordiantes of pixel.
const float xf = PixToNdc(xi, W);
const float yf = PixToNdc(yi, H);
const float2 pxy = make_float2(xf, yf);
// For keeping track of the K closest points we want a data structure
// that (1) gives O(1) access to the closest point for easy comparisons,
// and (2) allows insertion of new elements. In the CPU version we use
// std::priority_queue; then (2) is O(log K). We can't use STL
// containers in CUDA; we could roll our own max heap in an array, but
// that would likely have a lot of warp divergence so we do something
// simpler instead: keep the elements in an unsorted array, but keep
// track of the max value and the index of the max value. Then (1) is
// still O(1) time, while (2) is O(K) with a clean loop. Since K <= 8
// this should be fast enough for our purposes.
Pixel q[kMaxPointsPerPixel];
int q_size = 0;
float q_max_z = -1000;
int q_max_idx = -1;
// Using the batch index of the thread get the start and stop
// indices for the faces.
const int64_t face_start_idx = mesh_to_face_first_idx[n];
const int64_t face_stop_idx = face_start_idx + num_faces_per_mesh[n];
// Loop through the faces in the mesh.
for (int f = face_start_idx; f < face_stop_idx; ++f) {
// Check if the pixel pxy is inside the face bounding box and if it is,
// update q, q_size, q_max_z and q_max_idx in place.
CheckPixelInsideFace(
face_verts,
f,
q_size,
q_max_z,
q_max_idx,
q,
blur_radius,
pxy,
K,
perspective_correct);
}
// TODO: make sorting an option as only top k is needed, not sorted values.
BubbleSort(q, q_size);
int idx = n * H * W * K + yi * H * K + xi * K;
for (int k = 0; k < q_size; ++k) {
face_idxs[idx + k] = q[k].idx;
zbuf[idx + k] = q[k].z;
pix_dists[idx + k] = q[k].dist;
bary[(idx + k) * 3 + 0] = q[k].bary.x;
bary[(idx + k) * 3 + 1] = q[k].bary.y;
bary[(idx + k) * 3 + 2] = q[k].bary.z;
}
}
}
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
RasterizeMeshesNaiveCuda(
const torch::Tensor& face_verts,
const torch::Tensor& mesh_to_faces_packed_first_idx,
const torch::Tensor& num_faces_per_mesh,
const int image_size,
const float blur_radius,
const int num_closest,
bool perspective_correct) {
if (face_verts.ndimension() != 3 || face_verts.size(1) != 3 ||
face_verts.size(2) != 3) {
AT_ERROR("face_verts must have dimensions (num_faces, 3, 3)");
}
if (num_faces_per_mesh.size(0) != mesh_to_faces_packed_first_idx.size(0)) {
AT_ERROR(
"num_faces_per_mesh must have save size first dimension as mesh_to_faces_packed_first_idx");
}
if (num_closest > kMaxPointsPerPixel) {
std::stringstream ss;
ss << "Must have points_per_pixel <= " << kMaxPointsPerPixel;
AT_ERROR(ss.str());
}
const int N = num_faces_per_mesh.size(0); // batch size.
const int H = image_size; // Assume square images.
const int W = image_size;
const int K = num_closest;
auto long_opts = face_verts.options().dtype(torch::kInt64);
auto float_opts = face_verts.options().dtype(torch::kFloat32);
torch::Tensor face_idxs = torch::full({N, H, W, K}, -1, long_opts);
torch::Tensor zbuf = torch::full({N, H, W, K}, -1, float_opts);
torch::Tensor pix_dists = torch::full({N, H, W, K}, -1, float_opts);
torch::Tensor bary = torch::full({N, H, W, K, 3}, -1, float_opts);
const size_t blocks = 1024;
const size_t threads = 64;
RasterizeMeshesNaiveCudaKernel<<<blocks, threads>>>(
face_verts.contiguous().data<float>(),
mesh_to_faces_packed_first_idx.contiguous().data<int64_t>(),
num_faces_per_mesh.contiguous().data<int64_t>(),
blur_radius,
perspective_correct,
N,
H,
W,
K,
face_idxs.contiguous().data<int64_t>(),
zbuf.contiguous().data<float>(),
pix_dists.contiguous().data<float>(),
bary.contiguous().data<float>());
return std::make_tuple(face_idxs, zbuf, bary, pix_dists);
}
// ****************************************************************************
// * BACKWARD PASS *
// ****************************************************************************
// TODO: benchmark parallelizing over faces_verts instead of over pixels.
__global__ void RasterizeMeshesBackwardCudaKernel(
const float* face_verts, // (F, 3, 3)
const int64_t* pix_to_face, // (N, H, W, K)
bool perspective_correct,
int N,
int F,
int H,
int W,
int K,
const float* grad_zbuf, // (N, H, W, K)
const float* grad_bary, // (N, H, W, K, 3)
const float* grad_dists, // (N, H, W, K)
float* grad_face_verts) { // (F, 3, 3)
// Parallelize over each pixel in images of
// size H * W, for each image in the batch of size N.
const int num_threads = gridDim.x * blockDim.x;
const int tid = blockIdx.x * blockDim.x + threadIdx.x;
for (int t_i = tid; t_i < N * H * W; t_i += num_threads) {
// Convert linear index to 3D index
const int n = t_i / (H * W); // batch index.
const int pix_idx = t_i % (H * W);
const int yi = pix_idx / H;
const int xi = pix_idx % W;
const float xf = PixToNdc(xi, W);
const float yf = PixToNdc(yi, H);
const float2 pxy = make_float2(xf, yf);
// Loop over all the faces for this pixel.
for (int k = 0; k < K; k++) {
// Index into (N, H, W, K, :) grad tensors
const int i =
n * H * W * K + yi * H * K + xi * K + k; // pixel index + face index
const int f = pix_to_face[i];
if (f < 0) {
continue; // padded face.
}
// Get xyz coordinates of the three face vertices.
const auto v012 = GetSingleFaceVerts(face_verts, f);
const float3 v0 = thrust::get<0>(v012);
const float3 v1 = thrust::get<1>(v012);
const float3 v2 = thrust::get<2>(v012);
// Only neex xy for barycentric coordinate and distance calculations.
const float2 v0xy = make_float2(v0.x, v0.y);
const float2 v1xy = make_float2(v1.x, v1.y);
const float2 v2xy = make_float2(v2.x, v2.y);
// Get upstream gradients for the face.
const float grad_dist_upstream = grad_dists[i];
const float grad_zbuf_upstream = grad_zbuf[i];
const float grad_bary_upstream_w0 = grad_bary[i * 3 + 0];
const float grad_bary_upstream_w1 = grad_bary[i * 3 + 1];
const float grad_bary_upstream_w2 = grad_bary[i * 3 + 2];
const float3 grad_bary_upstream = make_float3(
grad_bary_upstream_w0, grad_bary_upstream_w1, grad_bary_upstream_w2);
const float3 bary0 = BarycentricCoordsForward(pxy, v0xy, v1xy, v2xy);
const float3 bary = !perspective_correct
? bary0
: BarycentricPerspectiveCorrectionForward(bary0, v0.z, v1.z, v2.z);
const bool inside = bary.x > 0.0f && bary.y > 0.0f && bary.z > 0.0f;
const float sign = inside ? -1.0f : 1.0f;
// TODO(T52813608) Add support for non-square images.
auto grad_dist_f = PointTriangleDistanceBackward(
pxy, v0xy, v1xy, v2xy, sign * grad_dist_upstream);
const float2 ddist_d_v0 = thrust::get<1>(grad_dist_f);
const float2 ddist_d_v1 = thrust::get<2>(grad_dist_f);
const float2 ddist_d_v2 = thrust::get<3>(grad_dist_f);
// Upstream gradient for barycentric coords from zbuf calculation:
// zbuf = bary_w0 * z0 + bary_w1 * z1 + bary_w2 * z2
// Therefore
// d_zbuf/d_bary_w0 = z0
// d_zbuf/d_bary_w1 = z1
// d_zbuf/d_bary_w2 = z2
const float3 d_zbuf_d_bary = make_float3(v0.z, v1.z, v2.z);
// Total upstream barycentric gradients are the sum of
// external upstream gradients and contribution from zbuf.
const float3 grad_bary_f_sum =
(grad_bary_upstream + grad_zbuf_upstream * d_zbuf_d_bary);
float3 grad_bary0 = grad_bary_f_sum;
float dz0_persp = 0.0f, dz1_persp = 0.0f, dz2_persp = 0.0f;
if (perspective_correct) {
auto perspective_grads = BarycentricPerspectiveCorrectionBackward(
bary0, v0.z, v1.z, v2.z, grad_bary_f_sum);
grad_bary0 = thrust::get<0>(perspective_grads);
dz0_persp = thrust::get<1>(perspective_grads);
dz1_persp = thrust::get<2>(perspective_grads);
dz2_persp = thrust::get<3>(perspective_grads);
}
auto grad_bary_f =
BarycentricCoordsBackward(pxy, v0xy, v1xy, v2xy, grad_bary0);
const float2 dbary_d_v0 = thrust::get<1>(grad_bary_f);
const float2 dbary_d_v1 = thrust::get<2>(grad_bary_f);
const float2 dbary_d_v2 = thrust::get<3>(grad_bary_f);
atomicAdd(grad_face_verts + f * 9 + 0, dbary_d_v0.x + ddist_d_v0.x);
atomicAdd(grad_face_verts + f * 9 + 1, dbary_d_v0.y + ddist_d_v0.y);
atomicAdd(
grad_face_verts + f * 9 + 2, grad_zbuf_upstream * bary.x + dz0_persp);
atomicAdd(grad_face_verts + f * 9 + 3, dbary_d_v1.x + ddist_d_v1.x);
atomicAdd(grad_face_verts + f * 9 + 4, dbary_d_v1.y + ddist_d_v1.y);
atomicAdd(
grad_face_verts + f * 9 + 5, grad_zbuf_upstream * bary.y + dz1_persp);
atomicAdd(grad_face_verts + f * 9 + 6, dbary_d_v2.x + ddist_d_v2.x);
atomicAdd(grad_face_verts + f * 9 + 7, dbary_d_v2.y + ddist_d_v2.y);
atomicAdd(
grad_face_verts + f * 9 + 8, grad_zbuf_upstream * bary.z + dz2_persp);
}
}
}
torch::Tensor RasterizeMeshesBackwardCuda(
const torch::Tensor& face_verts, // (F, 3, 3)
const torch::Tensor& pix_to_face, // (N, H, W, K)
const torch::Tensor& grad_zbuf, // (N, H, W, K)
const torch::Tensor& grad_bary, // (N, H, W, K, 3)
const torch::Tensor& grad_dists, // (N, H, W, K)
bool perspective_correct) {
const int F = face_verts.size(0);
const int N = pix_to_face.size(0);
const int H = pix_to_face.size(1);
const int W = pix_to_face.size(2);
const int K = pix_to_face.size(3);
torch::Tensor grad_face_verts = torch::zeros({F, 3, 3}, face_verts.options());
const size_t blocks = 1024;
const size_t threads = 64;
RasterizeMeshesBackwardCudaKernel<<<blocks, threads>>>(
face_verts.contiguous().data<float>(),
pix_to_face.contiguous().data<int64_t>(),
perspective_correct,
N,
F,
H,
W,
K,
grad_zbuf.contiguous().data<float>(),
grad_bary.contiguous().data<float>(),
grad_dists.contiguous().data<float>(),
grad_face_verts.contiguous().data<float>());
return grad_face_verts;
}
// ****************************************************************************
// * COARSE RASTERIZATION *
// ****************************************************************************
__global__ void RasterizeMeshesCoarseCudaKernel(
const float* face_verts,
const int64_t* mesh_to_face_first_idx,
const int64_t* num_faces_per_mesh,
const float blur_radius,
const int N,
const int F,
const int H,
const int W,
const int bin_size,
const int chunk_size,
const int max_faces_per_bin,
int* faces_per_bin,
int* bin_faces) {
extern __shared__ char sbuf[];
const int M = max_faces_per_bin;
const int num_bins = 1 + (W - 1) / bin_size; // Integer divide round up
const float half_pix = 1.0f / W; // Size of half a pixel in NDC units
// This is a boolean array of shape (num_bins, num_bins, chunk_size)
// stored in shared memory that will track whether each point in the chunk
// falls into each bin of the image.
BitMask binmask((unsigned int*)sbuf, num_bins, num_bins, chunk_size);
// Have each block handle a chunk of faces
const int chunks_per_batch = 1 + (F - 1) / chunk_size;
const int num_chunks = N * chunks_per_batch;
for (int chunk = blockIdx.x; chunk < num_chunks; chunk += gridDim.x) {
const int batch_idx = chunk / chunks_per_batch; // batch index
const int chunk_idx = chunk % chunks_per_batch;
const int face_start_idx = chunk_idx * chunk_size;
binmask.block_clear();
const int64_t mesh_face_start_idx = mesh_to_face_first_idx[batch_idx];
const int64_t mesh_face_stop_idx =
mesh_face_start_idx + num_faces_per_mesh[batch_idx];
// Have each thread handle a different face within the chunk
for (int f = threadIdx.x; f < chunk_size; f += blockDim.x) {
const int f_idx = face_start_idx + f;
// Check if face index corresponds to the mesh in the batch given by
// batch_idx
if (f_idx >= mesh_face_stop_idx || f_idx < mesh_face_start_idx) {
continue;
}
// Get xyz coordinates of the three face vertices.
const auto v012 = GetSingleFaceVerts(face_verts, f_idx);
const float3 v0 = thrust::get<0>(v012);
const float3 v1 = thrust::get<1>(v012);
const float3 v2 = thrust::get<2>(v012);
// Compute screen-space bbox for the triangle expanded by blur.
float xmin = FloatMin3(v0.x, v1.x, v2.x) - sqrt(blur_radius);
float ymin = FloatMin3(v0.y, v1.y, v2.y) - sqrt(blur_radius);
float xmax = FloatMax3(v0.x, v1.x, v2.x) + sqrt(blur_radius);
float ymax = FloatMax3(v0.y, v1.y, v2.y) + sqrt(blur_radius);
float zmax = FloatMax3(v0.z, v1.z, v2.z);
if (zmax < 0) {
continue; // Face is behind the camera.
}
// Brute-force search over all bins; TODO(T54294966) something smarter.
for (int by = 0; by < num_bins; ++by) {
// Y coordinate of the top and bottom of the bin.
// PixToNdc gives the location of the center of each pixel, so we
// need to add/subtract a half pixel to get the true extent of the bin.
const float bin_y_min = PixToNdc(by * bin_size, H) - half_pix;
const float bin_y_max = PixToNdc((by + 1) * bin_size - 1, H) + half_pix;
const bool y_overlap = (ymin <= bin_y_max) && (bin_y_min < ymax);
for (int bx = 0; bx < num_bins; ++bx) {
// X coordinate of the left and right of the bin.
const float bin_x_min = PixToNdc(bx * bin_size, W) - half_pix;
const float bin_x_max =
PixToNdc((bx + 1) * bin_size - 1, W) + half_pix;
const bool x_overlap = (xmin <= bin_x_max) && (bin_x_min < xmax);
if (y_overlap && x_overlap) {
binmask.set(by, bx, f);
}
}
}
}
__syncthreads();
// Now we have processed every face in the current chunk. We need to
// count the number of faces in each bin so we can write the indices
// out to global memory. We have each thread handle a different bin.
for (int byx = threadIdx.x; byx < num_bins * num_bins; byx += blockDim.x) {
const int by = byx / num_bins;
const int bx = byx % num_bins;
const int count = binmask.count(by, bx);
const int faces_per_bin_idx =
batch_idx * num_bins * num_bins + by * num_bins + bx;
// This atomically increments the (global) number of faces found
// in the current bin, and gets the previous value of the counter;
// this effectively allocates space in the bin_faces array for the
// faces in the current chunk that fall into this bin.
const int start = atomicAdd(faces_per_bin + faces_per_bin_idx, count);
// Now loop over the binmask and write the active bits for this bin
// out to bin_faces.
int next_idx = batch_idx * num_bins * num_bins * M + by * num_bins * M +
bx * M + start;
for (int f = 0; f < chunk_size; ++f) {
if (binmask.get(by, bx, f)) {
// TODO(T54296346) find the correct method for handling errors in
// CUDA. Throw an error if num_faces_per_bin > max_faces_per_bin.
// Either decrease bin size or increase max_faces_per_bin
bin_faces[next_idx] = face_start_idx + f;
next_idx++;
}
}
}
__syncthreads();
}
}
torch::Tensor RasterizeMeshesCoarseCuda(
const torch::Tensor& face_verts,
const torch::Tensor& mesh_to_face_first_idx,
const torch::Tensor& num_faces_per_mesh,
const int image_size,
const float blur_radius,
const int bin_size,
const int max_faces_per_bin) {
if (face_verts.ndimension() != 3 || face_verts.size(1) != 3 ||
face_verts.size(2) != 3) {
AT_ERROR("face_verts must have dimensions (num_faces, 3, 3)");
}
const int W = image_size;
const int H = image_size;
const int F = face_verts.size(0);
const int N = num_faces_per_mesh.size(0);
const int num_bins = 1 + (image_size - 1) / bin_size; // Divide round up.
const int M = max_faces_per_bin;
if (num_bins >= 22) {
std::stringstream ss;
ss << "Got " << num_bins << "; that's too many!";
AT_ERROR(ss.str());
}
auto opts = face_verts.options().dtype(torch::kInt32);
torch::Tensor faces_per_bin = torch::zeros({N, num_bins, num_bins}, opts);
torch::Tensor bin_faces = torch::full({N, num_bins, num_bins, M}, -1, opts);
const int chunk_size = 512;
const size_t shared_size = num_bins * num_bins * chunk_size / 8;
const size_t blocks = 64;
const size_t threads = 512;
RasterizeMeshesCoarseCudaKernel<<<blocks, threads, shared_size>>>(
face_verts.contiguous().data<float>(),
mesh_to_face_first_idx.contiguous().data<int64_t>(),
num_faces_per_mesh.contiguous().data<int64_t>(),
blur_radius,
N,
F,
H,
W,
bin_size,
chunk_size,
M,
faces_per_bin.contiguous().data<int32_t>(),
bin_faces.contiguous().data<int32_t>());
return bin_faces;
}
// ****************************************************************************
// * FINE RASTERIZATION *
// ****************************************************************************
__global__ void RasterizeMeshesFineCudaKernel(
const float* face_verts, // (F, 3, 3)
const int32_t* bin_faces, // (N, B, B, T)
const float blur_radius,
const int bin_size,
const bool perspective_correct,
const int N,
const int F,
const int B,
const int M,
const int H,
const int W,
const int K,
int64_t* face_idxs, // (N, S, S, K)
float* zbuf, // (N, S, S, K)
float* pix_dists, // (N, S, S, K)
float* bary // (N, S, S, K, 3)
) {
// This can be more than S^2 if S % bin_size != 0
int num_pixels = N * B * B * bin_size * bin_size;
int num_threads = gridDim.x * blockDim.x;
int tid = blockIdx.x * blockDim.x + threadIdx.x;
for (int pid = tid; pid < num_pixels; pid += num_threads) {
// Convert linear index into bin and pixel indices. We make the within
// block pixel ids move the fastest, so that adjacent threads will fall
// into the same bin; this should give them coalesced memory reads when
// they read from faces and bin_faces.
int i = pid;
const int n = i / (B * B * bin_size * bin_size);
i %= B * B * bin_size * bin_size;
const int by = i / (B * bin_size * bin_size);
i %= B * bin_size * bin_size;
const int bx = i / (bin_size * bin_size);
i %= bin_size * bin_size;
const int yi = i / bin_size + by * bin_size;
const int xi = i % bin_size + bx * bin_size;
if (yi >= H || xi >= W)
continue;
const float xf = PixToNdc(xi, W);
const float yf = PixToNdc(yi, H);
const float2 pxy = make_float2(xf, yf);
// This part looks like the naive rasterization kernel, except we use
// bin_faces to only look at a subset of faces already known to fall
// in this bin. TODO abstract out this logic into some data structure
// that is shared by both kernels?
Pixel q[kMaxPointsPerPixel];
int q_size = 0;
float q_max_z = -1000;
int q_max_idx = -1;
for (int m = 0; m < M; m++) {
const int f = bin_faces[n * B * B * M + by * B * M + bx * M + m];
if (f < 0) {
continue; // bin_faces uses -1 as a sentinal value.
}
// Check if the pixel pxy is inside the face bounding box and if it is,
// update q, q_size, q_max_z and q_max_idx in place.
CheckPixelInsideFace(
face_verts,
f,
q_size,
q_max_z,
q_max_idx,
q,
blur_radius,
pxy,
K,
perspective_correct);
}
// Now we've looked at all the faces for this bin, so we can write
// output for the current pixel.
// TODO: make sorting an option as only top k is needed, not sorted values.
BubbleSort(q, q_size);
const int pix_idx = n * H * W * K + yi * H * K + xi * K;
for (int k = 0; k < q_size; k++) {
face_idxs[pix_idx + k] = q[k].idx;
zbuf[pix_idx + k] = q[k].z;
pix_dists[pix_idx + k] = q[k].dist;
bary[(pix_idx + k) * 3 + 0] = q[k].bary.x;
bary[(pix_idx + k) * 3 + 1] = q[k].bary.y;
bary[(pix_idx + k) * 3 + 2] = q[k].bary.z;
}
}
}
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
RasterizeMeshesFineCuda(
const torch::Tensor& face_verts,
const torch::Tensor& bin_faces,
const int image_size,
const float blur_radius,
const int bin_size,
const int faces_per_pixel,
bool perspective_correct) {
if (face_verts.ndimension() != 3 || face_verts.size(1) != 3 ||
face_verts.size(2) != 3) {
AT_ERROR("face_verts must have dimensions (num_faces, 3, 3)");
}
if (bin_faces.ndimension() != 4) {
AT_ERROR("bin_faces must have 4 dimensions");
}
const int F = face_verts.size(0);
const int N = bin_faces.size(0);
const int B = bin_faces.size(1);
const int M = bin_faces.size(3);
const int K = faces_per_pixel;
const int H = image_size; // Assume square images only.
const int W = image_size;
if (K > kMaxPointsPerPixel) {
AT_ERROR("Must have num_closest <= 8");
}
auto long_opts = face_verts.options().dtype(torch::kInt64);
auto float_opts = face_verts.options().dtype(torch::kFloat32);
torch::Tensor face_idxs = torch::full({N, H, W, K}, -1, long_opts);
torch::Tensor zbuf = torch::full({N, H, W, K}, -1, float_opts);
torch::Tensor pix_dists = torch::full({N, H, W, K}, -1, float_opts);
torch::Tensor bary = torch::full({N, H, W, K, 3}, -1, float_opts);
const size_t blocks = 1024;
const size_t threads = 64;
RasterizeMeshesFineCudaKernel<<<blocks, threads>>>(
face_verts.contiguous().data<float>(),
bin_faces.contiguous().data<int32_t>(),
blur_radius,
bin_size,
perspective_correct,
N,
F,
B,
M,
H,
W,
K,
face_idxs.contiguous().data<int64_t>(),
zbuf.contiguous().data<float>(),
pix_dists.contiguous().data<float>(),
bary.contiguous().data<float>());
return std::make_tuple(face_idxs, zbuf, bary, pix_dists);
}

View File

@@ -0,0 +1,411 @@
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#pragma once
#include <torch/extension.h>
#include <cstdio>
#include <tuple>
// ****************************************************************************
// * FORWARD PASS *
// ****************************************************************************
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
RasterizeMeshesNaiveCpu(
const torch::Tensor& face_verts,
const torch::Tensor& mesh_to_face_first_idx,
const torch::Tensor& num_faces_per_mesh,
int image_size,
float blur_radius,
int faces_per_pixel,
bool perspective_correct);
std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor>
RasterizeMeshesNaiveCuda(
const at::Tensor& face_verts,
const at::Tensor& mesh_to_face_first_idx,
const at::Tensor& num_faces_per_mesh,
int image_size,
float blur_radius,
int num_closest,
bool perspective_correct);
// Forward pass for rasterizing a batch of meshes.
//
// Args:
// face_verts: Tensor of shape (F, 3, 3) giving (packed) vertex positions for
// faces in all the meshes in the batch. Concretely,
// face_verts[f, i] = [x, y, z] gives the coordinates for the
// ith vertex of the fth face. These vertices are expected to be
// in NDC coordinates in the range [-1, 1].
// mesh_to_face_first_idx: LongTensor of shape (N) giving the index in
// faces_verts of the first face in each mesh in
// the batch where N is the batch size.
// num_faces_per_mesh: LongTensor of shape (N) giving the number of faces
// for each mesh in the batch.
// image_size: Size in pixels of the output image to be rasterized.
// Assume square images only.
// blur_radius: float distance in NDC coordinates uses to expand the face
// bounding boxes for the rasterization. Set to 0.0 if no blur
// is required.
// faces_per_pixel: the number of closeset faces to rasterize per pixel.
// perspective_correct: Whether to apply perspective correction when
// computing barycentric coordinates. If this is True,
// then this function returns world-space barycentric
// coordinates for each pixel; if this is False then
// this function instead returns screen-space
// barycentric coordinates for each pixel.
//
// Returns:
// A 4 element tuple of:
// pix_to_face: int64 tensor of shape (N, H, W, K) giving the face index of
// each of the closest faces to the pixel in the rasterized
// image, or -1 for pixels that are not covered by any face.
// zbuf: float32 Tensor of shape (N, H, W, K) giving the depth of each of
// the closest faces for each pixel.
// barycentric_coords: float tensor of shape (N, H, W, K, 3) giving
// barycentric coordinates of the pixel with respect to
// each of the closest faces along the z axis, padded
// with -1 for pixels hit by fewer than
// faces_per_pixel faces.
// dists: float tensor of shape (N, H, W, K) giving the euclidean distance
// in the (NDC) x/y plane between each pixel and its K closest
// faces along the z axis padded with -1 for pixels hit by fewer than
// faces_per_pixel faces.
inline std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
RasterizeMeshesNaive(
const torch::Tensor& face_verts,
const torch::Tensor& mesh_to_face_first_idx,
const torch::Tensor& num_faces_per_mesh,
int image_size,
float blur_radius,
int faces_per_pixel,
bool perspective_correct) {
// TODO: Better type checking.
if (face_verts.type().is_cuda()) {
return RasterizeMeshesNaiveCuda(
face_verts,
mesh_to_face_first_idx,
num_faces_per_mesh,
image_size,
blur_radius,
faces_per_pixel,
perspective_correct);
} else {
return RasterizeMeshesNaiveCpu(
face_verts,
mesh_to_face_first_idx,
num_faces_per_mesh,
image_size,
blur_radius,
faces_per_pixel,
perspective_correct);
}
}
// ****************************************************************************
// * BACKWARD PASS *
// ****************************************************************************
torch::Tensor RasterizeMeshesBackwardCpu(
const torch::Tensor& face_verts,
const torch::Tensor& pix_to_face,
const torch::Tensor& grad_bary,
const torch::Tensor& grad_zbuf,
const torch::Tensor& grad_dists,
bool perspective_correct);
torch::Tensor RasterizeMeshesBackwardCuda(
const torch::Tensor& face_verts,
const torch::Tensor& pix_to_face,
const torch::Tensor& grad_bary,
const torch::Tensor& grad_zbuf,
const torch::Tensor& grad_dists,
bool perspective_correct);
// Args:
// face_verts: float32 Tensor of shape (F, 3, 3) (from forward pass) giving
// (packed) vertex positions for faces in all the meshes in
// the batch.
// pix_to_face: int64 tensor of shape (N, H, W, K) giving the face index of
// each of the closest faces to the pixel in the rasterized
// image, or -1 for pixels that are not covered by any face.
// grad_zbuf: Tensor of shape (N, H, W, K) giving upstream gradients
// d(loss)/d(zbuf) of the zbuf tensor from the forward pass.
// grad_bary: Tensor of shape (N, H, W, K, 3) giving upstream gradients
// d(loss)/d(bary) of the barycentric_coords tensor returned by
// the forward pass.
// grad_dists: Tensor of shape (N, H, W, K) giving upstream gradients
// d(loss)/d(dists) of the dists tensor from the forward pass.
// perspective_correct: Whether to apply perspective correction when
// computing barycentric coordinates. If this is True,
// then this function returns world-space barycentric
// coordinates for each pixel; if this is False then
// this function instead returns screen-space
// barycentric coordinates for each pixel.
//
// Returns:
// grad_face_verts: float32 Tensor of shape (F, 3, 3) giving downstream
// gradients for the face vertices.
torch::Tensor RasterizeMeshesBackward(
const torch::Tensor& face_verts,
const torch::Tensor& pix_to_face,
const torch::Tensor& grad_zbuf,
const torch::Tensor& grad_bary,
const torch::Tensor& grad_dists,
bool perspective_correct) {
if (face_verts.type().is_cuda()) {
return RasterizeMeshesBackwardCuda(
face_verts,
pix_to_face,
grad_zbuf,
grad_bary,
grad_dists,
perspective_correct);
} else {
return RasterizeMeshesBackwardCpu(
face_verts,
pix_to_face,
grad_zbuf,
grad_bary,
grad_dists,
perspective_correct);
}
}
// ****************************************************************************
// * COARSE RASTERIZATION *
// ****************************************************************************
torch::Tensor RasterizeMeshesCoarseCpu(
const torch::Tensor& face_verts,
const at::Tensor& mesh_to_face_first_idx,
const at::Tensor& num_faces_per_mesh,
int image_size,
float blur_radius,
int bin_size,
int max_faces_per_bin);
torch::Tensor RasterizeMeshesCoarseCuda(
const torch::Tensor& face_verts,
const torch::Tensor& mesh_to_face_first_idx,
const torch::Tensor& num_faces_per_mesh,
int image_size,
float blur_radius,
int bin_size,
int max_faces_per_bin);
// Args:
// face_verts: Tensor of shape (F, 3, 3) giving (packed) vertex positions for
// faces in all the meshes in the batch. Concretely,
// face_verts[f, i] = [x, y, z] gives the coordinates for the
// ith vertex of the fth face. These vertices are expected to be
// in NDC coordinates in the range [-1, 1].
// mesh_to_face_first_idx: LongTensor of shape (N) giving the index in
// faces_verts of the first face in each mesh in
// the batch where N is the batch size.
// num_faces_per_mesh: LongTensor of shape (N) giving the number of faces
// for each mesh in the batch.
// image_size: Size in pixels of the output image to be rasterized.
// blur_radius: float distance in NDC coordinates uses to expand the face
// bounding boxes for the rasterization. Set to 0.0 if no blur
// is required.
// bin_size: Size of each bin within the image (in pixels)
// max_faces_per_bin: Maximum number of faces to count in each bin.
//
// Returns:
// bin_face_idxs: Tensor of shape (N, num_bins, num_bins, K) giving the
// indices of faces that fall into each bin.
torch::Tensor RasterizeMeshesCoarse(
const torch::Tensor& face_verts,
const torch::Tensor& mesh_to_face_first_idx,
const torch::Tensor& num_faces_per_mesh,
int image_size,
float blur_radius,
int bin_size,
int max_faces_per_bin) {
if (face_verts.type().is_cuda()) {
return RasterizeMeshesCoarseCuda(
face_verts,
mesh_to_face_first_idx,
num_faces_per_mesh,
image_size,
blur_radius,
bin_size,
max_faces_per_bin);
} else {
return RasterizeMeshesCoarseCpu(
face_verts,
mesh_to_face_first_idx,
num_faces_per_mesh,
image_size,
blur_radius,
bin_size,
max_faces_per_bin);
}
}
// ****************************************************************************
// * FINE RASTERIZATION *
// ****************************************************************************
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
RasterizeMeshesFineCuda(
const torch::Tensor& face_verts,
const torch::Tensor& bin_faces,
int image_size,
float blur_radius,
int bin_size,
int faces_per_pixel,
bool perspective_correct);
// Args:
// face_verts: Tensor of shape (F, 3, 3) giving (packed) vertex positions for
// faces in all the meshes in the batch. Concretely,
// face_verts[f, i] = [x, y, z] gives the coordinates for the
// ith vertex of the fth face. These vertices are expected to be
// in NDC coordinates in the range [-1, 1].
// bin_faces: int32 Tensor of shape (N, B, B, M) giving the indices of faces
// that fall into each bin (output from coarse rasterization).
// image_size: Size in pixels of the output image to be rasterized.
// blur_radius: float distance in NDC coordinates uses to expand the face
// bounding boxes for the rasterization. Set to 0.0 if no blur
// is required.
// bin_size: Size of each bin within the image (in pixels)
// faces_per_pixel: the number of closeset faces to rasterize per pixel.
// perspective_correct: Whether to apply perspective correction when
// computing barycentric coordinates. If this is True,
// then this function returns world-space barycentric
// coordinates for each pixel; if this is False then
// this function instead returns screen-space
// barycentric coordinates for each pixel.
//
// Returns (same as rasterize_meshes):
// A 4 element tuple of:
// pix_to_face: int64 tensor of shape (N, H, W, K) giving the face index of
// each of the closest faces to the pixel in the rasterized
// image, or -1 for pixels that are not covered by any face.
// zbuf: float32 Tensor of shape (N, H, W, K) giving the depth of each of
// the closest faces for each pixel.
// barycentric_coords: float tensor of shape (N, H, W, K, 3) giving
// barycentric coordinates of the pixel with respect to
// each of the closest faces along the z axis, padded
// with -1 for pixels hit by fewer than
// faces_per_pixel faces.
// dists: float tensor of shape (N, H, W, K) giving the euclidean distance
// in the (NDC) x/y plane between each pixel and its K closest
// faces along the z axis padded with -1 for pixels hit by fewer than
// faces_per_pixel faces.
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
RasterizeMeshesFine(
const torch::Tensor& face_verts,
const torch::Tensor& bin_faces,
int image_size,
float blur_radius,
int bin_size,
int faces_per_pixel,
bool perspective_correct) {
if (face_verts.type().is_cuda()) {
return RasterizeMeshesFineCuda(
face_verts,
bin_faces,
image_size,
blur_radius,
bin_size,
faces_per_pixel,
perspective_correct);
} else {
AT_ERROR("NOT IMPLEMENTED");
}
}
// ****************************************************************************
// * MAIN ENTRY POINT *
// ****************************************************************************
// This is the main entry point for the forward pass of the mesh rasterizer;
// it uses either naive or coarse-to-fine rasterization based on bin_size.
//
// Args:
// face_verts: Tensor of shape (F, 3, 3) giving (packed) vertex positions for
// faces in all the meshes in the batch. Concretely,
// face_verts[f, i] = [x, y, z] gives the coordinates for the
// ith vertex of the fth face. These vertices are expected to be
// in NDC coordinates in the range [-1, 1].
// mesh_to_face_first_idx: LongTensor of shape (N) giving the index in
// faces_verts of the first face in each mesh in
// the batch where N is the batch size.
// num_faces_per_mesh: LongTensor of shape (N) giving the number of faces
// for each mesh in the batch.
// image_size: Size in pixels of the output image to be rasterized.
// blur_radius: float distance in NDC coordinates uses to expand the face
// bounding boxes for the rasterization. Set to 0.0 if no blur
// is required.
// bin_size: Bin size (in pixels) for coarse-to-fine rasterization. Setting
// bin_size=0 uses naive rasterization instead.
// max_faces_per_bin: The maximum number of faces allowed to fall into each
// bin when using coarse-to-fine rasterization.
// perspective_correct: Whether to apply perspective correction when
// computing barycentric coordinates. If this is True,
// then this function returns world-space barycentric
// coordinates for each pixel; if this is False then
// this function instead returns screen-space
// barycentric coordinates for each pixel.
//
// Returns:
// A 4 element tuple of:
// pix_to_face: int64 tensor of shape (N, H, W, K) giving the face index of
// each of the closest faces to the pixel in the rasterized
// image, or -1 for pixels that are not covered by any face.
// zbuf: float32 Tensor of shape (N, H, W, K) giving the depth of each of
// the closest faces for each pixel.
// barycentric_coords: float tensor of shape (N, H, W, K, 3) giving
// barycentric coordinates of the pixel with respect to
// each of the closest faces along the z axis, padded
// with -1 for pixels hit by fewer than
// faces_per_pixel faces.
// dists: float tensor of shape (N, H, W, K) giving the euclidean distance
// in the (NDC) x/y plane between each pixel and its K closest
// faces along the z axis padded with -1 for pixels hit by fewer than
// faces_per_pixel faces.
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
RasterizeMeshes(
const torch::Tensor& face_verts,
const torch::Tensor& mesh_to_face_first_idx,
const torch::Tensor& num_faces_per_mesh,
int image_size,
float blur_radius,
int faces_per_pixel,
int bin_size,
int max_faces_per_bin,
bool perspective_correct) {
if (bin_size > 0 && max_faces_per_bin > 0) {
// Use coarse-to-fine rasterization
auto bin_faces = RasterizeMeshesCoarse(
face_verts,
mesh_to_face_first_idx,
num_faces_per_mesh,
image_size,
blur_radius,
bin_size,
max_faces_per_bin);
return RasterizeMeshesFine(
face_verts,
bin_faces,
image_size,
blur_radius,
bin_size,
faces_per_pixel,
perspective_correct);
} else {
// Use the naive per-pixel implementation
return RasterizeMeshesNaive(
face_verts,
mesh_to_face_first_idx,
num_faces_per_mesh,
image_size,
blur_radius,
faces_per_pixel,
perspective_correct);
}
}

View File

@@ -0,0 +1,471 @@
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#include <torch/extension.h>
#include <algorithm>
#include <list>
#include <queue>
#include <tuple>
#include "geometry_utils.h"
#include "vec2.h"
#include "vec3.h"
float PixToNdc(int i, int S) {
// NDC x-offset + (i * pixel_width + half_pixel_width)
return -1 + (2 * i + 1.0f) / S;
}
// Get (x, y, z) values for vertex from (3, 3) tensor face.
template <typename Face>
auto ExtractVerts(const Face& face, const int vertex_index) {
return std::make_tuple(
face[vertex_index][0], face[vertex_index][1], face[vertex_index][2]);
}
// Compute min/max x/y for each face.
auto ComputeFaceBoundingBoxes(const torch::Tensor& face_verts) {
const int total_F = face_verts.size(0);
auto float_opts = face_verts.options().dtype(torch::kFloat32);
auto face_verts_a = face_verts.accessor<float, 3>();
torch::Tensor face_bboxes = torch::full({total_F, 6}, -2.0, float_opts);
// Loop through all the faces
for (int f = 0; f < total_F; ++f) {
const auto& face = face_verts_a[f];
float x0, x1, x2, y0, y1, y2, z0, z1, z2;
std::tie(x0, y0, z0) = ExtractVerts(face, 0);
std::tie(x1, y1, z1) = ExtractVerts(face, 1);
std::tie(x2, y2, z2) = ExtractVerts(face, 2);
const float x_min = std::min(x0, std::min(x1, x2));
const float y_min = std::min(y0, std::min(y1, y2));
const float x_max = std::max(x0, std::max(x1, x2));
const float y_max = std::max(y0, std::max(y1, y2));
const float z_min = std::min(z0, std::min(z1, z2));
const float z_max = std::max(z0, std::max(z1, z2));
face_bboxes[f][0] = x_min;
face_bboxes[f][1] = y_min;
face_bboxes[f][2] = x_max;
face_bboxes[f][3] = y_max;
face_bboxes[f][4] = z_min;
face_bboxes[f][5] = z_max;
}
return face_bboxes;
}
// Check if the point (px, py) lies inside the face bounding box face_bbox.
// Return true if the point is outside.
template <typename Face>
bool CheckPointOutsideBoundingBox(
const Face& face_bbox,
float blur_radius,
float px,
float py) {
// Read triangle bbox coordinates and expand by blur radius.
float x_min = face_bbox[0] - blur_radius;
float y_min = face_bbox[1] - blur_radius;
float x_max = face_bbox[2] + blur_radius;
float y_max = face_bbox[3] + blur_radius;
// Check if the current point is within the triangle bounding box.
return (px > x_max || px < x_min || py > y_max || py < y_min);
}
// Calculate areas of all faces. Returns a tensor of shape (total_faces, 1)
// where faces with zero area have value -1.
auto ComputeFaceAreas(const torch::Tensor& face_verts) {
const int total_F = face_verts.size(0);
auto float_opts = face_verts.options().dtype(torch::kFloat32);
auto face_verts_a = face_verts.accessor<float, 3>();
torch::Tensor face_areas = torch::full({total_F}, -1, float_opts);
// Loop through all the faces
for (int f = 0; f < total_F; ++f) {
const auto& face = face_verts_a[f];
float x0, x1, x2, y0, y1, y2, z0, z1, z2;
std::tie(x0, y0, z0) = ExtractVerts(face, 0);
std::tie(x1, y1, z1) = ExtractVerts(face, 1);
std::tie(x2, y2, z2) = ExtractVerts(face, 2);
const vec2<float> v0(x0, y0);
const vec2<float> v1(x1, y1);
const vec2<float> v2(x2, y2);
const float face_area = EdgeFunctionForward(v0, v1, v2);
face_areas[f] = face_area;
}
return face_areas;
}
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
RasterizeMeshesNaiveCpu(
const torch::Tensor& face_verts,
const torch::Tensor& mesh_to_face_first_idx,
const torch::Tensor& num_faces_per_mesh,
int image_size,
float blur_radius,
int faces_per_pixel,
bool perspective_correct) {
if (face_verts.ndimension() != 3 || face_verts.size(1) != 3 ||
face_verts.size(2) != 3) {
AT_ERROR("face_verts must have dimensions (num_faces, 3, 3)");
}
if (num_faces_per_mesh.size(0) != mesh_to_face_first_idx.size(0)) {
AT_ERROR(
"num_faces_per_mesh must have save size first dimension as mesh_to_face_first_idx");
}
const int32_t N = mesh_to_face_first_idx.size(0); // batch_size.
const int H = image_size;
const int W = image_size;
const int K = faces_per_pixel;
auto long_opts = face_verts.options().dtype(torch::kInt64);
auto float_opts = face_verts.options().dtype(torch::kFloat32);
// Initialize output tensors.
torch::Tensor face_idxs = torch::full({N, H, W, K}, -1, long_opts);
torch::Tensor zbuf = torch::full({N, H, W, K}, -1, float_opts);
torch::Tensor pix_dists = torch::full({N, H, W, K}, -1, float_opts);
torch::Tensor barycentric_coords =
torch::full({N, H, W, K, 3}, -1, float_opts);
auto face_verts_a = face_verts.accessor<float, 3>();
auto face_idxs_a = face_idxs.accessor<int64_t, 4>();
auto zbuf_a = zbuf.accessor<float, 4>();
auto pix_dists_a = pix_dists.accessor<float, 4>();
auto barycentric_coords_a = barycentric_coords.accessor<float, 5>();
auto face_bboxes = ComputeFaceBoundingBoxes(face_verts);
auto face_bboxes_a = face_bboxes.accessor<float, 2>();
auto face_areas = ComputeFaceAreas(face_verts);
auto face_areas_a = face_areas.accessor<float, 1>();
for (int n = 0; n < N; ++n) {
// Loop through each mesh in the batch.
// Get the start index of the faces in faces_packed and the num faces
// in the mesh to avoid having to loop through all the faces.
const int face_start_idx = mesh_to_face_first_idx[n].item().to<int32_t>();
const int face_stop_idx =
(face_start_idx + num_faces_per_mesh[n].item().to<int32_t>());
// Iterate through the horizontal lines of the image from top to bottom.
for (int yi = 0; yi < H; ++yi) {
// Y coordinate of the top of the pixel.
const float yf = PixToNdc(yi, H);
// Iterate through pixels on this horizontal line, left to right.
for (int xi = 0; xi < W; ++xi) {
// X coordinate of the left of the pixel.
const float xf = PixToNdc(xi, W);
// Use a priority queue to hold values:
// (z, idx, r, bary.x, bary.y. bary.z)
std::priority_queue<std::tuple<float, int, float, float, float, float>>
q;
// Loop through the faces in the mesh.
for (int f = face_start_idx; f < face_stop_idx; ++f) {
// Get coordinates of three face vertices.
const auto& face = face_verts_a[f];
float x0, x1, x2, y0, y1, y2, z0, z1, z2;
std::tie(x0, y0, z0) = ExtractVerts(face, 0);
std::tie(x1, y1, z1) = ExtractVerts(face, 1);
std::tie(x2, y2, z2) = ExtractVerts(face, 2);
const vec2<float> v0(x0, y0);
const vec2<float> v1(x1, y1);
const vec2<float> v2(x2, y2);
// Skip faces with zero area.
const float face_area = face_areas_a[f];
if (face_area <= kEpsilon && face_area >= -1.0f * kEpsilon) {
continue;
}
// Skip if point is outside the face bounding box.
const auto face_bbox = face_bboxes_a[f];
const bool outside_bbox = CheckPointOutsideBoundingBox(
face_bbox, std::sqrt(blur_radius), xf, yf);
if (outside_bbox) {
continue;
}
// Compute barycentric coordinates and use this to get the
// depth of the point on the triangle.
const vec2<float> pxy(xf, yf);
const vec3<float> bary0 =
BarycentricCoordinatesForward(pxy, v0, v1, v2);
const vec3<float> bary = !perspective_correct
? bary0
: BarycentricPerspectiveCorrectionForward(bary0, z0, z1, z2);
// Use barycentric coordinates to get the depth of the current pixel
const float pz = (bary.x * z0 + bary.y * z1 + bary.z * z2);
if (pz < 0) {
continue; // Point is behind the image plane so ignore.
}
// Compute absolute distance of the point to the triangle.
// If the point is inside the triangle then the distance
// is negative.
const float dist = PointTriangleDistanceForward(pxy, v0, v1, v2);
// Use the bary coordinates to determine if the point is
// inside the face.
const bool inside = bary.x > 0.0f && bary.y > 0.0f && bary.z > 0.0f;
const float signed_dist = inside ? -dist : dist;
// Check if pixel is outside blur region
if (!inside && dist >= blur_radius) {
continue;
}
// The current pixel lies inside the current face.
q.emplace(pz, f, signed_dist, bary.x, bary.y, bary.z);
if (static_cast<int>(q.size()) > K) {
q.pop();
}
}
while (!q.empty()) {
auto t = q.top();
q.pop();
const int i = q.size();
zbuf_a[n][yi][xi][i] = std::get<0>(t);
face_idxs_a[n][yi][xi][i] = std::get<1>(t);
pix_dists_a[n][yi][xi][i] = std::get<2>(t);
barycentric_coords_a[n][yi][xi][i][0] = std::get<3>(t);
barycentric_coords_a[n][yi][xi][i][1] = std::get<4>(t);
barycentric_coords_a[n][yi][xi][i][2] = std::get<5>(t);
}
}
}
}
return std::make_tuple(face_idxs, zbuf, barycentric_coords, pix_dists);
}
torch::Tensor RasterizeMeshesBackwardCpu(
const torch::Tensor& face_verts, // (F, 3, 3)
const torch::Tensor& pix_to_face, // (N, H, W, K)
const torch::Tensor& grad_zbuf, // (N, H, W, K)
const torch::Tensor& grad_bary, // (N, H, W, K, 3)
const torch::Tensor& grad_dists, // (N, H, W, K)
bool perspective_correct) {
const int F = face_verts.size(0);
const int N = pix_to_face.size(0);
const int H = pix_to_face.size(1);
const int W = pix_to_face.size(2);
const int K = pix_to_face.size(3);
torch::Tensor grad_face_verts = torch::zeros({F, 3, 3}, face_verts.options());
auto face_verts_a = face_verts.accessor<float, 3>();
auto pix_to_face_a = pix_to_face.accessor<int64_t, 4>();
auto grad_dists_a = grad_dists.accessor<float, 4>();
auto grad_zbuf_a = grad_zbuf.accessor<float, 4>();
auto grad_bary_a = grad_bary.accessor<float, 5>();
for (int n = 0; n < N; ++n) {
// Iterate through the horizontal lines of the image from top to bottom.
for (int y = 0; y < H; ++y) {
// Y coordinate of the top of the pixel.
const float yf = PixToNdc(y, H);
// Iterate through pixels on this horizontal line, left to right.
for (int x = 0; x < W; ++x) {
// X coordinate of the left of the pixel.
const float xf = PixToNdc(x, W);
const vec2<float> pxy(xf, yf);
// Iterate through the faces that hit this pixel.
for (int k = 0; k < K; ++k) {
// Get face index from forward pass output.
const int f = pix_to_face_a[n][y][x][k];
if (f < 0) {
continue; // padded face.
}
// Get coordinates of the three face vertices.
const auto face_verts_f = face_verts_a[f];
const float x0 = face_verts_f[0][0];
const float y0 = face_verts_f[0][1];
const float z0 = face_verts_f[0][2];
const float x1 = face_verts_f[1][0];
const float y1 = face_verts_f[1][1];
const float z1 = face_verts_f[1][2];
const float x2 = face_verts_f[2][0];
const float y2 = face_verts_f[2][1];
const float z2 = face_verts_f[2][2];
const vec2<float> v0xy(x0, y0);
const vec2<float> v1xy(x1, y1);
const vec2<float> v2xy(x2, y2);
// Get upstream gradients for the face.
const float grad_dist_upstream = grad_dists_a[n][y][x][k];
const float grad_zbuf_upstream = grad_zbuf_a[n][y][x][k];
const auto grad_bary_upstream_w012 = grad_bary_a[n][y][x][k];
const float grad_bary_upstream_w0 = grad_bary_upstream_w012[0];
const float grad_bary_upstream_w1 = grad_bary_upstream_w012[1];
const float grad_bary_upstream_w2 = grad_bary_upstream_w012[2];
const vec3<float> grad_bary_upstream(
grad_bary_upstream_w0,
grad_bary_upstream_w1,
grad_bary_upstream_w2);
const vec3<float> bary0 =
BarycentricCoordinatesForward(pxy, v0xy, v1xy, v2xy);
const vec3<float> bary = !perspective_correct
? bary0
: BarycentricPerspectiveCorrectionForward(bary0, z0, z1, z2);
// Distances inside the face are negative so get the
// correct sign to apply to the upstream gradient.
const bool inside = bary.x > 0.0f && bary.y > 0.0f && bary.z > 0.0f;
const float sign = inside ? -1.0f : 1.0f;
// TODO(T52813608) Add support for non-square images.
const auto grad_dist_f = PointTriangleDistanceBackward(
pxy, v0xy, v1xy, v2xy, sign * grad_dist_upstream);
const auto ddist_d_v0 = std::get<1>(grad_dist_f);
const auto ddist_d_v1 = std::get<2>(grad_dist_f);
const auto ddist_d_v2 = std::get<3>(grad_dist_f);
// Upstream gradient for barycentric coords from zbuf calculation:
// zbuf = bary_w0 * z0 + bary_w1 * z1 + bary_w2 * z2
// Therefore
// d_zbuf/d_bary_w0 = z0
// d_zbuf/d_bary_w1 = z1
// d_zbuf/d_bary_w2 = z2
const vec3<float> d_zbuf_d_bary(z0, z1, z2);
// Total upstream barycentric gradients are the sum of
// external upstream gradients and contribution from zbuf.
vec3<float> grad_bary_f_sum =
(grad_bary_upstream + grad_zbuf_upstream * d_zbuf_d_bary);
vec3<float> grad_bary0 = grad_bary_f_sum;
if (perspective_correct) {
auto perspective_grads = BarycentricPerspectiveCorrectionBackward(
bary0, z0, z1, z2, grad_bary_f_sum);
grad_bary0 = std::get<0>(perspective_grads);
grad_face_verts[f][0][2] += std::get<1>(perspective_grads);
grad_face_verts[f][1][2] += std::get<2>(perspective_grads);
grad_face_verts[f][2][2] += std::get<3>(perspective_grads);
}
auto grad_bary_f =
BarycentricCoordsBackward(pxy, v0xy, v1xy, v2xy, grad_bary0);
const vec2<float> dbary_d_v0 = std::get<1>(grad_bary_f);
const vec2<float> dbary_d_v1 = std::get<2>(grad_bary_f);
const vec2<float> dbary_d_v2 = std::get<3>(grad_bary_f);
// Update output gradient buffer.
grad_face_verts[f][0][0] += dbary_d_v0.x + ddist_d_v0.x;
grad_face_verts[f][0][1] += dbary_d_v0.y + ddist_d_v0.y;
grad_face_verts[f][0][2] += grad_zbuf_upstream * bary.x;
grad_face_verts[f][1][0] += dbary_d_v1.x + ddist_d_v1.x;
grad_face_verts[f][1][1] += dbary_d_v1.y + ddist_d_v1.y;
grad_face_verts[f][1][2] += grad_zbuf_upstream * bary.y;
grad_face_verts[f][2][0] += dbary_d_v2.x + ddist_d_v2.x;
grad_face_verts[f][2][1] += dbary_d_v2.y + ddist_d_v2.y;
grad_face_verts[f][2][2] += grad_zbuf_upstream * bary.z;
}
}
}
}
return grad_face_verts;
}
torch::Tensor RasterizeMeshesCoarseCpu(
const torch::Tensor& face_verts,
const torch::Tensor& mesh_to_face_first_idx,
const torch::Tensor& num_faces_per_mesh,
int image_size,
float blur_radius,
int bin_size,
int max_faces_per_bin) {
if (face_verts.ndimension() != 3 || face_verts.size(1) != 3 ||
face_verts.size(2) != 3) {
AT_ERROR("face_verts must have dimensions (num_faces, 3, 3)");
}
if (num_faces_per_mesh.ndimension() != 1) {
AT_ERROR("num_faces_per_mesh can only have one dimension");
}
const int N = num_faces_per_mesh.size(0); // batch size.
const int M = max_faces_per_bin;
// Assume square images. TODO(T52813608) Support non square images.
const float height = image_size;
const float width = image_size;
const int BH = 1 + (height - 1) / bin_size; // Integer division round up.
const int BW = 1 + (width - 1) / bin_size; // Integer division round up.
auto opts = face_verts.options().dtype(torch::kInt32);
torch::Tensor faces_per_bin = torch::zeros({N, BH, BW}, opts);
torch::Tensor bin_faces = torch::full({N, BH, BW, M}, -1, opts);
auto faces_per_bin_a = faces_per_bin.accessor<int32_t, 3>();
auto bin_faces_a = bin_faces.accessor<int32_t, 4>();
// Precompute all face bounding boxes.
auto face_bboxes = ComputeFaceBoundingBoxes(face_verts);
auto face_bboxes_a = face_bboxes.accessor<float, 2>();
const float pixel_width = 2.0f / image_size;
const float bin_width = pixel_width * bin_size;
// Iterate through the meshes in the batch.
for (int n = 0; n < N; ++n) {
const int face_start_idx = mesh_to_face_first_idx[n].item().to<int32_t>();
const int face_stop_idx =
(face_start_idx + num_faces_per_mesh[n].item().to<int32_t>());
float bin_y_min = -1.0f;
float bin_y_max = bin_y_min + bin_width;
// Iterate through the horizontal bins from top to bottom.
for (int by = 0; by < BH; ++by) {
float bin_x_min = -1.0f;
float bin_x_max = bin_x_min + bin_width;
// Iterate through bins on this horizontal line, left to right.
for (int bx = 0; bx < BW; ++bx) {
int32_t faces_hit = 0;
for (int32_t f = face_start_idx; f < face_stop_idx; ++f) {
// Get bounding box and expand by blur radius.
float face_x_min = face_bboxes_a[f][0] - std::sqrt(blur_radius);
float face_y_min = face_bboxes_a[f][1] - std::sqrt(blur_radius);
float face_x_max = face_bboxes_a[f][2] + std::sqrt(blur_radius);
float face_y_max = face_bboxes_a[f][3] + std::sqrt(blur_radius);
float face_z_max = face_bboxes_a[f][5];
if (face_z_max < 0) {
continue; // Face is behind the camera.
}
// Use a half-open interval so that faces exactly on the
// boundary between bins will fall into exactly one bin.
bool x_overlap =
(face_x_min <= bin_x_max) && (bin_x_min < face_x_max);
bool y_overlap =
(face_y_min <= bin_y_max) && (bin_y_min < face_y_max);
if (x_overlap && y_overlap) {
// Got too many faces for this bin, so throw an error.
if (faces_hit >= max_faces_per_bin) {
AT_ERROR("Got too many faces per bin");
}
// The current point falls in the current bin, so
// record it.
bin_faces_a[n][by][bx][faces_hit] = f;
faces_hit++;
}
}
// Shift the bin to the right for the next loop iteration.
bin_x_min = bin_x_max;
bin_x_max = bin_x_min + bin_width;
}
// Shift the bin down for the next loop iteration.
bin_y_min = bin_y_max;
bin_y_max = bin_y_min + bin_width;
}
}
return bin_faces;
}

View File

@@ -0,0 +1,59 @@
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#pragma once
#include <type_traits>
// A fixed-sized vector with basic arithmetic operators useful for
// representing 2D coordinates.
// TODO: switch to Eigen if more functionality is needed.
template <
typename T,
typename = std::enable_if_t<
std::is_same<T, double>::value || std::is_same<T, float>::value>>
struct vec2 {
T x, y;
typedef T scalar_t;
vec2(T x, T y) : x(x), y(y) {}
};
template <typename T>
inline vec2<T> operator+(const vec2<T>& a, const vec2<T>& b) {
return vec2<T>(a.x + b.x, a.y + b.y);
}
template <typename T>
inline vec2<T> operator-(const vec2<T>& a, const vec2<T>& b) {
return vec2<T>(a.x - b.x, a.y - b.y);
}
template <typename T>
inline vec2<T> operator*(const T a, const vec2<T>& b) {
return vec2<T>(a * b.x, a * b.y);
}
template <typename T>
inline vec2<T> operator/(const vec2<T>& a, const T b) {
if (b == 0.0) {
AT_ERROR(
"denominator in vec2 division is 0"); // prevent divide by 0 errors.
}
return vec2<T>(a.x / b, a.y / b);
}
template <typename T>
inline T dot(const vec2<T>& a, const vec2<T>& b) {
return a.x * b.x + a.y * b.y;
}
template <typename T>
inline T norm(const vec2<T>& a, const vec2<T>& b) {
const vec2<T> ba = b - a;
return sqrt(dot(ba, ba));
}
template <typename T>
std::ostream& operator<<(std::ostream& os, const vec2<T>& v) {
os << "vec2(" << v.x << ", " << v.y << ")";
return os;
}

View File

@@ -0,0 +1,63 @@
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#pragma once
// A fixed-sized vector with basic arithmetic operators useful for
// representing 3D coordinates.
// TODO: switch to Eigen if more functionality is needed.
template <
typename T,
typename = std::enable_if_t<
std::is_same<T, double>::value || std::is_same<T, float>::value>>
struct vec3 {
T x, y, z;
typedef T scalar_t;
vec3(T x, T y, T z) : x(x), y(y), z(z) {}
};
template <typename T>
inline vec3<T> operator+(const vec3<T>& a, const vec3<T>& b) {
return vec3<T>(a.x + b.x, a.y + b.y, a.z + b.z);
}
template <typename T>
inline vec3<T> operator-(const vec3<T>& a, const vec3<T>& b) {
return vec3<T>(a.x - b.x, a.y - b.y, a.z - b.z);
}
template <typename T>
inline vec3<T> operator/(const vec3<T>& a, const T b) {
if (b == 0.0) {
AT_ERROR(
"denominator in vec3 division is 0"); // prevent divide by 0 errors.
}
return vec3<T>(a.x / b, a.y / b, a.z / b);
}
template <typename T>
inline vec3<T> operator*(const T a, const vec3<T>& b) {
return vec3<T>(a * b.x, a * b.y, a * b.z);
}
template <typename T>
inline vec3<T> operator*(const vec3<T>& a, const vec3<T>& b) {
return vec3<T>(a.x * b.x, a.y * b.y, a.z * b.z);
}
template <typename T>
inline T dot(const vec3<T>& a, const vec3<T>& b) {
return a.x * b.x + a.y * b.y + a.z * b.z;
}
template <typename T>
inline vec3<T> cross(const vec3<T>& a, const vec3<T>& b) {
return vec3<T>(
a.y * b.z - a.z * b.y, a.z * b.x - a.x * b.z, a.x * b.y - a.y * b.x);
}
template <typename T>
std::ostream& operator<<(std::ostream& os, const vec3<T>& v) {
os << "vec3(" << v.x << ", " << v.y << ", " << v.z << ")";
return os;
}

View File

@@ -0,0 +1,73 @@
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#pragma once
#define BINMASK_H
// A BitMask represents a bool array of shape (H, W, N). We pack values into
// the bits of unsigned ints; a single unsigned int has B = 32 bits, so to hold
// all values we use H * W * (N / B) = H * W * D values. We want to store
// BitMasks in shared memory, so we assume that the memory has already been
// allocated for it elsewhere.
class BitMask {
public:
__device__ BitMask(unsigned int* data, int H, int W, int N)
: data(data), H(H), W(W), B(8 * sizeof(unsigned int)), D(N / B) {
// TODO: check if the data is null.
N = ceilf(N % 32); // take ceil incase N % 32 != 0
block_clear(); // clear the data
}
// Use all threads in the current block to clear all bits of this BitMask
__device__ void block_clear() {
for (int i = threadIdx.x; i < H * W * D; i += blockDim.x) {
data[i] = 0;
}
__syncthreads();
}
__device__ int _get_elem_idx(int y, int x, int d) {
return y * W * D + x * D + d / B;
}
__device__ int _get_bit_idx(int d) {
return d % B;
}
// Turn on a single bit (y, x, d)
__device__ void set(int y, int x, int d) {
int elem_idx = _get_elem_idx(y, x, d);
int bit_idx = _get_bit_idx(d);
const unsigned int mask = 1U << bit_idx;
atomicOr(data + elem_idx, mask);
}
// Turn off a single bit (y, x, d)
__device__ void unset(int y, int x, int d) {
int elem_idx = _get_elem_idx(y, x, d);
int bit_idx = _get_bit_idx(d);
const unsigned int mask = ~(1U << bit_idx);
atomicAnd(data + elem_idx, mask);
}
// Check whether the bit (y, x, d) is on or off
__device__ bool get(int y, int x, int d) {
int elem_idx = _get_elem_idx(y, x, d);
int bit_idx = _get_bit_idx(d);
return (data[elem_idx] >> bit_idx) & 1U;
}
// Compute the number of bits set in the row (y, x, :)
__device__ int count(int y, int x) {
int total = 0;
for (int i = 0; i < D; ++i) {
int elem_idx = y * W * D + x * D + i;
unsigned int elem = data[elem_idx];
total += __popc(elem);
}
return total;
}
private:
unsigned int* data;
int H, W, B, D;
};

View File

@@ -0,0 +1,33 @@
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#pragma once
// Given a pixel coordinate 0 <= i < S, convert it to a normalized device
// coordinate in the range [-1, 1]. We divide the NDC range into S evenly-sized
// pixels, and assume that each pixel falls in the *center* of its range.
__device__ inline float PixToNdc(int i, int S) {
// NDC x-offset + (i * pixel_width + half_pixel_width)
return -1 + (2 * i + 1.0f) / S;
}
// The maximum number of points per pixel that we can return. Since we use
// thread-local arrays to hold and sort points, the maximum size of the array
// needs to be known at compile time. There might be some fancy template magic
// we could use to make this more dynamic, but for now just fix a constant.
// TODO: is 8 enough? Would increasing have performance considerations?
const int32_t kMaxPointsPerPixel = 150;
template <typename T>
__device__ inline void BubbleSort(T* arr, int n) {
// Bubble sort. We only use it for tiny thread-local arrays (n < 8); in this
// regime we care more about warp divergence than computational complexity.
for (int i = 0; i < n - 1; ++i) {
for (int j = 0; j < n - i - 1; ++j) {
if (arr[j + 1] < arr[j]) {
T temp = arr[j];
arr[j] = arr[j + 1];
arr[j + 1] = temp;
}
}
}
}

View File

@@ -0,0 +1,511 @@
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#include <math.h>
#include <torch/extension.h>
#include <cstdio>
#include <sstream>
#include <tuple>
#include "rasterize_points/bitmask.cuh"
#include "rasterize_points/rasterization_utils.cuh"
namespace {
// A little structure for holding details about a pixel.
struct Pix {
float z; // Depth of the reference point.
int32_t idx; // Index of the reference point.
float dist2; // Euclidean distance square to the reference point.
};
__device__ inline bool operator<(const Pix& a, const Pix& b) {
return a.z < b.z;
}
// This function checks if a pixel given by xy location pxy lies within the
// point with index p and batch index n. One of the inputs is a list (q)
// which contains Pixel structs with the indices of the points which intersect
// with this pixel sorted by closest z distance. If the pixel pxy lies in the
// point, the list (q) is updated and re-orderered in place. In addition
// the auxillary variables q_size, q_max_z and q_max_idx are also modified.
// This code is shared between RasterizePointsNaiveCudaKernel and
// RasterizePointsFineCudaKernel.
template <typename PointQ>
__device__ void CheckPixelInsidePoint(
const float* points, // (N, P, 3)
const int p,
int& q_size,
float& q_max_z,
int& q_max_idx,
PointQ& q,
const float radius2,
const float xf,
const float yf,
const int n,
const int P,
const int K) {
const float px = points[n * P * 3 + p * 3 + 0];
const float py = points[n * P * 3 + p * 3 + 1];
const float pz = points[n * P * 3 + p * 3 + 2];
if (pz < 0)
return; // Don't render points behind the camera
const float dx = xf - px;
const float dy = yf - py;
const float dist2 = dx * dx + dy * dy;
if (dist2 < radius2) {
if (q_size < K) {
// Just insert it
q[q_size] = {pz, p, dist2};
if (pz > q_max_z) {
q_max_z = pz;
q_max_idx = q_size;
}
q_size++;
} else if (pz < q_max_z) {
// Overwrite the old max, and find the new max
q[q_max_idx] = {pz, p, dist2};
q_max_z = pz;
for (int i = 0; i < K; i++) {
if (q[i].z > q_max_z) {
q_max_z = q[i].z;
q_max_idx = i;
}
}
}
}
}
} // namespace
// ****************************************************************************
// * NAIVE RASTERIZATION *
// ****************************************************************************
__global__ void RasterizePointsNaiveCudaKernel(
const float* points, // (N, P, 3)
const float radius,
const int N,
const int P,
const int S,
const int K,
int32_t* point_idxs, // (N, S, S, K)
float* zbuf, // (N, S, S, K)
float* pix_dists) { // (N, S, S, K)
// Simple version: One thread per output pixel
const int num_threads = gridDim.x * blockDim.x;
const int tid = blockDim.x * blockIdx.x + threadIdx.x;
const float radius2 = radius * radius;
for (int i = tid; i < N * S * S; i += num_threads) {
// Convert linear index to 3D index
const int n = i / (S * S); // Batch index
const int pix_idx = i % (S * S);
const int yi = pix_idx / S;
const int xi = pix_idx % S;
const float xf = PixToNdc(xi, S);
const float yf = PixToNdc(yi, S);
// For keeping track of the K closest points we want a data structure
// that (1) gives O(1) access to the closest point for easy comparisons,
// and (2) allows insertion of new elements. In the CPU version we use
// std::priority_queue; then (2) is O(log K). We can't use STL
// containers in CUDA; we could roll our own max heap in an array, but
// that would likely have a lot of warp divergence so we do something
// simpler instead: keep the elements in an unsorted array, but keep
// track of the max value and the index of the max value. Then (1) is
// still O(1) time, while (2) is O(K) with a clean loop. Since K <= 8
// this should be fast enough for our purposes.
// TODO(jcjohns) Abstract this out into a standalone data structure
Pix q[kMaxPointsPerPixel];
int q_size = 0;
float q_max_z = -1000;
int q_max_idx = -1;
for (int p = 0; p < P; ++p) {
CheckPixelInsidePoint(
points, p, q_size, q_max_z, q_max_idx, q, radius2, xf, yf, n, P, K);
}
BubbleSort(q, q_size);
int idx = n * S * S * K + yi * S * K + xi * K;
for (int k = 0; k < q_size; ++k) {
point_idxs[idx + k] = q[k].idx;
zbuf[idx + k] = q[k].z;
pix_dists[idx + k] = q[k].dist2;
}
}
}
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor>
RasterizePointsNaiveCuda(
const torch::Tensor& points,
const int image_size,
const float radius,
const int points_per_pixel) {
const int N = points.size(0);
const int P = points.size(1);
const int S = image_size;
const int K = points_per_pixel;
if (K > kMaxPointsPerPixel) {
std::stringstream ss;
ss << "Must have points_per_pixel <= " << kMaxPointsPerPixel;
AT_ERROR(ss.str());
}
auto int_opts = points.options().dtype(torch::kInt32);
auto float_opts = points.options().dtype(torch::kFloat32);
torch::Tensor point_idxs = torch::full({N, S, S, K}, -1, int_opts);
torch::Tensor zbuf = torch::full({N, S, S, K}, -1, float_opts);
torch::Tensor pix_dists = torch::full({N, S, S, K}, -1, float_opts);
const size_t blocks = 1024;
const size_t threads = 64;
RasterizePointsNaiveCudaKernel<<<blocks, threads>>>(
points.contiguous().data<float>(),
radius,
N,
P,
S,
K,
point_idxs.contiguous().data<int32_t>(),
zbuf.contiguous().data<float>(),
pix_dists.contiguous().data<float>());
return std::make_tuple(point_idxs, zbuf, pix_dists);
}
// ****************************************************************************
// * COARSE RASTERIZATION *
// ****************************************************************************
__global__ void RasterizePointsCoarseCudaKernel(
const float* points,
const float radius,
const int N,
const int P,
const int S,
const int bin_size,
const int chunk_size,
const int max_points_per_bin,
int* points_per_bin,
int* bin_points) {
extern __shared__ char sbuf[];
const int M = max_points_per_bin;
const int num_bins = 1 + (S - 1) / bin_size; // Integer divide round up
const float half_pix = 1.0f / S; // Size of half a pixel in NDC units
// This is a boolean array of shape (num_bins, num_bins, chunk_size)
// stored in shared memory that will track whether each point in the chunk
// falls into each bin of the image.
BitMask binmask((unsigned int*)sbuf, num_bins, num_bins, chunk_size);
// Have each block handle a chunk of points and build a 3D bitmask in
// shared memory to mark which points hit which bins. In this first phase,
// each thread processes one point at a time. After processing the chunk,
// one thread is assigned per bin, and the thread counts and writes the
// points for the bin out to global memory.
const int chunks_per_batch = 1 + (P - 1) / chunk_size;
const int num_chunks = N * chunks_per_batch;
for (int chunk = blockIdx.x; chunk < num_chunks; chunk += gridDim.x) {
const int batch_idx = chunk / chunks_per_batch;
const int chunk_idx = chunk % chunks_per_batch;
const int point_start_idx = chunk_idx * chunk_size;
binmask.block_clear();
// Have each thread handle a different point within the chunk
for (int p = threadIdx.x; p < chunk_size; p += blockDim.x) {
const int p_idx = point_start_idx + p;
if (p_idx >= P)
break;
const float px = points[batch_idx * P * 3 + p_idx * 3 + 0];
const float py = points[batch_idx * P * 3 + p_idx * 3 + 1];
const float pz = points[batch_idx * P * 3 + p_idx * 3 + 2];
if (pz < 0)
continue; // Don't render points behind the camera
const float px0 = px - radius;
const float px1 = px + radius;
const float py0 = py - radius;
const float py1 = py + radius;
// Brute-force search over all bins; TODO something smarter?
// For example we could compute the exact bin where the point falls,
// then check neighboring bins. This way we wouldn't have to check
// all bins (however then we might have more warp divergence?)
for (int by = 0; by < num_bins; ++by) {
// Get y extent for the bin. PixToNdc gives us the location of
// the center of each pixel, so we need to add/subtract a half
// pixel to get the true extent of the bin.
const float by0 = PixToNdc(by * bin_size, S) - half_pix;
const float by1 = PixToNdc((by + 1) * bin_size - 1, S) + half_pix;
const bool y_overlap = (py0 <= by1) && (by0 <= py1);
if (!y_overlap) {
continue;
}
for (int bx = 0; bx < num_bins; ++bx) {
// Get x extent for the bin; again we need to adjust the
// output of PixToNdc by half a pixel.
const float bx0 = PixToNdc(bx * bin_size, S) - half_pix;
const float bx1 = PixToNdc((bx + 1) * bin_size - 1, S) + half_pix;
const bool x_overlap = (px0 <= bx1) && (bx0 <= px1);
if (x_overlap) {
binmask.set(by, bx, p);
}
}
}
}
__syncthreads();
// Now we have processed every point in the current chunk. We need to
// count the number of points in each bin so we can write the indices
// out to global memory. We have each thread handle a different bin.
for (int byx = threadIdx.x; byx < num_bins * num_bins; byx += blockDim.x) {
const int by = byx / num_bins;
const int bx = byx % num_bins;
const int count = binmask.count(by, bx);
const int points_per_bin_idx =
batch_idx * num_bins * num_bins + by * num_bins + bx;
// This atomically increments the (global) number of points found
// in the current bin, and gets the previous value of the counter;
// this effectively allocates space in the bin_points array for the
// points in the current chunk that fall into this bin.
const int start = atomicAdd(points_per_bin + points_per_bin_idx, count);
// Now loop over the binmask and write the active bits for this bin
// out to bin_points.
int next_idx = batch_idx * num_bins * num_bins * M + by * num_bins * M +
bx * M + start;
for (int p = 0; p < chunk_size; ++p) {
if (binmask.get(by, bx, p)) {
// TODO: Throw an error if next_idx >= M -- this means that
// we got more than max_points_per_bin in this bin
// TODO: check if atomicAdd is needed in line 265.
bin_points[next_idx] = point_start_idx + p;
next_idx++;
}
}
}
__syncthreads();
}
}
torch::Tensor RasterizePointsCoarseCuda(
const torch::Tensor& points,
const int image_size,
const float radius,
const int bin_size,
const int max_points_per_bin) {
const int N = points.size(0);
const int P = points.size(1);
const int num_bins = 1 + (image_size - 1) / bin_size; // divide round up
const int M = max_points_per_bin;
if (num_bins >= 22) {
// Make sure we do not use too much shared memory.
std::stringstream ss;
ss << "Got " << num_bins << "; that's too many!";
AT_ERROR(ss.str());
}
auto opts = points.options().dtype(torch::kInt32);
torch::Tensor points_per_bin = torch::zeros({N, num_bins, num_bins}, opts);
torch::Tensor bin_points = torch::full({N, num_bins, num_bins, M}, -1, opts);
const int chunk_size = 512;
const size_t shared_size = num_bins * num_bins * chunk_size / 8;
const size_t blocks = 64;
const size_t threads = 512;
RasterizePointsCoarseCudaKernel<<<blocks, threads, shared_size>>>(
points.contiguous().data<float>(),
radius,
N,
P,
image_size,
bin_size,
chunk_size,
M,
points_per_bin.contiguous().data<int32_t>(),
bin_points.contiguous().data<int32_t>());
return bin_points;
}
// ****************************************************************************
// * FINE RASTERIZATION *
// ****************************************************************************
__global__ void RasterizePointsFineCudaKernel(
const float* points, // (N, P, 3)
const int32_t* bin_points, // (N, B, B, T)
const float radius,
const int bin_size,
const int N,
const int P,
const int B,
const int M,
const int S,
const int K,
int32_t* point_idxs, // (N, S, S, K)
float* zbuf, // (N, S, S, K)
float* pix_dists) { // (N, S, S, K)
// This can be more than S^2 if S is not dividable by bin_size.
const int num_pixels = N * B * B * bin_size * bin_size;
const int num_threads = gridDim.x * blockDim.x;
const int tid = blockIdx.x * blockDim.x + threadIdx.x;
const float radius2 = radius * radius;
for (int pid = tid; pid < num_pixels; pid += num_threads) {
// Convert linear index into bin and pixel indices. We make the within
// block pixel ids move the fastest, so that adjacent threads will fall
// into the same bin; this should give them coalesced memory reads when
// they read from points and bin_points.
int i = pid;
const int n = i / (B * B * bin_size * bin_size);
i %= B * B * bin_size * bin_size;
const int by = i / (B * bin_size * bin_size);
i %= B * bin_size * bin_size;
const int bx = i / (bin_size * bin_size);
i %= bin_size * bin_size;
const int yi = i / bin_size + by * bin_size;
const int xi = i % bin_size + bx * bin_size;
if (yi >= S || xi >= S)
continue;
const float xf = PixToNdc(xi, S);
const float yf = PixToNdc(yi, S);
// This part looks like the naive rasterization kernel, except we use
// bin_points to only look at a subset of points already known to fall
// in this bin. TODO abstract out this logic into some data structure
// that is shared by both kernels?
Pix q[kMaxPointsPerPixel];
int q_size = 0;
float q_max_z = -1000;
int q_max_idx = -1;
for (int m = 0; m < M; ++m) {
const int p = bin_points[n * B * B * M + by * B * M + bx * M + m];
if (p < 0) {
// bin_points uses -1 as a sentinal value
continue;
}
CheckPixelInsidePoint(
points, p, q_size, q_max_z, q_max_idx, q, radius2, xf, yf, n, P, K);
}
// Now we've looked at all the points for this bin, so we can write
// output for the current pixel.
BubbleSort(q, q_size);
const int pix_idx = n * S * S * K + yi * S * K + xi * K;
for (int k = 0; k < q_size; ++k) {
point_idxs[pix_idx + k] = q[k].idx;
zbuf[pix_idx + k] = q[k].z;
pix_dists[pix_idx + k] = q[k].dist2;
}
}
}
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> RasterizePointsFineCuda(
const torch::Tensor& points,
const torch::Tensor& bin_points,
const int image_size,
const float radius,
const int bin_size,
const int points_per_pixel) {
const int N = points.size(0);
const int P = points.size(1);
const int B = bin_points.size(1);
const int M = bin_points.size(3);
const int S = image_size;
const int K = points_per_pixel;
if (K > kMaxPointsPerPixel) {
AT_ERROR("Must have num_closest <= 8");
}
auto int_opts = points.options().dtype(torch::kInt32);
auto float_opts = points.options().dtype(torch::kFloat32);
torch::Tensor point_idxs = torch::full({N, S, S, K}, -1, int_opts);
torch::Tensor zbuf = torch::full({N, S, S, K}, -1, float_opts);
torch::Tensor pix_dists = torch::full({N, S, S, K}, -1, float_opts);
const size_t blocks = 1024;
const size_t threads = 64;
RasterizePointsFineCudaKernel<<<blocks, threads>>>(
points.contiguous().data<float>(),
bin_points.contiguous().data<int32_t>(),
radius,
bin_size,
N,
P,
B,
M,
S,
K,
point_idxs.contiguous().data<int32_t>(),
zbuf.contiguous().data<float>(),
pix_dists.contiguous().data<float>());
return std::make_tuple(point_idxs, zbuf, pix_dists);
}
// ****************************************************************************
// * BACKWARD PASS *
// ****************************************************************************
// TODO(T55115174) Add more documentation for backward kernel.
__global__ void RasterizePointsBackwardCudaKernel(
const float* points, // (N, P, 3)
const int32_t* idxs, // (N, H, W, K)
const int N,
const int P,
const int H,
const int W,
const int K,
const float* grad_zbuf, // (N, H, W, K)
const float* grad_dists, // (N, H, W, K)
float* grad_points) { // (N, P, 3)
// Parallelized over each of K points per pixel, for each pixel in images of
// size H * W, for each image in the batch of size N.
int num_threads = gridDim.x * blockDim.x;
int tid = blockIdx.x * blockDim.x + threadIdx.x;
for (int i = tid; i < N * H * W * K; i += num_threads) {
const int n = i / (H * W * K);
const int yxk = i % (H * W * K);
const int yi = yxk / (W * K);
const int xk = yxk % (W * K);
const int xi = xk / K;
// k = xk % K (We don't actually need k, but this would be it.)
const float xf = PixToNdc(xi, W);
const float yf = PixToNdc(yi, H);
const int p = idxs[i];
if (p < 0)
continue;
const float grad_dist2 = grad_dists[i];
const int p_ind = n * P * 3 + p * 3;
const float px = points[p_ind];
const float py = points[p_ind + 1];
const float dx = px - xf;
const float dy = py - yf;
const float grad_px = 2.0f * grad_dist2 * dx;
const float grad_py = 2.0f * grad_dist2 * dy;
const float grad_pz = grad_zbuf[i];
atomicAdd(grad_points + p_ind, grad_px);
atomicAdd(grad_points + p_ind + 1, grad_py);
atomicAdd(grad_points + p_ind + 2, grad_pz);
}
}
torch::Tensor RasterizePointsBackwardCuda(
const torch::Tensor& points, // (N, P, 3)
const torch::Tensor& idxs, // (N, H, W, K)
const torch::Tensor& grad_zbuf, // (N, H, W, K)
const torch::Tensor& grad_dists) { // (N, H, W, K)
const int N = points.size(0);
const int P = points.size(1);
const int H = idxs.size(1);
const int W = idxs.size(2);
const int K = idxs.size(3);
torch::Tensor grad_points = torch::zeros({N, P, 3}, points.options());
const size_t blocks = 1024;
const size_t threads = 64;
RasterizePointsBackwardCudaKernel<<<blocks, threads>>>(
points.contiguous().data<float>(),
idxs.contiguous().data<int32_t>(),
N,
P,
H,
W,
K,
grad_zbuf.contiguous().data<float>(),
grad_dists.contiguous().data<float>(),
grad_points.contiguous().data<float>());
return grad_points;
}

View File

@@ -0,0 +1,230 @@
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#pragma once
#include <torch/extension.h>
#include <cstdio>
#include <tuple>
// ****************************************************************************
// * NAIVE RASTERIZATION *
// ****************************************************************************
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> RasterizePointsNaiveCpu(
const torch::Tensor& points,
const int image_size,
const float radius,
const int points_per_pixel);
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor>
RasterizePointsNaiveCuda(
const torch::Tensor& points,
const int image_size,
const float radius,
const int points_per_pixel);
// Naive (forward) pointcloud rasterization: For each pixel, for each point,
// check whether that point hits the pixel.
//
// Args:
// points: Tensor of shape (N, P, 3) (in NDC)
// radius: Radius of each point (in NDC units)
// image_size: (S) Size of the image to return (in pixels)
// points_per_pixel: (K) The number closest of points to return for each pixel
//
// Returns:
// idxs: int32 Tensor of shape (N, S, S, K) giving the indices of the
// closest K points along the z-axis for each pixel, padded with -1 for
// pixels
// hit by fewer than K points.
// zbuf: float32 Tensor of shape (N, S, S, K) giving the depth of each
// closest point for each pixel.
// dists: float32 Tensor of shape (N, S, S, K) giving squared Euclidean
// distance in the (NDC) x/y plane between each pixel and its K closest
// points along the z axis.
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> RasterizePointsNaive(
const torch::Tensor& points,
const int image_size,
const float radius,
const int points_per_pixel) {
if (points.type().is_cuda()) {
return RasterizePointsNaiveCuda(
points, image_size, radius, points_per_pixel);
} else {
return RasterizePointsNaiveCpu(
points, image_size, radius, points_per_pixel);
}
}
// ****************************************************************************
// * COARSE RASTERIZATION *
// ****************************************************************************
torch::Tensor RasterizePointsCoarseCpu(
const torch::Tensor& points,
const int image_size,
const float radius,
const int bin_size,
const int max_points_per_bin);
torch::Tensor RasterizePointsCoarseCuda(
const torch::Tensor& points,
const int image_size,
const float radius,
const int bin_size,
const int max_points_per_bin);
// Args:
// points: Tensor of shape (N, P, 3)
// radius: Radius of points to rasterize (in NDC units)
// image_size: Size of the image to generate (in pixels)
// bin_size: Size of each bin within the image (in pixels)
//
// Returns:
// points_per_bin: Tensor of shape (N, num_bins, num_bins) giving the number
// of points that fall in each bin
// bin_points: Tensor of shape (N, num_bins, num_bins, K) giving the indices
// of points that fall into each bin.
torch::Tensor RasterizePointsCoarse(
const torch::Tensor& points,
const int image_size,
const float radius,
const int bin_size,
const int max_points_per_bin) {
if (points.type().is_cuda()) {
return RasterizePointsCoarseCuda(
points, image_size, radius, bin_size, max_points_per_bin);
} else {
return RasterizePointsCoarseCpu(
points, image_size, radius, bin_size, max_points_per_bin);
}
}
// ****************************************************************************
// * FINE RASTERIZATION *
// ****************************************************************************
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> RasterizePointsFineCuda(
const torch::Tensor& points,
const torch::Tensor& bin_points,
const int image_size,
const float radius,
const int bin_size,
const int points_per_pixel);
// Args:
// points: float32 Tensor of shape (N, P, 3)
// bin_points: int32 Tensor of shape (N, B, B, M) giving the indices of points
// that fall into each bin (output from coarse rasterization)
// image_size: Size of image to generate (in pixels)
// radius: Radius of points to rasterize (NDC units)
// bin_size: Size of each bin (in pixels)
// points_per_pixel: How many points to rasterize for each pixel
//
// Returns (same as rasterize_points):
// idxs: int32 Tensor of shape (N, S, S, K) giving the indices of the closest
// points_per_pixel points along the z-axis for each pixel, padded with
// -1 for pixels hit by fewer than points_per_pixel points
// zbuf: float32 Tensor of shape (N, S, S, K) giving the depth of each of each
// closest point for each pixel
// dists: float32 Tensor of shape (N, S, S, K) giving squared Euclidean
// distance in the (NDC) x/y plane between each pixel and its K closest
// points along the z axis.
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> RasterizePointsFine(
const torch::Tensor& points,
const torch::Tensor& bin_points,
const int image_size,
const float radius,
const int bin_size,
const int points_per_pixel) {
if (points.type().is_cuda()) {
return RasterizePointsFineCuda(
points, bin_points, image_size, radius, bin_size, points_per_pixel);
} else {
AT_ERROR("NOT IMPLEMENTED");
}
}
// ****************************************************************************
// * BACKWARD PASS *
// ****************************************************************************
torch::Tensor RasterizePointsBackwardCpu(
const torch::Tensor& points,
const torch::Tensor& idxs,
const torch::Tensor& grad_zbuf,
const torch::Tensor& grad_dists);
torch::Tensor RasterizePointsBackwardCuda(
const torch::Tensor& points,
const torch::Tensor& idxs,
const torch::Tensor& grad_zbuf,
const torch::Tensor& grad_dists);
// Args:
// points: float32 Tensor of shape (N, P, 3)
// idxs: int32 Tensor of shape (N, H, W, K) (from forward pass)
// grad_zbuf: float32 Tensor of shape (N, H, W, K) giving upstream gradient
// d(loss)/d(zbuf) of the distances from each pixel to its nearest
// points.
// grad_dists: Tensor of shape (N, H, W, K) giving upstream gradient
// d(loss)/d(dists) of the dists tensor returned by the forward
// pass.
//
// Returns:
// grad_points: float32 Tensor of shape (N, P, 3) giving downstream gradients
torch::Tensor RasterizePointsBackward(
const torch::Tensor& points,
const torch::Tensor& idxs,
const torch::Tensor& grad_zbuf,
const torch::Tensor& grad_dists) {
if (points.type().is_cuda()) {
return RasterizePointsBackwardCuda(points, idxs, grad_zbuf, grad_dists);
} else {
return RasterizePointsBackwardCpu(points, idxs, grad_zbuf, grad_dists);
}
}
// ****************************************************************************
// * MAIN ENTRY POINT *
// ****************************************************************************
// This is the main entry point for the forward pass of the point rasterizer;
// it uses either naive or coarse-to-fine rasterization based on bin_size.
//
// Args:
// points: Tensor of shape (N, P, 3) (in NDC)
// radius: Radius of each point (in NDC units)
// image_size: (S) Size of the image to return (in pixels)
// points_per_pixel: (K) The number of points to return for each pixel
// bin_size: Bin size (in pixels) for coarse-to-fine rasterization. Setting
// bin_size=0 uses naive rasterization instead.
// max_points_per_bin: The maximum number of points allowed to fall into each
// bin when using coarse-to-fine rasterization.
//
// Returns:
// idxs: int32 Tensor of shape (N, S, S, K) giving the indices of the
// closest points_per_pixel points along the z-axis for each pixel,
// padded with -1 for pixels hit by fewer than points_per_pixel points
// zbuf: float32 Tensor of shape (N, S, S, K) giving the depth of each of each
// closest point for each pixel
// dists: float32 Tensor of shape (N, S, S, K) giving squared Euclidean
// distance in the (NDC) x/y plane between each pixel and its K closest
// points along the z axis.
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> RasterizePoints(
const torch::Tensor& points,
const int image_size,
const float radius,
const int points_per_pixel,
const int bin_size,
const int max_points_per_bin) {
if (bin_size == 0) {
// Use the naive per-pixel implementation
return RasterizePointsNaive(points, image_size, radius, points_per_pixel);
} else {
// Use coarse-to-fine rasterization
const auto bin_points = RasterizePointsCoarse(
points, image_size, radius, bin_size, max_points_per_bin);
return RasterizePointsFine(
points, bin_points, image_size, radius, bin_size, points_per_pixel);
}
}

View File

@@ -0,0 +1,196 @@
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#include <torch/extension.h>
#include <queue>
#include <tuple>
// Given a pixel coordinate 0 <= i < S, convert it to a normalized device
// coordinate in the range [-1, 1]. The NDC range is divided into S evenly-sized
// pixels, and assume that each pixel falls in the *center* of its range.
inline float PixToNdc(const int i, const int S) {
// NDC x-offset + (i * pixel_width + half_pixel_width)
return -1 + (2 * i + 1.0f) / S;
}
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> RasterizePointsNaiveCpu(
const torch::Tensor& points,
const int image_size,
const float radius,
const int points_per_pixel) {
const int N = points.size(0);
const int P = points.size(1);
const int S = image_size;
const int K = points_per_pixel;
auto int_opts = points.options().dtype(torch::kInt32);
auto float_opts = points.options().dtype(torch::kFloat32);
torch::Tensor point_idxs = torch::full({N, S, S, K}, -1, int_opts);
torch::Tensor zbuf = torch::full({N, S, S, K}, -1, float_opts);
torch::Tensor pix_dists = torch::full({N, S, S, K}, -1, float_opts);
auto points_a = points.accessor<float, 3>();
auto point_idxs_a = point_idxs.accessor<int32_t, 4>();
auto zbuf_a = zbuf.accessor<float, 4>();
auto pix_dists_a = pix_dists.accessor<float, 4>();
const float radius2 = radius * radius;
for (int n = 0; n < N; ++n) {
for (int yi = 0; yi < S; ++yi) {
float yf = PixToNdc(yi, S);
for (int xi = 0; xi < S; ++xi) {
float xf = PixToNdc(xi, S);
// Use a priority queue to hold (z, idx, r)
std::priority_queue<std::tuple<float, int, float>> q;
for (int p = 0; p < P; ++p) {
const float px = points_a[n][p][0];
const float py = points_a[n][p][1];
const float pz = points_a[n][p][2];
if (pz < 0) {
continue;
}
const float dx = px - xf;
const float dy = py - yf;
const float dist2 = dx * dx + dy * dy;
if (dist2 < radius2) {
// The current point hit the current pixel
q.emplace(pz, p, dist2);
if ((int)q.size() > K) {
q.pop();
}
}
}
// Now all the points have been seen, so pop elements off the queue
// one by one and write them into the output tensors.
while (!q.empty()) {
auto t = q.top();
q.pop();
int i = q.size();
zbuf_a[n][yi][xi][i] = std::get<0>(t);
point_idxs_a[n][yi][xi][i] = std::get<1>(t);
pix_dists_a[n][yi][xi][i] = std::get<2>(t);
}
}
}
}
return std::make_tuple(point_idxs, zbuf, pix_dists);
}
std::tuple<torch::Tensor, torch::Tensor> RasterizePointsCoarseCpu(
const torch::Tensor& points,
const int image_size,
const float radius,
const int bin_size,
const int max_points_per_bin) {
const int N = points.size(0);
const int P = points.size(1);
const int B = 1 + (image_size - 1) / bin_size; // Integer division round up
const int M = max_points_per_bin;
auto opts = points.options().dtype(torch::kInt32);
torch::Tensor points_per_bin = torch::zeros({N, B, B}, opts);
torch::Tensor bin_points = torch::full({N, B, B, M}, -1, opts);
auto points_a = points.accessor<float, 3>();
auto points_per_bin_a = points_per_bin.accessor<int32_t, 3>();
auto bin_points_a = bin_points.accessor<int32_t, 4>();
const float pixel_width = 2.0f / image_size;
const float bin_width = pixel_width * bin_size;
for (int n = 0; n < N; ++n) {
float bin_y_min = -1.0f;
float bin_y_max = bin_y_min + bin_width;
for (int by = 0; by < B; by++) {
float bin_x_min = -1.0f;
float bin_x_max = bin_x_min + bin_width;
for (int bx = 0; bx < B; bx++) {
int32_t points_hit = 0;
for (int32_t p = 0; p < P; p++) {
float px = points_a[n][p][0];
float py = points_a[n][p][1];
float pz = points_a[n][p][2];
if (pz < 0) {
continue;
}
float point_x_min = px - radius;
float point_x_max = px + radius;
float point_y_min = py - radius;
float point_y_max = py + radius;
// Use a half-open interval so that points exactly on the
// boundary between bins will fall into exactly one bin.
bool x_hit = (point_x_min <= bin_x_max) && (bin_x_min <= point_x_max);
bool y_hit = (point_y_min <= bin_y_max) && (bin_y_min <= point_y_max);
if (x_hit && y_hit) {
// Got too many points for this bin, so throw an error.
if (points_hit >= max_points_per_bin) {
AT_ERROR("Got too many points per bin");
}
// The current point falls in the current bin, so
// record it.
bin_points_a[n][by][bx][points_hit] = p;
points_hit++;
}
}
// Record the number of points found in this bin
points_per_bin_a[n][by][bx] = points_hit;
// Shift the bin to the right for the next loop iteration
bin_x_min = bin_x_max;
bin_x_max = bin_x_min + bin_width;
}
// Shift the bin down for the next loop iteration
bin_y_min = bin_y_max;
bin_y_max = bin_y_min + bin_width;
}
}
return std::make_tuple(points_per_bin, bin_points);
}
torch::Tensor RasterizePointsBackwardCpu(
const torch::Tensor& points, // (N, P, 3)
const torch::Tensor& idxs, // (N, H, W, K)
const torch::Tensor& grad_zbuf, // (N, H, W, K)
const torch::Tensor& grad_dists) { // (N, H, W, K)
const int N = points.size(0);
const int P = points.size(1);
const int H = idxs.size(1);
const int W = idxs.size(2);
const int K = idxs.size(3);
// For now only support square images.
// TODO(jcjohns): Extend to non-square images.
if (H != W) {
AT_ERROR("RasterizePointsBackwardCpu only supports square images");
}
torch::Tensor grad_points = torch::zeros({N, P, 3}, points.options());
auto points_a = points.accessor<float, 3>();
auto idxs_a = idxs.accessor<int32_t, 4>();
auto grad_dists_a = grad_dists.accessor<float, 4>();
auto grad_zbuf_a = grad_zbuf.accessor<float, 4>();
auto grad_points_a = grad_points.accessor<float, 3>();
for (int n = 0; n < N; ++n) { // Loop over images in the batch
for (int y = 0; y < H; ++y) { // Loop over rows in the image
const float yf = PixToNdc(y, H);
for (int x = 0; x < W; ++x) { // Loop over pixels in the row
const float xf = PixToNdc(x, W);
for (int k = 0; k < K; ++k) { // Loop over points for the pixel
const int p = idxs_a[n][y][x][k];
if (p < 0) {
break;
}
const float grad_dist2 = grad_dists_a[n][y][x][k];
const float px = points_a[n][p][0];
const float py = points_a[n][p][1];
const float dx = px - xf;
const float dy = py - yf;
// Remember: dists[n][y][x][k] = dx * dx + dy * dy;
const float grad_px = 2.0f * grad_dist2 * dx;
const float grad_py = 2.0f * grad_dist2 * dy;
grad_points_a[n][p][0] += grad_px;
grad_points_a[n][p][1] += grad_py;
grad_points_a[n][p][2] += grad_zbuf_a[n][y][x][k];
}
}
}
}
return grad_points;
}

7
pytorch3d/io/__init__.py Normal file
View File

@@ -0,0 +1,7 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
from .obj_io import load_obj, save_obj
from .ply_io import load_ply, save_ply
__all__ = [k for k in globals().keys() if not k.startswith("_")]

532
pytorch3d/io/obj_io.py Normal file
View File

@@ -0,0 +1,532 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
"""This module implements utility functions for loading and saving meshes."""
import numpy as np
import os
import pathlib
import warnings
from collections import namedtuple
from typing import List
import torch
from fvcore.common.file_io import PathManager
from PIL import Image
def _read_image(file_name: str, format=None):
"""
Read an image from a file using Pillow.
Args:
file_name: image file path.
format: one of ["RGB", "BGR"]
Returns:
image: an image of shape (H, W, C).
"""
if format not in ["RGB", "BGR"]:
raise ValueError("format can only be one of [RGB, BGR]; got %s", format)
with PathManager.open(file_name, "rb") as f:
image = Image.open(f)
if format is not None:
# PIL only supports RGB. First convert to RGB and flip channels
# below for BGR.
image = image.convert("RGB")
image = np.asarray(image).astype(np.float32)
if format == "BGR":
image = image[:, :, ::-1]
return image
# Faces & Aux type returned from load_obj function.
_Faces = namedtuple("Faces", "verts_idx normals_idx textures_idx materials_idx")
_Aux = namedtuple(
"Properties", "normals verts_uvs material_colors texture_images"
)
def _format_faces_indices(faces_indices, max_index):
"""
Format indices and check for invalid values. Indices can refer to
values in one of the face properties: vertices, textures or normals.
See comments of the load_obj function for more details.
Args:
faces_indices: List of ints of indices.
max_index: Max index for the face property.
Returns:
faces_indices: List of ints of indices.
Raises:
ValueError if indices are not in a valid range.
"""
faces_indices = torch.tensor(faces_indices, dtype=torch.int64)
# Change to 0 based indexing.
faces_indices[(faces_indices > 0)] -= 1
# Negative indexing counts from the end.
faces_indices[(faces_indices < 0)] += max_index
# Check indices are valid.
if not (
torch.all(faces_indices < max_index) and torch.all(faces_indices >= 0)
):
raise ValueError("Faces have invalid indices.")
return faces_indices
def _open_file(f):
new_f = False
if isinstance(f, str):
new_f = True
f = open(f, "r")
elif isinstance(f, pathlib.Path):
new_f = True
f = f.open("r")
return f, new_f
def load_obj(f_obj):
"""
Load a mesh and textures from a .obj and .mtl file.
Currently this handles verts, faces, vertex texture uv coordinates, normals,
texture images and material reflectivity values.
Note .obj files are 1-indexed. The tensors returned from this function
are 0-indexed. OBJ spec reference: http://www.martinreddy.net/gfx/3d/OBJ.spec
Example .obj file format:
::
# this is a comment
v 1.000000 -1.000000 -1.000000
v 1.000000 -1.000000 1.000000
v -1.000000 -1.000000 1.000000
v -1.000000 -1.000000 -1.000000
v 1.000000 1.000000 -1.000000
vt 0.748573 0.750412
vt 0.749279 0.501284
vt 0.999110 0.501077
vt 0.999455 0.750380
vn 0.000000 0.000000 -1.000000
vn -1.000000 -0.000000 -0.000000
vn -0.000000 -0.000000 1.000000
f 5/2/1 1/2/1 4/3/1
f 5/1/1 4/3/1 2/4/1
The first character of the line denotes the type of input:
::
- v is a vertex
- vt is the texture coordinate of one vertex
- vn is the normal of one vertex
- f is a face
Faces are interpreted as follows:
::
5/2/1 describes the first vertex of the first triange
- 5: index of vertex [1.000000 1.000000 -1.000000]
- 2: index of texture coordinate [0.749279 0.501284]
- 1: index of normal [0.000000 0.000000 -1.000000]
If there are faces with more than 3 vertices
they are subdivided into triangles. Polygonal faces are assummed to have
vertices ordered counter-clockwise so the (right-handed) normal points
into the screen e.g. a proper rectangular face would be specified like this:
::
0_________1
| |
| |
3 ________2
The face would be split into two triangles: (0, 1, 2) and (0, 2, 3),
both of which are also oriented clockwise and have normals
pointing into the screen.
Args:
f: A file-like object (with methods read, readline, tell, and seek),
a pathlib path or a string containing a file name.
Returns:
6-element tuple containing
- **verts**: FloatTensor of shape (V, 3).
- **faces**: NamedTuple with fields:
- verts_idx: LongTensor of vertex indices, shape (F, 3).
- normals_idx: (optional) LongTensor of normal indices, shape (F, 3).
- textures_idx: (optional) LongTensor of texture indices, shape (F, 3).
This can be used to index into verts_uvs.
- materials_idx: (optional) List of indices indicating which
material the texture is derived from for each face.
If there is no material for a face, the index is -1.
This can be used to retrieve the corresponding values
in material_colors/texture_images after they have been
converted to tensors or Materials/Textures data
structures - see textures.py and materials.py for
more info.
- **aux**: NamedTuple with fields:
- normals: FloatTensor of shape (N, 3)
- verts_uvs: FloatTensor of shape (T, 2), giving the uv coordinate per
vertex. If a vertex is shared between two faces, it can have
a different uv value for each instance. Therefore it is
possible that the number of verts_uvs is greater than
num verts i.e. T > V.
vertex.
- material_colors: dict of material names and associated properties.
If a material does not have any properties it will have an
empty dict.
.. code-block:: python
{
material_name_1: {
"ambient_color": tensor of shape (1, 3),
"diffuse_color": tensor of shape (1, 3),
"specular_color": tensor of shape (1, 3),
"shininess": tensor of shape (1)
},
material_name_2: {},
...
}
- texture_images: dict of material names and texture images.
.. code-block:: python
{
material_name_1: (H, W, 3) image,
...
}
"""
data_dir = "./"
if isinstance(f_obj, (str, bytes, os.PathLike)):
data_dir = os.path.dirname(f_obj)
f_obj, new_f = _open_file(f_obj)
try:
return _load(f_obj, data_dir)
finally:
if new_f:
f_obj.close()
def _parse_face(
line,
material_idx,
faces_verts_idx,
faces_normals_idx,
faces_textures_idx,
faces_materials_idx,
):
face = line.split(" ")[1:]
face_list = [f.split("/") for f in face]
face_verts = []
face_normals = []
face_textures = []
for vert_props in face_list:
# Vertex index.
face_verts.append(int(vert_props[0]))
if len(vert_props) > 1:
if vert_props[1] != "":
# Texture index is present e.g. f 4/1/1.
face_textures.append(int(vert_props[1]))
if len(vert_props) > 2:
# Normal index present e.g. 4/1/1 or 4//1.
face_normals.append(int(vert_props[2]))
if len(vert_props) > 3:
raise ValueError(
"Face vertices can ony have 3 properties. \
Face vert %s, Line: %s"
% (str(vert_props), str(line))
)
# Triplets must be consistent for all vertices in a face e.g.
# legal statement: f 4/1/1 3/2/1 2/1/1.
# illegal statement: f 4/1/1 3//1 2//1.
if len(face_normals) > 0:
if not (len(face_verts) == len(face_normals)):
raise ValueError(
"Face %s is an illegal statement. \
Vertex properties are inconsistent. Line: %s"
% (str(face), str(line))
)
if len(face_textures) > 0:
if not (len(face_verts) == len(face_textures)):
raise ValueError(
"Face %s is an illegal statement. \
Vertex properties are inconsistent. Line: %s"
% (str(face), str(line))
)
# Subdivide faces with more than 3 vertices. See comments of the
# load_obj function for more details.
for i in range(len(face_verts) - 2):
faces_verts_idx.append(
(face_verts[0], face_verts[i + 1], face_verts[i + 2])
)
if len(face_normals) > 0:
faces_normals_idx.append(
(face_normals[0], face_normals[i + 1], face_normals[i + 2])
)
if len(face_textures) > 0:
faces_textures_idx.append(
(face_textures[0], face_textures[i + 1], face_textures[i + 2])
)
faces_materials_idx.append(material_idx)
def _load(f_obj, data_dir):
"""
Load a mesh from a file-like object. See load_obj function more details.
Any material files associated with the obj are expected to be in the
directory given by data_dir.
"""
lines = [line.strip() for line in f_obj]
verts = []
normals = []
verts_uvs = []
faces_verts_idx = []
faces_normals_idx = []
faces_textures_idx = []
material_names = []
faces_materials_idx = []
f_mtl = None
materials_idx = -1
# startswith expects each line to be a string. If the file is read in as
# bytes then first decode to strings.
if isinstance(lines[0], bytes):
lines = [l.decode("utf-8") for l in lines]
for line in lines:
if line.startswith("mtllib"):
if len(line.split()) < 2:
raise ValueError("material file name is not specified")
# NOTE: this assumes only one mtl file per .obj.
f_mtl = os.path.join(data_dir, line.split()[1])
elif len(line.split()) != 0 and line.split()[0] == "usemtl":
material_name = line.split()[1]
material_names.append(material_name)
materials_idx = len(material_names) - 1
elif line.startswith("v "):
# Line is a vertex.
vert = [float(x) for x in line.split()[1:4]]
if len(vert) != 3:
msg = "Vertex %s does not have 3 values. Line: %s"
raise ValueError(msg % (str(vert), str(line)))
verts.append(vert)
elif line.startswith("vt "):
# Line is a texture.
tx = [float(x) for x in line.split()[1:3]]
if len(tx) != 2:
raise ValueError(
"Texture %s does not have 2 values. Line: %s"
% (str(tx), str(line))
)
verts_uvs.append(tx)
elif line.startswith("vn "):
# Line is a normal.
norm = [float(x) for x in line.split()[1:4]]
if len(norm) != 3:
msg = "Normal %s does not have 3 values. Line: %s"
raise ValueError(msg % (str(norm), str(line)))
normals.append(norm)
elif line.startswith("f "):
# Line is a face.
_parse_face(
line,
materials_idx,
faces_verts_idx,
faces_normals_idx,
faces_textures_idx,
faces_materials_idx,
)
verts = torch.tensor(verts) # (V, 3)
normals = torch.tensor(normals) # (N, 3)
verts_uvs = torch.tensor(verts_uvs) # (T, 3)
faces_verts_idx = _format_faces_indices(faces_verts_idx, verts.shape[0])
# Repeat for normals and textures if present.
if len(faces_normals_idx) > 0:
faces_normals_idx = _format_faces_indices(
faces_normals_idx, normals.shape[0]
)
if len(faces_textures_idx) > 0:
faces_textures_idx = _format_faces_indices(
faces_textures_idx, verts_uvs.shape[0]
)
if len(faces_materials_idx) > 0:
faces_materials_idx = torch.tensor(
faces_materials_idx, dtype=torch.int64
)
# Load materials
material_colors, texture_images = None, None
if (len(material_names) > 0) and (f_mtl is not None):
if os.path.isfile(f_mtl):
material_colors, texture_images = load_mtl(
f_mtl, material_names, data_dir
)
else:
warnings.warn(f"Mtl file does not exist: {f_mtl}")
elif len(material_names) > 0:
warnings.warn("No mtl file provided")
faces = _Faces(
verts_idx=faces_verts_idx,
normals_idx=faces_normals_idx,
textures_idx=faces_textures_idx,
materials_idx=faces_materials_idx,
)
aux = _Aux(
normals=normals if len(normals) > 0 else None,
verts_uvs=verts_uvs if len(verts_uvs) > 0 else None,
material_colors=material_colors,
texture_images=texture_images,
)
return verts, faces, aux
def load_mtl(f_mtl, material_names: List, data_dir: str):
"""
Load texture images and material reflectivity values for ambient, diffuse
and specular light (Ka, Kd, Ks, Ns).
Args:
f_mtl: a file like object of the material information.
material_names: a list of the material names found in the .obj file.
data_dir: the directory where the material texture files are located.
Returns:
material_colors: dict of properties for each material. If a material
does not have any properties it will have an emtpy dict.
{
material_name_1: {
"ambient_color": tensor of shape (1, 3),
"diffuse_color": tensor of shape (1, 3),
"specular_color": tensor of shape (1, 3),
"shininess": tensor of shape (1)
},
material_name_2: {},
...
}
texture_images: dict of material names and texture images
{
material_name_1: (H, W, 3) image,
...
}
"""
texture_files = {}
material_colors = {}
material_properties = {}
texture_images = {}
material_name = ""
f_mtl, new_f = _open_file(f_mtl)
lines = [line.strip() for line in f_mtl]
for line in lines:
if len(line.split()) != 0:
if line.split()[0] == "newmtl":
material_name = line.split()[1]
material_colors[material_name] = {}
if line.split()[0] == "map_Kd":
# Texture map.
texture_files[material_name] = line.split()[1]
if line.split()[0] == "Kd":
# RGB diffuse reflectivity
kd = np.array(list(line.split()[1:4])).astype(np.float32)
kd = torch.from_numpy(kd)
material_colors[material_name]["diffuse_color"] = kd
if line.split()[0] == "Ka":
# RGB ambient reflectivity
ka = np.array(list(line.split()[1:4])).astype(np.float32)
ka = torch.from_numpy(ka)
material_colors[material_name]["ambient_color"] = ka
if line.split()[0] == "Ks":
# RGB specular reflectivity
ks = np.array(list(line.split()[1:4])).astype(np.float32)
ks = torch.from_numpy(ks)
material_colors[material_name]["specular_color"] = ks
if line.split()[0] == "Ns":
# Specular exponent
ns = np.array(list(line.split()[1:4])).astype(np.float32)
ns = torch.from_numpy(ns)
material_colors[material_name]["shininess"] = ns
if new_f:
f_mtl.close()
# Only keep the materials referenced in the obj.
for name in material_names:
if name in texture_files:
# Load the texture image.
filename = texture_files[name]
filename_texture = os.path.join(data_dir, filename)
if os.path.isfile(filename_texture):
image = _read_image(filename_texture, format="RGB") / 255.0
image = torch.from_numpy(image)
texture_images[name] = image
else:
msg = f"Texture file does not exist: {filename_texture}"
warnings.warn(msg)
if name in material_colors:
material_properties[name] = material_colors[name]
return material_properties, texture_images
def save_obj(f, verts, faces, decimal_places: int = None):
"""
Save a mesh to an .obj file.
Args:
f: File (or path) to which the mesh should be written.
verts: FloatTensor of shape (V, 3) giving vertex coordinates.
faces: LongTensor of shape (F, 3) giving faces.
decimal_places: Number of decimal places for saving.
"""
new_f = False
if isinstance(f, str):
new_f = True
f = open(f, "w")
elif isinstance(f, pathlib.Path):
new_f = True
f = f.open("w")
try:
return _save(f, verts, faces, decimal_places)
finally:
if new_f:
f.close()
# TODO (nikhilar) Speed up this function.
def _save(f, verts, faces, decimal_places: int = None):
if verts.dim() != 2 or verts.size(1) != 3:
raise ValueError("Argument 'verts' should be of shape (num_verts, 3).")
if faces.dim() != 2 or faces.size(1) != 3:
raise ValueError("Argument 'faces' should be of shape (num_faces, 3).")
verts, faces = verts.cpu(), faces.cpu()
if decimal_places is None:
float_str = "%f"
else:
float_str = "%" + ".%df" % decimal_places
lines = ""
V, D = verts.shape
for i in range(V):
vert = [float_str % verts[i, j] for j in range(D)]
lines += "v %s\n" % " ".join(vert)
F, P = faces.shape
for i in range(F):
face = ["%d" % (faces[i, j] + 1) for j in range(P)]
if i + 1 < F:
lines += "f %s\n" % " ".join(face)
elif i + 1 == F:
# No newline at the end of the file.
lines += "f %s" % " ".join(face)
f.write(lines)

748
pytorch3d/io/ply_io.py Normal file
View File

@@ -0,0 +1,748 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""This module implements utility functions for loading and saving meshes."""
import numpy as np
import pathlib
import struct
import sys
import warnings
from collections import namedtuple
from typing import Optional, Tuple
import torch
_PlyTypeData = namedtuple("_PlyTypeData", "size struct_char np_type")
_PLY_TYPES = {
"char": _PlyTypeData(1, "b", np.byte),
"uchar": _PlyTypeData(1, "B", np.ubyte),
"short": _PlyTypeData(2, "h", np.short),
"ushort": _PlyTypeData(2, "H", np.ushort),
"int": _PlyTypeData(4, "i", np.int32),
"uint": _PlyTypeData(4, "I", np.uint32),
"float": _PlyTypeData(4, "f", np.float32),
"double": _PlyTypeData(8, "d", np.float64),
}
_Property = namedtuple("_Property", "name data_type list_size_type")
class _PlyElementType:
"""
Description of an element of a Ply file.
Members:
self.properties: (List[_Property]) description of all the properties.
Each one contains a name and data type.
self.count: (int) number of such elements in the file
self.name: (str) name of the element
"""
def __init__(self, name: str, count: int):
self.name = name
self.count = count
self.properties = []
def add_property(
self, name: str, data_type: str, list_size_type: Optional[str] = None
):
"""Adds a new property.
Args:
name: (str) name of the property.
data_type: (str) PLY data type.
list_size_type: (str) PLY data type of the list size, or None if not
a list.
"""
for property in self.properties:
if property.name == name:
msg = "Cannot have two properties called %s in %s."
raise ValueError(msg % (name, self.name))
self.properties.append(_Property(name, data_type, list_size_type))
def is_fixed_size(self) -> bool:
"""Return whether the Element has no list properties
Returns:
True if none of the properties are lists.
"""
for property in self.properties:
if property.list_size_type is not None:
return False
return True
def is_constant_type_fixed_size(self) -> bool:
"""Return whether the Element has all properties of the same non-list
type.
Returns:
True if none of the properties are lists and all the properties
share a type.
"""
if not self.is_fixed_size():
return False
first_type = self.properties[0].data_type
for property in self.properties:
if property.data_type != first_type:
return False
return True
def try_constant_list(self) -> bool:
"""Whether the element is just a single list, which might have a
constant size, and therefore we could try to parse quickly with numpy.
Returns:
True if the only property is a list.
"""
if len(self.properties) != 1:
return False
if self.properties[0].list_size_type is None:
return False
return True
class _PlyHeader:
def __init__(self, f):
"""
Load a header of a Ply file from a file-like object.
Members:
self.elements: (List[_PlyElementType]) element description
self.ascii: (bool) Whether in ascii format
self.big_endian: (bool) (if not ascii) whether big endian
self.obj_info: (dict) arbitrary extra data
Args:
f: file-like object.
"""
if f.readline() not in [b"ply\n", b"ply\r\n", "ply\n"]:
raise ValueError("Invalid file header.")
seen_format = False
self.elements = []
self.obj_info = {}
while True:
line = f.readline()
if isinstance(line, bytes):
line = line.decode("ascii")
line = line.strip()
if line == "end_header":
if not self.elements:
raise ValueError("No elements found.")
if not self.elements[-1].properties:
raise ValueError("Found an element with no properties.")
if not seen_format:
raise ValueError("No format line found.")
break
if not seen_format:
if line == "format ascii 1.0":
seen_format = True
self.ascii = True
continue
if line == "format binary_little_endian 1.0":
seen_format = True
self.ascii = False
self.big_endian = False
continue
if line == "format binary_big_endian 1.0":
seen_format = True
self.ascii = False
self.big_endian = True
continue
if line.startswith("format"):
raise ValueError("Invalid format line.")
if line.startswith("comment") or len(line) == 0:
continue
if line.startswith("element"):
self._parse_element(line)
continue
if line.startswith("obj_info"):
items = line.split(" ")
if len(items) != 3:
raise ValueError("Invalid line: %s" % line)
self.obj_info[items[1]] = items[2]
continue
if line.startswith("property"):
self._parse_property(line)
continue
raise ValueError("Invalid line: %s." % line)
def _parse_property(self, line: str):
"""
Decode a ply file header property line.
Args:
line: (str) the ply file's line.
"""
if not self.elements:
raise ValueError("Encountered property before any element.")
items = line.split(" ")
if len(items) not in [3, 5]:
raise ValueError("Invalid line: %s" % line)
datatype = items[1]
name = items[-1]
if datatype == "list":
datatype = items[3]
list_size_type = items[2]
if list_size_type not in _PLY_TYPES:
raise ValueError("Invalid datatype: %s" % list_size_type)
else:
list_size_type = None
if datatype not in _PLY_TYPES:
raise ValueError("Invalid datatype: %s" % datatype)
self.elements[-1].add_property(name, datatype, list_size_type)
def _parse_element(self, line: str):
"""
Decode a ply file header element line.
Args:
line: (str) the ply file's line.
"""
if self.elements and not self.elements[-1].properties:
raise ValueError("Found an element with no properties.")
items = line.split(" ")
if len(items) != 3:
raise ValueError("Invalid line: %s" % line)
try:
count = int(items[2])
except ValueError:
msg = "Number of items for %s was not a number."
raise ValueError(msg % items[1])
self.elements.append(_PlyElementType(items[1], count))
def _read_ply_fixed_size_element_ascii(f, definition: _PlyElementType):
"""
Given an element which has no lists and one type, read the
corresponding data.
Args:
f: file-like object being read.
definition: The element object which describes what we are reading.
Returns:
2D numpy array corresponding to the data. The rows are the different
values. There is one column for each property.
"""
np_type = _PLY_TYPES[definition.properties[0].data_type].np_type
data = np.loadtxt(
f, dtype=np_type, comments=None, ndmin=2, max_rows=definition.count
)
if data.shape[1] != len(definition.properties):
raise ValueError("Inconsistent data for %s." % definition.name)
if data.shape[0] != definition.count:
raise ValueError("Not enough data for %s." % definition.name)
return data
def _try_read_ply_constant_list_ascii(f, definition: _PlyElementType):
"""
If definition is an element which is a single list, attempt to read the
corresponding data assuming every value has the same length.
If the data is ragged, return None and leave f undisturbed.
Args:
f: file-like object being read.
definition: The element object which describes what we are reading.
Returns:
If every element has the same size, 2D numpy array corresponding to the
data. The rows are the different values. Otherwise None.
"""
np_type = _PLY_TYPES[definition.properties[0].data_type].np_type
start_point = f.tell()
try:
with warnings.catch_warnings():
warnings.filterwarnings(
"ignore", message=".* Empty input file.*", category=UserWarning
)
data = np.loadtxt(
f,
dtype=np_type,
comments=None,
ndmin=2,
max_rows=definition.count,
)
except ValueError:
f.seek(start_point)
return None
if (data.shape[1] - 1 != data[:, 0]).any():
msg = "A line of %s data did not have the specified length."
raise ValueError(msg % definition.name)
if data.shape[0] != definition.count:
raise ValueError("Not enough data for %s." % definition.name)
return data[:, 1:]
def _parse_heterogenous_property_ascii(datum, line_iter, property: _Property):
"""
Read a general data property from an ascii .ply file.
Args:
datum: list to append the single value to. That value will be a numpy
array if the property is a list property, otherwise an int or
float.
line_iter: iterator to words on the line from which we read.
property: the property object describing the property we are reading.
"""
value = next(line_iter, None)
if value is None:
raise ValueError("Too little data for an element.")
if property.list_size_type is None:
try:
if property.data_type in ["double", "float"]:
datum.append(float(value))
else:
datum.append(int(value))
except ValueError:
raise ValueError("Bad numerical data.")
else:
try:
length = int(value)
except ValueError:
raise ValueError("A list length was not a number.")
list_value = np.zeros(
length, dtype=_PLY_TYPES[property.data_type].np_type
)
for i in range(length):
inner_value = next(line_iter, None)
if inner_value is None:
raise ValueError("Too little data for an element.")
try:
list_value[i] = float(inner_value)
except ValueError:
raise ValueError("Bad numerical data.")
datum.append(list_value)
def _read_ply_element_ascii(f, definition: _PlyElementType):
"""
Decode all instances of a single element from an ascii .ply file.
Args:
f: file-like object being read.
definition: The element object which describes what we are reading.
Returns:
In simple cases where every element has the same size, 2D numpy array
corresponding to the data. The rows are the different values.
Otherwise a list of lists of values, where the outer list is
each occurence of the element, and the inner lists have one value per
property.
"""
if definition.is_constant_type_fixed_size():
return _read_ply_fixed_size_element_ascii(f, definition)
if definition.try_constant_list():
data = _try_read_ply_constant_list_ascii(f, definition)
if data is not None:
return data
# We failed to read the element as a lump, must process each line manually.
data = []
for _i in range(definition.count):
line_string = f.readline()
if line_string == "":
raise ValueError("Not enough data for %s." % definition.name)
datum = []
line_iter = iter(line_string.strip().split())
for property in definition.properties:
_parse_heterogenous_property_ascii(datum, line_iter, property)
data.append(datum)
if next(line_iter, None) is not None:
raise ValueError("Too much data for an element.")
return data
def _read_ply_fixed_size_element_binary(
f, definition: _PlyElementType, big_endian: bool
):
"""
Given an element which has no lists and one type, read the
corresponding data.
Args:
f: file-like object being read.
definition: The element object which describes what we are reading.
big_endian: (bool) whether the document is encoded as big endian.
Returns:
2D numpy array corresponding to the data. The rows are the different
values. There is one column for each property.
"""
ply_type = _PLY_TYPES[definition.properties[0].data_type]
np_type = ply_type.np_type
type_size = ply_type.size
needed_length = definition.count * len(definition.properties)
needed_bytes = needed_length * type_size
bytes_data = f.read(needed_bytes)
if len(bytes_data) != needed_bytes:
raise ValueError("Not enough data for %s." % definition.name)
data = np.frombuffer(bytes_data, dtype=np_type)
if (sys.byteorder == "big") != big_endian:
data = data.byteswap()
return data.reshape(definition.count, len(definition.properties))
def _read_ply_element_struct(f, definition: _PlyElementType, endian_str: str):
"""
Given an element which has no lists, read the corresponding data. Uses the
struct library.
Note: It looks like struct would also support lists where
type=size_type=char, but it is hard to know how much data to read in that
case.
Args:
f: file-like object being read.
definition: The element object which describes what we are reading.
endian_str: ">" or "<" according to whether the document is big or
little endian.
Returns:
2D numpy array corresponding to the data. The rows are the different
values. There is one column for each property.
"""
format = "".join(
_PLY_TYPES[property.data_type].struct_char
for property in definition.properties
)
format = endian_str + format
pattern = struct.Struct(format)
size = pattern.size
needed_bytes = size * definition.count
bytes_data = f.read(needed_bytes)
if len(bytes_data) != needed_bytes:
raise ValueError("Not enough data for %s." % definition.name)
data = [
pattern.unpack_from(bytes_data, i * size)
for i in range(definition.count)
]
return data
def _try_read_ply_constant_list_binary(
f, definition: _PlyElementType, big_endian: bool
):
"""
If definition is an element which is a single list, attempt to read the
corresponding data assuming every value has the same length.
If the data is ragged, return None and leave f undisturbed.
Args:
f: file-like object being read.
definition: The element object which describes what we are reading.
big_endian: (bool) whether the document is encoded as big endian.
Returns:
If every element has the same size, 2D numpy array corresponding to the
data. The rows are the different values. Otherwise None.
"""
property = definition.properties[0]
endian_str = ">" if big_endian else "<"
length_format = endian_str + _PLY_TYPES[property.list_size_type].struct_char
length_struct = struct.Struct(length_format)
def get_length():
bytes_data = f.read(length_struct.size)
if len(bytes_data) != length_struct.size:
raise ValueError("Not enough data for %s." % definition.name)
[length] = length_struct.unpack(bytes_data)
return length
start_point = f.tell()
length = get_length()
np_type = _PLY_TYPES[definition.properties[0].data_type].np_type
type_size = _PLY_TYPES[definition.properties[0].data_type].size
data_size = type_size * length
output = np.zeros((definition.count, length), dtype=np_type)
for i in range(definition.count):
bytes_data = f.read(data_size)
if len(bytes_data) != data_size:
raise ValueError("Not enough data for %s" % definition.name)
output[i] = np.frombuffer(bytes_data, dtype=np_type)
if i + 1 == definition.count:
break
if length != get_length():
f.seek(start_point)
return None
if (sys.byteorder == "big") != big_endian:
output = output.byteswap()
return output
def _read_ply_element_binary(
f, definition: _PlyElementType, big_endian: bool
) -> list:
"""
Decode all instances of a single element from a binary .ply file.
Args:
f: file-like object being read.
definition: The element object which describes what we are reading.
big_endian: (bool) whether the document is encoded as big endian.
Returns:
In simple cases where every element has the same size, 2D numpy array
corresponding to the data. The rows are the different values.
Otherwise a list of lists/tuples of values, where the outer list is
each occurence of the element, and the inner lists have one value per
property.
"""
endian_str = ">" if big_endian else "<"
if definition.is_constant_type_fixed_size():
return _read_ply_fixed_size_element_binary(f, definition, big_endian)
if definition.is_fixed_size():
return _read_ply_element_struct(f, definition, endian_str)
if definition.try_constant_list():
data = _try_read_ply_constant_list_binary(f, definition, big_endian)
if data is not None:
return data
# We failed to read the element as a lump, must process each line manually.
property_structs = []
for property in definition.properties:
initial_type = property.list_size_type or property.data_type
property_structs.append(
struct.Struct(endian_str + _PLY_TYPES[initial_type].struct_char)
)
data = []
for _i in range(definition.count):
datum = []
for property, property_struct in zip(
definition.properties, property_structs
):
size = property_struct.size
initial_data = f.read(size)
if len(initial_data) != size:
raise ValueError("Not enough data for %s" % definition.name)
[initial] = property_struct.unpack(initial_data)
if property.list_size_type is None:
datum.append(initial)
else:
type_size = _PLY_TYPES[property.data_type].size
needed_bytes = type_size * initial
list_data = f.read(needed_bytes)
if len(list_data) != needed_bytes:
raise ValueError("Not enough data for %s" % definition.name)
np_type = _PLY_TYPES[property.data_type].np_type
list_np = np.frombuffer(list_data, dtype=np_type)
if (sys.byteorder == "big") != big_endian:
list_np = list_np.byteswap()
datum.append(list_np)
data.append(datum)
return data
def _load_ply_raw_stream(f) -> Tuple[_PlyHeader, dict]:
"""
Implementation for _load_ply_raw which takes a stream.
Args:
f: A binary or text file-like object.
Returns:
header: A _PlyHeader object describing the metadata in the ply file.
elements: A dictionary of element names to values. If an element is regular, in
the sense of having no lists or being one uniformly-sized list, then the
value will be a 2D numpy array. If not, it is a list of the relevant
property values.
"""
header = _PlyHeader(f)
elements = {}
if header.ascii:
for element in header.elements:
elements[element.name] = _read_ply_element_ascii(f, element)
else:
big = header.big_endian
for element in header.elements:
elements[element.name] = _read_ply_element_binary(f, element, big)
end = f.read().strip()
if len(end) != 0:
raise ValueError("Extra data at end of file: " + str(end[:20]))
return header, elements
def _load_ply_raw(f) -> Tuple[_PlyHeader, dict]:
"""
Load the data from a .ply file.
Args:
f: A binary or text file-like object (with methods read, readline,
tell and seek), a pathlib path or a string containing a file name.
If the ply file is binary, a text stream is not supported.
It is recommended to use a binary stream.
Returns:
header: A _PlyHeader object describing the metadata in the ply file.
elements: A dictionary of element names to values. If an element is
regular, in the sense of having no lists or being one
uniformly-sized list, then the value will be a 2D numpy array.
If not, it is a list of the relevant property values.
"""
new_f = False
if isinstance(f, str):
new_f = True
f = open(f, "rb")
elif isinstance(f, pathlib.Path):
new_f = True
f = f.open("rb")
try:
header, elements = _load_ply_raw_stream(f)
finally:
if new_f:
f.close()
return header, elements
def load_ply(f):
"""
Load the data from a .ply file.
Example .ply file format:
ply
format ascii 1.0 { ascii/binary, format version number }
comment made by Greg Turk { comments keyword specified, like all lines }
comment this file is a cube
element vertex 8 { define "vertex" element, 8 of them in file }
property float x { vertex contains float "x" coordinate }
property float y { y coordinate is also a vertex property }
property float z { z coordinate, too }
element face 6 { there are 6 "face" elements in the file }
property list uchar int vertex_index { "vertex_indices" is a list of ints }
end_header { delimits the end of the header }
0 0 0 { start of vertex list }
0 0 1
0 1 1
0 1 0
1 0 0
1 0 1
1 1 1
1 1 0
4 0 1 2 3 { start of face list }
4 7 6 5 4
4 0 4 5 1
4 1 5 6 2
4 2 6 7 3
4 3 7 4 0
Args:
f: A binary or text file-like object (with methods read, readline,
tell and seek), a pathlib path or a string containing a file name.
If the ply file is in the binary ply format rather than the text
ply format, then a text stream is not supported.
It is easiest to use a binary stream in all cases.
Returns:
verts: FloatTensor of shape (V, 3).
faces: LongTensor of vertex indices, shape (F, 3).
"""
header, elements = _load_ply_raw(f)
vertex = elements.get("vertex", None)
if vertex is None:
raise ValueError("The ply file has no vertex element.")
face = elements.get("face", None)
if face is None:
raise ValueError("The ply file has no face element.")
if (
not isinstance(vertex, np.ndarray)
or vertex.ndim != 2
or vertex.shape[1] != 3
):
raise ValueError("Invalid vertices in file.")
verts = torch.tensor(vertex, dtype=torch.float32)
face_head = next(head for head in header.elements if head.name == "face")
if (
len(face_head.properties) != 1
or face_head.properties[0].list_size_type is None
):
raise ValueError("Unexpected form of faces data.")
# face_head.properties[0].name is usually "vertex_index" or "vertex_indices"
# but we don't need to enforce this.
if isinstance(face, np.ndarray) and face.ndim == 2:
if face.shape[1] < 3:
raise ValueError("Faces must have at least 3 vertices.")
face_arrays = [
face[:, [0, i + 1, i + 2]] for i in range(face.shape[1] - 2)
]
faces = torch.tensor(np.vstack(face_arrays), dtype=torch.int64)
else:
face_list = []
for face_item in face:
if face_item.ndim != 1:
raise ValueError("Bad face data.")
if face_item.shape[0] < 3:
raise ValueError("Faces must have at least 3 vertices.")
for i in range(face_item.shape[0] - 2):
face_list.append(
[face_item[0], face_item[i + 1], face_item[i + 2]]
)
faces = torch.tensor(face_list, dtype=torch.int64)
return verts, faces
def _save_ply(f, verts, faces, decimal_places: Optional[int]):
"""
Internal implementation for saving a mesh to a .ply file.
Args:
f: File object to which the mesh should be written.
verts: FloatTensor of shape (V, 3) giving vertex coordinates.
faces: LongTensor of shape (F, 3) giving faces.
decimal_places: Number of decimal places for saving.
"""
print("ply\nformat ascii 1.0", file=f)
print(f"element vertex {verts.shape[0]}", file=f)
print("property float x", file=f)
print("property float y", file=f)
print("property float z", file=f)
print(f"element face {faces.shape[0]}", file=f)
print("property list uchar int vertex_index", file=f)
print("end_header", file=f)
if decimal_places is None:
float_str = "%f"
else:
float_str = "%" + ".%df" % decimal_places
np.savetxt(f, verts.detach().numpy(), float_str)
np.savetxt(f, faces.detach().numpy(), "3 %d %d %d")
def save_ply(f, verts, faces, decimal_places: Optional[int] = None):
"""
Save a mesh to a .ply file.
Args:
f: File (or path) to which the mesh should be written.
verts: FloatTensor of shape (V, 3) giving vertex coordinates.
faces: LongTensor of shape (F, 3) giving faces.
decimal_places: Number of decimal places for saving.
"""
new_f = False
if isinstance(f, str):
new_f = True
f = open(f, "w")
elif isinstance(f, pathlib.Path):
new_f = True
f = f.open("w")
try:
_save_ply(f, verts, faces, decimal_places)
finally:
if new_f:
f.close()

View File

@@ -0,0 +1,9 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
from .chamfer import chamfer_distance
from .mesh_edge_loss import mesh_edge_loss
from .mesh_laplacian_smoothing import mesh_laplacian_smoothing
from .mesh_normal_consistency import mesh_normal_consistency
__all__ = [k for k in globals().keys() if not k.startswith("_")]

152
pytorch3d/loss/chamfer.py Normal file
View File

@@ -0,0 +1,152 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
import torch
import torch.nn.functional as F
from pytorch3d.ops.nearest_neighbor_points import nn_points_idx
def _validate_chamfer_reduction_inputs(
batch_reduction: str, point_reduction: str
):
"""Check the requested reductions are valid.
Args:
batch_reduction: Reduction operation to apply for the loss across the
batch, can be one of ["none", "mean", "sum"].
point_reduction: Reduction operation to apply for the loss across the
points, can be one of ["none", "mean", "sum"].
"""
if batch_reduction not in ["none", "mean", "sum"]:
raise ValueError(
'batch_reduction must be one of ["none", "mean", "sum"]'
)
if point_reduction not in ["none", "mean", "sum"]:
raise ValueError(
'point_reduction must be one of ["none", "mean", "sum"]'
)
if batch_reduction == "none" and point_reduction == "none":
raise ValueError(
'batch_reduction and point_reduction cannot both be "none".'
)
def chamfer_distance(
x,
y,
x_normals=None,
y_normals=None,
weights=None,
batch_reduction: str = "mean",
point_reduction: str = "mean",
):
"""
Chamfer distance between two pointclouds x and y.
Args:
x: FloatTensor of shape (N, P1, D) representing a batch of point clouds
with P1 points in each batch element, batch size N and feature
dimension D.
y: FloatTensor of shape (N, P2, D) representing a batch of point clouds
with P2 points in each batch element, batch size N and feature
dimension D.
x_normals: Optional FloatTensor of shape (N, P1, D).
y_normals: Optional FloatTensor of shape (N, P2, D).
weights: Optional FloatTensor of shape (N,) giving weights for
batch elements for reduction operation.
batch_reduction: Reduction operation to apply for the loss across the
batch, can be one of ["none", "mean", "sum"].
point_reduction: Reduction operation to apply for the loss across the
points, can be one of ["none", "mean", "sum"].
Returns:
2-element tuple containing
- **loss**: Tensor giving the reduced distance between the pointclouds
in x and the pointclouds in y.
- **loss_normals**: Tensor giving the reduced cosine distance of normals
between pointclouds in x and pointclouds in y. Returns None if
x_normals and y_normals are None.
"""
_validate_chamfer_reduction_inputs(batch_reduction, point_reduction)
N, P1, D = x.shape
P2 = y.shape[1]
if y.shape[0] != N or y.shape[2] != D:
raise ValueError("y does not have the correct shape.")
if weights is not None:
if weights.size(0) != N:
raise ValueError("weights must be of shape (N,).")
if not (weights >= 0).all():
raise ValueError("weights can not be nonnegative.")
if weights.sum() == 0.0:
weights = weights.view(N, 1)
if batch_reduction in ["mean", "sum"]:
return (
(x.sum((1, 2)) * weights).sum() * 0.0,
(x.sum((1, 2)) * weights).sum() * 0.0,
)
return (
(x.sum((1, 2)) * weights) * 0.0,
(x.sum((1, 2)) * weights) * 0.0,
)
return_normals = x_normals is not None and y_normals is not None
cham_norm_x = x.new_zeros(())
cham_norm_y = x.new_zeros(())
x_near, xidx_near, x_normals_near = nn_points_idx(x, y, y_normals)
y_near, yidx_near, y_normals_near = nn_points_idx(y, x, x_normals)
cham_x = (x - x_near).norm(dim=2, p=2) ** 2.0 # (N, P1)
cham_y = (y - y_near).norm(dim=2, p=2) ** 2.0 # (N, P2)
if weights is not None:
cham_x *= weights.view(N, 1)
cham_y *= weights.view(N, 1)
if return_normals:
cham_norm_x = 1 - torch.abs(
F.cosine_similarity(x_normals, x_normals_near, dim=2, eps=1e-6)
)
cham_norm_y = 1 - torch.abs(
F.cosine_similarity(y_normals, y_normals_near, dim=2, eps=1e-6)
)
if weights is not None:
cham_norm_x *= weights.view(N, 1)
cham_norm_y *= weights.view(N, 1)
if point_reduction != "none":
# If not 'none' then either 'sum' or 'mean'.
cham_x = cham_x.sum(1) # (N,)
cham_y = cham_y.sum(1) # (N,)
if return_normals:
cham_norm_x = cham_norm_x.sum(1) # (N,)
cham_norm_y = cham_norm_y.sum(1) # (N,)
if point_reduction == "mean":
cham_x /= P1
cham_y /= P2
if return_normals:
cham_norm_x /= P1
cham_norm_y /= P2
if batch_reduction != "none":
cham_x = cham_x.sum()
cham_y = cham_y.sum()
if return_normals:
cham_norm_x = cham_norm_x.sum()
cham_norm_y = cham_norm_y.sum()
if batch_reduction == "mean":
div = weights.sum() if weights is not None else N
cham_x /= div
cham_y /= div
if return_normals:
cham_norm_x /= div
cham_norm_y /= div
cham_dist = cham_x + cham_y
cham_normals = cham_norm_x + cham_norm_y if return_normals else None
return cham_dist, cham_normals

View File

@@ -0,0 +1,47 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
import torch
def mesh_edge_loss(meshes, target_length: float = 0.0):
"""
Computes mesh edge length regularization loss averaged across all meshes
in a batch. Each edge contributes equally to the final loss, regardless of
numbers of edges per mesh in the batch by weighting each mesh with the
inverse number of edges. For example, if mesh 3 (out of N) has only E=4
edges, then the loss for each edge in mesh 3 should be multiplied by 1/E to
contribute to the final loss.
Args:
meshes: Meshes object with a batch of meshes.
target_length: Resting value for the edge length.
Returns:
loss: Average loss across the batch. Returns 0 if meshes contains
no meshes or all empty meshes.
"""
if meshes.isempty():
return torch.tensor(
[0.0], dtype=torch.float32, device=meshes.device, requires_grad=True
)
N = len(meshes)
edges_packed = meshes.edges_packed() # (sum(E_n), 3)
verts_packed = meshes.verts_packed() # (sum(V_n), 3)
edge_to_mesh_idx = meshes.edges_packed_to_mesh_idx() # (sum(E_n), )
num_edges_per_mesh = meshes.num_edges_per_mesh() # N
# Determine the weight for each edge based on the number of edges in the
# mesh it corresponds to.
# TODO (nikhilar) Find a faster way of computing the weights for each edge
# as this is currently a bottleneck for meshes with a large number of faces.
weights = num_edges_per_mesh.gather(0, edge_to_mesh_idx)
weights = 1.0 / weights.float()
verts_edges = verts_packed[edges_packed]
v0, v1 = verts_edges.unbind(1)
loss = ((v0 - v1).norm(dim=1, p=2) - target_length) ** 2.0
loss = loss * weights
return loss.sum() / N

View File

@@ -0,0 +1,195 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import torch
def mesh_laplacian_smoothing(meshes, method: str = "uniform"):
r"""
Computes the laplacian smoothing objective for a batch of meshes.
This function supports three variants of Laplacian smoothing,
namely with uniform weights("uniform"), with cotangent weights ("cot"),
and cotangent cuvature ("cotcurv").For more details read [1, 2].
Args:
meshes: Meshes object with a batch of meshes.
method: str specifying the method for the laplacian.
Returns:
loss: Average laplacian smoothing loss across the batch.
Returns 0 if meshes contains no meshes or all empty meshes.
Consider a mesh M = (V, F), with verts of shape Nx3 and faces of shape Mx3.
The Laplacian matrix L is a NxN tensor such that LV gives a tensor of vectors:
for a uniform Laplacian, LuV[i] points to the centroid of its neighboring
vertices, a cotangent Laplacian LcV[i] is known to be an approximation of
the surface normal, while the curvature variant LckV[i] scales the normals
by the discrete mean curvature. For vertex i, assume S[i] is the set of
neighboring vertices to i, a_ij and b_ij are the "outside" angles in the
two triangles connecting vertex v_i and its neighboring vertex v_j
for j in S[i], as seen in the diagram below.
.. code-block:: python
a_ij
/\
/ \
/ \
/ \
v_i /________\ v_j
\ /
\ /
\ /
\ /
\/
b_ij
The definition of the Laplacian is LV[i] = sum_j w_ij (v_j - v_i)
For the uniform variant, w_ij = 1 / |S[i]|
For the cotangent variant,
w_ij = (cot a_ij + cot b_ij) / (sum_k cot a_ik + cot b_ik)
For the cotangent curvature, w_ij = (cot a_ij + cot b_ij) / (4 A[i])
where A[i] is the sum of the areas of all triangles containing vertex v_i.
There is a nice trigonometry identity to compute cotangents. Consider a triangle
with side lengths A, B, C and angles a, b, c.
.. code-block:: python
c
/|\
/ | \
/ | \
B / H| \ A
/ | \
/ | \
/a_____|_____b\
C
Then cot a = (B^2 + C^2 - A^2) / 4 * area
We know that area = CH/2, and by the law of cosines we have
A^2 = B^2 + C^2 - 2BC cos a => B^2 + C^2 - A^2 = 2BC cos a
Putting these together, we get:
B^2 + C^2 - A^2 2BC cos a
_______________ = _________ = (B/H) cos a = cos a / sin a = cot a
4 * area 2CH
[1] Desbrun et al, "Implicit fairing of irregular meshes using diffusion
and curvature flow", SIGGRAPH 1999.
[2] Nealan et al, "Laplacian Mesh Optimization", Graphite 2006.
"""
if meshes.isempty():
return torch.tensor(
[0.0], dtype=torch.float32, device=meshes.device, requires_grad=True
)
N = len(meshes)
verts_packed = meshes.verts_packed() # (sum(V_n), 3)
num_verts_per_mesh = meshes.num_verts_per_mesh() # (N,)
verts_packed_idx = meshes.verts_packed_to_mesh_idx() # (sum(V_n),)
weights = num_verts_per_mesh.gather(0, verts_packed_idx) # (sum(V_n),)
weights = 1.0 / weights.float()
# We don't want to backprop through the computation of the Laplacian;
# just treat it as a magic constant matrix that is used to transform
# verts into normals
with torch.no_grad():
if method == "uniform":
L = meshes.laplacian_packed()
elif method in ["cot", "cotcurv"]:
L, inv_areas = laplacian_cot(meshes)
if method == "cot":
norm_w = torch.sparse.sum(L, dim=1).to_dense().view(-1, 1)
idx = norm_w > 0
norm_w[idx] = 1.0 / norm_w[idx]
else:
norm_w = 0.25 * inv_areas
else:
raise ValueError("Method should be one of {uniform, cot, cotcurv}")
if method == "uniform":
loss = L.mm(verts_packed)
elif method == "cot":
loss = L.mm(verts_packed) * norm_w - verts_packed
elif method == "cotcurv":
loss = (L.mm(verts_packed) - verts_packed) * norm_w
loss = loss.norm(dim=1)
loss = loss * weights
return loss.sum() / N
def laplacian_cot(meshes):
"""
Returns the Laplacian matrix with cotangent weights and the inverse of the
face areas.
Args:
meshes: Meshes object with a batch of meshes.
Returns:
2-element tuple containing
- **L**: FloatTensor of shape (V,V) for the Laplacian matrix (V = sum(V_n))
Here, L[i, j] = cot a_ij + cot b_ij iff (i, j) is an edge in meshes.
See the description above for more clarity.
- **inv_areas**: FloatTensor of shape (V,) containing the inverse of sum of
face areas containing each vertex
"""
verts_packed = meshes.verts_packed() # (sum(V_n), 3)
faces_packed = meshes.faces_packed() # (sum(F_n), 3)
# V = sum(V_n), F = sum(F_n)
V, F = verts_packed.shape[0], faces_packed.shape[0]
face_verts = verts_packed[faces_packed]
v0, v1, v2 = face_verts[:, 0], face_verts[:, 1], face_verts[:, 2]
# Side lengths of each triangle, of shape (sum(F_n),)
# A is the side opposite v1, B is opposite v2, and C is opposite v3
A = (v1 - v2).norm(dim=1)
B = (v0 - v2).norm(dim=1)
C = (v0 - v1).norm(dim=1)
# Area of each triangle (with Heron's formula); shape is (sum(F_n),)
s = 0.5 * (A + B + C)
# note that the area can be negative (close to 0) causing nans after sqrt()
# we clip it to a small positive value
area = (s * (s - A) * (s - B) * (s - C)).clamp_(min=1e-12).sqrt()
# Compute cotangents of angles, of shape (sum(F_n), 3)
A2, B2, C2 = A * A, B * B, C * C
cota = (B2 + C2 - A2) / area
cotb = (A2 + C2 - B2) / area
cotc = (A2 + B2 - C2) / area
cot = torch.stack([cota, cotb, cotc], dim=1)
cot /= 4.0
# Construct a sparse matrix by basically doing:
# L[v1, v2] = cota
# L[v2, v0] = cotb
# L[v0, v1] = cotc
ii = faces_packed[:, [1, 2, 0]]
jj = faces_packed[:, [2, 0, 1]]
idx = torch.stack([ii, jj], dim=0).view(2, F * 3)
L = torch.sparse.FloatTensor(idx, cot.view(-1), (V, V))
# Make it symmetric; this means we are also setting
# L[v2, v1] = cota
# L[v0, v2] = cotb
# L[v1, v0] = cotc
L += L.t()
# For each vertex, compute the sum of areas for triangles containing it.
idx = faces_packed.view(-1)
inv_areas = torch.zeros(V, dtype=torch.float32, device=meshes.device)
val = torch.stack([area] * 3, dim=1).view(-1)
inv_areas.scatter_add_(0, idx, val)
idx = inv_areas > 0
inv_areas[idx] = 1.0 / inv_areas[idx]
inv_areas = inv_areas.view(-1, 1)
return L, inv_areas

View File

@@ -0,0 +1,148 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
from itertools import islice
import torch
def mesh_normal_consistency(meshes):
r"""
Computes the normal consistency of each mesh in meshes.
We compute the normal consistency for each pair of neighboring faces.
If e = (v0, v1) is the connecting edge of two neighboring faces f0 and f1,
then the normal consistency between f0 and f1
.. code-block:: python
a
/\
/ \
/ f0 \
/ \
v0 /____e___\ v1
\ /
\ /
\ f1 /
\ /
\/
b
The normal consistency is
.. code-block:: python
nc(f0, f1) = 1 - cos(n0, n1)
where cos(n0, n1) = n0^n1 / ||n0|| / ||n1|| is the cosine of the angle
between the normals n0 and n1, and
n0 = (v1 - v0) x (a - v0)
n1 = - (v1 - v0) x (b - v0) = (b - v0) x (v1 - v0)
This means that if nc(f0, f1) = 0 then n0 and n1 point to the same
direction, while if nc(f0, f1) = 2 then n0 and n1 point opposite direction.
.. note::
For well-constructed meshes the assumption that only two faces share an
edge is true. This assumption could make the implementation easier and faster.
This implementation does not follow this assumption. All the faces sharing e,
which can be any in number, are discovered.
Args:
meshes: Meshes object with a batch of meshes.
Returns:
loss: Average normal consistency across the batch.
Returns 0 if meshes contains no meshes or all empty meshes.
"""
if meshes.isempty():
return torch.tensor(
[0.0], dtype=torch.float32, device=meshes.device, requires_grad=True
)
N = len(meshes)
verts_packed = meshes.verts_packed() # (sum(V_n), 3)
faces_packed = meshes.faces_packed() # (sum(F_n), 3)
edges_packed = meshes.edges_packed() # (sum(E_n), 2)
verts_packed_to_mesh_idx = meshes.verts_packed_to_mesh_idx() # (sum(V_n),)
face_to_edge = meshes.faces_packed_to_edges_packed() # (sum(F_n), 3)
E = edges_packed.shape[0] # sum(E_n)
F = faces_packed.shape[0] # sum(F_n)
# We don't want gradients for the following operation. The goal is to
# find for each edge e all the vertices associated with e. In the example above,
# the vertices associated with e are (v0, v1, a, b), i.e. points on e (=v0, v1)
# and points connected on faces to e (=a, b).
with torch.no_grad():
edge_idx = face_to_edge.reshape(F * 3) # (3 * F,) indexes into edges
vert_idx = (
faces_packed.view(1, F, 3)
.expand(3, F, 3)
.transpose(0, 1)
.reshape(3 * F, 3)
)
edge_idx, edge_sort_idx = edge_idx.sort()
vert_idx = vert_idx[edge_sort_idx]
# In well constructed meshes each edge is shared by precisely 2 faces
# However, in many meshes, this assumption is not always satisfied.
# We want to find all faces that share an edge, a number which can
# vary and which depends on the topology.
# In particular, we find the vertices not on the edge on the shared faces.
# In the example above, we want to associate edge e with vertices a and b.
# This operation is done more efficiently in cpu with lists.
# TODO(gkioxari) find a better way to do this.
# edge_idx represents the index of the edge for each vertex. We can count
# the number of vertices which are associated with each edge.
# There can be a different number for each edge.
edge_num = edge_idx.bincount(minlength=E)
# Create pairs of vertices associated to e. We generate a list of lists:
# each list has the indices of the vertices which are opposite to one edge.
# The length of the list for each edge will vary.
vert_edge_pair_idx = split_list(
list(range(edge_idx.shape[0])), edge_num.tolist()
)
# For each list find all combinations of pairs in the list. This represents
# all pairs of vertices which are opposite to the same edge.
vert_edge_pair_idx = [
[e[i], e[j]]
for e in vert_edge_pair_idx
for i in range(len(e) - 1)
for j in range(1, len(e))
if i != j
]
vert_edge_pair_idx = torch.tensor(
vert_edge_pair_idx, device=meshes.device, dtype=torch.int64
)
v0_idx = edges_packed[edge_idx, 0]
v0 = verts_packed[v0_idx]
v1_idx = edges_packed[edge_idx, 1]
v1 = verts_packed[v1_idx]
# two of the following cross products are zeros as they are cross product
# with either (v1-v0)x(v1-v0) or (v1-v0)x(v0-v0)
n_temp0 = (v1 - v0).cross(verts_packed[vert_idx[:, 0]] - v0, dim=1)
n_temp1 = (v1 - v0).cross(verts_packed[vert_idx[:, 1]] - v0, dim=1)
n_temp2 = (v1 - v0).cross(verts_packed[vert_idx[:, 2]] - v0, dim=1)
n = n_temp0 + n_temp1 + n_temp2
n0 = n[vert_edge_pair_idx[:, 0]]
n1 = -n[vert_edge_pair_idx[:, 1]]
loss = 1 - torch.cosine_similarity(n0, n1, dim=1)
verts_packed_to_mesh_idx = verts_packed_to_mesh_idx[vert_idx[:, 0]]
verts_packed_to_mesh_idx = verts_packed_to_mesh_idx[
vert_edge_pair_idx[:, 0]
]
num_normals = verts_packed_to_mesh_idx.bincount(minlength=N)
weights = 1.0 / num_normals[verts_packed_to_mesh_idx].float()
loss = loss * weights
return loss.sum() / N
def split_list(input, length_to_split):
inputt = iter(input)
return [list(islice(inputt, elem)) for elem in length_to_split]

11
pytorch3d/ops/__init__.py Normal file
View File

@@ -0,0 +1,11 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
from .cubify import cubify
from .graph_conv import GraphConv
from .nearest_neighbor_points import nn_points_idx
from .sample_points_from_meshes import sample_points_from_meshes
from .subdivide_meshes import SubdivideMeshes
from .vert_align import vert_align
__all__ = [k for k in globals().keys() if not k.startswith("_")]

208
pytorch3d/ops/cubify.py Normal file
View File

@@ -0,0 +1,208 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
import torch
import torch.nn.functional as F
from pytorch3d.structures import Meshes
def unravel_index(idx, dims) -> torch.Tensor:
r"""
Equivalent to np.unravel_index
Args:
idx: A LongTensor whose elements are indices into the
flattened version of an array of dimensions dims.
dims: The shape of the array to be indexed.
Implemented only for dims=(N, H, W, D)
"""
if len(dims) != 4:
raise ValueError("Expects a 4-element list.")
N, H, W, D = dims
n = torch.div(idx, H * W * D)
h = torch.div(idx - n * H * W * D, W * D)
w = torch.div(idx - n * H * W * D - h * W * D, D)
d = idx - n * H * W * D - h * W * D - w * D
return torch.stack((n, h, w, d), dim=1)
def ravel_index(idx, dims) -> torch.Tensor:
"""
Computes the linear index in an array of shape dims.
It performs the reverse functionality of unravel_index
Args:
idx: A LongTensor of shape (N, 3). Each row corresponds to indices into an
array of dimensions dims.
dims: The shape of the array to be indexed.
Implemented only for dims=(H, W, D)
"""
if len(dims) != 3:
raise ValueError("Expects a 3-element list")
if idx.shape[1] != 3:
raise ValueError("Expects an index tensor of shape Nx3")
H, W, D = dims
linind = idx[:, 0] * W * D + idx[:, 1] * D + idx[:, 2]
return linind
@torch.no_grad()
def cubify(voxels, thresh, device=None) -> Meshes:
r"""
Converts a voxel to a mesh by replacing each occupied voxel with a cube
consisting of 12 faces and 8 vertices. Shared vertices are merged, and
internal faces are removed.
Args:
voxels: A FloatTensor of shape (N, D, H, W) containing occupancy probabilities.
thresh: A scalar threshold. If a voxel occupancy is larger than
thresh, the voxel is considered occupied.
Returns:
meshes: A Meshes object of the corresponding meshes.
"""
if device is None:
device = voxels.device
if len(voxels) == 0:
return Meshes(verts=[], faces=[])
N, D, H, W = voxels.size()
# vertices corresponding to a unit cube: 8x3
cube_verts = torch.tensor(
[
[0, 0, 0],
[0, 0, 1],
[0, 1, 0],
[0, 1, 1],
[1, 0, 0],
[1, 0, 1],
[1, 1, 0],
[1, 1, 1],
],
dtype=torch.int64,
device=device,
)
# faces corresponding to a unit cube: 12x3
cube_faces = torch.tensor(
[
[0, 1, 2],
[1, 3, 2], # left face: 0, 1
[2, 3, 6],
[3, 7, 6], # bottom face: 2, 3
[0, 2, 6],
[0, 6, 4], # front face: 4, 5
[0, 5, 1],
[0, 4, 5], # up face: 6, 7
[6, 7, 5],
[6, 5, 4], # right face: 8, 9
[1, 7, 3],
[1, 5, 7], # back face: 10, 11
],
dtype=torch.int64,
device=device,
)
wx = torch.tensor([0.5, 0.5], device=device).view(1, 1, 1, 1, 2)
wy = torch.tensor([0.5, 0.5], device=device).view(1, 1, 1, 2, 1)
wz = torch.tensor([0.5, 0.5], device=device).view(1, 1, 2, 1, 1)
voxelt = voxels.ge(thresh).float()
# N x 1 x D x H x W
voxelt = voxelt.view(N, 1, D, H, W)
# N x 1 x (D-1) x (H-1) x (W-1)
voxelt_x = F.conv3d(voxelt, wx).gt(0.5).float()
voxelt_y = F.conv3d(voxelt, wy).gt(0.5).float()
voxelt_z = F.conv3d(voxelt, wz).gt(0.5).float()
# 12 x N x 1 x D x H x W
faces_idx = torch.ones((cube_faces.size(0), N, 1, D, H, W), device=device)
# add left face
faces_idx[0, :, :, :, :, 1:] = 1 - voxelt_x
faces_idx[1, :, :, :, :, 1:] = 1 - voxelt_x
# add bottom face
faces_idx[2, :, :, :, :-1, :] = 1 - voxelt_y
faces_idx[3, :, :, :, :-1, :] = 1 - voxelt_y
# add front face
faces_idx[4, :, :, 1:, :, :] = 1 - voxelt_z
faces_idx[5, :, :, 1:, :, :] = 1 - voxelt_z
# add up face
faces_idx[6, :, :, :, 1:, :] = 1 - voxelt_y
faces_idx[7, :, :, :, 1:, :] = 1 - voxelt_y
# add right face
faces_idx[8, :, :, :, :, :-1] = 1 - voxelt_x
faces_idx[9, :, :, :, :, :-1] = 1 - voxelt_x
# add back face
faces_idx[10, :, :, :-1, :, :] = 1 - voxelt_z
faces_idx[11, :, :, :-1, :, :] = 1 - voxelt_z
faces_idx *= voxelt
# N x H x W x D x 12
faces_idx = faces_idx.permute(1, 2, 4, 5, 3, 0).squeeze(1)
# (NHWD) x 12
faces_idx = faces_idx.contiguous()
faces_idx = faces_idx.view(-1, cube_faces.size(0))
# boolean to linear index
# NF x 2
linind = torch.nonzero(faces_idx)
# NF x 4
nyxz = unravel_index(linind[:, 0], (N, H, W, D))
# NF x 3: faces
faces = torch.index_select(cube_faces, 0, linind[:, 1])
grid_faces = []
for d in range(cube_faces.size(1)):
# NF x 3
xyz = torch.index_select(cube_verts, 0, faces[:, d])
permute_idx = torch.tensor([1, 0, 2], device=device)
yxz = torch.index_select(xyz, 1, permute_idx)
yxz += nyxz[:, 1:]
# NF x 1
temp = ravel_index(yxz, (H + 1, W + 1, D + 1))
grid_faces.append(temp)
# NF x 3
grid_faces = torch.stack(grid_faces, dim=1)
y, x, z = torch.meshgrid(
torch.arange(H + 1), torch.arange(W + 1), torch.arange(D + 1)
)
y = y.to(device=device, dtype=torch.float32)
y = y * 2.0 / (H - 1.0) - 1.0
x = x.to(device=device, dtype=torch.float32)
x = x * 2.0 / (W - 1.0) - 1.0
z = z.to(device=device, dtype=torch.float32)
z = z * 2.0 / (D - 1.0) - 1.0
# ((H+1)(W+1)(D+1)) x 3
grid_verts = torch.stack((x, y, z), dim=3).view(-1, 3)
if len(nyxz) == 0:
verts_list = [torch.tensor([], dtype=torch.float32, device=device)] * N
faces_list = [torch.tensor([], dtype=torch.int64, device=device)] * N
return Meshes(verts=verts_list, faces=faces_list)
num_verts = grid_verts.size(0)
grid_faces += nyxz[:, 0].view(-1, 1) * num_verts
idleverts = torch.ones(num_verts * N, dtype=torch.uint8, device=device)
idleverts.scatter_(0, grid_faces.flatten(), 0)
grid_faces -= nyxz[:, 0].view(-1, 1) * num_verts
split_size = torch.bincount(nyxz[:, 0], minlength=N)
faces_list = list(torch.split(grid_faces, split_size.tolist(), 0))
idleverts = idleverts.view(N, num_verts)
idlenum = idleverts.cumsum(1)
verts_list = [
grid_verts.index_select(0, (idleverts[n] == 0).nonzero()[:, 0])
for n in range(N)
]
faces_list = [
nface - idlenum[n][nface] for n, nface in enumerate(faces_list)
]
return Meshes(verts=verts_list, faces=faces_list)

174
pytorch3d/ops/graph_conv.py Normal file
View File

@@ -0,0 +1,174 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
import torch
import torch.nn as nn
from torch.autograd import Function
from torch.autograd.function import once_differentiable
from pytorch3d import _C
class GraphConv(nn.Module):
"""A single graph convolution layer."""
def __init__(
self,
input_dim: int,
output_dim: int,
init: str = "normal",
directed: bool = False,
):
"""
Args:
input_dim: Number of input features per vertex.
output_dim: Number of output features per vertex.
init: Weight initialization method. Can be one of ['zero', 'normal'].
directed: Bool indicating if edges in the graph are directed.
"""
super().__init__()
self.input_dim = input_dim
self.output_dim = output_dim
self.directed = directed
self.w0 = nn.Linear(input_dim, output_dim)
self.w1 = nn.Linear(input_dim, output_dim)
if init == "normal":
nn.init.normal_(self.w0.weight, mean=0, std=0.01)
nn.init.normal_(self.w1.weight, mean=0, std=0.01)
self.w0.bias.data.zero_()
self.w1.bias.data.zero_()
elif init == "zero":
self.w0.weight.data.zero_()
self.w1.weight.data.zero_()
else:
raise ValueError('Invalid GraphConv initialization "%s"' % init)
def forward(self, verts, edges):
"""
Args:
verts: FloatTensor of shape (V, input_dim) where V is the number of
vertices and input_dim is the number of input features
per vertex. input_dim has to match the input_dim specified
in __init__.
edges: LongTensor of shape (E, 2) where E is the number of edges
where each edge has the indices of the two vertices which
form the edge.
Returns:
out: FloatTensor of shape (V, output_dim) where output_dim is the
number of output features per vertex.
"""
if verts.is_cuda != edges.is_cuda:
raise ValueError(
"verts and edges tensors must be on the same device."
)
if verts.shape[0] == 0:
# empty graph.
return verts.sum() * 0.0
verts_w0 = self.w0(verts) # (V, output_dim)
verts_w1 = self.w1(verts) # (V, output_dim)
if torch.cuda.is_available() and verts.is_cuda and edges.is_cuda:
neighbor_sums = gather_scatter(verts_w1, edges, self.directed)
else:
neighbor_sums = gather_scatter_python(
verts_w1, edges, self.directed
) # (V, output_dim)
# Add neighbor features to each vertex's features.
out = verts_w0 + neighbor_sums
return out
def __repr__(self):
Din, Dout, directed = self.input_dim, self.output_dim, self.directed
return "GraphConv(%d -> %d, directed=%r)" % (Din, Dout, directed)
def gather_scatter_python(input, edges, directed: bool = False):
"""
Python implementation of gather_scatter for aggregating features of
neighbor nodes in a graph.
Given a directed graph: v0 -> v1 -> v2 the updated feature for v1 depends
on v2 in order to be consistent with Morris et al. AAAI 2019
(https://arxiv.org/abs/1810.02244). This only affects
directed graphs; for undirected graphs v1 will depend on both v0 and v2,
no matter which way the edges are physically stored.
Args:
input: Tensor of shape (num_vertices, input_dim).
edges: Tensor of edge indices of shape (num_edges, 2).
directed: bool indicating if edges are directed.
Returns:
output: Tensor of same shape as input.
"""
if not (input.dim() == 2):
raise ValueError("input can only have 2 dimensions.")
if not (edges.dim() == 2):
raise ValueError("edges can only have 2 dimensions.")
if not (edges.shape[1] == 2):
raise ValueError("edges must be of shape (num_edges, 2).")
num_vertices, input_feature_dim = input.shape
num_edges = edges.shape[0]
output = torch.zeros_like(input)
idx0 = edges[:, 0].view(num_edges, 1).expand(num_edges, input_feature_dim)
idx1 = edges[:, 1].view(num_edges, 1).expand(num_edges, input_feature_dim)
output = output.scatter_add(0, idx0, input.gather(0, idx1))
if not directed:
output = output.scatter_add(0, idx1, input.gather(0, idx0))
return output
class GatherScatter(Function):
"""
Torch autograd Function wrapper for gather_scatter C++/CUDA implementations.
"""
@staticmethod
def forward(ctx, input, edges, directed=False):
"""
Args:
ctx: Context object used to calculate gradients.
input: Tensor of shape (num_vertices, input_dim)
edges: Tensor of edge indices of shape (num_edges, 2)
directed: Bool indicating if edges are directed.
Returns:
output: Tensor of same shape as input.
"""
if not (input.dim() == 2):
raise ValueError("input can only have 2 dimensions.")
if not (edges.dim() == 2):
raise ValueError("edges can only have 2 dimensions.")
if not (edges.shape[1] == 2):
raise ValueError("edges must be of shape (num_edges, 2).")
if not (input.dtype == torch.float32):
raise ValueError("input has to be of type torch.float32.")
ctx.directed = directed
input, edges = input.contiguous(), edges.contiguous()
ctx.save_for_backward(edges)
backward = False
output = _C.gather_scatter(input, edges, directed, backward)
return output
@staticmethod
@once_differentiable
def backward(ctx, grad_output):
grad_output = grad_output.contiguous()
edges = ctx.saved_tensors[0]
directed = ctx.directed
backward = True
grad_input = _C.gather_scatter(grad_output, edges, directed, backward)
grad_edges = None
grad_directed = None
return grad_input, grad_edges, grad_directed
gather_scatter = GatherScatter.apply

View File

@@ -0,0 +1,47 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
import torch
from pytorch3d import _C
def nn_points_idx(p1, p2, p2_normals=None) -> torch.Tensor:
"""
Compute the coordinates of nearest neighbors in pointcloud p2 to points in p1.
Args:
p1: FloatTensor of shape (N, P1, D) giving a batch of pointclouds each
containing P1 points of dimension D.
p2: FloatTensor of shape (N, P2, D) giving a batch of pointclouds each
containing P2 points of dimension D.
p2_normals: [optional] FloatTensor of shape (N, P2, D) giving
normals for p2. Default: None.
Returns:
3-element tuple containing
- **p1_nn_points**: FloatTensor of shape (N, P1, D) where
p1_neighbors[n, i] is the point in p2[n] which is
the nearest neighbor to p1[n, i].
- **p1_nn_idx**: LongTensor of shape (N, P1) giving the indices of
the neighbors.
- **p1_nn_normals**: Normal vectors for each point in p1_neighbors;
only returned if p2_normals is passed
else return [].
"""
N, P1, D = p1.shape
with torch.no_grad():
p1_nn_idx = _C.nn_points_idx(
p1.contiguous(), p2.contiguous()
) # (N, P1)
p1_nn_idx_expanded = p1_nn_idx.view(N, P1, 1).expand(N, P1, D)
p1_nn_points = p2.gather(1, p1_nn_idx_expanded)
if p2_normals is None:
p1_nn_normals = []
else:
if p2_normals.shape != p2.shape:
raise ValueError("p2_normals has incorrect shape.")
p1_nn_normals = p2_normals.gather(1, p1_nn_idx_expanded)
return p1_nn_points, p1_nn_idx, p1_nn_normals

View File

@@ -0,0 +1,127 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
"""
This module implements utility functions for sampling points from
batches of meshes.
"""
import sys
from typing import Tuple, Union
import torch
from pytorch3d import _C
def sample_points_from_meshes(
meshes, num_samples: int = 10000, return_normals: bool = False
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
"""
Convert a batch of meshes to a pointcloud by uniformly sampling points on
the surface of the mesh with probability proportional to the face area.
Args:
meshes: A Meshes object with a batch of N meshes.
num_samples: Integer giving the number of point samples per mesh.
return_normals: If True, return normals for the sampled points.
eps: (float) used to clamp the norm of the normals to avoid dividing by 0.
Returns:
2-element tuple containing
- **samples**: FloatTensor of shape (N, num_samples, 3) giving the
coordinates of sampled points for each mesh in the batch. For empty
meshes the corresponding row in the samples array will be filled with 0.
- **normals**: FloatTensor of shape (N, num_samples, 3) giving a normal vector
to each sampled point. Only returned if return_normals is True.
For empty meshes the corresponding row in the normals array will
be filled with 0.
"""
if meshes.isempty():
raise ValueError("Meshes are empty.")
verts = meshes.verts_packed()
faces = meshes.faces_packed()
mesh_to_face = meshes.mesh_to_faces_packed_first_idx()
num_meshes = len(meshes)
num_valid_meshes = torch.sum(meshes.valid) # Non empty meshes.
# Intialize samples tensor with fill value 0 for empty meshes.
samples = torch.zeros((num_meshes, num_samples, 3), device=meshes.device)
# Only compute samples for non empty meshes
with torch.no_grad():
areas, _ = _C.face_areas_normals(
verts, faces
) # Face areas can be zero.
max_faces = meshes.num_faces_per_mesh().max().item()
areas_padded = _C.packed_to_padded_tensor(
areas, mesh_to_face[meshes.valid], max_faces
) # (N, F)
# TODO (gkioxari) Confirm multinomial bug is not present with real data.
sample_face_idxs = areas_padded.multinomial(
num_samples, replacement=True
) # (N, num_samples)
sample_face_idxs += mesh_to_face[meshes.valid].view(num_valid_meshes, 1)
# Get the vertex coordinates of the sampled faces.
face_verts = verts[faces.long()]
v0, v1, v2 = face_verts[:, 0], face_verts[:, 1], face_verts[:, 2]
# Randomly generate barycentric coords.
w0, w1, w2 = _rand_barycentric_coords(
num_valid_meshes, num_samples, verts.dtype, verts.device
)
# Use the barycentric coords to get a point on each sampled face.
a = v0[sample_face_idxs] # (N, num_samples, 3)
b = v1[sample_face_idxs]
c = v2[sample_face_idxs]
samples[meshes.valid] = (
w0[:, :, None] * a + w1[:, :, None] * b + w2[:, :, None] * c
)
if return_normals:
# Intialize normals tensor with fill value 0 for empty meshes.
# Normals for the sampled points are face normals computed from
# the vertices of the face in which the sampled point lies.
normals = torch.zeros(
(num_meshes, num_samples, 3), device=meshes.device
)
vert_normals = (v1 - v0).cross(v2 - v1, dim=1)
vert_normals = vert_normals / vert_normals.norm(
dim=1, p=2, keepdim=True
).clamp(min=sys.float_info.epsilon)
vert_normals = vert_normals[sample_face_idxs]
normals[meshes.valid] = vert_normals
return samples, normals
else:
return samples
def _rand_barycentric_coords(
size1, size2, dtype, device
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Helper function to generate random barycentric coordinates which are uniformly
distributed over a triangle.
Args:
size1, size2: The number of coordinates generated will be size1*size2.
Output tensors will each be of shape (size1, size2).
dtype: Datatype to generate.
device: A torch.device object on which the outputs will be allocated.
Returns:
w0, w1, w2: Tensors of shape (size1, size2) giving random barycentric
coordinates
"""
uv = torch.rand(2, size1, size2, dtype=dtype, device=device)
u, v = uv[0], uv[1]
u_sqrt = u.sqrt()
w0 = 1.0 - u_sqrt
w1 = u_sqrt * (1.0 - v)
w2 = u_sqrt * v
return w0, w1, w2

View File

@@ -0,0 +1,479 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
import torch
import torch.nn as nn
from pytorch3d.structures import Meshes
class SubdivideMeshes(nn.Module):
"""
Subdivide a triangle mesh by adding a new vertex at the center of each edge
and dividing each face into four new faces. Vectors of vertex
attributes can also be subdivided by averaging the values of the attributes
at the two vertices which form each edge. This implementation
preserves face orientation - if the vertices of a face are all ordered
counter-clockwise, then the faces in the subdivided meshes will also have
their vertices ordered counter-clockwise.
If meshes is provided as an input, the initializer performs the relatively
expensive computation of determining the new face indices. This one-time
computation can be reused for all meshes with the same face topology
but different vertex positions.
"""
def __init__(self, meshes=None):
"""
Args:
meshes: Meshes object or None. If a meshes object is provided,
the first mesh is used to compute the new faces of the
subdivided topology which can be reused for meshes with
the same input topology.
"""
super(SubdivideMeshes, self).__init__()
self.precomputed = False
self._N = -1
if meshes is not None:
# This computation is on indices, so gradients do not need to be
# tracked.
mesh = meshes[0]
with torch.no_grad():
subdivided_faces = self.subdivide_faces(mesh)
if subdivided_faces.shape[1] != 3:
raise ValueError("faces can only have three vertices")
self.register_buffer("_subdivided_faces", subdivided_faces)
self.precomputed = True
def subdivide_faces(self, meshes):
r"""
Args:
meshes: a Meshes object.
Returns:
subdivided_faces_packed: (4*sum(F_n), 3) shape LongTensor of
original and new faces.
Refer to pytorch3d.structures.meshes.py for more details on packed
representations of faces.
Each face is split into 4 faces e.g. Input face
::
v0
/\
/ \
/ \
e1 / \ e0
/ \
/ \
/ \
/______________\
v2 e2 v1
faces_packed = [[0, 1, 2]]
faces_packed_to_edges_packed = [[2, 1, 0]]
`faces_packed_to_edges_packed` is used to represent all the new
vertex indices corresponding to the mid-points of edges in the mesh.
The actual vertex coordinates will be computed in the forward function.
To get the indices of the new vertices, offset
`faces_packed_to_edges_packed` by the total number of vertices.
::
faces_packed_to_edges_packed = [[2, 1, 0]] + 3 = [[5, 4, 3]]
e.g. subdivided face
::
v0
/\
/ \
/ f0 \
v4 /______\ v3
/\ /\
/ \ f3 / \
/ f2 \ / f1 \
/______\/______\
v2 v5 v1
f0 = [0, 3, 4]
f1 = [1, 5, 3]
f2 = [2, 4, 5]
f3 = [5, 4, 3]
"""
verts_packed = meshes.verts_packed()
with torch.no_grad():
faces_packed = meshes.faces_packed()
faces_packed_to_edges_packed = meshes.faces_packed_to_edges_packed()
faces_packed_to_edges_packed += verts_packed.shape[0]
f0 = torch.stack(
[
faces_packed[:, 0],
faces_packed_to_edges_packed[:, 2],
faces_packed_to_edges_packed[:, 1],
],
dim=1,
)
f1 = torch.stack(
[
faces_packed[:, 1],
faces_packed_to_edges_packed[:, 0],
faces_packed_to_edges_packed[:, 2],
],
dim=1,
)
f2 = torch.stack(
[
faces_packed[:, 2],
faces_packed_to_edges_packed[:, 1],
faces_packed_to_edges_packed[:, 0],
],
dim=1,
)
f3 = faces_packed_to_edges_packed
subdivided_faces_packed = torch.cat(
[f0, f1, f2, f3], dim=0
) # (4*sum(F_n), 3)
return subdivided_faces_packed
def forward(self, meshes, feats=None):
"""
Subdivide a batch of meshes by adding a new vertex on each edge, and
dividing each face into four new faces. New meshes contains two types
of vertices:
1) Vertices that appear in the input meshes.
Data for these vertices are copied from the input meshes.
2) New vertices at the midpoint of each edge.
Data for these vertices is the average of the data for the two
vertices that make up the edge.
Args:
meshes: Meshes object representing a batch of meshes.
feats: Per-vertex features to be subdivided along with the verts.
Should be parallel to the packed vert representation of the
input meshes; so it should have shape (V, D) where V is the
total number of verts in the input meshes. Default: None.
Returns:
2-element tuple containing
- **new_meshes**: Meshes object of a batch of subdivided meshes.
- **new_feats**: (optional) Tensor of subdivided feats, parallel to the
(packed) vertices of the subdivided meshes. Only returned
if feats is not None.
"""
self._N = len(meshes)
if self.precomputed:
return self.subdivide_homogeneous(meshes, feats)
else:
return self.subdivide_heterogenerous(meshes, feats)
def subdivide_homogeneous(self, meshes, feats=None):
"""
Subdivide verts (and optionally features) of a batch of meshes
where each mesh has the same topology of faces. The subdivided faces
are precomputed in the initializer.
Args:
meshes: Meshes object representing a batch of meshes.
feats: Per-vertex features to be subdivided along with the verts.
Returns:
2-element tuple containing
- **new_meshes**: Meshes object of a batch of subdivided meshes.
- **new_feats**: (optional) Tensor of subdivided feats, parallel to the
(packed) vertices of the subdivided meshes. Only returned
if feats is not None.
"""
verts = meshes.verts_padded() # (N, V, D)
edges = meshes[0].edges_packed()
# The set of faces is the same across the different meshes.
new_faces = self._subdivided_faces.view(1, -1, 3).expand(
self._N, -1, -1
)
# Add one new vertex at the midpoint of each edge by taking the average
# of the vertices that form each edge.
new_verts = verts[:, edges].mean(dim=2)
new_verts = torch.cat(
[verts, new_verts], dim=1
) # (sum(V_n)+sum(E_n), 3)
new_feats = None
# Calculate features for new vertices.
if feats is not None:
if feats.dim() == 2:
# feats is in packed format, transform it from packed to
# padded, i.e. (N*V, D) to (N, V, D).
feats = feats.view(verts.size(0), verts.size(1), feats.size(1))
if feats.dim() != 3:
raise ValueError(
"features need to be of shape (N, V, D) or (N*V, D)"
)
# Take average of the features at the vertices that form each edge.
new_feats = feats[:, edges].mean(dim=2)
new_feats = torch.cat(
[feats, new_feats], dim=1
) # (sum(V_n)+sum(E_n), 3)
new_meshes = Meshes(verts=new_verts, faces=new_faces)
if feats is None:
return new_meshes
else:
return new_meshes, new_feats
def subdivide_heterogenerous(self, meshes, feats=None):
"""
Subdivide faces, verts (and optionally features) of a batch of meshes
where each mesh can have different face topologies.
Args:
meshes: Meshes object representing a batch of meshes.
feats: Per-vertex features to be subdivided along with the verts.
Returns:
2-element tuple containing
- **new_meshes**: Meshes object of a batch of subdivided meshes.
- **new_feats**: (optional) Tensor of subdivided feats, parallel to the
(packed) vertices of the subdivided meshes. Only returned
if feats is not None.
"""
# The computation of new faces is on face indices, so gradients do not
# need to be tracked.
verts = meshes.verts_packed()
with torch.no_grad():
new_faces = self.subdivide_faces(meshes)
edges = meshes.edges_packed()
face_to_mesh_idx = meshes.faces_packed_to_mesh_idx()
edge_to_mesh_idx = meshes.edges_packed_to_mesh_idx()
num_edges_per_mesh = edge_to_mesh_idx.bincount(minlength=self._N)
num_verts_per_mesh = meshes.num_verts_per_mesh()
num_faces_per_mesh = meshes.num_faces_per_mesh()
# Add one new vertex at the midpoint of each edge.
new_verts_per_mesh = num_verts_per_mesh + num_edges_per_mesh # (N,)
new_face_to_mesh_idx = torch.cat([face_to_mesh_idx] * 4, dim=0)
# Calculate the indices needed to group the new and existing verts
# for each mesh.
verts_sort_idx = create_verts_index(
num_verts_per_mesh, num_edges_per_mesh, meshes.device
) # (sum(V_n)+sum(E_n),)
verts_ordered_idx_init = torch.zeros(
new_verts_per_mesh.sum(),
dtype=torch.int64,
device=meshes.device,
) # (sum(V_n)+sum(E_n),)
# Reassign vertex indices so that existing and new vertices for each
# mesh are sequential.
verts_ordered_idx = verts_ordered_idx_init.scatter_add(
0,
verts_sort_idx,
torch.arange(new_verts_per_mesh.sum(), device=meshes.device),
)
# Retrieve vertex indices for each face.
new_faces = verts_ordered_idx[new_faces]
# Calculate the indices needed to group the existing and new faces
# for each mesh.
face_sort_idx = create_faces_index(
num_faces_per_mesh, device=meshes.device
)
# Reorder the faces to sequentially group existing and new faces
# for each mesh.
new_faces = new_faces[face_sort_idx]
new_face_to_mesh_idx = new_face_to_mesh_idx[face_sort_idx]
new_faces_per_mesh = new_face_to_mesh_idx.bincount(
minlength=self._N
) # (sum(F_n)*4)
# Add one new vertex at the midpoint of each edge by taking the average
# of the verts that form each edge.
new_verts = verts[edges].mean(dim=1)
new_verts = torch.cat([verts, new_verts], dim=0)
# Reorder the verts to sequentially group existing and new verts for
# each mesh.
new_verts = new_verts[verts_sort_idx]
if feats is not None:
new_feats = feats[edges].mean(dim=1)
new_feats = torch.cat([feats, new_feats], dim=0)
new_feats = new_feats[verts_sort_idx]
verts_list = list(new_verts.split(new_verts_per_mesh.tolist(), 0))
faces_list = list(new_faces.split(new_faces_per_mesh.tolist(), 0))
new_verts_per_mesh_cumsum = torch.cat(
[
new_verts_per_mesh.new_full(size=(1,), fill_value=0.0),
new_verts_per_mesh.cumsum(0)[:-1],
],
dim=0,
)
faces_list = [
faces_list[n] - new_verts_per_mesh_cumsum[n] for n in range(self._N)
]
if feats is not None:
feats_list = new_feats.split(new_verts_per_mesh.tolist(), 0)
new_meshes = Meshes(verts=verts_list, faces=faces_list)
if feats is None:
return new_meshes
else:
new_feats = torch.cat(feats_list, dim=0)
return new_meshes, new_feats
def create_verts_index(verts_per_mesh, edges_per_mesh, device=None):
"""
Helper function to group the vertex indices for each mesh. New vertices are
stacked at the end of the original verts tensor, so in order to have
sequential packing, the verts tensor needs to be reordered so that the
vertices corresponding to each mesh are grouped together.
Args:
verts_per_mesh: Tensor of shape (N,) giving the number of vertices
in each mesh in the batch where N is the batch size.
edges_per_mesh: Tensor of shape (N,) giving the number of edges
in each mesh in the batch
Returns:
verts_idx: A tensor with vert indices for each mesh ordered sequentially
by mesh index.
"""
# e.g. verts_per_mesh = (4, 5, 6)
# e.g. edges_per_mesh = (5, 7, 9)
V = verts_per_mesh.sum() # e.g. 15
E = edges_per_mesh.sum() # e.g. 21
verts_per_mesh_cumsum = verts_per_mesh.cumsum(dim=0) # (N,) e.g. (4, 9, 15)
edges_per_mesh_cumsum = edges_per_mesh.cumsum(
dim=0
) # (N,) e.g. (5, 12, 21)
v_to_e_idx = verts_per_mesh_cumsum.clone()
# vertex to edge index.
v_to_e_idx[1:] += edges_per_mesh_cumsum[
:-1
] # e.g. (4, 9, 15) + (0, 5, 12) = (4, 14, 27)
# vertex to edge offset.
v_to_e_offset = (
V - verts_per_mesh_cumsum
) # e.g. 15 - (4, 9, 15) = (11, 6, 0)
v_to_e_offset[1:] += edges_per_mesh_cumsum[
:-1
] # e.g. (11, 6, 0) + (0, 5, 12) = (11, 11, 12)
e_to_v_idx = (
verts_per_mesh_cumsum[:-1] + edges_per_mesh_cumsum[:-1]
) # (4, 9) + (5, 12) = (9, 21)
e_to_v_offset = (
verts_per_mesh_cumsum[:-1] - edges_per_mesh_cumsum[:-1] - V
) # (4, 9) - (5, 12) - 15 = (-16, -18)
# Add one new vertex per edge.
idx_diffs = torch.ones(V + E, device=device, dtype=torch.int64) # (36,)
idx_diffs[v_to_e_idx] += v_to_e_offset
idx_diffs[e_to_v_idx] += e_to_v_offset
# e.g.
# [
# 1, 1, 1, 1, 12, 1, 1, 1, 1,
# -15, 1, 1, 1, 1, 12, 1, 1, 1, 1, 1, 1,
# -17, 1, 1, 1, 1, 1, 13, 1, 1, 1, 1, 1, 1, 1
# ]
verts_idx = idx_diffs.cumsum(dim=0) - 1
# e.g.
# [
# 0, 1, 2, 3, 15, 16, 17, 18, 19, --> mesh 0
# 4, 5, 6, 7, 8, 20, 21, 22, 23, 24, 25, 26, --> mesh 1
# 9, 10, 11, 12, 13, 14, 27, 28, 29, 30, 31, 32, 33, 34, 35 --> mesh 2
# ]
# where for mesh 0, [0, 1, 2, 3] are the indices of the existing verts, and
# [15, 16, 17, 18, 19] are the indices of the new verts after subdivision.
return verts_idx
def create_faces_index(faces_per_mesh, device=None):
"""
Helper function to group the faces indices for each mesh. New faces are
stacked at the end of the original faces tensor, so in order to have
sequential packing, the faces tensor needs to be reordered to that faces
corresponding to each mesh are grouped together.
Args:
faces_per_mesh: Tensor of shape (N,) giving the number of faces
in each mesh in the batch where N is the batch size.
Returns:
faces_idx: A tensor with face indices for each mesh ordered sequentially
by mesh index.
"""
# e.g. faces_per_mesh = [2, 5, 3]
F = faces_per_mesh.sum() # e.g. 10
faces_per_mesh_cumsum = faces_per_mesh.cumsum(dim=0) # (N,) e.g. (2, 7, 10)
switch1_idx = faces_per_mesh_cumsum.clone()
switch1_idx[1:] += (
3 * faces_per_mesh_cumsum[:-1]
) # e.g. (2, 7, 10) + (0, 6, 21) = (2, 13, 31)
switch2_idx = 2 * faces_per_mesh_cumsum # e.g. (4, 14, 20)
switch2_idx[1:] += (
2 * faces_per_mesh_cumsum[:-1]
) # e.g. (4, 14, 20) + (0, 4, 14) = (4, 18, 34)
switch3_idx = 3 * faces_per_mesh_cumsum # e.g. (6, 21, 30)
switch3_idx[1:] += faces_per_mesh_cumsum[
:-1
] # e.g. (6, 21, 30) + (0, 2, 7) = (6, 23, 37)
switch4_idx = 4 * faces_per_mesh_cumsum[:-1] # e.g. (8, 28)
switch123_offset = F - faces_per_mesh # e.g. (8, 5, 7)
idx_diffs = torch.ones(4 * F, device=device, dtype=torch.int64)
idx_diffs[switch1_idx] += switch123_offset
idx_diffs[switch2_idx] += switch123_offset
idx_diffs[switch3_idx] += switch123_offset
idx_diffs[switch4_idx] -= 3 * F
# e.g
# [
# 1, 1, 9, 1, 9, 1, 9, 1, -> mesh 0
# -29, 1, 1, 1, 1, 6, 1, 1, 1, 1, 6, 1, 1, 1, 1, 6, 1, 1, 1, 1, -> mesh 1
# -29, 1, 1, 8, 1, 1, 8, 1, 1, 8, 1, 1 -> mesh 2
# ]
faces_idx = idx_diffs.cumsum(dim=0) - 1
# e.g.
# [
# 0, 1, 10, 11, 20, 21, 30, 31,
# 2, 3, 4, 5, 6, 12, 13, 14, 15, 16, 22, 23, 24, 25, 26, 32, 33, 34, 35, 36,
# 7, 8, 9, 17, 18, 19, 27, 28, 29, 37, 38, 39
# ]
# where for mesh 0, [0, 1] are the indices of the existing faces, and
# [10, 11, 20, 21, 30, 31] are the indices of the new faces after subdivision.
return faces_idx

101
pytorch3d/ops/vert_align.py Normal file
View File

@@ -0,0 +1,101 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
import torch
import torch.nn.functional as F
def vert_align(
feats,
verts,
return_packed: bool = False,
interp_mode: str = "bilinear",
padding_mode: str = "zeros",
align_corners: bool = True,
) -> torch.Tensor:
"""
Sample vertex features from a feature map. This operation is called
"perceptual feaure pooling" in [1] or "vert align" in [2].
[1] Wang et al, "Pixel2Mesh: Generating 3D Mesh Models from Single
RGB Images", ECCV 2018.
[2] Gkioxari et al, "Mesh R-CNN", ICCV 2019
Args:
feats: FloatTensor of shape (N, C, H, W) representing image features
from which to sample or a list of features each with potentially
different C, H or W dimensions.
verts: FloatTensor of shape (N, V, 3) or an object (e.g. Meshes) with
'verts_padded' as an attribute giving the (x, y, z) vertex positions
for which to sample. (x, y) verts should be normalized such that
(-1, -1) corresponds to top-left and (+1, +1) to bottom-right
location in the input feature map.
return_packed: (bool) Indicates whether to return packed features
interp_mode: (str) Specifies how to interpolate features.
('bilinear' or 'nearest')
padding_mode: (str) Specifies how to handle vertices outside of the
[-1, 1] range. ('zeros', 'reflection', or 'border')
align_corners (bool): Geometrically, we consider the pixels of the
input as squares rather than points.
If set to ``True``, the extrema (``-1`` and ``1``) are considered as
referring to the center points of the input's corner pixels. If set
to ``False``, they are instead considered as referring to the corner
points of the input's corner pixels, making the sampling more
resolution agnostic. Default: ``True``
Returns:
feats_sampled: FloatTensor of shape (N, V, C) giving sampled features for
each vertex. If feats is a list, we return concatentated
features in axis=2 of shape (N, V, sum(C_n)) where
C_n = feats[n].shape[1]. If return_packed = True, the
features are transformed to a packed representation
of shape (sum(V), C)
"""
if torch.is_tensor(verts):
if verts.dim() != 3:
raise ValueError("verts tensor should be 3 dimensional")
grid = verts
elif hasattr(verts, "verts_padded"):
grid = verts.verts_padded()
else:
raise ValueError(
"verts must be a tensor or have a `verts_padded` attribute"
)
grid = grid[:, None, :, :2] # (N, 1, V, 2)
if torch.is_tensor(feats):
feats = [feats]
for feat in feats:
if feat.dim() != 4:
raise ValueError("feats must have shape (N, C, H, W)")
if grid.shape[0] != feat.shape[0]:
raise ValueError("inconsistent batch dimension")
feats_sampled = []
for feat in feats:
feat_sampled = F.grid_sample(
feat,
grid,
mode=interp_mode,
padding_mode=padding_mode,
align_corners=align_corners,
) # (N, C, 1, V)
feat_sampled = feat_sampled.squeeze(dim=2).transpose(1, 2) # (N, V, C)
feats_sampled.append(feat_sampled)
feats_sampled = torch.cat(feats_sampled, dim=2) # (N, V, sum(C))
if return_packed:
# flatten the first two dimensions: (N*V, C)
feats_sampled = feats_sampled.view(-1, feats_sampled.shape[-1])
if hasattr(verts, "verts_padded_to_packed_idx"):
idx = (
verts.verts_padded_to_packed_idx()
.view(-1, 1)
.expand(-1, feats_sampled.shape[-1])
)
feats_sampled = feats_sampled.gather(0, idx) # (sum(V), C)
return feats_sampled

View File

@@ -0,0 +1,36 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
from .blending import (
BlendParams,
hard_rgb_blend,
sigmoid_alpha_blend,
softmax_rgb_blend,
)
from .cameras import (
OpenGLOrthographicCameras,
OpenGLPerspectiveCameras,
camera_position_from_spherical_angles,
get_world_to_view_transform,
look_at_rotation,
look_at_view_transform,
)
from .lighting import DirectionalLights, PointLights, diffuse, specular
from .materials import Materials
from .mesh import (
GouradShader,
MeshRasterizer,
MeshRenderer,
PhongShader,
RasterizationSettings,
SilhouetteShader,
TexturedPhongShader,
gourad_shading,
interpolate_face_attributes,
interpolate_texture_map,
interpolate_vertex_colors,
phong_shading,
rasterize_meshes,
)
from .utils import TensorProperties, convert_to_tensors_and_broadcast
__all__ = [k for k in globals().keys() if not k.startswith("_")]

View File

@@ -0,0 +1,184 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
import numpy as np
from typing import NamedTuple
import torch
# Example functions for blending the top K colors per pixel using the outputs
# from rasterization.
# NOTE: All blending function should return an RGBA image per batch element
# Data class to store blending params with defaults
class BlendParams(NamedTuple):
sigma: float = 1e-4
gamma: float = 1e-4
background_color = (1.0, 1.0, 1.0)
def hard_rgb_blend(colors, fragments) -> torch.Tensor:
"""
Naive blending of top K faces to return an RGBA image
- **RGB** - choose color of the closest point i.e. K=0
- **A** - 1.0
Args:
colors: (N, H, W, K, 3) RGB color for each of the top K faces per pixel.
fragments: the outputs of rasterization. From this we use
- pix_to_face: LongTensor of shape (N, H, W, K) specifying the indices
of the faces (in the packed representation) which
overlap each pixel in the image. This is used to
determine the output shape.
Returns:
RGBA pixel_colors: (N, H, W, 4)
"""
N, H, W, K = fragments.pix_to_face.shape
device = fragments.pix_to_face.device
pixel_colors = torch.ones((N, H, W, 4), dtype=colors.dtype, device=device)
pixel_colors[..., :3] = colors[..., 0, :]
return torch.flip(pixel_colors, [1])
def sigmoid_alpha_blend(colors, fragments, blend_params) -> torch.Tensor:
"""
Silhouette blending to return an RGBA image
- **RGB** - choose color of the closest point.
- **A** - blend based on the 2D distance based probability map [0].
Args:
colors: (N, H, W, K, 3) RGB color for each of the top K faces per pixel.
fragments: the outputs of rasterization. From this we use
- pix_to_face: LongTensor of shape (N, H, W, K) specifying the indices
of the faces (in the packed representation) which
overlap each pixel in the image.
- dists: FloatTensor of shape (N, H, W, K) specifying
the 2D euclidean distance from the center of each pixel
to each of the top K overlapping faces.
Returns:
RGBA pixel_colors: (N, H, W, 4)
[0] Liu et al, 'Soft Rasterizer: A Differentiable Renderer for Image-based
3D Reasoning', ICCV 2019
"""
N, H, W, K = fragments.pix_to_face.shape
pixel_colors = torch.ones(
(N, H, W, 4), dtype=colors.dtype, device=colors.device
)
mask = fragments.pix_to_face >= 0
# The distance is negative if a pixel is inside a face and positive outside
# the face. Therefore use -1.0 * fragments.dists to get the correct sign.
prob = torch.sigmoid(-fragments.dists / blend_params.sigma) * mask
# The cumulative product ensures that alpha will be 1 if at least 1 face
# fully covers the pixel as for that face prob will be 1.0
# TODO: investigate why torch.cumprod backwards is very slow for large
# values of K.
# Temporarily replace this with exp(sum(log))) using the fact that
# a*b = exp(log(a*b)) = exp(log(a) + log(b))
# alpha = 1.0 - torch.cumprod((1.0 - prob), dim=-1)[..., -1]
alpha = 1.0 - torch.exp(torch.log((1.0 - prob)).sum(dim=-1))
pixel_colors[..., :3] = colors[..., 0, :] # Hard assign for RGB
pixel_colors[..., 3] = alpha
pixel_colors = torch.clamp(pixel_colors, min=0, max=1.0)
return torch.flip(pixel_colors, [1])
def softmax_rgb_blend(colors, fragments, blend_params) -> torch.Tensor:
"""
RGB and alpha channel blending to return an RGBA image based on the method
proposed in [0]
- **RGB** - blend the colors based on the 2D distance based probability map and
relative z distances.
- **A** - blend based on the 2D distance based probability map.
Args:
colors: (N, H, W, K, 3) RGB color for each of the top K faces per pixel.
fragments: namedtuple with outputs of rasterization. We use properties
- pix_to_face: LongTensor of shape (N, H, W, K) specifying the indices
of the faces (in the packed representation) which
overlap each pixel in the image.
- dists: FloatTensor of shape (N, H, W, K) specifying
the 2D euclidean distance from the center of each pixel
to each of the top K overlapping faces.
- zbuf: FloatTensor of shape (N, H, W, K) specifying
the interpolated depth from each pixel to to each of the
top K overlapping faces.
blend_params: instance of BlendParams dataclass containing properties
- sigma: float, parameter which controls the width of the sigmoid
function used to calculate the 2D distance based probability.
Sigma controls the sharpness of the edges of the shape.
- gamma: float, parameter which controls the scaling of the
exponential function used to control the opacity of the color.
- background_color: (3) element list/tuple/torch.Tensor specifying
the RGB values for the background color.
Returns:
RGBA pixel_colors: (N, H, W, 4)
[0] Shichen Liu et al, 'Soft Rasterizer: A Differentiable Renderer for
Image-based 3D Reasoning'
"""
N, H, W, K = fragments.pix_to_face.shape
device = fragments.pix_to_face.device
pix_colors = torch.ones(
(N, H, W, 4), dtype=colors.dtype, device=colors.device
)
background = blend_params.background_color
if not torch.is_tensor(background):
background = torch.tensor(
background, dtype=torch.float32, device=device
)
# Background color
delta = np.exp(1e-10 / blend_params.gamma) * 1e-10
delta = torch.tensor(delta, device=device)
# Near and far clipping planes.
# TODO: add zfar/znear as input params.
zfar = 100.0
znear = 1.0
# Mask for padded pixels.
mask = fragments.pix_to_face >= 0
# Sigmoid probability map based on the distance of the pixel to the face.
prob_map = torch.sigmoid(-fragments.dists / blend_params.sigma) * mask
# The cumulative product ensures that alpha will be 1 if at least 1 face
# fully covers the pixel as for that face prob will be 1.0
# TODO: investigate why torch.cumprod backwards is very slow for large
# values of K.
# Temporarily replace this with exp(sum(log))) using the fact that
# a*b = exp(log(a*b)) = exp(log(a) + log(b))
# alpha = 1.0 - torch.cumprod((1.0 - prob), dim=-1)[..., -1]
alpha = 1.0 - torch.exp(torch.log((1.0 - prob_map)).sum(dim=-1))
# Weights for each face. Adjust the exponential by the max z to prevent
# overflow. zbuf shape (N, H, W, K), find max over K.
# TODO: there may still be some instability in the exponent calculation.
z_inv = (zfar - fragments.zbuf) / (zfar - znear) * mask
z_inv_max = torch.max(z_inv, dim=-1).values[..., None]
weights_num = prob_map * torch.exp((z_inv - z_inv_max) / blend_params.gamma)
# Normalize weights.
# weights_num shape: (N, H, W, K). Sum over K and divide through by the sum.
denom = weights_num.sum(dim=-1)[..., None] + delta
weights = weights_num / denom
# Sum: weights * textures + background color
weighted_colors = (weights[..., None] * colors).sum(dim=-2)
weighted_background = (delta / denom) * background
pix_colors[..., :3] = weighted_colors + weighted_background
pix_colors[..., 3] = alpha
# Clamp colors to the range 0-1 and flip y axis.
pix_colors = torch.clamp(pix_colors, min=0, max=1.0)
return torch.flip(pix_colors, [1])

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,284 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
import torch
import torch.nn.functional as F
from .utils import TensorProperties, convert_to_tensors_and_broadcast
def diffuse(normals, color, direction) -> torch.Tensor:
"""
Calculate the diffuse component of light reflection using Lambert's
cosine law.
Args:
normals: (N, ..., 3) xyz normal vectors. Normals and points are
expected to have the same shape.
color: (1, 3) or (N, 3) RGB color of the diffuse component of the light.
direction: (x,y,z) direction of the light
Returns:
colors: (N, ..., 3), same shape as the input points.
The normals and light direction should be in the same coordinate frame
i.e. if the points have been transformed from world -> view space then
the normals and direction should also be in view space.
NOTE: to use with the packed vertices (i.e. no batch dimension) reformat the
inputs in the following way.
.. code-block:: python
Args:
normals: (P, 3)
color: (N, 3)[batch_idx, :] -> (P, 3)
direction: (N, 3)[batch_idx, :] -> (P, 3)
Returns:
colors: (P, 3)
where batch_idx is of shape (P). For meshes, batch_idx can be:
meshes.verts_packed_to_mesh_idx() or meshes.faces_packed_to_mesh_idx()
depending on whether points refers to the vertex coordinates or
average/interpolated face coordinates.
"""
# TODO: handle multiple directional lights per batch element.
# TODO: handle attentuation.
# Ensure color and location have same batch dimension as normals
normals, color, direction = convert_to_tensors_and_broadcast(
normals, color, direction, device=normals.device
)
# Reshape direction and color so they have all the arbitrary intermediate
# dimensions as normals. Assume first dim = batch dim and last dim = 3.
points_dims = normals.shape[1:-1]
expand_dims = (-1,) + (1,) * len(points_dims) + (3,)
if direction.shape != normals.shape:
direction = direction.view(expand_dims)
if color.shape != normals.shape:
color = color.view(expand_dims)
# Renormalize the normals in case they have been interpolated.
normals = F.normalize(normals, p=2, dim=-1, eps=1e-6)
direction = F.normalize(direction, p=2, dim=-1, eps=1e-6)
angle = F.relu(torch.sum(normals * direction, dim=-1))
return color * angle[..., None]
def specular(
points, normals, direction, color, camera_position, shininess
) -> torch.Tensor:
"""
Calculate the specular component of light reflection.
Args:
points: (N, ..., 3) xyz coordinates of the points.
normals: (N, ..., 3) xyz normal vectors for each point.
color: (N, 3) RGB color of the specular component of the light.
direction: (N, 3) vector direction of the light.
camera_position: (N, 3) The xyz position of the camera.
shininess: (N) The specular exponent of the material.
Returns:
colors: (N, ..., 3), same shape as the input points.
The points, normals, camera_position, and direction should be in the same
coordinate frame i.e. if the points have been transformed from
world -> view space then the normals, camera_position, and light direction
should also be in view space.
To use with a batch of packed points reindex in the following way.
.. code-block:: python::
Args:
points: (P, 3)
normals: (P, 3)
color: (N, 3)[batch_idx] -> (P, 3)
direction: (N, 3)[batch_idx] -> (P, 3)
camera_position: (N, 3)[batch_idx] -> (P, 3)
shininess: (N)[batch_idx] -> (P)
Returns:
colors: (P, 3)
where batch_idx is of shape (P). For meshes batch_idx can be:
meshes.verts_packed_to_mesh_idx() or meshes.faces_packed_to_mesh_idx().
"""
# TODO: handle multiple directional lights
# TODO: attentuate based on inverse squared distance to the light source
if points.shape != normals.shape:
msg = "Expected points and normals to have the same shape: got %r, %r"
raise ValueError(msg % (points.shape, normals.shape))
# Ensure all inputs have same batch dimension as points
matched_tensors = convert_to_tensors_and_broadcast(
points,
color,
direction,
camera_position,
shininess,
device=points.device,
)
_, color, direction, camera_position, shininess = matched_tensors
# Reshape direction and color so they have all the arbitrary intermediate
# dimensions as points. Assume first dim = batch dim and last dim = 3.
points_dims = points.shape[1:-1]
expand_dims = (-1,) + (1,) * len(points_dims)
if direction.shape != normals.shape:
direction = direction.view(expand_dims + (3,))
if color.shape != normals.shape:
color = color.view(expand_dims + (3,))
if camera_position.shape != normals.shape:
camera_position = camera_position.view(expand_dims + (3,))
if shininess.shape != normals.shape:
shininess = shininess.view(expand_dims)
# Renormalize the normals in case they have been interpolated.
normals = F.normalize(normals, p=2, dim=-1, eps=1e-6)
direction = F.normalize(direction, p=2, dim=-1, eps=1e-6)
cos_angle = torch.sum(normals * direction, dim=-1)
# No specular highlights if angle is less than 0.
mask = (cos_angle > 0).to(torch.float32)
# Calculate the specular reflection.
view_direction = camera_position - points
view_direction = F.normalize(view_direction, p=2, dim=-1, eps=1e-6)
reflect_direction = -direction + 2 * (cos_angle[..., None] * normals)
# Cosine of the angle between the reflected light ray and the viewer
alpha = F.relu(torch.sum(view_direction * reflect_direction, dim=-1)) * mask
return color * torch.pow(alpha, shininess)[..., None]
class DirectionalLights(TensorProperties):
def __init__(
self,
ambient_color=((0.5, 0.5, 0.5),),
diffuse_color=((0.3, 0.3, 0.3),),
specular_color=((0.2, 0.2, 0.2),),
direction=((0, 1, 0),),
device: str = "cpu",
):
"""
Args:
ambient_color: RGB color of the ambient component.
diffuse_color: RGB color of the diffuse component.
specular_color: RGB color of the specular component.
direction: (x, y, z) direction vector of the light.
device: torch.device on which the tensors should be located
The inputs can each be
- 3 element tuple/list or list of lists
- torch tensor of shape (1, 3)
- torch tensor of shape (N, 3)
The inputs are broadcast against each other so they all have batch
dimension N.
"""
super().__init__(
device=device,
ambient_color=ambient_color,
diffuse_color=diffuse_color,
specular_color=specular_color,
direction=direction,
)
_validate_light_properties(self)
if self.direction.shape[-1] != 3:
msg = "Expected direction to have shape (N, 3); got %r"
raise ValueError(msg % repr(self.direction.shape))
def clone(self):
other = DirectionalLights(device=self.device)
return super().clone(other)
def diffuse(self, normals, points=None) -> torch.Tensor:
# NOTE: Points is not used but is kept in the args so that the API is
# the same for directional and point lights. The call sites should not
# need to know the light type.
return diffuse(
normals=normals, color=self.diffuse_color, direction=self.direction
)
def specular(
self, normals, points, camera_position, shininess
) -> torch.Tensor:
return specular(
points=points,
normals=normals,
color=self.specular_color,
direction=self.direction,
camera_position=camera_position,
shininess=shininess,
)
class PointLights(TensorProperties):
def __init__(
self,
ambient_color=((0.5, 0.5, 0.5),),
diffuse_color=((0.3, 0.3, 0.3),),
specular_color=((0.2, 0.2, 0.2),),
location=((0, 1, 0),),
device: str = "cpu",
):
"""
Args:
ambient_color: RGB color of the ambient component
diffuse_color: RGB color of the diffuse component
specular_color: RGB color of the specular component
location: xyz position of the light.
device: torch.device on which the tensors should be located
The inputs can each be
- 3 element tuple/list or list of lists
- torch tensor of shape (1, 3)
- torch tensor of shape (N, 3)
The inputs are broadcast against each other so they all have batch
dimension N.
"""
super().__init__(
device=device,
ambient_color=ambient_color,
diffuse_color=diffuse_color,
specular_color=specular_color,
location=location,
)
_validate_light_properties(self)
if self.location.shape[-1] != 3:
msg = "Expected location to have shape (N, 3); got %r"
raise ValueError(msg % repr(self.location.shape))
def clone(self):
other = PointLights(device=self.device)
return super().clone(other)
def diffuse(self, normals, points) -> torch.Tensor:
direction = self.location - points
return diffuse(
normals=normals, color=self.diffuse_color, direction=direction
)
def specular(
self, normals, points, camera_position, shininess
) -> torch.Tensor:
direction = self.location - points
return specular(
points=points,
normals=normals,
color=self.specular_color,
direction=direction,
camera_position=camera_position,
shininess=shininess,
)
def _validate_light_properties(obj):
props = ("ambient_color", "diffuse_color", "specular_color")
for n in props:
t = getattr(obj, n)
if t.shape[-1] != 3:
msg = "Expected %s to have shape (N, 3); got %r"
raise ValueError(msg % (n, t.shape))

View File

@@ -0,0 +1,59 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
import torch
from .utils import TensorProperties
class Materials(TensorProperties):
"""
A class for storing a batch of material properties. Currently only one
material per batch element is supported.
"""
def __init__(
self,
ambient_color=((1, 1, 1),),
diffuse_color=((1, 1, 1),),
specular_color=((1, 1, 1),),
shininess=64,
device="cpu",
):
"""
Args:
ambient_color: RGB ambient reflectivity of the material
diffuse_color: RGB diffuse reflectivity of the material
specular_color: RGB specular reflectivity of the material
shininess: The specular exponent for the material. This defines
the focus of the specular highlight with a high value
resulting in a concentrated highlight. Shininess values
can range from 0-1000.
device: torch.device or string
ambient_color, diffuse_color and specular_color can be of shape
(1, 3) or (N, 3). shininess can be of shape (1) or (N).
The colors and shininess are broadcast against each other so need to
have either the same batch dimension or batch dimension = 1.
"""
super().__init__(
device=device,
diffuse_color=diffuse_color,
ambient_color=ambient_color,
specular_color=specular_color,
shininess=shininess,
)
for n in ["ambient_color", "diffuse_color", "specular_color"]:
t = getattr(self, n)
if t.shape[-1] != 3:
msg = "Expected %s to have shape (N, 3); got %r"
raise ValueError(msg % (n, t.shape))
if self.shininess.shape != torch.Size([self._N]):
msg = "shininess should have shape (N); got %r"
raise ValueError(msg % repr(self.shininess.shape))
def clone(self):
other = Materials(device=self.device)
return super().clone(other)

View File

@@ -0,0 +1,19 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
from .rasterize_meshes import rasterize_meshes
from .rasterizer import MeshRasterizer, RasterizationSettings
from .renderer import MeshRenderer
from .shader import (
GouradShader,
PhongShader,
SilhouetteShader,
TexturedPhongShader,
)
from .shading import gourad_shading, phong_shading
from .texturing import ( # isort: skip
interpolate_face_attributes,
interpolate_texture_map,
interpolate_vertex_colors,
)
__all__ = [k for k in globals().keys() if not k.startswith("_")]

View File

@@ -0,0 +1,477 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
import numpy as np
from typing import Optional
import torch
from pytorch3d import _C
# TODO make the epsilon user configurable
kEpsilon = 1e-30
def rasterize_meshes(
meshes,
image_size: int = 256,
blur_radius: float = 0.0,
faces_per_pixel: int = 8,
bin_size: Optional[int] = None,
max_faces_per_bin: Optional[int] = None,
perspective_correct: bool = False,
):
"""
Rasterize a batch of meshes given the shape of the desired output image.
Each mesh is rasterized onto a separate image of shape
(image_size, image_size).
Args:
meshes: A Meshes object representing a batch of meshes, batch size N.
image_size: Size in pixels of the output raster image for each mesh
in the batch. Assumes square images.
blur_radius: Float distance in the range [0, 2] used to expand the face
bounding boxes for rasterization. Setting blur radius
results in blurred edges around the shape instead of a
hard boundary. Set to 0 for no blur.
faces_per_pixel (Optional): Number of faces to save per pixel, returning
the nearest faces_per_pixel points along the z-axis.
bin_size: Size of bins to use for coarse-to-fine rasterization. Setting
bin_size=0 uses naive rasterization; setting bin_size=None attempts to
set it heuristically based on the shape of the input. This should not
affect the output, but can affect the speed of the forward pass.
faces_per_bin: Only applicable when using coarse-to-fine rasterization
(bin_size > 0); this is the maxiumum number of faces allowed within each
bin. If more than this many faces actually fall into a bin, an error
will be raised. This should not affect the output values, but can affect
the memory usage in the forward pass.
perspective_correct: Whether to apply perspective correction when computing
barycentric coordinates for pixels.
Returns:
4-element tuple containing
- **pix_to_face**: LongTensor of shape
(N, image_size, image_size, faces_per_pixel)
giving the indices of the nearest faces at each pixel,
sorted in ascending z-order.
Concretely ``pix_to_face[n, y, x, k] = f`` means that
``faces_verts[f]`` is the kth closest face (in the z-direction)
to pixel (y, x). Pixels that are hit by fewer than
faces_per_pixel are padded with -1.
- **zbuf**: FloatTensor of shape (N, image_size, image_size, faces_per_pixel)
giving the NDC z-coordinates of the nearest faces at each pixel,
sorted in ascending z-order.
Concretely, if ``pix_to_face[n, y, x, k] = f`` then
``zbuf[n, y, x, k] = face_verts[f, 2]``. Pixels hit by fewer than
faces_per_pixel are padded with -1.
- **barycentric**: FloatTensor of shape
(N, image_size, image_size, faces_per_pixel, 3)
giving the barycentric coordinates in NDC units of the
nearest faces at each pixel, sorted in ascending z-order.
Concretely, if ``pix_to_face[n, y, x, k] = f`` then
``[w0, w1, w2] = barycentric[n, y, x, k]`` gives
the barycentric coords for pixel (y, x) relative to the face
defined by ``face_verts[f]``. Pixels hit by fewer than
faces_per_pixel are padded with -1.
- **pix_dists**: FloatTensor of shape
(N, image_size, image_size, faces_per_pixel)
giving the signed Euclidean distance (in NDC units) in the
x/y plane of each point closest to the pixel. Concretely if
``pix_to_face[n, y, x, k] = f`` then ``pix_dists[n, y, x, k]`` is the
squared distance between the pixel (y, x) and the face given
by vertices ``face_verts[f]``. Pixels hit with fewer than
``faces_per_pixel`` are padded with -1.
"""
verts_packed = meshes.verts_packed()
faces_packed = meshes.faces_packed()
face_verts = verts_packed[faces_packed]
mesh_to_face_first_idx = meshes.mesh_to_faces_packed_first_idx()
num_faces_per_mesh = meshes.num_faces_per_mesh()
# TODO: Choose naive vs coarse-to-fine based on mesh size and image size.
if bin_size is None:
if not verts_packed.is_cuda:
# Binned CPU rasterization is not supported.
bin_size = 0
else:
# TODO better heuristics for bin size.
if image_size <= 64:
bin_size = 8
elif image_size <= 256:
bin_size = 16
elif image_size <= 512:
bin_size = 32
elif image_size <= 1024:
bin_size = 64
if max_faces_per_bin is None:
max_faces_per_bin = int(max(10000, verts_packed.shape[0] / 5))
return _RasterizeFaceVerts.apply(
face_verts,
mesh_to_face_first_idx,
num_faces_per_mesh,
image_size,
blur_radius,
faces_per_pixel,
bin_size,
max_faces_per_bin,
perspective_correct,
)
class _RasterizeFaceVerts(torch.autograd.Function):
"""
Torch autograd wrapper for forward and backward pass of rasterize_meshes
implemented in C++/CUDA.
Args:
face_verts: Tensor of shape (F, 3, 3) giving (packed) vertex positions
for faces in all the meshes in the batch. Concretely,
face_verts[f, i] = [x, y, z] gives the coordinates for the
ith vertex of the fth face. These vertices are expected to
be in NDC coordinates in the range [-1, 1].
mesh_to_face_first_idx: LongTensor of shape (N) giving the index in
faces_verts of the first face in each mesh in
the batch.
num_faces_per_mesh: LongTensor of shape (N) giving the number of faces
for each mesh in the batch.
image_size, blur_radius, faces_per_pixel: same as rasterize_meshes.
perspective_correct: same as rasterize_meshes.
Returns:
same as rasterize_meshes function.
"""
@staticmethod
def forward(
ctx,
face_verts,
mesh_to_face_first_idx,
num_faces_per_mesh,
image_size: int = 256,
blur_radius: float = 0.01,
faces_per_pixel: int = 0,
bin_size: int = 0,
max_faces_per_bin: int = 0,
perspective_correct: bool = False,
):
pix_to_face, zbuf, barycentric_coords, dists = _C.rasterize_meshes(
face_verts,
mesh_to_face_first_idx,
num_faces_per_mesh,
image_size,
blur_radius,
faces_per_pixel,
bin_size,
max_faces_per_bin,
perspective_correct,
)
ctx.save_for_backward(face_verts, pix_to_face)
ctx.perspective_correct = perspective_correct
return pix_to_face, zbuf, barycentric_coords, dists
@staticmethod
def backward(
ctx, grad_pix_to_face, grad_zbuf, grad_barycentric_coords, grad_dists
):
grad_face_verts = None
grad_mesh_to_face_first_idx = None
grad_num_faces_per_mesh = None
grad_image_size = None
grad_radius = None
grad_faces_per_pixel = None
grad_bin_size = None
grad_max_faces_per_bin = None
grad_perspective_correct = None
face_verts, pix_to_face = ctx.saved_tensors
grad_face_verts = _C.rasterize_meshes_backward(
face_verts,
pix_to_face,
grad_zbuf,
grad_barycentric_coords,
grad_dists,
ctx.perspective_correct,
)
grads = (
grad_face_verts,
grad_mesh_to_face_first_idx,
grad_num_faces_per_mesh,
grad_image_size,
grad_radius,
grad_faces_per_pixel,
grad_bin_size,
grad_max_faces_per_bin,
grad_perspective_correct,
)
return grads
def rasterize_meshes_python(
meshes,
image_size: int = 256,
blur_radius: float = 0.0,
faces_per_pixel: int = 8,
perspective_correct: bool = False,
):
"""
Naive PyTorch implementation of mesh rasterization with the same inputs and
outputs as the rasterize_meshes function.
This function is not optimized and is implemented as a comparison for the
C++/CUDA implementations.
"""
N = len(meshes)
# Assume only square images.
# TODO(T52813608) extend support for non-square images.
H, W, = image_size, image_size
K = faces_per_pixel
device = meshes.device
verts_packed = meshes.verts_packed()
faces_packed = meshes.faces_packed()
faces_verts = verts_packed[faces_packed]
mesh_to_face_first_idx = meshes.mesh_to_faces_packed_first_idx()
num_faces_per_mesh = meshes.num_faces_per_mesh()
# Intialize output tensors.
face_idxs = torch.full(
(N, H, W, K), fill_value=-1, dtype=torch.int64, device=device
)
zbuf = torch.full(
(N, H, W, K), fill_value=-1, dtype=torch.float32, device=device
)
bary_coords = torch.full(
(N, H, W, K, 3), fill_value=-1, dtype=torch.float32, device=device
)
pix_dists = torch.full(
(N, H, W, K), fill_value=-1, dtype=torch.float32, device=device
)
# NDC is from [-1, 1]. Get pixel size using specified image size.
pixel_width = 2.0 / W
pixel_height = 2.0 / H
# Calculate all face bounding boxes.
x_mins = torch.min(faces_verts[:, :, 0], dim=1, keepdim=True).values
x_maxs = torch.max(faces_verts[:, :, 0], dim=1, keepdim=True).values
y_mins = torch.min(faces_verts[:, :, 1], dim=1, keepdim=True).values
y_maxs = torch.max(faces_verts[:, :, 1], dim=1, keepdim=True).values
# Expand by blur radius.
x_mins = x_mins - np.sqrt(blur_radius) - kEpsilon
x_maxs = x_maxs + np.sqrt(blur_radius) + kEpsilon
y_mins = y_mins - np.sqrt(blur_radius) - kEpsilon
y_maxs = y_maxs + np.sqrt(blur_radius) + kEpsilon
# Loop through meshes in the batch.
for n in range(N):
face_start_idx = mesh_to_face_first_idx[n]
face_stop_idx = face_start_idx + num_faces_per_mesh[n]
# Y coordinate of the top of the image.
yf = -1.0 + 0.5 * pixel_height
# Iterate through the horizontal lines of the image from top to bottom.
for yi in range(H):
# X coordinate of the left of the image.
xf = -1.0 + 0.5 * pixel_width
# Iterate through pixels on this horizontal line, left to right.
for xi in range(W):
top_k_points = []
# Check whether each face in the mesh affects this pixel.
for f in range(face_start_idx, face_stop_idx):
face = faces_verts[f].squeeze()
v0, v1, v2 = face.unbind(0)
face_area = edge_function(v2, v0, v1)
# Ignore faces which have zero area.
if face_area == 0.0:
continue
outside_bbox = (
xf < x_mins[f]
or xf > x_maxs[f]
or yf < y_mins[f]
or yf > y_maxs[f]
)
# Check if pixel is outside of face bbox.
if outside_bbox:
continue
# Compute barycentric coordinates and pixel z distance.
pxy = torch.tensor(
[xf, yf], dtype=torch.float32, device=device
)
bary = barycentric_coordinates(pxy, v0[:2], v1[:2], v2[:2])
if perspective_correct:
z0, z1, z2 = v0[2], v1[2], v2[2]
l0, l1, l2 = bary[0], bary[1], bary[2]
top0 = l0 * z1 * z2
top1 = z0 * l1 * z2
top2 = z0 * z1 * l2
bot = top0 + top1 + top2
bary = torch.stack([top0 / bot, top1 / bot, top2 / bot])
pz = bary[0] * v0[2] + bary[1] * v1[2] + bary[2] * v2[2]
# Check if point is behind the image.
if pz < 0:
continue
# Calculate signed 2D distance from point to face.
# Points inside the triangle have negative distance.
dist = point_triangle_distance(pxy, v0[:2], v1[:2], v2[:2])
inside = all(x > 0.0 for x in bary)
signed_dist = dist * -1.0 if inside else dist
# Add an epsilon to prevent errors when comparing distance
# to blur radius.
if not inside and dist >= blur_radius:
continue
top_k_points.append((pz, f, bary, signed_dist))
top_k_points.sort()
if len(top_k_points) > K:
top_k_points = top_k_points[:K]
# Save to output tensors.
for k, (pz, f, bary, dist) in enumerate(top_k_points):
zbuf[n, yi, xi, k] = pz
face_idxs[n, yi, xi, k] = f
bary_coords[n, yi, xi, k, 0] = bary[0]
bary_coords[n, yi, xi, k, 1] = bary[1]
bary_coords[n, yi, xi, k, 2] = bary[2]
pix_dists[n, yi, xi, k] = dist
# Move to the next horizontal pixel
xf += pixel_width
# Move to the next vertical pixel
yf += pixel_height
return face_idxs, zbuf, bary_coords, pix_dists
def edge_function(p, v0, v1):
r"""
Determines whether a point p is on the right side of a 2D line segment
given by the end points v0, v1.
Args:
p: (x, y) Coordinates of a point.
v0, v1: (x, y) Coordinates of the end points of the edge.
Returns:
area: The signed area of the parallelogram given by the vectors
.. code-block:: python
A = p - v0
B = v1 - v0
v1 ________
/\ /
A / \ /
/ \ /
v0 /______\/
B p
The area can also be interpreted as the cross product A x B.
If the sign of the area is positive, the point p is on the
right side of the edge. Negative area indicates the point is on
the left side of the edge. i.e. for an edge v1 - v0
.. code-block:: python
v1
/
/
- / +
/
/
v0
"""
return (p[0] - v0[0]) * (v1[1] - v0[1]) - (p[1] - v0[1]) * (v1[0] - v0[0])
def barycentric_coordinates(p, v0, v1, v2):
"""
Compute the barycentric coordinates of a point relative to a triangle.
Args:
p: Coordinates of a point.
v0, v1, v2: Coordinates of the triangle vertices.
Returns
bary: (w0, w1, w2) barycentric coordinates in the range [0, 1].
"""
area = edge_function(v2, v0, v1) + kEpsilon # 2 x face area.
w0 = edge_function(p, v1, v2) / area
w1 = edge_function(p, v2, v0) / area
w2 = edge_function(p, v0, v1) / area
return (w0, w1, w2)
def point_line_distance(p, v0, v1):
"""
Return minimum distance between line segment (v1 - v0) and point p.
Args:
p: Coordinates of a point.
v0, v1: Coordinates of the end points of the line segment.
Returns:
non-square distance to the boundary of the triangle.
Consider the line extending the segment - this can be parameterized as
``v0 + t (v1 - v0)``.
First find the projection of point p onto the line. It falls where
``t = [(p - v0) . (v1 - v0)] / |v1 - v0|^2``
where . is the dot product.
The parameter t is clamped from [0, 1] to handle points outside the
segment (v1 - v0).
Once the projection of the point on the segment is known, the distance from
p to the projection gives the minimum distance to the segment.
"""
if p.shape != v0.shape != v1.shape:
raise ValueError("All points must have the same number of coordinates")
v1v0 = v1 - v0
l2 = v1v0.dot(v1v0) # |v1 - v0|^2
if l2 == 0.0:
return torch.sqrt((p - v1).dot(p - v1)) # v0 == v1
t = (v1v0).dot(p - v0) / l2
t = torch.clamp(t, min=0.0, max=1.0)
p_proj = v0 + t * v1v0
delta_p = p_proj - p
return delta_p.dot(delta_p)
def point_triangle_distance(p, v0, v1, v2):
"""
Return shortest distance between a point and a triangle.
Args:
p: Coordinates of a point.
v0, v1, v2: Coordinates of the three triangle vertices.
Returns:
shortest absolute distance from the point to the triangle.
"""
if p.shape != v0.shape != v1.shape != v2.shape:
raise ValueError("All points must have the same number of coordinates")
e01_dist = point_line_distance(p, v0, v1)
e02_dist = point_line_distance(p, v0, v2)
e12_dist = point_line_distance(p, v1, v2)
edge_dists_min = torch.min(torch.min(e01_dist, e02_dist), e12_dist)
return edge_dists_min

View File

@@ -0,0 +1,116 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
from dataclasses import dataclass
from typing import NamedTuple, Optional
import torch
import torch.nn as nn
from ..cameras import get_world_to_view_transform
from .rasterize_meshes import rasterize_meshes
# Class to store the outputs of mesh rasterization
class Fragments(NamedTuple):
pix_to_face: torch.Tensor
zbuf: torch.Tensor
bary_coords: torch.Tensor
dists: torch.Tensor
# Class to store the mesh rasterization params with defaults
@dataclass
class RasterizationSettings:
image_size: int = 256
blur_radius: float = 0.0
faces_per_pixel: int = 1
bin_size: Optional[int] = None
max_faces_per_bin: Optional[int] = None
perspective_correct: bool = False
class MeshRasterizer(nn.Module):
"""
This class implements methods for rasterizing a batch of heterogenous
Meshes.
"""
def __init__(self, cameras, raster_settings=None):
"""
Args:
cameras: A cameras object which has a `transform_points` method
which returns the transformed points after applying the
world-to-view and view-to-screen
transformations.
raster_settings: the parameters for rasterization. This should be a
named tuple.
All these initial settings can be overridden by passing keyword
arguments to the forward function.
"""
super().__init__()
if raster_settings is None:
raster_settings = RasterizationSettings()
self.cameras = cameras
self.raster_settings = raster_settings
def transform(self, meshes_world, **kwargs) -> torch.Tensor:
"""
Args:
meshes_world: a Meshes object representing a batch of meshes with
vertex coordinates in world space.
Returns:
meshes_screen: a Meshes object with the vertex positions in screen
space
NOTE: keeping this as a separate function for readability but it could
be moved into forward.
"""
cameras = kwargs.get("cameras", self.cameras)
verts_world = meshes_world.verts_padded()
verts_world_packed = meshes_world.verts_packed()
verts_screen = cameras.transform_points(verts_world, **kwargs)
# NOTE: Retaining view space z coordinate for now.
# TODO: Revisit whether or not to transform z coordinate to [-1, 1] or
# [0, 1] range.
view_transform = get_world_to_view_transform(R=cameras.R, T=cameras.T)
verts_view = view_transform.transform_points(verts_world)
verts_screen[..., 2] = verts_view[..., 2]
# Offset verts of input mesh to reuse cached padded/packed calculations.
pad_to_packed_idx = meshes_world.verts_padded_to_packed_idx()
verts_screen_packed = verts_screen.view(-1, 3)[pad_to_packed_idx, :]
verts_packed_offset = verts_screen_packed - verts_world_packed
return meshes_world.offset_verts(verts_packed_offset)
def forward(self, meshes_world, **kwargs) -> Fragments:
"""
Args:
meshes_world: a Meshes object representing a batch of meshes with
coordinates in world space.
Returns:
Fragments: Rasterization outputs as a named tuple.
"""
meshes_screen = self.transform(meshes_world, **kwargs)
raster_settings = kwargs.get("raster_settings", self.raster_settings)
# TODO(jcjohns): Should we try to set perspective_correct automatically
# based on the type of the camera?
pix_to_face, zbuf, bary_coords, dists = rasterize_meshes(
meshes_screen,
image_size=raster_settings.image_size,
blur_radius=raster_settings.blur_radius,
faces_per_pixel=raster_settings.faces_per_pixel,
bin_size=raster_settings.bin_size,
max_faces_per_bin=raster_settings.max_faces_per_bin,
perspective_correct=raster_settings.perspective_correct,
)
return Fragments(
pix_to_face=pix_to_face,
zbuf=zbuf,
bary_coords=bary_coords,
dists=dists,
)

View File

@@ -0,0 +1,39 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
import torch
import torch.nn as nn
# A renderer class should be initialized with a
# function for rasterization and a function for shading.
# The rasterizer should:
# - transform inputs from world -> screen space
# - rasterize inputs
# - return fragments
# The shader can take fragments as input along with any other properties of
# the scene and generate images.
# E.g. rasterize inputs and then shade
#
# fragments = self.rasterize(meshes)
# images = self.shader(fragments, meshes)
# return images
class MeshRenderer(nn.Module):
"""
A class for rendering a batch of heterogeneous meshes. The class should
be initialized with a rasterizer and shader class which each have a forward
function.
"""
def __init__(self, rasterizer, shader):
super().__init__()
self.rasterizer = rasterizer
self.shader = shader
def forward(self, meshes_world, **kwargs) -> torch.Tensor:
fragments = self.rasterizer(meshes_world, **kwargs)
images = self.shader(fragments, meshes_world, **kwargs)
return images

View File

@@ -0,0 +1,201 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
import torch
import torch.nn as nn
from ..blending import (
BlendParams,
hard_rgb_blend,
sigmoid_alpha_blend,
softmax_rgb_blend,
)
from ..cameras import OpenGLPerspectiveCameras
from ..lighting import PointLights
from ..materials import Materials
from .shading import gourad_shading, phong_shading
from .texturing import interpolate_texture_map, interpolate_vertex_colors
# A Shader should take as input fragments from the output of rasterization
# along with scene params and output images. A shader could perform operations
# such as:
# - interpolate vertex attributes for all the fragments
# - sample colors from a texture map
# - apply per pixel lighting
# - blend colors across top K faces per pixel.
class PhongShader(nn.Module):
"""
Per pixel lighting. Apply the lighting model using the interpolated coords
and normals for each pixel.
To use the default values, simply initialize the shader with the desired
device e.g.
.. code-block::
shader = PhongShader(device=torch.device("cuda:0"))
"""
def __init__(self, device="cpu", cameras=None, lights=None, materials=None):
super().__init__()
self.lights = (
lights if lights is not None else PointLights(device=device)
)
self.materials = (
materials if materials is not None else Materials(device=device)
)
self.cameras = (
cameras
if cameras is not None
else OpenGLPerspectiveCameras(device=device)
)
def forward(self, fragments, meshes, **kwargs) -> torch.Tensor:
texels = interpolate_vertex_colors(fragments, meshes)
cameras = kwargs.get("cameras", self.cameras)
lights = kwargs.get("lights", self.lights)
materials = kwargs.get("materials", self.materials)
colors = phong_shading(
meshes=meshes,
fragments=fragments,
texels=texels,
lights=lights,
cameras=cameras,
materials=materials,
)
images = hard_rgb_blend(colors, fragments)
return images
class GouradShader(nn.Module):
"""
Per vertex lighting. Apply the lighting model to the vertex colors and then
interpolate using the barycentric coordinates to get colors for each pixel.
To use the default values, simply initialize the shader with the desired
device e.g.
.. code-block::
shader = GouradShader(device=torch.device("cuda:0"))
"""
def __init__(self, device="cpu", cameras=None, lights=None, materials=None):
super().__init__()
self.lights = (
lights if lights is not None else PointLights(device=device)
)
self.materials = (
materials if materials is not None else Materials(device=device)
)
self.cameras = (
cameras
if cameras is not None
else OpenGLPerspectiveCameras(device=device)
)
def forward(self, fragments, meshes, **kwargs) -> torch.Tensor:
cameras = kwargs.get("cameras", self.cameras)
lights = kwargs.get("lights", self.lights)
materials = kwargs.get("materials", self.materials)
pixel_colors = gourad_shading(
meshes=meshes,
fragments=fragments,
lights=lights,
cameras=cameras,
materials=materials,
)
images = hard_rgb_blend(pixel_colors, fragments)
return images
class TexturedPhongShader(nn.Module):
"""
Per pixel lighting applied to a texture map. First interpolate the vertex
uv coordinates and sample from a texture map. Then apply the lighting model
using the interpolated coords and normals for each pixel.
To use the default values, simply initialize the shader with the desired
device e.g.
.. code-block::
shader = TexturedPhongShader(device=torch.device("cuda:0"))
"""
def __init__(
self,
device="cpu",
cameras=None,
lights=None,
materials=None,
blend_params=None,
):
super().__init__()
self.lights = (
lights if lights is not None else PointLights(device=device)
)
self.materials = (
materials if materials is not None else Materials(device=device)
)
self.cameras = (
cameras
if cameras is not None
else OpenGLPerspectiveCameras(device=device)
)
self.blend_params = (
blend_params if blend_params is not None else BlendParams()
)
def forward(self, fragments, meshes, **kwargs) -> torch.Tensor:
texels = interpolate_texture_map(fragments, meshes)
cameras = kwargs.get("cameras", self.cameras)
lights = kwargs.get("lights", self.lights)
materials = kwargs.get("materials", self.materials)
colors = phong_shading(
meshes=meshes,
fragments=fragments,
texels=texels,
lights=lights,
cameras=cameras,
materials=materials,
)
images = softmax_rgb_blend(colors, fragments, self.blend_params)
return images
class SilhouetteShader(nn.Module):
"""
Calculate the silhouette by blending the top K faces for each pixel based
on the 2d euclidean distance of the centre of the pixel to the mesh face.
Use this shader for generating silhouettes similar to SoftRasterizer [0].
.. note::
To be consistent with SoftRasterizer, initialize the
RasterizationSettings for the rasterizer with
`blur_radius = np.log(1. / 1e-4 - 1.) * blend_params.sigma`
[0] Liu et al, 'Soft Rasterizer: A Differentiable Renderer for Image-based
3D Reasoning', ICCV 2019
"""
def __init__(self, blend_params=None):
super().__init__()
self.blend_params = (
blend_params if blend_params is not None else BlendParams()
)
def forward(self, fragments, meshes, **kwargs) -> torch.Tensor:
""""
Only want to render the silhouette so RGB values can be ones.
There is no need for lighting or texturing
"""
colors = torch.ones_like(fragments.bary_coords)
blend_params = kwargs.get("blend_params", self.blend_params)
images = sigmoid_alpha_blend(colors, fragments, blend_params)
return images

View File

@@ -0,0 +1,126 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
from typing import Tuple
import torch
from .texturing import interpolate_face_attributes
def _apply_lighting(
points, normals, lights, cameras, materials
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Args:
points: torch tensor of shape (N, P, 3) or (P, 3).
normals: torch tensor of shape (N, P, 3) or (P, 3)
lights: instance of the Lights class.
cameras: instance of the Cameras class.
materials: instance of the Materials class.
Returns:
ambient_color: same shape as materials.ambient_color
diffuse_color: same shape as the input points
specular_color: same shape as the input points
"""
light_diffuse = lights.diffuse(normals=normals, points=points)
light_specular = lights.specular(
normals=normals,
points=points,
camera_position=cameras.get_camera_center(),
shininess=materials.shininess,
)
ambient_color = materials.ambient_color * lights.ambient_color
diffuse_color = materials.diffuse_color * light_diffuse
specular_color = materials.specular_color * light_specular
if normals.dim() == 2 and points.dim() == 2:
# If given packed inputs remove batch dim in output.
return (
ambient_color.squeeze(),
diffuse_color.squeeze(),
specular_color.squeeze(),
)
return ambient_color, diffuse_color, specular_color
def phong_shading(
meshes, fragments, lights, cameras, materials, texels
) -> torch.Tensor:
"""
Apply per pixel shading. First interpolate the vertex normals and
vertex coordinates using the barycentric coordinates to get the position
and normal at each pixel. Then compute the illumination for each pixel.
The pixel color is obtained by multiplying the pixel textures by the ambient
and diffuse illumination and adding the specular component.
Args:
meshes: Batch of meshes
fragments: Fragments named tuple with the outputs of rasterization
lights: Lights class containing a batch of lights
cameras: Cameras class containing a batch of cameras
materials: Materials class containing a batch of material properties
texels: texture per pixel of shape (N, H, W, K, 3)
Returns:
colors: (N, H, W, K, 3)
"""
verts = meshes.verts_packed() # (V, 3)
faces = meshes.faces_packed() # (F, 3)
vertex_normals = meshes.verts_normals_packed() # (V, 3)
faces_verts = verts[faces]
faces_normals = vertex_normals[faces]
pixel_coords = interpolate_face_attributes(fragments, faces_verts)
pixel_normals = interpolate_face_attributes(fragments, faces_normals)
ambient, diffuse, specular = _apply_lighting(
pixel_coords, pixel_normals, lights, cameras, materials
)
colors = (ambient + diffuse) * texels + specular
return colors
def gourad_shading(
meshes, fragments, lights, cameras, materials
) -> torch.Tensor:
"""
Apply per vertex shading. First compute the vertex illumination by applying
ambient, diffuse and specular lighting. If vertex color is available,
combine the ambient and diffuse vertex illumination with the vertex color
and add the specular component to determine the vertex shaded color.
Then interpolate the vertex shaded colors using the barycentric coordinates
to get a color per pixel.
Args:
meshes: Batch of meshes
fragments: Fragments named tuple with the outputs of rasterization
lights: Lights class containing a batch of lights parameters
cameras: Cameras class containing a batch of cameras parameters
materials: Materials class containing a batch of material properties
Returns:
colors: (N, H, W, K, 3)
"""
faces = meshes.faces_packed() # (F, 3)
verts = meshes.verts_packed()
vertex_normals = meshes.verts_normals_packed() # (V, 3)
vertex_colors = meshes.textures.verts_rgb_packed()
vert_to_mesh_idx = meshes.verts_packed_to_mesh_idx()
# Format properties of lights and materials so they are compatible
# with the packed representation of the vertices. This transforms
# all tensor properties in the class from shape (N, ...) -> (V, ...) where
# V is the number of packed vertices. If the number of meshes in the
# batch is one then this is not necessary.
if len(meshes) > 1:
lights = lights.clone().gather_props(vert_to_mesh_idx)
cameras = cameras.clone().gather_props(vert_to_mesh_idx)
materials = materials.clone().gather_props(vert_to_mesh_idx)
# Calculate the illumination at each vertex
ambient, diffuse, specular = _apply_lighting(
verts, vertex_normals, lights, cameras, materials
)
verts_colors_shaded = vertex_colors * (ambient + diffuse) + specular
face_colors = verts_colors_shaded[faces]
colors = interpolate_face_attributes(fragments, face_colors)
return colors

View File

@@ -0,0 +1,182 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
import torch
import torch.nn.functional as F
from pytorch3d.structures.textures import Textures
def _clip_barycentric_coordinates(bary) -> torch.Tensor:
"""
Args:
bary: barycentric coordinates of shape (...., 3) where `...` represents
an arbitrary number of dimensions
Returns:
bary: All barycentric coordinate values clipped to the range [0, 1]
and renormalized. The output is the same shape as the input.
"""
if bary.shape[-1] != 3:
msg = "Expected barycentric coords to have last dim = 3; got %r"
raise ValueError(msg % bary.shape)
clipped = bary.clamp(min=0, max=1)
clipped_sum = torch.clamp(clipped.sum(dim=-1, keepdim=True), min=1e-5)
clipped = clipped / clipped_sum
return clipped
def interpolate_face_attributes(
fragments, face_attributes: torch.Tensor, bary_clip: bool = False
) -> torch.Tensor:
"""
Interpolate arbitrary face attributes using the barycentric coordinates
for each pixel in the rasterized output.
Args:
fragments:
The outputs of rasterization. From this we use
- pix_to_face: LongTensor of shape (N, H, W, K) specifying the indices
of the faces (in the packed representation) which
overlap each pixel in the image.
- barycentric_coords: FloatTensor of shape (N, H, W, K, 3) specifying
the barycentric coordianates of each pixel
relative to the faces (in the packed
representation) which overlap the pixel.
face_attributes: packed attributes of shape (total_faces, 3, D),
specifying the value of the attribute for each
vertex in the face.
bary_clip: Bool to indicate if barycentric_coords should be clipped
before being used for interpolation.
Returns:
pixel_vals: tensor of shape (N, H, W, K, D) giving the interpolated
value of the face attribute for each pixel.
"""
pix_to_face = fragments.pix_to_face
barycentric_coords = fragments.bary_coords
F, FV, D = face_attributes.shape
if FV != 3:
raise ValueError("Faces can only have three vertices; got %r" % FV)
N, H, W, K, _ = barycentric_coords.shape
if pix_to_face.shape != (N, H, W, K):
msg = "pix_to_face must have shape (batch_size, H, W, K); got %r"
raise ValueError(msg % pix_to_face.shape)
if bary_clip:
barycentric_coords = _clip_barycentric_coordinates(barycentric_coords)
# Replace empty pixels in pix_to_face with 0 in order to interpolate.
mask = pix_to_face == -1
pix_to_face = pix_to_face.clone()
pix_to_face[mask] = 0
idx = pix_to_face.view(N * H * W * K, 1, 1).expand(N * H * W * K, 3, D)
pixel_face_vals = face_attributes.gather(0, idx).view(N, H, W, K, 3, D)
pixel_vals = (barycentric_coords[..., None] * pixel_face_vals).sum(dim=-2)
pixel_vals[mask] = 0 # Replace masked values in output.
return pixel_vals
def interpolate_texture_map(fragments, meshes) -> torch.Tensor:
"""
Interpolate a 2D texture map using uv vertex texture coordinates for each
face in the mesh. First interpolate the vertex uvs using barycentric coordinates
for each pixel in the rasterized output. Then interpolate the texture map
using the uv coordinate for each pixel.
Args:
fragments:
The outputs of rasterization. From this we use
- pix_to_face: LongTensor of shape (N, H, W, K) specifying the indices
of the faces (in the packed representation) which
overlap each pixel in the image.
- barycentric_coords: FloatTensor of shape (N, H, W, K, 3) specifying
the barycentric coordianates of each pixel
relative to the faces (in the packed
representation) which overlap the pixel.
meshes: Meshes representing a batch of meshes. It is expected that
meshes has a textures attribute which is an instance of the
Textures class.
Returns:
texels: tensor of shape (N, H, W, K, C) giving the interpolated
texture for each pixel in the rasterized image.
"""
if not isinstance(meshes.textures, Textures):
msg = "Expected meshes.textures to be an instance of Textures; got %r"
raise ValueError(msg % type(meshes.textures))
faces_uvs = meshes.textures.faces_uvs_packed()
verts_uvs = meshes.textures.verts_uvs_packed()
faces_verts_uvs = verts_uvs[faces_uvs]
texture_maps = meshes.textures.maps_padded()
# pixel_uvs: (N, H, W, K, 2)
pixel_uvs = interpolate_face_attributes(fragments, faces_verts_uvs)
N, H_out, W_out, K = fragments.pix_to_face.shape
N, H_in, W_in, C = texture_maps.shape # 3 for RGB
# pixel_uvs: (N, H, W, K, 2) -> (N, K, H, W, 2) -> (NK, H, W, 2)
pixel_uvs = pixel_uvs.permute(0, 3, 1, 2, 4).view(N * K, H_out, W_out, 2)
# textures.map:
# (N, H, W, C) -> (N, C, H, W) -> (1, N, C, H, W)
# -> expand (K, N, C, H, W) -> reshape (N*K, C, H, W)
texture_maps = (
texture_maps.permute(0, 3, 1, 2)[None, ...]
.expand(K, -1, -1, -1, -1)
.transpose(0, 1)
.reshape(N * K, C, H_in, W_in)
)
# Textures: (N*K, C, H, W), pixel_uvs: (N*K, H, W, 2)
# Now need to format the pixel uvs and the texture map correctly!
# From pytorch docs, grid_sample takes `grid` and `input`:
# grid specifies the sampling pixel locations normalized by
# the input spatial dimensions It should have most
# values in the range of [-1, 1]. Values x = -1, y = -1
# is the left-top pixel of input, and values x = 1, y = 1 is the
# right-bottom pixel of input.
pixel_uvs = pixel_uvs * 2.0 - 1.0
texture_maps = torch.flip(
texture_maps, [2]
) # flip y axis of the texture map
if texture_maps.device != pixel_uvs.device:
texture_maps = texture_maps.to(pixel_uvs.device)
texels = F.grid_sample(texture_maps, pixel_uvs, align_corners=False)
texels = texels.view(N, K, C, H_out, W_out).permute(0, 3, 4, 1, 2)
return texels
def interpolate_vertex_colors(fragments, meshes) -> torch.Tensor:
"""
Detemine the color for each rasterized face. Interpolate the colors for
vertices which form the face using the barycentric coordinates.
Args:
meshes: A Meshes class representing a batch of meshes.
fragments:
The outputs of rasterization. From this we use
- pix_to_face: LongTensor of shape (N, H, W, K) specifying the indices
of the faces (in the packed representation) which
overlap each pixel in the image.
- barycentric_coords: FloatTensor of shape (N, H, W, K, 3) specifying
the barycentric coordianates of each pixel
relative to the faces (in the packed
representation) which overlap the pixel.
Returns:
texels: An texture per pixel of shape (N, H, W, K, C).
There will be one C dimensional value for each element in
fragments.pix_to_face.
"""
vertex_textures = meshes.textures.verts_rgb_padded().view(-1, 3) # (V, C)
vertex_textures = vertex_textures[meshes.verts_padded_to_packed_idx(), :]
faces_packed = meshes.faces_packed()
faces_textures = vertex_textures[faces_packed] # (F, 3, C)
texels = interpolate_face_attributes(fragments, faces_textures)
return texels

317
pytorch3d/renderer/utils.py Normal file
View File

@@ -0,0 +1,317 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
import numpy as np
from typing import Any, Union
import torch
class TensorAccessor(object):
"""
A helper class to be used with the __getitem__ method. This can be used for
getting/setting the values for an attribute of a class at one particular
index. This is useful when the attributes of a class are batched tensors
and one element in the batch needs to be modified.
"""
def __init__(self, class_object, index: Union[int, slice]):
"""
Args:
class_object: this should be an instance of a class which has
attributes which are tensors representing a batch of
values.
index: int/slice, an index indicating the position in the batch.
In __setattr__ and __getattr__ only the value of class
attributes at this index will be accessed.
"""
self.__dict__["class_object"] = class_object
self.__dict__["index"] = index
def __setattr__(self, name: str, value: Any):
"""
Update the attribute given by `name` to the value given by `value`
at the index specified by `self.index`.
Args:
name: str, name of the attribute.
value: value to set the attribute to.
"""
v = getattr(self.class_object, name)
if not torch.is_tensor(v):
msg = "Can only set values on attributes which are tensors; got %r"
raise AttributeError(msg % type(v))
# Convert the attribute to a tensor if it is not a tensor.
if not torch.is_tensor(value):
value = torch.tensor(
value,
device=v.device,
dtype=v.dtype,
requires_grad=v.requires_grad,
)
# Check the shapes match the existing shape and the shape of the index.
if v.dim() > 1 and value.dim() > 1 and value.shape[1:] != v.shape[1:]:
msg = "Expected value to have shape %r; got %r"
raise ValueError(msg % (v.shape, value.shape))
if (
v.dim() == 0
and isinstance(self.index, slice)
and len(value) != len(self.index)
):
msg = "Expected value to have len %r; got %r"
raise ValueError(msg % (len(self.index), len(value)))
self.class_object.__dict__[name][self.index] = value
def __getattr__(self, name: str):
"""
Return the value of the attribute given by "name" on self.class_object
at the index specified in self.index.
Args:
name: string of the attribute name
"""
if hasattr(self.class_object, name):
return self.class_object.__dict__[name][self.index]
else:
msg = "Attribue %s not found on %r"
return AttributeError(msg % (name, self.class_object.__name__))
BROADCAST_TYPES = (float, int, list, tuple, torch.Tensor, np.ndarray)
class TensorProperties(object):
"""
A mix-in class for storing tensors as properties with helper methods.
"""
def __init__(self, dtype=torch.float32, device="cpu", **kwargs):
"""
Args:
dtype: data type to set for the inputs
device: str or torch.device
kwargs: any number of keyword arguments. Any arguments which are
of type (float/int/tuple/tensor/array) are broadcasted and
other keyword arguments are set as attributes.
"""
super().__init__()
self.device = device
self._N = 0
if kwargs is not None:
# broadcast all inputs which are float/int/list/tuple/tensor/array
# set as attributes anything else e.g. strings, bools
args_to_broadcast = {}
for k, v in kwargs.items():
if isinstance(v, (str, bool)):
setattr(self, k, v)
elif isinstance(v, BROADCAST_TYPES):
args_to_broadcast[k] = v
else:
msg = "Arg %s with type %r is not broadcastable"
print(msg % (k, type(v)))
names = args_to_broadcast.keys()
# convert from type dict.values to tuple
values = tuple(v for v in args_to_broadcast.values())
if len(values) > 0:
broadcasted_values = convert_to_tensors_and_broadcast(
*values, device=device
)
# Set broadcasted values as attributes on self.
for i, n in enumerate(names):
setattr(self, n, broadcasted_values[i])
if self._N == 0:
self._N = broadcasted_values[i].shape[0]
def __len__(self) -> int:
return self._N
def isempty(self) -> bool:
return self._N == 0
def __getitem__(self, index: Union[int, slice]):
"""
Args:
index: an int or slice used to index all the fields.
Returns:
if `index` is an index int/slice return a TensorAccessor class
with getattribute/setattribute methods which return/update the value
at the index in the original camera.
"""
if isinstance(index, (int, slice)):
return TensorAccessor(class_object=self, index=index)
msg = "Expected index of type int or slice; got %r"
raise ValueError(msg % type(index))
def to(self, device: str = "cpu"):
"""
In place operation to move class properties which are tensors to a
specified device. If self has a property "device", update this as well.
"""
for k in dir(self):
v = getattr(self, k)
if k == "device":
setattr(self, k, device)
if torch.is_tensor(v) and v.device != device:
setattr(self, k, v.to(device))
return self
def clone(self, other):
"""
Update the tensor properties of other with the cloned properties of self.
"""
for k in dir(self):
v = getattr(self, k)
if k == "device":
setattr(self, k, v)
if torch.is_tensor(v):
setattr(other, k, v.clone())
return other
def gather_props(self, batch_idx):
"""
This is an in place operation to reformat all tensor class attributes
based on a set of given indices using torch.gather. This is useful when
attributes which are batched tensors e.g. shape (N, 3) need to be
multiplied with another tensor which has a different first dimension
e.g. packed vertices of shape (V, 3).
Example
.. code-block:: python
self.specular_color = (N, 3) tensor of specular colors for each mesh
A lighting calculation may use
.. code-block:: python
verts_packed = meshes.verts_packed() # (V, 3)
To multiply these two tensors the batch dimension needs to be the same.
To achieve this we can do
.. code-block:: python
batch_idx = meshes.verts_packed_to_mesh_idx() # (V)
This gives index of the mesh for each vertex in verts_packed.
.. code-block:: python
self.gather_props(batch_idx)
self.specular_color = (V, 3) tensor with the specular color for
each packed vertex.
torch.gather requires the index tensor to have the same shape as the
input tensor so this method takes care of the reshaping of the index
tensor to use with class attributes with arbitrary dimensions.
Args:
batch_idx: shape (B, ...) where `...` represents an arbitrary
number of dimensions
Returns:
self with all properties reshaped. e.g. a property with shape (N, 3)
is transformed to shape (B, 3).
"""
for k in dir(self):
v = getattr(self, k)
if torch.is_tensor(v):
if v.shape[0] > 1:
# There are different values for each batch element
# so gather these using the batch_idx
idx_dims = batch_idx.shape
tensor_dims = v.shape
if len(idx_dims) > len(tensor_dims):
msg = "batch_idx cannot have more dimensions than %s. "
msg += "got shape %r and %s has shape %r"
raise ValueError(msg % (k, idx_dims, k, tensor_dims))
if idx_dims != tensor_dims:
# To use torch.gather the index tensor (batch_idx) has
# to have the same shape as the input tensor.
new_dims = len(tensor_dims) - len(idx_dims)
new_shape = idx_dims + (1,) * new_dims
expand_dims = (-1,) + tensor_dims[1:]
batch_idx = batch_idx.view(*new_shape)
batch_idx = batch_idx.expand(*expand_dims)
v = v.gather(0, batch_idx)
setattr(self, k, v)
return self
def format_tensor(
input, dtype=torch.float32, device: str = "cpu"
) -> torch.Tensor:
"""
Helper function for converting a scalar value to a tensor.
Args:
input: Python scalar, Python list/tuple, torch scalar, 1D torch tensor
dtype: data type for the input
device: torch device on which the tensor should be placed.
Returns:
input_vec: torch tensor with optional added batch dimension.
"""
if not torch.is_tensor(input):
input = torch.tensor(input, dtype=dtype, device=device)
if input.dim() == 0:
input = input.view(1)
if input.device != device:
input = input.to(device=device)
return input
def convert_to_tensors_and_broadcast(
*args, dtype=torch.float32, device: str = "cpu"
):
"""
Helper function to handle parsing an arbitrary number of inputs (*args)
which all need to have the same batch dimension.
The output is a list of tensors.
Args:
*args: an arbitrary number of inputs
Each of the values in `args` can be one of the following
- Python scalar
- Torch scalar
- Torch tensor of shape (N, K_i) or (1, K_i) where K_i are
an arbitrary number of dimensions which can vary for each
value in args. In this case each input is broadcast to a
tensor of shape (N, K_i)
dtype: data type to use when creating new tensors.
device: torch device on which the tensors should be placed.
Output:
args: A list of tensors of shape (N, K_i)
"""
# Convert all inputs to tensors with a batch dimension
args_1d = [format_tensor(c, dtype, device) for c in args]
# Find broadcast size
sizes = [c.shape[0] for c in args_1d]
N = max(sizes)
args_Nd = []
for c in args_1d:
if c.shape[0] != 1 and c.shape[0] != N:
msg = "Got non-broadcastable sizes %r" % (sizes)
raise ValueError(msg)
# Expand broadcast dim and keep non broadcast dims the same size
expand_sizes = (N,) + (-1,) * len(c.shape[1:])
args_Nd.append(c.expand(*expand_sizes))
if len(args) == 1:
args_Nd = args_Nd[0] # Return the first element
return args_Nd

View File

@@ -0,0 +1,12 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
from .meshes import Meshes
from .textures import Textures
from .utils import (
list_to_packed,
list_to_padded,
packed_to_list,
padded_to_list,
)
__all__ = [k for k in globals().keys() if not k.startswith("_")]

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,205 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
from typing import List, Union
import torch
import torchvision.transforms as T
from .utils import list_to_packed, padded_to_list
"""
This file has functions for interpolating textures after rasterization.
"""
def _pad_texture_maps(images: List[torch.Tensor]) -> torch.Tensor:
"""
Pad all texture images so they have the same height and width.
Args:
images: list of N tensors of shape (H, W)
Returns:
tex_maps: Tensor of shape (N, max_H, max_W)
"""
tex_maps = []
max_H = 0
max_W = 0
for im in images:
h, w, _3 = im.shape
if h > max_H:
max_H = h
if w > max_W:
max_W = w
tex_maps.append(im)
max_shape = (max_H, max_W)
# If all texture images are not the same size then resize to the
# largest size.
resize = T.Compose([T.ToPILImage(), T.Resize(size=max_shape), T.ToTensor()])
for i, image in enumerate(tex_maps):
if image.shape != max_shape:
# ToPIL takes and returns a C x H x W tensor
image = resize(image.permute(2, 0, 1)).permute(1, 2, 0)
tex_maps[i] = image
tex_maps = torch.stack(tex_maps, dim=0) # (num_tex_maps, max_H, max_W, 3)
return tex_maps
def _extend_tensor(input_tensor: torch.Tensor, N: int) -> torch.Tensor:
"""
Extend a tensor `input_tensor` with ndim > 2, `N` times along the batch
dimension. This is done in the following sequence of steps (where `B` is
the batch dimension):
.. code-block:: python
input_tensor (B, ...)
-> add leading empty dimension (1, B, ...)
-> expand (N, B, ...)
-> reshape (N * B, ...)
Args:
input_tensor: torch.Tensor with ndim > 2 representing a batched input.
N: number of times to extend each element of the batch.
"""
if input_tensor.ndim < 2:
raise ValueError("Input tensor must have ndimensions >= 2.")
B = input_tensor.shape[0]
non_batch_dims = tuple(input_tensor.shape[1:])
constant_dims = (-1,) * input_tensor.ndim # these dims are not expanded.
return (
input_tensor.clone()[None, ...]
.expand(N, *constant_dims)
.transpose(0, 1)
.reshape(N * B, *non_batch_dims)
)
class Textures(object):
def __init__(
self,
maps: Union[List, torch.Tensor] = None,
faces_uvs: torch.Tensor = None,
verts_uvs: torch.Tensor = None,
verts_rgb: torch.Tensor = None,
):
"""
Args:
maps: texture map per mesh. This can either be a list of maps
[(H, W, 3)] or a padded tensor of shape (N, H, W, 3).
faces_uvs: (N, F, 3) tensor giving the index into verts_uvs for each
vertex in the face. Padding value is assumed to be -1.
verts_uvs: (N, V, 2) tensor giving the uv coordinate per vertex.
verts_rgb: (N, V, 3) tensor giving the rgb color per vertex.
"""
if faces_uvs is not None and faces_uvs.ndim != 3:
msg = "Expected faces_uvs to be of shape (N, F, 3); got %r"
raise ValueError(msg % repr(faces_uvs.shape))
if verts_uvs is not None and verts_uvs.ndim != 3:
msg = "Expected verts_uvs to be of shape (N, V, 2); got %r"
raise ValueError(msg % repr(faces_uvs.shape))
if verts_rgb is not None and verts_rgb.ndim != 3:
msg = "Expected verts_rgb to be of shape (N, V, 3); got %r"
raise ValueError(msg % verts_rgb.shape)
if maps is not None:
if torch.is_tensor(map) and map.ndim != 4:
msg = "Expected maps to be of shape (N, H, W, 3); got %r"
raise ValueError(msg % repr(maps.shape))
elif isinstance(maps, list):
maps = _pad_texture_maps(maps)
self._faces_uvs_padded = faces_uvs
self._verts_uvs_padded = verts_uvs
self._verts_rgb_padded = verts_rgb
self._maps_padded = maps
self._num_faces_per_mesh = None
if self._faces_uvs_padded is not None:
self._num_faces_per_mesh = faces_uvs.gt(-1).all(-1).sum(-1).tolist()
def clone(self):
other = Textures()
for k in dir(self):
v = getattr(self, k)
if torch.is_tensor(v):
setattr(other, k, v.clone())
return other
def to(self, device):
for k in dir(self):
v = getattr(self, k)
if torch.is_tensor(v) and v.device != device:
setattr(self, k, v.to(device))
return self
def faces_uvs_padded(self) -> torch.Tensor:
return self._faces_uvs_padded
def faces_uvs_list(self) -> List[torch.Tensor]:
if self._faces_uvs_padded is not None:
return padded_to_list(
self._faces_uvs_padded, split_size=self._num_faces_per_mesh
)
def faces_uvs_packed(self) -> torch.Tensor:
return list_to_packed(self.faces_uvs_list())[0]
def verts_uvs_padded(self) -> torch.Tensor:
return self._verts_uvs_padded
def verts_uvs_list(self) -> List[torch.Tensor]:
return padded_to_list(self._verts_uvs_padded)
def verts_uvs_packed(self) -> torch.Tensor:
return list_to_packed(self.verts_uvs_list())[0]
def verts_rgb_padded(self) -> torch.Tensor:
return self._verts_rgb_padded
def verts_rgb_list(self) -> List[torch.Tensor]:
return padded_to_list(self._verts_rgb_padded)
def verts_rgb_packed(self) -> torch.Tensor:
return list_to_packed(self.verts_rgb_list())[0]
# Currently only the padded maps are used.
def maps_padded(self) -> torch.Tensor:
return self._maps_padded
def extend(self, N: int) -> "Textures":
"""
Create new Textures class which contains each input texture N times
Args:
N: number of new copies of each texture.
Returns:
new Textures object.
"""
if not isinstance(N, int):
raise ValueError("N must be an integer.")
if N <= 0:
raise ValueError("N must be > 0.")
if all(
v is not None
for v in [
self._faces_uvs_padded,
self._verts_uvs_padded,
self._maps_padded,
]
):
new_verts_uvs = _extend_tensor(self._verts_uvs_padded, N)
new_faces_uvs = _extend_tensor(self._faces_uvs_padded, N)
new_maps = _extend_tensor(self._maps_padded, N)
return Textures(
verts_uvs=new_verts_uvs, faces_uvs=new_faces_uvs, maps=new_maps
)
elif self._verts_rgb_padded is not None:
new_verts_rgb = _extend_tensor(self._verts_rgb_padded, N)
return Textures(verts_rgb=new_verts_rgb)
else:
msg = "Either vertex colors or texture maps are required."
raise ValueError(msg)

View File

@@ -0,0 +1,150 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
from typing import List, Union
import torch
"""
Util functions containing representation transforms for points/verts/faces.
"""
def list_to_padded(
x: List[torch.Tensor],
pad_size: Union[list, tuple, None] = None,
pad_value: float = 0.0,
equisized: bool = False,
) -> torch.Tensor:
r"""
Transforms a list of N tensors each of shape (Mi, Ki) into a single tensor
of shape (N, pad_size(0), pad_size(1)), or (N, max(Mi), max(Ki))
if pad_size is None.
Args:
x: list of Tensors
pad_size: list(int) specifying the size of the padded tensor
pad_value: float value to be used to fill the padded tensor
equisized: bool indicating whether the items in x are of equal size
(sometimes this is known and if provided saves computation)
Returns:
x_padded: tensor consisting of padded input tensors
"""
if equisized:
return torch.stack(x, 0)
if pad_size is None:
pad_dim0 = max(y.shape[0] for y in x if len(y) > 0)
pad_dim1 = max(y.shape[1] for y in x if len(y) > 0)
else:
if len(pad_size) != 2:
raise ValueError(
"Pad size must contain target size for 1st and 2nd dim"
)
pad_dim0, pad_dim1 = pad_size
N = len(x)
x_padded = torch.full(
(N, pad_dim0, pad_dim1), pad_value, dtype=x[0].dtype, device=x[0].device
)
for i, y in enumerate(x):
if len(y) > 0:
if y.ndim != 2:
raise ValueError("Supports only 2-dimensional tensor items")
x_padded[i, : y.shape[0], : y.shape[1]] = y
return x_padded
def padded_to_list(
x: torch.Tensor, split_size: Union[list, tuple, None] = None
):
r"""
Transforms a padded tensor of shape (N, M, K) into a list of N tensors
of shape (Mi, Ki) where (Mi, Ki) is specified in split_size(i), or of shape
(M, K) if split_size is None.
Support only for 3-dimensional input tensor.
Args:
x: tensor
split_size: the shape of the final tensor to be returned (of length N).
"""
if x.ndim != 3:
raise ValueError("Supports only 3-dimensional input tensors")
x_list = list(x.unbind(0))
if split_size is None:
return x_list
N = len(split_size)
if x.shape[0] != N:
raise ValueError(
"Split size must be of same length as inputs first dimension"
)
for i in range(N):
if isinstance(split_size[i], int):
x_list[i] = x_list[i][: split_size[i]]
elif len(split_size[i]) == 2:
x_list[i] = x_list[i][: split_size[i][0], : split_size[i][1]]
else:
raise ValueError(
"Support only for 2-dimensional unbinded tensor. \
Split size for more dimensions provided"
)
return x_list
def list_to_packed(x: List[torch.Tensor]):
r"""
Transforms a list of N tensors each of shape (Mi, K, ...) into a single
tensor of shape (sum(Mi), K, ...).
Args:
x: list of tensors.
Returns:
4-element tuple containing
- **x_packed**: tensor consisting of packed input tensors along the
1st dimension.
- **num_items**: tensor of shape N containing Mi for each element in x.
- **item_packed_first_idx**: tensor of shape N indicating the index of
the first item belonging to the same element in the original list.
- **item_packed_to_list_idx**: tensor of shape sum(Mi) containing the
index of the element in the list the item belongs to.
"""
N = len(x)
num_items = torch.zeros(N, dtype=torch.int64, device=x[0].device)
item_packed_first_idx = torch.zeros(
N, dtype=torch.int64, device=x[0].device
)
item_packed_to_list_idx = []
cur = 0
for i, y in enumerate(x):
num = len(y)
num_items[i] = num
item_packed_first_idx[i] = cur
item_packed_to_list_idx.append(
torch.full((num,), i, dtype=torch.int64, device=y.device)
)
cur += num
x_packed = torch.cat(x, dim=0)
item_packed_to_list_idx = torch.cat(item_packed_to_list_idx, dim=0)
return x_packed, num_items, item_packed_first_idx, item_packed_to_list_idx
def packed_to_list(x: torch.Tensor, split_size: Union[list, int]):
r"""
Transforms a tensor of shape (sum(Mi), K, L, ...) to N set of tensors of
shape (Mi, K, L, ...) where Mi's are defined in split_size
Args:
x: tensor
split_size: list or int defining the number of items for each split
Returns:
x_list: A list of Tensors
"""
return x.split(split_size, dim=0)

View File

@@ -0,0 +1,25 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
from .rotation_conversions import (
euler_angles_to_matrix,
matrix_to_euler_angles,
matrix_to_quaternion,
quaternion_apply,
quaternion_invert,
quaternion_multiply,
quaternion_raw_multiply,
quaternion_to_matrix,
random_quaternions,
random_rotation,
random_rotations,
standardize_quaternion,
)
from .so3 import (
so3_exponential_map,
so3_log_map,
so3_relative_angle,
so3_rotation_angle,
)
from .transform3d import Rotate, RotateAxisAngle, Scale, Transform3d, Translate
__all__ = [k for k in globals().keys() if not k.startswith("_")]

View File

@@ -0,0 +1,374 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
import functools
import torch
def quaternion_to_matrix(quaternions):
"""
Convert rotations given as quaternions to rotation matrices.
Args:
quaternions: quaternions with real part first,
as tensor of shape (..., 4).
Returns:
Rotation matrices as tensor of shape (..., 3, 3).
"""
r, i, j, k = torch.unbind(quaternions, -1)
two_s = 2.0 / (quaternions * quaternions).sum(-1)
o = torch.stack(
(
1 - two_s * (j * j + k * k),
two_s * (i * j - k * r),
two_s * (i * k + j * r),
two_s * (i * j + k * r),
1 - two_s * (i * i + k * k),
two_s * (j * k - i * r),
two_s * (i * k - j * r),
two_s * (j * k + i * r),
1 - two_s * (i * i + j * j),
),
-1,
)
return o.reshape(quaternions.shape[:-1] + (3, 3))
def _copysign(a, b):
"""
Return a tensor where each element has the absolute value taken from the,
corresponding element of a, with sign taken from the corresponding
element of b. This is like the standard copysign floating-point operation,
but is not careful about negative 0 and NaN.
Args:
a: source tensor.
b: tensor whose signs will be used, of the same shape as a.
Returns:
Tensor of the same shape as a with the signs of b.
"""
signs_differ = (a < 0) != (b < 0)
return torch.where(signs_differ, -a, a)
def matrix_to_quaternion(matrix):
"""
Convert rotations given as rotation matrices to quaternions.
Args:
matrix: Rotation matrices as tensor of shape (..., 3, 3).
Returns:
quaternions with real part first, as tensor of shape (..., 4).
"""
if matrix.size(-1) != 3 or matrix.size(-2) != 3:
raise ValueError(f"Invalid rotation matrix shape f{matrix.shape}.")
zero = matrix.new_zeros((1,))
m00 = matrix[..., 0, 0]
m11 = matrix[..., 1, 1]
m22 = matrix[..., 2, 2]
o0 = 0.5 * torch.sqrt(torch.max(zero, 1 + m00 + m11 + m22))
x = 0.5 * torch.sqrt(torch.max(zero, 1 + m00 - m11 - m22))
y = 0.5 * torch.sqrt(torch.max(zero, 1 - m00 + m11 - m22))
z = 0.5 * torch.sqrt(torch.max(zero, 1 - m00 - m11 + m22))
o1 = _copysign(x, matrix[..., 2, 1] - matrix[..., 1, 2])
o2 = _copysign(y, matrix[..., 0, 2] - matrix[..., 2, 0])
o3 = _copysign(z, matrix[..., 1, 0] - matrix[..., 0, 1])
return torch.stack((o0, o1, o2, o3), -1)
def _primary_matrix(axis: str, angle):
"""
Return the rotation matrices for one of the rotations about an axis
of which Euler angles describe, for each value of the angle given.
Args:
axis: Axis label "X" or "Y or "Z".
angle: any shape tensor of Euler angles in radians
Returns:
Rotation matrices as tensor of shape (..., 3, 3).
"""
cos = torch.cos(angle)
sin = torch.sin(angle)
one = torch.ones_like(angle)
zero = torch.zeros_like(angle)
if axis == "X":
o = (one, zero, zero, zero, cos, -sin, zero, sin, cos)
if axis == "Y":
o = (cos, zero, sin, zero, one, zero, -sin, zero, cos)
if axis == "Z":
o = (cos, -sin, zero, sin, cos, zero, zero, zero, one)
return torch.stack(o, -1).reshape(angle.shape + (3, 3))
def euler_angles_to_matrix(euler_angles, convention: str):
"""
Convert rotations given as Euler angles in radians to rotation matrices.
Args:
euler_angles: Euler angles in radians as tensor of shape (..., 3).
convention: Convention string of three uppercase letters from
{"X", "Y", and "Z"}.
Returns:
Rotation matrices as tensor of shape (..., 3, 3).
"""
if euler_angles.dim() == 0 or euler_angles.shape[-1] != 3:
raise ValueError("Invalid input euler angles.")
if len(convention) != 3:
raise ValueError("Convention must have 3 letters.")
if convention[1] in (convention[0], convention[2]):
raise ValueError(f"Invalid convention {convention}.")
for letter in convention:
if letter not in ("X", "Y", "Z"):
raise ValueError(f"Invalid letter {letter} in convention string.")
matrices = map(_primary_matrix, convention, torch.unbind(euler_angles, -1))
return functools.reduce(torch.matmul, matrices)
def _angle_from_tan(
axis: str, other_axis: str, data, horizontal: bool, tait_bryan: bool
):
"""
Extract the first or third Euler angle from the two members of
the matrix which are positive constant times its sine and cosine.
Args:
axis: Axis label "X" or "Y or "Z" for the angle we are finding.
other_axis: Axis label "X" or "Y or "Z" for the middle axis in the
convention.
data: Rotation matrices as tensor of shape (..., 3, 3).
horizontal: Whether we are looking for the angle for the third axis,
which means the relevant entries are in the same row of the
rotation matrix. If not, they are in the same column.
tait_bryan: Whether the first and third axes in the convention differ.
Returns:
Euler Angles in radians for each matrix in data as a tensor
of shape (...).
"""
i1, i2 = {"X": (2, 1), "Y": (0, 2), "Z": (1, 0)}[axis]
if horizontal:
i2, i1 = i1, i2
even = (axis + other_axis) in ["XY", "YZ", "ZX"]
if horizontal == even:
return torch.atan2(data[..., i1], data[..., i2])
if tait_bryan:
return torch.atan2(-data[..., i2], data[..., i1])
return torch.atan2(data[..., i2], -data[..., i1])
def _index_from_letter(letter: str):
if letter == "X":
return 0
if letter == "Y":
return 1
if letter == "Z":
return 2
def matrix_to_euler_angles(matrix, convention: str):
"""
Convert rotations given as rotation matrices to Euler angles in radians.
Args:
matrix: Rotation matrices as tensor of shape (..., 3, 3).
convention: Convention string of three uppercase letters.
Returns:
Euler angles in radians as tensor of shape (..., 3).
"""
if len(convention) != 3:
raise ValueError("Convention must have 3 letters.")
if convention[1] in (convention[0], convention[2]):
raise ValueError(f"Invalid convention {convention}.")
for letter in convention:
if letter not in ("X", "Y", "Z"):
raise ValueError(f"Invalid letter {letter} in convention string.")
if matrix.size(-1) != 3 or matrix.size(-2) != 3:
raise ValueError(f"Invalid rotation matrix shape f{matrix.shape}.")
i0 = _index_from_letter(convention[0])
i2 = _index_from_letter(convention[2])
tait_bryan = i0 != i2
if tait_bryan:
central_angle = torch.asin(
matrix[..., i0, i2] * (-1.0 if i0 - i2 in [-1, 2] else 1.0)
)
else:
central_angle = torch.acos(matrix[..., i0, i0])
o = (
_angle_from_tan(
convention[0], convention[1], matrix[..., i2], False, tait_bryan
),
central_angle,
_angle_from_tan(
convention[2], convention[1], matrix[..., i0, :], True, tait_bryan
),
)
return torch.stack(o, -1)
def random_quaternions(
n: int, dtype: torch.dtype = None, device=None, requires_grad=False
):
"""
Generate random quaternions representing rotations,
i.e. versors with nonnegative real part.
Args:
n: Number to return.
dtype: Type to return.
device: Desired device of returned tensor. Default:
uses the current device for the default tensor type.
requires_grad: Whether the resulting tensor should have the gradient
flag set.
Returns:
Quaternions as tensor of shape (N, 4).
"""
o = torch.randn(
(n, 4), dtype=dtype, device=device, requires_grad=requires_grad
)
s = (o * o).sum(1)
o = o / _copysign(torch.sqrt(s), o[:, 0])[:, None]
return o
def random_rotations(
n: int, dtype: torch.dtype = None, device=None, requires_grad=False
):
"""
Generate random rotations as 3x3 rotation matrices.
Args:
n: Number to return.
dtype: Type to return.
device: Device of returned tensor. Default: if None,
uses the current device for the default tensor type.
requires_grad: Whether the resulting tensor should have the gradient
flag set.
Returns:
Rotation matrices as tensor of shape (n, 3, 3).
"""
quaternions = random_quaternions(
n, dtype=dtype, device=device, requires_grad=requires_grad
)
return quaternion_to_matrix(quaternions)
def random_rotation(
dtype: torch.dtype = None, device=None, requires_grad=False
):
"""
Generate a single random 3x3 rotation matrix.
Args:
dtype: Type to return
device: Device of returned tensor. Default: if None,
uses the current device for the default tensor type
requires_grad: Whether the resulting tensor should have the gradient
flag set
Returns:
Rotation matrix as tensor of shape (3, 3).
"""
return random_rotations(1, dtype, device, requires_grad)[0]
def standardize_quaternion(quaternions):
"""
Convert a unit quaternion to a standard form: one in which the real
part is non negative.
Args:
quaternions: Quaternions with real part first,
as tensor of shape (..., 4).
Returns:
Standardized quaternions as tensor of shape (..., 4).
"""
return torch.where(quaternions[..., 0:1] < 0, -quaternions, quaternions)
def quaternion_raw_multiply(a, b):
"""
Multiply two quaternions.
Usual torch rules for broadcasting apply.
Args:
a: Quaternions as tensor of shape (..., 4), real part first.
b: Quaternions as tensor of shape (..., 4), real part first.
Returns:
The product of a and b, a tensor of quaternions shape (..., 4).
"""
aw, ax, ay, az = torch.unbind(a, -1)
bw, bx, by, bz = torch.unbind(b, -1)
ow = aw * bw - ax * bx - ay * by - az * bz
ox = aw * bx + ax * bw + ay * bz - az * by
oy = aw * by - ax * bz + ay * bw + az * bx
oz = aw * bz + ax * by - ay * bx + az * bw
return torch.stack((ow, ox, oy, oz), -1)
def quaternion_multiply(a, b):
"""
Multiply two quaternions representing rotations, returning the quaternion
representing their composition, i.e. the versor with nonnegative real part.
Usual torch rules for broadcasting apply.
Args:
a: Quaternions as tensor of shape (..., 4), real part first.
b: Quaternions as tensor of shape (..., 4), real part first.
Returns:
The product of a and b, a tensor of quaternions of shape (..., 4).
"""
ab = quaternion_raw_multiply(a, b)
return standardize_quaternion(ab)
def quaternion_invert(quaternion):
"""
Given a quaternion representing rotation, get the quaternion representing
its inverse.
Args:
quaternion: Quaternions as tensor of shape (..., 4), with real part
first, which must be versors (unit quaternions).
Returns:
The inverse, a tensor of quaternions of shape (..., 4).
"""
return quaternion * quaternion.new_tensor([1, -1, -1, -1])
def quaternion_apply(quaternion, point):
"""
Apply the rotation given by a quaternion to a 3D point.
Usual torch rules for broadcasting apply.
Args:
quaternion: Tensor of quaternions, real part first, of shape (..., 4).
point: Tensor of 3D points of shape (..., 3).
Returns:
Tensor of rotated points of shape (..., 3).
"""
if point.size(-1) != 3:
raise ValueError(f"Points are not in 3D, f{point.shape}.")
real_parts = point.new_zeros(point.shape[:-1] + (1,))
point_as_quaternion = torch.cat((real_parts, point), -1)
out = quaternion_raw_multiply(
quaternion_raw_multiply(quaternion, point_as_quaternion),
quaternion_invert(quaternion),
)
return out[..., 1:]

236
pytorch3d/transforms/so3.py Normal file
View File

@@ -0,0 +1,236 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
import torch
HAT_INV_SKEW_SYMMETRIC_TOL = 1e-5
def so3_relative_angle(R1, R2, cos_angle: bool = False):
"""
Calculates the relative angle (in radians) between pairs of
rotation matrices `R1` and `R2` with `angle = acos(0.5 * Trace(R1 R2^T)-1)`
.. note::
This corresponds to a geodesic distance on the 3D manifold of rotation
matrices.
Args:
R1: Batch of rotation matrices of shape `(minibatch, 3, 3)`.
R2: Batch of rotation matrices of shape `(minibatch, 3, 3)`.
cos_angle: If==True return cosine of the relative angle rather than
the angle itself. This can avoid the unstable
calculation of `acos`.
Returns:
Corresponding rotation angles of shape `(minibatch,)`.
If `cos_angle==True`, returns the cosine of the angles.
Raises:
ValueError if `R1` or `R2` is of incorrect shape.
ValueError if `R1` or `R2` has an unexpected trace.
"""
R12 = torch.bmm(R1, R2.permute(0, 2, 1))
return so3_rotation_angle(R12, cos_angle=cos_angle)
def so3_rotation_angle(R, eps: float = 1e-4, cos_angle: bool = False):
"""
Calculates angles (in radians) of a batch of rotation matrices `R` with
`angle = acos(0.5 * (Trace(R)-1))`. The trace of the
input matrices is checked to be in the valid range `[-1-eps,3+eps]`.
The `eps` argument is a small constant that allows for small errors
caused by limited machine precision.
Args:
R: Batch of rotation matrices of shape `(minibatch, 3, 3)`.
eps: Tolerance for the valid trace check.
cos_angle: If==True return cosine of the rotation angles rather than
the angle itself. This can avoid the unstable
calculation of `acos`.
Returns:
Corresponding rotation angles of shape `(minibatch,)`.
If `cos_angle==True`, returns the cosine of the angles.
Raises:
ValueError if `R` is of incorrect shape.
ValueError if `R` has an unexpected trace.
"""
N, dim1, dim2 = R.shape
if dim1 != 3 or dim2 != 3:
raise ValueError("Input has to be a batch of 3x3 Tensors.")
rot_trace = R[:, 0, 0] + R[:, 1, 1] + R[:, 2, 2]
if ((rot_trace < -1.0 - eps) + (rot_trace > 3.0 + eps)).any():
raise ValueError(
"A matrix has trace outside valid range [-1-eps,3+eps]."
)
# clamp to valid range
rot_trace = torch.clamp(rot_trace, -1.0, 3.0)
# phi ... rotation angle
phi = 0.5 * (rot_trace - 1.0)
if cos_angle:
return phi
else:
return phi.acos()
def so3_exponential_map(log_rot, eps: float = 0.0001):
"""
Convert a batch of logarithmic representations of rotation matrices `log_rot`
to a batch of 3x3 rotation matrices using Rodrigues formula [1].
In the logarithmic representation, each rotation matrix is represented as
a 3-dimensional vector (`log_rot`) who's l2-norm and direction correspond
to the magnitude of the rotation angle and the axis of rotation respectively.
The conversion has a singularity around `log(R) = 0`
which is handled by clamping controlled with the `eps` argument.
Args:
log_rot: Batch of vectors of shape `(minibatch , 3)`.
eps: A float constant handling the conversion singularity.
Returns:
Batch of rotation matrices of shape `(minibatch , 3 , 3)`.
Raises:
ValueError if `log_rot` is of incorrect shape.
[1] https://en.wikipedia.org/wiki/Rodrigues%27_rotation_formula
"""
_, dim = log_rot.shape
if dim != 3:
raise ValueError("Input tensor shape has to be Nx3.")
nrms = (log_rot * log_rot).sum(1)
# phis ... rotation angles
rot_angles = torch.clamp(nrms, eps).sqrt()
rot_angles_inv = 1.0 / rot_angles
fac1 = rot_angles_inv * rot_angles.sin()
fac2 = rot_angles_inv * rot_angles_inv * (1.0 - rot_angles.cos())
skews = hat(log_rot)
R = (
fac1[:, None, None] * skews
+ fac2[:, None, None] * torch.bmm(skews, skews)
+ torch.eye(3, dtype=log_rot.dtype, device=log_rot.device)[None]
)
return R
def so3_log_map(R, eps: float = 0.0001):
"""
Convert a batch of 3x3 rotation matrices `R`
to a batch of 3-dimensional matrix logarithms of rotation matrices
The conversion has a singularity around `(R=I)` which is handled
by clamping controlled with the `eps` argument.
Args:
R: batch of rotation matrices of shape `(minibatch, 3, 3)`.
eps: A float constant handling the conversion singularity.
Returns:
Batch of logarithms of input rotation matrices
of shape `(minibatch, 3)`.
Raises:
ValueError if `R` is of incorrect shape.
ValueError if `R` has an unexpected trace.
"""
N, dim1, dim2 = R.shape
if dim1 != 3 or dim2 != 3:
raise ValueError("Input has to be a batch of 3x3 Tensors.")
phi = so3_rotation_angle(R)
phi_valid = torch.clamp(phi.abs(), eps) * phi.sign()
log_rot_hat = (phi_valid / (2.0 * phi_valid.sin()))[:, None, None] * (
R - R.permute(0, 2, 1)
)
log_rot = hat_inv(log_rot_hat)
return log_rot
def hat_inv(h):
"""
Compute the inverse Hat operator [1] of a batch of 3x3 matrices.
Args:
h: Batch of skew-symmetric matrices of shape `(minibatch, 3, 3)`.
Returns:
Batch of 3d vectors of shape `(minibatch, 3, 3)`.
Raises:
ValueError if `h` is of incorrect shape.
ValueError if `h` not skew-symmetric.
[1] https://en.wikipedia.org/wiki/Hat_operator
"""
N, dim1, dim2 = h.shape
if dim1 != 3 or dim2 != 3:
raise ValueError("Input has to be a batch of 3x3 Tensors.")
ss_diff = (h + h.permute(0, 2, 1)).abs().max()
if float(ss_diff) > HAT_INV_SKEW_SYMMETRIC_TOL:
raise ValueError("One of input matrices not skew-symmetric.")
x = h[:, 2, 1]
y = h[:, 0, 2]
z = h[:, 1, 0]
v = torch.stack((x, y, z), dim=1)
return v
def hat(v):
"""
Compute the Hat operator [1] of a batch of 3D vectors.
Args:
v: Batch of vectors of shape `(minibatch , 3)`.
Returns:
Batch of skew-symmetric matrices of shape
`(minibatch, 3 , 3)` where each matrix is of the form:
`[ 0 -v_z v_y ]
[ v_z 0 -v_x ]
[ -v_y v_x 0 ]`
Raises:
ValueError if `v` is of incorrect shape.
[1] https://en.wikipedia.org/wiki/Hat_operator
"""
N, dim = v.shape
if dim != 3:
raise ValueError("Input vectors have to be 3-dimensional.")
h = v.new_zeros(N, 3, 3)
x, y, z = v.unbind(1)
h[:, 0, 1] = -z
h[:, 0, 2] = y
h[:, 1, 0] = z
h[:, 1, 2] = -x
h[:, 2, 0] = -y
h[:, 2, 1] = x
return h

View File

@@ -0,0 +1,677 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
import math
import torch
class Transform3d:
"""
A Transform3d object encapsulates a batch of N 3D transformations, and knows
how to transform points and normal vectors. Suppose that t is a Transform3d;
then we can do the following:
.. code-block:: python
N = len(t)
points = torch.randn(N, P, 3)
normals = torch.randn(N, P, 3)
points_transformed = t.transform_points(points) # => (N, P, 3)
normals_transformed = t.transform_points(normals) # => (N, P, 3)
BROADCASTING
Transform3d objects supports broadcasting. Suppose that t1 and tN are
Transform3D objects with len(t1) == 1 and len(tN) == N respectively. Then we
can broadcast transforms like this:
.. code-block:: python
t1.transform_points(torch.randn(P, 3)) # => (P, 3)
t1.transform_points(torch.randn(1, P, 3)) # => (1, P, 3)
t1.transform_points(torch.randn(M, P, 3)) # => (M, P, 3)
tN.transform_points(torch.randn(P, 3)) # => (N, P, 3)
tN.transform_points(torch.randn(1, P, 3)) # => (N, P, 3)
COMBINING TRANSFORMS
Transform3d objects can be combined in two ways: composing and stacking.
Composing is function composition. Given Transform3d objects t1, t2, t3,
the following all compute the same thing:
.. code-block:: python
y1 = t3.transform_points(t2.transform_points(t2.transform_points(x)))
y2 = t1.compose(t2).compose(t3).transform_points()
y3 = t1.compose(t2, t3).transform_points()
Composing transforms should broadcast.
.. code-block:: python
if len(t1) == 1 and len(t2) == N, then len(t1.compose(t2)) == N.
We can also stack a sequence of Transform3d objects, which represents
composition along the batch dimension; then the following should compute the
same thing.
.. code-block:: python
N, M = len(tN), len(tM)
xN = torch.randn(N, P, 3)
xM = torch.randn(M, P, 3)
y1 = torch.cat([tN.transform_points(xN), tM.transform_points(xM)], dim=0)
y2 = tN.stack(tM).transform_points(torch.cat([xN, xM], dim=0))
BUILDING TRANSFORMS
We provide convenience methods for easily building Transform3d objects
as compositions of basic transforms.
.. code-block:: python
# Scale by 0.5, then translate by (1, 2, 3)
t1 = Transform3d().scale(0.5).translate(1, 2, 3)
# Scale each axis by a different amount, then translate, then scale
t2 = Transform3d().scale(1, 3, 3).translate(2, 3, 1).scale(2.0)
t3 = t1.compose(t2)
tN = t1.stack(t3, t3)
BACKPROP THROUGH TRANSFORMS
When building transforms, we can also parameterize them by Torch tensors;
in this case we can backprop through the construction and application of
Transform objects, so they could be learned via gradient descent or
predicted by a neural network.
.. code-block:: python
s1_params = torch.randn(N, requires_grad=True)
t_params = torch.randn(N, 3, requires_grad=True)
s2_params = torch.randn(N, 3, requires_grad=True)
t = Transform3d().scale(s1_params).translate(t_params).scale(s2_params)
x = torch.randn(N, 3)
y = t.transform_points(x)
loss = compute_loss(y)
loss.backward()
with torch.no_grad():
s1_params -= lr * s1_params.grad
t_params -= lr * t_params.grad
s2_params -= lr * s2_params.grad
"""
def __init__(self, dtype=torch.float32, device="cpu"):
"""
This class assumes a row major ordering for all matrices.
"""
self._matrix = torch.eye(4, dtype=dtype, device=device).view(1, 4, 4)
self._transforms = [] # store transforms to compose
self._lu = None
self.device = device
def __len__(self):
return self.get_matrix().shape[0]
def compose(self, *others):
"""
Return a new Transform3d with the tranforms to compose stored as
an internal list.
Args:
*others: Any number of Transform3d objects
Returns:
A new Transform3d with the stored transforms
"""
out = Transform3d(device=self.device)
out._matrix = self._matrix.clone()
for other in others:
if not isinstance(other, Transform3d):
msg = "Only possible to compose Transform3d objects; got %s"
raise ValueError(msg % type(other))
out._transforms = self._transforms + list(others)
return out
def get_matrix(self):
"""
Return a matrix which is the result of composing this transform
with others stored in self.transforms. Where necessary transforms
are broadcast against each other.
For example, if self.transforms contains transforms t1, t2, and t3, and
given a set of points x, the following should be true:
.. code-block:: python
y1 = t1.compose(t2, t3).transform(x)
y2 = t3.transform(t2.transform(t1.transform(x)))
y1.get_matrix() == y2.get_matrix()
Returns:
A transformation matrix representing the composed inputs.
"""
composed_matrix = self._matrix.clone()
if len(self._transforms) > 0:
for other in self._transforms:
other_matrix = other.get_matrix()
composed_matrix = _broadcast_bmm(composed_matrix, other_matrix)
return composed_matrix
def _get_matrix_inverse(self):
"""
Return the inverse of self._matrix.
"""
return torch.inverse(self._matrix)
def inverse(self, invert_composed: bool = False):
"""
Returns a new Transform3D object that represents an inverse of the
current transformation.
Args:
invert_composed:
- True: First compose the list of stored transformations
and then apply inverse to the result. This is
potentially slower for classes of transformations
with inverses that can be computed efficiently
(e.g. rotations and translations).
- False: Invert the individual stored transformations
independently without composing them.
Returns:
A new Transform3D object contaning the inverse of the original
transformation.
"""
tinv = Transform3d(device=self.device)
if invert_composed:
# first compose then invert
tinv._matrix = torch.inverse(self.get_matrix())
else:
# self._get_matrix_inverse() implements efficient inverse
# of self._matrix
i_matrix = self._get_matrix_inverse()
# 2 cases:
if len(self._transforms) > 0:
# a) Either we have a non-empty list of transforms:
# Here we take self._matrix and append its inverse at the
# end of the reverted _transforms list. After composing
# the transformations with get_matrix(), this correctly
# right-multiplies by the inverse of self._matrix
# at the end of the composition.
tinv._transforms = [
t.inverse() for t in reversed(self._transforms)
]
last = Transform3d(device=self.device)
last._matrix = i_matrix
tinv._transforms.append(last)
else:
# b) Or there are no stored transformations
# we just set inverted matrix
tinv._matrix = i_matrix
return tinv
def stack(self, *others):
transforms = [self] + list(others)
matrix = torch.cat([t._matrix for t in transforms], dim=0)
out = Transform3d()
out._matrix = matrix
return out
def transform_points(self, points, eps: float = None):
"""
Use this transform to transform a set of 3D points. Assumes row major
ordering of the input points.
Args:
points: Tensor of shape (P, 3) or (N, P, 3)
eps: If eps!=None, the argument is used to clamp the
last coordinate before peforming the final division.
The clamping corresponds to:
last_coord := (last_coord.sign() + (last_coord==0)) *
torch.clamp(last_coord.abs(), eps),
i.e. the last coordinates that are exactly 0 will
be clamped to +eps.
Returns:
points_out: points of shape (N, P, 3) or (P, 3) depending
on the dimensions of the transform
"""
points_batch = points.clone()
if points_batch.dim() == 2:
points_batch = points_batch[None] # (P, 3) -> (1, P, 3)
if points_batch.dim() != 3:
msg = "Expected points to have dim = 2 or dim = 3: got shape %r"
raise ValueError(msg % points.shape)
N, P, _3 = points_batch.shape
ones = torch.ones(N, P, 1, dtype=points.dtype, device=points.device)
points_batch = torch.cat([points_batch, ones], dim=2)
composed_matrix = self.get_matrix()
points_out = _broadcast_bmm(points_batch, composed_matrix)
denom = points_out[..., 3:] # denominator
if eps is not None:
denom_sign = denom.sign() + (denom == 0.0).type_as(denom)
denom = denom_sign * torch.clamp(denom.abs(), eps)
points_out = points_out[..., :3] / denom
# When transform is (1, 4, 4) and points is (P, 3) return
# points_out of shape (P, 3)
if points_out.shape[0] == 1 and points.dim() == 2:
points_out = points_out.reshape(points.shape)
return points_out
def transform_normals(self, normals):
"""
Use this transform to transform a set of normal vectors.
Args:
normals: Tensor of shape (P, 3) or (N, P, 3)
Returns:
normals_out: Tensor of shape (P, 3) or (N, P, 3) depending
on the dimensions of the transform
"""
if normals.dim() not in [2, 3]:
msg = "Expected normals to have dim = 2 or dim = 3: got shape %r"
raise ValueError(msg % normals.shape)
composed_matrix = self.get_matrix()
# TODO: inverse is bad! Solve a linear system instead
mat = composed_matrix[:, :3, :3]
normals_out = _broadcast_bmm(normals, mat.transpose(1, 2).inverse())
# This doesn't pass unit tests. TODO investigate further
# if self._lu is None:
# self._lu = self._matrix[:, :3, :3].transpose(1, 2).lu()
# normals_out = normals.lu_solve(*self._lu)
# When transform is (1, 4, 4) and normals is (P, 3) return
# normals_out of shape (P, 3)
if normals_out.shape[0] == 1 and normals.dim() == 2:
normals_out = normals_out.reshape(normals.shape)
return normals_out
def translate(self, *args, **kwargs):
return self.compose(Translate(device=self.device, *args, **kwargs))
def scale(self, *args, **kwargs):
return self.compose(Scale(device=self.device, *args, **kwargs))
def rotate_axis_angle(self, *args, **kwargs):
return self.compose(
RotateAxisAngle(device=self.device, *args, **kwargs)
)
def clone(self):
"""
Deep copy of Transforms object. All internal tensors are cloned
individually.
Returns:
new Transforms object.
"""
other = Transform3d(device=self.device)
if self._lu is not None:
other._lu = [l.clone() for l in self._lu]
other._matrix = self._matrix.clone()
other._transforms = [t.clone() for t in self._transforms]
return other
def to(self, device, copy: bool = False, dtype=None):
"""
Match functionality of torch.Tensor.to()
If copy = True or the self Tensor is on a different device, the
returned tensor is a copy of self with the desired torch.device.
If copy = False and the self Tensor already has the correct torch.device,
then self is returned.
Args:
device: Device id for the new tensor.
copy: Boolean indicator whether or not to clone self. Default False.
dtype: If not None, casts the internal tensor variables
to a given torch.dtype.
Returns:
Transform3d object.
"""
if not copy and self.device == device:
return self
other = self.clone()
if self.device != device:
other.device = device
other._matrix = self._matrix.to(device=device, dtype=dtype)
for t in other._transforms:
t.to(device, copy=copy, dtype=dtype)
return other
def cpu(self):
return self.to(torch.device("cpu"))
def cuda(self):
return self.to(torch.device("cuda"))
class Translate(Transform3d):
def __init__(
self, x, y=None, z=None, dtype=torch.float32, device: str = "cpu"
):
"""
Create a new Transform3d representing 3D translations.
Option I: Translate(xyz, dtype=torch.float32, device='cpu')
xyz should be a tensor of shape (N, 3)
Option II: Translate(x, y, z, dtype=torch.float32, device='cpu')
Here x, y, and z will be broadcast against each other and
concatenated to form the translation. Each can be:
- A python scalar
- A torch scalar
- A 1D torch tensor
"""
super().__init__(device=device)
xyz = _handle_input(x, y, z, dtype, device, "Translate")
N = xyz.shape[0]
mat = torch.eye(4, dtype=dtype, device=device)
mat = mat.view(1, 4, 4).repeat(N, 1, 1)
mat[:, 3, :3] = xyz
self._matrix = mat
def _get_matrix_inverse(self):
"""
Return the inverse of self._matrix.
"""
inv_mask = self._matrix.new_ones([1, 4, 4])
inv_mask[0, 3, :3] = -1.0
i_matrix = self._matrix * inv_mask
return i_matrix
class Scale(Transform3d):
def __init__(
self, x, y=None, z=None, dtype=torch.float32, device: str = "cpu"
):
"""
A Transform3d representing a scaling operation, with different scale
factors along each coordinate axis.
Option I: Scale(s, dtype=torch.float32, device='cpu')
s can be one of
- Python scalar or torch scalar: Single uniform scale
- 1D torch tensor of shape (N,): A batch of uniform scale
- 2D torch tensor of shape (N, 3): Scale differently along each axis
Option II: Scale(x, y, z, dtype=torch.float32, device='cpu')
Each of x, y, and z can be one of
- python scalar
- torch scalar
- 1D torch tensor
"""
super().__init__(device=device)
xyz = _handle_input(
x, y, z, dtype, device, "scale", allow_singleton=True
)
N = xyz.shape[0]
# TODO: Can we do this all in one go somehow?
mat = torch.eye(4, dtype=dtype, device=device)
mat = mat.view(1, 4, 4).repeat(N, 1, 1)
mat[:, 0, 0] = xyz[:, 0]
mat[:, 1, 1] = xyz[:, 1]
mat[:, 2, 2] = xyz[:, 2]
self._matrix = mat
def _get_matrix_inverse(self):
"""
Return the inverse of self._matrix.
"""
xyz = torch.stack([self._matrix[:, i, i] for i in range(4)], dim=1)
ixyz = 1.0 / xyz
imat = torch.diag_embed(ixyz, dim1=1, dim2=2)
return imat
class Rotate(Transform3d):
def __init__(
self,
R,
dtype=torch.float32,
device: str = "cpu",
orthogonal_tol: float = 1e-5,
):
"""
Create a new Transform3d representing 3D rotation using a rotation
matrix as the input.
Args:
R: a tensor of shape (3, 3) or (N, 3, 3)
orthogonal_tol: tolerance for the test of the orthogonality of R
"""
super().__init__(device=device)
if R.dim() == 2:
R = R[None]
if R.shape[-2:] != (3, 3):
msg = "R must have shape (3, 3) or (N, 3, 3); got %s"
raise ValueError(msg % repr(R.shape))
R = R.to(dtype=dtype).to(device=device)
_check_valid_rotation_matrix(R, tol=orthogonal_tol)
N = R.shape[0]
mat = torch.eye(4, dtype=dtype, device=device)
mat = mat.view(1, 4, 4).repeat(N, 1, 1)
mat[:, :3, :3] = R
self._matrix = mat
def _get_matrix_inverse(self):
"""
Return the inverse of self._matrix.
"""
return self._matrix.permute(0, 2, 1).contiguous()
class RotateAxisAngle(Rotate):
def __init__(
self,
angle,
axis: str = "X",
degrees: bool = True,
dtype=torch.float64,
device: str = "cpu",
):
"""
Create a new Transform3d representing 3D rotation about an axis
by an angle.
Args:
angle:
- A torch tensor of shape (N, 1)
- A python scalar
- A torch scalar
axis:
string: one of ["X", "Y", "Z"] indicating the axis about which
to rotate.
NOTE: All batch elements are rotated about the same axis.
"""
axis = axis.upper()
if axis not in ["X", "Y", "Z"]:
msg = "Expected axis to be one of ['X', 'Y', 'Z']; got %s"
raise ValueError(msg % axis)
angle = _handle_angle_input(angle, dtype, device, "RotateAxisAngle")
angle = (angle / 180.0 * math.pi) if degrees else angle
N = angle.shape[0]
cos = torch.cos(angle)
sin = torch.sin(angle)
one = torch.ones_like(angle)
zero = torch.zeros_like(angle)
if axis == "X":
R_flat = (one, zero, zero, zero, cos, -sin, zero, sin, cos)
if axis == "Y":
R_flat = (cos, zero, sin, zero, one, zero, -sin, zero, cos)
if axis == "Z":
R_flat = (cos, -sin, zero, sin, cos, zero, zero, zero, one)
R = torch.stack(R_flat, -1).reshape((N, 3, 3))
super().__init__(device=device, R=R)
def _handle_coord(c, dtype, device):
"""
Helper function for _handle_input.
Args:
c: Python scalar, torch scalar, or 1D torch tensor
Returns:
c_vec: 1D torch tensor
"""
if not torch.is_tensor(c):
c = torch.tensor(c, dtype=dtype, device=device)
if c.dim() == 0:
c = c.view(1)
return c
def _handle_input(
x, y, z, dtype, device, name: str, allow_singleton: bool = False
):
"""
Helper function to handle parsing logic for building transforms. The output
is always a tensor of shape (N, 3), but there are several types of allowed
input.
Case I: Single Matrix
In this case x is a tensor of shape (N, 3), and y and z are None. Here just
return x.
Case II: Vectors and Scalars
In this case each of x, y, and z can be one of the following
- Python scalar
- Torch scalar
- Torch tensor of shape (N, 1) or (1, 1)
In this case x, y and z are broadcast to tensors of shape (N, 1)
and concatenated to a tensor of shape (N, 3)
Case III: Singleton (only if allow_singleton=True)
In this case y and z are None, and x can be one of the following:
- Python scalar
- Torch scalar
- Torch tensor of shape (N, 1) or (1, 1)
Here x will be duplicated 3 times, and we return a tensor of shape (N, 3)
Returns:
xyz: Tensor of shape (N, 3)
"""
# If x is actually a tensor of shape (N, 3) then just return it
if torch.is_tensor(x) and x.dim() == 2:
if x.shape[1] != 3:
msg = "Expected tensor of shape (N, 3); got %r (in %s)"
raise ValueError(msg % (x.shape, name))
if y is not None or z is not None:
msg = "Expected y and z to be None (in %s)" % name
raise ValueError(msg)
return x
if allow_singleton and y is None and z is None:
y = x
z = x
# Convert all to 1D tensors
xyz = [_handle_coord(c, dtype, device) for c in [x, y, z]]
# Broadcast and concatenate
sizes = [c.shape[0] for c in xyz]
N = max(sizes)
for c in xyz:
if c.shape[0] != 1 and c.shape[0] != N:
msg = "Got non-broadcastable sizes %r (in %s)" % (sizes, name)
raise ValueError(msg)
xyz = [c.expand(N) for c in xyz]
xyz = torch.stack(xyz, dim=1)
return xyz
def _handle_angle_input(x, dtype, device: str, name: str):
"""
Helper function for building a rotation function using angles.
The output is always of shape (N, 1).
The input can be one of:
- Torch tensor (N, 1) or (N)
- Python scalar
- Torch scalar
"""
# If x is actually a tensor of shape (N, 1) then just return it
if torch.is_tensor(x) and x.dim() == 2:
if x.shape[1] != 1:
msg = "Expected tensor of shape (N, 1); got %r (in %s)"
raise ValueError(msg % (x.shape, name))
return x
else:
return _handle_coord(x, dtype, device)
def _broadcast_bmm(a, b):
"""
Batch multiply two matrices and broadcast if necessary.
Args:
a: torch tensor of shape (P, K) or (M, P, K)
b: torch tensor of shape (N, K, K)
Returns:
a and b broadcast multipled. The output batch dimension is max(N, M).
To broadcast transforms across a batch dimension if M != N then
expect that either M = 1 or N = 1. The tensor with batch dimension 1 is
expanded to have shape N or M.
"""
if a.dim() == 2:
a = a[None]
if len(a) != len(b):
if not ((len(a) == 1) or (len(b) == 1)):
msg = "Expected batch dim for bmm to be equal or 1; got %r, %r"
raise ValueError(msg % (a.shape, b.shape))
if len(a) == 1:
a = a.expand(len(b), -1, -1)
if len(b) == 1:
b = b.expand(len(a), -1, -1)
return a.bmm(b)
def _check_valid_rotation_matrix(R, tol: float = 1e-7):
"""
Determine if R is a valid rotation matrix by checking it satisfies the
following conditions:
``RR^T = I and det(R) = 1``
Args:
R: an (N, 3, 3) matrix
Returns:
None
Prints an warning if R is an invalid rotation matrix. Else return.
"""
N = R.shape[0]
eye = torch.eye(3, dtype=R.dtype, device=R.device)
eye = eye.view(1, 3, 3).expand(N, -1, -1)
orthogonal = torch.allclose(R.bmm(R.transpose(1, 2)), eye, atol=tol)
det_R = torch.det(R)
no_distortion = torch.allclose(det_R, torch.ones_like(det_R))
if not (orthogonal and no_distortion):
msg = "R is not a valid rotation matrix"
print(msg)
return

View File

@@ -0,0 +1,5 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
from .ico_sphere import ico_sphere
__all__ = [k for k in globals().keys() if not k.startswith("_")]

View File

@@ -0,0 +1,81 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
import torch
from pytorch3d.ops.subdivide_meshes import SubdivideMeshes
from pytorch3d.structures.meshes import Meshes
# Vertex coordinates for a level 0 ico-sphere.
_ico_verts0 = [
[-0.5257, 0.8507, 0.0000],
[0.5257, 0.8507, 0.0000],
[-0.5257, -0.8507, 0.0000],
[0.5257, -0.8507, 0.0000],
[0.0000, -0.5257, 0.8507],
[0.0000, 0.5257, 0.8507],
[0.0000, -0.5257, -0.8507],
[0.0000, 0.5257, -0.8507],
[0.8507, 0.0000, -0.5257],
[0.8507, 0.0000, 0.5257],
[-0.8507, 0.0000, -0.5257],
[-0.8507, 0.0000, 0.5257],
]
# Faces for level 0 ico-sphere
_ico_faces0 = [
[0, 11, 5],
[0, 5, 1],
[0, 1, 7],
[0, 7, 10],
[0, 10, 11],
[1, 5, 9],
[5, 11, 4],
[11, 10, 2],
[10, 7, 6],
[7, 1, 8],
[3, 9, 4],
[3, 4, 2],
[3, 2, 6],
[3, 6, 8],
[3, 8, 9],
[4, 9, 5],
[2, 4, 11],
[6, 2, 10],
[8, 6, 7],
[9, 8, 1],
]
def ico_sphere(level: int = 0, device=None):
"""
Create verts and faces for a unit ico-sphere, with all faces oriented
consistently.
Args:
level: integer specifying the number of iterations for subdivision
of the mesh faces. Each additional level will result in four new
faces per face.
device: A torch.device object on which the outputs will be allocated.
Returns:
Meshes object with verts and faces.
"""
if device is None:
device = torch.device("cpu")
if level < 0:
raise ValueError("level must be >= 0.")
if level == 0:
verts = torch.tensor(_ico_verts0, dtype=torch.float32, device=device)
faces = torch.tensor(_ico_faces0, dtype=torch.int64, device=device)
else:
mesh = ico_sphere(level - 1, device)
subdivide = SubdivideMeshes()
mesh = subdivide(mesh)
verts = mesh.verts_list()[0]
verts /= verts.norm(p=2, dim=1, keepdim=True)
faces = mesh.faces_list()[0]
return Meshes(verts=[verts], faces=[faces])