mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-12-23 15:50:39 +08:00
pulsar integration.
Summary: This diff integrates the pulsar renderer source code into PyTorch3D as an alternative backend for the PyTorch3D point renderer. This diff is the first of a series of three diffs to complete that migration and focuses on the packaging and integration of the source code. For more information about the pulsar backend, see the release notes and the paper (https://arxiv.org/abs/2004.07484). For information on how to use the backend, see the point cloud rendering notebook and the examples in the folder `docs/examples`. Tasks addressed in the following diffs: * Add the PyTorch3D interface, * Add notebook examples and documentation (or adapt the existing ones to feature both interfaces). Reviewed By: nikhilaravi Differential Revision: D23947736 fbshipit-source-id: a5e77b53e6750334db22aefa89b4c079cda1b443
This commit is contained in:
committed by
Facebook GitHub Bot
parent
d565032399
commit
b19fe1de2f
63
pytorch3d/csrc/pulsar/pytorch/camera.cpp
Normal file
63
pytorch3d/csrc/pulsar/pytorch/camera.cpp
Normal file
@@ -0,0 +1,63 @@
|
||||
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
#include "./camera.h"
|
||||
#include "../include/math.h"
|
||||
|
||||
namespace pulsar {
|
||||
namespace pytorch {
|
||||
|
||||
CamInfo cam_info_from_params(
|
||||
const torch::Tensor& cam_pos,
|
||||
const torch::Tensor& pixel_0_0_center,
|
||||
const torch::Tensor& pixel_vec_x,
|
||||
const torch::Tensor& pixel_vec_y,
|
||||
const torch::Tensor& principal_point_offset,
|
||||
const float& focal_length,
|
||||
const uint& width,
|
||||
const uint& height,
|
||||
const float& min_dist,
|
||||
const float& max_dist,
|
||||
const bool& right_handed) {
|
||||
CamInfo res;
|
||||
fill_cam_vecs(
|
||||
cam_pos.detach().cpu(),
|
||||
pixel_0_0_center.detach().cpu(),
|
||||
pixel_vec_x.detach().cpu(),
|
||||
pixel_vec_y.detach().cpu(),
|
||||
principal_point_offset.detach().cpu(),
|
||||
right_handed,
|
||||
&res);
|
||||
res.half_pixel_size = 0.5f * length(res.pixel_dir_x);
|
||||
if (length(res.pixel_dir_y) * 0.5f - res.half_pixel_size > EPS) {
|
||||
throw std::runtime_error("Pixel sizes must agree in x and y direction!");
|
||||
}
|
||||
res.focal_length = focal_length;
|
||||
res.aperture_width =
|
||||
width + 2u * static_cast<uint>(abs(res.principal_point_offset_x));
|
||||
res.aperture_height =
|
||||
height + 2u * static_cast<uint>(abs(res.principal_point_offset_y));
|
||||
res.pixel_0_0_center -=
|
||||
res.pixel_dir_x * static_cast<float>(abs(res.principal_point_offset_x));
|
||||
res.pixel_0_0_center -=
|
||||
res.pixel_dir_y * static_cast<float>(abs(res.principal_point_offset_y));
|
||||
res.film_width = width;
|
||||
res.film_height = height;
|
||||
res.film_border_left =
|
||||
static_cast<uint>(std::max(0, 2 * res.principal_point_offset_x));
|
||||
res.film_border_top =
|
||||
static_cast<uint>(std::max(0, 2 * res.principal_point_offset_y));
|
||||
LOG_IF(INFO, PULSAR_LOG_INIT)
|
||||
<< "Aperture width, height: " << res.aperture_width << ", "
|
||||
<< res.aperture_height;
|
||||
LOG_IF(INFO, PULSAR_LOG_INIT)
|
||||
<< "Film width, height: " << res.film_width << ", " << res.film_height;
|
||||
LOG_IF(INFO, PULSAR_LOG_INIT)
|
||||
<< "Film border left, top: " << res.film_border_left << ", "
|
||||
<< res.film_border_top;
|
||||
res.min_dist = min_dist;
|
||||
res.max_dist = max_dist;
|
||||
res.norm_fac = 1.f / (max_dist - min_dist);
|
||||
return res;
|
||||
};
|
||||
|
||||
} // namespace pytorch
|
||||
} // namespace pulsar
|
||||
61
pytorch3d/csrc/pulsar/pytorch/camera.h
Normal file
61
pytorch3d/csrc/pulsar/pytorch/camera.h
Normal file
@@ -0,0 +1,61 @@
|
||||
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
#ifndef PULSAR_NATIVE_CAMERA_H_
|
||||
#define PULSAR_NATIVE_CAMERA_H_
|
||||
|
||||
#include <tuple>
|
||||
#include "../global.h"
|
||||
|
||||
#include "../include/camera.h"
|
||||
|
||||
namespace pulsar {
|
||||
namespace pytorch {
|
||||
|
||||
inline void fill_cam_vecs(
|
||||
const torch::Tensor& pos_vec,
|
||||
const torch::Tensor& pixel_0_0_center,
|
||||
const torch::Tensor& pixel_dir_x,
|
||||
const torch::Tensor& pixel_dir_y,
|
||||
const torch::Tensor& principal_point_offset,
|
||||
const bool& right_handed,
|
||||
CamInfo* res) {
|
||||
res->eye.x = pos_vec.data_ptr<float>()[0];
|
||||
res->eye.y = pos_vec.data_ptr<float>()[1];
|
||||
res->eye.z = pos_vec.data_ptr<float>()[2];
|
||||
res->pixel_0_0_center.x = pixel_0_0_center.data_ptr<float>()[0];
|
||||
res->pixel_0_0_center.y = pixel_0_0_center.data_ptr<float>()[1];
|
||||
res->pixel_0_0_center.z = pixel_0_0_center.data_ptr<float>()[2];
|
||||
res->pixel_dir_x.x = pixel_dir_x.data_ptr<float>()[0];
|
||||
res->pixel_dir_x.y = pixel_dir_x.data_ptr<float>()[1];
|
||||
res->pixel_dir_x.z = pixel_dir_x.data_ptr<float>()[2];
|
||||
res->pixel_dir_y.x = pixel_dir_y.data_ptr<float>()[0];
|
||||
res->pixel_dir_y.y = pixel_dir_y.data_ptr<float>()[1];
|
||||
res->pixel_dir_y.z = pixel_dir_y.data_ptr<float>()[2];
|
||||
auto sensor_dir_z = pixel_dir_y.cross(pixel_dir_x);
|
||||
sensor_dir_z /= sensor_dir_z.norm();
|
||||
if (right_handed) {
|
||||
sensor_dir_z *= -1.f;
|
||||
}
|
||||
res->sensor_dir_z.x = sensor_dir_z.data_ptr<float>()[0];
|
||||
res->sensor_dir_z.y = sensor_dir_z.data_ptr<float>()[1];
|
||||
res->sensor_dir_z.z = sensor_dir_z.data_ptr<float>()[2];
|
||||
res->principal_point_offset_x = principal_point_offset.data_ptr<int32_t>()[0];
|
||||
res->principal_point_offset_y = principal_point_offset.data_ptr<int32_t>()[1];
|
||||
}
|
||||
|
||||
CamInfo cam_info_from_params(
|
||||
const torch::Tensor& cam_pos,
|
||||
const torch::Tensor& pixel_0_0_center,
|
||||
const torch::Tensor& pixel_vec_x,
|
||||
const torch::Tensor& pixel_vec_y,
|
||||
const torch::Tensor& principal_point_offset,
|
||||
const float& focal_length,
|
||||
const uint& width,
|
||||
const uint& height,
|
||||
const float& min_dist,
|
||||
const float& max_dist,
|
||||
const bool& right_handed);
|
||||
|
||||
} // namespace pytorch
|
||||
} // namespace pulsar
|
||||
|
||||
#endif
|
||||
1481
pytorch3d/csrc/pulsar/pytorch/renderer.cpp
Normal file
1481
pytorch3d/csrc/pulsar/pytorch/renderer.cpp
Normal file
File diff suppressed because it is too large
Load Diff
167
pytorch3d/csrc/pulsar/pytorch/renderer.h
Normal file
167
pytorch3d/csrc/pulsar/pytorch/renderer.h
Normal file
@@ -0,0 +1,167 @@
|
||||
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
#ifndef PULSAR_NATIVE_PYTORCH_RENDERER_H_
|
||||
#define PULSAR_NATIVE_PYTORCH_RENDERER_H_
|
||||
|
||||
#include "../global.h"
|
||||
#include "../include/renderer.h"
|
||||
|
||||
namespace pulsar {
|
||||
namespace pytorch {
|
||||
|
||||
struct Renderer {
|
||||
public:
|
||||
/**
|
||||
* Pytorch Pulsar differentiable rendering module.
|
||||
*/
|
||||
explicit Renderer(
|
||||
const unsigned int& width,
|
||||
const unsigned int& height,
|
||||
const uint& max_n_balls,
|
||||
const bool& orthogonal_projection,
|
||||
const bool& right_handed_system,
|
||||
const float& background_normalization_depth,
|
||||
const uint& n_channels,
|
||||
const uint& n_track);
|
||||
~Renderer();
|
||||
|
||||
std::tuple<torch::Tensor, torch::Tensor> forward(
|
||||
const torch::Tensor& vert_pos,
|
||||
const torch::Tensor& vert_col,
|
||||
const torch::Tensor& vert_radii,
|
||||
const torch::Tensor& cam_pos,
|
||||
const torch::Tensor& pixel_0_0_center,
|
||||
const torch::Tensor& pixel_vec_x,
|
||||
const torch::Tensor& pixel_vec_y,
|
||||
const torch::Tensor& focal_length,
|
||||
const torch::Tensor& principal_point_offsets,
|
||||
const float& gamma,
|
||||
const float& max_depth,
|
||||
float min_depth,
|
||||
const c10::optional<torch::Tensor>& bg_col,
|
||||
const c10::optional<torch::Tensor>& opacity,
|
||||
const float& percent_allowed_difference,
|
||||
const uint& max_n_hits,
|
||||
const uint& mode);
|
||||
|
||||
std::tuple<
|
||||
at::optional<torch::Tensor>,
|
||||
at::optional<torch::Tensor>,
|
||||
at::optional<torch::Tensor>,
|
||||
at::optional<torch::Tensor>,
|
||||
at::optional<torch::Tensor>,
|
||||
at::optional<torch::Tensor>,
|
||||
at::optional<torch::Tensor>,
|
||||
at::optional<torch::Tensor>>
|
||||
backward(
|
||||
const torch::Tensor& grad_im,
|
||||
const torch::Tensor& image,
|
||||
const torch::Tensor& forw_info,
|
||||
const torch::Tensor& vert_pos,
|
||||
const torch::Tensor& vert_col,
|
||||
const torch::Tensor& vert_radii,
|
||||
const torch::Tensor& cam_pos,
|
||||
const torch::Tensor& pixel_0_0_center,
|
||||
const torch::Tensor& pixel_vec_x,
|
||||
const torch::Tensor& pixel_vec_y,
|
||||
const torch::Tensor& focal_length,
|
||||
const torch::Tensor& principal_point_offsets,
|
||||
const float& gamma,
|
||||
const float& max_depth,
|
||||
float min_depth,
|
||||
const c10::optional<torch::Tensor>& bg_col,
|
||||
const c10::optional<torch::Tensor>& opacity,
|
||||
const float& percent_allowed_difference,
|
||||
const uint& max_n_hits,
|
||||
const uint& mode,
|
||||
const bool& dif_pos,
|
||||
const bool& dif_col,
|
||||
const bool& dif_rad,
|
||||
const bool& dif_cam,
|
||||
const bool& dif_opy,
|
||||
const at::optional<std::pair<uint, uint>>& dbg_pos);
|
||||
|
||||
// Infrastructure.
|
||||
/**
|
||||
* Ensure that the renderer is placed on this device.
|
||||
* Is nearly a no-op if the device is correct.
|
||||
*/
|
||||
void ensure_on_device(torch::Device device, bool non_blocking = false);
|
||||
|
||||
/**
|
||||
* Ensure that at least n renderers are available.
|
||||
*/
|
||||
void ensure_n_renderers_gte(const size_t& batch_size);
|
||||
|
||||
/**
|
||||
* Check the parameters.
|
||||
*/
|
||||
std::tuple<size_t, size_t, bool, torch::Tensor> arg_check(
|
||||
const torch::Tensor& vert_pos,
|
||||
const torch::Tensor& vert_col,
|
||||
const torch::Tensor& vert_radii,
|
||||
const torch::Tensor& cam_pos,
|
||||
const torch::Tensor& pixel_0_0_center,
|
||||
const torch::Tensor& pixel_vec_x,
|
||||
const torch::Tensor& pixel_vec_y,
|
||||
const torch::Tensor& focal_length,
|
||||
const torch::Tensor& principal_point_offsets,
|
||||
const float& gamma,
|
||||
const float& max_depth,
|
||||
float& min_depth,
|
||||
const c10::optional<torch::Tensor>& bg_col,
|
||||
const c10::optional<torch::Tensor>& opacity,
|
||||
const float& percent_allowed_difference,
|
||||
const uint& max_n_hits,
|
||||
const uint& mode);
|
||||
|
||||
bool operator==(const Renderer& rhs) const;
|
||||
inline friend std::ostream& operator<<(
|
||||
std::ostream& stream,
|
||||
const Renderer& self) {
|
||||
stream << "pulsar::Renderer[";
|
||||
// Device info.
|
||||
stream << self.device_type;
|
||||
if (self.device_index != -1)
|
||||
stream << ", ID " << self.device_index;
|
||||
stream << "]";
|
||||
return stream;
|
||||
}
|
||||
|
||||
inline uint width() const {
|
||||
return this->renderer_vec[0].cam.film_width;
|
||||
}
|
||||
inline uint height() const {
|
||||
return this->renderer_vec[0].cam.film_height;
|
||||
}
|
||||
inline int max_num_balls() const {
|
||||
return this->renderer_vec[0].max_num_balls;
|
||||
}
|
||||
inline bool orthogonal() const {
|
||||
return this->renderer_vec[0].cam.orthogonal_projection;
|
||||
}
|
||||
inline bool right_handed() const {
|
||||
return this->renderer_vec[0].cam.right_handed;
|
||||
}
|
||||
inline uint n_track() const {
|
||||
return static_cast<uint>(this->renderer_vec[0].n_track);
|
||||
}
|
||||
|
||||
/** A tensor that is registered as a buffer with this Module to track its
|
||||
* device placement. Unfortunately, pytorch doesn't offer tracking Module
|
||||
* device placement in a better way as of now.
|
||||
*/
|
||||
torch::Tensor device_tracker;
|
||||
|
||||
protected:
|
||||
/** The device type for this renderer. */
|
||||
c10::DeviceType device_type;
|
||||
/** The device index for this renderer. */
|
||||
c10::DeviceIndex device_index;
|
||||
/** Pointer to the underlying pulsar renderers. */
|
||||
std::vector<pulsar::Renderer::Renderer> renderer_vec;
|
||||
};
|
||||
|
||||
} // namespace pytorch
|
||||
} // namespace pulsar
|
||||
|
||||
#endif
|
||||
48
pytorch3d/csrc/pulsar/pytorch/tensor_util.cpp
Normal file
48
pytorch3d/csrc/pulsar/pytorch/tensor_util.cpp
Normal file
@@ -0,0 +1,48 @@
|
||||
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <cuda_runtime_api.h>
|
||||
#include <torch/extension.h>
|
||||
|
||||
#include "./tensor_util.h"
|
||||
|
||||
namespace pulsar {
|
||||
namespace pytorch {
|
||||
|
||||
torch::Tensor sphere_ids_from_result_info_nograd(
|
||||
const torch::Tensor& forw_info) {
|
||||
torch::Tensor result = torch::zeros(
|
||||
{forw_info.size(0),
|
||||
forw_info.size(1),
|
||||
forw_info.size(2),
|
||||
(forw_info.size(3) - 3) / 2},
|
||||
torch::TensorOptions().device(forw_info.device()).dtype(torch::kInt32));
|
||||
// Get the relevant slice, contiguous.
|
||||
torch::Tensor tmp =
|
||||
forw_info
|
||||
.slice(
|
||||
/*dim=*/3, /*start=*/3, /*end=*/forw_info.size(3), /*step=*/2)
|
||||
.contiguous();
|
||||
if (forw_info.device().type() == c10::DeviceType::CUDA) {
|
||||
cudaMemcpyAsync(
|
||||
result.data_ptr(),
|
||||
tmp.data_ptr(),
|
||||
sizeof(uint32_t) * tmp.size(0) * tmp.size(1) * tmp.size(2) *
|
||||
tmp.size(3),
|
||||
cudaMemcpyDeviceToDevice,
|
||||
at::cuda::getCurrentCUDAStream());
|
||||
} else {
|
||||
memcpy(
|
||||
result.data_ptr(),
|
||||
tmp.data_ptr(),
|
||||
sizeof(uint32_t) * tmp.size(0) * tmp.size(1) * tmp.size(2) *
|
||||
tmp.size(3));
|
||||
}
|
||||
// `tmp` is freed after this, the memory might get reallocated. However,
|
||||
// only kernels in the same stream should ever be able to write to this
|
||||
// memory, which are executed only after the memcpy is complete. That's
|
||||
// why we can just continue.
|
||||
return result;
|
||||
}
|
||||
|
||||
} // namespace pytorch
|
||||
} // namespace pulsar
|
||||
16
pytorch3d/csrc/pulsar/pytorch/tensor_util.h
Normal file
16
pytorch3d/csrc/pulsar/pytorch/tensor_util.h
Normal file
@@ -0,0 +1,16 @@
|
||||
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
#ifndef PULSAR_NATIVE_PYTORCH_TENSOR_UTIL_H_
|
||||
#define PULSAR_NATIVE_PYTORCH_TENSOR_UTIL_H_
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
|
||||
namespace pulsar {
|
||||
namespace pytorch {
|
||||
|
||||
torch::Tensor sphere_ids_from_result_info_nograd(
|
||||
const torch::Tensor& forw_info);
|
||||
|
||||
}
|
||||
} // namespace pulsar
|
||||
|
||||
#endif
|
||||
24
pytorch3d/csrc/pulsar/pytorch/util.cpp
Normal file
24
pytorch3d/csrc/pulsar/pytorch/util.cpp
Normal file
@@ -0,0 +1,24 @@
|
||||
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
#include <cuda_runtime_api.h>
|
||||
|
||||
namespace pulsar {
|
||||
namespace pytorch {
|
||||
|
||||
void cudaDevToDev(
|
||||
void* trg,
|
||||
const void* src,
|
||||
const int& size,
|
||||
const cudaStream_t& stream) {
|
||||
cudaMemcpyAsync(trg, src, size, cudaMemcpyDeviceToDevice, stream);
|
||||
}
|
||||
|
||||
void cudaDevToHost(
|
||||
void* trg,
|
||||
const void* src,
|
||||
const int& size,
|
||||
const cudaStream_t& stream) {
|
||||
cudaMemcpyAsync(trg, src, size, cudaMemcpyDeviceToHost, stream);
|
||||
}
|
||||
|
||||
} // namespace pytorch
|
||||
} // namespace pulsar
|
||||
59
pytorch3d/csrc/pulsar/pytorch/util.h
Normal file
59
pytorch3d/csrc/pulsar/pytorch/util.h
Normal file
@@ -0,0 +1,59 @@
|
||||
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
#ifndef PULSAR_NATIVE_PYTORCH_UTIL_H_
|
||||
#define PULSAR_NATIVE_PYTORCH_UTIL_H_
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
#include "../global.h"
|
||||
|
||||
namespace pulsar {
|
||||
namespace pytorch {
|
||||
|
||||
void cudaDevToDev(
|
||||
void* trg,
|
||||
const void* src,
|
||||
const int& size,
|
||||
const cudaStream_t& stream);
|
||||
void cudaDevToHost(
|
||||
void* trg,
|
||||
const void* src,
|
||||
const int& size,
|
||||
const cudaStream_t& stream);
|
||||
|
||||
/**
|
||||
* This method takes a memory pointer and wraps it into a pytorch tensor.
|
||||
*
|
||||
* This is preferred over `torch::from_blob`, since that requires a CUDA
|
||||
* managed pointer. However, working with these for high performance
|
||||
* operations is slower. Most of the rendering operations should stay
|
||||
* local to the respective GPU anyways, so unmanaged pointers are
|
||||
* preferred.
|
||||
*/
|
||||
template <typename T>
|
||||
torch::Tensor from_blob(
|
||||
const T* ptr,
|
||||
const torch::IntArrayRef& shape,
|
||||
const c10::DeviceType& device_type,
|
||||
const c10::DeviceIndex& device_index,
|
||||
const torch::Dtype& dtype,
|
||||
const cudaStream_t& stream) {
|
||||
torch::Tensor ret = torch::zeros(
|
||||
shape, torch::device({device_type, device_index}).dtype(dtype));
|
||||
const int num_elements =
|
||||
std::accumulate(shape.begin(), shape.end(), 1, std::multiplies<int>{});
|
||||
if (device_type == c10::DeviceType::CUDA) {
|
||||
cudaDevToDev(
|
||||
ret.data_ptr(),
|
||||
static_cast<const void*>(ptr),
|
||||
sizeof(T) * num_elements,
|
||||
stream);
|
||||
// TODO: check for synchronization.
|
||||
} else {
|
||||
memcpy(ret.data_ptr(), ptr, sizeof(T) * num_elements);
|
||||
}
|
||||
return ret;
|
||||
};
|
||||
|
||||
} // namespace pytorch
|
||||
} // namespace pulsar
|
||||
|
||||
#endif
|
||||
Reference in New Issue
Block a user