mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-12-21 14:50:36 +08:00
Compare commits
57 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
33824be3cb | ||
|
|
2d4d345b6f | ||
|
|
45df20e9e2 | ||
|
|
fc6a6b8951 | ||
|
|
7711bf34a8 | ||
|
|
d098beb7a7 | ||
|
|
dd068703d1 | ||
|
|
50f8efa1cb | ||
|
|
5043d15361 | ||
|
|
e3d3a67a89 | ||
|
|
e55ea90609 | ||
|
|
3aee2a6005 | ||
|
|
c5ea8fa49e | ||
|
|
3ff6c5ab85 | ||
|
|
267bd8ef87 | ||
|
|
177eec6378 | ||
|
|
71db7a0ea2 | ||
|
|
6020323d94 | ||
|
|
182e845c19 | ||
|
|
f315ac131b | ||
|
|
fc08621879 | ||
|
|
3f327a516b | ||
|
|
366eff21d9 | ||
|
|
0a59450f0e | ||
|
|
3987612062 | ||
|
|
06a76ef8dd | ||
|
|
21205730d9 | ||
|
|
7e09505538 | ||
|
|
20bd8b33f6 | ||
|
|
7a3c0cbc9d | ||
|
|
215590b497 | ||
|
|
43cd681d4f | ||
|
|
42a4a7d432 | ||
|
|
699bc671ca | ||
|
|
49cf5a0f37 | ||
|
|
89b851e64c | ||
|
|
5247f6ad74 | ||
|
|
e41aff47db | ||
|
|
64a5bfadc8 | ||
|
|
055ab3a2e3 | ||
|
|
f6c2ca6bfc | ||
|
|
e20cbe9b0e | ||
|
|
c17e6f947a | ||
|
|
91c9f34137 | ||
|
|
81d82980bc | ||
|
|
8fe6934885 | ||
|
|
c434957b2a | ||
|
|
dd2a11b5fc | ||
|
|
9563ef79ca | ||
|
|
008c7ab58c | ||
|
|
9eaed4c495 | ||
|
|
e13848265d | ||
|
|
58566963d6 | ||
|
|
e17ed5cd50 | ||
|
|
8ed0c7a002 | ||
|
|
2da913c7e6 | ||
|
|
fca83e6369 |
@@ -88,7 +88,6 @@ def workflow_pair(
|
||||
upload=False,
|
||||
filter_branch,
|
||||
):
|
||||
|
||||
w = []
|
||||
py = python_version.replace(".", "")
|
||||
pyt = pytorch_version.replace(".", "")
|
||||
@@ -127,7 +126,6 @@ def generate_base_workflow(
|
||||
btype,
|
||||
filter_branch=None,
|
||||
):
|
||||
|
||||
d = {
|
||||
"name": base_workflow_name,
|
||||
"python_version": python_version,
|
||||
|
||||
23
.github/workflows/build.yml
vendored
Normal file
23
.github/workflows/build.yml
vendored
Normal file
@@ -0,0 +1,23 @@
|
||||
name: facebookresearch/pytorch3d/build_and_test
|
||||
on:
|
||||
pull_request:
|
||||
branches:
|
||||
- main
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
jobs:
|
||||
binary_linux_conda_cuda:
|
||||
runs-on: 4-core-ubuntu-gpu-t4
|
||||
env:
|
||||
PYTHON_VERSION: "3.12"
|
||||
BUILD_VERSION: "${{ github.run_number }}"
|
||||
PYTORCH_VERSION: "2.4.1"
|
||||
CU_VERSION: "cu121"
|
||||
JUST_TESTRUN: 1
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- name: Build and run tests
|
||||
run: |-
|
||||
conda create --name env --yes --quiet conda-build
|
||||
conda run --no-capture-output --name env python3 ./packaging/build_conda.py --use-conda-cuda
|
||||
@@ -10,7 +10,7 @@
|
||||
DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )"
|
||||
DIR=$(dirname "${DIR}")
|
||||
|
||||
if [[ -f "${DIR}/TARGETS" ]]
|
||||
if [[ -f "${DIR}/BUCK" ]]
|
||||
then
|
||||
pyfmt "${DIR}"
|
||||
else
|
||||
@@ -36,5 +36,5 @@ then
|
||||
|
||||
echo "Running pyre..."
|
||||
echo "To restart/kill pyre server, run 'pyre restart' or 'pyre kill' in fbcode/"
|
||||
( cd ~/fbsource/fbcode; pyre -l vision/fair/pytorch3d/ )
|
||||
( cd ~/fbsource/fbcode; arc pyre check //vision/fair/pytorch3d/... )
|
||||
fi
|
||||
|
||||
@@ -10,6 +10,7 @@ This example demonstrates the most trivial, direct interface of the pulsar
|
||||
sphere renderer. It renders and saves an image with 10 random spheres.
|
||||
Output: basic.png.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import math
|
||||
from os import path
|
||||
|
||||
@@ -11,6 +11,7 @@ interface for sphere renderering. It renders and saves an image with
|
||||
10 random spheres.
|
||||
Output: basic-pt3d.png.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from os import path
|
||||
|
||||
|
||||
@@ -14,6 +14,7 @@ distorted. Gradient-based optimization is used to converge towards the
|
||||
original camera parameters.
|
||||
Output: cam.gif.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import math
|
||||
from os import path
|
||||
|
||||
@@ -14,6 +14,7 @@ distorted. Gradient-based optimization is used to converge towards the
|
||||
original camera parameters.
|
||||
Output: cam-pt3d.gif
|
||||
"""
|
||||
|
||||
import logging
|
||||
from os import path
|
||||
|
||||
|
||||
@@ -18,6 +18,7 @@ This example is not available yet through the 'unified' interface,
|
||||
because opacity support has not landed in PyTorch3D for general data
|
||||
structures yet.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import math
|
||||
from os import path
|
||||
|
||||
@@ -13,6 +13,7 @@ The scene is initialized with random spheres. Gradient-based
|
||||
optimization is used to converge towards a faithful
|
||||
scene representation.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import math
|
||||
|
||||
|
||||
@@ -13,6 +13,7 @@ The scene is initialized with random spheres. Gradient-based
|
||||
optimization is used to converge towards a faithful
|
||||
scene representation.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import math
|
||||
|
||||
|
||||
@@ -4,10 +4,11 @@
|
||||
# This source code is licensed under the BSD-style license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import argparse
|
||||
import os.path
|
||||
import runpy
|
||||
import subprocess
|
||||
from typing import List
|
||||
from typing import List, Tuple
|
||||
|
||||
# required env vars:
|
||||
# CU_VERSION: E.g. cu112
|
||||
@@ -23,7 +24,7 @@ pytorch_major_minor = tuple(int(i) for i in PYTORCH_VERSION.split(".")[:2])
|
||||
source_root_dir = os.environ["PWD"]
|
||||
|
||||
|
||||
def version_constraint(version):
|
||||
def version_constraint(version) -> str:
|
||||
"""
|
||||
Given version "11.3" returns " >=11.3,<11.4"
|
||||
"""
|
||||
@@ -32,7 +33,7 @@ def version_constraint(version):
|
||||
return f" >={version},<{upper}"
|
||||
|
||||
|
||||
def get_cuda_major_minor():
|
||||
def get_cuda_major_minor() -> Tuple[str, str]:
|
||||
if CU_VERSION == "cpu":
|
||||
raise ValueError("fn only for cuda builds")
|
||||
if len(CU_VERSION) != 5 or CU_VERSION[:2] != "cu":
|
||||
@@ -42,11 +43,10 @@ def get_cuda_major_minor():
|
||||
return major, minor
|
||||
|
||||
|
||||
def setup_cuda():
|
||||
def setup_cuda(use_conda_cuda: bool) -> List[str]:
|
||||
if CU_VERSION == "cpu":
|
||||
return
|
||||
return []
|
||||
major, minor = get_cuda_major_minor()
|
||||
os.environ["CUDA_HOME"] = f"/usr/local/cuda-{major}.{minor}/"
|
||||
os.environ["FORCE_CUDA"] = "1"
|
||||
|
||||
basic_nvcc_flags = (
|
||||
@@ -75,6 +75,15 @@ def setup_cuda():
|
||||
|
||||
if os.environ.get("JUST_TESTRUN", "0") != "1":
|
||||
os.environ["NVCC_FLAGS"] = nvcc_flags
|
||||
if use_conda_cuda:
|
||||
os.environ["CONDA_CUDA_TOOLKIT_BUILD_CONSTRAINT1"] = "- cuda-toolkit"
|
||||
os.environ["CONDA_CUDA_TOOLKIT_BUILD_CONSTRAINT2"] = (
|
||||
f"- cuda-version={major}.{minor}"
|
||||
)
|
||||
return ["-c", f"nvidia/label/cuda-{major}.{minor}.0"]
|
||||
else:
|
||||
os.environ["CUDA_HOME"] = f"/usr/local/cuda-{major}.{minor}/"
|
||||
return []
|
||||
|
||||
|
||||
def setup_conda_pytorch_constraint() -> List[str]:
|
||||
@@ -95,7 +104,7 @@ def setup_conda_pytorch_constraint() -> List[str]:
|
||||
return ["-c", "pytorch", "-c", "nvidia"]
|
||||
|
||||
|
||||
def setup_conda_cudatoolkit_constraint():
|
||||
def setup_conda_cudatoolkit_constraint() -> None:
|
||||
if CU_VERSION == "cpu":
|
||||
os.environ["CONDA_CPUONLY_FEATURE"] = "- cpuonly"
|
||||
os.environ["CONDA_CUDATOOLKIT_CONSTRAINT"] = ""
|
||||
@@ -116,7 +125,7 @@ def setup_conda_cudatoolkit_constraint():
|
||||
os.environ["CONDA_CUDATOOLKIT_CONSTRAINT"] = toolkit
|
||||
|
||||
|
||||
def do_build(start_args: List[str]):
|
||||
def do_build(start_args: List[str]) -> None:
|
||||
args = start_args.copy()
|
||||
|
||||
test_flag = os.environ.get("TEST_FLAG")
|
||||
@@ -132,8 +141,16 @@ def do_build(start_args: List[str]):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Build the conda package.")
|
||||
parser.add_argument(
|
||||
"--use-conda-cuda",
|
||||
action="store_true",
|
||||
help="get cuda from conda ignoring local cuda",
|
||||
)
|
||||
our_args = parser.parse_args()
|
||||
|
||||
args = ["conda", "build"]
|
||||
setup_cuda()
|
||||
args += setup_cuda(use_conda_cuda=our_args.use_conda_cuda)
|
||||
|
||||
init_path = source_root_dir + "/pytorch3d/__init__.py"
|
||||
build_version = runpy.run_path(init_path)["__version__"]
|
||||
|
||||
@@ -8,10 +8,13 @@ source:
|
||||
requirements:
|
||||
build:
|
||||
- {{ compiler('c') }} # [win]
|
||||
{{ environ.get('CONDA_CUDA_TOOLKIT_BUILD_CONSTRAINT1', '') }}
|
||||
{{ environ.get('CONDA_CUDA_TOOLKIT_BUILD_CONSTRAINT2', '') }}
|
||||
{{ environ.get('CONDA_CUB_CONSTRAINT') }}
|
||||
|
||||
host:
|
||||
- python
|
||||
- mkl =2023 # [x86_64]
|
||||
{{ environ.get('SETUPTOOLS_CONSTRAINT') }}
|
||||
{{ environ.get('CONDA_PYTORCH_BUILD_CONSTRAINT') }}
|
||||
{{ environ.get('CONDA_PYTORCH_MKL_CONSTRAINT') }}
|
||||
@@ -22,6 +25,7 @@ requirements:
|
||||
- python
|
||||
- numpy >=1.11
|
||||
- torchvision >=0.5
|
||||
- mkl =2023 # [x86_64]
|
||||
- iopath
|
||||
{{ environ.get('CONDA_PYTORCH_CONSTRAINT') }}
|
||||
{{ environ.get('CONDA_CUDATOOLKIT_CONSTRAINT') }}
|
||||
@@ -47,8 +51,11 @@ test:
|
||||
- imageio
|
||||
- hydra-core
|
||||
- accelerate
|
||||
- matplotlib
|
||||
- tabulate
|
||||
- pandas
|
||||
- sqlalchemy
|
||||
commands:
|
||||
#pytest .
|
||||
python -m unittest discover -v -s tests -t .
|
||||
|
||||
|
||||
|
||||
@@ -7,7 +7,7 @@
|
||||
|
||||
# pyre-unsafe
|
||||
|
||||
""""
|
||||
""" "
|
||||
This file is the entry point for launching experiments with Implicitron.
|
||||
|
||||
Launch Training
|
||||
@@ -44,6 +44,7 @@ The outputs of the experiment are saved and logged in multiple ways:
|
||||
config file.
|
||||
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import warnings
|
||||
|
||||
@@ -26,7 +26,6 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ModelFactoryBase(ReplaceableBase):
|
||||
|
||||
resume: bool = True # resume from the last checkpoint
|
||||
|
||||
def __call__(self, **kwargs) -> ImplicitronModelBase:
|
||||
@@ -116,7 +115,9 @@ class ImplicitronModelFactory(ModelFactoryBase):
|
||||
"cuda:%d" % 0: "cuda:%d" % accelerator.local_process_index
|
||||
}
|
||||
model_state_dict = torch.load(
|
||||
model_io.get_model_path(model_path), map_location=map_location
|
||||
model_io.get_model_path(model_path),
|
||||
map_location=map_location,
|
||||
weights_only=True,
|
||||
)
|
||||
|
||||
try:
|
||||
|
||||
@@ -123,6 +123,7 @@ class ImplicitronOptimizerFactory(OptimizerFactoryBase):
|
||||
"""
|
||||
# Get the parameters to optimize
|
||||
if hasattr(model, "_get_param_groups"): # use the model function
|
||||
# pyre-fixme[29]: `Union[Tensor, Module]` is not a function.
|
||||
p_groups = model._get_param_groups(self.lr, wd=self.weight_decay)
|
||||
else:
|
||||
p_groups = [
|
||||
@@ -241,7 +242,7 @@ class ImplicitronOptimizerFactory(OptimizerFactoryBase):
|
||||
map_location = {
|
||||
"cuda:%d" % 0: "cuda:%d" % accelerator.local_process_index
|
||||
}
|
||||
optimizer_state = torch.load(opt_path, map_location)
|
||||
optimizer_state = torch.load(opt_path, map_location, weights_only=True)
|
||||
else:
|
||||
raise FileNotFoundError(f"Optimizer state {opt_path} does not exist.")
|
||||
return optimizer_state
|
||||
|
||||
@@ -161,7 +161,6 @@ class ImplicitronTrainingLoop(TrainingLoopBase):
|
||||
for epoch in range(start_epoch, self.max_epochs):
|
||||
# automatic new_epoch and plotting of stats at every epoch start
|
||||
with stats:
|
||||
|
||||
# Make sure to re-seed random generators to ensure reproducibility
|
||||
# even after restart.
|
||||
seed_all_random_engines(seed + epoch)
|
||||
@@ -395,6 +394,7 @@ class ImplicitronTrainingLoop(TrainingLoopBase):
|
||||
):
|
||||
prefix = f"e{stats.epoch}_it{stats.it[trainmode]}"
|
||||
if hasattr(model, "visualize"):
|
||||
# pyre-fixme[29]: `Union[Tensor, Module]` is not a function.
|
||||
model.visualize(
|
||||
viz,
|
||||
visdom_env_imgs,
|
||||
|
||||
@@ -53,12 +53,8 @@ class TestExperiment(unittest.TestCase):
|
||||
cfg.data_source_ImplicitronDataSource_args.dataset_map_provider_class_type = (
|
||||
"JsonIndexDatasetMapProvider"
|
||||
)
|
||||
dataset_args = (
|
||||
cfg.data_source_ImplicitronDataSource_args.dataset_map_provider_JsonIndexDatasetMapProvider_args
|
||||
)
|
||||
dataloader_args = (
|
||||
cfg.data_source_ImplicitronDataSource_args.data_loader_map_provider_SequenceDataLoaderMapProvider_args
|
||||
)
|
||||
dataset_args = cfg.data_source_ImplicitronDataSource_args.dataset_map_provider_JsonIndexDatasetMapProvider_args
|
||||
dataloader_args = cfg.data_source_ImplicitronDataSource_args.data_loader_map_provider_SequenceDataLoaderMapProvider_args
|
||||
dataset_args.category = "skateboard"
|
||||
dataset_args.test_restrict_sequence_id = 0
|
||||
dataset_args.dataset_root = "manifold://co3d/tree/extracted"
|
||||
@@ -94,12 +90,8 @@ class TestExperiment(unittest.TestCase):
|
||||
cfg.data_source_ImplicitronDataSource_args.dataset_map_provider_class_type = (
|
||||
"JsonIndexDatasetMapProvider"
|
||||
)
|
||||
dataset_args = (
|
||||
cfg.data_source_ImplicitronDataSource_args.dataset_map_provider_JsonIndexDatasetMapProvider_args
|
||||
)
|
||||
dataloader_args = (
|
||||
cfg.data_source_ImplicitronDataSource_args.data_loader_map_provider_SequenceDataLoaderMapProvider_args
|
||||
)
|
||||
dataset_args = cfg.data_source_ImplicitronDataSource_args.dataset_map_provider_JsonIndexDatasetMapProvider_args
|
||||
dataloader_args = cfg.data_source_ImplicitronDataSource_args.data_loader_map_provider_SequenceDataLoaderMapProvider_args
|
||||
dataset_args.category = "skateboard"
|
||||
dataset_args.test_restrict_sequence_id = 0
|
||||
dataset_args.dataset_root = "manifold://co3d/tree/extracted"
|
||||
@@ -111,9 +103,7 @@ class TestExperiment(unittest.TestCase):
|
||||
cfg.training_loop_ImplicitronTrainingLoop_args.max_epochs = 2
|
||||
cfg.training_loop_ImplicitronTrainingLoop_args.store_checkpoints = False
|
||||
cfg.optimizer_factory_ImplicitronOptimizerFactory_args.lr_policy = "Exponential"
|
||||
cfg.optimizer_factory_ImplicitronOptimizerFactory_args.exponential_lr_step_size = (
|
||||
2
|
||||
)
|
||||
cfg.optimizer_factory_ImplicitronOptimizerFactory_args.exponential_lr_step_size = 2
|
||||
|
||||
if DEBUG:
|
||||
experiment.dump_cfg(cfg)
|
||||
|
||||
@@ -81,8 +81,9 @@ class TestOptimizerFactory(unittest.TestCase):
|
||||
|
||||
def test_param_overrides_self_param_group_assignment(self):
|
||||
pa, pb, pc = [torch.nn.Parameter(data=torch.tensor(i * 1.0)) for i in range(3)]
|
||||
na, nb = Node(params=[pa]), Node(
|
||||
params=[pb], param_groups={"self": "pb_self", "p1": "pb_param"}
|
||||
na, nb = (
|
||||
Node(params=[pa]),
|
||||
Node(params=[pb], param_groups={"self": "pb_self", "p1": "pb_param"}),
|
||||
)
|
||||
root = Node(children=[na, nb], params=[pc], param_groups={"m1": "pb_member"})
|
||||
param_groups = self._get_param_groups(root)
|
||||
|
||||
@@ -84,9 +84,9 @@ def get_nerf_datasets(
|
||||
|
||||
if autodownload and any(not os.path.isfile(p) for p in (cameras_path, image_path)):
|
||||
# Automatically download the data files if missing.
|
||||
download_data((dataset_name,), data_root=data_root)
|
||||
download_data([dataset_name], data_root=data_root)
|
||||
|
||||
train_data = torch.load(cameras_path)
|
||||
train_data = torch.load(cameras_path, weights_only=True)
|
||||
n_cameras = train_data["cameras"]["R"].shape[0]
|
||||
|
||||
_image_max_image_pixels = Image.MAX_IMAGE_PIXELS
|
||||
|
||||
@@ -194,7 +194,6 @@ class Stats:
|
||||
it = self.it[stat_set]
|
||||
|
||||
for stat in self.log_vars:
|
||||
|
||||
if stat not in self.stats[stat_set]:
|
||||
self.stats[stat_set][stat] = AverageMeter()
|
||||
|
||||
|
||||
@@ -24,7 +24,6 @@ CONFIG_DIR = os.path.join(os.path.dirname(os.path.realpath(__file__)), "configs"
|
||||
|
||||
@hydra.main(config_path=CONFIG_DIR, config_name="lego")
|
||||
def main(cfg: DictConfig):
|
||||
|
||||
# Device on which to run.
|
||||
if torch.cuda.is_available():
|
||||
device = "cuda"
|
||||
@@ -63,7 +62,7 @@ def main(cfg: DictConfig):
|
||||
raise ValueError(f"Model checkpoint {checkpoint_path} does not exist!")
|
||||
|
||||
print(f"Loading checkpoint {checkpoint_path}.")
|
||||
loaded_data = torch.load(checkpoint_path)
|
||||
loaded_data = torch.load(checkpoint_path, weights_only=True)
|
||||
# Do not load the cached xy grid.
|
||||
# - this allows setting an arbitrary evaluation image size.
|
||||
state_dict = {
|
||||
|
||||
@@ -42,7 +42,6 @@ class TestRaysampler(unittest.TestCase):
|
||||
cameras, rays = [], []
|
||||
|
||||
for _ in range(batch_size):
|
||||
|
||||
R = random_rotations(1)
|
||||
T = torch.randn(1, 3)
|
||||
focal_length = torch.rand(1, 2) + 0.5
|
||||
|
||||
@@ -25,7 +25,6 @@ CONFIG_DIR = os.path.join(os.path.dirname(os.path.realpath(__file__)), "configs"
|
||||
|
||||
@hydra.main(config_path=CONFIG_DIR, config_name="lego")
|
||||
def main(cfg: DictConfig):
|
||||
|
||||
# Set the relevant seeds for reproducibility.
|
||||
np.random.seed(cfg.seed)
|
||||
torch.manual_seed(cfg.seed)
|
||||
@@ -77,7 +76,7 @@ def main(cfg: DictConfig):
|
||||
# Resume training if requested.
|
||||
if cfg.resume and os.path.isfile(checkpoint_path):
|
||||
print(f"Resuming from checkpoint {checkpoint_path}.")
|
||||
loaded_data = torch.load(checkpoint_path)
|
||||
loaded_data = torch.load(checkpoint_path, weights_only=True)
|
||||
model.load_state_dict(loaded_data["model"])
|
||||
stats = pickle.loads(loaded_data["stats"])
|
||||
print(f" => resuming from epoch {stats.epoch}.")
|
||||
@@ -219,7 +218,6 @@ def main(cfg: DictConfig):
|
||||
|
||||
# Validation
|
||||
if epoch % cfg.validation_epoch_interval == 0 and epoch > 0:
|
||||
|
||||
# Sample a validation camera/image.
|
||||
val_batch = next(val_dataloader.__iter__())
|
||||
val_image, val_camera, camera_idx = val_batch[0].values()
|
||||
|
||||
@@ -6,4 +6,4 @@
|
||||
|
||||
# pyre-unsafe
|
||||
|
||||
__version__ = "0.7.8"
|
||||
__version__ = "0.7.9"
|
||||
|
||||
@@ -17,7 +17,7 @@ Some functions which depend on PyTorch or Python versions.
|
||||
|
||||
|
||||
def meshgrid_ij(
|
||||
*A: Union[torch.Tensor, Sequence[torch.Tensor]]
|
||||
*A: Union[torch.Tensor, Sequence[torch.Tensor]],
|
||||
) -> Tuple[torch.Tensor, ...]: # pragma: no cover
|
||||
"""
|
||||
Like torch.meshgrid was before PyTorch 1.10.0, i.e. with indexing set to ij
|
||||
|
||||
@@ -32,7 +32,9 @@ __global__ void BallQueryKernel(
|
||||
at::PackedTensorAccessor64<int64_t, 3, at::RestrictPtrTraits> idxs,
|
||||
at::PackedTensorAccessor64<scalar_t, 3, at::RestrictPtrTraits> dists,
|
||||
const int64_t K,
|
||||
const float radius2) {
|
||||
const float radius,
|
||||
const float radius2,
|
||||
const bool skip_points_outside_cube) {
|
||||
const int64_t N = p1.size(0);
|
||||
const int64_t chunks_per_cloud = (1 + (p1.size(1) - 1) / blockDim.x);
|
||||
const int64_t chunks_to_do = N * chunks_per_cloud;
|
||||
@@ -51,7 +53,19 @@ __global__ void BallQueryKernel(
|
||||
// Iterate over points in p2 until desired count is reached or
|
||||
// all points have been considered
|
||||
for (int64_t j = 0, count = 0; j < lengths2[n] && count < K; ++j) {
|
||||
// Calculate the distance between the points
|
||||
if (skip_points_outside_cube) {
|
||||
bool is_within_radius = true;
|
||||
// Filter when any one coordinate is already outside the radius
|
||||
for (int d = 0; is_within_radius && d < D; ++d) {
|
||||
scalar_t abs_diff = fabs(p1[n][i][d] - p2[n][j][d]);
|
||||
is_within_radius = (abs_diff <= radius);
|
||||
}
|
||||
if (!is_within_radius) {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
// Else, calculate the distance between the points and compare
|
||||
scalar_t dist2 = 0.0;
|
||||
for (int d = 0; d < D; ++d) {
|
||||
scalar_t diff = p1[n][i][d] - p2[n][j][d];
|
||||
@@ -77,7 +91,8 @@ std::tuple<at::Tensor, at::Tensor> BallQueryCuda(
|
||||
const at::Tensor& lengths1, // (N,)
|
||||
const at::Tensor& lengths2, // (N,)
|
||||
int K,
|
||||
float radius) {
|
||||
float radius,
|
||||
bool skip_points_outside_cube) {
|
||||
// Check inputs are on the same device
|
||||
at::TensorArg p1_t{p1, "p1", 1}, p2_t{p2, "p2", 2},
|
||||
lengths1_t{lengths1, "lengths1", 3}, lengths2_t{lengths2, "lengths2", 4};
|
||||
@@ -120,7 +135,9 @@ std::tuple<at::Tensor, at::Tensor> BallQueryCuda(
|
||||
idxs.packed_accessor64<int64_t, 3, at::RestrictPtrTraits>(),
|
||||
dists.packed_accessor64<float, 3, at::RestrictPtrTraits>(),
|
||||
K_64,
|
||||
radius2);
|
||||
radius,
|
||||
radius2,
|
||||
skip_points_outside_cube);
|
||||
}));
|
||||
|
||||
AT_CUDA_CHECK(cudaGetLastError());
|
||||
|
||||
@@ -25,6 +25,9 @@
|
||||
// within the radius
|
||||
// radius: the radius around each point within which the neighbors need to be
|
||||
// located
|
||||
// skip_points_outside_cube: If true, reduce multiplications of float values
|
||||
// by not explicitly calculating distances to points that fall outside the
|
||||
// D-cube with side length (2*radius) centered at each point in p1.
|
||||
//
|
||||
// Returns:
|
||||
// p1_neighbor_idx: LongTensor of shape (N, P1, K), where
|
||||
@@ -46,7 +49,8 @@ std::tuple<at::Tensor, at::Tensor> BallQueryCpu(
|
||||
const at::Tensor& lengths1,
|
||||
const at::Tensor& lengths2,
|
||||
const int K,
|
||||
const float radius);
|
||||
const float radius,
|
||||
const bool skip_points_outside_cube);
|
||||
|
||||
// CUDA implementation
|
||||
std::tuple<at::Tensor, at::Tensor> BallQueryCuda(
|
||||
@@ -55,7 +59,8 @@ std::tuple<at::Tensor, at::Tensor> BallQueryCuda(
|
||||
const at::Tensor& lengths1,
|
||||
const at::Tensor& lengths2,
|
||||
const int K,
|
||||
const float radius);
|
||||
const float radius,
|
||||
const bool skip_points_outside_cube);
|
||||
|
||||
// Implementation which is exposed
|
||||
// Note: the backward pass reuses the KNearestNeighborBackward kernel
|
||||
@@ -65,7 +70,8 @@ inline std::tuple<at::Tensor, at::Tensor> BallQuery(
|
||||
const at::Tensor& lengths1,
|
||||
const at::Tensor& lengths2,
|
||||
int K,
|
||||
float radius) {
|
||||
float radius,
|
||||
bool skip_points_outside_cube) {
|
||||
if (p1.is_cuda() || p2.is_cuda()) {
|
||||
#ifdef WITH_CUDA
|
||||
CHECK_CUDA(p1);
|
||||
@@ -76,16 +82,20 @@ inline std::tuple<at::Tensor, at::Tensor> BallQuery(
|
||||
lengths1.contiguous(),
|
||||
lengths2.contiguous(),
|
||||
K,
|
||||
radius);
|
||||
radius,
|
||||
skip_points_outside_cube);
|
||||
#else
|
||||
AT_ERROR("Not compiled with GPU support.");
|
||||
#endif
|
||||
}
|
||||
CHECK_CPU(p1);
|
||||
CHECK_CPU(p2);
|
||||
return BallQueryCpu(
|
||||
p1.contiguous(),
|
||||
p2.contiguous(),
|
||||
lengths1.contiguous(),
|
||||
lengths2.contiguous(),
|
||||
K,
|
||||
radius);
|
||||
radius,
|
||||
skip_points_outside_cube);
|
||||
}
|
||||
|
||||
@@ -6,8 +6,8 @@
|
||||
* LICENSE file in the root directory of this source tree.
|
||||
*/
|
||||
|
||||
#include <math.h>
|
||||
#include <torch/extension.h>
|
||||
#include <queue>
|
||||
#include <tuple>
|
||||
|
||||
std::tuple<at::Tensor, at::Tensor> BallQueryCpu(
|
||||
@@ -16,7 +16,8 @@ std::tuple<at::Tensor, at::Tensor> BallQueryCpu(
|
||||
const at::Tensor& lengths1,
|
||||
const at::Tensor& lengths2,
|
||||
int K,
|
||||
float radius) {
|
||||
float radius,
|
||||
bool skip_points_outside_cube) {
|
||||
const int N = p1.size(0);
|
||||
const int P1 = p1.size(1);
|
||||
const int D = p1.size(2);
|
||||
@@ -38,6 +39,16 @@ std::tuple<at::Tensor, at::Tensor> BallQueryCpu(
|
||||
const int64_t length2 = lengths2_a[n];
|
||||
for (int64_t i = 0; i < length1; ++i) {
|
||||
for (int64_t j = 0, count = 0; j < length2 && count < K; ++j) {
|
||||
if (skip_points_outside_cube) {
|
||||
bool is_within_radius = true;
|
||||
for (int d = 0; is_within_radius && d < D; ++d) {
|
||||
float abs_diff = fabs(p1_a[n][i][d] - p2_a[n][j][d]);
|
||||
is_within_radius = (abs_diff <= radius);
|
||||
}
|
||||
if (!is_within_radius) {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
float dist2 = 0;
|
||||
for (int d = 0; d < D; ++d) {
|
||||
float diff = p1_a[n][i][d] - p2_a[n][j][d];
|
||||
|
||||
@@ -98,6 +98,11 @@ at::Tensor SigmoidAlphaBlendBackward(
|
||||
AT_ERROR("Not compiled with GPU support.");
|
||||
#endif
|
||||
}
|
||||
CHECK_CPU(distances);
|
||||
CHECK_CPU(pix_to_face);
|
||||
CHECK_CPU(alphas);
|
||||
CHECK_CPU(grad_alphas);
|
||||
|
||||
return SigmoidAlphaBlendBackwardCpu(
|
||||
grad_alphas, alphas, distances, pix_to_face, sigma);
|
||||
}
|
||||
|
||||
@@ -28,17 +28,16 @@ __global__ void alphaCompositeCudaForwardKernel(
|
||||
const at::PackedTensorAccessor64<float, 4, at::RestrictPtrTraits> alphas,
|
||||
const at::PackedTensorAccessor64<int64_t, 4, at::RestrictPtrTraits> points_idx) {
|
||||
// clang-format on
|
||||
const int64_t batch_size = result.size(0);
|
||||
const int64_t C = features.size(0);
|
||||
const int64_t H = points_idx.size(2);
|
||||
const int64_t W = points_idx.size(3);
|
||||
|
||||
// Get the batch and index
|
||||
const int batch = blockIdx.x;
|
||||
const auto batch = blockIdx.x;
|
||||
|
||||
const int num_pixels = C * H * W;
|
||||
const int num_threads = gridDim.y * blockDim.x;
|
||||
const int tid = blockIdx.y * blockDim.x + threadIdx.x;
|
||||
const auto num_threads = gridDim.y * blockDim.x;
|
||||
const auto tid = blockIdx.y * blockDim.x + threadIdx.x;
|
||||
|
||||
// Iterate over each feature in each pixel
|
||||
for (int pid = tid; pid < num_pixels; pid += num_threads) {
|
||||
@@ -79,17 +78,16 @@ __global__ void alphaCompositeCudaBackwardKernel(
|
||||
const at::PackedTensorAccessor64<float, 4, at::RestrictPtrTraits> alphas,
|
||||
const at::PackedTensorAccessor64<int64_t, 4, at::RestrictPtrTraits> points_idx) {
|
||||
// clang-format on
|
||||
const int64_t batch_size = points_idx.size(0);
|
||||
const int64_t C = features.size(0);
|
||||
const int64_t H = points_idx.size(2);
|
||||
const int64_t W = points_idx.size(3);
|
||||
|
||||
// Get the batch and index
|
||||
const int batch = blockIdx.x;
|
||||
const auto batch = blockIdx.x;
|
||||
|
||||
const int num_pixels = C * H * W;
|
||||
const int num_threads = gridDim.y * blockDim.x;
|
||||
const int tid = blockIdx.y * blockDim.x + threadIdx.x;
|
||||
const auto num_threads = gridDim.y * blockDim.x;
|
||||
const auto tid = blockIdx.y * blockDim.x + threadIdx.x;
|
||||
|
||||
// Parallelize over each feature in each pixel in images of size H * W,
|
||||
// for each image in the batch of size batch_size
|
||||
|
||||
@@ -74,6 +74,9 @@ torch::Tensor alphaCompositeForward(
|
||||
AT_ERROR("Not compiled with GPU support");
|
||||
#endif
|
||||
} else {
|
||||
CHECK_CPU(features);
|
||||
CHECK_CPU(alphas);
|
||||
CHECK_CPU(points_idx);
|
||||
return alphaCompositeCpuForward(features, alphas, points_idx);
|
||||
}
|
||||
}
|
||||
@@ -101,6 +104,11 @@ std::tuple<torch::Tensor, torch::Tensor> alphaCompositeBackward(
|
||||
AT_ERROR("Not compiled with GPU support");
|
||||
#endif
|
||||
} else {
|
||||
CHECK_CPU(grad_outputs);
|
||||
CHECK_CPU(features);
|
||||
CHECK_CPU(alphas);
|
||||
CHECK_CPU(points_idx);
|
||||
|
||||
return alphaCompositeCpuBackward(
|
||||
grad_outputs, features, alphas, points_idx);
|
||||
}
|
||||
|
||||
@@ -28,17 +28,16 @@ __global__ void weightedSumNormCudaForwardKernel(
|
||||
const at::PackedTensorAccessor64<float, 4, at::RestrictPtrTraits> alphas,
|
||||
const at::PackedTensorAccessor64<int64_t, 4, at::RestrictPtrTraits> points_idx) {
|
||||
// clang-format on
|
||||
const int64_t batch_size = result.size(0);
|
||||
const int64_t C = features.size(0);
|
||||
const int64_t H = points_idx.size(2);
|
||||
const int64_t W = points_idx.size(3);
|
||||
|
||||
// Get the batch and index
|
||||
const int batch = blockIdx.x;
|
||||
const auto batch = blockIdx.x;
|
||||
|
||||
const int num_pixels = C * H * W;
|
||||
const int num_threads = gridDim.y * blockDim.x;
|
||||
const int tid = blockIdx.y * blockDim.x + threadIdx.x;
|
||||
const auto num_threads = gridDim.y * blockDim.x;
|
||||
const auto tid = blockIdx.y * blockDim.x + threadIdx.x;
|
||||
|
||||
// Parallelize over each feature in each pixel in images of size H * W,
|
||||
// for each image in the batch of size batch_size
|
||||
@@ -92,17 +91,16 @@ __global__ void weightedSumNormCudaBackwardKernel(
|
||||
const at::PackedTensorAccessor64<float, 4, at::RestrictPtrTraits> alphas,
|
||||
const at::PackedTensorAccessor64<int64_t, 4, at::RestrictPtrTraits> points_idx) {
|
||||
// clang-format on
|
||||
const int64_t batch_size = points_idx.size(0);
|
||||
const int64_t C = features.size(0);
|
||||
const int64_t H = points_idx.size(2);
|
||||
const int64_t W = points_idx.size(3);
|
||||
|
||||
// Get the batch and index
|
||||
const int batch = blockIdx.x;
|
||||
const auto batch = blockIdx.x;
|
||||
|
||||
const int num_pixels = C * W * H;
|
||||
const int num_threads = gridDim.y * blockDim.x;
|
||||
const int tid = blockIdx.y * blockDim.x + threadIdx.x;
|
||||
const auto num_threads = gridDim.y * blockDim.x;
|
||||
const auto tid = blockIdx.y * blockDim.x + threadIdx.x;
|
||||
|
||||
// Parallelize over each feature in each pixel in images of size H * W,
|
||||
// for each image in the batch of size batch_size
|
||||
|
||||
@@ -73,6 +73,10 @@ torch::Tensor weightedSumNormForward(
|
||||
AT_ERROR("Not compiled with GPU support");
|
||||
#endif
|
||||
} else {
|
||||
CHECK_CPU(features);
|
||||
CHECK_CPU(alphas);
|
||||
CHECK_CPU(points_idx);
|
||||
|
||||
return weightedSumNormCpuForward(features, alphas, points_idx);
|
||||
}
|
||||
}
|
||||
@@ -100,6 +104,11 @@ std::tuple<torch::Tensor, torch::Tensor> weightedSumNormBackward(
|
||||
AT_ERROR("Not compiled with GPU support");
|
||||
#endif
|
||||
} else {
|
||||
CHECK_CPU(grad_outputs);
|
||||
CHECK_CPU(features);
|
||||
CHECK_CPU(alphas);
|
||||
CHECK_CPU(points_idx);
|
||||
|
||||
return weightedSumNormCpuBackward(
|
||||
grad_outputs, features, alphas, points_idx);
|
||||
}
|
||||
|
||||
@@ -26,17 +26,16 @@ __global__ void weightedSumCudaForwardKernel(
|
||||
const at::PackedTensorAccessor64<float, 4, at::RestrictPtrTraits> alphas,
|
||||
const at::PackedTensorAccessor64<int64_t, 4, at::RestrictPtrTraits> points_idx) {
|
||||
// clang-format on
|
||||
const int64_t batch_size = result.size(0);
|
||||
const int64_t C = features.size(0);
|
||||
const int64_t H = points_idx.size(2);
|
||||
const int64_t W = points_idx.size(3);
|
||||
|
||||
// Get the batch and index
|
||||
const int batch = blockIdx.x;
|
||||
const auto batch = blockIdx.x;
|
||||
|
||||
const int num_pixels = C * H * W;
|
||||
const int num_threads = gridDim.y * blockDim.x;
|
||||
const int tid = blockIdx.y * blockDim.x + threadIdx.x;
|
||||
const auto num_threads = gridDim.y * blockDim.x;
|
||||
const auto tid = blockIdx.y * blockDim.x + threadIdx.x;
|
||||
|
||||
// Parallelize over each feature in each pixel in images of size H * W,
|
||||
// for each image in the batch of size batch_size
|
||||
@@ -74,17 +73,16 @@ __global__ void weightedSumCudaBackwardKernel(
|
||||
const at::PackedTensorAccessor64<float, 4, at::RestrictPtrTraits> alphas,
|
||||
const at::PackedTensorAccessor64<int64_t, 4, at::RestrictPtrTraits> points_idx) {
|
||||
// clang-format on
|
||||
const int64_t batch_size = points_idx.size(0);
|
||||
const int64_t C = features.size(0);
|
||||
const int64_t H = points_idx.size(2);
|
||||
const int64_t W = points_idx.size(3);
|
||||
|
||||
// Get the batch and index
|
||||
const int batch = blockIdx.x;
|
||||
const auto batch = blockIdx.x;
|
||||
|
||||
const int num_pixels = C * H * W;
|
||||
const int num_threads = gridDim.y * blockDim.x;
|
||||
const int tid = blockIdx.y * blockDim.x + threadIdx.x;
|
||||
const auto num_threads = gridDim.y * blockDim.x;
|
||||
const auto tid = blockIdx.y * blockDim.x + threadIdx.x;
|
||||
|
||||
// Iterate over each pixel to compute the contribution to the
|
||||
// gradient for the features and weights
|
||||
|
||||
@@ -72,6 +72,9 @@ torch::Tensor weightedSumForward(
|
||||
AT_ERROR("Not compiled with GPU support");
|
||||
#endif
|
||||
} else {
|
||||
CHECK_CPU(features);
|
||||
CHECK_CPU(alphas);
|
||||
CHECK_CPU(points_idx);
|
||||
return weightedSumCpuForward(features, alphas, points_idx);
|
||||
}
|
||||
}
|
||||
@@ -98,6 +101,11 @@ std::tuple<torch::Tensor, torch::Tensor> weightedSumBackward(
|
||||
AT_ERROR("Not compiled with GPU support");
|
||||
#endif
|
||||
} else {
|
||||
CHECK_CPU(grad_outputs);
|
||||
CHECK_CPU(features);
|
||||
CHECK_CPU(alphas);
|
||||
CHECK_CPU(points_idx);
|
||||
|
||||
return weightedSumCpuBackward(grad_outputs, features, alphas, points_idx);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -7,15 +7,10 @@
|
||||
*/
|
||||
|
||||
// clang-format off
|
||||
#if !defined(USE_ROCM)
|
||||
#include "./pulsar/global.h" // Include before <torch/extension.h>.
|
||||
#endif
|
||||
#include <torch/extension.h>
|
||||
// clang-format on
|
||||
#if !defined(USE_ROCM)
|
||||
#include "./pulsar/pytorch/renderer.h"
|
||||
#include "./pulsar/pytorch/tensor_util.h"
|
||||
#endif
|
||||
#include "ball_query/ball_query.h"
|
||||
#include "blending/sigmoid_alpha_blend.h"
|
||||
#include "compositing/alpha_composite.h"
|
||||
@@ -104,22 +99,22 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
|
||||
// Pulsar.
|
||||
// Pulsar not enabled on AMD.
|
||||
#if !defined(USE_ROCM)
|
||||
#ifdef PULSAR_LOGGING_ENABLED
|
||||
c10::ShowLogInfoToStderr();
|
||||
#endif
|
||||
py::class_<
|
||||
pulsar::pytorch::Renderer,
|
||||
std::shared_ptr<pulsar::pytorch::Renderer>>(m, "PulsarRenderer")
|
||||
.def(py::init<
|
||||
const uint&,
|
||||
const uint&,
|
||||
const uint&,
|
||||
const bool&,
|
||||
const bool&,
|
||||
const float&,
|
||||
const uint&,
|
||||
const uint&>())
|
||||
.def(
|
||||
py::init<
|
||||
const uint&,
|
||||
const uint&,
|
||||
const uint&,
|
||||
const bool&,
|
||||
const bool&,
|
||||
const float&,
|
||||
const uint&,
|
||||
const uint&>())
|
||||
.def(
|
||||
"__eq__",
|
||||
[](const pulsar::pytorch::Renderer& a,
|
||||
@@ -154,10 +149,10 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
py::arg("gamma"),
|
||||
py::arg("max_depth"),
|
||||
py::arg("min_depth") /* = 0.f*/,
|
||||
py::arg(
|
||||
"bg_col") /* = at::nullopt not exposed properly in pytorch 1.1. */
|
||||
py::arg("bg_col") /* = std::nullopt not exposed properly in
|
||||
pytorch 1.1. */
|
||||
,
|
||||
py::arg("opacity") /* = at::nullopt ... */,
|
||||
py::arg("opacity") /* = std::nullopt ... */,
|
||||
py::arg("percent_allowed_difference") = 0.01f,
|
||||
py::arg("max_n_hits") = MAX_UINT,
|
||||
py::arg("mode") = 0)
|
||||
@@ -189,5 +184,4 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
m.attr("MAX_UINT") = py::int_(MAX_UINT);
|
||||
m.attr("MAX_USHORT") = py::int_(MAX_USHORT);
|
||||
m.attr("PULSAR_MAX_GRAD_SPHERES") = py::int_(MAX_GRAD_SPHERES);
|
||||
#endif
|
||||
}
|
||||
|
||||
@@ -60,6 +60,8 @@ std::tuple<at::Tensor, at::Tensor> FaceAreasNormalsForward(
|
||||
AT_ERROR("Not compiled with GPU support.");
|
||||
#endif
|
||||
}
|
||||
CHECK_CPU(verts);
|
||||
CHECK_CPU(faces);
|
||||
return FaceAreasNormalsForwardCpu(verts, faces);
|
||||
}
|
||||
|
||||
@@ -80,5 +82,9 @@ at::Tensor FaceAreasNormalsBackward(
|
||||
AT_ERROR("Not compiled with GPU support.");
|
||||
#endif
|
||||
}
|
||||
CHECK_CPU(grad_areas);
|
||||
CHECK_CPU(grad_normals);
|
||||
CHECK_CPU(verts);
|
||||
CHECK_CPU(faces);
|
||||
return FaceAreasNormalsBackwardCpu(grad_areas, grad_normals, verts, faces);
|
||||
}
|
||||
|
||||
@@ -20,14 +20,14 @@ __global__ void GatherScatterCudaKernel(
|
||||
const size_t V,
|
||||
const size_t D,
|
||||
const size_t E) {
|
||||
const int tid = threadIdx.x;
|
||||
const auto tid = threadIdx.x;
|
||||
|
||||
// Reverse the vertex order if backward.
|
||||
const int v0_idx = backward ? 1 : 0;
|
||||
const int v1_idx = backward ? 0 : 1;
|
||||
|
||||
// Edges are split evenly across the blocks.
|
||||
for (int e = blockIdx.x; e < E; e += gridDim.x) {
|
||||
for (auto e = blockIdx.x; e < E; e += gridDim.x) {
|
||||
// Get indices of vertices which form the edge.
|
||||
const int64_t v0 = edges[2 * e + v0_idx];
|
||||
const int64_t v1 = edges[2 * e + v1_idx];
|
||||
@@ -35,7 +35,7 @@ __global__ void GatherScatterCudaKernel(
|
||||
// Split vertex features evenly across threads.
|
||||
// This implementation will be quite wasteful when D<128 since there will be
|
||||
// a lot of threads doing nothing.
|
||||
for (int d = tid; d < D; d += blockDim.x) {
|
||||
for (auto d = tid; d < D; d += blockDim.x) {
|
||||
const float val = input[v1 * D + d];
|
||||
float* address = output + v0 * D + d;
|
||||
atomicAdd(address, val);
|
||||
|
||||
@@ -53,5 +53,7 @@ at::Tensor GatherScatter(
|
||||
AT_ERROR("Not compiled with GPU support.");
|
||||
#endif
|
||||
}
|
||||
CHECK_CPU(input);
|
||||
CHECK_CPU(edges);
|
||||
return GatherScatterCpu(input, edges, directed, backward);
|
||||
}
|
||||
|
||||
@@ -20,8 +20,8 @@ __global__ void InterpFaceAttrsForwardKernel(
|
||||
const size_t P,
|
||||
const size_t F,
|
||||
const size_t D) {
|
||||
const int tid = threadIdx.x + blockIdx.x * blockDim.x;
|
||||
const int num_threads = blockDim.x * gridDim.x;
|
||||
const auto tid = threadIdx.x + blockIdx.x * blockDim.x;
|
||||
const auto num_threads = blockDim.x * gridDim.x;
|
||||
for (int pd = tid; pd < P * D; pd += num_threads) {
|
||||
const int p = pd / D;
|
||||
const int d = pd % D;
|
||||
@@ -93,8 +93,8 @@ __global__ void InterpFaceAttrsBackwardKernel(
|
||||
const size_t P,
|
||||
const size_t F,
|
||||
const size_t D) {
|
||||
const int tid = threadIdx.x + blockIdx.x * blockDim.x;
|
||||
const int num_threads = blockDim.x * gridDim.x;
|
||||
const auto tid = threadIdx.x + blockIdx.x * blockDim.x;
|
||||
const auto num_threads = blockDim.x * gridDim.x;
|
||||
for (int pd = tid; pd < P * D; pd += num_threads) {
|
||||
const int p = pd / D;
|
||||
const int d = pd % D;
|
||||
|
||||
@@ -57,6 +57,8 @@ at::Tensor InterpFaceAttrsForward(
|
||||
AT_ERROR("Not compiled with GPU support.");
|
||||
#endif
|
||||
}
|
||||
CHECK_CPU(face_attrs);
|
||||
CHECK_CPU(barycentric_coords);
|
||||
return InterpFaceAttrsForwardCpu(pix_to_face, barycentric_coords, face_attrs);
|
||||
}
|
||||
|
||||
@@ -106,6 +108,9 @@ std::tuple<at::Tensor, at::Tensor> InterpFaceAttrsBackward(
|
||||
AT_ERROR("Not compiled with GPU support.");
|
||||
#endif
|
||||
}
|
||||
CHECK_CPU(face_attrs);
|
||||
CHECK_CPU(barycentric_coords);
|
||||
CHECK_CPU(grad_pix_attrs);
|
||||
return InterpFaceAttrsBackwardCpu(
|
||||
pix_to_face, barycentric_coords, face_attrs, grad_pix_attrs);
|
||||
}
|
||||
|
||||
@@ -44,5 +44,7 @@ inline std::tuple<at::Tensor, at::Tensor> IoUBox3D(
|
||||
AT_ERROR("Not compiled with GPU support.");
|
||||
#endif
|
||||
}
|
||||
CHECK_CPU(boxes1);
|
||||
CHECK_CPU(boxes2);
|
||||
return IoUBox3DCpu(boxes1.contiguous(), boxes2.contiguous());
|
||||
}
|
||||
|
||||
@@ -7,10 +7,7 @@
|
||||
*/
|
||||
|
||||
#include <torch/extension.h>
|
||||
#include <torch/torch.h>
|
||||
#include <list>
|
||||
#include <numeric>
|
||||
#include <queue>
|
||||
#include <tuple>
|
||||
#include "iou_box3d/iou_utils.h"
|
||||
|
||||
|
||||
@@ -461,10 +461,8 @@ __device__ inline std::tuple<float3, float3> ArgMaxVerts(
|
||||
__device__ inline bool IsCoplanarTriTri(
|
||||
const FaceVerts& tri1,
|
||||
const FaceVerts& tri2) {
|
||||
const float3 tri1_ctr = FaceCenter({tri1.v0, tri1.v1, tri1.v2});
|
||||
const float3 tri1_n = FaceNormal({tri1.v0, tri1.v1, tri1.v2});
|
||||
|
||||
const float3 tri2_ctr = FaceCenter({tri2.v0, tri2.v1, tri2.v2});
|
||||
const float3 tri2_n = FaceNormal({tri2.v0, tri2.v1, tri2.v2});
|
||||
|
||||
// Check if parallel
|
||||
@@ -500,7 +498,6 @@ __device__ inline bool IsCoplanarTriPlane(
|
||||
const FaceVerts& tri,
|
||||
const FaceVerts& plane,
|
||||
const float3& normal) {
|
||||
const float3 tri_ctr = FaceCenter({tri.v0, tri.v1, tri.v2});
|
||||
const float3 nt = FaceNormal({tri.v0, tri.v1, tri.v2});
|
||||
|
||||
// check if parallel
|
||||
@@ -728,7 +725,7 @@ __device__ inline int BoxIntersections(
|
||||
}
|
||||
}
|
||||
// Update the face_verts_out tris
|
||||
num_tris = offset;
|
||||
num_tris = min(MAX_TRIS, offset);
|
||||
for (int j = 0; j < num_tris; ++j) {
|
||||
face_verts_out[j] = tri_verts_updated[j];
|
||||
}
|
||||
|
||||
@@ -74,6 +74,8 @@ std::tuple<at::Tensor, at::Tensor> KNearestNeighborIdx(
|
||||
AT_ERROR("Not compiled with GPU support.");
|
||||
#endif
|
||||
}
|
||||
CHECK_CPU(p1);
|
||||
CHECK_CPU(p2);
|
||||
return KNearestNeighborIdxCpu(p1, p2, lengths1, lengths2, norm, K);
|
||||
}
|
||||
|
||||
@@ -140,6 +142,8 @@ std::tuple<at::Tensor, at::Tensor> KNearestNeighborBackward(
|
||||
AT_ERROR("Not compiled with GPU support.");
|
||||
#endif
|
||||
}
|
||||
CHECK_CPU(p1);
|
||||
CHECK_CPU(p2);
|
||||
return KNearestNeighborBackwardCpu(
|
||||
p1, p2, lengths1, lengths2, idxs, norm, grad_dists);
|
||||
}
|
||||
|
||||
@@ -58,5 +58,6 @@ inline std::tuple<at::Tensor, at::Tensor, at::Tensor> MarchingCubes(
|
||||
AT_ERROR("Not compiled with GPU support.");
|
||||
#endif
|
||||
}
|
||||
CHECK_CPU(vol);
|
||||
return MarchingCubesCpu(vol.contiguous(), isolevel);
|
||||
}
|
||||
|
||||
@@ -88,6 +88,8 @@ at::Tensor PackedToPadded(
|
||||
AT_ERROR("Not compiled with GPU support.");
|
||||
#endif
|
||||
}
|
||||
CHECK_CPU(inputs_packed);
|
||||
CHECK_CPU(first_idxs);
|
||||
return PackedToPaddedCpu(inputs_packed, first_idxs, max_size);
|
||||
}
|
||||
|
||||
@@ -105,5 +107,7 @@ at::Tensor PaddedToPacked(
|
||||
AT_ERROR("Not compiled with GPU support.");
|
||||
#endif
|
||||
}
|
||||
CHECK_CPU(inputs_padded);
|
||||
CHECK_CPU(first_idxs);
|
||||
return PaddedToPackedCpu(inputs_padded, first_idxs, num_inputs);
|
||||
}
|
||||
|
||||
@@ -174,8 +174,8 @@ std::tuple<at::Tensor, at::Tensor> HullHullDistanceForwardCpu(
|
||||
at::Tensor idxs = at::zeros({A_N,}, as_first_idx.options());
|
||||
// clang-format on
|
||||
|
||||
auto as_a = as.accessor < float, H1 == 1 ? 2 : 3 > ();
|
||||
auto bs_a = bs.accessor < float, H2 == 1 ? 2 : 3 > ();
|
||||
auto as_a = as.accessor<float, H1 == 1 ? 2 : 3>();
|
||||
auto bs_a = bs.accessor<float, H2 == 1 ? 2 : 3>();
|
||||
auto as_first_idx_a = as_first_idx.accessor<int64_t, 1>();
|
||||
auto bs_first_idx_a = bs_first_idx.accessor<int64_t, 1>();
|
||||
auto dists_a = dists.accessor<float, 1>();
|
||||
@@ -230,10 +230,10 @@ std::tuple<at::Tensor, at::Tensor> HullHullDistanceBackwardCpu(
|
||||
at::Tensor grad_as = at::zeros_like(as);
|
||||
at::Tensor grad_bs = at::zeros_like(bs);
|
||||
|
||||
auto as_a = as.accessor < float, H1 == 1 ? 2 : 3 > ();
|
||||
auto bs_a = bs.accessor < float, H2 == 1 ? 2 : 3 > ();
|
||||
auto grad_as_a = grad_as.accessor < float, H1 == 1 ? 2 : 3 > ();
|
||||
auto grad_bs_a = grad_bs.accessor < float, H2 == 1 ? 2 : 3 > ();
|
||||
auto as_a = as.accessor<float, H1 == 1 ? 2 : 3>();
|
||||
auto bs_a = bs.accessor<float, H2 == 1 ? 2 : 3>();
|
||||
auto grad_as_a = grad_as.accessor<float, H1 == 1 ? 2 : 3>();
|
||||
auto grad_bs_a = grad_bs.accessor<float, H2 == 1 ? 2 : 3>();
|
||||
auto idx_bs_a = idx_bs.accessor<int64_t, 1>();
|
||||
auto grad_dists_a = grad_dists.accessor<float, 1>();
|
||||
|
||||
|
||||
@@ -110,7 +110,7 @@ __global__ void DistanceForwardKernel(
|
||||
__syncthreads();
|
||||
|
||||
// Perform reduction in shared memory.
|
||||
for (int s = blockDim.x / 2; s > 32; s >>= 1) {
|
||||
for (auto s = blockDim.x / 2; s > 32; s >>= 1) {
|
||||
if (tid < s) {
|
||||
if (min_dists[tid] > min_dists[tid + s]) {
|
||||
min_dists[tid] = min_dists[tid + s];
|
||||
@@ -502,8 +502,8 @@ __global__ void PointFaceArrayForwardKernel(
|
||||
const float3* tris_f3 = (float3*)tris;
|
||||
|
||||
// Parallelize over P * S computations
|
||||
const int num_threads = gridDim.x * blockDim.x;
|
||||
const int tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
const auto num_threads = gridDim.x * blockDim.x;
|
||||
const auto tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
|
||||
for (int t_i = tid; t_i < P * T; t_i += num_threads) {
|
||||
const int t = t_i / P; // segment index.
|
||||
@@ -576,8 +576,8 @@ __global__ void PointFaceArrayBackwardKernel(
|
||||
const float3* tris_f3 = (float3*)tris;
|
||||
|
||||
// Parallelize over P * S computations
|
||||
const int num_threads = gridDim.x * blockDim.x;
|
||||
const int tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
const auto num_threads = gridDim.x * blockDim.x;
|
||||
const auto tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
|
||||
for (int t_i = tid; t_i < P * T; t_i += num_threads) {
|
||||
const int t = t_i / P; // triangle index.
|
||||
@@ -683,8 +683,8 @@ __global__ void PointEdgeArrayForwardKernel(
|
||||
float3* segms_f3 = (float3*)segms;
|
||||
|
||||
// Parallelize over P * S computations
|
||||
const int num_threads = gridDim.x * blockDim.x;
|
||||
const int tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
const auto num_threads = gridDim.x * blockDim.x;
|
||||
const auto tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
|
||||
for (int t_i = tid; t_i < P * S; t_i += num_threads) {
|
||||
const int s = t_i / P; // segment index.
|
||||
@@ -752,8 +752,8 @@ __global__ void PointEdgeArrayBackwardKernel(
|
||||
float3* segms_f3 = (float3*)segms;
|
||||
|
||||
// Parallelize over P * S computations
|
||||
const int num_threads = gridDim.x * blockDim.x;
|
||||
const int tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
const auto num_threads = gridDim.x * blockDim.x;
|
||||
const auto tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
|
||||
for (int t_i = tid; t_i < P * S; t_i += num_threads) {
|
||||
const int s = t_i / P; // segment index.
|
||||
|
||||
@@ -88,6 +88,10 @@ std::tuple<torch::Tensor, torch::Tensor> PointFaceDistanceForward(
|
||||
AT_ERROR("Not compiled with GPU support.");
|
||||
#endif
|
||||
}
|
||||
CHECK_CPU(points);
|
||||
CHECK_CPU(points_first_idx);
|
||||
CHECK_CPU(tris);
|
||||
CHECK_CPU(tris_first_idx);
|
||||
return PointFaceDistanceForwardCpu(
|
||||
points, points_first_idx, tris, tris_first_idx, min_triangle_area);
|
||||
}
|
||||
@@ -143,6 +147,10 @@ std::tuple<torch::Tensor, torch::Tensor> PointFaceDistanceBackward(
|
||||
AT_ERROR("Not compiled with GPU support.");
|
||||
#endif
|
||||
}
|
||||
CHECK_CPU(points);
|
||||
CHECK_CPU(tris);
|
||||
CHECK_CPU(idx_points);
|
||||
CHECK_CPU(grad_dists);
|
||||
return PointFaceDistanceBackwardCpu(
|
||||
points, tris, idx_points, grad_dists, min_triangle_area);
|
||||
}
|
||||
@@ -221,6 +229,10 @@ std::tuple<torch::Tensor, torch::Tensor> FacePointDistanceForward(
|
||||
AT_ERROR("Not compiled with GPU support.");
|
||||
#endif
|
||||
}
|
||||
CHECK_CPU(points);
|
||||
CHECK_CPU(points_first_idx);
|
||||
CHECK_CPU(tris);
|
||||
CHECK_CPU(tris_first_idx);
|
||||
return FacePointDistanceForwardCpu(
|
||||
points, points_first_idx, tris, tris_first_idx, min_triangle_area);
|
||||
}
|
||||
@@ -277,6 +289,10 @@ std::tuple<torch::Tensor, torch::Tensor> FacePointDistanceBackward(
|
||||
AT_ERROR("Not compiled with GPU support.");
|
||||
#endif
|
||||
}
|
||||
CHECK_CPU(points);
|
||||
CHECK_CPU(tris);
|
||||
CHECK_CPU(idx_tris);
|
||||
CHECK_CPU(grad_dists);
|
||||
return FacePointDistanceBackwardCpu(
|
||||
points, tris, idx_tris, grad_dists, min_triangle_area);
|
||||
}
|
||||
@@ -346,6 +362,10 @@ std::tuple<torch::Tensor, torch::Tensor> PointEdgeDistanceForward(
|
||||
AT_ERROR("Not compiled with GPU support.");
|
||||
#endif
|
||||
}
|
||||
CHECK_CPU(points);
|
||||
CHECK_CPU(points_first_idx);
|
||||
CHECK_CPU(segms);
|
||||
CHECK_CPU(segms_first_idx);
|
||||
return PointEdgeDistanceForwardCpu(
|
||||
points, points_first_idx, segms, segms_first_idx, max_points);
|
||||
}
|
||||
@@ -396,6 +416,10 @@ std::tuple<torch::Tensor, torch::Tensor> PointEdgeDistanceBackward(
|
||||
AT_ERROR("Not compiled with GPU support.");
|
||||
#endif
|
||||
}
|
||||
CHECK_CPU(points);
|
||||
CHECK_CPU(segms);
|
||||
CHECK_CPU(idx_points);
|
||||
CHECK_CPU(grad_dists);
|
||||
return PointEdgeDistanceBackwardCpu(points, segms, idx_points, grad_dists);
|
||||
}
|
||||
|
||||
@@ -464,6 +488,10 @@ std::tuple<torch::Tensor, torch::Tensor> EdgePointDistanceForward(
|
||||
AT_ERROR("Not compiled with GPU support.");
|
||||
#endif
|
||||
}
|
||||
CHECK_CPU(points);
|
||||
CHECK_CPU(points_first_idx);
|
||||
CHECK_CPU(segms);
|
||||
CHECK_CPU(segms_first_idx);
|
||||
return EdgePointDistanceForwardCpu(
|
||||
points, points_first_idx, segms, segms_first_idx, max_segms);
|
||||
}
|
||||
@@ -514,6 +542,10 @@ std::tuple<torch::Tensor, torch::Tensor> EdgePointDistanceBackward(
|
||||
AT_ERROR("Not compiled with GPU support.");
|
||||
#endif
|
||||
}
|
||||
CHECK_CPU(points);
|
||||
CHECK_CPU(segms);
|
||||
CHECK_CPU(idx_segms);
|
||||
CHECK_CPU(grad_dists);
|
||||
return EdgePointDistanceBackwardCpu(points, segms, idx_segms, grad_dists);
|
||||
}
|
||||
|
||||
@@ -567,6 +599,8 @@ torch::Tensor PointFaceArrayDistanceForward(
|
||||
AT_ERROR("Not compiled with GPU support.");
|
||||
#endif
|
||||
}
|
||||
CHECK_CPU(points);
|
||||
CHECK_CPU(tris);
|
||||
return PointFaceArrayDistanceForwardCpu(points, tris, min_triangle_area);
|
||||
}
|
||||
|
||||
@@ -613,6 +647,9 @@ std::tuple<torch::Tensor, torch::Tensor> PointFaceArrayDistanceBackward(
|
||||
AT_ERROR("Not compiled with GPU support.");
|
||||
#endif
|
||||
}
|
||||
CHECK_CPU(points);
|
||||
CHECK_CPU(tris);
|
||||
CHECK_CPU(grad_dists);
|
||||
return PointFaceArrayDistanceBackwardCpu(
|
||||
points, tris, grad_dists, min_triangle_area);
|
||||
}
|
||||
@@ -661,6 +698,8 @@ torch::Tensor PointEdgeArrayDistanceForward(
|
||||
AT_ERROR("Not compiled with GPU support.");
|
||||
#endif
|
||||
}
|
||||
CHECK_CPU(points);
|
||||
CHECK_CPU(segms);
|
||||
return PointEdgeArrayDistanceForwardCpu(points, segms);
|
||||
}
|
||||
|
||||
@@ -703,5 +742,8 @@ std::tuple<torch::Tensor, torch::Tensor> PointEdgeArrayDistanceBackward(
|
||||
AT_ERROR("Not compiled with GPU support.");
|
||||
#endif
|
||||
}
|
||||
CHECK_CPU(points);
|
||||
CHECK_CPU(segms);
|
||||
CHECK_CPU(grad_dists);
|
||||
return PointEdgeArrayDistanceBackwardCpu(points, segms, grad_dists);
|
||||
}
|
||||
|
||||
@@ -104,6 +104,12 @@ inline void PointsToVolumesForward(
|
||||
AT_ERROR("Not compiled with GPU support.");
|
||||
#endif
|
||||
}
|
||||
CHECK_CPU(points_3d);
|
||||
CHECK_CPU(points_features);
|
||||
CHECK_CPU(volume_densities);
|
||||
CHECK_CPU(volume_features);
|
||||
CHECK_CPU(grid_sizes);
|
||||
CHECK_CPU(mask);
|
||||
PointsToVolumesForwardCpu(
|
||||
points_3d,
|
||||
points_features,
|
||||
@@ -183,6 +189,14 @@ inline void PointsToVolumesBackward(
|
||||
AT_ERROR("Not compiled with GPU support.");
|
||||
#endif
|
||||
}
|
||||
CHECK_CPU(points_3d);
|
||||
CHECK_CPU(points_features);
|
||||
CHECK_CPU(grid_sizes);
|
||||
CHECK_CPU(mask);
|
||||
CHECK_CPU(grad_volume_densities);
|
||||
CHECK_CPU(grad_volume_features);
|
||||
CHECK_CPU(grad_points_3d);
|
||||
CHECK_CPU(grad_points_features);
|
||||
PointsToVolumesBackwardCpu(
|
||||
points_3d,
|
||||
points_features,
|
||||
|
||||
@@ -8,9 +8,7 @@
|
||||
|
||||
#include <torch/csrc/autograd/VariableTypeUtils.h>
|
||||
#include <torch/extension.h>
|
||||
#include <algorithm>
|
||||
#include <cmath>
|
||||
#include <thread>
|
||||
#include <vector>
|
||||
|
||||
// In the x direction, the location {0, ..., grid_size_x - 1} correspond to
|
||||
|
||||
@@ -15,8 +15,8 @@
|
||||
#endif
|
||||
|
||||
#if defined(_WIN64) || defined(_WIN32)
|
||||
#define uint unsigned int
|
||||
#define ushort unsigned short
|
||||
using uint = unsigned int;
|
||||
using ushort = unsigned short;
|
||||
#endif
|
||||
|
||||
#include "./logging.h" // <- include before torch/extension.h
|
||||
@@ -36,11 +36,13 @@
|
||||
#pragma nv_diag_suppress 2951
|
||||
#pragma nv_diag_suppress 2967
|
||||
#else
|
||||
#if !defined(USE_ROCM)
|
||||
#pragma diag_suppress = attribute_not_allowed
|
||||
#pragma diag_suppress = 1866
|
||||
#pragma diag_suppress = 2941
|
||||
#pragma diag_suppress = 2951
|
||||
#pragma diag_suppress = 2967
|
||||
#endif //! USE_ROCM
|
||||
#endif
|
||||
#else // __CUDACC__
|
||||
#define INLINE inline
|
||||
@@ -56,7 +58,9 @@
|
||||
#pragma clang diagnostic pop
|
||||
#ifdef WITH_CUDA
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#if !defined(USE_ROCM)
|
||||
#include <vector_functions.h>
|
||||
#endif //! USE_ROCM
|
||||
#else
|
||||
#ifndef cudaStream_t
|
||||
typedef void* cudaStream_t;
|
||||
|
||||
@@ -59,6 +59,11 @@ getLastCudaError(const char* errorMessage, const char* file, const int line) {
|
||||
#define SHARED __shared__
|
||||
#define ACTIVEMASK() __activemask()
|
||||
#define BALLOT(mask, val) __ballot_sync((mask), val)
|
||||
|
||||
/* TODO (ROCM-6.2): None of the WARP_* are used anywhere and ROCM-6.2 natively
|
||||
* supports __shfl_*. Disabling until the move to ROCM-6.2.
|
||||
*/
|
||||
#if !defined(USE_ROCM)
|
||||
/**
|
||||
* Find the cumulative sum within a warp up to the current
|
||||
* thread lane, with each mask thread contributing base.
|
||||
@@ -115,6 +120,7 @@ INLINE DEVICE float3 WARP_SUM_FLOAT3(
|
||||
ret.z = WARP_SUM(group, mask, base.z);
|
||||
return ret;
|
||||
}
|
||||
#endif //! USE_ROCM
|
||||
|
||||
// Floating point.
|
||||
// #define FMUL(a, b) __fmul_rn((a), (b))
|
||||
@@ -142,6 +148,7 @@ INLINE DEVICE float3 WARP_SUM_FLOAT3(
|
||||
#define FMA(x, y, z) __fmaf_rn((x), (y), (z))
|
||||
#define I2F(a) __int2float_rn(a)
|
||||
#define FRCP(x) __frcp_rn(x)
|
||||
#if !defined(USE_ROCM)
|
||||
__device__ static float atomicMax(float* address, float val) {
|
||||
int* address_as_i = (int*)address;
|
||||
int old = *address_as_i, assumed;
|
||||
@@ -166,6 +173,7 @@ __device__ static float atomicMin(float* address, float val) {
|
||||
} while (assumed != old);
|
||||
return __int_as_float(old);
|
||||
}
|
||||
#endif //! USE_ROCM
|
||||
#define DMAX(a, b) FMAX(a, b)
|
||||
#define DMIN(a, b) FMIN(a, b)
|
||||
#define DSQRT(a) sqrt(a)
|
||||
@@ -409,7 +417,7 @@ __device__ static float atomicMin(float* address, float val) {
|
||||
(OUT_PTR), \
|
||||
(NUM_SELECTED_PTR), \
|
||||
(NUM_ITEMS), \
|
||||
stream = (STREAM));
|
||||
(STREAM));
|
||||
|
||||
#define COPY_HOST_DEV(PTR_D, PTR_H, TYPE, SIZE) \
|
||||
HANDLECUDA(cudaMemcpy( \
|
||||
@@ -357,11 +357,11 @@ void MAX_WS(
|
||||
//
|
||||
//
|
||||
#define END_PARALLEL() \
|
||||
end_parallel :; \
|
||||
end_parallel:; \
|
||||
}
|
||||
#define END_PARALLEL_NORET() }
|
||||
#define END_PARALLEL_2D() \
|
||||
end_parallel :; \
|
||||
end_parallel:; \
|
||||
} \
|
||||
}
|
||||
#define END_PARALLEL_2D_NORET() \
|
||||
|
||||
@@ -14,7 +14,7 @@
|
||||
#include "./commands.h"
|
||||
|
||||
namespace pulsar {
|
||||
IHD CamGradInfo::CamGradInfo() {
|
||||
IHD CamGradInfo::CamGradInfo(int x) {
|
||||
cam_pos = make_float3(0.f, 0.f, 0.f);
|
||||
pixel_0_0_center = make_float3(0.f, 0.f, 0.f);
|
||||
pixel_dir_x = make_float3(0.f, 0.f, 0.f);
|
||||
|
||||
@@ -63,18 +63,13 @@ inline bool operator==(const CamInfo& a, const CamInfo& b) {
|
||||
};
|
||||
|
||||
struct CamGradInfo {
|
||||
HOST DEVICE CamGradInfo();
|
||||
HOST DEVICE CamGradInfo(int = 0);
|
||||
float3 cam_pos;
|
||||
float3 pixel_0_0_center;
|
||||
float3 pixel_dir_x;
|
||||
float3 pixel_dir_y;
|
||||
};
|
||||
|
||||
// TODO: remove once https://github.com/NVlabs/cub/issues/172 is resolved.
|
||||
struct IntWrapper {
|
||||
int val;
|
||||
};
|
||||
|
||||
} // namespace pulsar
|
||||
|
||||
#endif
|
||||
|
||||
@@ -24,7 +24,7 @@
|
||||
// #pragma diag_suppress = 68
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
// #pragma pop
|
||||
#include "../cuda/commands.h"
|
||||
#include "../gpu/commands.h"
|
||||
#else
|
||||
#pragma clang diagnostic push
|
||||
#pragma clang diagnostic ignored "-Weverything"
|
||||
|
||||
@@ -46,6 +46,7 @@ IHD float3 outer_product_sum(const float3& a) {
|
||||
}
|
||||
|
||||
// TODO: put intrinsics here.
|
||||
#if !defined(USE_ROCM)
|
||||
IHD float3 operator+(const float3& a, const float3& b) {
|
||||
return make_float3(a.x + b.x, a.y + b.y, a.z + b.z);
|
||||
}
|
||||
@@ -93,6 +94,7 @@ IHD float3 operator*(const float3& a, const float3& b) {
|
||||
IHD float3 operator*(const float& a, const float3& b) {
|
||||
return b * a;
|
||||
}
|
||||
#endif //! USE_ROCM
|
||||
|
||||
INLINE DEVICE float length(const float3& v) {
|
||||
// TODO: benchmark what's faster.
|
||||
@@ -147,11 +149,6 @@ IHD CamGradInfo operator*(const CamGradInfo& a, const float& b) {
|
||||
return res;
|
||||
}
|
||||
|
||||
IHD IntWrapper operator+(const IntWrapper& a, const IntWrapper& b) {
|
||||
IntWrapper res;
|
||||
res.val = a.val + b.val;
|
||||
return res;
|
||||
}
|
||||
} // namespace pulsar
|
||||
|
||||
#endif
|
||||
|
||||
@@ -155,8 +155,8 @@ void backward(
|
||||
stream);
|
||||
CHECKLAUNCH();
|
||||
SUM_WS(
|
||||
(IntWrapper*)(self->ids_sorted_d),
|
||||
(IntWrapper*)(self->n_grad_contributions_d),
|
||||
self->ids_sorted_d,
|
||||
self->n_grad_contributions_d,
|
||||
static_cast<int>(num_balls),
|
||||
self->workspace_d,
|
||||
self->workspace_size,
|
||||
|
||||
@@ -52,7 +52,7 @@ HOST void construct(
|
||||
self->cam.film_width = width;
|
||||
self->cam.film_height = height;
|
||||
self->max_num_balls = max_num_balls;
|
||||
MALLOC(self->result_d, float, width* height* n_channels);
|
||||
MALLOC(self->result_d, float, width * height * n_channels);
|
||||
self->cam.orthogonal_projection = orthogonal_projection;
|
||||
self->cam.right_handed = right_handed_system;
|
||||
self->cam.background_normalization_depth = background_normalization_depth;
|
||||
@@ -93,7 +93,7 @@ HOST void construct(
|
||||
MALLOC(self->di_sorted_d, DrawInfo, max_num_balls);
|
||||
MALLOC(self->region_flags_d, char, max_num_balls);
|
||||
MALLOC(self->num_selected_d, size_t, 1);
|
||||
MALLOC(self->forw_info_d, float, width* height * (3 + 2 * n_track));
|
||||
MALLOC(self->forw_info_d, float, width * height * (3 + 2 * n_track));
|
||||
MALLOC(self->min_max_pixels_d, IntersectInfo, 1);
|
||||
MALLOC(self->grad_pos_d, float3, max_num_balls);
|
||||
MALLOC(self->grad_col_d, float, max_num_balls* n_channels);
|
||||
|
||||
@@ -255,7 +255,7 @@ GLOBAL void calc_signature(
|
||||
* for every iteration through the loading loop every thread could add a
|
||||
* 'hit' to the buffer.
|
||||
*/
|
||||
#define RENDER_BUFFER_SIZE RENDER_BLOCK_SIZE* RENDER_BLOCK_SIZE * 2
|
||||
#define RENDER_BUFFER_SIZE RENDER_BLOCK_SIZE * RENDER_BLOCK_SIZE * 2
|
||||
/**
|
||||
* The threshold after which the spheres that are in the render buffer
|
||||
* are rendered and the buffer is flushed.
|
||||
|
||||
@@ -283,9 +283,15 @@ GLOBAL void render(
|
||||
(percent_allowed_difference > 0.f &&
|
||||
max_closest_possible_intersection > depth_threshold) ||
|
||||
tracker.get_n_hits() >= max_n_hits;
|
||||
#if defined(__CUDACC__) && defined(__HIP_PLATFORM_AMD__)
|
||||
unsigned long long warp_done = __ballot(done);
|
||||
int warp_done_bit_cnt = __popcll(warp_done);
|
||||
#else
|
||||
uint warp_done = thread_warp.ballot(done);
|
||||
int warp_done_bit_cnt = POPC(warp_done);
|
||||
#endif //__CUDACC__ && __HIP_PLATFORM_AMD__
|
||||
if (thread_warp.thread_rank() == 0)
|
||||
ATOMICADD_B(&n_pixels_done, POPC(warp_done));
|
||||
ATOMICADD_B(&n_pixels_done, warp_done_bit_cnt);
|
||||
// This sync is necessary to keep n_loaded until all threads are done with
|
||||
// painting.
|
||||
thread_block.sync();
|
||||
|
||||
@@ -213,8 +213,8 @@ std::tuple<size_t, size_t, bool, torch::Tensor> Renderer::arg_check(
|
||||
const float& gamma,
|
||||
const float& max_depth,
|
||||
float& min_depth,
|
||||
const c10::optional<torch::Tensor>& bg_col,
|
||||
const c10::optional<torch::Tensor>& opacity,
|
||||
const std::optional<torch::Tensor>& bg_col,
|
||||
const std::optional<torch::Tensor>& opacity,
|
||||
const float& percent_allowed_difference,
|
||||
const uint& max_n_hits,
|
||||
const uint& mode) {
|
||||
@@ -668,8 +668,8 @@ std::tuple<torch::Tensor, torch::Tensor> Renderer::forward(
|
||||
const float& gamma,
|
||||
const float& max_depth,
|
||||
float min_depth,
|
||||
const c10::optional<torch::Tensor>& bg_col,
|
||||
const c10::optional<torch::Tensor>& opacity,
|
||||
const std::optional<torch::Tensor>& bg_col,
|
||||
const std::optional<torch::Tensor>& opacity,
|
||||
const float& percent_allowed_difference,
|
||||
const uint& max_n_hits,
|
||||
const uint& mode) {
|
||||
@@ -888,14 +888,14 @@ std::tuple<torch::Tensor, torch::Tensor> Renderer::forward(
|
||||
};
|
||||
|
||||
std::tuple<
|
||||
at::optional<torch::Tensor>,
|
||||
at::optional<torch::Tensor>,
|
||||
at::optional<torch::Tensor>,
|
||||
at::optional<torch::Tensor>,
|
||||
at::optional<torch::Tensor>,
|
||||
at::optional<torch::Tensor>,
|
||||
at::optional<torch::Tensor>,
|
||||
at::optional<torch::Tensor>>
|
||||
std::optional<torch::Tensor>,
|
||||
std::optional<torch::Tensor>,
|
||||
std::optional<torch::Tensor>,
|
||||
std::optional<torch::Tensor>,
|
||||
std::optional<torch::Tensor>,
|
||||
std::optional<torch::Tensor>,
|
||||
std::optional<torch::Tensor>,
|
||||
std::optional<torch::Tensor>>
|
||||
Renderer::backward(
|
||||
const torch::Tensor& grad_im,
|
||||
const torch::Tensor& image,
|
||||
@@ -912,8 +912,8 @@ Renderer::backward(
|
||||
const float& gamma,
|
||||
const float& max_depth,
|
||||
float min_depth,
|
||||
const c10::optional<torch::Tensor>& bg_col,
|
||||
const c10::optional<torch::Tensor>& opacity,
|
||||
const std::optional<torch::Tensor>& bg_col,
|
||||
const std::optional<torch::Tensor>& opacity,
|
||||
const float& percent_allowed_difference,
|
||||
const uint& max_n_hits,
|
||||
const uint& mode,
|
||||
@@ -922,7 +922,7 @@ Renderer::backward(
|
||||
const bool& dif_rad,
|
||||
const bool& dif_cam,
|
||||
const bool& dif_opy,
|
||||
const at::optional<std::pair<uint, uint>>& dbg_pos) {
|
||||
const std::optional<std::pair<uint, uint>>& dbg_pos) {
|
||||
this->ensure_on_device(this->device_tracker.device());
|
||||
size_t batch_size;
|
||||
size_t n_points;
|
||||
@@ -1045,14 +1045,14 @@ Renderer::backward(
|
||||
}
|
||||
// Prepare the return value.
|
||||
std::tuple<
|
||||
at::optional<torch::Tensor>,
|
||||
at::optional<torch::Tensor>,
|
||||
at::optional<torch::Tensor>,
|
||||
at::optional<torch::Tensor>,
|
||||
at::optional<torch::Tensor>,
|
||||
at::optional<torch::Tensor>,
|
||||
at::optional<torch::Tensor>,
|
||||
at::optional<torch::Tensor>>
|
||||
std::optional<torch::Tensor>,
|
||||
std::optional<torch::Tensor>,
|
||||
std::optional<torch::Tensor>,
|
||||
std::optional<torch::Tensor>,
|
||||
std::optional<torch::Tensor>,
|
||||
std::optional<torch::Tensor>,
|
||||
std::optional<torch::Tensor>,
|
||||
std::optional<torch::Tensor>>
|
||||
ret;
|
||||
if (mode == 1 || (!dif_pos && !dif_col && !dif_rad && !dif_cam && !dif_opy)) {
|
||||
return ret;
|
||||
|
||||
@@ -44,21 +44,21 @@ struct Renderer {
|
||||
const float& gamma,
|
||||
const float& max_depth,
|
||||
float min_depth,
|
||||
const c10::optional<torch::Tensor>& bg_col,
|
||||
const c10::optional<torch::Tensor>& opacity,
|
||||
const std::optional<torch::Tensor>& bg_col,
|
||||
const std::optional<torch::Tensor>& opacity,
|
||||
const float& percent_allowed_difference,
|
||||
const uint& max_n_hits,
|
||||
const uint& mode);
|
||||
|
||||
std::tuple<
|
||||
at::optional<torch::Tensor>,
|
||||
at::optional<torch::Tensor>,
|
||||
at::optional<torch::Tensor>,
|
||||
at::optional<torch::Tensor>,
|
||||
at::optional<torch::Tensor>,
|
||||
at::optional<torch::Tensor>,
|
||||
at::optional<torch::Tensor>,
|
||||
at::optional<torch::Tensor>>
|
||||
std::optional<torch::Tensor>,
|
||||
std::optional<torch::Tensor>,
|
||||
std::optional<torch::Tensor>,
|
||||
std::optional<torch::Tensor>,
|
||||
std::optional<torch::Tensor>,
|
||||
std::optional<torch::Tensor>,
|
||||
std::optional<torch::Tensor>,
|
||||
std::optional<torch::Tensor>>
|
||||
backward(
|
||||
const torch::Tensor& grad_im,
|
||||
const torch::Tensor& image,
|
||||
@@ -75,8 +75,8 @@ struct Renderer {
|
||||
const float& gamma,
|
||||
const float& max_depth,
|
||||
float min_depth,
|
||||
const c10::optional<torch::Tensor>& bg_col,
|
||||
const c10::optional<torch::Tensor>& opacity,
|
||||
const std::optional<torch::Tensor>& bg_col,
|
||||
const std::optional<torch::Tensor>& opacity,
|
||||
const float& percent_allowed_difference,
|
||||
const uint& max_n_hits,
|
||||
const uint& mode,
|
||||
@@ -85,7 +85,7 @@ struct Renderer {
|
||||
const bool& dif_rad,
|
||||
const bool& dif_cam,
|
||||
const bool& dif_opy,
|
||||
const at::optional<std::pair<uint, uint>>& dbg_pos);
|
||||
const std::optional<std::pair<uint, uint>>& dbg_pos);
|
||||
|
||||
// Infrastructure.
|
||||
/**
|
||||
@@ -115,8 +115,8 @@ struct Renderer {
|
||||
const float& gamma,
|
||||
const float& max_depth,
|
||||
float& min_depth,
|
||||
const c10::optional<torch::Tensor>& bg_col,
|
||||
const c10::optional<torch::Tensor>& opacity,
|
||||
const std::optional<torch::Tensor>& bg_col,
|
||||
const std::optional<torch::Tensor>& opacity,
|
||||
const float& percent_allowed_difference,
|
||||
const uint& max_n_hits,
|
||||
const uint& mode);
|
||||
|
||||
@@ -8,6 +8,7 @@
|
||||
|
||||
#ifdef WITH_CUDA
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <c10/cuda/CUDAException.h>
|
||||
#include <cuda_runtime_api.h>
|
||||
#endif
|
||||
#include <torch/extension.h>
|
||||
@@ -33,13 +34,13 @@ torch::Tensor sphere_ids_from_result_info_nograd(
|
||||
.contiguous();
|
||||
if (forw_info.device().type() == c10::DeviceType::CUDA) {
|
||||
#ifdef WITH_CUDA
|
||||
cudaMemcpyAsync(
|
||||
C10_CUDA_CHECK(cudaMemcpyAsync(
|
||||
result.data_ptr(),
|
||||
tmp.data_ptr(),
|
||||
sizeof(uint32_t) * tmp.size(0) * tmp.size(1) * tmp.size(2) *
|
||||
tmp.size(3),
|
||||
cudaMemcpyDeviceToDevice,
|
||||
at::cuda::getCurrentCUDAStream());
|
||||
at::cuda::getCurrentCUDAStream()));
|
||||
#else
|
||||
throw std::runtime_error(
|
||||
"Copy on CUDA device initiated but built "
|
||||
|
||||
@@ -7,6 +7,7 @@
|
||||
*/
|
||||
|
||||
#ifdef WITH_CUDA
|
||||
#include <c10/cuda/CUDAException.h>
|
||||
#include <cuda_runtime_api.h>
|
||||
|
||||
namespace pulsar {
|
||||
@@ -17,7 +18,8 @@ void cudaDevToDev(
|
||||
const void* src,
|
||||
const int& size,
|
||||
const cudaStream_t& stream) {
|
||||
cudaMemcpyAsync(trg, src, size, cudaMemcpyDeviceToDevice, stream);
|
||||
C10_CUDA_CHECK(
|
||||
cudaMemcpyAsync(trg, src, size, cudaMemcpyDeviceToDevice, stream));
|
||||
}
|
||||
|
||||
void cudaDevToHost(
|
||||
@@ -25,7 +27,8 @@ void cudaDevToHost(
|
||||
const void* src,
|
||||
const int& size,
|
||||
const cudaStream_t& stream) {
|
||||
cudaMemcpyAsync(trg, src, size, cudaMemcpyDeviceToHost, stream);
|
||||
C10_CUDA_CHECK(
|
||||
cudaMemcpyAsync(trg, src, size, cudaMemcpyDeviceToHost, stream));
|
||||
}
|
||||
|
||||
} // namespace pytorch
|
||||
|
||||
@@ -6,9 +6,6 @@
|
||||
* LICENSE file in the root directory of this source tree.
|
||||
*/
|
||||
|
||||
#include "./global.h"
|
||||
#include "./logging.h"
|
||||
|
||||
/**
|
||||
* A compilation unit to provide warnings about the code and avoid
|
||||
* repeated messages.
|
||||
|
||||
@@ -25,7 +25,7 @@ class BitMask {
|
||||
|
||||
// Use all threads in the current block to clear all bits of this BitMask
|
||||
__device__ void block_clear() {
|
||||
for (int i = threadIdx.x; i < H * W * D; i += blockDim.x) {
|
||||
for (auto i = threadIdx.x; i < H * W * D; i += blockDim.x) {
|
||||
data[i] = 0;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
@@ -23,8 +23,8 @@ __global__ void TriangleBoundingBoxKernel(
|
||||
const float blur_radius,
|
||||
float* bboxes, // (4, F)
|
||||
bool* skip_face) { // (F,)
|
||||
const int tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
const int num_threads = blockDim.x * gridDim.x;
|
||||
const auto tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
const auto num_threads = blockDim.x * gridDim.x;
|
||||
const float sqrt_radius = sqrt(blur_radius);
|
||||
for (int f = tid; f < F; f += num_threads) {
|
||||
const float v0x = face_verts[f * 9 + 0 * 3 + 0];
|
||||
@@ -56,8 +56,8 @@ __global__ void PointBoundingBoxKernel(
|
||||
const int P,
|
||||
float* bboxes, // (4, P)
|
||||
bool* skip_points) {
|
||||
const int tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
const int num_threads = blockDim.x * gridDim.x;
|
||||
const auto tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
const auto num_threads = blockDim.x * gridDim.x;
|
||||
for (int p = tid; p < P; p += num_threads) {
|
||||
const float x = points[p * 3 + 0];
|
||||
const float y = points[p * 3 + 1];
|
||||
@@ -113,7 +113,7 @@ __global__ void RasterizeCoarseCudaKernel(
|
||||
const int chunks_per_batch = 1 + (E - 1) / chunk_size;
|
||||
const int num_chunks = N * chunks_per_batch;
|
||||
|
||||
for (int chunk = blockIdx.x; chunk < num_chunks; chunk += gridDim.x) {
|
||||
for (auto chunk = blockIdx.x; chunk < num_chunks; chunk += gridDim.x) {
|
||||
const int batch_idx = chunk / chunks_per_batch; // batch index
|
||||
const int chunk_idx = chunk % chunks_per_batch;
|
||||
const int elem_chunk_start_idx = chunk_idx * chunk_size;
|
||||
@@ -123,7 +123,7 @@ __global__ void RasterizeCoarseCudaKernel(
|
||||
const int64_t elem_stop_idx = elem_start_idx + elems_per_batch[batch_idx];
|
||||
|
||||
// Have each thread handle a different face within the chunk
|
||||
for (int e = threadIdx.x; e < chunk_size; e += blockDim.x) {
|
||||
for (auto e = threadIdx.x; e < chunk_size; e += blockDim.x) {
|
||||
const int e_idx = elem_chunk_start_idx + e;
|
||||
|
||||
// Check that we are still within the same element of the batch
|
||||
@@ -170,7 +170,7 @@ __global__ void RasterizeCoarseCudaKernel(
|
||||
// Now we have processed every elem in the current chunk. We need to
|
||||
// count the number of elems in each bin so we can write the indices
|
||||
// out to global memory. We have each thread handle a different bin.
|
||||
for (int byx = threadIdx.x; byx < num_bins_y * num_bins_x;
|
||||
for (auto byx = threadIdx.x; byx < num_bins_y * num_bins_x;
|
||||
byx += blockDim.x) {
|
||||
const int by = byx / num_bins_x;
|
||||
const int bx = byx % num_bins_x;
|
||||
|
||||
@@ -260,8 +260,8 @@ __global__ void RasterizeMeshesNaiveCudaKernel(
|
||||
float* pix_dists,
|
||||
float* bary) {
|
||||
// Simple version: One thread per output pixel
|
||||
int num_threads = gridDim.x * blockDim.x;
|
||||
int tid = blockDim.x * blockIdx.x + threadIdx.x;
|
||||
auto num_threads = gridDim.x * blockDim.x;
|
||||
auto tid = blockDim.x * blockIdx.x + threadIdx.x;
|
||||
|
||||
for (int i = tid; i < N * H * W; i += num_threads) {
|
||||
// Convert linear index to 3D index
|
||||
@@ -446,8 +446,8 @@ __global__ void RasterizeMeshesBackwardCudaKernel(
|
||||
|
||||
// Parallelize over each pixel in images of
|
||||
// size H * W, for each image in the batch of size N.
|
||||
const int num_threads = gridDim.x * blockDim.x;
|
||||
const int tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
const auto num_threads = gridDim.x * blockDim.x;
|
||||
const auto tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
|
||||
for (int t_i = tid; t_i < N * H * W; t_i += num_threads) {
|
||||
// Convert linear index to 3D index
|
||||
@@ -650,8 +650,8 @@ __global__ void RasterizeMeshesFineCudaKernel(
|
||||
) {
|
||||
// This can be more than H * W if H or W are not divisible by bin_size.
|
||||
int num_pixels = N * BH * BW * bin_size * bin_size;
|
||||
int num_threads = gridDim.x * blockDim.x;
|
||||
int tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
auto num_threads = gridDim.x * blockDim.x;
|
||||
auto tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
|
||||
for (int pid = tid; pid < num_pixels; pid += num_threads) {
|
||||
// Convert linear index into bin and pixel indices. We make the within
|
||||
|
||||
@@ -138,6 +138,9 @@ RasterizeMeshesNaive(
|
||||
AT_ERROR("Not compiled with GPU support");
|
||||
#endif
|
||||
} else {
|
||||
CHECK_CPU(face_verts);
|
||||
CHECK_CPU(mesh_to_face_first_idx);
|
||||
CHECK_CPU(num_faces_per_mesh);
|
||||
return RasterizeMeshesNaiveCpu(
|
||||
face_verts,
|
||||
mesh_to_face_first_idx,
|
||||
@@ -232,6 +235,11 @@ torch::Tensor RasterizeMeshesBackward(
|
||||
AT_ERROR("Not compiled with GPU support");
|
||||
#endif
|
||||
} else {
|
||||
CHECK_CPU(face_verts);
|
||||
CHECK_CPU(pix_to_face);
|
||||
CHECK_CPU(grad_zbuf);
|
||||
CHECK_CPU(grad_bary);
|
||||
CHECK_CPU(grad_dists);
|
||||
return RasterizeMeshesBackwardCpu(
|
||||
face_verts,
|
||||
pix_to_face,
|
||||
@@ -306,6 +314,9 @@ torch::Tensor RasterizeMeshesCoarse(
|
||||
AT_ERROR("Not compiled with GPU support");
|
||||
#endif
|
||||
} else {
|
||||
CHECK_CPU(face_verts);
|
||||
CHECK_CPU(mesh_to_face_first_idx);
|
||||
CHECK_CPU(num_faces_per_mesh);
|
||||
return RasterizeMeshesCoarseCpu(
|
||||
face_verts,
|
||||
mesh_to_face_first_idx,
|
||||
@@ -423,6 +434,8 @@ RasterizeMeshesFine(
|
||||
AT_ERROR("Not compiled with GPU support");
|
||||
#endif
|
||||
} else {
|
||||
CHECK_CPU(face_verts);
|
||||
CHECK_CPU(bin_faces);
|
||||
AT_ERROR("NOT IMPLEMENTED");
|
||||
}
|
||||
}
|
||||
|
||||
@@ -9,7 +9,6 @@
|
||||
#include <torch/extension.h>
|
||||
#include <algorithm>
|
||||
#include <list>
|
||||
#include <queue>
|
||||
#include <thread>
|
||||
#include <tuple>
|
||||
#include "ATen/core/TensorAccessor.h"
|
||||
|
||||
@@ -97,8 +97,8 @@ __global__ void RasterizePointsNaiveCudaKernel(
|
||||
float* zbuf, // (N, H, W, K)
|
||||
float* pix_dists) { // (N, H, W, K)
|
||||
// Simple version: One thread per output pixel
|
||||
const int num_threads = gridDim.x * blockDim.x;
|
||||
const int tid = blockDim.x * blockIdx.x + threadIdx.x;
|
||||
const auto num_threads = gridDim.x * blockDim.x;
|
||||
const auto tid = blockDim.x * blockIdx.x + threadIdx.x;
|
||||
for (int i = tid; i < N * H * W; i += num_threads) {
|
||||
// Convert linear index to 3D index
|
||||
const int n = i / (H * W); // Batch index
|
||||
@@ -237,8 +237,8 @@ __global__ void RasterizePointsFineCudaKernel(
|
||||
float* pix_dists) { // (N, H, W, K)
|
||||
// This can be more than H * W if H or W are not divisible by bin_size.
|
||||
const int num_pixels = N * BH * BW * bin_size * bin_size;
|
||||
const int num_threads = gridDim.x * blockDim.x;
|
||||
const int tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
const auto num_threads = gridDim.x * blockDim.x;
|
||||
const auto tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
|
||||
for (int pid = tid; pid < num_pixels; pid += num_threads) {
|
||||
// Convert linear index into bin and pixel indices. We make the within
|
||||
@@ -376,8 +376,8 @@ __global__ void RasterizePointsBackwardCudaKernel(
|
||||
float* grad_points) { // (P, 3)
|
||||
// Parallelized over each of K points per pixel, for each pixel in images of
|
||||
// size H * W, for each image in the batch of size N.
|
||||
int num_threads = gridDim.x * blockDim.x;
|
||||
int tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
auto num_threads = gridDim.x * blockDim.x;
|
||||
auto tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
for (int i = tid; i < N * H * W * K; i += num_threads) {
|
||||
// const int n = i / (H * W * K); // batch index (not needed).
|
||||
const int yxk = i % (H * W * K);
|
||||
|
||||
@@ -91,6 +91,10 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> RasterizePointsNaive(
|
||||
AT_ERROR("Not compiled with GPU support");
|
||||
#endif
|
||||
} else {
|
||||
CHECK_CPU(points);
|
||||
CHECK_CPU(cloud_to_packed_first_idx);
|
||||
CHECK_CPU(num_points_per_cloud);
|
||||
CHECK_CPU(radius);
|
||||
return RasterizePointsNaiveCpu(
|
||||
points,
|
||||
cloud_to_packed_first_idx,
|
||||
@@ -166,6 +170,10 @@ torch::Tensor RasterizePointsCoarse(
|
||||
AT_ERROR("Not compiled with GPU support");
|
||||
#endif
|
||||
} else {
|
||||
CHECK_CPU(points);
|
||||
CHECK_CPU(cloud_to_packed_first_idx);
|
||||
CHECK_CPU(num_points_per_cloud);
|
||||
CHECK_CPU(radius);
|
||||
return RasterizePointsCoarseCpu(
|
||||
points,
|
||||
cloud_to_packed_first_idx,
|
||||
@@ -232,6 +240,8 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> RasterizePointsFine(
|
||||
AT_ERROR("Not compiled with GPU support");
|
||||
#endif
|
||||
} else {
|
||||
CHECK_CPU(points);
|
||||
CHECK_CPU(bin_points);
|
||||
AT_ERROR("NOT IMPLEMENTED");
|
||||
}
|
||||
}
|
||||
@@ -284,6 +294,10 @@ torch::Tensor RasterizePointsBackward(
|
||||
AT_ERROR("Not compiled with GPU support");
|
||||
#endif
|
||||
} else {
|
||||
CHECK_CPU(points);
|
||||
CHECK_CPU(idxs);
|
||||
CHECK_CPU(grad_zbuf);
|
||||
CHECK_CPU(grad_dists);
|
||||
return RasterizePointsBackwardCpu(points, idxs, grad_zbuf, grad_dists);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -35,8 +35,6 @@ __global__ void FarthestPointSamplingKernel(
|
||||
__shared__ int64_t selected_store;
|
||||
|
||||
// Get constants
|
||||
const int64_t N = points.size(0);
|
||||
const int64_t P = points.size(1);
|
||||
const int64_t D = points.size(2);
|
||||
|
||||
// Get batch index and thread index
|
||||
@@ -109,7 +107,8 @@ at::Tensor FarthestPointSamplingCuda(
|
||||
const at::Tensor& points, // (N, P, 3)
|
||||
const at::Tensor& lengths, // (N,)
|
||||
const at::Tensor& K, // (N,)
|
||||
const at::Tensor& start_idxs) {
|
||||
const at::Tensor& start_idxs,
|
||||
const int64_t max_K_known = -1) {
|
||||
// Check inputs are on the same device
|
||||
at::TensorArg p_t{points, "points", 1}, lengths_t{lengths, "lengths", 2},
|
||||
k_t{K, "K", 3}, start_idxs_t{start_idxs, "start_idxs", 4};
|
||||
@@ -131,7 +130,12 @@ at::Tensor FarthestPointSamplingCuda(
|
||||
|
||||
const int64_t N = points.size(0);
|
||||
const int64_t P = points.size(1);
|
||||
const int64_t max_K = at::max(K).item<int64_t>();
|
||||
int64_t max_K;
|
||||
if (max_K_known > 0) {
|
||||
max_K = max_K_known;
|
||||
} else {
|
||||
max_K = at::max(K).item<int64_t>();
|
||||
}
|
||||
|
||||
// Initialize the output tensor with the sampled indices
|
||||
auto idxs = at::full({N, max_K}, -1, lengths.options());
|
||||
|
||||
@@ -43,7 +43,8 @@ at::Tensor FarthestPointSamplingCuda(
|
||||
const at::Tensor& points,
|
||||
const at::Tensor& lengths,
|
||||
const at::Tensor& K,
|
||||
const at::Tensor& start_idxs);
|
||||
const at::Tensor& start_idxs,
|
||||
const int64_t max_K_known = -1);
|
||||
|
||||
at::Tensor FarthestPointSamplingCpu(
|
||||
const at::Tensor& points,
|
||||
@@ -56,17 +57,23 @@ at::Tensor FarthestPointSampling(
|
||||
const at::Tensor& points,
|
||||
const at::Tensor& lengths,
|
||||
const at::Tensor& K,
|
||||
const at::Tensor& start_idxs) {
|
||||
const at::Tensor& start_idxs,
|
||||
const int64_t max_K_known = -1) {
|
||||
if (points.is_cuda() || lengths.is_cuda() || K.is_cuda()) {
|
||||
#ifdef WITH_CUDA
|
||||
CHECK_CUDA(points);
|
||||
CHECK_CUDA(lengths);
|
||||
CHECK_CUDA(K);
|
||||
CHECK_CUDA(start_idxs);
|
||||
return FarthestPointSamplingCuda(points, lengths, K, start_idxs);
|
||||
return FarthestPointSamplingCuda(
|
||||
points, lengths, K, start_idxs, max_K_known);
|
||||
#else
|
||||
AT_ERROR("Not compiled with GPU support.");
|
||||
#endif
|
||||
}
|
||||
CHECK_CPU(points);
|
||||
CHECK_CPU(lengths);
|
||||
CHECK_CPU(K);
|
||||
CHECK_CPU(start_idxs);
|
||||
return FarthestPointSamplingCpu(points, lengths, K, start_idxs);
|
||||
}
|
||||
|
||||
@@ -71,6 +71,8 @@ inline void SamplePdf(
|
||||
AT_ERROR("Not compiled with GPU support.");
|
||||
#endif
|
||||
}
|
||||
CHECK_CPU(weights);
|
||||
CHECK_CPU(outputs);
|
||||
CHECK_CONTIGUOUS(outputs);
|
||||
SamplePdfCpu(bins, weights, outputs, eps);
|
||||
}
|
||||
|
||||
@@ -99,8 +99,7 @@ namespace {
|
||||
// and increment it via template recursion until it is equal to the run-time
|
||||
// argument N.
|
||||
template <
|
||||
template <typename, int64_t>
|
||||
class Kernel,
|
||||
template <typename, int64_t> class Kernel,
|
||||
typename T,
|
||||
int64_t minN,
|
||||
int64_t maxN,
|
||||
@@ -124,8 +123,7 @@ struct DispatchKernelHelper1D {
|
||||
// 1D dispatch: Specialization when curN == maxN
|
||||
// We need this base case to avoid infinite template recursion.
|
||||
template <
|
||||
template <typename, int64_t>
|
||||
class Kernel,
|
||||
template <typename, int64_t> class Kernel,
|
||||
typename T,
|
||||
int64_t minN,
|
||||
int64_t maxN,
|
||||
@@ -145,8 +143,7 @@ struct DispatchKernelHelper1D<Kernel, T, minN, maxN, maxN, Args...> {
|
||||
// the run-time values of N and M, at which point we dispatch to the run
|
||||
// method of the kernel.
|
||||
template <
|
||||
template <typename, int64_t, int64_t>
|
||||
class Kernel,
|
||||
template <typename, int64_t, int64_t> class Kernel,
|
||||
typename T,
|
||||
int64_t minN,
|
||||
int64_t maxN,
|
||||
@@ -203,8 +200,7 @@ struct DispatchKernelHelper2D {
|
||||
|
||||
// 2D dispatch, specialization for curN == maxN
|
||||
template <
|
||||
template <typename, int64_t, int64_t>
|
||||
class Kernel,
|
||||
template <typename, int64_t, int64_t> class Kernel,
|
||||
typename T,
|
||||
int64_t minN,
|
||||
int64_t maxN,
|
||||
@@ -243,8 +239,7 @@ struct DispatchKernelHelper2D<
|
||||
|
||||
// 2D dispatch, specialization for curM == maxM
|
||||
template <
|
||||
template <typename, int64_t, int64_t>
|
||||
class Kernel,
|
||||
template <typename, int64_t, int64_t> class Kernel,
|
||||
typename T,
|
||||
int64_t minN,
|
||||
int64_t maxN,
|
||||
@@ -283,8 +278,7 @@ struct DispatchKernelHelper2D<
|
||||
|
||||
// 2D dispatch, specialization for curN == maxN, curM == maxM
|
||||
template <
|
||||
template <typename, int64_t, int64_t>
|
||||
class Kernel,
|
||||
template <typename, int64_t, int64_t> class Kernel,
|
||||
typename T,
|
||||
int64_t minN,
|
||||
int64_t maxN,
|
||||
@@ -313,8 +307,7 @@ struct DispatchKernelHelper2D<
|
||||
|
||||
// This is the function we expect users to call to dispatch to 1D functions
|
||||
template <
|
||||
template <typename, int64_t>
|
||||
class Kernel,
|
||||
template <typename, int64_t> class Kernel,
|
||||
typename T,
|
||||
int64_t minN,
|
||||
int64_t maxN,
|
||||
@@ -330,8 +323,7 @@ void DispatchKernel1D(const int64_t N, Args... args) {
|
||||
|
||||
// This is the function we expect users to call to dispatch to 2D functions
|
||||
template <
|
||||
template <typename, int64_t, int64_t>
|
||||
class Kernel,
|
||||
template <typename, int64_t, int64_t> class Kernel,
|
||||
typename T,
|
||||
int64_t minN,
|
||||
int64_t maxN,
|
||||
|
||||
@@ -376,8 +376,6 @@ PointLineDistanceBackward(
|
||||
float tt = t_top / t_bot;
|
||||
tt = __saturatef(tt);
|
||||
const float2 p_proj = (1.0f - tt) * v0 + tt * v1;
|
||||
const float2 d = p - p_proj;
|
||||
const float dist = sqrt(dot(d, d));
|
||||
|
||||
const float2 grad_p = -1.0f * grad_dist * 2.0f * (p_proj - p);
|
||||
const float2 grad_v0 = grad_dist * (1.0f - tt) * 2.0f * (p_proj - p);
|
||||
|
||||
@@ -15,3 +15,7 @@
|
||||
#define CHECK_CONTIGUOUS_CUDA(x) \
|
||||
CHECK_CUDA(x); \
|
||||
CHECK_CONTIGUOUS(x)
|
||||
#define CHECK_CPU(x) \
|
||||
TORCH_CHECK( \
|
||||
x.device().type() == torch::kCPU, \
|
||||
"Cannot use CPU implementation: " #x " not on CPU.")
|
||||
|
||||
@@ -83,7 +83,7 @@ class ShapeNetCore(ShapeNetBase): # pragma: no cover
|
||||
):
|
||||
synset_set.add(synset)
|
||||
elif (synset in self.synset_inv.keys()) and (
|
||||
(path.isdir(path.join(data_dir, self.synset_inv[synset])))
|
||||
path.isdir(path.join(data_dir, self.synset_inv[synset]))
|
||||
):
|
||||
synset_set.add(self.synset_inv[synset])
|
||||
else:
|
||||
|
||||
@@ -36,7 +36,6 @@ def collate_batched_meshes(batch: List[Dict]): # pragma: no cover
|
||||
|
||||
collated_dict["mesh"] = None
|
||||
if {"verts", "faces"}.issubset(collated_dict.keys()):
|
||||
|
||||
textures = None
|
||||
if "textures" in collated_dict:
|
||||
textures = TexturesAtlas(atlas=collated_dict["textures"])
|
||||
|
||||
@@ -26,7 +26,7 @@ from typing import (
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from pytorch3d.implicitron.dataset import types
|
||||
from pytorch3d.implicitron.dataset import orm_types, types
|
||||
from pytorch3d.implicitron.dataset.utils import (
|
||||
adjust_camera_to_bbox_crop_,
|
||||
adjust_camera_to_image_scale_,
|
||||
@@ -48,8 +48,12 @@ from pytorch3d.implicitron.dataset.utils import (
|
||||
from pytorch3d.implicitron.tools.config import registry, ReplaceableBase
|
||||
from pytorch3d.renderer.camera_utils import join_cameras_as_batch
|
||||
from pytorch3d.renderer.cameras import CamerasBase, PerspectiveCameras
|
||||
from pytorch3d.structures.meshes import join_meshes_as_batch, Meshes
|
||||
from pytorch3d.structures.pointclouds import join_pointclouds_as_batch, Pointclouds
|
||||
|
||||
FrameAnnotationT = types.FrameAnnotation | orm_types.SqlFrameAnnotation
|
||||
SequenceAnnotationT = types.SequenceAnnotation | orm_types.SqlSequenceAnnotation
|
||||
|
||||
|
||||
@dataclass
|
||||
class FrameData(Mapping[str, Any]):
|
||||
@@ -122,9 +126,9 @@ class FrameData(Mapping[str, Any]):
|
||||
meta: A dict for storing additional frame information.
|
||||
"""
|
||||
|
||||
frame_number: Optional[torch.LongTensor]
|
||||
sequence_name: Union[str, List[str]]
|
||||
sequence_category: Union[str, List[str]]
|
||||
frame_number: Optional[torch.LongTensor] = None
|
||||
sequence_name: Union[str, List[str]] = ""
|
||||
sequence_category: Union[str, List[str]] = ""
|
||||
frame_timestamp: Optional[torch.Tensor] = None
|
||||
image_size_hw: Optional[torch.LongTensor] = None
|
||||
effective_image_size_hw: Optional[torch.LongTensor] = None
|
||||
@@ -155,7 +159,7 @@ class FrameData(Mapping[str, Any]):
|
||||
new_params = {}
|
||||
for field_name in iter(self):
|
||||
value = getattr(self, field_name)
|
||||
if isinstance(value, (torch.Tensor, Pointclouds, CamerasBase)):
|
||||
if isinstance(value, (torch.Tensor, Pointclouds, CamerasBase, Meshes)):
|
||||
new_params[field_name] = value.to(*args, **kwargs)
|
||||
else:
|
||||
new_params[field_name] = value
|
||||
@@ -417,7 +421,6 @@ class FrameData(Mapping[str, Any]):
|
||||
for f in fields(elem):
|
||||
if not f.init:
|
||||
continue
|
||||
|
||||
list_values = override_fields.get(
|
||||
f.name, [getattr(d, f.name) for d in batch]
|
||||
)
|
||||
@@ -426,7 +429,7 @@ class FrameData(Mapping[str, Any]):
|
||||
if all(list_value is not None for list_value in list_values)
|
||||
else None
|
||||
)
|
||||
return cls(**collated)
|
||||
return type(elem)(**collated)
|
||||
|
||||
elif isinstance(elem, Pointclouds):
|
||||
return join_pointclouds_as_batch(batch)
|
||||
@@ -434,6 +437,8 @@ class FrameData(Mapping[str, Any]):
|
||||
elif isinstance(elem, CamerasBase):
|
||||
# TODO: don't store K; enforce working in NDC space
|
||||
return join_cameras_as_batch(batch)
|
||||
elif isinstance(elem, Meshes):
|
||||
return join_meshes_as_batch(batch)
|
||||
else:
|
||||
return torch.utils.data.dataloader.default_collate(batch)
|
||||
|
||||
@@ -454,8 +459,8 @@ class FrameDataBuilderBase(ReplaceableBase, Generic[FrameDataSubtype], ABC):
|
||||
@abstractmethod
|
||||
def build(
|
||||
self,
|
||||
frame_annotation: types.FrameAnnotation,
|
||||
sequence_annotation: types.SequenceAnnotation,
|
||||
frame_annotation: FrameAnnotationT,
|
||||
sequence_annotation: SequenceAnnotationT,
|
||||
*,
|
||||
load_blobs: bool = True,
|
||||
**kwargs,
|
||||
@@ -541,8 +546,8 @@ class GenericFrameDataBuilder(FrameDataBuilderBase[FrameDataSubtype], ABC):
|
||||
|
||||
def build(
|
||||
self,
|
||||
frame_annotation: types.FrameAnnotation,
|
||||
sequence_annotation: types.SequenceAnnotation,
|
||||
frame_annotation: FrameAnnotationT,
|
||||
sequence_annotation: SequenceAnnotationT,
|
||||
*,
|
||||
load_blobs: bool = True,
|
||||
**kwargs,
|
||||
@@ -586,58 +591,81 @@ class GenericFrameDataBuilder(FrameDataBuilderBase[FrameDataSubtype], ABC):
|
||||
),
|
||||
)
|
||||
|
||||
fg_mask_np: Optional[np.ndarray] = None
|
||||
dataset_root = self.dataset_root
|
||||
mask_annotation = frame_annotation.mask
|
||||
if mask_annotation is not None:
|
||||
if load_blobs and self.load_masks:
|
||||
fg_mask_np, mask_path = self._load_fg_probability(frame_annotation)
|
||||
depth_annotation = frame_annotation.depth
|
||||
image_path: str | None = None
|
||||
mask_path: str | None = None
|
||||
depth_path: str | None = None
|
||||
pcl_path: str | None = None
|
||||
if dataset_root is not None: # set all paths even if we won’t load blobs
|
||||
if frame_annotation.image.path is not None:
|
||||
image_path = os.path.join(dataset_root, frame_annotation.image.path)
|
||||
frame_data.image_path = image_path
|
||||
|
||||
if mask_annotation is not None and mask_annotation.path:
|
||||
mask_path = os.path.join(dataset_root, mask_annotation.path)
|
||||
frame_data.mask_path = mask_path
|
||||
|
||||
if depth_annotation is not None and depth_annotation.path is not None:
|
||||
depth_path = os.path.join(dataset_root, depth_annotation.path)
|
||||
frame_data.depth_path = depth_path
|
||||
|
||||
if point_cloud is not None:
|
||||
pcl_path = os.path.join(dataset_root, point_cloud.path)
|
||||
frame_data.sequence_point_cloud_path = pcl_path
|
||||
|
||||
fg_mask_np: np.ndarray | None = None
|
||||
bbox_xywh: tuple[float, float, float, float] | None = None
|
||||
|
||||
if mask_annotation is not None:
|
||||
if load_blobs and self.load_masks and mask_path:
|
||||
fg_mask_np = self._load_fg_probability(frame_annotation, mask_path)
|
||||
frame_data.fg_probability = safe_as_tensor(fg_mask_np, torch.float)
|
||||
|
||||
bbox_xywh = mask_annotation.bounding_box_xywh
|
||||
if bbox_xywh is None and fg_mask_np is not None:
|
||||
bbox_xywh = get_bbox_from_mask(fg_mask_np, self.box_crop_mask_thr)
|
||||
|
||||
frame_data.bbox_xywh = safe_as_tensor(bbox_xywh, torch.float)
|
||||
|
||||
if frame_annotation.image is not None:
|
||||
image_size_hw = safe_as_tensor(frame_annotation.image.size, torch.long)
|
||||
frame_data.image_size_hw = image_size_hw # original image size
|
||||
# image size after crop/resize
|
||||
frame_data.effective_image_size_hw = image_size_hw
|
||||
image_path = None
|
||||
dataset_root = self.dataset_root
|
||||
if frame_annotation.image.path is not None and dataset_root is not None:
|
||||
image_path = os.path.join(dataset_root, frame_annotation.image.path)
|
||||
frame_data.image_path = image_path
|
||||
|
||||
if load_blobs and self.load_images:
|
||||
if image_path is None:
|
||||
raise ValueError("Image path is required to load images.")
|
||||
|
||||
image_np = load_image(self._local_path(image_path))
|
||||
no_mask = fg_mask_np is None # didn’t read the mask file
|
||||
image_np = load_image(
|
||||
self._local_path(image_path), try_read_alpha=no_mask
|
||||
)
|
||||
if image_np.shape[0] == 4: # RGBA image
|
||||
if no_mask:
|
||||
fg_mask_np = image_np[3:]
|
||||
frame_data.fg_probability = safe_as_tensor(
|
||||
fg_mask_np, torch.float
|
||||
)
|
||||
|
||||
image_np = image_np[:3]
|
||||
|
||||
frame_data.image_rgb = self._postprocess_image(
|
||||
image_np, frame_annotation.image.size, frame_data.fg_probability
|
||||
)
|
||||
|
||||
if (
|
||||
load_blobs
|
||||
and self.load_depths
|
||||
and frame_annotation.depth is not None
|
||||
and frame_annotation.depth.path is not None
|
||||
):
|
||||
(
|
||||
frame_data.depth_map,
|
||||
frame_data.depth_path,
|
||||
frame_data.depth_mask,
|
||||
) = self._load_mask_depth(frame_annotation, fg_mask_np)
|
||||
if bbox_xywh is None and fg_mask_np is not None:
|
||||
bbox_xywh = get_bbox_from_mask(fg_mask_np, self.box_crop_mask_thr)
|
||||
frame_data.bbox_xywh = safe_as_tensor(bbox_xywh, torch.float)
|
||||
|
||||
if load_blobs and self.load_depths and depth_path is not None:
|
||||
frame_data.depth_map, frame_data.depth_mask = self._load_mask_depth(
|
||||
frame_annotation, depth_path, fg_mask_np
|
||||
)
|
||||
|
||||
if load_blobs and self.load_point_clouds and point_cloud is not None:
|
||||
pcl_path = self._fix_point_cloud_path(point_cloud.path)
|
||||
assert pcl_path is not None
|
||||
frame_data.sequence_point_cloud = load_pointcloud(
|
||||
self._local_path(pcl_path), max_points=self.max_points
|
||||
)
|
||||
frame_data.sequence_point_cloud_path = pcl_path
|
||||
|
||||
if frame_annotation.viewpoint is not None:
|
||||
frame_data.camera = self._get_pytorch3d_camera(frame_annotation)
|
||||
@@ -653,18 +681,14 @@ class GenericFrameDataBuilder(FrameDataBuilderBase[FrameDataSubtype], ABC):
|
||||
|
||||
return frame_data
|
||||
|
||||
def _load_fg_probability(
|
||||
self, entry: types.FrameAnnotation
|
||||
) -> Tuple[np.ndarray, str]:
|
||||
assert self.dataset_root is not None and entry.mask is not None
|
||||
full_path = os.path.join(self.dataset_root, entry.mask.path)
|
||||
fg_probability = load_mask(self._local_path(full_path))
|
||||
def _load_fg_probability(self, entry: FrameAnnotationT, path: str) -> np.ndarray:
|
||||
fg_probability = load_mask(self._local_path(path))
|
||||
if fg_probability.shape[-2:] != entry.image.size:
|
||||
raise ValueError(
|
||||
f"bad mask size: {fg_probability.shape[-2:]} vs {entry.image.size}!"
|
||||
)
|
||||
|
||||
return fg_probability, full_path
|
||||
return fg_probability
|
||||
|
||||
def _postprocess_image(
|
||||
self,
|
||||
@@ -685,14 +709,14 @@ class GenericFrameDataBuilder(FrameDataBuilderBase[FrameDataSubtype], ABC):
|
||||
|
||||
def _load_mask_depth(
|
||||
self,
|
||||
entry: types.FrameAnnotation,
|
||||
entry: FrameAnnotationT,
|
||||
path: str,
|
||||
fg_mask: Optional[np.ndarray],
|
||||
) -> Tuple[torch.Tensor, str, torch.Tensor]:
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
entry_depth = entry.depth
|
||||
dataset_root = self.dataset_root
|
||||
assert dataset_root is not None
|
||||
assert entry_depth is not None and entry_depth.path is not None
|
||||
path = os.path.join(dataset_root, entry_depth.path)
|
||||
assert entry_depth is not None
|
||||
depth_map = load_depth(self._local_path(path), entry_depth.scale_adjustment)
|
||||
|
||||
if self.mask_depths:
|
||||
@@ -706,11 +730,11 @@ class GenericFrameDataBuilder(FrameDataBuilderBase[FrameDataSubtype], ABC):
|
||||
else:
|
||||
depth_mask = (depth_map > 0.0).astype(np.float32)
|
||||
|
||||
return torch.tensor(depth_map), path, torch.tensor(depth_mask)
|
||||
return torch.tensor(depth_map), torch.tensor(depth_mask)
|
||||
|
||||
def _get_pytorch3d_camera(
|
||||
self,
|
||||
entry: types.FrameAnnotation,
|
||||
entry: FrameAnnotationT,
|
||||
) -> PerspectiveCameras:
|
||||
entry_viewpoint = entry.viewpoint
|
||||
assert entry_viewpoint is not None
|
||||
@@ -739,19 +763,6 @@ class GenericFrameDataBuilder(FrameDataBuilderBase[FrameDataSubtype], ABC):
|
||||
T=torch.tensor(entry_viewpoint.T, dtype=torch.float)[None],
|
||||
)
|
||||
|
||||
def _fix_point_cloud_path(self, path: str) -> str:
|
||||
"""
|
||||
Fix up a point cloud path from the dataset.
|
||||
Some files in Co3Dv2 have an accidental absolute path stored.
|
||||
"""
|
||||
unwanted_prefix = (
|
||||
"/large_experiments/p3/replay/datasets/co3d/co3d45k_220512/export_v23/"
|
||||
)
|
||||
if path.startswith(unwanted_prefix):
|
||||
path = path[len(unwanted_prefix) :]
|
||||
assert self.dataset_root is not None
|
||||
return os.path.join(self.dataset_root, path)
|
||||
|
||||
def _local_path(self, path: str) -> str:
|
||||
if self.path_manager is None:
|
||||
return path
|
||||
|
||||
@@ -222,7 +222,6 @@ class JsonIndexDatasetMapProviderV2(DatasetMapProviderBase):
|
||||
self.dataset_map = dataset_map
|
||||
|
||||
def _load_category(self, category: str) -> DatasetMap:
|
||||
|
||||
frame_file = os.path.join(self.dataset_root, category, "frame_annotations.jgz")
|
||||
sequence_file = os.path.join(
|
||||
self.dataset_root, category, "sequence_annotations.jgz"
|
||||
|
||||
@@ -75,7 +75,6 @@ def _minify(basedir, path_manager, factors=(), resolutions=()):
|
||||
def _load_data(
|
||||
basedir, factor=None, width=None, height=None, load_imgs=True, path_manager=None
|
||||
):
|
||||
|
||||
poses_arr = np.load(
|
||||
_local_path(path_manager, os.path.join(basedir, "poses_bounds.npy"))
|
||||
)
|
||||
@@ -164,7 +163,6 @@ def ptstocam(pts, c2w):
|
||||
|
||||
|
||||
def poses_avg(poses):
|
||||
|
||||
hwf = poses[0, :3, -1:]
|
||||
|
||||
center = poses[:, :3, 3].mean(0)
|
||||
@@ -192,7 +190,6 @@ def render_path_spiral(c2w, up, rads, focal, zdelta, zrate, rots, N):
|
||||
|
||||
|
||||
def recenter_poses(poses):
|
||||
|
||||
poses_ = poses + 0
|
||||
bottom = np.reshape([0, 0, 0, 1.0], [1, 4])
|
||||
c2w = poses_avg(poses)
|
||||
@@ -256,7 +253,6 @@ def spherify_poses(poses, bds):
|
||||
new_poses = []
|
||||
|
||||
for th in np.linspace(0.0, 2.0 * np.pi, 120):
|
||||
|
||||
camorigin = np.array([radcircle * np.cos(th), radcircle * np.sin(th), zh])
|
||||
up = np.array([0, 0, -1.0])
|
||||
|
||||
@@ -311,7 +307,6 @@ def load_llff_data(
|
||||
path_zflat=False,
|
||||
path_manager=None,
|
||||
):
|
||||
|
||||
poses, bds, imgs = _load_data(
|
||||
basedir, factor=factor, path_manager=path_manager
|
||||
) # factor=8 downsamples original imgs by 8x
|
||||
|
||||
@@ -4,6 +4,8 @@
|
||||
# This source code is licensed under the BSD-style license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
# pyre-unsafe
|
||||
|
||||
# This functionality requires SQLAlchemy 2.0 or later.
|
||||
|
||||
import math
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user