Files
pytorch3d/tests/test_cameras_alignment.py
Jeff Daily b73d735ecf Port pytorch3d (#2039)
Summary:
Enables building pytorch3d's `_C` extension against a ROCm-built PyTorch and running the test suite on AMD GPUs, including the pulsar subrenderer. Verified on AMD Instinct MI250X (gfx90a, warpSize=64), HIP 7.2, PyTorch 2.13.

## Mechanics

`torch.utils.cpp_extension.BuildExtension` auto-hipifies `.cu` sources of a `CUDAExtension` against a HIP-built torch (`cuda_runtime.h → hip/hip_runtime.h`, `cub:: → hipcub::`, `cudaStream_t → hipStream_t`, etc.), so most of the lift is build-system glue and a small number of CUDA intrinsics that don't have HIP equivalents.

- `setup.py`: detect ROCm via `torch.version.hip is not None`; treat `ROCM_HOME` as the GPU-toolkit-root analogue of `CUDA_HOME` (without this, `CUDA_HOME is None` silently demoted the build to a CPU-only `CppExtension`); skip `CUB_HOME`, CUDA-13 visibility flags, and `-ccbin=` on ROCm.
- `pytorch3d/csrc/pulsar/gpu/commands.h`: CUDA's `_rn`-suffixed FP rounding intrinsics (`__fadd_rn`, `__fdiv_rn`, `__fsqrt_rn`, `__fmaf_rn`, `__frcp_rn`) and `__saturatef` have no HIP equivalents — AMD's GPU ISA has no instruction-level rounding-mode override, so they expand to plain operators / `sqrtf` / `fmaf` / `1.0f/x` / `fmaxf(0,fminf(1,x))` on the `USE_ROCM` arm, which are rounding-mode-equivalent (both round-to-nearest-even). The HIP compiler may fuse `a+b*c` into a single-rounding FMA where CUDA's `_rn` would have prevented it; if FMA-fusion drift ever becomes a numerical issue, add `-ffp-contract=off` to pulsar's HIPCC flags. `__powf` is replaced with `powf`. `atomicAdd_block` has no HIP function-name equivalent — the semantic equivalent is `__hip_atomic_fetch_add(ptr, val, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_WORKGROUP)` (plain HIP `atomicAdd` is device-scope, strictly stronger than block-scope and forces L2-coherent atomics).
- `tests/test_point_mesh_distance.py`: loosen `grad_faces` tolerance in `test_point_face_distance` from `5e-7` to `5e-6` to match the sibling `test_face_point_distance`. The backward kernel uses `atomicAdd` and calls `alertNotDeterministic`; FP add order varies by wavefront width.
- The X_t / camera-R/T equality checks in `test_points_alignment.py` and `test_cameras_alignment.py` are now skipped when `n_points <= dim` (resp. `batch_size <= 3` for camera-center alignment in 3D). Mean-centering renders the SVD rank-deficient in those cases, so the rotation around the degenerate axis is non-unique and different BLAS implementations (rocBLAS RDNA vs CDNA, cuBLAS) pick different valid null-space directions. The center-alignment check still runs and verifies the well-defined part of the transformation.

Pull Request resolved: https://github.com/facebookresearch/pytorch3d/pull/2039

Test Plan:
All GPU tests pass on both AMD Instinct MI250X (gfx90a, wave64, HIP 7.2) and AMD Radeon Pro W7800 (gfx1100, wave32, HIP 7.2.53211, torch 2.13.0a0).

| Module | Result |
|---|---|
| knn, ball_query, sample_farthest_points, face_areas_normals | all pass |
| rasterize_points, rasterize_meshes, chamfer, packed_to_padded | all pass |
| interpolate_face_attributes, blending, compositing, sample_pdf, mesh_normal_consistency | all pass |
| point_mesh_distance | 9/9 pass (with tolerance fix in this PR) |
| pulsar/test_forward, test_channels, test_depth, test_hands, test_ortho, test_small_spheres | 10 passed (FB_TEST=1) |
| test_render_points pulsar tests, test_camera_conversions::test_pulsar_conversion | 3 passed |
| points_to_volumes, iou_box3d, marching_cubes | 20 failures, all env-only |

The 20 env-only failures are `torch.inverse()` on CPU tensors in test reference paths; this verification host's PyTorch was built with `USE_LAPACK: 0` (only `mkl-static` `.a` archives in the conda env; PyTorch's `FindBLAS` looks for `libmkl_intel_lp64.so`). Unrelated to the port — re-verifying with a LAPACK-linked PyTorch is left to upstream.

Reviewed By: MichaelRamamonjisoa

Differential Revision: D106825690

Pulled By: bottler

fbshipit-source-id: f7a9b6028e6fb555f3b8c0f9792e88b818327166
2026-06-01 06:08:12 -07:00

183 lines
6.9 KiB
Python

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import unittest
import numpy as np
import torch
from pytorch3d.ops import corresponding_cameras_alignment
from pytorch3d.renderer.cameras import (
OpenGLOrthographicCameras,
OpenGLPerspectiveCameras,
SfMOrthographicCameras,
SfMPerspectiveCameras,
)
from pytorch3d.transforms.rotation_conversions import random_rotations
from pytorch3d.transforms.so3 import so3_exp_map, so3_relative_angle
from .common_testing import TestCaseMixin
from .test_cameras import init_random_cameras
class TestCamerasAlignment(TestCaseMixin, unittest.TestCase):
def setUp(self) -> None:
super().setUp()
torch.manual_seed(42)
np.random.seed(42)
def test_corresponding_cameras_alignment(self):
"""
Checks the corresponding_cameras_alignment function.
"""
device = torch.device("cuda:0")
# try few different random setups
for _ in range(3):
for estimate_scale in (True, False):
# init true alignment transform
R_align_gt = random_rotations(1, device=device)[0]
T_align_gt = torch.randn(3, dtype=torch.float32, device=device)
# init true scale
if estimate_scale:
s_align_gt = torch.randn(
1, dtype=torch.float32, device=device
).exp()
else:
s_align_gt = torch.tensor(1.0, dtype=torch.float32, device=device)
for cam_type in (
SfMOrthographicCameras,
OpenGLPerspectiveCameras,
OpenGLOrthographicCameras,
SfMPerspectiveCameras,
):
# try well-determined and underdetermined cases
for batch_size in (10, 4, 3, 2, 1):
# get random cameras
cameras = init_random_cameras(
cam_type, batch_size, random_z=True
).to(device)
# try all alignment modes
for mode in ("extrinsics", "centers"):
# try different noise levels
for add_noise in (0.0, 0.01, 1e-4):
self._corresponding_cameras_alignment_test_case(
cameras,
R_align_gt,
T_align_gt,
s_align_gt,
estimate_scale,
mode,
add_noise,
)
def _corresponding_cameras_alignment_test_case(
self,
cameras,
R_align_gt,
T_align_gt,
s_align_gt,
estimate_scale,
mode,
add_noise,
):
batch_size = cameras.R.shape[0]
# get target camera centers
R_new = torch.bmm(R_align_gt[None].expand_as(cameras.R), cameras.R)
T_new = (
torch.bmm(T_align_gt[None, None].repeat(batch_size, 1, 1), cameras.R)[:, 0]
+ cameras.T
) * s_align_gt
if add_noise != 0.0:
R_new = torch.bmm(R_new, so3_exp_map(torch.randn_like(T_new) * add_noise))
T_new += torch.randn_like(T_new) * add_noise
# create new cameras from R_new and T_new
cameras_tgt = cameras.clone()
cameras_tgt.R = R_new
cameras_tgt.T = T_new
# align cameras and cameras_tgt
cameras_aligned = corresponding_cameras_alignment(
cameras, cameras_tgt, estimate_scale=estimate_scale, mode=mode
)
if batch_size <= 3 and mode == "centers":
# Underdetermined case: with <= 3 camera centers in 3D, the points
# span at most a 2D subspace after mean-centering, so the Umeyama
# SVD has a zero (or near-zero) third singular value and the
# rotation around the degenerate axis is ambiguous. Different
# SVD implementations (e.g. rocBLAS on RDNA vs CDNA, or
# cuBLAS) make different valid choices in that null direction.
# Only the camera centers are well-defined here, so check those.
self.assertClose(
cameras_aligned.get_camera_center(),
cameras_tgt.get_camera_center(),
atol=max(add_noise * 7.0, 1e-4),
)
else:
def _rmse(a):
return (torch.norm(a, dim=1, p=2) ** 2).mean().sqrt()
if add_noise != 0.0:
# in a noisy case check mean rotation/translation error for
# extrinsic alignment and root mean center error for center alignment
if mode == "centers":
self.assertNormsClose(
cameras_aligned.get_camera_center(),
cameras_tgt.get_camera_center(),
_rmse,
atol=max(add_noise * 10.0, 1e-4),
)
elif mode == "extrinsics":
angle_err = so3_relative_angle(
cameras_aligned.R, cameras_tgt.R, cos_angle=True
).mean()
self.assertClose(
angle_err, torch.ones_like(angle_err), atol=add_noise * 0.03
)
self.assertNormsClose(
cameras_aligned.T, cameras_tgt.T, _rmse, atol=add_noise * 7.0
)
else:
raise ValueError(mode)
else:
# compare the rotations and translations of cameras
self.assertClose(cameras_aligned.R, cameras_tgt.R, atol=3e-4)
self.assertClose(cameras_aligned.T, cameras_tgt.T, atol=3e-4)
# compare the centers
self.assertClose(
cameras_aligned.get_camera_center(),
cameras_tgt.get_camera_center(),
atol=3e-4,
)
@staticmethod
def corresponding_cameras_alignment(
batch_size: int, estimate_scale: bool, mode: str, cam_type=SfMPerspectiveCameras
):
device = torch.device("cuda:0")
cameras_src, cameras_tgt = [
init_random_cameras(cam_type, batch_size, random_z=True).to(device)
for _ in range(2)
]
torch.cuda.synchronize()
def compute_corresponding_cameras_alignment():
corresponding_cameras_alignment(
cameras_src, cameras_tgt, estimate_scale=estimate_scale, mode=mode
)
torch.cuda.synchronize()
return compute_corresponding_cameras_alignment