mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-01 03:12:49 +08:00
Hipify Pulsar for PyTorch3D
Summary: - Hipified Pytorch Pulsar - Created separate target for Pulsar tests and enabled RE testing - Pytorch3D full test suite requires additional work like fixing EGL dependencies on AMD Reviewed By: danzimm Differential Revision: D61339912 fbshipit-source-id: 0d10bc966e4de4a959f3834a386bad24e449dc1f
This commit is contained in:
parent
8ed0c7a002
commit
e17ed5cd50
@ -7,15 +7,11 @@
|
|||||||
*/
|
*/
|
||||||
|
|
||||||
// clang-format off
|
// clang-format off
|
||||||
#if !defined(USE_ROCM)
|
|
||||||
#include "./pulsar/global.h" // Include before <torch/extension.h>.
|
#include "./pulsar/global.h" // Include before <torch/extension.h>.
|
||||||
#endif
|
|
||||||
#include <torch/extension.h>
|
#include <torch/extension.h>
|
||||||
// clang-format on
|
// clang-format on
|
||||||
#if !defined(USE_ROCM)
|
|
||||||
#include "./pulsar/pytorch/renderer.h"
|
#include "./pulsar/pytorch/renderer.h"
|
||||||
#include "./pulsar/pytorch/tensor_util.h"
|
#include "./pulsar/pytorch/tensor_util.h"
|
||||||
#endif
|
|
||||||
#include "ball_query/ball_query.h"
|
#include "ball_query/ball_query.h"
|
||||||
#include "blending/sigmoid_alpha_blend.h"
|
#include "blending/sigmoid_alpha_blend.h"
|
||||||
#include "compositing/alpha_composite.h"
|
#include "compositing/alpha_composite.h"
|
||||||
@ -104,7 +100,6 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
|||||||
|
|
||||||
// Pulsar.
|
// Pulsar.
|
||||||
// Pulsar not enabled on AMD.
|
// Pulsar not enabled on AMD.
|
||||||
#if !defined(USE_ROCM)
|
|
||||||
#ifdef PULSAR_LOGGING_ENABLED
|
#ifdef PULSAR_LOGGING_ENABLED
|
||||||
c10::ShowLogInfoToStderr();
|
c10::ShowLogInfoToStderr();
|
||||||
#endif
|
#endif
|
||||||
@ -189,5 +184,4 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
|||||||
m.attr("MAX_UINT") = py::int_(MAX_UINT);
|
m.attr("MAX_UINT") = py::int_(MAX_UINT);
|
||||||
m.attr("MAX_USHORT") = py::int_(MAX_USHORT);
|
m.attr("MAX_USHORT") = py::int_(MAX_USHORT);
|
||||||
m.attr("PULSAR_MAX_GRAD_SPHERES") = py::int_(MAX_GRAD_SPHERES);
|
m.attr("PULSAR_MAX_GRAD_SPHERES") = py::int_(MAX_GRAD_SPHERES);
|
||||||
#endif
|
|
||||||
}
|
}
|
||||||
|
@ -36,11 +36,13 @@
|
|||||||
#pragma nv_diag_suppress 2951
|
#pragma nv_diag_suppress 2951
|
||||||
#pragma nv_diag_suppress 2967
|
#pragma nv_diag_suppress 2967
|
||||||
#else
|
#else
|
||||||
|
#if !defined(USE_ROCM)
|
||||||
#pragma diag_suppress = attribute_not_allowed
|
#pragma diag_suppress = attribute_not_allowed
|
||||||
#pragma diag_suppress = 1866
|
#pragma diag_suppress = 1866
|
||||||
#pragma diag_suppress = 2941
|
#pragma diag_suppress = 2941
|
||||||
#pragma diag_suppress = 2951
|
#pragma diag_suppress = 2951
|
||||||
#pragma diag_suppress = 2967
|
#pragma diag_suppress = 2967
|
||||||
|
#endif //! USE_ROCM
|
||||||
#endif
|
#endif
|
||||||
#else // __CUDACC__
|
#else // __CUDACC__
|
||||||
#define INLINE inline
|
#define INLINE inline
|
||||||
@ -56,7 +58,9 @@
|
|||||||
#pragma clang diagnostic pop
|
#pragma clang diagnostic pop
|
||||||
#ifdef WITH_CUDA
|
#ifdef WITH_CUDA
|
||||||
#include <ATen/cuda/CUDAContext.h>
|
#include <ATen/cuda/CUDAContext.h>
|
||||||
|
#if !defined(USE_ROCM)
|
||||||
#include <vector_functions.h>
|
#include <vector_functions.h>
|
||||||
|
#endif //! USE_ROCM
|
||||||
#else
|
#else
|
||||||
#ifndef cudaStream_t
|
#ifndef cudaStream_t
|
||||||
typedef void* cudaStream_t;
|
typedef void* cudaStream_t;
|
||||||
|
@ -59,6 +59,11 @@ getLastCudaError(const char* errorMessage, const char* file, const int line) {
|
|||||||
#define SHARED __shared__
|
#define SHARED __shared__
|
||||||
#define ACTIVEMASK() __activemask()
|
#define ACTIVEMASK() __activemask()
|
||||||
#define BALLOT(mask, val) __ballot_sync((mask), val)
|
#define BALLOT(mask, val) __ballot_sync((mask), val)
|
||||||
|
|
||||||
|
/* TODO (ROCM-6.2): None of the WARP_* are used anywhere and ROCM-6.2 natively
|
||||||
|
* supports __shfl_*. Disabling until the move to ROCM-6.2.
|
||||||
|
*/
|
||||||
|
#if !defined(USE_ROCM)
|
||||||
/**
|
/**
|
||||||
* Find the cumulative sum within a warp up to the current
|
* Find the cumulative sum within a warp up to the current
|
||||||
* thread lane, with each mask thread contributing base.
|
* thread lane, with each mask thread contributing base.
|
||||||
@ -115,6 +120,7 @@ INLINE DEVICE float3 WARP_SUM_FLOAT3(
|
|||||||
ret.z = WARP_SUM(group, mask, base.z);
|
ret.z = WARP_SUM(group, mask, base.z);
|
||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
|
#endif //! USE_ROCM
|
||||||
|
|
||||||
// Floating point.
|
// Floating point.
|
||||||
// #define FMUL(a, b) __fmul_rn((a), (b))
|
// #define FMUL(a, b) __fmul_rn((a), (b))
|
||||||
@ -142,6 +148,7 @@ INLINE DEVICE float3 WARP_SUM_FLOAT3(
|
|||||||
#define FMA(x, y, z) __fmaf_rn((x), (y), (z))
|
#define FMA(x, y, z) __fmaf_rn((x), (y), (z))
|
||||||
#define I2F(a) __int2float_rn(a)
|
#define I2F(a) __int2float_rn(a)
|
||||||
#define FRCP(x) __frcp_rn(x)
|
#define FRCP(x) __frcp_rn(x)
|
||||||
|
#if !defined(USE_ROCM)
|
||||||
__device__ static float atomicMax(float* address, float val) {
|
__device__ static float atomicMax(float* address, float val) {
|
||||||
int* address_as_i = (int*)address;
|
int* address_as_i = (int*)address;
|
||||||
int old = *address_as_i, assumed;
|
int old = *address_as_i, assumed;
|
||||||
@ -166,6 +173,7 @@ __device__ static float atomicMin(float* address, float val) {
|
|||||||
} while (assumed != old);
|
} while (assumed != old);
|
||||||
return __int_as_float(old);
|
return __int_as_float(old);
|
||||||
}
|
}
|
||||||
|
#endif //! USE_ROCM
|
||||||
#define DMAX(a, b) FMAX(a, b)
|
#define DMAX(a, b) FMAX(a, b)
|
||||||
#define DMIN(a, b) FMIN(a, b)
|
#define DMIN(a, b) FMIN(a, b)
|
||||||
#define DSQRT(a) sqrt(a)
|
#define DSQRT(a) sqrt(a)
|
@ -14,7 +14,7 @@
|
|||||||
#include "./commands.h"
|
#include "./commands.h"
|
||||||
|
|
||||||
namespace pulsar {
|
namespace pulsar {
|
||||||
IHD CamGradInfo::CamGradInfo() {
|
IHD CamGradInfo::CamGradInfo(int x) {
|
||||||
cam_pos = make_float3(0.f, 0.f, 0.f);
|
cam_pos = make_float3(0.f, 0.f, 0.f);
|
||||||
pixel_0_0_center = make_float3(0.f, 0.f, 0.f);
|
pixel_0_0_center = make_float3(0.f, 0.f, 0.f);
|
||||||
pixel_dir_x = make_float3(0.f, 0.f, 0.f);
|
pixel_dir_x = make_float3(0.f, 0.f, 0.f);
|
||||||
|
@ -63,7 +63,7 @@ inline bool operator==(const CamInfo& a, const CamInfo& b) {
|
|||||||
};
|
};
|
||||||
|
|
||||||
struct CamGradInfo {
|
struct CamGradInfo {
|
||||||
HOST DEVICE CamGradInfo();
|
HOST DEVICE CamGradInfo(int = 0);
|
||||||
float3 cam_pos;
|
float3 cam_pos;
|
||||||
float3 pixel_0_0_center;
|
float3 pixel_0_0_center;
|
||||||
float3 pixel_dir_x;
|
float3 pixel_dir_x;
|
||||||
|
@ -24,7 +24,7 @@
|
|||||||
// #pragma diag_suppress = 68
|
// #pragma diag_suppress = 68
|
||||||
#include <ATen/cuda/CUDAContext.h>
|
#include <ATen/cuda/CUDAContext.h>
|
||||||
// #pragma pop
|
// #pragma pop
|
||||||
#include "../cuda/commands.h"
|
#include "../gpu/commands.h"
|
||||||
#else
|
#else
|
||||||
#pragma clang diagnostic push
|
#pragma clang diagnostic push
|
||||||
#pragma clang diagnostic ignored "-Weverything"
|
#pragma clang diagnostic ignored "-Weverything"
|
||||||
|
@ -46,6 +46,7 @@ IHD float3 outer_product_sum(const float3& a) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// TODO: put intrinsics here.
|
// TODO: put intrinsics here.
|
||||||
|
#if !defined(USE_ROCM)
|
||||||
IHD float3 operator+(const float3& a, const float3& b) {
|
IHD float3 operator+(const float3& a, const float3& b) {
|
||||||
return make_float3(a.x + b.x, a.y + b.y, a.z + b.z);
|
return make_float3(a.x + b.x, a.y + b.y, a.z + b.z);
|
||||||
}
|
}
|
||||||
@ -93,6 +94,7 @@ IHD float3 operator*(const float3& a, const float3& b) {
|
|||||||
IHD float3 operator*(const float& a, const float3& b) {
|
IHD float3 operator*(const float& a, const float3& b) {
|
||||||
return b * a;
|
return b * a;
|
||||||
}
|
}
|
||||||
|
#endif //! USE_ROCM
|
||||||
|
|
||||||
INLINE DEVICE float length(const float3& v) {
|
INLINE DEVICE float length(const float3& v) {
|
||||||
// TODO: benchmark what's faster.
|
// TODO: benchmark what's faster.
|
||||||
|
@ -283,9 +283,15 @@ GLOBAL void render(
|
|||||||
(percent_allowed_difference > 0.f &&
|
(percent_allowed_difference > 0.f &&
|
||||||
max_closest_possible_intersection > depth_threshold) ||
|
max_closest_possible_intersection > depth_threshold) ||
|
||||||
tracker.get_n_hits() >= max_n_hits;
|
tracker.get_n_hits() >= max_n_hits;
|
||||||
|
#if defined(__CUDACC__) && defined(__HIP_PLATFORM_AMD__)
|
||||||
|
unsigned long long warp_done = __ballot(done);
|
||||||
|
int warp_done_bit_cnt = __popcll(warp_done);
|
||||||
|
#else
|
||||||
uint warp_done = thread_warp.ballot(done);
|
uint warp_done = thread_warp.ballot(done);
|
||||||
|
int warp_done_bit_cnt = POPC(warp_done);
|
||||||
|
#endif //__CUDACC__ && __HIP_PLATFORM_AMD__
|
||||||
if (thread_warp.thread_rank() == 0)
|
if (thread_warp.thread_rank() == 0)
|
||||||
ATOMICADD_B(&n_pixels_done, POPC(warp_done));
|
ATOMICADD_B(&n_pixels_done, warp_done_bit_cnt);
|
||||||
// This sync is necessary to keep n_loaded until all threads are done with
|
// This sync is necessary to keep n_loaded until all threads are done with
|
||||||
// painting.
|
// painting.
|
||||||
thread_block.sync();
|
thread_block.sync();
|
||||||
|
@ -76,13 +76,9 @@ from .points import (
|
|||||||
PointsRasterizationSettings,
|
PointsRasterizationSettings,
|
||||||
PointsRasterizer,
|
PointsRasterizer,
|
||||||
PointsRenderer,
|
PointsRenderer,
|
||||||
|
PulsarPointsRenderer,
|
||||||
rasterize_points,
|
rasterize_points,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Pulsar is not enabled on amd.
|
|
||||||
if not torch.version.hip:
|
|
||||||
from .points import PulsarPointsRenderer
|
|
||||||
|
|
||||||
from .splatter_blend import SplatterBlender
|
from .splatter_blend import SplatterBlender
|
||||||
from .utils import (
|
from .utils import (
|
||||||
convert_to_tensors_and_broadcast,
|
convert_to_tensors_and_broadcast,
|
||||||
|
@ -10,9 +10,7 @@ import torch
|
|||||||
|
|
||||||
from .compositor import AlphaCompositor, NormWeightedCompositor
|
from .compositor import AlphaCompositor, NormWeightedCompositor
|
||||||
|
|
||||||
# Pulsar not enabled on amd.
|
from .pulsar.unified import PulsarPointsRenderer
|
||||||
if not torch.version.hip:
|
|
||||||
from .pulsar.unified import PulsarPointsRenderer
|
|
||||||
|
|
||||||
from .rasterize_points import rasterize_points
|
from .rasterize_points import rasterize_points
|
||||||
from .rasterizer import PointsRasterizationSettings, PointsRasterizer
|
from .rasterizer import PointsRasterizationSettings, PointsRasterizer
|
||||||
|
Loading…
x
Reference in New Issue
Block a user