mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-03 12:22:49 +08:00
Summary: Removes the now-unnecessary kernels from point mesh edge file Migrates all point mesh functionality into one file. Reviewed By: gkioxari Differential Revision: D24550086 fbshipit-source-id: f924996cd38a7c2c1cf189d8a01611de4506cfa3
159 lines
6.3 KiB
C++
159 lines
6.3 KiB
C++
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
|
|
|
// clang-format off
|
|
#include "./pulsar/global.h" // Include before <torch/extension.h>.
|
|
#include <torch/extension.h>
|
|
// clang-format on
|
|
#include "./pulsar/pytorch/renderer.h"
|
|
#include "./pulsar/pytorch/tensor_util.h"
|
|
#include "blending/sigmoid_alpha_blend.h"
|
|
#include "compositing/alpha_composite.h"
|
|
#include "compositing/norm_weighted_sum.h"
|
|
#include "compositing/weighted_sum.h"
|
|
#include "face_areas_normals/face_areas_normals.h"
|
|
#include "gather_scatter/gather_scatter.h"
|
|
#include "interp_face_attrs/interp_face_attrs.h"
|
|
#include "knn/knn.h"
|
|
#include "packed_to_padded_tensor/packed_to_padded_tensor.h"
|
|
#include "point_mesh/point_mesh_cuda.h"
|
|
#include "rasterize_meshes/rasterize_meshes.h"
|
|
#include "rasterize_points/rasterize_points.h"
|
|
|
|
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
|
m.def("face_areas_normals_forward", &FaceAreasNormalsForward);
|
|
m.def("face_areas_normals_backward", &FaceAreasNormalsBackward);
|
|
m.def("packed_to_padded", &PackedToPadded);
|
|
m.def("padded_to_packed", &PaddedToPacked);
|
|
m.def("interp_face_attrs_forward", &InterpFaceAttrsForward);
|
|
m.def("interp_face_attrs_backward", &InterpFaceAttrsBackward);
|
|
#ifdef WITH_CUDA
|
|
m.def("knn_check_version", &KnnCheckVersion);
|
|
#endif
|
|
m.def("knn_points_idx", &KNearestNeighborIdx);
|
|
m.def("knn_points_backward", &KNearestNeighborBackward);
|
|
m.def("gather_scatter", &GatherScatter);
|
|
m.def("rasterize_points", &RasterizePoints);
|
|
m.def("rasterize_points_backward", &RasterizePointsBackward);
|
|
m.def("rasterize_meshes_backward", &RasterizeMeshesBackward);
|
|
m.def("rasterize_meshes", &RasterizeMeshes);
|
|
m.def("sigmoid_alpha_blend", &SigmoidAlphaBlend);
|
|
m.def("sigmoid_alpha_blend_backward", &SigmoidAlphaBlendBackward);
|
|
|
|
// Accumulation functions
|
|
m.def("accum_weightedsumnorm", &weightedSumNormForward);
|
|
m.def("accum_weightedsum", &weightedSumForward);
|
|
m.def("accum_alphacomposite", &alphaCompositeForward);
|
|
m.def("accum_weightedsumnorm_backward", &weightedSumNormBackward);
|
|
m.def("accum_weightedsum_backward", &weightedSumBackward);
|
|
m.def("accum_alphacomposite_backward", &alphaCompositeBackward);
|
|
|
|
// These are only visible for testing; users should not call them directly
|
|
m.def("_rasterize_points_coarse", &RasterizePointsCoarse);
|
|
m.def("_rasterize_points_naive", &RasterizePointsNaive);
|
|
m.def("_rasterize_meshes_naive", &RasterizeMeshesNaive);
|
|
m.def("_rasterize_meshes_coarse", &RasterizeMeshesCoarse);
|
|
m.def("_rasterize_meshes_fine", &RasterizeMeshesFine);
|
|
|
|
// PointEdge distance functions
|
|
m.def("point_edge_dist_forward", &PointEdgeDistanceForward);
|
|
m.def("point_edge_dist_backward", &PointEdgeDistanceBackward);
|
|
m.def("edge_point_dist_forward", &EdgePointDistanceForward);
|
|
m.def("edge_point_dist_backward", &EdgePointDistanceBackward);
|
|
m.def("point_edge_array_dist_forward", &PointEdgeArrayDistanceForward);
|
|
m.def("point_edge_array_dist_backward", &PointEdgeArrayDistanceBackward);
|
|
|
|
// PointFace distance functions
|
|
m.def("point_face_dist_forward", &PointFaceDistanceForward);
|
|
m.def("point_face_dist_backward", &PointFaceDistanceBackward);
|
|
m.def("face_point_dist_forward", &FacePointDistanceForward);
|
|
m.def("face_point_dist_backward", &FacePointDistanceBackward);
|
|
m.def("point_face_array_dist_forward", &PointFaceArrayDistanceForward);
|
|
m.def("point_face_array_dist_backward", &PointFaceArrayDistanceBackward);
|
|
|
|
// Pulsar.
|
|
#ifdef PULSAR_LOGGING_ENABLED
|
|
c10::ShowLogInfoToStderr();
|
|
#endif
|
|
py::class_<
|
|
pulsar::pytorch::Renderer,
|
|
std::shared_ptr<pulsar::pytorch::Renderer>>(m, "PulsarRenderer")
|
|
.def(py::init<
|
|
const uint&,
|
|
const uint&,
|
|
const uint&,
|
|
const bool&,
|
|
const bool&,
|
|
const float&,
|
|
const uint&,
|
|
const uint&>())
|
|
.def(
|
|
"__eq__",
|
|
[](const pulsar::pytorch::Renderer& a,
|
|
const pulsar::pytorch::Renderer& b) { return a == b; },
|
|
py::is_operator())
|
|
.def(
|
|
"__ne__",
|
|
[](const pulsar::pytorch::Renderer& a,
|
|
const pulsar::pytorch::Renderer& b) { return !(a == b); },
|
|
py::is_operator())
|
|
.def(
|
|
"__repr__",
|
|
[](const pulsar::pytorch::Renderer& self) {
|
|
std::stringstream ss;
|
|
ss << self;
|
|
return ss.str();
|
|
})
|
|
.def(
|
|
"forward",
|
|
&pulsar::pytorch::Renderer::forward,
|
|
py::arg("vert_pos"),
|
|
py::arg("vert_col"),
|
|
py::arg("vert_radii"),
|
|
|
|
py::arg("cam_pos"),
|
|
py::arg("pixel_0_0_center"),
|
|
py::arg("pixel_vec_x"),
|
|
py::arg("pixel_vec_y"),
|
|
py::arg("focal_length"),
|
|
py::arg("principal_point_offsets"),
|
|
|
|
py::arg("gamma"),
|
|
py::arg("max_depth"),
|
|
py::arg("min_depth") /* = 0.f*/,
|
|
py::arg(
|
|
"bg_col") /* = at::nullopt not exposed properly in pytorch 1.1. */
|
|
,
|
|
py::arg("opacity") /* = at::nullopt ... */,
|
|
py::arg("percent_allowed_difference") = 0.01f,
|
|
py::arg("max_n_hits") = MAX_UINT,
|
|
py::arg("mode") = 0)
|
|
.def("backward", &pulsar::pytorch::Renderer::backward)
|
|
.def_property(
|
|
"device_tracker",
|
|
[](const pulsar::pytorch::Renderer& self) {
|
|
return self.device_tracker;
|
|
},
|
|
[](pulsar::pytorch::Renderer& self, const torch::Tensor& val) {
|
|
self.device_tracker = val;
|
|
})
|
|
.def_property_readonly("width", &pulsar::pytorch::Renderer::width)
|
|
.def_property_readonly("height", &pulsar::pytorch::Renderer::height)
|
|
.def_property_readonly(
|
|
"max_num_balls", &pulsar::pytorch::Renderer::max_num_balls)
|
|
.def_property_readonly(
|
|
"orthogonal", &pulsar::pytorch::Renderer::orthogonal)
|
|
.def_property_readonly(
|
|
"right_handed", &pulsar::pytorch::Renderer::right_handed)
|
|
.def_property_readonly("n_track", &pulsar::pytorch::Renderer::n_track);
|
|
m.def(
|
|
"pulsar_sphere_ids_from_result_info_nograd",
|
|
&pulsar::pytorch::sphere_ids_from_result_info_nograd);
|
|
// Constants.
|
|
m.attr("EPS") = py::float_(EPS);
|
|
m.attr("MAX_FLOAT") = py::float_(MAX_FLOAT);
|
|
m.attr("MAX_INT") = py::int_(MAX_INT);
|
|
m.attr("MAX_UINT") = py::int_(MAX_UINT);
|
|
m.attr("MAX_USHORT") = py::int_(MAX_USHORT);
|
|
m.attr("PULSAR_MAX_GRAD_SPHERES") = py::int_(MAX_GRAD_SPHERES);
|
|
}
|