mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-12-22 07:10:34 +08:00
Compare commits
34 Commits
V0.7.8
...
bottler/un
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
62a2031dd4 | ||
|
|
3987612062 | ||
|
|
06a76ef8dd | ||
|
|
21205730d9 | ||
|
|
7e09505538 | ||
|
|
20bd8b33f6 | ||
|
|
7a3c0cbc9d | ||
|
|
215590b497 | ||
|
|
43cd681d4f | ||
|
|
42a4a7d432 | ||
|
|
699bc671ca | ||
|
|
49cf5a0f37 | ||
|
|
89b851e64c | ||
|
|
5247f6ad74 | ||
|
|
e41aff47db | ||
|
|
64a5bfadc8 | ||
|
|
055ab3a2e3 | ||
|
|
f6c2ca6bfc | ||
|
|
e20cbe9b0e | ||
|
|
c17e6f947a | ||
|
|
91c9f34137 | ||
|
|
81d82980bc | ||
|
|
8fe6934885 | ||
|
|
c434957b2a | ||
|
|
dd2a11b5fc | ||
|
|
9563ef79ca | ||
|
|
008c7ab58c | ||
|
|
9eaed4c495 | ||
|
|
e13848265d | ||
|
|
58566963d6 | ||
|
|
e17ed5cd50 | ||
|
|
8ed0c7a002 | ||
|
|
2da913c7e6 | ||
|
|
fca83e6369 |
@@ -88,7 +88,6 @@ def workflow_pair(
|
||||
upload=False,
|
||||
filter_branch,
|
||||
):
|
||||
|
||||
w = []
|
||||
py = python_version.replace(".", "")
|
||||
pyt = pytorch_version.replace(".", "")
|
||||
@@ -127,7 +126,6 @@ def generate_base_workflow(
|
||||
btype,
|
||||
filter_branch=None,
|
||||
):
|
||||
|
||||
d = {
|
||||
"name": base_workflow_name,
|
||||
"python_version": python_version,
|
||||
|
||||
23
.github/workflows/build.yml
vendored
Normal file
23
.github/workflows/build.yml
vendored
Normal file
@@ -0,0 +1,23 @@
|
||||
name: facebookresearch/pytorch3d/build_and_test
|
||||
on:
|
||||
pull_request:
|
||||
branches:
|
||||
- main
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
jobs:
|
||||
binary_linux_conda_cuda:
|
||||
runs-on: 4-core-ubuntu-gpu-t4
|
||||
env:
|
||||
PYTHON_VERSION: "3.12"
|
||||
BUILD_VERSION: "${{ github.run_number }}"
|
||||
PYTORCH_VERSION: "2.4.1"
|
||||
CU_VERSION: "cu121"
|
||||
JUST_TESTRUN: 1
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- name: Build and run tests
|
||||
run: |-
|
||||
conda create --name env --yes --quiet conda-build
|
||||
conda run --no-capture-output --name env python3 ./packaging/build_conda.py --use-conda-cuda
|
||||
@@ -36,5 +36,5 @@ then
|
||||
|
||||
echo "Running pyre..."
|
||||
echo "To restart/kill pyre server, run 'pyre restart' or 'pyre kill' in fbcode/"
|
||||
( cd ~/fbsource/fbcode; pyre -l vision/fair/pytorch3d/ )
|
||||
( cd ~/fbsource/fbcode; arc pyre check //vision/fair/pytorch3d/... )
|
||||
fi
|
||||
|
||||
@@ -10,6 +10,7 @@ This example demonstrates the most trivial, direct interface of the pulsar
|
||||
sphere renderer. It renders and saves an image with 10 random spheres.
|
||||
Output: basic.png.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import math
|
||||
from os import path
|
||||
|
||||
@@ -11,6 +11,7 @@ interface for sphere renderering. It renders and saves an image with
|
||||
10 random spheres.
|
||||
Output: basic-pt3d.png.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from os import path
|
||||
|
||||
|
||||
@@ -14,6 +14,7 @@ distorted. Gradient-based optimization is used to converge towards the
|
||||
original camera parameters.
|
||||
Output: cam.gif.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import math
|
||||
from os import path
|
||||
|
||||
@@ -14,6 +14,7 @@ distorted. Gradient-based optimization is used to converge towards the
|
||||
original camera parameters.
|
||||
Output: cam-pt3d.gif
|
||||
"""
|
||||
|
||||
import logging
|
||||
from os import path
|
||||
|
||||
|
||||
@@ -18,6 +18,7 @@ This example is not available yet through the 'unified' interface,
|
||||
because opacity support has not landed in PyTorch3D for general data
|
||||
structures yet.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import math
|
||||
from os import path
|
||||
|
||||
@@ -13,6 +13,7 @@ The scene is initialized with random spheres. Gradient-based
|
||||
optimization is used to converge towards a faithful
|
||||
scene representation.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import math
|
||||
|
||||
|
||||
@@ -13,6 +13,7 @@ The scene is initialized with random spheres. Gradient-based
|
||||
optimization is used to converge towards a faithful
|
||||
scene representation.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import math
|
||||
|
||||
|
||||
@@ -4,10 +4,11 @@
|
||||
# This source code is licensed under the BSD-style license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import argparse
|
||||
import os.path
|
||||
import runpy
|
||||
import subprocess
|
||||
from typing import List
|
||||
from typing import List, Tuple
|
||||
|
||||
# required env vars:
|
||||
# CU_VERSION: E.g. cu112
|
||||
@@ -23,7 +24,7 @@ pytorch_major_minor = tuple(int(i) for i in PYTORCH_VERSION.split(".")[:2])
|
||||
source_root_dir = os.environ["PWD"]
|
||||
|
||||
|
||||
def version_constraint(version):
|
||||
def version_constraint(version) -> str:
|
||||
"""
|
||||
Given version "11.3" returns " >=11.3,<11.4"
|
||||
"""
|
||||
@@ -32,7 +33,7 @@ def version_constraint(version):
|
||||
return f" >={version},<{upper}"
|
||||
|
||||
|
||||
def get_cuda_major_minor():
|
||||
def get_cuda_major_minor() -> Tuple[str, str]:
|
||||
if CU_VERSION == "cpu":
|
||||
raise ValueError("fn only for cuda builds")
|
||||
if len(CU_VERSION) != 5 or CU_VERSION[:2] != "cu":
|
||||
@@ -42,11 +43,10 @@ def get_cuda_major_minor():
|
||||
return major, minor
|
||||
|
||||
|
||||
def setup_cuda():
|
||||
def setup_cuda(use_conda_cuda: bool) -> List[str]:
|
||||
if CU_VERSION == "cpu":
|
||||
return
|
||||
return []
|
||||
major, minor = get_cuda_major_minor()
|
||||
os.environ["CUDA_HOME"] = f"/usr/local/cuda-{major}.{minor}/"
|
||||
os.environ["FORCE_CUDA"] = "1"
|
||||
|
||||
basic_nvcc_flags = (
|
||||
@@ -75,6 +75,15 @@ def setup_cuda():
|
||||
|
||||
if os.environ.get("JUST_TESTRUN", "0") != "1":
|
||||
os.environ["NVCC_FLAGS"] = nvcc_flags
|
||||
if use_conda_cuda:
|
||||
os.environ["CONDA_CUDA_TOOLKIT_BUILD_CONSTRAINT1"] = "- cuda-toolkit"
|
||||
os.environ["CONDA_CUDA_TOOLKIT_BUILD_CONSTRAINT2"] = (
|
||||
f"- cuda-version={major}.{minor}"
|
||||
)
|
||||
return ["-c", f"nvidia/label/cuda-{major}.{minor}.0"]
|
||||
else:
|
||||
os.environ["CUDA_HOME"] = f"/usr/local/cuda-{major}.{minor}/"
|
||||
return []
|
||||
|
||||
|
||||
def setup_conda_pytorch_constraint() -> List[str]:
|
||||
@@ -95,7 +104,7 @@ def setup_conda_pytorch_constraint() -> List[str]:
|
||||
return ["-c", "pytorch", "-c", "nvidia"]
|
||||
|
||||
|
||||
def setup_conda_cudatoolkit_constraint():
|
||||
def setup_conda_cudatoolkit_constraint() -> None:
|
||||
if CU_VERSION == "cpu":
|
||||
os.environ["CONDA_CPUONLY_FEATURE"] = "- cpuonly"
|
||||
os.environ["CONDA_CUDATOOLKIT_CONSTRAINT"] = ""
|
||||
@@ -116,7 +125,7 @@ def setup_conda_cudatoolkit_constraint():
|
||||
os.environ["CONDA_CUDATOOLKIT_CONSTRAINT"] = toolkit
|
||||
|
||||
|
||||
def do_build(start_args: List[str]):
|
||||
def do_build(start_args: List[str]) -> None:
|
||||
args = start_args.copy()
|
||||
|
||||
test_flag = os.environ.get("TEST_FLAG")
|
||||
@@ -132,8 +141,16 @@ def do_build(start_args: List[str]):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Build the conda package.")
|
||||
parser.add_argument(
|
||||
"--use-conda-cuda",
|
||||
action="store_true",
|
||||
help="get cuda from conda ignoring local cuda",
|
||||
)
|
||||
our_args = parser.parse_args()
|
||||
|
||||
args = ["conda", "build"]
|
||||
setup_cuda()
|
||||
args += setup_cuda(use_conda_cuda=our_args.use_conda_cuda)
|
||||
|
||||
init_path = source_root_dir + "/pytorch3d/__init__.py"
|
||||
build_version = runpy.run_path(init_path)["__version__"]
|
||||
|
||||
@@ -8,10 +8,13 @@ source:
|
||||
requirements:
|
||||
build:
|
||||
- {{ compiler('c') }} # [win]
|
||||
{{ environ.get('CONDA_CUDA_TOOLKIT_BUILD_CONSTRAINT1', '') }}
|
||||
{{ environ.get('CONDA_CUDA_TOOLKIT_BUILD_CONSTRAINT2', '') }}
|
||||
{{ environ.get('CONDA_CUB_CONSTRAINT') }}
|
||||
|
||||
host:
|
||||
- python
|
||||
- mkl =2023 # [x86_64]
|
||||
{{ environ.get('SETUPTOOLS_CONSTRAINT') }}
|
||||
{{ environ.get('CONDA_PYTORCH_BUILD_CONSTRAINT') }}
|
||||
{{ environ.get('CONDA_PYTORCH_MKL_CONSTRAINT') }}
|
||||
@@ -22,6 +25,7 @@ requirements:
|
||||
- python
|
||||
- numpy >=1.11
|
||||
- torchvision >=0.5
|
||||
- mkl =2023 # [x86_64]
|
||||
- iopath
|
||||
{{ environ.get('CONDA_PYTORCH_CONSTRAINT') }}
|
||||
{{ environ.get('CONDA_CUDATOOLKIT_CONSTRAINT') }}
|
||||
@@ -47,8 +51,11 @@ test:
|
||||
- imageio
|
||||
- hydra-core
|
||||
- accelerate
|
||||
- matplotlib
|
||||
- tabulate
|
||||
- pandas
|
||||
- sqlalchemy
|
||||
commands:
|
||||
#pytest .
|
||||
python -m unittest discover -v -s tests -t .
|
||||
|
||||
|
||||
|
||||
@@ -7,7 +7,7 @@
|
||||
|
||||
# pyre-unsafe
|
||||
|
||||
""""
|
||||
""" "
|
||||
This file is the entry point for launching experiments with Implicitron.
|
||||
|
||||
Launch Training
|
||||
@@ -44,6 +44,7 @@ The outputs of the experiment are saved and logged in multiple ways:
|
||||
config file.
|
||||
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import warnings
|
||||
|
||||
@@ -26,7 +26,6 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ModelFactoryBase(ReplaceableBase):
|
||||
|
||||
resume: bool = True # resume from the last checkpoint
|
||||
|
||||
def __call__(self, **kwargs) -> ImplicitronModelBase:
|
||||
@@ -116,7 +115,9 @@ class ImplicitronModelFactory(ModelFactoryBase):
|
||||
"cuda:%d" % 0: "cuda:%d" % accelerator.local_process_index
|
||||
}
|
||||
model_state_dict = torch.load(
|
||||
model_io.get_model_path(model_path), map_location=map_location
|
||||
model_io.get_model_path(model_path),
|
||||
map_location=map_location,
|
||||
weights_only=True,
|
||||
)
|
||||
|
||||
try:
|
||||
|
||||
@@ -123,6 +123,7 @@ class ImplicitronOptimizerFactory(OptimizerFactoryBase):
|
||||
"""
|
||||
# Get the parameters to optimize
|
||||
if hasattr(model, "_get_param_groups"): # use the model function
|
||||
# pyre-fixme[29]: `Union[Tensor, Module]` is not a function.
|
||||
p_groups = model._get_param_groups(self.lr, wd=self.weight_decay)
|
||||
else:
|
||||
p_groups = [
|
||||
@@ -241,7 +242,7 @@ class ImplicitronOptimizerFactory(OptimizerFactoryBase):
|
||||
map_location = {
|
||||
"cuda:%d" % 0: "cuda:%d" % accelerator.local_process_index
|
||||
}
|
||||
optimizer_state = torch.load(opt_path, map_location)
|
||||
optimizer_state = torch.load(opt_path, map_location, weights_only=True)
|
||||
else:
|
||||
raise FileNotFoundError(f"Optimizer state {opt_path} does not exist.")
|
||||
return optimizer_state
|
||||
|
||||
@@ -161,7 +161,6 @@ class ImplicitronTrainingLoop(TrainingLoopBase):
|
||||
for epoch in range(start_epoch, self.max_epochs):
|
||||
# automatic new_epoch and plotting of stats at every epoch start
|
||||
with stats:
|
||||
|
||||
# Make sure to re-seed random generators to ensure reproducibility
|
||||
# even after restart.
|
||||
seed_all_random_engines(seed + epoch)
|
||||
@@ -395,6 +394,7 @@ class ImplicitronTrainingLoop(TrainingLoopBase):
|
||||
):
|
||||
prefix = f"e{stats.epoch}_it{stats.it[trainmode]}"
|
||||
if hasattr(model, "visualize"):
|
||||
# pyre-fixme[29]: `Union[Tensor, Module]` is not a function.
|
||||
model.visualize(
|
||||
viz,
|
||||
visdom_env_imgs,
|
||||
|
||||
@@ -53,12 +53,8 @@ class TestExperiment(unittest.TestCase):
|
||||
cfg.data_source_ImplicitronDataSource_args.dataset_map_provider_class_type = (
|
||||
"JsonIndexDatasetMapProvider"
|
||||
)
|
||||
dataset_args = (
|
||||
cfg.data_source_ImplicitronDataSource_args.dataset_map_provider_JsonIndexDatasetMapProvider_args
|
||||
)
|
||||
dataloader_args = (
|
||||
cfg.data_source_ImplicitronDataSource_args.data_loader_map_provider_SequenceDataLoaderMapProvider_args
|
||||
)
|
||||
dataset_args = cfg.data_source_ImplicitronDataSource_args.dataset_map_provider_JsonIndexDatasetMapProvider_args
|
||||
dataloader_args = cfg.data_source_ImplicitronDataSource_args.data_loader_map_provider_SequenceDataLoaderMapProvider_args
|
||||
dataset_args.category = "skateboard"
|
||||
dataset_args.test_restrict_sequence_id = 0
|
||||
dataset_args.dataset_root = "manifold://co3d/tree/extracted"
|
||||
@@ -94,12 +90,8 @@ class TestExperiment(unittest.TestCase):
|
||||
cfg.data_source_ImplicitronDataSource_args.dataset_map_provider_class_type = (
|
||||
"JsonIndexDatasetMapProvider"
|
||||
)
|
||||
dataset_args = (
|
||||
cfg.data_source_ImplicitronDataSource_args.dataset_map_provider_JsonIndexDatasetMapProvider_args
|
||||
)
|
||||
dataloader_args = (
|
||||
cfg.data_source_ImplicitronDataSource_args.data_loader_map_provider_SequenceDataLoaderMapProvider_args
|
||||
)
|
||||
dataset_args = cfg.data_source_ImplicitronDataSource_args.dataset_map_provider_JsonIndexDatasetMapProvider_args
|
||||
dataloader_args = cfg.data_source_ImplicitronDataSource_args.data_loader_map_provider_SequenceDataLoaderMapProvider_args
|
||||
dataset_args.category = "skateboard"
|
||||
dataset_args.test_restrict_sequence_id = 0
|
||||
dataset_args.dataset_root = "manifold://co3d/tree/extracted"
|
||||
@@ -111,9 +103,7 @@ class TestExperiment(unittest.TestCase):
|
||||
cfg.training_loop_ImplicitronTrainingLoop_args.max_epochs = 2
|
||||
cfg.training_loop_ImplicitronTrainingLoop_args.store_checkpoints = False
|
||||
cfg.optimizer_factory_ImplicitronOptimizerFactory_args.lr_policy = "Exponential"
|
||||
cfg.optimizer_factory_ImplicitronOptimizerFactory_args.exponential_lr_step_size = (
|
||||
2
|
||||
)
|
||||
cfg.optimizer_factory_ImplicitronOptimizerFactory_args.exponential_lr_step_size = 2
|
||||
|
||||
if DEBUG:
|
||||
experiment.dump_cfg(cfg)
|
||||
|
||||
@@ -81,8 +81,9 @@ class TestOptimizerFactory(unittest.TestCase):
|
||||
|
||||
def test_param_overrides_self_param_group_assignment(self):
|
||||
pa, pb, pc = [torch.nn.Parameter(data=torch.tensor(i * 1.0)) for i in range(3)]
|
||||
na, nb = Node(params=[pa]), Node(
|
||||
params=[pb], param_groups={"self": "pb_self", "p1": "pb_param"}
|
||||
na, nb = (
|
||||
Node(params=[pa]),
|
||||
Node(params=[pb], param_groups={"self": "pb_self", "p1": "pb_param"}),
|
||||
)
|
||||
root = Node(children=[na, nb], params=[pc], param_groups={"m1": "pb_member"})
|
||||
param_groups = self._get_param_groups(root)
|
||||
|
||||
@@ -84,9 +84,9 @@ def get_nerf_datasets(
|
||||
|
||||
if autodownload and any(not os.path.isfile(p) for p in (cameras_path, image_path)):
|
||||
# Automatically download the data files if missing.
|
||||
download_data((dataset_name,), data_root=data_root)
|
||||
download_data([dataset_name], data_root=data_root)
|
||||
|
||||
train_data = torch.load(cameras_path)
|
||||
train_data = torch.load(cameras_path, weights_only=True)
|
||||
n_cameras = train_data["cameras"]["R"].shape[0]
|
||||
|
||||
_image_max_image_pixels = Image.MAX_IMAGE_PIXELS
|
||||
|
||||
@@ -194,7 +194,6 @@ class Stats:
|
||||
it = self.it[stat_set]
|
||||
|
||||
for stat in self.log_vars:
|
||||
|
||||
if stat not in self.stats[stat_set]:
|
||||
self.stats[stat_set][stat] = AverageMeter()
|
||||
|
||||
|
||||
@@ -24,7 +24,6 @@ CONFIG_DIR = os.path.join(os.path.dirname(os.path.realpath(__file__)), "configs"
|
||||
|
||||
@hydra.main(config_path=CONFIG_DIR, config_name="lego")
|
||||
def main(cfg: DictConfig):
|
||||
|
||||
# Device on which to run.
|
||||
if torch.cuda.is_available():
|
||||
device = "cuda"
|
||||
@@ -63,7 +62,7 @@ def main(cfg: DictConfig):
|
||||
raise ValueError(f"Model checkpoint {checkpoint_path} does not exist!")
|
||||
|
||||
print(f"Loading checkpoint {checkpoint_path}.")
|
||||
loaded_data = torch.load(checkpoint_path)
|
||||
loaded_data = torch.load(checkpoint_path, weights_only=True)
|
||||
# Do not load the cached xy grid.
|
||||
# - this allows setting an arbitrary evaluation image size.
|
||||
state_dict = {
|
||||
|
||||
@@ -42,7 +42,6 @@ class TestRaysampler(unittest.TestCase):
|
||||
cameras, rays = [], []
|
||||
|
||||
for _ in range(batch_size):
|
||||
|
||||
R = random_rotations(1)
|
||||
T = torch.randn(1, 3)
|
||||
focal_length = torch.rand(1, 2) + 0.5
|
||||
|
||||
@@ -25,7 +25,6 @@ CONFIG_DIR = os.path.join(os.path.dirname(os.path.realpath(__file__)), "configs"
|
||||
|
||||
@hydra.main(config_path=CONFIG_DIR, config_name="lego")
|
||||
def main(cfg: DictConfig):
|
||||
|
||||
# Set the relevant seeds for reproducibility.
|
||||
np.random.seed(cfg.seed)
|
||||
torch.manual_seed(cfg.seed)
|
||||
@@ -77,7 +76,7 @@ def main(cfg: DictConfig):
|
||||
# Resume training if requested.
|
||||
if cfg.resume and os.path.isfile(checkpoint_path):
|
||||
print(f"Resuming from checkpoint {checkpoint_path}.")
|
||||
loaded_data = torch.load(checkpoint_path)
|
||||
loaded_data = torch.load(checkpoint_path, weights_only=True)
|
||||
model.load_state_dict(loaded_data["model"])
|
||||
stats = pickle.loads(loaded_data["stats"])
|
||||
print(f" => resuming from epoch {stats.epoch}.")
|
||||
@@ -219,7 +218,6 @@ def main(cfg: DictConfig):
|
||||
|
||||
# Validation
|
||||
if epoch % cfg.validation_epoch_interval == 0 and epoch > 0:
|
||||
|
||||
# Sample a validation camera/image.
|
||||
val_batch = next(val_dataloader.__iter__())
|
||||
val_image, val_camera, camera_idx = val_batch[0].values()
|
||||
|
||||
@@ -17,7 +17,7 @@ Some functions which depend on PyTorch or Python versions.
|
||||
|
||||
|
||||
def meshgrid_ij(
|
||||
*A: Union[torch.Tensor, Sequence[torch.Tensor]]
|
||||
*A: Union[torch.Tensor, Sequence[torch.Tensor]],
|
||||
) -> Tuple[torch.Tensor, ...]: # pragma: no cover
|
||||
"""
|
||||
Like torch.meshgrid was before PyTorch 1.10.0, i.e. with indexing set to ij
|
||||
|
||||
@@ -7,7 +7,6 @@
|
||||
*/
|
||||
|
||||
#include <torch/extension.h>
|
||||
#include <queue>
|
||||
#include <tuple>
|
||||
|
||||
std::tuple<at::Tensor, at::Tensor> BallQueryCpu(
|
||||
|
||||
@@ -28,7 +28,6 @@ __global__ void alphaCompositeCudaForwardKernel(
|
||||
const at::PackedTensorAccessor64<float, 4, at::RestrictPtrTraits> alphas,
|
||||
const at::PackedTensorAccessor64<int64_t, 4, at::RestrictPtrTraits> points_idx) {
|
||||
// clang-format on
|
||||
const int64_t batch_size = result.size(0);
|
||||
const int64_t C = features.size(0);
|
||||
const int64_t H = points_idx.size(2);
|
||||
const int64_t W = points_idx.size(3);
|
||||
@@ -79,7 +78,6 @@ __global__ void alphaCompositeCudaBackwardKernel(
|
||||
const at::PackedTensorAccessor64<float, 4, at::RestrictPtrTraits> alphas,
|
||||
const at::PackedTensorAccessor64<int64_t, 4, at::RestrictPtrTraits> points_idx) {
|
||||
// clang-format on
|
||||
const int64_t batch_size = points_idx.size(0);
|
||||
const int64_t C = features.size(0);
|
||||
const int64_t H = points_idx.size(2);
|
||||
const int64_t W = points_idx.size(3);
|
||||
|
||||
@@ -28,7 +28,6 @@ __global__ void weightedSumNormCudaForwardKernel(
|
||||
const at::PackedTensorAccessor64<float, 4, at::RestrictPtrTraits> alphas,
|
||||
const at::PackedTensorAccessor64<int64_t, 4, at::RestrictPtrTraits> points_idx) {
|
||||
// clang-format on
|
||||
const int64_t batch_size = result.size(0);
|
||||
const int64_t C = features.size(0);
|
||||
const int64_t H = points_idx.size(2);
|
||||
const int64_t W = points_idx.size(3);
|
||||
@@ -92,7 +91,6 @@ __global__ void weightedSumNormCudaBackwardKernel(
|
||||
const at::PackedTensorAccessor64<float, 4, at::RestrictPtrTraits> alphas,
|
||||
const at::PackedTensorAccessor64<int64_t, 4, at::RestrictPtrTraits> points_idx) {
|
||||
// clang-format on
|
||||
const int64_t batch_size = points_idx.size(0);
|
||||
const int64_t C = features.size(0);
|
||||
const int64_t H = points_idx.size(2);
|
||||
const int64_t W = points_idx.size(3);
|
||||
|
||||
@@ -26,7 +26,6 @@ __global__ void weightedSumCudaForwardKernel(
|
||||
const at::PackedTensorAccessor64<float, 4, at::RestrictPtrTraits> alphas,
|
||||
const at::PackedTensorAccessor64<int64_t, 4, at::RestrictPtrTraits> points_idx) {
|
||||
// clang-format on
|
||||
const int64_t batch_size = result.size(0);
|
||||
const int64_t C = features.size(0);
|
||||
const int64_t H = points_idx.size(2);
|
||||
const int64_t W = points_idx.size(3);
|
||||
@@ -74,7 +73,6 @@ __global__ void weightedSumCudaBackwardKernel(
|
||||
const at::PackedTensorAccessor64<float, 4, at::RestrictPtrTraits> alphas,
|
||||
const at::PackedTensorAccessor64<int64_t, 4, at::RestrictPtrTraits> points_idx) {
|
||||
// clang-format on
|
||||
const int64_t batch_size = points_idx.size(0);
|
||||
const int64_t C = features.size(0);
|
||||
const int64_t H = points_idx.size(2);
|
||||
const int64_t W = points_idx.size(3);
|
||||
|
||||
@@ -7,15 +7,11 @@
|
||||
*/
|
||||
|
||||
// clang-format off
|
||||
#if !defined(USE_ROCM)
|
||||
#include "./pulsar/global.h" // Include before <torch/extension.h>.
|
||||
#endif
|
||||
#include <torch/extension.h>
|
||||
// clang-format on
|
||||
#if !defined(USE_ROCM)
|
||||
#include "./pulsar/pytorch/renderer.h"
|
||||
#include "./pulsar/pytorch/tensor_util.h"
|
||||
#endif
|
||||
#include "ball_query/ball_query.h"
|
||||
#include "blending/sigmoid_alpha_blend.h"
|
||||
#include "compositing/alpha_composite.h"
|
||||
@@ -104,7 +100,6 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
|
||||
// Pulsar.
|
||||
// Pulsar not enabled on AMD.
|
||||
#if !defined(USE_ROCM)
|
||||
#ifdef PULSAR_LOGGING_ENABLED
|
||||
c10::ShowLogInfoToStderr();
|
||||
#endif
|
||||
@@ -154,10 +149,10 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
py::arg("gamma"),
|
||||
py::arg("max_depth"),
|
||||
py::arg("min_depth") /* = 0.f*/,
|
||||
py::arg(
|
||||
"bg_col") /* = at::nullopt not exposed properly in pytorch 1.1. */
|
||||
py::arg("bg_col") /* = std::nullopt not exposed properly in
|
||||
pytorch 1.1. */
|
||||
,
|
||||
py::arg("opacity") /* = at::nullopt ... */,
|
||||
py::arg("opacity") /* = std::nullopt ... */,
|
||||
py::arg("percent_allowed_difference") = 0.01f,
|
||||
py::arg("max_n_hits") = MAX_UINT,
|
||||
py::arg("mode") = 0)
|
||||
@@ -189,5 +184,4 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
m.attr("MAX_UINT") = py::int_(MAX_UINT);
|
||||
m.attr("MAX_USHORT") = py::int_(MAX_USHORT);
|
||||
m.attr("PULSAR_MAX_GRAD_SPHERES") = py::int_(MAX_GRAD_SPHERES);
|
||||
#endif
|
||||
}
|
||||
|
||||
@@ -7,10 +7,7 @@
|
||||
*/
|
||||
|
||||
#include <torch/extension.h>
|
||||
#include <torch/torch.h>
|
||||
#include <list>
|
||||
#include <numeric>
|
||||
#include <queue>
|
||||
#include <tuple>
|
||||
#include "iou_box3d/iou_utils.h"
|
||||
|
||||
|
||||
@@ -461,10 +461,8 @@ __device__ inline std::tuple<float3, float3> ArgMaxVerts(
|
||||
__device__ inline bool IsCoplanarTriTri(
|
||||
const FaceVerts& tri1,
|
||||
const FaceVerts& tri2) {
|
||||
const float3 tri1_ctr = FaceCenter({tri1.v0, tri1.v1, tri1.v2});
|
||||
const float3 tri1_n = FaceNormal({tri1.v0, tri1.v1, tri1.v2});
|
||||
|
||||
const float3 tri2_ctr = FaceCenter({tri2.v0, tri2.v1, tri2.v2});
|
||||
const float3 tri2_n = FaceNormal({tri2.v0, tri2.v1, tri2.v2});
|
||||
|
||||
// Check if parallel
|
||||
@@ -500,7 +498,6 @@ __device__ inline bool IsCoplanarTriPlane(
|
||||
const FaceVerts& tri,
|
||||
const FaceVerts& plane,
|
||||
const float3& normal) {
|
||||
const float3 tri_ctr = FaceCenter({tri.v0, tri.v1, tri.v2});
|
||||
const float3 nt = FaceNormal({tri.v0, tri.v1, tri.v2});
|
||||
|
||||
// check if parallel
|
||||
@@ -728,7 +725,7 @@ __device__ inline int BoxIntersections(
|
||||
}
|
||||
}
|
||||
// Update the face_verts_out tris
|
||||
num_tris = offset;
|
||||
num_tris = min(MAX_TRIS, offset);
|
||||
for (int j = 0; j < num_tris; ++j) {
|
||||
face_verts_out[j] = tri_verts_updated[j];
|
||||
}
|
||||
|
||||
@@ -8,9 +8,7 @@
|
||||
|
||||
#include <torch/csrc/autograd/VariableTypeUtils.h>
|
||||
#include <torch/extension.h>
|
||||
#include <algorithm>
|
||||
#include <cmath>
|
||||
#include <thread>
|
||||
#include <vector>
|
||||
|
||||
// In the x direction, the location {0, ..., grid_size_x - 1} correspond to
|
||||
|
||||
@@ -36,11 +36,13 @@
|
||||
#pragma nv_diag_suppress 2951
|
||||
#pragma nv_diag_suppress 2967
|
||||
#else
|
||||
#if !defined(USE_ROCM)
|
||||
#pragma diag_suppress = attribute_not_allowed
|
||||
#pragma diag_suppress = 1866
|
||||
#pragma diag_suppress = 2941
|
||||
#pragma diag_suppress = 2951
|
||||
#pragma diag_suppress = 2967
|
||||
#endif //! USE_ROCM
|
||||
#endif
|
||||
#else // __CUDACC__
|
||||
#define INLINE inline
|
||||
@@ -56,7 +58,9 @@
|
||||
#pragma clang diagnostic pop
|
||||
#ifdef WITH_CUDA
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#if !defined(USE_ROCM)
|
||||
#include <vector_functions.h>
|
||||
#endif //! USE_ROCM
|
||||
#else
|
||||
#ifndef cudaStream_t
|
||||
typedef void* cudaStream_t;
|
||||
|
||||
@@ -59,6 +59,11 @@ getLastCudaError(const char* errorMessage, const char* file, const int line) {
|
||||
#define SHARED __shared__
|
||||
#define ACTIVEMASK() __activemask()
|
||||
#define BALLOT(mask, val) __ballot_sync((mask), val)
|
||||
|
||||
/* TODO (ROCM-6.2): None of the WARP_* are used anywhere and ROCM-6.2 natively
|
||||
* supports __shfl_*. Disabling until the move to ROCM-6.2.
|
||||
*/
|
||||
#if !defined(USE_ROCM)
|
||||
/**
|
||||
* Find the cumulative sum within a warp up to the current
|
||||
* thread lane, with each mask thread contributing base.
|
||||
@@ -115,6 +120,7 @@ INLINE DEVICE float3 WARP_SUM_FLOAT3(
|
||||
ret.z = WARP_SUM(group, mask, base.z);
|
||||
return ret;
|
||||
}
|
||||
#endif //! USE_ROCM
|
||||
|
||||
// Floating point.
|
||||
// #define FMUL(a, b) __fmul_rn((a), (b))
|
||||
@@ -142,6 +148,7 @@ INLINE DEVICE float3 WARP_SUM_FLOAT3(
|
||||
#define FMA(x, y, z) __fmaf_rn((x), (y), (z))
|
||||
#define I2F(a) __int2float_rn(a)
|
||||
#define FRCP(x) __frcp_rn(x)
|
||||
#if !defined(USE_ROCM)
|
||||
__device__ static float atomicMax(float* address, float val) {
|
||||
int* address_as_i = (int*)address;
|
||||
int old = *address_as_i, assumed;
|
||||
@@ -166,6 +173,7 @@ __device__ static float atomicMin(float* address, float val) {
|
||||
} while (assumed != old);
|
||||
return __int_as_float(old);
|
||||
}
|
||||
#endif //! USE_ROCM
|
||||
#define DMAX(a, b) FMAX(a, b)
|
||||
#define DMIN(a, b) FMIN(a, b)
|
||||
#define DSQRT(a) sqrt(a)
|
||||
@@ -14,7 +14,7 @@
|
||||
#include "./commands.h"
|
||||
|
||||
namespace pulsar {
|
||||
IHD CamGradInfo::CamGradInfo() {
|
||||
IHD CamGradInfo::CamGradInfo(int x) {
|
||||
cam_pos = make_float3(0.f, 0.f, 0.f);
|
||||
pixel_0_0_center = make_float3(0.f, 0.f, 0.f);
|
||||
pixel_dir_x = make_float3(0.f, 0.f, 0.f);
|
||||
|
||||
@@ -63,7 +63,7 @@ inline bool operator==(const CamInfo& a, const CamInfo& b) {
|
||||
};
|
||||
|
||||
struct CamGradInfo {
|
||||
HOST DEVICE CamGradInfo();
|
||||
HOST DEVICE CamGradInfo(int = 0);
|
||||
float3 cam_pos;
|
||||
float3 pixel_0_0_center;
|
||||
float3 pixel_dir_x;
|
||||
|
||||
@@ -24,7 +24,7 @@
|
||||
// #pragma diag_suppress = 68
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
// #pragma pop
|
||||
#include "../cuda/commands.h"
|
||||
#include "../gpu/commands.h"
|
||||
#else
|
||||
#pragma clang diagnostic push
|
||||
#pragma clang diagnostic ignored "-Weverything"
|
||||
|
||||
@@ -46,6 +46,7 @@ IHD float3 outer_product_sum(const float3& a) {
|
||||
}
|
||||
|
||||
// TODO: put intrinsics here.
|
||||
#if !defined(USE_ROCM)
|
||||
IHD float3 operator+(const float3& a, const float3& b) {
|
||||
return make_float3(a.x + b.x, a.y + b.y, a.z + b.z);
|
||||
}
|
||||
@@ -93,6 +94,7 @@ IHD float3 operator*(const float3& a, const float3& b) {
|
||||
IHD float3 operator*(const float& a, const float3& b) {
|
||||
return b * a;
|
||||
}
|
||||
#endif //! USE_ROCM
|
||||
|
||||
INLINE DEVICE float length(const float3& v) {
|
||||
// TODO: benchmark what's faster.
|
||||
|
||||
@@ -283,9 +283,15 @@ GLOBAL void render(
|
||||
(percent_allowed_difference > 0.f &&
|
||||
max_closest_possible_intersection > depth_threshold) ||
|
||||
tracker.get_n_hits() >= max_n_hits;
|
||||
#if defined(__CUDACC__) && defined(__HIP_PLATFORM_AMD__)
|
||||
unsigned long long warp_done = __ballot(done);
|
||||
int warp_done_bit_cnt = __popcll(warp_done);
|
||||
#else
|
||||
uint warp_done = thread_warp.ballot(done);
|
||||
int warp_done_bit_cnt = POPC(warp_done);
|
||||
#endif //__CUDACC__ && __HIP_PLATFORM_AMD__
|
||||
if (thread_warp.thread_rank() == 0)
|
||||
ATOMICADD_B(&n_pixels_done, POPC(warp_done));
|
||||
ATOMICADD_B(&n_pixels_done, warp_done_bit_cnt);
|
||||
// This sync is necessary to keep n_loaded until all threads are done with
|
||||
// painting.
|
||||
thread_block.sync();
|
||||
|
||||
@@ -213,8 +213,8 @@ std::tuple<size_t, size_t, bool, torch::Tensor> Renderer::arg_check(
|
||||
const float& gamma,
|
||||
const float& max_depth,
|
||||
float& min_depth,
|
||||
const c10::optional<torch::Tensor>& bg_col,
|
||||
const c10::optional<torch::Tensor>& opacity,
|
||||
const std::optional<torch::Tensor>& bg_col,
|
||||
const std::optional<torch::Tensor>& opacity,
|
||||
const float& percent_allowed_difference,
|
||||
const uint& max_n_hits,
|
||||
const uint& mode) {
|
||||
@@ -668,8 +668,8 @@ std::tuple<torch::Tensor, torch::Tensor> Renderer::forward(
|
||||
const float& gamma,
|
||||
const float& max_depth,
|
||||
float min_depth,
|
||||
const c10::optional<torch::Tensor>& bg_col,
|
||||
const c10::optional<torch::Tensor>& opacity,
|
||||
const std::optional<torch::Tensor>& bg_col,
|
||||
const std::optional<torch::Tensor>& opacity,
|
||||
const float& percent_allowed_difference,
|
||||
const uint& max_n_hits,
|
||||
const uint& mode) {
|
||||
@@ -888,14 +888,14 @@ std::tuple<torch::Tensor, torch::Tensor> Renderer::forward(
|
||||
};
|
||||
|
||||
std::tuple<
|
||||
at::optional<torch::Tensor>,
|
||||
at::optional<torch::Tensor>,
|
||||
at::optional<torch::Tensor>,
|
||||
at::optional<torch::Tensor>,
|
||||
at::optional<torch::Tensor>,
|
||||
at::optional<torch::Tensor>,
|
||||
at::optional<torch::Tensor>,
|
||||
at::optional<torch::Tensor>>
|
||||
std::optional<torch::Tensor>,
|
||||
std::optional<torch::Tensor>,
|
||||
std::optional<torch::Tensor>,
|
||||
std::optional<torch::Tensor>,
|
||||
std::optional<torch::Tensor>,
|
||||
std::optional<torch::Tensor>,
|
||||
std::optional<torch::Tensor>,
|
||||
std::optional<torch::Tensor>>
|
||||
Renderer::backward(
|
||||
const torch::Tensor& grad_im,
|
||||
const torch::Tensor& image,
|
||||
@@ -912,8 +912,8 @@ Renderer::backward(
|
||||
const float& gamma,
|
||||
const float& max_depth,
|
||||
float min_depth,
|
||||
const c10::optional<torch::Tensor>& bg_col,
|
||||
const c10::optional<torch::Tensor>& opacity,
|
||||
const std::optional<torch::Tensor>& bg_col,
|
||||
const std::optional<torch::Tensor>& opacity,
|
||||
const float& percent_allowed_difference,
|
||||
const uint& max_n_hits,
|
||||
const uint& mode,
|
||||
@@ -922,7 +922,7 @@ Renderer::backward(
|
||||
const bool& dif_rad,
|
||||
const bool& dif_cam,
|
||||
const bool& dif_opy,
|
||||
const at::optional<std::pair<uint, uint>>& dbg_pos) {
|
||||
const std::optional<std::pair<uint, uint>>& dbg_pos) {
|
||||
this->ensure_on_device(this->device_tracker.device());
|
||||
size_t batch_size;
|
||||
size_t n_points;
|
||||
@@ -1045,14 +1045,14 @@ Renderer::backward(
|
||||
}
|
||||
// Prepare the return value.
|
||||
std::tuple<
|
||||
at::optional<torch::Tensor>,
|
||||
at::optional<torch::Tensor>,
|
||||
at::optional<torch::Tensor>,
|
||||
at::optional<torch::Tensor>,
|
||||
at::optional<torch::Tensor>,
|
||||
at::optional<torch::Tensor>,
|
||||
at::optional<torch::Tensor>,
|
||||
at::optional<torch::Tensor>>
|
||||
std::optional<torch::Tensor>,
|
||||
std::optional<torch::Tensor>,
|
||||
std::optional<torch::Tensor>,
|
||||
std::optional<torch::Tensor>,
|
||||
std::optional<torch::Tensor>,
|
||||
std::optional<torch::Tensor>,
|
||||
std::optional<torch::Tensor>,
|
||||
std::optional<torch::Tensor>>
|
||||
ret;
|
||||
if (mode == 1 || (!dif_pos && !dif_col && !dif_rad && !dif_cam && !dif_opy)) {
|
||||
return ret;
|
||||
|
||||
@@ -44,21 +44,21 @@ struct Renderer {
|
||||
const float& gamma,
|
||||
const float& max_depth,
|
||||
float min_depth,
|
||||
const c10::optional<torch::Tensor>& bg_col,
|
||||
const c10::optional<torch::Tensor>& opacity,
|
||||
const std::optional<torch::Tensor>& bg_col,
|
||||
const std::optional<torch::Tensor>& opacity,
|
||||
const float& percent_allowed_difference,
|
||||
const uint& max_n_hits,
|
||||
const uint& mode);
|
||||
|
||||
std::tuple<
|
||||
at::optional<torch::Tensor>,
|
||||
at::optional<torch::Tensor>,
|
||||
at::optional<torch::Tensor>,
|
||||
at::optional<torch::Tensor>,
|
||||
at::optional<torch::Tensor>,
|
||||
at::optional<torch::Tensor>,
|
||||
at::optional<torch::Tensor>,
|
||||
at::optional<torch::Tensor>>
|
||||
std::optional<torch::Tensor>,
|
||||
std::optional<torch::Tensor>,
|
||||
std::optional<torch::Tensor>,
|
||||
std::optional<torch::Tensor>,
|
||||
std::optional<torch::Tensor>,
|
||||
std::optional<torch::Tensor>,
|
||||
std::optional<torch::Tensor>,
|
||||
std::optional<torch::Tensor>>
|
||||
backward(
|
||||
const torch::Tensor& grad_im,
|
||||
const torch::Tensor& image,
|
||||
@@ -75,8 +75,8 @@ struct Renderer {
|
||||
const float& gamma,
|
||||
const float& max_depth,
|
||||
float min_depth,
|
||||
const c10::optional<torch::Tensor>& bg_col,
|
||||
const c10::optional<torch::Tensor>& opacity,
|
||||
const std::optional<torch::Tensor>& bg_col,
|
||||
const std::optional<torch::Tensor>& opacity,
|
||||
const float& percent_allowed_difference,
|
||||
const uint& max_n_hits,
|
||||
const uint& mode,
|
||||
@@ -85,7 +85,7 @@ struct Renderer {
|
||||
const bool& dif_rad,
|
||||
const bool& dif_cam,
|
||||
const bool& dif_opy,
|
||||
const at::optional<std::pair<uint, uint>>& dbg_pos);
|
||||
const std::optional<std::pair<uint, uint>>& dbg_pos);
|
||||
|
||||
// Infrastructure.
|
||||
/**
|
||||
@@ -115,8 +115,8 @@ struct Renderer {
|
||||
const float& gamma,
|
||||
const float& max_depth,
|
||||
float& min_depth,
|
||||
const c10::optional<torch::Tensor>& bg_col,
|
||||
const c10::optional<torch::Tensor>& opacity,
|
||||
const std::optional<torch::Tensor>& bg_col,
|
||||
const std::optional<torch::Tensor>& opacity,
|
||||
const float& percent_allowed_difference,
|
||||
const uint& max_n_hits,
|
||||
const uint& mode);
|
||||
|
||||
@@ -8,6 +8,7 @@
|
||||
|
||||
#ifdef WITH_CUDA
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <c10/cuda/CUDAException.h>
|
||||
#include <cuda_runtime_api.h>
|
||||
#endif
|
||||
#include <torch/extension.h>
|
||||
@@ -33,13 +34,13 @@ torch::Tensor sphere_ids_from_result_info_nograd(
|
||||
.contiguous();
|
||||
if (forw_info.device().type() == c10::DeviceType::CUDA) {
|
||||
#ifdef WITH_CUDA
|
||||
cudaMemcpyAsync(
|
||||
C10_CUDA_CHECK(cudaMemcpyAsync(
|
||||
result.data_ptr(),
|
||||
tmp.data_ptr(),
|
||||
sizeof(uint32_t) * tmp.size(0) * tmp.size(1) * tmp.size(2) *
|
||||
tmp.size(3),
|
||||
cudaMemcpyDeviceToDevice,
|
||||
at::cuda::getCurrentCUDAStream());
|
||||
at::cuda::getCurrentCUDAStream()));
|
||||
#else
|
||||
throw std::runtime_error(
|
||||
"Copy on CUDA device initiated but built "
|
||||
|
||||
@@ -7,6 +7,7 @@
|
||||
*/
|
||||
|
||||
#ifdef WITH_CUDA
|
||||
#include <c10/cuda/CUDAException.h>
|
||||
#include <cuda_runtime_api.h>
|
||||
|
||||
namespace pulsar {
|
||||
@@ -17,7 +18,8 @@ void cudaDevToDev(
|
||||
const void* src,
|
||||
const int& size,
|
||||
const cudaStream_t& stream) {
|
||||
cudaMemcpyAsync(trg, src, size, cudaMemcpyDeviceToDevice, stream);
|
||||
C10_CUDA_CHECK(
|
||||
cudaMemcpyAsync(trg, src, size, cudaMemcpyDeviceToDevice, stream));
|
||||
}
|
||||
|
||||
void cudaDevToHost(
|
||||
@@ -25,7 +27,8 @@ void cudaDevToHost(
|
||||
const void* src,
|
||||
const int& size,
|
||||
const cudaStream_t& stream) {
|
||||
cudaMemcpyAsync(trg, src, size, cudaMemcpyDeviceToHost, stream);
|
||||
C10_CUDA_CHECK(
|
||||
cudaMemcpyAsync(trg, src, size, cudaMemcpyDeviceToHost, stream));
|
||||
}
|
||||
|
||||
} // namespace pytorch
|
||||
|
||||
@@ -9,7 +9,6 @@
|
||||
#include <torch/extension.h>
|
||||
#include <algorithm>
|
||||
#include <list>
|
||||
#include <queue>
|
||||
#include <thread>
|
||||
#include <tuple>
|
||||
#include "ATen/core/TensorAccessor.h"
|
||||
|
||||
@@ -35,8 +35,6 @@ __global__ void FarthestPointSamplingKernel(
|
||||
__shared__ int64_t selected_store;
|
||||
|
||||
// Get constants
|
||||
const int64_t N = points.size(0);
|
||||
const int64_t P = points.size(1);
|
||||
const int64_t D = points.size(2);
|
||||
|
||||
// Get batch index and thread index
|
||||
|
||||
@@ -376,8 +376,6 @@ PointLineDistanceBackward(
|
||||
float tt = t_top / t_bot;
|
||||
tt = __saturatef(tt);
|
||||
const float2 p_proj = (1.0f - tt) * v0 + tt * v1;
|
||||
const float2 d = p - p_proj;
|
||||
const float dist = sqrt(dot(d, d));
|
||||
|
||||
const float2 grad_p = -1.0f * grad_dist * 2.0f * (p_proj - p);
|
||||
const float2 grad_v0 = grad_dist * (1.0f - tt) * 2.0f * (p_proj - p);
|
||||
|
||||
@@ -83,7 +83,7 @@ class ShapeNetCore(ShapeNetBase): # pragma: no cover
|
||||
):
|
||||
synset_set.add(synset)
|
||||
elif (synset in self.synset_inv.keys()) and (
|
||||
(path.isdir(path.join(data_dir, self.synset_inv[synset])))
|
||||
path.isdir(path.join(data_dir, self.synset_inv[synset]))
|
||||
):
|
||||
synset_set.add(self.synset_inv[synset])
|
||||
else:
|
||||
|
||||
@@ -36,7 +36,6 @@ def collate_batched_meshes(batch: List[Dict]): # pragma: no cover
|
||||
|
||||
collated_dict["mesh"] = None
|
||||
if {"verts", "faces"}.issubset(collated_dict.keys()):
|
||||
|
||||
textures = None
|
||||
if "textures" in collated_dict:
|
||||
textures = TexturesAtlas(atlas=collated_dict["textures"])
|
||||
|
||||
@@ -26,7 +26,7 @@ from typing import (
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from pytorch3d.implicitron.dataset import types
|
||||
from pytorch3d.implicitron.dataset import orm_types, types
|
||||
from pytorch3d.implicitron.dataset.utils import (
|
||||
adjust_camera_to_bbox_crop_,
|
||||
adjust_camera_to_image_scale_,
|
||||
@@ -48,8 +48,12 @@ from pytorch3d.implicitron.dataset.utils import (
|
||||
from pytorch3d.implicitron.tools.config import registry, ReplaceableBase
|
||||
from pytorch3d.renderer.camera_utils import join_cameras_as_batch
|
||||
from pytorch3d.renderer.cameras import CamerasBase, PerspectiveCameras
|
||||
from pytorch3d.structures.meshes import join_meshes_as_batch, Meshes
|
||||
from pytorch3d.structures.pointclouds import join_pointclouds_as_batch, Pointclouds
|
||||
|
||||
FrameAnnotationT = types.FrameAnnotation | orm_types.SqlFrameAnnotation
|
||||
SequenceAnnotationT = types.SequenceAnnotation | orm_types.SqlSequenceAnnotation
|
||||
|
||||
|
||||
@dataclass
|
||||
class FrameData(Mapping[str, Any]):
|
||||
@@ -122,9 +126,9 @@ class FrameData(Mapping[str, Any]):
|
||||
meta: A dict for storing additional frame information.
|
||||
"""
|
||||
|
||||
frame_number: Optional[torch.LongTensor]
|
||||
sequence_name: Union[str, List[str]]
|
||||
sequence_category: Union[str, List[str]]
|
||||
frame_number: Optional[torch.LongTensor] = None
|
||||
sequence_name: Union[str, List[str]] = ""
|
||||
sequence_category: Union[str, List[str]] = ""
|
||||
frame_timestamp: Optional[torch.Tensor] = None
|
||||
image_size_hw: Optional[torch.LongTensor] = None
|
||||
effective_image_size_hw: Optional[torch.LongTensor] = None
|
||||
@@ -155,7 +159,7 @@ class FrameData(Mapping[str, Any]):
|
||||
new_params = {}
|
||||
for field_name in iter(self):
|
||||
value = getattr(self, field_name)
|
||||
if isinstance(value, (torch.Tensor, Pointclouds, CamerasBase)):
|
||||
if isinstance(value, (torch.Tensor, Pointclouds, CamerasBase, Meshes)):
|
||||
new_params[field_name] = value.to(*args, **kwargs)
|
||||
else:
|
||||
new_params[field_name] = value
|
||||
@@ -417,7 +421,6 @@ class FrameData(Mapping[str, Any]):
|
||||
for f in fields(elem):
|
||||
if not f.init:
|
||||
continue
|
||||
|
||||
list_values = override_fields.get(
|
||||
f.name, [getattr(d, f.name) for d in batch]
|
||||
)
|
||||
@@ -426,7 +429,7 @@ class FrameData(Mapping[str, Any]):
|
||||
if all(list_value is not None for list_value in list_values)
|
||||
else None
|
||||
)
|
||||
return cls(**collated)
|
||||
return type(elem)(**collated)
|
||||
|
||||
elif isinstance(elem, Pointclouds):
|
||||
return join_pointclouds_as_batch(batch)
|
||||
@@ -434,6 +437,8 @@ class FrameData(Mapping[str, Any]):
|
||||
elif isinstance(elem, CamerasBase):
|
||||
# TODO: don't store K; enforce working in NDC space
|
||||
return join_cameras_as_batch(batch)
|
||||
elif isinstance(elem, Meshes):
|
||||
return join_meshes_as_batch(batch)
|
||||
else:
|
||||
return torch.utils.data.dataloader.default_collate(batch)
|
||||
|
||||
@@ -454,8 +459,8 @@ class FrameDataBuilderBase(ReplaceableBase, Generic[FrameDataSubtype], ABC):
|
||||
@abstractmethod
|
||||
def build(
|
||||
self,
|
||||
frame_annotation: types.FrameAnnotation,
|
||||
sequence_annotation: types.SequenceAnnotation,
|
||||
frame_annotation: FrameAnnotationT,
|
||||
sequence_annotation: SequenceAnnotationT,
|
||||
*,
|
||||
load_blobs: bool = True,
|
||||
**kwargs,
|
||||
@@ -541,8 +546,8 @@ class GenericFrameDataBuilder(FrameDataBuilderBase[FrameDataSubtype], ABC):
|
||||
|
||||
def build(
|
||||
self,
|
||||
frame_annotation: types.FrameAnnotation,
|
||||
sequence_annotation: types.SequenceAnnotation,
|
||||
frame_annotation: FrameAnnotationT,
|
||||
sequence_annotation: SequenceAnnotationT,
|
||||
*,
|
||||
load_blobs: bool = True,
|
||||
**kwargs,
|
||||
@@ -586,58 +591,81 @@ class GenericFrameDataBuilder(FrameDataBuilderBase[FrameDataSubtype], ABC):
|
||||
),
|
||||
)
|
||||
|
||||
fg_mask_np: Optional[np.ndarray] = None
|
||||
dataset_root = self.dataset_root
|
||||
mask_annotation = frame_annotation.mask
|
||||
if mask_annotation is not None:
|
||||
if load_blobs and self.load_masks:
|
||||
fg_mask_np, mask_path = self._load_fg_probability(frame_annotation)
|
||||
depth_annotation = frame_annotation.depth
|
||||
image_path: str | None = None
|
||||
mask_path: str | None = None
|
||||
depth_path: str | None = None
|
||||
pcl_path: str | None = None
|
||||
if dataset_root is not None: # set all paths even if we won’t load blobs
|
||||
if frame_annotation.image.path is not None:
|
||||
image_path = os.path.join(dataset_root, frame_annotation.image.path)
|
||||
frame_data.image_path = image_path
|
||||
|
||||
if mask_annotation is not None and mask_annotation.path:
|
||||
mask_path = os.path.join(dataset_root, mask_annotation.path)
|
||||
frame_data.mask_path = mask_path
|
||||
|
||||
if depth_annotation is not None and depth_annotation.path is not None:
|
||||
depth_path = os.path.join(dataset_root, depth_annotation.path)
|
||||
frame_data.depth_path = depth_path
|
||||
|
||||
if point_cloud is not None:
|
||||
pcl_path = os.path.join(dataset_root, point_cloud.path)
|
||||
frame_data.sequence_point_cloud_path = pcl_path
|
||||
|
||||
fg_mask_np: np.ndarray | None = None
|
||||
bbox_xywh: tuple[float, float, float, float] | None = None
|
||||
|
||||
if mask_annotation is not None:
|
||||
if load_blobs and self.load_masks and mask_path:
|
||||
fg_mask_np = self._load_fg_probability(frame_annotation, mask_path)
|
||||
frame_data.fg_probability = safe_as_tensor(fg_mask_np, torch.float)
|
||||
|
||||
bbox_xywh = mask_annotation.bounding_box_xywh
|
||||
if bbox_xywh is None and fg_mask_np is not None:
|
||||
bbox_xywh = get_bbox_from_mask(fg_mask_np, self.box_crop_mask_thr)
|
||||
|
||||
frame_data.bbox_xywh = safe_as_tensor(bbox_xywh, torch.float)
|
||||
|
||||
if frame_annotation.image is not None:
|
||||
image_size_hw = safe_as_tensor(frame_annotation.image.size, torch.long)
|
||||
frame_data.image_size_hw = image_size_hw # original image size
|
||||
# image size after crop/resize
|
||||
frame_data.effective_image_size_hw = image_size_hw
|
||||
image_path = None
|
||||
dataset_root = self.dataset_root
|
||||
if frame_annotation.image.path is not None and dataset_root is not None:
|
||||
image_path = os.path.join(dataset_root, frame_annotation.image.path)
|
||||
frame_data.image_path = image_path
|
||||
|
||||
if load_blobs and self.load_images:
|
||||
if image_path is None:
|
||||
raise ValueError("Image path is required to load images.")
|
||||
|
||||
image_np = load_image(self._local_path(image_path))
|
||||
no_mask = fg_mask_np is None # didn’t read the mask file
|
||||
image_np = load_image(
|
||||
self._local_path(image_path), try_read_alpha=no_mask
|
||||
)
|
||||
if image_np.shape[0] == 4: # RGBA image
|
||||
if no_mask:
|
||||
fg_mask_np = image_np[3:]
|
||||
frame_data.fg_probability = safe_as_tensor(
|
||||
fg_mask_np, torch.float
|
||||
)
|
||||
|
||||
image_np = image_np[:3]
|
||||
|
||||
frame_data.image_rgb = self._postprocess_image(
|
||||
image_np, frame_annotation.image.size, frame_data.fg_probability
|
||||
)
|
||||
|
||||
if (
|
||||
load_blobs
|
||||
and self.load_depths
|
||||
and frame_annotation.depth is not None
|
||||
and frame_annotation.depth.path is not None
|
||||
):
|
||||
(
|
||||
frame_data.depth_map,
|
||||
frame_data.depth_path,
|
||||
frame_data.depth_mask,
|
||||
) = self._load_mask_depth(frame_annotation, fg_mask_np)
|
||||
if bbox_xywh is None and fg_mask_np is not None:
|
||||
bbox_xywh = get_bbox_from_mask(fg_mask_np, self.box_crop_mask_thr)
|
||||
frame_data.bbox_xywh = safe_as_tensor(bbox_xywh, torch.float)
|
||||
|
||||
if load_blobs and self.load_depths and depth_path is not None:
|
||||
frame_data.depth_map, frame_data.depth_mask = self._load_mask_depth(
|
||||
frame_annotation, depth_path, fg_mask_np
|
||||
)
|
||||
|
||||
if load_blobs and self.load_point_clouds and point_cloud is not None:
|
||||
pcl_path = self._fix_point_cloud_path(point_cloud.path)
|
||||
assert pcl_path is not None
|
||||
frame_data.sequence_point_cloud = load_pointcloud(
|
||||
self._local_path(pcl_path), max_points=self.max_points
|
||||
)
|
||||
frame_data.sequence_point_cloud_path = pcl_path
|
||||
|
||||
if frame_annotation.viewpoint is not None:
|
||||
frame_data.camera = self._get_pytorch3d_camera(frame_annotation)
|
||||
@@ -653,18 +681,14 @@ class GenericFrameDataBuilder(FrameDataBuilderBase[FrameDataSubtype], ABC):
|
||||
|
||||
return frame_data
|
||||
|
||||
def _load_fg_probability(
|
||||
self, entry: types.FrameAnnotation
|
||||
) -> Tuple[np.ndarray, str]:
|
||||
assert self.dataset_root is not None and entry.mask is not None
|
||||
full_path = os.path.join(self.dataset_root, entry.mask.path)
|
||||
fg_probability = load_mask(self._local_path(full_path))
|
||||
def _load_fg_probability(self, entry: FrameAnnotationT, path: str) -> np.ndarray:
|
||||
fg_probability = load_mask(self._local_path(path))
|
||||
if fg_probability.shape[-2:] != entry.image.size:
|
||||
raise ValueError(
|
||||
f"bad mask size: {fg_probability.shape[-2:]} vs {entry.image.size}!"
|
||||
)
|
||||
|
||||
return fg_probability, full_path
|
||||
return fg_probability
|
||||
|
||||
def _postprocess_image(
|
||||
self,
|
||||
@@ -685,14 +709,14 @@ class GenericFrameDataBuilder(FrameDataBuilderBase[FrameDataSubtype], ABC):
|
||||
|
||||
def _load_mask_depth(
|
||||
self,
|
||||
entry: types.FrameAnnotation,
|
||||
entry: FrameAnnotationT,
|
||||
path: str,
|
||||
fg_mask: Optional[np.ndarray],
|
||||
) -> Tuple[torch.Tensor, str, torch.Tensor]:
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
entry_depth = entry.depth
|
||||
dataset_root = self.dataset_root
|
||||
assert dataset_root is not None
|
||||
assert entry_depth is not None and entry_depth.path is not None
|
||||
path = os.path.join(dataset_root, entry_depth.path)
|
||||
assert entry_depth is not None
|
||||
depth_map = load_depth(self._local_path(path), entry_depth.scale_adjustment)
|
||||
|
||||
if self.mask_depths:
|
||||
@@ -706,11 +730,11 @@ class GenericFrameDataBuilder(FrameDataBuilderBase[FrameDataSubtype], ABC):
|
||||
else:
|
||||
depth_mask = (depth_map > 0.0).astype(np.float32)
|
||||
|
||||
return torch.tensor(depth_map), path, torch.tensor(depth_mask)
|
||||
return torch.tensor(depth_map), torch.tensor(depth_mask)
|
||||
|
||||
def _get_pytorch3d_camera(
|
||||
self,
|
||||
entry: types.FrameAnnotation,
|
||||
entry: FrameAnnotationT,
|
||||
) -> PerspectiveCameras:
|
||||
entry_viewpoint = entry.viewpoint
|
||||
assert entry_viewpoint is not None
|
||||
@@ -739,19 +763,6 @@ class GenericFrameDataBuilder(FrameDataBuilderBase[FrameDataSubtype], ABC):
|
||||
T=torch.tensor(entry_viewpoint.T, dtype=torch.float)[None],
|
||||
)
|
||||
|
||||
def _fix_point_cloud_path(self, path: str) -> str:
|
||||
"""
|
||||
Fix up a point cloud path from the dataset.
|
||||
Some files in Co3Dv2 have an accidental absolute path stored.
|
||||
"""
|
||||
unwanted_prefix = (
|
||||
"/large_experiments/p3/replay/datasets/co3d/co3d45k_220512/export_v23/"
|
||||
)
|
||||
if path.startswith(unwanted_prefix):
|
||||
path = path[len(unwanted_prefix) :]
|
||||
assert self.dataset_root is not None
|
||||
return os.path.join(self.dataset_root, path)
|
||||
|
||||
def _local_path(self, path: str) -> str:
|
||||
if self.path_manager is None:
|
||||
return path
|
||||
|
||||
@@ -222,7 +222,6 @@ class JsonIndexDatasetMapProviderV2(DatasetMapProviderBase):
|
||||
self.dataset_map = dataset_map
|
||||
|
||||
def _load_category(self, category: str) -> DatasetMap:
|
||||
|
||||
frame_file = os.path.join(self.dataset_root, category, "frame_annotations.jgz")
|
||||
sequence_file = os.path.join(
|
||||
self.dataset_root, category, "sequence_annotations.jgz"
|
||||
|
||||
@@ -75,7 +75,6 @@ def _minify(basedir, path_manager, factors=(), resolutions=()):
|
||||
def _load_data(
|
||||
basedir, factor=None, width=None, height=None, load_imgs=True, path_manager=None
|
||||
):
|
||||
|
||||
poses_arr = np.load(
|
||||
_local_path(path_manager, os.path.join(basedir, "poses_bounds.npy"))
|
||||
)
|
||||
@@ -164,7 +163,6 @@ def ptstocam(pts, c2w):
|
||||
|
||||
|
||||
def poses_avg(poses):
|
||||
|
||||
hwf = poses[0, :3, -1:]
|
||||
|
||||
center = poses[:, :3, 3].mean(0)
|
||||
@@ -192,7 +190,6 @@ def render_path_spiral(c2w, up, rads, focal, zdelta, zrate, rots, N):
|
||||
|
||||
|
||||
def recenter_poses(poses):
|
||||
|
||||
poses_ = poses + 0
|
||||
bottom = np.reshape([0, 0, 0, 1.0], [1, 4])
|
||||
c2w = poses_avg(poses)
|
||||
@@ -256,7 +253,6 @@ def spherify_poses(poses, bds):
|
||||
new_poses = []
|
||||
|
||||
for th in np.linspace(0.0, 2.0 * np.pi, 120):
|
||||
|
||||
camorigin = np.array([radcircle * np.cos(th), radcircle * np.sin(th), zh])
|
||||
up = np.array([0, 0, -1.0])
|
||||
|
||||
@@ -311,7 +307,6 @@ def load_llff_data(
|
||||
path_zflat=False,
|
||||
path_manager=None,
|
||||
):
|
||||
|
||||
poses, bds, imgs = _load_data(
|
||||
basedir, factor=factor, path_manager=path_manager
|
||||
) # factor=8 downsamples original imgs by 8x
|
||||
|
||||
@@ -4,6 +4,8 @@
|
||||
# This source code is licensed under the BSD-style license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
# pyre-unsafe
|
||||
|
||||
# This functionality requires SQLAlchemy 2.0 or later.
|
||||
|
||||
import math
|
||||
|
||||
@@ -4,11 +4,15 @@
|
||||
# This source code is licensed under the BSD-style license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
# pyre-unsafe
|
||||
|
||||
import hashlib
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
|
||||
import urllib
|
||||
from dataclasses import dataclass, Field, field
|
||||
from typing import (
|
||||
Any,
|
||||
ClassVar,
|
||||
@@ -29,17 +33,18 @@ import sqlalchemy as sa
|
||||
import torch
|
||||
from pytorch3d.implicitron.dataset.dataset_base import DatasetBase
|
||||
|
||||
from pytorch3d.implicitron.dataset.frame_data import ( # noqa
|
||||
from pytorch3d.implicitron.dataset.frame_data import (
|
||||
FrameData,
|
||||
FrameDataBuilder,
|
||||
FrameDataBuilder, # noqa
|
||||
FrameDataBuilderBase,
|
||||
)
|
||||
|
||||
from pytorch3d.implicitron.tools.config import (
|
||||
registry,
|
||||
ReplaceableBase,
|
||||
run_auto_creation,
|
||||
)
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.orm import scoped_session, Session, sessionmaker
|
||||
|
||||
from .orm_types import SqlFrameAnnotation, SqlSequenceAnnotation
|
||||
|
||||
@@ -51,7 +56,7 @@ _SET_LISTS_TABLE: str = "set_lists"
|
||||
|
||||
|
||||
@registry.register
|
||||
class SqlIndexDataset(DatasetBase, ReplaceableBase): # pyre-ignore
|
||||
class SqlIndexDataset(DatasetBase, ReplaceableBase):
|
||||
"""
|
||||
A dataset with annotations stored as SQLite tables. This is an index-based dataset.
|
||||
The length is returned after all sequence and frame filters are applied (see param
|
||||
@@ -88,6 +93,7 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase): # pyre-ignore
|
||||
engine verbatim. Don’t expose it to end users of your application!
|
||||
pick_categories: Restrict the dataset to the given list of categories.
|
||||
pick_sequences: A Sequence of sequence names to restrict the dataset to.
|
||||
pick_sequences_sql_clause: Custom SQL WHERE clause to constrain sequence annotations.
|
||||
exclude_sequences: A Sequence of the names of the sequences to exclude.
|
||||
limit_sequences_per_category_to: Limit the dataset to the first up to N
|
||||
sequences within each category (applies after all other sequence filters
|
||||
@@ -102,9 +108,16 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase): # pyre-ignore
|
||||
more frames than that; applied after other frame-level filters.
|
||||
seed: The seed of the random generator sampling `n_frames_per_sequence`
|
||||
random frames per sequence.
|
||||
preload_metadata: If True, the metadata is preloaded into memory.
|
||||
precompute_seq_to_idx: If True, precomputes the mapping from sequence name to indices.
|
||||
scoped_session: If True, allows different parts of the code to share
|
||||
a global session to access the database.
|
||||
"""
|
||||
|
||||
frame_annotations_type: ClassVar[Type[SqlFrameAnnotation]] = SqlFrameAnnotation
|
||||
sequence_annotations_type: ClassVar[Type[SqlSequenceAnnotation]] = (
|
||||
SqlSequenceAnnotation
|
||||
)
|
||||
|
||||
sqlite_metadata_file: str = ""
|
||||
dataset_root: Optional[str] = None
|
||||
@@ -117,6 +130,7 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase): # pyre-ignore
|
||||
pick_categories: Tuple[str, ...] = ()
|
||||
|
||||
pick_sequences: Tuple[str, ...] = ()
|
||||
pick_sequences_sql_clause: Optional[str] = None
|
||||
exclude_sequences: Tuple[str, ...] = ()
|
||||
limit_sequences_per_category_to: int = 0
|
||||
limit_sequences_to: int = 0
|
||||
@@ -124,12 +138,22 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase): # pyre-ignore
|
||||
n_frames_per_sequence: int = -1
|
||||
seed: int = 0
|
||||
remove_empty_masks_poll_whole_table_threshold: int = 300_000
|
||||
preload_metadata: bool = False
|
||||
precompute_seq_to_idx: bool = False
|
||||
# we set it manually in the constructor
|
||||
# _index: pd.DataFrame = field(init=False)
|
||||
_index: pd.DataFrame = field(init=False, metadata={"omegaconf_ignore": True})
|
||||
_sql_engine: sa.engine.Engine = field(
|
||||
init=False, metadata={"omegaconf_ignore": True}
|
||||
)
|
||||
eval_batches: Optional[List[Any]] = field(
|
||||
init=False, metadata={"omegaconf_ignore": True}
|
||||
)
|
||||
|
||||
frame_data_builder: FrameDataBuilderBase
|
||||
frame_data_builder: FrameDataBuilderBase # pyre-ignore[13]
|
||||
frame_data_builder_class_type: str = "FrameDataBuilder"
|
||||
|
||||
scoped_session: bool = False
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
if sa.__version__ < "2.0":
|
||||
raise ImportError("This class requires SQL Alchemy 2.0 or later")
|
||||
@@ -138,19 +162,28 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase): # pyre-ignore
|
||||
raise ValueError("sqlite_metadata_file must be set")
|
||||
|
||||
if self.dataset_root:
|
||||
frame_builder_type = self.frame_data_builder_class_type
|
||||
getattr(self, f"frame_data_builder_{frame_builder_type}_args")[
|
||||
"dataset_root"
|
||||
] = self.dataset_root
|
||||
frame_args = f"frame_data_builder_{self.frame_data_builder_class_type}_args"
|
||||
getattr(self, frame_args)["dataset_root"] = self.dataset_root
|
||||
getattr(self, frame_args)["path_manager"] = self.path_manager
|
||||
|
||||
run_auto_creation(self)
|
||||
self.frame_data_builder.path_manager = self.path_manager
|
||||
|
||||
# pyre-ignore # NOTE: sqlite-specific args (read-only mode).
|
||||
if self.path_manager is not None:
|
||||
self.sqlite_metadata_file = self.path_manager.get_local_path(
|
||||
self.sqlite_metadata_file
|
||||
)
|
||||
self.subset_lists_file = self.path_manager.get_local_path(
|
||||
self.subset_lists_file
|
||||
)
|
||||
|
||||
# NOTE: sqlite-specific args (read-only mode).
|
||||
self._sql_engine = sa.create_engine(
|
||||
f"sqlite:///file:{self.sqlite_metadata_file}?mode=ro&uri=true"
|
||||
f"sqlite:///file:{urllib.parse.quote(self.sqlite_metadata_file)}?mode=ro&uri=true"
|
||||
)
|
||||
|
||||
if self.preload_metadata:
|
||||
self._sql_engine = self._preload_database(self._sql_engine)
|
||||
|
||||
sequences = self._get_filtered_sequences_if_any()
|
||||
|
||||
if self.subsets:
|
||||
@@ -166,16 +199,29 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase): # pyre-ignore
|
||||
if len(index) == 0:
|
||||
raise ValueError(f"There are no frames in the subsets: {self.subsets}!")
|
||||
|
||||
self._index = index.set_index(["sequence_name", "frame_number"]) # pyre-ignore
|
||||
self._index = index.set_index(["sequence_name", "frame_number"])
|
||||
|
||||
self.eval_batches = None # pyre-ignore
|
||||
self.eval_batches = None
|
||||
if self.eval_batches_file:
|
||||
self.eval_batches = self._load_filter_eval_batches()
|
||||
|
||||
logger.info(str(self))
|
||||
|
||||
if self.scoped_session:
|
||||
self._session_factory = sessionmaker(bind=self._sql_engine) # pyre-ignore
|
||||
|
||||
if self.precompute_seq_to_idx:
|
||||
# This is deprecated and will be removed in the future.
|
||||
# After we backport https://github.com/facebookresearch/uco3d/pull/3
|
||||
logger.warning(
|
||||
"Using precompute_seq_to_idx is deprecated and will be removed in the future."
|
||||
)
|
||||
self._index["rowid"] = np.arange(len(self._index))
|
||||
groupby = self._index.groupby("sequence_name", sort=False)["rowid"]
|
||||
self._seq_to_indices = dict(groupby.apply(list)) # pyre-ignore
|
||||
del self._index["rowid"]
|
||||
|
||||
def __len__(self) -> int:
|
||||
# pyre-ignore[16]
|
||||
return len(self._index)
|
||||
|
||||
def __getitem__(self, frame_idx: Union[int, Tuple[str, int]]) -> FrameData:
|
||||
@@ -232,12 +278,18 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase): # pyre-ignore
|
||||
self.frame_annotations_type.frame_number
|
||||
== int(frame), # cast from np.int64
|
||||
)
|
||||
seq_stmt = sa.select(SqlSequenceAnnotation).where(
|
||||
SqlSequenceAnnotation.sequence_name == seq
|
||||
seq_stmt = sa.select(self.sequence_annotations_type).where(
|
||||
self.sequence_annotations_type.sequence_name == seq
|
||||
)
|
||||
with Session(self._sql_engine) as session:
|
||||
entry = session.scalars(stmt).one()
|
||||
seq_metadata = session.scalars(seq_stmt).one()
|
||||
if self.scoped_session:
|
||||
# pyre-ignore
|
||||
with scoped_session(self._session_factory)() as session:
|
||||
entry = session.scalars(stmt).one()
|
||||
seq_metadata = session.scalars(seq_stmt).one()
|
||||
else:
|
||||
with Session(self._sql_engine) as session:
|
||||
entry = session.scalars(stmt).one()
|
||||
seq_metadata = session.scalars(seq_stmt).one()
|
||||
|
||||
assert entry.image.path == self._index.loc[(seq, frame), "_image_path"]
|
||||
|
||||
@@ -250,7 +302,6 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase): # pyre-ignore
|
||||
return frame_data
|
||||
|
||||
def __str__(self) -> str:
|
||||
# pyre-ignore[16]
|
||||
return f"SqlIndexDataset #frames={len(self._index)}"
|
||||
|
||||
def sequence_names(self) -> Iterable[str]:
|
||||
@@ -260,9 +311,10 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase): # pyre-ignore
|
||||
# override
|
||||
def category_to_sequence_names(self) -> Dict[str, List[str]]:
|
||||
stmt = sa.select(
|
||||
SqlSequenceAnnotation.category, SqlSequenceAnnotation.sequence_name
|
||||
self.sequence_annotations_type.category,
|
||||
self.sequence_annotations_type.sequence_name,
|
||||
).where( # we limit results to sequences that have frames after all filters
|
||||
SqlSequenceAnnotation.sequence_name.in_(self.sequence_names())
|
||||
self.sequence_annotations_type.sequence_name.in_(self.sequence_names())
|
||||
)
|
||||
with self._sql_engine.connect() as connection:
|
||||
cat_to_seqs = pd.read_sql(stmt, connection)
|
||||
@@ -335,17 +387,31 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase): # pyre-ignore
|
||||
rows = self._index.index.get_loc(seq_name)
|
||||
if isinstance(rows, slice):
|
||||
assert rows.stop is not None, "Unexpected result from pandas"
|
||||
rows = range(rows.start or 0, rows.stop, rows.step or 1)
|
||||
rows_seq = range(rows.start or 0, rows.stop, rows.step or 1)
|
||||
else:
|
||||
rows = np.where(rows)[0]
|
||||
rows_seq = list(np.where(rows)[0])
|
||||
|
||||
index_slice, idx = self._get_frame_no_coalesced_ts_by_row_indices(
|
||||
rows, seq_name, subset_filter
|
||||
rows_seq, seq_name, subset_filter
|
||||
)
|
||||
index_slice["idx"] = idx
|
||||
|
||||
yield from index_slice.itertuples(index=False)
|
||||
|
||||
# override
|
||||
def sequence_indices_in_order(
|
||||
self, seq_name: str, subset_filter: Optional[Sequence[str]] = None
|
||||
) -> Iterator[int]:
|
||||
"""Same as `sequence_frames_in_order` but returns the iterator over
|
||||
only dataset indices.
|
||||
"""
|
||||
if self.precompute_seq_to_idx and subset_filter is None:
|
||||
# pyre-ignore
|
||||
yield from self._seq_to_indices[seq_name]
|
||||
else:
|
||||
for _, _, idx in self.sequence_frames_in_order(seq_name, subset_filter):
|
||||
yield idx
|
||||
|
||||
# override
|
||||
def get_eval_batches(self) -> Optional[List[Any]]:
|
||||
"""
|
||||
@@ -379,11 +445,35 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase): # pyre-ignore
|
||||
or self.limit_sequences_to > 0
|
||||
or self.limit_sequences_per_category_to > 0
|
||||
or len(self.pick_sequences) > 0
|
||||
or self.pick_sequences_sql_clause is not None
|
||||
or len(self.exclude_sequences) > 0
|
||||
or len(self.pick_categories) > 0
|
||||
or self.n_frames_per_sequence > 0
|
||||
)
|
||||
|
||||
def _preload_database(
|
||||
self, source_engine: sa.engine.base.Engine
|
||||
) -> sa.engine.base.Engine:
|
||||
destination_engine = sa.create_engine("sqlite:///:memory:")
|
||||
metadata = sa.MetaData()
|
||||
metadata.reflect(bind=source_engine)
|
||||
metadata.create_all(bind=destination_engine)
|
||||
|
||||
with source_engine.connect() as source_conn:
|
||||
with destination_engine.connect() as destination_conn:
|
||||
for table_obj in metadata.tables.values():
|
||||
# Select all rows from the source table
|
||||
source_rows = source_conn.execute(table_obj.select())
|
||||
|
||||
# Insert rows into the destination table
|
||||
for row in source_rows:
|
||||
destination_conn.execute(table_obj.insert().values(row))
|
||||
|
||||
# Commit the changes for each table
|
||||
destination_conn.commit()
|
||||
|
||||
return destination_engine
|
||||
|
||||
def _get_filtered_sequences_if_any(self) -> Optional[pd.Series]:
|
||||
# maximum possible filter (if limit_sequences_per_category_to == 0):
|
||||
# WHERE category IN 'self.pick_categories'
|
||||
@@ -396,19 +486,22 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase): # pyre-ignore
|
||||
*self._get_pick_filters(),
|
||||
*self._get_exclude_filters(),
|
||||
]
|
||||
if self.pick_sequences_sql_clause:
|
||||
print("Applying the custom SQL clause.")
|
||||
where_conditions.append(sa.text(self.pick_sequences_sql_clause))
|
||||
|
||||
def add_where(stmt):
|
||||
return stmt.where(*where_conditions) if where_conditions else stmt
|
||||
|
||||
if self.limit_sequences_per_category_to <= 0:
|
||||
stmt = add_where(sa.select(SqlSequenceAnnotation.sequence_name))
|
||||
stmt = add_where(sa.select(self.sequence_annotations_type.sequence_name))
|
||||
else:
|
||||
subquery = sa.select(
|
||||
SqlSequenceAnnotation.sequence_name,
|
||||
self.sequence_annotations_type.sequence_name,
|
||||
sa.func.row_number()
|
||||
.over(
|
||||
order_by=sa.text("ROWID"), # NOTE: ROWID is SQLite-specific
|
||||
partition_by=SqlSequenceAnnotation.category,
|
||||
partition_by=self.sequence_annotations_type.category,
|
||||
)
|
||||
.label("row_number"),
|
||||
)
|
||||
@@ -444,31 +537,34 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase): # pyre-ignore
|
||||
return []
|
||||
|
||||
logger.info(f"Limiting dataset to categories: {self.pick_categories}")
|
||||
return [SqlSequenceAnnotation.category.in_(self.pick_categories)]
|
||||
return [self.sequence_annotations_type.category.in_(self.pick_categories)]
|
||||
|
||||
def _get_pick_filters(self) -> List[sa.ColumnElement]:
|
||||
if not self.pick_sequences:
|
||||
return []
|
||||
|
||||
logger.info(f"Limiting dataset to sequences: {self.pick_sequences}")
|
||||
return [SqlSequenceAnnotation.sequence_name.in_(self.pick_sequences)]
|
||||
return [self.sequence_annotations_type.sequence_name.in_(self.pick_sequences)]
|
||||
|
||||
def _get_exclude_filters(self) -> List[sa.ColumnOperators]:
|
||||
if not self.exclude_sequences:
|
||||
return []
|
||||
|
||||
logger.info(f"Removing sequences from the dataset: {self.exclude_sequences}")
|
||||
return [SqlSequenceAnnotation.sequence_name.notin_(self.exclude_sequences)]
|
||||
return [
|
||||
self.sequence_annotations_type.sequence_name.notin_(self.exclude_sequences)
|
||||
]
|
||||
|
||||
def _load_subsets_from_json(self, subset_lists_path: str) -> pd.DataFrame:
|
||||
assert self.subsets is not None
|
||||
subsets = self.subsets
|
||||
assert subsets is not None
|
||||
with open(subset_lists_path, "r") as f:
|
||||
subset_to_seq_frame = json.load(f)
|
||||
|
||||
seq_frame_list = sum(
|
||||
(
|
||||
[(*row, subset) for row in subset_to_seq_frame[subset]]
|
||||
for subset in self.subsets
|
||||
for subset in subsets
|
||||
),
|
||||
[],
|
||||
)
|
||||
@@ -522,7 +618,7 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase): # pyre-ignore
|
||||
stmt = sa.select(
|
||||
self.frame_annotations_type.sequence_name,
|
||||
self.frame_annotations_type.frame_number,
|
||||
).where(self.frame_annotations_type._mask_mass == 0)
|
||||
).where(self.frame_annotations_type._mask_mass == 0) # pyre-ignore[16]
|
||||
with Session(self._sql_engine) as session:
|
||||
to_remove = session.execute(stmt).all()
|
||||
|
||||
@@ -586,7 +682,7 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase): # pyre-ignore
|
||||
stmt = sa.select(
|
||||
self.frame_annotations_type.sequence_name,
|
||||
self.frame_annotations_type.frame_number,
|
||||
self.frame_annotations_type._image_path,
|
||||
self.frame_annotations_type._image_path, # pyre-ignore[16]
|
||||
sa.null().label("subset"),
|
||||
)
|
||||
where_conditions = []
|
||||
@@ -600,7 +696,7 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase): # pyre-ignore
|
||||
logger.info(" excluding samples with empty masks")
|
||||
where_conditions.append(
|
||||
sa.or_(
|
||||
self.frame_annotations_type._mask_mass.is_(None),
|
||||
self.frame_annotations_type._mask_mass.is_(None), # pyre-ignore[16]
|
||||
self.frame_annotations_type._mask_mass != 0,
|
||||
)
|
||||
)
|
||||
@@ -634,7 +730,9 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase): # pyre-ignore
|
||||
assert self.eval_batches_file
|
||||
logger.info(f"Loading eval batches from {self.eval_batches_file}")
|
||||
|
||||
if not os.path.isfile(self.eval_batches_file):
|
||||
if (
|
||||
self.path_manager and not self.path_manager.isfile(self.eval_batches_file)
|
||||
) or (not self.path_manager and not os.path.isfile(self.eval_batches_file)):
|
||||
# The batch indices file does not exist.
|
||||
# Most probably the user has not specified the root folder.
|
||||
raise ValueError(
|
||||
@@ -642,7 +740,8 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase): # pyre-ignore
|
||||
+ "Please specify a correct dataset_root folder."
|
||||
)
|
||||
|
||||
with open(self.eval_batches_file, "r") as f:
|
||||
eval_batches_file = self._local_path(self.eval_batches_file)
|
||||
with open(eval_batches_file, "r") as f:
|
||||
eval_batches = json.load(f)
|
||||
|
||||
# limit the dataset to sequences to allow multiple evaluations in one file
|
||||
@@ -726,9 +825,15 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase): # pyre-ignore
|
||||
self.frame_annotations_type.sequence_name == seq_name,
|
||||
self.frame_annotations_type.frame_number.in_(frames),
|
||||
)
|
||||
frame_no_ts = None
|
||||
|
||||
with self._sql_engine.connect() as connection:
|
||||
frame_no_ts = pd.read_sql_query(stmt, connection)
|
||||
if self.scoped_session:
|
||||
stmt_text = str(stmt.compile(compile_kwargs={"literal_binds": True}))
|
||||
with scoped_session(self._session_factory)() as session: # pyre-ignore
|
||||
frame_no_ts = pd.read_sql_query(stmt_text, session.connection())
|
||||
else:
|
||||
with self._sql_engine.connect() as connection:
|
||||
frame_no_ts = pd.read_sql_query(stmt, connection)
|
||||
|
||||
if len(frame_no_ts) != len(index_slice):
|
||||
raise ValueError(
|
||||
@@ -758,11 +863,18 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase): # pyre-ignore
|
||||
prefixes=["TEMP"], # NOTE SQLite specific!
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def pre_expand(cls) -> None:
|
||||
# remove dataclass annotations that are not meant to be init params
|
||||
# because they cause troubles for OmegaConf
|
||||
for attr, attr_value in list(cls.__dict__.items()): # need to copy as we mutate
|
||||
if isinstance(attr_value, Field) and attr_value.metadata.get(
|
||||
"omegaconf_ignore", False
|
||||
):
|
||||
delattr(cls, attr)
|
||||
del cls.__annotations__[attr]
|
||||
|
||||
|
||||
def _seq_name_to_seed(seq_name) -> int:
|
||||
"""Generates numbers in [0, 2 ** 28)"""
|
||||
return int(hashlib.sha1(seq_name.encode("utf-8")).hexdigest()[:7], 16)
|
||||
|
||||
|
||||
def _safe_as_tensor(data, dtype):
|
||||
return torch.tensor(data, dtype=dtype) if data is not None else None
|
||||
|
||||
@@ -4,6 +4,8 @@
|
||||
# This source code is licensed under the BSD-style license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
# pyre-unsafe
|
||||
|
||||
|
||||
import logging
|
||||
import os
|
||||
@@ -43,7 +45,7 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@registry.register
|
||||
class SqlIndexDatasetMapProvider(DatasetMapProviderBase): # pyre-ignore [13]
|
||||
class SqlIndexDatasetMapProvider(DatasetMapProviderBase):
|
||||
"""
|
||||
Generates the training, validation, and testing dataset objects for
|
||||
a dataset laid out on disk like SQL-CO3D, with annotations in an SQLite data base.
|
||||
@@ -193,9 +195,9 @@ class SqlIndexDatasetMapProvider(DatasetMapProviderBase): # pyre-ignore [13]
|
||||
|
||||
# this is a mould that is never constructed, used to build self._dataset_map values
|
||||
dataset_class_type: str = "SqlIndexDataset"
|
||||
dataset: SqlIndexDataset
|
||||
dataset: SqlIndexDataset # pyre-ignore [13]
|
||||
|
||||
path_manager_factory: PathManagerFactory
|
||||
path_manager_factory: PathManagerFactory # pyre-ignore [13]
|
||||
path_manager_factory_class_type: str = "PathManagerFactory"
|
||||
|
||||
def __post_init__(self):
|
||||
@@ -282,8 +284,14 @@ class SqlIndexDatasetMapProvider(DatasetMapProviderBase): # pyre-ignore [13]
|
||||
logger.info(f"Val dataset: {str(val_dataset)}")
|
||||
|
||||
logger.debug("Extracting test dataset.")
|
||||
eval_batches_file = self._get_lists_file("eval_batches")
|
||||
del common_dataset_kwargs["eval_batches_file"]
|
||||
if self.eval_batches_path is None:
|
||||
eval_batches_file = None
|
||||
else:
|
||||
eval_batches_file = self._get_lists_file("eval_batches")
|
||||
|
||||
if "eval_batches_file" in common_dataset_kwargs:
|
||||
common_dataset_kwargs.pop("eval_batches_file", None)
|
||||
|
||||
test_dataset = dataset_type(
|
||||
**common_dataset_kwargs,
|
||||
subsets=self._get_subsets(self.test_subsets, True),
|
||||
|
||||
@@ -87,6 +87,15 @@ def is_train_frame(
|
||||
def get_bbox_from_mask(
|
||||
mask: np.ndarray, thr: float, decrease_quant: float = 0.05
|
||||
) -> Tuple[int, int, int, int]:
|
||||
# these corner cases need to be handled in order to avoid an infinite loop
|
||||
if mask.size == 0:
|
||||
warnings.warn("Empty mask is provided for bbox extraction.", stacklevel=1)
|
||||
return 0, 0, 1, 1
|
||||
|
||||
if not mask.min() >= 0.0:
|
||||
warnings.warn("Negative values in the mask for bbox extraction.", stacklevel=1)
|
||||
mask = mask.clip(min=0.0)
|
||||
|
||||
# bbox in xywh
|
||||
masks_for_box = np.zeros_like(mask)
|
||||
while masks_for_box.sum() <= 1.0:
|
||||
@@ -134,7 +143,15 @@ T = TypeVar("T", bound=torch.Tensor)
|
||||
def bbox_xyxy_to_xywh(xyxy: T) -> T:
|
||||
wh = xyxy[2:] - xyxy[:2]
|
||||
xywh = torch.cat([xyxy[:2], wh])
|
||||
return xywh # pyre-ignore
|
||||
return xywh # pyre-ignore[7]
|
||||
|
||||
|
||||
def bbox_xywh_to_xyxy(xywh: T, clamp_size: float | int | None = None) -> T:
|
||||
wh = xywh[2:]
|
||||
if clamp_size is not None:
|
||||
wh = wh.clamp(min=clamp_size)
|
||||
xyxy = torch.cat([xywh[:2], xywh[:2] + wh])
|
||||
return xyxy # pyre-ignore[7]
|
||||
|
||||
|
||||
def get_clamp_bbox(
|
||||
@@ -180,16 +197,6 @@ def rescale_bbox(
|
||||
return bbox * rel_size
|
||||
|
||||
|
||||
def bbox_xywh_to_xyxy(
|
||||
xywh: torch.Tensor, clamp_size: Optional[int] = None
|
||||
) -> torch.Tensor:
|
||||
xyxy = xywh.clone()
|
||||
if clamp_size is not None:
|
||||
xyxy[2:] = torch.clamp(xyxy[2:], clamp_size)
|
||||
xyxy[2:] += xyxy[:2]
|
||||
return xyxy
|
||||
|
||||
|
||||
def get_1d_bounds(arr: np.ndarray) -> Tuple[int, int]:
|
||||
nz = np.flatnonzero(arr)
|
||||
return nz[0], nz[-1] + 1
|
||||
@@ -201,18 +208,24 @@ def resize_image(
|
||||
image_width: Optional[int],
|
||||
mode: str = "bilinear",
|
||||
) -> Tuple[torch.Tensor, float, torch.Tensor]:
|
||||
|
||||
if isinstance(image, np.ndarray):
|
||||
image = torch.from_numpy(image)
|
||||
|
||||
if image_height is None or image_width is None:
|
||||
if (
|
||||
image_height is None
|
||||
or image_width is None
|
||||
or image.shape[-2] == 0
|
||||
or image.shape[-1] == 0
|
||||
):
|
||||
# skip the resizing
|
||||
return image, 1.0, torch.ones_like(image[:1])
|
||||
|
||||
# takes numpy array or tensor, returns pytorch tensor
|
||||
minscale = min(
|
||||
image_height / image.shape[-2],
|
||||
image_width / image.shape[-1],
|
||||
)
|
||||
|
||||
imre = torch.nn.functional.interpolate(
|
||||
image[None],
|
||||
scale_factor=minscale,
|
||||
@@ -220,6 +233,7 @@ def resize_image(
|
||||
align_corners=False if mode == "bilinear" else None,
|
||||
recompute_scale_factor=True,
|
||||
)[0]
|
||||
|
||||
imre_ = torch.zeros(image.shape[0], image_height, image_width)
|
||||
imre_[:, 0 : imre.shape[1], 0 : imre.shape[2]] = imre
|
||||
mask = torch.zeros(1, image_height, image_width)
|
||||
@@ -232,9 +246,21 @@ def transpose_normalize_image(image: np.ndarray) -> np.ndarray:
|
||||
return im.astype(np.float32) / 255.0
|
||||
|
||||
|
||||
def load_image(path: str) -> np.ndarray:
|
||||
def load_image(
|
||||
path: str, try_read_alpha: bool = False, pil_format: str = "RGB"
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Load an image from a path and return it as a numpy array.
|
||||
If try_read_alpha is True, the image is read as RGBA and the alpha channel is
|
||||
returned as the fourth channel.
|
||||
Otherwise, the image is read as RGB and a three-channel image is returned.
|
||||
"""
|
||||
with Image.open(path) as pil_im:
|
||||
im = np.array(pil_im.convert("RGB"))
|
||||
# Check if the image has an alpha channel
|
||||
if try_read_alpha and pil_im.mode == "RGBA":
|
||||
im = np.array(pil_im)
|
||||
else:
|
||||
im = np.array(pil_im.convert(pil_format))
|
||||
|
||||
return transpose_normalize_image(im)
|
||||
|
||||
@@ -329,6 +355,7 @@ def adjust_camera_to_bbox_crop_(
|
||||
|
||||
focal_length_px, principal_point_px = _convert_ndc_to_pixels(
|
||||
camera.focal_length[0],
|
||||
# pyre-fixme[29]: `Union[(self: TensorBase, indices: Union[None, slice[Any, A...
|
||||
camera.principal_point[0],
|
||||
image_size_wh,
|
||||
)
|
||||
@@ -341,6 +368,7 @@ def adjust_camera_to_bbox_crop_(
|
||||
)
|
||||
|
||||
camera.focal_length = focal_length[None]
|
||||
# pyre-fixme[16]: `PerspectiveCameras` has no attribute `principal_point`.
|
||||
camera.principal_point = principal_point_cropped[None]
|
||||
|
||||
|
||||
@@ -352,6 +380,7 @@ def adjust_camera_to_image_scale_(
|
||||
) -> PerspectiveCameras:
|
||||
focal_length_px, principal_point_px = _convert_ndc_to_pixels(
|
||||
camera.focal_length[0],
|
||||
# pyre-fixme[29]: `Union[(self: TensorBase, indices: Union[None, slice[Any, A...
|
||||
camera.principal_point[0],
|
||||
original_size_wh,
|
||||
)
|
||||
@@ -368,7 +397,8 @@ def adjust_camera_to_image_scale_(
|
||||
image_size_wh_output,
|
||||
)
|
||||
camera.focal_length = focal_length_scaled[None]
|
||||
camera.principal_point = principal_point_scaled[None]
|
||||
# pyre-fixme[16]: `PerspectiveCameras` has no attribute `principal_point`.
|
||||
camera.principal_point = principal_point_scaled[None] # pyre-ignore[16]
|
||||
|
||||
|
||||
# NOTE this cache is per-worker; they are implemented as processes.
|
||||
|
||||
@@ -299,7 +299,6 @@ def eval_batch(
|
||||
)
|
||||
|
||||
for loss_fg_mask, name_postfix in zip((mask_crop, mask_fg), ("_masked", "_fg")):
|
||||
|
||||
loss_mask_now = mask_crop * loss_fg_mask
|
||||
|
||||
for rgb_metric_name, rgb_metric_fun in zip(
|
||||
|
||||
@@ -106,7 +106,7 @@ class ResNetFeatureExtractor(FeatureExtractorBase):
|
||||
self.layers = torch.nn.ModuleList()
|
||||
self.proj_layers = torch.nn.ModuleList()
|
||||
for stage in range(self.max_stage):
|
||||
stage_name = f"layer{stage+1}"
|
||||
stage_name = f"layer{stage + 1}"
|
||||
feature_name = self._get_resnet_stage_feature_name(stage)
|
||||
if (stage + 1) in self.stages:
|
||||
if (
|
||||
@@ -139,12 +139,18 @@ class ResNetFeatureExtractor(FeatureExtractorBase):
|
||||
self.stages = set(self.stages) # convert to set for faster "in"
|
||||
|
||||
def _get_resnet_stage_feature_name(self, stage) -> str:
|
||||
return f"res_layer_{stage+1}"
|
||||
return f"res_layer_{stage + 1}"
|
||||
|
||||
def _resnet_normalize_image(self, img: torch.Tensor) -> torch.Tensor:
|
||||
# pyre-fixme[58]: `-` is not supported for operand types `Tensor` and
|
||||
# `Union[Tensor, Module]`.
|
||||
# pyre-fixme[58]: `/` is not supported for operand types `Tensor` and
|
||||
# `Union[Tensor, Module]`.
|
||||
return (img - self._resnet_mean) / self._resnet_std
|
||||
|
||||
def get_feat_dims(self) -> int:
|
||||
# pyre-fixme[29]: `Union[(self: TensorBase) -> Tensor, Tensor, Module]` is
|
||||
# not a function.
|
||||
return sum(self._feat_dim.values())
|
||||
|
||||
def forward(
|
||||
@@ -183,7 +189,12 @@ class ResNetFeatureExtractor(FeatureExtractorBase):
|
||||
else:
|
||||
imgs_normed = imgs_resized
|
||||
# is not a function.
|
||||
# pyre-fixme[29]: `Union[Tensor, Module]` is not a function.
|
||||
feats = self.stem(imgs_normed)
|
||||
# pyre-fixme[6]: For 1st argument expected `Iterable[_T1]` but got
|
||||
# `Union[Tensor, Module]`.
|
||||
# pyre-fixme[6]: For 2nd argument expected `Iterable[_T2]` but got
|
||||
# `Union[Tensor, Module]`.
|
||||
for stage, (layer, proj) in enumerate(zip(self.layers, self.proj_layers)):
|
||||
feats = layer(feats)
|
||||
# just a sanity check below
|
||||
|
||||
@@ -478,6 +478,8 @@ class GenericModel(ImplicitronModelBase):
|
||||
)
|
||||
custom_args["global_code"] = global_code
|
||||
|
||||
# pyre-fixme[29]: `Union[(self: Tensor) -> Any, Tensor, Module]` is not a
|
||||
# function.
|
||||
for func in self._implicit_functions:
|
||||
func.bind_args(**custom_args)
|
||||
|
||||
@@ -500,6 +502,8 @@ class GenericModel(ImplicitronModelBase):
|
||||
# Unbind the custom arguments to prevent pytorch from storing
|
||||
# large buffers of intermediate results due to points in the
|
||||
# bound arguments.
|
||||
# pyre-fixme[29]: `Union[(self: Tensor) -> Any, Tensor, Module]` is not a
|
||||
# function.
|
||||
for func in self._implicit_functions:
|
||||
func.unbind_args()
|
||||
|
||||
|
||||
@@ -71,6 +71,7 @@ class Autodecoder(Configurable, torch.nn.Module):
|
||||
return key_map
|
||||
|
||||
def calculate_squared_encoding_norm(self) -> Optional[torch.Tensor]:
|
||||
# pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute `weight`.
|
||||
return (self._autodecoder_codes.weight**2).mean()
|
||||
|
||||
def get_encoding_dim(self) -> int:
|
||||
@@ -95,6 +96,7 @@ class Autodecoder(Configurable, torch.nn.Module):
|
||||
# pyre-fixme[9]: x has type `Union[List[str], LongTensor]`; used as
|
||||
# `Tensor`.
|
||||
x = torch.tensor(
|
||||
# pyre-fixme[29]: `Union[(self: TensorBase, indices: Union[None, ...
|
||||
[self._key_map[elem] for elem in x],
|
||||
dtype=torch.long,
|
||||
device=next(self.parameters()).device,
|
||||
@@ -102,6 +104,7 @@ class Autodecoder(Configurable, torch.nn.Module):
|
||||
except StopIteration:
|
||||
raise ValueError("Not enough n_instances in the autodecoder") from None
|
||||
|
||||
# pyre-fixme[29]: `Union[Tensor, Module]` is not a function.
|
||||
return self._autodecoder_codes(x)
|
||||
|
||||
def _load_key_map_hook(
|
||||
|
||||
@@ -122,6 +122,7 @@ class HarmonicTimeEncoder(GlobalEncoderBase, torch.nn.Module):
|
||||
if frame_timestamp.shape[-1] != 1:
|
||||
raise ValueError("Frame timestamp's last dimensions should be one.")
|
||||
time = frame_timestamp / self.time_divisor
|
||||
# pyre-fixme[29]: `Union[Tensor, Module]` is not a function.
|
||||
return self._harmonic_embedding(time)
|
||||
|
||||
def calculate_squared_encoding_norm(self) -> Optional[torch.Tensor]:
|
||||
|
||||
@@ -232,9 +232,14 @@ class MLPWithInputSkips(Configurable, torch.nn.Module):
|
||||
# if the skip tensor is None, we use `x` instead.
|
||||
z = x
|
||||
skipi = 0
|
||||
# pyre-fixme[6]: For 1st argument expected `Iterable[_T]` but got
|
||||
# `Union[Tensor, Module]`.
|
||||
for li, layer in enumerate(self.mlp):
|
||||
# pyre-fixme[58]: `in` is not supported for right operand type
|
||||
# `Union[Tensor, Module]`.
|
||||
if li in self._input_skips:
|
||||
if self._skip_affine_trans:
|
||||
# pyre-fixme[29]: `Union[(self: TensorBase, indices: Union[None, ...
|
||||
y = self._apply_affine_layer(self.skip_affines[skipi], y, z)
|
||||
else:
|
||||
y = torch.cat((y, z), dim=-1)
|
||||
|
||||
@@ -141,11 +141,16 @@ class IdrFeatureField(ImplicitFunctionBase, torch.nn.Module):
|
||||
self.embed_fn is None and fun_viewpool is None and global_code is None
|
||||
):
|
||||
return torch.tensor(
|
||||
[], device=rays_points_world.device, dtype=rays_points_world.dtype
|
||||
[],
|
||||
device=rays_points_world.device,
|
||||
dtype=rays_points_world.dtype,
|
||||
# pyre-fixme[6]: For 2nd argument expected `Union[int, SymInt]` but got
|
||||
# `Union[Module, Tensor]`.
|
||||
).view(0, self.out_dim)
|
||||
|
||||
embeddings = []
|
||||
if self.embed_fn is not None:
|
||||
# pyre-fixme[29]: `Union[Module, Tensor]` is not a function.
|
||||
embeddings.append(self.embed_fn(rays_points_world))
|
||||
|
||||
if fun_viewpool is not None:
|
||||
@@ -164,13 +169,19 @@ class IdrFeatureField(ImplicitFunctionBase, torch.nn.Module):
|
||||
|
||||
embedding = torch.cat(embeddings, dim=-1)
|
||||
x = embedding
|
||||
# pyre-fixme[29]: `Union[(self: TensorBase, other: Union[bool, complex,
|
||||
# float, int, Tensor]) -> Tensor, Module, Tensor]` is not a function.
|
||||
for layer_idx in range(self.num_layers - 1):
|
||||
if layer_idx in self.skip_in:
|
||||
x = torch.cat([x, embedding], dim=-1) / 2**0.5
|
||||
|
||||
# pyre-fixme[29]: `Union[(self: TensorBase, indices: Union[None, slice[An...
|
||||
x = self.linear_layers[layer_idx](x)
|
||||
|
||||
# pyre-fixme[29]: `Union[(self: TensorBase, other: Union[bool, complex,
|
||||
# float, int, Tensor]) -> Tensor, Module, Tensor]` is not a function.
|
||||
if layer_idx < self.num_layers - 2:
|
||||
# pyre-fixme[29]: `Union[Module, Tensor]` is not a function.
|
||||
x = self.softplus(x)
|
||||
|
||||
return x
|
||||
|
||||
@@ -123,8 +123,10 @@ class NeuralRadianceFieldBase(ImplicitFunctionBase, torch.nn.Module):
|
||||
# Normalize the ray_directions to unit l2 norm.
|
||||
rays_directions_normed = torch.nn.functional.normalize(rays_directions, dim=-1)
|
||||
# Obtain the harmonic embedding of the normalized ray directions.
|
||||
# pyre-fixme[29]: `Union[Tensor, Module]` is not a function.
|
||||
rays_embedding = self.harmonic_embedding_dir(rays_directions_normed)
|
||||
|
||||
# pyre-fixme[29]: `Union[Tensor, Module]` is not a function.
|
||||
return self.color_layer((self.intermediate_linear(features), rays_embedding))
|
||||
|
||||
@staticmethod
|
||||
@@ -195,6 +197,8 @@ class NeuralRadianceFieldBase(ImplicitFunctionBase, torch.nn.Module):
|
||||
embeds = create_embeddings_for_implicit_function(
|
||||
xyz_world=rays_points_world,
|
||||
# for 2nd param but got `Union[None, torch.Tensor, torch.nn.Module]`.
|
||||
# pyre-fixme[6]: For 2nd argument expected `Optional[(...) -> Any]` but
|
||||
# got `Union[None, Tensor, Module]`.
|
||||
xyz_embedding_function=(
|
||||
self.harmonic_embedding_xyz if self.input_xyz else None
|
||||
),
|
||||
@@ -206,12 +210,14 @@ class NeuralRadianceFieldBase(ImplicitFunctionBase, torch.nn.Module):
|
||||
)
|
||||
|
||||
# embeds.shape = [minibatch x n_src x n_rays x n_pts x self.n_harmonic_functions*6+3]
|
||||
# pyre-fixme[29]: `Union[Tensor, Module]` is not a function.
|
||||
features = self.xyz_encoder(embeds)
|
||||
# features.shape = [minibatch x ... x self.n_hidden_neurons_xyz]
|
||||
# NNs operate on the flattenned rays; reshaping to the correct spatial size
|
||||
# TODO: maybe make the transformer work on non-flattened tensors to avoid this reshape
|
||||
features = features.reshape(*rays_points_world.shape[:-1], -1)
|
||||
|
||||
# pyre-fixme[29]: `Union[Tensor, Module]` is not a function.
|
||||
raw_densities = self.density_layer(features)
|
||||
# raw_densities.shape = [minibatch x ... x 1] in [0-1]
|
||||
|
||||
@@ -219,6 +225,8 @@ class NeuralRadianceFieldBase(ImplicitFunctionBase, torch.nn.Module):
|
||||
if camera is None:
|
||||
raise ValueError("Camera must be given if xyz_ray_dir_in_camera_coords")
|
||||
|
||||
# pyre-fixme[58]: `@` is not supported for operand types `Tensor` and
|
||||
# `Union[Tensor, Module]`.
|
||||
directions = ray_bundle.directions @ camera.R
|
||||
else:
|
||||
directions = ray_bundle.directions
|
||||
|
||||
@@ -103,6 +103,8 @@ class SRNRaymarchFunction(Configurable, torch.nn.Module):
|
||||
|
||||
embeds = create_embeddings_for_implicit_function(
|
||||
xyz_world=rays_points_world,
|
||||
# pyre-fixme[6]: For 2nd argument expected `Optional[(...) -> Any]` but
|
||||
# got `Union[Tensor, Module]`.
|
||||
xyz_embedding_function=self._harmonic_embedding,
|
||||
global_code=global_code,
|
||||
fun_viewpool=fun_viewpool,
|
||||
@@ -112,6 +114,7 @@ class SRNRaymarchFunction(Configurable, torch.nn.Module):
|
||||
|
||||
# Before running the network, we have to resize embeds to ndims=3,
|
||||
# otherwise the SRN layers consume huge amounts of memory.
|
||||
# pyre-fixme[29]: `Union[Tensor, Module]` is not a function.
|
||||
raymarch_features = self._net(
|
||||
embeds.view(embeds.shape[0], -1, embeds.shape[-1])
|
||||
)
|
||||
@@ -166,7 +169,9 @@ class SRNPixelGenerator(Configurable, torch.nn.Module):
|
||||
# Normalize the ray_directions to unit l2 norm.
|
||||
rays_directions_normed = torch.nn.functional.normalize(rays_directions, dim=-1)
|
||||
# Obtain the harmonic embedding of the normalized ray directions.
|
||||
# pyre-fixme[29]: `Union[Tensor, Module]` is not a function.
|
||||
rays_embedding = self._harmonic_embedding(rays_directions_normed)
|
||||
# pyre-fixme[29]: `Union[Tensor, Module]` is not a function.
|
||||
return self._color_layer((features, rays_embedding))
|
||||
|
||||
def forward(
|
||||
@@ -195,6 +200,7 @@ class SRNPixelGenerator(Configurable, torch.nn.Module):
|
||||
denoting the color of each ray point.
|
||||
"""
|
||||
# raymarch_features.shape = [minibatch x ... x pts_per_ray x 3]
|
||||
# pyre-fixme[29]: `Union[Tensor, Module]` is not a function.
|
||||
features = self._net(raymarch_features)
|
||||
# features.shape = [minibatch x ... x self.n_hidden_units]
|
||||
|
||||
@@ -202,6 +208,8 @@ class SRNPixelGenerator(Configurable, torch.nn.Module):
|
||||
if camera is None:
|
||||
raise ValueError("Camera must be given if xyz_ray_dir_in_camera_coords")
|
||||
|
||||
# pyre-fixme[58]: `@` is not supported for operand types `Tensor` and
|
||||
# `Union[Tensor, Module]`.
|
||||
directions = ray_bundle.directions @ camera.R
|
||||
else:
|
||||
directions = ray_bundle.directions
|
||||
@@ -209,6 +217,7 @@ class SRNPixelGenerator(Configurable, torch.nn.Module):
|
||||
# NNs operate on the flattenned rays; reshaping to the correct spatial size
|
||||
features = features.reshape(*raymarch_features.shape[:-1], -1)
|
||||
|
||||
# pyre-fixme[29]: `Union[Tensor, Module]` is not a function.
|
||||
raw_densities = self._density_layer(features)
|
||||
|
||||
rays_colors = self._get_colors(features, directions)
|
||||
@@ -269,6 +278,7 @@ class SRNRaymarchHyperNet(Configurable, torch.nn.Module):
|
||||
srn_raymarch_function.
|
||||
"""
|
||||
|
||||
# pyre-fixme[29]: `Union[Tensor, Module]` is not a function.
|
||||
net = self._hypernet(global_code)
|
||||
|
||||
# use the hyper-net generated network to instantiate the raymarch module
|
||||
@@ -296,7 +306,6 @@ class SRNRaymarchHyperNet(Configurable, torch.nn.Module):
|
||||
global_code=None,
|
||||
**kwargs,
|
||||
):
|
||||
|
||||
if global_code is None:
|
||||
raise ValueError("SRN Hypernetwork requires a non-trivial global code.")
|
||||
|
||||
@@ -304,6 +313,8 @@ class SRNRaymarchHyperNet(Configurable, torch.nn.Module):
|
||||
# across LSTM iterations for the same global_code.
|
||||
if self.cached_srn_raymarch_function is None:
|
||||
# generate the raymarching network from the hypernet
|
||||
# pyre-fixme[16]: `SRNRaymarchHyperNet` has no attribute
|
||||
# `cached_srn_raymarch_function`.
|
||||
self.cached_srn_raymarch_function = self._run_hypernet(global_code)
|
||||
(srn_raymarch_function,) = cast(
|
||||
Tuple[SRNRaymarchFunction], self.cached_srn_raymarch_function
|
||||
@@ -331,6 +342,7 @@ class SRNImplicitFunction(ImplicitFunctionBase, torch.nn.Module):
|
||||
def create_raymarch_function(self) -> None:
|
||||
self.raymarch_function = SRNRaymarchFunction(
|
||||
latent_dim=self.latent_dim,
|
||||
# pyre-fixme[32]: Keyword argument must be a mapping with string keys.
|
||||
**self.raymarch_function_args,
|
||||
)
|
||||
|
||||
@@ -389,6 +401,7 @@ class SRNHyperNetImplicitFunction(ImplicitFunctionBase, torch.nn.Module):
|
||||
self.hypernet = SRNRaymarchHyperNet(
|
||||
latent_dim=self.latent_dim,
|
||||
latent_dim_hypernet=self.latent_dim_hypernet,
|
||||
# pyre-fixme[32]: Keyword argument must be a mapping with string keys.
|
||||
**self.hypernet_args,
|
||||
)
|
||||
|
||||
|
||||
@@ -40,7 +40,6 @@ def create_embeddings_for_implicit_function(
|
||||
xyz_embedding_function: Optional[Callable],
|
||||
diag_cov: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
|
||||
bs, *spatial_size, pts_per_ray, _ = xyz_world.shape
|
||||
|
||||
if xyz_in_camera_coords:
|
||||
@@ -64,7 +63,6 @@ def create_embeddings_for_implicit_function(
|
||||
0,
|
||||
)
|
||||
else:
|
||||
|
||||
embeds = xyz_embedding_function(ray_points_for_embed, diag_cov=diag_cov)
|
||||
embeds = embeds.reshape(
|
||||
bs,
|
||||
|
||||
@@ -269,6 +269,7 @@ class VoxelGridBase(ReplaceableBase, torch.nn.Module):
|
||||
for name, tensor in vars(grid_values_with_wanted_resolution).items()
|
||||
}
|
||||
|
||||
# pyre-fixme[29]: `Union[Tensor, Module]` is not a function.
|
||||
return self.values_type(**params), True
|
||||
|
||||
def get_resolution_change_epochs(self) -> Tuple[int, ...]:
|
||||
@@ -882,6 +883,7 @@ class VoxelGridModule(Configurable, torch.nn.Module):
|
||||
torch.Tensor of shape (..., n_features)
|
||||
"""
|
||||
locator = self._get_volume_locator()
|
||||
# pyre-fixme[29]: `Union[Tensor, Module]` is not a function.
|
||||
grid_values = self.voxel_grid.values_type(**self.params)
|
||||
# voxel grids operate with extra n_grids dimension, which we fix to one
|
||||
return self.voxel_grid.evaluate_world(points[None], grid_values, locator)[0]
|
||||
@@ -895,6 +897,7 @@ class VoxelGridModule(Configurable, torch.nn.Module):
|
||||
replace current parameters
|
||||
"""
|
||||
if self.hold_voxel_grid_as_parameters:
|
||||
# pyre-fixme[16]: `VoxelGridModule` has no attribute `params`.
|
||||
self.params = torch.nn.ParameterDict(
|
||||
{
|
||||
k: torch.nn.Parameter(val)
|
||||
@@ -945,6 +948,7 @@ class VoxelGridModule(Configurable, torch.nn.Module):
|
||||
Returns:
|
||||
True if parameter change has happened else False.
|
||||
"""
|
||||
# pyre-fixme[29]: `Union[Tensor, Module]` is not a function.
|
||||
grid_values = self.voxel_grid.values_type(**self.params)
|
||||
grid_values, change = self.voxel_grid.change_resolution(
|
||||
grid_values, epoch=epoch
|
||||
@@ -992,16 +996,21 @@ class VoxelGridModule(Configurable, torch.nn.Module):
|
||||
"""
|
||||
'''
|
||||
new_params = {}
|
||||
# pyre-fixme[29]: `Union[(self: Tensor) -> Any, Tensor, Module]` is not a
|
||||
# function.
|
||||
for name in self.params:
|
||||
key = prefix + "params." + name
|
||||
if key in state_dict:
|
||||
new_params[name] = torch.zeros_like(state_dict[key])
|
||||
# pyre-fixme[29]: `Union[Tensor, Module]` is not a function.
|
||||
self.set_voxel_grid_parameters(self.voxel_grid.values_type(**new_params))
|
||||
|
||||
def get_device(self) -> torch.device:
|
||||
"""
|
||||
Returns torch.device on which module parameters are located
|
||||
"""
|
||||
# pyre-fixme[29]: `Union[(self: TensorBase) -> Tensor, Tensor, Module]` is
|
||||
# not a function.
|
||||
return next(val for val in self.params.values() if val is not None).device
|
||||
|
||||
def crop_self(self, min_point: torch.Tensor, max_point: torch.Tensor) -> None:
|
||||
@@ -1018,6 +1027,7 @@ class VoxelGridModule(Configurable, torch.nn.Module):
|
||||
"""
|
||||
locator = self._get_volume_locator()
|
||||
# torch.nn.modules.module.Module]` is not a function.
|
||||
# pyre-fixme[29]: `Union[Tensor, Module]` is not a function.
|
||||
old_grid_values = self.voxel_grid.values_type(**self.params)
|
||||
new_grid_values = self.voxel_grid.crop_world(
|
||||
min_point, max_point, old_grid_values, locator
|
||||
@@ -1025,6 +1035,7 @@ class VoxelGridModule(Configurable, torch.nn.Module):
|
||||
grid_values, _ = self.voxel_grid.change_resolution(
|
||||
new_grid_values, grid_values_with_wanted_resolution=old_grid_values
|
||||
)
|
||||
# pyre-fixme[16]: `VoxelGridModule` has no attribute `params`.
|
||||
self.params = torch.nn.ParameterDict(
|
||||
{
|
||||
k: torch.nn.Parameter(val)
|
||||
|
||||
@@ -192,16 +192,26 @@ class VoxelGridImplicitFunction(ImplicitFunctionBase, torch.nn.Module):
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
run_auto_creation(self)
|
||||
# pyre-fixme[16]: `VoxelGridImplicitFunction` has no attribute
|
||||
# `voxel_grid_scaffold`.
|
||||
self.voxel_grid_scaffold = self._create_voxel_grid_scaffold()
|
||||
# pyre-fixme[16]: `VoxelGridImplicitFunction` has no attribute
|
||||
# `harmonic_embedder_xyz_density`.
|
||||
self.harmonic_embedder_xyz_density = HarmonicEmbedding(
|
||||
**self.harmonic_embedder_xyz_density_args
|
||||
)
|
||||
# pyre-fixme[16]: `VoxelGridImplicitFunction` has no attribute
|
||||
# `harmonic_embedder_xyz_color`.
|
||||
self.harmonic_embedder_xyz_color = HarmonicEmbedding(
|
||||
**self.harmonic_embedder_xyz_color_args
|
||||
)
|
||||
# pyre-fixme[16]: `VoxelGridImplicitFunction` has no attribute
|
||||
# `harmonic_embedder_dir_color`.
|
||||
self.harmonic_embedder_dir_color = HarmonicEmbedding(
|
||||
**self.harmonic_embedder_dir_color_args
|
||||
)
|
||||
# pyre-fixme[16]: `VoxelGridImplicitFunction` has no attribute
|
||||
# `_scaffold_ready`.
|
||||
self._scaffold_ready = False
|
||||
|
||||
def forward(
|
||||
@@ -252,6 +262,7 @@ class VoxelGridImplicitFunction(ImplicitFunctionBase, torch.nn.Module):
|
||||
# ########## filter the points using the scaffold ########## #
|
||||
if self._scaffold_ready and self.scaffold_filter_points:
|
||||
with torch.no_grad():
|
||||
# pyre-fixme[29]: `Union[Tensor, Module]` is not a function.
|
||||
non_empty_points = self.voxel_grid_scaffold(points)[..., 0] > 0
|
||||
points = points[non_empty_points]
|
||||
if len(points) == 0:
|
||||
@@ -363,6 +374,7 @@ class VoxelGridImplicitFunction(ImplicitFunctionBase, torch.nn.Module):
|
||||
feature dimensionality which `decoder_density` returns
|
||||
"""
|
||||
embeds_density = self.voxel_grid_density(points)
|
||||
# pyre-fixme[29]: `Union[Tensor, Module]` is not a function.
|
||||
harmonic_embedding_density = self.harmonic_embedder_xyz_density(embeds_density)
|
||||
# shape = [..., density_dim]
|
||||
return self.decoder_density(harmonic_embedding_density)
|
||||
@@ -397,6 +409,8 @@ class VoxelGridImplicitFunction(ImplicitFunctionBase, torch.nn.Module):
|
||||
if self.xyz_ray_dir_in_camera_coords:
|
||||
if camera is None:
|
||||
raise ValueError("Camera must be given if xyz_ray_dir_in_camera_coords")
|
||||
# pyre-fixme[58]: `@` is not supported for operand types `Tensor` and
|
||||
# `Union[Tensor, Module]`.
|
||||
directions = directions @ camera.R
|
||||
|
||||
# ########## get voxel grid output ########## #
|
||||
@@ -405,11 +419,13 @@ class VoxelGridImplicitFunction(ImplicitFunctionBase, torch.nn.Module):
|
||||
|
||||
# ########## embed with the harmonic function ########## #
|
||||
# Obtain the harmonic embedding of the voxel grid output.
|
||||
# pyre-fixme[29]: `Union[Tensor, Module]` is not a function.
|
||||
harmonic_embedding_color = self.harmonic_embedder_xyz_color(embeds_color)
|
||||
|
||||
# Normalize the ray_directions to unit l2 norm.
|
||||
rays_directions_normed = torch.nn.functional.normalize(directions, dim=-1)
|
||||
# Obtain the harmonic embedding of the normalized ray directions.
|
||||
# pyre-fixme[29]: `Union[Tensor, Module]` is not a function.
|
||||
harmonic_embedding_dir = self.harmonic_embedder_dir_color(
|
||||
rays_directions_normed
|
||||
)
|
||||
@@ -478,8 +494,11 @@ class VoxelGridImplicitFunction(ImplicitFunctionBase, torch.nn.Module):
|
||||
an object inside, else False.
|
||||
"""
|
||||
# find bounding box
|
||||
# pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute
|
||||
# `get_grid_points`.
|
||||
points = self.voxel_grid_scaffold.get_grid_points(epoch=epoch)
|
||||
assert self._scaffold_ready, "Scaffold has to be calculated before cropping."
|
||||
# pyre-fixme[29]: `Union[Tensor, Module]` is not a function.
|
||||
occupancy = self.voxel_grid_scaffold(points)[..., 0] > 0
|
||||
non_zero_idxs = torch.nonzero(occupancy)
|
||||
if len(non_zero_idxs) == 0:
|
||||
@@ -511,6 +530,8 @@ class VoxelGridImplicitFunction(ImplicitFunctionBase, torch.nn.Module):
|
||||
"""
|
||||
|
||||
planes = []
|
||||
# pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute
|
||||
# `get_grid_points`.
|
||||
points = self.voxel_grid_scaffold.get_grid_points(epoch=epoch)
|
||||
|
||||
chunk_size = (
|
||||
@@ -530,7 +551,10 @@ class VoxelGridImplicitFunction(ImplicitFunctionBase, torch.nn.Module):
|
||||
stride=1,
|
||||
)
|
||||
occupancy_cube = density_cube > self.scaffold_empty_space_threshold
|
||||
# pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute `params`.
|
||||
self.voxel_grid_scaffold.params["voxel_grid"] = occupancy_cube.float()
|
||||
# pyre-fixme[16]: `VoxelGridImplicitFunction` has no attribute
|
||||
# `_scaffold_ready`.
|
||||
self._scaffold_ready = True
|
||||
|
||||
return False
|
||||
@@ -547,6 +571,8 @@ class VoxelGridImplicitFunction(ImplicitFunctionBase, torch.nn.Module):
|
||||
decoding function to this value.
|
||||
"""
|
||||
grid_args = self.voxel_grid_density_args
|
||||
# pyre-fixme[6]: For 1st argument expected `DictConfig` but got
|
||||
# `Union[Tensor, Module]`.
|
||||
grid_output_dim = VoxelGridModule.get_output_dim(grid_args)
|
||||
|
||||
embedder_args = self.harmonic_embedder_xyz_density_args
|
||||
@@ -575,6 +601,8 @@ class VoxelGridImplicitFunction(ImplicitFunctionBase, torch.nn.Module):
|
||||
decoding function to this value.
|
||||
"""
|
||||
grid_args = self.voxel_grid_color_args
|
||||
# pyre-fixme[6]: For 1st argument expected `DictConfig` but got
|
||||
# `Union[Tensor, Module]`.
|
||||
grid_output_dim = VoxelGridModule.get_output_dim(grid_args)
|
||||
|
||||
embedder_args = self.harmonic_embedder_xyz_color_args
|
||||
@@ -608,7 +636,9 @@ class VoxelGridImplicitFunction(ImplicitFunctionBase, torch.nn.Module):
|
||||
`self.voxel_grid_density`
|
||||
"""
|
||||
return VoxelGridModule(
|
||||
# pyre-fixme[29]: `Union[(self: TensorBase, indices: Union[None, slice[An...
|
||||
extents=self.voxel_grid_density_args["extents"],
|
||||
# pyre-fixme[29]: `Union[(self: TensorBase, indices: Union[None, slice[An...
|
||||
translation=self.voxel_grid_density_args["translation"],
|
||||
voxel_grid_class_type="FullResolutionVoxelGrid",
|
||||
hold_voxel_grid_as_parameters=False,
|
||||
|
||||
@@ -6,7 +6,6 @@
|
||||
|
||||
# pyre-unsafe
|
||||
|
||||
|
||||
import warnings
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
@@ -298,9 +297,8 @@ class ViewMetrics(ViewMetricsBase):
|
||||
_rgb_metrics(
|
||||
image_rgb,
|
||||
image_rgb_pred,
|
||||
fg_probability,
|
||||
fg_probability_pred,
|
||||
mask_crop,
|
||||
masks=fg_probability,
|
||||
masks_crop=mask_crop,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -310,9 +308,21 @@ class ViewMetrics(ViewMetricsBase):
|
||||
metrics["mask_neg_iou"] = utils.neg_iou_loss(
|
||||
fg_probability_pred, fg_probability, mask=mask_crop
|
||||
)
|
||||
metrics["mask_bce"] = utils.calc_bce(
|
||||
fg_probability_pred, fg_probability, mask=mask_crop
|
||||
)
|
||||
if torch.is_autocast_enabled():
|
||||
# To avoid issues with mixed precision
|
||||
metrics["mask_bce"] = utils.calc_bce(
|
||||
fg_probability_pred.logit(),
|
||||
fg_probability,
|
||||
mask=mask_crop,
|
||||
pred_logits=True,
|
||||
)
|
||||
else:
|
||||
metrics["mask_bce"] = utils.calc_bce(
|
||||
fg_probability_pred,
|
||||
fg_probability,
|
||||
mask=mask_crop,
|
||||
pred_logits=False,
|
||||
)
|
||||
|
||||
if depth_map is not None and depth_map_pred is not None:
|
||||
assert mask_crop is not None
|
||||
@@ -324,7 +334,11 @@ class ViewMetrics(ViewMetricsBase):
|
||||
if fg_probability is not None:
|
||||
mask = fg_probability * mask_crop
|
||||
_, abs_ = utils.eval_depth(
|
||||
depth_map_pred, depth_map, get_best_scale=True, mask=mask, crop=0
|
||||
depth_map_pred,
|
||||
depth_map,
|
||||
get_best_scale=True,
|
||||
mask=mask,
|
||||
crop=0,
|
||||
)
|
||||
metrics["depth_abs_fg"] = abs_.mean()
|
||||
|
||||
@@ -346,18 +360,26 @@ class ViewMetrics(ViewMetricsBase):
|
||||
return metrics
|
||||
|
||||
|
||||
def _rgb_metrics(images, images_pred, masks, masks_pred, masks_crop):
|
||||
def _rgb_metrics(
|
||||
images,
|
||||
images_pred,
|
||||
masks=None,
|
||||
masks_crop=None,
|
||||
huber_scaling: float = 0.03,
|
||||
):
|
||||
assert masks_crop is not None
|
||||
if images.shape[1] != images_pred.shape[1]:
|
||||
raise ValueError(
|
||||
f"Network output's RGB images had {images_pred.shape[1]} "
|
||||
f"channels. {images.shape[1]} expected."
|
||||
)
|
||||
rgb_abs = ((images_pred - images).abs()).mean(dim=1, keepdim=True)
|
||||
rgb_squared = ((images_pred - images) ** 2).mean(dim=1, keepdim=True)
|
||||
rgb_loss = utils.huber(rgb_squared, scaling=0.03)
|
||||
rgb_loss = utils.huber(rgb_squared, scaling=huber_scaling)
|
||||
crop_mass = masks_crop.sum().clamp(1.0)
|
||||
results = {
|
||||
"rgb_huber": (rgb_loss * masks_crop).sum() / crop_mass,
|
||||
"rgb_l1": (rgb_abs * masks_crop).sum() / crop_mass,
|
||||
"rgb_mse": (rgb_squared * masks_crop).sum() / crop_mass,
|
||||
"rgb_psnr": utils.calc_psnr(images_pred, images, mask=masks_crop),
|
||||
}
|
||||
|
||||
@@ -135,6 +135,7 @@ class LSTMRenderer(BaseRenderer, torch.nn.Module):
|
||||
break
|
||||
|
||||
# run the lstm marcher
|
||||
# pyre-fixme[29]: `Union[Tensor, Module]` is not a function.
|
||||
state_h, state_c = self._lstm(
|
||||
raymarch_features.view(-1, raymarch_features.shape[-1]),
|
||||
states[-1],
|
||||
@@ -142,6 +143,7 @@ class LSTMRenderer(BaseRenderer, torch.nn.Module):
|
||||
if state_h.requires_grad:
|
||||
state_h.register_hook(lambda x: x.clamp(min=-10, max=10))
|
||||
# predict the next step size
|
||||
# pyre-fixme[29]: `Union[Tensor, Module]` is not a function.
|
||||
signed_distance = self._out_layer(state_h).view(ray_bundle_t.lengths.shape)
|
||||
# log the lstm states
|
||||
states.append((state_h, state_c))
|
||||
|
||||
@@ -207,6 +207,7 @@ class AbstractMaskRaySampler(RaySamplerBase, torch.nn.Module):
|
||||
"""
|
||||
sample_mask = None
|
||||
if (
|
||||
# pyre-fixme[29]: `Union[(self: TensorBase, indices: Union[None, slice[An...
|
||||
self._sampling_mode[evaluation_mode] == RenderSamplingMode.MASK_SAMPLE
|
||||
and mask is not None
|
||||
):
|
||||
@@ -223,6 +224,7 @@ class AbstractMaskRaySampler(RaySamplerBase, torch.nn.Module):
|
||||
EvaluationMode.EVALUATION: self._evaluation_raysampler,
|
||||
}[evaluation_mode]
|
||||
|
||||
# pyre-fixme[29]: `Union[Tensor, Module]` is not a function.
|
||||
ray_bundle = raysampler(
|
||||
cameras=cameras,
|
||||
mask=sample_mask,
|
||||
@@ -240,6 +242,8 @@ class AbstractMaskRaySampler(RaySamplerBase, torch.nn.Module):
|
||||
"Heterogeneous ray bundle is not supported for conical frustum computation yet"
|
||||
)
|
||||
elif self.cast_ray_bundle_as_cone:
|
||||
# pyre-fixme[9]: pixel_hw has type `Tuple[float, float]`; used as
|
||||
# `Tuple[Union[Tensor, Module], Union[Tensor, Module]]`.
|
||||
pixel_hw: Tuple[float, float] = (self.pixel_height, self.pixel_width)
|
||||
pixel_radii_2d = compute_radii(cameras, ray_bundle.xys[..., :2], pixel_hw)
|
||||
return ImplicitronRayBundle(
|
||||
|
||||
@@ -179,8 +179,10 @@ class AccumulativeRaymarcherBase(RaymarcherBase, torch.nn.Module):
|
||||
rays_densities = torch.relu(rays_densities)
|
||||
|
||||
weighted_densities = deltas * rays_densities
|
||||
# pyre-fixme[29]: `Union[Tensor, Module]` is not a function.
|
||||
capped_densities = self._capping_function(weighted_densities)
|
||||
|
||||
# pyre-fixme[29]: `Union[Tensor, Module]` is not a function.
|
||||
rays_opacities = self._capping_function(
|
||||
torch.cumsum(weighted_densities, dim=-1)
|
||||
)
|
||||
@@ -190,6 +192,7 @@ class AccumulativeRaymarcherBase(RaymarcherBase, torch.nn.Module):
|
||||
)
|
||||
absorption_shifted[..., : self.surface_thickness] = 1.0
|
||||
|
||||
# pyre-fixme[29]: `Union[Tensor, Module]` is not a function.
|
||||
weights = self._weight_function(capped_densities, absorption_shifted)
|
||||
features = (weights[..., None] * rays_features).sum(dim=-2)
|
||||
depth = (weights * ray_lengths)[..., None].sum(dim=-2)
|
||||
@@ -197,6 +200,8 @@ class AccumulativeRaymarcherBase(RaymarcherBase, torch.nn.Module):
|
||||
alpha = opacities if self.blend_output else 1
|
||||
if self._bg_color.shape[-1] not in [1, features.shape[-1]]:
|
||||
raise ValueError("Wrong number of background color channels.")
|
||||
# pyre-fixme[58]: `*` is not supported for operand types `int` and
|
||||
# `Union[Tensor, Module]`.
|
||||
features = alpha * features + (1 - opacities) * self._bg_color
|
||||
|
||||
return RendererOutput(
|
||||
|
||||
@@ -61,6 +61,7 @@ class SignedDistanceFunctionRenderer(BaseRenderer, torch.nn.Module):
|
||||
|
||||
def create_ray_tracer(self) -> None:
|
||||
self.ray_tracer = RayTracing(
|
||||
# pyre-fixme[32]: Keyword argument must be a mapping with string keys.
|
||||
**self.ray_tracer_args,
|
||||
object_bounding_sphere=self.object_bounding_sphere,
|
||||
)
|
||||
@@ -149,6 +150,8 @@ class SignedDistanceFunctionRenderer(BaseRenderer, torch.nn.Module):
|
||||
n_eik_points,
|
||||
3,
|
||||
# but got `Union[device, Tensor, Module]`.
|
||||
# pyre-fixme[6]: For 3rd argument expected `Union[None, int, str,
|
||||
# device]` but got `Union[device, Tensor, Module]`.
|
||||
device=self._bg_color.device,
|
||||
).uniform_(-eik_bounding_box, eik_bounding_box)
|
||||
eikonal_pixel_points = points.clone()
|
||||
@@ -205,6 +208,7 @@ class SignedDistanceFunctionRenderer(BaseRenderer, torch.nn.Module):
|
||||
]
|
||||
normals_full.view(-1, 3)[surface_mask] = normals
|
||||
render_full.view(-1, self.render_features_dimensions)[surface_mask] = (
|
||||
# pyre-fixme[29]: `Union[Tensor, Module]` is not a function.
|
||||
self._rgb_network(
|
||||
features,
|
||||
differentiable_surface_points[None],
|
||||
@@ -216,8 +220,7 @@ class SignedDistanceFunctionRenderer(BaseRenderer, torch.nn.Module):
|
||||
)
|
||||
mask_full.view(-1, 1)[~surface_mask] = torch.sigmoid(
|
||||
# pyre-fixme[6]: For 1st param expected `Tensor` but got `float`.
|
||||
-self.soft_mask_alpha
|
||||
* sdf_output[~surface_mask]
|
||||
-self.soft_mask_alpha * sdf_output[~surface_mask]
|
||||
)
|
||||
|
||||
# scatter points with surface_mask
|
||||
|
||||
@@ -532,6 +532,7 @@ def _get_ray_dir_dot_prods(camera: CamerasBase, pts: torch.Tensor):
|
||||
|
||||
# does not produce nans randomly unlike get_camera_center() below
|
||||
cam_centers_rep = -torch.bmm(
|
||||
# pyre-fixme[29]: `Union[(self: TensorBase, indices: Union[None, slice[Any, A...
|
||||
camera_rep.T[:, None],
|
||||
camera_rep.R.permute(0, 2, 1),
|
||||
).reshape(-1, *([1] * (pts.ndim - 2)), 3)
|
||||
|
||||
@@ -209,6 +209,7 @@ def handle_seq_id(
|
||||
seq_id = torch.tensor(seq_id, dtype=torch.long, device=device)
|
||||
# pyre-fixme[16]: Item `List` of `Union[List[int], List[str], LongTensor]` has
|
||||
# no attribute `to`.
|
||||
# pyre-fixme[7]: Expected `LongTensor` but got `Tensor`.
|
||||
return seq_id.to(device)
|
||||
|
||||
|
||||
|
||||
@@ -21,7 +21,6 @@ def cleanup_eval_depth(
|
||||
sigma: float = 0.01,
|
||||
image=None,
|
||||
):
|
||||
|
||||
ba, _, H, W = depth.shape
|
||||
|
||||
pcl = point_cloud.points_padded()
|
||||
|
||||
@@ -6,12 +6,15 @@
|
||||
|
||||
# pyre-unsafe
|
||||
|
||||
import logging
|
||||
import math
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
from torch.nn import functional as F
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def eval_depth(
|
||||
pred: torch.Tensor,
|
||||
@@ -21,6 +24,8 @@ def eval_depth(
|
||||
get_best_scale: bool = True,
|
||||
mask_thr: float = 0.5,
|
||||
best_scale_clamp_thr: float = 1e-4,
|
||||
use_disparity: bool = False,
|
||||
disparity_eps: float = 1e-4,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Evaluate the depth error between the prediction `pred` and the ground
|
||||
@@ -64,6 +69,13 @@ def eval_depth(
|
||||
# s.t. we get best possible mse error
|
||||
scale_best = estimate_depth_scale_factor(pred, gt, dmask, best_scale_clamp_thr)
|
||||
pred = pred * scale_best[:, None, None, None]
|
||||
if use_disparity:
|
||||
gt = torch.div(1.0, (gt + disparity_eps))
|
||||
pred = torch.div(1.0, (pred + disparity_eps))
|
||||
scale_best = estimate_depth_scale_factor(
|
||||
pred, gt, dmask, best_scale_clamp_thr
|
||||
).detach()
|
||||
pred = pred * scale_best[:, None, None, None]
|
||||
|
||||
df = gt - pred
|
||||
|
||||
@@ -117,6 +129,7 @@ def calc_bce(
|
||||
pred_eps: float = 0.01,
|
||||
mask: Optional[torch.Tensor] = None,
|
||||
lerp_bound: Optional[float] = None,
|
||||
pred_logits: bool = False,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Calculates the binary cross entropy.
|
||||
@@ -139,9 +152,23 @@ def calc_bce(
|
||||
weight = torch.ones_like(gt) * mask
|
||||
|
||||
if lerp_bound is not None:
|
||||
# binary_cross_entropy_lerp requires pred to be in [0, 1]
|
||||
if pred_logits:
|
||||
pred = F.sigmoid(pred)
|
||||
|
||||
return binary_cross_entropy_lerp(pred, gt, weight, lerp_bound)
|
||||
else:
|
||||
return F.binary_cross_entropy(pred, gt, reduction="mean", weight=weight)
|
||||
if pred_logits:
|
||||
loss = F.binary_cross_entropy_with_logits(
|
||||
pred,
|
||||
gt,
|
||||
reduction="none",
|
||||
weight=weight,
|
||||
)
|
||||
else:
|
||||
loss = F.binary_cross_entropy(pred, gt, reduction="none", weight=weight)
|
||||
|
||||
return loss.mean()
|
||||
|
||||
|
||||
def binary_cross_entropy_lerp(
|
||||
|
||||
@@ -111,10 +111,10 @@ def load_model(fl, map_location: Optional[dict]):
|
||||
flstats = get_stats_path(fl)
|
||||
flmodel = get_model_path(fl)
|
||||
flopt = get_optimizer_path(fl)
|
||||
model_state_dict = torch.load(flmodel, map_location=map_location)
|
||||
model_state_dict = torch.load(flmodel, map_location=map_location, weights_only=True)
|
||||
stats = load_stats(flstats)
|
||||
if os.path.isfile(flopt):
|
||||
optimizer = torch.load(flopt, map_location=map_location)
|
||||
optimizer = torch.load(flopt, map_location=map_location, weights_only=True)
|
||||
else:
|
||||
optimizer = None
|
||||
|
||||
|
||||
@@ -100,7 +100,6 @@ def render_point_cloud_pytorch3d(
|
||||
bin_size: Optional[int] = None,
|
||||
**kwargs,
|
||||
):
|
||||
|
||||
# feature dimension
|
||||
featdim = point_cloud.features_packed().shape[-1]
|
||||
|
||||
|
||||
@@ -37,7 +37,6 @@ class AverageMeter:
|
||||
self.count = 0
|
||||
|
||||
def update(self, val, n=1, epoch=0):
|
||||
|
||||
# make sure the history is of the same len as epoch
|
||||
while len(self.history) <= epoch:
|
||||
self.history.append([])
|
||||
@@ -115,7 +114,6 @@ class Stats:
|
||||
visdom_server="http://localhost",
|
||||
visdom_port=8097,
|
||||
):
|
||||
|
||||
self.log_vars = log_vars
|
||||
self.visdom_env = visdom_env
|
||||
self.visdom_server = visdom_server
|
||||
@@ -202,7 +200,6 @@ class Stats:
|
||||
self.log_vars.append(add_log_var)
|
||||
|
||||
def update(self, preds, time_start=None, freeze_iter=False, stat_set="train"):
|
||||
|
||||
if self.epoch == -1: # uninitialized
|
||||
logger.warning(
|
||||
"epoch==-1 means uninitialized stats structure -> new_epoch() called"
|
||||
@@ -219,7 +216,6 @@ class Stats:
|
||||
epoch = self.epoch
|
||||
|
||||
for stat in self.log_vars:
|
||||
|
||||
if stat not in self.stats[stat_set]:
|
||||
self.stats[stat_set][stat] = AverageMeter()
|
||||
|
||||
@@ -248,7 +244,6 @@ class Stats:
|
||||
self.stats[stat_set][stat].update(val, epoch=epoch, n=1)
|
||||
|
||||
def get_epoch_averages(self, epoch=None):
|
||||
|
||||
stat_sets = list(self.stats.keys())
|
||||
|
||||
if epoch is None:
|
||||
@@ -345,7 +340,6 @@ class Stats:
|
||||
def plot_stats(
|
||||
self, visdom_env=None, plot_file=None, visdom_server=None, visdom_port=None
|
||||
):
|
||||
|
||||
# use the cached visdom env if none supplied
|
||||
if visdom_env is None:
|
||||
visdom_env = self.visdom_env
|
||||
@@ -449,7 +443,6 @@ class Stats:
|
||||
warnings.warn("Cant dump stats due to insufficient permissions!")
|
||||
|
||||
def synchronize_logged_vars(self, log_vars, default_val=float("NaN")):
|
||||
|
||||
stat_sets = list(self.stats.keys())
|
||||
|
||||
# remove the additional log_vars
|
||||
@@ -490,11 +483,12 @@ class Stats:
|
||||
for ep in range(lastep):
|
||||
self.stats[stat_set][stat].update(default_val, n=1, epoch=ep)
|
||||
epoch_generated = self.stats[stat_set][stat].get_epoch()
|
||||
assert (
|
||||
epoch_generated == self.epoch + 1
|
||||
), "bad epoch of synchronized log_var! %d vs %d" % (
|
||||
self.epoch + 1,
|
||||
epoch_generated,
|
||||
assert epoch_generated == self.epoch + 1, (
|
||||
"bad epoch of synchronized log_var! %d vs %d"
|
||||
% (
|
||||
self.epoch + 1,
|
||||
epoch_generated,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -16,8 +16,17 @@ from typing import Optional, Tuple, Union
|
||||
import matplotlib
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from PIL import Image
|
||||
|
||||
_NO_TORCHVISION = False
|
||||
try:
|
||||
import torchvision
|
||||
except ImportError:
|
||||
_NO_TORCHVISION = True
|
||||
|
||||
|
||||
_DEFAULT_FFMPEG = os.environ.get("FFMPEG", "ffmpeg")
|
||||
|
||||
matplotlib.use("Agg")
|
||||
@@ -36,6 +45,7 @@ class VideoWriter:
|
||||
fps: int = 20,
|
||||
output_format: str = "visdom",
|
||||
rmdir_allowed: bool = False,
|
||||
use_torchvision_video_writer: bool = False,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
"""
|
||||
@@ -49,6 +59,8 @@ class VideoWriter:
|
||||
is supported.
|
||||
rmdir_allowed: If `True` delete and create `cache_dir` in case
|
||||
it is not empty.
|
||||
use_torchvision_video_writer: If `True` use `torchvision.io.write_video`
|
||||
to write the video
|
||||
"""
|
||||
self.rmdir_allowed = rmdir_allowed
|
||||
self.output_format = output_format
|
||||
@@ -56,10 +68,14 @@ class VideoWriter:
|
||||
self.out_path = out_path
|
||||
self.cache_dir = cache_dir
|
||||
self.ffmpeg_bin = ffmpeg_bin
|
||||
self.use_torchvision_video_writer = use_torchvision_video_writer
|
||||
self.frames = []
|
||||
self.regexp = "frame_%08d.png"
|
||||
self.frame_num = 0
|
||||
|
||||
if self.use_torchvision_video_writer:
|
||||
assert not _NO_TORCHVISION, "torchvision not available"
|
||||
|
||||
if self.cache_dir is not None:
|
||||
self.tmp_dir = None
|
||||
if os.path.isdir(self.cache_dir):
|
||||
@@ -114,7 +130,7 @@ class VideoWriter:
|
||||
resize = im.size
|
||||
# make sure size is divisible by 2
|
||||
resize = tuple([resize[i] + resize[i] % 2 for i in (0, 1)])
|
||||
# pyre-fixme[16]: Module `Image` has no attribute `ANTIALIAS`.
|
||||
|
||||
im = im.resize(resize, Image.ANTIALIAS)
|
||||
im.save(outfile)
|
||||
|
||||
@@ -139,38 +155,56 @@ class VideoWriter:
|
||||
# got `Optional[str]`.
|
||||
regexp = os.path.join(self.cache_dir, self.regexp)
|
||||
|
||||
if shutil.which(self.ffmpeg_bin) is None:
|
||||
raise ValueError(
|
||||
f"Cannot find ffmpeg as `{self.ffmpeg_bin}`. "
|
||||
+ "Please set FFMPEG in the environment or ffmpeg_bin on this class."
|
||||
)
|
||||
|
||||
if self.output_format == "visdom": # works for ppt too
|
||||
args = [
|
||||
self.ffmpeg_bin,
|
||||
"-r",
|
||||
str(self.fps),
|
||||
"-i",
|
||||
regexp,
|
||||
"-vcodec",
|
||||
"h264",
|
||||
"-f",
|
||||
"mp4",
|
||||
"-y",
|
||||
"-crf",
|
||||
"18",
|
||||
"-b",
|
||||
"2000k",
|
||||
"-pix_fmt",
|
||||
"yuv420p",
|
||||
self.out_path,
|
||||
]
|
||||
if quiet:
|
||||
subprocess.check_call(
|
||||
args, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL
|
||||
# Video codec parameters
|
||||
video_codec = "h264"
|
||||
crf = "18"
|
||||
b = "2000k"
|
||||
pix_fmt = "yuv420p"
|
||||
|
||||
if self.use_torchvision_video_writer:
|
||||
torchvision.io.write_video(
|
||||
self.out_path,
|
||||
torch.stack(
|
||||
[torch.from_numpy(np.array(Image.open(f))) for f in self.frames]
|
||||
),
|
||||
fps=self.fps,
|
||||
video_codec=video_codec,
|
||||
options={"crf": crf, "b": b, "pix_fmt": pix_fmt},
|
||||
)
|
||||
|
||||
else:
|
||||
subprocess.check_call(args)
|
||||
if shutil.which(self.ffmpeg_bin) is None:
|
||||
raise ValueError(
|
||||
f"Cannot find ffmpeg as `{self.ffmpeg_bin}`. "
|
||||
+ "Please set FFMPEG in the environment or ffmpeg_bin on this class."
|
||||
)
|
||||
|
||||
args = [
|
||||
self.ffmpeg_bin,
|
||||
"-r",
|
||||
str(self.fps),
|
||||
"-i",
|
||||
regexp,
|
||||
"-vcodec",
|
||||
video_codec,
|
||||
"-f",
|
||||
"mp4",
|
||||
"-y",
|
||||
"-crf",
|
||||
crf,
|
||||
"-b",
|
||||
b,
|
||||
"-pix_fmt",
|
||||
pix_fmt,
|
||||
self.out_path,
|
||||
]
|
||||
if quiet:
|
||||
subprocess.check_call(
|
||||
args, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL
|
||||
)
|
||||
else:
|
||||
subprocess.check_call(args)
|
||||
else:
|
||||
raise ValueError("no such output type %s" % str(self.output_format))
|
||||
|
||||
|
||||
@@ -163,6 +163,8 @@ def _read_chunks(
|
||||
if binary_data is not None:
|
||||
binary_data = np.frombuffer(binary_data, dtype=np.uint8)
|
||||
|
||||
assert binary_data is not None
|
||||
|
||||
return json_data, binary_data
|
||||
|
||||
|
||||
|
||||
@@ -7,6 +7,7 @@
|
||||
# pyre-unsafe
|
||||
|
||||
"""This module implements utility functions for loading .mtl files and textures."""
|
||||
|
||||
import os
|
||||
import warnings
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
@@ -8,6 +8,7 @@
|
||||
|
||||
|
||||
"""This module implements utility functions for loading and saving meshes."""
|
||||
|
||||
import os
|
||||
import warnings
|
||||
from collections import namedtuple
|
||||
@@ -813,7 +814,6 @@ def _save(
|
||||
save_texture: bool = False,
|
||||
save_normals: bool = False,
|
||||
) -> None:
|
||||
|
||||
if len(verts) and (verts.dim() != 2 or verts.size(1) != 3):
|
||||
message = "'verts' should either be empty or of shape (num_verts, 3)."
|
||||
raise ValueError(message)
|
||||
|
||||
@@ -14,6 +14,7 @@ meshes as .off files.
|
||||
This format is introduced, for example, at
|
||||
http://www.geomview.org/docs/html/OFF.html .
|
||||
"""
|
||||
|
||||
import warnings
|
||||
from typing import cast, Optional, Tuple, Union
|
||||
|
||||
@@ -84,7 +85,7 @@ def _read_faces_lump(
|
||||
)
|
||||
data = np.loadtxt(file, dtype=np.float32, ndmin=2, max_rows=n_faces)
|
||||
except ValueError as e:
|
||||
if n_faces > 1 and "Wrong number of columns" in e.args[0]:
|
||||
if n_faces > 1 and "number of columns" in e.args[0]:
|
||||
file.seek(old_offset)
|
||||
return None
|
||||
raise ValueError("Not enough face data.") from None
|
||||
|
||||
@@ -11,6 +11,7 @@
|
||||
This module implements utility functions for loading and saving
|
||||
meshes and point clouds as PLY files.
|
||||
"""
|
||||
|
||||
import itertools
|
||||
import os
|
||||
import struct
|
||||
@@ -1246,7 +1247,7 @@ def _save_ply(
|
||||
return
|
||||
|
||||
color_np_type = np.ubyte if colors_as_uint8 else np.float32
|
||||
verts_dtype = [("verts", np.float32, 3)]
|
||||
verts_dtype: list = [("verts", np.float32, 3)]
|
||||
if verts_normals is not None:
|
||||
verts_dtype.append(("normals", np.float32, 3))
|
||||
if verts_colors is not None:
|
||||
|
||||
@@ -122,12 +122,17 @@ def corresponding_cameras_alignment(
|
||||
|
||||
# create a new cameras object and set the R and T accordingly
|
||||
cameras_src_aligned = cameras_src.clone()
|
||||
# pyre-fixme[6]: For 2nd argument expected `Tensor` but got `Union[Tensor, Module]`.
|
||||
cameras_src_aligned.R = torch.bmm(align_t_R.expand_as(cameras_src.R), cameras_src.R)
|
||||
cameras_src_aligned.T = (
|
||||
torch.bmm(
|
||||
align_t_T[:, None].repeat(cameras_src.R.shape[0], 1, 1),
|
||||
# pyre-fixme[6]: For 2nd argument expected `Tensor` but got
|
||||
# `Union[Tensor, Module]`.
|
||||
cameras_src.R,
|
||||
)[:, 0]
|
||||
# pyre-fixme[29]: `Union[(self: TensorBase, other: Union[bool, complex,
|
||||
# float, int, Tensor]) -> Tensor, Tensor, Module]` is not a function.
|
||||
+ cameras_src.T * align_t_s
|
||||
)
|
||||
|
||||
@@ -175,6 +180,7 @@ def _align_camera_extrinsics(
|
||||
R_A = (U V^T)^T
|
||||
```
|
||||
"""
|
||||
# pyre-fixme[6]: For 1st argument expected `Tensor` but got `Union[Tensor, Module]`.
|
||||
RRcov = torch.bmm(cameras_src.R, cameras_tgt.R.transpose(2, 1)).mean(0)
|
||||
U, _, V = torch.svd(RRcov)
|
||||
align_t_R = V @ U.t()
|
||||
@@ -204,7 +210,11 @@ def _align_camera_extrinsics(
|
||||
T_A = mean(B) - mean(A) * s_A
|
||||
```
|
||||
"""
|
||||
# pyre-fixme[6]: For 1st argument expected `Tensor` but got `Union[Tensor, Module]`.
|
||||
# pyre-fixme[29]: `Union[(self: TensorBase, indices: Union[None, slice[Any, Any, ...
|
||||
A = torch.bmm(cameras_src.R, cameras_src.T[:, :, None])[:, :, 0]
|
||||
# pyre-fixme[6]: For 1st argument expected `Tensor` but got `Union[Tensor, Module]`.
|
||||
# pyre-fixme[29]: `Union[(self: TensorBase, indices: Union[None, slice[Any, Any, ...
|
||||
B = torch.bmm(cameras_src.R, cameras_tgt.T[:, :, None])[:, :, 0]
|
||||
Amu = A.mean(0, keepdim=True)
|
||||
Bmu = B.mean(0, keepdim=True)
|
||||
|
||||
@@ -62,7 +62,7 @@ def cubify(
|
||||
*,
|
||||
feats: Optional[torch.Tensor] = None,
|
||||
device=None,
|
||||
align: str = "topleft"
|
||||
align: str = "topleft",
|
||||
) -> Meshes:
|
||||
r"""
|
||||
Converts a voxel to a mesh by replacing each occupied voxel with a cube
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user