mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-12-20 14:20:38 +08:00
mesh_normal_consistency speedup
Summary: One step in finding all the pairs of vertices which share faces is a simple calculation but annoying to parallelize. It was implemented in pure Python. We move it to C++. We still pull the data to the CPU and put the answer back on the device. Reviewed By: nikhilaravi, gkioxari Differential Revision: D26073475 fbshipit-source-id: ffbf4e2c347a511ab5084bceff600465812b6a52
This commit is contained in:
committed by
Facebook GitHub Bot
parent
5ac2f42184
commit
4bfe7158b1
@@ -14,6 +14,7 @@
|
||||
#include "gather_scatter/gather_scatter.h"
|
||||
#include "interp_face_attrs/interp_face_attrs.h"
|
||||
#include "knn/knn.h"
|
||||
#include "mesh_normal_consistency/mesh_normal_consistency.h"
|
||||
#include "packed_to_padded_tensor/packed_to_padded_tensor.h"
|
||||
#include "point_mesh/point_mesh_cuda.h"
|
||||
#include "rasterize_meshes/rasterize_meshes.h"
|
||||
@@ -31,6 +32,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
#endif
|
||||
m.def("knn_points_idx", &KNearestNeighborIdx);
|
||||
m.def("knn_points_backward", &KNearestNeighborBackward);
|
||||
m.def(
|
||||
"mesh_normal_consistency_find_verts", &MeshNormalConsistencyFindVertices);
|
||||
m.def("gather_scatter", &GatherScatter);
|
||||
m.def("rasterize_points", &RasterizePoints);
|
||||
m.def("rasterize_points_backward", &RasterizePointsBackward);
|
||||
|
||||
@@ -0,0 +1,24 @@
|
||||
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
#include <torch/extension.h>
|
||||
#include "utils/pytorch3d_cutils.h"
|
||||
|
||||
// For mesh_normal_consistency, find pairs of vertices opposite the same edge.
|
||||
//
|
||||
// Args:
|
||||
// edge_num: int64 Tensor of shape (E,) giving the number of vertices
|
||||
// corresponding to each edge.
|
||||
//
|
||||
// Returns:
|
||||
// pairs: int64 Tensor of shape (N,2)
|
||||
|
||||
at::Tensor MeshNormalConsistencyFindVerticesCpu(const at::Tensor& edge_num);
|
||||
|
||||
// Exposed implementation.
|
||||
at::Tensor MeshNormalConsistencyFindVertices(const at::Tensor& edge_num) {
|
||||
if (edge_num.is_cuda()) {
|
||||
AT_ERROR("This function needs a CPU tensor.");
|
||||
}
|
||||
return MeshNormalConsistencyFindVerticesCpu(edge_num);
|
||||
}
|
||||
@@ -0,0 +1,47 @@
|
||||
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
at::Tensor MeshNormalConsistencyFindVerticesCpu(const at::Tensor& edge_num) {
|
||||
// We take a LongTensor of shape (E,) giving the number of things intersecting
|
||||
// each edge. The things are taken to be numbered in order.
|
||||
// (In fact, the "things" are opposite vertices to edges, renumbered).
|
||||
// We return a tensor of shape (?, 2) where for every pair of things which
|
||||
// intersect the same edge there is a row of their numbers in the output.
|
||||
|
||||
// Example possible inputs and outputs (order of output is not specified):
|
||||
// [1,0,1,1,0] => [[]]
|
||||
// [3] => [[0,1], [0,2], [1,2]]
|
||||
// [0,3] => [[0,1], [0,2], [1,2]]
|
||||
// [1,3] => [[1,2], [1,3], [2,3]]
|
||||
//[1,0,2,1,0,2] => [[1,2], [4,5]]
|
||||
|
||||
const auto num_edges = edge_num.size(0);
|
||||
auto edges_a = edge_num.accessor<int64_t, 1>();
|
||||
|
||||
int64_t vert_idx = 0;
|
||||
std::vector<std::pair<int64_t, int64_t>> pairs;
|
||||
for (int64_t i_edge = 0; i_edge < num_edges; ++i_edge) {
|
||||
int64_t e = edges_a[i_edge];
|
||||
for (int64_t j = 0; j < e; ++j) {
|
||||
for (int64_t i = 0; i < j; ++i) {
|
||||
pairs.emplace_back(vert_idx + i, vert_idx + j);
|
||||
}
|
||||
}
|
||||
vert_idx += e;
|
||||
}
|
||||
|
||||
// Convert from std::vector by copying over the items to a new empty torch
|
||||
// tensor.
|
||||
auto pairs_tensor = at::empty({(int64_t)pairs.size(), 2}, edge_num.options());
|
||||
auto pairs_a = pairs_tensor.accessor<int64_t, 2>();
|
||||
for (int64_t i_pair = 0; i_pair < pairs.size(); ++i_pair) {
|
||||
auto accessor = pairs_a[i_pair];
|
||||
accessor[0] = pairs[i_pair].first;
|
||||
accessor[1] = pairs[i_pair].second;
|
||||
}
|
||||
|
||||
return pairs_tensor;
|
||||
}
|
||||
Reference in New Issue
Block a user