Remove THGeneral (#69041)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/69041

`TH_CONCAT_{N}` is still being used by THP so I've moved that into
it's own header but all the compiled code is gone.

Test Plan: Imported from OSS

Reviewed By: anjali411

Differential Revision: D32872477

Pulled By: ngimel

fbshipit-source-id: 06c82d8f96dbcee0715be407c61dfc7d7e8be47a
This commit is contained in:
Peter Bell 2021-12-13 16:12:58 -08:00 committed by Facebook GitHub Bot
parent d049cd2e01
commit f8fe9a2be1
4 changed files with 93 additions and 97 deletions

View File

@ -208,8 +208,7 @@ __device__ static float atomicMin(float* address, float val) {
#define IABS(a) abs(a) #define IABS(a) abs(a)
// Checks. // Checks.
#define CHECKOK C10_CUDA_CHECK #define ARGCHECK TORCH_CHECK_ARG
#define ARGCHECK THArgCheck
// Math. // Math.
#define NORM3DF(x, y, z) norm3df(x, y, z) #define NORM3DF(x, y, z) norm3df(x, y, z)

View File

@ -155,8 +155,7 @@ INLINE void ATOMICADD_F3(T* address, T val) {
#define IABS(a) abs(a) #define IABS(a) abs(a)
// Checks. // Checks.
#define CHECKOK THCheck #define ARGCHECK TORCH_CHECK_ARG
#define ARGCHECK THArgCheck
// Math. // Math.
#define NORM3DF(x, y, z) sqrtf(x* x + y * y + z * z) #define NORM3DF(x, y, z) sqrtf(x* x + y * y + z * z)

View File

@ -24,12 +24,10 @@
// #pragma diag_suppress = 68 // #pragma diag_suppress = 68
#include <ATen/cuda/CUDAContext.h> #include <ATen/cuda/CUDAContext.h>
// #pragma pop // #pragma pop
#include <TH/TH.h>
#include "../cuda/commands.h" #include "../cuda/commands.h"
#else #else
#pragma clang diagnostic push #pragma clang diagnostic push
#pragma clang diagnostic ignored "-Weverything" #pragma clang diagnostic ignored "-Weverything"
#include <TH/TH.h>
#pragma clang diagnostic pop #pragma clang diagnostic pop
#include "../host/commands.h" #include "../host/commands.h"
#endif #endif

View File

@ -32,16 +32,16 @@ Renderer::Renderer(
const uint& n_channels, const uint& n_channels,
const uint& n_track) { const uint& n_track) {
LOG_IF(INFO, PULSAR_LOG_INIT) << "Initializing renderer."; LOG_IF(INFO, PULSAR_LOG_INIT) << "Initializing renderer.";
THArgCheck(width > 0, 1, "image width must be > 0!"); TORCH_CHECK_ARG(width > 0, 1, "image width must be > 0!");
THArgCheck(height > 0, 2, "image height must be > 0!"); TORCH_CHECK_ARG(height > 0, 2, "image height must be > 0!");
THArgCheck(max_n_balls > 0, 3, "max_n_balls must be > 0!"); TORCH_CHECK_ARG(max_n_balls > 0, 3, "max_n_balls must be > 0!");
THArgCheck( TORCH_CHECK_ARG(
background_normalization_depth > 0.f && background_normalization_depth > 0.f &&
background_normalization_depth < 1.f, background_normalization_depth < 1.f,
5, 5,
"background_normalization_depth must be in ]0., 1.["); "background_normalization_depth must be in ]0., 1.[");
THArgCheck(n_channels > 0, 6, "n_channels must be > 0"); TORCH_CHECK_ARG(n_channels > 0, 6, "n_channels must be > 0");
THArgCheck( TORCH_CHECK_ARG(
n_track > 0 && n_track <= MAX_GRAD_SPHERES, n_track > 0 && n_track <= MAX_GRAD_SPHERES,
7, 7,
("n_track must be > 0 and <" + std::to_string(MAX_GRAD_SPHERES) + ("n_track must be > 0 and <" + std::to_string(MAX_GRAD_SPHERES) +
@ -92,7 +92,7 @@ bool Renderer::operator==(const Renderer& rhs) const {
}; };
void Renderer::ensure_on_device(torch::Device device, bool /*non_blocking*/) { void Renderer::ensure_on_device(torch::Device device, bool /*non_blocking*/) {
THArgCheck( TORCH_CHECK_ARG(
device.type() == c10::DeviceType::CUDA || device.type() == c10::DeviceType::CUDA ||
device.type() == c10::DeviceType::CPU, device.type() == c10::DeviceType::CPU,
1, 1,
@ -220,48 +220,48 @@ std::tuple<size_t, size_t, bool, torch::Tensor> Renderer::arg_check(
// Check all parameters adhere batch size. // Check all parameters adhere batch size.
batch_processing = true; batch_processing = true;
batch_size = vert_pos.size(0); batch_size = vert_pos.size(0);
THArgCheck( TORCH_CHECK_ARG(
vert_col.ndimension() == 3 && vert_col.ndimension() == 3 &&
vert_col.size(0) == static_cast<int64_t>(batch_size), vert_col.size(0) == static_cast<int64_t>(batch_size),
2, 2,
"vert_col needs to have batch size."); "vert_col needs to have batch size.");
THArgCheck( TORCH_CHECK_ARG(
vert_radii.ndimension() == 2 && vert_radii.ndimension() == 2 &&
vert_radii.size(0) == static_cast<int64_t>(batch_size), vert_radii.size(0) == static_cast<int64_t>(batch_size),
3, 3,
"vert_radii must be specified per batch."); "vert_radii must be specified per batch.");
THArgCheck( TORCH_CHECK_ARG(
cam_pos.ndimension() == 2 && cam_pos.ndimension() == 2 &&
cam_pos.size(0) == static_cast<int64_t>(batch_size), cam_pos.size(0) == static_cast<int64_t>(batch_size),
4, 4,
"cam_pos must be specified per batch and have the correct batch size."); "cam_pos must be specified per batch and have the correct batch size.");
THArgCheck( TORCH_CHECK_ARG(
pixel_0_0_center.ndimension() == 2 && pixel_0_0_center.ndimension() == 2 &&
pixel_0_0_center.size(0) == static_cast<int64_t>(batch_size), pixel_0_0_center.size(0) == static_cast<int64_t>(batch_size),
5, 5,
"pixel_0_0_center must be specified per batch."); "pixel_0_0_center must be specified per batch.");
THArgCheck( TORCH_CHECK_ARG(
pixel_vec_x.ndimension() == 2 && pixel_vec_x.ndimension() == 2 &&
pixel_vec_x.size(0) == static_cast<int64_t>(batch_size), pixel_vec_x.size(0) == static_cast<int64_t>(batch_size),
6, 6,
"pixel_vec_x must be specified per batch."); "pixel_vec_x must be specified per batch.");
THArgCheck( TORCH_CHECK_ARG(
pixel_vec_y.ndimension() == 2 && pixel_vec_y.ndimension() == 2 &&
pixel_vec_y.size(0) == static_cast<int64_t>(batch_size), pixel_vec_y.size(0) == static_cast<int64_t>(batch_size),
7, 7,
"pixel_vec_y must be specified per batch."); "pixel_vec_y must be specified per batch.");
THArgCheck( TORCH_CHECK_ARG(
focal_length.ndimension() == 1 && focal_length.ndimension() == 1 &&
focal_length.size(0) == static_cast<int64_t>(batch_size), focal_length.size(0) == static_cast<int64_t>(batch_size),
8, 8,
"focal_length must be specified per batch."); "focal_length must be specified per batch.");
THArgCheck( TORCH_CHECK_ARG(
principal_point_offsets.ndimension() == 2 && principal_point_offsets.ndimension() == 2 &&
principal_point_offsets.size(0) == static_cast<int64_t>(batch_size), principal_point_offsets.size(0) == static_cast<int64_t>(batch_size),
9, 9,
"principal_point_offsets must be specified per batch."); "principal_point_offsets must be specified per batch.");
if (opacity.has_value()) { if (opacity.has_value()) {
THArgCheck( TORCH_CHECK_ARG(
opacity.value().ndimension() == 2 && opacity.value().ndimension() == 2 &&
opacity.value().size(0) == static_cast<int64_t>(batch_size), opacity.value().size(0) == static_cast<int64_t>(batch_size),
13, 13,
@ -269,14 +269,14 @@ std::tuple<size_t, size_t, bool, torch::Tensor> Renderer::arg_check(
} }
// Check all parameters are for a matching number of points. // Check all parameters are for a matching number of points.
n_points = vert_pos.size(1); n_points = vert_pos.size(1);
THArgCheck( TORCH_CHECK_ARG(
vert_col.size(1) == static_cast<int64_t>(n_points), vert_col.size(1) == static_cast<int64_t>(n_points),
2, 2,
("The number of points for vertex positions (" + ("The number of points for vertex positions (" +
std::to_string(n_points) + ") and vertex colors (" + std::to_string(n_points) + ") and vertex colors (" +
std::to_string(vert_col.size(1)) + ") doesn't agree.") std::to_string(vert_col.size(1)) + ") doesn't agree.")
.c_str()); .c_str());
THArgCheck( TORCH_CHECK_ARG(
vert_radii.size(1) == static_cast<int64_t>(n_points), vert_radii.size(1) == static_cast<int64_t>(n_points),
3, 3,
("The number of points for vertex positions (" + ("The number of points for vertex positions (" +
@ -284,50 +284,50 @@ std::tuple<size_t, size_t, bool, torch::Tensor> Renderer::arg_check(
std::to_string(vert_col.size(1)) + ") doesn't agree.") std::to_string(vert_col.size(1)) + ") doesn't agree.")
.c_str()); .c_str());
if (opacity.has_value()) { if (opacity.has_value()) {
THArgCheck( TORCH_CHECK_ARG(
opacity.value().size(1) == static_cast<int64_t>(n_points), opacity.value().size(1) == static_cast<int64_t>(n_points),
13, 13,
"Opacity needs to be specified per point."); "Opacity needs to be specified per point.");
} }
// Check all parameters have the correct last dimension size. // Check all parameters have the correct last dimension size.
THArgCheck( TORCH_CHECK_ARG(
vert_pos.size(2) == 3, vert_pos.size(2) == 3,
1, 1,
("Vertex positions must be 3D (have shape " + ("Vertex positions must be 3D (have shape " +
std::to_string(vert_pos.size(2)) + ")!") std::to_string(vert_pos.size(2)) + ")!")
.c_str()); .c_str());
THArgCheck( TORCH_CHECK_ARG(
vert_col.size(2) == this->renderer_vec[0].cam.n_channels, vert_col.size(2) == this->renderer_vec[0].cam.n_channels,
2, 2,
("Vertex colors must have the right number of channels (have shape " + ("Vertex colors must have the right number of channels (have shape " +
std::to_string(vert_col.size(2)) + ", need " + std::to_string(vert_col.size(2)) + ", need " +
std::to_string(this->renderer_vec[0].cam.n_channels) + ")!") std::to_string(this->renderer_vec[0].cam.n_channels) + ")!")
.c_str()); .c_str());
THArgCheck( TORCH_CHECK_ARG(
cam_pos.size(1) == 3, cam_pos.size(1) == 3,
4, 4,
("Camera position must be 3D (has shape " + ("Camera position must be 3D (has shape " +
std::to_string(cam_pos.size(1)) + ")!") std::to_string(cam_pos.size(1)) + ")!")
.c_str()); .c_str());
THArgCheck( TORCH_CHECK_ARG(
pixel_0_0_center.size(1) == 3, pixel_0_0_center.size(1) == 3,
5, 5,
("pixel_0_0_center must be 3D (has shape " + ("pixel_0_0_center must be 3D (has shape " +
std::to_string(pixel_0_0_center.size(1)) + ")!") std::to_string(pixel_0_0_center.size(1)) + ")!")
.c_str()); .c_str());
THArgCheck( TORCH_CHECK_ARG(
pixel_vec_x.size(1) == 3, pixel_vec_x.size(1) == 3,
6, 6,
("pixel_vec_x must be 3D (has shape " + ("pixel_vec_x must be 3D (has shape " +
std::to_string(pixel_vec_x.size(1)) + ")!") std::to_string(pixel_vec_x.size(1)) + ")!")
.c_str()); .c_str());
THArgCheck( TORCH_CHECK_ARG(
pixel_vec_y.size(1) == 3, pixel_vec_y.size(1) == 3,
7, 7,
("pixel_vec_y must be 3D (has shape " + ("pixel_vec_y must be 3D (has shape " +
std::to_string(pixel_vec_y.size(1)) + ")!") std::to_string(pixel_vec_y.size(1)) + ")!")
.c_str()); .c_str());
THArgCheck( TORCH_CHECK_ARG(
principal_point_offsets.size(1) == 2, principal_point_offsets.size(1) == 2,
9, 9,
"principal_point_offsets must contain x and y offsets."); "principal_point_offsets must contain x and y offsets.");
@ -335,43 +335,43 @@ std::tuple<size_t, size_t, bool, torch::Tensor> Renderer::arg_check(
ensure_n_renderers_gte(batch_size); ensure_n_renderers_gte(batch_size);
} else { } else {
// Check all parameters are of correct dimension. // Check all parameters are of correct dimension.
THArgCheck( TORCH_CHECK_ARG(
vert_col.ndimension() == 2, 2, "vert_col needs to have dimension 2."); vert_col.ndimension() == 2, 2, "vert_col needs to have dimension 2.");
THArgCheck( TORCH_CHECK_ARG(
vert_radii.ndimension() == 1, 3, "vert_radii must have dimension 1."); vert_radii.ndimension() == 1, 3, "vert_radii must have dimension 1.");
THArgCheck(cam_pos.ndimension() == 1, 4, "cam_pos must have dimension 1."); TORCH_CHECK_ARG(cam_pos.ndimension() == 1, 4, "cam_pos must have dimension 1.");
THArgCheck( TORCH_CHECK_ARG(
pixel_0_0_center.ndimension() == 1, pixel_0_0_center.ndimension() == 1,
5, 5,
"pixel_0_0_center must have dimension 1."); "pixel_0_0_center must have dimension 1.");
THArgCheck( TORCH_CHECK_ARG(
pixel_vec_x.ndimension() == 1, 6, "pixel_vec_x must have dimension 1."); pixel_vec_x.ndimension() == 1, 6, "pixel_vec_x must have dimension 1.");
THArgCheck( TORCH_CHECK_ARG(
pixel_vec_y.ndimension() == 1, 7, "pixel_vec_y must have dimension 1."); pixel_vec_y.ndimension() == 1, 7, "pixel_vec_y must have dimension 1.");
THArgCheck( TORCH_CHECK_ARG(
focal_length.ndimension() == 0, focal_length.ndimension() == 0,
8, 8,
"focal_length must have dimension 0."); "focal_length must have dimension 0.");
THArgCheck( TORCH_CHECK_ARG(
principal_point_offsets.ndimension() == 1, principal_point_offsets.ndimension() == 1,
9, 9,
"principal_point_offsets must have dimension 1."); "principal_point_offsets must have dimension 1.");
if (opacity.has_value()) { if (opacity.has_value()) {
THArgCheck( TORCH_CHECK_ARG(
opacity.value().ndimension() == 1, opacity.value().ndimension() == 1,
13, 13,
"Opacity needs to be specified per sample."); "Opacity needs to be specified per sample.");
} }
// Check each. // Check each.
n_points = vert_pos.size(0); n_points = vert_pos.size(0);
THArgCheck( TORCH_CHECK_ARG(
vert_col.size(0) == static_cast<int64_t>(n_points), vert_col.size(0) == static_cast<int64_t>(n_points),
2, 2,
("The number of points for vertex positions (" + ("The number of points for vertex positions (" +
std::to_string(n_points) + ") and vertex colors (" + std::to_string(n_points) + ") and vertex colors (" +
std::to_string(vert_col.size(0)) + ") doesn't agree.") std::to_string(vert_col.size(0)) + ") doesn't agree.")
.c_str()); .c_str());
THArgCheck( TORCH_CHECK_ARG(
vert_radii.size(0) == static_cast<int64_t>(n_points), vert_radii.size(0) == static_cast<int64_t>(n_points),
3, 3,
("The number of points for vertex positions (" + ("The number of points for vertex positions (" +
@ -379,57 +379,57 @@ std::tuple<size_t, size_t, bool, torch::Tensor> Renderer::arg_check(
std::to_string(vert_col.size(0)) + ") doesn't agree.") std::to_string(vert_col.size(0)) + ") doesn't agree.")
.c_str()); .c_str());
if (opacity.has_value()) { if (opacity.has_value()) {
THArgCheck( TORCH_CHECK_ARG(
opacity.value().size(0) == static_cast<int64_t>(n_points), opacity.value().size(0) == static_cast<int64_t>(n_points),
12, 12,
"Opacity needs to be specified per point."); "Opacity needs to be specified per point.");
} }
// Check all parameters have the correct last dimension size. // Check all parameters have the correct last dimension size.
THArgCheck( TORCH_CHECK_ARG(
vert_pos.size(1) == 3, vert_pos.size(1) == 3,
1, 1,
("Vertex positions must be 3D (have shape " + ("Vertex positions must be 3D (have shape " +
std::to_string(vert_pos.size(1)) + ")!") std::to_string(vert_pos.size(1)) + ")!")
.c_str()); .c_str());
THArgCheck( TORCH_CHECK_ARG(
vert_col.size(1) == this->renderer_vec[0].cam.n_channels, vert_col.size(1) == this->renderer_vec[0].cam.n_channels,
2, 2,
("Vertex colors must have the right number of channels (have shape " + ("Vertex colors must have the right number of channels (have shape " +
std::to_string(vert_col.size(1)) + ", need " + std::to_string(vert_col.size(1)) + ", need " +
std::to_string(this->renderer_vec[0].cam.n_channels) + ")!") std::to_string(this->renderer_vec[0].cam.n_channels) + ")!")
.c_str()); .c_str());
THArgCheck( TORCH_CHECK_ARG(
cam_pos.size(0) == 3, cam_pos.size(0) == 3,
4, 4,
("Camera position must be 3D (has shape " + ("Camera position must be 3D (has shape " +
std::to_string(cam_pos.size(0)) + ")!") std::to_string(cam_pos.size(0)) + ")!")
.c_str()); .c_str());
THArgCheck( TORCH_CHECK_ARG(
pixel_0_0_center.size(0) == 3, pixel_0_0_center.size(0) == 3,
5, 5,
("pixel_0_0_center must be 3D (has shape " + ("pixel_0_0_center must be 3D (has shape " +
std::to_string(pixel_0_0_center.size(0)) + ")!") std::to_string(pixel_0_0_center.size(0)) + ")!")
.c_str()); .c_str());
THArgCheck( TORCH_CHECK_ARG(
pixel_vec_x.size(0) == 3, pixel_vec_x.size(0) == 3,
6, 6,
("pixel_vec_x must be 3D (has shape " + ("pixel_vec_x must be 3D (has shape " +
std::to_string(pixel_vec_x.size(0)) + ")!") std::to_string(pixel_vec_x.size(0)) + ")!")
.c_str()); .c_str());
THArgCheck( TORCH_CHECK_ARG(
pixel_vec_y.size(0) == 3, pixel_vec_y.size(0) == 3,
7, 7,
("pixel_vec_y must be 3D (has shape " + ("pixel_vec_y must be 3D (has shape " +
std::to_string(pixel_vec_y.size(0)) + ")!") std::to_string(pixel_vec_y.size(0)) + ")!")
.c_str()); .c_str());
THArgCheck( TORCH_CHECK_ARG(
principal_point_offsets.size(0) == 2, principal_point_offsets.size(0) == 2,
9, 9,
"principal_point_offsets must have x and y component."); "principal_point_offsets must have x and y component.");
} }
// Check device placement. // Check device placement.
auto dev = torch::device_of(vert_pos).value(); auto dev = torch::device_of(vert_pos).value();
THArgCheck( TORCH_CHECK_ARG(
dev.type() == this->device_type && dev.index() == this->device_index, dev.type() == this->device_type && dev.index() == this->device_index,
1, 1,
("Vertex positions must be stored on device " + ("Vertex positions must be stored on device " +
@ -439,7 +439,7 @@ std::tuple<size_t, size_t, bool, torch::Tensor> Renderer::arg_check(
std::to_string(dev.index()) + ".") std::to_string(dev.index()) + ".")
.c_str()); .c_str());
dev = torch::device_of(vert_col).value(); dev = torch::device_of(vert_col).value();
THArgCheck( TORCH_CHECK_ARG(
dev.type() == this->device_type && dev.index() == this->device_index, dev.type() == this->device_type && dev.index() == this->device_index,
2, 2,
("Vertex colors must be stored on device " + ("Vertex colors must be stored on device " +
@ -449,7 +449,7 @@ std::tuple<size_t, size_t, bool, torch::Tensor> Renderer::arg_check(
std::to_string(dev.index()) + ".") std::to_string(dev.index()) + ".")
.c_str()); .c_str());
dev = torch::device_of(vert_radii).value(); dev = torch::device_of(vert_radii).value();
THArgCheck( TORCH_CHECK_ARG(
dev.type() == this->device_type && dev.index() == this->device_index, dev.type() == this->device_type && dev.index() == this->device_index,
3, 3,
("Vertex radii must be stored on device " + ("Vertex radii must be stored on device " +
@ -459,7 +459,7 @@ std::tuple<size_t, size_t, bool, torch::Tensor> Renderer::arg_check(
std::to_string(dev.index()) + ".") std::to_string(dev.index()) + ".")
.c_str()); .c_str());
dev = torch::device_of(cam_pos).value(); dev = torch::device_of(cam_pos).value();
THArgCheck( TORCH_CHECK_ARG(
dev.type() == this->device_type && dev.index() == this->device_index, dev.type() == this->device_type && dev.index() == this->device_index,
4, 4,
("Camera position must be stored on device " + ("Camera position must be stored on device " +
@ -469,7 +469,7 @@ std::tuple<size_t, size_t, bool, torch::Tensor> Renderer::arg_check(
std::to_string(dev.index()) + ".") std::to_string(dev.index()) + ".")
.c_str()); .c_str());
dev = torch::device_of(pixel_0_0_center).value(); dev = torch::device_of(pixel_0_0_center).value();
THArgCheck( TORCH_CHECK_ARG(
dev.type() == this->device_type && dev.index() == this->device_index, dev.type() == this->device_type && dev.index() == this->device_index,
5, 5,
("pixel_0_0_center must be stored on device " + ("pixel_0_0_center must be stored on device " +
@ -479,7 +479,7 @@ std::tuple<size_t, size_t, bool, torch::Tensor> Renderer::arg_check(
std::to_string(dev.index()) + ".") std::to_string(dev.index()) + ".")
.c_str()); .c_str());
dev = torch::device_of(pixel_vec_x).value(); dev = torch::device_of(pixel_vec_x).value();
THArgCheck( TORCH_CHECK_ARG(
dev.type() == this->device_type && dev.index() == this->device_index, dev.type() == this->device_type && dev.index() == this->device_index,
6, 6,
("pixel_vec_x must be stored on device " + ("pixel_vec_x must be stored on device " +
@ -489,7 +489,7 @@ std::tuple<size_t, size_t, bool, torch::Tensor> Renderer::arg_check(
std::to_string(dev.index()) + ".") std::to_string(dev.index()) + ".")
.c_str()); .c_str());
dev = torch::device_of(pixel_vec_y).value(); dev = torch::device_of(pixel_vec_y).value();
THArgCheck( TORCH_CHECK_ARG(
dev.type() == this->device_type && dev.index() == this->device_index, dev.type() == this->device_type && dev.index() == this->device_index,
7, 7,
("pixel_vec_y must be stored on device " + ("pixel_vec_y must be stored on device " +
@ -499,7 +499,7 @@ std::tuple<size_t, size_t, bool, torch::Tensor> Renderer::arg_check(
std::to_string(dev.index()) + ".") std::to_string(dev.index()) + ".")
.c_str()); .c_str());
dev = torch::device_of(principal_point_offsets).value(); dev = torch::device_of(principal_point_offsets).value();
THArgCheck( TORCH_CHECK_ARG(
dev.type() == this->device_type && dev.index() == this->device_index, dev.type() == this->device_type && dev.index() == this->device_index,
9, 9,
("principal_point_offsets must be stored on device " + ("principal_point_offsets must be stored on device " +
@ -510,7 +510,7 @@ std::tuple<size_t, size_t, bool, torch::Tensor> Renderer::arg_check(
.c_str()); .c_str());
if (opacity.has_value()) { if (opacity.has_value()) {
dev = torch::device_of(opacity.value()).value(); dev = torch::device_of(opacity.value()).value();
THArgCheck( TORCH_CHECK_ARG(
dev.type() == this->device_type && dev.index() == this->device_index, dev.type() == this->device_type && dev.index() == this->device_index,
13, 13,
("opacity must be stored on device " + ("opacity must be stored on device " +
@ -521,33 +521,33 @@ std::tuple<size_t, size_t, bool, torch::Tensor> Renderer::arg_check(
.c_str()); .c_str());
} }
// Type checks. // Type checks.
THArgCheck( TORCH_CHECK_ARG(
vert_pos.scalar_type() == c10::kFloat, 1, "pulsar requires float types."); vert_pos.scalar_type() == c10::kFloat, 1, "pulsar requires float types.");
THArgCheck( TORCH_CHECK_ARG(
vert_col.scalar_type() == c10::kFloat, 2, "pulsar requires float types."); vert_col.scalar_type() == c10::kFloat, 2, "pulsar requires float types.");
THArgCheck( TORCH_CHECK_ARG(
vert_radii.scalar_type() == c10::kFloat, vert_radii.scalar_type() == c10::kFloat,
3, 3,
"pulsar requires float types."); "pulsar requires float types.");
THArgCheck( TORCH_CHECK_ARG(
cam_pos.scalar_type() == c10::kFloat, 4, "pulsar requires float types."); cam_pos.scalar_type() == c10::kFloat, 4, "pulsar requires float types.");
THArgCheck( TORCH_CHECK_ARG(
pixel_0_0_center.scalar_type() == c10::kFloat, pixel_0_0_center.scalar_type() == c10::kFloat,
5, 5,
"pulsar requires float types."); "pulsar requires float types.");
THArgCheck( TORCH_CHECK_ARG(
pixel_vec_x.scalar_type() == c10::kFloat, pixel_vec_x.scalar_type() == c10::kFloat,
6, 6,
"pulsar requires float types."); "pulsar requires float types.");
THArgCheck( TORCH_CHECK_ARG(
pixel_vec_y.scalar_type() == c10::kFloat, pixel_vec_y.scalar_type() == c10::kFloat,
7, 7,
"pulsar requires float types."); "pulsar requires float types.");
THArgCheck( TORCH_CHECK_ARG(
focal_length.scalar_type() == c10::kFloat, focal_length.scalar_type() == c10::kFloat,
8, 8,
"pulsar requires float types."); "pulsar requires float types.");
THArgCheck( TORCH_CHECK_ARG(
// Unfortunately, the PyTorch interface is inconsistent for // Unfortunately, the PyTorch interface is inconsistent for
// Int32: in Python, there exists an explicit int32 type, in // Int32: in Python, there exists an explicit int32 type, in
// C++ this is currently `c10::kInt`. // C++ this is currently `c10::kInt`.
@ -555,68 +555,68 @@ std::tuple<size_t, size_t, bool, torch::Tensor> Renderer::arg_check(
9, 9,
"principal_point_offsets must be provided as int32."); "principal_point_offsets must be provided as int32.");
if (opacity.has_value()) { if (opacity.has_value()) {
THArgCheck( TORCH_CHECK_ARG(
opacity.value().scalar_type() == c10::kFloat, opacity.value().scalar_type() == c10::kFloat,
13, 13,
"opacity must be a float type."); "opacity must be a float type.");
} }
// Content checks. // Content checks.
THArgCheck( TORCH_CHECK_ARG(
(vert_radii > FEPS).all().item<bool>(), (vert_radii > FEPS).all().item<bool>(),
3, 3,
("Vertex radii must be > FEPS (min is " + ("Vertex radii must be > FEPS (min is " +
std::to_string(vert_radii.min().item<float>()) + ").") std::to_string(vert_radii.min().item<float>()) + ").")
.c_str()); .c_str());
if (this->orthogonal()) { if (this->orthogonal()) {
THArgCheck( TORCH_CHECK_ARG(
(focal_length == 0.f).all().item<bool>(), (focal_length == 0.f).all().item<bool>(),
8, 8,
("for an orthogonal projection focal length must be zero (abs max: " + ("for an orthogonal projection focal length must be zero (abs max: " +
std::to_string(focal_length.abs().max().item<float>()) + ").") std::to_string(focal_length.abs().max().item<float>()) + ").")
.c_str()); .c_str());
} else { } else {
THArgCheck( TORCH_CHECK_ARG(
(focal_length > FEPS).all().item<bool>(), (focal_length > FEPS).all().item<bool>(),
8, 8,
("for a perspective projection focal length must be > FEPS (min " + ("for a perspective projection focal length must be > FEPS (min " +
std::to_string(focal_length.min().item<float>()) + ").") std::to_string(focal_length.min().item<float>()) + ").")
.c_str()); .c_str());
} }
THArgCheck( TORCH_CHECK_ARG(
gamma <= 1.f && gamma >= 1E-5f, gamma <= 1.f && gamma >= 1E-5f,
10, 10,
("gamma must be in [1E-5, 1] (" + std::to_string(gamma) + ").").c_str()); ("gamma must be in [1E-5, 1] (" + std::to_string(gamma) + ").").c_str());
if (min_depth == 0.f) { if (min_depth == 0.f) {
min_depth = focal_length.max().item<float>() + 2.f * FEPS; min_depth = focal_length.max().item<float>() + 2.f * FEPS;
} }
THArgCheck( TORCH_CHECK_ARG(
min_depth > focal_length.max().item<float>(), min_depth > focal_length.max().item<float>(),
12, 12,
("min_depth must be > focal_length (" + std::to_string(min_depth) + ("min_depth must be > focal_length (" + std::to_string(min_depth) +
" vs. " + std::to_string(focal_length.max().item<float>()) + ").") " vs. " + std::to_string(focal_length.max().item<float>()) + ").")
.c_str()); .c_str());
THArgCheck( TORCH_CHECK_ARG(
max_depth > min_depth + FEPS, max_depth > min_depth + FEPS,
11, 11,
("max_depth must be > min_depth + FEPS (" + std::to_string(max_depth) + ("max_depth must be > min_depth + FEPS (" + std::to_string(max_depth) +
" vs. " + std::to_string(min_depth + FEPS) + ").") " vs. " + std::to_string(min_depth + FEPS) + ").")
.c_str()); .c_str());
THArgCheck( TORCH_CHECK_ARG(
percent_allowed_difference >= 0.f && percent_allowed_difference < 1.f, percent_allowed_difference >= 0.f && percent_allowed_difference < 1.f,
14, 14,
("percent_allowed_difference must be in [0., 1.[ (" + ("percent_allowed_difference must be in [0., 1.[ (" +
std::to_string(percent_allowed_difference) + ").") std::to_string(percent_allowed_difference) + ").")
.c_str()); .c_str());
THArgCheck(max_n_hits > 0, 14, "max_n_hits must be > 0!"); TORCH_CHECK_ARG(max_n_hits > 0, 14, "max_n_hits must be > 0!");
THArgCheck(mode < 2, 15, "mode must be in {0, 1}."); TORCH_CHECK_ARG(mode < 2, 15, "mode must be in {0, 1}.");
torch::Tensor real_bg_col; torch::Tensor real_bg_col;
if (bg_col.has_value()) { if (bg_col.has_value()) {
THArgCheck( TORCH_CHECK_ARG(
bg_col.value().device().type() == this->device_type && bg_col.value().device().type() == this->device_type &&
bg_col.value().device().index() == this->device_index, bg_col.value().device().index() == this->device_index,
13, 13,
"bg_col must be stored on the renderer device!"); "bg_col must be stored on the renderer device!");
THArgCheck( TORCH_CHECK_ARG(
bg_col.value().ndimension() == 1 && bg_col.value().ndimension() == 1 &&
bg_col.value().size(0) == renderer_vec[0].cam.n_channels, bg_col.value().size(0) == renderer_vec[0].cam.n_channels,
13, 13,
@ -629,11 +629,11 @@ std::tuple<size_t, size_t, bool, torch::Tensor> Renderer::arg_check(
.to(c10::kFloat); .to(c10::kFloat);
} }
if (opacity.has_value()) { if (opacity.has_value()) {
THArgCheck( TORCH_CHECK_ARG(
(opacity.value() >= 0.f).all().item<bool>(), (opacity.value() >= 0.f).all().item<bool>(),
13, 13,
"opacity must be >= 0."); "opacity must be >= 0.");
THArgCheck( TORCH_CHECK_ARG(
(opacity.value() <= 1.f).all().item<bool>(), (opacity.value() <= 1.f).all().item<bool>(),
13, 13,
"opacity must be <= 1."); "opacity must be <= 1.");
@ -941,7 +941,7 @@ Renderer::backward(
max_n_hits, max_n_hits,
mode); mode);
// Additional checks for the gradient computation. // Additional checks for the gradient computation.
THArgCheck( TORCH_CHECK_ARG(
(grad_im.ndimension() == 3 + batch_processing && (grad_im.ndimension() == 3 + batch_processing &&
static_cast<uint>(grad_im.size(0 + batch_processing)) == static_cast<uint>(grad_im.size(0 + batch_processing)) ==
this->height() && this->height() &&
@ -950,7 +950,7 @@ Renderer::backward(
this->renderer_vec[0].cam.n_channels), this->renderer_vec[0].cam.n_channels),
1, 1,
"The gradient image size is not correct."); "The gradient image size is not correct.");
THArgCheck( TORCH_CHECK_ARG(
(image.ndimension() == 3 + batch_processing && (image.ndimension() == 3 + batch_processing &&
static_cast<uint>(image.size(0 + batch_processing)) == this->height() && static_cast<uint>(image.size(0 + batch_processing)) == this->height() &&
static_cast<uint>(image.size(1 + batch_processing)) == this->width() && static_cast<uint>(image.size(1 + batch_processing)) == this->width() &&
@ -958,32 +958,32 @@ Renderer::backward(
this->renderer_vec[0].cam.n_channels), this->renderer_vec[0].cam.n_channels),
2, 2,
"The result image size is not correct."); "The result image size is not correct.");
THArgCheck( TORCH_CHECK_ARG(
grad_im.scalar_type() == c10::kFloat, grad_im.scalar_type() == c10::kFloat,
1, 1,
"The gradient image must be of float type."); "The gradient image must be of float type.");
THArgCheck( TORCH_CHECK_ARG(
image.scalar_type() == c10::kFloat, image.scalar_type() == c10::kFloat,
2, 2,
"The image must be of float type."); "The image must be of float type.");
if (dif_opy) { if (dif_opy) {
THArgCheck(opacity.has_value(), 13, "dif_opy set requires opacity values."); TORCH_CHECK_ARG(opacity.has_value(), 13, "dif_opy set requires opacity values.");
} }
if (batch_processing) { if (batch_processing) {
THArgCheck( TORCH_CHECK_ARG(
grad_im.size(0) == static_cast<int64_t>(batch_size), grad_im.size(0) == static_cast<int64_t>(batch_size),
1, 1,
"Gradient image batch size must agree."); "Gradient image batch size must agree.");
THArgCheck( TORCH_CHECK_ARG(
image.size(0) == static_cast<int64_t>(batch_size), image.size(0) == static_cast<int64_t>(batch_size),
2, 2,
"Image batch size must agree."); "Image batch size must agree.");
THArgCheck( TORCH_CHECK_ARG(
forw_info.size(0) == static_cast<int64_t>(batch_size), forw_info.size(0) == static_cast<int64_t>(batch_size),
3, 3,
"forward info must have batch size."); "forward info must have batch size.");
} }
THArgCheck( TORCH_CHECK_ARG(
(forw_info.ndimension() == 3 + batch_processing && (forw_info.ndimension() == 3 + batch_processing &&
static_cast<uint>(forw_info.size(0 + batch_processing)) == static_cast<uint>(forw_info.size(0 + batch_processing)) ==
this->height() && this->height() &&
@ -993,13 +993,13 @@ Renderer::backward(
3 + 2 * this->n_track()), 3 + 2 * this->n_track()),
3, 3,
"The forward info image size is not correct."); "The forward info image size is not correct.");
THArgCheck( TORCH_CHECK_ARG(
forw_info.scalar_type() == c10::kFloat, forw_info.scalar_type() == c10::kFloat,
3, 3,
"The forward info must be of float type."); "The forward info must be of float type.");
// Check device. // Check device.
auto dev = torch::device_of(grad_im).value(); auto dev = torch::device_of(grad_im).value();
THArgCheck( TORCH_CHECK_ARG(
dev.type() == this->device_type && dev.index() == this->device_index, dev.type() == this->device_type && dev.index() == this->device_index,
1, 1,
("grad_im must be stored on device " + ("grad_im must be stored on device " +
@ -1009,7 +1009,7 @@ Renderer::backward(
std::to_string(dev.index()) + ".") std::to_string(dev.index()) + ".")
.c_str()); .c_str());
dev = torch::device_of(image).value(); dev = torch::device_of(image).value();
THArgCheck( TORCH_CHECK_ARG(
dev.type() == this->device_type && dev.index() == this->device_index, dev.type() == this->device_type && dev.index() == this->device_index,
2, 2,
("image must be stored on device " + ("image must be stored on device " +
@ -1019,7 +1019,7 @@ Renderer::backward(
std::to_string(dev.index()) + ".") std::to_string(dev.index()) + ".")
.c_str()); .c_str());
dev = torch::device_of(forw_info).value(); dev = torch::device_of(forw_info).value();
THArgCheck( TORCH_CHECK_ARG(
dev.type() == this->device_type && dev.index() == this->device_index, dev.type() == this->device_type && dev.index() == this->device_index,
3, 3,
("forw_info must be stored on device " + ("forw_info must be stored on device " +
@ -1029,7 +1029,7 @@ Renderer::backward(
std::to_string(dev.index()) + ".") std::to_string(dev.index()) + ".")
.c_str()); .c_str());
if (dbg_pos.has_value()) { if (dbg_pos.has_value()) {
THArgCheck( TORCH_CHECK_ARG(
dbg_pos.value().first < this->width() && dbg_pos.value().first < this->width() &&
dbg_pos.value().second < this->height(), dbg_pos.value().second < this->height(),
23, 23,