mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2026-06-17 04:28:54 +08:00
Port pytorch3d (#2039)
Summary: Enables building pytorch3d's `_C` extension against a ROCm-built PyTorch and running the test suite on AMD GPUs, including the pulsar subrenderer. Verified on AMD Instinct MI250X (gfx90a, warpSize=64), HIP 7.2, PyTorch 2.13. ## Mechanics `torch.utils.cpp_extension.BuildExtension` auto-hipifies `.cu` sources of a `CUDAExtension` against a HIP-built torch (`cuda_runtime.h → hip/hip_runtime.h`, `cub:: → hipcub::`, `cudaStream_t → hipStream_t`, etc.), so most of the lift is build-system glue and a small number of CUDA intrinsics that don't have HIP equivalents. - `setup.py`: detect ROCm via `torch.version.hip is not None`; treat `ROCM_HOME` as the GPU-toolkit-root analogue of `CUDA_HOME` (without this, `CUDA_HOME is None` silently demoted the build to a CPU-only `CppExtension`); skip `CUB_HOME`, CUDA-13 visibility flags, and `-ccbin=` on ROCm. - `pytorch3d/csrc/pulsar/gpu/commands.h`: CUDA's `_rn`-suffixed FP rounding intrinsics (`__fadd_rn`, `__fdiv_rn`, `__fsqrt_rn`, `__fmaf_rn`, `__frcp_rn`) and `__saturatef` have no HIP equivalents — AMD's GPU ISA has no instruction-level rounding-mode override, so they expand to plain operators / `sqrtf` / `fmaf` / `1.0f/x` / `fmaxf(0,fminf(1,x))` on the `USE_ROCM` arm, which are rounding-mode-equivalent (both round-to-nearest-even). The HIP compiler may fuse `a+b*c` into a single-rounding FMA where CUDA's `_rn` would have prevented it; if FMA-fusion drift ever becomes a numerical issue, add `-ffp-contract=off` to pulsar's HIPCC flags. `__powf` is replaced with `powf`. `atomicAdd_block` has no HIP function-name equivalent — the semantic equivalent is `__hip_atomic_fetch_add(ptr, val, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_WORKGROUP)` (plain HIP `atomicAdd` is device-scope, strictly stronger than block-scope and forces L2-coherent atomics). - `tests/test_point_mesh_distance.py`: loosen `grad_faces` tolerance in `test_point_face_distance` from `5e-7` to `5e-6` to match the sibling `test_face_point_distance`. The backward kernel uses `atomicAdd` and calls `alertNotDeterministic`; FP add order varies by wavefront width. - The X_t / camera-R/T equality checks in `test_points_alignment.py` and `test_cameras_alignment.py` are now skipped when `n_points <= dim` (resp. `batch_size <= 3` for camera-center alignment in 3D). Mean-centering renders the SVD rank-deficient in those cases, so the rotation around the degenerate axis is non-unique and different BLAS implementations (rocBLAS RDNA vs CDNA, cuBLAS) pick different valid null-space directions. The center-alignment check still runs and verifies the well-defined part of the transformation. Pull Request resolved: https://github.com/facebookresearch/pytorch3d/pull/2039 Test Plan: All GPU tests pass on both AMD Instinct MI250X (gfx90a, wave64, HIP 7.2) and AMD Radeon Pro W7800 (gfx1100, wave32, HIP 7.2.53211, torch 2.13.0a0). | Module | Result | |---|---| | knn, ball_query, sample_farthest_points, face_areas_normals | all pass | | rasterize_points, rasterize_meshes, chamfer, packed_to_padded | all pass | | interpolate_face_attributes, blending, compositing, sample_pdf, mesh_normal_consistency | all pass | | point_mesh_distance | 9/9 pass (with tolerance fix in this PR) | | pulsar/test_forward, test_channels, test_depth, test_hands, test_ortho, test_small_spheres | 10 passed (FB_TEST=1) | | test_render_points pulsar tests, test_camera_conversions::test_pulsar_conversion | 3 passed | | points_to_volumes, iou_box3d, marching_cubes | 20 failures, all env-only | The 20 env-only failures are `torch.inverse()` on CPU tensors in test reference paths; this verification host's PyTorch was built with `USE_LAPACK: 0` (only `mkl-static` `.a` archives in the conda env; PyTorch's `FindBLAS` looks for `libmkl_intel_lp64.so`). Unrelated to the port — re-verifying with a LAPACK-linked PyTorch is left to upstream. Reviewed By: MichaelRamamonjisoa Differential Revision: D106825690 Pulled By: bottler fbshipit-source-id: f7a9b6028e6fb555f3b8c0f9792e88b818327166
This commit is contained in:
committed by
meta-codesync[bot]
parent
c307c64c70
commit
b73d735ecf
76
setup.py
76
setup.py
@@ -14,7 +14,7 @@ from typing import List, Optional
|
||||
|
||||
import torch
|
||||
from setuptools import find_packages, setup
|
||||
from torch.utils.cpp_extension import CppExtension, CUDA_HOME, CUDAExtension
|
||||
from torch.utils.cpp_extension import CppExtension, CUDA_HOME, CUDAExtension, ROCM_HOME
|
||||
|
||||
|
||||
def get_existing_ccbin(nvcc_args: List[str]) -> Optional[str]:
|
||||
@@ -53,12 +53,18 @@ def get_extensions():
|
||||
define_macros = []
|
||||
include_dirs = [extensions_dir]
|
||||
|
||||
# ROCm/HIP support. When PyTorch is built with HIP, the cpp_extension
|
||||
# BuildExtension auto-hipifies .cu sources and swaps nvcc -> hipcc.
|
||||
is_rocm = torch.version.hip is not None
|
||||
|
||||
force_cuda = os.getenv("FORCE_CUDA", "0") == "1"
|
||||
force_no_cuda = os.getenv("PYTORCH3D_FORCE_NO_CUDA", "0") == "1"
|
||||
gpu_home_available = CUDA_HOME is not None or (is_rocm and ROCM_HOME is not None)
|
||||
if (
|
||||
not force_no_cuda and torch.cuda.is_available() and CUDA_HOME is not None
|
||||
not force_no_cuda and torch.cuda.is_available() and gpu_home_available
|
||||
) or force_cuda:
|
||||
extension = CUDAExtension
|
||||
|
||||
sources += source_cuda
|
||||
define_macros += [("WITH_CUDA", None)]
|
||||
# Thrust is only used for its tuple objects.
|
||||
@@ -66,7 +72,6 @@ def get_extensions():
|
||||
# We take the risk that CUB and Thrust are incompatible, because
|
||||
# we aren't using parts of Thrust which actually use CUB.
|
||||
define_macros += [("THRUST_IGNORE_CUB_VERSION_CHECK", None)]
|
||||
cub_home = os.environ.get("CUB_HOME", None)
|
||||
nvcc_args = [
|
||||
"-DCUDA_HAS_FP16=1",
|
||||
"-D__CUDA_NO_HALF_OPERATORS__",
|
||||
@@ -76,35 +81,40 @@ def get_extensions():
|
||||
if os.name != "nt":
|
||||
nvcc_args.append("-std=c++17")
|
||||
|
||||
# CUDA 13.0+ compatibility flags for pulsar.
|
||||
# Starting with CUDA 13, __global__ function visibility changed.
|
||||
# See: https://developer.nvidia.com/blog/
|
||||
# cuda-c-compiler-updates-impacting-elf-visibility-and-linkage/
|
||||
cuda_version = torch.version.cuda
|
||||
if cuda_version is not None:
|
||||
major = int(cuda_version.split(".")[0])
|
||||
if major >= 13:
|
||||
nvcc_args.extend(
|
||||
[
|
||||
"--device-entity-has-hidden-visibility=false",
|
||||
"-static-global-template-stub=false",
|
||||
]
|
||||
)
|
||||
if cub_home is None:
|
||||
prefix = os.environ.get("CONDA_PREFIX", None)
|
||||
if prefix is not None and os.path.isdir(prefix + "/include/cub"):
|
||||
cub_home = prefix + "/include"
|
||||
if not is_rocm:
|
||||
# CUDA 13.0+ compatibility flags for pulsar.
|
||||
# Starting with CUDA 13, __global__ function visibility changed.
|
||||
# See: https://developer.nvidia.com/blog/
|
||||
# cuda-c-compiler-updates-impacting-elf-visibility-and-linkage/
|
||||
cuda_version = torch.version.cuda
|
||||
if cuda_version is not None:
|
||||
major = int(cuda_version.split(".")[0])
|
||||
if major >= 13:
|
||||
nvcc_args.extend(
|
||||
[
|
||||
"--device-entity-has-hidden-visibility=false",
|
||||
"-static-global-template-stub=false",
|
||||
]
|
||||
)
|
||||
|
||||
if cub_home is None:
|
||||
warnings.warn(
|
||||
"The environment variable `CUB_HOME` was not found. "
|
||||
"NVIDIA CUB is required for compilation and can be downloaded "
|
||||
"from `https://github.com/NVIDIA/cub/releases`. You can unpack "
|
||||
"it to a location of your choice and set the environment variable "
|
||||
"`CUB_HOME` to the folder containing the `CMakeListst.txt` file."
|
||||
)
|
||||
else:
|
||||
include_dirs.append(os.path.realpath(cub_home).replace("\\ ", " "))
|
||||
# NVIDIA CUB. On ROCm, hipcub from the ROCm toolchain is used and
|
||||
# no external CUB_HOME is required.
|
||||
cub_home = os.environ.get("CUB_HOME", None)
|
||||
if cub_home is None:
|
||||
prefix = os.environ.get("CONDA_PREFIX", None)
|
||||
if prefix is not None and os.path.isdir(prefix + "/include/cub"):
|
||||
cub_home = prefix + "/include"
|
||||
|
||||
if cub_home is None:
|
||||
warnings.warn(
|
||||
"The environment variable `CUB_HOME` was not found. "
|
||||
"NVIDIA CUB is required for compilation and can be downloaded "
|
||||
"from `https://github.com/NVIDIA/cub/releases`. You can unpack "
|
||||
"it to a location of your choice and set the environment variable "
|
||||
"`CUB_HOME` to the folder containing the `CMakeListst.txt` file."
|
||||
)
|
||||
else:
|
||||
include_dirs.append(os.path.realpath(cub_home).replace("\\ ", " "))
|
||||
nvcc_flags_env = os.getenv("NVCC_FLAGS", "")
|
||||
if nvcc_flags_env != "":
|
||||
nvcc_args.extend(nvcc_flags_env.split(" "))
|
||||
@@ -113,7 +123,9 @@ def get_extensions():
|
||||
# https://github.com/facebookresearch/pytorch3d/issues/436
|
||||
# It is harmless after https://github.com/pytorch/pytorch/pull/47404 .
|
||||
# But it can be problematic in torch 1.7.0 and 1.7.1
|
||||
if torch.__version__[:4] != "1.7.":
|
||||
# On ROCm the host compiler is selected by hipcc itself; -ccbin is
|
||||
# an nvcc-only flag.
|
||||
if not is_rocm and torch.__version__[:4] != "1.7.":
|
||||
CC = os.environ.get("CC", None)
|
||||
if CC is not None:
|
||||
existing_CC = get_existing_ccbin(nvcc_args)
|
||||
|
||||
Reference in New Issue
Block a user