mesh_normal_consistency speedup

Summary: One step in finding all the pairs of vertices which share faces is a simple calculation but annoying to parallelize. It was implemented in pure Python. We move it to C++. We still pull the data to the CPU and put the answer back on the device.

Reviewed By: nikhilaravi, gkioxari

Differential Revision: D26073475

fbshipit-source-id: ffbf4e2c347a511ab5084bceff600465812b6a52
This commit is contained in:
Jeremy Reizenstein
2021-02-11 13:54:55 -08:00
committed by Facebook GitHub Bot
parent 5ac2f42184
commit 4bfe7158b1
4 changed files with 84 additions and 28 deletions

View File

@@ -14,6 +14,7 @@
#include "gather_scatter/gather_scatter.h"
#include "interp_face_attrs/interp_face_attrs.h"
#include "knn/knn.h"
#include "mesh_normal_consistency/mesh_normal_consistency.h"
#include "packed_to_padded_tensor/packed_to_padded_tensor.h"
#include "point_mesh/point_mesh_cuda.h"
#include "rasterize_meshes/rasterize_meshes.h"
@@ -31,6 +32,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
#endif
m.def("knn_points_idx", &KNearestNeighborIdx);
m.def("knn_points_backward", &KNearestNeighborBackward);
m.def(
"mesh_normal_consistency_find_verts", &MeshNormalConsistencyFindVertices);
m.def("gather_scatter", &GatherScatter);
m.def("rasterize_points", &RasterizePoints);
m.def("rasterize_points_backward", &RasterizePointsBackward);

View File

@@ -0,0 +1,24 @@
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#pragma once
#include <torch/extension.h>
#include "utils/pytorch3d_cutils.h"
// For mesh_normal_consistency, find pairs of vertices opposite the same edge.
//
// Args:
// edge_num: int64 Tensor of shape (E,) giving the number of vertices
// corresponding to each edge.
//
// Returns:
// pairs: int64 Tensor of shape (N,2)
at::Tensor MeshNormalConsistencyFindVerticesCpu(const at::Tensor& edge_num);
// Exposed implementation.
at::Tensor MeshNormalConsistencyFindVertices(const at::Tensor& edge_num) {
if (edge_num.is_cuda()) {
AT_ERROR("This function needs a CPU tensor.");
}
return MeshNormalConsistencyFindVerticesCpu(edge_num);
}

View File

@@ -0,0 +1,47 @@
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#include <ATen/ATen.h>
#include <utility>
#include <vector>
at::Tensor MeshNormalConsistencyFindVerticesCpu(const at::Tensor& edge_num) {
// We take a LongTensor of shape (E,) giving the number of things intersecting
// each edge. The things are taken to be numbered in order.
// (In fact, the "things" are opposite vertices to edges, renumbered).
// We return a tensor of shape (?, 2) where for every pair of things which
// intersect the same edge there is a row of their numbers in the output.
// Example possible inputs and outputs (order of output is not specified):
// [1,0,1,1,0] => [[]]
// [3] => [[0,1], [0,2], [1,2]]
// [0,3] => [[0,1], [0,2], [1,2]]
// [1,3] => [[1,2], [1,3], [2,3]]
//[1,0,2,1,0,2] => [[1,2], [4,5]]
const auto num_edges = edge_num.size(0);
auto edges_a = edge_num.accessor<int64_t, 1>();
int64_t vert_idx = 0;
std::vector<std::pair<int64_t, int64_t>> pairs;
for (int64_t i_edge = 0; i_edge < num_edges; ++i_edge) {
int64_t e = edges_a[i_edge];
for (int64_t j = 0; j < e; ++j) {
for (int64_t i = 0; i < j; ++i) {
pairs.emplace_back(vert_idx + i, vert_idx + j);
}
}
vert_idx += e;
}
// Convert from std::vector by copying over the items to a new empty torch
// tensor.
auto pairs_tensor = at::empty({(int64_t)pairs.size(), 2}, edge_num.options());
auto pairs_a = pairs_tensor.accessor<int64_t, 2>();
for (int64_t i_pair = 0; i_pair < pairs.size(); ++i_pair) {
auto accessor = pairs_a[i_pair];
accessor[0] = pairs[i_pair].first;
accessor[1] = pairs[i_pair].second;
}
return pairs_tensor;
}