mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-20 21:02:48 +08:00
Compare commits
50 Commits
Author | SHA1 | Date | |
---|---|---|---|
|
50f8efa1cb | ||
|
5043d15361 | ||
|
e3d3a67a89 | ||
|
e55ea90609 | ||
|
3aee2a6005 | ||
|
c5ea8fa49e | ||
|
3ff6c5ab85 | ||
|
267bd8ef87 | ||
|
177eec6378 | ||
|
71db7a0ea2 | ||
|
6020323d94 | ||
|
182e845c19 | ||
|
f315ac131b | ||
|
fc08621879 | ||
|
3f327a516b | ||
|
366eff21d9 | ||
|
0a59450f0e | ||
|
3987612062 | ||
|
06a76ef8dd | ||
|
21205730d9 | ||
|
7e09505538 | ||
|
20bd8b33f6 | ||
|
7a3c0cbc9d | ||
|
215590b497 | ||
|
43cd681d4f | ||
|
42a4a7d432 | ||
|
699bc671ca | ||
|
49cf5a0f37 | ||
|
89b851e64c | ||
|
5247f6ad74 | ||
|
e41aff47db | ||
|
64a5bfadc8 | ||
|
055ab3a2e3 | ||
|
f6c2ca6bfc | ||
|
e20cbe9b0e | ||
|
c17e6f947a | ||
|
91c9f34137 | ||
|
81d82980bc | ||
|
8fe6934885 | ||
|
c434957b2a | ||
|
dd2a11b5fc | ||
|
9563ef79ca | ||
|
008c7ab58c | ||
|
9eaed4c495 | ||
|
e13848265d | ||
|
58566963d6 | ||
|
e17ed5cd50 | ||
|
8ed0c7a002 | ||
|
2da913c7e6 | ||
|
fca83e6369 |
@ -88,7 +88,6 @@ def workflow_pair(
|
||||
upload=False,
|
||||
filter_branch,
|
||||
):
|
||||
|
||||
w = []
|
||||
py = python_version.replace(".", "")
|
||||
pyt = pytorch_version.replace(".", "")
|
||||
@ -127,7 +126,6 @@ def generate_base_workflow(
|
||||
btype,
|
||||
filter_branch=None,
|
||||
):
|
||||
|
||||
d = {
|
||||
"name": base_workflow_name,
|
||||
"python_version": python_version,
|
||||
|
23
.github/workflows/build.yml
vendored
Normal file
23
.github/workflows/build.yml
vendored
Normal file
@ -0,0 +1,23 @@
|
||||
name: facebookresearch/pytorch3d/build_and_test
|
||||
on:
|
||||
pull_request:
|
||||
branches:
|
||||
- main
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
jobs:
|
||||
binary_linux_conda_cuda:
|
||||
runs-on: 4-core-ubuntu-gpu-t4
|
||||
env:
|
||||
PYTHON_VERSION: "3.12"
|
||||
BUILD_VERSION: "${{ github.run_number }}"
|
||||
PYTORCH_VERSION: "2.4.1"
|
||||
CU_VERSION: "cu121"
|
||||
JUST_TESTRUN: 1
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- name: Build and run tests
|
||||
run: |-
|
||||
conda create --name env --yes --quiet conda-build
|
||||
conda run --no-capture-output --name env python3 ./packaging/build_conda.py --use-conda-cuda
|
@ -36,5 +36,5 @@ then
|
||||
|
||||
echo "Running pyre..."
|
||||
echo "To restart/kill pyre server, run 'pyre restart' or 'pyre kill' in fbcode/"
|
||||
( cd ~/fbsource/fbcode; pyre -l vision/fair/pytorch3d/ )
|
||||
( cd ~/fbsource/fbcode; arc pyre check //vision/fair/pytorch3d/... )
|
||||
fi
|
||||
|
@ -10,6 +10,7 @@ This example demonstrates the most trivial, direct interface of the pulsar
|
||||
sphere renderer. It renders and saves an image with 10 random spheres.
|
||||
Output: basic.png.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import math
|
||||
from os import path
|
||||
|
@ -11,6 +11,7 @@ interface for sphere renderering. It renders and saves an image with
|
||||
10 random spheres.
|
||||
Output: basic-pt3d.png.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from os import path
|
||||
|
||||
|
@ -14,6 +14,7 @@ distorted. Gradient-based optimization is used to converge towards the
|
||||
original camera parameters.
|
||||
Output: cam.gif.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import math
|
||||
from os import path
|
||||
|
@ -14,6 +14,7 @@ distorted. Gradient-based optimization is used to converge towards the
|
||||
original camera parameters.
|
||||
Output: cam-pt3d.gif
|
||||
"""
|
||||
|
||||
import logging
|
||||
from os import path
|
||||
|
||||
|
@ -18,6 +18,7 @@ This example is not available yet through the 'unified' interface,
|
||||
because opacity support has not landed in PyTorch3D for general data
|
||||
structures yet.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import math
|
||||
from os import path
|
||||
|
@ -13,6 +13,7 @@ The scene is initialized with random spheres. Gradient-based
|
||||
optimization is used to converge towards a faithful
|
||||
scene representation.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import math
|
||||
|
||||
|
@ -13,6 +13,7 @@ The scene is initialized with random spheres. Gradient-based
|
||||
optimization is used to converge towards a faithful
|
||||
scene representation.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import math
|
||||
|
||||
|
@ -4,10 +4,11 @@
|
||||
# This source code is licensed under the BSD-style license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import argparse
|
||||
import os.path
|
||||
import runpy
|
||||
import subprocess
|
||||
from typing import List
|
||||
from typing import List, Tuple
|
||||
|
||||
# required env vars:
|
||||
# CU_VERSION: E.g. cu112
|
||||
@ -23,7 +24,7 @@ pytorch_major_minor = tuple(int(i) for i in PYTORCH_VERSION.split(".")[:2])
|
||||
source_root_dir = os.environ["PWD"]
|
||||
|
||||
|
||||
def version_constraint(version):
|
||||
def version_constraint(version) -> str:
|
||||
"""
|
||||
Given version "11.3" returns " >=11.3,<11.4"
|
||||
"""
|
||||
@ -32,7 +33,7 @@ def version_constraint(version):
|
||||
return f" >={version},<{upper}"
|
||||
|
||||
|
||||
def get_cuda_major_minor():
|
||||
def get_cuda_major_minor() -> Tuple[str, str]:
|
||||
if CU_VERSION == "cpu":
|
||||
raise ValueError("fn only for cuda builds")
|
||||
if len(CU_VERSION) != 5 or CU_VERSION[:2] != "cu":
|
||||
@ -42,11 +43,10 @@ def get_cuda_major_minor():
|
||||
return major, minor
|
||||
|
||||
|
||||
def setup_cuda():
|
||||
def setup_cuda(use_conda_cuda: bool) -> List[str]:
|
||||
if CU_VERSION == "cpu":
|
||||
return
|
||||
return []
|
||||
major, minor = get_cuda_major_minor()
|
||||
os.environ["CUDA_HOME"] = f"/usr/local/cuda-{major}.{minor}/"
|
||||
os.environ["FORCE_CUDA"] = "1"
|
||||
|
||||
basic_nvcc_flags = (
|
||||
@ -75,6 +75,15 @@ def setup_cuda():
|
||||
|
||||
if os.environ.get("JUST_TESTRUN", "0") != "1":
|
||||
os.environ["NVCC_FLAGS"] = nvcc_flags
|
||||
if use_conda_cuda:
|
||||
os.environ["CONDA_CUDA_TOOLKIT_BUILD_CONSTRAINT1"] = "- cuda-toolkit"
|
||||
os.environ["CONDA_CUDA_TOOLKIT_BUILD_CONSTRAINT2"] = (
|
||||
f"- cuda-version={major}.{minor}"
|
||||
)
|
||||
return ["-c", f"nvidia/label/cuda-{major}.{minor}.0"]
|
||||
else:
|
||||
os.environ["CUDA_HOME"] = f"/usr/local/cuda-{major}.{minor}/"
|
||||
return []
|
||||
|
||||
|
||||
def setup_conda_pytorch_constraint() -> List[str]:
|
||||
@ -95,7 +104,7 @@ def setup_conda_pytorch_constraint() -> List[str]:
|
||||
return ["-c", "pytorch", "-c", "nvidia"]
|
||||
|
||||
|
||||
def setup_conda_cudatoolkit_constraint():
|
||||
def setup_conda_cudatoolkit_constraint() -> None:
|
||||
if CU_VERSION == "cpu":
|
||||
os.environ["CONDA_CPUONLY_FEATURE"] = "- cpuonly"
|
||||
os.environ["CONDA_CUDATOOLKIT_CONSTRAINT"] = ""
|
||||
@ -116,7 +125,7 @@ def setup_conda_cudatoolkit_constraint():
|
||||
os.environ["CONDA_CUDATOOLKIT_CONSTRAINT"] = toolkit
|
||||
|
||||
|
||||
def do_build(start_args: List[str]):
|
||||
def do_build(start_args: List[str]) -> None:
|
||||
args = start_args.copy()
|
||||
|
||||
test_flag = os.environ.get("TEST_FLAG")
|
||||
@ -132,8 +141,16 @@ def do_build(start_args: List[str]):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Build the conda package.")
|
||||
parser.add_argument(
|
||||
"--use-conda-cuda",
|
||||
action="store_true",
|
||||
help="get cuda from conda ignoring local cuda",
|
||||
)
|
||||
our_args = parser.parse_args()
|
||||
|
||||
args = ["conda", "build"]
|
||||
setup_cuda()
|
||||
args += setup_cuda(use_conda_cuda=our_args.use_conda_cuda)
|
||||
|
||||
init_path = source_root_dir + "/pytorch3d/__init__.py"
|
||||
build_version = runpy.run_path(init_path)["__version__"]
|
||||
|
@ -8,10 +8,13 @@ source:
|
||||
requirements:
|
||||
build:
|
||||
- {{ compiler('c') }} # [win]
|
||||
{{ environ.get('CONDA_CUDA_TOOLKIT_BUILD_CONSTRAINT1', '') }}
|
||||
{{ environ.get('CONDA_CUDA_TOOLKIT_BUILD_CONSTRAINT2', '') }}
|
||||
{{ environ.get('CONDA_CUB_CONSTRAINT') }}
|
||||
|
||||
host:
|
||||
- python
|
||||
- mkl =2023 # [x86_64]
|
||||
{{ environ.get('SETUPTOOLS_CONSTRAINT') }}
|
||||
{{ environ.get('CONDA_PYTORCH_BUILD_CONSTRAINT') }}
|
||||
{{ environ.get('CONDA_PYTORCH_MKL_CONSTRAINT') }}
|
||||
@ -22,6 +25,7 @@ requirements:
|
||||
- python
|
||||
- numpy >=1.11
|
||||
- torchvision >=0.5
|
||||
- mkl =2023 # [x86_64]
|
||||
- iopath
|
||||
{{ environ.get('CONDA_PYTORCH_CONSTRAINT') }}
|
||||
{{ environ.get('CONDA_CUDATOOLKIT_CONSTRAINT') }}
|
||||
@ -47,8 +51,11 @@ test:
|
||||
- imageio
|
||||
- hydra-core
|
||||
- accelerate
|
||||
- matplotlib
|
||||
- tabulate
|
||||
- pandas
|
||||
- sqlalchemy
|
||||
commands:
|
||||
#pytest .
|
||||
python -m unittest discover -v -s tests -t .
|
||||
|
||||
|
||||
|
@ -7,7 +7,7 @@
|
||||
|
||||
# pyre-unsafe
|
||||
|
||||
""""
|
||||
""" "
|
||||
This file is the entry point for launching experiments with Implicitron.
|
||||
|
||||
Launch Training
|
||||
@ -44,6 +44,7 @@ The outputs of the experiment are saved and logged in multiple ways:
|
||||
config file.
|
||||
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import warnings
|
||||
|
@ -26,7 +26,6 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ModelFactoryBase(ReplaceableBase):
|
||||
|
||||
resume: bool = True # resume from the last checkpoint
|
||||
|
||||
def __call__(self, **kwargs) -> ImplicitronModelBase:
|
||||
@ -116,7 +115,9 @@ class ImplicitronModelFactory(ModelFactoryBase):
|
||||
"cuda:%d" % 0: "cuda:%d" % accelerator.local_process_index
|
||||
}
|
||||
model_state_dict = torch.load(
|
||||
model_io.get_model_path(model_path), map_location=map_location
|
||||
model_io.get_model_path(model_path),
|
||||
map_location=map_location,
|
||||
weights_only=True,
|
||||
)
|
||||
|
||||
try:
|
||||
|
@ -123,6 +123,7 @@ class ImplicitronOptimizerFactory(OptimizerFactoryBase):
|
||||
"""
|
||||
# Get the parameters to optimize
|
||||
if hasattr(model, "_get_param_groups"): # use the model function
|
||||
# pyre-fixme[29]: `Union[Tensor, Module]` is not a function.
|
||||
p_groups = model._get_param_groups(self.lr, wd=self.weight_decay)
|
||||
else:
|
||||
p_groups = [
|
||||
@ -241,7 +242,7 @@ class ImplicitronOptimizerFactory(OptimizerFactoryBase):
|
||||
map_location = {
|
||||
"cuda:%d" % 0: "cuda:%d" % accelerator.local_process_index
|
||||
}
|
||||
optimizer_state = torch.load(opt_path, map_location)
|
||||
optimizer_state = torch.load(opt_path, map_location, weights_only=True)
|
||||
else:
|
||||
raise FileNotFoundError(f"Optimizer state {opt_path} does not exist.")
|
||||
return optimizer_state
|
||||
|
@ -161,7 +161,6 @@ class ImplicitronTrainingLoop(TrainingLoopBase):
|
||||
for epoch in range(start_epoch, self.max_epochs):
|
||||
# automatic new_epoch and plotting of stats at every epoch start
|
||||
with stats:
|
||||
|
||||
# Make sure to re-seed random generators to ensure reproducibility
|
||||
# even after restart.
|
||||
seed_all_random_engines(seed + epoch)
|
||||
@ -395,6 +394,7 @@ class ImplicitronTrainingLoop(TrainingLoopBase):
|
||||
):
|
||||
prefix = f"e{stats.epoch}_it{stats.it[trainmode]}"
|
||||
if hasattr(model, "visualize"):
|
||||
# pyre-fixme[29]: `Union[Tensor, Module]` is not a function.
|
||||
model.visualize(
|
||||
viz,
|
||||
visdom_env_imgs,
|
||||
|
@ -53,12 +53,8 @@ class TestExperiment(unittest.TestCase):
|
||||
cfg.data_source_ImplicitronDataSource_args.dataset_map_provider_class_type = (
|
||||
"JsonIndexDatasetMapProvider"
|
||||
)
|
||||
dataset_args = (
|
||||
cfg.data_source_ImplicitronDataSource_args.dataset_map_provider_JsonIndexDatasetMapProvider_args
|
||||
)
|
||||
dataloader_args = (
|
||||
cfg.data_source_ImplicitronDataSource_args.data_loader_map_provider_SequenceDataLoaderMapProvider_args
|
||||
)
|
||||
dataset_args = cfg.data_source_ImplicitronDataSource_args.dataset_map_provider_JsonIndexDatasetMapProvider_args
|
||||
dataloader_args = cfg.data_source_ImplicitronDataSource_args.data_loader_map_provider_SequenceDataLoaderMapProvider_args
|
||||
dataset_args.category = "skateboard"
|
||||
dataset_args.test_restrict_sequence_id = 0
|
||||
dataset_args.dataset_root = "manifold://co3d/tree/extracted"
|
||||
@ -94,12 +90,8 @@ class TestExperiment(unittest.TestCase):
|
||||
cfg.data_source_ImplicitronDataSource_args.dataset_map_provider_class_type = (
|
||||
"JsonIndexDatasetMapProvider"
|
||||
)
|
||||
dataset_args = (
|
||||
cfg.data_source_ImplicitronDataSource_args.dataset_map_provider_JsonIndexDatasetMapProvider_args
|
||||
)
|
||||
dataloader_args = (
|
||||
cfg.data_source_ImplicitronDataSource_args.data_loader_map_provider_SequenceDataLoaderMapProvider_args
|
||||
)
|
||||
dataset_args = cfg.data_source_ImplicitronDataSource_args.dataset_map_provider_JsonIndexDatasetMapProvider_args
|
||||
dataloader_args = cfg.data_source_ImplicitronDataSource_args.data_loader_map_provider_SequenceDataLoaderMapProvider_args
|
||||
dataset_args.category = "skateboard"
|
||||
dataset_args.test_restrict_sequence_id = 0
|
||||
dataset_args.dataset_root = "manifold://co3d/tree/extracted"
|
||||
@ -111,9 +103,7 @@ class TestExperiment(unittest.TestCase):
|
||||
cfg.training_loop_ImplicitronTrainingLoop_args.max_epochs = 2
|
||||
cfg.training_loop_ImplicitronTrainingLoop_args.store_checkpoints = False
|
||||
cfg.optimizer_factory_ImplicitronOptimizerFactory_args.lr_policy = "Exponential"
|
||||
cfg.optimizer_factory_ImplicitronOptimizerFactory_args.exponential_lr_step_size = (
|
||||
2
|
||||
)
|
||||
cfg.optimizer_factory_ImplicitronOptimizerFactory_args.exponential_lr_step_size = 2
|
||||
|
||||
if DEBUG:
|
||||
experiment.dump_cfg(cfg)
|
||||
|
@ -81,8 +81,9 @@ class TestOptimizerFactory(unittest.TestCase):
|
||||
|
||||
def test_param_overrides_self_param_group_assignment(self):
|
||||
pa, pb, pc = [torch.nn.Parameter(data=torch.tensor(i * 1.0)) for i in range(3)]
|
||||
na, nb = Node(params=[pa]), Node(
|
||||
params=[pb], param_groups={"self": "pb_self", "p1": "pb_param"}
|
||||
na, nb = (
|
||||
Node(params=[pa]),
|
||||
Node(params=[pb], param_groups={"self": "pb_self", "p1": "pb_param"}),
|
||||
)
|
||||
root = Node(children=[na, nb], params=[pc], param_groups={"m1": "pb_member"})
|
||||
param_groups = self._get_param_groups(root)
|
||||
|
@ -84,9 +84,9 @@ def get_nerf_datasets(
|
||||
|
||||
if autodownload and any(not os.path.isfile(p) for p in (cameras_path, image_path)):
|
||||
# Automatically download the data files if missing.
|
||||
download_data((dataset_name,), data_root=data_root)
|
||||
download_data([dataset_name], data_root=data_root)
|
||||
|
||||
train_data = torch.load(cameras_path)
|
||||
train_data = torch.load(cameras_path, weights_only=True)
|
||||
n_cameras = train_data["cameras"]["R"].shape[0]
|
||||
|
||||
_image_max_image_pixels = Image.MAX_IMAGE_PIXELS
|
||||
|
@ -194,7 +194,6 @@ class Stats:
|
||||
it = self.it[stat_set]
|
||||
|
||||
for stat in self.log_vars:
|
||||
|
||||
if stat not in self.stats[stat_set]:
|
||||
self.stats[stat_set][stat] = AverageMeter()
|
||||
|
||||
|
@ -24,7 +24,6 @@ CONFIG_DIR = os.path.join(os.path.dirname(os.path.realpath(__file__)), "configs"
|
||||
|
||||
@hydra.main(config_path=CONFIG_DIR, config_name="lego")
|
||||
def main(cfg: DictConfig):
|
||||
|
||||
# Device on which to run.
|
||||
if torch.cuda.is_available():
|
||||
device = "cuda"
|
||||
@ -63,7 +62,7 @@ def main(cfg: DictConfig):
|
||||
raise ValueError(f"Model checkpoint {checkpoint_path} does not exist!")
|
||||
|
||||
print(f"Loading checkpoint {checkpoint_path}.")
|
||||
loaded_data = torch.load(checkpoint_path)
|
||||
loaded_data = torch.load(checkpoint_path, weights_only=True)
|
||||
# Do not load the cached xy grid.
|
||||
# - this allows setting an arbitrary evaluation image size.
|
||||
state_dict = {
|
||||
|
@ -42,7 +42,6 @@ class TestRaysampler(unittest.TestCase):
|
||||
cameras, rays = [], []
|
||||
|
||||
for _ in range(batch_size):
|
||||
|
||||
R = random_rotations(1)
|
||||
T = torch.randn(1, 3)
|
||||
focal_length = torch.rand(1, 2) + 0.5
|
||||
|
@ -25,7 +25,6 @@ CONFIG_DIR = os.path.join(os.path.dirname(os.path.realpath(__file__)), "configs"
|
||||
|
||||
@hydra.main(config_path=CONFIG_DIR, config_name="lego")
|
||||
def main(cfg: DictConfig):
|
||||
|
||||
# Set the relevant seeds for reproducibility.
|
||||
np.random.seed(cfg.seed)
|
||||
torch.manual_seed(cfg.seed)
|
||||
@ -77,7 +76,7 @@ def main(cfg: DictConfig):
|
||||
# Resume training if requested.
|
||||
if cfg.resume and os.path.isfile(checkpoint_path):
|
||||
print(f"Resuming from checkpoint {checkpoint_path}.")
|
||||
loaded_data = torch.load(checkpoint_path)
|
||||
loaded_data = torch.load(checkpoint_path, weights_only=True)
|
||||
model.load_state_dict(loaded_data["model"])
|
||||
stats = pickle.loads(loaded_data["stats"])
|
||||
print(f" => resuming from epoch {stats.epoch}.")
|
||||
@ -219,7 +218,6 @@ def main(cfg: DictConfig):
|
||||
|
||||
# Validation
|
||||
if epoch % cfg.validation_epoch_interval == 0 and epoch > 0:
|
||||
|
||||
# Sample a validation camera/image.
|
||||
val_batch = next(val_dataloader.__iter__())
|
||||
val_image, val_camera, camera_idx = val_batch[0].values()
|
||||
|
@ -17,7 +17,7 @@ Some functions which depend on PyTorch or Python versions.
|
||||
|
||||
|
||||
def meshgrid_ij(
|
||||
*A: Union[torch.Tensor, Sequence[torch.Tensor]]
|
||||
*A: Union[torch.Tensor, Sequence[torch.Tensor]],
|
||||
) -> Tuple[torch.Tensor, ...]: # pragma: no cover
|
||||
"""
|
||||
Like torch.meshgrid was before PyTorch 1.10.0, i.e. with indexing set to ij
|
||||
|
@ -81,6 +81,8 @@ inline std::tuple<at::Tensor, at::Tensor> BallQuery(
|
||||
AT_ERROR("Not compiled with GPU support.");
|
||||
#endif
|
||||
}
|
||||
CHECK_CPU(p1);
|
||||
CHECK_CPU(p2);
|
||||
return BallQueryCpu(
|
||||
p1.contiguous(),
|
||||
p2.contiguous(),
|
||||
|
@ -7,7 +7,6 @@
|
||||
*/
|
||||
|
||||
#include <torch/extension.h>
|
||||
#include <queue>
|
||||
#include <tuple>
|
||||
|
||||
std::tuple<at::Tensor, at::Tensor> BallQueryCpu(
|
||||
|
@ -98,6 +98,11 @@ at::Tensor SigmoidAlphaBlendBackward(
|
||||
AT_ERROR("Not compiled with GPU support.");
|
||||
#endif
|
||||
}
|
||||
CHECK_CPU(distances);
|
||||
CHECK_CPU(pix_to_face);
|
||||
CHECK_CPU(alphas);
|
||||
CHECK_CPU(grad_alphas);
|
||||
|
||||
return SigmoidAlphaBlendBackwardCpu(
|
||||
grad_alphas, alphas, distances, pix_to_face, sigma);
|
||||
}
|
||||
|
@ -28,17 +28,16 @@ __global__ void alphaCompositeCudaForwardKernel(
|
||||
const at::PackedTensorAccessor64<float, 4, at::RestrictPtrTraits> alphas,
|
||||
const at::PackedTensorAccessor64<int64_t, 4, at::RestrictPtrTraits> points_idx) {
|
||||
// clang-format on
|
||||
const int64_t batch_size = result.size(0);
|
||||
const int64_t C = features.size(0);
|
||||
const int64_t H = points_idx.size(2);
|
||||
const int64_t W = points_idx.size(3);
|
||||
|
||||
// Get the batch and index
|
||||
const int batch = blockIdx.x;
|
||||
const auto batch = blockIdx.x;
|
||||
|
||||
const int num_pixels = C * H * W;
|
||||
const int num_threads = gridDim.y * blockDim.x;
|
||||
const int tid = blockIdx.y * blockDim.x + threadIdx.x;
|
||||
const auto num_threads = gridDim.y * blockDim.x;
|
||||
const auto tid = blockIdx.y * blockDim.x + threadIdx.x;
|
||||
|
||||
// Iterate over each feature in each pixel
|
||||
for (int pid = tid; pid < num_pixels; pid += num_threads) {
|
||||
@ -79,17 +78,16 @@ __global__ void alphaCompositeCudaBackwardKernel(
|
||||
const at::PackedTensorAccessor64<float, 4, at::RestrictPtrTraits> alphas,
|
||||
const at::PackedTensorAccessor64<int64_t, 4, at::RestrictPtrTraits> points_idx) {
|
||||
// clang-format on
|
||||
const int64_t batch_size = points_idx.size(0);
|
||||
const int64_t C = features.size(0);
|
||||
const int64_t H = points_idx.size(2);
|
||||
const int64_t W = points_idx.size(3);
|
||||
|
||||
// Get the batch and index
|
||||
const int batch = blockIdx.x;
|
||||
const auto batch = blockIdx.x;
|
||||
|
||||
const int num_pixels = C * H * W;
|
||||
const int num_threads = gridDim.y * blockDim.x;
|
||||
const int tid = blockIdx.y * blockDim.x + threadIdx.x;
|
||||
const auto num_threads = gridDim.y * blockDim.x;
|
||||
const auto tid = blockIdx.y * blockDim.x + threadIdx.x;
|
||||
|
||||
// Parallelize over each feature in each pixel in images of size H * W,
|
||||
// for each image in the batch of size batch_size
|
||||
|
@ -74,6 +74,9 @@ torch::Tensor alphaCompositeForward(
|
||||
AT_ERROR("Not compiled with GPU support");
|
||||
#endif
|
||||
} else {
|
||||
CHECK_CPU(features);
|
||||
CHECK_CPU(alphas);
|
||||
CHECK_CPU(points_idx);
|
||||
return alphaCompositeCpuForward(features, alphas, points_idx);
|
||||
}
|
||||
}
|
||||
@ -101,6 +104,11 @@ std::tuple<torch::Tensor, torch::Tensor> alphaCompositeBackward(
|
||||
AT_ERROR("Not compiled with GPU support");
|
||||
#endif
|
||||
} else {
|
||||
CHECK_CPU(grad_outputs);
|
||||
CHECK_CPU(features);
|
||||
CHECK_CPU(alphas);
|
||||
CHECK_CPU(points_idx);
|
||||
|
||||
return alphaCompositeCpuBackward(
|
||||
grad_outputs, features, alphas, points_idx);
|
||||
}
|
||||
|
@ -28,17 +28,16 @@ __global__ void weightedSumNormCudaForwardKernel(
|
||||
const at::PackedTensorAccessor64<float, 4, at::RestrictPtrTraits> alphas,
|
||||
const at::PackedTensorAccessor64<int64_t, 4, at::RestrictPtrTraits> points_idx) {
|
||||
// clang-format on
|
||||
const int64_t batch_size = result.size(0);
|
||||
const int64_t C = features.size(0);
|
||||
const int64_t H = points_idx.size(2);
|
||||
const int64_t W = points_idx.size(3);
|
||||
|
||||
// Get the batch and index
|
||||
const int batch = blockIdx.x;
|
||||
const auto batch = blockIdx.x;
|
||||
|
||||
const int num_pixels = C * H * W;
|
||||
const int num_threads = gridDim.y * blockDim.x;
|
||||
const int tid = blockIdx.y * blockDim.x + threadIdx.x;
|
||||
const auto num_threads = gridDim.y * blockDim.x;
|
||||
const auto tid = blockIdx.y * blockDim.x + threadIdx.x;
|
||||
|
||||
// Parallelize over each feature in each pixel in images of size H * W,
|
||||
// for each image in the batch of size batch_size
|
||||
@ -92,17 +91,16 @@ __global__ void weightedSumNormCudaBackwardKernel(
|
||||
const at::PackedTensorAccessor64<float, 4, at::RestrictPtrTraits> alphas,
|
||||
const at::PackedTensorAccessor64<int64_t, 4, at::RestrictPtrTraits> points_idx) {
|
||||
// clang-format on
|
||||
const int64_t batch_size = points_idx.size(0);
|
||||
const int64_t C = features.size(0);
|
||||
const int64_t H = points_idx.size(2);
|
||||
const int64_t W = points_idx.size(3);
|
||||
|
||||
// Get the batch and index
|
||||
const int batch = blockIdx.x;
|
||||
const auto batch = blockIdx.x;
|
||||
|
||||
const int num_pixels = C * W * H;
|
||||
const int num_threads = gridDim.y * blockDim.x;
|
||||
const int tid = blockIdx.y * blockDim.x + threadIdx.x;
|
||||
const auto num_threads = gridDim.y * blockDim.x;
|
||||
const auto tid = blockIdx.y * blockDim.x + threadIdx.x;
|
||||
|
||||
// Parallelize over each feature in each pixel in images of size H * W,
|
||||
// for each image in the batch of size batch_size
|
||||
|
@ -73,6 +73,10 @@ torch::Tensor weightedSumNormForward(
|
||||
AT_ERROR("Not compiled with GPU support");
|
||||
#endif
|
||||
} else {
|
||||
CHECK_CPU(features);
|
||||
CHECK_CPU(alphas);
|
||||
CHECK_CPU(points_idx);
|
||||
|
||||
return weightedSumNormCpuForward(features, alphas, points_idx);
|
||||
}
|
||||
}
|
||||
@ -100,6 +104,11 @@ std::tuple<torch::Tensor, torch::Tensor> weightedSumNormBackward(
|
||||
AT_ERROR("Not compiled with GPU support");
|
||||
#endif
|
||||
} else {
|
||||
CHECK_CPU(grad_outputs);
|
||||
CHECK_CPU(features);
|
||||
CHECK_CPU(alphas);
|
||||
CHECK_CPU(points_idx);
|
||||
|
||||
return weightedSumNormCpuBackward(
|
||||
grad_outputs, features, alphas, points_idx);
|
||||
}
|
||||
|
@ -26,17 +26,16 @@ __global__ void weightedSumCudaForwardKernel(
|
||||
const at::PackedTensorAccessor64<float, 4, at::RestrictPtrTraits> alphas,
|
||||
const at::PackedTensorAccessor64<int64_t, 4, at::RestrictPtrTraits> points_idx) {
|
||||
// clang-format on
|
||||
const int64_t batch_size = result.size(0);
|
||||
const int64_t C = features.size(0);
|
||||
const int64_t H = points_idx.size(2);
|
||||
const int64_t W = points_idx.size(3);
|
||||
|
||||
// Get the batch and index
|
||||
const int batch = blockIdx.x;
|
||||
const auto batch = blockIdx.x;
|
||||
|
||||
const int num_pixels = C * H * W;
|
||||
const int num_threads = gridDim.y * blockDim.x;
|
||||
const int tid = blockIdx.y * blockDim.x + threadIdx.x;
|
||||
const auto num_threads = gridDim.y * blockDim.x;
|
||||
const auto tid = blockIdx.y * blockDim.x + threadIdx.x;
|
||||
|
||||
// Parallelize over each feature in each pixel in images of size H * W,
|
||||
// for each image in the batch of size batch_size
|
||||
@ -74,17 +73,16 @@ __global__ void weightedSumCudaBackwardKernel(
|
||||
const at::PackedTensorAccessor64<float, 4, at::RestrictPtrTraits> alphas,
|
||||
const at::PackedTensorAccessor64<int64_t, 4, at::RestrictPtrTraits> points_idx) {
|
||||
// clang-format on
|
||||
const int64_t batch_size = points_idx.size(0);
|
||||
const int64_t C = features.size(0);
|
||||
const int64_t H = points_idx.size(2);
|
||||
const int64_t W = points_idx.size(3);
|
||||
|
||||
// Get the batch and index
|
||||
const int batch = blockIdx.x;
|
||||
const auto batch = blockIdx.x;
|
||||
|
||||
const int num_pixels = C * H * W;
|
||||
const int num_threads = gridDim.y * blockDim.x;
|
||||
const int tid = blockIdx.y * blockDim.x + threadIdx.x;
|
||||
const auto num_threads = gridDim.y * blockDim.x;
|
||||
const auto tid = blockIdx.y * blockDim.x + threadIdx.x;
|
||||
|
||||
// Iterate over each pixel to compute the contribution to the
|
||||
// gradient for the features and weights
|
||||
|
@ -72,6 +72,9 @@ torch::Tensor weightedSumForward(
|
||||
AT_ERROR("Not compiled with GPU support");
|
||||
#endif
|
||||
} else {
|
||||
CHECK_CPU(features);
|
||||
CHECK_CPU(alphas);
|
||||
CHECK_CPU(points_idx);
|
||||
return weightedSumCpuForward(features, alphas, points_idx);
|
||||
}
|
||||
}
|
||||
@ -98,6 +101,11 @@ std::tuple<torch::Tensor, torch::Tensor> weightedSumBackward(
|
||||
AT_ERROR("Not compiled with GPU support");
|
||||
#endif
|
||||
} else {
|
||||
CHECK_CPU(grad_outputs);
|
||||
CHECK_CPU(features);
|
||||
CHECK_CPU(alphas);
|
||||
CHECK_CPU(points_idx);
|
||||
|
||||
return weightedSumCpuBackward(grad_outputs, features, alphas, points_idx);
|
||||
}
|
||||
}
|
||||
|
@ -7,15 +7,10 @@
|
||||
*/
|
||||
|
||||
// clang-format off
|
||||
#if !defined(USE_ROCM)
|
||||
#include "./pulsar/global.h" // Include before <torch/extension.h>.
|
||||
#endif
|
||||
#include <torch/extension.h>
|
||||
// clang-format on
|
||||
#if !defined(USE_ROCM)
|
||||
#include "./pulsar/pytorch/renderer.h"
|
||||
#include "./pulsar/pytorch/tensor_util.h"
|
||||
#endif
|
||||
#include "ball_query/ball_query.h"
|
||||
#include "blending/sigmoid_alpha_blend.h"
|
||||
#include "compositing/alpha_composite.h"
|
||||
@ -104,7 +99,6 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
|
||||
// Pulsar.
|
||||
// Pulsar not enabled on AMD.
|
||||
#if !defined(USE_ROCM)
|
||||
#ifdef PULSAR_LOGGING_ENABLED
|
||||
c10::ShowLogInfoToStderr();
|
||||
#endif
|
||||
@ -154,10 +148,10 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
py::arg("gamma"),
|
||||
py::arg("max_depth"),
|
||||
py::arg("min_depth") /* = 0.f*/,
|
||||
py::arg(
|
||||
"bg_col") /* = at::nullopt not exposed properly in pytorch 1.1. */
|
||||
py::arg("bg_col") /* = std::nullopt not exposed properly in
|
||||
pytorch 1.1. */
|
||||
,
|
||||
py::arg("opacity") /* = at::nullopt ... */,
|
||||
py::arg("opacity") /* = std::nullopt ... */,
|
||||
py::arg("percent_allowed_difference") = 0.01f,
|
||||
py::arg("max_n_hits") = MAX_UINT,
|
||||
py::arg("mode") = 0)
|
||||
@ -189,5 +183,4 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
m.attr("MAX_UINT") = py::int_(MAX_UINT);
|
||||
m.attr("MAX_USHORT") = py::int_(MAX_USHORT);
|
||||
m.attr("PULSAR_MAX_GRAD_SPHERES") = py::int_(MAX_GRAD_SPHERES);
|
||||
#endif
|
||||
}
|
||||
|
@ -60,6 +60,8 @@ std::tuple<at::Tensor, at::Tensor> FaceAreasNormalsForward(
|
||||
AT_ERROR("Not compiled with GPU support.");
|
||||
#endif
|
||||
}
|
||||
CHECK_CPU(verts);
|
||||
CHECK_CPU(faces);
|
||||
return FaceAreasNormalsForwardCpu(verts, faces);
|
||||
}
|
||||
|
||||
@ -80,5 +82,9 @@ at::Tensor FaceAreasNormalsBackward(
|
||||
AT_ERROR("Not compiled with GPU support.");
|
||||
#endif
|
||||
}
|
||||
CHECK_CPU(grad_areas);
|
||||
CHECK_CPU(grad_normals);
|
||||
CHECK_CPU(verts);
|
||||
CHECK_CPU(faces);
|
||||
return FaceAreasNormalsBackwardCpu(grad_areas, grad_normals, verts, faces);
|
||||
}
|
||||
|
@ -20,14 +20,14 @@ __global__ void GatherScatterCudaKernel(
|
||||
const size_t V,
|
||||
const size_t D,
|
||||
const size_t E) {
|
||||
const int tid = threadIdx.x;
|
||||
const auto tid = threadIdx.x;
|
||||
|
||||
// Reverse the vertex order if backward.
|
||||
const int v0_idx = backward ? 1 : 0;
|
||||
const int v1_idx = backward ? 0 : 1;
|
||||
|
||||
// Edges are split evenly across the blocks.
|
||||
for (int e = blockIdx.x; e < E; e += gridDim.x) {
|
||||
for (auto e = blockIdx.x; e < E; e += gridDim.x) {
|
||||
// Get indices of vertices which form the edge.
|
||||
const int64_t v0 = edges[2 * e + v0_idx];
|
||||
const int64_t v1 = edges[2 * e + v1_idx];
|
||||
@ -35,7 +35,7 @@ __global__ void GatherScatterCudaKernel(
|
||||
// Split vertex features evenly across threads.
|
||||
// This implementation will be quite wasteful when D<128 since there will be
|
||||
// a lot of threads doing nothing.
|
||||
for (int d = tid; d < D; d += blockDim.x) {
|
||||
for (auto d = tid; d < D; d += blockDim.x) {
|
||||
const float val = input[v1 * D + d];
|
||||
float* address = output + v0 * D + d;
|
||||
atomicAdd(address, val);
|
||||
|
@ -53,5 +53,7 @@ at::Tensor GatherScatter(
|
||||
AT_ERROR("Not compiled with GPU support.");
|
||||
#endif
|
||||
}
|
||||
CHECK_CPU(input);
|
||||
CHECK_CPU(edges);
|
||||
return GatherScatterCpu(input, edges, directed, backward);
|
||||
}
|
||||
|
@ -20,8 +20,8 @@ __global__ void InterpFaceAttrsForwardKernel(
|
||||
const size_t P,
|
||||
const size_t F,
|
||||
const size_t D) {
|
||||
const int tid = threadIdx.x + blockIdx.x * blockDim.x;
|
||||
const int num_threads = blockDim.x * gridDim.x;
|
||||
const auto tid = threadIdx.x + blockIdx.x * blockDim.x;
|
||||
const auto num_threads = blockDim.x * gridDim.x;
|
||||
for (int pd = tid; pd < P * D; pd += num_threads) {
|
||||
const int p = pd / D;
|
||||
const int d = pd % D;
|
||||
@ -93,8 +93,8 @@ __global__ void InterpFaceAttrsBackwardKernel(
|
||||
const size_t P,
|
||||
const size_t F,
|
||||
const size_t D) {
|
||||
const int tid = threadIdx.x + blockIdx.x * blockDim.x;
|
||||
const int num_threads = blockDim.x * gridDim.x;
|
||||
const auto tid = threadIdx.x + blockIdx.x * blockDim.x;
|
||||
const auto num_threads = blockDim.x * gridDim.x;
|
||||
for (int pd = tid; pd < P * D; pd += num_threads) {
|
||||
const int p = pd / D;
|
||||
const int d = pd % D;
|
||||
|
@ -57,6 +57,8 @@ at::Tensor InterpFaceAttrsForward(
|
||||
AT_ERROR("Not compiled with GPU support.");
|
||||
#endif
|
||||
}
|
||||
CHECK_CPU(face_attrs);
|
||||
CHECK_CPU(barycentric_coords);
|
||||
return InterpFaceAttrsForwardCpu(pix_to_face, barycentric_coords, face_attrs);
|
||||
}
|
||||
|
||||
@ -106,6 +108,9 @@ std::tuple<at::Tensor, at::Tensor> InterpFaceAttrsBackward(
|
||||
AT_ERROR("Not compiled with GPU support.");
|
||||
#endif
|
||||
}
|
||||
CHECK_CPU(face_attrs);
|
||||
CHECK_CPU(barycentric_coords);
|
||||
CHECK_CPU(grad_pix_attrs);
|
||||
return InterpFaceAttrsBackwardCpu(
|
||||
pix_to_face, barycentric_coords, face_attrs, grad_pix_attrs);
|
||||
}
|
||||
|
@ -44,5 +44,7 @@ inline std::tuple<at::Tensor, at::Tensor> IoUBox3D(
|
||||
AT_ERROR("Not compiled with GPU support.");
|
||||
#endif
|
||||
}
|
||||
CHECK_CPU(boxes1);
|
||||
CHECK_CPU(boxes2);
|
||||
return IoUBox3DCpu(boxes1.contiguous(), boxes2.contiguous());
|
||||
}
|
||||
|
@ -7,10 +7,7 @@
|
||||
*/
|
||||
|
||||
#include <torch/extension.h>
|
||||
#include <torch/torch.h>
|
||||
#include <list>
|
||||
#include <numeric>
|
||||
#include <queue>
|
||||
#include <tuple>
|
||||
#include "iou_box3d/iou_utils.h"
|
||||
|
||||
|
@ -461,10 +461,8 @@ __device__ inline std::tuple<float3, float3> ArgMaxVerts(
|
||||
__device__ inline bool IsCoplanarTriTri(
|
||||
const FaceVerts& tri1,
|
||||
const FaceVerts& tri2) {
|
||||
const float3 tri1_ctr = FaceCenter({tri1.v0, tri1.v1, tri1.v2});
|
||||
const float3 tri1_n = FaceNormal({tri1.v0, tri1.v1, tri1.v2});
|
||||
|
||||
const float3 tri2_ctr = FaceCenter({tri2.v0, tri2.v1, tri2.v2});
|
||||
const float3 tri2_n = FaceNormal({tri2.v0, tri2.v1, tri2.v2});
|
||||
|
||||
// Check if parallel
|
||||
@ -500,7 +498,6 @@ __device__ inline bool IsCoplanarTriPlane(
|
||||
const FaceVerts& tri,
|
||||
const FaceVerts& plane,
|
||||
const float3& normal) {
|
||||
const float3 tri_ctr = FaceCenter({tri.v0, tri.v1, tri.v2});
|
||||
const float3 nt = FaceNormal({tri.v0, tri.v1, tri.v2});
|
||||
|
||||
// check if parallel
|
||||
@ -728,7 +725,7 @@ __device__ inline int BoxIntersections(
|
||||
}
|
||||
}
|
||||
// Update the face_verts_out tris
|
||||
num_tris = offset;
|
||||
num_tris = min(MAX_TRIS, offset);
|
||||
for (int j = 0; j < num_tris; ++j) {
|
||||
face_verts_out[j] = tri_verts_updated[j];
|
||||
}
|
||||
|
@ -74,6 +74,8 @@ std::tuple<at::Tensor, at::Tensor> KNearestNeighborIdx(
|
||||
AT_ERROR("Not compiled with GPU support.");
|
||||
#endif
|
||||
}
|
||||
CHECK_CPU(p1);
|
||||
CHECK_CPU(p2);
|
||||
return KNearestNeighborIdxCpu(p1, p2, lengths1, lengths2, norm, K);
|
||||
}
|
||||
|
||||
@ -140,6 +142,8 @@ std::tuple<at::Tensor, at::Tensor> KNearestNeighborBackward(
|
||||
AT_ERROR("Not compiled with GPU support.");
|
||||
#endif
|
||||
}
|
||||
CHECK_CPU(p1);
|
||||
CHECK_CPU(p2);
|
||||
return KNearestNeighborBackwardCpu(
|
||||
p1, p2, lengths1, lengths2, idxs, norm, grad_dists);
|
||||
}
|
||||
|
@ -58,5 +58,6 @@ inline std::tuple<at::Tensor, at::Tensor, at::Tensor> MarchingCubes(
|
||||
AT_ERROR("Not compiled with GPU support.");
|
||||
#endif
|
||||
}
|
||||
CHECK_CPU(vol);
|
||||
return MarchingCubesCpu(vol.contiguous(), isolevel);
|
||||
}
|
||||
|
@ -88,6 +88,8 @@ at::Tensor PackedToPadded(
|
||||
AT_ERROR("Not compiled with GPU support.");
|
||||
#endif
|
||||
}
|
||||
CHECK_CPU(inputs_packed);
|
||||
CHECK_CPU(first_idxs);
|
||||
return PackedToPaddedCpu(inputs_packed, first_idxs, max_size);
|
||||
}
|
||||
|
||||
@ -105,5 +107,7 @@ at::Tensor PaddedToPacked(
|
||||
AT_ERROR("Not compiled with GPU support.");
|
||||
#endif
|
||||
}
|
||||
CHECK_CPU(inputs_padded);
|
||||
CHECK_CPU(first_idxs);
|
||||
return PaddedToPackedCpu(inputs_padded, first_idxs, num_inputs);
|
||||
}
|
||||
|
@ -110,7 +110,7 @@ __global__ void DistanceForwardKernel(
|
||||
__syncthreads();
|
||||
|
||||
// Perform reduction in shared memory.
|
||||
for (int s = blockDim.x / 2; s > 32; s >>= 1) {
|
||||
for (auto s = blockDim.x / 2; s > 32; s >>= 1) {
|
||||
if (tid < s) {
|
||||
if (min_dists[tid] > min_dists[tid + s]) {
|
||||
min_dists[tid] = min_dists[tid + s];
|
||||
@ -502,8 +502,8 @@ __global__ void PointFaceArrayForwardKernel(
|
||||
const float3* tris_f3 = (float3*)tris;
|
||||
|
||||
// Parallelize over P * S computations
|
||||
const int num_threads = gridDim.x * blockDim.x;
|
||||
const int tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
const auto num_threads = gridDim.x * blockDim.x;
|
||||
const auto tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
|
||||
for (int t_i = tid; t_i < P * T; t_i += num_threads) {
|
||||
const int t = t_i / P; // segment index.
|
||||
@ -576,8 +576,8 @@ __global__ void PointFaceArrayBackwardKernel(
|
||||
const float3* tris_f3 = (float3*)tris;
|
||||
|
||||
// Parallelize over P * S computations
|
||||
const int num_threads = gridDim.x * blockDim.x;
|
||||
const int tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
const auto num_threads = gridDim.x * blockDim.x;
|
||||
const auto tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
|
||||
for (int t_i = tid; t_i < P * T; t_i += num_threads) {
|
||||
const int t = t_i / P; // triangle index.
|
||||
@ -683,8 +683,8 @@ __global__ void PointEdgeArrayForwardKernel(
|
||||
float3* segms_f3 = (float3*)segms;
|
||||
|
||||
// Parallelize over P * S computations
|
||||
const int num_threads = gridDim.x * blockDim.x;
|
||||
const int tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
const auto num_threads = gridDim.x * blockDim.x;
|
||||
const auto tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
|
||||
for (int t_i = tid; t_i < P * S; t_i += num_threads) {
|
||||
const int s = t_i / P; // segment index.
|
||||
@ -752,8 +752,8 @@ __global__ void PointEdgeArrayBackwardKernel(
|
||||
float3* segms_f3 = (float3*)segms;
|
||||
|
||||
// Parallelize over P * S computations
|
||||
const int num_threads = gridDim.x * blockDim.x;
|
||||
const int tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
const auto num_threads = gridDim.x * blockDim.x;
|
||||
const auto tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
|
||||
for (int t_i = tid; t_i < P * S; t_i += num_threads) {
|
||||
const int s = t_i / P; // segment index.
|
||||
|
@ -88,6 +88,10 @@ std::tuple<torch::Tensor, torch::Tensor> PointFaceDistanceForward(
|
||||
AT_ERROR("Not compiled with GPU support.");
|
||||
#endif
|
||||
}
|
||||
CHECK_CPU(points);
|
||||
CHECK_CPU(points_first_idx);
|
||||
CHECK_CPU(tris);
|
||||
CHECK_CPU(tris_first_idx);
|
||||
return PointFaceDistanceForwardCpu(
|
||||
points, points_first_idx, tris, tris_first_idx, min_triangle_area);
|
||||
}
|
||||
@ -143,6 +147,10 @@ std::tuple<torch::Tensor, torch::Tensor> PointFaceDistanceBackward(
|
||||
AT_ERROR("Not compiled with GPU support.");
|
||||
#endif
|
||||
}
|
||||
CHECK_CPU(points);
|
||||
CHECK_CPU(tris);
|
||||
CHECK_CPU(idx_points);
|
||||
CHECK_CPU(grad_dists);
|
||||
return PointFaceDistanceBackwardCpu(
|
||||
points, tris, idx_points, grad_dists, min_triangle_area);
|
||||
}
|
||||
@ -221,6 +229,10 @@ std::tuple<torch::Tensor, torch::Tensor> FacePointDistanceForward(
|
||||
AT_ERROR("Not compiled with GPU support.");
|
||||
#endif
|
||||
}
|
||||
CHECK_CPU(points);
|
||||
CHECK_CPU(points_first_idx);
|
||||
CHECK_CPU(tris);
|
||||
CHECK_CPU(tris_first_idx);
|
||||
return FacePointDistanceForwardCpu(
|
||||
points, points_first_idx, tris, tris_first_idx, min_triangle_area);
|
||||
}
|
||||
@ -277,6 +289,10 @@ std::tuple<torch::Tensor, torch::Tensor> FacePointDistanceBackward(
|
||||
AT_ERROR("Not compiled with GPU support.");
|
||||
#endif
|
||||
}
|
||||
CHECK_CPU(points);
|
||||
CHECK_CPU(tris);
|
||||
CHECK_CPU(idx_tris);
|
||||
CHECK_CPU(grad_dists);
|
||||
return FacePointDistanceBackwardCpu(
|
||||
points, tris, idx_tris, grad_dists, min_triangle_area);
|
||||
}
|
||||
@ -346,6 +362,10 @@ std::tuple<torch::Tensor, torch::Tensor> PointEdgeDistanceForward(
|
||||
AT_ERROR("Not compiled with GPU support.");
|
||||
#endif
|
||||
}
|
||||
CHECK_CPU(points);
|
||||
CHECK_CPU(points_first_idx);
|
||||
CHECK_CPU(segms);
|
||||
CHECK_CPU(segms_first_idx);
|
||||
return PointEdgeDistanceForwardCpu(
|
||||
points, points_first_idx, segms, segms_first_idx, max_points);
|
||||
}
|
||||
@ -396,6 +416,10 @@ std::tuple<torch::Tensor, torch::Tensor> PointEdgeDistanceBackward(
|
||||
AT_ERROR("Not compiled with GPU support.");
|
||||
#endif
|
||||
}
|
||||
CHECK_CPU(points);
|
||||
CHECK_CPU(segms);
|
||||
CHECK_CPU(idx_points);
|
||||
CHECK_CPU(grad_dists);
|
||||
return PointEdgeDistanceBackwardCpu(points, segms, idx_points, grad_dists);
|
||||
}
|
||||
|
||||
@ -464,6 +488,10 @@ std::tuple<torch::Tensor, torch::Tensor> EdgePointDistanceForward(
|
||||
AT_ERROR("Not compiled with GPU support.");
|
||||
#endif
|
||||
}
|
||||
CHECK_CPU(points);
|
||||
CHECK_CPU(points_first_idx);
|
||||
CHECK_CPU(segms);
|
||||
CHECK_CPU(segms_first_idx);
|
||||
return EdgePointDistanceForwardCpu(
|
||||
points, points_first_idx, segms, segms_first_idx, max_segms);
|
||||
}
|
||||
@ -514,6 +542,10 @@ std::tuple<torch::Tensor, torch::Tensor> EdgePointDistanceBackward(
|
||||
AT_ERROR("Not compiled with GPU support.");
|
||||
#endif
|
||||
}
|
||||
CHECK_CPU(points);
|
||||
CHECK_CPU(segms);
|
||||
CHECK_CPU(idx_segms);
|
||||
CHECK_CPU(grad_dists);
|
||||
return EdgePointDistanceBackwardCpu(points, segms, idx_segms, grad_dists);
|
||||
}
|
||||
|
||||
@ -567,6 +599,8 @@ torch::Tensor PointFaceArrayDistanceForward(
|
||||
AT_ERROR("Not compiled with GPU support.");
|
||||
#endif
|
||||
}
|
||||
CHECK_CPU(points);
|
||||
CHECK_CPU(tris);
|
||||
return PointFaceArrayDistanceForwardCpu(points, tris, min_triangle_area);
|
||||
}
|
||||
|
||||
@ -613,6 +647,9 @@ std::tuple<torch::Tensor, torch::Tensor> PointFaceArrayDistanceBackward(
|
||||
AT_ERROR("Not compiled with GPU support.");
|
||||
#endif
|
||||
}
|
||||
CHECK_CPU(points);
|
||||
CHECK_CPU(tris);
|
||||
CHECK_CPU(grad_dists);
|
||||
return PointFaceArrayDistanceBackwardCpu(
|
||||
points, tris, grad_dists, min_triangle_area);
|
||||
}
|
||||
@ -661,6 +698,8 @@ torch::Tensor PointEdgeArrayDistanceForward(
|
||||
AT_ERROR("Not compiled with GPU support.");
|
||||
#endif
|
||||
}
|
||||
CHECK_CPU(points);
|
||||
CHECK_CPU(segms);
|
||||
return PointEdgeArrayDistanceForwardCpu(points, segms);
|
||||
}
|
||||
|
||||
@ -703,5 +742,8 @@ std::tuple<torch::Tensor, torch::Tensor> PointEdgeArrayDistanceBackward(
|
||||
AT_ERROR("Not compiled with GPU support.");
|
||||
#endif
|
||||
}
|
||||
CHECK_CPU(points);
|
||||
CHECK_CPU(segms);
|
||||
CHECK_CPU(grad_dists);
|
||||
return PointEdgeArrayDistanceBackwardCpu(points, segms, grad_dists);
|
||||
}
|
||||
|
@ -104,6 +104,12 @@ inline void PointsToVolumesForward(
|
||||
AT_ERROR("Not compiled with GPU support.");
|
||||
#endif
|
||||
}
|
||||
CHECK_CPU(points_3d);
|
||||
CHECK_CPU(points_features);
|
||||
CHECK_CPU(volume_densities);
|
||||
CHECK_CPU(volume_features);
|
||||
CHECK_CPU(grid_sizes);
|
||||
CHECK_CPU(mask);
|
||||
PointsToVolumesForwardCpu(
|
||||
points_3d,
|
||||
points_features,
|
||||
@ -183,6 +189,14 @@ inline void PointsToVolumesBackward(
|
||||
AT_ERROR("Not compiled with GPU support.");
|
||||
#endif
|
||||
}
|
||||
CHECK_CPU(points_3d);
|
||||
CHECK_CPU(points_features);
|
||||
CHECK_CPU(grid_sizes);
|
||||
CHECK_CPU(mask);
|
||||
CHECK_CPU(grad_volume_densities);
|
||||
CHECK_CPU(grad_volume_features);
|
||||
CHECK_CPU(grad_points_3d);
|
||||
CHECK_CPU(grad_points_features);
|
||||
PointsToVolumesBackwardCpu(
|
||||
points_3d,
|
||||
points_features,
|
||||
|
@ -8,9 +8,7 @@
|
||||
|
||||
#include <torch/csrc/autograd/VariableTypeUtils.h>
|
||||
#include <torch/extension.h>
|
||||
#include <algorithm>
|
||||
#include <cmath>
|
||||
#include <thread>
|
||||
#include <vector>
|
||||
|
||||
// In the x direction, the location {0, ..., grid_size_x - 1} correspond to
|
||||
|
@ -15,8 +15,8 @@
|
||||
#endif
|
||||
|
||||
#if defined(_WIN64) || defined(_WIN32)
|
||||
#define uint unsigned int
|
||||
#define ushort unsigned short
|
||||
using uint = unsigned int;
|
||||
using ushort = unsigned short;
|
||||
#endif
|
||||
|
||||
#include "./logging.h" // <- include before torch/extension.h
|
||||
@ -36,11 +36,13 @@
|
||||
#pragma nv_diag_suppress 2951
|
||||
#pragma nv_diag_suppress 2967
|
||||
#else
|
||||
#if !defined(USE_ROCM)
|
||||
#pragma diag_suppress = attribute_not_allowed
|
||||
#pragma diag_suppress = 1866
|
||||
#pragma diag_suppress = 2941
|
||||
#pragma diag_suppress = 2951
|
||||
#pragma diag_suppress = 2967
|
||||
#endif //! USE_ROCM
|
||||
#endif
|
||||
#else // __CUDACC__
|
||||
#define INLINE inline
|
||||
@ -56,7 +58,9 @@
|
||||
#pragma clang diagnostic pop
|
||||
#ifdef WITH_CUDA
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#if !defined(USE_ROCM)
|
||||
#include <vector_functions.h>
|
||||
#endif //! USE_ROCM
|
||||
#else
|
||||
#ifndef cudaStream_t
|
||||
typedef void* cudaStream_t;
|
||||
|
@ -59,6 +59,11 @@ getLastCudaError(const char* errorMessage, const char* file, const int line) {
|
||||
#define SHARED __shared__
|
||||
#define ACTIVEMASK() __activemask()
|
||||
#define BALLOT(mask, val) __ballot_sync((mask), val)
|
||||
|
||||
/* TODO (ROCM-6.2): None of the WARP_* are used anywhere and ROCM-6.2 natively
|
||||
* supports __shfl_*. Disabling until the move to ROCM-6.2.
|
||||
*/
|
||||
#if !defined(USE_ROCM)
|
||||
/**
|
||||
* Find the cumulative sum within a warp up to the current
|
||||
* thread lane, with each mask thread contributing base.
|
||||
@ -115,6 +120,7 @@ INLINE DEVICE float3 WARP_SUM_FLOAT3(
|
||||
ret.z = WARP_SUM(group, mask, base.z);
|
||||
return ret;
|
||||
}
|
||||
#endif //! USE_ROCM
|
||||
|
||||
// Floating point.
|
||||
// #define FMUL(a, b) __fmul_rn((a), (b))
|
||||
@ -142,6 +148,7 @@ INLINE DEVICE float3 WARP_SUM_FLOAT3(
|
||||
#define FMA(x, y, z) __fmaf_rn((x), (y), (z))
|
||||
#define I2F(a) __int2float_rn(a)
|
||||
#define FRCP(x) __frcp_rn(x)
|
||||
#if !defined(USE_ROCM)
|
||||
__device__ static float atomicMax(float* address, float val) {
|
||||
int* address_as_i = (int*)address;
|
||||
int old = *address_as_i, assumed;
|
||||
@ -166,6 +173,7 @@ __device__ static float atomicMin(float* address, float val) {
|
||||
} while (assumed != old);
|
||||
return __int_as_float(old);
|
||||
}
|
||||
#endif //! USE_ROCM
|
||||
#define DMAX(a, b) FMAX(a, b)
|
||||
#define DMIN(a, b) FMIN(a, b)
|
||||
#define DSQRT(a) sqrt(a)
|
||||
@ -409,7 +417,7 @@ __device__ static float atomicMin(float* address, float val) {
|
||||
(OUT_PTR), \
|
||||
(NUM_SELECTED_PTR), \
|
||||
(NUM_ITEMS), \
|
||||
stream = (STREAM));
|
||||
(STREAM));
|
||||
|
||||
#define COPY_HOST_DEV(PTR_D, PTR_H, TYPE, SIZE) \
|
||||
HANDLECUDA(cudaMemcpy( \
|
@ -14,7 +14,7 @@
|
||||
#include "./commands.h"
|
||||
|
||||
namespace pulsar {
|
||||
IHD CamGradInfo::CamGradInfo() {
|
||||
IHD CamGradInfo::CamGradInfo(int x) {
|
||||
cam_pos = make_float3(0.f, 0.f, 0.f);
|
||||
pixel_0_0_center = make_float3(0.f, 0.f, 0.f);
|
||||
pixel_dir_x = make_float3(0.f, 0.f, 0.f);
|
||||
|
@ -63,18 +63,13 @@ inline bool operator==(const CamInfo& a, const CamInfo& b) {
|
||||
};
|
||||
|
||||
struct CamGradInfo {
|
||||
HOST DEVICE CamGradInfo();
|
||||
HOST DEVICE CamGradInfo(int = 0);
|
||||
float3 cam_pos;
|
||||
float3 pixel_0_0_center;
|
||||
float3 pixel_dir_x;
|
||||
float3 pixel_dir_y;
|
||||
};
|
||||
|
||||
// TODO: remove once https://github.com/NVlabs/cub/issues/172 is resolved.
|
||||
struct IntWrapper {
|
||||
int val;
|
||||
};
|
||||
|
||||
} // namespace pulsar
|
||||
|
||||
#endif
|
||||
|
@ -24,7 +24,7 @@
|
||||
// #pragma diag_suppress = 68
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
// #pragma pop
|
||||
#include "../cuda/commands.h"
|
||||
#include "../gpu/commands.h"
|
||||
#else
|
||||
#pragma clang diagnostic push
|
||||
#pragma clang diagnostic ignored "-Weverything"
|
||||
|
@ -46,6 +46,7 @@ IHD float3 outer_product_sum(const float3& a) {
|
||||
}
|
||||
|
||||
// TODO: put intrinsics here.
|
||||
#if !defined(USE_ROCM)
|
||||
IHD float3 operator+(const float3& a, const float3& b) {
|
||||
return make_float3(a.x + b.x, a.y + b.y, a.z + b.z);
|
||||
}
|
||||
@ -93,6 +94,7 @@ IHD float3 operator*(const float3& a, const float3& b) {
|
||||
IHD float3 operator*(const float& a, const float3& b) {
|
||||
return b * a;
|
||||
}
|
||||
#endif //! USE_ROCM
|
||||
|
||||
INLINE DEVICE float length(const float3& v) {
|
||||
// TODO: benchmark what's faster.
|
||||
@ -147,11 +149,6 @@ IHD CamGradInfo operator*(const CamGradInfo& a, const float& b) {
|
||||
return res;
|
||||
}
|
||||
|
||||
IHD IntWrapper operator+(const IntWrapper& a, const IntWrapper& b) {
|
||||
IntWrapper res;
|
||||
res.val = a.val + b.val;
|
||||
return res;
|
||||
}
|
||||
} // namespace pulsar
|
||||
|
||||
#endif
|
||||
|
@ -155,8 +155,8 @@ void backward(
|
||||
stream);
|
||||
CHECKLAUNCH();
|
||||
SUM_WS(
|
||||
(IntWrapper*)(self->ids_sorted_d),
|
||||
(IntWrapper*)(self->n_grad_contributions_d),
|
||||
self->ids_sorted_d,
|
||||
self->n_grad_contributions_d,
|
||||
static_cast<int>(num_balls),
|
||||
self->workspace_d,
|
||||
self->workspace_size,
|
||||
|
@ -283,9 +283,15 @@ GLOBAL void render(
|
||||
(percent_allowed_difference > 0.f &&
|
||||
max_closest_possible_intersection > depth_threshold) ||
|
||||
tracker.get_n_hits() >= max_n_hits;
|
||||
#if defined(__CUDACC__) && defined(__HIP_PLATFORM_AMD__)
|
||||
unsigned long long warp_done = __ballot(done);
|
||||
int warp_done_bit_cnt = __popcll(warp_done);
|
||||
#else
|
||||
uint warp_done = thread_warp.ballot(done);
|
||||
int warp_done_bit_cnt = POPC(warp_done);
|
||||
#endif //__CUDACC__ && __HIP_PLATFORM_AMD__
|
||||
if (thread_warp.thread_rank() == 0)
|
||||
ATOMICADD_B(&n_pixels_done, POPC(warp_done));
|
||||
ATOMICADD_B(&n_pixels_done, warp_done_bit_cnt);
|
||||
// This sync is necessary to keep n_loaded until all threads are done with
|
||||
// painting.
|
||||
thread_block.sync();
|
||||
|
@ -213,8 +213,8 @@ std::tuple<size_t, size_t, bool, torch::Tensor> Renderer::arg_check(
|
||||
const float& gamma,
|
||||
const float& max_depth,
|
||||
float& min_depth,
|
||||
const c10::optional<torch::Tensor>& bg_col,
|
||||
const c10::optional<torch::Tensor>& opacity,
|
||||
const std::optional<torch::Tensor>& bg_col,
|
||||
const std::optional<torch::Tensor>& opacity,
|
||||
const float& percent_allowed_difference,
|
||||
const uint& max_n_hits,
|
||||
const uint& mode) {
|
||||
@ -668,8 +668,8 @@ std::tuple<torch::Tensor, torch::Tensor> Renderer::forward(
|
||||
const float& gamma,
|
||||
const float& max_depth,
|
||||
float min_depth,
|
||||
const c10::optional<torch::Tensor>& bg_col,
|
||||
const c10::optional<torch::Tensor>& opacity,
|
||||
const std::optional<torch::Tensor>& bg_col,
|
||||
const std::optional<torch::Tensor>& opacity,
|
||||
const float& percent_allowed_difference,
|
||||
const uint& max_n_hits,
|
||||
const uint& mode) {
|
||||
@ -888,14 +888,14 @@ std::tuple<torch::Tensor, torch::Tensor> Renderer::forward(
|
||||
};
|
||||
|
||||
std::tuple<
|
||||
at::optional<torch::Tensor>,
|
||||
at::optional<torch::Tensor>,
|
||||
at::optional<torch::Tensor>,
|
||||
at::optional<torch::Tensor>,
|
||||
at::optional<torch::Tensor>,
|
||||
at::optional<torch::Tensor>,
|
||||
at::optional<torch::Tensor>,
|
||||
at::optional<torch::Tensor>>
|
||||
std::optional<torch::Tensor>,
|
||||
std::optional<torch::Tensor>,
|
||||
std::optional<torch::Tensor>,
|
||||
std::optional<torch::Tensor>,
|
||||
std::optional<torch::Tensor>,
|
||||
std::optional<torch::Tensor>,
|
||||
std::optional<torch::Tensor>,
|
||||
std::optional<torch::Tensor>>
|
||||
Renderer::backward(
|
||||
const torch::Tensor& grad_im,
|
||||
const torch::Tensor& image,
|
||||
@ -912,8 +912,8 @@ Renderer::backward(
|
||||
const float& gamma,
|
||||
const float& max_depth,
|
||||
float min_depth,
|
||||
const c10::optional<torch::Tensor>& bg_col,
|
||||
const c10::optional<torch::Tensor>& opacity,
|
||||
const std::optional<torch::Tensor>& bg_col,
|
||||
const std::optional<torch::Tensor>& opacity,
|
||||
const float& percent_allowed_difference,
|
||||
const uint& max_n_hits,
|
||||
const uint& mode,
|
||||
@ -922,7 +922,7 @@ Renderer::backward(
|
||||
const bool& dif_rad,
|
||||
const bool& dif_cam,
|
||||
const bool& dif_opy,
|
||||
const at::optional<std::pair<uint, uint>>& dbg_pos) {
|
||||
const std::optional<std::pair<uint, uint>>& dbg_pos) {
|
||||
this->ensure_on_device(this->device_tracker.device());
|
||||
size_t batch_size;
|
||||
size_t n_points;
|
||||
@ -1045,14 +1045,14 @@ Renderer::backward(
|
||||
}
|
||||
// Prepare the return value.
|
||||
std::tuple<
|
||||
at::optional<torch::Tensor>,
|
||||
at::optional<torch::Tensor>,
|
||||
at::optional<torch::Tensor>,
|
||||
at::optional<torch::Tensor>,
|
||||
at::optional<torch::Tensor>,
|
||||
at::optional<torch::Tensor>,
|
||||
at::optional<torch::Tensor>,
|
||||
at::optional<torch::Tensor>>
|
||||
std::optional<torch::Tensor>,
|
||||
std::optional<torch::Tensor>,
|
||||
std::optional<torch::Tensor>,
|
||||
std::optional<torch::Tensor>,
|
||||
std::optional<torch::Tensor>,
|
||||
std::optional<torch::Tensor>,
|
||||
std::optional<torch::Tensor>,
|
||||
std::optional<torch::Tensor>>
|
||||
ret;
|
||||
if (mode == 1 || (!dif_pos && !dif_col && !dif_rad && !dif_cam && !dif_opy)) {
|
||||
return ret;
|
||||
|
@ -44,21 +44,21 @@ struct Renderer {
|
||||
const float& gamma,
|
||||
const float& max_depth,
|
||||
float min_depth,
|
||||
const c10::optional<torch::Tensor>& bg_col,
|
||||
const c10::optional<torch::Tensor>& opacity,
|
||||
const std::optional<torch::Tensor>& bg_col,
|
||||
const std::optional<torch::Tensor>& opacity,
|
||||
const float& percent_allowed_difference,
|
||||
const uint& max_n_hits,
|
||||
const uint& mode);
|
||||
|
||||
std::tuple<
|
||||
at::optional<torch::Tensor>,
|
||||
at::optional<torch::Tensor>,
|
||||
at::optional<torch::Tensor>,
|
||||
at::optional<torch::Tensor>,
|
||||
at::optional<torch::Tensor>,
|
||||
at::optional<torch::Tensor>,
|
||||
at::optional<torch::Tensor>,
|
||||
at::optional<torch::Tensor>>
|
||||
std::optional<torch::Tensor>,
|
||||
std::optional<torch::Tensor>,
|
||||
std::optional<torch::Tensor>,
|
||||
std::optional<torch::Tensor>,
|
||||
std::optional<torch::Tensor>,
|
||||
std::optional<torch::Tensor>,
|
||||
std::optional<torch::Tensor>,
|
||||
std::optional<torch::Tensor>>
|
||||
backward(
|
||||
const torch::Tensor& grad_im,
|
||||
const torch::Tensor& image,
|
||||
@ -75,8 +75,8 @@ struct Renderer {
|
||||
const float& gamma,
|
||||
const float& max_depth,
|
||||
float min_depth,
|
||||
const c10::optional<torch::Tensor>& bg_col,
|
||||
const c10::optional<torch::Tensor>& opacity,
|
||||
const std::optional<torch::Tensor>& bg_col,
|
||||
const std::optional<torch::Tensor>& opacity,
|
||||
const float& percent_allowed_difference,
|
||||
const uint& max_n_hits,
|
||||
const uint& mode,
|
||||
@ -85,7 +85,7 @@ struct Renderer {
|
||||
const bool& dif_rad,
|
||||
const bool& dif_cam,
|
||||
const bool& dif_opy,
|
||||
const at::optional<std::pair<uint, uint>>& dbg_pos);
|
||||
const std::optional<std::pair<uint, uint>>& dbg_pos);
|
||||
|
||||
// Infrastructure.
|
||||
/**
|
||||
@ -115,8 +115,8 @@ struct Renderer {
|
||||
const float& gamma,
|
||||
const float& max_depth,
|
||||
float& min_depth,
|
||||
const c10::optional<torch::Tensor>& bg_col,
|
||||
const c10::optional<torch::Tensor>& opacity,
|
||||
const std::optional<torch::Tensor>& bg_col,
|
||||
const std::optional<torch::Tensor>& opacity,
|
||||
const float& percent_allowed_difference,
|
||||
const uint& max_n_hits,
|
||||
const uint& mode);
|
||||
|
@ -8,6 +8,7 @@
|
||||
|
||||
#ifdef WITH_CUDA
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <c10/cuda/CUDAException.h>
|
||||
#include <cuda_runtime_api.h>
|
||||
#endif
|
||||
#include <torch/extension.h>
|
||||
@ -33,13 +34,13 @@ torch::Tensor sphere_ids_from_result_info_nograd(
|
||||
.contiguous();
|
||||
if (forw_info.device().type() == c10::DeviceType::CUDA) {
|
||||
#ifdef WITH_CUDA
|
||||
cudaMemcpyAsync(
|
||||
C10_CUDA_CHECK(cudaMemcpyAsync(
|
||||
result.data_ptr(),
|
||||
tmp.data_ptr(),
|
||||
sizeof(uint32_t) * tmp.size(0) * tmp.size(1) * tmp.size(2) *
|
||||
tmp.size(3),
|
||||
cudaMemcpyDeviceToDevice,
|
||||
at::cuda::getCurrentCUDAStream());
|
||||
at::cuda::getCurrentCUDAStream()));
|
||||
#else
|
||||
throw std::runtime_error(
|
||||
"Copy on CUDA device initiated but built "
|
||||
|
@ -7,6 +7,7 @@
|
||||
*/
|
||||
|
||||
#ifdef WITH_CUDA
|
||||
#include <c10/cuda/CUDAException.h>
|
||||
#include <cuda_runtime_api.h>
|
||||
|
||||
namespace pulsar {
|
||||
@ -17,7 +18,8 @@ void cudaDevToDev(
|
||||
const void* src,
|
||||
const int& size,
|
||||
const cudaStream_t& stream) {
|
||||
cudaMemcpyAsync(trg, src, size, cudaMemcpyDeviceToDevice, stream);
|
||||
C10_CUDA_CHECK(
|
||||
cudaMemcpyAsync(trg, src, size, cudaMemcpyDeviceToDevice, stream));
|
||||
}
|
||||
|
||||
void cudaDevToHost(
|
||||
@ -25,7 +27,8 @@ void cudaDevToHost(
|
||||
const void* src,
|
||||
const int& size,
|
||||
const cudaStream_t& stream) {
|
||||
cudaMemcpyAsync(trg, src, size, cudaMemcpyDeviceToHost, stream);
|
||||
C10_CUDA_CHECK(
|
||||
cudaMemcpyAsync(trg, src, size, cudaMemcpyDeviceToHost, stream));
|
||||
}
|
||||
|
||||
} // namespace pytorch
|
||||
|
@ -6,9 +6,6 @@
|
||||
* LICENSE file in the root directory of this source tree.
|
||||
*/
|
||||
|
||||
#include "./global.h"
|
||||
#include "./logging.h"
|
||||
|
||||
/**
|
||||
* A compilation unit to provide warnings about the code and avoid
|
||||
* repeated messages.
|
||||
|
@ -25,7 +25,7 @@ class BitMask {
|
||||
|
||||
// Use all threads in the current block to clear all bits of this BitMask
|
||||
__device__ void block_clear() {
|
||||
for (int i = threadIdx.x; i < H * W * D; i += blockDim.x) {
|
||||
for (auto i = threadIdx.x; i < H * W * D; i += blockDim.x) {
|
||||
data[i] = 0;
|
||||
}
|
||||
__syncthreads();
|
||||
|
@ -23,8 +23,8 @@ __global__ void TriangleBoundingBoxKernel(
|
||||
const float blur_radius,
|
||||
float* bboxes, // (4, F)
|
||||
bool* skip_face) { // (F,)
|
||||
const int tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
const int num_threads = blockDim.x * gridDim.x;
|
||||
const auto tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
const auto num_threads = blockDim.x * gridDim.x;
|
||||
const float sqrt_radius = sqrt(blur_radius);
|
||||
for (int f = tid; f < F; f += num_threads) {
|
||||
const float v0x = face_verts[f * 9 + 0 * 3 + 0];
|
||||
@ -56,8 +56,8 @@ __global__ void PointBoundingBoxKernel(
|
||||
const int P,
|
||||
float* bboxes, // (4, P)
|
||||
bool* skip_points) {
|
||||
const int tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
const int num_threads = blockDim.x * gridDim.x;
|
||||
const auto tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
const auto num_threads = blockDim.x * gridDim.x;
|
||||
for (int p = tid; p < P; p += num_threads) {
|
||||
const float x = points[p * 3 + 0];
|
||||
const float y = points[p * 3 + 1];
|
||||
@ -113,7 +113,7 @@ __global__ void RasterizeCoarseCudaKernel(
|
||||
const int chunks_per_batch = 1 + (E - 1) / chunk_size;
|
||||
const int num_chunks = N * chunks_per_batch;
|
||||
|
||||
for (int chunk = blockIdx.x; chunk < num_chunks; chunk += gridDim.x) {
|
||||
for (auto chunk = blockIdx.x; chunk < num_chunks; chunk += gridDim.x) {
|
||||
const int batch_idx = chunk / chunks_per_batch; // batch index
|
||||
const int chunk_idx = chunk % chunks_per_batch;
|
||||
const int elem_chunk_start_idx = chunk_idx * chunk_size;
|
||||
@ -123,7 +123,7 @@ __global__ void RasterizeCoarseCudaKernel(
|
||||
const int64_t elem_stop_idx = elem_start_idx + elems_per_batch[batch_idx];
|
||||
|
||||
// Have each thread handle a different face within the chunk
|
||||
for (int e = threadIdx.x; e < chunk_size; e += blockDim.x) {
|
||||
for (auto e = threadIdx.x; e < chunk_size; e += blockDim.x) {
|
||||
const int e_idx = elem_chunk_start_idx + e;
|
||||
|
||||
// Check that we are still within the same element of the batch
|
||||
@ -170,7 +170,7 @@ __global__ void RasterizeCoarseCudaKernel(
|
||||
// Now we have processed every elem in the current chunk. We need to
|
||||
// count the number of elems in each bin so we can write the indices
|
||||
// out to global memory. We have each thread handle a different bin.
|
||||
for (int byx = threadIdx.x; byx < num_bins_y * num_bins_x;
|
||||
for (auto byx = threadIdx.x; byx < num_bins_y * num_bins_x;
|
||||
byx += blockDim.x) {
|
||||
const int by = byx / num_bins_x;
|
||||
const int bx = byx % num_bins_x;
|
||||
|
@ -260,8 +260,8 @@ __global__ void RasterizeMeshesNaiveCudaKernel(
|
||||
float* pix_dists,
|
||||
float* bary) {
|
||||
// Simple version: One thread per output pixel
|
||||
int num_threads = gridDim.x * blockDim.x;
|
||||
int tid = blockDim.x * blockIdx.x + threadIdx.x;
|
||||
auto num_threads = gridDim.x * blockDim.x;
|
||||
auto tid = blockDim.x * blockIdx.x + threadIdx.x;
|
||||
|
||||
for (int i = tid; i < N * H * W; i += num_threads) {
|
||||
// Convert linear index to 3D index
|
||||
@ -446,8 +446,8 @@ __global__ void RasterizeMeshesBackwardCudaKernel(
|
||||
|
||||
// Parallelize over each pixel in images of
|
||||
// size H * W, for each image in the batch of size N.
|
||||
const int num_threads = gridDim.x * blockDim.x;
|
||||
const int tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
const auto num_threads = gridDim.x * blockDim.x;
|
||||
const auto tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
|
||||
for (int t_i = tid; t_i < N * H * W; t_i += num_threads) {
|
||||
// Convert linear index to 3D index
|
||||
@ -650,8 +650,8 @@ __global__ void RasterizeMeshesFineCudaKernel(
|
||||
) {
|
||||
// This can be more than H * W if H or W are not divisible by bin_size.
|
||||
int num_pixels = N * BH * BW * bin_size * bin_size;
|
||||
int num_threads = gridDim.x * blockDim.x;
|
||||
int tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
auto num_threads = gridDim.x * blockDim.x;
|
||||
auto tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
|
||||
for (int pid = tid; pid < num_pixels; pid += num_threads) {
|
||||
// Convert linear index into bin and pixel indices. We make the within
|
||||
|
@ -138,6 +138,9 @@ RasterizeMeshesNaive(
|
||||
AT_ERROR("Not compiled with GPU support");
|
||||
#endif
|
||||
} else {
|
||||
CHECK_CPU(face_verts);
|
||||
CHECK_CPU(mesh_to_face_first_idx);
|
||||
CHECK_CPU(num_faces_per_mesh);
|
||||
return RasterizeMeshesNaiveCpu(
|
||||
face_verts,
|
||||
mesh_to_face_first_idx,
|
||||
@ -232,6 +235,11 @@ torch::Tensor RasterizeMeshesBackward(
|
||||
AT_ERROR("Not compiled with GPU support");
|
||||
#endif
|
||||
} else {
|
||||
CHECK_CPU(face_verts);
|
||||
CHECK_CPU(pix_to_face);
|
||||
CHECK_CPU(grad_zbuf);
|
||||
CHECK_CPU(grad_bary);
|
||||
CHECK_CPU(grad_dists);
|
||||
return RasterizeMeshesBackwardCpu(
|
||||
face_verts,
|
||||
pix_to_face,
|
||||
@ -306,6 +314,9 @@ torch::Tensor RasterizeMeshesCoarse(
|
||||
AT_ERROR("Not compiled with GPU support");
|
||||
#endif
|
||||
} else {
|
||||
CHECK_CPU(face_verts);
|
||||
CHECK_CPU(mesh_to_face_first_idx);
|
||||
CHECK_CPU(num_faces_per_mesh);
|
||||
return RasterizeMeshesCoarseCpu(
|
||||
face_verts,
|
||||
mesh_to_face_first_idx,
|
||||
@ -423,6 +434,8 @@ RasterizeMeshesFine(
|
||||
AT_ERROR("Not compiled with GPU support");
|
||||
#endif
|
||||
} else {
|
||||
CHECK_CPU(face_verts);
|
||||
CHECK_CPU(bin_faces);
|
||||
AT_ERROR("NOT IMPLEMENTED");
|
||||
}
|
||||
}
|
||||
|
@ -9,7 +9,6 @@
|
||||
#include <torch/extension.h>
|
||||
#include <algorithm>
|
||||
#include <list>
|
||||
#include <queue>
|
||||
#include <thread>
|
||||
#include <tuple>
|
||||
#include "ATen/core/TensorAccessor.h"
|
||||
|
@ -97,8 +97,8 @@ __global__ void RasterizePointsNaiveCudaKernel(
|
||||
float* zbuf, // (N, H, W, K)
|
||||
float* pix_dists) { // (N, H, W, K)
|
||||
// Simple version: One thread per output pixel
|
||||
const int num_threads = gridDim.x * blockDim.x;
|
||||
const int tid = blockDim.x * blockIdx.x + threadIdx.x;
|
||||
const auto num_threads = gridDim.x * blockDim.x;
|
||||
const auto tid = blockDim.x * blockIdx.x + threadIdx.x;
|
||||
for (int i = tid; i < N * H * W; i += num_threads) {
|
||||
// Convert linear index to 3D index
|
||||
const int n = i / (H * W); // Batch index
|
||||
@ -237,8 +237,8 @@ __global__ void RasterizePointsFineCudaKernel(
|
||||
float* pix_dists) { // (N, H, W, K)
|
||||
// This can be more than H * W if H or W are not divisible by bin_size.
|
||||
const int num_pixels = N * BH * BW * bin_size * bin_size;
|
||||
const int num_threads = gridDim.x * blockDim.x;
|
||||
const int tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
const auto num_threads = gridDim.x * blockDim.x;
|
||||
const auto tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
|
||||
for (int pid = tid; pid < num_pixels; pid += num_threads) {
|
||||
// Convert linear index into bin and pixel indices. We make the within
|
||||
@ -376,8 +376,8 @@ __global__ void RasterizePointsBackwardCudaKernel(
|
||||
float* grad_points) { // (P, 3)
|
||||
// Parallelized over each of K points per pixel, for each pixel in images of
|
||||
// size H * W, for each image in the batch of size N.
|
||||
int num_threads = gridDim.x * blockDim.x;
|
||||
int tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
auto num_threads = gridDim.x * blockDim.x;
|
||||
auto tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
for (int i = tid; i < N * H * W * K; i += num_threads) {
|
||||
// const int n = i / (H * W * K); // batch index (not needed).
|
||||
const int yxk = i % (H * W * K);
|
||||
|
@ -91,6 +91,10 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> RasterizePointsNaive(
|
||||
AT_ERROR("Not compiled with GPU support");
|
||||
#endif
|
||||
} else {
|
||||
CHECK_CPU(points);
|
||||
CHECK_CPU(cloud_to_packed_first_idx);
|
||||
CHECK_CPU(num_points_per_cloud);
|
||||
CHECK_CPU(radius);
|
||||
return RasterizePointsNaiveCpu(
|
||||
points,
|
||||
cloud_to_packed_first_idx,
|
||||
@ -166,6 +170,10 @@ torch::Tensor RasterizePointsCoarse(
|
||||
AT_ERROR("Not compiled with GPU support");
|
||||
#endif
|
||||
} else {
|
||||
CHECK_CPU(points);
|
||||
CHECK_CPU(cloud_to_packed_first_idx);
|
||||
CHECK_CPU(num_points_per_cloud);
|
||||
CHECK_CPU(radius);
|
||||
return RasterizePointsCoarseCpu(
|
||||
points,
|
||||
cloud_to_packed_first_idx,
|
||||
@ -232,6 +240,8 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> RasterizePointsFine(
|
||||
AT_ERROR("Not compiled with GPU support");
|
||||
#endif
|
||||
} else {
|
||||
CHECK_CPU(points);
|
||||
CHECK_CPU(bin_points);
|
||||
AT_ERROR("NOT IMPLEMENTED");
|
||||
}
|
||||
}
|
||||
@ -284,6 +294,10 @@ torch::Tensor RasterizePointsBackward(
|
||||
AT_ERROR("Not compiled with GPU support");
|
||||
#endif
|
||||
} else {
|
||||
CHECK_CPU(points);
|
||||
CHECK_CPU(idxs);
|
||||
CHECK_CPU(grad_zbuf);
|
||||
CHECK_CPU(grad_dists);
|
||||
return RasterizePointsBackwardCpu(points, idxs, grad_zbuf, grad_dists);
|
||||
}
|
||||
}
|
||||
|
@ -35,8 +35,6 @@ __global__ void FarthestPointSamplingKernel(
|
||||
__shared__ int64_t selected_store;
|
||||
|
||||
// Get constants
|
||||
const int64_t N = points.size(0);
|
||||
const int64_t P = points.size(1);
|
||||
const int64_t D = points.size(2);
|
||||
|
||||
// Get batch index and thread index
|
||||
@ -109,7 +107,8 @@ at::Tensor FarthestPointSamplingCuda(
|
||||
const at::Tensor& points, // (N, P, 3)
|
||||
const at::Tensor& lengths, // (N,)
|
||||
const at::Tensor& K, // (N,)
|
||||
const at::Tensor& start_idxs) {
|
||||
const at::Tensor& start_idxs,
|
||||
const int64_t max_K_known = -1) {
|
||||
// Check inputs are on the same device
|
||||
at::TensorArg p_t{points, "points", 1}, lengths_t{lengths, "lengths", 2},
|
||||
k_t{K, "K", 3}, start_idxs_t{start_idxs, "start_idxs", 4};
|
||||
@ -131,7 +130,12 @@ at::Tensor FarthestPointSamplingCuda(
|
||||
|
||||
const int64_t N = points.size(0);
|
||||
const int64_t P = points.size(1);
|
||||
const int64_t max_K = at::max(K).item<int64_t>();
|
||||
int64_t max_K;
|
||||
if (max_K_known > 0) {
|
||||
max_K = max_K_known;
|
||||
} else {
|
||||
max_K = at::max(K).item<int64_t>();
|
||||
}
|
||||
|
||||
// Initialize the output tensor with the sampled indices
|
||||
auto idxs = at::full({N, max_K}, -1, lengths.options());
|
||||
|
@ -43,7 +43,8 @@ at::Tensor FarthestPointSamplingCuda(
|
||||
const at::Tensor& points,
|
||||
const at::Tensor& lengths,
|
||||
const at::Tensor& K,
|
||||
const at::Tensor& start_idxs);
|
||||
const at::Tensor& start_idxs,
|
||||
const int64_t max_K_known = -1);
|
||||
|
||||
at::Tensor FarthestPointSamplingCpu(
|
||||
const at::Tensor& points,
|
||||
@ -56,17 +57,23 @@ at::Tensor FarthestPointSampling(
|
||||
const at::Tensor& points,
|
||||
const at::Tensor& lengths,
|
||||
const at::Tensor& K,
|
||||
const at::Tensor& start_idxs) {
|
||||
const at::Tensor& start_idxs,
|
||||
const int64_t max_K_known = -1) {
|
||||
if (points.is_cuda() || lengths.is_cuda() || K.is_cuda()) {
|
||||
#ifdef WITH_CUDA
|
||||
CHECK_CUDA(points);
|
||||
CHECK_CUDA(lengths);
|
||||
CHECK_CUDA(K);
|
||||
CHECK_CUDA(start_idxs);
|
||||
return FarthestPointSamplingCuda(points, lengths, K, start_idxs);
|
||||
return FarthestPointSamplingCuda(
|
||||
points, lengths, K, start_idxs, max_K_known);
|
||||
#else
|
||||
AT_ERROR("Not compiled with GPU support.");
|
||||
#endif
|
||||
}
|
||||
CHECK_CPU(points);
|
||||
CHECK_CPU(lengths);
|
||||
CHECK_CPU(K);
|
||||
CHECK_CPU(start_idxs);
|
||||
return FarthestPointSamplingCpu(points, lengths, K, start_idxs);
|
||||
}
|
||||
|
@ -71,6 +71,8 @@ inline void SamplePdf(
|
||||
AT_ERROR("Not compiled with GPU support.");
|
||||
#endif
|
||||
}
|
||||
CHECK_CPU(weights);
|
||||
CHECK_CPU(outputs);
|
||||
CHECK_CONTIGUOUS(outputs);
|
||||
SamplePdfCpu(bins, weights, outputs, eps);
|
||||
}
|
||||
|
@ -376,8 +376,6 @@ PointLineDistanceBackward(
|
||||
float tt = t_top / t_bot;
|
||||
tt = __saturatef(tt);
|
||||
const float2 p_proj = (1.0f - tt) * v0 + tt * v1;
|
||||
const float2 d = p - p_proj;
|
||||
const float dist = sqrt(dot(d, d));
|
||||
|
||||
const float2 grad_p = -1.0f * grad_dist * 2.0f * (p_proj - p);
|
||||
const float2 grad_v0 = grad_dist * (1.0f - tt) * 2.0f * (p_proj - p);
|
||||
|
@ -15,3 +15,7 @@
|
||||
#define CHECK_CONTIGUOUS_CUDA(x) \
|
||||
CHECK_CUDA(x); \
|
||||
CHECK_CONTIGUOUS(x)
|
||||
#define CHECK_CPU(x) \
|
||||
TORCH_CHECK( \
|
||||
x.device().type() == torch::kCPU, \
|
||||
"Cannot use CPU implementation: " #x " not on CPU.")
|
||||
|
@ -83,7 +83,7 @@ class ShapeNetCore(ShapeNetBase): # pragma: no cover
|
||||
):
|
||||
synset_set.add(synset)
|
||||
elif (synset in self.synset_inv.keys()) and (
|
||||
(path.isdir(path.join(data_dir, self.synset_inv[synset])))
|
||||
path.isdir(path.join(data_dir, self.synset_inv[synset]))
|
||||
):
|
||||
synset_set.add(self.synset_inv[synset])
|
||||
else:
|
||||
|
@ -36,7 +36,6 @@ def collate_batched_meshes(batch: List[Dict]): # pragma: no cover
|
||||
|
||||
collated_dict["mesh"] = None
|
||||
if {"verts", "faces"}.issubset(collated_dict.keys()):
|
||||
|
||||
textures = None
|
||||
if "textures" in collated_dict:
|
||||
textures = TexturesAtlas(atlas=collated_dict["textures"])
|
||||
|
@ -26,7 +26,7 @@ from typing import (
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from pytorch3d.implicitron.dataset import types
|
||||
from pytorch3d.implicitron.dataset import orm_types, types
|
||||
from pytorch3d.implicitron.dataset.utils import (
|
||||
adjust_camera_to_bbox_crop_,
|
||||
adjust_camera_to_image_scale_,
|
||||
@ -48,8 +48,12 @@ from pytorch3d.implicitron.dataset.utils import (
|
||||
from pytorch3d.implicitron.tools.config import registry, ReplaceableBase
|
||||
from pytorch3d.renderer.camera_utils import join_cameras_as_batch
|
||||
from pytorch3d.renderer.cameras import CamerasBase, PerspectiveCameras
|
||||
from pytorch3d.structures.meshes import join_meshes_as_batch, Meshes
|
||||
from pytorch3d.structures.pointclouds import join_pointclouds_as_batch, Pointclouds
|
||||
|
||||
FrameAnnotationT = types.FrameAnnotation | orm_types.SqlFrameAnnotation
|
||||
SequenceAnnotationT = types.SequenceAnnotation | orm_types.SqlSequenceAnnotation
|
||||
|
||||
|
||||
@dataclass
|
||||
class FrameData(Mapping[str, Any]):
|
||||
@ -122,9 +126,9 @@ class FrameData(Mapping[str, Any]):
|
||||
meta: A dict for storing additional frame information.
|
||||
"""
|
||||
|
||||
frame_number: Optional[torch.LongTensor]
|
||||
sequence_name: Union[str, List[str]]
|
||||
sequence_category: Union[str, List[str]]
|
||||
frame_number: Optional[torch.LongTensor] = None
|
||||
sequence_name: Union[str, List[str]] = ""
|
||||
sequence_category: Union[str, List[str]] = ""
|
||||
frame_timestamp: Optional[torch.Tensor] = None
|
||||
image_size_hw: Optional[torch.LongTensor] = None
|
||||
effective_image_size_hw: Optional[torch.LongTensor] = None
|
||||
@ -155,7 +159,7 @@ class FrameData(Mapping[str, Any]):
|
||||
new_params = {}
|
||||
for field_name in iter(self):
|
||||
value = getattr(self, field_name)
|
||||
if isinstance(value, (torch.Tensor, Pointclouds, CamerasBase)):
|
||||
if isinstance(value, (torch.Tensor, Pointclouds, CamerasBase, Meshes)):
|
||||
new_params[field_name] = value.to(*args, **kwargs)
|
||||
else:
|
||||
new_params[field_name] = value
|
||||
@ -417,7 +421,6 @@ class FrameData(Mapping[str, Any]):
|
||||
for f in fields(elem):
|
||||
if not f.init:
|
||||
continue
|
||||
|
||||
list_values = override_fields.get(
|
||||
f.name, [getattr(d, f.name) for d in batch]
|
||||
)
|
||||
@ -426,7 +429,7 @@ class FrameData(Mapping[str, Any]):
|
||||
if all(list_value is not None for list_value in list_values)
|
||||
else None
|
||||
)
|
||||
return cls(**collated)
|
||||
return type(elem)(**collated)
|
||||
|
||||
elif isinstance(elem, Pointclouds):
|
||||
return join_pointclouds_as_batch(batch)
|
||||
@ -434,6 +437,8 @@ class FrameData(Mapping[str, Any]):
|
||||
elif isinstance(elem, CamerasBase):
|
||||
# TODO: don't store K; enforce working in NDC space
|
||||
return join_cameras_as_batch(batch)
|
||||
elif isinstance(elem, Meshes):
|
||||
return join_meshes_as_batch(batch)
|
||||
else:
|
||||
return torch.utils.data.dataloader.default_collate(batch)
|
||||
|
||||
@ -454,8 +459,8 @@ class FrameDataBuilderBase(ReplaceableBase, Generic[FrameDataSubtype], ABC):
|
||||
@abstractmethod
|
||||
def build(
|
||||
self,
|
||||
frame_annotation: types.FrameAnnotation,
|
||||
sequence_annotation: types.SequenceAnnotation,
|
||||
frame_annotation: FrameAnnotationT,
|
||||
sequence_annotation: SequenceAnnotationT,
|
||||
*,
|
||||
load_blobs: bool = True,
|
||||
**kwargs,
|
||||
@ -541,8 +546,8 @@ class GenericFrameDataBuilder(FrameDataBuilderBase[FrameDataSubtype], ABC):
|
||||
|
||||
def build(
|
||||
self,
|
||||
frame_annotation: types.FrameAnnotation,
|
||||
sequence_annotation: types.SequenceAnnotation,
|
||||
frame_annotation: FrameAnnotationT,
|
||||
sequence_annotation: SequenceAnnotationT,
|
||||
*,
|
||||
load_blobs: bool = True,
|
||||
**kwargs,
|
||||
@ -586,58 +591,81 @@ class GenericFrameDataBuilder(FrameDataBuilderBase[FrameDataSubtype], ABC):
|
||||
),
|
||||
)
|
||||
|
||||
fg_mask_np: Optional[np.ndarray] = None
|
||||
dataset_root = self.dataset_root
|
||||
mask_annotation = frame_annotation.mask
|
||||
if mask_annotation is not None:
|
||||
if load_blobs and self.load_masks:
|
||||
fg_mask_np, mask_path = self._load_fg_probability(frame_annotation)
|
||||
depth_annotation = frame_annotation.depth
|
||||
image_path: str | None = None
|
||||
mask_path: str | None = None
|
||||
depth_path: str | None = None
|
||||
pcl_path: str | None = None
|
||||
if dataset_root is not None: # set all paths even if we won’t load blobs
|
||||
if frame_annotation.image.path is not None:
|
||||
image_path = os.path.join(dataset_root, frame_annotation.image.path)
|
||||
frame_data.image_path = image_path
|
||||
|
||||
if mask_annotation is not None and mask_annotation.path:
|
||||
mask_path = os.path.join(dataset_root, mask_annotation.path)
|
||||
frame_data.mask_path = mask_path
|
||||
|
||||
if depth_annotation is not None and depth_annotation.path is not None:
|
||||
depth_path = os.path.join(dataset_root, depth_annotation.path)
|
||||
frame_data.depth_path = depth_path
|
||||
|
||||
if point_cloud is not None:
|
||||
pcl_path = os.path.join(dataset_root, point_cloud.path)
|
||||
frame_data.sequence_point_cloud_path = pcl_path
|
||||
|
||||
fg_mask_np: np.ndarray | None = None
|
||||
bbox_xywh: tuple[float, float, float, float] | None = None
|
||||
|
||||
if mask_annotation is not None:
|
||||
if load_blobs and self.load_masks and mask_path:
|
||||
fg_mask_np = self._load_fg_probability(frame_annotation, mask_path)
|
||||
frame_data.fg_probability = safe_as_tensor(fg_mask_np, torch.float)
|
||||
|
||||
bbox_xywh = mask_annotation.bounding_box_xywh
|
||||
if bbox_xywh is None and fg_mask_np is not None:
|
||||
bbox_xywh = get_bbox_from_mask(fg_mask_np, self.box_crop_mask_thr)
|
||||
|
||||
frame_data.bbox_xywh = safe_as_tensor(bbox_xywh, torch.float)
|
||||
|
||||
if frame_annotation.image is not None:
|
||||
image_size_hw = safe_as_tensor(frame_annotation.image.size, torch.long)
|
||||
frame_data.image_size_hw = image_size_hw # original image size
|
||||
# image size after crop/resize
|
||||
frame_data.effective_image_size_hw = image_size_hw
|
||||
image_path = None
|
||||
dataset_root = self.dataset_root
|
||||
if frame_annotation.image.path is not None and dataset_root is not None:
|
||||
image_path = os.path.join(dataset_root, frame_annotation.image.path)
|
||||
frame_data.image_path = image_path
|
||||
|
||||
if load_blobs and self.load_images:
|
||||
if image_path is None:
|
||||
raise ValueError("Image path is required to load images.")
|
||||
|
||||
image_np = load_image(self._local_path(image_path))
|
||||
no_mask = fg_mask_np is None # didn’t read the mask file
|
||||
image_np = load_image(
|
||||
self._local_path(image_path), try_read_alpha=no_mask
|
||||
)
|
||||
if image_np.shape[0] == 4: # RGBA image
|
||||
if no_mask:
|
||||
fg_mask_np = image_np[3:]
|
||||
frame_data.fg_probability = safe_as_tensor(
|
||||
fg_mask_np, torch.float
|
||||
)
|
||||
|
||||
image_np = image_np[:3]
|
||||
|
||||
frame_data.image_rgb = self._postprocess_image(
|
||||
image_np, frame_annotation.image.size, frame_data.fg_probability
|
||||
)
|
||||
|
||||
if (
|
||||
load_blobs
|
||||
and self.load_depths
|
||||
and frame_annotation.depth is not None
|
||||
and frame_annotation.depth.path is not None
|
||||
):
|
||||
(
|
||||
frame_data.depth_map,
|
||||
frame_data.depth_path,
|
||||
frame_data.depth_mask,
|
||||
) = self._load_mask_depth(frame_annotation, fg_mask_np)
|
||||
if bbox_xywh is None and fg_mask_np is not None:
|
||||
bbox_xywh = get_bbox_from_mask(fg_mask_np, self.box_crop_mask_thr)
|
||||
frame_data.bbox_xywh = safe_as_tensor(bbox_xywh, torch.float)
|
||||
|
||||
if load_blobs and self.load_depths and depth_path is not None:
|
||||
frame_data.depth_map, frame_data.depth_mask = self._load_mask_depth(
|
||||
frame_annotation, depth_path, fg_mask_np
|
||||
)
|
||||
|
||||
if load_blobs and self.load_point_clouds and point_cloud is not None:
|
||||
pcl_path = self._fix_point_cloud_path(point_cloud.path)
|
||||
assert pcl_path is not None
|
||||
frame_data.sequence_point_cloud = load_pointcloud(
|
||||
self._local_path(pcl_path), max_points=self.max_points
|
||||
)
|
||||
frame_data.sequence_point_cloud_path = pcl_path
|
||||
|
||||
if frame_annotation.viewpoint is not None:
|
||||
frame_data.camera = self._get_pytorch3d_camera(frame_annotation)
|
||||
@ -653,18 +681,14 @@ class GenericFrameDataBuilder(FrameDataBuilderBase[FrameDataSubtype], ABC):
|
||||
|
||||
return frame_data
|
||||
|
||||
def _load_fg_probability(
|
||||
self, entry: types.FrameAnnotation
|
||||
) -> Tuple[np.ndarray, str]:
|
||||
assert self.dataset_root is not None and entry.mask is not None
|
||||
full_path = os.path.join(self.dataset_root, entry.mask.path)
|
||||
fg_probability = load_mask(self._local_path(full_path))
|
||||
def _load_fg_probability(self, entry: FrameAnnotationT, path: str) -> np.ndarray:
|
||||
fg_probability = load_mask(self._local_path(path))
|
||||
if fg_probability.shape[-2:] != entry.image.size:
|
||||
raise ValueError(
|
||||
f"bad mask size: {fg_probability.shape[-2:]} vs {entry.image.size}!"
|
||||
)
|
||||
|
||||
return fg_probability, full_path
|
||||
return fg_probability
|
||||
|
||||
def _postprocess_image(
|
||||
self,
|
||||
@ -685,14 +709,14 @@ class GenericFrameDataBuilder(FrameDataBuilderBase[FrameDataSubtype], ABC):
|
||||
|
||||
def _load_mask_depth(
|
||||
self,
|
||||
entry: types.FrameAnnotation,
|
||||
entry: FrameAnnotationT,
|
||||
path: str,
|
||||
fg_mask: Optional[np.ndarray],
|
||||
) -> Tuple[torch.Tensor, str, torch.Tensor]:
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
entry_depth = entry.depth
|
||||
dataset_root = self.dataset_root
|
||||
assert dataset_root is not None
|
||||
assert entry_depth is not None and entry_depth.path is not None
|
||||
path = os.path.join(dataset_root, entry_depth.path)
|
||||
assert entry_depth is not None
|
||||
depth_map = load_depth(self._local_path(path), entry_depth.scale_adjustment)
|
||||
|
||||
if self.mask_depths:
|
||||
@ -706,11 +730,11 @@ class GenericFrameDataBuilder(FrameDataBuilderBase[FrameDataSubtype], ABC):
|
||||
else:
|
||||
depth_mask = (depth_map > 0.0).astype(np.float32)
|
||||
|
||||
return torch.tensor(depth_map), path, torch.tensor(depth_mask)
|
||||
return torch.tensor(depth_map), torch.tensor(depth_mask)
|
||||
|
||||
def _get_pytorch3d_camera(
|
||||
self,
|
||||
entry: types.FrameAnnotation,
|
||||
entry: FrameAnnotationT,
|
||||
) -> PerspectiveCameras:
|
||||
entry_viewpoint = entry.viewpoint
|
||||
assert entry_viewpoint is not None
|
||||
@ -739,19 +763,6 @@ class GenericFrameDataBuilder(FrameDataBuilderBase[FrameDataSubtype], ABC):
|
||||
T=torch.tensor(entry_viewpoint.T, dtype=torch.float)[None],
|
||||
)
|
||||
|
||||
def _fix_point_cloud_path(self, path: str) -> str:
|
||||
"""
|
||||
Fix up a point cloud path from the dataset.
|
||||
Some files in Co3Dv2 have an accidental absolute path stored.
|
||||
"""
|
||||
unwanted_prefix = (
|
||||
"/large_experiments/p3/replay/datasets/co3d/co3d45k_220512/export_v23/"
|
||||
)
|
||||
if path.startswith(unwanted_prefix):
|
||||
path = path[len(unwanted_prefix) :]
|
||||
assert self.dataset_root is not None
|
||||
return os.path.join(self.dataset_root, path)
|
||||
|
||||
def _local_path(self, path: str) -> str:
|
||||
if self.path_manager is None:
|
||||
return path
|
||||
|
@ -222,7 +222,6 @@ class JsonIndexDatasetMapProviderV2(DatasetMapProviderBase):
|
||||
self.dataset_map = dataset_map
|
||||
|
||||
def _load_category(self, category: str) -> DatasetMap:
|
||||
|
||||
frame_file = os.path.join(self.dataset_root, category, "frame_annotations.jgz")
|
||||
sequence_file = os.path.join(
|
||||
self.dataset_root, category, "sequence_annotations.jgz"
|
||||
|
@ -75,7 +75,6 @@ def _minify(basedir, path_manager, factors=(), resolutions=()):
|
||||
def _load_data(
|
||||
basedir, factor=None, width=None, height=None, load_imgs=True, path_manager=None
|
||||
):
|
||||
|
||||
poses_arr = np.load(
|
||||
_local_path(path_manager, os.path.join(basedir, "poses_bounds.npy"))
|
||||
)
|
||||
@ -164,7 +163,6 @@ def ptstocam(pts, c2w):
|
||||
|
||||
|
||||
def poses_avg(poses):
|
||||
|
||||
hwf = poses[0, :3, -1:]
|
||||
|
||||
center = poses[:, :3, 3].mean(0)
|
||||
@ -192,7 +190,6 @@ def render_path_spiral(c2w, up, rads, focal, zdelta, zrate, rots, N):
|
||||
|
||||
|
||||
def recenter_poses(poses):
|
||||
|
||||
poses_ = poses + 0
|
||||
bottom = np.reshape([0, 0, 0, 1.0], [1, 4])
|
||||
c2w = poses_avg(poses)
|
||||
@ -256,7 +253,6 @@ def spherify_poses(poses, bds):
|
||||
new_poses = []
|
||||
|
||||
for th in np.linspace(0.0, 2.0 * np.pi, 120):
|
||||
|
||||
camorigin = np.array([radcircle * np.cos(th), radcircle * np.sin(th), zh])
|
||||
up = np.array([0, 0, -1.0])
|
||||
|
||||
@ -311,7 +307,6 @@ def load_llff_data(
|
||||
path_zflat=False,
|
||||
path_manager=None,
|
||||
):
|
||||
|
||||
poses, bds, imgs = _load_data(
|
||||
basedir, factor=factor, path_manager=path_manager
|
||||
) # factor=8 downsamples original imgs by 8x
|
||||
|
@ -4,6 +4,8 @@
|
||||
# This source code is licensed under the BSD-style license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
# pyre-unsafe
|
||||
|
||||
# This functionality requires SQLAlchemy 2.0 or later.
|
||||
|
||||
import math
|
||||
|
@ -4,11 +4,15 @@
|
||||
# This source code is licensed under the BSD-style license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
# pyre-unsafe
|
||||
|
||||
import hashlib
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
|
||||
import urllib
|
||||
from dataclasses import dataclass, Field, field
|
||||
from typing import (
|
||||
Any,
|
||||
ClassVar,
|
||||
@ -29,17 +33,18 @@ import sqlalchemy as sa
|
||||
import torch
|
||||
from pytorch3d.implicitron.dataset.dataset_base import DatasetBase
|
||||
|
||||
from pytorch3d.implicitron.dataset.frame_data import ( # noqa
|
||||
from pytorch3d.implicitron.dataset.frame_data import (
|
||||
FrameData,
|
||||
FrameDataBuilder,
|
||||
FrameDataBuilder, # noqa
|
||||
FrameDataBuilderBase,
|
||||
)
|
||||
|
||||
from pytorch3d.implicitron.tools.config import (
|
||||
registry,
|
||||
ReplaceableBase,
|
||||
run_auto_creation,
|
||||
)
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.orm import scoped_session, Session, sessionmaker
|
||||
|
||||
from .orm_types import SqlFrameAnnotation, SqlSequenceAnnotation
|
||||
|
||||
@ -51,7 +56,7 @@ _SET_LISTS_TABLE: str = "set_lists"
|
||||
|
||||
|
||||
@registry.register
|
||||
class SqlIndexDataset(DatasetBase, ReplaceableBase): # pyre-ignore
|
||||
class SqlIndexDataset(DatasetBase, ReplaceableBase):
|
||||
"""
|
||||
A dataset with annotations stored as SQLite tables. This is an index-based dataset.
|
||||
The length is returned after all sequence and frame filters are applied (see param
|
||||
@ -88,6 +93,7 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase): # pyre-ignore
|
||||
engine verbatim. Don’t expose it to end users of your application!
|
||||
pick_categories: Restrict the dataset to the given list of categories.
|
||||
pick_sequences: A Sequence of sequence names to restrict the dataset to.
|
||||
pick_sequences_sql_clause: Custom SQL WHERE clause to constrain sequence annotations.
|
||||
exclude_sequences: A Sequence of the names of the sequences to exclude.
|
||||
limit_sequences_per_category_to: Limit the dataset to the first up to N
|
||||
sequences within each category (applies after all other sequence filters
|
||||
@ -102,9 +108,16 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase): # pyre-ignore
|
||||
more frames than that; applied after other frame-level filters.
|
||||
seed: The seed of the random generator sampling `n_frames_per_sequence`
|
||||
random frames per sequence.
|
||||
preload_metadata: If True, the metadata is preloaded into memory.
|
||||
precompute_seq_to_idx: If True, precomputes the mapping from sequence name to indices.
|
||||
scoped_session: If True, allows different parts of the code to share
|
||||
a global session to access the database.
|
||||
"""
|
||||
|
||||
frame_annotations_type: ClassVar[Type[SqlFrameAnnotation]] = SqlFrameAnnotation
|
||||
sequence_annotations_type: ClassVar[Type[SqlSequenceAnnotation]] = (
|
||||
SqlSequenceAnnotation
|
||||
)
|
||||
|
||||
sqlite_metadata_file: str = ""
|
||||
dataset_root: Optional[str] = None
|
||||
@ -117,6 +130,7 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase): # pyre-ignore
|
||||
pick_categories: Tuple[str, ...] = ()
|
||||
|
||||
pick_sequences: Tuple[str, ...] = ()
|
||||
pick_sequences_sql_clause: Optional[str] = None
|
||||
exclude_sequences: Tuple[str, ...] = ()
|
||||
limit_sequences_per_category_to: int = 0
|
||||
limit_sequences_to: int = 0
|
||||
@ -124,12 +138,22 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase): # pyre-ignore
|
||||
n_frames_per_sequence: int = -1
|
||||
seed: int = 0
|
||||
remove_empty_masks_poll_whole_table_threshold: int = 300_000
|
||||
preload_metadata: bool = False
|
||||
precompute_seq_to_idx: bool = False
|
||||
# we set it manually in the constructor
|
||||
# _index: pd.DataFrame = field(init=False)
|
||||
_index: pd.DataFrame = field(init=False, metadata={"omegaconf_ignore": True})
|
||||
_sql_engine: sa.engine.Engine = field(
|
||||
init=False, metadata={"omegaconf_ignore": True}
|
||||
)
|
||||
eval_batches: Optional[List[Any]] = field(
|
||||
init=False, metadata={"omegaconf_ignore": True}
|
||||
)
|
||||
|
||||
frame_data_builder: FrameDataBuilderBase
|
||||
frame_data_builder: FrameDataBuilderBase # pyre-ignore[13]
|
||||
frame_data_builder_class_type: str = "FrameDataBuilder"
|
||||
|
||||
scoped_session: bool = False
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
if sa.__version__ < "2.0":
|
||||
raise ImportError("This class requires SQL Alchemy 2.0 or later")
|
||||
@ -138,19 +162,28 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase): # pyre-ignore
|
||||
raise ValueError("sqlite_metadata_file must be set")
|
||||
|
||||
if self.dataset_root:
|
||||
frame_builder_type = self.frame_data_builder_class_type
|
||||
getattr(self, f"frame_data_builder_{frame_builder_type}_args")[
|
||||
"dataset_root"
|
||||
] = self.dataset_root
|
||||
frame_args = f"frame_data_builder_{self.frame_data_builder_class_type}_args"
|
||||
getattr(self, frame_args)["dataset_root"] = self.dataset_root
|
||||
getattr(self, frame_args)["path_manager"] = self.path_manager
|
||||
|
||||
run_auto_creation(self)
|
||||
self.frame_data_builder.path_manager = self.path_manager
|
||||
|
||||
# pyre-ignore # NOTE: sqlite-specific args (read-only mode).
|
||||
if self.path_manager is not None:
|
||||
self.sqlite_metadata_file = self.path_manager.get_local_path(
|
||||
self.sqlite_metadata_file
|
||||
)
|
||||
self.subset_lists_file = self.path_manager.get_local_path(
|
||||
self.subset_lists_file
|
||||
)
|
||||
|
||||
# NOTE: sqlite-specific args (read-only mode).
|
||||
self._sql_engine = sa.create_engine(
|
||||
f"sqlite:///file:{self.sqlite_metadata_file}?mode=ro&uri=true"
|
||||
f"sqlite:///file:{urllib.parse.quote(self.sqlite_metadata_file)}?mode=ro&uri=true"
|
||||
)
|
||||
|
||||
if self.preload_metadata:
|
||||
self._sql_engine = self._preload_database(self._sql_engine)
|
||||
|
||||
sequences = self._get_filtered_sequences_if_any()
|
||||
|
||||
if self.subsets:
|
||||
@ -166,16 +199,29 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase): # pyre-ignore
|
||||
if len(index) == 0:
|
||||
raise ValueError(f"There are no frames in the subsets: {self.subsets}!")
|
||||
|
||||
self._index = index.set_index(["sequence_name", "frame_number"]) # pyre-ignore
|
||||
self._index = index.set_index(["sequence_name", "frame_number"])
|
||||
|
||||
self.eval_batches = None # pyre-ignore
|
||||
self.eval_batches = None
|
||||
if self.eval_batches_file:
|
||||
self.eval_batches = self._load_filter_eval_batches()
|
||||
|
||||
logger.info(str(self))
|
||||
|
||||
if self.scoped_session:
|
||||
self._session_factory = sessionmaker(bind=self._sql_engine) # pyre-ignore
|
||||
|
||||
if self.precompute_seq_to_idx:
|
||||
# This is deprecated and will be removed in the future.
|
||||
# After we backport https://github.com/facebookresearch/uco3d/pull/3
|
||||
logger.warning(
|
||||
"Using precompute_seq_to_idx is deprecated and will be removed in the future."
|
||||
)
|
||||
self._index["rowid"] = np.arange(len(self._index))
|
||||
groupby = self._index.groupby("sequence_name", sort=False)["rowid"]
|
||||
self._seq_to_indices = dict(groupby.apply(list)) # pyre-ignore
|
||||
del self._index["rowid"]
|
||||
|
||||
def __len__(self) -> int:
|
||||
# pyre-ignore[16]
|
||||
return len(self._index)
|
||||
|
||||
def __getitem__(self, frame_idx: Union[int, Tuple[str, int]]) -> FrameData:
|
||||
@ -232,12 +278,18 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase): # pyre-ignore
|
||||
self.frame_annotations_type.frame_number
|
||||
== int(frame), # cast from np.int64
|
||||
)
|
||||
seq_stmt = sa.select(SqlSequenceAnnotation).where(
|
||||
SqlSequenceAnnotation.sequence_name == seq
|
||||
seq_stmt = sa.select(self.sequence_annotations_type).where(
|
||||
self.sequence_annotations_type.sequence_name == seq
|
||||
)
|
||||
with Session(self._sql_engine) as session:
|
||||
entry = session.scalars(stmt).one()
|
||||
seq_metadata = session.scalars(seq_stmt).one()
|
||||
if self.scoped_session:
|
||||
# pyre-ignore
|
||||
with scoped_session(self._session_factory)() as session:
|
||||
entry = session.scalars(stmt).one()
|
||||
seq_metadata = session.scalars(seq_stmt).one()
|
||||
else:
|
||||
with Session(self._sql_engine) as session:
|
||||
entry = session.scalars(stmt).one()
|
||||
seq_metadata = session.scalars(seq_stmt).one()
|
||||
|
||||
assert entry.image.path == self._index.loc[(seq, frame), "_image_path"]
|
||||
|
||||
@ -250,7 +302,6 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase): # pyre-ignore
|
||||
return frame_data
|
||||
|
||||
def __str__(self) -> str:
|
||||
# pyre-ignore[16]
|
||||
return f"SqlIndexDataset #frames={len(self._index)}"
|
||||
|
||||
def sequence_names(self) -> Iterable[str]:
|
||||
@ -260,9 +311,10 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase): # pyre-ignore
|
||||
# override
|
||||
def category_to_sequence_names(self) -> Dict[str, List[str]]:
|
||||
stmt = sa.select(
|
||||
SqlSequenceAnnotation.category, SqlSequenceAnnotation.sequence_name
|
||||
self.sequence_annotations_type.category,
|
||||
self.sequence_annotations_type.sequence_name,
|
||||
).where( # we limit results to sequences that have frames after all filters
|
||||
SqlSequenceAnnotation.sequence_name.in_(self.sequence_names())
|
||||
self.sequence_annotations_type.sequence_name.in_(self.sequence_names())
|
||||
)
|
||||
with self._sql_engine.connect() as connection:
|
||||
cat_to_seqs = pd.read_sql(stmt, connection)
|
||||
@ -335,17 +387,31 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase): # pyre-ignore
|
||||
rows = self._index.index.get_loc(seq_name)
|
||||
if isinstance(rows, slice):
|
||||
assert rows.stop is not None, "Unexpected result from pandas"
|
||||
rows = range(rows.start or 0, rows.stop, rows.step or 1)
|
||||
rows_seq = range(rows.start or 0, rows.stop, rows.step or 1)
|
||||
else:
|
||||
rows = np.where(rows)[0]
|
||||
rows_seq = list(np.where(rows)[0])
|
||||
|
||||
index_slice, idx = self._get_frame_no_coalesced_ts_by_row_indices(
|
||||
rows, seq_name, subset_filter
|
||||
rows_seq, seq_name, subset_filter
|
||||
)
|
||||
index_slice["idx"] = idx
|
||||
|
||||
yield from index_slice.itertuples(index=False)
|
||||
|
||||
# override
|
||||
def sequence_indices_in_order(
|
||||
self, seq_name: str, subset_filter: Optional[Sequence[str]] = None
|
||||
) -> Iterator[int]:
|
||||
"""Same as `sequence_frames_in_order` but returns the iterator over
|
||||
only dataset indices.
|
||||
"""
|
||||
if self.precompute_seq_to_idx and subset_filter is None:
|
||||
# pyre-ignore
|
||||
yield from self._seq_to_indices[seq_name]
|
||||
else:
|
||||
for _, _, idx in self.sequence_frames_in_order(seq_name, subset_filter):
|
||||
yield idx
|
||||
|
||||
# override
|
||||
def get_eval_batches(self) -> Optional[List[Any]]:
|
||||
"""
|
||||
@ -379,11 +445,35 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase): # pyre-ignore
|
||||
or self.limit_sequences_to > 0
|
||||
or self.limit_sequences_per_category_to > 0
|
||||
or len(self.pick_sequences) > 0
|
||||
or self.pick_sequences_sql_clause is not None
|
||||
or len(self.exclude_sequences) > 0
|
||||
or len(self.pick_categories) > 0
|
||||
or self.n_frames_per_sequence > 0
|
||||
)
|
||||
|
||||
def _preload_database(
|
||||
self, source_engine: sa.engine.base.Engine
|
||||
) -> sa.engine.base.Engine:
|
||||
destination_engine = sa.create_engine("sqlite:///:memory:")
|
||||
metadata = sa.MetaData()
|
||||
metadata.reflect(bind=source_engine)
|
||||
metadata.create_all(bind=destination_engine)
|
||||
|
||||
with source_engine.connect() as source_conn:
|
||||
with destination_engine.connect() as destination_conn:
|
||||
for table_obj in metadata.tables.values():
|
||||
# Select all rows from the source table
|
||||
source_rows = source_conn.execute(table_obj.select())
|
||||
|
||||
# Insert rows into the destination table
|
||||
for row in source_rows:
|
||||
destination_conn.execute(table_obj.insert().values(row))
|
||||
|
||||
# Commit the changes for each table
|
||||
destination_conn.commit()
|
||||
|
||||
return destination_engine
|
||||
|
||||
def _get_filtered_sequences_if_any(self) -> Optional[pd.Series]:
|
||||
# maximum possible filter (if limit_sequences_per_category_to == 0):
|
||||
# WHERE category IN 'self.pick_categories'
|
||||
@ -396,19 +486,22 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase): # pyre-ignore
|
||||
*self._get_pick_filters(),
|
||||
*self._get_exclude_filters(),
|
||||
]
|
||||
if self.pick_sequences_sql_clause:
|
||||
print("Applying the custom SQL clause.")
|
||||
where_conditions.append(sa.text(self.pick_sequences_sql_clause))
|
||||
|
||||
def add_where(stmt):
|
||||
return stmt.where(*where_conditions) if where_conditions else stmt
|
||||
|
||||
if self.limit_sequences_per_category_to <= 0:
|
||||
stmt = add_where(sa.select(SqlSequenceAnnotation.sequence_name))
|
||||
stmt = add_where(sa.select(self.sequence_annotations_type.sequence_name))
|
||||
else:
|
||||
subquery = sa.select(
|
||||
SqlSequenceAnnotation.sequence_name,
|
||||
self.sequence_annotations_type.sequence_name,
|
||||
sa.func.row_number()
|
||||
.over(
|
||||
order_by=sa.text("ROWID"), # NOTE: ROWID is SQLite-specific
|
||||
partition_by=SqlSequenceAnnotation.category,
|
||||
partition_by=self.sequence_annotations_type.category,
|
||||
)
|
||||
.label("row_number"),
|
||||
)
|
||||
@ -444,31 +537,34 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase): # pyre-ignore
|
||||
return []
|
||||
|
||||
logger.info(f"Limiting dataset to categories: {self.pick_categories}")
|
||||
return [SqlSequenceAnnotation.category.in_(self.pick_categories)]
|
||||
return [self.sequence_annotations_type.category.in_(self.pick_categories)]
|
||||
|
||||
def _get_pick_filters(self) -> List[sa.ColumnElement]:
|
||||
if not self.pick_sequences:
|
||||
return []
|
||||
|
||||
logger.info(f"Limiting dataset to sequences: {self.pick_sequences}")
|
||||
return [SqlSequenceAnnotation.sequence_name.in_(self.pick_sequences)]
|
||||
return [self.sequence_annotations_type.sequence_name.in_(self.pick_sequences)]
|
||||
|
||||
def _get_exclude_filters(self) -> List[sa.ColumnOperators]:
|
||||
if not self.exclude_sequences:
|
||||
return []
|
||||
|
||||
logger.info(f"Removing sequences from the dataset: {self.exclude_sequences}")
|
||||
return [SqlSequenceAnnotation.sequence_name.notin_(self.exclude_sequences)]
|
||||
return [
|
||||
self.sequence_annotations_type.sequence_name.notin_(self.exclude_sequences)
|
||||
]
|
||||
|
||||
def _load_subsets_from_json(self, subset_lists_path: str) -> pd.DataFrame:
|
||||
assert self.subsets is not None
|
||||
subsets = self.subsets
|
||||
assert subsets is not None
|
||||
with open(subset_lists_path, "r") as f:
|
||||
subset_to_seq_frame = json.load(f)
|
||||
|
||||
seq_frame_list = sum(
|
||||
(
|
||||
[(*row, subset) for row in subset_to_seq_frame[subset]]
|
||||
for subset in self.subsets
|
||||
for subset in subsets
|
||||
),
|
||||
[],
|
||||
)
|
||||
@ -522,7 +618,7 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase): # pyre-ignore
|
||||
stmt = sa.select(
|
||||
self.frame_annotations_type.sequence_name,
|
||||
self.frame_annotations_type.frame_number,
|
||||
).where(self.frame_annotations_type._mask_mass == 0)
|
||||
).where(self.frame_annotations_type._mask_mass == 0) # pyre-ignore[16]
|
||||
with Session(self._sql_engine) as session:
|
||||
to_remove = session.execute(stmt).all()
|
||||
|
||||
@ -586,7 +682,7 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase): # pyre-ignore
|
||||
stmt = sa.select(
|
||||
self.frame_annotations_type.sequence_name,
|
||||
self.frame_annotations_type.frame_number,
|
||||
self.frame_annotations_type._image_path,
|
||||
self.frame_annotations_type._image_path, # pyre-ignore[16]
|
||||
sa.null().label("subset"),
|
||||
)
|
||||
where_conditions = []
|
||||
@ -600,7 +696,7 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase): # pyre-ignore
|
||||
logger.info(" excluding samples with empty masks")
|
||||
where_conditions.append(
|
||||
sa.or_(
|
||||
self.frame_annotations_type._mask_mass.is_(None),
|
||||
self.frame_annotations_type._mask_mass.is_(None), # pyre-ignore[16]
|
||||
self.frame_annotations_type._mask_mass != 0,
|
||||
)
|
||||
)
|
||||
@ -634,7 +730,9 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase): # pyre-ignore
|
||||
assert self.eval_batches_file
|
||||
logger.info(f"Loading eval batches from {self.eval_batches_file}")
|
||||
|
||||
if not os.path.isfile(self.eval_batches_file):
|
||||
if (
|
||||
self.path_manager and not self.path_manager.isfile(self.eval_batches_file)
|
||||
) or (not self.path_manager and not os.path.isfile(self.eval_batches_file)):
|
||||
# The batch indices file does not exist.
|
||||
# Most probably the user has not specified the root folder.
|
||||
raise ValueError(
|
||||
@ -642,7 +740,8 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase): # pyre-ignore
|
||||
+ "Please specify a correct dataset_root folder."
|
||||
)
|
||||
|
||||
with open(self.eval_batches_file, "r") as f:
|
||||
eval_batches_file = self._local_path(self.eval_batches_file)
|
||||
with open(eval_batches_file, "r") as f:
|
||||
eval_batches = json.load(f)
|
||||
|
||||
# limit the dataset to sequences to allow multiple evaluations in one file
|
||||
@ -656,7 +755,7 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase): # pyre-ignore
|
||||
if pick_sequences:
|
||||
old_len = len(eval_batches)
|
||||
eval_batches = [b for b in eval_batches if b[0][0] in pick_sequences]
|
||||
logger.warn(
|
||||
logger.warning(
|
||||
f"Picked eval batches by sequence/cat: {old_len} -> {len(eval_batches)}"
|
||||
)
|
||||
|
||||
@ -664,7 +763,7 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase): # pyre-ignore
|
||||
old_len = len(eval_batches)
|
||||
exclude_sequences = set(self.exclude_sequences)
|
||||
eval_batches = [b for b in eval_batches if b[0][0] not in exclude_sequences]
|
||||
logger.warn(
|
||||
logger.warning(
|
||||
f"Excluded eval batches by sequence: {old_len} -> {len(eval_batches)}"
|
||||
)
|
||||
|
||||
@ -726,9 +825,15 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase): # pyre-ignore
|
||||
self.frame_annotations_type.sequence_name == seq_name,
|
||||
self.frame_annotations_type.frame_number.in_(frames),
|
||||
)
|
||||
frame_no_ts = None
|
||||
|
||||
with self._sql_engine.connect() as connection:
|
||||
frame_no_ts = pd.read_sql_query(stmt, connection)
|
||||
if self.scoped_session:
|
||||
stmt_text = str(stmt.compile(compile_kwargs={"literal_binds": True}))
|
||||
with scoped_session(self._session_factory)() as session: # pyre-ignore
|
||||
frame_no_ts = pd.read_sql_query(stmt_text, session.connection())
|
||||
else:
|
||||
with self._sql_engine.connect() as connection:
|
||||
frame_no_ts = pd.read_sql_query(stmt, connection)
|
||||
|
||||
if len(frame_no_ts) != len(index_slice):
|
||||
raise ValueError(
|
||||
@ -758,11 +863,18 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase): # pyre-ignore
|
||||
prefixes=["TEMP"], # NOTE SQLite specific!
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def pre_expand(cls) -> None:
|
||||
# remove dataclass annotations that are not meant to be init params
|
||||
# because they cause troubles for OmegaConf
|
||||
for attr, attr_value in list(cls.__dict__.items()): # need to copy as we mutate
|
||||
if isinstance(attr_value, Field) and attr_value.metadata.get(
|
||||
"omegaconf_ignore", False
|
||||
):
|
||||
delattr(cls, attr)
|
||||
del cls.__annotations__[attr]
|
||||
|
||||
|
||||
def _seq_name_to_seed(seq_name) -> int:
|
||||
"""Generates numbers in [0, 2 ** 28)"""
|
||||
return int(hashlib.sha1(seq_name.encode("utf-8")).hexdigest()[:7], 16)
|
||||
|
||||
|
||||
def _safe_as_tensor(data, dtype):
|
||||
return torch.tensor(data, dtype=dtype) if data is not None else None
|
||||
|
@ -4,6 +4,8 @@
|
||||
# This source code is licensed under the BSD-style license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
# pyre-unsafe
|
||||
|
||||
|
||||
import logging
|
||||
import os
|
||||
@ -43,7 +45,7 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@registry.register
|
||||
class SqlIndexDatasetMapProvider(DatasetMapProviderBase): # pyre-ignore [13]
|
||||
class SqlIndexDatasetMapProvider(DatasetMapProviderBase):
|
||||
"""
|
||||
Generates the training, validation, and testing dataset objects for
|
||||
a dataset laid out on disk like SQL-CO3D, with annotations in an SQLite data base.
|
||||
@ -193,9 +195,9 @@ class SqlIndexDatasetMapProvider(DatasetMapProviderBase): # pyre-ignore [13]
|
||||
|
||||
# this is a mould that is never constructed, used to build self._dataset_map values
|
||||
dataset_class_type: str = "SqlIndexDataset"
|
||||
dataset: SqlIndexDataset
|
||||
dataset: SqlIndexDataset # pyre-ignore [13]
|
||||
|
||||
path_manager_factory: PathManagerFactory
|
||||
path_manager_factory: PathManagerFactory # pyre-ignore [13]
|
||||
path_manager_factory_class_type: str = "PathManagerFactory"
|
||||
|
||||
def __post_init__(self):
|
||||
@ -282,8 +284,14 @@ class SqlIndexDatasetMapProvider(DatasetMapProviderBase): # pyre-ignore [13]
|
||||
logger.info(f"Val dataset: {str(val_dataset)}")
|
||||
|
||||
logger.debug("Extracting test dataset.")
|
||||
eval_batches_file = self._get_lists_file("eval_batches")
|
||||
del common_dataset_kwargs["eval_batches_file"]
|
||||
if self.eval_batches_path is None:
|
||||
eval_batches_file = None
|
||||
else:
|
||||
eval_batches_file = self._get_lists_file("eval_batches")
|
||||
|
||||
if "eval_batches_file" in common_dataset_kwargs:
|
||||
common_dataset_kwargs.pop("eval_batches_file", None)
|
||||
|
||||
test_dataset = dataset_type(
|
||||
**common_dataset_kwargs,
|
||||
subsets=self._get_subsets(self.test_subsets, True),
|
||||
|
@ -87,6 +87,15 @@ def is_train_frame(
|
||||
def get_bbox_from_mask(
|
||||
mask: np.ndarray, thr: float, decrease_quant: float = 0.05
|
||||
) -> Tuple[int, int, int, int]:
|
||||
# these corner cases need to be handled in order to avoid an infinite loop
|
||||
if mask.size == 0:
|
||||
warnings.warn("Empty mask is provided for bbox extraction.", stacklevel=1)
|
||||
return 0, 0, 1, 1
|
||||
|
||||
if not mask.min() >= 0.0:
|
||||
warnings.warn("Negative values in the mask for bbox extraction.", stacklevel=1)
|
||||
mask = mask.clip(min=0.0)
|
||||
|
||||
# bbox in xywh
|
||||
masks_for_box = np.zeros_like(mask)
|
||||
while masks_for_box.sum() <= 1.0:
|
||||
@ -134,7 +143,15 @@ T = TypeVar("T", bound=torch.Tensor)
|
||||
def bbox_xyxy_to_xywh(xyxy: T) -> T:
|
||||
wh = xyxy[2:] - xyxy[:2]
|
||||
xywh = torch.cat([xyxy[:2], wh])
|
||||
return xywh # pyre-ignore
|
||||
return xywh # pyre-ignore[7]
|
||||
|
||||
|
||||
def bbox_xywh_to_xyxy(xywh: T, clamp_size: float | int | None = None) -> T:
|
||||
wh = xywh[2:]
|
||||
if clamp_size is not None:
|
||||
wh = wh.clamp(min=clamp_size)
|
||||
xyxy = torch.cat([xywh[:2], xywh[:2] + wh])
|
||||
return xyxy # pyre-ignore[7]
|
||||
|
||||
|
||||
def get_clamp_bbox(
|
||||
@ -180,16 +197,6 @@ def rescale_bbox(
|
||||
return bbox * rel_size
|
||||
|
||||
|
||||
def bbox_xywh_to_xyxy(
|
||||
xywh: torch.Tensor, clamp_size: Optional[int] = None
|
||||
) -> torch.Tensor:
|
||||
xyxy = xywh.clone()
|
||||
if clamp_size is not None:
|
||||
xyxy[2:] = torch.clamp(xyxy[2:], clamp_size)
|
||||
xyxy[2:] += xyxy[:2]
|
||||
return xyxy
|
||||
|
||||
|
||||
def get_1d_bounds(arr: np.ndarray) -> Tuple[int, int]:
|
||||
nz = np.flatnonzero(arr)
|
||||
return nz[0], nz[-1] + 1
|
||||
@ -201,18 +208,24 @@ def resize_image(
|
||||
image_width: Optional[int],
|
||||
mode: str = "bilinear",
|
||||
) -> Tuple[torch.Tensor, float, torch.Tensor]:
|
||||
|
||||
if isinstance(image, np.ndarray):
|
||||
image = torch.from_numpy(image)
|
||||
|
||||
if image_height is None or image_width is None:
|
||||
if (
|
||||
image_height is None
|
||||
or image_width is None
|
||||
or image.shape[-2] == 0
|
||||
or image.shape[-1] == 0
|
||||
):
|
||||
# skip the resizing
|
||||
return image, 1.0, torch.ones_like(image[:1])
|
||||
|
||||
# takes numpy array or tensor, returns pytorch tensor
|
||||
minscale = min(
|
||||
image_height / image.shape[-2],
|
||||
image_width / image.shape[-1],
|
||||
)
|
||||
|
||||
imre = torch.nn.functional.interpolate(
|
||||
image[None],
|
||||
scale_factor=minscale,
|
||||
@ -220,6 +233,7 @@ def resize_image(
|
||||
align_corners=False if mode == "bilinear" else None,
|
||||
recompute_scale_factor=True,
|
||||
)[0]
|
||||
|
||||
imre_ = torch.zeros(image.shape[0], image_height, image_width)
|
||||
imre_[:, 0 : imre.shape[1], 0 : imre.shape[2]] = imre
|
||||
mask = torch.zeros(1, image_height, image_width)
|
||||
@ -232,9 +246,21 @@ def transpose_normalize_image(image: np.ndarray) -> np.ndarray:
|
||||
return im.astype(np.float32) / 255.0
|
||||
|
||||
|
||||
def load_image(path: str) -> np.ndarray:
|
||||
def load_image(
|
||||
path: str, try_read_alpha: bool = False, pil_format: str = "RGB"
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Load an image from a path and return it as a numpy array.
|
||||
If try_read_alpha is True, the image is read as RGBA and the alpha channel is
|
||||
returned as the fourth channel.
|
||||
Otherwise, the image is read as RGB and a three-channel image is returned.
|
||||
"""
|
||||
with Image.open(path) as pil_im:
|
||||
im = np.array(pil_im.convert("RGB"))
|
||||
# Check if the image has an alpha channel
|
||||
if try_read_alpha and pil_im.mode == "RGBA":
|
||||
im = np.array(pil_im)
|
||||
else:
|
||||
im = np.array(pil_im.convert(pil_format))
|
||||
|
||||
return transpose_normalize_image(im)
|
||||
|
||||
@ -329,6 +355,7 @@ def adjust_camera_to_bbox_crop_(
|
||||
|
||||
focal_length_px, principal_point_px = _convert_ndc_to_pixels(
|
||||
camera.focal_length[0],
|
||||
# pyre-fixme[29]: `Union[(self: TensorBase, indices: Union[None, slice[Any, A...
|
||||
camera.principal_point[0],
|
||||
image_size_wh,
|
||||
)
|
||||
@ -341,6 +368,7 @@ def adjust_camera_to_bbox_crop_(
|
||||
)
|
||||
|
||||
camera.focal_length = focal_length[None]
|
||||
# pyre-fixme[16]: `PerspectiveCameras` has no attribute `principal_point`.
|
||||
camera.principal_point = principal_point_cropped[None]
|
||||
|
||||
|
||||
@ -352,6 +380,7 @@ def adjust_camera_to_image_scale_(
|
||||
) -> PerspectiveCameras:
|
||||
focal_length_px, principal_point_px = _convert_ndc_to_pixels(
|
||||
camera.focal_length[0],
|
||||
# pyre-fixme[29]: `Union[(self: TensorBase, indices: Union[None, slice[Any, A...
|
||||
camera.principal_point[0],
|
||||
original_size_wh,
|
||||
)
|
||||
@ -368,7 +397,8 @@ def adjust_camera_to_image_scale_(
|
||||
image_size_wh_output,
|
||||
)
|
||||
camera.focal_length = focal_length_scaled[None]
|
||||
camera.principal_point = principal_point_scaled[None]
|
||||
# pyre-fixme[16]: `PerspectiveCameras` has no attribute `principal_point`.
|
||||
camera.principal_point = principal_point_scaled[None] # pyre-ignore[16]
|
||||
|
||||
|
||||
# NOTE this cache is per-worker; they are implemented as processes.
|
||||
|
@ -299,7 +299,6 @@ def eval_batch(
|
||||
)
|
||||
|
||||
for loss_fg_mask, name_postfix in zip((mask_crop, mask_fg), ("_masked", "_fg")):
|
||||
|
||||
loss_mask_now = mask_crop * loss_fg_mask
|
||||
|
||||
for rgb_metric_name, rgb_metric_fun in zip(
|
||||
|
@ -106,7 +106,7 @@ class ResNetFeatureExtractor(FeatureExtractorBase):
|
||||
self.layers = torch.nn.ModuleList()
|
||||
self.proj_layers = torch.nn.ModuleList()
|
||||
for stage in range(self.max_stage):
|
||||
stage_name = f"layer{stage+1}"
|
||||
stage_name = f"layer{stage + 1}"
|
||||
feature_name = self._get_resnet_stage_feature_name(stage)
|
||||
if (stage + 1) in self.stages:
|
||||
if (
|
||||
@ -139,12 +139,18 @@ class ResNetFeatureExtractor(FeatureExtractorBase):
|
||||
self.stages = set(self.stages) # convert to set for faster "in"
|
||||
|
||||
def _get_resnet_stage_feature_name(self, stage) -> str:
|
||||
return f"res_layer_{stage+1}"
|
||||
return f"res_layer_{stage + 1}"
|
||||
|
||||
def _resnet_normalize_image(self, img: torch.Tensor) -> torch.Tensor:
|
||||
# pyre-fixme[58]: `-` is not supported for operand types `Tensor` and
|
||||
# `Union[Tensor, Module]`.
|
||||
# pyre-fixme[58]: `/` is not supported for operand types `Tensor` and
|
||||
# `Union[Tensor, Module]`.
|
||||
return (img - self._resnet_mean) / self._resnet_std
|
||||
|
||||
def get_feat_dims(self) -> int:
|
||||
# pyre-fixme[29]: `Union[(self: TensorBase) -> Tensor, Tensor, Module]` is
|
||||
# not a function.
|
||||
return sum(self._feat_dim.values())
|
||||
|
||||
def forward(
|
||||
@ -183,7 +189,12 @@ class ResNetFeatureExtractor(FeatureExtractorBase):
|
||||
else:
|
||||
imgs_normed = imgs_resized
|
||||
# is not a function.
|
||||
# pyre-fixme[29]: `Union[Tensor, Module]` is not a function.
|
||||
feats = self.stem(imgs_normed)
|
||||
# pyre-fixme[6]: For 1st argument expected `Iterable[_T1]` but got
|
||||
# `Union[Tensor, Module]`.
|
||||
# pyre-fixme[6]: For 2nd argument expected `Iterable[_T2]` but got
|
||||
# `Union[Tensor, Module]`.
|
||||
for stage, (layer, proj) in enumerate(zip(self.layers, self.proj_layers)):
|
||||
feats = layer(feats)
|
||||
# just a sanity check below
|
||||
|
@ -478,6 +478,8 @@ class GenericModel(ImplicitronModelBase):
|
||||
)
|
||||
custom_args["global_code"] = global_code
|
||||
|
||||
# pyre-fixme[29]: `Union[(self: Tensor) -> Any, Tensor, Module]` is not a
|
||||
# function.
|
||||
for func in self._implicit_functions:
|
||||
func.bind_args(**custom_args)
|
||||
|
||||
@ -500,6 +502,8 @@ class GenericModel(ImplicitronModelBase):
|
||||
# Unbind the custom arguments to prevent pytorch from storing
|
||||
# large buffers of intermediate results due to points in the
|
||||
# bound arguments.
|
||||
# pyre-fixme[29]: `Union[(self: Tensor) -> Any, Tensor, Module]` is not a
|
||||
# function.
|
||||
for func in self._implicit_functions:
|
||||
func.unbind_args()
|
||||
|
||||
|
@ -71,6 +71,7 @@ class Autodecoder(Configurable, torch.nn.Module):
|
||||
return key_map
|
||||
|
||||
def calculate_squared_encoding_norm(self) -> Optional[torch.Tensor]:
|
||||
# pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute `weight`.
|
||||
return (self._autodecoder_codes.weight**2).mean()
|
||||
|
||||
def get_encoding_dim(self) -> int:
|
||||
@ -95,6 +96,7 @@ class Autodecoder(Configurable, torch.nn.Module):
|
||||
# pyre-fixme[9]: x has type `Union[List[str], LongTensor]`; used as
|
||||
# `Tensor`.
|
||||
x = torch.tensor(
|
||||
# pyre-fixme[29]: `Union[(self: TensorBase, indices: Union[None, ...
|
||||
[self._key_map[elem] for elem in x],
|
||||
dtype=torch.long,
|
||||
device=next(self.parameters()).device,
|
||||
@ -102,6 +104,7 @@ class Autodecoder(Configurable, torch.nn.Module):
|
||||
except StopIteration:
|
||||
raise ValueError("Not enough n_instances in the autodecoder") from None
|
||||
|
||||
# pyre-fixme[29]: `Union[Tensor, Module]` is not a function.
|
||||
return self._autodecoder_codes(x)
|
||||
|
||||
def _load_key_map_hook(
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
x
Reference in New Issue
Block a user