From 05cbea115acbbcbea77999c03d55155b23479991 Mon Sep 17 00:00:00 2001 From: Josh Fromm Date: Thu, 15 Aug 2024 16:18:22 -0700 Subject: [PATCH] Hipify Pytorch3D (#1851) Summary: X-link: https://github.com/pytorch/pytorch/pull/133343 X-link: https://github.com/fairinternal/pytorch3d/pull/45 Pull Request resolved: https://github.com/facebookresearch/pytorch3d/pull/1851 Enables pytorch3d to build on AMD. An important part of enabling this was not compiling the Pulsar backend when the target is AMD. There are simply too many kernel incompatibilites to make it work (I tried haha). Fortunately, it doesnt seem like most modern applications of pytorch3d rely on Pulsar. We should be able to unlock most of pytorch3d's goodness on AMD without it. Reviewed By: bottler, houseroad Differential Revision: D61171993 fbshipit-source-id: fd4aee378a3568b22676c5bf2b727c135ff710af --- pytorch3d/csrc/ext.cpp | 7 +++++ .../csrc/rasterize_meshes/rasterize_meshes.cu | 2 +- pytorch3d/csrc/utils/float_math.cuh | 3 +++ pytorch3d/csrc/utils/warp_reduce.cuh | 26 +++++++++++++++++++ pytorch3d/renderer/__init__.py | 8 +++++- pytorch3d/renderer/points/__init__.py | 7 ++++- 6 files changed, 50 insertions(+), 3 deletions(-) diff --git a/pytorch3d/csrc/ext.cpp b/pytorch3d/csrc/ext.cpp index 6a17dbb0..49ec02c6 100644 --- a/pytorch3d/csrc/ext.cpp +++ b/pytorch3d/csrc/ext.cpp @@ -7,11 +7,15 @@ */ // clang-format off +#if !defined(USE_ROCM) #include "./pulsar/global.h" // Include before . +#endif #include // clang-format on +#if !defined(USE_ROCM) #include "./pulsar/pytorch/renderer.h" #include "./pulsar/pytorch/tensor_util.h" +#endif #include "ball_query/ball_query.h" #include "blending/sigmoid_alpha_blend.h" #include "compositing/alpha_composite.h" @@ -99,6 +103,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("marching_cubes", &MarchingCubes); // Pulsar. + // Pulsar not enabled on AMD. +#if !defined(USE_ROCM) #ifdef PULSAR_LOGGING_ENABLED c10::ShowLogInfoToStderr(); #endif @@ -183,4 +189,5 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.attr("MAX_UINT") = py::int_(MAX_UINT); m.attr("MAX_USHORT") = py::int_(MAX_USHORT); m.attr("PULSAR_MAX_GRAD_SPHERES") = py::int_(MAX_GRAD_SPHERES); +#endif } diff --git a/pytorch3d/csrc/rasterize_meshes/rasterize_meshes.cu b/pytorch3d/csrc/rasterize_meshes/rasterize_meshes.cu index 21ff7e50..9dd3e266 100644 --- a/pytorch3d/csrc/rasterize_meshes/rasterize_meshes.cu +++ b/pytorch3d/csrc/rasterize_meshes/rasterize_meshes.cu @@ -144,7 +144,7 @@ __device__ void CheckPixelInsideFace( const bool zero_face_area = (face_area <= kEpsilon && face_area >= -1.0f * kEpsilon); - if (zmax < 0 || cull_backfaces && back_face || outside_bbox || + if (zmax < 0 || (cull_backfaces && back_face) || outside_bbox || zero_face_area) { return; } diff --git a/pytorch3d/csrc/utils/float_math.cuh b/pytorch3d/csrc/utils/float_math.cuh index e48e960e..2a0e3e38 100644 --- a/pytorch3d/csrc/utils/float_math.cuh +++ b/pytorch3d/csrc/utils/float_math.cuh @@ -18,6 +18,8 @@ const auto vEpsilon = 1e-8; // Common functions and operators for float2. +// Complex arithmetic is already defined for AMD. +#if !defined(USE_ROCM) __device__ inline float2 operator-(const float2& a, const float2& b) { return make_float2(a.x - b.x, a.y - b.y); } @@ -41,6 +43,7 @@ __device__ inline float2 operator*(const float2& a, const float2& b) { __device__ inline float2 operator*(const float a, const float2& b) { return make_float2(a * b.x, a * b.y); } +#endif __device__ inline float FloatMin3(const float a, const float b, const float c) { return fminf(a, fminf(b, c)); diff --git a/pytorch3d/csrc/utils/warp_reduce.cuh b/pytorch3d/csrc/utils/warp_reduce.cuh index 3c903019..172f67b2 100644 --- a/pytorch3d/csrc/utils/warp_reduce.cuh +++ b/pytorch3d/csrc/utils/warp_reduce.cuh @@ -23,37 +23,51 @@ WarpReduceMin(scalar_t* min_dists, int64_t* min_idxs, const size_t tid) { min_idxs[tid] = min_idxs[tid + 32]; min_dists[tid] = min_dists[tid + 32]; } +// AMD does not use explicit syncwarp and instead automatically inserts memory +// fences during compilation. +#if !defined(USE_ROCM) __syncwarp(); +#endif // s = 16 if (min_dists[tid] > min_dists[tid + 16]) { min_idxs[tid] = min_idxs[tid + 16]; min_dists[tid] = min_dists[tid + 16]; } +#if !defined(USE_ROCM) __syncwarp(); +#endif // s = 8 if (min_dists[tid] > min_dists[tid + 8]) { min_idxs[tid] = min_idxs[tid + 8]; min_dists[tid] = min_dists[tid + 8]; } +#if !defined(USE_ROCM) __syncwarp(); +#endif // s = 4 if (min_dists[tid] > min_dists[tid + 4]) { min_idxs[tid] = min_idxs[tid + 4]; min_dists[tid] = min_dists[tid + 4]; } +#if !defined(USE_ROCM) __syncwarp(); +#endif // s = 2 if (min_dists[tid] > min_dists[tid + 2]) { min_idxs[tid] = min_idxs[tid + 2]; min_dists[tid] = min_dists[tid + 2]; } +#if !defined(USE_ROCM) __syncwarp(); +#endif // s = 1 if (min_dists[tid] > min_dists[tid + 1]) { min_idxs[tid] = min_idxs[tid + 1]; min_dists[tid] = min_dists[tid + 1]; } +#if !defined(USE_ROCM) __syncwarp(); +#endif } template @@ -65,30 +79,42 @@ __device__ void WarpReduceMax( dists[tid] = dists[tid + 32]; dists_idx[tid] = dists_idx[tid + 32]; } +#if !defined(USE_ROCM) __syncwarp(); +#endif if (dists[tid] < dists[tid + 16]) { dists[tid] = dists[tid + 16]; dists_idx[tid] = dists_idx[tid + 16]; } +#if !defined(USE_ROCM) __syncwarp(); +#endif if (dists[tid] < dists[tid + 8]) { dists[tid] = dists[tid + 8]; dists_idx[tid] = dists_idx[tid + 8]; } +#if !defined(USE_ROCM) __syncwarp(); +#endif if (dists[tid] < dists[tid + 4]) { dists[tid] = dists[tid + 4]; dists_idx[tid] = dists_idx[tid + 4]; } +#if !defined(USE_ROCM) __syncwarp(); +#endif if (dists[tid] < dists[tid + 2]) { dists[tid] = dists[tid + 2]; dists_idx[tid] = dists_idx[tid + 2]; } +#if !defined(USE_ROCM) __syncwarp(); +#endif if (dists[tid] < dists[tid + 1]) { dists[tid] = dists[tid + 1]; dists_idx[tid] = dists_idx[tid + 1]; } +#if !defined(USE_ROCM) __syncwarp(); +#endif } diff --git a/pytorch3d/renderer/__init__.py b/pytorch3d/renderer/__init__.py index affe30d6..e19e3eef 100644 --- a/pytorch3d/renderer/__init__.py +++ b/pytorch3d/renderer/__init__.py @@ -6,6 +6,8 @@ # pyre-unsafe +import torch + from .blending import ( BlendParams, hard_rgb_blend, @@ -74,9 +76,13 @@ from .points import ( PointsRasterizationSettings, PointsRasterizer, PointsRenderer, - PulsarPointsRenderer, rasterize_points, ) + +# Pulsar is not enabled on amd. +if not torch.version.hip: + from .points import PulsarPointsRenderer + from .splatter_blend import SplatterBlender from .utils import ( convert_to_tensors_and_broadcast, diff --git a/pytorch3d/renderer/points/__init__.py b/pytorch3d/renderer/points/__init__.py index 2abb97ce..2185f8f9 100644 --- a/pytorch3d/renderer/points/__init__.py +++ b/pytorch3d/renderer/points/__init__.py @@ -6,8 +6,13 @@ # pyre-unsafe +import torch + from .compositor import AlphaCompositor, NormWeightedCompositor -from .pulsar.unified import PulsarPointsRenderer + +# Pulsar not enabled on amd. +if not torch.version.hip: + from .pulsar.unified import PulsarPointsRenderer from .rasterize_points import rasterize_points from .rasterizer import PointsRasterizationSettings, PointsRasterizer