mirror of
				https://github.com/facebookresearch/pytorch3d.git
				synced 2025-11-04 18:02:14 +08:00 
			
		
		
		
	Remove THGeneral (#69041)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/69041 `TH_CONCAT_{N}` is still being used by THP so I've moved that into it's own header but all the compiled code is gone. Test Plan: Imported from OSS Reviewed By: anjali411 Differential Revision: D32872477 Pulled By: ngimel fbshipit-source-id: 06c82d8f96dbcee0715be407c61dfc7d7e8be47a
This commit is contained in:
		
							parent
							
								
									d049cd2e01
								
							
						
					
					
						commit
						f8fe9a2be1
					
				@ -208,8 +208,7 @@ __device__ static float atomicMin(float* address, float val) {
 | 
			
		||||
#define IABS(a) abs(a)
 | 
			
		||||
 | 
			
		||||
// Checks.
 | 
			
		||||
#define CHECKOK C10_CUDA_CHECK
 | 
			
		||||
#define ARGCHECK THArgCheck
 | 
			
		||||
#define ARGCHECK TORCH_CHECK_ARG
 | 
			
		||||
 | 
			
		||||
// Math.
 | 
			
		||||
#define NORM3DF(x, y, z) norm3df(x, y, z)
 | 
			
		||||
 | 
			
		||||
@ -155,8 +155,7 @@ INLINE void ATOMICADD_F3(T* address, T val) {
 | 
			
		||||
#define IABS(a) abs(a)
 | 
			
		||||
 | 
			
		||||
// Checks.
 | 
			
		||||
#define CHECKOK THCheck
 | 
			
		||||
#define ARGCHECK THArgCheck
 | 
			
		||||
#define ARGCHECK TORCH_CHECK_ARG
 | 
			
		||||
 | 
			
		||||
// Math.
 | 
			
		||||
#define NORM3DF(x, y, z) sqrtf(x* x + y * y + z * z)
 | 
			
		||||
 | 
			
		||||
@ -24,12 +24,10 @@
 | 
			
		||||
// #pragma diag_suppress = 68
 | 
			
		||||
#include <ATen/cuda/CUDAContext.h>
 | 
			
		||||
// #pragma pop
 | 
			
		||||
#include <TH/TH.h>
 | 
			
		||||
#include "../cuda/commands.h"
 | 
			
		||||
#else
 | 
			
		||||
#pragma clang diagnostic push
 | 
			
		||||
#pragma clang diagnostic ignored "-Weverything"
 | 
			
		||||
#include <TH/TH.h>
 | 
			
		||||
#pragma clang diagnostic pop
 | 
			
		||||
#include "../host/commands.h"
 | 
			
		||||
#endif
 | 
			
		||||
 | 
			
		||||
@ -32,16 +32,16 @@ Renderer::Renderer(
 | 
			
		||||
    const uint& n_channels,
 | 
			
		||||
    const uint& n_track) {
 | 
			
		||||
  LOG_IF(INFO, PULSAR_LOG_INIT) << "Initializing renderer.";
 | 
			
		||||
  THArgCheck(width > 0, 1, "image width must be > 0!");
 | 
			
		||||
  THArgCheck(height > 0, 2, "image height must be > 0!");
 | 
			
		||||
  THArgCheck(max_n_balls > 0, 3, "max_n_balls must be > 0!");
 | 
			
		||||
  THArgCheck(
 | 
			
		||||
  TORCH_CHECK_ARG(width > 0, 1, "image width must be > 0!");
 | 
			
		||||
  TORCH_CHECK_ARG(height > 0, 2, "image height must be > 0!");
 | 
			
		||||
  TORCH_CHECK_ARG(max_n_balls > 0, 3, "max_n_balls must be > 0!");
 | 
			
		||||
  TORCH_CHECK_ARG(
 | 
			
		||||
      background_normalization_depth > 0.f &&
 | 
			
		||||
          background_normalization_depth < 1.f,
 | 
			
		||||
      5,
 | 
			
		||||
      "background_normalization_depth must be in ]0., 1.[");
 | 
			
		||||
  THArgCheck(n_channels > 0, 6, "n_channels must be > 0");
 | 
			
		||||
  THArgCheck(
 | 
			
		||||
  TORCH_CHECK_ARG(n_channels > 0, 6, "n_channels must be > 0");
 | 
			
		||||
  TORCH_CHECK_ARG(
 | 
			
		||||
      n_track > 0 && n_track <= MAX_GRAD_SPHERES,
 | 
			
		||||
      7,
 | 
			
		||||
      ("n_track must be > 0 and <" + std::to_string(MAX_GRAD_SPHERES) +
 | 
			
		||||
@ -92,7 +92,7 @@ bool Renderer::operator==(const Renderer& rhs) const {
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
void Renderer::ensure_on_device(torch::Device device, bool /*non_blocking*/) {
 | 
			
		||||
  THArgCheck(
 | 
			
		||||
  TORCH_CHECK_ARG(
 | 
			
		||||
      device.type() == c10::DeviceType::CUDA ||
 | 
			
		||||
          device.type() == c10::DeviceType::CPU,
 | 
			
		||||
      1,
 | 
			
		||||
@ -220,48 +220,48 @@ std::tuple<size_t, size_t, bool, torch::Tensor> Renderer::arg_check(
 | 
			
		||||
    // Check all parameters adhere batch size.
 | 
			
		||||
    batch_processing = true;
 | 
			
		||||
    batch_size = vert_pos.size(0);
 | 
			
		||||
    THArgCheck(
 | 
			
		||||
    TORCH_CHECK_ARG(
 | 
			
		||||
        vert_col.ndimension() == 3 &&
 | 
			
		||||
            vert_col.size(0) == static_cast<int64_t>(batch_size),
 | 
			
		||||
        2,
 | 
			
		||||
        "vert_col needs to have batch size.");
 | 
			
		||||
    THArgCheck(
 | 
			
		||||
    TORCH_CHECK_ARG(
 | 
			
		||||
        vert_radii.ndimension() == 2 &&
 | 
			
		||||
            vert_radii.size(0) == static_cast<int64_t>(batch_size),
 | 
			
		||||
        3,
 | 
			
		||||
        "vert_radii must be specified per batch.");
 | 
			
		||||
    THArgCheck(
 | 
			
		||||
    TORCH_CHECK_ARG(
 | 
			
		||||
        cam_pos.ndimension() == 2 &&
 | 
			
		||||
            cam_pos.size(0) == static_cast<int64_t>(batch_size),
 | 
			
		||||
        4,
 | 
			
		||||
        "cam_pos must be specified per batch and have the correct batch size.");
 | 
			
		||||
    THArgCheck(
 | 
			
		||||
    TORCH_CHECK_ARG(
 | 
			
		||||
        pixel_0_0_center.ndimension() == 2 &&
 | 
			
		||||
            pixel_0_0_center.size(0) == static_cast<int64_t>(batch_size),
 | 
			
		||||
        5,
 | 
			
		||||
        "pixel_0_0_center must be specified per batch.");
 | 
			
		||||
    THArgCheck(
 | 
			
		||||
    TORCH_CHECK_ARG(
 | 
			
		||||
        pixel_vec_x.ndimension() == 2 &&
 | 
			
		||||
            pixel_vec_x.size(0) == static_cast<int64_t>(batch_size),
 | 
			
		||||
        6,
 | 
			
		||||
        "pixel_vec_x must be specified per batch.");
 | 
			
		||||
    THArgCheck(
 | 
			
		||||
    TORCH_CHECK_ARG(
 | 
			
		||||
        pixel_vec_y.ndimension() == 2 &&
 | 
			
		||||
            pixel_vec_y.size(0) == static_cast<int64_t>(batch_size),
 | 
			
		||||
        7,
 | 
			
		||||
        "pixel_vec_y must be specified per batch.");
 | 
			
		||||
    THArgCheck(
 | 
			
		||||
    TORCH_CHECK_ARG(
 | 
			
		||||
        focal_length.ndimension() == 1 &&
 | 
			
		||||
            focal_length.size(0) == static_cast<int64_t>(batch_size),
 | 
			
		||||
        8,
 | 
			
		||||
        "focal_length must be specified per batch.");
 | 
			
		||||
    THArgCheck(
 | 
			
		||||
    TORCH_CHECK_ARG(
 | 
			
		||||
        principal_point_offsets.ndimension() == 2 &&
 | 
			
		||||
            principal_point_offsets.size(0) == static_cast<int64_t>(batch_size),
 | 
			
		||||
        9,
 | 
			
		||||
        "principal_point_offsets must be specified per batch.");
 | 
			
		||||
    if (opacity.has_value()) {
 | 
			
		||||
      THArgCheck(
 | 
			
		||||
      TORCH_CHECK_ARG(
 | 
			
		||||
          opacity.value().ndimension() == 2 &&
 | 
			
		||||
              opacity.value().size(0) == static_cast<int64_t>(batch_size),
 | 
			
		||||
          13,
 | 
			
		||||
@ -269,14 +269,14 @@ std::tuple<size_t, size_t, bool, torch::Tensor> Renderer::arg_check(
 | 
			
		||||
    }
 | 
			
		||||
    // Check all parameters are for a matching number of points.
 | 
			
		||||
    n_points = vert_pos.size(1);
 | 
			
		||||
    THArgCheck(
 | 
			
		||||
    TORCH_CHECK_ARG(
 | 
			
		||||
        vert_col.size(1) == static_cast<int64_t>(n_points),
 | 
			
		||||
        2,
 | 
			
		||||
        ("The number of points for vertex positions (" +
 | 
			
		||||
         std::to_string(n_points) + ") and vertex colors (" +
 | 
			
		||||
         std::to_string(vert_col.size(1)) + ") doesn't agree.")
 | 
			
		||||
            .c_str());
 | 
			
		||||
    THArgCheck(
 | 
			
		||||
    TORCH_CHECK_ARG(
 | 
			
		||||
        vert_radii.size(1) == static_cast<int64_t>(n_points),
 | 
			
		||||
        3,
 | 
			
		||||
        ("The number of points for vertex positions (" +
 | 
			
		||||
@ -284,50 +284,50 @@ std::tuple<size_t, size_t, bool, torch::Tensor> Renderer::arg_check(
 | 
			
		||||
         std::to_string(vert_col.size(1)) + ") doesn't agree.")
 | 
			
		||||
            .c_str());
 | 
			
		||||
    if (opacity.has_value()) {
 | 
			
		||||
      THArgCheck(
 | 
			
		||||
      TORCH_CHECK_ARG(
 | 
			
		||||
          opacity.value().size(1) == static_cast<int64_t>(n_points),
 | 
			
		||||
          13,
 | 
			
		||||
          "Opacity needs to be specified per point.");
 | 
			
		||||
    }
 | 
			
		||||
    // Check all parameters have the correct last dimension size.
 | 
			
		||||
    THArgCheck(
 | 
			
		||||
    TORCH_CHECK_ARG(
 | 
			
		||||
        vert_pos.size(2) == 3,
 | 
			
		||||
        1,
 | 
			
		||||
        ("Vertex positions must be 3D (have shape " +
 | 
			
		||||
         std::to_string(vert_pos.size(2)) + ")!")
 | 
			
		||||
            .c_str());
 | 
			
		||||
    THArgCheck(
 | 
			
		||||
    TORCH_CHECK_ARG(
 | 
			
		||||
        vert_col.size(2) == this->renderer_vec[0].cam.n_channels,
 | 
			
		||||
        2,
 | 
			
		||||
        ("Vertex colors must have the right number of channels (have shape " +
 | 
			
		||||
         std::to_string(vert_col.size(2)) + ", need " +
 | 
			
		||||
         std::to_string(this->renderer_vec[0].cam.n_channels) + ")!")
 | 
			
		||||
            .c_str());
 | 
			
		||||
    THArgCheck(
 | 
			
		||||
    TORCH_CHECK_ARG(
 | 
			
		||||
        cam_pos.size(1) == 3,
 | 
			
		||||
        4,
 | 
			
		||||
        ("Camera position must be 3D (has shape " +
 | 
			
		||||
         std::to_string(cam_pos.size(1)) + ")!")
 | 
			
		||||
            .c_str());
 | 
			
		||||
    THArgCheck(
 | 
			
		||||
    TORCH_CHECK_ARG(
 | 
			
		||||
        pixel_0_0_center.size(1) == 3,
 | 
			
		||||
        5,
 | 
			
		||||
        ("pixel_0_0_center must be 3D (has shape " +
 | 
			
		||||
         std::to_string(pixel_0_0_center.size(1)) + ")!")
 | 
			
		||||
            .c_str());
 | 
			
		||||
    THArgCheck(
 | 
			
		||||
    TORCH_CHECK_ARG(
 | 
			
		||||
        pixel_vec_x.size(1) == 3,
 | 
			
		||||
        6,
 | 
			
		||||
        ("pixel_vec_x must be 3D (has shape " +
 | 
			
		||||
         std::to_string(pixel_vec_x.size(1)) + ")!")
 | 
			
		||||
            .c_str());
 | 
			
		||||
    THArgCheck(
 | 
			
		||||
    TORCH_CHECK_ARG(
 | 
			
		||||
        pixel_vec_y.size(1) == 3,
 | 
			
		||||
        7,
 | 
			
		||||
        ("pixel_vec_y must be 3D (has shape " +
 | 
			
		||||
         std::to_string(pixel_vec_y.size(1)) + ")!")
 | 
			
		||||
            .c_str());
 | 
			
		||||
    THArgCheck(
 | 
			
		||||
    TORCH_CHECK_ARG(
 | 
			
		||||
        principal_point_offsets.size(1) == 2,
 | 
			
		||||
        9,
 | 
			
		||||
        "principal_point_offsets must contain x and y offsets.");
 | 
			
		||||
@ -335,43 +335,43 @@ std::tuple<size_t, size_t, bool, torch::Tensor> Renderer::arg_check(
 | 
			
		||||
    ensure_n_renderers_gte(batch_size);
 | 
			
		||||
  } else {
 | 
			
		||||
    // Check all parameters are of correct dimension.
 | 
			
		||||
    THArgCheck(
 | 
			
		||||
    TORCH_CHECK_ARG(
 | 
			
		||||
        vert_col.ndimension() == 2, 2, "vert_col needs to have dimension 2.");
 | 
			
		||||
    THArgCheck(
 | 
			
		||||
    TORCH_CHECK_ARG(
 | 
			
		||||
        vert_radii.ndimension() == 1, 3, "vert_radii must have dimension 1.");
 | 
			
		||||
    THArgCheck(cam_pos.ndimension() == 1, 4, "cam_pos must have dimension 1.");
 | 
			
		||||
    THArgCheck(
 | 
			
		||||
    TORCH_CHECK_ARG(cam_pos.ndimension() == 1, 4, "cam_pos must have dimension 1.");
 | 
			
		||||
    TORCH_CHECK_ARG(
 | 
			
		||||
        pixel_0_0_center.ndimension() == 1,
 | 
			
		||||
        5,
 | 
			
		||||
        "pixel_0_0_center must have dimension 1.");
 | 
			
		||||
    THArgCheck(
 | 
			
		||||
    TORCH_CHECK_ARG(
 | 
			
		||||
        pixel_vec_x.ndimension() == 1, 6, "pixel_vec_x must have dimension 1.");
 | 
			
		||||
    THArgCheck(
 | 
			
		||||
    TORCH_CHECK_ARG(
 | 
			
		||||
        pixel_vec_y.ndimension() == 1, 7, "pixel_vec_y must have dimension 1.");
 | 
			
		||||
    THArgCheck(
 | 
			
		||||
    TORCH_CHECK_ARG(
 | 
			
		||||
        focal_length.ndimension() == 0,
 | 
			
		||||
        8,
 | 
			
		||||
        "focal_length must have dimension 0.");
 | 
			
		||||
    THArgCheck(
 | 
			
		||||
    TORCH_CHECK_ARG(
 | 
			
		||||
        principal_point_offsets.ndimension() == 1,
 | 
			
		||||
        9,
 | 
			
		||||
        "principal_point_offsets must have dimension 1.");
 | 
			
		||||
    if (opacity.has_value()) {
 | 
			
		||||
      THArgCheck(
 | 
			
		||||
      TORCH_CHECK_ARG(
 | 
			
		||||
          opacity.value().ndimension() == 1,
 | 
			
		||||
          13,
 | 
			
		||||
          "Opacity needs to be specified per sample.");
 | 
			
		||||
    }
 | 
			
		||||
    // Check each.
 | 
			
		||||
    n_points = vert_pos.size(0);
 | 
			
		||||
    THArgCheck(
 | 
			
		||||
    TORCH_CHECK_ARG(
 | 
			
		||||
        vert_col.size(0) == static_cast<int64_t>(n_points),
 | 
			
		||||
        2,
 | 
			
		||||
        ("The number of points for vertex positions (" +
 | 
			
		||||
         std::to_string(n_points) + ") and vertex colors (" +
 | 
			
		||||
         std::to_string(vert_col.size(0)) + ") doesn't agree.")
 | 
			
		||||
            .c_str());
 | 
			
		||||
    THArgCheck(
 | 
			
		||||
    TORCH_CHECK_ARG(
 | 
			
		||||
        vert_radii.size(0) == static_cast<int64_t>(n_points),
 | 
			
		||||
        3,
 | 
			
		||||
        ("The number of points for vertex positions (" +
 | 
			
		||||
@ -379,57 +379,57 @@ std::tuple<size_t, size_t, bool, torch::Tensor> Renderer::arg_check(
 | 
			
		||||
         std::to_string(vert_col.size(0)) + ") doesn't agree.")
 | 
			
		||||
            .c_str());
 | 
			
		||||
    if (opacity.has_value()) {
 | 
			
		||||
      THArgCheck(
 | 
			
		||||
      TORCH_CHECK_ARG(
 | 
			
		||||
          opacity.value().size(0) == static_cast<int64_t>(n_points),
 | 
			
		||||
          12,
 | 
			
		||||
          "Opacity needs to be specified per point.");
 | 
			
		||||
    }
 | 
			
		||||
    // Check all parameters have the correct last dimension size.
 | 
			
		||||
    THArgCheck(
 | 
			
		||||
    TORCH_CHECK_ARG(
 | 
			
		||||
        vert_pos.size(1) == 3,
 | 
			
		||||
        1,
 | 
			
		||||
        ("Vertex positions must be 3D (have shape " +
 | 
			
		||||
         std::to_string(vert_pos.size(1)) + ")!")
 | 
			
		||||
            .c_str());
 | 
			
		||||
    THArgCheck(
 | 
			
		||||
    TORCH_CHECK_ARG(
 | 
			
		||||
        vert_col.size(1) == this->renderer_vec[0].cam.n_channels,
 | 
			
		||||
        2,
 | 
			
		||||
        ("Vertex colors must have the right number of channels (have shape " +
 | 
			
		||||
         std::to_string(vert_col.size(1)) + ", need " +
 | 
			
		||||
         std::to_string(this->renderer_vec[0].cam.n_channels) + ")!")
 | 
			
		||||
            .c_str());
 | 
			
		||||
    THArgCheck(
 | 
			
		||||
    TORCH_CHECK_ARG(
 | 
			
		||||
        cam_pos.size(0) == 3,
 | 
			
		||||
        4,
 | 
			
		||||
        ("Camera position must be 3D (has shape " +
 | 
			
		||||
         std::to_string(cam_pos.size(0)) + ")!")
 | 
			
		||||
            .c_str());
 | 
			
		||||
    THArgCheck(
 | 
			
		||||
    TORCH_CHECK_ARG(
 | 
			
		||||
        pixel_0_0_center.size(0) == 3,
 | 
			
		||||
        5,
 | 
			
		||||
        ("pixel_0_0_center must be 3D (has shape " +
 | 
			
		||||
         std::to_string(pixel_0_0_center.size(0)) + ")!")
 | 
			
		||||
            .c_str());
 | 
			
		||||
    THArgCheck(
 | 
			
		||||
    TORCH_CHECK_ARG(
 | 
			
		||||
        pixel_vec_x.size(0) == 3,
 | 
			
		||||
        6,
 | 
			
		||||
        ("pixel_vec_x must be 3D (has shape " +
 | 
			
		||||
         std::to_string(pixel_vec_x.size(0)) + ")!")
 | 
			
		||||
            .c_str());
 | 
			
		||||
    THArgCheck(
 | 
			
		||||
    TORCH_CHECK_ARG(
 | 
			
		||||
        pixel_vec_y.size(0) == 3,
 | 
			
		||||
        7,
 | 
			
		||||
        ("pixel_vec_y must be 3D (has shape " +
 | 
			
		||||
         std::to_string(pixel_vec_y.size(0)) + ")!")
 | 
			
		||||
            .c_str());
 | 
			
		||||
    THArgCheck(
 | 
			
		||||
    TORCH_CHECK_ARG(
 | 
			
		||||
        principal_point_offsets.size(0) == 2,
 | 
			
		||||
        9,
 | 
			
		||||
        "principal_point_offsets must have x and y component.");
 | 
			
		||||
  }
 | 
			
		||||
  // Check device placement.
 | 
			
		||||
  auto dev = torch::device_of(vert_pos).value();
 | 
			
		||||
  THArgCheck(
 | 
			
		||||
  TORCH_CHECK_ARG(
 | 
			
		||||
      dev.type() == this->device_type && dev.index() == this->device_index,
 | 
			
		||||
      1,
 | 
			
		||||
      ("Vertex positions must be stored on device " +
 | 
			
		||||
@ -439,7 +439,7 @@ std::tuple<size_t, size_t, bool, torch::Tensor> Renderer::arg_check(
 | 
			
		||||
       std::to_string(dev.index()) + ".")
 | 
			
		||||
          .c_str());
 | 
			
		||||
  dev = torch::device_of(vert_col).value();
 | 
			
		||||
  THArgCheck(
 | 
			
		||||
  TORCH_CHECK_ARG(
 | 
			
		||||
      dev.type() == this->device_type && dev.index() == this->device_index,
 | 
			
		||||
      2,
 | 
			
		||||
      ("Vertex colors must be stored on device " +
 | 
			
		||||
@ -449,7 +449,7 @@ std::tuple<size_t, size_t, bool, torch::Tensor> Renderer::arg_check(
 | 
			
		||||
       std::to_string(dev.index()) + ".")
 | 
			
		||||
          .c_str());
 | 
			
		||||
  dev = torch::device_of(vert_radii).value();
 | 
			
		||||
  THArgCheck(
 | 
			
		||||
  TORCH_CHECK_ARG(
 | 
			
		||||
      dev.type() == this->device_type && dev.index() == this->device_index,
 | 
			
		||||
      3,
 | 
			
		||||
      ("Vertex radii must be stored on device " +
 | 
			
		||||
@ -459,7 +459,7 @@ std::tuple<size_t, size_t, bool, torch::Tensor> Renderer::arg_check(
 | 
			
		||||
       std::to_string(dev.index()) + ".")
 | 
			
		||||
          .c_str());
 | 
			
		||||
  dev = torch::device_of(cam_pos).value();
 | 
			
		||||
  THArgCheck(
 | 
			
		||||
  TORCH_CHECK_ARG(
 | 
			
		||||
      dev.type() == this->device_type && dev.index() == this->device_index,
 | 
			
		||||
      4,
 | 
			
		||||
      ("Camera position must be stored on device " +
 | 
			
		||||
@ -469,7 +469,7 @@ std::tuple<size_t, size_t, bool, torch::Tensor> Renderer::arg_check(
 | 
			
		||||
       std::to_string(dev.index()) + ".")
 | 
			
		||||
          .c_str());
 | 
			
		||||
  dev = torch::device_of(pixel_0_0_center).value();
 | 
			
		||||
  THArgCheck(
 | 
			
		||||
  TORCH_CHECK_ARG(
 | 
			
		||||
      dev.type() == this->device_type && dev.index() == this->device_index,
 | 
			
		||||
      5,
 | 
			
		||||
      ("pixel_0_0_center must be stored on device " +
 | 
			
		||||
@ -479,7 +479,7 @@ std::tuple<size_t, size_t, bool, torch::Tensor> Renderer::arg_check(
 | 
			
		||||
       std::to_string(dev.index()) + ".")
 | 
			
		||||
          .c_str());
 | 
			
		||||
  dev = torch::device_of(pixel_vec_x).value();
 | 
			
		||||
  THArgCheck(
 | 
			
		||||
  TORCH_CHECK_ARG(
 | 
			
		||||
      dev.type() == this->device_type && dev.index() == this->device_index,
 | 
			
		||||
      6,
 | 
			
		||||
      ("pixel_vec_x must be stored on device " +
 | 
			
		||||
@ -489,7 +489,7 @@ std::tuple<size_t, size_t, bool, torch::Tensor> Renderer::arg_check(
 | 
			
		||||
       std::to_string(dev.index()) + ".")
 | 
			
		||||
          .c_str());
 | 
			
		||||
  dev = torch::device_of(pixel_vec_y).value();
 | 
			
		||||
  THArgCheck(
 | 
			
		||||
  TORCH_CHECK_ARG(
 | 
			
		||||
      dev.type() == this->device_type && dev.index() == this->device_index,
 | 
			
		||||
      7,
 | 
			
		||||
      ("pixel_vec_y must be stored on device " +
 | 
			
		||||
@ -499,7 +499,7 @@ std::tuple<size_t, size_t, bool, torch::Tensor> Renderer::arg_check(
 | 
			
		||||
       std::to_string(dev.index()) + ".")
 | 
			
		||||
          .c_str());
 | 
			
		||||
  dev = torch::device_of(principal_point_offsets).value();
 | 
			
		||||
  THArgCheck(
 | 
			
		||||
  TORCH_CHECK_ARG(
 | 
			
		||||
      dev.type() == this->device_type && dev.index() == this->device_index,
 | 
			
		||||
      9,
 | 
			
		||||
      ("principal_point_offsets must be stored on device " +
 | 
			
		||||
@ -510,7 +510,7 @@ std::tuple<size_t, size_t, bool, torch::Tensor> Renderer::arg_check(
 | 
			
		||||
          .c_str());
 | 
			
		||||
  if (opacity.has_value()) {
 | 
			
		||||
    dev = torch::device_of(opacity.value()).value();
 | 
			
		||||
    THArgCheck(
 | 
			
		||||
    TORCH_CHECK_ARG(
 | 
			
		||||
        dev.type() == this->device_type && dev.index() == this->device_index,
 | 
			
		||||
        13,
 | 
			
		||||
        ("opacity must be stored on device " +
 | 
			
		||||
@ -521,33 +521,33 @@ std::tuple<size_t, size_t, bool, torch::Tensor> Renderer::arg_check(
 | 
			
		||||
            .c_str());
 | 
			
		||||
  }
 | 
			
		||||
  // Type checks.
 | 
			
		||||
  THArgCheck(
 | 
			
		||||
  TORCH_CHECK_ARG(
 | 
			
		||||
      vert_pos.scalar_type() == c10::kFloat, 1, "pulsar requires float types.");
 | 
			
		||||
  THArgCheck(
 | 
			
		||||
  TORCH_CHECK_ARG(
 | 
			
		||||
      vert_col.scalar_type() == c10::kFloat, 2, "pulsar requires float types.");
 | 
			
		||||
  THArgCheck(
 | 
			
		||||
  TORCH_CHECK_ARG(
 | 
			
		||||
      vert_radii.scalar_type() == c10::kFloat,
 | 
			
		||||
      3,
 | 
			
		||||
      "pulsar requires float types.");
 | 
			
		||||
  THArgCheck(
 | 
			
		||||
  TORCH_CHECK_ARG(
 | 
			
		||||
      cam_pos.scalar_type() == c10::kFloat, 4, "pulsar requires float types.");
 | 
			
		||||
  THArgCheck(
 | 
			
		||||
  TORCH_CHECK_ARG(
 | 
			
		||||
      pixel_0_0_center.scalar_type() == c10::kFloat,
 | 
			
		||||
      5,
 | 
			
		||||
      "pulsar requires float types.");
 | 
			
		||||
  THArgCheck(
 | 
			
		||||
  TORCH_CHECK_ARG(
 | 
			
		||||
      pixel_vec_x.scalar_type() == c10::kFloat,
 | 
			
		||||
      6,
 | 
			
		||||
      "pulsar requires float types.");
 | 
			
		||||
  THArgCheck(
 | 
			
		||||
  TORCH_CHECK_ARG(
 | 
			
		||||
      pixel_vec_y.scalar_type() == c10::kFloat,
 | 
			
		||||
      7,
 | 
			
		||||
      "pulsar requires float types.");
 | 
			
		||||
  THArgCheck(
 | 
			
		||||
  TORCH_CHECK_ARG(
 | 
			
		||||
      focal_length.scalar_type() == c10::kFloat,
 | 
			
		||||
      8,
 | 
			
		||||
      "pulsar requires float types.");
 | 
			
		||||
  THArgCheck(
 | 
			
		||||
  TORCH_CHECK_ARG(
 | 
			
		||||
      // Unfortunately, the PyTorch interface is inconsistent for
 | 
			
		||||
      // Int32: in Python, there exists an explicit int32 type, in
 | 
			
		||||
      // C++ this is currently `c10::kInt`.
 | 
			
		||||
@ -555,68 +555,68 @@ std::tuple<size_t, size_t, bool, torch::Tensor> Renderer::arg_check(
 | 
			
		||||
      9,
 | 
			
		||||
      "principal_point_offsets must be provided as int32.");
 | 
			
		||||
  if (opacity.has_value()) {
 | 
			
		||||
    THArgCheck(
 | 
			
		||||
    TORCH_CHECK_ARG(
 | 
			
		||||
        opacity.value().scalar_type() == c10::kFloat,
 | 
			
		||||
        13,
 | 
			
		||||
        "opacity must be a float type.");
 | 
			
		||||
  }
 | 
			
		||||
  // Content checks.
 | 
			
		||||
  THArgCheck(
 | 
			
		||||
  TORCH_CHECK_ARG(
 | 
			
		||||
      (vert_radii > FEPS).all().item<bool>(),
 | 
			
		||||
      3,
 | 
			
		||||
      ("Vertex radii must be > FEPS (min is " +
 | 
			
		||||
       std::to_string(vert_radii.min().item<float>()) + ").")
 | 
			
		||||
          .c_str());
 | 
			
		||||
  if (this->orthogonal()) {
 | 
			
		||||
    THArgCheck(
 | 
			
		||||
    TORCH_CHECK_ARG(
 | 
			
		||||
        (focal_length == 0.f).all().item<bool>(),
 | 
			
		||||
        8,
 | 
			
		||||
        ("for an orthogonal projection focal length must be zero (abs max: " +
 | 
			
		||||
         std::to_string(focal_length.abs().max().item<float>()) + ").")
 | 
			
		||||
            .c_str());
 | 
			
		||||
  } else {
 | 
			
		||||
    THArgCheck(
 | 
			
		||||
    TORCH_CHECK_ARG(
 | 
			
		||||
        (focal_length > FEPS).all().item<bool>(),
 | 
			
		||||
        8,
 | 
			
		||||
        ("for a perspective projection focal length must be > FEPS (min " +
 | 
			
		||||
         std::to_string(focal_length.min().item<float>()) + ").")
 | 
			
		||||
            .c_str());
 | 
			
		||||
  }
 | 
			
		||||
  THArgCheck(
 | 
			
		||||
  TORCH_CHECK_ARG(
 | 
			
		||||
      gamma <= 1.f && gamma >= 1E-5f,
 | 
			
		||||
      10,
 | 
			
		||||
      ("gamma must be in [1E-5, 1] (" + std::to_string(gamma) + ").").c_str());
 | 
			
		||||
  if (min_depth == 0.f) {
 | 
			
		||||
    min_depth = focal_length.max().item<float>() + 2.f * FEPS;
 | 
			
		||||
  }
 | 
			
		||||
  THArgCheck(
 | 
			
		||||
  TORCH_CHECK_ARG(
 | 
			
		||||
      min_depth > focal_length.max().item<float>(),
 | 
			
		||||
      12,
 | 
			
		||||
      ("min_depth must be > focal_length (" + std::to_string(min_depth) +
 | 
			
		||||
       " vs. " + std::to_string(focal_length.max().item<float>()) + ").")
 | 
			
		||||
          .c_str());
 | 
			
		||||
  THArgCheck(
 | 
			
		||||
  TORCH_CHECK_ARG(
 | 
			
		||||
      max_depth > min_depth + FEPS,
 | 
			
		||||
      11,
 | 
			
		||||
      ("max_depth must be > min_depth + FEPS (" + std::to_string(max_depth) +
 | 
			
		||||
       " vs. " + std::to_string(min_depth + FEPS) + ").")
 | 
			
		||||
          .c_str());
 | 
			
		||||
  THArgCheck(
 | 
			
		||||
  TORCH_CHECK_ARG(
 | 
			
		||||
      percent_allowed_difference >= 0.f && percent_allowed_difference < 1.f,
 | 
			
		||||
      14,
 | 
			
		||||
      ("percent_allowed_difference must be in [0., 1.[ (" +
 | 
			
		||||
       std::to_string(percent_allowed_difference) + ").")
 | 
			
		||||
          .c_str());
 | 
			
		||||
  THArgCheck(max_n_hits > 0, 14, "max_n_hits must be > 0!");
 | 
			
		||||
  THArgCheck(mode < 2, 15, "mode must be in {0, 1}.");
 | 
			
		||||
  TORCH_CHECK_ARG(max_n_hits > 0, 14, "max_n_hits must be > 0!");
 | 
			
		||||
  TORCH_CHECK_ARG(mode < 2, 15, "mode must be in {0, 1}.");
 | 
			
		||||
  torch::Tensor real_bg_col;
 | 
			
		||||
  if (bg_col.has_value()) {
 | 
			
		||||
    THArgCheck(
 | 
			
		||||
    TORCH_CHECK_ARG(
 | 
			
		||||
        bg_col.value().device().type() == this->device_type &&
 | 
			
		||||
            bg_col.value().device().index() == this->device_index,
 | 
			
		||||
        13,
 | 
			
		||||
        "bg_col must be stored on the renderer device!");
 | 
			
		||||
    THArgCheck(
 | 
			
		||||
    TORCH_CHECK_ARG(
 | 
			
		||||
        bg_col.value().ndimension() == 1 &&
 | 
			
		||||
            bg_col.value().size(0) == renderer_vec[0].cam.n_channels,
 | 
			
		||||
        13,
 | 
			
		||||
@ -629,11 +629,11 @@ std::tuple<size_t, size_t, bool, torch::Tensor> Renderer::arg_check(
 | 
			
		||||
                      .to(c10::kFloat);
 | 
			
		||||
  }
 | 
			
		||||
  if (opacity.has_value()) {
 | 
			
		||||
    THArgCheck(
 | 
			
		||||
    TORCH_CHECK_ARG(
 | 
			
		||||
        (opacity.value() >= 0.f).all().item<bool>(),
 | 
			
		||||
        13,
 | 
			
		||||
        "opacity must be >= 0.");
 | 
			
		||||
    THArgCheck(
 | 
			
		||||
    TORCH_CHECK_ARG(
 | 
			
		||||
        (opacity.value() <= 1.f).all().item<bool>(),
 | 
			
		||||
        13,
 | 
			
		||||
        "opacity must be <= 1.");
 | 
			
		||||
@ -941,7 +941,7 @@ Renderer::backward(
 | 
			
		||||
          max_n_hits,
 | 
			
		||||
          mode);
 | 
			
		||||
  // Additional checks for the gradient computation.
 | 
			
		||||
  THArgCheck(
 | 
			
		||||
  TORCH_CHECK_ARG(
 | 
			
		||||
      (grad_im.ndimension() == 3 + batch_processing &&
 | 
			
		||||
       static_cast<uint>(grad_im.size(0 + batch_processing)) ==
 | 
			
		||||
           this->height() &&
 | 
			
		||||
@ -950,7 +950,7 @@ Renderer::backward(
 | 
			
		||||
           this->renderer_vec[0].cam.n_channels),
 | 
			
		||||
      1,
 | 
			
		||||
      "The gradient image size is not correct.");
 | 
			
		||||
  THArgCheck(
 | 
			
		||||
  TORCH_CHECK_ARG(
 | 
			
		||||
      (image.ndimension() == 3 + batch_processing &&
 | 
			
		||||
       static_cast<uint>(image.size(0 + batch_processing)) == this->height() &&
 | 
			
		||||
       static_cast<uint>(image.size(1 + batch_processing)) == this->width() &&
 | 
			
		||||
@ -958,32 +958,32 @@ Renderer::backward(
 | 
			
		||||
           this->renderer_vec[0].cam.n_channels),
 | 
			
		||||
      2,
 | 
			
		||||
      "The result image size is not correct.");
 | 
			
		||||
  THArgCheck(
 | 
			
		||||
  TORCH_CHECK_ARG(
 | 
			
		||||
      grad_im.scalar_type() == c10::kFloat,
 | 
			
		||||
      1,
 | 
			
		||||
      "The gradient image must be of float type.");
 | 
			
		||||
  THArgCheck(
 | 
			
		||||
  TORCH_CHECK_ARG(
 | 
			
		||||
      image.scalar_type() == c10::kFloat,
 | 
			
		||||
      2,
 | 
			
		||||
      "The image must be of float type.");
 | 
			
		||||
  if (dif_opy) {
 | 
			
		||||
    THArgCheck(opacity.has_value(), 13, "dif_opy set requires opacity values.");
 | 
			
		||||
    TORCH_CHECK_ARG(opacity.has_value(), 13, "dif_opy set requires opacity values.");
 | 
			
		||||
  }
 | 
			
		||||
  if (batch_processing) {
 | 
			
		||||
    THArgCheck(
 | 
			
		||||
    TORCH_CHECK_ARG(
 | 
			
		||||
        grad_im.size(0) == static_cast<int64_t>(batch_size),
 | 
			
		||||
        1,
 | 
			
		||||
        "Gradient image batch size must agree.");
 | 
			
		||||
    THArgCheck(
 | 
			
		||||
    TORCH_CHECK_ARG(
 | 
			
		||||
        image.size(0) == static_cast<int64_t>(batch_size),
 | 
			
		||||
        2,
 | 
			
		||||
        "Image batch size must agree.");
 | 
			
		||||
    THArgCheck(
 | 
			
		||||
    TORCH_CHECK_ARG(
 | 
			
		||||
        forw_info.size(0) == static_cast<int64_t>(batch_size),
 | 
			
		||||
        3,
 | 
			
		||||
        "forward info must have batch size.");
 | 
			
		||||
  }
 | 
			
		||||
  THArgCheck(
 | 
			
		||||
  TORCH_CHECK_ARG(
 | 
			
		||||
      (forw_info.ndimension() == 3 + batch_processing &&
 | 
			
		||||
       static_cast<uint>(forw_info.size(0 + batch_processing)) ==
 | 
			
		||||
           this->height() &&
 | 
			
		||||
@ -993,13 +993,13 @@ Renderer::backward(
 | 
			
		||||
           3 + 2 * this->n_track()),
 | 
			
		||||
      3,
 | 
			
		||||
      "The forward info image size is not correct.");
 | 
			
		||||
  THArgCheck(
 | 
			
		||||
  TORCH_CHECK_ARG(
 | 
			
		||||
      forw_info.scalar_type() == c10::kFloat,
 | 
			
		||||
      3,
 | 
			
		||||
      "The forward info must be of float type.");
 | 
			
		||||
  // Check device.
 | 
			
		||||
  auto dev = torch::device_of(grad_im).value();
 | 
			
		||||
  THArgCheck(
 | 
			
		||||
  TORCH_CHECK_ARG(
 | 
			
		||||
      dev.type() == this->device_type && dev.index() == this->device_index,
 | 
			
		||||
      1,
 | 
			
		||||
      ("grad_im must be stored on device " +
 | 
			
		||||
@ -1009,7 +1009,7 @@ Renderer::backward(
 | 
			
		||||
       std::to_string(dev.index()) + ".")
 | 
			
		||||
          .c_str());
 | 
			
		||||
  dev = torch::device_of(image).value();
 | 
			
		||||
  THArgCheck(
 | 
			
		||||
  TORCH_CHECK_ARG(
 | 
			
		||||
      dev.type() == this->device_type && dev.index() == this->device_index,
 | 
			
		||||
      2,
 | 
			
		||||
      ("image must be stored on device " +
 | 
			
		||||
@ -1019,7 +1019,7 @@ Renderer::backward(
 | 
			
		||||
       std::to_string(dev.index()) + ".")
 | 
			
		||||
          .c_str());
 | 
			
		||||
  dev = torch::device_of(forw_info).value();
 | 
			
		||||
  THArgCheck(
 | 
			
		||||
  TORCH_CHECK_ARG(
 | 
			
		||||
      dev.type() == this->device_type && dev.index() == this->device_index,
 | 
			
		||||
      3,
 | 
			
		||||
      ("forw_info must be stored on device " +
 | 
			
		||||
@ -1029,7 +1029,7 @@ Renderer::backward(
 | 
			
		||||
       std::to_string(dev.index()) + ".")
 | 
			
		||||
          .c_str());
 | 
			
		||||
  if (dbg_pos.has_value()) {
 | 
			
		||||
    THArgCheck(
 | 
			
		||||
    TORCH_CHECK_ARG(
 | 
			
		||||
        dbg_pos.value().first < this->width() &&
 | 
			
		||||
            dbg_pos.value().second < this->height(),
 | 
			
		||||
        23,
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user