7 Commits

Author SHA1 Message Date
bottler
9c586b1351 Run tests in github action not circleci (#1896)
Summary: Pull Request resolved: https://github.com/facebookresearch/pytorch3d/pull/1896

Differential Revision: D65272512

Pulled By: bottler
2024-10-31 08:41:20 -07:00
Richard Barnes
e13848265d at::optional -> std::optional (#1170)
Summary: Pull Request resolved: https://github.com/pytorch/ao/pull/1170

Reviewed By: gineshidalgo99

Differential Revision: D64938040

fbshipit-source-id: 57f98b90676ad0164a6975ea50e4414fd85ae6c4
2024-10-25 06:37:57 -07:00
generatedunixname89002005307016
58566963d6 Add type error suppressions for upcoming upgrade
Reviewed By: MaggieMoss

Differential Revision: D64502797

fbshipit-source-id: cee9a54dfa8a005d5912b895d0bd094f352c5c6f
2024-10-16 19:22:01 -07:00
Suresh Babu Kolla
e17ed5cd50 Hipify Pulsar for PyTorch3D
Summary:
- Hipified Pytorch Pulsar
   - Created separate target for Pulsar tests and enabled RE testing
   - Pytorch3D full test suite requires additional work like fixing EGL
     dependencies on AMD

Reviewed By: danzimm

Differential Revision: D61339912

fbshipit-source-id: 0d10bc966e4de4a959f3834a386bad24e449dc1f
2024-10-09 14:38:42 -07:00
Richard Barnes
8ed0c7a002 c10::optional -> std::optional
Summary: `c10::optional` is an alias for `std::optional`. Let's remove the alias and use the real thing.

Reviewed By: meyering

Differential Revision: D63402341

fbshipit-source-id: 241383e7ca4b2f3f1f9cac3af083056123dfd02b
2024-10-03 14:38:37 -07:00
Richard Barnes
2da913c7e6 c10::optional -> std::optional
Summary: `c10::optional` is an alias for `std::optional`. Let's remove the alias and use the real thing.

Reviewed By: palmje

Differential Revision: D63409387

fbshipit-source-id: fb6db59a14db9e897e2e6b6ad378f33bf2af86e8
2024-10-02 11:09:29 -07:00
generatedunixname89002005307016
fca83e6369 Convert .pyre_configuration.local to fast by default architecture] [batch:23/263] [shard:3/N] [A]
Reviewed By: connernilsen

Differential Revision: D63415925

fbshipit-source-id: c3e28405c70f9edcf8c21457ac4faf7315b07322
2024-09-25 17:34:03 -07:00
32 changed files with 127 additions and 71 deletions

20
.github/workflows/build.yml vendored Normal file
View File

@@ -0,0 +1,20 @@
name: facebookresearch/pytorch3d/build_and_test
on:
pull_request:
branches:
- main
jobs:
binary_linux_conda_cuda:
runs-on: 4-core-ubuntu-gpu-t4
env:
PYTHON_VERSION: "3.12"
BUILD_VERSION: "${{ github.run_number }}"
PYTORCH_VERSION: "2.4.1"
CU_VERSION: "cu121"
JUST_TESTRUN: 1
steps:
- uses: actions/checkout@v4
- name: Build and run tests
run: |-
conda create --name env --yes --quiet conda-build
conda run --no-capture-output --name env python3 ./packaging/build_conda.py --use-conda-cuda

View File

@@ -4,10 +4,11 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import argparse
import os.path
import runpy
import subprocess
from typing import List
from typing import List, Tuple
# required env vars:
# CU_VERSION: E.g. cu112
@@ -23,7 +24,7 @@ pytorch_major_minor = tuple(int(i) for i in PYTORCH_VERSION.split(".")[:2])
source_root_dir = os.environ["PWD"]
def version_constraint(version):
def version_constraint(version) -> str:
"""
Given version "11.3" returns " >=11.3,<11.4"
"""
@@ -32,7 +33,7 @@ def version_constraint(version):
return f" >={version},<{upper}"
def get_cuda_major_minor():
def get_cuda_major_minor() -> Tuple[str, str]:
if CU_VERSION == "cpu":
raise ValueError("fn only for cuda builds")
if len(CU_VERSION) != 5 or CU_VERSION[:2] != "cu":
@@ -42,11 +43,10 @@ def get_cuda_major_minor():
return major, minor
def setup_cuda():
def setup_cuda(use_conda_cuda: bool) -> List[str]:
if CU_VERSION == "cpu":
return
return []
major, minor = get_cuda_major_minor()
os.environ["CUDA_HOME"] = f"/usr/local/cuda-{major}.{minor}/"
os.environ["FORCE_CUDA"] = "1"
basic_nvcc_flags = (
@@ -75,6 +75,15 @@ def setup_cuda():
if os.environ.get("JUST_TESTRUN", "0") != "1":
os.environ["NVCC_FLAGS"] = nvcc_flags
if use_conda_cuda:
os.environ["CONDA_CUDA_TOOLKIT_BUILD_CONSTRAINT1"] = "- cuda-toolkit"
os.environ["CONDA_CUDA_TOOLKIT_BUILD_CONSTRAINT2"] = (
f"- cuda-version={major}.{minor}"
)
return ["-c", f"nvidia/label/cuda-{major}.{minor}.0"]
else:
os.environ["CUDA_HOME"] = f"/usr/local/cuda-{major}.{minor}/"
return []
def setup_conda_pytorch_constraint() -> List[str]:
@@ -95,7 +104,7 @@ def setup_conda_pytorch_constraint() -> List[str]:
return ["-c", "pytorch", "-c", "nvidia"]
def setup_conda_cudatoolkit_constraint():
def setup_conda_cudatoolkit_constraint() -> None:
if CU_VERSION == "cpu":
os.environ["CONDA_CPUONLY_FEATURE"] = "- cpuonly"
os.environ["CONDA_CUDATOOLKIT_CONSTRAINT"] = ""
@@ -116,7 +125,7 @@ def setup_conda_cudatoolkit_constraint():
os.environ["CONDA_CUDATOOLKIT_CONSTRAINT"] = toolkit
def do_build(start_args: List[str]):
def do_build(start_args: List[str]) -> None:
args = start_args.copy()
test_flag = os.environ.get("TEST_FLAG")
@@ -132,8 +141,16 @@ def do_build(start_args: List[str]):
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Build the conda package.")
parser.add_argument(
"--use-conda-cuda",
action="store_true",
help="get cuda from conda ignoring local cuda",
)
our_args = parser.parse_args()
args = ["conda", "build"]
setup_cuda()
args += setup_cuda(use_conda_cuda=our_args.use_conda_cuda)
init_path = source_root_dir + "/pytorch3d/__init__.py"
build_version = runpy.run_path(init_path)["__version__"]

View File

@@ -8,10 +8,13 @@ source:
requirements:
build:
- {{ compiler('c') }} # [win]
{{ environ.get('CONDA_CUDA_TOOLKIT_BUILD_CONSTRAINT1', '') }}
{{ environ.get('CONDA_CUDA_TOOLKIT_BUILD_CONSTRAINT2', '') }}
{{ environ.get('CONDA_CUB_CONSTRAINT') }}
host:
- python
- mkl =2023 # [x86_64]
{{ environ.get('SETUPTOOLS_CONSTRAINT') }}
{{ environ.get('CONDA_PYTORCH_BUILD_CONSTRAINT') }}
{{ environ.get('CONDA_PYTORCH_MKL_CONSTRAINT') }}
@@ -22,12 +25,14 @@ requirements:
- python
- numpy >=1.11
- torchvision >=0.5
- mkl =2023 # [x86_64]
- iopath
{{ environ.get('CONDA_PYTORCH_CONSTRAINT') }}
{{ environ.get('CONDA_CUDATOOLKIT_CONSTRAINT') }}
build:
string: py{{py}}_{{ environ['CU_VERSION'] }}_pyt{{ environ['PYTORCH_VERSION_NODOT']}}
# script: LD_LIBRARY_PATH=$PREFIX/lib:$BUILD_PREFIX/lib:$LD_LIBRARY_PATH python setup.py install --single-version-externally-managed --record=record.txt # [not win]
script: python setup.py install --single-version-externally-managed --record=record.txt # [not win]
script_env:
- CUDA_HOME
@@ -47,6 +52,10 @@ test:
- imageio
- hydra-core
- accelerate
- matplotlib
- tabulate
- pandas
- sqlalchemy
commands:
#pytest .
python -m unittest discover -v -s tests -t .

View File

@@ -7,15 +7,11 @@
*/
// clang-format off
#if !defined(USE_ROCM)
#include "./pulsar/global.h" // Include before <torch/extension.h>.
#endif
#include <torch/extension.h>
// clang-format on
#if !defined(USE_ROCM)
#include "./pulsar/pytorch/renderer.h"
#include "./pulsar/pytorch/tensor_util.h"
#endif
#include "ball_query/ball_query.h"
#include "blending/sigmoid_alpha_blend.h"
#include "compositing/alpha_composite.h"
@@ -104,7 +100,6 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
// Pulsar.
// Pulsar not enabled on AMD.
#if !defined(USE_ROCM)
#ifdef PULSAR_LOGGING_ENABLED
c10::ShowLogInfoToStderr();
#endif
@@ -189,5 +184,4 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.attr("MAX_UINT") = py::int_(MAX_UINT);
m.attr("MAX_USHORT") = py::int_(MAX_USHORT);
m.attr("PULSAR_MAX_GRAD_SPHERES") = py::int_(MAX_GRAD_SPHERES);
#endif
}

View File

@@ -36,11 +36,13 @@
#pragma nv_diag_suppress 2951
#pragma nv_diag_suppress 2967
#else
#if !defined(USE_ROCM)
#pragma diag_suppress = attribute_not_allowed
#pragma diag_suppress = 1866
#pragma diag_suppress = 2941
#pragma diag_suppress = 2951
#pragma diag_suppress = 2967
#endif //! USE_ROCM
#endif
#else // __CUDACC__
#define INLINE inline
@@ -56,7 +58,9 @@
#pragma clang diagnostic pop
#ifdef WITH_CUDA
#include <ATen/cuda/CUDAContext.h>
#if !defined(USE_ROCM)
#include <vector_functions.h>
#endif //! USE_ROCM
#else
#ifndef cudaStream_t
typedef void* cudaStream_t;

View File

@@ -59,6 +59,11 @@ getLastCudaError(const char* errorMessage, const char* file, const int line) {
#define SHARED __shared__
#define ACTIVEMASK() __activemask()
#define BALLOT(mask, val) __ballot_sync((mask), val)
/* TODO (ROCM-6.2): None of the WARP_* are used anywhere and ROCM-6.2 natively
* supports __shfl_*. Disabling until the move to ROCM-6.2.
*/
#if !defined(USE_ROCM)
/**
* Find the cumulative sum within a warp up to the current
* thread lane, with each mask thread contributing base.
@@ -115,6 +120,7 @@ INLINE DEVICE float3 WARP_SUM_FLOAT3(
ret.z = WARP_SUM(group, mask, base.z);
return ret;
}
#endif //! USE_ROCM
// Floating point.
// #define FMUL(a, b) __fmul_rn((a), (b))
@@ -142,6 +148,7 @@ INLINE DEVICE float3 WARP_SUM_FLOAT3(
#define FMA(x, y, z) __fmaf_rn((x), (y), (z))
#define I2F(a) __int2float_rn(a)
#define FRCP(x) __frcp_rn(x)
#if !defined(USE_ROCM)
__device__ static float atomicMax(float* address, float val) {
int* address_as_i = (int*)address;
int old = *address_as_i, assumed;
@@ -166,6 +173,7 @@ __device__ static float atomicMin(float* address, float val) {
} while (assumed != old);
return __int_as_float(old);
}
#endif //! USE_ROCM
#define DMAX(a, b) FMAX(a, b)
#define DMIN(a, b) FMIN(a, b)
#define DSQRT(a) sqrt(a)

View File

@@ -14,7 +14,7 @@
#include "./commands.h"
namespace pulsar {
IHD CamGradInfo::CamGradInfo() {
IHD CamGradInfo::CamGradInfo(int x) {
cam_pos = make_float3(0.f, 0.f, 0.f);
pixel_0_0_center = make_float3(0.f, 0.f, 0.f);
pixel_dir_x = make_float3(0.f, 0.f, 0.f);

View File

@@ -63,7 +63,7 @@ inline bool operator==(const CamInfo& a, const CamInfo& b) {
};
struct CamGradInfo {
HOST DEVICE CamGradInfo();
HOST DEVICE CamGradInfo(int = 0);
float3 cam_pos;
float3 pixel_0_0_center;
float3 pixel_dir_x;

View File

@@ -24,7 +24,7 @@
// #pragma diag_suppress = 68
#include <ATen/cuda/CUDAContext.h>
// #pragma pop
#include "../cuda/commands.h"
#include "../gpu/commands.h"
#else
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Weverything"

View File

@@ -46,6 +46,7 @@ IHD float3 outer_product_sum(const float3& a) {
}
// TODO: put intrinsics here.
#if !defined(USE_ROCM)
IHD float3 operator+(const float3& a, const float3& b) {
return make_float3(a.x + b.x, a.y + b.y, a.z + b.z);
}
@@ -93,6 +94,7 @@ IHD float3 operator*(const float3& a, const float3& b) {
IHD float3 operator*(const float& a, const float3& b) {
return b * a;
}
#endif //! USE_ROCM
INLINE DEVICE float length(const float3& v) {
// TODO: benchmark what's faster.

View File

@@ -283,9 +283,15 @@ GLOBAL void render(
(percent_allowed_difference > 0.f &&
max_closest_possible_intersection > depth_threshold) ||
tracker.get_n_hits() >= max_n_hits;
#if defined(__CUDACC__) && defined(__HIP_PLATFORM_AMD__)
unsigned long long warp_done = __ballot(done);
int warp_done_bit_cnt = __popcll(warp_done);
#else
uint warp_done = thread_warp.ballot(done);
int warp_done_bit_cnt = POPC(warp_done);
#endif //__CUDACC__ && __HIP_PLATFORM_AMD__
if (thread_warp.thread_rank() == 0)
ATOMICADD_B(&n_pixels_done, POPC(warp_done));
ATOMICADD_B(&n_pixels_done, warp_done_bit_cnt);
// This sync is necessary to keep n_loaded until all threads are done with
// painting.
thread_block.sync();

View File

@@ -213,8 +213,8 @@ std::tuple<size_t, size_t, bool, torch::Tensor> Renderer::arg_check(
const float& gamma,
const float& max_depth,
float& min_depth,
const c10::optional<torch::Tensor>& bg_col,
const c10::optional<torch::Tensor>& opacity,
const std::optional<torch::Tensor>& bg_col,
const std::optional<torch::Tensor>& opacity,
const float& percent_allowed_difference,
const uint& max_n_hits,
const uint& mode) {
@@ -668,8 +668,8 @@ std::tuple<torch::Tensor, torch::Tensor> Renderer::forward(
const float& gamma,
const float& max_depth,
float min_depth,
const c10::optional<torch::Tensor>& bg_col,
const c10::optional<torch::Tensor>& opacity,
const std::optional<torch::Tensor>& bg_col,
const std::optional<torch::Tensor>& opacity,
const float& percent_allowed_difference,
const uint& max_n_hits,
const uint& mode) {
@@ -888,14 +888,14 @@ std::tuple<torch::Tensor, torch::Tensor> Renderer::forward(
};
std::tuple<
at::optional<torch::Tensor>,
at::optional<torch::Tensor>,
at::optional<torch::Tensor>,
at::optional<torch::Tensor>,
at::optional<torch::Tensor>,
at::optional<torch::Tensor>,
at::optional<torch::Tensor>,
at::optional<torch::Tensor>>
std::optional<torch::Tensor>,
std::optional<torch::Tensor>,
std::optional<torch::Tensor>,
std::optional<torch::Tensor>,
std::optional<torch::Tensor>,
std::optional<torch::Tensor>,
std::optional<torch::Tensor>,
std::optional<torch::Tensor>>
Renderer::backward(
const torch::Tensor& grad_im,
const torch::Tensor& image,
@@ -912,8 +912,8 @@ Renderer::backward(
const float& gamma,
const float& max_depth,
float min_depth,
const c10::optional<torch::Tensor>& bg_col,
const c10::optional<torch::Tensor>& opacity,
const std::optional<torch::Tensor>& bg_col,
const std::optional<torch::Tensor>& opacity,
const float& percent_allowed_difference,
const uint& max_n_hits,
const uint& mode,
@@ -922,7 +922,7 @@ Renderer::backward(
const bool& dif_rad,
const bool& dif_cam,
const bool& dif_opy,
const at::optional<std::pair<uint, uint>>& dbg_pos) {
const std::optional<std::pair<uint, uint>>& dbg_pos) {
this->ensure_on_device(this->device_tracker.device());
size_t batch_size;
size_t n_points;
@@ -1045,14 +1045,14 @@ Renderer::backward(
}
// Prepare the return value.
std::tuple<
at::optional<torch::Tensor>,
at::optional<torch::Tensor>,
at::optional<torch::Tensor>,
at::optional<torch::Tensor>,
at::optional<torch::Tensor>,
at::optional<torch::Tensor>,
at::optional<torch::Tensor>,
at::optional<torch::Tensor>>
std::optional<torch::Tensor>,
std::optional<torch::Tensor>,
std::optional<torch::Tensor>,
std::optional<torch::Tensor>,
std::optional<torch::Tensor>,
std::optional<torch::Tensor>,
std::optional<torch::Tensor>,
std::optional<torch::Tensor>>
ret;
if (mode == 1 || (!dif_pos && !dif_col && !dif_rad && !dif_cam && !dif_opy)) {
return ret;

View File

@@ -44,21 +44,21 @@ struct Renderer {
const float& gamma,
const float& max_depth,
float min_depth,
const c10::optional<torch::Tensor>& bg_col,
const c10::optional<torch::Tensor>& opacity,
const std::optional<torch::Tensor>& bg_col,
const std::optional<torch::Tensor>& opacity,
const float& percent_allowed_difference,
const uint& max_n_hits,
const uint& mode);
std::tuple<
at::optional<torch::Tensor>,
at::optional<torch::Tensor>,
at::optional<torch::Tensor>,
at::optional<torch::Tensor>,
at::optional<torch::Tensor>,
at::optional<torch::Tensor>,
at::optional<torch::Tensor>,
at::optional<torch::Tensor>>
std::optional<torch::Tensor>,
std::optional<torch::Tensor>,
std::optional<torch::Tensor>,
std::optional<torch::Tensor>,
std::optional<torch::Tensor>,
std::optional<torch::Tensor>,
std::optional<torch::Tensor>,
std::optional<torch::Tensor>>
backward(
const torch::Tensor& grad_im,
const torch::Tensor& image,
@@ -75,8 +75,8 @@ struct Renderer {
const float& gamma,
const float& max_depth,
float min_depth,
const c10::optional<torch::Tensor>& bg_col,
const c10::optional<torch::Tensor>& opacity,
const std::optional<torch::Tensor>& bg_col,
const std::optional<torch::Tensor>& opacity,
const float& percent_allowed_difference,
const uint& max_n_hits,
const uint& mode,
@@ -85,7 +85,7 @@ struct Renderer {
const bool& dif_rad,
const bool& dif_cam,
const bool& dif_opy,
const at::optional<std::pair<uint, uint>>& dbg_pos);
const std::optional<std::pair<uint, uint>>& dbg_pos);
// Infrastructure.
/**
@@ -115,8 +115,8 @@ struct Renderer {
const float& gamma,
const float& max_depth,
float& min_depth,
const c10::optional<torch::Tensor>& bg_col,
const c10::optional<torch::Tensor>& opacity,
const std::optional<torch::Tensor>& bg_col,
const std::optional<torch::Tensor>& opacity,
const float& percent_allowed_difference,
const uint& max_n_hits,
const uint& mode);

View File

@@ -163,6 +163,9 @@ def _read_chunks(
if binary_data is not None:
binary_data = np.frombuffer(binary_data, dtype=np.uint8)
# pyre-fixme[7]: Expected `Optional[Tuple[Dict[str, typing.Any],
# ndarray[typing.Any, typing.Any]]]` but got `Tuple[typing.Any,
# Optional[ndarray[typing.Any, dtype[typing.Any]]]]`.
return json_data, binary_data

View File

@@ -1250,6 +1250,9 @@ def _save_ply(
if verts_normals is not None:
verts_dtype.append(("normals", np.float32, 3))
if verts_colors is not None:
# pyre-fixme[6]: For 1st argument expected `Tuple[str,
# Type[floating[_32Bit]], int]` but got `Tuple[str,
# Type[Union[floating[_32Bit], unsignedinteger[typing.Any]]], int]`.
verts_dtype.append(("colors", color_np_type, 3))
vert_data = np.zeros(verts.shape[0], dtype=verts_dtype)

View File

@@ -76,13 +76,9 @@ from .points import (
PointsRasterizationSettings,
PointsRasterizer,
PointsRenderer,
PulsarPointsRenderer,
rasterize_points,
)
# Pulsar is not enabled on amd.
if not torch.version.hip:
from .points import PulsarPointsRenderer
from .splatter_blend import SplatterBlender
from .utils import (
convert_to_tensors_and_broadcast,

View File

@@ -10,9 +10,7 @@ import torch
from .compositor import AlphaCompositor, NormWeightedCompositor
# Pulsar not enabled on amd.
if not torch.version.hip:
from .pulsar.unified import PulsarPointsRenderer
from .pulsar.unified import PulsarPointsRenderer
from .rasterize_points import rasterize_points
from .rasterizer import PointsRasterizationSettings, PointsRasterizer

View File

@@ -674,7 +674,6 @@ def _add_mesh_trace(
verts[~verts_used] = verts_center
row, col = subplot_idx // ncols + 1, subplot_idx % ncols + 1
# pyre-fixme[16]: `Figure` has no attribute `add_trace`.
fig.add_trace(
go.Mesh3d(
x=verts[:, 0],
@@ -741,7 +740,6 @@ def _add_pointcloud_trace(
row = subplot_idx // ncols + 1
col = subplot_idx % ncols + 1
# pyre-fixme[16]: `Figure` has no attribute `add_trace`.
fig.add_trace(
go.Scatter3d(
x=verts[:, 0],
@@ -803,7 +801,6 @@ def _add_camera_trace(
x, y, z = all_cam_wires.detach().cpu().numpy().T.astype(float)
row, col = subplot_idx // ncols + 1, subplot_idx % ncols + 1
# pyre-fixme[16]: `Figure` has no attribute `add_trace`.
fig.add_trace(
go.Scatter3d(x=x, y=y, z=z, marker={"size": 1}, name=trace_name),
row=row,
@@ -898,7 +895,6 @@ def _add_ray_bundle_trace(
ray_lines = torch.cat((ray_lines, nan_tensor, ray_line))
x, y, z = ray_lines.detach().cpu().numpy().T.astype(float)
row, col = subplot_idx // ncols + 1, subplot_idx % ncols + 1
# pyre-fixme[16]: `Figure` has no attribute `add_trace`.
fig.add_trace(
go.Scatter3d(
x=x,
@@ -1010,7 +1006,6 @@ def _update_axes_bounds(
# Ensure that within a subplot, the bounds capture all traces
old_xrange, old_yrange, old_zrange = (
# pyre-fixme[16]: `Scene` has no attribute `__getitem__`.
current_layout["xaxis"]["range"],
current_layout["yaxis"]["range"],
current_layout["zaxis"]["range"],
@@ -1029,7 +1024,6 @@ def _update_axes_bounds(
xaxis = {"range": x_range}
yaxis = {"range": y_range}
zaxis = {"range": z_range}
# pyre-fixme[16]: `Scene` has no attribute `update`.
current_layout.update({"xaxis": xaxis, "yaxis": yaxis, "zaxis": zaxis})

View File

@@ -59,6 +59,8 @@ def texturesuv_image_matplotlib(
for i in indices:
# setting clip_on=False makes it obvious when
# we have UV coordinates outside the correct range
# pyre-fixme[6]: For 1st argument expected `Tuple[float, float]` but got
# `ndarray[Any, Any]`.
ax.add_patch(Circle(centers[i], radius, color=color, clip_on=False))