diff --git a/pytorch3d/csrc/ext.cpp b/pytorch3d/csrc/ext.cpp index 012a95be..79dc76db 100644 --- a/pytorch3d/csrc/ext.cpp +++ b/pytorch3d/csrc/ext.cpp @@ -14,6 +14,7 @@ #include "gather_scatter/gather_scatter.h" #include "interp_face_attrs/interp_face_attrs.h" #include "knn/knn.h" +#include "mesh_normal_consistency/mesh_normal_consistency.h" #include "packed_to_padded_tensor/packed_to_padded_tensor.h" #include "point_mesh/point_mesh_cuda.h" #include "rasterize_meshes/rasterize_meshes.h" @@ -31,6 +32,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { #endif m.def("knn_points_idx", &KNearestNeighborIdx); m.def("knn_points_backward", &KNearestNeighborBackward); + m.def( + "mesh_normal_consistency_find_verts", &MeshNormalConsistencyFindVertices); m.def("gather_scatter", &GatherScatter); m.def("rasterize_points", &RasterizePoints); m.def("rasterize_points_backward", &RasterizePointsBackward); diff --git a/pytorch3d/csrc/mesh_normal_consistency/mesh_normal_consistency.h b/pytorch3d/csrc/mesh_normal_consistency/mesh_normal_consistency.h new file mode 100644 index 00000000..521f635a --- /dev/null +++ b/pytorch3d/csrc/mesh_normal_consistency/mesh_normal_consistency.h @@ -0,0 +1,24 @@ +// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. + +#pragma once +#include +#include "utils/pytorch3d_cutils.h" + +// For mesh_normal_consistency, find pairs of vertices opposite the same edge. +// +// Args: +// edge_num: int64 Tensor of shape (E,) giving the number of vertices +// corresponding to each edge. +// +// Returns: +// pairs: int64 Tensor of shape (N,2) + +at::Tensor MeshNormalConsistencyFindVerticesCpu(const at::Tensor& edge_num); + +// Exposed implementation. +at::Tensor MeshNormalConsistencyFindVertices(const at::Tensor& edge_num) { + if (edge_num.is_cuda()) { + AT_ERROR("This function needs a CPU tensor."); + } + return MeshNormalConsistencyFindVerticesCpu(edge_num); +} diff --git a/pytorch3d/csrc/mesh_normal_consistency/mesh_normal_consistency_cpu.cpp b/pytorch3d/csrc/mesh_normal_consistency/mesh_normal_consistency_cpu.cpp new file mode 100644 index 00000000..73e31aa0 --- /dev/null +++ b/pytorch3d/csrc/mesh_normal_consistency/mesh_normal_consistency_cpu.cpp @@ -0,0 +1,47 @@ +// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. + +#include +#include +#include + +at::Tensor MeshNormalConsistencyFindVerticesCpu(const at::Tensor& edge_num) { + // We take a LongTensor of shape (E,) giving the number of things intersecting + // each edge. The things are taken to be numbered in order. + // (In fact, the "things" are opposite vertices to edges, renumbered). + // We return a tensor of shape (?, 2) where for every pair of things which + // intersect the same edge there is a row of their numbers in the output. + + // Example possible inputs and outputs (order of output is not specified): + // [1,0,1,1,0] => [[]] + // [3] => [[0,1], [0,2], [1,2]] + // [0,3] => [[0,1], [0,2], [1,2]] + // [1,3] => [[1,2], [1,3], [2,3]] + //[1,0,2,1,0,2] => [[1,2], [4,5]] + + const auto num_edges = edge_num.size(0); + auto edges_a = edge_num.accessor(); + + int64_t vert_idx = 0; + std::vector> pairs; + for (int64_t i_edge = 0; i_edge < num_edges; ++i_edge) { + int64_t e = edges_a[i_edge]; + for (int64_t j = 0; j < e; ++j) { + for (int64_t i = 0; i < j; ++i) { + pairs.emplace_back(vert_idx + i, vert_idx + j); + } + } + vert_idx += e; + } + + // Convert from std::vector by copying over the items to a new empty torch + // tensor. + auto pairs_tensor = at::empty({(int64_t)pairs.size(), 2}, edge_num.options()); + auto pairs_a = pairs_tensor.accessor(); + for (int64_t i_pair = 0; i_pair < pairs.size(); ++i_pair) { + auto accessor = pairs_a[i_pair]; + accessor[0] = pairs[i_pair].first; + accessor[1] = pairs[i_pair].second; + } + + return pairs_tensor; +} diff --git a/pytorch3d/loss/mesh_normal_consistency.py b/pytorch3d/loss/mesh_normal_consistency.py index 1433da52..c20e73f5 100644 --- a/pytorch3d/loss/mesh_normal_consistency.py +++ b/pytorch3d/loss/mesh_normal_consistency.py @@ -1,10 +1,10 @@ # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. - -from itertools import islice - import torch +# pyre-fixme[21]: Could not find name `_C` in `pytorch3d`. +from pytorch3d import _C + def mesh_normal_consistency(meshes): r""" @@ -71,9 +71,9 @@ def mesh_normal_consistency(meshes): F = faces_packed.shape[0] # sum(F_n) # We don't want gradients for the following operation. The goal is to - # find for each edge e all the vertices associated with e. In the example above, - # the vertices associated with e are (v0, v1, a, b), i.e. points on e (=v0, v1) - # and points connected on faces to e (=a, b). + # find for each edge e all the vertices associated with e. In the example + # above, the vertices associated with e are (a, b), i.e. the points connected + # on faces to e. with torch.no_grad(): edge_idx = face_to_edge.reshape(F * 3) # (3 * F,) indexes into edges vert_idx = ( @@ -95,23 +95,10 @@ def mesh_normal_consistency(meshes): # the number of vertices which are associated with each edge. # There can be a different number for each edge. edge_num = edge_idx.bincount(minlength=E) - # Create pairs of vertices associated to e. We generate a list of lists: - # each list has the indices of the vertices which are opposite to one edge. - # The length of the list for each edge will vary. - vert_edge_pair_idx = split_list( - list(range(edge_idx.shape[0])), edge_num.tolist() - ) - # For each list find all combinations of pairs in the list. This represents - # all pairs of vertices which are opposite to the same edge. - vert_edge_pair_idx = [ - [e[i], e[j]] - for e in vert_edge_pair_idx - for i in range(len(e) - 1) - for j in range(1, len(e)) - if i < j - ] - vert_edge_pair_idx = torch.tensor( - vert_edge_pair_idx, device=meshes.device, dtype=torch.int64 + + # This calculates all pairs of vertices which are opposite to the same edge. + vert_edge_pair_idx = _C.mesh_normal_consistency_find_verts(edge_num.cpu()).to( + edge_num.device ) if vert_edge_pair_idx.shape[0] == 0: @@ -141,8 +128,3 @@ def mesh_normal_consistency(meshes): loss = loss * weights return loss.sum() / N - - -def split_list(input, length_to_split): - inputt = iter(input) - return [list(islice(inputt, elem)) for elem in length_to_split]