pytorch3d/pytorch3d/csrc/utils/warp_reduce.cuh
Josh Fromm 05cbea115a Hipify Pytorch3D (#1851)
Summary:
X-link: https://github.com/pytorch/pytorch/pull/133343

X-link: https://github.com/fairinternal/pytorch3d/pull/45

Pull Request resolved: https://github.com/facebookresearch/pytorch3d/pull/1851

Enables pytorch3d to build on AMD. An important part of enabling this was not compiling the Pulsar backend when the target is AMD. There are simply too many kernel incompatibilites to make it work (I tried haha). Fortunately, it doesnt seem like most modern applications of pytorch3d rely on Pulsar. We should be able to unlock most of pytorch3d's goodness on AMD without it.

Reviewed By: bottler, houseroad

Differential Revision: D61171993

fbshipit-source-id: fd4aee378a3568b22676c5bf2b727c135ff710af
2024-08-15 16:18:22 -07:00

121 lines
2.9 KiB
Plaintext

/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/
#include <float.h>
#include <math.h>
#include <cstdio>
// Helper functions WarpReduceMin and WarpReduceMax used in .cu files
// Starting in Volta, instructions are no longer synchronous within a warp.
// We need to call __syncwarp() to sync the 32 threads in the warp
// instead of all the threads in the block.
template <typename scalar_t>
__device__ void
WarpReduceMin(scalar_t* min_dists, int64_t* min_idxs, const size_t tid) {
// s = 32
if (min_dists[tid] > min_dists[tid + 32]) {
min_idxs[tid] = min_idxs[tid + 32];
min_dists[tid] = min_dists[tid + 32];
}
// AMD does not use explicit syncwarp and instead automatically inserts memory
// fences during compilation.
#if !defined(USE_ROCM)
__syncwarp();
#endif
// s = 16
if (min_dists[tid] > min_dists[tid + 16]) {
min_idxs[tid] = min_idxs[tid + 16];
min_dists[tid] = min_dists[tid + 16];
}
#if !defined(USE_ROCM)
__syncwarp();
#endif
// s = 8
if (min_dists[tid] > min_dists[tid + 8]) {
min_idxs[tid] = min_idxs[tid + 8];
min_dists[tid] = min_dists[tid + 8];
}
#if !defined(USE_ROCM)
__syncwarp();
#endif
// s = 4
if (min_dists[tid] > min_dists[tid + 4]) {
min_idxs[tid] = min_idxs[tid + 4];
min_dists[tid] = min_dists[tid + 4];
}
#if !defined(USE_ROCM)
__syncwarp();
#endif
// s = 2
if (min_dists[tid] > min_dists[tid + 2]) {
min_idxs[tid] = min_idxs[tid + 2];
min_dists[tid] = min_dists[tid + 2];
}
#if !defined(USE_ROCM)
__syncwarp();
#endif
// s = 1
if (min_dists[tid] > min_dists[tid + 1]) {
min_idxs[tid] = min_idxs[tid + 1];
min_dists[tid] = min_dists[tid + 1];
}
#if !defined(USE_ROCM)
__syncwarp();
#endif
}
template <typename scalar_t>
__device__ void WarpReduceMax(
volatile scalar_t* dists,
volatile int64_t* dists_idx,
const size_t tid) {
if (dists[tid] < dists[tid + 32]) {
dists[tid] = dists[tid + 32];
dists_idx[tid] = dists_idx[tid + 32];
}
#if !defined(USE_ROCM)
__syncwarp();
#endif
if (dists[tid] < dists[tid + 16]) {
dists[tid] = dists[tid + 16];
dists_idx[tid] = dists_idx[tid + 16];
}
#if !defined(USE_ROCM)
__syncwarp();
#endif
if (dists[tid] < dists[tid + 8]) {
dists[tid] = dists[tid + 8];
dists_idx[tid] = dists_idx[tid + 8];
}
#if !defined(USE_ROCM)
__syncwarp();
#endif
if (dists[tid] < dists[tid + 4]) {
dists[tid] = dists[tid + 4];
dists_idx[tid] = dists_idx[tid + 4];
}
#if !defined(USE_ROCM)
__syncwarp();
#endif
if (dists[tid] < dists[tid + 2]) {
dists[tid] = dists[tid + 2];
dists_idx[tid] = dists_idx[tid + 2];
}
#if !defined(USE_ROCM)
__syncwarp();
#endif
if (dists[tid] < dists[tid + 1]) {
dists[tid] = dists[tid + 1];
dists_idx[tid] = dists_idx[tid + 1];
}
#if !defined(USE_ROCM)
__syncwarp();
#endif
}