From 069c9fd759461b60f089dd01f6779f9437a5695f Mon Sep 17 00:00:00 2001 From: Jeremy Reizenstein Date: Wed, 15 Dec 2021 08:32:49 -0800 Subject: [PATCH] pytorch TORCH_CHECK_ARG version compatibility Summary: Restore compatibility with old C++ after recent torch change. https://github.com/facebookresearch/pytorch3d/issues/995 Reviewed By: patricklabatut Differential Revision: D33093174 fbshipit-source-id: 841202fb875d601db265e93dcf9cfa4249d02b25 --- pytorch3d/csrc/pulsar/cuda/commands.h | 4 +++- pytorch3d/csrc/pulsar/host/commands.h | 4 +++- pytorch3d/csrc/pulsar/pytorch/renderer.cpp | 6 ++++++ 3 files changed, 12 insertions(+), 2 deletions(-) diff --git a/pytorch3d/csrc/pulsar/cuda/commands.h b/pytorch3d/csrc/pulsar/cuda/commands.h index 77bf479c..a7a038bd 100644 --- a/pytorch3d/csrc/pulsar/cuda/commands.h +++ b/pytorch3d/csrc/pulsar/cuda/commands.h @@ -208,7 +208,9 @@ __device__ static float atomicMin(float* address, float val) { #define IABS(a) abs(a) // Checks. -#define ARGCHECK TORCH_CHECK_ARG +// like TORCH_CHECK_ARG in PyTorch > 1.10 +#define ARGCHECK(cond, argN, ...) \ + TORCH_CHECK(cond, "invalid argument ", argN, ": ", __VA_ARGS__) // Math. #define NORM3DF(x, y, z) norm3df(x, y, z) diff --git a/pytorch3d/csrc/pulsar/host/commands.h b/pytorch3d/csrc/pulsar/host/commands.h index 6384969d..997c410d 100644 --- a/pytorch3d/csrc/pulsar/host/commands.h +++ b/pytorch3d/csrc/pulsar/host/commands.h @@ -155,7 +155,9 @@ INLINE void ATOMICADD_F3(T* address, T val) { #define IABS(a) abs(a) // Checks. -#define ARGCHECK TORCH_CHECK_ARG +// like TORCH_CHECK_ARG in PyTorch > 1.10 +#define ARGCHECK(cond, argN, ...) \ + TORCH_CHECK(cond, "invalid argument ", argN, ": ", __VA_ARGS__) // Math. #define NORM3DF(x, y, z) sqrtf(x* x + y * y + z * z) diff --git a/pytorch3d/csrc/pulsar/pytorch/renderer.cpp b/pytorch3d/csrc/pulsar/pytorch/renderer.cpp index 18347595..39f07c3b 100644 --- a/pytorch3d/csrc/pulsar/pytorch/renderer.cpp +++ b/pytorch3d/csrc/pulsar/pytorch/renderer.cpp @@ -17,6 +17,12 @@ #include #endif +#ifndef TORCH_CHECK_ARG +// torch <= 1.10 +#define TORCH_CHECK_ARG(cond, argN, ...) \ + TORCH_CHECK(cond, "invalid argument ", argN, ": ", __VA_ARGS__) +#endif + namespace PRE = ::pulsar::Renderer; namespace pulsar {