diff --git a/pytorch3d/csrc/ext.cpp b/pytorch3d/csrc/ext.cpp index 6a17dbb0..49ec02c6 100644 --- a/pytorch3d/csrc/ext.cpp +++ b/pytorch3d/csrc/ext.cpp @@ -7,11 +7,15 @@ */ // clang-format off +#if !defined(USE_ROCM) #include "./pulsar/global.h" // Include before . +#endif #include // clang-format on +#if !defined(USE_ROCM) #include "./pulsar/pytorch/renderer.h" #include "./pulsar/pytorch/tensor_util.h" +#endif #include "ball_query/ball_query.h" #include "blending/sigmoid_alpha_blend.h" #include "compositing/alpha_composite.h" @@ -99,6 +103,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("marching_cubes", &MarchingCubes); // Pulsar. + // Pulsar not enabled on AMD. +#if !defined(USE_ROCM) #ifdef PULSAR_LOGGING_ENABLED c10::ShowLogInfoToStderr(); #endif @@ -183,4 +189,5 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.attr("MAX_UINT") = py::int_(MAX_UINT); m.attr("MAX_USHORT") = py::int_(MAX_USHORT); m.attr("PULSAR_MAX_GRAD_SPHERES") = py::int_(MAX_GRAD_SPHERES); +#endif } diff --git a/pytorch3d/csrc/rasterize_meshes/rasterize_meshes.cu b/pytorch3d/csrc/rasterize_meshes/rasterize_meshes.cu index 21ff7e50..9dd3e266 100644 --- a/pytorch3d/csrc/rasterize_meshes/rasterize_meshes.cu +++ b/pytorch3d/csrc/rasterize_meshes/rasterize_meshes.cu @@ -144,7 +144,7 @@ __device__ void CheckPixelInsideFace( const bool zero_face_area = (face_area <= kEpsilon && face_area >= -1.0f * kEpsilon); - if (zmax < 0 || cull_backfaces && back_face || outside_bbox || + if (zmax < 0 || (cull_backfaces && back_face) || outside_bbox || zero_face_area) { return; } diff --git a/pytorch3d/csrc/utils/float_math.cuh b/pytorch3d/csrc/utils/float_math.cuh index e48e960e..2a0e3e38 100644 --- a/pytorch3d/csrc/utils/float_math.cuh +++ b/pytorch3d/csrc/utils/float_math.cuh @@ -18,6 +18,8 @@ const auto vEpsilon = 1e-8; // Common functions and operators for float2. +// Complex arithmetic is already defined for AMD. +#if !defined(USE_ROCM) __device__ inline float2 operator-(const float2& a, const float2& b) { return make_float2(a.x - b.x, a.y - b.y); } @@ -41,6 +43,7 @@ __device__ inline float2 operator*(const float2& a, const float2& b) { __device__ inline float2 operator*(const float a, const float2& b) { return make_float2(a * b.x, a * b.y); } +#endif __device__ inline float FloatMin3(const float a, const float b, const float c) { return fminf(a, fminf(b, c)); diff --git a/pytorch3d/csrc/utils/warp_reduce.cuh b/pytorch3d/csrc/utils/warp_reduce.cuh index 3c903019..172f67b2 100644 --- a/pytorch3d/csrc/utils/warp_reduce.cuh +++ b/pytorch3d/csrc/utils/warp_reduce.cuh @@ -23,37 +23,51 @@ WarpReduceMin(scalar_t* min_dists, int64_t* min_idxs, const size_t tid) { min_idxs[tid] = min_idxs[tid + 32]; min_dists[tid] = min_dists[tid + 32]; } +// AMD does not use explicit syncwarp and instead automatically inserts memory +// fences during compilation. +#if !defined(USE_ROCM) __syncwarp(); +#endif // s = 16 if (min_dists[tid] > min_dists[tid + 16]) { min_idxs[tid] = min_idxs[tid + 16]; min_dists[tid] = min_dists[tid + 16]; } +#if !defined(USE_ROCM) __syncwarp(); +#endif // s = 8 if (min_dists[tid] > min_dists[tid + 8]) { min_idxs[tid] = min_idxs[tid + 8]; min_dists[tid] = min_dists[tid + 8]; } +#if !defined(USE_ROCM) __syncwarp(); +#endif // s = 4 if (min_dists[tid] > min_dists[tid + 4]) { min_idxs[tid] = min_idxs[tid + 4]; min_dists[tid] = min_dists[tid + 4]; } +#if !defined(USE_ROCM) __syncwarp(); +#endif // s = 2 if (min_dists[tid] > min_dists[tid + 2]) { min_idxs[tid] = min_idxs[tid + 2]; min_dists[tid] = min_dists[tid + 2]; } +#if !defined(USE_ROCM) __syncwarp(); +#endif // s = 1 if (min_dists[tid] > min_dists[tid + 1]) { min_idxs[tid] = min_idxs[tid + 1]; min_dists[tid] = min_dists[tid + 1]; } +#if !defined(USE_ROCM) __syncwarp(); +#endif } template @@ -65,30 +79,42 @@ __device__ void WarpReduceMax( dists[tid] = dists[tid + 32]; dists_idx[tid] = dists_idx[tid + 32]; } +#if !defined(USE_ROCM) __syncwarp(); +#endif if (dists[tid] < dists[tid + 16]) { dists[tid] = dists[tid + 16]; dists_idx[tid] = dists_idx[tid + 16]; } +#if !defined(USE_ROCM) __syncwarp(); +#endif if (dists[tid] < dists[tid + 8]) { dists[tid] = dists[tid + 8]; dists_idx[tid] = dists_idx[tid + 8]; } +#if !defined(USE_ROCM) __syncwarp(); +#endif if (dists[tid] < dists[tid + 4]) { dists[tid] = dists[tid + 4]; dists_idx[tid] = dists_idx[tid + 4]; } +#if !defined(USE_ROCM) __syncwarp(); +#endif if (dists[tid] < dists[tid + 2]) { dists[tid] = dists[tid + 2]; dists_idx[tid] = dists_idx[tid + 2]; } +#if !defined(USE_ROCM) __syncwarp(); +#endif if (dists[tid] < dists[tid + 1]) { dists[tid] = dists[tid + 1]; dists_idx[tid] = dists_idx[tid + 1]; } +#if !defined(USE_ROCM) __syncwarp(); +#endif } diff --git a/pytorch3d/renderer/__init__.py b/pytorch3d/renderer/__init__.py index affe30d6..e19e3eef 100644 --- a/pytorch3d/renderer/__init__.py +++ b/pytorch3d/renderer/__init__.py @@ -6,6 +6,8 @@ # pyre-unsafe +import torch + from .blending import ( BlendParams, hard_rgb_blend, @@ -74,9 +76,13 @@ from .points import ( PointsRasterizationSettings, PointsRasterizer, PointsRenderer, - PulsarPointsRenderer, rasterize_points, ) + +# Pulsar is not enabled on amd. +if not torch.version.hip: + from .points import PulsarPointsRenderer + from .splatter_blend import SplatterBlender from .utils import ( convert_to_tensors_and_broadcast, diff --git a/pytorch3d/renderer/points/__init__.py b/pytorch3d/renderer/points/__init__.py index 2abb97ce..2185f8f9 100644 --- a/pytorch3d/renderer/points/__init__.py +++ b/pytorch3d/renderer/points/__init__.py @@ -6,8 +6,13 @@ # pyre-unsafe +import torch + from .compositor import AlphaCompositor, NormWeightedCompositor -from .pulsar.unified import PulsarPointsRenderer + +# Pulsar not enabled on amd. +if not torch.version.hip: + from .pulsar.unified import PulsarPointsRenderer from .rasterize_points import rasterize_points from .rasterizer import PointsRasterizationSettings, PointsRasterizer