mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-12-21 23:00:34 +08:00
Compare commits
7 Commits
V0.7.8
...
bottler/ac
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
9c586b1351 | ||
|
|
e13848265d | ||
|
|
58566963d6 | ||
|
|
e17ed5cd50 | ||
|
|
8ed0c7a002 | ||
|
|
2da913c7e6 | ||
|
|
fca83e6369 |
20
.github/workflows/build.yml
vendored
Normal file
20
.github/workflows/build.yml
vendored
Normal file
@@ -0,0 +1,20 @@
|
||||
name: facebookresearch/pytorch3d/build_and_test
|
||||
on:
|
||||
pull_request:
|
||||
branches:
|
||||
- main
|
||||
jobs:
|
||||
binary_linux_conda_cuda:
|
||||
runs-on: 4-core-ubuntu-gpu-t4
|
||||
env:
|
||||
PYTHON_VERSION: "3.12"
|
||||
BUILD_VERSION: "${{ github.run_number }}"
|
||||
PYTORCH_VERSION: "2.4.1"
|
||||
CU_VERSION: "cu121"
|
||||
JUST_TESTRUN: 1
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- name: Build and run tests
|
||||
run: |-
|
||||
conda create --name env --yes --quiet conda-build
|
||||
conda run --no-capture-output --name env python3 ./packaging/build_conda.py --use-conda-cuda
|
||||
@@ -4,10 +4,11 @@
|
||||
# This source code is licensed under the BSD-style license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import argparse
|
||||
import os.path
|
||||
import runpy
|
||||
import subprocess
|
||||
from typing import List
|
||||
from typing import List, Tuple
|
||||
|
||||
# required env vars:
|
||||
# CU_VERSION: E.g. cu112
|
||||
@@ -23,7 +24,7 @@ pytorch_major_minor = tuple(int(i) for i in PYTORCH_VERSION.split(".")[:2])
|
||||
source_root_dir = os.environ["PWD"]
|
||||
|
||||
|
||||
def version_constraint(version):
|
||||
def version_constraint(version) -> str:
|
||||
"""
|
||||
Given version "11.3" returns " >=11.3,<11.4"
|
||||
"""
|
||||
@@ -32,7 +33,7 @@ def version_constraint(version):
|
||||
return f" >={version},<{upper}"
|
||||
|
||||
|
||||
def get_cuda_major_minor():
|
||||
def get_cuda_major_minor() -> Tuple[str, str]:
|
||||
if CU_VERSION == "cpu":
|
||||
raise ValueError("fn only for cuda builds")
|
||||
if len(CU_VERSION) != 5 or CU_VERSION[:2] != "cu":
|
||||
@@ -42,11 +43,10 @@ def get_cuda_major_minor():
|
||||
return major, minor
|
||||
|
||||
|
||||
def setup_cuda():
|
||||
def setup_cuda(use_conda_cuda: bool) -> List[str]:
|
||||
if CU_VERSION == "cpu":
|
||||
return
|
||||
return []
|
||||
major, minor = get_cuda_major_minor()
|
||||
os.environ["CUDA_HOME"] = f"/usr/local/cuda-{major}.{minor}/"
|
||||
os.environ["FORCE_CUDA"] = "1"
|
||||
|
||||
basic_nvcc_flags = (
|
||||
@@ -75,6 +75,15 @@ def setup_cuda():
|
||||
|
||||
if os.environ.get("JUST_TESTRUN", "0") != "1":
|
||||
os.environ["NVCC_FLAGS"] = nvcc_flags
|
||||
if use_conda_cuda:
|
||||
os.environ["CONDA_CUDA_TOOLKIT_BUILD_CONSTRAINT1"] = "- cuda-toolkit"
|
||||
os.environ["CONDA_CUDA_TOOLKIT_BUILD_CONSTRAINT2"] = (
|
||||
f"- cuda-version={major}.{minor}"
|
||||
)
|
||||
return ["-c", f"nvidia/label/cuda-{major}.{minor}.0"]
|
||||
else:
|
||||
os.environ["CUDA_HOME"] = f"/usr/local/cuda-{major}.{minor}/"
|
||||
return []
|
||||
|
||||
|
||||
def setup_conda_pytorch_constraint() -> List[str]:
|
||||
@@ -95,7 +104,7 @@ def setup_conda_pytorch_constraint() -> List[str]:
|
||||
return ["-c", "pytorch", "-c", "nvidia"]
|
||||
|
||||
|
||||
def setup_conda_cudatoolkit_constraint():
|
||||
def setup_conda_cudatoolkit_constraint() -> None:
|
||||
if CU_VERSION == "cpu":
|
||||
os.environ["CONDA_CPUONLY_FEATURE"] = "- cpuonly"
|
||||
os.environ["CONDA_CUDATOOLKIT_CONSTRAINT"] = ""
|
||||
@@ -116,7 +125,7 @@ def setup_conda_cudatoolkit_constraint():
|
||||
os.environ["CONDA_CUDATOOLKIT_CONSTRAINT"] = toolkit
|
||||
|
||||
|
||||
def do_build(start_args: List[str]):
|
||||
def do_build(start_args: List[str]) -> None:
|
||||
args = start_args.copy()
|
||||
|
||||
test_flag = os.environ.get("TEST_FLAG")
|
||||
@@ -132,8 +141,16 @@ def do_build(start_args: List[str]):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Build the conda package.")
|
||||
parser.add_argument(
|
||||
"--use-conda-cuda",
|
||||
action="store_true",
|
||||
help="get cuda from conda ignoring local cuda",
|
||||
)
|
||||
our_args = parser.parse_args()
|
||||
|
||||
args = ["conda", "build"]
|
||||
setup_cuda()
|
||||
args += setup_cuda(use_conda_cuda=our_args.use_conda_cuda)
|
||||
|
||||
init_path = source_root_dir + "/pytorch3d/__init__.py"
|
||||
build_version = runpy.run_path(init_path)["__version__"]
|
||||
|
||||
@@ -8,10 +8,13 @@ source:
|
||||
requirements:
|
||||
build:
|
||||
- {{ compiler('c') }} # [win]
|
||||
{{ environ.get('CONDA_CUDA_TOOLKIT_BUILD_CONSTRAINT1', '') }}
|
||||
{{ environ.get('CONDA_CUDA_TOOLKIT_BUILD_CONSTRAINT2', '') }}
|
||||
{{ environ.get('CONDA_CUB_CONSTRAINT') }}
|
||||
|
||||
host:
|
||||
- python
|
||||
- mkl =2023 # [x86_64]
|
||||
{{ environ.get('SETUPTOOLS_CONSTRAINT') }}
|
||||
{{ environ.get('CONDA_PYTORCH_BUILD_CONSTRAINT') }}
|
||||
{{ environ.get('CONDA_PYTORCH_MKL_CONSTRAINT') }}
|
||||
@@ -22,12 +25,14 @@ requirements:
|
||||
- python
|
||||
- numpy >=1.11
|
||||
- torchvision >=0.5
|
||||
- mkl =2023 # [x86_64]
|
||||
- iopath
|
||||
{{ environ.get('CONDA_PYTORCH_CONSTRAINT') }}
|
||||
{{ environ.get('CONDA_CUDATOOLKIT_CONSTRAINT') }}
|
||||
|
||||
build:
|
||||
string: py{{py}}_{{ environ['CU_VERSION'] }}_pyt{{ environ['PYTORCH_VERSION_NODOT']}}
|
||||
# script: LD_LIBRARY_PATH=$PREFIX/lib:$BUILD_PREFIX/lib:$LD_LIBRARY_PATH python setup.py install --single-version-externally-managed --record=record.txt # [not win]
|
||||
script: python setup.py install --single-version-externally-managed --record=record.txt # [not win]
|
||||
script_env:
|
||||
- CUDA_HOME
|
||||
@@ -47,6 +52,10 @@ test:
|
||||
- imageio
|
||||
- hydra-core
|
||||
- accelerate
|
||||
- matplotlib
|
||||
- tabulate
|
||||
- pandas
|
||||
- sqlalchemy
|
||||
commands:
|
||||
#pytest .
|
||||
python -m unittest discover -v -s tests -t .
|
||||
|
||||
@@ -7,15 +7,11 @@
|
||||
*/
|
||||
|
||||
// clang-format off
|
||||
#if !defined(USE_ROCM)
|
||||
#include "./pulsar/global.h" // Include before <torch/extension.h>.
|
||||
#endif
|
||||
#include <torch/extension.h>
|
||||
// clang-format on
|
||||
#if !defined(USE_ROCM)
|
||||
#include "./pulsar/pytorch/renderer.h"
|
||||
#include "./pulsar/pytorch/tensor_util.h"
|
||||
#endif
|
||||
#include "ball_query/ball_query.h"
|
||||
#include "blending/sigmoid_alpha_blend.h"
|
||||
#include "compositing/alpha_composite.h"
|
||||
@@ -104,7 +100,6 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
|
||||
// Pulsar.
|
||||
// Pulsar not enabled on AMD.
|
||||
#if !defined(USE_ROCM)
|
||||
#ifdef PULSAR_LOGGING_ENABLED
|
||||
c10::ShowLogInfoToStderr();
|
||||
#endif
|
||||
@@ -189,5 +184,4 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
m.attr("MAX_UINT") = py::int_(MAX_UINT);
|
||||
m.attr("MAX_USHORT") = py::int_(MAX_USHORT);
|
||||
m.attr("PULSAR_MAX_GRAD_SPHERES") = py::int_(MAX_GRAD_SPHERES);
|
||||
#endif
|
||||
}
|
||||
|
||||
@@ -36,11 +36,13 @@
|
||||
#pragma nv_diag_suppress 2951
|
||||
#pragma nv_diag_suppress 2967
|
||||
#else
|
||||
#if !defined(USE_ROCM)
|
||||
#pragma diag_suppress = attribute_not_allowed
|
||||
#pragma diag_suppress = 1866
|
||||
#pragma diag_suppress = 2941
|
||||
#pragma diag_suppress = 2951
|
||||
#pragma diag_suppress = 2967
|
||||
#endif //! USE_ROCM
|
||||
#endif
|
||||
#else // __CUDACC__
|
||||
#define INLINE inline
|
||||
@@ -56,7 +58,9 @@
|
||||
#pragma clang diagnostic pop
|
||||
#ifdef WITH_CUDA
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#if !defined(USE_ROCM)
|
||||
#include <vector_functions.h>
|
||||
#endif //! USE_ROCM
|
||||
#else
|
||||
#ifndef cudaStream_t
|
||||
typedef void* cudaStream_t;
|
||||
|
||||
@@ -59,6 +59,11 @@ getLastCudaError(const char* errorMessage, const char* file, const int line) {
|
||||
#define SHARED __shared__
|
||||
#define ACTIVEMASK() __activemask()
|
||||
#define BALLOT(mask, val) __ballot_sync((mask), val)
|
||||
|
||||
/* TODO (ROCM-6.2): None of the WARP_* are used anywhere and ROCM-6.2 natively
|
||||
* supports __shfl_*. Disabling until the move to ROCM-6.2.
|
||||
*/
|
||||
#if !defined(USE_ROCM)
|
||||
/**
|
||||
* Find the cumulative sum within a warp up to the current
|
||||
* thread lane, with each mask thread contributing base.
|
||||
@@ -115,6 +120,7 @@ INLINE DEVICE float3 WARP_SUM_FLOAT3(
|
||||
ret.z = WARP_SUM(group, mask, base.z);
|
||||
return ret;
|
||||
}
|
||||
#endif //! USE_ROCM
|
||||
|
||||
// Floating point.
|
||||
// #define FMUL(a, b) __fmul_rn((a), (b))
|
||||
@@ -142,6 +148,7 @@ INLINE DEVICE float3 WARP_SUM_FLOAT3(
|
||||
#define FMA(x, y, z) __fmaf_rn((x), (y), (z))
|
||||
#define I2F(a) __int2float_rn(a)
|
||||
#define FRCP(x) __frcp_rn(x)
|
||||
#if !defined(USE_ROCM)
|
||||
__device__ static float atomicMax(float* address, float val) {
|
||||
int* address_as_i = (int*)address;
|
||||
int old = *address_as_i, assumed;
|
||||
@@ -166,6 +173,7 @@ __device__ static float atomicMin(float* address, float val) {
|
||||
} while (assumed != old);
|
||||
return __int_as_float(old);
|
||||
}
|
||||
#endif //! USE_ROCM
|
||||
#define DMAX(a, b) FMAX(a, b)
|
||||
#define DMIN(a, b) FMIN(a, b)
|
||||
#define DSQRT(a) sqrt(a)
|
||||
@@ -14,7 +14,7 @@
|
||||
#include "./commands.h"
|
||||
|
||||
namespace pulsar {
|
||||
IHD CamGradInfo::CamGradInfo() {
|
||||
IHD CamGradInfo::CamGradInfo(int x) {
|
||||
cam_pos = make_float3(0.f, 0.f, 0.f);
|
||||
pixel_0_0_center = make_float3(0.f, 0.f, 0.f);
|
||||
pixel_dir_x = make_float3(0.f, 0.f, 0.f);
|
||||
|
||||
@@ -63,7 +63,7 @@ inline bool operator==(const CamInfo& a, const CamInfo& b) {
|
||||
};
|
||||
|
||||
struct CamGradInfo {
|
||||
HOST DEVICE CamGradInfo();
|
||||
HOST DEVICE CamGradInfo(int = 0);
|
||||
float3 cam_pos;
|
||||
float3 pixel_0_0_center;
|
||||
float3 pixel_dir_x;
|
||||
|
||||
@@ -24,7 +24,7 @@
|
||||
// #pragma diag_suppress = 68
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
// #pragma pop
|
||||
#include "../cuda/commands.h"
|
||||
#include "../gpu/commands.h"
|
||||
#else
|
||||
#pragma clang diagnostic push
|
||||
#pragma clang diagnostic ignored "-Weverything"
|
||||
|
||||
@@ -46,6 +46,7 @@ IHD float3 outer_product_sum(const float3& a) {
|
||||
}
|
||||
|
||||
// TODO: put intrinsics here.
|
||||
#if !defined(USE_ROCM)
|
||||
IHD float3 operator+(const float3& a, const float3& b) {
|
||||
return make_float3(a.x + b.x, a.y + b.y, a.z + b.z);
|
||||
}
|
||||
@@ -93,6 +94,7 @@ IHD float3 operator*(const float3& a, const float3& b) {
|
||||
IHD float3 operator*(const float& a, const float3& b) {
|
||||
return b * a;
|
||||
}
|
||||
#endif //! USE_ROCM
|
||||
|
||||
INLINE DEVICE float length(const float3& v) {
|
||||
// TODO: benchmark what's faster.
|
||||
|
||||
@@ -283,9 +283,15 @@ GLOBAL void render(
|
||||
(percent_allowed_difference > 0.f &&
|
||||
max_closest_possible_intersection > depth_threshold) ||
|
||||
tracker.get_n_hits() >= max_n_hits;
|
||||
#if defined(__CUDACC__) && defined(__HIP_PLATFORM_AMD__)
|
||||
unsigned long long warp_done = __ballot(done);
|
||||
int warp_done_bit_cnt = __popcll(warp_done);
|
||||
#else
|
||||
uint warp_done = thread_warp.ballot(done);
|
||||
int warp_done_bit_cnt = POPC(warp_done);
|
||||
#endif //__CUDACC__ && __HIP_PLATFORM_AMD__
|
||||
if (thread_warp.thread_rank() == 0)
|
||||
ATOMICADD_B(&n_pixels_done, POPC(warp_done));
|
||||
ATOMICADD_B(&n_pixels_done, warp_done_bit_cnt);
|
||||
// This sync is necessary to keep n_loaded until all threads are done with
|
||||
// painting.
|
||||
thread_block.sync();
|
||||
|
||||
@@ -213,8 +213,8 @@ std::tuple<size_t, size_t, bool, torch::Tensor> Renderer::arg_check(
|
||||
const float& gamma,
|
||||
const float& max_depth,
|
||||
float& min_depth,
|
||||
const c10::optional<torch::Tensor>& bg_col,
|
||||
const c10::optional<torch::Tensor>& opacity,
|
||||
const std::optional<torch::Tensor>& bg_col,
|
||||
const std::optional<torch::Tensor>& opacity,
|
||||
const float& percent_allowed_difference,
|
||||
const uint& max_n_hits,
|
||||
const uint& mode) {
|
||||
@@ -668,8 +668,8 @@ std::tuple<torch::Tensor, torch::Tensor> Renderer::forward(
|
||||
const float& gamma,
|
||||
const float& max_depth,
|
||||
float min_depth,
|
||||
const c10::optional<torch::Tensor>& bg_col,
|
||||
const c10::optional<torch::Tensor>& opacity,
|
||||
const std::optional<torch::Tensor>& bg_col,
|
||||
const std::optional<torch::Tensor>& opacity,
|
||||
const float& percent_allowed_difference,
|
||||
const uint& max_n_hits,
|
||||
const uint& mode) {
|
||||
@@ -888,14 +888,14 @@ std::tuple<torch::Tensor, torch::Tensor> Renderer::forward(
|
||||
};
|
||||
|
||||
std::tuple<
|
||||
at::optional<torch::Tensor>,
|
||||
at::optional<torch::Tensor>,
|
||||
at::optional<torch::Tensor>,
|
||||
at::optional<torch::Tensor>,
|
||||
at::optional<torch::Tensor>,
|
||||
at::optional<torch::Tensor>,
|
||||
at::optional<torch::Tensor>,
|
||||
at::optional<torch::Tensor>>
|
||||
std::optional<torch::Tensor>,
|
||||
std::optional<torch::Tensor>,
|
||||
std::optional<torch::Tensor>,
|
||||
std::optional<torch::Tensor>,
|
||||
std::optional<torch::Tensor>,
|
||||
std::optional<torch::Tensor>,
|
||||
std::optional<torch::Tensor>,
|
||||
std::optional<torch::Tensor>>
|
||||
Renderer::backward(
|
||||
const torch::Tensor& grad_im,
|
||||
const torch::Tensor& image,
|
||||
@@ -912,8 +912,8 @@ Renderer::backward(
|
||||
const float& gamma,
|
||||
const float& max_depth,
|
||||
float min_depth,
|
||||
const c10::optional<torch::Tensor>& bg_col,
|
||||
const c10::optional<torch::Tensor>& opacity,
|
||||
const std::optional<torch::Tensor>& bg_col,
|
||||
const std::optional<torch::Tensor>& opacity,
|
||||
const float& percent_allowed_difference,
|
||||
const uint& max_n_hits,
|
||||
const uint& mode,
|
||||
@@ -922,7 +922,7 @@ Renderer::backward(
|
||||
const bool& dif_rad,
|
||||
const bool& dif_cam,
|
||||
const bool& dif_opy,
|
||||
const at::optional<std::pair<uint, uint>>& dbg_pos) {
|
||||
const std::optional<std::pair<uint, uint>>& dbg_pos) {
|
||||
this->ensure_on_device(this->device_tracker.device());
|
||||
size_t batch_size;
|
||||
size_t n_points;
|
||||
@@ -1045,14 +1045,14 @@ Renderer::backward(
|
||||
}
|
||||
// Prepare the return value.
|
||||
std::tuple<
|
||||
at::optional<torch::Tensor>,
|
||||
at::optional<torch::Tensor>,
|
||||
at::optional<torch::Tensor>,
|
||||
at::optional<torch::Tensor>,
|
||||
at::optional<torch::Tensor>,
|
||||
at::optional<torch::Tensor>,
|
||||
at::optional<torch::Tensor>,
|
||||
at::optional<torch::Tensor>>
|
||||
std::optional<torch::Tensor>,
|
||||
std::optional<torch::Tensor>,
|
||||
std::optional<torch::Tensor>,
|
||||
std::optional<torch::Tensor>,
|
||||
std::optional<torch::Tensor>,
|
||||
std::optional<torch::Tensor>,
|
||||
std::optional<torch::Tensor>,
|
||||
std::optional<torch::Tensor>>
|
||||
ret;
|
||||
if (mode == 1 || (!dif_pos && !dif_col && !dif_rad && !dif_cam && !dif_opy)) {
|
||||
return ret;
|
||||
|
||||
@@ -44,21 +44,21 @@ struct Renderer {
|
||||
const float& gamma,
|
||||
const float& max_depth,
|
||||
float min_depth,
|
||||
const c10::optional<torch::Tensor>& bg_col,
|
||||
const c10::optional<torch::Tensor>& opacity,
|
||||
const std::optional<torch::Tensor>& bg_col,
|
||||
const std::optional<torch::Tensor>& opacity,
|
||||
const float& percent_allowed_difference,
|
||||
const uint& max_n_hits,
|
||||
const uint& mode);
|
||||
|
||||
std::tuple<
|
||||
at::optional<torch::Tensor>,
|
||||
at::optional<torch::Tensor>,
|
||||
at::optional<torch::Tensor>,
|
||||
at::optional<torch::Tensor>,
|
||||
at::optional<torch::Tensor>,
|
||||
at::optional<torch::Tensor>,
|
||||
at::optional<torch::Tensor>,
|
||||
at::optional<torch::Tensor>>
|
||||
std::optional<torch::Tensor>,
|
||||
std::optional<torch::Tensor>,
|
||||
std::optional<torch::Tensor>,
|
||||
std::optional<torch::Tensor>,
|
||||
std::optional<torch::Tensor>,
|
||||
std::optional<torch::Tensor>,
|
||||
std::optional<torch::Tensor>,
|
||||
std::optional<torch::Tensor>>
|
||||
backward(
|
||||
const torch::Tensor& grad_im,
|
||||
const torch::Tensor& image,
|
||||
@@ -75,8 +75,8 @@ struct Renderer {
|
||||
const float& gamma,
|
||||
const float& max_depth,
|
||||
float min_depth,
|
||||
const c10::optional<torch::Tensor>& bg_col,
|
||||
const c10::optional<torch::Tensor>& opacity,
|
||||
const std::optional<torch::Tensor>& bg_col,
|
||||
const std::optional<torch::Tensor>& opacity,
|
||||
const float& percent_allowed_difference,
|
||||
const uint& max_n_hits,
|
||||
const uint& mode,
|
||||
@@ -85,7 +85,7 @@ struct Renderer {
|
||||
const bool& dif_rad,
|
||||
const bool& dif_cam,
|
||||
const bool& dif_opy,
|
||||
const at::optional<std::pair<uint, uint>>& dbg_pos);
|
||||
const std::optional<std::pair<uint, uint>>& dbg_pos);
|
||||
|
||||
// Infrastructure.
|
||||
/**
|
||||
@@ -115,8 +115,8 @@ struct Renderer {
|
||||
const float& gamma,
|
||||
const float& max_depth,
|
||||
float& min_depth,
|
||||
const c10::optional<torch::Tensor>& bg_col,
|
||||
const c10::optional<torch::Tensor>& opacity,
|
||||
const std::optional<torch::Tensor>& bg_col,
|
||||
const std::optional<torch::Tensor>& opacity,
|
||||
const float& percent_allowed_difference,
|
||||
const uint& max_n_hits,
|
||||
const uint& mode);
|
||||
|
||||
@@ -163,6 +163,9 @@ def _read_chunks(
|
||||
if binary_data is not None:
|
||||
binary_data = np.frombuffer(binary_data, dtype=np.uint8)
|
||||
|
||||
# pyre-fixme[7]: Expected `Optional[Tuple[Dict[str, typing.Any],
|
||||
# ndarray[typing.Any, typing.Any]]]` but got `Tuple[typing.Any,
|
||||
# Optional[ndarray[typing.Any, dtype[typing.Any]]]]`.
|
||||
return json_data, binary_data
|
||||
|
||||
|
||||
|
||||
@@ -1250,6 +1250,9 @@ def _save_ply(
|
||||
if verts_normals is not None:
|
||||
verts_dtype.append(("normals", np.float32, 3))
|
||||
if verts_colors is not None:
|
||||
# pyre-fixme[6]: For 1st argument expected `Tuple[str,
|
||||
# Type[floating[_32Bit]], int]` but got `Tuple[str,
|
||||
# Type[Union[floating[_32Bit], unsignedinteger[typing.Any]]], int]`.
|
||||
verts_dtype.append(("colors", color_np_type, 3))
|
||||
|
||||
vert_data = np.zeros(verts.shape[0], dtype=verts_dtype)
|
||||
|
||||
@@ -76,13 +76,9 @@ from .points import (
|
||||
PointsRasterizationSettings,
|
||||
PointsRasterizer,
|
||||
PointsRenderer,
|
||||
PulsarPointsRenderer,
|
||||
rasterize_points,
|
||||
)
|
||||
|
||||
# Pulsar is not enabled on amd.
|
||||
if not torch.version.hip:
|
||||
from .points import PulsarPointsRenderer
|
||||
|
||||
from .splatter_blend import SplatterBlender
|
||||
from .utils import (
|
||||
convert_to_tensors_and_broadcast,
|
||||
|
||||
@@ -10,9 +10,7 @@ import torch
|
||||
|
||||
from .compositor import AlphaCompositor, NormWeightedCompositor
|
||||
|
||||
# Pulsar not enabled on amd.
|
||||
if not torch.version.hip:
|
||||
from .pulsar.unified import PulsarPointsRenderer
|
||||
from .pulsar.unified import PulsarPointsRenderer
|
||||
|
||||
from .rasterize_points import rasterize_points
|
||||
from .rasterizer import PointsRasterizationSettings, PointsRasterizer
|
||||
|
||||
@@ -674,7 +674,6 @@ def _add_mesh_trace(
|
||||
verts[~verts_used] = verts_center
|
||||
|
||||
row, col = subplot_idx // ncols + 1, subplot_idx % ncols + 1
|
||||
# pyre-fixme[16]: `Figure` has no attribute `add_trace`.
|
||||
fig.add_trace(
|
||||
go.Mesh3d(
|
||||
x=verts[:, 0],
|
||||
@@ -741,7 +740,6 @@ def _add_pointcloud_trace(
|
||||
|
||||
row = subplot_idx // ncols + 1
|
||||
col = subplot_idx % ncols + 1
|
||||
# pyre-fixme[16]: `Figure` has no attribute `add_trace`.
|
||||
fig.add_trace(
|
||||
go.Scatter3d(
|
||||
x=verts[:, 0],
|
||||
@@ -803,7 +801,6 @@ def _add_camera_trace(
|
||||
x, y, z = all_cam_wires.detach().cpu().numpy().T.astype(float)
|
||||
|
||||
row, col = subplot_idx // ncols + 1, subplot_idx % ncols + 1
|
||||
# pyre-fixme[16]: `Figure` has no attribute `add_trace`.
|
||||
fig.add_trace(
|
||||
go.Scatter3d(x=x, y=y, z=z, marker={"size": 1}, name=trace_name),
|
||||
row=row,
|
||||
@@ -898,7 +895,6 @@ def _add_ray_bundle_trace(
|
||||
ray_lines = torch.cat((ray_lines, nan_tensor, ray_line))
|
||||
x, y, z = ray_lines.detach().cpu().numpy().T.astype(float)
|
||||
row, col = subplot_idx // ncols + 1, subplot_idx % ncols + 1
|
||||
# pyre-fixme[16]: `Figure` has no attribute `add_trace`.
|
||||
fig.add_trace(
|
||||
go.Scatter3d(
|
||||
x=x,
|
||||
@@ -1010,7 +1006,6 @@ def _update_axes_bounds(
|
||||
|
||||
# Ensure that within a subplot, the bounds capture all traces
|
||||
old_xrange, old_yrange, old_zrange = (
|
||||
# pyre-fixme[16]: `Scene` has no attribute `__getitem__`.
|
||||
current_layout["xaxis"]["range"],
|
||||
current_layout["yaxis"]["range"],
|
||||
current_layout["zaxis"]["range"],
|
||||
@@ -1029,7 +1024,6 @@ def _update_axes_bounds(
|
||||
xaxis = {"range": x_range}
|
||||
yaxis = {"range": y_range}
|
||||
zaxis = {"range": z_range}
|
||||
# pyre-fixme[16]: `Scene` has no attribute `update`.
|
||||
current_layout.update({"xaxis": xaxis, "yaxis": yaxis, "zaxis": zaxis})
|
||||
|
||||
|
||||
|
||||
@@ -59,6 +59,8 @@ def texturesuv_image_matplotlib(
|
||||
for i in indices:
|
||||
# setting clip_on=False makes it obvious when
|
||||
# we have UV coordinates outside the correct range
|
||||
# pyre-fixme[6]: For 1st argument expected `Tuple[float, float]` but got
|
||||
# `ndarray[Any, Any]`.
|
||||
ax.add_patch(Circle(centers[i], radius, color=color, clip_on=False))
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user