diff --git a/pytorch3d/csrc/ext.cpp b/pytorch3d/csrc/ext.cpp index 49ec02c6..0d698526 100644 --- a/pytorch3d/csrc/ext.cpp +++ b/pytorch3d/csrc/ext.cpp @@ -7,15 +7,11 @@ */ // clang-format off -#if !defined(USE_ROCM) #include "./pulsar/global.h" // Include before . -#endif #include // clang-format on -#if !defined(USE_ROCM) #include "./pulsar/pytorch/renderer.h" #include "./pulsar/pytorch/tensor_util.h" -#endif #include "ball_query/ball_query.h" #include "blending/sigmoid_alpha_blend.h" #include "compositing/alpha_composite.h" @@ -104,7 +100,6 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { // Pulsar. // Pulsar not enabled on AMD. -#if !defined(USE_ROCM) #ifdef PULSAR_LOGGING_ENABLED c10::ShowLogInfoToStderr(); #endif @@ -189,5 +184,4 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.attr("MAX_UINT") = py::int_(MAX_UINT); m.attr("MAX_USHORT") = py::int_(MAX_USHORT); m.attr("PULSAR_MAX_GRAD_SPHERES") = py::int_(MAX_GRAD_SPHERES); -#endif } diff --git a/pytorch3d/csrc/pulsar/global.h b/pytorch3d/csrc/pulsar/global.h index 3cea957e..16040119 100644 --- a/pytorch3d/csrc/pulsar/global.h +++ b/pytorch3d/csrc/pulsar/global.h @@ -36,11 +36,13 @@ #pragma nv_diag_suppress 2951 #pragma nv_diag_suppress 2967 #else +#if !defined(USE_ROCM) #pragma diag_suppress = attribute_not_allowed #pragma diag_suppress = 1866 #pragma diag_suppress = 2941 #pragma diag_suppress = 2951 #pragma diag_suppress = 2967 +#endif //! USE_ROCM #endif #else // __CUDACC__ #define INLINE inline @@ -56,7 +58,9 @@ #pragma clang diagnostic pop #ifdef WITH_CUDA #include +#if !defined(USE_ROCM) #include +#endif //! USE_ROCM #else #ifndef cudaStream_t typedef void* cudaStream_t; diff --git a/pytorch3d/csrc/pulsar/cuda/README.md b/pytorch3d/csrc/pulsar/gpu/README.md similarity index 100% rename from pytorch3d/csrc/pulsar/cuda/README.md rename to pytorch3d/csrc/pulsar/gpu/README.md diff --git a/pytorch3d/csrc/pulsar/cuda/commands.h b/pytorch3d/csrc/pulsar/gpu/commands.h similarity index 98% rename from pytorch3d/csrc/pulsar/cuda/commands.h rename to pytorch3d/csrc/pulsar/gpu/commands.h index 00e6f378..1e8f83f0 100644 --- a/pytorch3d/csrc/pulsar/cuda/commands.h +++ b/pytorch3d/csrc/pulsar/gpu/commands.h @@ -59,6 +59,11 @@ getLastCudaError(const char* errorMessage, const char* file, const int line) { #define SHARED __shared__ #define ACTIVEMASK() __activemask() #define BALLOT(mask, val) __ballot_sync((mask), val) + +/* TODO (ROCM-6.2): None of the WARP_* are used anywhere and ROCM-6.2 natively + * supports __shfl_*. Disabling until the move to ROCM-6.2. + */ +#if !defined(USE_ROCM) /** * Find the cumulative sum within a warp up to the current * thread lane, with each mask thread contributing base. @@ -115,6 +120,7 @@ INLINE DEVICE float3 WARP_SUM_FLOAT3( ret.z = WARP_SUM(group, mask, base.z); return ret; } +#endif //! USE_ROCM // Floating point. // #define FMUL(a, b) __fmul_rn((a), (b)) @@ -142,6 +148,7 @@ INLINE DEVICE float3 WARP_SUM_FLOAT3( #define FMA(x, y, z) __fmaf_rn((x), (y), (z)) #define I2F(a) __int2float_rn(a) #define FRCP(x) __frcp_rn(x) +#if !defined(USE_ROCM) __device__ static float atomicMax(float* address, float val) { int* address_as_i = (int*)address; int old = *address_as_i, assumed; @@ -166,6 +173,7 @@ __device__ static float atomicMin(float* address, float val) { } while (assumed != old); return __int_as_float(old); } +#endif //! USE_ROCM #define DMAX(a, b) FMAX(a, b) #define DMIN(a, b) FMIN(a, b) #define DSQRT(a) sqrt(a) diff --git a/pytorch3d/csrc/pulsar/cuda/renderer.backward.gpu.cu b/pytorch3d/csrc/pulsar/gpu/renderer.backward.gpu.cu similarity index 100% rename from pytorch3d/csrc/pulsar/cuda/renderer.backward.gpu.cu rename to pytorch3d/csrc/pulsar/gpu/renderer.backward.gpu.cu diff --git a/pytorch3d/csrc/pulsar/cuda/renderer.backward_dbg.gpu.cu b/pytorch3d/csrc/pulsar/gpu/renderer.backward_dbg.gpu.cu similarity index 100% rename from pytorch3d/csrc/pulsar/cuda/renderer.backward_dbg.gpu.cu rename to pytorch3d/csrc/pulsar/gpu/renderer.backward_dbg.gpu.cu diff --git a/pytorch3d/csrc/pulsar/cuda/renderer.calc_gradients.gpu.cu b/pytorch3d/csrc/pulsar/gpu/renderer.calc_gradients.gpu.cu similarity index 100% rename from pytorch3d/csrc/pulsar/cuda/renderer.calc_gradients.gpu.cu rename to pytorch3d/csrc/pulsar/gpu/renderer.calc_gradients.gpu.cu diff --git a/pytorch3d/csrc/pulsar/cuda/renderer.calc_signature.gpu.cu b/pytorch3d/csrc/pulsar/gpu/renderer.calc_signature.gpu.cu similarity index 100% rename from pytorch3d/csrc/pulsar/cuda/renderer.calc_signature.gpu.cu rename to pytorch3d/csrc/pulsar/gpu/renderer.calc_signature.gpu.cu diff --git a/pytorch3d/csrc/pulsar/cuda/renderer.construct.gpu.cu b/pytorch3d/csrc/pulsar/gpu/renderer.construct.gpu.cu similarity index 100% rename from pytorch3d/csrc/pulsar/cuda/renderer.construct.gpu.cu rename to pytorch3d/csrc/pulsar/gpu/renderer.construct.gpu.cu diff --git a/pytorch3d/csrc/pulsar/cuda/renderer.create_selector.gpu.cu b/pytorch3d/csrc/pulsar/gpu/renderer.create_selector.gpu.cu similarity index 100% rename from pytorch3d/csrc/pulsar/cuda/renderer.create_selector.gpu.cu rename to pytorch3d/csrc/pulsar/gpu/renderer.create_selector.gpu.cu diff --git a/pytorch3d/csrc/pulsar/cuda/renderer.destruct.gpu.cu b/pytorch3d/csrc/pulsar/gpu/renderer.destruct.gpu.cu similarity index 100% rename from pytorch3d/csrc/pulsar/cuda/renderer.destruct.gpu.cu rename to pytorch3d/csrc/pulsar/gpu/renderer.destruct.gpu.cu diff --git a/pytorch3d/csrc/pulsar/cuda/renderer.fill_bg.gpu.cu b/pytorch3d/csrc/pulsar/gpu/renderer.fill_bg.gpu.cu similarity index 100% rename from pytorch3d/csrc/pulsar/cuda/renderer.fill_bg.gpu.cu rename to pytorch3d/csrc/pulsar/gpu/renderer.fill_bg.gpu.cu diff --git a/pytorch3d/csrc/pulsar/cuda/renderer.forward.gpu.cu b/pytorch3d/csrc/pulsar/gpu/renderer.forward.gpu.cu similarity index 100% rename from pytorch3d/csrc/pulsar/cuda/renderer.forward.gpu.cu rename to pytorch3d/csrc/pulsar/gpu/renderer.forward.gpu.cu diff --git a/pytorch3d/csrc/pulsar/cuda/renderer.norm_cam_gradients.gpu.cu b/pytorch3d/csrc/pulsar/gpu/renderer.norm_cam_gradients.gpu.cu similarity index 100% rename from pytorch3d/csrc/pulsar/cuda/renderer.norm_cam_gradients.gpu.cu rename to pytorch3d/csrc/pulsar/gpu/renderer.norm_cam_gradients.gpu.cu diff --git a/pytorch3d/csrc/pulsar/cuda/renderer.norm_sphere_gradients.gpu.cu b/pytorch3d/csrc/pulsar/gpu/renderer.norm_sphere_gradients.gpu.cu similarity index 100% rename from pytorch3d/csrc/pulsar/cuda/renderer.norm_sphere_gradients.gpu.cu rename to pytorch3d/csrc/pulsar/gpu/renderer.norm_sphere_gradients.gpu.cu diff --git a/pytorch3d/csrc/pulsar/cuda/renderer.render.gpu.cu b/pytorch3d/csrc/pulsar/gpu/renderer.render.gpu.cu similarity index 100% rename from pytorch3d/csrc/pulsar/cuda/renderer.render.gpu.cu rename to pytorch3d/csrc/pulsar/gpu/renderer.render.gpu.cu diff --git a/pytorch3d/csrc/pulsar/include/camera.device.h b/pytorch3d/csrc/pulsar/include/camera.device.h index f003db31..73b9a80c 100644 --- a/pytorch3d/csrc/pulsar/include/camera.device.h +++ b/pytorch3d/csrc/pulsar/include/camera.device.h @@ -14,7 +14,7 @@ #include "./commands.h" namespace pulsar { -IHD CamGradInfo::CamGradInfo() { +IHD CamGradInfo::CamGradInfo(int x) { cam_pos = make_float3(0.f, 0.f, 0.f); pixel_0_0_center = make_float3(0.f, 0.f, 0.f); pixel_dir_x = make_float3(0.f, 0.f, 0.f); diff --git a/pytorch3d/csrc/pulsar/include/camera.h b/pytorch3d/csrc/pulsar/include/camera.h index cbb583a1..7264c811 100644 --- a/pytorch3d/csrc/pulsar/include/camera.h +++ b/pytorch3d/csrc/pulsar/include/camera.h @@ -63,7 +63,7 @@ inline bool operator==(const CamInfo& a, const CamInfo& b) { }; struct CamGradInfo { - HOST DEVICE CamGradInfo(); + HOST DEVICE CamGradInfo(int = 0); float3 cam_pos; float3 pixel_0_0_center; float3 pixel_dir_x; diff --git a/pytorch3d/csrc/pulsar/include/commands.h b/pytorch3d/csrc/pulsar/include/commands.h index c0b17f40..9cc61db6 100644 --- a/pytorch3d/csrc/pulsar/include/commands.h +++ b/pytorch3d/csrc/pulsar/include/commands.h @@ -24,7 +24,7 @@ // #pragma diag_suppress = 68 #include // #pragma pop -#include "../cuda/commands.h" +#include "../gpu/commands.h" #else #pragma clang diagnostic push #pragma clang diagnostic ignored "-Weverything" diff --git a/pytorch3d/csrc/pulsar/include/math.h b/pytorch3d/csrc/pulsar/include/math.h index d77e2ee1..1ea6b567 100644 --- a/pytorch3d/csrc/pulsar/include/math.h +++ b/pytorch3d/csrc/pulsar/include/math.h @@ -46,6 +46,7 @@ IHD float3 outer_product_sum(const float3& a) { } // TODO: put intrinsics here. +#if !defined(USE_ROCM) IHD float3 operator+(const float3& a, const float3& b) { return make_float3(a.x + b.x, a.y + b.y, a.z + b.z); } @@ -93,6 +94,7 @@ IHD float3 operator*(const float3& a, const float3& b) { IHD float3 operator*(const float& a, const float3& b) { return b * a; } +#endif //! USE_ROCM INLINE DEVICE float length(const float3& v) { // TODO: benchmark what's faster. diff --git a/pytorch3d/csrc/pulsar/include/renderer.render.device.h b/pytorch3d/csrc/pulsar/include/renderer.render.device.h index d1fe23f4..ab13c66d 100644 --- a/pytorch3d/csrc/pulsar/include/renderer.render.device.h +++ b/pytorch3d/csrc/pulsar/include/renderer.render.device.h @@ -283,9 +283,15 @@ GLOBAL void render( (percent_allowed_difference > 0.f && max_closest_possible_intersection > depth_threshold) || tracker.get_n_hits() >= max_n_hits; +#if defined(__CUDACC__) && defined(__HIP_PLATFORM_AMD__) + unsigned long long warp_done = __ballot(done); + int warp_done_bit_cnt = __popcll(warp_done); +#else uint warp_done = thread_warp.ballot(done); + int warp_done_bit_cnt = POPC(warp_done); +#endif //__CUDACC__ && __HIP_PLATFORM_AMD__ if (thread_warp.thread_rank() == 0) - ATOMICADD_B(&n_pixels_done, POPC(warp_done)); + ATOMICADD_B(&n_pixels_done, warp_done_bit_cnt); // This sync is necessary to keep n_loaded until all threads are done with // painting. thread_block.sync(); diff --git a/pytorch3d/renderer/__init__.py b/pytorch3d/renderer/__init__.py index e19e3eef..cc9d7fab 100644 --- a/pytorch3d/renderer/__init__.py +++ b/pytorch3d/renderer/__init__.py @@ -76,13 +76,9 @@ from .points import ( PointsRasterizationSettings, PointsRasterizer, PointsRenderer, + PulsarPointsRenderer, rasterize_points, ) - -# Pulsar is not enabled on amd. -if not torch.version.hip: - from .points import PulsarPointsRenderer - from .splatter_blend import SplatterBlender from .utils import ( convert_to_tensors_and_broadcast, diff --git a/pytorch3d/renderer/points/__init__.py b/pytorch3d/renderer/points/__init__.py index 2185f8f9..24c26c77 100644 --- a/pytorch3d/renderer/points/__init__.py +++ b/pytorch3d/renderer/points/__init__.py @@ -10,9 +10,7 @@ import torch from .compositor import AlphaCompositor, NormWeightedCompositor -# Pulsar not enabled on amd. -if not torch.version.hip: - from .pulsar.unified import PulsarPointsRenderer +from .pulsar.unified import PulsarPointsRenderer from .rasterize_points import rasterize_points from .rasterizer import PointsRasterizationSettings, PointsRasterizer